From 97cbb4b4b8d596a362184d590bd0f6f0746f5505 Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Sun, 12 Apr 2026 20:50:44 +0800 Subject: [PATCH 001/192] feat: add vpto backend --- .../skills/generate-vpto-release-doc/SKILL.md | 86 + .../scripts/generate_release_vpto_spec.py | 171 + .../skills/llvm-test-tool-fallback/SKILL.md | 16 + .../pto-a5-installed-impl-trace/SKILL.md | 198 + .../SKILL.md | 37 + .codex/skills/ptoas-build-and-abs/SKILL.md | 101 + .../skills/ptoas-npu-validation-a5/SKILL.md | 335 + .../skills/ptoas-vpto-llvm-artifacts/SKILL.md | 319 + .gitignore | 5 + docs/PTO_IR_manual.md | 43 +- docs/isa/01-pipeline-sync.md | 462 ++ docs/isa/02-dma-copy.md | 602 ++ docs/isa/03-vector-load-store.md | 595 ++ docs/isa/04-predicate-load-store.md | 135 + docs/isa/05-materialization-predicate.md | 322 + docs/isa/06-unary-vector-ops.md | 172 + docs/isa/07-binary-vector-ops.md | 293 + docs/isa/08-vec-scalar-ops.md | 236 + docs/isa/09-conversion-ops.md | 257 + docs/isa/10-reduction-ops.md | 244 + docs/isa/11-compare-select.md | 182 + docs/isa/12-data-rearrangement.md | 230 + docs/isa/13-dsa-sfu-ops.md | 229 + docs/isa/14-shared-arith.md | 99 + docs/isa/15-shared-scf.md | 97 + docs/release/vpto-spec-v0.1.md | 4883 +++++++++++ docs/release/vpto-spec-v0.2.md | 5072 ++++++++++++ docs/sample.pto | 57 + docs/tilelang-dsl-guide.md | 2921 +++++++ docs/tilelang-dsl-syntax-sugar-proposals.md | 404 + docs/vpto-spec.md | 974 +++ include/PTO/IR/PTOOps.td | 60 +- include/PTO/IR/PTOTypeDefs.td | 17 +- include/PTO/IR/VPTOOps.td | 1471 ++++ include/PTO/IR/VPTOTypeDefs.td | 53 + include/PTO/Transforms/HIVMIntrinsicNaming.h | 60 + include/PTO/Transforms/Passes.h | 7 +- include/PTO/Transforms/Passes.td | 98 + include/PTO/Transforms/VPTOLLVMEmitter.h | 43 + include/PTO/Transforms/VPTOLowering.h | 241 + include/pto-c/Dialect/PTO.h | 3 + lib/Bindings/Python/CMakeLists.txt | 2 + lib/Bindings/Python/PTOModule.cpp | 18 +- lib/CAPI/Dialect/PTO.cpp | 13 + lib/PTO/IR/CMakeLists.txt | 2 + lib/PTO/IR/PTO.cpp | 412 +- lib/PTO/IR/VPTO.cpp | 2924 +++++++ lib/PTO/Transforms/CMakeLists.txt | 16 + lib/PTO/Transforms/HIVMIntrinsicNaming.cpp | 561 ++ lib/PTO/Transforms/PTOToVPTO.cpp | 604 ++ lib/PTO/Transforms/PTOToVPTOLowering.cpp | 7290 +++++++++++++++++ lib/PTO/Transforms/PTOVPTOExpandBridgeOps.cpp | 114 + lib/PTO/Transforms/PTOVPTOPtrBoundary.cpp | 337 + lib/PTO/Transforms/PTOValidateVPTOIR.cpp | 756 ++ lib/PTO/Transforms/PTOViewToMemref.cpp | 2 +- lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 4562 +++++++++++ python/pto/dialects/pto.py | 872 ++ scripts/batch_compile_output_cpp.sh | 464 ++ scripts/compile_pto_to_vpto_llvm.sh | 116 + scripts/ptoas_env.sh | 83 + test/dsl/abs.py | 34 + test/dsl/strict_vecscope.py | 42 + test/dsl/template_abs.py | 48 + test/lit.cfg.py | 85 + tools/ptoas/ptoas.cpp | 327 +- 65 files changed, 41468 insertions(+), 46 deletions(-) create mode 100644 .codex/skills/generate-vpto-release-doc/SKILL.md create mode 100644 .codex/skills/generate-vpto-release-doc/scripts/generate_release_vpto_spec.py create mode 100644 .codex/skills/llvm-test-tool-fallback/SKILL.md create mode 100644 .codex/skills/pto-a5-installed-impl-trace/SKILL.md create mode 100644 .codex/skills/ptoas-bisheng-asm-from-object-cmd/SKILL.md create mode 100644 .codex/skills/ptoas-build-and-abs/SKILL.md create mode 100644 .codex/skills/ptoas-npu-validation-a5/SKILL.md create mode 100644 .codex/skills/ptoas-vpto-llvm-artifacts/SKILL.md create mode 100644 docs/isa/01-pipeline-sync.md create mode 100644 docs/isa/02-dma-copy.md create mode 100644 docs/isa/03-vector-load-store.md create mode 100644 docs/isa/04-predicate-load-store.md create mode 100644 docs/isa/05-materialization-predicate.md create mode 100644 docs/isa/06-unary-vector-ops.md create mode 100644 docs/isa/07-binary-vector-ops.md create mode 100644 docs/isa/08-vec-scalar-ops.md create mode 100644 docs/isa/09-conversion-ops.md create mode 100644 docs/isa/10-reduction-ops.md create mode 100644 docs/isa/11-compare-select.md create mode 100644 docs/isa/12-data-rearrangement.md create mode 100644 docs/isa/13-dsa-sfu-ops.md create mode 100644 docs/isa/14-shared-arith.md create mode 100644 docs/isa/15-shared-scf.md create mode 100644 docs/release/vpto-spec-v0.1.md create mode 100644 docs/release/vpto-spec-v0.2.md create mode 100644 docs/sample.pto create mode 100644 docs/tilelang-dsl-guide.md create mode 100644 docs/tilelang-dsl-syntax-sugar-proposals.md create mode 100644 docs/vpto-spec.md create mode 100644 include/PTO/IR/VPTOOps.td create mode 100644 include/PTO/IR/VPTOTypeDefs.td create mode 100644 include/PTO/Transforms/HIVMIntrinsicNaming.h create mode 100644 include/PTO/Transforms/VPTOLLVMEmitter.h create mode 100644 include/PTO/Transforms/VPTOLowering.h create mode 100644 lib/PTO/IR/VPTO.cpp create mode 100644 lib/PTO/Transforms/HIVMIntrinsicNaming.cpp create mode 100644 lib/PTO/Transforms/PTOToVPTO.cpp create mode 100644 lib/PTO/Transforms/PTOToVPTOLowering.cpp create mode 100644 lib/PTO/Transforms/PTOVPTOExpandBridgeOps.cpp create mode 100644 lib/PTO/Transforms/PTOVPTOPtrBoundary.cpp create mode 100644 lib/PTO/Transforms/PTOValidateVPTOIR.cpp create mode 100644 lib/PTO/Transforms/VPTOLLVMEmitter.cpp create mode 100755 scripts/batch_compile_output_cpp.sh create mode 100755 scripts/compile_pto_to_vpto_llvm.sh create mode 100644 scripts/ptoas_env.sh create mode 100644 test/dsl/abs.py create mode 100644 test/dsl/strict_vecscope.py create mode 100644 test/dsl/template_abs.py create mode 100644 test/lit.cfg.py diff --git a/.codex/skills/generate-vpto-release-doc/SKILL.md b/.codex/skills/generate-vpto-release-doc/SKILL.md new file mode 100644 index 000000000..fde4ceb97 --- /dev/null +++ b/.codex/skills/generate-vpto-release-doc/SKILL.md @@ -0,0 +1,86 @@ +--- +name: generate-vpto-release-doc +description: Generate or refresh `docs/release/vpto-spec-v*.md` by merging `docs/vpto-spec.md` with `docs/isa/*.md`, following the release-doc naming and layout. Use when the user asks to create or update a merged VPTO release spec, inline ISA Markdown into one release document, add TOC and version bullets, move `Quick Reference by Category` to the end, or strip update, appendix, and correspondence content from the merged release doc. +--- + +# Generate VPTO Release Doc + +Use this skill when the task is specifically about: +- creating a new merged release document under `docs/release/` +- refreshing an existing `vpto-spec-v*.md` release doc from `docs/vpto-spec.md` and `docs/isa/*.md` +- keeping the merged release doc aligned with the naming and structure used in `docs/release/vpto-spec-v0.1.md` + +## Canonical Workflow + +1. Pick the target version and output path. + +Default output path: + +```bash +docs/release/vpto-spec-v.md +``` + +2. Run the bundled generator script. + +```bash +python3 .codex/skills/generate-vpto-release-doc/scripts/generate_release_vpto_spec.py --version 0.2 +``` + +If you need an explicit note for the new version bullet: + +```bash +python3 .codex/skills/generate-vpto-release-doc/scripts/generate_release_vpto_spec.py \ + --version 0.2 \ + --version-note 'Merge `docs/vpto-spec.md` with `docs/isa/*.md`; add TOC; move `Quick Reference by Category` to the end; remove update, appendix, and correspondence content' +``` + +3. Review the generated file before finalizing. + +Check these invariants: +- exactly one `#` level title in the whole file +- `[toc]` is present near the top +- the top version bullet for the requested version was added +- `## Quick Reference by Category` is the final top-level section +- no `Updated:` / review-status boilerplate remains at the beginning +- no appendix sections remain +- no `## Correspondence Categories` section remains +- no `CCE correspondence` / builtin-mapping blocks remain + +4. If the user wants extra release-note wording, patch only the version bullets or other small wording around the generated content. Prefer rerunning the script over hand-merging large sections. + +## Source Mapping + +Use `docs/vpto-spec.md` for: +- `Part I: Architecture Overview` +- `Part II: Notation Convention` +- `C-Style Semantics Convention` +- `Template Placeholder Conventions` +- `Instruction Groups` +- `Supported Data Types` +- `Common Patterns` +- `Quick Reference by Category` + +Use `docs/isa/*.md` for: +- the inlined `Detailed ISA Group Reference` + +## Merge Rules + +The merged release document should: +- keep the release-doc title and version-bullet style +- preserve the `Instruction Groups` summary table +- inline `docs/isa/*.md` under `Detailed ISA Group Reference` +- convert `docs/isa/*.md` links into in-document anchors like `#isa-03-vector-load-store` +- demote the inlined ISA headings by two levels so the merged TOC stays stable +- place `Quick Reference by Category` at the end + +The merged release document must remove: +- beginning-of-file update/review metadata from `docs/vpto-spec.md` +- `## Correspondence Categories` +- all `CCE correspondence` blocks and related builtin/token mapping lines +- the sentence `For detailed semantics, C-style pseudocode, and CCE mappings, see the individual group documentation files.` +- appendix sections + +## Notes + +- The script assumes the source headings in `docs/vpto-spec.md` keep their current names. If extraction fails, inspect the heading names there before patching the script. +- The script is deterministic and is the preferred path for regenerating large merged release docs. diff --git a/.codex/skills/generate-vpto-release-doc/scripts/generate_release_vpto_spec.py b/.codex/skills/generate-vpto-release-doc/scripts/generate_release_vpto_spec.py new file mode 100644 index 000000000..f03e7d592 --- /dev/null +++ b/.codex/skills/generate-vpto-release-doc/scripts/generate_release_vpto_spec.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 +"""Generate merged VPTO release spec from docs/vpto-spec.md and docs/isa/*.md.""" + +from __future__ import annotations + +import argparse +import re +from pathlib import Path + + +ROOT = Path(__file__).resolve().parents[4] +DOCS_DIR = ROOT / "docs" +SOURCE_SPEC = DOCS_DIR / "vpto-spec.md" +ISA_DIR = DOCS_DIR / "isa" +RELEASE_DIR = DOCS_DIR / "release" + +TITLE = "# PTO micro Instruction Spec \u2014 Draft (A5)" +DEFAULT_VERSION_NOTES = { + "0.1": "Doc Init", + "0.2": "Update micro Instruction latency and throughput", + "0.3": "Refresh VPTO ISA specification", +} + +KEEP_SECTIONS = [ + "## Part I: Architecture Overview", + "## Part II: Notation Convention", + "## Instruction Groups", + "## Supported Data Types", + "## Common Patterns", + "## Quick Reference by Category", +] + +ISA_LINK_RE = re.compile(r"\[([^\]]+)\]\((?:\.\./)?(?:isa/)?([0-9]{2}-[A-Za-z0-9-]+)\.md\)") + + +def extract_sections(markdown: str) -> dict[str, str]: + headings = list(re.finditer(r"^## .*$", markdown, flags=re.MULTILINE)) + sections: dict[str, str] = {} + for index, match in enumerate(headings): + heading = match.group(0).strip() + start = match.start() + end = headings[index + 1].start() if index + 1 < len(headings) else len(markdown) + sections[heading] = markdown[start:end].strip() + "\n" + return sections + + +def rewrite_isa_links(text: str) -> str: + return ISA_LINK_RE.sub(lambda m: f"[{m.group(1)}](#isa-{m.group(2).lower()})", text) + + +def trim_trailing_rule(text: str) -> str: + return re.sub(r"\n---\s*\Z", "\n", text.strip() + "\n").rstrip() + + +def strip_unwanted_lines(text: str) -> str: + lines = text.splitlines() + kept: list[str] = [] + skip_correspondence = False + for line in lines: + if re.match(r"^## Correspondence Categories\b", line): + skip_correspondence = True + continue + if skip_correspondence: + if re.match(r"^## ", line): + skip_correspondence = False + else: + continue + if line.startswith("> **Status:**") or line.startswith("> **Base:**") or line.startswith("> **Additions from:**") or line.startswith("> **Updated:**"): + continue + if "For detailed semantics, C-style pseudocode, and CCE mappings" in line: + continue + if "CCE correspondence" in line or "builtin mapping" in line.lower(): + continue + kept.append(line) + text = "\n".join(kept).strip() + "\n" + text = re.sub(r"\n## Appendix [A-Z]:.*\Z", "\n", text, flags=re.DOTALL) + return text + + +def demote_headings(text: str, levels: int = 2) -> str: + def replace(match: re.Match[str]) -> str: + hashes = match.group(1) + heading = match.group(2) + new_level = min(6, len(hashes) + levels) + return f"{'#' * new_level} {heading}" + + return re.sub(r"^(#{1,6})\s+(.*)$", replace, text, flags=re.MULTILINE) + + +def render_version_bullets(version: str, version_note: str | None) -> str: + notes = dict(DEFAULT_VERSION_NOTES) + if version_note: + notes[version] = version_note + elif version not in notes: + notes[version] = "Release refresh" + + def key_fn(item: str) -> tuple[int, ...]: + return tuple(int(part) for part in item.split(".")) + + lines = [f"- v{ver}: {notes[ver]}" for ver in sorted(notes, key=key_fn, reverse=True)] + return "\n".join(lines) + + +def build_release_doc(version: str, version_note: str | None) -> str: + source_text = strip_unwanted_lines(SOURCE_SPEC.read_text()) + sections = extract_sections(source_text) + + missing = [name for name in KEEP_SECTIONS if name not in sections] + if missing: + raise SystemExit(f"missing expected headings in docs/vpto-spec.md: {missing}") + + content_sections = [trim_trailing_rule(rewrite_isa_links(sections[name])) for name in KEEP_SECTIONS[:-1]] + + isa_blocks: list[str] = ["## Detailed ISA Group Reference"] + for isa_path in sorted(ISA_DIR.glob("*.md")): + isa_text = rewrite_isa_links(isa_path.read_text().strip() + "\n") + isa_blocks.append(trim_trailing_rule(demote_headings(isa_text))) + + quick_reference = trim_trailing_rule(rewrite_isa_links(sections["## Quick Reference by Category"])) + + parts = [ + TITLE, + "", + render_version_bullets(version, version_note), + "", + "[toc]", + "", + "---", + "", + "\n\n".join(content_sections), + "\n\n".join(isa_blocks), + quick_reference, + "", + ] + return "\n".join(part for part in parts if part is not None) + + +def validate_release_doc(text: str) -> None: + if text.count("# PTO micro Instruction Spec") != 1: + raise SystemExit("expected exactly one top-level title") + if "\n[toc]\n" not in text: + raise SystemExit("missing [toc] near top") + if re.search(r"^## Quick Reference by Category\b", text, flags=re.MULTILINE) is None: + raise SystemExit("missing Quick Reference by Category") + if re.search(r"^## Quick Reference by Category\b[\s\S]*\Z", text, flags=re.MULTILINE) is None: + raise SystemExit("Quick Reference by Category must be present at end") + if re.search(r"^## Appendix\b", text, flags=re.MULTILINE): + raise SystemExit("appendix content must not remain") + if "Updated:" in text or "review" in text.splitlines()[:8]: + raise SystemExit("beginning metadata must not remain") + if "## Correspondence Categories" in text or "CCE correspondence" in text: + raise SystemExit("correspondence content must not remain") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--version", required=True, help="Release version, e.g. 0.2") + parser.add_argument("--version-note", help="Version bullet text for the requested version") + parser.add_argument("--output", help="Explicit output path") + args = parser.parse_args() + + output = Path(args.output) if args.output else RELEASE_DIR / f"vpto-spec-v{args.version}.md" + output.parent.mkdir(parents=True, exist_ok=True) + + text = build_release_doc(args.version, args.version_note) + validate_release_doc(text) + output.write_text(text) + + +if __name__ == "__main__": + main() diff --git a/.codex/skills/llvm-test-tool-fallback/SKILL.md b/.codex/skills/llvm-test-tool-fallback/SKILL.md new file mode 100644 index 000000000..e71bf178d --- /dev/null +++ b/.codex/skills/llvm-test-tool-fallback/SKILL.md @@ -0,0 +1,16 @@ +--- +name: llvm-test-tool-fallback +description: When `lit` or `FileCheck` is missing from the current shell, look for the corresponding LLVM test tools in the environment or existing LLVM workspace before treating it as a repo issue. +--- + +# LLVM Test Tool Fallback + +Use this skill when: +- `python3 -m lit` fails because `lit` is missing +- `FileCheck` is not in `PATH` +- a test command fails only because LLVM test tools are not available in the current shell + +Rule: +- do not stop at `command not found` +- first try to find `lit` / `FileCheck` from the environment's LLVM toolchain or an existing LLVM workspace +- treat missing `lit` / `FileCheck` as an environment-tool issue, not as a PTOAS regression diff --git a/.codex/skills/pto-a5-installed-impl-trace/SKILL.md b/.codex/skills/pto-a5-installed-impl-trace/SKILL.md new file mode 100644 index 000000000..b67cb0128 --- /dev/null +++ b/.codex/skills/pto-a5-installed-impl-trace/SKILL.md @@ -0,0 +1,198 @@ +--- +name: pto-a5-installed-impl-trace +description: Guide LLVM IR discovery for A5 VPTO lowering from the installed CANN/PTO implementation under ASCEND_HOME_PATH. Use when the user does not yet know which `llvm.hivm.*` intrinsic, builtin wrapper, or operand contract a VPTO/A5 op should lower to. +--- + +# PTO A5 Installed Implementation Trace + +Use this skill when the task is specifically about: +- checking what an A5 PTO op really does on the installed machine +- mapping PTO/A5 behavior to builtins or LLVM/HIVM intrinsics +- tracing PTO wrappers down to CCE builtin wrappers such as `__builtin_cce_*` +- deciding whether repo-local lowering is correct or only a guess +- resolving conflicts between generated repo IR and installed PTO headers +- tracing `Cmp`, `Cmps`, predicate, pack, store, or typed vector behavior + +This skill answers: +- what LLVM IR a VPTO op should lower to +- what the authoritative intrinsic name is +- what operand list or mask form the installed toolchain expects +- whether repo-local lowering or emission diverges from installed behavior + +This skill does not answer: +- how to build or link a finished LLVM-path artifact end to end +- how to package `.o`, `fatobj`, or `.so` +- how to run board validation + +## Strong Rule + +If you are about to change repo code for an A5 op, stop and inspect the +installed PTO implementation first. Treat the installed PTO library under +`ASCEND_HOME_PATH` as the semantic source of truth. + +Only make a repo-local substitution after you have confirmed one of: +- the installed PTO headers already express that replacement relationship +- the frontend/compiler intrinsic contract proves two forms are equivalent at + the intrinsic layer + +Do not guess behavior from repo-local lowering, emitter code, or from what +"seems plausible" for an intrinsic sequence. + +Do not start from repo-local lowering when the question is about real A5 +behavior. The installed PTO implementation under `ASCEND_HOME_PATH` is the +first source of truth. + +## Required Search Order + +Always follow this order: + +1. `source /usr/local/Ascend/cann/set_env.sh` +2. confirm `ASCEND_HOME_PATH` +3. inspect installed PTO dispatch headers: + - `$ASCEND_HOME_PATH/aarch64-linux/include/pto/common/pto_instr_impl.hpp` +4. inspect the matching A5 implementation: + - `$ASCEND_HOME_PATH/aarch64-linux/include/pto/npu/a5/T*.hpp` +5. inspect typed helpers: + - `$ASCEND_HOME_PATH/aarch64-linux/include/pto/npu/a5/utils.hpp` +6. inspect builtin wrapper headers when the question is about the real compiler-facing builtin: + - `$ASCEND_HOME_PATH/tools/bisheng_compiler/lib/clang/*/include/__clang_cce_vector_intrinsics.h` + - `$ASCEND_HOME_PATH/tools/bisheng_compiler/lib/clang/*/include/npu_arch_*/__clang_cce_vector_intrinsics.h` +7. inspect intrinsic name availability directly from the installed compiler binary before guessing LLVM/HIVM spellings: + - `strings $ASCEND_HOME_PATH/bin/bisheng | rg 'llvm\\.hivm\\.'` + - narrow to the op under investigation, for example: + - `strings $ASCEND_HOME_PATH/bin/bisheng | rg 'llvm\\.hivm\\.(vneg|vrsqrt|vnot|vmov)'` +8. only then compare against repo-local code under `lib/PTO/Transforms/` + +## Practical Fast Path + +For VPTO LLVM emission work, prefer this concrete order instead of jumping +straight to ad hoc compiler probes: + +1. confirm the op exists in installed PTO/A5 headers +2. confirm the builtin wrapper shape in installed Clang headers +3. confirm the intrinsic name family with: + - `strings $ASCEND_HOME_PATH/bin/bisheng | rg 'llvm\\.hivm\\.'` +4. patch repo-local emitter/lowering as little as possible +5. generate real repo-driven LLVM IR through the existing VPTO validation path: + - `source scripts/ptoas_env.sh` + - `WORK_SPACE=/tmp/ CASE_NAME= DEVICE=SIM COMPILE_ONLY=1 test/vpto/scripts/run_host_vpto_validation.sh` +6. inspect: + - `//*.ll` + - `//validation.log` +7. only after seeing the real generated `.ll` and Bisheng failure should you + refine the call shape + +This route is preferred because it preserves the real PTOAS lowering context, +the real case structure, and the exact driver invocation used by the repo. + +## Probe Strategy + +Use probes in this order: + +1. installed headers +2. `strings bisheng` +3. repo-generated VPTO LLVM IR from `run_host_vpto_validation.sh` +4. only then minimal handwritten `.ll` probes +5. handwritten `.cce` frontend probes are last resort + +Handwritten `.ll` probes are acceptable for quick ABI sanity checks such as: +- whether Bisheng recognizes a specific `llvm.hivm.*` name +- whether a guessed argument count immediately crashes or verifies + +But they are not the primary source of truth for semantic or frontend wrapper +behavior. + +## Avoid These Traps + +Do not default to handwritten `.cce` probes when repo-driven IR is available. +On this machine, bare `.cce` probes often fail before reaching the real +question because they are missing the exact frontend driver mode, target +features, wrapper setup, or host/device compilation context used by the repo. + +In particular, treat these as warning signs that you have started too low in +the stack: +- errors around `[aicore]` +- errors around `__cce_half` +- builtin alias attribute failures +- missing target feature or wrapper environment failures + +When these happen, step back to the repo-driven compile-only flow instead of +trying to repair the ad hoc frontend invocation from scratch. + +## Trace By The Real Type Split + +Do not infer the active implementation from the final storage type alone. +Follow the source element type and the installed dispatch branch. + +Example: +- for `Cmp` with `f32 -> ui8`, inspect the `sizeof(src) == 4` branch, not the + `ui8` destination branch +- for scalar or packed outputs, treat pack/store ops separately from compare + predicate generation + +Typical A5 compare split: +- 32-bit source elements -> `TCmp_32B` / `TCmps_32B` +- 16-bit source elements -> 16-bit branch +- 8-bit source elements -> 8-bit branch + +## What To Extract + +When tracing an op, capture: +- the installed PTO entrypoint that handles it +- the exact typed branch that matches the user case +- the builtins used in order +- any typed helper that explains `pset/plt` or store packing selection +- the compiler builtin wrapper if it is visible in installed Clang headers + +For compare-family questions, separate: +- predicate generation +- compare builtin +- predicate pack/interleave +- predicate store + +Stop at the builtin wrapper layer if the lower compiler implementation is not +available. That is still enough to answer questions such as: +- `pset_b32 -> __builtin_cce_pset_b32` +- `plt_b32 -> __builtin_cce_plt_b32_v300` + +## When The Builtin Name Is Still Not Enough + +If the installed PTO headers tell you the wrapper builtin but that still does +not answer the LLVM/HIVM operand contract, do not guess from repo-local +lowering. Extend the trace using the generated repo testcase first, and only +after that the real compiler frontend: + +1. run an existing repo case with: + - `WORK_SPACE=/tmp/ CASE_NAME= DEVICE=SIM COMPILE_ONLY=1 test/vpto/scripts/run_host_vpto_validation.sh` +2. inspect the generated `.ll` and `validation.log` +3. if the repo-generated LLVM IR still leaves the contract ambiguous, inspect + the testcase build flags from: + - `/build/CMakeFiles/.dir/flags.make` + - `/build/CMakeFiles/.dir/build.make` +4. rerun the same `bisheng` compile with `-v` and `-save-temps` +5. inspect: + - `*.ccei` for the exact installed PTO wrapper call sequence + - `strings *.bc | rg 'llvm.hivm\\.'` to see which HIVM intrinsics survived +6. if needed, rerun the same frontend compile with `-S`, `-emit-llvm`, or the + equivalent `cc1` invocation from `-v` to inspect the real LLVM IR emitted by + the compiler frontend before instruction selection + +This is the required fallback when the question is really: +- what exact `llvm.hivm.*` intrinsic shape the compiler expects +- whether a hand-written LLVM IR call shape is valid +- whether a selector failure is caused by a guessed mask/value form + +Prefer this real-frontend route over inventing mask constants or argument +shapes from memory. + +## Reporting Back + +When you use this skill, report: +- the exact installed header paths inspected +- whether `strings $ASCEND_HOME_PATH/bin/bisheng` confirmed the intrinsic name +- which typed branch was the authoritative one +- the builtin sequence observed there +- the builtin wrapper name if you found one in the installed Clang headers +- whether repo-generated `.ll` matched the guessed call shape +- whether repo-local lowering matches or diverges +- the first concrete mismatch, if any diff --git a/.codex/skills/ptoas-bisheng-asm-from-object-cmd/SKILL.md b/.codex/skills/ptoas-bisheng-asm-from-object-cmd/SKILL.md new file mode 100644 index 000000000..b50bb36fc --- /dev/null +++ b/.codex/skills/ptoas-bisheng-asm-from-object-cmd/SKILL.md @@ -0,0 +1,37 @@ +--- +name: ptoas-bisheng-asm-from-object-cmd +description: Use when you need assembly for a PTOAS VPTO case that already compiles to a device object. First find the exact command that produced the `.o`, then derive the `.s` command by replacing `-c` with `-S`. Do not guess a fresh Bisheng command line. +metadata: + short-description: Derive `.s` from real `.o` command +--- + +# PTOAS Bisheng ASM From Object Command + +Use this skill when the task is to inspect generated assembly for a VPTO case and the case already has a known `.o` build path. + +## Rule + +- Do not invent a new `bisheng` command. +- First find the exact command that built the `.o`. +- Then derive the `.s` command from that exact command by changing `-c` to `-S`. +- Keep the rest of the arguments unchanged unless the original command already wrote to a conflicting output path. + +## Preferred Sources + +- Validation script logs +- Build scripts such as `test/vpto/scripts/run_host_vpto_validation.sh` +- Saved shell history or generated compile traces in the case workspace + +## Procedure + +1. Locate the real `.o` compile command for the target case. +2. Copy that command exactly. +3. Replace `-c` with `-S`. +4. Point `-o` to a `.s` path. +5. Run the derived command. +6. Inspect the generated assembly instead of guessing from LLVM IR. + +## Anti-Pattern + +- Do not hand-write a new `bisheng -S ...` command from memory. +- Do not drop flags such as `--target`, `-march`, `--cce-aicore-arch`, `--cce-aicore-only`, `-O2`, include paths, or wrapper options that were present in the real `.o` command. diff --git a/.codex/skills/ptoas-build-and-abs/SKILL.md b/.codex/skills/ptoas-build-and-abs/SKILL.md new file mode 100644 index 000000000..bbfd993e3 --- /dev/null +++ b/.codex/skills/ptoas-build-and-abs/SKILL.md @@ -0,0 +1,101 @@ +--- +name: ptoas-build-and-abs +description: Rebuild PTOAS in the repo build directory and compile the Abs sample to inspect generated VPTO output. Use when the user asks to build ptoas, rebuild the current build tree, or run/check the Abs sample output. +--- + +# PTOAS Build And Abs + +Use this skill when the task is specifically about: +- rebuilding `ptoas` in this repo +- doing a full repo build in the repo-local `build/` directory +- compiling `test/samples/Abs/abs.py` +- inspecting the generated VPTO text for `Abs` + +## Canonical Commands + +### 1. Configure the repo-local build directory + +`do_cmake.sh` is the canonical entrypoint. It always targets `./build`. + +```bash +./do_cmake.sh --llvm /data/mouliangyu/projects/github.com/llvm/llvm-project/install +``` + +If `do_cmake.sh` fails because `build/` has a generator mismatch between old Makefiles/Ninja metadata, do not guess. State that `build/` is inconsistent and ask before cleaning the generated build metadata in `build/`. + +### 2. Build + +For just the CLI: + +```bash +CCACHE_DISABLE=1 ninja -C build ptoas +``` + +For a full repo build: + +```bash +CCACHE_DISABLE=1 ninja -C build +``` + +If the user asked for "full build", prefer the full command above. If they only want to run `Abs`, building `ptoas` is usually enough. + +### 3. Prepare runtime environment + +Before running `runop.sh`, always: + +```bash +source env.sh +``` + +This sets `PYTHONPATH`, `LD_LIBRARY_PATH`, and the MLIR/PTO python roots needed by the samples. + +### 4. Compile `Abs` to VPTO text + +Use `runop.sh` with explicit `PTOAS_BIN`, explicit output directory, and A5 backend flags: + +```bash +source env.sh +PTOAS_BIN="$PWD/build/tools/ptoas/ptoas" \ +PTOAS_OUT_DIR=/tmp/ptoas-abs-vpto \ +PTOAS_FLAGS='--pto-arch a5 --pto-backend=vpto --vpto-print-ir' \ +./test/samples/runop.sh -t Abs +``` + +Expected outputs: +- `/tmp/ptoas-abs-vpto/Abs/abs-pto-ir.pto` +- `/tmp/ptoas-abs-vpto/Abs/abs-pto.cpp` + +Despite the `.cpp` suffix, on the VPTO backend this file contains the emitted VPTO textual IR. + +## Inspection + +The main file to show the user is: + +```bash +sed -n '1,260p' /tmp/ptoas-abs-vpto/Abs/abs-pto.cpp +``` + +For quick sanity checks, look for: +- `vpto.copy_gm_to_ubuf` +- `src_strides = [32, 1]` +- `trace_offsets = [0, 0]` +- `trace_sizes = [32, 32]` +- `cce_aiv_loop_hint` +- `llvm.loop.aivector_scope` +- `vpto.vlds` +- `vpto.vabs` +- `vpto.vsts` +- `vpto.copy_ubuf_to_gm` + +## Reporting Back + +When you ran `Abs`, report: +- whether `ptoas` had to be rebuilt +- the exact generated file path for the VPTO text +- whether the output contains the expected copy-family metadata and vec-scope carrier attrs + +If the build fails, include the first concrete blocker: +- generator mismatch in `build/` +- link failure in `ptoas` +- missing runtime env because `env.sh` was not sourced +- missing sample output file diff --git a/.codex/skills/ptoas-npu-validation-a5/SKILL.md b/.codex/skills/ptoas-npu-validation-a5/SKILL.md new file mode 100644 index 000000000..735cde327 --- /dev/null +++ b/.codex/skills/ptoas-npu-validation-a5/SKILL.md @@ -0,0 +1,335 @@ +--- +name: ptoas-npu-validation-a5 +description: Generate and run PTOAS-based A5 test/npu_validation or test/vpto validations, build the testcase binaries, and validate runtime output on NPU or simulator. Use when the user wants NPU run validation, golden/compare checks, or runtime troubleshooting for A5. +--- + +# PTOAS NPU Validation A5 + +Use this skill when the task is specifically about: +- generating `test/npu_validation` projects from PTOAS output +- running `test/vpto/scripts/run_host_vpto_validation.sh` +- running `test/vpto` board validation or simulator validation +- building testcase binaries for A5 +- running NPU or simulator validation +- generating golden inputs and checking results with `compare.py` +- diagnosing runtime blockers such as missing device access or `aclrtSetDevice` + +This skill is the main entry point for runtime validation. + +Do not use this skill as the primary entry point when the task is only about: +- exporting LLVM IR or LLVM bitcode +- validating the `bisheng` handoff +- assembling a fat object or replacement kernel library from the LLVM path + +When this validation flow needs a custom LLVM IR or LLVM BC artifact, use +`ptoas-vpto-llvm-artifacts` first to build that artifact, then return here to +run the testcase. + +## Important Constraint + +The `npu_validation` flow still depends on an EmitC-generated sample export to +materialize the host-side testcase skeleton. + +For the existing automation, this EmitC export step is not something the user +must run manually first. The provided host-validation scripts already do it for +you. + +Specifically: +- `run_host_npu_validation.sh` automatically invokes `test/samples/runop.sh` + first +- that export is written under `WORK_SPACE/emitc/...` +- `run_host_npu_validation_case.sh` then uses that generated EmitC `*-pto.cpp` + as the input to `generate_testcase.py` + +Even when the final kernel under validation comes from the VPTO/LLVM path, the +current scripts do not generate a standalone host runner from VPTO MLIR or +LLVM IR directly. The canonical automated flow is: + +1. `run_host_npu_validation.sh` automatically exports the sample through the + default EmitC path to get `*-pto.cpp` +2. `run_host_npu_validation_case.sh` runs `generate_testcase.py` on that + generated EmitC kernel to create the testcase directory, host `main.cpp`, + kernel wrapper source, `launch.cpp`, and build system +3. if LLVM/VPTO validation is desired, `run_host_npu_validation_case.sh` + optionally calls `build_llvm_ir_kernel_so.sh` to rebuild and replace only + the final `lib_kernel.so` +4. the generated testcase binary is then run against that replacement kernel + library + +In other words: +- the scripts automatically do the EmitC export step before testcase + generation +- EmitC is still required to produce the host/testcase scaffolding +- LLVM/VPTO replaces the device kernel library, not the host testcase +- feeding raw VPTO textual MLIR directly into `generate_testcase.py` is not a + supported path + +## Automation Entry Points + +Use these scripts as the default automation entry points instead of rebuilding +the flow by hand: + +- `test/vpto/scripts/run_host_vpto_validation.sh` + - top-level driver for curated VPTO `kernel.pto` board/simulator validation + - consumes hand-authored VPTO cases under `test/vpto/cases/...` + - handles lowering, LLVM-path device object build, host build, golden, and compare + - is the default entry point when the user asks to run VPTO board validation directly + - when it fails at runtime, follow this skill's troubleshooting guidance instead of treating the first `aclrtSetDevice` failure as a final product regression + +- `test/npu_validation/scripts/run_host_npu_validation.sh` + - top-level driver for host/NPU validation + - automatically runs `test/samples/runop.sh` first + - automatically writes the EmitC export under `WORK_SPACE/emitc/...` + - discovers testcase names from `test/samples//npu_validation/...` + - dispatches each testcase to `run_host_npu_validation_case.sh` + +- `test/npu_validation/scripts/run_host_npu_validation_case.sh` + - per-testcase execution driver + - consumes the already-generated EmitC kernel from `WORK_SPACE/emitc/...` + - runs `generate_testcase.py` + - configures and builds the testcase + - when `KERNEL_MODE=llvm`, calls `build_llvm_ir_kernel_so.sh` to replace the + device kernel shared library + - runs the testcase binary and then `compare.py` + +- `test/npu_validation/scripts/build_llvm_ir_kernel_so.sh` + - helper used by the case runner for LLVM/VPTO validation + - assumes the EmitC-derived testcase and host wrapper already exist + - rebuilds only the replacement `lib_kernel.so` + - its internal `runop.sh` export may return non-zero because another sample + in the same family failed, but the script intentionally continues if the + requested testcase's LLVM IR artifact was still produced + +## Preconditions + +Before running `npu_validation` or `test/vpto`, make sure: +- `ptoas` is already built in `./build` +- `bisheng` is in `PATH` or available through CANN `set_env.sh` +- `PTO_ISA_ROOT` points to a `pto-isa` checkout with: + - `include/` + - `tests/common/` +- the shell can read `/dev/davinci*` if you intend to execute on real hardware + +Example: + +```bash +export PTO_ISA_ROOT=/path/to/pto-isa +``` + +Useful runtime check: + +```bash +source /usr/local/Ascend/cann/set_env.sh +python3 - <<'PY' +import ctypes +lib = ctypes.cdll.LoadLibrary('libascendcl.so') +aclInit = lib.aclInit; aclInit.argtypes=[ctypes.c_char_p]; aclInit.restype=ctypes.c_int +aclrtGetDeviceCount = lib.aclrtGetDeviceCount; aclrtGetDeviceCount.argtypes=[ctypes.c_void_p]; aclrtGetDeviceCount.restype=ctypes.c_int +aclrtSetDevice = lib.aclrtSetDevice; aclrtSetDevice.argtypes=[ctypes.c_int]; aclrtSetDevice.restype=ctypes.c_int +cnt = ctypes.c_uint(0) +print('aclInit', aclInit(None)) +print('aclrtGetDeviceCount', aclrtGetDeviceCount(ctypes.byref(cnt)), cnt.value) +print('aclrtSetDevice', aclrtSetDevice(0)) +PY +``` + +Interpretation: +- `aclInit` succeeds +- `aclrtGetDeviceCount` should report at least one device if the runtime can enumerate hardware +- if `aclrtSetDevice(0)` fails with `507033` (`ACL_ERROR_RT_DEV_SETUP_ERROR`), the user context can see a device but cannot open a usable runtime context + +This interpretation applies equally to: + +- `test/npu_validation` +- `test/vpto` + +When `test/vpto/scripts/run_host_vpto_validation.sh` hits `aclrtSetDevice`, do not immediately report a testcase regression. First treat it as a runtime-environment blocker and follow the checks in this skill. + +## Canonical Flow + +### 1. Generate the PTOAS kernel + +Use the default EmitC-style output, because `npu_validation` consumes `*-pto.cpp`. + +```bash +source env.sh +PTOAS_BIN="$PWD/build/tools/ptoas/ptoas" \ +PTOAS_OUT_DIR=/tmp/ptoas-abs-emitc \ +./test/samples/runop.sh -t Abs +``` + +Expected output: +- `/tmp/ptoas-abs-emitc/Abs/abs-pto.cpp` +- this EmitC kernel is also the required host/testcase input for the later + LLVM/VPTO replacement flow + +### 2. Generate the `npu_validation` testcase + +```bash +python3 test/npu_validation/scripts/generate_testcase.py \ + --input /tmp/ptoas-abs-emitc/Abs/abs-pto.cpp \ + --testcase abs \ + --output-root /tmp/ptoas-npu-validation-run \ + --run-mode sim \ + --soc-version dav_3102 \ + --aicore-arch dav-c310-vec +``` + +Expected output directory: +- `/tmp/ptoas-npu-validation-run/Abs/abs` + +### 3. Configure and build + +```bash +export PTO_ISA_ROOT=/path/to/pto-isa +source /usr/local/Ascend/cann/set_env.sh +cmake -S /tmp/ptoas-npu-validation-run/Abs/abs \ + -B /tmp/ptoas-npu-validation-run/Abs/abs/build \ + -DSOC_VERSION=dav_3102 \ + -DENABLE_SIM_GOLDEN=ON +cmake --build /tmp/ptoas-npu-validation-run/Abs/abs/build --parallel +``` + +Typical build expectations: +- `libabs_kernel.so` builds +- `abs` builds +- `abs_sim` may also build if the simulator runtime is available + +If you need to replace the default `libabs_kernel.so` with one assembled from +an LLVM IR or LLVM BC path, build that artifact with +`ptoas-vpto-llvm-artifacts` and place it first in `LD_LIBRARY_PATH` when +running `./build/abs`. + +Important: +- the LLVM/VPTO path does not bypass EmitC testcase generation +- `build_llvm_ir_kernel_so.sh` assumes the testcase was already generated from + the EmitC export and reuses its host wrapper/build artifacts + +### 4. Generate golden inputs + +```bash +cd /tmp/ptoas-npu-validation-run/Abs/abs +python3 ./golden.py +``` + +Expected files: +- `v1.bin` +- `v2.bin` + +For the generated `Abs` testcase, `golden.py` does not emit `golden_v2.bin`, +but `compare.py` expects it. Build the oracle explicitly from the input: + +```bash +cd /tmp/ptoas-npu-validation-run/Abs/abs +python3 - <<'PY' +import numpy as np +v1 = np.fromfile('v1.bin', dtype=np.float32) +np.abs(v1).astype(np.float32).tofile('golden_v2.bin') +PY +``` + +Expected additional file: +- `golden_v2.bin` + +## Running + +### NPU run + +Only attempt this on a shell that can actually see `/dev/davinci*`. + +```bash +export PTO_ISA_ROOT=/path/to/pto-isa +source /usr/local/Ascend/cann/set_env.sh +cd /tmp/ptoas-npu-validation-run/Abs/abs +LD_LIBRARY_PATH="${ASCEND_HOME_PATH}/lib64:${LD_LIBRARY_PATH:-}" \ + ./build/abs +``` + +For the repo's automated host-validation flow, prefer the script's default +remote runner: + +```bash +HOST_RUNNER='ssh root@localhost' +``` + +This is already the default in `run_host_npu_validation.sh` / +`run_host_npu_validation_case.sh`, and it is the preferred way to reach a root +context on the local machine when passwordless root SSH is already configured. + +Use that path first instead of assuming `sudo` is available or passwordless. + +If you are not using the repo scripts and your environment explicitly supports +`sudo`, you may still retry manually with: + +```bash +sudo bash -lc ' + cd /tmp/ptoas-npu-validation-run/Abs/abs + source /usr/local/Ascend/cann/set_env.sh >/dev/null 2>&1 + LD_LIBRARY_PATH="${ASCEND_HOME_PATH}/lib64:${LD_LIBRARY_PATH:-}" \ + ./build/abs +' +``` + +Observed runtime result on this machine for the `Abs` testcase: +- normal user run failed at `aclrtSetDevice(0)` with `507033` +- root-context execution is expected to go through the script default + `ssh root@localhost` path when available +- `python3 ./compare.py` then reported `[INFO] compare passed` + +Observed runtime result on this machine for the VPTO LLVM-path host validation +of `PyPTOIRParser/paged_attention_example_kernel_online_update`: +- `test/npu_validation/scripts/run_host_npu_validation.sh` passed end-to-end +- the replacement kernel library from `build_llvm_ir_kernel_so.sh` was loaded + successfully +- `compare.py` reported `[INFO] compare passed` +- during the LLVM artifact export step, `runop.sh` returned non-zero because + `paged_attention_example_kernel_softmax_prepare` failed in the same sample + batch, but the requested `online_update` LLVM IR was still generated and the + validation flow remained valid + +### Simulator run + +If `abs_sim` links successfully, run it with simulator libraries in `LD_LIBRARY_PATH`. + +```bash +export PTO_ISA_ROOT=/path/to/pto-isa +source /usr/local/Ascend/cann/set_env.sh +cd /tmp/ptoas-npu-validation-run/Abs/abs +LD_LIBRARY_PATH="${ASCEND_HOME_PATH}/aarch64-linux/simulator/dav_3510/lib:${ASCEND_HOME_PATH}/lib64:${LD_LIBRARY_PATH:-}" \ + ./build/abs_sim +``` + +Treat simulator execution as optional. Depending on the local CANN install, the +simulator binary may link successfully but still fail at runtime due to missing +simulator services or runtime symbols. + +## Compare + +After generating `golden_v2.bin` and running the NPU binary, compare with: + +```bash +cd /tmp/ptoas-npu-validation-run/Abs/abs +python3 ./compare.py +``` + +Expected success output: +- `[INFO] compare passed` + +## Known Failure Modes + +- `generate_testcase.py` fails because the input is not a PTOAS EmitC `*-pto.cpp` kernel +- configure fails because `PTO_ISA_ROOT` is unset or points to the wrong checkout +- `abs_sim` fails to link because simulator runtime symbols are missing +- `./build/abs` fails at `aclInit(nullptr)` because the shell does not have usable Ascend runtime access +- non-`sudo` `./build/abs` fails at `aclrtSetDevice(0)` with `507033`, meaning the user context sees the device but cannot open a usable runtime context +- `compare.py` reports `golden_v2.bin` missing because the testcase generation did not create it automatically + +## Reporting Back + +When you use this skill, report: +- the generated testcase directory +- whether `libabs_kernel.so`, `abs`, and `abs_sim` built +- whether `golden.py` generated input bins and whether `golden_v2.bin` had to be created explicitly +- whether NPU execution worked directly or required elevated privileges +- whether `compare.py` passed +- the first concrete blocker for NPU or simulator execution diff --git a/.codex/skills/ptoas-vpto-llvm-artifacts/SKILL.md b/.codex/skills/ptoas-vpto-llvm-artifacts/SKILL.md new file mode 100644 index 000000000..d24d19c85 --- /dev/null +++ b/.codex/skills/ptoas-vpto-llvm-artifacts/SKILL.md @@ -0,0 +1,319 @@ +--- +name: "ptoas-vpto-llvm-artifacts" +description: "Guide the PTOAS VPTO compile-and-link workflow: inspect VPTO MLIR, export LLVM IR or LLVM bitcode, validate the Bisheng handoff, and assemble device objects, fat objects, or shared kernel libraries. Use when the user asks how to build, export, compile, or link VPTO LLVM-path artifacts for A5." +--- + +# PTOAS VPTO LLVM Artifacts + +Use this skill when the task is specifically about: +- printing or inspecting VPTO intermediate MLIR +- exporting PTOAS A5 kernels as LLVM IR or LLVM bitcode through the VPTO backend +- checking whether the export is textual LLVM IR or real LLVM bitcode +- compiling the exported artifact with `bisheng` +- assembling a device object, fat relocatable object, or shared kernel library from the LLVM path +- helping with an "LLVM IR path build", "LLVM IR path compile", or "VPTO MLIR" request + +This skill answers: +- how to build or export the artifact +- how to hand the artifact to Bisheng +- how to continue from `.ll` / `.bc` to `.o` / `fatobj` / `.so` +- where each stage output is written + +This skill does not answer: +- which `llvm.hivm.*` intrinsic a VPTO op should lower to +- what the authoritative intrinsic name or operand contract is +- whether the repo-local emitter guessed the wrong LLVM IR form + +Those questions belong to `pto-a5-installed-impl-trace`. + +## Strong Rule + +Treat this skill as a compile-and-link workflow guide, not as the authority for +discovering intrinsic mappings. If the task turns into "what should this VPTO +op lower to" or "is this `llvm.hivm.*` form correct", switch to +`pto-a5-installed-impl-trace`. + +This is not the primary entry point for: +- generating `test/npu_validation` testcases +- running on hardware, handling `aclrtSetDevice`, or deciding whether `sudo` is needed +- `golden.py` / `compare.py` result checks +- discovering the authoritative LLVM IR shape for a VPTO op + +If the end goal is runtime validation, use `ptoas-npu-validation-a5` as the main +skill and call this skill only when that flow needs a custom LLVM IR or LLVM BC +kernel artifact. + +## Preconditions + +Before using this path, make sure: +- `ptoas` is already built in `./build` +- `bisheng` is available through CANN `set_env.sh` +- `env.sh` can be sourced from the repo root +- for the fatobj path, you already have a generated testcase directory that + contains a wrapper source such as `abs_kernel.cpp` and a built `launch.cpp.o` + +Load the repo environment before running examples: + +```bash +set +u +source env.sh +set -u +``` + +Use the `set +u` form when the caller shell has `set -u`, because `env.sh` +appends to variables such as `PYTHONPATH` and `LD_LIBRARY_PATH`. + +## Inspect VPTO MLIR + +Use this when you need to look at the VPTO-stage IR before deciding whether to +continue to textual LLVM IR, LLVM bitcode, or the full artifact assembly flow. + +Canonical flag: + +```bash +--vpto-print-ir +``` + +Example: + +```bash +source env.sh +PTOAS_BIN="$PWD/build/tools/ptoas/ptoas" \ +PTOAS_OUT_DIR=/tmp/ptoas-vpto-ir \ +PTOAS_FLAGS='--pto-arch a5 --pto-backend=vpto --vpto-print-ir' \ +./test/samples/runop.sh -t Abs +``` + +Use this output to: +- confirm the lowering has reached the VPTO dialect you expect +- inspect whether a transformation issue appears before LLVM export +- compare the VPTO MLIR path against the later LLVM IR or bitcode output + +## Export Paths + +### LLVM bitcode export + +Use: + +```bash +--pto-backend=vpto --vpto-emit-hivm-bc +``` + +Example: + +```bash +source env.sh +PTOAS_BIN="$PWD/build/tools/ptoas/ptoas" \ +PTOAS_OUT_DIR=/tmp/ptoas-vpto-hivm-bc \ +PTOAS_FLAGS='--pto-arch a5 --pto-backend=vpto --vpto-emit-hivm-bc' \ +./test/samples/runop.sh -t Abs +``` + +Typical outputs: +- `/tmp/ptoas-vpto-hivm-bc/Abs/abs-pto-ir.pto` +- `/tmp/ptoas-vpto-hivm-bc/Abs/abs-pto.cpp` + +Important: +- the payload is written to `*-pto.cpp` even in bitcode mode +- that file is LLVM bitcode, not C++ source + +Bitcode checks: + +```bash +file /tmp/ptoas-vpto-hivm-bc/Abs/abs-pto.cpp +xxd -l 16 /tmp/ptoas-vpto-hivm-bc/Abs/abs-pto.cpp +"$LLVM_ROOT/bin/llvm-dis" /tmp/ptoas-vpto-hivm-bc/Abs/abs-pto.cpp -o - | sed -n '1,80p' +``` + +Expected signs: +- `file` reports `LLVM IR bitcode` +- the header starts with `42 43 c0 de` +- `llvm-dis` shows HiVM/LLVM content + +### Textual LLVM IR export + +Use: + +```bash +--pto-backend=vpto --vpto-emit-hivm-llvm +``` + +Example: + +```bash +source env.sh +PTOAS_BIN="$PWD/build/tools/ptoas/ptoas" \ +PTOAS_OUT_DIR=/tmp/ptoas-vpto-hivm-llvm \ +PTOAS_FLAGS='--pto-arch a5 --pto-backend=vpto --vpto-emit-hivm-llvm' \ +./test/samples/runop.sh -t Abs +``` + +Typical output: +- `/tmp/ptoas-vpto-hivm-llvm/Abs/abs-pto.cpp` + +Important: +- despite the `.cpp` suffix, this file is textual LLVM IR +- compile it with `-x ir` + +Suggested progression: +- start with `--vpto-print-ir` when the user wants the intermediate VPTO form +- use `--vpto-emit-hivm-llvm` when the user wants textual LLVM IR +- use `--vpto-emit-hivm-bc` when the user wants real LLVM bitcode + +## Compile The Export With Bisheng + +Load the CANN environment first: + +```bash +source /usr/local/Ascend/cann/set_env.sh +``` + +### Compile bitcode to a device object + +Preferred: + +```bash +bisheng \ + --target=hiipu64-hisilicon-cce \ + -march=dav-c310-vec \ + --cce-aicore-arch=dav-c310-vec \ + --cce-aicore-only \ + -O2 \ + -c -x ir /tmp/ptoas-vpto-hivm-bc/Abs/abs-pto.cpp \ + -o /tmp/ptoas-vpto-hivm-bc/Abs/abs-pto.o +``` + +Alternative: +- copy or rename the payload to `.bc` +- compile without relying on the misleading `.cpp` suffix + +### Compile textual LLVM IR to a device object + +```bash +bisheng \ + --target=hiipu64-hisilicon-cce \ + -march=dav-c310-vec \ + --cce-aicore-arch=dav-c310-vec \ + --cce-aicore-only \ + -O2 \ + -c -x ir /tmp/ptoas-vpto-hivm-llvm/Abs/abs-pto.cpp \ + -o /tmp/abs_ir_path_artifacts/kernel_from_llvm_ir.o +``` + +Checks: +- keep `-march` and `--cce-aicore-arch` aligned with the intended testcase arch +- for the LLVM IR path, the resulting object should not retain unresolved + `llvm.hivm.*` symbols + +## If You Need The Real Compiler-Expected Intrinsic Shape + +This is outside the main purpose of this skill. + +When a hand-written LLVM IR path fails in instruction selection or appears to +miscompile, use this trace order: + +1. confirm the installed PTO wrapper path first with `pto-a5-installed-impl-trace` +2. generate the normal testcase kernel source through the working emitc path +3. inspect testcase compile flags from: + - `/build/CMakeFiles/.dir/flags.make` + - `/build/CMakeFiles/.dir/build.make` +4. rerun that same `bisheng` compile with `-v` and `-save-temps` +5. inspect: + - `*.ccei` to confirm the wrapper builtin sequence + - `strings *.bc | rg 'llvm.hivm\\.'` to see which HIVM intrinsics survive +6. if builtin names still are not enough, extract the exact frontend-produced + LLVM IR by replaying the `cc1` invocation from `-v` with `-emit-llvm -S` + +Use this when you need to answer questions such as: +- is the intrinsic name correct but the mask form wrong +- did the compiler expect a `plt/pset` result instead of a literal mask +- is the LLVM IR path missing hidden frontend-generated structure or attrs + +This is the preferred way to align repo-local LLVM emission with the real +compiler contract. + +## Assemble Fat Objects And Shared Libraries + +Use this only when the validation flow needs a replacement kernel library built +from the LLVM path. The canonical example below uses the generated `Abs` +testcase, but the pattern is the same for other testcases: take the testcase +wrapper source, embed the device object, pack it with `cce-ld`, then link the +shared kernel library. + +Required testcase artifacts: +- a wrapper source such as `/tmp/ptoas-npu-validation-run/Abs/abs/abs_kernel.cpp` +- a built launch object such as + `/tmp/ptoas-npu-validation-run/Abs/abs/build/CMakeFiles/abs_kernel.dir/launch.cpp.o` + +### 1. Build the host stub object + +```bash +/usr/local/Ascend/cann-9.0.0/tools/bisheng_compiler/bin/bisheng -cc1 \ + -triple aarch64-unknown-linux-gnu \ + -fcce-is-host \ + -fcce-fatobj-compile \ + -fcce-include-aibinary /tmp/abs_ir_path_artifacts/kernel_from_llvm_ir.o \ + -fcce-device-module-id a55ab1efc0defeed \ + -fcce-aicore-arch dav-c310-vec \ + -x cce /tmp/ptoas-npu-validation-run/Abs/abs/abs_kernel.cpp \ + -o /tmp/abs_ir_path_artifacts/kernel_host_stub.o +``` + +### 2. Pack the fat relocatable object + +```bash +/usr/local/Ascend/cann-9.0.0/bin/cce-ld \ + /usr/local/Ascend/cann-9.0.0/bin/ld.lld \ + -x \ + -cce-lite-bin-module-id a55ab1efc0defeed \ + -cce-aicore-arch=dav-c310-vec \ + -r \ + -o /tmp/abs_ir_path_artifacts/kernel_fat.o \ + -cce-stub-dir /usr/local/Ascend/cann-9.0.0/tools/bisheng_compiler/lib/clang/15.0.5/include/cce_stub \ + -cce-install-dir /usr/local/Ascend/cann-9.0.0/tools/bisheng_compiler/bin \ + -cce-inputs-number 1 \ + /tmp/abs_ir_path_artifacts/kernel_host_stub.o +``` + +The module id must match between: +- `-fcce-device-module-id` +- `-cce-lite-bin-module-id` + +### 3. Link the shared kernel library + +```bash +mkdir -p /tmp/abs_ir_path_artifacts/link_try +cd /tmp/abs_ir_path_artifacts/link_try +/usr/local/Ascend/cann-9.0.0/bin/bisheng \ + -fPIC -s -Wl,-z,relro -Wl,-z,now --cce-fatobj-link \ + -shared -Wl,-soname,libabs_kernel.so \ + -o libabs_kernel.so \ + /tmp/abs_ir_path_artifacts/kernel_fat.o \ + /tmp/ptoas-npu-validation-run/Abs/abs/build/CMakeFiles/abs_kernel.dir/launch.cpp.o +``` + +This skill stops at producing the replacement artifact. To run the testcase +with that library and validate outputs, switch back to `ptoas-npu-validation-a5`. + +## Failure Modes + +Report the first concrete blocker: +- `--vpto-print-ir`, `--vpto-emit-hivm-bc`, or `--vpto-emit-hivm-llvm` used without `--pto-backend=vpto` +- `--vpto-emit-hivm-bc` or `--vpto-emit-hivm-llvm` used without `--pto-backend=vpto` +- `env.sh` was not sourced, or failed under `set -u` +- `bisheng` was not found or CANN environment was not loaded +- a bitcode payload was treated as source because it kept a misleading suffix +- the testcase wrapper or `launch.cpp.o` is missing for the fatobj path +- the module ids used for stub creation and `cce-ld` packing do not match + +## Reporting Back + +When you use this skill, report: +- whether the user-facing artifact of interest was VPTO MLIR, textual LLVM IR, or LLVM bitcode +- the exact `ptoas` flags used +- whether the export was VPTO MLIR, LLVM bitcode, or textual LLVM IR +- the exact output path that contains the exported payload +- whether `llvm-dis`, `file`, or direct inspection confirmed the payload type +- whether `bisheng` produced a device object +- whether the flow also produced a fat relocatable object or shared kernel library +- which step was the first blocker, if the full artifact chain did not complete diff --git a/.gitignore b/.gitignore index 44c61b02a..4fbeb8f5e 100644 --- a/.gitignore +++ b/.gitignore @@ -64,6 +64,11 @@ dist/ /remote_npu_validation_results*.tsv /npu_validation/ test/samples/**/npu_validation/ +!test/samples/**/npu_validation/ +test/samples/**/npu_validation/* +!test/samples/**/npu_validation/golden.py +!test/samples/**/npu_validation/*/ +!test/samples/**/npu_validation/*/golden.py /tmp_gen* # IDE/editor diff --git a/docs/PTO_IR_manual.md b/docs/PTO_IR_manual.md index 35f737e91..39d969fa1 100644 --- a/docs/PTO_IR_manual.md +++ b/docs/PTO_IR_manual.md @@ -99,15 +99,20 @@ number of scalar FP4 elements. Operation support is still opt-in. Defining the type in PTO IR does not by itself imply that any particular operation accepts it. -### 2.2 `!pto.ptr` +### 2.2 `!pto.ptr` -A pointer to global memory. +A typed pointer. `memorySpace` is optional and defaults to `gm`. | Parameter | Type | Description | |-----------|------|-------------| | `elementType` | `element-type(i1/i8/i16/i32/f16/f32/bf16...)` | Element type pointed to | +| `memorySpace` | `gm` or `ub` | Pointer address space alias (`gm` -> global memory, `ub` -> vector/UB memory) | -**Syntax:** `!pto.ptr` +**Syntax:** `!pto.ptr` or `!pto.ptr` + +Pointer conversions are modeled explicitly with [`pto.castptr`](#ptocastptr). +Between two `!pto.ptr` types, casts are only legal when both pointers stay in +the same PTO memory space. --- @@ -334,6 +339,38 @@ result = ptr + offset // offset is in elements, not bytes %ptr_off = pto.addptr %base, %offset : !pto.ptr -> !pto.ptr ``` +##### `pto.castptr` - Explicit Pointer Cast + +**Summary:** Performs an explicit cast between integer addresses and `!pto.ptr`, +or between two `!pto.ptr` types. + +**Semantics:** + +```mlir +%p0 = pto.castptr %addr : i64 -> !pto.ptr +%p1 = pto.castptr %p0 : !pto.ptr -> !pto.ptr +%addr2 = pto.castptr %p1 : !pto.ptr -> i64 +``` + +**Arguments:** + +| Name | Type | Description | +|------|------|-------------| +| `input` | `integer` or `!pto.ptr<...>` | Source value to cast | + +**Results:** `integer` or `!pto.ptr<...>` + +**Constraints & Verification:** + +- Integer-to-integer casts are rejected; use normal integer cast ops instead +- Pointer-to-pointer casts are only legal when source and destination stay in + the same PTO memory space (`gm` or `ub`) +- The operation is pure (no side effects) + +**Hardware Mapping:** + +- No hardware pipeline (representation conversion only) + ##### `pto.make_tensor_view` - Create Tensor View **Summary:** Constructs a global tensor view from a pointer, declaring the physical base and strides (no allocation, no data movement). diff --git a/docs/isa/01-pipeline-sync.md b/docs/isa/01-pipeline-sync.md new file mode 100644 index 000000000..3040ec3d1 --- /dev/null +++ b/docs/isa/01-pipeline-sync.md @@ -0,0 +1,462 @@ +# 1. Pipeline Synchronization + +> **Category:** Synchronization primitives for coordinating pipeline execution +> **Pipelines:** MTE2 (GM→UB), PIPE_V (Vector), MTE3 (UB→GM) + +The PTO micro Instruction model operates on the Ascend 950's **Decoupled Access-Execute** architecture. The MTE and Vector pipelines run asynchronously, requiring explicit synchronization to prevent data hazards. + +--- + +## Intra-Core Pipeline Sync + +These ops coordinate data flow between pipelines within a single vector core. + +### `pto.set_flag` + +- **syntax:** `pto.set_flag["SRC_PIPE", "DST_PIPE", "EVENT_ID"]` +- **semantics:** Signal event from source pipe to destination pipe. + +```c +set_flag(src_pipe, dst_pipe, event_id); +``` + +**Example:** After MTE2 completes GM→UB transfer, signal Vector pipe: +```mlir +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +--- + +### `pto.wait_flag` + +- **syntax:** `pto.wait_flag["SRC_PIPE", "DST_PIPE", "EVENT_ID"]` +- **semantics:** Block destination pipe until source pipe signals event. + +```c +wait_flag(src_pipe, dst_pipe, event_id); +``` + +**Example:** Vector pipe waits for MTE2 data to arrive: +```mlir +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +--- + +### `pto.pipe_barrier` + +- **syntax:** `pto.pipe_barrier "PIPE_*"` +- **semantics:** Drain all pending ops in the specified pipe. All previously issued operations on that pipe complete before any subsequent operation begins. + +```c +pipe_barrier(pipe); +``` + +**Pipe identifiers:** `PIPE_MTE2`, `PIPE_V`, `PIPE_MTE3` + +**Example:** Two back-to-back `copy_ubuf_to_gm` calls writing to the same GM address. Without a barrier, MTE3 may reorder them and the final GM value is non-deterministic: + +```mlir +// Both stores target the same GM address — order matters! +pto.copy_ubuf_to_gm %ub_partial_0, %gm_result, ... +// Without pipe_barrier, MTE3 could execute the second copy before the first +// completes, producing a non-deterministic result at %gm_result. +pto.pipe_barrier "PIPE_MTE3" +// After barrier: first copy is guaranteed complete. Second copy overwrites deterministically. +pto.copy_ubuf_to_gm %ub_partial_1, %gm_result, ... +``` + +--- + +### `pto.get_buf` + +- **syntax:** `pto.get_buf "PIPE_*", %buf_id, %mode : i64, i64` +- **semantics:** Acquire buffer slot for inter-pipeline double-buffering coordination. + +```c +get_buf(pipe, buf_id, mode); +``` + +--- + +### `pto.rls_buf` + +- **syntax:** `pto.rls_buf "PIPE_*", %buf_id, %mode : i64, i64` +- **semantics:** Release buffer slot to allow other pipeline to proceed. + +```c +rls_buf(pipe, buf_id, mode); +``` + +--- + +### `pto.mem_bar` + +- **syntax:** `pto.mem_bar "BARRIER_TYPE"` +- **semantics:** Intra-vector-pipe memory fence within `__VEC_SCOPE__`. Required when UB addresses alias between vector load/store operations. + +```c +mem_bar(barrier_type); +``` + +**Barrier types:** + +| Type | Semantics | +|------|-----------| +| `VV_ALL` | All prior vector ops complete before subsequent | +| `VST_VLD` | All prior vector stores visible before subsequent loads | +| `VLD_VST` | All prior vector loads complete before subsequent stores | + +**Example:** Ensure stores are visible before loads to same UB region: +```mlir +pto.vsts %v0, %ub[%c0] : !pto.vreg<64xf32>, !pto.ptr +pto.mem_bar "VST_VLD" +%v1 = pto.vlds %ub[%c0] : !pto.ptr -> !pto.vreg<64xf32> +``` + +--- + +## Intra-Core Sync Patterns & Examples + +### Example 1: `set_flag` / `wait_flag` (Explicit Events) + +Each cross-pipeline data dependency requires an explicit signal/wait pair. The programmer must manually insert `set_flag` after the producer and `wait_flag` before the consumer. + +```mlir +// ─── Stage 1: MTE2 loads data from GM into UB ─── +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, ... + +// MTE2 signals: "UB data is ready for Vector pipe" +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +// ─── Stage 2: Vector pipe consumes UB data ─── +// Vector waits until MTE2's signal arrives +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +pto.vecscope { + %v = pto.vlds %ub_ptr[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} + +// Vector signals: "UB output is ready for MTE3" +pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + +// ─── Stage 3: MTE3 stores result from UB back to GM ─── +// MTE3 waits until Vector's signal arrives +pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + +pto.copy_ubuf_to_gm %ub_out, %gm_out, ... +``` + +**Key property:** Every cross-pipeline edge is an explicit `(set_flag, wait_flag)` pair. Simple for straight-line code, but gets verbose in loops (see Example 3). + +--- + +### Example 2: `get_buf` / `rls_buf` (Resource-Based) + +Instead of naming events, each pipeline declares when it **acquires** (`get_buf`) and **releases** (`rls_buf`) a shared UB buffer. Cross-pipeline RAW/WAR dependencies are resolved implicitly by program order — if MTE2 releases `buf_A` and Vector later acquires `buf_A`, the hardware ensures the acquire cannot proceed until the release completes. + +```mlir +// ─── Stage 1: MTE2 loads data into UB ─── +// MTE2 acquires ub_ptr — blocks if Vector hasn't released it from a prior iteration +pto.get_buf "PIPE_MTE2", %bufid_ub_ptr, %mode : i64, i64 +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, ... +// MTE2 done writing ub_ptr — release it so Vector can consume +pto.rls_buf "PIPE_MTE2", %bufid_ub_ptr, %mode : i64, i64 + +// ─── Stage 2: Vector computation ─── +// Vector acquires ub_ptr (input) — blocks until MTE2 releases it (RAW: MTE2 write → V read) +pto.get_buf "PIPE_V", %bufid_ub_ptr, %mode : i64, i64 +// Vector acquires ub_out (output) — blocks until MTE3 releases it from a prior iteration (WAR: MTE3 read → V write) +pto.get_buf "PIPE_V", %bufid_ub_out, %mode : i64, i64 + +pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub_ptr[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} + +// Vector done reading ub_ptr — release so MTE2 can reuse it in next iteration +pto.rls_buf "PIPE_V", %bufid_ub_ptr, %mode : i64, i64 +// Vector done writing ub_out — release so MTE3 can consume +pto.rls_buf "PIPE_V", %bufid_ub_out, %mode : i64, i64 + +// ─── Stage 3: MTE3 stores result to GM ─── +// MTE3 acquires ub_out — blocks until Vector releases it (RAW: V write → MTE3 read) +pto.get_buf "PIPE_MTE3", %bufid_ub_out, %mode : i64, i64 +pto.copy_ubuf_to_gm %ub_out, %gm_out, ... +// MTE3 done reading ub_out — release so Vector can reuse it in next iteration +pto.rls_buf "PIPE_MTE3", %bufid_ub_out, %mode : i64, i64 +``` + +**Key property:** No event IDs needed. Dependencies are implicit from program order of `get_buf`/`rls_buf` on the same buffer ID. This becomes much more convenient in multi-iteration loops (see Example 3). + +--- + +### Example 3: Ping/Pong Double-Buffering Loop + +Double-buffering overlaps DMA and compute by using two UB buffers alternately. All three stages (MTE2, Vector, MTE3) appear in the **same iteration** — the hardware pipelines them across iterations because different iterations operate on different buffers (`buf[i%2]`). + +#### Event ID scheme (`set_flag` / `wait_flag`) + +With 2 ping/pong buffers and 2 pipeline pairs (MTE2↔V, V↔MTE3), `set_flag`/`wait_flag` needs **8 event IDs** = 2 pipe-pairs × 2 buffers × (forward + reverse): + +**MTE2 ↔ Vector (input buffers):** + +| Event ID | Direction | Purpose | +|----------|-----------|---------| +| `EVT_IN_FWD_0` | MTE2 → V | RAW: buf_in[0] data ready | +| `EVT_IN_FWD_1` | MTE2 → V | RAW: buf_in[1] data ready | +| `EVT_IN_REV_0` | V → MTE2 | WAR: Vector done reading buf_in[0] | +| `EVT_IN_REV_1` | V → MTE2 | WAR: Vector done reading buf_in[1] | + +**Vector ↔ MTE3 (output buffers):** + +| Event ID | Direction | Purpose | +|----------|-----------|---------| +| `EVT_OUT_FWD_0` | V → MTE3 | RAW: buf_out[0] result ready | +| `EVT_OUT_FWD_1` | V → MTE3 | RAW: buf_out[1] result ready | +| `EVT_OUT_REV_0` | MTE3 → V | WAR: MTE3 done reading buf_out[0] | +| `EVT_OUT_REV_1` | MTE3 → V | WAR: MTE3 done reading buf_out[1] | + +#### 3a. `set_flag` / `wait_flag` version + +```mlir +// ═══ Pre-loop: prime ALL reverse-dependency signals ═══ +// Both input and output buffers start unused. We must pre-send +// reverse-dep signals so the first iteration's wait_flags don't deadlock. +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] // ◀ PRIME: buf_in[0] "free" +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] // ◀ PRIME: buf_in[1] "free" +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] // ◀ PRIME: buf_out[0] "free" +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] // ◀ PRIME: buf_out[1] "free" + +scf.for %i = %c0 to %N step %c1 { + // ── All 3 stages in same iteration, indexed by i%2 ── + // %pp = i % 2 (ping/pong selector for buffer & event IDs) + + // ── MTE2: load tile[i] into buf_in[i%2] ── + // WAR: wait until Vector has released buf_in[i%2] from iteration i-2 + pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{pp}"] + pto.copy_gm_to_ubuf %gm_ptr[%i], %ub_in[%pp], ... + // RAW: signal Vector that buf_in[i%2] data is ready + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVT_IN_FWD_{pp}"] + + // ── Vector: compute buf_in[i%2] → buf_out[i%2] ── + // RAW: wait for MTE2 to finish loading buf_in[i%2] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVT_IN_FWD_{pp}"] + // WAR: wait for MTE3 to finish reading buf_out[i%2] from iteration i-2 + pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{pp}"] + scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_in[%pp][%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%pp][%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } {llvm.loop.aivector_scope} + // WAR: tell MTE2 "done reading buf_in[i%2]" + pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{pp}"] + // RAW: tell MTE3 "buf_out[i%2] result ready" + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVT_OUT_FWD_{pp}"] + + // ── MTE3: store result from buf_out[i%2] to GM ── + // RAW: wait for Vector to finish writing buf_out[i%2] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVT_OUT_FWD_{pp}"] + pto.copy_ubuf_to_gm %ub_out[%pp], %gm_out[%i], ... + // WAR: tell Vector "done reading buf_out[i%2]" + pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{pp}"] +} + +// ═══ Post-loop: drain — match every pre-loop prime with a wait ═══ +// Each priming set_flag must be paired. The last loop iteration's +// set_flags are consumed by wait_flags that will never fire inside the +// loop (there is no iteration i+2). Drain them here. +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-1)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-2)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-1)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-2)%2}"] // ◀ DRAIN +``` + +**What `set_flag`/`wait_flag` requires outside the loop:** +- **Before the loop (4 × `set_flag`):** Prime every reverse-dependency event ID — one per buffer per pipe-pair. Without this, the first iteration's `wait_flag` for reverse deps would deadlock (no signal was ever sent). +- **After the loop (4 × `wait_flag`):** Drain the matching reverse-dep signals from the last iterations. Every `set_flag` must be paired with a `wait_flag` — the last loop iterations produce signals that no subsequent iteration consumes, so they must be drained explicitly. + +#### 3b. `get_buf` / `rls_buf` version + +Same ping/pong double-buffering, but **no pre-loop priming or post-loop draining needed.** Buffer acquire/release semantics handle everything. + +```mlir +scf.for %i = %c0 to %N step %c1 { + // %pp = i % 2 (ping/pong selector) + + // ── MTE2: load tile[i] into buf[i%2] ── + // Acquires buf[i%2] — on first iteration, buffer is free so proceeds immediately. + // On later iterations, blocks until Vector releases buf[i%2] (WAR: automatic). + pto.get_buf %bufid_buf[%pp], "PIPE_MTE2" + pto.copy_gm_to_ubuf %gm_ptr[%i], %ub_buf[%pp], ... + pto.rls_buf %bufid_buf[%pp], "PIPE_MTE2" + + // ── Vector: compute on buf[i%2] ── + // Acquires buf[i%2] — blocks until MTE2 releases it (RAW: automatic) + pto.get_buf %bufid_buf[%pp], "PIPE_V" + pto.get_buf %bufid_out[%pp], "PIPE_V" + scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_buf[%pp][%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%pp][%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } {llvm.loop.aivector_scope} + // Release buf[i%2] — MTE2 can reuse in iteration i+2 (WAR resolved) + pto.rls_buf %bufid_buf[%pp], "PIPE_V" + pto.rls_buf %bufid_out[%pp], "PIPE_V" + + // ── MTE3: store result ── + // Acquires out[i%2] — blocks until Vector releases it (RAW: automatic) + pto.get_buf %bufid_out[%pp], "PIPE_MTE3" + pto.copy_ubuf_to_gm %ub_out[%pp], %gm_out[%i], ... + pto.rls_buf %bufid_out[%pp], "PIPE_MTE3" +} +// No post-loop drain needed — last rls_buf completes the pipeline. +``` + +**No priming, no draining, no event IDs.** The acquire/release protocol on buffer IDs indexed by `i%2` implicitly resolves all cross-pipeline dependencies: +- **RAW** (MTE2→V): Vector's `get_buf` blocks until MTE2's `rls_buf` on `buf[i%2]` +- **WAR** (V→MTE2): MTE2's `get_buf` in iteration `i+2` blocks until Vector's `rls_buf` in iteration `i` (same buffer) +- **First iteration:** Buffer is initially free, so `get_buf` proceeds without blocking — no priming needed + +--- + +## Comparison Summary + +| Aspect | `set_flag` / `wait_flag` | `get_buf` / `rls_buf` | +|--------|--------------------------|------------------------| +| Dependency model | Explicit event signals | Implicit via buffer acquire/release | +| IDs per pipe-pair | **8** = 2 buffers × 2 dirs × 2 (fwd+rev) | 1 fwd + 1 rev per buffer (shared global pool) | +| Total HW IDs | 8 per pipe-pair, grows with buffers | **32 global** across all pipes | +| Reverse (WAR) deps | Extra `set_flag`/`wait_flag` pair per buffer | Handled automatically | +| Pre-loop setup | `set_flag` to prime each reverse dep | None | +| Post-loop teardown | `wait_flag` to drain all primed signals | None | +| Straight-line code | Simple, clear | Slightly more verbose (bracket each stage) | +| Ping/pong loops | 8 event IDs + 4 prime + 4 drain | Same pattern, no overhead | +| Best used for | Simple pipelines, fine-grained control | Double/multi-buffering, complex loops | + +--- + +## Inter-Core Sync + +> **Note:** Inter-core sync is only needed for **mixed Cube+Vector tasks** where Cube produces data that Vector consumes (or vice versa). **Vec-only tasks can ignore this section entirely.** + +These ops coordinate execution across the Cube block and Vector subblocks within a cluster. Each core cluster consists of **1 Cube block : 2 Vector subblocks**, each with its own **SU (Sequencer Unit)** running independent instruction streams. + +``` +Core Cluster (1:2 ratio) +┌─────────────────────────────────────────────┐ +│ ┌──────────────┐ ┌──────────────┐ │ +│ │ AIC (Cube) │ │ AIV0 (Vec) │ │ +│ │ ┌────────┐ │ │ ┌────────┐ │ │ +│ │ │ SU │──┼────┼──│ SU │ │ │ +│ │ └────────┘ │ │ └────────┘ │ │ +│ │ CUBE pipe │ │ MTE2/V/MTE3 │ │ +│ │ L0C buffer │ │ UB (256KB) │ │ +│ └──────────────┘ └──────────────┘ │ +│ ┌──────────────┐ │ +│ │ AIV1 (Vec) │ │ +│ │ ┌────────┐ │ │ +│ │ │ SU │ │ │ +│ │ └────────┘ │ │ +│ │ MTE2/V/MTE3 │ │ +│ │ UB (256KB) │ │ +│ └──────────────┘ │ +└─────────────────────────────────────────────┘ +``` + +### Platform Comparison + +| Aspect | A2A3 (Ascend 910) | A5 (Ascend 950) | +|--------|-------------------|-----------------| +| **Signal op** | `set_cross_core` (mode2) | `set_intra_block` | +| **Wait op** | `wait_flag_dev` | `wait_intra_core` | +| **Wait behavior** | SU-level blocking (entire core stalls) | Per-pipeline (only named pipe stalls) | +| **Semaphore pool** | 16 IDs per cluster, 4-bit counter | 16 IDs, but 32-ID address space (see below) | +| **C→V** | **Broadcast**: one `set` reaches both AIV0+AIV1 | **1:1**: separate `set` per subblock required | +| **V→C** | **Reduce**: Cube waits for both subblocks in one `wait` | **1:1**: Cube needs separate `wait` per subblock | + +### A2A3: `set_cross_core` / `wait_flag_dev` + +```c +// mode2 broadcast/reduce semantics for 1:2 cluster +set_cross_core(pipe, semaphore_id); // pipe: VEC/MTE2/CUBE/FIX +wait_flag_dev(semaphore_id); // SU-level blocking +``` + +``` +C→V Broadcast (one set reaches both): + AIC ──set_cross_core──┬──> AIV0 sema++ + └──> AIV1 sema++ + +V→C Reduce (one wait for both): + AIV0 ──set_cross_core──┐ + ├──> AIC wait_flag_dev (blocks until BOTH) + AIV1 ──set_cross_core──┘ +``` + +### `pto.set_cross_core` + +- **syntax:** `pto.set_cross_core %core_id, %event_id : i64, i64` +- **semantics:** Signal event to another core. Uses **mode2** for 1:2 cluster on A2A3. + +### `pto.wait_flag_dev` + +- **syntax:** `pto.wait_flag_dev %core_id, %event_id : i64, i64` +- **semantics:** Wait for event from another core. **SU-level blocking** — entire core stalls. + +### A5: `set_intra_block` / `wait_intra_core` + +```c +set_intra_block(trigger_pipe, semaphore_id); +wait_intra_core(wait_pipe, semaphore_id); // only named pipe stalls +``` + +**A5 semaphore address space:** The hardware has **16 physical semaphore IDs** but exposes a **32-ID address space** to support 1:1 signaling to each subblock: + +| ID Range | Target | +|----------|--------| +| 0–15 | AIV0 (subblock 0) | +| 16–31 (+15 offset) | AIV1 (subblock 1) | + +This means C→V requires **separate `set_intra_block` calls** per subblock (no broadcast), and V→C requires **separate `wait_intra_core` calls** per subblock (no hardware reduce). + +``` +C→V on A5 (1:1, no broadcast — need two sets): + AIC ──set_intra_block(pipe, sema_id)────> AIV0 + AIC ──set_intra_block(pipe, sema_id+15)──> AIV1 + +V→C on A5 (1:1, no reduce — need two waits): + AIV0 ──set_intra_block──> AIC wait_intra_core(pipe, sema_id) + AIV1 ──set_intra_block──> AIC wait_intra_core(pipe, sema_id+15) // extra wait +``` + +### `pto.set_intra_block` + +- **syntax:** `pto.set_intra_block %block_id, %event_id : i64, i64` +- **semantics:** Signal event within a block (A5). Specifies **trigger pipe**. 1:1 per subblock. + +### `pto.wait_intra_core` + +- **syntax:** `pto.wait_intra_core %block_id, %event_id : i64, i64` +- **semantics:** Wait for event within block (A5). Specifies **which pipeline should wait** — only that pipe stalls, SU and other pipes continue. + +### Wait Granularity Comparison + +``` +A2A3 wait_flag_dev (SU-level stall): + SU ──┬── PIPE_MTE2 ───╳ ALL STALLED + ├── PIPE_V ───╳ ALL STALLED + └── PIPE_MTE3 ───╳ ALL STALLED + +A5 wait_intra_core "PIPE_MTE2" (per-pipe stall): + SU ──┬── PIPE_MTE2 ───╳ STALLED (waiting for Cube) + ├── PIPE_V ─── ✓ RUNNING + └── PIPE_MTE3 ─── ✓ RUNNING +``` diff --git a/docs/isa/02-dma-copy.md b/docs/isa/02-dma-copy.md new file mode 100644 index 000000000..8d867af08 --- /dev/null +++ b/docs/isa/02-dma-copy.md @@ -0,0 +1,602 @@ +# 2. DMA Copy Programming + +> **Category:** DMA transfer configuration and execution +> **Pipelines:** MTE2 (GM→UB), MTE3 (UB→GM) + +DMA transfers move data between Global Memory (GM) and Unified Buffer (UB). The MTE engines operate asynchronously from the Vector core, requiring explicit sync (see [Pipeline Sync](01-pipeline-sync.md)). + +The MTE2/MTE3 DMA engine executes a **multi-level nested loop** transfer. Before issuing the copy instruction, stride and loop-size registers must be configured. + +--- + +## Loop Stride Configuration (GM→UB) + +These ops configure the MTE2 DMA engine's hardware loops for GM→UB transfers. They must be set **before** calling `pto.copy_gm_to_ubuf`. + +### `pto.set_loop_size_outtoub` + +- **syntax:** `pto.set_loop_size_outtoub %loop1_count, %loop2_count : i64, i64` +- **semantics:** Configure HW loop iteration counts for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%loop1_count` | 21 bits | Inner HW loop iteration count | +| `%loop2_count` | 21 bits | Outer HW loop iteration count | + +When not using multi-level looping, set both to 1. + +--- + +### `pto.set_loop2_stride_outtoub` + +- **syntax:** `pto.set_loop2_stride_outtoub %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure outer loop (loop2) pointer advance for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 40 bits | GM source pointer advance per loop2 iteration (bytes) | +| `%dst_stride` | 21 bits | UB destination pointer advance per loop2 iteration (bytes) | + +After each loop2 iteration, the DMA engine advances the GM read pointer by `%src_stride` and UB write pointer by `%dst_stride`. + +--- + +### `pto.set_loop1_stride_outtoub` + +- **syntax:** `pto.set_loop1_stride_outtoub %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure inner loop (loop1) pointer advance for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 40 bits | GM source pointer advance per loop1 iteration (bytes) | +| `%dst_stride` | 21 bits | UB destination pointer advance per loop1 iteration (bytes) | + +--- + +## Loop Stride Configuration (UB→GM) + +These ops configure the MTE3 DMA engine's hardware loops for UB→GM transfers. They must be set **before** calling `pto.copy_ubuf_to_gm`. + +Note: UB stride fields are 21 bits (sufficient for 256KB UB address space), GM stride fields are 40 bits (full GM address range). + +### `pto.set_loop_size_ubtoout` + +- **syntax:** `pto.set_loop_size_ubtoout %loop1_count, %loop2_count : i64, i64` +- **semantics:** Configure HW loop iteration counts for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%loop1_count` | 21 bits | Inner HW loop iteration count | +| `%loop2_count` | 21 bits | Outer HW loop iteration count | + +--- + +### `pto.set_loop2_stride_ubtoout` + +- **syntax:** `pto.set_loop2_stride_ubtoout %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure outer loop (loop2) pointer advance for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 21 bits | UB source pointer advance per loop2 iteration (bytes) | +| `%dst_stride` | 40 bits | GM destination pointer advance per loop2 iteration (bytes) | + +--- + +### `pto.set_loop1_stride_ubtoout` + +- **syntax:** `pto.set_loop1_stride_ubtoout %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure inner loop (loop1) pointer advance for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 21 bits | UB source pointer advance per loop1 iteration (bytes) | +| `%dst_stride` | 40 bits | GM destination pointer advance per loop1 iteration (bytes) | + +--- + +## DMA Transfer Execution + +### `pto.copy_gm_to_ubuf` + +- **syntax:** +```mlir +pto.copy_gm_to_ubuf %gm_src, %ub_dst, + %sid, %n_burst, %len_burst, %left_padding, %right_padding, + %data_select_bit, %l2_cache_ctl, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` +- **semantics:** DMA transfer from Global Memory (`!pto.ptr`) to Unified Buffer (`!pto.ptr`). + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%gm_src` | GM source pointer (`!pto.ptr`) | +| `%ub_dst` | UB destination pointer (`!pto.ptr`, 32B-aligned) | +| `%sid` | Stream ID (usually 0) | +| `%n_burst` | Number of burst rows (innermost loop count) | +| `%len_burst` | Contiguous bytes transferred per burst row | +| `%left_padding` | Left padding count (bytes) | +| `%right_padding` | Right padding count (bytes) | +| `%data_select_bit` | Padding / data-select control bit (`i1`) | +| `%l2_cache_ctl` | L2 cache allocate control (TBD — controls whether DMA allocates in L2 cache) | +| `%src_stride` | GM source stride: start-to-start distance between consecutive burst rows (bytes) | +| `%dst_stride` | UB destination stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | + +--- + +### `pto.copy_ubuf_to_gm` + +- **syntax:** +```mlir +pto.copy_ubuf_to_gm %ub_src, %gm_dst, + %sid, %n_burst, %len_burst, %reserved, %dst_stride, %src_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` +- **semantics:** DMA transfer from Unified Buffer (`!pto.ptr`) to Global Memory (`!pto.ptr`). MTE3 reads only `len_burst` bytes from each UB row (de-padding). + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%ub_src` | UB source pointer (`!pto.ptr`, 32B-aligned) | +| `%gm_dst` | GM destination pointer (`!pto.ptr`) | +| `%sid` | Stream ID (usually 0) | +| `%n_burst` | Number of burst rows | +| `%len_burst` | Contiguous bytes transferred per burst row | +| `%reserved` | Reserved field (set to 0) | +| `%dst_stride` | GM destination stride: start-to-start distance between consecutive burst rows (bytes) | +| `%src_stride` | UB source stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | + +--- + +### `pto.copy_ubuf_to_ubuf` + +- **syntax:** +```mlir +pto.copy_ubuf_to_ubuf %source, %dest, %sid, %n_burst, %len_burst, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64 x5 +``` +- **semantics:** Copy within Unified Buffer. + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%source` | UB source pointer | +| `%dest` | UB destination pointer | +| `%sid` | Stream ID | +| `%n_burst` | Number of bursts | +| `%len_burst` | Length per burst | +| `%src_stride` | Source stride | +| `%dst_stride` | Destination stride | + +--- + +## Burst / Stride / Pad Model + +All A5 DMA addresses are **stride-based**: stride is the distance from the start of one row to the start of the next row (`stride >= lenBurst`). There is no separate "gap" parameter. + +### Key Terms + +``` +burst = lenBurst contiguous bytes transferred per row +stride = distance (bytes) from start of row[r] to start of row[r+1] +pad = ub_stride - lenBurst, padded to the 32B alignment boundary +``` + +### Alignment Constraints + +- **UB addresses** (both source and destination) must be **32-byte aligned**. +- **GM→UB padding**: When `data_select_bit = true`, each UB row is padded from `lenBurst` up to the **32B-aligned boundary** of `ub_stride` with `pad_val` (set via `set_mov_pad_val`). This ensures every UB row starts at a 32B-aligned offset. +- **UB→GM de-padding**: MTE3 reads `lenBurst` bytes from each 32B-aligned UB row (skipping any padding that was added during load), writing only valid data to GM. This effectively strips padding on store. + +### 2D Diagram: GM→UB (pto.copy_gm_to_ubuf) + +``` +GM (source, `!pto.ptr`): + + |<--- src_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +UB (destination, `!pto.ptr`, 32B-aligned): + + |<---------- dst_stride (32B-aligned) ---------->| + |<- len_burst ->|<- pad (to 32B boundary) ->| | +Row 0: [##DATA########][000000 PAD 000000000000000] +Row 1: [##DATA########][000000 PAD 000000000000000] +Row 2: [##DATA########][000000 PAD 000000000000000] + ... +Row N-1: [##DATA########][000000 PAD 000000000000000] + +N = n_burst +stride = start of row[r] to start of row[r+1] +pad = filled with pad_val to 32B boundary (data_select_bit=true) +[DATA] = valid data transferred by DMA +[PAD] = pad_val fill (set via set_mov_pad_val) +``` + +### 2D Diagram: UB→GM (pto.copy_ubuf_to_gm) + +``` +UB (source, `!pto.ptr`, 32B-aligned start addr): + + |<---------- src_stride (32B-aligned) --------->| + |<- len_burst ->|<-- pad (ignored on read) -->| | +Row 0: [##DATA########][000 pad 000000000000000000] +Row 1: [##DATA########][000 pad 000000000000000000] +Row 2: [##DATA########][000 pad 000000000000000000] + ... +Row N-1: [##DATA########][000 pad 000000000000000000] + +GM (destination, `!pto.ptr`): + + |<--- dst_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +N = n_burst +MTE3 reads only len_burst bytes from each UB row (de-padding). +Only len_burst bytes are written to each GM row. +``` + +--- + +## Multi-Level Loop Semantics (C Code) + +The full DMA transfer is a nested loop. The HW loop registers (set before the copy) control the outer levels, and the copy instruction parameters control the innermost burst level. + +### GM→UB Full Loop + +```c +// C equivalent of what the HW executes: +for (int j = 0; j < loop2_count; j++) { // HW outer loop + uint8_t *gm1 = gm_src + j * loop2_src_stride; + uint8_t *ub1 = ub_dst + j * loop2_dst_stride; + + for (int k = 0; k < loop1_count; k++) { // HW inner loop + uint8_t *gm2 = gm1 + k * loop1_src_stride; + uint8_t *ub2 = ub1 + k * loop1_dst_stride; + + for (int r = 0; r < n_burst; r++) { // burst engine + memcpy(ub2 + r * dst_stride, // UB dest row + gm2 + r * src_stride, // GM src row + len_burst); // contiguous bytes + if (data_select_bit) + memset(ub2 + r * dst_stride + len_burst, + pad_val, dst_stride - len_burst); + } + } +} +``` + +### UB→GM Full Loop + +```c +// C equivalent: +for (int j = 0; j < loop2_count; j++) { + uint8_t *ub1 = ub_src + j * loop2_src_stride; + uint8_t *gm1 = gm_dst + j * loop2_dst_stride; + + for (int k = 0; k < loop1_count; k++) { + uint8_t *ub2 = ub1 + k * loop1_src_stride; + uint8_t *gm2 = gm1 + k * loop1_dst_stride; + + for (int r = 0; r < n_burst; r++) { + memcpy(gm2 + r * dst_stride, // GM dest row + ub2 + r * src_stride, // UB src row + len_burst); // contiguous bytes + } + } +} +``` + +--- + +## Example 1: GM→UB — Load a 32×32 f32 Tile (Simple Case) + +Load a 32×32 f32 tile from GM into UB. This matches the `abs_kernel_2d` test case. + +``` +GM layout (32 × 32 f32, contiguous): + + |<- len_burst = 128B (32 × 4) ->| + |<- src_stride = 128B --------->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + +UB layout (32 × 32 f32, 32B-aligned, contiguous): + + |<- dst_stride = 128B (32B-aligned) ->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + + len_burst = 32 × 4 = 128 bytes + src_stride = 128 bytes (contiguous rows) + dst_stride = 128 bytes (already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — no multi-level loops needed +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + +pto.copy_gm_to_ubuf %arg0, %ub_in, + %c0_i64, // sid = 0 + %c32_i64, // n_burst = 32 (32 rows) + %c128_i64, // len_burst = 128 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c128_i64, // src_stride = 128 bytes + %c128_i64 // dst_stride = 128 bytes + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +## Example 2: GM→UB — Load a 2D Tile from a Larger Matrix + +Load a 64×128 tile (f16) from a 1024×512 matrix in GM into UB. + +``` +GM layout (1024 × 512 f16): + + col 0 col 128 col 512 + | | | + +--[###TILE###]+.....................+ row R + +--[###TILE###]+.....................+ row R+1 + ... + +--[###TILE###]+.....................+ row R+63 + + |<--------- src_stride = 1024B ----------->| + |<-len_burst=256B->| + + len_burst = 128 × 2 = 256 bytes (128 f16 elements) + src_stride = 512 × 2 = 1024 bytes (start-to-start, full GM row) + +UB layout (64 × 128 f16, 32B-aligned, contiguous): + + +--[###TILE###]--+ row 0 (256 bytes, 32B-aligned, no pad) + +--[###TILE###]--+ row 1 + ... + +--[###TILE###]--+ row 63 + + dst_stride = 256 bytes (= len_burst, already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — no multi-level loops needed +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 (64 rows) + %c256_i64, // len_burst = 256 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c1024_i64, // src_stride = 1024 bytes (full matrix row) + %c256_i64 // dst_stride = 256 bytes (tile row) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +## Example 3: GM→UB — Load with Padding + +Load 100 valid columns from GM into a 128-wide UB tile (f16). The remaining 28 columns are zero-padded. + +``` +GM (100 cols valid, contiguous): + + |<-len_burst=200B->| + |<- src_stride=200B (start-to-start) ->| + +--[####DATA####]-+ row 0 + +--[####DATA####]-+ row 1 + ... + +--[####DATA####]-+ row 63 + +UB (128 cols wide, 32B-aligned, padded): + + |<--------- dst_stride = 256B (32B-aligned) --------->| + |<-len_burst=200B->|<---- pad = 56B to 32B boundary ->| + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 0 + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 1 + ... + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 63 + + len_burst = 100 × 2 = 200 bytes + src_stride = 200 bytes (start-to-start, contiguous in GM) + dst_stride = 128 × 2 = 256 bytes (32B-aligned tile width in UB) + pad = 256 - 200 = 56 bytes (padded to 32B boundary with pad_val) +``` + +```mlir +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 + %c200_i64, // len_burst = 200 bytes + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %true, // data_select_bit = true (enable padding) + %c0_i64, // l2_cache_ctl = 0 + %c200_i64, // src_stride = 200 bytes + %c256_i64 // dst_stride = 256 bytes (32B-aligned) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +## Example 4: UB→GM — Store a 32×32 f32 Tile (Simple Case) + +Store a 32×32 f32 tile from UB back to GM. This matches the `abs_kernel_2d` test case. + +``` +UB (source, 32B-aligned, 32 × 32 f32): + + |<- src_stride = 128B (32B-aligned) ->| + |<- len_burst = 128B ->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 + + (no padding here — len_burst == src_stride) + +GM (dest, 32 × 32 f32): + + |<- dst_stride = 128B ->| + |<- len_burst = 128B -->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 +``` + +```mlir +// Configure MTE3 strides +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + +pto.copy_ubuf_to_gm %ub_out, %arg1, + %c0_i64, // sid = 0 + %c32_i64, // n_burst = 32 + %c128_i64, // len_burst = 128 bytes + %c0_i64, // reserved = 0 + %c128_i64, // dst_stride = 128 bytes + %c128_i64 // src_stride = 128 bytes + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` + +--- + +## Example 5: UB→GM — Store a 2D Tile Back to a Larger Matrix + +Store a 64×128 tile (f16) from UB back to a 1024×512 GM matrix at an offset. + +``` +UB (source, 32B-aligned, 64 × 128 f16): + + |<- src_stride = 256B (32B-aligned) ->| + |<- len_burst = 256B ->| + +--[#####TILE#####]---+ row 0 + +--[#####TILE#####]---+ row 1 + ... + +--[#####TILE#####]---+ row 63 + + (no padding here — len_burst == src_stride) + +GM (dest, into 1024 × 512 matrix): + + |<----------- dst_stride = 1024B (start-to-start) --------->| + |<- len_burst = 256B ->| | + col 0 col 128 col 512 + +--[#####TILE#####]---+.............................+ row R + +--[#####TILE#####]---+.............................+ row R+1 + ... + +--[#####TILE#####]---+.............................+ row R+63 + + MTE3 reads len_burst bytes from each 32B-aligned UB row, + writes only len_burst bytes per GM row (stride controls row spacing). +``` + +```mlir +// Configure MTE3 strides +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 + +pto.copy_ubuf_to_gm %ub_ptr, %gm_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 + %c256_i64, // len_burst = 256 bytes + %c0_i64, // reserved = 0 + %c1024_i64, // dst_stride = 1024 bytes (GM row) + %c256_i64 // src_stride = 256 bytes (UB row) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` + +--- + +## Example 6: GM→UB with Multi-Level Loop (Batch of Tiles) + +Load 4 batches of 8×128 tiles from a [4, 8, 128] f16 tensor using loop1. + +``` +GM [4, 8, 128] f16 (contiguous): UB (4 tiles laid out sequentially): + + batch 0: 8 rows × 256 bytes [batch 0: 8×128][batch 1: 8×128] + batch 1: 8 rows × 256 bytes [batch 2: 8×128][batch 3: 8×128] + batch 2: 8 rows × 256 bytes + batch 3: 8 rows × 256 bytes loop1 src_stride = 2048 bytes (8 × 256) + loop1 dst_stride = 2048 bytes (8 × 256) + Each batch = 8 × 256 = 2048 bytes loop1_count = 4 (iterate over batches) +``` + +```mlir +// loop1_count = 4 batches, loop2_count = 1 (not used) +pto.set_loop_size_outtoub %c4_i64, %c1_i64 : i64, i64 + +// loop1 stride: advance by one batch (2048 bytes) in both GM and UB +pto.set_loop1_stride_outtoub %c2048_i64, %c2048_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c8_i64, // n_burst = 8 rows per batch + %c256_i64, // len_burst = 256 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c256_i64, // src_stride = 256 (contiguous rows) + %c256_i64 // dst_stride = 256 (contiguous rows) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +Execution trace: + +``` +loop1 iter 0: gm_ptr + 0×2048 → ub_ptr + 0×2048, DMA 8 rows × 256B +loop1 iter 1: gm_ptr + 1×2048 → ub_ptr + 1×2048, DMA 8 rows × 256B +loop1 iter 2: gm_ptr + 2×2048 → ub_ptr + 2×2048, DMA 8 rows × 256B +loop1 iter 3: gm_ptr + 3×2048 → ub_ptr + 3×2048, DMA 8 rows × 256B +``` diff --git a/docs/isa/03-vector-load-store.md b/docs/isa/03-vector-load-store.md new file mode 100644 index 000000000..bb840e44b --- /dev/null +++ b/docs/isa/03-vector-load-store.md @@ -0,0 +1,595 @@ +# 3. Vector Load/Store + +> **Category:** UB ↔ Vector Register data movement +> **Pipeline:** PIPE_V (Vector Core) + +Vector loads move data from Unified Buffer (UB) to vector registers (`vreg`). Vector stores move data from `vreg` back to UB. All vector compute operates only on `vreg` — UB is the staging area between DMA and compute. + +## Common Operand Model + +- `%source` / `%dest` is the base address operand in SSA form. The base pointer + MUST address the Vector tile buffer / UB space. +- `%offset` is the displacement operand in SSA form. The exact encoding is + instruction-specific, but the effective address and any post-update behavior + MUST match the selected instruction form. +- `%mask` is the predicate operand for predicated memory families. For memory + families, + inactive lanes or inactive blocks MUST NOT issue memory requests unless the + instruction explicitly documents a different behavior. +- `%result` is the destination vector register value in SSA form. +- `!pto.align` is the SSA carrier for alignment-buffer state used by unaligned + load/store families. The PTO micro Instruction representation makes that state explicit rather than implicit. + +--- + +## Latency and throughput (A5) + +**Cycle-accurate simulator (CA model)** issue→retire timings for vector-side instructions behind this chapter. Values are **simulator** results, **not** guaranteed for silicon. + +**SOC:** Tables below are from **Ascend910_9599** CA sim (the pto-isa ST default when **Ascend950PR_9599** is not selected). + +**Log `dist:` tokens:** PTO load/store modes lower to **`RV_VLD` / `RV_VLDI` / `RV_VST` / `RV_VSTI`** with a **`dist:`** field on the vector pipes (`RVECLD` / `RVECST`). Some simulator logs typo contiguous load as `dist:NORAML`; treat as **`NORMAL`**. + +### Reference op latencies (A5 mnemonics) + +| A5 mnemonic | Mode / note | Typical issue→retire (cycles) | +|-------------|-------------|------------------------------| +| `RV_VLD` | `dist:NORMAL` / `NORAML` | **9** | +| `RV_VLDI` | `dist:DINTLV` (dual vreg) | **9** | +| `RV_VST` / `RV_VSTI` | `dist:NORM` | **9** | +| `RV_VGATHER2` | `Dtype: B32` | **27–28** | +| `RV_VGATHERB` | indexed byte gather | **~21** | +| `RV_VSCATTER` | `Dtype: B16` | **~17** | +| `RV_VADD` | F32 between UB-backed ops | **7** | + +### `dist:` tokens (issue→retire) + +Most **`dist:`** tokens are **9** issue→retire cycles. **`INTLV`** on **`RV_VSTI`** is **12** cycles. + +| `dist:` (as in log) | RV op | issue→retire (cycles) | +|---------------------|-------|----------------------| +| `DINTLV` | `RV_VLDI` | **9** | +| `BRC` | `RV_VLD` | **9** | +| `BRC_BLK` | `RV_VLD` | **9** | +| `INTLV` | `RV_VSTI` | **12** | +| `UNPK` | `RV_VLD` | **9** | +| `NORM` | `RV_VSTI` | **9** | +| `PK` | `RV_VSTI` | **9** | +| `NORMAL` / `NORAML` | `RV_VLD` | **9** | + +**Note:** PTO intrinsic **`BRC_BLK`** matches the **`BRC_BLK`** `dist:` string on **`RV_VLD`** in simulator logs (block-replicate path; not a plain contiguous copy in the usual tiling use). + +**Issue (vector load/store):** `pto.vlds` (**`RV_VLD`**) is **dual-issue capable**: two independent `pto.vlds` can issue **in the same cycle**. **Alternatively**, the hardware can issue **one** `pto.vlds` **and** **one** `pto.vsts` together (**1+1**) in the same cycle. Each cycle is **either** dual **`vlds`** **or** **`vlds` + `vsts` (1+1)**—those two issue modes are mutually exclusive. Sustained throughput still depends on RAW hazards and loop structure. + +**Throughput (simulator, pattern-dependent):** + +- **`RV_VLD` / `pto.vlds`:** Dual-issue **or** half of a **1+1** with `vsts`, per the rule above. +- **`RV_VST` / `pto.vsts`:** In a **1+1** cycle, pairs with one `vlds`; otherwise typically **one** store per cycle in tight loops. +- **`RV_VGATHER2`:** Much lower than contiguous `RV_VLD` (on the order of **~0.1** ops/cycle in steady-state alongside 27–28-cycle latency). + +### PTO `dist` summary (loads) + +| PTO `dist` (load) | Latency | +|-------------------|-------------------| +| `NORM` | **9** cycles | +| `UNPK` | **9** cycles | +| `DINTLV` | **9** cycles (`RV_VLDI`) | +| `BRC` | **9** cycles (`RV_VLD`) | +| `BRC_BLK` | **9** cycles as **`dist:BRC_BLK`** on `RV_VLD` | +| `BDINTLV` | **9** cycles | +| `US`, `DS`, `SPLT4CHN`, `SPLT2CHN` | **9** cycles | + +### PTO `dist` summary (stores) + +| PTO `dist` (store) | Latency | +|--------------------|-------------------| +| `NORM` | **9** cycles (`RV_VSTI`) | +| `PK` | **9** cycles | +| `INTLV` (`pto.vstx2`) | **12** cycles | +| `MRG4CHN`, `MRG2CHN` | **9** cycles | + +### Gather, scatter, and special addressing + +| PTO op | A5-level | Latency | +|--------|----------|-------------------| +| `pto.vgather2` | `RV_VGATHER2` | **27–28** cycles (pattern-dependent) | +| `pto.vgatherb` | `RV_VGATHERB` | **~21** cycles issue→retire | +| `pto.vgather2_bc` | (broadcast gather) | **27–28** cycles (same as **`pto.vgather2`**) | +| `pto.vscatter` | `RV_VSCATTER` | **~17** cycles for **`Dtype: B16`** | + +### Strided loads/stores, unaligned ops, alignment state + +Ops such as **`pto.vldas`**, **`pto.vldus`**, **`pto.vsld`**, **`pto.vsldb`**, **`pto.vsst`**, **`pto.vsstb`**, **`pto.vsta`**, **`pto.vstas`**, **`pto.vstar`**, **`pto.vstu`**, **`pto.vstus`**, **`pto.vstur`**: **9** cycles (same vector load/store pipe family as contiguous `RV_VLD` / `RV_VST` unless listed otherwise above). + +### Dual-issue vs DMA + +DMA **`TLOAD` / `TSTORE`** (global memory ↔ UB) use **MTE** pipes, not `RV_VLD`/`RV_VST`. **MTE2** `MOV_*` latency is not the same as vector `RV_VLD` latency (see `02-dma-copy.md` for GM↔UB movement). + +--- + +## Contiguous Loads + +### `pto.vlds` + +- **syntax:** `%result = pto.vlds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.vreg` +- **semantics:** Vector load with distribution mode. +- **inputs:** + `%source` is the UB base address, `%offset` is the load displacement, and + `DIST` selects the distribution mode. +- **outputs:** + `%result` is the loaded vector register value. +- **constraints and limitations:** + The effective address MUST satisfy the alignment rule of the selected + distribution mode. `NORM` reads one full vector footprint. Broadcast, + upsample, downsample, unpack, split-channel, and deinterleave modes change + how memory bytes are mapped into destination lanes, but they do not change the + fact that the source is UB memory. PTO surface exposes load `dist` as family + tokens, and each family only supports the element widths listed below. + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `NORM` | width-agnostic | `dst[i] = UB[base + i * sizeof(T)]` | **9** cycles | +| `BRC` | `b8`, `b16`, `b32` | `dst[i] = UB[base]` for all `i` | **9** cycles | +| `US` | `b8`, `b16` | `dst[2*i] = dst[2*i+1] = UB[base + i]` | **9** cycles | +| `DS` | `b8`, `b16` | `dst[i] = UB[base + 2*i]` | **9** cycles | +| `UNPK` | `b8`, `b16`, `b32` | Expand packed source data into wider lanes | **9** cycles | +| `BRC_BLK` | width-agnostic | Block-replicate load path; simulator logs may print `dist:BRC_BLK` | **9** cycles | +| `E2B` | `b16`, `b32` | Load element groups and expand them into byte-oriented lane layout | **9** cycles | +| `UNPK4` | `b8` | Unpack 4-way packed `b8` source groups into destination lanes | **9** cycles | +| `SPLT4CHN` | `b8` | Split 4-channel interleaved source into one channel plane | **9** cycles | +| `SPLT2CHN` | `b8`, `b16` | Split 2-channel interleaved source into one channel plane | **9** cycles | + +`pto.vlds` currently covers only single-result load families. Dual-result +deinterleave forms are modeled separately in PTO surface as +[`pto.vldsx2`](#ptovldsx2): `BDINTLV` is the block-deinterleave family, while +`DINTLV` is the element-width-sensitive deinterleave family. + +**Example — Contiguous load:** +```mlir +%v = pto.vlds %ub[%offset] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +``` + +**Example — Broadcast scalar to all lanes:** +```mlir +%v = pto.vlds %ub[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> +``` + +--- + +### `pto.vldas` + +- **syntax:** `%result = pto.vldas %source : !pto.ptr -> !pto.align` +- **semantics:** Prime alignment buffer for subsequent unaligned load. +- **inputs:** + `%source` is the UB address whose surrounding aligned block seeds the load + alignment state. +- **outputs:** + `%result` is the initialized load-alignment state. +- **constraints and limitations:** + This op is the required leading operation for a `pto.vldus` stream using the + same alignment state. The source address itself need not be 32-byte aligned; + hardware truncates it to the aligned block boundary for the priming load. +- **Latency:** **9** cycles. + +--- + +### `pto.vldus` + +- **syntax:** `%result, %align_out = pto.vldus %source, %align : !pto.ptr, !pto.align -> !pto.vreg, !pto.align` +- **semantics:** Unaligned load using primed align state. +- **inputs:** + `%source` is the current UB address and `%align` is the incoming load + alignment state primed by `pto.vldas` or a prior `pto.vldus`. +- **outputs:** + `%result` is the assembled vector value and `%align_out` is the updated + alignment state. +- **constraints and limitations:** + A matching `pto.vldas` MUST appear before the first dependent `pto.vldus` + stream in the same vector loop. The installed no-post A5 interface keeps a + struct-shaped internal return for lowering convenience, but its no-post + `base` field is not meaningful user-visible state. VPTO therefore hides that + value and only exposes the updated align carrier. Reusing the original + `%source` starts a new explicit access point; if the caller wants another + no-post access, it should compute the next source pointer explicitly and pair + it with the required align setup. +- **Latency:** **9** cycles. + +**Unaligned load pattern:** +```mlir +%align = pto.vldas %ub : !pto.ptr -> !pto.align +%vec, %align2 = pto.vldus %ub, %align : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align +``` + +--- + +### `pto.init_align` + +- **syntax:** `%result = pto.init_align : !pto.align` +- **semantics:** Initialize store-side align carrier state. +- **outputs:** + `%result` is a fresh zero-initialized align carrier for store-side unaligned + streams such as `pto.vstus`, `pto.vstur`, `pto.vstar`, `pto.vstas`, and + `pto.pstu`. +- **constraints and limitations:** + This op is for store-family initialization only. Unaligned load streams still + start from `pto.vldas`. + +--- + +## Dual Loads (Deinterleave) + +### `pto.vldsx2` + +- **syntax:** `%low, %high = pto.vldsx2 %source[%offset], "DIST" : !pto.ptr, index -> !pto.vreg, !pto.vreg` +- **semantics:** Dual load with deinterleave (AoS → SoA conversion). +- **inputs:** + `%source` is the UB base pointer, `%offset` is the displacement, and `DIST` + selects a dual-load/deinterleave layout. +- **outputs:** + `%low` and `%high` are the two destination vectors. +- **constraints and limitations:** + This family is only legal for interleave/deinterleave style distributions. + The two outputs form an ordered pair, and that pairing MUST be preserved. + PTO surface accepts deinterleave families. `BDINTLV` is element-width + agnostic, while `DINTLV` supports only the element widths listed in the + table. +- **latency:** `BDINTLV` / `DINTLV` are both **9** cycles. + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `BDINTLV` | width-agnostic | Block deinterleave into two destination vectors | **9** cycles | +| `DINTLV` | `b8`, `b16`, `b32` | Deinterleave alternating elements into `%low` / `%high` | **9** cycles | + +```c +// DINTLV family on 32-bit elements: deinterleave 32-bit elements +for (int i = 0; i < 64; i++) { + low[i] = UB[base + 8*i]; // even elements + high[i] = UB[base + 8*i + 4]; // odd elements +} +``` + +**Example — Load interleaved XY pairs into separate X/Y vectors:** +```mlir +%x, %y = pto.vldsx2 %ub[%offset], "DINTLV" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +``` + +### `pto.vsldb` + +- **syntax:** `%result = pto.vsldb %source, %block_stride, %repeat_stride, %mask : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg` +- **semantics:** Block-strided load for 2D tile access. +- **inputs:** + `%source` is the UB base pointer. `%block_stride` and `%repeat_stride` are + the two 16-bit fields of the hardware control word, and `%mask` controls + which blocks participate. +- **outputs:** + `%result` is the loaded vector. +- **constraints and limitations:** + PTO surface does not expose the packed control word directly. If a block is + masked off, the corresponding destination block is zeroed and MUST NOT raise + an address overflow exception for that block. +- **Latency:** **9** cycles. + +```c +// Block-strided load on 32-bit elements: one 32B block = 8 lanes. +for (int blk = 0; blk < 8; ++blk) { + if (pg_b32[blk]) + dst_block[blk] = UB_block[base + repeat_stride + blk * block_stride]; + else + dst_block[blk] = 0; +} +``` + +--- + +## Gather (Indexed) Loads + +### `pto.vgather2` + +- **syntax:** `%result = pto.vgather2 %source, %offsets, %active_lanes : !pto.ptr, !pto.vreg, index -> !pto.vreg` +- **semantics:** Indexed gather from UB. +- **inputs:** + `%source` is the UB base pointer, `%offsets` provides per-lane element + offsets, and `%active_lanes` bounds how many lanes participate. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + Only the first `%active_lanes` indices participate. The index element width + and interpretation MUST match the selected gather form, and each effective + address must satisfy that form's alignment rules. +- **Latency:** **27–28** cycles per `RV_VGATHER2`; throughput much lower than contiguous `RV_VLD` (see **Latency and throughput (A5)** at the start of this chapter). + +```c +for (int i = 0; i < active_lanes; i++) + dst[i] = UB[base + offsets[i] * sizeof(T)]; +``` + +--- + +### `pto.vgatherb` + +- **syntax:** `%result = pto.vgatherb %source, %offsets, %mask : !pto.ptr, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Block gather load from UB. +- **inputs:** + `%source` is the UB base pointer, `%offsets` is a `ui32` offset vector, and + `%mask` is a `b32` predicate over the block-index lanes. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + This is a 32-byte block gather, not an element gather. `%source` MUST be + 32-byte aligned. Each participating `offsets[i]` is interpreted as a byte + offset and MUST itself be 32-byte aligned. Only the low `VL/8` bytes of the + offset vector are semantically valid; the effective block address is + `block_addr[i] = offsets_u32[i] + base`. If a `b32` predicate position is + false, the corresponding block does not participate in address coalescing, + does not raise overflow on that block address, and the destination block is + zero-filled. +- **Latency:** **~21** cycles issue→retire. + +```c +for (int blk = 0; blk < VL / 32; ++blk) { + if (pg_b32[blk]) + dst_block[blk] = UB_block[base + offsets_u32[blk]]; + else + dst_block[blk] = 0; +} +``` + +--- + +### `pto.vgather2_bc` + +- **syntax:** `%result = pto.vgather2_bc %source, %offsets, %mask : !pto.ptr, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Gather with broadcast, conditioned by mask. +- **inputs:** + `%source` is the UB base pointer, `%offsets` contains gather indices, and + `%mask` gates which lanes participate. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + This is a backward-compatible family. Masked-off lanes do not participate in + address coalescing and do not trigger address overflow exceptions; their + destination lanes are zero-filled. On the current PTO surface, `%offsets` + uses 32-bit integer elements. +- **Latency:** **27–28** cycles (same as **`pto.vgather2`**). + +--- + +## Contiguous Stores + +### `pto.vsts` + +- **syntax:** `pto.vsts %value, %dest[%offset], %mask {dist = "DIST"} : !pto.vreg, !pto.ptr, !pto.mask` +- **semantics:** Vector store with distribution mode. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offset` is + the displacement, `%mask` selects the active lanes or sub-elements, and + `DIST` selects the store distribution. +- **outputs:** + This op has no SSA result; it writes to UB memory. +- **constraints and limitations:** + The effective destination address MUST satisfy the alignment rule of the + selected store mode. The single-input `pto.vsts` family covers contiguous + store, first-element-only store, packed store, and channel-merge store. + Dual-input interleave store remains in `pto.vstsx2`. PTO surface exposes + store `dist` as family tokens, and each family only supports the element + widths listed below. + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `NORM` | `b8`, `b16`, `b32` | `UB[base + i] = src[i]` | **9** cycles | +| `1PT` | `b8`, `b16`, `b32` | Only element 0 is written to the destination footprint | **9** cycles | +| `PK` | `b16`, `b32`, `b64` | Pack low half bits of each source element before store | **9** cycles | +| `PK4` | `b32` | Pack low 8 bits of each `b32` element before store | **9** cycles | +| `MRG4CHN` | `b8` | Merge 4 channel planes into an interleaved 4-channel layout | **9** cycles | +| `MRG2CHN` | `b8`, `b16` | Merge 2 channel planes into an interleaved 2-channel layout | **9** cycles | + +**Example — Contiguous store:** +```mlir +pto.vsts %v, %ub[%offset], %mask {dist = "NORM"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +--- + +## Dual Stores (Interleave) + +### `pto.vstsx2` + +- **syntax:** `pto.vstsx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vreg, !pto.ptr, index, !pto.mask` +- **semantics:** Dual interleaved store (SoA → AoS conversion). +- **inputs:** + `%low` and `%high` are the two source vectors, `%dest` is the UB base pointer, + `%offset` is the displacement, `DIST` selects the interleave layout, and + `%mask` gates the participating elements. +- **outputs:** + This op has no SSA result; it writes an interleaved stream to UB. +- **constraints and limitations:** + This family is only legal for interleave distributions. The two source + vectors form an ordered pair, and the interleave semantics of that pair MUST + be preserved. PTO surface accepts the `INTLV` family, which only supports the + element widths listed below. + be preserved. PTO surface accepts the `INTLV` family, which only supports the + element widths listed below. +- **latency:** `INTLV` is **12** cycles。 + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `INTLV` | `b8`, `b16`, `b32` | Interleave `%low` / `%high` into one destination stream | **12** cycles | +| `INTLV` | `b8`, `b16`, `b32` | + +```c +// INTLV family on 32-bit elements: +for (int i = 0; i < 64; i++) { + UB[base + 8*i] = low[i]; + UB[base + 8*i + 4] = high[i]; +} +``` + +### `pto.vsstb` + +- **syntax:** `pto.vsstb %value, %dest, %block_stride, %repeat_stride, %mask : !pto.vreg, !pto.ptr, i16, i16, !pto.mask` +- **semantics:** Block-strided store for 2D tile access. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, + `%block_stride` and `%repeat_stride` are the two 16-bit fields of the + hardware control word, and `%mask` controls block participation. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + PTO surface does not expose the packed control word directly. Masked-off + blocks MUST NOT issue memory writes. +- **Latency:** **9** cycles. + +```c +// Block-strided store on 32-bit elements: one 32B block = 8 lanes. +for (int blk = 0; blk < 8; ++blk) { + if (pg_b32[blk]) + UB_block[base + repeat_stride + blk * block_stride] = src_block[blk]; +} +``` + +--- + +## Scatter (Indexed) Stores + +### `pto.vscatter` + +- **syntax:** `pto.vscatter %value, %dest, %offsets, %active_lanes : !pto.vreg, !pto.ptr, !pto.vreg, index` +- **semantics:** Indexed scatter to UB. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offsets` + provides per-lane or per-block indices, and `%active_lanes` bounds the active + requests. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + Only `b8`, `b16`, and `b32` element sizes are supported. The index vector + must use a supported integer element type and layout for this family. + Each computed address MUST be element-aligned. If two or more indices alias, + only one write is guaranteed and the winning lane is implementation-defined. +- **Latency:** **~17** cycles for **`Dtype: B16`**. + +```c +for (int i = 0; i < active_lanes; i++) + UB[base + offsets[i] * sizeof(T)] = src[i]; +``` + +--- + +## Alignment State Stores + +### `pto.vstas` +- **syntax:** `pto.vstas %value, %dest, %offset : !pto.align, !pto.ptr, i32` +- **semantics:** Scalar-register-offset form of alignment-state flush. +- **inputs:** + `%value` is the pending store-alignment state, `%dest` is the UB base + pointer, and `%offset` is the scalar-register style displacement. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + This family flushes pending store-alignment state using an explicit scalar + offset and keeps the scalar-offset form explicit. The incoming `%value` + should come from `pto.init_align` or from a prior state-producing unaligned + store op in the same stream. `%dest` and `%offset` together must identify the + same logical flush point produced by the immediately preceding stateful + unaligned-store step on that stream; using an unrelated base/offset pair is + invalid even if `%value` itself came from the same stream. +- **Latency:** **9** cycles. + +--- + +### `pto.vstar` +- **syntax:** `pto.vstar %value, %dest : !pto.align, !pto.ptr` +- **semantics:** Flush alignment state using the register-update form. +- **inputs:** + `%value` is the pending store-alignment state and `%dest` is the UB base + pointer. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + The implicit update state consumed by this flush MUST correspond to the same + store stream that produced `%value`. The first store-side state in a stream + should be created by `pto.init_align`. +- **Latency:** **9** cycles. + +--- + +### `pto.vstar` + +- **syntax:** `pto.vstar %value, %dest : !pto.align, !pto.ptr` +- **semantics:** Flush remaining alignment state. +- **inputs:** + `%value` is the pending alignment/buffer state that still needs to be emitted, + and `%dest` is the UB destination base pointer. +- **outputs:** + No SSA result. The effect is a memory-side flush that writes the remaining + buffered bytes to memory. +- **constraints and limitations:** + This op terminates an unaligned-store sequence. It MUST be paired with a + compatible prior state-producing store sequence so that the pending tail state + is well-defined. +- **Latency:** **9** cycles. + +--- + +## Stateful Store Ops + +These ops make reference-updated state explicit as SSA results. + +### `pto.vstus` + +- **syntax:** `%align_out = pto.vstus %align_in, %offset, %value, %base : !pto.align, i32, !pto.vreg, !pto.ptr -> !pto.align` +- **semantics:** No-post unaligned store with scalar offset. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%offset` is the scalar + displacement, `%value` is the vector being stored, and `%base` is the UB base + pointer. +- **outputs:** + `%align_out` is the updated buffered-tail state. +- **constraints and limitations:** + This is the scalar-offset stateful form of the unaligned store family. The + scalar offset width MUST match the selected form, and a later flush op is + still required. The first `%align_in` in the stream should come from + `pto.init_align`. This op does not mean "store a full vector starting at + `%base + %offset`". Instead, `%offset` describes how far the store stream + advances at this step, and `%align_out` carries any residual tail that could + not be committed yet. The no-post surface does not expose an updated base + pointer. A later flush op must therefore use an explicit destination/offset + pair that identifies the same logical flush point as this `pto.vstus`. +- **Latency:** **9** cycles. + +--- + +### `pto.vstur` + +- **syntax:** `%align_out = pto.vstur %align_in, %value, %base, "MODE" : !pto.align, !pto.vreg, !pto.ptr -> !pto.align` +- **semantics:** Unaligned store with residual flush and SPR-AR-driven state update. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%value` is the vector to + store, `%base` is the UB base pointer, and `MODE` selects whether the + hardware updates `SPR AR` after the store. +- **outputs:** + `%align_out` is the updated residual state after the current partial store. +- **constraints and limitations:** + The effective address is `base + AR`, where `AR` is the hardware SPR state + carried outside SSA. `POST_UPDATE` means hardware may advance `SPR AR` + according to the fixed `SPR SQZN` configuration; `NO_POST_UPDATE` preserves + the current `SPR AR` value. This form exposes only the evolving residual + align-state in SSA; it does not by itself guarantee that all buffered bytes + have reached memory. A compatible final flush is still required unless the + surrounding sequence is known to be complete. Independent sequences typically + begin from `AR = 0`; if the surrounding program does not already guarantee + that, the hardware sequence should clear `SPR AR` before the first dependent + `pto.vstur`. The first `%align_in` in the stream should come from + `pto.init_align`. `pto.vstur` also consumes the fixed `SPR SQZN` state, so a + preceding squeeze producer such as `pto.vsqz` / `pto.vusqz` MUST establish + the byte count before the store. `MODE` MUST be one of `POST_UPDATE` or + `NO_POST_UPDATE`. +- **Latency:** **9** cycles. diff --git a/docs/isa/04-predicate-load-store.md b/docs/isa/04-predicate-load-store.md new file mode 100644 index 000000000..9c3bed11d --- /dev/null +++ b/docs/isa/04-predicate-load-store.md @@ -0,0 +1,135 @@ +# 4. Predicate Load/Store + +> **Category:** UB ↔ Predicate Register data movement +> **Pipeline:** PIPE_V (Vector Core) + +Predicate registers (`!pto.mask`) are 256-bit registers that enable per-lane conditional execution. These ops move predicate values between UB and predicate registers. + +In concrete examples, `G` should be chosen to match the consumer family. The +examples below use `b32` when the loaded/stored mask is used with `f32` +vector compares or selects. + +The predicate load/store ops documented on this page always use explicit +`base[offset]` addressing. The immediate forms (`pldi`, `psti`) and dynamic +forms (`plds`, `psts`) differ only in how `%offset` is supplied. + +--- + +## Predicate Loads + +### `pto.plds` + +- **syntax:** `%result = pto.plds %source[%offset], "DIST" : !pto.ptr, index -> !pto.mask` +- **semantics:** Load predicate register with runtime offset. This is the + dynamic-offset form of `pto.pldi`: the predicate payload interpretation is + the same, but `%offset` is supplied as an SSA `index` instead of a constant + `index` immediate. +- **DIST:** mandatory string token, one of `NORM`, `US`, `DS`. + - `NORM`: load a normal packed predicate payload of size `VL/8`. + - `US`: load a packed predicate payload of size `VL/16`, then duplicate each + loaded bit once. + - `DS`: load a packed predicate payload of size `2 * VL/8`, then keep one + bit out of every two bits. + +The loaded payload is a packed predicate image in UB. Consumer ops interpret +the resulting `!pto.mask` according to the mask granularity `G`. +`pto.plds` only +models the explicit `base[offset]` form. + +**Example:** +```mlir +%mask = pto.plds %ub[%c0], "NORM" : !pto.ptr, index -> !pto.mask +``` + +--- + +### `pto.pldi` + +- **syntax:** `%result = pto.pldi %source[%offset], "DIST" : !pto.ptr, index -> !pto.mask` +- **offset:** must be a constant `index` immediate in PTO surface form. +- **semantics:** Load predicate register with immediate offset. +- **DIST:** mandatory string token, one of `NORM`, `US`, `DS`. + - `NORM`: load a normal packed predicate payload of size `VL/8`. + - `US`: load a packed predicate payload of size `VL/16`, then duplicate each + loaded bit once. + - `DS`: load a packed predicate payload of size `2 * VL/8`, then keep one + bit out of every two bits. + +Like `pto.plds`, this op reads a packed predicate payload from UB and +materializes it as `!pto.mask`. + +--- + +## Predicate Stores + +### `pto.psts` + +- **syntax:** `pto.psts %value, %dest[%offset], "DIST" : !pto.mask, !pto.ptr, index` +- **semantics:** Store predicate register with runtime offset. This is the + dynamic-offset form of `pto.psti`: the predicate payload interpretation is + the same, but `%offset` is supplied as an SSA `index` instead of a constant + `index` immediate. +- **DIST:** mandatory string token, one of `NORM`, `PK`. + - `NORM`: store the packed predicate payload into a normal destination space + of size `VL/8`. + - `PK`: store the packed predicate payload into a destination space of size + `VL/16`, keeping one bit out of every two bits. + +`pto.psts` stores the packed predicate payload represented by `!pto.mask`. +It only models the explicit `base[offset]` form. + +**Example:** +```mlir +pto.psts %mask, %ub[%c0], "NORM" : !pto.mask, !pto.ptr, index +``` + +--- + +### `pto.psti` + +- **syntax:** `pto.psti %value, %dest[%offset], "DIST" : !pto.mask, !pto.ptr, index` +- **offset:** must be a constant `index` immediate in PTO surface form. +- **semantics:** Store predicate register with immediate offset. +- **DIST:** mandatory string token, one of `NORM`, `PK`. + - `NORM`: store the packed predicate payload into a normal destination space + of size `VL/8`. + - `PK`: store the packed predicate payload into a destination space of size + `VL/16`, keeping one bit out of every two bits. + +`pto.psti` and `pto.psts` store the packed predicate payload represented by +`!pto.mask`. The surface distinction is only immediate-offset versus +dynamic-offset. + +--- + +### `pto.pstu` + +- **syntax:** `%align_out, %base_out = pto.pstu %align_in, %value, %base : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr` +- **syntax:** `%align_out, %base_out = pto.pstu %align_in, %value, %base : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr` +- **semantics:** Predicate unaligned store with align/base state update. The base type is fixed by mask granularity: `b16 <-> ui16`, `b32 <-> ui32`. +- **outputs:** + `%align_out` and `%base_out` are the updated unaligned-store state and are + intended to be used by a later `pto.pstu` call. +- **constraints and limitations:** + The first `%align_in` in a predicate unaligned-store stream should come from + `pto.init_align`. + +--- + +## Typical Usage Pattern + +```mlir +// Generate comparison mask +%mask = pto.vcmp %v0, %v1, %seed, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + +// Store mask to UB for later use +pto.psts %mask, %ub_mask[%c0], "NORM" : !pto.mask, !pto.ptr, index + +// ... later in another kernel ... + +// Load mask from UB +%saved_mask = pto.plds %ub_mask[%c0], "NORM" : !pto.ptr, index -> !pto.mask + +// Use for predicated select +%result = pto.vsel %v_true, %v_false, %saved_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` diff --git a/docs/isa/05-materialization-predicate.md b/docs/isa/05-materialization-predicate.md new file mode 100644 index 000000000..e6ee34975 --- /dev/null +++ b/docs/isa/05-materialization-predicate.md @@ -0,0 +1,322 @@ +# 5. Materialization & Predicate Ops + +> **Category:** Scalar broadcast, predicate generation and manipulation +> **Pipeline:** PIPE_V (Vector Core) + +These ops create vectors from scalar values and manipulate predicate registers. + +## Common Operand Model + +- `%value` is the scalar source value in SSA form. +- `%input` is either a source scalar or a source vector depending on the op. +- `%result` is the destination vector register value. +- For 32-bit scalar inputs, the scalar source MUST satisfy the backend's legal + scalar-source constraints for this family. + +--- + +## Scalar Materialization + +### `pto.vbr` + +- **syntax:** `%result = pto.vbr %value : T -> !pto.vreg` +- **semantics:** Broadcast scalar to all vector lanes. +- **inputs:** + `%value` is the scalar source. +- **outputs:** + `%result` is a vector whose active lanes all carry `%value`. +- **constraints and limitations:** + Supported forms are `b8`, `b16`, and `b32`. For `b8`, only the low 8 bits of + the scalar source are consumed. + +```c +for (int i = 0; i < N; i++) + dst[i] = value; +``` + +**Example:** +```mlir +%one = pto.vbr %c1_f32 : f32 -> !pto.vreg<64xf32> +``` + +--- + +### `pto.vdup` + +- **syntax:** `%result = pto.vdup %input, %mask {position = "LOWEST|HIGHEST"} : T|!pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Duplicate scalar or vector element to all lanes. +- **inputs:** + `%input` supplies the scalar or source-lane value selected by `position`, + and `%mask` controls the active lanes. +- **outputs:** + `%result` is the duplicated vector. +- **constraints and limitations:** + `position` selects which source vector element is duplicated and is only valid + for vector input. `position` defaults to `LOWEST`. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? input_scalar_or_element : 0; +``` + +--- + +## Predicate Generation + +### `pto.pset_b8` / `pto.pset_b16` / `pto.pset_b32` + +- **syntax:** `%result = pto.pset_b8 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pset_b16 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pset_b32 "PATTERN" : !pto.mask` +- **semantics:** Materialize a predicate register from a named pattern token. + +**Supported pattern tokens:** + +| Pattern | Description | +|---------|-------------| +| `PAT_ALL` | All lanes active | +| `PAT_ALLF` | All lanes inactive | +| `PAT_H` | High half active | +| `PAT_Q` | Upper quarter active | +| `PAT_VL1`...`PAT_VL128` | First N logical lanes active | +| `PAT_M3`, `PAT_M4` | Modular patterns | + +`PAT_ALL` is the PTO spelling of the VISA-style all-true predicate pattern. +The other tokens listed above are also concrete installed-toolchain pattern +objects, not PTO-only aliases. + +**Example — All 64 f32 lanes active:** +```mlir +%all_active = pto.pset_b32 "PAT_ALL" : !pto.mask +``` + +**Example — First 16 lanes active:** +```mlir +%first_16 = pto.pset_b32 "PAT_VL16" : !pto.mask +``` + +--- + +### `pto.pge_b8` / `pto.pge_b16` / `pto.pge_b32` + +- **syntax:** `%result = pto.pge_b8 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pge_b16 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pge_b32 "PATTERN" : !pto.mask` +- **semantics:** Generate a predicate from a lane-count pattern token. In the + common tail-mask form, `PAT_VL` marks the first `N` logical lanes active. +- **supported pattern tokens:** `PAT_ALL`, `PAT_ALLF`, `PAT_H`, `PAT_Q`, + `PAT_VL1`, `PAT_VL2`, `PAT_VL3`, `PAT_VL4`, `PAT_VL8`, `PAT_VL16`, + `PAT_VL32`, `PAT_VL64`, `PAT_VL128`, `PAT_M3`, `PAT_M4` + +```c +for (int i = 0; i < TOTAL_LANES; i++) + mask[i] = (i < len); +``` + +**Example — Tail mask for remainder loop:** +```mlir +%tail_mask = pto.pge_b32 "PAT_VL8" : !pto.mask +``` + +--- + +### `pto.plt_b8` / `pto.plt_b16` / `pto.plt_b32` + +- **syntax:** `%mask, %scalar_out = pto.plt_b8 %scalar : i32 -> !pto.mask, i32` +- **syntax:** `%mask, %scalar_out = pto.plt_b16 %scalar : i32 -> !pto.mask, i32` +- **syntax:** `%mask, %scalar_out = pto.plt_b32 %scalar : i32 -> !pto.mask, i32` +- **semantics:** Generate a tail-style predicate from an SSA lane-count value. + On A5/V300-style toolchains, this family is exposed as a post-update wrapper: + the predicate result becomes `%mask`, and the wrapper's carry-out scalar state + is surfaced as `%scalar_out`. +- **inputs:** + `%scalar` is the incoming lane-count / remaining-count state. +- **outputs:** + `%mask` is the generated predicate. + `%scalar_out` is the post-update scalar carry-out from the same `plt` call + and can be threaded into a subsequent `pto.plt_b*` call in the same chain. + +```c +for (int i = 0; i < VL_t; ++i) + mask[i] = (i < scalar_in); + +scalar_out = (scalar_in < VL_t) ? 0 : (scalar_in - VL_t); +``` + +Where `VL_t` is the logical lane count of the concrete op variant: + +- `pto.plt_b8`: `VL_t = 256` +- `pto.plt_b16`: `VL_t = 128` +- `pto.plt_b32`: `VL_t = 64` + +--- + +## Predicate Pack/Unpack + +### `pto.ppack` + +- **syntax:** `%result = pto.ppack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Narrowing pack of predicate register. +- **part tokens:** + - `LOWER`: pack into the lower half of `%result`; the upper half is zeroed. + - `HIGHER`: pack into the higher half of `%result`; the lower half is zeroed. + +Conceptually, `pto.ppack` keeps one bit out of each adjacent 2-bit group from +`%input`, packs those kept bits into the selected half of `%result`, and fills +the other half with zeros. + +```c +// Let VL be the logical lane count of the destination predicate. +// LOWER +for (int i = 0; i < VL / 2; ++i) + result[i] = input[2 * i]; +for (int i = VL / 2; i < VL; ++i) + result[i] = 0; + +// HIGHER +for (int i = 0; i < VL / 2; ++i) + result[VL / 2 + i] = input[2 * i]; +for (int i = 0; i < VL / 2; ++i) + result[i] = 0; +``` + +--- + +### `pto.punpack` + +- **syntax:** `%result = pto.punpack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Widening unpack of predicate register. +- **part tokens:** + - `LOWER`: unpack from the lower half of `%input`. + - `HIGHER`: unpack from the higher half of `%input`. + +Conceptually, `pto.punpack` reads the selected half of `%input`, zero-extends +each 1-bit predicate element into a 2-bit group in `%result`, and leaves the +expanded image in the full destination predicate register. + +```c +// Let VL be the logical lane count of the destination predicate. +// LOWER +for (int i = 0; i < VL / 2; ++i) { + result[2 * i] = input[i]; + result[2 * i + 1] = 0; +} + +// HIGHER +for (int i = 0; i < VL / 2; ++i) { + result[2 * i] = input[VL / 2 + i]; + result[2 * i + 1] = 0; +} +``` + +--- + +## Predicate Logical Ops + +### `pto.pand` + +- **syntax:** `%result = pto.pand %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise AND gated by a governing predicate. + +Inactive lanes selected out by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (src0[i] & src1[i]) : 0; +``` + +--- + +### `pto.por` + +- **syntax:** `%result = pto.por %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise OR gated by a governing predicate. + +Inactive lanes selected out by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (src0[i] | src1[i]) : 0; +``` + +--- + +### `pto.pxor` + +- **syntax:** `%result = pto.pxor %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise XOR gated by a governing predicate. + +Inactive lanes selected by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (src0[i] ^ src1[i]) : 0; +``` + +--- + +### `pto.pnot` + +- **syntax:** `%result = pto.pnot %input, %mask : !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise NOT gated by a governing predicate. + +Inactive lanes selected by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (~src[i]) : 0; +``` + +--- + +### `pto.psel` + +- **syntax:** `%result = pto.psel %src0, %src1, %sel : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate select (mux). `%sel` is the governing predicate that + chooses lanes from `%src0` or `%src1`. + +```c +for (int i = 0; i < N; i++) + dst[i] = sel[i] ? src0[i] : src1[i]; +``` + +--- + +### `pto.pdintlv_b8` / `pto.pdintlv_b16` / `pto.pdintlv_b32` + +- **syntax:** `%low, %high = pto.pdintlv_b8 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pdintlv_b16 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pdintlv_b32 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **semantics:** De-interleave two predicate sources and return the two + de-interleaved predicate images in the same predicate element family. + +--- + +### `pto.pintlv_b8` / `pto.pintlv_b16` / `pto.pintlv_b32` + +- **syntax:** `%low, %high = pto.pintlv_b8 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pintlv_b16 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pintlv_b32 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **semantics:** Interleave two predicate sources and return the two + resulting predicate images in the same predicate element family. + +--- + +## Typical Usage + +```mlir +// Generate all-active mask for f32 (64 lanes) +%all = pto.pset_b32 "PAT_ALL" : !pto.mask + +// Generate tail mask for remainder (last 12 elements) +%tail = pto.pge_b32 "PAT_VL12" : !pto.mask + +// Compare and generate mask +%cmp_mask = pto.vcmp %a, %b, %all, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + +// Combine masks: only process tail elements that passed comparison +%combined = pto.pand %cmp_mask, %tail, %all : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + +// Use for predicated operation +%result = pto.vsel %true_vals, %false_vals, %combined : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` diff --git a/docs/isa/06-unary-vector-ops.md b/docs/isa/06-unary-vector-ops.md new file mode 100644 index 000000000..2706ac39b --- /dev/null +++ b/docs/isa/06-unary-vector-ops.md @@ -0,0 +1,172 @@ +# 6. Unary Vector Ops + +> **Category:** Single-input vector operations +> **Pipeline:** PIPE_V (Vector Core) + +Element-wise operations that take one vector input and produce one vector output. + +## Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate operand. For this family, inactive lanes follow the + predication behavior of the selected instruction form: zeroing forms + zero-fill inactive lanes, while merging forms preserve the destination value. +- `%result` is the destination vector register value. Unless stated otherwise, + `%result` has the same lane count and element type as `%input`. + +## CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** values use **aclFloat16** in traces where measured. **bf16:** no simple-tile ST coverage on this surface; treat as **—**. + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vabs` | `RV_VABS_FP` | **5** | **5** | — | +| `pto.vneg` | `RV_VMULS` | **8** | **8** | — | +| `pto.vexp` | `RV_VEXP` | **16** | **21** | — | +| `pto.vln` | `RV_VLN` | **18** | **23** | — | +| `pto.vsqrt` | `RV_VSQRT` | **17** | **22** | — | +| `pto.vrelu` | `RV_VRELU` | **5** | **5** | — | +| `pto.vnot` | `RV_VNOT` | — | int-only paths | — | +| `pto.vmov` | `RV_VLD` proxy | **9** | **9** | — | + +--- + +## Arithmetic + +### `pto.vabs` + +- **syntax:** `%result = pto.vabs %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] < 0) ? -src[i] : src[i]; +``` + +- **inputs:** `%input` supplies the source lanes and `%mask` selects which lanes + participate. +- **outputs:** `%result` receives the lane-wise absolute values. +- **constraints and limitations:** Source and result types MUST match. On A5, + integer overflow follows the ISA default truncation behavior for this family; + `pto.vabs` is not an explicit saturating op. + +--- + +### `pto.vneg` + +- **syntax:** `%result = pto.vneg %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = -src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise arithmetic negation. +- **constraints and limitations:** Source and result types MUST match. + +--- + +## Transcendental + +### `pto.vexp` + +- **syntax:** `%result = pto.vexp %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = expf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds `exp(input[i])` per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + +--- + +### `pto.vln` + +- **syntax:** `%result = pto.vln %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = logf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the natural logarithm per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + For real-number semantics, active inputs SHOULD be strictly positive; non- + positive inputs follow the target's exception/NaN rules. + +--- + +### `pto.vsqrt` + +- **syntax:** `%result = pto.vsqrt %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = sqrtf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the square root per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + Negative active inputs follow the target's exception/NaN rules. + +--- + +## Activation + +### `pto.vrelu` + +- **syntax:** `%result = pto.vrelu %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] > 0) ? src[i] : 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds `max(input[i], 0)` per active lane. +- **constraints and limitations:** Only floating-point element types are legal + on the current A5 surface described here. + +--- + +## Bitwise + +### `pto.vnot` + +- **syntax:** `%result = pto.vnot %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = ~src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the lane-wise bitwise inversion. +- **constraints and limitations:** Integer element types only. + +--- + +## Movement + +## Typical Usage + +```mlir +// Softmax numerator: exp(x - max) +%sub = pto.vsub %x, %max_broadcast, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%exp = pto.vexp %sub, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// ReLU activation +%activated = pto.vrelu %linear_out, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` diff --git a/docs/isa/07-binary-vector-ops.md b/docs/isa/07-binary-vector-ops.md new file mode 100644 index 000000000..0ab4ae554 --- /dev/null +++ b/docs/isa/07-binary-vector-ops.md @@ -0,0 +1,293 @@ +# 7. Binary Vector Ops + +> **Category:** Two-input vector operations +> **Pipeline:** PIPE_V (Vector Core) + +Element-wise operations that take two vector inputs and produce one vector output. + +## Common Operand Model + +- `%lhs` and `%rhs` are the two source vector register values. +- `%mask` is the predicate operand `Pg` that gates which lanes participate. +- `%result` is the destination vector register value. Unless explicitly noted, + it has the same lane count and element type as the inputs. +- Unless explicitly documented otherwise, `%lhs`, `%rhs`, and `%result` MUST + have matching vector shapes and element types. + +## CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** uses **aclFloat16** in measured traces. **bf16:** — (no dedicated vec tile ST on this surface). + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vadd` | `RV_VADD` | **7** | **7** | — | +| `pto.vsub` | `RV_VSUB` | **7** | **7** | — | +| `pto.vmul` | `RV_VMUL` | **8** | **8** | — | +| `pto.vdiv` | `RV_VDIV` | **17** | **22** | — | + +--- + +## Arithmetic + +### `pto.vadd` + +- **syntax:** `%result = pto.vadd %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i64, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +- **inputs:** `%lhs` and `%rhs` are added lane-wise; `%mask` selects active + lanes. +- **outputs:** `%result` is the lane-wise sum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +### `pto.vsub` + +- **syntax:** `%result = pto.vsub %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i64, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] - src1[i]; +``` + +- **inputs:** `%lhs` is the minuend, `%rhs` is the subtrahend, and `%mask` + selects active lanes. +- **outputs:** `%result` is the lane-wise difference. +- **constraints and limitations:** Input and result types MUST match. + +--- + +### `pto.vmul` + +- **syntax:** `%result = pto.vmul %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, bf16, f32 (**NOT** i8/ui8) + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] * src1[i]; +``` + +- **inputs:** `%lhs` and `%rhs` are multiplied lane-wise; `%mask` selects + active lanes. +- **outputs:** `%result` is the lane-wise product. +- **constraints and limitations:** The current A5 profile excludes `i8/ui8` + forms from this surface. + +--- + +### `pto.vdiv` + +- **syntax:** `%result = pto.vdiv %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 only (no integer division) + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] / src1[i]; +``` + +- **inputs:** `%lhs` is the numerator, `%rhs` is the denominator, and `%mask` + selects active lanes. +- **outputs:** `%result` is the lane-wise quotient. +- **constraints and limitations:** Floating-point element types only. Active + denominators containing `+0` or `-0` follow the target's exceptional + behavior. + +--- + +### `pto.vmax` + +- **syntax:** `%result = pto.vmax %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src0[i] > src1[i]) ? src0[i] : src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` holds the lane-wise maximum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +### `pto.vmin` + +- **syntax:** `%result = pto.vmin %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src0[i] < src1[i]) ? src0[i] : src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` holds the lane-wise minimum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +## Bitwise + +### `pto.vand` + +- **syntax:** `%result = pto.vand %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] & src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise AND. +- **constraints and limitations:** Integer element types only. + +--- + +### `pto.vor` + +- **syntax:** `%result = pto.vor %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] | src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise OR. +- **constraints and limitations:** Integer element types only. + +--- + +### `pto.vxor` + +- **syntax:** `%result = pto.vxor %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] ^ src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise XOR. +- **constraints and limitations:** Integer element types only. + +--- + +## Shift + +### `pto.vshl` + +- **syntax:** `%result = pto.vshl %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] << src1[i]; +``` + +- **inputs:** `%lhs` supplies the shifted value, `%rhs` supplies the per-lane + shift amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. Shift counts + SHOULD stay within `[0, bitwidth(T) - 1]`; out-of-range behavior is target- + defined unless the verifier narrows it further. + +--- + +### `pto.vshr` + +- **syntax:** `%result = pto.vshr %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] >> src1[i]; // arithmetic for signed, logical for unsigned +``` + +- **inputs:** `%lhs` supplies the shifted value, `%rhs` supplies the per-lane + shift amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. Signedness of the + element type determines arithmetic vs logical behavior. + +--- + +## Carry Operations + +### `pto.vaddc` + +- **syntax:** `%result, %carry = pto.vaddc %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Add with carry output. + +```c +for (int i = 0; i < N; i++) { + uint64_t r = (uint64_t)src0[i] + src1[i]; + dst[i] = (T)r; + carry[i] = (r >> bitwidth); +} +``` + +- **inputs:** `%lhs` and `%rhs` are added lane-wise and `%mask` selects active + lanes. +- **outputs:** `%result` is the truncated arithmetic result and `%carry` is the + carry/overflow predicate per lane. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This is a carry-chain integer add family. On + the current A5 surface, only 32-bit integer element types are supported. + `%mask` and `%carry` therefore use the same typed-mask granularity as the + data vector family, which on the current documented A5 surface means + `!pto.mask`. + +--- + +### `pto.vsubc` + +- **syntax:** `%result, %carry = pto.vsubc %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Subtract with per-lane carry output. + +```c +for (int i = 0; i < N; i++) { + dst[i] = src0[i] - src1[i]; + carry[i] = (src0[i] >= src1[i]); +} +``` + +- **inputs:** `%lhs` and `%rhs` are subtracted lane-wise and `%mask` selects + active lanes. +- **outputs:** `%result` is the arithmetic difference and `%carry` is the + per-lane carry predicate. For this subtraction family, active lanes set + `%carry[i] = 1` when the subtraction completes without borrow, and + `%carry[i] = 0` when a borrow occurs. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This operation is currently restricted to + the 32-bit integer carry/borrow-chain family. `%mask` and `%carry` + therefore use the same typed-mask granularity as the data vector family, + which on the current documented A5 surface means `!pto.mask`. + +--- + +## Typical Usage + +```mlir +// Vector addition +%sum = pto.vadd %a, %b, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Element-wise multiply +%prod = pto.vmul %x, %y, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Clamp to range [min, max] +%clamped_low = pto.vmax %input, %min_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%clamped = pto.vmin %clamped_low, %max_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Bit manipulation +%masked = pto.vand %data, %bitmask, %mask : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +``` diff --git a/docs/isa/08-vec-scalar-ops.md b/docs/isa/08-vec-scalar-ops.md new file mode 100644 index 000000000..9ef60d3cb --- /dev/null +++ b/docs/isa/08-vec-scalar-ops.md @@ -0,0 +1,236 @@ +# 8. Vec-Scalar Ops + +> **Category:** Vector-scalar operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that combine a vector with a scalar value, applying the scalar to every lane. + +## Common Operand Model + +- `%input` is the source vector register value. +- `%scalar` is the scalar operand in SSA form. +- `%mask` is the predicate operand. +- `%result` is the destination vector register value. +- For 32-bit scalar forms, the scalar source MUST satisfy the backend's legal + scalar-source constraints for this family. +- For elementwise vec-scalar families whose scalar conceptually matches the + vector element type (`pto.vadds`, `pto.vmuls`, `pto.vmaxs`, + `pto.vmins`, `pto.vlrelu`): + - signed integer vectors accept signed integer scalars with the same width, + and also accept signless `i` + - unsigned integer vectors accept unsigned integer scalars with the same + width, and also accept signless `i` + - signless integer vectors accept signless `i` +- `pto.vshls` and `pto.vshrs` are not part of that rule; their scalar operand + is the shift amount and remains fixed to `i16`. + +## CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** uses **aclFloat16** in measured traces. **bf16:** —. + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vadds` | `RV_VADDS` | **7** | **7** | — | +| `pto.vmuls` | `RV_VMULS` | **8** | **8** | — | + +--- + +## Arithmetic + +### `pto.vadds` + +- **syntax:** `%result = pto.vadds %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` +- **A5 types:** `si8`, `si16`, `si32`, `ui8`, `ui16`, `ui32`, `f16`, `bf16`, `f32` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] + scalar; +``` + +- **inputs:** `%input` is the source vector, `%scalar` is broadcast logically to + each lane, and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise sum. +- **constraints and limitations:** Input vector element type, scalar type, and + result vector element type MUST match. For integer vector forms, `%scalar` + may also use matching-signedness integer or signless `i` with the same + bit width as the vector element type, so it can be fed directly from `arith` + constants. + +--- + +### `pto.vmuls` + +- **syntax:** `%result = pto.vmuls %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] * scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise product. +- **constraints and limitations:** Supported element types are hardware-family + specific; the current PTO micro Instruction documentation covers the common + numeric cases. For integer vector forms, `%scalar` may use matching-signedness + integer or signless `i` with the same bit width as the vector element + type. + +--- + +### `pto.vmaxs` + +- **syntax:** `%result = pto.vmaxs %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] > scalar) ? src[i] : scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise maximum. +- **constraints and limitations:** Input and result types MUST match. For + integer vector forms, `%scalar` may use matching-signedness integer or + signless `i` with the same bit width as the vector element type. + +--- + +### `pto.vmins` + +- **syntax:** `%result = pto.vmins %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] < scalar) ? src[i] : scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise minimum. +- **constraints and limitations:** Input and result types MUST match. For + integer vector forms, `%scalar` may use matching-signedness integer or + signless `i` with the same bit width as the vector element type. + +--- + +## Shift + +### `pto.vshls` + +- **syntax:** `%result = pto.vshls %input, %scalar, %mask : !pto.vreg, i16, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] << scalar; +``` + +- **inputs:** `%input` is the value vector, `%scalar` is the uniform `i16` shift + amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. The shift amount + SHOULD stay within the source element width. + +--- + +### `pto.vshrs` + +- **syntax:** `%result = pto.vshrs %input, %scalar, %mask : !pto.vreg, i16, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] >> scalar; +``` + +- **inputs:** `%input` is the value vector, `%scalar` is the uniform `i16` shift + amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. + +--- + +### `pto.vlrelu` + +- **syntax:** `%result = pto.vlrelu %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : scalar * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%scalar` is the leaky slope, + and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise leaky-ReLU result. +- **constraints and limitations:** Only `f16` and `f32` forms are currently + documented for `pto.vlrelu`. + +--- + +## Carry Operations + +### `pto.vaddcs` + +- **syntax:** `%result, %carry = pto.vaddcs %lhs, %rhs, %carry_in, %mask : !pto.vreg, !pto.vreg, !pto.mask, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Add with carry-in and carry-out. + +```c +for (int i = 0; i < N; i++) { + uint64_t r = (uint64_t)src0[i] + src1[i] + carry_in[i]; + dst[i] = (T)r; + carry_out[i] = (r >> bitwidth); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the value vectors, `%carry_in` is the + incoming carry predicate, and `%mask` selects active lanes. +- **outputs:** `%result` is the arithmetic result and `%carry` is the carry-out + predicate. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This is the scalar-extended carry-chain + family. On the current A5 surface, only 32-bit integer element types are + supported. `%carry_in`, `%mask`, and `%carry` therefore all use the same + typed-mask granularity as the data vector family, which on the current + documented A5 surface means `!pto.mask`. + +--- + +### `pto.vsubcs` + +- **syntax:** `%result, %carry = pto.vsubcs %lhs, %rhs, %carry_in, %mask : !pto.vreg, !pto.vreg, !pto.mask, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Subtract with carry input and output. + +```c +for (int i = 0; i < N; i++) { + dst[i] = src0[i] - src1[i] - (1 - carry_in[i]); + carry_out[i] = (src0[i] >= src1[i] + (1 - carry_in[i])); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the value vectors, `%carry_in` is the + incoming carry predicate, and `%mask` selects active lanes. +- **outputs:** `%result` is the arithmetic result and `%carry` is the + carry predicate after the lane-wise subtraction. For this subtraction family, + active lanes set `%carry[i] = 1` when the subtraction completes without + borrow, and `%carry[i] = 0` when a borrow occurs. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This is the scalar-extended borrow-chain + family and is currently restricted to 32-bit integer element types. + `%carry_in`, `%mask`, and `%carry` therefore all use the same typed-mask + granularity as the data vector family, which on the current documented A5 + surface means `!pto.mask`. + +--- + +## Typical Usage + +```mlir +// Add bias to all elements +%biased = pto.vadds %activation, %bias_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Scale by constant +%scaled = pto.vmuls %input, %scale, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Clamp to [0, 255] for uint8 quantization +%clamped_low = pto.vmaxs %input, %c0, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> +%clamped = pto.vmins %clamped_low, %c255, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Shift right by fixed amount +%shifted = pto.vshrs %data, %c4, %mask : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +``` diff --git a/docs/isa/09-conversion-ops.md b/docs/isa/09-conversion-ops.md new file mode 100644 index 000000000..c3674f523 --- /dev/null +++ b/docs/isa/09-conversion-ops.md @@ -0,0 +1,257 @@ +# 9. Conversion Ops + +> **Category:** Type conversion operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that convert between data types (float/int, narrowing/widening). + +## Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate mask that selects active conversion lanes. +- `%result` is the destination vector register value. +- `rnd`, `sat`, and `part` are optional attributes that refine + conversion behavior when the selected source/destination type pair needs + rounding, saturation, or lane placement control. +- The single `pto.vcvt` surface covers float-int, float-float, int-float, and + int-int conversion families. + +## CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). Only representative traces below; other `pto.vcvt` conversion pairs depend on the RV lowering in the trace. + +| PTO op | RV (CA) | Note | Latency | +|--------|---------|------|---------| +| `pto.vcvt` | `RV_VCVT_F2F` | f32→f16 | **7** | +| `pto.vci` | — | no vector `RV_*` in sampled `veccore0` trace | — | + +--- + +## `pto.vci` + +- **syntax:** `%result = pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- **semantics:** Generate a lane-index vector from a scalar seed/index value. +- **inputs:** + `%index` is the scalar seed or base index. +- **outputs:** + `%result` is the generated index vector. +- **constraints and limitations:** + This is an index-generation family, not a numeric conversion. `ORDER` and the + result element type together determine how indices are generated. `%result` + uses an integer element type, and the scalar `%index` type matches that + result element type. + +--- + +## `pto.vcvt` + +- **syntax:** `%result = pto.vcvt %input, %mask {rnd = "RND", sat = "SAT", part = "PART"} : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Type conversion between float/int types with rounding control. + +```c +for (int i = 0; i < min(N, M); i++) + if (mask[i]) + dst[i] = convert(src[i], T0, T1, rnd); +``` + +- **inputs:** + `%input` is the source vector, `%mask` selects active lanes, and attributes + select rounding, saturation, and output placement when the conversion changes + width or packs into sub-lane positions. +- **outputs:** + `%result` is the converted vector. +- **constraints and limitations:** + Only documented source/destination type pairs are legal. All three + attributes are optional at the surface level, but only the subset meaningful + to the selected conversion kind should be provided. The execution mask must + use the typed-mask granularity that matches the source vector family on the + current surface; there is no `!pto.mask` form in VPTO. + +--- + +### Rounding Modes + +| Mode | Description | +|------|-------------| +| `R` | Round to nearest, ties to even (default) | +| `A` | Round away from zero | +| `F` | Round toward negative infinity (floor) | +| `C` | Round toward positive infinity (ceil) | +| `Z` | Round toward zero (truncate) | +| `O` | Round to odd | + +--- + +### Saturation Modes + +| Mode | Description | +|------|-------------| +| `SAT` | Saturate on overflow | +| `NOSAT` | No saturation (wrap/undefined on overflow) | + +--- + +### Part Modes + +Use `part` when a width-changing conversion writes only one half of each wider +destination lane group. This is typically used in even/odd placement forms such +as `32 -> 16` or `16 -> 32` style conversions. + +| Mode | Description | +|------|-------------| +| `EVEN` | Output to even-indexed lanes | +| `ODD` | Output to odd-indexed lanes | + +--- + +### Attribute Guidance + +- `rnd` + - Use when the conversion needs an explicit rounding rule, especially for + float-to-int, float-to-float narrowing, or integer-to-float forms that do + not map exactly. +- `mask` + - Use to select which source lanes participate in the conversion. In + width-changing conversions, `mask` works together with `part` / `pp` to + determine which logical lane positions are produced. +- `sat` + - Use when the conversion may overflow the destination range and hardware + exposes a saturating form. +- `part` + - Use for width-changing conversions that select the even or odd half of the + destination packing layout. + +### A5 Supported Forms + +The forms below are expressed in PTO surface syntax. Source/target type +combinations not listed here should not currently be assumed to be supported +on A5. + +#### Float To Int + +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<32xsi64>` +- `%dst = pto.vcvt %src, %mask {rnd, sat} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {rnd, part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {rnd, sat} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<256xsi8>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xsi32>` + +#### Float To Float + +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xbf16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xf32>` + +#### Int To Float + +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xsi8>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {rnd} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<64xf32>` +- `%dst = pto.vcvt %src, %mask {rnd} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<64xf32>` + +#### Int To Int + +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<128xui16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<64xui32>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xsi8>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xsi8>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<64xui32>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<64xui32>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<128xui16>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<128xui16>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<32xsi64>` + +### A5 Supported Type Matrix + +The table below is only a summary. For exact attribute combinations, use the +per-form entries above as the source of truth. + +| `src \ dst` | `ui8` | `si8` | `ui16` | `si16` | `ui32` | `si32` | `si64` | `f16` | `f32` | `bf16` | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| `ui8` | | | Y | | Y | | | Y | | | +| `si8` | | | | Y | | Y | | Y | | | +| `ui16` | Y | | | | Y | | | | | | +| `si16` | Y | | | | Y | Y | | Y | Y | | +| `ui32` | Y | | Y | Y | | | | | | | +| `si32` | Y | | Y | Y | | | Y | | Y | | +| `si64` | | | | | | | | | | | +| `f16` | Y | Y | | Y | | Y | | | Y | | +| `f32` | | | | Y | | Y | Y | Y | | Y | +| `bf16` | | | | | | Y | | | Y | | + +--- + +### Width-Changing Conversion Pattern + +For conversions that change width (e.g., f32→f16), use even/odd parts and combine: + +```mlir +// Convert two f32 vectors to one f16 vector +%even = pto.vcvt %in0, %mask {rnd = "R", sat = "SAT", part = "EVEN"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%odd = pto.vcvt %in1, %mask {rnd = "R", sat = "SAT", part = "ODD"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%result = pto.vor %even, %odd, %mask : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +``` + +--- + +## `pto.vtrc` + +- **syntax:** `%result = pto.vtrc %input, "RND" : !pto.vreg -> !pto.vreg` +- **semantics:** Truncate/round float to integer-valued float (stays in float type). + +```c +for (int i = 0; i < N; i++) + dst[i] = round_to_int_valued_float(src[i], rnd); +``` + +- **inputs:** + `%input` is the floating-point source vector and `RND` selects the + truncation/rounding rule. +- **outputs:** + `%result` is still a floating-point vector, but each active lane now carries + an integer-valued floating-point result. +- **constraints and limitations:** + This op does not change the element type. `O` is supported for avoiding + double-rounding errors during staged conversions. + +**Example:** +```mlir +// Round to nearest integer, keep as float +%rounded = pto.vtrc %input, "R" : !pto.vreg<64xf32> -> !pto.vreg<64xf32> +// input: [1.4, 2.6, -1.5, 3.0] +// output: [1.0, 3.0, -2.0, 3.0] +``` + +--- + +## Typical Usage + +```mlir +// Quantization: f32 → i8 with saturation +%scaled = pto.vmuls %input, %scale, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> +%quantized = pto.vcvt %scaled, %mask {rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> +// Then narrow i32 → i8 via pack ops + +// Mixed precision: bf16 → f32 for accumulation +%f32_vec = pto.vcvt %bf16_input, %mask {part = "EVEN"} + : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xf32> + +// Floor for integer division +%floored = pto.vtrc %ratio, "F" : !pto.vreg<64xf32> -> !pto.vreg<64xf32> +%int_div = pto.vcvt %floored, %mask {rnd = "Z"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> +``` diff --git a/docs/isa/10-reduction-ops.md b/docs/isa/10-reduction-ops.md new file mode 100644 index 000000000..b2fb20894 --- /dev/null +++ b/docs/isa/10-reduction-ops.md @@ -0,0 +1,244 @@ +# 10. Reduction Ops + +> **Category:** Vector reduction operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that reduce a vector to a scalar or per-group result. + +## Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate operand `Pg`; inactive lanes do not participate. +- `%result` is the destination vector register value. +- Reduction results are written into the low-significance portion of the + destination vector and the remaining destination bits are zero-filled. + +--- + +## Full Vector Reductions + +### `pto.vcadd` + +- **syntax:** `%result = pto.vcadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i64, f16, f32 +- **semantics:** Sum all elements. Result in lane 0, others zeroed. + +```c +T sum = 0; +for (int i = 0; i < N; i++) + sum += src[i]; +dst[0] = sum; +for (int i = 1; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains the reduction result in its low element(s). +- **constraints and limitations:** Some narrow integer forms may widen the + internal accumulation or result placement. If all predicate bits are zero, the + result is zero. + +--- + +### `pto.vcmax` + +- **syntax:** `%result = pto.vcmax %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Find max element with argmax. The lowest destination element + stores the maximum value, the second-lowest destination element stores the + index of the first maximum, and all remaining elements are zero-filled. + +```c +T mx = -INF; int idx = 0; +for (int i = 0; i < N; i++) + if (src[i] > mx) { mx = src[i]; idx = i; } +dst[0] = mx; +dst[1] = idx; +for (int i = 2; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result[0]` holds the extremum value and `%result[1]` holds the + index. Other destination elements are zero-filled. +- **constraints and limitations:** If there are multiple maxima, the minimum + index is written. For floating-point types, inactive lanes are treated as + `-INF`; if all lanes are inactive, `%result[0]` becomes `-INF`. For integer + types, inactive lanes are treated as the literal minimum value; if all lanes + are inactive, `%result[0]` becomes that literal minimum value. The index is + written into the second destination element slot of the same destination + vector register. + +--- + +### `pto.vcmin` + +- **syntax:** `%result = pto.vcmin %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Find min element with argmin. The lowest destination element + stores the minimum value, the second-lowest destination element stores the + index of the first minimum, and all remaining elements are zero-filled. + +```c +T mn = INF; int idx = 0; +for (int i = 0; i < N; i++) + if (src[i] < mn) { mn = src[i]; idx = i; } +dst[0] = mn; +dst[1] = idx; +for (int i = 2; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result[0]` holds the extremum value and `%result[1]` holds the + index. Other destination elements are zero-filled. +- **constraints and limitations:** If there are multiple minima, the minimum + index is written. For floating-point types, inactive lanes are treated as + `+INF`; if all lanes are inactive, `%result[0]` becomes `+INF`. For integer + types, inactive lanes are treated as the literal maximum value; if all lanes + are inactive, `%result[0]` becomes that literal maximum value. The index is + written into the second destination element slot of the same destination + vector register. + +--- + +## Per-VLane (Group) Reductions + +The vector register is organized as **8 VLanes** of 32 bytes each. Group reductions operate within each VLane independently. + +``` +vreg layout (f32 example, 64 elements total): +VLane 0: [0..7] VLane 1: [8..15] VLane 2: [16..23] VLane 3: [24..31] +VLane 4: [32..39] VLane 5: [40..47] VLane 6: [48..55] VLane 7: [56..63] +``` + +### `pto.vcgadd` + +- **syntax:** `%result = pto.vcgadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Sum within each VLane. 8 results at indices 0, 8, 16, 24, 32, 40, 48, 56 (for f32). + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +// For f32: results at dst[0], dst[8], dst[16], dst[24], dst[32], dst[40], dst[48], dst[56] +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one sum per 32-byte VLane group, written + contiguously into the low slot of each group. +- **constraints and limitations:** This is a per-32-byte VLane-group reduction. + Inactive lanes are treated as zero. + +--- + +### `pto.vcgmax` + +- **syntax:** `%result = pto.vcgmax %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Max within each VLane. + +```c +int K = N / 8; +for (int g = 0; g < 8; g++) { + T mx = -INF; + for (int i = 0; i < K; i++) + if (src[g*K + i] > mx) mx = src[g*K + i]; + dst[g*K] = mx; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one maximum per 32-byte VLane group. +- **constraints and limitations:** Grouping is by hardware 32-byte VLane, not by + arbitrary software subvector. + +--- + +### `pto.vcgmin` + +- **syntax:** `%result = pto.vcgmin %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Min within each VLane. + +```c +int K = N / 8; +for (int g = 0; g < 8; g++) { + T mn = INF; + for (int i = 0; i < K; i++) + if (src[g*K + i] < mn) mn = src[g*K + i]; + dst[g*K] = mn; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one minimum per 32-byte VLane group. +- **constraints and limitations:** Grouping is by hardware 32-byte VLane, not by + arbitrary software subvector. + +--- + +## Prefix Operations + +### `pto.vcpadd` + +- **syntax:** `%result = pto.vcpadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Inclusive prefix sum (scan). + +```c +dst[0] = src[0]; +for (int i = 1; i < N; i++) + dst[i] = dst[i-1] + src[i]; +``` + +**Example:** +```c +// input: [1, 2, 3, 4, 5, ...] +// output: [1, 3, 6, 10, 15, ...] +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` is the inclusive prefix-sum vector. +- **constraints and limitations:** Only floating-point element types are + documented on the current A5 surface here. + +--- + +## Typical Usage + +```mlir +// Softmax: find max for numerical stability +%max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// max is in lane 0, broadcast it +%max_broadcast = pto.vlds %ub_tmp[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> + +// Row-wise sum using vcgadd (for 8-row tile) +%row_sums = pto.vcgadd %tile, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// Results at indices 0, 8, 16, 24, 32, 40, 48, 56 + +// Full vector sum for normalization +%total = pto.vcadd %values, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// total[0] contains the sum + +// Prefix sum for cumulative distribution +%cdf = pto.vcpadd %pdf, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` diff --git a/docs/isa/11-compare-select.md b/docs/isa/11-compare-select.md new file mode 100644 index 000000000..bc28f2fd1 --- /dev/null +++ b/docs/isa/11-compare-select.md @@ -0,0 +1,182 @@ +# 11. Compare & Select + +> **Category:** Comparison and conditional selection operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that compare vectors and conditionally select elements. + +## Common Operand Model + +- `%src0` and `%src1` are source vector operands. +- `%scalar` is the scalar operand for scalar-comparison families. +- `%seed` is the incoming predicate that limits which lanes participate in the + compare. +- `%result` is either a predicate mask (`vcmp`, `vcmps`) or a vector register + (`vsel`, `vselr`, `vselrv2`). + +--- + +## Comparison Operations + +### `pto.vcmp` + +- **syntax:** `%result = pto.vcmp %src0, %src1, %seed, "CMP_MODE" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.mask` +- **semantics:** Element-wise comparison, output predicate mask. + +```c +for (int i = 0; i < N; i++) + if (seed[i]) + dst[i] = (src0[i] CMP src1[i]) ? 1 : 0; +``` + +**Compare modes:** + +| Mode | Operation | +|------|-----------| +| `eq` | Equal (==) | +| `ne` | Not equal (!=) | +| `lt` | Less than (<) | +| `le` | Less than or equal (<=) | +| `gt` | Greater than (>) | +| `ge` | Greater than or equal (>=) | + +**Example:** +```mlir +%all_active = pto.pset_b32 "PAT_ALL" : !pto.mask +%lt_mask = pto.vcmp %a, %b, %all_active, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +// lt_mask[i] = 1 if a[i] < b[i] +``` + +- **inputs:** `%src0`, `%src1`, and `%seed`; `CMP_MODE` selects the comparison + predicate. +- **outputs:** `%result` is the generated predicate mask. +- **constraints and limitations:** Only lanes enabled by `%seed` participate. + Integer and floating-point comparisons follow their own element-type-specific + comparison rules. `%seed` and `%result` keep the typed-mask granularity that + matches `%src0` / `%src1`. + +--- + +### `pto.vcmps` + +- **syntax:** `%result = pto.vcmps %src, %scalar, %seed, "CMP_MODE" : !pto.vreg, T, !pto.mask -> !pto.mask` +- **semantics:** Compare vector against scalar. + +```c +for (int i = 0; i < N; i++) + if (seed[i]) + dst[i] = (src[i] CMP scalar) ? 1 : 0; +``` + +**Example:** +```mlir +%positive_mask = pto.vcmps %values, %c0_f32, %all_active, "gt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +// positive_mask[i] = 1 if values[i] > 0 +``` + +- **inputs:** `%src` is the vector source, `%scalar` is the scalar comparison + value, and `%seed` is the incoming predicate. +- **outputs:** `%result` is the generated predicate mask. +- **constraints and limitations:** For 32-bit scalar forms, the scalar source + MUST satisfy the backend's legal scalar-source constraints for this family. + `%seed` and `%result` keep the typed-mask granularity that matches `%src`. + +--- + +## Selection Operations + +### `pto.vsel` + +- **syntax:** `%result = pto.vsel %src0, %src1, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Per-lane select based on mask. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? src0[i] : src1[i]; +``` + +**Example — Conditional assignment:** +```mlir +// dst = mask ? true_vals : false_vals +%result = pto.vsel %true_vals, %false_vals, %condition + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +- **inputs:** `%src0` is the true-path vector, `%src1` is the false-path vector, + and `%mask` selects between them. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** Source vectors and result MUST have matching + vector shapes and element types. `%mask` keeps the typed-mask granularity + that matches the selected vector family. + +--- + +### `pto.vselr` + +- **syntax:** `%result = pto.vselr %src, %idx : !pto.vreg, !pto.vreg> -> !pto.vreg` +- **semantics:** Lane-select by index vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = src[idx[i]]; +``` + +- **inputs:** `%src` is the source vector. `%idx` is the lane-index vector. +- **outputs:** `%result` is the reordered vector. +- **constraints and limitations:** `%idx` must use integer elements. `%idx` + must have the same lane count as `%src`, and its integer element width must + match the bit width of `%src` element type. + +--- + +### `pto.vselrv2` + +- **syntax:** `%result = pto.vselrv2 %src0, %src1 : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Variant select form with the same current two-vector operand shape. +- **inputs:** `%src0` and `%src1` are the source vectors. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** This page records the surface shape only. + Lowering MUST preserve the exact A5 variant semantics selected for this form. + +--- + +## Typical Usage + +```mlir +// Clamp negative values to zero (manual ReLU) +%all = pto.pset_b32 "PAT_ALL" : !pto.mask +%zero = pto.vbr %c0_f32 : f32 -> !pto.vreg<64xf32> +%neg_mask = pto.vcmps %input, %c0_f32, %all, "lt" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%clamped = pto.vsel %zero, %input, %neg_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Element-wise max via compare+select +%gt_mask = pto.vcmp %a, %b, %all, "gt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +%max_ab = pto.vsel %a, %b, %gt_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Threshold filter +%above_thresh = pto.vcmps %scores, %threshold, %all, "ge" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%filtered = pto.vsel %scores, %zero, %above_thresh : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +--- + +## Compare + Select Pattern + +```mlir +// Softmax safe exp: exp(x - max) where x < max returns exp of negative +// but we want to clamp to avoid underflow + +%all = pto.pset_b32 "PAT_ALL" : !pto.mask + +// 1. Compare against threshold +%too_small = pto.vcmps %x_minus_max, %min_exp_arg, %all, "lt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask + +// 2. Clamp values below threshold +%clamped = pto.vsel %min_exp_arg_vec, %x_minus_max, %too_small + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// 3. Safe exp +%exp_result = pto.vexp %clamped, %all : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` diff --git a/docs/isa/12-data-rearrangement.md b/docs/isa/12-data-rearrangement.md new file mode 100644 index 000000000..359e7c306 --- /dev/null +++ b/docs/isa/12-data-rearrangement.md @@ -0,0 +1,230 @@ +# 12. Data Rearrangement + +> **Category:** In-register data movement and permutation +> **Pipeline:** PIPE_V (Vector Core) + +Operations that rearrange data within or between vector registers without memory access. + +## Common Operand Model + +- `%lhs` / `%rhs` are source vector register values. +- `%src` is a single source vector register value. +- `%result` is the destination vector register value unless an op explicitly + returns multiple vectors. +- These families do not access UB directly; they only rearrange register + contents. + +--- + +## Interleave / Deinterleave + +### `pto.vintlv` + +- **syntax:** `%low, %high = pto.vintlv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg, !pto.vreg` +- **semantics:** Interleave elements from two sources. + +```c +// Interleave: merge even/odd elements from two sources +// low = {src0[0], src1[0], src0[1], src1[1], ...} +// high = {src0[N/2], src1[N/2], src0[N/2+1], src1[N/2+1], ...} +``` + +- **inputs:** `%lhs` and `%rhs` are the two source vectors. +- **outputs:** `%low` and `%high` are the two destination vectors. +- **constraints and limitations:** The two outputs form a paired interleave + result. The PTO micro Instruction representation exposes that pair as two SSA results, and the pair ordering MUST + be preserved. + +--- + +### `pto.vdintlv` + +- **syntax:** `%low, %high = pto.vdintlv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg, !pto.vreg` +- **semantics:** Deinterleave elements into even/odd. + +```c +// Deinterleave: separate even/odd elements +// low = {src0[0], src0[2], src0[4], ...} // even +// high = {src0[1], src0[3], src0[5], ...} // odd +``` + +- **inputs:** `%lhs` and `%rhs` represent the interleaved source stream in the + current PTO micro Instruction representation. +- **outputs:** `%low` and `%high` are the separated destination vectors. +- **constraints and limitations:** The two outputs form the even/odd + deinterleave result pair, and their ordering MUST be preserved. + +--- + +## Compress / Expand + +### `pto.vsqz` + +- **syntax:** `%result = pto.vsqz %src, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Compress — pack active lanes to front. + +```c +int j = 0; +for (int i = 0; i < N; i++) + if (mask[i]) dst[j++] = src[i]; +while (j < N) dst[j++] = 0; +``` + +**Use case:** Sparse data compaction, filtering. + +- **inputs:** `%src` is the source vector and `%mask` selects which elements are + kept. +- **outputs:** `%result` is the compacted vector. +- **constraints and limitations:** This is a reduction-style compaction family. + Preserved element order MUST match source lane order. + +--- + +### `pto.vusqz` + +- **syntax:** `%result = pto.vusqz %src, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Generate per-lane prefix counts from the governing predicate. + +```c +dst[0] = 0; +for (int i = 1; i < N; i++) + dst[i] = mask[i - 1] ? (dst[i - 1] + 1) : dst[i - 1]; +``` + +- **inputs:** `%mask` is the governing predicate. The current PTO surface keeps + `%src` in the operand list for interface compatibility, but the observable + result semantics are determined by `%mask`. +- **outputs:** `%result[i]` equals the number of active lanes in `%mask[0:i)`, + with `%result[0] = 0`. +- **constraints and limitations:** `T` is currently limited to `si8`, `si16`, + or `si32`. This operation is a predicate-derived counting/rearrangement + primitive rather than a value-placement primitive. The final predicate lane + does not contribute to a later output lane because there is no `dst[N]`. + +--- + +--- + +### `pto.vselr` + +- **syntax:** `%result = pto.vselr %src, %idx : !pto.vreg, !pto.vreg> -> !pto.vreg` +- **semantics:** Register lane-select with an explicit index vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = src[idx[i]]; +``` + +- **inputs:** `%src` is the source vector. `%idx` is the lane-index vector. +- **outputs:** `%result` is the reordered vector. +- **constraints and limitations:** This page records the rearrangement use of + the family; the compare/select page documents the same name from the predicate + selection perspective. + +--- + +## Pack / Unpack + +### `pto.vpack` + +- **syntax:** `%result = pto.vpack %src, "PART" : !pto.vreg -> !pto.vreg<2NxT_narrow>` +- **semantics:** Narrow one wide vector and place the narrowed payload into the + selected half of the result. The other half is filled with zero. + +```c +// e.g., vreg<64xi32> → vreg<128xui16> +for (int i = 0; i < N; i++) + dst[i] = 0; + +if (part == LOWER) { + for (int i = 0; i < N; i++) + dst[i] = truncate(src[i]); +} else { // HIGHER + for (int i = 0; i < N; i++) + dst[N + i] = truncate(src[i]); +} +``` + +- **inputs:** `%src` is the wide source vector. `"LOWER"` and `"HIGHER"` + select whether the narrowed payload lands in the lower or upper half. +- **outputs:** `%result` is the packed narrow vector. +- **constraints and limitations:** Packing is a narrowing conversion with + truncation semantics. Current VPTO surface supports `i32/ui32 -> ui16` and + `i16/ui16 -> ui8`. + +--- + +### `pto.vsunpack` + +- **syntax:** `%result = pto.vsunpack %src, %part : !pto.vreg, index -> !pto.vreg` +- **semantics:** Sign-extending unpack — narrow to wide (half). + +```c +// e.g., vreg<128xi16> → vreg<64xi32> (one half) +for (int i = 0; i < N/2; i++) + dst[i] = sign_extend(src[part_offset + i]); +``` + +- **inputs:** `%src` is the packed narrow vector and `%part` selects which half + is unpacked. +- **outputs:** `%result` is the widened vector. +- **constraints and limitations:** This is the sign-extending unpack family. + +--- + +### `pto.vzunpack` + +- **syntax:** `%result = pto.vzunpack %src, %part : !pto.vreg, index -> !pto.vreg` +- **semantics:** Zero-extending unpack — narrow to wide (half). + +```c +for (int i = 0; i < N/2; i++) + dst[i] = zero_extend(src[part_offset + i]); +``` + +- **inputs:** `%src` is the packed narrow vector and `%part` selects which half + is unpacked. +- **outputs:** `%result` is the widened vector. +- **constraints and limitations:** This is the zero-extending unpack family. + +--- + +## Typical Usage + +```mlir +// AoS → SoA conversion using deinterleave +%even, %odd = pto.vdintlv %interleaved0, %interleaved1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// Filter: keep only elements passing condition +%pass_mask = pto.vcmps %values, %threshold, %all, "gt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%compacted = pto.vsqz %values, %pass_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Type narrowing via pack +%packed_i16 = pto.vpack %wide_i32, "LOWER" + : !pto.vreg<64xi32> -> !pto.vreg<128xui16> +``` + +--- + +## V2 Interleave Forms + +### `pto.vintlvv2` + +- **syntax:** `%result = pto.vintlvv2 %lhs, %rhs, "PART" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **inputs:** `%lhs` and `%rhs` are source vectors and `PART` selects the + returned half of the V2 interleave result. +- **outputs:** `%result` is the selected interleave half. +- **constraints and limitations:** This op exposes only one half of the V2 + result in SSA form. + +### `pto.vdintlvv2` + +- **syntax:** `%result = pto.vdintlvv2 %lhs, %rhs, "PART" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **inputs:** `%lhs` and `%rhs` are source vectors and `PART` selects the + returned half of the V2 deinterleave result. +- **outputs:** `%result` is the selected deinterleave half. +- **constraints and limitations:** This op exposes only one half of the V2 + result in SSA form. diff --git a/docs/isa/13-dsa-sfu-ops.md b/docs/isa/13-dsa-sfu-ops.md new file mode 100644 index 000000000..731fa71b2 --- /dev/null +++ b/docs/isa/13-dsa-sfu-ops.md @@ -0,0 +1,229 @@ +# 13. DSA/SFU Ops + +> **Category:** Domain-specific accelerator and special function unit operations +> **Pipeline:** PIPE_V (Vector Core) / SFU + +Fused operations, special functions, and UB-to-UB operations that leverage hardware acceleration. + +## Common Operand Model + +- `%input`, `%lhs`, `%rhs`, `%acc`, and `%alpha` are source SSA values whose + roles are called out per instruction. +- `%mask` is the predicate operand `Pg` when present. +- `%result` is the destination SSA value. +- This page mixes three different backend shapes: pure `vreg -> vreg` ops, + conversion/fusion ops, and UB-to-UB helpers. Each instruction section calls + out which storage model it uses. + +--- + +## Fused Activation Ops (vreg→vreg) + +### `pto.vlrelu` + +- **syntax:** `%result = pto.vlrelu %input, %alpha, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Leaky ReLU with scalar alpha. + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : alpha * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%alpha` is the scalar slope, + and `%mask` selects active lanes. +- **outputs:** `%result` is the leaky-ReLU vector. +- **constraints and limitations:** Only `f16` and `f32` forms are currently + documented for `pto.vlrelu`. + +--- + +### `pto.vprelu` + +- **syntax:** `%result = pto.vprelu %input, %alpha : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Parametric ReLU with per-element alpha vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : alpha[i] * src[i]; +``` + +- **inputs:** `%input` is the activation vector and `%alpha` is the per-element + slope vector. +- **outputs:** `%result` is the parametric-ReLU vector. +- **constraints and limitations:** Floating-point element types only on the + current A5 surface. + +--- + +### `pto.vexpdiff` + +- **syntax:** `%result = pto.vexpdiff %input, %max, "EVEN|ODD" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** input `f16` or `f32`, output `f32` +- **semantics:** Fused exp(x - max) for numerically stable softmax. + +```c +for (int i = 0; i < N; i++) + dst[i] = expf(src[i] - max[i]); +``` + +**Use case:** Softmax numerator computation with numerical stability. + +- **inputs:** `%input` is the source vector and `%max` is the broadcasted + subtraction term. `%part` selects `EVEN` or `ODD` for the + underlying hardware contract. +- **outputs:** `%result` is the fused `exp(input - max)` vector with `f32` + elements. +- **constraints and limitations:** Source vectors must be `f16` or `f32`, the + result vector must be `f32`, and source/result storage width must match. + +--- + +## Fused Compute+Convert Ops + +### `pto.vaxpy` + +- **syntax:** `%result = pto.vaxpy %src0, %src1, %alpha : !pto.vreg, !pto.vreg, T -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** AXPY — scalar-vector multiply-add. + +```c +for (int i = 0; i < N; i++) + dst[i] = alpha * src0[i] + src1[i]; +``` + +- **inputs:** `%src0` is the scaled vector, `%src1` is the addend vector, and + `%alpha` is the scalar multiplier. +- **outputs:** `%result` is the fused AXPY result. +- **constraints and limitations:** Floating-point element types only on the + current documented surface. + +--- + + +## Extended Arithmetic + +### `pto.vmull` + +- **syntax:** `%low, %high = pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` +- **A5 types:** i32/ui32 (native 32×32→64 widening multiply) +- **semantics:** Widening multiply with high/low results. + +```c +for (int i = 0; i < 64; i++) { + int64_t r = (int64_t)src0_i32[i] * (int64_t)src1_i32[i]; + dst_lo[i] = (int32_t)(r & 0xFFFFFFFF); + dst_hi[i] = (int32_t)(r >> 32); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the source vectors and `%mask` selects + active lanes. +- **outputs:** `%low` and `%high` expose the widened-product low/high parts. +- **constraints and limitations:** The current documented A5 form is the native + widening 32x32->64 integer multiply family. + +--- + +### `pto.vmula` + +- **syntax:** `%result = pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Multiply-accumulate. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) + dst[i] = acc[i] + lhs[i] * rhs[i]; +``` + +- **inputs:** `%acc` is the accumulator input, `%lhs` and `%rhs` are the + multiplicands, and `%mask` selects active lanes. +- **outputs:** `%result` is the multiply-accumulate result. +- **constraints and limitations:** `pto.vmula` is a fused multiply-accumulate + operation and is not always interchangeable with separate `vmul` plus `vadd`. + +--- + +## Index Generation + +### `pto.vci` + +- **syntax:** `%result = pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- **semantics:** Generate lane index vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = base_index + i; +``` + +**Use case:** Generate indices for gather/scatter, argsort, etc. + +- **inputs:** `%index` is the scalar seed/base index. +- **outputs:** `%result` is the generated index vector. +- **constraints and limitations:** This page documents the arithmetic/indexing + use of the family; the conversion page also records the same opcode for + completeness. + +--- + +## Sorting Operations + +### `pto.vbitsort` + +- **syntax:** `pto.vbitsort %dest, %src, %indices, %repeat_times : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, index` +- **semantics:** Sort 32 region proposals by score and materialize sorted + proposal records into `%dest`. +- **inputs:** `%dest` is the UB destination buffer. `%src` is the UB score + buffer. `%indices` is the UB index buffer. `%repeat_times` is the repeat + count; each repeat processes the next adjacent group of 32 scores and 32 + indices. +- **outputs:** This op writes UB memory and returns no SSA value. Each output + record occupies 8 bytes: the upper 4 bytes hold the index and the lower + 4 bytes hold the score. For `f16` score forms, the score uses the lower + 2 bytes of that 4-byte score field and the upper 2 bytes are reserved. +- **constraints and limitations:** `%dest`, `%src`, and `%indices` MUST be + UB-backed pointers and SHOULD satisfy the backend alignment contract expected + by the A5 `VBS32` instruction. Scores are sorted in descending order, so the + highest score is written to the lowest destination address. Equal-score ties + preserve the earlier input proposal first. This is a UB helper, not a pure + `vreg -> vreg` op. + +--- + +### `pto.vmrgsort4` + +- **syntax:** `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr, i64, i64` +- **semantics:** Merge-sort 4 pre-sorted input vectors. +- **inputs:** `%dest` is the UB destination, `%src0..%src3` are the four + pre-sorted UB inputs, `%count` is the number of valid elements, and `%config` + is the operation control word. +- **outputs:** This op writes UB memory and returns no SSA value. +- **constraints and limitations:** Inputs MUST already be sorted according to + the sort order encoded by `%config`. + +--- + +## Current Implementation Surface Summary + +- `pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` +- `pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- `pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- `pto.vbitsort %dest, %src, %indices, %repeat_times : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, index` +- `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, i64, i64` + +--- + +## Typical Usage + +```mlir +// Softmax with fused expdiff +%max_broadcast = pto.vlds %ub_max[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> +%exp_stable = pto.vexpdiff %logits, %max_broadcast : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// Leaky ReLU activation +%activated = pto.vlrelu %linear_out, %alpha_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Generate indices for argsort +%indices = pto.vci %c0 {order = "ASC"} : i32 -> !pto.vreg<64xi32> +``` diff --git a/docs/isa/14-shared-arith.md b/docs/isa/14-shared-arith.md new file mode 100644 index 000000000..6c703dc55 --- /dev/null +++ b/docs/isa/14-shared-arith.md @@ -0,0 +1,99 @@ +# 14. Arith (Shared MLIR Dialect) + +> **Category:** Shared full scalar `arith` surface used around PTO ops +> **Dialect:** `arith` +> **Upstream Reference:** https://mlir.llvm.org/docs/Dialects/ArithOps/ + +The upstream MLIR `arith` dialect defines primitive arithmetic, comparison, select, and cast operations over signless integer, index, floating-point, and boolean-compatible scalar values. Within PTO micro Instruction code, the full scalar operation surface of `arith` is supported. These ops are used around PTO instructions to build constants, compute offsets and loop bounds, perform general scalar math, derive valid-shape metadata, and form predicates for `scf` control flow. + +These ops are part of the documented PTO micro Instruction surface, but they are not PTO ISA instructions. + +--- + +## Role in PTO micro Instruction Code + +- materialize scalar constants used by PTO scalar operands and loop bounds +- compute scalar/index offsets for tensor views, partitioning, and dynamic shapes +- perform general scalar integer and floating-point math outside PTO vector/tile payload operations +- derive scalar predicates that guard `scf.if` or `scf.while` +- apply scalar casts, width changes, bitwise ops, and selects without introducing PTO-specific control ops + +Prefer PTO ops for vector or tile payload math. Use `arith` for scalar computation and bookkeeping that surrounds PTO regions. + +--- + +## Supported Surface + +The documented PTO micro Instruction surface supports the full scalar operation surface of upstream `arith`. The upstream `arith` dialect reference remains authoritative for the exhaustive op-by-op syntax and semantics. The categories below summarize how that support is used in PTO micro Instruction code. + +| Category | Representative Ops | Typical Use in PTO micro Instruction Code | +|----------|--------------------|------------------| +| Constants | `arith.constant` | integer, floating-point, boolean, and `index` constants | +| Integer / Index Arithmetic | `arith.addi`, `arith.subi`, `arith.muli`, `arith.divsi`, `arith.divui`, `arith.ceildivsi`, `arith.ceildivui`, `arith.floordivsi`, `arith.remsi`, `arith.remui` | offsets, bounds, chunk sizes, scalar math | +| Floating-Point Arithmetic | `arith.addf`, `arith.subf`, `arith.mulf`, `arith.divf`, `arith.negf`, `arith.maximumf`, `arith.minimumf`, `arith.maxnumf`, `arith.minnumf` | scalar math around PTO regions | +| Bitwise / Shift Ops | `arith.andi`, `arith.ori`, `arith.xori`, `arith.shli`, `arith.shrsi`, `arith.shrui` | flags, masks, packed scalar fields | +| Comparisons / Select | `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.maxsi`, `arith.minui` | predicates, clamps, scalar muxes | +| Casts / Width Changes | `arith.index_cast`, `arith.index_castui`, `arith.extsi`, `arith.extui`, `arith.trunci`, `arith.sitofp`, `arith.uitofp`, `arith.fptosi`, `arith.fptoui`, `arith.extf`, `arith.truncf`, `arith.bitcast` | ABI glue, dynamic-shape plumbing, scalar type adaptation | + +--- + +## Current PTOAS Coverage + +- the current repository examples are still dominated by constants, casts, integer/index arithmetic, compares, and selects because those are the most common surrounding-scalar patterns in existing kernels +- backend-specific tests such as the PTO shared-dialect fixture visibly exercise only a representative subset of `arith` ops in a single path +- the documented PTO micro Instruction source-level contract is nevertheless the full scalar `arith` surface, not just the index-heavy subset that appears most often in current samples + +This section therefore uses representative categories and examples instead of pretending that the supported `arith` surface is limited to the currently most common sample patterns. + +--- + +## Typical Patterns + +### Scalar Setup + +```mlir +%c0 = arith.constant 0 : index +%c1 = arith.constant 1 : index +%scale = arith.constant 2.0 : f32 +``` + +### Dynamic Offset Computation + +```mlir +%vrow = arith.index_cast %valid_row : i32 to index +%chunk = arith.muli %row, %c32 : index +%tail = arith.subi %limit, %chunk : index +``` + +### General Scalar Arithmetic + +```mlir +%sum_i = arith.addi %lhs_i, %rhs_i : i32 +%sum_f = arith.addf %lhs_f, %rhs_f : f32 +%prod_f = arith.mulf %sum_f, %scale : f32 +``` + +### Scalar Predicate and Selection + +```mlir +%is_first = arith.cmpi eq, %i, %c0 : index +%active = arith.select %is_first, %first_count, %steady_count : index +``` + +### Bitwise / Width Adaptation + +```mlir +%flags = arith.andi %flags0, %flags1 : i32 +%wide = arith.extui %flags : i32 to i64 +%shrunk = arith.trunci %wide : i64 to i16 +``` + +--- + +## Authoring Guidance + +- treat upstream `arith` scalar semantics as the source of truth for supported scalar ops +- keep `arith` values scalar or `index` typed; do not use `arith` as a substitute for PTO vector/tile compute +- use `arith` for general scalar math, scalar comparisons, bitwise operations, and casts around PTO regions, not just for `index` arithmetic +- use `arith.cmpi` / `arith.cmpf` plus `scf.if` / `scf.while` for control flow, not ad hoc control intrinsics +- prefer `arith.index_cast` / `arith.index_castui` at ABI or shape boundaries where `index` is required, but do not read that as a restriction on the rest of scalar `arith` diff --git a/docs/isa/15-shared-scf.md b/docs/isa/15-shared-scf.md new file mode 100644 index 000000000..12a637fd7 --- /dev/null +++ b/docs/isa/15-shared-scf.md @@ -0,0 +1,97 @@ +# 15. SCF (Shared MLIR Dialect) + +> **Category:** Shared structured control flow around PTO regions +> **Dialect:** `scf` +> **Upstream Reference:** https://mlir.llvm.org/docs/Dialects/SCFDialect/ + +The upstream MLIR `scf` dialect defines structured control flow operations with regions, including counted loops, conditional regions, and while-style loops. In PTO micro Instruction code, `scf` is the control shell around PTO ops: it sequences DMA, vector, and tile operations; carries scalar or tile state across iterations; and preserves analyzable control flow for PTO-specific analyses and lowerings. + +These ops are part of the documented PTO micro Instruction surface, but they are shared MLIR control-flow constructs rather than PTO ISA instructions. + +--- + +## Supported Ops + +| Op | Role in PTO micro Instruction Code | Notes | +|----|------------------------|-------| +| `scf.for` | counted loops and loop-carried values | common structured counted loop form | +| `scf.if` | structured conditional execution | may yield values or act as side-effect-only branch | +| `scf.yield` | region terminator for `for` / `if` / `while` bodies | carries loop or branch results | +| `scf.while` | break-like or stateful loops | useful for source-level structured control | +| `scf.condition` | loop-continue / loop-exit decision for `scf.while` | placed in the "before" region | + +Ops such as `scf.execute_region`, `scf.forall`, or `scf.index_switch` are not part of the documented shared-dialect portion of the PTO micro Instruction surface here. + +--- + +## Current PTOAS Coverage + +- `scf.for`, `scf.if`, and `scf.yield` are directly exercised in the shared-dialect PTO fixture and appear widely across PTO samples +- PTO synchronization and memory analyses explicitly reason about `scf.for`, `scf.if`, `scf.yield`, and `scf.while` +- `scf.while` and `scf.condition` appear in control-flow samples and are handled in PTO-to-EmitC control-flow lowering, but they are less broadly exercised than `for` / `if` on all backend paths + +--- + +## Typical Patterns + +### Counted Loop + +```mlir +scf.for %i = %c0 to %c4 step %c1 { + %offset = arith.muli %i, %c32 : index + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} +``` + +### Counted Loop with Loop-Carried State + +```mlir +%final_alive = scf.for %i = %c0 to %c4 step %c1 + iter_args(%alive = %true) -> (i1) { + %break_now = arith.cmpi eq, %i, %c2 : index + %next_alive = scf.if %break_now -> (i1) { + scf.yield %false : i1 + } else { + scf.yield %alive : i1 + } + scf.yield %next_alive : i1 +} +``` + +### Structured Conditional Region + +```mlir +%is_mode_a = arith.cmpi eq, %mode, %c0_i32 : i32 +scf.if %is_mode_a { + pto.tmuls ins(%data, %scale_a : !pto.tile_buf<...>, f32) outs(%data : !pto.tile_buf<...>) +} else { + pto.tadds ins(%data, %bias_b : !pto.tile_buf<...>, f32) outs(%data : !pto.tile_buf<...>) +} +``` + +### While-Style Break Loop + +```mlir +%final:2 = scf.while (%i = %c0, %alive = %true) : (index, i1) -> (index, i1) { + %lt = arith.cmpi slt, %i, %c4 : index + %go = arith.andi %lt, %alive : i1 + scf.condition(%go) %i, %alive : index, i1 +} do { +^bb0(%i2: index, %alive2: i1): + %next_i = arith.addi %i2, %c1 : index + scf.yield %next_i, %alive2 : index, i1 +} +``` + +--- + +## Authoring Guidance + +- use `scf.for` for regular counted loops and loop-carried scalar/tile state +- use `scf.if` for structured branching around PTO regions instead of inventing PTO-specific branch ops +- keep region results explicit with `scf.yield`; this is important for PTO analyses that track carried buffers and aliasing +- use `scf.while` only when a counted loop cannot express the control cleanly; `scf.for` remains the more common and better-exercised form in the current repository +- build branch predicates and loop conditions with `arith` ops, not PTO vector masks, unless the control decision truly comes from a scalarized value diff --git a/docs/release/vpto-spec-v0.1.md b/docs/release/vpto-spec-v0.1.md new file mode 100644 index 000000000..ab7c5f6b4 --- /dev/null +++ b/docs/release/vpto-spec-v0.1.md @@ -0,0 +1,4883 @@ +# PTO micro Instruction Spec — Draft (A5) + +- v0.1: Doc Init + +[toc] + +--- + +## Part I: Architecture Overview + +### Overview + +This document defines the PTO micro Instruction, a compiler-internal and externally facing specification designed to represent vector compute kernels within the PTO architecture. Much like NVVM provides a robust IR for GPU architectures, the PTO micro Instruction serves as the direct bridge between high-level programming models and the underlying hardware ISA, providing a precise, low-level representation of vector workloads explicitly designed for the Ascend 950 architecture. + +#### Position in the Stack and Layer Modeled + +The PTO micro Instruction operates as a very low-level intermediate representation within the PTO compiler stack. It is uniquely designed to accurately and comprehensively express all architectural information of the Ascend 950 hardware. It specifically models the bare-metal vector execution layer, making hardware-specific capabilities and constraints, such as exact vector lane configurations, memory space hierarchies, and hardware-specific fusion semantics, fully transparent and controllable. + +#### Why External Developers Read or Author PTO micro Instruction + +While the majority of users will interact with the PTO architecture via higher-level frameworks, external developers may need to read or author PTO micro Instruction directly for several key reasons: + +- Custom Toolchain Development: build custom compiler frontends or domain-specific languages (DSLs) that target the Ascend 950 architecture with maximum hardware utilization. +- Performance Engineering: inspect the output of high-level compiler passes, verify fine-grained optimization behaviors, and pinpoint performance bottlenecks at the architectural level. +- Micro-Optimization: hand-author highly optimized, critical mathematical kernels using a stable, precise IR when higher-level abstractions cannot achieve the theoretical peak performance of the hardware. + +#### Relationship to CCE + +The PTO micro Instruction is designed to express the full semantic capabilities of the Compute Cube Engine (CCE), but with significant structural and pipeline advantages for compiler development. + +- Bypassing the C/Clang Pipeline: while CCE heavily relies on C/C++ extensions parsed by Clang, the PTO micro Instruction operates entirely independently of the C language frontend. By bypassing Clang AST generation and frontend processing, utilizing the PTO micro Instruction significantly reduces overall compilation time and memory overhead. +- Enhanced IR Verification: because the PTO micro Instruction is a strongly typed, SSA-based (Static Single Assignment) compiler IR rather than a C-wrapper API, it provides a much more rigorous and detailed IR verification process. Structural inconsistencies, invalid memory access patterns, and operand type mismatches are caught immediately with precise, explicit diagnostic feedback, providing developers with much higher visibility into kernel correctness than traditional CCE error reporting. + +#### Intended Audience + +This document is written for compiler engineers, library writers, and advanced performance architects. We expect the reader to have a working understanding of modern compiler infrastructure, specifically MLIR, the principles of Static Single Assignment (SSA) form, and a deep understanding of the vector-processing capabilities of the Ascend 950 architecture. + +### Getting Started + +The PTO micro Instruction is architected as a performance-critical layer within the compiler stack, specifically designed to exploit the **Decoupled Access-Execute** (DAE) nature of the Ascend 950 hardware. + +#### Hardware Pipeline Modeling + +The IR is structured to mirror the three primary hardware pipelines of the Ascend 950 architecture. Correct PTO micro Instruction authoring requires managing the interaction between these asynchronous units: + +**MTE2** (Memory Transfer Engine - Inbound): Responsible for moving data from Global Memory (GM) to the Unified Buffer (UB). + +**Vector Core** (Computation): The primary engine for executing SIMD operations on data stored in UB. + +**MTE3** (Memory Transfer Engine - Outbound): Responsible for moving processed data from UB back to GM. + +#### Architecture Detail: Vector Lane (VLane) + +The vector register is organized as **8 VLanes** of 32 bytes each. A VLane is the atomic unit for group reduction operations. + +``` +vreg (256 bytes total): +┌─────────┬─────────┬─────────┬─────┬─────────┬─────────┐ +│ VLane 0 │ VLane 1 │ VLane 2 │ ... │ VLane 6 │ VLane 7 │ +│ 32B │ 32B │ 32B │ │ 32B │ 32B │ +└─────────┴─────────┴─────────┴─────┴─────────┴─────────┘ +``` + +Elements per VLane by data type: + +| Data Type | Elements/VLane | Total Elements/vreg | +|-----------|---------------|-------------------| +| i8/u8 | 32 | 256 | +| i16/u16/f16/bf16 | 16 | 128 | +| i32/u32/f32 | 8 | 64 | +| i64/u64 | 4 | 32 | + +#### Memory and Synchronization Model + +The PTO micro Instruction enforces a strict memory hierarchy. The Unified Buffer (UB) is the only valid operand source for vector compute instructions. Consequently, the architecture of a PTO micro Instruction program is defined by the explicit management of data movement: + +**Address Space Isolation**: The IR uses `!pto.ptr` to distinguish between GM (`!pto.ptr`) and UB (`!pto.ptr`). The verifier ensures that vector compute operations do not access GM directly; data must first be moved into UB. + +**UB Capacity**: The Unified Buffer provides 256KB of on-chip SRAM (also referred to as "vecTile"). + +**Data Flow**: + +``` +┌─────────────────────────────────────────────┐ +│ Global Memory (GM) │ +│ (Off-chip HBM/DDR) │ +└─────────────────────┬───────────────────────┘ + │ DMA (MTE2 inbound / MTE3 outbound) +┌─────────────────────▼───────────────────────┐ +│ Unified Buffer (UB) │ +│ (On-chip SRAM, 256KB) │ +└─────────────────────┬───────────────────────┘ + │ Vector Load/Store (PIPE_V) +┌─────────────────────▼───────────────────────┐ +│ Vector Register File (VRF) │ +│ vreg (256B each) + mask (256-bit each) │ +└─────────────────────────────────────────────┘ +``` + +1. **GM → UB**: DMA transfer via MTE2 (`pto.copy_gm_to_ubuf`) +2. **UB → vreg**: Vector Load instructions (`pto.vlds`, `pto.vldx2`, etc.) +3. **vreg → vreg**: Compute instructions (`pto.vadd`, `pto.vmul`, etc.) +4. **vreg → UB**: Vector Store instructions (`pto.vsts`, `pto.vstx2`, etc.) +5. **UB → GM**: DMA transfer via MTE3 (`pto.copy_ubuf_to_gm`) + +**Load/Store Access Patterns**: + +For UB↔vreg data movement, besides contiguous load/store, the architecture provides rich access pattern support including strided access, pack/unpack, interleave/deinterleave, broadcast, upsample/downsample, channel split/merge, gather/scatter, and squeeze/expand operations. For detailed instruction syntax and distribution modes, refer to the [Vector Load/Store](#isa-03-vector-load-store) group in the ISA specification. + +#### Synchronization Model + +The Ascend 950 architecture employs a cluster-based design with a 1:2 ratio of Cube cores to Vector cores. The PTO micro Instruction provides multiple levels of synchronization to manage concurrent execution across pipelines and cores: + +**Inter-Core Synchronization (within a cluster):** + +Synchronization between cores within the same cluster is achieved via the core sync mechanism using `pto.set_intra_core` and `pto.wait_intra_core` operations. This enables coordination between Cube and Vector cores sharing the same cluster resources. + +**Vector Core Pipeline Synchronization:** + +Within a single core, multiple pipelines operate asynchronously: + +- **MTE2 (PIPE_MTE2)**: DMA copy-in from GM to UB +- **MTE3 (PIPE_MTE3)**: DMA copy-out from UB to GM +- **Vector Compute (PIPE_V)**: Vector ALU operations +- **Scalar (PIPE_S)**: Scalar unit running the kernel program + +Pipeline synchronization can be achieved through two mechanisms: + +1. **Flag/Event mechanism**: `pto.set_flag` and `pto.wait_flag` operations resolve Read-After-Write (RAW) and Write-After-Read (WAR) hazards between pipelines. + +2. **Buffer-ID mechanism**: `pto.get_buf` and `pto.rls_buf` provide finer-grained synchronization through buffer acquisition and release semantics for producer-consumer coordination. + +**Intra-Pipeline Memory Barriers (within `__VEC_SCOPE__`):** + +Within the vector execution scope, the hardware does not track UB address aliasing between reg↔UB accesses. When UB addresses overlap or alias between vector load/store operations, explicit memory barriers are required: + +```c +pto.mem_bar "VV_ALL" // All prior vector ops complete before subsequent +pto.mem_bar "VST_VLD" // All prior vector stores visible before subsequent loads +pto.mem_bar "VLD_VST" // All prior vector loads complete before subsequent stores +``` + +Without proper barriers, loads may see stale data or stores may be reordered incorrectly. + +#### Execution Scopes (__VEC_SCOPE__) + +`__VEC_SCOPE__` is the IR-level representation of a Vector Function (VF) launch. In the PTO architecture, it defines the hardware interface between the Scalar Unit and the Vector Thread. + +It is not a dedicated `pto` op. In the PTO micro Instruction, this scope is modeled as a specialized `scf.for` loop annotated with `llvm.loop.aivector_scope`. This gives the compiler a natural structural boundary for identifying the code block that must be lowered into a discrete VF hardware instruction sequence. + +**Scalar-Vector Interface:** + +The execution model follows non-blocking fork semantics: + +- Scalar invocation: the scalar processor invokes a vector thread by calling a VF. Once the launch command is issued, the scalar unit does not stall and continues executing subsequent instructions in the pipeline. +- Vector execution: after invocation, the vector thread independently fetches and executes the instructions defined within the VF scope. +- Parallelism: this decoupled execution allows the scalar and vector units to run in parallel, so the scalar unit can prepare addresses or manage control flow while the vector unit performs heavy SIMD computation. + +**Launch Mechanism And Constraints:** + +- Parameter buffering: all arguments required by the VF must be staged in hardware-specific buffers. +- Launch overhead: launching a VF incurs a latency of a few cycles. Very small VFs should account for this overhead because launch cost can rival useful computation time. + +**MLIR Representation:** + +```mlir +scf.for %dummy = %c0 to %c1 step %c1 { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} {llvm.loop.aivector_scope} +``` + +### Example: Abs + +```mlir +pto.set_loop2_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.copy_gm_to_ubuf %7, %2, %3, %3, %c0_i64, %c32_i64, %4, %c0_i64, %c0_i64, + %false, %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i1, i64, i64, i64 + +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +scf.for %dummy = %c0 to %c1 step %c1 { + scf.for %lane = %c0 to %9 step %c64 { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %2[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %8[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } +} {llvm.loop.aivector_scope} + +pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] +pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop2_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 +pto.copy_ubuf_to_gm %8, %14, %3, %3, %c0_i64, %c32_i64, %4, %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i64 +``` + +### Scope + +This document is the interface specification centered on the `mlir::pto` dialect and the shared MLIR surface used alongside it in PTO micro Instruction programs. + +It only describes: + +- operation names +- operand and result lists +- operand and result types +- important attributes +- C-style semantics for each operation + +It does not describe lowering strategy. + +PTO micro Instruction source programs are not restricted to `pto` operations alone. In practice they also use shared MLIR dialect ops, most notably the full scalar operation surface of `arith` together with structured control-flow ops from `scf`, to express scalar constants, scalar arithmetic, type conversion, comparisons, and structured control flow around PTO vector or tile regions. These shared-dialect ops are part of the supported PTO micro Instruction source surface and should be regarded as part of PTO-ISA alongside `pto` dialect operations. + +### Shared MLIR Dialects + +- `arith`: the full scalar `arith` surface is supported in PTO micro Instruction programs, covering scalar integer, floating-point, boolean, and `index` operations. In current samples the most common uses are still constants, offset/bounds arithmetic, casts, compares, and selects. +- `scf`: structured control flow used to model counted loops, conditional regions, loop-carried state, and break-like control around PTO compute and data-movement ops. +- Shared dialect ops remain in standard MLIR form so that PTO analyses and backend passes can reason about control flow and scalar state without re-encoding them as PTO-specific instructions. + +### Core Types + +#### Element Types +`vreg`: `!pto.vreg` Fixed-width PTO micro Instruction vector type with total width exactly 256 bytes (2048 bits). `N` is the lane count, `T` is the element type, and `N * bitwidth(T) = 2048`. + +| Type | Bits | Description | +|------|------|-------------| +| `i8` / `s8` / `u8` | 8 | Signless/signed/unsigned 8-bit integer | +| `i16` / `s16` / `u16` | 16 | Signless/signed/unsigned 16-bit integer | +| `i32` / `s32` / `u32` | 32 | Signless/signed/unsigned 32-bit integer | +| `i64` / `s64` / `u64` | 64 | Signless/signed/unsigned 64-bit integer | +| `f16` | 16 | IEEE 754 half precision | +| `bf16` | 16 | Brain floating point | +| `f32` | 32 | IEEE 754 single precision | +| `f8e4m3` | 8 | FP8 (4-bit exponent, 3-bit mantissa) | +| `f8e5m2` | 8 | FP8 (5-bit exponent, 2-bit mantissa) | + +#### Address Space Conventions + +PTO micro Instruction memory operands use `!pto.ptr`. This specification models the following memory-space attributes: + +| Space | Interpretation | +|-------|----------------| +| `gm` | Global Memory (GM), off-chip HBM/DDR storage | +| `ub` | Unified Buffer (UB), on-chip vector buffer | + +Typical pointer construction and pointer arithmetic follow the same `!pto.ptr<..., space>` form: + +```mlir +%0 = pto.castptr %c0 : i64 -> !pto.ptr +%1 = pto.addptr %0, %c1024 : !pto.ptr -> !pto.ptr +``` + +#### `!pto.ptr` + +`!pto.ptr` is the typed pointer form used for explicit memory operands in PTO micro Instruction. + +- `T` is the element type associated with the pointed-to storage. +- `space` is the memory domain, typically `gm` or `ub` in this specification. +- A `pto.ptr` value carries an address plus its element-type / memory-space interpretation, but it does not carry tensor shape or stride metadata by itself. +- Tensor semantics are introduced separately through view-building operations such as `pto.make_tensor_view`. +- Pointer arithmetic is element-based rather than byte-based. + +Typical examples: + +- `!pto.ptr` +- `!pto.ptr` +- `!pto.ptr` + +#### Pointer Operations + +#### `pto.castptr` + +- **syntax:** `%result = pto.castptr %addr : i64 -> !pto.ptr` +- **semantics:** Reinterpret a scalar address value as a typed PTO pointer in the target memory space. + +```c +result = (ptr)addr; +``` + +`pto.castptr` is a pointer-construction operation. It does not perform data movement and does not by itself imply any load/store side effect. + +#### `pto.addptr` + +- **syntax:** `%result = pto.addptr %ptr, %offset : !pto.ptr -> !pto.ptr` +- **semantics:** Compute a new pointer by advancing the base pointer by an element offset. + +```c +result = ptr + offset; // offset counted in elements, not bytes +``` + +`pto.addptr` preserves both the element type `T` and the memory-space tag `space`. + +#### Pointer-Based Vector Access Example + +The following lowered-style fragment shows how typed PTO pointers flow through pointer construction, pointer arithmetic, structured control flow, and PTO memory ops: + +```mlir +%0 = pto.castptr %c0 : i64 -> !pto.ptr +%1 = pto.addptr %0, %c1024 : !pto.ptr -> !pto.ptr +scf.for %arg2 = %c0 to %c1 step %c1 { + %16 = scf.for %arg3 = %c0 to %11 step %c64 iter_args(%arg4 = %12) -> (i32) { + %mask, %scalar_out = pto.plt_b32 %arg4 : i32 -> !pto.mask, i32 + %17 = pto.vlds %1[%arg3] : !pto.ptr -> !pto.vreg<64xf32> + %18 = pto.vabs %17, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %18, %10[%arg3], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %scalar_out : i32 + } +} {llvm.loop.aivector_scope} +``` + +In this pattern, `pto.castptr` materializes a typed UB pointer, `pto.addptr` shifts the base by 1024 `f32` elements, and the subsequent `[%arg3]` indexing on `pto.vlds` / `pto.vsts` applies an additional element offset relative to that base. + +#### Special Types + +#### `!pto.mask` + +`!pto.mask` models an A5 predicate register (256-bit), not an integer vector. + +**Mask Granularity:** + +The mask is 256 bits in length, where each bit controls 1 byte of data. This means mask granularity varies by element type: + +| Element Type | Bits/Element | Mask Bits per Element | +|--------------|--------------|----------------------| +| `f32`/`i32` | 32 | 4 bits | +| `f16`/`bf16`/`i16` | 16 | 2 bits | +| `f8`/`i8` | 8 | 1 bit | + +**Predication Behavior (Zero-Merge):** + +The native hardware predication mode is **ZEROING** — inactive lanes produce zero: + +```c +dst[i] = mask[i] ? op(src0[i], src1[i]) : 0 // ZEROING mode +``` + +```mlir +// Predicated add: inactive lanes produce zero +%mask = pto.pset_b32 "PAT_VL32" : !pto.mask // first 32 lanes active +%result = pto.vcmp %a, %b, %mask, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +``` + +```mlir +// Compare and select: generate mask from comparison, use for conditional select +%mask = pto.vcmp %lhs, %rhs, %seed, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +%out = pto.vsel %x, %y, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +#### `!pto.align` + +`!pto.align` models the A5 vector-align carrier state. It is not payload data. + +```mlir +%align = pto.vldas %ub : !pto.ptr -> !pto.align +%vec, %align_out, %base_out = pto.vldus %ub, %align : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align, !pto.ptr +``` + +--- + +## Part II: Notation Convention + +This section defines the MLIR syntax patterns and C-style semantic notation used throughout the ISA reference (Part III). + +### MLIR Op Syntax Patterns + +All PTO micro Instruction operations follow standard MLIR syntax. The common patterns are: + +**Unary (one vector in, one vector out):** + +```mlir +%result = pto. %input : !pto.vreg -> !pto.vreg +``` + +**Binary (two vectors in, one vector out):** + +```mlir +%result = pto. %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg +``` + +**Vec-Scalar (one vector + one scalar in, one vector out):** + +```mlir +%result = pto. %input, %scalar : !pto.vreg, T -> !pto.vreg +``` + +**Load (memory to register):** + +```mlir +%result = pto.vlds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.vreg +``` + +**Store (register to memory):** + +```mlir +pto.vsts %value, %destination[%offset] {dist = "DIST"} : !pto.vreg, !pto.ptr +``` + +**Dual Load (one load, two results — deinterleave):** + +```mlir +%low, %high = pto.vldx2 %source[%offset], "DIST" : !pto.ptr, index -> !pto.vreg, !pto.vreg +``` + +**Dual Store (two inputs, one interleaved store):** + +```mlir +pto.vstx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vreg, !pto.ptr, index, !pto.mask +``` + +**Compare (two vectors + seed mask in, mask out):** + +```mlir +%mask = pto.vcmp %src0, %src1, %seed, "CMP_MODE" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.mask +``` + +**Conversion (one vector in, different-typed vector out):** + +```mlir +%result = pto.vcvt %input {round_mode = "ROUND_R", sat = "RS_ENABLE", part = "PART_EVEN"} : !pto.vreg -> !pto.vreg +``` + +**Predicate construction:** + +```mlir +%mask = pto.pset_b32 "PAT_ALL" : !pto.mask +%tail = pto.pge_b32 "PAT_VL16" : !pto.mask +``` + +**Sync operations:** + +```mlir +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +**Pointer construction and arithmetic:** + +```mlir +%ptr = pto.castptr %addr : i64 -> !pto.ptr +%ptr2 = pto.addptr %ptr, %offset : !pto.ptr -> !pto.ptr +``` + +### Shared Dialect Syntax Patterns + +PTO micro Instruction programs may interleave PTO ops with standard MLIR `arith` and `scf` ops. +The examples below emphasize common index-heavy patterns, but `arith` support is not limited to index arithmetic. + +**Scalar / index constant:** + +```mlir +%c0 = arith.constant 0 : index +%zero = arith.constant 0.0 : f32 +``` + +**Scalar arithmetic (integer / float / boolean-style bitwise):** + +```mlir +%sum_i = arith.addi %lhs_i, %rhs_i : i32 +%sum_f = arith.addf %lhs_f, %rhs_f : f32 +%bits = arith.andi %flags0, %flags1 : i32 +``` + +**Scalar compare and select:** + +```mlir +%cond = arith.cmpi eq, %lhs, %rhs : index +%bound = arith.select %cond, %a, %b : index +``` + +**Counted loop with loop-carried values:** + +```mlir +%result = scf.for %iv = %lb to %ub step %step + iter_args(%acc = %init) -> (index) { + %next = arith.addi %acc, %iv : index + scf.yield %next : index +} +``` + +**Structured conditional region:** + +```mlir +%selected = scf.if %cond -> (index) { + scf.yield %then_value : index +} else { + scf.yield %else_value : index +} +``` + +**Structured while loop:** + +```mlir +%state:2 = scf.while (%iv = %c0, %alive = %true) : (index, i1) -> (index, i1) { + %keep_going = arith.cmpi slt, %iv, %limit : index + scf.condition(%keep_going) %iv, %alive : index, i1 +} do { +^bb0(%iv_in: index, %alive_in: i1): + %iv_next = arith.addi %iv_in, %c1 : index + scf.yield %iv_next, %alive_in : index, i1 +} +``` + +### C-Style Semantics Convention + +For each ISA operation in Part III, semantics are expressed as C code. The convention: + +```c +// Vector register contents as arrays: +T dst[N]; // destination +T src0[N]; // first source +T src1[N]; // second source (binary ops) +T scalar; // scalar operand (vec-scalar ops) +int mask[N]; // per-lane predicate (0 or 1) + +// N = lane count determined by type: +// N = 256 for i8/u8 +// N = 128 for i16/u16/f16/bf16 +// N = 64 for i32/u32/f32 +// N = 32 for i64/u64 +``` + +**Example — pto.vadd semantics:** + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +**Example — pto.vcgadd (group reduction per VLane) semantics:** + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +### Template Placeholder Conventions + +| Placeholder | Meaning | +|-------------|---------| +| `"SRC_PIPE"`, `"DST_PIPE"` | Pipeline identifiers: `"PIPE_MTE2"`, `"PIPE_V"`, `"PIPE_MTE3"` | +| `"EVENT_ID"` | Event identifier: `"EVENT_ID0"` etc. | +| `"DIST"` | Distribution mode string (see the relevant load/store ISA group in Part III) | +| `"CMP_MODE"` | Compare predicate: `eq \| ne \| lt \| le \| gt \| ge` | +| `"ROUND_MODE"` | Rounding mode: `ROUND_R \| ROUND_A \| ROUND_F \| ROUND_C \| ROUND_Z` | +| `"SAT_MODE"` | Saturation: `RS_ENABLE \| RS_DISABLE` | +| `"PART_MODE"` | Half selector: `PART_EVEN \| PART_ODD` | +| `"PAT_*"` | Predicate pattern literal | +| `T` | Element type (f32, f16, bf16, i32, i16, i8, etc.) | +| `N` | Lane count (`N * bitwidth(T) = 2048`) | + +--- + +## Part III: ISA Instruction Reference — Summary + +This section provides a categorized overview of all PTO micro Instruction operations plus the shared MLIR `arith` and `scf` ops that may appear in PTO micro Instruction programs. Detailed documentation for each group is included later in this merged document. + +--- + +## Instruction Groups + +| # | Group | Description | Count | Details | +|---|-------|-------------|-------|---------| +| 1 | [Pipeline Sync](#isa-01-pipeline-sync) | Intra-core pipeline synchronization | 5 | `pto.set_flag`, `pto.wait_flag`, `pto.pipe_barrier`, `pto.get_buf`, `pto.rls_buf` | +| 2 | [DMA Copy Programming](#isa-02-dma-copy) | DMA configuration and transfer between GM↔UB | 9 | `pto.set_loop*_stride_*`, `pto.set_loop_size_*`, `pto.copy_gm_to_ubuf`, `pto.copy_ubuf_to_ubuf`, `pto.copy_ubuf_to_gm` | +| 3 | [Vector Load/Store](#isa-03-vector-load-store) | UB↔vreg data movement with various access patterns | ~20 | `pto.vlds`, `pto.vldx2`, `pto.vgather2`, `pto.vsts`, `pto.vstx2`, `pto.vscatter`, etc. | +| 4 | [Predicate Load/Store](#isa-04-predicate-load-store) | UB↔mask register movement | 7 | `pto.plds`, `pto.pld`, `pto.pldi`, `pto.psts`, `pto.pst`, `pto.psti`, `pto.pstu` | +| 5 | [Materialization & Predicate Ops](#isa-05-materialization-predicate) | Scalar broadcast, predicate generation and manipulation | ~17 | `pto.vbr`, `pto.vdup`, `pto.pset_b*`, `pto.pge_b*`, `pto.plt_b*`, `pto.ppack`, `pto.punpack`, `pto.pnot`, `pto.psel`, etc. | +| 6 | [Unary Vector Ops](#isa-06-unary-vector-ops) | Single-input element-wise operations | 9 | `pto.vabs`, `pto.vexp`, `pto.vln`, `pto.vsqrt`, `pto.vrec`, `pto.vrelu`, `pto.vnot`, `pto.vbcnt`, `pto.vcls` | +| 7 | [Binary Vector Ops](#isa-07-binary-vector-ops) | Two-input element-wise operations | 13 | `pto.vadd`, `pto.vsub`, `pto.vmul`, `pto.vdiv`, `pto.vmax`, `pto.vmin`, `pto.vand`, `pto.vor`, `pto.vxor`, `pto.vshl`, `pto.vshr`, `pto.vaddc`, `pto.vsubc` | +| 8 | [Vec-Scalar Ops](#isa-08-vec-scalar-ops) | Vector-scalar operations | 8 | `pto.vadds`, `pto.vmuls`, `pto.vmaxs`, `pto.vmins`, `pto.vlrelu`, `pto.vshls`, `pto.vshrs`, `pto.vaddcs`, `pto.vsubcs` | +| 9 | [Conversion Ops](#isa-09-conversion-ops) | Type conversion with rounding/saturation control | 2 | `pto.vcvt`, `pto.vtrc` | +| 10 | [Reduction Ops](#isa-10-reduction-ops) | Vector reductions | 3 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin` | +| 11 | [Compare & Select](#isa-11-compare-select) | Comparison and conditional selection | 5 | `pto.vcmp`, `pto.vcmps`, `pto.vsel`, `pto.vselr`, `pto.vselrv2` | +| 12 | [Data Rearrangement](#isa-12-data-rearrangement) | In-register data movement and permutation | 4 | `pto.vintlv`, `pto.vdintlv`, `pto.vintlvv2`, `pto.vdintlvv2` | +| 13 | [DSA/SFU Ops](#isa-13-dsa-sfu-ops) | Specialized ops, index generation, and sorting helpers | 5 | `pto.vmull`, `pto.vmula`, `pto.vci`, `pto.vbitsort`, `pto.vmrgsort4` | +| 14 | [Arith (Shared MLIR Dialect)](#isa-14-shared-arith) | Full scalar `arith` surface used around PTO ops; the companion page lists categories and representative examples | all scalar ops | `arith.constant`, `arith.addi`, `arith.addf`, `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.index_cast`, `arith.extsi`, `arith.trunci`, `arith.andi`, `arith.shli`, etc. | +| 15 | [SCF (Shared MLIR Dialect)](#isa-15-shared-scf) | Structured loops, branches, and loop-carried state around PTO regions | 5 | `scf.for`, `scf.if`, `scf.while`, `scf.condition`, `scf.yield` | + +--- + +## Detailed ISA Group Reference + +This section inlines the 15 ISA group documents so the architectural overview, notation, summary table, and per-group semantics can be read in a single file. + + + +### 1. Pipeline Synchronization + +> **Category:** Synchronization primitives for coordinating pipeline execution +> **Pipelines:** MTE2 (GM→UB), PIPE_V (Vector), MTE3 (UB→GM) + +The PTO micro Instruction model operates on the Ascend 950's **Decoupled Access-Execute** architecture. The MTE and Vector pipelines run asynchronously, requiring explicit synchronization to prevent data hazards. + +--- + +#### Intra-Core Pipeline Sync + +These ops coordinate data flow between pipelines within a single vector core. + +##### `pto.set_flag` + +- **syntax:** `pto.set_flag["SRC_PIPE", "DST_PIPE", "EVENT_ID"]` +- **semantics:** Signal event from source pipe to destination pipe. + +```c +set_flag(src_pipe, dst_pipe, event_id); +``` + +**Example:** After MTE2 completes GM→UB transfer, signal Vector pipe: +```mlir +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +--- + +##### `pto.wait_flag` + +- **syntax:** `pto.wait_flag["SRC_PIPE", "DST_PIPE", "EVENT_ID"]` +- **semantics:** Block destination pipe until source pipe signals event. + +```c +wait_flag(src_pipe, dst_pipe, event_id); +``` + +**Example:** Vector pipe waits for MTE2 data to arrive: +```mlir +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +--- + +##### `pto.pipe_barrier` + +- **syntax:** `pto.pipe_barrier "PIPE_*"` +- **semantics:** Drain all pending ops in the specified pipe. All previously issued operations on that pipe complete before any subsequent operation begins. + +```c +pipe_barrier(pipe); +``` + +**Pipe identifiers:** `PIPE_MTE2`, `PIPE_V`, `PIPE_MTE3` + +**Example:** Two back-to-back `copy_ubuf_to_gm` calls writing to the same GM address. Without a barrier, MTE3 may reorder them and the final GM value is non-deterministic: + +```mlir +// Both stores target the same GM address — order matters! +pto.copy_ubuf_to_gm %ub_partial_0, %gm_result, ... +// Without pipe_barrier, MTE3 could execute the second copy before the first +// completes, producing a non-deterministic result at %gm_result. +pto.pipe_barrier "PIPE_MTE3" +// After barrier: first copy is guaranteed complete. Second copy overwrites deterministically. +pto.copy_ubuf_to_gm %ub_partial_1, %gm_result, ... +``` + +--- + +##### `pto.get_buf` + +- **syntax:** `pto.get_buf "PIPE_*", %buf_id, %mode : i64, i64` +- **semantics:** Acquire buffer slot for inter-pipeline double-buffering coordination. + +```c +get_buf(pipe, buf_id, mode); +``` + +--- + +##### `pto.rls_buf` + +- **syntax:** `pto.rls_buf "PIPE_*", %buf_id, %mode : i64, i64` +- **semantics:** Release buffer slot to allow other pipeline to proceed. + +```c +rls_buf(pipe, buf_id, mode); +``` + +--- + +##### `pto.mem_bar` + +- **syntax:** `pto.mem_bar "BARRIER_TYPE"` +- **semantics:** Intra-vector-pipe memory fence within `__VEC_SCOPE__`. Required when UB addresses alias between vector load/store operations. + +```c +mem_bar(barrier_type); +``` + +**Barrier types:** + +| Type | Semantics | +|------|-----------| +| `VV_ALL` | All prior vector ops complete before subsequent | +| `VST_VLD` | All prior vector stores visible before subsequent loads | +| `VLD_VST` | All prior vector loads complete before subsequent stores | + +**Example:** Ensure stores are visible before loads to same UB region: +```mlir +pto.vsts %v0, %ub[%c0] : !pto.vreg<64xf32>, !pto.ptr +pto.mem_bar "VST_VLD" +%v1 = pto.vlds %ub[%c0] : !pto.ptr -> !pto.vreg<64xf32> +``` + +--- + +#### Intra-Core Sync Patterns & Examples + +##### Example 1: `set_flag` / `wait_flag` (Explicit Events) + +Each cross-pipeline data dependency requires an explicit signal/wait pair. The programmer must manually insert `set_flag` after the producer and `wait_flag` before the consumer. + +```mlir +// ─── Stage 1: MTE2 loads data from GM into UB ─── +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, ... + +// MTE2 signals: "UB data is ready for Vector pipe" +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +// ─── Stage 2: Vector pipe consumes UB data ─── +// Vector waits until MTE2's signal arrives +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_ptr[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} {llvm.loop.aivector_scope} + +// Vector signals: "UB output is ready for MTE3" +pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + +// ─── Stage 3: MTE3 stores result from UB back to GM ─── +// MTE3 waits until Vector's signal arrives +pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + +pto.copy_ubuf_to_gm %ub_out, %gm_out, ... +``` + +**Key property:** Every cross-pipeline edge is an explicit `(set_flag, wait_flag)` pair. Simple for straight-line code, but gets verbose in loops (see Example 3). + +--- + +##### Example 2: `get_buf` / `rls_buf` (Resource-Based) + +Instead of naming events, each pipeline declares when it **acquires** (`get_buf`) and **releases** (`rls_buf`) a shared UB buffer. Cross-pipeline RAW/WAR dependencies are resolved implicitly by program order — if MTE2 releases `buf_A` and Vector later acquires `buf_A`, the hardware ensures the acquire cannot proceed until the release completes. + +```mlir +// ─── Stage 1: MTE2 loads data into UB ─── +// MTE2 acquires ub_ptr — blocks if Vector hasn't released it from a prior iteration +pto.get_buf "PIPE_MTE2", %bufid_ub_ptr, %mode : i64, i64 +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, ... +// MTE2 done writing ub_ptr — release it so Vector can consume +pto.rls_buf "PIPE_MTE2", %bufid_ub_ptr, %mode : i64, i64 + +// ─── Stage 2: Vector computation ─── +// Vector acquires ub_ptr (input) — blocks until MTE2 releases it (RAW: MTE2 write → V read) +pto.get_buf "PIPE_V", %bufid_ub_ptr, %mode : i64, i64 +// Vector acquires ub_out (output) — blocks until MTE3 releases it from a prior iteration (WAR: MTE3 read → V write) +pto.get_buf "PIPE_V", %bufid_ub_out, %mode : i64, i64 + +scf.for %dummy = %c0 to %c1 step %c1 { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub_ptr[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} {llvm.loop.aivector_scope} + +// Vector done reading ub_ptr — release so MTE2 can reuse it in next iteration +pto.rls_buf "PIPE_V", %bufid_ub_ptr, %mode : i64, i64 +// Vector done writing ub_out — release so MTE3 can consume +pto.rls_buf "PIPE_V", %bufid_ub_out, %mode : i64, i64 + +// ─── Stage 3: MTE3 stores result to GM ─── +// MTE3 acquires ub_out — blocks until Vector releases it (RAW: V write → MTE3 read) +pto.get_buf "PIPE_MTE3", %bufid_ub_out, %mode : i64, i64 +pto.copy_ubuf_to_gm %ub_out, %gm_out, ... +// MTE3 done reading ub_out — release so Vector can reuse it in next iteration +pto.rls_buf "PIPE_MTE3", %bufid_ub_out, %mode : i64, i64 +``` + +**Key property:** No event IDs needed. Dependencies are implicit from program order of `get_buf`/`rls_buf` on the same buffer ID. This becomes much more convenient in multi-iteration loops (see Example 3). + +--- + +##### Example 3: Ping/Pong Double-Buffering Loop + +Double-buffering overlaps DMA and compute by using two UB buffers alternately. All three stages (MTE2, Vector, MTE3) appear in the **same iteration** — the hardware pipelines them across iterations because different iterations operate on different buffers (`buf[i%2]`). + +###### Event ID scheme (`set_flag` / `wait_flag`) + +With 2 ping/pong buffers and 2 pipeline pairs (MTE2↔V, V↔MTE3), `set_flag`/`wait_flag` needs **8 event IDs** = 2 pipe-pairs × 2 buffers × (forward + reverse): + +**MTE2 ↔ Vector (input buffers):** + +| Event ID | Direction | Purpose | +|----------|-----------|---------| +| `EVT_IN_FWD_0` | MTE2 → V | RAW: buf_in[0] data ready | +| `EVT_IN_FWD_1` | MTE2 → V | RAW: buf_in[1] data ready | +| `EVT_IN_REV_0` | V → MTE2 | WAR: Vector done reading buf_in[0] | +| `EVT_IN_REV_1` | V → MTE2 | WAR: Vector done reading buf_in[1] | + +**Vector ↔ MTE3 (output buffers):** + +| Event ID | Direction | Purpose | +|----------|-----------|---------| +| `EVT_OUT_FWD_0` | V → MTE3 | RAW: buf_out[0] result ready | +| `EVT_OUT_FWD_1` | V → MTE3 | RAW: buf_out[1] result ready | +| `EVT_OUT_REV_0` | MTE3 → V | WAR: MTE3 done reading buf_out[0] | +| `EVT_OUT_REV_1` | MTE3 → V | WAR: MTE3 done reading buf_out[1] | + +###### 3a. `set_flag` / `wait_flag` version + +```mlir +// ═══ Pre-loop: prime ALL reverse-dependency signals ═══ +// Both input and output buffers start unused. We must pre-send +// reverse-dep signals so the first iteration's wait_flags don't deadlock. +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] // ◀ PRIME: buf_in[0] "free" +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] // ◀ PRIME: buf_in[1] "free" +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] // ◀ PRIME: buf_out[0] "free" +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] // ◀ PRIME: buf_out[1] "free" + +scf.for %i = %c0 to %N step %c1 { + // ── All 3 stages in same iteration, indexed by i%2 ── + // %pp = i % 2 (ping/pong selector for buffer & event IDs) + + // ── MTE2: load tile[i] into buf_in[i%2] ── + // WAR: wait until Vector has released buf_in[i%2] from iteration i-2 + pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{pp}"] + pto.copy_gm_to_ubuf %gm_ptr[%i], %ub_in[%pp], ... + // RAW: signal Vector that buf_in[i%2] data is ready + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVT_IN_FWD_{pp}"] + + // ── Vector: compute buf_in[i%2] → buf_out[i%2] ── + // RAW: wait for MTE2 to finish loading buf_in[i%2] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVT_IN_FWD_{pp}"] + // WAR: wait for MTE3 to finish reading buf_out[i%2] from iteration i-2 + pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{pp}"] + scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_in[%pp][%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%pp][%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } {llvm.loop.aivector_scope} + // WAR: tell MTE2 "done reading buf_in[i%2]" + pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{pp}"] + // RAW: tell MTE3 "buf_out[i%2] result ready" + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVT_OUT_FWD_{pp}"] + + // ── MTE3: store result from buf_out[i%2] to GM ── + // RAW: wait for Vector to finish writing buf_out[i%2] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVT_OUT_FWD_{pp}"] + pto.copy_ubuf_to_gm %ub_out[%pp], %gm_out[%i], ... + // WAR: tell Vector "done reading buf_out[i%2]" + pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{pp}"] +} + +// ═══ Post-loop: drain — match every pre-loop prime with a wait ═══ +// Each priming set_flag must be paired. The last loop iteration's +// set_flags are consumed by wait_flags that will never fire inside the +// loop (there is no iteration i+2). Drain them here. +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-1)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-2)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-1)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-2)%2}"] // ◀ DRAIN +``` + +**What `set_flag`/`wait_flag` requires outside the loop:** +- **Before the loop (4 × `set_flag`):** Prime every reverse-dependency event ID — one per buffer per pipe-pair. Without this, the first iteration's `wait_flag` for reverse deps would deadlock (no signal was ever sent). +- **After the loop (4 × `wait_flag`):** Drain the matching reverse-dep signals from the last iterations. Every `set_flag` must be paired with a `wait_flag` — the last loop iterations produce signals that no subsequent iteration consumes, so they must be drained explicitly. + +###### 3b. `get_buf` / `rls_buf` version + +Same ping/pong double-buffering, but **no pre-loop priming or post-loop draining needed.** Buffer acquire/release semantics handle everything. + +```mlir +scf.for %i = %c0 to %N step %c1 { + // %pp = i % 2 (ping/pong selector) + + // ── MTE2: load tile[i] into buf[i%2] ── + // Acquires buf[i%2] — on first iteration, buffer is free so proceeds immediately. + // On later iterations, blocks until Vector releases buf[i%2] (WAR: automatic). + pto.get_buf %bufid_buf[%pp], "PIPE_MTE2" + pto.copy_gm_to_ubuf %gm_ptr[%i], %ub_buf[%pp], ... + pto.rls_buf %bufid_buf[%pp], "PIPE_MTE2" + + // ── Vector: compute on buf[i%2] ── + // Acquires buf[i%2] — blocks until MTE2 releases it (RAW: automatic) + pto.get_buf %bufid_buf[%pp], "PIPE_V" + pto.get_buf %bufid_out[%pp], "PIPE_V" + scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_buf[%pp][%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%pp][%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } {llvm.loop.aivector_scope} + // Release buf[i%2] — MTE2 can reuse in iteration i+2 (WAR resolved) + pto.rls_buf %bufid_buf[%pp], "PIPE_V" + pto.rls_buf %bufid_out[%pp], "PIPE_V" + + // ── MTE3: store result ── + // Acquires out[i%2] — blocks until Vector releases it (RAW: automatic) + pto.get_buf %bufid_out[%pp], "PIPE_MTE3" + pto.copy_ubuf_to_gm %ub_out[%pp], %gm_out[%i], ... + pto.rls_buf %bufid_out[%pp], "PIPE_MTE3" +} +// No post-loop drain needed — last rls_buf completes the pipeline. +``` + +**No priming, no draining, no event IDs.** The acquire/release protocol on buffer IDs indexed by `i%2` implicitly resolves all cross-pipeline dependencies: +- **RAW** (MTE2→V): Vector's `get_buf` blocks until MTE2's `rls_buf` on `buf[i%2]` +- **WAR** (V→MTE2): MTE2's `get_buf` in iteration `i+2` blocks until Vector's `rls_buf` in iteration `i` (same buffer) +- **First iteration:** Buffer is initially free, so `get_buf` proceeds without blocking — no priming needed + +--- + +#### Comparison Summary + +| Aspect | `set_flag` / `wait_flag` | `get_buf` / `rls_buf` | +|--------|--------------------------|------------------------| +| Dependency model | Explicit event signals | Implicit via buffer acquire/release | +| IDs per pipe-pair | **8** = 2 buffers × 2 dirs × 2 (fwd+rev) | 1 fwd + 1 rev per buffer (shared global pool) | +| Total HW IDs | 8 per pipe-pair, grows with buffers | **32 global** across all pipes | +| Reverse (WAR) deps | Extra `set_flag`/`wait_flag` pair per buffer | Handled automatically | +| Pre-loop setup | `set_flag` to prime each reverse dep | None | +| Post-loop teardown | `wait_flag` to drain all primed signals | None | +| Straight-line code | Simple, clear | Slightly more verbose (bracket each stage) | +| Ping/pong loops | 8 event IDs + 4 prime + 4 drain | Same pattern, no overhead | +| Best used for | Simple pipelines, fine-grained control | Double/multi-buffering, complex loops | + +--- + +#### Inter-Core Sync + +> **Note:** Inter-core sync is only needed for **mixed Cube+Vector tasks** where Cube produces data that Vector consumes (or vice versa). **Vec-only tasks can ignore this section entirely.** + +These ops coordinate execution across the Cube block and Vector subblocks within a cluster. Each core cluster consists of **1 Cube block : 2 Vector subblocks**, each with its own **SU (Sequencer Unit)** running independent instruction streams. + +``` +Core Cluster (1:2 ratio) +┌─────────────────────────────────────────────┐ +│ ┌──────────────┐ ┌──────────────┐ │ +│ │ AIC (Cube) │ │ AIV0 (Vec) │ │ +│ │ ┌────────┐ │ │ ┌────────┐ │ │ +│ │ │ SU │──┼────┼──│ SU │ │ │ +│ │ └────────┘ │ │ └────────┘ │ │ +│ │ CUBE pipe │ │ MTE2/V/MTE3 │ │ +│ │ L0C buffer │ │ UB (256KB) │ │ +│ └──────────────┘ └──────────────┘ │ +│ ┌──────────────┐ │ +│ │ AIV1 (Vec) │ │ +│ │ ┌────────┐ │ │ +│ │ │ SU │ │ │ +│ │ └────────┘ │ │ +│ │ MTE2/V/MTE3 │ │ +│ │ UB (256KB) │ │ +│ └──────────────┘ │ +└─────────────────────────────────────────────┘ +``` + +##### Platform Comparison + +| Aspect | A2A3 (Ascend 910) | A5 (Ascend 950) | +|--------|-------------------|-----------------| +| **Signal op** | `set_cross_core` (mode2) | `set_intra_block` | +| **Wait op** | `wait_flag_dev` | `wait_intra_core` | +| **Wait behavior** | SU-level blocking (entire core stalls) | Per-pipeline (only named pipe stalls) | +| **Semaphore pool** | 16 IDs per cluster, 4-bit counter | 16 IDs, but 32-ID address space (see below) | +| **C→V** | **Broadcast**: one `set` reaches both AIV0+AIV1 | **1:1**: separate `set` per subblock required | +| **V→C** | **Reduce**: Cube waits for both subblocks in one `wait` | **1:1**: Cube needs separate `wait` per subblock | + +##### A2A3: `set_cross_core` / `wait_flag_dev` + +```c +// mode2 broadcast/reduce semantics for 1:2 cluster +set_cross_core(pipe, semaphore_id); // pipe: VEC/MTE2/CUBE/FIX +wait_flag_dev(semaphore_id); // SU-level blocking +``` + +``` +C→V Broadcast (one set reaches both): + AIC ──set_cross_core──┬──> AIV0 sema++ + └──> AIV1 sema++ + +V→C Reduce (one wait for both): + AIV0 ──set_cross_core──┐ + ├──> AIC wait_flag_dev (blocks until BOTH) + AIV1 ──set_cross_core──┘ +``` + +##### `pto.set_cross_core` + +- **syntax:** `pto.set_cross_core %core_id, %event_id : i64, i64` +- **semantics:** Signal event to another core. Uses **mode2** for 1:2 cluster on A2A3. + +##### `pto.wait_flag_dev` + +- **syntax:** `pto.wait_flag_dev %core_id, %event_id : i64, i64` +- **semantics:** Wait for event from another core. **SU-level blocking** — entire core stalls. + +##### A5: `set_intra_block` / `wait_intra_core` + +```c +set_intra_block(trigger_pipe, semaphore_id); +wait_intra_core(wait_pipe, semaphore_id); // only named pipe stalls +``` + +**A5 semaphore address space:** The hardware has **16 physical semaphore IDs** but exposes a **32-ID address space** to support 1:1 signaling to each subblock: + +| ID Range | Target | +|----------|--------| +| 0–15 | AIV0 (subblock 0) | +| 16–31 (+15 offset) | AIV1 (subblock 1) | + +This means C→V requires **separate `set_intra_block` calls** per subblock (no broadcast), and V→C requires **separate `wait_intra_core` calls** per subblock (no hardware reduce). + +``` +C→V on A5 (1:1, no broadcast — need two sets): + AIC ──set_intra_block(pipe, sema_id)────> AIV0 + AIC ──set_intra_block(pipe, sema_id+15)──> AIV1 + +V→C on A5 (1:1, no reduce — need two waits): + AIV0 ──set_intra_block──> AIC wait_intra_core(pipe, sema_id) + AIV1 ──set_intra_block──> AIC wait_intra_core(pipe, sema_id+15) // extra wait +``` + +##### `pto.set_intra_block` + +- **syntax:** `pto.set_intra_block %block_id, %event_id : i64, i64` +- **semantics:** Signal event within a block (A5). Specifies **trigger pipe**. 1:1 per subblock. + +##### `pto.wait_intra_core` + +- **syntax:** `pto.wait_intra_core %block_id, %event_id : i64, i64` +- **semantics:** Wait for event within block (A5). Specifies **which pipeline should wait** — only that pipe stalls, SU and other pipes continue. + +##### Wait Granularity Comparison + +``` +A2A3 wait_flag_dev (SU-level stall): + SU ──┬── PIPE_MTE2 ───╳ ALL STALLED + ├── PIPE_V ───╳ ALL STALLED + └── PIPE_MTE3 ───╳ ALL STALLED + +A5 wait_intra_core "PIPE_MTE2" (per-pipe stall): + SU ──┬── PIPE_MTE2 ───╳ STALLED (waiting for Cube) + ├── PIPE_V ─── ✓ RUNNING + └── PIPE_MTE3 ─── ✓ RUNNING +``` + +--- + + + +### 2. DMA Copy Programming + +> **Category:** DMA transfer configuration and execution +> **Pipelines:** MTE2 (GM→UB), MTE3 (UB→GM) + +DMA transfers move data between Global Memory (GM) and Unified Buffer (UB). The MTE engines operate asynchronously from the Vector core, requiring explicit sync (see [Pipeline Sync](#isa-01-pipeline-sync)). + +The MTE2/MTE3 DMA engine executes a **multi-level nested loop** transfer. Before issuing the copy instruction, stride and loop-size registers must be configured. + +--- + +#### Loop Stride Configuration (GM→UB) + +These ops configure the MTE2 DMA engine's hardware loops for GM→UB transfers. They must be set **before** calling `pto.copy_gm_to_ubuf`. + +##### `pto.set_loop_size_outtoub` + +- **syntax:** `pto.set_loop_size_outtoub %loop1_count, %loop2_count : i64, i64` +- **semantics:** Configure HW loop iteration counts for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%loop1_count` | 21 bits | Inner HW loop iteration count | +| `%loop2_count` | 21 bits | Outer HW loop iteration count | + +When not using multi-level looping, set both to 1. + +--- + +##### `pto.set_loop2_stride_outtoub` + +- **syntax:** `pto.set_loop2_stride_outtoub %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure outer loop (loop2) pointer advance for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 40 bits | GM source pointer advance per loop2 iteration (bytes) | +| `%dst_stride` | 21 bits | UB destination pointer advance per loop2 iteration (bytes) | + +After each loop2 iteration, the DMA engine advances the GM read pointer by `%src_stride` and UB write pointer by `%dst_stride`. + +--- + +##### `pto.set_loop1_stride_outtoub` + +- **syntax:** `pto.set_loop1_stride_outtoub %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure inner loop (loop1) pointer advance for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 40 bits | GM source pointer advance per loop1 iteration (bytes) | +| `%dst_stride` | 21 bits | UB destination pointer advance per loop1 iteration (bytes) | + +--- + +#### Loop Stride Configuration (UB→GM) + +These ops configure the MTE3 DMA engine's hardware loops for UB→GM transfers. They must be set **before** calling `pto.copy_ubuf_to_gm`. + +Note: UB stride fields are 21 bits (sufficient for 256KB UB address space), GM stride fields are 40 bits (full GM address range). + +##### `pto.set_loop_size_ubtoout` + +- **syntax:** `pto.set_loop_size_ubtoout %loop1_count, %loop2_count : i64, i64` +- **semantics:** Configure HW loop iteration counts for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%loop1_count` | 21 bits | Inner HW loop iteration count | +| `%loop2_count` | 21 bits | Outer HW loop iteration count | + +--- + +##### `pto.set_loop2_stride_ubtoout` + +- **syntax:** `pto.set_loop2_stride_ubtoout %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure outer loop (loop2) pointer advance for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 21 bits | UB source pointer advance per loop2 iteration (bytes) | +| `%dst_stride` | 40 bits | GM destination pointer advance per loop2 iteration (bytes) | + +--- + +##### `pto.set_loop1_stride_ubtoout` + +- **syntax:** `pto.set_loop1_stride_ubtoout %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure inner loop (loop1) pointer advance for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 21 bits | UB source pointer advance per loop1 iteration (bytes) | +| `%dst_stride` | 40 bits | GM destination pointer advance per loop1 iteration (bytes) | + +--- + +#### DMA Transfer Execution + +##### `pto.copy_gm_to_ubuf` + +- **syntax:** +```mlir +pto.copy_gm_to_ubuf %gm_src, %ub_dst, + %sid, %n_burst, %len_burst, %left_padding, %right_padding, + %data_select_bit, %l2_cache_ctl, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` +- **semantics:** DMA transfer from Global Memory (`!pto.ptr`) to Unified Buffer (`!pto.ptr`). + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%gm_src` | GM source pointer (`!pto.ptr`) | +| `%ub_dst` | UB destination pointer (`!pto.ptr`, 32B-aligned) | +| `%sid` | Stream ID (usually 0) | +| `%n_burst` | Number of burst rows (innermost loop count) | +| `%len_burst` | Contiguous bytes transferred per burst row | +| `%left_padding` | Left padding count (bytes) | +| `%right_padding` | Right padding count (bytes) | +| `%data_select_bit` | Padding / data-select control bit (`i1`) | +| `%l2_cache_ctl` | L2 cache allocate control (TBD — controls whether DMA allocates in L2 cache) | +| `%src_stride` | GM source stride: start-to-start distance between consecutive burst rows (bytes) | +| `%dst_stride` | UB destination stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | + +--- + +##### `pto.copy_ubuf_to_gm` + +- **syntax:** +```mlir +pto.copy_ubuf_to_gm %ub_src, %gm_dst, + %sid, %n_burst, %len_burst, %reserved, %dst_stride, %src_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` +- **semantics:** DMA transfer from Unified Buffer (`!pto.ptr`) to Global Memory (`!pto.ptr`). MTE3 reads only `len_burst` bytes from each UB row (de-padding). + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%ub_src` | UB source pointer (`!pto.ptr`, 32B-aligned) | +| `%gm_dst` | GM destination pointer (`!pto.ptr`) | +| `%sid` | Stream ID (usually 0) | +| `%n_burst` | Number of burst rows | +| `%len_burst` | Contiguous bytes transferred per burst row | +| `%reserved` | Reserved field (set to 0) | +| `%dst_stride` | GM destination stride: start-to-start distance between consecutive burst rows (bytes) | +| `%src_stride` | UB source stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | + +--- + +##### `pto.copy_ubuf_to_ubuf` + +- **syntax:** +```mlir +pto.copy_ubuf_to_ubuf %source, %dest, %sid, %n_burst, %len_burst, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64 x5 +``` +- **semantics:** Copy within Unified Buffer. + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%source` | UB source pointer | +| `%dest` | UB destination pointer | +| `%sid` | Stream ID | +| `%n_burst` | Number of bursts | +| `%len_burst` | Length per burst | +| `%src_stride` | Source stride | +| `%dst_stride` | Destination stride | + +--- + +#### Burst / Stride / Pad Model + +All A5 DMA addresses are **stride-based**: stride is the distance from the start of one row to the start of the next row (`stride >= lenBurst`). There is no separate "gap" parameter. + +##### Key Terms + +``` +burst = lenBurst contiguous bytes transferred per row +stride = distance (bytes) from start of row[r] to start of row[r+1] +pad = ub_stride - lenBurst, padded to the 32B alignment boundary +``` + +##### Alignment Constraints + +- **UB addresses** (both source and destination) must be **32-byte aligned**. +- **GM→UB padding**: When `data_select_bit = true`, each UB row is padded from `lenBurst` up to the **32B-aligned boundary** of `ub_stride` with `pad_val` (set via `set_mov_pad_val`). This ensures every UB row starts at a 32B-aligned offset. +- **UB→GM de-padding**: MTE3 reads `lenBurst` bytes from each 32B-aligned UB row (skipping any padding that was added during load), writing only valid data to GM. This effectively strips padding on store. + +##### 2D Diagram: GM→UB (pto.copy_gm_to_ubuf) + +``` +GM (source, `!pto.ptr`): + + |<--- src_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +UB (destination, `!pto.ptr`, 32B-aligned): + + |<---------- dst_stride (32B-aligned) ---------->| + |<- len_burst ->|<- pad (to 32B boundary) ->| | +Row 0: [##DATA########][000000 PAD 000000000000000] +Row 1: [##DATA########][000000 PAD 000000000000000] +Row 2: [##DATA########][000000 PAD 000000000000000] + ... +Row N-1: [##DATA########][000000 PAD 000000000000000] + +N = n_burst +stride = start of row[r] to start of row[r+1] +pad = filled with pad_val to 32B boundary (data_select_bit=true) +[DATA] = valid data transferred by DMA +[PAD] = pad_val fill (set via set_mov_pad_val) +``` + +##### 2D Diagram: UB→GM (pto.copy_ubuf_to_gm) + +``` +UB (source, `!pto.ptr`, 32B-aligned start addr): + + |<---------- src_stride (32B-aligned) --------->| + |<- len_burst ->|<-- pad (ignored on read) -->| | +Row 0: [##DATA########][000 pad 000000000000000000] +Row 1: [##DATA########][000 pad 000000000000000000] +Row 2: [##DATA########][000 pad 000000000000000000] + ... +Row N-1: [##DATA########][000 pad 000000000000000000] + +GM (destination, `!pto.ptr`): + + |<--- dst_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +N = n_burst +MTE3 reads only len_burst bytes from each UB row (de-padding). +Only len_burst bytes are written to each GM row. +``` + +--- + +#### Multi-Level Loop Semantics (C Code) + +The full DMA transfer is a nested loop. The HW loop registers (set before the copy) control the outer levels, and the copy instruction parameters control the innermost burst level. + +##### GM→UB Full Loop + +```c +// C equivalent of what the HW executes: +for (int j = 0; j < loop2_count; j++) { // HW outer loop + uint8_t *gm1 = gm_src + j * loop2_src_stride; + uint8_t *ub1 = ub_dst + j * loop2_dst_stride; + + for (int k = 0; k < loop1_count; k++) { // HW inner loop + uint8_t *gm2 = gm1 + k * loop1_src_stride; + uint8_t *ub2 = ub1 + k * loop1_dst_stride; + + for (int r = 0; r < n_burst; r++) { // burst engine + memcpy(ub2 + r * dst_stride, // UB dest row + gm2 + r * src_stride, // GM src row + len_burst); // contiguous bytes + if (data_select_bit) + memset(ub2 + r * dst_stride + len_burst, + pad_val, dst_stride - len_burst); + } + } +} +``` + +##### UB→GM Full Loop + +```c +// C equivalent: +for (int j = 0; j < loop2_count; j++) { + uint8_t *ub1 = ub_src + j * loop2_src_stride; + uint8_t *gm1 = gm_dst + j * loop2_dst_stride; + + for (int k = 0; k < loop1_count; k++) { + uint8_t *ub2 = ub1 + k * loop1_src_stride; + uint8_t *gm2 = gm1 + k * loop1_dst_stride; + + for (int r = 0; r < n_burst; r++) { + memcpy(gm2 + r * dst_stride, // GM dest row + ub2 + r * src_stride, // UB src row + len_burst); // contiguous bytes + } + } +} +``` + +--- + +#### Example 1: GM→UB — Load a 32×32 f32 Tile (Simple Case) + +Load a 32×32 f32 tile from GM into UB. This matches the `abs_kernel_2d` test case. + +``` +GM layout (32 × 32 f32, contiguous): + + |<- len_burst = 128B (32 × 4) ->| + |<- src_stride = 128B --------->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + +UB layout (32 × 32 f32, 32B-aligned, contiguous): + + |<- dst_stride = 128B (32B-aligned) ->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + + len_burst = 32 × 4 = 128 bytes + src_stride = 128 bytes (contiguous rows) + dst_stride = 128 bytes (already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — no multi-level loops needed +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + +pto.copy_gm_to_ubuf %arg0, %ub_in, + %c0_i64, // sid = 0 + %c32_i64, // n_burst = 32 (32 rows) + %c128_i64, // len_burst = 128 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c128_i64, // src_stride = 128 bytes + %c128_i64 // dst_stride = 128 bytes + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 2: GM→UB — Load a 2D Tile from a Larger Matrix + +Load a 64×128 tile (f16) from a 1024×512 matrix in GM into UB. + +``` +GM layout (1024 × 512 f16): + + col 0 col 128 col 512 + | | | + +--[###TILE###]+.....................+ row R + +--[###TILE###]+.....................+ row R+1 + ... + +--[###TILE###]+.....................+ row R+63 + + |<--------- src_stride = 1024B ----------->| + |<-len_burst=256B->| + + len_burst = 128 × 2 = 256 bytes (128 f16 elements) + src_stride = 512 × 2 = 1024 bytes (start-to-start, full GM row) + +UB layout (64 × 128 f16, 32B-aligned, contiguous): + + +--[###TILE###]--+ row 0 (256 bytes, 32B-aligned, no pad) + +--[###TILE###]--+ row 1 + ... + +--[###TILE###]--+ row 63 + + dst_stride = 256 bytes (= len_burst, already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — no multi-level loops needed +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 (64 rows) + %c256_i64, // len_burst = 256 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c1024_i64, // src_stride = 1024 bytes (full matrix row) + %c256_i64 // dst_stride = 256 bytes (tile row) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 3: GM→UB — Load with Padding + +Load 100 valid columns from GM into a 128-wide UB tile (f16). The remaining 28 columns are zero-padded. + +``` +GM (100 cols valid, contiguous): + + |<-len_burst=200B->| + |<- src_stride=200B (start-to-start) ->| + +--[####DATA####]-+ row 0 + +--[####DATA####]-+ row 1 + ... + +--[####DATA####]-+ row 63 + +UB (128 cols wide, 32B-aligned, padded): + + |<--------- dst_stride = 256B (32B-aligned) --------->| + |<-len_burst=200B->|<---- pad = 56B to 32B boundary ->| + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 0 + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 1 + ... + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 63 + + len_burst = 100 × 2 = 200 bytes + src_stride = 200 bytes (start-to-start, contiguous in GM) + dst_stride = 128 × 2 = 256 bytes (32B-aligned tile width in UB) + pad = 256 - 200 = 56 bytes (padded to 32B boundary with pad_val) +``` + +```mlir +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 + %c200_i64, // len_burst = 200 bytes + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %true, // data_select_bit = true (enable padding) + %c0_i64, // l2_cache_ctl = 0 + %c200_i64, // src_stride = 200 bytes + %c256_i64 // dst_stride = 256 bytes (32B-aligned) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 4: UB→GM — Store a 32×32 f32 Tile (Simple Case) + +Store a 32×32 f32 tile from UB back to GM. This matches the `abs_kernel_2d` test case. + +``` +UB (source, 32B-aligned, 32 × 32 f32): + + |<- src_stride = 128B (32B-aligned) ->| + |<- len_burst = 128B ->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 + + (no padding here — len_burst == src_stride) + +GM (dest, 32 × 32 f32): + + |<- dst_stride = 128B ->| + |<- len_burst = 128B -->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 +``` + +```mlir +// Configure MTE3 strides +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + +pto.copy_ubuf_to_gm %ub_out, %arg1, + %c0_i64, // sid = 0 + %c32_i64, // n_burst = 32 + %c128_i64, // len_burst = 128 bytes + %c0_i64, // reserved = 0 + %c128_i64, // dst_stride = 128 bytes + %c128_i64 // src_stride = 128 bytes + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` + +--- + +#### Example 5: UB→GM — Store a 2D Tile Back to a Larger Matrix + +Store a 64×128 tile (f16) from UB back to a 1024×512 GM matrix at an offset. + +``` +UB (source, 32B-aligned, 64 × 128 f16): + + |<- src_stride = 256B (32B-aligned) ->| + |<- len_burst = 256B ->| + +--[#####TILE#####]---+ row 0 + +--[#####TILE#####]---+ row 1 + ... + +--[#####TILE#####]---+ row 63 + + (no padding here — len_burst == src_stride) + +GM (dest, into 1024 × 512 matrix): + + |<----------- dst_stride = 1024B (start-to-start) --------->| + |<- len_burst = 256B ->| | + col 0 col 128 col 512 + +--[#####TILE#####]---+.............................+ row R + +--[#####TILE#####]---+.............................+ row R+1 + ... + +--[#####TILE#####]---+.............................+ row R+63 + + MTE3 reads len_burst bytes from each 32B-aligned UB row, + writes only len_burst bytes per GM row (stride controls row spacing). +``` + +```mlir +// Configure MTE3 strides +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 + +pto.copy_ubuf_to_gm %ub_ptr, %gm_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 + %c256_i64, // len_burst = 256 bytes + %c0_i64, // reserved = 0 + %c1024_i64, // dst_stride = 1024 bytes (GM row) + %c256_i64 // src_stride = 256 bytes (UB row) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` + +--- + +#### Example 6: GM→UB with Multi-Level Loop (Batch of Tiles) + +Load 4 batches of 8×128 tiles from a [4, 8, 128] f16 tensor using loop1. + +``` +GM [4, 8, 128] f16 (contiguous): UB (4 tiles laid out sequentially): + + batch 0: 8 rows × 256 bytes [batch 0: 8×128][batch 1: 8×128] + batch 1: 8 rows × 256 bytes [batch 2: 8×128][batch 3: 8×128] + batch 2: 8 rows × 256 bytes + batch 3: 8 rows × 256 bytes loop1 src_stride = 2048 bytes (8 × 256) + loop1 dst_stride = 2048 bytes (8 × 256) + Each batch = 8 × 256 = 2048 bytes loop1_count = 4 (iterate over batches) +``` + +```mlir +// loop1_count = 4 batches, loop2_count = 1 (not used) +pto.set_loop_size_outtoub %c4_i64, %c1_i64 : i64, i64 + +// loop1 stride: advance by one batch (2048 bytes) in both GM and UB +pto.set_loop1_stride_outtoub %c2048_i64, %c2048_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c8_i64, // n_burst = 8 rows per batch + %c256_i64, // len_burst = 256 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c256_i64, // src_stride = 256 (contiguous rows) + %c256_i64 // dst_stride = 256 (contiguous rows) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +Execution trace: + +``` +loop1 iter 0: gm_ptr + 0×2048 → ub_ptr + 0×2048, DMA 8 rows × 256B +loop1 iter 1: gm_ptr + 1×2048 → ub_ptr + 1×2048, DMA 8 rows × 256B +loop1 iter 2: gm_ptr + 2×2048 → ub_ptr + 2×2048, DMA 8 rows × 256B +loop1 iter 3: gm_ptr + 3×2048 → ub_ptr + 3×2048, DMA 8 rows × 256B +``` + +--- + + + +### 3. Vector Load/Store + +> **Category:** UB ↔ Vector Register data movement +> **Pipeline:** PIPE_V (Vector Core) + +Vector loads move data from Unified Buffer (UB) to vector registers (`vreg`). Vector stores move data from `vreg` back to UB. All vector compute operates only on `vreg` — UB is the staging area between DMA and compute. + +#### Common Operand Model + +- `%source` / `%dest` is the base address operand in SSA form. The base pointer + MUST address the Vector tile buffer / UB space. +- `%offset` is the displacement operand in SSA form. The exact encoding is + instruction-specific, but the effective address and any post-update behavior + MUST match the selected instruction form. +- `%mask` is the predicate operand for predicated memory families. For memory + families, + inactive lanes or inactive blocks MUST NOT issue memory requests unless the + instruction explicitly documents a different behavior. +- `%result` is the destination vector register value in SSA form. +- `!pto.align` is the SSA carrier for alignment-buffer state used by unaligned + load/store families. The PTO micro Instruction representation makes that state explicit rather than implicit. + +--- + +#### Contiguous Loads + +##### `pto.vlds` + +- **syntax:** `%result = pto.vlds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.vreg` +- **semantics:** Vector load with distribution mode. +- **inputs:** + `%source` is the UB base address, `%offset` is the load displacement, and + `DIST` selects the distribution mode. +- **outputs:** + `%result` is the loaded vector register value. +- **constraints and limitations:** + The effective address MUST satisfy the alignment rule of the selected + distribution mode. `NORM` reads one full vector footprint. Broadcast, + upsample, downsample, unpack, split-channel, and deinterleave modes change + how memory bytes are mapped into destination lanes, but they do not change the + fact that the source is UB memory. + +**Distribution modes:** + +| Mode | Description | C Semantics | +|------|-------------|-------------| +| `NORM` | Contiguous 256B load | `dst[i] = UB[base + i * sizeof(T)]` | +| `BRC_B8/B16/B32` | Broadcast single element | `dst[i] = UB[base]` for all i | +| `US_B8/B16` | Upsample (duplicate each element) | `dst[2*i] = dst[2*i+1] = UB[base + i]` | +| `DS_B8/B16` | Downsample (every 2nd element) | `dst[i] = UB[base + 2*i]` | +| `UNPK_B8/B16/B32` | Unpack (zero-extend to wider type) | `dst_i32[i] = (uint32_t)UB_i16[base + 2*i]` | +| `SPLT4CHN_B8` | Split 4-channel (RGBA → R plane) | Extract every 4th byte | +| `SPLT2CHN_B8/B16` | Split 2-channel | Extract every 2nd element | +| `DINTLV_B32` | Deinterleave 32-bit | Even elements only | +| `BLK` | Block load | Blocked access pattern | + +**Example — Contiguous load:** +```mlir +%v = pto.vlds %ub[%offset] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +``` + +**Example — Broadcast scalar to all lanes:** +```mlir +%v = pto.vlds %ub[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> +``` + +--- + +##### `pto.vldas` + +- **syntax:** `%result = pto.vldas %source : !pto.ptr -> !pto.align` +- **semantics:** Prime alignment buffer for subsequent unaligned load. +- **inputs:** + `%source` is the UB address whose surrounding aligned block seeds the load + alignment state. +- **outputs:** + `%result` is the initialized load-alignment state. +- **constraints and limitations:** + This op is the required leading operation for a `pto.vldus` stream using the + same alignment state. The source address itself need not be 32-byte aligned; + hardware truncates it to the aligned block boundary for the priming load. + +--- + +##### `pto.vldus` + +- **syntax:** `%result, %align_out, %base_out = pto.vldus %source, %align : !pto.ptr, !pto.align -> !pto.vreg, !pto.align, !pto.ptr` +- **semantics:** Unaligned load using primed align state. +- **inputs:** + `%source` is the current UB address and `%align` is the incoming load + alignment state primed by `pto.vldas` or a prior `pto.vldus`. +- **outputs:** + `%result` is the assembled vector value, `%align_out` is the updated alignment + state, and `%base_out` is the post-update base pointer state exposed in SSA + form. +- **constraints and limitations:** + A matching `pto.vldas` MUST appear before the first dependent `pto.vldus` + stream in the same vector loop. Both the alignment state and the base address + advance across the stream, and the PTO micro Instruction representation exposes those updates as SSA results. + +**Unaligned load pattern:** +```mlir +%align = pto.vldas %ub : !pto.ptr -> !pto.align +%vec, %align2, %ub2 = pto.vldus %ub, %align : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align, !pto.ptr +``` + +--- + +#### Dual Loads (Deinterleave) + +##### `pto.vldx2` + +- **syntax:** `%low, %high = pto.vldx2 %source[%offset], "DIST" : !pto.ptr, index -> !pto.vreg, !pto.vreg` +- **semantics:** Dual load with deinterleave (AoS → SoA conversion). +- **inputs:** + `%source` is the UB base pointer, `%offset` is the displacement, and `DIST` + selects a dual-load/deinterleave layout. +- **outputs:** + `%low` and `%high` are the two destination vectors. +- **constraints and limitations:** + This family is only legal for interleave/deinterleave style distributions. + The two outputs form an ordered pair, and that pairing MUST be preserved. + +**Distribution modes:** `DINTLV_B8`, `DINTLV_B16`, `DINTLV_B32`, `BDINTLV` + +```c +// DINTLV_B32: deinterleave 32-bit elements +for (int i = 0; i < 64; i++) { + low[i] = UB[base + 8*i]; // even elements + high[i] = UB[base + 8*i + 4]; // odd elements +} +``` + +**Example — Load interleaved XY pairs into separate X/Y vectors:** +```mlir +%x, %y = pto.vldx2 %ub[%offset], "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +``` + +--- + +#### Strided Loads + +##### `pto.vsld` + +- **syntax:** `%result = pto.vsld %source[%offset], "STRIDE" : !pto.ptr -> !pto.vreg` +- **semantics:** Strided load with fixed stride pattern. +- **inputs:** + `%source` is the UB base pointer and `%offset` is the displacement encoded + with the selected fixed stride mode. +- **outputs:** + `%result` is the loaded vector. +- **constraints and limitations:** + This is a deprecated compatibility family. The selected stride token + determines which sub-elements are read from each source block. + +**Stride modes:** `STRIDE_S3_B16`, `STRIDE_S4_B64`, `STRIDE_S8_B32`, `STRIDE_S2_B64` + +--- + +##### `pto.vsldb` + +- **syntax:** `%result = pto.vsldb %source, %offset, %mask : !pto.ptr, i32, !pto.mask -> !pto.vreg` +- **semantics:** Block-strided load for 2D tile access. +- **inputs:** + `%source` is the UB base pointer, `%offset` is the packed stride/control word, + and `%mask` controls which blocks participate. +- **outputs:** + `%result` is the loaded vector. +- **constraints and limitations:** + `%offset` is not a plain byte displacement; it encodes the block stride and + repeat pattern. If a block is masked off, the corresponding destination block + is zeroed and MUST NOT raise an address overflow exception for that block. + +--- + +#### Gather (Indexed) Loads + +##### `pto.vgather2` + +- **syntax:** `%result = pto.vgather2 %source, %offsets, %active_lanes : !pto.ptr, !pto.vreg, index -> !pto.vreg` +- **semantics:** Indexed gather from UB. +- **inputs:** + `%source` is the UB base pointer, `%offsets` provides per-lane element + offsets, and `%active_lanes` bounds how many lanes participate. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + Only the first `%active_lanes` indices participate. The index element width + and interpretation MUST match the selected gather form, and each effective + address must satisfy that form's alignment rules. + +```c +for (int i = 0; i < active_lanes; i++) + dst[i] = UB[base + offsets[i] * sizeof(T)]; +``` + +--- + +##### `pto.vgatherb` + +- **syntax:** `%result = pto.vgatherb %source, %offsets, %active_lanes : !pto.ptr, !pto.vreg, index -> !pto.vreg` +- **semantics:** Byte-granularity indexed gather from UB. +- **inputs:** + `%source` is the UB base pointer, `%offsets` contains per-block byte offsets, + and `%active_lanes` bounds the number of active gathered blocks. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + This is a block gather, not a byte-per-lane gather. `%source` MUST be 32-byte + aligned, each participating offset MUST describe a 32-byte-aligned block, and + inactive blocks are zero-filled. + +```c +for (int i = 0; i < active_lanes; i++) + dst[i] = UB[base + offsets[i]]; // byte-addressed +``` + +--- + +##### `pto.vgather2_bc` + +- **syntax:** `%result = pto.vgather2_bc %source, %offsets, %mask : !pto.ptr, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Gather with broadcast, conditioned by mask. +- **inputs:** + `%source` is the UB base pointer, `%offsets` contains gather indices, and + `%mask` gates which lanes participate. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + This is a backward-compatible family. Masked-off lanes do not participate in + address coalescing and do not trigger address overflow exceptions; their + destination lanes are zero-filled. + +--- + +#### Contiguous Stores + +##### `pto.vsts` + +- **syntax:** `pto.vsts %value, %dest[%offset], %mask {dist = "DIST"} : !pto.vreg, !pto.ptr, !pto.mask` +- **semantics:** Vector store with distribution mode. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offset` is + the displacement, `%mask` selects the active lanes or sub-elements, and + `DIST` selects the store distribution. +- **outputs:** + This op has no SSA result; it writes to UB memory. +- **constraints and limitations:** + The effective destination address MUST satisfy the alignment rule of the + selected store mode. Narrowing/packing modes may only preserve a subset of the + source bits. Merge-channel modes reinterpret the source vector as channel + planes and interleave them on store. + +**Distribution modes:** + +| Mode | Description | C Semantics | +|------|-------------|-------------| +| `NORM_B8/B16/B32` | Contiguous store | `UB[base + i] = src[i]` | +| `PK_B16/B32` | Pack/narrowing store | `UB_i16[base + 2*i] = truncate_16(src_i32[i])` | +| `MRG4CHN_B8` | Merge 4 channels (R,G,B,A → RGBA) | Interleave 4 planes | +| `MRG2CHN_B8/B16` | Merge 2 channels | Interleave 2 planes | + +**Example — Contiguous store:** +```mlir +pto.vsts %v, %ub[%offset], %mask {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +--- + +#### Dual Stores (Interleave) + +##### `pto.vstx2` + +- **syntax:** `pto.vstx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vreg, !pto.ptr, index, !pto.mask` +- **semantics:** Dual interleaved store (SoA → AoS conversion). +- **inputs:** + `%low` and `%high` are the two source vectors, `%dest` is the UB base pointer, + `%offset` is the displacement, `DIST` selects the interleave layout, and + `%mask` gates the participating elements. +- **outputs:** + This op has no SSA result; it writes an interleaved stream to UB. +- **constraints and limitations:** + This family is only legal for interleave distributions. The two source + vectors form an ordered pair, and the interleave semantics of that pair MUST + be preserved. + +**Distribution modes:** `INTLV_B8`, `INTLV_B16`, `INTLV_B32` + +```c +// INTLV_B32: +for (int i = 0; i < 64; i++) { + UB[base + 8*i] = low[i]; + UB[base + 8*i + 4] = high[i]; +} +``` + +--- + +#### Strided Stores + +##### `pto.vsst` + +- **syntax:** `pto.vsst %value, %dest[%offset], "STRIDE" : !pto.vreg, !pto.ptr` +- **semantics:** Strided store with fixed stride pattern. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, and `%offset` + / `STRIDE` select the fixed strided layout. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + This is a deprecated compatibility family. The stride token, not the vector + lane number alone, determines which destination elements are written. + +--- + +##### `pto.vsstb` + +- **syntax:** `pto.vsstb %value, %dest, %offset, %mask : !pto.vreg, !pto.ptr, i32, !pto.mask` +- **semantics:** Block-strided store for 2D tile access. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offset` is + the packed stride/control word, and `%mask` controls block participation. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + `%offset` is a control word, not a plain byte displacement. This is a + deprecated compatibility family kept for surface coverage. + +--- + +#### Scatter (Indexed) Stores + +##### `pto.vscatter` + +- **syntax:** `pto.vscatter %value, %dest, %offsets, %active_lanes : !pto.vreg, !pto.ptr, !pto.vreg, index` +- **semantics:** Indexed scatter to UB. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offsets` + provides per-lane or per-block indices, and `%active_lanes` bounds the active + requests. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + Only `b8`, `b16`, and `b32` element sizes are supported. The index vector + must use a supported integer element type and layout for this family. + Each computed address MUST be element-aligned. If two or more indices alias, + only one write is guaranteed and the winning lane is implementation-defined. + +```c +for (int i = 0; i < active_lanes; i++) + UB[base + offsets[i] * sizeof(T)] = src[i]; +``` + +--- + +#### Alignment State Stores + +##### `pto.vsta` + +- **syntax:** `pto.vsta %value, %dest[%offset] : !pto.align, !pto.ptr, index` +- **semantics:** Flush alignment state to memory. +- **inputs:** + `%value` is the pending store-alignment state, `%dest` is the UB base pointer, + and `%offset` is the flush displacement. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + The flush address MUST match the post-updated address expected by the + preceding unaligned-store stream. After the flush, the corresponding store + alignment state is consumed. + +--- + +##### `pto.vstas` +- **syntax:** `pto.vstas %value, %dest, %offset : !pto.align, !pto.ptr, i32` +- **semantics:** Scalar-register-offset form of alignment-state flush. +- **inputs:** + `%value` is the pending store-alignment state, `%dest` is the UB base + pointer, and `%offset` is the scalar-register style displacement. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + This family uses the same buffered-tail semantics as `pto.vsta` but keeps the + scalar-offset form explicit. + +--- + +##### `pto.vstar` +- **syntax:** `pto.vstar %value, %dest : !pto.align, !pto.ptr` +- **semantics:** Flush alignment state using the register-update form. +- **inputs:** + `%value` is the pending store-alignment state and `%dest` is the UB base + pointer. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + The implicit update state consumed by this flush MUST correspond to the same + store stream that produced `%value`. + +--- + +##### `pto.vstu` +- **syntax:** `%align_out, %base_out = pto.vstu %align_in, %base_in, %value, %dest, %mode : !pto.align, !pto.ptr, !pto.vreg, !pto.ptr, index -> !pto.align, !pto.ptr` +- **semantics:** Unaligned store with explicit threaded alignment/base state. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%base_in` is the current + stream base, `%value` is the vector to store, `%dest` is the UB base pointer, + and `%mode` selects the post-update behavior. +- **outputs:** + `%align_out` is the updated buffered-tail state and `%base_out` is the + post-update base pointer state. +- **constraints and limitations:** + This op models a stateful unaligned-store sequence in SSA form. A final + `pto.vsta` / `pto.vstas` / `pto.vstar` is still required to flush the trailing + buffered bytes. + +--- + +##### `pto.vstus` +- **syntax:** `%align_out, %base_out = pto.vstus %align_in, %base_in, %value, %dest, %offset : !pto.align, !pto.ptr, !pto.vreg, !pto.ptr, i32 -> !pto.align, !pto.ptr` +- **semantics:** Scalar-offset unaligned store with threaded state. +- **inputs:** + Same roles as `pto.vstu`, but `%offset` is provided explicitly as the scalar + displacement. +- **outputs:** + Updated alignment state and base state. +- **constraints and limitations:** + The same final flush requirement and state-threading constraints as + `pto.vstu` apply. + +--- + +##### `pto.vstur` +- **syntax:** `%align_out = pto.vstur %align_in, %value, %dest : !pto.align, !pto.vreg, !pto.ptr -> !pto.align` +- **semantics:** Register-update unaligned store form. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%value` is the vector to + store, and `%dest` is the UB base pointer. +- **outputs:** + `%align_out` is the updated buffered-tail state. +- **constraints and limitations:** + This op updates only the residual alignment state. A matching flush op is + still required to emit the trailing bytes. + +- **syntax:** `pto.vstas %value, %dest, %offset : !pto.align, !pto.ptr, i32` +- **semantics:** Flush alignment state with scalar offset. + +--- + +##### `pto.vstar` + +- **syntax:** `pto.vstar %value, %dest : !pto.align, !pto.ptr` +- **semantics:** Flush remaining alignment state. +- **inputs:** + `%value` is the pending alignment/buffer state that still needs to be emitted, + and `%dest` is the UB destination base pointer. +- **outputs:** + No SSA result. The effect is a memory-side flush that writes the remaining + buffered bytes to memory. +- **constraints and limitations:** + This op terminates an unaligned-store sequence. It MUST be paired with a + compatible prior state-producing store sequence so that the pending tail state + is well-defined. + +--- + +#### Stateful Store Ops + +These ops make reference-updated state explicit as SSA results. + +##### `pto.vstu` + +- **syntax:** `%align_out, %offset_out = pto.vstu %align_in, %offset_in, %value, %base, "MODE" : !pto.align, index, !pto.vreg, !pto.ptr -> !pto.align, index` +- **semantics:** Unaligned store with align + offset state update. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%offset_in` is the current + logical byte/element displacement, `%value` is the vector being stored, and + `%base` is the UB base pointer. +- **outputs:** + `%align_out` is the updated alignment/tail state and `%offset_out` is the + next offset after applying the selected post-update rule. +- **constraints and limitations:** + The alignment state MUST be threaded in program order. A terminating flush + form such as `pto.vstar`/`pto.vstas` is still required to commit the buffered + tail bytes. + +**Mode tokens:** `POST_UPDATE`, `NO_POST_UPDATE` + +--- + +##### `pto.vstus` + +- **syntax:** `%align_out, %base_out = pto.vstus %align_in, %offset, %value, %base, "MODE" : !pto.align, i32, !pto.vreg, !pto.ptr -> !pto.align, !pto.ptr` +- **semantics:** Unaligned store with scalar offset and state update. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%offset` is the scalar + displacement, `%value` is the vector being stored, and `%base` is the UB base + pointer. +- **outputs:** + `%align_out` is the updated buffered-tail state and `%base_out` is the next + base pointer when the lowering chooses a post-update form. +- **constraints and limitations:** + This is the scalar-offset stateful form of the unaligned store family. The + scalar offset width and update mode MUST match the selected form, and a later + flush op is still required. + +--- + +##### `pto.vstur` + +- **syntax:** `%align_out = pto.vstur %align_in, %value, %base, "MODE" : !pto.align, !pto.vreg, !pto.ptr -> !pto.align` +- **semantics:** Unaligned store with residual flush and state update. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%value` is the vector to + store, and `%base` is the UB base pointer. +- **outputs:** + `%align_out` is the updated residual state after the current partial store. +- **constraints and limitations:** + This form exposes only the evolving state; it does not by itself guarantee + that all buffered bytes have reached memory. A compatible final flush is still + required unless the surrounding sequence is known to be complete. + +--- + + + +### 4. Predicate Load/Store + +> **Category:** UB ↔ Predicate Register data movement +> **Pipeline:** PIPE_V (Vector Core) + +Predicate registers (`!pto.mask`) are 256-bit registers that enable per-lane conditional execution. These ops move predicate values between UB and predicate registers. + +--- + +#### Predicate Loads + +##### `pto.plds` + +- **syntax:** `%result = pto.plds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.mask` +- **semantics:** Load predicate register with scalar offset. + +**Distribution modes:** `NORM`, `US`, `DS` + +**Example:** +```mlir +%mask = pto.plds %ub[%c0] {dist = "NORM"} : !pto.ptr -> !pto.mask +``` + +--- + +##### `pto.pld` + +- **syntax:** `%result = pto.pld %source[%offset], "DIST" : !pto.ptr, index -> !pto.mask` +- **semantics:** Load predicate register with areg offset. + +--- + +##### `pto.pldi` + +- **syntax:** `%result = pto.pldi %source, %offset, "DIST" : !pto.ptr, i32 -> !pto.mask` +- **semantics:** Load predicate register with immediate offset. + +--- + +#### Predicate Stores + +##### `pto.psts` + +- **syntax:** `pto.psts %value, %dest[%offset] : !pto.mask, !pto.ptr` +- **semantics:** Store predicate register with scalar offset. + +**Example:** +```mlir +pto.psts %mask, %ub[%c0] : !pto.mask, !pto.ptr +``` + +--- + +##### `pto.pst` + +- **syntax:** `pto.pst %value, %dest[%offset], "DIST" : !pto.mask, !pto.ptr, index` +- **semantics:** Store predicate register with areg offset. + +**Distribution modes:** `NORM`, `PK` + +--- + +##### `pto.psti` + +- **syntax:** `pto.psti %value, %dest, %offset, "DIST" : !pto.mask, !pto.ptr, i32` +- **semantics:** Store predicate register with immediate offset. + +--- + +##### `pto.pstu` + +- **syntax:** `%align_out, %base_out = pto.pstu %align_in, %value, %base : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr` +- **semantics:** Predicate unaligned store with align state update. + +--- + +#### Typical Usage Pattern + +```mlir +// Generate comparison mask +%mask = pto.vcmp %v0, %v1, %seed, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + +// Store mask to UB for later use +pto.psts %mask, %ub_mask[%c0] : !pto.mask, !pto.ptr + +// ... later in another kernel ... + +// Load mask from UB +%saved_mask = pto.plds %ub_mask[%c0] {dist = "NORM"} : !pto.ptr -> !pto.mask + +// Use for predicated select +%result = pto.vsel %v_true, %v_false, %saved_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +--- + + + +### 5. Materialization & Predicate Ops + +> **Category:** Scalar broadcast, predicate generation and manipulation +> **Pipeline:** PIPE_V (Vector Core) + +These ops create vectors from scalar values and manipulate predicate registers. + +#### Common Operand Model + +- `%value` is the scalar source value in SSA form. +- `%input` is either a source scalar or a source vector depending on the op. +- `%result` is the destination vector register value. +- For 32-bit scalar inputs, the scalar source MUST satisfy the backend's legal + scalar-source constraints for this family. + +--- + +#### Scalar Materialization + +##### `pto.vbr` + +- **syntax:** `%result = pto.vbr %value : T -> !pto.vreg` +- **semantics:** Broadcast scalar to all vector lanes. +- **inputs:** + `%value` is the scalar source. +- **outputs:** + `%result` is a vector whose active lanes all carry `%value`. +- **constraints and limitations:** + Supported forms are `b8`, `b16`, and `b32`. For `b8`, only the low 8 bits of + the scalar source are consumed. + +```c +for (int i = 0; i < N; i++) + dst[i] = value; +``` + +**Example:** +```mlir +%one = pto.vbr %c1_f32 : f32 -> !pto.vreg<64xf32> +``` + +--- + +##### `pto.vdup` + +- **syntax:** `%result = pto.vdup %input {position = "POSITION"} : T|!pto.vreg -> !pto.vreg` +- **semantics:** Duplicate scalar or vector element to all lanes. +- **inputs:** + `%input` supplies the scalar or source-lane value selected by `position`. +- **outputs:** + `%result` is the duplicated vector. +- **constraints and limitations:** + `position` selects which source element or scalar position is duplicated. The + current PTO micro Instruction representation models that selector as an attribute rather than a + separate operand. + +```c +for (int i = 0; i < N; i++) + dst[i] = input_scalar_or_element; +``` + +--- + +#### Predicate Generation + +##### `pto.pset_b8` / `pto.pset_b16` / `pto.pset_b32` + +- **syntax:** `%result = pto.pset_b32 "PAT_*" : !pto.mask` +- **semantics:** Generate predicate from pattern. + +**Patterns:** + +| Pattern | Description | +|---------|-------------| +| `PAT_ALL` | All lanes active | +| `PAT_ALLF` | All lanes inactive | +| `PAT_H` | High half active | +| `PAT_Q` | Upper quarter active | +| `PAT_VL1`...`PAT_VL128` | First N lanes active | +| `PAT_M3`, `PAT_M4` | Modular patterns | + +**Example — All 64 f32 lanes active:** +```mlir +%all_active = pto.pset_b32 "PAT_ALL" : !pto.mask +``` + +**Example — First 16 lanes active:** +```mlir +%first_16 = pto.pset_b32 "PAT_VL16" : !pto.mask +``` + +--- + +##### `pto.pge_b8` / `pto.pge_b16` / `pto.pge_b32` + +- **syntax:** `%result = pto.pge_b32 "PAT_*" : !pto.mask` +- **semantics:** Generate tail mask — first N lanes active. + +```c +for (int i = 0; i < TOTAL_LANES; i++) + mask[i] = (i < len); +``` + +**Example — Tail mask for remainder loop:** +```mlir +%tail_mask = pto.pge_b32 "PAT_VL8" : !pto.mask + +--- + +##### `pto.plt_b8` / `pto.plt_b16` / `pto.plt_b32` + +- **syntax:** `%mask, %scalar_out = pto.plt_b32 %scalar : i32 -> !pto.mask, i32` +- **semantics:** Generate predicate state together with updated scalar state. +``` + +--- + +#### Predicate Pack/Unpack + +##### `pto.ppack` + +- **syntax:** `%result = pto.ppack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Narrowing pack of predicate register. + +**Part tokens:** `LOWER`, `HIGHER` + +--- + +##### `pto.punpack` + +- **syntax:** `%result = pto.punpack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Widening unpack of predicate register. + +--- + +#### Predicate Logical Ops + +##### `pto.pand` + +- **syntax:** `%result = pto.pand %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise AND. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src0[i] & src1[i]; +``` + +--- + +##### `pto.por` + +- **syntax:** `%result = pto.por %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise OR. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src0[i] | src1[i]; +``` + +--- + +##### `pto.pxor` + +- **syntax:** `%result = pto.pxor %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise XOR. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src0[i] ^ src1[i]; +``` + +--- + +##### `pto.pnot` + +- **syntax:** `%result = pto.pnot %input, %mask : !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise NOT. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = ~src[i]; +``` + +--- + +##### `pto.psel` + +- **syntax:** `%result = pto.psel %src0, %src1, %sel : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate select (mux). + +```c +for (int i = 0; i < N; i++) + dst[i] = sel[i] ? src0[i] : src1[i]; +``` + +--- + +#### Predicate Movement + +##### `pto.ppack` + +- **syntax:** `%result = pto.ppack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Narrowing pack of predicate register. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src[i]; +``` + +--- + +##### `pto.punpack` + +- **syntax:** `%result = pto.punpack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Widening unpack of predicate register. + +--- + +##### `pto.pdintlv_b8` + +- **syntax:** `%low, %high = pto.pdintlv_b8 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **semantics:** Predicate deinterleave. + +--- + +##### `pto.pintlv_b16` + +- **syntax:** `%low, %high = pto.pintlv_b16 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **semantics:** Predicate interleave. + +--- + +#### Typical Usage + +```mlir +// Generate all-active mask for f32 (64 lanes) +%all = pto.pset_b32 "PAT_ALL" : !pto.mask + +// Generate tail mask for remainder (last 12 elements) +%tail = pto.pge_b32 "PAT_VL12" : !pto.mask + +// Compare and generate mask +%cmp_mask = pto.vcmp %a, %b, %all, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + +// Combine masks: only process tail elements that passed comparison +%combined = pto.pand %cmp_mask, %tail, %all : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + +// Use for predicated operation +%result = pto.vsel %true_vals, %false_vals, %combined : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +--- + + + +### 6. Unary Vector Ops + +> **Category:** Single-input vector operations +> **Pipeline:** PIPE_V (Vector Core) + +Element-wise operations that take one vector input and produce one vector output. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate operand. For this family, inactive lanes follow the + predication behavior of the selected instruction form: zeroing forms + zero-fill inactive lanes, while merging forms preserve the destination value. +- `%result` is the destination vector register value. Unless stated otherwise, + `%result` has the same lane count and element type as `%input`. + +--- + +#### Arithmetic + +##### `pto.vabs` + +- **syntax:** `%result = pto.vabs %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] < 0) ? -src[i] : src[i]; +``` + +- **inputs:** `%input` supplies the source lanes and `%mask` selects which lanes + participate. +- **outputs:** `%result` receives the lane-wise absolute values. +- **constraints and limitations:** Source and result types MUST match. Integer + overflow on the most-negative signed value follows the target-defined + behavior. + +--- + +##### `pto.vneg` + +- **syntax:** `%result = pto.vneg %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = -src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise arithmetic negation. +- **constraints and limitations:** Source and result types MUST match. + +--- + +#### Transcendental + +##### `pto.vexp` + +- **syntax:** `%result = pto.vexp %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = expf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds `exp(input[i])` per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + +--- + +##### `pto.vln` + +- **syntax:** `%result = pto.vln %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = logf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the natural logarithm per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + For real-number semantics, active inputs SHOULD be strictly positive; non- + positive inputs follow the target's exception/NaN rules. + +--- + +##### `pto.vsqrt` + +- **syntax:** `%result = pto.vsqrt %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = sqrtf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the square root per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + Negative active inputs follow the target's exception/NaN rules. + +--- + +##### `pto.vrsqrt` + +- **syntax:** `%result = pto.vrsqrt %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = 1.0f / sqrtf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds reciprocal-square-root values per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + Active inputs containing `+0` or `-0` follow the target's divide-style + exceptional behavior. + +--- + +##### `pto.vrec` + +- **syntax:** `%result = pto.vrec %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = 1.0f / src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the reciprocal per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + Active inputs containing `+0` or `-0` follow the target's divide-style + exceptional behavior. + +--- + +#### Activation + +##### `pto.vrelu` + +- **syntax:** `%result = pto.vrelu %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] > 0) ? src[i] : 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds `max(input[i], 0)` per active lane. +- **constraints and limitations:** Only floating-point element types are legal + on the current A5 surface described here. + +--- + +#### Bitwise + +##### `pto.vnot` + +- **syntax:** `%result = pto.vnot %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = ~src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the lane-wise bitwise inversion. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vbcnt` + +- **syntax:** `%result = pto.vbcnt %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = __builtin_popcount(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the population count for each active lane. +- **constraints and limitations:** Integer element types only. The count is + over the source element width, not over the full vector register. + +--- + +##### `pto.vcls` + +- **syntax:** `%result = pto.vcls %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = count_leading_sign_bits(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the leading-sign-bit count per active lane. +- **constraints and limitations:** Integer element types only. This operation is + sign-aware, so signed interpretation matters. + +--- + +#### Movement + +##### `pto.vmov` + +- **syntax:** `%result = pto.vmov %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Vector register copy. + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` is a copy of the source vector. +- **constraints and limitations:** Predicated `pto.vmov` behaves like a masked + copy, while the unpredicated form behaves like a full-register copy. + +--- + +#### Typical Usage + +```mlir +// Softmax numerator: exp(x - max) +%sub = pto.vsub %x, %max_broadcast, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%exp = pto.vexp %sub, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Reciprocal for division +%sum_rcp = pto.vrec %sum, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// ReLU activation +%activated = pto.vrelu %linear_out, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +--- + + + +### 7. Binary Vector Ops + +> **Category:** Two-input vector operations +> **Pipeline:** PIPE_V (Vector Core) + +Element-wise operations that take two vector inputs and produce one vector output. + +#### Common Operand Model + +- `%lhs` and `%rhs` are the two source vector register values. +- `%mask` is the predicate operand `Pg` that gates which lanes participate. +- `%result` is the destination vector register value. Unless explicitly noted, + it has the same lane count and element type as the inputs. +- Unless explicitly documented otherwise, `%lhs`, `%rhs`, and `%result` MUST + have matching vector shapes and element types. + +--- + +#### Arithmetic + +##### `pto.vadd` + +- **syntax:** `%result = pto.vadd %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i64, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +- **inputs:** `%lhs` and `%rhs` are added lane-wise; `%mask` selects active + lanes. +- **outputs:** `%result` is the lane-wise sum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vsub` + +- **syntax:** `%result = pto.vsub %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i64, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] - src1[i]; +``` + +- **inputs:** `%lhs` is the minuend, `%rhs` is the subtrahend, and `%mask` + selects active lanes. +- **outputs:** `%result` is the lane-wise difference. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmul` + +- **syntax:** `%result = pto.vmul %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, bf16, f32 (**NOT** i8/u8) + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] * src1[i]; +``` + +- **inputs:** `%lhs` and `%rhs` are multiplied lane-wise; `%mask` selects + active lanes. +- **outputs:** `%result` is the lane-wise product. +- **constraints and limitations:** The current A5 profile excludes `i8/u8` + forms from this surface. + +--- + +##### `pto.vdiv` + +- **syntax:** `%result = pto.vdiv %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 only (no integer division) + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] / src1[i]; +``` + +- **inputs:** `%lhs` is the numerator, `%rhs` is the denominator, and `%mask` + selects active lanes. +- **outputs:** `%result` is the lane-wise quotient. +- **constraints and limitations:** Floating-point element types only. Active + denominators containing `+0` or `-0` follow the target's exceptional + behavior. + +--- + +##### `pto.vmax` + +- **syntax:** `%result = pto.vmax %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src0[i] > src1[i]) ? src0[i] : src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` holds the lane-wise maximum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmin` + +- **syntax:** `%result = pto.vmin %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src0[i] < src1[i]) ? src0[i] : src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` holds the lane-wise minimum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +#### Bitwise + +##### `pto.vand` + +- **syntax:** `%result = pto.vand %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] & src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise AND. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vor` + +- **syntax:** `%result = pto.vor %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] | src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise OR. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vxor` + +- **syntax:** `%result = pto.vxor %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] ^ src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise XOR. +- **constraints and limitations:** Integer element types only. + +--- + +#### Shift + +##### `pto.vshl` + +- **syntax:** `%result = pto.vshl %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] << src1[i]; +``` + +- **inputs:** `%lhs` supplies the shifted value, `%rhs` supplies the per-lane + shift amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. Shift counts + SHOULD stay within `[0, bitwidth(T) - 1]`; out-of-range behavior is target- + defined unless the verifier narrows it further. + +--- + +##### `pto.vshr` + +- **syntax:** `%result = pto.vshr %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] >> src1[i]; // arithmetic for signed, logical for unsigned +``` + +- **inputs:** `%lhs` supplies the shifted value, `%rhs` supplies the per-lane + shift amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. Signedness of the + element type determines arithmetic vs logical behavior. + +--- + +#### Carry Operations + +##### `pto.vaddc` + +- **syntax:** `%result, %carry = pto.vaddc %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Add with carry output. + +```c +for (int i = 0; i < N; i++) { + uint64_t r = (uint64_t)src0[i] + src1[i]; + dst[i] = (T)r; + carry[i] = (r >> bitwidth); +} +``` + +- **inputs:** `%lhs` and `%rhs` are added lane-wise and `%mask` selects active + lanes. +- **outputs:** `%result` is the truncated arithmetic result and `%carry` is the + carry/overflow predicate per lane. +- **constraints and limitations:** This is a carry-chain integer add family. On + the current A5 surface, it SHOULD be treated as an unsigned integer + operation. + +--- + +##### `pto.vsubc` + +- **syntax:** `%result, %borrow = pto.vsubc %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Subtract with borrow output. + +```c +for (int i = 0; i < N; i++) { + dst[i] = src0[i] - src1[i]; + borrow[i] = (src0[i] < src1[i]); +} +``` + +- **inputs:** `%lhs` and `%rhs` are subtracted lane-wise and `%mask` selects + active lanes. +- **outputs:** `%result` is the arithmetic difference and `%borrow` marks lanes + that borrowed. +- **constraints and limitations:** This operation SHOULD be treated as an + unsigned 32-bit carry-chain family unless and until the verifier states + otherwise. + +--- + +#### Typical Usage + +```mlir +// Vector addition +%sum = pto.vadd %a, %b, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Element-wise multiply +%prod = pto.vmul %x, %y, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Clamp to range [min, max] +%clamped_low = pto.vmax %input, %min_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%clamped = pto.vmin %clamped_low, %max_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Bit manipulation +%masked = pto.vand %data, %bitmask, %mask : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +``` + +--- + + + +### 8. Vec-Scalar Ops + +> **Category:** Vector-scalar operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that combine a vector with a scalar value, applying the scalar to every lane. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%scalar` is the scalar operand in SSA form. +- `%mask` is the predicate operand. +- `%result` is the destination vector register value. +- For 32-bit scalar forms, the scalar source MUST satisfy the backend's legal + scalar-source constraints for this family. + +--- + +#### Arithmetic + +##### `pto.vadds` + +- **syntax:** `%result = pto.vadds %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] + scalar; +``` + +- **inputs:** `%input` is the source vector, `%scalar` is broadcast logically to + each active lane, and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise sum. +- **constraints and limitations:** Inactive lanes follow the predication + behavior defined for this family. On the current surface, inactive lanes are + treated as zeroing lanes. + +--- + +##### `pto.vsubs` + +- **syntax:** `%result = pto.vsubs %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] - scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise difference. +- **constraints and limitations:** Integer or floating-point legality depends on + the selected type family in lowering. + +--- + +##### `pto.vmuls` + +- **syntax:** `%result = pto.vmuls %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] * scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise product. +- **constraints and limitations:** Supported element types are hardware-family + specific; the current PTO micro Instruction documentation covers the common numeric cases. + +--- + +##### `pto.vmaxs` + +- **syntax:** `%result = pto.vmaxs %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] > scalar) ? src[i] : scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise maximum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmins` + +- **syntax:** `%result = pto.vmins %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] < scalar) ? src[i] : scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise minimum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +#### Bitwise + +##### `pto.vands` + +- **syntax:** `%result = pto.vands %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] & scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise AND. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vors` + +- **syntax:** `%result = pto.vors %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] | scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise OR. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vxors` + +- **syntax:** `%result = pto.vxors %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] ^ scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise XOR. +- **constraints and limitations:** Integer element types only. + +--- + +#### Shift + +##### `pto.vshls` + +- **syntax:** `%result = pto.vshls %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] << scalar; +``` + +- **inputs:** `%input` is the value vector, `%scalar` is the uniform shift + amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. The shift amount + SHOULD stay within the source element width. + +--- + +##### `pto.vshrs` + +- **syntax:** `%result = pto.vshrs %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] >> scalar; +``` + +- **inputs:** `%input` is the value vector, `%scalar` is the uniform shift + amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vlrelu` + +- **syntax:** `%result = pto.vlrelu %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : scalar * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%scalar` is the leaky slope, + and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise leaky-ReLU result. +- **constraints and limitations:** Only `f16` and `f32` forms are currently + documented for `pto.vlrelu`. + +--- + +#### Carry Operations + +##### `pto.vaddcs` + +- **syntax:** `%result, %carry = pto.vaddcs %lhs, %rhs, %carry_in, %mask : !pto.vreg, !pto.vreg, !pto.mask, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Add with carry-in and carry-out. + +```c +for (int i = 0; i < N; i++) { + uint64_t r = (uint64_t)src0[i] + src1[i] + carry_in[i]; + dst[i] = (T)r; + carry_out[i] = (r >> bitwidth); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the value vectors, `%carry_in` is the + incoming carry predicate, and `%mask` selects active lanes. +- **outputs:** `%result` is the arithmetic result and `%carry` is the carry-out + predicate. +- **constraints and limitations:** This is the scalar-extended carry-chain + family. Treat it as an unsigned integer operation unless the verifier states a + wider legal domain. + +--- + +##### `pto.vsubcs` + +- **syntax:** `%result, %borrow = pto.vsubcs %lhs, %rhs, %borrow_in, %mask : !pto.vreg, !pto.vreg, !pto.mask, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Subtract with borrow-in and borrow-out. + +```c +for (int i = 0; i < N; i++) { + dst[i] = src0[i] - src1[i] - borrow_in[i]; + borrow_out[i] = (src0[i] < src1[i] + borrow_in[i]); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the value vectors, `%borrow_in` is the + incoming borrow predicate, and `%mask` selects active lanes. +- **outputs:** `%result` is the arithmetic result and `%borrow` is the + borrow-out predicate. +- **constraints and limitations:** This is the scalar-extended borrow-chain + family and SHOULD be treated as an unsigned integer operation. + +--- + +#### Typical Usage + +```mlir +// Add bias to all elements +%biased = pto.vadds %activation, %bias_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Scale by constant +%scaled = pto.vmuls %input, %scale, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Clamp to [0, 255] for uint8 quantization +%clamped_low = pto.vmaxs %input, %c0, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> +%clamped = pto.vmins %clamped_low, %c255, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Shift right by fixed amount +%shifted = pto.vshrs %data, %c4, %mask : !pto.vreg<64xi32>, i32, !pto.mask -> !pto.vreg<64xi32> +``` + +--- + + + +### 9. Conversion Ops + +> **Category:** Type conversion operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that convert between data types (float/int, narrowing/widening). + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%result` is the destination vector register value. +- `round_mode`, `sat`, and `part` control rounding, saturation, and lane-part + selection in attribute form. +- The single `pto.vcvt` surface covers float-int, float-float, int-float, and + int-int conversion families. + +--- + +#### `pto.vci` + +- **syntax:** `%result = pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- **semantics:** Generate a lane-index vector from a scalar seed/index value. +- **inputs:** + `%index` is the scalar seed or base index. +- **outputs:** + `%result` is the generated index vector. +- **constraints and limitations:** + This is an index-generation family, not a numeric conversion. `ORDER` and the + result element type together determine how indices are generated. + +--- + +#### `pto.vcvt` + +- **syntax:** `%result = pto.vcvt %input {round_mode = "ROUND_MODE", sat = "SAT_MODE", part = "PART_MODE"} : !pto.vreg -> !pto.vreg` +- **semantics:** Type conversion between float/int types with rounding control. + +```c +for (int i = 0; i < min(N, M); i++) + dst[i] = convert(src[i], T0, T1, round_mode); +``` + +- **inputs:** + `%input` is the source vector; attributes select rounding, saturation, and + even/odd placement when the conversion changes width. +- **outputs:** + `%result` is the converted vector. +- **constraints and limitations:** + Only documented source/destination type pairs are legal. `PART_EVEN` / + `PART_ODD` is only meaningful for width-changing forms that pack two source + streams into one destination register. + +--- + +##### Rounding Modes + +| Mode | Description | +|------|-------------| +| `ROUND_R` | Round to nearest, ties to even (default) | +| `ROUND_A` | Round away from zero | +| `ROUND_F` | Round toward negative infinity (floor) | +| `ROUND_C` | Round toward positive infinity (ceil) | +| `ROUND_Z` | Round toward zero (truncate) | +| `ROUND_O` | Round to odd | + +--- + +##### Saturation Modes + +| Mode | Description | +|------|-------------| +| `RS_ENABLE` | Saturate on overflow | +| `RS_DISABLE` | No saturation (wrap/undefined on overflow) | + +--- + +##### Part Modes (for width-changing conversions) + +| Mode | Description | +|------|-------------| +| `PART_EVEN` | Output to even-indexed lanes | +| `PART_ODD` | Output to odd-indexed lanes | + +--- + +##### A5 Supported Conversions + +**Float-Float (vcvtff):** +- f32 ↔ f16 +- f32 ↔ bf16 +- f16 ↔ bf16 + +**Float-Int (vcvtfi):** +- f16 → i16, f16 → i32 +- f32 → i16, f32 → i32 +- bf16 → i32 + +**Int-Float (vcvtif):** +- i16 → f16 +- i32 → f32 + +--- + +##### Width-Changing Conversion Pattern + +For conversions that change width (e.g., f32→f16), use even/odd parts and combine: + +```mlir +// Convert two f32 vectors to one f16 vector +%even = pto.vcvt %in0 {round_mode = "ROUND_R", sat = "RS_ENABLE", part = "PART_EVEN"} + : !pto.vreg<64xf32> -> !pto.vreg<128xf16> +%odd = pto.vcvt %in1 {round_mode = "ROUND_R", sat = "RS_ENABLE", part = "PART_ODD"} + : !pto.vreg<64xf32> -> !pto.vreg<128xf16> +%result = pto.vor %even, %odd, %mask : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +``` + +--- + +#### `pto.vtrc` + +- **syntax:** `%result = pto.vtrc %input, "ROUND_MODE" : !pto.vreg -> !pto.vreg` +- **semantics:** Truncate/round float to integer-valued float (stays in float type). + +```c +for (int i = 0; i < N; i++) + dst[i] = round_to_int_valued_float(src[i], round_mode); +``` + +- **inputs:** + `%input` is the floating-point source vector and `ROUND_MODE` selects the + truncation/rounding rule. +- **outputs:** + `%result` is still a floating-point vector, but each active lane now carries + an integer-valued floating-point result. +- **constraints and limitations:** + This op does not change the element type. `ROUND_O` is supported for avoiding + double-rounding errors during staged conversions. + +**Example:** +```mlir +// Round to nearest integer, keep as float +%rounded = pto.vtrc %input, "ROUND_R" : !pto.vreg<64xf32> -> !pto.vreg<64xf32> +// input: [1.4, 2.6, -1.5, 3.0] +// output: [1.0, 3.0, -2.0, 3.0] +``` + +--- + +#### Typical Usage + +```mlir +// Quantization: f32 → i8 with saturation +%scaled = pto.vmuls %input, %scale, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> +%quantized = pto.vcvt %scaled {round_mode = "ROUND_R", sat = "RS_ENABLE"} + : !pto.vreg<64xf32> -> !pto.vreg<64xi32> +// Then narrow i32 → i8 via pack ops + +// Mixed precision: bf16 → f32 for accumulation +%f32_vec = pto.vcvt %bf16_input {round_mode = "ROUND_R"} + : !pto.vreg<128xbf16> -> !pto.vreg<64xf32> + +// Floor for integer division +%floored = pto.vtrc %ratio, "ROUND_F" : !pto.vreg<64xf32> -> !pto.vreg<64xf32> +%int_div = pto.vcvt %floored {round_mode = "ROUND_Z"} + : !pto.vreg<64xf32> -> !pto.vreg<64xi32> +``` + +--- + + + +### 10. Reduction Ops + +> **Category:** Vector reduction operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that reduce a vector to a scalar or per-group result. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate operand `Pg`; inactive lanes do not participate. +- `%result` is the destination vector register value. +- Reduction results are written into the low-significance portion of the + destination vector and the remaining destination bits are zero-filled. + +--- + +#### Full Vector Reductions + +##### `pto.vcadd` + +- **syntax:** `%result = pto.vcadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i64, f16, f32 +- **semantics:** Sum all elements. Result in lane 0, others zeroed. + +```c +T sum = 0; +for (int i = 0; i < N; i++) + sum += src[i]; +dst[0] = sum; +for (int i = 1; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains the reduction result in its low element(s). +- **constraints and limitations:** Some narrow integer forms may widen the + internal accumulation or result placement. If all predicate bits are zero, the + result is zero. + +--- + +##### `pto.vcmax` + +- **syntax:** `%result = pto.vcmax %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Find max element with argmax. Result value + index in lane 0. + +```c +T mx = -INF; int idx = 0; +for (int i = 0; i < N; i++) + if (src[i] > mx) { mx = src[i]; idx = i; } +dst_val[0] = mx; +dst_idx[0] = idx; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` carries the reduction result in the low destination + positions. +- **constraints and limitations:** This family computes both the extremum and + location information, but the exact packing of that information into the + destination vector depends on the chosen form. If all predicate bits are zero, + the result follows the zero-filled convention. + +--- + +##### `pto.vcmin` + +- **syntax:** `%result = pto.vcmin %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Find min element with argmin. Result value + index in lane 0. + +```c +T mn = INF; int idx = 0; +for (int i = 0; i < N; i++) + if (src[i] < mn) { mn = src[i]; idx = i; } +dst_val[0] = mn; +dst_idx[0] = idx; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` carries the reduction result in the low destination + positions. +- **constraints and limitations:** As with `pto.vcmax`, the exact value/index + packing depends on the chosen form and MUST be preserved consistently. + +--- + +#### Per-VLane (Group) Reductions + +The vector register is organized as **8 VLanes** of 32 bytes each. Group reductions operate within each VLane independently. + +``` +vreg layout (f32 example, 64 elements total): +VLane 0: [0..7] VLane 1: [8..15] VLane 2: [16..23] VLane 3: [24..31] +VLane 4: [32..39] VLane 5: [40..47] VLane 6: [48..55] VLane 7: [56..63] +``` + +##### `pto.vcgadd` + +- **syntax:** `%result = pto.vcgadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Sum within each VLane. 8 results at indices 0, 8, 16, 24, 32, 40, 48, 56 (for f32). + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +// For f32: results at dst[0], dst[8], dst[16], dst[24], dst[32], dst[40], dst[48], dst[56] +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one sum per 32-byte VLane group, written + contiguously into the low slot of each group. +- **constraints and limitations:** This is a per-32-byte VLane-group reduction. + Inactive lanes are treated as zero. + +--- + +##### `pto.vcgmax` + +- **syntax:** `%result = pto.vcgmax %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Max within each VLane. + +```c +int K = N / 8; +for (int g = 0; g < 8; g++) { + T mx = -INF; + for (int i = 0; i < K; i++) + if (src[g*K + i] > mx) mx = src[g*K + i]; + dst[g*K] = mx; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one maximum per 32-byte VLane group. +- **constraints and limitations:** Grouping is by hardware 32-byte VLane, not by + arbitrary software subvector. + +--- + +##### `pto.vcgmin` + +- **syntax:** `%result = pto.vcgmin %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Min within each VLane. + +```c +int K = N / 8; +for (int g = 0; g < 8; g++) { + T mn = INF; + for (int i = 0; i < K; i++) + if (src[g*K + i] < mn) mn = src[g*K + i]; + dst[g*K] = mn; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one minimum per 32-byte VLane group. +- **constraints and limitations:** Grouping is by hardware 32-byte VLane, not by + arbitrary software subvector. + +--- + +#### Prefix Operations + +##### `pto.vcpadd` + +- **syntax:** `%result = pto.vcpadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Inclusive prefix sum (scan). + +```c +dst[0] = src[0]; +for (int i = 1; i < N; i++) + dst[i] = dst[i-1] + src[i]; +``` + +**Example:** +```c +// input: [1, 2, 3, 4, 5, ...] +// output: [1, 3, 6, 10, 15, ...] +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` is the inclusive prefix-sum vector. +- **constraints and limitations:** Only floating-point element types are + documented on the current A5 surface here. + +--- + +#### Typical Usage + +```mlir +// Softmax: find max for numerical stability +%max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// max is in lane 0, broadcast it +%max_broadcast = pto.vlds %ub_tmp[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + +// Row-wise sum using vcgadd (for 8-row tile) +%row_sums = pto.vcgadd %tile, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// Results at indices 0, 8, 16, 24, 32, 40, 48, 56 + +// Full vector sum for normalization +%total = pto.vcadd %values, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// total[0] contains the sum + +// Prefix sum for cumulative distribution +%cdf = pto.vcpadd %pdf, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +--- + + + +### 11. Compare & Select + +> **Category:** Comparison and conditional selection operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that compare vectors and conditionally select elements. + +#### Common Operand Model + +- `%src0` and `%src1` are source vector operands. +- `%scalar` is the scalar operand for scalar-comparison families. +- `%seed` is the incoming predicate that limits which lanes participate in the + compare. +- `%result` is either a predicate mask (`vcmp`, `vcmps`) or a vector register + (`vsel`, `vselr`, `vselrv2`). + +--- + +#### Comparison Operations + +##### `pto.vcmp` + +- **syntax:** `%result = pto.vcmp %src0, %src1, %seed, "CMP_MODE" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.mask` +- **semantics:** Element-wise comparison, output predicate mask. + +```c +for (int i = 0; i < N; i++) + if (seed[i]) + dst[i] = (src0[i] CMP src1[i]) ? 1 : 0; +``` + +**Compare modes:** + +| Mode | Operation | +|------|-----------| +| `eq` | Equal (==) | +| `ne` | Not equal (!=) | +| `lt` | Less than (<) | +| `le` | Less than or equal (<=) | +| `gt` | Greater than (>) | +| `ge` | Greater than or equal (>=) | + +**Example:** +```mlir +%all_active = pto.pset_b32 "PAT_ALL" : !pto.mask +%lt_mask = pto.vcmp %a, %b, %all_active, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +// lt_mask[i] = 1 if a[i] < b[i] +``` + +- **inputs:** `%src0`, `%src1`, and `%seed`; `CMP_MODE` selects the comparison + predicate. +- **outputs:** `%result` is the generated predicate mask. +- **constraints and limitations:** Only lanes enabled by `%seed` participate. + Integer and floating-point comparisons follow their own element-type-specific + comparison rules. + +--- + +##### `pto.vcmps` + +- **syntax:** `%result = pto.vcmps %src, %scalar, %seed, "CMP_MODE" : !pto.vreg, T, !pto.mask -> !pto.mask` +- **semantics:** Compare vector against scalar. + +```c +for (int i = 0; i < N; i++) + if (seed[i]) + dst[i] = (src[i] CMP scalar) ? 1 : 0; +``` + +**Example:** +```mlir +%positive_mask = pto.vcmps %values, %c0_f32, %all_active, "gt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +// positive_mask[i] = 1 if values[i] > 0 +``` + +- **inputs:** `%src` is the vector source, `%scalar` is the scalar comparison + value, and `%seed` is the incoming predicate. +- **outputs:** `%result` is the generated predicate mask. +- **constraints and limitations:** For 32-bit scalar forms, the scalar source + MUST satisfy the backend's legal scalar-source constraints for this family. + +--- + +#### Selection Operations + +##### `pto.vsel` + +- **syntax:** `%result = pto.vsel %src0, %src1, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Per-lane select based on mask. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? src0[i] : src1[i]; +``` + +**Example — Conditional assignment:** +```mlir +// dst = mask ? true_vals : false_vals +%result = pto.vsel %true_vals, %false_vals, %condition + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +- **inputs:** `%src0` is the true-path vector, `%src1` is the false-path vector, + and `%mask` selects between them. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** Source vectors and result MUST have matching + vector shapes and element types. + +--- + +##### `pto.vselr` + +- **syntax:** `%result = pto.vselr %src0, %src1 : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Select with reversed mask semantics. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? src1[i] : src0[i]; // reversed from vsel +``` + +- **inputs:** `%src0` and `%src1` are the source vectors. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** This family preserves reversed-select + semantics. If the concrete lowering uses an implicit predicate source, that + predicate source MUST be documented by the surrounding IR pattern. + +--- + +##### `pto.vselrv2` + +- **syntax:** `%result = pto.vselrv2 %src0, %src1 : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Variant select form with the same current two-vector operand shape. +- **inputs:** `%src0` and `%src1` are the source vectors. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** This page records the surface shape only. + Lowering MUST preserve the exact A5 variant semantics selected for this form. + +--- + +#### Typical Usage + +```mlir +// Clamp negative values to zero (manual ReLU) +%all = pto.pset_b32 "PAT_ALL" : !pto.mask +%zero = pto.vbr %c0_f32 : f32 -> !pto.vreg<64xf32> +%neg_mask = pto.vcmps %input, %c0_f32, %all, "lt" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%clamped = pto.vsel %zero, %input, %neg_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Element-wise max via compare+select +%gt_mask = pto.vcmp %a, %b, %all, "gt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +%max_ab = pto.vsel %a, %b, %gt_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Threshold filter +%above_thresh = pto.vcmps %scores, %threshold, %all, "ge" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%filtered = pto.vsel %scores, %zero, %above_thresh : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +--- + +#### Compare + Select Pattern + +```mlir +// Softmax safe exp: exp(x - max) where x < max returns exp of negative +// but we want to clamp to avoid underflow + +%all = pto.pset_b32 "PAT_ALL" : !pto.mask + +// 1. Compare against threshold +%too_small = pto.vcmps %x_minus_max, %min_exp_arg, %all, "lt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask + +// 2. Clamp values below threshold +%clamped = pto.vsel %min_exp_arg_vec, %x_minus_max, %too_small + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// 3. Safe exp +%exp_result = pto.vexp %clamped, %all : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +--- + + + +### 12. Data Rearrangement + +> **Category:** In-register data movement and permutation +> **Pipeline:** PIPE_V (Vector Core) + +Operations that rearrange data within or between vector registers without memory access. + +#### Common Operand Model + +- `%lhs` / `%rhs` are source vector register values. +- `%src` is a single source vector register value. +- `%result` is the destination vector register value unless an op explicitly + returns multiple vectors. +- These families do not access UB directly; they only rearrange register + contents. + +--- + +#### Interleave / Deinterleave + +##### `pto.vintlv` + +- **syntax:** `%low, %high = pto.vintlv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg, !pto.vreg` +- **semantics:** Interleave elements from two sources. + +```c +// Interleave: merge even/odd elements from two sources +// low = {src0[0], src1[0], src0[1], src1[1], ...} +// high = {src0[N/2], src1[N/2], src0[N/2+1], src1[N/2+1], ...} +``` + +- **inputs:** `%lhs` and `%rhs` are the two source vectors. +- **outputs:** `%low` and `%high` are the two destination vectors. +- **constraints and limitations:** The two outputs form a paired interleave + result. The PTO micro Instruction representation exposes that pair as two SSA results, and the pair ordering MUST + be preserved. + +--- + +##### `pto.vdintlv` + +- **syntax:** `%low, %high = pto.vdintlv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg, !pto.vreg` +- **semantics:** Deinterleave elements into even/odd. + +```c +// Deinterleave: separate even/odd elements +// low = {src0[0], src0[2], src0[4], ...} // even +// high = {src0[1], src0[3], src0[5], ...} // odd +``` + +- **inputs:** `%lhs` and `%rhs` represent the interleaved source stream in the + current PTO micro Instruction representation. +- **outputs:** `%low` and `%high` are the separated destination vectors. +- **constraints and limitations:** The two outputs form the even/odd + deinterleave result pair, and their ordering MUST be preserved. + +--- + +#### Slide / Shift + +##### `pto.vslide` + +- **syntax:** `%result = pto.vslide %src0, %src1, %amt : !pto.vreg, !pto.vreg, i16 -> !pto.vreg` +- **semantics:** Concatenate two vectors and extract N-element window at offset. + +```c +// Conceptually: tmp[0..2N-1] = {src1, src0} +// dst[i] = tmp[amt + i] +if (amt >= 0) + for (int i = 0; i < N; i++) + dst[i] = (i >= amt) ? src0[i - amt] : src1[N - amt + i]; +``` + +**Use case:** Sliding window operations, shift register patterns. + +- **inputs:** `%src0` and `%src1` provide the concatenated source window and + `%amt` selects the extraction offset. +- **outputs:** `%result` is the extracted destination window. +- **constraints and limitations:** `pto.vslide` operates on the logical + concatenation of `%src1` and `%src0`. The source order and extraction offset + MUST be preserved exactly. + +--- + +##### `pto.vshift` + +- **syntax:** `%result = pto.vshift %src, %amt : !pto.vreg, i16 -> !pto.vreg` +- **semantics:** Single-source slide (shift with zero fill). + +```c +for (int i = 0; i < N; i++) + dst[i] = (i >= amt) ? src[i - amt] : 0; +``` + +- **inputs:** `%src` is the source vector and `%amt` is the slide amount. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** This surface represents the single-source + slide/shift family. Zero-fill versus other fill behavior MUST match the + selected form. + +--- + +#### Compress / Expand + +##### `pto.vsqz` + +- **syntax:** `%result = pto.vsqz %src, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Compress — pack active lanes to front. + +```c +int j = 0; +for (int i = 0; i < N; i++) + if (mask[i]) dst[j++] = src[i]; +while (j < N) dst[j++] = 0; +``` + +**Use case:** Sparse data compaction, filtering. + +- **inputs:** `%src` is the source vector and `%mask` selects which elements are + kept. +- **outputs:** `%result` is the compacted vector. +- **constraints and limitations:** This is a reduction-style compaction family. + Preserved element order MUST match source lane order. + +--- + +##### `pto.vusqz` + +- **syntax:** `%result = pto.vusqz %mask : !pto.mask -> !pto.vreg` +- **semantics:** Expand — scatter front elements to active positions. + +```c +int j = 0; +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src_front[j++]; + else dst[i] = 0; +``` + +- **inputs:** `%mask` is the expansion/placement predicate. +- **outputs:** `%result` is the expanded vector image. +- **constraints and limitations:** The source-front stream is implicit in the + current surface. Lane placement for active and inactive positions MUST be + preserved exactly. + +--- + +#### Permutation + +##### `pto.vperm` + +- **syntax:** `%result = pto.vperm %src, %index : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** In-register permute (table lookup). **Not** memory gather. + +```c +for (int i = 0; i < N; i++) + dst[i] = src[index[i] % N]; +``` + +**Note:** This operates on register contents, unlike `pto.vgather2` which reads from UB memory. + +- **inputs:** `%src` is the source vector and `%index` supplies per-lane source + indices. +- **outputs:** `%result` is the permuted vector. +- **constraints and limitations:** This is an in-register permutation family. + `%index` values outside the legal range follow the wrap/clamp behavior of the + selected form. + +--- + +##### `pto.vselr` + +- **syntax:** `%result = pto.vselr %src0, %src1 : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Register select with reversed mask semantics. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? src1[i] : src0[i]; +``` + +- **inputs:** `%src0` and `%src1` are source vectors. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** This page records the rearrangement use of + the family; the compare/select page documents the same name from the predicate + selection perspective. + +--- + +#### Pack / Unpack + +##### `pto.vpack` + +- **syntax:** `%result = pto.vpack %src0, %src1, %part : !pto.vreg, !pto.vreg, index -> !pto.vreg<2NxT_narrow>` +- **semantics:** Narrowing pack — two wide vectors to one narrow vector. + +```c +// e.g., two vreg<64xi32> → one vreg<128xi16> +for (int i = 0; i < N; i++) { + dst[i] = truncate(src0[i]); + dst[N + i] = truncate(src1[i]); +} +``` + +- **inputs:** `%src0` and `%src1` are wide source vectors and `%part` selects + the packing submode. +- **outputs:** `%result` is the packed narrow vector. +- **constraints and limitations:** Packing is a narrowing conversion. Source + values that do not fit the destination width follow the truncation semantics + of the selected packing mode. + +--- + +##### `pto.vsunpack` + +- **syntax:** `%result = pto.vsunpack %src, %part : !pto.vreg, index -> !pto.vreg` +- **semantics:** Sign-extending unpack — narrow to wide (half). + +```c +// e.g., vreg<128xi16> → vreg<64xi32> (one half) +for (int i = 0; i < N/2; i++) + dst[i] = sign_extend(src[part_offset + i]); +``` + +- **inputs:** `%src` is the packed narrow vector and `%part` selects which half + is unpacked. +- **outputs:** `%result` is the widened vector. +- **constraints and limitations:** This is the sign-extending unpack family. + +--- + +##### `pto.vzunpack` + +- **syntax:** `%result = pto.vzunpack %src, %part : !pto.vreg, index -> !pto.vreg` +- **semantics:** Zero-extending unpack — narrow to wide (half). + +```c +for (int i = 0; i < N/2; i++) + dst[i] = zero_extend(src[part_offset + i]); +``` + +- **inputs:** `%src` is the packed narrow vector and `%part` selects which half + is unpacked. +- **outputs:** `%result` is the widened vector. +- **constraints and limitations:** This is the zero-extending unpack family. + +--- + +#### Typical Usage + +```mlir +// AoS → SoA conversion using deinterleave +%even, %odd = pto.vdintlv %interleaved0, %interleaved1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// Filter: keep only elements passing condition +%pass_mask = pto.vcmps %values, %threshold, %all, "gt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%compacted = pto.vsqz %values, %pass_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Sliding window sum +%prev_window = pto.vslide %curr, %prev, %c1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, i16 -> !pto.vreg<64xf32> +%window_sum = pto.vadd %curr, %prev_window, %all + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Type narrowing via pack +%packed_i16 = pto.vpack %wide0_i32, %wide1_i32, %c0 + : !pto.vreg<64xi32>, !pto.vreg<64xi32>, index -> !pto.vreg<128xi16> +``` + +--- + +#### V2 Interleave Forms + +##### `pto.vintlvv2` + +- **syntax:** `%result = pto.vintlvv2 %lhs, %rhs, "PART" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **inputs:** `%lhs` and `%rhs` are source vectors and `PART` selects the + returned half of the V2 interleave result. +- **outputs:** `%result` is the selected interleave half. +- **constraints and limitations:** This op exposes only one half of the V2 + result in SSA form. + +##### `pto.vdintlvv2` + +- **syntax:** `%result = pto.vdintlvv2 %lhs, %rhs, "PART" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **inputs:** `%lhs` and `%rhs` are source vectors and `PART` selects the + returned half of the V2 deinterleave result. +- **outputs:** `%result` is the selected deinterleave half. +- **constraints and limitations:** This op exposes only one half of the V2 + result in SSA form. + +--- + + + +### 13. DSA/SFU Ops + +> **Category:** Domain-specific accelerator and special function unit operations +> **Pipeline:** PIPE_V (Vector Core) / SFU + +Fused operations, special functions, and UB-to-UB operations that leverage hardware acceleration. + +#### Common Operand Model + +- `%input`, `%lhs`, `%rhs`, `%acc`, and `%alpha` are source SSA values whose + roles are called out per instruction. +- `%mask` is the predicate operand `Pg` when present. +- `%result` is the destination SSA value. +- This page mixes three different backend shapes: pure `vreg -> vreg` ops, + conversion/fusion ops, and UB-to-UB helpers. Each instruction section calls + out which storage model it uses. + +--- + +#### Fused Activation Ops (vreg→vreg) + +##### `pto.vlrelu` + +- **syntax:** `%result = pto.vlrelu %input, %alpha, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Leaky ReLU with scalar alpha. + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : alpha * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%alpha` is the scalar slope, + and `%mask` selects active lanes. +- **outputs:** `%result` is the leaky-ReLU vector. +- **constraints and limitations:** Only `f16` and `f32` forms are currently + documented for `pto.vlrelu`. + +--- + +##### `pto.vprelu` + +- **syntax:** `%result = pto.vprelu %input, %alpha : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Parametric ReLU with per-element alpha vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : alpha[i] * src[i]; +``` + +- **inputs:** `%input` is the activation vector and `%alpha` is the per-element + slope vector. +- **outputs:** `%result` is the parametric-ReLU vector. +- **constraints and limitations:** Floating-point element types only on the + current A5 surface. + +--- + +##### `pto.vexpdiff` + +- **syntax:** `%result = pto.vexpdiff %input, %max : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Fused exp(x - max) for numerically stable softmax. + +```c +for (int i = 0; i < N; i++) + dst[i] = expf(src[i] - max[i]); +``` + +**Use case:** Softmax numerator computation with numerical stability. + +- **inputs:** `%input` is the source vector and `%max` is the broadcasted + subtraction term. +- **outputs:** `%result` is the fused `exp(input - max)` vector. +- **constraints and limitations:** Floating-point element types only. + +--- + +#### Fused Compute+Convert Ops + +##### `pto.vaddrelu` + +- **syntax:** `%result = pto.vaddrelu %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Fused add + ReLU. + +```c +for (int i = 0; i < N; i++) + dst[i] = max(src0[i] + src1[i], 0); +``` + +- **inputs:** `%lhs` and `%rhs` are the two addends. +- **outputs:** `%result` is the fused add-then-ReLU result. +- **constraints and limitations:** Floating-point element types only on the + current documented surface. + +--- + +##### `pto.vsubrelu` + +- **syntax:** `%result = pto.vsubrelu %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Fused sub + ReLU. + +```c +for (int i = 0; i < N; i++) + dst[i] = max(src0[i] - src1[i], 0); +``` + +- **inputs:** `%lhs` is the minuend and `%rhs` is the subtrahend. +- **outputs:** `%result` is the fused sub-then-ReLU result. +- **constraints and limitations:** Floating-point element types only on the + current documented surface. + +--- + +##### `pto.vaxpy` + +- **syntax:** `%result = pto.vaxpy %src0, %src1, %alpha : !pto.vreg, !pto.vreg, T -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** AXPY — scalar-vector multiply-add. + +```c +for (int i = 0; i < N; i++) + dst[i] = alpha * src0[i] + src1[i]; +``` + +- **inputs:** `%src0` is the scaled vector, `%src1` is the addend vector, and + `%alpha` is the scalar multiplier. +- **outputs:** `%result` is the fused AXPY result. +- **constraints and limitations:** Floating-point element types only on the + current documented surface. + +--- + +##### `pto.vaddreluconv` + +- **syntax:** `%result = pto.vaddreluconv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Fused add + ReLU + type conversion (HW fusion). + +```c +// f32→f16 variant: +for (int i = 0; i < 64; i++) + dst_f16[i] = f32_to_f16(max(src0_f32[i] + src1_f32[i], 0)); + +// f16→i8 variant: +for (int i = 0; i < 128; i++) + dst_i8[i] = f16_to_i8(max(src0_f16[i] + src1_f16[i], 0)); +``` + +- **inputs:** `%lhs` and `%rhs` are the source vectors. +- **outputs:** `%result` is the fused add/ReLU/convert result. +- **constraints and limitations:** Only backend-supported source/destination + type pairs are legal. Rounding, saturation, and packing rules follow the + semantics of this fused operation, not an arbitrary sequence of standalone + ops. + +--- + +##### `pto.vmulconv` + +- **syntax:** `%result = pto.vmulconv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Fused mul + type conversion (HW fusion). + +```c +// f16→i8 variant: +for (int i = 0; i < 128; i++) + dst_i8[i] = f16_to_i8(src0_f16[i] * src1_f16[i]); +``` + +- **inputs:** `%lhs` and `%rhs` are the source vectors. +- **outputs:** `%result` is the fused mul/convert result. +- **constraints and limitations:** Only backend-supported source/destination + type pairs are legal. + +--- + +#### Extended Arithmetic + +##### `pto.vmull` + +- **syntax:** `%low, %high = pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` +- **A5 types:** i32/u32 (native 32×32→64 widening multiply) +- **semantics:** Widening multiply with high/low results. + +```c +for (int i = 0; i < 64; i++) { + int64_t r = (int64_t)src0_i32[i] * (int64_t)src1_i32[i]; + dst_lo[i] = (int32_t)(r & 0xFFFFFFFF); + dst_hi[i] = (int32_t)(r >> 32); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the source vectors and `%mask` selects + active lanes. +- **outputs:** `%low` and `%high` expose the widened-product low/high parts. +- **constraints and limitations:** The current documented A5 form is the native + widening 32x32->64 integer multiply family. + +--- + +##### `pto.vmula` + +- **syntax:** `%result = pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Multiply-accumulate. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) + dst[i] = acc[i] + lhs[i] * rhs[i]; +``` + +- **inputs:** `%acc` is the accumulator input, `%lhs` and `%rhs` are the + multiplicands, and `%mask` selects active lanes. +- **outputs:** `%result` is the multiply-accumulate result. +- **constraints and limitations:** `pto.vmula` is a fused multiply-accumulate + operation and is not always interchangeable with separate `vmul` plus `vadd`. + +--- + +#### Index Generation + +##### `pto.vci` + +- **syntax:** `%result = pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- **semantics:** Generate lane index vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = base_index + i; +``` + +**Use case:** Generate indices for gather/scatter, argsort, etc. + +- **inputs:** `%index` is the scalar seed/base index. +- **outputs:** `%result` is the generated index vector. +- **constraints and limitations:** This page documents the arithmetic/indexing + use of the family; the conversion page also records the same opcode for + completeness. + +--- + +#### UB-to-UB Operations + +##### `pto.vtranspose` + +- **syntax:** `pto.vtranspose %dest, %src, %config : !pto.ptr, !pto.ptr, i64` +- **semantics:** UB-to-UB transpose operation (not vreg-to-vreg). + +**Note:** This operates on UB memory directly, not on vector registers. + +- **inputs:** `%dest` and `%src` are UB pointers and `%config` is the ISA + control/config word. +- **outputs:** This op writes UB memory and returns no SSA value. +- **constraints and limitations:** This is not a `vreg -> vreg` op even though + it lives in the `pto.v*` namespace. Its correctness depends on the control + word and UB layout contract. + +--- + +#### Sorting Operations + +##### `pto.vsort32` + +- **syntax:** `pto.vsort32 %dest, %src, %config : !pto.ptr, !pto.ptr, i64` +- **semantics:** Sort 32 elements in UB. +- **inputs:** `%dest` and `%src` are UB pointers and `%config` is the ISA + control/config word. +- **outputs:** This op writes UB memory and returns no SSA value. +- **constraints and limitations:** This is a UB-to-UB accelerator helper, not a + pure vector-register op. + +--- + +##### `pto.vmrgsort` + +- **syntax:** `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr, !pto.ptr x4, i64, i64` +- **semantics:** Merge-sort 4 pre-sorted input vectors. +- **inputs:** `%dest` is the UB destination, `%src0..%src3` are the four + pre-sorted UB inputs, `%count` is the number of valid elements, and `%config` + is the operation control word. +- **outputs:** This op writes UB memory and returns no SSA value. +- **constraints and limitations:** Inputs MUST already be sorted according to + the sort order encoded by `%config`. This page uses the shorter mnemonic + `pto.vmrgsort`, while the current implementation summary still refers to + `pto.vmrgsort4`. + +--- + +#### Current Implementation Surface Summary + +- `pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` +- `pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- `pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- `pto.vbitsort %dest, %src, %indices, %repeat_times : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, index` +- `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, i64, i64` + +--- + +#### Typical Usage + +```mlir +// Softmax with fused expdiff +%max_broadcast = pto.vlds %ub_max[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> +%exp_stable = pto.vexpdiff %logits, %max_broadcast : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// Leaky ReLU activation +%activated = pto.vlrelu %linear_out, %alpha_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Fused residual add + ReLU +%residual = pto.vaddrelu %conv_out, %skip_connection : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// Generate indices for argsort +%indices = pto.vci %c0 {order = "ASC"} : i32 -> !pto.vreg<64xi32> +``` + +--- + + + +### 14. Arith (Shared MLIR Dialect) + +> **Category:** Shared full scalar `arith` surface used around PTO ops +> **Dialect:** `arith` +> **Upstream Reference:** https://mlir.llvm.org/docs/Dialects/ArithOps/ + +The upstream MLIR `arith` dialect defines primitive arithmetic, comparison, select, and cast operations over signless integer, index, floating-point, and boolean-compatible scalar values. Within PTO micro Instruction code, the full scalar operation surface of `arith` is supported. These ops are used around PTO instructions to build constants, compute offsets and loop bounds, perform general scalar math, derive valid-shape metadata, and form predicates for `scf` control flow. + +These ops are part of the documented PTO micro Instruction surface, but they are not PTO ISA instructions. + +--- + +#### Role in PTO micro Instruction Code + +- materialize scalar constants used by PTO scalar operands and loop bounds +- compute scalar/index offsets for tensor views, partitioning, and dynamic shapes +- perform general scalar integer and floating-point math outside PTO vector/tile payload operations +- derive scalar predicates that guard `scf.if` or `scf.while` +- apply scalar casts, width changes, bitwise ops, and selects without introducing PTO-specific control ops + +Prefer PTO ops for vector or tile payload math. Use `arith` for scalar computation and bookkeeping that surrounds PTO regions. + +--- + +#### Supported Surface + +The documented PTO micro Instruction surface supports the full scalar operation surface of upstream `arith`. The upstream `arith` dialect reference remains authoritative for the exhaustive op-by-op syntax and semantics. The categories below summarize how that support is used in PTO micro Instruction code. + +| Category | Representative Ops | Typical Use in PTO micro Instruction Code | +|----------|--------------------|------------------| +| Constants | `arith.constant` | integer, floating-point, boolean, and `index` constants | +| Integer / Index Arithmetic | `arith.addi`, `arith.subi`, `arith.muli`, `arith.divsi`, `arith.divui`, `arith.ceildivsi`, `arith.ceildivui`, `arith.floordivsi`, `arith.remsi`, `arith.remui` | offsets, bounds, chunk sizes, scalar math | +| Floating-Point Arithmetic | `arith.addf`, `arith.subf`, `arith.mulf`, `arith.divf`, `arith.negf`, `arith.maximumf`, `arith.minimumf`, `arith.maxnumf`, `arith.minnumf` | scalar math around PTO regions | +| Bitwise / Shift Ops | `arith.andi`, `arith.ori`, `arith.xori`, `arith.shli`, `arith.shrsi`, `arith.shrui` | flags, masks, packed scalar fields | +| Comparisons / Select | `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.maxsi`, `arith.minui` | predicates, clamps, scalar muxes | +| Casts / Width Changes | `arith.index_cast`, `arith.index_castui`, `arith.extsi`, `arith.extui`, `arith.trunci`, `arith.sitofp`, `arith.uitofp`, `arith.fptosi`, `arith.fptoui`, `arith.extf`, `arith.truncf`, `arith.bitcast` | ABI glue, dynamic-shape plumbing, scalar type adaptation | + +--- + +#### Current PTOAS Coverage + +- the current repository examples are still dominated by constants, casts, integer/index arithmetic, compares, and selects because those are the most common surrounding-scalar patterns in existing kernels +- backend-specific tests such as the PTO shared-dialect fixture visibly exercise only a representative subset of `arith` ops in a single path +- the documented PTO micro Instruction source-level contract is nevertheless the full scalar `arith` surface, not just the index-heavy subset that appears most often in current samples + +This section therefore uses representative categories and examples instead of pretending that the supported `arith` surface is limited to the currently most common sample patterns. + +--- + +#### Typical Patterns + +##### Scalar Setup + +```mlir +%c0 = arith.constant 0 : index +%c1 = arith.constant 1 : index +%scale = arith.constant 2.0 : f32 +``` + +##### Dynamic Offset Computation + +```mlir +%vrow = arith.index_cast %valid_row : i32 to index +%chunk = arith.muli %row, %c32 : index +%tail = arith.subi %limit, %chunk : index +``` + +##### General Scalar Arithmetic + +```mlir +%sum_i = arith.addi %lhs_i, %rhs_i : i32 +%sum_f = arith.addf %lhs_f, %rhs_f : f32 +%prod_f = arith.mulf %sum_f, %scale : f32 +``` + +##### Scalar Predicate and Selection + +```mlir +%is_first = arith.cmpi eq, %i, %c0 : index +%active = arith.select %is_first, %first_count, %steady_count : index +``` + +##### Bitwise / Width Adaptation + +```mlir +%flags = arith.andi %flags0, %flags1 : i32 +%wide = arith.extui %flags : i32 to i64 +%shrunk = arith.trunci %wide : i64 to i16 +``` + +--- + +#### Authoring Guidance + +- treat upstream `arith` scalar semantics as the source of truth for supported scalar ops +- keep `arith` values scalar or `index` typed; do not use `arith` as a substitute for PTO vector/tile compute +- use `arith` for general scalar math, scalar comparisons, bitwise operations, and casts around PTO regions, not just for `index` arithmetic +- use `arith.cmpi` / `arith.cmpf` plus `scf.if` / `scf.while` for control flow, not ad hoc control intrinsics +- prefer `arith.index_cast` / `arith.index_castui` at ABI or shape boundaries where `index` is required, but do not read that as a restriction on the rest of scalar `arith` + +--- + + + +### 15. SCF (Shared MLIR Dialect) + +> **Category:** Shared structured control flow around PTO regions +> **Dialect:** `scf` +> **Upstream Reference:** https://mlir.llvm.org/docs/Dialects/SCFDialect/ + +The upstream MLIR `scf` dialect defines structured control flow operations with regions, including counted loops, conditional regions, and while-style loops. In PTO micro Instruction code, `scf` is the control shell around PTO ops: it sequences DMA, vector, and tile operations; carries scalar or tile state across iterations; and preserves analyzable control flow for PTO-specific analyses and lowerings. + +These ops are part of the documented PTO micro Instruction surface, but they are shared MLIR control-flow constructs rather than PTO ISA instructions. + +--- + +#### Supported Ops + +| Op | Role in PTO micro Instruction Code | Notes | +|----|------------------------|-------| +| `scf.for` | counted loops and loop-carried values | also used for `__VEC_SCOPE__` dummy-loop form | +| `scf.if` | structured conditional execution | may yield values or act as side-effect-only branch | +| `scf.yield` | region terminator for `for` / `if` / `while` bodies | carries loop or branch results | +| `scf.while` | break-like or stateful loops | useful for source-level structured control | +| `scf.condition` | loop-continue / loop-exit decision for `scf.while` | placed in the "before" region | + +Ops such as `scf.execute_region`, `scf.forall`, or `scf.index_switch` are not part of the documented shared-dialect portion of the PTO micro Instruction surface here. + +--- + +#### Current PTOAS Coverage + +- `scf.for`, `scf.if`, and `scf.yield` are directly exercised in the shared-dialect PTO fixture and appear widely across PTO samples +- the `__VEC_SCOPE__` contract in PTO micro Instruction is modeled as a specialized `scf.for` annotated with `llvm.loop.aivector_scope` +- PTO synchronization and memory analyses explicitly reason about `scf.for`, `scf.if`, `scf.yield`, and `scf.while` +- `scf.while` and `scf.condition` appear in control-flow samples and are handled in PTO-to-EmitC control-flow lowering, but they are less broadly exercised than `for` / `if` on all backend paths + +--- + +#### Typical Patterns + +##### Counted Loop + +```mlir +scf.for %i = %c0 to %c4 step %c1 { + %offset = arith.muli %i, %c32 : index + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} +``` + +##### Counted Loop with Loop-Carried State + +```mlir +%final_alive = scf.for %i = %c0 to %c4 step %c1 + iter_args(%alive = %true) -> (i1) { + %break_now = arith.cmpi eq, %i, %c2 : index + %next_alive = scf.if %break_now -> (i1) { + scf.yield %false : i1 + } else { + scf.yield %alive : i1 + } + scf.yield %next_alive : i1 +} +``` + +##### Structured Conditional Region + +```mlir +%is_mode_a = arith.cmpi eq, %mode, %c0_i32 : i32 +scf.if %is_mode_a { + pto.tmuls ins(%data, %scale_a : !pto.tile_buf<...>, f32) outs(%data : !pto.tile_buf<...>) +} else { + pto.tadds ins(%data, %bias_b : !pto.tile_buf<...>, f32) outs(%data : !pto.tile_buf<...>) +} +``` + +##### While-Style Break Loop + +```mlir +%final:2 = scf.while (%i = %c0, %alive = %true) : (index, i1) -> (index, i1) { + %lt = arith.cmpi slt, %i, %c4 : index + %go = arith.andi %lt, %alive : i1 + scf.condition(%go) %i, %alive : index, i1 +} do { +^bb0(%i2: index, %alive2: i1): + %next_i = arith.addi %i2, %c1 : index + scf.yield %next_i, %alive2 : index, i1 +} +``` + +--- + +#### Authoring Guidance + +- use `scf.for` for regular counted loops and loop-carried scalar/tile state +- use `scf.if` for structured branching around PTO regions instead of inventing PTO-specific branch ops +- keep region results explicit with `scf.yield`; this is important for PTO analyses that track carried buffers and aliasing +- use `scf.while` only when a counted loop cannot express the control cleanly; `scf.for` remains the more common and better-exercised form in the current repository +- build branch predicates and loop conditions with `arith` ops, not PTO vector masks, unless the control decision truly comes from a scalarized value + +--- + +## Supported Data Types + +| Type | Bits | vreg Lanes | Description | +|------|------|-----------|-------------| +| `i8` / `u8` | 8 | 256 | Signed/unsigned 8-bit integer | +| `i16` / `u16` | 16 | 128 | Signed/unsigned 16-bit integer | +| `f16` | 16 | 128 | IEEE 754 half precision | +| `bf16` | 16 | 128 | Brain floating point | +| `i32` / `u32` | 32 | 64 | Signed/unsigned 32-bit integer | +| `f32` | 32 | 64 | IEEE 754 single precision | +| `i64` / `u64` | 64 | 32 | Signed/unsigned 64-bit integer | + +--- + +## Common Patterns + +### Softmax (Numerically Stable) + +```mlir +// 1. Find max +%max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %max_vec, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +%max_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + +// 2. exp(x - max) using fused op +%exp = pto.vexpdiff %logits, %max_bc : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// 3. Sum +%sum = pto.vcadd %exp, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %sum, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +%sum_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + +// 4. Divide +%softmax = pto.vdiv %exp, %sum_bc, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +### ReLU Variants + +```mlir +// Standard ReLU +%relu = pto.vrelu %input, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Leaky ReLU (scalar alpha) +%lrelu = pto.vlrelu %input, %alpha, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Parametric ReLU (per-element alpha) +%prelu = pto.vprelu %input, %alpha_vec : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// Fused add + ReLU +%fused = pto.vaddrelu %a, %b : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> +``` + +### Data Layout Conversion + +```mlir +// AoS → SoA (deinterleave) +%x, %y = pto.vldx2 %ub_xy[%offset], "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// SoA → AoS (interleave) +pto.vstx2 %x, %y, %ub_xy[%offset], "INTLV_B32", %all_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, index, !pto.mask +``` + + +--- + +## Quick Reference by Category + +### Memory Operations + +| Operation | Group | Description | +|-----------|-------|-------------| +| GM→UB DMA | 2 | `pto.copy_gm_to_ubuf` | +| UB→GM DMA | 2 | `pto.copy_ubuf_to_gm` | +| UB→UB Copy | 2 | `pto.copy_ubuf_to_ubuf` | +| Contiguous Load | 3 | `pto.vlds` with `NORM` dist | +| Broadcast Load | 3 | `pto.vlds` with `BRC_*` dist | +| Gather | 3 | `pto.vgather2`, `pto.vgatherb` | +| Contiguous Store | 3 | `pto.vsts` with `NORM_*` dist | +| Scatter | 3 | `pto.vscatter` | + +### Compute Operations + +| Operation | Group | Description | +|-----------|-------|-------------| +| Element-wise Arithmetic | 6, 7 | `pto.vadd`, `pto.vmul`, `pto.vabs`, etc. | +| Scalar Operations | 8 | `pto.vadds`, `pto.vmuls`, etc. | +| Transcendental | 6 | `pto.vexp`, `pto.vln`, `pto.vsqrt`, etc. | +| Reduction | 10 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin` | +| Comparison | 11 | `pto.vcmp`, `pto.vcmps` | +| Selection | 11 | `pto.vsel`, `pto.vselr` | + +### Type & Data Manipulation + +| Operation | Group | Description | +|-----------|-------|-------------| +| Type Conversion | 9 | `pto.vcvt` | +| Interleave/Deinterleave | 12 | `pto.vintlv`, `pto.vdintlv` | +| Interleave/Deinterleave | 12 | `pto.vintlv`, `pto.vdintlv`, `pto.vintlvv2`, `pto.vdintlvv2` | + +### Synchronization + +| Operation | Group | Description | +|-----------|-------|-------------| +| Intra-core Sync | 1 | `pto.set_flag`, `pto.wait_flag` | +| Pipeline Buffer Sync | 1 | `pto.get_buf`, `pto.rls_buf` | + +### Scalar & Control Operations + +Group 14 covers the full scalar `arith` surface. The rows below list common PTO micro Instruction patterns rather than an exhaustive partition of `arith` ops. + +| Operation | Group | Description | +|-----------|-------|-------------| +| Scalar Constants | 14 | `arith.constant` | +| Scalar Integer / Index Arithmetic | 14 | `arith.addi`, `arith.subi`, `arith.muli`, `arith.divsi`, `arith.remui`, `arith.ceildivsi`, etc. | +| Scalar Floating-Point Arithmetic | 14 | `arith.addf`, `arith.subf`, `arith.mulf`, `arith.divf`, `arith.maximumf`, etc. | +| Scalar Compare & Select | 14 | `arith.cmpi`, `arith.cmpf`, `arith.select` | +| Scalar Casts / Width Changes | 14 | `arith.index_cast`, `arith.index_castui`, `arith.extsi`, `arith.extui`, `arith.trunci`, `arith.sitofp`, etc. | +| Scalar Bitwise / Shift Ops | 14 | `arith.andi`, `arith.ori`, `arith.xori`, `arith.shli`, `arith.shrsi`, `arith.shrui`, etc. | +| Counted Loops | 15 | `scf.for` | +| Conditional Regions | 15 | `scf.if`, `scf.yield` | +| Break-like Structured Loops | 15 | `scf.while`, `scf.condition`, `scf.yield` | diff --git a/docs/release/vpto-spec-v0.2.md b/docs/release/vpto-spec-v0.2.md new file mode 100644 index 000000000..3c1e31419 --- /dev/null +++ b/docs/release/vpto-spec-v0.2.md @@ -0,0 +1,5072 @@ +# PTO micro Instruction Spec — Draft (A5) + +- v0.2: Update micro Instruction latency and throughput +- v0.1: Doc Init + +[toc] + +--- + +## Part I: Architecture Overview + +### Overview + +This document defines the PTO micro Instruction, a compiler-internal and externally facing specification designed to represent vector compute kernels within the PTO architecture. Much like NVVM provides a robust IR for GPU architectures, the PTO micro Instruction serves as the direct bridge between high-level programming models and the underlying hardware ISA, providing a precise, low-level representation of vector workloads explicitly designed for the Ascend 950 architecture. + +#### Position in the Stack and Layer Modeled + +The PTO micro Instruction operates as a very low-level intermediate representation within the PTO compiler stack. It is uniquely designed to accurately and comprehensively express all architectural information of the Ascend 950 hardware. It specifically models the bare-metal vector execution layer, making hardware-specific capabilities and constraints, such as exact vector lane configurations, memory space hierarchies, and hardware-specific fusion semantics, fully transparent and controllable. + +#### PTO Instruction Modes and Compilation Flows + +Within the end-to-end PTO software stack, PTO instructions may appear in three closely related authoring or lowering modes: + +- **PTO Tile Instruction**: tile-oriented PTO code that serves as a nano-kernel encapsulation of Tile operations, primarily expressing computation and data movement in terms of tile buffers, tile shapes, and tile-local layout. +- **PTO micro Instruction**: vector-execution-oriented PTO code that makes DMA setup, vector registers, masks, synchronization, and `__VEC_SCOPE__` boundaries explicit. This document is centered on this mode. +- **PTO Tile+micro Instruction**: a hybrid PTO form that keeps tile-level orchestration while embedding explicit micro-instruction regions where direct vector-pipeline control is required. + +From these PTO instruction forms, the stack can proceed along two main compilation flows: + +- **CCE generation flow**: PTO ISA is lowered into a CCE-oriented representation, which is then compiled by the BiSheng toolchain into Ascend device binaries. +- **Bytecode generation flow**: PTO ISA is emitted as bytecode, which is then compiled by the BiSheng toolchain into Ascend device binaries. + +```text +High-level frameworks / DSLs / library kernels + | + v + +----------------------------------+ + | PTO ISA layer | + | | + | (1) PTO Tile Instruction | + | (2) PTO micro Instruction | + | (3) PTO Tile+micro Instruction | + +----------------+-----------------+ + | + +------------+------------+ + | | + v v + +-------------------------+ +-------------------------+ + | Path A: generate CCE | | Path B: generate | + | (CCE-oriented form) | | bytecode | + +------------+------------+ +------------+------------+ + | | + v v + +-----------------------------------------------+ + | BiSheng compiler | + +---------------------------+-------------------+ + | + v + +-----------------------------+ + | Ascend device binaries | + +-----------------------------+ +``` + +#### Why External Developers Read or Author PTO micro Instruction + +While the majority of users will interact with the PTO architecture via higher-level frameworks, external developers may need to read or author PTO micro Instruction directly for several key reasons: + +- Custom Toolchain Development: build custom compiler frontends or domain-specific languages (DSLs) that target the Ascend 950 architecture with maximum hardware utilization. +- Performance Engineering: inspect the output of high-level compiler passes, verify fine-grained optimization behaviors, and pinpoint performance bottlenecks at the architectural level. +- Micro-Optimization: hand-author highly optimized, critical mathematical kernels using a stable, precise IR when higher-level abstractions cannot achieve the theoretical peak performance of the hardware. + +#### Relationship to CCE + +The PTO micro Instruction is designed to express the full semantic capabilities of the Compute Cube Engine (CCE), but with significant structural and pipeline advantages for compiler development. + +- Bypassing the C/Clang Pipeline: while CCE heavily relies on C/C++ extensions parsed by Clang, the PTO micro Instruction operates entirely independently of the C language frontend. By bypassing Clang AST generation and frontend processing, utilizing the PTO micro Instruction significantly reduces overall compilation time and memory overhead. +- Enhanced IR Verification: because the PTO micro Instruction is a strongly typed, SSA-based (Static Single Assignment) compiler IR rather than a C-wrapper API, it provides a much more rigorous and detailed IR verification process. Structural inconsistencies, invalid memory access patterns, and operand type mismatches are caught immediately with precise, explicit diagnostic feedback, providing developers with much higher visibility into kernel correctness than traditional CCE error reporting. + +#### Intended Audience + +This document is written for compiler engineers, library writers, and advanced performance architects. We expect the reader to have a working understanding of modern compiler infrastructure, specifically MLIR, the principles of Static Single Assignment (SSA) form, and a deep understanding of the vector-processing capabilities of the Ascend 950 architecture. + +### Getting Started + +The PTO micro Instruction is architected as a performance-critical layer within the compiler stack, specifically designed to exploit the **Decoupled Access-Execute** (DAE) nature of the Ascend 950 hardware. + +#### Hardware Pipeline Modeling + +The IR is structured to mirror the three primary hardware pipelines of the Ascend 950 architecture. Correct PTO micro Instruction authoring requires managing the interaction between these asynchronous units: + +**MTE2** (Memory Transfer Engine - Inbound): Responsible for moving data from Global Memory (GM) to the Unified Buffer (UB). + +**Vector Core** (Computation): The primary engine for executing SIMD operations on data stored in UB. + +**MTE3** (Memory Transfer Engine - Outbound): Responsible for moving processed data from UB back to GM. + +#### Architecture Detail: Vector Lane (VLane) + +The vector register is organized as **8 VLanes** of 32 bytes each. A VLane is the atomic unit for group reduction operations. + +``` +vreg (256 bytes total): +┌─────────┬─────────┬─────────┬─────┬─────────┬─────────┐ +│ VLane 0 │ VLane 1 │ VLane 2 │ ... │ VLane 6 │ VLane 7 │ +│ 32B │ 32B │ 32B │ │ 32B │ 32B │ +└─────────┴─────────┴─────────┴─────┴─────────┴─────────┘ +``` + +Elements per VLane by data type: + +| Data Type | Elements/VLane | Total Elements/vreg | +|-----------|---------------|-------------------| +| i8/u8 | 32 | 256 | +| i16/u16/f16/bf16 | 16 | 128 | +| i32/u32/f32 | 8 | 64 | +| i64/u64 | 4 | 32 | + +#### Memory and Synchronization Model + +The PTO micro Instruction enforces a strict memory hierarchy. The Unified Buffer (UB) is the only valid operand source for vector compute instructions. Consequently, the architecture of a PTO micro Instruction program is defined by the explicit management of data movement: + +**Address Space Isolation**: The IR uses `!pto.ptr` to distinguish between GM (`!pto.ptr`) and UB (`!pto.ptr`). The verifier ensures that vector compute operations do not access GM directly; data must first be moved into UB. + +**UB Capacity**: The Unified Buffer provides 256KB of on-chip SRAM (also referred to as "vecTile"). + +**Data Flow**: + +``` +┌─────────────────────────────────────────────┐ +│ Global Memory (GM) │ +│ (Off-chip HBM/DDR) │ +└─────────────────────┬───────────────────────┘ + │ DMA (MTE2 inbound / MTE3 outbound) +┌─────────────────────▼───────────────────────┐ +│ Unified Buffer (UB) │ +│ (On-chip SRAM, 256KB) │ +└─────────────────────┬───────────────────────┘ + │ Vector Load/Store (PIPE_V) +┌─────────────────────▼───────────────────────┐ +│ Vector Register File (VRF) │ +│ vreg (256B each) + mask (256-bit each) │ +└─────────────────────────────────────────────┘ +``` + +1. **GM → UB**: DMA transfer via MTE2 (`pto.copy_gm_to_ubuf`) +2. **UB → vreg**: Vector Load instructions (`pto.vlds`, `pto.vldx2`, etc.) +3. **vreg → vreg**: Compute instructions (`pto.vadd`, `pto.vmul`, etc.) +4. **vreg → UB**: Vector Store instructions (`pto.vsts`, `pto.vstx2`, etc.) +5. **UB → GM**: DMA transfer via MTE3 (`pto.copy_ubuf_to_gm`) + +**Load/Store Access Patterns**: + +For UB↔vreg data movement, besides contiguous load/store, the architecture provides rich access pattern support including strided access, pack/unpack, interleave/deinterleave, broadcast, upsample/downsample, channel split/merge, gather/scatter, and squeeze/expand operations. For detailed instruction syntax and distribution modes, refer to the [Vector Load/Store](#isa-03-vector-load-store) group in the ISA specification. + +#### Synchronization Model + +The Ascend 950 architecture employs a cluster-based design with a 1:2 ratio of Cube cores to Vector cores. The PTO micro Instruction provides multiple levels of synchronization to manage concurrent execution across pipelines and cores: + +**Inter-Core Synchronization (within a cluster):** + +Synchronization between cores within the same cluster is achieved via the core sync mechanism using `pto.set_intra_core` and `pto.wait_intra_core` operations. This enables coordination between Cube and Vector cores sharing the same cluster resources. + +**Vector Core Pipeline Synchronization:** + +Within a single core, multiple pipelines operate asynchronously: + +- **MTE2 (PIPE_MTE2)**: DMA copy-in from GM to UB +- **MTE3 (PIPE_MTE3)**: DMA copy-out from UB to GM +- **Vector Compute (PIPE_V)**: Vector ALU operations +- **Scalar (PIPE_S)**: Scalar unit running the kernel program + +Pipeline synchronization can be achieved through two mechanisms: + +1. **Flag/Event mechanism**: `pto.set_flag` and `pto.wait_flag` operations resolve Read-After-Write (RAW) and Write-After-Read (WAR) hazards between pipelines. + +2. **Buffer-ID mechanism**: `pto.get_buf` and `pto.rls_buf` provide finer-grained synchronization through buffer acquisition and release semantics for producer-consumer coordination. + +**Intra-Pipeline Memory Barriers (within `__VEC_SCOPE__`):** + +Within the vector execution scope, the hardware does not track UB address aliasing between reg↔UB accesses. When UB addresses overlap or alias between vector load/store operations, explicit memory barriers are required: + +```c +pto.mem_bar "VV_ALL" // All prior vector ops complete before subsequent +pto.mem_bar "VST_VLD" // All prior vector stores visible before subsequent loads +pto.mem_bar "VLD_VST" // All prior vector loads complete before subsequent stores +``` + +Without proper barriers, loads may see stale data or stores may be reordered incorrectly. + +#### Execution Scopes (__VEC_SCOPE__) + +`__VEC_SCOPE__` is the IR-level representation of a Vector Function (VF) launch. In the PTO architecture, it defines the hardware interface between the Scalar Unit and the Vector Thread. + +In PTO micro Instruction source IR, vector execution scopes are modeled as dedicated region ops. The default form is `pto.vecscope`; when the scope body must reject implicit capture and require explicit region arguments, use `pto.strict_vecscope`. + +**Scalar-Vector Interface:** + +The execution model follows non-blocking fork semantics: + +- Scalar invocation: the scalar processor invokes a vector thread by calling a VF. Once the launch command is issued, the scalar unit does not stall and continues executing subsequent instructions in the pipeline. +- Vector execution: after invocation, the vector thread independently fetches and executes the instructions defined within the VF scope. +- Parallelism: this decoupled execution allows the scalar and vector units to run in parallel, so the scalar unit can prepare addresses or manage control flow while the vector unit performs heavy SIMD computation. + +**Launch Mechanism And Constraints:** + +- Parameter buffering: all arguments required by the VF must be staged in hardware-specific buffers. +- Launch overhead: launching a VF incurs a latency of a few cycles. Very small VFs should account for this overhead because launch cost can rival useful computation time. + +**MLIR Representation:** + +```mlir +pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} +``` + +**Strict MLIR Representation:** + +```mlir +pto.strict_vecscope(%ub, %ub_out, %lane) { +^bb0(%in: !pto.ptr, %out: !pto.ptr, %iv: index): + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %in[%iv] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %out[%iv], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} : (!pto.ptr, !pto.ptr, index) -> () +``` + +`pto.strict_vecscope` is the strict form of `pto.vecscope`. + +- `pto.vecscope` allows the body to use surrounding SSA values directly. +- `pto.strict_vecscope` requires every external value used by the body to be passed through the op operand list and received as a body block argument. +- `pto.strict_vecscope` rejects implicit capture from the surrounding scope. +- both ops still represent one explicit VPTO vector interval. + +### Example: VecScope + +```mlir +pto.set_loop2_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.copy_gm_to_ubuf %7, %2, %3, %3, %c0_i64, %c32_i64, %4, %c0_i64, %c0_i64, + %false, %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i1, i64, i64, i64 + +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +pto.vecscope { + scf.for %lane = %c0 to %9 step %c64 { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %2[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %8[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } +} + +pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] +pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop2_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 +pto.copy_ubuf_to_gm %8, %14, %3, %3, %c0_i64, %c32_i64, %4, %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i64 +``` + +### Example: Strict VecScope + +```mlir +pto.strict_vecscope(%ub_in, %ub_out, %lane, %remaining) { +^bb0(%in: !pto.ptr, %out: !pto.ptr, %iv: index, %rem: i32): + %mask, %next_remaining = pto.plt_b32 %rem : i32 -> !pto.mask, i32 + %v = pto.vlds %in[%iv] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %out[%iv], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} : (!pto.ptr, !pto.ptr, index, i32) -> () +``` + +Use `pto.strict_vecscope` when the source form should make all vector-scope inputs explicit in the region signature instead of relying on surrounding SSA visibility. The scope op itself only defines the vector-interval boundary and region argument contract. + +### Scope + +This document is the interface specification centered on the `mlir::pto` dialect and the shared MLIR surface used alongside it in PTO micro Instruction programs. + +It only describes: + +- operation names +- operand and result lists +- operand and result types +- important attributes +- C-style semantics for each operation + +It does not describe lowering strategy. + +PTO micro Instruction source programs are not restricted to `pto` operations alone. In practice they also use shared MLIR dialect ops, most notably the full scalar operation surface of `arith` together with structured control-flow ops from `scf`, to express scalar constants, scalar arithmetic, type conversion, comparisons, and structured control flow around PTO vector or tile regions. These shared-dialect ops are part of the supported PTO micro Instruction source surface and should be regarded as part of PTO-ISA alongside `pto` dialect operations. + +- `vreg`: `!pto.vreg` + Fixed-width VPTO vector type with total width exactly 256 bytes. +- `mask`: `!pto.mask` + Typed predicate-register view. `G` is one of `b8`, `b16`, `b32` and records the byte-granularity interpretation used by VPTO ops and verifiers. +- `align`: `!pto.align` +- `buf`: buffer-like LLVM pointer type accepted by the dialect +- `buf_like`: `memref<...>` or `!llvm.ptr` for stateless/predicate + `vld*/vst*` families +- `idx`: `index` +- `i32`: `i32` +- `i64`: `i64` + +### Shared MLIR Dialects + +- `arith`: the full scalar `arith` surface is supported in PTO micro Instruction programs, covering scalar integer, floating-point, boolean, and `index` operations. In current samples the most common uses are still constants, offset/bounds arithmetic, casts, compares, and selects. +- `scf`: structured control flow used to model counted loops, conditional regions, loop-carried state, and break-like control around PTO compute and data-movement ops. +- Shared dialect ops remain in standard MLIR form so that PTO analyses and backend passes can reason about control flow and scalar state without re-encoding them as PTO-specific instructions. + +### Core Types + +### Element Types +`vreg`: `!pto.vreg` Fixed-width PTO micro Instruction vector type with total width exactly 256 bytes (2048 bits). `N` is the lane count, `T` is the element type, and `N * bitwidth(T) = 2048`. + +| Type | Bits | Description | +|------|------|-------------| +| `i8` / `s8` / `u8` | 8 | Signless/signed/unsigned 8-bit integer | +| `i16` / `s16` / `u16` | 16 | Signless/signed/unsigned 16-bit integer | +| `i32` / `s32` / `u32` | 32 | Signless/signed/unsigned 32-bit integer | +| `i64` / `s64` / `u64` | 64 | Signless/signed/unsigned 64-bit integer | +| `f16` | 16 | IEEE 754 half precision | +| `bf16` | 16 | Brain floating point | +| `f32` | 32 | IEEE 754 single precision | +| `f8e4m3` | 8 | FP8 (4-bit exponent, 3-bit mantissa) | +| `f8e5m2` | 8 | FP8 (5-bit exponent, 2-bit mantissa) | + +### Address Space Conventions + +PTO micro Instruction memory operands use `!pto.ptr`. This specification models the following memory-space attributes: + +| Space | Interpretation | +|-------|----------------| +| `gm` | Global Memory (GM), off-chip HBM/DDR storage | +| `ub` | Unified Buffer (UB), on-chip vector buffer | + +Typical pointer construction and pointer arithmetic follow the same `!pto.ptr<..., space>` form: + +```mlir +%0 = pto.castptr %c0 : i64 -> !pto.ptr +%1 = pto.addptr %0, %c1024 : !pto.ptr -> !pto.ptr +``` + +### `!pto.ptr` + +`!pto.ptr` is the typed pointer form used for explicit memory operands in PTO micro Instruction. + +- `T` is the element type associated with the pointed-to storage. +- `space` is the memory domain, typically `gm` or `ub` in this specification. +- A `pto.ptr` value carries an address plus its element-type / memory-space interpretation, but it does not carry tensor shape or stride metadata by itself. +- Tensor semantics are introduced separately through view-building operations such as `pto.make_tensor_view`. +- Pointer arithmetic is element-based rather than byte-based. + +Typical examples: + +- `!pto.ptr` +- `!pto.ptr` +- `!pto.ptr` + +### Pointer Operations + +#### `pto.castptr` + +- **syntax:** `%result = pto.castptr %addr : i64 -> !pto.ptr` +- **semantics:** Reinterpret a scalar address value as a typed PTO pointer in the target memory space. + +```c +result = (ptr)addr; +``` + +`pto.castptr` is a pointer-construction operation. It does not perform data movement and does not by itself imply any load/store side effect. + +#### `pto.addptr` + +- **syntax:** `%result = pto.addptr %ptr, %offset : !pto.ptr -> !pto.ptr` +- **semantics:** Compute a new pointer by advancing the base pointer by an element offset. + +```c +result = ptr + offset; // offset counted in elements, not bytes +``` + +`pto.addptr` preserves both the element type `T` and the memory-space tag `space`. + +#### Pointer-Based Vector Access Example + +The following lowered-style fragment shows how typed PTO pointers flow through pointer construction, pointer arithmetic, structured control flow, and PTO memory ops: + +```mlir +%0 = pto.castptr %c0 : i64 -> !pto.ptr +%1 = pto.addptr %0, %c1024 : !pto.ptr -> !pto.ptr +pto.vecscope { + %16 = scf.for %arg3 = %c0 to %11 step %c64 iter_args(%arg4 = %12) -> (i32) { + %mask, %scalar_out = pto.plt_b32 %arg4 : i32 -> !pto.mask, i32 + %17 = pto.vlds %1[%arg3] : !pto.ptr -> !pto.vreg<64xf32> + %18 = pto.vabs %17, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %18, %10[%arg3], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %scalar_out : i32 + } +} +``` + +In this pattern, `pto.castptr` materializes a typed UB pointer, `pto.addptr` shifts the base by 1024 `f32` elements, and the subsequent `[%arg3]` indexing on `pto.vlds` / `pto.vsts` applies an additional element offset relative to that base. + +### Special Types + +#### `!pto.mask` + +`!pto.mask` models an A5 predicate register (256-bit) under a typed granularity view, not an integer vector. + +`G` is part of the type and MUST be one of: + +- `b32` +- `b16` +- `b8` + +All three forms describe the same physical 256-bit predicate-register class. The type parameter does not encode how many lanes are currently active. Instead, it records how VPTO interprets the register when matching mask-producing ops, mask-consuming ops, and verifier legality rules. + +In the ISA chapters below, this document uses `!pto.mask` as shorthand when a +family is generic over granularity. For op families whose names already encode +the granularity, such as `pset_b32`, `pge_b16`, `plt_b8`, +`pdintlv_b8`, and `pintlv_b16`, examples use the corresponding concrete typed +mask. + +**Mask Granularity:** + +The predicate register is 256 bits in length, where each bit controls 1 byte of data. `G` therefore describes how many bytes form one logical element slot: + +| Mask Type | Bytes / Element Slot | Typical Element Family | Derived Logical Lanes | +|-----------|----------------------|------------------------|-----------------------| +| `!pto.mask` | 4 | `f32` / `i32` | 64 | +| `!pto.mask` | 2 | `f16` / `bf16` / `i16` | 128 | +| `!pto.mask` | 1 | 8-bit element family | 256 | + +This is intentionally different from a lane-vector model such as `mask<64xi1>`: + +- `!pto.mask` still denotes a 256-bit predicate register; +- `64` is only the derived logical lane count for the `b32` view; +- value-level patterns such as `PAT_VL32` describe which lanes are active, not a different type. + +**Predication Behavior (Zero-Merge):** + +The native hardware predication mode is **ZEROING** — inactive lanes produce zero: + +```c +dst[i] = mask[i] ? op(src0[i], src1[i]) : 0 // ZEROING mode +``` + +```mlir +// Predicated add: inactive lanes produce zero +%mask = pto.pset_b32 "PAT_VL32" : !pto.mask // first 32 logical b32 lanes active +%result = pto.vcmp %a, %b, %mask, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +``` + +```mlir +// Compare and select: generate mask from comparison, use for conditional select +%mask = pto.vcmp %lhs, %rhs, %seed, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +%out = pto.vsel %x, %y, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +#### `!pto.align` + +`!pto.align` models the A5 vector-align carrier state. It is not payload data. + +```mlir +%align = pto.vldas %ub : !pto.ptr -> !pto.align +%vec, %align_out, %base_out = pto.vldus %ub, %align : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align, !pto.ptr +``` + +--- + +## Part II: Notation Convention + +This section defines the MLIR syntax patterns and C-style semantic notation used throughout the ISA reference (Part III). + +### MLIR Op Syntax Patterns + +All PTO micro Instruction operations follow standard MLIR syntax. The common patterns are: + +**Unary (one vector in, one vector out):** + +```mlir +%result = pto. %input : !pto.vreg -> !pto.vreg +``` + +**Binary (two vectors in, one vector out):** + +```mlir +%result = pto. %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg +``` + +**Vec-Scalar (one vector + one scalar in, one vector out):** + +```mlir +%result = pto. %input, %scalar : !pto.vreg, T -> !pto.vreg +``` + +**Load (memory to register):** + +```mlir +%result = pto.vlds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.vreg +``` + +**Store (register to memory):** + +```mlir +pto.vsts %value, %destination[%offset] {dist = "DIST"} : !pto.vreg, !pto.ptr +``` + +**Dual Load (one load, two results — deinterleave):** + +```mlir +%low, %high = pto.vldx2 %source[%offset], "DIST" : !pto.ptr, index -> !pto.vreg, !pto.vreg +``` + +**Dual Store (two inputs, one interleaved store):** + +```mlir +pto.vstx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vreg, !pto.ptr, index, !pto.mask +``` + +**Compare (two vectors + seed mask in, mask out):** + +```mlir +%mask = pto.vcmp %src0, %src1, %seed, "CMP_MODE" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.mask +``` + +**Conversion (one vector in, different-typed vector out):** + +```mlir +%result = pto.vcvt %input {round_mode = "ROUND_R", sat = "RS_ENABLE", part = "PART_EVEN"} : !pto.vreg -> !pto.vreg +``` + +**Predicate construction:** + +```mlir +%mask = pto.pset_b32 "PAT_ALL" : !pto.mask +%tail = pto.pge_b32 "PAT_VL16" : !pto.mask +``` + +**Sync operations:** + +```mlir +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +**Pointer construction and arithmetic:** + +```mlir +%ptr = pto.castptr %addr : i64 -> !pto.ptr +%ptr2 = pto.addptr %ptr, %offset : !pto.ptr -> !pto.ptr +``` + +### Shared Dialect Syntax Patterns + +PTO micro Instruction programs may interleave PTO ops with standard MLIR `arith` and `scf` ops. +The examples below emphasize common index-heavy patterns, but `arith` support is not limited to index arithmetic. + +**Scalar / index constant:** + +```mlir +%c0 = arith.constant 0 : index +%zero = arith.constant 0.0 : f32 +``` + +### C-Style Semantics Convention + +For each ISA operation in Part III, semantics are expressed as C code. The convention: + +```c +// Vector register contents as arrays: +T dst[N]; // destination +T src0[N]; // first source +T src1[N]; // second source (binary ops) +T scalar; // scalar operand (vec-scalar ops) +int mask[N]; // per-lane predicate (0 or 1) + +// N = lane count determined by type: +// N = 256 for i8/u8 +// N = 128 for i16/u16/f16/bf16 +// N = 64 for i32/u32/f32 +// N = 32 for i64/u64 +``` + +**Example — pto.vadd semantics:** + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +**Example — pto.vcgadd (group reduction per VLane) semantics:** + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +### Template Placeholder Conventions + +| Placeholder | Meaning | +|-------------|---------| +| `"SRC_PIPE"`, `"DST_PIPE"` | Pipeline identifiers: `"PIPE_MTE2"`, `"PIPE_V"`, `"PIPE_MTE3"` | +| `"EVENT_ID"` | Event identifier: `"EVENT_ID0"` etc. | +| `"DIST"` | Distribution mode string (see the relevant load/store ISA group in Part III) | +| `"CMP_MODE"` | Compare predicate: `eq \| ne \| lt \| le \| gt \| ge` | +| `"ROUND_MODE"` | Rounding mode: `ROUND_R \| ROUND_A \| ROUND_F \| ROUND_C \| ROUND_Z` | +| `"SAT_MODE"` | Saturation: `RS_ENABLE \| RS_DISABLE` | +| `"PART_MODE"` | Half selector: `PART_EVEN \| PART_ODD` | +| `"PAT_*"` | Predicate pattern literal | +| `T` | Element type (f32, f16, bf16, i32, i16, i8, etc.) | +| `N` | Lane count (`N * bitwidth(T) = 2048`) | + +--- + +## Part III: ISA Instruction Reference — Summary + +This section provides a categorized overview of all PTO micro Instruction operations plus the shared MLIR `arith` and `scf` ops that may appear in PTO micro Instruction programs. Detailed documentation for each group is included later in this merged document. + +--- + +## Instruction Groups + +| # | Group | Description | Count | Details | +|---|-------|-------------|-------|---------| +| 1 | [Pipeline Sync](#isa-01-pipeline-sync) | Intra-core pipeline synchronization | 5 | `pto.set_flag`, `pto.wait_flag`, `pto.pipe_barrier`, `pto.get_buf`, `pto.rls_buf` | +| 2 | [DMA Copy Programming](#isa-02-dma-copy) | DMA configuration and transfer between GM↔UB | 9 | `pto.set_loop*_stride_*`, `pto.set_loop_size_*`, `pto.copy_gm_to_ubuf`, `pto.copy_ubuf_to_ubuf`, `pto.copy_ubuf_to_gm` | +| 3 | [Vector Load/Store](#isa-03-vector-load-store) | UB↔vreg data movement with various access patterns | ~20 | `pto.vlds`, `pto.vldx2`, `pto.vgather2`, `pto.vsts`, `pto.vstx2`, `pto.vscatter`, etc. | +| 4 | [Predicate Load/Store](#isa-04-predicate-load-store) | UB↔mask register movement | 7 | `pto.plds`, `pto.pld`, `pto.pldi`, `pto.psts`, `pto.pst`, `pto.psti`, `pto.pstu` | +| 5 | [Materialization & Predicate Ops](#isa-05-materialization-predicate) | Scalar broadcast, predicate generation and manipulation | ~17 | `pto.vbr`, `pto.vdup`, `pto.pset_b*`, `pto.pge_b*`, `pto.plt_b*`, `pto.ppack`, `pto.punpack`, `pto.pnot`, `pto.psel`, etc. | +| 6 | [Unary Vector Ops](#isa-06-unary-vector-ops) | Single-input element-wise operations | 9 | `pto.vabs`, `pto.vexp`, `pto.vln`, `pto.vsqrt`, `pto.vrec`, `pto.vrelu`, `pto.vnot`, `pto.vbcnt`, `pto.vcls` | +| 7 | [Binary Vector Ops](#isa-07-binary-vector-ops) | Two-input element-wise operations | 13 | `pto.vadd`, `pto.vsub`, `pto.vmul`, `pto.vdiv`, `pto.vmax`, `pto.vmin`, `pto.vand`, `pto.vor`, `pto.vxor`, `pto.vshl`, `pto.vshr`, `pto.vaddc`, `pto.vsubc` | +| 8 | [Vec-Scalar Ops](#isa-08-vec-scalar-ops) | Vector-scalar operations | 8 | `pto.vadds`, `pto.vmuls`, `pto.vmaxs`, `pto.vmins`, `pto.vlrelu`, `pto.vshls`, `pto.vshrs`, `pto.vaddcs`, `pto.vsubcs` | +| 9 | [Conversion Ops](#isa-09-conversion-ops) | Type conversion with rounding/saturation control | 2 | `pto.vcvt`, `pto.vtrc` | +| 10 | [Reduction Ops](#isa-10-reduction-ops) | Vector reductions | 3 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin` | +| 11 | [Compare & Select](#isa-11-compare-select) | Comparison and conditional selection | 5 | `pto.vcmp`, `pto.vcmps`, `pto.vsel`, `pto.vselr`, `pto.vselrv2` | +| 12 | [Data Rearrangement](#isa-12-data-rearrangement) | In-register data movement and permutation | 4 | `pto.vintlv`, `pto.vdintlv`, `pto.vintlvv2`, `pto.vdintlvv2` | +| 13 | [DSA/SFU Ops](#isa-13-dsa-sfu-ops) | Specialized ops, index generation, and sorting helpers | 5 | `pto.vmull`, `pto.vmula`, `pto.vci`, `pto.vbitsort`, `pto.vmrgsort4` | +| 14 | [Arith (Shared MLIR Dialect)](#isa-14-shared-arith) | Full scalar `arith` surface used around PTO ops; the companion page lists categories and representative examples | all scalar ops | `arith.constant`, `arith.addi`, `arith.addf`, `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.index_cast`, `arith.extsi`, `arith.trunci`, `arith.andi`, `arith.shli`, etc. | +| 15 | [SCF (Shared MLIR Dialect)](#isa-15-shared-scf) | Structured loops, branches, and loop-carried state around PTO regions | 5 | `scf.for`, `scf.if`, `scf.while`, `scf.condition`, `scf.yield` | + +--- + +## Detailed ISA Group Reference + +This section inlines the 15 ISA group documents so the architectural overview, notation, summary table, and per-group semantics can be read in a single file. + + + +### 1. Pipeline Synchronization + +> **Category:** Synchronization primitives for coordinating pipeline execution +> **Pipelines:** MTE2 (GM→UB), PIPE_V (Vector), MTE3 (UB→GM) + +The PTO micro Instruction model operates on the Ascend 950's **Decoupled Access-Execute** architecture. The MTE and Vector pipelines run asynchronously, requiring explicit synchronization to prevent data hazards. + +--- + +#### Intra-Core Pipeline Sync + +These ops coordinate data flow between pipelines within a single vector core. + +##### `pto.set_flag` + +- **syntax:** `pto.set_flag["SRC_PIPE", "DST_PIPE", "EVENT_ID"]` +- **semantics:** Signal event from source pipe to destination pipe. + +```c +set_flag(src_pipe, dst_pipe, event_id); +``` + +**Example:** After MTE2 completes GM→UB transfer, signal Vector pipe: +```mlir +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +--- + +##### `pto.wait_flag` + +- **syntax:** `pto.wait_flag["SRC_PIPE", "DST_PIPE", "EVENT_ID"]` +- **semantics:** Block destination pipe until source pipe signals event. + +```c +wait_flag(src_pipe, dst_pipe, event_id); +``` + +**Example:** Vector pipe waits for MTE2 data to arrive: +```mlir +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +--- + +##### `pto.pipe_barrier` + +- **syntax:** `pto.pipe_barrier "PIPE_*"` +- **semantics:** Drain all pending ops in the specified pipe. All previously issued operations on that pipe complete before any subsequent operation begins. + +```c +pipe_barrier(pipe); +``` + +**Pipe identifiers:** `PIPE_MTE2`, `PIPE_V`, `PIPE_MTE3` + +**Example:** Two back-to-back `copy_ubuf_to_gm` calls writing to the same GM address. Without a barrier, MTE3 may reorder them and the final GM value is non-deterministic: + +```mlir +// Both stores target the same GM address — order matters! +pto.copy_ubuf_to_gm %ub_partial_0, %gm_result, ... +// Without pipe_barrier, MTE3 could execute the second copy before the first +// completes, producing a non-deterministic result at %gm_result. +pto.pipe_barrier "PIPE_MTE3" +// After barrier: first copy is guaranteed complete. Second copy overwrites deterministically. +pto.copy_ubuf_to_gm %ub_partial_1, %gm_result, ... +``` + +--- + +##### `pto.get_buf` + +- **syntax:** `pto.get_buf "PIPE_*", %buf_id, %mode : i64, i64` +- **semantics:** Acquire buffer slot for inter-pipeline double-buffering coordination. + +```c +get_buf(pipe, buf_id, mode); +``` + +--- + +##### `pto.rls_buf` + +- **syntax:** `pto.rls_buf "PIPE_*", %buf_id, %mode : i64, i64` +- **semantics:** Release buffer slot to allow other pipeline to proceed. + +```c +rls_buf(pipe, buf_id, mode); +``` + +--- + +##### `pto.mem_bar` + +- **syntax:** `pto.mem_bar "BARRIER_TYPE"` +- **semantics:** Intra-vector-pipe memory fence within `__VEC_SCOPE__`. Required when UB addresses alias between vector load/store operations. + +```c +mem_bar(barrier_type); +``` + +**Barrier types:** + +| Type | Semantics | +|------|-----------| +| `VV_ALL` | All prior vector ops complete before subsequent | +| `VST_VLD` | All prior vector stores visible before subsequent loads | +| `VLD_VST` | All prior vector loads complete before subsequent stores | + +**Example:** Ensure stores are visible before loads to same UB region: +```mlir +pto.vsts %v0, %ub[%c0] : !pto.vreg<64xf32>, !pto.ptr +pto.mem_bar "VST_VLD" +%v1 = pto.vlds %ub[%c0] : !pto.ptr -> !pto.vreg<64xf32> +``` + +--- + +#### Intra-Core Sync Patterns & Examples + +##### Example 1: `set_flag` / `wait_flag` (Explicit Events) + +Each cross-pipeline data dependency requires an explicit signal/wait pair. The programmer must manually insert `set_flag` after the producer and `wait_flag` before the consumer. + +```mlir +// ─── Stage 1: MTE2 loads data from GM into UB ─── +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, ... + +// MTE2 signals: "UB data is ready for Vector pipe" +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +// ─── Stage 2: Vector pipe consumes UB data ─── +// Vector waits until MTE2's signal arrives +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +pto.vecscope { + %v = pto.vlds %ub_ptr[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} + +// Vector signals: "UB output is ready for MTE3" +pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + +// ─── Stage 3: MTE3 stores result from UB back to GM ─── +// MTE3 waits until Vector's signal arrives +pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + +pto.copy_ubuf_to_gm %ub_out, %gm_out, ... +``` + +**Key property:** Every cross-pipeline edge is an explicit `(set_flag, wait_flag)` pair. Simple for straight-line code, but gets verbose in loops (see Example 3). + +--- + +##### Example 2: `get_buf` / `rls_buf` (Resource-Based) + +Instead of naming events, each pipeline declares when it **acquires** (`get_buf`) and **releases** (`rls_buf`) a shared UB buffer. Cross-pipeline RAW/WAR dependencies are resolved implicitly by program order — if MTE2 releases `buf_A` and Vector later acquires `buf_A`, the hardware ensures the acquire cannot proceed until the release completes. + +```mlir +// ─── Stage 1: MTE2 loads data into UB ─── +// MTE2 acquires ub_ptr — blocks if Vector hasn't released it from a prior iteration +pto.get_buf "PIPE_MTE2", %bufid_ub_ptr, %mode : i64, i64 +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, ... +// MTE2 done writing ub_ptr — release it so Vector can consume +pto.rls_buf "PIPE_MTE2", %bufid_ub_ptr, %mode : i64, i64 + +// ─── Stage 2: Vector computation ─── +// Vector acquires ub_ptr (input) — blocks until MTE2 releases it (RAW: MTE2 write → V read) +pto.get_buf "PIPE_V", %bufid_ub_ptr, %mode : i64, i64 +// Vector acquires ub_out (output) — blocks until MTE3 releases it from a prior iteration (WAR: MTE3 read → V write) +pto.get_buf "PIPE_V", %bufid_ub_out, %mode : i64, i64 + +pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub_ptr[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} + +// Vector done reading ub_ptr — release so MTE2 can reuse it in next iteration +pto.rls_buf "PIPE_V", %bufid_ub_ptr, %mode : i64, i64 +// Vector done writing ub_out — release so MTE3 can consume +pto.rls_buf "PIPE_V", %bufid_ub_out, %mode : i64, i64 + +// ─── Stage 3: MTE3 stores result to GM ─── +// MTE3 acquires ub_out — blocks until Vector releases it (RAW: V write → MTE3 read) +pto.get_buf "PIPE_MTE3", %bufid_ub_out, %mode : i64, i64 +pto.copy_ubuf_to_gm %ub_out, %gm_out, ... +// MTE3 done reading ub_out — release so Vector can reuse it in next iteration +pto.rls_buf "PIPE_MTE3", %bufid_ub_out, %mode : i64, i64 +``` + +**Key property:** No event IDs needed. Dependencies are implicit from program order of `get_buf`/`rls_buf` on the same buffer ID. This becomes much more convenient in multi-iteration loops (see Example 3). + +--- + +##### Example 3: Ping/Pong Double-Buffering Loop + +Double-buffering overlaps DMA and compute by using two UB buffers alternately. All three stages (MTE2, Vector, MTE3) appear in the **same iteration** — the hardware pipelines them across iterations because different iterations operate on different buffers (`buf[i%2]`). + +###### Event ID scheme (`set_flag` / `wait_flag`) + +With 2 ping/pong buffers and 2 pipeline pairs (MTE2↔V, V↔MTE3), `set_flag`/`wait_flag` needs **8 event IDs** = 2 pipe-pairs × 2 buffers × (forward + reverse): + +**MTE2 ↔ Vector (input buffers):** + +| Event ID | Direction | Purpose | +|----------|-----------|---------| +| `EVT_IN_FWD_0` | MTE2 → V | RAW: buf_in[0] data ready | +| `EVT_IN_FWD_1` | MTE2 → V | RAW: buf_in[1] data ready | +| `EVT_IN_REV_0` | V → MTE2 | WAR: Vector done reading buf_in[0] | +| `EVT_IN_REV_1` | V → MTE2 | WAR: Vector done reading buf_in[1] | + +**Vector ↔ MTE3 (output buffers):** + +| Event ID | Direction | Purpose | +|----------|-----------|---------| +| `EVT_OUT_FWD_0` | V → MTE3 | RAW: buf_out[0] result ready | +| `EVT_OUT_FWD_1` | V → MTE3 | RAW: buf_out[1] result ready | +| `EVT_OUT_REV_0` | MTE3 → V | WAR: MTE3 done reading buf_out[0] | +| `EVT_OUT_REV_1` | MTE3 → V | WAR: MTE3 done reading buf_out[1] | + +###### 3a. `set_flag` / `wait_flag` version + +```mlir +// ═══ Pre-loop: prime ALL reverse-dependency signals ═══ +// Both input and output buffers start unused. We must pre-send +// reverse-dep signals so the first iteration's wait_flags don't deadlock. +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] // ◀ PRIME: buf_in[0] "free" +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] // ◀ PRIME: buf_in[1] "free" +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] // ◀ PRIME: buf_out[0] "free" +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] // ◀ PRIME: buf_out[1] "free" + +scf.for %i = %c0 to %N step %c1 { + // ── All 3 stages in same iteration, indexed by i%2 ── + // %pp = i % 2 (ping/pong selector for buffer & event IDs) + + // ── MTE2: load tile[i] into buf_in[i%2] ── + // WAR: wait until Vector has released buf_in[i%2] from iteration i-2 + pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{pp}"] + pto.copy_gm_to_ubuf %gm_ptr[%i], %ub_in[%pp], ... + // RAW: signal Vector that buf_in[i%2] data is ready + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVT_IN_FWD_{pp}"] + + // ── Vector: compute buf_in[i%2] → buf_out[i%2] ── + // RAW: wait for MTE2 to finish loading buf_in[i%2] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVT_IN_FWD_{pp}"] + // WAR: wait for MTE3 to finish reading buf_out[i%2] from iteration i-2 + pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{pp}"] + scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_in[%pp][%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%pp][%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } {llvm.loop.aivector_scope} + // WAR: tell MTE2 "done reading buf_in[i%2]" + pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{pp}"] + // RAW: tell MTE3 "buf_out[i%2] result ready" + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVT_OUT_FWD_{pp}"] + + // ── MTE3: store result from buf_out[i%2] to GM ── + // RAW: wait for Vector to finish writing buf_out[i%2] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVT_OUT_FWD_{pp}"] + pto.copy_ubuf_to_gm %ub_out[%pp], %gm_out[%i], ... + // WAR: tell Vector "done reading buf_out[i%2]" + pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{pp}"] +} + +// ═══ Post-loop: drain — match every pre-loop prime with a wait ═══ +// Each priming set_flag must be paired. The last loop iteration's +// set_flags are consumed by wait_flags that will never fire inside the +// loop (there is no iteration i+2). Drain them here. +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-1)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-2)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-1)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-2)%2}"] // ◀ DRAIN +``` + +**What `set_flag`/`wait_flag` requires outside the loop:** +- **Before the loop (4 × `set_flag`):** Prime every reverse-dependency event ID — one per buffer per pipe-pair. Without this, the first iteration's `wait_flag` for reverse deps would deadlock (no signal was ever sent). +- **After the loop (4 × `wait_flag`):** Drain the matching reverse-dep signals from the last iterations. Every `set_flag` must be paired with a `wait_flag` — the last loop iterations produce signals that no subsequent iteration consumes, so they must be drained explicitly. + +###### 3b. `get_buf` / `rls_buf` version + +Same ping/pong double-buffering, but **no pre-loop priming or post-loop draining needed.** Buffer acquire/release semantics handle everything. + +```mlir +scf.for %i = %c0 to %N step %c1 { + // %pp = i % 2 (ping/pong selector) + + // ── MTE2: load tile[i] into buf[i%2] ── + // Acquires buf[i%2] — on first iteration, buffer is free so proceeds immediately. + // On later iterations, blocks until Vector releases buf[i%2] (WAR: automatic). + pto.get_buf %bufid_buf[%pp], "PIPE_MTE2" + pto.copy_gm_to_ubuf %gm_ptr[%i], %ub_buf[%pp], ... + pto.rls_buf %bufid_buf[%pp], "PIPE_MTE2" + + // ── Vector: compute on buf[i%2] ── + // Acquires buf[i%2] — blocks until MTE2 releases it (RAW: automatic) + pto.get_buf %bufid_buf[%pp], "PIPE_V" + pto.get_buf %bufid_out[%pp], "PIPE_V" + scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_buf[%pp][%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%pp][%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } {llvm.loop.aivector_scope} + // Release buf[i%2] — MTE2 can reuse in iteration i+2 (WAR resolved) + pto.rls_buf %bufid_buf[%pp], "PIPE_V" + pto.rls_buf %bufid_out[%pp], "PIPE_V" + + // ── MTE3: store result ── + // Acquires out[i%2] — blocks until Vector releases it (RAW: automatic) + pto.get_buf %bufid_out[%pp], "PIPE_MTE3" + pto.copy_ubuf_to_gm %ub_out[%pp], %gm_out[%i], ... + pto.rls_buf %bufid_out[%pp], "PIPE_MTE3" +} +// No post-loop drain needed — last rls_buf completes the pipeline. +``` + +**No priming, no draining, no event IDs.** The acquire/release protocol on buffer IDs indexed by `i%2` implicitly resolves all cross-pipeline dependencies: +- **RAW** (MTE2→V): Vector's `get_buf` blocks until MTE2's `rls_buf` on `buf[i%2]` +- **WAR** (V→MTE2): MTE2's `get_buf` in iteration `i+2` blocks until Vector's `rls_buf` in iteration `i` (same buffer) +- **First iteration:** Buffer is initially free, so `get_buf` proceeds without blocking — no priming needed + +--- + +#### Comparison Summary + +| Aspect | `set_flag` / `wait_flag` | `get_buf` / `rls_buf` | +|--------|--------------------------|------------------------| +| Dependency model | Explicit event signals | Implicit via buffer acquire/release | +| IDs per pipe-pair | **8** = 2 buffers × 2 dirs × 2 (fwd+rev) | 1 fwd + 1 rev per buffer (shared global pool) | +| Total HW IDs | 8 per pipe-pair, grows with buffers | **32 global** across all pipes | +| Reverse (WAR) deps | Extra `set_flag`/`wait_flag` pair per buffer | Handled automatically | +| Pre-loop setup | `set_flag` to prime each reverse dep | None | +| Post-loop teardown | `wait_flag` to drain all primed signals | None | +| Straight-line code | Simple, clear | Slightly more verbose (bracket each stage) | +| Ping/pong loops | 8 event IDs + 4 prime + 4 drain | Same pattern, no overhead | +| Best used for | Simple pipelines, fine-grained control | Double/multi-buffering, complex loops | + +--- + +#### Inter-Core Sync + +> **Note:** Inter-core sync is only needed for **mixed Cube+Vector tasks** where Cube produces data that Vector consumes (or vice versa). **Vec-only tasks can ignore this section entirely.** + +These ops coordinate execution across the Cube block and Vector subblocks within a cluster. Each core cluster consists of **1 Cube block : 2 Vector subblocks**, each with its own **SU (Sequencer Unit)** running independent instruction streams. + +``` +Core Cluster (1:2 ratio) +┌─────────────────────────────────────────────┐ +│ ┌──────────────┐ ┌──────────────┐ │ +│ │ AIC (Cube) │ │ AIV0 (Vec) │ │ +│ │ ┌────────┐ │ │ ┌────────┐ │ │ +│ │ │ SU │──┼────┼──│ SU │ │ │ +│ │ └────────┘ │ │ └────────┘ │ │ +│ │ CUBE pipe │ │ MTE2/V/MTE3 │ │ +│ │ L0C buffer │ │ UB (256KB) │ │ +│ └──────────────┘ └──────────────┘ │ +│ ┌──────────────┐ │ +│ │ AIV1 (Vec) │ │ +│ │ ┌────────┐ │ │ +│ │ │ SU │ │ │ +│ │ └────────┘ │ │ +│ │ MTE2/V/MTE3 │ │ +│ │ UB (256KB) │ │ +│ └──────────────┘ │ +└─────────────────────────────────────────────┘ +``` + +##### Platform Comparison + +| Aspect | A2A3 (Ascend 910) | A5 (Ascend 950) | +|--------|-------------------|-----------------| +| **Signal op** | `set_cross_core` (mode2) | `set_intra_block` | +| **Wait op** | `wait_flag_dev` | `wait_intra_core` | +| **Wait behavior** | SU-level blocking (entire core stalls) | Per-pipeline (only named pipe stalls) | +| **Semaphore pool** | 16 IDs per cluster, 4-bit counter | 16 IDs, but 32-ID address space (see below) | +| **C→V** | **Broadcast**: one `set` reaches both AIV0+AIV1 | **1:1**: separate `set` per subblock required | +| **V→C** | **Reduce**: Cube waits for both subblocks in one `wait` | **1:1**: Cube needs separate `wait` per subblock | + +##### A2A3: `set_cross_core` / `wait_flag_dev` + +```c +// mode2 broadcast/reduce semantics for 1:2 cluster +set_cross_core(pipe, semaphore_id); // pipe: VEC/MTE2/CUBE/FIX +wait_flag_dev(semaphore_id); // SU-level blocking +``` + +``` +C→V Broadcast (one set reaches both): + AIC ──set_cross_core──┬──> AIV0 sema++ + └──> AIV1 sema++ + +V→C Reduce (one wait for both): + AIV0 ──set_cross_core──┐ + ├──> AIC wait_flag_dev (blocks until BOTH) + AIV1 ──set_cross_core──┘ +``` + +##### `pto.set_cross_core` + +- **syntax:** `pto.set_cross_core %core_id, %event_id : i64, i64` +- **semantics:** Signal event to another core. Uses **mode2** for 1:2 cluster on A2A3. + +##### `pto.wait_flag_dev` + +- **syntax:** `pto.wait_flag_dev %core_id, %event_id : i64, i64` +- **semantics:** Wait for event from another core. **SU-level blocking** — entire core stalls. + +##### A5: `set_intra_block` / `wait_intra_core` + +```c +set_intra_block(trigger_pipe, semaphore_id); +wait_intra_core(wait_pipe, semaphore_id); // only named pipe stalls +``` + +**A5 semaphore address space:** The hardware has **16 physical semaphore IDs** but exposes a **32-ID address space** to support 1:1 signaling to each subblock: + +| ID Range | Target | +|----------|--------| +| 0–15 | AIV0 (subblock 0) | +| 16–31 (+15 offset) | AIV1 (subblock 1) | + +This means C→V requires **separate `set_intra_block` calls** per subblock (no broadcast), and V→C requires **separate `wait_intra_core` calls** per subblock (no hardware reduce). + +``` +C→V on A5 (1:1, no broadcast — need two sets): + AIC ──set_intra_block(pipe, sema_id)────> AIV0 + AIC ──set_intra_block(pipe, sema_id+15)──> AIV1 + +V→C on A5 (1:1, no reduce — need two waits): + AIV0 ──set_intra_block──> AIC wait_intra_core(pipe, sema_id) + AIV1 ──set_intra_block──> AIC wait_intra_core(pipe, sema_id+15) // extra wait +``` + +##### `pto.set_intra_block` + +- **syntax:** `pto.set_intra_block %block_id, %event_id : i64, i64` +- **semantics:** Signal event within a block (A5). Specifies **trigger pipe**. 1:1 per subblock. + +##### `pto.wait_intra_core` + +- **syntax:** `pto.wait_intra_core %block_id, %event_id : i64, i64` +- **semantics:** Wait for event within block (A5). Specifies **which pipeline should wait** — only that pipe stalls, SU and other pipes continue. + +##### Wait Granularity Comparison + +``` +A2A3 wait_flag_dev (SU-level stall): + SU ──┬── PIPE_MTE2 ───╳ ALL STALLED + ├── PIPE_V ───╳ ALL STALLED + └── PIPE_MTE3 ───╳ ALL STALLED + +A5 wait_intra_core "PIPE_MTE2" (per-pipe stall): + SU ──┬── PIPE_MTE2 ───╳ STALLED (waiting for Cube) + ├── PIPE_V ─── ✓ RUNNING + └── PIPE_MTE3 ─── ✓ RUNNING +``` + + + +### 2. DMA Copy Programming + +> **Category:** DMA transfer configuration and execution +> **Pipelines:** MTE2 (GM→UB), MTE3 (UB→GM) + +DMA transfers move data between Global Memory (GM) and Unified Buffer (UB). The MTE engines operate asynchronously from the Vector core, requiring explicit sync (see [Pipeline Sync](#isa-01-pipeline-sync)). + +The MTE2/MTE3 DMA engine executes a **multi-level nested loop** transfer. Before issuing the copy instruction, stride and loop-size registers must be configured. + +--- + +#### Loop Stride Configuration (GM→UB) + +These ops configure the MTE2 DMA engine's hardware loops for GM→UB transfers. They must be set **before** calling `pto.copy_gm_to_ubuf`. + +##### `pto.set_loop_size_outtoub` + +- **syntax:** `pto.set_loop_size_outtoub %loop1_count, %loop2_count : i64, i64` +- **semantics:** Configure HW loop iteration counts for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%loop1_count` | 21 bits | Inner HW loop iteration count | +| `%loop2_count` | 21 bits | Outer HW loop iteration count | + +When not using multi-level looping, set both to 1. + +--- + +##### `pto.set_loop2_stride_outtoub` + +- **syntax:** `pto.set_loop2_stride_outtoub %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure outer loop (loop2) pointer advance for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 40 bits | GM source pointer advance per loop2 iteration (bytes) | +| `%dst_stride` | 21 bits | UB destination pointer advance per loop2 iteration (bytes) | + +After each loop2 iteration, the DMA engine advances the GM read pointer by `%src_stride` and UB write pointer by `%dst_stride`. + +--- + +##### `pto.set_loop1_stride_outtoub` + +- **syntax:** `pto.set_loop1_stride_outtoub %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure inner loop (loop1) pointer advance for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 40 bits | GM source pointer advance per loop1 iteration (bytes) | +| `%dst_stride` | 21 bits | UB destination pointer advance per loop1 iteration (bytes) | + +--- + +#### Loop Stride Configuration (UB→GM) + +These ops configure the MTE3 DMA engine's hardware loops for UB→GM transfers. They must be set **before** calling `pto.copy_ubuf_to_gm`. + +Note: UB stride fields are 21 bits (sufficient for 256KB UB address space), GM stride fields are 40 bits (full GM address range). + +##### `pto.set_loop_size_ubtoout` + +- **syntax:** `pto.set_loop_size_ubtoout %loop1_count, %loop2_count : i64, i64` +- **semantics:** Configure HW loop iteration counts for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%loop1_count` | 21 bits | Inner HW loop iteration count | +| `%loop2_count` | 21 bits | Outer HW loop iteration count | + +--- + +##### `pto.set_loop2_stride_ubtoout` + +- **syntax:** `pto.set_loop2_stride_ubtoout %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure outer loop (loop2) pointer advance for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 21 bits | UB source pointer advance per loop2 iteration (bytes) | +| `%dst_stride` | 40 bits | GM destination pointer advance per loop2 iteration (bytes) | + +--- + +##### `pto.set_loop1_stride_ubtoout` + +- **syntax:** `pto.set_loop1_stride_ubtoout %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure inner loop (loop1) pointer advance for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 21 bits | UB source pointer advance per loop1 iteration (bytes) | +| `%dst_stride` | 40 bits | GM destination pointer advance per loop1 iteration (bytes) | + +--- + +#### DMA Transfer Execution + +##### `pto.copy_gm_to_ubuf` + +- **syntax:** +```mlir +pto.copy_gm_to_ubuf %gm_src, %ub_dst, + %sid, %n_burst, %len_burst, %left_padding, %right_padding, + %data_select_bit, %l2_cache_ctl, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` +- **semantics:** DMA transfer from Global Memory (`!pto.ptr`) to Unified Buffer (`!pto.ptr`). + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%gm_src` | GM source pointer (`!pto.ptr`) | +| `%ub_dst` | UB destination pointer (`!pto.ptr`, 32B-aligned) | +| `%sid` | Stream ID (usually 0) | +| `%n_burst` | Number of burst rows (innermost loop count) | +| `%len_burst` | Contiguous bytes transferred per burst row | +| `%left_padding` | Left padding count (bytes) | +| `%right_padding` | Right padding count (bytes) | +| `%data_select_bit` | Padding / data-select control bit (`i1`) | +| `%l2_cache_ctl` | L2 cache allocate control (TBD — controls whether DMA allocates in L2 cache) | +| `%src_stride` | GM source stride: start-to-start distance between consecutive burst rows (bytes) | +| `%dst_stride` | UB destination stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | + +--- + +##### `pto.copy_ubuf_to_gm` + +- **syntax:** +```mlir +pto.copy_ubuf_to_gm %ub_src, %gm_dst, + %sid, %n_burst, %len_burst, %reserved, %dst_stride, %src_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` +- **semantics:** DMA transfer from Unified Buffer (`!pto.ptr`) to Global Memory (`!pto.ptr`). MTE3 reads only `len_burst` bytes from each UB row (de-padding). + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%ub_src` | UB source pointer (`!pto.ptr`, 32B-aligned) | +| `%gm_dst` | GM destination pointer (`!pto.ptr`) | +| `%sid` | Stream ID (usually 0) | +| `%n_burst` | Number of burst rows | +| `%len_burst` | Contiguous bytes transferred per burst row | +| `%reserved` | Reserved field (set to 0) | +| `%dst_stride` | GM destination stride: start-to-start distance between consecutive burst rows (bytes) | +| `%src_stride` | UB source stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | + +--- + +##### `pto.copy_ubuf_to_ubuf` + +- **syntax:** +```mlir +pto.copy_ubuf_to_ubuf %source, %dest, %sid, %n_burst, %len_burst, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64 x5 +``` +- **semantics:** Copy within Unified Buffer. + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%source` | UB source pointer | +| `%dest` | UB destination pointer | +| `%sid` | Stream ID | +| `%n_burst` | Number of bursts | +| `%len_burst` | Length per burst | +| `%src_stride` | Source stride | +| `%dst_stride` | Destination stride | + +--- + +#### Burst / Stride / Pad Model + +All A5 DMA addresses are **stride-based**: stride is the distance from the start of one row to the start of the next row (`stride >= lenBurst`). There is no separate "gap" parameter. + +##### Key Terms + +``` +burst = lenBurst contiguous bytes transferred per row +stride = distance (bytes) from start of row[r] to start of row[r+1] +pad = ub_stride - lenBurst, padded to the 32B alignment boundary +``` + +##### Alignment Constraints + +- **UB addresses** (both source and destination) must be **32-byte aligned**. +- **GM→UB padding**: When `data_select_bit = true`, each UB row is padded from `lenBurst` up to the **32B-aligned boundary** of `ub_stride` with `pad_val` (set via `set_mov_pad_val`). This ensures every UB row starts at a 32B-aligned offset. +- **UB→GM de-padding**: MTE3 reads `lenBurst` bytes from each 32B-aligned UB row (skipping any padding that was added during load), writing only valid data to GM. This effectively strips padding on store. + +##### 2D Diagram: GM→UB (pto.copy_gm_to_ubuf) + +``` +GM (source, `!pto.ptr`): + + |<--- src_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +UB (destination, `!pto.ptr`, 32B-aligned): + + |<---------- dst_stride (32B-aligned) ---------->| + |<- len_burst ->|<- pad (to 32B boundary) ->| | +Row 0: [##DATA########][000000 PAD 000000000000000] +Row 1: [##DATA########][000000 PAD 000000000000000] +Row 2: [##DATA########][000000 PAD 000000000000000] + ... +Row N-1: [##DATA########][000000 PAD 000000000000000] + +N = n_burst +stride = start of row[r] to start of row[r+1] +pad = filled with pad_val to 32B boundary (data_select_bit=true) +[DATA] = valid data transferred by DMA +[PAD] = pad_val fill (set via set_mov_pad_val) +``` + +##### 2D Diagram: UB→GM (pto.copy_ubuf_to_gm) + +``` +UB (source, `!pto.ptr`, 32B-aligned start addr): + + |<---------- src_stride (32B-aligned) --------->| + |<- len_burst ->|<-- pad (ignored on read) -->| | +Row 0: [##DATA########][000 pad 000000000000000000] +Row 1: [##DATA########][000 pad 000000000000000000] +Row 2: [##DATA########][000 pad 000000000000000000] + ... +Row N-1: [##DATA########][000 pad 000000000000000000] + +GM (destination, `!pto.ptr`): + + |<--- dst_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +N = n_burst +MTE3 reads only len_burst bytes from each UB row (de-padding). +Only len_burst bytes are written to each GM row. +``` + +--- + +#### Multi-Level Loop Semantics (C Code) + +The full DMA transfer is a nested loop. The HW loop registers (set before the copy) control the outer levels, and the copy instruction parameters control the innermost burst level. + +##### GM→UB Full Loop + +```c +// C equivalent of what the HW executes: +for (int j = 0; j < loop2_count; j++) { // HW outer loop + uint8_t *gm1 = gm_src + j * loop2_src_stride; + uint8_t *ub1 = ub_dst + j * loop2_dst_stride; + + for (int k = 0; k < loop1_count; k++) { // HW inner loop + uint8_t *gm2 = gm1 + k * loop1_src_stride; + uint8_t *ub2 = ub1 + k * loop1_dst_stride; + + for (int r = 0; r < n_burst; r++) { // burst engine + memcpy(ub2 + r * dst_stride, // UB dest row + gm2 + r * src_stride, // GM src row + len_burst); // contiguous bytes + if (data_select_bit) + memset(ub2 + r * dst_stride + len_burst, + pad_val, dst_stride - len_burst); + } + } +} +``` + +##### UB→GM Full Loop + +```c +// C equivalent: +for (int j = 0; j < loop2_count; j++) { + uint8_t *ub1 = ub_src + j * loop2_src_stride; + uint8_t *gm1 = gm_dst + j * loop2_dst_stride; + + for (int k = 0; k < loop1_count; k++) { + uint8_t *ub2 = ub1 + k * loop1_src_stride; + uint8_t *gm2 = gm1 + k * loop1_dst_stride; + + for (int r = 0; r < n_burst; r++) { + memcpy(gm2 + r * dst_stride, // GM dest row + ub2 + r * src_stride, // UB src row + len_burst); // contiguous bytes + } + } +} +``` + +--- + +#### Example 1: GM→UB — Load a 32×32 f32 Tile (Simple Case) + +Load a 32×32 f32 tile from GM into UB. This matches the `abs_kernel_2d` test case. + +``` +GM layout (32 × 32 f32, contiguous): + + |<- len_burst = 128B (32 × 4) ->| + |<- src_stride = 128B --------->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + +UB layout (32 × 32 f32, 32B-aligned, contiguous): + + |<- dst_stride = 128B (32B-aligned) ->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + + len_burst = 32 × 4 = 128 bytes + src_stride = 128 bytes (contiguous rows) + dst_stride = 128 bytes (already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — no multi-level loops needed +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + +pto.copy_gm_to_ubuf %arg0, %ub_in, + %c0_i64, // sid = 0 + %c32_i64, // n_burst = 32 (32 rows) + %c128_i64, // len_burst = 128 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c128_i64, // src_stride = 128 bytes + %c128_i64 // dst_stride = 128 bytes + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 2: GM→UB — Load a 2D Tile from a Larger Matrix + +Load a 64×128 tile (f16) from a 1024×512 matrix in GM into UB. + +``` +GM layout (1024 × 512 f16): + + col 0 col 128 col 512 + | | | + +--[###TILE###]+.....................+ row R + +--[###TILE###]+.....................+ row R+1 + ... + +--[###TILE###]+.....................+ row R+63 + + |<--------- src_stride = 1024B ----------->| + |<-len_burst=256B->| + + len_burst = 128 × 2 = 256 bytes (128 f16 elements) + src_stride = 512 × 2 = 1024 bytes (start-to-start, full GM row) + +UB layout (64 × 128 f16, 32B-aligned, contiguous): + + +--[###TILE###]--+ row 0 (256 bytes, 32B-aligned, no pad) + +--[###TILE###]--+ row 1 + ... + +--[###TILE###]--+ row 63 + + dst_stride = 256 bytes (= len_burst, already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — no multi-level loops needed +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 (64 rows) + %c256_i64, // len_burst = 256 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c1024_i64, // src_stride = 1024 bytes (full matrix row) + %c256_i64 // dst_stride = 256 bytes (tile row) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 3: GM→UB — Load with Padding + +Load 100 valid columns from GM into a 128-wide UB tile (f16). The remaining 28 columns are zero-padded. + +``` +GM (100 cols valid, contiguous): + + |<-len_burst=200B->| + |<- src_stride=200B (start-to-start) ->| + +--[####DATA####]-+ row 0 + +--[####DATA####]-+ row 1 + ... + +--[####DATA####]-+ row 63 + +UB (128 cols wide, 32B-aligned, padded): + + |<--------- dst_stride = 256B (32B-aligned) --------->| + |<-len_burst=200B->|<---- pad = 56B to 32B boundary ->| + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 0 + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 1 + ... + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 63 + + len_burst = 100 × 2 = 200 bytes + src_stride = 200 bytes (start-to-start, contiguous in GM) + dst_stride = 128 × 2 = 256 bytes (32B-aligned tile width in UB) + pad = 256 - 200 = 56 bytes (padded to 32B boundary with pad_val) +``` + +```mlir +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 + %c200_i64, // len_burst = 200 bytes + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %true, // data_select_bit = true (enable padding) + %c0_i64, // l2_cache_ctl = 0 + %c200_i64, // src_stride = 200 bytes + %c256_i64 // dst_stride = 256 bytes (32B-aligned) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 4: UB→GM — Store a 32×32 f32 Tile (Simple Case) + +Store a 32×32 f32 tile from UB back to GM. This matches the `abs_kernel_2d` test case. + +``` +UB (source, 32B-aligned, 32 × 32 f32): + + |<- src_stride = 128B (32B-aligned) ->| + |<- len_burst = 128B ->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 + + (no padding here — len_burst == src_stride) + +GM (dest, 32 × 32 f32): + + |<- dst_stride = 128B ->| + |<- len_burst = 128B -->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 +``` + +```mlir +// Configure MTE3 strides +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + +pto.copy_ubuf_to_gm %ub_out, %arg1, + %c0_i64, // sid = 0 + %c32_i64, // n_burst = 32 + %c128_i64, // len_burst = 128 bytes + %c0_i64, // reserved = 0 + %c128_i64, // dst_stride = 128 bytes + %c128_i64 // src_stride = 128 bytes + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` + +--- + +#### Example 5: UB→GM — Store a 2D Tile Back to a Larger Matrix + +Store a 64×128 tile (f16) from UB back to a 1024×512 GM matrix at an offset. + +``` +UB (source, 32B-aligned, 64 × 128 f16): + + |<- src_stride = 256B (32B-aligned) ->| + |<- len_burst = 256B ->| + +--[#####TILE#####]---+ row 0 + +--[#####TILE#####]---+ row 1 + ... + +--[#####TILE#####]---+ row 63 + + (no padding here — len_burst == src_stride) + +GM (dest, into 1024 × 512 matrix): + + |<----------- dst_stride = 1024B (start-to-start) --------->| + |<- len_burst = 256B ->| | + col 0 col 128 col 512 + +--[#####TILE#####]---+.............................+ row R + +--[#####TILE#####]---+.............................+ row R+1 + ... + +--[#####TILE#####]---+.............................+ row R+63 + + MTE3 reads len_burst bytes from each 32B-aligned UB row, + writes only len_burst bytes per GM row (stride controls row spacing). +``` + +```mlir +// Configure MTE3 strides +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 + +pto.copy_ubuf_to_gm %ub_ptr, %gm_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 + %c256_i64, // len_burst = 256 bytes + %c0_i64, // reserved = 0 + %c1024_i64, // dst_stride = 1024 bytes (GM row) + %c256_i64 // src_stride = 256 bytes (UB row) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` + +--- + +#### Example 6: GM→UB with Multi-Level Loop (Batch of Tiles) + +Load 4 batches of 8×128 tiles from a [4, 8, 128] f16 tensor using loop1. + +``` +GM [4, 8, 128] f16 (contiguous): UB (4 tiles laid out sequentially): + + batch 0: 8 rows × 256 bytes [batch 0: 8×128][batch 1: 8×128] + batch 1: 8 rows × 256 bytes [batch 2: 8×128][batch 3: 8×128] + batch 2: 8 rows × 256 bytes + batch 3: 8 rows × 256 bytes loop1 src_stride = 2048 bytes (8 × 256) + loop1 dst_stride = 2048 bytes (8 × 256) + Each batch = 8 × 256 = 2048 bytes loop1_count = 4 (iterate over batches) +``` + +```mlir +// loop1_count = 4 batches, loop2_count = 1 (not used) +pto.set_loop_size_outtoub %c4_i64, %c1_i64 : i64, i64 + +// loop1 stride: advance by one batch (2048 bytes) in both GM and UB +pto.set_loop1_stride_outtoub %c2048_i64, %c2048_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c8_i64, // n_burst = 8 rows per batch + %c256_i64, // len_burst = 256 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c256_i64, // src_stride = 256 (contiguous rows) + %c256_i64 // dst_stride = 256 (contiguous rows) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +Execution trace: + +``` +loop1 iter 0: gm_ptr + 0×2048 → ub_ptr + 0×2048, DMA 8 rows × 256B +loop1 iter 1: gm_ptr + 1×2048 → ub_ptr + 1×2048, DMA 8 rows × 256B +loop1 iter 2: gm_ptr + 2×2048 → ub_ptr + 2×2048, DMA 8 rows × 256B +loop1 iter 3: gm_ptr + 3×2048 → ub_ptr + 3×2048, DMA 8 rows × 256B +``` + + + +### 3. Vector Load/Store + +> **Category:** UB ↔ Vector Register data movement +> **Pipeline:** PIPE_V (Vector Core) + +Vector loads move data from Unified Buffer (UB) to vector registers (`vreg`). Vector stores move data from `vreg` back to UB. All vector compute operates only on `vreg` — UB is the staging area between DMA and compute. + +#### Common Operand Model + +- `%source` / `%dest` is the base address operand in SSA form. The base pointer + MUST address the Vector tile buffer / UB space. +- `%offset` is the displacement operand in SSA form. The exact encoding is + instruction-specific, but the effective address and any post-update behavior + MUST match the selected instruction form. +- `%mask` is the predicate operand for predicated memory families. For memory + families, + inactive lanes or inactive blocks MUST NOT issue memory requests unless the + instruction explicitly documents a different behavior. +- `%result` is the destination vector register value in SSA form. +- `!pto.align` is the SSA carrier for alignment-buffer state used by unaligned + load/store families. The PTO micro Instruction representation makes that state explicit rather than implicit. + +--- + +#### Latency and throughput (A5) + +**Cycle-accurate simulator (CA model)** issue→retire timings for vector-side instructions behind this chapter. Values are **simulator** results, **not** guaranteed for silicon. + +**SOC:** Tables below are from **Ascend910_9599** CA sim (the pto-isa ST default when **Ascend950PR_9599** is not selected). + +**Log `dist:` tokens:** PTO load/store modes lower to **`RV_VLD` / `RV_VLDI` / `RV_VST` / `RV_VSTI`** with a **`dist:`** field on the vector pipes (`RVECLD` / `RVECST`). Some simulator logs typo contiguous load as `dist:NORAML`; treat as **`NORMAL`**. + +##### Reference op latencies (A5 mnemonics) + +| A5 mnemonic | Mode / note | Typical issue→retire (cycles) | +|-------------|-------------|------------------------------| +| `RV_VLD` | `dist:NORMAL` / `NORAML` | **9** | +| `RV_VLDI` | `dist:DINTLV_B32` (dual vreg) | **9** | +| `RV_VST` / `RV_VSTI` | `dist:NORM_B32` | **9** | +| `RV_VGATHER2` | `Dtype: B32` | **27–28** | +| `RV_VGATHERB` | indexed byte gather | **~21** | +| `RV_VSCATTER` | `Dtype: B16` | **~17** | +| `RV_VADD` | F32 between UB-backed ops | **7** | + +##### `dist:` tokens (issue→retire) + +Most **`dist:`** tokens are **9** issue→retire cycles. **`INTLV_*`** on **`RV_VSTI`** are **12** cycles. + +| `dist:` (as in log) | RV op | issue→retire (cycles) | +|---------------------|-------|----------------------| +| `DINTLV_B32` | `RV_VLDI` | **9** | +| `DINTLV_B16` | `RV_VLDI` | **9** | +| `DINTLV_B8` | `RV_VLDI` | **9** | +| `BRC_B32` | `RV_VLD` | **9** | +| `BRC_B8` | `RV_VLD` | **9** | +| `BRC_B16` | `RV_VLD` | **9** | +| `BRC_BLK` | `RV_VLD` | **9** | +| `INTLV_B32` | `RV_VSTI` | **12** | +| `INTLV_B16` | `RV_VSTI` | **12** | +| `INTLV_B8` | `RV_VSTI` | **12** | +| `UNPK_B8` | `RV_VLD` | **9** | +| `UNPK_B16` | `RV_VLD` | **9** | +| `UNPK_B32` | `RV_VLD` | **9** | +| `NORM_B32` | `RV_VSTI` | **9** | +| `NORM_B16` | `RV_VSTI` | **9** | +| `NORM_B8` | `RV_VSTI` | **9** | +| `PK_B32` | `RV_VSTI` | **9** | +| `PK_B16` | `RV_VSTI` | **9** | +| `NORMAL` / `NORAML` | `RV_VLD` | **9** | + +**Note:** PTO intrinsic **`BLK`** matches the **`BRC_BLK`** `dist:` string on **`RV_VLD`** in simulator logs (block-replicate path; not a plain contiguous copy in the usual tiling use). + +**Issue (vector load/store):** `pto.vlds` (**`RV_VLD`**) is **dual-issue capable**: two independent `pto.vlds` can issue **in the same cycle**. **Alternatively**, the hardware can issue **one** `pto.vlds` **and** **one** `pto.vsts` together (**1+1**) in the same cycle. Each cycle is **either** dual **`vlds`** **or** **`vlds` + `vsts` (1+1)**—those two issue modes are mutually exclusive. Sustained throughput still depends on RAW hazards and loop structure. + +**Throughput (simulator, pattern-dependent):** + +- **`RV_VLD` / `pto.vlds`:** Dual-issue **or** half of a **1+1** with `vsts`, per the rule above. +- **`RV_VST` / `pto.vsts`:** In a **1+1** cycle, pairs with one `vlds`; otherwise typically **one** store per cycle in tight loops. +- **`RV_VGATHER2`:** Much lower than contiguous `RV_VLD` (on the order of **~0.1** ops/cycle in steady-state alongside 27–28-cycle latency). + +##### PTO `dist` summary (loads) + +| PTO `dist` (load) | Latency | +|-------------------|-------------------| +| `NORM` | **9** cycles | +| `UNPK_B8`, `UNPK_B16`, `UNPK_B32` | **9** cycles | +| `DINTLV_B32` | **9** cycles (`RV_VLDI`) | +| `DINTLV_B16`, `DINTLV_B8` | **9** cycles (same `RV_VLDI` + `dist:DINTLV_*` path as `DINTLV_B32`) | +| `BRC_B32` | **9** cycles | +| `BRC_B8`, `BRC_B16` | **9** cycles (`RV_VLD`) | +| `BLK` | **9** cycles as **`dist:BRC_BLK`** on `RV_VLD` | +| `BDINTLV` | **9** cycles | +| `US_*`, `DS_*`, `SPLT*` | **9** cycles | + +##### PTO `dist` summary (stores) + +| PTO `dist` (store) | Latency | +|--------------------|-------------------| +| `NORM_B8`, `NORM_B16`, `NORM_B32` | **9** cycles (`RV_VSTI`) | +| `PK_B16`, `PK_B32` | **9** cycles | +| `INTLV_B32` (`pto.vstx2`) | **12** cycles | +| `INTLV_B16`, `INTLV_B8` | **12** cycles (same interleave store path as `INTLV_B32`) | +| `MRG4CHN_B8`, `MRG2CHN_*` | **9** cycles | + +##### Gather, scatter, and special addressing + +| PTO op | A5-level | Latency | +|--------|----------|-------------------| +| `pto.vgather2` | `RV_VGATHER2` | **27–28** cycles (pattern-dependent) | +| `pto.vgatherb` | `RV_VGATHERB` | **~21** cycles issue→retire | +| `pto.vgather2_bc` | (broadcast gather) | **27–28** cycles (same as **`pto.vgather2`**) | +| `pto.vscatter` | `RV_VSCATTER` | **~17** cycles for **`Dtype: B16`** | + +##### Strided loads/stores, unaligned ops, alignment state + +Ops such as **`pto.vldas`**, **`pto.vldus`**, **`pto.vsld`**, **`pto.vsldb`**, **`pto.vsst`**, **`pto.vsstb`**, **`pto.vsta`**, **`pto.vstas`**, **`pto.vstar`**, **`pto.vstu`**, **`pto.vstus`**, **`pto.vstur`**: **9** cycles (same vector load/store pipe family as contiguous `RV_VLD` / `RV_VST` unless listed otherwise above). + +##### Dual-issue vs DMA + +DMA **`TLOAD` / `TSTORE`** (global memory ↔ UB) use **MTE** pipes, not `RV_VLD`/`RV_VST`. **MTE2** `MOV_*` latency is not the same as vector `RV_VLD` latency (see `02-dma-copy.md` for GM↔UB movement). + +--- + +#### Contiguous Loads + +##### `pto.vlds` + +- **syntax:** `%result = pto.vlds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.vreg` +- **semantics:** Vector load with distribution mode. +- **inputs:** + `%source` is the UB base address, `%offset` is the load displacement, and + `DIST` selects the distribution mode. +- **outputs:** + `%result` is the loaded vector register value. +- **constraints and limitations:** + The effective address MUST satisfy the alignment rule of the selected + distribution mode. `NORM` reads one full vector footprint. Broadcast, + upsample, downsample, unpack, split-channel, and deinterleave modes change + how memory bytes are mapped into destination lanes, but they do not change the + fact that the source is UB memory. + +**Distribution modes:** + +| Mode | Description | C Semantics | Latency | +|------|-------------|-------------|---------------------| +| `NORM` | Contiguous 256B load | `dst[i] = UB[base + i * sizeof(T)]` | **9** cycles | +| `BRC_B32` | Broadcast single element | `dst[i] = UB[base]` for all i | **9** cycles | +| `BRC_B8`, `BRC_B16` | Broadcast first lane element | Same idea at B8/B16 width | **9** cycles | +| `US_B8/B16` | Upsample (duplicate each element) | `dst[2*i] = dst[2*i+1] = UB[base + i]` | **9** cycles | +| `DS_B8/B16` | Downsample (every 2nd element) | `dst[i] = UB[base + 2*i]` | **9** cycles | +| `UNPK_B8/B16/B32` | Unpack (zero-extend to wider type) | `dst_i32[i] = (uint32_t)UB_i16[base + 2*i]` | **9** cycles | +| `SPLT4CHN_B8` | Split 4-channel (RGBA → R plane) | Extract every 4th byte | **9** cycles | +| `SPLT2CHN_B8/B16` | Split 2-channel | Extract every 2nd element | **9** cycles | +| `DINTLV_B32` | Deinterleave 32-bit | Even elements only | **9** cycles | +| `DINTLV_B16`, `DINTLV_B8` | Deinterleave 16-bit / 8-bit | Pair lanes from interleaved UB | **9** cycles | +| `BDINTLV` | Block deinterleave | (see PTO headers for exact tiling) | **9** cycles | +| `BLK` | Block load | Blocked / tiled access pattern (see PTO headers) | **9** cycles (`dist:BRC_BLK` on `RV_VLD`) | + +**Example — Contiguous load:** +```mlir +%v = pto.vlds %ub[%offset] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +``` + +**Example — Broadcast scalar to all lanes:** +```mlir +%v = pto.vlds %ub[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> +``` + +--- + +##### `pto.vldas` + +- **syntax:** `%result = pto.vldas %source : !pto.ptr -> !pto.align` +- **semantics:** Prime alignment buffer for subsequent unaligned load. +- **inputs:** + `%source` is the UB address whose surrounding aligned block seeds the load + alignment state. +- **outputs:** + `%result` is the initialized load-alignment state. +- **constraints and limitations:** + This op is the required leading operation for a `pto.vldus` stream using the + same alignment state. The source address itself need not be 32-byte aligned; + hardware truncates it to the aligned block boundary for the priming load. +- **Latency:** **9** cycles. + +--- + +##### `pto.vldus` + +- **syntax:** `%result, %align_out, %base_out = pto.vldus %source, %align : !pto.ptr, !pto.align -> !pto.vreg, !pto.align, !pto.ptr` +- **semantics:** Unaligned load using primed align state. +- **inputs:** + `%source` is the current UB address and `%align` is the incoming load + alignment state primed by `pto.vldas` or a prior `pto.vldus`. +- **outputs:** + `%result` is the assembled vector value, `%align_out` is the updated alignment + state, and `%base_out` is the post-update base pointer state exposed in SSA + form. +- **constraints and limitations:** + A matching `pto.vldas` MUST appear before the first dependent `pto.vldus` + stream in the same vector loop. Both the alignment state and the base address + advance across the stream, and the PTO micro Instruction representation exposes those updates as SSA results. +- **Latency:** **9** cycles. + +**Unaligned load pattern:** +```mlir +%align = pto.vldas %ub : !pto.ptr -> !pto.align +%vec, %align2, %ub2 = pto.vldus %ub, %align : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align, !pto.ptr +``` + +--- + +#### Dual Loads (Deinterleave) + +##### `pto.vldx2` + +- **syntax:** `%low, %high = pto.vldx2 %source[%offset], "DIST" : !pto.ptr, index -> !pto.vreg, !pto.vreg` +- **semantics:** Dual load with deinterleave (AoS → SoA conversion). +- **inputs:** + `%source` is the UB base pointer, `%offset` is the displacement, and `DIST` + selects a dual-load/deinterleave layout. +- **outputs:** + `%low` and `%high` are the two destination vectors. +- **constraints and limitations:** + This family is only legal for interleave/deinterleave style distributions. + The two outputs form an ordered pair, and that pairing MUST be preserved. +- **Latency:** **`DINTLV_B32` → 9** cycles on `RV_VLDI`. **`DINTLV_B16` / `DINTLV_B8` → 9** cycles on `RV_VLDI`. **`BDINTLV` → 9** cycles on `RV_VLDI`. + +**Distribution modes:** `DINTLV_B8`, `DINTLV_B16`, `DINTLV_B32`, `BDINTLV` + +```c +// DINTLV_B32: deinterleave 32-bit elements +for (int i = 0; i < 64; i++) { + low[i] = UB[base + 8*i]; // even elements + high[i] = UB[base + 8*i + 4]; // odd elements +} +``` + +**Example — Load interleaved XY pairs into separate X/Y vectors:** +```mlir +%x, %y = pto.vldx2 %ub[%offset], "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +``` + +--- + +#### Strided Loads + +##### `pto.vsld` + +- **syntax:** `%result = pto.vsld %source[%offset], "STRIDE" : !pto.ptr -> !pto.vreg` +- **semantics:** Strided load with fixed stride pattern. +- **inputs:** + `%source` is the UB base pointer and `%offset` is the displacement encoded + with the selected fixed stride mode. +- **outputs:** + `%result` is the loaded vector. +- **constraints and limitations:** + This is a deprecated compatibility family. The selected stride token + determines which sub-elements are read from each source block. +- **Latency:** **9** cycles. + +**Stride modes:** `STRIDE_S3_B16`, `STRIDE_S4_B64`, `STRIDE_S8_B32`, `STRIDE_S2_B64` + +--- + +##### `pto.vsldb` + +- **syntax:** `%result = pto.vsldb %source, %offset, %mask : !pto.ptr, i32, !pto.mask -> !pto.vreg` +- **semantics:** Block-strided load for 2D tile access. +- **inputs:** + `%source` is the UB base pointer, `%offset` is the packed stride/control word, + and `%mask` controls which blocks participate. +- **outputs:** + `%result` is the loaded vector. +- **constraints and limitations:** + `%offset` is not a plain byte displacement; it encodes the block stride and + repeat pattern. If a block is masked off, the corresponding destination block + is zeroed and MUST NOT raise an address overflow exception for that block. +- **Latency:** **9** cycles. + +--- + +#### Gather (Indexed) Loads + +##### `pto.vgather2` + +- **syntax:** `%result = pto.vgather2 %source, %offsets, %active_lanes : !pto.ptr, !pto.vreg, index -> !pto.vreg` +- **semantics:** Indexed gather from UB. +- **inputs:** + `%source` is the UB base pointer, `%offsets` provides per-lane element + offsets, and `%active_lanes` bounds how many lanes participate. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + Only the first `%active_lanes` indices participate. The index element width + and interpretation MUST match the selected gather form, and each effective + address must satisfy that form's alignment rules. +- **Latency:** **27–28** cycles per `RV_VGATHER2`; throughput much lower than contiguous `RV_VLD` (see **Latency and throughput (A5)** at the start of this chapter). + +```c +for (int i = 0; i < active_lanes; i++) + dst[i] = UB[base + offsets[i] * sizeof(T)]; +``` + +--- + +##### `pto.vgatherb` + +- **syntax:** `%result = pto.vgatherb %source, %offsets, %active_lanes : !pto.ptr, !pto.vreg, index -> !pto.vreg` +- **semantics:** Byte-granularity indexed gather from UB. +- **inputs:** + `%source` is the UB base pointer, `%offsets` contains per-block byte offsets, + and `%active_lanes` bounds the number of active gathered blocks. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + This is a block gather, not a byte-per-lane gather. `%source` MUST be 32-byte + aligned, each participating offset MUST describe a 32-byte-aligned block, and + inactive blocks are zero-filled. +- **Latency:** **~21** cycles issue→retire. + +```c +for (int i = 0; i < active_lanes; i++) + dst[i] = UB[base + offsets[i]]; // byte-addressed +``` + +--- + +##### `pto.vgather2_bc` + +- **syntax:** `%result = pto.vgather2_bc %source, %offsets, %mask : !pto.ptr, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Gather with broadcast, conditioned by mask. +- **inputs:** + `%source` is the UB base pointer, `%offsets` contains gather indices, and + `%mask` gates which lanes participate. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + This is a backward-compatible family. Masked-off lanes do not participate in + address coalescing and do not trigger address overflow exceptions; their + destination lanes are zero-filled. +- **Latency:** **27–28** cycles (same as **`pto.vgather2`**). + +--- + +#### Contiguous Stores + +##### `pto.vsts` + +- **syntax:** `pto.vsts %value, %dest[%offset], %mask {dist = "DIST"} : !pto.vreg, !pto.ptr, !pto.mask` +- **semantics:** Vector store with distribution mode. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offset` is + the displacement, `%mask` selects the active lanes or sub-elements, and + `DIST` selects the store distribution. +- **outputs:** + This op has no SSA result; it writes to UB memory. +- **constraints and limitations:** + The effective destination address MUST satisfy the alignment rule of the + selected store mode. Narrowing/packing modes may only preserve a subset of the + source bits. Merge-channel modes reinterpret the source vector as channel + planes and interleave them on store. + +**Distribution modes:** + +| Mode | Description | C Semantics | Latency | +|------|-------------|-------------|---------------------| +| `NORM_B8/B16/B32` | Contiguous store | `UB[base + i] = src[i]` | **9** cycles | +| `PK_B16/B32` | Pack/narrowing store | `UB_i16[base + 2*i] = truncate_16(src_i32[i])` | **9** cycles | +| `MRG4CHN_B8` | Merge 4 channels (R,G,B,A → RGBA) | Interleave 4 planes | **9** cycles | +| `MRG2CHN_B8/B16` | Merge 2 channels | Interleave 2 planes | **9** cycles | + +**Example — Contiguous store:** +```mlir +pto.vsts %v, %ub[%offset], %mask {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +--- + +#### Dual Stores (Interleave) + +##### `pto.vstx2` + +- **syntax:** `pto.vstx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vreg, !pto.ptr, index, !pto.mask` +- **semantics:** Dual interleaved store (SoA → AoS conversion). +- **inputs:** + `%low` and `%high` are the two source vectors, `%dest` is the UB base pointer, + `%offset` is the displacement, `DIST` selects the interleave layout, and + `%mask` gates the participating elements. +- **outputs:** + This op has no SSA result; it writes an interleaved stream to UB. +- **constraints and limitations:** + This family is only legal for interleave distributions. The two source + vectors form an ordered pair, and the interleave semantics of that pair MUST + be preserved. +- **Latency:** **`INTLV_B32` / `INTLV_B16` / `INTLV_B8` → 12** cycles on `RV_VSTI`. + +**Distribution modes:** `INTLV_B8`, `INTLV_B16`, `INTLV_B32` + +```c +// INTLV_B32: +for (int i = 0; i < 64; i++) { + UB[base + 8*i] = low[i]; + UB[base + 8*i + 4] = high[i]; +} +``` + +--- + +#### Strided Stores + +##### `pto.vsst` + +- **syntax:** `pto.vsst %value, %dest[%offset], "STRIDE" : !pto.vreg, !pto.ptr` +- **semantics:** Strided store with fixed stride pattern. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, and `%offset` + / `STRIDE` select the fixed strided layout. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + This is a deprecated compatibility family. The stride token, not the vector + lane number alone, determines which destination elements are written. +- **Latency:** **9** cycles. + +--- + +##### `pto.vsstb` + +- **syntax:** `pto.vsstb %value, %dest, %offset, %mask : !pto.vreg, !pto.ptr, i32, !pto.mask` +- **semantics:** Block-strided store for 2D tile access. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offset` is + the packed stride/control word, and `%mask` controls block participation. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + `%offset` is a control word, not a plain byte displacement. This is a + deprecated compatibility family kept for surface coverage. +- **Latency:** **9** cycles. + +--- + +#### Scatter (Indexed) Stores + +##### `pto.vscatter` + +- **syntax:** `pto.vscatter %value, %dest, %offsets, %active_lanes : !pto.vreg, !pto.ptr, !pto.vreg, index` +- **semantics:** Indexed scatter to UB. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offsets` + provides per-lane or per-block indices, and `%active_lanes` bounds the active + requests. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + Only `b8`, `b16`, and `b32` element sizes are supported. The index vector + must use a supported integer element type and layout for this family. + Each computed address MUST be element-aligned. If two or more indices alias, + only one write is guaranteed and the winning lane is implementation-defined. +- **Latency:** **~17** cycles for **`Dtype: B16`**. + +```c +for (int i = 0; i < active_lanes; i++) + UB[base + offsets[i] * sizeof(T)] = src[i]; +``` + +--- + +#### Alignment State Stores + +##### `pto.vsta` + +- **syntax:** `pto.vsta %value, %dest[%offset] : !pto.align, !pto.ptr, index` +- **semantics:** Flush alignment state to memory. +- **inputs:** + `%value` is the pending store-alignment state, `%dest` is the UB base pointer, + and `%offset` is the flush displacement. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + The flush address MUST match the post-updated address expected by the + preceding unaligned-store stream. After the flush, the corresponding store + alignment state is consumed. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstas` +- **syntax:** `pto.vstas %value, %dest, %offset : !pto.align, !pto.ptr, i32` +- **semantics:** Scalar-register-offset form of alignment-state flush. +- **inputs:** + `%value` is the pending store-alignment state, `%dest` is the UB base + pointer, and `%offset` is the scalar-register style displacement. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + This family uses the same buffered-tail semantics as `pto.vsta` but keeps the + scalar-offset form explicit. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstar` +- **syntax:** `pto.vstar %value, %dest : !pto.align, !pto.ptr` +- **semantics:** Flush alignment state using the register-update form. +- **inputs:** + `%value` is the pending store-alignment state and `%dest` is the UB base + pointer. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + The implicit update state consumed by this flush MUST correspond to the same + store stream that produced `%value`. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstu` +- **syntax:** `%align_out, %base_out = pto.vstu %align_in, %base_in, %value, %dest, %mode : !pto.align, !pto.ptr, !pto.vreg, !pto.ptr, index -> !pto.align, !pto.ptr` +- **semantics:** Unaligned store with explicit threaded alignment/base state. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%base_in` is the current + stream base, `%value` is the vector to store, `%dest` is the UB base pointer, + and `%mode` selects the post-update behavior. +- **outputs:** + `%align_out` is the updated buffered-tail state and `%base_out` is the + post-update base pointer state. +- **constraints and limitations:** + This op models a stateful unaligned-store sequence in SSA form. A final + `pto.vsta` / `pto.vstas` / `pto.vstar` is still required to flush the trailing + buffered bytes. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstus` +- **syntax:** `%align_out, %base_out = pto.vstus %align_in, %base_in, %value, %dest, %offset : !pto.align, !pto.ptr, !pto.vreg, !pto.ptr, i32 -> !pto.align, !pto.ptr` +- **semantics:** Scalar-offset unaligned store with threaded state. +- **inputs:** + Same roles as `pto.vstu`, but `%offset` is provided explicitly as the scalar + displacement. +- **outputs:** + Updated alignment state and base state. +- **constraints and limitations:** + The same final flush requirement and state-threading constraints as + `pto.vstu` apply. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstur` +- **syntax:** `%align_out = pto.vstur %align_in, %value, %dest : !pto.align, !pto.vreg, !pto.ptr -> !pto.align` +- **semantics:** Register-update unaligned store form. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%value` is the vector to + store, and `%dest` is the UB base pointer. +- **outputs:** + `%align_out` is the updated buffered-tail state. +- **constraints and limitations:** + This op updates only the residual alignment state. A matching flush op is + still required to emit the trailing bytes. +- **Latency:** **9** cycles. + +--- + +#### Stateful Store Ops + +These ops make reference-updated state explicit as SSA results. + +##### `pto.vstu` + +- **syntax:** `%align_out, %offset_out = pto.vstu %align_in, %offset_in, %value, %base, "MODE" : !pto.align, index, !pto.vreg, !pto.ptr -> !pto.align, index` +- **semantics:** Unaligned store with align + offset state update. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%offset_in` is the current + logical byte/element displacement, `%value` is the vector being stored, and + `%base` is the UB base pointer. +- **outputs:** + `%align_out` is the updated alignment/tail state and `%offset_out` is the + next offset after applying the selected post-update rule. +- **constraints and limitations:** + The alignment state MUST be threaded in program order. A terminating flush + form such as `pto.vstar`/`pto.vstas` is still required to commit the buffered + tail bytes. +- **Latency:** **9** cycles. + +**Mode tokens:** `POST_UPDATE`, `NO_POST_UPDATE` + +--- + +##### `pto.vstus` + +- **syntax:** `%align_out, %base_out = pto.vstus %align_in, %offset, %value, %base, "MODE" : !pto.align, i32, !pto.vreg, !pto.ptr -> !pto.align, !pto.ptr` +- **semantics:** Unaligned store with scalar offset and state update. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%offset` is the scalar + displacement, `%value` is the vector being stored, and `%base` is the UB base + pointer. +- **outputs:** + `%align_out` is the updated buffered-tail state and `%base_out` is the next + base pointer when the lowering chooses a post-update form. +- **constraints and limitations:** + This is the scalar-offset stateful form of the unaligned store family. The + scalar offset width and update mode MUST match the selected form, and a later + flush op is still required. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstur` + +- **syntax:** `%align_out = pto.vstur %align_in, %value, %base, "MODE" : !pto.align, !pto.vreg, !pto.ptr -> !pto.align` +- **semantics:** Unaligned store with residual flush and state update. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%value` is the vector to + store, and `%base` is the UB base pointer. +- **outputs:** + `%align_out` is the updated residual state after the current partial store. +- **constraints and limitations:** + This form exposes only the evolving state; it does not by itself guarantee + that all buffered bytes have reached memory. A compatible final flush is still + required unless the surrounding sequence is known to be complete. +- **Latency:** **9** cycles. + + + +### 4. Predicate Load/Store + +> **Category:** UB ↔ Predicate Register data movement +> **Pipeline:** PIPE_V (Vector Core) + +Predicate registers (`!pto.mask`) are 256-bit registers that enable per-lane conditional execution. These ops move predicate values between UB and predicate registers. + +In concrete examples, `G` should be chosen to match the consumer family. The +examples below use `b32` when the loaded/stored mask is paired with `f32` +vector compares or selects. + +--- + +#### Predicate Loads + +##### `pto.plds` + +- **syntax:** `%result = pto.plds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.mask` +- **semantics:** Load predicate register with scalar offset. + +**Distribution modes:** `NORM`, `US`, `DS` + +**Example:** +```mlir +%mask = pto.plds %ub[%c0] {dist = "NORM"} : !pto.ptr -> !pto.mask +``` + +--- + +##### `pto.pld` + +- **syntax:** `%result = pto.pld %source[%offset], "DIST" : !pto.ptr, index -> !pto.mask` +- **semantics:** Load predicate register with areg offset. + +--- + +##### `pto.pldi` + +- **syntax:** `%result = pto.pldi %source, %offset, "DIST" : !pto.ptr, i32 -> !pto.mask` +- **semantics:** Load predicate register with immediate offset. + +--- + +#### Predicate Stores + +##### `pto.psts` + +- **syntax:** `pto.psts %value, %dest[%offset] : !pto.mask, !pto.ptr` +- **semantics:** Store predicate register with scalar offset. + +**Example:** +```mlir +pto.psts %mask, %ub[%c0] : !pto.mask, !pto.ptr +``` + +--- + +##### `pto.pst` + +- **syntax:** `pto.pst %value, %dest[%offset], "DIST" : !pto.mask, !pto.ptr, index` +- **semantics:** Store predicate register with areg offset. + +**Distribution modes:** `NORM`, `PK` + +--- + +##### `pto.psti` + +- **syntax:** `pto.psti %value, %dest, %offset, "DIST" : !pto.mask, !pto.ptr, i32` +- **semantics:** Store predicate register with immediate offset. + +--- + +##### `pto.pstu` + +- **syntax:** `%align_out, %base_out = pto.pstu %align_in, %value, %base : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr` +- **semantics:** Predicate unaligned store with align state update. + +--- + +#### Typical Usage Pattern + +```mlir +// Generate comparison mask +%mask = pto.vcmp %v0, %v1, %seed, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + +// Store mask to UB for later use +pto.psts %mask, %ub_mask[%c0] : !pto.mask, !pto.ptr + +// ... later in another kernel ... + +// Load mask from UB +%saved_mask = pto.plds %ub_mask[%c0] {dist = "NORM"} : !pto.ptr -> !pto.mask + +// Use for predicated select +%result = pto.vsel %v_true, %v_false, %saved_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 5. Materialization & Predicate Ops + +> **Category:** Scalar broadcast, predicate generation and manipulation +> **Pipeline:** PIPE_V (Vector Core) + +These ops create vectors from scalar values and manipulate predicate registers. + +#### Common Operand Model + +- `%value` is the scalar source value in SSA form. +- `%input` is either a source scalar or a source vector depending on the op. +- `%result` is the destination vector register value. +- For 32-bit scalar inputs, the scalar source MUST satisfy the backend's legal + scalar-source constraints for this family. + +--- + +#### Scalar Materialization + +##### `pto.vbr` + +- **syntax:** `%result = pto.vbr %value : T -> !pto.vreg` +- **semantics:** Broadcast scalar to all vector lanes. +- **inputs:** + `%value` is the scalar source. +- **outputs:** + `%result` is a vector whose active lanes all carry `%value`. +- **constraints and limitations:** + Supported forms are `b8`, `b16`, and `b32`. For `b8`, only the low 8 bits of + the scalar source are consumed. + +```c +for (int i = 0; i < N; i++) + dst[i] = value; +``` + +**Example:** +```mlir +%one = pto.vbr %c1_f32 : f32 -> !pto.vreg<64xf32> +``` + +--- + +##### `pto.vdup` + +- **syntax:** `%result = pto.vdup %input {position = "POSITION"} : T|!pto.vreg -> !pto.vreg` +- **semantics:** Duplicate scalar or vector element to all lanes. +- **inputs:** + `%input` supplies the scalar or source-lane value selected by `position`. +- **outputs:** + `%result` is the duplicated vector. +- **constraints and limitations:** + `position` selects which source element or scalar position is duplicated. The + current PTO micro Instruction representation models that selector as an attribute rather than a + separate operand. + +```c +for (int i = 0; i < N; i++) + dst[i] = input_scalar_or_element; +``` + +--- + +#### Predicate Generation + +##### `pto.pset_b8` / `pto.pset_b16` / `pto.pset_b32` + +- **syntax:** `%result = pto.pset_b8 "PAT_*" : !pto.mask` +- **syntax:** `%result = pto.pset_b16 "PAT_*" : !pto.mask` +- **syntax:** `%result = pto.pset_b32 "PAT_*" : !pto.mask` +- **semantics:** Generate predicate from pattern. + +**Patterns:** + +| Pattern | Description | +|---------|-------------| +| `PAT_ALL` | All lanes active | +| `PAT_ALLF` | All lanes inactive | +| `PAT_H` | High half active | +| `PAT_Q` | Upper quarter active | +| `PAT_VL1`...`PAT_VL128` | First N lanes active | +| `PAT_M3`, `PAT_M4` | Modular patterns | + +**Example — All 64 f32 lanes active:** +```mlir +%all_active = pto.pset_b32 "PAT_ALL" : !pto.mask +``` + +**Example — First 16 lanes active:** +```mlir +%first_16 = pto.pset_b32 "PAT_VL16" : !pto.mask +``` + +--- + +##### `pto.pge_b8` / `pto.pge_b16` / `pto.pge_b32` + +- **syntax:** `%result = pto.pge_b8 "PAT_*" : !pto.mask` +- **syntax:** `%result = pto.pge_b16 "PAT_*" : !pto.mask` +- **syntax:** `%result = pto.pge_b32 "PAT_*" : !pto.mask` +- **semantics:** Generate tail mask — first N lanes active. + +```c +for (int i = 0; i < TOTAL_LANES; i++) + mask[i] = (i < len); +``` + +**Example — Tail mask for remainder loop:** +```mlir +%tail_mask = pto.pge_b32 "PAT_VL8" : !pto.mask +``` + +--- + +##### `pto.plt_b8` / `pto.plt_b16` / `pto.plt_b32` + +- **syntax:** `%mask, %scalar_out = pto.plt_b8 %scalar : i32 -> !pto.mask, i32` +- **syntax:** `%mask, %scalar_out = pto.plt_b16 %scalar : i32 -> !pto.mask, i32` +- **syntax:** `%mask, %scalar_out = pto.plt_b32 %scalar : i32 -> !pto.mask, i32` +- **semantics:** Generate predicate state together with updated scalar state. + +--- + +#### Predicate Pack/Unpack + +##### `pto.ppack` + +- **syntax:** `%result = pto.ppack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Narrowing pack of predicate register. + +**Part tokens:** `LOWER`, `HIGHER` + +--- + +##### `pto.punpack` + +- **syntax:** `%result = pto.punpack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Widening unpack of predicate register. + +--- + +#### Predicate Logical Ops + +##### `pto.pand` + +- **syntax:** `%result = pto.pand %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise AND. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src0[i] & src1[i]; +``` + +--- + +##### `pto.por` + +- **syntax:** `%result = pto.por %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise OR. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src0[i] | src1[i]; +``` + +--- + +##### `pto.pxor` + +- **syntax:** `%result = pto.pxor %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise XOR. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src0[i] ^ src1[i]; +``` + +--- + +##### `pto.pnot` + +- **syntax:** `%result = pto.pnot %input, %mask : !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise NOT. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = ~src[i]; +``` + +--- + +##### `pto.psel` + +- **syntax:** `%result = pto.psel %src0, %src1, %sel : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate select (mux). + +```c +for (int i = 0; i < N; i++) + dst[i] = sel[i] ? src0[i] : src1[i]; +``` + +--- + +#### Predicate Movement + +##### `pto.ppack` + +- **syntax:** `%result = pto.ppack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Narrowing pack of predicate register. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src[i]; +``` + +--- + +##### `pto.punpack` + +- **syntax:** `%result = pto.punpack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Widening unpack of predicate register. + +--- + +##### `pto.pdintlv_b8` + +- **syntax:** `%low, %high = pto.pdintlv_b8 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **semantics:** Predicate deinterleave. + +--- + +##### `pto.pintlv_b16` + +- **syntax:** `%low, %high = pto.pintlv_b16 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **semantics:** Predicate interleave. + +--- + +#### Typical Usage + +```mlir +// Generate all-active mask for f32 (64 lanes) +%all = pto.pset_b32 "PAT_ALL" : !pto.mask + +// Generate tail mask for remainder (last 12 elements) +%tail = pto.pge_b32 "PAT_VL12" : !pto.mask + +// Compare and generate mask +%cmp_mask = pto.vcmp %a, %b, %all, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + +// Combine masks: only process tail elements that passed comparison +%combined = pto.pand %cmp_mask, %tail, %all : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + +// Use for predicated operation +%result = pto.vsel %true_vals, %false_vals, %combined : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 6. Unary Vector Ops + +> **Category:** Single-input vector operations +> **Pipeline:** PIPE_V (Vector Core) + +Element-wise operations that take one vector input and produce one vector output. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate operand. For this family, inactive lanes follow the + predication behavior of the selected instruction form: zeroing forms + zero-fill inactive lanes, while merging forms preserve the destination value. +- `%result` is the destination vector register value. Unless stated otherwise, + `%result` has the same lane count and element type as `%input`. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** values use **aclFloat16** in traces where measured. **bf16:** no simple-tile ST coverage on this surface; treat as **—**. + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vabs` | `RV_VABS_FP` | **5** | **5** | — | +| `pto.vneg` | `RV_VMULS` | **8** | **8** | — | +| `pto.vexp` | `RV_VEXP` | **16** | **21** | — | +| `pto.vln` | `RV_VLN` | **18** | **23** | — | +| `pto.vsqrt` | `RV_VSQRT` | **17** | **22** | — | +| `pto.vrsqrt` | `RV_VSQRT` / `RV_VDIV` | **17** / **17** | **22** / **22** | — | +| `pto.vrec` | `RV_VDIV` | **17** | **22** | — | +| `pto.vrelu` | `RV_VRELU` | **5** | **5** | — | +| `pto.vnot` | `RV_VNOT` | — | int-only paths | — | +| `pto.vmov` | `RV_VLD` proxy | **9** | **9** | — | + +--- + +#### Arithmetic + +##### `pto.vabs` + +- **syntax:** `%result = pto.vabs %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] < 0) ? -src[i] : src[i]; +``` + +- **inputs:** `%input` supplies the source lanes and `%mask` selects which lanes + participate. +- **outputs:** `%result` receives the lane-wise absolute values. +- **constraints and limitations:** Source and result types MUST match. Integer + overflow on the most-negative signed value follows the target-defined + behavior. + +--- + +##### `pto.vneg` + +- **syntax:** `%result = pto.vneg %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = -src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise arithmetic negation. +- **constraints and limitations:** Source and result types MUST match. + +--- + +#### Transcendental + +##### `pto.vexp` + +- **syntax:** `%result = pto.vexp %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = expf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds `exp(input[i])` per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + +--- + +##### `pto.vln` + +- **syntax:** `%result = pto.vln %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = logf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the natural logarithm per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + For real-number semantics, active inputs SHOULD be strictly positive; non- + positive inputs follow the target's exception/NaN rules. + +--- + +##### `pto.vsqrt` + +- **syntax:** `%result = pto.vsqrt %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = sqrtf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the square root per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + Negative active inputs follow the target's exception/NaN rules. + +--- + +##### `pto.vrsqrt` + +- **syntax:** `%result = pto.vrsqrt %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = 1.0f / sqrtf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds reciprocal-square-root values per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + Active inputs containing `+0` or `-0` follow the target's divide-style + exceptional behavior. + +--- + +##### `pto.vrec` + +- **syntax:** `%result = pto.vrec %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = 1.0f / src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the reciprocal per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + Active inputs containing `+0` or `-0` follow the target's divide-style + exceptional behavior. + +--- + +#### Activation + +##### `pto.vrelu` + +- **syntax:** `%result = pto.vrelu %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] > 0) ? src[i] : 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds `max(input[i], 0)` per active lane. +- **constraints and limitations:** Only floating-point element types are legal + on the current A5 surface described here. + +--- + +#### Bitwise + +##### `pto.vnot` + +- **syntax:** `%result = pto.vnot %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = ~src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the lane-wise bitwise inversion. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vbcnt` + +- **syntax:** `%result = pto.vbcnt %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = __builtin_popcount(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the population count for each active lane. +- **constraints and limitations:** Integer element types only. The count is + over the source element width, not over the full vector register. + +--- + +##### `pto.vcls` + +- **syntax:** `%result = pto.vcls %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = count_leading_sign_bits(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the leading-sign-bit count per active lane. +- **constraints and limitations:** Integer element types only. This operation is + sign-aware, so signed interpretation matters. + +--- + +#### Movement + +##### `pto.vmov` + +- **syntax:** `%result = pto.vmov %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Vector register copy. + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` is a copy of the source vector. +- **constraints and limitations:** Predicated `pto.vmov` behaves like a masked + copy, while the unpredicated form behaves like a full-register copy. + +--- + +#### Typical Usage + +```mlir +// Softmax numerator: exp(x - max) +%sub = pto.vsub %x, %max_broadcast, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%exp = pto.vexp %sub, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Reciprocal for division +%sum_rcp = pto.vrec %sum, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// ReLU activation +%activated = pto.vrelu %linear_out, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 7. Binary Vector Ops + +> **Category:** Two-input vector operations +> **Pipeline:** PIPE_V (Vector Core) + +Element-wise operations that take two vector inputs and produce one vector output. + +#### Common Operand Model + +- `%lhs` and `%rhs` are the two source vector register values. +- `%mask` is the predicate operand `Pg` that gates which lanes participate. +- `%result` is the destination vector register value. Unless explicitly noted, + it has the same lane count and element type as the inputs. +- Unless explicitly documented otherwise, `%lhs`, `%rhs`, and `%result` MUST + have matching vector shapes and element types. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** uses **aclFloat16** in measured traces. **bf16:** — (no dedicated vec tile ST on this surface). + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vadd` | `RV_VADD` | **7** | **7** | — | +| `pto.vsub` | `RV_VSUB` | **7** | **7** | — | +| `pto.vmul` | `RV_VMUL` | **8** | **8** | — | +| `pto.vdiv` | `RV_VDIV` | **17** | **22** | — | + +--- + +#### Arithmetic + +##### `pto.vadd` + +- **syntax:** `%result = pto.vadd %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i64, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +- **inputs:** `%lhs` and `%rhs` are added lane-wise; `%mask` selects active + lanes. +- **outputs:** `%result` is the lane-wise sum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vsub` + +- **syntax:** `%result = pto.vsub %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i64, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] - src1[i]; +``` + +- **inputs:** `%lhs` is the minuend, `%rhs` is the subtrahend, and `%mask` + selects active lanes. +- **outputs:** `%result` is the lane-wise difference. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmul` + +- **syntax:** `%result = pto.vmul %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, bf16, f32 (**NOT** i8/u8) + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] * src1[i]; +``` + +- **inputs:** `%lhs` and `%rhs` are multiplied lane-wise; `%mask` selects + active lanes. +- **outputs:** `%result` is the lane-wise product. +- **constraints and limitations:** The current A5 profile excludes `i8/u8` + forms from this surface. + +--- + +##### `pto.vdiv` + +- **syntax:** `%result = pto.vdiv %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 only (no integer division) + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] / src1[i]; +``` + +- **inputs:** `%lhs` is the numerator, `%rhs` is the denominator, and `%mask` + selects active lanes. +- **outputs:** `%result` is the lane-wise quotient. +- **constraints and limitations:** Floating-point element types only. Active + denominators containing `+0` or `-0` follow the target's exceptional + behavior. + +--- + +##### `pto.vmax` + +- **syntax:** `%result = pto.vmax %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src0[i] > src1[i]) ? src0[i] : src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` holds the lane-wise maximum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmin` + +- **syntax:** `%result = pto.vmin %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src0[i] < src1[i]) ? src0[i] : src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` holds the lane-wise minimum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +#### Bitwise + +##### `pto.vand` + +- **syntax:** `%result = pto.vand %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] & src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise AND. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vor` + +- **syntax:** `%result = pto.vor %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] | src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise OR. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vxor` + +- **syntax:** `%result = pto.vxor %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] ^ src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise XOR. +- **constraints and limitations:** Integer element types only. + +--- + +#### Shift + +##### `pto.vshl` + +- **syntax:** `%result = pto.vshl %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] << src1[i]; +``` + +- **inputs:** `%lhs` supplies the shifted value, `%rhs` supplies the per-lane + shift amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. Shift counts + SHOULD stay within `[0, bitwidth(T) - 1]`; out-of-range behavior is target- + defined unless the verifier narrows it further. + +--- + +##### `pto.vshr` + +- **syntax:** `%result = pto.vshr %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] >> src1[i]; // arithmetic for signed, logical for unsigned +``` + +- **inputs:** `%lhs` supplies the shifted value, `%rhs` supplies the per-lane + shift amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. Signedness of the + element type determines arithmetic vs logical behavior. + +--- + +#### Carry Operations + +##### `pto.vaddc` + +- **syntax:** `%result, %carry = pto.vaddc %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Add with carry output. + +```c +for (int i = 0; i < N; i++) { + uint64_t r = (uint64_t)src0[i] + src1[i]; + dst[i] = (T)r; + carry[i] = (r >> bitwidth); +} +``` + +- **inputs:** `%lhs` and `%rhs` are added lane-wise and `%mask` selects active + lanes. +- **outputs:** `%result` is the truncated arithmetic result and `%carry` is the + carry/overflow predicate per lane. +- **constraints and limitations:** This is a carry-chain integer add family. On + the current A5 surface, it SHOULD be treated as an unsigned integer + operation. + +--- + +##### `pto.vsubc` + +- **syntax:** `%result, %borrow = pto.vsubc %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Subtract with borrow output. + +```c +for (int i = 0; i < N; i++) { + dst[i] = src0[i] - src1[i]; + borrow[i] = (src0[i] < src1[i]); +} +``` + +- **inputs:** `%lhs` and `%rhs` are subtracted lane-wise and `%mask` selects + active lanes. +- **outputs:** `%result` is the arithmetic difference and `%borrow` marks lanes + that borrowed. +- **constraints and limitations:** This operation SHOULD be treated as an + unsigned 32-bit carry-chain family unless and until the verifier states + otherwise. + +--- + +#### Typical Usage + +```mlir +// Vector addition +%sum = pto.vadd %a, %b, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Element-wise multiply +%prod = pto.vmul %x, %y, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Clamp to range [min, max] +%clamped_low = pto.vmax %input, %min_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%clamped = pto.vmin %clamped_low, %max_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Bit manipulation +%masked = pto.vand %data, %bitmask, %mask : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +``` + + + +### 8. Vec-Scalar Ops + +> **Category:** Vector-scalar operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that combine a vector with a scalar value, applying the scalar to every lane. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%scalar` is the scalar operand in SSA form. +- `%mask` is the predicate operand. +- `%result` is the destination vector register value. +- For 32-bit scalar forms, the scalar source MUST satisfy the backend's legal + scalar-source constraints for this family. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** uses **aclFloat16** in measured traces. **bf16:** —. + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vadds` | `RV_VADDS` | **7** | **7** | — | +| `pto.vmuls` | `RV_VMULS` | **8** | **8** | — | + +--- + +#### Arithmetic + +##### `pto.vadds` + +- **syntax:** `%result = pto.vadds %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] + scalar; +``` + +- **inputs:** `%input` is the source vector, `%scalar` is broadcast logically to + each active lane, and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise sum. +- **constraints and limitations:** Inactive lanes follow the predication + behavior defined for this family. On the current surface, inactive lanes are + treated as zeroing lanes. + +--- + +##### `pto.vsubs` + +- **syntax:** `%result = pto.vsubs %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] - scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise difference. +- **constraints and limitations:** Integer or floating-point legality depends on + the selected type family in lowering. + +--- + +##### `pto.vmuls` + +- **syntax:** `%result = pto.vmuls %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] * scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise product. +- **constraints and limitations:** Supported element types are hardware-family + specific; the current PTO micro Instruction documentation covers the common numeric cases. + +--- + +##### `pto.vmaxs` + +- **syntax:** `%result = pto.vmaxs %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] > scalar) ? src[i] : scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise maximum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmins` + +- **syntax:** `%result = pto.vmins %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] < scalar) ? src[i] : scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise minimum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +#### Bitwise + +##### `pto.vands` + +- **syntax:** `%result = pto.vands %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] & scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise AND. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vors` + +- **syntax:** `%result = pto.vors %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] | scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise OR. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vxors` + +- **syntax:** `%result = pto.vxors %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] ^ scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise XOR. +- **constraints and limitations:** Integer element types only. + +--- + +#### Shift + +##### `pto.vshls` + +- **syntax:** `%result = pto.vshls %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] << scalar; +``` + +- **inputs:** `%input` is the value vector, `%scalar` is the uniform shift + amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. The shift amount + SHOULD stay within the source element width. + +--- + +##### `pto.vshrs` + +- **syntax:** `%result = pto.vshrs %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] >> scalar; +``` + +- **inputs:** `%input` is the value vector, `%scalar` is the uniform shift + amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vlrelu` + +- **syntax:** `%result = pto.vlrelu %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : scalar * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%scalar` is the leaky slope, + and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise leaky-ReLU result. +- **constraints and limitations:** Only `f16` and `f32` forms are currently + documented for `pto.vlrelu`. + +--- + +#### Carry Operations + +##### `pto.vaddcs` + +- **syntax:** `%result, %carry = pto.vaddcs %lhs, %rhs, %carry_in, %mask : !pto.vreg, !pto.vreg, !pto.mask, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Add with carry-in and carry-out. + +```c +for (int i = 0; i < N; i++) { + uint64_t r = (uint64_t)src0[i] + src1[i] + carry_in[i]; + dst[i] = (T)r; + carry_out[i] = (r >> bitwidth); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the value vectors, `%carry_in` is the + incoming carry predicate, and `%mask` selects active lanes. +- **outputs:** `%result` is the arithmetic result and `%carry` is the carry-out + predicate. +- **constraints and limitations:** This is the scalar-extended carry-chain + family. Treat it as an unsigned integer operation unless the verifier states a + wider legal domain. + +--- + +##### `pto.vsubcs` + +- **syntax:** `%result, %borrow = pto.vsubcs %lhs, %rhs, %borrow_in, %mask : !pto.vreg, !pto.vreg, !pto.mask, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Subtract with borrow-in and borrow-out. + +```c +for (int i = 0; i < N; i++) { + dst[i] = src0[i] - src1[i] - borrow_in[i]; + borrow_out[i] = (src0[i] < src1[i] + borrow_in[i]); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the value vectors, `%borrow_in` is the + incoming borrow predicate, and `%mask` selects active lanes. +- **outputs:** `%result` is the arithmetic result and `%borrow` is the + borrow-out predicate. +- **constraints and limitations:** This is the scalar-extended borrow-chain + family and SHOULD be treated as an unsigned integer operation. + +--- + +#### Typical Usage + +```mlir +// Add bias to all elements +%biased = pto.vadds %activation, %bias_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Scale by constant +%scaled = pto.vmuls %input, %scale, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Clamp to [0, 255] for uint8 quantization +%clamped_low = pto.vmaxs %input, %c0, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> +%clamped = pto.vmins %clamped_low, %c255, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Shift right by fixed amount +%shifted = pto.vshrs %data, %c4, %mask : !pto.vreg<64xi32>, i32, !pto.mask -> !pto.vreg<64xi32> +``` + + + +### 9. Conversion Ops + +> **Category:** Type conversion operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that convert between data types (float/int, narrowing/widening). + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%result` is the destination vector register value. +- `round_mode`, `sat`, and `part` control rounding, saturation, and lane-part + selection in attribute form. +- The single `pto.vcvt` surface covers float-int, float-float, int-float, and + int-int conversion families. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). Only representative traces below; other `pto.vcvt` conversion pairs depend on the RV lowering in the trace. + +| PTO op | RV (CA) | Note | Latency | +|--------|---------|------|---------| +| `pto.vcvt` | `RV_VCVT_F2F` | f32→f16 | **7** | +| `pto.vci` | — | no vector `RV_*` in sampled `veccore0` trace | — | + +--- + +#### `pto.vci` + +- **syntax:** `%result = pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- **semantics:** Generate a lane-index vector from a scalar seed/index value. +- **inputs:** + `%index` is the scalar seed or base index. +- **outputs:** + `%result` is the generated index vector. +- **constraints and limitations:** + This is an index-generation family, not a numeric conversion. `ORDER` and the + result element type together determine how indices are generated. + +--- + +#### `pto.vcvt` + +- **syntax:** `%result = pto.vcvt %input {round_mode = "ROUND_MODE", sat = "SAT_MODE", part = "PART_MODE"} : !pto.vreg -> !pto.vreg` +- **semantics:** Type conversion between float/int types with rounding control. + +```c +for (int i = 0; i < min(N, M); i++) + dst[i] = convert(src[i], T0, T1, round_mode); +``` + +- **inputs:** + `%input` is the source vector; attributes select rounding, saturation, and + even/odd placement when the conversion changes width. +- **outputs:** + `%result` is the converted vector. +- **constraints and limitations:** + Only documented source/destination type pairs are legal. `PART_EVEN` / + `PART_ODD` is only meaningful for width-changing forms that pack two source + streams into one destination register. + +--- + +##### Rounding Modes + +| Mode | Description | +|------|-------------| +| `ROUND_R` | Round to nearest, ties to even (default) | +| `ROUND_A` | Round away from zero | +| `ROUND_F` | Round toward negative infinity (floor) | +| `ROUND_C` | Round toward positive infinity (ceil) | +| `ROUND_Z` | Round toward zero (truncate) | +| `ROUND_O` | Round to odd | + +--- + +##### Saturation Modes + +| Mode | Description | +|------|-------------| +| `RS_ENABLE` | Saturate on overflow | +| `RS_DISABLE` | No saturation (wrap/undefined on overflow) | + +--- + +##### Part Modes (for width-changing conversions) + +| Mode | Description | +|------|-------------| +| `PART_EVEN` | Output to even-indexed lanes | +| `PART_ODD` | Output to odd-indexed lanes | + +--- + +##### A5 Supported Conversions + +**Float-Float (vcvtff):** +- f32 ↔ f16 +- f32 ↔ bf16 +- f16 ↔ bf16 + +**Float-Int (vcvtfi):** +- f16 → i16, f16 → i32 +- f32 → i16, f32 → i32 +- bf16 → i32 + +**Int-Float (vcvtif):** +- i16 → f16 +- i32 → f32 + +--- + +##### Width-Changing Conversion Pattern + +For conversions that change width (e.g., f32→f16), use even/odd parts and combine: + +```mlir +// Convert two f32 vectors to one f16 vector +%even = pto.vcvt %in0 {round_mode = "ROUND_R", sat = "RS_ENABLE", part = "PART_EVEN"} + : !pto.vreg<64xf32> -> !pto.vreg<128xf16> +%odd = pto.vcvt %in1 {round_mode = "ROUND_R", sat = "RS_ENABLE", part = "PART_ODD"} + : !pto.vreg<64xf32> -> !pto.vreg<128xf16> +%result = pto.vor %even, %odd, %mask : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +``` + +--- + +#### `pto.vtrc` + +- **syntax:** `%result = pto.vtrc %input, "ROUND_MODE" : !pto.vreg -> !pto.vreg` +- **semantics:** Truncate/round float to integer-valued float (stays in float type). + +```c +for (int i = 0; i < N; i++) + dst[i] = round_to_int_valued_float(src[i], round_mode); +``` + +- **inputs:** + `%input` is the floating-point source vector and `ROUND_MODE` selects the + truncation/rounding rule. +- **outputs:** + `%result` is still a floating-point vector, but each active lane now carries + an integer-valued floating-point result. +- **constraints and limitations:** + This op does not change the element type. `ROUND_O` is supported for avoiding + double-rounding errors during staged conversions. + +**Example:** +```mlir +// Round to nearest integer, keep as float +%rounded = pto.vtrc %input, "ROUND_R" : !pto.vreg<64xf32> -> !pto.vreg<64xf32> +// input: [1.4, 2.6, -1.5, 3.0] +// output: [1.0, 3.0, -2.0, 3.0] +``` + +--- + +#### Typical Usage + +```mlir +// Quantization: f32 → i8 with saturation +%scaled = pto.vmuls %input, %scale, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> +%quantized = pto.vcvt %scaled {round_mode = "ROUND_R", sat = "RS_ENABLE"} + : !pto.vreg<64xf32> -> !pto.vreg<64xi32> +// Then narrow i32 → i8 via pack ops + +// Mixed precision: bf16 → f32 for accumulation +%f32_vec = pto.vcvt %bf16_input {round_mode = "ROUND_R"} + : !pto.vreg<128xbf16> -> !pto.vreg<64xf32> + +// Floor for integer division +%floored = pto.vtrc %ratio, "ROUND_F" : !pto.vreg<64xf32> -> !pto.vreg<64xf32> +%int_div = pto.vcvt %floored {round_mode = "ROUND_Z"} + : !pto.vreg<64xf32> -> !pto.vreg<64xi32> +``` + + + +### 10. Reduction Ops + +> **Category:** Vector reduction operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that reduce a vector to a scalar or per-group result. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate operand `Pg`; inactive lanes do not participate. +- `%result` is the destination vector register value. +- Reduction results are written into the low-significance portion of the + destination vector and the remaining destination bits are zero-filled. + +--- + +#### Full Vector Reductions + +##### `pto.vcadd` + +- **syntax:** `%result = pto.vcadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i64, f16, f32 +- **semantics:** Sum all elements. Result in lane 0, others zeroed. + +```c +T sum = 0; +for (int i = 0; i < N; i++) + sum += src[i]; +dst[0] = sum; +for (int i = 1; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains the reduction result in its low element(s). +- **constraints and limitations:** Some narrow integer forms may widen the + internal accumulation or result placement. If all predicate bits are zero, the + result is zero. + +--- + +##### `pto.vcmax` + +- **syntax:** `%result = pto.vcmax %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Find max element with argmax. Result value + index in lane 0. + +```c +T mx = -INF; int idx = 0; +for (int i = 0; i < N; i++) + if (src[i] > mx) { mx = src[i]; idx = i; } +dst_val[0] = mx; +dst_idx[0] = idx; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` carries the reduction result in the low destination + positions. +- **constraints and limitations:** This family computes both the extremum and + location information, but the exact packing of that information into the + destination vector depends on the chosen form. If all predicate bits are zero, + the result follows the zero-filled convention. + +--- + +##### `pto.vcmin` + +- **syntax:** `%result = pto.vcmin %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Find min element with argmin. Result value + index in lane 0. + +```c +T mn = INF; int idx = 0; +for (int i = 0; i < N; i++) + if (src[i] < mn) { mn = src[i]; idx = i; } +dst_val[0] = mn; +dst_idx[0] = idx; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` carries the reduction result in the low destination + positions. +- **constraints and limitations:** As with `pto.vcmax`, the exact value/index + packing depends on the chosen form and MUST be preserved consistently. + +--- + +#### Per-VLane (Group) Reductions + +The vector register is organized as **8 VLanes** of 32 bytes each. Group reductions operate within each VLane independently. + +``` +vreg layout (f32 example, 64 elements total): +VLane 0: [0..7] VLane 1: [8..15] VLane 2: [16..23] VLane 3: [24..31] +VLane 4: [32..39] VLane 5: [40..47] VLane 6: [48..55] VLane 7: [56..63] +``` + +##### `pto.vcgadd` + +- **syntax:** `%result = pto.vcgadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Sum within each VLane. 8 results at indices 0, 8, 16, 24, 32, 40, 48, 56 (for f32). + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +// For f32: results at dst[0], dst[8], dst[16], dst[24], dst[32], dst[40], dst[48], dst[56] +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one sum per 32-byte VLane group, written + contiguously into the low slot of each group. +- **constraints and limitations:** This is a per-32-byte VLane-group reduction. + Inactive lanes are treated as zero. + +--- + +##### `pto.vcgmax` + +- **syntax:** `%result = pto.vcgmax %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Max within each VLane. + +```c +int K = N / 8; +for (int g = 0; g < 8; g++) { + T mx = -INF; + for (int i = 0; i < K; i++) + if (src[g*K + i] > mx) mx = src[g*K + i]; + dst[g*K] = mx; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one maximum per 32-byte VLane group. +- **constraints and limitations:** Grouping is by hardware 32-byte VLane, not by + arbitrary software subvector. + +--- + +##### `pto.vcgmin` + +- **syntax:** `%result = pto.vcgmin %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Min within each VLane. + +```c +int K = N / 8; +for (int g = 0; g < 8; g++) { + T mn = INF; + for (int i = 0; i < K; i++) + if (src[g*K + i] < mn) mn = src[g*K + i]; + dst[g*K] = mn; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one minimum per 32-byte VLane group. +- **constraints and limitations:** Grouping is by hardware 32-byte VLane, not by + arbitrary software subvector. + +--- + +#### Prefix Operations + +##### `pto.vcpadd` + +- **syntax:** `%result = pto.vcpadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Inclusive prefix sum (scan). + +```c +dst[0] = src[0]; +for (int i = 1; i < N; i++) + dst[i] = dst[i-1] + src[i]; +``` + +**Example:** +```c +// input: [1, 2, 3, 4, 5, ...] +// output: [1, 3, 6, 10, 15, ...] +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` is the inclusive prefix-sum vector. +- **constraints and limitations:** Only floating-point element types are + documented on the current A5 surface here. + +--- + +#### Typical Usage + +```mlir +// Softmax: find max for numerical stability +%max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// max is in lane 0, broadcast it +%max_broadcast = pto.vlds %ub_tmp[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + +// Row-wise sum using vcgadd (for 8-row tile) +%row_sums = pto.vcgadd %tile, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// Results at indices 0, 8, 16, 24, 32, 40, 48, 56 + +// Full vector sum for normalization +%total = pto.vcadd %values, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// total[0] contains the sum + +// Prefix sum for cumulative distribution +%cdf = pto.vcpadd %pdf, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 11. Compare & Select + +> **Category:** Comparison and conditional selection operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that compare vectors and conditionally select elements. + +#### Common Operand Model + +- `%src0` and `%src1` are source vector operands. +- `%scalar` is the scalar operand for scalar-comparison families. +- `%seed` is the incoming predicate that limits which lanes participate in the + compare. +- `%result` is either a predicate mask (`vcmp`, `vcmps`) or a vector register + (`vsel`, `vselr`, `vselrv2`). + +--- + +#### Comparison Operations + +##### `pto.vcmp` + +- **syntax:** `%result = pto.vcmp %src0, %src1, %seed, "CMP_MODE" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.mask` +- **semantics:** Element-wise comparison, output predicate mask. + +```c +for (int i = 0; i < N; i++) + if (seed[i]) + dst[i] = (src0[i] CMP src1[i]) ? 1 : 0; +``` + +**Compare modes:** + +| Mode | Operation | +|------|-----------| +| `eq` | Equal (==) | +| `ne` | Not equal (!=) | +| `lt` | Less than (<) | +| `le` | Less than or equal (<=) | +| `gt` | Greater than (>) | +| `ge` | Greater than or equal (>=) | + +**Example:** +```mlir +%all_active = pto.pset_b32 "PAT_ALL" : !pto.mask +%lt_mask = pto.vcmp %a, %b, %all_active, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +// lt_mask[i] = 1 if a[i] < b[i] +``` + +- **inputs:** `%src0`, `%src1`, and `%seed`; `CMP_MODE` selects the comparison + predicate. +- **outputs:** `%result` is the generated predicate mask. +- **constraints and limitations:** Only lanes enabled by `%seed` participate. + Integer and floating-point comparisons follow their own element-type-specific + comparison rules. + +--- + +##### `pto.vcmps` + +- **syntax:** `%result = pto.vcmps %src, %scalar, %seed, "CMP_MODE" : !pto.vreg, T, !pto.mask -> !pto.mask` +- **semantics:** Compare vector against scalar. + +```c +for (int i = 0; i < N; i++) + if (seed[i]) + dst[i] = (src[i] CMP scalar) ? 1 : 0; +``` + +**Example:** +```mlir +%positive_mask = pto.vcmps %values, %c0_f32, %all_active, "gt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +// positive_mask[i] = 1 if values[i] > 0 +``` + +- **inputs:** `%src` is the vector source, `%scalar` is the scalar comparison + value, and `%seed` is the incoming predicate. +- **outputs:** `%result` is the generated predicate mask. +- **constraints and limitations:** For 32-bit scalar forms, the scalar source + MUST satisfy the backend's legal scalar-source constraints for this family. + +--- + +#### Selection Operations + +##### `pto.vsel` + +- **syntax:** `%result = pto.vsel %src0, %src1, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Per-lane select based on mask. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? src0[i] : src1[i]; +``` + +**Example — Conditional assignment:** +```mlir +// dst = mask ? true_vals : false_vals +%result = pto.vsel %true_vals, %false_vals, %condition + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +- **inputs:** `%src0` is the true-path vector, `%src1` is the false-path vector, + and `%mask` selects between them. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** Source vectors and result MUST have matching + vector shapes and element types. + +--- + +##### `pto.vselr` + +- **syntax:** `%result = pto.vselr %src0, %src1 : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Select with reversed mask semantics. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? src1[i] : src0[i]; // reversed from vsel +``` + +- **inputs:** `%src0` and `%src1` are the source vectors. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** This family preserves reversed-select + semantics. If the concrete lowering uses an implicit predicate source, that + predicate source MUST be documented by the surrounding IR pattern. + +--- + +##### `pto.vselrv2` + +- **syntax:** `%result = pto.vselrv2 %src0, %src1 : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Variant select form with the same current two-vector operand shape. +- **inputs:** `%src0` and `%src1` are the source vectors. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** This page records the surface shape only. + Lowering MUST preserve the exact A5 variant semantics selected for this form. + +--- + +#### Typical Usage + +```mlir +// Clamp negative values to zero (manual ReLU) +%all = pto.pset_b32 "PAT_ALL" : !pto.mask +%zero = pto.vbr %c0_f32 : f32 -> !pto.vreg<64xf32> +%neg_mask = pto.vcmps %input, %c0_f32, %all, "lt" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%clamped = pto.vsel %zero, %input, %neg_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Element-wise max via compare+select +%gt_mask = pto.vcmp %a, %b, %all, "gt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +%max_ab = pto.vsel %a, %b, %gt_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Threshold filter +%above_thresh = pto.vcmps %scores, %threshold, %all, "ge" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%filtered = pto.vsel %scores, %zero, %above_thresh : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +--- + +#### Compare + Select Pattern + +```mlir +// Softmax safe exp: exp(x - max) where x < max returns exp of negative +// but we want to clamp to avoid underflow + +%all = pto.pset_b32 "PAT_ALL" : !pto.mask + +// 1. Compare against threshold +%too_small = pto.vcmps %x_minus_max, %min_exp_arg, %all, "lt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask + +// 2. Clamp values below threshold +%clamped = pto.vsel %min_exp_arg_vec, %x_minus_max, %too_small + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// 3. Safe exp +%exp_result = pto.vexp %clamped, %all : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 12. Data Rearrangement + +> **Category:** In-register data movement and permutation +> **Pipeline:** PIPE_V (Vector Core) + +Operations that rearrange data within or between vector registers without memory access. + +#### Common Operand Model + +- `%lhs` / `%rhs` are source vector register values. +- `%src` is a single source vector register value. +- `%result` is the destination vector register value unless an op explicitly + returns multiple vectors. +- These families do not access UB directly; they only rearrange register + contents. + +--- + +#### Interleave / Deinterleave + +##### `pto.vintlv` + +- **syntax:** `%low, %high = pto.vintlv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg, !pto.vreg` +- **semantics:** Interleave elements from two sources. + +```c +// Interleave: merge even/odd elements from two sources +// low = {src0[0], src1[0], src0[1], src1[1], ...} +// high = {src0[N/2], src1[N/2], src0[N/2+1], src1[N/2+1], ...} +``` + +- **inputs:** `%lhs` and `%rhs` are the two source vectors. +- **outputs:** `%low` and `%high` are the two destination vectors. +- **constraints and limitations:** The two outputs form a paired interleave + result. The PTO micro Instruction representation exposes that pair as two SSA results, and the pair ordering MUST + be preserved. + +--- + +##### `pto.vdintlv` + +- **syntax:** `%low, %high = pto.vdintlv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg, !pto.vreg` +- **semantics:** Deinterleave elements into even/odd. + +```c +// Deinterleave: separate even/odd elements +// low = {src0[0], src0[2], src0[4], ...} // even +// high = {src0[1], src0[3], src0[5], ...} // odd +``` + +- **inputs:** `%lhs` and `%rhs` represent the interleaved source stream in the + current PTO micro Instruction representation. +- **outputs:** `%low` and `%high` are the separated destination vectors. +- **constraints and limitations:** The two outputs form the even/odd + deinterleave result pair, and their ordering MUST be preserved. + +--- + +#### Slide / Shift + +##### `pto.vslide` + +- **syntax:** `%result = pto.vslide %src0, %src1, %amt : !pto.vreg, !pto.vreg, i16 -> !pto.vreg` +- **semantics:** Concatenate two vectors and extract N-element window at offset. + +```c +// Conceptually: tmp[0..2N-1] = {src1, src0} +// dst[i] = tmp[amt + i] +if (amt >= 0) + for (int i = 0; i < N; i++) + dst[i] = (i >= amt) ? src0[i - amt] : src1[N - amt + i]; +``` + +**Use case:** Sliding window operations, shift register patterns. + +- **inputs:** `%src0` and `%src1` provide the concatenated source window and + `%amt` selects the extraction offset. +- **outputs:** `%result` is the extracted destination window. +- **constraints and limitations:** `pto.vslide` operates on the logical + concatenation of `%src1` and `%src0`. The source order and extraction offset + MUST be preserved exactly. + +--- + +##### `pto.vshift` + +- **syntax:** `%result = pto.vshift %src, %amt : !pto.vreg, i16 -> !pto.vreg` +- **semantics:** Single-source slide (shift with zero fill). + +```c +for (int i = 0; i < N; i++) + dst[i] = (i >= amt) ? src[i - amt] : 0; +``` + +- **inputs:** `%src` is the source vector and `%amt` is the slide amount. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** This surface represents the single-source + slide/shift family. Zero-fill versus other fill behavior MUST match the + selected form. + +--- + +#### Compress / Expand + +##### `pto.vsqz` + +- **syntax:** `%result = pto.vsqz %src, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Compress — pack active lanes to front. + +```c +int j = 0; +for (int i = 0; i < N; i++) + if (mask[i]) dst[j++] = src[i]; +while (j < N) dst[j++] = 0; +``` + +**Use case:** Sparse data compaction, filtering. + +- **inputs:** `%src` is the source vector and `%mask` selects which elements are + kept. +- **outputs:** `%result` is the compacted vector. +- **constraints and limitations:** This is a reduction-style compaction family. + Preserved element order MUST match source lane order. + +--- + +##### `pto.vusqz` + +- **syntax:** `%result = pto.vusqz %mask : !pto.mask -> !pto.vreg` +- **semantics:** Expand — scatter front elements to active positions. + +```c +int j = 0; +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src_front[j++]; + else dst[i] = 0; +``` + +- **inputs:** `%mask` is the expansion/placement predicate. +- **outputs:** `%result` is the expanded vector image. +- **constraints and limitations:** The source-front stream is implicit in the + current surface. Lane placement for active and inactive positions MUST be + preserved exactly. + +--- + +#### Permutation + +##### `pto.vperm` + +- **syntax:** `%result = pto.vperm %src, %index : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** In-register permute (table lookup). **Not** memory gather. + +```c +for (int i = 0; i < N; i++) + dst[i] = src[index[i] % N]; +``` + +**Note:** This operates on register contents, unlike `pto.vgather2` which reads from UB memory. + +- **inputs:** `%src` is the source vector and `%index` supplies per-lane source + indices. +- **outputs:** `%result` is the permuted vector. +- **constraints and limitations:** This is an in-register permutation family. + `%index` values outside the legal range follow the wrap/clamp behavior of the + selected form. + +--- + +##### `pto.vselr` + +- **syntax:** `%result = pto.vselr %src0, %src1 : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Register select with reversed mask semantics. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? src1[i] : src0[i]; +``` + +- **inputs:** `%src0` and `%src1` are source vectors. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** This page records the rearrangement use of + the family; the compare/select page documents the same name from the predicate + selection perspective. + +--- + +#### Pack / Unpack + +##### `pto.vpack` + +- **syntax:** `%result = pto.vpack %src0, %src1, %part : !pto.vreg, !pto.vreg, index -> !pto.vreg<2NxT_narrow>` +- **semantics:** Narrowing pack — two wide vectors to one narrow vector. + +```c +// e.g., two vreg<64xi32> → one vreg<128xi16> +for (int i = 0; i < N; i++) { + dst[i] = truncate(src0[i]); + dst[N + i] = truncate(src1[i]); +} +``` + +- **inputs:** `%src0` and `%src1` are wide source vectors and `%part` selects + the packing submode. +- **outputs:** `%result` is the packed narrow vector. +- **constraints and limitations:** Packing is a narrowing conversion. Source + values that do not fit the destination width follow the truncation semantics + of the selected packing mode. + +--- + +##### `pto.vsunpack` + +- **syntax:** `%result = pto.vsunpack %src, %part : !pto.vreg, index -> !pto.vreg` +- **semantics:** Sign-extending unpack — narrow to wide (half). + +```c +// e.g., vreg<128xi16> → vreg<64xi32> (one half) +for (int i = 0; i < N/2; i++) + dst[i] = sign_extend(src[part_offset + i]); +``` + +- **inputs:** `%src` is the packed narrow vector and `%part` selects which half + is unpacked. +- **outputs:** `%result` is the widened vector. +- **constraints and limitations:** This is the sign-extending unpack family. + +--- + +##### `pto.vzunpack` + +- **syntax:** `%result = pto.vzunpack %src, %part : !pto.vreg, index -> !pto.vreg` +- **semantics:** Zero-extending unpack — narrow to wide (half). + +```c +for (int i = 0; i < N/2; i++) + dst[i] = zero_extend(src[part_offset + i]); +``` + +- **inputs:** `%src` is the packed narrow vector and `%part` selects which half + is unpacked. +- **outputs:** `%result` is the widened vector. +- **constraints and limitations:** This is the zero-extending unpack family. + +--- + +#### Typical Usage + +```mlir +// AoS → SoA conversion using deinterleave +%even, %odd = pto.vdintlv %interleaved0, %interleaved1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// Filter: keep only elements passing condition +%pass_mask = pto.vcmps %values, %threshold, %all, "gt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%compacted = pto.vsqz %values, %pass_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Sliding window sum +%prev_window = pto.vslide %curr, %prev, %c1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, i16 -> !pto.vreg<64xf32> +%window_sum = pto.vadd %curr, %prev_window, %all + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Type narrowing via pack +%packed_i16 = pto.vpack %wide0_i32, %wide1_i32, %c0 + : !pto.vreg<64xi32>, !pto.vreg<64xi32>, index -> !pto.vreg<128xi16> +``` + +--- + +#### V2 Interleave Forms + +##### `pto.vintlvv2` + +- **syntax:** `%result = pto.vintlvv2 %lhs, %rhs, "PART" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **inputs:** `%lhs` and `%rhs` are source vectors and `PART` selects the + returned half of the V2 interleave result. +- **outputs:** `%result` is the selected interleave half. +- **constraints and limitations:** This op exposes only one half of the V2 + result in SSA form. + +##### `pto.vdintlvv2` + +- **syntax:** `%result = pto.vdintlvv2 %lhs, %rhs, "PART" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **inputs:** `%lhs` and `%rhs` are source vectors and `PART` selects the + returned half of the V2 deinterleave result. +- **outputs:** `%result` is the selected deinterleave half. +- **constraints and limitations:** This op exposes only one half of the V2 + result in SSA form. + + + +### 13. DSA/SFU Ops + +> **Category:** Domain-specific accelerator and special function unit operations +> **Pipeline:** PIPE_V (Vector Core) / SFU + +Fused operations, special functions, and UB-to-UB operations that leverage hardware acceleration. + +#### Common Operand Model + +- `%input`, `%lhs`, `%rhs`, `%acc`, and `%alpha` are source SSA values whose + roles are called out per instruction. +- `%mask` is the predicate operand `Pg` when present. +- `%result` is the destination SSA value. +- This page mixes three different backend shapes: pure `vreg -> vreg` ops, + conversion/fusion ops, and UB-to-UB helpers. Each instruction section calls + out which storage model it uses. + +--- + +#### Fused Activation Ops (vreg→vreg) + +##### `pto.vlrelu` + +- **syntax:** `%result = pto.vlrelu %input, %alpha, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Leaky ReLU with scalar alpha. + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : alpha * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%alpha` is the scalar slope, + and `%mask` selects active lanes. +- **outputs:** `%result` is the leaky-ReLU vector. +- **constraints and limitations:** Only `f16` and `f32` forms are currently + documented for `pto.vlrelu`. + +--- + +##### `pto.vprelu` + +- **syntax:** `%result = pto.vprelu %input, %alpha : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Parametric ReLU with per-element alpha vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : alpha[i] * src[i]; +``` + +- **inputs:** `%input` is the activation vector and `%alpha` is the per-element + slope vector. +- **outputs:** `%result` is the parametric-ReLU vector. +- **constraints and limitations:** Floating-point element types only on the + current A5 surface. + +--- + +##### `pto.vexpdiff` + +- **syntax:** `%result = pto.vexpdiff %input, %max : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Fused exp(x - max) for numerically stable softmax. + +```c +for (int i = 0; i < N; i++) + dst[i] = expf(src[i] - max[i]); +``` + +**Use case:** Softmax numerator computation with numerical stability. + +- **inputs:** `%input` is the source vector and `%max` is the broadcasted + subtraction term. +- **outputs:** `%result` is the fused `exp(input - max)` vector. +- **constraints and limitations:** Floating-point element types only. + +--- + +#### Fused Compute+Convert Ops + +##### `pto.vaddrelu` + +- **syntax:** `%result = pto.vaddrelu %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Fused add + ReLU. + +```c +for (int i = 0; i < N; i++) + dst[i] = max(src0[i] + src1[i], 0); +``` + +- **inputs:** `%lhs` and `%rhs` are the two addends. +- **outputs:** `%result` is the fused add-then-ReLU result. +- **constraints and limitations:** Floating-point element types only on the + current documented surface. + +--- + +##### `pto.vsubrelu` + +- **syntax:** `%result = pto.vsubrelu %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Fused sub + ReLU. + +```c +for (int i = 0; i < N; i++) + dst[i] = max(src0[i] - src1[i], 0); +``` + +- **inputs:** `%lhs` is the minuend and `%rhs` is the subtrahend. +- **outputs:** `%result` is the fused sub-then-ReLU result. +- **constraints and limitations:** Floating-point element types only on the + current documented surface. + +--- + +##### `pto.vaxpy` + +- **syntax:** `%result = pto.vaxpy %src0, %src1, %alpha : !pto.vreg, !pto.vreg, T -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** AXPY — scalar-vector multiply-add. + +```c +for (int i = 0; i < N; i++) + dst[i] = alpha * src0[i] + src1[i]; +``` + +- **inputs:** `%src0` is the scaled vector, `%src1` is the addend vector, and + `%alpha` is the scalar multiplier. +- **outputs:** `%result` is the fused AXPY result. +- **constraints and limitations:** Floating-point element types only on the + current documented surface. + +--- + +##### `pto.vaddreluconv` + +- **syntax:** `%result = pto.vaddreluconv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Fused add + ReLU + type conversion (HW fusion). + +```c +// f32→f16 variant: +for (int i = 0; i < 64; i++) + dst_f16[i] = f32_to_f16(max(src0_f32[i] + src1_f32[i], 0)); + +// f16→i8 variant: +for (int i = 0; i < 128; i++) + dst_i8[i] = f16_to_i8(max(src0_f16[i] + src1_f16[i], 0)); +``` + +- **inputs:** `%lhs` and `%rhs` are the source vectors. +- **outputs:** `%result` is the fused add/ReLU/convert result. +- **constraints and limitations:** Only backend-supported source/destination + type pairs are legal. Rounding, saturation, and packing rules follow the + semantics of this fused operation, not an arbitrary sequence of standalone + ops. + +--- + +##### `pto.vmulconv` + +- **syntax:** `%result = pto.vmulconv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Fused mul + type conversion (HW fusion). + +```c +// f16→i8 variant: +for (int i = 0; i < 128; i++) + dst_i8[i] = f16_to_i8(src0_f16[i] * src1_f16[i]); +``` + +- **inputs:** `%lhs` and `%rhs` are the source vectors. +- **outputs:** `%result` is the fused mul/convert result. +- **constraints and limitations:** Only backend-supported source/destination + type pairs are legal. + +--- + +#### Extended Arithmetic + +##### `pto.vmull` + +- **syntax:** `%low, %high = pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` +- **A5 types:** i32/u32 (native 32×32→64 widening multiply) +- **semantics:** Widening multiply with high/low results. + +```c +for (int i = 0; i < 64; i++) { + int64_t r = (int64_t)src0_i32[i] * (int64_t)src1_i32[i]; + dst_lo[i] = (int32_t)(r & 0xFFFFFFFF); + dst_hi[i] = (int32_t)(r >> 32); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the source vectors and `%mask` selects + active lanes. +- **outputs:** `%low` and `%high` expose the widened-product low/high parts. +- **constraints and limitations:** The current documented A5 form is the native + widening 32x32->64 integer multiply family. + +--- + +##### `pto.vmula` + +- **syntax:** `%result = pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Multiply-accumulate. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) + dst[i] = acc[i] + lhs[i] * rhs[i]; +``` + +- **inputs:** `%acc` is the accumulator input, `%lhs` and `%rhs` are the + multiplicands, and `%mask` selects active lanes. +- **outputs:** `%result` is the multiply-accumulate result. +- **constraints and limitations:** `pto.vmula` is a fused multiply-accumulate + operation and is not always interchangeable with separate `vmul` plus `vadd`. + +--- + +#### Index Generation + +##### `pto.vci` + +- **syntax:** `%result = pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- **semantics:** Generate lane index vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = base_index + i; +``` + +**Use case:** Generate indices for gather/scatter, argsort, etc. + +- **inputs:** `%index` is the scalar seed/base index. +- **outputs:** `%result` is the generated index vector. +- **constraints and limitations:** This page documents the arithmetic/indexing + use of the family; the conversion page also records the same opcode for + completeness. + +--- + +#### UB-to-UB Operations + +##### `pto.vtranspose` + +- **syntax:** `pto.vtranspose %dest, %src, %config : !pto.ptr, !pto.ptr, i64` +- **semantics:** UB-to-UB transpose operation (not vreg-to-vreg). + +**Note:** This operates on UB memory directly, not on vector registers. + +- **inputs:** `%dest` and `%src` are UB pointers and `%config` is the ISA + control/config word. +- **outputs:** This op writes UB memory and returns no SSA value. +- **constraints and limitations:** This is not a `vreg -> vreg` op even though + it lives in the `pto.v*` namespace. Its correctness depends on the control + word and UB layout contract. + +--- + +#### Sorting Operations + +##### `pto.vsort32` + +- **syntax:** `pto.vsort32 %dest, %src, %config : !pto.ptr, !pto.ptr, i64` +- **semantics:** Sort 32 elements in UB. +- **inputs:** `%dest` and `%src` are UB pointers and `%config` is the ISA + control/config word. +- **outputs:** This op writes UB memory and returns no SSA value. +- **constraints and limitations:** This is a UB-to-UB accelerator helper, not a + pure vector-register op. + +--- + +##### `pto.vmrgsort` + +- **syntax:** `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr, !pto.ptr x4, i64, i64` +- **semantics:** Merge-sort 4 pre-sorted input vectors. +- **inputs:** `%dest` is the UB destination, `%src0..%src3` are the four + pre-sorted UB inputs, `%count` is the number of valid elements, and `%config` + is the operation control word. +- **outputs:** This op writes UB memory and returns no SSA value. +- **constraints and limitations:** Inputs MUST already be sorted according to + the sort order encoded by `%config`. This page uses the shorter mnemonic + `pto.vmrgsort`, while the current implementation summary still refers to + `pto.vmrgsort4`. + +--- + +#### Current Implementation Surface Summary + +- `pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` +- `pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- `pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- `pto.vbitsort %dest, %src, %indices, %repeat_times : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, index` +- `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, i64, i64` + +--- + +#### Typical Usage + +```mlir +// Softmax with fused expdiff +%max_broadcast = pto.vlds %ub_max[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> +%exp_stable = pto.vexpdiff %logits, %max_broadcast : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// Leaky ReLU activation +%activated = pto.vlrelu %linear_out, %alpha_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Fused residual add + ReLU +%residual = pto.vaddrelu %conv_out, %skip_connection : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// Generate indices for argsort +%indices = pto.vci %c0 {order = "ASC"} : i32 -> !pto.vreg<64xi32> +``` + + + +### 14. Arith (Shared MLIR Dialect) + +> **Category:** Shared full scalar `arith` surface used around PTO ops +> **Dialect:** `arith` +> **Upstream Reference:** https://mlir.llvm.org/docs/Dialects/ArithOps/ + +The upstream MLIR `arith` dialect defines primitive arithmetic, comparison, select, and cast operations over signless integer, index, floating-point, and boolean-compatible scalar values. Within PTO micro Instruction code, the full scalar operation surface of `arith` is supported. These ops are used around PTO instructions to build constants, compute offsets and loop bounds, perform general scalar math, derive valid-shape metadata, and form predicates for `scf` control flow. + +These ops are part of the documented PTO micro Instruction surface, but they are not PTO ISA instructions. + +--- + +#### Role in PTO micro Instruction Code + +- materialize scalar constants used by PTO scalar operands and loop bounds +- compute scalar/index offsets for tensor views, partitioning, and dynamic shapes +- perform general scalar integer and floating-point math outside PTO vector/tile payload operations +- derive scalar predicates that guard `scf.if` or `scf.while` +- apply scalar casts, width changes, bitwise ops, and selects without introducing PTO-specific control ops + +Prefer PTO ops for vector or tile payload math. Use `arith` for scalar computation and bookkeeping that surrounds PTO regions. + +--- + +#### Supported Surface + +The documented PTO micro Instruction surface supports the full scalar operation surface of upstream `arith`. The upstream `arith` dialect reference remains authoritative for the exhaustive op-by-op syntax and semantics. The categories below summarize how that support is used in PTO micro Instruction code. + +| Category | Representative Ops | Typical Use in PTO micro Instruction Code | +|----------|--------------------|------------------| +| Constants | `arith.constant` | integer, floating-point, boolean, and `index` constants | +| Integer / Index Arithmetic | `arith.addi`, `arith.subi`, `arith.muli`, `arith.divsi`, `arith.divui`, `arith.ceildivsi`, `arith.ceildivui`, `arith.floordivsi`, `arith.remsi`, `arith.remui` | offsets, bounds, chunk sizes, scalar math | +| Floating-Point Arithmetic | `arith.addf`, `arith.subf`, `arith.mulf`, `arith.divf`, `arith.negf`, `arith.maximumf`, `arith.minimumf`, `arith.maxnumf`, `arith.minnumf` | scalar math around PTO regions | +| Bitwise / Shift Ops | `arith.andi`, `arith.ori`, `arith.xori`, `arith.shli`, `arith.shrsi`, `arith.shrui` | flags, masks, packed scalar fields | +| Comparisons / Select | `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.maxsi`, `arith.minui` | predicates, clamps, scalar muxes | +| Casts / Width Changes | `arith.index_cast`, `arith.index_castui`, `arith.extsi`, `arith.extui`, `arith.trunci`, `arith.sitofp`, `arith.uitofp`, `arith.fptosi`, `arith.fptoui`, `arith.extf`, `arith.truncf`, `arith.bitcast` | ABI glue, dynamic-shape plumbing, scalar type adaptation | + +--- + +#### Current PTOAS Coverage + +- the current repository examples are still dominated by constants, casts, integer/index arithmetic, compares, and selects because those are the most common surrounding-scalar patterns in existing kernels +- backend-specific tests such as the PTO shared-dialect fixture visibly exercise only a representative subset of `arith` ops in a single path +- the documented PTO micro Instruction source-level contract is nevertheless the full scalar `arith` surface, not just the index-heavy subset that appears most often in current samples + +This section therefore uses representative categories and examples instead of pretending that the supported `arith` surface is limited to the currently most common sample patterns. + +--- + +#### Typical Patterns + +##### Scalar Setup + +```mlir +%c0 = arith.constant 0 : index +%c1 = arith.constant 1 : index +%scale = arith.constant 2.0 : f32 +``` + +##### Dynamic Offset Computation + +```mlir +%vrow = arith.index_cast %valid_row : i32 to index +%chunk = arith.muli %row, %c32 : index +%tail = arith.subi %limit, %chunk : index +``` + +##### General Scalar Arithmetic + +```mlir +%sum_i = arith.addi %lhs_i, %rhs_i : i32 +%sum_f = arith.addf %lhs_f, %rhs_f : f32 +%prod_f = arith.mulf %sum_f, %scale : f32 +``` + +##### Scalar Predicate and Selection + +```mlir +%is_first = arith.cmpi eq, %i, %c0 : index +%active = arith.select %is_first, %first_count, %steady_count : index +``` + +##### Bitwise / Width Adaptation + +```mlir +%flags = arith.andi %flags0, %flags1 : i32 +%wide = arith.extui %flags : i32 to i64 +%shrunk = arith.trunci %wide : i64 to i16 +``` + +--- + +#### Authoring Guidance + +- treat upstream `arith` scalar semantics as the source of truth for supported scalar ops +- keep `arith` values scalar or `index` typed; do not use `arith` as a substitute for PTO vector/tile compute +- use `arith` for general scalar math, scalar comparisons, bitwise operations, and casts around PTO regions, not just for `index` arithmetic +- use `arith.cmpi` / `arith.cmpf` plus `scf.if` / `scf.while` for control flow, not ad hoc control intrinsics +- prefer `arith.index_cast` / `arith.index_castui` at ABI or shape boundaries where `index` is required, but do not read that as a restriction on the rest of scalar `arith` + + + +### 15. SCF (Shared MLIR Dialect) + +> **Category:** Shared structured control flow around PTO regions +> **Dialect:** `scf` +> **Upstream Reference:** https://mlir.llvm.org/docs/Dialects/SCFDialect/ + +The upstream MLIR `scf` dialect defines structured control flow operations with regions, including counted loops, conditional regions, and while-style loops. In PTO micro Instruction code, `scf` is the control shell around PTO ops: it sequences DMA, vector, and tile operations; carries scalar or tile state across iterations; and preserves analyzable control flow for PTO-specific analyses and lowerings. + +These ops are part of the documented PTO micro Instruction surface, but they are shared MLIR control-flow constructs rather than PTO ISA instructions. + +--- + +#### Supported Ops + +| Op | Role in PTO micro Instruction Code | Notes | +|----|------------------------|-------| +| `scf.for` | counted loops and loop-carried values | common structured counted loop form | +| `scf.if` | structured conditional execution | may yield values or act as side-effect-only branch | +| `scf.yield` | region terminator for `for` / `if` / `while` bodies | carries loop or branch results | +| `scf.while` | break-like or stateful loops | useful for source-level structured control | +| `scf.condition` | loop-continue / loop-exit decision for `scf.while` | placed in the "before" region | + +Ops such as `scf.execute_region`, `scf.forall`, or `scf.index_switch` are not part of the documented shared-dialect portion of the PTO micro Instruction surface here. + +--- + +#### Current PTOAS Coverage + +- `scf.for`, `scf.if`, and `scf.yield` are directly exercised in the shared-dialect PTO fixture and appear widely across PTO samples +- PTO synchronization and memory analyses explicitly reason about `scf.for`, `scf.if`, `scf.yield`, and `scf.while` +- `scf.while` and `scf.condition` appear in control-flow samples and are handled in PTO-to-EmitC control-flow lowering, but they are less broadly exercised than `for` / `if` on all backend paths + +--- + +#### Typical Patterns + +##### Counted Loop + +```mlir +scf.for %i = %c0 to %c4 step %c1 { + %offset = arith.muli %i, %c32 : index + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} +``` + +##### Counted Loop with Loop-Carried State + +```mlir +%final_alive = scf.for %i = %c0 to %c4 step %c1 + iter_args(%alive = %true) -> (i1) { + %break_now = arith.cmpi eq, %i, %c2 : index + %next_alive = scf.if %break_now -> (i1) { + scf.yield %false : i1 + } else { + scf.yield %alive : i1 + } + scf.yield %next_alive : i1 +} +``` + +##### Structured Conditional Region + +```mlir +%is_mode_a = arith.cmpi eq, %mode, %c0_i32 : i32 +scf.if %is_mode_a { + pto.tmuls ins(%data, %scale_a : !pto.tile_buf<...>, f32) outs(%data : !pto.tile_buf<...>) +} else { + pto.tadds ins(%data, %bias_b : !pto.tile_buf<...>, f32) outs(%data : !pto.tile_buf<...>) +} +``` + +##### While-Style Break Loop + +```mlir +%final:2 = scf.while (%i = %c0, %alive = %true) : (index, i1) -> (index, i1) { + %lt = arith.cmpi slt, %i, %c4 : index + %go = arith.andi %lt, %alive : i1 + scf.condition(%go) %i, %alive : index, i1 +} do { +^bb0(%i2: index, %alive2: i1): + %next_i = arith.addi %i2, %c1 : index + scf.yield %next_i, %alive2 : index, i1 +} +``` + +--- + +#### Authoring Guidance + +- use `scf.for` for regular counted loops and loop-carried scalar/tile state +- use `scf.if` for structured branching around PTO regions instead of inventing PTO-specific branch ops +- keep region results explicit with `scf.yield`; this is important for PTO analyses that track carried buffers and aliasing +- use `scf.while` only when a counted loop cannot express the control cleanly; `scf.for` remains the more common and better-exercised form in the current repository +- build branch predicates and loop conditions with `arith` ops, not PTO vector masks, unless the control decision truly comes from a scalarized value + +## Supported Data Types + +| Type | Bits | vreg Lanes | Description | +|------|------|-----------|-------------| +| `i8` / `u8` | 8 | 256 | Signed/unsigned 8-bit integer | +| `i16` / `u16` | 16 | 128 | Signed/unsigned 16-bit integer | +| `f16` | 16 | 128 | IEEE 754 half precision | +| `bf16` | 16 | 128 | Brain floating point | +| `i32` / `u32` | 32 | 64 | Signed/unsigned 32-bit integer | +| `f32` | 32 | 64 | IEEE 754 single precision | +| `i64` / `u64` | 64 | 32 | Signed/unsigned 64-bit integer | + +--- + +## Common Patterns + +### Softmax (Numerically Stable) + +```mlir +// 1. Find max +%max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %max_vec, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +%max_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + +// 2. exp(x - max) using fused op +%exp = pto.vexpdiff %logits, %max_bc : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// 3. Sum +%sum = pto.vcadd %exp, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %sum, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +%sum_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + +// 4. Divide +%softmax = pto.vdiv %exp, %sum_bc, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +### ReLU Variants + +```mlir +// Standard ReLU +%relu = pto.vrelu %input, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Leaky ReLU (scalar alpha) +%lrelu = pto.vlrelu %input, %alpha, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Parametric ReLU (per-element alpha) +%prelu = pto.vprelu %input, %alpha_vec : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// Fused add + ReLU +%fused = pto.vaddrelu %a, %b : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> +``` + +### Data Layout Conversion + +```mlir +// AoS → SoA (deinterleave) +%x, %y = pto.vldx2 %ub_xy[%offset], "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// SoA → AoS (interleave) +pto.vstx2 %x, %y, %ub_xy[%offset], "INTLV_B32", %all_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, index, !pto.mask +``` + +--- + +## Quick Reference by Category + +### Memory Operations + +| Operation | Group | Description | +|-----------|-------|-------------| +| GM→UB DMA | 2 | `pto.copy_gm_to_ubuf` | +| UB→GM DMA | 2 | `pto.copy_ubuf_to_gm` | +| UB→UB Copy | 2 | `pto.copy_ubuf_to_ubuf` | +| Contiguous Load | 3 | `pto.vlds` with `NORM` dist | +| Broadcast Load | 3 | `pto.vlds` with `BRC_*` dist | +| Gather | 3 | `pto.vgather2`, `pto.vgatherb` | +| Contiguous Store | 3 | `pto.vsts` with `NORM_*` dist | +| Scatter | 3 | `pto.vscatter` | + +### Compute Operations + +| Operation | Group | Description | +|-----------|-------|-------------| +| Element-wise Arithmetic | 6, 7 | `pto.vadd`, `pto.vmul`, `pto.vabs`, etc. | +| Scalar Operations | 8 | `pto.vadds`, `pto.vmuls`, etc. | +| Transcendental | 6 | `pto.vexp`, `pto.vln`, `pto.vsqrt`, etc. | +| Reduction | 10 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin` | +| Comparison | 11 | `pto.vcmp`, `pto.vcmps` | +| Selection | 11 | `pto.vsel`, `pto.vselr` | + +### Type & Data Manipulation + +| Operation | Group | Description | +|-----------|-------|-------------| +| Type Conversion | 9 | `pto.vcvt` | +| Interleave/Deinterleave | 12 | `pto.vintlv`, `pto.vdintlv` | +| Interleave/Deinterleave | 12 | `pto.vintlv`, `pto.vdintlv`, `pto.vintlvv2`, `pto.vdintlvv2` | + +### Synchronization + +| Operation | Group | Description | +|-----------|-------|-------------| +| Intra-core Sync | 1 | `pto.set_flag`, `pto.wait_flag` | +| Pipeline Buffer Sync | 1 | `pto.get_buf`, `pto.rls_buf` | + +### Scalar & Control Operations + +Group 14 covers the full scalar `arith` surface. The rows below list common PTO micro Instruction patterns rather than an exhaustive partition of `arith` ops. + +| Operation | Group | Description | +|-----------|-------|-------------| +| Scalar Constants | 14 | `arith.constant` | +| Scalar Integer / Index Arithmetic | 14 | `arith.addi`, `arith.subi`, `arith.muli`, `arith.divsi`, `arith.remui`, `arith.ceildivsi`, etc. | +| Scalar Floating-Point Arithmetic | 14 | `arith.addf`, `arith.subf`, `arith.mulf`, `arith.divf`, `arith.maximumf`, etc. | +| Scalar Compare & Select | 14 | `arith.cmpi`, `arith.cmpf`, `arith.select` | +| Scalar Casts / Width Changes | 14 | `arith.index_cast`, `arith.index_castui`, `arith.extsi`, `arith.extui`, `arith.trunci`, `arith.sitofp`, etc. | +| Scalar Bitwise / Shift Ops | 14 | `arith.andi`, `arith.ori`, `arith.xori`, `arith.shli`, `arith.shrsi`, `arith.shrui`, etc. | +| Counted Loops | 15 | `scf.for` | +| Conditional Regions | 15 | `scf.if`, `scf.yield` | +| Break-like Structured Loops | 15 | `scf.while`, `scf.condition`, `scf.yield` | diff --git a/docs/sample.pto b/docs/sample.pto new file mode 100644 index 000000000..956b7ba4c --- /dev/null +++ b/docs/sample.pto @@ -0,0 +1,57 @@ +module attributes {pto.target_arch = "a5"} { + func.func @abs_kernel_2d(%arg0: memref, %arg1: memref) { + %c4096_i64 = arith.constant 4096 : i64 + %c0_i64 = arith.constant 0 : i64 + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [0], sizes: [%c32, %c32], strides: [%c32, %c1] {layout = #pto.layout} : memref to memref> + %reinterpret_cast_0 = memref.reinterpret_cast %arg1 to offset: [0], sizes: [%c32, %c32], strides: [%c32, %c1] {layout = #pto.layout} : memref to memref> + %memspacecast = memref.memory_space_cast %arg0 : memref to memref + %0 = builtin.unrealized_conversion_cast %memspacecast : memref to !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> + %1 = llvm.extractvalue %0[1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> + %2 = llvm.inttoptr %c0_i64 : i64 to !llvm.ptr<6> + %3 = arith.index_castui %c32 : index to i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c4_i64 = arith.constant 4 : i64 + %4 = arith.muli %3, %c32_i64 : i64 + %5 = arith.muli %c1_i64, %4 : i64 + %6 = arith.muli %5, %c4_i64 : i64 + %7 = arith.muli %4, %c4_i64 : i64 + %8 = arith.muli %3, %c4_i64 : i64 + %c128_i64 = arith.constant 128 : i64 + %9 = llvm.getelementptr %1[%c0_i64] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, i8 + a5vm.set_loop2_stride_outtoub %6, %c4096_i64 : i64, i64 + a5vm.set_loop1_stride_outtoub %7, %c4096_i64 : i64, i64 + a5vm.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + a5vm.copy_gm_to_ubuf %9, %2, %3, %3, %c0_i64, %3, %8, %c0_i64, %c0_i64, %c0_i64, %c128_i64, %c128_i64 {a5vm.element_type = "u32", data_select_bit = false, layout = "nd", ub_pad = false} : !llvm.ptr<1>, !llvm.ptr<6>, i64, i64, i64, i64, i64, i64, i64, i64, i64, i64 + a5vm.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + a5vm.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + %10 = llvm.inttoptr %c4096_i64 : i64 to !llvm.ptr<6> + %c0 = arith.constant 0 : index + %11 = arith.muli %c32, %c32 : index + %c64 = arith.constant 64 : index + %12 = arith.index_castui %11 : index to i32 + pto.vecscope { + %16 = scf.for %arg3 = %c0 to %11 step %c64 iter_args(%arg4 = %12) -> (i32) { + %mask, %scalar_out = a5vm.plt_b32 %arg4 : i32 -> !a5vm.mask, i32 + %17 = a5vm.vlds %2[%arg3] : !llvm.ptr<6> -> !a5vm.vreg<64xf32> + %18 = a5vm.vabs %17, %mask {mode = "MODE_ZEROING"} : !a5vm.vreg<64xf32>, !a5vm.mask -> !a5vm.vreg<64xf32> + a5vm.vsts %18, %10[%arg3], %mask : !a5vm.vreg<64xf32>, !llvm.ptr<6>, !a5vm.mask + scf.yield %scalar_out : i32 + } + } + a5vm.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + a5vm.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + %memspacecast_1 = memref.memory_space_cast %arg1 : memref to memref + %13 = builtin.unrealized_conversion_cast %memspacecast_1 : memref to !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> + %14 = llvm.extractvalue %13[1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> + %15 = llvm.getelementptr %14[%c0_i64] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, i8 + a5vm.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + a5vm.set_loop1_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 + a5vm.set_loop2_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 + a5vm.copy_ubuf_to_gm %10, %15, %3, %3, %c0_i64, %c32_i64, %8, %c0_i64, %c128_i64, %c128_i64 {a5vm.element_type = "u32", layout = "nd"} : !llvm.ptr<6>, !llvm.ptr<1>, i64, i64, i64, i64, i64, i64, i64, i64 + a5vm.pipe_barrier "PIPE_ALL" + return + } +} diff --git a/docs/tilelang-dsl-guide.md b/docs/tilelang-dsl-guide.md new file mode 100644 index 000000000..2dc224e50 --- /dev/null +++ b/docs/tilelang-dsl-guide.md @@ -0,0 +1,2921 @@ +# TileLang Python DSL Guide + +The TileLang Python DSL provides a high-level, Pythonic interface for authoring vector compute kernels targeting the Ascend NPU hardware. This guide is intended for library developers and performance engineers who need to write efficient, hardware-aware kernels using the PTO micro instruction set. + +The DSL is designed to generate MLIR function libraries rather than direct binary executables. These MLIR libraries are intended to be consumed by other compilation frameworks that transform high-level tile semantics into low-level vector operations. This enables library developers to focus on hardware-aware kernel authoring while relying on upstream compilers for tile-level optimizations and code generation. + +## Quick Start + +**Note on mask pattern enums**: For brevity, examples in this guide use `PAT` as an alias for `pto.MaskPattern` (e.g., `PAT.ALL` instead of `pto.MaskPattern.PAT_ALL`). You can create this alias with `from pto import MaskPattern as PAT` or `PAT = pto.MaskPattern`. + +Here's a minimal example of a tile scaling kernel using the new Tile type: + +```python +import pto + +@pto.vkernel(target="a5", op="scale", dtypes=[(pto.AnyFloat, pto.AnyFloat)], priority=10) +def tile_scale(input_tensor: pto.TensorView, # Input tensor view (shape: 256x128, f32, GM) + output_tensor: pto.TensorView, # Output tensor view (same shape and type) + scale_factor: pto.f32): # Scaling factor + # Access tensor properties + rows, cols = input_tensor.shape # (256, 128) + dtype = input_tensor.element_type # pto.f32 + + # Create a temporary tile in UB for computation + ub_tile = pto.tile((rows, cols), dtype, pto.MemorySpace.UB) + + # Load input tensor from GM to UB using high-level DMA operation + pto.dma_load(input_tensor, ub_tile) + + # Vector computation: scale all elements in the tile + all_mask = pto.make_mask(dtype, PAT.ALL) + + # Process tile in row-major order + for row in range(0, rows): + # Process each row in vector chunks + # Vector width is hardware-defined: 256 bytes / element size + # For f32: 256/4 = 64 lanes, for f16: 256/2 = 128 lanes + vector_lanes = pto.get_lanes(dtype) # Compute vector lanes based on element type (e.g., 64 for f32, 128 for f16) + for col_start in range(0, cols, vector_lanes): + # Load vector using element-indexing syntax (no manual byte calculation) + vec = pto.vlds(ub_tile[row, col_start:]) + + # Scale vector + scaled = pto.vmuls(vec, scale_factor, all_mask) + + # Store result back using element-indexing syntax + pto.vsts(scaled, ub_tile[row, col_start:], all_mask) + + # Store result from UB back to GM output tensor using high-level DMA operation + pto.dma_store(ub_tile, output_tensor) +``` + +This example demonstrates: +1. **TensorView parameters** in kernel declaration +2. **TensorView property access** (shape, element_type) +3. **Tile creation** for temporary buffers +4. **High-level DMA operations** (`dma_load`/`dma_store`) for data movement +5. **Implicit tile→UBRef conversion** in vector load/store operations +6. **Automatic DMA parameter inference** from tensor slices and tile properties + +For an even more concise example showing pure computation on UB tiles (assuming data is already in UB): + +```python +@pto.vkernel(target="a5", op="elementwise", dtypes=[(pto.AnyFloat, pto.AnyFloat, pto.AnyFloat)], priority=10) +def ub_tile_computation(a: pto.Tile, # UB tile + b: pto.Tile, # UB tile + c: pto.Tile): # UB tile (output) + dtype = a.element_type + + # All tiles are in UB memory space + all_mask = pto.make_mask(dtype, PAT.ALL) + rows, cols = a.shape + + # Element-wise: c = a + b * 2.0 + for i in range(0, rows * cols, 64): + # Load vectors from UB tiles using element-indexing syntax + vec_a = pto.vlds(a[i:]) # Implicit tile→UBRef with automatic offset calculation + vec_b = pto.vlds(b[i:]) + + # Compute: b * 2.0 + scaled_b = pto.vmuls(vec_b, 2.0, all_mask) + + # Compute: a + scaled_b + result = pto.vadd(vec_a, scaled_b, all_mask) + + # Store result to output tile using element-indexing syntax + pto.vsts(result, c[i:], all_mask) +``` + +## Core Concepts + +### Kernel Declaration + +Kernels are defined using the `@pto.vkernel` decorator with enhanced matching capabilities for PTO operations. The decorator specifies matching criteria for target architecture, operation type, data types, and additional constraints, along with a priority for disambiguation when multiple kernels match. + +#### Basic Syntax + +```python +@pto.vkernel( + target="a5", # Target architecture + op="matmul", # PTO operation name to match + dtypes=[(pto.f16, pto.f16, pto.f32)], # Type signatures + constraints=[ # Additional constraints + AnyOf(k_dim_aligned_64, continuous_memory), + Not(requires_ub_memory) + ], + priority=100 # Priority for selection +) +def matmul_fallback(a: pto.Tile, b: pto.Tile, c: pto.Tile) -> None: + # kernel implementation +``` + +#### Decorator Parameters + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `target` | `str` | Yes | Target hardware architecture (e.g., `"a5"` for Ascend 950). | +| `op` | `str` | Yes | Name of the PTO operation to match (e.g., `"matmul"`, `"conv2d"`, `"add"`). | +| `dtypes` | `List[Tuple[Type, ...]]` | Yes | List of type signatures. Each tuple specifies the expected data types for the operation's operands (inputs and outputs) in order. | +| `constraints` | `List[Constraint]` | No | Additional constraints that must be satisfied for the kernel to be selected. Can include logical combinations (`AnyOf`, `AllOf`, `Not`). Default: empty list. | +| `priority` | `int` | No | Selection priority when multiple kernels match. Higher values have higher priority. Default: `0`. | +| `name` | `str` | No | Kernel name (used for debugging and profiling). Defaults to the decorated function's name. | + +#### Type Matching Rules + +The `dtypes` parameter supports flexible type matching: + +1. **Concrete Types**: Exact type matches using DSL scalar types: + - `pto.f16`, `pto.f32`, `pto.bf16` + - `pto.i8`, `pto.i16`, `pto.i32`, `pto.i64` + - `pto.mask_b8`, `pto.mask_b16`, `pto.mask_b32` + +2. **Type Wildcards**: Generic type patterns: + - `pto.AnyFloat`: Matches any floating-point type (`f16`, `bf16`, `f32`) + - `pto.AnyInt`: Matches any integer type (`i8`, `i16`, `i32`, `i64`) + - `pto.AnyType`: Matches any scalar type + - `pto.AnyMask`: Matches any mask type (`mask_b8`, `mask_b16`, `mask_b32`) + +3. **Type Variables**: Named type variables that enforce consistency within a signature: + ```python + T = pto.TypeVar('T') # Define a type variable + + @pto.vkernel( + target="a5", + op="elementwise", + dtypes=[(T, T, T)], # All three operands must have the same type + constraints=[] + ) + def elementwise_same_type(x: pto.Tile, y: pto.Tile, out: pto.Tile) -> None: + # x, y, and out must have identical element types + pass + ``` + +4. **Mixed Signatures**: Multiple type signatures for the same operation: + ```python + @pto.vkernel( + target="a5", + op="add", + dtypes=[ + (pto.AnyFloat, pto.AnyFloat, pto.AnyFloat), # Float addition + (pto.AnyInt, pto.AnyInt, pto.AnyInt) # Integer addition + ] + ) + def generic_add(a: pto.Tile, b: pto.Tile, c: pto.Tile) -> None: + # Supports both float and integer types + pass + ``` + +#### Constraint System + +Constraints are compile-time predicates that refine kernel selection. The system supports logical combinations of constraints. + +##### Predefined Constraints + +| Constraint | Description | +|------------|-------------| +| `k_dim_aligned_64` | K dimension is aligned to 64 elements (for matmul kernels). | +| `continuous_memory` | Operands reside in contiguous memory regions. | +| `requires_ub_memory` | Operation requires Unified Buffer memory (vs. Global Memory). | +| `tensor_rank(rank)` | Operand tensor has specified rank (e.g., `tensor_rank(2)` for 2D tensors). | +| `broadcastable` | Operands are broadcastable according to NumPy-style broadcasting rules. | +| `static_shape` | All tensor dimensions are known at compile time (no dynamic shapes). | + +##### Logical Constraint Combinators + +| Combinator | Description | Example | +|------------|-------------|---------| +| `AnyOf(c1, c2, ...)` | At least one of the constraints must be satisfied. | `AnyOf(k_dim_aligned_64, continuous_memory)` | +| `AllOf(c1, c2, ...)` | All constraints must be satisfied. | `AllOf(tensor_rank(2), static_shape)` | +| `Not(c)` | The constraint must not be satisfied. | `Not(requires_ub_memory)` | + +##### Custom Constraints + +Users can define custom constraints using predicate functions: + +```python +# Define a custom constraint +def large_batch(batch_size: pto.i32) -> pto.Constraint: + """Batch size must be ≥ 1024.""" + return pto.Constraint(lambda op: op.batch_size >= batch_size) + +@pto.vkernel( + target="a5", + op="matmul", + dtypes=[(pto.AnyFloat, pto.AnyFloat, pto.AnyFloat)], + constraints=[large_batch(1024)] +) +def large_batch_matmul(a: pto.Tile, b: pto.Tile, c: pto.Tile) -> None: + # Optimized for large batch sizes + pass +``` + +#### Kernel Selection Mechanism + +When a PTO operation needs implementation, the system performs the following matching process: + +1. **Target Filtering**: Select kernels with matching `target` architecture. +2. **Operation Filtering**: Select kernels with matching `op` name. +3. **Type Matching**: For each kernel's `dtypes` list, check if any signature matches the operation's operand types: + - Concrete types must match exactly. + - Wildcard types match according to their category. + - Type variables must be consistent within the signature. +4. **Constraint Validation**: For each matching kernel, evaluate all `constraints`. If any constraint fails, the kernel is rejected. +5. **Priority Selection**: From the remaining kernels, select the one with the highest `priority` value. +6. **Fallback**: If no kernel matches, compilation fails with an error. + +#### Examples + +##### Matmul with Multiple Implementations + +```python +# High-performance kernel for aligned K dimension +@pto.vkernel( + target="a5", + op="matmul", + dtypes=[(pto.f16, pto.f16, pto.f32)], + constraints=[k_dim_aligned_64], + priority=200 +) +def matmul_aligned_k(a: pto.Tile, b: pto.Tile, c: pto.Tile) -> None: + # Optimized implementation for aligned K + pass + +# General-purpose fallback +@pto.vkernel( + target="a5", + op="matmul", + dtypes=[(pto.AnyFloat, pto.AnyFloat, pto.AnyFloat)], + constraints=[], + priority=100 +) +def matmul_general(a: pto.Tile, b: pto.Tile, c: pto.Tile) -> None: + # Generic implementation + pass +``` + +##### Elementwise Operation with Type Polymorphism + +```python +@pto.vkernel( + target="a5", + op="add", + dtypes=[ + (pto.AnyFloat, pto.AnyFloat, pto.AnyFloat), + (pto.AnyInt, pto.AnyInt, pto.AnyInt) + ], + constraints=[broadcastable] +) +def polymorphic_add(a: pto.Tile, b: pto.Tile, out: pto.Tile) -> None: + # Single implementation handles both float and integer types + dtype = a.element_type + all_mask = pto.make_mask(dtype, PAT.ALL) + # ... implementation using generic vector operations + pass +``` + +##### Constrained Convolution Kernel + +```python +@pto.vkernel( + target="a5", + op="conv2d", + dtypes=[(pto.f16, pto.f16, pto.f32)], + constraints=[ + AllOf( + tensor_rank(4), # NHWC format + static_shape, # No dynamic dimensions + Not(requires_ub_memory) # GM memory preferred + ) + ], + priority=150 +) +def conv2d_nhwc_f16_f32(input: pto.Tile, filter: pto.Tile, output: pto.Tile) -> None: + # Optimized for NHWC layout with static shapes + pass +``` + +### Value Model + +The DSL operates on symbolic values, not Python runtime values: +- **Constants**: Python literals that are typed to machine types +- **Operation results**: Values produced by DSL operations +- **Block arguments**: Values introduced by control flow structures + +### Memory Spaces + +The DSL supports different memory spaces: +- `MemorySpace.GM`: Global Memory +- `MemorySpace.UB`: Unified Buffer (local storage for vector computation) + +## Type System + +### Scalar Types + +| DSL Type | Description | Bit Width | +|----------|-------------|-----------| +| `pto.i1` | Boolean | 1 | +| `pto.i8` | 8-bit integer | 8 | +| `pto.i16` | 16-bit integer | 16 | +| `pto.i32` | 32-bit integer | 32 | +| `pto.i64` | 64-bit integer | 64 | +| `pto.f16` | Half precision float | 16 | +| `pto.bf16` | Brain float 16 | 16 | +| `pto.f32` | Single precision float | 32 | + +Python literals are automatically typed: +- `bool` → `pto.i1` +- `int` → Context-dependent (typically `pto.i32` or `pto.i64`) +- `float` → `pto.f32` + +For explicit typing, use type constructors: +```python +x = pto.i32(1024) # Explicit i32 constant +y: pto.i32 = 1024 # Type annotation +``` + +### Vector Types + +Vector registers have fixed 256-byte width: + +```python +v64_f32 = pto.vreg(64, pto.f32) # 64 lanes of f32 (64 * 32b = 2048b) +v128_f16 = pto.vreg(128, pto.f16) # 128 lanes of f16 (128 * 16b = 2048b) +``` + +Constraint: `lanes × bitwidth(element_type) = 2048` + +### Typed Masks + +Masks are typed by their bit granularity: + +| DSL Type | VPTO Type | Description | +|----------|-----------|-------------| +| `pto.mask_b8` | `!pto.mask` | 8-bit granularity mask | +| `pto.mask_b16` | `!pto.mask` | 16-bit granularity mask | +| `pto.mask_b32` | `!pto.mask` | 32-bit granularity mask | + +Mask operations must match the vector element family: +- `f32` vectors use `mask_b32` +- `f16` vectors use `mask_b16` +- `i8` vectors use `mask_b8` + +```python +# Correct: f32 vector with b32 mask +mask32 = pto.make_mask(pto.f32, PAT.ALL) +vec_f32 = pto.vlds(ptr, offset) +out = pto.vabs(vec_f32, mask32) + +# Error: mismatched mask granularity +mask16 = pto.make_mask(pto.f16, PAT.ALL) +out = pto.vabs(vec_f32, mask16) # Type error! +``` + +### Pointer Types + +Pointers combine element type and memory space: + +```python +from pto import MemorySpace + +ptr_gm = pto.ptr(pto.f32, MemorySpace.GM) # GM pointer to f32 +ptr_ub = pto.ptr(pto.f16, MemorySpace.UB) # UB pointer to f16 +``` + +The `MemorySpace` enum provides type-safe memory space specification: + +| Enum Value | Description | +|------------|-------------| +| `MemorySpace.GM` | Global Memory (off-chip HBM/DDR) | +| `MemorySpace.UB` | Unified Buffer (on-chip SRAM, 256KB) | + +This replaces string literals (`MemorySpace.GM`/`MemorySpace.UB`) with compile-time checked enums. + +### Pointer Type Aliases + +For clarity in API documentation, the following type aliases are used: + +| Alias | Equivalent Type | Description | +|-------|----------------|-------------| +| `GMPtr` | `ptr(..., MemorySpace.GM)` | Pointer to Global Memory | +| `UBPtr` | `ptr(..., MemorySpace.UB)` | Pointer to Unified Buffer | +| `UBRef` | `Union[MemRefType, UBPtr]` | UB buffer or pointer (accepted by load/store ops) | +| `Tile` | `pto.tile(...)` | Tile buffer with layout and configuration | + +### MemRef Types + +For buffer-like authoring, use memref types: + +```python +buf1d = pto.memref(256, pto.f32, MemorySpace.UB) # 1D: 256-element f32 buffer in UB +buf2d = pto.memref((256, 128), pto.f32, MemorySpace.UB) # 2D: 256x128 f32 buffer in UB +``` + +- **1D shapes**: Use a scalar integer (e.g., `256`) +- **Multi-dimensional shapes**: Use a tuple (e.g., `(256, 128)`) + +MemRefs are used for stateless load/store operations that accept `buf_like` operands in VPTO. + + +### TensorView Types + +TensorView types represent views into tensors residing in Global Memory (GM). They are used as kernel parameters for describing GM data and support slicing operations to create logical partitions for DMA load/store operations. + +### TensorView Type Definition + +TensorView types are parameterized by shape and element type: + +```python +# Kernel parameter using TensorView +@pto.vkernel(target="a5", op="custom", dtypes=[(pto.AnyFloat, pto.AnyFloat, pto.AnyFloat)], priority=10) +def tiled_kernel( + input_tensor: pto.TensorView, # GM tensor view + output_tensor: pto.TensorView, # GM tensor view + tile_buf: pto.Tile # UB tile +): + # Access tensor view properties + rows, cols = input_tensor.shape # (dynamic or static) + dtype = input_tensor.element_type # e.g., pto.f32 + strides = input_tensor.strides # stride in elements +``` + +**Important Notes:** +- TensorView is a **read-only descriptor** for GM data (though DMA store operations can write to it) +- Shape can be **static** (compile-time constants) or **dynamic** (determined at runtime) +- Strides are expressed in elements, not bytes +- Memory space is always GM (Global Memory) + +### TensorView Attributes + +| Attribute | Type | Description | +|-----------|------|-------------| +| `shape` | `tuple[int, ...]` | Tensor dimensions (2D only in current profile) | +| `element_type` | `Type` | Element data type (e.g., `pto.f32`, `pto.f16`) | +| `strides` | `tuple[int, ...]` | Stride in elements for each dimension | +| `offset` | `pto.i64` | Byte offset from base pointer (internal) | + +### Padding Mode Enum + +Padding mode controls how out-of-bounds accesses are handled during DMA load/store operations: + +| Enum Value | Description | +|------------|-------------| +| `PadMode.PadNull` | No padding (out-of-bounds access is invalid) | +| `PadMode.PadFirstElem` | Pad using the first element of the source | +| `PadMode.PadValue` | Pad using a specified value (requires `pad_value` parameter) | + +**Usage:** +```python +from pto import PadMode + +# Load with zero padding +pto.dma_load(src_partition, dst_tile, + pad_mode=PadMode.PadValue, + pad_value=pto.f32(0.0)) + +# Load with first-element padding +pto.dma_load(src_partition, dst_tile, pad_mode=PadMode.PadFirstElem) + +# Load without padding (default) +pto.dma_load(src_partition, dst_tile) # pad_mode=PadMode.PadNull +``` + +### Slicing Syntax + +TensorView supports Python slicing syntax to create logical partitions: + +```python +# Create a partition from a tensor view +partition = tensor_view[row_start:row_end, col_start:col_end] + +# Example: extract a 16x16 tile from a larger tensor +tile_view = large_tensor[0:16, 0:16] + +# Dynamic offsets and sizes +start_row = pto.i32(0) +start_col = pto.i32(0) +dynamic_partition = tensor_view[start_row:start_row+16, start_col:start_col+16] +``` + +**Constraints:** +- Slicing returns a new TensorView representing the logical partition +- The partition must be within the original tensor bounds +- Slices can be static (constant bounds) or dynamic (runtime values) + +### Alignment Type + +The `pto.align` type is used for alignment carrier operations and maps to `!pto.align`. + +### Tile Types + +Tile types represent data blocks in memory with layout and configuration information, corresponding to `!pto.tile_buf` in the VPTO IR. Tiles are commonly used as kernel parameters for tiled computations. + +#### Tile Type Definition + +```python +# Create a tile with shape, element type, and memory space +tile = pto.tile((256, 128), pto.f32, MemorySpace.UB) + +# With explicit configuration +config = pto.tile_config( + b_layout=pto.BLayout.ROW_MAJOR, + s_layout=pto.SLayout.NONE_BOX, + s_fractal_size=pto.i32(16), + pad_value=pto.PadValue.ZERO +) +tile = pto.tile((256, 128), pto.f32, MemorySpace.UB, config=config) + +# With valid shape (actual data dimensions within tile) +tile = pto.tile((256, 128), pto.f32, MemorySpace.UB, valid_shape=(240, 120)) +``` + +**Important Notes on Shape and Valid Shape:** +- **Static Shape Requirement**: The `shape` parameter must be a compile-time constant. Tile dimensions are fixed at compilation time and cannot change at runtime. +- **Valid Shape Constraints**: The `valid_shape` parameter can be either static (compile-time constant) or dynamic (determined at runtime). It must be less than or equal to the physical `shape` in each dimension. This allows for variable-sized data within a fixed tile allocation. +- **Default Behavior**: When `valid_shape` is not specified, it defaults to the full `shape`. + +#### Tile Attributes + +| Attribute | Type | Description | +|-----------|------|-------------| +| `shape` | `tuple[int, ...]` | **Static** full tile dimensions (compile-time constant) | +| `element_type` | `Type` | Element data type (e.g., `pto.f32`) | +| `memory_space` | `MemorySpace` | Memory space (GM, UB, etc.) | +| `valid_shape` | `tuple[int, ...]` | Actual data dimensions within tile (can be static/compile-time or dynamic/runtime). Must be ≤ shape in each dimension. | +| `config` | `TileConfig` | Layout and padding configuration | + +#### Tile Configuration + +The tile configuration includes layout and padding information: + +```python +# Layout enums +pto.BLayout.ROW_MAJOR # 0: row-major base layout +pto.BLayout.COL_MAJOR # 1: column-major base layout + +pto.SLayout.NONE_BOX # 0: no secondary layout +pto.SLayout.ROW_MAJOR # 1: row-major secondary layout +pto.SLayout.COL_MAJOR # 2: column-major secondary layout + +pto.PadValue.NULL # 0: no padding +pto.PadValue.ZERO # 1: zero padding +pto.PadValue.MAX # 2: maximum value padding +pto.PadValue.MIN # 3: minimum value padding +``` + +#### Tile Shape Concepts + +- **Static Physical Shape**: The `shape` parameter represents the **static physical dimensions** of the tile allocated in memory. This must be a **compile-time constant** because tile memory allocation is fixed during compilation. The shape determines the total memory footprint and cannot change at runtime. + +- **Valid Shape**: The `valid_shape` parameter represents the logical dimensions of actual data within the tile. It can be either **static** (compile-time constant) or **dynamic** (determined at runtime). It must be less than or equal to the physical `shape` in each dimension. When `valid_shape` is not specified, it defaults to the full `shape`. + +- **Key Distinction**: + - `shape`: **Static, compile-time** - Fixed tile allocation + - `valid_shape`: **Static or Dynamic** - Actual data region (must be ≤ shape) + +- **Constraints**: + - `valid_shape[i] ≤ shape[i]` for each dimension i + - `shape` must be compile-time constants + - `valid_shape` can be compile-time constants or runtime values + +- **Use Cases**: + - Fixed-size tile buffers with variable data (e.g., batch processing with different input sizes) + - Padding scenarios where physical allocation is larger than actual data + - Partial tile utilization in tiled algorithms + +- **Fractal Layout**: The `s_fractal_size` in tile configuration specifies the size of fractal blocks for secondary layout. This is used for optimized memory access patterns in matrix operations. + +- **Padding Behavior**: The `pad_value` determines how out-of-bounds accesses are handled when reading beyond `valid_shape` but within `shape`. Padding values are used for accesses in the padded region (between valid_shape and shape). + +> **⚠️ Important: Shape Constraints** +> +> The tile `shape` must be **compile-time constants**. `valid_shape` can be compile-time constants or determined at runtime, but must satisfy `valid_shape[i] ≤ shape[i]` for all dimensions i. + +### Tile Operations + +#### Basic Access Operations + +```python +# Get tile properties +shape = tile.shape # (256, 128) +elem_type = tile.element_type # pto.f32 +mem_space = tile.memory_space # MemorySpace.UB +valid_shape = tile.valid_shape # (240, 120) or same as shape + +# Get configuration properties +config = tile.config +b_layout = config.b_layout # pto.BLayout.ROW_MAJOR +s_layout = config.s_layout # pto.SLayout.NONE_BOX +s_fractal = config.s_fractal_size # pto.i32(16) +pad = config.pad_value # pto.PadValue.ZERO + +# Dynamic properties +rank = tile.rank # 2 +num_elements = tile.num_elements # 32768 (256 * 128) +valid_elements = tile.valid_elements # 28800 (240 * 120) +``` + +#### Layout and Stride Queries + +```python +# Get layout descriptors +layout_desc = tile.layout_descriptor # Returns layout description object + +# Get strides (in elements) +strides = tile.strides # (128, 1) for row-major 256x128 + +# Get byte strides +byte_strides = tile.byte_strides # (512, 4) for f32 row-major + +# Get base offset (in bytes) +offset = tile.offset # pto.i64(0) or specified offset +``` + +#### Conversion Operations + +Tiles support both explicit and implicit conversion to UBRef. When a tile is used in operations expecting a UBRef (e.g., `pto.vlds`, `pto.vsts`), it is automatically converted. + +```python +# Convert to UBRef (implicit in vector operations) +ub_ref = tile.to_ubref() # Explicit conversion +# or use tile as UBRef directly in vector ops +vec = pto.vlds(tile, offset) # Implicit conversion + +# Convert to typed pointer +ptr = tile.as_ptr() # Returns pto.ptr(pto.f32, MemorySpace.UB) + +# Convert to MemRef (for compatibility) +memref = tile.to_memref() # Returns pto.memref((256, 128), pto.f32, MemorySpace.UB) + +# Extract slice of tile +slice_tile = tile.slice((0, 0), (64, 128)) # 64x128 slice from top-left corner + +# Reshape tile (logical reshape, no data movement) +reshaped = tile.reshape((32768,)) # 1D reshape of 256x128 tile +``` + +#### Kernel Parameter Usage + +```python +@pto.vkernel(target="a5", op="scale", dtypes=[(pto.AnyFloat, pto.AnyFloat)], priority=10) +def tiled_kernel( + input_tile: pto.Tile, # Tile parameter + output_tile: pto.Tile, # Another tile parameter + scale: pto.f32 +): + # Convert tiles to UBRef for vector operations + ub_in = input_tile.to_ubref() + ub_out = output_tile.to_ubref() + + # Or use tiles directly (implicit conversion) + all_mask = pto.make_mask(pto.f32, PAT.ALL) + for i in range(0, 256, 64): + # tile implicitly converts to UBRef in vlds with element-indexing syntax + vec = pto.vlds(input_tile[i, 0:]) # Load from row i, columns 0 to vector_lanes-1 + scaled = pto.vmuls(vec, scale, all_mask) + pto.vsts(scaled, output_tile[i, 0:], all_mask) # Store to same position +``` + +#### Tile Creation from Existing Buffers + +```python +# Create tile from existing pointer with shape +ptr = pto.castptr(0, pto.ptr(pto.f32, MemorySpace.UB)) +tile = pto.tile_from_ptr(ptr, (256, 128), pto.f32) + +# Create tile from memref +memref = pto.memref((256, 128), pto.f32, MemorySpace.UB) +tile = pto.tile_from_memref(memref) + +# Create tile with explicit stride +tile = pto.tile_with_strides((256, 128), pto.f32, MemorySpace.UB, + strides=(256, 1)) # Column-major strides +``` + +## Control Flow + +### Vector Scopes + +The TileLang DSL supports implicit vector scope inference, allowing developers to write vector operations directly without explicit `pto.vecscope()` blocks. The compiler automatically groups consecutive, data-dependent vector operations into implicit vector scopes during lowering. + +#### Implicit Scope Inference + +**Note:** The explicit `pto.vecscope()` construct is deprecated. Vector operations are automatically grouped into implicit scopes by the compiler's Scope Inference Pass. + +When you write vector operations like `pto.vlds`, `pto.vadd`, `pto.vsts` directly in your code, the compiler's **Scope Inference Pass** analyzes the control flow graph and automatically creates vector scopes: + +```python +# No explicit vecscope needed - compiler infers scope boundaries +vec = pto.vlds(outer_ptr, offset) +result = pto.vadd(vec, vec, all_mask) +pto.vsts(result, dst_ptr, offset, all_mask) +``` + +The compiler automatically groups these three operations into a single implicit vector scope because they form a data-dependent chain. + +**Scope boundary rules:** +1. **Control flow boundaries**: Branches (`if`/`else`), loops (`for`/`while`), and function calls create implicit scope boundaries +2. **Scalar operations**: Non-vector operations (e.g., scalar arithmetic, pointer arithmetic) create boundaries +3. **Explicit strict_vecscope**: User-defined `strict_vecscope` blocks create hard boundaries + +#### Explicit Scope Boundaries with `strict_vecscope` + +For precise control over scope boundaries, use explicit `strict_vecscope` blocks. These create hard boundaries that prevent the compiler from merging operations across the block boundary: + +```python +with pto.strict_vecscope(src_ptr, dst_ptr, start, end) as (s, d, lb, ub): + # Operations inside this block are isolated from outside + # Compiler will not merge operations across this boundary + for i in range(lb, ub, 64): + vec = pto.vlds(s, i) + pto.vsts(vec, d, i, all_mask) +``` + +**Use cases for strict_vecscope:** +- Performance optimization: Isolate critical vector computation regions +- Debugging: Create explicit boundaries to isolate vector operations +- Resource management: Control vector register allocation boundaries +- Compatibility: Ensure deterministic scope placement for hardware constraints + +### Loops + +Counted loops use Python's `range` syntax: + +```python +for i in range(lb, ub, step): + # Loop body + mask, rem = pto.make_mask(pto.f32, remaining) + # ... +``` + +Loop-carried state is automatically handled through variable updates within the loop. + +### Conditionals + +`if` statements support value merging: + +```python +flag: pto.i1 = some_condition +step: pto.i32 = 0 + +if flag: + step = pto.i32(64) +else: + step = pto.i32(128) + +# 'step' here is the merged result from both branches +``` + +Variables defined in only one branch are local to that branch. + +## Operations + +The DSL provides operations grouped by functionality. All operations use the `pto.` prefix. Operations are organized by functional families following the VPTO instruction set architecture. + +### Pointer Construction + +Operations for creating and manipulating typed pointers. + +#### `pto.castptr(offset: pto.i64, ptr_type: Type) -> PtrType` + +**Description**: Creates a pointer with the specified offset and type. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `offset` | `pto.i64` | Byte offset from base address | +| `ptr_type` | `Type` | Target pointer type (e.g., `pto.ptr(pto.f32, MemorySpace.GM)`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `ptr` | `PtrType` | Typed pointer value | + +**Example**: +```python +ub_ptr = pto.castptr(0, pto.ptr(pto.f32, MemorySpace.UB)) +``` + +#### `pto.addptr(ptr: PtrType, offset: pto.i64) -> PtrType` + +**Description**: Adds an offset to an existing pointer. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `ptr` | `PtrType` | Source pointer | +| `offset` | `pto.i64` | Byte offset to add | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `new_ptr` | `PtrType` | Pointer with offset applied | + +**Example**: +```python +next_ptr = pto.addptr(ub_ptr, 4096) +``` + +### Synchronization & Buffer Control + +Operations for pipeline synchronization and buffer management. + +#### `pto.set_flag(pipe_from: PIPE, pipe_to: PIPE, event: EVENT) -> None` + +**Description**: Sets a synchronization flag between hardware pipelines. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pipe_from` | `PIPE` | Source pipeline (e.g., `PIPE.MTE2`) | +| `pipe_to` | `PIPE` | Destination pipeline (e.g., `PIPE.V`) | +| `event` | `EVENT` | Event identifier (e.g., `EVENT.ID0`) | + +**Returns**: None (side-effect operation) + +**Example**: +```python +from pto import PIPE, EVENT + +pto.set_flag(PIPE.MTE2, PIPE.V, EVENT.ID0) +``` + +#### `pto.wait_flag(pipe_from: PIPE, pipe_to: PIPE, event: EVENT) -> None` + +**Description**: Waits for a synchronization flag between hardware pipelines. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pipe_from` | `PIPE` | Source pipeline (e.g., `PIPE.MTE2`) | +| `pipe_to` | `PIPE` | Destination pipeline (e.g., `PIPE.V`) | +| `event` | `EVENT` | Event identifier (e.g., `EVENT.ID0`) | + +**Returns**: None (side-effect operation) + +**Example**: +```python +from pto import PIPE, EVENT + +pto.wait_flag(PIPE.MTE2, PIPE.V, EVENT.ID0) +``` + +#### `pto.pipe_barrier(pipes: PIPE) -> None` + +**Description**: Executes a barrier across specified pipelines. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pipes` | `PIPE` | Pipeline specification (e.g., `PIPE.ALL`) | + +**Returns**: None (side-effect operation) + +**Example**: +```python +from pto import PIPE + +pto.pipe_barrier(PIPE.ALL) +``` + +#### `pto.get_buf(op_type: SyncOpType, buf_id: pto.i32, mode: pto.i32 = 0) -> None` + +**Description**: Acquires a buffer for producer-consumer synchronization. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `op_type` | `SyncOpType` | Operation type (e.g., `SyncOpType.TLOAD`) | +| `buf_id` | `pto.i32` | Buffer identifier | +| `mode` | `pto.i32` | Acquisition mode (default: 0) | + +**Returns**: None (side-effect operation) + +**Example**: +```python +from pto import SyncOpType + +# Acquire buffer for DMA load operation +pto.get_buf(SyncOpType.TLOAD, 0) +``` + +#### `pto.rls_buf(op_type: SyncOpType, buf_id: pto.i32, mode: pto.i32 = 0) -> None` + +**Description**: Releases a previously acquired buffer. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `op_type` | `SyncOpType` | Operation type (e.g., `SyncOpType.TLOAD`) | +| `buf_id` | `pto.i32` | Buffer identifier | +| `mode` | `pto.i32` | Release mode (default: 0) | + +**Returns**: None (side-effect操作) + +**Example**: +```python +from pto import SyncOpType + +# Release buffer for DMA load operation +pto.rls_buf(SyncOpType.TLOAD, 0) +``` + +### Low-level DMA Programming (Legacy) + +**Note**: These low-level DMA programming operations are automatically handled by `pto.dma_load` and `pto.dma_store` in most cases. They expose hardware DMA engine parameters directly and should only be used when the automatic inference provided by the high-level API is insufficient for specific optimization needs. + +This section contains both DMA configuration operations (setting loop strides and sizes) and DMA execution operations (copying data). Prefer the high-level `pto.dma_load` and `pto.dma_store` operations which automatically infer all parameters from TensorView slices and Tile properties. + +#### When to Use Low-level DMA Programming + +Consider using these low-level operations only in the following scenarios: + +1. **Performance micro-optimization**: When specific DMA parameter tuning is required for performance-critical code +2. **Non-standard access patterns**: When TensorView slicing syntax cannot express the desired memory access pattern +3. **Hardware-specific optimizations**: When targeting specific DMA engine characteristics not captured by the high-level API + +For 99% of use cases, `pto.dma_load` and `pto.dma_store` with TensorView slicing provide sufficient control and are much easier to use correctly. + +#### Manual Configuration Example + +```python +# Manual DMA configuration (discouraged for normal use) +pto.set_loop2_stride_outtoub(32, 128) # Outer loop strides +pto.set_loop1_stride_outtoub(1, 32) # Inner loop strides +pto.set_loop_size_outtoub(16, 16) # Transfer size +pto.copy_gm_to_ubuf(gm_ptr, ub_ptr, ...) + +# Equivalent using high-level API (recommended) +pto.dma_load(input_tensor[0:16, 0:16], ub_tile) +# All loop strides and sizes automatically inferred +``` + +#### `pto.set_loop2_stride_outtoub(stride0: pto.i64, stride1: pto.i64) -> None` + +**Description**: Configures DMA stride parameters for GM → UB transfers (loop2). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `stride0` | `pto.i64` | First dimension stride | +| `stride1` | `pto.i64` | Second dimension stride | + +**Returns**: None (side-effect operation) + +#### `pto.set_loop1_stride_outtoub(stride0: pto.i64, stride1: pto.i64) -> None` + +**Description**: Configures DMA stride parameters for GM → UB transfers (loop1). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `stride0` | `pto.i64` | First dimension stride | +| `stride1` | `pto.i64` | Second dimension stride | + +**Returns**: None (side-effect operation) + +#### `pto.set_loop_size_outtoub(size0: pto.i64, size1: pto.i64) -> None` + +**Description**: Configures DMA transfer size for GM → UB transfers. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `size0` | `pto.i64` | First dimension size | +| `size1` | `pto.i64` | Second dimension size | + +**Returns**: None (side-effect operation) + +**Example**: +```python +pto.set_loop_size_outtoub(1, 1) +``` + +#### `pto.set_loop2_stride_ubtoout(stride0: pto.i64, stride1: pto.i64) -> None` + +**Description**: Configures DMA stride parameters for UB → GM transfers (loop2). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `stride0` | `pto.i64` | First dimension stride | +| `stride1` | `pto.i64` | Second dimension stride | + +**Returns**: None (side-effect operation) + +#### `pto.set_loop1_stride_ubtoout(stride0: pto.i64, stride1: pto.i64) -> None` + +**Description**: Configures DMA stride parameters for UB → GM transfers (loop1). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `stride0` | `pto.i64` | First dimension stride | +| `stride1` | `pto.i64` | Second dimension stride | + +**Returns**: None (side-effect operation) + +#### `pto.set_loop_size_ubtoout(size0: pto.i64, size1: pto.i64) -> None` + +**Description**: Configures DMA transfer size for UB → GM transfers. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `size0` | `pto.i64` | First dimension size | +| `size1` | `pto.i64` | Second dimension size | + +**Returns**: None (side-effect operation) + +#### DMA Execution Operations + +**Note**: These operations execute DMA transfers but require manual configuration of DMA parameters (loop strides, loop sizes) using the `set_loop*_stride_*` and `set_loop_size_*` operations described above. The high-level `pto.dma_load` and `pto.dma_store` operations automatically handle both configuration and execution. + +The following operations provide direct control over DMA transfers but require manual stride and size configuration. Prefer the high-level Tile Data Movement operations for most use cases. + +#### `pto.copy_gm_to_ubuf(src: GMPtr, dst: UBPtr, src_offset: pto.i64, src_stride0: pto.i64, src_stride1: pto.i64, dst_offset: pto.i64, dst_stride0: pto.i64, transpose: pto.i1, pad_left: pto.i64, pad_right: pto.i64, pad_value: pto.i64) -> None` + +**Description**: Copies data from Global Memory (GM) to Unified Buffer (UB). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `GMPtr` | Source GM pointer | +| `dst` | `UBPtr` | Destination UB pointer | +| `src_offset` | `pto.i64` | Source offset | +| `src_stride0` | `pto.i64` | Source stride dimension 0 | +| `src_stride1` | `pto.i64` | Source stride dimension 1 | +| `dst_offset` | `pto.i64` | Destination offset | +| `dst_stride0` | `pto.i64` | Destination stride dimension 0 | +| `transpose` | `pto.i1` | Transpose flag | +| `pad_left` | `pto.i64` | Left padding size | +| `pad_right` | `pto.i64` | Right padding size | +| `pad_value` | `pto.i64` | Padding value | + +**Returns**: None (side-effect operation) + +**Example**: +```python +pto.copy_gm_to_ubuf(gm_ptr, ub_ptr, 0, 32, 128, 0, 0, False, 0, 128, 128) +``` + +#### `pto.copy_ubuf_to_ubuf(src: UBPtr, dst: UBPtr, src_offset: pto.i64, src_stride0: pto.i64, src_stride1: pto.i64, dst_offset: pto.i64, dst_stride0: pto.i64, dst_stride1: pto.i64) -> None` + +**Description**: Copies data within Unified Buffer (UB → UB). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `UBPtr` | Source UB pointer | +| `dst` | `UBPtr` | Destination UB pointer | +| `src_offset` | `pto.i64` | Source offset | +| `src_stride0` | `pto.i64` | Source stride dimension 0 | +| `src_stride1` | `pto.i64` | Source stride dimension 1 | +| `dst_offset` | `pto.i64` | Destination offset | +| `dst_stride0` | `pto.i64` | Destination stride dimension 0 | +| `dst_stride1` | `pto.i64` | Destination stride dimension 1 | + +**Returns**: None (side-effect operation) + +#### `pto.copy_ubuf_to_gm(src: UBPtr, dst: GMPtr, src_offset: pto.i64, src_stride0: pto.i64, src_stride1: pto.i64, dst_offset: pto.i64, dst_stride0: pto.i64, dst_stride1: pto.i64) -> None` + +**Description**: Copies data from Unified Buffer (UB) to Global Memory (GM). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `UBPtr` | Source UB pointer | +| `dst` | `GMPtr` | Destination GM pointer | +| `src_offset` | `pto.i64` | Source offset | +| `src_stride0` | `pto.i64` | Source stride dimension 0 | +| `src_stride1` | `pto.i64` | Source stride dimension 1 | +| `dst_offset` | `pto.i64` | Destination offset | +| `dst_stride0` | `pto.i64` | Destination stride dimension 0 | +| `dst_stride1` | `pto.i64` | Destination stride dimension 1 | + +**Returns**: None (side-effect operation) + +**Example**: +```python +pto.copy_ubuf_to_gm(ub_ptr, gm_ptr, 0, 32, 128, 0, 128, 128) +``` + +### Tile Data Movement Operations + +High-level operations for moving data between TensorView partitions (GM) and Tile buffers (UB), as well as between Tile buffers. These operations **automatically handle all low-level DMA configuration** and provide an intuitive interface based on tile semantics. + +#### Automatic DMA Parameter Inference + +The `pto.dma_load` and `pto.dma_store` operations automatically infer DMA transfer parameters (loop strides, loop sizes) from: + +1. **TensorView slices** - Python slicing syntax captures stride information: + ```python + # Contiguous slice: [0:16, 0:16] + pto.dma_load(input_tensor[0:16, 0:16], ub_tile) + + # Strided slice: [0:64:2, 0:32] → stride=2 in first dimension + pto.dma_load(input_tensor[0:64:2, 0:32], ub_tile) + ``` + +2. **Tile properties** - Layout and memory space determine destination patterns: + ```python + # Row-major vs column-major layouts affect stride computation + row_major_tile = pto.tile((16, 16), pto.f32, pto.MemorySpace.UB, b_layout=pto.BLayout.ROW_MAJOR) + col_major_tile = pto.tile((16, 16), pto.f32, pto.MemorySpace.UB, b_layout=pto.BLayout.COL_MAJOR) + ``` + +3. **Transpose and padding requirements** - Specified via operation parameters. + +#### Benefits of Automatic Inference + +- **Simplified API**: No need to manually call `set_loop*_stride_*` and `set_loop_size_*` operations +- **Reduced errors**: Automatic parameter validation and consistency checking +- **Hardware abstraction**: Focus on data movement semantics, not DMA engine details +- **Portable code**: Same TileLang code works across different DMA implementations + +For advanced use cases requiring manual DMA parameter control, see the [Low-level DMA Programming (Legacy)](#low-level-dma-programming-legacy) section. + +#### `pto.dma_load(src: TensorView, dst: Tile, pad_mode: PadMode = PadMode.PadNull, pad_value: ScalarType = None, left_padding: Index = 0, right_padding: Index = 0, init_out_buffer: bool = False) -> None` + +**Description**: Loads data from a TensorView partition (GM) into a Tile buffer (UB). This maps to `pto.copy_gm_to_ubuf` operation in VPTO IR. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `TensorView` | Source tensor view partition (must be in GM) | +| `dst` | `Tile` | Destination tile buffer (must be in UB memory space) | +| `pad_mode` | `PadMode` | Padding mode (PadNull, PadFirstElem, PadValue) | +| `pad_value` | `ScalarType` | Padding value (required if `pad_mode == PadValue`) | +| `left_padding` | `Index` | Left padding element count | +| `right_padding` | `Index` | Right padding element count | +| `init_out_buffer` | `bool` | Initialize output buffer before loading | + +**Returns**: None (side-effect operation) + +**Constraints**: +- Destination tile must have `memory_space = MemorySpace.UB` +- Element types of source and destination must have same bitwidth +- Source partition shape must match destination tile valid shape (after accounting for padding) + +**Example**: +```python +# Load a 16x16 partition into a UB tile +pto.dma_load(input_tensor[0:16, 0:16], ub_tile) + +# Load with zero padding +pto.dma_load(input_tensor[0:16, 0:16], ub_tile, + pad_mode=PadMode.PadValue, + pad_value=pto.f32(0.0), + left_padding=2, + right_padding=2) +``` + +#### `pto.dma_store(src: Tile, dst: TensorView, pad_mode: PadMode = PadMode.PadNull, pad_value: ScalarType = None, left_padding: Index = 0, right_padding: Index = 0) -> None` + +**Description**: Stores data from a Tile buffer (UB) to a TensorView partition (GM). This maps to `pto.copy_ubuf_to_gm` operation in VPTO IR. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `Tile` | Source tile buffer (must be in UB memory space) | +| `dst` | `TensorView` | Destination tensor view partition (must be in GM) | +| `pad_mode` | `PadMode` | Padding mode (PadNull, PadFirstElem, PadValue) | +| `pad_value` | `ScalarType` | Padding value (required if `pad_mode == PadValue`) | +| `left_padding` | `Index` | Left padding element count | +| `right_padding` | `Index` | Right padding element count | + +**Returns**: None (side-effect operation) + +**Constraints**: +- Source tile must have `memory_space = MemorySpace.UB` +- Element types of source and destination must have same bitwidth +- Source tile valid shape must match destination partition shape (after accounting for padding) + +**Example**: +```python +# Store a UB tile to a GM partition +pto.dma_store(ub_tile, output_tensor[0:16, 0:16]) + +# Store with padding +pto.dma_store(ub_tile, output_tensor[0:16, 0:16], + pad_mode=PadMode.PadValue, + pad_value=pto.f32(0.0), + left_padding=1, + right_padding=1) +``` + +#### `pto.dma_copy(src: Tile, dst: Tile, src_offset: tuple[Index, Index] = (0, 0), dst_offset: tuple[Index, Index] = (0, 0), copy_shape: tuple[Index, Index] = None) -> None` + +**Description**: Copies data between Tile buffers within Unified Buffer (UB → UB). This maps to `pto.copy_ubuf_to_ubuf` operation in VPTO IR. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `Tile` | Source tile buffer (must be in UB memory space) | +| `dst` | `Tile` | Destination tile buffer (must be in UB memory space) | +| `src_offset` | `tuple[Index, Index]` | Offset within source tile (row, col) in elements | +| `dst_offset` | `tuple[Index, Index]` | Offset within destination tile (row, col) in elements | +| `copy_shape` | `tuple[Index, Index]` | Shape of region to copy (rows, cols) in elements. If None, copies the maximum valid region starting from offsets. | + +**Returns**: None (side-effect operation) + +**Constraints**: +- Both tiles must have `memory_space = MemorySpace.UB` +- Element types of source and destination must match +- Source and destination regions must be within tile valid shapes + +**Example**: +```python +# Copy entire tile +pto.dma_copy(src_tile, dst_tile) + +# Copy subregion: copy 8x8 block from (2,2) in src to (0,0) in dst +pto.dma_copy(src_tile, dst_tile, + src_offset=(2, 2), + dst_offset=(0, 0), + copy_shape=(8, 8)) +``` + +**Note**: These high-level operations automatically handle DMA stride and size configuration based on tile shapes, layouts, and offsets. For low-level control, see the [Low-level DMA Programming (Legacy)](#low-level-dma-programming-legacy) section. + +#### VPTO IR Mapping + +The high-level DMA operations in TileLang DSL map to corresponding operations in VPTO IR: + +| TileLang DSL Operation | VPTO IR Operation | Description | +|------------------------|-------------------|-------------| +| `pto.dma_load` | `pto.copy_gm_to_ubuf` | Loads data from GM tensor view to UB tile buffer | +| `pto.dma_store` | `pto.copy_ubuf_to_gm` | Stores data from UB tile buffer to GM tensor view | +| `pto.dma_copy` | `pto.copy_ubuf_to_ubuf` | Copies data between UB tile buffers | + +These mappings allow the TileLang compiler to generate efficient VPTO IR code while providing a higher-level, more intuitive API for developers. The compiler automatically handles the conversion between Tile/TensorView abstractions and the low-level pointer/stride representation required by VPTO IR operations. + + +### Address Generation Syntax Sugar + +To simplify address calculation and reduce manual byte offset computation errors, TileLang DSL provides syntactic sugar for vector load/store operations using element-based indexing. This syntax automatically computes the byte offset based on tile shape, element type, and layout. + +#### Indexing Syntax + +The syntax supports two indexing modes for different operations: + +1. **Vector-range indexing** (for vector load/store operations): + - **Row-major layout (default)**: `tile[row_index, col_start:]` + - `row_index`: Row index (0-based) + - `col_start:`: Starting column index followed by colon, indicating a vector-width contiguous region starting from this column + - The colon (`:`) indicates an implicit vector-width range determined by hardware vector size (256 bytes) and element type + + - **Column-major layout**: `tile[row_start:, col_index]` + - `row_start:`: Starting row index followed by colon, indicating a vector-width contiguous region starting from this row + - `col_index`: Column index (0-based) + - Used for column-major tiles (`BLayout.COL_MAJOR`) where elements are stored column-wise + + - **1D tile indexing**: `tile[start:]` (or equivalently `tile[0, start:]` for row-major or `tile[start:, 0]` for column-major) + - `start:`: Starting element index followed by colon + +2. **Single-element indexing** (for scalar load operations like `pto.vsld`): + - **Row-major layout (default)**: `tile[row_index, col_index]` + - `row_index`: Row index (0-based) + - `col_index`: Column index (0-based) + - Loads a single element at the specified position and broadcasts it to all vector lanes + + - **Column-major layout**: `tile[row_index, col_index]` (same syntax) + - `row_index`: Row index (0-based) + - `col_index`: Column index (0-based) + - Same syntax as row-major; the layout determines how the offset is computed + + - **1D tile indexing**: `tile[pos]` + - `pos`: Element index (0-based) + - Loads a single element at the specified position and broadcasts it to all vector lanes + +#### Vector Width Calculation + +The number of elements loaded/stored in a single vector operation is determined by: + +``` +vector_lanes = 256 // element_size_bytes(element_type) +``` + +**Convenience API**: Use `pto.get_lanes(dtype)` to compute vector lanes for a given element type (e.g., `pto.get_lanes(pto.f32)` returns 64, `pto.get_lanes(pto.f16)` returns 128). + +Where `element_size_bytes` is: +- 1 byte for `i8` +- 2 bytes for `i16`, `f16`, `bf16` +- 4 bytes for `i32`, `f32` +- 8 bytes for `i64` + +#### Offset Computation + +The byte offset is automatically computed based on tile layout: + +- **Row-major layout** (`BLayout.ROW_MAJOR`): + ``` + offset = (row_index * stride_row + col_start) * element_size_bytes + ``` + where `stride_row` is the row stride in elements (typically `tile.shape[1]` for contiguous tiles). + +- **Column-major layout** (`BLayout.COL_MAJOR`): + - For syntax `tile[row_start:, col_index]`: + ``` + offset = (col_index * stride_col + row_start) * element_size_bytes + ``` + - For backward compatibility with traditional offset calculation: + ``` + offset = (col_start * stride_col + row_index) * element_size_bytes + ``` + where `stride_col` is the column stride in elements (typically `tile.shape[0]` for contiguous tiles), `row_start` is the starting row index, and `col_index` is the column index. + +**Note**: +- For single-element indexing (`tile[row, col]` or `tile[pos]`), the same offset formulas apply with `col_start` replaced by `col_index` (or `start` replaced by `pos` for 1D tiles). +- For column-major vector-range indexing (`tile[row_start:, col_index]`), the offset formula uses `row_start` as the starting position along the contiguous dimension. +- The compiler automatically handles the appropriate substitution based on the indexing syntax and tile layout. + +#### Constraints + +1. **Boundary checks**: The requested region must be within tile bounds: + - **For vector-range indexing** (`:` syntax): + - **Row-major layout** (`tile[row_index, col_start:]`): + - `row_index < tile.shape[0]` and `col_start + vector_lanes <= tile.shape[1]` + - **Column-major layout** (`tile[row_start:, col_index]`): + - `row_start + vector_lanes <= tile.shape[0]` and `col_index < tile.shape[1]` + - **1D tile indexing**: `tile[start:]` + - `start + vector_lanes <= tile.shape[0]` (or `tile.shape[1]` for 1D tiles) + - **For single-element indexing** (no `:` syntax): + - 2D: `row_index < tile.shape[0]` and `col_index < tile.shape[1]` (same for both layouts) + - 1D: `pos < tile.shape[0]` (or `tile.shape[1]` for 1D tiles) + +2. **Alignment**: The computed offset must satisfy hardware alignment requirements for the operation. + +3. **Full vectors only**: The `:` syntax always loads/stores a full vector width. For partial vectors, use the traditional byte offset approach with explicit mask handling. + +4. **Single-element operations**: The single-element indexing syntax (`tile[row, col]` or `tile[pos]`) is only supported for scalar load operations like `pto.vsld`. For other operations, use vector-range indexing with `:` syntax. + +#### Supported Operations + +The indexing syntax is supported for all vector load and store operations with the following syntax mapping: + +- **Vector-range indexing** (`tile[row, col:]` or `tile[start:]`): + - Load operations: `vlds`, `vldas`, `vldus`, `vplds`, `vldx2` + - Store operations: `vsts`, `vsta`, `psts`, `vsst`, `vstx2` + +- **Single-element indexing** (`tile[row, col]` or `tile[pos]`): + - Load operations: `vsld` (scalar load with broadcast) + +#### Examples + +The following examples use row-major layout syntax. For column-major tiles, use `tile[row_start:, col_index]` syntax instead of `tile[row_index, col_start:]`. + +```python +# 2D tile indexing (row-major layout) +vec = pto.vlds(tile[i, j:]) # Load vector from row i, columns j to j+vector_lanes-1 +pto.vsts(vec, tile[i, j:], mask) # Store vector with mask + +# 1D tile indexing +vec = pto.vlds(tile[k:]) # Load vector from elements k to k+vector_lanes-1 +pto.vsts(vec, tile[k:], mask) # Store vector with mask + +# Dual load with indexing +vec1, vec2 = pto.vldx2(tile_a[i, j:], tile_b[i, j:]) + +# Aligned load with indexing +vec = pto.vldas(tile[i, j:], align) + +# Scalar load (broadcast) +vec = pto.vsld(tile[i, j]) # Load scalar at tile[i,j] and broadcast to vector +``` + +#### Comparison with Manual Offset Calculation + +**Traditional approach (error-prone):** +```python +# Manual byte offset calculation for f32 tile +rows, cols = tile.shape +row_offset = i * cols * 4 # Hard-coded 4 bytes for f32 +col_offset = j * 4 +offset = row_offset + col_offset +vec = pto.vlds(tile, offset) +``` + +**New syntax (type-safe):** +```python +# Automatic offset calculation +vec = pto.vlds(tile[i, j:]) # Compiler computes correct offset for any element type +``` + +The syntax sugar eliminates manual byte calculations, reduces errors, and makes code generic across different element types (e.g., the same kernel works for both `f16` and `f32` without modification). + +### Vector Load Operations + +Operations for loading data from memory into vector registers. + +#### `pto.vlds(buf: UBRef, offset: Index) -> VRegType` +#### `pto.vlds(tile[row, col:]) -> VRegType` +#### `pto.vlds(tile[start:]) -> VRegType` + +**Description**: Stateless vector load from buffer. Supports both traditional byte-offset syntax and new element-indexing syntax. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `UBRef` | Buffer or pointer (UB memory space) | +| `offset` | `Index` | Byte offset | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Loaded vector register | + +**Constraints**: +- Buffer must be in UB memory space +- For byte-offset syntax: offset must be properly aligned based on element type +- For element-indexing syntax: the requested vector region must be within tile bounds and satisfy alignment requirements + +**Examples**: +```python +# Traditional byte-offset syntax +vec = pto.vlds(ub_ptr, lane * 256) + +# New element-indexing syntax +vec = pto.vlds(tile[i, j:]) # Load from row i, columns j to j+vector_lanes-1 +vec = pto.vlds(tile[k:]) # Load from 1D tile, elements k to k+vector_lanes-1 + +# Generic kernel that works for both f16 and f32 +@pto.vkernel(target="a5", op="scale", dtypes=[(pto.AnyFloat, pto.AnyFloat)], priority=10) +def generic_scale(src: pto.Tile, dst: pto.Tile, scale: pto.f32): + rows, cols = src.shape + all_mask = pto.make_mask(src.element_type, PAT.ALL) + for i in range(0, rows): + for j in range(0, cols, vector_lanes): # vector_lanes computed from element type + # No manual byte calculation needed! + vec = pto.vlds(src[i, j:]) + scaled = pto.vmuls(vec, scale, all_mask) + pto.vsts(scaled, dst[i, j:], all_mask) +``` + +#### `pto.vldas(buf: UBRef, offset: Index, align: pto.align) -> VRegType` +#### `pto.vldas(tile[row, col:], align: pto.align) -> VRegType` +#### `pto.vldas(tile[start:], align: pto.align) -> VRegType` + +**Description**: Aligned vector load with explicit alignment carrier. Supports both byte-offset and element-indexing syntax. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `UBRef` | Buffer or pointer (UB memory space) | +| `offset` | `Index` | Byte offset | +| `align` | `pto.align` | Alignment specification | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index | +| `align` | `pto.align` | Alignment specification | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Loaded vector register | + +**Examples**: +```python +# Byte-offset syntax +vec = pto.vldas(ub_ptr, offset, align) + +# Element-indexing syntax +vec = pto.vldas(tile[i, j:], align) +vec = pto.vldas(tile[k:], align) +``` + +#### `pto.vldus(buf: UBRef, offset: Index) -> VRegType` +#### `pto.vldus(tile[row, col:]) -> VRegType` +#### `pto.vldus(tile[start:]) -> VRegType` + +**Description**: Unaligned vector load. Supports both byte-offset and element-indexing syntax. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `UBRef` | Buffer or pointer (UB memory space) | +| `offset` | `Index` | Byte offset | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Loaded vector register | + +**Examples**: +```python +# Byte-offset syntax +vec = pto.vldus(ub_ptr, offset) + +# Element-indexing syntax +vec = pto.vldus(tile[i, j:]) +vec = pto.vldus(tile[k:]) +``` + +#### `pto.vplds(buf: UBRef, offset: Index, pred: MaskType) -> VRegType` +#### `pto.vplds(tile[row, col:], pred: MaskType) -> VRegType` +#### `pto.vplds(tile[start:], pred: MaskType) -> VRegType` + +**Description**: Predicated vector load stateless. Supports both byte-offset and element-indexing syntax. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `UBRef` | Buffer or pointer (UB memory space) | +| `offset` | `Index` | Byte offset | +| `pred` | `MaskType` | Predicate mask | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index | +| `pred` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Loaded vector register | + +**Examples**: +```python +# Byte-offset syntax +vec = pto.vplds(ub_ptr, offset, mask) + +# Element-indexing syntax +vec = pto.vplds(tile[i, j:], mask) +vec = pto.vplds(tile[k:], mask) +``` + +#### `pto.vldx2(buf1: UBRef, buf2: UBRef, offset: Index) -> (VRegType, VRegType)` +#### `pto.vldx2(tile1[row, col:], tile2[row, col:]) -> (VRegType, VRegType)` +#### `pto.vldx2(tile1[start:], tile2[start:]) -> (VRegType, VRegType)` + +**Description**: Dual vector load from two buffers. Supports both byte-offset and element-indexing syntax. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf1` | `UBRef` | First buffer or pointer | +| `buf2` | `UBRef` | Second buffer or pointer | +| `offset` | `Index` | Byte offset (applied to both buffers) | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile1[row, col:]` | `Tile` with indexing | First 2D tile with row index and starting column | +| `tile2[row, col:]` | `Tile` with indexing | Second 2D tile with row index and starting column | +| _or_ | | | +| `tile1[start:]` | `Tile` with indexing | First 1D tile with starting element index | +| `tile2[start:]` | `Tile` with indexing | Second 1D tile with starting element index | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec1` | `VRegType` | Vector from first buffer | +| `vec2` | `VRegType` | Vector from second buffer | + +**Examples**: +```python +# Byte-offset syntax +vec1, vec2 = pto.vldx2(ub_ptr1, ub_ptr2, offset) + +# Element-indexing syntax +vec1, vec2 = pto.vldx2(tile_a[i, j:], tile_b[i, j:]) +vec1, vec2 = pto.vldx2(tile_a[k:], tile_b[k:]) +``` + +#### `pto.vsld(buf: UBRef, offset: Index) -> VRegType` +#### `pto.vsld(tile[row, col]) -> VRegType` +#### `pto.vsld(tile[pos]) -> VRegType` + +**Description**: Scalar load to vector (broadcast scalar to all lanes). Supports both byte-offset and element-indexing syntax. The element-indexing syntax loads a single element (not a vector) and broadcasts it to all lanes. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `UBRef` | Buffer or pointer (UB memory space) | +| `offset` | `Index` | Byte offset | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile[row, col]` | `Tile` with indexing | 2D tile with row and column indices (single element) | +| `tile[pos]` | `Tile` with indexing | 1D tile with element index (single element) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Vector with scalar broadcast to all lanes | + +**Examples**: +```python +# Byte-offset syntax +vec = pto.vsld(ub_ptr, offset) + +# Element-indexing syntax +vec = pto.vsld(tile[i, j]) # Load single element at (i,j) and broadcast +vec = pto.vsld(tile[k]) # Load single element at position k and broadcast +``` + +### Predicate Operations + +Operations for creating and manipulating typed masks. + +**Recommended API**: For most use cases, prefer the unified `pto.make_mask()` function which automatically selects the appropriate mask granularity based on element type and supports both tail processing (remaining element count) and pattern-based mask generation. This eliminates the need to manually choose between `plt_b8`/`plt_b16`/`plt_b32` (tail processing) and `pset_b8`/`pset_b16`/`pset_b32` (pattern generation) operations. + +**Pattern alias**: For brevity in examples, the documentation uses `PAT` as an alias for `pto.MaskPattern` (e.g., `PAT.ALL` instead of `pto.MaskPattern.PAT_ALL`). In practice, you can create this alias with `from pto import MaskPattern as PAT` or `PAT = pto.MaskPattern`. + +#### `pto.pset_b8(pattern: pto.MaskPattern) -> pto.mask_b8` + +**Description**: Creates an 8-bit granularity mask from a pattern. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pattern` | `pto.MaskPattern` | Mask pattern enum (e.g., `pto.MaskPattern.PAT_ALL`, `pto.MaskPattern.PAT_EVEN`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b8` | 8-bit granularity mask | + +**Constraints**: +- Used with `i8` vector operations + +**Example**: +```python +mask8 = pto.make_mask(pto.i8, PAT.ALL) +``` + +#### `pto.pset_b16(pattern: pto.MaskPattern) -> pto.mask_b16` + +**Description**: Creates a 16-bit granularity mask from a pattern. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pattern` | `pto.MaskPattern` | Mask pattern enum (e.g., `pto.MaskPattern.PAT_ALL`, `pto.MaskPattern.PAT_EVEN`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b16` | 16-bit granularity mask | + +**Constraints**: +- Used with `f16`/`bf16`/`i16` vector operations + +**Example**: +```python +mask16 = pto.make_mask(pto.f16, PAT.ALL) +``` + +#### `pto.pset_b32(pattern: pto.MaskPattern) -> pto.mask_b32` + +**Description**: Creates a 32-bit granularity mask from a pattern. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pattern` | `pto.MaskPattern` | Mask pattern enum (e.g., `pto.MaskPattern.PAT_ALL`, `pto.MaskPattern.PAT_EVEN`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b32` | 32-bit granularity mask | + +**Constraints**: +- Used with `f32`/`i32` vector operations + +**Example**: +```python +mask32 = pto.make_mask(pto.f32, PAT.ALL) +``` + +#### `pto.pge_b8(vec: VRegType, scalar: ScalarType) -> pto.mask_b8` + +**Description**: Creates 8-bit mask where vector elements ≥ scalar. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector (element type must match mask granularity) | +| `scalar` | `ScalarType` | Scalar comparison value | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b8` | 8-bit granularity mask | + +**Constraints**: +- Vector element type must be `i8` or compatible + +#### `pto.pge_b16(vec: VRegType, scalar: ScalarType) -> pto.mask_b16` + +**Description**: Creates 16-bit mask where vector elements ≥ scalar. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector (element type must match mask granularity) | +| `scalar` | `ScalarType` | Scalar comparison value | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b16` | 16-bit granularity mask | + +**Constraints**: +- Vector element type must be `f16`/`bf16`/`i16` + +#### `pto.pge_b32(vec: VRegType, scalar: ScalarType) -> pto.mask_b32` + +**Description**: Creates 32-bit mask where vector elements ≥ scalar. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector (element type must match mask granularity) | +| `scalar` | `ScalarType` | Scalar comparison value | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b32` | 32-bit granularity mask | + +**Constraints**: +- Vector element type must be `f32`/`i32` + +**Example**: +```python +mask = pto.pge_b32(vec_f32, pto.f32(0.0)) +``` + +#### `pto.plt_b8(vec: VRegType, scalar: ScalarType) -> pto.mask_b8` + +**Description**: Creates 8-bit mask where vector elements < scalar. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector (element type must match mask granularity) | +| `scalar` | `ScalarType` | Scalar comparison value | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b8` | 8-bit granularity mask | + +#### `pto.plt_b16(vec: VRegType, scalar: ScalarType) -> pto.mask_b16` + +**Description**: Creates 16-bit mask where vector elements < scalar. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector (element type must match mask granularity) | +| `scalar` | `ScalarType` | Scalar comparison value | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b16` | 16-bit granularity mask | + +#### `pto.plt_b32(vec: VRegType, scalar: ScalarType) -> (pto.mask_b32, pto.i32)` + +**Description**: Creates 32-bit mask where vector elements < scalar, returns mask and remaining count. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector (element type must match mask granularity) | +| `scalar` | `ScalarType` | Scalar comparison value | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b32` | 32-bit granularity mask | +| `remaining` | `pto.i32` | Remaining element count | + +**Example**: +```python +mask, remaining = pto.plt_b32(vec_f32, pto.f32(10.0)) +``` + +#### `pto.make_mask(element_type: Type, value: pto.i32 | pto.MaskPattern) -> MaskType | (MaskType, pto.i32)` + +**Description**: Creates a mask with appropriate bitwidth (8, 16, or 32) based on element type, automatically inferring whether to perform tail processing or pattern-based mask generation based on the `value` parameter type. This convenience function eliminates the need to manually choose between `plt_b8`/`plt_b16`/`plt_b32` and `pset_b8`/`pset_b16`/`pset_b32` operations. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `element_type` | `Type` | Element type (e.g., `pto.f32`, `pto.f16`, `pto.i8`) | +| `value` | `pto.i32` \| `pto.MaskPattern` | Either:
- Remaining element count (as `pto.i32`) for tail processing
- Mask pattern enum value for fixed mask generation (e.g., `pto.MaskPattern.PAT_ALL`, `pto.MaskPattern.PAT_VL32`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `MaskType` | Generated mask with appropriate granularity | +| `remaining` | `pto.i32` | Updated remaining element count (only returned when `value` is a `pto.i32` for tail processing) | + +**Constraints**: +- The `element_type` must be one of: `f32`, `i32`, `f16`, `bf16`, `i16`, `i8` +- The returned mask granularity matches the element type: 32-bit for `f32`/`i32`, 16-bit for `f16`/`bf16`/`i16`, 8-bit for `i8` +- The function infers the operation mode from the `value` parameter type at compile time: + - `pto.i32` value → tail processing mode (returns `(mask, updated_remaining)`) + - `pto.MaskPattern` enum value → pattern mode (returns `mask` only) + +**Implementation Note**: This function is a DSL macro that performs type-based dispatch at compile time: +- When `value` is a `pto.i32` expression: expands to corresponding `plt_b` instruction (`plt_b32`, `plt_b16`, or `plt_b8`) +- When `value` is a `pto.MaskPattern` enum value: expands to corresponding `pset_b` instruction (`pset_b32`, `pset_b16`, or `pset_b8`) + +**Example**: +```python +# Tail processing with f32 vectors: value is pto.i32 → expands to plt_b32 +mask_f32, remaining_f32 = pto.make_mask(pto.f32, remaining_elements) + +# Tail processing with f16 vectors: value is pto.i32 → expands to plt_b16 +mask_f16, remaining_f16 = pto.make_mask(pto.f16, remaining_elements) + +# Tail processing with i8 vectors: value is pto.i32 → expands to plt_b8 +mask_i8, remaining_i8 = pto.make_mask(pto.i8, remaining_elements) + +# Pattern-based mask with f32 vectors: value is MaskPattern enum → expands to pset_b32 +mask_all_f32 = pto.make_mask(pto.f32, PAT.ALL) + +# Pattern-based mask with f16 vectors: value is MaskPattern enum → expands to pset_b16 +mask_even_f16 = pto.make_mask(pto.f16, PAT.EVEN) + +# Pattern-based mask with i8 vectors: value is MaskPattern enum → expands to pset_b8 +mask_all_i8 = pto.make_mask(pto.i8, PAT.ALL) + +# Type annotations help clarify expected parameter types +remaining: pto.i32 = 1024 +mask1, updated = pto.make_mask(pto.f32, remaining) # tail processing +mask2 = pto.make_mask(pto.f32, PAT.ALL) # pattern mode +``` + +#### `pto.ppack(mask: MaskType) -> pto.i32` + +**Description**: Packs mask bits into a 32-bit integer. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Input mask (`mask_b8`, `mask_b16`, or `mask_b32`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `packed` | `pto.i32` | Packed mask bits | + +#### `pto.punpack(packed: pto.i32) -> MaskType` + +**Description**: Unpacks 32-bit integer to mask (granularity determined by context). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `packed` | `pto.i32` | Packed mask bits | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `MaskType` | Unpacked mask | + +#### `pto.pnot(mask: MaskType) -> MaskType` + +**Description**: Logical negation of mask bits. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Input mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `negated` | `MaskType` | Negated mask | + +#### `pto.psel(mask: MaskType, true_val: ScalarType, false_val: ScalarType) -> ScalarType` + +**Description**: Selects between two scalar values based on mask. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Selection mask | +| `true_val` | `ScalarType` | Value selected when mask bit is 1 | +| `false_val` | `ScalarType` | Value selected when mask bit is 0 | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `ScalarType` | Selected scalar value | + +### Unary Vector Operations + +Element-wise unary operations on vector registers. + +#### `pto.vabs(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Absolute value of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask (granularity must match vector element type) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Absolute values | + +**Constraints**: +- Mask granularity must match vector element type (e.g., `f32` requires `mask_b32`) + +**Example**: +```python +abs_vec = pto.vabs(vec_f32, mask32) +``` + +#### `pto.vexp(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Exponential of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Exponential values | + +#### `pto.vln(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Natural logarithm of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Natural logarithm values | + +#### `pto.vsqrt(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Square root of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Square root values | + +#### `pto.vrelu(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: ReLU activation (max(0, x)) of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | ReLU-activated values | + +#### `pto.vnot(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Bitwise NOT of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Bitwise NOT values | + +#### `pto.vcadd(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Complex addition of vector elements (treating pairs as complex numbers). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector (interpreted as complex pairs) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Complex addition result | + +#### `pto.vcmax(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Complex maximum of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector (interpreted as complex pairs) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Complex maximum result | + +### Binary Vector Operations + +Element-wise binary operations on vector registers. + +#### `pto.vadd(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise addition of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Sum of vectors | + +**Example**: +```python +sum_vec = pto.vadd(vec_a, vec_b, mask32) +``` + +#### `pto.vsub(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise subtraction of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Difference of vectors | + +#### `pto.vmul(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise multiplication of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Product of vectors | + +#### `pto.vdiv(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise division of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Quotient of vectors | + +#### `pto.vmax(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise maximum of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Element-wise maximum | + +#### `pto.vmin(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise minimum of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Element-wise minimum | + +#### `pto.vand(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise bitwise AND of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Bitwise AND result | + +#### `pto.vor(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise bitwise OR of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Bitwise OR result | + +#### `pto.vxor(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise bitwise XOR of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Bitwise XOR result | + +#### `pto.vshl(vec: VRegType, shift: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise shift left (vector shift amounts). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `shift` | `VRegType` | Shift amounts (per element) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Shifted values | + +#### `pto.vshr(vec: VRegType, shift: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise shift right (vector shift amounts). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `shift` | `VRegType` | Shift amounts (per element) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Shifted values | + +### Vector-Scalar Operations + +Operations between vectors and scalars. + +#### `pto.vmuls(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Vector multiplied by scalar (broadcast). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Scalar multiplier | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Scaled vector | + +**Example**: +```python +scaled = pto.vmuls(vec_f32, pto.f32(2.0), mask32) +``` + +#### `pto.vadds(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Vector plus scalar (broadcast). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Scalar addend | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Result vector | + +#### `pto.vmaxs(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Element-wise maximum of vector and scalar. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Scalar value | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Maximum values | + +#### `pto.vmins(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Element-wise minimum of vector and scalar. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Scalar value | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Minimum values | + +#### `pto.vlrelu(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Leaky ReLU activation (max(αx, x)). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Alpha coefficient | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Leaky ReLU activated values | + +#### `pto.vshls(vec: VRegType, shift: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Vector shift left by scalar (uniform shift). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `shift` | `ScalarType` | Shift amount (same for all elements) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Shifted values | + +#### `pto.vshrs(vec: VRegType, shift: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Vector shift right by scalar (uniform shift). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `shift` | `ScalarType` | Shift amount (same for all elements) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Shifted values | + +### Carry & Select Operations + +Operations with carry propagation and selection. + +#### `pto.vaddc(vec1: VRegType, vec2: VRegType, carry_in: ScalarType, mask: MaskType) -> (VRegType, ScalarType)` + +**Description**: Vector addition with carry input and output. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `carry_in` | `ScalarType` | Input carry bit | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Sum vector | +| `carry_out` | `ScalarType` | Output carry bit | + +#### `pto.vsubc(vec1: VRegType, vec2: VRegType, borrow_in: ScalarType, mask: MaskType) -> (VRegType, ScalarType)` + +**Description**: Vector subtraction with borrow input and output. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `borrow_in` | `ScalarType` | Input borrow bit | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Difference vector | +| `borrow_out` | `ScalarType` | Output borrow bit | + +#### `pto.vsel(mask: MaskType, true_vec: VRegType, false_vec: VRegType) -> VRegType` + +**Description**: Vector select based on mask. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Selection mask | +| `true_vec` | `VRegType` | Vector selected when mask bit is 1 | +| `false_vec` | `VRegType` | Vector selected when mask bit is 0 | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Selected vector | + +**Example**: +```python +result = pto.vsel(mask32, scaled_vec, original_vec) +``` + +### Data Rearrangement + +Operations for rearranging data within vectors. + +#### `pto.pdintlv_b8(mask: pto.mask_b8) -> pto.mask_b8` + +**Description**: Deinterleave 8-bit mask. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `pto.mask_b8` | Input 8-bit mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `pto.mask_b8` | Deinterleaved mask | + +#### `pto.pintlv_b16(mask: pto.mask_b16) -> pto.mask_b16` + +**Description**: Interleave 16-bit mask. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `pto.mask_b16` | Input 16-bit mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `pto.mask_b16` | Interleaved mask | + +#### `pto.vintlv(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Interleave two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Interleaved vector | + +#### `pto.vdintlv(vec: VRegType, mask: MaskType) -> (VRegType, VRegType)` + +**Description**: Deinterleave vector into two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec1` | `VRegType` | First deinterleaved vector | +| `vec2` | `VRegType` | Second deinterleaved vector | + +### Conversion & Special Operations + +Type conversion and specialized operations. + +#### `pto.vtrc(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Truncate vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Truncated vector | + +#### `pto.vcvt(vec: VRegType, to_type: Type, mask: MaskType) -> VRegType` + +**Description**: Type conversion of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `to_type` | `Type` | Target element type | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Converted vector | + +#### `pto.vbitsort(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Bitonic sort of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Sorted vector | + +#### `pto.vmrgsort4(vec1: VRegType, vec2: VRegType, vec3: VRegType, vec4: VRegType, mask: MaskType) -> VRegType` + +**Description**: 4-way merge sort of vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `vec3` | `VRegType` | Third input vector | +| `vec4` | `VRegType` | Fourth input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Merged and sorted vector | + +### Stateless Store Operations + +Operations for storing data from vector registers to memory (stateless). + +#### `pto.vsts(vec: VRegType, buf: UBRef, offset: Index, mask: MaskType) -> None` +#### `pto.vsts(vec: VRegType, tile[row, col:], mask: MaskType) -> None` +#### `pto.vsts(vec: VRegType, tile[start:], mask: MaskType) -> None` + +**Description**: Stateless vector store to buffer. Supports both byte-offset and element-indexing syntax. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Vector to store | +| `buf` | `UBRef` | Destination buffer or pointer (UB memory space) | +| `offset` | `Index` | Byte offset | +| `mask` | `MaskType` | Predicate mask | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Vector to store | +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: None (side-effect operation) + +**Constraints**: +- Buffer must be in UB memory space +- For byte-offset syntax: offset must be properly aligned based on element type +- For element-indexing syntax: the destination vector region must be within tile bounds and satisfy alignment requirements + +**Examples**: +```python +# Byte-offset syntax +pto.vsts(vec_f32, ub_ptr, lane * 256, mask32) + +# Element-indexing syntax +pto.vsts(vec, tile[i, j:], mask) # Store to row i, columns j to j+vector_lanes-1 +pto.vsts(vec, tile[k:], mask) # Store to 1D tile, elements k to k+vector_lanes-1 + +# In a generic kernel +@pto.vkernel(target="a5", op="copy", dtypes=[(pto.AnyFloat, pto.AnyFloat)], priority=10) +def generic_store(src: pto.Tile, dst: pto.Tile): + rows, cols = src.shape + all_mask = pto.make_mask(src.element_type, PAT.ALL) + for i in range(0, rows): + for j in range(0, cols, vector_lanes): + vec = pto.vlds(src[i, j:]) + pto.vsts(vec, dst[i, j:], all_mask) # No manual offset calculation +``` + +#### `pto.psts(mask: MaskType, buf: UBRef, offset: Index) -> None` +#### `pto.psts(mask: MaskType, tile[row, col:]) -> None` +#### `pto.psts(mask: MaskType, tile[start:]) -> None` + +**Description**: Predicate store to buffer. Supports both traditional byte-offset syntax and new element-indexing syntax. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Mask to store | +| `buf` | `UBRef` | Destination buffer or pointer | +| `offset` | `Index` | Byte offset | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Mask to store | +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | + +**Parameters (1D element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Mask to store | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | + +**Returns**: None (side-effect operation) + +#### `pto.vsst(scalar: ScalarType, buf: UBRef, offset: Index, mask: MaskType) -> None` +#### `pto.vsst(scalar: ScalarType, tile[row, col:], mask: MaskType) -> None` +#### `pto.vsst(scalar: ScalarType, tile[start:], mask: MaskType) -> None` + +**Description**: Scalar to vector store (broadcast scalar to all lanes). Supports both traditional byte-offset syntax and new element-indexing syntax. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `scalar` | `ScalarType` | Scalar value | +| `buf` | `UBRef` | Destination buffer or pointer | +| `offset` | `Index` | Byte offset | +| `mask` | `MaskType` | Predicate mask | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `scalar` | `ScalarType` | Scalar value | +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | +| `mask` | `MaskType` | Predicate mask | + +**Parameters (1D element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `scalar` | `ScalarType` | Scalar value | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: None (side-effect operation) + +#### `pto.vstx2(vec1: VRegType, vec2: VRegType, buf1: UBRef, buf2: UBRef, offset: Index, mask: MaskType) -> None` +#### `pto.vstx2(vec1: VRegType, vec2: VRegType, tile1[row, col:], tile2[row, col:], mask: MaskType) -> None` +#### `pto.vstx2(vec1: VRegType, vec2: VRegType, tile1[start:], tile2[start:], mask: MaskType) -> None` + +**Description**: Dual vector store to two buffers. Supports both traditional byte-offset syntax and new element-indexing syntax. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First vector to store | +| `vec2` | `VRegType` | Second vector to store | +| `buf1` | `UBRef` | First destination buffer | +| `buf2` | `UBRef` | Second destination buffer | +| `offset` | `Index` | Byte offset (applied to both buffers) | +| `mask` | `MaskType` | Predicate mask | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First vector to store | +| `vec2` | `VRegType` | Second vector to store | +| `tile1[row, col:]` | `Tile` with indexing | First 2D tile with row index and starting column (vector-width range) | +| `tile2[row, col:]` | `Tile` with indexing | Second 2D tile with row index and starting column (vector-width range) | +| `mask` | `MaskType` | Predicate mask | + +**Parameters (1D element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First vector to store | +| `vec2` | `VRegType` | Second vector to store | +| `tile1[start:]` | `Tile` with indexing | First 1D tile with starting element index (vector-width range) | +| `tile2[start:]` | `Tile` with indexing | Second 1D tile with starting element index (vector-width range) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: None (side-effect operation) + +#### `pto.vsta(vec: VRegType, buf: UBRef, offset: Index, align: pto.align, mask: MaskType) -> None` +#### `pto.vsta(vec: VRegType, tile[row, col:], align: pto.align, mask: MaskType) -> None` +#### `pto.vsta(vec: VRegType, tile[start:], align: pto.align, mask: MaskType) -> None` + +**Description**: Aligned vector store with explicit alignment carrier. Supports both traditional byte-offset syntax and new element-indexing syntax. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Vector to store | +| `buf` | `UBRef` | Destination buffer or pointer | +| `offset` | `Index` | Byte offset | +| `align` | `pto.align` | Alignment specification | +| `mask` | `MaskType` | Predicate mask | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Vector to store | +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | +| `align` | `pto.align` | Alignment specification | +| `mask` | `MaskType` | Predicate mask | + +**Parameters (1D element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Vector to store | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | +| `align` | `pto.align` | Alignment specification | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: None (side-effect operation) + +### Stateful Store Operations + +Operations for storing data with stateful semantics. + +#### `pto.pstu(mask: MaskType, buf: UBRef, offset: Index) -> None` + +**Description**: Predicate stateful store. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Mask to store | +| `buf` | `UBRef` | Destination buffer or pointer | +| `offset` | `Index` | Byte offset | + +**Returns**: None (side-effect operation) + +#### `pto.vstu(vec: VRegType, buf: UBRef, offset: Index, mask: MaskType) -> None` + +**Description**: Vector stateful store. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Vector to store | +| `buf` | `UBRef` | Destination buffer or pointer | +| `offset` | `Index` | Byte offset | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: None (side-effect operation) + +#### `pto.vstus(align_in: AlignType, offset: i32, vec: VRegType, buf: UBRef) -> AlignType` + +**Description**: No-post unaligned vector store with scalar offset. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `align_in` | `AlignType` | Incoming unaligned-store state | +| `offset` | `i32` | Stream advance offset in elements | +| `vec` | `VRegType` | Vector to store | +| `buf` | `UBRef` | UB destination base pointer | + +**Returns**: Updated align-state token for a later flush op such as `pto.vstas`. + +#### `pto.vstur(align_in: AlignType, vec: VRegType, buf: UBRef, mode: str) -> AlignType` + +**Description**: Unaligned vector store using the SPR-AR-driven stateful form. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `align_in` | `AlignType` | Incoming unaligned-store state | +| `vec` | `VRegType` | Vector to store | +| `buf` | `UBRef` | UB destination base pointer | +| `mode` | `str` | `POST_UPDATE` or `NO_POST_UPDATE` | + +**Returns**: Updated align-state token for a later flush op such as `pto.vstar`. + +## Examples + +### Simple Vector Copy + +```python +@pto.vkernel(...) +def vector_copy(src: pto.memref(256, pto.f32, MemorySpace.UB), + dst: pto.memref(256, pto.f32, MemorySpace.UB)): + all_mask = pto.make_mask(pto.f32, PAT.ALL) + for offset in range(0, 256, 64): + vec = pto.vlds(src, offset) + pto.vsts(vec, dst, offset, all_mask) +``` + +### Conditional Computation + +```python +@pto.vkernel(...) +def conditional_scale(src: pto.ptr(pto.f32, MemorySpace.GM), + dst: pto.ptr(pto.f32, MemorySpace.GM), + threshold: pto.f32): + # ... setup ... + + with pto.strict_vecscope(ub_in, ub_out, threshold) as (vin, vout, thresh): + for i in range(0, 1024, 64): + vec = pto.vlds(vin, i) + + # Compare with threshold + mask = pto.pge_b32(vec, thresh) + + # Scale values above threshold + scaled = pto.vmuls(vec, pto.f32(2.0), mask) + + # Keep original values below threshold + result = pto.vsel(mask, scaled, vec) + + pto.vsts(result, vout, i, all_mask) +``` + +### Loop with Carry + +```python +@pto.vkernel(...) +def prefix_sum(src: pto.ptr(pto.i32, MemorySpace.UB), + dst: pto.ptr(pto.i32, MemorySpace.UB)): + all_mask = pto.make_mask(pto.i32, PAT.ALL) + carry = pto.i32(0) + + for i in range(0, 256, 64): + vec = pto.vlds(src, i) + result, carry = pto.vaddcs(vec, carry, all_mask) + pto.vsts(result, dst, i, all_mask) +``` + +## Common Errors + +### Typed Mask Mismatch + +``` +Error: f32 vector operation cannot consume mask_b16 +``` + +**Solution:** Ensure mask granularity matches vector element size: +- `f32` vectors use `mask_b32` +- `f16` vectors use `mask_b16` +- `i8` vectors use `mask_b8` + +### Strict Scope Implicit Capture + +``` +Error: strict_vecscope body cannot capture outer value 'ub_in' implicitly +``` + +**Solution:** Pass all required values in the capture list: + +```python +# Wrong: +with pto.strict_vecscope() as (): + vec = pto.vlds(ub_in, offset) # ub_in from outer scope + +# Correct: +with pto.strict_vecscope(ub_in) as (ub): + vec = pto.vlds(ub, offset) +``` + +### Untyped Loop Carried State + +``` +Error: loop-carried value must have explicit machine type +``` + +**Solution:** Add type annotations to loop-carried variables: + +```python +# Wrong: +remaining = 1024 # Plain Python int +for i in range(0, N, step): + mask, remaining = pto.make_mask(pto.f32, remaining) + +# Correct: +remaining: pto.i32 = 1024 +# or +remaining = pto.i32(1024) +``` + +## Compatibility Notes + +The current experimental implementation in `python/pto/dialects/pto.py` differs from this specification in several ways: + +1. **Mask types**: The experimental version uses untyped `mask` instead of `mask_b8`/`mask_b16`/`mask_b32` +2. **Barrier operation**: Uses `pto.barrier()` instead of `pto.pipe_barrier()` +3. **MemRef support**: Does not yet support `pto.memref()` types +4. **Operation coverage**: Implements only a subset of operations + +When implementing new code, follow this specification. The experimental implementation will be updated to match over time. + +## Next Steps + +- Explore the ISA documentation in `docs/isa/` for detailed operation semantics +- Check `test/samples/` for example kernels +- Refer to `docs/vpto-spec.md` for the underlying VPTO instruction specification + +For compiler developers, see `docs/PTO_IR_manual.md` for MLIR-level details. diff --git a/docs/tilelang-dsl-syntax-sugar-proposals.md b/docs/tilelang-dsl-syntax-sugar-proposals.md new file mode 100644 index 000000000..8a60466d9 --- /dev/null +++ b/docs/tilelang-dsl-syntax-sugar-proposals.md @@ -0,0 +1,404 @@ +# TileLang DSL Syntax Sugar Proposals + +## Overview + +This document proposes syntax sugar enhancements for the TileLang Python DSL to improve programming ergonomics while maintaining close correspondence with the underlying VPTO IR. The current DSL design closely mirrors VPTO instructions, which can lead to verbose and error-prone code. These proposals aim to provide higher-level abstractions that compile down to the existing VPTO operations. + +## Current Usability Challenges + +### 1. **Low-Level Pointer Operations** +```python +# Current: manual byte offset management +ub_in = pto.castptr(0, pto.ptr(pto.f32, MemorySpace.UB)) +ub_out = pto.castptr(4096, pto.ptr(pto.f32, MemorySpace.UB)) +next_ptr = pto.addptr(ub_ptr, 4096) +``` +**Problem**: Users must manage byte offsets and memory spaces manually. + +### 2. **Verbose Copy Operations** +The `pto.copy_ubuf_to_ubuf` operation has 7 parameters: +- `src_offset`, `src_stride0`, `src_stride1` +- `dst_offset`, `dst_stride0`, `dst_stride1` + +**Problem**: Correctly setting stride parameters is error-prone, especially for multi-dimensional data. + +### 3. **Precise Mask Type Matching** +```python +# Must ensure mask granularity matches element type +mask32 = pto.pset_b32("PAT_ALL") # f32 requires b32 mask +mask16 = pto.pset_b16("PAT_ALL") # f16 requires b16 mask +``` +**Problem**: Type error messages are not intuitive and easy to confuse. + +### 4. **Strict Vector Scope Requirements** +```python +# strict_vecscope requires explicit capture of all variables +with pto.strict_vecscope(src_ptr, dst_ptr, start, end) as (s, d, lb, ub): + # Can only use captured variables +``` +**Problem**: Increases boilerplate code, especially when multiple variables need capture. + +### 5. **Manual Synchronization Management** +```python +pto.set_flag(PIPE.MTE2, PIPE.V, EVENT.ID0) +pto.wait_flag(PIPE.MTE2, PIPE.V, EVENT.ID0) +``` +**Problem**: Easy to forget synchronization or use wrong event IDs. + +### 6. **Byte Offsets vs. Element Indices** +```python +# Need to calculate byte offsets +vec = pto.vlds(ub_ptr, lane * 256) # Assuming f32, 4 bytes per element +``` +**Problem**: Users must understand underlying memory layout. + +## Proposed Syntax Sugar Enhancements + +### 1. **Array View Abstraction** + +#### Current API +```python +# Low-level pointer operations +ub_ptr = pto.castptr(0, pto.ptr(pto.f32, MemorySpace.UB)) +vec = pto.vlds(ub_ptr, 64 * 4) # Load 64th f32 element +``` + +#### Proposed Syntax Sugar +```python +# Create array views +ub_array = pto.ub_array(256, pto.f32, base_offset=0) # 256-element f32 UB array +gm_array = pto.gm_array(1024, pto.f32, src) # GM pointer array view + +# Element access with automatic offset calculation +element = ub_array[64] # Get 64th element (auto-calculates byte offset) +slice = ub_array[128:256] # Slice operation + +# Array assignment (compiles to appropriate copy operations) +ub_array[0:64] = gm_array[0:64] # Compiles to copy_gm_to_ubuf + +# Multi-dimensional arrays +ub_2d = pto.ub_array((256, 128), pto.f32) # 2D array +row = ub_2d[32, :] # Row slice +col = ub_2d[:, 64] # Column slice +``` + +#### Implementation Notes +- `ub_array[64]` → `pto.vlds(ub_ptr, 64 * sizeof(f32))` +- `ub_array[0:64] = gm_array[0:64]` → Appropriate `copy_gm_to_ubuf` call with stride calculations +- Array views are compile-time constructs with no runtime overhead + +### 2. **Simplified Copy Operations** + +#### Current API +```python +pto.copy_gm_to_ubuf(src, dst, 0, 32, 128, 0, 0, False, 0, 128, 128) +``` + +#### Proposed Syntax Sugar +```python +# Full array copy +pto.copy_gm_to_ub(gm_array, ub_array) + +# Slice copy with automatic stride calculation +pto.copy_gm_to_ub(gm_array[0:64], ub_array[128:192]) + +# Copy with element count +pto.copy_gm_to_ub(gm_array, ub_array, count=64) + +# Transpose copy +pto.copy_gm_to_ub(gm_array, ub_array, transpose=True) + +# Multi-dimensional copy with automatic stride inference +pto.copy_gm_to_ub(gm_2d[0:32, :], ub_2d[:, 0:64]) + +# Chained operations +(pto.copy_gm_to_ub(gm_array, ub_array) + .then(pto.copy_ub_to_ub(ub_array, ub_temp)) + .then(pto.copy_ub_to_gm(ub_temp, dst_array))) +``` + +### 3. **Automatic Mask Inference** + +#### Current API +```python +# Must specify mask type explicitly +mask32 = pto.pset_b32("PAT_ALL") +vec_f32 = pto.vlds(ptr, offset) +out = pto.vabs(vec_f32, mask32) +``` + +#### Proposed Syntax Sugar +```python +# Automatic mask type inference +mask = pto.pset("PAT_ALL") # Inferred as mask_b32 for f32 vectors +out = pto.vabs(vec_f32, mask) # Type-safe, auto-matched + +# Vector method syntax (more Pythonic) +out = vec_f32.abs(mask="PAT_ALL") +out = vec_f32.add(other_vec, mask=pto.pset("PAT_EVEN")) +out = vec_f32.max(scalar, mask="PAT_ALL") + +# Mask creation from comparison +mask = vec_f32 >= pto.f32(0.0) # Creates appropriate mask_b32 +mask = vec_f32 < threshold # Auto-infers mask type + +# Mask operations with auto-typing +combined = mask1 & mask2 # Bitwise AND with type preservation +inverted = ~mask # Logical NOT +``` + +### 4. **Simplified Synchronization Primitives** + +#### Current API +```python +pto.set_flag(PIPE.MTE2, PIPE.V, EVENT.ID0) +# ... computation ... +pto.wait_flag(PIPE.MTE2, PIPE.V, EVENT.ID0) +``` + +#### Proposed Syntax Sugar +```python +# Context manager for automatic synchronization +with pto.sync_between(PIPE.MTE2, PIPE.V, event=EVENT.ID0): + # set_flag called on entry, wait_flag on exit + pto.copy_gm_to_ub(src, dst) + compute_block() + +# Decorator for function-level synchronization +@pto.synchronized(from_pipe=PIPE.MTE2, to_pipe=PIPE.V) +def compute_block(): + # Automatic synchronization before and after + pass + +# Pipeline synchronization chain +with pto.pipeline([ + (PIPE.MTE2, PIPE.V, EVENT.ID0), + (PIPE.V, PIPE.MTE3, EVENT.ID1), + (PIPE.MTE3, PIPE.S, EVENT.ID2) +]): + # Multi-stage synchronization + stage1() + stage2() + stage3() +``` + +### 5. **Element-Level Indexing Operations** + +#### Current API +```python +# Byte offset calculation required +vec = pto.vlds(ub_ptr, lane * 256) # Need to know f32 is 4 bytes +``` + +#### Proposed Syntax Sugar +```python +# Element-level indexing +vec = pto.vlde(ub_array, lane) # Automatic byte offset calculation +pto.vste(vec, ub_array, lane) # Element-level store + +# Array view methods +vec = ub_array.load_element(lane) +ub_array.store_element(lane, vec) + +# Batch operations +vectors = ub_array.load_elements([0, 64, 128, 192]) +ub_array.store_elements([256, 320, 384], vectors) + +# Strided access +stride = ub_array.load_stride(start=0, end=1024, step=64) +``` + +### 6. **Type Inference Simplification** + +#### Current API +```python +# Explicit type annotations required +remaining: pto.i32 = 1024 +# or +remaining = pto.i32(1024) +``` + +#### Proposed Syntax Sugar +```python +# Automatic type inference for constants +remaining = pto.constant(1024) # Inferred as i32 or i64 from context +step = pto.constant(64, type=pto.i32) # Explicit type specification + +# Typed range with automatic inference +for i in pto.range(0, 1024, 64): # i automatically gets correct machine type + # i is pto.i32 + +# Function argument type inference +@pto.vkernel +def kernel(x): # Type inferred from usage + return x * pto.constant(2) # x type inferred from multiplication + +# Variable type inference from operations +result = pto.constant(10) + pto.constant(20) # result is pto.i32 +``` + +### 7. **More Flexible Vector Scopes** + +#### Current API +```python +# Explicit capture required +with pto.strict_vecscope(src_ptr, dst_ptr, start, end) as (s, d, lb, ub): + for i in range(lb, ub, step): + vec = pto.vlds(s, i) + pto.vsts(vec, d, i, mask) +``` + +#### Proposed Syntax Sugar +```python +# Automatic variable capture +with pto.vector_scope(): + # Variables used in scope are automatically captured + for i in pto.range(start, end, step): + vec = src_array.load_element(i) + dst_array.store_element(i, vec.abs()) + +# Decorator for vectorized functions +@pto.vectorize +def compute_element(src, dst, index): + vec = src.load_element(index) + dst.store_element(index, vec.abs()) + +# Apply vectorized function across range +pto.vector_map(compute_element, src_array, dst_array, range(0, 1024, 64)) + +# Lambda support +pto.vector_map(lambda x: x.abs(), src_array, dst_array) +``` + +### 8. **Built-in Utility Functions** + +#### Common Pattern Encapsulation +```python +# Vector map/reduce operations +result = pto.vector_map(abs, src_array, dst_array) # Element-wise mapping +sum = pto.vector_reduce(add, array) # Reduction +max_val = pto.vector_reduce(max, array) # Maximum reduction + +# Vector zip/unzip +zipped = pto.vector_zip(src1, src2, dst) # Interleave +unzipped1, unzipped2 = pto.vector_unzip(src, dst1, dst2) # Deinterleave + +# Mathematical functions +result = pto.vector_sin(array) +result = pto.vector_exp(array) +result = pto.vector_relu(array) +result = pto.vector_sigmoid(array) + +# Statistical operations +mean = pto.vector_mean(array) +variance = pto.vector_variance(array) +min_val, max_val = pto.vector_minmax(array) + +# Linear algebra (small-scale) +dot_product = pto.vector_dot(vec1, vec2) +norm = pto.vector_norm(array) +``` + +## Implementation Strategy + +These syntax sugar enhancements can be implemented through: + +1. **Python Decorators and Context Managers**: For synchronization and vector scopes +2. **Wrapper Classes**: `UBArray`, `GMArray`, `Vector` classes that encapsulate low-level operations +3. **Operator Overloading**: Support for `[]`, `:`, arithmetic operators on wrapper classes +4. **Type Inference System**: Context-based machine type inference +5. **Compile-time Transformation**: Conversion of high-level syntax to low-level VPTO operations before IR generation + +## Compatibility with VPTO IR + +**Key Principle**: All syntax sugar must ultimately lower to existing VPTO operations. + +### Lowering Examples + +| Syntax Sugar | VPTO IR Equivalent | +|--------------|-------------------| +| `ub_array[64]` | `pto.vlds(ub_ptr, 64 * sizeof(f32))` | +| `pto.copy_gm_to_ub(src_array, dst_array)` | Appropriate `copy_gm_to_ubuf` call with calculated strides | +| `with pto.sync_between(...):` | `set_flag` + `wait_flag` pair | +| `mask = vec_f32 >= pto.f32(0.0)` | `pto.pge_b32(vec_f32, pto.f32(0.0))` | +| `vec_f32.abs(mask="PAT_ALL")` | `pto.vabs(vec_f32, pto.pset_b32("PAT_ALL"))` | + +## Prioritization + +### High Priority (Immediate Value) +1. Array view abstraction +2. Simplified copy operations +3. Automatic mask inference + +### Medium Priority (Significant Ergonomics Improvement) +4. Element-level indexing +5. Type inference simplification +6. Flexible vector scopes + +### Low Priority (Advanced Features) +7. Enhanced synchronization primitives +8. Built-in utility functions + +## Migration Path + +The existing low-level API will remain available for performance-critical code or direct VPTO IR correspondence. Syntax sugar will be provided as an optional layer that can be mixed with low-level operations. + +```python +# Mixed usage example +@pto.vkernel +def mixed_kernel(src: pto.ptr(pto.f32, MemorySpace.GM), + dst: pto.ptr(pto.f32, MemorySpace.GM)): + # Low-level: manual pointer setup + ub_in = pto.castptr(0, pto.ptr(pto.f32, MemorySpace.UB)) + + # High-level: array view for computation + ub_array = pto.ub_array(256, pto.f32, base_ptr=ub_in) + + # Mixed: low-level copy, high-level computation + pto.copy_gm_to_ubuf(src, ub_in, 0, 32, 128, 0, 0, False, 0, 128, 128) + + with pto.vector_scope(): + for i in pto.range(0, 256, 64): + vec = ub_array.load_element(i) + result = vec.abs(mask="PAT_ALL") + ub_array.store_element(i, result) + + # Low-level: copy back + pto.copy_ubuf_to_gm(ub_in, dst, 0, 32, 128, 0, 128, 128) +``` + +## Next Steps + +1. **Prototype Implementation**: Start with array view abstraction and simplified copy operations +2. **User Feedback**: Gather feedback from performance engineers on the proposed syntax +3. **Gradual Rollout**: Implement enhancements in phases, starting with high-priority items +4. **Documentation**: Update DSL guide with syntax sugar examples and migration guides +5. **Testing**: Ensure all syntax sugar correctly lowers to VPTO IR and maintains performance + +These enhancements will significantly improve the TileLang DSL's usability while maintaining the close correspondence with underlying VPTO IR that performance engineers require. + +1. 软件流水线(Software Pipelining)的表达成本 +在 NPU 上写 Vector 级算子,最难的往往不是数值计算,而是利用 UB (Unified Buffer) 进行 Double/Multi-Buffering(乒乓缓存),并手动排布内存搬运与计算的流水线。 + +现状挑战:如果开发者全靠手写 set_flag、wait_flag,以及手动维护 Ping-Pong 缓冲的偏移量,代码会迅速膨胀且极易死锁或读写冲突。 + +优化建议:DSL 在保留底层原语的同时,可以提供稍微高级一点的流水线抽象。例如,引入 pto.CircularBuffer(tile, num_stages=2) 的概念,让开发者可以专注于“当前 stage 的计算”,而由底层生成器自动完成不同 stage 的指针轮转和 Flag 同步。 + +2. Python 宿主变量 vs MLIR SSA 变量的心智模型边界 +因为 DSL 的本质是用 Python 元编程来生成 MLIR(静态图),开发者在写代码时很容易混淆“Python 运行期的值”和“NPU 运行期的值”。 + +现状挑战:手册中提到“变量的自动合并”(比如 if 分支产生合并),这涉及到复杂的 SSA 转换。特别是在 for 循环中,**循环携带状态(Loop-carried state)**的处理往往是个痛点。如果开发者在循环外定义了一个 Python 列表或字典,在循环内去修改它,这在生成 MLIR 的 scf.for 时是无法正确映射的。 + +优化建议:需要有极其明确的类型系统提示或语法边界,强制区分编译期求值的变量(Meta-variables)和生成的 MLIR Value。可以考虑借鉴 Triton 的方式,提供类似 tl.constexpr 的装饰或类型,让开发者清楚哪些分支在生成 MLIR 时会被静态展开,哪些会真正生成 scf.if。 + +3. 地址计算(Address Generation)的易错性 +即使是对底层开发者,手动计算字节偏移也是痛苦且容易出 Bug 的。 + +现状挑战:i * cols * 4 这种强依赖 f32 占用 4 字节的硬编码,在泛型算子开发中会带来负担(比如想写一个同时兼容 f16 和 f32 的模板算子)。 + +优化建议:提供基于语义的视图(View)操作。保留控制力不代表必须算字节。可以提供类似 tile.get_vector_slice(row_idx, vec_idx) 的接口,它在内部自动 Emit(发射)对应的 MLIR 乘法和加法指令来计算 offset。这不仅防呆,还能让生成的 MLIR 结构更规范。 + +4. Mask 的隐式推导(针对边界处理) +NPU 算子经常要处理尾部不对齐的数据(Tail processing)。 + +优化建议:虽然底层需要具体的 Mask 寄存器配置(如 PAT_ALL),但在 for 循环的最后一步边界处理时,能否提供一个类似 pto.make_mask(remaining_elements) 的宏/内联函数?让它在生成 MLIR 时,自动展开为对应的硬件 plt_b32 等指令,这样可以大幅减少手写冗长边界判断的样板代码。 \ No newline at end of file diff --git a/docs/vpto-spec.md b/docs/vpto-spec.md new file mode 100644 index 000000000..efb383feb --- /dev/null +++ b/docs/vpto-spec.md @@ -0,0 +1,974 @@ +# PTO micro Instruction Spec — Merged Draft (A5) + +> **Status:** DRAFT for review +> **Base:** [vpto-spec.md](https://github.com/mouliangyu/PTOAS/blob/feature-vpto-backend/docs/vpto-spec.md) (2026-03-20) +> **Updated:** 2026-03-27 + +--- + +## Part I: Architecture Overview + +### Overview + +This document defines the PTO micro Instruction, a compiler-internal and externally facing specification designed to represent vector compute kernels within the PTO architecture. Much like NVVM provides a robust IR for GPU architectures, the PTO micro Instruction serves as the direct bridge between high-level programming models and the underlying hardware ISA, providing a precise, low-level representation of vector workloads explicitly designed for the Ascend 950 architecture. + +#### Position in the Stack and Layer Modeled + +The PTO micro Instruction operates as a very low-level intermediate representation within the PTO compiler stack. It is uniquely designed to accurately and comprehensively express all architectural information of the Ascend 950 hardware. It specifically models the bare-metal vector execution layer, making hardware-specific capabilities and constraints, such as exact vector lane configurations, memory space hierarchies, and hardware-specific fusion semantics, fully transparent and controllable. + +#### PTO Instruction Modes and Compilation Flows + +Within the end-to-end PTO software stack, PTO instructions may appear in three closely related authoring or lowering modes: + +- **PTO Tile Instruction**: tile-oriented PTO code that serves as a nano-kernel encapsulation of Tile operations, primarily expressing computation and data movement in terms of tile buffers, tile shapes, and tile-local layout. +- **PTO micro Instruction**: vector-execution-oriented PTO code that makes DMA setup, vector registers, masks, synchronization, and `__VEC_SCOPE__` boundaries explicit. This document is centered on this mode. +- **PTO Tile+micro Instruction**: a hybrid PTO form that keeps tile-level orchestration while embedding explicit micro-instruction regions where direct vector-pipeline control is required. + +From these PTO instruction forms, the stack can proceed along two main compilation flows: + +- **CCE generation flow**: PTO ISA is lowered into a CCE-oriented representation, which is then compiled by the BiSheng toolchain into Ascend device binaries. +- **Bytecode generation flow**: PTO ISA is emitted as bytecode, which is then compiled by the BiSheng toolchain into Ascend device binaries. + +```text +High-level frameworks / DSLs / library kernels + | + v + +----------------------------------+ + | PTO ISA layer | + | | + | (1) PTO Tile Instruction | + | (2) PTO micro Instruction | + | (3) PTO Tile+micro Instruction | + +----------------+-----------------+ + | + +------------+------------+ + | | + v v + +-------------------------+ +-------------------------+ + | Path A: generate CCE | | Path B: generate | + | (CCE-oriented form) | | bytecode | + +------------+------------+ +------------+------------+ + | | + v v + +-----------------------------------------------+ + | BiSheng compiler | + +---------------------------+-------------------+ + | + v + +-----------------------------+ + | Ascend device binaries | + +-----------------------------+ +``` + +#### Why External Developers Read or Author PTO micro Instruction + +While the majority of users will interact with the PTO architecture via higher-level frameworks, external developers may need to read or author PTO micro Instruction directly for several key reasons: + +- Custom Toolchain Development: build custom compiler frontends or domain-specific languages (DSLs) that target the Ascend 950 architecture with maximum hardware utilization. +- Performance Engineering: inspect the output of high-level compiler passes, verify fine-grained optimization behaviors, and pinpoint performance bottlenecks at the architectural level. +- Micro-Optimization: hand-author highly optimized, critical mathematical kernels using a stable, precise IR when higher-level abstractions cannot achieve the theoretical peak performance of the hardware. + +#### Relationship to CCE + +The PTO micro Instruction is designed to express the full semantic capabilities of the Compute Cube Engine (CCE), but with significant structural and pipeline advantages for compiler development. + +- Bypassing the C/Clang Pipeline: while CCE heavily relies on C/C++ extensions parsed by Clang, the PTO micro Instruction operates entirely independently of the C language frontend. By bypassing Clang AST generation and frontend processing, utilizing the PTO micro Instruction significantly reduces overall compilation time and memory overhead. +- Enhanced IR Verification: because the PTO micro Instruction is a strongly typed, SSA-based (Static Single Assignment) compiler IR rather than a C-wrapper API, it provides a much more rigorous and detailed IR verification process. Structural inconsistencies, invalid memory access patterns, and operand type mismatches are caught immediately with precise, explicit diagnostic feedback, providing developers with much higher visibility into kernel correctness than traditional CCE error reporting. + +#### Intended Audience + +This document is written for compiler engineers, library writers, and advanced performance architects. We expect the reader to have a working understanding of modern compiler infrastructure, specifically MLIR, the principles of Static Single Assignment (SSA) form, and a deep understanding of the vector-processing capabilities of the Ascend 950 architecture. + +### Getting Started + +The PTO micro Instruction is architected as a performance-critical layer within the compiler stack, specifically designed to exploit the **Decoupled Access-Execute** (DAE) nature of the Ascend 950 hardware. + +#### Hardware Pipeline Modeling + +The IR is structured to mirror the three primary hardware pipelines of the Ascend 950 architecture. Correct PTO micro Instruction authoring requires managing the interaction between these asynchronous units: + +**MTE2** (Memory Transfer Engine - Inbound): Responsible for moving data from Global Memory (GM) to the Unified Buffer (UB). + +**Vector Core** (Computation): The primary engine for executing SIMD operations on data stored in UB. + +**MTE3** (Memory Transfer Engine - Outbound): Responsible for moving processed data from UB back to GM. + +#### Architecture Detail: Vector Lane (VLane) + +The vector register is organized as **8 VLanes** of 32 bytes each. A VLane is the atomic unit for group reduction operations. + +``` +vreg (256 bytes total): +┌─────────┬─────────┬─────────┬─────┬─────────┬─────────┐ +│ VLane 0 │ VLane 1 │ VLane 2 │ ... │ VLane 6 │ VLane 7 │ +│ 32B │ 32B │ 32B │ │ 32B │ 32B │ +└─────────┴─────────┴─────────┴─────┴─────────┴─────────┘ +``` + +Elements per VLane by data type: + +| Data Type | Elements/VLane | Total Elements/vreg | +|-----------|---------------|-------------------| +| i8/si8/ui8 | 32 | 256 | +| i16/si16/ui16/f16/bf16 | 16 | 128 | +| i32/si32/ui32/f32 | 8 | 64 | +| i64/si64/ui64 | 4 | 32 | + +#### Memory and Synchronization Model + +The PTO micro Instruction enforces a strict memory hierarchy. The Unified Buffer (UB) is the only valid operand source for vector compute instructions. Consequently, the architecture of a PTO micro Instruction program is defined by the explicit management of data movement: + +**Address Space Isolation**: The IR uses `!pto.ptr` to distinguish between GM (`!pto.ptr`) and UB (`!pto.ptr`). The verifier ensures that vector compute operations do not access GM directly; data must first be moved into UB. + +**UB Capacity**: The Unified Buffer provides 256KB of on-chip SRAM (also referred to as "vecTile"). + +**Data Flow**: + +``` +┌─────────────────────────────────────────────┐ +│ Global Memory (GM) │ +│ (Off-chip HBM/DDR) │ +└─────────────────────┬───────────────────────┘ + │ DMA (MTE2 inbound / MTE3 outbound) +┌─────────────────────▼───────────────────────┐ +│ Unified Buffer (UB) │ +│ (On-chip SRAM, 256KB) │ +└─────────────────────┬───────────────────────┘ + │ Vector Load/Store (PIPE_V) +┌─────────────────────▼───────────────────────┐ +│ Vector Register File (VRF) │ +│ vreg (256B each) + mask (256-bit each) │ +└─────────────────────────────────────────────┘ +``` + +1. **GM → UB**: DMA transfer via MTE2 (`pto.copy_gm_to_ubuf`) +2. **UB → vreg**: Vector Load instructions (`pto.vlds`, `pto.vldsx2`, etc.) +3. **vreg → vreg**: Compute instructions (`pto.vadd`, `pto.vmul`, etc.) +4. **vreg → UB**: Vector Store instructions (`pto.vsts`, `pto.vstsx2`, etc.) +5. **UB → GM**: DMA transfer via MTE3 (`pto.copy_ubuf_to_gm`) + +**Load/Store Access Patterns**: + +For UB↔vreg data movement, besides contiguous load/store, the architecture provides rich access pattern support including strided access, pack/unpack, interleave/deinterleave, broadcast, upsample/downsample, channel split/merge, gather/scatter, and squeeze/expand operations. For detailed instruction syntax and distribution modes, refer to the [Vector Load/Store](isa/03-vector-load-store.md) group in the ISA specification. + +#### Synchronization Model + +The Ascend 950 architecture employs a cluster-based design with a 1:2 ratio of Cube cores to Vector cores. The PTO micro Instruction provides multiple levels of synchronization to manage concurrent execution across pipelines and cores: + +**Inter-Core Synchronization (within a cluster):** + +Synchronization between cores within the same cluster is achieved via the core sync mechanism using `pto.set_intra_core` and `pto.wait_intra_core` operations. This enables coordination between Cube and Vector cores sharing the same cluster resources. + +**Vector Core Pipeline Synchronization:** + +Within a single core, multiple pipelines operate asynchronously: + +- **MTE2 (PIPE_MTE2)**: DMA copy-in from GM to UB +- **MTE3 (PIPE_MTE3)**: DMA copy-out from UB to GM +- **Vector Compute (PIPE_V)**: Vector ALU operations +- **Scalar (PIPE_S)**: Scalar unit running the kernel program + +Pipeline synchronization can be achieved through two mechanisms: + +1. **Flag/Event mechanism**: `pto.set_flag` and `pto.wait_flag` operations resolve Read-After-Write (RAW) and Write-After-Read (WAR) hazards between pipelines. + +2. **Buffer-ID mechanism**: `pto.get_buf` and `pto.rls_buf` provide finer-grained synchronization through buffer acquisition and release semantics for producer-consumer coordination. + +**Intra-Pipeline Memory Barriers (within `__VEC_SCOPE__`):** + +Within the vector execution scope, the hardware does not track UB address aliasing between reg↔UB accesses. When UB addresses overlap or alias between vector load/store operations, explicit memory barriers are required: + +```c +pto.mem_bar "VV_ALL" // All prior vector ops complete before subsequent +pto.mem_bar "VST_VLD" // All prior vector stores visible before subsequent loads +pto.mem_bar "VLD_VST" // All prior vector loads complete before subsequent stores +``` + +Without proper barriers, loads may see stale data or stores may be reordered incorrectly. + +#### Execution Scopes (__VEC_SCOPE__) + +`__VEC_SCOPE__` is the IR-level representation of a Vector Function (VF) launch. In the PTO architecture, it defines the hardware interface between the Scalar Unit and the Vector Thread. + +In PTO micro Instruction source IR, vector execution scopes are modeled as dedicated region ops. The default form is `pto.vecscope`; when the scope body must reject implicit capture and require explicit region arguments, use `pto.strict_vecscope`. + +**Scalar-Vector Interface:** + +The execution model follows non-blocking fork semantics: + +- Scalar invocation: the scalar processor invokes a vector thread by calling a VF. Once the launch command is issued, the scalar unit does not stall and continues executing subsequent instructions in the pipeline. +- Vector execution: after invocation, the vector thread independently fetches and executes the instructions defined within the VF scope. +- Parallelism: this decoupled execution allows the scalar and vector units to run in parallel, so the scalar unit can prepare addresses or manage control flow while the vector unit performs heavy SIMD computation. + +**Launch Mechanism And Constraints:** + +- Parameter buffering: all arguments required by the VF must be staged in hardware-specific buffers. +- Launch overhead: launching a VF incurs a latency of a few cycles. Very small VFs should account for this overhead because launch cost can rival useful computation time. + +**MLIR Representation:** + +```mlir +pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} +``` + +**Strict MLIR Representation:** + +```mlir +pto.strict_vecscope(%ub, %ub_out, %lane) { +^bb0(%in: !pto.ptr, %out: !pto.ptr, %iv: index): + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %in[%iv] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %out[%iv], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} : (!pto.ptr, !pto.ptr, index) -> () +``` + +`pto.strict_vecscope` is the strict form of `pto.vecscope`. + +- `pto.vecscope` allows the body to use surrounding SSA values directly. +- `pto.strict_vecscope` requires every external value used by the body to be passed through the op operand list and received as a body block argument. +- `pto.strict_vecscope` rejects implicit capture from the surrounding scope. +- both ops still represent one explicit VPTO vector interval. +- regardless of whether the source form uses `pto.vecscope`, + `pto.strict_vecscope`, or a lowered carrier loop with + `llvm.loop.aivector_scope`, every op that produces or consumes `!pto.vreg`, + `!pto.mask<...>`, or `!pto.align` must be enclosed by exactly one vector + interval +- nested vector intervals are not part of the legal VPTO surface; ordinary + nested `scf.for` structure is fine, but one vector interval may not contain + another vector interval + +### Example: VecScope + +```mlir +pto.set_loop2_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.copy_gm_to_ubuf %7, %2, %3, %3, %c0_i64, %c32_i64, %4, %c0_i64, %c0_i64, + %false, %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i1, i64, i64, i64 + +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +pto.vecscope { + scf.for %lane = %c0 to %9 step %c64 { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %2[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %8[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } +} + +pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] +pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop2_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 +pto.copy_ubuf_to_gm %8, %14, %3, %3, %c0_i64, %c32_i64, %4, %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i64 +``` + +### Example: Strict VecScope + +```mlir +pto.strict_vecscope(%ub_in, %ub_out, %lane, %remaining) { +^bb0(%in: !pto.ptr, %out: !pto.ptr, %iv: index, %rem: i32): + %mask, %next_remaining = pto.plt_b32 %rem : i32 -> !pto.mask, i32 + %v = pto.vlds %in[%iv] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %out[%iv], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} : (!pto.ptr, !pto.ptr, index, i32) -> () +``` + +Use `pto.strict_vecscope` when the source form should make all vector-scope inputs explicit in the region signature instead of relying on surrounding SSA visibility. The scope op itself only defines the vector-interval boundary and region argument contract. + +### Scope + +This document is the interface specification centered on the `mlir::pto` dialect and the shared MLIR surface used alongside it in PTO micro Instruction programs. + +It only describes: + +- operation names +- operand and result lists +- operand and result types +- important attributes +- C-style semantics for each operation + +It does not describe lowering strategy. + +PTO micro Instruction source programs are not restricted to `pto` operations alone. In practice they also use shared MLIR dialect ops, most notably the full scalar operation surface of `arith` together with structured control-flow ops from `scf`, to express scalar constants, scalar arithmetic, type conversion, comparisons, and structured control flow around PTO vector or tile regions. These shared-dialect ops are part of the supported PTO micro Instruction source surface and should be regarded as part of PTO-ISA alongside `pto` dialect operations. + +### Shared MLIR Dialects + +- `arith`: the full scalar `arith` surface is supported in PTO micro Instruction programs, covering scalar integer, floating-point, boolean, and `index` operations. In current samples the most common uses are still constants, offset/bounds arithmetic, casts, compares, and selects. +- `scf`: structured control flow used to model counted loops, conditional regions, loop-carried state, and break-like control around PTO compute and data-movement ops. +- Shared dialect ops remain in standard MLIR form so that PTO analyses and backend passes can reason about control flow and scalar state without re-encoding them as PTO-specific instructions. + +### Runtime Query Operations + +PTO micro Instruction also provides scalar runtime-query ops for inspecting the +current execution instance. These ops are pure, have no side effects, and may +be used in ordinary scalar control-flow or address computation. + +#### `pto.get_block_idx` + +- **syntax:** `%block = pto.get_block_idx` +- **result:** `i64` +- **semantics:** Return the current block ID. + +```c +block = block_idx(); +``` + +#### `pto.get_subblock_idx` + +- **syntax:** `%subblock = pto.get_subblock_idx` +- **result:** `i64` +- **semantics:** Return the current vector subblock ID. + +```c +subblock = subblock_idx(); +``` + +#### `pto.get_block_num` + +- **syntax:** `%block_num = pto.get_block_num` +- **result:** `i64` +- **semantics:** Return the total number of launched blocks visible to the + current kernel instance. + +```c +block_num = block_num(); +``` + +#### `pto.get_subblock_num` + +- **syntax:** `%subblock_num = pto.get_subblock_num` +- **result:** `i64` +- **semantics:** Return the number of vector subblocks visible to the current + execution instance. + +```c +subblock_num = subblock_num(); +``` + +Typical usage: + +```mlir +%block = pto.get_block_idx +%subblock = pto.get_subblock_idx +%block_num = pto.get_block_num +%subblock_num = pto.get_subblock_num +``` + +### Core Types + +### Element Types +`vreg`: `!pto.vreg` Fixed-width PTO micro Instruction vector type with total width exactly 256 bytes (2048 bits). `N` is the lane count, `T` is the element type, and `N * bitwidth(T) = 2048`. + +| Type | Bits | Description | +|------|------|-------------| +| `i8` / `si8` / `ui8` | 8 | Signless/signed/unsigned 8-bit integer | +| `i16` / `si16` / `ui16` | 16 | Signless/signed/unsigned 16-bit integer | +| `i32` / `si32` / `ui32` | 32 | Signless/signed/unsigned 32-bit integer | +| `i64` / `si64` / `ui64` | 64 | Signless/signed/unsigned 64-bit integer | +| `f16` | 16 | IEEE 754 half precision | +| `bf16` | 16 | Brain floating point | +| `f32` | 32 | IEEE 754 single precision | + +### Mask Types + +`mask`: `!pto.mask` Typed predicate-register view. `G` is one of `b8`, `b16`, `b32` and records the byte-granularity interpretation used by VPTO ops and verifiers. + +Typed masks are also the primary legality contract for predicated VPTO code: + +- vector ops over `f32`, `i32`, `si32`, and `ui32` consume `!pto.mask` +- vector ops over `f16`, `bf16`, `i16`, `si16`, and `ui16` consume + `!pto.mask` +- vector ops over 8-bit element families consume `!pto.mask` +- compare families keep seed-mask and result-mask granularity aligned with the + compared vector family +- carry families keep carry-in, carry-out, and execution-mask granularity + aligned with the data-vector family +- mask-only ops that do not explicitly change granularity preserve the same `G` + +### Address Space Conventions + +PTO micro Instruction memory operands use `!pto.ptr`. This specification models the following memory-space attributes: + +| Space | Interpretation | +|-------|----------------| +| `gm` | Global Memory (GM), off-chip HBM/DDR storage | +| `ub` | Unified Buffer (UB), on-chip vector buffer | + +Typical pointer construction and pointer arithmetic follow the same `!pto.ptr<..., space>` form: + +```mlir +%0 = pto.castptr %c0 : i64 -> !pto.ptr +%1 = pto.addptr %0, %c1024 : !pto.ptr -> !pto.ptr +``` + +### `!pto.ptr` + +`!pto.ptr` is the typed pointer form used for explicit memory operands in PTO micro Instruction. + +- `T` is the element type associated with the pointed-to storage. +- `space` is the memory domain, typically `gm` or `ub` in this specification. +- A `pto.ptr` value carries an address plus its element-type / memory-space interpretation, but it does not carry tensor shape or stride metadata by itself. +- Tensor semantics are introduced separately through view-building operations such as `pto.make_tensor_view`. +- Pointer arithmetic is element-based rather than byte-based. + +Typical examples: + +- `!pto.ptr` +- `!pto.ptr` +- `!pto.ptr` + +### Pointer Operations + +#### `pto.castptr` + +- **syntax:** `%result = pto.castptr %addr : i64 -> !pto.ptr` +- **semantics:** Reinterpret a scalar address value as a typed PTO pointer in the target memory space. + +```c +result = (ptr)addr; +``` + +`pto.castptr` is a pointer-construction operation. It does not perform data movement and does not by itself imply any load/store side effect. + +#### `pto.addptr` + +- **syntax:** `%result = pto.addptr %ptr, %offset : !pto.ptr -> !pto.ptr` +- **semantics:** Compute a new pointer by advancing the base pointer by an element offset. + +```c +result = ptr + offset; // offset counted in elements, not bytes +``` + +`pto.addptr` preserves both the element type `T` and the memory-space tag `space`. + +#### `pto.load_scalar` + +- **syntax:** `%value = pto.load_scalar %ptr[%offset] : !pto.ptr -> T` +- **semantics:** Load one scalar element from a pointer-like operand. + +```c +value = ptr[offset]; +``` + +- **inputs:** + `%ptr` is a typed PTO pointer `!pto.ptr`, and `%offset` is an + `index` displacement counted in elements. +- **outputs:** + `%value` is the loaded scalar element. +- **constraints and limitations:** + The result type MUST match the element type of `%ptr`. This op is a scalar + memory helper; unlike `pto.vlds`, it does not produce a `vreg` result and + does not participate in vector load `dist` families. + +#### `pto.store_scalar` + +- **syntax:** `pto.store_scalar %value, %ptr[%offset] : !pto.ptr, T` +- **semantics:** Store one scalar element to a pointer-like operand. + +```c +ptr[offset] = value; +``` + +- **inputs:** + `%value` is the scalar value to store. `%ptr` is a typed PTO pointer + `!pto.ptr`, and `%offset` is an `index` displacement counted in + elements. +- **constraints and limitations:** + The stored value type MUST match the element type of `%ptr`. This op is a + scalar memory helper; unlike `pto.vsts`, it does not consume a mask and does + not target vector-store `dist` families. + +#### Pointer-Based Vector Access Example + +The following lowered-style fragment shows how typed PTO pointers flow through +pointer construction, pointer arithmetic, structured control flow, and PTO +memory ops. Scalar memory access is expressed on `!pto.ptr` in +general, but the common VPTO pattern here is UB-local scalar access alongside +UB vector loads/stores: + +```mlir +%0 = pto.castptr %c0 : i64 -> !pto.ptr +%1 = pto.addptr %0, %c1024 : !pto.ptr -> !pto.ptr +pto.vecscope { + %16 = scf.for %arg3 = %c0 to %11 step %c64 iter_args(%arg4 = %12) -> (i32) { + %mask, %scalar_out = pto.plt_b32 %arg4 : i32 -> !pto.mask, i32 + %s = pto.load_scalar %1[%c4] : !pto.ptr -> f32 + pto.store_scalar %s, %1[%c8] : !pto.ptr, f32 + %17 = pto.vlds %1[%arg3] : !pto.ptr -> !pto.vreg<64xf32> + %18 = pto.vabs %17, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %18, %10[%arg3], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %scalar_out : i32 + } +} +``` + +In this pattern, `pto.castptr` materializes a typed UB pointer, `pto.addptr` shifts the base by 1024 `f32` elements, and the subsequent `[%arg3]` indexing on `pto.vlds` / `pto.vsts` applies an additional element offset relative to that base. + +### Special Types + +#### `!pto.mask` + +`!pto.mask` models an A5 predicate register (256-bit) under a typed granularity view, not an integer vector. + +`G` is part of the type and MUST be one of: + +- `b32` +- `b16` +- `b8` + +All three forms describe the same physical 256-bit predicate-register class. The type parameter does not encode how many lanes are currently active. Instead, it records how VPTO interprets the register when matching mask-producing ops, mask-consuming ops, and verifier legality rules. + +In the ISA chapters below, this document uses `!pto.mask` as shorthand when a +family is generic over granularity. For op families whose names already encode +the granularity, such as `pset_b32`, `pge_b16`, `plt_b8`, +`pdintlv_b8`, and `pintlv_b16`, examples use the corresponding concrete typed +mask. + +**Mask Granularity:** + +The predicate register is 256 bits in length, where each bit controls 1 byte of data. `G` therefore describes how many bytes form one logical element slot: + +| Mask Type | Bytes / Element Slot | Typical Element Family | Derived Logical Lanes | +|-----------|----------------------|------------------------|-----------------------| +| `!pto.mask` | 4 | `f32` / `i32` | 64 | +| `!pto.mask` | 2 | `f16` / `bf16` / `i16` | 128 | +| `!pto.mask` | 1 | 8-bit element family | 256 | + +This is intentionally different from a lane-vector model such as `mask<64xi1>`: + +- `!pto.mask` still denotes a 256-bit predicate register; +- `64` is only the derived logical lane count for the `b32` view; +- value-level patterns such as `PAT_VL32` describe which lanes are active, not a different type. +- `pto.vaddc`, `pto.vsubc`, `pto.vaddcs`, and `pto.vsubcs` use `!pto.mask` + to carry their per-lane carry results, interpreted with this same + granularity. + +**Predication Behavior (Zero-Merge):** + +The native hardware predication mode is **ZEROING** — inactive lanes produce zero: + +```c +dst[i] = mask[i] ? op(src0[i], src1[i]) : 0 // ZEROING mode +``` + +```mlir +// Predicated add: inactive lanes produce zero +%mask = pto.pset_b32 "PAT_VL32" : !pto.mask // first 32 logical b32 lanes active +%result = pto.vcmp %a, %b, %mask, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +``` + +```mlir +// Compare and select: generate mask from comparison, use for conditional select +%mask = pto.vcmp %lhs, %rhs, %seed, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +%out = pto.vsel %x, %y, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +#### `!pto.align` + +`!pto.align` models the A5 vector-align carrier state. It is not payload data. + +```mlir +%align = pto.vldas %ub : !pto.ptr -> !pto.align +%vec, %align_out = pto.vldus %ub, %align : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align + +%store_align = pto.init_align : !pto.align +%next_align = pto.vstus %store_align, %offset, %vec, %ub + : !pto.align, i32, !pto.vreg<64xf32>, !pto.ptr -> !pto.align +``` + +--- + +## Part II: Notation Convention + +This section defines the MLIR syntax patterns and C-style semantic notation used throughout the ISA reference (Part III). + +### MLIR Op Syntax Patterns + +All PTO micro Instruction operations follow standard MLIR syntax. The common patterns are: + +**Unary (one vector in, one vector out):** + +```mlir +%result = pto. %input : !pto.vreg -> !pto.vreg +``` + +**Binary (two vectors in, one vector out):** + +```mlir +%result = pto. %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg +``` + +**Vec-Scalar (one vector + one scalar in, one vector out):** + +```mlir +%result = pto. %input, %scalar : !pto.vreg, T -> !pto.vreg +``` + +**Load (memory to register):** + +```mlir +%result = pto.vlds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.vreg +``` + +**Store (register to memory):** + +```mlir +pto.vsts %value, %destination[%offset] {dist = "DIST"} : !pto.vreg, !pto.ptr +``` + +**Dual Load (one load, two results — deinterleave):** + +```mlir +%low, %high = pto.vldsx2 %source[%offset], "DIST" : !pto.ptr, index -> !pto.vreg, !pto.vreg +``` + +**Dual Store (two inputs, one interleaved store):** + +```mlir +pto.vstsx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vreg, !pto.ptr, index, !pto.mask +``` + +**Compare (two vectors + seed mask in, mask out):** + +```mlir +%mask = pto.vcmp %src0, %src1, %seed, "CMP_MODE" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.mask +``` + +**Conversion (one vector in, different-typed vector out):** + +```mlir +%result = pto.vcvt %input {rnd = "R", sat = "SAT", part = "EVEN"} : !pto.vreg -> !pto.vreg +``` + +**Predicate construction:** + +```mlir +%mask = pto.pset_b32 "PAT_ALL" : !pto.mask +%tail = pto.pge_b32 "PAT_VL16" : !pto.mask +``` + +**Sync operations:** + +```mlir +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +**Pointer construction and arithmetic:** + +```mlir +%ptr = pto.castptr %addr : i64 -> !pto.ptr +%ptr2 = pto.addptr %ptr, %offset : !pto.ptr -> !pto.ptr +``` + +### Shared Dialect Syntax Patterns + +PTO micro Instruction programs may interleave PTO ops with standard MLIR `arith` and `scf` ops. +The examples below emphasize common index-heavy patterns, but `arith` support is not limited to index arithmetic. + +**Scalar / index constant:** + +```mlir +%c0 = arith.constant 0 : index +%zero = arith.constant 0.0 : f32 +``` + +**Scalar arithmetic (integer / float / boolean-style bitwise):** + +```mlir +%sum_i = arith.addi %lhs_i, %rhs_i : i32 +%sum_f = arith.addf %lhs_f, %rhs_f : f32 +%bits = arith.andi %flags0, %flags1 : i32 +``` + +**Scalar compare and select:** + +```mlir +%cond = arith.cmpi eq, %lhs, %rhs : index +%bound = arith.select %cond, %a, %b : index +``` + +**Counted loop with loop-carried values:** + +```mlir +%result = scf.for %iv = %lb to %ub step %step + iter_args(%acc = %init) -> (index) { + %next = arith.addi %acc, %iv : index + scf.yield %next : index +} +``` + +**Structured conditional region:** + +```mlir +%selected = scf.if %cond -> (index) { + scf.yield %then_value : index +} else { + scf.yield %else_value : index +} +``` + +**Structured while loop:** + +```mlir +%state:2 = scf.while (%iv = %c0, %alive = %true) : (index, i1) -> (index, i1) { + %keep_going = arith.cmpi slt, %iv, %limit : index + scf.condition(%keep_going) %iv, %alive : index, i1 +} do { +^bb0(%iv_in: index, %alive_in: i1): + %iv_next = arith.addi %iv_in, %c1 : index + scf.yield %iv_next, %alive_in : index, i1 +} +``` + +### C-Style Semantics Convention + +For each ISA operation in Part III, semantics are expressed as C code. The convention: + +```c +// Vector register contents as arrays: +T dst[N]; // destination +T src0[N]; // first source +T src1[N]; // second source (binary ops) +T scalar; // scalar operand (vec-scalar ops) +int mask[N]; // per-lane predicate (0 or 1) + +// N = lane count determined by type: +// N = 256 for i8/si8/ui8 +// N = 128 for i16/si16/ui16/f16/bf16 +// N = 64 for i32/si32/ui32/f32 +// N = 32 for i64/si64/ui64 +``` + +**Example — pto.vadd semantics:** + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +**Example — pto.vcgadd (group reduction per VLane) semantics:** + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +### Template Placeholder Conventions + +| Placeholder | Meaning | +|-------------|---------| +| `"SRC_PIPE"`, `"DST_PIPE"` | Pipeline identifiers: `"PIPE_MTE2"`, `"PIPE_V"`, `"PIPE_MTE3"` | +| `"EVENT_ID"` | Event identifier: `"EVENT_ID0"` etc. | +| `"DIST"` | Distribution mode string (see the relevant load/store ISA group in Part III) | +| `"CMP_MODE"` | Compare predicate: `eq \| ne \| lt \| le \| gt \| ge` | +| `"RND"` | Rounding mode: `R \| A \| F \| C \| Z \| O` | +| `"SAT"` | Saturation: `SAT \| NOSAT` | +| `"PART"` | Half selector: `EVEN \| ODD` | +| `"PAT_*"` | Predicate pattern literal | +| `T` | Element type (f32, f16, bf16, i32, i16, i8, etc.) | +| `N` | Lane count (`N * bitwidth(T) = 2048`) | + +--- + +## Part III: ISA Instruction Reference +# Part III: ISA Instruction Reference — Summary + +This section provides a categorized overview of all PTO micro Instruction operations plus the shared MLIR `arith` and `scf` ops that may appear in PTO micro Instruction programs. Detailed documentation for each group is available in the linked files. + +--- + +## Instruction Groups + +| # | Group | Description | Count | Details | +|---|-------|-------------|-------|---------| +| 1 | [Pipeline Sync](isa/01-pipeline-sync.md) | Intra-core pipeline synchronization | 5 | `pto.set_flag`, `pto.wait_flag`, `pto.pipe_barrier`, `pto.get_buf`, `pto.rls_buf` | +| 2 | [DMA Copy Programming](isa/02-dma-copy.md) | DMA configuration and transfer between GM↔UB | 9 | `pto.set_loop*_stride_*`, `pto.set_loop_size_*`, `pto.copy_gm_to_ubuf`, `pto.copy_ubuf_to_ubuf`, `pto.copy_ubuf_to_gm` | +| 3 | [Vector Load/Store](isa/03-vector-load-store.md) | UB↔vreg data movement with various access patterns | ~20 | `pto.vlds`, `pto.vldsx2`, `pto.vgather2`, `pto.vsts`, `pto.vstsx2`, `pto.vscatter`, etc. | +| 4 | [Predicate Load/Store](isa/04-predicate-load-store.md) | UB↔mask register movement | 5 | `pto.plds`, `pto.pldi`, `pto.psts`, `pto.psti`, `pto.pstu` | +| 5 | [Materialization & Predicate Ops](isa/05-materialization-predicate.md) | Scalar broadcast, predicate generation and manipulation | ~17 | `pto.vbr`, `pto.vdup`, `pto.pset_b*`, `pto.pge_b*`, `pto.plt_b*`, `pto.ppack`, `pto.punpack`, `pto.pnot`, `pto.psel`, etc. | +| 6 | [Unary Vector Ops](isa/06-unary-vector-ops.md) | Single-input element-wise operations | 6 | `pto.vabs`, `pto.vexp`, `pto.vln`, `pto.vsqrt`, `pto.vrelu`, `pto.vnot` | +| 7 | [Binary Vector Ops](isa/07-binary-vector-ops.md) | Two-input element-wise operations | 13 | `pto.vadd`, `pto.vsub`, `pto.vmul`, `pto.vdiv`, `pto.vmax`, `pto.vmin`, `pto.vand`, `pto.vor`, `pto.vxor`, `pto.vshl`, `pto.vshr`, `pto.vaddc`, `pto.vsubc` | +| 8 | [Vec-Scalar Ops](isa/08-vec-scalar-ops.md) | Vector-scalar operations | 9 | `pto.vadds`, `pto.vmuls`, `pto.vmaxs`, `pto.vmins`, `pto.vlrelu`, `pto.vshls`, `pto.vshrs`, `pto.vaddcs`, `pto.vsubcs` | +| 9 | [Conversion Ops](isa/09-conversion-ops.md) | Type conversion with rounding/saturation control | 2 | `pto.vcvt`, `pto.vtrc` | +| 10 | [Reduction Ops](isa/10-reduction-ops.md) | Vector reductions | 7 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin`, `pto.vcgadd`, `pto.vcgmax`, `pto.vcgmin`, `pto.vcpadd` | +| 11 | [Compare & Select](isa/11-compare-select.md) | Comparison and conditional selection | 4 (+1 not A5) | `pto.vcmp`, `pto.vcmps`, `pto.vsel`, `pto.vselr` (`pto.vselrv2` removed: not A5) | +| 12 | [Data Rearrangement](isa/12-data-rearrangement.md) | In-register data movement and permutation | 2 (+2 not A5) | `pto.vintlv`, `pto.vdintlv` (`pto.vintlvv2`, `pto.vdintlvv2` removed: not A5) | +| 13 | [DSA/SFU Ops](isa/13-dsa-sfu-ops.md) | Specialized ops, index generation, and sorting helpers | 9 | `pto.vlrelu`, `pto.vprelu`, `pto.vexpdiff`, `pto.vaxpy`, `pto.vmull`, `pto.vmula`, `pto.vci`, `pto.vbitsort`, `pto.vmrgsort4` | +| 14 | [Arith (Shared MLIR Dialect)](isa/14-shared-arith.md) | Full scalar `arith` surface used around PTO ops; the companion page lists categories and representative examples | all scalar ops | `arith.constant`, `arith.addi`, `arith.addf`, `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.index_cast`, `arith.extsi`, `arith.trunci`, `arith.andi`, `arith.shli`, etc. | +| 15 | [SCF (Shared MLIR Dialect)](isa/15-shared-scf.md) | Structured loops, branches, and loop-carried state around PTO regions | 5 | `scf.for`, `scf.if`, `scf.while`, `scf.condition`, `scf.yield` | + +--- + +## Quick Reference by Category + +### Memory Operations + +| Operation | Group | Description | +|-----------|-------|-------------| +| GM→UB DMA | 2 | `pto.copy_gm_to_ubuf` | +| UB→GM DMA | 2 | `pto.copy_ubuf_to_gm` | +| UB→UB Copy | 2 | `pto.copy_ubuf_to_ubuf` | +| Contiguous Load | 3 | `pto.vlds` with `NORM` dist | +| Broadcast Load | 3 | `pto.vlds` with `BRC` family dist | +| Gather | 3 | `pto.vgather2`, `pto.vgatherb` | +| Contiguous Store | 3 | `pto.vsts` with `NORM` dist | +| Scatter | 3 | `pto.vscatter` | + +### Compute Operations + +| Operation | Group | Description | +|-----------|-------|-------------| +| Element-wise Arithmetic | 6, 7 | `pto.vadd`, `pto.vmul`, `pto.vabs`, etc. | +| Scalar Operations | 8 | `pto.vadds`, `pto.vmuls`, etc. | +| Transcendental | 6 | `pto.vexp`, `pto.vln`, `pto.vsqrt`, etc. | +| Reduction | 10 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin` | +| Comparison | 11 | `pto.vcmp`, `pto.vcmps` | +| Selection | 11 | `pto.vsel`, `pto.vselr` | + +### Type & Data Manipulation + +| Operation | Group | Description | +|-----------|-------|-------------| +| Type Conversion | 9 | `pto.vcvt` | +| Interleave/Deinterleave | 12 | `pto.vintlv`, `pto.vdintlv` | +| Interleave/Deinterleave (not A5) | 12 | `pto.vintlvv2`, `pto.vdintlvv2` | + +### Synchronization + +| Operation | Group | Description | +|-----------|-------|-------------| +| Intra-core Sync | 1 | `pto.set_flag`, `pto.wait_flag` | +| Pipeline Buffer Sync | 1 | `pto.get_buf`, `pto.rls_buf` | + +### Scalar & Control Operations + +Group 14 covers the full scalar `arith` surface. The rows below list common PTO micro Instruction patterns rather than an exhaustive partition of `arith` ops. + +| Operation | Group | Description | +|-----------|-------|-------------| +| Scalar Constants | 14 | `arith.constant` | +| Scalar Integer / Index Arithmetic | 14 | `arith.addi`, `arith.subi`, `arith.muli`, `arith.divsi`, `arith.remui`, `arith.ceildivsi`, etc. | +| Scalar Floating-Point Arithmetic | 14 | `arith.addf`, `arith.subf`, `arith.mulf`, `arith.divf`, `arith.maximumf`, etc. | +| Scalar Compare & Select | 14 | `arith.cmpi`, `arith.cmpf`, `arith.select` | +| Scalar Casts / Width Changes | 14 | `arith.index_cast`, `arith.index_castui`, `arith.extsi`, `arith.extui`, `arith.trunci`, `arith.sitofp`, etc. | +| Scalar Bitwise / Shift Ops | 14 | `arith.andi`, `arith.ori`, `arith.xori`, `arith.shli`, `arith.shrsi`, `arith.shrui`, etc. | +| Counted Loops | 15 | `scf.for` | +| Conditional Regions | 15 | `scf.if`, `scf.yield` | +| Break-like Structured Loops | 15 | `scf.while`, `scf.condition`, `scf.yield` | + +--- + +## Supported Data Types + +| Type | Bits | vreg Lanes | Description | +|------|------|-----------|-------------| +| `i8` / `si8` / `ui8` | 8 | 256 | Signless/signed/unsigned 8-bit integer | +| `i16` / `si16` / `ui16` | 16 | 128 | Signless/signed/unsigned 16-bit integer | +| `f16` | 16 | 128 | IEEE 754 half precision | +| `bf16` | 16 | 128 | Brain floating point | +| `i32` / `si32` / `ui32` | 32 | 64 | Signless/signed/unsigned 32-bit integer | +| `f32` | 32 | 64 | IEEE 754 single precision | +| `i64` / `si64` / `ui64` | 64 | 32 | Signless/signed/unsigned 64-bit integer | + +--- + +## Common Patterns + +### Softmax (Numerically Stable) + +```mlir +// 1. Find max +%max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %max_vec, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +%max_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> + +// 2. exp(x - max) using fused op +%exp = pto.vexpdiff %logits, %max_bc, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// 3. Sum +%sum = pto.vcadd %exp, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %sum, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +%sum_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> + +// 4. Divide +%softmax = pto.vdiv %exp, %sum_bc, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +### ReLU Variants + +```mlir +// Standard ReLU +%relu = pto.vrelu %input, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Leaky ReLU (scalar alpha) +%lrelu = pto.vlrelu %input, %alpha, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Parametric ReLU (per-element alpha) +%prelu = pto.vprelu %input, %alpha_vec : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +``` + +### Data Layout Conversion + +```mlir +// AoS → SoA (deinterleave) +%x, %y = pto.vldsx2 %ub_xy[%offset], "DINTLV" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// SoA → AoS (interleave) +pto.vstsx2 %x, %y, %ub_xy[%offset], "INTLV", %all_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, index, !pto.mask +``` + +--- + +*For detailed semantics, C-style pseudocode, and CCE mappings, see the individual group documentation files.* + +--- + +## Appendix: Discussion Points + +### Part I + +1. **mem_bar as pto op:** Should `pto.mem_bar` be a formal pto dialect op, or is there an existing mechanism? +2. **UB size parameterization:** Is 256KB always fixed, or should spec allow for architecture variants? +3. **MERGING predication:** Intentionally omitted (SW-emulated, perf overhead). Revisit if needed later. + +### Part II + +1. **Predication in C semantics:** Should every op's C code explicitly show the `if (mask[i])` guard, or assume all-active and note predication separately? +2. **VLane terminology:** Using "VLane" instead of "DataBlock" — confirm this naming is preferred. + +### Part 3A + +1. **pto.vdupi:** Is this distinct from `pto.vdup` with an immediate operand, or can `pto.vdup` handle both? +2. **Predicate ops (pand/por/pxor and predicate movement forms):** These need MLIR op definitions and verifier rules. Confirm priority. + +### Part 3B + +1. **Section 10 removals:** 4 interleave ops removed (not on A5). If multi-arch support is needed later, these would need conditional inclusion. + +### Part 3C + +2. **Store dist family completeness:** `vsts` currently covers `NORM`, `1PT`, `PK`, `PK4`, `MRG4CHN`, and `MRG2CHN`, while `vstsx2` covers `INTLV`. Confirm whether the surface constraints for these families are already sufficiently clear and complete. +3. **vcvt width-changing pattern:** The even/odd + `vor` pattern for forms such as `f32 -> f16` is the standard compiler lowering. Confirm this is the intended representation in the spec. +4. **Stateful store ops (Section 14):** These are complex with SSA state threading. Are they all needed for A5, or can some be simplified? diff --git a/include/PTO/IR/PTOOps.td b/include/PTO/IR/PTOOps.td index 5c96f3af2..0fc4dbabd 100644 --- a/include/PTO/IR/PTOOps.td +++ b/include/PTO/IR/PTOOps.td @@ -49,7 +49,7 @@ def TileBufOrMemRef : def ScalarPtrOrMemRef : TypeConstraint< CPred<"::mlir::pto::isScalarPtrOrMemRef($_self)">, - "Ptr or MemRef in GM">; + "Ptr or GM MemRef">; def ScalarType : AnyTypeOf<[AnySignlessInteger, AnyFloat], "numeric (integer/float)">; @@ -72,6 +72,8 @@ class PTO_DpsOp traits = []> class PTO_Op traits = []> : Op; +include "PTO/IR/VPTOOps.td" + //===----------------------------------------------------------------------===// // Pointer/View Ops (for your front-end IR) //===----------------------------------------------------------------------===// @@ -101,6 +103,31 @@ def AddPtrOp : PTO_Op<"addptr", [ }]; } +def CastPtrOp : PTO_Op<"castptr", [Pure]> { + let summary = "Cast between integer and !pto.ptr, or between !pto.ptr types"; + let description = [{ + Performs an explicit pointer-domain cast. + + Supported cases: + - integer -> !pto.ptr + - !pto.ptr -> integer + - !pto.ptr -> !pto.ptr + - memref<..., space> -> !pto.ptr (extract the aligned base ptr) + + Pointer-to-pointer casts must stay within the same PTO memory space. Cross + space casts such as gm <-> ub are rejected by the verifier. + }]; + + let arguments = (ins AnyType:$input); + let results = (outs AnyType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $input attr-dict `:` type($input) `->` type($result) + }]; +} + //===----------------------------------------------------------------------===// // Scalar pointer load/store //===----------------------------------------------------------------------===// @@ -1967,8 +1994,10 @@ def SetFlagOp : PTO_Op<"set_flag"> { PTO_EventAttr:$event_id ); let results = (outs); - let assemblyFormat = [{ - `[` $src_pipe `,` $dst_pipe `,` $event_id `]` attr-dict + let extraClassDeclaration = [{ + void print(::mlir::OpAsmPrinter &p); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, + ::mlir::OperationState &result); }]; } @@ -1980,8 +2009,10 @@ def WaitFlagOp : PTO_Op<"wait_flag"> { PTO_EventAttr:$event_id ); let results = (outs); - let assemblyFormat = [{ - `[` $src_pipe `,` $dst_pipe `,` $event_id `]` attr-dict + let extraClassDeclaration = [{ + void print(::mlir::OpAsmPrinter &p); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, + ::mlir::OperationState &result); }]; } @@ -2015,6 +2046,9 @@ def WaitFlagDynOp : PTO_Op<"wait_flag_dyn"> { // Buffer-ID Synchronization (A5) //===----------------------------------------------------------------------===// +def PTO_PipeLikeAttr + : AnyAttrOf<[PTO_PipeEventTypeAttr, PTO_SyncOpTypeAttr, PTO_PipeAttr]>; + def GetBufOp : PTO_Op<"get_buf"> { let summary = "Acquire a buffer-id token for a sync op type (A5)"; let description = [{ @@ -2031,7 +2065,7 @@ def GetBufOp : PTO_Op<"get_buf"> { }]; let arguments = (ins - PTO_PipeEventTypeLikeAttr:$op_type, + PTO_PipeLikeAttr:$op_type, I32Attr:$buf_id, DefaultValuedAttr:$mode ); @@ -2040,8 +2074,10 @@ def GetBufOp : PTO_Op<"get_buf"> { let hasVerifier = 1; - let assemblyFormat = [{ - `[` $op_type `,` $buf_id `]` attr-dict + let extraClassDeclaration = [{ + void print(::mlir::OpAsmPrinter &p); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, + ::mlir::OperationState &result); }]; } @@ -2055,7 +2091,7 @@ def RlsBufOp : PTO_Op<"rls_buf"> { }]; let arguments = (ins - PTO_PipeEventTypeLikeAttr:$op_type, + PTO_PipeLikeAttr:$op_type, I32Attr:$buf_id, DefaultValuedAttr:$mode ); @@ -2064,8 +2100,10 @@ def RlsBufOp : PTO_Op<"rls_buf"> { let hasVerifier = 1; - let assemblyFormat = [{ - `[` $op_type `,` $buf_id `]` attr-dict + let extraClassDeclaration = [{ + void print(::mlir::OpAsmPrinter &p); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, + ::mlir::OperationState &result); }]; } diff --git a/include/PTO/IR/PTOTypeDefs.td b/include/PTO/IR/PTOTypeDefs.td index 9947acbeb..23e053081 100644 --- a/include/PTO/IR/PTOTypeDefs.td +++ b/include/PTO/IR/PTOTypeDefs.td @@ -18,18 +18,25 @@ include "mlir/Interfaces/DataLayoutInterfaces.td" include "PTO/IR/PTODialect.td" include "PTO/IR/PTOAttrs.td" -// ---- !pto.ptr ---- +// ---- !pto.ptr ---- def PtrType : TypeDef { let mnemonic = "ptr"; let parameters = (ins - "mlir::Type":$elementType + "mlir::Type":$elementType, + "mlir::pto::AddressSpaceAttr":$memorySpace ); - let assemblyFormat = "`<` $elementType `>`"; + let hasCustomAssemblyFormat = 1; let skipDefaultBuilders = 1; let builders = [ TypeBuilder<(ins "Type":$elementType), [{ - return Base::get($_ctxt, elementType); + return Base::get($_ctxt, elementType, + mlir::pto::AddressSpaceAttr::get($_ctxt, + mlir::pto::AddressSpace::GM)); + }]>, + TypeBuilder<(ins "Type":$elementType, + "mlir::pto::AddressSpaceAttr":$memorySpace), [{ + return Base::get($_ctxt, elementType, memorySpace); }]> ]; } @@ -275,3 +282,5 @@ def F4E2M1x2Type : TypeDef($_self)">, + "PTO low-level vector type">; +def PTO_MaskTypeConstraint : Type($_self)">, + "PTO low-level mask type">; +def PTO_B8MaskTypeConstraint : Type< + CPred<"::llvm::isa<::mlir::pto::MaskType>($_self) && ::llvm::cast<::mlir::pto::MaskType>($_self).isB8()">, + "PTO low-level b8 mask type">; +def PTO_B16MaskTypeConstraint : Type< + CPred<"::llvm::isa<::mlir::pto::MaskType>($_self) && ::llvm::cast<::mlir::pto::MaskType>($_self).isB16()">, + "PTO low-level b16 mask type">; +def PTO_B32MaskTypeConstraint : Type< + CPred<"::llvm::isa<::mlir::pto::MaskType>($_self) && ::llvm::cast<::mlir::pto::MaskType>($_self).isB32()">, + "PTO low-level b32 mask type">; +def PTO_AlignTypeConstraint : Type($_self)">, + "PTO low-level align type">; + +def PTO_BufferType : Type< + CPred<"::llvm::isa<::mlir::pto::PtrType>($_self)">, + "pointer-like buffer type">; +def PTO_BufferLikeType : AnyTypeOf<[AnyMemRef, PTO_BufferType], + "memref or pointer-like buffer type">; + +def VecScopeOp : PTO_Op<"vecscope", [SingleBlock, NoTerminator]> { + let summary = "Structured region container for one VPTO vector scope"; + let description = [{ + `pto.vecscope` marks a structured vector-scope interval without overloading + a dummy carrier loop with scope metadata. Lowering and emission passes may + use the region boundary to preserve loop shape while treating the enclosed + body as one VPTO vector interval. + }]; + + let regions = (region SizedRegion<1>:$body); + let hasVerifier = 1; + let assemblyFormat = "$body attr-dict"; +} + +def StrictVecScopeOp : PTO_Op<"strict_vecscope", [SingleBlock, NoTerminator, + IsolatedFromAbove]> { + let summary = "Structured VPTO vector scope with explicit captures only"; + let description = [{ + `pto.strict_vecscope` is the strict form of `pto.vecscope`. Values used by + the body must be passed explicitly through op operands and corresponding + block arguments; implicit SSA capture from the surrounding scope is + rejected. + }]; + + let arguments = (ins Variadic:$captures); + let regions = (region SizedRegion<1>:$body); + let hasVerifier = 1; + let assemblyFormat = [{ + `(` $captures `)` $body attr-dict `:` functional-type($captures, results) + }]; +} + +class PTO_BinaryI64ConfigOp : PTO_Op { + let arguments = (ins + I64:$first, + I64:$second + ); + + let results = (outs); + + let assemblyFormat = [{ + $first `,` $second attr-dict `:` type($first) `,` type($second) + }]; +} + +def PTO_SetLoop2StrideOutToUbOp : PTO_BinaryI64ConfigOp<"set_loop2_stride_outtoub">; +def PTO_SetLoop1StrideOutToUbOp : PTO_BinaryI64ConfigOp<"set_loop1_stride_outtoub">; +def PTO_SetLoopSizeOutToUbOp : PTO_BinaryI64ConfigOp<"set_loop_size_outtoub">; +def PTO_SetLoop2StrideUbToOutOp : PTO_BinaryI64ConfigOp<"set_loop2_stride_ubtoout">; +def PTO_SetLoop1StrideUbToOutOp : PTO_BinaryI64ConfigOp<"set_loop1_stride_ubtoout">; +def PTO_SetLoopSizeUbToOutOp : PTO_BinaryI64ConfigOp<"set_loop_size_ubtoout">; + +def PTO_CopyGmToUbufOp : PTO_Op<"copy_gm_to_ubuf", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$sid, + I64:$n_burst, + I64:$len_burst, + I64:$left_padding_count, + I64:$right_padding_count, + I1:$data_select_bit, + I64:$l2_cache_ctl, + I64:$gm_stride, + I64:$ub_stride + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `,` $destination `,` $sid `,` $n_burst `,` $len_burst `,` + $left_padding_count `,` $right_padding_count `,` $data_select_bit `,` $l2_cache_ctl `,` $gm_stride `,` $ub_stride + attr-dict `:` type($source) `,` type($destination) `,` type($sid) `,` type($n_burst) `,` + type($len_burst) `,` type($left_padding_count) `,` + type($right_padding_count) `,` type($data_select_bit) `,` type($l2_cache_ctl) `,` type($gm_stride) `,` type($ub_stride) + }]; +} + +def PTO_CopyUbufToUbufOp : PTO_Op<"copy_ubuf_to_ubuf"> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$sid, + I64:$n_burst, + I64:$len_burst, + I64:$src_stride, + I64:$dst_stride + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `,` $destination `,` $sid `,` $n_burst `,` $len_burst `,` $src_stride `,` $dst_stride + attr-dict `:` type($source) `,` type($destination) `,` type($sid) `,` type($n_burst) `,` + type($len_burst) `,` type($src_stride) `,` type($dst_stride) + }]; +} + +def PTO_VldsOp : PTO_Op<"vlds", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + Index:$offset, + OptionalAttr:$dist + ); + + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `[` $offset `]` attr-dict `:` type($source) `->` type($result) + }]; +} + +def PTO_VldsPostOp : PTO_Op<"vlds_post", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + Index:$offset, + OptionalAttr:$dist + ); + + let results = (outs PTO_VectorType:$result, + PTO_BufferLikeType:$updated_source); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `[` $offset `]` attr-dict `:` type($source) `->` type($result) `,` type($updated_source) + }]; +} + +def PTO_Vldsx2Op : PTO_Op<"vldsx2", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + Index:$offset, + StrAttr:$dist + ); + let results = (outs PTO_VectorType:$low, PTO_VectorType:$high); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `[` $offset `]` `,` $dist attr-dict `:` type($source) `,` type($offset) `->` type($low) `,` type($high) + }]; +} + +def PTO_VldasOp : PTO_Op<"vldas", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$source + ); + + let results = (outs PTO_AlignTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source attr-dict `:` type($source) `->` type($result) + }]; +} + +def PTO_InitAlignOp : PTO_Op<"init_align", []> { + let arguments = (ins); + + let results = (outs PTO_AlignTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + attr-dict `:` type($result) + }]; +} + +def PTO_SprclrOp : PTO_Op<"sprclr", []> { + let arguments = (ins StrAttr:$spr); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $spr attr-dict + }]; +} + +def PTO_VldusOp : PTO_Op<"vldus", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$source, + PTO_AlignTypeConstraint:$align + ); + + let results = (outs + PTO_VectorType:$result, + PTO_AlignTypeConstraint:$updated_align + ); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `,` $align attr-dict `:` type($source) `,` type($align) `->` type($result) `,` type($updated_align) + }]; +} + +def PTO_UvldOp : PTO_Op<"uvld", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + Index:$offset + ); + + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `[` $offset `]` attr-dict `:` type($source) `->` type($result) + }]; +} + +def PTO_VbrOp : PTO_Op<"vbr", [Pure]> { + let arguments = (ins AnyType:$value); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $value attr-dict `:` type($value) `->` type($result) + }]; +} + +def PTO_VdupOp : PTO_Op<"vdup", [Pure]> { + let arguments = (ins + AnyType:$input, + PTO_MaskTypeConstraint:$mask, + OptionalAttr:$position + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $input `,` $mask attr-dict `:` type($input) `,` type($mask) `->` type($result) + }]; +} + +def PTO_PsetB8Op : PTO_Op<"pset_b8", [Pure]> { + let arguments = (ins StrAttr:$pattern); + let results = (outs PTO_B8MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $pattern attr-dict `:` type($result) + }]; +} + +def PTO_PsetB16Op : PTO_Op<"pset_b16", [Pure]> { + let arguments = (ins StrAttr:$pattern); + let results = (outs PTO_B16MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $pattern attr-dict `:` type($result) + }]; +} + +// NOTE: The op families introduced below are intentionally marked as +// unvalidated scaffolding. They are added to preserve missing CCE builtin +// semantics at the dialect layer, but they have not yet been validated through +// PTO lowering or end-to-end sample execution. +def PTO_PsetB32Op : PTO_Op<"pset_b32", [Pure]> { + let arguments = (ins StrAttr:$pattern); + let results = (outs PTO_B32MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $pattern attr-dict `:` type($result) + }]; +} + +def PTO_PgeB8Op : PTO_Op<"pge_b8", [Pure]> { + let arguments = (ins StrAttr:$pattern); + let results = (outs PTO_B8MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $pattern attr-dict `:` type($result) + }]; +} + +def PTO_PgeB16Op : PTO_Op<"pge_b16", [Pure]> { + let arguments = (ins StrAttr:$pattern); + let results = (outs PTO_B16MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $pattern attr-dict `:` type($result) + }]; +} + +def PTO_PgeB32Op : PTO_Op<"pge_b32", [Pure]> { + let arguments = (ins StrAttr:$pattern); + let results = (outs PTO_B32MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $pattern attr-dict `:` type($result) + }]; +} + +def PTO_PltB8Op : PTO_Op<"plt_b8", [Pure]> { + let arguments = (ins I32:$scalar); + let results = (outs PTO_B8MaskTypeConstraint:$mask, I32:$scalar_out); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $scalar attr-dict `:` type($scalar) `->` type($mask) `,` type($scalar_out) + }]; +} + +def PTO_PltB16Op : PTO_Op<"plt_b16", [Pure]> { + let arguments = (ins I32:$scalar); + let results = (outs PTO_B16MaskTypeConstraint:$mask, I32:$scalar_out); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $scalar attr-dict `:` type($scalar) `->` type($mask) `,` type($scalar_out) + }]; +} + +def PTO_PltB32Op : PTO_Op<"plt_b32", [Pure]> { + let arguments = (ins I32:$scalar); + let results = (outs PTO_B32MaskTypeConstraint:$mask, I32:$scalar_out); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $scalar attr-dict `:` type($scalar) `->` type($mask) `,` type($scalar_out) + }]; +} + +class PTO_MaskUnaryOp : PTO_Op { + let arguments = (ins PTO_MaskTypeConstraint:$input); + let results = (outs PTO_MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $input attr-dict `:` type($input) `->` type($result) + }]; +} + +def PTO_PpackOp : PTO_MaskUnaryOp<"ppack"> { + let arguments = (ins PTO_MaskTypeConstraint:$input, StrAttr:$part); + let assemblyFormat = [{ + $input `,` $part attr-dict `:` type($input) `->` type($result) + }]; +} + +def PTO_PunpackOp : PTO_MaskUnaryOp<"punpack"> { + let arguments = (ins PTO_MaskTypeConstraint:$input, StrAttr:$part); + let assemblyFormat = [{ + $input `,` $part attr-dict `:` type($input) `->` type($result) + }]; +} + +def PTO_PnotOp : PTO_Op<"pnot", [Pure]> { + let arguments = (ins + PTO_MaskTypeConstraint:$input, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $input `,` $mask attr-dict `:` type($input) `,` type($mask) `->` type($result) + }]; +} + +def PTO_PselOp : PTO_Op<"psel", [Pure]> { + let arguments = (ins + PTO_MaskTypeConstraint:$src0, + PTO_MaskTypeConstraint:$src1, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src0 `,` $src1 `,` $mask attr-dict `:` type($src0) `,` type($src1) `,` type($mask) `->` type($result) + }]; +} + +def PTO_PandOp : PTO_Op<"pand", [Pure]> { + let arguments = (ins + PTO_MaskTypeConstraint:$src0, + PTO_MaskTypeConstraint:$src1, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src0 `,` $src1 `,` $mask attr-dict `:` type($src0) `,` type($src1) `,` type($mask) `->` type($result) + }]; +} + +def PTO_PorOp : PTO_Op<"por", [Pure]> { + let arguments = (ins + PTO_MaskTypeConstraint:$src0, + PTO_MaskTypeConstraint:$src1, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src0 `,` $src1 `,` $mask attr-dict `:` type($src0) `,` type($src1) `,` type($mask) `->` type($result) + }]; +} + +def PTO_PxorOp : PTO_Op<"pxor", [Pure]> { + let arguments = (ins + PTO_MaskTypeConstraint:$src0, + PTO_MaskTypeConstraint:$src1, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src0 `,` $src1 `,` $mask attr-dict `:` type($src0) `,` type($src1) `,` type($mask) `->` type($result) + }]; +} + +def PTO_PldsOp : PTO_Op<"plds", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + Index:$offset, + StrAttr:$dist + ); + let results = (outs PTO_MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `[` $offset `]` `,` $dist attr-dict `:` type($source) `,` type($offset) `->` type($result) + }]; +} + +def PTO_PldiOp : PTO_Op<"pldi", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + Index:$offset, + StrAttr:$dist + ); + let results = (outs PTO_MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `[` $offset `]` `,` $dist attr-dict `:` type($source) `,` type($offset) `->` type($result) + }]; +} + +def PTO_PstiOp : PTO_Op<"psti", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_MaskTypeConstraint:$value, + PTO_BufferLikeType:$destination, + Index:$offset, + StrAttr:$dist + ); + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $value `,` $destination `[` $offset `]` `,` $dist attr-dict `:` type($value) `,` type($destination) `,` type($offset) + }]; +} + +def PTO_VabsOp : PTO_Op<"vabs", [Pure]> { + let arguments = (ins + PTO_VectorType:$input, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $input `,` $mask attr-dict `:` type($input) `,` type($mask) `->` type($result) + }]; +} + +class PTO_UnaryVecOp : PTO_Op { + let arguments = (ins + PTO_VectorType:$input, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $input `,` $mask attr-dict `:` type($input) `,` type($mask) `->` type($result) + }]; +} + +def PTO_VexpOp : PTO_UnaryVecOp<"vexp">; +def PTO_VlnOp : PTO_UnaryVecOp<"vln">; +def PTO_VsqrtOp : PTO_UnaryVecOp<"vsqrt">; +def PTO_VnegOp : PTO_UnaryVecOp<"vneg">; +def PTO_VrsqrtOp : PTO_UnaryVecOp<"vrsqrt">; +def PTO_VrecOp : PTO_UnaryVecOp<"vrec">; +def PTO_VreluOp : PTO_UnaryVecOp<"vrelu">; +def PTO_VnotOp : PTO_UnaryVecOp<"vnot">; +def PTO_VcaddOp : PTO_UnaryVecOp<"vcadd">; +def PTO_VcmaxOp : PTO_UnaryVecOp<"vcmax">; +def PTO_VcminOp : PTO_UnaryVecOp<"vcmin">; +def PTO_VcgaddOp : PTO_UnaryVecOp<"vcgadd">; +def PTO_VcgmaxOp : PTO_UnaryVecOp<"vcgmax">; +def PTO_VcgminOp : PTO_UnaryVecOp<"vcgmin">; +def PTO_VcpaddOp : PTO_UnaryVecOp<"vcpadd">; + +class PTO_BinaryVecOp : PTO_Op { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs `,` $mask attr-dict `:` type($lhs) `,` type($rhs) `,` type($mask) `->` type($result) + }]; +} + +def PTO_VaddOp : PTO_BinaryVecOp<"vadd">; +def PTO_VsubOp : PTO_BinaryVecOp<"vsub">; +def PTO_VsaddOp : PTO_BinaryVecOp<"vsadd">; +def PTO_VssubOp : PTO_BinaryVecOp<"vssub">; +def PTO_VmulOp : PTO_BinaryVecOp<"vmul">; +def PTO_VdivOp : PTO_BinaryVecOp<"vdiv">; +def PTO_VmaxOp : PTO_BinaryVecOp<"vmax">; +def PTO_VminOp : PTO_BinaryVecOp<"vmin">; +def PTO_VandOp : PTO_BinaryVecOp<"vand">; +def PTO_VorOp : PTO_BinaryVecOp<"vor">; +def PTO_VxorOp : PTO_BinaryVecOp<"vxor">; + +def PTO_VaddcOp : PTO_Op<"vaddc", [Pure]> { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs + PTO_VectorType:$result, + PTO_MaskTypeConstraint:$carry + ); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs `,` $mask attr-dict `:` type($lhs) `,` type($rhs) `,` type($mask) `->` type($result) `,` type($carry) + }]; +} + +def PTO_VsubcOp : PTO_Op<"vsubc", [Pure]> { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs + PTO_VectorType:$result, + PTO_MaskTypeConstraint:$carry + ); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs `,` $mask attr-dict `:` type($lhs) `,` type($rhs) `,` type($mask) `->` type($result) `,` type($carry) + }]; +} + +def PTO_VaddcsOp : PTO_Op<"vaddcs", [Pure]> { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs, + PTO_MaskTypeConstraint:$carry_in, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs + PTO_VectorType:$result, + PTO_MaskTypeConstraint:$carry + ); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs `,` $carry_in `,` $mask attr-dict `:` type($lhs) `,` type($rhs) `,` type($carry_in) `,` type($mask) `->` type($result) `,` type($carry) + }]; +} + +def PTO_VsubcsOp : PTO_Op<"vsubcs", [Pure]> { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs, + PTO_MaskTypeConstraint:$carry_in, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs + PTO_VectorType:$result, + PTO_MaskTypeConstraint:$carry + ); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs `,` $carry_in `,` $mask attr-dict `:` type($lhs) `,` type($rhs) `,` type($carry_in) `,` type($mask) `->` type($result) `,` type($carry) + }]; +} + +def PTO_VbcntOp : PTO_UnaryVecOp<"vbcnt">; +def PTO_VclsOp : PTO_UnaryVecOp<"vcls">; + +def PTO_VshlOp : PTO_BinaryVecOp<"vshl">; +def PTO_VshrOp : PTO_BinaryVecOp<"vshr">; + +def PTO_VselOp : PTO_Op<"vsel", [Pure]> { + let arguments = (ins + PTO_VectorType:$src0, + PTO_VectorType:$src1, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src0 `,` $src1 `,` $mask attr-dict `:` type($src0) `,` type($src1) `,` type($mask) `->` type($result) + }]; +} + +def PTO_VcmpOp : PTO_Op<"vcmp", [Pure]> { + let arguments = (ins + PTO_VectorType:$src0, + PTO_VectorType:$src1, + PTO_MaskTypeConstraint:$mask, + StrAttr:$cmp_mode + ); + let results = (outs PTO_MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src0 `,` $src1 `,` $mask `,` $cmp_mode attr-dict `:` type($src0) `,` type($src1) `,` type($mask) `->` type($result) + }]; +} + +def PTO_VcmpsOp : PTO_Op<"vcmps", [Pure]> { + let arguments = (ins + PTO_VectorType:$src, + AnyType:$scalar, + PTO_MaskTypeConstraint:$mask, + StrAttr:$cmp_mode + ); + let results = (outs PTO_MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src `,` $scalar `,` $mask `,` $cmp_mode attr-dict `:` type($src) `,` type($scalar) `,` type($mask) `->` type($result) + }]; +} + +class PTO_PredicatePairReorderOp + : PTO_Op { + let arguments = (ins + operandTy:$lhs, + operandTy:$rhs + ); + let results = (outs + operandTy:$low, + operandTy:$high + ); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($low) `,` type($high) + }]; +} + +def PTO_PdintlvB8Op : PTO_PredicatePairReorderOp<"pdintlv_b8", + PTO_B8MaskTypeConstraint>; +def PTO_PdintlvB16Op : PTO_PredicatePairReorderOp<"pdintlv_b16", + PTO_B16MaskTypeConstraint>; +def PTO_PdintlvB32Op : PTO_PredicatePairReorderOp<"pdintlv_b32", + PTO_B32MaskTypeConstraint>; + +def PTO_PintlvB8Op : PTO_PredicatePairReorderOp<"pintlv_b8", + PTO_B8MaskTypeConstraint>; +def PTO_PintlvB16Op : PTO_PredicatePairReorderOp<"pintlv_b16", + PTO_B16MaskTypeConstraint>; +def PTO_PintlvB32Op : PTO_PredicatePairReorderOp<"pintlv_b32", + PTO_B32MaskTypeConstraint>; + +class PTO_VecScalarOp : PTO_Op { + let arguments = (ins + PTO_VectorType:$input, + AnyType:$scalar + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $input `,` $scalar attr-dict `:` type($input) `,` type($scalar) `->` type($result) + }]; +} + +class PTO_VecScalarMaskedOp : PTO_Op { + let arguments = (ins + PTO_VectorType:$input, + AnyType:$scalar, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $input `,` $scalar `,` $mask attr-dict `:` type($input) `,` type($scalar) `,` type($mask) `->` type($result) + }]; +} + +def PTO_VtrcOp : PTO_Op<"vtrc", [Pure]> { + let arguments = (ins + PTO_VectorType:$input, + StrAttr:$round_mode + ); + let results = (outs PTO_VectorType:$result); + + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; +} + +def PTO_VcvtOp : PTO_Op<"vcvt", [Pure]> { + let arguments = (ins + PTO_VectorType:$input, + OptionalAttr:$rnd, + OptionalAttr:$sat, + OptionalAttr:$part + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + +def PTO_VciOp : PTO_Op<"vci", [Pure]> { + let arguments = (ins + AnyInteger:$index, + OptionalAttr:$order + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $index attr-dict `:` type($index) `->` type($result) + }]; +} + +def PTO_VbitsortOp : PTO_Op<"vbitsort", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$destination, + PTO_BufferType:$source, + PTO_BufferType:$indices, + Index:$repeat_times + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $destination `,` $source `,` $indices `,` $repeat_times attr-dict `:` type($destination) `,` + type($source) `,` type($indices) `,` type($repeat_times) + }]; +} + +def PTO_Vmrgsort4Op : PTO_Op<"vmrgsort4"> { + let arguments = (ins + PTO_BufferType:$destination, + PTO_BufferType:$source0, + PTO_BufferType:$source1, + PTO_BufferType:$source2, + PTO_BufferType:$source3, + I64:$count, + I64:$config + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $destination `,` $source0 `,` $source1 `,` $source2 `,` $source3 `,` $count `,` $config + attr-dict `:` type($destination) `,` type($source0) `,` type($source1) `,` type($source2) `,` + type($source3) `,` type($count) `,` type($config) + }]; +} + +def PTO_Vgather2Op : PTO_Op<"vgather2", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$source, + PTO_VectorType:$offsets, + Index:$active_lanes + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `,` $offsets `,` $active_lanes attr-dict `:` type($source) `,` type($offsets) `,` type($active_lanes) `->` type($result) + }]; +} + +def PTO_VgatherbOp : PTO_Op<"vgatherb", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$source, + PTO_VectorType:$offsets, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `,` $offsets `,` $mask attr-dict `:` type($source) `,` type($offsets) `,` type($mask) `->` type($result) + }]; +} + +// NOTE: Unvalidated new gather/select/interleave-family abstractions. Added to +// cover CCE builtin families not yet exercised through end-to-end PTO seams. +def PTO_Vgather2BcOp : PTO_Op<"vgather2_bc", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$source, + PTO_VectorType:$offsets, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `,` $offsets `,` $mask attr-dict `:` type($source) `,` type($offsets) `,` type($mask) `->` type($result) + }]; +} + +def PTO_VmulsOp : PTO_VecScalarMaskedOp<"vmuls">; +def PTO_VaddsOp : PTO_VecScalarMaskedOp<"vadds">; +def PTO_VsaddsOp : PTO_VecScalarMaskedOp<"vsadds">; +def PTO_VmaxsOp : PTO_VecScalarMaskedOp<"vmaxs">; +def PTO_VminsOp : PTO_VecScalarMaskedOp<"vmins">; +def PTO_VlreluOp : PTO_VecScalarMaskedOp<"vlrelu">; +def PTO_VshlsOp : PTO_VecScalarMaskedOp<"vshls">; +def PTO_VshrsOp : PTO_VecScalarMaskedOp<"vshrs">; + +def PTO_VstsOp : PTO_Op<"vsts", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_VectorType:$value, + PTO_BufferLikeType:$destination, + Index:$offset, + OptionalAttr:$dist, + PTO_MaskTypeConstraint:$mask + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $value `,` $destination `[` $offset `]` `,` $mask attr-dict `:` type($value) `,` type($destination) `,` type($mask) + }]; +} + +def PTO_VstsPostOp : PTO_Op<"vsts_post", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_VectorType:$value, + PTO_BufferLikeType:$destination, + Index:$offset, + OptionalAttr:$dist, + PTO_MaskTypeConstraint:$mask + ); + + let results = (outs PTO_BufferLikeType:$updated_destination); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $value `,` $destination `[` $offset `]` `,` $mask attr-dict `:` type($value) `,` type($destination) `,` type($mask) `->` type($updated_destination) + }]; +} + +def PTO_VscatterOp : PTO_Op<"vscatter", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_VectorType:$value, + PTO_BufferType:$destination, + PTO_VectorType:$offsets, + Index:$active_lanes + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $value `,` $destination `,` $offsets `,` $active_lanes attr-dict `:` type($value) `,` type($destination) `,` type($offsets) `,` type($active_lanes) + }]; +} + +def PTO_PstsOp : PTO_Op<"psts", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_MaskTypeConstraint:$value, + PTO_BufferLikeType:$destination, + Index:$offset, + StrAttr:$dist + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $value `,` $destination `[` $offset `]` `,` $dist attr-dict `:` type($value) `,` type($destination) `,` type($offset) + }]; +} + +def PTO_CopyUbufToGmOp : PTO_Op<"copy_ubuf_to_gm", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$sid, + I64:$n_burst, + I64:$len_burst, + I64:$reserved, + I64:$burst_dst_stride, + I64:$burst_src_stride + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `,` $destination `,` $sid `,` $n_burst `,` $len_burst `,` + $reserved `,` $burst_dst_stride `,` $burst_src_stride + attr-dict `:` type($source) `,` type($destination) `,` type($sid) `,` type($n_burst) `,` + type($len_burst) `,` type($reserved) `,` + type($burst_dst_stride) `,` type($burst_src_stride) + }]; +} + +// NOTE: Unvalidated new x2 / pair / align-store-family abstractions. Added to +// reflect CCE builtin families but not yet end-to-end validated. +def PTO_VselrOp : PTO_Op<"vselr", [Pure]> { + let arguments = (ins + PTO_VectorType:$src0, + PTO_VectorType:$src1 + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src0 `,` $src1 attr-dict `:` type($src0) `,` type($src1) `->` type($result) + }]; +} + +def PTO_VslideOp : PTO_Op<"vslide", [Pure]> { + let arguments = (ins + PTO_VectorType:$src0, + PTO_VectorType:$src1, + I16:$amt + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src0 `,` $src1 `,` $amt attr-dict `:` type($src0) `,` type($src1) `,` type($amt) `->` type($result) + }]; +} + +def PTO_VsqzOp : PTO_UnaryVecOp<"vsqz">; + +def PTO_VusqzOp : PTO_Op<"vusqz", [Pure]> { + let arguments = (ins + PTO_VectorType:$src, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src `,` $mask attr-dict `:` type($src) `,` type($mask) `->` type($result) + }]; +} + +def PTO_VpackOp : PTO_Op<"vpack", [Pure]> { + let arguments = (ins + PTO_VectorType:$src, + StrAttr:$part + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src `,` $part attr-dict `:` type($src) `->` type($result) + }]; +} + +def PTO_VsunpackOp : PTO_Op<"vsunpack", [Pure]> { + let arguments = (ins + PTO_VectorType:$src, + Index:$part + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src `,` $part attr-dict `:` type($src) `->` type($result) + }]; +} + +def PTO_VzunpackOp : PTO_Op<"vzunpack", [Pure]> { + let arguments = (ins + PTO_VectorType:$src, + Index:$part + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src `,` $part attr-dict `:` type($src) `->` type($result) + }]; +} + +def PTO_Vselrv2Op : PTO_Op<"vselrv2", [Pure]> { + let arguments = (ins + PTO_VectorType:$src0, + PTO_VectorType:$src1 + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src0 `,` $src1 attr-dict `:` type($src0) `,` type($src1) `->` type($result) + }]; +} + +def PTO_VintlvOp : PTO_Op<"vintlv", [Pure]> { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs + ); + let results = (outs PTO_VectorType:$low, PTO_VectorType:$high); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($low) `,` type($high) + }]; +} + +def PTO_VdintlvOp : PTO_Op<"vdintlv", [Pure]> { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs + ); + let results = (outs PTO_VectorType:$low, PTO_VectorType:$high); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($low) `,` type($high) + }]; +} + +def PTO_Vintlvv2Op : PTO_Op<"vintlvv2", [Pure]> { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs, + StrAttr:$part + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs `,` $part attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) + }]; +} + +def PTO_Vdintlvv2Op : PTO_Op<"vdintlvv2", [Pure]> { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs, + StrAttr:$part + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs `,` $part attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) + }]; +} + +def PTO_VmullOp : PTO_Op<"vmull", [Pure]> { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_VectorType:$low, PTO_VectorType:$high); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs `,` $mask attr-dict `:` type($lhs) `,` type($rhs) `,` type($mask) `->` type($low) `,` type($high) + }]; +} + +def PTO_VmulaOp : PTO_Op<"vmula", [Pure]> { + let arguments = (ins + PTO_VectorType:$acc, + PTO_VectorType:$lhs, + PTO_VectorType:$rhs, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $acc `,` $lhs `,` $rhs `,` $mask attr-dict `:` type($acc) `,` type($lhs) `,` type($rhs) `,` type($mask) `->` type($result) + }]; +} + +class PTO_UnmaskedBinaryVecOp : PTO_Op { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) + }]; +} + +def PTO_VpreluOp : PTO_UnmaskedBinaryVecOp<"vprelu">; +def PTO_VexpdiffOp : PTO_Op<"vexpdiff", [Pure]> { + let arguments = (ins + PTO_VectorType:$input, + PTO_VectorType:$max, + StrAttr:$part + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $input `,` $max `,` $part attr-dict `:` type($input) `,` type($max) `->` type($result) + }]; +} + +def PTO_VaxpyOp : PTO_Op<"vaxpy", [Pure]> { + let arguments = (ins + PTO_VectorType:$src0, + PTO_VectorType:$src1, + AnyType:$alpha + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src0 `,` $src1 `,` $alpha attr-dict `:` type($src0) `,` type($src1) `,` type($alpha) `->` type($result) + }]; +} + +def PTO_VaddreluconvOp : PTO_Op<"vaddreluconv", [Pure]> { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) + }]; +} + +def PTO_VmulconvOp : PTO_Op<"vmulconv", [Pure]> { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) + }]; +} + +def PTO_Vstsx2Op : PTO_Op<"vstsx2", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_VectorType:$low, + PTO_VectorType:$high, + PTO_BufferLikeType:$destination, + Index:$offset, + StrAttr:$dist, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $low `,` $high `,` $destination `[` $offset `]` `,` $dist `,` $mask attr-dict `:` type($low) `,` type($high) `,` type($destination) `,` type($offset) `,` type($mask) + }]; +} + +def PTO_VsldbOp : PTO_Op<"vsldb", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + I16:$block_stride, + I16:$repeat_stride, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `,` $block_stride `,` $repeat_stride `,` $mask attr-dict `:` type($source) `,` type($block_stride) `,` type($repeat_stride) `,` type($mask) `->` type($result) + }]; +} + +def PTO_VsstbOp : PTO_Op<"vsstb", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_VectorType:$value, + PTO_BufferLikeType:$destination, + I16:$block_stride, + I16:$repeat_stride, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $value `,` $destination `,` $block_stride `,` $repeat_stride `,` $mask attr-dict `:` type($value) `,` type($destination) `,` type($block_stride) `,` type($repeat_stride) `,` type($mask) + }]; +} + +def PTO_VstasOp : PTO_Op<"vstas", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_AlignTypeConstraint:$value, + PTO_BufferLikeType:$destination, + I32:$offset + ); + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $value `,` $destination `,` $offset attr-dict `:` type($value) `,` type($destination) `,` type($offset) + }]; +} + +def PTO_VstarOp : PTO_Op<"vstar", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_AlignTypeConstraint:$value, + PTO_BufferLikeType:$destination + ); + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $value `,` $destination attr-dict `:` type($value) `,` type($destination) + }]; +} + +// NOTE: Unvalidated stateful store-family abstractions. These preserve +// align/base/offset update results explicitly in SSA form instead of relying on +// implicit CCE reference updates. +// Keep `base/base_out` pointer-only (`PTO_BufferType`): memref semantics for +// stateful post-update addresses are intentionally out of scope in this change. +def PTO_PstuOp : PTO_Op<"pstu", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_AlignTypeConstraint:$align_in, + PTO_MaskTypeConstraint:$value, + PTO_BufferType:$base + ); + let results = (outs PTO_AlignTypeConstraint:$align_out, PTO_BufferType:$base_out); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $align_in `,` $value `,` $base attr-dict `:` type($align_in) `,` type($value) `,` type($base) `->` type($align_out) `,` type($base_out) + }]; +} + +def PTO_VstusOp : PTO_Op<"vstus", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_AlignTypeConstraint:$align_in, + I32:$offset, + PTO_VectorType:$value, + PTO_BufferType:$base + ); + let results = (outs PTO_AlignTypeConstraint:$align_out); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $align_in `,` $offset `,` $value `,` $base attr-dict `:` type($align_in) `,` type($offset) `,` type($value) `,` type($base) `->` type($align_out) + }]; +} + +def PTO_VsturOp : PTO_Op<"vstur", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_AlignTypeConstraint:$align_in, + PTO_VectorType:$value, + PTO_BufferType:$base, + StrAttr:$mode + ); + let results = (outs PTO_AlignTypeConstraint:$align_out); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $align_in `,` $value `,` $base `,` $mode attr-dict `:` type($align_in) `,` type($value) `,` type($base) `->` type($align_out) + }]; +} + +#endif // MLIR_DIALECT_PTO_IR_VPTOOPS diff --git a/include/PTO/IR/VPTOTypeDefs.td b/include/PTO/IR/VPTOTypeDefs.td new file mode 100644 index 000000000..04e8ac583 --- /dev/null +++ b/include/PTO/IR/VPTOTypeDefs.td @@ -0,0 +1,53 @@ +//===- VPTOTypeDefs.td ---------------------------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_PTO_IR_VPTOTYPEDEFS +#define MLIR_DIALECT_PTO_IR_VPTOTYPEDEFS + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinTypeInterfaces.td" + +def VRegType : TypeDef { + let mnemonic = "vreg"; + let summary = "A 256-byte PTO low-level vector"; + + let parameters = (ins + "int64_t":$elementCount, + "Type":$elementType + ); + + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; +} + +def MaskType : TypeDef { + let mnemonic = "mask"; + let summary = "A PTO low-level predicate/mask register"; + + let parameters = (ins + StringRefParameter<"mask granularity view">:$granularity + ); + + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; + + let extraClassDeclaration = [{ + static bool isSupportedGranularity(::llvm::StringRef granularity); + + bool isB8() const { return getGranularity() == "b8"; } + bool isB16() const { return getGranularity() == "b16"; } + bool isB32() const { return getGranularity() == "b32"; } + }]; +} + +def AlignType : TypeDef { + let mnemonic = "align"; + let summary = "A PTO low-level vector_align carrier"; +} + +#endif // MLIR_DIALECT_PTO_IR_VPTOTYPEDEFS diff --git a/include/PTO/Transforms/HIVMIntrinsicNaming.h b/include/PTO/Transforms/HIVMIntrinsicNaming.h new file mode 100644 index 000000000..7ba956168 --- /dev/null +++ b/include/PTO/Transforms/HIVMIntrinsicNaming.h @@ -0,0 +1,60 @@ +#ifndef MLIR_DIALECT_PTO_TRANSFORMS_HIVMINTRINSICNAMING_H +#define MLIR_DIALECT_PTO_TRANSFORMS_HIVMINTRINSICNAMING_H + +#include +#include + +#include "mlir/IR/Operation.h" +#include "mlir/Support/LLVM.h" + +namespace mlir::pto { + +struct NamingInputs { + std::string sourceOpName; + std::string family; + std::string vectorShape; + std::string elementType; + std::vector usedFields; + std::vector missingFields; +}; + +struct UnresolvedEmissionRecord { + std::string sourceOpName; + std::string placeholderName; + std::string candidateName; + std::vector usedFields; + std::vector missingFields; + std::string resultTypeFragment; + std::string location; +}; + +struct IntrinsicSelection { + bool resolved = false; + std::string sourceOpName; + std::string calleeName; + std::string placeholderName; + std::string candidateName; + std::vector usedFields; + std::vector missingFields; + std::string resultTypeFragment; + std::string location; + + std::string getEmittedCallee() const { + return resolved ? calleeName : placeholderName; + } + + UnresolvedEmissionRecord asUnresolvedRecord() const { + return UnresolvedEmissionRecord{sourceOpName, placeholderName, candidateName, + usedFields, missingFields, resultTypeFragment, + location}; + } +}; + +FailureOr selectIntrinsic(Operation *op); +FailureOr selectLoadIntrinsic(Operation *op); +FailureOr selectUnaryIntrinsic(Operation *op); +FailureOr selectStoreIntrinsic(Operation *op); + +} // namespace mlir::pto + +#endif // MLIR_DIALECT_PTO_TRANSFORMS_HIVMINTRINSICNAMING_H diff --git a/include/PTO/Transforms/Passes.h b/include/PTO/Transforms/Passes.h index 8b50dee9e..4193bc7d5 100644 --- a/include/PTO/Transforms/Passes.h +++ b/include/PTO/Transforms/Passes.h @@ -64,7 +64,12 @@ std::unique_ptr createPTORemoveRedundantBarrierPass(); std::unique_ptr createPTOViewToMemrefPass(); std::unique_ptr createInferPTOLayoutPass(); std::unique_ptr createPTOA5NormalizeTMovPass(); - +std::unique_ptr createPTOVPTOExpandBridgeOpsPass(); +std::unique_ptr createPTOVPTOPtrBoundaryPass(); +std::unique_ptr createPTOValidateVPTOIRPass(); +std::unique_ptr createPTOValidateVPTOEmissionIRPass(); +std::unique_ptr createLowerPTOToVPTOPass(); +std::unique_ptr createLowerPTOToVPTOPass(StringRef loweringStrategy); //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index 3dfd20435..5094c356a 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -15,6 +15,9 @@ // //===----------------------------------------------------------------------===// +// The VPTO backend is emitted from tools/ptoas rather than a TableGen pass; +// these registrations continue to describe the shared pre-backend pipeline. + #ifndef MLIR_DIALECT_PTO_PASSES #define MLIR_DIALECT_PTO_PASSES @@ -248,4 +251,99 @@ def PTOViewToMemref : Pass<"pto-view-to-memref", "ModuleOp"> { ]; } +def PTOValidateVPTOIR : Pass<"pto-validate-vpto-ir", "ModuleOp"> { + let summary = + "Validate authoring-stage VPTO legality before ptr-boundary canonicalization"; + let description = [{ + Runs the authoring-stage VPTO legality verifier on post-mainline VPTO IR. + This stage keeps the memref-first authoring surface legal, while checking + the shared structural contracts that must hold before emission-boundary + canonicalization. + }]; + let constructor = "mlir::pto::createPTOValidateVPTOIRPass()"; + let dependentDialects = ["mlir::func::FuncDialect", + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::scf::SCFDialect"]; +} + +def PTOValidateVPTOEmissionIR + : Pass<"pto-validate-vpto-emission-ir", "ModuleOp"> { + let summary = + "Validate emission-stage VPTO legality after ptr-boundary canonicalization"; + let description = [{ + Runs the emission-stage VPTO legality verifier on ptr-form VPTO IR after + `PTOVPTOPtrBoundary`. This stage re-checks the shared authoring contracts + and confirms the final emission surface no longer carries memref boundary + state or residual bridge scaffold. + }]; + let constructor = "mlir::pto::createPTOValidateVPTOEmissionIRPass()"; + let dependentDialects = ["mlir::func::FuncDialect", + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::scf::SCFDialect"]; +} + +def PTOVPTOExpandBridgeOps + : Pass<"pto-vpto-expand-bridge-ops", "func::FuncOp"> { + let summary = + "Expand temporary VPTO bridge ops back to emission-ready VPTO IR"; + let description = [{ + Low-level fusion may keep temporary bridge ops in VPTO IR so legality and + alias analysis can still see memref-form operands. This pass expands those + bridge ops back to the existing emission-ready pointer-level VPTO forms + before backend emission. + }]; + let constructor = "mlir::pto::createPTOVPTOExpandBridgeOpsPass()"; + let dependentDialects = ["mlir::func::FuncDialect", + "mlir::arith::ArithDialect", + "mlir::memref::MemRefDialect", + "mlir::LLVM::LLVMDialect", + "mlir::pto::PTODialect"]; +} + +def PTOVPTOPtrBoundary + : Pass<"pto-vpto-ptr-boundary", "ModuleOp"> { + let summary = + "Canonicalize the final VPTO emission boundary from memref-first IR to ptr ABI"; + let description = [{ + Runs the final emission-boundary ptr canonicalization after the backend + mainline has finished its memref-first optimization pipeline. This pass + rewrites eligible memref function arguments to same-space `!pto.ptr`, + rejects memref function results, canonicalizes supported body-level VPTO + buffer-like ops to ptr-form, and drops dead boundary/view scaffold such as + trivial `pto.castptr`, `pto.bind_tile`, `memref.subview`, + `memref.reinterpret_cast`, and `memref.memory_space_cast` once they become + unused. + }]; + let constructor = "mlir::pto::createPTOVPTOPtrBoundaryPass()"; + let dependentDialects = ["mlir::func::FuncDialect", + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect"]; +} + +def PTOToVPTO : Pass<"pto-to-vpto", "ModuleOp"> { + let summary = "Lower PTO tile ops to VPTO backend ops"; + let description = [{ + Lowers PTO tile ops to VPTO backend ops. For already-planned fusion groups, + the pass rewrites the `pto.fusion_region` body in place and preserves the + wrapper until explicit flatten. Residual non-fused PTO ops may continue to + be lowered directly in their parent block and are not wrapped into + synthetic `pto.fusion_region` containers solely for backend lowering. + }]; + let constructor = "mlir::pto::createLowerPTOToVPTOPass()"; + let options = [ + Option<"loweringStrategy", "pto-lowering-strategy", "std::string", + "\"post-update\"", + "vector lowering strategy: post-update or no-post-update"> + ]; + let dependentDialects = [ + "pto::PTODialect", + "func::FuncDialect", + "arith::ArithDialect", + "memref::MemRefDialect", + "scf::SCFDialect" + ]; +} + #endif // MLIR_DIALECT_PTO_PASSES diff --git a/include/PTO/Transforms/VPTOLLVMEmitter.h b/include/PTO/Transforms/VPTOLLVMEmitter.h new file mode 100644 index 000000000..dc56f64b2 --- /dev/null +++ b/include/PTO/Transforms/VPTOLLVMEmitter.h @@ -0,0 +1,43 @@ +#ifndef MLIR_DIALECT_PTO_TRANSFORMS_VPTOLLVMEMITTER_H +#define MLIR_DIALECT_PTO_TRANSFORMS_VPTOLLVMEMITTER_H + +#include + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Support/LLVM.h" + +namespace mlir { +class ModuleOp; +} + +namespace llvm { +class raw_ostream; +} + +namespace mlir::pto { + +struct VPTOEmissionOptions { + bool dumpVPTOIR = false; + bool printIntrinsicSelections = false; + bool allowUnresolved = true; + std::string unresolvedReportPath; + std::string targetTriple; + std::string march; + std::string aicoreArch; + std::string defaultTargetCPU; + std::string defaultTargetFeatures; +}; + +LogicalResult +translateVPTOModuleToLLVMText(ModuleOp module, llvm::raw_ostream &os, + const VPTOEmissionOptions &options, + llvm::raw_ostream &diagOS); + +LogicalResult +translateVPTOModuleToLLVMBitcode(ModuleOp module, llvm::raw_ostream &os, + const VPTOEmissionOptions &options, + llvm::raw_ostream &diagOS); + +} // namespace mlir::pto + +#endif // MLIR_DIALECT_PTO_TRANSFORMS_VPTOLLVMEMITTER_H diff --git a/include/PTO/Transforms/VPTOLowering.h b/include/PTO/Transforms/VPTOLowering.h new file mode 100644 index 000000000..17730ab4e --- /dev/null +++ b/include/PTO/Transforms/VPTOLowering.h @@ -0,0 +1,241 @@ +//===- VPTOLowering.h - PTO to VPTO lowering contracts ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_PTO_TRANSFORMS_VPTOLOWERING_H_ +#define MLIR_DIALECT_PTO_TRANSFORMS_VPTOLOWERING_H_ + +#include "PTO/IR/PTO.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/Support/LLVM.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +namespace pto { + +enum class VPTOTileDomain { + Vec, + Acc, + Mat, +}; + +enum class VPTOLoweringStrategy { + PostUpdate, + NoPostUpdate, +}; + +struct VPTOPartitionTrace { + SmallVector offsets; + SmallVector sizes; + bool hasDynamicOffsets = false; + bool hasDynamicSizes = false; +}; + +struct VPTOLoopProgramming { + int64_t loop2 = 1; + int64_t loop1 = 1; + int64_t srcLoop2Stride = 1; + int64_t srcLoop1Stride = 1; + int64_t dstLoop2Stride = 1; + int64_t dstLoop1Stride = 1; +}; + +enum class VPTOLoopScopeKind { + None, + AIVVectorScope, +}; + +struct VPTOLoopScopeContract { + VPTOLoopScopeKind kind = VPTOLoopScopeKind::None; + StringRef loweredAttr = "llvm.loop.aivector_scope"; + int64_t loopDepth = 0; +}; + +struct VPTOLoadContract { + StringRef sourceLayout; + SmallVector sourceShape; + SmallVector sourceStrides; + StringRef tileLayout; + VPTOTileDomain tileDomain = VPTOTileDomain::Vec; + Type elementType; + Value validRowsValue; + Value validColsValue; + int64_t validRows = ShapedType::kDynamic; + int64_t validCols = ShapedType::kDynamic; + StringRef padMode; + Value padValue; + Value leftPaddingNum; + Value rightPaddingNum; + bool initOutBuffer = false; + Value initCondition; + VPTOPartitionTrace trace; +}; + +struct VPTOUnaryContract { + StringRef family; + VPTOTileDomain tileDomain = VPTOTileDomain::Vec; + StringRef tileLayout; + Value validRowsValue; + Value validColsValue; + int64_t validRows = ShapedType::kDynamic; + int64_t validCols = ShapedType::kDynamic; + Type elementType; + VPTOLoopScopeContract loopScope; +}; + +struct VPTOBinaryContract { + StringRef family; + VPTOTileDomain tileDomain = VPTOTileDomain::Vec; + StringRef tileLayout; + Value validRowsValue; + Value validColsValue; + int64_t validRows = ShapedType::kDynamic; + int64_t validCols = ShapedType::kDynamic; + Type elementType; + VPTOLoopScopeContract loopScope; +}; + +struct VPTOStoreContract { + VPTOTileDomain srcDomain = VPTOTileDomain::Vec; + StringRef destinationLayout; + SmallVector destinationShape; + SmallVector destinationStrides; + Type elementType; + Value validRowsValue; + Value validColsValue; + int64_t validRows = ShapedType::kDynamic; + int64_t validCols = ShapedType::kDynamic; + VPTOPartitionTrace trace; +}; + +void set_loop2_stride_outtoub(Operation *copyOp, int64_t dstStride, + int64_t srcStride, Builder &builder); +void set_loop1_stride_outtoub(Operation *copyOp, int64_t dstStride, + int64_t srcStride, Builder &builder); +void set_loop_size_outtoub(Operation *copyOp, int64_t loop2, int64_t loop1, + Builder &builder); +void set_loop2_stride_ubtoout(Operation *copyOp, int64_t srcStride, + int64_t dstStride, Builder &builder); +void set_loop1_stride_ubtoout(Operation *copyOp, int64_t srcStride, + int64_t dstStride, Builder &builder); +void set_loop_size_ubtoout(Operation *copyOp, int64_t loop2, int64_t loop1, + Builder &builder); +FailureOr +createLoopScopeRegion(Location loc, const VPTOLoopScopeContract &contract, + PatternRewriter &rewriter); +Value materializeBufferPointer(Value value, Type elementType, + Attribute memorySpace, + PatternRewriter &rewriter, Location loc); + +LogicalResult lowerTLOAD(TLoadOp op, PatternRewriter &rewriter); +LogicalResult lowerTABS(TAbsOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTADD(TAddOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTSUB(TSubOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTMUL(TMulOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTDIV(TDivOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTMAX(TMaxOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTMIN(TMinOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTAND(TAndOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTANDS(TAndSOp op, PatternRewriter &rewriter); +LogicalResult lowerTOR(TOrOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTORS(TOrSOp op, PatternRewriter &rewriter); +LogicalResult lowerTXOR(TXorOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTXORS(TXorSOp op, PatternRewriter &rewriter); +LogicalResult lowerTEXP(TExpOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTLOG(TLogOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTSQRT(TSqrtOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTRSQRT(TRsqrtOp op, PatternRewriter &rewriter); +LogicalResult lowerTRECIP(TRecipOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTNEG(TNegOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTLRELU(TLReluOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTCI(TCIOp op, PatternRewriter &rewriter); +LogicalResult lowerTCVT(TCvtOp op, PatternRewriter &rewriter); +LogicalResult lowerTCmp(TCmpOp op, PatternRewriter &rewriter); +LogicalResult lowerTCmpS(TCmpSOp op, PatternRewriter &rewriter); +LogicalResult lowerTSel(TSelOp op, PatternRewriter &rewriter); +LogicalResult lowerTAddC(TAddCOp op, PatternRewriter &rewriter); +LogicalResult lowerTAddS(TAddSOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTAddSC(TAddSCOp op, PatternRewriter &rewriter); +LogicalResult lowerTMinS(TMinSOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTDivS(TDivSOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTMulS(TMulSOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTSubC(TSubCOp op, PatternRewriter &rewriter); +LogicalResult lowerTSubS(TSubSOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTSubSC(TSubSCOp op, PatternRewriter &rewriter); +LogicalResult lowerTMaxS(TMaxSOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTSelS(TSelSOp op, PatternRewriter &rewriter); +LogicalResult lowerTRELU(TReluOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTNOT(TNotOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTTRANS(TTransOp op, PatternRewriter &rewriter); +LogicalResult lowerTFILLPAD(TFillPadOp op, PatternRewriter &rewriter); +LogicalResult lowerTFILLPADExpand(TFillPadExpandOp op, PatternRewriter &rewriter); +LogicalResult lowerTRowMax(TRowMaxOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTRowMin(TRowMinOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTRowSum(TRowSumOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTColMax(TColMaxOp op, PatternRewriter &rewriter); +LogicalResult lowerTColMin(TColMinOp op, PatternRewriter &rewriter); +LogicalResult lowerTColSum(TColSumOp op, PatternRewriter &rewriter); +LogicalResult lowerTRowExpand(TRowExpandOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTColExpand(TColExpandOp op, PatternRewriter &rewriter); +LogicalResult lowerTRowExpandMul(TRowExpandMulOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTRowExpandDiv(TRowExpandDivOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTRowExpandSub(TRowExpandSubOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTPartAdd(TPartAddOp op, PatternRewriter &rewriter); +LogicalResult lowerTPartMax(TPartMaxOp op, PatternRewriter &rewriter); +LogicalResult lowerTPartMin(TPartMinOp op, PatternRewriter &rewriter); +LogicalResult lowerTExpandS(TExpandsOp op, PatternRewriter &rewriter); +LogicalResult lowerTGather(TGatherOp op, PatternRewriter &rewriter); +LogicalResult lowerTGatherB(TGatherBOp op, PatternRewriter &rewriter); +LogicalResult lowerTScatter(TScatterOp op, PatternRewriter &rewriter); +LogicalResult lowerTMrgSort(TMrgSortOp op, PatternRewriter &rewriter); +LogicalResult lowerTSort32(TSort32Op op, PatternRewriter &rewriter); +LogicalResult lowerTSTORE(TStoreOp op, PatternRewriter &rewriter); +LogicalResult lowerSetFlag(SetFlagOp op, PatternRewriter &rewriter); +LogicalResult lowerWaitFlag(WaitFlagOp op, PatternRewriter &rewriter); +LogicalResult lowerBarrier(BarrierOp op, PatternRewriter &rewriter); +LogicalResult lowerGetBuf(GetBufOp op, PatternRewriter &rewriter); +LogicalResult lowerRlsBuf(RlsBufOp op, PatternRewriter &rewriter); +LogicalResult convertVPTOEmissionBoundaryToPtr( + ModuleOp module, llvm::raw_ostream *diagOS = nullptr); + +} // namespace pto +} // namespace mlir + +#endif // MLIR_DIALECT_PTO_TRANSFORMS_VPTOLOWERING_H_ diff --git a/include/pto-c/Dialect/PTO.h b/include/pto-c/Dialect/PTO.h index f7ee619c7..acfac3a48 100644 --- a/include/pto-c/Dialect/PTO.h +++ b/include/pto-c/Dialect/PTO.h @@ -26,7 +26,10 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(PTO, pto); // ---- !pto.ptr ---- bool mlirPTOTypeIsAPtrType(MlirType type); MlirType mlirPTOPtrTypeGet(MlirContext ctx, MlirType elementType); +MlirType mlirPTOPtrTypeGetWithMemorySpace(MlirContext ctx, MlirType elementType, + MlirAttribute memorySpace); MlirType mlirPTOPtrTypeGetElementType(MlirType type); +MlirAttribute mlirPTOPtrTypeGetMemorySpace(MlirType type); // ---- !pto.async_session / !pto.async_event ---- bool mlirPTOTypeIsAAsyncSessionType(MlirType type); diff --git a/lib/Bindings/Python/CMakeLists.txt b/lib/Bindings/Python/CMakeLists.txt index e9e32ba98..6fe560a06 100644 --- a/lib/Bindings/Python/CMakeLists.txt +++ b/lib/Bindings/Python/CMakeLists.txt @@ -39,6 +39,7 @@ target_link_libraries(_pto PRIVATE MLIRSupport MLIRArithDialect MLIRMemRefDialect + MLIRSCFDialect MLIRDestinationStyleOpInterface MLIRInferTypeOpInterface MLIRSideEffectInterfaces @@ -47,6 +48,7 @@ target_link_libraries(_pto PRIVATE MLIRLoopLikeInterface MLIRViewLikeInterface MLIRFunctionInterfaces + MLIRLLVMDialect ) # 关键:放到 mlir/_mlir_libs 下(匹配 MLIR dialect python 的 import 习惯) diff --git a/lib/Bindings/Python/PTOModule.cpp b/lib/Bindings/Python/PTOModule.cpp index 4ff530512..98d293db2 100644 --- a/lib/Bindings/Python/PTOModule.cpp +++ b/lib/Bindings/Python/PTOModule.cpp @@ -567,20 +567,34 @@ static void bindPTOModule(pybind11::module &m) { [](MlirType type) -> bool { return mlirPTOTypeIsAPtrType(type); }) .def_classmethod( "get", - [](py::object cls, MlirType elementType, + [](py::object cls, MlirType elementType, py::object memorySpace, MlirContext context) -> py::object { MlirContext ctx = context; if (!ctx.ptr) ctx = mlirTypeGetContext(elementType); - MlirType t = mlirPTOPtrTypeGet(ctx, elementType); + MlirType t = {nullptr}; + if (memorySpace.is_none()) { + t = mlirPTOPtrTypeGet(ctx, elementType); + } else { + MlirAttribute memorySpaceAttr = + py::cast(memorySpace); + t = mlirPTOPtrTypeGetWithMemorySpace(ctx, elementType, + memorySpaceAttr); + } return cls.attr("__call__")(t); }, py::arg("cls"), py::arg("element_type"), + py::arg("memory_space") = py::none(), py::arg("context") = py::none()) .def_property_readonly( "element_type", [](MlirType self) -> MlirType { return mlirPTOPtrTypeGetElementType(self); + }) + .def_property_readonly( + "memory_space", + [](MlirType self) -> MlirAttribute { + return mlirPTOPtrTypeGetMemorySpace(self); }); mlir_type_subclass( diff --git a/lib/CAPI/Dialect/PTO.cpp b/lib/CAPI/Dialect/PTO.cpp index 3d5682ae8..d0c9ce820 100644 --- a/lib/CAPI/Dialect/PTO.cpp +++ b/lib/CAPI/Dialect/PTO.cpp @@ -60,6 +60,14 @@ MlirType mlirPTOPtrTypeGet(MlirContext ctx, MlirType elementType) { return wrap(mlir::pto::PtrType::get(c, elem)); } +MlirType mlirPTOPtrTypeGetWithMemorySpace(MlirContext ctx, MlirType elementType, + MlirAttribute memorySpace) { + auto c = unwrap(ctx); + auto elem = unwrap(elementType); + auto space = mlir::cast(unwrap(memorySpace)); + return wrap(mlir::pto::PtrType::get(c, elem, space)); +} + MlirType mlirPTOPtrTypeGetElementType(MlirType type) { auto t = cast(unwrap(type));; return wrap(t.getElementType()); @@ -105,6 +113,11 @@ MlirType mlirPTOF4E2M1x2TypeGet(MlirContext ctx) { return wrap(mlir::pto::F4E2M1x2Type::get(unwrap(ctx))); } +MlirAttribute mlirPTOPtrTypeGetMemorySpace(MlirType type) { + auto t = cast(unwrap(type)); + return wrap(t.getMemorySpace()); +} + bool mlirPTOAttrIsAAddressSpaceAttr(MlirAttribute attr) { return mlir::isa(unwrap(attr)); } diff --git a/lib/PTO/IR/CMakeLists.txt b/lib/PTO/IR/CMakeLists.txt index b055d8290..74b9e0bd6 100644 --- a/lib/PTO/IR/CMakeLists.txt +++ b/lib/PTO/IR/CMakeLists.txt @@ -14,6 +14,7 @@ # [关键] 库名重命名为 PTOIR,避免与 LLVM 里的 PTODialect/MLIRPTODialect 冲突 add_mlir_dialect_library(PTOIR PTO.cpp + VPTO.cpp PTOAttrs.cpp PTOSyncUtils.cpp PTOTypeDefs.cpp @@ -29,6 +30,7 @@ add_mlir_dialect_library(PTOIR MLIRIR MLIRFuncDialect MLIRMemRefDialect + MLIRSCFDialect MLIRControlFlowInterfaces MLIRInferTypeOpInterface MLIRSideEffectInterfaces diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index 2f5d8e3fc..74751c313 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -170,6 +170,11 @@ static LogicalResult verifyArithmeticElemTypeForArch( Operation *op, Type elemTy, PTOArch targetArch, bool allowInt8OnA5, bool allowBf16OnA5, StringRef a2a3Error, StringRef a5Error); static bool isRowMajorTileBuf(Type ty); +static ParseResult parseQuotedPipeToken(OpAsmParser &parser, PipeAttr &attr); +static ParseResult parseQuotedEventToken(OpAsmParser &parser, EventAttr &attr); +static ParseResult parseLegacyOrAttrPipe(OpAsmParser &parser, PipeAttr &attr); +static ParseResult parseLegacyOrAttrEvent(OpAsmParser &parser, EventAttr &attr); +static ParseResult parseI32LiteralAttr(OpAsmParser &parser, IntegerAttr &attr); #define GET_ENUM_CLASSES #include "PTO/IR/PTOEnums.cpp.inc" @@ -374,6 +379,24 @@ static LogicalResult dispatchVerifierByArch(Operation *op, FnA2A3 &&verifyA2A3, return verifyA5(); } } +static std::optional parsePtrAddressSpaceKeyword(StringRef keyword) { + return llvm::StringSwitch>(keyword) + .Case("gm", pto::AddressSpace::GM) + .Case("ub", pto::AddressSpace::VEC) + .Default(std::nullopt); +} + +static StringRef printPtrAddressSpaceKeyword(pto::AddressSpace space) { + switch (space) { + case pto::AddressSpace::GM: + case pto::AddressSpace::Zero: + return "gm"; + case pto::AddressSpace::VEC: + return "ub"; + default: + return {}; + } +} static ParseResult parseSyncEventOpCommon(OpAsmParser &parser, OperationState &result, @@ -483,17 +506,22 @@ static mlir::Type parsePTOTypeAllowNoBang(mlir::OpAsmParser &parser) { mlir::Type elem; if (failed(parser.parseType(elem))) return mlir::Type(); + auto memorySpace = pto::AddressSpaceAttr::get(ctx, pto::AddressSpace::GM); if (succeeded(parser.parseOptionalComma())) { - // ptr no longer accepts an address space; consume the attr for recovery. - mlir::Attribute memorySpace; - (void)parser.parseAttribute(memorySpace); - parser.emitError(parser.getCurrentLocation(), - "!pto.ptr no longer accepts address space; use !pto.ptr"); - return mlir::Type(); + StringRef memorySpaceKeyword; + if (failed(parser.parseKeyword(&memorySpaceKeyword))) + return mlir::Type(); + auto parsed = parsePtrAddressSpaceKeyword(memorySpaceKeyword); + if (!parsed) { + parser.emitError(parser.getCurrentLocation(), + "!pto.ptr address space must be `gm` or `ub`"); + return mlir::Type(); + } + memorySpace = pto::AddressSpaceAttr::get(ctx, *parsed); } if (failed(parser.parseGreater())) return mlir::Type(); - return mlir::pto::PtrType::get(ctx, elem); + return mlir::pto::PtrType::get(ctx, elem, memorySpace); } if (head == "pto.tensor_view") { @@ -519,6 +547,40 @@ void TensorViewType::print(::mlir::AsmPrinter &printer) const { printShapeAndElem(printer, getShape(), getElementType()); } +mlir::Type PtrType::parse(::mlir::AsmParser &parser) { + Type elementType; + if (failed(parser.parseLess()) || failed(parser.parseType(elementType))) + return {}; + + auto memorySpace = + pto::AddressSpaceAttr::get(parser.getContext(), pto::AddressSpace::GM); + if (succeeded(parser.parseOptionalComma())) { + StringRef memorySpaceKeyword; + if (failed(parser.parseKeyword(&memorySpaceKeyword))) + return {}; + auto parsed = parsePtrAddressSpaceKeyword(memorySpaceKeyword); + if (!parsed) { + parser.emitError(parser.getCurrentLocation(), + "!pto.ptr address space must be `gm` or `ub`"); + return {}; + } + memorySpace = pto::AddressSpaceAttr::get(parser.getContext(), *parsed); + } + + if (failed(parser.parseGreater())) + return {}; + return PtrType::get(parser.getContext(), elementType, memorySpace); +} + +void PtrType::print(::mlir::AsmPrinter &printer) const { + printer << "<" << getElementType(); + StringRef memorySpaceKeyword = + printPtrAddressSpaceKeyword(getMemorySpace().getAddressSpace()); + if (!memorySpaceKeyword.empty()) + printer << ", " << memorySpaceKeyword; + printer << ">"; +} + //===----------------------------------------------------------------------===// // pto.tdivs custom asm to support both: // pto.tdivs ins(%src, %scalar : !pto.tile_buf<...>, f32) outs(%dst : !pto.tile_buf<...>) @@ -1519,6 +1581,43 @@ LogicalResult mlir::pto::AddPtrOp::verify() { return success(); } +LogicalResult mlir::pto::CastPtrOp::verify() { + Type inputType = getInput().getType(); + Type resultType = getResult().getType(); + + auto inputPtrType = dyn_cast(inputType); + auto resultPtrType = dyn_cast(resultType); + auto inputMemRefType = dyn_cast(inputType); + bool inputIsInteger = isa(inputType); + bool resultIsInteger = isa(resultType); + + if (!inputPtrType && !inputMemRefType && !inputIsInteger) + return emitOpError("input must be an integer, memref, or !pto.ptr<...>"); + if (!resultPtrType && !resultIsInteger) + return emitOpError("result must be an integer or !pto.ptr<...>"); + + if (inputIsInteger && resultIsInteger) + return emitOpError("integer-to-integer cast is not a ptr cast"); + + if (inputMemRefType && resultIsInteger) + return emitOpError("memref-to-integer cast is unsupported"); + + if (inputMemRefType && resultPtrType) { + auto memrefSpace = dyn_cast_or_null( + inputMemRefType.getMemorySpace()); + auto resultSpace = resultPtrType.getMemorySpace(); + if (memrefSpace && memrefSpace != resultSpace) + return emitOpError("memref-to-ptr cast must stay within the same PTO memory space"); + } + + if (inputPtrType && resultPtrType && + inputPtrType.getMemorySpace() != resultPtrType.getMemorySpace()) { + return emitOpError("ptr-to-ptr cast must stay within the same PTO memory space"); + } + + return success(); +} + @@ -1541,6 +1640,8 @@ void PTODialect::initialize() { AddressSpaceAttr mlir::pto::getPTOAddressSpaceAttr(Type type) { + if (auto ptrType = dyn_cast(type)) + return ptrType.getMemorySpace(); auto memRefType = dyn_cast(type); if (!memRefType) return {}; @@ -1552,7 +1653,7 @@ AddressSpaceAttr mlir::pto::getPTOAddressSpaceAttr(Type type) { bool mlir::pto::isScalarPtrOrMemRef(Type type) { if (auto pty = dyn_cast(type)) - return true; + return static_cast(pty); if (auto memTy = dyn_cast(type)) return isGmAddressSpaceAttr(memTy.getMemorySpace()); return false; @@ -5523,14 +5624,19 @@ static LogicalResult verifyBufSyncOp(Operation *op, Attribute opTypeAttr, if (!opTypeAttr) return op->emitOpError("expects 'op_type' attribute"); - auto opTypeOr = parseSyncOpTypeLikeAttr(opTypeAttr); - if (failed(opTypeOr)) { - auto diag = - op->emitOpError("expects 'op_type' to be pipe_event_type/sync_op_type, got "); - diag << opTypeAttr; - return failure(); + pto::PIPE pipe = pto::PIPE::PIPE_UNASSIGNED; + if (auto pipeAttr = dyn_cast(opTypeAttr)) { + pipe = pipeAttr.getPipe(); + } else { + auto opTypeOr = parseSyncOpTypeLikeAttr(opTypeAttr); + if (failed(opTypeOr)) { + auto diag = op->emitOpError( + "expects 'op_type' to be pipe_event_type/sync_op_type/pipe, got "); + diag << opTypeAttr; + return failure(); + } + pipe = mapSyncOpTypeToPipe(*opTypeOr); } - pto::PIPE pipe = mapSyncOpTypeToPipe(*opTypeOr); if (!isConcreteSyncPipe(pipe)) return op->emitOpError("expects 'op_type' to map to a concrete pipe, not PIPE_ALL/PIPE_UNASSIGNED"); @@ -5558,6 +5664,282 @@ LogicalResult RlsBufOp::verify() { return verifyBufSyncOp(getOperation(), getOpTypeAttr(), getBufIdAttr(), getModeAttr()); } + +static ParseResult parseQuotedPipeToken(OpAsmParser &parser, PipeAttr &attr) { + std::string pipeName; + auto loc = parser.getCurrentLocation(); + if (failed(parser.parseString(&pipeName))) + return failure(); + auto pipe = symbolizePIPE(pipeName); + if (!pipe) + return parser.emitError(loc) << "invalid pipe token: " << pipeName; + attr = PipeAttr::get(parser.getContext(), *pipe); + return success(); +} + +static ParseResult parseQuotedEventToken(OpAsmParser &parser, EventAttr &attr) { + std::string eventName; + auto loc = parser.getCurrentLocation(); + if (failed(parser.parseString(&eventName))) + return failure(); + auto event = symbolizeEVENT(eventName); + if (!event) + return parser.emitError(loc) << "invalid event token: " << eventName; + attr = EventAttr::get(parser.getContext(), *event); + return success(); +} + +static ParseResult parseLegacyOrAttrPipe(OpAsmParser &parser, PipeAttr &attr) { + auto loc = parser.getCurrentLocation(); + std::string token; + if (succeeded(parser.parseOptionalString(&token))) { + auto pipe = symbolizePIPE(token); + if (!pipe) + return parser.emitError(loc) << "invalid pipe token: " << token; + attr = PipeAttr::get(parser.getContext(), *pipe); + return success(); + } + + if (succeeded(parser.parseOptionalLess())) { + StringRef keyword; + if (parser.parseKeyword(&keyword) || parser.parseGreater()) + return failure(); + auto pipe = symbolizePIPE(keyword); + if (!pipe) + return parser.emitError(loc) << "invalid pipe token: " << keyword; + attr = PipeAttr::get(parser.getContext(), *pipe); + return success(); + } + + Attribute parsed; + if (failed(parser.parseAttribute(parsed))) + return failure(); + auto pipeAttr = dyn_cast(parsed); + if (!pipeAttr) + return parser.emitError(loc, "expected pipe attribute"); + attr = pipeAttr; + return success(); +} + +static ParseResult parseLegacyOrAttrEvent(OpAsmParser &parser, EventAttr &attr) { + auto loc = parser.getCurrentLocation(); + std::string token; + if (succeeded(parser.parseOptionalString(&token))) { + auto event = symbolizeEVENT(token); + if (!event) + return parser.emitError(loc) << "invalid event token: " << token; + attr = EventAttr::get(parser.getContext(), *event); + return success(); + } + + if (succeeded(parser.parseOptionalLess())) { + StringRef keyword; + if (parser.parseKeyword(&keyword) || parser.parseGreater()) + return failure(); + auto event = symbolizeEVENT(keyword); + if (!event) + return parser.emitError(loc) << "invalid event token: " << keyword; + attr = EventAttr::get(parser.getContext(), *event); + return success(); + } + + Attribute parsed; + if (failed(parser.parseAttribute(parsed))) + return failure(); + auto eventAttr = dyn_cast(parsed); + if (!eventAttr) + return parser.emitError(loc, "expected event attribute"); + attr = eventAttr; + return success(); +} + +static ParseResult parseI32LiteralAttr(OpAsmParser &parser, IntegerAttr &attr) { + auto loc = parser.getCurrentLocation(); + int64_t value = 0; + if (failed(parser.parseInteger(value))) + return failure(); + if (value < std::numeric_limits::min() || + value > std::numeric_limits::max()) + return parser.emitError(loc, "expected 32-bit integer literal"); + attr = IntegerAttr::get(IntegerType::get(parser.getContext(), 32), value); + return success(); +} + +static void printLegacySyncTriplet(OpAsmPrinter &p, PipeAttr srcPipe, + PipeAttr dstPipe, EventAttr eventId, + ArrayRef attrs) { + p << "[\"" << stringifyPIPE(srcPipe.getPipe()) << "\", \"" + << stringifyPIPE(dstPipe.getPipe()) << "\", \"" + << stringifyEVENT(eventId.getEvent()) << "\"]"; + p.printOptionalAttrDict(attrs, {"src_pipe", "dst_pipe", "event_id"}); +} + +ParseResult SetFlagOp::parse(OpAsmParser &parser, OperationState &result) { + PipeAttr srcPipe; + PipeAttr dstPipe; + EventAttr eventId; + if (parser.parseLSquare() || parseLegacyOrAttrPipe(parser, srcPipe) || + parser.parseComma() || parseLegacyOrAttrPipe(parser, dstPipe) || + parser.parseComma() || parseLegacyOrAttrEvent(parser, eventId) || + parser.parseRSquare()) + return failure(); + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + result.addAttribute("src_pipe", srcPipe); + result.addAttribute("dst_pipe", dstPipe); + result.addAttribute("event_id", eventId); + return success(); +} + +void SetFlagOp::print(OpAsmPrinter &p) { + printLegacySyncTriplet(p, getSrcPipe(), getDstPipe(), getEventId(), + (*this)->getAttrs()); +} + +ParseResult WaitFlagOp::parse(OpAsmParser &parser, OperationState &result) { + PipeAttr srcPipe; + PipeAttr dstPipe; + EventAttr eventId; + if (parser.parseLSquare() || parseLegacyOrAttrPipe(parser, srcPipe) || + parser.parseComma() || parseLegacyOrAttrPipe(parser, dstPipe) || + parser.parseComma() || parseLegacyOrAttrEvent(parser, eventId) || + parser.parseRSquare()) + return failure(); + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + result.addAttribute("src_pipe", srcPipe); + result.addAttribute("dst_pipe", dstPipe); + result.addAttribute("event_id", eventId); + return success(); +} + +void WaitFlagOp::print(OpAsmPrinter &p) { + printLegacySyncTriplet(p, getSrcPipe(), getDstPipe(), getEventId(), + (*this)->getAttrs()); +} + +static ParseResult parseLegacyOrAttrOpType(OpAsmParser &parser, + Attribute &opTypeAttr) { + auto loc = parser.getCurrentLocation(); + std::string token; + if (succeeded(parser.parseOptionalString(&token))) { + if (auto pipe = symbolizePIPE(token)) { + opTypeAttr = PipeAttr::get(parser.getContext(), *pipe); + return success(); + } + if (auto opType = symbolizeSyncOpType(token)) { + opTypeAttr = PipeEventTypeAttr::get(parser.getContext(), *opType); + return success(); + } + return parser.emitError(loc) << "invalid get_buf/rls_buf token: " << token; + } + + if (succeeded(parser.parseOptionalLSquare())) { + if (failed(parser.parseAttribute(opTypeAttr))) + return failure(); + return success(); + } + + if (failed(parser.parseAttribute(opTypeAttr))) + return failure(); + return success(); +} + +static ParseResult parseBufSyncOp(OpAsmParser &parser, OperationState &result) { + Attribute opTypeAttr; + IntegerAttr bufIdAttr; + IntegerAttr modeAttr; + + auto loc = parser.getCurrentLocation(); + std::string token; + if (succeeded(parser.parseOptionalString(&token))) { + if (auto pipe = symbolizePIPE(token)) + opTypeAttr = PipeAttr::get(parser.getContext(), *pipe); + else if (auto opType = symbolizeSyncOpType(token)) + opTypeAttr = PipeEventTypeAttr::get(parser.getContext(), *opType); + else + return parser.emitError(loc) << "invalid get_buf/rls_buf token: " << token; + + if (parser.parseComma() || parseI32LiteralAttr(parser, bufIdAttr)) + return failure(); + if (succeeded(parser.parseOptionalComma())) { + if (parseI32LiteralAttr(parser, modeAttr)) + return failure(); + } else { + modeAttr = IntegerAttr::get(IntegerType::get(parser.getContext(), 32), 0); + } + } else if (succeeded(parser.parseOptionalLSquare())) { + if (parser.parseAttribute(opTypeAttr) || parser.parseComma() || + parseI32LiteralAttr(parser, bufIdAttr)) + return failure(); + if (succeeded(parser.parseOptionalComma())) { + if (parseI32LiteralAttr(parser, modeAttr)) + return failure(); + } else { + modeAttr = IntegerAttr::get(IntegerType::get(parser.getContext(), 32), 0); + } + if (parser.parseRSquare()) + return failure(); + } else { + return parser.emitError(loc, "expected string pipe/op_type or '['"); + } + + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + result.addAttribute("op_type", opTypeAttr); + result.addAttribute("buf_id", bufIdAttr); + result.addAttribute("mode", modeAttr); + return success(); +} + +static void printBufSyncOp(OpAsmPrinter &p, Attribute opTypeAttr, + IntegerAttr bufIdAttr, IntegerAttr modeAttr, + ArrayRef attrs) { + if (auto pipeAttr = dyn_cast(opTypeAttr)) { + p << " \"" << stringifyPIPE(pipeAttr.getPipe()) << "\", " + << bufIdAttr.getInt() << ", " << modeAttr.getInt(); + } else if (auto pipeEventType = dyn_cast(opTypeAttr)) { + auto pipe = mapSyncOpTypeToPipe(pipeEventType.getOpType()); + if (isConcreteSyncPipe(pipe)) { + p << " \"" << stringifyPIPE(pipe) << "\", " << bufIdAttr.getInt() + << ", " << modeAttr.getInt(); + } else { + p << "[" << opTypeAttr << ", " << bufIdAttr.getInt() << ", " + << modeAttr.getInt() << "]"; + } + } else if (auto syncOpType = dyn_cast(opTypeAttr)) { + auto pipe = mapSyncOpTypeToPipe(syncOpType.getOpType()); + if (isConcreteSyncPipe(pipe)) { + p << " \"" << stringifyPIPE(pipe) << "\", " << bufIdAttr.getInt() + << ", " << modeAttr.getInt(); + } else { + p << "[" << opTypeAttr << ", " << bufIdAttr.getInt() << ", " + << modeAttr.getInt() << "]"; + } + } else { + p << "[" << opTypeAttr << ", " << bufIdAttr.getInt() << ", " + << modeAttr.getInt() << "]"; + } + p.printOptionalAttrDict(attrs, {"op_type", "buf_id", "mode"}); +} + +ParseResult GetBufOp::parse(OpAsmParser &parser, OperationState &result) { + return parseBufSyncOp(parser, result); +} + +void GetBufOp::print(OpAsmPrinter &p) { + printBufSyncOp(p, getOpTypeAttr(), getBufIdAttr(), getModeAttr(), + (*this)->getAttrs()); +} + +ParseResult RlsBufOp::parse(OpAsmParser &parser, OperationState &result) { + return parseBufSyncOp(parser, result); +} + +void RlsBufOp::print(OpAsmPrinter &p) { + printBufSyncOp(p, getOpTypeAttr(), getBufIdAttr(), getModeAttr(), + (*this)->getAttrs()); +} // ---- TOp ---- LogicalResult TGemvBiasOp::verify() { auto verifyA2A3 = [&]() -> LogicalResult { diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp new file mode 100644 index 000000000..aa37efb79 --- /dev/null +++ b/lib/PTO/IR/VPTO.cpp @@ -0,0 +1,2924 @@ +//===- VPTO.cpp - VPTO dialect -------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PTO/IR/PTO.h" + +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" + +#include + +using namespace mlir; +using namespace mlir::pto; + +static llvm::cl::opt disableVPTOAlignChainVerification( + "vpto-disable-align-chain-verification", + llvm::cl::desc("Disable !pto.align linear-chain verifier checks"), + llvm::cl::init(false), llvm::cl::Hidden); + +static std::string formatVRegType(int64_t elementCount, Type elementType) { + std::string storage; + llvm::raw_string_ostream os(storage); + os << "!pto.vreg<" << elementCount << "x" << elementType << ">"; + return storage; +} + +static std::string formatMaskType(StringRef granularity) { + std::string storage; + llvm::raw_string_ostream os(storage); + os << "!pto.mask<" << granularity << ">"; + return storage; +} + +static LogicalResult verifyVRegTypeLike(Operation *op, Type type, + StringRef roleDescription) { + auto vecType = dyn_cast(type); + if (!vecType) + return op->emitOpError() << roleDescription << " must be !pto.vreg<...>"; + + return VRegType::verify( + [&]() { return op->emitOpError() << roleDescription << " "; }, + vecType.getElementCount(), vecType.getElementType()); +} + +static LogicalResult verifyMaskTypeLike(Operation *op, Type type, + StringRef roleDescription) { + if (!isa(type)) + return op->emitOpError() << roleDescription << " must be !pto.mask<...>"; + return success(); +} + +static LogicalResult verifyMaskTypeWithGranularityLike(Operation *op, Type type, + StringRef roleDescription, + StringRef granularity) { + auto maskType = dyn_cast(type); + if (!maskType) + return op->emitOpError() << roleDescription << " must be !pto.mask<...>"; + if (maskType.getGranularity() != granularity) { + return op->emitOpError() + << roleDescription << " must be " << formatMaskType(granularity); + } + return success(); +} + +static LogicalResult verifyEnclosingLoopLike(Operation *op, + StringRef opNameForDiag) { + if (!op->getParentOfType()) { + return op->emitOpError() + << "requires enclosing loop structure for " << opNameForDiag + << " lowering"; + } + return success(); +} + +static LogicalResult verifyNotNestedInVecScope(Operation *op, + StringRef opNameForDiag) { + if (op->getParentOfType() || + op->getParentOfType()) { + return op->emitOpError() + << "must not be nested under pto.vecscope/pto.strict_vecscope; " + << opNameForDiag << " is a UB helper op rather than a vecscope op"; + } + return success(); +} + +static LogicalResult verifyNestedInVecScope(Operation *op, + StringRef opNameForDiag) { + if (op->getParentOfType() || op->getParentOfType()) + return success(); + return op->emitOpError() + << "must be nested under pto.vecscope/pto.strict_vecscope; " + << opNameForDiag << " is part of the vecscope control sequence"; +} + +static LogicalResult verifyAlignTypeLike(Operation *op, Type type, + StringRef roleDescription) { + if (!isa(type)) + return op->emitOpError() << roleDescription << " must be !pto.align"; + return success(); +} + +static bool isSupportedVdupPosition(std::optional position) { + return !position || *position == "LOWEST" || *position == "HIGHEST"; +} + +static std::optional getVdupMaskGranularity(Type elementType) { + if (auto intType = dyn_cast(elementType)) { + switch (intType.getWidth()) { + case 8: + return StringRef("b8"); + case 16: + return StringRef("b16"); + case 32: + return StringRef("b32"); + default: + return std::nullopt; + } + } + if (elementType.isF16() || elementType.isBF16()) + return StringRef("b16"); + if (elementType.isF32()) + return StringRef("b32"); + return std::nullopt; +} + +static bool isStoreAlignProducer(Operation *op) { + return isa(op); +} + +static bool isStoreAlignSink(Operation *op) { + return isa(op); +} + +static bool isLoadAlignProducer(Operation *op) { + return isa(op); +} + +static bool isValueOwnedByRegion(Value value, Region *region) { + if (auto blockArg = dyn_cast(value)) + return blockArg.getParentRegion() == region; + if (Operation *def = value.getDefiningOp()) + return def->getParentRegion() == region; + return false; +} + +static FailureOr resolveStoreAlignRoot(Value value, Operation *user) { + llvm::SmallPtrSet visited; + Value current = value; + + while (true) { + if (!visited.insert(current.getAsOpaquePointer()).second) { + return failure(); + } + + if (auto blockArg = dyn_cast(current)) { + auto *owner = blockArg.getOwner(); + auto forOp = dyn_cast(owner->getParentOp()); + if (!forOp) + return failure(); + unsigned argNumber = blockArg.getArgNumber(); + unsigned ivCount = forOp.getNumInductionVars(); + if (argNumber < ivCount) + return failure(); + unsigned iterIdx = argNumber - ivCount; + if (iterIdx >= forOp.getInitArgs().size()) + return failure(); + current = forOp.getInitArgs()[iterIdx]; + continue; + } + + if (Operation *def = current.getDefiningOp()) { + if (isStoreAlignProducer(def)) + return current; + if (auto forOp = dyn_cast(def)) { + auto result = dyn_cast(current); + if (!result) + return failure(); + unsigned resultIdx = result.getResultNumber(); + if (resultIdx >= forOp.getYieldedValues().size()) + return failure(); + current = forOp.getYieldedValues()[resultIdx]; + continue; + } + } + + return failure(); + } +} + +static LogicalResult verifyStoreAlignLoopThreading(Value align, Operation *user, + StringRef roleDescription) { + Operation *cursor = user; + while (auto forOp = cursor->getParentOfType()) { + Region *body = &forOp.getRegion(); + if (!isValueOwnedByRegion(align, body)) { + return user->emitOpError() + << roleDescription + << " must be threaded through scf.for iter_args when used inside a " + "loop"; + } + cursor = forOp; + } + return success(); +} + +static LogicalResult verifyStoreAlignLinearUses(Value value, Operation *user) { + llvm::SmallPtrSet visited; + Value current = value; + + while (visited.insert(current.getAsOpaquePointer()).second) { + SmallVector nextValues; + SmallVector terminalUsers; + + for (OpOperand &use : current.getUses()) { + Operation *owner = use.getOwner(); + if (isStoreAlignSink(owner)) { + terminalUsers.push_back(owner); + continue; + } + if (auto stateOp = dyn_cast(owner)) { + nextValues.push_back(stateOp.getAlignOut()); + continue; + } + if (auto stateOp = dyn_cast(owner)) { + nextValues.push_back(stateOp.getAlignOut()); + continue; + } + if (auto stateOp = dyn_cast(owner)) { + nextValues.push_back(stateOp.getAlignOut()); + continue; + } + if (auto forOp = dyn_cast(owner)) { + unsigned firstInitArg = forOp.getNumControlOperands(); + if (use.getOperandNumber() < firstInitArg) + return user->emitOpError() + << "found unexpected scf.for control operand use for !pto.align"; + unsigned iterIdx = use.getOperandNumber() - firstInitArg; + if (iterIdx >= forOp.getRegionIterArgs().size()) + return user->emitOpError() + << "found invalid scf.for iter_args use for !pto.align"; + nextValues.push_back(forOp.getRegionIterArgs()[iterIdx]); + continue; + } + if (auto yieldOp = dyn_cast(owner)) { + auto forOp = dyn_cast(yieldOp->getParentOp()); + if (!forOp) + return user->emitOpError() + << "found !pto.align yielded from non-scf.for loop"; + unsigned resultIdx = use.getOperandNumber(); + if (resultIdx >= forOp.getNumResults()) + return user->emitOpError() + << "found invalid scf.yield result mapping for !pto.align"; + nextValues.push_back(forOp.getResult(resultIdx)); + continue; + } + return user->emitOpError() + << "found unsupported !pto.align consumer " << owner->getName(); + } + + if (nextValues.size() + terminalUsers.size() > 1) { + return user->emitOpError() + << "!pto.align value must form a single linear store-state chain"; + } + if (nextValues.empty()) + return success(); + current = nextValues.front(); + } + + return success(); +} + +static LogicalResult verifyStoreAlignChain(Value align, Operation *user, + StringRef roleDescription) { + if (disableVPTOAlignChainVerification) + return success(); + + if (failed(verifyAlignTypeLike(user, align.getType(), roleDescription))) + return failure(); + + if (failed(verifyStoreAlignLoopThreading(align, user, roleDescription))) + return failure(); + + FailureOr root = resolveStoreAlignRoot(align, user); + if (failed(root)) { + if (Operation *def = align.getDefiningOp()) { + if (!isa(def)) { + return user->emitOpError() + << roleDescription + << " must be produced by pto.init_align or a prior store-state op, got " + << def->getName(); + } + } + return user->emitOpError() + << roleDescription + << " must be produced by pto.init_align or a prior store-state op"; + } + + Operation *def = (*root).getDefiningOp(); + if (!isStoreAlignProducer(def)) { + return user->emitOpError() + << roleDescription + << " must be produced by pto.init_align or a prior store-state op, got " + << def->getName(); + } + + return verifyStoreAlignLinearUses(*root, user); +} + +static FailureOr resolveLoadAlignRoot(Value value, Operation *user) { + llvm::SmallPtrSet visited; + Value current = value; + + while (true) { + if (!visited.insert(current.getAsOpaquePointer()).second) + return failure(); + + if (auto blockArg = dyn_cast(current)) { + auto *owner = blockArg.getOwner(); + auto forOp = dyn_cast(owner->getParentOp()); + if (!forOp) + return failure(); + unsigned argNumber = blockArg.getArgNumber(); + unsigned ivCount = forOp.getNumInductionVars(); + if (argNumber < ivCount) + return failure(); + unsigned iterIdx = argNumber - ivCount; + if (iterIdx >= forOp.getInitArgs().size()) + return failure(); + current = forOp.getInitArgs()[iterIdx]; + continue; + } + + if (Operation *def = current.getDefiningOp()) { + if (isLoadAlignProducer(def)) + return current; + if (auto forOp = dyn_cast(def)) { + auto result = dyn_cast(current); + if (!result) + return failure(); + unsigned resultIdx = result.getResultNumber(); + if (resultIdx >= forOp.getYieldedValues().size()) + return failure(); + current = forOp.getYieldedValues()[resultIdx]; + continue; + } + } + + return failure(); + } +} + +static LogicalResult verifyLoadAlignLoopThreading(Value align, Operation *user, + StringRef roleDescription) { + Operation *cursor = user; + while (auto forOp = cursor->getParentOfType()) { + Region *body = &forOp.getRegion(); + if (!isValueOwnedByRegion(align, body)) { + return user->emitOpError() + << roleDescription + << " must be threaded through scf.for iter_args when used inside a " + "loop"; + } + cursor = forOp; + } + return success(); +} + +static LogicalResult verifyLoadAlignLinearUses(Value value, Operation *user) { + llvm::SmallPtrSet visited; + Value current = value; + + while (visited.insert(current.getAsOpaquePointer()).second) { + SmallVector nextValues; + + for (OpOperand &use : current.getUses()) { + Operation *owner = use.getOwner(); + if (auto stateOp = dyn_cast(owner)) { + nextValues.push_back(stateOp.getUpdatedAlign()); + continue; + } + if (auto forOp = dyn_cast(owner)) { + unsigned firstInitArg = forOp.getNumControlOperands(); + if (use.getOperandNumber() < firstInitArg) { + return user->emitOpError() + << "found unexpected scf.for control operand use for !pto.align"; + } + unsigned iterIdx = use.getOperandNumber() - firstInitArg; + if (iterIdx >= forOp.getRegionIterArgs().size()) { + return user->emitOpError() + << "found invalid scf.for iter_args use for !pto.align"; + } + nextValues.push_back(forOp.getRegionIterArgs()[iterIdx]); + continue; + } + if (auto yieldOp = dyn_cast(owner)) { + auto forOp = dyn_cast(yieldOp->getParentOp()); + if (!forOp) { + return user->emitOpError() + << "found !pto.align yielded from non-scf.for loop"; + } + unsigned resultIdx = use.getOperandNumber(); + if (resultIdx >= forOp.getNumResults()) { + return user->emitOpError() + << "found invalid scf.yield result mapping for !pto.align"; + } + nextValues.push_back(forOp.getResult(resultIdx)); + continue; + } + return user->emitOpError() + << "found unsupported !pto.align consumer " << owner->getName(); + } + + if (nextValues.size() > 1) { + return user->emitOpError() + << "!pto.align value must form a single linear load-state chain"; + } + if (nextValues.empty()) + return success(); + current = nextValues.front(); + } + + return success(); +} + +static LogicalResult verifyLoadAlignChain(Value align, Operation *user, + StringRef roleDescription) { + if (disableVPTOAlignChainVerification) + return success(); + + if (failed(verifyAlignTypeLike(user, align.getType(), roleDescription))) + return failure(); + + if (failed(verifyLoadAlignLoopThreading(align, user, roleDescription))) + return failure(); + + FailureOr root = resolveLoadAlignRoot(align, user); + if (failed(root)) { + if (Operation *def = align.getDefiningOp()) { + if (!isa(def)) { + return user->emitOpError() + << roleDescription + << " must be produced by pto.vldas or a prior load-state op, got " + << def->getName(); + } + } + return user->emitOpError() + << roleDescription + << " must be produced by pto.vldas or a prior load-state op"; + } + + Operation *def = (*root).getDefiningOp(); + if (!isLoadAlignProducer(def)) { + return user->emitOpError() + << roleDescription + << " must be produced by pto.vldas or a prior load-state op, got " + << def->getName(); + } + + return verifyLoadAlignLinearUses(*root, user); +} + +static bool isSupportedPredicatePattern(StringRef pattern) { + return pattern == "PAT_ALL" || pattern == "PAT_VL1" || pattern == "PAT_VL2" || + pattern == "PAT_VL3" || pattern == "PAT_VL4" || pattern == "PAT_VL8" || + pattern == "PAT_VL16" || pattern == "PAT_VL32" || + pattern == "PAT_VL64" || pattern == "PAT_VL128" || + pattern == "PAT_M3" || pattern == "PAT_M4" || pattern == "PAT_H" || + pattern == "PAT_Q" || pattern == "PAT_ALLF"; +} + +static bool isSupportedPredicateLoadDist(StringRef dist) { + return dist == "NORM" || dist == "US" || dist == "DS"; +} + +static bool isSupportedPredicateStoreDist(StringRef dist) { + return dist == "NORM" || dist == "PK"; +} + +static bool isSupportedStrideToken(StringRef stride) { + return stride == "STRIDE_S3_B16" || stride == "STRIDE_S4_B64" || + stride == "STRIDE_S8_B32" || stride == "STRIDE_S2_B64" || + stride == "STRIDE_VSST_S8_B16"; +} + +static bool isSupportedPartToken(StringRef part) { + return part == "LOWER" || part == "HIGHER"; +} + +static bool isSupportedSprToken(StringRef spr) { return spr == "AR"; } + +static std::optional normalizeRoundModeToken(StringRef token) { + if (token == "R" || token == "ROUND_R") + return StringRef("R"); + if (token == "A" || token == "ROUND_A") + return StringRef("A"); + if (token == "F" || token == "ROUND_F") + return StringRef("F"); + if (token == "C" || token == "ROUND_C") + return StringRef("C"); + if (token == "Z" || token == "ROUND_Z") + return StringRef("Z"); + if (token == "O" || token == "ROUND_O") + return StringRef("O"); + return std::nullopt; +} + +static std::optional normalizeSaturationToken(StringRef token) { + if (token == "SAT" || token == "RS_ENABLE") + return StringRef("SAT"); + if (token == "NOSAT" || token == "RS_DISABLE") + return StringRef("NOSAT"); + return std::nullopt; +} + +static std::optional normalizeEvenOddPartToken(StringRef token) { + if (token == "EVEN" || token == "PART_EVEN") + return StringRef("EVEN"); + if (token == "ODD" || token == "PART_ODD") + return StringRef("ODD"); + return std::nullopt; +} + +namespace { + +enum class VcvtElemKind { + Invalid, + F16, + BF16, + F32, + S8, + U8, + S16, + U16, + S32, + U32, + S64, +}; + +struct VcvtContract { + bool requiresRnd; + bool requiresSat; + bool requiresPart; +}; + +static VcvtElemKind classifyVcvtElemType(Type type) { + if (type.isF16()) + return VcvtElemKind::F16; + if (type.isBF16()) + return VcvtElemKind::BF16; + if (type.isF32()) + return VcvtElemKind::F32; + if (auto intType = dyn_cast(type)) { + switch (intType.getWidth()) { + case 8: + return intType.isUnsigned() ? VcvtElemKind::U8 : VcvtElemKind::S8; + case 16: + return intType.isUnsigned() ? VcvtElemKind::U16 : VcvtElemKind::S16; + case 32: + return intType.isUnsigned() ? VcvtElemKind::U32 : VcvtElemKind::S32; + case 64: + return intType.isUnsigned() ? VcvtElemKind::Invalid : VcvtElemKind::S64; + default: + return VcvtElemKind::Invalid; + } + } + return VcvtElemKind::Invalid; +} + +static std::optional getVcvtElemBitWidth(VcvtElemKind kind) { + switch (kind) { + case VcvtElemKind::F16: + case VcvtElemKind::BF16: + case VcvtElemKind::S16: + case VcvtElemKind::U16: + return 16; + case VcvtElemKind::F32: + case VcvtElemKind::S32: + case VcvtElemKind::U32: + return 32; + case VcvtElemKind::S8: + case VcvtElemKind::U8: + return 8; + case VcvtElemKind::S64: + return 64; + case VcvtElemKind::Invalid: + return std::nullopt; + } + return std::nullopt; +} + +static std::optional lookupVcvtContract(VcvtElemKind src, + VcvtElemKind dst) { + switch (src) { + case VcvtElemKind::F32: + switch (dst) { + case VcvtElemKind::F16: + case VcvtElemKind::BF16: + case VcvtElemKind::S16: + case VcvtElemKind::S64: + return VcvtContract{/*requiresRnd=*/true, /*requiresSat=*/true, + /*requiresPart=*/true}; + case VcvtElemKind::S32: + return VcvtContract{/*requiresRnd=*/true, /*requiresSat=*/true, + /*requiresPart=*/false}; + default: + return std::nullopt; + } + case VcvtElemKind::F16: + switch (dst) { + case VcvtElemKind::F32: + case VcvtElemKind::S32: + return VcvtContract{/*requiresRnd=*/false, /*requiresSat=*/false, + /*requiresPart=*/true}; + case VcvtElemKind::S16: + return VcvtContract{/*requiresRnd=*/true, /*requiresSat=*/true, + /*requiresPart=*/false}; + case VcvtElemKind::S8: + case VcvtElemKind::U8: + return VcvtContract{/*requiresRnd=*/true, /*requiresSat=*/true, + /*requiresPart=*/true}; + default: + return std::nullopt; + } + case VcvtElemKind::BF16: + switch (dst) { + case VcvtElemKind::F32: + return VcvtContract{/*requiresRnd=*/false, /*requiresSat=*/false, + /*requiresPart=*/true}; + case VcvtElemKind::S32: + return VcvtContract{/*requiresRnd=*/true, /*requiresSat=*/true, + /*requiresPart=*/true}; + default: + return std::nullopt; + } + case VcvtElemKind::U8: + switch (dst) { + case VcvtElemKind::F16: + case VcvtElemKind::U16: + case VcvtElemKind::U32: + return VcvtContract{/*requiresRnd=*/false, /*requiresSat=*/false, + /*requiresPart=*/true}; + default: + return std::nullopt; + } + case VcvtElemKind::S8: + switch (dst) { + case VcvtElemKind::F16: + case VcvtElemKind::S16: + case VcvtElemKind::S32: + return VcvtContract{/*requiresRnd=*/false, /*requiresSat=*/false, + /*requiresPart=*/true}; + default: + return std::nullopt; + } + case VcvtElemKind::U16: + switch (dst) { + case VcvtElemKind::U8: + return VcvtContract{/*requiresRnd=*/false, /*requiresSat=*/true, + /*requiresPart=*/true}; + case VcvtElemKind::U32: + return VcvtContract{/*requiresRnd=*/false, /*requiresSat=*/false, + /*requiresPart=*/true}; + default: + return std::nullopt; + } + case VcvtElemKind::S16: + switch (dst) { + case VcvtElemKind::F16: + return VcvtContract{/*requiresRnd=*/true, /*requiresSat=*/false, + /*requiresPart=*/false}; + case VcvtElemKind::F32: + case VcvtElemKind::U32: + case VcvtElemKind::S32: + return VcvtContract{/*requiresRnd=*/false, /*requiresSat=*/false, + /*requiresPart=*/true}; + case VcvtElemKind::U8: + return VcvtContract{/*requiresRnd=*/false, /*requiresSat=*/true, + /*requiresPart=*/true}; + default: + return std::nullopt; + } + case VcvtElemKind::U32: + switch (dst) { + case VcvtElemKind::U8: + case VcvtElemKind::U16: + case VcvtElemKind::S16: + return VcvtContract{/*requiresRnd=*/false, /*requiresSat=*/true, + /*requiresPart=*/true}; + default: + return std::nullopt; + } + case VcvtElemKind::S32: + switch (dst) { + case VcvtElemKind::F32: + return VcvtContract{/*requiresRnd=*/true, /*requiresSat=*/false, + /*requiresPart=*/false}; + case VcvtElemKind::U8: + case VcvtElemKind::U16: + case VcvtElemKind::S16: + return VcvtContract{/*requiresRnd=*/false, /*requiresSat=*/true, + /*requiresPart=*/true}; + case VcvtElemKind::S64: + return VcvtContract{/*requiresRnd=*/false, /*requiresSat=*/false, + /*requiresPart=*/true}; + default: + return std::nullopt; + } + case VcvtElemKind::S64: + switch (dst) { + case VcvtElemKind::F32: + return VcvtContract{/*requiresRnd=*/true, /*requiresSat=*/false, + /*requiresPart=*/true}; + case VcvtElemKind::S32: + return VcvtContract{/*requiresRnd=*/false, /*requiresSat=*/true, + /*requiresPart=*/true}; + default: + return std::nullopt; + } + case VcvtElemKind::Invalid: + return std::nullopt; + } + return std::nullopt; +} + +} // namespace + +static std::optional getDistElementWidth(Type type) { + if (auto intType = dyn_cast(type)) + return intType.getWidth(); + if (type.isF16() || type.isBF16()) + return 16; + if (type.isF32()) + return 32; + if (type.isF64()) + return 64; + return std::nullopt; +} + +static bool matchesWidthFamily(StringRef dist, unsigned width, + ArrayRef allowedWidths) { + return llvm::is_contained(allowedWidths, width); +} + +static bool isSupportedVldx2DistToken(StringRef dist) { + return dist == "BDINTLV" || dist == "DINTLV"; +} + +static bool isSupportedVldsDistToken(StringRef dist) { + return dist == "NORM" || dist == "BRC" || dist == "US" || dist == "DS" || + dist == "UNPK" || dist == "BRC_BLK" || dist == "E2B" || + dist == "UNPK4" || dist == "SPLT4CHN" || dist == "SPLT2CHN"; +} + +static bool isSupportedVstsDistToken(StringRef dist) { + return dist == "NORM" || dist == "1PT" || dist == "PK" || + dist == "PK4" || dist == "MRG4CHN" || dist == "MRG2CHN"; +} + +static bool isSupportedVstsx2DistToken(StringRef dist) { + return dist == "INTLV"; +} + +static LogicalResult verifyVldsDistWidth(Operation *op, StringRef dist, + Type elementType) { + auto width = getDistElementWidth(elementType); + if (!width) + return op->emitOpError("requires load element type with a concrete bit width"); + + if (dist == "NORM" || dist == "BRC_BLK") + return success(); + if (dist == "BRC") + return matchesWidthFamily(dist, *width, {8, 16, 32}) + ? success() + : op->emitOpError("dist BRC only supports 8/16/32-bit elements"); + if (dist == "US") + return matchesWidthFamily(dist, *width, {8, 16}) + ? success() + : op->emitOpError("dist US only supports 8/16-bit elements"); + if (dist == "DS") + return matchesWidthFamily(dist, *width, {8, 16}) + ? success() + : op->emitOpError("dist DS only supports 8/16-bit elements"); + if (dist == "UNPK") + return matchesWidthFamily(dist, *width, {8, 16, 32}) + ? success() + : op->emitOpError("dist UNPK only supports 8/16/32-bit elements"); + if (dist == "E2B") + return matchesWidthFamily(dist, *width, {16, 32}) + ? success() + : op->emitOpError("dist E2B only supports 16/32-bit elements"); + if (dist == "UNPK4") + return *width == 8 + ? success() + : op->emitOpError("dist UNPK4 only supports 8-bit elements"); + if (dist == "SPLT4CHN") + return *width == 8 + ? success() + : op->emitOpError("dist SPLT4CHN only supports 8-bit elements"); + if (dist == "SPLT2CHN") + return matchesWidthFamily(dist, *width, {8, 16}) + ? success() + : op->emitOpError("dist SPLT2CHN only supports 8/16-bit elements"); + + return op->emitOpError("requires a supported load distribution token"); +} + +static LogicalResult verifyVldsx2DistWidth(Operation *op, StringRef dist, + Type elementType) { + auto width = getDistElementWidth(elementType); + if (!width) + return op->emitOpError( + "requires x2 load element type with a concrete bit width"); + if (dist == "BDINTLV") + return success(); + if (dist == "DINTLV") + return matchesWidthFamily(dist, *width, {8, 16, 32}) + ? success() + : op->emitOpError("dist DINTLV only supports 8/16/32-bit elements"); + return op->emitOpError("requires a supported x2 load distribution token"); +} + +static LogicalResult verifyVstsDistWidth(Operation *op, StringRef dist, + Type elementType) { + auto width = getDistElementWidth(elementType); + if (!width) + return op->emitOpError( + "requires store element type with a concrete bit width"); + + if (dist == "NORM") + return matchesWidthFamily(dist, *width, {8, 16, 32}) + ? success() + : op->emitOpError("dist NORM only supports 8/16/32-bit elements"); + if (dist == "1PT") + return matchesWidthFamily(dist, *width, {8, 16, 32}) + ? success() + : op->emitOpError("dist 1PT only supports 8/16/32-bit elements"); + if (dist == "PK") + return matchesWidthFamily(dist, *width, {16, 32, 64}) + ? success() + : op->emitOpError("dist PK only supports 16/32/64-bit elements"); + if (dist == "PK4") + return *width == 32 + ? success() + : op->emitOpError("dist PK4 only supports 32-bit elements"); + if (dist == "MRG4CHN") + return *width == 8 + ? success() + : op->emitOpError("dist MRG4CHN only supports 8-bit elements"); + if (dist == "MRG2CHN") + return matchesWidthFamily(dist, *width, {8, 16}) + ? success() + : op->emitOpError("dist MRG2CHN only supports 8/16-bit elements"); + + return op->emitOpError("requires a supported store distribution token"); +} + +static LogicalResult verifyVstsx2DistWidth(Operation *op, StringRef dist, + Type elementType) { + auto width = getDistElementWidth(elementType); + if (!width) + return op->emitOpError( + "requires x2 store element type with a concrete bit width"); + if (dist == "INTLV") + return matchesWidthFamily(dist, *width, {8, 16, 32}) + ? success() + : op->emitOpError("dist INTLV only supports 8/16/32-bit elements"); + return op->emitOpError("requires a supported x2 store distribution token"); +} + +static bool isSupportedPostMode(StringRef mode) { + return mode == "NO_POST_UPDATE" || mode == "POST_UPDATE"; +} + +static std::optional getOptionalPostModeAttr(Operation *op) { + if (auto mode = op->getAttrOfType("mode")) + return mode.getValue(); + return std::nullopt; +} + +static unsigned getIntOrFloatBitWidth(Type type) { + if (auto intType = dyn_cast(type)) + return intType.getWidth(); + if (auto floatType = dyn_cast(type)) + return floatType.getWidth(); + return 0; +} + +static bool isIntegerOrFloatLike(Type type) { + return isa(type) || isa(type); +} + +static std::optional getVRegStorageBitWidth(Type type) { + auto vecType = dyn_cast(type); + if (!vecType) + return std::nullopt; + unsigned elemWidth = getIntOrFloatBitWidth(vecType.getElementType()); + if (!elemWidth) + return std::nullopt; + return vecType.getElementCount() * static_cast(elemWidth); +} + +static LogicalResult verifyIntegerVRegTypeLike(Operation *op, Type type, + StringRef roleDescription) { + if (failed(verifyVRegTypeLike(op, type, roleDescription))) + return failure(); + auto vecType = cast(type); + if (!isa(vecType.getElementType())) + return op->emitOpError() + << roleDescription << " must use integer vector element type"; + return success(); +} + +enum class MemoryRole { + Unknown, + GM, + UB, + Other, +}; + +static MemoryRole classifyMemoryRole(Type type) { + auto memrefType = dyn_cast(type); + if (!memrefType) { + if (auto ptrType = dyn_cast(type)) { + switch (ptrType.getMemorySpace().getAddressSpace()) { + case pto::AddressSpace::GM: + case pto::AddressSpace::Zero: + return MemoryRole::GM; + case pto::AddressSpace::VEC: + return MemoryRole::UB; + default: + return MemoryRole::Other; + } + } + return MemoryRole::Other; + } + + Attribute memorySpace = memrefType.getMemorySpace(); + if (!memorySpace) + return MemoryRole::Unknown; + + if (auto addrSpace = dyn_cast(memorySpace)) { + switch (addrSpace.getAddressSpace()) { + case pto::AddressSpace::GM: + case pto::AddressSpace::Zero: + return MemoryRole::GM; + case pto::AddressSpace::VEC: + return MemoryRole::UB; + default: + return MemoryRole::Other; + } + } + + if (auto intAttr = dyn_cast(memorySpace)) { + switch (intAttr.getInt()) { + case static_cast(pto::AddressSpace::GM): + case static_cast(pto::AddressSpace::Zero): + return MemoryRole::GM; + case static_cast(pto::AddressSpace::VEC): + return MemoryRole::UB; + default: + return MemoryRole::Other; + } + } + + return MemoryRole::Other; +} + +static bool isBufferLike(Type type) { + return isa(type); +} + +static int64_t getPtrElementByteSize(Type type) { + auto ptrType = dyn_cast(type); + if (!ptrType) + return 0; + + Type elementType = ptrType.getElementType(); + if (auto floatType = dyn_cast(elementType)) + return (floatType.getWidth() + 7) / 8; + if (auto intType = dyn_cast(elementType)) + return (intType.getWidth() + 7) / 8; + return 0; +} + +template +static LogicalResult verifyCopyGmToUbufOp(CopyOp op, bool expectSourceGM) { + auto sourceType = dyn_cast(op.getSource().getType()); + auto destinationType = dyn_cast(op.getDestination().getType()); + if (!sourceType || !destinationType) + return op.emitOpError("requires typed !pto.ptr source and destination"); + + MemoryRole sourceRole = classifyMemoryRole(op.getSource().getType()); + MemoryRole destinationRole = classifyMemoryRole(op.getDestination().getType()); + bool directionMatches = true; + if (expectSourceGM) { + directionMatches &= sourceRole != MemoryRole::UB; + directionMatches &= destinationRole != MemoryRole::GM; + } else { + directionMatches &= sourceRole != MemoryRole::GM; + directionMatches &= destinationRole != MemoryRole::UB; + } + + if (!directionMatches) { + return op.emitOpError() + << "requires " + << (expectSourceGM ? "GM source and UB destination" + : "UB source and GM destination"); + } + + int64_t sourceElemBytes = getPtrElementByteSize(sourceType); + int64_t destinationElemBytes = getPtrElementByteSize(destinationType); + if (sourceElemBytes <= 0 || destinationElemBytes <= 0) + return op.emitOpError("requires copy source and destination element types with known byte width"); + if (sourceElemBytes != destinationElemBytes) + return op.emitOpError("requires source and destination element byte widths to match"); + + return success(); +} + +template +static LogicalResult verifyCopyUbufToGmOp(CopyOp op, bool expectSourceGM) { + auto sourceType = dyn_cast(op.getSource().getType()); + auto destinationType = dyn_cast(op.getDestination().getType()); + if (!sourceType || !destinationType) + return op.emitOpError("requires typed !pto.ptr source and destination"); + + MemoryRole sourceRole = classifyMemoryRole(op.getSource().getType()); + MemoryRole destinationRole = classifyMemoryRole(op.getDestination().getType()); + bool directionMatches = true; + if (expectSourceGM) { + directionMatches &= sourceRole != MemoryRole::UB; + directionMatches &= destinationRole != MemoryRole::GM; + } else { + directionMatches &= sourceRole != MemoryRole::GM; + directionMatches &= destinationRole != MemoryRole::UB; + } + + if (!directionMatches) { + return op.emitOpError() + << "requires " + << (expectSourceGM ? "GM source and UB destination" + : "UB source and GM destination"); + } + + int64_t sourceElemBytes = getPtrElementByteSize(sourceType); + int64_t destinationElemBytes = getPtrElementByteSize(destinationType); + if (sourceElemBytes <= 0 || destinationElemBytes <= 0) + return op.emitOpError("requires copy source and destination element types with known byte width"); + if (sourceElemBytes != destinationElemBytes) + return op.emitOpError("requires source and destination element byte widths to match"); + + return success(); +} + +Type VRegType::parse(AsmParser &parser) { + SmallVector shape; + Type elementType; + SMLoc loc = parser.getCurrentLocation(); + + if (failed(parser.parseLess()) || + failed(parser.parseDimensionList(shape, /*allowDynamic=*/false, + /*withTrailingX=*/true)) || + shape.size() != 1 || failed(parser.parseType(elementType)) || + failed(parser.parseGreater())) + return {}; + + return parser.getChecked(loc, parser.getContext(), shape.front(), + elementType); +} + +void VRegType::print(AsmPrinter &printer) const { + printer << "<" << getElementCount() << "x"; + printer.printType(getElementType()); + printer << ">"; +} + +LogicalResult VRegType::verify(function_ref emitError, + int64_t elementCount, Type elementType) { + if (elementCount <= 0) + return emitError() << "'" << formatVRegType(elementCount, elementType) + << "' expected a positive element count"; + + auto intOrFloat = mlir::dyn_cast(elementType); + unsigned elementBitWidth = 0; + if (intOrFloat) { + elementBitWidth = intOrFloat.getWidth(); + } else if (auto floatType = mlir::dyn_cast(elementType)) { + elementBitWidth = floatType.getWidth(); + } else { + return emitError() << "'" << formatVRegType(elementCount, elementType) + << "' expected an integer or floating-point element type"; + } + + if (elementCount * static_cast(elementBitWidth) != 2048) + return emitError() << "'" << formatVRegType(elementCount, elementType) + << "' expected exactly 256 bytes"; + + return success(); +} + +LogicalResult VecScopeOp::verify() { + Region &bodyRegion = getBody(); + if (bodyRegion.empty()) + return emitOpError("expects a non-empty body region"); + + Block &body = bodyRegion.front(); + if (body.getNumArguments() != 0) + return emitOpError() << "expects body block to have no arguments, got " + << body.getNumArguments(); + + return success(); +} + +LogicalResult StrictVecScopeOp::verify() { + Region &bodyRegion = getBody(); + if (bodyRegion.empty()) + return emitOpError("expects a non-empty body region"); + + Block &body = bodyRegion.front(); + if (body.getNumArguments() != getCaptures().size()) + return emitOpError() << "expects body block to have " + << getCaptures().size() + << " arguments to match explicit captures, got " + << body.getNumArguments(); + + for (auto [idx, pair] : + llvm::enumerate(llvm::zip(body.getArguments(), getCaptures()))) { + BlockArgument blockArg = std::get<0>(pair); + Value capture = std::get<1>(pair); + if (blockArg.getType() != capture.getType()) + return emitOpError() << "expects body block argument #" << idx + << " to have type " << capture.getType() + << ", got " << blockArg.getType(); + } + return success(); +} + +bool MaskType::isSupportedGranularity(StringRef granularity) { + return granularity == "b8" || granularity == "b16" || + granularity == "b32"; +} + +Type MaskType::parse(AsmParser &parser) { + auto loc = parser.getCurrentLocation(); + StringRef granularity; + if (failed(parser.parseLess()) || failed(parser.parseKeyword(&granularity)) || + failed(parser.parseGreater())) + return {}; + + return parser.getChecked(loc, parser.getContext(), granularity); +} + +void MaskType::print(AsmPrinter &printer) const { + printer << "<" << getGranularity() << ">"; +} + +LogicalResult +MaskType::verify(function_ref emitError, + StringRef granularity) { + if (!isSupportedGranularity(granularity)) + return emitError() << "'" << formatMaskType(granularity) + << "' expected granularity to be one of b8, b16, b32"; + return success(); +} + +void CopyGmToUbufOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult CopyGmToUbufOp::verify() { + return verifyCopyGmToUbufOp(*this, true); +} + +LogicalResult VbrOp::verify() { + if (failed(verifyVRegTypeLike(*this, getResult().getType(), "result"))) + return failure(); + + auto resultVecType = cast(getResult().getType()); + Type elementType = getValue().getType(); + if (isa(elementType)) + return emitOpError("value must be a scalar matching the result element type"); + if (elementType != resultVecType.getElementType()) + return emitOpError("value type must match result element type"); + return success(); +} + +LogicalResult VcaddOp::verify() { + if (failed(verifyVRegTypeLike(*this, getInput().getType(), "input")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result"))) + return failure(); + if (getInput().getType() != getResult().getType()) + return emitOpError("input and result must have the same vector type"); + return success(); +} + +LogicalResult VcmaxOp::verify() { + if (failed(verifyVRegTypeLike(*this, getInput().getType(), "input")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result"))) + return failure(); + if (getInput().getType() != getResult().getType()) + return emitOpError("input and result must have the same vector type"); + return success(); +} + +LogicalResult VcminOp::verify() { + if (failed(verifyVRegTypeLike(*this, getInput().getType(), "input")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result"))) + return failure(); + if (getInput().getType() != getResult().getType()) + return emitOpError("input and result must have the same vector type"); + return success(); +} + +LogicalResult VciOp::verify() { + auto resultType = dyn_cast(getResult().getType()); + if (!resultType) + return emitOpError("result must be !pto.vreg<...>"); + if (!isa(resultType.getElementType())) + return emitOpError("result element type must be integer"); + auto indexType = dyn_cast(getIndex().getType()); + if (!indexType) + return emitOpError("index must be an integer scalar"); + if (indexType != resultType.getElementType()) + return emitOpError("index type must match result element type"); + return success(); +} + +void Vgather2Op::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult Vgather2Op::verify() { + if (!isBufferLike(getSource().getType())) + return emitOpError("requires a pointer-like source"); + MemoryRole sourceRole = classifyMemoryRole(getSource().getType()); + if (sourceRole == MemoryRole::GM) + return emitOpError("requires a UB-backed source"); + + auto offsetsType = dyn_cast(getOffsets().getType()); + auto resultType = dyn_cast(getResult().getType()); + if (!offsetsType || !resultType) + return emitOpError("offsets and result must be !pto.vreg<...>"); + if (!isa(offsetsType.getElementType())) + return emitOpError("offset vector must use integer element type"); + if (offsetsType.getElementCount() != resultType.getElementCount()) + return emitOpError("offset and result vectors must have the same element count"); + if (!getActiveLanes().getType().isIndex()) + return emitOpError("active_lanes must be index"); + return success(); +} + +LogicalResult CopyUbufToUbufOp::verify() { + if (!isBufferLike(getSource().getType()) || !isBufferLike(getDestination().getType())) + return emitOpError("requires pointer-like source and destination"); + if (classifyMemoryRole(getSource().getType()) != MemoryRole::UB || + classifyMemoryRole(getDestination().getType()) != MemoryRole::UB) + return emitOpError("requires UB-backed source and destination"); + return success(); +} + +void VgatherbOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult VgatherbOp::verify() { + if (!isBufferLike(getSource().getType())) + return emitOpError("requires a pointer-like source"); + MemoryRole sourceRole = classifyMemoryRole(getSource().getType()); + if (sourceRole == MemoryRole::GM) + return emitOpError("requires a UB-backed source"); + + if (failed(verifyMaskTypeWithGranularityLike(getOperation(), getMask().getType(), + "mask type", "b32"))) + return failure(); + + auto offsetsType = dyn_cast(getOffsets().getType()); + auto resultType = dyn_cast(getResult().getType()); + if (!offsetsType || !resultType) + return emitOpError("offsets and result must be !pto.vreg<...>"); + auto offsetsElemType = dyn_cast(offsetsType.getElementType()); + if (!offsetsElemType) + return emitOpError("offset vector must use integer element type"); + if (offsetsElemType.getWidth() != 32) + return emitOpError("currently requires 32-bit offset vector elements"); + if (offsetsType.getElementCount() != resultType.getElementCount()) + return emitOpError("offset and result vectors must have the same element count"); + return success(); +} + +void Vgather2BcOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult Vgather2BcOp::verify() { + if (!isBufferLike(getSource().getType())) + return emitOpError("requires a pointer-like source"); + if (classifyMemoryRole(getSource().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed source"); + if (failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type"))) + return failure(); + + auto offsetsType = dyn_cast(getOffsets().getType()); + auto resultType = dyn_cast(getResult().getType()); + if (!offsetsType || !resultType) + return emitOpError("offsets and result must be !pto.vreg<...>"); + auto offsetsElemType = dyn_cast(offsetsType.getElementType()); + if (!offsetsElemType) + return emitOpError("offset vector must use integer element type"); + if (offsetsElemType.getWidth() != 32) + return emitOpError("currently requires 32-bit offset vector elements"); + if (offsetsType.getElementCount() != resultType.getElementCount()) + return emitOpError("offset and result vectors must have the same element count"); + return success(); +} + +LogicalResult VbitsortOp::verify() { + if (!isBufferLike(getDestination().getType()) || !isBufferLike(getSource().getType()) || + !isBufferLike(getIndices().getType())) + return emitOpError("requires pointer-like destination/source/indices"); + if (classifyMemoryRole(getDestination().getType()) != MemoryRole::UB || + classifyMemoryRole(getSource().getType()) != MemoryRole::UB || + classifyMemoryRole(getIndices().getType()) != MemoryRole::UB) + return emitOpError("requires UB-backed destination/source/indices"); + if (!getRepeatTimes().getType().isIndex()) + return emitOpError("repeat_times must be index"); + if (failed(verifyNotNestedInVecScope(*this, "pto.vbitsort"))) + return failure(); + return success(); +} + +void VbitsortOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getIndicesMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult Vmrgsort4Op::verify() { + if (!isBufferLike(getDestination().getType()) || !isBufferLike(getSource0().getType()) || + !isBufferLike(getSource1().getType()) || !isBufferLike(getSource2().getType()) || + !isBufferLike(getSource3().getType())) + return emitOpError("requires pointer-like destination and sources"); + if (classifyMemoryRole(getDestination().getType()) != MemoryRole::UB || + classifyMemoryRole(getSource0().getType()) != MemoryRole::UB || + classifyMemoryRole(getSource1().getType()) != MemoryRole::UB || + classifyMemoryRole(getSource2().getType()) != MemoryRole::UB || + classifyMemoryRole(getSource3().getType()) != MemoryRole::UB) + return emitOpError("requires UB-backed destination and sources"); + return success(); +} + +LogicalResult VmaxOp::verify() { + if (failed(verifyVRegTypeLike(*this, getLhs().getType(), "lhs")) || + failed(verifyVRegTypeLike(*this, getRhs().getType(), "rhs")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result"))) + return failure(); + if (getLhs().getType() != getRhs().getType() || + getLhs().getType() != getResult().getType()) + return emitOpError("lhs, rhs, and result must have the same vector type"); + return success(); +} + +LogicalResult VminOp::verify() { + if (failed(verifyVRegTypeLike(*this, getLhs().getType(), "lhs")) || + failed(verifyVRegTypeLike(*this, getRhs().getType(), "rhs")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result"))) + return failure(); + if (getLhs().getType() != getRhs().getType() || + getLhs().getType() != getResult().getType()) + return emitOpError("lhs, rhs, and result must have the same vector type"); + return success(); +} + +void VldsOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +template +static LogicalResult verifyVldsCommon(LoadOp op) { + if (!isBufferLike(op.getSource().getType())) + return op.emitOpError("requires a pointer-like source"); + + if (failed(verifyVRegTypeLike(op, op.getResult().getType(), "result type"))) + return failure(); + + MemoryRole sourceRole = classifyMemoryRole(op.getSource().getType()); + if (sourceRole == MemoryRole::GM) + return op.emitOpError("requires a UB-backed source"); + + if (op.getDistAttr()) { + StringRef dist = *op.getDist(); + if (!isSupportedVldsDistToken(dist)) + return op.emitOpError( + "supports only NORM, BRC, US, DS, UNPK, BRC_BLK, E2B, UNPK4, " + "and SPLT2CHN/SPLT4CHN load distributions"); + if (failed(verifyVldsDistWidth( + op.getOperation(), dist, + cast(op.getResult().getType()).getElementType()))) + return failure(); + } + + return success(); +} + +LogicalResult VldsOp::verify() { + if (failed(verifyVldsCommon(*this))) + return failure(); + if (std::optional mode = getOptionalPostModeAttr(getOperation()); + mode && !isSupportedPostMode(*mode)) + return emitOpError("requires mode to be POST_UPDATE or NO_POST_UPDATE"); + return success(); +} +void VldsPostOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult VldsPostOp::verify() { + if (failed(verifyVldsCommon(*this))) + return failure(); + if (getUpdatedSource().getType() != getSource().getType()) + return emitOpError("requires updated source result to match source type"); + return success(); +} + +void VldasOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult VldasOp::verify() { + if (!isBufferLike(getSource().getType())) + return emitOpError("requires a pointer-like source"); + if (failed(verifyAlignTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (classifyMemoryRole(getSource().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed source"); + return success(); +} + +LogicalResult InitAlignOp::verify() { + return verifyAlignTypeLike(*this, getResult().getType(), "result type"); +} + +LogicalResult SprclrOp::verify() { + if (!isSupportedSprToken(getSpr())) + return emitOpError("requires spr to be \"AR\""); + if (failed(verifyNestedInVecScope(*this, "pto.sprclr"))) + return failure(); + return success(); +} + +void VldusOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult VldusOp::verify() { + if (failed(verifyLoadAlignChain(getAlign(), *this, "align type")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result type")) || + failed(verifyAlignTypeLike(*this, getUpdatedAlign().getType(), + "updated align type"))) + return failure(); + if (!isBufferLike(getSource().getType())) + return emitOpError("requires a pointer-like source"); + if (classifyMemoryRole(getSource().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed source"); + return success(); +} + +void UvldOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult UvldOp::verify() { + if (failed(verifyVRegTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (!isBufferLike(getSource().getType())) + return emitOpError("requires a buffer-like source"); + if (classifyMemoryRole(getSource().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed source"); + + auto sourceMemRef = dyn_cast(getSource().getType()); + if (!sourceMemRef) + return success(); + + Type sourceElementType = sourceMemRef.getElementType(); + Type vectorElementType = cast(getResult().getType()).getElementType(); + if (sourceElementType != vectorElementType) + return emitOpError( + "requires source element type to match vector element type"); + return success(); +} + +LogicalResult VdupOp::verify() { + auto resultType = dyn_cast(getResult().getType()); + if (!resultType) + return emitOpError("result must be !pto.vreg<...>"); + + std::optional granularity = + getVdupMaskGranularity(resultType.getElementType()); + if (!granularity) + return emitOpError("result element type must use b8, b16, or b32 mask granularity"); + if (failed(verifyMaskTypeWithGranularityLike( + getOperation(), getMask().getType(), "mask type", *granularity))) + return failure(); + + if (!isSupportedVdupPosition(getPosition())) + return emitOpError("position must be LOWEST or HIGHEST"); + + Type inputType = getInput().getType(); + if (auto inputVecType = dyn_cast(inputType)) { + if (inputVecType != resultType) + return emitOpError("vector input must match result vector type"); + return success(); + } + + if (getPosition()) + return emitOpError("position is only supported for vector input"); + + if (inputType != resultType.getElementType()) + return emitOpError("scalar input must match result element type"); + + return success(); +} + +LogicalResult PsetB8Op::verify() { + if (failed(verifyMaskTypeWithGranularityLike(*this, getResult().getType(), + "result type", "b8"))) + return failure(); + + if (!isSupportedPredicatePattern(getPattern())) + return emitOpError("requires a supported PAT_* predicate pattern"); + return success(); +} + +LogicalResult PsetB16Op::verify() { + if (failed(verifyMaskTypeWithGranularityLike(*this, getResult().getType(), + "result type", "b16"))) + return failure(); + + if (!isSupportedPredicatePattern(getPattern())) + return emitOpError("requires a supported PAT_* predicate pattern"); + return success(); +} + +LogicalResult PsetB32Op::verify() { + if (failed(verifyMaskTypeWithGranularityLike(*this, getResult().getType(), + "result type", "b32"))) + return failure(); + if (!isSupportedPredicatePattern(getPattern())) + return emitOpError("requires a supported PAT_* predicate pattern"); + return success(); +} + +LogicalResult PgeB8Op::verify() { + if (failed(verifyMaskTypeWithGranularityLike(*this, getResult().getType(), + "result type", "b8"))) + return failure(); + if (!isSupportedPredicatePattern(getPattern())) + return emitOpError("requires a supported PAT_* predicate pattern"); + return success(); +} + +LogicalResult PgeB16Op::verify() { + if (failed(verifyMaskTypeWithGranularityLike(*this, getResult().getType(), + "result type", "b16"))) + return failure(); + if (!isSupportedPredicatePattern(getPattern())) + return emitOpError("requires a supported PAT_* predicate pattern"); + return success(); +} + +LogicalResult PgeB32Op::verify() { + if (failed(verifyMaskTypeWithGranularityLike(*this, getResult().getType(), + "result type", "b32"))) + return failure(); + if (!isSupportedPredicatePattern(getPattern())) + return emitOpError("requires a supported PAT_* predicate pattern"); + return success(); +} + +template +static LogicalResult verifyPredicateLaneCountOp(PltOp op, + StringRef granularity) { + if (failed(verifyMaskTypeWithGranularityLike(op, op.getMask().getType(), + "mask type", granularity))) + return failure(); + Type scalarType = op.getScalar().getType(); + auto scalarIntType = dyn_cast(scalarType); + if (!scalarIntType || scalarIntType.getWidth() != 32) + return op.emitOpError("requires scalar to be i32"); + if (op.getScalarOut().getType() != scalarType) + return op.emitOpError("requires scalar_out to match scalar type"); + return success(); +} + +LogicalResult PltB8Op::verify() { return verifyPredicateLaneCountOp(*this, "b8"); } +LogicalResult PltB16Op::verify() { + return verifyPredicateLaneCountOp(*this, "b16"); +} +LogicalResult PltB32Op::verify() { + return verifyPredicateLaneCountOp(*this, "b32"); +} + +LogicalResult PpackOp::verify() { + if (failed(verifyMaskTypeLike(*this, getInput().getType(), "input type")) || + failed(verifyMaskTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (getPart() != "LOWER") + return emitOpError("currently supports only LOWER part"); + return success(); +} + +LogicalResult PunpackOp::verify() { + if (failed(verifyMaskTypeLike(*this, getInput().getType(), "input type")) || + failed(verifyMaskTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (getPart() != "LOWER") + return emitOpError("currently supports only LOWER part"); + return success(); +} + +LogicalResult PnotOp::verify() { + if (failed(verifyMaskTypeLike(*this, getInput().getType(), "input type")) || + failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type")) || + failed(verifyMaskTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + return success(); +} + +LogicalResult PselOp::verify() { + if (failed(verifyMaskTypeLike(*this, getSrc0().getType(), "src0 type")) || + failed(verifyMaskTypeLike(*this, getSrc1().getType(), "src1 type")) || + failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type")) || + failed(verifyMaskTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + return success(); +} + +template +static LogicalResult verifyBinaryMaskOp(BinaryMaskOp op) { + if (failed(verifyMaskTypeLike(op, op.getSrc0().getType(), "src0 type")) || + failed(verifyMaskTypeLike(op, op.getSrc1().getType(), "src1 type")) || + failed(verifyMaskTypeLike(op, op.getMask().getType(), "mask type")) || + failed(verifyMaskTypeLike(op, op.getResult().getType(), "result type"))) + return failure(); + return success(); +} + +LogicalResult PandOp::verify() { return verifyBinaryMaskOp(*this); } +LogicalResult PorOp::verify() { return verifyBinaryMaskOp(*this); } +LogicalResult PxorOp::verify() { return verifyBinaryMaskOp(*this); } + +void PldsOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult PldsOp::verify() { + if (!isBufferLike(getSource().getType())) + return emitOpError("requires a pointer-like source"); + if (failed(verifyMaskTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + MemoryRole sourceRole = classifyMemoryRole(getSource().getType()); + if (sourceRole == MemoryRole::GM) + return emitOpError("requires a UB-backed source"); + if (!getOffset().getType().isIndex()) + return emitOpError("requires index offset"); + if (!isSupportedPredicateLoadDist(getDist())) + return emitOpError("requires predicate load dist to be NORM, US, or DS"); + if (failed(verifyEnclosingLoopLike(*this, "pto.plds"))) + return failure(); + return success(); +} + +void PldiOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult PldiOp::verify() { + if (!isBufferLike(getSource().getType())) + return emitOpError("requires a pointer-like source"); + if (failed(verifyMaskTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (classifyMemoryRole(getSource().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed source"); + if (!matchPattern(getOffset(), m_Constant())) + return emitOpError("requires offset to be a constant index immediate"); + if (!isSupportedPredicateLoadDist(getDist())) + return emitOpError("requires predicate load dist to be NORM, US, or DS"); + if (failed(verifyEnclosingLoopLike(*this, "pto.pldi"))) + return failure(); + return success(); +} + +template +static LogicalResult verifyElementwiseVecScalarOpLike(OpTy op) { + auto inputType = dyn_cast(op.getInput().getType()); + auto resultType = dyn_cast(op.getResult().getType()); + if (!inputType || !resultType) + return op.emitOpError("input and result must be !pto.vreg<...>"); + if (inputType != resultType) + return op.emitOpError("input and result vector types must match"); + + Type elemType = inputType.getElementType(); + Type scalarType = op.getScalar().getType(); + if (scalarType == elemType) + return success(); + + auto elemInt = dyn_cast(elemType); + auto scalarInt = dyn_cast(scalarType); + if (!elemInt || !scalarInt || elemInt.getWidth() != scalarInt.getWidth()) + return op.emitOpError("scalar type must match vector element type"); + + if (elemInt.isSigned() && (scalarInt.isSigned() || scalarInt.isSignless())) + return success(); + if (elemInt.isUnsigned() && + (scalarInt.isUnsigned() || scalarInt.isSignless())) + return success(); + if (elemInt.isSignless() && scalarInt.isSignless()) + return success(); + + return op.emitOpError( + "integer scalar type must match vector element width and use matching signedness or signless i"); +} + +template +static LogicalResult verifyVecScalarOpLike(OpTy op) { + if (failed(verifyElementwiseVecScalarOpLike(op))) + return failure(); + return success(); +} + +template +static LogicalResult verifySignedSaturatingVecScalarOpLike(OpTy op) { + if (failed(verifyElementwiseVecScalarOpLike(op))) + return failure(); + auto inputType = cast(op.getInput().getType()); + auto elemType = dyn_cast(inputType.getElementType()); + if (!elemType || elemType.isUnsigned() || elemType.getWidth() != 16) + return op.emitOpError("requires s16 vector element type"); + return success(); +} + +template +static LogicalResult verifyVecScalarMaskedOpLike(OpTy op) { + if (failed(verifyElementwiseVecScalarOpLike(op))) + return failure(); + if (failed(verifyMaskTypeLike(op, op.getMask().getType(), "mask type"))) + return failure(); + return success(); +} + +template +static LogicalResult verifyCarryVecOp(CarryOp op) { + if (failed(verifyIntegerVRegTypeLike(op, op.getLhs().getType(), "lhs type")) || + failed(verifyIntegerVRegTypeLike(op, op.getRhs().getType(), "rhs type")) || + failed(verifyMaskTypeLike(op, op.getMask().getType(), "mask type")) || + failed(verifyIntegerVRegTypeLike(op, op.getResult().getType(), + "result type")) || + failed(verifyMaskTypeLike(op, op.getCarry().getType(), "carry type"))) + return failure(); + + auto lhsType = cast(op.getLhs().getType()); + auto rhsType = cast(op.getRhs().getType()); + auto resultType = cast(op.getResult().getType()); + auto lhsElemType = cast(lhsType.getElementType()); + if (lhsType != rhsType || lhsType != resultType) + return op.emitOpError("requires lhs, rhs, and result to have matching vector types"); + if (lhsElemType.getWidth() != 32) + return op.emitOpError("currently requires 32-bit integer vector elements"); + return success(); +} + +template +static LogicalResult verifyCarryVecOpWithInput(CarryWithInputOp op) { + if (failed(verifyCarryVecOp(op)) || + failed(verifyMaskTypeLike(op, op.getCarryIn().getType(), + "carry_in type"))) + return failure(); + return success(); +} + +LogicalResult VmulsOp::verify() { return verifyVecScalarMaskedOpLike(*this); } +LogicalResult VaddsOp::verify() { return verifyVecScalarMaskedOpLike(*this); } +LogicalResult VsaddsOp::verify() { + if (failed(verifySignedSaturatingVecScalarOpLike(*this))) + return failure(); + return verifyMaskTypeLike(*this, getMask().getType(), "mask type"); +} +LogicalResult VmaxsOp::verify() { return verifyVecScalarMaskedOpLike(*this); } +LogicalResult VminsOp::verify() { return verifyVecScalarMaskedOpLike(*this); } +LogicalResult VlreluOp::verify() { return verifyVecScalarMaskedOpLike(*this); } +LogicalResult VshlsOp::verify() { + auto inputType = dyn_cast(getInput().getType()); + auto resultType = dyn_cast(getResult().getType()); + if (!inputType || !resultType) + return emitOpError("input and result must be !pto.vreg<...>"); + if (failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type"))) + return failure(); + if (inputType != resultType) + return emitOpError("input and result vector types must match"); + if (!isa(inputType.getElementType())) + return emitOpError("requires integer vector and integer scalar"); + auto scalarType = dyn_cast(getScalar().getType()); + if (!scalarType || !scalarType.isSignlessInteger(16)) + return emitOpError("requires signless i16 scalar"); + return success(); +} +LogicalResult VshrsOp::verify() { + auto inputType = dyn_cast(getInput().getType()); + auto resultType = dyn_cast(getResult().getType()); + if (!inputType || !resultType) + return emitOpError("input and result must be !pto.vreg<...>"); + if (failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type"))) + return failure(); + if (inputType != resultType) + return emitOpError("input and result vector types must match"); + if (!isa(inputType.getElementType())) + return emitOpError("requires integer vector and integer scalar"); + auto scalarType = dyn_cast(getScalar().getType()); + if (!scalarType || !scalarType.isSignlessInteger(16)) + return emitOpError("requires signless i16 scalar"); + return success(); +} + +LogicalResult VabsOp::verify() { + if (failed(verifyVRegTypeLike(*this, getInput().getType(), "operand type"))) + return failure(); + if (failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type"))) + return failure(); + if (failed(verifyVRegTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (getInput().getType() != getResult().getType()) + return emitOpError("requires matching register vector shape"); + return success(); +} + +template +static LogicalResult verifyUnaryVecOp(UnaryOp op) { + if (failed(verifyVRegTypeLike(op, op.getInput().getType(), "operand type"))) + return failure(); + if (failed(verifyMaskTypeLike(op, op.getMask().getType(), "mask type"))) + return failure(); + if (failed(verifyVRegTypeLike(op, op.getResult().getType(), "result type"))) + return failure(); + if (op.getInput().getType() != op.getResult().getType()) + return op.emitOpError("requires matching register vector shape"); + return success(); +} + +LogicalResult VexpOp::verify() { return verifyUnaryVecOp(*this); } +LogicalResult VlnOp::verify() { return verifyUnaryVecOp(*this); } +LogicalResult VsqrtOp::verify() { return verifyUnaryVecOp(*this); } +LogicalResult VnegOp::verify() { return verifyUnaryVecOp(*this); } +LogicalResult VrsqrtOp::verify() { + if (failed(verifyUnaryVecOp(*this))) + return failure(); + auto inputType = cast(getInput().getType()); + Type elemType = inputType.getElementType(); + if (!elemType.isF16() && !elemType.isF32()) + return emitOpError("requires f16 or f32 vector element type"); + return success(); +} +LogicalResult VrecOp::verify() { return verifyUnaryVecOp(*this); } +LogicalResult VreluOp::verify() { return verifyUnaryVecOp(*this); } +LogicalResult VnotOp::verify() { return verifyUnaryVecOp(*this); } +LogicalResult VbcntOp::verify() { + if (failed(verifyUnaryVecOp(*this))) + return failure(); + auto inputType = cast(getInput().getType()); + if (!isa(inputType.getElementType())) + return emitOpError("requires integer vector element type"); + return success(); +} +LogicalResult VclsOp::verify() { + if (failed(verifyUnaryVecOp(*this))) + return failure(); + auto inputType = cast(getInput().getType()); + if (!isa(inputType.getElementType())) + return emitOpError("requires integer vector element type"); + return success(); +} + +template +static LogicalResult verifyBinaryVecOp(BinaryOp op) { + if (failed(verifyVRegTypeLike(op, op.getLhs().getType(), "lhs type"))) + return failure(); + if (failed(verifyVRegTypeLike(op, op.getRhs().getType(), "rhs type"))) + return failure(); + if (failed(verifyMaskTypeLike(op, op.getMask().getType(), "mask type"))) + return failure(); + if (failed(verifyVRegTypeLike(op, op.getResult().getType(), "result type"))) + return failure(); + if (op.getLhs().getType() != op.getRhs().getType() || + op.getLhs().getType() != op.getResult().getType()) + return op.emitOpError("requires matching register vector shapes"); + return success(); +} + +template +static LogicalResult verifySignedSaturatingBinaryVecOp(BinaryOp op) { + if (failed(verifyBinaryVecOp(op))) + return failure(); + auto lhsType = cast(op.getLhs().getType()); + auto elemType = dyn_cast(lhsType.getElementType()); + if (!elemType || elemType.isUnsigned() || elemType.getWidth() != 16) + return op.emitOpError("requires s16 vector element type"); + return success(); +} + +LogicalResult VaddOp::verify() { return verifyBinaryVecOp(*this); } +LogicalResult VsubOp::verify() { return verifyBinaryVecOp(*this); } +LogicalResult VsaddOp::verify() { + return verifySignedSaturatingBinaryVecOp(*this); +} +LogicalResult VssubOp::verify() { + return verifySignedSaturatingBinaryVecOp(*this); +} +LogicalResult VmulOp::verify() { return verifyBinaryVecOp(*this); } +LogicalResult VdivOp::verify() { return verifyBinaryVecOp(*this); } +LogicalResult VandOp::verify() { return verifyBinaryVecOp(*this); } +LogicalResult VorOp::verify() { return verifyBinaryVecOp(*this); } +LogicalResult VxorOp::verify() { return verifyBinaryVecOp(*this); } +LogicalResult VshlOp::verify() { + if (failed(verifyBinaryVecOp(*this))) + return failure(); + auto lhsType = cast(getLhs().getType()); + if (!isa(lhsType.getElementType())) + return emitOpError("requires integer vector element type"); + return success(); +} +LogicalResult VshrOp::verify() { + if (failed(verifyBinaryVecOp(*this))) + return failure(); + auto lhsType = cast(getLhs().getType()); + if (!isa(lhsType.getElementType())) + return emitOpError("requires integer vector element type"); + return success(); +} +LogicalResult VaddcOp::verify() { return verifyCarryVecOp(*this); } +LogicalResult VsubcOp::verify() { return verifyCarryVecOp(*this); } +LogicalResult VaddcsOp::verify() { return verifyCarryVecOpWithInput(*this); } +LogicalResult VsubcsOp::verify() { return verifyCarryVecOpWithInput(*this); } + +template +static LogicalResult verifyReductionVecOp(ReductionOp op) { + return verifyUnaryVecOp(op); +} + +template +static LogicalResult verifyGroupReductionVecOp(ReductionOp op) { + if (failed(verifyReductionVecOp(op))) + return failure(); + auto inputType = cast(op.getInput().getType()); + Type elemType = inputType.getElementType(); + if (auto intType = dyn_cast(elemType)) { + if (intType.getWidth() < 16 || intType.getWidth() > 32) + return op.emitOpError( + "requires 16-bit or 32-bit integer vector element type"); + return success(); + } + if (!elemType.isF16() && !elemType.isF32()) + return op.emitOpError("requires i16/i32/f16/f32 vector element type"); + return success(); +} + +LogicalResult VcgaddOp::verify() { return verifyGroupReductionVecOp(*this); } +LogicalResult VcgmaxOp::verify() { return verifyGroupReductionVecOp(*this); } +LogicalResult VcgminOp::verify() { return verifyGroupReductionVecOp(*this); } +LogicalResult VcpaddOp::verify() { + if (failed(verifyReductionVecOp(*this))) + return failure(); + auto inputType = cast(getInput().getType()); + Type elemType = inputType.getElementType(); + if (!elemType.isF16() && !elemType.isF32()) + return emitOpError("requires f16 or f32 vector element type"); + return success(); +} + +template +static LogicalResult verifyLaneSelectOp(SelectOp op) { + if (failed(verifyVRegTypeLike(op, op.getSrc0().getType(), "src0 type")) || + failed(verifyVRegTypeLike(op, op.getSrc1().getType(), "src1 type")) || + failed(verifyVRegTypeLike(op, op.getResult().getType(), "result type"))) + return failure(); + + auto src0Type = cast(op.getSrc0().getType()); + auto src1Type = cast(op.getSrc1().getType()); + auto resultType = cast(op.getResult().getType()); + if (src0Type != resultType) + return op.emitOpError("requires src0 and result to have identical vector types"); + if (src1Type.getElementCount() != src0Type.getElementCount()) + return op.emitOpError("requires src0/src1 to have identical element counts"); + auto src1ElemType = dyn_cast(src1Type.getElementType()); + if (!src1ElemType) + return op.emitOpError("requires src1 to use integer vector elements"); + if (src1ElemType.getWidth() != getIntOrFloatBitWidth(src0Type.getElementType())) + return op.emitOpError("requires src1 integer element width to match src0 element width"); + return success(); +} + +template +static LogicalResult verifyPairVecResults(PairOp op) { + if (failed(verifyVRegTypeLike(op, op.getLhs().getType(), "lhs type")) || + failed(verifyVRegTypeLike(op, op.getRhs().getType(), "rhs type")) || + failed(verifyVRegTypeLike(op, op.getLow().getType(), "low result type")) || + failed(verifyVRegTypeLike(op, op.getHigh().getType(), "high result type"))) + return failure(); + if (op.getLhs().getType() != op.getRhs().getType() || + op.getLhs().getType() != op.getLow().getType() || + op.getLhs().getType() != op.getHigh().getType()) + return op.emitOpError("requires operands and results to share one vector type"); + return success(); +} + +template +static LogicalResult verifyPartVecOp(PartOp op) { + if (failed(verifyVRegTypeLike(op, op.getLhs().getType(), "lhs type")) || + failed(verifyVRegTypeLike(op, op.getRhs().getType(), "rhs type")) || + failed(verifyVRegTypeLike(op, op.getResult().getType(), "result type"))) + return failure(); + if (op.getLhs().getType() != op.getRhs().getType() || + op.getLhs().getType() != op.getResult().getType()) + return op.emitOpError("requires operands and result to share one vector type"); + if (!isSupportedPartToken(op.getPart())) + return op.emitOpError("requires part to be LOWER or HIGHER"); + return success(); +} + +LogicalResult VselOp::verify() { + if (failed(verifyVRegTypeLike(*this, getSrc0().getType(), "src0 type")) || + failed(verifyVRegTypeLike(*this, getSrc1().getType(), "src1 type")) || + failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (getSrc0().getType() != getSrc1().getType() || + getSrc0().getType() != getResult().getType()) + return emitOpError("requires src0, src1, and result to have identical vector types"); + return success(); +} + +LogicalResult VselrOp::verify() { return verifyLaneSelectOp(*this); } +LogicalResult Vselrv2Op::verify() { return verifyLaneSelectOp(*this); } + +LogicalResult VslideOp::verify() { + if (failed(verifyVRegTypeLike(*this, getSrc0().getType(), "src0 type")) || + failed(verifyVRegTypeLike(*this, getSrc1().getType(), "src1 type")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (getSrc0().getType() != getSrc1().getType() || + getSrc0().getType() != getResult().getType()) + return emitOpError("requires src0, src1, and result to share one vector type"); + return success(); +} + +LogicalResult VsqzOp::verify() { return verifyUnaryVecOp(*this); } + +LogicalResult VusqzOp::verify() { + if (failed(verifyVRegTypeLike(*this, getSrc().getType(), "src type")) || + failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (getSrc().getType() != getResult().getType()) + return emitOpError("requires src and result to share one vector type"); + auto srcType = cast(getSrc().getType()); + auto elemType = dyn_cast(srcType.getElementType()); + if (!elemType) + return emitOpError("requires signed integer vector element type"); + if (elemType.isUnsigned()) + return emitOpError("requires signed integer vector element type"); + unsigned width = elemType.getWidth(); + if (width != 8 && width != 16 && width != 32) + return emitOpError("requires s8/s16/s32 vector element type"); + return success(); +} + +LogicalResult VpackOp::verify() { + if (failed(verifyVRegTypeLike(*this, getSrc().getType(), "src type")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (!isSupportedPartToken(getPart())) + return emitOpError("requires part to be LOWER or HIGHER"); + auto srcType = cast(getSrc().getType()); + auto resultType = cast(getResult().getType()); + Type srcElemType = srcType.getElementType(); + Type resultElemType = resultType.getElementType(); + if (!isa(srcElemType) || !isa(resultElemType)) + return emitOpError("currently requires integer source and result element types"); + if (resultType.getElementCount() != srcType.getElementCount() * 2) + return emitOpError( + "requires result element count to be twice the source element count"); + unsigned srcWidth = getIntOrFloatBitWidth(srcElemType); + unsigned resultWidth = getIntOrFloatBitWidth(resultElemType); + if (!srcWidth || resultWidth * 2 != srcWidth) + return emitOpError( + "requires result element width to be half the source element width"); + auto srcIntType = cast(srcElemType); + auto resultIntType = cast(resultElemType); + if (!resultIntType.isUnsigned()) + return emitOpError("requires unsigned result element type"); + if (!((srcIntType.getWidth() == 32 && resultIntType.getWidth() == 16) || + (srcIntType.getWidth() == 16 && resultIntType.getWidth() == 8))) + return emitOpError( + "currently supports only s32/u32 -> u16 and s16/u16 -> u8"); + return success(); +} + +template +static LogicalResult verifyUnpackVecOp(UnpackOp op) { + if (failed(verifyVRegTypeLike(op, op.getSrc().getType(), "src type")) || + failed(verifyVRegTypeLike(op, op.getResult().getType(), "result type"))) + return failure(); + auto srcType = cast(op.getSrc().getType()); + auto resultType = cast(op.getResult().getType()); + Type srcElemType = srcType.getElementType(); + Type resultElemType = resultType.getElementType(); + if (!isa(srcElemType) || !isa(resultElemType)) + return op.emitOpError( + "currently requires integer source and result element types"); + if (srcType.getElementCount() != resultType.getElementCount() * 2) + return op.emitOpError( + "requires source element count to be twice the result element count"); + unsigned srcWidth = getIntOrFloatBitWidth(srcElemType); + unsigned resultWidth = getIntOrFloatBitWidth(resultElemType); + if (!srcWidth || srcWidth * 2 != resultWidth) + return op.emitOpError( + "requires result element width to be twice the source element width"); + return success(); +} + +LogicalResult VsunpackOp::verify() { return verifyUnpackVecOp(*this); } +LogicalResult VzunpackOp::verify() { return verifyUnpackVecOp(*this); } + +static bool isSupportedCmpMode(StringRef mode) { + return mode == "eq" || mode == "ne" || mode == "lt" || mode == "le" || + mode == "gt" || mode == "ge"; +} + +LogicalResult VcmpOp::verify() { + if (failed(verifyVRegTypeLike(*this, getSrc0().getType(), "src0 type")) || + failed(verifyVRegTypeLike(*this, getSrc1().getType(), "src1 type")) || + failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type")) || + failed(verifyMaskTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (getSrc0().getType() != getSrc1().getType()) + return emitOpError("requires src0 and src1 to have identical vector types"); + if (!isSupportedCmpMode(getCmpMode())) + return emitOpError("requires cmp_mode to be one of eq/ne/lt/le/gt/ge"); + return success(); +} + +LogicalResult VcmpsOp::verify() { + if (failed(verifyVRegTypeLike(*this, getSrc().getType(), "src type")) || + failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type")) || + failed(verifyMaskTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + auto srcType = cast(getSrc().getType()); + if (getScalar().getType() != srcType.getElementType()) + return emitOpError("requires scalar type to match source element type"); + if (!isSupportedCmpMode(getCmpMode())) + return emitOpError("requires cmp_mode to be one of eq/ne/lt/le/gt/ge"); + return success(); +} + +ParseResult VtrcOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand input; + std::string roundModeToken; + NamedAttrList attrs; + Type inputType, resultType; + + if (parser.parseOperand(input) || parser.parseComma() || + parser.parseKeywordOrString(&roundModeToken) || + parser.parseOptionalAttrDict(attrs) || + parser.parseColonType(inputType) || parser.parseArrow() || + parser.parseType(resultType)) + return failure(); + + auto normalized = normalizeRoundModeToken(roundModeToken); + if (!normalized) + return parser.emitError(parser.getCurrentLocation()) + << "round mode must be one of R/A/F/C/Z/O or " + "ROUND_R/ROUND_A/ROUND_F/ROUND_C/ROUND_Z/ROUND_O"; + + attrs.set("round_mode", parser.getBuilder().getStringAttr(*normalized)); + result.addAttributes(attrs); + if (parser.resolveOperand(input, inputType, result.operands)) + return failure(); + result.addTypes(resultType); + return success(); +} + +void VtrcOp::print(OpAsmPrinter &printer) { + printer << ' ' << getInput() << ", "; + Builder builder(getContext()); + auto normalized = normalizeRoundModeToken(getRoundMode()); + printer.printAttributeWithoutType( + builder.getStringAttr(normalized.value_or(getRoundMode()))); + printer.printOptionalAttrDict((*this)->getAttrs(), {"round_mode"}); + printer << " : " << getInput().getType() << " -> " << getResult().getType(); +} + +LogicalResult VtrcOp::verify() { + auto inputType = dyn_cast(getInput().getType()); + auto resultType = dyn_cast(getResult().getType()); + if (!inputType || !resultType) + return emitOpError("input and result must be !pto.vreg<...>"); + if (inputType != resultType) + return emitOpError("requires input and result to have identical vreg type"); + if (!normalizeRoundModeToken(getRoundMode())) + return emitOpError("round mode must be one of R/A/F/C/Z/O"); + return success(); +} + +ParseResult VcvtOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand input; + NamedAttrList attrs; + Type inputType, resultType; + + if (parser.parseOperand(input) || parser.parseOptionalAttrDict(attrs) || + parser.parseColonType(inputType) || parser.parseArrow() || + parser.parseType(resultType)) + return failure(); + + Attribute legacyRndAttr = attrs.get("round_mode"); + Attribute rndAttr = attrs.get("rnd"); + if (legacyRndAttr && rndAttr) + return parser.emitError(parser.getCurrentLocation()) + << "rnd and round_mode cannot be specified together"; + + auto normalizeNamedStringAttr = + [&](StringRef sourceName, StringRef canonicalName, + auto normalizeFn) -> ParseResult { + Attribute rawAttr = attrs.get(sourceName); + if (!rawAttr) + return success(); + auto strAttr = dyn_cast(rawAttr); + if (!strAttr) + return parser.emitError(parser.getCurrentLocation()) + << sourceName << " must be a string literal"; + auto normalized = normalizeFn(strAttr.getValue()); + if (!normalized) + return parser.emitError(parser.getCurrentLocation()) + << sourceName << " has unsupported value '" << strAttr.getValue() + << "'"; + attrs.erase(sourceName); + attrs.set(canonicalName, parser.getBuilder().getStringAttr(*normalized)); + return success(); + }; + + if (failed(normalizeNamedStringAttr("round_mode", "rnd", + normalizeRoundModeToken)) || + failed(normalizeNamedStringAttr("rnd", "rnd", normalizeRoundModeToken)) || + failed(normalizeNamedStringAttr("sat", "sat", normalizeSaturationToken)) || + failed( + normalizeNamedStringAttr("part", "part", normalizeEvenOddPartToken))) + return failure(); + + result.addAttributes(attrs); + if (parser.resolveOperand(input, inputType, result.operands)) + return failure(); + result.addTypes(resultType); + return success(); +} + +void VcvtOp::print(OpAsmPrinter &printer) { + printer << ' ' << getInput(); + printer.printOptionalAttrDict((*this)->getAttrs()); + printer << " : " << getInput().getType() << " -> " << getResult().getType(); +} + +LogicalResult VcvtOp::verify() { + auto inputType = dyn_cast(getInput().getType()); + auto resultType = dyn_cast(getResult().getType()); + if (!inputType || !resultType) + return emitOpError("input and result must be !pto.vreg<...>"); + + VcvtElemKind inputElemKind = classifyVcvtElemType(inputType.getElementType()); + VcvtElemKind resultElemKind = classifyVcvtElemType(resultType.getElementType()); + auto contract = lookupVcvtContract(inputElemKind, resultElemKind); + if (!contract) + return emitOpError("unsupported vcvt source/result element type pair"); + + auto inputElemBits = getVcvtElemBitWidth(inputElemKind); + auto resultElemBits = getVcvtElemBitWidth(resultElemKind); + if (!inputElemBits || !resultElemBits) + return emitOpError("could not determine vcvt element bit width"); + if (inputType.getElementCount() * static_cast(*inputElemBits) != + resultType.getElementCount() * static_cast(*resultElemBits)) { + return emitOpError("requires source and result vectors to carry the same " + "total number of bits"); + } + + if (getRndAttr()) { + StringRef roundMode = *getRnd(); + if (!normalizeRoundModeToken(roundMode)) + return emitOpError("rnd must be one of R/A/F/C/Z/O"); + } + if (static_cast(getRndAttr()) != contract->requiresRnd) { + return contract->requiresRnd ? emitOpError("requires rnd attr for this vcvt type pair") + : emitOpError("rnd attr is not valid for this vcvt type pair"); + } + + if (getSatAttr()) { + StringRef sat = *getSat(); + if (!normalizeSaturationToken(sat)) + return emitOpError("sat must be SAT or NOSAT"); + } + if (static_cast(getSatAttr()) != contract->requiresSat) { + return contract->requiresSat ? emitOpError("requires sat attr for this vcvt type pair") + : emitOpError("sat attr is not valid for this vcvt type pair"); + } + + if (getPartAttr()) { + StringRef part = *getPart(); + if (!normalizeEvenOddPartToken(part)) + return emitOpError("part must be EVEN or ODD"); + } + if (static_cast(getPartAttr()) != contract->requiresPart) { + return contract->requiresPart ? emitOpError("requires part attr for this vcvt type pair") + : emitOpError("part attr is not valid for this vcvt type pair"); + } + + return success(); +} + +LogicalResult PdintlvB8Op::verify() { + if (failed(verifyMaskTypeWithGranularityLike(*this, getLhs().getType(), + "lhs type", "b8")) || + failed(verifyMaskTypeWithGranularityLike(*this, getRhs().getType(), + "rhs type", "b8")) || + failed(verifyMaskTypeWithGranularityLike(*this, getLow().getType(), + "low type", "b8")) || + failed(verifyMaskTypeWithGranularityLike(*this, getHigh().getType(), + "high type", "b8"))) + return failure(); + return success(); +} + +LogicalResult PdintlvB16Op::verify() { + if (failed(verifyMaskTypeWithGranularityLike(*this, getLhs().getType(), + "lhs type", "b16")) || + failed(verifyMaskTypeWithGranularityLike(*this, getRhs().getType(), + "rhs type", "b16")) || + failed(verifyMaskTypeWithGranularityLike(*this, getLow().getType(), + "low type", "b16")) || + failed(verifyMaskTypeWithGranularityLike(*this, getHigh().getType(), + "high type", "b16"))) + return failure(); + return success(); +} + +LogicalResult PdintlvB32Op::verify() { + if (failed(verifyMaskTypeWithGranularityLike(*this, getLhs().getType(), + "lhs type", "b32")) || + failed(verifyMaskTypeWithGranularityLike(*this, getRhs().getType(), + "rhs type", "b32")) || + failed(verifyMaskTypeWithGranularityLike(*this, getLow().getType(), + "low type", "b32")) || + failed(verifyMaskTypeWithGranularityLike(*this, getHigh().getType(), + "high type", "b32"))) + return failure(); + return success(); +} + +LogicalResult PintlvB8Op::verify() { + if (failed(verifyMaskTypeWithGranularityLike(*this, getLhs().getType(), + "lhs type", "b8")) || + failed(verifyMaskTypeWithGranularityLike(*this, getRhs().getType(), + "rhs type", "b8")) || + failed(verifyMaskTypeWithGranularityLike(*this, getLow().getType(), + "low type", "b8")) || + failed(verifyMaskTypeWithGranularityLike(*this, getHigh().getType(), + "high type", "b8"))) + return failure(); + return success(); +} + +LogicalResult PintlvB16Op::verify() { + if (failed(verifyMaskTypeWithGranularityLike(*this, getLhs().getType(), + "lhs type", "b16")) || + failed(verifyMaskTypeWithGranularityLike(*this, getRhs().getType(), + "rhs type", "b16")) || + failed(verifyMaskTypeWithGranularityLike(*this, getLow().getType(), + "low type", "b16")) || + failed(verifyMaskTypeWithGranularityLike(*this, getHigh().getType(), + "high type", "b16"))) + return failure(); + return success(); +} + +LogicalResult PintlvB32Op::verify() { + if (failed(verifyMaskTypeWithGranularityLike(*this, getLhs().getType(), + "lhs type", "b32")) || + failed(verifyMaskTypeWithGranularityLike(*this, getRhs().getType(), + "rhs type", "b32")) || + failed(verifyMaskTypeWithGranularityLike(*this, getLow().getType(), + "low type", "b32")) || + failed(verifyMaskTypeWithGranularityLike(*this, getHigh().getType(), + "high type", "b32"))) + return failure(); + return success(); +} + +LogicalResult VintlvOp::verify() { return verifyPairVecResults(*this); } +LogicalResult VdintlvOp::verify() { return verifyPairVecResults(*this); } +LogicalResult Vintlvv2Op::verify() { return verifyPartVecOp(*this); } +LogicalResult Vdintlvv2Op::verify() { return verifyPartVecOp(*this); } + +LogicalResult VmullOp::verify() { + if (failed(verifyPairVecResults(*this)) || + failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type"))) + return failure(); + auto lhsType = cast(getLhs().getType()); + auto lhsElemType = dyn_cast(lhsType.getElementType()); + if (!lhsElemType) + return emitOpError("requires integer vector element type"); + if (lhsElemType.getWidth() != 32) + return emitOpError("currently requires 32-bit integer vector elements"); + return success(); +} + +LogicalResult VmulaOp::verify() { + if (failed(verifyVRegTypeLike(*this, getAcc().getType(), "acc type")) || + failed(verifyVRegTypeLike(*this, getLhs().getType(), "lhs type")) || + failed(verifyVRegTypeLike(*this, getRhs().getType(), "rhs type")) || + failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (getAcc().getType() != getLhs().getType() || + getAcc().getType() != getRhs().getType() || + getAcc().getType() != getResult().getType()) + return emitOpError("requires acc, lhs, rhs, and result to share one vector type"); + return success(); +} + +template +static LogicalResult verifyBinaryVecNoMaskOp(BinaryVecNoMaskOp op) { + if (failed(verifyVRegTypeLike(op, op.getLhs().getType(), "lhs type")) || + failed(verifyVRegTypeLike(op, op.getRhs().getType(), "rhs type")) || + failed(verifyVRegTypeLike(op, op.getResult().getType(), "result type"))) + return failure(); + if (op.getLhs().getType() != op.getRhs().getType() || + op.getLhs().getType() != op.getResult().getType()) + return op.emitOpError("requires lhs, rhs, and result to share one vector type"); + return success(); +} + +template +static LogicalResult verifyFloatBinaryVecNoMaskOp(BinaryVecNoMaskOp op) { + if (failed(verifyBinaryVecNoMaskOp(op))) + return failure(); + auto lhsType = cast(op.getLhs().getType()); + Type elemType = lhsType.getElementType(); + if (!elemType.isF16() && !elemType.isF32()) + return op.emitOpError("requires f16 or f32 vector element type"); + return success(); +} + +LogicalResult VpreluOp::verify() { return verifyFloatBinaryVecNoMaskOp(*this); } +LogicalResult VexpdiffOp::verify() { + if (failed(verifyVRegTypeLike(*this, getInput().getType(), "input type")) || + failed(verifyVRegTypeLike(*this, getMax().getType(), "max type")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + + auto inputType = cast(getInput().getType()); + auto maxType = cast(getMax().getType()); + auto resultType = cast(getResult().getType()); + if (inputType != maxType) + return emitOpError("requires input and max to share one vector type"); + + Type inputElemType = inputType.getElementType(); + if (!inputElemType.isF16() && !inputElemType.isF32()) + return emitOpError("requires f16 or f32 input vector element type"); + if (!resultType.getElementType().isF32()) + return emitOpError("requires f32 result vector element type"); + + auto inputBits = getVRegStorageBitWidth(inputType); + auto resultBits = getVRegStorageBitWidth(resultType); + if (!inputBits || !resultBits || *inputBits != *resultBits) + return emitOpError( + "requires source and result to preserve total vector storage width"); + + StringRef part = getPart(); + if (part != "EVEN" && part != "ODD") + return emitOpError("part must be EVEN or ODD"); + return success(); +} + +LogicalResult VaxpyOp::verify() { + if (failed(verifyVRegTypeLike(*this, getSrc0().getType(), "src0 type")) || + failed(verifyVRegTypeLike(*this, getSrc1().getType(), "src1 type")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + auto src0Type = cast(getSrc0().getType()); + auto src1Type = cast(getSrc1().getType()); + auto resultType = cast(getResult().getType()); + if (src0Type != src1Type || src0Type != resultType) + return emitOpError("requires src0, src1, and result to share one vector type"); + Type elemType = src0Type.getElementType(); + if (!elemType.isF16() && !elemType.isF32()) + return emitOpError("requires f16 or f32 vector element type"); + if (getAlpha().getType() != elemType) + return emitOpError("requires alpha type to match vector element type"); + return success(); +} + +template +static LogicalResult verifyFusedConvVecOp(ConvOp op) { + if (failed(verifyVRegTypeLike(op, op.getLhs().getType(), "lhs type")) || + failed(verifyVRegTypeLike(op, op.getRhs().getType(), "rhs type")) || + failed(verifyVRegTypeLike(op, op.getResult().getType(), "result type"))) + return failure(); + auto lhsType = cast(op.getLhs().getType()); + auto rhsType = cast(op.getRhs().getType()); + auto resultType = cast(op.getResult().getType()); + if (lhsType != rhsType) + return op.emitOpError("requires lhs and rhs to share one vector type"); + if (!isIntegerOrFloatLike(lhsType.getElementType()) || + !isIntegerOrFloatLike(resultType.getElementType())) + return op.emitOpError( + "requires integer or floating-point vector element types"); + auto lhsBits = getVRegStorageBitWidth(lhsType); + auto resultBits = getVRegStorageBitWidth(resultType); + if (!lhsBits || !resultBits || *lhsBits != *resultBits) + return op.emitOpError( + "requires source and result to preserve total vector storage width"); + return success(); +} + +LogicalResult VaddreluconvOp::verify() { + return verifyFusedConvVecOp(*this); +} +LogicalResult VmulconvOp::verify() { return verifyFusedConvVecOp(*this); } + +void Vldsx2Op::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult Vldsx2Op::verify() { + if (!isBufferLike(getSource().getType())) + return emitOpError("requires a pointer-like source"); + if (classifyMemoryRole(getSource().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed source"); + if (!getOffset().getType().isIndex()) + return emitOpError("requires index offset"); + if (failed(verifyVRegTypeLike(*this, getLow().getType(), "low result type")) || + failed(verifyVRegTypeLike(*this, getHigh().getType(), "high result type"))) + return failure(); + if (getLow().getType() != getHigh().getType()) + return emitOpError("requires low/high results to share one vector type"); + if (!isSupportedVldx2DistToken(getDist())) + return emitOpError("requires a supported x2 load distribution token"); + if (failed(verifyVldsx2DistWidth( + getOperation(), getDist(), + cast(getLow().getType()).getElementType()))) + return failure(); + return success(); +} + +void VstsOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getValueMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +template +static LogicalResult verifyVstsCommon(StoreOp op) { + if (failed(verifyVRegTypeLike(op, op.getValue().getType(), "value type"))) + return failure(); + if (failed(verifyMaskTypeLike(op, op.getMask().getType(), "mask type"))) + return failure(); + + if (!isBufferLike(op.getDestination().getType())) + return op.emitOpError("requires a pointer-like destination"); + + MemoryRole destinationRole = classifyMemoryRole(op.getDestination().getType()); + if (destinationRole == MemoryRole::GM) + return op.emitOpError("requires a UB-backed destination"); + + if (std::optional dist = op.getDist(); + dist && !isSupportedVstsDistToken(*dist)) { + return op.emitOpError("requires a supported store distribution token"); + } + if (std::optional dist = op.getDist(); + dist && + failed(verifyVstsDistWidth( + op.getOperation(), *dist, + cast(op.getValue().getType()).getElementType()))) + return failure(); + + return success(); +} + +LogicalResult VstsOp::verify() { + if (failed(verifyVstsCommon(*this))) + return failure(); + if (std::optional mode = getOptionalPostModeAttr(getOperation()); + mode && !isSupportedPostMode(*mode)) + return emitOpError("requires mode to be POST_UPDATE or NO_POST_UPDATE"); + return success(); +} +void VstsPostOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getValueMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult VstsPostOp::verify() { + if (failed(verifyVstsCommon(*this))) + return failure(); + if (getUpdatedDestination().getType() != getDestination().getType()) + return emitOpError( + "requires updated destination result to match destination type"); + return success(); +} + +void Vstsx2Op::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getLowMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getHighMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult Vstsx2Op::verify() { + if (failed(verifyVRegTypeLike(*this, getLow().getType(), "low value type")) || + failed(verifyVRegTypeLike(*this, getHigh().getType(), "high value type")) || + failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type"))) + return failure(); + if (getLow().getType() != getHigh().getType()) + return emitOpError("requires low/high values to share one vector type"); + if (!isBufferLike(getDestination().getType())) + return emitOpError("requires a pointer-like destination"); + if (classifyMemoryRole(getDestination().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed destination"); + if (!getOffset().getType().isIndex()) + return emitOpError("requires index offset"); + if (!isSupportedVstsx2DistToken(getDist())) + return emitOpError("requires a supported x2 store distribution token"); + if (failed(verifyVstsx2DistWidth( + getOperation(), getDist(), + cast(getLow().getType()).getElementType()))) + return failure(); + return success(); +} + +void VscatterOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getValueMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult VscatterOp::verify() { + if (failed(verifyVRegTypeLike(*this, getValue().getType(), "value type"))) + return failure(); + if (!isBufferLike(getDestination().getType())) + return emitOpError("requires a pointer-like destination"); + auto offsetsType = dyn_cast(getOffsets().getType()); + auto valueType = dyn_cast(getValue().getType()); + if (!offsetsType || !valueType) + return emitOpError("value and offsets must be !pto.vreg<...>"); + auto offsetsElemType = dyn_cast(offsetsType.getElementType()); + if (!offsetsElemType) + return emitOpError("offset vector must use integer element type"); + if (offsetsElemType.getWidth() != 32) + return emitOpError("currently requires 32-bit offset vector elements"); + if (offsetsType.getElementCount() != valueType.getElementCount()) + return emitOpError("offset and value vectors must have the same element count"); + MemoryRole destinationRole = classifyMemoryRole(getDestination().getType()); + if (destinationRole == MemoryRole::GM) + return emitOpError("requires a UB-backed destination"); + if (!getActiveLanes().getType().isIndex()) + return emitOpError("active_lanes must be index"); + return success(); +} + +void VsldbOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult VsldbOp::verify() { + if (!isBufferLike(getSource().getType())) + return emitOpError("requires a pointer-like source"); + if (classifyMemoryRole(getSource().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed source"); + if (failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (!getBlockStride().getType().isSignlessInteger(16)) + return emitOpError("requires block_stride to be i16"); + if (!getRepeatStride().getType().isSignlessInteger(16)) + return emitOpError("requires repeat_stride to be i16"); + return success(); +} + +void PstsOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getValueMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +void PstiOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getValueMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult PstiOp::verify() { + if (failed(verifyMaskTypeLike(*this, getValue().getType(), "value type"))) + return failure(); + if (!isBufferLike(getDestination().getType())) + return emitOpError("requires a pointer-like destination"); + if (classifyMemoryRole(getDestination().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed destination"); + if (!matchPattern(getOffset(), m_Constant())) + return emitOpError("requires offset to be a constant index immediate"); + if (!isSupportedPredicateStoreDist(getDist())) + return emitOpError("requires predicate store dist to be NORM or PK"); + return success(); +} + +LogicalResult PstsOp::verify() { + if (failed(verifyMaskTypeLike(*this, getValue().getType(), "value type"))) + return failure(); + if (!isBufferLike(getDestination().getType())) + return emitOpError("requires a pointer-like destination"); + MemoryRole destinationRole = classifyMemoryRole(getDestination().getType()); + if (destinationRole == MemoryRole::GM) + return emitOpError("requires a UB-backed destination"); + if (!getOffset().getType().isIndex()) + return emitOpError("requires index offset"); + if (!isSupportedPredicateStoreDist(getDist())) + return emitOpError("requires predicate store dist to be NORM or PK"); + return success(); +} + +void VsstbOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getValueMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult VsstbOp::verify() { + if (failed(verifyVRegTypeLike(*this, getValue().getType(), "value type")) || + failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type"))) + return failure(); + if (!isBufferLike(getDestination().getType())) + return emitOpError("requires a pointer-like destination"); + if (classifyMemoryRole(getDestination().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed destination"); + if (!getBlockStride().getType().isSignlessInteger(16)) + return emitOpError("requires block_stride to be i16"); + if (!getRepeatStride().getType().isSignlessInteger(16)) + return emitOpError("requires repeat_stride to be i16"); + return success(); +} + +void VstasOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getValueMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult VstasOp::verify() { + if (failed(verifyStoreAlignChain(getValue(), *this, "value type"))) + return failure(); + if (!isBufferLike(getDestination().getType())) + return emitOpError("requires a pointer-like destination"); + if (classifyMemoryRole(getDestination().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed destination"); + return success(); +} + +void VstarOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getValueMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult VstarOp::verify() { + if (failed(verifyStoreAlignChain(getValue(), *this, "value type"))) + return failure(); + if (!isBufferLike(getDestination().getType())) + return emitOpError("requires a pointer-like destination"); + if (classifyMemoryRole(getDestination().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed destination"); + return success(); +} + +void PstuOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getAlignInMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getValueMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getBaseMutable()); +} + +LogicalResult PstuOp::verify() { + if (failed(verifyStoreAlignChain(getAlignIn(), *this, "align_in type")) || + failed(verifyMaskTypeLike(*this, getValue().getType(), "value type")) || + failed(verifyAlignTypeLike(*this, getAlignOut().getType(), "align_out type"))) + return failure(); + if (!isBufferLike(getBase().getType()) || !isBufferLike(getBaseOut().getType())) + return emitOpError("requires pointer-like base and base_out"); + if (getBase().getType() != getBaseOut().getType()) + return emitOpError("requires base and base_out to have identical types"); + if (classifyMemoryRole(getBase().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed base"); + auto baseType = cast(getBase().getType()); + auto maskType = cast(getValue().getType()); + auto elemType = dyn_cast(baseType.getElementType()); + if (!elemType || elemType.isSigned() || (elemType.getWidth() != 16 && elemType.getWidth() != 32)) + return emitOpError("requires ui16/ui32 UB base type"); + if (maskType.isB16() && elemType.getWidth() != 16) + return emitOpError("requires !pto.mask to pair with !pto.ptr"); + if (maskType.isB32() && elemType.getWidth() != 32) + return emitOpError("requires !pto.mask to pair with !pto.ptr"); + return success(); +} + +void VstusOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getAlignInMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getValueMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getBaseMutable()); +} + +LogicalResult VstusOp::verify() { + if (failed(verifyStoreAlignChain(getAlignIn(), *this, "align_in type")) || + failed(verifyVRegTypeLike(*this, getValue().getType(), "value type")) || + failed(verifyAlignTypeLike(*this, getAlignOut().getType(), "align_out type"))) + return failure(); + if (!isBufferLike(getBase().getType())) + return emitOpError("requires a pointer-like base"); + if (classifyMemoryRole(getBase().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed base"); + return success(); +} + +void VsturOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getAlignInMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getValueMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getBaseMutable()); +} + +LogicalResult VsturOp::verify() { + if (failed(verifyStoreAlignChain(getAlignIn(), *this, "align_in type")) || + failed(verifyVRegTypeLike(*this, getValue().getType(), "value type")) || + failed(verifyAlignTypeLike(*this, getAlignOut().getType(), "align_out type"))) + return failure(); + if (!isBufferLike(getBase().getType())) + return emitOpError("requires a pointer-like base"); + if (classifyMemoryRole(getBase().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed base"); + if (!isSupportedPostMode(getMode())) + return emitOpError("requires mode to be POST_UPDATE or NO_POST_UPDATE"); + return success(); +} + +void CopyUbufToGmOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult CopyUbufToGmOp::verify() { + return verifyCopyUbufToGmOp(*this, false); +} diff --git a/lib/PTO/Transforms/CMakeLists.txt b/lib/PTO/Transforms/CMakeLists.txt index 6979ad706..17ddd92df 100644 --- a/lib/PTO/Transforms/CMakeLists.txt +++ b/lib/PTO/Transforms/CMakeLists.txt @@ -12,6 +12,14 @@ # See LICENSE in the root of the software repository for the full text of the License. add_mlir_dialect_library(PTOTransforms + HIVMIntrinsicNaming.cpp + VPTOLLVMEmitter.cpp + PTOVPTOExpandBridgeOps.cpp + PTOVPTOPtrBoundary.cpp + PTOToVPTO.cpp + PTOToVPTOLowering.cpp + PTOValidateVPTOIR.cpp + InsertSync/PTOInsertSync.cpp InsertSync/InsertSyncDebug.cpp PTOViewToMemref.cpp @@ -49,6 +57,9 @@ add_mlir_dialect_library(PTOTransforms PTOPassesIncGen PTOOpsIncGen + LINK_COMPONENTS + Analysis + LINK_LIBS PUBLIC PTOIR MLIRIR @@ -62,7 +73,12 @@ add_mlir_dialect_library(PTOTransforms MLIRTransformUtils MLIRTransforms MLIRTensorDialect + MLIRSCFDialect MLIRSCFToEmitC + MLIRSCFToControlFlow + MLIRConvertToLLVMPass + MLIRTargetLLVMIRExport + MLIRToLLVMIRTranslationRegistration ) install(TARGETS PTOTransforms diff --git a/lib/PTO/Transforms/HIVMIntrinsicNaming.cpp b/lib/PTO/Transforms/HIVMIntrinsicNaming.cpp new file mode 100644 index 000000000..d87cb6867 --- /dev/null +++ b/lib/PTO/Transforms/HIVMIntrinsicNaming.cpp @@ -0,0 +1,561 @@ +//===- HIVMIntrinsicNaming.cpp - HIVM intrinsic selection -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PTO/Transforms/HIVMIntrinsicNaming.h" + +#include "PTO/IR/PTO.h" + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/BuiltinTypes.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include + +using namespace mlir; + +namespace mlir::pto { +namespace { + +static std::string getLocationString(Location loc) { + std::string storage; + llvm::raw_string_ostream os(storage); + loc.print(os); + return storage; +} + +static std::string sanitizeNameFragment(llvm::StringRef text) { + std::string out; + out.reserve(text.size()); + for (char c : text) { + if (std::isalnum(static_cast(c)) || c == '.' || c == '_') + out.push_back(c); + else + out.push_back('_'); + } + return out; +} + +static std::string printAttrText(Attribute attr) { + std::string storage; + llvm::raw_string_ostream os(storage); + os << attr; + return storage; +} + +static std::string getElementTypeFragment(Type type) { + if (type.isF16()) + return "f16"; + if (type.isBF16()) + return "bf16"; + if (type.isF32()) + return "f32"; + if (auto intType = dyn_cast(type)) + return (intType.isUnsigned() ? "u" : "s") + std::to_string(intType.getWidth()); + return "unknown"; +} + +static std::string getVectorTypeFragment(Type type) { + auto vecType = dyn_cast(type); + if (!vecType) + return {}; + return ("v" + std::to_string(vecType.getElementCount()) + + getElementTypeFragment(vecType.getElementType())); +} + +static std::string getCopyElementFragment(Type type) { + auto ptrType = dyn_cast(type); + if (!ptrType) + return {}; + Type elementType = ptrType.getElementType(); + if (auto floatType = dyn_cast(elementType)) { + switch ((floatType.getWidth() + 7) / 8) { + case 1: + return "u8"; + case 2: + return "u16"; + case 4: + case 8: + return "u32"; + default: + return {}; + } + } + if (auto intType = dyn_cast(elementType)) { + switch ((intType.getWidth() + 7) / 8) { + case 1: + return "u8"; + case 2: + return "u16"; + case 4: + case 8: + return "u32"; + default: + return {}; + } + } + return {}; +} + +static std::string getOpMnemonic(Operation *op) { + return op->getName().stripDialect().str(); +} + +static IntrinsicSelection makeResolved(Operation *op, llvm::StringRef calleeName, + llvm::ArrayRef usedFields, + llvm::StringRef resultTypeFragment) { + IntrinsicSelection selection; + selection.resolved = true; + selection.sourceOpName = op->getName().getStringRef().str(); + selection.calleeName = calleeName.str(); + selection.usedFields.assign(usedFields.begin(), usedFields.end()); + selection.resultTypeFragment = resultTypeFragment.str(); + selection.location = getLocationString(op->getLoc()); + return selection; +} + +static IntrinsicSelection makeUnresolved(Operation *op, + llvm::StringRef familyOrOp, + llvm::StringRef candidateName, + llvm::ArrayRef usedFields, + llvm::ArrayRef missingFields, + llvm::StringRef resultTypeFragment) { + IntrinsicSelection selection; + selection.resolved = false; + selection.sourceOpName = op->getName().getStringRef().str(); + selection.candidateName = candidateName.str(); + selection.usedFields.assign(usedFields.begin(), usedFields.end()); + selection.missingFields.assign(missingFields.begin(), missingFields.end()); + selection.resultTypeFragment = resultTypeFragment.str(); + selection.location = getLocationString(op->getLoc()); + + std::string name = "__ptoas_hivm_unresolved."; + name += sanitizeNameFragment(familyOrOp); + if (!resultTypeFragment.empty()) { + name += "."; + name += sanitizeNameFragment(resultTypeFragment); + } + selection.placeholderName = std::move(name); + return selection; +} + +static FailureOr selectSyncLike(Operation *op) { + llvm::SmallVector usedFields; + usedFields.push_back("op=" + getOpMnemonic(op)); + + if (auto setFlag = dyn_cast(op)) { + usedFields.push_back("src_pipe=" + printAttrText(setFlag.getSrcPipe())); + usedFields.push_back("dst_pipe=" + printAttrText(setFlag.getDstPipe())); + usedFields.push_back("event=" + printAttrText(setFlag.getEventId())); + return makeResolved(op, "llvm.hivm.SET.FLAG.IMM", usedFields, ""); + } else if (auto waitFlag = dyn_cast(op)) { + usedFields.push_back("src_pipe=" + printAttrText(waitFlag.getSrcPipe())); + usedFields.push_back("dst_pipe=" + printAttrText(waitFlag.getDstPipe())); + usedFields.push_back("event=" + printAttrText(waitFlag.getEventId())); + return makeResolved(op, "llvm.hivm.WAIT.FLAG.IMM", usedFields, ""); + } else if (auto barrier = dyn_cast(op)) { + usedFields.push_back("pipe=" + printAttrText(barrier.getPipe())); + return makeResolved(op, "llvm.hivm.BARRIER", usedFields, ""); + } + + llvm::SmallVector missingFields = {"confirmed_hivm_name"}; + return makeUnresolved(op, getOpMnemonic(op), "", usedFields, missingFields, ""); +} + +static FailureOr selectConfigLike(Operation *op) { + llvm::SmallVector usedFields = {"op=" + getOpMnemonic(op)}; + + if (isa(op)) + return makeResolved(op, "llvm.hivm.SET.LOOP2.STRIDE.OUTTOUB", usedFields, + ""); + if (isa(op)) + return makeResolved(op, "llvm.hivm.SET.LOOP1.STRIDE.OUTTOUB", + usedFields, ""); + if (isa(op)) + return makeResolved(op, "llvm.hivm.SET.LOOP.SIZE.OUTTOUB", usedFields, ""); + if (isa(op)) + return makeResolved(op, "llvm.hivm.SET.LOOP2.STRIDE.UBTOOUT", usedFields, + ""); + if (isa(op)) + return makeResolved(op, "llvm.hivm.SET.LOOP1.STRIDE.UBTOOUT", usedFields, + ""); + if (isa(op)) + return makeResolved(op, "llvm.hivm.SET.LOOP.SIZE.UBTOOUT", usedFields, ""); + + llvm::SmallVector missingFields = {"confirmed_hivm_name"}; + return makeUnresolved(op, getOpMnemonic(op), "", usedFields, missingFields, + ""); +} + +static FailureOr selectPredicateIntrinsic(Operation *op) { + llvm::SmallVector usedFields; + if (auto pset = dyn_cast(op)) { + const std::string resultFragment = + getVectorTypeFragment(pset.getResult().getType()); + usedFields = {"family=pset", "bitwidth=8", "result=" + resultFragment, + "pattern=i32"}; + return makeResolved(op, "llvm.hivm.pset.b8", usedFields, resultFragment); + } + if (auto pset = dyn_cast(op)) { + const std::string resultFragment = + getVectorTypeFragment(pset.getResult().getType()); + usedFields = {"family=pset", "bitwidth=16", "result=" + resultFragment, + "pattern=i32"}; + return makeResolved(op, "llvm.hivm.pset.b16", usedFields, resultFragment); + } + if (auto pset = dyn_cast(op)) { + const std::string resultFragment = + getVectorTypeFragment(pset.getResult().getType()); + usedFields = {"family=pset", "bitwidth=32", "result=" + resultFragment, + "pattern=i32"}; + return makeResolved(op, "llvm.hivm.pset.b32", usedFields, resultFragment); + } + if (auto pge = dyn_cast(op)) { + const std::string resultFragment = + getVectorTypeFragment(pge.getResult().getType()); + usedFields = {"family=pge", "bitwidth=8", "result=" + resultFragment, + "pattern=i32", "variant=i32_zero"}; + return makeResolved(op, "llvm.hivm.pge.b8", usedFields, resultFragment); + } + if (auto pge = dyn_cast(op)) { + const std::string resultFragment = + getVectorTypeFragment(pge.getResult().getType()); + usedFields = {"family=pge", "bitwidth=16", "result=" + resultFragment, + "pattern=i32", "variant=i32_zero"}; + return makeResolved(op, "llvm.hivm.pge.b16", usedFields, resultFragment); + } + if (auto pge = dyn_cast(op)) { + const std::string resultFragment = + getVectorTypeFragment(pge.getResult().getType()); + usedFields = {"family=pge", "bitwidth=32", "result=" + resultFragment, + "pattern=i32", "variant=i32_zero"}; + return makeResolved(op, "llvm.hivm.pge.b32", usedFields, resultFragment); + } + if (auto plt = dyn_cast(op)) { + const std::string resultFragment = + getVectorTypeFragment(plt.getMask().getType()); + usedFields = {"family=plt", "bitwidth=8", "result=" + resultFragment, + "variant=v300", "scalar=i32", "scalar_out=i32"}; + return makeResolved(op, "llvm.hivm.plt.b8.v300", usedFields, resultFragment); + } + if (auto plt = dyn_cast(op)) { + const std::string resultFragment = + getVectorTypeFragment(plt.getMask().getType()); + usedFields = {"family=plt", "bitwidth=16", "result=" + resultFragment, + "variant=v300", "scalar=i32", "scalar_out=i32"}; + return makeResolved(op, "llvm.hivm.plt.b16.v300", usedFields, resultFragment); + } + if (auto plt = dyn_cast(op)) { + const std::string resultFragment = + getVectorTypeFragment(plt.getMask().getType()); + usedFields = {"family=plt", "bitwidth=32", "result=" + resultFragment, + "variant=v300", "scalar=i32", "scalar_out=i32"}; + return makeResolved(op, "llvm.hivm.plt.b32.v300", usedFields, resultFragment); + } + + return failure(); +} + +} // namespace + +FailureOr selectLoadIntrinsic(Operation *op) { + if (auto vlds = dyn_cast(op)) { + const std::string vecFragment = getVectorTypeFragment(vlds.getResult().getType()); + llvm::SmallVector usedFields = { + "family=vldsx1", "vector=" + vecFragment, "mode=NO_POST_UPDATE"}; + if (vlds.getDistAttr()) + usedFields.push_back("dist=" + (*vlds.getDist()).str()); + + if (vecFragment == "v64f32") + return makeResolved(op, "llvm.hivm.vldsx1", usedFields, vecFragment); + + llvm::SmallVector missingFields = {"confirmed_hivm_name"}; + std::string candidate = "llvm.hivm.vldsx1"; + return makeUnresolved(op, "vldsx1", candidate, usedFields, missingFields, + vecFragment); + } + + if (auto vldsPost = dyn_cast(op)) { + const std::string vecFragment = + getVectorTypeFragment(vldsPost.getResult().getType()); + llvm::SmallVector usedFields = { + "family=vldsx1", "variant=post", "vector=" + vecFragment, + "mode=POST_UPDATE"}; + if (vldsPost.getDistAttr()) + usedFields.push_back("dist=" + (*vldsPost.getDist()).str()); + + if (vecFragment == "v64f32") + return makeResolved(op, "llvm.hivm.vldsx1.post", usedFields, vecFragment); + + llvm::SmallVector missingFields = {"confirmed_hivm_name"}; + std::string candidate = "llvm.hivm.vldsx1.post"; + return makeUnresolved(op, "vldsx1.post", candidate, usedFields, + missingFields, vecFragment); + } + + return failure(); +} + +FailureOr selectUnaryIntrinsic(Operation *op) { + auto vabs = dyn_cast(op); + if (vabs) { + const std::string vecFragment = getVectorTypeFragment(vabs.getResult().getType()); + llvm::SmallVector usedFields = { + "family=vabs", "vector=" + vecFragment, "variant=x"}; + + if (vecFragment == "v64f32") + return makeResolved(op, "llvm.hivm.vabs.v64f32.x", usedFields, vecFragment); + + llvm::SmallVector missingFields = {"confirmed_hivm_name"}; + std::string candidate = "llvm.hivm.vabs"; + if (!vecFragment.empty()) + candidate += "." + vecFragment + ".x"; + return makeUnresolved(op, "vabs", candidate, usedFields, missingFields, + vecFragment); + } + + if (auto vexp = dyn_cast(op)) { + const std::string vecFragment = getVectorTypeFragment(vexp.getResult().getType()); + llvm::SmallVector usedFields = { + "family=vexp", "vector=" + vecFragment, "variant=x"}; + std::string candidate = "llvm.hivm.vexp"; + if (!vecFragment.empty()) + candidate += "." + vecFragment + ".x"; + return makeResolved(op, candidate, usedFields, vecFragment); + } + + if (auto vdup = dyn_cast(op)) { + const std::string vecFragment = getVectorTypeFragment(vdup.getResult().getType()); + const bool vectorInput = isa(vdup.getInput().getType()); + const StringRef position = vdup.getPosition().value_or("LOWEST"); + const char *family = + vectorInput ? (position == "HIGHEST" ? "vdupm" : "vdup") : "vdups"; + llvm::SmallVector usedFields = { + "family=" + std::string(family), "vector=" + vecFragment, + "variant=z"}; + if (!vectorInput && !isa(vdup.getInput().getType())) { + llvm::SmallVector missingFields = {"scalar_input_vdup_mapping"}; + return makeUnresolved(op, "vdup", "llvm.hivm.vdups", usedFields, missingFields, + vecFragment); + } + std::string candidate = "llvm.hivm."; + candidate += family; + if (!vecFragment.empty()) + candidate += "." + vecFragment + ".z"; + return makeResolved(op, candidate, usedFields, vecFragment); + } + + if (auto binary = dyn_cast(op)) { + const std::string vecFragment = getVectorTypeFragment(binary.getResult().getType()); + llvm::SmallVector usedFields = { + "family=vadd", "vector=" + vecFragment, "variant=x"}; + std::string candidate = "llvm.hivm.vadd"; + if (!vecFragment.empty()) + candidate += "." + vecFragment + ".x"; + return makeResolved(op, candidate, usedFields, vecFragment); + } + + if (auto binary = dyn_cast(op)) { + const std::string vecFragment = getVectorTypeFragment(binary.getResult().getType()); + llvm::SmallVector usedFields = { + "family=vsub", "vector=" + vecFragment, "variant=x"}; + std::string candidate = "llvm.hivm.vsub"; + if (!vecFragment.empty()) + candidate += "." + vecFragment + ".x"; + return makeResolved(op, candidate, usedFields, vecFragment); + } + + if (auto binary = dyn_cast(op)) { + const std::string vecFragment = getVectorTypeFragment(binary.getResult().getType()); + llvm::SmallVector usedFields = { + "family=vmul", "vector=" + vecFragment, "variant=x"}; + std::string candidate = "llvm.hivm.vmul"; + if (!vecFragment.empty()) + candidate += "." + vecFragment + ".x"; + return makeResolved(op, candidate, usedFields, vecFragment); + } + + if (auto binary = dyn_cast(op)) { + const std::string vecFragment = getVectorTypeFragment(binary.getResult().getType()); + llvm::SmallVector usedFields = { + "family=vmax", "vector=" + vecFragment, "variant=x"}; + std::string candidate = "llvm.hivm.vmax"; + if (!vecFragment.empty()) + candidate += "." + vecFragment + ".x"; + return makeResolved(op, candidate, usedFields, vecFragment); + } + + if (auto binary = dyn_cast(op)) { + const std::string vecFragment = getVectorTypeFragment(binary.getResult().getType()); + llvm::SmallVector usedFields = { + "family=vmuls", "vector=" + vecFragment, "variant=x"}; + std::string candidate = "llvm.hivm.vmuls"; + if (!vecFragment.empty()) + candidate += "." + vecFragment + ".x"; + return makeResolved(op, candidate, usedFields, vecFragment); + } + + if (auto binary = dyn_cast(op)) { + const std::string vecFragment = getVectorTypeFragment(binary.getResult().getType()); + llvm::SmallVector usedFields = { + "family=vadds", "vector=" + vecFragment, "variant=x"}; + std::string candidate = "llvm.hivm.vadds"; + if (!vecFragment.empty()) + candidate += "." + vecFragment + ".x"; + return makeResolved(op, candidate, usedFields, vecFragment); + } + + if (auto binary = dyn_cast(op)) { + const std::string vecFragment = getVectorTypeFragment(binary.getResult().getType()); + llvm::SmallVector usedFields = { + "family=vmaxs", "vector=" + vecFragment, "variant=x"}; + std::string candidate = "llvm.hivm.vmaxs"; + if (!vecFragment.empty()) + candidate += "." + vecFragment + ".x"; + return makeResolved(op, candidate, usedFields, vecFragment); + } + + if (auto binary = dyn_cast(op)) { + const std::string vecFragment = getVectorTypeFragment(binary.getResult().getType()); + llvm::SmallVector usedFields = { + "family=vmins", "vector=" + vecFragment, "variant=x"}; + std::string candidate = "llvm.hivm.vmins"; + if (!vecFragment.empty()) + candidate += "." + vecFragment + ".x"; + return makeResolved(op, candidate, usedFields, vecFragment); + } + + if (auto binary = dyn_cast(op)) { + const std::string vecFragment = getVectorTypeFragment(binary.getResult().getType()); + llvm::SmallVector usedFields = { + "family=vlrelu", "vector=" + vecFragment, "variant=x"}; + std::string candidate = "llvm.hivm.vlrelu"; + if (!vecFragment.empty()) + candidate += "." + vecFragment + ".x"; + return makeResolved(op, candidate, usedFields, vecFragment); + } + + if (auto binary = dyn_cast(op)) { + const std::string vecFragment = getVectorTypeFragment(binary.getResult().getType()); + llvm::SmallVector usedFields = { + "family=vshls", "vector=" + vecFragment, "variant=x"}; + std::string candidate = "llvm.hivm.vshls"; + if (!vecFragment.empty()) + candidate += "." + vecFragment + ".x"; + return makeResolved(op, candidate, usedFields, vecFragment); + } + + if (auto binary = dyn_cast(op)) { + const std::string vecFragment = getVectorTypeFragment(binary.getResult().getType()); + llvm::SmallVector usedFields = { + "family=vshrs", "vector=" + vecFragment, "variant=x"}; + std::string candidate = "llvm.hivm.vshrs"; + if (!vecFragment.empty()) + candidate += "." + vecFragment + ".x"; + return makeResolved(op, candidate, usedFields, vecFragment); + } + + return failure(); +} + +FailureOr selectStoreIntrinsic(Operation *op) { + llvm::SmallVector usedFields; + llvm::SmallVector missingFields = {"confirmed_hivm_name"}; + + if (auto vsts = dyn_cast(op)) { + const std::string vecFragment = getVectorTypeFragment(vsts.getValue().getType()); + usedFields = {"family=vstsx1", "vector=" + vecFragment, + "predicate_source=explicit_mask", "mode=NO_POST_UPDATE"}; + if (vsts.getDistAttr()) + usedFields.push_back("dist=" + (*vsts.getDist()).str()); + if (vecFragment == "v64f32") + return makeResolved(op, "llvm.hivm.vstsx1", usedFields, vecFragment); + return makeUnresolved(op, "vstsx1", "llvm.hivm.vstsx1", usedFields, missingFields, + vecFragment); + } + + if (auto vstsPost = dyn_cast(op)) { + const std::string vecFragment = + getVectorTypeFragment(vstsPost.getValue().getType()); + usedFields = {"family=vstsx1", "variant=post", "vector=" + vecFragment, + "predicate_source=explicit_mask", "mode=POST_UPDATE"}; + if (vstsPost.getDistAttr()) + usedFields.push_back("dist=" + (*vstsPost.getDist()).str()); + if (vecFragment == "v64f32") + return makeResolved(op, "llvm.hivm.vstsx1.post", usedFields, + vecFragment); + std::string candidate = "llvm.hivm.vstsx1.post"; + return makeUnresolved(op, "vstsx1.post", candidate, usedFields, + missingFields, vecFragment); + } + + if (auto copy = dyn_cast(op)) { + std::string elemFragment = getCopyElementFragment(copy.getSource().getType()); + usedFields = {"family=copy_gm_to_ubuf"}; + if (!elemFragment.empty()) + usedFields.push_back("element=" + elemFragment); + if (elemFragment == "u8" || elemFragment == "u16" || + elemFragment == "u32" || elemFragment == "f32") { + std::string callee = "llvm.hivm.MOV.OUT.TO.UB.ALIGN.V2."; + callee += elemFragment; + callee += ".DV"; + return makeResolved(op, callee, usedFields, ""); + } + std::string candidate = "llvm.hivm.MOV.OUT.TO.UB.ALIGN.V2"; + if (!elemFragment.empty()) + candidate += "." + elemFragment + ".DV"; + missingFields.push_back("element_type_mapping"); + return makeUnresolved(op, "copy_gm_to_ubuf", candidate, usedFields, + missingFields, ""); + } + + if (auto copy = dyn_cast(op)) { + std::string elemFragment = getCopyElementFragment(copy.getSource().getType()); + usedFields = {"family=copy_ubuf_to_gm"}; + if (!elemFragment.empty()) + usedFields.push_back("element=" + elemFragment); + return makeResolved(op, "llvm.hivm.MOV.UB.TO.OUT.ALIGN.V2.DV", + usedFields, ""); + } + + if (isa(op)) { + usedFields = {"family=copy_ubuf_to_ubuf"}; + return makeUnresolved(op, "copy_ubuf_to_ubuf", "copy_ubuf_to_ubuf", + usedFields, missingFields, ""); + } + + return failure(); +} + +FailureOr selectIntrinsic(Operation *op) { + if (isa(op)) + return selectSyncLike(op); + + if (isa(op)) + return selectConfigLike(op); + + if (succeeded(selectLoadIntrinsic(op))) + return *selectLoadIntrinsic(op); + if (succeeded(selectUnaryIntrinsic(op))) + return *selectUnaryIntrinsic(op); + if (succeeded(selectPredicateIntrinsic(op))) + return *selectPredicateIntrinsic(op); + if (succeeded(selectStoreIntrinsic(op))) + return *selectStoreIntrinsic(op); + + llvm::SmallVector usedFields = {"op=" + getOpMnemonic(op)}; + llvm::SmallVector missingFields = {"family_mapping", + "confirmed_hivm_name"}; + return makeUnresolved(op, getOpMnemonic(op), "", usedFields, missingFields, + ""); +} + +} // namespace mlir::pto diff --git a/lib/PTO/Transforms/PTOToVPTO.cpp b/lib/PTO/Transforms/PTOToVPTO.cpp new file mode 100644 index 000000000..1661e1fcf --- /dev/null +++ b/lib/PTO/Transforms/PTOToVPTO.cpp @@ -0,0 +1,604 @@ +//===- PTOToVPTO.cpp - PTO to VPTO pass wiring ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PTO/Transforms/VPTOLowering.h" +#include "PTO/Transforms/Passes.h" + +#include "PTO/IR/PTO.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +namespace pto { + +#define GEN_PASS_DEF_PTOTOVPTO +#include "PTO/Transforms/Passes.h.inc" + +namespace { + + +FailureOr +parseVPTOLoweringStrategy(StringRef strategyName) { + if (strategyName == "post-update") + return VPTOLoweringStrategy::PostUpdate; + if (strategyName == "no-post-update") + return VPTOLoweringStrategy::NoPostUpdate; + return failure(); +} + +LogicalResult lowerTLOADOp(TLoadOp op, PatternRewriter &rewriter) { + return lowerTLOAD(op, rewriter); +} + +LogicalResult lowerTABSOp(TAbsOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTABS(op, rewriter, strategy); +} + +LogicalResult lowerTADDOp(TAddOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTADD(op, rewriter, strategy); +} + +LogicalResult lowerTSUBOp(TSubOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTSUB(op, rewriter, strategy); +} + +LogicalResult lowerTMULOp(TMulOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTMUL(op, rewriter, strategy); +} + +LogicalResult lowerTDIVOp(TDivOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTDIV(op, rewriter, strategy); +} + +LogicalResult lowerTMAXOp(TMaxOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTMAX(op, rewriter, strategy); +} + +LogicalResult lowerTMINOp(TMinOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTMIN(op, rewriter, strategy); +} + +LogicalResult lowerTANDOp(TAndOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTAND(op, rewriter, strategy); +} + +LogicalResult lowerTANDSOp(TAndSOp op, PatternRewriter &rewriter) { + return lowerTANDS(op, rewriter); +} + +LogicalResult lowerTOROp(TOrOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTOR(op, rewriter, strategy); +} + +LogicalResult lowerTORSOp(TOrSOp op, PatternRewriter &rewriter) { + return lowerTORS(op, rewriter); +} + +LogicalResult lowerTXOROp(TXorOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTXOR(op, rewriter, strategy); +} + +LogicalResult lowerTXORSOp(TXorSOp op, PatternRewriter &rewriter) { + return lowerTXORS(op, rewriter); +} + +LogicalResult lowerTEXPOp(TExpOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTEXP(op, rewriter, strategy); +} + +LogicalResult lowerTLOGOp(TLogOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTLOG(op, rewriter, strategy); +} + +LogicalResult lowerTSQRTOp(TSqrtOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTSQRT(op, rewriter, strategy); +} + +LogicalResult lowerTRSQRTOp(TRsqrtOp op, PatternRewriter &rewriter) { + return lowerTRSQRT(op, rewriter); +} + +LogicalResult lowerTRECIPOp(TRecipOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTRECIP(op, rewriter, strategy); +} + +LogicalResult lowerTNEGOp(TNegOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTNEG(op, rewriter, strategy); +} + +LogicalResult lowerTLRELUOp(TLReluOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTLRELU(op, rewriter, strategy); +} + +LogicalResult lowerTCIOp(TCIOp op, PatternRewriter &rewriter) { + return lowerTCI(op, rewriter); +} + +LogicalResult lowerTCVTOp(TCvtOp op, PatternRewriter &rewriter) { + return lowerTCVT(op, rewriter); +} + +LogicalResult lowerTCmpOp(TCmpOp op, PatternRewriter &rewriter) { + return lowerTCmp(op, rewriter); +} + +LogicalResult lowerTCmpSOp(TCmpSOp op, PatternRewriter &rewriter) { + return lowerTCmpS(op, rewriter); +} + +LogicalResult lowerTSelOp(TSelOp op, PatternRewriter &rewriter) { + return lowerTSel(op, rewriter); +} + +LogicalResult lowerTAddCOp(TAddCOp op, PatternRewriter &rewriter) { + return lowerTAddC(op, rewriter); +} + +LogicalResult lowerTAddSOp(TAddSOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTAddS(op, rewriter, strategy); +} + +LogicalResult lowerTAddSCOp(TAddSCOp op, PatternRewriter &rewriter) { + return lowerTAddSC(op, rewriter); +} + +LogicalResult lowerTMinSOp(TMinSOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTMinS(op, rewriter, strategy); +} + +LogicalResult lowerTSubCOp(TSubCOp op, PatternRewriter &rewriter) { + return lowerTSubC(op, rewriter); +} + +LogicalResult lowerTSubSOp(TSubSOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTSubS(op, rewriter, strategy); +} + +LogicalResult lowerTSubSCOp(TSubSCOp op, PatternRewriter &rewriter) { + return lowerTSubSC(op, rewriter); +} + +LogicalResult lowerTMaxSOp(TMaxSOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTMaxS(op, rewriter, strategy); +} + +LogicalResult lowerTDivSOp(TDivSOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTDivS(op, rewriter, strategy); +} + +LogicalResult lowerTMulSOp(TMulSOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTMulS(op, rewriter, strategy); +} + +LogicalResult lowerTSelSOp(TSelSOp op, PatternRewriter &rewriter) { + return lowerTSelS(op, rewriter); +} + +LogicalResult lowerTRELUOp(TReluOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTRELU(op, rewriter, strategy); +} + +LogicalResult lowerTNOTOp(TNotOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTNOT(op, rewriter, strategy); +} + +LogicalResult lowerTTRANSOp(TTransOp op, PatternRewriter &rewriter) { + return lowerTTRANS(op, rewriter); +} + +LogicalResult lowerTFILLPADOp(TFillPadOp op, PatternRewriter &rewriter) { + return lowerTFILLPAD(op, rewriter); +} + +LogicalResult lowerTFILLPADExpandOp(TFillPadExpandOp op, PatternRewriter &rewriter) { + return lowerTFILLPADExpand(op, rewriter); +} + +LogicalResult lowerTRowMaxOp(TRowMaxOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTRowMax(op, rewriter, strategy); +} + +LogicalResult lowerTRowMinOp(TRowMinOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTRowMin(op, rewriter, strategy); +} + +LogicalResult lowerTRowSumOp(TRowSumOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTRowSum(op, rewriter, strategy); +} + +LogicalResult lowerTColMaxOp(TColMaxOp op, PatternRewriter &rewriter) { + return lowerTColMax(op, rewriter); +} + +LogicalResult lowerTColMinOp(TColMinOp op, PatternRewriter &rewriter) { + return lowerTColMin(op, rewriter); +} + +LogicalResult lowerTColSumOp(TColSumOp op, PatternRewriter &rewriter) { + return lowerTColSum(op, rewriter); +} + +LogicalResult lowerTRowExpandOp(TRowExpandOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTRowExpand(op, rewriter, strategy); +} + +LogicalResult lowerTColExpandOp(TColExpandOp op, PatternRewriter &rewriter) { + return lowerTColExpand(op, rewriter); +} + +LogicalResult lowerTRowExpandMulOp(TRowExpandMulOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTRowExpandMul(op, rewriter, strategy); +} + +LogicalResult lowerTRowExpandDivOp(TRowExpandDivOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTRowExpandDiv(op, rewriter, strategy); +} + +LogicalResult lowerTRowExpandSubOp(TRowExpandSubOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTRowExpandSub(op, rewriter, strategy); +} + +LogicalResult lowerTPartAddOp(TPartAddOp op, PatternRewriter &rewriter) { + return lowerTPartAdd(op, rewriter); +} + +LogicalResult lowerTPartMaxOp(TPartMaxOp op, PatternRewriter &rewriter) { + return lowerTPartMax(op, rewriter); +} + +LogicalResult lowerTPartMinOp(TPartMinOp op, PatternRewriter &rewriter) { + return lowerTPartMin(op, rewriter); +} + +LogicalResult lowerTExpandSOp(TExpandsOp op, PatternRewriter &rewriter) { + return lowerTExpandS(op, rewriter); +} + +LogicalResult lowerTGatherOp(TGatherOp op, PatternRewriter &rewriter) { + return lowerTGather(op, rewriter); +} + +LogicalResult lowerTGatherBOp(TGatherBOp op, PatternRewriter &rewriter) { + return lowerTGatherB(op, rewriter); +} + +LogicalResult lowerTScatterOp(TScatterOp op, PatternRewriter &rewriter) { + return lowerTScatter(op, rewriter); +} + +LogicalResult lowerTMrgSortOp(TMrgSortOp op, PatternRewriter &rewriter) { + return lowerTMrgSort(op, rewriter); +} + +LogicalResult lowerTSort32Op(TSort32Op op, PatternRewriter &rewriter) { + return lowerTSort32(op, rewriter); +} + +LogicalResult lowerTSTOREOp(TStoreOp op, PatternRewriter &rewriter) { + return lowerTSTORE(op, rewriter); +} + +LogicalResult lowerSetFlagOp(SetFlagOp op, PatternRewriter &rewriter) { + return lowerSetFlag(op, rewriter); +} + +LogicalResult lowerWaitFlagOp(WaitFlagOp op, PatternRewriter &rewriter) { + return lowerWaitFlag(op, rewriter); +} + +LogicalResult lowerBarrierOp(BarrierOp op, PatternRewriter &rewriter) { + return lowerBarrier(op, rewriter); +} + +LogicalResult lowerGetBufOp(GetBufOp op, PatternRewriter &rewriter) { + return lowerGetBuf(op, rewriter); +} + +LogicalResult lowerRlsBufOp(RlsBufOp op, PatternRewriter &rewriter) { + return lowerRlsBuf(op, rewriter); +} + +LogicalResult lowerTensorPipelineOp(Operation *op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + rewriter.setInsertionPoint(op); + + LogicalResult lowered = success(); + if (auto tload = dyn_cast(op)) + lowered = lowerTLOADOp(tload, rewriter); + else if (auto tabs = dyn_cast(op)) + lowered = lowerTABSOp(tabs, rewriter, strategy); + else if (auto tadd = dyn_cast(op)) + lowered = lowerTADDOp(tadd, rewriter, strategy); + else if (auto tsub = dyn_cast(op)) + lowered = lowerTSUBOp(tsub, rewriter, strategy); + else if (auto tmul = dyn_cast(op)) + lowered = lowerTMULOp(tmul, rewriter, strategy); + else if (auto tdiv = dyn_cast(op)) + lowered = lowerTDIVOp(tdiv, rewriter, strategy); + else if (auto tmax = dyn_cast(op)) + lowered = lowerTMAXOp(tmax, rewriter, strategy); + else if (auto tmin = dyn_cast(op)) + lowered = lowerTMINOp(tmin, rewriter, strategy); + else if (auto tand = dyn_cast(op)) + lowered = lowerTANDOp(tand, rewriter, strategy); + else if (auto tands = dyn_cast(op)) + lowered = lowerTANDSOp(tands, rewriter); + else if (auto tor = dyn_cast(op)) + lowered = lowerTOROp(tor, rewriter, strategy); + else if (auto tors = dyn_cast(op)) + lowered = lowerTORSOp(tors, rewriter); + else if (auto txor = dyn_cast(op)) + lowered = lowerTXOROp(txor, rewriter, strategy); + else if (auto txors = dyn_cast(op)) + lowered = lowerTXORSOp(txors, rewriter); + else if (auto texp = dyn_cast(op)) + lowered = lowerTEXPOp(texp, rewriter, strategy); + else if (auto tlog = dyn_cast(op)) + lowered = lowerTLOGOp(tlog, rewriter, strategy); + else if (auto tsqrt = dyn_cast(op)) + lowered = lowerTSQRTOp(tsqrt, rewriter, strategy); + else if (auto trsqr = dyn_cast(op)) + lowered = lowerTRSQRTOp(trsqr, rewriter); + else if (auto trecip = dyn_cast(op)) + lowered = lowerTRECIPOp(trecip, rewriter, strategy); + else if (auto tneg = dyn_cast(op)) + lowered = lowerTNEGOp(tneg, rewriter, strategy); + else if (auto tlrelu = dyn_cast(op)) + lowered = lowerTLRELUOp(tlrelu, rewriter, strategy); + else if (auto tci = dyn_cast(op)) + lowered = lowerTCIOp(tci, rewriter); + else if (auto tcvt = dyn_cast(op)) + lowered = lowerTCVTOp(tcvt, rewriter); + else if (auto tcmp = dyn_cast(op)) + lowered = lowerTCmpOp(tcmp, rewriter); + else if (auto tcmps = dyn_cast(op)) + lowered = lowerTCmpSOp(tcmps, rewriter); + else if (auto tsel = dyn_cast(op)) + lowered = lowerTSelOp(tsel, rewriter); + else if (auto taddc = dyn_cast(op)) + lowered = lowerTAddCOp(taddc, rewriter); + else if (auto tadds = dyn_cast(op)) + lowered = lowerTAddSOp(tadds, rewriter, strategy); + else if (auto taddsc = dyn_cast(op)) + lowered = lowerTAddSCOp(taddsc, rewriter); + else if (auto tmins = dyn_cast(op)) + lowered = lowerTMinSOp(tmins, rewriter, strategy); + else if (auto tsubc = dyn_cast(op)) + lowered = lowerTSubCOp(tsubc, rewriter); + else if (auto tsubs = dyn_cast(op)) + lowered = lowerTSubSOp(tsubs, rewriter, strategy); + else if (auto tsubsc = dyn_cast(op)) + lowered = lowerTSubSCOp(tsubsc, rewriter); + else if (auto tmaxs = dyn_cast(op)) + lowered = lowerTMaxSOp(tmaxs, rewriter, strategy); + else if (auto tdivs = dyn_cast(op)) + lowered = lowerTDivSOp(tdivs, rewriter, strategy); + else if (auto tmuls = dyn_cast(op)) + lowered = lowerTMulSOp(tmuls, rewriter, strategy); + else if (auto tsels = dyn_cast(op)) + lowered = lowerTSelSOp(tsels, rewriter); + else if (auto trelu = dyn_cast(op)) + lowered = lowerTRELUOp(trelu, rewriter, strategy); + else if (auto tnot = dyn_cast(op)) + lowered = lowerTNOTOp(tnot, rewriter, strategy); + else if (auto ttrans = dyn_cast(op)) + lowered = lowerTTRANSOp(ttrans, rewriter); + else if (auto tfillpad = dyn_cast(op)) + lowered = lowerTFILLPADOp(tfillpad, rewriter); + else if (auto tfillpadExpand = dyn_cast(op)) + lowered = lowerTFILLPADExpandOp(tfillpadExpand, rewriter); + else if (auto trowmax = dyn_cast(op)) + lowered = lowerTRowMaxOp(trowmax, rewriter, strategy); + else if (auto trowmin = dyn_cast(op)) + lowered = lowerTRowMinOp(trowmin, rewriter, strategy); + else if (auto trowsum = dyn_cast(op)) + lowered = lowerTRowSumOp(trowsum, rewriter, strategy); + else if (auto tcolmax = dyn_cast(op)) + lowered = lowerTColMaxOp(tcolmax, rewriter); + else if (auto tcolmin = dyn_cast(op)) + lowered = lowerTColMinOp(tcolmin, rewriter); + else if (auto tcolsum = dyn_cast(op)) + lowered = lowerTColSumOp(tcolsum, rewriter); + else if (auto trowexpand = dyn_cast(op)) + lowered = lowerTRowExpandOp(trowexpand, rewriter, strategy); + else if (auto tcolexpand = dyn_cast(op)) + lowered = lowerTColExpandOp(tcolexpand, rewriter); + else if (auto trowexpandmul = dyn_cast(op)) + lowered = lowerTRowExpandMulOp(trowexpandmul, rewriter, strategy); + else if (auto trowexpanddiv = dyn_cast(op)) + lowered = lowerTRowExpandDivOp(trowexpanddiv, rewriter, strategy); + else if (auto trowexpandsub = dyn_cast(op)) + lowered = lowerTRowExpandSubOp(trowexpandsub, rewriter, strategy); + else if (auto tpartadd = dyn_cast(op)) + lowered = lowerTPartAddOp(tpartadd, rewriter); + else if (auto tpartmax = dyn_cast(op)) + lowered = lowerTPartMaxOp(tpartmax, rewriter); + else if (auto tpartmin = dyn_cast(op)) + lowered = lowerTPartMinOp(tpartmin, rewriter); + else if (auto texpands = dyn_cast(op)) + lowered = lowerTExpandSOp(texpands, rewriter); + else if (auto tgather = dyn_cast(op)) + lowered = lowerTGatherOp(tgather, rewriter); + else if (auto tgatherb = dyn_cast(op)) + lowered = lowerTGatherBOp(tgatherb, rewriter); + else if (auto tscatter = dyn_cast(op)) + lowered = lowerTScatterOp(tscatter, rewriter); + else if (auto tmrgsort = dyn_cast(op)) + lowered = lowerTMrgSortOp(tmrgsort, rewriter); + else if (auto tsort32 = dyn_cast(op)) + lowered = lowerTSort32Op(tsort32, rewriter); + else if (auto tstore = dyn_cast(op)) + lowered = lowerTSTOREOp(tstore, rewriter); + else + return success(); + + if (failed(lowered)) + return failure(); + + rewriter.eraseOp(op); + return success(); +} + +LogicalResult lowerResidualPTOOp(Operation *op, PatternRewriter &rewriter) { + rewriter.setInsertionPoint(op); + + LogicalResult lowered = success(); + if (auto setFlag = dyn_cast(op)) + lowered = lowerSetFlagOp(setFlag, rewriter); + else if (auto waitFlag = dyn_cast(op)) + lowered = lowerWaitFlagOp(waitFlag, rewriter); + else if (auto barrier = dyn_cast(op)) + lowered = lowerBarrierOp(barrier, rewriter); + else if (auto getBuf = dyn_cast(op)) + lowered = lowerGetBufOp(getBuf, rewriter); + else if (auto rlsBuf = dyn_cast(op)) + lowered = lowerRlsBufOp(rlsBuf, rewriter); + else if (isa(op) && op->use_empty()) + lowered = success(); + else + return success(); + + if (failed(lowered)) + return failure(); + + rewriter.eraseOp(op); + return success(); +} + +struct PTOToVPTOPass : public impl::PTOToVPTOBase { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PTOToVPTOPass) + + PTOToVPTOPass() = default; + + explicit PTOToVPTOPass(StringRef loweringStrategy) { + this->loweringStrategy = loweringStrategy.str(); + } + + void runOnOperation() override { + ModuleOp module = getOperation(); + FailureOr loweringStrategy = + parseVPTOLoweringStrategy(this->loweringStrategy); + if (failed(loweringStrategy)) { + module.emitError() + << "unsupported pto-lowering-strategy: " << this->loweringStrategy + << " (expected post-update or no-post-update)"; + signalPassFailure(); + return; + } + SmallVector tensorPipelineOps; + SmallVector residualPTOOps; + module.walk([&](Operation *op) { + if (isa(op)) + tensorPipelineOps.push_back(op); + else if (isa(op)) + residualPTOOps.push_back(op); + }); + + PatternRewriter rewriter(&getContext()); + bool sawFailure = false; + for (Operation *op : tensorPipelineOps) { + if (!op->getBlock()) + continue; + if (failed(lowerTensorPipelineOp(op, rewriter, *loweringStrategy))) + sawFailure = true; + } + for (Operation *op : residualPTOOps) { + if (!op->getBlock()) + continue; + if (failed(lowerResidualPTOOp(op, rewriter))) + sawFailure = true; + } + + bool erasedDeadScaffold = true; + while (erasedDeadScaffold) { + erasedDeadScaffold = false; + SmallVector deadScaffoldOps; + module.walk([&](Operation *op) { + if ((isa(op)) && op->use_empty()) + deadScaffoldOps.push_back(op); + }); + for (Operation *op : deadScaffoldOps) { + if (!op->getBlock()) + continue; + rewriter.setInsertionPoint(op); + rewriter.eraseOp(op); + erasedDeadScaffold = true; + } + } + + // Keep the backend mainline memref-first through PTOToVPTO. Pointer ABI + // bridging belongs to the emission boundary, where text/LLVM emitters can + // materialize the required ptr-only signature on a cloned module. + + if (sawFailure) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr createLowerPTOToVPTOPass() { + return std::make_unique(); +} + +std::unique_ptr createLowerPTOToVPTOPass(StringRef loweringStrategy) { + return std::make_unique(loweringStrategy); +} + +} // namespace pto +} // namespace mlir diff --git a/lib/PTO/Transforms/PTOToVPTOLowering.cpp b/lib/PTO/Transforms/PTOToVPTOLowering.cpp new file mode 100644 index 000000000..9568be1bf --- /dev/null +++ b/lib/PTO/Transforms/PTOToVPTOLowering.cpp @@ -0,0 +1,7290 @@ +//===- PTOToVPTOLowering.cpp - PTO to VPTO lowering helpers --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PTO/Transforms/VPTOLowering.h" + +#include "PTO/IR/PTO.h" +#include "PTO/IR/PTOSyncUtils.h" + +#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/ADT/APFloat.h" + +#include +#include + +namespace mlir { +namespace pto { + +namespace { + +constexpr StringLiteral kLoweredLoopScopeAttrName = "llvm.loop.aivector_scope"; + +struct ResolvedTensorView { + Value root; + Attribute layoutAttr; + SmallVector shape; + SmallVector strides; + OpFoldResult offsetElems; +}; + +struct VecNdTransferPlan { + Value outerCount; + Value outerSrcStrideElems; + Value outerDstStrideElems; + Value loop2Size; + Value loop1Size; + Value loop2FirstStrideBytes; + Value loop2SecondStrideBytes; + Value loop1FirstStrideBytes; + Value loop1SecondStrideBytes; + Value nBurst; + Value lenBurst; + Value firstStrideBytes; + Value secondStrideBytes; +}; + +struct VPTORowReduceContract { + StringRef family; + VPTOTileDomain srcDomain = VPTOTileDomain::Vec; + VPTOTileDomain dstDomain = VPTOTileDomain::Vec; + StringRef srcLayout; + StringRef dstLayout; + Type elementType; + Value validRowsValue; + Value validColsValue; + int64_t validRows = ShapedType::kDynamic; + int64_t validCols = ShapedType::kDynamic; + int64_t dstValidCols = ShapedType::kDynamic; + VPTOLoopScopeContract loopScope; +}; + +struct VPTOColReduceContract { + StringRef family; + VPTOTileDomain srcDomain = VPTOTileDomain::Vec; + VPTOTileDomain dstDomain = VPTOTileDomain::Vec; + StringRef srcLayout; + StringRef dstLayout; + Type elementType; + Value validRowsValue; + Value validColsValue; + int64_t validRows = ShapedType::kDynamic; + int64_t validCols = ShapedType::kDynamic; + int64_t dstValidRows = ShapedType::kDynamic; + int64_t dstValidCols = ShapedType::kDynamic; + bool isBinary = false; + Value tmp; + VPTOLoopScopeContract loopScope; +}; + +struct VPTOPartContract { + StringRef family; + VPTOTileDomain src0Domain = VPTOTileDomain::Vec; + VPTOTileDomain src1Domain = VPTOTileDomain::Vec; + VPTOTileDomain dstDomain = VPTOTileDomain::Vec; + StringRef src0Layout; + StringRef src1Layout; + StringRef dstLayout; + Type elementType; + Value src0ValidRowsValue; + Value src0ValidColsValue; + Value src1ValidRowsValue; + Value src1ValidColsValue; + Value dstValidRowsValue; + Value dstValidColsValue; + int64_t src0ValidRows = ShapedType::kDynamic; + int64_t src0ValidCols = ShapedType::kDynamic; + int64_t src1ValidRows = ShapedType::kDynamic; + int64_t src1ValidCols = ShapedType::kDynamic; + int64_t dstValidRows = ShapedType::kDynamic; + int64_t dstValidCols = ShapedType::kDynamic; + VPTOLoopScopeContract loopScope; +}; + +struct VPTOExpandContract { + StringRef family; + VPTOTileDomain srcDomain = VPTOTileDomain::Vec; + VPTOTileDomain dstDomain = VPTOTileDomain::Vec; + StringRef srcLayout; + StringRef dstLayout; + Type elementType; + Value srcValidRowsValue; + Value srcValidColsValue; + Value dstValidRowsValue; + Value dstValidColsValue; + int64_t srcValidRows = ShapedType::kDynamic; + int64_t srcValidCols = ShapedType::kDynamic; + int64_t dstValidRows = ShapedType::kDynamic; + int64_t dstValidCols = ShapedType::kDynamic; + VPTOLoopScopeContract loopScope; +}; + +StringRef inferVecTransferLayoutFromTile(StringRef explicitLayout, + StringRef tileLayout) { + if (explicitLayout != "nd") + return explicitLayout; + if (tileLayout == "col_major") + return "dn"; + return "nd"; +} + +int64_t getElementByteSize(Type type); +Value materializeIndexValue(Value maybeValue, int64_t fallback, + PatternRewriter &rewriter, Location loc); +Value materializeI64Value(Value maybeValue, int64_t fallback, + PatternRewriter &rewriter, Location loc); + +LogicalResult emitUnresolvedInstalledA5BaselineError(Operation *op, + StringRef family) { + return op->emitOpError() + << family + << " lowering is intentionally unresolved until the installed A5 PTO " + "helper baseline is located and traced"; +} + +std::optional getConstInt(Value value) { + if (!value) + return std::nullopt; + + if (auto constIndex = value.getDefiningOp()) + return constIndex.value(); + if (auto constInt = value.getDefiningOp()) + return constInt.value(); + if (auto constOp = value.getDefiningOp()) { + if (auto intAttr = dyn_cast(constOp.getValue())) + return intAttr.getInt(); + } + return std::nullopt; +} + +std::optional getConstInt(OpFoldResult value) { + if (auto attr = dyn_cast(value)) { + if (auto intAttr = dyn_cast(attr)) + return intAttr.getInt(); + return std::nullopt; + } + return getConstInt(cast(value)); +} + +Value materializeIndexOfr(OpFoldResult value, PatternRewriter &rewriter, + Location loc) { + if (auto attr = dyn_cast(value)) { + if (auto intAttr = dyn_cast(attr)) + return rewriter.create(loc, intAttr.getInt()); + return {}; + } + Value v = cast(value); + if (v.getType().isIndex()) + return v; + if (isa(v.getType())) + return rewriter.create(loc, rewriter.getIndexType(), v); + return {}; +} + +Value materializeI64Ofr(OpFoldResult value, PatternRewriter &rewriter, + Location loc) { + if (auto attr = dyn_cast(value)) { + if (auto intAttr = dyn_cast(attr)) + return rewriter.create(loc, intAttr.getInt(), 64); + return {}; + } + return materializeI64Value(cast(value), ShapedType::kDynamic, rewriter, loc); +} + +Value materializeIndexBuilder(OpFoldResult value, PatternRewriter &rewriter, Location loc) { + if (auto attr = dyn_cast(value)) { + if (auto intAttr = dyn_cast(attr)) + return rewriter.create(loc, intAttr.getInt()); + return {}; + } + Value v = cast(value); + if (v.getType().isIndex()) + return v; + if (isa(v.getType())) + return rewriter.create(loc, rewriter.getIndexType(), v); + return {}; +} + +Value createI64Mul(Value lhs, Value rhs, PatternRewriter &rewriter, Location loc) { + if (!lhs || !rhs) + return {}; + if (std::optional lhsConst = getConstInt(lhs)) { + if (std::optional rhsConst = getConstInt(rhs)) + return rewriter.create(loc, (*lhsConst) * (*rhsConst), 64); + } + return rewriter.create(loc, lhs, rhs); +} + +Value createI64Add(Value lhs, Value rhs, PatternRewriter &rewriter, Location loc) { + if (!lhs || !rhs) + return {}; + if (std::optional lhsConst = getConstInt(lhs)) { + if (std::optional rhsConst = getConstInt(rhs)) + return rewriter.create(loc, (*lhsConst) + (*rhsConst), 64); + } + return rewriter.create(loc, lhs, rhs); +} + +OpFoldResult addOfr(OpFoldResult lhs, OpFoldResult rhs, PatternRewriter &rewriter, + Location loc) { + if (auto lhsConst = getConstInt(lhs)) { + if (auto rhsConst = getConstInt(rhs)) + return rewriter.getIndexAttr((*lhsConst) + (*rhsConst)); + } + Value lhsValue = materializeIndexBuilder(lhs, rewriter, loc); + Value rhsValue = materializeIndexBuilder(rhs, rewriter, loc); + if (!lhsValue || !rhsValue) + return {}; + return rewriter.create(loc, lhsValue, rhsValue).getResult(); +} + +OpFoldResult multiplyOfr(OpFoldResult lhs, OpFoldResult rhs, PatternRewriter &rewriter, + Location loc) { + if (auto lhsConst = getConstInt(lhs)) { + if (auto rhsConst = getConstInt(rhs)) + return rewriter.getIndexAttr((*lhsConst) * (*rhsConst)); + } + Value lhsValue = materializeIndexBuilder(lhs, rewriter, loc); + Value rhsValue = materializeIndexBuilder(rhs, rewriter, loc); + if (!lhsValue || !rhsValue) + return {}; + return rewriter.create(loc, lhsValue, rhsValue).getResult(); +} + +bool resolveTensorView(Value value, ResolvedTensorView &info, PatternRewriter &rewriter, + Location loc) { + if (!value) + return false; + + if (auto part = value.getDefiningOp()) { + if (!resolveTensorView(part.getSource(), info, rewriter, loc)) + return false; + SmallVector offsets; + offsets.reserve(part.getOffsets().size()); + for (Value offset : part.getOffsets()) + offsets.push_back(offset); + if (offsets.size() != info.strides.size()) + return false; + OpFoldResult totalOffset = info.offsetElems; + for (auto [offset, stride] : llvm::zip(offsets, info.strides)) { + OpFoldResult term = multiplyOfr(offset, stride, rewriter, loc); + if (!term) + return false; + totalOffset = addOfr(totalOffset, term, rewriter, loc); + if (!totalOffset) + return false; + } + info.offsetElems = totalOffset; + info.shape.clear(); + for (Value size : part.getSizes()) + info.shape.push_back(size); + return true; + } + + if (auto source = value.getDefiningOp()) { + info.root = source.getPtr(); + info.layoutAttr = source.getLayoutAttr(); + info.shape.assign(source.getShape().begin(), source.getShape().end()); + info.strides.assign(source.getStrides().begin(), source.getStrides().end()); + info.offsetElems = rewriter.getIndexAttr(0); + return true; + } + + if (auto subview = value.getDefiningOp()) { + ResolvedTensorView parent; + Value source = subview.getSource(); + if (auto reinterpret = source.getDefiningOp()) { + Value root = reinterpret.getSource(); + while (true) { + if (auto cast = root.getDefiningOp()) { + root = cast.getSource(); + continue; + } + break; + } + parent.root = root; + if (Attribute layout = reinterpret->getAttr("layout")) + parent.layoutAttr = layout; + auto parentShapes = + getMixedValues(reinterpret.getStaticSizes(), reinterpret.getSizes(), rewriter); + auto parentStrides = + getMixedValues(reinterpret.getStaticStrides(), reinterpret.getStrides(), rewriter); + auto offsets = + getMixedValues(reinterpret.getStaticOffsets(), reinterpret.getOffsets(), rewriter); + parent.shape.assign(parentShapes.begin(), parentShapes.end()); + parent.strides.assign(parentStrides.begin(), parentStrides.end()); + parent.offsetElems = + offsets.empty() ? OpFoldResult(rewriter.getIndexAttr(0)) : offsets.front(); + } else if (!resolveTensorView(source, parent, rewriter, loc)) { + return false; + } + + if (parent.strides.empty()) { + auto sourceType = dyn_cast(source.getType()); + if (!sourceType) + return false; + SmallVector strides; + int64_t offset = 0; + if (failed(getStridesAndOffset(sourceType, strides, offset))) { + strides.assign(sourceType.getRank(), 1); + int64_t running = 1; + for (int i = sourceType.getRank() - 1; i >= 0; --i) { + strides[i] = running; + int64_t dim = sourceType.getShape()[i]; + if (dim != ShapedType::kDynamic) + running *= dim; + } + } + for (int64_t stride : strides) + parent.strides.push_back(rewriter.getIndexAttr(stride == ShapedType::kDynamic ? 1 : stride)); + parent.offsetElems = rewriter.getIndexAttr(offset); + parent.root = source; + } + + info = parent; + if (subview.getMixedOffsets().size() != info.strides.size()) + return false; + + OpFoldResult totalOffset = info.offsetElems; + for (auto [offset, stride] : llvm::zip(subview.getMixedOffsets(), info.strides)) { + OpFoldResult term = multiplyOfr(offset, stride, rewriter, loc); + if (!term) + return false; + totalOffset = addOfr(totalOffset, term, rewriter, loc); + if (!totalOffset) + return false; + } + + SmallVector newStrides; + newStrides.reserve(info.strides.size()); + for (auto [srcStride, step] : llvm::zip(info.strides, subview.getMixedStrides())) { + OpFoldResult product = multiplyOfr(srcStride, step, rewriter, loc); + if (!product) + return false; + newStrides.push_back(product); + } + + info.offsetElems = totalOffset; + info.shape.assign(subview.getMixedSizes().begin(), subview.getMixedSizes().end()); + info.strides = std::move(newStrides); + return true; + } + + if (auto reinterpret = value.getDefiningOp()) { + Value root = reinterpret.getSource(); + while (true) { + if (auto cast = root.getDefiningOp()) { + root = cast.getSource(); + continue; + } + if (auto unrealized = root.getDefiningOp()) { + if (!unrealized.getInputs().empty()) { + root = unrealized.getInputs().front(); + continue; + } + } + break; + } + info.root = root; + if (Attribute layout = reinterpret->getAttr("layout")) + info.layoutAttr = layout; + auto reinterpretShapes = + getMixedValues(reinterpret.getStaticSizes(), reinterpret.getSizes(), rewriter); + auto reinterpretStrides = + getMixedValues(reinterpret.getStaticStrides(), reinterpret.getStrides(), rewriter); + auto offsets = + getMixedValues(reinterpret.getStaticOffsets(), reinterpret.getOffsets(), rewriter); + info.shape.assign(reinterpretShapes.begin(), reinterpretShapes.end()); + info.strides.assign(reinterpretStrides.begin(), reinterpretStrides.end()); + if (!offsets.empty()) { + if (offsets.size() != 1) + return false; + info.offsetElems = offsets.front(); + } else { + info.offsetElems = rewriter.getIndexAttr(0); + } + return true; + } + + if (auto cast = value.getDefiningOp()) + return resolveTensorView(cast.getSource(), info, rewriter, loc); + + if (auto memrefType = dyn_cast(value.getType())) { + info.root = value; + info.shape.clear(); + for (int64_t dim : memrefType.getShape()) + info.shape.push_back(rewriter.getIndexAttr(dim == ShapedType::kDynamic ? 1 : dim)); + SmallVector strides; + int64_t offset = 0; + if (failed(getStridesAndOffset(memrefType, strides, offset))) { + strides.assign(memrefType.getRank(), 1); + int64_t running = 1; + for (int i = memrefType.getRank() - 1; i >= 0; --i) { + strides[i] = running; + int64_t dim = memrefType.getShape()[i]; + if (dim != ShapedType::kDynamic) + running *= dim; + } + offset = 0; + } + info.strides.clear(); + for (int64_t stride : strides) + info.strides.push_back(rewriter.getIndexAttr(stride == ShapedType::kDynamic ? 1 : stride)); + info.offsetElems = rewriter.getIndexAttr(offset); + return true; + } + + return false; +} + +void normalizeMixedGlobalShapeAndStride(ArrayRef shape, + ArrayRef strides, + SmallVectorImpl &globalShape, + SmallVectorImpl &globalStride, + PatternRewriter &rewriter, Location loc) { + constexpr int64_t kRank = 5; + globalShape.assign(kRank, rewriter.getIndexAttr(1)); + globalStride.assign(kRank, rewriter.getIndexAttr(1)); + + size_t rank = std::min(shape.size(), strides.size()); + rank = std::min(rank, kRank); + size_t base = kRank - rank; + for (size_t i = 0; i < rank; ++i) { + globalShape[base + i] = shape[shape.size() - rank + i]; + globalStride[base + i] = strides[strides.size() - rank + i]; + } + + for (int i = static_cast(kRank) - 2; i >= 0; --i) { + if (i >= static_cast(base)) + continue; + OpFoldResult product = multiplyOfr(globalStride[i + 1], globalShape[i + 1], rewriter, loc); + if (!product) + product = rewriter.getIndexAttr(ShapedType::kDynamic); + globalStride[i] = product; + } +} + +Value adjustPointerByElemOffset(Value ptr, Value elemOffsetI64, int64_t elemBytes, + PatternRewriter &rewriter, Location loc) { + if (!ptr || !elemOffsetI64 || elemBytes <= 0) + return {}; + + Value offset = elemOffsetI64.getType().isIndex() + ? rewriter.create( + loc, rewriter.getI64Type(), elemOffsetI64) + : elemOffsetI64; + Value byteOffset = offset; + if (elemBytes != 1) { + Value elemBytesValue = rewriter.create(loc, elemBytes, 64); + byteOffset = createI64Mul(offset, elemBytesValue, rewriter, loc); + } + if (auto ptrType = dyn_cast(ptr.getType())) { + auto bytePtrType = PtrType::get(rewriter.getContext(), rewriter.getI8Type(), + ptrType.getMemorySpace()); + Value bytePtr = ptrType == bytePtrType + ? ptr + : rewriter.create(loc, bytePtrType, ptr).getResult(); + Value byteOffsetIndex = + byteOffset.getType().isIndex() + ? byteOffset + : rewriter.create(loc, rewriter.getIndexType(), + byteOffset); + return rewriter.create(loc, bytePtrType, bytePtr, byteOffsetIndex); + } + return {}; +} + +Value castPtrToElementType(Value ptr, Type elementType, PatternRewriter &rewriter, + Location loc) { + auto ptrType = dyn_cast_or_null(ptr.getType()); + if (!ptrType || !elementType) + return {}; + auto targetType = + PtrType::get(rewriter.getContext(), elementType, ptrType.getMemorySpace()); + if (targetType == ptrType) + return ptr; + return rewriter.create(loc, targetType, ptr).getResult(); +} + +Type getCopyTransferElementType(Type elementType, Builder &builder) { + if (getElementByteSize(elementType) == 8) + return builder.getI32Type(); + return elementType; +} + +LogicalResult buildVecNdLoadPlan(ArrayRef shape, + ArrayRef strides, int64_t tileCols, + Value validColsValue, int64_t validCols, + Type elementType, PatternRewriter &rewriter, + Location loc, VecNdTransferPlan &plan) { + if (tileCols == ShapedType::kDynamic) + return failure(); + int64_t elemBytes = getElementByteSize(elementType); + if (elemBytes <= 0) + return failure(); + + SmallVector globalShape; + SmallVector globalStride; + normalizeMixedGlobalShapeAndStride(shape, strides, globalShape, globalStride, rewriter, loc); + + auto toI64 = [&](OpFoldResult ofr) { return materializeI64Ofr(ofr, rewriter, loc); }; + Value gShape0 = toI64(globalShape[0]); + Value gShape1 = toI64(globalShape[1]); + Value gShape2 = toI64(globalShape[2]); + Value gShape3 = toI64(globalShape[3]); + Value gStride0 = toI64(globalStride[0]); + Value gStride1 = toI64(globalStride[1]); + Value gStride2 = toI64(globalStride[2]); + Value gStride3 = toI64(globalStride[3]); + Value validColsI64 = materializeI64Value(validColsValue, validCols, rewriter, loc); + if (!gShape0 || !gShape1 || !gShape2 || !gShape3 || !gStride0 || !gStride1 || + !gStride2 || !gStride3 || !validColsI64) + return failure(); + + Value tileColsI64 = rewriter.create(loc, tileCols, 64); + Value elemBytesI64 = rewriter.create(loc, elemBytes, 64); + Value dstStride2 = createI64Mul(gShape3, tileColsI64, rewriter, loc); + Value dstStride1 = createI64Mul(gShape2, dstStride2, rewriter, loc); + Value dstStride0 = createI64Mul(gShape1, dstStride1, rewriter, loc); + + plan.outerCount = gShape0; + plan.outerSrcStrideElems = gStride0; + plan.outerDstStrideElems = dstStride0; + plan.loop2Size = gShape1; + plan.loop1Size = gShape2; + plan.loop2FirstStrideBytes = createI64Mul(dstStride1, elemBytesI64, rewriter, loc); + plan.loop2SecondStrideBytes = createI64Mul(gStride1, elemBytesI64, rewriter, loc); + plan.loop1FirstStrideBytes = createI64Mul(dstStride2, elemBytesI64, rewriter, loc); + plan.loop1SecondStrideBytes = createI64Mul(gStride2, elemBytesI64, rewriter, loc); + plan.nBurst = gShape3; + plan.lenBurst = createI64Mul(validColsI64, elemBytesI64, rewriter, loc); + plan.firstStrideBytes = createI64Mul(gStride3, elemBytesI64, rewriter, loc); + plan.secondStrideBytes = createI64Mul(tileColsI64, elemBytesI64, rewriter, loc); + return success(); +} + +LogicalResult buildVecDnLoadPlan(ArrayRef shape, + ArrayRef strides, int64_t tileRows, + Value validRowsValue, int64_t validRows, + Type elementType, PatternRewriter &rewriter, + Location loc, VecNdTransferPlan &plan) { + if (tileRows == ShapedType::kDynamic) + return failure(); + int64_t elemBytes = getElementByteSize(elementType); + if (elemBytes <= 0) + return failure(); + + SmallVector globalShape; + SmallVector globalStride; + normalizeMixedGlobalShapeAndStride(shape, strides, globalShape, globalStride, + rewriter, loc); + + auto toI64 = [&](OpFoldResult ofr) { return materializeI64Ofr(ofr, rewriter, loc); }; + Value gShape0 = toI64(globalShape[0]); + Value gShape1 = toI64(globalShape[1]); + Value gShape2 = toI64(globalShape[2]); + Value gShape4 = toI64(globalShape[4]); + Value gStride0 = toI64(globalStride[0]); + Value gStride1 = toI64(globalStride[1]); + Value gStride2 = toI64(globalStride[2]); + Value gStride4 = toI64(globalStride[4]); + Value validRowsI64 = materializeI64Value(validRowsValue, validRows, rewriter, loc); + if (!gShape0 || !gShape1 || !gShape2 || !gShape4 || !gStride0 || !gStride1 || + !gStride2 || !gStride4 || !validRowsI64) + return failure(); + + Value tileRowsI64 = rewriter.create(loc, tileRows, 64); + Value elemBytesI64 = rewriter.create(loc, elemBytes, 64); + Value dstStride2 = createI64Mul(gShape4, tileRowsI64, rewriter, loc); + Value dstStride1 = createI64Mul(gShape2, dstStride2, rewriter, loc); + Value dstStride0 = createI64Mul(gShape1, dstStride1, rewriter, loc); + + plan.outerCount = gShape0; + plan.outerSrcStrideElems = gStride0; + plan.outerDstStrideElems = dstStride0; + plan.loop2Size = gShape1; + plan.loop1Size = gShape2; + plan.loop2FirstStrideBytes = createI64Mul(dstStride1, elemBytesI64, rewriter, loc); + plan.loop2SecondStrideBytes = createI64Mul(gStride1, elemBytesI64, rewriter, loc); + plan.loop1FirstStrideBytes = createI64Mul(dstStride2, elemBytesI64, rewriter, loc); + plan.loop1SecondStrideBytes = createI64Mul(gStride2, elemBytesI64, rewriter, loc); + plan.nBurst = gShape4; + plan.lenBurst = createI64Mul(validRowsI64, elemBytesI64, rewriter, loc); + plan.firstStrideBytes = createI64Mul(gStride4, elemBytesI64, rewriter, loc); + plan.secondStrideBytes = createI64Mul(tileRowsI64, elemBytesI64, rewriter, loc); + return success(); +} + +LogicalResult buildVecNdStorePlan(ArrayRef shape, + ArrayRef strides, int64_t tileCols, + Value validColsValue, int64_t validCols, + Type elementType, PatternRewriter &rewriter, + Location loc, VecNdTransferPlan &plan) { + if (failed(buildVecNdLoadPlan(shape, strides, tileCols, validColsValue, validCols, + elementType, rewriter, loc, plan))) + return failure(); + std::swap(plan.outerSrcStrideElems, plan.outerDstStrideElems); + std::swap(plan.loop2FirstStrideBytes, plan.loop2SecondStrideBytes); + std::swap(plan.loop1FirstStrideBytes, plan.loop1SecondStrideBytes); + return success(); +} + +LogicalResult buildVecDnStorePlan(ArrayRef shape, + ArrayRef strides, int64_t tileRows, + Value validRowsValue, int64_t validRows, + Type elementType, PatternRewriter &rewriter, + Location loc, VecNdTransferPlan &plan) { + if (tileRows == ShapedType::kDynamic) + return failure(); + int64_t elemBytes = getElementByteSize(elementType); + if (elemBytes <= 0) + return failure(); + + SmallVector globalShape; + SmallVector globalStride; + normalizeMixedGlobalShapeAndStride(shape, strides, globalShape, globalStride, + rewriter, loc); + + auto toI64 = [&](OpFoldResult ofr) { return materializeI64Ofr(ofr, rewriter, loc); }; + Value gShape0 = toI64(globalShape[0]); + Value gShape1 = toI64(globalShape[1]); + Value gShape2 = toI64(globalShape[2]); + Value gShape4 = toI64(globalShape[4]); + Value gStride0 = toI64(globalStride[0]); + Value gStride1 = toI64(globalStride[1]); + Value gStride2 = toI64(globalStride[2]); + Value gStride4 = toI64(globalStride[4]); + Value validRowsI64 = materializeI64Value(validRowsValue, validRows, rewriter, loc); + if (!gShape0 || !gShape1 || !gShape2 || !gShape4 || !gStride0 || !gStride1 || + !gStride2 || !gStride4 || !validRowsI64) + return failure(); + + Value tileRowsI64 = rewriter.create(loc, tileRows, 64); + Value elemBytesI64 = rewriter.create(loc, elemBytes, 64); + Value outerSrcStride = + createI64Mul(createI64Mul(createI64Mul(gShape1, gShape2, rewriter, loc), + gShape4, rewriter, loc), + tileRowsI64, rewriter, loc); + Value loop1SrcStride = + createI64Mul(createI64Mul(tileRowsI64, gShape4, rewriter, loc), elemBytesI64, + rewriter, loc); + Value loop2SrcStride = + createI64Mul(createI64Mul(createI64Mul(gShape2, tileRowsI64, rewriter, loc), + gShape4, rewriter, loc), + elemBytesI64, rewriter, loc); + + plan.outerCount = gShape0; + plan.outerSrcStrideElems = outerSrcStride; + plan.outerDstStrideElems = gStride0; + plan.loop2Size = gShape1; + plan.loop1Size = gShape2; + plan.loop2FirstStrideBytes = loop2SrcStride; + plan.loop2SecondStrideBytes = createI64Mul(gStride1, elemBytesI64, rewriter, loc); + plan.loop1FirstStrideBytes = loop1SrcStride; + plan.loop1SecondStrideBytes = createI64Mul(gStride2, elemBytesI64, rewriter, loc); + plan.nBurst = gShape4; + plan.lenBurst = createI64Mul(validRowsI64, elemBytesI64, rewriter, loc); + plan.firstStrideBytes = createI64Mul(gStride4, elemBytesI64, rewriter, loc); + plan.secondStrideBytes = createI64Mul(tileRowsI64, elemBytesI64, rewriter, loc); + return success(); +} + +StringRef stringifyTileLayout(TileBufType type) { + if (auto layoutAttr = dyn_cast_or_null(type.getBLayoutAttr())) { + switch (layoutAttr.getValue()) { + case BLayout::RowMajor: + return "row_major"; + case BLayout::ColMajor: + return "col_major"; + } + } + return "row_major"; +} + +StringRef stringifyTileLayoutConfig(TileBufConfigAttr config) { + if (!config) + return "row_major"; + if (auto layoutAttr = dyn_cast_or_null(config.getBLayout())) { + switch (layoutAttr.getValue()) { + case BLayout::RowMajor: + return "row_major"; + case BLayout::ColMajor: + return "col_major"; + } + } + return "row_major"; +} + +StringRef stringifyPadModeAttr(PadModeAttr padMode) { + if (!padMode) + return "none"; + + switch (padMode.getPadmode()) { + case PadMode::PadNull: + return "none"; + case PadMode::PadFirstElem: + return "first_elem"; + case PadMode::PadValue: + return "value"; + } + return "none"; +} + +StringRef stringifyLayoutAttr(Attribute layoutAttr) { + if (auto attr = dyn_cast_or_null(layoutAttr)) + return stringifyLayout(attr.getLayout()); + return "nd"; +} + +PipeAttr stringifyPipeAttr(PipeAttr pipe, PatternRewriter &rewriter) { + return PipeAttr::get(rewriter.getContext(), pipe.getPipe()); +} + +EventAttr stringifyEventAttr(EventAttr event, PatternRewriter &rewriter) { + return EventAttr::get(rewriter.getContext(), event.getEvent()); +} + +StringRef stringifyCmpModeAttr(CmpModeAttr cmpMode) { + if (!cmpMode) + return "eq"; + switch (cmpMode.getValue()) { + case CmpMode::EQ: + return "eq"; + case CmpMode::NE: + return "ne"; + case CmpMode::LT: + return "lt"; + case CmpMode::LE: + return "le"; + case CmpMode::GT: + return "gt"; + case CmpMode::GE: + return "ge"; + } + return "eq"; +} + +StringRef stringifyElementTypeFragment(Type type) { + if (!type) + return "unknown"; + if (type.isF16()) + return "f16"; + if (type.isBF16()) + return "bf16"; + if (type.isF32()) + return "f32"; + if (auto intType = dyn_cast(type)) { + if (intType.isUnsigned()) + switch (intType.getWidth()) { + case 8: + return "u8"; + case 16: + return "u16"; + case 32: + return "u32"; + case 64: + return "u64"; + default: + break; + } + switch (intType.getWidth()) { + case 8: + return "s8"; + case 16: + return "s16"; + case 32: + return "s32"; + case 64: + return "s64"; + default: + break; + } + } + return "unknown"; +} + +StringRef stringifyCopyTransferTypeFragment(Type type) { + switch (getElementByteSize(type)) { + case 1: + return "u8"; + case 2: + return "u16"; + case 4: + case 8: + return "u32"; + default: + return stringifyElementTypeFragment(type); + } +} + +static bool isSupportedPackedCmp32ElementType(Type type) { + if (!type) + return false; + if (type.isF32()) + return true; + auto intType = dyn_cast(type); + return intType && intType.getWidth() == 32; +} + +VPTOTileDomain deriveTileDomain(Attribute memorySpace) { + if (auto addrSpace = dyn_cast_or_null(memorySpace)) { + switch (addrSpace.getAddressSpace()) { + case AddressSpace::ACC: + return VPTOTileDomain::Acc; + case AddressSpace::MAT: + return VPTOTileDomain::Mat; + case AddressSpace::VEC: + default: + return VPTOTileDomain::Vec; + } + } + if (auto intAttr = dyn_cast_or_null(memorySpace)) { + switch (intAttr.getInt()) { + case static_cast(AddressSpace::ACC): + return VPTOTileDomain::Acc; + case static_cast(AddressSpace::MAT): + return VPTOTileDomain::Mat; + default: + return VPTOTileDomain::Vec; + } + } + return VPTOTileDomain::Vec; +} + +void getValidShape(TileBufType type, int64_t &rows, int64_t &cols) { + ArrayRef validShape = type.getValidShape(); + rows = validShape.size() > 0 ? validShape[0] : ShapedType::kDynamic; + cols = validShape.size() > 1 ? validShape[1] : ShapedType::kDynamic; +} + +static std::pair getIfResultYieldedValues(Value value) { + auto result = dyn_cast(value); + if (!result) + return {Value(), Value()}; + auto ifOp = dyn_cast(result.getOwner()); + if (!ifOp) + return {Value(), Value()}; + unsigned resultNumber = result.getResultNumber(); + auto thenYield = dyn_cast(ifOp.thenBlock()->getTerminator()); + auto elseYield = dyn_cast(ifOp.elseBlock()->getTerminator()); + if (!thenYield || !elseYield) + return {Value(), Value()}; + if (resultNumber >= thenYield.getNumOperands() || + resultNumber >= elseYield.getNumOperands()) + return {Value(), Value()}; + return {thenYield.getOperand(resultNumber), elseYield.getOperand(resultNumber)}; +} + +static bool equalOrBothNull(Value lhs, Value rhs) { + if (!lhs && !rhs) + return true; + if (!lhs || !rhs) + return false; + if (lhs == rhs) + return true; + auto lhsConst = getConstInt(lhs); + auto rhsConst = getConstInt(rhs); + return lhsConst && rhsConst && *lhsConst == *rhsConst; +} + +TileBufConfigAttr lookupTileConfig(Value value) { + if (!value) + return {}; + if (auto bind = value.getDefiningOp()) + return bind.getConfig(); + if (auto cast = value.getDefiningOp()) + return cast.getConfig().value_or(TileBufConfigAttr{}); + if (auto subview = value.getDefiningOp()) + return lookupTileConfig(subview.getSource()); + if (auto reinterpret = value.getDefiningOp()) + return lookupTileConfig(reinterpret.getSource()); + if (auto cast = value.getDefiningOp()) + return lookupTileConfig(cast.getSource()); + if (auto [thenValue, elseValue] = getIfResultYieldedValues(value); + thenValue && elseValue) { + TileBufConfigAttr thenConfig = lookupTileConfig(thenValue); + TileBufConfigAttr elseConfig = lookupTileConfig(elseValue); + if (thenConfig && elseConfig && thenConfig == elseConfig) + return thenConfig; + } + return {}; +} + +bool hasStructuredTileDriver(Value value) { + if (!value) + return false; + if (isa(value.getType())) + return true; + if (value.getDefiningOp()) + return true; + if (auto subview = value.getDefiningOp()) + return hasStructuredTileDriver(subview.getSource()); + if (auto reinterpret = value.getDefiningOp()) + return hasStructuredTileDriver(reinterpret.getSource()); + if (auto cast = value.getDefiningOp()) + return hasStructuredTileDriver(cast.getSource()); + if (auto [thenValue, elseValue] = getIfResultYieldedValues(value); + thenValue && elseValue) { + return hasStructuredTileDriver(thenValue) && hasStructuredTileDriver(elseValue); + } + return false; +} + +void lookupValidDims(Value value, Value &validRow, Value &validCol) { + if (!value) { + validRow = {}; + validCol = {}; + return; + } + if (auto bind = value.getDefiningOp()) { + validRow = bind.getValidRow(); + validCol = bind.getValidCol(); + return; + } + if (auto cast = value.getDefiningOp()) { + validRow = cast.getValidRow(); + validCol = cast.getValidCol(); + return; + } + if (auto subview = value.getDefiningOp()) { + lookupValidDims(subview.getSource(), validRow, validCol); + return; + } + if (auto reinterpret = value.getDefiningOp()) { + lookupValidDims(reinterpret.getSource(), validRow, validCol); + return; + } + if (auto cast = value.getDefiningOp()) { + lookupValidDims(cast.getSource(), validRow, validCol); + return; + } + if (auto [thenValue, elseValue] = getIfResultYieldedValues(value); + thenValue && elseValue) { + Value thenRow; + Value thenCol; + Value elseRow; + Value elseCol; + lookupValidDims(thenValue, thenRow, thenCol); + lookupValidDims(elseValue, elseRow, elseCol); + validRow = equalOrBothNull(thenRow, elseRow) ? thenRow : Value(); + validCol = equalOrBothNull(thenCol, elseCol) ? thenCol : Value(); + return; + } + validRow = {}; + validCol = {}; +} + +Type getElementType(Value value) { + Type type = value.getType(); + if (auto tileType = dyn_cast(type)) + return tileType.getElementType(); + if (auto memrefType = dyn_cast(type)) + return memrefType.getElementType(); + if (auto ptrType = dyn_cast(type)) + return ptrType.getElementType(); + return {}; +} + +Attribute getMemorySpace(Value value) { + Type type = value.getType(); + if (auto tileType = dyn_cast(type)) + return tileType.getMemorySpace(); + if (auto memrefType = dyn_cast(type)) + return memrefType.getMemorySpace(); + if (auto ptrType = dyn_cast(type)) + return ptrType.getMemorySpace(); + return {}; +} + +StringRef deriveTileLayout(Value value) { + if (auto tileType = dyn_cast(value.getType())) + return stringifyTileLayout(tileType); + return stringifyTileLayoutConfig(lookupTileConfig(value)); +} + +void deriveValidShape(Value value, int64_t &rows, int64_t &cols) { + if (auto tileType = dyn_cast(value.getType())) { + getValidShape(tileType, rows, cols); + return; + } + + Value validRow; + Value validCol; + lookupValidDims(value, validRow, validCol); + rows = getConstInt(validRow).value_or(ShapedType::kDynamic); + cols = getConstInt(validCol).value_or(ShapedType::kDynamic); + if (rows != ShapedType::kDynamic && cols != ShapedType::kDynamic) + return; + if (!hasStructuredTileDriver(value)) + return; + + auto shapedType = dyn_cast(value.getType()); + if (!shapedType || !shapedType.hasRank()) + return; + + ArrayRef shape = shapedType.getShape(); + if (shape.empty()) { + if (rows == ShapedType::kDynamic) + rows = 1; + if (cols == ShapedType::kDynamic) + cols = 1; + return; + } + if (shape.size() == 1) { + if (rows == ShapedType::kDynamic) + rows = 1; + if (cols == ShapedType::kDynamic) + cols = shape.front(); + return; + } + + if (cols == ShapedType::kDynamic) + cols = shape.back(); + if (rows == ShapedType::kDynamic) { + int64_t flatRows = 1; + for (int64_t dim : shape.drop_back()) { + if (dim == ShapedType::kDynamic) { + flatRows = ShapedType::kDynamic; + break; + } + flatRows *= dim; + } + rows = flatRows; + } +} + +void deriveValidShapeValues(Value value, Value &rows, Value &cols) { + if (auto tileType = dyn_cast(value.getType())) { + ArrayRef validShape = tileType.getValidShape(); + rows = {}; + cols = {}; + (void)validShape; + lookupValidDims(value, rows, cols); + return; + } + lookupValidDims(value, rows, cols); +} + +void appendStaticSizes(ValueRange values, SmallVectorImpl &out, + bool &hasDynamic) { + out.clear(); + hasDynamic = false; + out.reserve(values.size()); + for (Value value : values) { + if (std::optional constant = getConstInt(value)) { + out.push_back(*constant); + continue; + } + out.push_back(ShapedType::kDynamic); + hasDynamic = true; + } +} + +int64_t getElementByteSize(Type type) { + if (auto floatType = dyn_cast(type)) + return (floatType.getWidth() + 7) / 8; + if (auto intType = dyn_cast(type)) + return (intType.getWidth() + 7) / 8; + return 0; +} + +Value materializeIndexValue(Value maybeValue, int64_t fallback, + PatternRewriter &rewriter, Location loc) { + if (maybeValue) + return maybeValue; + if (fallback != ShapedType::kDynamic) + return rewriter.create(loc, fallback); + return {}; +} + +Value materializeI64Value(Value maybeValue, int64_t fallback, + PatternRewriter &rewriter, Location loc) { + if (maybeValue) { + Type type = maybeValue.getType(); + if (type.isIndex()) + return rewriter.create(loc, rewriter.getI64Type(), maybeValue); + if (type.isInteger(64)) + return maybeValue; + if (auto intType = dyn_cast(type)) + return rewriter.create(loc, rewriter.getI64Type(), maybeValue); + } + if (fallback != ShapedType::kDynamic) + return rewriter.create(loc, fallback, 64); + return {}; +} + +void recordStaticValues(ValueRange values, SmallVectorImpl &out) { + out.clear(); + out.reserve(values.size()); + for (Value value : values) + out.push_back(getConstInt(value).value_or(ShapedType::kDynamic)); +} + +void recordStaticSizes(ArrayRef values, + SmallVectorImpl &out, bool &hasDynamic) { + out.clear(); + hasDynamic = false; + out.reserve(values.size()); + for (OpFoldResult value : values) { + if (auto attr = dyn_cast(value)) { + if (auto intAttr = dyn_cast(attr)) { + out.push_back(intAttr.getInt()); + continue; + } + } else if (std::optional constant = + getConstInt(cast(value))) { + out.push_back(*constant); + continue; + } + out.push_back(ShapedType::kDynamic); + hasDynamic = true; + } +} + +void mergeSubviewTrace(VPTOPartitionTrace &trace, ArrayRef offsets, + ArrayRef sizes, bool hasDynamicOffsets, + bool hasDynamicSizes) { + if (trace.offsets.empty()) { + trace.offsets.assign(offsets.begin(), offsets.end()); + trace.hasDynamicOffsets = hasDynamicOffsets; + } else { + size_t count = std::min(trace.offsets.size(), offsets.size()); + for (size_t i = 0; i < count; ++i) { + if (trace.offsets[i] == ShapedType::kDynamic || + offsets[i] == ShapedType::kDynamic) { + trace.offsets[i] = ShapedType::kDynamic; + trace.hasDynamicOffsets = true; + continue; + } + trace.offsets[i] += offsets[i]; + } + trace.hasDynamicOffsets = trace.hasDynamicOffsets || hasDynamicOffsets; + } + + trace.sizes.assign(sizes.begin(), sizes.end()); + trace.hasDynamicSizes = hasDynamicSizes; +} + +Value resolveTensorViewBase(Value value, Attribute &layoutAttr, + SmallVectorImpl &shape, + SmallVectorImpl &strides) { + if (!value) + return {}; + + if (auto part = value.getDefiningOp()) { + return resolveTensorViewBase(part.getSource(), layoutAttr, shape, strides); + } + + if (auto source = value.getDefiningOp()) { + layoutAttr = source.getLayoutAttr(); + auto tensorType = dyn_cast(source.getResult().getType()); + shape.assign(tensorType.getShape().begin(), tensorType.getShape().end()); + recordStaticValues(source.getStrides(), strides); + return source.getPtr(); + } + + if (auto subview = value.getDefiningOp()) { + Value base = + resolveTensorViewBase(subview.getSource(), layoutAttr, shape, strides); + if (shape.empty()) { + bool hasDynamicSizes = false; + recordStaticSizes(subview.getMixedSizes(), shape, hasDynamicSizes); + } + return base ? base : value; + } + + if (auto reinterpret = value.getDefiningOp()) { + if (Attribute layout = reinterpret->getAttr("layout")) + layoutAttr = layout; + if (shape.empty()) { + bool hasDynamicSizes = false; + recordStaticSizes(reinterpret.getMixedSizes(), shape, hasDynamicSizes); + } + if (strides.empty()) { + bool hasDynamicStrides = false; + recordStaticSizes(reinterpret.getMixedStrides(), strides, + hasDynamicStrides); + } + Value base = + resolveTensorViewBase(reinterpret.getSource(), layoutAttr, shape, strides); + return base ? base : value; + } + + if (auto cast = value.getDefiningOp()) { + Value base = + resolveTensorViewBase(cast.getSource(), layoutAttr, shape, strides); + return base ? base : value; + } + + if (auto memrefType = dyn_cast(value.getType())) { + if (shape.empty()) + shape.assign(memrefType.getShape().begin(), memrefType.getShape().end()); + if (strides.empty()) { + int64_t offset = 0; + if (failed(mlir::getStridesAndOffset(memrefType, strides, offset))) + strides.assign(shape.size(), ShapedType::kDynamic); + } + return value; + } + + return {}; +} + +pto::VRegType getVPTOVRegType(MLIRContext *context, Type elementType) { + unsigned bitWidth = 0; + if (auto floatType = dyn_cast(elementType)) + bitWidth = floatType.getWidth(); + else if (auto intType = dyn_cast(elementType)) + bitWidth = intType.getWidth(); + + if (bitWidth == 0 || 2048 % bitWidth != 0) + return {}; + return pto::VRegType::get(context, 2048 / bitWidth, elementType); +} + +pto::MaskType getVPTOMaskType(MLIRContext *context, StringRef granularity) { + return pto::MaskType::get(context, granularity); +} + +pto::MaskType getVPTOMaskTypeForElementType(MLIRContext *context, + Type elementType) { + unsigned bitWidth = 0; + if (auto floatType = dyn_cast(elementType)) + bitWidth = floatType.getWidth(); + else if (auto intType = dyn_cast(elementType)) + bitWidth = intType.getWidth(); + + switch (bitWidth) { + case 8: + return getVPTOMaskType(context, "b8"); + case 16: + return getVPTOMaskType(context, "b16"); + case 32: + return getVPTOMaskType(context, "b32"); + default: + return {}; + } +} + +ArrayAttr asI64ArrayAttr(Builder &builder, ArrayRef values) { + SmallVector attrs; + attrs.reserve(values.size()); + for (int64_t value : values) + attrs.push_back(builder.getI64IntegerAttr(value)); + return builder.getArrayAttr(attrs); +} + +void normalizeToPTOGlobalShapeAndStride(ArrayRef shape, + ArrayRef strides, + SmallVectorImpl &globalShape, + SmallVectorImpl &globalStride) { + constexpr int64_t kRank = 5; + globalShape.assign(kRank, 1); + globalStride.assign(kRank, 1); + + size_t shapeRank = std::min(shape.size(), kRank); + size_t strideRank = std::min(strides.size(), kRank); + size_t rank = std::min(shapeRank, strideRank); + size_t base = kRank - rank; + + for (size_t i = 0; i < rank; ++i) { + globalShape[base + i] = shape[shape.size() - rank + i]; + globalStride[base + i] = strides[strides.size() - rank + i]; + } + + for (int i = static_cast(kRank) - 2; i >= 0; --i) { + if (i >= static_cast(base)) + continue; + if (globalStride[i + 1] == ShapedType::kDynamic || + globalShape[i + 1] == ShapedType::kDynamic) { + globalStride[i] = ShapedType::kDynamic; + continue; + } + globalStride[i] = globalStride[i + 1] * globalShape[i + 1]; + } +} + +int64_t packLoopStrideConfig(int64_t first, int64_t second) { + return (static_cast(first) << 40) | static_cast(second); +} + +int64_t packLoopSizeConfig(int64_t loop2, int64_t loop1) { + return (static_cast(loop2) << 21) | static_cast(loop1); +} + +LogicalResult deriveVecNDTransferConfig(ArrayRef shape, + ArrayRef strides, + StringRef tileLayout, Type elementType, + int64_t validRows, int64_t validCols, + SmallVectorImpl &globalShape, + SmallVectorImpl &globalStride, + int64_t &nBurst, int64_t &lenBurst, + int64_t &gmStrideBytes, + int64_t &ubStrideBytes, + int64_t &loop1Size, + int64_t &loop2Size, + int64_t &loop1FirstStrideBytes, + int64_t &loop1SecondStrideBytes, + int64_t &loop2FirstStrideBytes, + int64_t &loop2SecondStrideBytes) { + if (tileLayout != "row_major") + return failure(); + + int64_t elemBytes = getElementByteSize(elementType); + if (elemBytes <= 0) + return failure(); + + normalizeToPTOGlobalShapeAndStride(shape, strides, globalShape, globalStride); + if (globalShape.size() != 5 || globalStride.size() != 5) + return failure(); + if (llvm::any_of(globalShape, [](int64_t v) { return v == ShapedType::kDynamic; }) || + llvm::any_of(globalStride, [](int64_t v) { return v == ShapedType::kDynamic; })) + return failure(); + nBurst = globalShape[3]; + lenBurst = (validCols == ShapedType::kDynamic) ? ShapedType::kDynamic + : validCols * elemBytes; + gmStrideBytes = globalStride[3] * elemBytes; + ubStrideBytes = globalShape[4] * elemBytes; + + int64_t dstStride2 = globalShape[3] * validCols; + int64_t dstStride1 = globalShape[2] * dstStride2; + + loop2Size = globalShape[1]; + loop1Size = globalShape[2]; + loop2FirstStrideBytes = dstStride1 * elemBytes; + loop2SecondStrideBytes = globalStride[1] * elemBytes; + loop1FirstStrideBytes = dstStride2 * elemBytes; + loop1SecondStrideBytes = globalStride[2] * elemBytes; + return success(); +} + +std::pair getStaticTileRowsCols(Value value) { + if (auto shapedType = dyn_cast(value.getType())) { + ArrayRef shape = shapedType.getShape(); + if (shape.size() >= 2) + return {shape[shape.size() - 2], shape[shape.size() - 1]}; + } + return {ShapedType::kDynamic, ShapedType::kDynamic}; +} + +Value materializeStaticOrDynamicDimAsIndex(Value value, int64_t dim, + unsigned dimPos, + PatternRewriter &rewriter, + Location loc) { + if (dim != ShapedType::kDynamic) + return rewriter.create(loc, dim); + if (isa(value.getType())) + return rewriter.create(loc, value, dimPos); + return {}; +} + +LogicalResult materializeShapeBackedValidShapeValues(Value value, Value &rows, + Value &cols, + PatternRewriter &rewriter, + Location loc) { + rows = {}; + cols = {}; + + auto shapedType = dyn_cast(value.getType()); + if (!shapedType || !shapedType.hasRank() || !hasStructuredTileDriver(value)) + return failure(); + + ArrayRef shape = shapedType.getShape(); + if (shape.empty()) { + rows = rewriter.create(loc, 1); + cols = rewriter.create(loc, 1); + return success(); + } + if (shape.size() == 1) { + rows = rewriter.create(loc, 1); + cols = materializeStaticOrDynamicDimAsIndex(value, shape.front(), 0, rewriter, loc); + return success(cols != nullptr); + } + + cols = materializeStaticOrDynamicDimAsIndex(value, shape.back(), shape.size() - 1, + rewriter, loc); + if (!cols) + return failure(); + + Value flatRows = rewriter.create(loc, 1); + for (auto [idx, dim] : llvm::enumerate(shape.drop_back())) { + Value dimValue = + materializeStaticOrDynamicDimAsIndex(value, dim, idx, rewriter, loc); + if (!dimValue) + return failure(); + flatRows = rewriter.create(loc, flatRows, dimValue); + } + rows = flatRows; + return success(); +} + +LogicalResult resolveExecutionValidShape(Value carrier, Value &rowsValue, + Value &colsValue, int64_t &rows, + int64_t &cols, + PatternRewriter &rewriter, + Location loc) { + rowsValue = materializeIndexValue(rowsValue, rows, rewriter, loc); + colsValue = materializeIndexValue(colsValue, cols, rewriter, loc); + if (rowsValue && colsValue) + return success(); + + if (succeeded(materializeShapeBackedValidShapeValues(carrier, rowsValue, colsValue, + rewriter, loc))) { + deriveValidShape(carrier, rows, cols); + return success(rowsValue && colsValue); + } + return failure(); +} + +Attribute getGmMemorySpace(MLIRContext *context) { + return AddressSpaceAttr::get(context, AddressSpace::GM); +} + +AddressSpaceAttr getNormalizedPtrMemorySpace(Attribute memorySpace, + MLIRContext *context) { + if (auto addrSpace = dyn_cast_or_null(memorySpace)) + return addrSpace; + if (auto intAttr = dyn_cast_or_null(memorySpace)) + return AddressSpaceAttr::get(context, + static_cast(intAttr.getInt())); + return AddressSpaceAttr::get(context, AddressSpace::GM); +} + +Value materializeMemRefView(Value value, ArrayRef shape, Type elementType, + Attribute memorySpace, + PatternRewriter &rewriter, Location loc) { + auto memrefType = + MemRefType::get(shape, elementType, AffineMap(), memorySpace); + if (value.getType() == memrefType) + return value; + return rewriter + .create( + loc, TypeRange(ArrayRef{memrefType}), value) + .getResult(0); +} + +Value materializeTileBufferView(Value value, PatternRewriter &rewriter, + Location loc) { + if (auto memrefType = dyn_cast(value.getType())) + return value; + + auto tileType = dyn_cast(value.getType()); + if (!tileType) + return {}; + + return materializeMemRefView(value, tileType.getShape(), tileType.getElementType(), + tileType.getMemorySpace(), rewriter, loc); +} + +} // namespace + +Value materializeBufferPointer(Value value, Type elementType, + Attribute memorySpace, + PatternRewriter &rewriter, Location loc) { + if (!value) + return {}; + + auto ptrMemorySpace = + getNormalizedPtrMemorySpace(memorySpace, rewriter.getContext()); + auto ptrType = PtrType::get(rewriter.getContext(), elementType, ptrMemorySpace); + + if (value.getType() == ptrType) + return value; + + if (auto bind = value.getDefiningOp()) + return materializeBufferPointer(bind.getSource(), elementType, memorySpace, + rewriter, loc); + + if (auto cast = value.getDefiningOp()) { + if (cast.getAddrs().empty()) + return {}; + return rewriter.create(loc, ptrType, cast.getAddrs().front()) + .getResult(); + } + + Value memrefValue = materializeTileBufferView(value, rewriter, loc); + auto memrefType = dyn_cast_or_null(memrefValue.getType()); + if (!memrefValue || !memrefType) + return {}; + return rewriter.create(loc, ptrType, memrefValue).getResult(); +} + +namespace { + +Value materializeBufferLikeAddress(Value value, Type elementType, + Attribute memorySpace, + PatternRewriter &rewriter, Location loc) { + if (!value) + return {}; + + if (auto bind = value.getDefiningOp()) + return materializeBufferLikeAddress(bind.getSource(), elementType, memorySpace, + rewriter, loc); + + // Keep memref semantics through the VPTO mainline whenever possible. + Value memrefValue = materializeTileBufferView(value, rewriter, loc); + if (memrefValue && isa(memrefValue.getType())) + return memrefValue; + + return materializeBufferPointer(value, elementType, memorySpace, rewriter, loc); +} + +Value offsetBufferPointer(Value basePtr, Type elementType, Value elementOffset, + PatternRewriter &rewriter, Location loc) { + if (!basePtr) + return {}; + + if (auto ptrType = dyn_cast(basePtr.getType())) { + Value offsetIndex = + elementOffset.getType().isIndex() + ? elementOffset + : rewriter.create(loc, + rewriter.getIndexType(), + elementOffset); + return rewriter.create(loc, ptrType, basePtr, offsetIndex); + } + return {}; +} + +Value buildPackedCountI64(PatternRewriter &rewriter, Location loc, + ArrayRef counts) { + Value packed = rewriter.create(loc, 0, 64); + for (auto [idx, count] : llvm::enumerate(counts)) { + Value countI64 = count.getType().isIndex() + ? rewriter.create( + loc, rewriter.getI64Type(), count) + : count; + if (idx != 0) { + Value shift = rewriter.create(loc, idx * 16, 64); + countI64 = rewriter.create(loc, countI64, shift); + } + packed = rewriter.create(loc, packed, countI64); + } + return packed; +} + +Value buildCeilDivPositiveI64(PatternRewriter &rewriter, Location loc, Value lhs, + int64_t rhs) { + Value rhsValue = rewriter.create(loc, rhs, 64); + Value rhsMinusOne = rewriter.create(loc, rhs - 1, 64); + Value biased = rewriter.create(loc, lhs, rhsMinusOne); + return rewriter.create(loc, biased, rhsValue); +} + +VPTOPartitionTrace extractPartitionTrace(Value value) { + VPTOPartitionTrace trace; + if (auto part = value.getDefiningOp()) { + appendStaticSizes(part.getOffsets(), trace.offsets, trace.hasDynamicOffsets); + appendStaticSizes(part.getSizes(), trace.sizes, trace.hasDynamicSizes); + return trace; + } + if (auto subview = value.getDefiningOp()) { + trace = extractPartitionTrace(subview.getSource()); + SmallVector offsets; + SmallVector sizes; + bool hasDynamicOffsets = false; + bool hasDynamicSizes = false; + recordStaticSizes(subview.getMixedOffsets(), offsets, hasDynamicOffsets); + recordStaticSizes(subview.getMixedSizes(), sizes, hasDynamicSizes); + mergeSubviewTrace(trace, offsets, sizes, hasDynamicOffsets, hasDynamicSizes); + return trace; + } + if (auto reinterpret = value.getDefiningOp()) + return extractPartitionTrace(reinterpret.getSource()); + if (auto cast = value.getDefiningOp()) + return extractPartitionTrace(cast.getSource()); + if (auto unrealized = value.getDefiningOp()) { + if (!unrealized.getInputs().empty()) + return extractPartitionTrace(unrealized.getInputs().front()); + } + return trace; +} + +VPTOLoadContract extractTLoadContract(TLoadOp op) { + VPTOLoadContract contract; + contract.trace = extractPartitionTrace(op.getSrc()); + contract.elementType = getElementType(op.getDst()); + + Attribute layoutAttr; + Value base = resolveTensorViewBase(op.getSrc(), layoutAttr, contract.sourceShape, + contract.sourceStrides); + (void)base; + contract.sourceLayout = stringifyLayoutAttr(layoutAttr); + + contract.tileLayout = deriveTileLayout(op.getDst()); + contract.tileDomain = deriveTileDomain(getMemorySpace(op.getDst())); + deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); + deriveValidShape(op.getDst(), contract.validRows, contract.validCols); + contract.padMode = stringifyPadModeAttr(op.getPadModeAttr()); + contract.padValue = op.getPadValue(); + contract.leftPaddingNum = op.getLeftPaddingNum(); + contract.rightPaddingNum = op.getRightPaddingNum(); + contract.initOutBuffer = op.getInitOutBuffer(); + contract.initCondition = op.getInitCondition(); + return contract; +} + +VPTOUnaryContract extractTAbsContract(TAbsOp op) { + VPTOUnaryContract contract; + contract.family = "abs"; + contract.tileDomain = deriveTileDomain(getMemorySpace(op.getSrc())); + contract.tileLayout = deriveTileLayout(op.getSrc()); + deriveValidShapeValues(op.getSrc(), contract.validRowsValue, contract.validColsValue); + deriveValidShape(op.getSrc(), contract.validRows, contract.validCols); + contract.elementType = getElementType(op.getSrc()); + contract.loopScope.kind = VPTOLoopScopeKind::AIVVectorScope; + contract.loopScope.loweredAttr = kLoweredLoopScopeAttrName; + contract.loopScope.loopDepth = 0; + return contract; +} + +VPTOBinaryContract buildBinaryContract(StringRef family, Value src0) { + VPTOBinaryContract contract; + contract.family = family; + contract.tileDomain = deriveTileDomain(getMemorySpace(src0)); + contract.tileLayout = deriveTileLayout(src0); + deriveValidShapeValues(src0, contract.validRowsValue, contract.validColsValue); + deriveValidShape(src0, contract.validRows, contract.validCols); + contract.elementType = getElementType(src0); + contract.loopScope.kind = VPTOLoopScopeKind::AIVVectorScope; + contract.loopScope.loweredAttr = kLoweredLoopScopeAttrName; + contract.loopScope.loopDepth = 0; + return contract; +} + +VPTOBinaryContract extractTAddContract(TAddOp op) { + return buildBinaryContract("add", op.getSrc0()); +} + +VPTOBinaryContract extractTSubContract(TSubOp op) { + return buildBinaryContract("sub", op.getSrc0()); +} + +VPTOBinaryContract extractTMulContract(TMulOp op) { + return buildBinaryContract("mul", op.getSrc0()); +} + +VPTOBinaryContract extractTDivContract(TDivOp op) { + return buildBinaryContract("div", op.getSrc0()); +} + +VPTOUnaryContract buildUnaryContract(StringRef family, Value src) { + VPTOUnaryContract contract; + contract.family = family; + contract.tileDomain = deriveTileDomain(getMemorySpace(src)); + contract.tileLayout = deriveTileLayout(src); + deriveValidShapeValues(src, contract.validRowsValue, contract.validColsValue); + deriveValidShape(src, contract.validRows, contract.validCols); + contract.elementType = getElementType(src); + contract.loopScope.kind = VPTOLoopScopeKind::AIVVectorScope; + contract.loopScope.loweredAttr = kLoweredLoopScopeAttrName; + contract.loopScope.loopDepth = 0; + return contract; +} + +VPTOUnaryContract extractTExpContract(TExpOp op) { + return buildUnaryContract("exp", op.getSrc()); +} + +VPTOUnaryContract extractTLogContract(TLogOp op) { + return buildUnaryContract("log", op.getSrc()); +} + +VPTOUnaryContract extractTSqrtContract(TSqrtOp op) { + return buildUnaryContract("sqrt", op.getSrc()); +} + +VPTOUnaryContract extractTRecipContract(TRecipOp op) { + return buildUnaryContract("recip", op.getSrc()); +} + +VPTOUnaryContract extractTReluContract(TReluOp op) { + return buildUnaryContract("relu", op.getSrc()); +} + +VPTOUnaryContract extractTNotContract(TNotOp op) { + return buildUnaryContract("not", op.getSrc()); +} + +static FailureOr stringifyA5RoundMode(TCvtOp op, + PatternRewriter &rewriter) { + switch (op.getRmode()) { + case RoundMode::NONE: + case RoundMode::RINT: + case RoundMode::CAST_RINT: + return rewriter.getStringAttr("ROUND_R"); + case RoundMode::ROUND: + return rewriter.getStringAttr("ROUND_A"); + case RoundMode::FLOOR: + return rewriter.getStringAttr("ROUND_F"); + case RoundMode::CEIL: + return rewriter.getStringAttr("ROUND_C"); + case RoundMode::TRUNC: + return rewriter.getStringAttr("ROUND_Z"); + case RoundMode::ODD: + return rewriter.getStringAttr("ROUND_O"); + } + return failure(); +} + +enum class VPTOCvtLoweringKind { + Vtrc, + F32ToBF16, + F16ToF32, + BF16ToF32, +}; + +static FailureOr classifyA5CvtLowering(Type srcElemType, + Type dstElemType) { + if (srcElemType.isF32() && dstElemType.isF32()) + return VPTOCvtLoweringKind::Vtrc; + if (srcElemType.isF32() && dstElemType.isBF16()) + return VPTOCvtLoweringKind::F32ToBF16; + if (srcElemType.isF16() && dstElemType.isF32()) + return VPTOCvtLoweringKind::F16ToF32; + if (srcElemType.isBF16() && dstElemType.isF32()) + return VPTOCvtLoweringKind::BF16ToF32; + return failure(); +} + +VPTOUnaryContract extractTExpandSContract(TExpandsOp op) { + VPTOUnaryContract contract; + contract.family = "expands"; + contract.tileDomain = deriveTileDomain(getMemorySpace(op.getDst())); + contract.tileLayout = deriveTileLayout(op.getDst()); + deriveValidShapeValues(op.getDst(), contract.validRowsValue, + contract.validColsValue); + deriveValidShape(op.getDst(), contract.validRows, contract.validCols); + contract.elementType = getElementType(op.getDst()); + contract.loopScope.kind = VPTOLoopScopeKind::AIVVectorScope; + contract.loopScope.loweredAttr = kLoweredLoopScopeAttrName; + contract.loopScope.loopDepth = 0; + return contract; +} + +VPTOExpandContract extractTRowExpandContract(TRowExpandOp op) { + VPTOExpandContract contract; + contract.family = "rowexpand"; + contract.srcDomain = deriveTileDomain(getMemorySpace(op.getSrc())); + contract.dstDomain = deriveTileDomain(getMemorySpace(op.getDst())); + contract.srcLayout = deriveTileLayout(op.getSrc()); + contract.dstLayout = deriveTileLayout(op.getDst()); + contract.elementType = getElementType(op.getSrc()); + deriveValidShapeValues(op.getSrc(), contract.srcValidRowsValue, + contract.srcValidColsValue); + deriveValidShape(op.getSrc(), contract.srcValidRows, contract.srcValidCols); + deriveValidShapeValues(op.getDst(), contract.dstValidRowsValue, + contract.dstValidColsValue); + deriveValidShape(op.getDst(), contract.dstValidRows, contract.dstValidCols); + contract.loopScope.kind = VPTOLoopScopeKind::AIVVectorScope; + contract.loopScope.loweredAttr = kLoweredLoopScopeAttrName; + contract.loopScope.loopDepth = 0; + return contract; +} + +VPTOExpandContract extractTColExpandContract(TColExpandOp op) { + VPTOExpandContract contract; + contract.family = "colexpand"; + contract.srcDomain = deriveTileDomain(getMemorySpace(op.getSrc())); + contract.dstDomain = deriveTileDomain(getMemorySpace(op.getDst())); + contract.srcLayout = deriveTileLayout(op.getSrc()); + contract.dstLayout = deriveTileLayout(op.getDst()); + contract.elementType = getElementType(op.getSrc()); + deriveValidShapeValues(op.getSrc(), contract.srcValidRowsValue, + contract.srcValidColsValue); + deriveValidShape(op.getSrc(), contract.srcValidRows, contract.srcValidCols); + deriveValidShapeValues(op.getDst(), contract.dstValidRowsValue, + contract.dstValidColsValue); + deriveValidShape(op.getDst(), contract.dstValidRows, contract.dstValidCols); + contract.loopScope.kind = VPTOLoopScopeKind::AIVVectorScope; + contract.loopScope.loweredAttr = kLoweredLoopScopeAttrName; + contract.loopScope.loopDepth = 0; + return contract; +} + +VPTORowReduceContract extractTRowReduceContract(Value src, Value dst, + StringRef family) { + VPTORowReduceContract contract; + contract.family = family; + contract.srcDomain = deriveTileDomain(getMemorySpace(src)); + contract.dstDomain = deriveTileDomain(getMemorySpace(dst)); + contract.srcLayout = deriveTileLayout(src); + contract.dstLayout = deriveTileLayout(dst); + contract.elementType = getElementType(src); + deriveValidShapeValues(src, contract.validRowsValue, contract.validColsValue); + deriveValidShape(src, contract.validRows, contract.validCols); + int64_t dstRows = ShapedType::kDynamic; + deriveValidShape(dst, dstRows, contract.dstValidCols); + contract.loopScope.kind = VPTOLoopScopeKind::AIVVectorScope; + contract.loopScope.loweredAttr = kLoweredLoopScopeAttrName; + contract.loopScope.loopDepth = 0; + return contract; +} + +VPTORowReduceContract extractTRowMaxContract(TRowMaxOp op) { + return extractTRowReduceContract(op.getSrc(), op.getDst(), "rowmax"); +} + +VPTORowReduceContract extractTRowMinContract(TRowMinOp op) { + return extractTRowReduceContract(op.getSrc(), op.getDst(), "rowmin"); +} + +VPTORowReduceContract extractTRowSumContract(TRowSumOp op) { + return extractTRowReduceContract(op.getSrc(), op.getDst(), "rowsum"); +} + +VPTOColReduceContract extractTColReduceContract(Value src, Value dst, + StringRef family) { + VPTOColReduceContract contract; + contract.family = family; + contract.srcDomain = deriveTileDomain(getMemorySpace(src)); + contract.dstDomain = deriveTileDomain(getMemorySpace(dst)); + contract.srcLayout = deriveTileLayout(src); + contract.dstLayout = deriveTileLayout(dst); + contract.elementType = getElementType(src); + deriveValidShapeValues(src, contract.validRowsValue, contract.validColsValue); + deriveValidShape(src, contract.validRows, contract.validCols); + deriveValidShape(dst, contract.dstValidRows, contract.dstValidCols); + contract.loopScope.kind = VPTOLoopScopeKind::AIVVectorScope; + contract.loopScope.loweredAttr = kLoweredLoopScopeAttrName; + contract.loopScope.loopDepth = 0; + return contract; +} + +VPTOColReduceContract extractTColMaxContract(TColMaxOp op) { + return extractTColReduceContract(op.getSrc(), op.getDst(), "colmax"); +} + +VPTOColReduceContract extractTColMinContract(TColMinOp op) { + return extractTColReduceContract(op.getSrc(), op.getDst(), "colmin"); +} + +VPTOColReduceContract extractTColSumContract(TColSumOp op) { + VPTOColReduceContract contract = + extractTColReduceContract(op.getSrc(), op.getDst(), "colsum"); + contract.isBinary = op.getIsBinary(); + contract.tmp = op.getTmp(); + return contract; +} + +VPTOPartContract extractTPartContract(Value src0, Value src1, Value dst, + StringRef family) { + VPTOPartContract contract; + contract.family = family; + contract.src0Domain = deriveTileDomain(getMemorySpace(src0)); + contract.src1Domain = deriveTileDomain(getMemorySpace(src1)); + contract.dstDomain = deriveTileDomain(getMemorySpace(dst)); + contract.src0Layout = deriveTileLayout(src0); + contract.src1Layout = deriveTileLayout(src1); + contract.dstLayout = deriveTileLayout(dst); + contract.elementType = getElementType(dst); + deriveValidShapeValues(src0, contract.src0ValidRowsValue, contract.src0ValidColsValue); + deriveValidShapeValues(src1, contract.src1ValidRowsValue, contract.src1ValidColsValue); + deriveValidShapeValues(dst, contract.dstValidRowsValue, contract.dstValidColsValue); + deriveValidShape(src0, contract.src0ValidRows, contract.src0ValidCols); + deriveValidShape(src1, contract.src1ValidRows, contract.src1ValidCols); + deriveValidShape(dst, contract.dstValidRows, contract.dstValidCols); + contract.loopScope.kind = VPTOLoopScopeKind::AIVVectorScope; + contract.loopScope.loweredAttr = kLoweredLoopScopeAttrName; + contract.loopScope.loopDepth = 0; + return contract; +} + +VPTOPartContract extractTPartAddContract(TPartAddOp op) { + return extractTPartContract(op.getSrc0(), op.getSrc1(), op.getDst(), "partadd"); +} + +VPTOPartContract extractTPartMaxContract(TPartMaxOp op) { + return extractTPartContract(op.getSrc0(), op.getSrc1(), op.getDst(), "partmax"); +} + +VPTOPartContract extractTPartMinContract(TPartMinOp op) { + return extractTPartContract(op.getSrc0(), op.getSrc1(), op.getDst(), "partmin"); +} + +VPTOStoreContract extractTStoreContract(TStoreOp op) { + VPTOStoreContract contract; + contract.trace = extractPartitionTrace(op.getDst()); + + contract.srcDomain = deriveTileDomain(getMemorySpace(op.getSrc())); + deriveValidShapeValues(op.getSrc(), contract.validRowsValue, contract.validColsValue); + deriveValidShape(op.getSrc(), contract.validRows, contract.validCols); + contract.elementType = getElementType(op.getSrc()); + + Attribute layoutAttr; + Value base = resolveTensorViewBase(op.getDst(), layoutAttr, + contract.destinationShape, + contract.destinationStrides); + (void)base; + contract.destinationLayout = stringifyLayoutAttr(layoutAttr); + return contract; +} + +void attachLoadContractAttrs(Operation *op, const VPTOLoadContract &contract) { + Builder builder(op->getContext()); + SmallVector globalShape; + SmallVector globalStride; + normalizeToPTOGlobalShapeAndStride(contract.sourceShape, contract.sourceStrides, + globalShape, globalStride); + op->setAttr("g_shape", asI64ArrayAttr(builder, globalShape)); + op->setAttr("g_strides", asI64ArrayAttr(builder, globalStride)); +} + +void attachStoreContractAttrs(Operation *op, const VPTOStoreContract &contract) { + Builder builder(op->getContext()); + SmallVector globalShape; + SmallVector globalStride; + normalizeToPTOGlobalShapeAndStride(contract.destinationShape, + contract.destinationStrides, globalShape, + globalStride); + op->setAttr("g_shape", asI64ArrayAttr(builder, globalShape)); + op->setAttr("g_strides", asI64ArrayAttr(builder, globalStride)); +} + +LogicalResult lowerUnsupportedAccStore(Location loc) { + emitError(loc) << "TSTORE ACC lowering TODO for vpto backend"; + return failure(); +} + +LogicalResult lowerUnsupportedMatStore(Location loc) { + emitError(loc) << "TSTORE MAT lowering TODO for vpto backend"; + return failure(); +} + +} // namespace + +FailureOr +createLoopScopeRegion(Location loc, const VPTOLoopScopeContract &contract, + PatternRewriter &rewriter) { + if (contract.kind == VPTOLoopScopeKind::None) + return failure(); + if (contract.kind != VPTOLoopScopeKind::AIVVectorScope) + return failure(); + + auto vecScope = rewriter.create(loc); + vecScope.getBody().push_back(new Block()); + return vecScope; +} + +void set_loop2_stride_outtoub(Operation *copyOp, int64_t dstStride, + int64_t srcStride, Builder &builder) { + copyOp->setAttr("pto.set_loop2_stride_outtoub", + builder.getI64IntegerAttr( + packLoopStrideConfig(dstStride, srcStride))); +} + +void set_loop1_stride_outtoub(Operation *copyOp, int64_t dstStride, + int64_t srcStride, Builder &builder) { + copyOp->setAttr("pto.set_loop1_stride_outtoub", + builder.getI64IntegerAttr( + packLoopStrideConfig(dstStride, srcStride))); +} + +void set_loop_size_outtoub(Operation *copyOp, int64_t loop2, int64_t loop1, + Builder &builder) { + copyOp->setAttr("pto.set_loop_size_outtoub", + builder.getI64IntegerAttr(packLoopSizeConfig(loop2, loop1))); +} + +void set_loop2_stride_ubtoout(Operation *copyOp, int64_t srcStride, + int64_t dstStride, Builder &builder) { + copyOp->setAttr("pto.set_loop2_stride_ubtoout", + builder.getI64IntegerAttr( + packLoopStrideConfig(srcStride, dstStride))); +} + +void set_loop1_stride_ubtoout(Operation *copyOp, int64_t srcStride, + int64_t dstStride, Builder &builder) { + copyOp->setAttr("pto.set_loop1_stride_ubtoout", + builder.getI64IntegerAttr( + packLoopStrideConfig(srcStride, dstStride))); +} + +void set_loop_size_ubtoout(Operation *copyOp, int64_t loop2, int64_t loop1, + Builder &builder) { + copyOp->setAttr("pto.set_loop_size_ubtoout", + builder.getI64IntegerAttr(packLoopSizeConfig(loop2, loop1))); +} + +LogicalResult programCopyGmToUbLoops(Operation *copyOp, + const VPTOLoadContract &contract, + Builder &builder) { + SmallVector globalShape; + SmallVector globalStride; + int64_t nBurst = 0, lenBurst = 0, gmStrideBytes = 0, ubStrideBytes = 0; + int64_t loop1Size = 0, loop2Size = 0; + int64_t loop1DstStrideBytes = 0, loop1SrcStrideBytes = 0; + int64_t loop2DstStrideBytes = 0, loop2SrcStrideBytes = 0; + if (failed(deriveVecNDTransferConfig(contract.sourceShape, contract.sourceStrides, + contract.tileLayout, contract.elementType, + contract.validRows, contract.validCols, + globalShape, globalStride, nBurst, lenBurst, + gmStrideBytes, ubStrideBytes, loop1Size, + loop2Size, loop1DstStrideBytes, + loop1SrcStrideBytes, loop2DstStrideBytes, + loop2SrcStrideBytes))) + return failure(); + + set_loop2_stride_outtoub(copyOp, loop2DstStrideBytes, loop2SrcStrideBytes, builder); + set_loop1_stride_outtoub(copyOp, loop1DstStrideBytes, loop1SrcStrideBytes, builder); + set_loop_size_outtoub(copyOp, loop2Size, loop1Size, builder); + return success(); +} + +LogicalResult programCopyUbToGmLoops(Operation *copyOp, + const VPTOStoreContract &contract, + Builder &builder) { + SmallVector globalShape; + SmallVector globalStride; + int64_t nBurst = 0, lenBurst = 0, burstDstStrideBytes = 0, burstSrcStrideBytes = 0; + int64_t loop1Size = 0, loop2Size = 0; + int64_t loop1SrcStrideBytes = 0, loop1DstStrideBytes = 0; + int64_t loop2SrcStrideBytes = 0, loop2DstStrideBytes = 0; + if (failed(deriveVecNDTransferConfig(contract.destinationShape, + contract.destinationStrides, + "row_major", contract.elementType, + contract.validRows, contract.validCols, + globalShape, globalStride, nBurst, lenBurst, + burstDstStrideBytes, burstSrcStrideBytes, + loop1Size, loop2Size, loop1SrcStrideBytes, + loop1DstStrideBytes, loop2SrcStrideBytes, + loop2DstStrideBytes))) + return failure(); + + set_loop_size_ubtoout(copyOp, loop2Size, loop1Size, builder); + set_loop1_stride_ubtoout(copyOp, loop1SrcStrideBytes, loop1DstStrideBytes, builder); + set_loop2_stride_ubtoout(copyOp, loop2SrcStrideBytes, loop2DstStrideBytes, builder); + return success(); +} + +int64_t deriveStaticRowStride(Value value) { + StringRef layout = deriveTileLayout(value); + if (layout == "col_major") + return 1; + + if (auto tileType = dyn_cast(value.getType())) { + ArrayRef shape = tileType.getShape(); + if (shape.size() >= 2) + return shape[shape.size() - 1]; + } + if (auto shapedType = dyn_cast(value.getType())) { + ArrayRef shape = shapedType.getShape(); + if (shape.size() >= 2) + return shape[shape.size() - 1]; + } + return ShapedType::kDynamic; +} + +int64_t deriveStaticShapeDim(Value value, unsigned dim) { + if (auto tileType = dyn_cast(value.getType())) { + ArrayRef shape = tileType.getShape(); + if (dim < shape.size()) + return shape[dim]; + } + if (auto shapedType = dyn_cast(value.getType())) { + ArrayRef shape = shapedType.getShape(); + if (dim < shape.size()) + return shape[dim]; + } + return ShapedType::kDynamic; +} + +int64_t deriveStaticTileCols(Value value) { + if (auto tileType = dyn_cast(value.getType())) { + ArrayRef shape = tileType.getShape(); + if (!shape.empty()) + return shape.back(); + } + if (auto shapedType = dyn_cast(value.getType())) { + ArrayRef shape = shapedType.getShape(); + if (!shape.empty()) + return shape.back(); + } + return ShapedType::kDynamic; +} + +Value buildFullWidthColsCondition(ArrayRef tileCols, + Value validColsValue, + PatternRewriter &rewriter, Location loc) { + Value condition; + for (int64_t tileCol : tileCols) { + if (tileCol == ShapedType::kDynamic) + return {}; + Value tileColValue = rewriter.create(loc, tileCol); + Value isFullWidth = rewriter.create( + loc, arith::CmpIPredicate::eq, validColsValue, tileColValue); + condition = condition ? rewriter.create(loc, condition, isFullWidth) + : isFullWidth; + } + return condition; +} + +Value buildMinIndexValue(PatternRewriter &rewriter, Location loc, Value lhs, + Value rhs) { + auto lhsLtRhs = rewriter.create(loc, arith::CmpIPredicate::slt, + lhs, rhs); + return rewriter.create(loc, lhsLtRhs, lhs, rhs); +} + +struct PredicateMaterialization { + Value mask; + Value nextScalar; +}; + +PredicateMaterialization buildPredicateForLaneCount(PatternRewriter &rewriter, + Location loc, + Type elementType, + Value laneCount) { + auto maskType = getVPTOMaskTypeForElementType(rewriter.getContext(), elementType); + Value laneCountI32 = laneCount; + if (laneCount.getType().isIndex()) { + laneCountI32 = + rewriter.create(loc, rewriter.getI32Type(), laneCount); + } else if (auto intType = dyn_cast(laneCount.getType())) { + if (intType.getWidth() < 32) + laneCountI32 = rewriter.create(loc, rewriter.getI32Type(), laneCount); + else if (intType.getWidth() > 32) + laneCountI32 = + rewriter.create(loc, rewriter.getI32Type(), laneCount); + } + unsigned bitWidth = 0; + if (auto intType = dyn_cast(elementType)) + bitWidth = intType.getWidth(); + else if (auto floatType = dyn_cast(elementType)) + bitWidth = floatType.getWidth(); + if (bitWidth == 8) { + auto plt = rewriter.create(loc, maskType, rewriter.getI32Type(), + laneCountI32); + return {plt.getMask(), plt.getScalarOut()}; + } + if (bitWidth == 16) { + auto plt = rewriter.create(loc, maskType, rewriter.getI32Type(), + laneCountI32); + return {plt.getMask(), plt.getScalarOut()}; + } + if (bitWidth == 32) { + auto plt = rewriter.create(loc, maskType, rewriter.getI32Type(), + laneCountI32); + return {plt.getMask(), plt.getScalarOut()}; + } + llvm_unreachable("unsupported element type for predicate lane-count lowering"); +} + +Value buildPredicateMaskForLaneCount(PatternRewriter &rewriter, Location loc, + Type elementType, Value laneCount) { + return buildPredicateForLaneCount(rewriter, loc, elementType, laneCount).mask; +} + +Value buildAllPredicateMask(PatternRewriter &rewriter, Location loc, + Type elementType) { + auto maskType = getVPTOMaskTypeForElementType(rewriter.getContext(), elementType); + StringAttr allPattern = rewriter.getStringAttr("PAT_ALL"); + unsigned bitWidth = 0; + if (auto intType = dyn_cast(elementType)) + bitWidth = intType.getWidth(); + else if (auto floatType = dyn_cast(elementType)) + bitWidth = floatType.getWidth(); + if (bitWidth == 8) + return rewriter.create(loc, maskType, allPattern).getResult(); + if (bitWidth == 16) + return rewriter.create(loc, maskType, allPattern).getResult(); + if (bitWidth == 32) + return rewriter.create(loc, maskType, allPattern).getResult(); + llvm_unreachable("unsupported element type for full predicate mask lowering"); +} + +LogicalResult buildMaskedVectorStore(PatternRewriter &rewriter, Location loc, + Value value, Value dstBuffer, + Value dstOffset, Value activeLanes, + int64_t vectorWidth) { + auto vecType = cast(value.getType()); + Value mask = buildPredicateMaskForLaneCount(rewriter, loc, + vecType.getElementType(), + activeLanes); + rewriter.create(loc, value, dstBuffer, dstOffset, StringAttr(), + mask); + return success(); +} + +Attribute buildRowReduceInitValue(Type elementType, StringRef family, + Builder &builder) { + if (!isa(elementType)) + return {}; + + if (family == "rowsum") + return builder.getFloatAttr(elementType, 0.0); + + const llvm::fltSemantics &semantics = [&]() -> const llvm::fltSemantics & { + if (elementType.isF16()) + return llvm::APFloat::IEEEhalf(); + if (elementType.isBF16()) + return llvm::APFloat::BFloat(); + return llvm::APFloat::IEEEsingle(); + }(); + bool negative = family == "rowmax"; + return builder.getFloatAttr(elementType, llvm::APFloat::getInf(semantics, negative)); +} + +Attribute buildPartPadValue(Type elementType, StringRef family, Builder &builder) { + if (family == "partadd") + return builder.getZeroAttr(elementType); + if (isa(elementType)) { + const llvm::fltSemantics &semantics = [&]() -> const llvm::fltSemantics & { + if (elementType.isF16()) + return llvm::APFloat::IEEEhalf(); + if (elementType.isBF16()) + return llvm::APFloat::BFloat(); + return llvm::APFloat::IEEEsingle(); + }(); + bool negative = family == "partmax"; + return builder.getFloatAttr(elementType, llvm::APFloat::getInf(semantics, negative)); + } + if (auto intType = dyn_cast(elementType)) { + unsigned width = intType.getWidth(); + if (intType.isUnsigned()) { + if (family == "partmax") + return builder.getIntegerAttr(elementType, 0); + return builder.getIntegerAttr(elementType, llvm::APInt::getAllOnes(width)); + } + if (family == "partmax") + return builder.getIntegerAttr(elementType, llvm::APInt::getSignedMinValue(width)); + return builder.getIntegerAttr(elementType, llvm::APInt::getSignedMaxValue(width)); + } + return {}; +} + +Attribute buildFillPadValue(Type elementType, PadValueAttr padAttr, Builder &builder) { + if (!padAttr) + return {}; + + switch (padAttr.getValue()) { + case PadValue::Null: + return {}; + case PadValue::Zero: + return builder.getZeroAttr(elementType); + case PadValue::Max: + if (isa(elementType)) { + const llvm::fltSemantics &semantics = [&]() -> const llvm::fltSemantics & { + if (elementType.isF16()) + return llvm::APFloat::IEEEhalf(); + if (elementType.isBF16()) + return llvm::APFloat::BFloat(); + return llvm::APFloat::IEEEsingle(); + }(); + return builder.getFloatAttr(elementType, + llvm::APFloat::getLargest(semantics)); + } + if (auto intType = dyn_cast(elementType)) { + unsigned width = intType.getWidth(); + return intType.isUnsigned() + ? builder.getIntegerAttr(elementType, + llvm::APInt::getMaxValue(width)) + : builder.getIntegerAttr(elementType, + llvm::APInt::getSignedMaxValue(width)); + } + return {}; + case PadValue::Min: + if (isa(elementType)) { + const llvm::fltSemantics &semantics = [&]() -> const llvm::fltSemantics & { + if (elementType.isF16()) + return llvm::APFloat::IEEEhalf(); + if (elementType.isBF16()) + return llvm::APFloat::BFloat(); + return llvm::APFloat::IEEEsingle(); + }(); + auto value = llvm::APFloat::getLargest(semantics); + value.changeSign(); + return builder.getFloatAttr(elementType, value); + } + if (auto intType = dyn_cast(elementType)) { + unsigned width = intType.getWidth(); + return intType.isUnsigned() + ? builder.getIntegerAttr(elementType, llvm::APInt(width, 0)) + : builder.getIntegerAttr(elementType, + llvm::APInt::getSignedMinValue(width)); + } + return {}; + } + return {}; +} + +LogicalResult buildRowReduceVecScope(StringRef family, + const VPTORowReduceContract &contract, + VPTOLoweringStrategy strategy, Value src, + Value dst, + PatternRewriter &rewriter, Location loc) { + auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); + if (!vecType) + return emitError(loc) << "unsupported VPTO row-reduce element type"; + + Value srcBuffer = materializeBufferPointer(src, contract.elementType, + getMemorySpace(src), rewriter, loc); + Value dstBuffer = materializeBufferPointer(dst, contract.elementType, + getMemorySpace(dst), rewriter, loc); + if (!srcBuffer || !dstBuffer) + return emitError(loc) << "requires pointer-backed tile buffers for row-reduce lowering"; + + if (contract.validRows == ShapedType::kDynamic || + contract.validCols == ShapedType::kDynamic) + return emitError(loc) << family << " lowering currently requires static valid rows and cols"; + + int64_t srcRowStride = deriveStaticRowStride(src); + int64_t dstRowStride = deriveStaticRowStride(dst); + if (srcRowStride == ShapedType::kDynamic || dstRowStride == ShapedType::kDynamic) + return emitError(loc) << family << " lowering requires static row strides"; + + Attribute initValue = buildRowReduceInitValue(contract.elementType, family, rewriter); + if (!initValue) + return emitError(loc) << family << " lowering supports only f16 and f32 element types"; + + auto getRowReduceStoreDist = [&]() -> StringAttr { + if (contract.elementType.isF16() || contract.elementType.isBF16()) + return rewriter.getStringAttr("1PT"); + if (contract.elementType.isF32()) + return rewriter.getStringAttr("1PT"); + return {}; + }; + StringAttr storeDist = getRowReduceStoreDist(); + if (!storeDist) + return emitError(loc) << family << " lowering supports only f16 and f32 row-reduce stores"; + + int64_t vectorWidth = vecType.getElementCount(); + int64_t repeatTimes = llvm::divideCeil(contract.validCols, vectorWidth); + + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value rowsUpper = rewriter.create(loc, contract.validRows); + Value srcRowStrideValue = rewriter.create(loc, srcRowStride); + Value dstRowStrideValue = rewriter.create(loc, dstRowStride); + Value vectorWidthValue = rewriter.create(loc, vectorWidth); + Value initScalar = rewriter.create(loc, cast(initValue)); + + FailureOr vecScope = + createLoopScopeRegion(loc, contract.loopScope, rewriter); + if (failed(vecScope)) + return emitError(loc) << "failed to create AIV vector scope region"; + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + Value dstPredicate = + buildPredicateMaskForLaneCount(rewriter, loc, contract.elementType, c1); + Value validColsValue = + rewriter.create(loc, contract.validCols); + + if (strategy == VPTOLoweringStrategy::PostUpdate) { + auto rowLoop = + rewriter.create(loc, c0, rowsUpper, c1, ValueRange{dstBuffer}); + + OpBuilder::InsertionGuard rowGuard(rewriter); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + Value row = rowLoop.getInductionVar(); + Value dstPtr = rowLoop.getRegionIterArgs().front(); + Value rowBase = rewriter.create(loc, row, srcRowStrideValue); + Value srcPtr = + adjustPointerByElemOffset(srcBuffer, rowBase, getElementByteSize(contract.elementType), + rewriter, loc); + Value acc = rewriter.create(loc, vecType, initScalar); + Value remainingCols = rewriter.create( + loc, contract.validCols, 32); + for (int64_t repeatIndex = 0; repeatIndex < repeatTimes; ++repeatIndex) { + auto predicateState = + buildPredicateForLaneCount(rewriter, loc, contract.elementType, remainingCols); + Value srcPredicate = predicateState.mask; + auto srcVecOp = rewriter.create( + loc, TypeRange{vecType, srcPtr.getType()}, srcPtr, vectorWidthValue, + rewriter.getStringAttr("NORM")); + Value srcVec = srcVecOp.getResult(); + srcPtr = srcVecOp.getUpdatedSource(); + + Value reduced; + if (family == "rowsum") + reduced = rewriter.create(loc, vecType, srcVec, srcPredicate); + else if (family == "rowmax") + reduced = rewriter.create(loc, vecType, srcVec, srcPredicate); + else if (family == "rowmin") + reduced = rewriter.create(loc, vecType, srcVec, srcPredicate); + else + return emitError(loc) << "unsupported VPTO row-reduce family: " << family; + + Value fullMask = buildAllPredicateMask(rewriter, loc, contract.elementType); + if (family == "rowsum") + acc = rewriter.create(loc, vecType, acc, reduced, fullMask); + else if (family == "rowmax") + acc = rewriter.create(loc, vecType, acc, reduced, fullMask); + else + acc = rewriter.create(loc, vecType, acc, reduced, fullMask); + remainingCols = predicateState.nextScalar; + } + + auto storeOp = rewriter.create(loc, dstPtr.getType(), acc, dstPtr, + dstRowStrideValue, storeDist, + dstPredicate); + Value nextDstPtr = storeOp.getUpdatedDestination(); + rewriter.create(loc, nextDstPtr); + return success(); + } + + auto rowLoop = rewriter.create(loc, c0, rowsUpper, c1); + OpBuilder::InsertionGuard rowGuard(rewriter); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + Value row = rowLoop.getInductionVar(); + Value rowBase = rewriter.create(loc, row, srcRowStrideValue); + Value acc = rewriter.create(loc, vecType, initScalar); + for (int64_t repeatIndex = 0; repeatIndex < repeatTimes; ++repeatIndex) { + Value repeat = rewriter.create(loc, repeatIndex); + Value repeatBase = + rewriter.create(loc, repeat, vectorWidthValue); + Value srcOffset = + rewriter.create(loc, rowBase, repeatBase); + Value remainingCols = + rewriter.create(loc, validColsValue, repeatBase); + Value srcPredicate = buildPredicateMaskForLaneCount( + rewriter, loc, contract.elementType, remainingCols); + Value srcVec = + rewriter.create(loc, vecType, srcBuffer, srcOffset, + StringAttr()) + .getResult(); + + Value reduced; + if (family == "rowsum") + reduced = rewriter.create(loc, vecType, srcVec, srcPredicate); + else if (family == "rowmax") + reduced = rewriter.create(loc, vecType, srcVec, srcPredicate); + else if (family == "rowmin") + reduced = rewriter.create(loc, vecType, srcVec, srcPredicate); + else + return emitError(loc) << "unsupported VPTO row-reduce family: " << family; + + Value fullMask = buildAllPredicateMask(rewriter, loc, contract.elementType); + if (family == "rowsum") + acc = rewriter.create(loc, vecType, acc, reduced, fullMask); + else if (family == "rowmax") + acc = rewriter.create(loc, vecType, acc, reduced, fullMask); + else + acc = rewriter.create(loc, vecType, acc, reduced, fullMask); + } + + Value dstOffset = rewriter.create(loc, row, dstRowStrideValue); + rewriter.create(loc, acc, dstBuffer, dstOffset, storeDist, + dstPredicate); + return success(); +} + +LogicalResult buildColReduceVecScope(StringRef family, + const VPTOColReduceContract &contract, + Value src, Value dst, Value tmp, + PatternRewriter &rewriter, Location loc) { + auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); + if (!vecType) + return emitError(loc) << "unsupported VPTO col-reduce element type"; + + Value srcBuffer = materializeBufferPointer(src, contract.elementType, + getMemorySpace(src), rewriter, loc); + Value dstBuffer = materializeBufferPointer(dst, contract.elementType, + getMemorySpace(dst), rewriter, loc); + if (!srcBuffer || !dstBuffer) + return emitError(loc) << "requires pointer-backed tile buffers for col-reduce lowering"; + + Value tmpBuffer; + if (contract.isBinary) { + tmpBuffer = materializeBufferPointer(tmp, contract.elementType, getMemorySpace(tmp), + rewriter, loc); + if (!tmpBuffer) + return emitError(loc) << "binary colsum lowering requires pointer-backed tmp tile"; + } + + int64_t srcRowStride = deriveStaticRowStride(src); + int64_t dstRowStride = deriveStaticRowStride(dst); + int64_t tmpRowStride = + contract.isBinary ? deriveStaticRowStride(tmp) : ShapedType::kDynamic; + if (srcRowStride == ShapedType::kDynamic || dstRowStride == ShapedType::kDynamic || + (contract.isBinary && tmpRowStride == ShapedType::kDynamic)) + return emitError(loc) << family << " lowering requires static row strides"; + + Attribute initValue = buildRowReduceInitValue(contract.elementType, family, rewriter); + if (!initValue) + return emitError(loc) << family << " lowering supports only f16 and f32 element types"; + + int64_t vectorWidth = vecType.getElementCount(); + int64_t repeatTimes = llvm::divideCeil(contract.validCols, vectorWidth); + + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value repeatUpper = rewriter.create(loc, repeatTimes); + Value rowUpper = rewriter.create(loc, contract.validRows); + Value srcRowStrideValue = rewriter.create(loc, srcRowStride); + Value dstRowStrideValue = rewriter.create(loc, dstRowStride); + Value vectorWidthValue = rewriter.create(loc, vectorWidth); + Value initScalar = rewriter.create(loc, cast(initValue)); + + FailureOr vecScope = + createLoopScopeRegion(loc, contract.loopScope, rewriter); + if (failed(vecScope)) + return emitError(loc) << "failed to create AIV vector scope region"; + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + auto chunkLoop = rewriter.create(loc, c0, repeatUpper, c1); + + OpBuilder::InsertionGuard chunkGuard(rewriter); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + Value chunk = chunkLoop.getInductionVar(); + Value chunkOffset = rewriter.create(loc, chunk, vectorWidthValue); + + if (!contract.isBinary) { + Value firstRowOffset = chunkOffset; + Value acc0 = + rewriter.create(loc, vecType, srcBuffer, firstRowOffset, StringAttr()).getResult(); + auto rowLoop = rewriter.create(loc, c1, rowUpper, c1, ValueRange{acc0}); + OpBuilder::InsertionGuard rowGuard(rewriter); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + Value row = rowLoop.getInductionVar(); + Value acc = rowLoop.getRegionIterArgs().front(); + Value rowBase = rewriter.create(loc, row, srcRowStrideValue); + Value srcOffset = rewriter.create(loc, rowBase, chunkOffset); + Value srcVec = + rewriter.create(loc, vecType, srcBuffer, srcOffset, StringAttr()).getResult(); + Value nextAcc; + Value fullMask = buildAllPredicateMask(rewriter, loc, contract.elementType); + if (family == "colmax") + nextAcc = rewriter.create(loc, vecType, acc, srcVec, fullMask); + else if (family == "colmin") + nextAcc = rewriter.create(loc, vecType, acc, srcVec, fullMask); + else + nextAcc = rewriter.create(loc, vecType, acc, srcVec, fullMask); + rewriter.create(loc, nextAcc); + + rewriter.setInsertionPointAfter(rowLoop); + Value dstOffset = chunkOffset; + rewriter.create( + loc, rowLoop.getResult(0), dstBuffer, dstOffset, StringAttr(), + buildAllPredicateMask(rewriter, loc, contract.elementType)); + return success(); + } + + Value tmpRowStrideValue = rewriter.create(loc, tmpRowStride); + auto reducePair = [&](Value lhs, Value rhs) -> Value { + return rewriter.create( + loc, vecType, lhs, rhs, buildAllPredicateMask(rewriter, loc, contract.elementType)) + .getResult(); + }; + + int64_t nLoopStatic = contract.validRows / 2; + bool remainStatic = (contract.validRows % 2) != 0; + Value pairUpper = rewriter.create(loc, nLoopStatic); + auto pairLoop = rewriter.create(loc, c0, pairUpper, c1); + { + OpBuilder::InsertionGuard pairGuard(rewriter); + rewriter.setInsertionPointToStart(pairLoop.getBody()); + Value pair = pairLoop.getInductionVar(); + Value row0 = rewriter.create( + loc, rewriter.create(loc, pair, rewriter.create(loc, 2)), + srcRowStrideValue); + Value row1 = rewriter.create( + loc, rewriter.create(loc, + rewriter.create(loc, pair, rewriter.create(loc, 2)), + c1), + srcRowStrideValue); + Value src0Offset = rewriter.create(loc, row0, chunkOffset); + Value src1Offset = rewriter.create(loc, row1, chunkOffset); + Value lhs = rewriter.create(loc, vecType, srcBuffer, src0Offset, StringAttr()).getResult(); + Value rhs = rewriter.create(loc, vecType, srcBuffer, src1Offset, StringAttr()).getResult(); + Value sum = reducePair(lhs, rhs); + Value tmpOffset = rewriter.create(loc, pair, tmpRowStrideValue); + rewriter.create(loc, sum, tmpBuffer, tmpOffset, StringAttr(), + buildAllPredicateMask(rewriter, loc, + contract.elementType)); + } + + if (remainStatic && nLoopStatic > 0) { + Value lastRowOffset = rewriter.create( + loc, + rewriter.create( + loc, rewriter.create(loc, contract.validRows - 1), + srcRowStrideValue), + chunkOffset); + Value tmpOffset = rewriter.create( + loc, rewriter.create(loc, nLoopStatic - 1), tmpRowStrideValue); + Value lhs = rewriter.create(loc, vecType, srcBuffer, lastRowOffset, StringAttr()).getResult(); + Value rhs = rewriter.create(loc, vecType, tmpBuffer, tmpOffset, StringAttr()).getResult(); + Value sum = reducePair(lhs, rhs); + rewriter.create(loc, sum, tmpBuffer, tmpOffset, StringAttr(), + buildAllPredicateMask(rewriter, loc, + contract.elementType)); + } + + int64_t currentRows = nLoopStatic; + while (currentRows > 1) { + int64_t nextRows = currentRows / 2; + bool remain = (currentRows % 2) != 0; + Value nextUpper = rewriter.create(loc, nextRows); + auto foldLoop = rewriter.create(loc, c0, nextUpper, c1); + OpBuilder::InsertionGuard foldGuard(rewriter); + rewriter.setInsertionPointToStart(foldLoop.getBody()); + Value pair = foldLoop.getInductionVar(); + Value idx2 = rewriter.create( + loc, pair, rewriter.create(loc, 2)); + Value idx2p1 = rewriter.create(loc, idx2, c1); + Value lhsOff = rewriter.create(loc, idx2, tmpRowStrideValue); + Value rhsOff = rewriter.create(loc, idx2p1, tmpRowStrideValue); + Value lhs = rewriter.create(loc, vecType, tmpBuffer, lhsOff, StringAttr()).getResult(); + Value rhs = rewriter.create(loc, vecType, tmpBuffer, rhsOff, StringAttr()).getResult(); + Value sum = reducePair(lhs, rhs); + Value outOff = rewriter.create(loc, pair, tmpRowStrideValue); + rewriter.create(loc, sum, tmpBuffer, outOff, StringAttr(), + buildAllPredicateMask(rewriter, loc, + contract.elementType)); + + rewriter.setInsertionPointAfter(foldLoop); + if (remain && nextRows > 0) { + Value lhsOff = rewriter.create( + loc, rewriter.create(loc, nextRows - 1), tmpRowStrideValue); + Value rhsOff = rewriter.create( + loc, rewriter.create(loc, 2 * nextRows), tmpRowStrideValue); + Value lhs = rewriter.create(loc, vecType, tmpBuffer, lhsOff, StringAttr()).getResult(); + Value rhs = rewriter.create(loc, vecType, tmpBuffer, rhsOff, StringAttr()).getResult(); + Value sum = reducePair(lhs, rhs); + rewriter.create(loc, sum, tmpBuffer, lhsOff, StringAttr(), + buildAllPredicateMask(rewriter, loc, + contract.elementType)); + } + currentRows = nextRows; + } + + Value finalVec; + if (currentRows == 0) { + finalVec = rewriter.create(loc, vecType, initScalar).getResult(); + } else { + finalVec = rewriter.create(loc, vecType, tmpBuffer, c0, StringAttr()).getResult(); + } + Value dstOffset = chunkOffset; + rewriter.create(loc, finalVec, dstBuffer, dstOffset, StringAttr(), + buildAllPredicateMask(rewriter, loc, + contract.elementType)); + return success(); +} + +LogicalResult buildPartFill(StringRef family, const VPTOPartContract &contract, + Value dstBuffer, pto::VRegType vecType, + int64_t dstStride, PatternRewriter &rewriter, + Location loc) { + Attribute initValue = buildPartPadValue(contract.elementType, family, rewriter); + if (!initValue) + return emitError(loc) << "unsupported pad value for " << family; + int64_t vectorWidth = vecType.getElementCount(); + int64_t repeatTimes = llvm::divideCeil(contract.dstValidCols, vectorWidth); + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value rowsUpper = rewriter.create(loc, contract.dstValidRows); + Value repeatUpper = rewriter.create(loc, repeatTimes); + Value dstStrideValue = rewriter.create(loc, dstStride); + Value vectorWidthValue = rewriter.create(loc, vectorWidth); + Value initScalar = rewriter.create(loc, cast(initValue)); + Value initVec = rewriter.create(loc, vecType, initScalar); + auto rowLoop = rewriter.create(loc, c0, rowsUpper, c1); + OpBuilder::InsertionGuard rowGuard(rewriter); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + auto chunkLoop = rewriter.create(loc, c0, repeatUpper, c1); + OpBuilder::InsertionGuard chunkGuard(rewriter); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + Value row = rowLoop.getInductionVar(); + Value chunk = chunkLoop.getInductionVar(); + Value rowBase = rewriter.create(loc, row, dstStrideValue); + Value chunkBase = rewriter.create(loc, chunk, vectorWidthValue); + Value dstOffset = rewriter.create(loc, rowBase, chunkBase); + rewriter.create(loc, initVec, dstBuffer, dstOffset, StringAttr(), + buildAllPredicateMask(rewriter, loc, + vecType.getElementType())); + rewriter.setInsertionPointAfter(chunkLoop); + return success(); +} + +LogicalResult buildPartCopyRegion(Value srcBuffer, Value dstBuffer, pto::VRegType vecType, + int64_t srcStride, int64_t dstStride, + int64_t startRow, int64_t validRows, + int64_t validCols, PatternRewriter &rewriter, + Location loc) { + int64_t vectorWidth = vecType.getElementCount(); + int64_t repeatTimes = llvm::divideCeil(validCols, vectorWidth); + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value rowsUpper = rewriter.create(loc, validRows); + Value repeatUpper = rewriter.create(loc, repeatTimes); + Value srcStrideValue = rewriter.create(loc, srcStride); + Value dstStrideValue = rewriter.create(loc, dstStride); + Value vectorWidthValue = rewriter.create(loc, vectorWidth); + Value startRowValue = rewriter.create(loc, startRow); + auto rowLoop = rewriter.create(loc, startRowValue, rowsUpper, c1); + OpBuilder::InsertionGuard rowGuard(rewriter); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + auto chunkLoop = rewriter.create(loc, c0, repeatUpper, c1); + OpBuilder::InsertionGuard chunkGuard(rewriter); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + Value row = rowLoop.getInductionVar(); + Value chunk = chunkLoop.getInductionVar(); + Value rowSrc = rewriter.create(loc, row, srcStrideValue); + Value rowDst = rewriter.create(loc, row, dstStrideValue); + Value chunkBase = rewriter.create(loc, chunk, vectorWidthValue); + Value srcOffset = rewriter.create(loc, rowSrc, chunkBase); + Value dstOffset = rewriter.create(loc, rowDst, chunkBase); + Value vec = rewriter.create(loc, vecType, srcBuffer, srcOffset, StringAttr()).getResult(); + rewriter.create(loc, vec, dstBuffer, dstOffset, StringAttr(), + buildAllPredicateMask(rewriter, loc, + vecType.getElementType())); + rewriter.setInsertionPointAfter(chunkLoop); + return success(); +} + +LogicalResult buildPartBinaryRegion(StringRef family, Value src0Buffer, Value src1Buffer, + Value dstBuffer, pto::VRegType vecType, + int64_t src0Stride, int64_t src1Stride, + int64_t dstStride, int64_t validRows, + int64_t validCols, PatternRewriter &rewriter, + Location loc) { + int64_t vectorWidth = vecType.getElementCount(); + int64_t repeatTimes = llvm::divideCeil(validCols, vectorWidth); + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value rowsUpper = rewriter.create(loc, validRows); + Value repeatUpper = rewriter.create(loc, repeatTimes); + Value src0StrideValue = rewriter.create(loc, src0Stride); + Value src1StrideValue = rewriter.create(loc, src1Stride); + Value dstStrideValue = rewriter.create(loc, dstStride); + Value vectorWidthValue = rewriter.create(loc, vectorWidth); + auto rowLoop = rewriter.create(loc, c0, rowsUpper, c1); + OpBuilder::InsertionGuard rowGuard(rewriter); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + auto chunkLoop = rewriter.create(loc, c0, repeatUpper, c1); + OpBuilder::InsertionGuard chunkGuard(rewriter); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + Value row = rowLoop.getInductionVar(); + Value chunk = chunkLoop.getInductionVar(); + Value chunkBase = rewriter.create(loc, chunk, vectorWidthValue); + Value rowSrc0 = rewriter.create(loc, row, src0StrideValue); + Value rowSrc1 = rewriter.create(loc, row, src1StrideValue); + Value rowDst = rewriter.create(loc, row, dstStrideValue); + Value src0Offset = rewriter.create(loc, rowSrc0, chunkBase); + Value src1Offset = rewriter.create(loc, rowSrc1, chunkBase); + Value dstOffset = rewriter.create(loc, rowDst, chunkBase); + Value lhs = rewriter.create(loc, vecType, src0Buffer, src0Offset, StringAttr()).getResult(); + Value rhs = rewriter.create(loc, vecType, src1Buffer, src1Offset, StringAttr()).getResult(); + Value fullMask = buildAllPredicateMask(rewriter, loc, vecType.getElementType()); + Value out; + if (family == "partadd") + out = rewriter.create(loc, vecType, lhs, rhs, fullMask); + else if (family == "partmax") + out = rewriter.create(loc, vecType, lhs, rhs, fullMask); + else if (family == "partmin") + out = rewriter.create(loc, vecType, lhs, rhs, fullMask); + else + return emitError(loc) << "unsupported part family: " << family; + rewriter.create(loc, out, dstBuffer, dstOffset, StringAttr(), + buildAllPredicateMask(rewriter, loc, + vecType.getElementType())); + rewriter.setInsertionPointAfter(chunkLoop); + return success(); +} + +LogicalResult buildPartVecScope(StringRef family, const VPTOPartContract &contract, + Value src0, Value src1, Value dst, + PatternRewriter &rewriter, Location loc) { + auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); + if (!vecType) + return emitError(loc) << "unsupported VPTO part element type"; + Value src0Buffer = materializeBufferLikeAddress(src0, contract.elementType, + getMemorySpace(src0), rewriter, loc); + Value src1Buffer = materializeBufferLikeAddress(src1, contract.elementType, + getMemorySpace(src1), rewriter, loc); + Value dstBuffer = materializeBufferLikeAddress(dst, contract.elementType, + getMemorySpace(dst), rewriter, loc); + if (!src0Buffer || !src1Buffer || !dstBuffer) + return emitError(loc) << "requires pointer-backed tile buffers for part lowering"; + int64_t src0Stride = deriveStaticRowStride(src0); + int64_t src1Stride = deriveStaticRowStride(src1); + int64_t dstStride = deriveStaticRowStride(dst); + if (src0Stride == ShapedType::kDynamic || src1Stride == ShapedType::kDynamic || + dstStride == ShapedType::kDynamic) + return emitError(loc) << family << " lowering requires static row strides"; + + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + FailureOr vecScope = + createLoopScopeRegion(loc, contract.loopScope, rewriter); + if (failed(vecScope)) + return emitError(loc) << "failed to create AIV vector scope region"; + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + + auto condSrc0EqDst = contract.src0ValidRows == contract.dstValidRows && + contract.src0ValidCols == contract.dstValidCols; + auto condSrc0RowLtDst = contract.src0ValidRows < contract.dstValidRows && + contract.src0ValidCols == contract.dstValidCols; + auto condSrc0ColLtDst = contract.src0ValidRows <= contract.dstValidRows && + contract.src0ValidCols < contract.dstValidCols; + auto condSrc1EqDst = contract.src1ValidRows == contract.dstValidRows && + contract.src1ValidCols == contract.dstValidCols; + auto condSrc1RowLtDst = contract.src1ValidRows < contract.dstValidRows && + contract.src1ValidCols == contract.dstValidCols; + auto condSrc1ColLtDst = contract.src1ValidRows <= contract.dstValidRows && + contract.src1ValidCols < contract.dstValidCols; + + if (family == "partadd") { + if (condSrc0EqDst && condSrc1EqDst) + return buildPartBinaryRegion(family, src0Buffer, src1Buffer, dstBuffer, vecType, + src0Stride, src1Stride, dstStride, + contract.dstValidRows, contract.dstValidCols, + rewriter, loc); + if (condSrc0ColLtDst && condSrc1EqDst) { + if (failed(buildPartCopyRegion(src1Buffer, dstBuffer, vecType, src1Stride, dstStride, + 0, contract.src1ValidRows, contract.dstValidCols, + rewriter, loc))) + return failure(); + if (contract.src0ValidCols != 0) + return buildPartBinaryRegion(family, src0Buffer, dstBuffer, dstBuffer, vecType, + src0Stride, dstStride, dstStride, + contract.src0ValidRows, contract.src0ValidCols, + rewriter, loc); + return success(); + } + if (condSrc0RowLtDst && condSrc1EqDst) { + if (contract.src0ValidRows != 0 && + failed(buildPartBinaryRegion(family, src0Buffer, src1Buffer, dstBuffer, vecType, + src0Stride, src1Stride, dstStride, + contract.src0ValidRows, contract.src0ValidCols, + rewriter, loc))) + return failure(); + return buildPartCopyRegion(src1Buffer, dstBuffer, vecType, src1Stride, dstStride, + contract.src0ValidRows, contract.src1ValidRows, + contract.dstValidCols, rewriter, loc); + } + if (condSrc1ColLtDst && condSrc0EqDst) { + if (failed(buildPartCopyRegion(src0Buffer, dstBuffer, vecType, src0Stride, dstStride, + 0, contract.src0ValidRows, contract.dstValidCols, + rewriter, loc))) + return failure(); + if (contract.src1ValidCols != 0) + return buildPartBinaryRegion(family, src1Buffer, dstBuffer, dstBuffer, vecType, + src1Stride, dstStride, dstStride, + contract.src1ValidRows, contract.src1ValidCols, + rewriter, loc); + return success(); + } + if (condSrc1RowLtDst && condSrc0EqDst) { + if (contract.src1ValidRows != 0 && + failed(buildPartBinaryRegion(family, src0Buffer, src1Buffer, dstBuffer, vecType, + src0Stride, src1Stride, dstStride, + contract.src1ValidRows, contract.src1ValidCols, + rewriter, loc))) + return failure(); + return buildPartCopyRegion(src0Buffer, dstBuffer, vecType, src0Stride, dstStride, + contract.src1ValidRows, contract.src0ValidRows, + contract.dstValidCols, rewriter, loc); + } + return emitError(loc) << "partadd lowering only supports PTO-covered destination-equality/extension cases"; + } + + bool condDstGeSrc = contract.src0ValidRows <= contract.dstValidRows && + contract.src0ValidCols <= contract.dstValidCols && + contract.src1ValidRows <= contract.dstValidRows && + contract.src1ValidCols <= contract.dstValidCols; + if (!condDstGeSrc) + return emitError(loc) << family << " lowering only supports dst >= src0/src1 shape relation"; + if (failed(buildPartFill(family, contract, dstBuffer, vecType, dstStride, rewriter, loc))) + return failure(); + if (failed(buildPartCopyRegion(src0Buffer, dstBuffer, vecType, src0Stride, dstStride, + 0, contract.src0ValidRows, contract.src0ValidCols, + rewriter, loc))) + return failure(); + return buildPartBinaryRegion(family, dstBuffer, src1Buffer, dstBuffer, vecType, + dstStride, src1Stride, dstStride, + contract.src1ValidRows, contract.src1ValidCols, + rewriter, loc); +} + +LogicalResult buildUnaryVecScope(StringRef family, + const VPTOUnaryContract &contract, + VPTOLoweringStrategy strategy, Value src, + Value dst, PatternRewriter &rewriter, + Location loc) { + auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); + if (!vecType) + return emitError(loc) << "unsupported VPTO unary element type"; + + Value srcBuffer = materializeBufferPointer(src, contract.elementType, + getMemorySpace(src), rewriter, loc); + Value dstBuffer = materializeBufferPointer(dst, contract.elementType, + getMemorySpace(dst), rewriter, loc); + if (!srcBuffer || !dstBuffer) + return emitError(loc) << "requires pointer-backed tile buffers for unary lowering"; + + int64_t vectorWidth = vecType.getElementCount(); + Value validRowsValue; + Value validColsValue; + int64_t validRows = ShapedType::kDynamic; + int64_t validCols = ShapedType::kDynamic; + deriveValidShapeValues(dst, validRowsValue, validColsValue); + deriveValidShape(dst, validRows, validCols); + if (failed(resolveExecutionValidShape(dst, validRowsValue, validColsValue, validRows, + validCols, rewriter, loc))) + return emitError(loc) << "unary lowering requires valid rows and cols"; + + int64_t srcStride = deriveStaticRowStride(src); + int64_t dstStride = deriveStaticRowStride(dst); + int64_t srcCols = deriveStaticTileCols(src); + int64_t dstCols = deriveStaticTileCols(dst); + if (srcStride == ShapedType::kDynamic || dstStride == ShapedType::kDynamic || + srcCols == ShapedType::kDynamic || dstCols == ShapedType::kDynamic) + return emitError(loc) << "unary lowering requires static row strides and cols"; + + auto buildUnaryValue = [&](Value loaded, Value predicate) -> FailureOr { + if (family == "abs") + return rewriter.create(loc, vecType, loaded, predicate).getResult(); + if (family == "exp") + return rewriter.create(loc, vecType, loaded, predicate).getResult(); + if (family == "log") + return rewriter.create(loc, vecType, loaded, predicate).getResult(); + if (family == "sqrt") + return rewriter.create(loc, vecType, loaded, predicate).getResult(); + if (family == "relu") + return rewriter.create(loc, vecType, loaded, predicate).getResult(); + if (family == "not") + return rewriter.create(loc, vecType, loaded, predicate).getResult(); + return failure(); + }; + + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value totalElementsValue = + rewriter.create(loc, validRowsValue, validColsValue); + Value vectorStepValue = + rewriter.create(loc, vectorWidth); + Value srcStrideValue = rewriter.create(loc, srcStride); + Value dstStrideValue = rewriter.create(loc, dstStride); + Value scalarInit = rewriter.create(loc, rewriter.getI32Type(), + totalElementsValue); + Value rowScalarInit = rewriter.create(loc, rewriter.getI32Type(), + validColsValue); + Value fullWidthCond = + buildFullWidthColsCondition({srcCols, dstCols}, validColsValue, rewriter, loc); + if (!fullWidthCond) + return emitError(loc) << "unary lowering could not materialize full-width selector"; + + FailureOr vecScope = + createLoopScopeRegion(loc, contract.loopScope, rewriter); + if (failed(vecScope)) + return emitError(loc) << "failed to create AIV vector scope region"; + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + auto ifOp = rewriter.create(loc, TypeRange{}, fullWidthCond, + /*withElseRegion=*/true); + rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); + { + scf::ForOp chunkLoop; + if (strategy == VPTOLoweringStrategy::PostUpdate) { + chunkLoop = rewriter.create( + loc, c0, totalElementsValue, vectorStepValue, + ValueRange{srcBuffer, dstBuffer, scalarInit}); + } else { + chunkLoop = rewriter.create(loc, c0, totalElementsValue, + vectorStepValue, + ValueRange{scalarInit}); + } + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + Value remaining = chunkLoop.getRegionIterArgs().back(); + PredicateMaterialization predicateState = + buildPredicateForLaneCount(rewriter, loc, contract.elementType, remaining); + Value loadBase = srcBuffer; + Value storeBase = dstBuffer; + Value loadOffset = chunkLoop.getInductionVar(); + Value storeOffset = chunkLoop.getInductionVar(); + if (strategy == VPTOLoweringStrategy::PostUpdate) { + loadBase = chunkLoop.getRegionIterArgs()[0]; + storeBase = chunkLoop.getRegionIterArgs()[1]; + loadOffset = vectorStepValue; + storeOffset = vectorStepValue; + } + Value loaded; + Value nextSrc = {}; + if (strategy == VPTOLoweringStrategy::PostUpdate) { + auto vlds = rewriter.create(loc, vecType, loadBase.getType(), + loadBase, loadOffset, StringAttr()); + loaded = vlds.getResult(); + nextSrc = vlds.getUpdatedSource(); + } else { + auto vlds = + rewriter.create(loc, vecType, loadBase, loadOffset, StringAttr()); + loaded = vlds.getResult(); + } + FailureOr computed = buildUnaryValue(loaded, predicateState.mask); + if (failed(computed)) + return emitError(loc) << "unsupported VPTO unary family: " << family; + if (strategy == VPTOLoweringStrategy::PostUpdate) { + auto vsts = rewriter.create(loc, storeBase.getType(), *computed, + storeBase, storeOffset, StringAttr(), + predicateState.mask); + Value nextDst = vsts.getUpdatedDestination(); + rewriter.create( + loc, ValueRange{nextSrc, nextDst, predicateState.nextScalar}); + } else { + rewriter.create(loc, *computed, storeBase, storeOffset, + StringAttr(), predicateState.mask); + rewriter.create(loc, ValueRange{predicateState.nextScalar}); + } + } + + rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); + { + Value repeatUpper = rewriter.create(loc, validColsValue, + vectorStepValue); + auto rowLoop = rewriter.create(loc, c0, validRowsValue, c1); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + Value row = rowLoop.getInductionVar(); + Value srcRowBase = rewriter.create(loc, row, srcStrideValue); + Value dstRowBase = rewriter.create(loc, row, dstStrideValue); + auto repeatLoop = rewriter.create(loc, c0, repeatUpper, c1, + ValueRange{rowScalarInit}); + rewriter.setInsertionPointToStart(repeatLoop.getBody()); + Value remaining = repeatLoop.getRegionIterArgs()[0]; + PredicateMaterialization predicateState = + buildPredicateForLaneCount(rewriter, loc, contract.elementType, remaining); + Value chunkBase = + rewriter.create(loc, repeatLoop.getInductionVar(), vectorStepValue); + Value srcOffset = rewriter.create(loc, srcRowBase, chunkBase); + Value dstOffset = rewriter.create(loc, dstRowBase, chunkBase); + auto loaded = + rewriter.create(loc, vecType, srcBuffer, srcOffset, StringAttr()); + FailureOr computed = + buildUnaryValue(loaded.getResult(), predicateState.mask); + if (failed(computed)) + return emitError(loc) << "unsupported VPTO unary family: " << family; + rewriter.create(loc, *computed, dstBuffer, dstOffset, + StringAttr(), predicateState.mask); + rewriter.create(loc, ValueRange{predicateState.nextScalar}); + } + rewriter.setInsertionPointAfter(ifOp); + + return success(); +} + +LogicalResult buildBinaryVecScope(StringRef family, + const VPTOBinaryContract &contract, + VPTOLoweringStrategy strategy, Value src0, + Value src1, Value dst, + PatternRewriter &rewriter, Location loc) { + auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); + if (!vecType) + return emitError(loc) << "unsupported VPTO binary element type"; + + Value src0Buffer = materializeBufferPointer(src0, contract.elementType, + getMemorySpace(src0), rewriter, loc); + Value src1Buffer = materializeBufferPointer(src1, contract.elementType, + getMemorySpace(src1), rewriter, loc); + Value dstBuffer = materializeBufferPointer(dst, contract.elementType, + getMemorySpace(dst), rewriter, loc); + if (!src0Buffer || !src1Buffer || !dstBuffer) + return emitError(loc) << "requires pointer-backed tile buffers for binary lowering"; + + int64_t vectorWidth = vecType.getElementCount(); + Value validRowsValue = contract.validRowsValue; + Value validColsValue = contract.validColsValue; + int64_t validRows = contract.validRows; + int64_t validCols = contract.validCols; + if (failed(resolveExecutionValidShape(dst, validRowsValue, validColsValue, validRows, + validCols, rewriter, loc))) + return emitError(loc) << "binary lowering requires valid rows and cols"; + + int64_t src0Stride = deriveStaticRowStride(src0); + int64_t src1Stride = deriveStaticRowStride(src1); + int64_t dstStride = deriveStaticRowStride(dst); + int64_t src0Cols = deriveStaticTileCols(src0); + int64_t src1Cols = deriveStaticTileCols(src1); + int64_t dstCols = deriveStaticTileCols(dst); + if (src0Stride == ShapedType::kDynamic || src1Stride == ShapedType::kDynamic || + dstStride == ShapedType::kDynamic || src0Cols == ShapedType::kDynamic || + src1Cols == ShapedType::kDynamic || dstCols == ShapedType::kDynamic) + return emitError(loc) << "binary lowering requires static row strides and cols"; + + auto buildBinaryValue = [&](Value lhs, Value rhs, Value mask) -> FailureOr { + if (family == "add") + return rewriter.create(loc, vecType, lhs, rhs, mask).getResult(); + if (family == "sub") + return rewriter.create(loc, vecType, lhs, rhs, mask).getResult(); + if (family == "mul") + return rewriter.create(loc, vecType, lhs, rhs, mask).getResult(); + if (family == "div") + return rewriter.create(loc, vecType, lhs, rhs, mask).getResult(); + if (family == "max") + return rewriter.create(loc, vecType, lhs, rhs, mask).getResult(); + if (family == "min") + return rewriter.create(loc, vecType, lhs, rhs, mask).getResult(); + if (family == "and") + return rewriter.create(loc, vecType, lhs, rhs, mask).getResult(); + if (family == "or") + return rewriter.create(loc, vecType, lhs, rhs, mask).getResult(); + if (family == "xor") + return rewriter.create(loc, vecType, lhs, rhs, mask).getResult(); + return failure(); + }; + + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value totalElementsValue = + rewriter.create(loc, validRowsValue, validColsValue); + Value vectorStepValue = + rewriter.create(loc, vectorWidth); + Value src0StrideValue = rewriter.create(loc, src0Stride); + Value src1StrideValue = rewriter.create(loc, src1Stride); + Value dstStrideValue = rewriter.create(loc, dstStride); + Value scalarInit = rewriter.create(loc, rewriter.getI32Type(), + totalElementsValue); + Value rowScalarInit = rewriter.create(loc, rewriter.getI32Type(), + validColsValue); + bool sameShapeLinearPath = src0Stride == dstStride && src1Stride == dstStride && + src0Cols == dstCols && src1Cols == dstCols; + Value fullWidthCond = buildFullWidthColsCondition( + {src0Cols, src1Cols, dstCols}, validColsValue, rewriter, loc); + if (!fullWidthCond) + return emitError(loc) << "binary lowering could not materialize full-width selector"; + Value use1DCond = sameShapeLinearPath ? fullWidthCond : Value(); + + FailureOr vecScope = + createLoopScopeRegion(loc, contract.loopScope, rewriter); + if (failed(vecScope)) + return emitError(loc) << "failed to create AIV vector scope region"; + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + auto emit1DBody = [&]() -> LogicalResult { + scf::ForOp chunkLoop; + if (strategy == VPTOLoweringStrategy::PostUpdate) { + chunkLoop = rewriter.create( + loc, c0, totalElementsValue, vectorStepValue, + ValueRange{src0Buffer, src1Buffer, dstBuffer, scalarInit}); + } else { + chunkLoop = rewriter.create(loc, c0, totalElementsValue, + vectorStepValue, + ValueRange{scalarInit}); + } + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + Value remaining = chunkLoop.getRegionIterArgs().back(); + PredicateMaterialization predicateState = + buildPredicateForLaneCount(rewriter, loc, contract.elementType, remaining); + Value lhsBase = src0Buffer; + Value rhsBase = src1Buffer; + Value dstBase = dstBuffer; + Value loadOffset = chunkLoop.getInductionVar(); + Value storeOffset = chunkLoop.getInductionVar(); + if (strategy == VPTOLoweringStrategy::PostUpdate) { + lhsBase = chunkLoop.getRegionIterArgs()[0]; + rhsBase = chunkLoop.getRegionIterArgs()[1]; + dstBase = chunkLoop.getRegionIterArgs()[2]; + loadOffset = vectorStepValue; + storeOffset = vectorStepValue; + } + Value lhsValue; + Value rhsValue; + Value nextSrc0 = {}; + Value nextSrc1 = {}; + if (strategy == VPTOLoweringStrategy::PostUpdate) { + auto lhs = rewriter.create(loc, vecType, lhsBase.getType(), + lhsBase, loadOffset, StringAttr()); + auto rhs = rewriter.create(loc, vecType, rhsBase.getType(), + rhsBase, loadOffset, StringAttr()); + lhsValue = lhs.getResult(); + rhsValue = rhs.getResult(); + nextSrc0 = lhs.getUpdatedSource(); + nextSrc1 = rhs.getUpdatedSource(); + } else { + auto lhs = + rewriter.create(loc, vecType, lhsBase, loadOffset, StringAttr()); + auto rhs = + rewriter.create(loc, vecType, rhsBase, loadOffset, StringAttr()); + lhsValue = lhs.getResult(); + rhsValue = rhs.getResult(); + } + FailureOr computed = buildBinaryValue(lhsValue, rhsValue, predicateState.mask); + if (failed(computed)) + return emitError(loc) << "unsupported VPTO binary family: " << family; + if (strategy == VPTOLoweringStrategy::PostUpdate) { + auto vsts = rewriter.create(loc, dstBase.getType(), *computed, + dstBase, storeOffset, StringAttr(), + predicateState.mask); + Value nextDst = vsts.getUpdatedDestination(); + rewriter.create( + loc, + ValueRange{nextSrc0, nextSrc1, nextDst, predicateState.nextScalar}); + } else { + rewriter.create(loc, *computed, dstBase, storeOffset, + StringAttr(), predicateState.mask); + rewriter.create(loc, ValueRange{predicateState.nextScalar}); + } + return success(); + }; + + auto emit2DBody = [&]() -> LogicalResult { + Value repeatUpper = rewriter.create(loc, validColsValue, + vectorStepValue); + auto rowLoop = rewriter.create(loc, c0, validRowsValue, c1); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + Value row = rowLoop.getInductionVar(); + Value src0RowBase = rewriter.create(loc, row, src0StrideValue); + Value src1RowBase = rewriter.create(loc, row, src1StrideValue); + Value dstRowBase = rewriter.create(loc, row, dstStrideValue); + + if (strategy == VPTOLoweringStrategy::NoPostUpdate) { + auto repeatLoop = rewriter.create(loc, c0, repeatUpper, c1, + ValueRange{rowScalarInit}); + rewriter.setInsertionPointToStart(repeatLoop.getBody()); + Value remaining = repeatLoop.getRegionIterArgs()[0]; + PredicateMaterialization predicateState = + buildPredicateForLaneCount(rewriter, loc, contract.elementType, remaining); + Value chunkBase = + rewriter.create(loc, repeatLoop.getInductionVar(), vectorStepValue); + Value src0Offset = rewriter.create(loc, src0RowBase, chunkBase); + Value src1Offset = rewriter.create(loc, src1RowBase, chunkBase); + Value dstOffset = rewriter.create(loc, dstRowBase, chunkBase); + auto lhs = rewriter.create(loc, vecType, src0Buffer, src0Offset, + StringAttr()); + auto rhs = rewriter.create(loc, vecType, src1Buffer, src1Offset, + StringAttr()); + FailureOr computed = + buildBinaryValue(lhs.getResult(), rhs.getResult(), predicateState.mask); + if (failed(computed)) + return emitError(loc) << "unsupported VPTO binary family: " << family; + rewriter.create(loc, *computed, dstBuffer, dstOffset, + StringAttr(), predicateState.mask); + rewriter.create(loc, ValueRange{predicateState.nextScalar}); + return success(); + } + + auto repeatLoop = rewriter.create(loc, c0, repeatUpper, c1); + rewriter.setInsertionPointToStart(repeatLoop.getBody()); + Value chunkBase = + rewriter.create(loc, repeatLoop.getInductionVar(), vectorStepValue); + Value src0Offset = rewriter.create(loc, src0RowBase, chunkBase); + Value src1Offset = rewriter.create(loc, src1RowBase, chunkBase); + Value dstOffset = rewriter.create(loc, dstRowBase, chunkBase); + Value nextChunk = rewriter.create(loc, chunkBase, vectorStepValue); + Value exceeds = + rewriter.create(loc, arith::CmpIPredicate::sge, nextChunk, validColsValue); + Value tailCount = rewriter.create(loc, validColsValue, chunkBase); + Value activeLanes = + rewriter.create(loc, exceeds, tailCount, vectorStepValue); + Value predicate = buildPredicateMaskForLaneCount(rewriter, loc, + contract.elementType, activeLanes); + auto lhs = + rewriter.create(loc, vecType, src0Buffer, src0Offset, StringAttr()); + auto rhs = + rewriter.create(loc, vecType, src1Buffer, src1Offset, StringAttr()); + FailureOr computed = buildBinaryValue(lhs.getResult(), rhs.getResult(), predicate); + if (failed(computed)) + return emitError(loc) << "unsupported VPTO binary family: " << family; + rewriter.create(loc, *computed, dstBuffer, dstOffset, + StringAttr(), predicate); + return success(); + }; + + if (use1DCond) { + auto ifOp = rewriter.create(loc, TypeRange{}, use1DCond, + /*withElseRegion=*/true); + rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); + if (failed(emit1DBody())) + return failure(); + rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); + if (failed(emit2DBody())) + return failure(); + rewriter.setInsertionPointAfter(ifOp); + } else { + if (failed(emit2DBody())) + return failure(); + } + return success(); +} + +LogicalResult buildExpandScalarVecScope(const VPTOUnaryContract &contract, + Value scalar, Value dst, + PatternRewriter &rewriter, + Location loc) { + auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); + if (!vecType) + return emitError(loc) << "unsupported VPTO expands element type"; + + Value dstBuffer = materializeBufferPointer(dst, contract.elementType, + getMemorySpace(dst), rewriter, loc); + if (!dstBuffer) + return emitError(loc) << "requires pointer-backed tile buffer for expands lowering"; + + Value validRowsValue = materializeIndexValue(contract.validRowsValue, + contract.validRows, rewriter, loc); + Value validColsValue = materializeIndexValue(contract.validColsValue, + contract.validCols, rewriter, loc); + if (!validRowsValue || !validColsValue) + return emitError(loc) << "expands lowering requires valid rows and cols"; + + int64_t vectorWidth = vecType.getElementCount(); + int64_t dstStride = deriveStaticRowStride(dst); + int64_t dstCols = deriveStaticTileCols(dst); + if (dstStride == ShapedType::kDynamic || dstCols == ShapedType::kDynamic) + return emitError(loc) << "expands lowering requires static destination row stride and cols"; + + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value totalElementsValue = + rewriter.create(loc, validRowsValue, validColsValue); + Value vectorStepValue = + rewriter.create(loc, vectorWidth); + Value dstStrideValue = rewriter.create(loc, dstStride); + Value fullWidthCond = + buildFullWidthColsCondition({dstCols}, validColsValue, rewriter, loc); + if (!fullWidthCond) + return emitError(loc) << "expands lowering could not materialize full-width selector"; + + FailureOr vecScope = + createLoopScopeRegion(loc, contract.loopScope, rewriter); + if (failed(vecScope)) + return emitError(loc) << "failed to create AIV vector scope region"; + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + auto ifOp = rewriter.create(loc, TypeRange{}, fullWidthCond, + /*withElseRegion=*/true); + + rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); + { + Value scalarInit = rewriter.create( + loc, rewriter.getI32Type(), totalElementsValue); + auto chunkLoop = rewriter.create( + loc, c0, totalElementsValue, vectorStepValue, + ValueRange{dstBuffer, scalarInit}); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + Value dstPtr = chunkLoop.getRegionIterArgs()[0]; + Value remaining = chunkLoop.getRegionIterArgs()[1]; + PredicateMaterialization predicateState = buildPredicateForLaneCount( + rewriter, loc, contract.elementType, remaining); + Value computed = + rewriter.create(loc, vecType, scalar, predicateState.mask, StringAttr()); + auto vsts = rewriter.create(loc, dstPtr.getType(), computed, dstPtr, + vectorStepValue, StringAttr(), + predicateState.mask); + Value nextDst = vsts.getUpdatedDestination(); + rewriter.create( + loc, ValueRange{nextDst, predicateState.nextScalar}); + } + + rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); + { + Value repeatUpper = rewriter.create(loc, validColsValue, + vectorStepValue); + auto rowLoop = rewriter.create(loc, c0, validRowsValue, c1); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + Value row = rowLoop.getInductionVar(); + Value rowBase = rewriter.create(loc, row, dstStrideValue); + auto repeatLoop = rewriter.create(loc, c0, repeatUpper, c1); + rewriter.setInsertionPointToStart(repeatLoop.getBody()); + Value repeat = repeatLoop.getInductionVar(); + Value chunkBase = rewriter.create(loc, repeat, vectorStepValue); + Value dstOffset = rewriter.create(loc, rowBase, chunkBase); + Value remainingCols = + rewriter.create(loc, validColsValue, chunkBase); + Value activeLanes = + buildMinIndexValue(rewriter, loc, remainingCols, vectorStepValue); + Value predicate = buildPredicateMaskForLaneCount( + rewriter, loc, contract.elementType, activeLanes); + Value computed = + rewriter.create(loc, vecType, scalar, predicate, StringAttr()); + rewriter.create(loc, computed, dstBuffer, dstOffset, + StringAttr(), predicate); + } + + rewriter.setInsertionPointAfter(ifOp); + return success(); +} + +LogicalResult buildScalarUnaryVecScope(StringRef family, + const VPTOUnaryContract &contract, + VPTOLoweringStrategy strategy, + Value src, Value scalar, Value dst, + PatternRewriter &rewriter, + Location loc) { + auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); + if (!vecType) + return emitError(loc) << "unsupported VPTO scalar-unary element type"; + + Value srcBuffer = materializeBufferPointer(src, contract.elementType, + getMemorySpace(src), rewriter, loc); + Value dstBuffer = materializeBufferPointer(dst, contract.elementType, + getMemorySpace(dst), rewriter, loc); + if (!srcBuffer || !dstBuffer) + return emitError(loc) + << "requires pointer-backed tile buffers for scalar-unary lowering"; + + Value validRowsValue = materializeIndexValue(contract.validRowsValue, + contract.validRows, rewriter, loc); + Value validColsValue = materializeIndexValue(contract.validColsValue, + contract.validCols, rewriter, loc); + if (!validRowsValue || !validColsValue) + return emitError(loc) << family << " lowering requires valid rows and cols"; + + int64_t vectorWidth = vecType.getElementCount(); + int64_t srcStride = deriveStaticRowStride(src); + int64_t dstStride = deriveStaticRowStride(dst); + int64_t srcCols = deriveStaticTileCols(src); + int64_t dstCols = deriveStaticTileCols(dst); + if (srcStride == ShapedType::kDynamic || dstStride == ShapedType::kDynamic || + srcCols == ShapedType::kDynamic || dstCols == ShapedType::kDynamic) + return emitError(loc) + << family << " lowering requires static src/dst row stride and cols"; + + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value totalElementsValue = + rewriter.create(loc, validRowsValue, validColsValue); + Value vectorStepValue = + rewriter.create(loc, vectorWidth); + Value srcStrideValue = rewriter.create(loc, srcStride); + Value dstStrideValue = rewriter.create(loc, dstStride); + Value fullWidthCond = buildFullWidthColsCondition( + {srcCols, dstCols}, validColsValue, rewriter, loc); + if (!fullWidthCond) + return emitError(loc) << family << " lowering could not materialize full-width selector"; + + FailureOr vecScope = + createLoopScopeRegion(loc, contract.loopScope, rewriter); + if (failed(vecScope)) + return emitError(loc) << "failed to create AIV vector scope region"; + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + auto ifOp = rewriter.create(loc, TypeRange{}, fullWidthCond, + /*withElseRegion=*/true); + + rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); + { + auto emitComputed = [&](Value loadedVec, Value predicate) -> FailureOr { + if (family == "adds") + return rewriter.create(loc, vecType, loadedVec, scalar, predicate).getResult(); + if (family == "maxs") + return rewriter.create(loc, vecType, loadedVec, scalar, predicate).getResult(); + if (family == "mins") + return rewriter.create(loc, vecType, loadedVec, scalar, predicate).getResult(); + if (family == "muls") + return rewriter.create(loc, vecType, loadedVec, scalar, predicate).getResult(); + if (family == "lrelu") + return rewriter.create(loc, vecType, loadedVec, scalar, predicate).getResult(); + return failure(); + }; + + if (strategy == VPTOLoweringStrategy::NoPostUpdate) { + auto chunkLoop = + rewriter.create(loc, c0, totalElementsValue, vectorStepValue); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + Value offset = chunkLoop.getInductionVar(); + Value remaining = rewriter.create(loc, totalElementsValue, offset); + Value activeLanes = + buildMinIndexValue(rewriter, loc, remaining, vectorStepValue); + Value predicate = buildPredicateMaskForLaneCount( + rewriter, loc, contract.elementType, activeLanes); + auto loaded = + rewriter.create(loc, vecType, srcBuffer, offset, StringAttr()); + FailureOr computed = emitComputed(loaded.getResult(), predicate); + if (failed(computed)) + return emitError(loc) << "unsupported VPTO scalar-unary family: " << family; + rewriter.create(loc, *computed, dstBuffer, offset, StringAttr(), + predicate); + } else { + Value scalarInit = rewriter.create( + loc, rewriter.getI32Type(), totalElementsValue); + auto chunkLoop = rewriter.create( + loc, c0, totalElementsValue, vectorStepValue, + ValueRange{srcBuffer, dstBuffer, scalarInit}); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + Value srcPtr = chunkLoop.getRegionIterArgs()[0]; + Value dstPtr = chunkLoop.getRegionIterArgs()[1]; + Value remaining = chunkLoop.getRegionIterArgs()[2]; + PredicateMaterialization predicateState = buildPredicateForLaneCount( + rewriter, loc, contract.elementType, remaining); + auto loaded = rewriter.create(loc, vecType, srcPtr.getType(), srcPtr, + vectorStepValue, StringAttr()); + FailureOr computed = emitComputed(loaded.getResult(), predicateState.mask); + if (failed(computed)) + return emitError(loc) << "unsupported VPTO scalar-unary family: " << family; + auto vsts = rewriter.create(loc, dstPtr.getType(), *computed, dstPtr, + vectorStepValue, StringAttr(), + predicateState.mask); + Value nextSrc = loaded.getUpdatedSource(); + Value nextDst = vsts.getUpdatedDestination(); + rewriter.create( + loc, ValueRange{nextSrc, nextDst, predicateState.nextScalar}); + } + } + + rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); + { + Value repeatUpper = rewriter.create(loc, validColsValue, + vectorStepValue); + auto rowLoop = rewriter.create(loc, c0, validRowsValue, c1); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + Value row = rowLoop.getInductionVar(); + Value srcRowBase = rewriter.create(loc, row, srcStrideValue); + Value dstRowBase = rewriter.create(loc, row, dstStrideValue); + auto repeatLoop = rewriter.create(loc, c0, repeatUpper, c1); + rewriter.setInsertionPointToStart(repeatLoop.getBody()); + Value repeat = repeatLoop.getInductionVar(); + Value chunkBase = rewriter.create(loc, repeat, vectorStepValue); + Value srcOffset = rewriter.create(loc, srcRowBase, chunkBase); + Value dstOffset = rewriter.create(loc, dstRowBase, chunkBase); + Value predicate; + if (strategy == VPTOLoweringStrategy::NoPostUpdate) { + predicate = + buildPredicateMaskForLaneCount(rewriter, loc, contract.elementType, validColsValue); + } else { + Value remainingCols = + rewriter.create(loc, validColsValue, chunkBase); + Value activeLanes = + buildMinIndexValue(rewriter, loc, remainingCols, vectorStepValue); + predicate = buildPredicateMaskForLaneCount( + rewriter, loc, contract.elementType, activeLanes); + } + auto loaded = + rewriter.create(loc, vecType, srcBuffer, srcOffset, StringAttr()); + Value computed; + if (family == "adds") + computed = rewriter.create(loc, vecType, loaded.getResult(), scalar, predicate); + else if (family == "maxs") + computed = rewriter.create(loc, vecType, loaded.getResult(), scalar, predicate); + else if (family == "mins") + computed = rewriter.create(loc, vecType, loaded.getResult(), scalar, predicate); + else if (family == "muls") + computed = rewriter.create(loc, vecType, loaded.getResult(), scalar, predicate); + else if (family == "lrelu") + computed = rewriter.create(loc, vecType, loaded.getResult(), scalar, predicate); + else + return emitError(loc) << "unsupported VPTO scalar-unary family: " << family; + rewriter.create(loc, computed, dstBuffer, dstOffset, + StringAttr(), predicate); + } + + rewriter.setInsertionPointAfter(ifOp); + return success(); +} + +LogicalResult buildScalarBitwiseVecScope(StringRef family, + const VPTOUnaryContract &contract, + Value src, Value scalar, Value dst, + PatternRewriter &rewriter, + Location loc) { + auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); + if (!vecType) + return emitError(loc) << "unsupported VPTO scalar-bitwise element type"; + + Value srcBuffer = materializeBufferPointer(src, contract.elementType, + getMemorySpace(src), rewriter, loc); + Value dstBuffer = materializeBufferPointer(dst, contract.elementType, + getMemorySpace(dst), rewriter, loc); + if (!srcBuffer || !dstBuffer) + return emitError(loc) + << "requires pointer-backed tile buffers for scalar-bitwise lowering"; + + Value validRowsValue; + Value validColsValue; + int64_t validRows = ShapedType::kDynamic; + int64_t validCols = ShapedType::kDynamic; + deriveValidShapeValues(dst, validRowsValue, validColsValue); + deriveValidShape(dst, validRows, validCols); + if (failed(resolveExecutionValidShape(dst, validRowsValue, validColsValue, validRows, + validCols, rewriter, loc))) + return emitError(loc) << family << " lowering requires valid rows and cols"; + + int64_t vectorWidth = vecType.getElementCount(); + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value totalElementsValue = + rewriter.create(loc, validRowsValue, validColsValue); + Value vectorStepValue = + rewriter.create(loc, vectorWidth); + Value vectorWidthValue = + rewriter.create(loc, vectorWidth); + + FailureOr vecScope = + createLoopScopeRegion(loc, contract.loopScope, rewriter); + if (failed(vecScope)) + return emitError(loc) << "failed to create AIV vector scope region"; + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + auto chunkLoop = + rewriter.create(loc, c0, totalElementsValue, vectorStepValue); + + OpBuilder::InsertionGuard chunkGuard(rewriter); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + Value offset = chunkLoop.getInductionVar(); + Value remaining = rewriter.create(loc, totalElementsValue, offset); + Value activeLanes = + buildMinIndexValue(rewriter, loc, remaining, vectorWidthValue); + Value predicate = + buildPredicateMaskForLaneCount(rewriter, loc, contract.elementType, activeLanes); + Value scalarVec = + rewriter.create(loc, vecType, scalar, predicate, StringAttr()); + auto loaded = rewriter.create(loc, vecType, srcBuffer, offset, + StringAttr()); + + Value computed; + if (family == "ands") + computed = + rewriter.create(loc, vecType, loaded.getResult(), scalarVec, predicate); + else if (family == "ors") + computed = + rewriter.create(loc, vecType, loaded.getResult(), scalarVec, predicate); + else if (family == "xors") + computed = + rewriter.create(loc, vecType, loaded.getResult(), scalarVec, predicate); + else + return emitError(loc) << "unsupported VPTO scalar-bitwise family: " << family; + rewriter.create(loc, computed, dstBuffer, offset, StringAttr(), + predicate); + return success(); +} + +static bool isVPTOShapedLikeValue(Value value) { + Type type = value.getType(); + return isa(type); +} + +LogicalResult buildScalarDivVecScope(const VPTOUnaryContract &contract, + VPTOLoweringStrategy strategy, + Value src, Value scalar, Value dst, + bool scalarFirst, + PatternRewriter &rewriter, Location loc) { + auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); + if (!vecType) + return emitError(loc) << "unsupported VPTO divs element type"; + + Value srcBuffer = materializeBufferPointer(src, contract.elementType, + getMemorySpace(src), rewriter, loc); + Value dstBuffer = materializeBufferPointer(dst, contract.elementType, + getMemorySpace(dst), rewriter, loc); + if (!srcBuffer || !dstBuffer) + return emitError(loc) + << "requires pointer-backed tile buffers for divs lowering"; + + Value validRowsValue = materializeIndexValue(contract.validRowsValue, + contract.validRows, rewriter, loc); + Value validColsValue = materializeIndexValue(contract.validColsValue, + contract.validCols, rewriter, loc); + if (!validRowsValue || !validColsValue) + return emitError(loc) << "divs lowering requires valid rows and cols"; + + int64_t vectorWidth = vecType.getElementCount(); + int64_t srcStride = deriveStaticRowStride(src); + int64_t dstStride = deriveStaticRowStride(dst); + int64_t srcCols = deriveStaticTileCols(src); + int64_t dstCols = deriveStaticTileCols(dst); + if (srcStride == ShapedType::kDynamic || dstStride == ShapedType::kDynamic || + srcCols == ShapedType::kDynamic || dstCols == ShapedType::kDynamic) + return emitError(loc) + << "divs lowering requires static src/dst row stride and cols"; + + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value totalElementsValue = + rewriter.create(loc, validRowsValue, validColsValue); + Value vectorStepValue = + rewriter.create(loc, vectorWidth); + Value srcStrideValue = rewriter.create(loc, srcStride); + Value dstStrideValue = rewriter.create(loc, dstStride); + Value fullWidthCond = buildFullWidthColsCondition( + {srcCols, dstCols}, validColsValue, rewriter, loc); + if (!fullWidthCond) + return emitError(loc) << "divs lowering could not materialize full-width selector"; + + auto buildDivValue = [&](Value loaded, Value predicate) -> FailureOr { + if (contract.elementType.isF32()) { + if (scalarFirst) { + Value scalarVec = + rewriter.create(loc, vecType, scalar, predicate, StringAttr()); + return rewriter.create(loc, vecType, scalarVec, loaded, predicate) + .getResult(); + } + Value one = rewriter.create( + loc, contract.elementType, + rewriter.getFloatAttr(contract.elementType, 1.0)); + Value reciprocal = rewriter.create(loc, one, scalar); + return rewriter.create(loc, vecType, loaded, reciprocal, predicate).getResult(); + } + if (contract.elementType.isF16()) { + Value scalarVec = + rewriter.create(loc, vecType, scalar, predicate, StringAttr()); + return scalarFirst + ? rewriter.create(loc, vecType, scalarVec, loaded, predicate) + .getResult() + : rewriter.create(loc, vecType, loaded, scalarVec, predicate) + .getResult(); + } + return failure(); + }; + + FailureOr vecScope = + createLoopScopeRegion(loc, contract.loopScope, rewriter); + if (failed(vecScope)) + return emitError(loc) << "failed to create AIV vector scope region"; + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + auto ifOp = rewriter.create(loc, TypeRange{}, fullWidthCond, + /*withElseRegion=*/true); + + rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); + { + if (strategy == VPTOLoweringStrategy::NoPostUpdate) { + auto chunkLoop = + rewriter.create(loc, c0, totalElementsValue, vectorStepValue); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + Value offset = chunkLoop.getInductionVar(); + Value remaining = rewriter.create(loc, totalElementsValue, offset); + Value activeLanes = + buildMinIndexValue(rewriter, loc, remaining, vectorStepValue); + Value predicate = buildPredicateMaskForLaneCount( + rewriter, loc, contract.elementType, activeLanes); + auto loaded = + rewriter.create(loc, vecType, srcBuffer, offset, StringAttr()); + FailureOr computed = buildDivValue(loaded.getResult(), predicate); + if (failed(computed)) + return emitError(loc) + << "divs lowering currently supports only f16 and f32 element types"; + rewriter.create(loc, *computed, dstBuffer, offset, StringAttr(), + predicate); + } else { + Value scalarInit = rewriter.create( + loc, rewriter.getI32Type(), totalElementsValue); + auto chunkLoop = rewriter.create( + loc, c0, totalElementsValue, vectorStepValue, + ValueRange{srcBuffer, dstBuffer, scalarInit}); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + Value srcPtr = chunkLoop.getRegionIterArgs()[0]; + Value dstPtr = chunkLoop.getRegionIterArgs()[1]; + Value remaining = chunkLoop.getRegionIterArgs()[2]; + PredicateMaterialization predicateState = buildPredicateForLaneCount( + rewriter, loc, contract.elementType, remaining); + auto loaded = rewriter.create(loc, vecType, srcPtr.getType(), srcPtr, + vectorStepValue, StringAttr()); + FailureOr computed = buildDivValue(loaded.getResult(), predicateState.mask); + if (failed(computed)) + return emitError(loc) + << "divs lowering currently supports only f16 and f32 element types"; + auto vsts = rewriter.create(loc, dstPtr.getType(), *computed, dstPtr, + vectorStepValue, StringAttr(), + predicateState.mask); + Value nextSrc = loaded.getUpdatedSource(); + Value nextDst = vsts.getUpdatedDestination(); + rewriter.create( + loc, ValueRange{nextSrc, nextDst, predicateState.nextScalar}); + } + } + + rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); + { + Value repeatUpper = rewriter.create(loc, validColsValue, + vectorStepValue); + auto rowLoop = rewriter.create(loc, c0, validRowsValue, c1); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + Value row = rowLoop.getInductionVar(); + Value srcRowBase = rewriter.create(loc, row, srcStrideValue); + Value dstRowBase = rewriter.create(loc, row, dstStrideValue); + auto repeatLoop = rewriter.create(loc, c0, repeatUpper, c1); + rewriter.setInsertionPointToStart(repeatLoop.getBody()); + Value repeat = repeatLoop.getInductionVar(); + Value chunkBase = rewriter.create(loc, repeat, vectorStepValue); + Value srcOffset = rewriter.create(loc, srcRowBase, chunkBase); + Value dstOffset = rewriter.create(loc, dstRowBase, chunkBase); + Value predicate; + if (strategy == VPTOLoweringStrategy::NoPostUpdate) { + predicate = + buildPredicateMaskForLaneCount(rewriter, loc, contract.elementType, validColsValue); + } else { + Value remainingCols = + rewriter.create(loc, validColsValue, chunkBase); + Value activeLanes = + buildMinIndexValue(rewriter, loc, remainingCols, vectorStepValue); + predicate = buildPredicateMaskForLaneCount( + rewriter, loc, contract.elementType, activeLanes); + } + auto loaded = + rewriter.create(loc, vecType, srcBuffer, srcOffset, StringAttr()); + FailureOr computed = buildDivValue(loaded.getResult(), predicate); + if (failed(computed)) + return emitError(loc) + << "divs lowering currently supports only f16 and f32 element types"; + rewriter.create(loc, *computed, dstBuffer, dstOffset, + StringAttr(), predicate); + } + + rewriter.setInsertionPointAfter(ifOp); + return success(); +} + +LogicalResult checkExpandContract(Operation *op, + const VPTOExpandContract &contract) { + bool hasPrecheckFailure = false; + if (contract.srcDomain != VPTOTileDomain::Vec || + contract.dstDomain != VPTOTileDomain::Vec) { + op->emitOpError() << contract.family + << " lowering requires vec source and destination"; + hasPrecheckFailure = true; + } + if (contract.srcLayout != "row_major" || contract.dstLayout != "row_major") { + op->emitOpError() << contract.family + << " lowering requires row-major source and destination tile layout"; + hasPrecheckFailure = true; + } + if (!contract.elementType || + (!contract.elementType.isF16() && !contract.elementType.isF32())) { + op->emitOpError() << contract.family + << " lowering currently supports only f16 and f32 element types"; + hasPrecheckFailure = true; + } + auto isStatic = [](int64_t value) { return value != ShapedType::kDynamic; }; + if (!isStatic(contract.srcValidRows) || !isStatic(contract.srcValidCols) || + !isStatic(contract.dstValidRows) || !isStatic(contract.dstValidCols)) { + op->emitOpError() << contract.family + << " lowering currently requires static source and destination valid shapes"; + hasPrecheckFailure = true; + } + return failure(hasPrecheckFailure); +} + +LogicalResult buildRowExpandVecScope(const VPTOExpandContract &contract, + VPTOLoweringStrategy strategy, Value src, Value dst, + PatternRewriter &rewriter, Location loc) { + auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); + if (!vecType) + return emitError(loc) << "unsupported VPTO rowexpand element type"; + + Value srcBuffer = materializeBufferPointer(src, contract.elementType, + getMemorySpace(src), rewriter, loc); + Value dstBuffer = materializeBufferPointer(dst, contract.elementType, + getMemorySpace(dst), rewriter, loc); + if (!srcBuffer || !dstBuffer) + return emitError(loc) + << "requires pointer-backed tile buffers for rowexpand lowering"; + + auto [srcRows, srcCols] = getStaticTileRowsCols(src); + auto [dstRows, dstCols] = getStaticTileRowsCols(dst); + if (srcCols == ShapedType::kDynamic || dstCols == ShapedType::kDynamic || + srcRows == ShapedType::kDynamic || dstRows == ShapedType::kDynamic) + return emitError(loc) << "rowexpand lowering requires static physical tile shape"; + + int64_t vectorWidth = vecType.getElementCount(); + Value validRowsValue = materializeIndexValue( + contract.dstValidRowsValue, contract.dstValidRows, rewriter, loc); + Value validColsValue = materializeIndexValue( + contract.dstValidColsValue, contract.dstValidCols, rewriter, loc); + if (!validRowsValue || !validColsValue) + return emitError(loc) << "rowexpand lowering requires valid rows and cols"; + + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value srcStrideValue = rewriter.create(loc, srcCols); + Value dstStrideValue = rewriter.create(loc, dstCols); + Value vectorStepValue = + rewriter.create(loc, vectorWidth); + Value rowScalarInit = rewriter.create(loc, rewriter.getI32Type(), + validColsValue); + FailureOr vecScope = + createLoopScopeRegion(loc, contract.loopScope, rewriter); + if (failed(vecScope)) + return emitError(loc) << "failed to create AIV vector scope region"; + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + Value repeatUpper = rewriter.create(loc, validColsValue, + vectorStepValue); + if (strategy == VPTOLoweringStrategy::NoPostUpdate) { + auto rowLoop = rewriter.create(loc, c0, validRowsValue, c1); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + Value row = rowLoop.getInductionVar(); + Value srcOffset = rewriter.create(loc, row, srcStrideValue); + Value dstBase = rewriter.create(loc, row, dstStrideValue); + auto loaded = + rewriter.create(loc, vecType, srcBuffer, srcOffset, StringAttr()); + Value fullMask = buildAllPredicateMask(rewriter, loc, contract.elementType); + Value expanded = rewriter.create( + loc, vecType, loaded.getResult(), fullMask, rewriter.getStringAttr("LOWEST")); + auto chunkLoop = rewriter.create(loc, c0, repeatUpper, c1, + ValueRange{rowScalarInit}); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + Value remaining = chunkLoop.getRegionIterArgs()[0]; + PredicateMaterialization predicateState = + buildPredicateForLaneCount(rewriter, loc, contract.elementType, remaining); + Value chunkBase = + rewriter.create(loc, chunkLoop.getInductionVar(), vectorStepValue); + Value dstOffset = rewriter.create(loc, dstBase, chunkBase); + rewriter.create(loc, expanded, dstBuffer, dstOffset, StringAttr(), + predicateState.mask); + rewriter.create(loc, ValueRange{predicateState.nextScalar}); + return success(); + } + + auto rowLoop = + rewriter.create(loc, c0, validRowsValue, c1, ValueRange{srcBuffer, dstBuffer}); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + Value srcPtr = rowLoop.getRegionIterArgs()[0]; + Value dstPtr = rowLoop.getRegionIterArgs()[1]; + auto loaded = rewriter.create(loc, vecType, srcPtr.getType(), srcPtr, + srcStrideValue, StringAttr()); + Value fullMask = buildAllPredicateMask(rewriter, loc, contract.elementType); + Value expanded = rewriter.create( + loc, vecType, loaded.getResult(), fullMask, rewriter.getStringAttr("LOWEST")); + auto chunkLoop = rewriter.create(loc, c0, repeatUpper, c1, + ValueRange{dstPtr, rowScalarInit}); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + Value dstChunkPtr = chunkLoop.getRegionIterArgs()[0]; + Value remaining = chunkLoop.getRegionIterArgs()[1]; + PredicateMaterialization predicateState = + buildPredicateForLaneCount(rewriter, loc, contract.elementType, remaining); + auto vsts = rewriter.create(loc, dstChunkPtr.getType(), expanded, + dstChunkPtr, vectorStepValue, StringAttr(), + predicateState.mask); + Value nextDstChunkPtr = vsts.getUpdatedDestination(); + rewriter.create(loc, ValueRange{nextDstChunkPtr, predicateState.nextScalar}); + + rewriter.setInsertionPointAfter(chunkLoop); + Value rowAdvance = rewriter.create(loc, repeatUpper, vectorStepValue); + Value dstPad = rewriter.create(loc, dstStrideValue, rowAdvance); + Value nextDstPtr = + offsetBufferPointer(dstPtr, contract.elementType, dstPad, rewriter, loc); + Value nextSrcPtr = loaded.getUpdatedSource(); + rewriter.create(loc, ValueRange{nextSrcPtr, nextDstPtr}); + return success(); +} + +LogicalResult buildColExpandVecScope(const VPTOExpandContract &contract, + Value src, Value dst, + PatternRewriter &rewriter, Location loc) { + auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); + if (!vecType) + return emitError(loc) << "unsupported VPTO colexpand element type"; + + Value srcBuffer = materializeBufferPointer(src, contract.elementType, + getMemorySpace(src), rewriter, loc); + Value dstBuffer = materializeBufferPointer(dst, contract.elementType, + getMemorySpace(dst), rewriter, loc); + if (!srcBuffer || !dstBuffer) + return emitError(loc) + << "requires pointer-backed tile buffers for colexpand lowering"; + + auto [dstRows, dstCols] = getStaticTileRowsCols(dst); + if (dstRows == ShapedType::kDynamic || dstCols == ShapedType::kDynamic) + return emitError(loc) + << "colexpand lowering requires static physical destination tile shape"; + + int64_t vectorWidth = vecType.getElementCount(); + Value validRowsValue = materializeIndexValue( + contract.dstValidRowsValue, contract.dstValidRows, rewriter, loc); + Value validColsValue = materializeIndexValue( + contract.dstValidColsValue, contract.dstValidCols, rewriter, loc); + if (!validRowsValue || !validColsValue) + return emitError(loc) << "colexpand lowering requires valid rows and cols"; + + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value dstStrideValue = rewriter.create(loc, dstCols); + Value vectorStepValue = + rewriter.create(loc, vectorWidth); + FailureOr vecScope = + createLoopScopeRegion(loc, contract.loopScope, rewriter); + if (failed(vecScope)) + return emitError(loc) << "failed to create AIV vector scope region"; + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + auto rowLoop = rewriter.create(loc, c0, validRowsValue, c1); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + auto chunkLoop = + rewriter.create(loc, c0, validColsValue, vectorStepValue); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + + Value dstBase = + rewriter.create(loc, rowLoop.getInductionVar(), dstStrideValue); + Value dstOffset = + rewriter.create(loc, dstBase, chunkLoop.getInductionVar()); + auto loaded = rewriter.create( + loc, vecType, srcBuffer, chunkLoop.getInductionVar(), StringAttr()); + rewriter.create(loc, loaded.getResult(), dstBuffer, dstOffset, + StringAttr(), + buildAllPredicateMask(rewriter, loc, + contract.elementType)); + return success(); +} + +LogicalResult checkGenericUnaryContract(Operation *op, + const VPTOUnaryContract &contract, + Value dst, + function_ref typePredicate, + StringRef supportedTypeText) { + int64_t dstRows = ShapedType::kDynamic; + int64_t dstCols = ShapedType::kDynamic; + deriveValidShape(dst, dstRows, dstCols); + StringRef dstLayout = deriveTileLayout(dst); + VPTOTileDomain dstDomain = deriveTileDomain(getMemorySpace(dst)); + + bool hasPrecheckFailure = false; + if (contract.tileDomain != VPTOTileDomain::Vec || dstDomain != VPTOTileDomain::Vec) { + op->emitOpError() << contract.family << " lowering requires tile domain vec"; + hasPrecheckFailure = true; + } + if (contract.tileLayout != "row_major" || dstLayout != "row_major") { + op->emitOpError() << contract.family << " lowering requires row-major tile layout"; + hasPrecheckFailure = true; + } + if (contract.validRows != ShapedType::kDynamic && + dstRows != ShapedType::kDynamic && dstRows > contract.validRows) { + op->emitOpError() << contract.family + << " lowering requires destination valid rows not to exceed source"; + hasPrecheckFailure = true; + } + if (contract.validCols != ShapedType::kDynamic && + dstCols != ShapedType::kDynamic && dstCols > contract.validCols) { + op->emitOpError() << contract.family + << " lowering requires destination valid cols not to exceed source"; + hasPrecheckFailure = true; + } + if (!contract.elementType || !typePredicate(contract.elementType)) { + op->emitOpError() + << contract.family << " lowering supports only " << supportedTypeText; + hasPrecheckFailure = true; + } + return failure(hasPrecheckFailure); +} + +LogicalResult checkGenericBinaryContract( + Operation *op, const VPTOBinaryContract &contract, Value src1, Value dst, + function_ref typePredicate, StringRef supportedTypeText) { + StringRef src1Layout = deriveTileLayout(src1); + StringRef dstLayout = deriveTileLayout(dst); + VPTOTileDomain src1Domain = deriveTileDomain(getMemorySpace(src1)); + VPTOTileDomain dstDomain = deriveTileDomain(getMemorySpace(dst)); + + bool hasPrecheckFailure = false; + if (contract.tileDomain != VPTOTileDomain::Vec || src1Domain != VPTOTileDomain::Vec || + dstDomain != VPTOTileDomain::Vec) { + op->emitOpError() << contract.family << " lowering requires tile domain vec"; + hasPrecheckFailure = true; + } + if (contract.tileLayout != "row_major" || src1Layout != "row_major" || + dstLayout != "row_major") { + op->emitOpError() << contract.family << " lowering requires row-major tile layout"; + hasPrecheckFailure = true; + } + if (!contract.elementType || !typePredicate(contract.elementType)) { + op->emitOpError() + << contract.family << " lowering supports only " << supportedTypeText; + hasPrecheckFailure = true; + } + return failure(hasPrecheckFailure); +} + +LogicalResult checkRowReduceContract(Operation *op, + const VPTORowReduceContract &contract, + Value dst) { + int64_t dstRows = ShapedType::kDynamic; + int64_t dstCols = ShapedType::kDynamic; + deriveValidShape(dst, dstRows, dstCols); + + bool hasPrecheckFailure = false; + if (contract.srcDomain != VPTOTileDomain::Vec || + contract.dstDomain != VPTOTileDomain::Vec) { + op->emitOpError() << contract.family << " lowering requires vec source and destination"; + hasPrecheckFailure = true; + } + if (contract.srcLayout != "row_major") { + op->emitOpError() << contract.family << " lowering requires row-major source tile layout"; + hasPrecheckFailure = true; + } + if (contract.dstLayout != "row_major" && contract.dstLayout != "col_major") { + op->emitOpError() << contract.family + << " lowering requires row-major or col-major destination tile layout"; + hasPrecheckFailure = true; + } + if (!contract.elementType || (!contract.elementType.isF16() && !contract.elementType.isF32())) { + op->emitOpError() << contract.family << " lowering supports only f16 and f32 element types"; + hasPrecheckFailure = true; + } + if (contract.validRows == ShapedType::kDynamic || + contract.validCols == ShapedType::kDynamic) { + op->emitOpError() << contract.family + << " lowering currently requires static source valid rows and cols"; + hasPrecheckFailure = true; + } + if (contract.validRows != dstRows) { + op->emitOpError() << contract.family + << " lowering requires destination valid rows to match source valid rows"; + hasPrecheckFailure = true; + } + if (dstCols != 1) { + op->emitOpError() << contract.family + << " lowering requires destination valid cols to equal 1"; + hasPrecheckFailure = true; + } + if (contract.dstLayout == "col_major") { + auto [dstRowsPhysical, dstColsPhysical] = getStaticTileRowsCols(dst); + (void)dstRowsPhysical; + if (dstColsPhysical != 1) { + op->emitOpError() << contract.family + << " lowering requires col-major destinations to use physical cols == 1"; + hasPrecheckFailure = true; + } + } + return failure(hasPrecheckFailure); +} + +LogicalResult checkColReduceContract(Operation *op, + const VPTOColReduceContract &contract, + Value dst) { + int64_t dstRows = ShapedType::kDynamic; + int64_t dstCols = ShapedType::kDynamic; + deriveValidShape(dst, dstRows, dstCols); + + bool hasPrecheckFailure = false; + if (contract.srcDomain != VPTOTileDomain::Vec || + contract.dstDomain != VPTOTileDomain::Vec) { + op->emitOpError() << contract.family << " lowering requires vec source and destination"; + hasPrecheckFailure = true; + } + if (contract.srcLayout != "row_major" || contract.dstLayout != "row_major") { + op->emitOpError() << contract.family + << " lowering requires row-major source and destination tile layout"; + hasPrecheckFailure = true; + } + if (!contract.elementType || + (!contract.elementType.isF16() && !contract.elementType.isF32())) { + op->emitOpError() << contract.family << " lowering supports only f16 and f32 element types"; + hasPrecheckFailure = true; + } + if (contract.validRows == ShapedType::kDynamic || + contract.validCols == ShapedType::kDynamic) { + op->emitOpError() << contract.family + << " lowering currently requires static source valid rows and cols"; + hasPrecheckFailure = true; + } + if (dstRows != 1) { + op->emitOpError() << contract.family + << " lowering requires destination valid rows to equal 1"; + hasPrecheckFailure = true; + } + if (dstCols != contract.validCols) { + op->emitOpError() << contract.family + << " lowering requires destination valid cols to match source valid cols"; + hasPrecheckFailure = true; + } + if (contract.isBinary && !contract.tmp) { + op->emitOpError() << contract.family << " lowering requires tmp for binary path"; + hasPrecheckFailure = true; + } + return failure(hasPrecheckFailure); +} + +LogicalResult checkPartContract(Operation *op, const VPTOPartContract &contract) { + bool hasPrecheckFailure = false; + if (contract.src0Domain != VPTOTileDomain::Vec || + contract.src1Domain != VPTOTileDomain::Vec || + contract.dstDomain != VPTOTileDomain::Vec) { + op->emitOpError() << contract.family << " lowering requires vec source and destination"; + hasPrecheckFailure = true; + } + if (contract.src0Layout != "row_major" || contract.src1Layout != "row_major" || + contract.dstLayout != "row_major") { + op->emitOpError() << contract.family + << " lowering requires row-major source and destination tile layout"; + hasPrecheckFailure = true; + } + if (!contract.elementType) + hasPrecheckFailure = true; + else if (contract.family == "partadd") { + bool ok = contract.elementType.isF16() || contract.elementType.isF32() || + contract.elementType.isBF16(); + if (auto intType = dyn_cast(contract.elementType)) + ok = intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + if (!ok) { + op->emitOpError() << contract.family + << " lowering supports f16, f32, bf16, and 8/16/32-bit integers"; + hasPrecheckFailure = true; + } + } else { + bool ok = contract.elementType.isF16() || contract.elementType.isF32() || + contract.elementType.isBF16(); + if (auto intType = dyn_cast(contract.elementType)) + ok = intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + if (!ok) { + op->emitOpError() << contract.family + << " lowering supports f16, f32, bf16, and 8/16/32-bit integers"; + hasPrecheckFailure = true; + } + } + auto allStatic = [&](int64_t a, int64_t b) { + return a != ShapedType::kDynamic && b != ShapedType::kDynamic; + }; + if (!allStatic(contract.src0ValidRows, contract.src0ValidCols) || + !allStatic(contract.src1ValidRows, contract.src1ValidCols) || + !allStatic(contract.dstValidRows, contract.dstValidCols)) { + op->emitOpError() << contract.family + << " lowering currently requires static source and destination valid shapes"; + hasPrecheckFailure = true; + } + return failure(hasPrecheckFailure); +} + +LogicalResult lowerTLOAD(TLoadOp op, PatternRewriter &rewriter) { + VPTOLoadContract contract = extractTLoadContract(op); + if (contract.tileDomain != VPTOTileDomain::Vec) + return op.emitOpError("currently supports only VEC TLOAD lowering"); + + ResolvedTensorView sourceView; + if (!resolveTensorView(op.getSrc(), sourceView, rewriter, op.getLoc())) + return op.emitOpError("requires a recoverable source tensor view for VPTO lowering"); + + StringRef sourceLayout = + inferVecTransferLayoutFromTile(stringifyLayoutAttr(sourceView.layoutAttr), + contract.tileLayout); + bool isNdLoad = contract.tileLayout == "row_major" && sourceLayout == "nd"; + bool isDnLoad = contract.tileLayout == "col_major" && sourceLayout == "dn"; + if (!isNdLoad && !isDnLoad) + return op.emitOpError("currently supports only ND row_major or DN col_major vec TLOAD lowering"); + + Value sourceBuffer = + materializeBufferPointer(sourceView.root, getElementType(sourceView.root), + getGmMemorySpace(rewriter.getContext()), rewriter, + op.getLoc()); + Value destinationBuffer = + materializeBufferPointer(op.getDst(), contract.elementType, + getMemorySpace(op.getDst()), rewriter, op.getLoc()); + if (!sourceBuffer || !destinationBuffer) + return op.emitOpError("requires A5-compatible source and destination buffers"); + + auto [tileRows, tileCols] = getStaticTileRowsCols(op.getDst()); + (void)tileRows; + bool ubPad = contract.padMode != "none" || contract.padValue || + contract.leftPaddingNum || contract.rightPaddingNum; + Value validRowsValue = + materializeI64Value(contract.validRowsValue, contract.validRows, rewriter, + op.getLoc()); + Value validColsValue = + materializeI64Value(contract.validColsValue, contract.validCols, rewriter, + op.getLoc()); + Value sidValue = rewriter.create(op.getLoc(), 0, 64); + int64_t elemBytes = getElementByteSize(contract.elementType); + if ((isNdLoad && tileCols == ShapedType::kDynamic) || + (isDnLoad && tileRows == ShapedType::kDynamic) || elemBytes <= 0) + return op.emitOpError("requires static tile shape for A5-compatible transfer arguments"); + VecNdTransferPlan plan; + LogicalResult planResult = + isNdLoad ? buildVecNdLoadPlan(sourceView.shape, sourceView.strides, tileCols, + contract.validColsValue, contract.validCols, + contract.elementType, rewriter, op.getLoc(), plan) + : buildVecDnLoadPlan(sourceView.shape, sourceView.strides, tileRows, + contract.validRowsValue, contract.validRows, + contract.elementType, rewriter, op.getLoc(), plan); + if (failed(planResult)) + return op.emitOpError("requires PTO-compatible vec copy_gm_to_ubuf arguments"); + Value leftPaddingValue = rewriter.create(op.getLoc(), 0, 64); + Value rightPaddingValue = rewriter.create(op.getLoc(), 0, 64); + Value cacheCtlValue = rewriter.create(op.getLoc(), 0, 64); + if (!validRowsValue || !validColsValue) + return op.emitOpError("requires valid rows and cols for A5-compatible transfer arguments"); + Value sourceOffset = + materializeI64Ofr(sourceView.offsetElems, rewriter, op.getLoc()); + if (!sourceOffset) + return op.emitOpError("requires a materializable source offset for VPTO lowering"); + Value sourceBase = adjustPointerByElemOffset(sourceBuffer, sourceOffset, elemBytes, + rewriter, op.getLoc()); + if (!sourceBase) + return op.emitOpError("failed to materialize source base pointer"); + + rewriter.create( + op.getLoc(), plan.loop2FirstStrideBytes, plan.loop2SecondStrideBytes); + rewriter.create( + op.getLoc(), plan.loop1FirstStrideBytes, plan.loop1SecondStrideBytes); + rewriter.create(op.getLoc(), plan.loop2Size, + plan.loop1Size); + + auto emitCopy = [&](Value srcPtr, Value dstPtr) { + Type transferElementType = + getCopyTransferElementType(contract.elementType, rewriter); + Value typedSrcPtr = + castPtrToElementType(srcPtr, transferElementType, rewriter, op.getLoc()); + Value typedDstPtr = + castPtrToElementType(dstPtr, transferElementType, rewriter, op.getLoc()); + if (!typedSrcPtr || !typedDstPtr) + return failure(); + Value dataSelectBitValue = + rewriter.create(op.getLoc(), rewriter.getI1Type(), + rewriter.getBoolAttr(ubPad)); + rewriter.create( + op.getLoc(), typedSrcPtr, typedDstPtr, sidValue, plan.nBurst, + plan.lenBurst, leftPaddingValue, rightPaddingValue, dataSelectBitValue, + cacheCtlValue, plan.firstStrideBytes, plan.secondStrideBytes); + return success(); + }; + + if (std::optional outerConst = getConstInt(plan.outerCount); outerConst && *outerConst == 1) { + return emitCopy(sourceBase, destinationBuffer); + } + + Value c0 = rewriter.create(op.getLoc(), 0); + Value c1 = rewriter.create(op.getLoc(), 1); + Value outerUpper = + rewriter.create(op.getLoc(), rewriter.getIndexType(), + plan.outerCount); + auto outerLoop = rewriter.create(op.getLoc(), c0, outerUpper, c1); + rewriter.setInsertionPointToStart(outerLoop.getBody()); + Value ivI64 = rewriter.create(op.getLoc(), rewriter.getI64Type(), + outerLoop.getInductionVar()); + Value srcStep = createI64Mul(ivI64, plan.outerSrcStrideElems, rewriter, op.getLoc()); + Value dstStep = createI64Mul(ivI64, plan.outerDstStrideElems, rewriter, op.getLoc()); + Value iterSrc = adjustPointerByElemOffset(sourceBase, srcStep, elemBytes, rewriter, + op.getLoc()); + Value iterDst = adjustPointerByElemOffset(destinationBuffer, dstStep, elemBytes, rewriter, + op.getLoc()); + return emitCopy(iterSrc, iterDst); +} + +LogicalResult lowerTABS(TAbsOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOUnaryContract contract = extractTAbsContract(op); + if (failed(checkGenericUnaryContract( + op, contract, op.getDst(), + [](Type type) { return type.isF16() || type.isF32(); }, "f16 and f32 element types"))) + return failure(); + + return buildUnaryVecScope("abs", contract, strategy, op.getSrc(), op.getDst(), + rewriter, op.getLoc()); +} + +LogicalResult lowerTADD(TAddOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOBinaryContract contract = extractTAddContract(op); + deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); + deriveValidShape(op.getDst(), contract.validRows, contract.validCols); + if (failed(checkGenericBinaryContract( + op, contract, op.getSrc1(), op.getDst(), + [](Type type) { + if (type.isF16() || type.isF32() || type.isBF16()) + return true; + if (auto intType = dyn_cast(type)) + return intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + return false; + }, + "f16, f32, bf16, and 8/16/32-bit integer element types"))) + return failure(); + return buildBinaryVecScope("add", contract, strategy, op.getSrc0(), + op.getSrc1(), op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTSUB(TSubOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOBinaryContract contract = extractTSubContract(op); + deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); + deriveValidShape(op.getDst(), contract.validRows, contract.validCols); + if (failed(checkGenericBinaryContract( + op, contract, op.getSrc1(), op.getDst(), + [](Type type) { + if (type.isF16() || type.isF32() || type.isBF16()) + return true; + if (auto intType = dyn_cast(type)) + return intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + return false; + }, + "f16, f32, bf16, and 8/16/32-bit integer element types"))) + return failure(); + return buildBinaryVecScope("sub", contract, strategy, op.getSrc0(), + op.getSrc1(), op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTMUL(TMulOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOBinaryContract contract = extractTMulContract(op); + deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); + deriveValidShape(op.getDst(), contract.validRows, contract.validCols); + if (failed(checkGenericBinaryContract( + op, contract, op.getSrc1(), op.getDst(), + [](Type type) { + if (type.isF16() || type.isF32() || type.isBF16()) + return true; + if (auto intType = dyn_cast(type)) + return intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + return false; + }, + "f16, f32, bf16, and 8/16/32-bit integer element types"))) + return failure(); + return buildBinaryVecScope("mul", contract, strategy, op.getSrc0(), + op.getSrc1(), op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTDIV(TDivOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOBinaryContract contract = extractTDivContract(op); + deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); + deriveValidShape(op.getDst(), contract.validRows, contract.validCols); + if (failed(checkGenericBinaryContract( + op, contract, op.getSrc1(), op.getDst(), + [](Type type) { + if (type.isF16() || type.isF32()) + return true; + if (auto intType = dyn_cast(type)) + return intType.getWidth() == 16 || intType.getWidth() == 32; + return false; + }, + "f16, f32, and 16/32-bit integer element types"))) + return failure(); + return buildBinaryVecScope("div", contract, strategy, op.getSrc0(), + op.getSrc1(), op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTMAX(TMaxOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOBinaryContract contract = buildBinaryContract("max", op.getSrc0()); + deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); + deriveValidShape(op.getDst(), contract.validRows, contract.validCols); + if (failed(checkGenericBinaryContract( + op, contract, op.getSrc1(), op.getDst(), + [](Type type) { + if (type.isF16() || type.isF32() || type.isBF16()) + return true; + if (auto intType = dyn_cast(type)) + return intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + return false; + }, + "f16, f32, bf16, and 8/16/32-bit integer element types"))) + return failure(); + return buildBinaryVecScope("max", contract, strategy, op.getSrc0(), + op.getSrc1(), op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTMIN(TMinOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOBinaryContract contract = buildBinaryContract("min", op.getSrc0()); + deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); + deriveValidShape(op.getDst(), contract.validRows, contract.validCols); + if (failed(checkGenericBinaryContract( + op, contract, op.getSrc1(), op.getDst(), + [](Type type) { + if (type.isF16() || type.isF32() || type.isBF16()) + return true; + if (auto intType = dyn_cast(type)) + return intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + return false; + }, + "f16, f32, bf16, and 8/16/32-bit integer element types"))) + return failure(); + return buildBinaryVecScope("min", contract, strategy, op.getSrc0(), + op.getSrc1(), op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTAND(TAndOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOBinaryContract contract = buildBinaryContract("and", op.getSrc0()); + deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); + deriveValidShape(op.getDst(), contract.validRows, contract.validCols); + if (failed(checkGenericBinaryContract( + op, contract, op.getSrc1(), op.getDst(), + [](Type type) { + if (auto intType = dyn_cast(type)) + return intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + return false; + }, + "8/16/32-bit integer element types"))) + return failure(); + return buildBinaryVecScope("and", contract, strategy, op.getSrc0(), + op.getSrc1(), op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTANDS(TAndSOp op, PatternRewriter &rewriter) { + return emitUnresolvedInstalledA5BaselineError(op, "tands"); +} + +LogicalResult lowerTOR(TOrOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOBinaryContract contract = buildBinaryContract("or", op.getSrc0()); + deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); + deriveValidShape(op.getDst(), contract.validRows, contract.validCols); + if (failed(checkGenericBinaryContract( + op, contract, op.getSrc1(), op.getDst(), + [](Type type) { + if (auto intType = dyn_cast(type)) + return intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + return false; + }, + "8/16/32-bit integer element types"))) + return failure(); + return buildBinaryVecScope("or", contract, strategy, op.getSrc0(), + op.getSrc1(), op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTORS(TOrSOp op, PatternRewriter &rewriter) { + return emitUnresolvedInstalledA5BaselineError(op, "tors"); +} + +LogicalResult lowerTXOR(TXorOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOBinaryContract contract = buildBinaryContract("xor", op.getSrc0()); + deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); + deriveValidShape(op.getDst(), contract.validRows, contract.validCols); + if (failed(checkGenericBinaryContract( + op, contract, op.getSrc1(), op.getDst(), + [](Type type) { + if (auto intType = dyn_cast(type)) + return intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + return false; + }, + "8/16/32-bit integer element types"))) + return failure(); + return buildBinaryVecScope("xor", contract, strategy, op.getSrc0(), + op.getSrc1(), op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTXORS(TXorSOp op, PatternRewriter &rewriter) { + return emitUnresolvedInstalledA5BaselineError(op, "txors"); +} + +LogicalResult lowerTEXP(TExpOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOUnaryContract contract = extractTExpContract(op); + if (failed(checkGenericUnaryContract( + op, contract, op.getDst(), + [](Type type) { return type.isF16() || type.isF32(); }, "f16 and f32 element types"))) + return failure(); + return buildUnaryVecScope("exp", contract, strategy, op.getSrc(), op.getDst(), + rewriter, op.getLoc()); +} + +LogicalResult lowerTLOG(TLogOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOUnaryContract contract = extractTLogContract(op); + if (failed(checkGenericUnaryContract( + op, contract, op.getDst(), + [](Type type) { return type.isF16() || type.isF32(); }, "f16 and f32 element types"))) + return failure(); + return buildUnaryVecScope("log", contract, strategy, op.getSrc(), op.getDst(), + rewriter, op.getLoc()); +} + +LogicalResult lowerTSQRT(TSqrtOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOUnaryContract contract = extractTSqrtContract(op); + if (failed(checkGenericUnaryContract( + op, contract, op.getDst(), + [](Type type) { return type.isF16() || type.isF32(); }, "f16 and f32 element types"))) + return failure(); + return buildUnaryVecScope("sqrt", contract, strategy, op.getSrc(), op.getDst(), + rewriter, op.getLoc()); +} + +LogicalResult lowerTRSQRT(TRsqrtOp op, PatternRewriter &rewriter) { + VPTOUnaryContract contract = buildUnaryContract("rsqrt", op.getSrc()); + if (failed(checkGenericUnaryContract( + op, contract, op.getDst(), + [](Type type) { return type.isF16() || type.isF32(); }, "f16 and f32 element types"))) + return failure(); + + auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); + if (!vecType) + return op.emitOpError("trsqrt lowering requires a supported VPTO vector element type"); + + Value srcBuffer = materializeBufferPointer(op.getSrc(), contract.elementType, + getMemorySpace(op.getSrc()), rewriter, + op.getLoc()); + Value dstBuffer = materializeBufferPointer(op.getDst(), contract.elementType, + getMemorySpace(op.getDst()), rewriter, + op.getLoc()); + if (!srcBuffer || !dstBuffer) + return op.emitOpError("trsqrt lowering requires pointer-backed tile buffers"); + + Value validRowsValue = materializeIndexValue(contract.validRowsValue, + contract.validRows, rewriter, op.getLoc()); + Value validColsValue = materializeIndexValue(contract.validColsValue, + contract.validCols, rewriter, op.getLoc()); + if (!validRowsValue || !validColsValue) + return op.emitOpError("trsqrt lowering requires valid rows and cols"); + + int64_t srcRowStride = deriveStaticRowStride(op.getSrc()); + int64_t dstRowStride = deriveStaticRowStride(op.getDst()); + if (srcRowStride == ShapedType::kDynamic || dstRowStride == ShapedType::kDynamic) + return op.emitOpError("trsqrt lowering requires static row-major row strides"); + + int64_t vectorWidth = vecType.getElementCount(); + Value c0 = rewriter.create(op.getLoc(), 0); + Value c1 = rewriter.create(op.getLoc(), 1); + Value srcRowStrideValue = + rewriter.create(op.getLoc(), srcRowStride); + Value dstRowStrideValue = + rewriter.create(op.getLoc(), dstRowStride); + Value vectorStepValue = + rewriter.create(op.getLoc(), vectorWidth); + TypedAttr oneAttr = FloatAttr::get(contract.elementType, 1.0); + Value one = rewriter.create(op.getLoc(), oneAttr); + + FailureOr vecScope = + createLoopScopeRegion(op.getLoc(), contract.loopScope, rewriter); + if (failed(vecScope)) + return op.emitOpError("failed to create AIV vector scope region"); + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + auto rowLoop = rewriter.create(op.getLoc(), c0, validRowsValue, c1); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + Value fullMask = buildAllPredicateMask(rewriter, op.getLoc(), vecType.getElementType()); + auto ones = + rewriter.create(op.getLoc(), vecType, one, fullMask, StringAttr()); + auto chunkLoop = + rewriter.create(op.getLoc(), c0, validColsValue, vectorStepValue); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + Value srcRowBase = rewriter.create( + op.getLoc(), rowLoop.getInductionVar(), srcRowStrideValue); + Value dstRowBase = rewriter.create( + op.getLoc(), rowLoop.getInductionVar(), dstRowStrideValue); + Value chunkOffset = chunkLoop.getInductionVar(); + Value srcOffset = + rewriter.create(op.getLoc(), srcRowBase, chunkOffset); + Value dstOffset = + rewriter.create(op.getLoc(), dstRowBase, chunkOffset); + Value remaining = rewriter.create(op.getLoc(), validColsValue, chunkOffset); + Value predicate = + buildPredicateMaskForLaneCount(rewriter, op.getLoc(), contract.elementType, remaining); + auto loaded = rewriter.create(op.getLoc(), vecType, srcBuffer, + srcOffset, StringAttr()); + auto sqrt = rewriter.create(op.getLoc(), vecType, loaded.getResult(), + predicate); + auto result = rewriter.create(op.getLoc(), vecType, ones.getResult(), + sqrt.getResult(), predicate); + rewriter.create( + op.getLoc(), result.getResult(), dstBuffer, dstOffset, StringAttr(), predicate); + return success(); +} + +LogicalResult lowerTRECIP(TRecipOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOUnaryContract contract = extractTRecipContract(op); + if (failed(checkGenericUnaryContract( + op, contract, op.getDst(), + [](Type type) { return type.isF16() || type.isF32(); }, "f16 and f32 element types"))) + return failure(); + return buildUnaryVecScope("recip", contract, strategy, op.getSrc(), + op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTNEG(TNegOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOUnaryContract contract = buildUnaryContract("muls", op.getSrc()); + if (failed(checkGenericUnaryContract( + op, contract, op.getDst(), + [](Type type) { + if (type.isF16() || type.isF32()) + return true; + if (auto intType = dyn_cast(type)) + return intType.getWidth() == 16 || intType.getWidth() == 32; + return false; + }, + "f16, f32, and 16/32-bit integer element types"))) + return failure(); + + TypedAttr negOneAttr; + if (contract.elementType.isF16()) + negOneAttr = FloatAttr::get(contract.elementType, -1.0); + else if (contract.elementType.isF32()) + negOneAttr = FloatAttr::get(contract.elementType, -1.0); + else if (auto intType = dyn_cast(contract.elementType)) + negOneAttr = IntegerAttr::get(intType, -1); + else + return op.emitOpError("tneg lowering requires scalar element type"); + + Value negOne = rewriter.create(op.getLoc(), negOneAttr); + return buildScalarUnaryVecScope("muls", contract, strategy, op.getSrc(), negOne, + op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTLRELU(TLReluOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOUnaryContract contract = buildUnaryContract("lrelu", op.getSrc()); + if (failed(checkGenericUnaryContract( + op, contract, op.getDst(), + [](Type type) { return type.isF16() || type.isF32(); }, + "f16 and f32 element types"))) + return failure(); + if (op.getSlope().getType() != contract.elementType) + return op.emitOpError("tlrelu lowering requires slope type to match source element type"); + return buildScalarUnaryVecScope("lrelu", contract, strategy, op.getSrc(), op.getSlope(), + op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTCVT(TCvtOp op, PatternRewriter &rewriter) { + VPTOUnaryContract contract = buildUnaryContract("cvt", op.getSrc()); + if (failed(checkGenericUnaryContract( + op, contract, op.getDst(), + [](Type type) { return type.isF16() || type.isF32() || type.isBF16(); }, + "f16, f32, or bf16 element type"))) + return failure(); + + Type dstElementType = getElementType(op.getDst()); + FailureOr loweringKind = + classifyA5CvtLowering(contract.elementType, dstElementType); + if (failed(loweringKind)) + return op.emitOpError( + "current tcvt lowering supports only f32->f32, f32->bf16, f16->f32, and bf16->f32"); + + FailureOr roundMode = stringifyA5RoundMode(op, rewriter); + if (failed(roundMode)) + return op.emitOpError("tcvt lowering does not recognize the requested round mode"); + + auto srcVecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); + auto dstVecType = getVPTOVRegType(rewriter.getContext(), dstElementType); + if (!srcVecType || !dstVecType) + return op.emitOpError("tcvt lowering requires legal VPTO vector types"); + + Value srcBuffer = materializeBufferPointer(op.getSrc(), contract.elementType, + getMemorySpace(op.getSrc()), rewriter, + op.getLoc()); + Value dstBuffer = materializeBufferPointer(op.getDst(), dstElementType, + getMemorySpace(op.getDst()), rewriter, + op.getLoc()); + if (!srcBuffer || !dstBuffer) + return op.emitOpError("tcvt lowering requires pointer-backed tile buffers"); + + Value validRowsValue = materializeIndexValue(contract.validRowsValue, + contract.validRows, rewriter, + op.getLoc()); + Value validColsValue = materializeIndexValue(contract.validColsValue, + contract.validCols, rewriter, + op.getLoc()); + if (!validRowsValue || !validColsValue) + return op.emitOpError("tcvt lowering requires valid rows and cols"); + + int64_t vectorWidth = dstVecType.getElementCount(); + if (contract.validRows != ShapedType::kDynamic && + contract.validCols != ShapedType::kDynamic) { + int64_t totalElements = contract.validRows * contract.validCols; + if (totalElements % vectorWidth != 0) + return op.emitOpError( + "tcvt lowering requires total valid elements divisible by vector width"); + } + + Value c0 = rewriter.create(op.getLoc(), 0); + Value c1 = rewriter.create(op.getLoc(), 1); + Value totalElementsValue = + rewriter.create(op.getLoc(), validRowsValue, validColsValue); + Value vectorStepValue = + rewriter.create(op.getLoc(), vectorWidth); + + FailureOr vecScope = + createLoopScopeRegion(op.getLoc(), contract.loopScope, rewriter); + if (failed(vecScope)) + return op.emitOpError("failed to create AIV vector scope region"); + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + auto chunkLoop = + rewriter.create(op.getLoc(), c0, totalElementsValue, vectorStepValue); + OpBuilder::InsertionGuard chunkGuard(rewriter); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + Value offset = chunkLoop.getInductionVar(); + switch (*loweringKind) { + case VPTOCvtLoweringKind::Vtrc: { + auto loaded = + rewriter.create(op.getLoc(), srcVecType, srcBuffer, offset, StringAttr()); + Value converted = rewriter.create(op.getLoc(), dstVecType, + loaded.getResult(), *roundMode); + rewriter.create( + op.getLoc(), converted, dstBuffer, offset, StringAttr(), + buildAllPredicateMask(rewriter, op.getLoc(), dstElementType)); + break; + } + case VPTOCvtLoweringKind::F32ToBF16: { + Value halfStep = rewriter.create( + op.getLoc(), srcVecType.getElementCount()); + Value upperOffset = + rewriter.create(op.getLoc(), offset, halfStep); + auto lower = + rewriter.create(op.getLoc(), srcVecType, srcBuffer, offset, StringAttr()); + auto upper = rewriter.create(op.getLoc(), srcVecType, srcBuffer, + upperOffset, StringAttr()); + Value odd = rewriter.create( + op.getLoc(), dstVecType, upper.getResult(), *roundMode, + rewriter.getStringAttr("RS_ENABLE"), rewriter.getStringAttr("PART_ODD")); + Value even = rewriter.create( + op.getLoc(), dstVecType, lower.getResult(), *roundMode, + rewriter.getStringAttr("RS_ENABLE"), rewriter.getStringAttr("PART_EVEN")); + Value merged = + rewriter.create( + op.getLoc(), dstVecType, even, odd, + buildAllPredicateMask(rewriter, op.getLoc(), dstElementType)); + rewriter.create( + op.getLoc(), merged, dstBuffer, offset, StringAttr(), + buildAllPredicateMask(rewriter, op.getLoc(), dstElementType)); + break; + } + case VPTOCvtLoweringKind::F16ToF32: { + auto loaded = rewriter.create( + op.getLoc(), srcVecType, srcBuffer, offset, rewriter.getStringAttr("UNPK_B16")); + Value converted = rewriter.create( + op.getLoc(), dstVecType, loaded.getResult(), StringAttr(), + StringAttr(), rewriter.getStringAttr("PART_EVEN")); + rewriter.create( + op.getLoc(), converted, dstBuffer, offset, StringAttr(), + buildAllPredicateMask(rewriter, op.getLoc(), dstElementType)); + break; + } + case VPTOCvtLoweringKind::BF16ToF32: { + auto loaded = rewriter.create( + op.getLoc(), srcVecType, srcBuffer, offset, rewriter.getStringAttr("UNPK_B16")); + Value converted = rewriter.create( + op.getLoc(), dstVecType, loaded.getResult(), StringAttr(), + StringAttr(), rewriter.getStringAttr("PART_EVEN")); + rewriter.create( + op.getLoc(), converted, dstBuffer, offset, StringAttr(), + buildAllPredicateMask(rewriter, op.getLoc(), dstElementType)); + break; + } + } + return success(); +} + +template +LogicalResult buildPackedCmp32VecScope(StringRef family, + const VPTOBinaryContract &contract, + Value dst, Value dstBuffer, + PatternRewriter &rewriter, Location loc, + CompareEmitter emitCompare) { + auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); + if (!vecType) + return emitError(loc) << family << " lowering requires a supported vector element type"; + + Value validRowsValue = materializeIndexValue(contract.validRowsValue, + contract.validRows, rewriter, loc); + Value validColsValue = materializeIndexValue(contract.validColsValue, + contract.validCols, rewriter, loc); + if (!validRowsValue || !validColsValue) + return emitError(loc) << family << " lowering requires valid rows and cols"; + if (contract.validRows == ShapedType::kDynamic || + contract.validCols == ShapedType::kDynamic) + return emitError(loc) << family << " lowering currently requires static valid rows and cols"; + + int64_t totalElements = contract.validRows * contract.validCols; + constexpr int64_t repeatElem = 64; + int64_t repeatTimes = (totalElements + repeatElem - 1) / repeatElem; + int64_t pairedRepeats = repeatTimes / 2; + int64_t remainRepeats = repeatTimes % 2; + + auto compareMaskType = + getVPTOMaskTypeForElementType(rewriter.getContext(), contract.elementType); + auto packedMaskType = getVPTOMaskType(rewriter.getContext(), "b8"); + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value pairUpper = rewriter.create(loc, pairedRepeats); + Value repeatStep = rewriter.create(loc, repeatElem); + Value pairSrcStride = rewriter.create(loc, repeatElem * 2); + Value pairDstStride = rewriter.create(loc, 4); + Value laneCount = rewriter.create(loc, repeatElem, 32); + Value totalRemaining = rewriter.create(loc, totalElements, 32); + + FailureOr vecScope = + createLoopScopeRegion(loc, contract.loopScope, rewriter); + if (failed(vecScope)) + return emitError(loc) << "failed to create AIV vector scope region"; + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + auto pairLoop = + rewriter.create(loc, c0, pairUpper, c1, ValueRange{totalRemaining}); + rewriter.setInsertionPointToStart(pairLoop.getBody()); + Value remaining = pairLoop.getRegionIterArgs().front(); + Value pairBase = rewriter.create(loc, pairLoop.getInductionVar(), + pairSrcStride); + Value pairNext = rewriter.create(loc, pairBase, repeatStep); + Value dstOffset = rewriter.create(loc, pairLoop.getInductionVar(), + pairDstStride); + Value dstBase = adjustPointerByElemOffset(dstBuffer, dstOffset, 4, rewriter, loc); + Value dstZero = rewriter.create(loc, 0); + auto pairMask0 = rewriter.create(loc, compareMaskType, + rewriter.getI32Type(), + remaining); + auto pairMask1 = rewriter.create(loc, compareMaskType, + rewriter.getI32Type(), + pairMask0.getScalarOut()); + Value cmp0 = emitCompare(rewriter, loc, pairBase, pairMask0.getMask()); + Value cmp1 = emitCompare(rewriter, loc, pairNext, pairMask1.getMask()); + Value packedCmp0 = rewriter + .create(loc, packedMaskType, cmp0, + rewriter.getStringAttr("LOWER")) + .getResult(); + Value packedCmp1 = rewriter + .create(loc, packedMaskType, cmp1, + rewriter.getStringAttr("LOWER")) + .getResult(); + auto interleaved = rewriter.create( + loc, packedMaskType, packedMaskType, packedCmp0, packedCmp1); + rewriter.create(loc, interleaved.getLow(), dstBase, dstZero, + "NORM"); + rewriter.create(loc, pairMask1.getScalarOut()); + + if (remainRepeats == 0) + return success(); + + rewriter.setInsertionPointAfter(pairLoop); + Value tailBase = rewriter.create(loc, pairedRepeats * repeatElem * 2); + Value tailDst = rewriter.create(loc, pairedRepeats * 4); + Value tailDstBase = adjustPointerByElemOffset(dstBuffer, tailDst, 4, rewriter, loc); + Value tailDstZero = rewriter.create(loc, 0); + auto tailMask = rewriter.create(loc, compareMaskType, + rewriter.getI32Type(), + pairLoop.getResult(0)); + Value tailCmp = emitCompare(rewriter, loc, tailBase, tailMask.getMask()); + Value packedTail = rewriter + .create(loc, packedMaskType, tailCmp, + rewriter.getStringAttr("LOWER")) + .getResult(); + rewriter.create(loc, packedTail, tailDstBase, tailDstZero, + "NORM"); + return success(); +} + +LogicalResult lowerTCmpS(TCmpSOp op, PatternRewriter &rewriter) { + VPTOBinaryContract contract = buildBinaryContract("cmps", op.getSrc()); + deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); + deriveValidShape(op.getDst(), contract.validRows, contract.validCols); + + if (contract.tileDomain != VPTOTileDomain::Vec || + deriveTileDomain(getMemorySpace(op.getDst())) != VPTOTileDomain::Vec) + return op.emitOpError("tcmps lowering requires tile domain vec"); + if (contract.tileLayout != "row_major" || deriveTileLayout(op.getDst()) != "row_major") + return op.emitOpError("tcmps lowering requires row-major tile layout"); + if (contract.validRows == ShapedType::kDynamic || + contract.validCols == ShapedType::kDynamic) + return op.emitOpError("tcmps lowering requires static valid shape"); + int64_t dstRows = ShapedType::kDynamic; + int64_t dstCols = ShapedType::kDynamic; + deriveValidShape(op.getDst(), dstRows, dstCols); + if (contract.validRows != dstRows || contract.validCols != dstCols) + return op.emitOpError("tcmps lowering requires matching source and destination valid region"); + if (!isSupportedPackedCmp32ElementType(contract.elementType)) + return op.emitOpError("tcmps lowering currently supports only 32-bit source tiles"); + auto dstElemType = dyn_cast_or_null(getElementType(op.getDst())); + if (!dstElemType || !dstElemType.isUnsignedInteger(8)) + return op.emitOpError("tcmps lowering currently requires ui8 destination tiles"); + if (op.getScalar().getType() != contract.elementType) + return op.emitOpError("tcmps lowering requires scalar type to match source element type"); + + Value srcBuffer = materializeBufferPointer(op.getSrc(), contract.elementType, + getMemorySpace(op.getSrc()), rewriter, + op.getLoc()); + Value dstBuffer = materializeBufferPointer(op.getDst(), getElementType(op.getDst()), + getMemorySpace(op.getDst()), rewriter, + op.getLoc()); + if (!srcBuffer || !dstBuffer) + return op.emitOpError("tcmps lowering requires pointer-backed tile buffers"); + + StringAttr cmpMode = rewriter.getStringAttr(stringifyCmpModeAttr(op.getCmpModeAttr())); + return buildPackedCmp32VecScope( + "tcmps", contract, op.getDst(), dstBuffer, rewriter, op.getLoc(), + [&](PatternRewriter &nestedRewriter, Location nestedLoc, Value offset, + Value mask) -> Value { + auto vecType = + getVPTOVRegType(nestedRewriter.getContext(), contract.elementType); + auto loaded = + nestedRewriter.create(nestedLoc, vecType, srcBuffer, offset, StringAttr()); + return nestedRewriter + .create(nestedLoc, + getVPTOMaskTypeForElementType( + nestedRewriter.getContext(), + contract.elementType), + loaded.getResult(), op.getScalar(), mask, cmpMode) + .getResult(); + }); +} + +LogicalResult lowerTCmp(TCmpOp op, PatternRewriter &rewriter) { + VPTOBinaryContract contract = buildBinaryContract("cmp", op.getSrc0()); + deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); + deriveValidShape(op.getDst(), contract.validRows, contract.validCols); + + if (contract.tileDomain != VPTOTileDomain::Vec || + deriveTileDomain(getMemorySpace(op.getSrc1())) != VPTOTileDomain::Vec || + deriveTileDomain(getMemorySpace(op.getDst())) != VPTOTileDomain::Vec) + return op.emitOpError("tcmp lowering requires tile domain vec"); + if (contract.tileLayout != "row_major" || deriveTileLayout(op.getSrc1()) != "row_major" || + deriveTileLayout(op.getDst()) != "row_major") + return op.emitOpError("tcmp lowering requires row-major tile layout"); + if (contract.validRows == ShapedType::kDynamic || + contract.validCols == ShapedType::kDynamic) + return op.emitOpError("tcmp lowering requires static valid shape"); + int64_t src1Rows = ShapedType::kDynamic; + int64_t src1Cols = ShapedType::kDynamic; + int64_t dstRows = ShapedType::kDynamic; + int64_t dstCols = ShapedType::kDynamic; + deriveValidShape(op.getSrc1(), src1Rows, src1Cols); + deriveValidShape(op.getDst(), dstRows, dstCols); + if (contract.validRows != src1Rows || contract.validCols != src1Cols || + contract.validRows != dstRows || contract.validCols != dstCols) + return op.emitOpError("tcmp lowering requires matching source and destination valid region"); + if (!isSupportedPackedCmp32ElementType(contract.elementType)) + return op.emitOpError("tcmp lowering currently supports only 32-bit source tiles"); + if (getElementType(op.getSrc1()) != contract.elementType) + return op.emitOpError("tcmp lowering requires src1 element type to match src0"); + auto dstElemType = dyn_cast_or_null(getElementType(op.getDst())); + if (!dstElemType || !dstElemType.isUnsignedInteger(8)) + return op.emitOpError("tcmp lowering currently requires ui8 destination tiles"); + + Value src0Buffer = materializeBufferPointer(op.getSrc0(), contract.elementType, + getMemorySpace(op.getSrc0()), rewriter, + op.getLoc()); + Value src1Buffer = materializeBufferPointer(op.getSrc1(), contract.elementType, + getMemorySpace(op.getSrc1()), rewriter, + op.getLoc()); + Value dstBuffer = materializeBufferPointer(op.getDst(), getElementType(op.getDst()), + getMemorySpace(op.getDst()), rewriter, + op.getLoc()); + if (!src0Buffer || !src1Buffer || !dstBuffer) + return op.emitOpError("tcmp lowering requires pointer-backed tile buffers"); + + StringAttr cmpMode = rewriter.getStringAttr(stringifyCmpModeAttr(op.getCmpModeAttr())); + return buildPackedCmp32VecScope( + "tcmp", contract, op.getDst(), dstBuffer, rewriter, op.getLoc(), + [&](PatternRewriter &nestedRewriter, Location nestedLoc, Value offset, + Value mask) -> Value { + auto vecType = + getVPTOVRegType(nestedRewriter.getContext(), contract.elementType); + auto lhs = + nestedRewriter.create(nestedLoc, vecType, src0Buffer, offset, StringAttr()); + auto rhs = + nestedRewriter.create(nestedLoc, vecType, src1Buffer, offset, StringAttr()); + return nestedRewriter + .create(nestedLoc, + getVPTOMaskTypeForElementType( + nestedRewriter.getContext(), + contract.elementType), + lhs.getResult(), rhs.getResult(), mask, cmpMode) + .getResult(); + }); +} + +LogicalResult lowerTCI(TCIOp op, PatternRewriter &rewriter) { + Type elementType = getElementType(op.getDst()); + auto intType = dyn_cast_or_null(elementType); + if (!intType || (intType.getWidth() != 16 && intType.getWidth() != 32)) + return op.emitOpError("tci lowering requires i16 or i32 destination element type"); + if (deriveTileDomain(getMemorySpace(op.getDst())) != VPTOTileDomain::Vec) + return op.emitOpError("tci lowering requires tile domain vec"); + if (deriveTileLayout(op.getDst()) != "row_major") + return op.emitOpError("tci lowering requires row-major tile layout"); + + int64_t validRows = ShapedType::kDynamic; + int64_t validCols = ShapedType::kDynamic; + Value validRowsValue; + Value validColsValue; + deriveValidShapeValues(op.getDst(), validRowsValue, validColsValue); + deriveValidShape(op.getDst(), validRows, validCols); + if (validRows != 1) + return op.emitOpError("tci lowering currently requires valid rows == 1"); + + Value dstBuffer = materializeBufferPointer(op.getDst(), elementType, + getMemorySpace(op.getDst()), rewriter, + op.getLoc()); + if (!dstBuffer) + return op.emitOpError("tci lowering requires pointer-backed destination tile buffer"); + + Value upperBound = materializeIndexValue(validColsValue, validCols, rewriter, op.getLoc()); + if (!upperBound) + return op.emitOpError("tci lowering requires valid cols"); + + Value c0 = rewriter.create(op.getLoc(), 0); + Value c1 = rewriter.create(op.getLoc(), 1); + auto loop = rewriter.create(op.getLoc(), c0, upperBound, c1); + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(loop.getBody()); + Value iv = loop.getInductionVar(); + Value ivAsElem = rewriter.create(op.getLoc(), intType, iv); + Value stored = + op.getDescending() + ? rewriter.create(op.getLoc(), op.getS(), ivAsElem).getResult() + : rewriter.create(op.getLoc(), op.getS(), ivAsElem).getResult(); + rewriter.create(op.getLoc(), dstBuffer, iv, stored); + return success(); +} + +LogicalResult lowerTRELU(TReluOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOUnaryContract contract = extractTReluContract(op); + if (failed(checkGenericUnaryContract( + op, contract, op.getDst(), + [](Type type) { + return type.isF16() || type.isF32() || + (isa(type) && cast(type).getWidth() == 32); + }, + "f16, f32, and i32 element types"))) + return failure(); + return buildUnaryVecScope("relu", contract, strategy, op.getSrc(), + op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTNOT(TNotOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOUnaryContract contract = extractTNotContract(op); + if (failed(checkGenericUnaryContract( + op, contract, op.getDst(), + [](Type type) { + if (type.isF16() || type.isF32() || type.isBF16()) + return true; + if (auto intType = dyn_cast(type)) + return intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + return false; + }, + "f16, f32, bf16, and 8/16/32-bit integer element types"))) + return failure(); + return buildUnaryVecScope("not", contract, strategy, op.getSrc(), op.getDst(), + rewriter, op.getLoc()); +} + +LogicalResult lowerTTRANS(TTransOp op, PatternRewriter &rewriter) { + VPTOUnaryContract contract = buildUnaryContract("trans", op.getSrc()); + int64_t dstRows = ShapedType::kDynamic; + int64_t dstCols = ShapedType::kDynamic; + deriveValidShape(op.getDst(), dstRows, dstCols); + + if (contract.tileDomain != VPTOTileDomain::Vec || + deriveTileDomain(getMemorySpace(op.getDst())) != VPTOTileDomain::Vec) + return op.emitOpError("ttrans lowering requires tile domain vec"); + if (contract.tileLayout != "row_major" || deriveTileLayout(op.getDst()) != "row_major") + return op.emitOpError("ttrans lowering requires row-major tile layout"); + if (contract.validRows == ShapedType::kDynamic || contract.validCols == ShapedType::kDynamic || + dstRows == ShapedType::kDynamic || dstCols == ShapedType::kDynamic) + return op.emitOpError("ttrans lowering requires static valid shape"); + if (contract.validRows != dstCols || contract.validCols != dstRows) + return op.emitOpError("ttrans lowering requires transposed source/destination valid shape"); + if (contract.elementType != getElementType(op.getDst())) + return op.emitOpError("ttrans lowering requires matching source/destination element type"); + + int64_t elemBytes = getElementByteSize(contract.elementType); + int64_t srcStride = deriveStaticRowStride(op.getSrc()); + int64_t dstStride = deriveStaticRowStride(op.getDst()); + if (elemBytes != 4) + return op.emitOpError("ttrans lowering currently supports only b32 element types"); + if (srcStride == ShapedType::kDynamic || dstStride == ShapedType::kDynamic) + return op.emitOpError("ttrans lowering requires static source/destination row stride"); + + auto dataVecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); + auto indexElemType = rewriter.getIntegerType(32); + auto indexVecType = getVPTOVRegType(rewriter.getContext(), indexElemType); + if (!dataVecType || !indexVecType) + return op.emitOpError("ttrans lowering requires supported VPTO vector types"); + + Value srcBuffer = materializeBufferPointer(op.getSrc(), contract.elementType, + getMemorySpace(op.getSrc()), rewriter, + op.getLoc()); + Value dstBuffer = materializeBufferPointer(op.getDst(), contract.elementType, + getMemorySpace(op.getDst()), rewriter, + op.getLoc()); + if (!srcBuffer || !dstBuffer) + return op.emitOpError("ttrans lowering requires pointer-backed tile buffers"); + + constexpr int64_t repeatBytes = 256; + constexpr int64_t blockBytes = 32; + int64_t elementsPerRepeat = repeatBytes / elemBytes; + int64_t blockSizeElem = blockBytes / elemBytes; + int64_t alignedRows = + llvm::divideCeil(contract.validRows, blockSizeElem) * blockSizeElem; + int64_t repeatTimes = llvm::divideCeil(alignedRows, elementsPerRepeat); + + Value c0 = rewriter.create(op.getLoc(), 0); + Value c1 = rewriter.create(op.getLoc(), 1); + Value colsUpper = rewriter.create(op.getLoc(), contract.validCols); + Value chunkUpper = rewriter.create(op.getLoc(), repeatTimes); + Value elementsPerRepeatValue = + rewriter.create(op.getLoc(), elementsPerRepeat); + Value dstStrideValue = rewriter.create(op.getLoc(), dstStride); + Value srcStrideI32 = rewriter.create(op.getLoc(), srcStride, 32); + + FailureOr vecScope = + createLoopScopeRegion(op.getLoc(), contract.loopScope, rewriter); + if (failed(vecScope)) + return op.emitOpError("failed to create AIV vector scope region"); + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + auto colLoop = rewriter.create(op.getLoc(), c0, colsUpper, c1); + rewriter.setInsertionPointToStart(colLoop.getBody()); + auto chunkLoop = rewriter.create(op.getLoc(), c0, chunkUpper, c1); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + + Value chunkBase = rewriter.create(op.getLoc(), chunkLoop.getInductionVar(), + elementsPerRepeatValue); + Value colI32 = rewriter.create(op.getLoc(), indexElemType, + colLoop.getInductionVar()); + Value chunkBaseI32 = + rewriter.create(op.getLoc(), indexElemType, chunkBase); + auto indices = + rewriter.create(op.getLoc(), indexVecType, chunkBaseI32, + rewriter.getStringAttr("INC_ORDER")); + Value fullMask = buildAllPredicateMask(rewriter, op.getLoc(), indexElemType); + auto scaled = rewriter.create(op.getLoc(), indexVecType, + indices.getResult(), srcStrideI32, fullMask); + auto offsets = rewriter.create(op.getLoc(), indexVecType, + scaled.getResult(), colI32, fullMask); + Value fullActiveLanes = + rewriter.create(op.getLoc(), + dataVecType.getElementCount()); + auto gathered = + rewriter.create(op.getLoc(), dataVecType, srcBuffer, + offsets.getResult(), fullActiveLanes); + Value dstBase = + rewriter.create(op.getLoc(), colLoop.getInductionVar(), dstStrideValue); + Value dstOffset = rewriter.create(op.getLoc(), dstBase, chunkBase); + rewriter.create( + op.getLoc(), gathered.getResult(), dstBuffer, dstOffset, StringAttr(), + buildAllPredicateMask(rewriter, op.getLoc(), contract.elementType)); + return success(); +} + +template +LogicalResult lowerTFillPadCommon(FillPadOpTy op, PatternRewriter &rewriter, + bool allowDstExpand) { + VPTOUnaryContract contract = buildUnaryContract("fillpad", op.getSrc()); + int64_t dstRows = ShapedType::kDynamic; + int64_t dstCols = ShapedType::kDynamic; + deriveValidShape(op.getDst(), dstRows, dstCols); + + if (contract.tileDomain != VPTOTileDomain::Vec || + deriveTileDomain(getMemorySpace(op.getDst())) != VPTOTileDomain::Vec) + return op.emitOpError("fillpad lowering requires tile domain vec"); + if (contract.tileLayout != "row_major" || deriveTileLayout(op.getDst()) != "row_major") + return op.emitOpError("fillpad lowering requires row-major tile layout"); + if (contract.validRows == ShapedType::kDynamic || contract.validCols == ShapedType::kDynamic || + dstRows == ShapedType::kDynamic || dstCols == ShapedType::kDynamic) + return op.emitOpError("fillpad lowering requires static valid shape"); + if (!allowDstExpand && (contract.validRows != dstRows || contract.validCols != dstCols)) + return op.emitOpError("tfillpad lowering requires matching source/destination valid shape"); + if (allowDstExpand && (dstRows < contract.validRows || dstCols < contract.validCols)) + return op.emitOpError("tfillpad_expand lowering requires dst shape >= src shape"); + if (contract.elementType != getElementType(op.getDst())) + return op.emitOpError("fillpad lowering requires matching source/destination element type"); + + int64_t srcStride = deriveStaticRowStride(op.getSrc()); + int64_t dstStride = deriveStaticRowStride(op.getDst()); + if (srcStride == ShapedType::kDynamic || dstStride == ShapedType::kDynamic) + return op.emitOpError("fillpad lowering requires static source/destination row stride"); + + auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); + if (!vecType) + return op.emitOpError("fillpad lowering requires supported VPTO vector element type"); + + Value srcBuffer = materializeBufferPointer(op.getSrc(), contract.elementType, + getMemorySpace(op.getSrc()), rewriter, + op.getLoc()); + Value dstBuffer = materializeBufferPointer(op.getDst(), contract.elementType, + getMemorySpace(op.getDst()), rewriter, + op.getLoc()); + if (!srcBuffer || !dstBuffer) + return op.emitOpError("fillpad lowering requires pointer-backed tile buffers"); + + auto config = lookupTileConfig(op.getDst()); + PadValueAttr padAttr = config ? dyn_cast(config.getPad()) : PadValueAttr{}; + Attribute padValueAttr = buildFillPadValue(contract.elementType, padAttr, rewriter); + if (!padValueAttr) + return op.emitOpError("fillpad lowering requires a concrete non-null dst pad value"); + Value padScalar = rewriter.create(op.getLoc(), cast(padValueAttr)); + Value fullMask = buildAllPredicateMask(rewriter, op.getLoc(), vecType.getElementType()); + auto padVec = + rewriter.create(op.getLoc(), vecType, padScalar, fullMask, StringAttr()); + + int64_t vectorWidth = vecType.getElementCount(); + int64_t padCols = dstCols - contract.validCols; + int64_t padRows = dstRows - contract.validRows; + + Value c0 = rewriter.create(op.getLoc(), 0); + Value c1 = rewriter.create(op.getLoc(), 1); + Value srcRowsUpper = rewriter.create(op.getLoc(), contract.validRows); + Value srcColsUpper = rewriter.create(op.getLoc(), contract.validCols); + Value dstRowsUpper = rewriter.create(op.getLoc(), dstRows); + Value vectorStep = rewriter.create(op.getLoc(), vectorWidth); + Value srcStrideValue = rewriter.create(op.getLoc(), srcStride); + Value dstStrideValue = rewriter.create(op.getLoc(), dstStride); + Value validColsValue = rewriter.create(op.getLoc(), contract.validCols); + Value dstColsValue = rewriter.create(op.getLoc(), dstCols); + + FailureOr vecScope = + createLoopScopeRegion(op.getLoc(), contract.loopScope, rewriter); + if (failed(vecScope)) + return op.emitOpError("failed to create AIV vector scope region"); + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + + auto rowLoop = rewriter.create(op.getLoc(), c0, srcRowsUpper, c1); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + Value srcRowBase = rewriter.create(op.getLoc(), rowLoop.getInductionVar(), + srcStrideValue); + Value dstRowBase = rewriter.create(op.getLoc(), rowLoop.getInductionVar(), + dstStrideValue); + + auto copyChunkLoop = + rewriter.create(op.getLoc(), c0, srcColsUpper, vectorStep); + rewriter.setInsertionPointToStart(copyChunkLoop.getBody()); + Value copyOffset = + rewriter.create(op.getLoc(), srcRowBase, copyChunkLoop.getInductionVar()); + auto loaded = rewriter.create(op.getLoc(), vecType, srcBuffer, + copyOffset, StringAttr()); + Value copyDstOffset = + rewriter.create(op.getLoc(), dstRowBase, copyChunkLoop.getInductionVar()); + Value copyRemaining = + rewriter.create(op.getLoc(), validColsValue, copyChunkLoop.getInductionVar()); + auto copyNeedsClamp = rewriter.create(op.getLoc(), arith::CmpIPredicate::slt, + copyRemaining, vectorStep); + Value copyActiveLanes = + rewriter.create(op.getLoc(), copyNeedsClamp, copyRemaining, vectorStep); + Value copyMask = buildPredicateMaskForLaneCount( + rewriter, op.getLoc(), contract.elementType, copyActiveLanes); + rewriter.create(op.getLoc(), loaded.getResult(), dstBuffer, + copyDstOffset, StringAttr(), copyMask); + + rewriter.setInsertionPointAfter(copyChunkLoop); + if (padCols > 0) { + Value padColsUpper = rewriter.create(op.getLoc(), padCols); + auto padColLoop = rewriter.create(op.getLoc(), c0, padColsUpper, vectorStep); + rewriter.setInsertionPointToStart(padColLoop.getBody()); + Value padDstStart = rewriter.create(op.getLoc(), dstRowBase, validColsValue); + Value padDstOffset = rewriter.create(op.getLoc(), padDstStart, + padColLoop.getInductionVar()); + Value padRemaining = + rewriter.create(op.getLoc(), padColsUpper, padColLoop.getInductionVar()); + auto padNeedsClamp = rewriter.create(op.getLoc(), arith::CmpIPredicate::slt, + padRemaining, vectorStep); + Value padActiveLanes = + rewriter.create(op.getLoc(), padNeedsClamp, padRemaining, vectorStep); + Value padMask = buildPredicateMaskForLaneCount( + rewriter, op.getLoc(), contract.elementType, padActiveLanes); + rewriter.create(op.getLoc(), padVec.getResult(), dstBuffer, + padDstOffset, StringAttr(), padMask); + } + + rewriter.setInsertionPointAfter(rowLoop); + if (padRows > 0) { + Value bottomStart = rewriter.create(op.getLoc(), srcRowsUpper, dstStrideValue); + Value bottomElements = + rewriter.create(op.getLoc(), + rewriter.create(op.getLoc(), dstRowsUpper, + dstColsValue), + bottomStart); + auto bottomLoop = rewriter.create(op.getLoc(), c0, bottomElements, vectorStep); + rewriter.setInsertionPointToStart(bottomLoop.getBody()); + Value bottomDstOffset = + rewriter.create(op.getLoc(), bottomStart, bottomLoop.getInductionVar()); + Value bottomRemaining = + rewriter.create(op.getLoc(), bottomElements, bottomLoop.getInductionVar()); + auto bottomNeedsClamp = rewriter.create( + op.getLoc(), arith::CmpIPredicate::slt, bottomRemaining, vectorStep); + Value bottomActiveLanes = rewriter.create( + op.getLoc(), bottomNeedsClamp, bottomRemaining, vectorStep); + Value bottomMask = buildPredicateMaskForLaneCount( + rewriter, op.getLoc(), contract.elementType, bottomActiveLanes); + rewriter.create(op.getLoc(), padVec.getResult(), dstBuffer, + bottomDstOffset, StringAttr(), bottomMask); + } + + return success(); +} + +LogicalResult lowerTFILLPAD(TFillPadOp op, PatternRewriter &rewriter) { + return lowerTFillPadCommon(op, rewriter, /*allowDstExpand=*/false); +} + +LogicalResult lowerTFILLPADExpand(TFillPadExpandOp op, PatternRewriter &rewriter) { + return lowerTFillPadCommon(op, rewriter, /*allowDstExpand=*/true); +} + +LogicalResult lowerTExpandS(TExpandsOp op, PatternRewriter &rewriter) { + VPTOUnaryContract contract = extractTExpandSContract(op); + if (contract.tileDomain != VPTOTileDomain::Vec) + return op.emitOpError("expands lowering requires tile domain vec"); + if (contract.tileLayout != "row_major") + return op.emitOpError("expands lowering requires row-major tile layout"); + if (!contract.elementType) + return op.emitOpError("expands lowering requires a concrete element type"); + + Type scalarType = op.getScalar().getType(); + if (scalarType != contract.elementType) + return op.emitOpError("expands lowering requires scalar type to match destination element type"); + + if (!(contract.elementType.isF16() || contract.elementType.isF32() || + contract.elementType.isBF16())) { + if (auto intType = dyn_cast(contract.elementType)) { + unsigned width = intType.getWidth(); + if (width != 8 && width != 16 && width != 32) + return op.emitOpError("expands lowering supports only f16, f32, bf16, and 8/16/32-bit integer element types"); + } else { + return op.emitOpError("expands lowering supports only scalar integer or floating-point element types"); + } + } + + return buildExpandScalarVecScope(contract, op.getScalar(), op.getDst(), + rewriter, op.getLoc()); +} + +LogicalResult lowerTGather(TGatherOp op, PatternRewriter &rewriter) { + auto requireVecRowMajor = [&](Value value, StringRef role) -> LogicalResult { + if (deriveTileDomain(getMemorySpace(value)) != VPTOTileDomain::Vec) + return op.emitOpError() << "tgather lowering requires vec tile domain for " + << role; + if (deriveTileLayout(value) != "row_major") + return op.emitOpError() << "tgather lowering requires row-major layout for " + << role; + return success(); + }; + + if (failed(requireVecRowMajor(op.getSrc(), "src")) || + failed(requireVecRowMajor(op.getDst(), "dst"))) + return failure(); + + Type dataElementType = getElementType(op.getSrc()); + if (dataElementType != getElementType(op.getDst())) + return op.emitOpError("tgather lowering requires matching src/dst element type"); + + auto dataVecType = getVPTOVRegType(rewriter.getContext(), dataElementType); + if (!dataVecType) + return op.emitOpError("tgather lowering requires supported VPTO data type"); + + Value srcBuffer = materializeBufferPointer(op.getSrc(), dataElementType, + getMemorySpace(op.getSrc()), rewriter, + op.getLoc()); + Value dstBuffer = materializeBufferPointer(op.getDst(), dataElementType, + getMemorySpace(op.getDst()), rewriter, + op.getLoc()); + if (!srcBuffer || !dstBuffer) + return op.emitOpError("tgather lowering requires pointer-backed tile buffers"); + + int64_t srcStride = deriveStaticRowStride(op.getSrc()); + int64_t dstStride = deriveStaticRowStride(op.getDst()); + if (srcStride == ShapedType::kDynamic || dstStride == ShapedType::kDynamic) + return op.emitOpError("tgather lowering requires static row stride"); + + Value c0 = rewriter.create(op.getLoc(), 0); + Value c1 = rewriter.create(op.getLoc(), 1); + VPTOLoopScopeContract loopScope; + loopScope.kind = VPTOLoopScopeKind::AIVVectorScope; + loopScope.loweredAttr = kLoweredLoopScopeAttrName; + loopScope.loopDepth = 0; + + FailureOr vecScope = + createLoopScopeRegion(op.getLoc(), loopScope, rewriter); + if (failed(vecScope)) + return op.emitOpError("failed to create AIV vector scope region"); + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + + if (Value indices = op.getIndices()) { + if (failed(requireVecRowMajor(indices, "indices"))) + return failure(); + + Type indexElementType = getElementType(indices); + auto indexIntegerType = dyn_cast(indexElementType); + auto indexVecType = getVPTOVRegType(rewriter.getContext(), indexElementType); + if (!indexIntegerType || !indexVecType) + return op.emitOpError("tgather index lowering requires integer indices with supported VPTO vector type"); + if (indexVecType.getElementCount() != dataVecType.getElementCount()) + return op.emitOpError("tgather index lowering currently requires matching data/index vector widths"); + + Value indexBuffer = materializeBufferPointer(indices, indexElementType, + getMemorySpace(indices), rewriter, + op.getLoc()); + if (!indexBuffer) + return op.emitOpError("tgather index lowering requires pointer-backed indices tile"); + + int64_t indexStride = deriveStaticRowStride(indices); + if (indexStride == ShapedType::kDynamic) + return op.emitOpError("tgather index lowering requires static index row stride"); + + Value validRowsValue; + Value validColsValue; + int64_t validRows = ShapedType::kDynamic; + int64_t validCols = ShapedType::kDynamic; + deriveValidShapeValues(op.getDst(), validRowsValue, validColsValue); + deriveValidShape(op.getDst(), validRows, validCols); + if (failed(resolveExecutionValidShape(op.getDst(), validRowsValue, validColsValue, + validRows, validCols, rewriter, op.getLoc()))) + return op.emitOpError("tgather index lowering requires valid dst shape"); + + int64_t chunkWidth = indexVecType.getElementCount(); + Value chunkStep = rewriter.create(op.getLoc(), chunkWidth); + Value dstStrideValue = + rewriter.create(op.getLoc(), dstStride); + Value indexStrideValue = + rewriter.create(op.getLoc(), indexStride); + + auto rowLoop = rewriter.create(op.getLoc(), c0, validRowsValue, c1); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + auto chunkLoop = + rewriter.create(op.getLoc(), c0, validColsValue, chunkStep); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + + Value row = rowLoop.getInductionVar(); + Value chunkBase = chunkLoop.getInductionVar(); + Value remaining = + rewriter.create(op.getLoc(), validColsValue, chunkBase); + Value activeLanes = + buildMinIndexValue(rewriter, op.getLoc(), remaining, chunkStep); + + Value dstRowBase = + rewriter.create(op.getLoc(), row, dstStrideValue); + Value indexRowBase = + rewriter.create(op.getLoc(), row, indexStrideValue); + Value indexOffset = + rewriter.create(op.getLoc(), indexRowBase, chunkBase); + auto offsetVector = rewriter.create(op.getLoc(), indexVecType, + indexBuffer, indexOffset, + StringAttr()); + auto gathered = rewriter.create( + op.getLoc(), dataVecType, srcBuffer, offsetVector.getResult(), activeLanes); + Value dstOffset = + rewriter.create(op.getLoc(), dstRowBase, chunkBase); + return buildMaskedVectorStore(rewriter, op.getLoc(), gathered.getResult(), + dstBuffer, dstOffset, activeLanes, chunkWidth); + } + + auto maskPattern = op.getMaskPatternAttr(); + if (!maskPattern) + return op.emitOpError("tgather lowering requires indices or maskPattern"); + if (maskPattern.getValue() != MaskPattern::P1111) + return op.emitOpError("tgather mask lowering currently supports only maskPattern=P1111"); + + Value validRowsValue; + Value validColsValue; + int64_t validRows = ShapedType::kDynamic; + int64_t validCols = ShapedType::kDynamic; + deriveValidShapeValues(op.getSrc(), validRowsValue, validColsValue); + deriveValidShape(op.getSrc(), validRows, validCols); + if (failed(resolveExecutionValidShape(op.getSrc(), validRowsValue, validColsValue, + validRows, validCols, rewriter, op.getLoc()))) + return op.emitOpError("tgather mask lowering requires valid src shape"); + + int64_t chunkWidth = dataVecType.getElementCount(); + Value chunkStep = rewriter.create(op.getLoc(), chunkWidth); + Value srcStrideValue = + rewriter.create(op.getLoc(), srcStride); + Value dstStrideValue = + rewriter.create(op.getLoc(), dstStride); + + auto rowLoop = rewriter.create(op.getLoc(), c0, validRowsValue, c1); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + auto chunkLoop = + rewriter.create(op.getLoc(), c0, validColsValue, chunkStep); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + + Value row = rowLoop.getInductionVar(); + Value chunkBase = chunkLoop.getInductionVar(); + Value remaining = + rewriter.create(op.getLoc(), validColsValue, chunkBase); + Value activeLanes = buildMinIndexValue(rewriter, op.getLoc(), remaining, chunkStep); + + Value srcRowBase = + rewriter.create(op.getLoc(), row, srcStrideValue); + Value dstRowBase = + rewriter.create(op.getLoc(), row, dstStrideValue); + Value srcOffset = + rewriter.create(op.getLoc(), srcRowBase, chunkBase); + auto loaded = rewriter.create(op.getLoc(), dataVecType, srcBuffer, + srcOffset, StringAttr()); + Value dstOffset = + rewriter.create(op.getLoc(), dstRowBase, chunkBase); + return buildMaskedVectorStore(rewriter, op.getLoc(), loaded.getResult(), dstBuffer, + dstOffset, activeLanes, chunkWidth); +} + +LogicalResult lowerTGatherB(TGatherBOp op, PatternRewriter &rewriter) { + auto requireVecRowMajor = [&](Value value, StringRef role) -> LogicalResult { + if (deriveTileDomain(getMemorySpace(value)) != VPTOTileDomain::Vec) + return op.emitOpError() << "tgatherb lowering requires vec tile domain for " + << role; + if (deriveTileLayout(value) != "row_major") + return op.emitOpError() << "tgatherb lowering requires row-major layout for " + << role; + return success(); + }; + + if (failed(requireVecRowMajor(op.getSrc(), "src")) || + failed(requireVecRowMajor(op.getOffsets(), "offsets")) || + failed(requireVecRowMajor(op.getDst(), "dst"))) + return failure(); + + Type dataElementType = getElementType(op.getDst()); + if (getElementType(op.getSrc()) != dataElementType) + return op.emitOpError("tgatherb lowering requires matching src/dst element type"); + + auto offsetIntegerType = dyn_cast(getElementType(op.getOffsets())); + if (!offsetIntegerType || offsetIntegerType.getWidth() != 32 || + !offsetIntegerType.isUnsigned()) + return op.emitOpError("tgatherb lowering currently requires unsigned 32-bit offsets"); + + auto dataVecType = getVPTOVRegType(rewriter.getContext(), dataElementType); + auto offsetVecType = + getVPTOVRegType(rewriter.getContext(), getElementType(op.getOffsets())); + if (!dataVecType || !offsetVecType) + return op.emitOpError("tgatherb lowering requires supported VPTO vector types"); + + Value srcBuffer = materializeBufferPointer(op.getSrc(), dataElementType, + getMemorySpace(op.getSrc()), rewriter, + op.getLoc()); + Value dstBuffer = materializeBufferPointer(op.getDst(), dataElementType, + getMemorySpace(op.getDst()), rewriter, + op.getLoc()); + Value offsetBuffer = + materializeBufferPointer(op.getOffsets(), getElementType(op.getOffsets()), + getMemorySpace(op.getOffsets()), rewriter, op.getLoc()); + if (!srcBuffer || !dstBuffer || !offsetBuffer) + return op.emitOpError("tgatherb lowering requires pointer-backed tile buffers"); + + int64_t dstStride = deriveStaticRowStride(op.getDst()); + int64_t offsetStride = deriveStaticRowStride(op.getOffsets()); + int64_t staticRows = deriveStaticShapeDim(op.getDst(), 0); + int64_t staticCols = deriveStaticShapeDim(op.getDst(), 1); + if (dstStride == ShapedType::kDynamic || offsetStride == ShapedType::kDynamic || + staticRows == ShapedType::kDynamic || staticCols == ShapedType::kDynamic) + return op.emitOpError("tgatherb lowering requires static tile shape and row stride"); + + Value validRowsValue; + Value validColsValue; + int64_t validRows = ShapedType::kDynamic; + int64_t validCols = ShapedType::kDynamic; + deriveValidShapeValues(op.getDst(), validRowsValue, validColsValue); + deriveValidShape(op.getDst(), validRows, validCols); + if (failed(resolveExecutionValidShape(op.getDst(), validRowsValue, validColsValue, + validRows, validCols, rewriter, op.getLoc()))) + return op.emitOpError("tgatherb lowering requires valid dst shape"); + + unsigned elemBytes = dataElementType.getIntOrFloatBitWidth() / 8; + int64_t elementsPerRepeat = 256 / elemBytes; + int64_t blockSizeElem = 32 / elemBytes; + int64_t staticRepeatTimes = llvm::divideCeil(staticCols, elementsPerRepeat); + + Value c0 = rewriter.create(op.getLoc(), 0); + Value c1 = rewriter.create(op.getLoc(), 1); + Value elementsPerRepeatValue = + rewriter.create(op.getLoc(), elementsPerRepeat); + Value blockSizeElemValue = + rewriter.create(op.getLoc(), blockSizeElem); + Value dstStrideValue = + rewriter.create(op.getLoc(), dstStride); + Value offsetStrideValue = + rewriter.create(op.getLoc(), offsetStride); + + VPTOLoopScopeContract loopScope; + loopScope.kind = VPTOLoopScopeKind::AIVVectorScope; + loopScope.loweredAttr = kLoweredLoopScopeAttrName; + loopScope.loopDepth = 0; + + FailureOr vecScope = + createLoopScopeRegion(op.getLoc(), loopScope, rewriter); + if (failed(vecScope)) + return op.emitOpError("failed to create AIV vector scope region"); + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + + if (staticRepeatTimes > staticRows) { + auto rowLoop = rewriter.create(op.getLoc(), c0, validRowsValue, c1); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + auto chunkLoop = rewriter.create(op.getLoc(), c0, validColsValue, + elementsPerRepeatValue); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + + Value row = rowLoop.getInductionVar(); + Value chunkBase = chunkLoop.getInductionVar(); + Value remaining = + rewriter.create(op.getLoc(), validColsValue, chunkBase); + Value activeLanes = buildMinIndexValue(rewriter, op.getLoc(), remaining, + elementsPerRepeatValue); + Value rowOffsetBase = + rewriter.create(op.getLoc(), row, offsetStrideValue); + Value rowDstBase = + rewriter.create(op.getLoc(), row, dstStrideValue); + Value offsetChunkBase = + rewriter.create(op.getLoc(), chunkBase, + blockSizeElemValue); + Value offsetLoadOffset = + rewriter.create(op.getLoc(), rowOffsetBase, offsetChunkBase); + auto offsets = rewriter.create(op.getLoc(), offsetVecType, + offsetBuffer, offsetLoadOffset, + StringAttr()); + auto gathered = rewriter.create( + op.getLoc(), dataVecType, srcBuffer, offsets.getResult(), activeLanes); + Value dstOffset = + rewriter.create(op.getLoc(), rowDstBase, chunkBase); + return buildMaskedVectorStore(rewriter, op.getLoc(), gathered.getResult(), + dstBuffer, dstOffset, activeLanes, + dataVecType.getElementCount()); + } + + auto chunkLoop = rewriter.create(op.getLoc(), c0, validColsValue, + elementsPerRepeatValue); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + auto rowLoop = rewriter.create(op.getLoc(), c0, validRowsValue, c1); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + + Value chunkBase = chunkLoop.getInductionVar(); + Value row = rowLoop.getInductionVar(); + Value remaining = + rewriter.create(op.getLoc(), validColsValue, chunkBase); + Value activeLanes = buildMinIndexValue(rewriter, op.getLoc(), remaining, + elementsPerRepeatValue); + Value rowOffsetBase = + rewriter.create(op.getLoc(), row, offsetStrideValue); + Value rowDstBase = + rewriter.create(op.getLoc(), row, dstStrideValue); + Value offsetChunkBase = + rewriter.create(op.getLoc(), chunkBase, + blockSizeElemValue); + Value offsetLoadOffset = + rewriter.create(op.getLoc(), rowOffsetBase, offsetChunkBase); + auto offsets = rewriter.create(op.getLoc(), offsetVecType, offsetBuffer, + offsetLoadOffset, StringAttr()); + auto gathered = rewriter.create( + op.getLoc(), dataVecType, srcBuffer, offsets.getResult(), activeLanes); + Value dstOffset = + rewriter.create(op.getLoc(), chunkBase, rowDstBase); + return buildMaskedVectorStore(rewriter, op.getLoc(), gathered.getResult(), + dstBuffer, dstOffset, activeLanes, + dataVecType.getElementCount()); +} + +LogicalResult lowerTScatter(TScatterOp op, PatternRewriter &rewriter) { + auto requireVecRowMajor = [&](Value value, StringRef role) -> LogicalResult { + if (deriveTileDomain(getMemorySpace(value)) != VPTOTileDomain::Vec) + return op.emitOpError() << "tscatter lowering requires vec tile domain for " + << role; + if (deriveTileLayout(value) != "row_major") + return op.emitOpError() << "tscatter lowering requires row-major layout for " + << role; + return success(); + }; + + if (failed(requireVecRowMajor(op.getSrc(), "src")) || + failed(requireVecRowMajor(op.getIndexes(), "indexes")) || + failed(requireVecRowMajor(op.getDst(), "dst"))) + return failure(); + + Type dataElementType = getElementType(op.getSrc()); + if (dataElementType != getElementType(op.getDst())) + return op.emitOpError("tscatter lowering requires matching src/dst element type"); + + Type indexElementType = getElementType(op.getIndexes()); + auto indexIntegerType = dyn_cast(indexElementType); + if (!indexIntegerType || indexIntegerType.getWidth() != 32) + return op.emitOpError("tscatter lowering currently requires 32-bit integer indexes"); + + auto dataVecType = getVPTOVRegType(rewriter.getContext(), dataElementType); + auto indexVecType = getVPTOVRegType(rewriter.getContext(), indexElementType); + if (!dataVecType || !indexVecType || + dataVecType.getElementCount() != indexVecType.getElementCount()) + return op.emitOpError("tscatter lowering currently requires matching data/index vector widths"); + + Value srcBuffer = materializeBufferPointer(op.getSrc(), dataElementType, + getMemorySpace(op.getSrc()), rewriter, + op.getLoc()); + Value dstBuffer = materializeBufferPointer(op.getDst(), dataElementType, + getMemorySpace(op.getDst()), rewriter, + op.getLoc()); + Value indexBuffer = materializeBufferPointer(op.getIndexes(), indexElementType, + getMemorySpace(op.getIndexes()), rewriter, + op.getLoc()); + if (!srcBuffer || !dstBuffer || !indexBuffer) + return op.emitOpError("tscatter lowering requires pointer-backed tile buffers"); + + int64_t srcStride = deriveStaticRowStride(op.getSrc()); + int64_t indexStride = deriveStaticRowStride(op.getIndexes()); + if (srcStride == ShapedType::kDynamic || indexStride == ShapedType::kDynamic) + return op.emitOpError("tscatter lowering requires static src/index row stride"); + + Value validRowsValue; + Value validColsValue; + int64_t validRows = ShapedType::kDynamic; + int64_t validCols = ShapedType::kDynamic; + deriveValidShapeValues(op.getIndexes(), validRowsValue, validColsValue); + deriveValidShape(op.getIndexes(), validRows, validCols); + if (failed(resolveExecutionValidShape(op.getIndexes(), validRowsValue, validColsValue, + validRows, validCols, rewriter, op.getLoc()))) + return op.emitOpError("tscatter lowering requires valid index shape"); + + Value c0 = rewriter.create(op.getLoc(), 0); + Value c1 = rewriter.create(op.getLoc(), 1); + Value chunkStep = rewriter.create( + op.getLoc(), indexVecType.getElementCount()); + Value srcStrideValue = + rewriter.create(op.getLoc(), srcStride); + Value indexStrideValue = + rewriter.create(op.getLoc(), indexStride); + + VPTOLoopScopeContract loopScope; + loopScope.kind = VPTOLoopScopeKind::AIVVectorScope; + loopScope.loweredAttr = kLoweredLoopScopeAttrName; + loopScope.loopDepth = 0; + + FailureOr vecScope = + createLoopScopeRegion(op.getLoc(), loopScope, rewriter); + if (failed(vecScope)) + return op.emitOpError("failed to create AIV vector scope region"); + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + auto rowLoop = rewriter.create(op.getLoc(), c0, validRowsValue, c1); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + auto chunkLoop = + rewriter.create(op.getLoc(), c0, validColsValue, chunkStep); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + + Value row = rowLoop.getInductionVar(); + Value chunkBase = chunkLoop.getInductionVar(); + Value remaining = + rewriter.create(op.getLoc(), validColsValue, chunkBase); + Value activeLanes = + buildMinIndexValue(rewriter, op.getLoc(), remaining, chunkStep); + + Value srcRowBase = + rewriter.create(op.getLoc(), row, srcStrideValue); + Value indexRowBase = + rewriter.create(op.getLoc(), row, indexStrideValue); + Value srcOffset = + rewriter.create(op.getLoc(), srcRowBase, chunkBase); + Value indexOffset = + rewriter.create(op.getLoc(), indexRowBase, chunkBase); + auto srcVector = rewriter.create(op.getLoc(), dataVecType, srcBuffer, + srcOffset, StringAttr()); + auto indexVector = rewriter.create(op.getLoc(), indexVecType, indexBuffer, + indexOffset, StringAttr()); + rewriter.create(op.getLoc(), srcVector.getResult(), dstBuffer, + indexVector.getResult(), activeLanes); + return success(); +} + +LogicalResult lowerTMrgSort(TMrgSortOp op, PatternRewriter &rewriter) { + auto requireVecRowMajor = [&](Value value, StringRef role) -> LogicalResult { + if (deriveTileDomain(getMemorySpace(value)) != VPTOTileDomain::Vec) + return op.emitOpError() << "tmrgsort lowering requires vec tile domain for " + << role; + if (deriveTileLayout(value) != "row_major") + return op.emitOpError() << "tmrgsort lowering requires row-major layout for " + << role; + return success(); + }; + auto requireOneRow = [&](Value value, StringRef role) -> LogicalResult { + if (deriveStaticShapeDim(value, 0) != 1) + return op.emitOpError() << "tmrgsort lowering requires rows==1 for " << role; + return success(); + }; + + Location loc = op.getLoc(); + if (op.isFormat1()) { + Value src = op.getSrcs().front(); + Value dst = op.getDsts().front(); + if (failed(requireVecRowMajor(src, "src")) || failed(requireVecRowMajor(dst, "dst")) || + failed(requireOneRow(src, "src")) || failed(requireOneRow(dst, "dst"))) + return failure(); + + Type elementType = getElementType(src); + if (elementType != getElementType(dst)) + return op.emitOpError("tmrgsort format1 requires matching src/dst element type"); + if (!(elementType.isF16() || elementType.isF32())) + return op.emitOpError("tmrgsort format1 currently supports only f16/f32"); + + Value srcBuffer = materializeBufferPointer(src, elementType, getMemorySpace(src), + rewriter, loc); + Value dstBuffer = materializeBufferPointer(dst, elementType, getMemorySpace(dst), + rewriter, loc); + if (!srcBuffer || !dstBuffer) + return op.emitOpError("tmrgsort format1 requires pointer-backed tile buffers"); + + Value blockLen = op.getBlockLen(); + if (!blockLen) + return op.emitOpError("tmrgsort format1 requires blockLen"); + Value blockLenI64; + if (blockLen.getType().isIndex()) + blockLenI64 = + rewriter.create(loc, rewriter.getI64Type(), blockLen); + else + blockLenI64 = + rewriter.create(loc, rewriter.getI64Type(), blockLen); + Value blockLenIndex = + rewriter.create(loc, rewriter.getIndexType(), blockLenI64); + + Value validRowsValue; + Value validColsValue; + int64_t validRows = ShapedType::kDynamic; + int64_t validCols = ShapedType::kDynamic; + deriveValidShapeValues(src, validRowsValue, validColsValue); + deriveValidShape(src, validRows, validCols); + Value validColsI64 = materializeI64Value(validColsValue, validCols, rewriter, loc); + + int64_t elemBytes = getElementByteSize(elementType); + Value numStructures = rewriter.create( + loc, rewriter.getI64Type(), + rewriter.create( + loc, blockLenI64, rewriter.create(loc, elemBytes, 64)), + rewriter.create(loc, 3, 64)); + Value count = buildPackedCountI64(rewriter, loc, + {numStructures, numStructures, numStructures, numStructures}); + Value repeatTimes = rewriter.create( + loc, validColsI64, + rewriter.create( + loc, blockLenI64, rewriter.create(loc, 4, 64))); + Value config = rewriter.create( + loc, repeatTimes, rewriter.create(loc, 0b1111 << 8, 64)); + + Value src0 = srcBuffer; + Value src1 = offsetBufferPointer(srcBuffer, elementType, blockLenIndex, rewriter, loc); + Value src2 = offsetBufferPointer( + srcBuffer, elementType, + rewriter.create(loc, blockLenIndex, + rewriter.create(loc, 2)), + rewriter, loc); + Value src3 = offsetBufferPointer( + srcBuffer, elementType, + rewriter.create(loc, blockLenIndex, + rewriter.create(loc, 3)), + rewriter, loc); + rewriter.create(loc, dstBuffer, src0, src1, src2, src3, count, + config); + return success(); + } + + if (!op.isFormat2()) + return op.emitOpError("unsupported tmrgsort format for current vpto backend"); + if (op.getExhausted()) + return op.emitOpError("tmrgsort format2 exhausted=true is not yet supported"); + if (op.getSrcs().size() != 4 || op.getDsts().size() != 2) + return op.emitOpError("tmrgsort format2 currently requires exactly 4 srcs and 2 dsts"); + + Type elementType = getElementType(op.getSrcs().front()); + if (!(elementType.isF16() || elementType.isF32())) + return op.emitOpError("tmrgsort format2 currently supports only f16/f32"); + + SmallVector srcBuffers; + SmallVector srcCounts; + srcBuffers.reserve(4); + srcCounts.reserve(4); + for (Value src : op.getSrcs()) { + if (failed(requireVecRowMajor(src, "src")) || failed(requireOneRow(src, "src"))) + return failure(); + if (getElementType(src) != elementType) + return op.emitOpError("tmrgsort format2 requires matching source element types"); + + Value srcBuffer = + materializeBufferPointer(src, elementType, getMemorySpace(src), rewriter, loc); + if (!srcBuffer) + return op.emitOpError("tmrgsort format2 requires pointer-backed source tiles"); + srcBuffers.push_back(srcBuffer); + + Value rowsValue; + Value colsValue; + int64_t rows = ShapedType::kDynamic; + int64_t cols = ShapedType::kDynamic; + deriveValidShapeValues(src, rowsValue, colsValue); + deriveValidShape(src, rows, cols); + Value colsI64 = materializeI64Value(colsValue, cols, rewriter, loc); + srcCounts.push_back(rewriter.create( + loc, rewriter.getI64Type(), colsI64, + rewriter.create(loc, elementType.isF32() ? 1 : 2, 64))); + } + + Value dst = op.getDsts()[0]; + Value tmp = op.getDsts()[1]; + if (failed(requireVecRowMajor(dst, "dst")) || failed(requireVecRowMajor(tmp, "tmp")) || + failed(requireOneRow(dst, "dst")) || failed(requireOneRow(tmp, "tmp"))) + return failure(); + if (getElementType(dst) != elementType || getElementType(tmp) != elementType) + return op.emitOpError("tmrgsort format2 requires matching dst/tmp element types"); + + Value dstBuffer = + materializeBufferPointer(dst, elementType, getMemorySpace(dst), rewriter, loc); + Value tmpBuffer = + materializeBufferPointer(tmp, elementType, getMemorySpace(tmp), rewriter, loc); + if (!dstBuffer || !tmpBuffer) + return op.emitOpError("tmrgsort format2 requires pointer-backed dst/tmp tiles"); + + Value count = buildPackedCountI64(rewriter, loc, srcCounts); + Value config = + rewriter.create(loc, 1 | (0b1111 << 8), 64); + rewriter.create(loc, tmpBuffer, srcBuffers[0], srcBuffers[1], + srcBuffers[2], srcBuffers[3], count, config); + + Value dstRowsValue; + Value dstColsValue; + int64_t dstRows = ShapedType::kDynamic; + int64_t dstCols = ShapedType::kDynamic; + deriveValidShapeValues(dst, dstRowsValue, dstColsValue); + deriveValidShape(dst, dstRows, dstCols); + Value dstColsI64 = materializeI64Value(dstColsValue, dstCols, rewriter, loc); + int64_t elemBytes = getElementByteSize(elementType); + Value lenBurst = buildCeilDivPositiveI64( + rewriter, loc, + rewriter.create( + loc, dstColsI64, rewriter.create(loc, elemBytes, 64)), + 32); + Value zeroI64 = rewriter.create(loc, 0, 64); + Value oneI64 = rewriter.create(loc, 1, 64); + rewriter.create(loc, tmpBuffer, dstBuffer, zeroI64, oneI64, + lenBurst, zeroI64, zeroI64); + return success(); +} + +LogicalResult lowerTSort32(TSort32Op op, PatternRewriter &rewriter) { + auto requireVecRowMajor = [&](Value value, StringRef role) -> LogicalResult { + if (deriveTileDomain(getMemorySpace(value)) != VPTOTileDomain::Vec) + return op.emitOpError() << "tsort32 lowering requires vec tile domain for " + << role; + if (deriveTileLayout(value) != "row_major") + return op.emitOpError() << "tsort32 lowering requires row-major layout for " + << role; + return success(); + }; + + if (failed(requireVecRowMajor(op.getSrc(), "src")) || + failed(requireVecRowMajor(op.getDst(), "dst")) || + failed(requireVecRowMajor(op.getIdx(), "idx"))) + return failure(); + + Type dataType = getElementType(op.getSrc()); + if (dataType != getElementType(op.getDst())) + return op.emitOpError("tsort32 lowering requires matching src/dst element type"); + if (!(dataType.isF16() || dataType.isF32())) + return op.emitOpError("tsort32 lowering currently supports only f16/f32 data"); + auto idxType = dyn_cast(getElementType(op.getIdx())); + if (!idxType || idxType.getWidth() != 32 || !idxType.isUnsigned()) + return op.emitOpError("tsort32 lowering currently requires u32 index tile"); + + Value srcBuffer = + materializeBufferPointer(op.getSrc(), dataType, getMemorySpace(op.getSrc()), + rewriter, op.getLoc()); + Value dstBuffer = + materializeBufferPointer(op.getDst(), dataType, getMemorySpace(op.getDst()), + rewriter, op.getLoc()); + Value idxBuffer = materializeBufferPointer(op.getIdx(), getElementType(op.getIdx()), + getMemorySpace(op.getIdx()), rewriter, + op.getLoc()); + if (!srcBuffer || !dstBuffer || !idxBuffer) + return op.emitOpError("tsort32 lowering requires pointer-backed tiles"); + + int64_t srcStride = deriveStaticRowStride(op.getSrc()); + int64_t dstStride = deriveStaticRowStride(op.getDst()); + int64_t idxStride = deriveStaticRowStride(op.getIdx()); + if (srcStride == ShapedType::kDynamic || dstStride == ShapedType::kDynamic || + idxStride == ShapedType::kDynamic) + return op.emitOpError("tsort32 lowering requires static row stride"); + + Value validRowsValue; + Value validColsValue; + int64_t validRows = ShapedType::kDynamic; + int64_t validCols = ShapedType::kDynamic; + deriveValidShapeValues(op.getSrc(), validRowsValue, validColsValue); + deriveValidShape(op.getSrc(), validRows, validCols); + if (validCols == ShapedType::kDynamic || (validCols % 32) != 0) + return op.emitOpError("tsort32 lowering currently requires static validCol divisible by 32"); + + int64_t idxValidRows = ShapedType::kDynamic; + int64_t idxValidCols = ShapedType::kDynamic; + deriveValidShape(op.getIdx(), idxValidRows, idxValidCols); + bool idxBroadcast = idxValidRows == 1; + + Value c0 = rewriter.create(op.getLoc(), 0); + Value c1 = rewriter.create(op.getLoc(), 1); + Value repeatNumPerRow = + rewriter.create(op.getLoc(), validCols / 32); + Value srcStrideValue = rewriter.create(op.getLoc(), srcStride); + Value dstStrideValue = rewriter.create(op.getLoc(), dstStride); + Value idxStrideValue = + rewriter.create(op.getLoc(), idxBroadcast ? 0 : idxStride); + + auto rowLoop = rewriter.create(op.getLoc(), c0, validRowsValue, c1); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + Value row = rowLoop.getInductionVar(); + Value srcOffset = rewriter.create(op.getLoc(), row, srcStrideValue); + Value dstOffset = rewriter.create(op.getLoc(), row, dstStrideValue); + Value idxOffset = rewriter.create(op.getLoc(), row, idxStrideValue); + Value rowSrcPtr = + offsetBufferPointer(srcBuffer, dataType, srcOffset, rewriter, op.getLoc()); + Value rowDstPtr = + offsetBufferPointer(dstBuffer, dataType, dstOffset, rewriter, op.getLoc()); + Value rowIdxPtr = offsetBufferPointer(idxBuffer, getElementType(op.getIdx()), idxOffset, + rewriter, op.getLoc()); + rewriter.create(op.getLoc(), rowDstPtr, rowSrcPtr, rowIdxPtr, + repeatNumPerRow); + return success(); +} + +LogicalResult lowerTMulS(TMulSOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOUnaryContract contract = buildUnaryContract("muls", op.getSrc0()); + if (failed(checkGenericUnaryContract( + op, contract, op.getDst(), + [](Type type) { + if (type.isF16() || type.isF32()) + return true; + if (auto intType = dyn_cast(type)) + return intType.getWidth() == 16 || intType.getWidth() == 32; + return false; + }, + "f16, f32, and 16/32-bit integer element types"))) + return failure(); + if (op.getScalar().getType() != contract.elementType) + return op.emitOpError("tmuls lowering requires scalar type to match source element type"); + return buildScalarUnaryVecScope("muls", contract, strategy, op.getSrc0(), op.getScalar(), + op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTSelS(TSelSOp op, PatternRewriter &rewriter) { + VPTOBinaryContract contract = buildBinaryContract("sels", op.getSrc()); + deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); + deriveValidShape(op.getDst(), contract.validRows, contract.validCols); + if (failed(checkGenericBinaryContract( + op, contract, op.getTmp(), op.getDst(), + [](Type type) { + if (type.isF16() || type.isF32() || type.isBF16()) + return true; + if (auto intType = dyn_cast(type)) + return intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + return false; + }, + "f16, f32, bf16, and 8/16/32-bit integer element types"))) + return failure(); + + auto selectModeType = dyn_cast(op.getScalar().getType()); + if (!selectModeType) + return op.emitOpError("tsels lowering requires integer selectMode"); + + auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); + if (!vecType) + return op.emitOpError("tsels lowering requires a supported VPTO vector element type"); + + Value src0Buffer = materializeBufferPointer(op.getSrc(), contract.elementType, + getMemorySpace(op.getSrc()), rewriter, + op.getLoc()); + Value src1Buffer = materializeBufferPointer(op.getTmp(), contract.elementType, + getMemorySpace(op.getTmp()), rewriter, + op.getLoc()); + Value dstBuffer = materializeBufferPointer(op.getDst(), contract.elementType, + getMemorySpace(op.getDst()), rewriter, + op.getLoc()); + if (!src0Buffer || !src1Buffer || !dstBuffer) + return op.emitOpError("tsels lowering requires pointer-backed tile buffers"); + + Value validRowsValue = materializeIndexValue(contract.validRowsValue, + contract.validRows, rewriter, op.getLoc()); + Value validColsValue = materializeIndexValue(contract.validColsValue, + contract.validCols, rewriter, op.getLoc()); + if (!validRowsValue || !validColsValue) + return op.emitOpError("tsels lowering requires valid rows and cols"); + + int64_t vectorWidth = vecType.getElementCount(); + if (contract.validRows != ShapedType::kDynamic && + contract.validCols != ShapedType::kDynamic) { + int64_t totalElements = contract.validRows * contract.validCols; + if (totalElements % vectorWidth != 0) + return op.emitOpError( + "tsels lowering currently requires total valid elements divisible by vector width"); + } + + Value c0 = rewriter.create(op.getLoc(), 0); + Value c1 = rewriter.create(op.getLoc(), 1); + Value totalElementsValue = + rewriter.create(op.getLoc(), validRowsValue, validColsValue); + Value vectorStepValue = + rewriter.create(op.getLoc(), vectorWidth); + + FailureOr vecScope = + createLoopScopeRegion(op.getLoc(), contract.loopScope, rewriter); + if (failed(vecScope)) + return op.emitOpError("failed to create AIV vector scope region"); + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + + Value selectOne = rewriter.create( + op.getLoc(), IntegerAttr::get(selectModeType, 1)); + Value isAll = rewriter.create(op.getLoc(), arith::CmpIPredicate::eq, + op.getScalar(), selectOne); + auto ifOp = rewriter.create( + op.getLoc(), TypeRange{getVPTOMaskTypeForElementType(rewriter.getContext(), contract.elementType)}, isAll, + /*withElseRegion=*/true); + rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); + Value allMask = rewriter + .create(op.getLoc(), + getVPTOMaskTypeForElementType(rewriter.getContext(), contract.elementType), + rewriter.getStringAttr("PAT_ALL")) + .getResult(); + rewriter.create(op.getLoc(), allMask); + rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); + Value allfMask = rewriter + .create(op.getLoc(), + getVPTOMaskTypeForElementType(rewriter.getContext(), contract.elementType), + rewriter.getStringAttr("PAT_ALLF")) + .getResult(); + rewriter.create(op.getLoc(), allfMask); + + rewriter.setInsertionPointAfter(ifOp); + auto chunkLoop = + rewriter.create(op.getLoc(), c0, totalElementsValue, vectorStepValue); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + Value offset = chunkLoop.getInductionVar(); + Value mask = ifOp.getResult(0); + auto src0Vec = rewriter.create(op.getLoc(), vecType, src0Buffer, + offset, StringAttr()); + auto src1Vec = rewriter.create(op.getLoc(), vecType, src1Buffer, + offset, StringAttr()); + Value selected = rewriter + .create(op.getLoc(), vecType, src0Vec.getResult(), + src1Vec.getResult(), mask) + .getResult(); + rewriter.create( + op.getLoc(), selected, dstBuffer, offset, StringAttr(), + buildAllPredicateMask(rewriter, op.getLoc(), contract.elementType)); + return success(); +} + +LogicalResult lowerTSel(TSelOp op, PatternRewriter &rewriter) { + VPTOBinaryContract contract = buildBinaryContract("tsel", op.getSrc0()); + deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); + deriveValidShape(op.getDst(), contract.validRows, contract.validCols); + + int64_t src1Rows = ShapedType::kDynamic; + int64_t src1Cols = ShapedType::kDynamic; + int64_t dstRows = ShapedType::kDynamic; + int64_t dstCols = ShapedType::kDynamic; + int64_t maskRows = ShapedType::kDynamic; + int64_t maskCols = ShapedType::kDynamic; + deriveValidShape(op.getSrc1(), src1Rows, src1Cols); + deriveValidShape(op.getDst(), dstRows, dstCols); + deriveValidShape(op.getMask(), maskRows, maskCols); + + if (contract.tileDomain != VPTOTileDomain::Vec || + deriveTileDomain(getMemorySpace(op.getSrc1())) != VPTOTileDomain::Vec || + deriveTileDomain(getMemorySpace(op.getDst())) != VPTOTileDomain::Vec || + deriveTileDomain(getMemorySpace(op.getMask())) != VPTOTileDomain::Vec) + return op.emitOpError("tsel lowering requires tile domain vec"); + if (contract.tileLayout != "row_major" || deriveTileLayout(op.getSrc1()) != "row_major" || + deriveTileLayout(op.getDst()) != "row_major" || deriveTileLayout(op.getMask()) != "row_major") + return op.emitOpError("tsel lowering requires row-major tile layout"); + if (contract.validRows == ShapedType::kDynamic || + contract.validCols == ShapedType::kDynamic) + return op.emitOpError("tsel lowering requires static valid shape"); + if (contract.validRows != src1Rows || contract.validCols != src1Cols || + contract.validRows != dstRows || contract.validCols != dstCols || + contract.validRows != maskRows || contract.validCols != maskCols) + return op.emitOpError("tsel lowering requires matching source, mask, and destination valid region"); + if (!contract.elementType || !contract.elementType.isF32()) + return op.emitOpError("tsel lowering currently supports only f32 data tiles"); + auto maskElemType = dyn_cast_or_null(getElementType(op.getMask())); + if (!maskElemType || maskElemType.getWidth() != 8) + return op.emitOpError("tsel lowering currently requires i8 mask tiles"); + + auto [tileRows, tileCols] = getStaticTileRowsCols(op.getDst()); + auto [maskTileRows, maskTileCols] = getStaticTileRowsCols(op.getMask()); + if (tileRows == ShapedType::kDynamic || tileCols == ShapedType::kDynamic || + maskTileRows == ShapedType::kDynamic || maskTileCols == ShapedType::kDynamic) + return op.emitOpError("tsel lowering requires static tile rows and cols"); + Value maskBuffer = materializeBufferPointer(op.getMask(), getElementType(op.getMask()), + getMemorySpace(op.getMask()), rewriter, + op.getLoc()); + Value src0Buffer = materializeBufferPointer(op.getSrc0(), contract.elementType, + getMemorySpace(op.getSrc0()), rewriter, + op.getLoc()); + Value src1Buffer = materializeBufferPointer(op.getSrc1(), contract.elementType, + getMemorySpace(op.getSrc1()), rewriter, + op.getLoc()); + Value dstBuffer = materializeBufferPointer(op.getDst(), contract.elementType, + getMemorySpace(op.getDst()), rewriter, + op.getLoc()); + if (!maskBuffer || !src0Buffer || !src1Buffer || !dstBuffer) + return op.emitOpError("tsel lowering requires pointer-backed tile buffers"); + + auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); + if (!vecType) + return op.emitOpError("tsel lowering requires a supported VPTO vector element type"); + + Value c0 = rewriter.create(op.getLoc(), 0); + Value c1 = rewriter.create(op.getLoc(), 1); + Value validRowsValue = materializeIndexValue(contract.validRowsValue, contract.validRows, + rewriter, op.getLoc()); + if (!validRowsValue) + return op.emitOpError("tsel lowering requires valid rows"); + Value rowStride = rewriter.create(op.getLoc(), tileCols); + Value maskStride = rewriter.create(op.getLoc(), maskTileCols); + constexpr int64_t elementsPerRepeat = 64; + constexpr int64_t unrollConstant = 2; + int64_t repeatTimes = (contract.validCols + elementsPerRepeat - 1) / elementsPerRepeat; + int64_t pairedRepeatTimes = repeatTimes / unrollConstant; + int64_t remainRepeat = repeatTimes % unrollConstant; + int64_t repeatIdxBase = pairedRepeatTimes * unrollConstant; + + FailureOr vecScope = + createLoopScopeRegion(op.getLoc(), contract.loopScope, rewriter); + if (failed(vecScope)) + return op.emitOpError("failed to create AIV vector scope region"); + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + auto splitMaskType = getVPTOMaskType(rewriter.getContext(), "b16"); + Value fullMask = rewriter + .create(op.getLoc(), splitMaskType, + rewriter.getStringAttr("PAT_ALL")) + .getResult(); + auto rowLoop = rewriter.create(op.getLoc(), c0, validRowsValue, c1); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + Value rowBase = rewriter.create(op.getLoc(), rowLoop.getInductionVar(), rowStride); + Value maskBase = rewriter.create(op.getLoc(), rowLoop.getInductionVar(), maskStride); + + for (int64_t j = 0; j < pairedRepeatTimes; ++j) { + int64_t repeatIdx = j * unrollConstant; + int64_t colOffset0 = repeatIdx * elementsPerRepeat; + int64_t colOffset1 = colOffset0 + elementsPerRepeat; + int64_t maskOffsetImm = repeatIdx * 8; + int64_t count0 = std::min(elementsPerRepeat, contract.validCols - colOffset0); + int64_t count1 = std::min(elementsPerRepeat, contract.validCols - colOffset1); + + Value maskOffset = rewriter.create( + op.getLoc(), maskBase, + rewriter.create(op.getLoc(), maskOffsetImm)); + Value rawMask = rewriter + .create(op.getLoc(), + splitMaskType, + maskBuffer, maskOffset, + rewriter.getStringAttr("US")) + .getResult(); + auto splitMask = rewriter.create( + op.getLoc(), splitMaskType, splitMaskType, rawMask, fullMask); + + Value dataOffset0 = rewriter.create( + op.getLoc(), rowBase, + rewriter.create(op.getLoc(), colOffset0)); + auto lhs0 = rewriter.create(op.getLoc(), vecType, src0Buffer, + dataOffset0, StringAttr()); + auto rhs0 = rewriter.create(op.getLoc(), vecType, src1Buffer, + dataOffset0, StringAttr()); + Value selected0 = rewriter + .create(op.getLoc(), vecType, lhs0.getResult(), + rhs0.getResult(), splitMask.getLow()) + .getResult(); + Value storeMask0 = buildPredicateMaskForLaneCount( + rewriter, op.getLoc(), contract.elementType, + rewriter.create(op.getLoc(), count0)); + rewriter.create(op.getLoc(), selected0, dstBuffer, dataOffset0, + StringAttr(), storeMask0); + + Value dataOffset1 = rewriter.create( + op.getLoc(), rowBase, + rewriter.create(op.getLoc(), colOffset1)); + auto lhs1 = rewriter.create(op.getLoc(), vecType, src0Buffer, + dataOffset1, StringAttr()); + auto rhs1 = rewriter.create(op.getLoc(), vecType, src1Buffer, + dataOffset1, StringAttr()); + Value selected1 = rewriter + .create(op.getLoc(), vecType, lhs1.getResult(), + rhs1.getResult(), splitMask.getHigh()) + .getResult(); + Value storeMask1 = buildPredicateMaskForLaneCount( + rewriter, op.getLoc(), contract.elementType, + rewriter.create(op.getLoc(), count1)); + rewriter.create(op.getLoc(), selected1, dstBuffer, dataOffset1, + StringAttr(), storeMask1); + } + + for (int64_t j = 0; j < remainRepeat; ++j) { + int64_t repeatIdx = repeatIdxBase + j; + int64_t colOffset = repeatIdx * elementsPerRepeat; + int64_t count = std::max(0, contract.validCols - colOffset); + int64_t maskOffsetImm = repeatIdx * 8; + + Value maskOffset = rewriter.create( + op.getLoc(), maskBase, + rewriter.create(op.getLoc(), maskOffsetImm)); + Value rawMask = rewriter + .create(op.getLoc(), + splitMaskType, + maskBuffer, maskOffset, + rewriter.getStringAttr("US")) + .getResult(); + Value unpackedMask = rewriter + .create( + op.getLoc(), splitMaskType, + rawMask, rewriter.getStringAttr("LOWER")) + .getResult(); + Value dataOffset = rewriter.create( + op.getLoc(), rowBase, + rewriter.create(op.getLoc(), colOffset)); + auto lhs = rewriter.create(op.getLoc(), vecType, src0Buffer, + dataOffset, StringAttr()); + auto rhs = rewriter.create(op.getLoc(), vecType, src1Buffer, + dataOffset, StringAttr()); + Value selected = rewriter + .create(op.getLoc(), vecType, lhs.getResult(), + rhs.getResult(), unpackedMask) + .getResult(); + Value storeMask = buildPredicateMaskForLaneCount( + rewriter, op.getLoc(), contract.elementType, + rewriter.create(op.getLoc(), count)); + rewriter.create(op.getLoc(), selected, dstBuffer, dataOffset, + StringAttr(), storeMask); + } + return success(); +} + +LogicalResult lowerTDivS(TDivSOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + Value tileOperand; + Value scalarOperand; + bool scalarFirst = false; + if (isVPTOShapedLikeValue(op.getSrc()) && !isVPTOShapedLikeValue(op.getScalar())) { + tileOperand = op.getSrc(); + scalarOperand = op.getScalar(); + } else if (!isVPTOShapedLikeValue(op.getSrc()) && + isVPTOShapedLikeValue(op.getScalar())) { + tileOperand = op.getScalar(); + scalarOperand = op.getSrc(); + scalarFirst = true; + } else { + return op.emitOpError( + "divs lowering requires exactly one shaped operand and one scalar operand"); + } + + VPTOUnaryContract contract = buildUnaryContract("divs", tileOperand); + if (failed(checkGenericUnaryContract( + op, contract, op.getDst(), + [](Type type) { return type.isF16() || type.isF32(); }, + "f16 and f32 element types"))) + return failure(); + if (scalarOperand.getType() != contract.elementType) + return op.emitOpError( + "divs lowering requires scalar type to match source element type"); + return buildScalarDivVecScope(contract, strategy, tileOperand, scalarOperand, op.getDst(), + scalarFirst, rewriter, op.getLoc()); +} + +LogicalResult lowerTAddS(TAddSOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOUnaryContract contract = buildUnaryContract("adds", op.getSrc()); + if (failed(checkGenericUnaryContract( + op, contract, op.getDst(), + [](Type type) { + if (type.isF16() || type.isF32() || type.isBF16()) + return true; + if (auto intType = dyn_cast(type)) + return intType.getWidth() == 16 || intType.getWidth() == 32; + return false; + }, + "f16, f32, bf16, and 16/32-bit integer element types"))) + return failure(); + if (op.getScalar().getType() != contract.elementType) + return op.emitOpError("tadds lowering requires scalar type to match source element type"); + return buildScalarUnaryVecScope("adds", contract, strategy, op.getSrc(), op.getScalar(), + op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTAddC(TAddCOp op, PatternRewriter &rewriter) { + VPTOBinaryContract first = buildBinaryContract("add", op.getSrc0()); + deriveValidShapeValues(op.getDst(), first.validRowsValue, first.validColsValue); + deriveValidShape(op.getDst(), first.validRows, first.validCols); + if (failed(checkGenericBinaryContract( + op, first, op.getSrc1(), op.getDst(), + [](Type type) { + if (type.isF16() || type.isF32() || type.isBF16()) + return true; + if (auto intType = dyn_cast(type)) + return intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + return false; + }, + "f16, f32, bf16, and 8/16/32-bit integer element types"))) + return failure(); + if (failed(buildBinaryVecScope("add", first, VPTOLoweringStrategy::PostUpdate, + op.getSrc0(), op.getSrc1(), op.getDst(), + rewriter, op.getLoc()))) + return failure(); + + VPTOBinaryContract second = buildBinaryContract("add", op.getDst()); + deriveValidShapeValues(op.getDst(), second.validRowsValue, second.validColsValue); + deriveValidShape(op.getDst(), second.validRows, second.validCols); + if (failed(checkGenericBinaryContract( + op, second, op.getSrc2(), op.getDst(), + [](Type type) { + if (type.isF16() || type.isF32() || type.isBF16()) + return true; + if (auto intType = dyn_cast(type)) + return intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + return false; + }, + "f16, f32, bf16, and 8/16/32-bit integer element types"))) + return failure(); + return buildBinaryVecScope("add", second, VPTOLoweringStrategy::PostUpdate, + op.getDst(), op.getSrc2(), op.getDst(), rewriter, + op.getLoc()); +} + +LogicalResult lowerTAddSC(TAddSCOp op, PatternRewriter &rewriter) { + return emitUnresolvedInstalledA5BaselineError(op, "taddsc"); +} + +LogicalResult lowerTSubC(TSubCOp op, PatternRewriter &rewriter) { + VPTOBinaryContract first = buildBinaryContract("sub", op.getSrc0()); + deriveValidShapeValues(op.getDst(), first.validRowsValue, first.validColsValue); + deriveValidShape(op.getDst(), first.validRows, first.validCols); + if (failed(checkGenericBinaryContract( + op, first, op.getSrc1(), op.getDst(), + [](Type type) { + if (type.isF16() || type.isF32() || type.isBF16()) + return true; + if (auto intType = dyn_cast(type)) + return intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + return false; + }, + "f16, f32, bf16, and 8/16/32-bit integer element types"))) + return failure(); + if (failed(buildBinaryVecScope("sub", first, VPTOLoweringStrategy::PostUpdate, + op.getSrc0(), op.getSrc1(), op.getDst(), + rewriter, op.getLoc()))) + return failure(); + + VPTOBinaryContract second = buildBinaryContract("add", op.getDst()); + deriveValidShapeValues(op.getDst(), second.validRowsValue, second.validColsValue); + deriveValidShape(op.getDst(), second.validRows, second.validCols); + if (failed(checkGenericBinaryContract( + op, second, op.getSrc2(), op.getDst(), + [](Type type) { + if (type.isF16() || type.isF32() || type.isBF16()) + return true; + if (auto intType = dyn_cast(type)) + return intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + return false; + }, + "f16, f32, bf16, and 8/16/32-bit integer element types"))) + return failure(); + return buildBinaryVecScope("add", second, VPTOLoweringStrategy::PostUpdate, + op.getDst(), op.getSrc2(), op.getDst(), rewriter, + op.getLoc()); +} + +LogicalResult lowerTSubS(TSubSOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + (void)rewriter; + (void)strategy; + return emitUnresolvedInstalledA5BaselineError(op, "tsubs"); +} + +LogicalResult lowerTSubSC(TSubSCOp op, PatternRewriter &rewriter) { + return emitUnresolvedInstalledA5BaselineError(op, "tsubsc"); +} + +LogicalResult lowerTMaxS(TMaxSOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOUnaryContract contract = buildUnaryContract("maxs", op.getSrc()); + if (failed(checkGenericUnaryContract( + op, contract, op.getDst(), + [](Type type) { return type.isF32(); }, "f32 element type"))) + return failure(); + if (op.getScalar().getType() != contract.elementType) + return op.emitOpError("tmaxs lowering requires scalar type to match source element type"); + return buildScalarUnaryVecScope("maxs", contract, strategy, op.getSrc(), op.getScalar(), + op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTMinS(TMinSOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOUnaryContract contract = buildUnaryContract("mins", op.getSrc()); + if (failed(checkGenericUnaryContract( + op, contract, op.getDst(), + [](Type type) { return type.isF32(); }, "f32 element type"))) + return failure(); + if (op.getScalar().getType() != contract.elementType) + return op.emitOpError("tmins lowering requires scalar type to match source element type"); + return buildScalarUnaryVecScope("mins", contract, strategy, op.getSrc(), op.getScalar(), + op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTRowMax(TRowMaxOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTORowReduceContract contract = extractTRowMaxContract(op); + if (failed(checkRowReduceContract(op, contract, op.getDst()))) + return failure(); + return buildRowReduceVecScope("rowmax", contract, strategy, op.getSrc(), op.getDst(), + rewriter, op.getLoc()); +} + +LogicalResult lowerTRowMin(TRowMinOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTORowReduceContract contract = extractTRowMinContract(op); + if (failed(checkRowReduceContract(op, contract, op.getDst()))) + return failure(); + return buildRowReduceVecScope("rowmin", contract, strategy, op.getSrc(), op.getDst(), + rewriter, op.getLoc()); +} + +LogicalResult lowerTRowSum(TRowSumOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTORowReduceContract contract = extractTRowSumContract(op); + if (failed(checkRowReduceContract(op, contract, op.getDst()))) + return failure(); + return buildRowReduceVecScope("rowsum", contract, strategy, op.getSrc(), op.getDst(), + rewriter, op.getLoc()); +} + +LogicalResult lowerTColMax(TColMaxOp op, PatternRewriter &rewriter) { + VPTOColReduceContract contract = extractTColMaxContract(op); + if (failed(checkColReduceContract(op, contract, op.getDst()))) + return failure(); + return buildColReduceVecScope("colmax", contract, op.getSrc(), op.getDst(), + Value(), rewriter, op.getLoc()); +} + +LogicalResult lowerTColMin(TColMinOp op, PatternRewriter &rewriter) { + VPTOColReduceContract contract = extractTColMinContract(op); + if (failed(checkColReduceContract(op, contract, op.getDst()))) + return failure(); + return buildColReduceVecScope("colmin", contract, op.getSrc(), op.getDst(), + Value(), rewriter, op.getLoc()); +} + +LogicalResult lowerTColSum(TColSumOp op, PatternRewriter &rewriter) { + VPTOColReduceContract contract = extractTColSumContract(op); + if (failed(checkColReduceContract(op, contract, op.getDst()))) + return failure(); + return buildColReduceVecScope("colsum", contract, op.getSrc(), op.getDst(), + op.getTmp(), rewriter, op.getLoc()); +} + +LogicalResult lowerTRowExpand(TRowExpandOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOExpandContract contract = extractTRowExpandContract(op); + if (failed(checkExpandContract(op, contract))) + return failure(); + if (contract.srcValidRows != contract.dstValidRows) + return op.emitOpError() + << "rowexpand lowering requires source and destination valid rows to match"; + return buildRowExpandVecScope(contract, strategy, op.getSrc(), op.getDst(), rewriter, + op.getLoc()); +} + +LogicalResult lowerTColExpand(TColExpandOp op, PatternRewriter &rewriter) { + VPTOExpandContract contract = extractTColExpandContract(op); + if (failed(checkExpandContract(op, contract))) + return failure(); + if (contract.srcValidCols != contract.dstValidCols) + return op.emitOpError() + << "colexpand lowering requires source and destination valid cols to match"; + return buildColExpandVecScope(contract, op.getSrc(), op.getDst(), rewriter, + op.getLoc()); +} + +template +LogicalResult lowerTRowExpandBinaryLike(OpTy op, PatternRewriter &rewriter, + StringRef family, + VPTOLoweringStrategy strategy) { + Type elementType = getElementType(op.getDst()); + if (!elementType || (!elementType.isF16() && !elementType.isF32())) + return op.emitOpError() << family + << " lowering currently supports only f16 and f32 element types"; + + if (deriveTileDomain(getMemorySpace(op.getDst())) != VPTOTileDomain::Vec || + deriveTileDomain(getMemorySpace(op.getSrc0())) != VPTOTileDomain::Vec || + deriveTileDomain(getMemorySpace(op.getSrc1())) != VPTOTileDomain::Vec) + return op.emitOpError() << family << " lowering requires vec tile domain"; + if (deriveTileLayout(op.getDst()) != "row_major") + return op.emitOpError() << family << " lowering requires row-major dst layout"; + + int64_t dstValidRows = ShapedType::kDynamic; + int64_t dstValidCols = ShapedType::kDynamic; + int64_t src0ValidRows = ShapedType::kDynamic; + int64_t src0ValidCols = ShapedType::kDynamic; + int64_t src1ValidRows = ShapedType::kDynamic; + int64_t src1ValidCols = ShapedType::kDynamic; + deriveValidShape(op.getDst(), dstValidRows, dstValidCols); + deriveValidShape(op.getSrc0(), src0ValidRows, src0ValidCols); + deriveValidShape(op.getSrc1(), src1ValidRows, src1ValidCols); + if (dstValidRows == ShapedType::kDynamic || dstValidCols == ShapedType::kDynamic || + src0ValidRows == ShapedType::kDynamic || src0ValidCols == ShapedType::kDynamic || + src1ValidRows == ShapedType::kDynamic || src1ValidCols == ShapedType::kDynamic) + return op.emitOpError() << family + << " lowering currently requires static valid shapes"; + + bool src0EqDst = op.getSrc0().getType() == op.getDst().getType(); + bool src1EqDst = op.getSrc1().getType() == op.getDst().getType(); + if (!src0EqDst && !src1EqDst) + return op.emitOpError() << family + << " lowering requires src0 or src1 to match dst tile type"; + + Value baseSrc = src0EqDst ? op.getSrc0() : op.getSrc1(); + Value expandSrc = src0EqDst ? op.getSrc1() : op.getSrc0(); + StringRef expandLayout = deriveTileLayout(expandSrc); + int64_t expandValidRows = src0EqDst ? src1ValidRows : src0ValidRows; + int64_t expandValidCols = src0EqDst ? src1ValidCols : src0ValidCols; + if (expandValidRows != dstValidRows) + return op.emitOpError() << family + << " lowering requires expand operand valid rows to match dst"; + + int64_t elemBytes = getElementByteSize(elementType); + bool expandIsRowMajor = expandLayout == "row_major" && expandValidCols == 32 / elemBytes; + bool expandIsColMajor = expandLayout == "col_major" && expandValidCols == 1; + if (!expandIsRowMajor && !expandIsColMajor) + return op.emitOpError() << family + << " lowering requires PTO A5-compatible expand operand shape"; + + auto vecType = getVPTOVRegType(rewriter.getContext(), elementType); + if (!vecType) + return op.emitOpError() << family + << " lowering requires a legal VPTO vector type"; + + Value baseBuffer = materializeBufferPointer(baseSrc, elementType, + getMemorySpace(baseSrc), rewriter, + op.getLoc()); + Value expandBuffer = materializeBufferPointer(expandSrc, elementType, + getMemorySpace(expandSrc), rewriter, + op.getLoc()); + Value dstBuffer = materializeBufferPointer(op.getDst(), elementType, + getMemorySpace(op.getDst()), rewriter, + op.getLoc()); + if (!baseBuffer || !expandBuffer || !dstBuffer) + return op.emitOpError() << family + << " lowering requires pointer-backed tile buffers"; + + int64_t dstRowStride = deriveStaticRowStride(op.getDst()); + int64_t baseRowStride = deriveStaticRowStride(baseSrc); + int64_t expandRowStride = deriveStaticRowStride(expandSrc); + if (dstRowStride == ShapedType::kDynamic || baseRowStride == ShapedType::kDynamic || + expandRowStride == ShapedType::kDynamic) + return op.emitOpError() << family << " lowering requires static row strides"; + + Value c0 = rewriter.create(op.getLoc(), 0); + Value c1 = rewriter.create(op.getLoc(), 1); + Value rowsUpper = rewriter.create(op.getLoc(), dstValidRows); + Value colsUpper = rewriter.create(op.getLoc(), dstValidCols); + Value vectorStep = + rewriter.create(op.getLoc(), vecType.getElementCount()); + Value baseStrideValue = + rewriter.create(op.getLoc(), baseRowStride); + Value expandStrideValue = + rewriter.create(op.getLoc(), expandRowStride); + Value dstStrideValue = + rewriter.create(op.getLoc(), dstRowStride); + Value blockSizeValue = + rewriter.create(op.getLoc(), 32 / elemBytes); + + VPTOLoopScopeContract loopScope; + loopScope.kind = VPTOLoopScopeKind::AIVVectorScope; + loopScope.loweredAttr = kLoweredLoopScopeAttrName; + loopScope.loopDepth = 0; + + auto buildRowExpandValue = [&](Value baseVec, Value expandedVec, + Value predicate) -> FailureOr { + if (family == "trowexpandmul") + return rewriter.create(op.getLoc(), vecType, baseVec, + expandedVec, predicate) + .getResult(); + if (family == "trowexpanddiv") { + if (src0EqDst) + return rewriter.create(op.getLoc(), vecType, baseVec, + expandedVec, predicate) + .getResult(); + return rewriter.create(op.getLoc(), vecType, expandedVec, + baseVec, predicate) + .getResult(); + } + if (family == "trowexpandsub") { + if (src0EqDst) + return rewriter.create(op.getLoc(), vecType, baseVec, + expandedVec, predicate) + .getResult(); + return rewriter.create(op.getLoc(), vecType, expandedVec, + baseVec, predicate) + .getResult(); + } + return failure(); + }; + + FailureOr vecScope = + createLoopScopeRegion(op.getLoc(), loopScope, rewriter); + if (failed(vecScope)) + return op.emitOpError("failed to create AIV vector scope region"); + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + auto rowLoop = rewriter.create(op.getLoc(), c0, rowsUpper, c1); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + Value row = rowLoop.getInductionVar(); + Value baseRowOffset = rewriter.create(op.getLoc(), row, baseStrideValue); + Value dstRowOffset = rewriter.create(op.getLoc(), row, dstStrideValue); + Value expandRowOffset = expandIsRowMajor + ? rewriter.create(op.getLoc(), row, blockSizeValue) + : rewriter.create(op.getLoc(), row, expandStrideValue); + + Value expandVec; + if (expandIsColMajor) { + Value fullMask = buildAllPredicateMask(rewriter, op.getLoc(), elementType); + Value expandScalar = + rewriter.create(op.getLoc(), vecType, expandBuffer, + expandRowOffset); + expandVec = rewriter + .create(op.getLoc(), vecType, expandScalar, fullMask, + StringAttr()) + .getResult(); + } else { + expandVec = rewriter + .create(op.getLoc(), vecType, expandBuffer, expandRowOffset, + rewriter.getStringAttr("BLK")) + .getResult(); + } + + auto colLoop = rewriter.create(op.getLoc(), c0, colsUpper, vectorStep); + rewriter.setInsertionPointToStart(colLoop.getBody()); + Value col = colLoop.getInductionVar(); + Value remainingCols = rewriter.create(op.getLoc(), colsUpper, col); + Value needsTailMask = rewriter.create( + op.getLoc(), arith::CmpIPredicate::slt, remainingCols, vectorStep); + Value activeLanes = rewriter.create(op.getLoc(), needsTailMask, + remainingCols, vectorStep); + Value baseOffset = rewriter.create(op.getLoc(), baseRowOffset, col); + Value dstOffset = rewriter.create(op.getLoc(), dstRowOffset, col); + Value storeMask = + buildPredicateMaskForLaneCount(rewriter, op.getLoc(), elementType, activeLanes); + Value baseVec = + rewriter.create(op.getLoc(), vecType, baseBuffer, baseOffset, StringAttr()); + FailureOr computed = + buildRowExpandValue(baseVec, expandVec, storeMask); + if (failed(computed)) + return op.emitOpError() << "unsupported rowexpand binary family"; + rewriter.create(op.getLoc(), *computed, dstBuffer, dstOffset, + StringAttr(), storeMask); + rewriter.create(op.getLoc()); + return success(); +} + +LogicalResult lowerTRowExpandMul(TRowExpandMulOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTRowExpandBinaryLike(op, rewriter, "trowexpandmul", strategy); +} + +LogicalResult lowerTRowExpandDiv(TRowExpandDivOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTRowExpandBinaryLike(op, rewriter, "trowexpanddiv", strategy); +} + +LogicalResult lowerTRowExpandSub(TRowExpandSubOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTRowExpandBinaryLike(op, rewriter, "trowexpandsub", strategy); +} + +LogicalResult lowerTPartAdd(TPartAddOp op, PatternRewriter &rewriter) { + VPTOPartContract contract = extractTPartAddContract(op); + if (failed(checkPartContract(op, contract))) + return failure(); + return buildPartVecScope("partadd", contract, op.getSrc0(), op.getSrc1(), + op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTPartMax(TPartMaxOp op, PatternRewriter &rewriter) { + VPTOPartContract contract = extractTPartMaxContract(op); + if (failed(checkPartContract(op, contract))) + return failure(); + return buildPartVecScope("partmax", contract, op.getSrc0(), op.getSrc1(), + op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTPartMin(TPartMinOp op, PatternRewriter &rewriter) { + VPTOPartContract contract = extractTPartMinContract(op); + if (failed(checkPartContract(op, contract))) + return failure(); + return buildPartVecScope("partmin", contract, op.getSrc0(), op.getSrc1(), + op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTSTORE(TStoreOp op, PatternRewriter &rewriter) { + VPTOStoreContract contract = extractTStoreContract(op); + + switch (contract.srcDomain) { + case VPTOTileDomain::Acc: + return lowerUnsupportedAccStore(op.getLoc()); + case VPTOTileDomain::Mat: + return lowerUnsupportedMatStore(op.getLoc()); + case VPTOTileDomain::Vec: + break; + } + + ResolvedTensorView destinationView; + if (!resolveTensorView(op.getDst(), destinationView, rewriter, op.getLoc())) + return op.emitOpError("requires a recoverable destination tensor view for VPTO lowering"); + + StringRef sourceTileLayout = deriveTileLayout(op.getSrc()); + StringRef destinationLayout = + inferVecTransferLayoutFromTile(stringifyLayoutAttr(destinationView.layoutAttr), + sourceTileLayout); + bool isNdStore = sourceTileLayout == "row_major" && destinationLayout == "nd"; + bool isDnStore = sourceTileLayout == "col_major" && destinationLayout == "dn"; + if (!isNdStore && !isDnStore) + return op.emitOpError("currently supports only ND row_major or DN col_major vec TSTORE lowering"); + + Value sourceBuffer = + materializeBufferPointer(op.getSrc(), contract.elementType, + getMemorySpace(op.getSrc()), rewriter, op.getLoc()); + Value destinationBuffer = + materializeBufferPointer(destinationView.root, getElementType(destinationView.root), + getGmMemorySpace(rewriter.getContext()), rewriter, + op.getLoc()); + if (!sourceBuffer || !destinationBuffer) + return op.emitOpError("requires A5-compatible source and destination buffers"); + + auto [tileRows, tileCols] = getStaticTileRowsCols(op.getSrc()); + Value validRowsValue = + materializeI64Value(contract.validRowsValue, contract.validRows, rewriter, + op.getLoc()); + Value validColsValue = + materializeI64Value(contract.validColsValue, contract.validCols, rewriter, + op.getLoc()); + Value sidValue = rewriter.create(op.getLoc(), 0, 64); + int64_t elemBytes = getElementByteSize(contract.elementType); + if ((isNdStore && tileCols == ShapedType::kDynamic) || + (isDnStore && tileRows == ShapedType::kDynamic) || elemBytes <= 0) + return op.emitOpError("requires static tile shape for A5-compatible transfer arguments"); + VecNdTransferPlan plan; + LogicalResult planResult = + isNdStore ? buildVecNdStorePlan(destinationView.shape, destinationView.strides, + tileCols, contract.validColsValue, + contract.validCols, contract.elementType, + rewriter, op.getLoc(), plan) + : buildVecDnStorePlan(destinationView.shape, destinationView.strides, + tileRows, contract.validRowsValue, + contract.validRows, contract.elementType, + rewriter, op.getLoc(), plan); + if (failed(planResult)) + return op.emitOpError("requires PTO-compatible vec copy_ubuf_to_gm arguments"); + Value reservedValue = rewriter.create(op.getLoc(), 0, 64); + if (!validRowsValue || !validColsValue) + return op.emitOpError("requires valid rows and cols for A5-compatible transfer arguments"); + Value destinationOffset = + materializeI64Ofr(destinationView.offsetElems, rewriter, op.getLoc()); + if (!destinationOffset) + return op.emitOpError("requires a materializable destination offset for VPTO lowering"); + Value destinationBase = + adjustPointerByElemOffset(destinationBuffer, destinationOffset, elemBytes, rewriter, + op.getLoc()); + if (!destinationBase) + return op.emitOpError("failed to materialize destination base pointer"); + + rewriter.create(op.getLoc(), plan.loop2Size, + plan.loop1Size); + rewriter.create( + op.getLoc(), plan.loop1FirstStrideBytes, plan.loop1SecondStrideBytes); + rewriter.create( + op.getLoc(), plan.loop2FirstStrideBytes, plan.loop2SecondStrideBytes); + + auto emitCopy = [&](Value srcPtr, Value dstPtr) { + Type transferElementType = + getCopyTransferElementType(contract.elementType, rewriter); + Value typedSrcPtr = + castPtrToElementType(srcPtr, transferElementType, rewriter, op.getLoc()); + Value typedDstPtr = + castPtrToElementType(dstPtr, transferElementType, rewriter, op.getLoc()); + if (!typedSrcPtr || !typedDstPtr) + return failure(); + rewriter.create( + op.getLoc(), typedSrcPtr, typedDstPtr, sidValue, plan.nBurst, + plan.lenBurst, reservedValue, plan.firstStrideBytes, + plan.secondStrideBytes); + return success(); + }; + + if (std::optional outerConst = getConstInt(plan.outerCount); outerConst && *outerConst == 1) { + return emitCopy(sourceBuffer, destinationBase); + } + + Value c0 = rewriter.create(op.getLoc(), 0); + Value c1 = rewriter.create(op.getLoc(), 1); + Value outerUpper = + rewriter.create(op.getLoc(), rewriter.getIndexType(), + plan.outerCount); + auto outerLoop = rewriter.create(op.getLoc(), c0, outerUpper, c1); + rewriter.setInsertionPointToStart(outerLoop.getBody()); + Value ivI64 = rewriter.create(op.getLoc(), rewriter.getI64Type(), + outerLoop.getInductionVar()); + Value srcStep = createI64Mul(ivI64, plan.outerSrcStrideElems, rewriter, op.getLoc()); + Value dstStep = createI64Mul(ivI64, plan.outerDstStrideElems, rewriter, op.getLoc()); + Value iterSrc = adjustPointerByElemOffset(sourceBuffer, srcStep, elemBytes, rewriter, + op.getLoc()); + Value iterDst = adjustPointerByElemOffset(destinationBase, dstStep, elemBytes, rewriter, + op.getLoc()); + return emitCopy(iterSrc, iterDst); +} + +LogicalResult lowerSetFlag(SetFlagOp op, PatternRewriter &rewriter) { + rewriter.create(op.getLoc(), + stringifyPipeAttr(op.getSrcPipe(), rewriter), + stringifyPipeAttr(op.getDstPipe(), rewriter), + stringifyEventAttr(op.getEventId(), rewriter)); + return success(); +} + +LogicalResult lowerWaitFlag(WaitFlagOp op, PatternRewriter &rewriter) { + rewriter.create(op.getLoc(), + stringifyPipeAttr(op.getSrcPipe(), rewriter), + stringifyPipeAttr(op.getDstPipe(), rewriter), + stringifyEventAttr(op.getEventId(), rewriter)); + return success(); +} + +LogicalResult lowerBarrier(BarrierOp op, PatternRewriter &rewriter) { + rewriter.create(op.getLoc(), + stringifyPipeAttr(op.getPipe(), rewriter)); + return success(); +} + +static FailureOr stringifyConcreteSyncPipeAttr(Attribute opTypeAttr, + PatternRewriter &rewriter) { + if (auto pipeAttr = dyn_cast(opTypeAttr)) + return PipeAttr::get(rewriter.getContext(), pipeAttr.getPipe()); + auto opTypeOr = parseSyncOpTypeLikeAttr(opTypeAttr); + if (failed(opTypeOr)) + return failure(); + PIPE pipe = mapSyncOpTypeToPipe(*opTypeOr); + if (!isConcreteSyncPipe(pipe)) + return failure(); + return PipeAttr::get(rewriter.getContext(), pipe); +} + +LogicalResult lowerGetBuf(GetBufOp op, PatternRewriter &rewriter) { + FailureOr pipeAttr = + stringifyConcreteSyncPipeAttr(op.getOpTypeAttr(), rewriter); + if (failed(pipeAttr)) + return op.emitOpError("get_buf expects SyncOpType/PipeEventType that maps to a concrete pipe"); + + rewriter.create(op.getLoc(), Attribute(*pipeAttr), + static_cast(op.getBufId()), + static_cast(op.getMode())); + return success(); +} + +LogicalResult lowerRlsBuf(RlsBufOp op, PatternRewriter &rewriter) { + FailureOr pipeAttr = + stringifyConcreteSyncPipeAttr(op.getOpTypeAttr(), rewriter); + if (failed(pipeAttr)) + return op.emitOpError("rls_buf expects SyncOpType/PipeEventType that maps to a concrete pipe"); + + rewriter.create(op.getLoc(), Attribute(*pipeAttr), + static_cast(op.getBufId()), + static_cast(op.getMode())); + return success(); +} + +namespace { + +static Type convertVPTOBoundaryMemRefType(Type type) { + auto memrefType = dyn_cast(type); + if (!memrefType) + return type; + auto memorySpace = dyn_cast_or_null(memrefType.getMemorySpace()); + if (!memorySpace) + return {}; + return PtrType::get(type.getContext(), memrefType.getElementType(), memorySpace); +} + +static LogicalResult eraseDeadVPTOMemRefScaffold(ModuleOp module) { + bool erasedAny = true; + while (erasedAny) { + erasedAny = false; + SmallVector deadOps; + module.walk([&](Operation *op) { + if (!op->use_empty()) + return; + if (isa(op)) + deadOps.push_back(op); + }); + for (Operation *op : deadOps) { + op->erase(); + erasedAny = true; + } + } + return success(); +} + +static LogicalResult verifyNoResidualVPTOMemRefs(ModuleOp module, + llvm::raw_ostream *diagOS) { + for (func::FuncOp func : module.getOps()) { + for (Type input : func.getFunctionType().getInputs()) { + if (!isa(input)) + continue; + if (diagOS) + *diagOS << "VPTO ptr-only boundary failed: residual memref argument in " + << func.getName() << ": " << input << "\n"; + return failure(); + } + for (Type result : func.getFunctionType().getResults()) { + if (!isa(result)) + continue; + if (diagOS) + *diagOS << "VPTO ptr-only boundary failed: residual memref result in " + << func.getName() << ": " << result << "\n"; + return failure(); + } + } + + WalkResult walk = module.walk([&](Operation *op) { + auto hasResidualMemRef = [](TypeRange types) { + return llvm::any_of(types, [](Type type) { + return isa(type); + }); + }; + if (hasResidualMemRef(op->getOperandTypes()) || + hasResidualMemRef(op->getResultTypes())) { + if (diagOS) { + *diagOS << "VPTO ptr-only boundary failed: residual memref-typed op " + << op->getName() << "\n"; + op->print(*diagOS); + *diagOS << "\n"; + } + return WalkResult::interrupt(); + } + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + for (BlockArgument arg : block.getArguments()) { + if (!isa(arg.getType())) + continue; + if (diagOS) + *diagOS << "VPTO ptr-only boundary failed: residual memref block " + << "argument in op " << op->getName() << ": " + << arg.getType() << "\n"; + return WalkResult::interrupt(); + } + } + } + return WalkResult::advance(); + }); + return walk.wasInterrupted() ? failure() : success(); +} + +} // namespace + +LogicalResult convertVPTOFunctionBoundariesToPtr(ModuleOp module, + llvm::raw_ostream *diagOS) { + // VPTO kernels use ptr-only entry semantics: the function ABI keeps only the + // same-space base pointer, while shape/stride/offset stay in live SSA and + // address calculations inside the body. + if (failed(eraseDeadVPTOMemRefScaffold(module))) + return failure(); + + bool sawFailure = false; + for (func::FuncOp func : module.getOps()) { + if (func.isExternal()) + continue; + + FunctionType functionType = func.getFunctionType(); + SmallVector newInputs(functionType.getInputs().begin(), + functionType.getInputs().end()); + bool changed = false; + + for (auto [idx, inputType] : llvm::enumerate(functionType.getInputs())) { + auto memrefType = dyn_cast(inputType); + if (!memrefType) + continue; + + Type newType = convertVPTOBoundaryMemRefType(inputType); + if (!newType) { + if (diagOS) + *diagOS << "VPTO ptr-only boundary failed: unsupported memref " + << "argument type in " << func.getName() << ": " + << inputType << "\n"; + sawFailure = true; + continue; + } + + BlockArgument arg = func.getArgument(idx); + SmallVector users(arg.getUsers().begin(), arg.getUsers().end()); + arg.setType(newType); + newInputs[idx] = newType; + changed = true; + + for (Operation *user : users) { + if (auto cast = dyn_cast(user)) { + if (cast.getInput() != arg) + continue; + if (cast.getResult().getType() == newType) { + cast.getResult().replaceAllUsesWith(arg); + cast.erase(); + } + continue; + } + + if (isa(user) && + user->use_empty()) { + user->erase(); + continue; + } + + if (diagOS) { + *diagOS << "VPTO ptr-only boundary failed: argument " << idx + << " of " << func.getName() + << " still feeds a memref-dependent user after ptr rewrite:\n"; + user->print(*diagOS); + *diagOS << "\n"; + } + sawFailure = true; + } + } + + for (Type resultType : functionType.getResults()) { + if (!isa(resultType)) + continue; + if (diagOS) + *diagOS << "VPTO ptr-only boundary failed: memref result is unsupported " + << "for " << func.getName() << ": " << resultType << "\n"; + sawFailure = true; + } + + if (changed) { + func.setFunctionType( + FunctionType::get(module.getContext(), newInputs, functionType.getResults())); + } + } + + if (sawFailure) + return failure(); + + if (failed(eraseDeadVPTOMemRefScaffold(module))) + return failure(); + return verifyNoResidualVPTOMemRefs(module, diagOS); +} + +} // namespace pto +} // namespace mlir diff --git a/lib/PTO/Transforms/PTOVPTOExpandBridgeOps.cpp b/lib/PTO/Transforms/PTOVPTOExpandBridgeOps.cpp new file mode 100644 index 000000000..7d2cd0d4d --- /dev/null +++ b/lib/PTO/Transforms/PTOVPTOExpandBridgeOps.cpp @@ -0,0 +1,114 @@ +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_PTOVPTOEXPANDBRIDGEOPS +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; + +namespace { + +static pto::AddressSpaceAttr getPointerMemorySpace(Attribute memorySpace, + MLIRContext *ctx) { + if (auto addrSpace = dyn_cast_or_null(memorySpace)) + return addrSpace; + if (auto intAttr = dyn_cast_or_null(memorySpace)) + return pto::AddressSpaceAttr::get( + ctx, static_cast(intAttr.getInt())); + return pto::AddressSpaceAttr::get(ctx, pto::AddressSpace::GM); +} + +static Value materializeBufferPointer(Value value, PatternRewriter &rewriter, + Location loc) { + if (!value) + return {}; + + if (isa(value.getType())) + return value; + + auto memrefType = dyn_cast(value.getType()); + if (!memrefType) + return {}; + + auto ptrType = + pto::PtrType::get(rewriter.getContext(), memrefType.getElementType(), + getPointerMemorySpace(memrefType.getMemorySpace(), + rewriter.getContext())); + return rewriter.create(loc, ptrType, value).getResult(); +} + +static Value offsetBufferPointer(Value basePtr, Type elementType, + Value elementOffset, + PatternRewriter &rewriter, Location loc) { + if (!basePtr) + return {}; + + Value offsetIndex = elementOffset; + if (!offsetIndex.getType().isIndex()) + offsetIndex = rewriter.create(loc, + rewriter.getIndexType(), + elementOffset); + return rewriter.create(loc, basePtr.getType(), basePtr, + offsetIndex); +} + +struct ExpandUvldPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(pto::UvldOp op, + PatternRewriter &rewriter) const override { + auto vecType = dyn_cast(op.getResult().getType()); + if (!vecType) + return failure(); + + Value basePtr = materializeBufferPointer(op.getSource(), rewriter, op.getLoc()); + if (!basePtr) + return op.emitOpError( + "requires a recoverable pointer base for uvld expansion"); + + Value loadPtr = offsetBufferPointer(basePtr, vecType.getElementType(), + op.getOffset(), rewriter, op.getLoc()); + auto alignType = pto::AlignType::get(rewriter.getContext()); + Value align = + rewriter.create(op.getLoc(), alignType, loadPtr); + auto load = rewriter.create( + op.getLoc(), TypeRange{vecType, alignType, loadPtr.getType()}, + ValueRange{loadPtr, align}); + rewriter.replaceOp(op, load.getResult()); + return success(); + } +}; + +struct PTOVPTOExpandBridgeOpsPass + : public pto::impl::PTOVPTOExpandBridgeOpsBase { + using pto::impl::PTOVPTOExpandBridgeOpsBase< + PTOVPTOExpandBridgeOpsPass>::PTOVPTOExpandBridgeOpsBase; + + void runOnOperation() override { + func::FuncOp func = getOperation(); + if (func.isExternal()) + return; + + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createPTOVPTOExpandBridgeOpsPass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/PTOVPTOPtrBoundary.cpp b/lib/PTO/Transforms/PTOVPTOPtrBoundary.cpp new file mode 100644 index 000000000..6aa62259f --- /dev/null +++ b/lib/PTO/Transforms/PTOVPTOPtrBoundary.cpp @@ -0,0 +1,337 @@ +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/VPTOLowering.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_PTOVPTOPTRBOUNDARY +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; + +namespace { + +static Type convertVPTOBoundaryMemRefType(Type type) { + auto memrefType = dyn_cast(type); + if (!memrefType) + return type; + auto memorySpace = + dyn_cast_or_null(memrefType.getMemorySpace()); + if (!memorySpace) + return {}; + return pto::PtrType::get(type.getContext(), memrefType.getElementType(), + memorySpace); +} + +static bool isTrivialVPTOBoundaryCastPtr(pto::CastPtrOp castOp) { + return castOp.getInput().getType() == castOp.getResult().getType(); +} + +static LogicalResult eraseDeadVPTOMemRefScaffold(ModuleOp module) { + bool erasedAny = true; + while (erasedAny) { + erasedAny = false; + SmallVector trivialCasts; + SmallVector deadOps; + module.walk([&](Operation *op) { + if (auto castOp = dyn_cast(op)) { + if (isTrivialVPTOBoundaryCastPtr(castOp)) { + trivialCasts.push_back(castOp); + return; + } + if (castOp->use_empty()) + deadOps.push_back(op); + return; + } + + if (!op->use_empty()) + return; + if (isa(op)) + deadOps.push_back(op); + }); + + for (pto::CastPtrOp castOp : trivialCasts) { + if (!castOp->getBlock()) + continue; + castOp.getResult().replaceAllUsesWith(castOp.getInput()); + castOp.erase(); + erasedAny = true; + } + + for (Operation *op : deadOps) { + if (!op->getBlock()) + continue; + op->erase(); + erasedAny = true; + } + } + return success(); +} + +static Type getVPTOBufferElementType(Value value) { + Type type = value.getType(); + if (auto tileType = dyn_cast(type)) + return tileType.getElementType(); + if (auto memrefType = dyn_cast(type)) + return memrefType.getElementType(); + if (auto ptrType = dyn_cast(type)) + return ptrType.getElementType(); + return {}; +} + +static Attribute getVPTOBufferMemorySpace(Value value) { + Type type = value.getType(); + if (auto tileType = dyn_cast(type)) + return tileType.getMemorySpace(); + if (auto memrefType = dyn_cast(type)) + return memrefType.getMemorySpace(); + if (auto ptrType = dyn_cast(type)) + return ptrType.getMemorySpace(); + return {}; +} + +static bool needsPtrCanonicalization(Value value) { + return isa(value.getType()); +} + +static bool isSupportedVPTOBufferLikeBoundaryOp(Operation *op) { + return isa(op); +} + +static LogicalResult canonicalizeBoundaryCastPtrOps(ModuleOp module, + llvm::raw_ostream *diagOS) { + SmallVector castsToRewrite; + module.walk([&](pto::CastPtrOp castOp) { + if (!isa(castOp.getInput().getType())) + return; + if (!isa(castOp.getResult().getType())) + return; + castsToRewrite.push_back(castOp); + }); + + PatternRewriter rewriter(module.getContext()); + for (pto::CastPtrOp castOp : castsToRewrite) { + if (!castOp->getBlock()) + continue; + + auto resultType = dyn_cast(castOp.getResult().getType()); + if (!resultType) + continue; + + rewriter.setInsertionPoint(castOp); + Value ptrValue = pto::materializeBufferPointer( + castOp.getInput(), resultType.getElementType(), + resultType.getMemorySpace(), rewriter, castOp.getLoc()); + if (!ptrValue) { + if (diagOS) { + *diagOS << "VPTO emission-boundary ptr rewrite failed: could not " + "canonicalize pto.castptr input for "; + castOp->print(*diagOS); + *diagOS << "\n"; + } + return failure(); + } + + castOp.getResult().replaceAllUsesWith(ptrValue); + rewriter.eraseOp(castOp); + } + + return success(); +} + +static LogicalResult canonicalizeSupportedVPTOBufferLikeOps( + ModuleOp module, llvm::raw_ostream *diagOS) { + SmallVector opsToRewrite; + module.walk([&](Operation *op) { + if (isSupportedVPTOBufferLikeBoundaryOp(op)) + opsToRewrite.push_back(op); + }); + + PatternRewriter rewriter(module.getContext()); + for (Operation *op : opsToRewrite) { + rewriter.setInsertionPoint(op); + + SmallVector newOperands; + newOperands.reserve(op->getNumOperands()); + bool changed = false; + + for (Value operand : op->getOperands()) { + if (!needsPtrCanonicalization(operand)) { + newOperands.push_back(operand); + continue; + } + + Type elementType = getVPTOBufferElementType(operand); + Attribute memorySpace = getVPTOBufferMemorySpace(operand); + if (!elementType || !memorySpace) { + if (diagOS) { + *diagOS << "VPTO emission-boundary ptr rewrite failed: could not " + "derive element type or memory space for operand of "; + op->print(*diagOS); + *diagOS << "\n"; + } + return failure(); + } + + Value ptrValue = pto::materializeBufferPointer(operand, elementType, + memorySpace, rewriter, + op->getLoc()); + if (!ptrValue) { + if (diagOS) { + *diagOS << "VPTO emission-boundary ptr rewrite failed: could not " + "materialize pointer operand for "; + op->print(*diagOS); + *diagOS << "\n"; + } + return failure(); + } + + changed = changed || (ptrValue != operand); + newOperands.push_back(ptrValue); + } + + if (!changed) + continue; + + OperationState state(op->getLoc(), op->getName().getStringRef()); + state.addOperands(newOperands); + state.addTypes(op->getResultTypes()); + state.addAttributes(op->getAttrs()); + + Operation *newOp = rewriter.create(state); + rewriter.replaceOp(op, newOp->getResults()); + } + + return success(); +} + +struct PTOVPTOPtrBoundaryPass + : public pto::impl::PTOVPTOPtrBoundaryBase { + using pto::impl::PTOVPTOPtrBoundaryBase< + PTOVPTOPtrBoundaryPass>::PTOVPTOPtrBoundaryBase; + + void runOnOperation() override { + ModuleOp module = getOperation(); + if (failed(pto::convertVPTOEmissionBoundaryToPtr(module, &llvm::errs()))) + signalPassFailure(); + } +}; + +} // namespace + +LogicalResult mlir::pto::convertVPTOEmissionBoundaryToPtr( + ModuleOp module, llvm::raw_ostream *diagOS) { + // VPTO kernels use ptr-only entry semantics at the emission boundary: the + // function ABI keeps only the same-space base pointer, while shape/stride + // state remains in SSA. Body-level op canonicalization is added on top of + // this entry rewrite in follow-up tasks. + if (failed(eraseDeadVPTOMemRefScaffold(module))) + return failure(); + + bool sawFailure = false; + for (func::FuncOp func : module.getOps()) { + if (func.isExternal()) + continue; + + FunctionType functionType = func.getFunctionType(); + SmallVector newInputs(functionType.getInputs().begin(), + functionType.getInputs().end()); + bool changed = false; + + for (auto [idx, inputType] : llvm::enumerate(functionType.getInputs())) { + auto memrefType = dyn_cast(inputType); + if (!memrefType) + continue; + + Type newType = convertVPTOBoundaryMemRefType(inputType); + if (!newType) { + if (diagOS) + *diagOS << "VPTO emission-boundary ptr rewrite failed: unsupported " + "memref argument type in " + << func.getName() << ": " << inputType << "\n"; + sawFailure = true; + continue; + } + + BlockArgument arg = func.getArgument(idx); + SmallVector users(arg.getUsers().begin(), arg.getUsers().end()); + arg.setType(newType); + newInputs[idx] = newType; + changed = true; + + for (Operation *user : users) { + if (auto cast = dyn_cast(user)) { + if (cast.getInput() != arg) + continue; + if (cast.getResult().getType() == newType) { + cast.getResult().replaceAllUsesWith(arg); + cast.erase(); + } + continue; + } + + if (isa(user) && + user->use_empty()) { + user->erase(); + continue; + } + + if (isSupportedVPTOBufferLikeBoundaryOp(user)) + continue; + + if (diagOS) { + *diagOS << "VPTO emission-boundary ptr rewrite failed: argument " + << idx << " of " << func.getName() + << " still feeds a memref-dependent user after ptr rewrite:\n"; + user->print(*diagOS); + *diagOS << "\n"; + } + sawFailure = true; + } + } + + for (Type resultType : functionType.getResults()) { + if (!isa(resultType)) + continue; + if (diagOS) + *diagOS << "VPTO emission-boundary ptr rewrite failed: memref result " + "is unsupported for " + << func.getName() << ": " << resultType << "\n"; + sawFailure = true; + } + + if (changed) { + func.setFunctionType( + FunctionType::get(module.getContext(), newInputs, functionType.getResults())); + } + } + + if (sawFailure) + return failure(); + + if (failed(canonicalizeBoundaryCastPtrOps(module, diagOS))) + return failure(); + + if (failed(canonicalizeSupportedVPTOBufferLikeOps(module, diagOS))) + return failure(); + + return eraseDeadVPTOMemRefScaffold(module); +} + +std::unique_ptr mlir::pto::createPTOVPTOPtrBoundaryPass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/PTOValidateVPTOIR.cpp b/lib/PTO/Transforms/PTOValidateVPTOIR.cpp new file mode 100644 index 000000000..a81b4e81f --- /dev/null +++ b/lib/PTO/Transforms/PTOValidateVPTOIR.cpp @@ -0,0 +1,756 @@ +//===- PTOValidateVPTOIR.cpp - Shared VPTO legality helpers --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file owns the shared helper layer for the dual-stage VPTO legality +// verifier. Follow-up tasks add the public validation entrypoints and pass +// wrappers on top of this utility layer. +// +//===----------------------------------------------------------------------===// + +#include "PTO/IR/PTO.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include + +namespace mlir { +namespace pto { + +LogicalResult validateVPTOAuthoringIR(ModuleOp module, + llvm::raw_ostream *diagOS = nullptr); +LogicalResult validateVPTOEmissionIR(ModuleOp module, + llvm::raw_ostream *diagOS = nullptr); + +namespace detail { + +constexpr llvm::StringLiteral kAIVectorScopeAttrName = + "llvm.loop.aivector_scope"; + +enum class VPTOMaskGranularity { + B8, + B16, + B32, +}; + +enum class VPTOBufferAddressFamily { + None, + Copy, + BufferLike, + PtrOnly, +}; + +enum class VPTOLegalityStage { + Authoring, + Emission, +}; + +class VPTOLegalityHelper { +public: + explicit VPTOLegalityHelper(ModuleOp module) : module(module) {} + + ModuleOp getModule() const { return module; } + + SmallVector getFunctions() { + SmallVector funcs; + for (func::FuncOp func : module.getOps()) + funcs.push_back(func); + return funcs; + } + + static bool isLegalityTypedValue(Type type) { + return isa(type); + } + + static bool isBufferLikeValue(Type type) { + return isa(type); + } + + static bool requiresVecScope(Operation *op) { + if (!isPTOp(op)) + return false; + + return llvm::any_of(op->getOperandTypes(), isLegalityTypedValue) || + llvm::any_of(op->getResultTypes(), isLegalityTypedValue); + } + + static bool isAIVectorScopeCarrier(scf::ForOp loop) { + return loop && loop->hasAttr(kAIVectorScopeAttrName); + } + + static bool isDedicatedVecScopeCarrier(Operation *op) { + return isa_and_nonnull(op); + } + + static bool isAnyVectorScopeCarrier(Operation *op) { + if (auto loop = dyn_cast_or_null(op)) + return isAIVectorScopeCarrier(loop); + return isDedicatedVecScopeCarrier(op); + } + + static Operation *getEnclosingVectorScopeCarrier(Operation *op) { + for (Operation *parent = op ? op->getParentOp() : nullptr; parent; + parent = parent->getParentOp()) { + if (isAnyVectorScopeCarrier(parent)) + return parent; + } + return nullptr; + } + + static std::optional getMaskGranularity(Type type) { + auto maskType = dyn_cast(type); + if (!maskType) + return std::nullopt; + return getMaskGranularity(maskType); + } + + static std::optional getMaskGranularity(MaskType type) { + if (type.isB8()) + return VPTOMaskGranularity::B8; + if (type.isB16()) + return VPTOMaskGranularity::B16; + if (type.isB32()) + return VPTOMaskGranularity::B32; + return std::nullopt; + } + + static StringRef stringifyMaskGranularity(VPTOMaskGranularity granularity) { + switch (granularity) { + case VPTOMaskGranularity::B8: + return "b8"; + case VPTOMaskGranularity::B16: + return "b16"; + case VPTOMaskGranularity::B32: + return "b32"; + } + llvm_unreachable("unsupported VPTO mask granularity"); + } + + static std::optional + inferMaskGranularityFromType(Type type) { + if (auto vregType = dyn_cast(type)) + type = vregType.getElementType(); + + if (type.isF32()) + return VPTOMaskGranularity::B32; + if (type.isF16() || type.isBF16()) + return VPTOMaskGranularity::B16; + + auto intType = dyn_cast(type); + if (!intType) + return std::nullopt; + + switch (intType.getWidth()) { + case 8: + return VPTOMaskGranularity::B8; + case 16: + return VPTOMaskGranularity::B16; + case 32: + return VPTOMaskGranularity::B32; + default: + return std::nullopt; + } + } + + static std::optional + inferMaskGranularityFromFamily(Operation *op) { + StringRef mnemonic = getPTOpMnemonic(op); + if (mnemonic.empty()) + return std::nullopt; + + if (mnemonic.ends_with("_b8")) + return VPTOMaskGranularity::B8; + if (mnemonic.ends_with("_b16")) + return VPTOMaskGranularity::B16; + if (mnemonic.ends_with("_b32")) + return VPTOMaskGranularity::B32; + return std::nullopt; + } + + static VPTOBufferAddressFamily classifyBufferAddressFamily(Operation *op) { + if (!op) + return VPTOBufferAddressFamily::None; + + if (isa(op)) + return VPTOBufferAddressFamily::Copy; + + if (isa(op)) + return VPTOBufferAddressFamily::PtrOnly; + + if (isa(op)) + return VPTOBufferAddressFamily::BufferLike; + + return VPTOBufferAddressFamily::None; + } + + static bool isSupportedEmissionBufferLikeOp(Operation *op) { + return classifyBufferAddressFamily(op) == + VPTOBufferAddressFamily::BufferLike; + } + + static bool isResidualEmissionScaffold(Operation *op) { + return isa(op) || + isTrivialEmissionCastPtr(op); + } + + static SmallVector collectBufferOperands(Operation *op) { + SmallVector bufferOperands; + for (OpOperand &operand : op->getOpOperands()) { + if (isBufferLikeValue(operand.get().getType())) + bufferOperands.push_back(&operand); + } + return bufferOperands; + } + +private: + static bool isPTOp(Operation *op) { + return op && op->getName().getStringRef().starts_with("pto."); + } + + static StringRef getPTOpMnemonic(Operation *op) { + if (!isPTOp(op)) + return {}; + StringRef mnemonic = op->getName().getStringRef(); + (void)mnemonic.consume_front("pto."); + return mnemonic; + } + + static bool isTrivialEmissionCastPtr(Operation *op) { + auto castOp = dyn_cast_or_null(op); + return castOp && + castOp.getInput().getType() == castOp.getResult().getType(); + } + + ModuleOp module; +}; + +class VPTOLegalityValidator { +public: + VPTOLegalityValidator(ModuleOp module, VPTOLegalityStage stage, + llvm::raw_ostream *diagOS) + : helper(module), stage(stage), diagOS(diagOS) {} + + LogicalResult validate() { + if (!helper.getModule()) { + writeDiagnostic("VPTO legality validation requires a valid module\n"); + return failure(); + } + + if (failed(validateAuthoringRules())) + return failure(); + + if (stage == VPTOLegalityStage::Emission && + failed(validateEmissionRules())) + return failure(); + + return success(); + } + +private: + LogicalResult validateAuthoringRules() { + if (failed(validateAuthoringFunctionSurface())) + return failure(); + if (failed(validateAuthoringOperationSurface())) + return failure(); + return success(); + } + + LogicalResult validateEmissionRules() { + if (failed(validateEmissionFunctionSurface())) + return failure(); + if (failed(validateEmissionOperationSurface())) + return failure(); + return success(); + } + + static std::string formatExpectedMaskType(VPTOMaskGranularity granularity) { + std::string storage; + llvm::raw_string_ostream os(storage); + os << "!pto.mask<" + << VPTOLegalityHelper::stringifyMaskGranularity(granularity) << ">"; + return storage; + } + + static LogicalResult validateMaskMatchesVectorFamily(Operation *op, + Type maskType, + StringRef maskRole, + Type vectorType, + StringRef vectorRole) { + auto actual = VPTOLegalityHelper::getMaskGranularity(maskType); + auto expected = VPTOLegalityHelper::inferMaskGranularityFromType(vectorType); + if (!actual || !expected || *actual == *expected) + return success(); + + return op->emitOpError() + << maskRole << " " << maskType << " does not match " << vectorRole + << " " << vectorType << "; expected " + << formatExpectedMaskType(*expected); + } + + static LogicalResult validateSameMaskGranularity(Operation *op, Type lhsType, + StringRef lhsRole, + Type rhsType, + StringRef rhsRole) { + auto lhs = VPTOLegalityHelper::getMaskGranularity(lhsType); + auto rhs = VPTOLegalityHelper::getMaskGranularity(rhsType); + if (!lhs || !rhs || *lhs == *rhs) + return success(); + + return op->emitOpError() << lhsRole << " " << lhsType << " does not match " + << rhsRole << " " << rhsType; + } + + template + static LogicalResult validateInputMaskVectorConsumer(OpTy op) { + return validateMaskMatchesVectorFamily(op, op.getMask().getType(), + "mask type", + op.getInput().getType(), + "input vector type"); + } + + template + static LogicalResult validateBinaryMaskVectorConsumer(OpTy op) { + return validateMaskMatchesVectorFamily(op, op.getMask().getType(), + "mask type", op.getLhs().getType(), + "lhs vector type"); + } + + template + static LogicalResult validateValueMaskVectorConsumer(OpTy op) { + return validateMaskMatchesVectorFamily(op, op.getMask().getType(), + "mask type", op.getValue().getType(), + "value vector type"); + } + + template + static LogicalResult validateResultMaskVectorConsumer(OpTy op) { + return validateMaskMatchesVectorFamily(op, op.getMask().getType(), + "mask type", + op.getResult().getType(), + "result vector type"); + } + + template + static LogicalResult validateCarryFamilyContract(CarryOp op) { + if (failed(validateMaskMatchesVectorFamily(op, op.getMask().getType(), + "mask type", + op.getLhs().getType(), + "lhs vector type")) || + failed(validateSameMaskGranularity(op, op.getMask().getType(), + "mask type", + op.getCarry().getType(), + "carry type"))) + return failure(); + + if constexpr (std::is_same_v || + std::is_same_v) { + if (failed(validateSameMaskGranularity(op, op.getCarryIn().getType(), + "carry_in type", + op.getMask().getType(), + "mask type")) || + failed(validateSameMaskGranularity(op, op.getCarryIn().getType(), + "carry_in type", + op.getCarry().getType(), + "carry type"))) + return failure(); + } + + return success(); + } + + template + static LogicalResult validateCompareFamilyContract(CompareOp op, Type vecType) { + if (failed(validateMaskMatchesVectorFamily(op, op.getMask().getType(), + "seed mask type", vecType, + "input vector type")) || + failed(validateMaskMatchesVectorFamily(op, op.getResult().getType(), + "result mask type", vecType, + "input vector type")) || + failed(validateSameMaskGranularity(op, op.getMask().getType(), + "seed mask type", + op.getResult().getType(), + "result mask type"))) + return failure(); + return success(); + } + + template + static LogicalResult validateMaskOnlyUnaryContract(MaskUnaryOp op) { + return validateSameMaskGranularity(op, op.getInput().getType(), + "input mask type", + op.getResult().getType(), + "result mask type"); + } + + static LogicalResult validateMaskOnlyPnotContract(PnotOp op) { + if (failed(validateSameMaskGranularity(op, op.getInput().getType(), + "input mask type", + op.getMask().getType(), + "mask type")) || + failed(validateSameMaskGranularity(op, op.getInput().getType(), + "input mask type", + op.getResult().getType(), + "result mask type"))) + return failure(); + return success(); + } + + static LogicalResult validateMaskOnlyPselContract(PselOp op) { + if (failed(validateSameMaskGranularity(op, op.getSrc0().getType(), + "src0 mask type", + op.getSrc1().getType(), + "src1 mask type")) || + failed(validateSameMaskGranularity(op, op.getSrc0().getType(), + "src0 mask type", + op.getMask().getType(), + "mask type")) || + failed(validateSameMaskGranularity(op, op.getSrc0().getType(), + "src0 mask type", + op.getResult().getType(), + "result mask type"))) + return failure(); + return success(); + } + + template + static LogicalResult validatePredicateMovementContract( + PredicateMovementOp op) { + auto expected = VPTOLegalityHelper::inferMaskGranularityFromFamily(op); + if (!expected) + return success(); + + if (failed(validateSameMaskGranularity(op, op.getLhs().getType(), + "lhs mask type", + op.getRhs().getType(), + "rhs mask type")) || + failed(validateSameMaskGranularity(op, op.getLhs().getType(), + "lhs mask type", + op.getLow().getType(), + "low mask type")) || + failed(validateSameMaskGranularity(op, op.getLhs().getType(), + "lhs mask type", + op.getHigh().getType(), + "high mask type"))) + return failure(); + + auto lhs = VPTOLegalityHelper::getMaskGranularity(op.getLhs().getType()); + if (!lhs || *lhs == *expected) + return success(); + + return op.emitOpError() + << "predicate movement family requires " + << formatExpectedMaskType(*expected) + << " but got lhs mask type " << op.getLhs().getType(); + } + + static LogicalResult validateFamilySuffixMaskResult(Operation *op, + Type resultType, + StringRef resultRole) { + auto expected = VPTOLegalityHelper::inferMaskGranularityFromFamily(op); + auto actual = VPTOLegalityHelper::getMaskGranularity(resultType); + if (!expected || !actual || *expected == *actual) + return success(); + + return op->emitOpError() + << "family suffix requires " << resultRole << " to be " + << formatExpectedMaskType(*expected) << ", but got " << resultType; + } + + static LogicalResult validateFamilySuffixMaskContracts(Operation *op) { + return llvm::TypeSwitch(op) + .Case( + [](auto concreteOp) { + return validateFamilySuffixMaskResult( + concreteOp, concreteOp.getResult().getType(), "result type"); + }) + .Case([](auto concreteOp) { + return validateFamilySuffixMaskResult(concreteOp, + concreteOp.getMask().getType(), + "mask result type"); + }) + .Default([](Operation *) { return success(); }); + } + + static LogicalResult validateMaskGranularityContracts(Operation *op) { + return llvm::TypeSwitch(op) + .Case( + [](auto concreteOp) { + return validateInputMaskVectorConsumer(concreteOp); + }) + .Case([](auto concreteOp) { + return validateBinaryMaskVectorConsumer(concreteOp); + }) + .Case([](auto concreteOp) { + return validateCarryFamilyContract(concreteOp); + }) + .Case([](VcmpOp concreteOp) { + return validateCompareFamilyContract(concreteOp, + concreteOp.getSrc0().getType()); + }) + .Case([](VcmpsOp concreteOp) { + return validateCompareFamilyContract(concreteOp, + concreteOp.getSrc().getType()); + }) + .Case([](auto concreteOp) { + return validateMaskOnlyUnaryContract(concreteOp); + }) + .Case( + [](PnotOp concreteOp) { return validateMaskOnlyPnotContract(concreteOp); }) + .Case( + [](PselOp concreteOp) { return validateMaskOnlyPselContract(concreteOp); }) + .Case([](auto concreteOp) { + return validatePredicateMovementContract(concreteOp); + }) + .Case([](VselOp concreteOp) { + return validateMaskMatchesVectorFamily(concreteOp, + concreteOp.getMask().getType(), + "mask type", + concreteOp.getSrc0().getType(), + "src0 vector type"); + }) + .Case([](auto concreteOp) { + return validateResultMaskVectorConsumer(concreteOp); + }) + .Case([](auto concreteOp) { + return validateValueMaskVectorConsumer(concreteOp); + }) + .Case([](Vstsx2Op concreteOp) { + return validateMaskMatchesVectorFamily(concreteOp, + concreteOp.getMask().getType(), + "mask type", + concreteOp.getLow().getType(), + "low vector type"); + }) + .Case([](auto concreteOp) { + return validateMaskMatchesVectorFamily(concreteOp, + concreteOp.getMask().getType(), + "mask type", + concreteOp.getLhs().getType(), + "lhs vector type"); + }) + .Default([](Operation *) { return success(); }); + } + + LogicalResult validateAuthoringFunctionSurface() { + for (func::FuncOp func : helper.getFunctions()) { + (void)func; + } + return success(); + } + + LogicalResult validateAuthoringOperationSurface() { + WalkResult loopWalkResult = helper.getModule().walk([&](scf::ForOp loop) { + if (!VPTOLegalityHelper::isAIVectorScopeCarrier(loop)) + return WalkResult::advance(); + + Operation *parentScope = + VPTOLegalityHelper::getEnclosingVectorScopeCarrier(loop); + if (!parentScope) + return WalkResult::advance(); + + if (isa(parentScope)) { + loop.emitOpError() << "does not allow nested scf.for with '" + << kAIVectorScopeAttrName << "'"; + return WalkResult::interrupt(); + } + + loop.emitOpError() + << "does not allow legacy scf.for carrier nested inside dedicated " + "pto.vecscope/pto.strict_vecscope"; + return WalkResult::interrupt(); + }); + if (loopWalkResult.wasInterrupted()) + return failure(); + + WalkResult vecScopeWalkResult = helper.getModule().walk([&](Operation *op) { + if (!VPTOLegalityHelper::isDedicatedVecScopeCarrier(op)) + return WalkResult::advance(); + + if (!VPTOLegalityHelper::getEnclosingVectorScopeCarrier(op)) + return WalkResult::advance(); + + op->emitOpError() + << "does not allow nested dedicated pto.vecscope/pto.strict_vecscope"; + return WalkResult::interrupt(); + }); + if (vecScopeWalkResult.wasInterrupted()) + return failure(); + + WalkResult opWalkResult = helper.getModule().walk([&](Operation *op) { + (void)VPTOLegalityHelper::inferMaskGranularityFromFamily(op); + (void)VPTOLegalityHelper::classifyBufferAddressFamily(op); + + if (!VPTOLegalityHelper::requiresVecScope(op)) + return WalkResult::advance(); + + if (VPTOLegalityHelper::getEnclosingVectorScopeCarrier(op)) + return (failed(validateFamilySuffixMaskContracts(op)) || + failed(validateMaskGranularityContracts(op))) + ? WalkResult::interrupt() + : WalkResult::advance(); + + op->emitOpError() + << "requires enclosing scf.for with '" + << kAIVectorScopeAttrName + << "' or dedicated pto.vecscope/pto.strict_vecscope" + << "' because it consumes or produces !pto.vreg/!pto.mask/!pto.align"; + return WalkResult::interrupt(); + }); + return opWalkResult.wasInterrupted() ? failure() : success(); + } + + LogicalResult validateEmissionFunctionSurface() { + for (func::FuncOp func : helper.getFunctions()) { + FunctionType functionType = func.getFunctionType(); + + for (auto [idx, inputType] : llvm::enumerate(functionType.getInputs())) { + if (!isa(inputType)) + continue; + return func.emitError() + << "emission-stage VPTO legality rejects memref argument #" + << idx << ": " << inputType; + } + + for (auto [idx, resultType] : llvm::enumerate(functionType.getResults())) { + if (!isa(resultType)) + continue; + return func.emitError() + << "emission-stage VPTO legality rejects memref result #" + << idx << ": " << resultType; + } + } + return success(); + } + + LogicalResult validateEmissionOperationSurface() { + WalkResult walkResult = helper.getModule().walk([&](Operation *op) { + VPTOBufferAddressFamily family = + VPTOLegalityHelper::classifyBufferAddressFamily(op); + + if (family == VPTOBufferAddressFamily::BufferLike) { + for (OpOperand *operand : VPTOLegalityHelper::collectBufferOperands(op)) { + Type operandType = operand->get().getType(); + if (!isa(operandType)) + continue; + + op->emitOpError() + << "emission-stage VPTO legality rejects memref-form buffer " + "operand #" + << operand->getOperandNumber() << " of type " << operandType + << " for buffer-like family op"; + return WalkResult::interrupt(); + } + } + + if (VPTOLegalityHelper::isResidualEmissionScaffold(op)) { + op->emitOpError() + << "must be eliminated before emission-stage VPTO validation"; + return WalkResult::interrupt(); + } + + return WalkResult::advance(); + }); + return walkResult.wasInterrupted() ? failure() : success(); + } + + void writeDiagnostic(StringRef message) const { + if (diagOS) + *diagOS << message; + } + + VPTOLegalityHelper helper; + VPTOLegalityStage stage; + llvm::raw_ostream *diagOS; +}; + +} // namespace detail + +namespace { + +struct PTOValidateVPTOIRPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PTOValidateVPTOIRPass) + + StringRef getArgument() const final { return "pto-validate-vpto-ir"; } + + StringRef getDescription() const final { + return "Validate authoring-stage VPTO legality before emission-boundary canonicalization"; + } + + void runOnOperation() override { + ModuleOp module = getOperation(); + if (failed(validateVPTOAuthoringIR(module, &llvm::errs()))) + signalPassFailure(); + } +}; + +struct PTOValidateVPTOEmissionIRPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PTOValidateVPTOEmissionIRPass) + + StringRef getArgument() const final { + return "pto-validate-vpto-emission-ir"; + } + + StringRef getDescription() const final { + return "Validate emission-stage VPTO legality after ptr-boundary canonicalization"; + } + + void runOnOperation() override { + ModuleOp module = getOperation(); + if (failed(validateVPTOEmissionIR(module, &llvm::errs()))) + signalPassFailure(); + } +}; + +} // namespace + +LogicalResult validateVPTOAuthoringIR(ModuleOp module, + llvm::raw_ostream *diagOS) { + return detail::VPTOLegalityValidator( + module, detail::VPTOLegalityStage::Authoring, diagOS) + .validate(); +} + +LogicalResult validateVPTOEmissionIR(ModuleOp module, + llvm::raw_ostream *diagOS) { + return detail::VPTOLegalityValidator( + module, detail::VPTOLegalityStage::Emission, diagOS) + .validate(); +} + +std::unique_ptr createPTOValidateVPTOIRPass() { + return std::make_unique(); +} + +std::unique_ptr createPTOValidateVPTOEmissionIRPass() { + return std::make_unique(); +} + +} // namespace pto +} // namespace mlir diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index 8aa398169..7fd339515 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -528,7 +528,7 @@ static Type convertPTOTypeToMemRef(Type t) { // 1. 处理 !pto.ptr if (auto pty = dyn_cast(t)) { return MemRefType::get({ShapedType::kDynamic}, pty.getElementType(), - MemRefLayoutAttrInterface(), Attribute()); + MemRefLayoutAttrInterface(), pty.getMemorySpace()); } // 2. 处理 !pto.tile_buf<...> diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp new file mode 100644 index 000000000..55418cfe4 --- /dev/null +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -0,0 +1,4562 @@ +//===- VPTOLLVMEmitter.cpp - VPTO to official LLVM IR text emitter -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PTO/Transforms/VPTOLLVMEmitter.h" + +#include "PTO/IR/PTO.h" +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/VPTOLowering.h" +#include "PTO/Transforms/HIVMIntrinsicNaming.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Conversion/Passes.h" +#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Export.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/Bitcode/BitcodeWriter.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/Process.h" +#include "llvm/Support/Program.h" +#include "llvm/ADT/ScopeExit.h" +#include "llvm/Support/raw_ostream.h" + +#include + +using namespace mlir; + +namespace mlir::pto { +namespace { + +constexpr StringLiteral kAIVScopeDummyCallee = "aivscope_dummy"; + +struct QueriedTargetAttrs { + std::string targetCPU; + std::string targetFeatures; +}; + +struct ABIExpr { + enum class Kind { Constant, FuncArg, Mul }; + + Kind kind = Kind::Constant; + uint64_t constant = 0; + unsigned argIndex = 0; + std::unique_ptr lhs; + std::unique_ptr rhs; + + static ABIExpr constantExpr(uint64_t value) { + ABIExpr expr; + expr.kind = Kind::Constant; + expr.constant = value; + return expr; + } + + static ABIExpr argExpr(unsigned argIndex) { + ABIExpr expr; + expr.kind = Kind::FuncArg; + expr.argIndex = argIndex; + return expr; + } + + static ABIExpr mulExpr(ABIExpr lhs, ABIExpr rhs) { + ABIExpr expr; + expr.kind = Kind::Mul; + expr.lhs = std::make_unique(std::move(lhs)); + expr.rhs = std::make_unique(std::move(rhs)); + return expr; + } +}; + +struct ExternalMemRefABISpec { + unsigned addressSpace = 1; + int64_t rank = 0; + ABIExpr offset = ABIExpr::constantExpr(0); + ABIExpr totalSize = ABIExpr::constantExpr(1); + ABIExpr stride = ABIExpr::constantExpr(1); +}; + +struct ExternalArgABISpec { + bool isMemRef = false; + ExternalMemRefABISpec memrefSpec; +}; + +struct FunctionABISpec { + SmallVector args; +}; + +static Type getElementTypeFromVectorLike(Type type); +static Type getElementTypeFromPointerLike(Type type); +static std::optional getElementCountFromVectorLike(Type type); +static func::FuncOp getOrCreateExternalFunc(ModuleOp module, StringRef name, + FunctionType type); +static Value castIntegerLikeTo(Operation *anchor, Value value, Type targetType); + +static std::string getElementTypeFragment(Type type) { + if (type.isF16()) + return "f16"; + if (type.isBF16()) + return "bf16"; + if (type.isF32()) + return "f32"; + if (auto intType = dyn_cast(type)) + return (intType.isUnsigned() ? "u" : "s") + std::to_string(intType.getWidth()); + return {}; +} + +static std::optional parseRoundModeImmediate(StringRef roundMode) { + if (roundMode == "R" || roundMode == "ROUND_R") + return 0; // __cce_simd::ROUND::R + if (roundMode == "A" || roundMode == "ROUND_A") + return 1; // __cce_simd::ROUND::A + if (roundMode == "F" || roundMode == "ROUND_F") + return 2; // __cce_simd::ROUND::F + if (roundMode == "C" || roundMode == "ROUND_C") + return 3; // __cce_simd::ROUND::C + if (roundMode == "Z" || roundMode == "ROUND_Z") + return 4; // __cce_simd::ROUND::Z + if (roundMode == "O" || roundMode == "ROUND_O") + return 5; // __cce_simd::ROUND::O + return std::nullopt; +} + +static std::optional parseSaturationImmediate(StringRef sat) { + if (sat == "SAT" || sat == "RS_ENABLE") + return 0; // __cce_simd::RoundingSaturation::ENABLE + if (sat == "NOSAT" || sat == "RS_DISABLE") + return 1; // __cce_simd::RoundingSaturation::DISABLE + return std::nullopt; +} + +static std::optional parsePartImmediate(StringRef part) { + if (part == "EVEN" || part == "PART_EVEN") + return 0; // __cce_simd::Part::EVEN + if (part == "ODD" || part == "PART_ODD") + return 1; // __cce_simd::Part::ODD + return std::nullopt; +} + +static FailureOr normalizeVdupScalarOperand(OpBuilder &builder, Location loc, + pto::VdupOp vdup) { + Value input = vdup.getInput(); + Type scalarType = input.getType(); + auto intType = dyn_cast(scalarType); + if (!intType || intType.getWidth() != 8) + return input; + + Type resultElemType = getElementTypeFromVectorLike(vdup.getResult().getType()); + std::string resultElemFragment = getElementTypeFragment(resultElemType); + if (resultElemFragment != "s8" && resultElemFragment != "u8") + return input; + + Type i16Type = builder.getIntegerType(16); + if (resultElemFragment == "u8") + return builder.create(loc, i16Type, input).getResult(); + return builder.create(loc, i16Type, input).getResult(); +} + +// VSQZ #st hint must only be set when the compacted vector feeds VSTUR. +// Emitting #st=1 without a matching VSTUR consumer can deadlock hardware queues. +static uint64_t determineVsqzStoreHint(pto::VsqzOp vsqz) { + Value result = vsqz.getResult(); + for (Operation *user : result.getUsers()) { + auto vstur = dyn_cast(user); + if (!vstur) + continue; + if (vstur.getValue() == result) + return 1; + } + return 0; +} + +enum class VcvtElemKind { + Invalid, + F16, + BF16, + F32, + S8, + U8, + S16, + U16, + S32, + U32, + S64, +}; + +struct VcvtContract { + const char *intrinsic; + bool requiresRnd; + bool requiresSat; + bool requiresPart; + unsigned maskBitWidth; +}; + +static VcvtElemKind classifyVcvtElemType(Type type) { + if (type.isF16()) + return VcvtElemKind::F16; + if (type.isBF16()) + return VcvtElemKind::BF16; + if (type.isF32()) + return VcvtElemKind::F32; + if (auto intType = dyn_cast(type)) { + switch (intType.getWidth()) { + case 8: + return intType.isUnsigned() ? VcvtElemKind::U8 : VcvtElemKind::S8; + case 16: + return intType.isUnsigned() ? VcvtElemKind::U16 : VcvtElemKind::S16; + case 32: + return intType.isUnsigned() ? VcvtElemKind::U32 : VcvtElemKind::S32; + case 64: + return intType.isUnsigned() ? VcvtElemKind::Invalid : VcvtElemKind::S64; + default: + return VcvtElemKind::Invalid; + } + } + return VcvtElemKind::Invalid; +} + +static std::optional lookupVcvtContract(VcvtElemKind src, + VcvtElemKind dst) { + switch (src) { + case VcvtElemKind::F32: + switch (dst) { + case VcvtElemKind::F16: + return VcvtContract{"llvm.hivm.vcvtff.f322f16.x", true, true, true, 32}; + case VcvtElemKind::BF16: + return VcvtContract{"llvm.hivm.vcvtff.f322bf16.x", true, true, true, 32}; + case VcvtElemKind::S16: + return VcvtContract{"llvm.hivm.vcvtfi.f322s16.x", true, true, true, 32}; + case VcvtElemKind::S32: + return VcvtContract{"llvm.hivm.vcvtfi.f322s32.x", true, true, false, 32}; + case VcvtElemKind::S64: + return VcvtContract{"llvm.hivm.vcvtfi.f322s64.x", true, true, true, 32}; + default: + return std::nullopt; + } + case VcvtElemKind::F16: + switch (dst) { + case VcvtElemKind::F32: + return VcvtContract{"llvm.hivm.vcvtff.f162f32.x", false, false, true, 16}; + case VcvtElemKind::S32: + return VcvtContract{"llvm.hivm.vcvtfi.f162s32.x", true, false, true, 16}; + case VcvtElemKind::S16: + return VcvtContract{"llvm.hivm.vcvtfi.f162s16.x", true, true, false, 16}; + case VcvtElemKind::S8: + return VcvtContract{"llvm.hivm.vcvtfi.f162s8.x", true, true, true, 16}; + case VcvtElemKind::U8: + return VcvtContract{"llvm.hivm.vcvtfi.f162u8.x", true, true, true, 16}; + default: + return std::nullopt; + } + case VcvtElemKind::BF16: + switch (dst) { + case VcvtElemKind::F32: + return VcvtContract{"llvm.hivm.vcvtff.bf162f32.x", false, false, true, 16}; + case VcvtElemKind::S32: + return VcvtContract{"llvm.hivm.vcvtfi.bf162s32.x", true, true, true, 16}; + default: + return std::nullopt; + } + case VcvtElemKind::U8: + switch (dst) { + case VcvtElemKind::F16: + return VcvtContract{"llvm.hivm.vcvtif.u82f16.x", false, false, true, 8}; + case VcvtElemKind::U16: + return VcvtContract{"llvm.hivm.vcvtii.u82u16.x", false, false, true, 8}; + case VcvtElemKind::U32: + return VcvtContract{"llvm.hivm.vcvtii.u82u32.x", false, false, true, 8}; + default: + return std::nullopt; + } + case VcvtElemKind::S8: + switch (dst) { + case VcvtElemKind::F16: + return VcvtContract{"llvm.hivm.vcvtif.s82f16.x", false, false, true, 8}; + case VcvtElemKind::S16: + return VcvtContract{"llvm.hivm.vcvtii.s82s16.x", false, false, true, 8}; + case VcvtElemKind::S32: + return VcvtContract{"llvm.hivm.vcvtii.s82s32.x", false, false, true, 8}; + default: + return std::nullopt; + } + case VcvtElemKind::U16: + switch (dst) { + case VcvtElemKind::U8: + return VcvtContract{"llvm.hivm.vcvtii.u162u8.x", false, true, true, 16}; + case VcvtElemKind::U32: + return VcvtContract{"llvm.hivm.vcvtii.u162u32.x", false, false, true, 16}; + default: + return std::nullopt; + } + case VcvtElemKind::S16: + switch (dst) { + case VcvtElemKind::F16: + return VcvtContract{"llvm.hivm.vcvtif.s162f16.x", true, false, false, 16}; + case VcvtElemKind::F32: + return VcvtContract{"llvm.hivm.vcvtif.s162f32.x", false, false, true, 16}; + case VcvtElemKind::U8: + return VcvtContract{"llvm.hivm.vcvtii.s162u8.x", false, true, true, 16}; + case VcvtElemKind::U32: + return VcvtContract{"llvm.hivm.vcvtii.s162u32.x", false, false, true, 16}; + case VcvtElemKind::S32: + return VcvtContract{"llvm.hivm.vcvtii.s162s32.x", false, false, true, 16}; + default: + return std::nullopt; + } + case VcvtElemKind::U32: + switch (dst) { + case VcvtElemKind::U8: + return VcvtContract{"llvm.hivm.vcvtii.u322u8.x", false, true, true, 32}; + case VcvtElemKind::U16: + return VcvtContract{"llvm.hivm.vcvtii.u322u16.x", false, true, true, 32}; + case VcvtElemKind::S16: + return VcvtContract{"llvm.hivm.vcvtii.u322s16.x", false, true, true, 32}; + default: + return std::nullopt; + } + case VcvtElemKind::S32: + switch (dst) { + case VcvtElemKind::F32: + return VcvtContract{"llvm.hivm.vcvtif.s322f32.x", true, false, false, 32}; + case VcvtElemKind::U8: + return VcvtContract{"llvm.hivm.vcvtii.s322u8.x", false, true, true, 32}; + case VcvtElemKind::U16: + return VcvtContract{"llvm.hivm.vcvtii.s322u16.x", false, true, true, 32}; + case VcvtElemKind::S16: + return VcvtContract{"llvm.hivm.vcvtii.s322s16.x", false, true, true, 32}; + case VcvtElemKind::S64: + return VcvtContract{"llvm.hivm.vcvtii.s322s64.x", false, false, true, 32}; + default: + return std::nullopt; + } + case VcvtElemKind::S64: + switch (dst) { + case VcvtElemKind::F32: + return VcvtContract{"llvm.hivm.vcvtif.s642f32.x", true, false, true, 32}; + case VcvtElemKind::S32: + return VcvtContract{"llvm.hivm.vcvtii.s642s32.x", false, true, true, 32}; + default: + return std::nullopt; + } + case VcvtElemKind::Invalid: + return std::nullopt; + } + return std::nullopt; +} + +static std::optional parseHiLoPartImmediate(StringRef part) { + if (part == "LOWER") + return 0; // __cce_simd::HiloPart::Lower + if (part == "HIGHER") + return 1; // __cce_simd::HiloPart::Higher + return std::nullopt; +} + +static std::optional parsePredicatePatternImmediate(StringRef pattern) { + if (pattern == "PAT_ALL") + return 0; + if (pattern == "PAT_VL1") + return 1; + if (pattern == "PAT_VL2") + return 2; + if (pattern == "PAT_VL3") + return 3; + if (pattern == "PAT_VL4") + return 4; + if (pattern == "PAT_VL8") + return 5; + if (pattern == "PAT_VL16") + return 6; + if (pattern == "PAT_VL32") + return 7; + if (pattern == "PAT_VL64") + return 8; + if (pattern == "PAT_VL128") + return 9; + if (pattern == "PAT_M3") + return 10; + if (pattern == "PAT_M4") + return 11; + if (pattern == "PAT_H") + return 12; + if (pattern == "PAT_Q") + return 13; + if (pattern == "PAT_ALLF") + return 15; + return std::nullopt; +} + +static Type getSignlessIntegerTypeWithSameWidth(Type type, Builder &builder) { + if (auto intType = dyn_cast(type)) + return builder.getIntegerType(intType.getWidth()); + if (auto floatType = dyn_cast(type)) + return builder.getIntegerType(floatType.getWidth()); + return {}; +} + +static std::string getVbrScalarFragment(Type type) { + if (type.isF16()) + return "f16"; + if (type.isBF16()) + return "bf16"; + if (type.isF32()) + return "f32"; + if (auto intType = dyn_cast(type)) + return (intType.isUnsigned() ? "u" : "s") + std::to_string(intType.getWidth()); + return {}; +} + +static std::string getCopyElementFragment(Type elementType) { + if (!elementType) + return {}; + if (elementType.isF16()) + return "f16"; + if (elementType.isBF16()) + return "bf16"; + if (elementType.isF32()) + return "f32"; + if (auto intType = dyn_cast(elementType)) { + switch (intType.getWidth()) { + case 8: + return intType.isUnsigned() ? "u8" : "s8"; + case 16: + return intType.isUnsigned() ? "u16" : "s16"; + case 32: + return intType.isUnsigned() ? "u32" : "s32"; + default: + return {}; + } + } + return {}; +} + +static std::optional buildABIExprFromValue(Value value); + +static std::optional buildABIExprFromFoldResult(OpFoldResult ofr) { + if (auto attr = ofr.dyn_cast()) { + if (auto intAttr = dyn_cast(attr)) + return ABIExpr::constantExpr(intAttr.getValue().getZExtValue()); + return std::nullopt; + } + return buildABIExprFromValue(ofr.get()); +} + +static std::optional buildABIExprFromValue(Value value) { + if (auto blockArg = dyn_cast(value)) { + auto func = dyn_cast(blockArg.getOwner()->getParentOp()); + if (!func || blockArg.getOwner() != &func.getBody().front()) + return std::nullopt; + return ABIExpr::argExpr(blockArg.getArgNumber()); + } + + if (auto constIndex = value.getDefiningOp()) + return ABIExpr::constantExpr(constIndex.value()); + if (auto constOp = value.getDefiningOp()) { + if (auto intAttr = dyn_cast(constOp.getValue())) + return ABIExpr::constantExpr(intAttr.getValue().getZExtValue()); + } + if (auto castOp = value.getDefiningOp()) + return buildABIExprFromValue(castOp.getIn()); + if (auto castOp = value.getDefiningOp()) + return buildABIExprFromValue(castOp.getIn()); + if (auto extOp = value.getDefiningOp()) + return buildABIExprFromValue(extOp.getIn()); + if (auto extOp = value.getDefiningOp()) + return buildABIExprFromValue(extOp.getIn()); + if (auto truncOp = value.getDefiningOp()) + return buildABIExprFromValue(truncOp.getIn()); + if (auto mulOp = value.getDefiningOp()) { + auto lhs = buildABIExprFromValue(mulOp.getLhs()); + auto rhs = buildABIExprFromValue(mulOp.getRhs()); + if (!lhs || !rhs) + return std::nullopt; + return ABIExpr::mulExpr(std::move(*lhs), std::move(*rhs)); + } + + return std::nullopt; +} + +static unsigned getExternalPointerAddressSpace(MemRefType type) { + if (auto addrAttr = dyn_cast_or_null(type.getMemorySpace())) { + switch (addrAttr.getAddressSpace()) { + case pto::AddressSpace::GM: + case pto::AddressSpace::Zero: + return 1; + case pto::AddressSpace::VEC: + return 6; + default: + break; + } + } + return 1; +} + +static std::optional deriveMemRefTotalSize(BlockArgument arg, + MemRefType type) { + if (type.getRank() != 1) + return std::nullopt; + + if (!type.isDynamicDim(0)) + return ABIExpr::constantExpr(type.getDimSize(0)); + + for (Operation *user : arg.getUsers()) { + auto reinterpret = dyn_cast(user); + if (!reinterpret || reinterpret.getSource() != arg) + continue; + + std::optional accum; + for (OpFoldResult size : reinterpret.getMixedSizes()) { + auto sizeExpr = buildABIExprFromFoldResult(size); + if (!sizeExpr) + return std::nullopt; + accum = accum ? ABIExpr::mulExpr(std::move(*accum), std::move(*sizeExpr)) + : std::move(*sizeExpr); + } + if (accum) + return accum; + } + + return std::nullopt; +} + +static llvm::StringMap collectFunctionABISpecs(ModuleOp module) { + llvm::StringMap specs; + module.walk([&](func::FuncOp funcOp) { + if (funcOp.isExternal()) + return; + + FunctionABISpec funcSpec; + funcSpec.args.reserve(funcOp.getNumArguments()); + + for (BlockArgument arg : funcOp.getArguments()) { + ExternalArgABISpec argSpec; + if (auto memrefType = dyn_cast(arg.getType())) { + if (memrefType.getRank() == 1) { + auto totalSize = deriveMemRefTotalSize(arg, memrefType); + if (totalSize) { + argSpec.isMemRef = true; + argSpec.memrefSpec.addressSpace = + getExternalPointerAddressSpace(memrefType); + argSpec.memrefSpec.rank = 1; + argSpec.memrefSpec.offset = ABIExpr::constantExpr(0); + argSpec.memrefSpec.totalSize = std::move(*totalSize); + argSpec.memrefSpec.stride = ABIExpr::constantExpr(1); + } + } + } + funcSpec.args.push_back(std::move(argSpec)); + } + + specs[funcOp.getName().str()] = std::move(funcSpec); + }); + return specs; +} + +static std::optional parsePipeImmediate(llvm::StringRef pipe) { + if (pipe == "PIPE_S") + return 0; + if (pipe == "PIPE_V") + return 1; + if (pipe == "PIPE_M") + return 2; + if (pipe == "PIPE_MTE1") + return 3; + if (pipe == "PIPE_MTE2") + return 4; + if (pipe == "PIPE_MTE3") + return 5; + if (pipe == "PIPE_ALL") + return 6; + if (pipe == "PIPE_MTE4") + return 7; + if (pipe == "PIPE_MTE5") + return 8; + if (pipe == "PIPE_V2") + return 9; + if (pipe == "PIPE_FIX") + return 10; + if (pipe == "VIRTUAL_PIPE_MTE2_L1A") + return 11; + if (pipe == "VIRTUAL_PIPE_MTE2_L1B") + return 12; + return std::nullopt; +} + +static std::optional parseEventImmediate(llvm::StringRef event) { + if (!event.consume_front("EVENT_ID")) + return std::nullopt; + uint64_t value = 0; + if (event.getAsInteger(10, value)) + return std::nullopt; + return value; +} + +static std::optional parseSprImmediate(llvm::StringRef spr) { + if (spr == "AR") + return 74; + return std::nullopt; +} + +static std::optional getDistElementWidth(Type type) { + if (auto intType = dyn_cast(type)) + return intType.getWidth(); + if (type.isF16() || type.isBF16()) + return 16; + if (type.isF32()) + return 32; + if (type.isF64()) + return 64; + return std::nullopt; +} + +static std::optional parseLoadDistImmediate(llvm::StringRef dist, + Type elementType) { + auto width = getDistElementWidth(elementType); + if (dist.empty() || dist == "NORM") + return 0; + if (!width) + return std::nullopt; + if (dist == "BRC") + return *width == 8 ? std::optional(1) + : *width == 16 ? std::optional(2) + : *width == 32 ? std::optional(3) + : std::nullopt; + if (dist == "US") + return *width == 8 ? std::optional(6) + : *width == 16 ? std::optional(7) + : std::nullopt; + if (dist == "DS") + return *width == 8 ? std::optional(8) + : *width == 16 ? std::optional(9) + : std::nullopt; + if (dist == "UNPK") + return *width == 8 ? std::optional(13) + : *width == 16 ? std::optional(14) + : *width == 32 ? std::optional(18) + : std::nullopt; + if (dist == "BRC_BLK") + return 15; + if (dist == "E2B") + return *width == 16 ? std::optional(16) + : *width == 32 ? std::optional(17) + : std::nullopt; + if (dist == "UNPK4") + return *width == 8 ? std::optional(20) : std::nullopt; + if (dist == "SPLT4CHN") + return *width == 8 ? std::optional(21) : std::nullopt; + if (dist == "SPLT2CHN") + return *width == 8 ? std::optional(22) + : *width == 16 ? std::optional(23) + : std::nullopt; + return std::nullopt; +} + +static std::optional parseLoadX2DistImmediate(llvm::StringRef dist, + Type elementType) { + auto width = getDistElementWidth(elementType); + if (dist == "BDINTLV") + return 10; + if (!width) + return std::nullopt; + if (dist == "DINTLV") + return *width == 8 ? std::optional(11) + : *width == 16 ? std::optional(12) + : *width == 32 ? std::optional(19) + : std::nullopt; + return std::nullopt; +} + +static std::optional parsePredicateLoadDistImmediate(llvm::StringRef dist) { + if (dist.empty() || dist == "NORM") + return 0; // Dist::DIST_NORM + if (dist == "US") + return 1; // Dist::DIST_US + if (dist == "DS") + return 2; // Dist::DIST_DS + return std::nullopt; +} + +static std::optional parsePredicateStoreDistImmediate(llvm::StringRef dist) { + if (dist.empty() || dist == "NORM") + return 0; // Dist::DIST_NORM + if (dist == "PK") + return 1; // Dist::DIST_PK + return std::nullopt; +} + +static Value packBlockRepeatStride(Operation *anchor, Value blockStride, + Value repeatStride) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + + Value blockI32 = castIntegerLikeTo(anchor, blockStride, builder.getI32Type()); + Value repeatI32 = + castIntegerLikeTo(anchor, repeatStride, builder.getI32Type()); + if (!blockI32 || !repeatI32) + return {}; + + auto c16 = builder.create(anchor->getLoc(), 16, 32); + auto blockShifted = + builder.create(anchor->getLoc(), blockI32, c16); + return builder + .create(anchor->getLoc(), blockShifted, repeatI32) + .getResult(); +} + +static std::optional parseOrderImmediate(llvm::StringRef order) { + if (order.empty() || order == "ASC") + return 0; // INC_ORDER + if (order == "DESC") + return 1; // DEC_ORDER + return std::nullopt; +} + +static std::optional parseStoreDistImmediate(llvm::StringRef dist, + Type elementType) { + auto width = getDistElementWidth(elementType); + if (dist.empty() || dist == "NORM") { + if (!width) + return std::nullopt; + if (*width == 8) + return 0; // norm_b8 + if (*width == 16) + return 1; // norm_b16 + if (*width == 32) + return 2; // norm_b32 + return std::nullopt; + } + if (!width) + return std::nullopt; + if (dist == "1PT") + return *width == 8 ? std::optional(3) + : *width == 16 ? std::optional(4) + : *width == 32 ? std::optional(5) + : std::nullopt; + if (dist == "PK") + return *width == 16 ? std::optional(6) + : *width == 32 ? std::optional(7) + : *width == 64 ? std::optional(10) + : std::nullopt; + if (dist == "PK4") + return *width == 32 ? std::optional(12) : std::nullopt; + if (dist == "MRG4CHN") + return *width == 8 ? std::optional(13) : std::nullopt; + if (dist == "MRG2CHN") + return *width == 8 ? std::optional(14) + : *width == 16 ? std::optional(15) + : std::nullopt; + return std::nullopt; +} + +static std::optional parseStoreX2DistImmediate(llvm::StringRef dist, + Type elementType) { + auto width = getDistElementWidth(elementType); + if (!width) + return std::nullopt; + if (dist == "INTLV") + return *width == 8 ? std::optional(8) + : *width == 16 ? std::optional(9) + : *width == 32 ? std::optional(11) + : std::nullopt; + return std::nullopt; +} + +static std::optional parsePostModeImmediate(StringRef mode) { + if (mode == "NO_POST_UPDATE") + return 0; + if (mode == "POST_UPDATE") + return 1; + return std::nullopt; +} + +static Type convertVPTOType(Type type, Builder &builder) { + if (auto vecType = dyn_cast(type)) + return VectorType::get({vecType.getElementCount()}, vecType.getElementType()); + if (isa(type)) + return VectorType::get({256}, builder.getI1Type()); + if (isa(type)) + return VectorType::get({32}, builder.getIntegerType(8)); + if (auto ptrType = dyn_cast(type)) { + return LLVM::LLVMPointerType::get( + builder.getContext(), + static_cast(ptrType.getMemorySpace().getAddressSpace())); + } + return type; +} + +static bool hasPtoPtrType(TypeRange types) { + return llvm::any_of(types, [](Type type) { return isa(type); }); +} + +static bool hasPtoAlignType(Type type) { + if (isa(type)) + return true; + if (auto functionType = dyn_cast(type)) + return llvm::any_of(functionType.getInputs(), hasPtoAlignType) || + llvm::any_of(functionType.getResults(), hasPtoAlignType); + return false; +} + +static bool hasPtoAlignType(TypeRange types) { + return llvm::any_of(types, [](Type type) { return hasPtoAlignType(type); }); +} + +static bool hasPtoMemRefMemorySpace(Type type) { + if (auto memRefType = dyn_cast(type)) + return isa(memRefType.getMemorySpace()); + if (auto functionType = dyn_cast(type)) + return llvm::any_of(functionType.getInputs(), hasPtoMemRefMemorySpace) || + llvm::any_of(functionType.getResults(), hasPtoMemRefMemorySpace); + return false; +} + +static bool hasPtoMemRefMemorySpace(TypeRange types) { + return llvm::any_of(types, [](Type type) { + return hasPtoMemRefMemorySpace(type); + }); +} + +struct ConvertPtoMemRefSpaceCarrierOp final : ConversionPattern { + ConvertPtoMemRefSpaceCarrierOp(TypeConverter &typeConverter, + MLIRContext *context) + : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), 1, context) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (!hasPtoMemRefMemorySpace(op->getOperandTypes()) && + !hasPtoMemRefMemorySpace(op->getResultTypes())) + return failure(); + if (op->getNumRegions() != 0) + return rewriter.notifyMatchFailure( + op, "region ops with PTO memref spaces are handled structurally"); + + FailureOr converted = + convertOpResultTypes(op, operands, *typeConverter, rewriter); + if (failed(converted)) + return failure(); + return success(); + } +}; + +struct ConvertMemRefReinterpretCastSpaceOp final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type convertedResultType = getTypeConverter()->convertType(op.getType()); + auto memRefResultType = dyn_cast_or_null(convertedResultType); + if (!memRefResultType) + return rewriter.notifyMatchFailure(op, "expected memref result type"); + + rewriter.replaceOpWithNewOp( + op, memRefResultType, adaptor.getSource(), adaptor.getOffsets(), + adaptor.getSizes(), adaptor.getStrides(), op.getStaticOffsets(), + op.getStaticSizes(), op.getStaticStrides()); + return success(); + } +}; + +struct ConvertMemRefSubViewSpaceOp final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type convertedResultType = getTypeConverter()->convertType(op.getType()); + auto memRefResultType = dyn_cast_or_null(convertedResultType); + if (!memRefResultType) + return rewriter.notifyMatchFailure(op, "expected memref result type"); + + rewriter.replaceOpWithNewOp( + op, memRefResultType, adaptor.getSource(), op.getMixedOffsets(), + op.getMixedSizes(), op.getMixedStrides()); + return success(); + } +}; + +struct ConvertMemRefSpaceUnrealizedCastOp final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op->getNumOperands() != 1 || op->getNumResults() != 1) + return failure(); + if (!hasPtoMemRefMemorySpace(op->getOperandTypes()) && + !hasPtoMemRefMemorySpace(op->getResultTypes())) + return failure(); + + Type convertedResultType = getTypeConverter()->convertType(op.getResult(0).getType()); + if (!convertedResultType) + return failure(); + + Value input = adaptor.getOperands().front(); + if (input.getType() == convertedResultType) { + rewriter.replaceOp(op, input); + return success(); + } + return failure(); + } +}; + +static LogicalResult normalizePtoMemRefSpaces(ModuleOp module, + llvm::raw_ostream &diagOS) { + MLIRContext *context = module.getContext(); + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + typeConverter.addConversion([&](MemRefType type) -> Type { + auto addrSpace = dyn_cast_or_null(type.getMemorySpace()); + if (!addrSpace) + return type; + return MemRefType::get( + type.getShape(), type.getElementType(), type.getLayout(), + IntegerAttr::get(IntegerType::get(context, 64), + static_cast(addrSpace.getAddressSpace()))); + }); + typeConverter.addTypeAttributeConversion( + [](MemRefType, pto::AddressSpaceAttr attr) -> Attribute { + return IntegerAttr::get(IntegerType::get(attr.getContext(), 64), + static_cast(attr.getAddressSpace())); + }); + auto materializeMemRefCast = [](OpBuilder &builder, Type resultType, + ValueRange inputs, Location loc) -> Value { + if (inputs.size() != 1) + return {}; + return builder + .create(loc, TypeRange{resultType}, inputs) + .getResult(0); + }; + typeConverter.addSourceMaterialization(materializeMemRefCast); + typeConverter.addTargetMaterialization(materializeMemRefCast); + typeConverter.addArgumentMaterialization(materializeMemRefCast); + + ConversionTarget target(*context); + target.addLegalOp(); + target.addDynamicallyLegalOp( + [&](func::FuncOp op) { + return typeConverter.isSignatureLegal(op.getFunctionType()) && + typeConverter.isLegal(&op.getBody()); + }); + target.addDynamicallyLegalOp( + [&](func::CallOp op) { return typeConverter.isLegal(op); }); + target.addDynamicallyLegalOp( + [&](func::ReturnOp op) { return typeConverter.isLegal(op); }); + target.addDynamicallyLegalOp( + [&](Operation *op) { + return isLegalForBranchOpInterfaceTypeConversionPattern(op, + typeConverter); + }); + target.markUnknownOpDynamicallyLegal([&](Operation *op) { + return typeConverter.isLegal(op->getOperandTypes()) && + typeConverter.isLegal(op->getResultTypes()); + }); + + RewritePatternSet patterns(context); + scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter, patterns, + target); + populateFunctionOpInterfaceTypeConversionPattern(patterns, + typeConverter); + populateCallOpTypeConversionPattern(patterns, typeConverter); + populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); + populateReturnOpTypeConversionPattern(patterns, typeConverter); + patterns.add( + typeConverter, context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + diagOS << "VPTO LLVM emission failed: memref address-space normalization " + "failed\n"; + return failure(); + } + + SmallVector castsToFold; + module.walk([&](UnrealizedConversionCastOp castOp) { + if (castOp->getNumOperands() != 1 || castOp->getNumResults() != 1) + return; + if (!hasPtoMemRefMemorySpace(castOp->getOperandTypes()) && + !hasPtoMemRefMemorySpace(castOp->getResultTypes())) + return; + Type convertedResultType = typeConverter.convertType(castOp.getResult(0).getType()); + if (convertedResultType && convertedResultType == castOp.getOperand(0).getType()) + castsToFold.push_back(castOp); + }); + for (UnrealizedConversionCastOp castOp : castsToFold) { + castOp.getResult(0).replaceAllUsesWith(castOp.getOperand(0)); + castOp.erase(); + } + + WalkResult leftover = module.walk([&](Operation *op) { + if (hasPtoMemRefMemorySpace(op->getOperandTypes()) || + hasPtoMemRefMemorySpace(op->getResultTypes())) { + diagOS << "VPTO LLVM emission failed: residual PTO memref address space on op " + << op->getName().getStringRef() << "\n"; + op->print(diagOS); + diagOS << "\n"; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + if (leftover.wasInterrupted()) + return failure(); + return success(); +} + +struct ConvertPtoAddPtrOp final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::AddPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto convertedResultType = + getTypeConverter()->convertType(op.getResult().getType()); + auto llvmPtrType = dyn_cast_or_null(convertedResultType); + if (!llvmPtrType) + return rewriter.notifyMatchFailure(op, "expected LLVM pointer result type"); + + Value offset = adaptor.getOffset(); + if (offset.getType().isIndex()) + offset = rewriter.create(op.getLoc(), + rewriter.getI64Type(), offset); + + auto gep = rewriter.create( + op.getLoc(), llvmPtrType, cast(op.getPtr().getType()).getElementType(), + adaptor.getPtr(), ValueRange{offset}); + rewriter.replaceOp(op, gep.getResult()); + return success(); + } +}; + +struct ConvertPtoCastPtrOp final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::CastPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type convertedResultType = + getTypeConverter()->convertType(op.getResult().getType()); + if (!convertedResultType) + return rewriter.notifyMatchFailure(op, "could not convert castptr result type"); + + Value input = adaptor.getInput(); + Type inputType = input.getType(); + if (inputType == convertedResultType) { + rewriter.replaceOp(op, input); + return success(); + } + + if (auto llvmPtrType = dyn_cast(convertedResultType)) { + if (isa(inputType)) { + auto intToPtr = + rewriter.create(op.getLoc(), llvmPtrType, input); + rewriter.replaceOp(op, intToPtr.getResult()); + return success(); + } + auto sourcePtrType = dyn_cast(inputType); + if (!sourcePtrType) + return rewriter.notifyMatchFailure(op, "expected integer or LLVM pointer input"); + if (sourcePtrType.getAddressSpace() == llvmPtrType.getAddressSpace()) { + auto bitcast = + rewriter.create(op.getLoc(), llvmPtrType, input); + rewriter.replaceOp(op, bitcast.getResult()); + return success(); + } + return rewriter.notifyMatchFailure(op, "cross-address-space ptr casts are unsupported"); + } + + if (auto resultIntType = dyn_cast(convertedResultType)) { + if (auto inputPtrType = dyn_cast(inputType)) { + rewriter.replaceOpWithNewOp(op, resultIntType, input); + return success(); + } + if (auto inputIntType = dyn_cast(inputType)) { + unsigned srcWidth = inputIntType.getWidth(); + unsigned dstWidth = resultIntType.getWidth(); + if (srcWidth == dstWidth) { + rewriter.replaceOp(op, input); + return success(); + } + if (srcWidth < dstWidth) { + rewriter.replaceOpWithNewOp(op, resultIntType, input); + return success(); + } + rewriter.replaceOpWithNewOp(op, resultIntType, input); + return success(); + } + } + + return rewriter.notifyMatchFailure(op, "unsupported castptr conversion"); + } +}; + +struct ConvertPtoLoadScalarOp final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::LoadScalarOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto llvmPtrType = dyn_cast(adaptor.getPtr().getType()); + if (!llvmPtrType) + return rewriter.notifyMatchFailure(op, "expected LLVM pointer operand"); + + Value offset = adaptor.getOffset(); + if (offset.getType().isIndex()) + offset = rewriter.create(op.getLoc(), + rewriter.getI64Type(), offset); + + Value elemPtr = adaptor.getPtr(); + if (!matchPattern(offset, m_Zero())) { + elemPtr = rewriter.create(op.getLoc(), llvmPtrType, + op.getValue().getType(), adaptor.getPtr(), + ValueRange{offset}); + } + + auto getNaturalAlignment = [&](Type type) -> unsigned { + unsigned alignBytes = 0; + if (auto intType = dyn_cast(type)) { + alignBytes = llvm::divideCeil(unsigned(intType.getWidth()), 8u); + } else if (type.isF16() || type.isBF16()) { + alignBytes = 2; + } else if (type.isF32()) { + alignBytes = 4; + } else if (type.isF64()) { + alignBytes = 8; + } + return alignBytes; + }; + + rewriter.replaceOpWithNewOp( + op, op.getValue().getType(), elemPtr, + getNaturalAlignment(op.getValue().getType())); + return success(); + } +}; + +struct ConvertPtoStoreScalarOp final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::StoreScalarOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto llvmPtrType = dyn_cast(adaptor.getPtr().getType()); + if (!llvmPtrType) + return rewriter.notifyMatchFailure(op, "expected LLVM pointer operand"); + + Value offset = adaptor.getOffset(); + if (offset.getType().isIndex()) + offset = rewriter.create(op.getLoc(), + rewriter.getI64Type(), offset); + + Value elemPtr = adaptor.getPtr(); + if (!matchPattern(offset, m_Zero())) { + elemPtr = rewriter.create(op.getLoc(), llvmPtrType, + adaptor.getValue().getType(), + adaptor.getPtr(), ValueRange{offset}); + } + + auto getNaturalAlignment = [&](Type type) -> unsigned { + unsigned alignBytes = 0; + if (auto intType = dyn_cast(type)) { + alignBytes = llvm::divideCeil(unsigned(intType.getWidth()), 8u); + } else if (type.isF16() || type.isBF16()) { + alignBytes = 2; + } else if (type.isF32()) { + alignBytes = 4; + } else if (type.isF64()) { + alignBytes = 8; + } + return alignBytes; + }; + + rewriter.replaceOpWithNewOp( + op, adaptor.getValue(), elemPtr, + getNaturalAlignment(adaptor.getValue().getType())); + return success(); + } +}; + +struct ConvertPtoUnrealizedCastOp final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op->getNumOperands() != 1 || op->getNumResults() != 1) + return rewriter.notifyMatchFailure(op, "only 1:1 casts are supported"); + + Type convertedResultType = + getTypeConverter()->convertType(op.getResult(0).getType()); + if (!convertedResultType) + return rewriter.notifyMatchFailure(op, "could not convert cast result type"); + + Value input = adaptor.getOperands().front(); + if (auto llvmPtrType = dyn_cast(convertedResultType)) { + if (input.getType().isInteger(64)) { + rewriter.replaceOpWithNewOp(op, llvmPtrType, input); + return success(); + } + } + if (input.getType() == convertedResultType) { + rewriter.replaceOp(op, input); + return success(); + } + + auto cast = rewriter.create( + op.getLoc(), TypeRange{convertedResultType}, input); + rewriter.replaceOp(op, cast.getResults()); + return success(); + } +}; + +struct ConvertPtoPtrCarrierOp final : ConversionPattern { + ConvertPtoPtrCarrierOp(TypeConverter &typeConverter, MLIRContext *context) + : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), 1, context) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (isa(op)) + return failure(); + if (!hasPtoPtrType(op->getOperandTypes()) && !hasPtoPtrType(op->getResultTypes())) + return failure(); + if (op->getNumRegions() != 0) + return rewriter.notifyMatchFailure(op, "region ops with pto.ptr are unsupported"); + + SmallVector convertedResultTypes; + if (failed(typeConverter->convertTypes(op->getResultTypes(), convertedResultTypes))) + return failure(); + + OperationState state(op->getLoc(), op->getName().getStringRef()); + state.addOperands(operands); + state.addTypes(convertedResultTypes); + state.addAttributes(op->getAttrs()); + state.addSuccessors(op->getSuccessors()); + + Operation *newOp = rewriter.create(state); + rewriter.replaceOp(op, newOp->getResults()); + return success(); + } +}; + +struct ConvertPtoAlignUnrealizedCastOp final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op->getNumOperands() != 1 || op->getNumResults() != 1) + return failure(); + if (!hasPtoAlignType(op->getOperandTypes()) && + !hasPtoAlignType(op->getResultTypes())) + return failure(); + + Type convertedResultType = + getTypeConverter()->convertType(op.getResult(0).getType()); + if (!convertedResultType) + return failure(); + + Value input = adaptor.getOperands().front(); + if (input.getType() == convertedResultType) { + rewriter.replaceOp(op, input); + return success(); + } + return failure(); + } +}; + +struct ConvertPtoAlignCarrierOp final : ConversionPattern { + ConvertPtoAlignCarrierOp(TypeConverter &typeConverter, MLIRContext *context) + : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), 1, context) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (isa(op)) + return failure(); + if (!hasPtoAlignType(op->getOperandTypes()) && + !hasPtoAlignType(op->getResultTypes())) + return failure(); + if (op->getNumRegions() != 0) + return rewriter.notifyMatchFailure(op, + "region ops with pto.align are handled structurally"); + + SmallVector convertedResultTypes; + if (failed(typeConverter->convertTypes(op->getResultTypes(), + convertedResultTypes))) + return failure(); + + OperationState state(op->getLoc(), op->getName().getStringRef()); + state.addOperands(operands); + state.addTypes(convertedResultTypes); + state.addAttributes(op->getAttrs()); + state.addSuccessors(op->getSuccessors()); + + Operation *newOp = rewriter.create(state); + rewriter.replaceOp(op, newOp->getResults()); + return success(); + } +}; + +static LogicalResult normalizePtoPtrsToLLVM(ModuleOp module, llvm::raw_ostream &diagOS) { + MLIRContext *context = module.getContext(); + + for (func::FuncOp funcOp : module.getOps()) { + if (funcOp.isExternal()) + continue; + } + + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + typeConverter.addConversion([&](pto::PtrType type) -> Type { + return LLVM::LLVMPointerType::get( + context, static_cast(type.getMemorySpace().getAddressSpace())); + }); + auto materializePtrCast = [](OpBuilder &builder, Type resultType, + ValueRange inputs, Location loc) -> Value { + if (inputs.size() != 1) + return {}; + return builder + .create(loc, TypeRange{resultType}, inputs) + .getResult(0); + }; + typeConverter.addSourceMaterialization(materializePtrCast); + typeConverter.addTargetMaterialization(materializePtrCast); + typeConverter.addArgumentMaterialization(materializePtrCast); + + ConversionTarget target(*context); + target.addLegalOp(); + target.addDynamicallyLegalOp([&](func::FuncOp op) { + return typeConverter.isSignatureLegal(op.getFunctionType()) && + typeConverter.isLegal(&op.getBody()); + }); + target.addDynamicallyLegalOp( + [&](func::CallOp op) { return typeConverter.isLegal(op); }); + target.addDynamicallyLegalOp( + [&](func::ReturnOp op) { return typeConverter.isLegal(op); }); + target.addDynamicallyLegalOp( + [&](Operation *op) { + return isLegalForBranchOpInterfaceTypeConversionPattern(op, + typeConverter); + }); + target.markUnknownOpDynamicallyLegal([&](Operation *op) { + return typeConverter.isLegal(op->getOperandTypes()) && + typeConverter.isLegal(op->getResultTypes()); + }); + target.addIllegalOp(); + target.addDynamicallyLegalOp([](UnrealizedConversionCastOp op) { + return !hasPtoPtrType(op->getOperandTypes()) && !hasPtoPtrType(op->getResultTypes()); + }); + + RewritePatternSet patterns(context); + scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter, patterns, + target); + populateFunctionOpInterfaceTypeConversionPattern(patterns, + typeConverter); + populateCallOpTypeConversionPattern(patterns, typeConverter); + populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); + populateReturnOpTypeConversionPattern(patterns, typeConverter); + patterns.add( + typeConverter, context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + diagOS << "VPTO LLVM emission failed: pto.ptr normalization failed\n"; + return failure(); + } + + SmallVector castsToFold; + module.walk([&](UnrealizedConversionCastOp castOp) { + if (castOp->getNumOperands() != 1 || castOp->getNumResults() != 1) + return; + if (!hasPtoPtrType(castOp->getOperandTypes()) && + !hasPtoPtrType(castOp->getResultTypes())) + return; + Type convertedResultType = typeConverter.convertType(castOp.getResult(0).getType()); + if (convertedResultType && convertedResultType == castOp.getOperand(0).getType()) + castsToFold.push_back(castOp); + }); + for (UnrealizedConversionCastOp castOp : castsToFold) { + castOp.getResult(0).replaceAllUsesWith(castOp.getOperand(0)); + castOp.erase(); + } + + return success(); +} + +static LogicalResult normalizePtoAlignsToABI(ModuleOp module, + llvm::raw_ostream &diagOS) { + MLIRContext *context = module.getContext(); + + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + typeConverter.addConversion([&](pto::AlignType type) -> Type { + return VectorType::get({32}, IntegerType::get(context, 8)); + }); + auto materializeAlignCast = [](OpBuilder &builder, Type resultType, + ValueRange inputs, Location loc) -> Value { + if (inputs.size() != 1) + return {}; + return builder + .create(loc, TypeRange{resultType}, inputs) + .getResult(0); + }; + typeConverter.addSourceMaterialization(materializeAlignCast); + typeConverter.addTargetMaterialization(materializeAlignCast); + typeConverter.addArgumentMaterialization(materializeAlignCast); + + ConversionTarget target(*context); + target.addLegalOp(); + target.addDynamicallyLegalOp([&](func::FuncOp op) { + return typeConverter.isSignatureLegal(op.getFunctionType()) && + typeConverter.isLegal(&op.getBody()); + }); + target.addDynamicallyLegalOp( + [&](func::CallOp op) { return typeConverter.isLegal(op); }); + target.addDynamicallyLegalOp( + [&](func::ReturnOp op) { return typeConverter.isLegal(op); }); + target.addDynamicallyLegalOp( + [&](Operation *op) { + return isLegalForBranchOpInterfaceTypeConversionPattern(op, + typeConverter); + }); + target.addDynamicallyLegalOp( + [&](UnrealizedConversionCastOp op) { + return !hasPtoAlignType(op->getOperandTypes()) && + !hasPtoAlignType(op->getResultTypes()); + }); + target.markUnknownOpDynamicallyLegal([&](Operation *op) { + return typeConverter.isLegal(op->getOperandTypes()) && + typeConverter.isLegal(op->getResultTypes()); + }); + + RewritePatternSet patterns(context); + scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter, patterns, + target); + populateFunctionOpInterfaceTypeConversionPattern(patterns, + typeConverter); + populateCallOpTypeConversionPattern(patterns, typeConverter); + populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); + populateReturnOpTypeConversionPattern(patterns, typeConverter); + patterns.add( + typeConverter, context); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + diagOS << "VPTO LLVM emission failed: pto.align normalization failed\n"; + return failure(); + } + + SmallVector castsToFold; + module.walk([&](UnrealizedConversionCastOp castOp) { + if (castOp->getNumOperands() != 1 || castOp->getNumResults() != 1) + return; + if (!hasPtoAlignType(castOp->getOperandTypes()) && + !hasPtoAlignType(castOp->getResultTypes())) + return; + Type convertedResultType = + typeConverter.convertType(castOp.getResult(0).getType()); + if (convertedResultType && + convertedResultType == castOp.getOperand(0).getType()) + castsToFold.push_back(castOp); + }); + for (UnrealizedConversionCastOp castOp : castsToFold) { + castOp.getResult(0).replaceAllUsesWith(castOp.getOperand(0)); + castOp.erase(); + } + + WalkResult leftover = module.walk([&](Operation *op) { + if (hasPtoAlignType(op->getOperandTypes()) || + hasPtoAlignType(op->getResultTypes())) { + diagOS << "VPTO LLVM emission failed: residual pto.align type on op " + << op->getName().getStringRef() << "\n"; + op->print(diagOS); + diagOS << "\n"; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + if (leftover.wasInterrupted()) + return failure(); + return success(); +} + +static Type getElementTypeFromVectorLike(Type type) { + if (auto vecType = dyn_cast(type)) + return vecType.getElementType(); + if (auto vecType = dyn_cast(type)) + return vecType.getElementType(); + return {}; +} + +static Type getElementTypeFromPointerLike(Type type) { + if (auto ptrType = dyn_cast(type)) + return ptrType.getElementType(); + if (auto memRefType = dyn_cast(type)) + return memRefType.getElementType(); + return {}; +} + +static Type getElementTypeFromABIValue(Value value) { + if (!value) + return {}; + if (Type direct = getElementTypeFromPointerLike(value.getType())) + return direct; + return {}; +} + +static std::optional getElementCountFromVectorLike(Type type) { + if (auto vecType = dyn_cast(type)) + return vecType.getElementCount(); + if (auto vecType = dyn_cast(type)) { + if (vecType.getRank() != 1) + return std::nullopt; + return vecType.getShape().front(); + } + return std::nullopt; +} + +static Value castIntegerLikeTo(Operation *anchor, Value value, Type targetType) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + + if (value.getType() == targetType) + return value; + + auto targetInt = dyn_cast(targetType); + if (value.getType().isIndex() && targetInt) + return builder.create(anchor->getLoc(), targetType, value); + if (auto sourceInt = dyn_cast(value.getType())) { + if (targetInt) { + if (sourceInt.getWidth() < targetInt.getWidth()) + return builder.create(anchor->getLoc(), targetType, value); + if (sourceInt.getWidth() > targetInt.getWidth()) + return builder.create(anchor->getLoc(), targetType, value); + return value; + } + if (targetType.isIndex()) + return builder.create(anchor->getLoc(), targetType, value); + } + + return {}; +} + +static FailureOr convertElementOffsetToBytes(Operation *anchor, Value offset, + Type elementType) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + + Value offsetI32 = castIntegerLikeTo(anchor, offset, builder.getI32Type()); + if (!offsetI32) + return failure(); + + unsigned bitWidth = 0; + if (auto intType = dyn_cast(elementType)) + bitWidth = intType.getWidth(); + else if (auto floatType = dyn_cast(elementType)) + bitWidth = floatType.getWidth(); + if (bitWidth == 0 || bitWidth % 8 != 0) + return failure(); + + Value scale = builder.create( + anchor->getLoc(), builder.getI32IntegerAttr(bitWidth / 8)); + return builder.create(anchor->getLoc(), offsetI32, scale) + .getResult(); +} + +static Value buildBridgeCast(OpBuilder &builder, Location loc, Value input, + Type targetType) { + if (input.getType() == targetType) + return input; + if ((isa(input.getType()) && + isa(targetType)) || + (isa(input.getType()) && + isa(targetType))) { + return builder + .create(loc, TypeRange{targetType}, input) + .getResult(0); + } + return builder.create(loc, targetType, input).getResult(); +} + +static FailureOr requirePointerABIAddress(Operation *anchor, Value address, + llvm::raw_ostream &diagOS) { + if (isa(address.getType())) + return address; + if (auto ptrType = dyn_cast(address.getType())) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + auto llvmPtrType = LLVM::LLVMPointerType::get( + builder.getContext(), + static_cast(ptrType.getMemorySpace().getAddressSpace())); + Value abiAddress = buildBridgeCast(builder, anchor->getLoc(), address, llvmPtrType); + return abiAddress; + } + + diagOS << "VPTO LLVM emission failed: expected pointer-ABI address after " + "pre-emit canonicalization, but saw " + << address.getType() << " on op "; + anchor->print(diagOS); + diagOS << "\n"; + return failure(); +} + +static FailureOr materializeAlignABIValue(Operation *anchor, Value align, + llvm::raw_ostream &diagOS) { + if (!align) + return failure(); + if (isa(align.getType())) + return align; + + auto alignType = dyn_cast(align.getType()); + if (!alignType) { + diagOS << "VPTO LLVM emission failed: expected align ABI value, but saw " + << align.getType() << "\n"; + return failure(); + } + + Operation *def = align.getDefiningOp(); + if (!def) { + diagOS << "VPTO LLVM emission failed: unsupported non-ABI align producer " + << "" + << " for " << alignType << "\n"; + return failure(); + } + + auto defName = def->getName().getStringRef(); + if (defName != "pto.init_align" && defName != "ub.poison") { + diagOS << "VPTO LLVM emission failed: unsupported non-ABI align producer "; + diagOS << def->getName(); + diagOS << " for " << alignType << "\n"; + return failure(); + } + + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + auto abiType = cast(convertVPTOType(alignType, builder)); + auto zeroAttr = DenseElementsAttr::get(abiType, builder.getI8IntegerAttr(0)); + return builder.create(anchor->getLoc(), abiType, zeroAttr) + .getResult(); +} + +static Value getI64Constant(OpBuilder &builder, Location loc, uint64_t value) { + return builder.create(loc, builder.getI64IntegerAttr(value)) + .getResult(); +} + +static Value getI32Constant(OpBuilder &builder, Location loc, uint64_t value) { + return builder.create(loc, builder.getI32IntegerAttr(value)) + .getResult(); +} + +static Value getI16Constant(OpBuilder &builder, Location loc, uint64_t value) { + return builder.create(loc, builder.getI16IntegerAttr(value)) + .getResult(); +} + +static Value buildAllTrueMask(OpBuilder &builder, Location loc) { + auto maskType = VectorType::get({256}, builder.getI1Type()); + auto attr = DenseElementsAttr::get(maskType, true); + return builder.create(loc, maskType, attr).getResult(); +} + +static FailureOr buildPltB8Mask(IRRewriter &builder, ModuleOp module, + Location loc, uint64_t laneCount, + llvm::raw_ostream &diagOS) { + Value laneCountValue = getI32Constant(builder, loc, laneCount); + auto maskType = VectorType::get({256}, builder.getI1Type()); + auto funcType = + builder.getFunctionType({builder.getI32Type()}, {maskType, builder.getI32Type()}); + auto callee = + getOrCreateExternalFunc(module, "llvm.hivm.plt.b8.v300", funcType); + auto call = builder.create(loc, callee, ValueRange{laneCountValue}); + return call.getResult(0); +} + +static FailureOr buildPltB32Mask(IRRewriter &builder, ModuleOp module, + Location loc, uint64_t laneCount, + llvm::raw_ostream &diagOS) { + // Keep this helper narrowly scoped to the verified HIVM form we have observed + // in emitc-generated device IR. For Expands/TExpandS, installed PTO source + // calls pset_b32(PAT_ALL), but save-temps from the working emitc path show + // that the compiler frontend does not preserve a pset-shaped HIVM intrinsic + // here. Instead, the full-lane mask is materialized in the final device IR as + // llvm.hivm.plt.b32.v300(i32 64), i.e. a canonical "all 64 b32 lanes active" + // form that the backend accepts. Reproduce that observed lowering here; do + // not treat it as evidence that pset_b32 and plt_b32 are generally + // interchangeable at the source or VPTO level. + Value laneCountValue = getI32Constant(builder, loc, laneCount); + auto maskType = VectorType::get({256}, builder.getI1Type()); + auto funcType = + builder.getFunctionType({builder.getI32Type()}, {maskType, builder.getI32Type()}); + auto callee = + getOrCreateExternalFunc(module, "llvm.hivm.plt.b32.v300", funcType); + auto call = builder.create(loc, callee, ValueRange{laneCountValue}); + return call.getResult(0); +} + +static FailureOr buildPltB16Mask(IRRewriter &builder, ModuleOp module, + Location loc, uint64_t laneCount, + llvm::raw_ostream &diagOS) { + Value laneCountValue = getI32Constant(builder, loc, laneCount); + auto maskType = VectorType::get({256}, builder.getI1Type()); + auto funcType = + builder.getFunctionType({builder.getI32Type()}, {maskType, builder.getI32Type()}); + auto callee = + getOrCreateExternalFunc(module, "llvm.hivm.plt.b16.v300", funcType); + auto call = builder.create(loc, callee, ValueRange{laneCountValue}); + return call.getResult(0); +} + +static FailureOr buildDynamicPltMask(IRRewriter &builder, ModuleOp module, + Location loc, Value laneCount, + Type vectorElemType, + llvm::raw_ostream &diagOS) { + Value laneCountI32 = laneCount; + Type i32Type = builder.getI32Type(); + if (laneCountI32.getType() != i32Type) { + if (laneCountI32.getType().isIndex()) { + laneCountI32 = + builder.create(loc, i32Type, laneCountI32); + } else if (auto sourceInt = dyn_cast(laneCountI32.getType())) { + auto targetInt = cast(i32Type); + if (sourceInt.getWidth() < targetInt.getWidth()) { + laneCountI32 = + builder.create(loc, i32Type, laneCountI32); + } else if (sourceInt.getWidth() > targetInt.getWidth()) { + laneCountI32 = + builder.create(loc, i32Type, laneCountI32); + } + } else { + return failure(); + } + } + + auto maskType = VectorType::get({256}, builder.getI1Type()); + auto funcType = + builder.getFunctionType({builder.getI32Type()}, {maskType, builder.getI32Type()}); + + StringRef calleeName; + if (vectorElemType.isF32()) { + calleeName = "llvm.hivm.plt.b32.v300"; + } else if (vectorElemType.isF16() || vectorElemType.isBF16()) { + calleeName = "llvm.hivm.plt.b16.v300"; + } else if (auto intType = dyn_cast(vectorElemType)) { + if (intType.getWidth() == 32) + calleeName = "llvm.hivm.plt.b32.v300"; + else if (intType.getWidth() == 16) + calleeName = "llvm.hivm.plt.b16.v300"; + } + + if (calleeName.empty()) { + diagOS << "VPTO LLVM emission failed: unsupported dynamic plt mask element " + "type " + << vectorElemType << "\n"; + return failure(); + } + + auto callee = getOrCreateExternalFunc(module, calleeName, funcType); + auto call = builder.create(loc, callee, ValueRange{laneCountI32}); + return call.getResult(0); +} + +static FailureOr packLoopPair(Operation *anchor, Value low, Value high) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + + Value lowI64 = castIntegerLikeTo(anchor, low, builder.getI64Type()); + Value highI64 = castIntegerLikeTo(anchor, high, builder.getI64Type()); + if (!lowI64 || !highI64) + return failure(); + + Value shift = getI64Constant(builder, anchor->getLoc(), 40); + Value highShifted = + builder.create(anchor->getLoc(), highI64, shift).getResult(); + return builder.create(anchor->getLoc(), highShifted, lowI64) + .getResult(); +} + +static FailureOr packLoopSize(Operation *anchor, Value loop2, Value loop1) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + + Value loop2I64 = castIntegerLikeTo(anchor, loop2, builder.getI64Type()); + Value loop1I64 = castIntegerLikeTo(anchor, loop1, builder.getI64Type()); + if (!loop2I64 || !loop1I64) + return failure(); + + Value shift = getI64Constant(builder, anchor->getLoc(), 21); + Value loop2Shifted = + builder.create(anchor->getLoc(), loop2I64, shift).getResult(); + return builder.create(anchor->getLoc(), loop2Shifted, loop1I64) + .getResult(); +} + +static FailureOr +packCopyGmToUbConfig0(Operation *anchor, pto::CopyGmToUbufOp op, + ValueRange operands) { + if (operands.size() != 11) + return failure(); + + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); + + auto getI64Operand = [&](unsigned idx) -> Value { + return castIntegerLikeTo(anchor, operands[idx], builder.getI64Type()); + }; + + Value sid = getI64Operand(2); + Value nBurst = getI64Operand(3); + Value lenBurst = getI64Operand(4); + Value leftPadding = getI64Operand(5); + Value rightPadding = getI64Operand(6); + Value dataSelect = castIntegerLikeTo(anchor, operands[7], builder.getI64Type()); + Value cacheCtl = getI64Operand(8); + if (!sid || !nBurst || !lenBurst || !leftPadding || !rightPadding || + !dataSelect || !cacheCtl) + return failure(); + + auto shl = [&](Value value, uint64_t amount) -> Value { + return builder.create(loc, value, + getI64Constant(builder, loc, amount)); + }; + auto bitOr = [&](Value lhs, Value rhs) -> Value { + return builder.create(loc, lhs, rhs); + }; + + Value config = sid; + config = bitOr(config, shl(nBurst, 4)); + config = bitOr(config, shl(lenBurst, 25)); + config = bitOr(config, shl(leftPadding, 46)); + config = bitOr(config, shl(rightPadding, 52)); + config = bitOr(config, shl(dataSelect, 58)); + config = bitOr(config, shl(cacheCtl, 60)); + return config; +} + +static FailureOr +packCopyGmToUbConfig1(Operation *anchor, ValueRange operands) { + if (operands.size() != 11) + return failure(); + return packLoopPair(anchor, operands[9], operands[10]); +} + +static FailureOr +packCopyUbToGmConfig0(Operation *anchor, ValueRange operands) { + if (operands.size() != 8) + return failure(); + + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); + + auto getI64Operand = [&](unsigned idx) -> Value { + return castIntegerLikeTo(anchor, operands[idx], builder.getI64Type()); + }; + + Value sid = getI64Operand(2); + Value nBurst = getI64Operand(3); + Value lenBurst = getI64Operand(4); + Value reserved = getI64Operand(5); + if (!sid || !nBurst || !lenBurst || !reserved) + return failure(); + + auto shl = [&](Value value, uint64_t amount) -> Value { + return builder.create(loc, value, + getI64Constant(builder, loc, amount)); + }; + auto bitOr = [&](Value lhs, Value rhs) -> Value { + return builder.create(loc, lhs, rhs); + }; + + Value config = sid; + config = bitOr(config, shl(nBurst, 4)); + config = bitOr(config, shl(lenBurst, 25)); + config = bitOr(config, shl(reserved, 60)); + return config; +} + +static FailureOr +packCopyUbToGmConfig1(Operation *anchor, ValueRange operands) { + if (operands.size() != 8) + return failure(); + return packLoopPair(anchor, operands[6], operands[7]); +} + +static FailureOr packVbitsortConfig(Operation *anchor, Value repeatTimes) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); + + Value repeatI64 = castIntegerLikeTo(anchor, repeatTimes, builder.getI64Type()); + if (!repeatI64) + return failure(); + return builder + .create(loc, repeatI64, getI64Constant(builder, loc, 56)) + .getResult(); +} + +static func::FuncOp getOrCreateExternalFunc(ModuleOp module, StringRef name, + FunctionType type) { + if (auto existing = module.lookupSymbol(name)) + return existing; + OpBuilder builder(module.getBodyRegion()); + builder.setInsertionPointToStart(module.getBody()); + auto func = builder.create(module.getLoc(), name, type); + func.setPrivate(); + return func; +} + +static FailureOr getConfirmedCallee(Operation *op) { + if (isa(op)) + return std::string("llvm.hivm.SET.LOOP2.STRIDE.OUTTOUB"); + if (isa(op)) + return std::string("llvm.hivm.SET.LOOP1.STRIDE.OUTTOUB"); + if (isa(op)) + return std::string("llvm.hivm.SET.LOOP.SIZE.OUTTOUB"); + if (isa(op)) + return std::string("llvm.hivm.SET.LOOP2.STRIDE.UBTOOUT"); + if (isa(op)) + return std::string("llvm.hivm.SET.LOOP1.STRIDE.UBTOOUT"); + if (isa(op)) + return std::string("llvm.hivm.SET.LOOP.SIZE.UBTOOUT"); + if (auto copy = dyn_cast(op)) { + Type elementType = getElementTypeFromABIValue(copy.getSource()); + if (!elementType) + elementType = getElementTypeFromABIValue(copy.getDestination()); + std::string elem = getCopyElementFragment(elementType); + if (elem.empty()) + return failure(); + return "llvm.hivm.MOV.OUT.TO.UB.ALIGN.V2." + elem + ".DV"; + } + if (isa(op)) + return std::string("llvm.hivm.MOV.UB.TO.OUT.ALIGN.V2.DV"); + if (isa(op)) + return std::string("llvm.hivm.SET.FLAG.IMM"); + if (isa(op)) + return std::string("llvm.hivm.WAIT.FLAG.IMM"); + if (isa(op)) + return std::string("llvm.hivm.BARRIER"); + if (isa(op)) + return std::string("llvm.hivm.GET.BLOCK.IDX"); + if (isa(op)) + return std::string("llvm.hivm.GET.SUBBLOCKID"); + if (isa(op)) + return std::string("llvm.hivm.GET.BLOCK.NUM"); + if (isa(op)) + return std::string("llvm.hivm.GET.SUBBLOCKDIM"); + if (isa(op)) + return std::string("llvm.hivm.sprclr"); + if (isa(op)) + return std::string("llvm.hivm.plt.b8.v300"); + if (isa(op)) + return std::string("llvm.hivm.plt.b32.v300"); + if (isa(op)) + return std::string("llvm.hivm.plt.b16.v300"); + if (isa(op)) + return std::string("llvm.hivm.pset.b8"); + if (isa(op)) + return std::string("llvm.hivm.pset.b16"); + if (isa(op)) + return std::string("llvm.hivm.pset.b32"); + if (isa(op)) + return std::string("llvm.hivm.pge.b8"); + if (isa(op)) + return std::string("llvm.hivm.pge.b16"); + if (isa(op)) + return std::string("llvm.hivm.pge.b32"); + if (isa(op)) + return std::string("llvm.hivm.vldas"); + if (isa(op)) + return std::string("llvm.hivm.init.vector.align.data"); + if (auto vldus = dyn_cast(op)) { + std::string vec = getElementTypeFragment( + getElementTypeFromVectorLike(vldus.getResult().getType())); + auto lanes = getElementCountFromVectorLike(vldus.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vldus.v" + std::to_string(*lanes) + vec; + } + if (isa(op)) + return std::string("llvm.hivm.vstus"); + if (isa(op)) + return std::string("llvm.hivm.vstur"); + if (auto vlds = dyn_cast(op)) { + std::string vec = getElementTypeFragment(getElementTypeFromVectorLike(vlds.getResult().getType())); + auto lanes = getElementCountFromVectorLike(vlds.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + std::string name = "llvm.hivm.vldsx1"; + name += ".v" + std::to_string(*lanes) + vec; + return name; + } + if (auto vldsPost = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(vldsPost.getResult().getType())); + auto lanes = getElementCountFromVectorLike(vldsPost.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vldsx1.post.v" + std::to_string(*lanes) + vec; + } + if (auto vldsPost = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(vldsPost.getResult().getType())); + auto lanes = getElementCountFromVectorLike(vldsPost.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vldsx1.post.v" + std::to_string(*lanes) + vec; + } + if (auto vabs = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(vabs.getResult().getType())); + auto lanes = getElementCountFromVectorLike(vabs.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vabs.v" + std::to_string(*lanes) + vec + ".x"; + } + if (auto vexp = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(vexp.getResult().getType())); + auto lanes = getElementCountFromVectorLike(vexp.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vexp.v" + std::to_string(*lanes) + vec + ".x"; + } + if (auto vln = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(vln.getResult().getType())); + auto lanes = getElementCountFromVectorLike(vln.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vln.v" + std::to_string(*lanes) + vec + ".x"; + } + if (auto vneg = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(vneg.getResult().getType())); + auto lanes = getElementCountFromVectorLike(vneg.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vneg.v" + std::to_string(*lanes) + vec + ".x"; + } + if (auto vsqrt = dyn_cast(op)) { + std::string vec = getElementTypeFragment( + getElementTypeFromVectorLike(vsqrt.getResult().getType())); + auto lanes = getElementCountFromVectorLike(vsqrt.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vsqrt.v" + std::to_string(*lanes) + vec + ".x"; + } + if (auto vrelu = dyn_cast(op)) { + std::string vec = getElementTypeFragment( + getElementTypeFromVectorLike(vrelu.getResult().getType())); + auto lanes = getElementCountFromVectorLike(vrelu.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vrelu.v" + std::to_string(*lanes) + vec + ".x"; + } + if (auto vnot = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(vnot.getResult().getType())); + auto lanes = getElementCountFromVectorLike(vnot.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vnot.v" + std::to_string(*lanes) + vec + ".x"; + } + if (auto vdup = dyn_cast(op)) { + Type inputType = vdup.getInput().getType(); + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(vdup.getResult().getType())); + auto lanes = getElementCountFromVectorLike(vdup.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + if (isa(inputType)) { + StringRef position = vdup.getPosition().value_or("LOWEST"); + StringRef family = position == "HIGHEST" ? "vdupm" : "vdup"; + return "llvm.hivm." + family.str() + ".v" + std::to_string(*lanes) + vec + ".z"; + } + return "llvm.hivm.vdups.v" + std::to_string(*lanes) + vec + ".z"; + } + if (auto vbr = dyn_cast(op)) { + std::string scalar = getVbrScalarFragment(vbr.getValue().getType()); + if (scalar.empty()) + return failure(); + return "llvm.hivm.vbr." + scalar + ".v300"; + } + if (auto binary = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); + auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vadd.v" + std::to_string(*lanes) + vec + ".x"; + } + if (auto binary = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); + auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vsub.v" + std::to_string(*lanes) + vec + ".x"; + } + if (auto binary = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); + auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vmul.v" + std::to_string(*lanes) + vec + ".x"; + } + if (auto binary = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); + auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vmuls.v" + std::to_string(*lanes) + vec + ".x"; + } + if (auto binary = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); + auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vadds.v" + std::to_string(*lanes) + vec + ".x"; + } + if (auto binary = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); + auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vmaxs.v" + std::to_string(*lanes) + vec + ".x"; + } + if (auto binary = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); + auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vmins.v" + std::to_string(*lanes) + vec + ".x"; + } + if (auto binary = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); + auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vlrelu.v" + std::to_string(*lanes) + vec + ".x"; + } + if (auto binary = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); + auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vshls.v" + std::to_string(*lanes) + vec + ".x"; + } + if (auto binary = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); + auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vshrs.v" + std::to_string(*lanes) + vec + ".x"; + } + if (auto binary = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); + auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vprelu.v" + std::to_string(*lanes) + vec + ".x"; + } + if (auto binary = dyn_cast(op)) { + std::string srcVec = + getElementTypeFragment(getElementTypeFromVectorLike(binary.getInput().getType())); + auto srcLanes = getElementCountFromVectorLike(binary.getInput().getType()); + std::string dstElem = + getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); + if (srcVec.empty() || dstElem.empty() || !srcLanes) + return failure(); + return "llvm.hivm.vexpdif.v" + std::to_string(*srcLanes) + srcVec + dstElem; + } + if (auto binary = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); + auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vdiv.v" + std::to_string(*lanes) + vec + ".x"; + } + if (auto binary = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); + auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vmax.v" + std::to_string(*lanes) + vec + ".x"; + } + if (auto binary = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); + auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vmin.v" + std::to_string(*lanes) + vec + ".x"; + } + if (auto binary = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); + auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vand.v" + std::to_string(*lanes) + vec + ".x"; + } + if (auto binary = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); + auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vor.v" + std::to_string(*lanes) + vec + ".x"; + } + if (auto binary = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); + auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vxor.v" + std::to_string(*lanes) + vec + ".x"; + } + if (auto binary = dyn_cast(op)) { + std::string vec = getElementTypeFragment( + getElementTypeFromVectorLike(binary.getResult().getType())); + auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vaddc.v" + std::to_string(*lanes) + vec; + } + if (auto binary = dyn_cast(op)) { + std::string vec = getElementTypeFragment( + getElementTypeFromVectorLike(binary.getResult().getType())); + auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vsubc.v" + std::to_string(*lanes) + vec; + } + if (auto binary = dyn_cast(op)) { + std::string vec = getElementTypeFragment( + getElementTypeFromVectorLike(binary.getResult().getType())); + auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vaddcs.v" + std::to_string(*lanes) + vec; + } + if (auto binary = dyn_cast(op)) { + std::string vec = getElementTypeFragment( + getElementTypeFromVectorLike(binary.getResult().getType())); + auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vsubcs.v" + std::to_string(*lanes) + vec; + } + if (auto binary = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); + auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vshl.v" + std::to_string(*lanes) + vec + ".x"; + } + if (auto binary = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); + auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vshr.v" + std::to_string(*lanes) + vec + ".x"; + } + if (auto unary = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(unary.getResult().getType())); + auto lanes = getElementCountFromVectorLike(unary.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vcadd.v" + std::to_string(*lanes) + vec + ".x"; + } + if (auto unary = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(unary.getResult().getType())); + auto lanes = getElementCountFromVectorLike(unary.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vcmax.v" + std::to_string(*lanes) + vec + ".x"; + } + if (auto unary = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(unary.getResult().getType())); + auto lanes = getElementCountFromVectorLike(unary.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vcmin.v" + std::to_string(*lanes) + vec + ".x"; + } + if (auto unary = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(unary.getResult().getType())); + auto lanes = getElementCountFromVectorLike(unary.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vcgadd.v" + std::to_string(*lanes) + vec + ".x"; + } + if (auto unary = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(unary.getResult().getType())); + auto lanes = getElementCountFromVectorLike(unary.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vcgmax.v" + std::to_string(*lanes) + vec + ".x"; + } + if (auto unary = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(unary.getResult().getType())); + auto lanes = getElementCountFromVectorLike(unary.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vcgmin.v" + std::to_string(*lanes) + vec + ".x"; + } + if (auto unary = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(unary.getResult().getType())); + auto lanes = getElementCountFromVectorLike(unary.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vcpadd.v" + std::to_string(*lanes) + vec + ".x"; + } + if (auto ternary = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(ternary.getResult().getType())); + auto lanes = getElementCountFromVectorLike(ternary.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vmula.v" + std::to_string(*lanes) + vec + ".m"; + } + if (auto binary = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(binary.getLow().getType())); + auto lanes = getElementCountFromVectorLike(binary.getLow().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vmull.v" + std::to_string(*lanes) + vec; + } + if (auto ternary = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(ternary.getResult().getType())); + auto lanes = getElementCountFromVectorLike(ternary.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vaxpy.v" + std::to_string(*lanes) + vec + ".m"; + } + if (auto vci = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(vci.getResult().getType())); + auto lanes = getElementCountFromVectorLike(vci.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + if (vec == "f16") + return "llvm.hivm.vci.v" + std::to_string(*lanes) + vec + ".f16"; + if (vec == "f32") + return "llvm.hivm.vci.v" + std::to_string(*lanes) + vec + ".f32"; + return "llvm.hivm.vci.v" + std::to_string(*lanes) + vec; + } + if (auto vbitsort = dyn_cast(op)) { + Type sourceElemType = getElementTypeFromABIValue(vbitsort.getSource()); + if (!sourceElemType) + return failure(); + if (sourceElemType.isF16()) + return std::string("llvm.hivm.VBS32.V300.f16"); + if (sourceElemType.isF32()) + return std::string("llvm.hivm.VBS32.V300.f32"); + return failure(); + } + if (auto vtrc = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(vtrc.getResult().getType())); + auto lanes = getElementCountFromVectorLike(vtrc.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vtrc." + vec + ".x"; + } + if (auto vcvt = dyn_cast(op)) { + Type inputElemType = getElementTypeFromVectorLike(vcvt.getInput().getType()); + Type resultElemType = getElementTypeFromVectorLike(vcvt.getResult().getType()); + if (!inputElemType || !resultElemType) + return failure(); + auto contract = lookupVcvtContract(classifyVcvtElemType(inputElemType), + classifyVcvtElemType(resultElemType)); + if (contract) + return std::string(contract->intrinsic); + return failure(); + } + if (isa(op)) + return std::string("llvm.hivm.vstar"); + if (isa(op)) + return std::string("llvm.hivm.vstas"); + if (auto vsqz = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(vsqz.getResult().getType())); + auto lanes = getElementCountFromVectorLike(vsqz.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vsqz.v" + std::to_string(*lanes) + vec + ".x.v300"; + } + if (auto vusqz = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(vusqz.getResult().getType())); + auto lanes = getElementCountFromVectorLike(vusqz.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vusqz.v" + std::to_string(*lanes) + vec + ".m"; + } + if (auto unpack = dyn_cast(op)) { + Type inputElemType = getElementTypeFromVectorLike(unpack.getSrc().getType()); + Type resultElemType = getElementTypeFromVectorLike(unpack.getResult().getType()); + std::string input = getElementTypeFragment(inputElemType); + std::string result = getElementTypeFragment(resultElemType); + if (input.empty() || result.empty()) + return failure(); + return "llvm.hivm.vsunpack." + input + "2" + result; + } + if (auto unpack = dyn_cast(op)) { + Type inputElemType = getElementTypeFromVectorLike(unpack.getSrc().getType()); + Type resultElemType = getElementTypeFromVectorLike(unpack.getResult().getType()); + std::string input = getElementTypeFragment(inputElemType); + std::string result = getElementTypeFragment(resultElemType); + if (input.empty() || result.empty()) + return failure(); + return "llvm.hivm.vzunpack." + input + "2" + result; + } + if (auto pack = dyn_cast(op)) { + Type inputElemType = getElementTypeFromVectorLike(pack.getSrc().getType()); + Type resultElemType = getElementTypeFromVectorLike(pack.getResult().getType()); + std::string input = getElementTypeFragment(inputElemType); + std::string result = getElementTypeFragment(resultElemType); + auto part = parseHiLoPartImmediate(pack.getPart()); + if (input.empty() || result.empty() || !part) + return failure(); + return "llvm.hivm.vpack." + input + "2" + result + ".x"; + } + if (auto interleave = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(interleave.getLow().getType())); + auto lanes = getElementCountFromVectorLike(interleave.getLow().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vintlv.v" + std::to_string(*lanes) + vec; + } + if (auto deinterleave = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(deinterleave.getLow().getType())); + auto lanes = getElementCountFromVectorLike(deinterleave.getLow().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vdintlv.v" + std::to_string(*lanes) + vec; + } + if (isa(op)) + return std::string("llvm.hivm.vsldb"); + if (isa(op)) + return std::string("llvm.hivm.vsstb"); + if (auto vldsx2 = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(vldsx2.getLow().getType())); + auto lanes = getElementCountFromVectorLike(vldsx2.getLow().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vldsx2.v" + std::to_string(*lanes) + vec; + } + if (auto vstsx2 = dyn_cast(op)) { + std::string vec = getElementTypeFragment( + getElementTypeFromVectorLike(vstsx2.getLow().getType())); + auto lanes = getElementCountFromVectorLike(vstsx2.getLow().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vstsx2.v" + std::to_string(*lanes) + vec; + } + if (auto vsts = dyn_cast(op)) { + std::string vec = getElementTypeFragment(getElementTypeFromVectorLike(vsts.getValue().getType())); + auto lanes = getElementCountFromVectorLike(vsts.getValue().getType()); + if (vec.empty() || !lanes) + return failure(); + std::string name = "llvm.hivm.vstsx1"; + name += ".v" + std::to_string(*lanes) + vec; + return name; + } + if (auto vstsPost = dyn_cast(op)) { + std::string vec = getElementTypeFragment( + getElementTypeFromVectorLike(vstsPost.getValue().getType())); + auto lanes = getElementCountFromVectorLike(vstsPost.getValue().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vstsx1.post.v" + std::to_string(*lanes) + vec; + } + if (auto vstsPost = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(vstsPost.getValue().getType())); + auto lanes = getElementCountFromVectorLike(vstsPost.getValue().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vstsx1.post.v" + std::to_string(*lanes) + vec; + } + if (auto vcmp = dyn_cast(op)) { + std::string elem = getElementTypeFragment(getElementTypeFromVectorLike(vcmp.getSrc0().getType())); + if (elem.empty()) + return failure(); + return "llvm.hivm.vcmp." + vcmp.getCmpMode().str() + "." + elem + ".z"; + } + if (auto vcmps = dyn_cast(op)) { + std::string elem = getElementTypeFragment(getElementTypeFromVectorLike(vcmps.getSrc().getType())); + if (elem.empty()) + return failure(); + return "llvm.hivm.vcmps." + vcmps.getCmpMode().str() + "." + elem + ".z"; + } + if (auto vsel = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(vsel.getResult().getType())); + auto lanes = getElementCountFromVectorLike(vsel.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vsel.v" + std::to_string(*lanes) + vec; + } + if (auto vselr = dyn_cast(op)) { + Type elemType = getElementTypeFromVectorLike(vselr.getResult().getType()); + auto lanes = getElementCountFromVectorLike(vselr.getResult().getType()); + if (!elemType || !lanes) + return failure(); + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(vselr.getResult().getType())); + if (auto floatType = dyn_cast(elemType); floatType && floatType.isF32()) + vec = "u32"; + if (vec.empty()) + return failure(); + return "llvm.hivm.vselr.v" + std::to_string(*lanes) + vec; + } + if (isa(op)) + return std::string("llvm.hivm.ppack.z"); + if (isa(op)) + return std::string("llvm.hivm.punpack"); + if (isa(op)) + return std::string("llvm.hivm.pnot.z"); + if (isa(op)) + return std::string("llvm.hivm.psel"); + if (isa(op)) + return std::string("llvm.hivm.pand.z"); + if (isa(op)) + return std::string("llvm.hivm.por.z"); + if (isa(op)) + return std::string("llvm.hivm.pxor.z"); + if (isa(op)) + return std::string("llvm.hivm.pdintlv.b8"); + if (isa(op)) + return std::string("llvm.hivm.pdintlv.b16"); + if (isa(op)) + return std::string("llvm.hivm.pdintlv.b32"); + if (isa(op)) + return std::string("llvm.hivm.pintlv.b8"); + if (isa(op)) + return std::string("llvm.hivm.pintlv.b16"); + if (isa(op)) + return std::string("llvm.hivm.pintlv.b32"); + if (isa(op)) + return std::string("llvm.hivm.plds.b8"); + if (isa(op)) + return std::string("llvm.hivm.pldi.b8"); + if (isa(op)) + return std::string("llvm.hivm.psts.b8"); + if (op->getName().getStringRef() == "pto.pstu") { + Type maskOperandType = op->getOperand(1).getType(); + if (auto maskType = dyn_cast(maskOperandType)) { + if (maskType.isB16()) + return std::string("llvm.hivm.pstu.b16"); + if (maskType.isB32()) + return std::string("llvm.hivm.pstu.b32"); + } + if (Type baseElementType = getElementTypeFromABIValue(op->getOperand(2))) { + if (auto intType = dyn_cast(baseElementType)) { + if (intType.getWidth() == 16) + return std::string("llvm.hivm.pstu.b16"); + if (intType.getWidth() == 32) + return std::string("llvm.hivm.pstu.b32"); + } + } + // Current repo coverage only exercises the installed `b32` surface. Keep + // this fallback narrow to unblock those cases; `b16` still needs an + // end-to-end testcase path before we can claim the generic surface works. + return std::string("llvm.hivm.pstu.b32"); + } + if (auto pstu = dyn_cast(op)) { + if (auto maskType = dyn_cast(pstu.getValue().getType())) { + if (maskType.isB16()) + return std::string("llvm.hivm.pstu.b16"); + if (maskType.isB32()) + return std::string("llvm.hivm.pstu.b32"); + } + if (Type baseElementType = getElementTypeFromABIValue(pstu.getBase())) { + if (auto intType = dyn_cast(baseElementType)) { + if (intType.getWidth() == 16) + return std::string("llvm.hivm.pstu.b16"); + if (intType.getWidth() == 32) + return std::string("llvm.hivm.pstu.b32"); + } + } + return failure(); + } + if (isa(op)) + return std::string("llvm.hivm.psti.b8"); + if (auto gather = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(gather.getResult().getType())); + auto lanes = getElementCountFromVectorLike(gather.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vgather2.v300.v" + std::to_string(*lanes) + vec; + } + if (auto gather = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(gather.getResult().getType())); + auto lanes = getElementCountFromVectorLike(gather.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vgather2.bc.v" + std::to_string(*lanes) + vec; + } + if (auto gather = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(gather.getResult().getType())); + auto lanes = getElementCountFromVectorLike(gather.getResult().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vgatherb.v310.v" + std::to_string(*lanes) + vec; + } + if (auto scatter = dyn_cast(op)) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(scatter.getValue().getType())); + auto lanes = getElementCountFromVectorLike(scatter.getValue().getType()); + if (vec.empty() || !lanes) + return failure(); + return "llvm.hivm.vscatter.v" + std::to_string(*lanes) + vec + ".v300"; + } + return failure(); +} + +static LogicalResult +guardNoMemRefIntrinsicArgs(Operation *op, StringRef calleeName, + ValueRange callArgs, llvm::raw_ostream &diagOS) { + if (calleeName != "llvm.hivm.vldsx1" && calleeName != "llvm.hivm.vstsx1") + return success(); + + for (auto [idx, arg] : llvm::enumerate(callArgs)) { + Type argType = arg.getType(); + if (!isa(argType)) + continue; + diagOS << "VPTO LLVM emission failed: intrinsic ABI guard rejected memref " + "argument #" + << idx << " for " << calleeName << " from " + << op->getName().getStringRef() << " (" << argType << ")\n"; + return failure(); + } + return success(); +} + +static LogicalResult rewriteVPTOOp(Operation *op, ModuleOp module, + llvm::raw_ostream &diagOS) { + IRRewriter builder(op->getContext()); + builder.setInsertionPoint(op); + Location loc = op->getLoc(); + + if (auto vbr = dyn_cast(op)) { + auto calleeName = getConfirmedCallee(op); + if (failed(calleeName)) { + diagOS << "VPTO LLVM emission failed: unsupported op " + << op->getName().getStringRef() << "\n"; + return failure(); + } + + Type resultType = convertVPTOType(vbr.getResult().getType(), builder); + Type scalarType = vbr.getValue().getType(); + if (!resultType || !scalarType) { + diagOS << "VPTO LLVM emission failed: could not materialize vbr types\n"; + return failure(); + } + + auto funcType = builder.getFunctionType({scalarType}, {resultType}); + auto callee = getOrCreateExternalFunc(module, *calleeName, funcType); + auto call = + builder.create(loc, callee, ValueRange{vbr.getValue()}); + builder.replaceOp(op, call.getResults()); + return success(); + } + + if (isa(op)) { + SmallVector argTypes; + auto funcType = builder.getFunctionType(argTypes, op->getResultTypes()); + auto callee = getOrCreateExternalFunc(module, *getConfirmedCallee(op), funcType); + auto call = builder.create(loc, callee, ValueRange{}); + builder.replaceOp(op, call.getResults()); + return success(); + } + + auto calleeName = getConfirmedCallee(op); + if (failed(calleeName)) { + diagOS << "VPTO LLVM emission failed: unsupported op " + << op->getName().getStringRef() << "\n"; + return failure(); + } + + SmallVector surfaceResultTypes(op->getResultTypes().begin(), + op->getResultTypes().end()); + SmallVector loweredResultTypes; + loweredResultTypes.reserve(surfaceResultTypes.size()); + for (Type type : surfaceResultTypes) + loweredResultTypes.push_back(convertVPTOType(type, builder)); + SmallVector intrinsicResultTypes(loweredResultTypes.begin(), + loweredResultTypes.end()); + if (auto vldus = dyn_cast(op)) { + Type sourceType = convertVPTOType(vldus.getSource().getType(), builder); + if (!sourceType) { + diagOS << "VPTO LLVM emission failed: could not materialize vldus source type\n"; + return failure(); + } + intrinsicResultTypes.push_back(sourceType); + } + + SmallVector callArgs; + + if (isa(op)) { + auto packed = packLoopPair(op, op->getOperand(0), op->getOperand(1)); + if (failed(packed)) + return failure(); + callArgs.push_back(*packed); + } else if (isa(op)) { + auto packed = packLoopSize(op, op->getOperand(0), op->getOperand(1)); + if (failed(packed)) + return failure(); + callArgs.push_back(*packed); + } else if (auto copy = dyn_cast(op)) { + auto config0 = packCopyGmToUbConfig0(op, copy, op->getOperands()); + auto config1 = packCopyGmToUbConfig1(op, op->getOperands()); + auto destination = requirePointerABIAddress(op, copy.getDestination(), diagOS); + auto source = requirePointerABIAddress(op, copy.getSource(), diagOS); + if (failed(config0) || failed(config1) || failed(destination) || + failed(source)) + return failure(); + callArgs.push_back(*destination); + callArgs.push_back(*source); + callArgs.push_back(*config0); + callArgs.push_back(*config1); + } else if (auto copy = dyn_cast(op)) { + auto config0 = packCopyUbToGmConfig0(op, op->getOperands()); + auto config1 = packCopyUbToGmConfig1(op, op->getOperands()); + auto destination = requirePointerABIAddress(op, copy.getDestination(), diagOS); + auto source = requirePointerABIAddress(op, copy.getSource(), diagOS); + if (failed(config0) || failed(config1) || failed(destination) || + failed(source)) + return failure(); + callArgs.push_back(*destination); + callArgs.push_back(*source); + callArgs.push_back(*config0); + callArgs.push_back(*config1); + } else if (auto setFlag = dyn_cast(op)) { + auto src = parsePipeImmediate(stringifyPIPE(setFlag.getSrcPipe().getPipe())); + auto dst = parsePipeImmediate(stringifyPIPE(setFlag.getDstPipe().getPipe())); + auto event = parseEventImmediate(stringifyEVENT(setFlag.getEventId().getEvent())); + if (!src || !dst || !event) + return failure(); + callArgs.push_back(getI64Constant(builder, loc, *src)); + callArgs.push_back(getI64Constant(builder, loc, *dst)); + callArgs.push_back(getI64Constant(builder, loc, *event)); + } else if (auto waitFlag = dyn_cast(op)) { + auto src = + parsePipeImmediate(stringifyPIPE(waitFlag.getSrcPipe().getPipe())); + auto dst = + parsePipeImmediate(stringifyPIPE(waitFlag.getDstPipe().getPipe())); + auto event = + parseEventImmediate(stringifyEVENT(waitFlag.getEventId().getEvent())); + if (!src || !dst || !event) + return failure(); + callArgs.push_back(getI64Constant(builder, loc, *src)); + callArgs.push_back(getI64Constant(builder, loc, *dst)); + callArgs.push_back(getI64Constant(builder, loc, *event)); + } else if (auto barrier = dyn_cast(op)) { + auto pipe = parsePipeImmediate(stringifyPIPE(barrier.getPipe().getPipe())); + if (!pipe) + return failure(); + callArgs.push_back(getI64Constant(builder, loc, *pipe)); + } else if (auto sprclr = dyn_cast(op)) { + auto spr = parseSprImmediate(sprclr.getSpr()); + if (!spr) { + diagOS << "VPTO LLVM emission failed: unsupported sprclr target " + << sprclr.getSpr() << "\n"; + return failure(); + } + callArgs.push_back(getI16Constant(builder, loc, *spr)); + } else if (isa(op)) { + Value laneCount = castIntegerLikeTo(op, op->getOperand(0), builder.getI32Type()); + if (!laneCount) + return failure(); + callArgs.push_back(laneCount); + } else if (auto pset = dyn_cast(op)) { + auto pattern = parsePredicatePatternImmediate(pset.getPattern()); + if (!pattern) { + diagOS << "VPTO LLVM emission failed: unsupported pset_b8 pattern " + << pset.getPattern() << "\n"; + return failure(); + } + callArgs.push_back(getI32Constant(builder, loc, *pattern)); + } else if (auto pset = dyn_cast(op)) { + auto pattern = parsePredicatePatternImmediate(pset.getPattern()); + if (!pattern) { + diagOS << "VPTO LLVM emission failed: unsupported pset_b16 pattern " + << pset.getPattern() << "\n"; + return failure(); + } + callArgs.push_back(getI32Constant(builder, loc, *pattern)); + } else if (auto pset = dyn_cast(op)) { + auto pattern = parsePredicatePatternImmediate(pset.getPattern()); + if (!pattern) { + diagOS << "VPTO LLVM emission failed: unsupported pset_b32 pattern " + << pset.getPattern() << "\n"; + return failure(); + } + callArgs.push_back(getI32Constant(builder, loc, *pattern)); + } else if (auto pge = dyn_cast(op)) { + auto pattern = parsePredicatePatternImmediate(pge.getPattern()); + if (!pattern) { + diagOS << "VPTO LLVM emission failed: unsupported pge_b8 pattern " + << pge.getPattern() << "\n"; + return failure(); + } + callArgs.push_back(getI32Constant(builder, loc, *pattern)); + callArgs.push_back(getI32Constant(builder, loc, 0)); + } else if (auto pge = dyn_cast(op)) { + auto pattern = parsePredicatePatternImmediate(pge.getPattern()); + if (!pattern) { + diagOS << "VPTO LLVM emission failed: unsupported pge_b16 pattern " + << pge.getPattern() << "\n"; + return failure(); + } + callArgs.push_back(getI32Constant(builder, loc, *pattern)); + callArgs.push_back(getI32Constant(builder, loc, 0)); + } else if (auto pge = dyn_cast(op)) { + auto pattern = parsePredicatePatternImmediate(pge.getPattern()); + if (!pattern) { + diagOS << "VPTO LLVM emission failed: unsupported pge_b32 pattern " + << pge.getPattern() << "\n"; + return failure(); + } + callArgs.push_back(getI32Constant(builder, loc, *pattern)); + callArgs.push_back(getI32Constant(builder, loc, 0)); + } else if (isa(op)) { + // llvm.hivm.init.vector.align.data() has no operands. + } else if (auto vldas = dyn_cast(op)) { + auto source = requirePointerABIAddress(op, vldas.getSource(), diagOS); + if (failed(source)) + return failure(); + callArgs.push_back(*source); + } else if (auto vldus = dyn_cast(op)) { + auto source = requirePointerABIAddress(op, vldus.getSource(), diagOS); + if (failed(source)) + return failure(); + callArgs.push_back(*source); + callArgs.push_back(vldus.getAlign()); + } else if (auto vstus = dyn_cast(op)) { + Type elementType = getElementTypeFromVectorLike(vstus.getValue().getType()); + auto basePtr = requirePointerABIAddress(op, vstus.getBase(), diagOS); + auto alignValue = materializeAlignABIValue(op, vstus.getAlignIn(), diagOS); + if (!elementType || failed(basePtr)) + return failure(); + auto offsetBytes = convertElementOffsetToBytes(op, vstus.getOffset(), elementType); + if (failed(offsetBytes) || failed(alignValue)) + return failure(); + callArgs.push_back(vstus.getValue()); + callArgs.push_back(*basePtr); + callArgs.push_back(*offsetBytes); + callArgs.push_back(*alignValue); + } else if (auto vstur = dyn_cast(op)) { + auto basePtr = requirePointerABIAddress(op, vstur.getBase(), diagOS); + auto postMode = parsePostModeImmediate(vstur.getMode()); + auto alignValue = materializeAlignABIValue(op, vstur.getAlignIn(), diagOS); + if (failed(basePtr) || !postMode) { + if (!postMode) + diagOS << "VPTO LLVM emission failed: unsupported vstur mode " + << vstur.getMode() << "\n"; + return failure(); + } + if (failed(alignValue)) + return failure(); + callArgs.push_back(vstur.getValue()); + callArgs.push_back(*basePtr); + callArgs.push_back(*alignValue); + callArgs.push_back(getI32Constant(builder, loc, *postMode)); + callArgs.push_back(getI32Constant(builder, loc, 0)); + } else if (auto vlds = dyn_cast(op)) { + Type elementType = getElementTypeFromVectorLike(vlds.getResult().getType()); + auto offsetBytes = convertElementOffsetToBytes( + op, op->getOperand(1), elementType); + auto basePtr = requirePointerABIAddress(op, op->getOperand(0), diagOS); + auto dist = + parseLoadDistImmediate(vlds.getDist().value_or("NORM"), elementType); + if (!elementType || failed(offsetBytes) || failed(basePtr) || !dist) { + if (elementType && succeeded(basePtr) && !dist) + diagOS << "VPTO LLVM emission failed: unsupported vlds dist immediate\n"; + return failure(); + } + callArgs.push_back(*basePtr); + callArgs.push_back(*offsetBytes); + callArgs.push_back(getI32Constant(builder, loc, *dist)); + callArgs.push_back(getI32Constant(builder, loc, 0)); + } else if (auto vldsPost = dyn_cast(op)) { + Type elementType = getElementTypeFromVectorLike(vldsPost.getResult().getType()); + auto offsetBytes = convertElementOffsetToBytes( + op, vldsPost.getOffset(), elementType); + auto basePtr = requirePointerABIAddress(op, vldsPost.getSource(), diagOS); + auto dist = + parseLoadDistImmediate(vldsPost.getDist().value_or("NORM"), elementType); + if (!elementType || failed(offsetBytes) || failed(basePtr) || !dist) + return failure(); + callArgs.push_back(*basePtr); + callArgs.push_back(*offsetBytes); + callArgs.push_back(getI32Constant(builder, loc, *dist)); + callArgs.push_back(getI32Constant(builder, loc, 1)); + } else if (auto vabs = dyn_cast(op)) { + Value input = op->getOperand(0); + Value mask = op->getOperand(1); + Type vecType = loweredResultTypes.front(); + Type maskType = convertVPTOType(mask.getType(), builder); + if (input.getType() != vecType || mask.getType() != maskType) { + diagOS << "VPTO LLVM emission failed: unexpected vabs operand types\n"; + return failure(); + } + callArgs.push_back(input); + callArgs.push_back(mask); + } else if (auto unary = dyn_cast(op)) { + Value input = unary.getInput(); + Value mask = unary.getMask(); + Type vecType = loweredResultTypes.front(); + Type maskType = convertVPTOType(mask.getType(), builder); + if (input.getType() != vecType || mask.getType() != maskType) { + diagOS << "VPTO LLVM emission failed: unexpected " + << op->getName().getStringRef() << " operand types\n"; + return failure(); + } + callArgs.push_back(input); + callArgs.push_back(mask); + } else if (auto unary = dyn_cast(op)) { + Value input = unary.getInput(); + Value mask = unary.getMask(); + Type vecType = loweredResultTypes.front(); + Type maskType = convertVPTOType(mask.getType(), builder); + if (input.getType() != vecType || mask.getType() != maskType) { + diagOS << "VPTO LLVM emission failed: unexpected " + << op->getName().getStringRef() << " operand types\n"; + return failure(); + } + callArgs.push_back(input); + callArgs.push_back(mask); + } else if (auto unary = dyn_cast(op)) { + Value input = unary.getInput(); + Value mask = unary.getMask(); + Type vecType = loweredResultTypes.front(); + Type maskType = convertVPTOType(mask.getType(), builder); + if (input.getType() != vecType || mask.getType() != maskType) { + diagOS << "VPTO LLVM emission failed: unexpected " + << op->getName().getStringRef() << " operand types\n"; + return failure(); + } + callArgs.push_back(input); + callArgs.push_back(mask); + } else if (auto unary = dyn_cast(op)) { + Value input = unary.getInput(); + Value mask = unary.getMask(); + Type vecType = loweredResultTypes.front(); + Type maskType = convertVPTOType(mask.getType(), builder); + if (input.getType() != vecType || mask.getType() != maskType) { + diagOS << "VPTO LLVM emission failed: unexpected " + << op->getName().getStringRef() << " operand types\n"; + return failure(); + } + callArgs.push_back(input); + callArgs.push_back(mask); + } else if (auto unary = dyn_cast(op)) { + Value input = unary.getInput(); + Value mask = unary.getMask(); + Type vecType = loweredResultTypes.front(); + Type maskType = convertVPTOType(mask.getType(), builder); + if (input.getType() != vecType || mask.getType() != maskType) { + diagOS << "VPTO LLVM emission failed: unexpected " + << op->getName().getStringRef() << " operand types\n"; + return failure(); + } + callArgs.push_back(input); + callArgs.push_back(mask); + } else if (auto unary = dyn_cast(op)) { + Value input = unary.getInput(); + Value mask = unary.getMask(); + Type vecType = loweredResultTypes.front(); + Type maskType = convertVPTOType(mask.getType(), builder); + if (input.getType() != vecType || mask.getType() != maskType) { + diagOS << "VPTO LLVM emission failed: unexpected " + << op->getName().getStringRef() << " operand types\n"; + return failure(); + } + callArgs.push_back(input); + callArgs.push_back(mask); + } else if (auto vdup = dyn_cast(op)) { + Type scalarType = getElementTypeFromVectorLike(vdup.getResult().getType()); + bool vectorInput = isa(vdup.getInput().getType()); + if (!vectorInput && (!scalarType || vdup.getInput().getType() != scalarType)) { + diagOS << "VPTO LLVM emission failed: unexpected vdup operand types\n"; + return failure(); + } + if (vectorInput && vdup.getInput().getType() != loweredResultTypes.front()) { + diagOS << "VPTO LLVM emission failed: vector-input vdup requires matching result type\n"; + return failure(); + } + if (vectorInput) { + callArgs.push_back(vdup.getInput()); + } else { + FailureOr normalizedScalar = normalizeVdupScalarOperand(builder, loc, vdup); + if (failed(normalizedScalar)) + return failure(); + callArgs.push_back(*normalizedScalar); + } + callArgs.push_back(vdup.getMask()); + callArgs.push_back(getI32Constant(builder, loc, 1)); + } else if (isa(op)) { + callArgs.append(op->operand_begin(), op->operand_end()); + } else if (isa(op)) { + callArgs.push_back(op->getOperand(0)); + callArgs.push_back(op->getOperand(1)); + callArgs.push_back(op->getOperand(2)); + } else if (isa(op)) { + callArgs.push_back(op->getOperand(0)); + callArgs.push_back(op->getOperand(1)); + callArgs.push_back(op->getOperand(2)); + callArgs.push_back(op->getOperand(3)); + } else if (auto vmula = dyn_cast(op)) { + callArgs.push_back(vmula.getAcc()); + callArgs.push_back(vmula.getLhs()); + callArgs.push_back(vmula.getRhs()); + callArgs.push_back(vmula.getMask()); + } else if (auto vmull = dyn_cast(op)) { + callArgs.push_back(vmull.getLhs()); + callArgs.push_back(vmull.getRhs()); + callArgs.push_back(vmull.getMask()); + } else if (auto vaxpy = dyn_cast(op)) { + auto laneCount = getElementCountFromVectorLike(vaxpy.getResult().getType()); + if (!laneCount) { + diagOS << "VPTO LLVM emission failed: could not determine lane count for " + << op->getName().getStringRef() << "\n"; + return failure(); + } + Type elemType = getElementTypeFromVectorLike(vaxpy.getResult().getType()); + Value mask; + if (elemType.isF32()) { + auto fullMask = buildPltB32Mask(builder, module, loc, *laneCount, diagOS); + if (failed(fullMask)) + return failure(); + mask = *fullMask; + } else { + auto fullMask = buildPltB16Mask(builder, module, loc, *laneCount, diagOS); + if (failed(fullMask)) + return failure(); + mask = *fullMask; + } + // Installed wrapper surface is dst = alpha * src0 + dst. VPTO models this + // as a pure op returning the updated addend vector. + callArgs.push_back(vaxpy.getSrc1()); + callArgs.push_back(vaxpy.getSrc0()); + callArgs.push_back(vaxpy.getAlpha()); + callArgs.push_back(mask); + } else if (auto vci = dyn_cast(op)) { + auto orderAttr = op->getAttrOfType("order"); + auto order = parseOrderImmediate(orderAttr ? orderAttr.getValue() : StringRef("ASC")); + if (!order) { + diagOS << "VPTO LLVM emission failed: unsupported vci order "; + if (orderAttr) + diagOS << orderAttr.getValue(); + else + diagOS << ""; + diagOS << "\n"; + return failure(); + } + callArgs.push_back(vci.getIndex()); + callArgs.push_back(getI32Constant(builder, loc, *order)); + } else if (isa(op)) { + callArgs.append(op->operand_begin(), op->operand_end()); + auto laneCount = getElementCountFromVectorLike(op->getResult(0).getType()); + if (!laneCount) { + diagOS << "VPTO LLVM emission failed: could not determine lane count for " + << op->getName().getStringRef() << "\n"; + return failure(); + } + Value mask; + if (getElementTypeFromVectorLike(op->getResult(0).getType()).isF32() || + getElementTypeFromVectorLike(op->getResult(0).getType()).isInteger(32)) { + auto fullMask = buildPltB32Mask(builder, module, loc, *laneCount, diagOS); + if (failed(fullMask)) + return failure(); + mask = *fullMask; + } else { + auto fullMask = buildPltB16Mask(builder, module, loc, *laneCount, diagOS); + if (failed(fullMask)) + return failure(); + mask = *fullMask; + } + callArgs.push_back(mask); + } else if (auto vexpdiff = dyn_cast(op)) { + callArgs.push_back(vexpdiff.getInput()); + callArgs.push_back(vexpdiff.getMax()); + auto srcLaneCount = getElementCountFromVectorLike(vexpdiff.getInput().getType()); + if (!srcLaneCount) { + diagOS << "VPTO LLVM emission failed: could not determine lane count for " + << op->getName().getStringRef() << "\n"; + return failure(); + } + Value mask; + Type inputElemType = getElementTypeFromVectorLike(vexpdiff.getInput().getType()); + if (inputElemType.isF32() || inputElemType.isInteger(32)) { + auto fullMask = buildPltB32Mask(builder, module, loc, *srcLaneCount, diagOS); + if (failed(fullMask)) + return failure(); + mask = *fullMask; + } else { + auto fullMask = buildPltB16Mask(builder, module, loc, *srcLaneCount, diagOS); + if (failed(fullMask)) + return failure(); + mask = *fullMask; + } + auto part = parsePartImmediate(vexpdiff.getPart()); + if (!part) { + diagOS << "VPTO LLVM emission failed: unsupported vexpdiff part "; + diagOS << vexpdiff.getPart(); + diagOS << "\n"; + return failure(); + } + callArgs.push_back(mask); + callArgs.push_back(getI32Constant(builder, loc, *part)); + } else if (isa(op)) { + callArgs.append(op->operand_begin(), op->operand_end()); + } else if (isa(op)) { + callArgs.push_back(op->getOperand(0)); + callArgs.push_back(op->getOperand(1)); + } else if (auto vtrc = dyn_cast(op)) { + auto roundMode = parseRoundModeImmediate(vtrc.getRoundMode()); + if (!roundMode) { + diagOS << "VPTO LLVM emission failed: unsupported round mode " + << vtrc.getRoundMode() << "\n"; + return failure(); + } + auto laneCount = getElementCountFromVectorLike(vtrc.getResult().getType()); + if (!laneCount) { + diagOS << "VPTO LLVM emission failed: could not determine lane count for " + << op->getName().getStringRef() << "\n"; + return failure(); + } + auto mask = buildPltB32Mask(builder, module, loc, *laneCount, diagOS); + if (failed(mask)) + return failure(); + callArgs.push_back(vtrc.getInput()); + callArgs.push_back(getI32Constant(builder, loc, *roundMode)); + callArgs.push_back(*mask); + } else if (auto vcvt = dyn_cast(op)) { + Type inputElemType = getElementTypeFromVectorLike(vcvt.getInput().getType()); + Type resultElemType = getElementTypeFromVectorLike(vcvt.getResult().getType()); + auto inputLanes = getElementCountFromVectorLike(vcvt.getInput().getType()); + if (!inputElemType || !resultElemType || !inputLanes) { + diagOS << "VPTO LLVM emission failed: could not determine vcvt type shape\n"; + return failure(); + } + + auto contract = lookupVcvtContract(classifyVcvtElemType(inputElemType), + classifyVcvtElemType(resultElemType)); + if (!contract) { + diagOS << "VPTO LLVM emission failed: unsupported vcvt type pair " + << vcvt.getInput().getType() << " -> " << vcvt.getResult().getType() + << "\n"; + return failure(); + } + + callArgs.push_back(vcvt.getInput()); + FailureOr mask = failure(); + switch (contract->maskBitWidth) { + case 8: + mask = buildPltB8Mask(builder, module, loc, *inputLanes, diagOS); + break; + case 16: + mask = buildPltB16Mask(builder, module, loc, *inputLanes, diagOS); + break; + case 32: + mask = buildPltB32Mask(builder, module, loc, *inputLanes, diagOS); + break; + default: + diagOS << "VPTO LLVM emission failed: unsupported vcvt mask width " + << contract->maskBitWidth << "\n"; + return failure(); + } + if (failed(mask)) + return failure(); + callArgs.push_back(*mask); + + if (contract->requiresRnd) { + auto roundMode = vcvt.getRndAttr() + ? parseRoundModeImmediate(*vcvt.getRnd()) + : std::nullopt; + if (!roundMode) { + diagOS << "VPTO LLVM emission failed: vcvt requires valid rnd attr\n"; + return failure(); + } + callArgs.push_back(getI32Constant(builder, loc, *roundMode)); + } + if (contract->requiresSat) { + auto sat = + vcvt.getSatAttr() ? parseSaturationImmediate(*vcvt.getSat()) : std::nullopt; + if (!sat) { + diagOS << "VPTO LLVM emission failed: vcvt requires valid sat attr\n"; + return failure(); + } + callArgs.push_back(getI32Constant(builder, loc, *sat)); + } + if (contract->requiresPart) { + auto part = + vcvt.getPartAttr() ? parsePartImmediate(*vcvt.getPart()) : std::nullopt; + if (!part) { + diagOS << "VPTO LLVM emission failed: vcvt requires valid part attr\n"; + return failure(); + } + callArgs.push_back(getI32Constant(builder, loc, *part)); + } + } else if (auto vstar = dyn_cast(op)) { + auto basePtr = requirePointerABIAddress(op, vstar.getDestination(), diagOS); + auto alignValue = materializeAlignABIValue(op, vstar.getValue(), diagOS); + if (failed(basePtr) || failed(alignValue)) + return failure(); + callArgs.push_back(*alignValue); + callArgs.push_back(*basePtr); + callArgs.push_back(getI32Constant(builder, loc, 0)); + } else if (auto vstas = dyn_cast(op)) { + auto basePtr = requirePointerABIAddress(op, vstas.getDestination(), diagOS); + auto alignValue = materializeAlignABIValue(op, vstas.getValue(), diagOS); + Type elementType = getElementTypeFromABIValue(vstas.getDestination()); + if (failed(basePtr) || failed(alignValue) || !elementType) { + diagOS << "VPTO LLVM emission failed: could not materialize vstas ABI " + "inputs; destination type=" + << vstas.getDestination().getType() << ", element type=" + << (elementType ? elementType : Type()) << "\n"; + return failure(); + } + auto offsetBytes = convertElementOffsetToBytes(op, vstas.getOffset(), elementType); + if (failed(offsetBytes)) { + diagOS << "VPTO LLVM emission failed: could not materialize vstas byte " + "offset from " + << vstas.getOffset().getType() << " using element type " + << elementType << "\n"; + return failure(); + } + callArgs.push_back(*alignValue); + callArgs.push_back(*basePtr); + callArgs.push_back(*offsetBytes); + callArgs.push_back(getI32Constant(builder, loc, 0)); + } else if (auto vsqz = dyn_cast(op)) { + callArgs.push_back(vsqz.getInput()); + callArgs.push_back(vsqz.getMask()); + callArgs.push_back( + getI32Constant(builder, loc, determineVsqzStoreHint(vsqz))); + } else if (auto vusqz = dyn_cast(op)) { + callArgs.push_back(vusqz.getSrc()); + callArgs.push_back(vusqz.getMask()); + } else if (auto unpack = dyn_cast(op)) { + Value part = castIntegerLikeTo(op, unpack.getPart(), builder.getI32Type()); + if (!part) { + diagOS << "VPTO LLVM emission failed: could not materialize vsunpack part\n"; + return failure(); + } + callArgs.push_back(unpack.getSrc()); + callArgs.push_back(part); + } else if (auto unpack = dyn_cast(op)) { + Value part = castIntegerLikeTo(op, unpack.getPart(), builder.getI32Type()); + if (!part) { + diagOS << "VPTO LLVM emission failed: could not materialize vzunpack part\n"; + return failure(); + } + callArgs.push_back(unpack.getSrc()); + callArgs.push_back(part); + } else if (auto pack = dyn_cast(op)) { + auto part = parseHiLoPartImmediate(pack.getPart()); + if (!part) { + diagOS << "VPTO LLVM emission failed: unsupported vpack part " + << pack.getPart() << "\n"; + return failure(); + } + callArgs.push_back(pack.getSrc()); + callArgs.push_back(getI32Constant(builder, loc, *part)); + } else if (auto interleave = dyn_cast(op)) { + callArgs.push_back(interleave.getLhs()); + callArgs.push_back(interleave.getRhs()); + } else if (auto deinterleave = dyn_cast(op)) { + callArgs.push_back(deinterleave.getLhs()); + callArgs.push_back(deinterleave.getRhs()); + } else if (auto vldsx2 = dyn_cast(op)) { + Type elementType = getElementTypeFromVectorLike(vldsx2.getLow().getType()); + auto offsetBytes = convertElementOffsetToBytes(op, vldsx2.getOffset(), elementType); + auto basePtr = requirePointerABIAddress(op, vldsx2.getSource(), diagOS); + auto dist = parseLoadX2DistImmediate(vldsx2.getDist(), elementType); + if (!elementType || failed(offsetBytes) || failed(basePtr) || !dist) { + if (elementType && succeeded(basePtr) && !dist) + diagOS << "VPTO LLVM emission failed: unsupported vldsx2 dist immediate\n"; + return failure(); + } + callArgs.push_back(*basePtr); + callArgs.push_back(*offsetBytes); + callArgs.push_back(getI32Constant(builder, loc, *dist)); + callArgs.push_back(getI32Constant(builder, loc, 0)); + } else if (auto vsldb = dyn_cast(op)) { + auto basePtr = requirePointerABIAddress(op, vsldb.getSource(), diagOS); + Value packedStride = packBlockRepeatStride( + op, vsldb.getBlockStride(), vsldb.getRepeatStride()); + if (failed(basePtr) || !packedStride) { + if (succeeded(basePtr) && !packedStride) + diagOS << "VPTO LLVM emission failed: could not pack vsldb control word\n"; + return failure(); + } + callArgs.push_back(*basePtr); + callArgs.push_back(packedStride); + callArgs.push_back(getI32Constant(builder, loc, 0)); + callArgs.push_back(vsldb.getMask()); + } else if (auto vsstb = dyn_cast(op)) { + auto basePtr = requirePointerABIAddress(op, vsstb.getDestination(), diagOS); + Value packedStride = packBlockRepeatStride( + op, vsstb.getBlockStride(), vsstb.getRepeatStride()); + if (failed(basePtr) || !packedStride) { + if (succeeded(basePtr) && !packedStride) + diagOS << "VPTO LLVM emission failed: could not pack vsstb control word\n"; + return failure(); + } + callArgs.push_back(vsstb.getValue()); + callArgs.push_back(*basePtr); + callArgs.push_back(packedStride); + callArgs.push_back(getI32Constant(builder, loc, 0)); + callArgs.push_back(vsstb.getMask()); + } else if (auto vstx2 = dyn_cast(op)) { + Type elementType = getElementTypeFromVectorLike(vstx2.getLow().getType()); + auto offsetBytes = + convertElementOffsetToBytes(op, vstx2.getOffset(), elementType); + auto basePtr = + requirePointerABIAddress(op, vstx2.getDestination(), diagOS); + auto dist = parseStoreX2DistImmediate(vstx2.getDist(), elementType); + if (!elementType || failed(offsetBytes) || failed(basePtr) || !dist) { + if (elementType && succeeded(basePtr) && !dist) + diagOS + << "VPTO LLVM emission failed: unsupported vstsx2 dist immediate\n"; + return failure(); + } + Value offsetI32 = castIntegerLikeTo(op, *offsetBytes, builder.getI32Type()); + if (!offsetI32) + return failure(); + callArgs.push_back(vstx2.getLow()); + callArgs.push_back(vstx2.getHigh()); + callArgs.push_back(*basePtr); + callArgs.push_back(offsetI32); + callArgs.push_back(getI32Constant(builder, loc, *dist)); + callArgs.push_back(getI32Constant(builder, loc, 0)); + callArgs.push_back(vstx2.getMask()); + } else if (auto vsts = dyn_cast(op)) { + Type elementType = getElementTypeFromVectorLike(vsts.getValue().getType()); + auto offsetBytes = convertElementOffsetToBytes( + op, op->getOperand(2), elementType); + auto basePtr = requirePointerABIAddress(op, op->getOperand(1), diagOS); + auto dist = + parseStoreDistImmediate(vsts.getDist().value_or("NORM"), elementType); + if (!elementType || failed(offsetBytes) || failed(basePtr) || !dist) { + if (elementType && succeeded(basePtr) && !dist) + diagOS << "VPTO LLVM emission failed: unsupported vsts dist immediate\n"; + return failure(); + } + callArgs.push_back(op->getOperand(0)); + callArgs.push_back(*basePtr); + callArgs.push_back(*offsetBytes); + callArgs.push_back(getI32Constant(builder, loc, *dist)); + callArgs.push_back(getI32Constant(builder, loc, 0)); + callArgs.push_back(op->getOperand(3)); + } else if (auto vstsPost = dyn_cast(op)) { + Type elementType = getElementTypeFromVectorLike(vstsPost.getValue().getType()); + auto offsetBytes = convertElementOffsetToBytes(op, vstsPost.getOffset(), elementType); + auto basePtr = requirePointerABIAddress(op, vstsPost.getDestination(), diagOS); + auto dist = parseStoreDistImmediate(vstsPost.getDist().value_or("NORM"), + elementType); + if (!elementType || failed(offsetBytes) || failed(basePtr) || !dist) + return failure(); + callArgs.push_back(vstsPost.getValue()); + callArgs.push_back(*basePtr); + callArgs.push_back(*offsetBytes); + callArgs.push_back(getI32Constant(builder, loc, *dist)); + callArgs.push_back(getI32Constant(builder, loc, 1)); + callArgs.push_back(vstsPost.getMask()); + } else if (auto ppack = dyn_cast(op)) { + auto part = parseHiLoPartImmediate(ppack.getPart()); + if (!part) { + diagOS << "VPTO LLVM emission failed: unsupported ppack part " + << ppack.getPart() << "\n"; + return failure(); + } + callArgs.push_back(ppack.getInput()); + callArgs.push_back(getI32Constant(builder, loc, *part)); + } else if (auto punpack = dyn_cast(op)) { + auto part = parseHiLoPartImmediate(punpack.getPart()); + if (!part) { + diagOS << "VPTO LLVM emission failed: unsupported punpack part " + << punpack.getPart() << "\n"; + return failure(); + } + callArgs.push_back(punpack.getInput()); + callArgs.push_back(getI32Constant(builder, loc, *part)); + } else if (auto vselr = dyn_cast(op)) { + auto resultVecType = dyn_cast(loweredResultTypes.front()); + if (!resultVecType) { + diagOS << "VPTO LLVM emission failed: unexpected vselr result type\n"; + return failure(); + } + Type intrinsicVecType = resultVecType; + if (auto resultFloat = dyn_cast(resultVecType.getElementType()); + resultFloat && resultFloat.isF32()) { + intrinsicVecType = + VectorType::get(resultVecType.getShape(), builder.getI32Type(), + resultVecType.getScalableDims()); + } + intrinsicResultTypes[0] = intrinsicVecType; + callArgs.push_back(buildBridgeCast(builder, loc, vselr.getSrc0(), intrinsicVecType)); + callArgs.push_back(vselr.getSrc1()); + } else if (isa(op)) { + callArgs.append(op->operand_begin(), op->operand_end()); + } else if (auto plds = dyn_cast(op)) { + auto basePtr = requirePointerABIAddress(op, plds.getSource(), diagOS); + Value offset = castIntegerLikeTo(op, plds.getOffset(), builder.getI32Type()); + auto dist = parsePredicateLoadDistImmediate(plds.getDist()); + if (failed(basePtr) || !offset || !dist) { + if (succeeded(basePtr) && offset && !dist) + diagOS << "VPTO LLVM emission failed: unsupported plds dist immediate\n"; + return failure(); + } + callArgs.push_back(*basePtr); + callArgs.push_back(offset); + callArgs.push_back(getI32Constant(builder, loc, *dist)); + callArgs.push_back(getI32Constant(builder, loc, 0)); + } else if (auto pldi = dyn_cast(op)) { + auto basePtr = requirePointerABIAddress(op, pldi.getSource(), diagOS); + Value offset = castIntegerLikeTo(op, pldi.getOffset(), builder.getI32Type()); + auto dist = parsePredicateLoadDistImmediate(pldi.getDist()); + if (failed(basePtr) || !offset || !dist) { + if (succeeded(basePtr) && offset && !dist) + diagOS << "VPTO LLVM emission failed: unsupported pldi dist immediate\n"; + return failure(); + } + callArgs.push_back(*basePtr); + callArgs.push_back(offset); + callArgs.push_back(getI32Constant(builder, loc, *dist)); + callArgs.push_back(getI32Constant(builder, loc, 0)); + } else if (auto psts = dyn_cast(op)) { + auto basePtr = requirePointerABIAddress(op, psts.getDestination(), diagOS); + Value offset = castIntegerLikeTo(op, psts.getOffset(), builder.getI32Type()); + auto dist = parsePredicateStoreDistImmediate(psts.getDist()); + if (failed(basePtr) || !offset || !dist) { + if (succeeded(basePtr) && offset && !dist) + diagOS << "VPTO LLVM emission failed: unsupported psts dist immediate\n"; + return failure(); + } + callArgs.push_back(psts.getValue()); + callArgs.push_back(*basePtr); + callArgs.push_back(offset); + callArgs.push_back(getI32Constant(builder, loc, *dist)); + callArgs.push_back(getI32Constant(builder, loc, 0)); + } else if (op->getName().getStringRef() == "pto.pstu") { + auto basePtr = requirePointerABIAddress(op, op->getOperand(2), diagOS); + auto alignValue = materializeAlignABIValue(op, op->getOperand(0), diagOS); + if (failed(basePtr) || failed(alignValue)) + return failure(); + callArgs.push_back(op->getOperand(1)); + callArgs.push_back(*basePtr); + callArgs.push_back(*alignValue); + } else if (auto pstu = dyn_cast(op)) { + auto basePtr = requirePointerABIAddress(op, pstu.getBase(), diagOS); + auto alignValue = materializeAlignABIValue(op, pstu.getAlignIn(), diagOS); + if (failed(basePtr) || failed(alignValue)) + return failure(); + callArgs.push_back(pstu.getValue()); + callArgs.push_back(*basePtr); + callArgs.push_back(*alignValue); + } else if (auto psti = dyn_cast(op)) { + auto basePtr = requirePointerABIAddress(op, psti.getDestination(), diagOS); + Value offset = castIntegerLikeTo(op, psti.getOffset(), builder.getI32Type()); + auto dist = parsePredicateStoreDistImmediate(psti.getDist()); + if (failed(basePtr) || !offset || !dist) { + if (succeeded(basePtr) && offset && !dist) + diagOS << "VPTO LLVM emission failed: unsupported psti dist immediate\n"; + return failure(); + } + callArgs.push_back(psti.getValue()); + callArgs.push_back(*basePtr); + callArgs.push_back(offset); + callArgs.push_back(getI32Constant(builder, loc, *dist)); + callArgs.push_back(getI32Constant(builder, loc, 0)); + } else if (auto gather = dyn_cast(op)) { + Type resultElemType = getElementTypeFromVectorLike(gather.getResult().getType()); + auto basePtr = requirePointerABIAddress(op, gather.getSource(), diagOS); + auto mask = buildDynamicPltMask(builder, module, loc, gather.getActiveLanes(), + resultElemType, diagOS); + if (!resultElemType || failed(basePtr) || failed(mask)) + return failure(); + callArgs.push_back(*basePtr); + callArgs.push_back(gather.getOffsets()); + callArgs.push_back(*mask); + } else if (auto gather = dyn_cast(op)) { + auto basePtr = requirePointerABIAddress(op, gather.getSource(), diagOS); + if (failed(basePtr)) + return failure(); + callArgs.push_back(*basePtr); + callArgs.push_back(gather.getOffsets()); + callArgs.push_back(gather.getMask()); + } else if (auto gather = dyn_cast(op)) { + auto basePtr = requirePointerABIAddress(op, gather.getSource(), diagOS); + if (failed(basePtr)) + return failure(); + callArgs.push_back(*basePtr); + callArgs.push_back(gather.getOffsets()); + callArgs.push_back(gather.getMask()); + } else if (auto vbitsort = dyn_cast(op)) { + auto destination = requirePointerABIAddress(op, vbitsort.getDestination(), diagOS); + auto source = requirePointerABIAddress(op, vbitsort.getSource(), diagOS); + auto indices = requirePointerABIAddress(op, vbitsort.getIndices(), diagOS); + auto config = packVbitsortConfig(op, vbitsort.getRepeatTimes()); + if (failed(destination) || failed(source) || failed(indices) || failed(config)) + return failure(); + callArgs.push_back(*destination); + callArgs.push_back(*source); + callArgs.push_back(*indices); + callArgs.push_back(*config); + } else if (auto scatter = dyn_cast(op)) { + Type valueElemType = getElementTypeFromVectorLike(scatter.getValue().getType()); + auto basePtr = requirePointerABIAddress(op, scatter.getDestination(), diagOS); + auto mask = buildDynamicPltMask(builder, module, loc, scatter.getActiveLanes(), + valueElemType, diagOS); + if (!valueElemType || failed(basePtr) || failed(mask)) + return failure(); + callArgs.push_back(scatter.getValue()); + callArgs.push_back(*basePtr); + callArgs.push_back(scatter.getOffsets()); + callArgs.push_back(*mask); + } else { + diagOS << "VPTO LLVM emission failed: op lowering is not implemented for " + << op->getName().getStringRef() << "\n"; + return failure(); + } + + SmallVector argTypes; + for (Value arg : callArgs) + argTypes.push_back(arg.getType()); + + auto funcType = builder.getFunctionType(argTypes, intrinsicResultTypes); + auto callee = getOrCreateExternalFunc(module, *calleeName, funcType); + auto call = builder.create(loc, callee, callArgs); + if (op->getNumResults() == 0) + builder.eraseOp(op); + else { + SmallVector finalResults; + finalResults.reserve(op->getNumResults()); + for (auto [idx, result] : + llvm::enumerate(call.getResults().take_front(op->getNumResults()))) { + Type surfaceType = surfaceResultTypes[idx]; + if (isa(surfaceType)) { + diagOS << "VPTO LLVM emission failed: unexpected LLVM pointer surface " + "result type on op "; + op->print(diagOS); + diagOS << "\n"; + return failure(); + } + if (isa(surfaceType)) { + finalResults.push_back(buildBridgeCast(builder, loc, result, surfaceType)); + continue; + } + finalResults.push_back(result); + } + builder.replaceOp(op, finalResults); + } + return success(); +} + +static LogicalResult rewriteVPTOOps(ModuleOp module, llvm::raw_ostream &diagOS) { + SmallVector opsToRewrite; + module.walk([&](Operation *op) { + if (op->getName().getDialectNamespace() != "pto") + return; + if (isa(op)) + return; + opsToRewrite.push_back(op); + }); + + for (Operation *op : opsToRewrite) { + if (failed(rewriteVPTOOp(op, module, diagOS))) + return failure(); + } + + bool hasVPTO = false; + module.walk([&](Operation *op) { + if (op->getName().getDialectNamespace() != "pto") + return; + if (isa(op)) + return; + hasVPTO = true; + }); + + SmallVector poisonOps; + module.walk([&](Operation *op) { + auto name = op->getName().getStringRef(); + if (name == "ub.poison" && + op->getNumResults() == 1 && + isa(op->getResult(0).getType())) + poisonOps.push_back(op); + }); + for (Operation *op : poisonOps) { + OpBuilder builder(op); + auto abiType = cast(convertVPTOType(op->getResult(0).getType(), builder)); + auto zeroAttr = DenseElementsAttr::get(abiType, builder.getI8IntegerAttr(0)); + auto zero = builder.create(op->getLoc(), abiType, zeroAttr); + op->getResult(0).replaceAllUsesWith(zero.getResult()); + op->erase(); + } + + return success(!hasVPTO); +} + +static Type normalizeTypeForOfficialLLVMLowering(Type type, Builder &builder) { + type = convertVPTOType(type, builder); + + if (auto memrefType = dyn_cast(type)) { + auto addrAttr = + dyn_cast_or_null(memrefType.getMemorySpace()); + if (!addrAttr) + return type; + unsigned addrSpace = getExternalPointerAddressSpace(memrefType); + return MemRefType::get(memrefType.getShape(), memrefType.getElementType(), + memrefType.getLayout(), + builder.getI64IntegerAttr(addrSpace)); + } + + if (auto memrefType = dyn_cast(type)) { + auto addrAttr = + dyn_cast_or_null(memrefType.getMemorySpace()); + if (!addrAttr) + return type; + // Official MemRef-to-LLVM conversion requires integer memory spaces. + return UnrankedMemRefType::get(memrefType.getElementType(), + builder.getI64IntegerAttr( + static_cast(AddressSpace::GM))); + } + + return type; +} + +static void normalizeFuncSignaturesForOfficialLLVMLowering(ModuleOp module) { + Builder builder(module.getContext()); + + for (func::FuncOp funcOp : module.getOps()) { + FunctionType oldType = funcOp.getFunctionType(); + SmallVector newInputs; + SmallVector newResults; + bool changed = false; + + newInputs.reserve(oldType.getNumInputs()); + for (Type input : oldType.getInputs()) { + Type normalized = normalizeTypeForOfficialLLVMLowering(input, builder); + changed |= (normalized != input); + newInputs.push_back(normalized); + } + + newResults.reserve(oldType.getNumResults()); + for (Type result : oldType.getResults()) { + Type normalized = normalizeTypeForOfficialLLVMLowering(result, builder); + changed |= (normalized != result); + newResults.push_back(normalized); + } + + if (!changed) + continue; + + auto newType = builder.getFunctionType(newInputs, newResults); + funcOp.setFunctionTypeAttr(TypeAttr::get(newType)); + + if (funcOp.isExternal()) + continue; + Block &entry = funcOp.getBody().front(); + for (auto [arg, newType] : llvm::zip(entry.getArguments(), newInputs)) + if (arg.getType() != newType) + arg.setType(newType); + } +} + +static void ensureAIVScopeDummyDecl(ModuleOp module) { + SymbolTable symbolTable(module); + if (symbolTable.lookup(kAIVScopeDummyCallee)) + return; + + OpBuilder builder(module.getBodyRegion()); + builder.setInsertionPointToStart(module.getBody()); + auto funcType = builder.getFunctionType(TypeRange{}, TypeRange{}); + auto dummy = builder.create(module.getLoc(), + kAIVScopeDummyCallee, funcType); + dummy.setPrivate(); +} + +static void materializeVecScopeCarrierLoops(ModuleOp module) { + MLIRContext *ctx = module.getContext(); + (void)ctx->getOrLoadDialect(); + (void)ctx->getOrLoadDialect(); + ensureAIVScopeDummyDecl(module); + + SmallVector scopes; + module.walk([&](pto::VecScopeOp vecScope) { scopes.push_back(vecScope); }); + + IRRewriter rewriter(module.getContext()); + for (pto::VecScopeOp vecScope : llvm::reverse(scopes)) { + if (!vecScope || vecScope.getBody().empty()) + continue; + + rewriter.setInsertionPoint(vecScope); + auto loc = vecScope.getLoc(); + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + scf::ForOp carrier = rewriter.create(loc, c0, c1, c1); + + Block &vecScopeBody = vecScope.getBody().front(); + Block *carrierBody = carrier.getBody(); + Operation *yield = carrierBody->getTerminator(); + carrierBody->getOperations().splice(Block::iterator(yield), + vecScopeBody.getOperations(), + vecScopeBody.begin(), + vecScopeBody.end()); + rewriter.setInsertionPoint(yield); + rewriter.create(loc, kAIVScopeDummyCallee, TypeRange{}, + ValueRange{}); + rewriter.eraseOp(vecScope); + } + + SmallVector strictScopes; + module.walk( + [&](pto::StrictVecScopeOp strictVecScope) { strictScopes.push_back(strictVecScope); }); + + for (pto::StrictVecScopeOp strictVecScope : llvm::reverse(strictScopes)) { + if (!strictVecScope || strictVecScope.getBody().empty()) + continue; + + rewriter.setInsertionPoint(strictVecScope); + auto loc = strictVecScope.getLoc(); + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + scf::ForOp carrier = rewriter.create(loc, c0, c1, c1); + + Block &strictBody = strictVecScope.getBody().front(); + Block *carrierBody = carrier.getBody(); + Operation *yield = carrierBody->getTerminator(); + + IRMapping mapping; + for (auto [blockArg, capture] : + llvm::zip(strictBody.getArguments(), strictVecScope.getCaptures())) + mapping.map(blockArg, capture); + + rewriter.setInsertionPoint(yield); + for (Operation &nested : strictBody.getOperations()) + rewriter.clone(nested, mapping); + rewriter.create(loc, kAIVScopeDummyCallee, TypeRange{}, + ValueRange{}); + + rewriter.eraseOp(strictVecScope); + } +} + +static bool satisfiesAIVectorScopeLatchPostcondition(llvm::Loop *loop) { + llvm::BasicBlock *latch = loop->getLoopLatch(); + if (!latch) + return false; + + llvm::SmallVector preds(llvm::predecessors(latch)); + if (preds.size() != 1) + return false; + + auto *predTerm = preds.front()->getTerminator(); + return predTerm && predTerm->getNumSuccessors() == 1 && + predTerm->getSuccessor(0) == latch; +} + +// Bisheng imposes a strict CFG contract on loops carrying +// `llvm.loop.aivector_scope` metadata: +// 1. the latch must have exactly one predecessor +// 2. that predecessor must have exactly one successor, namely the latch +// +// The generic SCF/LLVM lowering pipeline does not preserve this shape for us. +// Therefore VPTO LLVM emission treats this as a required postcondition instead +// of a best-effort cleanup: +// - if the loop already satisfies the contract, keep it as-is +// - otherwise normalize all latch predecessors through a dummy block +// - if normalization still cannot re-establish the contract, fail the export +// +// Failing loudly here is intentional. Silently attaching aivscope metadata to +// an unsupported latch shape only defers the problem into Bisheng as a backend +// crash, which makes future regressions harder to diagnose. +static LogicalResult ensureDummyPredForAIVectorScopeLatch(llvm::Loop *loop, + llvm::raw_ostream &diagOS) { + if (satisfiesAIVectorScopeLatchPostcondition(loop)) + return success(); + + llvm::BasicBlock *latch = loop->getLoopLatch(); + if (!latch) { + diagOS << "VPTO LLVM emission failed: aivscope loop is missing a latch\n"; + return failure(); + } + + llvm::SmallVector preds(llvm::predecessors(latch)); + if (preds.empty()) { + diagOS << "VPTO LLVM emission failed: aivscope latch has no predecessor\n"; + return failure(); + } + + auto *dummy = llvm::SplitBlockPredecessors( + latch, preds, "aivscope.dummy", static_cast(nullptr), + static_cast(nullptr), nullptr, /*PreserveLCSSA=*/false); + if (!dummy) { + diagOS << "VPTO LLVM emission failed: failed to normalize aivscope latch " + "predecessors\n"; + return failure(); + } + + if (!satisfiesAIVectorScopeLatchPostcondition(loop)) { + diagOS << "VPTO LLVM emission failed: normalized aivscope latch still does " + "not satisfy the single-predecessor/single-successor contract\n"; + return failure(); + } + return success(); +} + +static LogicalResult attachAIVectorScopeMetadata( + llvm::Module &llvmModule, llvm::raw_ostream &diagOS) { + llvm::Function *dummyCallee = llvmModule.getFunction(kAIVScopeDummyCallee); + if (!dummyCallee) + return success(); + + for (llvm::Function &function : llvmModule) { + if (function.isDeclaration()) + continue; + llvm::DominatorTree dt(function); + llvm::LoopInfo loopInfo(dt); + + // Stage 1: collect the lowered vecscope markers in this function. Each + // marker should end up inside the final LLVM loop that carries one + // `pto.vecscope` / `pto.strict_vecscope`. + llvm::SmallVector dummyCalls; + for (llvm::BasicBlock &block : function) { + for (llvm::Instruction &inst : block) { + auto *call = dyn_cast(&inst); + if (call && call->getCalledFunction() == dummyCallee) + dummyCalls.push_back(call); + } + } + + for (llvm::CallInst *dummyCall : dummyCalls) { + llvm::BasicBlock *markedBlock = dummyCall->getParent(); + llvm::Loop *loop = loopInfo.getLoopFor(markedBlock); + if (!loop) { + diagOS << "VPTO LLVM emission failed: aivscope_dummy in function " + << function.getName() << " does not belong to an LLVM loop\n"; + return failure(); + } + + // Stage 2: if the marker ended up in the loop latch, split the block so + // the eventual latch stays as a clean backedge block instead of carrying + // vector-thread side effects. + if (markedBlock == loop->getLoopLatch() && + dummyCall != markedBlock->getTerminator()) { + markedBlock->splitBasicBlock(dummyCall->getIterator(), "aivscope.latch"); + dt.recalculate(function); + loopInfo.releaseMemory(); + loopInfo.analyze(dt); + markedBlock = dummyCall->getParent(); + loop = loopInfo.getLoopFor(markedBlock); + if (!loop) { + diagOS << "VPTO LLVM emission failed: split aivscope latch in " + << function.getName() + << " no longer belongs to an LLVM loop\n"; + return failure(); + } + } + + if (failed(ensureDummyPredForAIVectorScopeLatch(loop, diagOS))) + return failure(); + + // Stage 3: after any CFG surgery, re-query the loop and attach + // `llvm.loop.aivector_scope` to the normalized latch backedge. The dummy + // marker has served its purpose by this point and is removed. + dt.recalculate(function); + loopInfo.releaseMemory(); + loopInfo.analyze(dt); + loop = loopInfo.getLoopFor(markedBlock); + if (!loop) { + diagOS << "VPTO LLVM emission failed: aivscope_dummy in function " + << function.getName() + << " lost its loop after latch normalization\n"; + return failure(); + } + + llvm::BasicBlock *latch = loop->getLoopLatch(); + auto *branch = dyn_cast_or_null( + latch ? latch->getTerminator() : nullptr); + if (!branch || branch->isConditional()) { + diagOS << "VPTO LLVM emission failed: normalized aivscope loop in " + << function.getName() + << " does not have an unconditional latch backedge\n"; + return failure(); + } + + llvm::LLVMContext &ctx = llvmModule.getContext(); + llvm::Metadata *ops[] = { + nullptr, llvm::MDNode::get(ctx, llvm::MDString::get(ctx, "llvm.loop.aivector_scope"))}; + auto *loopID = llvm::MDNode::getDistinct(ctx, ops); + loopID->replaceOperandWith(0, loopID); + branch->setMetadata(llvm::LLVMContext::MD_loop, loopID); + dummyCall->eraseFromParent(); + } + } + + if (dummyCallee->use_empty()) + dummyCallee->eraseFromParent(); + return success(); +} + +static void attachHIVMKernelAnnotations(llvm::Module &llvmModule) { + llvm::NamedMDNode *annotations = llvmModule.getOrInsertNamedMetadata( + "hivm.annotations"); + llvm::LLVMContext &ctx = llvmModule.getContext(); + llvm::Type *i32Ty = llvm::Type::getInt32Ty(ctx); + llvm::Constant *one = llvm::ConstantInt::get(i32Ty, 1); + + auto addAnnotation = [&](llvm::Function &function, llvm::StringRef kind) { + llvm::Metadata *ops[] = { + llvm::ValueAsMetadata::get(&function), + llvm::MDString::get(ctx, kind), + llvm::ConstantAsMetadata::get(one)}; + annotations->addOperand(llvm::MDNode::get(ctx, ops)); + }; + + for (llvm::Function &function : llvmModule) { + if (function.isDeclaration()) + continue; + if (function.getLinkage() != llvm::GlobalValue::ExternalLinkage) + continue; + + llvm::StringRef name = function.getName(); + if (name.contains(".extracted") || name.contains(".vector.thread")) + continue; + + addAnnotation(function, "kernel"); + addAnnotation(function, "kernel_with_simd"); + } +} + +static FailureOr extractQuotedLLVMFnAttr(llvm::StringRef ir, + llvm::StringRef key) { + std::string pattern = "\""; + pattern += key.str(); + pattern += "\"=\""; + size_t start = ir.find(pattern); + if (start == llvm::StringRef::npos) + return failure(); + start += pattern.size(); + size_t end = ir.find('"', start); + if (end == llvm::StringRef::npos || end <= start) + return failure(); + return ir.slice(start, end).str(); +} + +static FailureOr +queryDefaultTargetAttrs(const VPTOEmissionOptions &options, + llvm::raw_ostream &diagOS) { + static llvm::StringMap cache; + + if (options.targetTriple.empty() || options.march.empty() || + options.aicoreArch.empty()) { + diagOS << "VPTO LLVM emission failed: missing target query options\n"; + return failure(); + } + + std::string cacheKey = + options.targetTriple + "|" + options.march + "|" + options.aicoreArch; + if (auto it = cache.find(cacheKey); it != cache.end()) + return it->second; + + auto bisheng = llvm::sys::findProgramByName("bisheng"); + if (!bisheng) { + diagOS << "VPTO LLVM emission failed: unable to find 'bisheng' in PATH\n"; + return failure(); + } + const std::string &bishengPath = *bisheng; + + llvm::SmallString<64> inputPath; + llvm::SmallString<64> outputPath; + int inputFD = -1; + int outputFD = -1; + if (auto ec = llvm::sys::fs::createTemporaryFile("ptoas-vpto-target-query", + "c", inputFD, inputPath)) { + diagOS << "VPTO LLVM emission failed: cannot create bisheng query input: " + << ec.message() << "\n"; + return failure(); + } + if (auto ec = llvm::sys::fs::createTemporaryFile("ptoas-vpto-target-query", + "ll", outputFD, outputPath)) { + llvm::sys::fs::remove(inputPath); + llvm::sys::Process::SafelyCloseFileDescriptor(inputFD); + diagOS << "VPTO LLVM emission failed: cannot create bisheng query output: " + << ec.message() << "\n"; + return failure(); + } + + auto cleanup = llvm::make_scope_exit([&]() { + llvm::sys::fs::remove(inputPath); + llvm::sys::fs::remove(outputPath); + }); + + { + llvm::raw_fd_ostream inputOS(inputFD, /*shouldClose=*/false); + inputOS << "void f(void) {}\n"; + } + llvm::sys::Process::SafelyCloseFileDescriptor(inputFD); + llvm::sys::Process::SafelyCloseFileDescriptor(outputFD); + + llvm::SmallString<128> stderrPath; + int stderrFD = -1; + if (auto ec = llvm::sys::fs::createTemporaryFile("ptoas-vpto-target-query", + "stderr", stderrFD, + stderrPath)) { + diagOS << "VPTO LLVM emission failed: cannot create bisheng query stderr: " + << ec.message() << "\n"; + return failure(); + } + auto stderrCleanup = llvm::make_scope_exit([&]() { + llvm::sys::fs::remove(stderrPath); + }); + llvm::sys::Process::SafelyCloseFileDescriptor(stderrFD); + + llvm::SmallVector argStorage = { + bishengPath, + ("--target=" + options.targetTriple), + ("-march=" + options.march), + ("--cce-aicore-arch=" + options.aicoreArch), + "--cce-aicore-only", + "-x", + "c", + inputPath.str().str(), + "-S", + "-emit-llvm", + "-o", + outputPath.str().str(), + }; + llvm::SmallVector args; + args.reserve(argStorage.size()); + for (const std::string &arg : argStorage) + args.push_back(arg); + + std::string execErr; + bool execFailed = false; + int rc = llvm::sys::ExecuteAndWait( + bishengPath, args, std::nullopt, + {std::nullopt, std::nullopt, llvm::StringRef(stderrPath)}, 0, 0, + &execErr, &execFailed); + + auto stderrBuffer = llvm::MemoryBuffer::getFile(stderrPath); + llvm::StringRef stderrText = + stderrBuffer ? stderrBuffer.get()->getBuffer() : llvm::StringRef(); + + if (execFailed || rc != 0) { + diagOS << "VPTO LLVM emission failed: bisheng target query failed\n"; + diagOS << "Command:"; + for (llvm::StringRef arg : args) + diagOS << " " << arg; + diagOS << "\n"; + if (!execErr.empty()) + diagOS << execErr << "\n"; + if (!stderrText.empty()) + diagOS << stderrText << "\n"; + return failure(); + } + + auto outputBuffer = llvm::MemoryBuffer::getFile(outputPath); + if (!outputBuffer) { + diagOS << "VPTO LLVM emission failed: cannot read bisheng query output\n"; + return failure(); + } + + FailureOr targetCPU = + extractQuotedLLVMFnAttr(outputBuffer.get()->getBuffer(), "target-cpu"); + FailureOr targetFeatures = extractQuotedLLVMFnAttr( + outputBuffer.get()->getBuffer(), "target-features"); + if (failed(targetCPU) || failed(targetFeatures)) { + diagOS << "VPTO LLVM emission failed: cannot parse bisheng target attrs\n"; + diagOS << outputBuffer.get()->getBuffer() << "\n"; + return failure(); + } + + QueriedTargetAttrs attrs{*targetCPU, *targetFeatures}; + cache[cacheKey] = attrs; + return attrs; +} + +static LogicalResult +applyQueriedTargetAttrs(ModuleOp module, const VPTOEmissionOptions &options, + llvm::raw_ostream &diagOS) { + FailureOr attrs = queryDefaultTargetAttrs(options, diagOS); + if (failed(attrs)) { + if (options.defaultTargetCPU.empty() || + options.defaultTargetFeatures.empty()) + return failure(); + diagOS << "VPTO LLVM emission: falling back to configured default target attributes\n"; + attrs = QueriedTargetAttrs{options.defaultTargetCPU, + options.defaultTargetFeatures}; + } + + MLIRContext *ctx = module.getContext(); + StringAttr cpuAttr = StringAttr::get(ctx, attrs->targetCPU); + LLVM::TargetFeaturesAttr featureAttr = + LLVM::TargetFeaturesAttr::get(ctx, attrs->targetFeatures); + module.walk([&](LLVM::LLVMFuncOp funcOp) { + funcOp.setTargetCpuAttr(cpuAttr); + funcOp.setTargetFeaturesAttr(featureAttr); + }); + return success(); +} + +static llvm::Value *castABIValue(llvm::IRBuilder<> &builder, llvm::Value *value, + llvm::Type *targetType) { + if (value->getType() == targetType) + return value; + + if (auto *targetPtr = dyn_cast(targetType)) { + auto *sourcePtr = dyn_cast(value->getType()); + if (!sourcePtr) + return nullptr; + if (sourcePtr->getAddressSpace() == targetPtr->getAddressSpace()) + return builder.CreateBitCast(value, targetType); + return builder.CreateAddrSpaceCast(value, targetType); + } + + if (targetType->isIntegerTy()) { + if (value->getType()->isIntegerTy()) { + unsigned srcWidth = value->getType()->getIntegerBitWidth(); + unsigned dstWidth = targetType->getIntegerBitWidth(); + if (srcWidth == dstWidth) + return value; + if (srcWidth < dstWidth) + return builder.CreateZExt(value, targetType); + return builder.CreateTrunc(value, targetType); + } + } + + return nullptr; +} + +static llvm::Value *materializeABIExpr(llvm::IRBuilder<> &builder, + const ABIExpr &expr, + llvm::Function *wrapper, + llvm::Type *targetType) { + switch (expr.kind) { + case ABIExpr::Kind::Constant: + return llvm::ConstantInt::get(targetType, expr.constant); + case ABIExpr::Kind::FuncArg: { + if (expr.argIndex >= wrapper->arg_size()) + return nullptr; + return castABIValue(builder, wrapper->getArg(expr.argIndex), targetType); + } + case ABIExpr::Kind::Mul: { + llvm::Value *lhs = + materializeABIExpr(builder, *expr.lhs, wrapper, targetType); + llvm::Value *rhs = + materializeABIExpr(builder, *expr.rhs, wrapper, targetType); + if (!lhs || !rhs) + return nullptr; + return builder.CreateMul(lhs, rhs); + } + } + return nullptr; +} + +static unsigned getMemRefExpandedArgCount(int64_t rank) { + return 2u + 1u + static_cast(rank) + static_cast(rank); +} + +static llvm::Value *resolveInsertedAggregateValue(llvm::Value *value, + llvm::ArrayRef idxs) { + auto *insert = dyn_cast(value); + if (!insert) + return nullptr; + + if (insert->getIndices() == idxs) + return insert->getInsertedValueOperand(); + + return resolveInsertedAggregateValue(insert->getAggregateOperand(), idxs); +} + +static llvm::Value *resolveAddrSpaceRoundTrip(llvm::Value *value) { + auto *outerCast = dyn_cast(value); + if (!outerCast) + return nullptr; + + auto *innerCast = dyn_cast(outerCast->getPointerOperand()); + if (!innerCast) + return nullptr; + + llvm::Value *original = innerCast->getPointerOperand(); + if (original->getType() != outerCast->getType()) + return nullptr; + + auto *innerDstPtr = dyn_cast(innerCast->getType()); + auto *outerDstPtr = dyn_cast(outerCast->getType()); + auto *origPtr = dyn_cast(original->getType()); + if (!innerDstPtr || !outerDstPtr || !origPtr) + return nullptr; + + if (innerDstPtr->getAddressSpace() == origPtr->getAddressSpace()) + return nullptr; + if (outerDstPtr->getAddressSpace() != origPtr->getAddressSpace()) + return nullptr; + + return original; +} + +static void simplifyAggregateCarrierOps(llvm::Function &function) { + bool changed = true; + while (changed) { + changed = false; + + SmallVector toErase; + for (llvm::BasicBlock &block : function) { + for (llvm::Instruction &inst : block) { + if (auto *cast = dyn_cast(&inst)) { + if (llvm::Value *resolved = resolveAddrSpaceRoundTrip(cast)) { + cast->replaceAllUsesWith(resolved); + toErase.push_back(cast); + changed = true; + continue; + } + } + + if (auto *extract = dyn_cast(&inst)) { + if (llvm::Value *resolved = + resolveInsertedAggregateValue(extract->getAggregateOperand(), + extract->getIndices())) { + extract->replaceAllUsesWith(resolved); + toErase.push_back(extract); + changed = true; + continue; + } + } + + if (llvm::isInstructionTriviallyDead(&inst)) { + toErase.push_back(&inst); + changed = true; + } + } + } + + for (llvm::Instruction *inst : toErase) + if (!inst->isTerminator()) + inst->eraseFromParent(); + } +} + +static LogicalResult rewriteFunctionsToEmitCStyleABI( + llvm::Module &llvmModule, const llvm::StringMap &specs, + llvm::raw_ostream &diagOS) { + SmallVector funcs; + for (llvm::Function &function : llvmModule) + if (!function.isDeclaration()) + funcs.push_back(&function); + + for (llvm::Function *function : funcs) { + auto it = specs.find(function->getName()); + if (it == specs.end()) + continue; + + const FunctionABISpec &spec = it->second; + if (spec.args.empty()) + continue; + + bool needsRewrite = + llvm::any_of(spec.args, [](const ExternalArgABISpec &arg) { + return arg.isMemRef; + }); + if (!needsRewrite) + continue; + + SmallVector publicArgTypes; + SmallVector oldArgBaseIndex(spec.args.size(), 0); + unsigned oldArgCursor = 0; + bool supported = true; + for (auto [idx, argSpec] : llvm::enumerate(spec.args)) { + oldArgBaseIndex[idx] = oldArgCursor; + if (argSpec.isMemRef) { + if (argSpec.memrefSpec.rank != 1) { + supported = false; + break; + } + publicArgTypes.push_back(llvm::PointerType::get( + llvmModule.getContext(), argSpec.memrefSpec.addressSpace)); + oldArgCursor += getMemRefExpandedArgCount(argSpec.memrefSpec.rank); + } else { + if (oldArgCursor >= function->arg_size()) { + supported = false; + break; + } + publicArgTypes.push_back(function->getArg(oldArgCursor)->getType()); + ++oldArgCursor; + } + } + + if (!supported || oldArgCursor != function->arg_size()) { + diagOS << "VPTO LLVM emission warning: skipping ABI rewrite for " + << function->getName() + << " because the lowered signature does not match the seam spec\n"; + continue; + } + + std::string originalName = function->getName().str(); + std::string tempName = "__ptoas_old_" + originalName; + function->setName(tempName); + function->setLinkage(llvm::GlobalValue::InternalLinkage); + + auto *publicType = llvm::FunctionType::get(function->getReturnType(), + publicArgTypes, + function->isVarArg()); + llvm::Function *replacement = llvm::Function::Create( + publicType, llvm::GlobalValue::ExternalLinkage, originalName, &llvmModule); + replacement->copyAttributesFrom(function); + replacement->setLinkage(llvm::GlobalValue::ExternalLinkage); + + unsigned publicArgIndex = 0; + for (llvm::Argument &arg : replacement->args()) + arg.setName("arg" + std::to_string(publicArgIndex++)); + + llvm::BasicBlock *bridgeEntry = llvm::BasicBlock::Create( + llvmModule.getContext(), "entry", replacement); + llvm::IRBuilder<> builder(bridgeEntry); + + llvm::ValueToValueMapTy vmap; + for (auto [idx, argSpec] : llvm::enumerate(spec.args)) { + llvm::Value *publicArg = replacement->getArg(idx); + unsigned oldBase = oldArgBaseIndex[idx]; + if (!argSpec.isMemRef) { + llvm::Value *casted = castABIValue( + builder, publicArg, function->getArg(oldBase)->getType()); + if (!casted) { + diagOS << "VPTO LLVM emission failed: cannot cast scalar arg for " + << originalName << "\n"; + return failure(); + } + vmap[function->getArg(oldBase)] = casted; + continue; + } + + llvm::Type *oldPtrTy = function->getArg(oldBase)->getType(); + llvm::Type *oldAlignedPtrTy = function->getArg(oldBase + 1)->getType(); + llvm::Type *oldOffsetTy = function->getArg(oldBase + 2)->getType(); + llvm::Type *oldSizeTy = function->getArg(oldBase + 3)->getType(); + llvm::Type *oldStrideTy = function->getArg(oldBase + 4)->getType(); + + llvm::Value *allocated = castABIValue(builder, publicArg, oldPtrTy); + llvm::Value *aligned = castABIValue(builder, publicArg, oldAlignedPtrTy); + llvm::Value *offset = materializeABIExpr( + builder, argSpec.memrefSpec.offset, replacement, oldOffsetTy); + llvm::Value *size = materializeABIExpr( + builder, argSpec.memrefSpec.totalSize, replacement, oldSizeTy); + llvm::Value *stride = materializeABIExpr( + builder, argSpec.memrefSpec.stride, replacement, oldStrideTy); + if (!allocated || !aligned || !offset || !size || !stride) { + diagOS << "VPTO LLVM emission failed: cannot materialize direct ABI for " + << originalName << "\n"; + return failure(); + } + + vmap[function->getArg(oldBase)] = allocated; + vmap[function->getArg(oldBase + 1)] = aligned; + vmap[function->getArg(oldBase + 2)] = offset; + vmap[function->getArg(oldBase + 3)] = size; + vmap[function->getArg(oldBase + 4)] = stride; + } + + llvm::SmallVector returns; + llvm::CloneFunctionInto(replacement, function, vmap, + llvm::CloneFunctionChangeType::LocalChangesOnly, + returns); + + llvm::BasicBlock *oldEntry = &replacement->getEntryBlock(); + llvm::BasicBlock *clonedEntry = oldEntry->getNextNode(); + if (!clonedEntry) { + diagOS << "VPTO LLVM emission failed: cloned function body is empty for " + << originalName << "\n"; + return failure(); + } + builder.CreateBr(clonedEntry); + + function->eraseFromParent(); + simplifyAggregateCarrierOps(*replacement); + } + + return success(); +} + +static std::unique_ptr +buildLLVMModuleFromPreparedVPTO(ModuleOp module, + llvm::LLVMContext &llvmContext, + const VPTOEmissionOptions &options, + llvm::raw_ostream &diagOS) { + materializeVecScopeCarrierLoops(module); + + if (failed(normalizePtoMemRefSpaces(module, diagOS))) + return nullptr; + + if (failed(normalizePtoAlignsToABI(module, diagOS))) + return nullptr; + + if (failed(rewriteVPTOOps(module, diagOS))) { + diagOS << "VPTO LLVM emission failed: VPTO-to-call rewriting failed\n"; + return nullptr; + } + + if (failed(normalizePtoPtrsToLLVM(module, diagOS))) + return nullptr; + + normalizeFuncSignaturesForOfficialLLVMLowering(module); + + PassManager pm(module.getContext()); + pm.enableVerifier(); + pm.addPass(createConvertSCFToCFPass()); + pm.addPass(createArithToLLVMConversionPass()); + pm.addPass(createConvertIndexToLLVMPass()); + pm.addPass(createFinalizeMemRefToLLVMConversionPass()); + pm.addPass(createConvertFuncToLLVMPass()); + pm.addPass(createConvertControlFlowToLLVMPass()); + pm.addPass(createReconcileUnrealizedCastsPass()); + if (failed(pm.run(module))) { + diagOS << "VPTO LLVM emission failed: official lowering pipeline failed\n"; + return nullptr; + } + + if (failed(applyQueriedTargetAttrs(module, options, diagOS))) + return nullptr; + + registerBuiltinDialectTranslation(*module.getContext()); + registerLLVMDialectTranslation(*module.getContext()); + auto llvmModule = translateModuleToLLVMIR(module.getOperation(), llvmContext); + if (!llvmModule) { + diagOS << "VPTO LLVM emission failed: LLVM IR export failed\n"; + return nullptr; + } + + if (failed(attachAIVectorScopeMetadata(*llvmModule, diagOS))) + return nullptr; + attachHIVMKernelAnnotations(*llvmModule); + llvmModule->setModuleIdentifier("ptoas.hivm.official"); + llvmModule->setSourceFileName("ptoas.hivm.official"); + return llvmModule; +} + +} // namespace + +LogicalResult +translateVPTOModuleToLLVMText(ModuleOp module, llvm::raw_ostream &os, + const VPTOEmissionOptions &options, + llvm::raw_ostream &diagOS) { + llvm::LLVMContext llvmContext; + auto llvmModule = + buildLLVMModuleFromPreparedVPTO(module, llvmContext, options, diagOS); + if (!llvmModule) + return failure(); + llvmModule->print(os, nullptr); + return success(); +} + +LogicalResult +translateVPTOModuleToLLVMBitcode(ModuleOp module, llvm::raw_ostream &os, + const VPTOEmissionOptions &options, + llvm::raw_ostream &diagOS) { + llvm::LLVMContext llvmContext; + auto llvmModule = + buildLLVMModuleFromPreparedVPTO(module, llvmContext, options, diagOS); + if (!llvmModule) + return failure(); + llvm::WriteBitcodeToFile(*llvmModule, os); + return success(); +} + +} // namespace mlir::pto diff --git a/python/pto/dialects/pto.py b/python/pto/dialects/pto.py index ab6648fc4..b90174f9f 100644 --- a/python/pto/dialects/pto.py +++ b/python/pto/dialects/pto.py @@ -705,3 +705,875 @@ def _install_op_aliases(): __all__.extend(_install_op_aliases()) + +# ----------------------------------------------------------------------------- +# Experimental VPTO Python DSL (`@pto.vkernel`) +# ----------------------------------------------------------------------------- +import ast as _ast +import inspect as _inspect +import textwrap as _textwrap +from dataclasses import dataclass as _dataclass + + +class _VKernelType: + def render(self): + raise NotImplementedError + + +@_dataclass(frozen=True) +class _VKernelScalarType(_VKernelType): + name: str + + def render(self): + return self.name + + +@_dataclass(frozen=True) +class _VKernelPtrType(_VKernelType): + elem: _VKernelType + space: str + + def render(self): + return f"!pto.ptr<{self.elem.render()}, {self.space}>" + + +@_dataclass(frozen=True) +class _VKernelVRegType(_VKernelType): + lanes: int + elem: _VKernelType + + def render(self): + return f"!pto.vreg<{self.lanes}x{self.elem.render()}>" + + +@_dataclass(frozen=True) +class _VKernelConstBinding: + value: object + + +@_dataclass(frozen=True) +class _VKernelStructDef(_VKernelType): + name: str + fields: tuple + + def render(self): + raise _VKernelCompileError(f"{self.name} is a template-only surface type; use .jit(...) to specialize it") + + def __call__(self, **kwargs): + return _VKernelStructBinding(self, dict(kwargs)) + + +@_dataclass(frozen=True) +class _VKernelStructBinding: + schema: _VKernelStructDef + values: dict + + +@_dataclass(frozen=True) +class _VKStaticSequence: + values: tuple + + +@_dataclass(frozen=True) +class _VKStructValue: + schema: _VKernelStructDef + fields: dict + + +i1 = _VKernelScalarType("i1") +i8 = _VKernelScalarType("i8") +i16 = _VKernelScalarType("i16") +i32 = _VKernelScalarType("i32") +i64 = _VKernelScalarType("i64") +f16 = _VKernelScalarType("f16") +bf16 = _VKernelScalarType("bf16") +f32 = _VKernelScalarType("f32") +_vk_index = _VKernelScalarType("index") +mask = _VKernelScalarType("!pto.mask") +align = _VKernelScalarType("!pto.align") + + +def ptr(elem_type, space): + return _VKernelPtrType(elem_type, space) + + +def vreg(lanes, elem_type): + return _VKernelVRegType(lanes, elem_type) + + +def const(value): + return _VKernelConstBinding(value) + + +def struct(cls): + annotations = dict(getattr(cls, "__annotations__", {})) + if not annotations: + raise _VKernelCompileError("@pto.struct requires annotated fields") + fields = [] + for name, field_ty in annotations.items(): + if field_ty not in (ptr, const): + raise _VKernelCompileError( + f"unsupported field annotation for {cls.__name__}.{name}: {field_ty!r}" + ) + fields.append((name, field_ty)) + return _VKernelStructDef(cls.__name__, tuple(fields)) + + +@struct +class Tile: + ub_ptr: ptr + shape: const + + +tile = Tile + + +class _VKernelCompileError(Exception): + pass + + +@_dataclass +class _VKValue: + name: str | None = None + type: _VKernelType | None = None + literal: object | None = None + + def render_type(self): + if self.type is None: + raise _VKernelCompileError(f"unresolved type for {self.name}") + return self.type.render() + + +def _project_result(group, index, ty): + return _VKValue(f"{group.name}#{index}", ty) + + +def _load_standard_dialects(): + try: + from mlir.dialects import arith as _mlir_arith # noqa: F401 + from mlir.dialects import func as _mlir_func # noqa: F401 + from mlir.dialects import scf as _mlir_scf # noqa: F401 + except ImportError as exc: + raise RuntimeError("mlir standard dialect python bindings are required for vkernel parsing") from exc + + +class _VKernelContext: + def __init__(self): + self.ssa_counter = 0 + self.arg_counter = 0 + + def new_ssa(self): + name = f"%{self.ssa_counter}" + self.ssa_counter += 1 + return name + + def new_arg(self): + name = f"%arg{self.arg_counter}" + self.arg_counter += 1 + return name + + +def _type_key(ty): + return ty.render() if ty is not None else None + + +def _types_equal(lhs, rhs): + if lhs is None or rhs is None: + return lhs is rhs + return lhs.render() == rhs.render() + + +def _ensure_type(value, expected): + if value.type is None: + value.type = expected + return + if not _types_equal(value.type, expected): + raise _VKernelCompileError( + f"type mismatch for {value.name}: expected {expected.render()}, got {value.type.render()}" + ) + + +def _literal_text(value): + if isinstance(value, bool): + return "true" if value else "false" + return str(value) + + +def _coerce_surface_type(value): + if value is bool: + return i1 + if value is float: + return f32 + return value + + +def _ptr_elem_bytes(ptr_type): + if not isinstance(ptr_type, _VKernelPtrType): + raise _VKernelCompileError("elem_bytes requires a ptr type") + elem_name = ptr_type.elem.render() + table = { + "i8": 1, + "i16": 2, + "i32": 4, + "i64": 8, + "f16": 2, + "bf16": 2, + "f32": 4, + } + if elem_name not in table: + raise _VKernelCompileError(f"unsupported elem_bytes for {elem_name}") + return table[elem_name] + + +def _ptr_vector_lanes(ptr_type): + return 256 // _ptr_elem_bytes(ptr_type) + + +class _VKernelBuilder: + def __init__(self, py_fn, fn_def, target, kernel_name, specialization=None): + self.py_fn = py_fn + self.fn_def = fn_def + self.target = target + self.kernel_name = kernel_name + self.ctx = _VKernelContext() + self.specialization = specialization or {} + + def _emit(self, lines, indent, text): + lines.append(" " * indent + text) + + def _eval_type_expr(self, node): + expr = _ast.Expression(body=node) + globals_dict = dict(self.py_fn.__globals__) + globals_dict.update(globals()) + value = eval(compile(expr, self.py_fn.__code__.co_filename, "eval"), + globals_dict, {}) + value = _coerce_surface_type(value) + if not isinstance(value, _VKernelType): + raise _VKernelCompileError(f"unsupported vkernel type annotation: {value!r}") + return value + + def _new_value(self, ty=None): + return _VKValue(self.ctx.new_ssa(), ty) + + def _new_arg_value(self, ty=None): + return _VKValue(self.ctx.new_arg(), ty) + + def _materialize_value(self, value, lines, indent, expected_type=None): + if expected_type is not None: + _ensure_type(value, expected_type) + if value.name is not None: + return value + if value.literal is None: + raise _VKernelCompileError("value has no SSA name and cannot be materialized") + if value.type is None: + raise _VKernelCompileError("literal requires type context") + value.name = self.ctx.new_ssa() + lit = _literal_text(value.literal) + if isinstance(value.literal, bool): + self._emit(lines, indent, f"{value.name} = arith.constant {lit}") + else: + self._emit(lines, indent, f"{value.name} = arith.constant {lit} : {value.type.render()}") + return value + + def _literal_value(self, node, lines, indent, expected_type): + value = _VKValue(type=expected_type, literal=node.value) + if expected_type is None: + return value + return self._materialize_value(value, lines, indent) + + def _lower_attribute(self, node, env, lines, indent, expected_type=None): + if isinstance(node.value, _ast.Name): + if node.value.id not in env: + raise _VKernelCompileError(f"unknown name '{node.value.id}'") + base = env[node.value.id] + else: + base = self._lower_expr(node.value, env, lines, indent) + if isinstance(base, _VKStructValue): + if node.attr not in base.fields: + raise _VKernelCompileError(f"unsupported struct attribute '{node.attr}'") + field = base.fields[node.attr] + if isinstance(field, _VKValue): + return self._materialize_value(field, lines, indent, expected_type) + return field + if isinstance(base, _VKValue) and isinstance(base.type, _VKernelPtrType): + if node.attr == "elem_bytes": + return _VKValue(type=expected_type, literal=_ptr_elem_bytes(base.type)) + raise _VKernelCompileError(f"unsupported attribute access '{node.attr}'") + + def _lower_subscript(self, node, env, lines, indent, expected_type=None): + base = self._lower_expr(node.value, env, lines, indent) + if not isinstance(base, _VKStaticSequence): + raise _VKernelCompileError("subscript base must be a static sequence") + if not isinstance(node.slice, _ast.Constant) or not isinstance(node.slice.value, int): + raise _VKernelCompileError("only constant integer subscripts are supported") + index = node.slice.value + if index < 0 or index >= len(base.values): + raise _VKernelCompileError("subscript out of range") + value = base.values[index] + if not isinstance(value, _VKValue): + value = _VKValue(type=expected_type, literal=value) + return self._materialize_value(value, lines, indent, expected_type) if expected_type is not None else value + + def _lower_binop(self, node, env, lines, indent, expected_type=None): + lhs = self._lower_expr(node.left, env, lines, indent) + rhs = self._lower_expr(node.right, env, lines, indent) + if lhs.literal is not None and rhs.literal is not None: + if isinstance(node.op, _ast.Mult): + result = lhs.literal * rhs.literal + elif isinstance(node.op, _ast.FloorDiv): + result = lhs.literal // rhs.literal + else: + raise _VKernelCompileError(f"unsupported binary operator: {type(node.op).__name__}") + return _VKValue(type=expected_type, literal=result) + raise _VKernelCompileError("non-constant binary expressions are not supported yet") + + def _lower_expr(self, node, env, lines, indent, expected_type=None): + if isinstance(node, _ast.Name): + if node.id not in env: + raise _VKernelCompileError(f"unknown name '{node.id}'") + value = env[node.id] + if isinstance(value, (_VKStructValue, _VKStaticSequence)): + raise _VKernelCompileError(f"name '{node.id}' is not a scalar/SSA value") + if ( + isinstance(value, _VKValue) + and value.name is None + and value.literal is not None + and expected_type is not None + ): + return self._materialize_value( + _VKValue(type=expected_type, literal=value.literal), + lines, + indent, + ) + return self._materialize_value(value, lines, indent, expected_type) + if isinstance(node, _ast.Constant): + return self._literal_value(node, lines, indent, expected_type) + if isinstance(node, _ast.Attribute): + return self._lower_attribute(node, env, lines, indent, expected_type) + if isinstance(node, _ast.Subscript): + return self._lower_subscript(node, env, lines, indent, expected_type) + if isinstance(node, _ast.BinOp): + return self._lower_binop(node, env, lines, indent, expected_type) + if isinstance(node, _ast.Call): + results = self._lower_call(node, env, lines, indent, expected_types=[expected_type] if expected_type else None) + if len(results) != 1: + raise _VKernelCompileError("expression expected single result") + return results[0] + raise _VKernelCompileError(f"unsupported expression: {type(node).__name__}") + + def _lower_call_name(self, node): + if isinstance(node, _ast.Attribute) and isinstance(node.value, _ast.Name) and node.value.id == "pto": + return node.attr + raise _VKernelCompileError("only pto.* calls are supported") + + def _infer_expr_type(self, node, env): + if isinstance(node, _ast.Name): + if node.id not in env: + raise _VKernelCompileError(f"unknown name '{node.id}'") + value = env[node.id] + return value.type if isinstance(value, _VKValue) else None + if isinstance(node, _ast.Attribute): + try: + value = self._lower_attribute(node, env, [], 0) + except _VKernelCompileError: + return None + return value.type if isinstance(value, _VKValue) else None + if isinstance(node, _ast.Constant): + return None + return None + + def _format_typed_operands(self, values): + return ", ".join(v.name for v in values), ", ".join(v.render_type() for v in values) + + def _lower_call(self, node, env, lines, indent, expected_types=None): + opname = self._lower_call_name(node.func) + + if opname in ("set_loop_size_outtoub", "set_loop_size_ubtoout"): + ops = [self._lower_expr(arg, env, lines, indent, i64) for arg in node.args] + operands, types = self._format_typed_operands(ops) + self._emit(lines, indent, f"pto.{opname} {operands} : {types}") + return [] + + if opname == "castptr": + if len(node.args) != 2: + raise _VKernelCompileError("pto.castptr expects 2 arguments") + result_type = self._eval_type_expr(node.args[1]) + addr = self._lower_expr(node.args[0], env, lines, indent, i64) + result = self._new_value(result_type) + self._emit(lines, indent, f"{result.name} = pto.castptr {addr.name} : {addr.render_type()} -> {result.render_type()}") + return [result] + + if opname == "copy_gm_to_ubuf": + expected = [None, None, i64, i64, i64, i64, i64, i1, i64, i64, i64] + ops = [self._lower_expr(arg, env, lines, indent, expected[i]) for i, arg in enumerate(node.args)] + operands, types = self._format_typed_operands(ops) + self._emit(lines, indent, f"pto.copy_gm_to_ubuf {operands} : {types}") + return [] + + if opname == "copy_ubuf_to_gm": + expected = [None, None, i64, i64, i64, i64, i64, i64] + ops = [self._lower_expr(arg, env, lines, indent, expected[i]) for i, arg in enumerate(node.args)] + operands, types = self._format_typed_operands(ops) + self._emit(lines, indent, f"pto.copy_ubuf_to_gm {operands} : {types}") + return [] + + if opname in ("set_flag", "wait_flag"): + attrs = [] + for arg in node.args: + if not isinstance(arg, _ast.Constant) or not isinstance(arg.value, str): + raise _VKernelCompileError(f"pto.{opname} expects string literals") + attrs.append(arg.value) + self._emit(lines, indent, f'pto.{opname}["{attrs[0]}", "{attrs[1]}", "{attrs[2]}"]') + return [] + + if opname == "barrier": + arg = node.args[0] + if not isinstance(arg, _ast.Constant) or not isinstance(arg.value, str): + raise _VKernelCompileError("pto.barrier expects a string literal") + self._emit(lines, indent, f"pto.barrier #pto.pipe<{arg.value}>") + return [] + + if opname == "plt_b32": + src = self._lower_expr(node.args[0], env, lines, indent, i32) + res0 = self._new_value(mask) + res1 = self._new_value(i32) + self._emit(lines, indent, f"{res0.name}, {res1.name} = pto.plt_b32 {src.name} : i32 -> !pto.mask, i32") + return [res0, res1] + + if opname == "pset_b32": + arg = node.args[0] + if not isinstance(arg, _ast.Constant) or not isinstance(arg.value, str): + raise _VKernelCompileError("pto.pset_b32 expects a string literal") + res = self._new_value(mask) + self._emit(lines, indent, f'{res.name} = pto.pset_b32 "{arg.value}" : !pto.mask') + return [res] + + if opname == "vlds": + ptr_value = self._lower_expr(node.args[0], env, lines, indent) + if not isinstance(ptr_value.type, _VKernelPtrType): + raise _VKernelCompileError("pto.vlds expects a ptr operand") + offset = self._lower_expr(node.args[1], env, lines, indent, _vk_index) + result = self._new_value(vreg(_ptr_vector_lanes(ptr_value.type), ptr_value.type.elem)) + self._emit(lines, indent, + f"{result.name} = pto.vlds {ptr_value.name}[{offset.name}] : {ptr_value.render_type()} -> {result.render_type()}") + return [result] + + if opname == "vabs": + vec_value = self._lower_expr(node.args[0], env, lines, indent) + mask_value = self._lower_expr(node.args[1], env, lines, indent, mask) + result = self._new_value(vec_value.type) + self._emit(lines, indent, + f"{result.name} = pto.vabs {vec_value.name}, {mask_value.name} : {vec_value.render_type()}, {mask_value.render_type()} -> {result.render_type()}") + return [result] + + if opname == "vsts": + vec_value = self._lower_expr(node.args[0], env, lines, indent) + ptr_value = self._lower_expr(node.args[1], env, lines, indent) + offset = self._lower_expr(node.args[2], env, lines, indent, _vk_index) + mask_value = self._lower_expr(node.args[3], env, lines, indent, mask) + self._emit(lines, indent, + f"pto.vsts {vec_value.name}, {ptr_value.name}[{offset.name}], {mask_value.name} : {vec_value.render_type()}, {ptr_value.render_type()}, {mask_value.render_type()}") + return [] + + raise _VKernelCompileError(f"unsupported pto op in vkernel: {opname}") + + def _collect_assigned_names(self, statements): + names = set() + + class Visitor(_ast.NodeVisitor): + def visit_Assign(self, node): + for target in node.targets: + self._collect_target(target) + + def _collect_target(self, target): + if isinstance(target, _ast.Name): + names.add(target.id) + elif isinstance(target, _ast.Tuple): + for elt in target.elts: + self._collect_target(elt) + + visitor = Visitor() + for stmt in statements: + if isinstance(stmt, (_ast.With, _ast.For, _ast.If)): + continue + visitor.visit(stmt) + return names + + def _compile_block(self, statements, env, indent): + lines = [] + current_env = dict(env) + + for stmt in statements: + if isinstance(stmt, _ast.Assign): + if len(stmt.targets) != 1: + raise _VKernelCompileError("multiple assignment targets are not supported") + target = stmt.targets[0] + if isinstance(target, _ast.Name): + value = self._lower_expr(stmt.value, current_env, lines, indent) + current_env[target.id] = value + elif isinstance(target, _ast.Tuple): + results = self._lower_call(stmt.value, current_env, lines, indent) + if len(results) != len(target.elts): + raise _VKernelCompileError("tuple assignment arity mismatch") + for elt, value in zip(target.elts, results): + if not isinstance(elt, _ast.Name): + raise _VKernelCompileError("tuple assignment only supports names") + current_env[elt.id] = value + else: + raise _VKernelCompileError("unsupported assignment target") + continue + + if isinstance(stmt, _ast.AnnAssign): + if stmt.value is None: + raise _VKernelCompileError("annotation-only assignment is not supported") + if not isinstance(stmt.target, _ast.Name): + raise _VKernelCompileError("annotated assignment only supports names") + target_type = self._eval_type_expr(stmt.annotation) + value = self._lower_expr(stmt.value, current_env, lines, indent, target_type) + current_env[stmt.target.id] = value + continue + + if isinstance(stmt, _ast.Expr): + if isinstance(stmt.value, _ast.Call): + self._lower_call(stmt.value, current_env, lines, indent) + else: + self._lower_expr(stmt.value, current_env, lines, indent) + continue + + if isinstance(stmt, _ast.Return): + if stmt.value is not None: + raise _VKernelCompileError("only empty return is supported") + self._emit(lines, indent, "return") + continue + + if isinstance(stmt, _ast.With): + if len(stmt.items) != 1: + raise _VKernelCompileError("only single with item is supported") + item = stmt.items[0] + name = self._lower_call_name(item.context_expr.func) + if name not in ("strict_vecscope", "vecscope"): + raise _VKernelCompileError("unsupported with context") + if name == "strict_vecscope": + body_lines, body_result = self._compile_strict_vecscope(item, stmt.body, current_env, indent) + else: + body_lines, body_result = self._compile_vecscope(stmt.body, current_env, indent) + lines.extend(body_lines) + current_env.update(body_result) + continue + + if isinstance(stmt, _ast.For): + loop_lines, updated_env = self._compile_for(stmt, current_env, indent) + lines.extend(loop_lines) + current_env = updated_env + continue + + if isinstance(stmt, _ast.If): + if_lines, updated_env = self._compile_if(stmt, current_env, indent) + lines.extend(if_lines) + current_env = updated_env + continue + + raise _VKernelCompileError(f"unsupported statement: {type(stmt).__name__}") + + return lines, current_env + + def _compile_vecscope(self, body, outer_env, indent): + body_lines, _ = self._compile_block(body, dict(outer_env), indent + 1) + lines = [] + self._emit(lines, indent, "pto.vecscope {") + lines.extend(body_lines) + self._emit(lines, indent, "}") + return lines, {} + + def _compile_strict_vecscope(self, item, body, outer_env, indent): + if not isinstance(item.optional_vars, _ast.Tuple): + raise _VKernelCompileError("pto.strict_vecscope requires tuple binding in 'as'") + if len(item.context_expr.args) != len(item.optional_vars.elts): + raise _VKernelCompileError("strict_vecscope capture arity must match bound block arguments") + arg_names = [] + inner_env = {} + for elt in item.optional_vars.elts: + if not isinstance(elt, _ast.Name): + raise _VKernelCompileError("pto.strict_vecscope bindings must be names") + arg = self._new_arg_value() + arg_names.append((elt.id, arg)) + inner_env[elt.id] = arg + + for expr, (_, arg) in zip(item.context_expr.args, arg_names): + inferred_type = self._infer_expr_type(expr, outer_env) + if inferred_type is not None: + arg.type = inferred_type + + lines = [] + body_lines, body_env = self._compile_block(body, inner_env, indent + 1) + captures = [] + for name, arg in arg_names: + if arg.type is None and name in body_env and body_env[name].type is not None: + arg.type = body_env[name].type + for expr, (_, arg) in zip(item.context_expr.args, arg_names): + if arg.type is None: + raise _VKernelCompileError("strict_vecscope block argument type could not be inferred") + capture = self._lower_expr(expr, outer_env, lines, indent, expected_type=arg.type) + captures.append(capture) + capture_operands = ", ".join(value.name for value in captures) + block_args = ", ".join(f"{arg.name}: {arg.render_type()}" for _, arg in arg_names) + func_type = ", ".join(arg.render_type() for _, arg in arg_names) + + self._emit(lines, indent, f"pto.strict_vecscope({capture_operands}) {{") + self._emit(lines, indent, f"^bb0({block_args}):") + lines.extend(body_lines) + self._emit(lines, indent, f"}} : ({func_type}) -> ()") + return lines, {} + + def _compile_for(self, stmt, outer_env, indent): + if not isinstance(stmt.target, _ast.Name): + raise _VKernelCompileError("for target must be a single name") + if not isinstance(stmt.iter, _ast.Call) or not isinstance(stmt.iter.func, _ast.Name) or stmt.iter.func.id != "range": + raise _VKernelCompileError("only Python range(...) loops are supported") + if len(stmt.iter.args) != 3: + raise _VKernelCompileError("range expects exactly 3 arguments in vkernel") + + lines = [] + lb = self._lower_expr(stmt.iter.args[0], outer_env, lines, indent, _vk_index) + ub = self._lower_expr(stmt.iter.args[1], outer_env, lines, indent, _vk_index) + step = self._lower_expr(stmt.iter.args[2], outer_env, lines, indent, _vk_index) + + loop_env = dict(outer_env) + iv = self._new_arg_value(_vk_index) + loop_env[stmt.target.id] = iv + candidate_carried = [] + for name in self._collect_assigned_names(stmt.body): + if name in outer_env and name != stmt.target.id: + iter_arg = self._new_arg_value(outer_env[name].type) + loop_env[name] = iter_arg + candidate_carried.append((name, outer_env[name], iter_arg)) + + body_lines, body_env = self._compile_block(stmt.body, loop_env, indent + 1) + carried = [] + for name, before, iter_arg in candidate_carried: + after = body_env.get(name) + if after is not None and after is not iter_arg: + carried.append((name, before, after)) + + result_prefix = "" + yield_line = None + if carried: + results = [after.render_type() for _, _, after in carried] + result_value = self._new_value() + result_prefix = f"{result_value.name}:{len(carried)} = " + iter_arg_map = {name: iter_arg for name, _, iter_arg in candidate_carried} + carried_with_initials = [] + for name, before, after in carried: + before = self._materialize_value(before, lines, indent, after.type) + carried_with_initials.append((name, before, after)) + carried = carried_with_initials + iter_args = ", ".join( + f"{iter_arg_map[name].name} = {before.name}" for name, before, _ in carried + ) + self._emit( + lines, + indent, + f"{result_prefix}scf.for {iv.name} = {lb.name} to {ub.name} step {step.name} iter_args({iter_args}) -> ({', '.join(results)}) {{", + ) + yield_line = f"scf.yield {', '.join(after.name for _, _, after in carried)} : {', '.join(results)}" + else: + self._emit(lines, indent, f"scf.for {iv.name} = {lb.name} to {ub.name} step {step.name} {{") + lines.extend(body_lines) + if yield_line: + self._emit(lines, indent + 1, yield_line) + self._emit(lines, indent, "}") + + updated_env = dict(outer_env) + if carried: + for idx, (name, _, after) in enumerate(carried): + updated_env[name] = _project_result(result_value, idx, after.type) + return lines, updated_env + + def _compile_if(self, stmt, outer_env, indent): + lines = [] + cond = self._lower_expr(stmt.test, outer_env, lines, indent, i1) + then_lines, then_env = self._compile_block(stmt.body, dict(outer_env), indent + 1) + else_lines, else_env = self._compile_block(stmt.orelse, dict(outer_env), indent + 1) + updated = [] + for name, before in outer_env.items(): + then_val = then_env.get(name, before) + else_val = else_env.get(name, before) + if then_val is not before or else_val is not before: + if not _types_equal(then_val.type, else_val.type): + raise _VKernelCompileError(f"if merge type mismatch for '{name}'") + updated.append((name, then_val, else_val)) + + if updated: + result = self._new_value() + types = ", ".join(val.type.render() for _, val, _ in updated) + self._emit(lines, indent, f"{result.name}:{len(updated)} = scf.if {cond.name} -> ({types}) {{") + lines.extend(then_lines) + self._emit(lines, indent + 1, f"scf.yield {', '.join(val.name for _, val, _ in updated)} : {types}") + self._emit(lines, indent, "} else {") + lines.extend(else_lines) + self._emit(lines, indent + 1, f"scf.yield {', '.join(val.name for _, _, val in updated)} : {types}") + self._emit(lines, indent, "}") + updated_env = dict(outer_env) + for idx, (name, then_val, _) in enumerate(updated): + updated_env[name] = _project_result(result, idx, then_val.type) + return lines, updated_env + + self._emit(lines, indent, f"scf.if {cond.name} {{") + lines.extend(then_lines) + self._emit(lines, indent, "} else {") + lines.extend(else_lines) + self._emit(lines, indent, "}") + return lines, dict(outer_env) + + def build_text(self): + lines = [f'module attributes {{pto.target_arch = "{self.target}"}} {{'] + arg_types = [] + env = {} + for arg in self.fn_def.args.args: + arg_ty = _coerce_surface_type(self.py_fn.__annotations__.get(arg.arg)) + if arg_ty is None: + raise _VKernelCompileError(f"missing type annotation for argument '{arg.arg}'") + if not isinstance(arg_ty, _VKernelType): + raise _VKernelCompileError(f"unsupported type annotation for argument '{arg.arg}'") + if isinstance(arg_ty, _VKernelStructDef): + if arg.arg not in self.specialization: + raise _VKernelCompileError( + f"template argument '{arg.arg}: {arg_ty.name}' requires .jit(...) specialization" + ) + binding = self.specialization[arg.arg] + if not isinstance(binding, _VKernelStructBinding) or binding.schema != arg_ty: + raise _VKernelCompileError( + f"specialization for '{arg.arg}' must be a {arg_ty.name}(...) binding" + ) + struct_fields = {} + for field_name, field_kind in arg_ty.fields: + if field_name not in binding.values: + raise _VKernelCompileError( + f"missing field '{field_name}' in specialization for '{arg.arg}'" + ) + field_value = binding.values[field_name] + if field_kind is ptr: + if not isinstance(field_value, _VKernelPtrType): + raise _VKernelCompileError( + f"{arg_ty.name}.{field_name} must be a pto.ptr(...) type object" + ) + arg_val = self._new_arg_value(field_value) + arg_types.append(f"{arg_val.name}: {field_value.render()}") + struct_fields[field_name] = arg_val + continue + if field_kind is const: + if not isinstance(field_value, _VKernelConstBinding): + raise _VKernelCompileError( + f"{arg_ty.name}.{field_name} must use pto.const(...)" + ) + static_value = field_value.value + if not isinstance(static_value, (list, tuple)) or not all( + isinstance(v, int) for v in static_value + ): + raise _VKernelCompileError( + f"{arg_ty.name}.{field_name} must be a list/tuple of ints" + ) + struct_fields[field_name] = _VKStaticSequence( + tuple(_VKValue(literal=v) for v in static_value) + ) + continue + raise _VKernelCompileError( + f"unsupported struct field kind for {arg_ty.name}.{field_name}" + ) + env[arg.arg] = _VKStructValue(arg_ty, struct_fields) + continue + arg_val = self._new_arg_value(arg_ty) + arg_types.append(f"{arg_val.name}: {arg_ty.render()}") + env[arg.arg] = arg_val + self._emit(lines, 1, f"func.func @{self.kernel_name}({', '.join(arg_types)}) {{") + body_lines, _ = self._compile_block(self.fn_def.body, env, 2) + lines.extend(body_lines) + if not any(line.strip() == "return" for line in body_lines): + self._emit(lines, 2, "return") + self._emit(lines, 1, "}") + lines.append("}") + return "\n".join(lines) + "\n" + + +class VKernelHandle: + def __init__(self, py_fn, target="a5", name=None, verify=True, specialization=None): + self._py_fn = py_fn + self._target = target + self._name = name or py_fn.__name__ + self._verify = verify + self._specialization = specialization or {} + self._cached_text = None + + def _load_ast(self): + source = _textwrap.dedent(_inspect.getsource(self._py_fn)) + module = _ast.parse(source) + for node in module.body: + if isinstance(node, _ast.FunctionDef) and node.name == self._py_fn.__name__: + return node + raise _VKernelCompileError(f"failed to locate function AST for {self._py_fn.__name__}") + + def mlir_text(self): + if self._cached_text is None: + builder = _VKernelBuilder( + self._py_fn, + self._load_ast(), + self._target, + self._name, + specialization=self._specialization, + ) + self._cached_text = builder.build_text() + return self._cached_text + + def mlir_module(self): + with _ods_ir.Context() as ctx: + _load_standard_dialects() + register_dialect(ctx, load=True) + return _ods_ir.Module.parse(self.mlir_text(), ctx) + + def verify(self): + mod = self.mlir_module() + mod.operation.verify() + return True + + def dump(self): + print(self.mlir_text(), end="") + + def emit(self, path): + with open(path, "w", encoding="utf-8") as f: + f.write(self.mlir_text()) + + def jit(self, **kwargs): + return VKernelHandle( + self._py_fn, + target=self._target, + name=self._name, + verify=self._verify, + specialization=kwargs, + ) + + def __str__(self): + return self.mlir_text() + + +def vkernel(py_fn=None, *, target="a5", name=None, verify=True): + def wrap(fn): + return VKernelHandle(fn, target=target, name=name, verify=verify) + + if py_fn is None: + return wrap + return wrap(py_fn) + + +__all__.extend([ + "vkernel", + "VKernelHandle", + "struct", + "Tile", + "tile", + "const", + "ptr", + "vreg", + "i1", "i8", "i16", "i32", "i64", + "f16", "bf16", "f32", + "mask", "align", +]) diff --git a/scripts/batch_compile_output_cpp.sh b/scripts/batch_compile_output_cpp.sh new file mode 100755 index 000000000..13426a9f6 --- /dev/null +++ b/scripts/batch_compile_output_cpp.sh @@ -0,0 +1,464 @@ +#!/usr/bin/env bash + +set -u + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd -- "${SCRIPT_DIR}/.." && pwd)" +ASCEND_HOME_PATH="${ASCEND_HOME_PATH:-${HOME}/cann}" + +DEFAULT_SOURCE_DIR="${PTO_SOURCE_DIR:-${REPO_ROOT}}" +SRC_ROOT="${PTOAS_OUT_DIR:-${DEFAULT_SOURCE_DIR}/build/output}" +BUILD_ROOT="${DEFAULT_SOURCE_DIR}/build/output_asm" +LOG_DIR="${DEFAULT_SOURCE_DIR}/build/output_log" + +COMPILER="${COMPILER:-}" +PTO_ISA_PATH="${PTO_ISA_PATH:-${PTO_ISA_ROOT:-}}" +EXTRA_ARGS=() + +JOBS="${JOBS:-$(nproc)}" +AICORE_ARCH="${AICORE_ARCH:-dav-c310-vec}" +MEM_BASE_DEFINE="${MEM_BASE_DEFINE:-REGISTER_BASE}" +ENABLE_DEFAULT_ARGS=1 + +print_usage() { + cat <<'EOF' +批量编译 output 目录下所有 .cpp 文件为 .S,并汇总结果。 + +用法: + scripts/batch_compile_output_cpp.sh \ + [--compiler <编译器路径>] \ + [--pto-isa-path ] \ + [--compile-arg <单个参数>]... \ + [--jobs <并行数>] \ + [--aicore-arch ] \ + [--mem-base-define <宏名>] \ + [--src-root <源码目录>] \ + [--build-root <产物目录>] \ + [--log-dir <日志目录>] + +参数说明: + --compiler, -c 编译器路径。默认优先使用环境变量 COMPILER, + 其次使用 PATH 中的 bisheng 或 + ${ASCEND_HOME_PATH}/bin/bisheng + --pto-isa-path, -p PTO-ISA 根路径。默认优先使用环境变量 + PTO_ISA_PATH / PTO_ISA_ROOT。脚本会自动检测 include 目录: + 1) /include + 2) /tests/common (存在时自动加入) + 3) + --compile-arg 额外编译参数,可重复传入 + --jobs, -j 并行编译任务数,默认: nproc + --aicore-arch 默认: dav-c220-vec + --mem-base-define 默认: MEMORY_BASE (可改为 REGISTER_BASE) + --no-default-args 不使用脚本内置默认参数(仅使用 --compile-arg) + --src-root 要扫描的 .cpp 根目录,默认: $PTOAS_OUT_DIR + 或 $PTO_SOURCE_DIR/build/output + --build-root .S 产物目录,默认: $PTO_SOURCE_DIR/build/output_asm + --log-dir 编译日志目录,默认: /logs + --help, -h 显示帮助 + +推荐先执行: + source scripts/ptoas_env.sh + +默认编译参数来源: + 由 test/npu_validation/scripts/generate_testcase.py 中 + CMAKE_CCE_COMPILE_OPTIONS + target_compile_options() 提取: + -xcce -fenable-matrix --cce-aicore-enable-tl -fPIC -Xhost-start -Xhost-end + -mllvm -cce-aicore-function-stack-size=0x8000 + -mllvm -cce-aicore-record-overflow=true + -mllvm -cce-aicore-addr-transform + -mllvm -cce-aicore-dcci-insert-for-scalar=false + --cce-aicore-arch= -D -std=c++17 +EOF +} + +die() { + echo "[ERROR] $*" >&2 + exit 1 +} + +while [[ $# -gt 0 ]]; do + case "$1" in + --compiler | -c) + [[ $# -ge 2 ]] || die "--compiler 缺少参数" + COMPILER="$2" + shift 2 + ;; + --pto-isa-path | -p) + [[ $# -ge 2 ]] || die "--pto-isa-path 缺少参数" + PTO_ISA_PATH="$2" + shift 2 + ;; + --compile-arg) + [[ $# -ge 2 ]] || die "--compile-arg 缺少参数" + EXTRA_ARGS+=("$2") + shift 2 + ;; + --jobs | -j) + [[ $# -ge 2 ]] || die "--jobs 缺少参数" + JOBS="$2" + shift 2 + ;; + --aicore-arch) + [[ $# -ge 2 ]] || die "--aicore-arch 缺少参数" + AICORE_ARCH="$2" + shift 2 + ;; + --mem-base-define) + [[ $# -ge 2 ]] || die "--mem-base-define 缺少参数" + MEM_BASE_DEFINE="$2" + shift 2 + ;; + --no-default-args) + ENABLE_DEFAULT_ARGS=0 + shift + ;; + --src-root) + [[ $# -ge 2 ]] || die "--src-root 缺少参数" + SRC_ROOT="$2" + shift 2 + ;; + --build-root) + [[ $# -ge 2 ]] || die "--build-root 缺少参数" + BUILD_ROOT="$2" + shift 2 + ;; + --log-dir) + [[ $# -ge 2 ]] || die "--log-dir 缺少参数" + LOG_DIR="$2" + shift 2 + ;; + --help | -h) + print_usage + exit 0 + ;; + *) + die "未知参数: $1 (使用 --help 查看用法)" + ;; + esac +done + +if [[ -z "${COMPILER}" ]]; then + if command -v bisheng >/dev/null 2>&1; then + COMPILER="$(command -v bisheng)" + elif [[ -n "${ASCEND_HOME_PATH:-}" && -x "${ASCEND_HOME_PATH}/bin/bisheng" ]]; then + COMPILER="${ASCEND_HOME_PATH}/bin/bisheng" + fi +elif [[ "${COMPILER}" != */* ]] && command -v "${COMPILER}" >/dev/null 2>&1; then + COMPILER="$(command -v "${COMPILER}")" +fi + +[[ -n "${COMPILER}" ]] || die "未找到编译器,请先 source scripts/ptoas_env.sh,或通过 --compiler/COMPILER 指定 bisheng 路径" +[[ -n "${PTO_ISA_PATH}" ]] || die "未找到 PTO-ISA 路径,请通过 --pto-isa-path、PTO_ISA_PATH 或 PTO_ISA_ROOT 指定" +[[ -x "${COMPILER}" ]] || die "编译器不可执行: ${COMPILER}" +[[ -d "${SRC_ROOT}" ]] || die "源码目录不存在: ${SRC_ROOT}" +[[ -d "${PTO_ISA_PATH}" ]] || die "PTO-ISA 路径不存在: ${PTO_ISA_PATH}" +[[ "${JOBS}" =~ ^[1-9][0-9]*$ ]] || die "--jobs 必须为正整数" + +if [[ -z "${LOG_DIR}" ]]; then + LOG_DIR="${BUILD_ROOT}/logs" +fi + +mkdir -p "${BUILD_ROOT}" "${LOG_DIR}" || die "创建目录失败" + +INCLUDE_DIRS=() +if [[ -f "${PTO_ISA_PATH}/include/pto/pto-inst.hpp" ]]; then + INCLUDE_DIRS+=("${PTO_ISA_PATH}/include") +fi +if [[ -d "${PTO_ISA_PATH}/tests/common" ]]; then + INCLUDE_DIRS+=("${PTO_ISA_PATH}/tests/common") +fi +if [[ -f "${PTO_ISA_PATH}/pto/pto-inst.hpp" ]]; then + INCLUDE_DIRS+=("${PTO_ISA_PATH}") +fi +[[ ${#INCLUDE_DIRS[@]} -gt 0 ]] || die "未找到 pto/pto-inst.hpp,请检查 --pto-isa-path" + +if [[ -n "${ASCEND_HOME_PATH:-}" && -d "${ASCEND_HOME_PATH}/include" ]]; then + INCLUDE_DIRS+=("${ASCEND_HOME_PATH}/include") +fi +ASCEND_DRIVER_PATH="${ASCEND_DRIVER_PATH:-/usr/local/Ascend/driver}" +if [[ -d "${ASCEND_DRIVER_PATH}/kernel/inc" ]]; then + INCLUDE_DIRS+=("${ASCEND_DRIVER_PATH}/kernel/inc") +fi + +DEFAULT_ARGS=() +if [[ ${ENABLE_DEFAULT_ARGS} -eq 1 ]]; then + DEFAULT_ARGS=( + "-xcce" + "-fenable-matrix" + "--cce-aicore-enable-tl" + "--cce-aicore-only" + "-fPIC" + "-Xhost-start" + "-Xhost-end" + "-mllvm" "-cce-aicore-stack-size=0x8000" + "-mllvm" "-cce-aicore-function-stack-size=0x8000" + "-mllvm" "-cce-aicore-record-overflow=true" + "-mllvm" "-cce-aicore-addr-transform" + "-mllvm" "-cce-aicore-dcci-insert-for-scalar=false" + "--cce-aicore-arch=${AICORE_ARCH}" + "-D${MEM_BASE_DEFINE}" + "-std=c++17" + ) + if [[ "${AICORE_ARCH}" == dav-l310* || "${AICORE_ARCH}" == dav-l311* ]]; then + FILTERED_DEFAULT_ARGS=() + i=0 + while [[ ${i} -lt ${#DEFAULT_ARGS[@]} ]]; do + if [[ "${DEFAULT_ARGS[${i}]}" == "-mllvm" ]] && [[ $((i + 1)) -lt ${#DEFAULT_ARGS[@]} ]] && + [[ "${DEFAULT_ARGS[$((i + 1))]}" == "-cce-aicore-stack-size=0x8000" ]]; then + i=$((i + 2)) + continue + fi + FILTERED_DEFAULT_ARGS+=("${DEFAULT_ARGS[${i}]}") + i=$((i + 1)) + done + DEFAULT_ARGS=("${FILTERED_DEFAULT_ARGS[@]}") + fi +fi + +declare -a CPP_FILES=() +while IFS= read -r -d '' file; do + CPP_FILES+=("${file}") +done < <(find "${SRC_ROOT}" -type f -name "*.cpp" -print0 | sort -z) + +TOTAL_COUNT=${#CPP_FILES[@]} +[[ ${TOTAL_COUNT} -gt 0 ]] || die "未在 ${SRC_ROOT} 下找到 .cpp 文件" + +STATUS_FILE="$(mktemp "${BUILD_ROOT}/compile_status.XXXXXX")" || die "创建状态文件失败" +trap 'rm -f "${STATUS_FILE}"' EXIT + +record_compile_status() { + local status="$1" + local rel_path="$2" + printf '%s\t%s\n' "${status}" "${rel_path}" >>"${STATUS_FILE}" +} + +cleanup_work_dir() { + local work_dir="$1" + [[ -n "${work_dir}" ]] && rm -rf -- "${work_dir}" +} + +get_log_failure_reason() { + local log_path="$1" + local excerpt + + excerpt="$(grep -E -i 'error:|fatal:|undefined reference|undefined symbol|undeclared identifier|exception|traceback|failed' "${log_path}" | tail -n 5 || true)" + if [[ -z "${excerpt}" ]]; then + excerpt="$(tail -n 10 "${log_path}" 2>/dev/null || true)" + fi + printf '%s' "${excerpt}" +} + +find_generated_output() { + local work_dir="$1" + local src_stem="$2" + local candidate + + for candidate in \ + "${work_dir}/${src_stem}.o" \ + "${work_dir}/${src_stem}.S" \ + "${work_dir}/${src_stem}.s"; do + if [[ -f "${candidate}" ]]; then + printf '%s\n' "${candidate}" + return 0 + fi + done + + find "${work_dir}" -maxdepth 1 -type f \( -name "*.o" -o -name "*.S" -o -name "*.s" \) | head -n 1 +} + +write_rebuild_cmd() { + local cmd_path="$1" + local asm_path="$2" + local src_stem="$3" + shift 3 + local -a cmd=("$@") + local cmd_text="" + local arg + + for arg in "${cmd[@]}"; do + printf -v cmd_text '%s %q' "${cmd_text}" "${arg}" + done + cmd_text="${cmd_text# }" + + { + echo "#!/usr/bin/env bash" + echo + echo "set -euo pipefail" + echo + printf 'ASM_PATH=%q\n' "${asm_path}" + printf 'SRC_STEM=%q\n' "${src_stem}" + printf 'WORK_ROOT=%q\n' "${BUILD_ROOT}" + echo + echo 'WORK_DIR="$(mktemp -d "${WORK_ROOT}/tmp_rebuild.XXXXXX")"' + echo 'trap '\''rm -rf -- "${WORK_DIR}"'\'' EXIT' + echo + echo 'cd "${WORK_DIR}"' + echo "${cmd_text}" + echo + echo 'GENERATED_FILE=""' + echo 'for candidate in "${WORK_DIR}/${SRC_STEM}.o" "${WORK_DIR}/${SRC_STEM}.S" "${WORK_DIR}/${SRC_STEM}.s"; do' + echo ' if [[ -f "${candidate}" ]]; then' + echo ' GENERATED_FILE="${candidate}"' + echo ' break' + echo ' fi' + echo 'done' + echo + echo 'if [[ -z "${GENERATED_FILE}" ]]; then' + echo ' GENERATED_FILE="$(find "${WORK_DIR}" -maxdepth 1 -type f \( -name "*.o" -o -name "*.S" -o -name "*.s" \) | head -n 1)"' + echo 'fi' + echo + echo 'if [[ -z "${GENERATED_FILE}" || ! -f "${GENERATED_FILE}" ]]; then' + echo ' echo "[ERROR] 编译成功但未找到输出文件,期望类型: .o/.S/.s" >&2' + echo ' exit 1' + echo 'fi' + echo + echo 'mkdir -p "$(dirname -- "${ASM_PATH}")"' + echo 'mv -f -- "${GENERATED_FILE}" "${ASM_PATH}"' + printf 'echo "已更新: %s"\n' "${asm_path}" + } >"${cmd_path}" || return 1 + + chmod +x "${cmd_path}" +} + +compile_one() { + local src="$1" + local rel_path asm_path log_path cmd_path src_base src_stem work_dir generated_file + local -a cmd=() + + rel_path="${src#"${SRC_ROOT}/"}" + asm_path="${BUILD_ROOT}/${rel_path%.cpp}.S" + log_path="${LOG_DIR}/${rel_path%.cpp}.log" + cmd_path="$(dirname -- "${log_path}")/cmd.sh" + src_base="$(basename -- "${src}")" + src_stem="${src_base%.cpp}" + + mkdir -p "$(dirname -- "${asm_path}")" "$(dirname -- "${log_path}")" || { + record_compile_status "FAIL" "${rel_path}" + return 0 + } + + cmd=("${COMPILER}") + if [[ ${#DEFAULT_ARGS[@]} -gt 0 ]]; then + cmd+=("${DEFAULT_ARGS[@]}") + fi + if [[ ${#EXTRA_ARGS[@]} -gt 0 ]]; then + cmd+=("${EXTRA_ARGS[@]}") + fi + local inc + for inc in "${INCLUDE_DIRS[@]}"; do + cmd+=("-I${inc}") + done + cmd+=("-c" "${src}") + + if ! write_rebuild_cmd "${cmd_path}" "${asm_path}" "${src_stem}" "${cmd[@]}"; then + record_compile_status "FAIL" "${rel_path}" + return 0 + fi + + echo "[BUILD] ${rel_path}" + work_dir="$(mktemp -d "${BUILD_ROOT}/tmp_compile.XXXXXX")" || { + record_compile_status "FAIL" "${rel_path}" + return 0 + } + + if ! (cd "${work_dir}" && "${cmd[@]}") >"${log_path}" 2>&1; then + cleanup_work_dir "${work_dir}" + record_compile_status "FAIL" "${rel_path}" + return 0 + fi + + generated_file="$(find_generated_output "${work_dir}" "${src_stem}")" + + if [[ -z "${generated_file}" || ! -f "${generated_file}" ]]; then + { + echo + echo "[ERROR] 编译成功但未找到输出文件,期望类型: .o/.S/.s" + echo "[ERROR] 临时目录: ${work_dir}" + } >>"${log_path}" + cleanup_work_dir "${work_dir}" + record_compile_status "FAIL" "${rel_path}" + return 0 + fi + + if mv -f -- "${generated_file}" "${asm_path}"; then + cleanup_work_dir "${work_dir}" + record_compile_status "OK" "${rel_path}" + else + { + echo + echo "[ERROR] 输出重命名失败: ${generated_file} -> ${asm_path}" + } >>"${log_path}" + cleanup_work_dir "${work_dir}" + record_compile_status "FAIL" "${rel_path}" + fi +} + +START_TIME="$(date +%s)" + +echo "[INFO] 编译器: ${COMPILER}" +echo "[INFO] 源目录: ${SRC_ROOT}" +echo "[INFO] 产物目录(.S): ${BUILD_ROOT}" +echo "[INFO] 日志目录: ${LOG_DIR}" +echo "[INFO] PTO-ISA: ${PTO_ISA_PATH}" +echo "[INFO] 并行度: ${JOBS}" +echo "[INFO] include: ${INCLUDE_DIRS[*]}" +if [[ ${ENABLE_DEFAULT_ARGS} -eq 1 ]]; then + echo "[INFO] 默认参数(来自 generate_testcase.py): ${DEFAULT_ARGS[*]}" +else + echo "[INFO] 默认参数: 已禁用 (--no-default-args)" +fi +if [[ ${#EXTRA_ARGS[@]} -gt 0 ]]; then + echo "[INFO] 额外参数: ${EXTRA_ARGS[*]}" +fi +echo "[INFO] 文件总数: ${TOTAL_COUNT}" +echo + +running_jobs=0 +for src in "${CPP_FILES[@]}"; do + compile_one "${src}" & + running_jobs=$((running_jobs + 1)) + if [[ ${running_jobs} -ge ${JOBS} ]]; then + wait -n + running_jobs=$((running_jobs - 1)) + fi +done + +wait + +SUCCESS_COUNT="$(awk -F'\t' '$1=="OK"{c++} END{print c+0}' "${STATUS_FILE}")" +FAIL_COUNT="$(awk -F'\t' '$1=="FAIL"{c++} END{print c+0}' "${STATUS_FILE}")" + +declare -a FAILED_FILES=() +while IFS= read -r failed; do + [[ -n "${failed}" ]] && FAILED_FILES+=("${failed}") +done < <(awk -F'\t' '$1=="FAIL"{print $2}' "${STATUS_FILE}") + +END_TIME="$(date +%s)" +ELAPSED="$((END_TIME - START_TIME))" + +echo +echo "========== 编译汇总 ==========" +echo "总文件数 : ${TOTAL_COUNT}" +echo "成功数 : ${SUCCESS_COUNT}" +echo "失败数 : ${FAIL_COUNT}" +echo "耗时(秒) : ${ELAPSED}" + +if [[ ${FAIL_COUNT} -gt 0 ]]; then + failure_reason="" + echo + echo "失败文件列表:" + for f in "${FAILED_FILES[@]}"; do + echo " - ${f} (log: ${LOG_DIR}/${f%.cpp}.log)" + failure_reason="$(get_log_failure_reason "${LOG_DIR}/${f%.cpp}.log")" + if [[ -n "${failure_reason}" ]]; then + while IFS= read -r line; do + [[ -n "${line}" ]] || continue + echo " reason: ${line}" + done <<<"${failure_reason}" + fi + done + exit 1 +fi + +echo "[INFO] 全部编译成功" +exit 0 diff --git a/scripts/compile_pto_to_vpto_llvm.sh b/scripts/compile_pto_to_vpto_llvm.sh new file mode 100755 index 000000000..2d15c86b6 --- /dev/null +++ b/scripts/compile_pto_to_vpto_llvm.sh @@ -0,0 +1,116 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT_DIR="$(cd "${SCRIPT_DIR}/.." && pwd)" + +PTO_FILE="${1:-}" +OUT_DIR_ARG="${2:-}" + +PTOAS_BIN="${PTOAS_BIN:-${ROOT_DIR}/build/tools/ptoas/ptoas}" +PTOAS_FLAGS="${PTOAS_FLAGS:---pto-arch a5}" +VPTO_FLAGS="${VPTO_FLAGS:---pto-backend=vpto --vpto-emit-hivm-llvm}" +AICORE_ARCH="${AICORE_ARCH:-dav-c310-vec}" +ASCEND_HOME_PATH="${ASCEND_HOME_PATH:-${HOME}/cann}" +BISHENG_BIN="" +BISHENG_FLAGS="${BISHENG_FLAGS:-}" +LLVM_IR="" +DEVICE_OBJ="" + +log() { + echo "[$(date +'%F %T')] $*" +} + +die() { + echo "ERROR: $*" >&2 + exit 1 +} + +on_error() { + local exit_code="$1" + if [[ -n "${LLVM_IR}" && -f "${LLVM_IR}" ]]; then + echo "Retained LLVM IR: ${LLVM_IR}" >&2 + fi + if [[ -n "${DEVICE_OBJ}" ]]; then + echo "Expected device object: ${DEVICE_OBJ}" >&2 + fi + exit "${exit_code}" +} + +trap 'on_error $?' ERR + +usage() { + cat < [output_dir] + +Environment overrides: + PTOAS_BIN path to ptoas + PTOAS_FLAGS default: --pto-arch a5 + VPTO_FLAGS default: --pto-backend=vpto --vpto-emit-hivm-llvm + ASCEND_HOME_PATH default: \$HOME/cann + BISHENG_BIN + BISHENG_FLAGS extra flags passed to bisheng when compiling .ll to .o + AICORE_ARCH default: dav-c310-vec + +Example: + $(basename "$0") test/samples/PyPTOIRParser/paged_attention_example_kernel_online_update.pto +EOF +} + +[[ -n "${PTO_FILE}" ]] || { + usage + exit 1 +} + +[[ "${PTO_FILE}" == *.pto ]] || die "input must be a .pto file: ${PTO_FILE}" +[[ -f "${PTO_FILE}" ]] || die "missing input file: ${PTO_FILE}" + +set +u +source "${ROOT_DIR}/scripts/ptoas_env.sh" +set -u + +if [[ -n "${ASCEND_HOME_PATH}" && -f "${ASCEND_HOME_PATH}/set_env.sh" ]]; then + set +u + source "${ASCEND_HOME_PATH}/set_env.sh" >/dev/null 2>&1 + set -u +fi + +BISHENG_BIN="${BISHENG_BIN:-${ASCEND_HOME_PATH}/bin/bisheng}" + +[[ -x "${PTOAS_BIN}" ]] || die "PTOAS_BIN is not executable: ${PTOAS_BIN}" +command -v "${BISHENG_BIN}" >/dev/null 2>&1 || die "bisheng not found: ${BISHENG_BIN}" + +pto_abs="$(cd "$(dirname "${PTO_FILE}")" && pwd)/$(basename "${PTO_FILE}")" +pto_base="$(basename "${PTO_FILE}" .pto)" + +if [[ -n "${OUT_DIR_ARG}" ]]; then + OUT_DIR="${OUT_DIR_ARG}" +else + OUT_DIR="${ROOT_DIR}/build/vpto_quick/${pto_base}" +fi + +mkdir -p "${OUT_DIR}" +OUT_DIR="$(cd "${OUT_DIR}" && pwd)" + +LLVM_IR="${OUT_DIR}/${pto_base}.ll" +DEVICE_OBJ="${OUT_DIR}/${pto_base}.o" + +log "step 1/2: lower PTO to VPTO LLVM IR" +"${PTOAS_BIN}" ${PTOAS_FLAGS} ${VPTO_FLAGS} \ + "${pto_abs}" \ + -o "${LLVM_IR}" + +log "step 2/2: compile LLVM IR to device object" +"${BISHENG_BIN}" \ + --target=hiipu64-hisilicon-cce \ + -march="${AICORE_ARCH}" \ + --cce-aicore-arch="${AICORE_ARCH}" \ + --cce-aicore-only \ + ${BISHENG_FLAGS} \ + -c -x ir "${LLVM_IR}" \ + -o "${DEVICE_OBJ}" + +log "done" +echo "LLVM IR: ${LLVM_IR}" +echo "Device object: ${DEVICE_OBJ}" diff --git a/scripts/ptoas_env.sh b/scripts/ptoas_env.sh new file mode 100644 index 000000000..95dcd9a8d --- /dev/null +++ b/scripts/ptoas_env.sh @@ -0,0 +1,83 @@ +#!/usr/bin/env bash +# PTOAS runtime environment bootstrap. +# Usage: +# source scripts/ptoas_env.sh +# +# Optional overrides before sourcing: +# export WORKSPACE_DIR=/path/to/workspace +# export LLVM_BUILD_DIR=/path/to/llvm-project/build-shared +# export PTO_SOURCE_DIR=/path/to/PTOAS +# export PTO_INSTALL_DIR=/path/to/PTOAS/install +# export PTO_PYTHON_BIN=/path/to/python3 + +if [[ "${BASH_SOURCE[0]}" == "${0}" ]]; then + echo "This script must be sourced: source scripts/ptoas_env.sh" + exit 1 +fi + +_PTOAS_ENV_SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)" +_PTOAS_REPO_DIR="$(cd -- "${_PTOAS_ENV_SCRIPT_DIR}/.." && pwd)" + +# Default layout: +# / +# ├── PTOAS/ +# └── llvm-project/ +export PTO_SOURCE_DIR="${PTO_SOURCE_DIR:-${_PTOAS_REPO_DIR}}" +export WORKSPACE_DIR="${WORKSPACE_DIR:-$(cd -- "${PTO_SOURCE_DIR}/.." && pwd)}" +export LLVM_SOURCE_DIR="${LLVM_SOURCE_DIR:-${WORKSPACE_DIR}/llvm-project}" +export LLVM_BUILD_DIR="${LLVM_BUILD_DIR:-${LLVM_SOURCE_DIR}/build-shared}" +export PTO_INSTALL_DIR="${PTO_INSTALL_DIR:-${PTO_SOURCE_DIR}/install}" +export PTO_ISA_PATH="${PTO_ISA_PATH:-${WORKSPACE_DIR}/pto-isa}" +export ASCEND_HOME_PATH="${ASCEND_HOME_PATH:-${HOME}/cann}" + +export MLIR_PYTHON_ROOT="${MLIR_PYTHON_ROOT:-${LLVM_BUILD_DIR}/tools/mlir/python_packages/mlir_core}" +export PTO_PYTHON_ROOT="${PTO_PYTHON_ROOT:-${PTO_INSTALL_DIR}}" +export PTO_PYTHON_BUILD_ROOT="${PTO_PYTHON_BUILD_ROOT:-${PTO_SOURCE_DIR}/build/python}" +export PYBIND11_CMAKE_DIR=$(python3 -m pybind11 --cmakedir) +export PTOAS_FLAGS="${PTOAS_FLAGS:-}" +export PTOAS_OUT_DIR=$PTO_SOURCE_DIR/build/output + +_ptoas_prepend_path() { + local var_name="$1" + local value="$2" + local current="${!var_name:-}" + if [[ -z "${value}" ]]; then + return 0 + fi + if [[ ! -e "${value}" ]]; then + return 0 + fi + if [[ ":${current}:" == *":${value}:"* ]]; then + return 0 + fi + if [[ -z "${current}" ]]; then + printf -v "${var_name}" '%s' "${value}" + else + printf -v "${var_name}" '%s:%s' "${value}" "${current}" + fi + export "${var_name}" +} + +_ptoas_prepend_path PYTHONPATH "${MLIR_PYTHON_ROOT}" +_ptoas_prepend_path PYTHONPATH "${PTO_PYTHON_ROOT}" +_ptoas_prepend_path PYTHONPATH "${PTO_PYTHON_BUILD_ROOT}" + +_ptoas_prepend_path LD_LIBRARY_PATH "${LLVM_BUILD_DIR}/lib" +_ptoas_prepend_path LD_LIBRARY_PATH "${PTO_INSTALL_DIR}/lib" +_ptoas_prepend_path LD_LIBRARY_PATH "${PTO_SOURCE_DIR}/build/lib" + +_ptoas_prepend_path PATH "${PTO_SOURCE_DIR}/build/tools/ptoas" + +if [[ -n "${PTO_PYTHON_BIN:-}" && -x "${PTO_PYTHON_BIN}" ]]; then + alias ptoas-python="${PTO_PYTHON_BIN}" +fi + +echo "[ptoas_env] PTO_SOURCE_DIR=${PTO_SOURCE_DIR}" +echo "[ptoas_env] LLVM_BUILD_DIR=${LLVM_BUILD_DIR}" +echo "[ptoas_env] PTO_INSTALL_DIR=${PTO_INSTALL_DIR}" +echo "[ptoas_env] PTO_ISA_PATH=${PTO_ISA_PATH}" +echo "[ptoas_env] ASCEND_HOME_PATH=${ASCEND_HOME_PATH}" +echo "[ptoas_env] PATH/PYTHONPATH/LD_LIBRARY_PATH updated" + +unset _PTOAS_ENV_SCRIPT_DIR +unset _PTOAS_REPO_DIR diff --git a/test/dsl/abs.py b/test/dsl/abs.py new file mode 100644 index 000000000..7c67e5959 --- /dev/null +++ b/test/dsl/abs.py @@ -0,0 +1,34 @@ +import mlir.dialects.pto as pto + + +@pto.vkernel(target="a5", name="abs_kernel_2d") +def abs_kernel_2d(inp: pto.ptr(pto.f32, "gm"), out: pto.ptr(pto.f32, "gm")): + ub_in = pto.castptr(0, pto.ptr(pto.f32, "ub")) + ub_out = pto.castptr(4096, pto.ptr(pto.f32, "ub")) + + pto.set_loop_size_outtoub(1, 1) + pto.copy_gm_to_ubuf(inp, ub_in, 0, 32, 128, 0, 0, False, 0, 128, 128) + + pto.set_flag("PIPE_MTE2", "PIPE_V", "EVENT_ID0") + pto.wait_flag("PIPE_MTE2", "PIPE_V", "EVENT_ID0") + + with pto.vecscope(): + remaining: pto.i32 = 1024 + for offset in range(0, 1024, 64): + mask, remaining = pto.plt_b32(remaining) + vec_in = pto.vlds(ub_in, offset) + vec_out = pto.vabs(vec_in, mask) + pto.vsts(vec_out, ub_out, offset, mask) + + pto.set_flag("PIPE_V", "PIPE_MTE3", "EVENT_ID0") + pto.wait_flag("PIPE_V", "PIPE_MTE3", "EVENT_ID0") + + pto.set_loop_size_ubtoout(1, 1) + pto.copy_ubuf_to_gm(ub_out, out, 0, 32, 128, 0, 128, 128) + pto.barrier("PIPE_ALL") + + return + + +if __name__ == "__main__": + print(abs_kernel_2d.mlir_text(), end="") diff --git a/test/dsl/strict_vecscope.py b/test/dsl/strict_vecscope.py new file mode 100644 index 000000000..e882df3d8 --- /dev/null +++ b/test/dsl/strict_vecscope.py @@ -0,0 +1,42 @@ +import mlir.dialects.pto as pto + + +@pto.vkernel(target="a5", name="abs_strict_vecscope_kernel_2d") +def abs_strict_vecscope_kernel_2d( + inp: pto.ptr(pto.f32, "gm"), out: pto.ptr(pto.f32, "gm") +): + ub_in = pto.castptr(0, pto.ptr(pto.f32, "ub")) + ub_out = pto.castptr(4096, pto.ptr(pto.f32, "ub")) + + pto.set_loop_size_outtoub(1, 1) + pto.copy_gm_to_ubuf(inp, ub_in, 0, 32, 128, 0, 0, False, 0, 128, 128) + + pto.set_flag("PIPE_MTE2", "PIPE_V", "EVENT_ID0") + pto.wait_flag("PIPE_MTE2", "PIPE_V", "EVENT_ID0") + + with pto.strict_vecscope(ub_in, ub_out, 0, 1024, 64, 1024) as ( + src, + dst, + lb, + ub, + step, + remaining, + ): + for offset in range(lb, ub, step): + mask, remaining = pto.plt_b32(remaining) + vec_in = pto.vlds(src, offset) + vec_out = pto.vabs(vec_in, mask) + pto.vsts(vec_out, dst, offset, mask) + + pto.set_flag("PIPE_V", "PIPE_MTE3", "EVENT_ID0") + pto.wait_flag("PIPE_V", "PIPE_MTE3", "EVENT_ID0") + + pto.set_loop_size_ubtoout(1, 1) + pto.copy_ubuf_to_gm(ub_out, out, 0, 32, 128, 0, 128, 128) + pto.barrier("PIPE_ALL") + + return + + +if __name__ == "__main__": + print(abs_strict_vecscope_kernel_2d.mlir_text(), end="") diff --git a/test/dsl/template_abs.py b/test/dsl/template_abs.py new file mode 100644 index 000000000..87b330e32 --- /dev/null +++ b/test/dsl/template_abs.py @@ -0,0 +1,48 @@ +import mlir.dialects.pto as pto + + +@pto.vkernel(target="a5", name="template_abs_kernel") +def template_abs_kernel(src: pto.Tile, dst: pto.Tile): + total = src.shape[0] * src.shape[1] + step = 256 // src.ub_ptr.elem_bytes + + with pto.strict_vecscope(src.ub_ptr, dst.ub_ptr, 0, total, step, total) as ( + vin, + vout, + lb, + ub, + vec_step, + remaining, + ): + for offset in range(lb, ub, vec_step): + mask, remaining = pto.plt_b32(remaining) + vec_in = pto.vlds(vin, offset) + vec_out = pto.vabs(vec_in, mask) + pto.vsts(vec_out, vout, offset, mask) + + +template_abs_kernel_f32 = template_abs_kernel.jit( + src=pto.Tile( + ub_ptr=pto.ptr(pto.f32, "ub"), + shape=pto.const([32, 32]), + ), + dst=pto.Tile( + ub_ptr=pto.ptr(pto.f32, "ub"), + shape=pto.const([32, 32]), + ), +) + +template_abs_kernel_f16 = template_abs_kernel.jit( + src=pto.Tile( + ub_ptr=pto.ptr(pto.f16, "ub"), + shape=pto.const([32, 32]), + ), + dst=pto.Tile( + ub_ptr=pto.ptr(pto.f16, "ub"), + shape=pto.const([32, 32]), + ), +) + + +if __name__ == "__main__": + print(template_abs_kernel_f32.mlir_text(), end="") diff --git a/test/lit.cfg.py b/test/lit.cfg.py new file mode 100644 index 000000000..95e17569a --- /dev/null +++ b/test/lit.cfg.py @@ -0,0 +1,85 @@ +import os +import lit.formats + +config.name = "PTOAS" +config.test_format = lit.formats.ShTest(execute_external=True) + +# Keep discovery focused on lit-style tests. +config.suffixes = [".mlir", ".pto"] +config.excludes = [ + "CMakeLists.txt", + "README.md", + "lit.cfg.py", + "resources", +] + +config.test_source_root = os.path.dirname(__file__) + + +def _resolve_build_root(): + env_build_dir = os.environ.get("PTOAS_BUILD_DIR") + if env_build_dir: + return os.path.abspath(env_build_dir) + + repo_root = os.path.abspath(os.path.join(config.test_source_root, "..")) + return os.path.join(repo_root, "build") + + +build_root = _resolve_build_root() +config.test_exec_root = os.path.join(build_root, "test") +os.makedirs(config.test_exec_root, exist_ok=True) + + +def _resolve_llvm_bin_dir(): + env_build_dir = os.environ.get("LLVM_BUILD_DIR") + candidates = [] + if env_build_dir: + candidates.append(os.path.join(os.path.abspath(env_build_dir), "bin")) + + repo_root = os.path.abspath(os.path.join(config.test_source_root, "..")) + candidates.append( + os.path.abspath( + os.path.join(repo_root, "..", "llvm-project", "build-shared", "bin") + ) + ) + + for candidate in candidates: + if os.path.isdir(candidate): + return candidate + return "" + + +def _resolve_ptoas_bin(): + env_bin = os.environ.get("PTOAS_BIN") + if env_bin: + return env_bin + + candidate = os.path.join(build_root, "tools", "ptoas", "ptoas") + if os.path.isfile(candidate) and os.access(candidate, os.X_OK): + return candidate + + return "ptoas" + + +def _prepend_path(path_var, entry): + if not entry: + return path_var + if not path_var: + return entry + return entry + os.pathsep + path_var + + +ptoas_bin = _resolve_ptoas_bin() +ptoas_dir = os.path.dirname(ptoas_bin) if os.path.isabs(ptoas_bin) else "" +llvm_bin_dir = _resolve_llvm_bin_dir() + +path_env = config.environment.get("PATH", os.environ.get("PATH", "")) +if llvm_bin_dir: + path_env = _prepend_path(path_env, llvm_bin_dir) +if ptoas_dir: + path_env = _prepend_path(path_env, ptoas_dir) +config.environment["PATH"] = path_env + +# Keep RUN lines using bare `ptoas` stable regardless of shell cwd. +if os.path.isabs(ptoas_bin): + config.substitutions.append(("ptoas", ptoas_bin)) diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index 309a94c96..f90b95f3a 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -7,9 +7,12 @@ // See LICENSE in the root of the software repository for the full text of the License. #include "PTO/IR/PTO.h" +#include "PTO/Transforms/VPTOLowering.h" +#include "PTO/Transforms/VPTOLLVMEmitter.h" #include "PTO/Transforms/Passes.h" #include "PTO/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" @@ -42,6 +45,7 @@ #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/StringMap.h" +#include "llvm/Support/MemoryBuffer.h" #include using namespace mlir; @@ -210,12 +214,87 @@ static llvm::cl::opt ptoBuildLevel( llvm::cl::value_desc("level1|level2|level3"), llvm::cl::init("level2")); +static llvm::cl::opt ptoBackend( + "pto-backend", + llvm::cl::desc("Final PTOAS backend: emitc or vpto (default: emitc)"), + llvm::cl::value_desc("emitc|vpto"), llvm::cl::init("emitc")); + +static llvm::cl::opt emitVPTO( + "emit-vpto", + llvm::cl::desc("Write final post-pass VPTO IR to -o"), + llvm::cl::init(false)); + +static llvm::cl::opt vptoPrintIR( + "vpto-print-ir", + llvm::cl::desc("Print post-pass VPTO backend IR to stderr"), + llvm::cl::init(false)); + +static llvm::cl::opt vptoLoweringStrategy( + "vpto-lowering-strategy", + llvm::cl::desc("VPTO vector lowering strategy: post-update or no-post-update"), + llvm::cl::value_desc("post-update|no-post-update"), + llvm::cl::init("post-update")); + +static llvm::cl::opt dumpVPTOIR( + "dump-vpto-ir", + llvm::cl::desc("Print post-pass VPTO backend IR to stderr"), + llvm::cl::init(false)); + +static llvm::cl::opt ptoPrintSeamIR( + "pto-print-seam-ir", + llvm::cl::desc("Print shared pre-backend seam IR to stderr"), + llvm::cl::init(false)); + +static llvm::cl::opt ptoSeamIRFile( + "pto-seam-ir-file", + llvm::cl::desc("Write shared pre-backend seam IR to a file"), + llvm::cl::value_desc("path"), + llvm::cl::init("")); + +static llvm::cl::opt vptoPrintIntrinsics( + "vpto-print-intrinsics", + llvm::cl::desc("Print VPTO intrinsic selection decisions to stderr"), + llvm::cl::init(false)); + +static llvm::cl::opt vptoEmitHIVMOfficialLLVM( + "vpto-emit-hivm-llvm", + llvm::cl::desc("After lowering to VPTO IR, emit textual LLVM/HIVM via " + "the official LLVM dialect export path"), + llvm::cl::init(false)); + +static llvm::cl::opt vptoEmitHIVMOfficialBitcode( + "vpto-emit-hivm-bc", + llvm::cl::desc("After lowering to VPTO IR, emit LLVM bitcode via the " + "official LLVM dialect export path"), + llvm::cl::init(false)); + +static llvm::cl::opt vptoAllowUnresolved( + "vpto-allow-unresolved", + llvm::cl::desc("Emit explicit unresolved VPTO comments instead of failing"), + llvm::cl::init(false)); + +static llvm::cl::opt vptoUnresolvedReport( + "vpto-unresolved-report", + llvm::cl::desc("Write unresolved VPTO mappings to a sidecar report"), + llvm::cl::value_desc("path"), llvm::cl::init("")); + +static llvm::cl::opt hivmUnresolvedReport( + "hivm-unresolved-report", + llvm::cl::desc("Write unresolved HIVM mappings to a sidecar report"), + llvm::cl::value_desc("path"), + llvm::cl::init("")); + enum class PTOBuildLevel { Level1, Level2, Level3, }; +enum class PTOBackend { + EmitC, + VPTO, +}; + static PTOBuildLevel defaultBuildLevel() { return PTOBuildLevel::Level2; } @@ -261,6 +340,94 @@ static bool parseAutoSyncTailHint(llvm::StringRef hintStr, std::string &normaliz return false; } +static bool parseBackend(llvm::StringRef backendStr, PTOBackend &out) { + std::string s = backendStr.str(); + for (char &c : s) + c = static_cast(std::tolower(static_cast(c))); + if (s == "emitc") { + out = PTOBackend::EmitC; + return true; + } + if (s == "vpto") { + out = PTOBackend::VPTO; + return true; + } + return false; +} + +static LogicalResult emitSharedPreBackendSeamIR(ModuleOp module, + llvm::StringRef outputPath) { + if (outputPath.empty()) + return success(); + + if (outputPath == "-") { + module->print(llvm::outs()); + llvm::outs() << "\n"; + llvm::outs().flush(); + return success(); + } + + std::error_code ec; + llvm::ToolOutputFile outputFile(outputPath, ec, llvm::sys::fs::OF_None); + if (ec) { + llvm::errs() << "Error: failed to open seam IR file '" << outputPath + << "': " << ec.message() << "\n"; + return failure(); + } + + module->print(outputFile.os()); + outputFile.os() << "\n"; + outputFile.keep(); + return success(); +} + +static bool containsVPTOOpPrefix(llvm::StringRef line, + llvm::StringRef opPrefix) { + size_t searchFrom = 0; + while (searchFrom < line.size()) { + size_t pos = line.find(opPrefix, searchFrom); + if (pos == llvm::StringRef::npos) + return false; + + if (pos == 0) + return true; + + unsigned char before = static_cast(line[pos - 1]); + if (std::isspace(before) || before == '(' || before == '=' || + before == ',') + return true; + + searchFrom = pos + 1; + } + return false; +} + +static bool containsVPTOIR(llvm::StringRef input) { + llvm::StringRef rest = input; + while (!rest.empty()) { + auto split = rest.split('\n'); + llvm::StringRef line = split.first.trim(); + if (!line.starts_with("//") && + (line.contains("!pto.vec<") || line.contains("!pto.mask") || + line.contains("!pto.align") || + containsVPTOOpPrefix(line, "pto.copy_") || + containsVPTOOpPrefix(line, "pto.set_loop") || + containsVPTOOpPrefix(line, "pto.v") || + containsVPTOOpPrefix(line, "pto.plt_") || + containsVPTOOpPrefix(line, "pto.pset_") || + containsVPTOOpPrefix(line, "pto.psts") || + containsVPTOOpPrefix(line, "pto.pdintlv_") || + containsVPTOOpPrefix(line, "pto.set_flag") || + containsVPTOOpPrefix(line, "pto.wait_flag") || + containsVPTOOpPrefix(line, "pto.pipe_barrier") || + containsVPTOOpPrefix(line, "pto.get_buf") || + containsVPTOOpPrefix(line, "pto.rls_buf"))) + return true; + rest = split.second; + } + return false; +} + // -------------------------------------------------------------------------- // Post-process C++ output: rewrite marker calls into Tile member calls. // @@ -894,6 +1061,85 @@ static bool shouldDeclareVariablesAtTop(ModuleOp module) { llvm::any_of(module.getOps(), hasMultiBlockFunc); } +static LogicalResult prepareVPTOForEmission(ModuleOp module) { + if (failed(convertVPTOEmissionBoundaryToPtr(module, &llvm::errs()))) { + llvm::errs() << "Error: VPTO emission boundary canonicalization failed.\n"; + return failure(); + } + + PassManager prepPM(module->getContext()); + prepPM.enableVerifier(); + prepPM.addNestedPass(createPTOVPTOExpandBridgeOpsPass()); + prepPM.addPass(createCSEPass()); + prepPM.addPass(pto::createPTOValidateVPTOEmissionIRPass()); + if (failed(prepPM.run(module))) { + llvm::errs() << "Error: VPTO emission preparation failed.\n"; + return failure(); + } + + return success(); +} + +static LogicalResult lowerPTOToVPTOBackend(ModuleOp module) { + PassManager backendPM(module.getContext()); + backendPM.addPass(pto::createLowerPTOToVPTOPass()); + backendPM.addPass(mlir::createCSEPass()); + if (failed(backendPM.run(module))) { + llvm::errs() << "Error: backend lowering pass execution failed.\n"; + return failure(); + } + return success(); +} + +static pto::VPTOEmissionOptions buildVPTOEmissionOptions() { + pto::VPTOEmissionOptions options; + options.dumpVPTOIR = false; + options.printIntrinsicSelections = vptoPrintIntrinsics; + options.allowUnresolved = vptoAllowUnresolved; + options.unresolvedReportPath = + !hivmUnresolvedReport.empty() ? hivmUnresolvedReport : vptoUnresolvedReport; + options.targetTriple = "hiipu64-hisilicon-cce"; + options.march = "dav-c310-vec"; + options.aicoreArch = "dav-c310-vec"; + options.defaultTargetCPU = "dav-c310-vec"; + options.defaultTargetFeatures = + "+ATOMIC,+ArchV130,+AregRedefinable,+ArithmeticBf16,+AtomicForB8 ," + "+F8e4m3,+F8e5m2,+F8e8m0,+FFTSBlk,+Fp4e1m2x2,+Fp4e2m1x2,+LDExtRefine," + "+MOVX8,+SPR7bits,+SyncV,+dav-c310-vec"; + return options; +} + +static int emitPreparedVPTOBackendResult(ModuleOp module, + llvm::ToolOutputFile &outputFile) { + if (emitVPTO || (!vptoEmitHIVMOfficialLLVM && !vptoEmitHIVMOfficialBitcode)) { + module.print(outputFile.os()); + outputFile.os() << "\n"; + outputFile.keep(); + return 0; + } + + pto::VPTOEmissionOptions options = buildVPTOEmissionOptions(); + LogicalResult emissionStatus = + vptoEmitHIVMOfficialBitcode + ? pto::translateVPTOModuleToLLVMBitcode(module, outputFile.os(), + options, llvm::errs()) + : pto::translateVPTOModuleToLLVMText(module, outputFile.os(), + options, llvm::errs()); + if (failed(emissionStatus)) { + llvm::errs() << "Error: Failed to emit VPTO text.\n"; + return 1; + } + outputFile.keep(); + return 0; +} + +static int emitVPTOBackendResult(ModuleOp module, + llvm::ToolOutputFile &outputFile) { + if (failed(prepareVPTOForEmission(module))) + return 1; + return emitPreparedVPTOBackendResult(module, outputFile); +} + int main(int argc, char **argv) { DialectRegistry registry; registry.insert(); @@ -933,6 +1179,36 @@ int main(int argc, char **argv) { mlir::registerPassManagerCLOptions(); llvm::cl::ParseCommandLineOptions(argc, argv, "PTO Assembler (ptoas)\n"); + PTOBackend effectiveBackend = PTOBackend::EmitC; + if (!parseBackend(ptoBackend, effectiveBackend)) { + llvm::errs() << "Error: invalid --pto-backend='" << ptoBackend + << "'. Expected 'emitc' or 'vpto'.\n"; + return 1; + } + + if (vptoEmitHIVMOfficialLLVM && vptoEmitHIVMOfficialBitcode) { + llvm::errs() << "Error: --vpto-emit-hivm-llvm and --vpto-emit-hivm-bc " + "cannot be used together.\n"; + return 1; + } + + if (emitVPTO && + (vptoEmitHIVMOfficialLLVM || vptoEmitHIVMOfficialBitcode)) { + llvm::errs() << "Error: --emit-vpto cannot be used together with HIVM " + "emission flags.\n"; + return 1; + } + + if (effectiveBackend != PTOBackend::VPTO && + (vptoEmitHIVMOfficialLLVM || vptoEmitHIVMOfficialBitcode || emitVPTO || + vptoPrintIntrinsics || vptoAllowUnresolved || + !vptoUnresolvedReport.empty() || !hivmUnresolvedReport.empty() || + ptoPrintSeamIR || !ptoSeamIRFile.empty())) { + llvm::errs() << "Error: VPTO-specific flags require " + "--pto-backend=vpto.\n"; + return 1; + } + // Read whole input first (so we can auto-detect .ptobc by magic). auto fileOrErr = llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); if (!fileOrErr) { @@ -957,6 +1233,8 @@ int main(int argc, char **argv) { OwningOpRef module; llvm::StringRef buf = (*fileOrErr)->getBuffer(); const bool isPTOBC = (buf.size() >= 6 && std::memcmp(buf.data(), "PTOBC\0", 6) == 0); + const bool inputIsVPTOIR = containsVPTOIR(buf); + auto normalizeArch = [](llvm::StringRef archValue) { std::string normalized = archValue.str(); for (char &c : normalized) @@ -1092,6 +1370,24 @@ int main(int argc, char **argv) { return 1; } + // [Fix] ToolOutputFile Usage + std::error_code ec; + llvm::ToolOutputFile outputFile(outputFilename, ec, llvm::sys::fs::OF_None); + if (ec) { + llvm::errs() << ec.message() << "\n"; + return 1; + } + + if (effectiveBackend == PTOBackend::VPTO && inputIsVPTOIR) { + if (ptoPrintSeamIR || !ptoSeamIRFile.empty()) { + llvm::errs() << "Error: shared pre-backend seam IR is unavailable when " + "the input is already VPTO IR.\n"; + return 1; + } + + return emitVPTOBackendResult(*module, outputFile); + } + // Main PassManager PassManager pm(&context); @@ -1123,15 +1419,6 @@ int main(int argc, char **argv) { // Conditionally add Sync pass based on flag. if (enableInsertSync) pm.addNestedPass(pto::createPTOInsertSyncPass()); - - - // [Fix] ToolOutputFile Usage - std::error_code ec; - llvm::ToolOutputFile outputFile(outputFilename, ec, llvm::sys::fs::OF_None); - if (ec) { - llvm::errs() << ec.message() << "\n"; - return 1; - } if (emitMlirIR) { if (failed(pm.run(*module))) { @@ -1143,6 +1430,28 @@ int main(int argc, char **argv) { } pm.addPass(createCSEPass()); + + module->getOperation()->setAttr("pto.target_arch", + mlir::StringAttr::get(&context, arch)); + + if (effectiveBackend == PTOBackend::VPTO) { + if (failed(pm.run(*module))) { + llvm::errs() << "Error: Pass execution failed.\n"; + return 1; + } + + if (ptoPrintSeamIR) { + module->print(llvm::errs()); + llvm::errs() << "\n"; + } + if (failed(emitSharedPreBackendSeamIR(*module, ptoSeamIRFile))) + return 1; + + if (failed(lowerPTOToVPTOBackend(*module))) + return 1; + return emitVPTOBackendResult(*module, outputFile); + } + if (arch == "a3") { pm.addPass(pto::createEmitPTOManualPass(pto::PTOArch::A3)); } else { From cd451b43ca105d3129c0f1591bb27e4a16d8d144 Mon Sep 17 00:00:00 2001 From: WenboCodes Date: Fri, 10 Apr 2026 16:48:53 +0800 Subject: [PATCH 002/192] clarify block query docs and trim conversion section Explain block/subblock runtime queries in workload-partitioning terms and remove redundant supported-forms wording from conversion ops docs. Co-Authored-By: Claude Opus 4.6 (1M context) --- docs/isa/09-conversion-ops.md | 6 ----- docs/vpto-spec.md | 42 ++++++++++++++++++++++++++--------- 2 files changed, 32 insertions(+), 16 deletions(-) diff --git a/docs/isa/09-conversion-ops.md b/docs/isa/09-conversion-ops.md index c3674f523..2099474ca 100644 --- a/docs/isa/09-conversion-ops.md +++ b/docs/isa/09-conversion-ops.md @@ -121,12 +121,6 @@ as `32 -> 16` or `16 -> 32` style conversions. - Use for width-changing conversions that select the even or odd half of the destination packing layout. -### A5 Supported Forms - -The forms below are expressed in PTO surface syntax. Source/target type -combinations not listed here should not currently be assumed to be supported -on A5. - #### Float To Int - `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<32xsi64>` diff --git a/docs/vpto-spec.md b/docs/vpto-spec.md index efb383feb..c20942b75 100644 --- a/docs/vpto-spec.md +++ b/docs/vpto-spec.md @@ -310,17 +310,41 @@ PTO micro Instruction source programs are not restricted to `pto` operations alo - `scf`: structured control flow used to model counted loops, conditional regions, loop-carried state, and break-like control around PTO compute and data-movement ops. - Shared dialect ops remain in standard MLIR form so that PTO analyses and backend passes can reason about control flow and scalar state without re-encoding them as PTO-specific instructions. -### Runtime Query Operations +### BlockDim Query Operations -PTO micro Instruction also provides scalar runtime-query ops for inspecting the -current execution instance. These ops are pure, have no side effects, and may -be used in ordinary scalar control-flow or address computation. +These ops expose the current kernel instance's execution coordinates to scalar code. They are the PTO-level equivalent of runtime queries such as `GetBlockIdx()` and `GetBlockNum()` in kernel programming models. + +Use them when the same kernel body is launched across multiple blocks or subblocks and each execution instance must figure out which slice of the global workload it owns. + +A common pattern is: + +- split the full input/output tensor into `block_num` disjoint block-sized regions +- let each block compute its own starting offset from `block_idx` +- within one block, further tile the local region and drive the tile loop with ordinary scalar `arith` / `scf` ops + +For example, if a tensor is split evenly across 8 blocks and each block handles `block_length = 2048` elements, then block `b` owns the global range `[b * block_length, (b + 1) * block_length)`. The per-block GM base pointer can be formed by adding `block_idx * block_length` elements to the original base pointer. + +At the PTO micro Instruction level, these runtime-query ops are pure scalar producers. They do not perform data movement, do not allocate memory, and do not by themselves create tiling or double buffering. Instead, they provide the scalar values used by surrounding address computation and structured control flow. + +#### Example: block-level data partitioning + +```mlir +%block = pto.get_block_idx +%block_num = pto.get_block_num +%block_len = arith.constant 2048 : index +%base = arith.index_cast %block : i64 to index +%offset = arith.muli %base, %block_len : index +%block_in = pto.addptr %gm_in, %offset : !pto.ptr -> !pto.ptr +%block_out = pto.addptr %gm_out, %offset : !pto.ptr -> !pto.ptr +``` + +In this pattern, all blocks execute the same kernel body, but each block sees a different `%block` value and therefore computes a different GM window. #### `pto.get_block_idx` - **syntax:** `%block = pto.get_block_idx` - **result:** `i64` -- **semantics:** Return the current block ID. +- **semantics:** Return the current block ID in the range `[0, pto.get_block_num())`. ```c block = block_idx(); @@ -330,7 +354,7 @@ block = block_idx(); - **syntax:** `%subblock = pto.get_subblock_idx` - **result:** `i64` -- **semantics:** Return the current vector subblock ID. +- **semantics:** Return the current subblock ID in the range `[0, pto.get_subblock_num())`. ```c subblock = subblock_idx(); @@ -340,8 +364,7 @@ subblock = subblock_idx(); - **syntax:** `%block_num = pto.get_block_num` - **result:** `i64` -- **semantics:** Return the total number of launched blocks visible to the - current kernel instance. +- **semantics:** Return the total number of launched blocks visible to the current kernel instance. ```c block_num = block_num(); @@ -351,8 +374,7 @@ block_num = block_num(); - **syntax:** `%subblock_num = pto.get_subblock_num` - **result:** `i64` -- **semantics:** Return the number of vector subblocks visible to the current - execution instance. +- **semantics:** Return the total number of visible subblocks for the current execution instance. ```c subblock_num = subblock_num(); From a34c04ca59674c8a1183503faff2a61d695fbbf4 Mon Sep 17 00:00:00 2001 From: Lok Date: Sat, 11 Apr 2026 15:37:01 +0800 Subject: [PATCH 003/192] docs(isa): clarify get_buf/rls_buf usage and mode parameter - Add detailed mode parameter documentation (mode=0 vs mode=1) - Add 'Why get_buf/rls_buf is More Programmer-Friendly' section: - No manual priming/draining for ping/pong loops - No loop peeling for complex/nested loop dependencies - Simpler mental model (buffer ID + program order) - Add quick example comparison showing set_flag overhead vs get_buf simplicity - Update Example 2 and 3b with explicit mode=0 in code - Update comparison table with 'Loop peeling' row --- docs/isa/01-pipeline-sync.md | 153 ++++++++++++++++++++++++++++++----- 1 file changed, 134 insertions(+), 19 deletions(-) diff --git a/docs/isa/01-pipeline-sync.md b/docs/isa/01-pipeline-sync.md index 3040ec3d1..752b0ea85 100644 --- a/docs/isa/01-pipeline-sync.md +++ b/docs/isa/01-pipeline-sync.md @@ -90,6 +90,29 @@ rls_buf(pipe, buf_id, mode); --- +### Mode Parameter for `get_buf` / `rls_buf` + +The `mode` parameter controls how `get_buf` and `rls_buf` interact with pipeline execution and dependency tracking: + +| Mode | `get_buf` Behavior | `rls_buf` Behavior | Use Case | +|------|-------------------|-------------------|----------| +| **0** (default) | **Blocking acquire**: waits for all previous `rls_buf` with same `buf_id` from all pipelines (in program order) before the specified pipe can proceed | **Immediate release**: signals completion for only the instructions related to the specified pipe | **Automatic ping/pong dependency** — recommended for double/multi-buffering | +| **1** | **Non-blocking acquire**: does not wait; pipe execution proceeds immediately | **Deferred release**: waits for all instructions across all pipelines with same `buf_id` to retire before signaling | **Backward compatibility** with `set_flag`/`wait_flag` semantics | + +**Mode 0 (Default — Recommended):** +- `get_buf`: The specified pipeline blocks until all previous `rls_buf` operations for the same buffer ID (from any pipeline) have completed, respecting program order. +- `rls_buf`: Immediately signals that the specified pipeline has finished using the buffer — only waits for that pipe's related instructions. +- This mode provides **automatic RAW/WAR/WAW dependency resolution** based on buffer ID and program order, making it ideal for ping/pong and N-buffer patterns. + +**Mode 1 (Legacy Compatibility):** +- `get_buf`: Does not block — the pipeline proceeds immediately without waiting. +- `rls_buf`: Waits for **all** previous instructions across **all** pipelines with the same buffer ID to retire before signaling release. +- This mode emulates `set_flag`/`wait_flag` behavior and is provided for backward compatibility with existing code patterns. + +> **Note:** A5 supports both `set_flag`/`wait_flag` and `get_buf`/`rls_buf` mechanisms. Mode 1 is rarely needed since mode 0 provides a more programmer-friendly approach for buffer-based synchronization. + +--- + ### `pto.mem_bar` - **syntax:** `pto.mem_bar "BARRIER_TYPE"` @@ -116,6 +139,97 @@ pto.mem_bar "VST_VLD" --- +## Why `get_buf` / `rls_buf` is More Programmer-Friendly + +The buffer-based synchronization (`get_buf`/`rls_buf`) provides the **same functional capability** as `set_flag`/`wait_flag` for maintaining correct ordering of RAW/WAR/WAW dependencies across pipelines, but with significant usability advantages: + +### 1. No Manual Priming or Draining + +With `set_flag`/`wait_flag`, ping/pong loops require: +- **Pre-loop priming**: 4× `set_flag` to initialize reverse-dependency signals (otherwise first iteration deadlocks) +- **Post-loop draining**: 4× `wait_flag` to consume leftover signals from final iterations + +With `get_buf`/`rls_buf`: +- **First iteration**: Buffer is initially free, so `get_buf` proceeds immediately — no priming needed +- **Final iteration**: Last `rls_buf` simply completes — no draining required + +### 2. No Loop Peeling for Complex Dependencies + +For nested loops or non-1:1 producer-consumer ratios (e.g., 1 producer : N consumers, or complex control flow), `set_flag`/`wait_flag` requires manual **loop peeling** to handle boundary conditions: + +```mlir +// set_flag/wait_flag: must peel first/last iterations for correct signal counts +// Iteration 0: special case (no previous consumer) +// Iteration N-1: special case (no next producer) +// Each boundary needs explicit handling +``` + +With `get_buf`/`rls_buf`, the acquire/release protocol handles all boundaries automatically: + +```mlir +// get_buf/rls_buf: uniform loop body, no peeling +scf.for %i = %c0 to %N step %c1 { + pto.get_buf %bufid, "PIPE_X" // blocks if buffer in use + // ... work ... + pto.rls_buf %bufid, "PIPE_X" // signals completion +} +// Works correctly for any loop structure or dependency pattern +``` + +### 3. Simpler Mental Model + +| Aspect | `set_flag`/`wait_flag` | `get_buf`/`rls_buf` | +|--------|------------------------|---------------------| +| **Dependency tracking** | Manual: track event IDs, signal directions, pair every set with wait | Automatic: buffer ID + program order | +| **Event ID management** | 8 IDs for 2-buffer ping/pong (grows with buffers) | Just buffer IDs (shared global pool) | +| **Error-prone areas** | Forgetting prime/drain, mismatched IDs, wrong direction | Forgetting release (but compile-time checkable) | + +### Quick Example Comparison + +**Problem:** MTE2 loads into `buf[i%2]`, Vector processes, MTE3 stores — standard ping/pong. + +**set_flag/wait_flag approach:** +```mlir +// BEFORE loop: prime 4 reverse-dep signals +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] + +scf.for %i = ... { + // 8 set_flag + 8 wait_flag inside loop + // Must track {IN,OUT} × {FWD,REV} × {0,1} = 8 event IDs +} + +// AFTER loop: drain 4 signals +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-1)%2}"] +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-2)%2}"] +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-1)%2}"] +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-2)%2}"] +``` + +**get_buf/rls_buf approach:** +```mlir +scf.for %i = ... { + pto.get_buf %bufid_in[%pp], "PIPE_MTE2" + // ... MTE2 work ... + pto.rls_buf %bufid_in[%pp], "PIPE_MTE2" + + pto.get_buf %bufid_in[%pp], "PIPE_V" + pto.get_buf %bufid_out[%pp], "PIPE_V" + // ... Vector work ... + pto.rls_buf %bufid_in[%pp], "PIPE_V" + pto.rls_buf %bufid_out[%pp], "PIPE_V" + + pto.get_buf %bufid_out[%pp], "PIPE_MTE3" + // ... MTE3 work ... + pto.rls_buf %bufid_out[%pp], "PIPE_MTE3" +} +// Done. No prime. No drain. Dependencies resolved by buffer ID + program order. +``` + +--- + ## Intra-Core Sync Patterns & Examples ### Example 1: `set_flag` / `wait_flag` (Explicit Events) @@ -161,16 +275,16 @@ Instead of naming events, each pipeline declares when it **acquires** (`get_buf` ```mlir // ─── Stage 1: MTE2 loads data into UB ─── // MTE2 acquires ub_ptr — blocks if Vector hasn't released it from a prior iteration -pto.get_buf "PIPE_MTE2", %bufid_ub_ptr, %mode : i64, i64 +pto.get_buf "PIPE_MTE2", %bufid_ub_ptr, %c0 : i64, i64 // mode=0 (default) pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, ... // MTE2 done writing ub_ptr — release it so Vector can consume -pto.rls_buf "PIPE_MTE2", %bufid_ub_ptr, %mode : i64, i64 +pto.rls_buf "PIPE_MTE2", %bufid_ub_ptr, %c0 : i64, i64 // ─── Stage 2: Vector computation ─── // Vector acquires ub_ptr (input) — blocks until MTE2 releases it (RAW: MTE2 write → V read) -pto.get_buf "PIPE_V", %bufid_ub_ptr, %mode : i64, i64 +pto.get_buf "PIPE_V", %bufid_ub_ptr, %c0 : i64, i64 // Vector acquires ub_out (output) — blocks until MTE3 releases it from a prior iteration (WAR: MTE3 read → V write) -pto.get_buf "PIPE_V", %bufid_ub_out, %mode : i64, i64 +pto.get_buf "PIPE_V", %bufid_ub_out, %c0 : i64, i64 pto.vecscope { %mask = pto.pset_b32 "PAT_ALL" : !pto.mask @@ -180,16 +294,16 @@ pto.vecscope { } // Vector done reading ub_ptr — release so MTE2 can reuse it in next iteration -pto.rls_buf "PIPE_V", %bufid_ub_ptr, %mode : i64, i64 +pto.rls_buf "PIPE_V", %bufid_ub_ptr, %c0 : i64, i64 // Vector done writing ub_out — release so MTE3 can consume -pto.rls_buf "PIPE_V", %bufid_ub_out, %mode : i64, i64 +pto.rls_buf "PIPE_V", %bufid_ub_out, %c0 : i64, i64 // ─── Stage 3: MTE3 stores result to GM ─── // MTE3 acquires ub_out — blocks until Vector releases it (RAW: V write → MTE3 read) -pto.get_buf "PIPE_MTE3", %bufid_ub_out, %mode : i64, i64 +pto.get_buf "PIPE_MTE3", %bufid_ub_out, %c0 : i64, i64 pto.copy_ubuf_to_gm %ub_out, %gm_out, ... // MTE3 done reading ub_out — release so Vector can reuse it in next iteration -pto.rls_buf "PIPE_MTE3", %bufid_ub_out, %mode : i64, i64 +pto.rls_buf "PIPE_MTE3", %bufid_ub_out, %c0 : i64, i64 ``` **Key property:** No event IDs needed. Dependencies are implicit from program order of `get_buf`/`rls_buf` on the same buffer ID. This becomes much more convenient in multi-iteration loops (see Example 3). @@ -293,14 +407,14 @@ scf.for %i = %c0 to %N step %c1 { // ── MTE2: load tile[i] into buf[i%2] ── // Acquires buf[i%2] — on first iteration, buffer is free so proceeds immediately. // On later iterations, blocks until Vector releases buf[i%2] (WAR: automatic). - pto.get_buf %bufid_buf[%pp], "PIPE_MTE2" + pto.get_buf "PIPE_MTE2", %bufid_buf[%pp], %c0 : i64, i64 // mode=0 pto.copy_gm_to_ubuf %gm_ptr[%i], %ub_buf[%pp], ... - pto.rls_buf %bufid_buf[%pp], "PIPE_MTE2" + pto.rls_buf "PIPE_MTE2", %bufid_buf[%pp], %c0 : i64, i64 // ── Vector: compute on buf[i%2] ── // Acquires buf[i%2] — blocks until MTE2 releases it (RAW: automatic) - pto.get_buf %bufid_buf[%pp], "PIPE_V" - pto.get_buf %bufid_out[%pp], "PIPE_V" + pto.get_buf "PIPE_V", %bufid_buf[%pp], %c0 : i64, i64 + pto.get_buf "PIPE_V", %bufid_out[%pp], %c0 : i64, i64 scf.for %dummy = %c0 to %c1 step %c1 { %v = pto.vlds %ub_buf[%pp][%lane] : !pto.ptr -> !pto.vreg<64xf32> %mask = pto.pset_b32 "PAT_ALL" : !pto.mask @@ -308,14 +422,14 @@ scf.for %i = %c0 to %N step %c1 { pto.vsts %abs, %ub_out[%pp][%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask } {llvm.loop.aivector_scope} // Release buf[i%2] — MTE2 can reuse in iteration i+2 (WAR resolved) - pto.rls_buf %bufid_buf[%pp], "PIPE_V" - pto.rls_buf %bufid_out[%pp], "PIPE_V" + pto.rls_buf "PIPE_V", %bufid_buf[%pp], %c0 : i64, i64 + pto.rls_buf "PIPE_V", %bufid_out[%pp], %c0 : i64, i64 // ── MTE3: store result ── // Acquires out[i%2] — blocks until Vector releases it (RAW: automatic) - pto.get_buf %bufid_out[%pp], "PIPE_MTE3" + pto.get_buf "PIPE_MTE3", %bufid_out[%pp], %c0 : i64, i64 pto.copy_ubuf_to_gm %ub_out[%pp], %gm_out[%i], ... - pto.rls_buf %bufid_out[%pp], "PIPE_MTE3" + pto.rls_buf "PIPE_MTE3", %bufid_out[%pp], %c0 : i64, i64 } // No post-loop drain needed — last rls_buf completes the pipeline. ``` @@ -335,10 +449,11 @@ scf.for %i = %c0 to %N step %c1 { | IDs per pipe-pair | **8** = 2 buffers × 2 dirs × 2 (fwd+rev) | 1 fwd + 1 rev per buffer (shared global pool) | | Total HW IDs | 8 per pipe-pair, grows with buffers | **32 global** across all pipes | | Reverse (WAR) deps | Extra `set_flag`/`wait_flag` pair per buffer | Handled automatically | -| Pre-loop setup | `set_flag` to prime each reverse dep | None | -| Post-loop teardown | `wait_flag` to drain all primed signals | None | +| Pre-loop setup | `set_flag` to prime each reverse dep | **None** | +| Post-loop teardown | `wait_flag` to drain all primed signals | **None** | +| Loop peeling for complex deps | Required for non-1:1 or nested loops | **Not required** | | Straight-line code | Simple, clear | Slightly more verbose (bracket each stage) | -| Ping/pong loops | 8 event IDs + 4 prime + 4 drain | Same pattern, no overhead | +| Ping/pong loops | 8 event IDs + 4 prime + 4 drain | Same pattern, **no overhead** | | Best used for | Simple pipelines, fine-grained control | Double/multi-buffering, complex loops | --- From 46f1a8ab2783d42897c4984cd2f7afd511912771 Mon Sep 17 00:00:00 2001 From: Lok Date: Sat, 11 Apr 2026 15:47:38 +0800 Subject: [PATCH 004/192] fix: correct event ID explanation in comparison table - set_flag/wait_flag: 2 IDs per buffer (1 forward + 1 reverse pipe-pair) - get_buf/rls_buf: 1 ID per buffer (handles both directions automatically) - 8 per pipe-pair is HW limit, not a formula --- docs/isa/01-pipeline-sync.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/isa/01-pipeline-sync.md b/docs/isa/01-pipeline-sync.md index 752b0ea85..e4877f108 100644 --- a/docs/isa/01-pipeline-sync.md +++ b/docs/isa/01-pipeline-sync.md @@ -446,8 +446,8 @@ scf.for %i = %c0 to %N step %c1 { | Aspect | `set_flag` / `wait_flag` | `get_buf` / `rls_buf` | |--------|--------------------------|------------------------| | Dependency model | Explicit event signals | Implicit via buffer acquire/release | -| IDs per pipe-pair | **8** = 2 buffers × 2 dirs × 2 (fwd+rev) | 1 fwd + 1 rev per buffer (shared global pool) | -| Total HW IDs | 8 per pipe-pair, grows with buffers | **32 global** across all pipes | +| IDs per pipe-pair | 2 IDs per buffer: 1 for forward (e.g., MTE2→V) + 1 for reverse (V→MTE2) | **1 ID per buffer** (handles both directions automatically) | +| Total HW IDs | **8 per pipe-pair** (hardware limit) | **32 global** across all pipes | | Reverse (WAR) deps | Extra `set_flag`/`wait_flag` pair per buffer | Handled automatically | | Pre-loop setup | `set_flag` to prime each reverse dep | **None** | | Post-loop teardown | `wait_flag` to drain all primed signals | **None** | From 3df56a7c5d4483500e6e7faed3903b24e2e7c057 Mon Sep 17 00:00:00 2001 From: Lok Date: Sat, 11 Apr 2026 15:55:15 +0800 Subject: [PATCH 005/192] fix: clarify event ID management in comparison table - set_flag/wait_flag: 8 IDs per pipe-pair direction (HW limit) - get_buf/rls_buf: 1 buffer ID per shared resource (HW limit: 32 global), same ID used across all pipelines --- docs/isa/01-pipeline-sync.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/isa/01-pipeline-sync.md b/docs/isa/01-pipeline-sync.md index e4877f108..ee3032ee2 100644 --- a/docs/isa/01-pipeline-sync.md +++ b/docs/isa/01-pipeline-sync.md @@ -181,7 +181,7 @@ scf.for %i = %c0 to %N step %c1 { | Aspect | `set_flag`/`wait_flag` | `get_buf`/`rls_buf` | |--------|------------------------|---------------------| | **Dependency tracking** | Manual: track event IDs, signal directions, pair every set with wait | Automatic: buffer ID + program order | -| **Event ID management** | 8 IDs for 2-buffer ping/pong (grows with buffers) | Just buffer IDs (shared global pool) | +| **Event ID management** | **8 IDs per pipe-pair direction** (HW limit); 2-buffer ping/pong uses 4 IDs per pipe-pair (2 fwd + 2 rev) | **1 buffer ID per shared resource** (HW limit: 32 global); same ID used across all pipelines | | **Error-prone areas** | Forgetting prime/drain, mismatched IDs, wrong direction | Forgetting release (but compile-time checkable) | ### Quick Example Comparison From 31540f2dcd140b4f00c9a9d3d41ef716e1c60f61 Mon Sep 17 00:00:00 2001 From: Lok Date: Sat, 11 Apr 2026 15:59:55 +0800 Subject: [PATCH 006/192] fix: simplify event ID explanation and drain example - Event ID mgmt: each buffer occupies 1 ID per direction (removed misleading 4 IDs calc) - Drain example: use concrete EVT_*_0/EVT_*_1 instead of {(N-1)%2} expressions --- docs/isa/01-pipeline-sync.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/isa/01-pipeline-sync.md b/docs/isa/01-pipeline-sync.md index ee3032ee2..ba9ae4d4e 100644 --- a/docs/isa/01-pipeline-sync.md +++ b/docs/isa/01-pipeline-sync.md @@ -181,7 +181,7 @@ scf.for %i = %c0 to %N step %c1 { | Aspect | `set_flag`/`wait_flag` | `get_buf`/`rls_buf` | |--------|------------------------|---------------------| | **Dependency tracking** | Manual: track event IDs, signal directions, pair every set with wait | Automatic: buffer ID + program order | -| **Event ID management** | **8 IDs per pipe-pair direction** (HW limit); 2-buffer ping/pong uses 4 IDs per pipe-pair (2 fwd + 2 rev) | **1 buffer ID per shared resource** (HW limit: 32 global); same ID used across all pipelines | +| **Event ID management** | **8 IDs per pipe-pair direction** (HW limit); each buffer occupies 1 ID per direction | **1 buffer ID per shared resource** (HW limit: 32 global); same ID used across all pipelines | | **Error-prone areas** | Forgetting prime/drain, mismatched IDs, wrong direction | Forgetting release (but compile-time checkable) | ### Quick Example Comparison @@ -202,10 +202,10 @@ scf.for %i = ... { } // AFTER loop: drain 4 signals -pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-1)%2}"] -pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-2)%2}"] -pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-1)%2}"] -pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-2)%2}"] +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] ``` **get_buf/rls_buf approach:** From 440925c7307e71d0b66f58b51f92727bec9e78bb Mon Sep 17 00:00:00 2001 From: Lok Date: Sat, 11 Apr 2026 16:16:52 +0800 Subject: [PATCH 007/192] fix: correct set_flag/wait_flag count in quick example MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 4 set_flag + 4 wait_flag (not 8) - 4 IDs = 2 pipe-pair directions × 2 ping/pong buffers --- docs/isa/01-pipeline-sync.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/isa/01-pipeline-sync.md b/docs/isa/01-pipeline-sync.md index ba9ae4d4e..818105be2 100644 --- a/docs/isa/01-pipeline-sync.md +++ b/docs/isa/01-pipeline-sync.md @@ -197,8 +197,8 @@ pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] scf.for %i = ... { - // 8 set_flag + 8 wait_flag inside loop - // Must track {IN,OUT} × {FWD,REV} × {0,1} = 8 event IDs + // 4 set_flag + 4 wait_flag inside loop + // Must track 4 IDs: 2 pipe-pair directions × 2 ping/pong buffers } // AFTER loop: drain 4 signals From a7e54516a1e6ffcc146692de064aa2b518db0cca Mon Sep 17 00:00:00 2001 From: Lok Date: Sat, 11 Apr 2026 16:20:13 +0800 Subject: [PATCH 008/192] fix: add concrete 1:N example for loop peeling comparison MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - set_flag/wait_flag: 1 MTE2 load, 8 Vector slices — must peel set/wait outside loop - get_buf/rls_buf: same pattern but acquire/release can stay inside or outside --- docs/isa/01-pipeline-sync.md | 37 ++++++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/docs/isa/01-pipeline-sync.md b/docs/isa/01-pipeline-sync.md index 818105be2..268c8e105 100644 --- a/docs/isa/01-pipeline-sync.md +++ b/docs/isa/01-pipeline-sync.md @@ -155,25 +155,38 @@ With `get_buf`/`rls_buf`: ### 2. No Loop Peeling for Complex Dependencies -For nested loops or non-1:1 producer-consumer ratios (e.g., 1 producer : N consumers, or complex control flow), `set_flag`/`wait_flag` requires manual **loop peeling** to handle boundary conditions: +For non-1:1 producer-consumer ratios (e.g., 1 MTE2 load : N Vector compute slices), `set_flag`/`wait_flag` requires **peeling the set_flag outside the loop**: ```mlir -// set_flag/wait_flag: must peel first/last iterations for correct signal counts -// Iteration 0: special case (no previous consumer) -// Iteration N-1: special case (no next producer) -// Each boundary needs explicit handling +// set_flag/wait_flag: 1 MTE2 load, 8 Vector computes on slices +// MTE2 loads large tile once +pto.copy_gm_to_ubuf %gm_ptr, %ub_tile, ... +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVT_TILE_READY"] // ◀ MUST be outside loop + +// Vector consumes in 8 slices — but wait_flag can only fire ONCE +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVT_TILE_READY"] // ◀ MUST peel before loop +scf.for %slice = %c0 to %c8 step %c1 { + // compute on %ub_tile[%slice] + // Cannot put wait_flag here — would deadlock on iteration 1+ +} ``` -With `get_buf`/`rls_buf`, the acquire/release protocol handles all boundaries automatically: +With `get_buf`/`rls_buf`, acquire/release can stay **inside or outside** the loop — both work: ```mlir -// get_buf/rls_buf: uniform loop body, no peeling -scf.for %i = %c0 to %N step %c1 { - pto.get_buf %bufid, "PIPE_X" // blocks if buffer in use - // ... work ... - pto.rls_buf %bufid, "PIPE_X" // signals completion +// get_buf/rls_buf: same 1:8 pattern, but more flexible +// MTE2 loads large tile +pto.get_buf "PIPE_MTE2", %bufid_tile, %c0 : i64, i64 +pto.copy_gm_to_ubuf %gm_ptr, %ub_tile, ... +pto.rls_buf "PIPE_MTE2", %bufid_tile, %c0 : i64, i64 + +// Vector can acquire once outside, or acquire/release per slice — both correct +pto.get_buf "PIPE_V", %bufid_tile, %c0 : i64, i64 // acquire once +scf.for %slice = %c0 to %c8 step %c1 { + // compute on %ub_tile[%slice] } -// Works correctly for any loop structure or dependency pattern +pto.rls_buf "PIPE_V", %bufid_tile, %c0 : i64, i64 // release once +// No peeling required — get_buf blocks until MTE2's rls_buf completes ``` ### 3. Simpler Mental Model From 4d5dac841671649d410d6cb480839a4a2f4372af Mon Sep 17 00:00:00 2001 From: Lok Date: Sat, 11 Apr 2026 16:24:42 +0800 Subject: [PATCH 009/192] fix: show get_buf/rls_buf inside scf loop for 1:N example - Acquire/release per slice inside loop - Iteration 0 blocks until MTE2 done, iterations 1-7 proceed immediately --- docs/isa/01-pipeline-sync.md | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/docs/isa/01-pipeline-sync.md b/docs/isa/01-pipeline-sync.md index 268c8e105..c7b32b677 100644 --- a/docs/isa/01-pipeline-sync.md +++ b/docs/isa/01-pipeline-sync.md @@ -171,22 +171,23 @@ scf.for %slice = %c0 to %c8 step %c1 { } ``` -With `get_buf`/`rls_buf`, acquire/release can stay **inside or outside** the loop — both work: +With `get_buf`/`rls_buf`, acquire/release can be **inside the loop** — no peeling needed: ```mlir -// get_buf/rls_buf: same 1:8 pattern, but more flexible +// get_buf/rls_buf: same 1:8 pattern, acquire/release inside loop works fine // MTE2 loads large tile pto.get_buf "PIPE_MTE2", %bufid_tile, %c0 : i64, i64 pto.copy_gm_to_ubuf %gm_ptr, %ub_tile, ... pto.rls_buf "PIPE_MTE2", %bufid_tile, %c0 : i64, i64 -// Vector can acquire once outside, or acquire/release per slice — both correct -pto.get_buf "PIPE_V", %bufid_tile, %c0 : i64, i64 // acquire once +// Vector acquires/releases per slice — all 8 iterations work correctly scf.for %slice = %c0 to %c8 step %c1 { + pto.get_buf "PIPE_V", %bufid_tile, %c0 : i64, i64 // iteration 0: blocks until MTE2 done + // iteration 1-7: proceeds immediately (already acquired) // compute on %ub_tile[%slice] + pto.rls_buf "PIPE_V", %bufid_tile, %c0 : i64, i64 } -pto.rls_buf "PIPE_V", %bufid_tile, %c0 : i64, i64 // release once -// No peeling required — get_buf blocks until MTE2's rls_buf completes +// No peeling required — get_buf handles the MTE2→V dependency automatically ``` ### 3. Simpler Mental Model From dcd8586a5538c9542991825c455a9c270e0c0019 Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Sun, 12 Apr 2026 06:06:14 +0800 Subject: [PATCH 010/192] refactor vpto llvm emiter --- docs/isa/09-conversion-ops.md | 15 +- docs/release/vpto-spec-v0.1.md | 16 +- docs/release/vpto-spec-v0.2.md | 16 +- docs/tilelang-dsl-guide.md | 10 +- include/PTO/IR/VPTOOps.td | 24 +- .../PTO/Transforms/VPTOLLVMEmitterHelper.h | 6 + lib/PTO/IR/VPTO.cpp | 113 +- lib/PTO/Transforms/CMakeLists.txt | 1 + lib/PTO/Transforms/PTOToVPTOLowering.cpp | 4 +- lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 8308 +++++++++-------- lib/PTO/Transforms/VPTOLLVMEmitterHelper.cpp | 684 ++ 11 files changed, 5098 insertions(+), 4099 deletions(-) create mode 100644 include/PTO/Transforms/VPTOLLVMEmitterHelper.h create mode 100644 lib/PTO/Transforms/VPTOLLVMEmitterHelper.cpp diff --git a/docs/isa/09-conversion-ops.md b/docs/isa/09-conversion-ops.md index 2099474ca..efb3a9ed4 100644 --- a/docs/isa/09-conversion-ops.md +++ b/docs/isa/09-conversion-ops.md @@ -203,7 +203,7 @@ For conversions that change width (e.g., f32→f16), use even/odd parts and comb ## `pto.vtrc` -- **syntax:** `%result = pto.vtrc %input, "RND" : !pto.vreg -> !pto.vreg` +- **syntax:** `%result = pto.vtrc %input, %mask, "RND" : !pto.vreg, !pto.mask -> !pto.vreg` - **semantics:** Truncate/round float to integer-valued float (stays in float type). ```c @@ -212,19 +212,20 @@ for (int i = 0; i < N; i++) ``` - **inputs:** - `%input` is the floating-point source vector and `RND` selects the - truncation/rounding rule. + `%input` is the floating-point source vector, `%mask` selects active lanes, + and `RND` selects the truncation/rounding rule. - **outputs:** `%result` is still a floating-point vector, but each active lane now carries an integer-valued floating-point result. - **constraints and limitations:** - This op does not change the element type. `O` is supported for avoiding - double-rounding errors during staged conversions. + This op does not change the element type. `T` must be `f16`, `f32`, or + `bf16`. `RND` must be one of `R`, `A`, `F`, `C`, or `Z`. `BW` must match the + element width: `b16` for `f16`/`bf16`, `b32` for `f32`. **Example:** ```mlir // Round to nearest integer, keep as float -%rounded = pto.vtrc %input, "R" : !pto.vreg<64xf32> -> !pto.vreg<64xf32> +%rounded = pto.vtrc %input, %mask, "R" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> // input: [1.4, 2.6, -1.5, 3.0] // output: [1.0, 3.0, -2.0, 3.0] ``` @@ -245,7 +246,7 @@ for (int i = 0; i < N; i++) : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xf32> // Floor for integer division -%floored = pto.vtrc %ratio, "F" : !pto.vreg<64xf32> -> !pto.vreg<64xf32> +%floored = pto.vtrc %ratio, %mask, "F" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %int_div = pto.vcvt %floored, %mask {rnd = "Z"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> ``` diff --git a/docs/release/vpto-spec-v0.1.md b/docs/release/vpto-spec-v0.1.md index ab7c5f6b4..a2949a485 100644 --- a/docs/release/vpto-spec-v0.1.md +++ b/docs/release/vpto-spec-v0.1.md @@ -3470,7 +3470,7 @@ For conversions that change width (e.g., f32→f16), use even/odd parts and comb #### `pto.vtrc` -- **syntax:** `%result = pto.vtrc %input, "ROUND_MODE" : !pto.vreg -> !pto.vreg` +- **syntax:** `%result = pto.vtrc %input, %mask, "ROUND_MODE" : !pto.vreg, !pto.mask -> !pto.vreg` - **semantics:** Truncate/round float to integer-valued float (stays in float type). ```c @@ -3479,19 +3479,21 @@ for (int i = 0; i < N; i++) ``` - **inputs:** - `%input` is the floating-point source vector and `ROUND_MODE` selects the - truncation/rounding rule. + `%input` is the floating-point source vector, `%mask` selects active lanes, + and `ROUND_MODE` selects the truncation/rounding rule. - **outputs:** `%result` is still a floating-point vector, but each active lane now carries an integer-valued floating-point result. - **constraints and limitations:** - This op does not change the element type. `ROUND_O` is supported for avoiding - double-rounding errors during staged conversions. + This op does not change the element type. `T` must be `f16`, `f32`, or + `bf16`. `ROUND_MODE` must be one of `ROUND_R`, `ROUND_A`, `ROUND_F`, + `ROUND_C`, or `ROUND_Z`. `BW` must match the element width: `b16` for + `f16`/`bf16`, `b32` for `f32`. **Example:** ```mlir // Round to nearest integer, keep as float -%rounded = pto.vtrc %input, "ROUND_R" : !pto.vreg<64xf32> -> !pto.vreg<64xf32> +%rounded = pto.vtrc %input, %mask, "ROUND_R" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> // input: [1.4, 2.6, -1.5, 3.0] // output: [1.0, 3.0, -2.0, 3.0] ``` @@ -3512,7 +3514,7 @@ for (int i = 0; i < N; i++) : !pto.vreg<128xbf16> -> !pto.vreg<64xf32> // Floor for integer division -%floored = pto.vtrc %ratio, "ROUND_F" : !pto.vreg<64xf32> -> !pto.vreg<64xf32> +%floored = pto.vtrc %ratio, %mask, "ROUND_F" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %int_div = pto.vcvt %floored {round_mode = "ROUND_Z"} : !pto.vreg<64xf32> -> !pto.vreg<64xi32> ``` diff --git a/docs/release/vpto-spec-v0.2.md b/docs/release/vpto-spec-v0.2.md index 3c1e31419..90b632a14 100644 --- a/docs/release/vpto-spec-v0.2.md +++ b/docs/release/vpto-spec-v0.2.md @@ -3675,7 +3675,7 @@ For conversions that change width (e.g., f32→f16), use even/odd parts and comb #### `pto.vtrc` -- **syntax:** `%result = pto.vtrc %input, "ROUND_MODE" : !pto.vreg -> !pto.vreg` +- **syntax:** `%result = pto.vtrc %input, %mask, "ROUND_MODE" : !pto.vreg, !pto.mask -> !pto.vreg` - **semantics:** Truncate/round float to integer-valued float (stays in float type). ```c @@ -3684,19 +3684,21 @@ for (int i = 0; i < N; i++) ``` - **inputs:** - `%input` is the floating-point source vector and `ROUND_MODE` selects the - truncation/rounding rule. + `%input` is the floating-point source vector, `%mask` selects active lanes, + and `ROUND_MODE` selects the truncation/rounding rule. - **outputs:** `%result` is still a floating-point vector, but each active lane now carries an integer-valued floating-point result. - **constraints and limitations:** - This op does not change the element type. `ROUND_O` is supported for avoiding - double-rounding errors during staged conversions. + This op does not change the element type. `T` must be `f16`, `f32`, or + `bf16`. `ROUND_MODE` must be one of `ROUND_R`, `ROUND_A`, `ROUND_F`, + `ROUND_C`, or `ROUND_Z`. `BW` must match the element width: `b16` for + `f16`/`bf16`, `b32` for `f32`. **Example:** ```mlir // Round to nearest integer, keep as float -%rounded = pto.vtrc %input, "ROUND_R" : !pto.vreg<64xf32> -> !pto.vreg<64xf32> +%rounded = pto.vtrc %input, %mask, "ROUND_R" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> // input: [1.4, 2.6, -1.5, 3.0] // output: [1.0, 3.0, -2.0, 3.0] ``` @@ -3717,7 +3719,7 @@ for (int i = 0; i < N; i++) : !pto.vreg<128xbf16> -> !pto.vreg<64xf32> // Floor for integer division -%floored = pto.vtrc %ratio, "ROUND_F" : !pto.vreg<64xf32> -> !pto.vreg<64xf32> +%floored = pto.vtrc %ratio, %mask, "ROUND_F" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %int_div = pto.vcvt %floored {round_mode = "ROUND_Z"} : !pto.vreg<64xf32> -> !pto.vreg<64xi32> ``` diff --git a/docs/tilelang-dsl-guide.md b/docs/tilelang-dsl-guide.md index 2dc224e50..dfdd41263 100644 --- a/docs/tilelang-dsl-guide.md +++ b/docs/tilelang-dsl-guide.md @@ -2494,20 +2494,22 @@ Operations for rearranging data within vectors. Type conversion and specialized operations. -#### `pto.vtrc(vec: VRegType, mask: MaskType) -> VRegType` +#### `pto.vtrc(vec: VRegType, mask: MaskType, rnd: str) -> VRegType` -**Description**: Truncate vector elements. +**Description**: Truncate/round floating-point vector elements to integer-valued +floating-point results under an explicit predicate mask. **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| | `vec` | `VRegType` | Input vector | -| `mask` | `MaskType` | Predicate mask | +| `mask` | `MaskType` | Predicate mask; granularity must match element width | +| `rnd` | `str` | Round mode: `R`, `A`, `F`, `C`, or `Z` | **Returns**: | Return Value | Type | Description | |--------------|------|-------------| -| `result` | `VRegType` | Truncated vector | +| `result` | `VRegType` | Rounded result with the same floating-point element type | #### `pto.vcvt(vec: VRegType, to_type: Type, mask: MaskType) -> VRegType` diff --git a/include/PTO/IR/VPTOOps.td b/include/PTO/IR/VPTOOps.td index 2fbad0291..dc416e868 100644 --- a/include/PTO/IR/VPTOOps.td +++ b/include/PTO/IR/VPTOOps.td @@ -581,8 +581,6 @@ def PTO_VexpOp : PTO_UnaryVecOp<"vexp">; def PTO_VlnOp : PTO_UnaryVecOp<"vln">; def PTO_VsqrtOp : PTO_UnaryVecOp<"vsqrt">; def PTO_VnegOp : PTO_UnaryVecOp<"vneg">; -def PTO_VrsqrtOp : PTO_UnaryVecOp<"vrsqrt">; -def PTO_VrecOp : PTO_UnaryVecOp<"vrec">; def PTO_VreluOp : PTO_UnaryVecOp<"vrelu">; def PTO_VnotOp : PTO_UnaryVecOp<"vnot">; def PTO_VcaddOp : PTO_UnaryVecOp<"vcadd">; @@ -610,8 +608,6 @@ class PTO_BinaryVecOp : PTO_Op { def PTO_VaddOp : PTO_BinaryVecOp<"vadd">; def PTO_VsubOp : PTO_BinaryVecOp<"vsub">; -def PTO_VsaddOp : PTO_BinaryVecOp<"vsadd">; -def PTO_VssubOp : PTO_BinaryVecOp<"vssub">; def PTO_VmulOp : PTO_BinaryVecOp<"vmul">; def PTO_VdivOp : PTO_BinaryVecOp<"vdiv">; def PTO_VmaxOp : PTO_BinaryVecOp<"vmax">; @@ -694,9 +690,6 @@ def PTO_VsubcsOp : PTO_Op<"vsubcs", [Pure]> { }]; } -def PTO_VbcntOp : PTO_UnaryVecOp<"vbcnt">; -def PTO_VclsOp : PTO_UnaryVecOp<"vcls">; - def PTO_VshlOp : PTO_BinaryVecOp<"vshl">; def PTO_VshrOp : PTO_BinaryVecOp<"vshr">; @@ -811,6 +804,7 @@ class PTO_VecScalarMaskedOp : PTO_Op { def PTO_VtrcOp : PTO_Op<"vtrc", [Pure]> { let arguments = (ins PTO_VectorType:$input, + PTO_MaskTypeConstraint:$mask, StrAttr:$round_mode ); let results = (outs PTO_VectorType:$result); @@ -943,7 +937,6 @@ def PTO_Vgather2BcOp : PTO_Op<"vgather2_bc", [ def PTO_VmulsOp : PTO_VecScalarMaskedOp<"vmuls">; def PTO_VaddsOp : PTO_VecScalarMaskedOp<"vadds">; -def PTO_VsaddsOp : PTO_VecScalarMaskedOp<"vsadds">; def PTO_VmaxsOp : PTO_VecScalarMaskedOp<"vmaxs">; def PTO_VminsOp : PTO_VecScalarMaskedOp<"vmins">; def PTO_VlreluOp : PTO_VecScalarMaskedOp<"vlrelu">; @@ -1071,21 +1064,6 @@ def PTO_VselrOp : PTO_Op<"vselr", [Pure]> { }]; } -def PTO_VslideOp : PTO_Op<"vslide", [Pure]> { - let arguments = (ins - PTO_VectorType:$src0, - PTO_VectorType:$src1, - I16:$amt - ); - let results = (outs PTO_VectorType:$result); - - let hasVerifier = 1; - - let assemblyFormat = [{ - $src0 `,` $src1 `,` $amt attr-dict `:` type($src0) `,` type($src1) `,` type($amt) `->` type($result) - }]; -} - def PTO_VsqzOp : PTO_UnaryVecOp<"vsqz">; def PTO_VusqzOp : PTO_Op<"vusqz", [Pure]> { diff --git a/include/PTO/Transforms/VPTOLLVMEmitterHelper.h b/include/PTO/Transforms/VPTOLLVMEmitterHelper.h new file mode 100644 index 000000000..555bbe274 --- /dev/null +++ b/include/PTO/Transforms/VPTOLLVMEmitterHelper.h @@ -0,0 +1,6 @@ +#ifndef MLIR_DIALECT_PTO_TRANSFORMS_VPTOLLVMEMITTERHELPER_H +#define MLIR_DIALECT_PTO_TRANSFORMS_VPTOLLVMEMITTERHELPER_H + +#include "PTO/Transforms/VPTOLLVMEmitter.h" + +#endif // MLIR_DIALECT_PTO_TRANSFORMS_VPTOLLVMEMITTERHELPER_H diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp index aa37efb79..02351975d 100644 --- a/lib/PTO/IR/VPTO.cpp +++ b/lib/PTO/IR/VPTO.cpp @@ -140,6 +140,11 @@ static std::optional getVdupMaskGranularity(Type elementType) { return std::nullopt; } +static bool isSupportedVtrcRoundMode(StringRef mode) { + return mode == "R" || mode == "A" || mode == "F" || mode == "C" || + mode == "Z"; +} + static bool isStoreAlignProducer(Operation *op) { return isa(op); } @@ -1769,17 +1774,6 @@ static LogicalResult verifyVecScalarOpLike(OpTy op) { return success(); } -template -static LogicalResult verifySignedSaturatingVecScalarOpLike(OpTy op) { - if (failed(verifyElementwiseVecScalarOpLike(op))) - return failure(); - auto inputType = cast(op.getInput().getType()); - auto elemType = dyn_cast(inputType.getElementType()); - if (!elemType || elemType.isUnsigned() || elemType.getWidth() != 16) - return op.emitOpError("requires s16 vector element type"); - return success(); -} - template static LogicalResult verifyVecScalarMaskedOpLike(OpTy op) { if (failed(verifyElementwiseVecScalarOpLike(op))) @@ -1821,11 +1815,6 @@ static LogicalResult verifyCarryVecOpWithInput(CarryWithInputOp op) { LogicalResult VmulsOp::verify() { return verifyVecScalarMaskedOpLike(*this); } LogicalResult VaddsOp::verify() { return verifyVecScalarMaskedOpLike(*this); } -LogicalResult VsaddsOp::verify() { - if (failed(verifySignedSaturatingVecScalarOpLike(*this))) - return failure(); - return verifyMaskTypeLike(*this, getMask().getType(), "mask type"); -} LogicalResult VmaxsOp::verify() { return verifyVecScalarMaskedOpLike(*this); } LogicalResult VminsOp::verify() { return verifyVecScalarMaskedOpLike(*this); } LogicalResult VlreluOp::verify() { return verifyVecScalarMaskedOpLike(*this); } @@ -1891,34 +1880,8 @@ LogicalResult VexpOp::verify() { return verifyUnaryVecOp(*this); } LogicalResult VlnOp::verify() { return verifyUnaryVecOp(*this); } LogicalResult VsqrtOp::verify() { return verifyUnaryVecOp(*this); } LogicalResult VnegOp::verify() { return verifyUnaryVecOp(*this); } -LogicalResult VrsqrtOp::verify() { - if (failed(verifyUnaryVecOp(*this))) - return failure(); - auto inputType = cast(getInput().getType()); - Type elemType = inputType.getElementType(); - if (!elemType.isF16() && !elemType.isF32()) - return emitOpError("requires f16 or f32 vector element type"); - return success(); -} -LogicalResult VrecOp::verify() { return verifyUnaryVecOp(*this); } LogicalResult VreluOp::verify() { return verifyUnaryVecOp(*this); } LogicalResult VnotOp::verify() { return verifyUnaryVecOp(*this); } -LogicalResult VbcntOp::verify() { - if (failed(verifyUnaryVecOp(*this))) - return failure(); - auto inputType = cast(getInput().getType()); - if (!isa(inputType.getElementType())) - return emitOpError("requires integer vector element type"); - return success(); -} -LogicalResult VclsOp::verify() { - if (failed(verifyUnaryVecOp(*this))) - return failure(); - auto inputType = cast(getInput().getType()); - if (!isa(inputType.getElementType())) - return emitOpError("requires integer vector element type"); - return success(); -} template static LogicalResult verifyBinaryVecOp(BinaryOp op) { @@ -1936,25 +1899,8 @@ static LogicalResult verifyBinaryVecOp(BinaryOp op) { return success(); } -template -static LogicalResult verifySignedSaturatingBinaryVecOp(BinaryOp op) { - if (failed(verifyBinaryVecOp(op))) - return failure(); - auto lhsType = cast(op.getLhs().getType()); - auto elemType = dyn_cast(lhsType.getElementType()); - if (!elemType || elemType.isUnsigned() || elemType.getWidth() != 16) - return op.emitOpError("requires s16 vector element type"); - return success(); -} - LogicalResult VaddOp::verify() { return verifyBinaryVecOp(*this); } LogicalResult VsubOp::verify() { return verifyBinaryVecOp(*this); } -LogicalResult VsaddOp::verify() { - return verifySignedSaturatingBinaryVecOp(*this); -} -LogicalResult VssubOp::verify() { - return verifySignedSaturatingBinaryVecOp(*this); -} LogicalResult VmulOp::verify() { return verifyBinaryVecOp(*this); } LogicalResult VdivOp::verify() { return verifyBinaryVecOp(*this); } LogicalResult VandOp::verify() { return verifyBinaryVecOp(*this); } @@ -2081,17 +2027,6 @@ LogicalResult VselOp::verify() { LogicalResult VselrOp::verify() { return verifyLaneSelectOp(*this); } LogicalResult Vselrv2Op::verify() { return verifyLaneSelectOp(*this); } -LogicalResult VslideOp::verify() { - if (failed(verifyVRegTypeLike(*this, getSrc0().getType(), "src0 type")) || - failed(verifyVRegTypeLike(*this, getSrc1().getType(), "src1 type")) || - failed(verifyVRegTypeLike(*this, getResult().getType(), "result type"))) - return failure(); - if (getSrc0().getType() != getSrc1().getType() || - getSrc0().getType() != getResult().getType()) - return emitOpError("requires src0, src1, and result to share one vector type"); - return success(); -} - LogicalResult VsqzOp::verify() { return verifyUnaryVecOp(*this); } LogicalResult VusqzOp::verify() { @@ -2203,39 +2138,44 @@ LogicalResult VcmpsOp::verify() { ParseResult VtrcOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::UnresolvedOperand input; + OpAsmParser::UnresolvedOperand mask; std::string roundModeToken; NamedAttrList attrs; - Type inputType, resultType; + Type inputType, maskType, resultType; if (parser.parseOperand(input) || parser.parseComma() || + parser.parseOperand(mask) || parser.parseComma() || parser.parseKeywordOrString(&roundModeToken) || parser.parseOptionalAttrDict(attrs) || - parser.parseColonType(inputType) || parser.parseArrow() || + parser.parseColonType(inputType) || parser.parseComma() || + parser.parseType(maskType) || parser.parseArrow() || parser.parseType(resultType)) return failure(); auto normalized = normalizeRoundModeToken(roundModeToken); - if (!normalized) + if (!normalized || !isSupportedVtrcRoundMode(*normalized)) return parser.emitError(parser.getCurrentLocation()) - << "round mode must be one of R/A/F/C/Z/O or " - "ROUND_R/ROUND_A/ROUND_F/ROUND_C/ROUND_Z/ROUND_O"; + << "round mode must be one of R/A/F/C/Z or " + "ROUND_R/ROUND_A/ROUND_F/ROUND_C/ROUND_Z"; attrs.set("round_mode", parser.getBuilder().getStringAttr(*normalized)); result.addAttributes(attrs); - if (parser.resolveOperand(input, inputType, result.operands)) + if (parser.resolveOperand(input, inputType, result.operands) || + parser.resolveOperand(mask, maskType, result.operands)) return failure(); result.addTypes(resultType); return success(); } void VtrcOp::print(OpAsmPrinter &printer) { - printer << ' ' << getInput() << ", "; + printer << ' ' << getInput() << ", " << getMask() << ", "; Builder builder(getContext()); auto normalized = normalizeRoundModeToken(getRoundMode()); printer.printAttributeWithoutType( builder.getStringAttr(normalized.value_or(getRoundMode()))); printer.printOptionalAttrDict((*this)->getAttrs(), {"round_mode"}); - printer << " : " << getInput().getType() << " -> " << getResult().getType(); + printer << " : " << getInput().getType() << ", " << getMask().getType() + << " -> " << getResult().getType(); } LogicalResult VtrcOp::verify() { @@ -2243,10 +2183,23 @@ LogicalResult VtrcOp::verify() { auto resultType = dyn_cast(getResult().getType()); if (!inputType || !resultType) return emitOpError("input and result must be !pto.vreg<...>"); + if (failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type"))) + return failure(); if (inputType != resultType) return emitOpError("requires input and result to have identical vreg type"); - if (!normalizeRoundModeToken(getRoundMode())) - return emitOpError("round mode must be one of R/A/F/C/Z/O"); + auto elemType = inputType.getElementType(); + if (!(elemType.isF16() || elemType.isF32() || elemType.isBF16())) + return emitOpError("requires f16/f32/bf16 vector element type"); + auto expectedGranularity = getVdupMaskGranularity(elemType); + if (!expectedGranularity) + return emitOpError("requires element type with supported predicate granularity"); + if (failed(verifyMaskTypeWithGranularityLike(*this, getMask().getType(), + "mask type", + *expectedGranularity))) + return failure(); + auto normalized = normalizeRoundModeToken(getRoundMode()); + if (!normalized || !isSupportedVtrcRoundMode(*normalized)) + return emitOpError("round mode must be one of R/A/F/C/Z"); return success(); } diff --git a/lib/PTO/Transforms/CMakeLists.txt b/lib/PTO/Transforms/CMakeLists.txt index 17ddd92df..732db768e 100644 --- a/lib/PTO/Transforms/CMakeLists.txt +++ b/lib/PTO/Transforms/CMakeLists.txt @@ -14,6 +14,7 @@ add_mlir_dialect_library(PTOTransforms HIVMIntrinsicNaming.cpp VPTOLLVMEmitter.cpp + VPTOLLVMEmitterHelper.cpp PTOVPTOExpandBridgeOps.cpp PTOVPTOPtrBoundary.cpp PTOToVPTO.cpp diff --git a/lib/PTO/Transforms/PTOToVPTOLowering.cpp b/lib/PTO/Transforms/PTOToVPTOLowering.cpp index 9568be1bf..89d4f33e7 100644 --- a/lib/PTO/Transforms/PTOToVPTOLowering.cpp +++ b/lib/PTO/Transforms/PTOToVPTOLowering.cpp @@ -4817,8 +4817,10 @@ LogicalResult lowerTCVT(TCvtOp op, PatternRewriter &rewriter) { case VPTOCvtLoweringKind::Vtrc: { auto loaded = rewriter.create(op.getLoc(), srcVecType, srcBuffer, offset, StringAttr()); + Value mask = buildAllPredicateMask(rewriter, op.getLoc(), dstElementType); Value converted = rewriter.create(op.getLoc(), dstVecType, - loaded.getResult(), *roundMode); + loaded.getResult(), mask, + *roundMode); rewriter.create( op.getLoc(), converted, dstBuffer, offset, StringAttr(), buildAllPredicateMask(rewriter, op.getLoc(), dstElementType)); diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index 55418cfe4..16c6a8bd7 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -1,21 +1,8 @@ -//===- VPTOLLVMEmitter.cpp - VPTO to official LLVM IR text emitter -------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - #include "PTO/Transforms/VPTOLLVMEmitter.h" #include "PTO/IR/PTO.h" -#include "PTO/IR/PTO.h" -#include "PTO/Transforms/VPTOLowering.h" -#include "PTO/Transforms/HIVMIntrinsicNaming.h" -#include "PTO/Transforms/Passes.h" +#include "PTO/IR/PTOSyncUtils.h" -#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" -#include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/Passes.h" #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" @@ -24,111 +11,141 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/PatternMatch.h" -#include "mlir/IR/SymbolTable.h" #include "mlir/Pass/PassManager.h" -#include "mlir/Transforms/Passes.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Export.h" -#include "llvm/ADT/SmallString.h" +#include "llvm/Bitcode/BitcodeWriter.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/StringMap.h" -#include "llvm/Analysis/LoopInfo.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/BasicBlock.h" -#include "llvm/IR/CFG.h" -#include "llvm/IR/Instructions.h" -#include "llvm/IR/Dominators.h" #include "llvm/IR/LLVMContext.h" -#include "llvm/IR/MDBuilder.h" -#include "llvm/IR/Module.h" -#include "llvm/Bitcode/BitcodeWriter.h" -#include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Cloning.h" -#include "llvm/Transforms/Utils/Local.h" -#include "llvm/Support/MemoryBuffer.h" -#include "llvm/Support/FileSystem.h" -#include "llvm/Support/Process.h" -#include "llvm/Support/Program.h" -#include "llvm/ADT/ScopeExit.h" -#include "llvm/Support/raw_ostream.h" -#include +namespace mlir::pto { -using namespace mlir; +void materializeVecScopeCarrierLoops(ModuleOp module); +LogicalResult normalizePtoMemRefSpaces(ModuleOp module, + llvm::raw_ostream &diagOS); +LogicalResult applyQueriedTargetAttrs(ModuleOp module, + const VPTOEmissionOptions &options, + llvm::raw_ostream &diagOS); +LogicalResult attachAIVectorScopeMetadata(llvm::Module &llvmModule, + llvm::raw_ostream &diagOS); +void attachHIVMKernelAnnotations(llvm::Module &llvmModule); -namespace mlir::pto { namespace { -constexpr StringLiteral kAIVScopeDummyCallee = "aivscope_dummy"; +static std::string getElementTypeFragment(Type type); +static Type getElementTypeFromVectorLike(Type type); +static std::optional getElementCountFromVectorLike(Type type); -struct QueriedTargetAttrs { - std::string targetCPU; - std::string targetFeatures; -}; +static Type convertVPTOType(Type type, Builder &builder) { + if (auto vecType = dyn_cast(type)) + return VectorType::get({vecType.getElementCount()}, vecType.getElementType()); + if (isa(type)) + return VectorType::get({256}, builder.getI1Type()); + if (isa(type)) + return VectorType::get({32}, builder.getI8Type()); + if (auto ptrType = dyn_cast(type)) { + return LLVM::LLVMPointerType::get( + builder.getContext(), + static_cast(ptrType.getMemorySpace().getAddressSpace())); + } + return type; +} -struct ABIExpr { - enum class Kind { Constant, FuncArg, Mul }; +static bool hasVPTOConvertibleType(Type type) { + return isa(type); +} - Kind kind = Kind::Constant; - uint64_t constant = 0; - unsigned argIndex = 0; - std::unique_ptr lhs; - std::unique_ptr rhs; +static bool hasVPTOConvertibleType(TypeRange types) { + return llvm::any_of(types, [](Type type) { return hasVPTOConvertibleType(type); }); +} - static ABIExpr constantExpr(uint64_t value) { - ABIExpr expr; - expr.kind = Kind::Constant; - expr.constant = value; - return expr; - } +static Value materializeVPTOCast(OpBuilder &builder, Type resultType, + ValueRange inputs, Location loc) { + if (inputs.size() != 1) + return {}; + return builder + .create(loc, TypeRange{resultType}, inputs) + .getResult(0); +} - static ABIExpr argExpr(unsigned argIndex) { - ABIExpr expr; - expr.kind = Kind::FuncArg; - expr.argIndex = argIndex; - return expr; +class VPTOTypeConverter final : public TypeConverter { +public: + explicit VPTOTypeConverter(MLIRContext *context) { + addConversion([](Type type) { return type; }); + addConversion([](Type type) -> Type { + // The conversion callback outlives this constructor, so build on demand + // from the current type context instead of capturing a local Builder. + Builder builder(type.getContext()); + return convertVPTOType(type, builder); + }); + addSourceMaterialization(materializeVPTOCast); + addTargetMaterialization(materializeVPTOCast); + addArgumentMaterialization(materializeVPTOCast); } +}; - static ABIExpr mulExpr(ABIExpr lhs, ABIExpr rhs) { - ABIExpr expr; - expr.kind = Kind::Mul; - expr.lhs = std::make_unique(std::move(lhs)); - expr.rhs = std::make_unique(std::move(rhs)); - return expr; - } +struct PlannedDecl { + std::string name; + FunctionType type; }; -struct ExternalMemRefABISpec { - unsigned addressSpace = 1; - int64_t rank = 0; - ABIExpr offset = ABIExpr::constantExpr(0); - ABIExpr totalSize = ABIExpr::constantExpr(1); - ABIExpr stride = ABIExpr::constantExpr(1); +struct LoweringState { + SmallVector plannedDecls; }; -struct ExternalArgABISpec { - bool isMemRef = false; - ExternalMemRefABISpec memrefSpec; +enum class VcvtElemKind { + Invalid, + F16, + BF16, + F32, + S8, + U8, + S16, + U16, + S32, + U32, + S64, }; -struct FunctionABISpec { - SmallVector args; +struct VcvtContract { + const char *intrinsic; + bool requiresRnd; + bool requiresSat; + bool requiresPart; + unsigned maskBitWidth; }; -static Type getElementTypeFromVectorLike(Type type); -static Type getElementTypeFromPointerLike(Type type); -static std::optional getElementCountFromVectorLike(Type type); -static func::FuncOp getOrCreateExternalFunc(ModuleOp module, StringRef name, - FunctionType type); -static Value castIntegerLikeTo(Operation *anchor, Value value, Type targetType); +static Value getI64Constant(OpBuilder &builder, Location loc, uint64_t value) { + return builder.create(loc, builder.getI64IntegerAttr(value)) + .getResult(); +} + +static Value getI32Constant(OpBuilder &builder, Location loc, uint64_t value) { + return builder.create(loc, builder.getI32IntegerAttr(value)) + .getResult(); +} + +static FailureOr buildLaneTypedCallee(MLIRContext *context, + Type resultType, + StringRef stem, + StringRef suffix) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + auto lanes = getElementCountFromVectorLike(resultType); + if (vec.empty() || !lanes) + return failure(); + + return StringAttr::get(context, "llvm.hivm." + stem.str() + ".v" + + std::to_string(*lanes) + vec + + suffix.str()) + .getValue(); +} static std::string getElementTypeFragment(Type type) { if (type.isF16()) @@ -142,47 +159,70 @@ static std::string getElementTypeFragment(Type type) { return {}; } -static std::optional parseRoundModeImmediate(StringRef roundMode) { - if (roundMode == "R" || roundMode == "ROUND_R") - return 0; // __cce_simd::ROUND::R - if (roundMode == "A" || roundMode == "ROUND_A") - return 1; // __cce_simd::ROUND::A - if (roundMode == "F" || roundMode == "ROUND_F") - return 2; // __cce_simd::ROUND::F - if (roundMode == "C" || roundMode == "ROUND_C") - return 3; // __cce_simd::ROUND::C - if (roundMode == "Z" || roundMode == "ROUND_Z") - return 4; // __cce_simd::ROUND::Z - if (roundMode == "O" || roundMode == "ROUND_O") - return 5; // __cce_simd::ROUND::O - return std::nullopt; +static std::string getVbrScalarFragment(Type type) { + if (type.isF16()) + return "f16"; + if (type.isBF16()) + return "bf16"; + if (type.isF32()) + return "f32"; + if (auto intType = dyn_cast(type)) + return (intType.isUnsigned() ? "u" : "s") + std::to_string(intType.getWidth()); + return {}; } -static std::optional parseSaturationImmediate(StringRef sat) { - if (sat == "SAT" || sat == "RS_ENABLE") - return 0; // __cce_simd::RoundingSaturation::ENABLE - if (sat == "NOSAT" || sat == "RS_DISABLE") - return 1; // __cce_simd::RoundingSaturation::DISABLE - return std::nullopt; +static Type getElementTypeFromVectorLike(Type type) { + if (auto vecType = dyn_cast(type)) + return vecType.getElementType(); + if (auto vecType = dyn_cast(type)) + return vecType.getElementType(); + return {}; } -static std::optional parsePartImmediate(StringRef part) { - if (part == "EVEN" || part == "PART_EVEN") - return 0; // __cce_simd::Part::EVEN - if (part == "ODD" || part == "PART_ODD") - return 1; // __cce_simd::Part::ODD +static std::optional getElementCountFromVectorLike(Type type) { + if (auto vecType = dyn_cast(type)) + return vecType.getElementCount(); + if (auto vecType = dyn_cast(type)) { + if (vecType.getRank() != 1) + return std::nullopt; + return vecType.getShape().front(); + } return std::nullopt; } +static Value castIntegerLikeTo(Operation *anchor, Value value, Type targetType) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + + if (value.getType() == targetType) + return value; + + auto targetInt = dyn_cast(targetType); + if (value.getType().isIndex() && targetInt) + return builder.create(anchor->getLoc(), targetType, value); + if (auto sourceInt = dyn_cast(value.getType())) { + if (targetInt) { + if (sourceInt.getWidth() < targetInt.getWidth()) + return builder.create(anchor->getLoc(), targetType, value); + if (sourceInt.getWidth() > targetInt.getWidth()) + return builder.create(anchor->getLoc(), targetType, value); + return value; + } + if (targetType.isIndex()) + return builder.create(anchor->getLoc(), targetType, value); + } + + return {}; +} + static FailureOr normalizeVdupScalarOperand(OpBuilder &builder, Location loc, - pto::VdupOp vdup) { - Value input = vdup.getInput(); - Type scalarType = input.getType(); - auto intType = dyn_cast(scalarType); + pto::VdupOp op) { + Value input = op.getInput(); + auto intType = dyn_cast(input.getType()); if (!intType || intType.getWidth() != 8) return input; - Type resultElemType = getElementTypeFromVectorLike(vdup.getResult().getType()); + Type resultElemType = getElementTypeFromVectorLike(op.getResult().getType()); std::string resultElemFragment = getElementTypeFragment(resultElemType); if (resultElemFragment != "s8" && resultElemFragment != "u8") return input; @@ -193,201 +233,28 @@ static FailureOr normalizeVdupScalarOperand(OpBuilder &builder, Location return builder.create(loc, i16Type, input).getResult(); } -// VSQZ #st hint must only be set when the compacted vector feeds VSTUR. -// Emitting #st=1 without a matching VSTUR consumer can deadlock hardware queues. -static uint64_t determineVsqzStoreHint(pto::VsqzOp vsqz) { - Value result = vsqz.getResult(); - for (Operation *user : result.getUsers()) { - auto vstur = dyn_cast(user); - if (!vstur) - continue; - if (vstur.getValue() == result) - return 1; - } - return 0; -} - -enum class VcvtElemKind { - Invalid, - F16, - BF16, - F32, - S8, - U8, - S16, - U16, - S32, - U32, - S64, -}; - -struct VcvtContract { - const char *intrinsic; - bool requiresRnd; - bool requiresSat; - bool requiresPart; - unsigned maskBitWidth; -}; - -static VcvtElemKind classifyVcvtElemType(Type type) { - if (type.isF16()) - return VcvtElemKind::F16; - if (type.isBF16()) - return VcvtElemKind::BF16; - if (type.isF32()) - return VcvtElemKind::F32; - if (auto intType = dyn_cast(type)) { +static std::string getCopyElementFragment(Type elementType) { + if (!elementType) + return {}; + if (elementType.isF16()) + return "f16"; + if (elementType.isBF16()) + return "bf16"; + if (elementType.isF32()) + return "f32"; + if (auto intType = dyn_cast(elementType)) { switch (intType.getWidth()) { case 8: - return intType.isUnsigned() ? VcvtElemKind::U8 : VcvtElemKind::S8; + return intType.isUnsigned() ? "u8" : "s8"; case 16: - return intType.isUnsigned() ? VcvtElemKind::U16 : VcvtElemKind::S16; + return intType.isUnsigned() ? "u16" : "s16"; case 32: - return intType.isUnsigned() ? VcvtElemKind::U32 : VcvtElemKind::S32; - case 64: - return intType.isUnsigned() ? VcvtElemKind::Invalid : VcvtElemKind::S64; + return intType.isUnsigned() ? "u32" : "s32"; default: - return VcvtElemKind::Invalid; + return {}; } } - return VcvtElemKind::Invalid; -} - -static std::optional lookupVcvtContract(VcvtElemKind src, - VcvtElemKind dst) { - switch (src) { - case VcvtElemKind::F32: - switch (dst) { - case VcvtElemKind::F16: - return VcvtContract{"llvm.hivm.vcvtff.f322f16.x", true, true, true, 32}; - case VcvtElemKind::BF16: - return VcvtContract{"llvm.hivm.vcvtff.f322bf16.x", true, true, true, 32}; - case VcvtElemKind::S16: - return VcvtContract{"llvm.hivm.vcvtfi.f322s16.x", true, true, true, 32}; - case VcvtElemKind::S32: - return VcvtContract{"llvm.hivm.vcvtfi.f322s32.x", true, true, false, 32}; - case VcvtElemKind::S64: - return VcvtContract{"llvm.hivm.vcvtfi.f322s64.x", true, true, true, 32}; - default: - return std::nullopt; - } - case VcvtElemKind::F16: - switch (dst) { - case VcvtElemKind::F32: - return VcvtContract{"llvm.hivm.vcvtff.f162f32.x", false, false, true, 16}; - case VcvtElemKind::S32: - return VcvtContract{"llvm.hivm.vcvtfi.f162s32.x", true, false, true, 16}; - case VcvtElemKind::S16: - return VcvtContract{"llvm.hivm.vcvtfi.f162s16.x", true, true, false, 16}; - case VcvtElemKind::S8: - return VcvtContract{"llvm.hivm.vcvtfi.f162s8.x", true, true, true, 16}; - case VcvtElemKind::U8: - return VcvtContract{"llvm.hivm.vcvtfi.f162u8.x", true, true, true, 16}; - default: - return std::nullopt; - } - case VcvtElemKind::BF16: - switch (dst) { - case VcvtElemKind::F32: - return VcvtContract{"llvm.hivm.vcvtff.bf162f32.x", false, false, true, 16}; - case VcvtElemKind::S32: - return VcvtContract{"llvm.hivm.vcvtfi.bf162s32.x", true, true, true, 16}; - default: - return std::nullopt; - } - case VcvtElemKind::U8: - switch (dst) { - case VcvtElemKind::F16: - return VcvtContract{"llvm.hivm.vcvtif.u82f16.x", false, false, true, 8}; - case VcvtElemKind::U16: - return VcvtContract{"llvm.hivm.vcvtii.u82u16.x", false, false, true, 8}; - case VcvtElemKind::U32: - return VcvtContract{"llvm.hivm.vcvtii.u82u32.x", false, false, true, 8}; - default: - return std::nullopt; - } - case VcvtElemKind::S8: - switch (dst) { - case VcvtElemKind::F16: - return VcvtContract{"llvm.hivm.vcvtif.s82f16.x", false, false, true, 8}; - case VcvtElemKind::S16: - return VcvtContract{"llvm.hivm.vcvtii.s82s16.x", false, false, true, 8}; - case VcvtElemKind::S32: - return VcvtContract{"llvm.hivm.vcvtii.s82s32.x", false, false, true, 8}; - default: - return std::nullopt; - } - case VcvtElemKind::U16: - switch (dst) { - case VcvtElemKind::U8: - return VcvtContract{"llvm.hivm.vcvtii.u162u8.x", false, true, true, 16}; - case VcvtElemKind::U32: - return VcvtContract{"llvm.hivm.vcvtii.u162u32.x", false, false, true, 16}; - default: - return std::nullopt; - } - case VcvtElemKind::S16: - switch (dst) { - case VcvtElemKind::F16: - return VcvtContract{"llvm.hivm.vcvtif.s162f16.x", true, false, false, 16}; - case VcvtElemKind::F32: - return VcvtContract{"llvm.hivm.vcvtif.s162f32.x", false, false, true, 16}; - case VcvtElemKind::U8: - return VcvtContract{"llvm.hivm.vcvtii.s162u8.x", false, true, true, 16}; - case VcvtElemKind::U32: - return VcvtContract{"llvm.hivm.vcvtii.s162u32.x", false, false, true, 16}; - case VcvtElemKind::S32: - return VcvtContract{"llvm.hivm.vcvtii.s162s32.x", false, false, true, 16}; - default: - return std::nullopt; - } - case VcvtElemKind::U32: - switch (dst) { - case VcvtElemKind::U8: - return VcvtContract{"llvm.hivm.vcvtii.u322u8.x", false, true, true, 32}; - case VcvtElemKind::U16: - return VcvtContract{"llvm.hivm.vcvtii.u322u16.x", false, true, true, 32}; - case VcvtElemKind::S16: - return VcvtContract{"llvm.hivm.vcvtii.u322s16.x", false, true, true, 32}; - default: - return std::nullopt; - } - case VcvtElemKind::S32: - switch (dst) { - case VcvtElemKind::F32: - return VcvtContract{"llvm.hivm.vcvtif.s322f32.x", true, false, false, 32}; - case VcvtElemKind::U8: - return VcvtContract{"llvm.hivm.vcvtii.s322u8.x", false, true, true, 32}; - case VcvtElemKind::U16: - return VcvtContract{"llvm.hivm.vcvtii.s322u16.x", false, true, true, 32}; - case VcvtElemKind::S16: - return VcvtContract{"llvm.hivm.vcvtii.s322s16.x", false, true, true, 32}; - case VcvtElemKind::S64: - return VcvtContract{"llvm.hivm.vcvtii.s322s64.x", false, false, true, 32}; - default: - return std::nullopt; - } - case VcvtElemKind::S64: - switch (dst) { - case VcvtElemKind::F32: - return VcvtContract{"llvm.hivm.vcvtif.s642f32.x", true, false, true, 32}; - case VcvtElemKind::S32: - return VcvtContract{"llvm.hivm.vcvtii.s642s32.x", false, true, true, 32}; - default: - return std::nullopt; - } - case VcvtElemKind::Invalid: - return std::nullopt; - } - return std::nullopt; -} - -static std::optional parseHiLoPartImmediate(StringRef part) { - if (part == "LOWER") - return 0; // __cce_simd::HiloPart::Lower - if (part == "HIGHER") - return 1; // __cce_simd::HiloPart::Higher - return std::nullopt; + return {}; } static std::optional parsePredicatePatternImmediate(StringRef pattern) { @@ -424,173 +291,73 @@ static std::optional parsePredicatePatternImmediate(StringRef pattern) return std::nullopt; } -static Type getSignlessIntegerTypeWithSameWidth(Type type, Builder &builder) { - if (auto intType = dyn_cast(type)) - return builder.getIntegerType(intType.getWidth()); - if (auto floatType = dyn_cast(type)) - return builder.getIntegerType(floatType.getWidth()); - return {}; -} - -static std::string getVbrScalarFragment(Type type) { - if (type.isF16()) - return "f16"; - if (type.isBF16()) - return "bf16"; - if (type.isF32()) - return "f32"; - if (auto intType = dyn_cast(type)) - return (intType.isUnsigned() ? "u" : "s") + std::to_string(intType.getWidth()); - return {}; +static std::optional parseHiLoPartImmediate(StringRef part) { + if (part == "LOWER") + return 0; + if (part == "HIGHER") + return 1; + return std::nullopt; } -static std::string getCopyElementFragment(Type elementType) { - if (!elementType) - return {}; - if (elementType.isF16()) - return "f16"; - if (elementType.isBF16()) - return "bf16"; - if (elementType.isF32()) - return "f32"; - if (auto intType = dyn_cast(elementType)) { - switch (intType.getWidth()) { - case 8: - return intType.isUnsigned() ? "u8" : "s8"; - case 16: - return intType.isUnsigned() ? "u16" : "s16"; - case 32: - return intType.isUnsigned() ? "u32" : "s32"; - default: - return {}; - } - } - return {}; +static std::optional parseRoundModeImmediate(StringRef roundMode) { + if (roundMode == "R" || roundMode == "ROUND_R") + return 0; + if (roundMode == "A" || roundMode == "ROUND_A") + return 1; + if (roundMode == "F" || roundMode == "ROUND_F") + return 2; + if (roundMode == "C" || roundMode == "ROUND_C") + return 3; + if (roundMode == "Z" || roundMode == "ROUND_Z") + return 4; + if (roundMode == "O" || roundMode == "ROUND_O") + return 5; + return std::nullopt; } -static std::optional buildABIExprFromValue(Value value); - -static std::optional buildABIExprFromFoldResult(OpFoldResult ofr) { - if (auto attr = ofr.dyn_cast()) { - if (auto intAttr = dyn_cast(attr)) - return ABIExpr::constantExpr(intAttr.getValue().getZExtValue()); - return std::nullopt; - } - return buildABIExprFromValue(ofr.get()); +static std::optional parseSaturationImmediate(StringRef sat) { + if (sat == "SAT" || sat == "RS_ENABLE") + return 0; + if (sat == "NOSAT" || sat == "RS_DISABLE") + return 1; + return std::nullopt; } -static std::optional buildABIExprFromValue(Value value) { - if (auto blockArg = dyn_cast(value)) { - auto func = dyn_cast(blockArg.getOwner()->getParentOp()); - if (!func || blockArg.getOwner() != &func.getBody().front()) - return std::nullopt; - return ABIExpr::argExpr(blockArg.getArgNumber()); - } - - if (auto constIndex = value.getDefiningOp()) - return ABIExpr::constantExpr(constIndex.value()); - if (auto constOp = value.getDefiningOp()) { - if (auto intAttr = dyn_cast(constOp.getValue())) - return ABIExpr::constantExpr(intAttr.getValue().getZExtValue()); - } - if (auto castOp = value.getDefiningOp()) - return buildABIExprFromValue(castOp.getIn()); - if (auto castOp = value.getDefiningOp()) - return buildABIExprFromValue(castOp.getIn()); - if (auto extOp = value.getDefiningOp()) - return buildABIExprFromValue(extOp.getIn()); - if (auto extOp = value.getDefiningOp()) - return buildABIExprFromValue(extOp.getIn()); - if (auto truncOp = value.getDefiningOp()) - return buildABIExprFromValue(truncOp.getIn()); - if (auto mulOp = value.getDefiningOp()) { - auto lhs = buildABIExprFromValue(mulOp.getLhs()); - auto rhs = buildABIExprFromValue(mulOp.getRhs()); - if (!lhs || !rhs) - return std::nullopt; - return ABIExpr::mulExpr(std::move(*lhs), std::move(*rhs)); - } - +static std::optional parsePartImmediate(StringRef part) { + if (part == "EVEN" || part == "PART_EVEN") + return 0; + if (part == "ODD" || part == "PART_ODD") + return 1; return std::nullopt; } -static unsigned getExternalPointerAddressSpace(MemRefType type) { - if (auto addrAttr = dyn_cast_or_null(type.getMemorySpace())) { - switch (addrAttr.getAddressSpace()) { - case pto::AddressSpace::GM: - case pto::AddressSpace::Zero: - return 1; - case pto::AddressSpace::VEC: - return 6; - default: - break; - } - } - return 1; +static std::optional parsePredicateStoreDistImmediate(StringRef dist) { + if (dist == "NORM") + return 0; + if (dist == "PK") + return 1; + return std::nullopt; } -static std::optional deriveMemRefTotalSize(BlockArgument arg, - MemRefType type) { - if (type.getRank() != 1) - return std::nullopt; - - if (!type.isDynamicDim(0)) - return ABIExpr::constantExpr(type.getDimSize(0)); - - for (Operation *user : arg.getUsers()) { - auto reinterpret = dyn_cast(user); - if (!reinterpret || reinterpret.getSource() != arg) - continue; - - std::optional accum; - for (OpFoldResult size : reinterpret.getMixedSizes()) { - auto sizeExpr = buildABIExprFromFoldResult(size); - if (!sizeExpr) - return std::nullopt; - accum = accum ? ABIExpr::mulExpr(std::move(*accum), std::move(*sizeExpr)) - : std::move(*sizeExpr); - } - if (accum) - return accum; - } - +static std::optional parsePredicateLoadDistImmediate(StringRef dist) { + if (dist.empty() || dist == "NORM") + return 0; + if (dist == "US") + return 1; + if (dist == "DS") + return 2; return std::nullopt; } -static llvm::StringMap collectFunctionABISpecs(ModuleOp module) { - llvm::StringMap specs; - module.walk([&](func::FuncOp funcOp) { - if (funcOp.isExternal()) - return; - - FunctionABISpec funcSpec; - funcSpec.args.reserve(funcOp.getNumArguments()); - - for (BlockArgument arg : funcOp.getArguments()) { - ExternalArgABISpec argSpec; - if (auto memrefType = dyn_cast(arg.getType())) { - if (memrefType.getRank() == 1) { - auto totalSize = deriveMemRefTotalSize(arg, memrefType); - if (totalSize) { - argSpec.isMemRef = true; - argSpec.memrefSpec.addressSpace = - getExternalPointerAddressSpace(memrefType); - argSpec.memrefSpec.rank = 1; - argSpec.memrefSpec.offset = ABIExpr::constantExpr(0); - argSpec.memrefSpec.totalSize = std::move(*totalSize); - argSpec.memrefSpec.stride = ABIExpr::constantExpr(1); - } - } - } - funcSpec.args.push_back(std::move(argSpec)); - } - - specs[funcOp.getName().str()] = std::move(funcSpec); - }); - return specs; +static std::optional parsePostModeImmediate(StringRef mode) { + if (mode == "NO_POST_UPDATE") + return 0; + if (mode == "POST_UPDATE") + return 1; + return std::nullopt; } -static std::optional parsePipeImmediate(llvm::StringRef pipe) { +static std::optional parsePipeImmediate(StringRef pipe) { if (pipe == "PIPE_S") return 0; if (pipe == "PIPE_V") @@ -620,7 +387,7 @@ static std::optional parsePipeImmediate(llvm::StringRef pipe) { return std::nullopt; } -static std::optional parseEventImmediate(llvm::StringRef event) { +static std::optional parseEventImmediate(StringRef event) { if (!event.consume_front("EVENT_ID")) return std::nullopt; uint64_t value = 0; @@ -629,7 +396,7 @@ static std::optional parseEventImmediate(llvm::StringRef event) { return value; } -static std::optional parseSprImmediate(llvm::StringRef spr) { +static std::optional parseSprImmediate(StringRef spr) { if (spr == "AR") return 74; return std::nullopt; @@ -647,120 +414,242 @@ static std::optional getDistElementWidth(Type type) { return std::nullopt; } -static std::optional parseLoadDistImmediate(llvm::StringRef dist, - Type elementType) { - auto width = getDistElementWidth(elementType); - if (dist.empty() || dist == "NORM") - return 0; - if (!width) - return std::nullopt; - if (dist == "BRC") - return *width == 8 ? std::optional(1) - : *width == 16 ? std::optional(2) - : *width == 32 ? std::optional(3) - : std::nullopt; - if (dist == "US") - return *width == 8 ? std::optional(6) - : *width == 16 ? std::optional(7) - : std::nullopt; - if (dist == "DS") - return *width == 8 ? std::optional(8) - : *width == 16 ? std::optional(9) - : std::nullopt; - if (dist == "UNPK") - return *width == 8 ? std::optional(13) - : *width == 16 ? std::optional(14) - : *width == 32 ? std::optional(18) - : std::nullopt; - if (dist == "BRC_BLK") - return 15; - if (dist == "E2B") - return *width == 16 ? std::optional(16) - : *width == 32 ? std::optional(17) - : std::nullopt; - if (dist == "UNPK4") - return *width == 8 ? std::optional(20) : std::nullopt; - if (dist == "SPLT4CHN") - return *width == 8 ? std::optional(21) : std::nullopt; - if (dist == "SPLT2CHN") - return *width == 8 ? std::optional(22) - : *width == 16 ? std::optional(23) - : std::nullopt; - return std::nullopt; +static VcvtElemKind classifyVcvtElemType(Type type) { + if (type.isF16()) + return VcvtElemKind::F16; + if (type.isBF16()) + return VcvtElemKind::BF16; + if (type.isF32()) + return VcvtElemKind::F32; + if (auto intType = dyn_cast(type)) { + switch (intType.getWidth()) { + case 8: + return intType.isUnsigned() ? VcvtElemKind::U8 : VcvtElemKind::S8; + case 16: + return intType.isUnsigned() ? VcvtElemKind::U16 : VcvtElemKind::S16; + case 32: + return intType.isUnsigned() ? VcvtElemKind::U32 : VcvtElemKind::S32; + case 64: + return intType.isUnsigned() ? VcvtElemKind::Invalid : VcvtElemKind::S64; + default: + return VcvtElemKind::Invalid; + } + } + return VcvtElemKind::Invalid; } -static std::optional parseLoadX2DistImmediate(llvm::StringRef dist, - Type elementType) { - auto width = getDistElementWidth(elementType); - if (dist == "BDINTLV") - return 10; - if (!width) - return std::nullopt; - if (dist == "DINTLV") - return *width == 8 ? std::optional(11) - : *width == 16 ? std::optional(12) - : *width == 32 ? std::optional(19) - : std::nullopt; +static std::optional lookupVcvtContract(VcvtElemKind src, + VcvtElemKind dst) { + switch (src) { + case VcvtElemKind::F32: + switch (dst) { + case VcvtElemKind::F16: + return VcvtContract{"llvm.hivm.vcvtff.f322f16.x", true, true, true, 32}; + case VcvtElemKind::BF16: + return VcvtContract{"llvm.hivm.vcvtff.f322bf16.x", true, true, true, 32}; + case VcvtElemKind::S16: + return VcvtContract{"llvm.hivm.vcvtfi.f322s16.x", true, true, true, 32}; + case VcvtElemKind::S32: + return VcvtContract{"llvm.hivm.vcvtfi.f322s32.x", true, true, false, 32}; + case VcvtElemKind::S64: + return VcvtContract{"llvm.hivm.vcvtfi.f322s64.x", true, true, true, 32}; + default: + return std::nullopt; + } + case VcvtElemKind::F16: + switch (dst) { + case VcvtElemKind::F32: + return VcvtContract{"llvm.hivm.vcvtff.f162f32.x", false, false, true, 16}; + case VcvtElemKind::S32: + return VcvtContract{"llvm.hivm.vcvtfi.f162s32.x", true, false, true, 16}; + case VcvtElemKind::S16: + return VcvtContract{"llvm.hivm.vcvtfi.f162s16.x", true, true, false, 16}; + case VcvtElemKind::S8: + return VcvtContract{"llvm.hivm.vcvtfi.f162s8.x", true, true, true, 16}; + case VcvtElemKind::U8: + return VcvtContract{"llvm.hivm.vcvtfi.f162u8.x", true, true, true, 16}; + default: + return std::nullopt; + } + case VcvtElemKind::BF16: + switch (dst) { + case VcvtElemKind::F32: + return VcvtContract{"llvm.hivm.vcvtff.bf162f32.x", false, false, true, 16}; + case VcvtElemKind::S32: + return VcvtContract{"llvm.hivm.vcvtfi.bf162s32.x", true, true, true, 16}; + default: + return std::nullopt; + } + case VcvtElemKind::U8: + switch (dst) { + case VcvtElemKind::F16: + return VcvtContract{"llvm.hivm.vcvtif.u82f16.x", false, false, true, 8}; + case VcvtElemKind::U16: + return VcvtContract{"llvm.hivm.vcvtii.u82u16.x", false, false, true, 8}; + case VcvtElemKind::U32: + return VcvtContract{"llvm.hivm.vcvtii.u82u32.x", false, false, true, 8}; + default: + return std::nullopt; + } + case VcvtElemKind::S8: + switch (dst) { + case VcvtElemKind::F16: + return VcvtContract{"llvm.hivm.vcvtif.s82f16.x", false, false, true, 8}; + case VcvtElemKind::S16: + return VcvtContract{"llvm.hivm.vcvtii.s82s16.x", false, false, true, 8}; + case VcvtElemKind::S32: + return VcvtContract{"llvm.hivm.vcvtii.s82s32.x", false, false, true, 8}; + default: + return std::nullopt; + } + case VcvtElemKind::U16: + switch (dst) { + case VcvtElemKind::U8: + return VcvtContract{"llvm.hivm.vcvtii.u162u8.x", false, true, true, 16}; + case VcvtElemKind::U32: + return VcvtContract{"llvm.hivm.vcvtii.u162u32.x", false, false, true, 16}; + default: + return std::nullopt; + } + case VcvtElemKind::S16: + switch (dst) { + case VcvtElemKind::F16: + return VcvtContract{"llvm.hivm.vcvtif.s162f16.x", true, false, false, 16}; + case VcvtElemKind::F32: + return VcvtContract{"llvm.hivm.vcvtif.s162f32.x", false, false, true, 16}; + case VcvtElemKind::U8: + return VcvtContract{"llvm.hivm.vcvtii.s162u8.x", false, true, true, 16}; + case VcvtElemKind::U32: + return VcvtContract{"llvm.hivm.vcvtii.s162u32.x", false, false, true, 16}; + case VcvtElemKind::S32: + return VcvtContract{"llvm.hivm.vcvtii.s162s32.x", false, false, true, 16}; + default: + return std::nullopt; + } + case VcvtElemKind::U32: + switch (dst) { + case VcvtElemKind::U8: + return VcvtContract{"llvm.hivm.vcvtii.u322u8.x", false, true, true, 32}; + case VcvtElemKind::U16: + return VcvtContract{"llvm.hivm.vcvtii.u322u16.x", false, true, true, 32}; + case VcvtElemKind::S16: + return VcvtContract{"llvm.hivm.vcvtii.u322s16.x", false, true, true, 32}; + default: + return std::nullopt; + } + case VcvtElemKind::S32: + switch (dst) { + case VcvtElemKind::F32: + return VcvtContract{"llvm.hivm.vcvtif.s322f32.x", true, false, false, 32}; + case VcvtElemKind::U8: + return VcvtContract{"llvm.hivm.vcvtii.s322u8.x", false, true, true, 32}; + case VcvtElemKind::U16: + return VcvtContract{"llvm.hivm.vcvtii.s322u16.x", false, true, true, 32}; + case VcvtElemKind::S16: + return VcvtContract{"llvm.hivm.vcvtii.s322s16.x", false, true, true, 32}; + case VcvtElemKind::S64: + return VcvtContract{"llvm.hivm.vcvtii.s322s64.x", false, false, true, 32}; + default: + return std::nullopt; + } + case VcvtElemKind::S64: + switch (dst) { + case VcvtElemKind::F32: + return VcvtContract{"llvm.hivm.vcvtif.s642f32.x", true, false, true, 32}; + case VcvtElemKind::S32: + return VcvtContract{"llvm.hivm.vcvtii.s642s32.x", false, true, true, 32}; + default: + return std::nullopt; + } + case VcvtElemKind::Invalid: + return std::nullopt; + } return std::nullopt; } -static std::optional parsePredicateLoadDistImmediate(llvm::StringRef dist) { - if (dist.empty() || dist == "NORM") - return 0; // Dist::DIST_NORM - if (dist == "US") - return 1; // Dist::DIST_US - if (dist == "DS") - return 2; // Dist::DIST_DS - return std::nullopt; +// VSQZ #st hint must only be set when the compacted vector feeds VSTUR. +// Emitting #st=1 without a matching VSTUR consumer can deadlock hardware queues. +static uint64_t determineVsqzStoreHint(pto::VsqzOp vsqz) { + Value result = vsqz.getResult(); + for (Operation *user : result.getUsers()) { + auto vstur = dyn_cast(user); + if (!vstur) + continue; + if (vstur.getValue() == result) + return 1; + } + return 0; } -static std::optional parsePredicateStoreDistImmediate(llvm::StringRef dist) { +static std::optional parseLoadDistImmediate(StringRef dist, + Type elementType) { + auto width = getDistElementWidth(elementType); if (dist.empty() || dist == "NORM") - return 0; // Dist::DIST_NORM - if (dist == "PK") - return 1; // Dist::DIST_PK + return 0; + if (!width) + return std::nullopt; + if (dist == "BRC") + return *width == 8 ? std::optional(1) + : *width == 16 ? std::optional(2) + : *width == 32 ? std::optional(3) + : std::nullopt; + if (dist == "US") + return *width == 8 ? std::optional(6) + : *width == 16 ? std::optional(7) + : std::nullopt; + if (dist == "DS") + return *width == 8 ? std::optional(8) + : *width == 16 ? std::optional(9) + : std::nullopt; + if (dist == "UNPK") + return *width == 8 ? std::optional(13) + : *width == 16 ? std::optional(14) + : *width == 32 ? std::optional(18) + : std::nullopt; + if (dist == "BRC_BLK") + return 15; + if (dist == "E2B") + return *width == 16 ? std::optional(16) + : *width == 32 ? std::optional(17) + : std::nullopt; + if (dist == "UNPK4") + return *width == 8 ? std::optional(20) : std::nullopt; + if (dist == "SPLT4CHN") + return *width == 8 ? std::optional(21) : std::nullopt; + if (dist == "SPLT2CHN") + return *width == 8 ? std::optional(22) + : *width == 16 ? std::optional(23) + : std::nullopt; return std::nullopt; } -static Value packBlockRepeatStride(Operation *anchor, Value blockStride, - Value repeatStride) { - OpBuilder builder(anchor); - builder.setInsertionPoint(anchor); - - Value blockI32 = castIntegerLikeTo(anchor, blockStride, builder.getI32Type()); - Value repeatI32 = - castIntegerLikeTo(anchor, repeatStride, builder.getI32Type()); - if (!blockI32 || !repeatI32) - return {}; - - auto c16 = builder.create(anchor->getLoc(), 16, 32); - auto blockShifted = - builder.create(anchor->getLoc(), blockI32, c16); - return builder - .create(anchor->getLoc(), blockShifted, repeatI32) - .getResult(); -} - -static std::optional parseOrderImmediate(llvm::StringRef order) { - if (order.empty() || order == "ASC") - return 0; // INC_ORDER - if (order == "DESC") - return 1; // DEC_ORDER +static std::optional parseLoadX2DistImmediate(StringRef dist, + Type elementType) { + auto width = getDistElementWidth(elementType); + if (dist == "BDINTLV") + return 10; + if (!width) + return std::nullopt; + if (dist == "DINTLV") + return *width == 8 ? std::optional(11) + : *width == 16 ? std::optional(12) + : *width == 32 ? std::optional(19) + : std::nullopt; return std::nullopt; } -static std::optional parseStoreDistImmediate(llvm::StringRef dist, +static std::optional parseStoreDistImmediate(StringRef dist, Type elementType) { auto width = getDistElementWidth(elementType); if (dist.empty() || dist == "NORM") { if (!width) return std::nullopt; if (*width == 8) - return 0; // norm_b8 + return 0; if (*width == 16) - return 1; // norm_b16 + return 1; if (*width == 32) - return 2; // norm_b32 + return 2; return std::nullopt; } if (!width) @@ -786,7 +675,7 @@ static std::optional parseStoreDistImmediate(llvm::StringRef dist, return std::nullopt; } -static std::optional parseStoreX2DistImmediate(llvm::StringRef dist, +static std::optional parseStoreX2DistImmediate(StringRef dist, Type elementType) { auto width = getDistElementWidth(elementType); if (!width) @@ -799,2926 +688,4126 @@ static std::optional parseStoreX2DistImmediate(llvm::StringRef dist, return std::nullopt; } -static std::optional parsePostModeImmediate(StringRef mode) { - if (mode == "NO_POST_UPDATE") - return 0; - if (mode == "POST_UPDATE") - return 1; - return std::nullopt; -} +static Value packBlockRepeatStride(Operation *anchor, Value blockStride, + Value repeatStride) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); -static Type convertVPTOType(Type type, Builder &builder) { - if (auto vecType = dyn_cast(type)) - return VectorType::get({vecType.getElementCount()}, vecType.getElementType()); - if (isa(type)) - return VectorType::get({256}, builder.getI1Type()); - if (isa(type)) - return VectorType::get({32}, builder.getIntegerType(8)); - if (auto ptrType = dyn_cast(type)) { - return LLVM::LLVMPointerType::get( - builder.getContext(), - static_cast(ptrType.getMemorySpace().getAddressSpace())); - } - return type; -} + Value blockI32 = castIntegerLikeTo(anchor, blockStride, builder.getI32Type()); + Value repeatI32 = + castIntegerLikeTo(anchor, repeatStride, builder.getI32Type()); + if (!blockI32 || !repeatI32) + return {}; -static bool hasPtoPtrType(TypeRange types) { - return llvm::any_of(types, [](Type type) { return isa(type); }); + auto c16 = builder.create(anchor->getLoc(), 16, 32); + auto blockShifted = + builder.create(anchor->getLoc(), blockI32, c16); + return builder + .create(anchor->getLoc(), blockShifted, repeatI32) + .getResult(); } -static bool hasPtoAlignType(Type type) { - if (isa(type)) - return true; - if (auto functionType = dyn_cast(type)) - return llvm::any_of(functionType.getInputs(), hasPtoAlignType) || - llvm::any_of(functionType.getResults(), hasPtoAlignType); - return false; +static std::optional parseOrderImmediate(StringRef order) { + if (order.empty() || order == "ASC") + return 0; + if (order == "DESC") + return 1; + return std::nullopt; } -static bool hasPtoAlignType(TypeRange types) { - return llvm::any_of(types, [](Type type) { return hasPtoAlignType(type); }); -} +static FailureOr packLoopPair(Operation *anchor, Value low, Value high) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); -static bool hasPtoMemRefMemorySpace(Type type) { - if (auto memRefType = dyn_cast(type)) - return isa(memRefType.getMemorySpace()); - if (auto functionType = dyn_cast(type)) - return llvm::any_of(functionType.getInputs(), hasPtoMemRefMemorySpace) || - llvm::any_of(functionType.getResults(), hasPtoMemRefMemorySpace); - return false; -} + Value lowI64 = castIntegerLikeTo(anchor, low, builder.getI64Type()); + Value highI64 = castIntegerLikeTo(anchor, high, builder.getI64Type()); + if (!lowI64 || !highI64) + return failure(); -static bool hasPtoMemRefMemorySpace(TypeRange types) { - return llvm::any_of(types, [](Type type) { - return hasPtoMemRefMemorySpace(type); - }); + Value shift = getI64Constant(builder, anchor->getLoc(), 40); + Value highShifted = + builder.create(anchor->getLoc(), highI64, shift).getResult(); + return builder.create(anchor->getLoc(), highShifted, lowI64) + .getResult(); } -struct ConvertPtoMemRefSpaceCarrierOp final : ConversionPattern { - ConvertPtoMemRefSpaceCarrierOp(TypeConverter &typeConverter, - MLIRContext *context) - : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), 1, context) {} +static FailureOr packLoopSize(Operation *anchor, Value loop2, Value loop1) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); - LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - if (!hasPtoMemRefMemorySpace(op->getOperandTypes()) && - !hasPtoMemRefMemorySpace(op->getResultTypes())) - return failure(); - if (op->getNumRegions() != 0) - return rewriter.notifyMatchFailure( - op, "region ops with PTO memref spaces are handled structurally"); + Value loop2I64 = castIntegerLikeTo(anchor, loop2, builder.getI64Type()); + Value loop1I64 = castIntegerLikeTo(anchor, loop1, builder.getI64Type()); + if (!loop2I64 || !loop1I64) + return failure(); - FailureOr converted = - convertOpResultTypes(op, operands, *typeConverter, rewriter); - if (failed(converted)) - return failure(); - return success(); - } -}; + Value shift = getI64Constant(builder, anchor->getLoc(), 21); + Value loop2Shifted = + builder.create(anchor->getLoc(), loop2I64, shift).getResult(); + return builder.create(anchor->getLoc(), loop2Shifted, loop1I64) + .getResult(); +} -struct ConvertMemRefReinterpretCastSpaceOp final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +static FailureOr +packCopyGmToUbConfig0(Operation *anchor, ValueRange operands) { + if (operands.size() != 11) + return failure(); - LogicalResult - matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type convertedResultType = getTypeConverter()->convertType(op.getType()); - auto memRefResultType = dyn_cast_or_null(convertedResultType); - if (!memRefResultType) - return rewriter.notifyMatchFailure(op, "expected memref result type"); - - rewriter.replaceOpWithNewOp( - op, memRefResultType, adaptor.getSource(), adaptor.getOffsets(), - adaptor.getSizes(), adaptor.getStrides(), op.getStaticOffsets(), - op.getStaticSizes(), op.getStaticStrides()); - return success(); - } -}; + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); -struct ConvertMemRefSubViewSpaceOp final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + auto getI64Operand = [&](unsigned idx) -> Value { + return castIntegerLikeTo(anchor, operands[idx], builder.getI64Type()); + }; - LogicalResult - matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type convertedResultType = getTypeConverter()->convertType(op.getType()); - auto memRefResultType = dyn_cast_or_null(convertedResultType); - if (!memRefResultType) - return rewriter.notifyMatchFailure(op, "expected memref result type"); - - rewriter.replaceOpWithNewOp( - op, memRefResultType, adaptor.getSource(), op.getMixedOffsets(), - op.getMixedSizes(), op.getMixedStrides()); - return success(); - } -}; + Value sid = getI64Operand(2); + Value nBurst = getI64Operand(3); + Value lenBurst = getI64Operand(4); + Value leftPadding = getI64Operand(5); + Value rightPadding = getI64Operand(6); + Value dataSelect = castIntegerLikeTo(anchor, operands[7], builder.getI64Type()); + Value cacheCtl = getI64Operand(8); + if (!sid || !nBurst || !lenBurst || !leftPadding || !rightPadding || + !dataSelect || !cacheCtl) + return failure(); -struct ConvertMemRefSpaceUnrealizedCastOp final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + auto shl = [&](Value value, uint64_t amount) -> Value { + return builder.create(loc, value, + getI64Constant(builder, loc, amount)); + }; + auto bitOr = [&](Value lhs, Value rhs) -> Value { + return builder.create(loc, lhs, rhs); + }; - LogicalResult - matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (op->getNumOperands() != 1 || op->getNumResults() != 1) - return failure(); - if (!hasPtoMemRefMemorySpace(op->getOperandTypes()) && - !hasPtoMemRefMemorySpace(op->getResultTypes())) - return failure(); + Value config = sid; + config = bitOr(config, shl(nBurst, 4)); + config = bitOr(config, shl(lenBurst, 25)); + config = bitOr(config, shl(leftPadding, 46)); + config = bitOr(config, shl(rightPadding, 52)); + config = bitOr(config, shl(dataSelect, 58)); + config = bitOr(config, shl(cacheCtl, 60)); + return config; +} - Type convertedResultType = getTypeConverter()->convertType(op.getResult(0).getType()); - if (!convertedResultType) - return failure(); +static FailureOr +packCopyGmToUbConfig1(Operation *anchor, ValueRange operands) { + if (operands.size() != 11) + return failure(); + return packLoopPair(anchor, operands[9], operands[10]); +} - Value input = adaptor.getOperands().front(); - if (input.getType() == convertedResultType) { - rewriter.replaceOp(op, input); - return success(); - } +static FailureOr +packCopyUbToGmConfig0(Operation *anchor, ValueRange operands) { + if (operands.size() != 8) return failure(); - } -}; -static LogicalResult normalizePtoMemRefSpaces(ModuleOp module, - llvm::raw_ostream &diagOS) { - MLIRContext *context = module.getContext(); - TypeConverter typeConverter; - typeConverter.addConversion([](Type type) { return type; }); - typeConverter.addConversion([&](MemRefType type) -> Type { - auto addrSpace = dyn_cast_or_null(type.getMemorySpace()); - if (!addrSpace) - return type; - return MemRefType::get( - type.getShape(), type.getElementType(), type.getLayout(), - IntegerAttr::get(IntegerType::get(context, 64), - static_cast(addrSpace.getAddressSpace()))); - }); - typeConverter.addTypeAttributeConversion( - [](MemRefType, pto::AddressSpaceAttr attr) -> Attribute { - return IntegerAttr::get(IntegerType::get(attr.getContext(), 64), - static_cast(attr.getAddressSpace())); - }); - auto materializeMemRefCast = [](OpBuilder &builder, Type resultType, - ValueRange inputs, Location loc) -> Value { - if (inputs.size() != 1) - return {}; - return builder - .create(loc, TypeRange{resultType}, inputs) - .getResult(0); + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); + + auto getI64Operand = [&](unsigned idx) -> Value { + return castIntegerLikeTo(anchor, operands[idx], builder.getI64Type()); }; - typeConverter.addSourceMaterialization(materializeMemRefCast); - typeConverter.addTargetMaterialization(materializeMemRefCast); - typeConverter.addArgumentMaterialization(materializeMemRefCast); - ConversionTarget target(*context); - target.addLegalOp(); - target.addDynamicallyLegalOp( - [&](func::FuncOp op) { - return typeConverter.isSignatureLegal(op.getFunctionType()) && - typeConverter.isLegal(&op.getBody()); - }); - target.addDynamicallyLegalOp( - [&](func::CallOp op) { return typeConverter.isLegal(op); }); - target.addDynamicallyLegalOp( - [&](func::ReturnOp op) { return typeConverter.isLegal(op); }); - target.addDynamicallyLegalOp( - [&](Operation *op) { - return isLegalForBranchOpInterfaceTypeConversionPattern(op, - typeConverter); - }); - target.markUnknownOpDynamicallyLegal([&](Operation *op) { - return typeConverter.isLegal(op->getOperandTypes()) && - typeConverter.isLegal(op->getResultTypes()); - }); + Value sid = getI64Operand(2); + Value nBurst = getI64Operand(3); + Value lenBurst = getI64Operand(4); + Value reserved = getI64Operand(5); + if (!sid || !nBurst || !lenBurst || !reserved) + return failure(); - RewritePatternSet patterns(context); - scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter, patterns, - target); - populateFunctionOpInterfaceTypeConversionPattern(patterns, - typeConverter); - populateCallOpTypeConversionPattern(patterns, typeConverter); - populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); - populateReturnOpTypeConversionPattern(patterns, typeConverter); - patterns.add( - typeConverter, context); - patterns.add(typeConverter, context); + auto shl = [&](Value value, uint64_t amount) -> Value { + return builder.create(loc, value, + getI64Constant(builder, loc, amount)); + }; + auto bitOr = [&](Value lhs, Value rhs) -> Value { + return builder.create(loc, lhs, rhs); + }; - if (failed(applyPartialConversion(module, target, std::move(patterns)))) { - diagOS << "VPTO LLVM emission failed: memref address-space normalization " - "failed\n"; + Value config = sid; + config = bitOr(config, shl(nBurst, 4)); + config = bitOr(config, shl(lenBurst, 25)); + config = bitOr(config, shl(reserved, 60)); + return config; +} + +static FailureOr +packCopyUbToGmConfig1(Operation *anchor, ValueRange operands) { + if (operands.size() != 8) return failure(); - } + return packLoopPair(anchor, operands[6], operands[7]); +} - SmallVector castsToFold; - module.walk([&](UnrealizedConversionCastOp castOp) { - if (castOp->getNumOperands() != 1 || castOp->getNumResults() != 1) - return; - if (!hasPtoMemRefMemorySpace(castOp->getOperandTypes()) && - !hasPtoMemRefMemorySpace(castOp->getResultTypes())) - return; - Type convertedResultType = typeConverter.convertType(castOp.getResult(0).getType()); - if (convertedResultType && convertedResultType == castOp.getOperand(0).getType()) - castsToFold.push_back(castOp); - }); - for (UnrealizedConversionCastOp castOp : castsToFold) { - castOp.getResult(0).replaceAllUsesWith(castOp.getOperand(0)); - castOp.erase(); - } +static FailureOr packVbitsortConfig(Operation *anchor, Value repeatTimes) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); - WalkResult leftover = module.walk([&](Operation *op) { - if (hasPtoMemRefMemorySpace(op->getOperandTypes()) || - hasPtoMemRefMemorySpace(op->getResultTypes())) { - diagOS << "VPTO LLVM emission failed: residual PTO memref address space on op " - << op->getName().getStringRef() << "\n"; - op->print(diagOS); - diagOS << "\n"; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - if (leftover.wasInterrupted()) + Value repeatI64 = castIntegerLikeTo(anchor, repeatTimes, builder.getI64Type()); + if (!repeatI64) return failure(); - return success(); + return builder + .create(loc, repeatI64, getI64Constant(builder, loc, 56)) + .getResult(); } -struct ConvertPtoAddPtrOp final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +static FailureOr convertElementOffsetToBytes(Operation *anchor, Value offset, + Type elementType) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); - LogicalResult - matchAndRewrite(pto::AddPtrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto convertedResultType = - getTypeConverter()->convertType(op.getResult().getType()); - auto llvmPtrType = dyn_cast_or_null(convertedResultType); - if (!llvmPtrType) - return rewriter.notifyMatchFailure(op, "expected LLVM pointer result type"); + Value offsetI32 = castIntegerLikeTo(anchor, offset, builder.getI32Type()); + if (!offsetI32) + return failure(); - Value offset = adaptor.getOffset(); - if (offset.getType().isIndex()) - offset = rewriter.create(op.getLoc(), - rewriter.getI64Type(), offset); + unsigned bitWidth = 0; + if (auto intType = dyn_cast(elementType)) + bitWidth = intType.getWidth(); + else if (auto floatType = dyn_cast(elementType)) + bitWidth = floatType.getWidth(); + if (bitWidth == 0 || bitWidth % 8 != 0) + return failure(); - auto gep = rewriter.create( - op.getLoc(), llvmPtrType, cast(op.getPtr().getType()).getElementType(), - adaptor.getPtr(), ValueRange{offset}); - rewriter.replaceOp(op, gep.getResult()); - return success(); + Value scale = builder.create( + anchor->getLoc(), builder.getI32IntegerAttr(bitWidth / 8)); + return builder.create(anchor->getLoc(), offsetI32, scale) + .getResult(); +} + +static FailureOr materializeDynamicPltMask(ConversionPatternRewriter &rewriter, + LoweringState &state, + Location loc, + Value laneCount, + Type vectorElemType) { + Type i32Type = rewriter.getI32Type(); + Value laneCountI32 = laneCount; + if (laneCountI32.getType() != i32Type) { + laneCountI32 = castIntegerLikeTo(rewriter.getInsertionBlock()->getParentOp(), + laneCountI32, i32Type); + if (!laneCountI32) + return failure(); } -}; -struct ConvertPtoCastPtrOp final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + StringRef calleeName; + if (vectorElemType.isF32()) { + calleeName = StringRef("llvm.hivm.plt.b32.v300"); + } else if (vectorElemType.isF16() || vectorElemType.isBF16()) { + calleeName = StringRef("llvm.hivm.plt.b16.v300"); + } else if (auto intType = dyn_cast(vectorElemType)) { + if (intType.getWidth() == 32) + calleeName = StringRef("llvm.hivm.plt.b32.v300"); + else if (intType.getWidth() == 16) + calleeName = StringRef("llvm.hivm.plt.b16.v300"); + else if (intType.getWidth() == 8) + calleeName = StringRef("llvm.hivm.plt.b8.v300"); + } + if (calleeName.empty()) + return failure(); - LogicalResult - matchAndRewrite(pto::CastPtrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type convertedResultType = - getTypeConverter()->convertType(op.getResult().getType()); - if (!convertedResultType) - return rewriter.notifyMatchFailure(op, "could not convert castptr result type"); + Type maskType = VectorType::get({256}, rewriter.getI1Type()); + auto funcType = + rewriter.getFunctionType(TypeRange{i32Type}, TypeRange{maskType, i32Type}); + auto call = rewriter.create(loc, calleeName, funcType.getResults(), + ValueRange{laneCountI32}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + return call.getResult(0); +} - Value input = adaptor.getInput(); - Type inputType = input.getType(); - if (inputType == convertedResultType) { - rewriter.replaceOp(op, input); - return success(); - } - - if (auto llvmPtrType = dyn_cast(convertedResultType)) { - if (isa(inputType)) { - auto intToPtr = - rewriter.create(op.getLoc(), llvmPtrType, input); - rewriter.replaceOp(op, intToPtr.getResult()); - return success(); - } - auto sourcePtrType = dyn_cast(inputType); - if (!sourcePtrType) - return rewriter.notifyMatchFailure(op, "expected integer or LLVM pointer input"); - if (sourcePtrType.getAddressSpace() == llvmPtrType.getAddressSpace()) { - auto bitcast = - rewriter.create(op.getLoc(), llvmPtrType, input); - rewriter.replaceOp(op, bitcast.getResult()); - return success(); - } - return rewriter.notifyMatchFailure(op, "cross-address-space ptr casts are unsupported"); - } - - if (auto resultIntType = dyn_cast(convertedResultType)) { - if (auto inputPtrType = dyn_cast(inputType)) { - rewriter.replaceOpWithNewOp(op, resultIntType, input); - return success(); - } - if (auto inputIntType = dyn_cast(inputType)) { - unsigned srcWidth = inputIntType.getWidth(); - unsigned dstWidth = resultIntType.getWidth(); - if (srcWidth == dstWidth) { - rewriter.replaceOp(op, input); - return success(); - } - if (srcWidth < dstWidth) { - rewriter.replaceOpWithNewOp(op, resultIntType, input); - return success(); - } - rewriter.replaceOpWithNewOp(op, resultIntType, input); - return success(); - } - } +static FailureOr buildCarryBinaryCallee(MLIRContext *context, + Type resultType, + StringRef stem) { + std::string vec = + getElementTypeFragment(cast(resultType).getElementType()); + auto lanes = getElementCountFromVectorLike(resultType); + if (vec.empty() || !lanes) + return failure(); + return StringAttr::get(context, "llvm.hivm." + stem.str() + ".v" + + std::to_string(*lanes) + vec) + .getValue(); +} - return rewriter.notifyMatchFailure(op, "unsupported castptr conversion"); - } -}; +template +static StringRef getUnaryMaskedStem() { + if constexpr (std::is_same_v) + return "vabs"; + if constexpr (std::is_same_v) + return "vexp"; + if constexpr (std::is_same_v) + return "vln"; + if constexpr (std::is_same_v) + return "vneg"; + if constexpr (std::is_same_v) + return "vsqrt"; + if constexpr (std::is_same_v) + return "vrelu"; + if constexpr (std::is_same_v) + return "vnot"; + return {}; +} -struct ConvertPtoLoadScalarOp final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +template +static StringRef getBinaryMaskedStem() { + if constexpr (std::is_same_v) + return "vadd"; + if constexpr (std::is_same_v) + return "vsub"; + if constexpr (std::is_same_v) + return "vmul"; + if constexpr (std::is_same_v) + return "vdiv"; + if constexpr (std::is_same_v) + return "vmax"; + if constexpr (std::is_same_v) + return "vmin"; + if constexpr (std::is_same_v) + return "vand"; + if constexpr (std::is_same_v) + return "vor"; + if constexpr (std::is_same_v) + return "vxor"; + if constexpr (std::is_same_v) + return "vshl"; + if constexpr (std::is_same_v) + return "vshr"; + return {}; +} - LogicalResult - matchAndRewrite(pto::LoadScalarOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto llvmPtrType = dyn_cast(adaptor.getPtr().getType()); - if (!llvmPtrType) - return rewriter.notifyMatchFailure(op, "expected LLVM pointer operand"); +template +static StringRef getCarryBinaryStem() { + if constexpr (std::is_same_v) + return "vaddc"; + if constexpr (std::is_same_v) + return "vsubc"; + if constexpr (std::is_same_v) + return "vaddcs"; + if constexpr (std::is_same_v) + return "vsubcs"; + return {}; +} - Value offset = adaptor.getOffset(); - if (offset.getType().isIndex()) - offset = rewriter.create(op.getLoc(), - rewriter.getI64Type(), offset); +template +static constexpr bool hasCarryInput() { + return std::is_same_v || + std::is_same_v; +} - Value elemPtr = adaptor.getPtr(); - if (!matchPattern(offset, m_Zero())) { - elemPtr = rewriter.create(op.getLoc(), llvmPtrType, - op.getValue().getType(), adaptor.getPtr(), - ValueRange{offset}); - } +static FailureOr buildVselCallee(MLIRContext *context, + Type resultType) { + std::string vec = + getElementTypeFragment(cast(resultType).getElementType()); + auto lanes = getElementCountFromVectorLike(resultType); + if (vec.empty() || !lanes) + return failure(); + return StringAttr::get(context, "llvm.hivm.vsel.v" + std::to_string(*lanes) + + vec) + .getValue(); +} - auto getNaturalAlignment = [&](Type type) -> unsigned { - unsigned alignBytes = 0; - if (auto intType = dyn_cast(type)) { - alignBytes = llvm::divideCeil(unsigned(intType.getWidth()), 8u); - } else if (type.isF16() || type.isBF16()) { - alignBytes = 2; - } else if (type.isF32()) { - alignBytes = 4; - } else if (type.isF64()) { - alignBytes = 8; - } - return alignBytes; - }; +static FailureOr buildVselrCallee(MLIRContext *context, + Type resultType) { + Type elemType = getElementTypeFromVectorLike(resultType); + auto lanes = getElementCountFromVectorLike(resultType); + if (!elemType || !lanes) + return failure(); - rewriter.replaceOpWithNewOp( - op, op.getValue().getType(), elemPtr, - getNaturalAlignment(op.getValue().getType())); - return success(); - } -}; + std::string vec = getElementTypeFragment(elemType); + if (auto floatType = dyn_cast(elemType); + floatType && floatType.isF32()) + vec = "u32"; + if (vec.empty()) + return failure(); -struct ConvertPtoStoreScalarOp final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + return StringAttr::get(context, "llvm.hivm.vselr.v" + std::to_string(*lanes) + + vec) + .getValue(); +} - LogicalResult - matchAndRewrite(pto::StoreScalarOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto llvmPtrType = dyn_cast(adaptor.getPtr().getType()); - if (!llvmPtrType) - return rewriter.notifyMatchFailure(op, "expected LLVM pointer operand"); +static FailureOr buildVdupCallee(MLIRContext *context, pto::VdupOp op) { + Type inputType = op.getInput().getType(); + Type resultType = op.getResult().getType(); + std::string vec = getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + auto lanes = getElementCountFromVectorLike(resultType); + if (vec.empty() || !lanes) + return failure(); - Value offset = adaptor.getOffset(); - if (offset.getType().isIndex()) - offset = rewriter.create(op.getLoc(), - rewriter.getI64Type(), offset); + if (isa(inputType)) { + StringRef position = op.getPosition().value_or("LOWEST"); + StringRef family = position == "HIGHEST" ? "vdupm" : "vdup"; + return StringAttr::get(context, "llvm.hivm." + family.str() + ".v" + + std::to_string(*lanes) + vec + ".z") + .getValue(); + } - Value elemPtr = adaptor.getPtr(); - if (!matchPattern(offset, m_Zero())) { - elemPtr = rewriter.create(op.getLoc(), llvmPtrType, - adaptor.getValue().getType(), - adaptor.getPtr(), ValueRange{offset}); - } + return StringAttr::get(context, "llvm.hivm.vdups.v" + std::to_string(*lanes) + + vec + ".z") + .getValue(); +} - auto getNaturalAlignment = [&](Type type) -> unsigned { - unsigned alignBytes = 0; - if (auto intType = dyn_cast(type)) { - alignBytes = llvm::divideCeil(unsigned(intType.getWidth()), 8u); - } else if (type.isF16() || type.isBF16()) { - alignBytes = 2; - } else if (type.isF32()) { - alignBytes = 4; - } else if (type.isF64()) { - alignBytes = 8; - } - return alignBytes; - }; +static FailureOr buildVbrCallee(MLIRContext *context, Type scalarType) { + std::string scalar = getVbrScalarFragment(scalarType); + if (scalar.empty()) + return failure(); + return StringAttr::get(context, "llvm.hivm.vbr." + scalar + ".v300").getValue(); +} - rewriter.replaceOpWithNewOp( - op, adaptor.getValue(), elemPtr, - getNaturalAlignment(adaptor.getValue().getType())); - return success(); +static FailureOr buildPstuCallee(MLIRContext *context, pto::PstuOp op) { + if (auto maskType = dyn_cast(op.getValue().getType())) { + if (maskType.isB16()) + return StringAttr::get(context, "llvm.hivm.pstu.b16").getValue(); + if (maskType.isB32()) + return StringAttr::get(context, "llvm.hivm.pstu.b32").getValue(); } -}; + return failure(); +} -struct ConvertPtoUnrealizedCastOp final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +static StringRef buildVstusCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.vstus").getValue(); +} - LogicalResult - matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (op->getNumOperands() != 1 || op->getNumResults() != 1) - return rewriter.notifyMatchFailure(op, "only 1:1 casts are supported"); +static StringRef buildVsturCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.vstur").getValue(); +} - Type convertedResultType = - getTypeConverter()->convertType(op.getResult(0).getType()); - if (!convertedResultType) - return rewriter.notifyMatchFailure(op, "could not convert cast result type"); +static StringRef buildInitAlignCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.init.vector.align.data").getValue(); +} - Value input = adaptor.getOperands().front(); - if (auto llvmPtrType = dyn_cast(convertedResultType)) { - if (input.getType().isInteger(64)) { - rewriter.replaceOpWithNewOp(op, llvmPtrType, input); - return success(); - } - } - if (input.getType() == convertedResultType) { - rewriter.replaceOp(op, input); - return success(); - } +static StringRef buildSprclrCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.sprclr").getValue(); +} - auto cast = rewriter.create( - op.getLoc(), TypeRange{convertedResultType}, input); - rewriter.replaceOp(op, cast.getResults()); - return success(); - } -}; +static StringRef buildVstarCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.vstar").getValue(); +} -struct ConvertPtoPtrCarrierOp final : ConversionPattern { - ConvertPtoPtrCarrierOp(TypeConverter &typeConverter, MLIRContext *context) - : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), 1, context) {} +static StringRef buildVstasCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.vstas").getValue(); +} - LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - if (isa(op)) - return failure(); - if (!hasPtoPtrType(op->getOperandTypes()) && !hasPtoPtrType(op->getResultTypes())) - return failure(); - if (op->getNumRegions() != 0) - return rewriter.notifyMatchFailure(op, "region ops with pto.ptr are unsupported"); +static FailureOr buildVldsPostCallee(MLIRContext *context, + Type resultType) { + std::string vec = getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + auto lanes = getElementCountFromVectorLike(resultType); + if (vec.empty() || !lanes) + return failure(); + return StringAttr::get(context, "llvm.hivm.vldsx1.post.v" + + std::to_string(*lanes) + vec) + .getValue(); +} - SmallVector convertedResultTypes; - if (failed(typeConverter->convertTypes(op->getResultTypes(), convertedResultTypes))) - return failure(); +static FailureOr buildVstsPostCallee(MLIRContext *context, + Type valueType) { + std::string vec = getElementTypeFragment(getElementTypeFromVectorLike(valueType)); + auto lanes = getElementCountFromVectorLike(valueType); + if (vec.empty() || !lanes) + return failure(); + return StringAttr::get(context, "llvm.hivm.vstsx1.post.v" + + std::to_string(*lanes) + vec) + .getValue(); +} - OperationState state(op->getLoc(), op->getName().getStringRef()); - state.addOperands(operands); - state.addTypes(convertedResultTypes); - state.addAttributes(op->getAttrs()); - state.addSuccessors(op->getSuccessors()); +static StringRef buildVldasCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.vldas").getValue(); +} - Operation *newOp = rewriter.create(state); - rewriter.replaceOp(op, newOp->getResults()); - return success(); - } -}; +static FailureOr buildVldusCallee(MLIRContext *context, + Type resultType) { + std::string vec = getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + auto lanes = getElementCountFromVectorLike(resultType); + if (vec.empty() || !lanes) + return failure(); + return StringAttr::get(context, "llvm.hivm.vldus.v" + + std::to_string(*lanes) + vec) + .getValue(); +} -struct ConvertPtoAlignUnrealizedCastOp final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +static FailureOr buildVcmpCallee(MLIRContext *context, Type inputType, + StringRef cmpMode, + bool isScalarCompare) { + std::string elem = getElementTypeFragment(getElementTypeFromVectorLike(inputType)); + if (elem.empty()) + return failure(); + StringRef stem = isScalarCompare ? "vcmps" : "vcmp"; + return StringAttr::get(context, "llvm.hivm." + stem.str() + "." + + cmpMode.str() + "." + elem + ".z") + .getValue(); +} - LogicalResult - matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (op->getNumOperands() != 1 || op->getNumResults() != 1) - return failure(); - if (!hasPtoAlignType(op->getOperandTypes()) && - !hasPtoAlignType(op->getResultTypes())) - return failure(); +template +static StringRef getVecScalarMaskedStem() { + if constexpr (std::is_same_v) + return "vmuls"; + if constexpr (std::is_same_v) + return "vadds"; + if constexpr (std::is_same_v) + return "vmaxs"; + if constexpr (std::is_same_v) + return "vmins"; + if constexpr (std::is_same_v) + return "vlrelu"; + if constexpr (std::is_same_v) + return "vshls"; + if constexpr (std::is_same_v) + return "vshrs"; + return {}; +} - Type convertedResultType = - getTypeConverter()->convertType(op.getResult(0).getType()); - if (!convertedResultType) - return failure(); +template +static StringRef getReductionUnaryStem() { + if constexpr (std::is_same_v) + return "vcadd"; + if constexpr (std::is_same_v) + return "vcmax"; + if constexpr (std::is_same_v) + return "vcmin"; + if constexpr (std::is_same_v) + return "vcgadd"; + if constexpr (std::is_same_v) + return "vcgmax"; + if constexpr (std::is_same_v) + return "vcgmin"; + if constexpr (std::is_same_v) + return "vcpadd"; + return {}; +} - Value input = adaptor.getOperands().front(); - if (input.getType() == convertedResultType) { - rewriter.replaceOp(op, input); - return success(); - } +static FailureOr buildCopyGmToUbCallee(MLIRContext *context, + pto::CopyGmToUbufOp op) { + Type elementType = cast(op.getSource().getType()).getElementType(); + std::string elem = getCopyElementFragment(elementType); + if (elem.empty()) return failure(); - } -}; - -struct ConvertPtoAlignCarrierOp final : ConversionPattern { - ConvertPtoAlignCarrierOp(TypeConverter &typeConverter, MLIRContext *context) - : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), 1, context) {} + return StringAttr::get(context, "llvm.hivm.MOV.OUT.TO.UB.ALIGN.V2." + elem + + ".DV") + .getValue(); +} - LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - if (isa(op)) - return failure(); - if (!hasPtoAlignType(op->getOperandTypes()) && - !hasPtoAlignType(op->getResultTypes())) - return failure(); - if (op->getNumRegions() != 0) - return rewriter.notifyMatchFailure(op, - "region ops with pto.align are handled structurally"); +static StringRef buildCopyUbToGmCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.MOV.UB.TO.OUT.ALIGN.V2.DV") + .getValue(); +} - SmallVector convertedResultTypes; - if (failed(typeConverter->convertTypes(op->getResultTypes(), - convertedResultTypes))) - return failure(); +static StringRef buildPstiCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.psti.b8").getValue(); +} - OperationState state(op->getLoc(), op->getName().getStringRef()); - state.addOperands(operands); - state.addTypes(convertedResultTypes); - state.addAttributes(op->getAttrs()); - state.addSuccessors(op->getSuccessors()); +static StringRef buildPstsCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.psts.b8").getValue(); +} - Operation *newOp = rewriter.create(state); - rewriter.replaceOp(op, newOp->getResults()); - return success(); - } -}; +static StringRef buildPldiCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pldi.b8").getValue(); +} -static LogicalResult normalizePtoPtrsToLLVM(ModuleOp module, llvm::raw_ostream &diagOS) { - MLIRContext *context = module.getContext(); +static StringRef buildPldsCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.plds.b8").getValue(); +} - for (func::FuncOp funcOp : module.getOps()) { - if (funcOp.isExternal()) - continue; - } +static StringRef buildPnotCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pnot.z").getValue(); +} - TypeConverter typeConverter; - typeConverter.addConversion([](Type type) { return type; }); - typeConverter.addConversion([&](pto::PtrType type) -> Type { - return LLVM::LLVMPointerType::get( - context, static_cast(type.getMemorySpace().getAddressSpace())); - }); - auto materializePtrCast = [](OpBuilder &builder, Type resultType, - ValueRange inputs, Location loc) -> Value { - if (inputs.size() != 1) - return {}; - return builder - .create(loc, TypeRange{resultType}, inputs) - .getResult(0); - }; - typeConverter.addSourceMaterialization(materializePtrCast); - typeConverter.addTargetMaterialization(materializePtrCast); - typeConverter.addArgumentMaterialization(materializePtrCast); +static StringRef buildPselCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.psel").getValue(); +} - ConversionTarget target(*context); - target.addLegalOp(); - target.addDynamicallyLegalOp([&](func::FuncOp op) { - return typeConverter.isSignatureLegal(op.getFunctionType()) && - typeConverter.isLegal(&op.getBody()); - }); - target.addDynamicallyLegalOp( - [&](func::CallOp op) { return typeConverter.isLegal(op); }); - target.addDynamicallyLegalOp( - [&](func::ReturnOp op) { return typeConverter.isLegal(op); }); - target.addDynamicallyLegalOp( - [&](Operation *op) { - return isLegalForBranchOpInterfaceTypeConversionPattern(op, - typeConverter); - }); - target.markUnknownOpDynamicallyLegal([&](Operation *op) { - return typeConverter.isLegal(op->getOperandTypes()) && - typeConverter.isLegal(op->getResultTypes()); - }); - target.addIllegalOp(); - target.addDynamicallyLegalOp([](UnrealizedConversionCastOp op) { - return !hasPtoPtrType(op->getOperandTypes()) && !hasPtoPtrType(op->getResultTypes()); - }); +static StringRef buildPandCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pand.z").getValue(); +} - RewritePatternSet patterns(context); - scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter, patterns, - target); - populateFunctionOpInterfaceTypeConversionPattern(patterns, - typeConverter); - populateCallOpTypeConversionPattern(patterns, typeConverter); - populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); - populateReturnOpTypeConversionPattern(patterns, typeConverter); - patterns.add( - typeConverter, context); - patterns.add(typeConverter, context); +static StringRef buildPorCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.por.z").getValue(); +} - if (failed(applyPartialConversion(module, target, std::move(patterns)))) { - diagOS << "VPTO LLVM emission failed: pto.ptr normalization failed\n"; - return failure(); - } +static StringRef buildPxorCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pxor.z").getValue(); +} - SmallVector castsToFold; - module.walk([&](UnrealizedConversionCastOp castOp) { - if (castOp->getNumOperands() != 1 || castOp->getNumResults() != 1) - return; - if (!hasPtoPtrType(castOp->getOperandTypes()) && - !hasPtoPtrType(castOp->getResultTypes())) - return; - Type convertedResultType = typeConverter.convertType(castOp.getResult(0).getType()); - if (convertedResultType && convertedResultType == castOp.getOperand(0).getType()) - castsToFold.push_back(castOp); - }); - for (UnrealizedConversionCastOp castOp : castsToFold) { - castOp.getResult(0).replaceAllUsesWith(castOp.getOperand(0)); - castOp.erase(); - } +static StringRef buildPpackCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.ppack.z").getValue(); +} - return success(); +static StringRef buildPunpackCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.punpack").getValue(); } -static LogicalResult normalizePtoAlignsToABI(ModuleOp module, - llvm::raw_ostream &diagOS) { - MLIRContext *context = module.getContext(); +template +static StringRef buildPredicatePairReorderCallee(MLIRContext *context); - TypeConverter typeConverter; - typeConverter.addConversion([](Type type) { return type; }); - typeConverter.addConversion([&](pto::AlignType type) -> Type { - return VectorType::get({32}, IntegerType::get(context, 8)); - }); - auto materializeAlignCast = [](OpBuilder &builder, Type resultType, - ValueRange inputs, Location loc) -> Value { - if (inputs.size() != 1) - return {}; - return builder - .create(loc, TypeRange{resultType}, inputs) - .getResult(0); - }; - typeConverter.addSourceMaterialization(materializeAlignCast); - typeConverter.addTargetMaterialization(materializeAlignCast); - typeConverter.addArgumentMaterialization(materializeAlignCast); - - ConversionTarget target(*context); - target.addLegalOp(); - target.addDynamicallyLegalOp([&](func::FuncOp op) { - return typeConverter.isSignatureLegal(op.getFunctionType()) && - typeConverter.isLegal(&op.getBody()); - }); - target.addDynamicallyLegalOp( - [&](func::CallOp op) { return typeConverter.isLegal(op); }); - target.addDynamicallyLegalOp( - [&](func::ReturnOp op) { return typeConverter.isLegal(op); }); - target.addDynamicallyLegalOp( - [&](Operation *op) { - return isLegalForBranchOpInterfaceTypeConversionPattern(op, - typeConverter); - }); - target.addDynamicallyLegalOp( - [&](UnrealizedConversionCastOp op) { - return !hasPtoAlignType(op->getOperandTypes()) && - !hasPtoAlignType(op->getResultTypes()); - }); - target.markUnknownOpDynamicallyLegal([&](Operation *op) { - return typeConverter.isLegal(op->getOperandTypes()) && - typeConverter.isLegal(op->getResultTypes()); - }); - - RewritePatternSet patterns(context); - scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter, patterns, - target); - populateFunctionOpInterfaceTypeConversionPattern(patterns, - typeConverter); - populateCallOpTypeConversionPattern(patterns, typeConverter); - populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); - populateReturnOpTypeConversionPattern(patterns, typeConverter); - patterns.add( - typeConverter, context); +template <> +StringRef buildPredicatePairReorderCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pdintlv.b8").getValue(); +} - if (failed(applyPartialConversion(module, target, std::move(patterns)))) { - diagOS << "VPTO LLVM emission failed: pto.align normalization failed\n"; - return failure(); - } +template <> +StringRef buildPredicatePairReorderCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pdintlv.b16").getValue(); +} - SmallVector castsToFold; - module.walk([&](UnrealizedConversionCastOp castOp) { - if (castOp->getNumOperands() != 1 || castOp->getNumResults() != 1) - return; - if (!hasPtoAlignType(castOp->getOperandTypes()) && - !hasPtoAlignType(castOp->getResultTypes())) - return; - Type convertedResultType = - typeConverter.convertType(castOp.getResult(0).getType()); - if (convertedResultType && - convertedResultType == castOp.getOperand(0).getType()) - castsToFold.push_back(castOp); - }); - for (UnrealizedConversionCastOp castOp : castsToFold) { - castOp.getResult(0).replaceAllUsesWith(castOp.getOperand(0)); - castOp.erase(); - } +template <> +StringRef buildPredicatePairReorderCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pdintlv.b32").getValue(); +} - WalkResult leftover = module.walk([&](Operation *op) { - if (hasPtoAlignType(op->getOperandTypes()) || - hasPtoAlignType(op->getResultTypes())) { - diagOS << "VPTO LLVM emission failed: residual pto.align type on op " - << op->getName().getStringRef() << "\n"; - op->print(diagOS); - diagOS << "\n"; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - if (leftover.wasInterrupted()) - return failure(); - return success(); +template <> +StringRef buildPredicatePairReorderCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pintlv.b8").getValue(); } -static Type getElementTypeFromVectorLike(Type type) { - if (auto vecType = dyn_cast(type)) - return vecType.getElementType(); - if (auto vecType = dyn_cast(type)) - return vecType.getElementType(); - return {}; +template <> +StringRef buildPredicatePairReorderCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pintlv.b16").getValue(); } -static Type getElementTypeFromPointerLike(Type type) { - if (auto ptrType = dyn_cast(type)) - return ptrType.getElementType(); - if (auto memRefType = dyn_cast(type)) - return memRefType.getElementType(); - return {}; +template <> +StringRef buildPredicatePairReorderCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pintlv.b32").getValue(); } -static Type getElementTypeFromABIValue(Value value) { - if (!value) - return {}; - if (Type direct = getElementTypeFromPointerLike(value.getType())) - return direct; - return {}; +static FailureOr buildInterleaveCallee(MLIRContext *context, + Type resultType, + StringRef stem) { + return buildLaneTypedCallee(context, resultType, stem, ""); } -static std::optional getElementCountFromVectorLike(Type type) { - if (auto vecType = dyn_cast(type)) - return vecType.getElementCount(); - if (auto vecType = dyn_cast(type)) { - if (vecType.getRank() != 1) - return std::nullopt; - return vecType.getShape().front(); - } - return std::nullopt; +static FailureOr buildUnpackCallee(MLIRContext *context, + Type inputType, + Type resultType, + StringRef stem) { + std::string input = + getElementTypeFragment(getElementTypeFromVectorLike(inputType)); + std::string result = + getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + if (input.empty() || result.empty()) + return failure(); + return StringAttr::get(context, + "llvm.hivm." + stem.str() + "." + input + "2" + result) + .getValue(); } -static Value castIntegerLikeTo(Operation *anchor, Value value, Type targetType) { - OpBuilder builder(anchor); - builder.setInsertionPoint(anchor); +static FailureOr buildVpackCallee(MLIRContext *context, Type inputType, + Type resultType) { + std::string input = + getElementTypeFragment(getElementTypeFromVectorLike(inputType)); + std::string result = + getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + if (input.empty() || result.empty()) + return failure(); - if (value.getType() == targetType) - return value; + return StringAttr::get(context, "llvm.hivm.vpack." + input + "2" + result + ".x") + .getValue(); +} - auto targetInt = dyn_cast(targetType); - if (value.getType().isIndex() && targetInt) - return builder.create(anchor->getLoc(), targetType, value); - if (auto sourceInt = dyn_cast(value.getType())) { - if (targetInt) { - if (sourceInt.getWidth() < targetInt.getWidth()) - return builder.create(anchor->getLoc(), targetType, value); - if (sourceInt.getWidth() > targetInt.getWidth()) - return builder.create(anchor->getLoc(), targetType, value); - return value; - } - if (targetType.isIndex()) - return builder.create(anchor->getLoc(), targetType, value); - } +static FailureOr buildVsqzCallee(MLIRContext *context, + Type resultType) { + return buildLaneTypedCallee(context, resultType, "vsqz", ".x.v300"); +} - return {}; +static FailureOr buildVusqzCallee(MLIRContext *context, + Type resultType) { + return buildLaneTypedCallee(context, resultType, "vusqz", ".m"); } -static FailureOr convertElementOffsetToBytes(Operation *anchor, Value offset, - Type elementType) { - OpBuilder builder(anchor); - builder.setInsertionPoint(anchor); +static FailureOr buildVmulaCallee(MLIRContext *context, + Type resultType) { + return buildLaneTypedCallee(context, resultType, "vmula", ".m"); +} - Value offsetI32 = castIntegerLikeTo(anchor, offset, builder.getI32Type()); - if (!offsetI32) - return failure(); +static FailureOr buildVmullCallee(MLIRContext *context, + Type resultType) { + return buildLaneTypedCallee(context, resultType, "vmull", ""); +} - unsigned bitWidth = 0; - if (auto intType = dyn_cast(elementType)) - bitWidth = intType.getWidth(); - else if (auto floatType = dyn_cast(elementType)) - bitWidth = floatType.getWidth(); - if (bitWidth == 0 || bitWidth % 8 != 0) - return failure(); +template +static StringRef getPredicateStoreCallee(MLIRContext *context); - Value scale = builder.create( - anchor->getLoc(), builder.getI32IntegerAttr(bitWidth / 8)); - return builder.create(anchor->getLoc(), offsetI32, scale) - .getResult(); +template <> +StringRef getPredicateStoreCallee(MLIRContext *context) { + return buildPstiCallee(context); } -static Value buildBridgeCast(OpBuilder &builder, Location loc, Value input, - Type targetType) { - if (input.getType() == targetType) - return input; - if ((isa(input.getType()) && - isa(targetType)) || - (isa(input.getType()) && - isa(targetType))) { - return builder - .create(loc, TypeRange{targetType}, input) - .getResult(0); - } - return builder.create(loc, targetType, input).getResult(); +template <> +StringRef getPredicateStoreCallee(MLIRContext *context) { + return buildPstsCallee(context); } -static FailureOr requirePointerABIAddress(Operation *anchor, Value address, - llvm::raw_ostream &diagOS) { - if (isa(address.getType())) - return address; - if (auto ptrType = dyn_cast(address.getType())) { - OpBuilder builder(anchor); - builder.setInsertionPoint(anchor); - auto llvmPtrType = LLVM::LLVMPointerType::get( - builder.getContext(), - static_cast(ptrType.getMemorySpace().getAddressSpace())); - Value abiAddress = buildBridgeCast(builder, anchor->getLoc(), address, llvmPtrType); - return abiAddress; - } +template +static StringRef getPredicateLoadCallee(MLIRContext *context); - diagOS << "VPTO LLVM emission failed: expected pointer-ABI address after " - "pre-emit canonicalization, but saw " - << address.getType() << " on op "; - anchor->print(diagOS); - diagOS << "\n"; - return failure(); +template <> +StringRef getPredicateLoadCallee(MLIRContext *context) { + return buildPldiCallee(context); } -static FailureOr materializeAlignABIValue(Operation *anchor, Value align, - llvm::raw_ostream &diagOS) { - if (!align) - return failure(); - if (isa(align.getType())) - return align; +template <> +StringRef getPredicateLoadCallee(MLIRContext *context) { + return buildPldsCallee(context); +} - auto alignType = dyn_cast(align.getType()); - if (!alignType) { - diagOS << "VPTO LLVM emission failed: expected align ABI value, but saw " - << align.getType() << "\n"; - return failure(); - } +template +static StringRef getPredicateMaskCallee(MLIRContext *context); - Operation *def = align.getDefiningOp(); - if (!def) { - diagOS << "VPTO LLVM emission failed: unsupported non-ABI align producer " - << "" - << " for " << alignType << "\n"; - return failure(); - } +template <> +StringRef getPredicateMaskCallee(MLIRContext *context) { + return buildPnotCallee(context); +} - auto defName = def->getName().getStringRef(); - if (defName != "pto.init_align" && defName != "ub.poison") { - diagOS << "VPTO LLVM emission failed: unsupported non-ABI align producer "; - diagOS << def->getName(); - diagOS << " for " << alignType << "\n"; - return failure(); - } +template <> +StringRef getPredicateMaskCallee(MLIRContext *context) { + return buildPselCallee(context); +} - OpBuilder builder(anchor); - builder.setInsertionPoint(anchor); - auto abiType = cast(convertVPTOType(alignType, builder)); - auto zeroAttr = DenseElementsAttr::get(abiType, builder.getI8IntegerAttr(0)); - return builder.create(anchor->getLoc(), abiType, zeroAttr) - .getResult(); +template <> +StringRef getPredicateMaskCallee(MLIRContext *context) { + return buildPandCallee(context); } -static Value getI64Constant(OpBuilder &builder, Location loc, uint64_t value) { - return builder.create(loc, builder.getI64IntegerAttr(value)) - .getResult(); +template <> +StringRef getPredicateMaskCallee(MLIRContext *context) { + return buildPorCallee(context); } -static Value getI32Constant(OpBuilder &builder, Location loc, uint64_t value) { - return builder.create(loc, builder.getI32IntegerAttr(value)) - .getResult(); +template <> +StringRef getPredicateMaskCallee(MLIRContext *context) { + return buildPxorCallee(context); } -static Value getI16Constant(OpBuilder &builder, Location loc, uint64_t value) { - return builder.create(loc, builder.getI16IntegerAttr(value)) - .getResult(); +template +static StringRef getPredicatePackCallee(MLIRContext *context); + +template <> +StringRef getPredicatePackCallee(MLIRContext *context) { + return buildPpackCallee(context); } -static Value buildAllTrueMask(OpBuilder &builder, Location loc) { - auto maskType = VectorType::get({256}, builder.getI1Type()); - auto attr = DenseElementsAttr::get(maskType, true); - return builder.create(loc, maskType, attr).getResult(); +template <> +StringRef getPredicatePackCallee(MLIRContext *context) { + return buildPunpackCallee(context); } -static FailureOr buildPltB8Mask(IRRewriter &builder, ModuleOp module, - Location loc, uint64_t laneCount, - llvm::raw_ostream &diagOS) { - Value laneCountValue = getI32Constant(builder, loc, laneCount); - auto maskType = VectorType::get({256}, builder.getI1Type()); - auto funcType = - builder.getFunctionType({builder.getI32Type()}, {maskType, builder.getI32Type()}); - auto callee = - getOrCreateExternalFunc(module, "llvm.hivm.plt.b8.v300", funcType); - auto call = builder.create(loc, callee, ValueRange{laneCountValue}); - return call.getResult(0); +template +static StringRef buildPltCallee(MLIRContext *context); + +template <> +StringRef buildPltCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.plt.b8.v300").getValue(); } -static FailureOr buildPltB32Mask(IRRewriter &builder, ModuleOp module, - Location loc, uint64_t laneCount, - llvm::raw_ostream &diagOS) { - // Keep this helper narrowly scoped to the verified HIVM form we have observed - // in emitc-generated device IR. For Expands/TExpandS, installed PTO source - // calls pset_b32(PAT_ALL), but save-temps from the working emitc path show - // that the compiler frontend does not preserve a pset-shaped HIVM intrinsic - // here. Instead, the full-lane mask is materialized in the final device IR as - // llvm.hivm.plt.b32.v300(i32 64), i.e. a canonical "all 64 b32 lanes active" - // form that the backend accepts. Reproduce that observed lowering here; do - // not treat it as evidence that pset_b32 and plt_b32 are generally - // interchangeable at the source or VPTO level. - Value laneCountValue = getI32Constant(builder, loc, laneCount); - auto maskType = VectorType::get({256}, builder.getI1Type()); - auto funcType = - builder.getFunctionType({builder.getI32Type()}, {maskType, builder.getI32Type()}); - auto callee = - getOrCreateExternalFunc(module, "llvm.hivm.plt.b32.v300", funcType); - auto call = builder.create(loc, callee, ValueRange{laneCountValue}); - return call.getResult(0); +template <> +StringRef buildPltCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.plt.b16.v300").getValue(); } -static FailureOr buildPltB16Mask(IRRewriter &builder, ModuleOp module, - Location loc, uint64_t laneCount, - llvm::raw_ostream &diagOS) { - Value laneCountValue = getI32Constant(builder, loc, laneCount); - auto maskType = VectorType::get({256}, builder.getI1Type()); - auto funcType = - builder.getFunctionType({builder.getI32Type()}, {maskType, builder.getI32Type()}); - auto callee = - getOrCreateExternalFunc(module, "llvm.hivm.plt.b16.v300", funcType); - auto call = builder.create(loc, callee, ValueRange{laneCountValue}); - return call.getResult(0); +template <> +StringRef buildPltCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.plt.b32.v300").getValue(); } -static FailureOr buildDynamicPltMask(IRRewriter &builder, ModuleOp module, - Location loc, Value laneCount, - Type vectorElemType, - llvm::raw_ostream &diagOS) { - Value laneCountI32 = laneCount; - Type i32Type = builder.getI32Type(); - if (laneCountI32.getType() != i32Type) { - if (laneCountI32.getType().isIndex()) { - laneCountI32 = - builder.create(loc, i32Type, laneCountI32); - } else if (auto sourceInt = dyn_cast(laneCountI32.getType())) { - auto targetInt = cast(i32Type); - if (sourceInt.getWidth() < targetInt.getWidth()) { - laneCountI32 = - builder.create(loc, i32Type, laneCountI32); - } else if (sourceInt.getWidth() > targetInt.getWidth()) { - laneCountI32 = - builder.create(loc, i32Type, laneCountI32); - } - } else { - return failure(); - } - } +template +static StringRef buildPsetCallee(MLIRContext *context); - auto maskType = VectorType::get({256}, builder.getI1Type()); - auto funcType = - builder.getFunctionType({builder.getI32Type()}, {maskType, builder.getI32Type()}); +template <> +StringRef buildPsetCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pset.b8").getValue(); +} - StringRef calleeName; - if (vectorElemType.isF32()) { - calleeName = "llvm.hivm.plt.b32.v300"; - } else if (vectorElemType.isF16() || vectorElemType.isBF16()) { - calleeName = "llvm.hivm.plt.b16.v300"; - } else if (auto intType = dyn_cast(vectorElemType)) { - if (intType.getWidth() == 32) - calleeName = "llvm.hivm.plt.b32.v300"; - else if (intType.getWidth() == 16) - calleeName = "llvm.hivm.plt.b16.v300"; - } +template <> +StringRef buildPsetCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pset.b16").getValue(); +} - if (calleeName.empty()) { - diagOS << "VPTO LLVM emission failed: unsupported dynamic plt mask element " - "type " - << vectorElemType << "\n"; - return failure(); - } +template <> +StringRef buildPsetCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pset.b32").getValue(); +} - auto callee = getOrCreateExternalFunc(module, calleeName, funcType); - auto call = builder.create(loc, callee, ValueRange{laneCountI32}); - return call.getResult(0); +template +static StringRef buildPgeCallee(MLIRContext *context); + +template <> +StringRef buildPgeCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pge.b8").getValue(); } -static FailureOr packLoopPair(Operation *anchor, Value low, Value high) { - OpBuilder builder(anchor); - builder.setInsertionPoint(anchor); +template <> +StringRef buildPgeCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pge.b16").getValue(); +} - Value lowI64 = castIntegerLikeTo(anchor, low, builder.getI64Type()); - Value highI64 = castIntegerLikeTo(anchor, high, builder.getI64Type()); - if (!lowI64 || !highI64) +template <> +StringRef buildPgeCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pge.b32").getValue(); +} + +static FailureOr buildVldsCallee(MLIRContext *context, Type resultType) { + std::string vec = getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + auto lanes = getElementCountFromVectorLike(resultType); + if (vec.empty() || !lanes) return failure(); + return StringAttr::get(context, "llvm.hivm.vldsx1.v" + std::to_string(*lanes) + + vec) + .getValue(); +} - Value shift = getI64Constant(builder, anchor->getLoc(), 40); - Value highShifted = - builder.create(anchor->getLoc(), highI64, shift).getResult(); - return builder.create(anchor->getLoc(), highShifted, lowI64) - .getResult(); +static FailureOr buildVldsx2Callee(MLIRContext *context, + Type resultType) { + return buildLaneTypedCallee(context, resultType, "vldsx2", ""); } -static FailureOr packLoopSize(Operation *anchor, Value loop2, Value loop1) { - OpBuilder builder(anchor); - builder.setInsertionPoint(anchor); +static StringRef buildVsldbCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.vsldb").getValue(); +} - Value loop2I64 = castIntegerLikeTo(anchor, loop2, builder.getI64Type()); - Value loop1I64 = castIntegerLikeTo(anchor, loop1, builder.getI64Type()); - if (!loop2I64 || !loop1I64) +static FailureOr buildVstsCallee(MLIRContext *context, Type valueType) { + std::string vec = getElementTypeFragment(getElementTypeFromVectorLike(valueType)); + auto lanes = getElementCountFromVectorLike(valueType); + if (vec.empty() || !lanes) return failure(); + return StringAttr::get(context, "llvm.hivm.vstsx1.v" + std::to_string(*lanes) + + vec) + .getValue(); +} - Value shift = getI64Constant(builder, anchor->getLoc(), 21); - Value loop2Shifted = - builder.create(anchor->getLoc(), loop2I64, shift).getResult(); - return builder.create(anchor->getLoc(), loop2Shifted, loop1I64) - .getResult(); +static FailureOr buildVstsx2Callee(MLIRContext *context, Type valueType) { + return buildLaneTypedCallee(context, valueType, "vstsx2", ""); } -static FailureOr -packCopyGmToUbConfig0(Operation *anchor, pto::CopyGmToUbufOp op, - ValueRange operands) { - if (operands.size() != 11) +static StringRef buildVsstbCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.vsstb").getValue(); +} + +static FailureOr buildVgather2Callee(MLIRContext *context, + Type resultType) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + auto lanes = getElementCountFromVectorLike(resultType); + if (vec.empty() || !lanes) return failure(); + return StringAttr::get(context, "llvm.hivm.vgather2.v300.v" + + std::to_string(*lanes) + vec) + .getValue(); +} - OpBuilder builder(anchor); - builder.setInsertionPoint(anchor); - Location loc = anchor->getLoc(); +static FailureOr buildVgather2BcCallee(MLIRContext *context, + Type resultType) { + return buildLaneTypedCallee(context, resultType, "vgather2.bc", ""); +} - auto getI64Operand = [&](unsigned idx) -> Value { - return castIntegerLikeTo(anchor, operands[idx], builder.getI64Type()); - }; +static FailureOr buildVgatherbCallee(MLIRContext *context, + Type resultType) { + return buildLaneTypedCallee(context, resultType, "vgatherb.v310", ""); +} - Value sid = getI64Operand(2); - Value nBurst = getI64Operand(3); - Value lenBurst = getI64Operand(4); - Value leftPadding = getI64Operand(5); - Value rightPadding = getI64Operand(6); - Value dataSelect = castIntegerLikeTo(anchor, operands[7], builder.getI64Type()); - Value cacheCtl = getI64Operand(8); - if (!sid || !nBurst || !lenBurst || !leftPadding || !rightPadding || - !dataSelect || !cacheCtl) - return failure(); +static FailureOr buildVscatterCallee(MLIRContext *context, + Type valueType) { + return buildLaneTypedCallee(context, valueType, "vscatter", ".v300"); +} - auto shl = [&](Value value, uint64_t amount) -> Value { - return builder.create(loc, value, - getI64Constant(builder, loc, amount)); - }; - auto bitOr = [&](Value lhs, Value rhs) -> Value { - return builder.create(loc, lhs, rhs); - }; +static FailureOr buildVpreluCallee(MLIRContext *context, + Type resultType) { + return buildLaneTypedCallee(context, resultType, "vprelu", ".x"); +} - Value config = sid; - config = bitOr(config, shl(nBurst, 4)); - config = bitOr(config, shl(lenBurst, 25)); - config = bitOr(config, shl(leftPadding, 46)); - config = bitOr(config, shl(rightPadding, 52)); - config = bitOr(config, shl(dataSelect, 58)); - config = bitOr(config, shl(cacheCtl, 60)); - return config; +static FailureOr buildVaxpyCallee(MLIRContext *context, + Type resultType) { + return buildLaneTypedCallee(context, resultType, "vaxpy", ".m"); } -static FailureOr -packCopyGmToUbConfig1(Operation *anchor, ValueRange operands) { - if (operands.size() != 11) +static FailureOr buildVciCallee(MLIRContext *context, Type resultType) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + auto lanes = getElementCountFromVectorLike(resultType); + if (vec.empty() || !lanes) return failure(); - return packLoopPair(anchor, operands[9], operands[10]); + if (vec == "f16" || vec == "f32") + return StringAttr::get(context, "llvm.hivm.vci.v" + std::to_string(*lanes) + + vec + "." + vec) + .getValue(); + return StringAttr::get(context, + "llvm.hivm.vci.v" + std::to_string(*lanes) + vec) + .getValue(); } -static FailureOr -packCopyUbToGmConfig0(Operation *anchor, ValueRange operands) { - if (operands.size() != 8) +static FailureOr buildVtrcCallee(MLIRContext *context, Type resultType) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + auto lanes = getElementCountFromVectorLike(resultType); + if (vec.empty() || !lanes) return failure(); + return StringAttr::get(context, "llvm.hivm.vtrc." + vec + ".x").getValue(); +} - OpBuilder builder(anchor); - builder.setInsertionPoint(anchor); - Location loc = anchor->getLoc(); +static FailureOr buildVexpdiffCallee(MLIRContext *context, + Type inputType, + Type resultType) { + std::string srcVec = + getElementTypeFragment(getElementTypeFromVectorLike(inputType)); + auto srcLanes = getElementCountFromVectorLike(inputType); + std::string dstElem = + getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + if (srcVec.empty() || dstElem.empty() || !srcLanes) + return failure(); + return StringAttr::get(context, "llvm.hivm.vexpdif.v" + + std::to_string(*srcLanes) + srcVec + + dstElem) + .getValue(); +} - auto getI64Operand = [&](unsigned idx) -> Value { - return castIntegerLikeTo(anchor, operands[idx], builder.getI64Type()); - }; +static FailureOr buildVbitsortCallee(MLIRContext *context, + pto::VbitsortOp op) { + Type sourceElemType = cast(op.getSource().getType()).getElementType(); + if (sourceElemType.isF16()) + return StringAttr::get(context, "llvm.hivm.VBS32.V300.f16").getValue(); + if (sourceElemType.isF32()) + return StringAttr::get(context, "llvm.hivm.VBS32.V300.f32").getValue(); + return failure(); +} - Value sid = getI64Operand(2); - Value nBurst = getI64Operand(3); - Value lenBurst = getI64Operand(4); - Value reserved = getI64Operand(5); - if (!sid || !nBurst || !lenBurst || !reserved) +static FailureOr buildVcvtContract(pto::VcvtOp op) { + Type inputElemType = getElementTypeFromVectorLike(op.getInput().getType()); + Type resultElemType = getElementTypeFromVectorLike(op.getResult().getType()); + if (!inputElemType || !resultElemType) + return failure(); + auto contract = lookupVcvtContract(classifyVcvtElemType(inputElemType), + classifyVcvtElemType(resultElemType)); + if (!contract) return failure(); + return *contract; +} - auto shl = [&](Value value, uint64_t amount) -> Value { - return builder.create(loc, value, - getI64Constant(builder, loc, amount)); - }; - auto bitOr = [&](Value lhs, Value rhs) -> Value { - return builder.create(loc, lhs, rhs); - }; +template +static StringRef buildSetLoopCallee(MLIRContext *context); - Value config = sid; - config = bitOr(config, shl(nBurst, 4)); - config = bitOr(config, shl(lenBurst, 25)); - config = bitOr(config, shl(reserved, 60)); - return config; +template <> +StringRef buildSetLoopCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.LOOP2.STRIDE.OUTTOUB") + .getValue(); } -static FailureOr -packCopyUbToGmConfig1(Operation *anchor, ValueRange operands) { - if (operands.size() != 8) - return failure(); - return packLoopPair(anchor, operands[6], operands[7]); +template <> +StringRef buildSetLoopCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.LOOP1.STRIDE.OUTTOUB") + .getValue(); } -static FailureOr packVbitsortConfig(Operation *anchor, Value repeatTimes) { - OpBuilder builder(anchor); - builder.setInsertionPoint(anchor); - Location loc = anchor->getLoc(); +template <> +StringRef buildSetLoopCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.LOOP.SIZE.OUTTOUB") + .getValue(); +} - Value repeatI64 = castIntegerLikeTo(anchor, repeatTimes, builder.getI64Type()); - if (!repeatI64) - return failure(); - return builder - .create(loc, repeatI64, getI64Constant(builder, loc, 56)) - .getResult(); +template <> +StringRef buildSetLoopCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.LOOP2.STRIDE.UBTOOUT") + .getValue(); } -static func::FuncOp getOrCreateExternalFunc(ModuleOp module, StringRef name, - FunctionType type) { - if (auto existing = module.lookupSymbol(name)) - return existing; - OpBuilder builder(module.getBodyRegion()); - builder.setInsertionPointToStart(module.getBody()); - auto func = builder.create(module.getLoc(), name, type); - func.setPrivate(); - return func; -} - -static FailureOr getConfirmedCallee(Operation *op) { - if (isa(op)) - return std::string("llvm.hivm.SET.LOOP2.STRIDE.OUTTOUB"); - if (isa(op)) - return std::string("llvm.hivm.SET.LOOP1.STRIDE.OUTTOUB"); - if (isa(op)) - return std::string("llvm.hivm.SET.LOOP.SIZE.OUTTOUB"); - if (isa(op)) - return std::string("llvm.hivm.SET.LOOP2.STRIDE.UBTOOUT"); - if (isa(op)) - return std::string("llvm.hivm.SET.LOOP1.STRIDE.UBTOOUT"); - if (isa(op)) - return std::string("llvm.hivm.SET.LOOP.SIZE.UBTOOUT"); - if (auto copy = dyn_cast(op)) { - Type elementType = getElementTypeFromABIValue(copy.getSource()); - if (!elementType) - elementType = getElementTypeFromABIValue(copy.getDestination()); - std::string elem = getCopyElementFragment(elementType); - if (elem.empty()) - return failure(); - return "llvm.hivm.MOV.OUT.TO.UB.ALIGN.V2." + elem + ".DV"; - } - if (isa(op)) - return std::string("llvm.hivm.MOV.UB.TO.OUT.ALIGN.V2.DV"); - if (isa(op)) - return std::string("llvm.hivm.SET.FLAG.IMM"); - if (isa(op)) - return std::string("llvm.hivm.WAIT.FLAG.IMM"); - if (isa(op)) - return std::string("llvm.hivm.BARRIER"); - if (isa(op)) - return std::string("llvm.hivm.GET.BLOCK.IDX"); - if (isa(op)) - return std::string("llvm.hivm.GET.SUBBLOCKID"); - if (isa(op)) - return std::string("llvm.hivm.GET.BLOCK.NUM"); - if (isa(op)) - return std::string("llvm.hivm.GET.SUBBLOCKDIM"); - if (isa(op)) - return std::string("llvm.hivm.sprclr"); - if (isa(op)) - return std::string("llvm.hivm.plt.b8.v300"); - if (isa(op)) - return std::string("llvm.hivm.plt.b32.v300"); - if (isa(op)) - return std::string("llvm.hivm.plt.b16.v300"); - if (isa(op)) - return std::string("llvm.hivm.pset.b8"); - if (isa(op)) - return std::string("llvm.hivm.pset.b16"); - if (isa(op)) - return std::string("llvm.hivm.pset.b32"); - if (isa(op)) - return std::string("llvm.hivm.pge.b8"); - if (isa(op)) - return std::string("llvm.hivm.pge.b16"); - if (isa(op)) - return std::string("llvm.hivm.pge.b32"); - if (isa(op)) - return std::string("llvm.hivm.vldas"); - if (isa(op)) - return std::string("llvm.hivm.init.vector.align.data"); - if (auto vldus = dyn_cast(op)) { - std::string vec = getElementTypeFragment( - getElementTypeFromVectorLike(vldus.getResult().getType())); - auto lanes = getElementCountFromVectorLike(vldus.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vldus.v" + std::to_string(*lanes) + vec; - } - if (isa(op)) - return std::string("llvm.hivm.vstus"); - if (isa(op)) - return std::string("llvm.hivm.vstur"); - if (auto vlds = dyn_cast(op)) { - std::string vec = getElementTypeFragment(getElementTypeFromVectorLike(vlds.getResult().getType())); - auto lanes = getElementCountFromVectorLike(vlds.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - std::string name = "llvm.hivm.vldsx1"; - name += ".v" + std::to_string(*lanes) + vec; - return name; - } - if (auto vldsPost = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(vldsPost.getResult().getType())); - auto lanes = getElementCountFromVectorLike(vldsPost.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vldsx1.post.v" + std::to_string(*lanes) + vec; - } - if (auto vldsPost = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(vldsPost.getResult().getType())); - auto lanes = getElementCountFromVectorLike(vldsPost.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vldsx1.post.v" + std::to_string(*lanes) + vec; - } - if (auto vabs = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(vabs.getResult().getType())); - auto lanes = getElementCountFromVectorLike(vabs.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vabs.v" + std::to_string(*lanes) + vec + ".x"; - } - if (auto vexp = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(vexp.getResult().getType())); - auto lanes = getElementCountFromVectorLike(vexp.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vexp.v" + std::to_string(*lanes) + vec + ".x"; - } - if (auto vln = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(vln.getResult().getType())); - auto lanes = getElementCountFromVectorLike(vln.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vln.v" + std::to_string(*lanes) + vec + ".x"; - } - if (auto vneg = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(vneg.getResult().getType())); - auto lanes = getElementCountFromVectorLike(vneg.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vneg.v" + std::to_string(*lanes) + vec + ".x"; - } - if (auto vsqrt = dyn_cast(op)) { - std::string vec = getElementTypeFragment( - getElementTypeFromVectorLike(vsqrt.getResult().getType())); - auto lanes = getElementCountFromVectorLike(vsqrt.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vsqrt.v" + std::to_string(*lanes) + vec + ".x"; - } - if (auto vrelu = dyn_cast(op)) { - std::string vec = getElementTypeFragment( - getElementTypeFromVectorLike(vrelu.getResult().getType())); - auto lanes = getElementCountFromVectorLike(vrelu.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vrelu.v" + std::to_string(*lanes) + vec + ".x"; - } - if (auto vnot = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(vnot.getResult().getType())); - auto lanes = getElementCountFromVectorLike(vnot.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vnot.v" + std::to_string(*lanes) + vec + ".x"; - } - if (auto vdup = dyn_cast(op)) { - Type inputType = vdup.getInput().getType(); - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(vdup.getResult().getType())); - auto lanes = getElementCountFromVectorLike(vdup.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - if (isa(inputType)) { - StringRef position = vdup.getPosition().value_or("LOWEST"); - StringRef family = position == "HIGHEST" ? "vdupm" : "vdup"; - return "llvm.hivm." + family.str() + ".v" + std::to_string(*lanes) + vec + ".z"; - } - return "llvm.hivm.vdups.v" + std::to_string(*lanes) + vec + ".z"; - } - if (auto vbr = dyn_cast(op)) { - std::string scalar = getVbrScalarFragment(vbr.getValue().getType()); - if (scalar.empty()) - return failure(); - return "llvm.hivm.vbr." + scalar + ".v300"; - } - if (auto binary = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); - auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vadd.v" + std::to_string(*lanes) + vec + ".x"; - } - if (auto binary = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); - auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vsub.v" + std::to_string(*lanes) + vec + ".x"; - } - if (auto binary = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); - auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vmul.v" + std::to_string(*lanes) + vec + ".x"; - } - if (auto binary = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); - auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vmuls.v" + std::to_string(*lanes) + vec + ".x"; - } - if (auto binary = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); - auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vadds.v" + std::to_string(*lanes) + vec + ".x"; - } - if (auto binary = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); - auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vmaxs.v" + std::to_string(*lanes) + vec + ".x"; - } - if (auto binary = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); - auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vmins.v" + std::to_string(*lanes) + vec + ".x"; - } - if (auto binary = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); - auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vlrelu.v" + std::to_string(*lanes) + vec + ".x"; - } - if (auto binary = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); - auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vshls.v" + std::to_string(*lanes) + vec + ".x"; - } - if (auto binary = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); - auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vshrs.v" + std::to_string(*lanes) + vec + ".x"; - } - if (auto binary = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); - auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vprelu.v" + std::to_string(*lanes) + vec + ".x"; - } - if (auto binary = dyn_cast(op)) { - std::string srcVec = - getElementTypeFragment(getElementTypeFromVectorLike(binary.getInput().getType())); - auto srcLanes = getElementCountFromVectorLike(binary.getInput().getType()); - std::string dstElem = - getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); - if (srcVec.empty() || dstElem.empty() || !srcLanes) - return failure(); - return "llvm.hivm.vexpdif.v" + std::to_string(*srcLanes) + srcVec + dstElem; - } - if (auto binary = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); - auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vdiv.v" + std::to_string(*lanes) + vec + ".x"; - } - if (auto binary = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); - auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vmax.v" + std::to_string(*lanes) + vec + ".x"; - } - if (auto binary = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); - auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vmin.v" + std::to_string(*lanes) + vec + ".x"; - } - if (auto binary = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); - auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vand.v" + std::to_string(*lanes) + vec + ".x"; - } - if (auto binary = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); - auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vor.v" + std::to_string(*lanes) + vec + ".x"; - } - if (auto binary = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); - auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vxor.v" + std::to_string(*lanes) + vec + ".x"; - } - if (auto binary = dyn_cast(op)) { - std::string vec = getElementTypeFragment( - getElementTypeFromVectorLike(binary.getResult().getType())); - auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vaddc.v" + std::to_string(*lanes) + vec; - } - if (auto binary = dyn_cast(op)) { - std::string vec = getElementTypeFragment( - getElementTypeFromVectorLike(binary.getResult().getType())); - auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vsubc.v" + std::to_string(*lanes) + vec; - } - if (auto binary = dyn_cast(op)) { - std::string vec = getElementTypeFragment( - getElementTypeFromVectorLike(binary.getResult().getType())); - auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vaddcs.v" + std::to_string(*lanes) + vec; - } - if (auto binary = dyn_cast(op)) { - std::string vec = getElementTypeFragment( - getElementTypeFromVectorLike(binary.getResult().getType())); - auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vsubcs.v" + std::to_string(*lanes) + vec; - } - if (auto binary = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); - auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vshl.v" + std::to_string(*lanes) + vec + ".x"; - } - if (auto binary = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(binary.getResult().getType())); - auto lanes = getElementCountFromVectorLike(binary.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vshr.v" + std::to_string(*lanes) + vec + ".x"; - } - if (auto unary = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(unary.getResult().getType())); - auto lanes = getElementCountFromVectorLike(unary.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vcadd.v" + std::to_string(*lanes) + vec + ".x"; - } - if (auto unary = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(unary.getResult().getType())); - auto lanes = getElementCountFromVectorLike(unary.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vcmax.v" + std::to_string(*lanes) + vec + ".x"; - } - if (auto unary = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(unary.getResult().getType())); - auto lanes = getElementCountFromVectorLike(unary.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vcmin.v" + std::to_string(*lanes) + vec + ".x"; - } - if (auto unary = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(unary.getResult().getType())); - auto lanes = getElementCountFromVectorLike(unary.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vcgadd.v" + std::to_string(*lanes) + vec + ".x"; - } - if (auto unary = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(unary.getResult().getType())); - auto lanes = getElementCountFromVectorLike(unary.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vcgmax.v" + std::to_string(*lanes) + vec + ".x"; - } - if (auto unary = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(unary.getResult().getType())); - auto lanes = getElementCountFromVectorLike(unary.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vcgmin.v" + std::to_string(*lanes) + vec + ".x"; - } - if (auto unary = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(unary.getResult().getType())); - auto lanes = getElementCountFromVectorLike(unary.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vcpadd.v" + std::to_string(*lanes) + vec + ".x"; - } - if (auto ternary = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(ternary.getResult().getType())); - auto lanes = getElementCountFromVectorLike(ternary.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vmula.v" + std::to_string(*lanes) + vec + ".m"; - } - if (auto binary = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(binary.getLow().getType())); - auto lanes = getElementCountFromVectorLike(binary.getLow().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vmull.v" + std::to_string(*lanes) + vec; - } - if (auto ternary = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(ternary.getResult().getType())); - auto lanes = getElementCountFromVectorLike(ternary.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vaxpy.v" + std::to_string(*lanes) + vec + ".m"; - } - if (auto vci = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(vci.getResult().getType())); - auto lanes = getElementCountFromVectorLike(vci.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - if (vec == "f16") - return "llvm.hivm.vci.v" + std::to_string(*lanes) + vec + ".f16"; - if (vec == "f32") - return "llvm.hivm.vci.v" + std::to_string(*lanes) + vec + ".f32"; - return "llvm.hivm.vci.v" + std::to_string(*lanes) + vec; - } - if (auto vbitsort = dyn_cast(op)) { - Type sourceElemType = getElementTypeFromABIValue(vbitsort.getSource()); - if (!sourceElemType) - return failure(); - if (sourceElemType.isF16()) - return std::string("llvm.hivm.VBS32.V300.f16"); - if (sourceElemType.isF32()) - return std::string("llvm.hivm.VBS32.V300.f32"); - return failure(); - } - if (auto vtrc = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(vtrc.getResult().getType())); - auto lanes = getElementCountFromVectorLike(vtrc.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vtrc." + vec + ".x"; - } - if (auto vcvt = dyn_cast(op)) { - Type inputElemType = getElementTypeFromVectorLike(vcvt.getInput().getType()); - Type resultElemType = getElementTypeFromVectorLike(vcvt.getResult().getType()); - if (!inputElemType || !resultElemType) - return failure(); - auto contract = lookupVcvtContract(classifyVcvtElemType(inputElemType), - classifyVcvtElemType(resultElemType)); - if (contract) - return std::string(contract->intrinsic); - return failure(); - } - if (isa(op)) - return std::string("llvm.hivm.vstar"); - if (isa(op)) - return std::string("llvm.hivm.vstas"); - if (auto vsqz = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(vsqz.getResult().getType())); - auto lanes = getElementCountFromVectorLike(vsqz.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vsqz.v" + std::to_string(*lanes) + vec + ".x.v300"; - } - if (auto vusqz = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(vusqz.getResult().getType())); - auto lanes = getElementCountFromVectorLike(vusqz.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vusqz.v" + std::to_string(*lanes) + vec + ".m"; - } - if (auto unpack = dyn_cast(op)) { - Type inputElemType = getElementTypeFromVectorLike(unpack.getSrc().getType()); - Type resultElemType = getElementTypeFromVectorLike(unpack.getResult().getType()); - std::string input = getElementTypeFragment(inputElemType); - std::string result = getElementTypeFragment(resultElemType); - if (input.empty() || result.empty()) - return failure(); - return "llvm.hivm.vsunpack." + input + "2" + result; - } - if (auto unpack = dyn_cast(op)) { - Type inputElemType = getElementTypeFromVectorLike(unpack.getSrc().getType()); - Type resultElemType = getElementTypeFromVectorLike(unpack.getResult().getType()); - std::string input = getElementTypeFragment(inputElemType); - std::string result = getElementTypeFragment(resultElemType); - if (input.empty() || result.empty()) - return failure(); - return "llvm.hivm.vzunpack." + input + "2" + result; - } - if (auto pack = dyn_cast(op)) { - Type inputElemType = getElementTypeFromVectorLike(pack.getSrc().getType()); - Type resultElemType = getElementTypeFromVectorLike(pack.getResult().getType()); - std::string input = getElementTypeFragment(inputElemType); - std::string result = getElementTypeFragment(resultElemType); - auto part = parseHiLoPartImmediate(pack.getPart()); - if (input.empty() || result.empty() || !part) - return failure(); - return "llvm.hivm.vpack." + input + "2" + result + ".x"; - } - if (auto interleave = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(interleave.getLow().getType())); - auto lanes = getElementCountFromVectorLike(interleave.getLow().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vintlv.v" + std::to_string(*lanes) + vec; - } - if (auto deinterleave = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(deinterleave.getLow().getType())); - auto lanes = getElementCountFromVectorLike(deinterleave.getLow().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vdintlv.v" + std::to_string(*lanes) + vec; - } - if (isa(op)) - return std::string("llvm.hivm.vsldb"); - if (isa(op)) - return std::string("llvm.hivm.vsstb"); - if (auto vldsx2 = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(vldsx2.getLow().getType())); - auto lanes = getElementCountFromVectorLike(vldsx2.getLow().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vldsx2.v" + std::to_string(*lanes) + vec; - } - if (auto vstsx2 = dyn_cast(op)) { - std::string vec = getElementTypeFragment( - getElementTypeFromVectorLike(vstsx2.getLow().getType())); - auto lanes = getElementCountFromVectorLike(vstsx2.getLow().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vstsx2.v" + std::to_string(*lanes) + vec; - } - if (auto vsts = dyn_cast(op)) { - std::string vec = getElementTypeFragment(getElementTypeFromVectorLike(vsts.getValue().getType())); - auto lanes = getElementCountFromVectorLike(vsts.getValue().getType()); - if (vec.empty() || !lanes) - return failure(); - std::string name = "llvm.hivm.vstsx1"; - name += ".v" + std::to_string(*lanes) + vec; - return name; - } - if (auto vstsPost = dyn_cast(op)) { - std::string vec = getElementTypeFragment( - getElementTypeFromVectorLike(vstsPost.getValue().getType())); - auto lanes = getElementCountFromVectorLike(vstsPost.getValue().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vstsx1.post.v" + std::to_string(*lanes) + vec; - } - if (auto vstsPost = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(vstsPost.getValue().getType())); - auto lanes = getElementCountFromVectorLike(vstsPost.getValue().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vstsx1.post.v" + std::to_string(*lanes) + vec; - } - if (auto vcmp = dyn_cast(op)) { - std::string elem = getElementTypeFragment(getElementTypeFromVectorLike(vcmp.getSrc0().getType())); - if (elem.empty()) - return failure(); - return "llvm.hivm.vcmp." + vcmp.getCmpMode().str() + "." + elem + ".z"; - } - if (auto vcmps = dyn_cast(op)) { - std::string elem = getElementTypeFragment(getElementTypeFromVectorLike(vcmps.getSrc().getType())); - if (elem.empty()) - return failure(); - return "llvm.hivm.vcmps." + vcmps.getCmpMode().str() + "." + elem + ".z"; - } - if (auto vsel = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(vsel.getResult().getType())); - auto lanes = getElementCountFromVectorLike(vsel.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vsel.v" + std::to_string(*lanes) + vec; - } - if (auto vselr = dyn_cast(op)) { - Type elemType = getElementTypeFromVectorLike(vselr.getResult().getType()); - auto lanes = getElementCountFromVectorLike(vselr.getResult().getType()); - if (!elemType || !lanes) - return failure(); - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(vselr.getResult().getType())); - if (auto floatType = dyn_cast(elemType); floatType && floatType.isF32()) - vec = "u32"; - if (vec.empty()) - return failure(); - return "llvm.hivm.vselr.v" + std::to_string(*lanes) + vec; - } - if (isa(op)) - return std::string("llvm.hivm.ppack.z"); - if (isa(op)) - return std::string("llvm.hivm.punpack"); - if (isa(op)) - return std::string("llvm.hivm.pnot.z"); - if (isa(op)) - return std::string("llvm.hivm.psel"); - if (isa(op)) - return std::string("llvm.hivm.pand.z"); - if (isa(op)) - return std::string("llvm.hivm.por.z"); - if (isa(op)) - return std::string("llvm.hivm.pxor.z"); - if (isa(op)) - return std::string("llvm.hivm.pdintlv.b8"); - if (isa(op)) - return std::string("llvm.hivm.pdintlv.b16"); - if (isa(op)) - return std::string("llvm.hivm.pdintlv.b32"); - if (isa(op)) - return std::string("llvm.hivm.pintlv.b8"); - if (isa(op)) - return std::string("llvm.hivm.pintlv.b16"); - if (isa(op)) - return std::string("llvm.hivm.pintlv.b32"); - if (isa(op)) - return std::string("llvm.hivm.plds.b8"); - if (isa(op)) - return std::string("llvm.hivm.pldi.b8"); - if (isa(op)) - return std::string("llvm.hivm.psts.b8"); - if (op->getName().getStringRef() == "pto.pstu") { - Type maskOperandType = op->getOperand(1).getType(); - if (auto maskType = dyn_cast(maskOperandType)) { - if (maskType.isB16()) - return std::string("llvm.hivm.pstu.b16"); - if (maskType.isB32()) - return std::string("llvm.hivm.pstu.b32"); - } - if (Type baseElementType = getElementTypeFromABIValue(op->getOperand(2))) { - if (auto intType = dyn_cast(baseElementType)) { - if (intType.getWidth() == 16) - return std::string("llvm.hivm.pstu.b16"); - if (intType.getWidth() == 32) - return std::string("llvm.hivm.pstu.b32"); - } - } - // Current repo coverage only exercises the installed `b32` surface. Keep - // this fallback narrow to unblock those cases; `b16` still needs an - // end-to-end testcase path before we can claim the generic surface works. - return std::string("llvm.hivm.pstu.b32"); - } - if (auto pstu = dyn_cast(op)) { - if (auto maskType = dyn_cast(pstu.getValue().getType())) { - if (maskType.isB16()) - return std::string("llvm.hivm.pstu.b16"); - if (maskType.isB32()) - return std::string("llvm.hivm.pstu.b32"); - } - if (Type baseElementType = getElementTypeFromABIValue(pstu.getBase())) { - if (auto intType = dyn_cast(baseElementType)) { - if (intType.getWidth() == 16) - return std::string("llvm.hivm.pstu.b16"); - if (intType.getWidth() == 32) - return std::string("llvm.hivm.pstu.b32"); - } - } - return failure(); - } - if (isa(op)) - return std::string("llvm.hivm.psti.b8"); - if (auto gather = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(gather.getResult().getType())); - auto lanes = getElementCountFromVectorLike(gather.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vgather2.v300.v" + std::to_string(*lanes) + vec; - } - if (auto gather = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(gather.getResult().getType())); - auto lanes = getElementCountFromVectorLike(gather.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vgather2.bc.v" + std::to_string(*lanes) + vec; - } - if (auto gather = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(gather.getResult().getType())); - auto lanes = getElementCountFromVectorLike(gather.getResult().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vgatherb.v310.v" + std::to_string(*lanes) + vec; - } - if (auto scatter = dyn_cast(op)) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(scatter.getValue().getType())); - auto lanes = getElementCountFromVectorLike(scatter.getValue().getType()); - if (vec.empty() || !lanes) - return failure(); - return "llvm.hivm.vscatter.v" + std::to_string(*lanes) + vec + ".v300"; - } - return failure(); +template <> +StringRef buildSetLoopCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.LOOP1.STRIDE.UBTOOUT") + .getValue(); } -static LogicalResult -guardNoMemRefIntrinsicArgs(Operation *op, StringRef calleeName, - ValueRange callArgs, llvm::raw_ostream &diagOS) { - if (calleeName != "llvm.hivm.vldsx1" && calleeName != "llvm.hivm.vstsx1") - return success(); +template <> +StringRef buildSetLoopCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.LOOP.SIZE.UBTOOUT") + .getValue(); +} + +template +static StringRef buildSyncCallee(MLIRContext *context); + +template <> +StringRef buildSyncCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.FLAG.IMM").getValue(); +} + +template <> +StringRef buildSyncCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.WAIT.FLAG.IMM").getValue(); +} + +template <> +StringRef buildSyncCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.BARRIER").getValue(); +} + +template <> +StringRef buildSyncCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.GET.BUFI.mode").getValue(); +} + +template <> +StringRef buildSyncCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.RLS.BUFI.mode").getValue(); +} + +template +static StringRef buildRuntimeQueryCallee(MLIRContext *context); + +template <> +StringRef buildRuntimeQueryCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.GET.BLOCK.IDX").getValue(); +} + +template <> +StringRef buildRuntimeQueryCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.GET.SUBBLOCKID").getValue(); +} + +template <> +StringRef buildRuntimeQueryCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.GET.BLOCK.NUM").getValue(); +} - for (auto [idx, arg] : llvm::enumerate(callArgs)) { - Type argType = arg.getType(); - if (!isa(argType)) +template <> +StringRef buildRuntimeQueryCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.GET.SUBBLOCKDIM").getValue(); +} + +static LogicalResult +materializeDecls(ModuleOp module, ArrayRef plannedDecls, + llvm::raw_ostream &diagOS) { + OpBuilder builder(module.getBodyRegion()); + builder.setInsertionPointToStart(&module.getBodyRegion().front()); + for (const PlannedDecl &decl : plannedDecls) { + if (func::FuncOp existing = module.lookupSymbol(decl.name)) { + if (existing.getFunctionType() != decl.type) { + diagOS << "VPTO LLVM emission failed: conflicting declaration for " + << decl.name << "\n"; + return failure(); + } continue; - diagOS << "VPTO LLVM emission failed: intrinsic ABI guard rejected memref " - "argument #" - << idx << " for " << calleeName << " from " - << op->getName().getStringRef() << " (" << argType << ")\n"; - return failure(); + } + auto func = + builder.create(module.getLoc(), decl.name, decl.type); + func.setPrivate(); } return success(); } -static LogicalResult rewriteVPTOOp(Operation *op, ModuleOp module, - llvm::raw_ostream &diagOS) { - IRRewriter builder(op->getContext()); - builder.setInsertionPoint(op); - Location loc = op->getLoc(); +template +class LowerUnaryMaskedOpPattern final : public OpConversionPattern { +public: + explicit LowerUnaryMaskedOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} - if (auto vbr = dyn_cast(op)) { - auto calleeName = getConfirmedCallee(op); - if (failed(calleeName)) { - diagOS << "VPTO LLVM emission failed: unsupported op " - << op->getName().getStringRef() << "\n"; - return failure(); - } - - Type resultType = convertVPTOType(vbr.getResult().getType(), builder); - Type scalarType = vbr.getValue().getType(); - if (!resultType || !scalarType) { - diagOS << "VPTO LLVM emission failed: could not materialize vbr types\n"; - return failure(); + LogicalResult + matchAndRewrite(UnaryOp op, typename UnaryOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + StringRef stem = getUnaryMaskedStem(); + FailureOr calleeName = + buildLaneTypedCallee(op.getContext(), op.getResult().getType(), stem, ".x"); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported unary VPTO signature"); + + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, + "failed to convert unary result type"); + + Value input = adaptor.getOperands()[0]; + Value mask = adaptor.getOperands()[1]; + Type expectedMaskType = + this->getTypeConverter()->convertType(op->getOperand(1).getType()); + if (!input || !mask || input.getType() != resultType || + mask.getType() != expectedMaskType) { + return rewriter.notifyMatchFailure( + op, "unexpected converted unary VPTO operand types"); } - auto funcType = builder.getFunctionType({scalarType}, {resultType}); - auto callee = getOrCreateExternalFunc(module, *calleeName, funcType); - auto call = - builder.create(loc, callee, ValueRange{vbr.getValue()}); - builder.replaceOp(op, call.getResults()); + auto call = rewriter.create(op.getLoc(), *calleeName, + TypeRange{resultType}, + ValueRange{input, mask}); + state.plannedDecls.push_back( + PlannedDecl{calleeName->str(), call.getCalleeType()}); + rewriter.replaceOp(op, call.getResults()); return success(); } - if (isa(op)) { - SmallVector argTypes; - auto funcType = builder.getFunctionType(argTypes, op->getResultTypes()); - auto callee = getOrCreateExternalFunc(module, *getConfirmedCallee(op), funcType); - auto call = builder.create(loc, callee, ValueRange{}); - builder.replaceOp(op, call.getResults()); - return success(); - } +private: + LoweringState &state; +}; - auto calleeName = getConfirmedCallee(op); - if (failed(calleeName)) { - diagOS << "VPTO LLVM emission failed: unsupported op " - << op->getName().getStringRef() << "\n"; - return failure(); - } +class LowerVsqzOpPattern final : public OpConversionPattern { +public: + explicit LowerVsqzOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} - SmallVector surfaceResultTypes(op->getResultTypes().begin(), - op->getResultTypes().end()); - SmallVector loweredResultTypes; - loweredResultTypes.reserve(surfaceResultTypes.size()); - for (Type type : surfaceResultTypes) - loweredResultTypes.push_back(convertVPTOType(type, builder)); - SmallVector intrinsicResultTypes(loweredResultTypes.begin(), - loweredResultTypes.end()); - if (auto vldus = dyn_cast(op)) { - Type sourceType = convertVPTOType(vldus.getSource().getType(), builder); - if (!sourceType) { - diagOS << "VPTO LLVM emission failed: could not materialize vldus source type\n"; - return failure(); + LogicalResult + matchAndRewrite(pto::VsqzOp op, pto::VsqzOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = + buildVsqzCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vsqz VPTO signature"); + + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + Type maskType = this->getTypeConverter()->convertType(op.getMask().getType()); + if (!resultType || !maskType) + return rewriter.notifyMatchFailure(op, "failed to convert vsqz types"); + + Value input = adaptor.getInput(); + Value mask = adaptor.getMask(); + if (!input || !mask || input.getType() != resultType || + mask.getType() != maskType) { + return rewriter.notifyMatchFailure(op, + "unexpected converted vsqz operand types"); } - intrinsicResultTypes.push_back(sourceType); + + Value storeHint = + getI32Constant(rewriter, op.getLoc(), determineVsqzStoreHint(op)); + auto funcType = rewriter.getFunctionType( + TypeRange{resultType, maskType, storeHint.getType()}, TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, + ValueRange{input, mask, storeHint}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); } - SmallVector callArgs; +private: + LoweringState &state; +}; - if (isa(op)) { - auto packed = packLoopPair(op, op->getOperand(0), op->getOperand(1)); - if (failed(packed)) - return failure(); - callArgs.push_back(*packed); - } else if (isa(op)) { - auto packed = packLoopSize(op, op->getOperand(0), op->getOperand(1)); - if (failed(packed)) - return failure(); - callArgs.push_back(*packed); - } else if (auto copy = dyn_cast(op)) { - auto config0 = packCopyGmToUbConfig0(op, copy, op->getOperands()); - auto config1 = packCopyGmToUbConfig1(op, op->getOperands()); - auto destination = requirePointerABIAddress(op, copy.getDestination(), diagOS); - auto source = requirePointerABIAddress(op, copy.getSource(), diagOS); - if (failed(config0) || failed(config1) || failed(destination) || - failed(source)) - return failure(); - callArgs.push_back(*destination); - callArgs.push_back(*source); - callArgs.push_back(*config0); - callArgs.push_back(*config1); - } else if (auto copy = dyn_cast(op)) { - auto config0 = packCopyUbToGmConfig0(op, op->getOperands()); - auto config1 = packCopyUbToGmConfig1(op, op->getOperands()); - auto destination = requirePointerABIAddress(op, copy.getDestination(), diagOS); - auto source = requirePointerABIAddress(op, copy.getSource(), diagOS); - if (failed(config0) || failed(config1) || failed(destination) || - failed(source)) - return failure(); - callArgs.push_back(*destination); - callArgs.push_back(*source); - callArgs.push_back(*config0); - callArgs.push_back(*config1); - } else if (auto setFlag = dyn_cast(op)) { - auto src = parsePipeImmediate(stringifyPIPE(setFlag.getSrcPipe().getPipe())); - auto dst = parsePipeImmediate(stringifyPIPE(setFlag.getDstPipe().getPipe())); - auto event = parseEventImmediate(stringifyEVENT(setFlag.getEventId().getEvent())); - if (!src || !dst || !event) - return failure(); - callArgs.push_back(getI64Constant(builder, loc, *src)); - callArgs.push_back(getI64Constant(builder, loc, *dst)); - callArgs.push_back(getI64Constant(builder, loc, *event)); - } else if (auto waitFlag = dyn_cast(op)) { - auto src = - parsePipeImmediate(stringifyPIPE(waitFlag.getSrcPipe().getPipe())); - auto dst = - parsePipeImmediate(stringifyPIPE(waitFlag.getDstPipe().getPipe())); - auto event = - parseEventImmediate(stringifyEVENT(waitFlag.getEventId().getEvent())); - if (!src || !dst || !event) - return failure(); - callArgs.push_back(getI64Constant(builder, loc, *src)); - callArgs.push_back(getI64Constant(builder, loc, *dst)); - callArgs.push_back(getI64Constant(builder, loc, *event)); - } else if (auto barrier = dyn_cast(op)) { - auto pipe = parsePipeImmediate(stringifyPIPE(barrier.getPipe().getPipe())); - if (!pipe) - return failure(); - callArgs.push_back(getI64Constant(builder, loc, *pipe)); - } else if (auto sprclr = dyn_cast(op)) { - auto spr = parseSprImmediate(sprclr.getSpr()); - if (!spr) { - diagOS << "VPTO LLVM emission failed: unsupported sprclr target " - << sprclr.getSpr() << "\n"; - return failure(); - } - callArgs.push_back(getI16Constant(builder, loc, *spr)); - } else if (isa(op)) { - Value laneCount = castIntegerLikeTo(op, op->getOperand(0), builder.getI32Type()); - if (!laneCount) - return failure(); - callArgs.push_back(laneCount); - } else if (auto pset = dyn_cast(op)) { - auto pattern = parsePredicatePatternImmediate(pset.getPattern()); - if (!pattern) { - diagOS << "VPTO LLVM emission failed: unsupported pset_b8 pattern " - << pset.getPattern() << "\n"; - return failure(); - } - callArgs.push_back(getI32Constant(builder, loc, *pattern)); - } else if (auto pset = dyn_cast(op)) { - auto pattern = parsePredicatePatternImmediate(pset.getPattern()); - if (!pattern) { - diagOS << "VPTO LLVM emission failed: unsupported pset_b16 pattern " - << pset.getPattern() << "\n"; - return failure(); - } - callArgs.push_back(getI32Constant(builder, loc, *pattern)); - } else if (auto pset = dyn_cast(op)) { - auto pattern = parsePredicatePatternImmediate(pset.getPattern()); - if (!pattern) { - diagOS << "VPTO LLVM emission failed: unsupported pset_b32 pattern " - << pset.getPattern() << "\n"; - return failure(); - } - callArgs.push_back(getI32Constant(builder, loc, *pattern)); - } else if (auto pge = dyn_cast(op)) { - auto pattern = parsePredicatePatternImmediate(pge.getPattern()); - if (!pattern) { - diagOS << "VPTO LLVM emission failed: unsupported pge_b8 pattern " - << pge.getPattern() << "\n"; - return failure(); - } - callArgs.push_back(getI32Constant(builder, loc, *pattern)); - callArgs.push_back(getI32Constant(builder, loc, 0)); - } else if (auto pge = dyn_cast(op)) { - auto pattern = parsePredicatePatternImmediate(pge.getPattern()); - if (!pattern) { - diagOS << "VPTO LLVM emission failed: unsupported pge_b16 pattern " - << pge.getPattern() << "\n"; - return failure(); - } - callArgs.push_back(getI32Constant(builder, loc, *pattern)); - callArgs.push_back(getI32Constant(builder, loc, 0)); - } else if (auto pge = dyn_cast(op)) { - auto pattern = parsePredicatePatternImmediate(pge.getPattern()); - if (!pattern) { - diagOS << "VPTO LLVM emission failed: unsupported pge_b32 pattern " - << pge.getPattern() << "\n"; - return failure(); - } - callArgs.push_back(getI32Constant(builder, loc, *pattern)); - callArgs.push_back(getI32Constant(builder, loc, 0)); - } else if (isa(op)) { - // llvm.hivm.init.vector.align.data() has no operands. - } else if (auto vldas = dyn_cast(op)) { - auto source = requirePointerABIAddress(op, vldas.getSource(), diagOS); - if (failed(source)) - return failure(); - callArgs.push_back(*source); - } else if (auto vldus = dyn_cast(op)) { - auto source = requirePointerABIAddress(op, vldus.getSource(), diagOS); - if (failed(source)) - return failure(); - callArgs.push_back(*source); - callArgs.push_back(vldus.getAlign()); - } else if (auto vstus = dyn_cast(op)) { - Type elementType = getElementTypeFromVectorLike(vstus.getValue().getType()); - auto basePtr = requirePointerABIAddress(op, vstus.getBase(), diagOS); - auto alignValue = materializeAlignABIValue(op, vstus.getAlignIn(), diagOS); - if (!elementType || failed(basePtr)) - return failure(); - auto offsetBytes = convertElementOffsetToBytes(op, vstus.getOffset(), elementType); - if (failed(offsetBytes) || failed(alignValue)) - return failure(); - callArgs.push_back(vstus.getValue()); - callArgs.push_back(*basePtr); - callArgs.push_back(*offsetBytes); - callArgs.push_back(*alignValue); - } else if (auto vstur = dyn_cast(op)) { - auto basePtr = requirePointerABIAddress(op, vstur.getBase(), diagOS); - auto postMode = parsePostModeImmediate(vstur.getMode()); - auto alignValue = materializeAlignABIValue(op, vstur.getAlignIn(), diagOS); - if (failed(basePtr) || !postMode) { - if (!postMode) - diagOS << "VPTO LLVM emission failed: unsupported vstur mode " - << vstur.getMode() << "\n"; - return failure(); +class LowerVusqzOpPattern final : public OpConversionPattern { +public: + explicit LowerVusqzOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VusqzOp op, pto::VusqzOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = + buildVusqzCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vusqz VPTO signature"); + + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + Type maskType = this->getTypeConverter()->convertType(op.getMask().getType()); + if (!resultType || !maskType) + return rewriter.notifyMatchFailure(op, "failed to convert vusqz types"); + + Value src = adaptor.getSrc(); + Value mask = adaptor.getMask(); + if (!src || !mask || src.getType() != resultType || mask.getType() != maskType) { + return rewriter.notifyMatchFailure(op, + "unexpected converted vusqz operand types"); } - if (failed(alignValue)) - return failure(); - callArgs.push_back(vstur.getValue()); - callArgs.push_back(*basePtr); - callArgs.push_back(*alignValue); - callArgs.push_back(getI32Constant(builder, loc, *postMode)); - callArgs.push_back(getI32Constant(builder, loc, 0)); - } else if (auto vlds = dyn_cast(op)) { - Type elementType = getElementTypeFromVectorLike(vlds.getResult().getType()); - auto offsetBytes = convertElementOffsetToBytes( - op, op->getOperand(1), elementType); - auto basePtr = requirePointerABIAddress(op, op->getOperand(0), diagOS); - auto dist = - parseLoadDistImmediate(vlds.getDist().value_or("NORM"), elementType); - if (!elementType || failed(offsetBytes) || failed(basePtr) || !dist) { - if (elementType && succeeded(basePtr) && !dist) - diagOS << "VPTO LLVM emission failed: unsupported vlds dist immediate\n"; - return failure(); + + auto funcType = + rewriter.getFunctionType(TypeRange{resultType, maskType}, TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, ValueRange{src, mask}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVmulaOpPattern final : public OpConversionPattern { +public: + explicit LowerVmulaOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VmulaOp op, pto::VmulaOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = + buildVmulaCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vmula VPTO signature"); + + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + Type maskType = this->getTypeConverter()->convertType(op.getMask().getType()); + if (!resultType || !maskType) + return rewriter.notifyMatchFailure(op, "failed to convert vmula types"); + + Value acc = adaptor.getAcc(); + Value lhs = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + Value mask = adaptor.getMask(); + if (!acc || !lhs || !rhs || !mask || acc.getType() != resultType || + lhs.getType() != resultType || rhs.getType() != resultType || + mask.getType() != maskType) { + return rewriter.notifyMatchFailure(op, + "unexpected converted vmula operand types"); } - callArgs.push_back(*basePtr); - callArgs.push_back(*offsetBytes); - callArgs.push_back(getI32Constant(builder, loc, *dist)); - callArgs.push_back(getI32Constant(builder, loc, 0)); - } else if (auto vldsPost = dyn_cast(op)) { - Type elementType = getElementTypeFromVectorLike(vldsPost.getResult().getType()); - auto offsetBytes = convertElementOffsetToBytes( - op, vldsPost.getOffset(), elementType); - auto basePtr = requirePointerABIAddress(op, vldsPost.getSource(), diagOS); - auto dist = - parseLoadDistImmediate(vldsPost.getDist().value_or("NORM"), elementType); - if (!elementType || failed(offsetBytes) || failed(basePtr) || !dist) - return failure(); - callArgs.push_back(*basePtr); - callArgs.push_back(*offsetBytes); - callArgs.push_back(getI32Constant(builder, loc, *dist)); - callArgs.push_back(getI32Constant(builder, loc, 1)); - } else if (auto vabs = dyn_cast(op)) { - Value input = op->getOperand(0); - Value mask = op->getOperand(1); - Type vecType = loweredResultTypes.front(); - Type maskType = convertVPTOType(mask.getType(), builder); - if (input.getType() != vecType || mask.getType() != maskType) { - diagOS << "VPTO LLVM emission failed: unexpected vabs operand types\n"; - return failure(); + + auto funcType = rewriter.getFunctionType( + TypeRange{resultType, resultType, resultType, maskType}, + TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, + ValueRange{acc, lhs, rhs, mask}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVmullOpPattern final : public OpConversionPattern { +public: + explicit LowerVmullOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VmullOp op, pto::VmullOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = + buildVmullCallee(op.getContext(), op.getLow().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vmull VPTO signature"); + + Type inputType = this->getTypeConverter()->convertType(op.getLhs().getType()); + Type maskType = this->getTypeConverter()->convertType(op.getMask().getType()); + SmallVector resultTypes; + if (!inputType || !maskType || + failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), resultTypes))) { + return rewriter.notifyMatchFailure(op, "failed to convert vmull types"); } - callArgs.push_back(input); - callArgs.push_back(mask); - } else if (auto unary = dyn_cast(op)) { - Value input = unary.getInput(); - Value mask = unary.getMask(); - Type vecType = loweredResultTypes.front(); - Type maskType = convertVPTOType(mask.getType(), builder); - if (input.getType() != vecType || mask.getType() != maskType) { - diagOS << "VPTO LLVM emission failed: unexpected " - << op->getName().getStringRef() << " operand types\n"; - return failure(); + if (resultTypes.size() != 2 || resultTypes[0] != resultTypes[1]) + return rewriter.notifyMatchFailure(op, "unexpected converted vmull results"); + + Value lhs = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + Value mask = adaptor.getMask(); + if (!lhs || !rhs || !mask || lhs.getType() != inputType || + rhs.getType() != inputType || mask.getType() != maskType) { + return rewriter.notifyMatchFailure(op, + "unexpected converted vmull operand types"); } - callArgs.push_back(input); - callArgs.push_back(mask); - } else if (auto unary = dyn_cast(op)) { - Value input = unary.getInput(); - Value mask = unary.getMask(); - Type vecType = loweredResultTypes.front(); - Type maskType = convertVPTOType(mask.getType(), builder); - if (input.getType() != vecType || mask.getType() != maskType) { - diagOS << "VPTO LLVM emission failed: unexpected " - << op->getName().getStringRef() << " operand types\n"; - return failure(); + + auto funcType = rewriter.getFunctionType(TypeRange{inputType, inputType, maskType}, + resultTypes); + auto call = rewriter.create(op.getLoc(), *calleeName, resultTypes, + ValueRange{lhs, rhs, mask}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerBinaryMaskedOpPattern final : public OpConversionPattern { +public: + explicit LowerBinaryMaskedOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(BinaryOp op, typename BinaryOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + StringRef stem = getBinaryMaskedStem(); + FailureOr calleeName = + buildLaneTypedCallee(op.getContext(), op.getResult().getType(), stem, ".x"); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported binary VPTO signature"); + + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, + "failed to convert binary result type"); + + Value lhs = adaptor.getOperands()[0]; + Value rhs = adaptor.getOperands()[1]; + Value mask = adaptor.getOperands()[2]; + Type expectedMaskType = + this->getTypeConverter()->convertType(op->getOperand(2).getType()); + if (!lhs || !rhs || !mask || lhs.getType() != resultType || + rhs.getType() != resultType || mask.getType() != expectedMaskType) { + return rewriter.notifyMatchFailure( + op, "unexpected converted binary VPTO operand types"); } - callArgs.push_back(input); - callArgs.push_back(mask); - } else if (auto unary = dyn_cast(op)) { - Value input = unary.getInput(); - Value mask = unary.getMask(); - Type vecType = loweredResultTypes.front(); - Type maskType = convertVPTOType(mask.getType(), builder); - if (input.getType() != vecType || mask.getType() != maskType) { - diagOS << "VPTO LLVM emission failed: unexpected " - << op->getName().getStringRef() << " operand types\n"; - return failure(); + + auto call = rewriter.create(op.getLoc(), *calleeName, + TypeRange{resultType}, + ValueRange{lhs, rhs, mask}); + state.plannedDecls.push_back( + PlannedDecl{calleeName->str(), call.getCalleeType()}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerCarryBinaryOpPattern final : public OpConversionPattern { +public: + explicit LowerCarryBinaryOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(CarryOp op, typename CarryOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + StringRef stem = getCarryBinaryStem(); + FailureOr calleeName = + buildCarryBinaryCallee(op.getContext(), op.getResult().getType(), stem); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported carry VPTO signature"); + + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + Type carryType = + this->getTypeConverter()->convertType(op->getResult(1).getType()); + if (!resultType || !carryType) + return rewriter.notifyMatchFailure(op, + "failed to convert carry result types"); + + SmallVector callArgs; + callArgs.append(adaptor.getOperands().begin(), adaptor.getOperands().end()); + const size_t expectedArgCount = hasCarryInput() ? 4 : 3; + if (callArgs.size() != expectedArgCount || callArgs[0].getType() != resultType || + callArgs[1].getType() != resultType || callArgs.back().getType() != carryType) + return rewriter.notifyMatchFailure(op, + "unexpected converted carry operand types"); + if constexpr (hasCarryInput()) { + if (callArgs[2].getType() != carryType) + return rewriter.notifyMatchFailure( + op, "unexpected converted carry input operand type"); } - callArgs.push_back(input); - callArgs.push_back(mask); - } else if (auto unary = dyn_cast(op)) { - Value input = unary.getInput(); - Value mask = unary.getMask(); - Type vecType = loweredResultTypes.front(); - Type maskType = convertVPTOType(mask.getType(), builder); - if (input.getType() != vecType || mask.getType() != maskType) { - diagOS << "VPTO LLVM emission failed: unexpected " - << op->getName().getStringRef() << " operand types\n"; - return failure(); + + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType, carryType}, callArgs); + state.plannedDecls.push_back( + PlannedDecl{calleeName->str(), call.getCalleeType()}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerCopyOpPattern final : public OpConversionPattern { +public: + explicit LowerCopyOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(CopyOp op, typename CopyOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = failure(); + if constexpr (std::is_same_v) + calleeName = buildCopyGmToUbCallee(op.getContext(), op); + else + calleeName = buildCopyUbToGmCallee(op.getContext()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported copy VPTO signature"); + + auto llvmSourceType = + dyn_cast(adaptor.getOperands()[0].getType()); + auto llvmDestType = + dyn_cast(adaptor.getOperands()[1].getType()); + if (!llvmSourceType || !llvmDestType) + return rewriter.notifyMatchFailure(op, "expected LLVM pointer copy operands"); + + FailureOr config0 = failure(); + FailureOr config1 = failure(); + if constexpr (std::is_same_v) { + config0 = packCopyGmToUbConfig0(op, adaptor.getOperands()); + config1 = packCopyGmToUbConfig1(op, adaptor.getOperands()); + } else { + config0 = packCopyUbToGmConfig0(op, adaptor.getOperands()); + config1 = packCopyUbToGmConfig1(op, adaptor.getOperands()); } - callArgs.push_back(input); - callArgs.push_back(mask); - } else if (auto unary = dyn_cast(op)) { - Value input = unary.getInput(); - Value mask = unary.getMask(); - Type vecType = loweredResultTypes.front(); - Type maskType = convertVPTOType(mask.getType(), builder); - if (input.getType() != vecType || mask.getType() != maskType) { - diagOS << "VPTO LLVM emission failed: unexpected " - << op->getName().getStringRef() << " operand types\n"; - return failure(); + if (failed(config0) || failed(config1)) + return rewriter.notifyMatchFailure(op, "failed to materialize copy config"); + + SmallVector args{adaptor.getOperands()[1], adaptor.getOperands()[0], + *config0, *config1}; + auto funcType = rewriter.getFunctionType( + TypeRange{llvmDestType, llvmSourceType, rewriter.getI64Type(), + rewriter.getI64Type()}, + TypeRange{}); + auto call = rewriter.create(op.getLoc(), *calleeName, + TypeRange{}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.eraseOp(op); + (void)call; + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerVecScalarMaskedOpPattern final + : public OpConversionPattern { +public: + explicit LowerVecScalarMaskedOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(VecScalarOp op, typename VecScalarOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + StringRef stem = getVecScalarMaskedStem(); + FailureOr calleeName = + buildLaneTypedCallee(op.getContext(), op.getResult().getType(), stem, ".x"); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, + "unsupported vec-scalar VPTO signature"); + + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure( + op, "failed to convert vec-scalar result type"); + + Value input = adaptor.getOperands()[0]; + Value scalar = adaptor.getOperands()[1]; + Value mask = adaptor.getOperands()[2]; + Type expectedMaskType = + this->getTypeConverter()->convertType(op->getOperand(2).getType()); + if (!input || !scalar || !mask || input.getType() != resultType || + mask.getType() != expectedMaskType) { + return rewriter.notifyMatchFailure( + op, "unexpected converted vec-scalar VPTO operand types"); } - callArgs.push_back(input); - callArgs.push_back(mask); - } else if (auto unary = dyn_cast(op)) { - Value input = unary.getInput(); - Value mask = unary.getMask(); - Type vecType = loweredResultTypes.front(); - Type maskType = convertVPTOType(mask.getType(), builder); - if (input.getType() != vecType || mask.getType() != maskType) { - diagOS << "VPTO LLVM emission failed: unexpected " - << op->getName().getStringRef() << " operand types\n"; - return failure(); + + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, + ValueRange{input, scalar, mask}); + state.plannedDecls.push_back( + PlannedDecl{calleeName->str(), call.getCalleeType()}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerReductionUnaryOpPattern final + : public OpConversionPattern { +public: + explicit LowerReductionUnaryOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(ReductionOp op, typename ReductionOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + StringRef stem = getReductionUnaryStem(); + FailureOr calleeName = + buildLaneTypedCallee(op.getContext(), op.getResult().getType(), stem, ".x"); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, + "unsupported reduction VPTO signature"); + + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + Type maskType = this->getTypeConverter()->convertType(op.getMask().getType()); + if (!resultType || !maskType) { + return rewriter.notifyMatchFailure( + op, "failed to convert reduction result type"); } - callArgs.push_back(input); - callArgs.push_back(mask); - } else if (auto vdup = dyn_cast(op)) { - Type scalarType = getElementTypeFromVectorLike(vdup.getResult().getType()); - bool vectorInput = isa(vdup.getInput().getType()); - if (!vectorInput && (!scalarType || vdup.getInput().getType() != scalarType)) { - diagOS << "VPTO LLVM emission failed: unexpected vdup operand types\n"; - return failure(); + + Value input = adaptor.getInput(); + Value mask = adaptor.getMask(); + if (!input || !mask || input.getType() != resultType || + mask.getType() != maskType) { + return rewriter.notifyMatchFailure( + op, "unexpected converted reduction operand types"); } - if (vectorInput && vdup.getInput().getType() != loweredResultTypes.front()) { - diagOS << "VPTO LLVM emission failed: vector-input vdup requires matching result type\n"; - return failure(); + + auto call = rewriter.create(op.getLoc(), *calleeName, + TypeRange{resultType}, + ValueRange{input, mask}); + state.plannedDecls.push_back( + PlannedDecl{calleeName->str(), call.getCalleeType()}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVselOpPattern final : public OpConversionPattern { +public: + explicit LowerVselOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VselOp op, pto::VselOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = + buildVselCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vsel VPTO signature"); + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + Type maskType = this->getTypeConverter()->convertType(op.getMask().getType()); + if (!resultType || !maskType) + return rewriter.notifyMatchFailure(op, "failed to convert vsel result type"); + + Value src0 = adaptor.getSrc0(); + Value src1 = adaptor.getSrc1(); + Value mask = adaptor.getMask(); + if (!src0 || !src1 || !mask || src0.getType() != resultType || + src1.getType() != resultType || mask.getType() != maskType) { + return rewriter.notifyMatchFailure(op, + "unexpected converted vsel operand types"); } + + auto call = rewriter.create(op.getLoc(), *calleeName, + TypeRange{resultType}, + ValueRange{src0, src1, mask}); + state.plannedDecls.push_back( + PlannedDecl{calleeName->str(), call.getCalleeType()}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVdupOpPattern final : public OpConversionPattern { +public: + explicit LowerVdupOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VdupOp op, pto::VdupOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = buildVdupCallee(op.getContext(), op); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vdup VPTO signature"); + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + Type maskType = this->getTypeConverter()->convertType(op.getMask().getType()); + if (!resultType || !maskType) + return rewriter.notifyMatchFailure(op, "failed to convert vdup result type"); + + Value mask = adaptor.getMask(); + if (!mask || mask.getType() != maskType) + return rewriter.notifyMatchFailure(op, + "unexpected converted vdup mask type"); + + SmallVector callArgs; + bool vectorInput = isa(op.getInput().getType()); if (vectorInput) { - callArgs.push_back(vdup.getInput()); + Value input = adaptor.getInput(); + if (!input || input.getType() != resultType) { + return rewriter.notifyMatchFailure( + op, "vector-input vdup requires matching result type"); + } + callArgs.push_back(input); } else { - FailureOr normalizedScalar = normalizeVdupScalarOperand(builder, loc, vdup); + Type scalarType = getElementTypeFromVectorLike(op.getResult().getType()); + if (!scalarType || op.getInput().getType() != scalarType) { + return rewriter.notifyMatchFailure(op, + "unexpected scalar-input vdup type"); + } + FailureOr normalizedScalar = + normalizeVdupScalarOperand(rewriter, op.getLoc(), op); if (failed(normalizedScalar)) - return failure(); + return rewriter.notifyMatchFailure(op, + "failed to normalize scalar vdup input"); callArgs.push_back(*normalizedScalar); } - callArgs.push_back(vdup.getMask()); - callArgs.push_back(getI32Constant(builder, loc, 1)); - } else if (isa(op)) { - callArgs.append(op->operand_begin(), op->operand_end()); - } else if (isa(op)) { - callArgs.push_back(op->getOperand(0)); - callArgs.push_back(op->getOperand(1)); - callArgs.push_back(op->getOperand(2)); - } else if (isa(op)) { - callArgs.push_back(op->getOperand(0)); - callArgs.push_back(op->getOperand(1)); - callArgs.push_back(op->getOperand(2)); - callArgs.push_back(op->getOperand(3)); - } else if (auto vmula = dyn_cast(op)) { - callArgs.push_back(vmula.getAcc()); - callArgs.push_back(vmula.getLhs()); - callArgs.push_back(vmula.getRhs()); - callArgs.push_back(vmula.getMask()); - } else if (auto vmull = dyn_cast(op)) { - callArgs.push_back(vmull.getLhs()); - callArgs.push_back(vmull.getRhs()); - callArgs.push_back(vmull.getMask()); - } else if (auto vaxpy = dyn_cast(op)) { - auto laneCount = getElementCountFromVectorLike(vaxpy.getResult().getType()); - if (!laneCount) { - diagOS << "VPTO LLVM emission failed: could not determine lane count for " - << op->getName().getStringRef() << "\n"; - return failure(); - } - Type elemType = getElementTypeFromVectorLike(vaxpy.getResult().getType()); - Value mask; - if (elemType.isF32()) { - auto fullMask = buildPltB32Mask(builder, module, loc, *laneCount, diagOS); - if (failed(fullMask)) - return failure(); - mask = *fullMask; - } else { - auto fullMask = buildPltB16Mask(builder, module, loc, *laneCount, diagOS); - if (failed(fullMask)) - return failure(); - mask = *fullMask; - } - // Installed wrapper surface is dst = alpha * src0 + dst. VPTO models this - // as a pure op returning the updated addend vector. - callArgs.push_back(vaxpy.getSrc1()); - callArgs.push_back(vaxpy.getSrc0()); - callArgs.push_back(vaxpy.getAlpha()); - callArgs.push_back(mask); - } else if (auto vci = dyn_cast(op)) { - auto orderAttr = op->getAttrOfType("order"); - auto order = parseOrderImmediate(orderAttr ? orderAttr.getValue() : StringRef("ASC")); - if (!order) { - diagOS << "VPTO LLVM emission failed: unsupported vci order "; - if (orderAttr) - diagOS << orderAttr.getValue(); - else - diagOS << ""; - diagOS << "\n"; - return failure(); - } - callArgs.push_back(vci.getIndex()); - callArgs.push_back(getI32Constant(builder, loc, *order)); - } else if (isa(op)) { - callArgs.append(op->operand_begin(), op->operand_end()); - auto laneCount = getElementCountFromVectorLike(op->getResult(0).getType()); - if (!laneCount) { - diagOS << "VPTO LLVM emission failed: could not determine lane count for " - << op->getName().getStringRef() << "\n"; - return failure(); - } - Value mask; - if (getElementTypeFromVectorLike(op->getResult(0).getType()).isF32() || - getElementTypeFromVectorLike(op->getResult(0).getType()).isInteger(32)) { - auto fullMask = buildPltB32Mask(builder, module, loc, *laneCount, diagOS); - if (failed(fullMask)) - return failure(); - mask = *fullMask; - } else { - auto fullMask = buildPltB16Mask(builder, module, loc, *laneCount, diagOS); - if (failed(fullMask)) - return failure(); - mask = *fullMask; - } - callArgs.push_back(mask); - } else if (auto vexpdiff = dyn_cast(op)) { - callArgs.push_back(vexpdiff.getInput()); - callArgs.push_back(vexpdiff.getMax()); - auto srcLaneCount = getElementCountFromVectorLike(vexpdiff.getInput().getType()); - if (!srcLaneCount) { - diagOS << "VPTO LLVM emission failed: could not determine lane count for " - << op->getName().getStringRef() << "\n"; - return failure(); - } - Value mask; - Type inputElemType = getElementTypeFromVectorLike(vexpdiff.getInput().getType()); - if (inputElemType.isF32() || inputElemType.isInteger(32)) { - auto fullMask = buildPltB32Mask(builder, module, loc, *srcLaneCount, diagOS); - if (failed(fullMask)) - return failure(); - mask = *fullMask; - } else { - auto fullMask = buildPltB16Mask(builder, module, loc, *srcLaneCount, diagOS); - if (failed(fullMask)) - return failure(); - mask = *fullMask; - } - auto part = parsePartImmediate(vexpdiff.getPart()); - if (!part) { - diagOS << "VPTO LLVM emission failed: unsupported vexpdiff part "; - diagOS << vexpdiff.getPart(); - diagOS << "\n"; - return failure(); - } + callArgs.push_back(mask); - callArgs.push_back(getI32Constant(builder, loc, *part)); - } else if (isa(op)) { - callArgs.append(op->operand_begin(), op->operand_end()); - } else if (isa(op)) { - callArgs.push_back(op->getOperand(0)); - callArgs.push_back(op->getOperand(1)); - } else if (auto vtrc = dyn_cast(op)) { - auto roundMode = parseRoundModeImmediate(vtrc.getRoundMode()); - if (!roundMode) { - diagOS << "VPTO LLVM emission failed: unsupported round mode " - << vtrc.getRoundMode() << "\n"; - return failure(); + callArgs.push_back(getI32Constant(rewriter, op.getLoc(), 1)); + + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, callArgs); + state.plannedDecls.push_back( + PlannedDecl{calleeName->str(), call.getCalleeType()}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVbrOpPattern final : public OpConversionPattern { +public: + explicit LowerVbrOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VbrOp op, pto::VbrOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = + buildVbrCallee(op.getContext(), op.getValue().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vbr VPTO signature"); + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, "failed to convert vbr result type"); + + Value scalar = adaptor.getValue(); + if (!scalar || scalar.getType() != op.getValue().getType()) + return rewriter.notifyMatchFailure(op, + "unexpected converted vbr operand type"); + + auto funcType = rewriter.getFunctionType(TypeRange{scalar.getType()}, + TypeRange{resultType}); + auto call = rewriter.create(op.getLoc(), *calleeName, + TypeRange{resultType}, + ValueRange{scalar}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVselrOpPattern final : public OpConversionPattern { +public: + explicit LowerVselrOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VselrOp op, pto::VselrOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = + buildVselrCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vselr VPTO signature"); + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + auto resultVectorType = dyn_cast(resultType); + if (!resultVectorType) + return rewriter.notifyMatchFailure(op, + "unexpected converted vselr result type"); + + Type intrinsicResultType = resultType; + if (auto floatType = dyn_cast(resultVectorType.getElementType()); + floatType && floatType.isF32()) { + intrinsicResultType = VectorType::get( + resultVectorType.getShape(), rewriter.getI32Type(), + resultVectorType.getScalableDims()); } - auto laneCount = getElementCountFromVectorLike(vtrc.getResult().getType()); - if (!laneCount) { - diagOS << "VPTO LLVM emission failed: could not determine lane count for " - << op->getName().getStringRef() << "\n"; - return failure(); + + Type indexType = this->getTypeConverter()->convertType(op.getSrc1().getType()); + if (!indexType) + return rewriter.notifyMatchFailure(op, + "failed to convert vselr index type"); + + Value src0 = adaptor.getSrc0(); + Value src1 = adaptor.getSrc1(); + if (!src0 || !src1 || src1.getType() != indexType) + return rewriter.notifyMatchFailure(op, + "unexpected converted vselr operand types"); + + if (src0.getType() != intrinsicResultType) { + if (src0.getType() != resultType) + return rewriter.notifyMatchFailure(op, + "unexpected converted vselr source type"); + src0 = rewriter.create(op.getLoc(), intrinsicResultType, src0); } - auto mask = buildPltB32Mask(builder, module, loc, *laneCount, diagOS); - if (failed(mask)) - return failure(); - callArgs.push_back(vtrc.getInput()); - callArgs.push_back(getI32Constant(builder, loc, *roundMode)); - callArgs.push_back(*mask); - } else if (auto vcvt = dyn_cast(op)) { - Type inputElemType = getElementTypeFromVectorLike(vcvt.getInput().getType()); - Type resultElemType = getElementTypeFromVectorLike(vcvt.getResult().getType()); - auto inputLanes = getElementCountFromVectorLike(vcvt.getInput().getType()); - if (!inputElemType || !resultElemType || !inputLanes) { - diagOS << "VPTO LLVM emission failed: could not determine vcvt type shape\n"; - return failure(); + + auto funcType = rewriter.getFunctionType( + TypeRange{intrinsicResultType, indexType}, TypeRange{intrinsicResultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{intrinsicResultType}, + ValueRange{src0, src1}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + + Value result = call.getResult(0); + if (intrinsicResultType != resultType) + result = rewriter.create(op.getLoc(), resultType, result); + rewriter.replaceOp(op, result); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerPnotOpPattern final : public OpConversionPattern { +public: + explicit LowerPnotOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::PnotOp op, pto::PnotOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, + "failed to convert pnot result type"); + + Value input = adaptor.getInput(); + Value mask = adaptor.getMask(); + if (!input || !mask || input.getType() != resultType || + mask.getType() != resultType) { + return rewriter.notifyMatchFailure( + op, "unexpected converted pnot operand types"); } - auto contract = lookupVcvtContract(classifyVcvtElemType(inputElemType), - classifyVcvtElemType(resultElemType)); - if (!contract) { - diagOS << "VPTO LLVM emission failed: unsupported vcvt type pair " - << vcvt.getInput().getType() << " -> " << vcvt.getResult().getType() - << "\n"; - return failure(); + StringRef calleeName = getPredicateMaskCallee(op.getContext()); + auto call = rewriter.create(op.getLoc(), calleeName, + TypeRange{resultType}, + ValueRange{input, mask}); + state.plannedDecls.push_back( + PlannedDecl{calleeName.str(), call.getCalleeType()}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerInterleaveOpPattern final + : public OpConversionPattern { +public: + explicit LowerInterleaveOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(InterleaveOp op, typename InterleaveOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + StringRef stem = std::is_same_v ? "vintlv" : "vdintlv"; + FailureOr calleeName = + buildInterleaveCallee(op.getContext(), op.getLow().getType(), stem); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, + "unsupported interleave VPTO signature"); + + Type lowType = this->getTypeConverter()->convertType(op.getLow().getType()); + Type highType = this->getTypeConverter()->convertType(op.getHigh().getType()); + if (!lowType || !highType || lowType != highType) { + return rewriter.notifyMatchFailure( + op, "failed to convert interleave result types"); } - callArgs.push_back(vcvt.getInput()); - FailureOr mask = failure(); - switch (contract->maskBitWidth) { - case 8: - mask = buildPltB8Mask(builder, module, loc, *inputLanes, diagOS); - break; - case 16: - mask = buildPltB16Mask(builder, module, loc, *inputLanes, diagOS); - break; - case 32: - mask = buildPltB32Mask(builder, module, loc, *inputLanes, diagOS); - break; - default: - diagOS << "VPTO LLVM emission failed: unsupported vcvt mask width " - << contract->maskBitWidth << "\n"; - return failure(); + Value lhs = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + if (!lhs || !rhs || lhs.getType() != lowType || rhs.getType() != lowType) { + return rewriter.notifyMatchFailure( + op, "unexpected converted interleave operand types"); } - if (failed(mask)) - return failure(); - callArgs.push_back(*mask); - if (contract->requiresRnd) { - auto roundMode = vcvt.getRndAttr() - ? parseRoundModeImmediate(*vcvt.getRnd()) - : std::nullopt; - if (!roundMode) { - diagOS << "VPTO LLVM emission failed: vcvt requires valid rnd attr\n"; - return failure(); - } - callArgs.push_back(getI32Constant(builder, loc, *roundMode)); - } - if (contract->requiresSat) { - auto sat = - vcvt.getSatAttr() ? parseSaturationImmediate(*vcvt.getSat()) : std::nullopt; - if (!sat) { - diagOS << "VPTO LLVM emission failed: vcvt requires valid sat attr\n"; - return failure(); - } - callArgs.push_back(getI32Constant(builder, loc, *sat)); + auto funcType = rewriter.getFunctionType(TypeRange{lowType, lowType}, + TypeRange{lowType, highType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{lowType, highType}, ValueRange{lhs, rhs}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerPredicatePackOpPattern final : public OpConversionPattern { +public: + explicit LowerPredicatePackOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(PackOp op, typename PackOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure( + op, "failed to convert predicate-pack result type"); + + auto part = parseHiLoPartImmediate(op.getPart()); + if (!part) + return rewriter.notifyMatchFailure( + op, "unsupported predicate-pack part immediate"); + + Value input = adaptor.getInput(); + if (!input || input.getType() != resultType) + return rewriter.notifyMatchFailure( + op, "unexpected converted predicate-pack operand type"); + + Value partValue = rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(*part)); + StringRef calleeName = getPredicatePackCallee(op.getContext()); + auto funcType = rewriter.getFunctionType( + TypeRange{resultType, rewriter.getI32Type()}, TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), calleeName, TypeRange{resultType}, ValueRange{input, partValue}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerUnpackOpPattern final : public OpConversionPattern { +public: + explicit LowerUnpackOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(UnpackOp op, typename UnpackOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + StringRef stem = std::is_same_v ? "vsunpack" + : "vzunpack"; + FailureOr calleeName = buildUnpackCallee( + op.getContext(), op.getSrc().getType(), op.getResult().getType(), stem); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, + "unsupported unpack VPTO signature"); + + Type srcType = this->getTypeConverter()->convertType(op.getSrc().getType()); + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (!srcType || !resultType) + return rewriter.notifyMatchFailure(op, + "failed to convert unpack types"); + + Value src = adaptor.getSrc(); + if (!src || src.getType() != srcType) { + return rewriter.notifyMatchFailure( + op, "unexpected converted unpack source type"); } - if (contract->requiresPart) { - auto part = - vcvt.getPartAttr() ? parsePartImmediate(*vcvt.getPart()) : std::nullopt; - if (!part) { - diagOS << "VPTO LLVM emission failed: vcvt requires valid part attr\n"; - return failure(); - } - callArgs.push_back(getI32Constant(builder, loc, *part)); + + Value part = castIntegerLikeTo(op, adaptor.getPart(), rewriter.getI32Type()); + if (!part) + return rewriter.notifyMatchFailure(op, "failed to materialize unpack part"); + + auto funcType = rewriter.getFunctionType(TypeRange{srcType, part.getType()}, + TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, ValueRange{src, part}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVpackOpPattern final : public OpConversionPattern { +public: + explicit LowerVpackOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VpackOp op, pto::VpackOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = + buildVpackCallee(op.getContext(), op.getSrc().getType(), + op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vpack VPTO signature"); + + Type srcType = this->getTypeConverter()->convertType(op.getSrc().getType()); + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (!srcType || !resultType) + return rewriter.notifyMatchFailure(op, "failed to convert vpack types"); + + auto partImm = parseHiLoPartImmediate(op.getPart()); + if (!partImm) + return rewriter.notifyMatchFailure(op, "unsupported vpack part immediate"); + + Value src = adaptor.getSrc(); + if (!src || src.getType() != srcType) { + return rewriter.notifyMatchFailure( + op, "unexpected converted vpack source type"); } - } else if (auto vstar = dyn_cast(op)) { - auto basePtr = requirePointerABIAddress(op, vstar.getDestination(), diagOS); - auto alignValue = materializeAlignABIValue(op, vstar.getValue(), diagOS); - if (failed(basePtr) || failed(alignValue)) - return failure(); - callArgs.push_back(*alignValue); - callArgs.push_back(*basePtr); - callArgs.push_back(getI32Constant(builder, loc, 0)); - } else if (auto vstas = dyn_cast(op)) { - auto basePtr = requirePointerABIAddress(op, vstas.getDestination(), diagOS); - auto alignValue = materializeAlignABIValue(op, vstas.getValue(), diagOS); - Type elementType = getElementTypeFromABIValue(vstas.getDestination()); - if (failed(basePtr) || failed(alignValue) || !elementType) { - diagOS << "VPTO LLVM emission failed: could not materialize vstas ABI " - "inputs; destination type=" - << vstas.getDestination().getType() << ", element type=" - << (elementType ? elementType : Type()) << "\n"; - return failure(); + + Value part = getI32Constant(rewriter, op.getLoc(), *partImm); + auto funcType = rewriter.getFunctionType(TypeRange{srcType, part.getType()}, + TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, ValueRange{src, part}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerPredicateMaskBinaryOpPattern final + : public OpConversionPattern { +public: + explicit LowerPredicateMaskBinaryOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(PredicateMaskOp op, typename PredicateMaskOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure( + op, "failed to convert predicate-mask result type"); + + Value src0 = adaptor.getSrc0(); + Value src1 = adaptor.getSrc1(); + Value mask = adaptor.getMask(); + if (!src0 || !src1 || !mask || src0.getType() != resultType || + src1.getType() != resultType || mask.getType() != resultType) { + return rewriter.notifyMatchFailure( + op, "unexpected converted predicate-mask operand types"); } - auto offsetBytes = convertElementOffsetToBytes(op, vstas.getOffset(), elementType); - if (failed(offsetBytes)) { - diagOS << "VPTO LLVM emission failed: could not materialize vstas byte " - "offset from " - << vstas.getOffset().getType() << " using element type " - << elementType << "\n"; - return failure(); + + StringRef calleeName = getPredicateMaskCallee(op.getContext()); + auto call = rewriter.create(op.getLoc(), calleeName, + TypeRange{resultType}, + ValueRange{src0, src1, mask}); + state.plannedDecls.push_back( + PlannedDecl{calleeName.str(), call.getCalleeType()}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerPredicatePairReorderOpPattern final + : public OpConversionPattern { +public: + explicit LowerPredicatePairReorderOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(ReorderOp op, typename ReorderOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector resultTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), resultTypes))) + return rewriter.notifyMatchFailure( + op, "failed to convert predicate-pair-reorder result types"); + if (resultTypes.size() != 2 || resultTypes[0] != resultTypes[1]) + return rewriter.notifyMatchFailure( + op, "unexpected predicate-pair-reorder converted result types"); + + Value lhs = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + if (!lhs || !rhs || lhs.getType() != resultTypes[0] || + rhs.getType() != resultTypes[0]) { + return rewriter.notifyMatchFailure( + op, "unexpected converted predicate-pair-reorder operand types"); } - callArgs.push_back(*alignValue); - callArgs.push_back(*basePtr); - callArgs.push_back(*offsetBytes); - callArgs.push_back(getI32Constant(builder, loc, 0)); - } else if (auto vsqz = dyn_cast(op)) { - callArgs.push_back(vsqz.getInput()); - callArgs.push_back(vsqz.getMask()); - callArgs.push_back( - getI32Constant(builder, loc, determineVsqzStoreHint(vsqz))); - } else if (auto vusqz = dyn_cast(op)) { - callArgs.push_back(vusqz.getSrc()); - callArgs.push_back(vusqz.getMask()); - } else if (auto unpack = dyn_cast(op)) { - Value part = castIntegerLikeTo(op, unpack.getPart(), builder.getI32Type()); - if (!part) { - diagOS << "VPTO LLVM emission failed: could not materialize vsunpack part\n"; - return failure(); + + StringRef calleeName = + buildPredicatePairReorderCallee(op.getContext()); + auto call = rewriter.create(op.getLoc(), calleeName, resultTypes, + ValueRange{lhs, rhs}); + state.plannedDecls.push_back( + PlannedDecl{calleeName.str(), call.getCalleeType()}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerCmpOpPattern final : public OpConversionPattern { +public: + explicit LowerCmpOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(CmpOp op, typename CmpOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + constexpr bool isScalarCompare = std::is_same_v; + Type inputType = Type(); + if constexpr (isScalarCompare) + inputType = op.getSrc().getType(); + else + inputType = op.getSrc0().getType(); + FailureOr calleeName = + buildVcmpCallee(op.getContext(), inputType, op.getCmpMode(), + isScalarCompare); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, + "unsupported compare VPTO signature"); + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + Type maskType = + this->getTypeConverter()->convertType(op.getMask().getType()); + if (!resultType || !maskType) + return rewriter.notifyMatchFailure(op, + "failed to convert compare result type"); + if (resultType != maskType) + return rewriter.notifyMatchFailure(op, + "unexpected compare mask conversion"); + + SmallVector callArgs; + callArgs.append(adaptor.getOperands().begin(), adaptor.getOperands().end()); + if constexpr (isScalarCompare) { + if (callArgs.size() != 3 || !callArgs[0] || !callArgs[1] || !callArgs[2] || + callArgs[2].getType() != maskType) { + return rewriter.notifyMatchFailure( + op, "unexpected converted scalar-compare operand types"); + } + } else { + if (callArgs.size() != 3 || !callArgs[0] || !callArgs[1] || !callArgs[2] || + callArgs[0].getType() != callArgs[1].getType() || + callArgs[2].getType() != maskType) { + return rewriter.notifyMatchFailure( + op, "unexpected converted compare operand types"); + } } - callArgs.push_back(unpack.getSrc()); - callArgs.push_back(part); - } else if (auto unpack = dyn_cast(op)) { - Value part = castIntegerLikeTo(op, unpack.getPart(), builder.getI32Type()); - if (!part) { - diagOS << "VPTO LLVM emission failed: could not materialize vzunpack part\n"; - return failure(); + + auto call = rewriter.create(op.getLoc(), *calleeName, + TypeRange{resultType}, callArgs); + state.plannedDecls.push_back( + PlannedDecl{calleeName->str(), call.getCalleeType()}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerPltOpPattern final : public OpConversionPattern { +public: + explicit LowerPltOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(PltOp op, typename PltOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value laneCount = castIntegerLikeTo(op, adaptor.getScalar(), rewriter.getI32Type()); + if (!laneCount) + return rewriter.notifyMatchFailure(op, "failed to materialize plt lane count"); + + SmallVector resultTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), resultTypes))) + return rewriter.notifyMatchFailure(op, "failed to convert plt result types"); + + StringRef calleeName = buildPltCallee(op.getContext()); + auto funcType = rewriter.getFunctionType(TypeRange{rewriter.getI32Type()}, + resultTypes); + auto call = rewriter.create(op.getLoc(), calleeName, + resultTypes, ValueRange{laneCount}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerPsetOpPattern final : public OpConversionPattern { +public: + explicit LowerPsetOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(PsetOp op, typename PsetOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + auto pattern = parsePredicatePatternImmediate(op.getPattern()); + if (!pattern) + return rewriter.notifyMatchFailure(op, "unsupported pset pattern"); + + SmallVector resultTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), resultTypes))) + return rewriter.notifyMatchFailure(op, "failed to convert pset result types"); + + StringRef calleeName = buildPsetCallee(op.getContext()); + Value patternValue = rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(*pattern)); + auto funcType = rewriter.getFunctionType(TypeRange{rewriter.getI32Type()}, + resultTypes); + auto call = rewriter.create(op.getLoc(), calleeName, + resultTypes, ValueRange{patternValue}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerPgeOpPattern final : public OpConversionPattern { +public: + explicit LowerPgeOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(PgeOp op, typename PgeOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + auto pattern = parsePredicatePatternImmediate(op.getPattern()); + if (!pattern) + return rewriter.notifyMatchFailure(op, "unsupported pge pattern"); + + SmallVector resultTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), resultTypes))) + return rewriter.notifyMatchFailure(op, "failed to convert pge result types"); + + StringRef calleeName = buildPgeCallee(op.getContext()); + Value patternValue = rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(*pattern)); + Value zero = rewriter.create(op.getLoc(), + rewriter.getI32IntegerAttr(0)); + auto funcType = rewriter.getFunctionType( + TypeRange{rewriter.getI32Type(), rewriter.getI32Type()}, resultTypes); + auto call = + rewriter.create(op.getLoc(), calleeName, resultTypes, + ValueRange{patternValue, zero}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVldsOpPattern final : public OpConversionPattern { +public: + explicit LowerVldsOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VldsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type elementType = getElementTypeFromVectorLike(op.getResult().getType()); + if (!elementType) + return rewriter.notifyMatchFailure(op, "unsupported vlds element type"); + auto offsetBytes = convertElementOffsetToBytes(op, adaptor.getOffset(), elementType); + auto basePtr = dyn_cast(adaptor.getSource().getType()); + auto dist = + parseLoadDistImmediate(op.getDist().value_or("NORM"), elementType); + if (failed(offsetBytes) || !basePtr || !dist) + return rewriter.notifyMatchFailure(op, "failed to materialize vlds operands"); + + SmallVector resultTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), resultTypes))) + return rewriter.notifyMatchFailure(op, "failed to convert vlds result types"); + + FailureOr calleeName = buildVldsCallee(op.getContext(), + op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vlds signature"); + + Value distValue = rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(*dist)); + Value zero = rewriter.create(op.getLoc(), + rewriter.getI32IntegerAttr(0)); + SmallVector args{adaptor.getSource(), *offsetBytes, distValue, zero}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getSource().getType(), rewriter.getI32Type(), + rewriter.getI32Type(), rewriter.getI32Type()}, + resultTypes); + auto call = rewriter.create(op.getLoc(), *calleeName, + resultTypes, args); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVldsPostOpPattern final + : public OpConversionPattern { +public: + explicit LowerVldsPostOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::VldsPostOp op, pto::VldsPostOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type elementType = getElementTypeFromVectorLike(op.getResult().getType()); + if (!elementType) + return rewriter.notifyMatchFailure(op, "unsupported vlds_post element type"); + + auto offsetBytes = convertElementOffsetToBytes(op, adaptor.getOffset(), elementType); + auto basePtr = dyn_cast(adaptor.getSource().getType()); + auto dist = + parseLoadDistImmediate(op.getDist().value_or("NORM"), elementType); + if (failed(offsetBytes) || !basePtr || !dist) { + return rewriter.notifyMatchFailure(op, + "failed to materialize vlds_post operands"); } - callArgs.push_back(unpack.getSrc()); - callArgs.push_back(part); - } else if (auto pack = dyn_cast(op)) { - auto part = parseHiLoPartImmediate(pack.getPart()); - if (!part) { - diagOS << "VPTO LLVM emission failed: unsupported vpack part " - << pack.getPart() << "\n"; - return failure(); + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + Type updatedSourceType = + this->getTypeConverter()->convertType(op.getUpdatedSource().getType()); + if (!resultType || !updatedSourceType || updatedSourceType != adaptor.getSource().getType()) { + return rewriter.notifyMatchFailure(op, + "failed to convert vlds_post result types"); } - callArgs.push_back(pack.getSrc()); - callArgs.push_back(getI32Constant(builder, loc, *part)); - } else if (auto interleave = dyn_cast(op)) { - callArgs.push_back(interleave.getLhs()); - callArgs.push_back(interleave.getRhs()); - } else if (auto deinterleave = dyn_cast(op)) { - callArgs.push_back(deinterleave.getLhs()); - callArgs.push_back(deinterleave.getRhs()); - } else if (auto vldsx2 = dyn_cast(op)) { - Type elementType = getElementTypeFromVectorLike(vldsx2.getLow().getType()); - auto offsetBytes = convertElementOffsetToBytes(op, vldsx2.getOffset(), elementType); - auto basePtr = requirePointerABIAddress(op, vldsx2.getSource(), diagOS); - auto dist = parseLoadX2DistImmediate(vldsx2.getDist(), elementType); - if (!elementType || failed(offsetBytes) || failed(basePtr) || !dist) { - if (elementType && succeeded(basePtr) && !dist) - diagOS << "VPTO LLVM emission failed: unsupported vldsx2 dist immediate\n"; - return failure(); + + FailureOr calleeName = + buildVldsPostCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vlds_post signature"); + + Value distValue = getI32Constant(rewriter, op.getLoc(), *dist); + Value postValue = getI32Constant(rewriter, op.getLoc(), 1); + SmallVector args{adaptor.getSource(), *offsetBytes, distValue, postValue}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getSource().getType(), (*offsetBytes).getType(), + distValue.getType(), postValue.getType()}, + TypeRange{resultType, updatedSourceType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType, updatedSourceType}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVldsx2OpPattern final : public OpConversionPattern { +public: + explicit LowerVldsx2OpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::Vldsx2Op op, pto::Vldsx2Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type elementType = getElementTypeFromVectorLike(op.getLow().getType()); + if (!elementType) + return rewriter.notifyMatchFailure(op, "unsupported vldsx2 element type"); + + auto offsetBytes = + convertElementOffsetToBytes(op, adaptor.getOffset(), elementType); + auto basePtr = dyn_cast(adaptor.getSource().getType()); + auto dist = parseLoadX2DistImmediate(op.getDist(), elementType); + if (failed(offsetBytes) || !basePtr || !dist) { + return rewriter.notifyMatchFailure(op, + "failed to materialize vldsx2 operands"); } - callArgs.push_back(*basePtr); - callArgs.push_back(*offsetBytes); - callArgs.push_back(getI32Constant(builder, loc, *dist)); - callArgs.push_back(getI32Constant(builder, loc, 0)); - } else if (auto vsldb = dyn_cast(op)) { - auto basePtr = requirePointerABIAddress(op, vsldb.getSource(), diagOS); - Value packedStride = packBlockRepeatStride( - op, vsldb.getBlockStride(), vsldb.getRepeatStride()); - if (failed(basePtr) || !packedStride) { - if (succeeded(basePtr) && !packedStride) - diagOS << "VPTO LLVM emission failed: could not pack vsldb control word\n"; - return failure(); + + SmallVector resultTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), + resultTypes)) || + resultTypes.size() != 2) { + return rewriter.notifyMatchFailure(op, + "failed to convert vldsx2 result types"); } - callArgs.push_back(*basePtr); - callArgs.push_back(packedStride); - callArgs.push_back(getI32Constant(builder, loc, 0)); - callArgs.push_back(vsldb.getMask()); - } else if (auto vsstb = dyn_cast(op)) { - auto basePtr = requirePointerABIAddress(op, vsstb.getDestination(), diagOS); - Value packedStride = packBlockRepeatStride( - op, vsstb.getBlockStride(), vsstb.getRepeatStride()); - if (failed(basePtr) || !packedStride) { - if (succeeded(basePtr) && !packedStride) - diagOS << "VPTO LLVM emission failed: could not pack vsstb control word\n"; - return failure(); + + FailureOr calleeName = + buildVldsx2Callee(op.getContext(), op.getLow().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vldsx2 signature"); + + Value distValue = getI32Constant(rewriter, op.getLoc(), *dist); + Value zeroValue = getI32Constant(rewriter, op.getLoc(), 0); + SmallVector args{adaptor.getSource(), *offsetBytes, distValue, + zeroValue}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getSource().getType(), (*offsetBytes).getType(), + distValue.getType(), zeroValue.getType()}, + resultTypes); + auto call = rewriter.create(op.getLoc(), *calleeName, + resultTypes, args); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVsldbOpPattern final : public OpConversionPattern { +public: + explicit LowerVsldbOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VsldbOp op, pto::VsldbOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto basePtr = dyn_cast(adaptor.getSource().getType()); + Value packedStride = + packBlockRepeatStride(op, adaptor.getBlockStride(), adaptor.getRepeatStride()); + if (!basePtr || !packedStride) + return rewriter.notifyMatchFailure(op, "failed to materialize vsldb operands"); + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, "failed to convert vsldb result type"); + + StringRef calleeName = buildVsldbCallee(op.getContext()); + Value zeroValue = getI32Constant(rewriter, op.getLoc(), 0); + SmallVector args{adaptor.getSource(), packedStride, zeroValue, + adaptor.getMask()}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getSource().getType(), packedStride.getType(), + zeroValue.getType(), adaptor.getMask().getType()}, + TypeRange{resultType}); + auto call = rewriter.create(op.getLoc(), calleeName, + TypeRange{resultType}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerInitAlignOpPattern final + : public OpConversionPattern { +public: + explicit LowerInitAlignOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::InitAlignOp op, pto::InitAlignOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, "failed to convert init_align result type"); + + StringRef calleeName = buildInitAlignCallee(op.getContext()); + auto funcType = rewriter.getFunctionType(TypeRange{}, TypeRange{resultType}); + auto call = + rewriter.create(op.getLoc(), calleeName, TypeRange{resultType}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVldasOpPattern final : public OpConversionPattern { +public: + explicit LowerVldasOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VldasOp op, pto::VldasOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto sourceType = dyn_cast(adaptor.getSource().getType()); + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!sourceType || !resultType) + return rewriter.notifyMatchFailure(op, + "expected converted vldas operand/result types"); + + StringRef calleeName = buildVldasCallee(op.getContext()); + auto funcType = + rewriter.getFunctionType(TypeRange{adaptor.getSource().getType()}, + TypeRange{resultType}); + auto call = rewriter.create(op.getLoc(), calleeName, + TypeRange{resultType}, + ValueRange{adaptor.getSource()}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVldusOpPattern final : public OpConversionPattern { +public: + explicit LowerVldusOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VldusOp op, pto::VldusOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto sourceType = dyn_cast(adaptor.getSource().getType()); + SmallVector resultTypes; + if (!sourceType || + failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), resultTypes)) || + resultTypes.size() != 2 || adaptor.getAlign().getType() != resultTypes[1]) { + return rewriter.notifyMatchFailure(op, + "expected converted vldus operand/result types"); } - callArgs.push_back(vsstb.getValue()); - callArgs.push_back(*basePtr); - callArgs.push_back(packedStride); - callArgs.push_back(getI32Constant(builder, loc, 0)); - callArgs.push_back(vsstb.getMask()); - } else if (auto vstx2 = dyn_cast(op)) { - Type elementType = getElementTypeFromVectorLike(vstx2.getLow().getType()); + + FailureOr calleeName = + buildVldusCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vldus signature"); + + SmallVector intrinsicResultTypes(resultTypes.begin(), resultTypes.end()); + // The installed no-post A5 vldus intrinsic returns an extra hidden base ptr. + intrinsicResultTypes.push_back(adaptor.getSource().getType()); + + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getSource().getType(), adaptor.getAlign().getType()}, + intrinsicResultTypes); + auto call = rewriter.create( + op.getLoc(), *calleeName, intrinsicResultTypes, + ValueRange{adaptor.getSource(), adaptor.getAlign()}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults().take_front(resultTypes.size())); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerSprclrOpPattern final : public OpConversionPattern { +public: + explicit LowerSprclrOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::SprclrOp op, pto::SprclrOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + auto spr = parseSprImmediate(op.getSpr()); + if (!spr) + return rewriter.notifyMatchFailure(op, "unsupported sprclr target"); + + StringRef calleeName = buildSprclrCallee(op.getContext()); + Value sprValue = rewriter.create( + op.getLoc(), rewriter.getI16IntegerAttr(*spr)); + auto funcType = rewriter.getFunctionType(TypeRange{sprValue.getType()}, TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, ValueRange{sprValue}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVstsOpPattern final : public OpConversionPattern { +public: + explicit LowerVstsOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VstsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type elementType = getElementTypeFromVectorLike(op.getValue().getType()); + if (!elementType) + return rewriter.notifyMatchFailure(op, "unsupported vsts element type"); auto offsetBytes = - convertElementOffsetToBytes(op, vstx2.getOffset(), elementType); + convertElementOffsetToBytes(op, adaptor.getOffset(), elementType); + auto basePtr = dyn_cast(adaptor.getDestination().getType()); + auto dist = + parseStoreDistImmediate(op.getDist().value_or("NORM"), elementType); + if (failed(offsetBytes) || !basePtr || !dist) + return rewriter.notifyMatchFailure(op, "failed to materialize vsts operands"); + + FailureOr calleeName = + buildVstsCallee(op.getContext(), op.getValue().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vsts signature"); + + Value distValue = rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(*dist)); + Value zero = rewriter.create(op.getLoc(), + rewriter.getI32IntegerAttr(0)); + SmallVector args{adaptor.getValue(), adaptor.getDestination(), + *offsetBytes, distValue, zero, adaptor.getMask()}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getValue().getType(), adaptor.getDestination().getType(), + rewriter.getI32Type(), rewriter.getI32Type(), + rewriter.getI32Type(), adaptor.getMask().getType()}, + TypeRange{}); + rewriter.create(op.getLoc(), *calleeName, TypeRange{}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVsstbOpPattern final : public OpConversionPattern { +public: + explicit LowerVsstbOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VsstbOp op, pto::VsstbOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { auto basePtr = - requirePointerABIAddress(op, vstx2.getDestination(), diagOS); - auto dist = parseStoreX2DistImmediate(vstx2.getDist(), elementType); - if (!elementType || failed(offsetBytes) || failed(basePtr) || !dist) { - if (elementType && succeeded(basePtr) && !dist) - diagOS - << "VPTO LLVM emission failed: unsupported vstsx2 dist immediate\n"; - return failure(); - } - Value offsetI32 = castIntegerLikeTo(op, *offsetBytes, builder.getI32Type()); - if (!offsetI32) - return failure(); - callArgs.push_back(vstx2.getLow()); - callArgs.push_back(vstx2.getHigh()); - callArgs.push_back(*basePtr); - callArgs.push_back(offsetI32); - callArgs.push_back(getI32Constant(builder, loc, *dist)); - callArgs.push_back(getI32Constant(builder, loc, 0)); - callArgs.push_back(vstx2.getMask()); - } else if (auto vsts = dyn_cast(op)) { - Type elementType = getElementTypeFromVectorLike(vsts.getValue().getType()); - auto offsetBytes = convertElementOffsetToBytes( - op, op->getOperand(2), elementType); - auto basePtr = requirePointerABIAddress(op, op->getOperand(1), diagOS); + dyn_cast(adaptor.getDestination().getType()); + Value packedStride = + packBlockRepeatStride(op, adaptor.getBlockStride(), adaptor.getRepeatStride()); + if (!basePtr || !packedStride) + return rewriter.notifyMatchFailure(op, "failed to materialize vsstb operands"); + + StringRef calleeName = buildVsstbCallee(op.getContext()); + Value zeroValue = getI32Constant(rewriter, op.getLoc(), 0); + SmallVector args{adaptor.getValue(), adaptor.getDestination(), + packedStride, zeroValue, adaptor.getMask()}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getValue().getType(), adaptor.getDestination().getType(), + packedStride.getType(), zeroValue.getType(), + adaptor.getMask().getType()}, + TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVstsPostOpPattern final + : public OpConversionPattern { +public: + explicit LowerVstsPostOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::VstsPostOp op, pto::VstsPostOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type elementType = getElementTypeFromVectorLike(op.getValue().getType()); + if (!elementType) + return rewriter.notifyMatchFailure(op, "unsupported vsts_post element type"); + + auto offsetBytes = convertElementOffsetToBytes(op, adaptor.getOffset(), elementType); + auto basePtr = + dyn_cast(adaptor.getDestination().getType()); auto dist = - parseStoreDistImmediate(vsts.getDist().value_or("NORM"), elementType); - if (!elementType || failed(offsetBytes) || failed(basePtr) || !dist) { - if (elementType && succeeded(basePtr) && !dist) - diagOS << "VPTO LLVM emission failed: unsupported vsts dist immediate\n"; - return failure(); - } - callArgs.push_back(op->getOperand(0)); - callArgs.push_back(*basePtr); - callArgs.push_back(*offsetBytes); - callArgs.push_back(getI32Constant(builder, loc, *dist)); - callArgs.push_back(getI32Constant(builder, loc, 0)); - callArgs.push_back(op->getOperand(3)); - } else if (auto vstsPost = dyn_cast(op)) { - Type elementType = getElementTypeFromVectorLike(vstsPost.getValue().getType()); - auto offsetBytes = convertElementOffsetToBytes(op, vstsPost.getOffset(), elementType); - auto basePtr = requirePointerABIAddress(op, vstsPost.getDestination(), diagOS); - auto dist = parseStoreDistImmediate(vstsPost.getDist().value_or("NORM"), - elementType); - if (!elementType || failed(offsetBytes) || failed(basePtr) || !dist) - return failure(); - callArgs.push_back(vstsPost.getValue()); - callArgs.push_back(*basePtr); - callArgs.push_back(*offsetBytes); - callArgs.push_back(getI32Constant(builder, loc, *dist)); - callArgs.push_back(getI32Constant(builder, loc, 1)); - callArgs.push_back(vstsPost.getMask()); - } else if (auto ppack = dyn_cast(op)) { - auto part = parseHiLoPartImmediate(ppack.getPart()); - if (!part) { - diagOS << "VPTO LLVM emission failed: unsupported ppack part " - << ppack.getPart() << "\n"; - return failure(); - } - callArgs.push_back(ppack.getInput()); - callArgs.push_back(getI32Constant(builder, loc, *part)); - } else if (auto punpack = dyn_cast(op)) { - auto part = parseHiLoPartImmediate(punpack.getPart()); - if (!part) { - diagOS << "VPTO LLVM emission failed: unsupported punpack part " - << punpack.getPart() << "\n"; - return failure(); - } - callArgs.push_back(punpack.getInput()); - callArgs.push_back(getI32Constant(builder, loc, *part)); - } else if (auto vselr = dyn_cast(op)) { - auto resultVecType = dyn_cast(loweredResultTypes.front()); - if (!resultVecType) { - diagOS << "VPTO LLVM emission failed: unexpected vselr result type\n"; - return failure(); - } - Type intrinsicVecType = resultVecType; - if (auto resultFloat = dyn_cast(resultVecType.getElementType()); - resultFloat && resultFloat.isF32()) { - intrinsicVecType = - VectorType::get(resultVecType.getShape(), builder.getI32Type(), - resultVecType.getScalableDims()); - } - intrinsicResultTypes[0] = intrinsicVecType; - callArgs.push_back(buildBridgeCast(builder, loc, vselr.getSrc0(), intrinsicVecType)); - callArgs.push_back(vselr.getSrc1()); - } else if (isa(op)) { - callArgs.append(op->operand_begin(), op->operand_end()); - } else if (auto plds = dyn_cast(op)) { - auto basePtr = requirePointerABIAddress(op, plds.getSource(), diagOS); - Value offset = castIntegerLikeTo(op, plds.getOffset(), builder.getI32Type()); - auto dist = parsePredicateLoadDistImmediate(plds.getDist()); - if (failed(basePtr) || !offset || !dist) { - if (succeeded(basePtr) && offset && !dist) - diagOS << "VPTO LLVM emission failed: unsupported plds dist immediate\n"; - return failure(); - } - callArgs.push_back(*basePtr); - callArgs.push_back(offset); - callArgs.push_back(getI32Constant(builder, loc, *dist)); - callArgs.push_back(getI32Constant(builder, loc, 0)); - } else if (auto pldi = dyn_cast(op)) { - auto basePtr = requirePointerABIAddress(op, pldi.getSource(), diagOS); - Value offset = castIntegerLikeTo(op, pldi.getOffset(), builder.getI32Type()); - auto dist = parsePredicateLoadDistImmediate(pldi.getDist()); - if (failed(basePtr) || !offset || !dist) { - if (succeeded(basePtr) && offset && !dist) - diagOS << "VPTO LLVM emission failed: unsupported pldi dist immediate\n"; - return failure(); + parseStoreDistImmediate(op.getDist().value_or("NORM"), elementType); + if (failed(offsetBytes) || !basePtr || !dist) { + return rewriter.notifyMatchFailure(op, + "failed to materialize vsts_post operands"); } - callArgs.push_back(*basePtr); - callArgs.push_back(offset); - callArgs.push_back(getI32Constant(builder, loc, *dist)); - callArgs.push_back(getI32Constant(builder, loc, 0)); - } else if (auto psts = dyn_cast(op)) { - auto basePtr = requirePointerABIAddress(op, psts.getDestination(), diagOS); - Value offset = castIntegerLikeTo(op, psts.getOffset(), builder.getI32Type()); - auto dist = parsePredicateStoreDistImmediate(psts.getDist()); - if (failed(basePtr) || !offset || !dist) { - if (succeeded(basePtr) && offset && !dist) - diagOS << "VPTO LLVM emission failed: unsupported psts dist immediate\n"; - return failure(); + + Type updatedDestinationType = + this->getTypeConverter()->convertType(op.getUpdatedDestination().getType()); + if (!updatedDestinationType || updatedDestinationType != adaptor.getDestination().getType()) { + return rewriter.notifyMatchFailure(op, + "failed to convert vsts_post result type"); } - callArgs.push_back(psts.getValue()); - callArgs.push_back(*basePtr); - callArgs.push_back(offset); - callArgs.push_back(getI32Constant(builder, loc, *dist)); - callArgs.push_back(getI32Constant(builder, loc, 0)); - } else if (op->getName().getStringRef() == "pto.pstu") { - auto basePtr = requirePointerABIAddress(op, op->getOperand(2), diagOS); - auto alignValue = materializeAlignABIValue(op, op->getOperand(0), diagOS); - if (failed(basePtr) || failed(alignValue)) - return failure(); - callArgs.push_back(op->getOperand(1)); - callArgs.push_back(*basePtr); - callArgs.push_back(*alignValue); - } else if (auto pstu = dyn_cast(op)) { - auto basePtr = requirePointerABIAddress(op, pstu.getBase(), diagOS); - auto alignValue = materializeAlignABIValue(op, pstu.getAlignIn(), diagOS); - if (failed(basePtr) || failed(alignValue)) - return failure(); - callArgs.push_back(pstu.getValue()); - callArgs.push_back(*basePtr); - callArgs.push_back(*alignValue); - } else if (auto psti = dyn_cast(op)) { - auto basePtr = requirePointerABIAddress(op, psti.getDestination(), diagOS); - Value offset = castIntegerLikeTo(op, psti.getOffset(), builder.getI32Type()); - auto dist = parsePredicateStoreDistImmediate(psti.getDist()); - if (failed(basePtr) || !offset || !dist) { - if (succeeded(basePtr) && offset && !dist) - diagOS << "VPTO LLVM emission failed: unsupported psti dist immediate\n"; - return failure(); + + FailureOr calleeName = + buildVstsPostCallee(op.getContext(), op.getValue().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vsts_post signature"); + + Value distValue = getI32Constant(rewriter, op.getLoc(), *dist); + Value postValue = getI32Constant(rewriter, op.getLoc(), 1); + SmallVector args{adaptor.getValue(), adaptor.getDestination(), *offsetBytes, + distValue, postValue, adaptor.getMask()}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getValue().getType(), adaptor.getDestination().getType(), + (*offsetBytes).getType(), distValue.getType(), postValue.getType(), + adaptor.getMask().getType()}, + TypeRange{updatedDestinationType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{updatedDestinationType}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVstsx2OpPattern final : public OpConversionPattern { +public: + explicit LowerVstsx2OpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::Vstsx2Op op, pto::Vstsx2Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type elementType = getElementTypeFromVectorLike(op.getLow().getType()); + if (!elementType) + return rewriter.notifyMatchFailure(op, "unsupported vstsx2 element type"); + + auto offsetBytes = + convertElementOffsetToBytes(op, adaptor.getOffset(), elementType); + auto basePtr = + dyn_cast(adaptor.getDestination().getType()); + auto dist = parseStoreX2DistImmediate(op.getDist(), elementType); + if (failed(offsetBytes) || !basePtr || !dist) { + return rewriter.notifyMatchFailure(op, + "failed to materialize vstsx2 operands"); } - callArgs.push_back(psti.getValue()); - callArgs.push_back(*basePtr); - callArgs.push_back(offset); - callArgs.push_back(getI32Constant(builder, loc, *dist)); - callArgs.push_back(getI32Constant(builder, loc, 0)); - } else if (auto gather = dyn_cast(op)) { - Type resultElemType = getElementTypeFromVectorLike(gather.getResult().getType()); - auto basePtr = requirePointerABIAddress(op, gather.getSource(), diagOS); - auto mask = buildDynamicPltMask(builder, module, loc, gather.getActiveLanes(), - resultElemType, diagOS); - if (!resultElemType || failed(basePtr) || failed(mask)) - return failure(); - callArgs.push_back(*basePtr); - callArgs.push_back(gather.getOffsets()); + + FailureOr calleeName = + buildVstsx2Callee(op.getContext(), op.getLow().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vstsx2 signature"); + + Value distValue = getI32Constant(rewriter, op.getLoc(), *dist); + Value zeroValue = getI32Constant(rewriter, op.getLoc(), 0); + SmallVector args{adaptor.getLow(), adaptor.getHigh(), + adaptor.getDestination(), *offsetBytes, distValue, + zeroValue, adaptor.getMask()}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getLow().getType(), adaptor.getHigh().getType(), + adaptor.getDestination().getType(), (*offsetBytes).getType(), + distValue.getType(), zeroValue.getType(), + adaptor.getMask().getType()}, + TypeRange{}); + rewriter.create(op.getLoc(), *calleeName, TypeRange{}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerPstuOpPattern final : public OpConversionPattern { +public: + explicit LowerPstuOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::PstuOp op, pto::PstuOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = buildPstuCallee(op.getContext(), op); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported pstu signature"); + + SmallVector resultTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), resultTypes))) + return rewriter.notifyMatchFailure(op, "failed to convert pstu result types"); + if (resultTypes.size() != 2) + return rewriter.notifyMatchFailure(op, "unexpected converted pstu result arity"); + + auto baseType = dyn_cast(adaptor.getBase().getType()); + if (!baseType || adaptor.getAlignIn().getType() != resultTypes[0] || + adaptor.getBase().getType() != resultTypes[1]) { + return rewriter.notifyMatchFailure(op, + "unexpected converted pstu operand/result types"); + } + + SmallVector args{adaptor.getValue(), adaptor.getBase(), adaptor.getAlignIn()}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getValue().getType(), adaptor.getBase().getType(), + adaptor.getAlignIn().getType()}, + resultTypes); + auto call = rewriter.create(op.getLoc(), *calleeName, resultTypes, + args); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVstusOpPattern final : public OpConversionPattern { +public: + explicit LowerVstusOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VstusOp op, pto::VstusOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type elementType = getElementTypeFromVectorLike(op.getValue().getType()); + if (!elementType) + return rewriter.notifyMatchFailure(op, "unsupported vstus element type"); + + auto offsetBytes = convertElementOffsetToBytes(op, adaptor.getOffset(), elementType); + if (failed(offsetBytes)) + return rewriter.notifyMatchFailure(op, "failed to convert vstus offset"); + + Type resultType = this->getTypeConverter()->convertType(op.getAlignOut().getType()); + auto baseType = dyn_cast(adaptor.getBase().getType()); + if (!resultType || !baseType || adaptor.getAlignIn().getType() != resultType) { + return rewriter.notifyMatchFailure(op, + "unexpected converted vstus operand/result types"); + } + + StringRef calleeName = buildVstusCallee(op.getContext()); + SmallVector args{adaptor.getValue(), adaptor.getBase(), *offsetBytes, + adaptor.getAlignIn()}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getValue().getType(), adaptor.getBase().getType(), + (*offsetBytes).getType(), adaptor.getAlignIn().getType()}, + TypeRange{resultType}); + auto call = + rewriter.create(op.getLoc(), calleeName, TypeRange{resultType}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVsturOpPattern final : public OpConversionPattern { +public: + explicit LowerVsturOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VsturOp op, pto::VsturOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto postMode = parsePostModeImmediate(op.getMode()); + if (!postMode) + return rewriter.notifyMatchFailure(op, "unsupported vstur mode immediate"); + + Type resultType = this->getTypeConverter()->convertType(op.getAlignOut().getType()); + auto baseType = dyn_cast(adaptor.getBase().getType()); + if (!resultType || !baseType || adaptor.getAlignIn().getType() != resultType) { + return rewriter.notifyMatchFailure(op, + "unexpected converted vstur operand/result types"); + } + + StringRef calleeName = buildVsturCallee(op.getContext()); + Value modeValue = getI32Constant(rewriter, op.getLoc(), *postMode); + Value zeroValue = getI32Constant(rewriter, op.getLoc(), 0); + SmallVector args{adaptor.getValue(), adaptor.getBase(), adaptor.getAlignIn(), + modeValue, zeroValue}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getValue().getType(), adaptor.getBase().getType(), + adaptor.getAlignIn().getType(), modeValue.getType(), + zeroValue.getType()}, + TypeRange{resultType}); + auto call = + rewriter.create(op.getLoc(), calleeName, TypeRange{resultType}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVstarOpPattern final : public OpConversionPattern { +public: + explicit LowerVstarOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VstarOp op, pto::VstarOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto baseType = dyn_cast(adaptor.getDestination().getType()); + Type alignType = this->getTypeConverter()->convertType(op.getValue().getType()); + if (!baseType || !alignType || adaptor.getValue().getType() != alignType) { + return rewriter.notifyMatchFailure(op, + "unexpected converted vstar operand types"); + } + + StringRef calleeName = buildVstarCallee(op.getContext()); + Value zeroValue = getI32Constant(rewriter, op.getLoc(), 0); + SmallVector args{adaptor.getValue(), adaptor.getDestination(), zeroValue}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getValue().getType(), adaptor.getDestination().getType(), + zeroValue.getType()}, + TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVstasOpPattern final : public OpConversionPattern { +public: + explicit LowerVstasOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VstasOp op, pto::VstasOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto baseType = dyn_cast(adaptor.getDestination().getType()); + Type alignType = this->getTypeConverter()->convertType(op.getValue().getType()); + auto dstType = dyn_cast(op.getDestination().getType()); + if (!baseType || !alignType || adaptor.getValue().getType() != alignType || !dstType) { + return rewriter.notifyMatchFailure(op, + "unexpected converted vstas operand types"); + } + + auto offsetBytes = + convertElementOffsetToBytes(op, adaptor.getOffset(), dstType.getElementType()); + if (failed(offsetBytes)) + return rewriter.notifyMatchFailure(op, "failed to convert vstas offset"); + + StringRef calleeName = buildVstasCallee(op.getContext()); + Value zeroValue = getI32Constant(rewriter, op.getLoc(), 0); + SmallVector args{adaptor.getValue(), adaptor.getDestination(), *offsetBytes, + zeroValue}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getValue().getType(), adaptor.getDestination().getType(), + (*offsetBytes).getType(), zeroValue.getType()}, + TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVgather2OpPattern final + : public OpConversionPattern { +public: + explicit LowerVgather2OpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::Vgather2Op op, pto::Vgather2Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type elemType = getElementTypeFromVectorLike(op.getResult().getType()); + auto basePtr = dyn_cast(adaptor.getSource().getType()); + if (!elemType || !basePtr) + return rewriter.notifyMatchFailure(op, + "unexpected converted vgather2 operand types"); + + FailureOr mask = materializeDynamicPltMask( + rewriter, state, op.getLoc(), adaptor.getActiveLanes(), elemType); + if (failed(mask)) + return rewriter.notifyMatchFailure(op, + "failed to materialize vgather2 mask"); + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, "failed to convert vgather2 result type"); + + FailureOr calleeName = + buildVgather2Callee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vgather2 signature"); + + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getSource().getType(), adaptor.getOffsets().getType(), + (*mask).getType()}, + TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, + ValueRange{adaptor.getSource(), adaptor.getOffsets(), *mask}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVgather2BcOpPattern final + : public OpConversionPattern { +public: + explicit LowerVgather2BcOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::Vgather2BcOp op, pto::Vgather2BcOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto basePtr = dyn_cast(adaptor.getSource().getType()); + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!basePtr || !resultType) + return rewriter.notifyMatchFailure(op, + "unexpected converted vgather2_bc operand/result types"); + + FailureOr calleeName = + buildVgather2BcCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vgather2_bc signature"); + + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getSource().getType(), adaptor.getOffsets().getType(), + adaptor.getMask().getType()}, + TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, + ValueRange{adaptor.getSource(), adaptor.getOffsets(), adaptor.getMask()}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVgatherbOpPattern final + : public OpConversionPattern { +public: + explicit LowerVgatherbOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::VgatherbOp op, pto::VgatherbOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto basePtr = dyn_cast(adaptor.getSource().getType()); + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!basePtr || !resultType) + return rewriter.notifyMatchFailure(op, + "unexpected converted vgatherb operand/result types"); + + FailureOr calleeName = + buildVgatherbCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vgatherb signature"); + + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getSource().getType(), adaptor.getOffsets().getType(), + adaptor.getMask().getType()}, + TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, + ValueRange{adaptor.getSource(), adaptor.getOffsets(), adaptor.getMask()}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVscatterOpPattern final + : public OpConversionPattern { +public: + explicit LowerVscatterOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::VscatterOp op, pto::VscatterOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type elemType = getElementTypeFromVectorLike(op.getValue().getType()); + auto basePtr = + dyn_cast(adaptor.getDestination().getType()); + if (!elemType || !basePtr) + return rewriter.notifyMatchFailure(op, + "unexpected converted vscatter operand types"); + + FailureOr mask = materializeDynamicPltMask( + rewriter, state, op.getLoc(), adaptor.getActiveLanes(), elemType); + if (failed(mask)) + return rewriter.notifyMatchFailure(op, + "failed to materialize vscatter mask"); + + FailureOr calleeName = + buildVscatterCallee(op.getContext(), op.getValue().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vscatter signature"); + + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getValue().getType(), adaptor.getDestination().getType(), + adaptor.getOffsets().getType(), (*mask).getType()}, + TypeRange{}); + rewriter.create( + op.getLoc(), *calleeName, TypeRange{}, + ValueRange{adaptor.getValue(), adaptor.getDestination(), + adaptor.getOffsets(), *mask}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVpreluOpPattern final : public OpConversionPattern { +public: + explicit LowerVpreluOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::VpreluOp op, pto::VpreluOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto laneCount = getElementCountFromVectorLike(op.getResult().getType()); + Type elemType = getElementTypeFromVectorLike(op.getResult().getType()); + if (!laneCount || !elemType) + return rewriter.notifyMatchFailure(op, "unsupported vprelu signature"); + + FailureOr mask = materializeDynamicPltMask( + rewriter, state, op.getLoc(), getI32Constant(rewriter, op.getLoc(), *laneCount), + elemType); + if (failed(mask)) + return rewriter.notifyMatchFailure(op, "failed to materialize vprelu mask"); + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, "failed to convert vprelu result type"); + + FailureOr calleeName = + buildVpreluCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vprelu callee"); + + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getLhs().getType(), adaptor.getRhs().getType(), + (*mask).getType()}, + TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, + ValueRange{adaptor.getLhs(), adaptor.getRhs(), *mask}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVaxpyOpPattern final : public OpConversionPattern { +public: + explicit LowerVaxpyOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VaxpyOp op, pto::VaxpyOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto laneCount = getElementCountFromVectorLike(op.getResult().getType()); + Type elemType = getElementTypeFromVectorLike(op.getResult().getType()); + if (!laneCount || !elemType) + return rewriter.notifyMatchFailure(op, "unsupported vaxpy signature"); + + FailureOr mask = materializeDynamicPltMask( + rewriter, state, op.getLoc(), getI32Constant(rewriter, op.getLoc(), *laneCount), + elemType); + if (failed(mask)) + return rewriter.notifyMatchFailure(op, "failed to materialize vaxpy mask"); + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, "failed to convert vaxpy result type"); + + FailureOr calleeName = + buildVaxpyCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vaxpy callee"); + + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getSrc1().getType(), adaptor.getSrc0().getType(), + adaptor.getAlpha().getType(), (*mask).getType()}, + TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, + ValueRange{adaptor.getSrc1(), adaptor.getSrc0(), adaptor.getAlpha(), + *mask}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVciOpPattern final : public OpConversionPattern { +public: + explicit LowerVciOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VciOp op, pto::VciOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto order = parseOrderImmediate(op.getOrder().value_or("ASC")); + if (!order) + return rewriter.notifyMatchFailure(op, "unsupported vci order"); + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, "failed to convert vci result type"); + + FailureOr calleeName = + buildVciCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vci callee"); + + Value orderValue = getI32Constant(rewriter, op.getLoc(), *order); + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getIndex().getType(), orderValue.getType()}, + TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, + ValueRange{adaptor.getIndex(), orderValue}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVexpdiffOpPattern final + : public OpConversionPattern { +public: + explicit LowerVexpdiffOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::VexpdiffOp op, pto::VexpdiffOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto laneCount = getElementCountFromVectorLike(op.getInput().getType()); + Type elemType = getElementTypeFromVectorLike(op.getInput().getType()); + auto part = parsePartImmediate(op.getPart()); + if (!laneCount || !elemType || !part) + return rewriter.notifyMatchFailure(op, "unsupported vexpdiff signature"); + + FailureOr mask = materializeDynamicPltMask( + rewriter, state, op.getLoc(), getI32Constant(rewriter, op.getLoc(), *laneCount), + elemType); + if (failed(mask)) + return rewriter.notifyMatchFailure(op, "failed to materialize vexpdiff mask"); + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, "failed to convert vexpdiff result type"); + + FailureOr calleeName = + buildVexpdiffCallee(op.getContext(), op.getInput().getType(), + op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vexpdiff callee"); + + Value partValue = getI32Constant(rewriter, op.getLoc(), *part); + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getInput().getType(), adaptor.getMax().getType(), + (*mask).getType(), partValue.getType()}, + TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, + ValueRange{adaptor.getInput(), adaptor.getMax(), *mask, partValue}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVbitsortOpPattern final + : public OpConversionPattern { +public: + explicit LowerVbitsortOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::VbitsortOp op, pto::VbitsortOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto dstType = + dyn_cast(adaptor.getDestination().getType()); + auto srcType = dyn_cast(adaptor.getSource().getType()); + auto idxType = + dyn_cast(adaptor.getIndices().getType()); + if (!dstType || !srcType || !idxType) + return rewriter.notifyMatchFailure(op, + "unexpected converted vbitsort operand types"); + + FailureOr config = packVbitsortConfig(op, adaptor.getRepeatTimes()); + if (failed(config)) + return rewriter.notifyMatchFailure(op, "failed to pack vbitsort config"); + + FailureOr calleeName = buildVbitsortCallee(op.getContext(), op); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vbitsort signature"); + + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getDestination().getType(), adaptor.getSource().getType(), + adaptor.getIndices().getType(), (*config).getType()}, + TypeRange{}); + rewriter.create( + op.getLoc(), *calleeName, TypeRange{}, + ValueRange{adaptor.getDestination(), adaptor.getSource(), + adaptor.getIndices(), *config}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVcvtOpPattern final : public OpConversionPattern { +public: + explicit LowerVcvtOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VcvtOp op, pto::VcvtOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto inputLanes = getElementCountFromVectorLike(op.getInput().getType()); + if (!inputLanes) + return rewriter.notifyMatchFailure(op, "unsupported vcvt input shape"); + + FailureOr contract = buildVcvtContract(op); + if (failed(contract)) + return rewriter.notifyMatchFailure(op, "unsupported vcvt type pair"); + + Type maskElemType = rewriter.getIntegerType((*contract).maskBitWidth); + FailureOr mask = materializeDynamicPltMask( + rewriter, state, op.getLoc(), + getI32Constant(rewriter, op.getLoc(), *inputLanes), maskElemType); + if (failed(mask)) + return rewriter.notifyMatchFailure(op, "failed to materialize vcvt mask"); + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, "failed to convert vcvt result type"); + + SmallVector callArgs; + SmallVector argTypes; + callArgs.push_back(adaptor.getInput()); + argTypes.push_back(adaptor.getInput().getType()); callArgs.push_back(*mask); - } else if (auto gather = dyn_cast(op)) { - auto basePtr = requirePointerABIAddress(op, gather.getSource(), diagOS); - if (failed(basePtr)) + argTypes.push_back((*mask).getType()); + + if ((*contract).requiresRnd) { + auto roundMode = + op.getRndAttr() ? parseRoundModeImmediate(*op.getRnd()) : std::nullopt; + if (!roundMode) + return rewriter.notifyMatchFailure(op, "vcvt requires valid rnd attr"); + Value roundValue = getI32Constant(rewriter, op.getLoc(), *roundMode); + callArgs.push_back(roundValue); + argTypes.push_back(roundValue.getType()); + } + + if ((*contract).requiresSat) { + auto saturation = + op.getSatAttr() ? parseSaturationImmediate(*op.getSat()) : std::nullopt; + if (!saturation) + return rewriter.notifyMatchFailure(op, "vcvt requires valid sat attr"); + Value satValue = getI32Constant(rewriter, op.getLoc(), *saturation); + callArgs.push_back(satValue); + argTypes.push_back(satValue.getType()); + } + + if ((*contract).requiresPart) { + auto part = op.getPartAttr() ? parsePartImmediate(*op.getPart()) : std::nullopt; + if (!part) + return rewriter.notifyMatchFailure(op, "vcvt requires valid part attr"); + Value partValue = getI32Constant(rewriter, op.getLoc(), *part); + callArgs.push_back(partValue); + argTypes.push_back(partValue.getType()); + } + + auto funcType = rewriter.getFunctionType(argTypes, TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), StringRef((*contract).intrinsic), TypeRange{resultType}, callArgs); + state.plannedDecls.push_back( + PlannedDecl{std::string((*contract).intrinsic), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVtrcOpPattern final : public OpConversionPattern { +public: + explicit LowerVtrcOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VtrcOp op, pto::VtrcOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto roundMode = parseRoundModeImmediate(op.getRoundMode()); + if (!roundMode) + return rewriter.notifyMatchFailure(op, "unsupported vtrc signature"); + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, "failed to convert vtrc result type"); + + FailureOr calleeName = + buildVtrcCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vtrc callee"); + + Value roundValue = getI32Constant(rewriter, op.getLoc(), *roundMode); + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getInput().getType(), roundValue.getType(), + adaptor.getMask().getType()}, + TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, + ValueRange{adaptor.getInput(), roundValue, adaptor.getMask()}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerPredicateStoreOpPattern final : public OpConversionPattern { +public: + explicit LowerPredicateStoreOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(StoreOp op, typename StoreOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto llvmDestType = + dyn_cast(adaptor.getDestination().getType()); + Type valueType = this->getTypeConverter()->convertType(op.getValue().getType()); + if (!llvmDestType || !valueType) + return rewriter.notifyMatchFailure( + op, "expected converted predicate-store operand types"); + + auto dist = parsePredicateStoreDistImmediate(op.getDist()); + if (!dist) + return rewriter.notifyMatchFailure( + op, "unsupported predicate-store dist immediate"); + + Value offset = castIntegerLikeTo(op, adaptor.getOffset(), rewriter.getI32Type()); + if (!offset) + return rewriter.notifyMatchFailure( + op, "failed to convert predicate-store offset to i32"); + + StringRef calleeName = getPredicateStoreCallee(op.getContext()); + SmallVector args; + args.push_back(adaptor.getValue()); + args.push_back(adaptor.getDestination()); + args.push_back(offset); + args.push_back(rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(*dist))); + args.push_back(rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(0))); + auto funcType = rewriter.getFunctionType( + TypeRange{valueType, llvmDestType, rewriter.getI32Type(), + rewriter.getI32Type(), rewriter.getI32Type()}, + TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerPredicateLoadOpPattern final : public OpConversionPattern { +public: + explicit LowerPredicateLoadOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(LoadOp op, typename LoadOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto llvmSourceType = + dyn_cast(adaptor.getSource().getType()); + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (!llvmSourceType || !resultType) + return rewriter.notifyMatchFailure( + op, "expected converted predicate-load operand/result types"); + + auto dist = parsePredicateLoadDistImmediate(op.getDist()); + if (!dist) + return rewriter.notifyMatchFailure( + op, "unsupported predicate-load dist immediate"); + + Value offset = castIntegerLikeTo(op, adaptor.getOffset(), rewriter.getI32Type()); + if (!offset) + return rewriter.notifyMatchFailure( + op, "failed to convert predicate-load offset to i32"); + + StringRef calleeName = getPredicateLoadCallee(op.getContext()); + SmallVector args; + args.push_back(adaptor.getSource()); + args.push_back(offset); + args.push_back(rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(*dist))); + args.push_back(rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(0))); + auto funcType = rewriter.getFunctionType( + TypeRange{llvmSourceType, rewriter.getI32Type(), rewriter.getI32Type(), + rewriter.getI32Type()}, + TypeRange{resultType}); + auto call = + rewriter.create(op.getLoc(), calleeName, resultType, args); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerSetLoopConfigOpPattern final : public OpConversionPattern { +public: + explicit LowerSetLoopConfigOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(LoopOp op, typename LoopOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr packed = failure(); + if constexpr (std::is_same_v || + std::is_same_v) { + packed = packLoopSize(op, adaptor.getFirst(), adaptor.getSecond()); + } else { + packed = packLoopPair(op, adaptor.getFirst(), adaptor.getSecond()); + } + if (failed(packed)) + return rewriter.notifyMatchFailure(op, + "failed to pack loop configuration"); + + StringRef calleeName = buildSetLoopCallee(op.getContext()); + auto funcType = + rewriter.getFunctionType(TypeRange{rewriter.getI64Type()}, TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, + ValueRange{*packed}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerPipeEventSyncOpPattern final : public OpConversionPattern { +public: + explicit LowerPipeEventSyncOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(SyncOp op, typename SyncOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + auto src = parsePipeImmediate(stringifyPIPE(op.getSrcPipe().getPipe())); + auto dst = parsePipeImmediate(stringifyPIPE(op.getDstPipe().getPipe())); + auto event = parseEventImmediate(stringifyEVENT(op.getEventId().getEvent())); + if (!src || !dst || !event) + return rewriter.notifyMatchFailure(op, "unsupported sync immediate"); + + StringRef calleeName = buildSyncCallee(op.getContext()); + Value srcValue = getI64Constant(rewriter, op.getLoc(), *src); + Value dstValue = getI64Constant(rewriter, op.getLoc(), *dst); + Value eventValue = getI64Constant(rewriter, op.getLoc(), *event); + auto funcType = rewriter.getFunctionType( + TypeRange{rewriter.getI64Type(), rewriter.getI64Type(), + rewriter.getI64Type()}, + TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, + ValueRange{srcValue, dstValue, eventValue}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerBarrierOpPattern final : public OpConversionPattern { +public: + explicit LowerBarrierOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::BarrierOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + auto pipe = parsePipeImmediate(stringifyPIPE(op.getPipe().getPipe())); + if (!pipe) + return rewriter.notifyMatchFailure(op, "unsupported barrier pipe"); + + StringRef calleeName = buildSyncCallee(op.getContext()); + Value pipeValue = getI64Constant(rewriter, op.getLoc(), *pipe); + auto funcType = + rewriter.getFunctionType(TypeRange{rewriter.getI64Type()}, TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, + ValueRange{pipeValue}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerBufSyncOpPattern final : public OpConversionPattern { +public: + explicit LowerBufSyncOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(BufSyncOp op, typename BufSyncOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + PIPE pipe = PIPE::PIPE_UNASSIGNED; + if (auto pipeAttr = dyn_cast(op.getOpTypeAttr())) { + pipe = pipeAttr.getPipe(); + } else { + auto opTypeOr = parseSyncOpTypeLikeAttr(op.getOpTypeAttr()); + if (failed(opTypeOr)) + return rewriter.notifyMatchFailure( + op, "buffer sync expects pipe/sync_op_type/pipe_event_type attr"); + pipe = mapSyncOpTypeToPipe(*opTypeOr); + } + if (!isConcreteSyncPipe(pipe)) + return rewriter.notifyMatchFailure(op, + "buffer sync op_type cannot map to concrete pipe"); + + auto pipeImm = parsePipeImmediate(stringifyPIPE(pipe)); + if (!pipeImm) + return rewriter.notifyMatchFailure(op, "unsupported buffer sync pipe"); + + StringRef calleeName = buildSyncCallee(op.getContext()); + Value pipeValue = getI64Constant(rewriter, op.getLoc(), *pipeImm); + Value bufIdValue = + getI64Constant(rewriter, op.getLoc(), op.getBufIdAttr().getInt()); + Value modeValue = + getI64Constant(rewriter, op.getLoc(), op.getModeAttr().getInt()); + auto funcType = rewriter.getFunctionType( + TypeRange{rewriter.getI64Type(), rewriter.getI64Type(), + rewriter.getI64Type()}, + TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, + ValueRange{pipeValue, bufIdValue, modeValue}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerRuntimeQueryOpPattern final : public OpConversionPattern { +public: + explicit LowerRuntimeQueryOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(QueryOp op, typename QueryOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, + "failed to convert runtime-query result type"); + + StringRef calleeName = buildRuntimeQueryCallee(op.getContext()); + auto funcType = rewriter.getFunctionType(TypeRange{}, TypeRange{resultType}); + auto call = rewriter.create(op.getLoc(), calleeName, + TypeRange{resultType}, ValueRange{}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class ConvertVPTOUnrealizedCastOp final + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op->getNumOperands() != 1 || op->getNumResults() != 1) return failure(); - callArgs.push_back(*basePtr); - callArgs.push_back(gather.getOffsets()); - callArgs.push_back(gather.getMask()); - } else if (auto gather = dyn_cast(op)) { - auto basePtr = requirePointerABIAddress(op, gather.getSource(), diagOS); - if (failed(basePtr)) + if (!hasVPTOConvertibleType(op->getOperandTypes()) && + !hasVPTOConvertibleType(op->getResultTypes())) return failure(); - callArgs.push_back(*basePtr); - callArgs.push_back(gather.getOffsets()); - callArgs.push_back(gather.getMask()); - } else if (auto vbitsort = dyn_cast(op)) { - auto destination = requirePointerABIAddress(op, vbitsort.getDestination(), diagOS); - auto source = requirePointerABIAddress(op, vbitsort.getSource(), diagOS); - auto indices = requirePointerABIAddress(op, vbitsort.getIndices(), diagOS); - auto config = packVbitsortConfig(op, vbitsort.getRepeatTimes()); - if (failed(destination) || failed(source) || failed(indices) || failed(config)) + + Type convertedResultType = + getTypeConverter()->convertType(op.getResult(0).getType()); + if (!convertedResultType) return failure(); - callArgs.push_back(*destination); - callArgs.push_back(*source); - callArgs.push_back(*indices); - callArgs.push_back(*config); - } else if (auto scatter = dyn_cast(op)) { - Type valueElemType = getElementTypeFromVectorLike(scatter.getValue().getType()); - auto basePtr = requirePointerABIAddress(op, scatter.getDestination(), diagOS); - auto mask = buildDynamicPltMask(builder, module, loc, scatter.getActiveLanes(), - valueElemType, diagOS); - if (!valueElemType || failed(basePtr) || failed(mask)) + + Value input = adaptor.getOperands().front(); + if (input.getType() != convertedResultType) return failure(); - callArgs.push_back(scatter.getValue()); - callArgs.push_back(*basePtr); - callArgs.push_back(scatter.getOffsets()); - callArgs.push_back(*mask); - } else { - diagOS << "VPTO LLVM emission failed: op lowering is not implemented for " - << op->getName().getStringRef() << "\n"; - return failure(); + + rewriter.replaceOp(op, input); + return success(); } +}; - SmallVector argTypes; - for (Value arg : callArgs) - argTypes.push_back(arg.getType()); - - auto funcType = builder.getFunctionType(argTypes, intrinsicResultTypes); - auto callee = getOrCreateExternalFunc(module, *calleeName, funcType); - auto call = builder.create(loc, callee, callArgs); - if (op->getNumResults() == 0) - builder.eraseOp(op); - else { - SmallVector finalResults; - finalResults.reserve(op->getNumResults()); - for (auto [idx, result] : - llvm::enumerate(call.getResults().take_front(op->getNumResults()))) { - Type surfaceType = surfaceResultTypes[idx]; - if (isa(surfaceType)) { - diagOS << "VPTO LLVM emission failed: unexpected LLVM pointer surface " - "result type on op "; - op->print(diagOS); - diagOS << "\n"; - return failure(); +class ConvertPtoAddPtrOp final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::AddPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type convertedResultType = getTypeConverter()->convertType(op.getResult().getType()); + auto llvmPtrType = dyn_cast(convertedResultType); + if (!llvmPtrType) + return rewriter.notifyMatchFailure(op, "expected LLVM pointer result type"); + + Value offset = adaptor.getOffset(); + if (offset.getType().isIndex()) + offset = rewriter.create(op.getLoc(), + rewriter.getI64Type(), offset); + + auto gep = rewriter.create( + op.getLoc(), llvmPtrType, cast(op.getPtr().getType()).getElementType(), + adaptor.getPtr(), ValueRange{offset}); + rewriter.replaceOp(op, gep.getResult()); + return success(); + } +}; + +class ConvertPtoCastPtrOp final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::CastPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type convertedResultType = + getTypeConverter()->convertType(op.getResult().getType()); + if (!convertedResultType) + return rewriter.notifyMatchFailure(op, + "could not convert castptr result type"); + + Value input = adaptor.getInput(); + Type inputType = input.getType(); + if (inputType == convertedResultType) { + rewriter.replaceOp(op, input); + return success(); + } + + if (auto llvmPtrType = dyn_cast(convertedResultType)) { + if (isa(inputType)) { + rewriter.replaceOpWithNewOp(op, llvmPtrType, input); + return success(); + } + auto sourcePtrType = dyn_cast(inputType); + if (!sourcePtrType) + return rewriter.notifyMatchFailure(op, + "expected integer or LLVM pointer input"); + if (sourcePtrType.getAddressSpace() == llvmPtrType.getAddressSpace()) { + rewriter.replaceOpWithNewOp(op, llvmPtrType, input); + return success(); } - if (isa(surfaceType)) { - finalResults.push_back(buildBridgeCast(builder, loc, result, surfaceType)); - continue; + return rewriter.notifyMatchFailure( + op, "cross-address-space ptr casts are unsupported"); + } + + if (auto resultIntType = dyn_cast(convertedResultType)) { + if (isa(inputType)) { + rewriter.replaceOpWithNewOp(op, resultIntType, input); + return success(); } - finalResults.push_back(result); } - builder.replaceOp(op, finalResults); + + return rewriter.notifyMatchFailure(op, "unsupported castptr conversion"); + } +}; + +class ConvertPtoLoadScalarOp final + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::LoadScalarOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto llvmPtrType = dyn_cast(adaptor.getPtr().getType()); + if (!llvmPtrType) + return rewriter.notifyMatchFailure(op, "expected LLVM pointer operand"); + + Value offset = adaptor.getOffset(); + if (offset.getType().isIndex()) + offset = rewriter.create(op.getLoc(), + rewriter.getI64Type(), offset); + + Value elemPtr = adaptor.getPtr(); + if (!matchPattern(offset, m_Zero())) { + elemPtr = rewriter.create(op.getLoc(), llvmPtrType, + op.getValue().getType(), adaptor.getPtr(), + ValueRange{offset}); + } + + auto getNaturalAlignment = [&](Type type) -> unsigned { + unsigned alignBytes = 0; + if (auto intType = dyn_cast(type)) + alignBytes = llvm::divideCeil(unsigned(intType.getWidth()), 8u); + else if (type.isF16() || type.isBF16()) + alignBytes = 2; + else if (type.isF32()) + alignBytes = 4; + else if (type.isF64()) + alignBytes = 8; + return alignBytes; + }; + + rewriter.replaceOpWithNewOp( + op, op.getValue().getType(), elemPtr, + getNaturalAlignment(op.getValue().getType())); + return success(); + } +}; + +class ConvertPtoStoreScalarOp final + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::StoreScalarOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto llvmPtrType = dyn_cast(adaptor.getPtr().getType()); + if (!llvmPtrType) + return rewriter.notifyMatchFailure(op, "expected LLVM pointer operand"); + + Value offset = adaptor.getOffset(); + if (offset.getType().isIndex()) + offset = rewriter.create(op.getLoc(), + rewriter.getI64Type(), offset); + + Value elemPtr = adaptor.getPtr(); + if (!matchPattern(offset, m_Zero())) { + elemPtr = rewriter.create(op.getLoc(), llvmPtrType, + adaptor.getValue().getType(), + adaptor.getPtr(), ValueRange{offset}); + } + + auto getNaturalAlignment = [&](Type type) -> unsigned { + unsigned alignBytes = 0; + if (auto intType = dyn_cast(type)) + alignBytes = llvm::divideCeil(unsigned(intType.getWidth()), 8u); + else if (type.isF16() || type.isBF16()) + alignBytes = 2; + else if (type.isF32()) + alignBytes = 4; + else if (type.isF64()) + alignBytes = 8; + return alignBytes; + }; + + rewriter.create(op.getLoc(), adaptor.getValue(), elemPtr, + getNaturalAlignment(adaptor.getValue().getType())); + rewriter.eraseOp(op); + return success(); + } +}; + +class ConvertVPTOTypedCarrierOp final : public ConversionPattern { +public: + ConvertVPTOTypedCarrierOp(TypeConverter &typeConverter, MLIRContext *context) + : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), 1, context) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (isa(op)) + return failure(); + if (!hasVPTOConvertibleType(op->getOperandTypes()) && + !hasVPTOConvertibleType(op->getResultTypes())) + return failure(); + if (op->getNumRegions() != 0) + return rewriter.notifyMatchFailure( + op, "region ops with VPTO types are handled structurally"); + + FailureOr converted = + convertOpResultTypes(op, operands, *typeConverter, rewriter); + if (failed(converted)) + return failure(); + return success(); } - return success(); +}; + +static void populateVPTOOpLoweringPatterns(VPTOTypeConverter &typeConverter, + RewritePatternSet &patterns, + LoweringState &state) { + patterns.add, + LowerUnaryMaskedOpPattern, + LowerUnaryMaskedOpPattern, + LowerUnaryMaskedOpPattern, + LowerUnaryMaskedOpPattern, + LowerUnaryMaskedOpPattern, + LowerUnaryMaskedOpPattern, + LowerVsqzOpPattern, LowerVusqzOpPattern, + LowerVmulaOpPattern, LowerVmullOpPattern, + LowerBinaryMaskedOpPattern, + LowerBinaryMaskedOpPattern, + LowerBinaryMaskedOpPattern, + LowerBinaryMaskedOpPattern, + LowerBinaryMaskedOpPattern, + LowerBinaryMaskedOpPattern, + LowerBinaryMaskedOpPattern, + LowerBinaryMaskedOpPattern, + LowerBinaryMaskedOpPattern, + LowerCarryBinaryOpPattern, + LowerCarryBinaryOpPattern, + LowerCarryBinaryOpPattern, + LowerCarryBinaryOpPattern, + LowerBinaryMaskedOpPattern, + LowerBinaryMaskedOpPattern, + LowerVecScalarMaskedOpPattern, + LowerVecScalarMaskedOpPattern, + LowerVecScalarMaskedOpPattern, + LowerVecScalarMaskedOpPattern, + LowerVecScalarMaskedOpPattern, + LowerVecScalarMaskedOpPattern, + LowerVecScalarMaskedOpPattern, + LowerReductionUnaryOpPattern, + LowerReductionUnaryOpPattern, + LowerReductionUnaryOpPattern, + LowerReductionUnaryOpPattern, + LowerReductionUnaryOpPattern, + LowerReductionUnaryOpPattern, + LowerReductionUnaryOpPattern, + LowerVdupOpPattern, + LowerVbrOpPattern, + LowerPredicatePackOpPattern, + LowerPredicatePackOpPattern, + LowerVselOpPattern, LowerVselrOpPattern, LowerPnotOpPattern, + LowerPredicateMaskBinaryOpPattern, + LowerPredicateMaskBinaryOpPattern, + LowerPredicateMaskBinaryOpPattern, + LowerPredicateMaskBinaryOpPattern, + LowerPredicatePairReorderOpPattern, + LowerPredicatePairReorderOpPattern, + LowerPredicatePairReorderOpPattern, + LowerPredicatePairReorderOpPattern, + LowerPredicatePairReorderOpPattern, + LowerPredicatePairReorderOpPattern, + LowerUnpackOpPattern, + LowerUnpackOpPattern, + LowerVpackOpPattern, + LowerInterleaveOpPattern, + LowerInterleaveOpPattern, + LowerCmpOpPattern, + LowerCmpOpPattern, + LowerPltOpPattern, + LowerPltOpPattern, + LowerPltOpPattern, + LowerPsetOpPattern, + LowerPsetOpPattern, + LowerPsetOpPattern, + LowerPgeOpPattern, + LowerPgeOpPattern, + LowerPgeOpPattern, + LowerSetLoopConfigOpPattern, + LowerSetLoopConfigOpPattern, + LowerSetLoopConfigOpPattern, + LowerSetLoopConfigOpPattern, + LowerSetLoopConfigOpPattern, + LowerSetLoopConfigOpPattern, + LowerPipeEventSyncOpPattern, + LowerPipeEventSyncOpPattern, + LowerBarrierOpPattern, + LowerBufSyncOpPattern, + LowerBufSyncOpPattern, + LowerRuntimeQueryOpPattern, + LowerRuntimeQueryOpPattern, + LowerRuntimeQueryOpPattern, + LowerRuntimeQueryOpPattern, + LowerVldsOpPattern, LowerVldsPostOpPattern, + LowerVldsx2OpPattern, LowerVsldbOpPattern, + LowerVldasOpPattern, LowerInitAlignOpPattern, + LowerVldusOpPattern, LowerSprclrOpPattern, + LowerVstsOpPattern, LowerVsstbOpPattern, + LowerVstsPostOpPattern, LowerVstsx2OpPattern, + LowerVstarOpPattern, LowerVstasOpPattern, + LowerVgather2OpPattern, LowerVgather2BcOpPattern, + LowerVgatherbOpPattern, LowerVscatterOpPattern, + LowerVpreluOpPattern, LowerVaxpyOpPattern, + LowerVciOpPattern, LowerVexpdiffOpPattern, + LowerVbitsortOpPattern, LowerVtrcOpPattern, LowerVcvtOpPattern, + LowerPredicateLoadOpPattern, + LowerPredicateLoadOpPattern, + LowerPredicateStoreOpPattern, + LowerPredicateStoreOpPattern, + LowerPstuOpPattern, LowerVstusOpPattern, LowerVsturOpPattern, + LowerCopyOpPattern, + LowerCopyOpPattern>( + typeConverter, patterns.getContext(), state); } -static LogicalResult rewriteVPTOOps(ModuleOp module, llvm::raw_ostream &diagOS) { - SmallVector opsToRewrite; - module.walk([&](Operation *op) { - if (op->getName().getDialectNamespace() != "pto") - return; - if (isa(op)) - return; - opsToRewrite.push_back(op); - }); +static void configureVPTOOpLoweringTarget(ConversionTarget &target, + VPTOTypeConverter &typeConverter) { + (void)typeConverter; + target.addLegalOp(); + target.addLegalDialect(); + target.addLegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); +} - for (Operation *op : opsToRewrite) { - if (failed(rewriteVPTOOp(op, module, diagOS))) - return failure(); - } +static void populateVPTOStructuralTypePatterns( + VPTOTypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target) { + scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter, patterns, + target); + populateFunctionOpInterfaceTypeConversionPattern(patterns, + typeConverter); + populateCallOpTypeConversionPattern(patterns, typeConverter); + populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); + populateReturnOpTypeConversionPattern(patterns, typeConverter); +} - bool hasVPTO = false; - module.walk([&](Operation *op) { - if (op->getName().getDialectNamespace() != "pto") +static void foldVPTOTypeCasts(ModuleOp module, TypeConverter &typeConverter) { + SmallVector castsToFold; + module.walk([&](UnrealizedConversionCastOp castOp) { + if (castOp->getNumOperands() != 1 || castOp->getNumResults() != 1) return; - if (isa(op)) + if (!hasVPTOConvertibleType(castOp->getOperandTypes()) && + !hasVPTOConvertibleType(castOp->getResultTypes())) return; - hasVPTO = true; - }); - - SmallVector poisonOps; - module.walk([&](Operation *op) { - auto name = op->getName().getStringRef(); - if (name == "ub.poison" && - op->getNumResults() == 1 && - isa(op->getResult(0).getType())) - poisonOps.push_back(op); + Type convertedResultType = + typeConverter.convertType(castOp.getResult(0).getType()); + if (convertedResultType && + convertedResultType == castOp.getOperand(0).getType()) + castsToFold.push_back(castOp); }); - for (Operation *op : poisonOps) { - OpBuilder builder(op); - auto abiType = cast(convertVPTOType(op->getResult(0).getType(), builder)); - auto zeroAttr = DenseElementsAttr::get(abiType, builder.getI8IntegerAttr(0)); - auto zero = builder.create(op->getLoc(), abiType, zeroAttr); - op->getResult(0).replaceAllUsesWith(zero.getResult()); - op->erase(); + for (UnrealizedConversionCastOp castOp : castsToFold) { + castOp.getResult(0).replaceAllUsesWith(castOp.getOperand(0)); + castOp.erase(); } - - return success(!hasVPTO); } -static Type normalizeTypeForOfficialLLVMLowering(Type type, Builder &builder) { - type = convertVPTOType(type, builder); +static LogicalResult lowerVPTOOps(ModuleOp module, llvm::raw_ostream &diagOS) { + MLIRContext *context = module.getContext(); + VPTOTypeConverter typeConverter(context); + ConversionTarget target(*context); + RewritePatternSet patterns(context); + LoweringState state; + + configureVPTOOpLoweringTarget(target, typeConverter); + populateVPTOOpLoweringPatterns(typeConverter, patterns, state); - if (auto memrefType = dyn_cast(type)) { - auto addrAttr = - dyn_cast_or_null(memrefType.getMemorySpace()); - if (!addrAttr) - return type; - unsigned addrSpace = getExternalPointerAddressSpace(memrefType); - return MemRefType::get(memrefType.getShape(), memrefType.getElementType(), - memrefType.getLayout(), - builder.getI64IntegerAttr(addrSpace)); + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + diagOS << "VPTO LLVM emission failed: VPTO op lowering failed\n"; + return failure(); } + if (failed(materializeDecls(module, state.plannedDecls, diagOS))) + return failure(); + return success(); +} + +static LogicalResult lowerVPTOTypes(ModuleOp module, llvm::raw_ostream &diagOS) { + MLIRContext *context = module.getContext(); + VPTOTypeConverter typeConverter(context); + ConversionTarget target(*context); + RewritePatternSet patterns(context); + + target.addLegalOp(); + target.addDynamicallyLegalOp([&](func::FuncOp op) { + return typeConverter.isSignatureLegal(op.getFunctionType()) && + typeConverter.isLegal(&op.getBody()); + }); + target.addDynamicallyLegalOp( + [&](func::CallOp op) { return typeConverter.isLegal(op); }); + target.addDynamicallyLegalOp( + [&](func::ReturnOp op) { return typeConverter.isLegal(op); }); + target.addDynamicallyLegalOp( + [&](Operation *op) { + return isLegalForBranchOpInterfaceTypeConversionPattern(op, + typeConverter); + }); + target.addIllegalOp(); + target.addDynamicallyLegalOp( + [&](UnrealizedConversionCastOp op) { + return !hasVPTOConvertibleType(op->getOperandTypes()) && + !hasVPTOConvertibleType(op->getResultTypes()); + }); + target.markUnknownOpDynamicallyLegal([&](Operation *op) { + return typeConverter.isLegal(op->getOperandTypes()) && + typeConverter.isLegal(op->getResultTypes()); + }); + + populateVPTOStructuralTypePatterns(typeConverter, patterns, target); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); - if (auto memrefType = dyn_cast(type)) { - auto addrAttr = - dyn_cast_or_null(memrefType.getMemorySpace()); - if (!addrAttr) - return type; - // Official MemRef-to-LLVM conversion requires integer memory spaces. - return UnrankedMemRefType::get(memrefType.getElementType(), - builder.getI64IntegerAttr( - static_cast(AddressSpace::GM))); + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + diagOS << "VPTO LLVM emission failed: VPTO type lowering failed\n"; + return failure(); } + foldVPTOTypeCasts(module, typeConverter); + return success(); +} +static Type normalizeTypeForOfficialLLVMLowering(Type type, Builder &builder) { + type = convertVPTOType(type, builder); return type; } @@ -3731,14 +4820,11 @@ static void normalizeFuncSignaturesForOfficialLLVMLowering(ModuleOp module) { SmallVector newResults; bool changed = false; - newInputs.reserve(oldType.getNumInputs()); for (Type input : oldType.getInputs()) { Type normalized = normalizeTypeForOfficialLLVMLowering(input, builder); changed |= (normalized != input); newInputs.push_back(normalized); } - - newResults.reserve(oldType.getNumResults()); for (Type result : oldType.getResults()) { Type normalized = normalizeTypeForOfficialLLVMLowering(result, builder); changed |= (normalized != result); @@ -3760,745 +4846,31 @@ static void normalizeFuncSignaturesForOfficialLLVMLowering(ModuleOp module) { } } -static void ensureAIVScopeDummyDecl(ModuleOp module) { - SymbolTable symbolTable(module); - if (symbolTable.lookup(kAIVScopeDummyCallee)) - return; - - OpBuilder builder(module.getBodyRegion()); - builder.setInsertionPointToStart(module.getBody()); - auto funcType = builder.getFunctionType(TypeRange{}, TypeRange{}); - auto dummy = builder.create(module.getLoc(), - kAIVScopeDummyCallee, funcType); - dummy.setPrivate(); -} - -static void materializeVecScopeCarrierLoops(ModuleOp module) { - MLIRContext *ctx = module.getContext(); - (void)ctx->getOrLoadDialect(); - (void)ctx->getOrLoadDialect(); - ensureAIVScopeDummyDecl(module); - - SmallVector scopes; - module.walk([&](pto::VecScopeOp vecScope) { scopes.push_back(vecScope); }); - - IRRewriter rewriter(module.getContext()); - for (pto::VecScopeOp vecScope : llvm::reverse(scopes)) { - if (!vecScope || vecScope.getBody().empty()) - continue; - - rewriter.setInsertionPoint(vecScope); - auto loc = vecScope.getLoc(); - Value c0 = rewriter.create(loc, 0); - Value c1 = rewriter.create(loc, 1); - scf::ForOp carrier = rewriter.create(loc, c0, c1, c1); - - Block &vecScopeBody = vecScope.getBody().front(); - Block *carrierBody = carrier.getBody(); - Operation *yield = carrierBody->getTerminator(); - carrierBody->getOperations().splice(Block::iterator(yield), - vecScopeBody.getOperations(), - vecScopeBody.begin(), - vecScopeBody.end()); - rewriter.setInsertionPoint(yield); - rewriter.create(loc, kAIVScopeDummyCallee, TypeRange{}, - ValueRange{}); - rewriter.eraseOp(vecScope); - } - - SmallVector strictScopes; - module.walk( - [&](pto::StrictVecScopeOp strictVecScope) { strictScopes.push_back(strictVecScope); }); - - for (pto::StrictVecScopeOp strictVecScope : llvm::reverse(strictScopes)) { - if (!strictVecScope || strictVecScope.getBody().empty()) - continue; - - rewriter.setInsertionPoint(strictVecScope); - auto loc = strictVecScope.getLoc(); - Value c0 = rewriter.create(loc, 0); - Value c1 = rewriter.create(loc, 1); - scf::ForOp carrier = rewriter.create(loc, c0, c1, c1); - - Block &strictBody = strictVecScope.getBody().front(); - Block *carrierBody = carrier.getBody(); - Operation *yield = carrierBody->getTerminator(); - - IRMapping mapping; - for (auto [blockArg, capture] : - llvm::zip(strictBody.getArguments(), strictVecScope.getCaptures())) - mapping.map(blockArg, capture); - - rewriter.setInsertionPoint(yield); - for (Operation &nested : strictBody.getOperations()) - rewriter.clone(nested, mapping); - rewriter.create(loc, kAIVScopeDummyCallee, TypeRange{}, - ValueRange{}); - - rewriter.eraseOp(strictVecScope); - } -} - -static bool satisfiesAIVectorScopeLatchPostcondition(llvm::Loop *loop) { - llvm::BasicBlock *latch = loop->getLoopLatch(); - if (!latch) - return false; - - llvm::SmallVector preds(llvm::predecessors(latch)); - if (preds.size() != 1) - return false; - - auto *predTerm = preds.front()->getTerminator(); - return predTerm && predTerm->getNumSuccessors() == 1 && - predTerm->getSuccessor(0) == latch; -} - -// Bisheng imposes a strict CFG contract on loops carrying -// `llvm.loop.aivector_scope` metadata: -// 1. the latch must have exactly one predecessor -// 2. that predecessor must have exactly one successor, namely the latch -// -// The generic SCF/LLVM lowering pipeline does not preserve this shape for us. -// Therefore VPTO LLVM emission treats this as a required postcondition instead -// of a best-effort cleanup: -// - if the loop already satisfies the contract, keep it as-is -// - otherwise normalize all latch predecessors through a dummy block -// - if normalization still cannot re-establish the contract, fail the export -// -// Failing loudly here is intentional. Silently attaching aivscope metadata to -// an unsupported latch shape only defers the problem into Bisheng as a backend -// crash, which makes future regressions harder to diagnose. -static LogicalResult ensureDummyPredForAIVectorScopeLatch(llvm::Loop *loop, - llvm::raw_ostream &diagOS) { - if (satisfiesAIVectorScopeLatchPostcondition(loop)) - return success(); - - llvm::BasicBlock *latch = loop->getLoopLatch(); - if (!latch) { - diagOS << "VPTO LLVM emission failed: aivscope loop is missing a latch\n"; - return failure(); - } - - llvm::SmallVector preds(llvm::predecessors(latch)); - if (preds.empty()) { - diagOS << "VPTO LLVM emission failed: aivscope latch has no predecessor\n"; - return failure(); - } - - auto *dummy = llvm::SplitBlockPredecessors( - latch, preds, "aivscope.dummy", static_cast(nullptr), - static_cast(nullptr), nullptr, /*PreserveLCSSA=*/false); - if (!dummy) { - diagOS << "VPTO LLVM emission failed: failed to normalize aivscope latch " - "predecessors\n"; - return failure(); - } - - if (!satisfiesAIVectorScopeLatchPostcondition(loop)) { - diagOS << "VPTO LLVM emission failed: normalized aivscope latch still does " - "not satisfy the single-predecessor/single-successor contract\n"; - return failure(); - } - return success(); -} - -static LogicalResult attachAIVectorScopeMetadata( - llvm::Module &llvmModule, llvm::raw_ostream &diagOS) { - llvm::Function *dummyCallee = llvmModule.getFunction(kAIVScopeDummyCallee); - if (!dummyCallee) - return success(); - - for (llvm::Function &function : llvmModule) { - if (function.isDeclaration()) - continue; - llvm::DominatorTree dt(function); - llvm::LoopInfo loopInfo(dt); - - // Stage 1: collect the lowered vecscope markers in this function. Each - // marker should end up inside the final LLVM loop that carries one - // `pto.vecscope` / `pto.strict_vecscope`. - llvm::SmallVector dummyCalls; - for (llvm::BasicBlock &block : function) { - for (llvm::Instruction &inst : block) { - auto *call = dyn_cast(&inst); - if (call && call->getCalledFunction() == dummyCallee) - dummyCalls.push_back(call); - } - } - - for (llvm::CallInst *dummyCall : dummyCalls) { - llvm::BasicBlock *markedBlock = dummyCall->getParent(); - llvm::Loop *loop = loopInfo.getLoopFor(markedBlock); - if (!loop) { - diagOS << "VPTO LLVM emission failed: aivscope_dummy in function " - << function.getName() << " does not belong to an LLVM loop\n"; - return failure(); - } - - // Stage 2: if the marker ended up in the loop latch, split the block so - // the eventual latch stays as a clean backedge block instead of carrying - // vector-thread side effects. - if (markedBlock == loop->getLoopLatch() && - dummyCall != markedBlock->getTerminator()) { - markedBlock->splitBasicBlock(dummyCall->getIterator(), "aivscope.latch"); - dt.recalculate(function); - loopInfo.releaseMemory(); - loopInfo.analyze(dt); - markedBlock = dummyCall->getParent(); - loop = loopInfo.getLoopFor(markedBlock); - if (!loop) { - diagOS << "VPTO LLVM emission failed: split aivscope latch in " - << function.getName() - << " no longer belongs to an LLVM loop\n"; - return failure(); - } - } - - if (failed(ensureDummyPredForAIVectorScopeLatch(loop, diagOS))) - return failure(); - - // Stage 3: after any CFG surgery, re-query the loop and attach - // `llvm.loop.aivector_scope` to the normalized latch backedge. The dummy - // marker has served its purpose by this point and is removed. - dt.recalculate(function); - loopInfo.releaseMemory(); - loopInfo.analyze(dt); - loop = loopInfo.getLoopFor(markedBlock); - if (!loop) { - diagOS << "VPTO LLVM emission failed: aivscope_dummy in function " - << function.getName() - << " lost its loop after latch normalization\n"; - return failure(); - } - - llvm::BasicBlock *latch = loop->getLoopLatch(); - auto *branch = dyn_cast_or_null( - latch ? latch->getTerminator() : nullptr); - if (!branch || branch->isConditional()) { - diagOS << "VPTO LLVM emission failed: normalized aivscope loop in " - << function.getName() - << " does not have an unconditional latch backedge\n"; - return failure(); - } - - llvm::LLVMContext &ctx = llvmModule.getContext(); - llvm::Metadata *ops[] = { - nullptr, llvm::MDNode::get(ctx, llvm::MDString::get(ctx, "llvm.loop.aivector_scope"))}; - auto *loopID = llvm::MDNode::getDistinct(ctx, ops); - loopID->replaceOperandWith(0, loopID); - branch->setMetadata(llvm::LLVMContext::MD_loop, loopID); - dummyCall->eraseFromParent(); - } - } - - if (dummyCallee->use_empty()) - dummyCallee->eraseFromParent(); - return success(); -} - -static void attachHIVMKernelAnnotations(llvm::Module &llvmModule) { - llvm::NamedMDNode *annotations = llvmModule.getOrInsertNamedMetadata( - "hivm.annotations"); - llvm::LLVMContext &ctx = llvmModule.getContext(); - llvm::Type *i32Ty = llvm::Type::getInt32Ty(ctx); - llvm::Constant *one = llvm::ConstantInt::get(i32Ty, 1); - - auto addAnnotation = [&](llvm::Function &function, llvm::StringRef kind) { - llvm::Metadata *ops[] = { - llvm::ValueAsMetadata::get(&function), - llvm::MDString::get(ctx, kind), - llvm::ConstantAsMetadata::get(one)}; - annotations->addOperand(llvm::MDNode::get(ctx, ops)); - }; - - for (llvm::Function &function : llvmModule) { - if (function.isDeclaration()) - continue; - if (function.getLinkage() != llvm::GlobalValue::ExternalLinkage) - continue; - - llvm::StringRef name = function.getName(); - if (name.contains(".extracted") || name.contains(".vector.thread")) - continue; - - addAnnotation(function, "kernel"); - addAnnotation(function, "kernel_with_simd"); - } -} - -static FailureOr extractQuotedLLVMFnAttr(llvm::StringRef ir, - llvm::StringRef key) { - std::string pattern = "\""; - pattern += key.str(); - pattern += "\"=\""; - size_t start = ir.find(pattern); - if (start == llvm::StringRef::npos) - return failure(); - start += pattern.size(); - size_t end = ir.find('"', start); - if (end == llvm::StringRef::npos || end <= start) - return failure(); - return ir.slice(start, end).str(); -} - -static FailureOr -queryDefaultTargetAttrs(const VPTOEmissionOptions &options, - llvm::raw_ostream &diagOS) { - static llvm::StringMap cache; - - if (options.targetTriple.empty() || options.march.empty() || - options.aicoreArch.empty()) { - diagOS << "VPTO LLVM emission failed: missing target query options\n"; - return failure(); - } - - std::string cacheKey = - options.targetTriple + "|" + options.march + "|" + options.aicoreArch; - if (auto it = cache.find(cacheKey); it != cache.end()) - return it->second; - - auto bisheng = llvm::sys::findProgramByName("bisheng"); - if (!bisheng) { - diagOS << "VPTO LLVM emission failed: unable to find 'bisheng' in PATH\n"; - return failure(); - } - const std::string &bishengPath = *bisheng; - - llvm::SmallString<64> inputPath; - llvm::SmallString<64> outputPath; - int inputFD = -1; - int outputFD = -1; - if (auto ec = llvm::sys::fs::createTemporaryFile("ptoas-vpto-target-query", - "c", inputFD, inputPath)) { - diagOS << "VPTO LLVM emission failed: cannot create bisheng query input: " - << ec.message() << "\n"; - return failure(); - } - if (auto ec = llvm::sys::fs::createTemporaryFile("ptoas-vpto-target-query", - "ll", outputFD, outputPath)) { - llvm::sys::fs::remove(inputPath); - llvm::sys::Process::SafelyCloseFileDescriptor(inputFD); - diagOS << "VPTO LLVM emission failed: cannot create bisheng query output: " - << ec.message() << "\n"; - return failure(); - } +template +static LogicalResult runPipeline(ModuleOp module, llvm::raw_ostream &diagOS, + const VPTOEmissionOptions &options, + EmitFn &&emit) { + OwningOpRef clonedOp(module->clone()); + ModuleOp clonedModule = cast(*clonedOp); - auto cleanup = llvm::make_scope_exit([&]() { - llvm::sys::fs::remove(inputPath); - llvm::sys::fs::remove(outputPath); - }); + materializeVecScopeCarrierLoops(clonedModule); - { - llvm::raw_fd_ostream inputOS(inputFD, /*shouldClose=*/false); - inputOS << "void f(void) {}\n"; - } - llvm::sys::Process::SafelyCloseFileDescriptor(inputFD); - llvm::sys::Process::SafelyCloseFileDescriptor(outputFD); - - llvm::SmallString<128> stderrPath; - int stderrFD = -1; - if (auto ec = llvm::sys::fs::createTemporaryFile("ptoas-vpto-target-query", - "stderr", stderrFD, - stderrPath)) { - diagOS << "VPTO LLVM emission failed: cannot create bisheng query stderr: " - << ec.message() << "\n"; - return failure(); - } - auto stderrCleanup = llvm::make_scope_exit([&]() { - llvm::sys::fs::remove(stderrPath); - }); - llvm::sys::Process::SafelyCloseFileDescriptor(stderrFD); - - llvm::SmallVector argStorage = { - bishengPath, - ("--target=" + options.targetTriple), - ("-march=" + options.march), - ("--cce-aicore-arch=" + options.aicoreArch), - "--cce-aicore-only", - "-x", - "c", - inputPath.str().str(), - "-S", - "-emit-llvm", - "-o", - outputPath.str().str(), - }; - llvm::SmallVector args; - args.reserve(argStorage.size()); - for (const std::string &arg : argStorage) - args.push_back(arg); - - std::string execErr; - bool execFailed = false; - int rc = llvm::sys::ExecuteAndWait( - bishengPath, args, std::nullopt, - {std::nullopt, std::nullopt, llvm::StringRef(stderrPath)}, 0, 0, - &execErr, &execFailed); - - auto stderrBuffer = llvm::MemoryBuffer::getFile(stderrPath); - llvm::StringRef stderrText = - stderrBuffer ? stderrBuffer.get()->getBuffer() : llvm::StringRef(); - - if (execFailed || rc != 0) { - diagOS << "VPTO LLVM emission failed: bisheng target query failed\n"; - diagOS << "Command:"; - for (llvm::StringRef arg : args) - diagOS << " " << arg; - diagOS << "\n"; - if (!execErr.empty()) - diagOS << execErr << "\n"; - if (!stderrText.empty()) - diagOS << stderrText << "\n"; + if (failed(normalizePtoMemRefSpaces(clonedModule, diagOS))) { + diagOS << "VPTO LLVM emission failed: normalizePtoMemRefSpaces failed\n"; return failure(); } - - auto outputBuffer = llvm::MemoryBuffer::getFile(outputPath); - if (!outputBuffer) { - diagOS << "VPTO LLVM emission failed: cannot read bisheng query output\n"; + if (failed(lowerVPTOOps(clonedModule, diagOS))) { + diagOS << "VPTO LLVM emission failed: lowerVPTOOps failed\n"; return failure(); } - - FailureOr targetCPU = - extractQuotedLLVMFnAttr(outputBuffer.get()->getBuffer(), "target-cpu"); - FailureOr targetFeatures = extractQuotedLLVMFnAttr( - outputBuffer.get()->getBuffer(), "target-features"); - if (failed(targetCPU) || failed(targetFeatures)) { - diagOS << "VPTO LLVM emission failed: cannot parse bisheng target attrs\n"; - diagOS << outputBuffer.get()->getBuffer() << "\n"; + if (failed(lowerVPTOTypes(clonedModule, diagOS))) { + diagOS << "VPTO LLVM emission failed: lowerVPTOTypes failed\n"; return failure(); } - QueriedTargetAttrs attrs{*targetCPU, *targetFeatures}; - cache[cacheKey] = attrs; - return attrs; -} - -static LogicalResult -applyQueriedTargetAttrs(ModuleOp module, const VPTOEmissionOptions &options, - llvm::raw_ostream &diagOS) { - FailureOr attrs = queryDefaultTargetAttrs(options, diagOS); - if (failed(attrs)) { - if (options.defaultTargetCPU.empty() || - options.defaultTargetFeatures.empty()) - return failure(); - diagOS << "VPTO LLVM emission: falling back to configured default target attributes\n"; - attrs = QueriedTargetAttrs{options.defaultTargetCPU, - options.defaultTargetFeatures}; - } - - MLIRContext *ctx = module.getContext(); - StringAttr cpuAttr = StringAttr::get(ctx, attrs->targetCPU); - LLVM::TargetFeaturesAttr featureAttr = - LLVM::TargetFeaturesAttr::get(ctx, attrs->targetFeatures); - module.walk([&](LLVM::LLVMFuncOp funcOp) { - funcOp.setTargetCpuAttr(cpuAttr); - funcOp.setTargetFeaturesAttr(featureAttr); - }); - return success(); -} - -static llvm::Value *castABIValue(llvm::IRBuilder<> &builder, llvm::Value *value, - llvm::Type *targetType) { - if (value->getType() == targetType) - return value; - - if (auto *targetPtr = dyn_cast(targetType)) { - auto *sourcePtr = dyn_cast(value->getType()); - if (!sourcePtr) - return nullptr; - if (sourcePtr->getAddressSpace() == targetPtr->getAddressSpace()) - return builder.CreateBitCast(value, targetType); - return builder.CreateAddrSpaceCast(value, targetType); - } - - if (targetType->isIntegerTy()) { - if (value->getType()->isIntegerTy()) { - unsigned srcWidth = value->getType()->getIntegerBitWidth(); - unsigned dstWidth = targetType->getIntegerBitWidth(); - if (srcWidth == dstWidth) - return value; - if (srcWidth < dstWidth) - return builder.CreateZExt(value, targetType); - return builder.CreateTrunc(value, targetType); - } - } - - return nullptr; -} - -static llvm::Value *materializeABIExpr(llvm::IRBuilder<> &builder, - const ABIExpr &expr, - llvm::Function *wrapper, - llvm::Type *targetType) { - switch (expr.kind) { - case ABIExpr::Kind::Constant: - return llvm::ConstantInt::get(targetType, expr.constant); - case ABIExpr::Kind::FuncArg: { - if (expr.argIndex >= wrapper->arg_size()) - return nullptr; - return castABIValue(builder, wrapper->getArg(expr.argIndex), targetType); - } - case ABIExpr::Kind::Mul: { - llvm::Value *lhs = - materializeABIExpr(builder, *expr.lhs, wrapper, targetType); - llvm::Value *rhs = - materializeABIExpr(builder, *expr.rhs, wrapper, targetType); - if (!lhs || !rhs) - return nullptr; - return builder.CreateMul(lhs, rhs); - } - } - return nullptr; -} - -static unsigned getMemRefExpandedArgCount(int64_t rank) { - return 2u + 1u + static_cast(rank) + static_cast(rank); -} - -static llvm::Value *resolveInsertedAggregateValue(llvm::Value *value, - llvm::ArrayRef idxs) { - auto *insert = dyn_cast(value); - if (!insert) - return nullptr; - - if (insert->getIndices() == idxs) - return insert->getInsertedValueOperand(); - - return resolveInsertedAggregateValue(insert->getAggregateOperand(), idxs); -} - -static llvm::Value *resolveAddrSpaceRoundTrip(llvm::Value *value) { - auto *outerCast = dyn_cast(value); - if (!outerCast) - return nullptr; - - auto *innerCast = dyn_cast(outerCast->getPointerOperand()); - if (!innerCast) - return nullptr; - - llvm::Value *original = innerCast->getPointerOperand(); - if (original->getType() != outerCast->getType()) - return nullptr; - - auto *innerDstPtr = dyn_cast(innerCast->getType()); - auto *outerDstPtr = dyn_cast(outerCast->getType()); - auto *origPtr = dyn_cast(original->getType()); - if (!innerDstPtr || !outerDstPtr || !origPtr) - return nullptr; - - if (innerDstPtr->getAddressSpace() == origPtr->getAddressSpace()) - return nullptr; - if (outerDstPtr->getAddressSpace() != origPtr->getAddressSpace()) - return nullptr; - - return original; -} - -static void simplifyAggregateCarrierOps(llvm::Function &function) { - bool changed = true; - while (changed) { - changed = false; - - SmallVector toErase; - for (llvm::BasicBlock &block : function) { - for (llvm::Instruction &inst : block) { - if (auto *cast = dyn_cast(&inst)) { - if (llvm::Value *resolved = resolveAddrSpaceRoundTrip(cast)) { - cast->replaceAllUsesWith(resolved); - toErase.push_back(cast); - changed = true; - continue; - } - } - - if (auto *extract = dyn_cast(&inst)) { - if (llvm::Value *resolved = - resolveInsertedAggregateValue(extract->getAggregateOperand(), - extract->getIndices())) { - extract->replaceAllUsesWith(resolved); - toErase.push_back(extract); - changed = true; - continue; - } - } - - if (llvm::isInstructionTriviallyDead(&inst)) { - toErase.push_back(&inst); - changed = true; - } - } - } - - for (llvm::Instruction *inst : toErase) - if (!inst->isTerminator()) - inst->eraseFromParent(); - } -} - -static LogicalResult rewriteFunctionsToEmitCStyleABI( - llvm::Module &llvmModule, const llvm::StringMap &specs, - llvm::raw_ostream &diagOS) { - SmallVector funcs; - for (llvm::Function &function : llvmModule) - if (!function.isDeclaration()) - funcs.push_back(&function); - - for (llvm::Function *function : funcs) { - auto it = specs.find(function->getName()); - if (it == specs.end()) - continue; - - const FunctionABISpec &spec = it->second; - if (spec.args.empty()) - continue; - - bool needsRewrite = - llvm::any_of(spec.args, [](const ExternalArgABISpec &arg) { - return arg.isMemRef; - }); - if (!needsRewrite) - continue; - - SmallVector publicArgTypes; - SmallVector oldArgBaseIndex(spec.args.size(), 0); - unsigned oldArgCursor = 0; - bool supported = true; - for (auto [idx, argSpec] : llvm::enumerate(spec.args)) { - oldArgBaseIndex[idx] = oldArgCursor; - if (argSpec.isMemRef) { - if (argSpec.memrefSpec.rank != 1) { - supported = false; - break; - } - publicArgTypes.push_back(llvm::PointerType::get( - llvmModule.getContext(), argSpec.memrefSpec.addressSpace)); - oldArgCursor += getMemRefExpandedArgCount(argSpec.memrefSpec.rank); - } else { - if (oldArgCursor >= function->arg_size()) { - supported = false; - break; - } - publicArgTypes.push_back(function->getArg(oldArgCursor)->getType()); - ++oldArgCursor; - } - } - - if (!supported || oldArgCursor != function->arg_size()) { - diagOS << "VPTO LLVM emission warning: skipping ABI rewrite for " - << function->getName() - << " because the lowered signature does not match the seam spec\n"; - continue; - } - - std::string originalName = function->getName().str(); - std::string tempName = "__ptoas_old_" + originalName; - function->setName(tempName); - function->setLinkage(llvm::GlobalValue::InternalLinkage); - - auto *publicType = llvm::FunctionType::get(function->getReturnType(), - publicArgTypes, - function->isVarArg()); - llvm::Function *replacement = llvm::Function::Create( - publicType, llvm::GlobalValue::ExternalLinkage, originalName, &llvmModule); - replacement->copyAttributesFrom(function); - replacement->setLinkage(llvm::GlobalValue::ExternalLinkage); - - unsigned publicArgIndex = 0; - for (llvm::Argument &arg : replacement->args()) - arg.setName("arg" + std::to_string(publicArgIndex++)); - - llvm::BasicBlock *bridgeEntry = llvm::BasicBlock::Create( - llvmModule.getContext(), "entry", replacement); - llvm::IRBuilder<> builder(bridgeEntry); - - llvm::ValueToValueMapTy vmap; - for (auto [idx, argSpec] : llvm::enumerate(spec.args)) { - llvm::Value *publicArg = replacement->getArg(idx); - unsigned oldBase = oldArgBaseIndex[idx]; - if (!argSpec.isMemRef) { - llvm::Value *casted = castABIValue( - builder, publicArg, function->getArg(oldBase)->getType()); - if (!casted) { - diagOS << "VPTO LLVM emission failed: cannot cast scalar arg for " - << originalName << "\n"; - return failure(); - } - vmap[function->getArg(oldBase)] = casted; - continue; - } - - llvm::Type *oldPtrTy = function->getArg(oldBase)->getType(); - llvm::Type *oldAlignedPtrTy = function->getArg(oldBase + 1)->getType(); - llvm::Type *oldOffsetTy = function->getArg(oldBase + 2)->getType(); - llvm::Type *oldSizeTy = function->getArg(oldBase + 3)->getType(); - llvm::Type *oldStrideTy = function->getArg(oldBase + 4)->getType(); - - llvm::Value *allocated = castABIValue(builder, publicArg, oldPtrTy); - llvm::Value *aligned = castABIValue(builder, publicArg, oldAlignedPtrTy); - llvm::Value *offset = materializeABIExpr( - builder, argSpec.memrefSpec.offset, replacement, oldOffsetTy); - llvm::Value *size = materializeABIExpr( - builder, argSpec.memrefSpec.totalSize, replacement, oldSizeTy); - llvm::Value *stride = materializeABIExpr( - builder, argSpec.memrefSpec.stride, replacement, oldStrideTy); - if (!allocated || !aligned || !offset || !size || !stride) { - diagOS << "VPTO LLVM emission failed: cannot materialize direct ABI for " - << originalName << "\n"; - return failure(); - } - - vmap[function->getArg(oldBase)] = allocated; - vmap[function->getArg(oldBase + 1)] = aligned; - vmap[function->getArg(oldBase + 2)] = offset; - vmap[function->getArg(oldBase + 3)] = size; - vmap[function->getArg(oldBase + 4)] = stride; - } - - llvm::SmallVector returns; - llvm::CloneFunctionInto(replacement, function, vmap, - llvm::CloneFunctionChangeType::LocalChangesOnly, - returns); - - llvm::BasicBlock *oldEntry = &replacement->getEntryBlock(); - llvm::BasicBlock *clonedEntry = oldEntry->getNextNode(); - if (!clonedEntry) { - diagOS << "VPTO LLVM emission failed: cloned function body is empty for " - << originalName << "\n"; - return failure(); - } - builder.CreateBr(clonedEntry); - - function->eraseFromParent(); - simplifyAggregateCarrierOps(*replacement); - } - - return success(); -} - -static std::unique_ptr -buildLLVMModuleFromPreparedVPTO(ModuleOp module, - llvm::LLVMContext &llvmContext, - const VPTOEmissionOptions &options, - llvm::raw_ostream &diagOS) { - materializeVecScopeCarrierLoops(module); - - if (failed(normalizePtoMemRefSpaces(module, diagOS))) - return nullptr; - - if (failed(normalizePtoAlignsToABI(module, diagOS))) - return nullptr; - - if (failed(rewriteVPTOOps(module, diagOS))) { - diagOS << "VPTO LLVM emission failed: VPTO-to-call rewriting failed\n"; - return nullptr; - } - - if (failed(normalizePtoPtrsToLLVM(module, diagOS))) - return nullptr; + normalizeFuncSignaturesForOfficialLLVMLowering(clonedModule); - normalizeFuncSignaturesForOfficialLLVMLowering(module); - - PassManager pm(module.getContext()); + PassManager pm(clonedModule.getContext()); pm.enableVerifier(); pm.addPass(createConvertSCFToCFPass()); pm.addPass(createArithToLLVMConversionPass()); @@ -4507,28 +4879,30 @@ buildLLVMModuleFromPreparedVPTO(ModuleOp module, pm.addPass(createConvertFuncToLLVMPass()); pm.addPass(createConvertControlFlowToLLVMPass()); pm.addPass(createReconcileUnrealizedCastsPass()); - if (failed(pm.run(module))) { + if (failed(pm.run(clonedModule))) { diagOS << "VPTO LLVM emission failed: official lowering pipeline failed\n"; - return nullptr; + return failure(); } - if (failed(applyQueriedTargetAttrs(module, options, diagOS))) - return nullptr; + if (failed(applyQueriedTargetAttrs(clonedModule, options, diagOS))) + return failure(); - registerBuiltinDialectTranslation(*module.getContext()); - registerLLVMDialectTranslation(*module.getContext()); - auto llvmModule = translateModuleToLLVMIR(module.getOperation(), llvmContext); + llvm::LLVMContext llvmContext; + registerBuiltinDialectTranslation(*clonedModule.getContext()); + registerLLVMDialectTranslation(*clonedModule.getContext()); + std::unique_ptr llvmModule = + translateModuleToLLVMIR(clonedModule.getOperation(), llvmContext); if (!llvmModule) { diagOS << "VPTO LLVM emission failed: LLVM IR export failed\n"; - return nullptr; + return failure(); } if (failed(attachAIVectorScopeMetadata(*llvmModule, diagOS))) - return nullptr; + return failure(); attachHIVMKernelAnnotations(*llvmModule); llvmModule->setModuleIdentifier("ptoas.hivm.official"); llvmModule->setSourceFileName("ptoas.hivm.official"); - return llvmModule; + return emit(*llvmModule); } } // namespace @@ -4537,26 +4911,20 @@ LogicalResult translateVPTOModuleToLLVMText(ModuleOp module, llvm::raw_ostream &os, const VPTOEmissionOptions &options, llvm::raw_ostream &diagOS) { - llvm::LLVMContext llvmContext; - auto llvmModule = - buildLLVMModuleFromPreparedVPTO(module, llvmContext, options, diagOS); - if (!llvmModule) - return failure(); - llvmModule->print(os, nullptr); - return success(); + return runPipeline(module, diagOS, options, [&](llvm::Module &llvmModule) { + llvmModule.print(os, nullptr); + return success(); + }); } LogicalResult translateVPTOModuleToLLVMBitcode(ModuleOp module, llvm::raw_ostream &os, const VPTOEmissionOptions &options, llvm::raw_ostream &diagOS) { - llvm::LLVMContext llvmContext; - auto llvmModule = - buildLLVMModuleFromPreparedVPTO(module, llvmContext, options, diagOS); - if (!llvmModule) - return failure(); - llvm::WriteBitcodeToFile(*llvmModule, os); - return success(); + return runPipeline(module, diagOS, options, [&](llvm::Module &llvmModule) { + llvm::WriteBitcodeToFile(llvmModule, os); + return success(); + }); } } // namespace mlir::pto diff --git a/lib/PTO/Transforms/VPTOLLVMEmitterHelper.cpp b/lib/PTO/Transforms/VPTOLLVMEmitterHelper.cpp new file mode 100644 index 000000000..931f4966f --- /dev/null +++ b/lib/PTO/Transforms/VPTOLLVMEmitterHelper.cpp @@ -0,0 +1,684 @@ +//===- VPTOLLVMEmitterHelper.cpp - VPTO LLVM emission helpers ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PTO/Transforms/VPTOLLVMEmitterHelper.h" + +#include "PTO/IR/PTO.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Target/LLVMIR/Export.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/Process.h" +#include "llvm/Support/Program.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" + +using namespace mlir; + +namespace mlir::pto { +namespace { + +constexpr StringLiteral kAIVScopeDummyCallee = "aivscope_dummy"; + +struct QueriedTargetAttrs { + std::string targetCPU; + std::string targetFeatures; +}; + +static bool hasPtoMemRefMemorySpace(Type type) { + if (auto memRefType = dyn_cast(type)) + return isa(memRefType.getMemorySpace()); + if (auto functionType = dyn_cast(type)) + return llvm::any_of(functionType.getInputs(), hasPtoMemRefMemorySpace) || + llvm::any_of(functionType.getResults(), hasPtoMemRefMemorySpace); + return false; +} + +static bool hasPtoMemRefMemorySpace(TypeRange types) { + return llvm::any_of(types, [](Type type) { + return hasPtoMemRefMemorySpace(type); + }); +} + +struct ConvertPtoMemRefSpaceCarrierOp final : ConversionPattern { + ConvertPtoMemRefSpaceCarrierOp(TypeConverter &typeConverter, + MLIRContext *context) + : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), 1, context) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (!hasPtoMemRefMemorySpace(op->getOperandTypes()) && + !hasPtoMemRefMemorySpace(op->getResultTypes())) + return failure(); + if (op->getNumRegions() != 0) + return rewriter.notifyMatchFailure( + op, "region ops with PTO memref spaces are handled structurally"); + + FailureOr converted = + convertOpResultTypes(op, operands, *typeConverter, rewriter); + if (failed(converted)) + return failure(); + return success(); + } +}; + +struct ConvertMemRefReinterpretCastSpaceOp final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type convertedResultType = getTypeConverter()->convertType(op.getType()); + auto memRefResultType = dyn_cast_or_null(convertedResultType); + if (!memRefResultType) + return rewriter.notifyMatchFailure(op, "expected memref result type"); + + rewriter.replaceOpWithNewOp( + op, memRefResultType, adaptor.getSource(), adaptor.getOffsets(), + adaptor.getSizes(), adaptor.getStrides(), op.getStaticOffsets(), + op.getStaticSizes(), op.getStaticStrides()); + return success(); + } +}; + +struct ConvertMemRefSubViewSpaceOp final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type convertedResultType = getTypeConverter()->convertType(op.getType()); + auto memRefResultType = dyn_cast_or_null(convertedResultType); + if (!memRefResultType) + return rewriter.notifyMatchFailure(op, "expected memref result type"); + + rewriter.replaceOpWithNewOp( + op, memRefResultType, adaptor.getSource(), op.getMixedOffsets(), + op.getMixedSizes(), op.getMixedStrides()); + return success(); + } +}; + +struct ConvertMemRefSpaceUnrealizedCastOp final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op->getNumOperands() != 1 || op->getNumResults() != 1) + return failure(); + if (!hasPtoMemRefMemorySpace(op->getOperandTypes()) && + !hasPtoMemRefMemorySpace(op->getResultTypes())) + return failure(); + + Type convertedResultType = + getTypeConverter()->convertType(op.getResult(0).getType()); + if (!convertedResultType) + return failure(); + + Value input = adaptor.getOperands().front(); + if (input.getType() == convertedResultType) { + rewriter.replaceOp(op, input); + return success(); + } + return failure(); + } +}; + +static void ensureAIVScopeDummyDecl(ModuleOp module) { + SymbolTable symbolTable(module); + if (symbolTable.lookup(kAIVScopeDummyCallee)) + return; + + OpBuilder builder(module.getBodyRegion()); + builder.setInsertionPointToStart(module.getBody()); + auto funcType = builder.getFunctionType(TypeRange{}, TypeRange{}); + auto dummy = builder.create(module.getLoc(), + kAIVScopeDummyCallee, funcType); + dummy.setPrivate(); +} + +static bool satisfiesAIVectorScopeLatchPostcondition(llvm::Loop *loop) { + llvm::BasicBlock *latch = loop->getLoopLatch(); + if (!latch) + return false; + + llvm::SmallVector preds(llvm::predecessors(latch)); + if (preds.size() != 1) + return false; + + auto *predTerm = preds.front()->getTerminator(); + return predTerm && predTerm->getNumSuccessors() == 1 && + predTerm->getSuccessor(0) == latch; +} + +static LogicalResult ensureDummyPredForAIVectorScopeLatch( + llvm::Loop *loop, llvm::raw_ostream &diagOS) { + if (satisfiesAIVectorScopeLatchPostcondition(loop)) + return success(); + + llvm::BasicBlock *latch = loop->getLoopLatch(); + if (!latch) { + diagOS << "VPTO LLVM emission failed: aivscope loop is missing a latch\n"; + return failure(); + } + + llvm::SmallVector preds(llvm::predecessors(latch)); + if (preds.empty()) { + diagOS << "VPTO LLVM emission failed: aivscope latch has no predecessor\n"; + return failure(); + } + + auto *dummy = llvm::SplitBlockPredecessors( + latch, preds, "aivscope.dummy", static_cast(nullptr), + static_cast(nullptr), nullptr, /*PreserveLCSSA=*/false); + if (!dummy) { + diagOS << "VPTO LLVM emission failed: failed to normalize aivscope latch " + "predecessors\n"; + return failure(); + } + + if (!satisfiesAIVectorScopeLatchPostcondition(loop)) { + diagOS << "VPTO LLVM emission failed: normalized aivscope latch still does " + "not satisfy the single-predecessor/single-successor contract\n"; + return failure(); + } + return success(); +} + +static FailureOr extractQuotedLLVMFnAttr(llvm::StringRef ir, + llvm::StringRef key) { + std::string pattern = "\""; + pattern += key.str(); + pattern += "\"=\""; + size_t start = ir.find(pattern); + if (start == llvm::StringRef::npos) + return failure(); + start += pattern.size(); + size_t end = ir.find('"', start); + if (end == llvm::StringRef::npos || end <= start) + return failure(); + return ir.slice(start, end).str(); +} + +static FailureOr +queryDefaultTargetAttrs(const VPTOEmissionOptions &options, + llvm::raw_ostream &diagOS) { + static llvm::StringMap cache; + + if (options.targetTriple.empty() || options.march.empty() || + options.aicoreArch.empty()) { + diagOS << "VPTO LLVM emission failed: missing target query options\n"; + return failure(); + } + + std::string cacheKey = + options.targetTriple + "|" + options.march + "|" + options.aicoreArch; + if (auto it = cache.find(cacheKey); it != cache.end()) + return it->second; + + auto bisheng = llvm::sys::findProgramByName("bisheng"); + if (!bisheng) { + diagOS << "VPTO LLVM emission failed: unable to find 'bisheng' in PATH\n"; + return failure(); + } + const std::string &bishengPath = *bisheng; + + llvm::SmallString<64> inputPath; + llvm::SmallString<64> outputPath; + int inputFD = -1; + int outputFD = -1; + if (auto ec = llvm::sys::fs::createTemporaryFile("ptoas-vpto-target-query", + "c", inputFD, inputPath)) { + diagOS << "VPTO LLVM emission failed: cannot create bisheng query input: " + << ec.message() << "\n"; + return failure(); + } + if (auto ec = llvm::sys::fs::createTemporaryFile("ptoas-vpto-target-query", + "ll", outputFD, outputPath)) { + llvm::sys::fs::remove(inputPath); + llvm::sys::Process::SafelyCloseFileDescriptor(inputFD); + diagOS << "VPTO LLVM emission failed: cannot create bisheng query output: " + << ec.message() << "\n"; + return failure(); + } + + auto cleanup = llvm::make_scope_exit([&]() { + llvm::sys::fs::remove(inputPath); + llvm::sys::fs::remove(outputPath); + }); + + { + llvm::raw_fd_ostream inputOS(inputFD, /*shouldClose=*/false); + inputOS << "void f(void) {}\n"; + } + llvm::sys::Process::SafelyCloseFileDescriptor(inputFD); + llvm::sys::Process::SafelyCloseFileDescriptor(outputFD); + + llvm::SmallString<128> stderrPath; + int stderrFD = -1; + if (auto ec = llvm::sys::fs::createTemporaryFile("ptoas-vpto-target-query", + "stderr", stderrFD, + stderrPath)) { + diagOS << "VPTO LLVM emission failed: cannot create bisheng query stderr: " + << ec.message() << "\n"; + return failure(); + } + auto stderrCleanup = llvm::make_scope_exit([&]() { + llvm::sys::fs::remove(stderrPath); + }); + llvm::sys::Process::SafelyCloseFileDescriptor(stderrFD); + + llvm::SmallVector argStorage = { + bishengPath, + ("--target=" + options.targetTriple), + ("-march=" + options.march), + ("--cce-aicore-arch=" + options.aicoreArch), + "--cce-aicore-only", + "-x", + "c", + inputPath.str().str(), + "-S", + "-emit-llvm", + "-o", + outputPath.str().str(), + }; + llvm::SmallVector args; + args.reserve(argStorage.size()); + for (const std::string &arg : argStorage) + args.push_back(arg); + + std::string execErr; + bool execFailed = false; + int rc = llvm::sys::ExecuteAndWait( + bishengPath, args, std::nullopt, + {std::nullopt, std::nullopt, llvm::StringRef(stderrPath)}, 0, 0, + &execErr, &execFailed); + + auto stderrBuffer = llvm::MemoryBuffer::getFile(stderrPath); + llvm::StringRef stderrText = + stderrBuffer ? stderrBuffer.get()->getBuffer() : llvm::StringRef(); + + if (execFailed || rc != 0) { + diagOS << "VPTO LLVM emission failed: bisheng target query failed\n"; + diagOS << "Command:"; + for (llvm::StringRef arg : args) + diagOS << " " << arg; + diagOS << "\n"; + if (!execErr.empty()) + diagOS << execErr << "\n"; + if (!stderrText.empty()) + diagOS << stderrText << "\n"; + return failure(); + } + + auto outputBuffer = llvm::MemoryBuffer::getFile(outputPath); + if (!outputBuffer) { + diagOS << "VPTO LLVM emission failed: cannot read bisheng query output\n"; + return failure(); + } + + FailureOr targetCPU = + extractQuotedLLVMFnAttr(outputBuffer.get()->getBuffer(), "target-cpu"); + FailureOr targetFeatures = + extractQuotedLLVMFnAttr(outputBuffer.get()->getBuffer(), "target-features"); + if (failed(targetCPU) || failed(targetFeatures)) { + diagOS << "VPTO LLVM emission failed: cannot parse bisheng target attrs\n"; + diagOS << outputBuffer.get()->getBuffer() << "\n"; + return failure(); + } + + QueriedTargetAttrs attrs{*targetCPU, *targetFeatures}; + cache[cacheKey] = attrs; + return attrs; +} + +} // namespace + +LogicalResult normalizePtoMemRefSpaces(ModuleOp module, + llvm::raw_ostream &diagOS) { + MLIRContext *context = module.getContext(); + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + typeConverter.addConversion([&](MemRefType type) -> Type { + auto addrSpace = dyn_cast_or_null(type.getMemorySpace()); + if (!addrSpace) + return type; + return MemRefType::get( + type.getShape(), type.getElementType(), type.getLayout(), + IntegerAttr::get(IntegerType::get(context, 64), + static_cast(addrSpace.getAddressSpace()))); + }); + typeConverter.addTypeAttributeConversion( + [](MemRefType, pto::AddressSpaceAttr attr) -> Attribute { + return IntegerAttr::get(IntegerType::get(attr.getContext(), 64), + static_cast(attr.getAddressSpace())); + }); + auto materializeMemRefCast = [](OpBuilder &builder, Type resultType, + ValueRange inputs, Location loc) -> Value { + if (inputs.size() != 1) + return {}; + return builder + .create(loc, TypeRange{resultType}, inputs) + .getResult(0); + }; + typeConverter.addSourceMaterialization(materializeMemRefCast); + typeConverter.addTargetMaterialization(materializeMemRefCast); + typeConverter.addArgumentMaterialization(materializeMemRefCast); + + ConversionTarget target(*context); + target.addLegalOp(); + target.addDynamicallyLegalOp([&](func::FuncOp op) { + return typeConverter.isSignatureLegal(op.getFunctionType()) && + typeConverter.isLegal(&op.getBody()); + }); + target.addDynamicallyLegalOp( + [&](func::CallOp op) { return typeConverter.isLegal(op); }); + target.addDynamicallyLegalOp( + [&](func::ReturnOp op) { return typeConverter.isLegal(op); }); + target.addDynamicallyLegalOp( + [&](Operation *op) { + return isLegalForBranchOpInterfaceTypeConversionPattern(op, + typeConverter); + }); + target.markUnknownOpDynamicallyLegal([&](Operation *op) { + return typeConverter.isLegal(op->getOperandTypes()) && + typeConverter.isLegal(op->getResultTypes()); + }); + + RewritePatternSet patterns(context); + scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter, patterns, + target); + populateFunctionOpInterfaceTypeConversionPattern(patterns, + typeConverter); + populateCallOpTypeConversionPattern(patterns, typeConverter); + populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); + populateReturnOpTypeConversionPattern(patterns, typeConverter); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + diagOS << "VPTO LLVM emission failed: memref address-space normalization " + "failed\n"; + return failure(); + } + + SmallVector castsToFold; + module.walk([&](UnrealizedConversionCastOp castOp) { + if (castOp->getNumOperands() != 1 || castOp->getNumResults() != 1) + return; + if (!hasPtoMemRefMemorySpace(castOp->getOperandTypes()) && + !hasPtoMemRefMemorySpace(castOp->getResultTypes())) + return; + Type convertedResultType = + typeConverter.convertType(castOp.getResult(0).getType()); + if (convertedResultType && + convertedResultType == castOp.getOperand(0).getType()) + castsToFold.push_back(castOp); + }); + for (UnrealizedConversionCastOp castOp : castsToFold) { + castOp.getResult(0).replaceAllUsesWith(castOp.getOperand(0)); + castOp.erase(); + } + + WalkResult leftover = module.walk([&](Operation *op) { + if (hasPtoMemRefMemorySpace(op->getOperandTypes()) || + hasPtoMemRefMemorySpace(op->getResultTypes())) { + diagOS << "VPTO LLVM emission failed: residual PTO memref address space " + "on op " + << op->getName().getStringRef() << "\n"; + op->print(diagOS); + diagOS << "\n"; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + if (leftover.wasInterrupted()) + return failure(); + return success(); +} + +void materializeVecScopeCarrierLoops(ModuleOp module) { + MLIRContext *ctx = module.getContext(); + (void)ctx->getOrLoadDialect(); + (void)ctx->getOrLoadDialect(); + ensureAIVScopeDummyDecl(module); + + SmallVector scopes; + module.walk([&](pto::VecScopeOp vecScope) { scopes.push_back(vecScope); }); + + IRRewriter rewriter(module.getContext()); + for (pto::VecScopeOp vecScope : llvm::reverse(scopes)) { + if (!vecScope || vecScope.getBody().empty()) + continue; + + rewriter.setInsertionPoint(vecScope); + auto loc = vecScope.getLoc(); + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + scf::ForOp carrier = rewriter.create(loc, c0, c1, c1); + + Block &vecScopeBody = vecScope.getBody().front(); + Block *carrierBody = carrier.getBody(); + Operation *yield = carrierBody->getTerminator(); + carrierBody->getOperations().splice(Block::iterator(yield), + vecScopeBody.getOperations(), + vecScopeBody.begin(), + vecScopeBody.end()); + rewriter.setInsertionPoint(yield); + rewriter.create(loc, kAIVScopeDummyCallee, TypeRange{}, + ValueRange{}); + rewriter.eraseOp(vecScope); + } + + SmallVector strictScopes; + module.walk([&](pto::StrictVecScopeOp strictVecScope) { + strictScopes.push_back(strictVecScope); + }); + + for (pto::StrictVecScopeOp strictVecScope : llvm::reverse(strictScopes)) { + if (!strictVecScope || strictVecScope.getBody().empty()) + continue; + + rewriter.setInsertionPoint(strictVecScope); + auto loc = strictVecScope.getLoc(); + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + scf::ForOp carrier = rewriter.create(loc, c0, c1, c1); + + Block &strictBody = strictVecScope.getBody().front(); + Block *carrierBody = carrier.getBody(); + Operation *yield = carrierBody->getTerminator(); + + IRMapping mapping; + for (auto [blockArg, capture] : + llvm::zip(strictBody.getArguments(), strictVecScope.getCaptures())) + mapping.map(blockArg, capture); + + rewriter.setInsertionPoint(yield); + for (Operation &nested : strictBody.getOperations()) + rewriter.clone(nested, mapping); + rewriter.create(loc, kAIVScopeDummyCallee, TypeRange{}, + ValueRange{}); + + rewriter.eraseOp(strictVecScope); + } +} + +LogicalResult attachAIVectorScopeMetadata(llvm::Module &llvmModule, + llvm::raw_ostream &diagOS) { + llvm::Function *dummyCallee = llvmModule.getFunction(kAIVScopeDummyCallee); + if (!dummyCallee) + return success(); + + for (llvm::Function &function : llvmModule) { + if (function.isDeclaration()) + continue; + llvm::DominatorTree dt(function); + llvm::LoopInfo loopInfo(dt); + + llvm::SmallVector dummyCalls; + for (llvm::BasicBlock &block : function) { + for (llvm::Instruction &inst : block) { + auto *call = dyn_cast(&inst); + if (call && call->getCalledFunction() == dummyCallee) + dummyCalls.push_back(call); + } + } + + for (llvm::CallInst *dummyCall : dummyCalls) { + llvm::BasicBlock *markedBlock = dummyCall->getParent(); + llvm::Loop *loop = loopInfo.getLoopFor(markedBlock); + if (!loop) { + diagOS << "VPTO LLVM emission failed: aivscope_dummy in function " + << function.getName() << " does not belong to an LLVM loop\n"; + return failure(); + } + + if (markedBlock == loop->getLoopLatch() && + dummyCall != markedBlock->getTerminator()) { + markedBlock->splitBasicBlock(dummyCall->getIterator(), "aivscope.latch"); + dt.recalculate(function); + loopInfo.releaseMemory(); + loopInfo.analyze(dt); + markedBlock = dummyCall->getParent(); + loop = loopInfo.getLoopFor(markedBlock); + if (!loop) { + diagOS << "VPTO LLVM emission failed: split aivscope latch in " + << function.getName() + << " no longer belongs to an LLVM loop\n"; + return failure(); + } + } + + if (failed(ensureDummyPredForAIVectorScopeLatch(loop, diagOS))) + return failure(); + + dt.recalculate(function); + loopInfo.releaseMemory(); + loopInfo.analyze(dt); + loop = loopInfo.getLoopFor(markedBlock); + if (!loop) { + diagOS << "VPTO LLVM emission failed: aivscope_dummy in function " + << function.getName() + << " lost its loop after latch normalization\n"; + return failure(); + } + + llvm::BasicBlock *latch = loop->getLoopLatch(); + auto *branch = dyn_cast_or_null( + latch ? latch->getTerminator() : nullptr); + if (!branch || branch->isConditional()) { + diagOS << "VPTO LLVM emission failed: normalized aivscope loop in " + << function.getName() + << " does not have an unconditional latch backedge\n"; + return failure(); + } + + llvm::LLVMContext &ctx = llvmModule.getContext(); + llvm::Metadata *ops[] = { + nullptr, llvm::MDNode::get(ctx, llvm::MDString::get(ctx, "llvm.loop.aivector_scope"))}; + auto *loopID = llvm::MDNode::getDistinct(ctx, ops); + loopID->replaceOperandWith(0, loopID); + branch->setMetadata(llvm::LLVMContext::MD_loop, loopID); + dummyCall->eraseFromParent(); + } + } + + if (dummyCallee->use_empty()) + dummyCallee->eraseFromParent(); + return success(); +} + +void attachHIVMKernelAnnotations(llvm::Module &llvmModule) { + llvm::NamedMDNode *annotations = + llvmModule.getOrInsertNamedMetadata("hivm.annotations"); + llvm::LLVMContext &ctx = llvmModule.getContext(); + llvm::Type *i32Ty = llvm::Type::getInt32Ty(ctx); + llvm::Constant *one = llvm::ConstantInt::get(i32Ty, 1); + + auto addAnnotation = [&](llvm::Function &function, llvm::StringRef kind) { + llvm::Metadata *ops[] = { + llvm::ValueAsMetadata::get(&function), + llvm::MDString::get(ctx, kind), + llvm::ConstantAsMetadata::get(one)}; + annotations->addOperand(llvm::MDNode::get(ctx, ops)); + }; + + for (llvm::Function &function : llvmModule) { + if (function.isDeclaration()) + continue; + if (function.getLinkage() != llvm::GlobalValue::ExternalLinkage) + continue; + + llvm::StringRef name = function.getName(); + if (name.contains(".extracted") || name.contains(".vector.thread")) + continue; + + addAnnotation(function, "kernel"); + addAnnotation(function, "kernel_with_simd"); + } +} + +LogicalResult +applyQueriedTargetAttrs(ModuleOp module, const VPTOEmissionOptions &options, + llvm::raw_ostream &diagOS) { + FailureOr attrs = queryDefaultTargetAttrs(options, diagOS); + if (failed(attrs)) { + if (options.defaultTargetCPU.empty() || + options.defaultTargetFeatures.empty()) + return failure(); + diagOS << "VPTO LLVM emission: falling back to configured default target " + "attributes\n"; + attrs = QueriedTargetAttrs{options.defaultTargetCPU, + options.defaultTargetFeatures}; + } + + MLIRContext *ctx = module.getContext(); + StringAttr cpuAttr = StringAttr::get(ctx, attrs->targetCPU); + LLVM::TargetFeaturesAttr featureAttr = + LLVM::TargetFeaturesAttr::get(ctx, attrs->targetFeatures); + module.walk([&](LLVM::LLVMFuncOp funcOp) { + funcOp.setTargetCpuAttr(cpuAttr); + funcOp.setTargetFeaturesAttr(featureAttr); + }); + return success(); +} + +} // namespace mlir::pto From 227da17ef7e1bb42b402046506b218a6cd52bdf7 Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Sun, 12 Apr 2026 17:47:44 +0800 Subject: [PATCH 011/192] add online softmax (q * k is ready) case --- .../kernels/online-softmax-update/compare.py | 50 ++++++ .../kernels/online-softmax-update/golden.py | 65 +++++++ .../kernels/online-softmax-update/kernel.pto | 164 ++++++++++++++++++ .../kernels/online-softmax-update/launch.cpp | 56 ++++++ .../kernels/online-softmax-update/main.cpp | 153 ++++++++++++++++ .../kernels/online-softmax-update/stub.cpp | 23 +++ 6 files changed, 511 insertions(+) create mode 100644 test/vpto/cases/kernels/online-softmax-update/compare.py create mode 100644 test/vpto/cases/kernels/online-softmax-update/golden.py create mode 100644 test/vpto/cases/kernels/online-softmax-update/kernel.pto create mode 100644 test/vpto/cases/kernels/online-softmax-update/launch.cpp create mode 100644 test/vpto/cases/kernels/online-softmax-update/main.cpp create mode 100644 test/vpto/cases/kernels/online-softmax-update/stub.cpp diff --git a/test/vpto/cases/kernels/online-softmax-update/compare.py b/test/vpto/cases/kernels/online-softmax-update/compare.py new file mode 100644 index 000000000..e6af92b4a --- /dev/null +++ b/test/vpto/cases/kernels/online-softmax-update/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# case: kernels/online-softmax-update +# family: kernels +# target_ops: pto.copy_gm_to_ubuf, pto.copy_ubuf_to_gm, pto.vlds, pto.vcmax, pto.vdup, pto.vmax, pto.vexpdiff, pto.vcadd, pto.vadd, pto.vmul, pto.vdiv, pto.vsts +# scenarios: online-softmax-update, 16x128-f32, oldmax-oldsum-qk-to-newmax-newsum-expmax-out + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + abs_diff = np.abs(golden.astype(np.float64) - output.astype(np.float64)) + idx = int(np.argmax(abs_diff)) + print( + f"[ERROR] Mismatch: max diff={float(abs_diff[idx])} at idx={idx} " + f"(golden={float(golden[idx])}, out={float(output[idx])}, dtype={dtype_np})" + ) + return False + return True + + +def main(): + ok = True + ok = compare_bin("golden_v4.bin", "v4.bin", np.float32, 1e-4) and ok + ok = compare_bin("golden_v5.bin", "v5.bin", np.float32, 1e-4) and ok + ok = compare_bin("golden_v6.bin", "v6.bin", np.float32, 1e-4) and ok + ok = compare_bin("golden_v7.bin", "v7.bin", np.float32, 1e-4) and ok + if not ok: + print("[ERROR] compare failed") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/kernels/online-softmax-update/golden.py b/test/vpto/cases/kernels/online-softmax-update/golden.py new file mode 100644 index 000000000..ea41425eb --- /dev/null +++ b/test/vpto/cases/kernels/online-softmax-update/golden.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# case: kernels/online-softmax-update +# family: kernels +# target_ops: pto.get_block_idx, pto.copy_gm_to_ubuf, pto.copy_ubuf_to_gm, pto.vlds, pto.vcmax, pto.vdup, pto.vmax, pto.vexpdiff, pto.vcadd, pto.vadd, pto.vmul, pto.vdiv, pto.vsts +# scenarios: online-softmax-update, dynamic-rows-and-seq, max-seq-128, block-rows-8, oldmax-oldsum-qk-to-newmax-newsum-expmax-out + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 24 +COLS = 128 +SEED = 19 +SEQ = 73 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + seq = SEQ + oldmax = rng.uniform(-3.0, 1.5, size=(ROWS,)).astype(np.float32) + oldsum = rng.uniform(0.5, 4.0, size=(ROWS,)).astype(np.float32) + qk = rng.normal(loc=0.0, scale=1.5, size=(ROWS, COLS)).astype(np.float32) + + qk_active = qk[:, :seq] + qk_rowmax = np.max(qk_active, axis=1) + newmax = np.maximum(qk_rowmax, oldmax) + tmp_active = np.exp(qk_active - newmax[:, None], dtype=np.float32) + cursum = np.sum(tmp_active, axis=1, dtype=np.float32) + raw_expmax = np.exp(oldmax - newmax, dtype=np.float32) + newsum = raw_expmax * oldsum + cursum + expmax = (raw_expmax * oldsum) / newsum + out = np.zeros((ROWS, COLS), dtype=np.float32) + out[:, :seq] = tmp_active / newsum[:, None] + + zeros_state = np.zeros((ROWS,), dtype=np.float32) + zeros_out = np.zeros((ROWS, COLS), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + oldmax.tofile(output_dir / "v1.bin") + oldsum.tofile(output_dir / "v2.bin") + qk.reshape(-1).tofile(output_dir / "v3.bin") + zeros_state.tofile(output_dir / "v4.bin") + zeros_state.tofile(output_dir / "v5.bin") + zeros_state.tofile(output_dir / "v6.bin") + zeros_out.reshape(-1).tofile(output_dir / "v7.bin") + np.array([seq], dtype=np.int32).tofile(output_dir / "v8.bin") + np.array([ROWS], dtype=np.int32).tofile(output_dir / "v9.bin") + newmax.tofile(output_dir / "golden_v4.bin") + newsum.tofile(output_dir / "golden_v5.bin") + expmax.tofile(output_dir / "golden_v6.bin") + out.astype(np.float32, copy=False).reshape(-1).tofile(output_dir / "golden_v7.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/kernels/online-softmax-update/kernel.pto b/test/vpto/cases/kernels/online-softmax-update/kernel.pto new file mode 100644 index 000000000..9d49bc6cb --- /dev/null +++ b/test/vpto/cases/kernels/online-softmax-update/kernel.pto @@ -0,0 +1,164 @@ +// ----------------------------------------------------------------------------- +// case: kernels/online-softmax-update +// family: kernels +// target_ops: pto.get_block_idx, pto.copy_gm_to_ubuf, pto.copy_ubuf_to_gm, pto.vlds, pto.vcmax, pto.vdup, pto.vmax, pto.vexpdiff, pto.vcadd, pto.vadd, pto.vmul, pto.vdiv, pto.vsts +// scenarios: online-softmax-update, dynamic-rows-and-seq, max-seq-128, block-rows-8, oldmax-oldsum-qk-to-newmax-newsum-expmax-out +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @online_softmax_update_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr, + %arg3: !pto.ptr, + %arg4: !pto.ptr, + %arg5: !pto.ptr, + %arg6: !pto.ptr, + %arg7: i32, + %arg8: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c8_i64 = arith.constant 8 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c8448_i64 = arith.constant 8448 : i64 + %c16640_i64 = arith.constant 16640 : i64 + %c16768_i64 = arith.constant 16768 : i64 + %c16896_i64 = arith.constant 16896 : i64 + + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %c64_i32 = arith.constant 64 : i32 + %c0_i32 = arith.constant 0 : i32 + %false = arith.constant false + + %block = pto.get_block_idx + %block_idx = arith.index_cast %block : i64 to index + %row_base = arith.muli %block_idx, %c8 : index + %qk_base = arith.muli %row_base, %c128 : index + %block_rows_i32 = arith.index_cast %c8 : index to i32 + %row_base_i32 = arith.index_cast %row_base : index to i32 + %remaining_rows = arith.subi %arg8, %row_base_i32 : i32 + %has_rows = arith.cmpi sgt, %remaining_rows, %c0_i32 : i32 + %too_many_rows = arith.cmpi sgt, %remaining_rows, %c8_i32 : i32 + %row_count_i32 = arith.select %too_many_rows, %c8_i32, %remaining_rows : i32 + %row_count = arith.index_cast %row_count_i32 : i32 to index + %row_count_i64 = arith.extui %row_count_i32 : i32 to i64 + %gm_oldmax = pto.addptr %arg0, %row_base : !pto.ptr -> !pto.ptr + %gm_oldsum = pto.addptr %arg1, %row_base : !pto.ptr -> !pto.ptr + %gm_qk = pto.addptr %arg2, %qk_base : !pto.ptr -> !pto.ptr + %gm_qk_hi = pto.addptr %gm_qk, %c64 : !pto.ptr -> !pto.ptr + %gm_newmax = pto.addptr %arg3, %row_base : !pto.ptr -> !pto.ptr + %gm_newsum = pto.addptr %arg4, %row_base : !pto.ptr -> !pto.ptr + %gm_expmax = pto.addptr %arg5, %row_base : !pto.ptr -> !pto.ptr + %gm_out = pto.addptr %arg6, %qk_base : !pto.ptr -> !pto.ptr + + %ub_oldmax = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_oldsum = pto.castptr %c128_i64 : i64 -> !pto.ptr + %ub_qk = pto.castptr %c256_i64 : i64 -> !pto.ptr + %ub_qk_hi = pto.addptr %ub_qk, %c64 : !pto.ptr -> !pto.ptr + %ub_out = pto.castptr %c8448_i64 : i64 -> !pto.ptr + %ub_newmax = pto.castptr %c16640_i64 : i64 -> !pto.ptr + %ub_newsum = pto.castptr %c16768_i64 : i64 -> !pto.ptr + %ub_expmax = pto.castptr %c16896_i64 : i64 -> !pto.ptr + + scf.if %has_rows { + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %gm_oldmax, %ub_oldmax, %c0_i64, %c1_i64, %c32_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %gm_oldsum, %ub_oldsum, %c0_i64, %c1_i64, %c32_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %gm_qk, %ub_qk, %c0_i64, %row_count_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c512_i64, %c512_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %gm_qk_hi, %ub_qk_hi, %c0_i64, %row_count_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c512_i64, %c512_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %active = pto.pset_b32 "PAT_ALL" : !pto.mask + %one_mask, %one_remaining = pto.plt_b32 %c1_i32 : i32 -> !pto.mask, i32 + scf.for %row = %c0 to %row_count step %c1 { + %row_qk = arith.muli %row, %c128 : index + %oldmax_bc = pto.vlds %ub_oldmax[%row] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> + %oldsum_bc = pto.vlds %ub_oldsum[%row] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> + + %final_max, %final_sum = scf.for %chunk = %c0 to %c128 step %c64 + iter_args(%running_max = %oldmax_bc, %running_sum = %oldsum_bc) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %chunk_i32 = arith.index_cast %chunk : index to i32 + %remaining_cols = arith.subi %arg7, %chunk_i32 : i32 + %has_chunk = arith.cmpi sgt, %remaining_cols, %c0_i32 : i32 + %next_max, %next_sum = scf.if %has_chunk -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %chunk_mask, %chunk_rest = pto.plt_b32 %remaining_cols : i32 -> !pto.mask, i32 + %chunk_base = arith.addi %row_qk, %chunk : index + %vec = pto.vlds %ub_qk[%chunk_base] : !pto.ptr -> !pto.vreg<64xf32> + %chunk_max = pto.vcmax %vec, %chunk_mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %chunk_max_bc = pto.vdup %chunk_max, %active {position = "LOWEST"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %merged_max = pto.vmax %running_max, %chunk_max_bc, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %scaled_running = pto.vexpdiff %running_max, %merged_max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + %running_sum_scaled = pto.vmul %scaled_running, %running_sum, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %chunk_exp = pto.vexpdiff %vec, %merged_max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + %chunk_sum = pto.vcadd %chunk_exp, %chunk_mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %chunk_sum_bc = pto.vdup %chunk_sum, %active {position = "LOWEST"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %merged_sum = pto.vadd %running_sum_scaled, %chunk_sum_bc, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + scf.yield %merged_max, %merged_sum : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } else { + scf.yield %running_max, %running_sum : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + scf.yield %next_max, %next_sum : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + %raw_expmax = pto.vexpdiff %oldmax_bc, %final_max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + %scaled_oldsum = pto.vmul %raw_expmax, %oldsum_bc, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %expmax = pto.vdiv %scaled_oldsum, %final_sum, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %final_max, %ub_newmax[%row], %one_mask {dist = "1PT"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.vsts %final_sum, %ub_newsum[%row], %one_mask {dist = "1PT"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.vsts %expmax, %ub_expmax[%row], %one_mask {dist = "1PT"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + + %zero = pto.vsub %final_max, %final_max, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + scf.for %chunk = %c0 to %c128 step %c64 { + %chunk_base = arith.addi %row_qk, %chunk : index + pto.vsts %zero, %ub_out[%chunk_base], %active : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + + scf.for %chunk = %c0 to %c128 step %c64 { + %chunk_i32 = arith.index_cast %chunk : index to i32 + %remaining_cols = arith.subi %arg7, %chunk_i32 : i32 + %has_chunk = arith.cmpi sgt, %remaining_cols, %c0_i32 : i32 + scf.if %has_chunk { + %chunk_mask, %chunk_rest = pto.plt_b32 %remaining_cols : i32 -> !pto.mask, i32 + %chunk_base = arith.addi %row_qk, %chunk : index + %vec = pto.vlds %ub_qk[%chunk_base] : !pto.ptr -> !pto.vreg<64xf32> + %exp = pto.vexpdiff %vec, %final_max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + %out = pto.vdiv %exp, %final_sum, %chunk_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%chunk_base], %chunk_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_newmax, %gm_newmax, %c0_i64, %c1_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_newsum, %gm_newsum, %c0_i64, %c1_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_expmax, %gm_expmax, %c0_i64, %c1_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %row_count_i64, %c512_i64, %c0_i64, %c512_i64, %c512_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + } + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/kernels/online-softmax-update/launch.cpp b/test/vpto/cases/kernels/online-softmax-update/launch.cpp new file mode 100644 index 000000000..5cf6c4e2f --- /dev/null +++ b/test/vpto/cases/kernels/online-softmax-update/launch.cpp @@ -0,0 +1,56 @@ +// ----------------------------------------------------------------------------- +// case: kernels/online-softmax-update +// family: kernels +// target_ops: pto.get_block_idx, pto.copy_gm_to_ubuf, pto.copy_ubuf_to_gm, pto.vlds, pto.vcmax, pto.vdup, pto.vmax, pto.vexpdiff, pto.vcadd, pto.vadd, pto.vmul, pto.vdiv, pto.vsts +// scenarios: online-softmax-update, dynamic-rows-and-seq, max-seq-128, block-rows-8, oldmax-oldsum-qk-to-newmax-newsum-expmax-out +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#include +#include + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +namespace pto { +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +} // namespace pto +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ AICORE void online_softmax_update_kernel_2d( + __gm__ float *v1, __gm__ float *v2, __gm__ float *v3, + __gm__ float *v4, __gm__ float *v5, __gm__ float *v6, + __gm__ float *v7, int32_t v8, int32_t v9); + +void LaunchOnline_softmax_update_kernel_2d(float *v1, float *v2, float *v3, + float *v4, float *v5, float *v6, + float *v7, int32_t v8, int32_t v9, + void *stream) { + const int32_t blockRows = 8; + const int32_t blocks = (v9 + blockRows - 1) / blockRows; + online_softmax_update_kernel_2d<<>>( + (__gm__ float *)v1, (__gm__ float *)v2, (__gm__ float *)v3, + (__gm__ float *)v4, (__gm__ float *)v5, (__gm__ float *)v6, + (__gm__ float *)v7, v8, v9); +} diff --git a/test/vpto/cases/kernels/online-softmax-update/main.cpp b/test/vpto/cases/kernels/online-softmax-update/main.cpp new file mode 100644 index 000000000..6282f13a8 --- /dev/null +++ b/test/vpto/cases/kernels/online-softmax-update/main.cpp @@ -0,0 +1,153 @@ +// ----------------------------------------------------------------------------- +// case: kernels/online-softmax-update +// family: kernels +// target_ops: pto.get_block_idx, pto.copy_gm_to_ubuf, pto.copy_ubuf_to_gm, pto.vlds, pto.vcmax, pto.vdup, pto.vmax, pto.vexpdiff, pto.vcadd, pto.vadd, pto.vmul, pto.vdiv, pto.vsts +// scenarios: online-softmax-update, dynamic-rows-and-seq, max-seq-128, block-rows-8, oldmax-oldsum-qk-to-newmax-newsum-expmax-out +// ----------------------------------------------------------------------------- +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +namespace pto { +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +} // namespace pto +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchOnline_softmax_update_kernel_2d(float *v1, float *v2, float *v3, + float *v4, float *v5, float *v6, + float *v7, int32_t v8, int32_t v9, + void *stream); + +int main() { + constexpr size_t elemCountSeq = 1; + constexpr size_t elemCountRows = 1; + size_t fileSizeSeq = elemCountSeq * sizeof(int32_t); + size_t fileSizeRows = elemCountRows * sizeof(int32_t); + size_t elemCountState = 0; + size_t elemCountOut = 0; + size_t fileSizeState = 0; + size_t fileSizeOut = 0; + float *v1Host = nullptr, *v2Host = nullptr, *v3Host = nullptr; + float *v4Host = nullptr, *v5Host = nullptr, *v6Host = nullptr; + float *v7Host = nullptr; + float *v1Device = nullptr, *v2Device = nullptr, *v3Device = nullptr; + float *v4Device = nullptr, *v5Device = nullptr, *v6Device = nullptr; + float *v7Device = nullptr; + int32_t v8Host = 0, v9Host = 0; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ReadFile("./v8.bin", fileSizeSeq, &v8Host, fileSizeSeq); + ReadFile("./v9.bin", fileSizeRows, &v9Host, fileSizeRows); + + elemCountState = static_cast(v9Host); + elemCountOut = static_cast(v9Host) * 128; + fileSizeState = elemCountState * sizeof(float); + fileSizeOut = elemCountOut * sizeof(float); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSizeState)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSizeState)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSizeOut)); + ACL_CHECK(aclrtMallocHost((void **)(&v4Host), fileSizeState)); + ACL_CHECK(aclrtMallocHost((void **)(&v5Host), fileSizeState)); + ACL_CHECK(aclrtMallocHost((void **)(&v6Host), fileSizeState)); + ACL_CHECK(aclrtMallocHost((void **)(&v7Host), fileSizeOut)); + + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSizeState, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSizeState, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSizeOut, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v4Device, fileSizeState, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v5Device, fileSizeState, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v6Device, fileSizeState, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v7Device, fileSizeOut, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSizeState, v1Host, fileSizeState); + ReadFile("./v2.bin", fileSizeState, v2Host, fileSizeState); + ReadFile("./v3.bin", fileSizeOut, v3Host, fileSizeOut); + ReadFile("./v4.bin", fileSizeState, v4Host, fileSizeState); + ReadFile("./v5.bin", fileSizeState, v5Host, fileSizeState); + ReadFile("./v6.bin", fileSizeState, v6Host, fileSizeState); + ReadFile("./v7.bin", fileSizeOut, v7Host, fileSizeOut); + + ACL_CHECK(aclrtMemcpy(v1Device, fileSizeState, v1Host, fileSizeState, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSizeState, v2Host, fileSizeState, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSizeOut, v3Host, fileSizeOut, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v4Device, fileSizeState, v4Host, fileSizeState, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v5Device, fileSizeState, v5Host, fileSizeState, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v6Device, fileSizeState, v6Host, fileSizeState, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v7Device, fileSizeOut, v7Host, fileSizeOut, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchOnline_softmax_update_kernel_2d(v1Device, v2Device, v3Device, + v4Device, v5Device, v6Device, + v7Device, v8Host, v9Host, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v4Host, fileSizeState, v4Device, fileSizeState, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v5Host, fileSizeState, v5Device, fileSizeState, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v6Host, fileSizeState, v6Device, fileSizeState, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v7Host, fileSizeOut, v7Device, fileSizeOut, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v4.bin", v4Host, fileSizeState); + WriteFile("./v5.bin", v5Host, fileSizeState); + WriteFile("./v6.bin", v6Host, fileSizeState); + WriteFile("./v7.bin", v7Host, fileSizeOut); + +cleanup: + aclrtFree(v1Device); aclrtFree(v2Device); aclrtFree(v3Device); + aclrtFree(v4Device); aclrtFree(v5Device); aclrtFree(v6Device); aclrtFree(v7Device); + aclrtFreeHost(v1Host); aclrtFreeHost(v2Host); aclrtFreeHost(v3Host); + aclrtFreeHost(v4Host); aclrtFreeHost(v5Host); aclrtFreeHost(v6Host); aclrtFreeHost(v7Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + + return rc; +} diff --git a/test/vpto/cases/kernels/online-softmax-update/stub.cpp b/test/vpto/cases/kernels/online-softmax-update/stub.cpp new file mode 100644 index 000000000..003519801 --- /dev/null +++ b/test/vpto/cases/kernels/online-softmax-update/stub.cpp @@ -0,0 +1,23 @@ +// ----------------------------------------------------------------------------- +// case: kernels/online-softmax-update +// family: kernels +// target_ops: pto.copy_gm_to_ubuf, pto.copy_ubuf_to_gm, pto.vlds, pto.vcmax, pto.vdup, pto.vmax, pto.vexpdiff, pto.vcadd, pto.vadd, pto.vmul, pto.vdiv, pto.vsts +// scenarios: online-softmax-update, dynamic-rows-and-seq, max-seq-128, block-rows-8, oldmax-oldsum-qk-to-newmax-newsum-expmax-out +// ----------------------------------------------------------------------------- +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ AICORE void online_softmax_update_kernel_2d( + __gm__ float *v1, __gm__ float *v2, __gm__ float *v3, + __gm__ float *v4, __gm__ float *v5, __gm__ float *v6, + __gm__ float *v7, int32_t v8, int32_t v9) { + (void)v1; (void)v2; (void)v3; (void)v4; + (void)v5; (void)v6; (void)v7; (void)v8; (void)v9; +} From dbfe96bece90f9ed1ec9b9b4593392f088fbcd79 Mon Sep 17 00:00:00 2001 From: WenboCodes Date: Sun, 12 Apr 2026 19:19:06 +0800 Subject: [PATCH 012/192] docs: add VPTO spec v0.3 release draft Add the merged v0.3 PTO micro-instruction release spec document for A5, including ISA group references and updated synchronization notes. Co-Authored-By: Claude Sonnet 4.6 --- docs/release/vpto-spec-v0.3.md | 5349 ++++++++++++++++++++++++++++++++ 1 file changed, 5349 insertions(+) create mode 100644 docs/release/vpto-spec-v0.3.md diff --git a/docs/release/vpto-spec-v0.3.md b/docs/release/vpto-spec-v0.3.md new file mode 100644 index 000000000..8de281795 --- /dev/null +++ b/docs/release/vpto-spec-v0.3.md @@ -0,0 +1,5349 @@ +# PTO micro Instruction Spec — Draft (A5) + +- v0.3: Add runtime block query and vector-interval legality notes; Normalize load/store distribution families; Update get_buf/rls_buf details +- v0.2: Update micro Instruction latency and throughput +- v0.1: Doc Init + +[toc] + +--- + +## Part I: Architecture Overview + +### Overview + +This document defines the PTO micro Instruction, a compiler-internal and externally facing specification designed to represent vector compute kernels within the PTO architecture. Much like NVVM provides a robust IR for GPU architectures, the PTO micro Instruction serves as the direct bridge between high-level programming models and the underlying hardware ISA, providing a precise, low-level representation of vector workloads explicitly designed for the Ascend 950 architecture. + +#### Position in the Stack and Layer Modeled + +The PTO micro Instruction operates as a very low-level intermediate representation within the PTO compiler stack. It is uniquely designed to accurately and comprehensively express all architectural information of the Ascend 950 hardware. It specifically models the bare-metal vector execution layer, making hardware-specific capabilities and constraints, such as exact vector lane configurations, memory space hierarchies, and hardware-specific fusion semantics, fully transparent and controllable. + +#### PTO Instruction Modes and Compilation Flows + +Within the end-to-end PTO software stack, PTO instructions may appear in three closely related authoring or lowering modes: + +- **PTO Tile Instruction**: tile-oriented PTO code that serves as a nano-kernel encapsulation of Tile operations, primarily expressing computation and data movement in terms of tile buffers, tile shapes, and tile-local layout. +- **PTO micro Instruction**: vector-execution-oriented PTO code that makes DMA setup, vector registers, masks, synchronization, and `__VEC_SCOPE__` boundaries explicit. This document is centered on this mode. +- **PTO Tile+micro Instruction**: a hybrid PTO form that keeps tile-level orchestration while embedding explicit micro-instruction regions where direct vector-pipeline control is required. + +From these PTO instruction forms, the stack can proceed along two main compilation flows: + +- **CCE generation flow**: PTO ISA is lowered into a CCE-oriented representation, which is then compiled by the BiSheng toolchain into Ascend device binaries. +- **Bytecode generation flow**: PTO ISA is emitted as bytecode, which is then compiled by the BiSheng toolchain into Ascend device binaries. + +```text +High-level frameworks / DSLs / library kernels + | + v + +----------------------------------+ + | PTO ISA layer | + | | + | (1) PTO Tile Instruction | + | (2) PTO micro Instruction | + | (3) PTO Tile+micro Instruction | + +----------------+-----------------+ + | + +------------+------------+ + | | + v v + +-------------------------+ +-------------------------+ + | Path A: generate CCE | | Path B: generate | + | (CCE-oriented form) | | bytecode | + +------------+------------+ +------------+------------+ + | | + v v + +-----------------------------------------------+ + | BiSheng compiler | + +---------------------------+-------------------+ + | + v + +-----------------------------+ + | Ascend device binaries | + +-----------------------------+ +``` + +#### Why External Developers Read or Author PTO micro Instruction + +While the majority of users will interact with the PTO architecture via higher-level frameworks, external developers may need to read or author PTO micro Instruction directly for several key reasons: + +- Custom Toolchain Development: build custom compiler frontends or domain-specific languages (DSLs) that target the Ascend 950 architecture with maximum hardware utilization. +- Performance Engineering: inspect the output of high-level compiler passes, verify fine-grained optimization behaviors, and pinpoint performance bottlenecks at the architectural level. +- Micro-Optimization: hand-author highly optimized, critical mathematical kernels using a stable, precise IR when higher-level abstractions cannot achieve the theoretical peak performance of the hardware. + +#### Relationship to CCE + +The PTO micro Instruction is designed to express the full semantic capabilities of the Compute Cube Engine (CCE), but with significant structural and pipeline advantages for compiler development. + +- Bypassing the C/Clang Pipeline: while CCE heavily relies on C/C++ extensions parsed by Clang, the PTO micro Instruction operates entirely independently of the C language frontend. By bypassing Clang AST generation and frontend processing, utilizing the PTO micro Instruction significantly reduces overall compilation time and memory overhead. +- Enhanced IR Verification: because the PTO micro Instruction is a strongly typed, SSA-based (Static Single Assignment) compiler IR rather than a C-wrapper API, it provides a much more rigorous and detailed IR verification process. Structural inconsistencies, invalid memory access patterns, and operand type mismatches are caught immediately with precise, explicit diagnostic feedback, providing developers with much higher visibility into kernel correctness than traditional CCE error reporting. + +#### Intended Audience + +This document is written for compiler engineers, library writers, and advanced performance architects. We expect the reader to have a working understanding of modern compiler infrastructure, specifically MLIR, the principles of Static Single Assignment (SSA) form, and a deep understanding of the vector-processing capabilities of the Ascend 950 architecture. + +### Getting Started + +The PTO micro Instruction is architected as a performance-critical layer within the compiler stack, specifically designed to exploit the **Decoupled Access-Execute** (DAE) nature of the Ascend 950 hardware. + +#### Hardware Pipeline Modeling + +The IR is structured to mirror the three primary hardware pipelines of the Ascend 950 architecture. Correct PTO micro Instruction authoring requires managing the interaction between these asynchronous units: + +**MTE2** (Memory Transfer Engine - Inbound): Responsible for moving data from Global Memory (GM) to the Unified Buffer (UB). + +**Vector Core** (Computation): The primary engine for executing SIMD operations on data stored in UB. + +**MTE3** (Memory Transfer Engine - Outbound): Responsible for moving processed data from UB back to GM. + +#### Architecture Detail: Vector Lane (VLane) + +The vector register is organized as **8 VLanes** of 32 bytes each. A VLane is the atomic unit for group reduction operations. + +``` +vreg (256 bytes total): +┌─────────┬─────────┬─────────┬─────┬─────────┬─────────┐ +│ VLane 0 │ VLane 1 │ VLane 2 │ ... │ VLane 6 │ VLane 7 │ +│ 32B │ 32B │ 32B │ │ 32B │ 32B │ +└─────────┴─────────┴─────────┴─────┴─────────┴─────────┘ +``` + +Elements per VLane by data type: + +| Data Type | Elements/VLane | Total Elements/vreg | +|-----------|---------------|-------------------| +| i8/si8/ui8 | 32 | 256 | +| i16/si16/ui16/f16/bf16 | 16 | 128 | +| i32/si32/ui32/f32 | 8 | 64 | +| i64/si64/ui64 | 4 | 32 | + +#### Memory and Synchronization Model + +The PTO micro Instruction enforces a strict memory hierarchy. The Unified Buffer (UB) is the only valid operand source for vector compute instructions. Consequently, the architecture of a PTO micro Instruction program is defined by the explicit management of data movement: + +**Address Space Isolation**: The IR uses `!pto.ptr` to distinguish between GM (`!pto.ptr`) and UB (`!pto.ptr`). The verifier ensures that vector compute operations do not access GM directly; data must first be moved into UB. + +**UB Capacity**: The Unified Buffer provides 256KB of on-chip SRAM (also referred to as "vecTile"). + +**Data Flow**: + +``` +┌─────────────────────────────────────────────┐ +│ Global Memory (GM) │ +│ (Off-chip HBM/DDR) │ +└─────────────────────┬───────────────────────┘ + │ DMA (MTE2 inbound / MTE3 outbound) +┌─────────────────────▼───────────────────────┐ +│ Unified Buffer (UB) │ +│ (On-chip SRAM, 256KB) │ +└─────────────────────┬───────────────────────┘ + │ Vector Load/Store (PIPE_V) +┌─────────────────────▼───────────────────────┐ +│ Vector Register File (VRF) │ +│ vreg (256B each) + mask (256-bit each) │ +└─────────────────────────────────────────────┘ +``` + +1. **GM → UB**: DMA transfer via MTE2 (`pto.copy_gm_to_ubuf`) +2. **UB → vreg**: Vector Load instructions (`pto.vlds`, `pto.vldsx2`, etc.) +3. **vreg → vreg**: Compute instructions (`pto.vadd`, `pto.vmul`, etc.) +4. **vreg → UB**: Vector Store instructions (`pto.vsts`, `pto.vstsx2`, etc.) +5. **UB → GM**: DMA transfer via MTE3 (`pto.copy_ubuf_to_gm`) + +**Load/Store Access Patterns**: + +For UB↔vreg data movement, besides contiguous load/store, the architecture provides rich access pattern support including strided access, pack/unpack, interleave/deinterleave, broadcast, upsample/downsample, channel split/merge, gather/scatter, and squeeze/expand operations. For detailed instruction syntax and distribution modes, refer to the [Vector Load/Store](#isa-03-vector-load-store) group in the ISA specification. + +#### Synchronization Model + +The Ascend 950 architecture employs a cluster-based design with a 1:2 ratio of Cube cores to Vector cores. The PTO micro Instruction provides multiple levels of synchronization to manage concurrent execution across pipelines and cores: + +**Inter-Core Synchronization (within a cluster):** + +Synchronization between cores within the same cluster is achieved via the core sync mechanism using `pto.set_intra_core` and `pto.wait_intra_core` operations. This enables coordination between Cube and Vector cores sharing the same cluster resources. + +**Vector Core Pipeline Synchronization:** + +Within a single core, multiple pipelines operate asynchronously: + +- **MTE2 (PIPE_MTE2)**: DMA copy-in from GM to UB +- **MTE3 (PIPE_MTE3)**: DMA copy-out from UB to GM +- **Vector Compute (PIPE_V)**: Vector ALU operations +- **Scalar (PIPE_S)**: Scalar unit running the kernel program + +Pipeline synchronization can be achieved through two mechanisms: + +1. **Flag/Event mechanism**: `pto.set_flag` and `pto.wait_flag` operations resolve Read-After-Write (RAW) and Write-After-Read (WAR) hazards between pipelines. + +2. **Buffer-ID mechanism**: `pto.get_buf` and `pto.rls_buf` provide finer-grained synchronization through buffer acquisition and release semantics for producer-consumer coordination. + +**Intra-Pipeline Memory Barriers (within `__VEC_SCOPE__`):** + +Within the vector execution scope, the hardware does not track UB address aliasing between reg↔UB accesses. When UB addresses overlap or alias between vector load/store operations, explicit memory barriers are required: + +```c +pto.mem_bar "VV_ALL" // All prior vector ops complete before subsequent +pto.mem_bar "VST_VLD" // All prior vector stores visible before subsequent loads +pto.mem_bar "VLD_VST" // All prior vector loads complete before subsequent stores +``` + +Without proper barriers, loads may see stale data or stores may be reordered incorrectly. + +#### Execution Scopes (__VEC_SCOPE__) + +`__VEC_SCOPE__` is the IR-level representation of a Vector Function (VF) launch. In the PTO architecture, it defines the hardware interface between the Scalar Unit and the Vector Thread. + +In PTO micro Instruction source IR, vector execution scopes are modeled as dedicated region ops. The default form is `pto.vecscope`; when the scope body must reject implicit capture and require explicit region arguments, use `pto.strict_vecscope`. + +**Scalar-Vector Interface:** + +The execution model follows non-blocking fork semantics: + +- Scalar invocation: the scalar processor invokes a vector thread by calling a VF. Once the launch command is issued, the scalar unit does not stall and continues executing subsequent instructions in the pipeline. +- Vector execution: after invocation, the vector thread independently fetches and executes the instructions defined within the VF scope. +- Parallelism: this decoupled execution allows the scalar and vector units to run in parallel, so the scalar unit can prepare addresses or manage control flow while the vector unit performs heavy SIMD computation. + +**Launch Mechanism And Constraints:** + +- Parameter buffering: all arguments required by the VF must be staged in hardware-specific buffers. +- Launch overhead: launching a VF incurs a latency of a few cycles. Very small VFs should account for this overhead because launch cost can rival useful computation time. + +**MLIR Representation:** + +```mlir +pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} +``` + +**Strict MLIR Representation:** + +```mlir +pto.strict_vecscope(%ub, %ub_out, %lane) { +^bb0(%in: !pto.ptr, %out: !pto.ptr, %iv: index): + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %in[%iv] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %out[%iv], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} : (!pto.ptr, !pto.ptr, index) -> () +``` + +`pto.strict_vecscope` is the strict form of `pto.vecscope`. + +- `pto.vecscope` allows the body to use surrounding SSA values directly. +- `pto.strict_vecscope` requires every external value used by the body to be passed through the op operand list and received as a body block argument. +- `pto.strict_vecscope` rejects implicit capture from the surrounding scope. +- both ops still represent one explicit VPTO vector interval. +- regardless of whether the source form uses `pto.vecscope`, + `pto.strict_vecscope`, or a lowered carrier loop with + `llvm.loop.aivector_scope`, every op that produces or consumes `!pto.vreg`, + `!pto.mask<...>`, or `!pto.align` must be enclosed by exactly one vector + interval +- nested vector intervals are not part of the legal VPTO surface; ordinary + nested `scf.for` structure is fine, but one vector interval may not contain + another vector interval + +### Example: VecScope + +```mlir +pto.set_loop2_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.copy_gm_to_ubuf %7, %2, %3, %3, %c0_i64, %c32_i64, %4, %c0_i64, %c0_i64, + %false, %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i1, i64, i64, i64 + +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +pto.vecscope { + scf.for %lane = %c0 to %9 step %c64 { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %2[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %8[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } +} + +pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] +pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop2_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 +pto.copy_ubuf_to_gm %8, %14, %3, %3, %c0_i64, %c32_i64, %4, %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i64 +``` + +### Example: Strict VecScope + +```mlir +pto.strict_vecscope(%ub_in, %ub_out, %lane, %remaining) { +^bb0(%in: !pto.ptr, %out: !pto.ptr, %iv: index, %rem: i32): + %mask, %next_remaining = pto.plt_b32 %rem : i32 -> !pto.mask, i32 + %v = pto.vlds %in[%iv] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %out[%iv], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} : (!pto.ptr, !pto.ptr, index, i32) -> () +``` + +Use `pto.strict_vecscope` when the source form should make all vector-scope inputs explicit in the region signature instead of relying on surrounding SSA visibility. The scope op itself only defines the vector-interval boundary and region argument contract. + +### Scope + +This document is the interface specification centered on the `mlir::pto` dialect and the shared MLIR surface used alongside it in PTO micro Instruction programs. + +It only describes: + +- operation names +- operand and result lists +- operand and result types +- important attributes +- C-style semantics for each operation + +It does not describe lowering strategy. + +PTO micro Instruction source programs are not restricted to `pto` operations alone. In practice they also use shared MLIR dialect ops, most notably the full scalar operation surface of `arith` together with structured control-flow ops from `scf`, to express scalar constants, scalar arithmetic, type conversion, comparisons, and structured control flow around PTO vector or tile regions. These shared-dialect ops are part of the supported PTO micro Instruction source surface and should be regarded as part of PTO-ISA alongside `pto` dialect operations. + +### Shared MLIR Dialects + +- `arith`: the full scalar `arith` surface is supported in PTO micro Instruction programs, covering scalar integer, floating-point, boolean, and `index` operations. In current samples the most common uses are still constants, offset/bounds arithmetic, casts, compares, and selects. +- `scf`: structured control flow used to model counted loops, conditional regions, loop-carried state, and break-like control around PTO compute and data-movement ops. +- Shared dialect ops remain in standard MLIR form so that PTO analyses and backend passes can reason about control flow and scalar state without re-encoding them as PTO-specific instructions. + +### BlockDim Query Operations + +These ops expose the current kernel instance's execution coordinates to scalar code. They are the PTO-level equivalent of runtime queries such as `GetBlockIdx()` and `GetBlockNum()` in kernel programming models. + +Use them when the same kernel body is launched across multiple blocks or subblocks and each execution instance must figure out which slice of the global workload it owns. + +A common pattern is: + +- split the full input/output tensor into `block_num` disjoint block-sized regions +- let each block compute its own starting offset from `block_idx` +- within one block, further tile the local region and drive the tile loop with ordinary scalar `arith` / `scf` ops + +For example, if a tensor is split evenly across 8 blocks and each block handles `block_length = 2048` elements, then block `b` owns the global range `[b * block_length, (b + 1) * block_length)`. The per-block GM base pointer can be formed by adding `block_idx * block_length` elements to the original base pointer. + +At the PTO micro Instruction level, these runtime-query ops are pure scalar producers. They do not perform data movement, do not allocate memory, and do not by themselves create tiling or double buffering. Instead, they provide the scalar values used by surrounding address computation and structured control flow. + +#### Example: block-level data partitioning + +```mlir +%block = pto.get_block_idx +%block_num = pto.get_block_num +%block_len = arith.constant 2048 : index +%base = arith.index_cast %block : i64 to index +%offset = arith.muli %base, %block_len : index +%block_in = pto.addptr %gm_in, %offset : !pto.ptr -> !pto.ptr +%block_out = pto.addptr %gm_out, %offset : !pto.ptr -> !pto.ptr +``` + +In this pattern, all blocks execute the same kernel body, but each block sees a different `%block` value and therefore computes a different GM window. + +#### `pto.get_block_idx` + +- **syntax:** `%block = pto.get_block_idx` +- **result:** `i64` +- **semantics:** Return the current block ID in the range `[0, pto.get_block_num())`. + +```c +block = block_idx(); +``` + +#### `pto.get_subblock_idx` + +- **syntax:** `%subblock = pto.get_subblock_idx` +- **result:** `i64` +- **semantics:** Return the current subblock ID in the range `[0, pto.get_subblock_num())`. + +```c +subblock = subblock_idx(); +``` + +#### `pto.get_block_num` + +- **syntax:** `%block_num = pto.get_block_num` +- **result:** `i64` +- **semantics:** Return the total number of launched blocks visible to the current kernel instance. + +```c +block_num = block_num(); +``` + +#### `pto.get_subblock_num` + +- **syntax:** `%subblock_num = pto.get_subblock_num` +- **result:** `i64` +- **semantics:** Return the total number of visible subblocks for the current execution instance. + +```c +subblock_num = subblock_num(); +``` + +Typical usage: + +```mlir +%block = pto.get_block_idx +%subblock = pto.get_subblock_idx +%block_num = pto.get_block_num +%subblock_num = pto.get_subblock_num +``` + +### Core Types + +### Element Types +`vreg`: `!pto.vreg` Fixed-width PTO micro Instruction vector type with total width exactly 256 bytes (2048 bits). `N` is the lane count, `T` is the element type, and `N * bitwidth(T) = 2048`. + +| Type | Bits | Description | +|------|------|-------------| +| `i8` / `si8` / `ui8` | 8 | Signless/signed/unsigned 8-bit integer | +| `i16` / `si16` / `ui16` | 16 | Signless/signed/unsigned 16-bit integer | +| `i32` / `si32` / `ui32` | 32 | Signless/signed/unsigned 32-bit integer | +| `i64` / `si64` / `ui64` | 64 | Signless/signed/unsigned 64-bit integer | +| `f16` | 16 | IEEE 754 half precision | +| `bf16` | 16 | Brain floating point | +| `f32` | 32 | IEEE 754 single precision | + +### Mask Types + +`mask`: `!pto.mask` Typed predicate-register view. `G` is one of `b8`, `b16`, `b32` and records the byte-granularity interpretation used by VPTO ops and verifiers. + +Typed masks are also the primary legality contract for predicated VPTO code: + +- vector ops over `f32`, `i32`, `si32`, and `ui32` consume `!pto.mask` +- vector ops over `f16`, `bf16`, `i16`, `si16`, and `ui16` consume + `!pto.mask` +- vector ops over 8-bit element families consume `!pto.mask` +- compare families keep seed-mask and result-mask granularity aligned with the + compared vector family +- carry families keep carry-in, carry-out, and execution-mask granularity + aligned with the data-vector family +- mask-only ops that do not explicitly change granularity preserve the same `G` + +### Address Space Conventions + +PTO micro Instruction memory operands use `!pto.ptr`. This specification models the following memory-space attributes: + +| Space | Interpretation | +|-------|----------------| +| `gm` | Global Memory (GM), off-chip HBM/DDR storage | +| `ub` | Unified Buffer (UB), on-chip vector buffer | + +Typical pointer construction and pointer arithmetic follow the same `!pto.ptr<..., space>` form: + +```mlir +%0 = pto.castptr %c0 : i64 -> !pto.ptr +%1 = pto.addptr %0, %c1024 : !pto.ptr -> !pto.ptr +``` + +### `!pto.ptr` + +`!pto.ptr` is the typed pointer form used for explicit memory operands in PTO micro Instruction. + +- `T` is the element type associated with the pointed-to storage. +- `space` is the memory domain, typically `gm` or `ub` in this specification. +- A `pto.ptr` value carries an address plus its element-type / memory-space interpretation, but it does not carry tensor shape or stride metadata by itself. +- Tensor semantics are introduced separately through view-building operations such as `pto.make_tensor_view`. +- Pointer arithmetic is element-based rather than byte-based. + +Typical examples: + +- `!pto.ptr` +- `!pto.ptr` +- `!pto.ptr` + +### Pointer Operations + +#### `pto.castptr` + +- **syntax:** `%result = pto.castptr %addr : i64 -> !pto.ptr` +- **semantics:** Reinterpret a scalar address value as a typed PTO pointer in the target memory space. + +```c +result = (ptr)addr; +``` + +`pto.castptr` is a pointer-construction operation. It does not perform data movement and does not by itself imply any load/store side effect. + +#### `pto.addptr` + +- **syntax:** `%result = pto.addptr %ptr, %offset : !pto.ptr -> !pto.ptr` +- **semantics:** Compute a new pointer by advancing the base pointer by an element offset. + +```c +result = ptr + offset; // offset counted in elements, not bytes +``` + +`pto.addptr` preserves both the element type `T` and the memory-space tag `space`. + +#### `pto.load_scalar` + +- **syntax:** `%value = pto.load_scalar %ptr[%offset] : !pto.ptr -> T` +- **semantics:** Load one scalar element from a pointer-like operand. + +```c +value = ptr[offset]; +``` + +- **inputs:** + `%ptr` is a typed PTO pointer `!pto.ptr`, and `%offset` is an + `index` displacement counted in elements. +- **outputs:** + `%value` is the loaded scalar element. +- **constraints and limitations:** + The result type MUST match the element type of `%ptr`. This op is a scalar + memory helper; unlike `pto.vlds`, it does not produce a `vreg` result and + does not participate in vector load `dist` families. + +#### `pto.store_scalar` + +- **syntax:** `pto.store_scalar %value, %ptr[%offset] : !pto.ptr, T` +- **semantics:** Store one scalar element to a pointer-like operand. + +```c +ptr[offset] = value; +``` + +- **inputs:** + `%value` is the scalar value to store. `%ptr` is a typed PTO pointer + `!pto.ptr`, and `%offset` is an `index` displacement counted in + elements. +- **constraints and limitations:** + The stored value type MUST match the element type of `%ptr`. This op is a + scalar memory helper; unlike `pto.vsts`, it does not consume a mask and does + not target vector-store `dist` families. + +#### Pointer-Based Vector Access Example + +The following lowered-style fragment shows how typed PTO pointers flow through +pointer construction, pointer arithmetic, structured control flow, and PTO +memory ops. Scalar memory access is expressed on `!pto.ptr` in +general, but the common VPTO pattern here is UB-local scalar access alongside +UB vector loads/stores: + +```mlir +%0 = pto.castptr %c0 : i64 -> !pto.ptr +%1 = pto.addptr %0, %c1024 : !pto.ptr -> !pto.ptr +pto.vecscope { + %16 = scf.for %arg3 = %c0 to %11 step %c64 iter_args(%arg4 = %12) -> (i32) { + %mask, %scalar_out = pto.plt_b32 %arg4 : i32 -> !pto.mask, i32 + %s = pto.load_scalar %1[%c4] : !pto.ptr -> f32 + pto.store_scalar %s, %1[%c8] : !pto.ptr, f32 + %17 = pto.vlds %1[%arg3] : !pto.ptr -> !pto.vreg<64xf32> + %18 = pto.vabs %17, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %18, %10[%arg3], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %scalar_out : i32 + } +} +``` + +In this pattern, `pto.castptr` materializes a typed UB pointer, `pto.addptr` shifts the base by 1024 `f32` elements, and the subsequent `[%arg3]` indexing on `pto.vlds` / `pto.vsts` applies an additional element offset relative to that base. + +### Special Types + +#### `!pto.mask` + +`!pto.mask` models an A5 predicate register (256-bit) under a typed granularity view, not an integer vector. + +`G` is part of the type and MUST be one of: + +- `b32` +- `b16` +- `b8` + +All three forms describe the same physical 256-bit predicate-register class. The type parameter does not encode how many lanes are currently active. Instead, it records how VPTO interprets the register when matching mask-producing ops, mask-consuming ops, and verifier legality rules. + +In the ISA chapters below, this document uses `!pto.mask` as shorthand when a +family is generic over granularity. For op families whose names already encode +the granularity, such as `pset_b32`, `pge_b16`, `plt_b8`, +`pdintlv_b8`, and `pintlv_b16`, examples use the corresponding concrete typed +mask. + +**Mask Granularity:** + +The predicate register is 256 bits in length, where each bit controls 1 byte of data. `G` therefore describes how many bytes form one logical element slot: + +| Mask Type | Bytes / Element Slot | Typical Element Family | Derived Logical Lanes | +|-----------|----------------------|------------------------|-----------------------| +| `!pto.mask` | 4 | `f32` / `i32` | 64 | +| `!pto.mask` | 2 | `f16` / `bf16` / `i16` | 128 | +| `!pto.mask` | 1 | 8-bit element family | 256 | + +This is intentionally different from a lane-vector model such as `mask<64xi1>`: + +- `!pto.mask` still denotes a 256-bit predicate register; +- `64` is only the derived logical lane count for the `b32` view; +- value-level patterns such as `PAT_VL32` describe which lanes are active, not a different type. +- `pto.vaddc`, `pto.vsubc`, `pto.vaddcs`, and `pto.vsubcs` use `!pto.mask` + to carry their per-lane carry results, interpreted with this same + granularity. + +**Predication Behavior (Zero-Merge):** + +The native hardware predication mode is **ZEROING** — inactive lanes produce zero: + +```c +dst[i] = mask[i] ? op(src0[i], src1[i]) : 0 // ZEROING mode +``` + +```mlir +// Predicated add: inactive lanes produce zero +%mask = pto.pset_b32 "PAT_VL32" : !pto.mask // first 32 logical b32 lanes active +%result = pto.vcmp %a, %b, %mask, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +``` + +```mlir +// Compare and select: generate mask from comparison, use for conditional select +%mask = pto.vcmp %lhs, %rhs, %seed, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +%out = pto.vsel %x, %y, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +#### `!pto.align` + +`!pto.align` models the A5 vector-align carrier state. It is not payload data. + +```mlir +%align = pto.vldas %ub : !pto.ptr -> !pto.align +%vec, %align_out = pto.vldus %ub, %align : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align + +%store_align = pto.init_align : !pto.align +%next_align = pto.vstus %store_align, %offset, %vec, %ub + : !pto.align, i32, !pto.vreg<64xf32>, !pto.ptr -> !pto.align +``` + +--- + +## Part II: Notation Convention + +This section defines the MLIR syntax patterns and C-style semantic notation used throughout the ISA reference (Part III). + +### MLIR Op Syntax Patterns + +All PTO micro Instruction operations follow standard MLIR syntax. The common patterns are: + +**Unary (one vector in, one vector out):** + +```mlir +%result = pto. %input : !pto.vreg -> !pto.vreg +``` + +**Binary (two vectors in, one vector out):** + +```mlir +%result = pto. %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg +``` + +**Vec-Scalar (one vector + one scalar in, one vector out):** + +```mlir +%result = pto. %input, %scalar : !pto.vreg, T -> !pto.vreg +``` + +**Load (memory to register):** + +```mlir +%result = pto.vlds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.vreg +``` + +**Store (register to memory):** + +```mlir +pto.vsts %value, %destination[%offset] {dist = "DIST"} : !pto.vreg, !pto.ptr +``` + +**Dual Load (one load, two results — deinterleave):** + +```mlir +%low, %high = pto.vldsx2 %source[%offset], "DIST" : !pto.ptr, index -> !pto.vreg, !pto.vreg +``` + +**Dual Store (two inputs, one interleaved store):** + +```mlir +pto.vstsx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vreg, !pto.ptr, index, !pto.mask +``` + +**Compare (two vectors + seed mask in, mask out):** + +```mlir +%mask = pto.vcmp %src0, %src1, %seed, "CMP_MODE" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.mask +``` + +**Conversion (one vector in, different-typed vector out):** + +```mlir +%result = pto.vcvt %input {rnd = "R", sat = "SAT", part = "EVEN"} : !pto.vreg -> !pto.vreg +``` + +**Predicate construction:** + +```mlir +%mask = pto.pset_b32 "PAT_ALL" : !pto.mask +%tail = pto.pge_b32 "PAT_VL16" : !pto.mask +``` + +**Sync operations:** + +```mlir +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +**Pointer construction and arithmetic:** + +```mlir +%ptr = pto.castptr %addr : i64 -> !pto.ptr +%ptr2 = pto.addptr %ptr, %offset : !pto.ptr -> !pto.ptr +``` + +### Shared Dialect Syntax Patterns + +PTO micro Instruction programs may interleave PTO ops with standard MLIR `arith` and `scf` ops. +The examples below emphasize common index-heavy patterns, but `arith` support is not limited to index arithmetic. + +**Scalar / index constant:** + +```mlir +%c0 = arith.constant 0 : index +%zero = arith.constant 0.0 : f32 +``` + +**Scalar arithmetic (integer / float / boolean-style bitwise):** + +```mlir +%sum_i = arith.addi %lhs_i, %rhs_i : i32 +%sum_f = arith.addf %lhs_f, %rhs_f : f32 +%bits = arith.andi %flags0, %flags1 : i32 +``` + +**Scalar compare and select:** + +```mlir +%cond = arith.cmpi eq, %lhs, %rhs : index +%bound = arith.select %cond, %a, %b : index +``` + +**Counted loop with loop-carried values:** + +```mlir +%result = scf.for %iv = %lb to %ub step %step + iter_args(%acc = %init) -> (index) { + %next = arith.addi %acc, %iv : index + scf.yield %next : index +} +``` + +**Structured conditional region:** + +```mlir +%selected = scf.if %cond -> (index) { + scf.yield %then_value : index +} else { + scf.yield %else_value : index +} +``` + +**Structured while loop:** + +```mlir +%state:2 = scf.while (%iv = %c0, %alive = %true) : (index, i1) -> (index, i1) { + %keep_going = arith.cmpi slt, %iv, %limit : index + scf.condition(%keep_going) %iv, %alive : index, i1 +} do { +^bb0(%iv_in: index, %alive_in: i1): + %iv_next = arith.addi %iv_in, %c1 : index + scf.yield %iv_next, %alive_in : index, i1 +} +``` + +### C-Style Semantics Convention + +For each ISA operation in Part III, semantics are expressed as C code. The convention: + +```c +// Vector register contents as arrays: +T dst[N]; // destination +T src0[N]; // first source +T src1[N]; // second source (binary ops) +T scalar; // scalar operand (vec-scalar ops) +int mask[N]; // per-lane predicate (0 or 1) + +// N = lane count determined by type: +// N = 256 for i8/si8/ui8 +// N = 128 for i16/si16/ui16/f16/bf16 +// N = 64 for i32/si32/ui32/f32 +// N = 32 for i64/si64/ui64 +``` + +**Example — pto.vadd semantics:** + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +**Example — pto.vcgadd (group reduction per VLane) semantics:** + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +### Template Placeholder Conventions + +| Placeholder | Meaning | +|-------------|---------| +| `"SRC_PIPE"`, `"DST_PIPE"` | Pipeline identifiers: `"PIPE_MTE2"`, `"PIPE_V"`, `"PIPE_MTE3"` | +| `"EVENT_ID"` | Event identifier: `"EVENT_ID0"` etc. | +| `"DIST"` | Distribution mode string (see the relevant load/store ISA group in Part III) | +| `"CMP_MODE"` | Compare predicate: `eq \| ne \| lt \| le \| gt \| ge` | +| `"RND"` | Rounding mode: `R \| A \| F \| C \| Z \| O` | +| `"SAT"` | Saturation: `SAT \| NOSAT` | +| `"PART"` | Half selector: `EVEN \| ODD` | +| `"PAT_*"` | Predicate pattern literal | +| `T` | Element type (f32, f16, bf16, i32, i16, i8, etc.) | +| `N` | Lane count (`N * bitwidth(T) = 2048`) | + +--- + +### C-Style Semantics Convention + +For each ISA operation in Part III, semantics are expressed as C code. The convention: + +```c +// Vector register contents as arrays: +T dst[N]; // destination +T src0[N]; // first source +T src1[N]; // second source (binary ops) +T scalar; // scalar operand (vec-scalar ops) +int mask[N]; // per-lane predicate (0 or 1) + +// N = lane count determined by type: +// N = 256 for i8/si8/ui8 +// N = 128 for i16/si16/ui16/f16/bf16 +// N = 64 for i32/si32/ui32/f32 +// N = 32 for i64/si64/ui64 +``` + +**Example — pto.vadd semantics:** + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +**Example — pto.vcgadd (group reduction per VLane) semantics:** + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +### Template Placeholder Conventions + +| Placeholder | Meaning | +|-------------|---------| +| `"SRC_PIPE"`, `"DST_PIPE"` | Pipeline identifiers: `"PIPE_MTE2"`, `"PIPE_V"`, `"PIPE_MTE3"` | +| `"EVENT_ID"` | Event identifier: `"EVENT_ID0"` etc. | +| `"DIST"` | Distribution mode string (see the relevant load/store ISA group in Part III) | +| `"CMP_MODE"` | Compare predicate: `eq \| ne \| lt \| le \| gt \| ge` | +| `"RND"` | Rounding mode: `R \| A \| F \| C \| Z \| O` | +| `"SAT"` | Saturation: `SAT \| NOSAT` | +| `"PART"` | Half selector: `EVEN \| ODD` | +| `"PAT_*"` | Predicate pattern literal | +| `T` | Element type (f32, f16, bf16, i32, i16, i8, etc.) | +| `N` | Lane count (`N * bitwidth(T) = 2048`) | + +--- + +## Part III: ISA Instruction Reference — Summary + +This section provides a categorized overview of all PTO micro Instruction operations plus the shared MLIR `arith` and `scf` ops that may appear in PTO micro Instruction programs. Detailed documentation for each group is included later in this merged document. + +--- + +## Instruction Groups + +| # | Group | Description | Count | Details | +|---|-------|-------------|-------|---------| +| 1 | [Pipeline Sync](#isa-01-pipeline-sync) | Intra-core pipeline synchronization | 5 | `pto.set_flag`, `pto.wait_flag`, `pto.pipe_barrier`, `pto.get_buf`, `pto.rls_buf` | +| 2 | [DMA Copy Programming](#isa-02-dma-copy) | DMA configuration and transfer between GM↔UB | 9 | `pto.set_loop*_stride_*`, `pto.set_loop_size_*`, `pto.copy_gm_to_ubuf`, `pto.copy_ubuf_to_ubuf`, `pto.copy_ubuf_to_gm` | +| 3 | [Vector Load/Store](#isa-03-vector-load-store) | UB↔vreg data movement with various access patterns | ~20 | `pto.vlds`, `pto.vldsx2`, `pto.vgather2`, `pto.vsts`, `pto.vstsx2`, `pto.vscatter`, etc. | +| 4 | [Predicate Load/Store](#isa-04-predicate-load-store) | UB↔mask register movement | 5 | `pto.plds`, `pto.pldi`, `pto.psts`, `pto.psti`, `pto.pstu` | +| 5 | [Materialization & Predicate Ops](#isa-05-materialization-predicate) | Scalar broadcast, predicate generation and manipulation | ~17 | `pto.vbr`, `pto.vdup`, `pto.pset_b*`, `pto.pge_b*`, `pto.plt_b*`, `pto.ppack`, `pto.punpack`, `pto.pnot`, `pto.psel`, etc. | +| 6 | [Unary Vector Ops](#isa-06-unary-vector-ops) | Single-input element-wise operations | 6 | `pto.vabs`, `pto.vexp`, `pto.vln`, `pto.vsqrt`, `pto.vrelu`, `pto.vnot` | +| 7 | [Binary Vector Ops](#isa-07-binary-vector-ops) | Two-input element-wise operations | 13 | `pto.vadd`, `pto.vsub`, `pto.vmul`, `pto.vdiv`, `pto.vmax`, `pto.vmin`, `pto.vand`, `pto.vor`, `pto.vxor`, `pto.vshl`, `pto.vshr`, `pto.vaddc`, `pto.vsubc` | +| 8 | [Vec-Scalar Ops](#isa-08-vec-scalar-ops) | Vector-scalar operations | 9 | `pto.vadds`, `pto.vmuls`, `pto.vmaxs`, `pto.vmins`, `pto.vlrelu`, `pto.vshls`, `pto.vshrs`, `pto.vaddcs`, `pto.vsubcs` | +| 9 | [Conversion Ops](#isa-09-conversion-ops) | Type conversion with rounding/saturation control | 2 | `pto.vcvt`, `pto.vtrc` | +| 10 | [Reduction Ops](#isa-10-reduction-ops) | Vector reductions | 7 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin`, `pto.vcgadd`, `pto.vcgmax`, `pto.vcgmin`, `pto.vcpadd` | +| 11 | [Compare & Select](#isa-11-compare-select) | Comparison and conditional selection | 4 (+1 not A5) | `pto.vcmp`, `pto.vcmps`, `pto.vsel`, `pto.vselr` (`pto.vselrv2` removed: not A5) | +| 12 | [Data Rearrangement](#isa-12-data-rearrangement) | In-register data movement and permutation | 2 (+2 not A5) | `pto.vintlv`, `pto.vdintlv` (`pto.vintlvv2`, `pto.vdintlvv2` removed: not A5) | +| 13 | [DSA/SFU Ops](#isa-13-dsa-sfu-ops) | Specialized ops, index generation, and sorting helpers | 9 | `pto.vlrelu`, `pto.vprelu`, `pto.vexpdiff`, `pto.vaxpy`, `pto.vmull`, `pto.vmula`, `pto.vci`, `pto.vbitsort`, `pto.vmrgsort4` | +| 14 | [Arith (Shared MLIR Dialect)](#isa-14-shared-arith) | Full scalar `arith` surface used around PTO ops; the companion page lists categories and representative examples | all scalar ops | `arith.constant`, `arith.addi`, `arith.addf`, `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.index_cast`, `arith.extsi`, `arith.trunci`, `arith.andi`, `arith.shli`, etc. | +| 15 | [SCF (Shared MLIR Dialect)](#isa-15-shared-scf) | Structured loops, branches, and loop-carried state around PTO regions | 5 | `scf.for`, `scf.if`, `scf.while`, `scf.condition`, `scf.yield` | + +--- + +## Detailed ISA Group Reference + +This section inlines the 15 ISA group documents so the architectural overview, notation, summary table, and per-group semantics can be read in a single file. + + + +### 1. Pipeline Synchronization + +> **Category:** Synchronization primitives for coordinating pipeline execution +> **Pipelines:** MTE2 (GM→UB), PIPE_V (Vector), MTE3 (UB→GM) + +The PTO micro Instruction model operates on the Ascend 950's **Decoupled Access-Execute** architecture. The MTE and Vector pipelines run asynchronously, requiring explicit synchronization to prevent data hazards. + +--- + +#### Intra-Core Pipeline Sync + +These ops coordinate data flow between pipelines within a single vector core. + +##### `pto.set_flag` + +- **syntax:** `pto.set_flag["SRC_PIPE", "DST_PIPE", "EVENT_ID"]` +- **semantics:** Signal event from source pipe to destination pipe. + +```c +set_flag(src_pipe, dst_pipe, event_id); +``` + +**Example:** After MTE2 completes GM→UB transfer, signal Vector pipe: +```mlir +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +--- + +##### `pto.wait_flag` + +- **syntax:** `pto.wait_flag["SRC_PIPE", "DST_PIPE", "EVENT_ID"]` +- **semantics:** Block destination pipe until source pipe signals event. + +```c +wait_flag(src_pipe, dst_pipe, event_id); +``` + +**Example:** Vector pipe waits for MTE2 data to arrive: +```mlir +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +--- + +##### `pto.pipe_barrier` + +- **syntax:** `pto.pipe_barrier "PIPE_*"` +- **semantics:** Drain all pending ops in the specified pipe. All previously issued operations on that pipe complete before any subsequent operation begins. + +```c +pipe_barrier(pipe); +``` + +**Pipe identifiers:** `PIPE_MTE2`, `PIPE_V`, `PIPE_MTE3` + +**Example:** Two back-to-back `copy_ubuf_to_gm` calls writing to the same GM address. Without a barrier, MTE3 may reorder them and the final GM value is non-deterministic: + +```mlir +// Both stores target the same GM address — order matters! +pto.copy_ubuf_to_gm %ub_partial_0, %gm_result, ... +// Without pipe_barrier, MTE3 could execute the second copy before the first +// completes, producing a non-deterministic result at %gm_result. +pto.pipe_barrier "PIPE_MTE3" +// After barrier: first copy is guaranteed complete. Second copy overwrites deterministically. +pto.copy_ubuf_to_gm %ub_partial_1, %gm_result, ... +``` + +--- + +##### `pto.get_buf` + +- **syntax:** `pto.get_buf "PIPE_*", %buf_id, %mode : i64, i64` +- **semantics:** Acquire buffer slot for inter-pipeline double-buffering coordination. + +```c +get_buf(pipe, buf_id, mode); +``` + +--- + +##### `pto.rls_buf` + +- **syntax:** `pto.rls_buf "PIPE_*", %buf_id, %mode : i64, i64` +- **semantics:** Release buffer slot to allow other pipeline to proceed. + +```c +rls_buf(pipe, buf_id, mode); +``` + +--- + +##### Mode Parameter for `get_buf` / `rls_buf` + +The `mode` parameter controls how `get_buf` and `rls_buf` interact with pipeline execution and dependency tracking: + +| Mode | `get_buf` Behavior | `rls_buf` Behavior | Use Case | +|------|-------------------|-------------------|----------| +| **0** (default) | **Blocking acquire**: waits for all previous `rls_buf` with same `buf_id` from all pipelines (in program order) before the specified pipe can proceed | **Immediate release**: signals completion for only the instructions related to the specified pipe | **Automatic ping/pong dependency** — recommended for double/multi-buffering | +| **1** | **Non-blocking acquire**: does not wait; pipe execution proceeds immediately | **Deferred release**: waits for all instructions across all pipelines with same `buf_id` to retire before signaling | **Backward compatibility** with `set_flag`/`wait_flag` semantics | + +**Mode 0 (Default — Recommended):** +- `get_buf`: The specified pipeline blocks until all previous `rls_buf` operations for the same buffer ID (from any pipeline) have completed, respecting program order. +- `rls_buf`: Immediately signals that the specified pipeline has finished using the buffer — only waits for that pipe's related instructions. +- This mode provides **automatic RAW/WAR/WAW dependency resolution** based on buffer ID and program order, making it ideal for ping/pong and N-buffer patterns. + +**Mode 1 (Legacy Compatibility):** +- `get_buf`: Does not block — the pipeline proceeds immediately without waiting. +- `rls_buf`: Waits for **all** previous instructions across **all** pipelines with the same buffer ID to retire before signaling release. +- This mode emulates `set_flag`/`wait_flag` behavior and is provided for backward compatibility with existing code patterns. + +> **Note:** A5 supports both `set_flag`/`wait_flag` and `get_buf`/`rls_buf` mechanisms. Mode 1 is rarely needed since mode 0 provides a more programmer-friendly approach for buffer-based synchronization. + +--- + +##### `pto.mem_bar` + +- **syntax:** `pto.mem_bar "BARRIER_TYPE"` +- **semantics:** Intra-vector-pipe memory fence within `__VEC_SCOPE__`. Required when UB addresses alias between vector load/store operations. + +```c +mem_bar(barrier_type); +``` + +**Barrier types:** + +| Type | Semantics | +|------|-----------| +| `VV_ALL` | All prior vector ops complete before subsequent | +| `VST_VLD` | All prior vector stores visible before subsequent loads | +| `VLD_VST` | All prior vector loads complete before subsequent stores | + +**Example:** Ensure stores are visible before loads to same UB region: +```mlir +pto.vsts %v0, %ub[%c0] : !pto.vreg<64xf32>, !pto.ptr +pto.mem_bar "VST_VLD" +%v1 = pto.vlds %ub[%c0] : !pto.ptr -> !pto.vreg<64xf32> +``` + +--- + +#### Why `get_buf` / `rls_buf` is More Programmer-Friendly + +The buffer-based synchronization (`get_buf`/`rls_buf`) provides the **same functional capability** as `set_flag`/`wait_flag` for maintaining correct ordering of RAW/WAR/WAW dependencies across pipelines, but with significant usability advantages: + +##### 1. No Manual Priming or Draining + +With `set_flag`/`wait_flag`, ping/pong loops require: +- **Pre-loop priming**: 4× `set_flag` to initialize reverse-dependency signals (otherwise first iteration deadlocks) +- **Post-loop draining**: 4× `wait_flag` to consume leftover signals from final iterations + +With `get_buf`/`rls_buf`: +- **First iteration**: Buffer is initially free, so `get_buf` proceeds immediately — no priming needed +- **Final iteration**: Last `rls_buf` simply completes — no draining required + +##### 2. No Loop Peeling for Complex Dependencies + +For non-1:1 producer-consumer ratios (e.g., 1 MTE2 load : N Vector compute slices), `set_flag`/`wait_flag` requires **peeling the set_flag outside the loop**: + +```mlir +// set_flag/wait_flag: 1 MTE2 load, 8 Vector computes on slices +// MTE2 loads large tile once +pto.copy_gm_to_ubuf %gm_ptr, %ub_tile, ... +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVT_TILE_READY"] // ◀ MUST be outside loop + +// Vector consumes in 8 slices — but wait_flag can only fire ONCE +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVT_TILE_READY"] // ◀ MUST peel before loop +scf.for %slice = %c0 to %c8 step %c1 { + // compute on %ub_tile[%slice] + // Cannot put wait_flag here — would deadlock on iteration 1+ +} +``` + +With `get_buf`/`rls_buf`, acquire/release can be **inside the loop** — no peeling needed: + +```mlir +// get_buf/rls_buf: same 1:8 pattern, acquire/release inside loop works fine +// MTE2 loads large tile +pto.get_buf "PIPE_MTE2", %bufid_tile, %c0 : i64, i64 +pto.copy_gm_to_ubuf %gm_ptr, %ub_tile, ... +pto.rls_buf "PIPE_MTE2", %bufid_tile, %c0 : i64, i64 + +// Vector acquires/releases per slice — all 8 iterations work correctly +scf.for %slice = %c0 to %c8 step %c1 { + pto.get_buf "PIPE_V", %bufid_tile, %c0 : i64, i64 // iteration 0: blocks until MTE2 done + // iteration 1-7: proceeds immediately (already acquired) + // compute on %ub_tile[%slice] + pto.rls_buf "PIPE_V", %bufid_tile, %c0 : i64, i64 +} +// No peeling required — get_buf handles the MTE2→V dependency automatically +``` + +##### 3. Simpler Mental Model + +| Aspect | `set_flag`/`wait_flag` | `get_buf`/`rls_buf` | +|--------|------------------------|---------------------| +| **Dependency tracking** | Manual: track event IDs, signal directions, pair every set with wait | Automatic: buffer ID + program order | +| **Event ID management** | **8 IDs per pipe-pair direction** (HW limit); each buffer occupies 1 ID per direction | **1 buffer ID per shared resource** (HW limit: 32 global); same ID used across all pipelines | +| **Error-prone areas** | Forgetting prime/drain, mismatched IDs, wrong direction | Forgetting release (but compile-time checkable) | + +##### Quick Example Comparison + +**Problem:** MTE2 loads into `buf[i%2]`, Vector processes, MTE3 stores — standard ping/pong. + +**set_flag/wait_flag approach:** +```mlir +// BEFORE loop: prime 4 reverse-dep signals +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] + +scf.for %i = ... { + // 4 set_flag + 4 wait_flag inside loop + // Must track 4 IDs: 2 pipe-pair directions × 2 ping/pong buffers +} + +// AFTER loop: drain 4 signals +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] +``` + +**get_buf/rls_buf approach:** +```mlir +scf.for %i = ... { + pto.get_buf %bufid_in[%pp], "PIPE_MTE2" + // ... MTE2 work ... + pto.rls_buf %bufid_in[%pp], "PIPE_MTE2" + + pto.get_buf %bufid_in[%pp], "PIPE_V" + pto.get_buf %bufid_out[%pp], "PIPE_V" + // ... Vector work ... + pto.rls_buf %bufid_in[%pp], "PIPE_V" + pto.rls_buf %bufid_out[%pp], "PIPE_V" + + pto.get_buf %bufid_out[%pp], "PIPE_MTE3" + // ... MTE3 work ... + pto.rls_buf %bufid_out[%pp], "PIPE_MTE3" +} +// Done. No prime. No drain. Dependencies resolved by buffer ID + program order. +``` + +--- + +#### Intra-Core Sync Patterns & Examples + +##### Example 1: `set_flag` / `wait_flag` (Explicit Events) + +Each cross-pipeline data dependency requires an explicit signal/wait pair. The programmer must manually insert `set_flag` after the producer and `wait_flag` before the consumer. + +```mlir +// ─── Stage 1: MTE2 loads data from GM into UB ─── +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, ... + +// MTE2 signals: "UB data is ready for Vector pipe" +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +// ─── Stage 2: Vector pipe consumes UB data ─── +// Vector waits until MTE2's signal arrives +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +pto.vecscope { + %v = pto.vlds %ub_ptr[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} + +// Vector signals: "UB output is ready for MTE3" +pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + +// ─── Stage 3: MTE3 stores result from UB back to GM ─── +// MTE3 waits until Vector's signal arrives +pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + +pto.copy_ubuf_to_gm %ub_out, %gm_out, ... +``` + +**Key property:** Every cross-pipeline edge is an explicit `(set_flag, wait_flag)` pair. Simple for straight-line code, but gets verbose in loops (see Example 3). + +--- + +##### Example 2: `get_buf` / `rls_buf` (Resource-Based) + +Instead of naming events, each pipeline declares when it **acquires** (`get_buf`) and **releases** (`rls_buf`) a shared UB buffer. Cross-pipeline RAW/WAR dependencies are resolved implicitly by program order — if MTE2 releases `buf_A` and Vector later acquires `buf_A`, the hardware ensures the acquire cannot proceed until the release completes. + +```mlir +// ─── Stage 1: MTE2 loads data into UB ─── +// MTE2 acquires ub_ptr — blocks if Vector hasn't released it from a prior iteration +pto.get_buf "PIPE_MTE2", %bufid_ub_ptr, %c0 : i64, i64 // mode=0 (default) +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, ... +// MTE2 done writing ub_ptr — release it so Vector can consume +pto.rls_buf "PIPE_MTE2", %bufid_ub_ptr, %c0 : i64, i64 + +// ─── Stage 2: Vector computation ─── +// Vector acquires ub_ptr (input) — blocks until MTE2 releases it (RAW: MTE2 write → V read) +pto.get_buf "PIPE_V", %bufid_ub_ptr, %c0 : i64, i64 +// Vector acquires ub_out (output) — blocks until MTE3 releases it from a prior iteration (WAR: MTE3 read → V write) +pto.get_buf "PIPE_V", %bufid_ub_out, %c0 : i64, i64 + +pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub_ptr[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} + +// Vector done reading ub_ptr — release so MTE2 can reuse it in next iteration +pto.rls_buf "PIPE_V", %bufid_ub_ptr, %c0 : i64, i64 +// Vector done writing ub_out — release so MTE3 can consume +pto.rls_buf "PIPE_V", %bufid_ub_out, %c0 : i64, i64 + +// ─── Stage 3: MTE3 stores result to GM ─── +// MTE3 acquires ub_out — blocks until Vector releases it (RAW: V write → MTE3 read) +pto.get_buf "PIPE_MTE3", %bufid_ub_out, %c0 : i64, i64 +pto.copy_ubuf_to_gm %ub_out, %gm_out, ... +// MTE3 done reading ub_out — release so Vector can reuse it in next iteration +pto.rls_buf "PIPE_MTE3", %bufid_ub_out, %c0 : i64, i64 +``` + +**Key property:** No event IDs needed. Dependencies are implicit from program order of `get_buf`/`rls_buf` on the same buffer ID. This becomes much more convenient in multi-iteration loops (see Example 3). + +--- + +##### Example 3: Ping/Pong Double-Buffering Loop + +Double-buffering overlaps DMA and compute by using two UB buffers alternately. All three stages (MTE2, Vector, MTE3) appear in the **same iteration** — the hardware pipelines them across iterations because different iterations operate on different buffers (`buf[i%2]`). + +###### Event ID scheme (`set_flag` / `wait_flag`) + +With 2 ping/pong buffers and 2 pipeline pairs (MTE2↔V, V↔MTE3), `set_flag`/`wait_flag` needs **8 event IDs** = 2 pipe-pairs × 2 buffers × (forward + reverse): + +**MTE2 ↔ Vector (input buffers):** + +| Event ID | Direction | Purpose | +|----------|-----------|---------| +| `EVT_IN_FWD_0` | MTE2 → V | RAW: buf_in[0] data ready | +| `EVT_IN_FWD_1` | MTE2 → V | RAW: buf_in[1] data ready | +| `EVT_IN_REV_0` | V → MTE2 | WAR: Vector done reading buf_in[0] | +| `EVT_IN_REV_1` | V → MTE2 | WAR: Vector done reading buf_in[1] | + +**Vector ↔ MTE3 (output buffers):** + +| Event ID | Direction | Purpose | +|----------|-----------|---------| +| `EVT_OUT_FWD_0` | V → MTE3 | RAW: buf_out[0] result ready | +| `EVT_OUT_FWD_1` | V → MTE3 | RAW: buf_out[1] result ready | +| `EVT_OUT_REV_0` | MTE3 → V | WAR: MTE3 done reading buf_out[0] | +| `EVT_OUT_REV_1` | MTE3 → V | WAR: MTE3 done reading buf_out[1] | + +###### 3a. `set_flag` / `wait_flag` version + +```mlir +// ═══ Pre-loop: prime ALL reverse-dependency signals ═══ +// Both input and output buffers start unused. We must pre-send +// reverse-dep signals so the first iteration's wait_flags don't deadlock. +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] // ◀ PRIME: buf_in[0] "free" +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] // ◀ PRIME: buf_in[1] "free" +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] // ◀ PRIME: buf_out[0] "free" +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] // ◀ PRIME: buf_out[1] "free" + +scf.for %i = %c0 to %N step %c1 { + // ── All 3 stages in same iteration, indexed by i%2 ── + // %pp = i % 2 (ping/pong selector for buffer & event IDs) + + // ── MTE2: load tile[i] into buf_in[i%2] ── + // WAR: wait until Vector has released buf_in[i%2] from iteration i-2 + pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{pp}"] + pto.copy_gm_to_ubuf %gm_ptr[%i], %ub_in[%pp], ... + // RAW: signal Vector that buf_in[i%2] data is ready + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVT_IN_FWD_{pp}"] + + // ── Vector: compute buf_in[i%2] → buf_out[i%2] ── + // RAW: wait for MTE2 to finish loading buf_in[i%2] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVT_IN_FWD_{pp}"] + // WAR: wait for MTE3 to finish reading buf_out[i%2] from iteration i-2 + pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{pp}"] + scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_in[%pp][%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%pp][%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } {llvm.loop.aivector_scope} + // WAR: tell MTE2 "done reading buf_in[i%2]" + pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{pp}"] + // RAW: tell MTE3 "buf_out[i%2] result ready" + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVT_OUT_FWD_{pp}"] + + // ── MTE3: store result from buf_out[i%2] to GM ── + // RAW: wait for Vector to finish writing buf_out[i%2] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVT_OUT_FWD_{pp}"] + pto.copy_ubuf_to_gm %ub_out[%pp], %gm_out[%i], ... + // WAR: tell Vector "done reading buf_out[i%2]" + pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{pp}"] +} + +// ═══ Post-loop: drain — match every pre-loop prime with a wait ═══ +// Each priming set_flag must be paired. The last loop iteration's +// set_flags are consumed by wait_flags that will never fire inside the +// loop (there is no iteration i+2). Drain them here. +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-1)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-2)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-1)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-2)%2}"] // ◀ DRAIN +``` + +**What `set_flag`/`wait_flag` requires outside the loop:** +- **Before the loop (4 × `set_flag`):** Prime every reverse-dependency event ID — one per buffer per pipe-pair. Without this, the first iteration's `wait_flag` for reverse deps would deadlock (no signal was ever sent). +- **After the loop (4 × `wait_flag`):** Drain the matching reverse-dep signals from the last iterations. Every `set_flag` must be paired with a `wait_flag` — the last loop iterations produce signals that no subsequent iteration consumes, so they must be drained explicitly. + +###### 3b. `get_buf` / `rls_buf` version + +Same ping/pong double-buffering, but **no pre-loop priming or post-loop draining needed.** Buffer acquire/release semantics handle everything. + +```mlir +scf.for %i = %c0 to %N step %c1 { + // %pp = i % 2 (ping/pong selector) + + // ── MTE2: load tile[i] into buf[i%2] ── + // Acquires buf[i%2] — on first iteration, buffer is free so proceeds immediately. + // On later iterations, blocks until Vector releases buf[i%2] (WAR: automatic). + pto.get_buf "PIPE_MTE2", %bufid_buf[%pp], %c0 : i64, i64 // mode=0 + pto.copy_gm_to_ubuf %gm_ptr[%i], %ub_buf[%pp], ... + pto.rls_buf "PIPE_MTE2", %bufid_buf[%pp], %c0 : i64, i64 + + // ── Vector: compute on buf[i%2] ── + // Acquires buf[i%2] — blocks until MTE2 releases it (RAW: automatic) + pto.get_buf "PIPE_V", %bufid_buf[%pp], %c0 : i64, i64 + pto.get_buf "PIPE_V", %bufid_out[%pp], %c0 : i64, i64 + scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_buf[%pp][%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%pp][%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } {llvm.loop.aivector_scope} + // Release buf[i%2] — MTE2 can reuse in iteration i+2 (WAR resolved) + pto.rls_buf "PIPE_V", %bufid_buf[%pp], %c0 : i64, i64 + pto.rls_buf "PIPE_V", %bufid_out[%pp], %c0 : i64, i64 + + // ── MTE3: store result ── + // Acquires out[i%2] — blocks until Vector releases it (RAW: automatic) + pto.get_buf "PIPE_MTE3", %bufid_out[%pp], %c0 : i64, i64 + pto.copy_ubuf_to_gm %ub_out[%pp], %gm_out[%i], ... + pto.rls_buf "PIPE_MTE3", %bufid_out[%pp], %c0 : i64, i64 +} +// No post-loop drain needed — last rls_buf completes the pipeline. +``` + +**No priming, no draining, no event IDs.** The acquire/release protocol on buffer IDs indexed by `i%2` implicitly resolves all cross-pipeline dependencies: +- **RAW** (MTE2→V): Vector's `get_buf` blocks until MTE2's `rls_buf` on `buf[i%2]` +- **WAR** (V→MTE2): MTE2's `get_buf` in iteration `i+2` blocks until Vector's `rls_buf` in iteration `i` (same buffer) +- **First iteration:** Buffer is initially free, so `get_buf` proceeds without blocking — no priming needed + +--- + +#### Comparison Summary + +| Aspect | `set_flag` / `wait_flag` | `get_buf` / `rls_buf` | +|--------|--------------------------|------------------------| +| Dependency model | Explicit event signals | Implicit via buffer acquire/release | +| IDs per pipe-pair | 2 IDs per buffer: 1 for forward (e.g., MTE2→V) + 1 for reverse (V→MTE2) | **1 ID per buffer** (handles both directions automatically) | +| Total HW IDs | **8 per pipe-pair** (hardware limit) | **32 global** across all pipes | +| Reverse (WAR) deps | Extra `set_flag`/`wait_flag` pair per buffer | Handled automatically | +| Pre-loop setup | `set_flag` to prime each reverse dep | **None** | +| Post-loop teardown | `wait_flag` to drain all primed signals | **None** | +| Loop peeling for complex deps | Required for non-1:1 or nested loops | **Not required** | +| Straight-line code | Simple, clear | Slightly more verbose (bracket each stage) | +| Ping/pong loops | 8 event IDs + 4 prime + 4 drain | Same pattern, **no overhead** | +| Best used for | Simple pipelines, fine-grained control | Double/multi-buffering, complex loops | + +--- + +#### Inter-Core Sync + +> **Note:** Inter-core sync is only needed for **mixed Cube+Vector tasks** where Cube produces data that Vector consumes (or vice versa). **Vec-only tasks can ignore this section entirely.** + +These ops coordinate execution across the Cube block and Vector subblocks within a cluster. Each core cluster consists of **1 Cube block : 2 Vector subblocks**, each with its own **SU (Sequencer Unit)** running independent instruction streams. + +``` +Core Cluster (1:2 ratio) +┌─────────────────────────────────────────────┐ +│ ┌──────────────┐ ┌──────────────┐ │ +│ │ AIC (Cube) │ │ AIV0 (Vec) │ │ +│ │ ┌────────┐ │ │ ┌────────┐ │ │ +│ │ │ SU │──┼────┼──│ SU │ │ │ +│ │ └────────┘ │ │ └────────┘ │ │ +│ │ CUBE pipe │ │ MTE2/V/MTE3 │ │ +│ │ L0C buffer │ │ UB (256KB) │ │ +│ └──────────────┘ └──────────────┘ │ +│ ┌──────────────┐ │ +│ │ AIV1 (Vec) │ │ +│ │ ┌────────┐ │ │ +│ │ │ SU │ │ │ +│ │ └────────┘ │ │ +│ │ MTE2/V/MTE3 │ │ +│ │ UB (256KB) │ │ +│ └──────────────┘ │ +└─────────────────────────────────────────────┘ +``` + +##### Platform Comparison + +| Aspect | A2A3 (Ascend 910) | A5 (Ascend 950) | +|--------|-------------------|-----------------| +| **Signal op** | `set_cross_core` (mode2) | `set_intra_block` | +| **Wait op** | `wait_flag_dev` | `wait_intra_core` | +| **Wait behavior** | SU-level blocking (entire core stalls) | Per-pipeline (only named pipe stalls) | +| **Semaphore pool** | 16 IDs per cluster, 4-bit counter | 16 IDs, but 32-ID address space (see below) | +| **C→V** | **Broadcast**: one `set` reaches both AIV0+AIV1 | **1:1**: separate `set` per subblock required | +| **V→C** | **Reduce**: Cube waits for both subblocks in one `wait` | **1:1**: Cube needs separate `wait` per subblock | + +##### A2A3: `set_cross_core` / `wait_flag_dev` + +```c +// mode2 broadcast/reduce semantics for 1:2 cluster +set_cross_core(pipe, semaphore_id); // pipe: VEC/MTE2/CUBE/FIX +wait_flag_dev(semaphore_id); // SU-level blocking +``` + +``` +C→V Broadcast (one set reaches both): + AIC ──set_cross_core──┬──> AIV0 sema++ + └──> AIV1 sema++ + +V→C Reduce (one wait for both): + AIV0 ──set_cross_core──┐ + ├──> AIC wait_flag_dev (blocks until BOTH) + AIV1 ──set_cross_core──┘ +``` + +##### `pto.set_cross_core` + +- **syntax:** `pto.set_cross_core %core_id, %event_id : i64, i64` +- **semantics:** Signal event to another core. Uses **mode2** for 1:2 cluster on A2A3. + +##### `pto.wait_flag_dev` + +- **syntax:** `pto.wait_flag_dev %core_id, %event_id : i64, i64` +- **semantics:** Wait for event from another core. **SU-level blocking** — entire core stalls. + +##### A5: `set_intra_block` / `wait_intra_core` + +```c +set_intra_block(trigger_pipe, semaphore_id); +wait_intra_core(wait_pipe, semaphore_id); // only named pipe stalls +``` + +**A5 semaphore address space:** The hardware has **16 physical semaphore IDs** but exposes a **32-ID address space** to support 1:1 signaling to each subblock: + +| ID Range | Target | +|----------|--------| +| 0–15 | AIV0 (subblock 0) | +| 16–31 (+15 offset) | AIV1 (subblock 1) | + +This means C→V requires **separate `set_intra_block` calls** per subblock (no broadcast), and V→C requires **separate `wait_intra_core` calls** per subblock (no hardware reduce). + +``` +C→V on A5 (1:1, no broadcast — need two sets): + AIC ──set_intra_block(pipe, sema_id)────> AIV0 + AIC ──set_intra_block(pipe, sema_id+15)──> AIV1 + +V→C on A5 (1:1, no reduce — need two waits): + AIV0 ──set_intra_block──> AIC wait_intra_core(pipe, sema_id) + AIV1 ──set_intra_block──> AIC wait_intra_core(pipe, sema_id+15) // extra wait +``` + +##### `pto.set_intra_block` + +- **syntax:** `pto.set_intra_block %block_id, %event_id : i64, i64` +- **semantics:** Signal event within a block (A5). Specifies **trigger pipe**. 1:1 per subblock. + +##### `pto.wait_intra_core` + +- **syntax:** `pto.wait_intra_core %block_id, %event_id : i64, i64` +- **semantics:** Wait for event within block (A5). Specifies **which pipeline should wait** — only that pipe stalls, SU and other pipes continue. + +##### Wait Granularity Comparison + +``` +A2A3 wait_flag_dev (SU-level stall): + SU ──┬── PIPE_MTE2 ───╳ ALL STALLED + ├── PIPE_V ───╳ ALL STALLED + └── PIPE_MTE3 ───╳ ALL STALLED + +A5 wait_intra_core "PIPE_MTE2" (per-pipe stall): + SU ──┬── PIPE_MTE2 ───╳ STALLED (waiting for Cube) + ├── PIPE_V ─── ✓ RUNNING + └── PIPE_MTE3 ─── ✓ RUNNING +``` + + + +### 2. DMA Copy Programming + +> **Category:** DMA transfer configuration and execution +> **Pipelines:** MTE2 (GM→UB), MTE3 (UB→GM) + +DMA transfers move data between Global Memory (GM) and Unified Buffer (UB). The MTE engines operate asynchronously from the Vector core, requiring explicit sync (see [Pipeline Sync](#isa-01-pipeline-sync)). + +The MTE2/MTE3 DMA engine executes a **multi-level nested loop** transfer. Before issuing the copy instruction, stride and loop-size registers must be configured. + +--- + +#### Loop Stride Configuration (GM→UB) + +These ops configure the MTE2 DMA engine's hardware loops for GM→UB transfers. They must be set **before** calling `pto.copy_gm_to_ubuf`. + +##### `pto.set_loop_size_outtoub` + +- **syntax:** `pto.set_loop_size_outtoub %loop1_count, %loop2_count : i64, i64` +- **semantics:** Configure HW loop iteration counts for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%loop1_count` | 21 bits | Inner HW loop iteration count | +| `%loop2_count` | 21 bits | Outer HW loop iteration count | + +When not using multi-level looping, set both to 1. + +--- + +##### `pto.set_loop2_stride_outtoub` + +- **syntax:** `pto.set_loop2_stride_outtoub %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure outer loop (loop2) pointer advance for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 40 bits | GM source pointer advance per loop2 iteration (bytes) | +| `%dst_stride` | 21 bits | UB destination pointer advance per loop2 iteration (bytes) | + +After each loop2 iteration, the DMA engine advances the GM read pointer by `%src_stride` and UB write pointer by `%dst_stride`. + +--- + +##### `pto.set_loop1_stride_outtoub` + +- **syntax:** `pto.set_loop1_stride_outtoub %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure inner loop (loop1) pointer advance for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 40 bits | GM source pointer advance per loop1 iteration (bytes) | +| `%dst_stride` | 21 bits | UB destination pointer advance per loop1 iteration (bytes) | + +--- + +#### Loop Stride Configuration (UB→GM) + +These ops configure the MTE3 DMA engine's hardware loops for UB→GM transfers. They must be set **before** calling `pto.copy_ubuf_to_gm`. + +Note: UB stride fields are 21 bits (sufficient for 256KB UB address space), GM stride fields are 40 bits (full GM address range). + +##### `pto.set_loop_size_ubtoout` + +- **syntax:** `pto.set_loop_size_ubtoout %loop1_count, %loop2_count : i64, i64` +- **semantics:** Configure HW loop iteration counts for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%loop1_count` | 21 bits | Inner HW loop iteration count | +| `%loop2_count` | 21 bits | Outer HW loop iteration count | + +--- + +##### `pto.set_loop2_stride_ubtoout` + +- **syntax:** `pto.set_loop2_stride_ubtoout %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure outer loop (loop2) pointer advance for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 21 bits | UB source pointer advance per loop2 iteration (bytes) | +| `%dst_stride` | 40 bits | GM destination pointer advance per loop2 iteration (bytes) | + +--- + +##### `pto.set_loop1_stride_ubtoout` + +- **syntax:** `pto.set_loop1_stride_ubtoout %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure inner loop (loop1) pointer advance for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 21 bits | UB source pointer advance per loop1 iteration (bytes) | +| `%dst_stride` | 40 bits | GM destination pointer advance per loop1 iteration (bytes) | + +--- + +#### DMA Transfer Execution + +##### `pto.copy_gm_to_ubuf` + +- **syntax:** +```mlir +pto.copy_gm_to_ubuf %gm_src, %ub_dst, + %sid, %n_burst, %len_burst, %left_padding, %right_padding, + %data_select_bit, %l2_cache_ctl, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` +- **semantics:** DMA transfer from Global Memory (`!pto.ptr`) to Unified Buffer (`!pto.ptr`). + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%gm_src` | GM source pointer (`!pto.ptr`) | +| `%ub_dst` | UB destination pointer (`!pto.ptr`, 32B-aligned) | +| `%sid` | Stream ID (usually 0) | +| `%n_burst` | Number of burst rows (innermost loop count) | +| `%len_burst` | Contiguous bytes transferred per burst row | +| `%left_padding` | Left padding count (bytes) | +| `%right_padding` | Right padding count (bytes) | +| `%data_select_bit` | Padding / data-select control bit (`i1`) | +| `%l2_cache_ctl` | L2 cache allocate control (TBD — controls whether DMA allocates in L2 cache) | +| `%src_stride` | GM source stride: start-to-start distance between consecutive burst rows (bytes) | +| `%dst_stride` | UB destination stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | + +--- + +##### `pto.copy_ubuf_to_gm` + +- **syntax:** +```mlir +pto.copy_ubuf_to_gm %ub_src, %gm_dst, + %sid, %n_burst, %len_burst, %reserved, %dst_stride, %src_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` +- **semantics:** DMA transfer from Unified Buffer (`!pto.ptr`) to Global Memory (`!pto.ptr`). MTE3 reads only `len_burst` bytes from each UB row (de-padding). + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%ub_src` | UB source pointer (`!pto.ptr`, 32B-aligned) | +| `%gm_dst` | GM destination pointer (`!pto.ptr`) | +| `%sid` | Stream ID (usually 0) | +| `%n_burst` | Number of burst rows | +| `%len_burst` | Contiguous bytes transferred per burst row | +| `%reserved` | Reserved field (set to 0) | +| `%dst_stride` | GM destination stride: start-to-start distance between consecutive burst rows (bytes) | +| `%src_stride` | UB source stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | + +--- + +##### `pto.copy_ubuf_to_ubuf` + +- **syntax:** +```mlir +pto.copy_ubuf_to_ubuf %source, %dest, %sid, %n_burst, %len_burst, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64 x5 +``` +- **semantics:** Copy within Unified Buffer. + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%source` | UB source pointer | +| `%dest` | UB destination pointer | +| `%sid` | Stream ID | +| `%n_burst` | Number of bursts | +| `%len_burst` | Length per burst | +| `%src_stride` | Source stride | +| `%dst_stride` | Destination stride | + +--- + +#### Burst / Stride / Pad Model + +All A5 DMA addresses are **stride-based**: stride is the distance from the start of one row to the start of the next row (`stride >= lenBurst`). There is no separate "gap" parameter. + +##### Key Terms + +``` +burst = lenBurst contiguous bytes transferred per row +stride = distance (bytes) from start of row[r] to start of row[r+1] +pad = ub_stride - lenBurst, padded to the 32B alignment boundary +``` + +##### Alignment Constraints + +- **UB addresses** (both source and destination) must be **32-byte aligned**. +- **GM→UB padding**: When `data_select_bit = true`, each UB row is padded from `lenBurst` up to the **32B-aligned boundary** of `ub_stride` with `pad_val` (set via `set_mov_pad_val`). This ensures every UB row starts at a 32B-aligned offset. +- **UB→GM de-padding**: MTE3 reads `lenBurst` bytes from each 32B-aligned UB row (skipping any padding that was added during load), writing only valid data to GM. This effectively strips padding on store. + +##### 2D Diagram: GM→UB (pto.copy_gm_to_ubuf) + +``` +GM (source, `!pto.ptr`): + + |<--- src_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +UB (destination, `!pto.ptr`, 32B-aligned): + + |<---------- dst_stride (32B-aligned) ---------->| + |<- len_burst ->|<- pad (to 32B boundary) ->| | +Row 0: [##DATA########][000000 PAD 000000000000000] +Row 1: [##DATA########][000000 PAD 000000000000000] +Row 2: [##DATA########][000000 PAD 000000000000000] + ... +Row N-1: [##DATA########][000000 PAD 000000000000000] + +N = n_burst +stride = start of row[r] to start of row[r+1] +pad = filled with pad_val to 32B boundary (data_select_bit=true) +[DATA] = valid data transferred by DMA +[PAD] = pad_val fill (set via set_mov_pad_val) +``` + +##### 2D Diagram: UB→GM (pto.copy_ubuf_to_gm) + +``` +UB (source, `!pto.ptr`, 32B-aligned start addr): + + |<---------- src_stride (32B-aligned) --------->| + |<- len_burst ->|<-- pad (ignored on read) -->| | +Row 0: [##DATA########][000 pad 000000000000000000] +Row 1: [##DATA########][000 pad 000000000000000000] +Row 2: [##DATA########][000 pad 000000000000000000] + ... +Row N-1: [##DATA########][000 pad 000000000000000000] + +GM (destination, `!pto.ptr`): + + |<--- dst_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +N = n_burst +MTE3 reads only len_burst bytes from each UB row (de-padding). +Only len_burst bytes are written to each GM row. +``` + +--- + +#### Multi-Level Loop Semantics (C Code) + +The full DMA transfer is a nested loop. The HW loop registers (set before the copy) control the outer levels, and the copy instruction parameters control the innermost burst level. + +##### GM→UB Full Loop + +```c +// C equivalent of what the HW executes: +for (int j = 0; j < loop2_count; j++) { // HW outer loop + uint8_t *gm1 = gm_src + j * loop2_src_stride; + uint8_t *ub1 = ub_dst + j * loop2_dst_stride; + + for (int k = 0; k < loop1_count; k++) { // HW inner loop + uint8_t *gm2 = gm1 + k * loop1_src_stride; + uint8_t *ub2 = ub1 + k * loop1_dst_stride; + + for (int r = 0; r < n_burst; r++) { // burst engine + memcpy(ub2 + r * dst_stride, // UB dest row + gm2 + r * src_stride, // GM src row + len_burst); // contiguous bytes + if (data_select_bit) + memset(ub2 + r * dst_stride + len_burst, + pad_val, dst_stride - len_burst); + } + } +} +``` + +##### UB→GM Full Loop + +```c +// C equivalent: +for (int j = 0; j < loop2_count; j++) { + uint8_t *ub1 = ub_src + j * loop2_src_stride; + uint8_t *gm1 = gm_dst + j * loop2_dst_stride; + + for (int k = 0; k < loop1_count; k++) { + uint8_t *ub2 = ub1 + k * loop1_src_stride; + uint8_t *gm2 = gm1 + k * loop1_dst_stride; + + for (int r = 0; r < n_burst; r++) { + memcpy(gm2 + r * dst_stride, // GM dest row + ub2 + r * src_stride, // UB src row + len_burst); // contiguous bytes + } + } +} +``` + +--- + +#### Example 1: GM→UB — Load a 32×32 f32 Tile (Simple Case) + +Load a 32×32 f32 tile from GM into UB. This matches the `abs_kernel_2d` test case. + +``` +GM layout (32 × 32 f32, contiguous): + + |<- len_burst = 128B (32 × 4) ->| + |<- src_stride = 128B --------->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + +UB layout (32 × 32 f32, 32B-aligned, contiguous): + + |<- dst_stride = 128B (32B-aligned) ->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + + len_burst = 32 × 4 = 128 bytes + src_stride = 128 bytes (contiguous rows) + dst_stride = 128 bytes (already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — no multi-level loops needed +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + +pto.copy_gm_to_ubuf %arg0, %ub_in, + %c0_i64, // sid = 0 + %c32_i64, // n_burst = 32 (32 rows) + %c128_i64, // len_burst = 128 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c128_i64, // src_stride = 128 bytes + %c128_i64 // dst_stride = 128 bytes + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 2: GM→UB — Load a 2D Tile from a Larger Matrix + +Load a 64×128 tile (f16) from a 1024×512 matrix in GM into UB. + +``` +GM layout (1024 × 512 f16): + + col 0 col 128 col 512 + | | | + +--[###TILE###]+.....................+ row R + +--[###TILE###]+.....................+ row R+1 + ... + +--[###TILE###]+.....................+ row R+63 + + |<--------- src_stride = 1024B ----------->| + |<-len_burst=256B->| + + len_burst = 128 × 2 = 256 bytes (128 f16 elements) + src_stride = 512 × 2 = 1024 bytes (start-to-start, full GM row) + +UB layout (64 × 128 f16, 32B-aligned, contiguous): + + +--[###TILE###]--+ row 0 (256 bytes, 32B-aligned, no pad) + +--[###TILE###]--+ row 1 + ... + +--[###TILE###]--+ row 63 + + dst_stride = 256 bytes (= len_burst, already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — no multi-level loops needed +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 (64 rows) + %c256_i64, // len_burst = 256 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c1024_i64, // src_stride = 1024 bytes (full matrix row) + %c256_i64 // dst_stride = 256 bytes (tile row) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 3: GM→UB — Load with Padding + +Load 100 valid columns from GM into a 128-wide UB tile (f16). The remaining 28 columns are zero-padded. + +``` +GM (100 cols valid, contiguous): + + |<-len_burst=200B->| + |<- src_stride=200B (start-to-start) ->| + +--[####DATA####]-+ row 0 + +--[####DATA####]-+ row 1 + ... + +--[####DATA####]-+ row 63 + +UB (128 cols wide, 32B-aligned, padded): + + |<--------- dst_stride = 256B (32B-aligned) --------->| + |<-len_burst=200B->|<---- pad = 56B to 32B boundary ->| + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 0 + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 1 + ... + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 63 + + len_burst = 100 × 2 = 200 bytes + src_stride = 200 bytes (start-to-start, contiguous in GM) + dst_stride = 128 × 2 = 256 bytes (32B-aligned tile width in UB) + pad = 256 - 200 = 56 bytes (padded to 32B boundary with pad_val) +``` + +```mlir +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 + %c200_i64, // len_burst = 200 bytes + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %true, // data_select_bit = true (enable padding) + %c0_i64, // l2_cache_ctl = 0 + %c200_i64, // src_stride = 200 bytes + %c256_i64 // dst_stride = 256 bytes (32B-aligned) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 4: UB→GM — Store a 32×32 f32 Tile (Simple Case) + +Store a 32×32 f32 tile from UB back to GM. This matches the `abs_kernel_2d` test case. + +``` +UB (source, 32B-aligned, 32 × 32 f32): + + |<- src_stride = 128B (32B-aligned) ->| + |<- len_burst = 128B ->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 + + (no padding here — len_burst == src_stride) + +GM (dest, 32 × 32 f32): + + |<- dst_stride = 128B ->| + |<- len_burst = 128B -->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 +``` + +```mlir +// Configure MTE3 strides +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + +pto.copy_ubuf_to_gm %ub_out, %arg1, + %c0_i64, // sid = 0 + %c32_i64, // n_burst = 32 + %c128_i64, // len_burst = 128 bytes + %c0_i64, // reserved = 0 + %c128_i64, // dst_stride = 128 bytes + %c128_i64 // src_stride = 128 bytes + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` + +--- + +#### Example 5: UB→GM — Store a 2D Tile Back to a Larger Matrix + +Store a 64×128 tile (f16) from UB back to a 1024×512 GM matrix at an offset. + +``` +UB (source, 32B-aligned, 64 × 128 f16): + + |<- src_stride = 256B (32B-aligned) ->| + |<- len_burst = 256B ->| + +--[#####TILE#####]---+ row 0 + +--[#####TILE#####]---+ row 1 + ... + +--[#####TILE#####]---+ row 63 + + (no padding here — len_burst == src_stride) + +GM (dest, into 1024 × 512 matrix): + + |<----------- dst_stride = 1024B (start-to-start) --------->| + |<- len_burst = 256B ->| | + col 0 col 128 col 512 + +--[#####TILE#####]---+.............................+ row R + +--[#####TILE#####]---+.............................+ row R+1 + ... + +--[#####TILE#####]---+.............................+ row R+63 + + MTE3 reads len_burst bytes from each 32B-aligned UB row, + writes only len_burst bytes per GM row (stride controls row spacing). +``` + +```mlir +// Configure MTE3 strides +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 + +pto.copy_ubuf_to_gm %ub_ptr, %gm_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 + %c256_i64, // len_burst = 256 bytes + %c0_i64, // reserved = 0 + %c1024_i64, // dst_stride = 1024 bytes (GM row) + %c256_i64 // src_stride = 256 bytes (UB row) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` + +--- + +#### Example 6: GM→UB with Multi-Level Loop (Batch of Tiles) + +Load 4 batches of 8×128 tiles from a [4, 8, 128] f16 tensor using loop1. + +``` +GM [4, 8, 128] f16 (contiguous): UB (4 tiles laid out sequentially): + + batch 0: 8 rows × 256 bytes [batch 0: 8×128][batch 1: 8×128] + batch 1: 8 rows × 256 bytes [batch 2: 8×128][batch 3: 8×128] + batch 2: 8 rows × 256 bytes + batch 3: 8 rows × 256 bytes loop1 src_stride = 2048 bytes (8 × 256) + loop1 dst_stride = 2048 bytes (8 × 256) + Each batch = 8 × 256 = 2048 bytes loop1_count = 4 (iterate over batches) +``` + +```mlir +// loop1_count = 4 batches, loop2_count = 1 (not used) +pto.set_loop_size_outtoub %c4_i64, %c1_i64 : i64, i64 + +// loop1 stride: advance by one batch (2048 bytes) in both GM and UB +pto.set_loop1_stride_outtoub %c2048_i64, %c2048_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c8_i64, // n_burst = 8 rows per batch + %c256_i64, // len_burst = 256 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c256_i64, // src_stride = 256 (contiguous rows) + %c256_i64 // dst_stride = 256 (contiguous rows) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +Execution trace: + +``` +loop1 iter 0: gm_ptr + 0×2048 → ub_ptr + 0×2048, DMA 8 rows × 256B +loop1 iter 1: gm_ptr + 1×2048 → ub_ptr + 1×2048, DMA 8 rows × 256B +loop1 iter 2: gm_ptr + 2×2048 → ub_ptr + 2×2048, DMA 8 rows × 256B +loop1 iter 3: gm_ptr + 3×2048 → ub_ptr + 3×2048, DMA 8 rows × 256B +``` + + + +### 3. Vector Load/Store + +> **Category:** UB ↔ Vector Register data movement +> **Pipeline:** PIPE_V (Vector Core) + +Vector loads move data from Unified Buffer (UB) to vector registers (`vreg`). Vector stores move data from `vreg` back to UB. All vector compute operates only on `vreg` — UB is the staging area between DMA and compute. + +#### Common Operand Model + +- `%source` / `%dest` is the base address operand in SSA form. The base pointer + MUST address the Vector tile buffer / UB space. +- `%offset` is the displacement operand in SSA form. The exact encoding is + instruction-specific, but the effective address and any post-update behavior + MUST match the selected instruction form. +- `%mask` is the predicate operand for predicated memory families. For memory + families, + inactive lanes or inactive blocks MUST NOT issue memory requests unless the + instruction explicitly documents a different behavior. +- `%result` is the destination vector register value in SSA form. +- `!pto.align` is the SSA carrier for alignment-buffer state used by unaligned + load/store families. The PTO micro Instruction representation makes that state explicit rather than implicit. + +--- + +#### Latency and throughput (A5) + +**Cycle-accurate simulator (CA model)** issue→retire timings for vector-side instructions behind this chapter. Values are **simulator** results, **not** guaranteed for silicon. + +**SOC:** Tables below are from **Ascend910_9599** CA sim (the pto-isa ST default when **Ascend950PR_9599** is not selected). + +**Log `dist:` tokens:** PTO load/store modes lower to **`RV_VLD` / `RV_VLDI` / `RV_VST` / `RV_VSTI`** with a **`dist:`** field on the vector pipes (`RVECLD` / `RVECST`). Some simulator logs typo contiguous load as `dist:NORAML`; treat as **`NORMAL`**. + +##### Reference op latencies (A5 mnemonics) + +| A5 mnemonic | Mode / note | Typical issue→retire (cycles) | +|-------------|-------------|------------------------------| +| `RV_VLD` | `dist:NORMAL` / `NORAML` | **9** | +| `RV_VLDI` | `dist:DINTLV` (dual vreg) | **9** | +| `RV_VST` / `RV_VSTI` | `dist:NORM` | **9** | +| `RV_VGATHER2` | `Dtype: B32` | **27–28** | +| `RV_VGATHERB` | indexed byte gather | **~21** | +| `RV_VSCATTER` | `Dtype: B16` | **~17** | +| `RV_VADD` | F32 between UB-backed ops | **7** | + +##### `dist:` tokens (issue→retire) + +Most **`dist:`** tokens are **9** issue→retire cycles. **`INTLV`** on **`RV_VSTI`** is **12** cycles. + +| `dist:` (as in log) | RV op | issue→retire (cycles) | +|---------------------|-------|----------------------| +| `DINTLV` | `RV_VLDI` | **9** | +| `BRC` | `RV_VLD` | **9** | +| `BRC_BLK` | `RV_VLD` | **9** | +| `INTLV` | `RV_VSTI` | **12** | +| `UNPK` | `RV_VLD` | **9** | +| `NORM` | `RV_VSTI` | **9** | +| `PK` | `RV_VSTI` | **9** | +| `NORMAL` / `NORAML` | `RV_VLD` | **9** | + +**Note:** PTO intrinsic **`BRC_BLK`** matches the **`BRC_BLK`** `dist:` string on **`RV_VLD`** in simulator logs (block-replicate path; not a plain contiguous copy in the usual tiling use). + +**Issue (vector load/store):** `pto.vlds` (**`RV_VLD`**) is **dual-issue capable**: two independent `pto.vlds` can issue **in the same cycle**. **Alternatively**, the hardware can issue **one** `pto.vlds` **and** **one** `pto.vsts` together (**1+1**) in the same cycle. Each cycle is **either** dual **`vlds`** **or** **`vlds` + `vsts` (1+1)**—those two issue modes are mutually exclusive. Sustained throughput still depends on RAW hazards and loop structure. + +**Throughput (simulator, pattern-dependent):** + +- **`RV_VLD` / `pto.vlds`:** Dual-issue **or** half of a **1+1** with `vsts`, per the rule above. +- **`RV_VST` / `pto.vsts`:** In a **1+1** cycle, pairs with one `vlds`; otherwise typically **one** store per cycle in tight loops. +- **`RV_VGATHER2`:** Much lower than contiguous `RV_VLD` (on the order of **~0.1** ops/cycle in steady-state alongside 27–28-cycle latency). + +##### PTO `dist` summary (loads) + +| PTO `dist` (load) | Latency | +|-------------------|-------------------| +| `NORM` | **9** cycles | +| `UNPK` | **9** cycles | +| `DINTLV` | **9** cycles (`RV_VLDI`) | +| `BRC` | **9** cycles (`RV_VLD`) | +| `BRC_BLK` | **9** cycles as **`dist:BRC_BLK`** on `RV_VLD` | +| `BDINTLV` | **9** cycles | +| `US`, `DS`, `SPLT4CHN`, `SPLT2CHN` | **9** cycles | + +##### PTO `dist` summary (stores) + +| PTO `dist` (store) | Latency | +|--------------------|-------------------| +| `NORM` | **9** cycles (`RV_VSTI`) | +| `PK` | **9** cycles | +| `INTLV` (`pto.vstx2`) | **12** cycles | +| `MRG4CHN`, `MRG2CHN` | **9** cycles | + +##### Gather, scatter, and special addressing + +| PTO op | A5-level | Latency | +|--------|----------|-------------------| +| `pto.vgather2` | `RV_VGATHER2` | **27–28** cycles (pattern-dependent) | +| `pto.vgatherb` | `RV_VGATHERB` | **~21** cycles issue→retire | +| `pto.vgather2_bc` | (broadcast gather) | **27–28** cycles (same as **`pto.vgather2`**) | +| `pto.vscatter` | `RV_VSCATTER` | **~17** cycles for **`Dtype: B16`** | + +##### Strided loads/stores, unaligned ops, alignment state + +Ops such as **`pto.vldas`**, **`pto.vldus`**, **`pto.vsld`**, **`pto.vsldb`**, **`pto.vsst`**, **`pto.vsstb`**, **`pto.vsta`**, **`pto.vstas`**, **`pto.vstar`**, **`pto.vstu`**, **`pto.vstus`**, **`pto.vstur`**: **9** cycles (same vector load/store pipe family as contiguous `RV_VLD` / `RV_VST` unless listed otherwise above). + +##### Dual-issue vs DMA + +DMA **`TLOAD` / `TSTORE`** (global memory ↔ UB) use **MTE** pipes, not `RV_VLD`/`RV_VST`. **MTE2** `MOV_*` latency is not the same as vector `RV_VLD` latency (see `02-dma-copy.md` for GM↔UB movement). + +--- + +#### Contiguous Loads + +##### `pto.vlds` + +- **syntax:** `%result = pto.vlds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.vreg` +- **semantics:** Vector load with distribution mode. +- **inputs:** + `%source` is the UB base address, `%offset` is the load displacement, and + `DIST` selects the distribution mode. +- **outputs:** + `%result` is the loaded vector register value. +- **constraints and limitations:** + The effective address MUST satisfy the alignment rule of the selected + distribution mode. `NORM` reads one full vector footprint. Broadcast, + upsample, downsample, unpack, split-channel, and deinterleave modes change + how memory bytes are mapped into destination lanes, but they do not change the + fact that the source is UB memory. PTO surface exposes load `dist` as family + tokens, and each family only supports the element widths listed below. + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `NORM` | width-agnostic | `dst[i] = UB[base + i * sizeof(T)]` | **9** cycles | +| `BRC` | `b8`, `b16`, `b32` | `dst[i] = UB[base]` for all `i` | **9** cycles | +| `US` | `b8`, `b16` | `dst[2*i] = dst[2*i+1] = UB[base + i]` | **9** cycles | +| `DS` | `b8`, `b16` | `dst[i] = UB[base + 2*i]` | **9** cycles | +| `UNPK` | `b8`, `b16`, `b32` | Expand packed source data into wider lanes | **9** cycles | +| `BRC_BLK` | width-agnostic | Block-replicate load path; simulator logs may print `dist:BRC_BLK` | **9** cycles | +| `E2B` | `b16`, `b32` | Load element groups and expand them into byte-oriented lane layout | **9** cycles | +| `UNPK4` | `b8` | Unpack 4-way packed `b8` source groups into destination lanes | **9** cycles | +| `SPLT4CHN` | `b8` | Split 4-channel interleaved source into one channel plane | **9** cycles | +| `SPLT2CHN` | `b8`, `b16` | Split 2-channel interleaved source into one channel plane | **9** cycles | + +`pto.vlds` currently covers only single-result load families. Dual-result +deinterleave forms are modeled separately in PTO surface as +[`pto.vldsx2`](#ptovldsx2): `BDINTLV` is the block-deinterleave family, while +`DINTLV` is the element-width-sensitive deinterleave family. + +**Example — Contiguous load:** +```mlir +%v = pto.vlds %ub[%offset] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +``` + +**Example — Broadcast scalar to all lanes:** +```mlir +%v = pto.vlds %ub[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> +``` + +--- + +##### `pto.vldas` + +- **syntax:** `%result = pto.vldas %source : !pto.ptr -> !pto.align` +- **semantics:** Prime alignment buffer for subsequent unaligned load. +- **inputs:** + `%source` is the UB address whose surrounding aligned block seeds the load + alignment state. +- **outputs:** + `%result` is the initialized load-alignment state. +- **constraints and limitations:** + This op is the required leading operation for a `pto.vldus` stream using the + same alignment state. The source address itself need not be 32-byte aligned; + hardware truncates it to the aligned block boundary for the priming load. +- **Latency:** **9** cycles. + +--- + +##### `pto.vldus` + +- **syntax:** `%result, %align_out = pto.vldus %source, %align : !pto.ptr, !pto.align -> !pto.vreg, !pto.align` +- **semantics:** Unaligned load using primed align state. +- **inputs:** + `%source` is the current UB address and `%align` is the incoming load + alignment state primed by `pto.vldas` or a prior `pto.vldus`. +- **outputs:** + `%result` is the assembled vector value and `%align_out` is the updated + alignment state. +- **constraints and limitations:** + A matching `pto.vldas` MUST appear before the first dependent `pto.vldus` + stream in the same vector loop. The installed no-post A5 interface keeps a + struct-shaped internal return for lowering convenience, but its no-post + `base` field is not meaningful user-visible state. VPTO therefore hides that + value and only exposes the updated align carrier. Reusing the original + `%source` starts a new explicit access point; if the caller wants another + no-post access, it should compute the next source pointer explicitly and pair + it with the required align setup. +- **Latency:** **9** cycles. + +**Unaligned load pattern:** +```mlir +%align = pto.vldas %ub : !pto.ptr -> !pto.align +%vec, %align2 = pto.vldus %ub, %align : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align +``` + +--- + +##### `pto.init_align` + +- **syntax:** `%result = pto.init_align : !pto.align` +- **semantics:** Initialize store-side align carrier state. +- **outputs:** + `%result` is a fresh zero-initialized align carrier for store-side unaligned + streams such as `pto.vstus`, `pto.vstur`, `pto.vstar`, `pto.vstas`, and + `pto.pstu`. +- **constraints and limitations:** + This op is for store-family initialization only. Unaligned load streams still + start from `pto.vldas`. + +--- + +#### Dual Loads (Deinterleave) + +##### `pto.vldsx2` + +- **syntax:** `%low, %high = pto.vldsx2 %source[%offset], "DIST" : !pto.ptr, index -> !pto.vreg, !pto.vreg` +- **semantics:** Dual load with deinterleave (AoS → SoA conversion). +- **inputs:** + `%source` is the UB base pointer, `%offset` is the displacement, and `DIST` + selects a dual-load/deinterleave layout. +- **outputs:** + `%low` and `%high` are the two destination vectors. +- **constraints and limitations:** + This family is only legal for interleave/deinterleave style distributions. + The two outputs form an ordered pair, and that pairing MUST be preserved. + PTO surface accepts deinterleave families. `BDINTLV` is element-width + agnostic, while `DINTLV` supports only the element widths listed in the + table. +- **latency:** `BDINTLV` / `DINTLV` are both **9** cycles. + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `BDINTLV` | width-agnostic | Block deinterleave into two destination vectors | **9** cycles | +| `DINTLV` | `b8`, `b16`, `b32` | Deinterleave alternating elements into `%low` / `%high` | **9** cycles | + +```c +// DINTLV family on 32-bit elements: deinterleave 32-bit elements +for (int i = 0; i < 64; i++) { + low[i] = UB[base + 8*i]; // even elements + high[i] = UB[base + 8*i + 4]; // odd elements +} +``` + +**Example — Load interleaved XY pairs into separate X/Y vectors:** +```mlir +%x, %y = pto.vldsx2 %ub[%offset], "DINTLV" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +``` + +##### `pto.vsldb` + +- **syntax:** `%result = pto.vsldb %source, %block_stride, %repeat_stride, %mask : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg` +- **semantics:** Block-strided load for 2D tile access. +- **inputs:** + `%source` is the UB base pointer. `%block_stride` and `%repeat_stride` are + the two 16-bit fields of the hardware control word, and `%mask` controls + which blocks participate. +- **outputs:** + `%result` is the loaded vector. +- **constraints and limitations:** + PTO surface does not expose the packed control word directly. If a block is + masked off, the corresponding destination block is zeroed and MUST NOT raise + an address overflow exception for that block. +- **Latency:** **9** cycles. + +```c +// Block-strided load on 32-bit elements: one 32B block = 8 lanes. +for (int blk = 0; blk < 8; ++blk) { + if (pg_b32[blk]) + dst_block[blk] = UB_block[base + repeat_stride + blk * block_stride]; + else + dst_block[blk] = 0; +} +``` + +--- + +#### Gather (Indexed) Loads + +##### `pto.vgather2` + +- **syntax:** `%result = pto.vgather2 %source, %offsets, %active_lanes : !pto.ptr, !pto.vreg, index -> !pto.vreg` +- **semantics:** Indexed gather from UB. +- **inputs:** + `%source` is the UB base pointer, `%offsets` provides per-lane element + offsets, and `%active_lanes` bounds how many lanes participate. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + Only the first `%active_lanes` indices participate. The index element width + and interpretation MUST match the selected gather form, and each effective + address must satisfy that form's alignment rules. +- **Latency:** **27–28** cycles per `RV_VGATHER2`; throughput much lower than contiguous `RV_VLD` (see **Latency and throughput (A5)** at the start of this chapter). + +```c +for (int i = 0; i < active_lanes; i++) + dst[i] = UB[base + offsets[i] * sizeof(T)]; +``` + +--- + +##### `pto.vgatherb` + +- **syntax:** `%result = pto.vgatherb %source, %offsets, %mask : !pto.ptr, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Block gather load from UB. +- **inputs:** + `%source` is the UB base pointer, `%offsets` is a `ui32` offset vector, and + `%mask` is a `b32` predicate over the block-index lanes. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + This is a 32-byte block gather, not an element gather. `%source` MUST be + 32-byte aligned. Each participating `offsets[i]` is interpreted as a byte + offset and MUST itself be 32-byte aligned. Only the low `VL/8` bytes of the + offset vector are semantically valid; the effective block address is + `block_addr[i] = offsets_u32[i] + base`. If a `b32` predicate position is + false, the corresponding block does not participate in address coalescing, + does not raise overflow on that block address, and the destination block is + zero-filled. +- **Latency:** **~21** cycles issue→retire. + +```c +for (int blk = 0; blk < VL / 32; ++blk) { + if (pg_b32[blk]) + dst_block[blk] = UB_block[base + offsets_u32[blk]]; + else + dst_block[blk] = 0; +} +``` + +--- + +##### `pto.vgather2_bc` + +- **syntax:** `%result = pto.vgather2_bc %source, %offsets, %mask : !pto.ptr, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Gather with broadcast, conditioned by mask. +- **inputs:** + `%source` is the UB base pointer, `%offsets` contains gather indices, and + `%mask` gates which lanes participate. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + This is a backward-compatible family. Masked-off lanes do not participate in + address coalescing and do not trigger address overflow exceptions; their + destination lanes are zero-filled. On the current PTO surface, `%offsets` + uses 32-bit integer elements. +- **Latency:** **27–28** cycles (same as **`pto.vgather2`**). + +--- + +#### Contiguous Stores + +##### `pto.vsts` + +- **syntax:** `pto.vsts %value, %dest[%offset], %mask {dist = "DIST"} : !pto.vreg, !pto.ptr, !pto.mask` +- **semantics:** Vector store with distribution mode. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offset` is + the displacement, `%mask` selects the active lanes or sub-elements, and + `DIST` selects the store distribution. +- **outputs:** + This op has no SSA result; it writes to UB memory. +- **constraints and limitations:** + The effective destination address MUST satisfy the alignment rule of the + selected store mode. The single-input `pto.vsts` family covers contiguous + store, first-element-only store, packed store, and channel-merge store. + Dual-input interleave store remains in `pto.vstsx2`. PTO surface exposes + store `dist` as family tokens, and each family only supports the element + widths listed below. + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `NORM` | `b8`, `b16`, `b32` | `UB[base + i] = src[i]` | **9** cycles | +| `1PT` | `b8`, `b16`, `b32` | Only element 0 is written to the destination footprint | **9** cycles | +| `PK` | `b16`, `b32`, `b64` | Pack low half bits of each source element before store | **9** cycles | +| `PK4` | `b32` | Pack low 8 bits of each `b32` element before store | **9** cycles | +| `MRG4CHN` | `b8` | Merge 4 channel planes into an interleaved 4-channel layout | **9** cycles | +| `MRG2CHN` | `b8`, `b16` | Merge 2 channel planes into an interleaved 2-channel layout | **9** cycles | + +**Example — Contiguous store:** +```mlir +pto.vsts %v, %ub[%offset], %mask {dist = "NORM"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +--- + +#### Dual Stores (Interleave) + +##### `pto.vstsx2` + +- **syntax:** `pto.vstsx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vreg, !pto.ptr, index, !pto.mask` +- **semantics:** Dual interleaved store (SoA → AoS conversion). +- **inputs:** + `%low` and `%high` are the two source vectors, `%dest` is the UB base pointer, + `%offset` is the displacement, `DIST` selects the interleave layout, and + `%mask` gates the participating elements. +- **outputs:** + This op has no SSA result; it writes an interleaved stream to UB. +- **constraints and limitations:** + This family is only legal for interleave distributions. The two source + vectors form an ordered pair, and the interleave semantics of that pair MUST + be preserved. PTO surface accepts the `INTLV` family, which only supports the + element widths listed below. + be preserved. PTO surface accepts the `INTLV` family, which only supports the + element widths listed below. +- **latency:** `INTLV` is **12** cycles。 + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `INTLV` | `b8`, `b16`, `b32` | Interleave `%low` / `%high` into one destination stream | **12** cycles | +| `INTLV` | `b8`, `b16`, `b32` | + +```c +// INTLV family on 32-bit elements: +for (int i = 0; i < 64; i++) { + UB[base + 8*i] = low[i]; + UB[base + 8*i + 4] = high[i]; +} +``` + +##### `pto.vsstb` + +- **syntax:** `pto.vsstb %value, %dest, %block_stride, %repeat_stride, %mask : !pto.vreg, !pto.ptr, i16, i16, !pto.mask` +- **semantics:** Block-strided store for 2D tile access. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, + `%block_stride` and `%repeat_stride` are the two 16-bit fields of the + hardware control word, and `%mask` controls block participation. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + PTO surface does not expose the packed control word directly. Masked-off + blocks MUST NOT issue memory writes. +- **Latency:** **9** cycles. + +```c +// Block-strided store on 32-bit elements: one 32B block = 8 lanes. +for (int blk = 0; blk < 8; ++blk) { + if (pg_b32[blk]) + UB_block[base + repeat_stride + blk * block_stride] = src_block[blk]; +} +``` + +--- + +#### Scatter (Indexed) Stores + +##### `pto.vscatter` + +- **syntax:** `pto.vscatter %value, %dest, %offsets, %active_lanes : !pto.vreg, !pto.ptr, !pto.vreg, index` +- **semantics:** Indexed scatter to UB. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offsets` + provides per-lane or per-block indices, and `%active_lanes` bounds the active + requests. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + Only `b8`, `b16`, and `b32` element sizes are supported. The index vector + must use a supported integer element type and layout for this family. + Each computed address MUST be element-aligned. If two or more indices alias, + only one write is guaranteed and the winning lane is implementation-defined. +- **Latency:** **~17** cycles for **`Dtype: B16`**. + +```c +for (int i = 0; i < active_lanes; i++) + UB[base + offsets[i] * sizeof(T)] = src[i]; +``` + +--- + +#### Alignment State Stores + +##### `pto.vstas` +- **syntax:** `pto.vstas %value, %dest, %offset : !pto.align, !pto.ptr, i32` +- **semantics:** Scalar-register-offset form of alignment-state flush. +- **inputs:** + `%value` is the pending store-alignment state, `%dest` is the UB base + pointer, and `%offset` is the scalar-register style displacement. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + This family flushes pending store-alignment state using an explicit scalar + offset and keeps the scalar-offset form explicit. The incoming `%value` + should come from `pto.init_align` or from a prior state-producing unaligned + store op in the same stream. `%dest` and `%offset` together must identify the + same logical flush point produced by the immediately preceding stateful + unaligned-store step on that stream; using an unrelated base/offset pair is + invalid even if `%value` itself came from the same stream. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstar` +- **syntax:** `pto.vstar %value, %dest : !pto.align, !pto.ptr` +- **semantics:** Flush alignment state using the register-update form. +- **inputs:** + `%value` is the pending store-alignment state and `%dest` is the UB base + pointer. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + The implicit update state consumed by this flush MUST correspond to the same + store stream that produced `%value`. The first store-side state in a stream + should be created by `pto.init_align`. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstar` + +- **syntax:** `pto.vstar %value, %dest : !pto.align, !pto.ptr` +- **semantics:** Flush remaining alignment state. +- **inputs:** + `%value` is the pending alignment/buffer state that still needs to be emitted, + and `%dest` is the UB destination base pointer. +- **outputs:** + No SSA result. The effect is a memory-side flush that writes the remaining + buffered bytes to memory. +- **constraints and limitations:** + This op terminates an unaligned-store sequence. It MUST be paired with a + compatible prior state-producing store sequence so that the pending tail state + is well-defined. +- **Latency:** **9** cycles. + +--- + +#### Stateful Store Ops + +These ops make reference-updated state explicit as SSA results. + +##### `pto.vstus` + +- **syntax:** `%align_out = pto.vstus %align_in, %offset, %value, %base : !pto.align, i32, !pto.vreg, !pto.ptr -> !pto.align` +- **semantics:** No-post unaligned store with scalar offset. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%offset` is the scalar + displacement, `%value` is the vector being stored, and `%base` is the UB base + pointer. +- **outputs:** + `%align_out` is the updated buffered-tail state. +- **constraints and limitations:** + This is the scalar-offset stateful form of the unaligned store family. The + scalar offset width MUST match the selected form, and a later flush op is + still required. The first `%align_in` in the stream should come from + `pto.init_align`. This op does not mean "store a full vector starting at + `%base + %offset`". Instead, `%offset` describes how far the store stream + advances at this step, and `%align_out` carries any residual tail that could + not be committed yet. The no-post surface does not expose an updated base + pointer. A later flush op must therefore use an explicit destination/offset + pair that identifies the same logical flush point as this `pto.vstus`. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstur` + +- **syntax:** `%align_out = pto.vstur %align_in, %value, %base, "MODE" : !pto.align, !pto.vreg, !pto.ptr -> !pto.align` +- **semantics:** Unaligned store with residual flush and SPR-AR-driven state update. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%value` is the vector to + store, `%base` is the UB base pointer, and `MODE` selects whether the + hardware updates `SPR AR` after the store. +- **outputs:** + `%align_out` is the updated residual state after the current partial store. +- **constraints and limitations:** + The effective address is `base + AR`, where `AR` is the hardware SPR state + carried outside SSA. `POST_UPDATE` means hardware may advance `SPR AR` + according to the fixed `SPR SQZN` configuration; `NO_POST_UPDATE` preserves + the current `SPR AR` value. This form exposes only the evolving residual + align-state in SSA; it does not by itself guarantee that all buffered bytes + have reached memory. A compatible final flush is still required unless the + surrounding sequence is known to be complete. Independent sequences typically + begin from `AR = 0`; if the surrounding program does not already guarantee + that, the hardware sequence should clear `SPR AR` before the first dependent + `pto.vstur`. The first `%align_in` in the stream should come from + `pto.init_align`. `pto.vstur` also consumes the fixed `SPR SQZN` state, so a + preceding squeeze producer such as `pto.vsqz` / `pto.vusqz` MUST establish + the byte count before the store. `MODE` MUST be one of `POST_UPDATE` or + `NO_POST_UPDATE`. +- **Latency:** **9** cycles. + + + +### 4. Predicate Load/Store + +> **Category:** UB ↔ Predicate Register data movement +> **Pipeline:** PIPE_V (Vector Core) + +Predicate registers (`!pto.mask`) are 256-bit registers that enable per-lane conditional execution. These ops move predicate values between UB and predicate registers. + +In concrete examples, `G` should be chosen to match the consumer family. The +examples below use `b32` when the loaded/stored mask is used with `f32` +vector compares or selects. + +The predicate load/store ops documented on this page always use explicit +`base[offset]` addressing. The immediate forms (`pldi`, `psti`) and dynamic +forms (`plds`, `psts`) differ only in how `%offset` is supplied. + +--- + +#### Predicate Loads + +##### `pto.plds` + +- **syntax:** `%result = pto.plds %source[%offset], "DIST" : !pto.ptr, index -> !pto.mask` +- **semantics:** Load predicate register with runtime offset. This is the + dynamic-offset form of `pto.pldi`: the predicate payload interpretation is + the same, but `%offset` is supplied as an SSA `index` instead of a constant + `index` immediate. +- **DIST:** mandatory string token, one of `NORM`, `US`, `DS`. + - `NORM`: load a normal packed predicate payload of size `VL/8`. + - `US`: load a packed predicate payload of size `VL/16`, then duplicate each + loaded bit once. + - `DS`: load a packed predicate payload of size `2 * VL/8`, then keep one + bit out of every two bits. + +The loaded payload is a packed predicate image in UB. Consumer ops interpret +the resulting `!pto.mask` according to the mask granularity `G`. +`pto.plds` only +models the explicit `base[offset]` form. + +**Example:** +```mlir +%mask = pto.plds %ub[%c0], "NORM" : !pto.ptr, index -> !pto.mask +``` + +--- + +##### `pto.pldi` + +- **syntax:** `%result = pto.pldi %source[%offset], "DIST" : !pto.ptr, index -> !pto.mask` +- **offset:** must be a constant `index` immediate in PTO surface form. +- **semantics:** Load predicate register with immediate offset. +- **DIST:** mandatory string token, one of `NORM`, `US`, `DS`. + - `NORM`: load a normal packed predicate payload of size `VL/8`. + - `US`: load a packed predicate payload of size `VL/16`, then duplicate each + loaded bit once. + - `DS`: load a packed predicate payload of size `2 * VL/8`, then keep one + bit out of every two bits. + +Like `pto.plds`, this op reads a packed predicate payload from UB and +materializes it as `!pto.mask`. + +--- + +#### Predicate Stores + +##### `pto.psts` + +- **syntax:** `pto.psts %value, %dest[%offset], "DIST" : !pto.mask, !pto.ptr, index` +- **semantics:** Store predicate register with runtime offset. This is the + dynamic-offset form of `pto.psti`: the predicate payload interpretation is + the same, but `%offset` is supplied as an SSA `index` instead of a constant + `index` immediate. +- **DIST:** mandatory string token, one of `NORM`, `PK`. + - `NORM`: store the packed predicate payload into a normal destination space + of size `VL/8`. + - `PK`: store the packed predicate payload into a destination space of size + `VL/16`, keeping one bit out of every two bits. + +`pto.psts` stores the packed predicate payload represented by `!pto.mask`. +It only models the explicit `base[offset]` form. + +**Example:** +```mlir +pto.psts %mask, %ub[%c0], "NORM" : !pto.mask, !pto.ptr, index +``` + +--- + +##### `pto.psti` + +- **syntax:** `pto.psti %value, %dest[%offset], "DIST" : !pto.mask, !pto.ptr, index` +- **offset:** must be a constant `index` immediate in PTO surface form. +- **semantics:** Store predicate register with immediate offset. +- **DIST:** mandatory string token, one of `NORM`, `PK`. + - `NORM`: store the packed predicate payload into a normal destination space + of size `VL/8`. + - `PK`: store the packed predicate payload into a destination space of size + `VL/16`, keeping one bit out of every two bits. + +`pto.psti` and `pto.psts` store the packed predicate payload represented by +`!pto.mask`. The surface distinction is only immediate-offset versus +dynamic-offset. + +--- + +##### `pto.pstu` + +- **syntax:** `%align_out, %base_out = pto.pstu %align_in, %value, %base : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr` +- **syntax:** `%align_out, %base_out = pto.pstu %align_in, %value, %base : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr` +- **semantics:** Predicate unaligned store with align/base state update. The base type is fixed by mask granularity: `b16 <-> ui16`, `b32 <-> ui32`. +- **outputs:** + `%align_out` and `%base_out` are the updated unaligned-store state and are + intended to be used by a later `pto.pstu` call. +- **constraints and limitations:** + The first `%align_in` in a predicate unaligned-store stream should come from + `pto.init_align`. + +--- + +#### Typical Usage Pattern + +```mlir +// Generate comparison mask +%mask = pto.vcmp %v0, %v1, %seed, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + +// Store mask to UB for later use +pto.psts %mask, %ub_mask[%c0], "NORM" : !pto.mask, !pto.ptr, index + +// ... later in another kernel ... + +// Load mask from UB +%saved_mask = pto.plds %ub_mask[%c0], "NORM" : !pto.ptr, index -> !pto.mask + +// Use for predicated select +%result = pto.vsel %v_true, %v_false, %saved_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 5. Materialization & Predicate Ops + +> **Category:** Scalar broadcast, predicate generation and manipulation +> **Pipeline:** PIPE_V (Vector Core) + +These ops create vectors from scalar values and manipulate predicate registers. + +#### Common Operand Model + +- `%value` is the scalar source value in SSA form. +- `%input` is either a source scalar or a source vector depending on the op. +- `%result` is the destination vector register value. +- For 32-bit scalar inputs, the scalar source MUST satisfy the backend's legal + scalar-source constraints for this family. + +--- + +#### Scalar Materialization + +##### `pto.vbr` + +- **syntax:** `%result = pto.vbr %value : T -> !pto.vreg` +- **semantics:** Broadcast scalar to all vector lanes. +- **inputs:** + `%value` is the scalar source. +- **outputs:** + `%result` is a vector whose active lanes all carry `%value`. +- **constraints and limitations:** + Supported forms are `b8`, `b16`, and `b32`. For `b8`, only the low 8 bits of + the scalar source are consumed. + +```c +for (int i = 0; i < N; i++) + dst[i] = value; +``` + +**Example:** +```mlir +%one = pto.vbr %c1_f32 : f32 -> !pto.vreg<64xf32> +``` + +--- + +##### `pto.vdup` + +- **syntax:** `%result = pto.vdup %input, %mask {position = "LOWEST|HIGHEST"} : T|!pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Duplicate scalar or vector element to all lanes. +- **inputs:** + `%input` supplies the scalar or source-lane value selected by `position`, + and `%mask` controls the active lanes. +- **outputs:** + `%result` is the duplicated vector. +- **constraints and limitations:** + `position` selects which source vector element is duplicated and is only valid + for vector input. `position` defaults to `LOWEST`. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? input_scalar_or_element : 0; +``` + +--- + +#### Predicate Generation + +##### `pto.pset_b8` / `pto.pset_b16` / `pto.pset_b32` + +- **syntax:** `%result = pto.pset_b8 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pset_b16 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pset_b32 "PATTERN" : !pto.mask` +- **semantics:** Materialize a predicate register from a named pattern token. + +**Supported pattern tokens:** + +| Pattern | Description | +|---------|-------------| +| `PAT_ALL` | All lanes active | +| `PAT_ALLF` | All lanes inactive | +| `PAT_H` | High half active | +| `PAT_Q` | Upper quarter active | +| `PAT_VL1`...`PAT_VL128` | First N logical lanes active | +| `PAT_M3`, `PAT_M4` | Modular patterns | + +`PAT_ALL` is the PTO spelling of the VISA-style all-true predicate pattern. +The other tokens listed above are also concrete installed-toolchain pattern +objects, not PTO-only aliases. + +**Example — All 64 f32 lanes active:** +```mlir +%all_active = pto.pset_b32 "PAT_ALL" : !pto.mask +``` + +**Example — First 16 lanes active:** +```mlir +%first_16 = pto.pset_b32 "PAT_VL16" : !pto.mask +``` + +--- + +##### `pto.pge_b8` / `pto.pge_b16` / `pto.pge_b32` + +- **syntax:** `%result = pto.pge_b8 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pge_b16 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pge_b32 "PATTERN" : !pto.mask` +- **semantics:** Generate a predicate from a lane-count pattern token. In the + common tail-mask form, `PAT_VL` marks the first `N` logical lanes active. +- **supported pattern tokens:** `PAT_ALL`, `PAT_ALLF`, `PAT_H`, `PAT_Q`, + `PAT_VL1`, `PAT_VL2`, `PAT_VL3`, `PAT_VL4`, `PAT_VL8`, `PAT_VL16`, + `PAT_VL32`, `PAT_VL64`, `PAT_VL128`, `PAT_M3`, `PAT_M4` + +```c +for (int i = 0; i < TOTAL_LANES; i++) + mask[i] = (i < len); +``` + +**Example — Tail mask for remainder loop:** +```mlir +%tail_mask = pto.pge_b32 "PAT_VL8" : !pto.mask +``` + +--- + +##### `pto.plt_b8` / `pto.plt_b16` / `pto.plt_b32` + +- **syntax:** `%mask, %scalar_out = pto.plt_b8 %scalar : i32 -> !pto.mask, i32` +- **syntax:** `%mask, %scalar_out = pto.plt_b16 %scalar : i32 -> !pto.mask, i32` +- **syntax:** `%mask, %scalar_out = pto.plt_b32 %scalar : i32 -> !pto.mask, i32` +- **semantics:** Generate a tail-style predicate from an SSA lane-count value. + On A5/V300-style toolchains, this family is exposed as a post-update wrapper: + the predicate result becomes `%mask`, and the wrapper's carry-out scalar state + is surfaced as `%scalar_out`. +- **inputs:** + `%scalar` is the incoming lane-count / remaining-count state. +- **outputs:** + `%mask` is the generated predicate. + `%scalar_out` is the post-update scalar carry-out from the same `plt` call + and can be threaded into a subsequent `pto.plt_b*` call in the same chain. + +```c +for (int i = 0; i < VL_t; ++i) + mask[i] = (i < scalar_in); + +scalar_out = (scalar_in < VL_t) ? 0 : (scalar_in - VL_t); +``` + +Where `VL_t` is the logical lane count of the concrete op variant: + +- `pto.plt_b8`: `VL_t = 256` +- `pto.plt_b16`: `VL_t = 128` +- `pto.plt_b32`: `VL_t = 64` + +--- + +#### Predicate Pack/Unpack + +##### `pto.ppack` + +- **syntax:** `%result = pto.ppack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Narrowing pack of predicate register. +- **part tokens:** + - `LOWER`: pack into the lower half of `%result`; the upper half is zeroed. + - `HIGHER`: pack into the higher half of `%result`; the lower half is zeroed. + +Conceptually, `pto.ppack` keeps one bit out of each adjacent 2-bit group from +`%input`, packs those kept bits into the selected half of `%result`, and fills +the other half with zeros. + +```c +// Let VL be the logical lane count of the destination predicate. +// LOWER +for (int i = 0; i < VL / 2; ++i) + result[i] = input[2 * i]; +for (int i = VL / 2; i < VL; ++i) + result[i] = 0; + +// HIGHER +for (int i = 0; i < VL / 2; ++i) + result[VL / 2 + i] = input[2 * i]; +for (int i = 0; i < VL / 2; ++i) + result[i] = 0; +``` + +--- + +##### `pto.punpack` + +- **syntax:** `%result = pto.punpack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Widening unpack of predicate register. +- **part tokens:** + - `LOWER`: unpack from the lower half of `%input`. + - `HIGHER`: unpack from the higher half of `%input`. + +Conceptually, `pto.punpack` reads the selected half of `%input`, zero-extends +each 1-bit predicate element into a 2-bit group in `%result`, and leaves the +expanded image in the full destination predicate register. + +```c +// Let VL be the logical lane count of the destination predicate. +// LOWER +for (int i = 0; i < VL / 2; ++i) { + result[2 * i] = input[i]; + result[2 * i + 1] = 0; +} + +// HIGHER +for (int i = 0; i < VL / 2; ++i) { + result[2 * i] = input[VL / 2 + i]; + result[2 * i + 1] = 0; +} +``` + +--- + +#### Predicate Logical Ops + +##### `pto.pand` + +- **syntax:** `%result = pto.pand %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise AND gated by a governing predicate. + +Inactive lanes selected out by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (src0[i] & src1[i]) : 0; +``` + +--- + +##### `pto.por` + +- **syntax:** `%result = pto.por %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise OR gated by a governing predicate. + +Inactive lanes selected out by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (src0[i] | src1[i]) : 0; +``` + +--- + +##### `pto.pxor` + +- **syntax:** `%result = pto.pxor %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise XOR gated by a governing predicate. + +Inactive lanes selected by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (src0[i] ^ src1[i]) : 0; +``` + +--- + +##### `pto.pnot` + +- **syntax:** `%result = pto.pnot %input, %mask : !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise NOT gated by a governing predicate. + +Inactive lanes selected by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (~src[i]) : 0; +``` + +--- + +##### `pto.psel` + +- **syntax:** `%result = pto.psel %src0, %src1, %sel : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate select (mux). `%sel` is the governing predicate that + chooses lanes from `%src0` or `%src1`. + +```c +for (int i = 0; i < N; i++) + dst[i] = sel[i] ? src0[i] : src1[i]; +``` + +--- + +##### `pto.pdintlv_b8` / `pto.pdintlv_b16` / `pto.pdintlv_b32` + +- **syntax:** `%low, %high = pto.pdintlv_b8 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pdintlv_b16 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pdintlv_b32 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **semantics:** De-interleave two predicate sources and return the two + de-interleaved predicate images in the same predicate element family. + +--- + +##### `pto.pintlv_b8` / `pto.pintlv_b16` / `pto.pintlv_b32` + +- **syntax:** `%low, %high = pto.pintlv_b8 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pintlv_b16 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pintlv_b32 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **semantics:** Interleave two predicate sources and return the two + resulting predicate images in the same predicate element family. + +--- + +#### Typical Usage + +```mlir +// Generate all-active mask for f32 (64 lanes) +%all = pto.pset_b32 "PAT_ALL" : !pto.mask + +// Generate tail mask for remainder (last 12 elements) +%tail = pto.pge_b32 "PAT_VL12" : !pto.mask + +// Compare and generate mask +%cmp_mask = pto.vcmp %a, %b, %all, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + +// Combine masks: only process tail elements that passed comparison +%combined = pto.pand %cmp_mask, %tail, %all : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + +// Use for predicated operation +%result = pto.vsel %true_vals, %false_vals, %combined : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 6. Unary Vector Ops + +> **Category:** Single-input vector operations +> **Pipeline:** PIPE_V (Vector Core) + +Element-wise operations that take one vector input and produce one vector output. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate operand. For this family, inactive lanes follow the + predication behavior of the selected instruction form: zeroing forms + zero-fill inactive lanes, while merging forms preserve the destination value. +- `%result` is the destination vector register value. Unless stated otherwise, + `%result` has the same lane count and element type as `%input`. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** values use **aclFloat16** in traces where measured. **bf16:** no simple-tile ST coverage on this surface; treat as **—**. + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vabs` | `RV_VABS_FP` | **5** | **5** | — | +| `pto.vneg` | `RV_VMULS` | **8** | **8** | — | +| `pto.vexp` | `RV_VEXP` | **16** | **21** | — | +| `pto.vln` | `RV_VLN` | **18** | **23** | — | +| `pto.vsqrt` | `RV_VSQRT` | **17** | **22** | — | +| `pto.vrelu` | `RV_VRELU` | **5** | **5** | — | +| `pto.vnot` | `RV_VNOT` | — | int-only paths | — | +| `pto.vmov` | `RV_VLD` proxy | **9** | **9** | — | + +--- + +#### Arithmetic + +##### `pto.vabs` + +- **syntax:** `%result = pto.vabs %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] < 0) ? -src[i] : src[i]; +``` + +- **inputs:** `%input` supplies the source lanes and `%mask` selects which lanes + participate. +- **outputs:** `%result` receives the lane-wise absolute values. +- **constraints and limitations:** Source and result types MUST match. On A5, + integer overflow follows the ISA default truncation behavior for this family; + `pto.vabs` is not an explicit saturating op. + +--- + +##### `pto.vneg` + +- **syntax:** `%result = pto.vneg %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = -src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise arithmetic negation. +- **constraints and limitations:** Source and result types MUST match. + +--- + +#### Transcendental + +##### `pto.vexp` + +- **syntax:** `%result = pto.vexp %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = expf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds `exp(input[i])` per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + +--- + +##### `pto.vln` + +- **syntax:** `%result = pto.vln %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = logf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the natural logarithm per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + For real-number semantics, active inputs SHOULD be strictly positive; non- + positive inputs follow the target's exception/NaN rules. + +--- + +##### `pto.vsqrt` + +- **syntax:** `%result = pto.vsqrt %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = sqrtf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the square root per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + Negative active inputs follow the target's exception/NaN rules. + +--- + +#### Activation + +##### `pto.vrelu` + +- **syntax:** `%result = pto.vrelu %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] > 0) ? src[i] : 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds `max(input[i], 0)` per active lane. +- **constraints and limitations:** Only floating-point element types are legal + on the current A5 surface described here. + +--- + +#### Bitwise + +##### `pto.vnot` + +- **syntax:** `%result = pto.vnot %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = ~src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the lane-wise bitwise inversion. +- **constraints and limitations:** Integer element types only. + +--- + +#### Movement + +#### Typical Usage + +```mlir +// Softmax numerator: exp(x - max) +%sub = pto.vsub %x, %max_broadcast, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%exp = pto.vexp %sub, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// ReLU activation +%activated = pto.vrelu %linear_out, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 7. Binary Vector Ops + +> **Category:** Two-input vector operations +> **Pipeline:** PIPE_V (Vector Core) + +Element-wise operations that take two vector inputs and produce one vector output. + +#### Common Operand Model + +- `%lhs` and `%rhs` are the two source vector register values. +- `%mask` is the predicate operand `Pg` that gates which lanes participate. +- `%result` is the destination vector register value. Unless explicitly noted, + it has the same lane count and element type as the inputs. +- Unless explicitly documented otherwise, `%lhs`, `%rhs`, and `%result` MUST + have matching vector shapes and element types. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** uses **aclFloat16** in measured traces. **bf16:** — (no dedicated vec tile ST on this surface). + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vadd` | `RV_VADD` | **7** | **7** | — | +| `pto.vsub` | `RV_VSUB` | **7** | **7** | — | +| `pto.vmul` | `RV_VMUL` | **8** | **8** | — | +| `pto.vdiv` | `RV_VDIV` | **17** | **22** | — | + +--- + +#### Arithmetic + +##### `pto.vadd` + +- **syntax:** `%result = pto.vadd %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i64, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +- **inputs:** `%lhs` and `%rhs` are added lane-wise; `%mask` selects active + lanes. +- **outputs:** `%result` is the lane-wise sum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vsub` + +- **syntax:** `%result = pto.vsub %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i64, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] - src1[i]; +``` + +- **inputs:** `%lhs` is the minuend, `%rhs` is the subtrahend, and `%mask` + selects active lanes. +- **outputs:** `%result` is the lane-wise difference. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmul` + +- **syntax:** `%result = pto.vmul %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, bf16, f32 (**NOT** i8/ui8) + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] * src1[i]; +``` + +- **inputs:** `%lhs` and `%rhs` are multiplied lane-wise; `%mask` selects + active lanes. +- **outputs:** `%result` is the lane-wise product. +- **constraints and limitations:** The current A5 profile excludes `i8/ui8` + forms from this surface. + +--- + +##### `pto.vdiv` + +- **syntax:** `%result = pto.vdiv %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 only (no integer division) + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] / src1[i]; +``` + +- **inputs:** `%lhs` is the numerator, `%rhs` is the denominator, and `%mask` + selects active lanes. +- **outputs:** `%result` is the lane-wise quotient. +- **constraints and limitations:** Floating-point element types only. Active + denominators containing `+0` or `-0` follow the target's exceptional + behavior. + +--- + +##### `pto.vmax` + +- **syntax:** `%result = pto.vmax %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src0[i] > src1[i]) ? src0[i] : src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` holds the lane-wise maximum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmin` + +- **syntax:** `%result = pto.vmin %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src0[i] < src1[i]) ? src0[i] : src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` holds the lane-wise minimum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +#### Bitwise + +##### `pto.vand` + +- **syntax:** `%result = pto.vand %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] & src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise AND. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vor` + +- **syntax:** `%result = pto.vor %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] | src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise OR. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vxor` + +- **syntax:** `%result = pto.vxor %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] ^ src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise XOR. +- **constraints and limitations:** Integer element types only. + +--- + +#### Shift + +##### `pto.vshl` + +- **syntax:** `%result = pto.vshl %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] << src1[i]; +``` + +- **inputs:** `%lhs` supplies the shifted value, `%rhs` supplies the per-lane + shift amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. Shift counts + SHOULD stay within `[0, bitwidth(T) - 1]`; out-of-range behavior is target- + defined unless the verifier narrows it further. + +--- + +##### `pto.vshr` + +- **syntax:** `%result = pto.vshr %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] >> src1[i]; // arithmetic for signed, logical for unsigned +``` + +- **inputs:** `%lhs` supplies the shifted value, `%rhs` supplies the per-lane + shift amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. Signedness of the + element type determines arithmetic vs logical behavior. + +--- + +#### Carry Operations + +##### `pto.vaddc` + +- **syntax:** `%result, %carry = pto.vaddc %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Add with carry output. + +```c +for (int i = 0; i < N; i++) { + uint64_t r = (uint64_t)src0[i] + src1[i]; + dst[i] = (T)r; + carry[i] = (r >> bitwidth); +} +``` + +- **inputs:** `%lhs` and `%rhs` are added lane-wise and `%mask` selects active + lanes. +- **outputs:** `%result` is the truncated arithmetic result and `%carry` is the + carry/overflow predicate per lane. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This is a carry-chain integer add family. On + the current A5 surface, only 32-bit integer element types are supported. + `%mask` and `%carry` therefore use the same typed-mask granularity as the + data vector family, which on the current documented A5 surface means + `!pto.mask`. + +--- + +##### `pto.vsubc` + +- **syntax:** `%result, %carry = pto.vsubc %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Subtract with per-lane carry output. + +```c +for (int i = 0; i < N; i++) { + dst[i] = src0[i] - src1[i]; + carry[i] = (src0[i] >= src1[i]); +} +``` + +- **inputs:** `%lhs` and `%rhs` are subtracted lane-wise and `%mask` selects + active lanes. +- **outputs:** `%result` is the arithmetic difference and `%carry` is the + per-lane carry predicate. For this subtraction family, active lanes set + `%carry[i] = 1` when the subtraction completes without borrow, and + `%carry[i] = 0` when a borrow occurs. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This operation is currently restricted to + the 32-bit integer carry/borrow-chain family. `%mask` and `%carry` + therefore use the same typed-mask granularity as the data vector family, + which on the current documented A5 surface means `!pto.mask`. + +--- + +#### Typical Usage + +```mlir +// Vector addition +%sum = pto.vadd %a, %b, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Element-wise multiply +%prod = pto.vmul %x, %y, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Clamp to range [min, max] +%clamped_low = pto.vmax %input, %min_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%clamped = pto.vmin %clamped_low, %max_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Bit manipulation +%masked = pto.vand %data, %bitmask, %mask : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +``` + + + +### 8. Vec-Scalar Ops + +> **Category:** Vector-scalar operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that combine a vector with a scalar value, applying the scalar to every lane. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%scalar` is the scalar operand in SSA form. +- `%mask` is the predicate operand. +- `%result` is the destination vector register value. +- For 32-bit scalar forms, the scalar source MUST satisfy the backend's legal + scalar-source constraints for this family. +- For elementwise vec-scalar families whose scalar conceptually matches the + vector element type (`pto.vadds`, `pto.vmuls`, `pto.vmaxs`, + `pto.vmins`, `pto.vlrelu`): + - signed integer vectors accept signed integer scalars with the same width, + and also accept signless `i` + - unsigned integer vectors accept unsigned integer scalars with the same + width, and also accept signless `i` + - signless integer vectors accept signless `i` +- `pto.vshls` and `pto.vshrs` are not part of that rule; their scalar operand + is the shift amount and remains fixed to `i16`. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** uses **aclFloat16** in measured traces. **bf16:** —. + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vadds` | `RV_VADDS` | **7** | **7** | — | +| `pto.vmuls` | `RV_VMULS` | **8** | **8** | — | + +--- + +#### Arithmetic + +##### `pto.vadds` + +- **syntax:** `%result = pto.vadds %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` +- **A5 types:** `si8`, `si16`, `si32`, `ui8`, `ui16`, `ui32`, `f16`, `bf16`, `f32` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] + scalar; +``` + +- **inputs:** `%input` is the source vector, `%scalar` is broadcast logically to + each lane, and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise sum. +- **constraints and limitations:** Input vector element type, scalar type, and + result vector element type MUST match. For integer vector forms, `%scalar` + may also use matching-signedness integer or signless `i` with the same + bit width as the vector element type, so it can be fed directly from `arith` + constants. + +--- + +##### `pto.vmuls` + +- **syntax:** `%result = pto.vmuls %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] * scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise product. +- **constraints and limitations:** Supported element types are hardware-family + specific; the current PTO micro Instruction documentation covers the common + numeric cases. For integer vector forms, `%scalar` may use matching-signedness + integer or signless `i` with the same bit width as the vector element + type. + +--- + +##### `pto.vmaxs` + +- **syntax:** `%result = pto.vmaxs %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] > scalar) ? src[i] : scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise maximum. +- **constraints and limitations:** Input and result types MUST match. For + integer vector forms, `%scalar` may use matching-signedness integer or + signless `i` with the same bit width as the vector element type. + +--- + +##### `pto.vmins` + +- **syntax:** `%result = pto.vmins %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] < scalar) ? src[i] : scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise minimum. +- **constraints and limitations:** Input and result types MUST match. For + integer vector forms, `%scalar` may use matching-signedness integer or + signless `i` with the same bit width as the vector element type. + +--- + +#### Shift + +##### `pto.vshls` + +- **syntax:** `%result = pto.vshls %input, %scalar, %mask : !pto.vreg, i16, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] << scalar; +``` + +- **inputs:** `%input` is the value vector, `%scalar` is the uniform `i16` shift + amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. The shift amount + SHOULD stay within the source element width. + +--- + +##### `pto.vshrs` + +- **syntax:** `%result = pto.vshrs %input, %scalar, %mask : !pto.vreg, i16, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] >> scalar; +``` + +- **inputs:** `%input` is the value vector, `%scalar` is the uniform `i16` shift + amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vlrelu` + +- **syntax:** `%result = pto.vlrelu %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : scalar * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%scalar` is the leaky slope, + and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise leaky-ReLU result. +- **constraints and limitations:** Only `f16` and `f32` forms are currently + documented for `pto.vlrelu`. + +--- + +#### Carry Operations + +##### `pto.vaddcs` + +- **syntax:** `%result, %carry = pto.vaddcs %lhs, %rhs, %carry_in, %mask : !pto.vreg, !pto.vreg, !pto.mask, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Add with carry-in and carry-out. + +```c +for (int i = 0; i < N; i++) { + uint64_t r = (uint64_t)src0[i] + src1[i] + carry_in[i]; + dst[i] = (T)r; + carry_out[i] = (r >> bitwidth); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the value vectors, `%carry_in` is the + incoming carry predicate, and `%mask` selects active lanes. +- **outputs:** `%result` is the arithmetic result and `%carry` is the carry-out + predicate. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This is the scalar-extended carry-chain + family. On the current A5 surface, only 32-bit integer element types are + supported. `%carry_in`, `%mask`, and `%carry` therefore all use the same + typed-mask granularity as the data vector family, which on the current + documented A5 surface means `!pto.mask`. + +--- + +##### `pto.vsubcs` + +- **syntax:** `%result, %carry = pto.vsubcs %lhs, %rhs, %carry_in, %mask : !pto.vreg, !pto.vreg, !pto.mask, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Subtract with carry input and output. + +```c +for (int i = 0; i < N; i++) { + dst[i] = src0[i] - src1[i] - (1 - carry_in[i]); + carry_out[i] = (src0[i] >= src1[i] + (1 - carry_in[i])); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the value vectors, `%carry_in` is the + incoming carry predicate, and `%mask` selects active lanes. +- **outputs:** `%result` is the arithmetic result and `%carry` is the + carry predicate after the lane-wise subtraction. For this subtraction family, + active lanes set `%carry[i] = 1` when the subtraction completes without + borrow, and `%carry[i] = 0` when a borrow occurs. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This is the scalar-extended borrow-chain + family and is currently restricted to 32-bit integer element types. + `%carry_in`, `%mask`, and `%carry` therefore all use the same typed-mask + granularity as the data vector family, which on the current documented A5 + surface means `!pto.mask`. + +--- + +#### Typical Usage + +```mlir +// Add bias to all elements +%biased = pto.vadds %activation, %bias_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Scale by constant +%scaled = pto.vmuls %input, %scale, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Clamp to [0, 255] for uint8 quantization +%clamped_low = pto.vmaxs %input, %c0, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> +%clamped = pto.vmins %clamped_low, %c255, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Shift right by fixed amount +%shifted = pto.vshrs %data, %c4, %mask : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +``` + + + +### 9. Conversion Ops + +> **Category:** Type conversion operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that convert between data types (float/int, narrowing/widening). + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate mask that selects active conversion lanes. +- `%result` is the destination vector register value. +- `rnd`, `sat`, and `part` are optional attributes that refine + conversion behavior when the selected source/destination type pair needs + rounding, saturation, or lane placement control. +- The single `pto.vcvt` surface covers float-int, float-float, int-float, and + int-int conversion families. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). Only representative traces below; other `pto.vcvt` conversion pairs depend on the RV lowering in the trace. + +| PTO op | RV (CA) | Note | Latency | +|--------|---------|------|---------| +| `pto.vcvt` | `RV_VCVT_F2F` | f32→f16 | **7** | +| `pto.vci` | — | no vector `RV_*` in sampled `veccore0` trace | — | + +--- + +#### `pto.vci` + +- **syntax:** `%result = pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- **semantics:** Generate a lane-index vector from a scalar seed/index value. +- **inputs:** + `%index` is the scalar seed or base index. +- **outputs:** + `%result` is the generated index vector. +- **constraints and limitations:** + This is an index-generation family, not a numeric conversion. `ORDER` and the + result element type together determine how indices are generated. `%result` + uses an integer element type, and the scalar `%index` type matches that + result element type. + +--- + +#### `pto.vcvt` + +- **syntax:** `%result = pto.vcvt %input, %mask {rnd = "RND", sat = "SAT", part = "PART"} : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Type conversion between float/int types with rounding control. + +```c +for (int i = 0; i < min(N, M); i++) + if (mask[i]) + dst[i] = convert(src[i], T0, T1, rnd); +``` + +- **inputs:** + `%input` is the source vector, `%mask` selects active lanes, and attributes + select rounding, saturation, and output placement when the conversion changes + width or packs into sub-lane positions. +- **outputs:** + `%result` is the converted vector. +- **constraints and limitations:** + Only documented source/destination type pairs are legal. All three + attributes are optional at the surface level, but only the subset meaningful + to the selected conversion kind should be provided. The execution mask must + use the typed-mask granularity that matches the source vector family on the + current surface; there is no `!pto.mask` form in VPTO. + +--- + +##### Rounding Modes + +| Mode | Description | +|------|-------------| +| `R` | Round to nearest, ties to even (default) | +| `A` | Round away from zero | +| `F` | Round toward negative infinity (floor) | +| `C` | Round toward positive infinity (ceil) | +| `Z` | Round toward zero (truncate) | +| `O` | Round to odd | + +--- + +##### Saturation Modes + +| Mode | Description | +|------|-------------| +| `SAT` | Saturate on overflow | +| `NOSAT` | No saturation (wrap/undefined on overflow) | + +--- + +##### Part Modes + +Use `part` when a width-changing conversion writes only one half of each wider +destination lane group. This is typically used in even/odd placement forms such +as `32 -> 16` or `16 -> 32` style conversions. + +| Mode | Description | +|------|-------------| +| `EVEN` | Output to even-indexed lanes | +| `ODD` | Output to odd-indexed lanes | + +--- + +##### Attribute Guidance + +- `rnd` + - Use when the conversion needs an explicit rounding rule, especially for + float-to-int, float-to-float narrowing, or integer-to-float forms that do + not map exactly. +- `mask` + - Use to select which source lanes participate in the conversion. In + width-changing conversions, `mask` works together with `part` / `pp` to + determine which logical lane positions are produced. +- `sat` + - Use when the conversion may overflow the destination range and hardware + exposes a saturating form. +- `part` + - Use for width-changing conversions that select the even or odd half of the + destination packing layout. + +###### Float To Int + +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<32xsi64>` +- `%dst = pto.vcvt %src, %mask {rnd, sat} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {rnd, part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {rnd, sat} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<256xsi8>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xsi32>` + +###### Float To Float + +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xbf16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xf32>` + +###### Int To Float + +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xsi8>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {rnd} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<64xf32>` +- `%dst = pto.vcvt %src, %mask {rnd} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<64xf32>` + +###### Int To Int + +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<128xui16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<64xui32>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xsi8>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xsi8>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<64xui32>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<64xui32>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<128xui16>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<128xui16>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<32xsi64>` + +##### A5 Supported Type Matrix + +The table below is only a summary. For exact attribute combinations, use the +per-form entries above as the source of truth. + +| `src \ dst` | `ui8` | `si8` | `ui16` | `si16` | `ui32` | `si32` | `si64` | `f16` | `f32` | `bf16` | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| `ui8` | | | Y | | Y | | | Y | | | +| `si8` | | | | Y | | Y | | Y | | | +| `ui16` | Y | | | | Y | | | | | | +| `si16` | Y | | | | Y | Y | | Y | Y | | +| `ui32` | Y | | Y | Y | | | | | | | +| `si32` | Y | | Y | Y | | | Y | | Y | | +| `si64` | | | | | | | | | | | +| `f16` | Y | Y | | Y | | Y | | | Y | | +| `f32` | | | | Y | | Y | Y | Y | | Y | +| `bf16` | | | | | | Y | | | Y | | + +--- + +##### Width-Changing Conversion Pattern + +For conversions that change width (e.g., f32→f16), use even/odd parts and combine: + +```mlir +// Convert two f32 vectors to one f16 vector +%even = pto.vcvt %in0, %mask {rnd = "R", sat = "SAT", part = "EVEN"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%odd = pto.vcvt %in1, %mask {rnd = "R", sat = "SAT", part = "ODD"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%result = pto.vor %even, %odd, %mask : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +``` + +--- + +#### `pto.vtrc` + +- **syntax:** `%result = pto.vtrc %input, %mask, "RND" : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Truncate/round float to integer-valued float (stays in float type). + +```c +for (int i = 0; i < N; i++) + dst[i] = round_to_int_valued_float(src[i], rnd); +``` + +- **inputs:** + `%input` is the floating-point source vector, `%mask` selects active lanes, + and `RND` selects the truncation/rounding rule. +- **outputs:** + `%result` is still a floating-point vector, but each active lane now carries + an integer-valued floating-point result. +- **constraints and limitations:** + This op does not change the element type. `T` must be `f16`, `f32`, or + `bf16`. `RND` must be one of `R`, `A`, `F`, `C`, or `Z`. `BW` must match the + element width: `b16` for `f16`/`bf16`, `b32` for `f32`. + +**Example:** +```mlir +// Round to nearest integer, keep as float +%rounded = pto.vtrc %input, %mask, "R" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// input: [1.4, 2.6, -1.5, 3.0] +// output: [1.0, 3.0, -2.0, 3.0] +``` + +--- + +#### Typical Usage + +```mlir +// Quantization: f32 → i8 with saturation +%scaled = pto.vmuls %input, %scale, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> +%quantized = pto.vcvt %scaled, %mask {rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> +// Then narrow i32 → i8 via pack ops + +// Mixed precision: bf16 → f32 for accumulation +%f32_vec = pto.vcvt %bf16_input, %mask {part = "EVEN"} + : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xf32> + +// Floor for integer division +%floored = pto.vtrc %ratio, %mask, "F" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%int_div = pto.vcvt %floored, %mask {rnd = "Z"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> +``` + + + +### 10. Reduction Ops + +> **Category:** Vector reduction operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that reduce a vector to a scalar or per-group result. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate operand `Pg`; inactive lanes do not participate. +- `%result` is the destination vector register value. +- Reduction results are written into the low-significance portion of the + destination vector and the remaining destination bits are zero-filled. + +--- + +#### Full Vector Reductions + +##### `pto.vcadd` + +- **syntax:** `%result = pto.vcadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i64, f16, f32 +- **semantics:** Sum all elements. Result in lane 0, others zeroed. + +```c +T sum = 0; +for (int i = 0; i < N; i++) + sum += src[i]; +dst[0] = sum; +for (int i = 1; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains the reduction result in its low element(s). +- **constraints and limitations:** Some narrow integer forms may widen the + internal accumulation or result placement. If all predicate bits are zero, the + result is zero. + +--- + +##### `pto.vcmax` + +- **syntax:** `%result = pto.vcmax %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Find max element with argmax. The lowest destination element + stores the maximum value, the second-lowest destination element stores the + index of the first maximum, and all remaining elements are zero-filled. + +```c +T mx = -INF; int idx = 0; +for (int i = 0; i < N; i++) + if (src[i] > mx) { mx = src[i]; idx = i; } +dst[0] = mx; +dst[1] = idx; +for (int i = 2; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result[0]` holds the extremum value and `%result[1]` holds the + index. Other destination elements are zero-filled. +- **constraints and limitations:** If there are multiple maxima, the minimum + index is written. For floating-point types, inactive lanes are treated as + `-INF`; if all lanes are inactive, `%result[0]` becomes `-INF`. For integer + types, inactive lanes are treated as the literal minimum value; if all lanes + are inactive, `%result[0]` becomes that literal minimum value. The index is + written into the second destination element slot of the same destination + vector register. + +--- + +##### `pto.vcmin` + +- **syntax:** `%result = pto.vcmin %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Find min element with argmin. The lowest destination element + stores the minimum value, the second-lowest destination element stores the + index of the first minimum, and all remaining elements are zero-filled. + +```c +T mn = INF; int idx = 0; +for (int i = 0; i < N; i++) + if (src[i] < mn) { mn = src[i]; idx = i; } +dst[0] = mn; +dst[1] = idx; +for (int i = 2; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result[0]` holds the extremum value and `%result[1]` holds the + index. Other destination elements are zero-filled. +- **constraints and limitations:** If there are multiple minima, the minimum + index is written. For floating-point types, inactive lanes are treated as + `+INF`; if all lanes are inactive, `%result[0]` becomes `+INF`. For integer + types, inactive lanes are treated as the literal maximum value; if all lanes + are inactive, `%result[0]` becomes that literal maximum value. The index is + written into the second destination element slot of the same destination + vector register. + +--- + +#### Per-VLane (Group) Reductions + +The vector register is organized as **8 VLanes** of 32 bytes each. Group reductions operate within each VLane independently. + +``` +vreg layout (f32 example, 64 elements total): +VLane 0: [0..7] VLane 1: [8..15] VLane 2: [16..23] VLane 3: [24..31] +VLane 4: [32..39] VLane 5: [40..47] VLane 6: [48..55] VLane 7: [56..63] +``` + +##### `pto.vcgadd` + +- **syntax:** `%result = pto.vcgadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Sum within each VLane. 8 results at indices 0, 8, 16, 24, 32, 40, 48, 56 (for f32). + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +// For f32: results at dst[0], dst[8], dst[16], dst[24], dst[32], dst[40], dst[48], dst[56] +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one sum per 32-byte VLane group, written + contiguously into the low slot of each group. +- **constraints and limitations:** This is a per-32-byte VLane-group reduction. + Inactive lanes are treated as zero. + +--- + +##### `pto.vcgmax` + +- **syntax:** `%result = pto.vcgmax %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Max within each VLane. + +```c +int K = N / 8; +for (int g = 0; g < 8; g++) { + T mx = -INF; + for (int i = 0; i < K; i++) + if (src[g*K + i] > mx) mx = src[g*K + i]; + dst[g*K] = mx; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one maximum per 32-byte VLane group. +- **constraints and limitations:** Grouping is by hardware 32-byte VLane, not by + arbitrary software subvector. + +--- + +##### `pto.vcgmin` + +- **syntax:** `%result = pto.vcgmin %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Min within each VLane. + +```c +int K = N / 8; +for (int g = 0; g < 8; g++) { + T mn = INF; + for (int i = 0; i < K; i++) + if (src[g*K + i] < mn) mn = src[g*K + i]; + dst[g*K] = mn; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one minimum per 32-byte VLane group. +- **constraints and limitations:** Grouping is by hardware 32-byte VLane, not by + arbitrary software subvector. + +--- + +#### Prefix Operations + +##### `pto.vcpadd` + +- **syntax:** `%result = pto.vcpadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Inclusive prefix sum (scan). + +```c +dst[0] = src[0]; +for (int i = 1; i < N; i++) + dst[i] = dst[i-1] + src[i]; +``` + +**Example:** +```c +// input: [1, 2, 3, 4, 5, ...] +// output: [1, 3, 6, 10, 15, ...] +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` is the inclusive prefix-sum vector. +- **constraints and limitations:** Only floating-point element types are + documented on the current A5 surface here. + +--- + +#### Typical Usage + +```mlir +// Softmax: find max for numerical stability +%max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// max is in lane 0, broadcast it +%max_broadcast = pto.vlds %ub_tmp[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> + +// Row-wise sum using vcgadd (for 8-row tile) +%row_sums = pto.vcgadd %tile, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// Results at indices 0, 8, 16, 24, 32, 40, 48, 56 + +// Full vector sum for normalization +%total = pto.vcadd %values, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// total[0] contains the sum + +// Prefix sum for cumulative distribution +%cdf = pto.vcpadd %pdf, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 11. Compare & Select + +> **Category:** Comparison and conditional selection operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that compare vectors and conditionally select elements. + +#### Common Operand Model + +- `%src0` and `%src1` are source vector operands. +- `%scalar` is the scalar operand for scalar-comparison families. +- `%seed` is the incoming predicate that limits which lanes participate in the + compare. +- `%result` is either a predicate mask (`vcmp`, `vcmps`) or a vector register + (`vsel`, `vselr`, `vselrv2`). + +--- + +#### Comparison Operations + +##### `pto.vcmp` + +- **syntax:** `%result = pto.vcmp %src0, %src1, %seed, "CMP_MODE" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.mask` +- **semantics:** Element-wise comparison, output predicate mask. + +```c +for (int i = 0; i < N; i++) + if (seed[i]) + dst[i] = (src0[i] CMP src1[i]) ? 1 : 0; +``` + +**Compare modes:** + +| Mode | Operation | +|------|-----------| +| `eq` | Equal (==) | +| `ne` | Not equal (!=) | +| `lt` | Less than (<) | +| `le` | Less than or equal (<=) | +| `gt` | Greater than (>) | +| `ge` | Greater than or equal (>=) | + +**Example:** +```mlir +%all_active = pto.pset_b32 "PAT_ALL" : !pto.mask +%lt_mask = pto.vcmp %a, %b, %all_active, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +// lt_mask[i] = 1 if a[i] < b[i] +``` + +- **inputs:** `%src0`, `%src1`, and `%seed`; `CMP_MODE` selects the comparison + predicate. +- **outputs:** `%result` is the generated predicate mask. +- **constraints and limitations:** Only lanes enabled by `%seed` participate. + Integer and floating-point comparisons follow their own element-type-specific + comparison rules. `%seed` and `%result` keep the typed-mask granularity that + matches `%src0` / `%src1`. + +--- + +##### `pto.vcmps` + +- **syntax:** `%result = pto.vcmps %src, %scalar, %seed, "CMP_MODE" : !pto.vreg, T, !pto.mask -> !pto.mask` +- **semantics:** Compare vector against scalar. + +```c +for (int i = 0; i < N; i++) + if (seed[i]) + dst[i] = (src[i] CMP scalar) ? 1 : 0; +``` + +**Example:** +```mlir +%positive_mask = pto.vcmps %values, %c0_f32, %all_active, "gt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +// positive_mask[i] = 1 if values[i] > 0 +``` + +- **inputs:** `%src` is the vector source, `%scalar` is the scalar comparison + value, and `%seed` is the incoming predicate. +- **outputs:** `%result` is the generated predicate mask. +- **constraints and limitations:** For 32-bit scalar forms, the scalar source + MUST satisfy the backend's legal scalar-source constraints for this family. + `%seed` and `%result` keep the typed-mask granularity that matches `%src`. + +--- + +#### Selection Operations + +##### `pto.vsel` + +- **syntax:** `%result = pto.vsel %src0, %src1, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Per-lane select based on mask. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? src0[i] : src1[i]; +``` + +**Example — Conditional assignment:** +```mlir +// dst = mask ? true_vals : false_vals +%result = pto.vsel %true_vals, %false_vals, %condition + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +- **inputs:** `%src0` is the true-path vector, `%src1` is the false-path vector, + and `%mask` selects between them. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** Source vectors and result MUST have matching + vector shapes and element types. `%mask` keeps the typed-mask granularity + that matches the selected vector family. + +--- + +##### `pto.vselr` + +- **syntax:** `%result = pto.vselr %src, %idx : !pto.vreg, !pto.vreg> -> !pto.vreg` +- **semantics:** Lane-select by index vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = src[idx[i]]; +``` + +- **inputs:** `%src` is the source vector. `%idx` is the lane-index vector. +- **outputs:** `%result` is the reordered vector. +- **constraints and limitations:** `%idx` must use integer elements. `%idx` + must have the same lane count as `%src`, and its integer element width must + match the bit width of `%src` element type. + +--- + +##### `pto.vselrv2` + +- **syntax:** `%result = pto.vselrv2 %src0, %src1 : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Variant select form with the same current two-vector operand shape. +- **inputs:** `%src0` and `%src1` are the source vectors. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** This page records the surface shape only. + Lowering MUST preserve the exact A5 variant semantics selected for this form. + +--- + +#### Typical Usage + +```mlir +// Clamp negative values to zero (manual ReLU) +%all = pto.pset_b32 "PAT_ALL" : !pto.mask +%zero = pto.vbr %c0_f32 : f32 -> !pto.vreg<64xf32> +%neg_mask = pto.vcmps %input, %c0_f32, %all, "lt" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%clamped = pto.vsel %zero, %input, %neg_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Element-wise max via compare+select +%gt_mask = pto.vcmp %a, %b, %all, "gt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +%max_ab = pto.vsel %a, %b, %gt_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Threshold filter +%above_thresh = pto.vcmps %scores, %threshold, %all, "ge" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%filtered = pto.vsel %scores, %zero, %above_thresh : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +--- + +#### Compare + Select Pattern + +```mlir +// Softmax safe exp: exp(x - max) where x < max returns exp of negative +// but we want to clamp to avoid underflow + +%all = pto.pset_b32 "PAT_ALL" : !pto.mask + +// 1. Compare against threshold +%too_small = pto.vcmps %x_minus_max, %min_exp_arg, %all, "lt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask + +// 2. Clamp values below threshold +%clamped = pto.vsel %min_exp_arg_vec, %x_minus_max, %too_small + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// 3. Safe exp +%exp_result = pto.vexp %clamped, %all : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 12. Data Rearrangement + +> **Category:** In-register data movement and permutation +> **Pipeline:** PIPE_V (Vector Core) + +Operations that rearrange data within or between vector registers without memory access. + +#### Common Operand Model + +- `%lhs` / `%rhs` are source vector register values. +- `%src` is a single source vector register value. +- `%result` is the destination vector register value unless an op explicitly + returns multiple vectors. +- These families do not access UB directly; they only rearrange register + contents. + +--- + +#### Interleave / Deinterleave + +##### `pto.vintlv` + +- **syntax:** `%low, %high = pto.vintlv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg, !pto.vreg` +- **semantics:** Interleave elements from two sources. + +```c +// Interleave: merge even/odd elements from two sources +// low = {src0[0], src1[0], src0[1], src1[1], ...} +// high = {src0[N/2], src1[N/2], src0[N/2+1], src1[N/2+1], ...} +``` + +- **inputs:** `%lhs` and `%rhs` are the two source vectors. +- **outputs:** `%low` and `%high` are the two destination vectors. +- **constraints and limitations:** The two outputs form a paired interleave + result. The PTO micro Instruction representation exposes that pair as two SSA results, and the pair ordering MUST + be preserved. + +--- + +##### `pto.vdintlv` + +- **syntax:** `%low, %high = pto.vdintlv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg, !pto.vreg` +- **semantics:** Deinterleave elements into even/odd. + +```c +// Deinterleave: separate even/odd elements +// low = {src0[0], src0[2], src0[4], ...} // even +// high = {src0[1], src0[3], src0[5], ...} // odd +``` + +- **inputs:** `%lhs` and `%rhs` represent the interleaved source stream in the + current PTO micro Instruction representation. +- **outputs:** `%low` and `%high` are the separated destination vectors. +- **constraints and limitations:** The two outputs form the even/odd + deinterleave result pair, and their ordering MUST be preserved. + +--- + +#### Compress / Expand + +##### `pto.vsqz` + +- **syntax:** `%result = pto.vsqz %src, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Compress — pack active lanes to front. + +```c +int j = 0; +for (int i = 0; i < N; i++) + if (mask[i]) dst[j++] = src[i]; +while (j < N) dst[j++] = 0; +``` + +**Use case:** Sparse data compaction, filtering. + +- **inputs:** `%src` is the source vector and `%mask` selects which elements are + kept. +- **outputs:** `%result` is the compacted vector. +- **constraints and limitations:** This is a reduction-style compaction family. + Preserved element order MUST match source lane order. + +--- + +##### `pto.vusqz` + +- **syntax:** `%result = pto.vusqz %src, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Generate per-lane prefix counts from the governing predicate. + +```c +dst[0] = 0; +for (int i = 1; i < N; i++) + dst[i] = mask[i - 1] ? (dst[i - 1] + 1) : dst[i - 1]; +``` + +- **inputs:** `%mask` is the governing predicate. The current PTO surface keeps + `%src` in the operand list for interface compatibility, but the observable + result semantics are determined by `%mask`. +- **outputs:** `%result[i]` equals the number of active lanes in `%mask[0:i)`, + with `%result[0] = 0`. +- **constraints and limitations:** `T` is currently limited to `si8`, `si16`, + or `si32`. This operation is a predicate-derived counting/rearrangement + primitive rather than a value-placement primitive. The final predicate lane + does not contribute to a later output lane because there is no `dst[N]`. + +--- + +--- + +##### `pto.vselr` + +- **syntax:** `%result = pto.vselr %src, %idx : !pto.vreg, !pto.vreg> -> !pto.vreg` +- **semantics:** Register lane-select with an explicit index vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = src[idx[i]]; +``` + +- **inputs:** `%src` is the source vector. `%idx` is the lane-index vector. +- **outputs:** `%result` is the reordered vector. +- **constraints and limitations:** This page records the rearrangement use of + the family; the compare/select page documents the same name from the predicate + selection perspective. + +--- + +#### Pack / Unpack + +##### `pto.vpack` + +- **syntax:** `%result = pto.vpack %src, "PART" : !pto.vreg -> !pto.vreg<2NxT_narrow>` +- **semantics:** Narrow one wide vector and place the narrowed payload into the + selected half of the result. The other half is filled with zero. + +```c +// e.g., vreg<64xi32> → vreg<128xui16> +for (int i = 0; i < N; i++) + dst[i] = 0; + +if (part == LOWER) { + for (int i = 0; i < N; i++) + dst[i] = truncate(src[i]); +} else { // HIGHER + for (int i = 0; i < N; i++) + dst[N + i] = truncate(src[i]); +} +``` + +- **inputs:** `%src` is the wide source vector. `"LOWER"` and `"HIGHER"` + select whether the narrowed payload lands in the lower or upper half. +- **outputs:** `%result` is the packed narrow vector. +- **constraints and limitations:** Packing is a narrowing conversion with + truncation semantics. Current VPTO surface supports `i32/ui32 -> ui16` and + `i16/ui16 -> ui8`. + +--- + +##### `pto.vsunpack` + +- **syntax:** `%result = pto.vsunpack %src, %part : !pto.vreg, index -> !pto.vreg` +- **semantics:** Sign-extending unpack — narrow to wide (half). + +```c +// e.g., vreg<128xi16> → vreg<64xi32> (one half) +for (int i = 0; i < N/2; i++) + dst[i] = sign_extend(src[part_offset + i]); +``` + +- **inputs:** `%src` is the packed narrow vector and `%part` selects which half + is unpacked. +- **outputs:** `%result` is the widened vector. +- **constraints and limitations:** This is the sign-extending unpack family. + +--- + +##### `pto.vzunpack` + +- **syntax:** `%result = pto.vzunpack %src, %part : !pto.vreg, index -> !pto.vreg` +- **semantics:** Zero-extending unpack — narrow to wide (half). + +```c +for (int i = 0; i < N/2; i++) + dst[i] = zero_extend(src[part_offset + i]); +``` + +- **inputs:** `%src` is the packed narrow vector and `%part` selects which half + is unpacked. +- **outputs:** `%result` is the widened vector. +- **constraints and limitations:** This is the zero-extending unpack family. + +--- + +#### Typical Usage + +```mlir +// AoS → SoA conversion using deinterleave +%even, %odd = pto.vdintlv %interleaved0, %interleaved1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// Filter: keep only elements passing condition +%pass_mask = pto.vcmps %values, %threshold, %all, "gt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%compacted = pto.vsqz %values, %pass_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Type narrowing via pack +%packed_i16 = pto.vpack %wide_i32, "LOWER" + : !pto.vreg<64xi32> -> !pto.vreg<128xui16> +``` + +--- + +#### V2 Interleave Forms + +##### `pto.vintlvv2` + +- **syntax:** `%result = pto.vintlvv2 %lhs, %rhs, "PART" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **inputs:** `%lhs` and `%rhs` are source vectors and `PART` selects the + returned half of the V2 interleave result. +- **outputs:** `%result` is the selected interleave half. +- **constraints and limitations:** This op exposes only one half of the V2 + result in SSA form. + +##### `pto.vdintlvv2` + +- **syntax:** `%result = pto.vdintlvv2 %lhs, %rhs, "PART" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **inputs:** `%lhs` and `%rhs` are source vectors and `PART` selects the + returned half of the V2 deinterleave result. +- **outputs:** `%result` is the selected deinterleave half. +- **constraints and limitations:** This op exposes only one half of the V2 + result in SSA form. + + + +### 13. DSA/SFU Ops + +> **Category:** Domain-specific accelerator and special function unit operations +> **Pipeline:** PIPE_V (Vector Core) / SFU + +Fused operations, special functions, and UB-to-UB operations that leverage hardware acceleration. + +#### Common Operand Model + +- `%input`, `%lhs`, `%rhs`, `%acc`, and `%alpha` are source SSA values whose + roles are called out per instruction. +- `%mask` is the predicate operand `Pg` when present. +- `%result` is the destination SSA value. +- This page mixes three different backend shapes: pure `vreg -> vreg` ops, + conversion/fusion ops, and UB-to-UB helpers. Each instruction section calls + out which storage model it uses. + +--- + +#### Fused Activation Ops (vreg→vreg) + +##### `pto.vlrelu` + +- **syntax:** `%result = pto.vlrelu %input, %alpha, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Leaky ReLU with scalar alpha. + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : alpha * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%alpha` is the scalar slope, + and `%mask` selects active lanes. +- **outputs:** `%result` is the leaky-ReLU vector. +- **constraints and limitations:** Only `f16` and `f32` forms are currently + documented for `pto.vlrelu`. + +--- + +##### `pto.vprelu` + +- **syntax:** `%result = pto.vprelu %input, %alpha : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Parametric ReLU with per-element alpha vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : alpha[i] * src[i]; +``` + +- **inputs:** `%input` is the activation vector and `%alpha` is the per-element + slope vector. +- **outputs:** `%result` is the parametric-ReLU vector. +- **constraints and limitations:** Floating-point element types only on the + current A5 surface. + +--- + +##### `pto.vexpdiff` + +- **syntax:** `%result = pto.vexpdiff %input, %max, "EVEN|ODD" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** input `f16` or `f32`, output `f32` +- **semantics:** Fused exp(x - max) for numerically stable softmax. + +```c +for (int i = 0; i < N; i++) + dst[i] = expf(src[i] - max[i]); +``` + +**Use case:** Softmax numerator computation with numerical stability. + +- **inputs:** `%input` is the source vector and `%max` is the broadcasted + subtraction term. `%part` selects `EVEN` or `ODD` for the + underlying hardware contract. +- **outputs:** `%result` is the fused `exp(input - max)` vector with `f32` + elements. +- **constraints and limitations:** Source vectors must be `f16` or `f32`, the + result vector must be `f32`, and source/result storage width must match. + +--- + +#### Fused Compute+Convert Ops + +##### `pto.vaxpy` + +- **syntax:** `%result = pto.vaxpy %src0, %src1, %alpha : !pto.vreg, !pto.vreg, T -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** AXPY — scalar-vector multiply-add. + +```c +for (int i = 0; i < N; i++) + dst[i] = alpha * src0[i] + src1[i]; +``` + +- **inputs:** `%src0` is the scaled vector, `%src1` is the addend vector, and + `%alpha` is the scalar multiplier. +- **outputs:** `%result` is the fused AXPY result. +- **constraints and limitations:** Floating-point element types only on the + current documented surface. + +--- + +#### Extended Arithmetic + +##### `pto.vmull` + +- **syntax:** `%low, %high = pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` +- **A5 types:** i32/ui32 (native 32×32→64 widening multiply) +- **semantics:** Widening multiply with high/low results. + +```c +for (int i = 0; i < 64; i++) { + int64_t r = (int64_t)src0_i32[i] * (int64_t)src1_i32[i]; + dst_lo[i] = (int32_t)(r & 0xFFFFFFFF); + dst_hi[i] = (int32_t)(r >> 32); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the source vectors and `%mask` selects + active lanes. +- **outputs:** `%low` and `%high` expose the widened-product low/high parts. +- **constraints and limitations:** The current documented A5 form is the native + widening 32x32->64 integer multiply family. + +--- + +##### `pto.vmula` + +- **syntax:** `%result = pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Multiply-accumulate. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) + dst[i] = acc[i] + lhs[i] * rhs[i]; +``` + +- **inputs:** `%acc` is the accumulator input, `%lhs` and `%rhs` are the + multiplicands, and `%mask` selects active lanes. +- **outputs:** `%result` is the multiply-accumulate result. +- **constraints and limitations:** `pto.vmula` is a fused multiply-accumulate + operation and is not always interchangeable with separate `vmul` plus `vadd`. + +--- + +#### Index Generation + +##### `pto.vci` + +- **syntax:** `%result = pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- **semantics:** Generate lane index vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = base_index + i; +``` + +**Use case:** Generate indices for gather/scatter, argsort, etc. + +- **inputs:** `%index` is the scalar seed/base index. +- **outputs:** `%result` is the generated index vector. +- **constraints and limitations:** This page documents the arithmetic/indexing + use of the family; the conversion page also records the same opcode for + completeness. + +--- + +#### Sorting Operations + +##### `pto.vbitsort` + +- **syntax:** `pto.vbitsort %dest, %src, %indices, %repeat_times : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, index` +- **semantics:** Sort 32 region proposals by score and materialize sorted + proposal records into `%dest`. +- **inputs:** `%dest` is the UB destination buffer. `%src` is the UB score + buffer. `%indices` is the UB index buffer. `%repeat_times` is the repeat + count; each repeat processes the next adjacent group of 32 scores and 32 + indices. +- **outputs:** This op writes UB memory and returns no SSA value. Each output + record occupies 8 bytes: the upper 4 bytes hold the index and the lower + 4 bytes hold the score. For `f16` score forms, the score uses the lower + 2 bytes of that 4-byte score field and the upper 2 bytes are reserved. +- **constraints and limitations:** `%dest`, `%src`, and `%indices` MUST be + UB-backed pointers and SHOULD satisfy the backend alignment contract expected + by the A5 `VBS32` instruction. Scores are sorted in descending order, so the + highest score is written to the lowest destination address. Equal-score ties + preserve the earlier input proposal first. This is a UB helper, not a pure + `vreg -> vreg` op. + +--- + +##### `pto.vmrgsort4` + +- **syntax:** `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr, i64, i64` +- **semantics:** Merge-sort 4 pre-sorted input vectors. +- **inputs:** `%dest` is the UB destination, `%src0..%src3` are the four + pre-sorted UB inputs, `%count` is the number of valid elements, and `%config` + is the operation control word. +- **outputs:** This op writes UB memory and returns no SSA value. +- **constraints and limitations:** Inputs MUST already be sorted according to + the sort order encoded by `%config`. + +--- + +#### Current Implementation Surface Summary + +- `pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` +- `pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- `pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- `pto.vbitsort %dest, %src, %indices, %repeat_times : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, index` +- `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, i64, i64` + +--- + +#### Typical Usage + +```mlir +// Softmax with fused expdiff +%max_broadcast = pto.vlds %ub_max[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> +%exp_stable = pto.vexpdiff %logits, %max_broadcast : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// Leaky ReLU activation +%activated = pto.vlrelu %linear_out, %alpha_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Generate indices for argsort +%indices = pto.vci %c0 {order = "ASC"} : i32 -> !pto.vreg<64xi32> +``` + + + +### 14. Arith (Shared MLIR Dialect) + +> **Category:** Shared full scalar `arith` surface used around PTO ops +> **Dialect:** `arith` +> **Upstream Reference:** https://mlir.llvm.org/docs/Dialects/ArithOps/ + +The upstream MLIR `arith` dialect defines primitive arithmetic, comparison, select, and cast operations over signless integer, index, floating-point, and boolean-compatible scalar values. Within PTO micro Instruction code, the full scalar operation surface of `arith` is supported. These ops are used around PTO instructions to build constants, compute offsets and loop bounds, perform general scalar math, derive valid-shape metadata, and form predicates for `scf` control flow. + +These ops are part of the documented PTO micro Instruction surface, but they are not PTO ISA instructions. + +--- + +#### Role in PTO micro Instruction Code + +- materialize scalar constants used by PTO scalar operands and loop bounds +- compute scalar/index offsets for tensor views, partitioning, and dynamic shapes +- perform general scalar integer and floating-point math outside PTO vector/tile payload operations +- derive scalar predicates that guard `scf.if` or `scf.while` +- apply scalar casts, width changes, bitwise ops, and selects without introducing PTO-specific control ops + +Prefer PTO ops for vector or tile payload math. Use `arith` for scalar computation and bookkeeping that surrounds PTO regions. + +--- + +#### Supported Surface + +The documented PTO micro Instruction surface supports the full scalar operation surface of upstream `arith`. The upstream `arith` dialect reference remains authoritative for the exhaustive op-by-op syntax and semantics. The categories below summarize how that support is used in PTO micro Instruction code. + +| Category | Representative Ops | Typical Use in PTO micro Instruction Code | +|----------|--------------------|------------------| +| Constants | `arith.constant` | integer, floating-point, boolean, and `index` constants | +| Integer / Index Arithmetic | `arith.addi`, `arith.subi`, `arith.muli`, `arith.divsi`, `arith.divui`, `arith.ceildivsi`, `arith.ceildivui`, `arith.floordivsi`, `arith.remsi`, `arith.remui` | offsets, bounds, chunk sizes, scalar math | +| Floating-Point Arithmetic | `arith.addf`, `arith.subf`, `arith.mulf`, `arith.divf`, `arith.negf`, `arith.maximumf`, `arith.minimumf`, `arith.maxnumf`, `arith.minnumf` | scalar math around PTO regions | +| Bitwise / Shift Ops | `arith.andi`, `arith.ori`, `arith.xori`, `arith.shli`, `arith.shrsi`, `arith.shrui` | flags, masks, packed scalar fields | +| Comparisons / Select | `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.maxsi`, `arith.minui` | predicates, clamps, scalar muxes | +| Casts / Width Changes | `arith.index_cast`, `arith.index_castui`, `arith.extsi`, `arith.extui`, `arith.trunci`, `arith.sitofp`, `arith.uitofp`, `arith.fptosi`, `arith.fptoui`, `arith.extf`, `arith.truncf`, `arith.bitcast` | ABI glue, dynamic-shape plumbing, scalar type adaptation | + +--- + +#### Current PTOAS Coverage + +- the current repository examples are still dominated by constants, casts, integer/index arithmetic, compares, and selects because those are the most common surrounding-scalar patterns in existing kernels +- backend-specific tests such as the PTO shared-dialect fixture visibly exercise only a representative subset of `arith` ops in a single path +- the documented PTO micro Instruction source-level contract is nevertheless the full scalar `arith` surface, not just the index-heavy subset that appears most often in current samples + +This section therefore uses representative categories and examples instead of pretending that the supported `arith` surface is limited to the currently most common sample patterns. + +--- + +#### Typical Patterns + +##### Scalar Setup + +```mlir +%c0 = arith.constant 0 : index +%c1 = arith.constant 1 : index +%scale = arith.constant 2.0 : f32 +``` + +##### Dynamic Offset Computation + +```mlir +%vrow = arith.index_cast %valid_row : i32 to index +%chunk = arith.muli %row, %c32 : index +%tail = arith.subi %limit, %chunk : index +``` + +##### General Scalar Arithmetic + +```mlir +%sum_i = arith.addi %lhs_i, %rhs_i : i32 +%sum_f = arith.addf %lhs_f, %rhs_f : f32 +%prod_f = arith.mulf %sum_f, %scale : f32 +``` + +##### Scalar Predicate and Selection + +```mlir +%is_first = arith.cmpi eq, %i, %c0 : index +%active = arith.select %is_first, %first_count, %steady_count : index +``` + +##### Bitwise / Width Adaptation + +```mlir +%flags = arith.andi %flags0, %flags1 : i32 +%wide = arith.extui %flags : i32 to i64 +%shrunk = arith.trunci %wide : i64 to i16 +``` + +--- + +#### Authoring Guidance + +- treat upstream `arith` scalar semantics as the source of truth for supported scalar ops +- keep `arith` values scalar or `index` typed; do not use `arith` as a substitute for PTO vector/tile compute +- use `arith` for general scalar math, scalar comparisons, bitwise operations, and casts around PTO regions, not just for `index` arithmetic +- use `arith.cmpi` / `arith.cmpf` plus `scf.if` / `scf.while` for control flow, not ad hoc control intrinsics +- prefer `arith.index_cast` / `arith.index_castui` at ABI or shape boundaries where `index` is required, but do not read that as a restriction on the rest of scalar `arith` + + + +### 15. SCF (Shared MLIR Dialect) + +> **Category:** Shared structured control flow around PTO regions +> **Dialect:** `scf` +> **Upstream Reference:** https://mlir.llvm.org/docs/Dialects/SCFDialect/ + +The upstream MLIR `scf` dialect defines structured control flow operations with regions, including counted loops, conditional regions, and while-style loops. In PTO micro Instruction code, `scf` is the control shell around PTO ops: it sequences DMA, vector, and tile operations; carries scalar or tile state across iterations; and preserves analyzable control flow for PTO-specific analyses and lowerings. + +These ops are part of the documented PTO micro Instruction surface, but they are shared MLIR control-flow constructs rather than PTO ISA instructions. + +--- + +#### Supported Ops + +| Op | Role in PTO micro Instruction Code | Notes | +|----|------------------------|-------| +| `scf.for` | counted loops and loop-carried values | common structured counted loop form | +| `scf.if` | structured conditional execution | may yield values or act as side-effect-only branch | +| `scf.yield` | region terminator for `for` / `if` / `while` bodies | carries loop or branch results | +| `scf.while` | break-like or stateful loops | useful for source-level structured control | +| `scf.condition` | loop-continue / loop-exit decision for `scf.while` | placed in the "before" region | + +Ops such as `scf.execute_region`, `scf.forall`, or `scf.index_switch` are not part of the documented shared-dialect portion of the PTO micro Instruction surface here. + +--- + +#### Current PTOAS Coverage + +- `scf.for`, `scf.if`, and `scf.yield` are directly exercised in the shared-dialect PTO fixture and appear widely across PTO samples +- PTO synchronization and memory analyses explicitly reason about `scf.for`, `scf.if`, `scf.yield`, and `scf.while` +- `scf.while` and `scf.condition` appear in control-flow samples and are handled in PTO-to-EmitC control-flow lowering, but they are less broadly exercised than `for` / `if` on all backend paths + +--- + +#### Typical Patterns + +##### Counted Loop + +```mlir +scf.for %i = %c0 to %c4 step %c1 { + %offset = arith.muli %i, %c32 : index + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} +``` + +##### Counted Loop with Loop-Carried State + +```mlir +%final_alive = scf.for %i = %c0 to %c4 step %c1 + iter_args(%alive = %true) -> (i1) { + %break_now = arith.cmpi eq, %i, %c2 : index + %next_alive = scf.if %break_now -> (i1) { + scf.yield %false : i1 + } else { + scf.yield %alive : i1 + } + scf.yield %next_alive : i1 +} +``` + +##### Structured Conditional Region + +```mlir +%is_mode_a = arith.cmpi eq, %mode, %c0_i32 : i32 +scf.if %is_mode_a { + pto.tmuls ins(%data, %scale_a : !pto.tile_buf<...>, f32) outs(%data : !pto.tile_buf<...>) +} else { + pto.tadds ins(%data, %bias_b : !pto.tile_buf<...>, f32) outs(%data : !pto.tile_buf<...>) +} +``` + +##### While-Style Break Loop + +```mlir +%final:2 = scf.while (%i = %c0, %alive = %true) : (index, i1) -> (index, i1) { + %lt = arith.cmpi slt, %i, %c4 : index + %go = arith.andi %lt, %alive : i1 + scf.condition(%go) %i, %alive : index, i1 +} do { +^bb0(%i2: index, %alive2: i1): + %next_i = arith.addi %i2, %c1 : index + scf.yield %next_i, %alive2 : index, i1 +} +``` + +--- + +#### Authoring Guidance + +- use `scf.for` for regular counted loops and loop-carried scalar/tile state +- use `scf.if` for structured branching around PTO regions instead of inventing PTO-specific branch ops +- keep region results explicit with `scf.yield`; this is important for PTO analyses that track carried buffers and aliasing +- use `scf.while` only when a counted loop cannot express the control cleanly; `scf.for` remains the more common and better-exercised form in the current repository +- build branch predicates and loop conditions with `arith` ops, not PTO vector masks, unless the control decision truly comes from a scalarized value + +## Supported Data Types + +| Type | Bits | vreg Lanes | Description | +|------|------|-----------|-------------| +| `i8` / `si8` / `ui8` | 8 | 256 | Signless/signed/unsigned 8-bit integer | +| `i16` / `si16` / `ui16` | 16 | 128 | Signless/signed/unsigned 16-bit integer | +| `f16` | 16 | 128 | IEEE 754 half precision | +| `bf16` | 16 | 128 | Brain floating point | +| `i32` / `si32` / `ui32` | 32 | 64 | Signless/signed/unsigned 32-bit integer | +| `f32` | 32 | 64 | IEEE 754 single precision | +| `i64` / `si64` / `ui64` | 64 | 32 | Signless/signed/unsigned 64-bit integer | + +--- + +## Common Patterns + +### Softmax (Numerically Stable) + +```mlir +// 1. Find max +%max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %max_vec, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +%max_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> + +// 2. exp(x - max) using fused op +%exp = pto.vexpdiff %logits, %max_bc, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// 3. Sum +%sum = pto.vcadd %exp, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %sum, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +%sum_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> + +// 4. Divide +%softmax = pto.vdiv %exp, %sum_bc, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +### ReLU Variants + +```mlir +// Standard ReLU +%relu = pto.vrelu %input, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Leaky ReLU (scalar alpha) +%lrelu = pto.vlrelu %input, %alpha, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Parametric ReLU (per-element alpha) +%prelu = pto.vprelu %input, %alpha_vec : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +``` + +### Data Layout Conversion + +```mlir +// AoS → SoA (deinterleave) +%x, %y = pto.vldsx2 %ub_xy[%offset], "DINTLV" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// SoA → AoS (interleave) +pto.vstsx2 %x, %y, %ub_xy[%offset], "INTLV", %all_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, index, !pto.mask +``` + +--- + +--- + +## Quick Reference by Category + +### Memory Operations + +| Operation | Group | Description | +|-----------|-------|-------------| +| GM→UB DMA | 2 | `pto.copy_gm_to_ubuf` | +| UB→GM DMA | 2 | `pto.copy_ubuf_to_gm` | +| UB→UB Copy | 2 | `pto.copy_ubuf_to_ubuf` | +| Contiguous Load | 3 | `pto.vlds` with `NORM` dist | +| Broadcast Load | 3 | `pto.vlds` with `BRC` family dist | +| Gather | 3 | `pto.vgather2`, `pto.vgatherb` | +| Contiguous Store | 3 | `pto.vsts` with `NORM` dist | +| Scatter | 3 | `pto.vscatter` | + +### Compute Operations + +| Operation | Group | Description | +|-----------|-------|-------------| +| Element-wise Arithmetic | 6, 7 | `pto.vadd`, `pto.vmul`, `pto.vabs`, etc. | +| Scalar Operations | 8 | `pto.vadds`, `pto.vmuls`, etc. | +| Transcendental | 6 | `pto.vexp`, `pto.vln`, `pto.vsqrt`, etc. | +| Reduction | 10 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin` | +| Comparison | 11 | `pto.vcmp`, `pto.vcmps` | +| Selection | 11 | `pto.vsel`, `pto.vselr` | + +### Type & Data Manipulation + +| Operation | Group | Description | +|-----------|-------|-------------| +| Type Conversion | 9 | `pto.vcvt` | +| Interleave/Deinterleave | 12 | `pto.vintlv`, `pto.vdintlv` | +| Interleave/Deinterleave (not A5) | 12 | `pto.vintlvv2`, `pto.vdintlvv2` | + +### Synchronization + +| Operation | Group | Description | +|-----------|-------|-------------| +| Intra-core Sync | 1 | `pto.set_flag`, `pto.wait_flag` | +| Pipeline Buffer Sync | 1 | `pto.get_buf`, `pto.rls_buf` | + +### Scalar & Control Operations + +Group 14 covers the full scalar `arith` surface. The rows below list common PTO micro Instruction patterns rather than an exhaustive partition of `arith` ops. + +| Operation | Group | Description | +|-----------|-------|-------------| +| Scalar Constants | 14 | `arith.constant` | +| Scalar Integer / Index Arithmetic | 14 | `arith.addi`, `arith.subi`, `arith.muli`, `arith.divsi`, `arith.remui`, `arith.ceildivsi`, etc. | +| Scalar Floating-Point Arithmetic | 14 | `arith.addf`, `arith.subf`, `arith.mulf`, `arith.divf`, `arith.maximumf`, etc. | +| Scalar Compare & Select | 14 | `arith.cmpi`, `arith.cmpf`, `arith.select` | +| Scalar Casts / Width Changes | 14 | `arith.index_cast`, `arith.index_castui`, `arith.extsi`, `arith.extui`, `arith.trunci`, `arith.sitofp`, etc. | +| Scalar Bitwise / Shift Ops | 14 | `arith.andi`, `arith.ori`, `arith.xori`, `arith.shli`, `arith.shrsi`, `arith.shrui`, etc. | +| Counted Loops | 15 | `scf.for` | +| Conditional Regions | 15 | `scf.if`, `scf.yield` | +| Break-like Structured Loops | 15 | `scf.while`, `scf.condition`, `scf.yield` | From a77c8b2adab6795f6b4a9ff6dcb9c6a8cf0bd975 Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Sun, 12 Apr 2026 22:45:20 +0800 Subject: [PATCH 013/192] add PTO-Gym submodule --- .gitmodules | 3 +++ 3rdparty/PTO-Gym | 1 + 2 files changed, 4 insertions(+) create mode 100644 .gitmodules create mode 160000 3rdparty/PTO-Gym diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 000000000..9ae183956 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "3rdparty/PTO-Gym"] + path = 3rdparty/PTO-Gym + url = git@github.com:PTO-ISA/PTO-Gym.git diff --git a/3rdparty/PTO-Gym b/3rdparty/PTO-Gym new file mode 160000 index 000000000..8a186eae3 --- /dev/null +++ b/3rdparty/PTO-Gym @@ -0,0 +1 @@ +Subproject commit 8a186eae3befc4f1417f4618addbd9e942339acd From 949a9286a16cb9f8fe5bfb3408b18a9e8e29af12 Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Sun, 12 Apr 2026 23:01:22 +0800 Subject: [PATCH 014/192] feat: add PTO-Gym guide skill --- .../skills/pto-gym-vpto-validation/SKILL.md | 85 +++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 .codex/skills/pto-gym-vpto-validation/SKILL.md diff --git a/.codex/skills/pto-gym-vpto-validation/SKILL.md b/.codex/skills/pto-gym-vpto-validation/SKILL.md new file mode 100644 index 000000000..0e1451a61 --- /dev/null +++ b/.codex/skills/pto-gym-vpto-validation/SKILL.md @@ -0,0 +1,85 @@ +--- +name: pto-gym-vpto-validation +description: Run PTO-Gym validation from this PTOAS repo. Use when the user asks to run PTO-Gym SIM or board validation from the current source tree. Always force PTOAS onto the VPTO LLVM path instead of relying on the repo default backend. +--- + +# PTO-Gym VPTO Validation + +Use this skill when the task is specifically about: +- running `3rdparty/PTO-Gym/examples/pto/scripts/run_host_vpto_validation.sh` +- running `3rdparty/PTO-Gym/examples/pto/scripts/run_host_vpto_validation_parallel.sh` +- validating PTO-Gym cases from this PTOAS source tree + +## Required Rule + +When PTO-Gym is run from this repo, do not rely on the default PTOAS backend. + +Always pass PTOAS flags that force the VPTO LLVM path. +The current `ptoas` CLI spellings in this repo are `--pto-backend=vpto` and +`--vpto-emit-hivm-llvm`; do not shorten `--pto-backend` to `--backend`. + +Use: + +```bash +PTOAS_FLAGS='--pto-backend=vpto --vpto-emit-hivm-llvm --pto-arch a5' +``` + +If the caller already provides `PTOAS_FLAGS`, make sure these options are still +present. Do not silently fall back to the repo default backend. + +## Canonical Environment + +Use `.work/` under the repo for all scratch output and temp files: + +```bash +mkdir -p .work/tmp .work/runs +export TMPDIR=$PWD/.work/tmp +export TMP=$TMPDIR +export TEMP=$TMPDIR +``` + +Typical simulator environment: + +```bash +source /home/mouliangyu/.local/ascend/beta.2/cann-9.0.0-beta.2/set_env.sh +export ASCEND_HOME_PATH=/home/mouliangyu/.local/ascend/beta.2/cann-9.0.0-beta.2 +export PTOAS_BIN=$PWD/build/tools/ptoas/ptoas +export PTOAS_FLAGS='--pto-backend=vpto --vpto-emit-hivm-llvm --pto-arch a5' +``` + +## Canonical Commands + +Single case: + +```bash +WORK_SPACE=$PWD/.work/runs/pto-gym-single \ +ASCEND_HOME_PATH=$ASCEND_HOME_PATH \ +PTOAS_BIN=$PTOAS_BIN \ +PTOAS_FLAGS="$PTOAS_FLAGS" \ +CASE_NAME=micro-op/binary-vector/vadd \ +DEVICE=SIM \ +bash 3rdparty/PTO-Gym/examples/pto/scripts/run_host_vpto_validation.sh +``` + +Parallel micro-op sweep: + +```bash +WORK_SPACE=$PWD/.work/runs/pto-gym-microop \ +ASCEND_HOME_PATH=$ASCEND_HOME_PATH \ +PTOAS_BIN=$PTOAS_BIN \ +PTOAS_FLAGS="$PTOAS_FLAGS" \ +CASE_PREFIX=micro-op \ +DEVICE=SIM \ +JOBS=64 \ +bash 3rdparty/PTO-Gym/examples/pto/scripts/run_host_vpto_validation_parallel.sh +``` + +## Reporting Back + +Report: +- the exact `PTOAS_FLAGS` used +- the final `PASS/FAIL` counts +- the summary file path under `.work/runs/...` + +If a run fails, identify the first failing case from `parallel-summary.tsv` and +then inspect that case directory under `WORK_SPACE`. From 495a02c1a951b7a72039e2f99cb3faf07ed38399 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Fri, 3 Apr 2026 13:49:00 +0800 Subject: [PATCH 015/192] Add tilelang dsl implementation --- tilelang-dsl/CMakeLists.txt | 27 + tilelang-dsl/README.md | 19 + tilelang-dsl/docs/README.md | 12 + tilelang-dsl/docs/v1-surface.md | 235 +++++++ tilelang-dsl/examples/README.md | 16 + tilelang-dsl/examples/v1_emit_mlir_demo.py | 60 ++ tilelang-dsl/python/README.md | 3 + tilelang-dsl/python/tilelang_dsl/__init__.py | 59 ++ tilelang-dsl/python/tilelang_dsl/kernel.py | 643 +++++++++++++++++++ tilelang-dsl/python/tilelang_dsl/types.py | 103 +++ tilelang-dsl/tests/README.md | 5 + tilelang-dsl/tests/import_tilelang_dsl.py | 12 + tilelang-dsl/tests/test_tilelang_dsl_v1.py | 153 +++++ 13 files changed, 1347 insertions(+) create mode 100644 tilelang-dsl/CMakeLists.txt create mode 100644 tilelang-dsl/README.md create mode 100644 tilelang-dsl/docs/README.md create mode 100644 tilelang-dsl/docs/v1-surface.md create mode 100644 tilelang-dsl/examples/README.md create mode 100644 tilelang-dsl/examples/v1_emit_mlir_demo.py create mode 100644 tilelang-dsl/python/README.md create mode 100644 tilelang-dsl/python/tilelang_dsl/__init__.py create mode 100644 tilelang-dsl/python/tilelang_dsl/kernel.py create mode 100644 tilelang-dsl/python/tilelang_dsl/types.py create mode 100644 tilelang-dsl/tests/README.md create mode 100644 tilelang-dsl/tests/import_tilelang_dsl.py create mode 100644 tilelang-dsl/tests/test_tilelang_dsl_v1.py diff --git a/tilelang-dsl/CMakeLists.txt b/tilelang-dsl/CMakeLists.txt new file mode 100644 index 000000000..445d920b2 --- /dev/null +++ b/tilelang-dsl/CMakeLists.txt @@ -0,0 +1,27 @@ +# ========================================================= +# TileLang DSL package wiring +# ========================================================= + +set(TILELANG_DSL_PACKAGE_SRC_DIR + "${CMAKE_CURRENT_SOURCE_DIR}/python/tilelang_dsl") +set(TILELANG_DSL_BUILD_ROOT "${CMAKE_BINARY_DIR}/python") +set(TILELANG_DSL_BUILD_PACKAGE_DIR + "${TILELANG_DSL_BUILD_ROOT}/tilelang_dsl") + +add_custom_target(TileLangDSLPackage ALL + COMMAND ${CMAKE_COMMAND} -E make_directory "${TILELANG_DSL_BUILD_ROOT}" + COMMAND ${CMAKE_COMMAND} -E remove_directory "${TILELANG_DSL_BUILD_PACKAGE_DIR}" + COMMAND ${CMAKE_COMMAND} -E copy_directory + "${TILELANG_DSL_PACKAGE_SRC_DIR}" + "${TILELANG_DSL_BUILD_PACKAGE_DIR}" + COMMENT "Staging tilelang_dsl package into build/python" + VERBATIM +) + +install( + DIRECTORY "${TILELANG_DSL_PACKAGE_SRC_DIR}" + DESTINATION "." + COMPONENT PTOAS_Runtime + PATTERN "__pycache__" EXCLUDE + PATTERN "*.pyc" EXCLUDE +) diff --git a/tilelang-dsl/README.md b/tilelang-dsl/README.md new file mode 100644 index 000000000..1bf15d1f8 --- /dev/null +++ b/tilelang-dsl/README.md @@ -0,0 +1,19 @@ +TileLang DSL v1 lives under this directory. + +This subtree is the source of truth for the new frontend introduced by +`add-tilelang-dsl-core-foundation`. + +Boundary with the existing `python/pto/dialects/pto.py` module: +- `tilelang-dsl/` owns new TileLang DSL v1 core implementation work +- `python/pto/dialects/pto.py` keeps PTO dialect bindings and the legacy + experimental VPTO Python DSL surface +- Root-level wiring into build/install/test is allowed, but TileLang DSL core + logic must not move back into `python/pto/dialects/pto.py` + +Layout: +- `python/tilelang_dsl/`: package sources +- `tests/`: TileLang DSL focused tests +- `examples/`: self-contained examples +- `docs/`: local documentation for this frontend + +Root-level wiring belongs to follow-up tasks and must stay minimal. diff --git a/tilelang-dsl/docs/README.md b/tilelang-dsl/docs/README.md new file mode 100644 index 000000000..357cc14a2 --- /dev/null +++ b/tilelang-dsl/docs/README.md @@ -0,0 +1,12 @@ +TileLang DSL local documentation lives here. + +Current docs: +- `v1-surface.md`: the TileLang DSL v1 contract implemented by + `add-tilelang-dsl-core-foundation` + +Documentation boundary: +- `tilelang-dsl/docs/` is the local documentation source of truth for the new + `tilelang_dsl` frontend +- repository-level docs may link here, but should not redefine this package's + implemented v1 boundary +- `python/pto/dialects/pto.py` is not the source of truth for TileLang DSL v1 diff --git a/tilelang-dsl/docs/v1-surface.md b/tilelang-dsl/docs/v1-surface.md new file mode 100644 index 000000000..225b566cc --- /dev/null +++ b/tilelang-dsl/docs/v1-surface.md @@ -0,0 +1,235 @@ +# TileLang DSL v1 Surface + +## Scope + +This document records the implemented v1 boundary for the standalone +`tilelang_dsl` package introduced by +`add-tilelang-dsl-core-foundation`. + +It covers: +- package entrypoints +- supported `@vkernel` decorator metadata +- parameter typing rules +- Tile specialization requirements +- current frontend diagnostics boundary +- deferred features that belong to follow-up changes + +It does not define: +- DSL to VPTO lowering details +- matcher and priority semantics +- advanced vector-family surface +- implicit vecscope inference + +## Source Of Truth + +TileLang DSL v1 source of truth lives under: +- `tilelang-dsl/python/tilelang_dsl/` +- `tilelang-dsl/tests/` +- `tilelang-dsl/examples/` +- `tilelang-dsl/docs/` + +`python/pto/dialects/pto.py` is not the source of truth for TileLang DSL v1. +That file still exists for PTO dialect bindings and the legacy experimental VPTO +Python DSL surface. Root-level wiring into build, install, and test is allowed, +but new TileLang DSL core behavior must land under `tilelang-dsl/`. + +## Package Entry + +Examples and tests should import the standalone package: + +```python +import tilelang_dsl as pto +``` + +The package currently exports: +- `vkernel` +- `VKernelDescriptor` +- `BoundKernelParameter` +- `MaterializedMLIRModule` +- `TileLangFrontendError` +- `TensorView` +- `Tile` +- scalar dtypes such as `f16`, `bf16`, `f32`, `i8`, `i16`, `i32`, `i64` +- Tile specialization helpers: `MemorySpace`, `TileConfig`, `TileSpecialization` + +## v1 Decorator Surface + +The supported v1 decorator surface is: + +```python +@pto.vkernel( + target="a5", + op="some_op_name", + dtypes=[(pto.f32, pto.f16, pto.i32)], + name="optional_name", + verify=True, +) +def kernel(...): + ... +``` + +Current rules: +- `target` only accepts `"a5"` +- `op` is required and must be a non-empty string +- `dtypes` must contain exactly one monomorphic signature tuple +- `name` is optional and defaults to the Python function name +- `verify` is optional and must be a bool + +The descriptor keeps these metadata fields: +- `target` +- `op` +- `dtypes` +- `name` +- `verify` + +## Parameter Typing + +v1 accepts these parameter categories: +- bare `TensorView` +- bare `Tile` +- explicit scalar annotations such as `pto.i32`, `pto.f16`, `pto.f32` + +Binding rules: +- the single `dtypes` signature binds parameter element types positionally +- `TensorView` parameters get their element dtype from the same position in + `dtypes` +- `Tile` parameters get their element dtype from the same position in `dtypes` +- scalar parameters must use an explicit scalar annotation +- scalar annotations must exactly match the dtype at the same position in + `dtypes` + +Example: + +```python +@pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.bf16, pto.i32)]) +def kernel(inp: pto.TensorView, tmp: pto.Tile, scale: pto.i32): + return None +``` + +In this example: +- `inp` binds to `f32` +- `tmp` binds to `bf16` +- `scale` binds to `i32` + +## Tile Specialization + +Bare `Tile` parameters are incomplete until descriptor-level specialization is +provided. + +The only supported completion path is: + +```python +specialized = descriptor.specialize( + tmp=pto.TileSpecialization( + shape=(16, 32), + memory_space=pto.MemorySpace.UB, + config=pto.TileConfig.from_mapping({"layout": "row_major"}), + ) +) +``` + +Current v1 Tile profile rules: +- Tile physical shape must be static +- Tile dimensions must be positive integers +- Tile rank must be 1D or 2D +- Tile memory space must be `MemorySpace.UB` +- `config` may be omitted, provided as `TileConfig`, or built from a dict + +Before all bare `Tile` parameters are specialized, the descriptor must reject: +- `mlir_text()` +- `mlir_module()` +- `verify()` +- `emit(path)` + +## Materialization API + +After all bare `Tile` parameters are specialized, the descriptor exposes: +- `mlir_text()` +- `mlir_module()` +- `verify()` +- `emit(path)` + +At this stage of the workflow, these APIs provide a stable descriptor/materialization +surface for the new package. They do not yet define the final TileLang DSL to +VPTO lowering behavior; that work belongs to +`add-tilelang-dsl-authoring-vpto-lowering`. + +## Frontend Diagnostics + +The v1 frontend fails fast for: +- unsupported decorator matcher features +- unsupported Python syntax +- arbitrary external calls +- unsupported `pto.*` op surface +- missing Tile specialization +- dynamic physical Tile shape +- illegal Tile profile + +Diagnostics are frontend errors, not deferred verifier failures. When source is +available, errors include file, line, and column information. + +## Minimal Validation + +The following commands are the minimal validation set for +`add-tilelang-dsl-core-foundation`: + +```bash +cmake --build build --target TileLangDSLPackage +python3 -c "import sys; sys.path.insert(0, 'build/python'); import tilelang_dsl; print(tilelang_dsl.__file__)" +ctest --test-dir build -R tilelang_dsl_import --output-on-failure +ctest --test-dir build -R tilelang_dsl_unittest --output-on-failure +``` + +What these commands confirm: +- the standalone `tilelang_dsl` package is staged into `build/python/` +- Python can import the staged package directly +- the dedicated import smoke test passes +- the focused unittest suite passes for descriptor API, specialization, and + diagnostics coverage + +For a direct source-location diagnostics smoke, run: + +```bash +tmp=$(mktemp /tmp/tilelang_dsl_diag_XXXX.py) +cat > "$tmp" <<'PY' +import tilelang_dsl as pto + +try: + @pto.vkernel(op="x", dtypes=[(pto.f32,)]) + def kernel(x: pto.TensorView): + while True: + return None +except pto.TileLangFrontendError as exc: + print(exc) +PY +PYTHONPATH=build/python python3 "$tmp" +rm -f "$tmp" +``` + +Expected output shape: + +```text +/tmp/tilelang_dsl_diag_XXXX.py:6:5: unsupported Python syntax `while` in TileLang DSL v1 +``` + +This confirms diagnostics are emitted against the authored DSL source file +rather than an internal lowering location. + +## Deferred Features + +The following are intentionally out of scope for v1 and belong to follow-up +changes: +- multiple `dtypes` signatures +- `constraints` +- `priority` +- `AnyFloat`, `AnyInt`, `AnyType`, `AnyMask` +- `TypeVar` +- matcher registry and deterministic selection +- implicit vecscope inference +- raw pointer authoring surface +- advanced vector-family support +- final TileLang DSL to VPTO lowering implementation + +Matcher-related extensions are deferred to +`extend-tilelang-dsl-matcher-and-advanced-surface`. +Lowering work is deferred to `add-tilelang-dsl-authoring-vpto-lowering`. diff --git a/tilelang-dsl/examples/README.md b/tilelang-dsl/examples/README.md new file mode 100644 index 000000000..1c9aac4d7 --- /dev/null +++ b/tilelang-dsl/examples/README.md @@ -0,0 +1,16 @@ +TileLang DSL examples live here. + +Examples in this subtree should import `tilelang_dsl` as their package +entrypoint once the package wiring is added. + +Current example: +- `v1_emit_mlir_demo.py`: define a v1 `@pto.vkernel`, specialize a bare + `Tile`, and materialize the result as MLIR text or an `.mlir` file + +Typical usage from the repository root: + +```bash +cmake --build build --target TileLangDSLPackage +python3 tilelang-dsl/examples/v1_emit_mlir_demo.py +python3 tilelang-dsl/examples/v1_emit_mlir_demo.py /tmp/tilelang_demo.mlir +``` diff --git a/tilelang-dsl/examples/v1_emit_mlir_demo.py b/tilelang-dsl/examples/v1_emit_mlir_demo.py new file mode 100644 index 000000000..770bf12ea --- /dev/null +++ b/tilelang-dsl/examples/v1_emit_mlir_demo.py @@ -0,0 +1,60 @@ +"""Minimal TileLang DSL v1 demo that materializes a kernel into MLIR.""" + +import sys +from pathlib import Path + + +def _import_tilelang_dsl(): + try: + import tilelang_dsl as pto + + return pto + except ModuleNotFoundError: + repo_root = Path(__file__).resolve().parents[2] + sys.path.insert(0, str(repo_root / "build" / "python")) + import tilelang_dsl as pto + + return pto + + +pto = _import_tilelang_dsl() + + +@pto.vkernel( + op="eltwise_with_tile", + dtypes=[(pto.f32, pto.f16, pto.i32)], + name="tilelang_v1_demo_kernel", +) +def kernel(inp: pto.TensorView, tile: pto.Tile, scale: pto.i32): + return None + + +def build_specialized_kernel(): + return kernel.specialize( + tile=pto.TileSpecialization( + shape=(16, 32), + memory_space=pto.MemorySpace.UB, + config=pto.TileConfig.from_mapping({"layout": "row_major"}), + ) + ) + + +def main(argv: list[str]) -> int: + specialized = build_specialized_kernel() + + if len(argv) > 2: + print(f"usage: {Path(argv[0]).name} [output.mlir]", file=sys.stderr) + return 2 + + if len(argv) == 2: + output_path = Path(argv[1]) + specialized.emit(output_path) + print(f"wrote MLIR to {output_path}") + return 0 + + print(specialized.mlir_text()) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main(sys.argv)) diff --git a/tilelang-dsl/python/README.md b/tilelang-dsl/python/README.md new file mode 100644 index 000000000..39272be3d --- /dev/null +++ b/tilelang-dsl/python/README.md @@ -0,0 +1,3 @@ +This directory hosts the TileLang DSL Python package sources. + +The package root is `tilelang_dsl/`. diff --git a/tilelang-dsl/python/tilelang_dsl/__init__.py b/tilelang-dsl/python/tilelang_dsl/__init__.py new file mode 100644 index 000000000..19ad88cd6 --- /dev/null +++ b/tilelang-dsl/python/tilelang_dsl/__init__.py @@ -0,0 +1,59 @@ +"""TileLang DSL v1 package.""" + +from .kernel import ( + BoundKernelParameter, + MaterializedMLIRModule, + TileLangFrontendError, + VKernelDescriptor, + vkernel, +) +from .types import ( + AnyFloat, + AnyInt, + AnyMask, + AnyType, + MemorySpace, + ScalarType, + TensorView, + Tile, + TileConfig, + TileSpecialization, + TypeVar, + TypeVariable, + WildcardType, + bf16, + f16, + f32, + i8, + i16, + i32, + i64, +) + +__all__ = [ + "BoundKernelParameter", + "MaterializedMLIRModule", + "TileLangFrontendError", + "VKernelDescriptor", + "vkernel", + "ScalarType", + "WildcardType", + "TypeVariable", + "TypeVar", + "TensorView", + "Tile", + "MemorySpace", + "TileConfig", + "TileSpecialization", + "i8", + "i16", + "i32", + "i64", + "f16", + "bf16", + "f32", + "AnyFloat", + "AnyInt", + "AnyType", + "AnyMask", +] diff --git a/tilelang-dsl/python/tilelang_dsl/kernel.py b/tilelang-dsl/python/tilelang_dsl/kernel.py new file mode 100644 index 000000000..621b19e80 --- /dev/null +++ b/tilelang-dsl/python/tilelang_dsl/kernel.py @@ -0,0 +1,643 @@ +"""Kernel descriptor surface for TileLang DSL v1.""" + +from __future__ import annotations + +import inspect +import re +import textwrap +import ast +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Callable + +from .types import ( + MemorySpace, + ScalarType, + TensorView, + Tile, + TileConfig, + TileSpecialization, + TypeVariable, + WildcardType, +) + + +_UNSET = object() +_MATCHER_FOLLOW_UP_CHANGE = "extend-tilelang-dsl-matcher-and-advanced-surface" + + +def _unsupported_feature_message(feature: str) -> str: + return ( + f"{feature} is not supported in TileLang DSL v1; " + f"see follow-up change `{_MATCHER_FOLLOW_UP_CHANGE}`" + ) + + +def _reject_unsupported_decorator_feature(name: str, value: Any) -> None: + if value is _UNSET: + return + raise ValueError(_unsupported_feature_message(f"decorator feature `{name}`")) + + +def _reject_unsupported_dtype_feature(dtype: Any) -> None: + if isinstance(dtype, WildcardType): + raise ValueError( + _unsupported_feature_message(f"dtype wildcard `{dtype.name}`") + ) + if isinstance(dtype, TypeVariable): + raise ValueError( + _unsupported_feature_message(f"dtype type variable `{dtype.name}`") + ) + + +class TileLangFrontendError(ValueError): + """Source-located frontend diagnostic for TileLang DSL.""" + + def __init__(self, path: str, line: int, column: int, message: str): + self.path = path + self.line = line + self.column = column + self.message = message + super().__init__(f"{path}:{line}:{column}: {message}") + + +@dataclass(frozen=True) +class _FunctionSourceInfo: + path: str + start_line: int + function_def: ast.FunctionDef + + def location(self, node: ast.AST) -> tuple[int, int]: + line = self.start_line + getattr(node, "lineno", 1) - 1 + column = getattr(node, "col_offset", 0) + 1 + return line, column + + def error(self, node: ast.AST, message: str) -> TileLangFrontendError: + line, column = self.location(node) + return TileLangFrontendError(self.path, line, column, message) + + def parameter_node(self, param_name: str) -> ast.AST | None: + for arg in self.function_def.args.args: + if arg.arg == param_name: + return arg.annotation or arg + return None + + +class _KernelBodyValidator(ast.NodeVisitor): + def __init__(self, source_info: _FunctionSourceInfo): + self.source_info = source_info + + def validate(self) -> None: + for stmt in self.source_info.function_def.body: + self.visit(stmt) + + def visit_While(self, node: ast.While) -> None: + raise self.source_info.error(node, "unsupported Python syntax `while` in TileLang DSL v1") + + def visit_ListComp(self, node: ast.ListComp) -> None: + raise self.source_info.error( + node, "unsupported Python syntax `list comprehension` in TileLang DSL v1" + ) + + def visit_DictComp(self, node: ast.DictComp) -> None: + raise self.source_info.error( + node, "unsupported Python syntax `dict comprehension` in TileLang DSL v1" + ) + + def visit_SetComp(self, node: ast.SetComp) -> None: + raise self.source_info.error( + node, "unsupported Python syntax `set comprehension` in TileLang DSL v1" + ) + + def visit_GeneratorExp(self, node: ast.GeneratorExp) -> None: + raise self.source_info.error( + node, "unsupported Python syntax `generator expression` in TileLang DSL v1" + ) + + def visit_Call(self, node: ast.Call) -> None: + if isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name): + if node.func.value.id == "pto": + raise self.source_info.error( + node, + f"unsupported op surface `pto.{node.func.attr}` in TileLang DSL v1", + ) + raise self.source_info.error( + node, + f"arbitrary external call `{node.func.value.id}.{node.func.attr}` is not supported " + "in TileLang DSL v1", + ) + + if isinstance(node.func, ast.Name): + raise self.source_info.error( + node, + f"arbitrary external call `{node.func.id}` is not supported in TileLang DSL v1", + ) + + raise self.source_info.error( + node, + "unsupported call surface in TileLang DSL v1", + ) + + +def _load_function_source_info(py_fn: Callable[..., Any]) -> _FunctionSourceInfo | None: + try: + source_lines, start_line = inspect.getsourcelines(py_fn) + path = inspect.getsourcefile(py_fn) or inspect.getfile(py_fn) + except (OSError, IOError, TypeError): + return None + + source = textwrap.dedent("".join(source_lines)) + module = ast.parse(source) + for node in module.body: + if isinstance(node, ast.FunctionDef) and node.name == py_fn.__name__: + return _FunctionSourceInfo(path=path, start_line=start_line, function_def=node) + return None + + +def _validate_function_body(source_info: _FunctionSourceInfo | None) -> None: + if source_info is None: + return + _KernelBodyValidator(source_info).validate() + + +def _raise_tile_param_error( + source_info: _FunctionSourceInfo | None, + param_name: str, + message: str, + fallback_exception: type[Exception] = ValueError, +) -> None: + if source_info is not None: + node = source_info.parameter_node(param_name) + if node is not None: + raise source_info.error(node, message) + raise fallback_exception(message) + + +def _freeze_dtypes(dtypes: Any) -> tuple[tuple[Any, ...], ...]: + if not isinstance(dtypes, (list, tuple)): + raise TypeError("dtypes must be a sequence of signature tuples") + + frozen_signatures = [] + for signature in dtypes: + if not isinstance(signature, (list, tuple)): + raise TypeError("each dtypes entry must be a signature tuple") + frozen_signature = tuple(signature) + for dtype in frozen_signature: + _reject_unsupported_dtype_feature(dtype) + frozen_signatures.append(frozen_signature) + + if not frozen_signatures: + raise ValueError("dtypes must contain at least one signature tuple") + + if len(frozen_signatures) != 1: + raise ValueError( + _unsupported_feature_message("multiple dtypes signatures") + ) + + return tuple(frozen_signatures) + + +@dataclass(frozen=True) +class BoundKernelParameter: + """One parameter after v1 monomorphic dtype binding.""" + + name: str + kind: str + annotation: Any + dtype: ScalarType + + @property + def element_dtype(self) -> ScalarType | None: + if self.kind in ("tensorview", "tile"): + return self.dtype + return None + + +@dataclass(frozen=True) +class VKernelDescriptor: + """Descriptor returned by `@tilelang_dsl.vkernel`.""" + + target: str + op: str + dtypes: tuple[tuple[Any, ...], ...] + name: str + verify_enabled: bool + parameters: tuple[BoundKernelParameter, ...] + _py_fn: Callable[..., Any] = field(repr=False) + _source_info: _FunctionSourceInfo | None = field(repr=False, compare=False, default=None) + specializations: tuple[tuple[str, TileSpecialization], ...] = () + + @property + def py_fn(self) -> Callable[..., Any]: + return self._py_fn + + @property + def dtype_signature(self) -> tuple[ScalarType, ...]: + return self.dtypes[0] + + @property + def metadata(self) -> dict[str, Any]: + return { + "target": self.target, + "op": self.op, + "dtypes": self.dtypes, + "name": self.name, + "verify": self.verify_enabled, + } + + @property + def tile_parameters(self) -> tuple[BoundKernelParameter, ...]: + return tuple(param for param in self.parameters if param.kind == "tile") + + @property + def specializations_by_name(self) -> dict[str, TileSpecialization]: + return dict(self.specializations) + + def specialize(self, **bindings: Any) -> "VKernelDescriptor": + tile_params = {param.name: param for param in self.tile_parameters} + if not tile_params: + if bindings: + unknown = ", ".join(sorted(bindings)) + raise TypeError( + f"specialize() received bindings for non-Tile parameters: {unknown}" + ) + return self + + unknown = sorted(set(bindings) - set(tile_params)) + if unknown: + unknown_names = ", ".join(unknown) + raise TypeError( + f"specialize() only accepts bare Tile parameters; got: {unknown_names}" + ) + + updated = self.specializations_by_name + for name, binding in bindings.items(): + updated[name] = _coerce_tile_specialization(name, binding, self._source_info) + + return VKernelDescriptor( + target=self.target, + op=self.op, + dtypes=self.dtypes, + name=self.name, + verify_enabled=self.verify_enabled, + parameters=self.parameters, + _source_info=self._source_info, + specializations=tuple(sorted(updated.items())), + _py_fn=self._py_fn, + ) + + def _require_specialized_tiles(self, api_name: str) -> None: + tile_names = [param.name for param in self.tile_parameters] + if not tile_names: + return + + specialized = self.specializations_by_name + missing = [name for name in tile_names if name not in specialized] + if missing: + missing_names = ", ".join(missing) + _raise_tile_param_error( + self._source_info, + missing[0], + f"{api_name}() requires specialize() bindings for bare Tile parameters: " + f"{missing_names}", + ) + + def _format_symbol_name(self) -> str: + if re.fullmatch(r"[A-Za-z_][A-Za-z0-9_$.]*", self.name): + return f"@{self.name}" + escaped = self.name.replace("\\", "\\\\").replace('"', '\\"') + return f'@"{escaped}"' + + def mlir_text(self) -> str: + self._require_specialized_tiles("mlir_text") + + lines = [ + f"// tilelang.target = {self.target}", + f"// tilelang.op = {self.op}", + f"// tilelang.dtypes = {self.dtypes}", + f"// tilelang.verify = {self.verify_enabled}", + ] + for name, spec in self.specializations: + lines.append( + "// tilelang.specialize " + f"{name} shape={spec.shape} memory_space={spec.memory_space.value} " + f"config={spec.config}" + ) + lines.extend( + [ + "module {", + f" func.func {self._format_symbol_name()}() {{", + " return", + " }", + "}", + "", + ] + ) + return "\n".join(lines) + + def mlir_module(self) -> "MaterializedMLIRModule": + self._require_specialized_tiles("mlir_module") + return MaterializedMLIRModule(self.mlir_text()) + + def verify(self) -> bool: + self._require_specialized_tiles("verify") + self.mlir_module() + return True + + def emit(self, path: str | Path) -> None: + self._require_specialized_tiles("emit") + output_path = Path(path) + output_path.write_text(self.mlir_text(), encoding="utf-8") + + +@dataclass(frozen=True) +class MaterializedMLIRModule: + text: str + + def __str__(self) -> str: + return self.text + + def verify(self) -> bool: + return True + + +def _validate_target(target: str) -> str: + if not isinstance(target, str): + raise TypeError("target must be a string") + if target != "a5": + raise ValueError("TileLang DSL v1 currently only supports target='a5'") + return target + + +def _validate_op(op: Any) -> str: + if not isinstance(op, str) or not op: + raise TypeError("op must be a non-empty string") + return op + + +def _validate_name(py_fn: Callable[..., Any], name: Any) -> str: + if name is None: + return py_fn.__name__ + if not isinstance(name, str) or not name: + raise TypeError("name must be a non-empty string") + return name + + +def _validate_verify(verify: Any) -> bool: + if not isinstance(verify, bool): + raise TypeError("verify must be a bool") + return verify + + +def _coerce_memory_space(value: Any, param_name: str) -> MemorySpace: + if isinstance(value, MemorySpace): + return value + if isinstance(value, str): + normalized = value.strip().upper() + try: + return MemorySpace[normalized] + except KeyError as exc: + raise ValueError( + f"specialization for '{param_name}' uses unsupported memory_space {value!r}" + ) from exc + raise TypeError( + f"specialization for '{param_name}' must provide MemorySpace or string memory_space" + ) + + +def _coerce_tile_config(value: Any, param_name: str) -> TileConfig | None: + if value is None: + return None + if isinstance(value, TileConfig): + return value + if isinstance(value, dict): + return TileConfig.from_mapping(value) + raise TypeError( + f"specialization for '{param_name}' must provide TileConfig, dict, or None for config" + ) + + +def _coerce_tile_specialization( + param_name: str, + binding: Any, + source_info: _FunctionSourceInfo | None, +) -> TileSpecialization: + if isinstance(binding, TileSpecialization): + spec = binding + elif isinstance(binding, dict): + if "shape" not in binding: + _raise_tile_param_error( + source_info, + param_name, + f"specialization for '{param_name}' must provide a static physical Tile shape", + TypeError, + ) + if "memory_space" not in binding: + _raise_tile_param_error( + source_info, + param_name, + f"specialization for '{param_name}' must provide memory_space", + TypeError, + ) + spec = TileSpecialization( + shape=tuple(binding["shape"]), + memory_space=_coerce_memory_space(binding["memory_space"], param_name), + config=_coerce_tile_config(binding.get("config"), param_name), + ) + else: + _raise_tile_param_error( + source_info, + param_name, + f"specialization for '{param_name}' must be a TileSpecialization or dict", + TypeError, + ) + + if not spec.shape: + _raise_tile_param_error( + source_info, + param_name, + f"illegal Tile profile for '{param_name}': shape must be non-empty", + ) + for dim in spec.shape: + if not isinstance(dim, int) or isinstance(dim, bool): + _raise_tile_param_error( + source_info, + param_name, + f"dynamic physical Tile shape is not supported for '{param_name}'", + TypeError, + ) + if dim <= 0: + _raise_tile_param_error( + source_info, + param_name, + f"illegal Tile profile for '{param_name}': dimensions must be positive", + ) + if len(spec.shape) not in (1, 2): + _raise_tile_param_error( + source_info, + param_name, + f"illegal Tile profile for '{param_name}': v1 only supports rank-1 or rank-2 Tile shapes", + ) + if spec.memory_space != MemorySpace.UB: + _raise_tile_param_error( + source_info, + param_name, + f"illegal Tile profile for '{param_name}': v1 only supports MemorySpace.UB", + ) + return spec + + +def _validate_scalar_dtype(dtype: Any, param_name: str) -> ScalarType: + if not isinstance(dtype, ScalarType): + raise TypeError( + f"dtypes entry for parameter '{param_name}' must be a TileLang scalar dtype" + ) + return dtype + + +def _bind_parameter( + param: inspect.Parameter, dtype: Any +) -> BoundKernelParameter: + if param.kind not in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + raise TypeError( + f"parameter '{param.name}' uses unsupported parameter kind for TileLang DSL v1" + ) + if param.default is not inspect._empty: + raise TypeError( + f"parameter '{param.name}' must not declare a default value in TileLang DSL v1" + ) + if param.annotation is inspect._empty: + raise TypeError( + f"parameter '{param.name}' must declare a TileLang DSL type annotation" + ) + + annotation = param.annotation + scalar_dtype = _validate_scalar_dtype(dtype, param.name) + + if annotation is TensorView: + return BoundKernelParameter( + name=param.name, + kind="tensorview", + annotation=annotation, + dtype=scalar_dtype, + ) + if annotation is Tile: + return BoundKernelParameter( + name=param.name, + kind="tile", + annotation=annotation, + dtype=scalar_dtype, + ) + if isinstance(annotation, ScalarType): + if annotation != scalar_dtype: + raise TypeError( + f"scalar parameter '{param.name}' annotation {annotation!r} " + f"does not match dtypes entry {scalar_dtype!r}" + ) + return BoundKernelParameter( + name=param.name, + kind="scalar", + annotation=annotation, + dtype=scalar_dtype, + ) + + raise TypeError( + f"parameter '{param.name}' uses unsupported annotation {annotation!r}" + ) + + +def _bind_parameters( + py_fn: Callable[..., Any], dtypes: tuple[tuple[Any, ...], ...] +) -> tuple[BoundKernelParameter, ...]: + if len(dtypes) != 1: + raise ValueError( + "TileLang DSL v1 requires dtypes to contain exactly one monomorphic signature tuple" + ) + + signature = inspect.signature(py_fn) + params = tuple(signature.parameters.values()) + dtype_signature = dtypes[0] + + if len(dtype_signature) != len(params): + raise ValueError( + "single dtypes signature must match the decorated function parameter count" + ) + + return tuple( + _bind_parameter(param, dtype) + for param, dtype in zip(params, dtype_signature) + ) + + +def _build_descriptor( + py_fn: Callable[..., Any], + *, + target: str, + op: Any, + dtypes: Any, + name: Any, + verify: Any, +) -> VKernelDescriptor: + if not callable(py_fn): + raise TypeError("@vkernel can only decorate callables") + + source_info = _load_function_source_info(py_fn) + _validate_function_body(source_info) + frozen_dtypes = _freeze_dtypes(dtypes) + + return VKernelDescriptor( + target=_validate_target(target), + op=_validate_op(op), + dtypes=frozen_dtypes, + name=_validate_name(py_fn, name), + verify_enabled=_validate_verify(verify), + parameters=_bind_parameters(py_fn, frozen_dtypes), + _py_fn=py_fn, + _source_info=source_info, + ) + + +def vkernel( + py_fn: Callable[..., Any] | None = None, + *, + target: str = "a5", + op: str | None = None, + dtypes: Any = None, + name: str | None = None, + verify: bool = True, + constraints: Any = _UNSET, + priority: Any = _UNSET, +) -> VKernelDescriptor | Callable[[Callable[..., Any]], VKernelDescriptor]: + """Create a TileLang DSL v1 kernel descriptor. + + v1 keeps only the minimal descriptor metadata surface: + `target`, `op`, `dtypes`, `name`, and `verify`. + """ + _reject_unsupported_decorator_feature("constraints", constraints) + _reject_unsupported_decorator_feature("priority", priority) + + def wrap(fn: Callable[..., Any]) -> VKernelDescriptor: + return _build_descriptor( + fn, + target=target, + op=op, + dtypes=dtypes, + name=name, + verify=verify, + ) + + if py_fn is None: + return wrap + return wrap(py_fn) + + +__all__ = [ + "BoundKernelParameter", + "MaterializedMLIRModule", + "TileLangFrontendError", + "VKernelDescriptor", + "vkernel", +] diff --git a/tilelang-dsl/python/tilelang_dsl/types.py b/tilelang-dsl/python/tilelang_dsl/types.py new file mode 100644 index 000000000..161a3e0cd --- /dev/null +++ b/tilelang-dsl/python/tilelang_dsl/types.py @@ -0,0 +1,103 @@ +"""Public type markers for the TileLang DSL v1 surface.""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import Any, Mapping + + +@dataclass(frozen=True) +class ScalarType: + name: str + + def __repr__(self) -> str: + return self.name + + +class TensorView: + """Bare TensorView annotation marker for TileLang DSL v1.""" + + +class Tile: + """Bare Tile annotation marker for TileLang DSL v1.""" + + +@dataclass(frozen=True) +class WildcardType: + name: str + + def __repr__(self) -> str: + return self.name + + +@dataclass(frozen=True) +class TypeVariable: + name: str + + def __repr__(self) -> str: + return f"TypeVar({self.name!r})" + + +class MemorySpace(str, Enum): + GM = "gm" + UB = "ub" + + +@dataclass(frozen=True) +class TileConfig: + fields: tuple[tuple[str, Any], ...] = () + + @classmethod + def from_mapping(cls, mapping: Mapping[str, Any]) -> "TileConfig": + return cls(tuple(sorted(mapping.items()))) + + +@dataclass(frozen=True) +class TileSpecialization: + shape: tuple[int, ...] + memory_space: MemorySpace + config: TileConfig | None = None + + +i8 = ScalarType("i8") +i16 = ScalarType("i16") +i32 = ScalarType("i32") +i64 = ScalarType("i64") +f16 = ScalarType("f16") +bf16 = ScalarType("bf16") +f32 = ScalarType("f32") +AnyFloat = WildcardType("AnyFloat") +AnyInt = WildcardType("AnyInt") +AnyType = WildcardType("AnyType") +AnyMask = WildcardType("AnyMask") + + +def TypeVar(name: str) -> TypeVariable: + if not isinstance(name, str) or not name: + raise TypeError("TypeVar name must be a non-empty string") + return TypeVariable(name) + + +__all__ = [ + "ScalarType", + "WildcardType", + "TypeVariable", + "TypeVar", + "TensorView", + "Tile", + "MemorySpace", + "TileConfig", + "TileSpecialization", + "i8", + "i16", + "i32", + "i64", + "f16", + "bf16", + "f32", + "AnyFloat", + "AnyInt", + "AnyType", + "AnyMask", +] diff --git a/tilelang-dsl/tests/README.md b/tilelang-dsl/tests/README.md new file mode 100644 index 000000000..c1370a85a --- /dev/null +++ b/tilelang-dsl/tests/README.md @@ -0,0 +1,5 @@ +TileLang DSL tests live here. + +Keep tests for this frontend isolated from the legacy `test/python/` and +other repository-wide test trees unless a follow-up task explicitly wires +shared coverage. diff --git a/tilelang-dsl/tests/import_tilelang_dsl.py b/tilelang-dsl/tests/import_tilelang_dsl.py new file mode 100644 index 000000000..a410cd95d --- /dev/null +++ b/tilelang-dsl/tests/import_tilelang_dsl.py @@ -0,0 +1,12 @@ +import tilelang_dsl + + +def main() -> None: + package_file = getattr(tilelang_dsl, "__file__", None) + if not package_file: + raise SystemExit("tilelang_dsl import did not expose __file__") + print(package_file) + + +if __name__ == "__main__": + main() diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py new file mode 100644 index 000000000..bef057aeb --- /dev/null +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -0,0 +1,153 @@ +import tempfile +import unittest +from importlib import util +from pathlib import Path + +import tilelang_dsl as pto + + +class TileLangDSLPackageTests(unittest.TestCase): + def test_package_exports_surface(self) -> None: + self.assertIsNotNone(pto.__file__) + self.assertTrue(hasattr(pto, "vkernel")) + self.assertTrue(hasattr(pto, "TensorView")) + self.assertTrue(hasattr(pto, "Tile")) + self.assertTrue(hasattr(pto, "TileSpecialization")) + + +class TileLangDSLDescriptorTests(unittest.TestCase): + def test_descriptor_metadata_and_parameter_binding(self) -> None: + @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f16, pto.i32)], verify=False) + def kernel(inp: pto.TensorView, tile: pto.Tile, scale: pto.i32): + return None + + self.assertEqual(kernel.target, "a5") + self.assertEqual(kernel.op, "eltwise") + self.assertEqual(kernel.name, "kernel") + self.assertFalse(kernel.verify_enabled) + self.assertEqual(kernel.metadata["verify"], False) + self.assertEqual(kernel.dtype_signature, (pto.f32, pto.f16, pto.i32)) + self.assertEqual( + [(param.name, param.kind, param.dtype) for param in kernel.parameters], + [("inp", "tensorview", pto.f32), ("tile", "tile", pto.f16), ("scale", "scalar", pto.i32)], + ) + self.assertEqual(kernel.parameters[0].element_dtype, pto.f32) + self.assertEqual(kernel.parameters[1].element_dtype, pto.f16) + self.assertIsNone(kernel.parameters[2].element_dtype) + + def test_specialization_enables_materialization_apis(self) -> None: + @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f16)]) + def kernel(inp: pto.TensorView, tile: pto.Tile): + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(16, 32), + memory_space=pto.MemorySpace.UB, + config=pto.TileConfig.from_mapping({"layout": "row_major"}), + ) + ) + + self.assertIn("tile", specialized.specializations_by_name) + text = specialized.mlir_text() + self.assertIn("// tilelang.target = a5", text) + self.assertIn("// tilelang.specialize tile shape=(16, 32) memory_space=ub", text) + module = specialized.mlir_module() + self.assertEqual(type(module).__name__, "MaterializedMLIRModule") + self.assertTrue(module.verify()) + self.assertTrue(specialized.verify()) + + with tempfile.TemporaryDirectory() as tmpdir: + out = Path(tmpdir) / "kernel.mlir" + specialized.emit(out) + self.assertEqual(out.read_text(encoding="utf-8"), text) + + +class TileLangDSLDiagnosticsTests(unittest.TestCase): + def test_matcher_feature_diagnostics_point_to_follow_up_change(self) -> None: + cases = [ + lambda: pto.vkernel(op="x", dtypes=[(pto.f32,)], constraints=[])(lambda x: None), + lambda: pto.vkernel(op="x", dtypes=[(pto.f32,)], priority=1)(lambda x: None), + lambda: pto.vkernel(op="x", dtypes=[(pto.f32,), (pto.f16,)])(lambda x: None), + lambda: pto.vkernel(op="x", dtypes=[(pto.AnyFloat,)])(lambda x: None), + lambda: pto.vkernel(op="x", dtypes=[(pto.TypeVar("T"),)])(lambda x: None), + ] + + for thunk in cases: + with self.assertRaises(ValueError) as ctx: + thunk() + self.assertIn( + "extend-tilelang-dsl-matcher-and-advanced-surface", + str(ctx.exception), + ) + + def test_unsupported_python_syntax_reports_source_location(self) -> None: + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.vkernel(op="x", dtypes=[(pto.f32,)]) + def kernel(x: pto.TensorView): + while True: + return None + + self.assertIn("unsupported Python syntax `while`", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_arbitrary_external_call_reports_source_location(self) -> None: + def helper(): + return None + + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.vkernel(op="x", dtypes=[(pto.f32,)]) + def kernel(x: pto.TensorView): + helper() + return None + + self.assertIn("arbitrary external call `helper`", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_unsupported_pto_surface_reports_source_location(self) -> None: + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.vkernel(op="x", dtypes=[(pto.f32,)]) + def kernel(x: pto.TensorView): + pto.vadd(x) + return None + + self.assertIn("unsupported op surface `pto.vadd`", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_missing_specialization_reports_source_location(self) -> None: + @pto.vkernel(op="x", dtypes=[(pto.f32, pto.f16)]) + def kernel(x: pto.TensorView, tile: pto.Tile): + return None + + with self.assertRaises(pto.TileLangFrontendError) as ctx: + kernel.mlir_text() + + self.assertIn("requires specialize() bindings for bare Tile parameters", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_dynamic_shape_and_illegal_profile_report_source_location(self) -> None: + @pto.vkernel(op="x", dtypes=[(pto.f32, pto.f16)]) + def kernel(x: pto.TensorView, tile: pto.Tile): + return None + + with self.assertRaises(pto.TileLangFrontendError) as dynamic_ctx: + kernel.specialize(tile={"shape": (16, "n"), "memory_space": "ub"}) + self.assertIn("dynamic physical Tile shape is not supported", str(dynamic_ctx.exception)) + self.assertIn(f"{__file__}:", str(dynamic_ctx.exception)) + + with self.assertRaises(pto.TileLangFrontendError) as rank_ctx: + kernel.specialize(tile={"shape": (4, 4, 4), "memory_space": "ub"}) + self.assertIn("v1 only supports rank-1 or rank-2 Tile shapes", str(rank_ctx.exception)) + self.assertIn(f"{__file__}:", str(rank_ctx.exception)) + + with self.assertRaises(pto.TileLangFrontendError) as space_ctx: + kernel.specialize(tile={"shape": (4, 4), "memory_space": "gm"}) + self.assertIn("v1 only supports MemorySpace.UB", str(space_ctx.exception)) + self.assertIn(f"{__file__}:", str(space_ctx.exception)) + + +if __name__ == "__main__": + unittest.main() From f5ee6764fdc84566c225a6ba04a901a7a7640b02 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Fri, 3 Apr 2026 13:49:32 +0800 Subject: [PATCH 016/192] Update openspec --- .../.openspec.yaml | 2 + .../design.md | 177 ++++++++++++++++++ .../proposal.md | 89 +++++++++ .../specs/tilelang-dsl-diagnostics/spec.md | 61 ++++++ .../specs/tilelang-dsl-surface/spec.md | 54 ++++++ .../specs/vpto-ir-legality/spec.md | 34 ++++ .../tasks.md | 29 +++ 7 files changed, 446 insertions(+) create mode 100644 openspec/changes/archive/2026-04-03-add-tilelang-dsl-core-foundation/.openspec.yaml create mode 100644 openspec/changes/archive/2026-04-03-add-tilelang-dsl-core-foundation/design.md create mode 100644 openspec/changes/archive/2026-04-03-add-tilelang-dsl-core-foundation/proposal.md create mode 100644 openspec/changes/archive/2026-04-03-add-tilelang-dsl-core-foundation/specs/tilelang-dsl-diagnostics/spec.md create mode 100644 openspec/changes/archive/2026-04-03-add-tilelang-dsl-core-foundation/specs/tilelang-dsl-surface/spec.md create mode 100644 openspec/changes/archive/2026-04-03-add-tilelang-dsl-core-foundation/specs/vpto-ir-legality/spec.md create mode 100644 openspec/changes/archive/2026-04-03-add-tilelang-dsl-core-foundation/tasks.md diff --git a/openspec/changes/archive/2026-04-03-add-tilelang-dsl-core-foundation/.openspec.yaml b/openspec/changes/archive/2026-04-03-add-tilelang-dsl-core-foundation/.openspec.yaml new file mode 100644 index 000000000..c430c5fa6 --- /dev/null +++ b/openspec/changes/archive/2026-04-03-add-tilelang-dsl-core-foundation/.openspec.yaml @@ -0,0 +1,2 @@ +schema: spec-driven +created: 2026-04-03 diff --git a/openspec/changes/archive/2026-04-03-add-tilelang-dsl-core-foundation/design.md b/openspec/changes/archive/2026-04-03-add-tilelang-dsl-core-foundation/design.md new file mode 100644 index 000000000..0bc8b9253 --- /dev/null +++ b/openspec/changes/archive/2026-04-03-add-tilelang-dsl-core-foundation/design.md @@ -0,0 +1,177 @@ +## Context + +### 范围 + +本 design 只覆盖 TileLang DSL v1 的基础前端契约,不覆盖 DSL -> VPTO lowering 本身。 +它回答四个问题: + +1. TileLang DSL v1 的代码、测试、样例、文档应该放在哪里 +2. v1 `@pto.vkernel` 的 public surface 到哪一层 +3. bare `TensorView` / `Tile` 参数如何定型 +4. frontend 必须在哪些点 fail-fast,而不是把非法输入拖到 lowering / verifier 阶段 + +### 当前状态 + +当前仓库里有三套彼此未完全对齐的事实: + +1. `docs/tilelang-dsl-guide.md` + +- 它描述了高层 TileLang DSL 的理想 surface,包括 `TensorView`、`Tile`、高层 DMA、mask inference、matcher、implicit vecscope inference 等。 +- 该文档覆盖面很大,但没有被 OpenSpec capability 拆成可实现的 v1/v2 边界。 + +2. `python/pto/dialects/pto.py` + +- 该文件当前主要服务于 PTO dialect Python bindings。 +- 其中夹带了一个实验性的 `@pto.vkernel` parser,surface 更接近直接 author VPTO,不等于 TileLang DSL guide。 +- 继续在这个文件里叠加 TileLang DSL,会把两套不同 DSL 和两套不同约束混在一起。 + +3. 真实的 VPTO legality contract + +- 当前 verifier 与 `test/vpto_validate/` 明确要求 dedicated `pto.vecscope/pto.strict_vecscope`。 +- `openspec/specs/vpto-ir-legality/spec.md` 仍残留 `llvm.loop.aivector_scope` 版本的旧 requirement。 +- 如果不先修正 spec,后续 lowering 会天然踩在错误契约上。 + +### 实现约束 + +- 本特性相关源码、样例、测试、局部文档必须集中在 `tilelang-dsl/`。 +- 根目录其他位置只允许做最小 build/install/test wiring,不得把核心实现重新塞回 `python/` 或 `test/` 现有目录树。 +- v1 只支持 `a5`,并且只支持单个 monomorphic `dtypes` signature。 +- bare `Tile` 参数不能依赖 Python runtime 值自动推导 physical shape;必须由显式 specialization 提供。 +- diagnostics 需要在 frontend 分层给出,不得把“unsupported feature”伪装成底层 verifier failure。 + +## Goals / Non-Goals + +**Goals:** + +- 建立独立 `tilelang_dsl` package 和 `tilelang-dsl/` 目录边界。 +- 固定 v1 `@pto.vkernel` 的 descriptor API 与参数定型规则。 +- 固定 v1 bare `Tile` specialization 机制。 +- 固定 v1 frontend diagnostics 的分层和失败行为。 +- 修正 `vpto-ir-legality` 的 vecscope requirement,使后续 lowering change 可依附真实 contract。 + +**Non-Goals:** + +- 不在本 change 中设计具体 lowering pass、builder 或 codegen pipeline。 +- 不在本 change 中给 `constraints`、`priority`、`Any*`、`TypeVar` 定义运行语义。 +- 不引入公开的 TileLang 中间 IR。 +- 不要求现有 `python/pto/dialects/pto.py` 与新 package 共用内部实现。 + +## Decisions + +### 1. 采用独立 package `tilelang_dsl`,而不是扩展现有 `python/pto` + +决策: + +- TileLang DSL v1 实现放在 `tilelang-dsl/python/tilelang_dsl/` +- 示例统一使用 `import tilelang_dsl as pto` +- `tilelang-dsl/tests/`、`tilelang-dsl/examples/`、`tilelang-dsl/docs/` 一并承载本特性工件 + +原因: + +- 用户已经明确要求不考虑现有其他 Python binding 实现。 +- 现有 `python/pto/dialects/pto.py` 同时承担 dialect binding 和实验 DSL,继续叠加会扩大耦合面。 +- 独立 package 让 OpenSpec、测试和后续实现边界都更清晰。 + +备选方案: + +- 直接扩展 `python/pto/dialects/pto.py` + - 放弃原因:会把 TileLang DSL 和现有实验 VPTO DSL 混在同一入口,难以隔离行为和文档口径。 + +### 2. v1 decorator surface 固定为 `a5` 单 target + 单一 monomorphic `dtypes` + +决策: + +- v1 `@pto.vkernel` 仅接受 `target="a5"` +- `op` 为必填 metadata +- `dtypes` 必须是仅含一个 tuple 的 monomorphic signature +- `name`、`verify` 保留 +- `constraints`、`priority`、多 signature `dtypes`、`Any*`、`TypeVar` 一律在 frontend reject,并由 diagnostics 明确指向 follow-up change + +原因: + +- 这是当前最小且可实现的契约,能支撑后续 v1 lowering,不会把 matcher 语义混入基础 change。 +- 这让参数定型规则稳定:每个参数位置只存在一个最终类型绑定结果。 + +备选方案: + +- 直接支持完整 matcher surface + - 放弃原因:会把 kernel registry、constraint evaluation、tie-breaking、wildcard typing 一并引入,超出 v1 基础 change。 + +### 3. bare `TensorView` / `Tile` 注解继续沿用 guide 风格,元素类型通过 `dtypes` 绑定 + +决策: + +- `TensorView` / `Tile` 参数在函数签名中使用 bare annotation +- 单个 `dtypes` signature 按参数位置绑定元素类型 +- 标量参数仍使用显式标量注解,并在 `dtypes` 的同位置写出同类型 + +原因: + +- 这与 `docs/tilelang-dsl-guide.md` 的核心书写方式一致。 +- 对后续 matcher change 友好,不会先人为引入另一套“参数注解写满全部类型”的平行 surface。 + +备选方案: + +- 要求每个参数都在注解中写完整类型/shape + - 放弃原因:与 guide 差异过大,也会把 Tile specialization 和动态 TensorView profile 搅进签名层。 + +### 4. bare `Tile` 参数采用 descriptor-level specialization,而不是 Python runtime 推导 + +决策: + +- bare `Tile` 参数的 physical shape / memory space / config 不在函数定义时写死 +- `descriptor.specialize(**bindings)` 是唯一合法的补全入口 +- 只有所有 bare `Tile` 参数都 specialization 完成后,才能调用 `mlir_text()` / `mlir_module()` / `verify()` + +原因: + +- Tile physical shape 必须静态,但内核定义时未必能知道具体实例。 +- 显式 specialization 比隐式 runtime 推导更稳定,也更适合后续 matcher / registry 场景。 + +备选方案: + +- 让 `Tile` 参数从 runtime Python object 自动推导 + - 放弃原因:会把编译期 contract 和运行期 object 混在一起,难以保证 deterministic IR surface。 + +### 5. diagnostics 在 frontend 分层 fail-fast,不把 unsupported feature 甩给后端 verifier + +决策: + +- decorator-level unsupported feature +- syntax-level unsupported Python construct +- type/profile-level illegal shape / missing specialization +- lowering前非法 vector-scope 前提 + +以上都必须在 TileLang frontend 直接报错,并附带源码位置。 + +原因: + +- 这些错误是 TileLang surface 语义问题,不应该等到底层 VPTO verifier 再以“语义不合法 IR”形式暴露。 +- fail-fast diagnostics 才能稳定区分“DSL 不支持”与“lowering 出 bug”。 + +### 6. 先在 OpenSpec 中修正 `vpto-ir-legality` 的 vecscope contract + +决策: + +- 本 change 直接带一个 `vpto-ir-legality` delta +- 明确 authoring-form 只接受 dedicated `pto.vecscope/pto.strict_vecscope` +- 继续拒绝 legacy `scf.for {llvm.loop.aivector_scope}` + +原因: + +- 这是后续 lowering change 的前置条件。 +- 继续保留错误 spec 会让实现与契约长期漂移。 + +## Risks / Trade-offs + +- [Risk] `tilelang_dsl` 与现有 `pto` 相关命名空间容易混淆 + Mitigation:独立 package 名固定为 `tilelang_dsl`,示例中只通过 `import tilelang_dsl as pto` 复用书写风格。 + +- [Risk] v1 对 matcher feature 的 reject 可能被误解为“设计缺失” + Mitigation:在 proposal、diagnostics 和 follow-up change 中明确这些能力被延期到 `extend-tilelang-dsl-matcher-and-advanced-surface`。 + +- [Risk] 先修 spec 再做实现可能暴露与现有 verifier 更多不一致点 + Mitigation:本 change 只修正已经被 `lib/PTO/Transforms/PTOValidateVPTOIR.cpp` 与 `test/vpto_validate/` 明确证明的 vecscope 冲突,不额外扩张。 + +- [Risk] bare `Tile` specialization API 若定义不清,会把 shape/profile 责任拖到后续 change + Mitigation:在本 change 中直接固定 `specialize()` 是唯一入口,并把缺失 specialization 归类为 frontend error。 diff --git a/openspec/changes/archive/2026-04-03-add-tilelang-dsl-core-foundation/proposal.md b/openspec/changes/archive/2026-04-03-add-tilelang-dsl-core-foundation/proposal.md new file mode 100644 index 000000000..a60bdbc89 --- /dev/null +++ b/openspec/changes/archive/2026-04-03-add-tilelang-dsl-core-foundation/proposal.md @@ -0,0 +1,89 @@ +# Proposal: 建立 `tilelang-dsl/` 独立前端的 v1 基础契约 + +## 概述 + +`docs/tilelang-dsl-guide.md` 已经描述了一套面向 Tile/TensorView authoring 的 Python DSL,但仓库中还没有与该文档对齐的 OpenSpec change,也没有把这套 surface 与现有 `python/pto/dialects/pto.py` 的实验性 VPTO Python DSL 做清晰切分。 +本 change 先落定 TileLang DSL v1 的基础契约:在 `tilelang-dsl/` 下建立独立前端目录、独立 package、独立测试与文档边界,并把首版 public surface、参数定型方式、Tile specialization 机制和 frontend diagnostics 固化为 OpenSpec。 + +## 背景与动机 + +当前存在三类直接问题: + +1. `docs/tilelang-dsl-guide.md` 的 surface 尚未被 OpenSpec 约束 + +- 文档里已经暴露了 `@pto.vkernel`、`TensorView`、`Tile`、高层 `dma_load/dma_store`、typed-mask、element-indexing 等 surface。 +- 这些 surface 没有对应 capability spec,后续实现边界、失败行为和测试覆盖都无法稳定收敛。 + +2. 现有 `python/pto/dialects/pto.py` 中的实验实现不是合适的 source of truth + +- 当前文件把 PTO dialect Python bindings 和一个实验性的 `@pto.vkernel` parser 混在一起。 +- 它使用的是另一套更接近 hand-written VPTO 的 surface,和 `docs/tilelang-dsl-guide.md` 的目标 DSL 并不等价。 +- 用户已明确要求本特性不依赖现有其他 Python binding 实现,并要求把本特性相关工作集中在 `tilelang-dsl/` 目录。 + +3. `vpto-ir-legality` 的 vecscope OpenSpec 契约与真实实现/测试不一致 + +- 现有 `openspec/specs/vpto-ir-legality/spec.md` 仍把 `scf.for {llvm.loop.aivector_scope}` 记作 authoring-form carrier。 +- 当前 verifier 与回归已经明确拒绝 legacy `scf.for {llvm.loop.aivector_scope}`,要求使用 dedicated `pto.vecscope/pto.strict_vecscope`。 +- TileLang DSL 如果要稳定 lower 到 authoring-form VPTO,必须先把这一契约纠正到与真实实现一致。 + +## 目标 + +- 在 `tilelang-dsl/` 下定义独立的 TileLang DSL v1 package、源码边界、测试边界和文档边界。 +- 固定 v1 `@pto.vkernel` 的最小 public surface:`a5` 单 target、monomorphic `dtypes`、bare `TensorView`/`Tile` 参数注解、descriptor API、Tile specialization。 +- 固定 frontend diagnostics 契约,确保 unsupported matcher feature、unsupported Python syntax、Tile specialization 缺失、非法 shape profile 都能 fail-fast。 +- 通过 OpenSpec 修正 `vpto-ir-legality` 中 vecscope 相关 requirement,使其与当前 verifier / lit 回归一致。 + +## 非目标 + +- 不在本 change 中实现 DSL -> VPTO lowering;该部分由后续 `add-tilelang-dsl-authoring-vpto-lowering` change 覆盖。 +- 不在本 change 中引入 kernel matcher、`Any*` / `TypeVar`、多 signature `dtypes`、`constraints`、`priority`。 +- 不要求复用或改造现有 `python/pto/dialects/pto.py` 的实验 `@pto.vkernel` 实现。 +- 不为 `a5` 之外的 target 建模。 +- 不在本 change 中扩展到 implicit vecscope inference、raw pointer authoring、advanced vector family。 + +## 变更内容 + +- 新增 `tilelang-dsl-surface` capability,定义独立 package、repo layout、v1 `@pto.vkernel` surface、参数定型方式、descriptor API 与 Tile specialization 契约。 +- 新增 `tilelang-dsl-diagnostics` capability,定义 frontend 对 unsupported feature、unsupported syntax、specialization 缺失和 shape profile 错误的诊断义务。 +- 修改 `vpto-ir-legality` capability,修正 authoring-form vecscope carrier 的 requirement:以 dedicated `pto.vecscope/pto.strict_vecscope` 为准,并继续拒绝 legacy `scf.for {llvm.loop.aivector_scope}`。 + +## Capabilities + +### New Capabilities + +- `tilelang-dsl-surface`: 定义 `tilelang-dsl/` 独立前端的 v1 public surface、package 入口、descriptor API、参数定型与 Tile specialization 契约。 +- `tilelang-dsl-diagnostics`: 定义 TileLang DSL v1 frontend 的 fail-fast diagnostics、错误定位与错误分层契约。 + +### Modified Capabilities + +- `vpto-ir-legality`: 修正 authoring-form VPTO vecscope carrier 契约,使 OpenSpec 与当前 verifier / regression 使用的 dedicated `pto.vecscope/pto.strict_vecscope` 语义一致。 + +## 预期结果 + +- `tilelang-dsl/` 成为本特性的唯一源码、样例、测试和局部文档承载目录;根目录只保留最小 build/install/test 接线。 +- TileLang DSL v1 的 public surface 和 diagnostics 行为不再依赖 `docs/tilelang-dsl-guide.md` 的口头描述或现有实验实现,而是有明确 OpenSpec 契约。 +- 后续 DSL -> VPTO lowering change 可以直接依附真实的 authoring-form VPTO legality contract,而不是继续踩在错误的 `llvm.loop.aivector_scope` requirement 上。 + +## 成功标准 + +- 新增 `openspec/changes/add-tilelang-dsl-core-foundation/`,包含 proposal、design、tasks。 +- 新增 `specs/tilelang-dsl-surface/spec.md` 和 `specs/tilelang-dsl-diagnostics/spec.md`。 +- 新增 `specs/vpto-ir-legality/spec.md` delta,明确 legacy `scf.for {llvm.loop.aivector_scope}` 不再是合法 authoring-form carrier。 +- proposal/design/tasks 明确写清: + - `tilelang-dsl/` 是本特性的唯一工作目录; + - v1 只接受 `a5` 单 target 和单一 monomorphic `dtypes` signature; + - bare `Tile` 参数必须先 specialization 再 materialize IR; + - unsupported matcher feature 与 unsupported syntax 必须 fail-fast。 + +## 影响 + +- 受影响目录: + - `tilelang-dsl/` + - `openspec/specs/vpto-ir-legality/spec.md` + - 必要的根级 CMake / 安装 / 测试入口接线 +- 受影响 public API: + - 新增独立 package `tilelang_dsl` + - 新增 v1 descriptor API:`specialize()`, `mlir_text()`, `mlir_module()`, `verify()`, `emit(path)` +- 对现有 `python/pto/dialects/pto.py` 的要求: + - 不再作为本特性的 source of truth + - 如需接线,只允许最小兼容或安装 wiring,不允许把 TileLang DSL 核心逻辑继续堆叠在该文件内 diff --git a/openspec/changes/archive/2026-04-03-add-tilelang-dsl-core-foundation/specs/tilelang-dsl-diagnostics/spec.md b/openspec/changes/archive/2026-04-03-add-tilelang-dsl-core-foundation/specs/tilelang-dsl-diagnostics/spec.md new file mode 100644 index 000000000..dcd4bb029 --- /dev/null +++ b/openspec/changes/archive/2026-04-03-add-tilelang-dsl-core-foundation/specs/tilelang-dsl-diagnostics/spec.md @@ -0,0 +1,61 @@ +# tilelang-dsl-diagnostics Specification + +## ADDED Requirements + +### Requirement: v1 MUST fail fast on unsupported matcher and decorator features + +TileLang DSL v1 frontend 对以下 surface MUST fail-fast,而不是静默忽略或拖到 lowering 阶段: + +- 多个 `dtypes` signature +- `constraints` +- `priority` +- `AnyFloat` / `AnyInt` / `AnyType` / `AnyMask` +- `TypeVar` + +#### Scenario: unsupported matcher feature is rejected at decorator parse time + +- **WHEN** 用户在 v1 kernel decorator 中写入 `constraints`、`priority`、多 signature `dtypes`、`Any*` 或 `TypeVar` +- **THEN** frontend MUST 直接报错 +- **AND** 诊断 MUST 明确指出该 feature 不属于 v1 范围 +- **AND** 诊断 SHOULD 指向 follow-up change `extend-tilelang-dsl-matcher-and-advanced-surface`,而不是伪装成底层 type error + +### Requirement: v1 MUST reject unsupported Python syntax and unsupported DSL calls before IR generation + +TileLang DSL v1 frontend MUST 只接受受限 Python 子集。 +`while`、list/dict/set comprehension、arbitrary external function call、未注册 DSL op、以及其他超出 v1 surface 的 Python 结构 MUST 在 frontend 被拒绝。 + +#### Scenario: unsupported Python construct is rejected before lowering + +- **WHEN** kernel body 使用 `while`、comprehension、任意非 `pto.*` function call 或未纳入 v1 support matrix 的 DSL call +- **THEN** frontend MUST 在生成任何 VPTO IR 之前报错 +- **AND** 诊断 MUST 指明违规的 Python construct 或 DSL call 名称 + +### Requirement: Tile specialization and shape-profile errors MUST be diagnosed in the frontend + +TileLang DSL v1 frontend MUST 把以下错误归类为前端错误: + +- bare `Tile` 参数未完成 specialization +- Tile physical shape 不是静态编译期常量 +- Tile profile 与 v1 支持的 rank / memory-space 约束不匹配 + +#### Scenario: unspecialized or dynamically-shaped tile fails before materialization + +- **WHEN** kernel 含 bare `Tile` 参数但调用方未完成 `specialize()`,或 specialization 试图给出 dynamic physical tile shape +- **THEN** frontend MUST 在 `mlir_text()` / `mlir_module()` / `verify()` 之前直接报错 +- **AND** MUST NOT 继续尝试生成不完整的 authoring-form VPTO IR + +### Requirement: frontend diagnostics MUST include source location and semantic cause + +TileLang DSL v1 的 frontend diagnostics MUST 包含 DSL 源位置和语义原因。 +错误消息 MUST 能区分: + +- decorator surface 不支持 +- Python 语法子集不支持 +- 参数定型失败 +- Tile specialization/profile 非法 + +#### Scenario: user sees actionable diagnostic with source location + +- **WHEN** frontend 因 unsupported feature、unsupported syntax、type binding failure 或 specialization error 拒绝一个 kernel +- **THEN** 诊断 MUST 至少包含 DSL 源文件位置、行列号或等价的 source span +- **AND** MUST 明确指出失败原因属于哪一层 frontend 语义,而不是只给出底层 verifier 或 parser 的通用报错 diff --git a/openspec/changes/archive/2026-04-03-add-tilelang-dsl-core-foundation/specs/tilelang-dsl-surface/spec.md b/openspec/changes/archive/2026-04-03-add-tilelang-dsl-core-foundation/specs/tilelang-dsl-surface/spec.md new file mode 100644 index 000000000..24e1cd4e6 --- /dev/null +++ b/openspec/changes/archive/2026-04-03-add-tilelang-dsl-core-foundation/specs/tilelang-dsl-surface/spec.md @@ -0,0 +1,54 @@ +# tilelang-dsl-surface Specification + +## ADDED Requirements + +### Requirement: TileLang DSL v1 MUST live under `tilelang-dsl/` and expose a dedicated `tilelang_dsl` package + +TileLang DSL v1 的实现、样例、测试和局部文档 MUST 集中在 `tilelang-dsl/`。 +对外 import 入口 MUST 是独立 package `tilelang_dsl`,不得继续把本特性的核心逻辑建立在现有 `python/pto/dialects/pto.py` 的实验 DSL 上。 +根目录其他路径若有改动,MUST 仅限最小 build/install/test wiring。 + +#### Scenario: TileLang DSL source stays isolated from existing Python binding code + +- **WHEN** 仓库为 TileLang DSL v1 新增源码、样例、测试和局部文档 +- **THEN** 这些工件 MUST 放在 `tilelang-dsl/` 下 +- **AND** repo root 或 `python/` 现有目录树的改动 MUST 只承担最小接线职责 +- **AND** `python/pto/dialects/pto.py` MUST NOT 继续作为 TileLang DSL v1 的 source of truth + +### Requirement: v1 `@pto.vkernel` surface MUST be limited to the monomorphic `a5` profile + +TileLang DSL v1 的 `@pto.vkernel` MUST 只接受 `target="a5"`。 +`op` MUST 作为必填 metadata 保留。 +`dtypes` MUST 只包含一个 monomorphic signature tuple。 +`name` 和 `verify` MAY 保留为可选字段。 +v1 不在 public surface 中支持多 signature `dtypes`、`constraints`、`priority`、`Any*` 或 `TypeVar`。 + +#### Scenario: monomorphic a5 kernel descriptor is accepted + +- **WHEN** 用户定义 `@pto.vkernel(target="a5", op="scale", dtypes=[(pto.f32, pto.f32, pto.f32)])` +- **THEN** frontend MUST 接受该 decorator surface +- **AND** descriptor MUST 保留 `target/op/dtypes/name/verify` metadata 用于后续编译和调试 + +### Requirement: bare `TensorView` and `Tile` annotations MUST bind element types through the single `dtypes` signature + +在 v1 中,`TensorView` 和 `Tile` 参数 MUST 允许使用 bare annotation。 +其元素类型 MUST 由 decorator 的单个 `dtypes` signature 按参数位置绑定。 +标量参数 MUST 继续使用显式标量注解,并与 `dtypes` 中对应位置的标量类型保持一致。 + +#### Scenario: `dtypes` binds operand element types positionally + +- **WHEN** kernel 参数按位置写成 `TensorView, TensorView, Tile, pto.f32` +- **THEN** 单个 `dtypes` signature MUST 按同样的位置顺序提供两个 GM operand 的元素类型、一个 Tile operand 的元素类型和一个标量类型 +- **AND** frontend MUST 使用该 signature 作为参数定型的唯一来源 + +### Requirement: bare `Tile` parameters MUST require explicit specialization before IR materialization + +对 bare `Tile` 参数,frontend MUST 在 descriptor 上提供显式 specialization 入口。 +Tile 的 physical shape、memory space 和配置 MUST 在 specialization 阶段补全。 +在所有 bare `Tile` 参数完成 specialization 之前,descriptor MUST NOT 允许执行 `mlir_text()`, `mlir_module()`, `verify()` 或 `emit(path)`。 + +#### Scenario: specialized tile kernel can materialize IR + +- **WHEN** kernel 含 bare `Tile` 参数,且调用方通过 `descriptor.specialize(**bindings)` 为所有 bare `Tile` 参数补齐静态 shape / space / config +- **THEN** 返回的 specialized descriptor MUST 允许调用 `mlir_text()`, `mlir_module()`, `verify()` 和 `emit(path)` +- **AND** specialization 之后的 Tile physical shape MUST 作为编译期静态契约固定下来 diff --git a/openspec/changes/archive/2026-04-03-add-tilelang-dsl-core-foundation/specs/vpto-ir-legality/spec.md b/openspec/changes/archive/2026-04-03-add-tilelang-dsl-core-foundation/specs/vpto-ir-legality/spec.md new file mode 100644 index 000000000..5f0e9b308 --- /dev/null +++ b/openspec/changes/archive/2026-04-03-add-tilelang-dsl-core-foundation/specs/vpto-ir-legality/spec.md @@ -0,0 +1,34 @@ +# vpto-ir-legality Specification + +## MODIFIED Requirements + +### Requirement: VPTO vector and predicate structure MUST stay inside a single dedicated vecscope carrier + +legacy `scf.for {llvm.loop.aivector_scope}` authoring form MUST NOT be accepted any longer. +在 authoring-form VPTO 中,所有消费或产生 `!pto.vreg`、`!pto.mask<...>`、`!pto.align` 的 VPTO op MUST 位于 dedicated `pto.vecscope` 或 `pto.strict_vecscope` 作用域内。 +同时,dedicated `pto.vecscope/pto.strict_vecscope` carrier MUST NOT 相互嵌套。 + +#### Scenario: vector or predicate VPTO op outside dedicated vecscope is rejected + +- **WHEN** authoring-form VPTO IR 中出现消费或产生 `!pto.vreg`、`!pto.mask<...>`、`!pto.align` 的 VPTO op,且该 op 不在任何 `pto.vecscope` 或 `pto.strict_vecscope` 内 +- **THEN** authoring-stage verifier MUST 拒绝该 IR +- **AND** 诊断 MUST 明确指出违规 op 缺少 enclosing dedicated vecscope + +#### Scenario: legacy `scf.for {llvm.loop.aivector_scope}` carrier is rejected + +- **WHEN** authoring-form VPTO IR 试图继续使用带 `llvm.loop.aivector_scope` attr 的 `scf.for` 作为 vector carrier +- **THEN** authoring-stage verifier MUST 拒绝该 IR +- **AND** 诊断 MUST 明确指出该 form 已是 legacy authoring surface +- **AND** 诊断 MUST 要求改用 dedicated `pto.vecscope/pto.strict_vecscope` + +#### Scenario: nested dedicated vecscope carriers are rejected + +- **WHEN** 某个 `pto.vecscope` 或 `pto.strict_vecscope` 作用域内再次出现 dedicated `pto.vecscope` 或 `pto.strict_vecscope` +- **THEN** authoring-stage verifier MUST 拒绝该 IR +- **AND** 诊断 MUST 明确指出存在 nested dedicated vecscope + +#### Scenario: shared scalar and control-flow surface is still allowed outside dedicated vecscope + +- **WHEN** `arith`、`scf`、pointer-building、copy programming 或 sync programming 相关 op 本身不产生也不消费 `!pto.vreg`、`!pto.mask<...>`、`!pto.align` +- **THEN** authoring-stage verifier MUST NOT 仅因为这些 op 位于 dedicated vecscope 外就拒绝 IR +- **AND** vecscope 约束 MUST 只针对 VPTO vector / predicate / align surface 生效 diff --git a/openspec/changes/archive/2026-04-03-add-tilelang-dsl-core-foundation/tasks.md b/openspec/changes/archive/2026-04-03-add-tilelang-dsl-core-foundation/tasks.md new file mode 100644 index 000000000..bab43ba2b --- /dev/null +++ b/openspec/changes/archive/2026-04-03-add-tilelang-dsl-core-foundation/tasks.md @@ -0,0 +1,29 @@ +## 1. OpenSpec 基础契约 + +- [x] 1.1 新增 `openspec/changes/add-tilelang-dsl-core-foundation/specs/tilelang-dsl-surface/spec.md`,固定 v1 package、decorator surface、参数定型和 Tile specialization 契约。 +- [x] 1.2 新增 `openspec/changes/add-tilelang-dsl-core-foundation/specs/tilelang-dsl-diagnostics/spec.md`,固定 fail-fast diagnostics、错误定位和 unsupported feature 行为。 +- [x] 1.3 新增 `openspec/changes/add-tilelang-dsl-core-foundation/specs/vpto-ir-legality/spec.md` delta,修正 vecscope carrier requirement 与真实 verifier 一致。 + +## 2. 目录与接线 + +- [x] 2.1 在 `tilelang-dsl/` 下创建 `python/`, `tests/`, `examples/`, `docs/` 基础布局,并保持本特性源码集中在该目录。 +- [x] 2.2 增加最小根级 build/install/test wiring,让 `tilelang_dsl` package 能被本仓库构建和测试系统发现,但不把核心逻辑迁回 `python/` 现有目录。 +- [x] 2.3 明确 `tilelang-dsl/` 与现有 `python/pto/dialects/pto.py` 的边界,避免在旧文件中继续堆叠 TileLang DSL 核心实现。 + +## 3. Surface 与 descriptor API + +- [x] 3.1 实现 v1 `@pto.vkernel` descriptor skeleton,固定 `target/op/dtypes/name/verify` 字段和 `a5` 单 target 约束。 +- [x] 3.2 实现 bare `TensorView` / `Tile` 参数的单一 monomorphic `dtypes` 绑定规则。 +- [x] 3.3 实现 bare `Tile` 参数的 `specialize(**bindings)` 机制,并把 `mlir_text()`, `mlir_module()`, `verify()`, `emit(path)` 挂到 descriptor 上。 + +## 4. Frontend diagnostics + +- [x] 4.1 为 `constraints`、`priority`、多 signature `dtypes`、`Any*`、`TypeVar` 提供 fail-fast diagnostics,并在消息中指向 follow-up change。 +- [x] 4.2 为 unsupported Python syntax / arbitrary call / unsupported op surface 提供 source-located diagnostics。 +- [x] 4.3 为缺失 Tile specialization、dynamic physical tile shape、非法 shape profile 提供前端错误,而不是把错误拖到 lowering 或 verifier。 + +## 5. 测试与文档 + +- [x] 5.1 在 `tilelang-dsl/tests/` 增加 package/import、descriptor API、specialization 和 diagnostics 的正反向测试。 +- [x] 5.2 在 `tilelang-dsl/docs/` 写明 v1 surface 与延期 feature,明确现有 `python/pto/dialects/pto.py` 不是本特性的 source of truth。 +- [x] 5.3 运行与记录最小验证命令,确认 `tilelang_dsl` package 可被构建/导入,diagnostics 能稳定定位到 DSL 源位置。 From 47a102f5f078842a2b8a38db3e7f675c0b85d098 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Fri, 3 Apr 2026 23:51:06 +0800 Subject: [PATCH 017/192] Support more syntax in tilelang dsl --- CMakeLists.txt | 26 +- tilelang-dsl/python/tilelang_dsl/__init__.py | 14 + .../python/tilelang_dsl/frontend_ast.py | 381 +++++ tilelang-dsl/python/tilelang_dsl/kernel.py | 131 +- tilelang-dsl/python/tilelang_dsl/lowering.py | 719 ++++++++++ tilelang-dsl/python/tilelang_dsl/semantic.py | 1265 +++++++++++++++++ tilelang-dsl/python/tilelang_dsl/types.py | 39 + tilelang-dsl/tests/test_tilelang_dsl_v1.py | 282 +++- 8 files changed, 2825 insertions(+), 32 deletions(-) create mode 100644 tilelang-dsl/python/tilelang_dsl/frontend_ast.py create mode 100644 tilelang-dsl/python/tilelang_dsl/lowering.py create mode 100644 tilelang-dsl/python/tilelang_dsl/semantic.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 01470d268..68db8312b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -90,6 +90,11 @@ include_directories(${PROJECT_BINARY_DIR}/include) link_directories(${LLVM_BUILD_LIBRARY_DIR}) add_definitions(${LLVM_DEFINITIONS}) +# ========================================================= +# 3.1 Testing option setup +# ========================================================= +include(CTest) + # 开启 Python 绑定选项 option(PTO_ENABLE_PYTHON_BINDING "Enable Python bindings" ON) @@ -101,6 +106,7 @@ if(PTO_ENABLE_PYTHON_BINDING) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/python/pto) add_subdirectory(python) + add_subdirectory(tilelang-dsl) endif() # ========================================================= @@ -113,9 +119,27 @@ add_subdirectory(tools) # ========================================================= # 4.1 Tests (ctest) # ========================================================= -include(CTest) if(BUILD_TESTING) enable_testing() + if(PTO_ENABLE_PYTHON_BINDING) + add_test( + NAME tilelang_dsl_import + COMMAND "${Python3_EXECUTABLE}" + "${CMAKE_CURRENT_SOURCE_DIR}/tilelang-dsl/tests/import_tilelang_dsl.py" + ) + set_tests_properties(tilelang_dsl_import PROPERTIES + ENVIRONMENT "PYTHONPATH=${CMAKE_BINARY_DIR}/python:$ENV{PYTHONPATH}" + ) + add_test( + NAME tilelang_dsl_unittest + COMMAND "${Python3_EXECUTABLE}" -m unittest discover + -s "${CMAKE_CURRENT_SOURCE_DIR}/tilelang-dsl/tests" + -p "test_*.py" + ) + set_tests_properties(tilelang_dsl_unittest PROPERTIES + ENVIRONMENT "PYTHONPATH=${CMAKE_BINARY_DIR}/python:$ENV{PYTHONPATH}" + ) + endif() add_subdirectory(tools/ptobc/tests) endif() diff --git a/tilelang-dsl/python/tilelang_dsl/__init__.py b/tilelang-dsl/python/tilelang_dsl/__init__.py index 19ad88cd6..ca9ea233b 100644 --- a/tilelang-dsl/python/tilelang_dsl/__init__.py +++ b/tilelang-dsl/python/tilelang_dsl/__init__.py @@ -12,7 +12,13 @@ AnyInt, AnyMask, AnyType, + EVENT, + PIPE, + Event, MemorySpace, + MaskPattern, + PAT, + Pipe, ScalarType, TensorView, Tile, @@ -24,6 +30,7 @@ bf16, f16, f32, + i1, i8, i16, i32, @@ -43,8 +50,15 @@ "TensorView", "Tile", "MemorySpace", + "Pipe", + "Event", + "PIPE", + "EVENT", + "MaskPattern", + "PAT", "TileConfig", "TileSpecialization", + "i1", "i8", "i16", "i32", diff --git a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py new file mode 100644 index 000000000..5c875d797 --- /dev/null +++ b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py @@ -0,0 +1,381 @@ +"""Frontend AST nodes for TileLang DSL descriptor materialization.""" + +from __future__ import annotations + +import ast +from dataclasses import dataclass +from typing import Any + + +@dataclass(frozen=True) +class FrontendParameterNode: + name: str + kind: str + annotation: Any + dtype: Any + + +@dataclass(frozen=True) +class FrontendTileSpecializationNode: + name: str + shape: tuple[int, ...] + memory_space: str + config: Any + + +class FrontendExprNode: + """Base class for lowered frontend expressions.""" + + +@dataclass(frozen=True) +class FrontendNameExpr(FrontendExprNode): + name: str + + +@dataclass(frozen=True) +class FrontendConstantExpr(FrontendExprNode): + value: Any + + +@dataclass(frozen=True) +class FrontendSymbolExpr(FrontendExprNode): + namespace: str + name: str + + +@dataclass(frozen=True) +class FrontendSliceExpr(FrontendExprNode): + start: FrontendExprNode | None + stop: FrontendExprNode | None + step: FrontendExprNode | None + + +@dataclass(frozen=True) +class FrontendTupleExpr(FrontendExprNode): + elements: tuple[FrontendExprNode, ...] + + +@dataclass(frozen=True) +class FrontendAttributeExpr(FrontendExprNode): + base: FrontendExprNode + attr: str + + +@dataclass(frozen=True) +class FrontendSubscriptExpr(FrontendExprNode): + base: FrontendExprNode + index: FrontendExprNode + + +@dataclass(frozen=True) +class FrontendBinaryExpr(FrontendExprNode): + lhs: FrontendExprNode + op: str + rhs: FrontendExprNode + + +@dataclass(frozen=True) +class FrontendCallExpr(FrontendExprNode): + namespace: str | None + name: str + args: tuple[FrontendExprNode, ...] + + +class FrontendTargetNode: + """Base class for assignment targets.""" + + +@dataclass(frozen=True) +class FrontendNameTarget(FrontendTargetNode): + name: str + + +@dataclass(frozen=True) +class FrontendTupleTarget(FrontendTargetNode): + elements: tuple[FrontendNameTarget, ...] + + +class FrontendStmtNode: + """Base class for lowered frontend statements.""" + + +@dataclass(frozen=True) +class FrontendAssignStmt(FrontendStmtNode): + target: FrontendTargetNode + value: FrontendExprNode + annotation: Any | None = None + + +@dataclass(frozen=True) +class FrontendExprStmt(FrontendStmtNode): + expr: FrontendExprNode + + +@dataclass(frozen=True) +class FrontendReturnStmt(FrontendStmtNode): + value: FrontendExprNode | None + + +@dataclass(frozen=True) +class FrontendForStmt(FrontendStmtNode): + target: str + lower_bound: FrontendExprNode + upper_bound: FrontendExprNode + step: FrontendExprNode + body: tuple[FrontendStmtNode, ...] + + +@dataclass(frozen=True) +class FrontendIfStmt(FrontendStmtNode): + condition: FrontendExprNode + then_body: tuple[FrontendStmtNode, ...] + else_body: tuple[FrontendStmtNode, ...] + + +@dataclass(frozen=True) +class FrontendStrictVecscopeStmt(FrontendStmtNode): + captures: tuple[FrontendExprNode, ...] + block_arguments: tuple[str, ...] + body: tuple[FrontendStmtNode, ...] + + +@dataclass(frozen=True) +class FrontendKernelNode: + target: str + op: str + name: str + verify_enabled: bool + dtype_signature: tuple[Any, ...] + parameters: tuple[FrontendParameterNode, ...] + tile_specializations: tuple[FrontendTileSpecializationNode, ...] + body: tuple[FrontendStmtNode, ...] + + +_BINARY_OP_NAMES = { + ast.Add: "add", + ast.Sub: "sub", + ast.Mult: "mul", + ast.FloorDiv: "floordiv", +} + + +def _attribute_path(node: ast.AST) -> tuple[str, ...] | None: + if isinstance(node, ast.Name): + return (node.id,) + if isinstance(node, ast.Attribute): + base_path = _attribute_path(node.value) + if base_path is None: + return None + return base_path + (node.attr,) + return None + + +def _build_expr(node: ast.AST, source_info: Any) -> FrontendExprNode: + if isinstance(node, ast.Name): + return FrontendNameExpr(name=node.id) + if isinstance(node, ast.Constant): + return FrontendConstantExpr(value=node.value) + if isinstance(node, ast.Slice): + start = None if node.lower is None else _build_expr(node.lower, source_info) + stop = None if node.upper is None else _build_expr(node.upper, source_info) + step = None if node.step is None else _build_expr(node.step, source_info) + return FrontendSliceExpr(start=start, stop=stop, step=step) + if isinstance(node, ast.Tuple): + return FrontendTupleExpr( + elements=tuple(_build_expr(elt, source_info) for elt in node.elts) + ) + if isinstance(node, ast.Attribute): + path = _attribute_path(node) + if path is not None and path[0] in {"pto", "PAT", "PIPE", "EVENT"} and len(path) >= 2: + return FrontendSymbolExpr(namespace=".".join(path[:-1]), name=path[-1]) + return FrontendAttributeExpr(base=_build_expr(node.value, source_info), attr=node.attr) + if isinstance(node, ast.Subscript): + return FrontendSubscriptExpr( + base=_build_expr(node.value, source_info), + index=_build_expr(node.slice, source_info), + ) + if isinstance(node, ast.BinOp): + op_name = _BINARY_OP_NAMES.get(type(node.op)) + if op_name is None: + raise source_info.error( + node, + f"unsupported binary operator `{type(node.op).__name__}` in TileLang DSL v1", + ) + return FrontendBinaryExpr( + lhs=_build_expr(node.left, source_info), + op=op_name, + rhs=_build_expr(node.right, source_info), + ) + if isinstance(node, ast.Call): + if isinstance(node.func, ast.Name): + return FrontendCallExpr( + namespace=None, + name=node.func.id, + args=tuple(_build_expr(arg, source_info) for arg in node.args), + ) + if isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name): + return FrontendCallExpr( + namespace=node.func.value.id, + name=node.func.attr, + args=tuple(_build_expr(arg, source_info) for arg in node.args), + ) + raise source_info.error( + node, + f"unsupported expression `{type(node).__name__}` in TileLang DSL v1", + ) + + +def _build_target(node: ast.AST, source_info: Any) -> FrontendTargetNode: + if isinstance(node, ast.Name): + return FrontendNameTarget(name=node.id) + if isinstance(node, ast.Tuple): + elements = [] + for elt in node.elts: + if not isinstance(elt, ast.Name): + raise source_info.error(elt, "tuple assignment only supports names in TileLang DSL v1") + elements.append(FrontendNameTarget(name=elt.id)) + return FrontendTupleTarget(elements=tuple(elements)) + raise source_info.error( + node, + f"unsupported assignment target `{type(node).__name__}` in TileLang DSL v1", + ) + + +def _build_stmt(node: ast.stmt, source_info: Any) -> FrontendStmtNode: + if isinstance(node, ast.Assign): + if len(node.targets) != 1: + raise source_info.error(node, "multiple assignment targets are not supported in TileLang DSL v1") + return FrontendAssignStmt( + target=_build_target(node.targets[0], source_info), + value=_build_expr(node.value, source_info), + ) + if isinstance(node, ast.AnnAssign): + if node.value is None: + raise source_info.error(node, "annotation-only assignments are not supported in TileLang DSL v1") + return FrontendAssignStmt( + target=_build_target(node.target, source_info), + value=_build_expr(node.value, source_info), + annotation=node.annotation, + ) + if isinstance(node, ast.Expr): + return FrontendExprStmt(expr=_build_expr(node.value, source_info)) + if isinstance(node, ast.Return): + value = None + if node.value is not None: + if not (isinstance(node.value, ast.Constant) and node.value.value is None): + value = _build_expr(node.value, source_info) + return FrontendReturnStmt(value=value) + if isinstance(node, ast.For): + if not isinstance(node.target, ast.Name): + raise source_info.error(node.target, "for target must be a single name") + if not isinstance(node.iter, ast.Call) or not isinstance(node.iter.func, ast.Name) or node.iter.func.id != "range": + raise source_info.error(node.iter, "only Python range(lb, ub, step) loops are supported") + if len(node.iter.args) != 3: + raise source_info.error(node.iter, "range() expects exactly 3 arguments in TileLang DSL v1") + return FrontendForStmt( + target=node.target.id, + lower_bound=_build_expr(node.iter.args[0], source_info), + upper_bound=_build_expr(node.iter.args[1], source_info), + step=_build_expr(node.iter.args[2], source_info), + body=tuple(_build_stmt(stmt, source_info) for stmt in node.body), + ) + if isinstance(node, ast.If): + return FrontendIfStmt( + condition=_build_expr(node.test, source_info), + then_body=tuple(_build_stmt(stmt, source_info) for stmt in node.body), + else_body=tuple(_build_stmt(stmt, source_info) for stmt in node.orelse), + ) + if isinstance(node, ast.With): + if len(node.items) != 1: + raise source_info.error(node, "only a single with-item is supported in TileLang DSL v1") + item = node.items[0] + if not isinstance(item.context_expr, ast.Call): + raise source_info.error(item.context_expr, "with context must be a call in TileLang DSL v1") + if not ( + isinstance(item.context_expr.func, ast.Attribute) + and isinstance(item.context_expr.func.value, ast.Name) + and item.context_expr.func.value.id == "pto" + and item.context_expr.func.attr == "strict_vecscope" + ): + raise source_info.error(item.context_expr, "only pto.strict_vecscope is supported in TileLang DSL v1") + if not isinstance(item.optional_vars, ast.Tuple): + raise source_info.error(item, "pto.strict_vecscope requires tuple binding in 'as'") + block_arguments = [] + for elt in item.optional_vars.elts: + if not isinstance(elt, ast.Name): + raise source_info.error(elt, "pto.strict_vecscope bindings must be names") + block_arguments.append(elt.id) + return FrontendStrictVecscopeStmt( + captures=tuple(_build_expr(arg, source_info) for arg in item.context_expr.args), + block_arguments=tuple(block_arguments), + body=tuple(_build_stmt(stmt, source_info) for stmt in node.body), + ) + raise source_info.error( + node, + f"unsupported statement `{type(node).__name__}` in TileLang DSL v1", + ) + + +def build_frontend_kernel_node(descriptor: Any) -> FrontendKernelNode: + """Project the core-foundation descriptor into a lowering-owned AST.""" + + parameters = tuple( + FrontendParameterNode( + name=param.name, + kind=param.kind, + annotation=param.annotation, + dtype=param.dtype, + ) + for param in descriptor.parameters + ) + tile_specializations = tuple( + FrontendTileSpecializationNode( + name=name, + shape=spec.shape, + memory_space=spec.memory_space.value, + config=spec.config, + ) + for name, spec in descriptor.specializations + ) + source_info = descriptor._source_info + body = () + if source_info is not None: + body = tuple(_build_stmt(stmt, source_info) for stmt in source_info.function_def.body) + return FrontendKernelNode( + target=descriptor.target, + op=descriptor.op, + name=descriptor.name, + verify_enabled=descriptor.verify_enabled, + dtype_signature=descriptor.dtype_signature, + parameters=parameters, + tile_specializations=tile_specializations, + body=body, + ) + + +__all__ = [ + "FrontendAssignStmt", + "FrontendAttributeExpr", + "FrontendBinaryExpr", + "FrontendCallExpr", + "FrontendConstantExpr", + "FrontendExprNode", + "FrontendExprStmt", + "FrontendForStmt", + "FrontendIfStmt", + "FrontendKernelNode", + "FrontendNameExpr", + "FrontendNameTarget", + "FrontendParameterNode", + "FrontendReturnStmt", + "FrontendSliceExpr", + "FrontendStrictVecscopeStmt", + "FrontendStmtNode", + "FrontendSubscriptExpr", + "FrontendSymbolExpr", + "FrontendTargetNode", + "FrontendTileSpecializationNode", + "FrontendTupleExpr", + "FrontendTupleTarget", + "build_frontend_kernel_node", +] diff --git a/tilelang-dsl/python/tilelang_dsl/kernel.py b/tilelang-dsl/python/tilelang_dsl/kernel.py index 621b19e80..f27d00e77 100644 --- a/tilelang-dsl/python/tilelang_dsl/kernel.py +++ b/tilelang-dsl/python/tilelang_dsl/kernel.py @@ -3,7 +3,6 @@ from __future__ import annotations import inspect -import re import textwrap import ast from dataclasses import dataclass, field @@ -20,10 +19,46 @@ TypeVariable, WildcardType, ) +from .frontend_ast import build_frontend_kernel_node +from .lowering import lower_semantic_kernel +from .semantic import analyze_frontend_kernel _UNSET = object() _MATCHER_FOLLOW_UP_CHANGE = "extend-tilelang-dsl-matcher-and-advanced-surface" +_V1_ALLOWED_TOPLEVEL_PTO_CALLS = { + "strict_vecscope", + "dma_load", + "dma_store", + "set_flag", + "wait_flag", + "pipe_barrier", + "barrier", +} +_V1_ALLOWED_VECSCOPE_PTO_CALLS = { + "make_mask", + "vlds", + "vsts", + "vabs", + "vrelu", + "vexp", + "vnot", + "vadd", + "vsub", + "vmul", + "vdiv", + "vmax", + "vmin", + "vand", + "vor", + "vxor", + "vadds", + "vsubs", + "vmuls", + "vdivs", + "vmaxs", + "vmins", +} def _unsupported_feature_message(feature: str) -> str: @@ -86,6 +121,7 @@ def parameter_node(self, param_name: str) -> ast.AST | None: class _KernelBodyValidator(ast.NodeVisitor): def __init__(self, source_info: _FunctionSourceInfo): self.source_info = source_info + self._vecscope_depth = 0 def validate(self) -> None: for stmt in self.source_info.function_def.body: @@ -114,8 +150,65 @@ def visit_GeneratorExp(self, node: ast.GeneratorExp) -> None: node, "unsupported Python syntax `generator expression` in TileLang DSL v1" ) + def visit_For(self, node: ast.For) -> None: + if not isinstance(node.target, ast.Name): + raise self.source_info.error(node.target, "for target must be a single name") + if not isinstance(node.iter, ast.Call) or not isinstance(node.iter.func, ast.Name): + raise self.source_info.error(node.iter, "only Python range(lb, ub, step) loops are supported") + if node.iter.func.id != "range": + raise self.source_info.error(node.iter, "only Python range(lb, ub, step) loops are supported") + if len(node.iter.args) != 3: + raise self.source_info.error(node.iter, "range() expects exactly 3 arguments in TileLang DSL v1") + for stmt in node.body: + self.visit(stmt) + for stmt in node.orelse: + self.visit(stmt) + + def visit_If(self, node: ast.If) -> None: + for stmt in node.body: + self.visit(stmt) + for stmt in node.orelse: + self.visit(stmt) + + def visit_With(self, node: ast.With) -> None: + if len(node.items) != 1: + raise self.source_info.error(node, "only single with item is supported in TileLang DSL v1") + item = node.items[0] + if not isinstance(item.context_expr, ast.Call): + raise self.source_info.error(item.context_expr, "with context must be a call in TileLang DSL v1") + if not ( + isinstance(item.context_expr.func, ast.Attribute) + and isinstance(item.context_expr.func.value, ast.Name) + and item.context_expr.func.value.id == "pto" + and item.context_expr.func.attr == "strict_vecscope" + ): + raise self.source_info.error( + item.context_expr, + "only pto.strict_vecscope is supported as a with-context in TileLang DSL v1", + ) + if not isinstance(item.optional_vars, ast.Tuple): + raise self.source_info.error(item, "pto.strict_vecscope requires tuple binding in 'as'") + for elt in item.optional_vars.elts: + if not isinstance(elt, ast.Name): + raise self.source_info.error(elt, "pto.strict_vecscope bindings must be names") + self._vecscope_depth += 1 + try: + for stmt in node.body: + self.visit(stmt) + finally: + self._vecscope_depth -= 1 + def visit_Call(self, node: ast.Call) -> None: if isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name): + if node.func.value.id == "pto" and node.func.attr in _V1_ALLOWED_TOPLEVEL_PTO_CALLS: + return + if node.func.value.id == "pto" and node.func.attr in _V1_ALLOWED_VECSCOPE_PTO_CALLS: + if self._vecscope_depth <= 0: + raise self.source_info.error( + node, + f"vector op surface `pto.{node.func.attr}` requires explicit pto.strict_vecscope in TileLang DSL v1", + ) + return if node.func.value.id == "pto": raise self.source_info.error( node, @@ -128,6 +221,8 @@ def visit_Call(self, node: ast.Call) -> None: ) if isinstance(node.func, ast.Name): + if node.func.id == "range": + return raise self.source_info.error( node, f"arbitrary external call `{node.func.id}` is not supported in TileLang DSL v1", @@ -302,38 +397,14 @@ def _require_specialized_tiles(self, api_name: str) -> None: f"{missing_names}", ) - def _format_symbol_name(self) -> str: - if re.fullmatch(r"[A-Za-z_][A-Za-z0-9_$.]*", self.name): - return f"@{self.name}" - escaped = self.name.replace("\\", "\\\\").replace('"', '\\"') - return f'@"{escaped}"' + def _build_authoring_module(self): + frontend_kernel = build_frontend_kernel_node(self) + semantic_kernel = analyze_frontend_kernel(frontend_kernel) + return lower_semantic_kernel(semantic_kernel) def mlir_text(self) -> str: self._require_specialized_tiles("mlir_text") - - lines = [ - f"// tilelang.target = {self.target}", - f"// tilelang.op = {self.op}", - f"// tilelang.dtypes = {self.dtypes}", - f"// tilelang.verify = {self.verify_enabled}", - ] - for name, spec in self.specializations: - lines.append( - "// tilelang.specialize " - f"{name} shape={spec.shape} memory_space={spec.memory_space.value} " - f"config={spec.config}" - ) - lines.extend( - [ - "module {", - f" func.func {self._format_symbol_name()}() {{", - " return", - " }", - "}", - "", - ] - ) - return "\n".join(lines) + return self._build_authoring_module().render() def mlir_module(self) -> "MaterializedMLIRModule": self._require_specialized_tiles("mlir_module") diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py new file mode 100644 index 000000000..04b30ba5e --- /dev/null +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -0,0 +1,719 @@ +"""Authoring-form VPTO lowering skeleton for TileLang DSL v1.""" + +from __future__ import annotations + +import re +from dataclasses import dataclass + +from .semantic import ( + SemanticAssignStmt, + SemanticAttributeAccess, + SemanticBinaryExpr, + SemanticBindingRef, + SemanticCallExpr, + SemanticDmaLoadStmt, + SemanticDmaStoreStmt, + SemanticExpr, + SemanticExprStmt, + SemanticForStmt, + SemanticIfStmt, + SemanticIndexType, + SemanticIfResult, + SemanticKernel, + SemanticLiteralExpr, + SemanticMaskType, + SemanticPipeBarrierStmt, + SemanticReturnStmt, + SemanticScalarType, + SemanticSetFlagStmt, + SemanticStmt, + SemanticStrictVecscopeStmt, + SemanticSubscriptAccess, + SemanticSymbolExpr, + SemanticTensorSliceExpr, + SemanticTensorViewType, + SemanticTileType, + SemanticType, + SemanticVRegType, + SemanticVectorStoreStmt, + SemanticWaitFlagStmt, +) +from .types import MaskPattern, ScalarType + + +_I1_TYPE = SemanticScalarType(dtype=ScalarType("i1")) +_I64_TYPE = SemanticScalarType(dtype=ScalarType("i64")) + + +def _format_symbol_name(symbol_name: str) -> str: + if re.fullmatch(r"[A-Za-z_][A-Za-z0-9_$.]*", symbol_name): + return f"@{symbol_name}" + escaped = symbol_name.replace("\\", "\\\\").replace('"', '\\"') + return f'@"{escaped}"' + + +@dataclass(frozen=True) +class AuthoringModule: + """Lowering result that owns authoring-form VPTO text emission.""" + + kernel: SemanticKernel + + def render(self) -> str: + return _AuthoringRenderer(self.kernel).render() + + +@dataclass(frozen=True) +class _RenderedValue: + name: str + type: SemanticType + + +class _AuthoringRenderer: + def __init__(self, kernel: SemanticKernel): + self.kernel = kernel + self._constant_lines: list[str] = [] + self._constant_cache: dict[tuple[str, object], str] = {} + self._temp_counter = 0 + self._loop_counter = 0 + + def render(self) -> str: + parameter_list = ", ".join( + f"{param.ssa_name}: {self._render_type(param.type)}" + for param in self.kernel.parameters + ) + env = { + param.name: _RenderedValue(name=param.ssa_name, type=param.type) + for param in self.kernel.parameters + } + body_lines = self._render_block(self.kernel.body, env, indent=4) + + lines = [ + f"// tilelang.target = {self.kernel.target}", + f"// tilelang.op = {self.kernel.op}", + f"// tilelang.dtypes = {self.kernel.dtype_signature}", + f"// tilelang.verify = {self.kernel.verify_enabled}", + ] + for binding in self.kernel.tile_bindings: + lines.append( + "// tilelang.specialize " + f"{binding.name} shape={binding.shape} memory_space={binding.memory_space} " + f"config={binding.config}" + ) + lines.append(f'module attributes {{pto.target_arch = "{self.kernel.target}"}} {{') + lines.append( + f" func.func {_format_symbol_name(self.kernel.symbol_name)}({parameter_list}) {{" + ) + lines.extend(self._constant_lines) + lines.extend(body_lines) + lines.append(" }") + lines.append("}") + lines.append("") + return "\n".join(lines) + + def _render_block( + self, + statements: tuple[SemanticStmt, ...], + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + lines: list[str] = [] + for stmt in statements: + lines.extend(self._render_stmt(stmt, env, indent=indent)) + return lines + + def _render_stmt( + self, + stmt: SemanticStmt, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + if isinstance(stmt, SemanticAssignStmt): + return self._render_assign(stmt, env, indent=indent) + if isinstance(stmt, SemanticExprStmt): + self._lower_expr(stmt.expr, env, indent=indent) + return [] + if isinstance(stmt, SemanticDmaLoadStmt): + return self._render_dma_load(stmt, env, indent=indent) + if isinstance(stmt, SemanticDmaStoreStmt): + return self._render_dma_store(stmt, env, indent=indent) + if isinstance(stmt, SemanticVectorStoreStmt): + return self._render_vector_store(stmt, env, indent=indent) + if isinstance(stmt, SemanticSetFlagStmt): + return [ + self._indent(indent) + + f'pto.set_flag["{stmt.src_pipe}", "{stmt.dst_pipe}", "{stmt.event}"]' + ] + if isinstance(stmt, SemanticWaitFlagStmt): + return [ + self._indent(indent) + + f'pto.wait_flag["{stmt.src_pipe}", "{stmt.dst_pipe}", "{stmt.event}"]' + ] + if isinstance(stmt, SemanticPipeBarrierStmt): + return [self._indent(indent) + f"pto.barrier #pto.pipe<{stmt.pipe}>"] + if isinstance(stmt, SemanticReturnStmt): + if stmt.value is None: + return [self._indent(indent) + "return"] + value = self._lower_expr(stmt.value, env, indent=indent) + return [self._indent(indent) + f"return {value.name} : {self._render_type(value.type)}"] + if isinstance(stmt, SemanticStrictVecscopeStmt): + return self._render_strict_vecscope(stmt, env, indent=indent) + if isinstance(stmt, SemanticForStmt): + return self._render_for(stmt, env, indent=indent) + if isinstance(stmt, SemanticIfStmt): + return self._render_if(stmt, env, indent=indent) + raise ValueError(f"unsupported semantic statement {type(stmt).__name__}") + + def _render_assign( + self, + stmt: SemanticAssignStmt, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + if len(stmt.targets) != 1: + raise NotImplementedError("multiple-result assignment is not supported in TileLang DSL v1 yet") + target = stmt.targets[0] + lines: list[str] = [] + lowered = self._lower_expr( + stmt.value, + env, + indent=indent, + desired_name=target.ssa_name, + into=lines, + ) + env[target.name] = lowered + return lines + + def _render_dma_load( + self, + stmt: SemanticDmaLoadStmt, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + src = self._lower_expr(stmt.src.base, env, indent=indent) + dst = self._lower_expr(stmt.dst, env, indent=indent) + row_count, col_count = self._tensor_slice_extents(stmt.src) + element_bytes = self._dtype_byte_width(stmt.src.type.element_dtype) + burst_bytes = col_count * element_bytes + + c0_i64 = self._materialize_constant(0, _I64_TYPE) + c1_i64 = self._materialize_constant(1, _I64_TYPE) + n_burst = self._materialize_constant(row_count, _I64_TYPE) + len_burst = self._materialize_constant(burst_bytes, _I64_TYPE) + false_bit = self._materialize_constant(False, _I1_TYPE) + + return [ + self._indent(indent) + + f"pto.set_loop_size_outtoub {c1_i64}, {c1_i64} : i64, i64", + self._indent(indent) + + "pto.copy_gm_to_ubuf " + + f"{src.name}, {dst.name}, {c0_i64}, {n_burst}, {len_burst}, {c0_i64}, {c0_i64}, " + + f"{false_bit}, {c0_i64}, {len_burst}, {len_burst} : " + + f"{self._render_type(src.type)}, {self._render_type(dst.type)}, " + + "i64, i64, i64, i64, i64, i1, i64, i64, i64", + ] + + def _render_dma_store( + self, + stmt: SemanticDmaStoreStmt, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + src = self._lower_expr(stmt.src, env, indent=indent) + dst = self._lower_expr(stmt.dst.base, env, indent=indent) + row_count, col_count = self._tensor_slice_extents(stmt.dst) + element_bytes = self._dtype_byte_width(stmt.dst.type.element_dtype) + burst_bytes = col_count * element_bytes + + c0_i64 = self._materialize_constant(0, _I64_TYPE) + c1_i64 = self._materialize_constant(1, _I64_TYPE) + n_burst = self._materialize_constant(row_count, _I64_TYPE) + len_burst = self._materialize_constant(burst_bytes, _I64_TYPE) + + return [ + self._indent(indent) + + f"pto.set_loop_size_ubtoout {c1_i64}, {c1_i64} : i64, i64", + self._indent(indent) + + "pto.copy_ubuf_to_gm " + + f"{src.name}, {dst.name}, {c0_i64}, {n_burst}, {len_burst}, {c0_i64}, " + + f"{len_burst}, {len_burst} : {self._render_type(src.type)}, {self._render_type(dst.type)}, " + + "i64, i64, i64, i64, i64, i64", + ] + + def _render_vector_store( + self, + stmt: SemanticVectorStoreStmt, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + value = self._lower_expr(stmt.value, env, indent=indent) + destination = self._lower_expr(stmt.destination, env, indent=indent) + offset = self._lower_expr(stmt.offset, env, indent=indent) + mask = self._lower_expr(stmt.mask, env, indent=indent) + return [ + self._indent(indent) + + "pto.vsts " + + f"{value.name}, {destination.name}[{offset.name}], {mask.name} : " + + f"{self._render_type(value.type)}, {self._render_type(destination.type)}, {self._render_type(mask.type)}" + ] + + def _tensor_slice_extents(self, expr: SemanticTensorSliceExpr) -> tuple[int, int]: + if expr.type.rank != 2 or len(expr.type.extents) != 2: + raise NotImplementedError("TileLang DSL v1 DMA lowering currently only supports rank-2 TensorView slices") + return expr.type.extents + + def _render_strict_vecscope( + self, + stmt: SemanticStrictVecscopeStmt, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + capture_values = [self._lower_expr(expr, env, indent=indent) for expr in stmt.captures] + capture_names = ", ".join(value.name for value in capture_values) + block_args = ", ".join( + f"{binding.ssa_name}: {self._render_type(binding.type)}" + for binding in stmt.block_arguments + ) + function_type = ", ".join( + self._render_type(binding.type) for binding in stmt.block_arguments + ) + + scope_env = { + binding.name: _RenderedValue(name=binding.ssa_name, type=binding.type) + for binding in stmt.block_arguments + } + + lines = [self._indent(indent) + f"pto.strict_vecscope({capture_names}) {{"] + lines.append(self._indent(indent) + f"^bb0({block_args}):") + lines.extend(self._render_block(stmt.body, scope_env, indent=indent + 2)) + lines.append(self._indent(indent) + f"}} : ({function_type}) -> ()") + return lines + + def _render_for( + self, + stmt: SemanticForStmt, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + lower_bound = self._lower_expr(stmt.lower_bound, env, indent=indent) + upper_bound = self._lower_expr(stmt.upper_bound, env, indent=indent) + step = self._lower_expr(stmt.step, env, indent=indent) + + body_env = dict(env) + body_env[stmt.induction_variable.name] = _RenderedValue( + name=stmt.induction_variable.ssa_name, + type=stmt.induction_variable.type, + ) + + if not stmt.loop_carried: + lines = [ + self._indent(indent) + + f"scf.for {stmt.induction_variable.ssa_name} = {lower_bound.name} " + f"to {upper_bound.name} step {step.name} {{" + ] + lines.extend(self._render_block(stmt.body, body_env, indent=indent + 2)) + lines.append(self._indent(indent) + "}") + return lines + + if len(stmt.loop_carried) != 1: + raise NotImplementedError( + "TileLang DSL v1 lowering currently supports at most one loop-carried binding" + ) + + carried_binding = stmt.loop_carried[0] + initial_value = env[carried_binding.name] + iter_arg_name = f"%{carried_binding.name}_iter_{self._loop_counter}" + self._loop_counter += 1 + body_env[carried_binding.name] = _RenderedValue( + name=iter_arg_name, + type=carried_binding.type, + ) + + lines = [ + self._indent(indent) + + f"{carried_binding.ssa_name}:1 = scf.for {stmt.induction_variable.ssa_name} = " + f"{lower_bound.name} to {upper_bound.name} step {step.name} " + f"iter_args({iter_arg_name} = {initial_value.name}) -> " + f"({self._render_type(carried_binding.type)}) {{" + ] + lines.extend(self._render_block(stmt.body, body_env, indent=indent + 2)) + yielded_value = body_env[carried_binding.name] + lines.append( + self._indent(indent + 2) + + f"scf.yield {yielded_value.name} : {self._render_type(yielded_value.type)}" + ) + lines.append(self._indent(indent) + "}") + env[carried_binding.name] = _RenderedValue( + name=carried_binding.ssa_name, + type=carried_binding.type, + ) + return lines + + def _render_if( + self, + stmt: SemanticIfStmt, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + cond_lines: list[str] = [] + condition = self._lower_condition(stmt.condition, env, indent=indent, into=cond_lines) + then_env = dict(env) + else_env = dict(env) + + if not stmt.results: + lines = list(cond_lines) + lines.append(self._indent(indent) + f"scf.if {condition.name} {{") + lines.extend(self._render_block(stmt.then_body, then_env, indent=indent + 2)) + if stmt.else_body: + lines.append(self._indent(indent) + "} else {") + lines.extend(self._render_block(stmt.else_body, else_env, indent=indent + 2)) + lines.append(self._indent(indent) + "}") + return lines + + if len(stmt.results) != 1: + raise NotImplementedError( + "TileLang DSL v1 lowering currently supports at most one merged if/else binding" + ) + + result = stmt.results[0] + lines = list(cond_lines) + lines.append( + self._indent(indent) + + f"{result.result_binding.ssa_name} = scf.if {condition.name} -> " + + f"({self._render_type(result.result_binding.type)}) {{" + ) + lines.extend(self._render_block(stmt.then_body, then_env, indent=indent + 2)) + then_value = then_env.get(result.result_binding.name, then_env.get(result.then_binding.name)) + if then_value is None: + then_value = _RenderedValue(result.then_binding.ssa_name, result.then_binding.type) + lines.append( + self._indent(indent + 2) + + f"scf.yield {then_value.name} : {self._render_type(then_value.type)}" + ) + lines.append(self._indent(indent) + "} else {") + lines.extend(self._render_block(stmt.else_body, else_env, indent=indent + 2)) + else_value = else_env.get(result.result_binding.name, else_env.get(result.else_binding.name)) + if else_value is None: + else_value = _RenderedValue(result.else_binding.ssa_name, result.else_binding.type) + lines.append( + self._indent(indent + 2) + + f"scf.yield {else_value.name} : {self._render_type(else_value.type)}" + ) + lines.append(self._indent(indent) + "}") + env[result.result_binding.name] = _RenderedValue( + name=result.result_binding.ssa_name, + type=result.result_binding.type, + ) + return lines + + def _lower_condition( + self, + expr: SemanticExpr, + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + value = self._lower_expr(expr, env, indent=indent) + if isinstance(value.type, SemanticScalarType) and value.type.dtype.name == "i1": + return value + + zero_type: SemanticType + predicate: str + if isinstance(value.type, SemanticIndexType): + zero_type = SemanticIndexType() + predicate = "arith.cmpi ne" + elif isinstance(value.type, SemanticScalarType): + zero_type = value.type + if value.type.dtype.name in {"f16", "bf16", "f32"}: + predicate = "arith.cmpf une" + else: + predicate = "arith.cmpi ne" + else: + raise NotImplementedError(f"unsupported if condition type {value.type!r}") + + zero = self._materialize_constant(0, zero_type) + result_name = self._new_temp() + into.append( + self._indent(indent) + + f"{result_name} = {predicate}, {value.name}, {zero} : {self._render_type(value.type)}" + ) + return _RenderedValue(name=result_name, type=_I1_TYPE) + + def _lower_expr( + self, + expr: SemanticExpr, + env: dict[str, _RenderedValue], + *, + indent: int, + desired_name: str | None = None, + into: list[str] | None = None, + ) -> _RenderedValue: + if isinstance(expr, SemanticBindingRef): + return env.get(expr.binding.name, _RenderedValue(expr.binding.ssa_name, expr.type)) + if isinstance(expr, SemanticLiteralExpr): + if desired_name is not None and into is not None: + into.append( + self._indent(indent) + + f"{desired_name} = arith.constant {self._format_constant(expr.value, expr.type)} : " + f"{self._render_type(expr.type)}" + ) + return _RenderedValue(name=desired_name, type=expr.type) + return _RenderedValue( + name=self._materialize_constant(expr.value, expr.type), + type=expr.type, + ) + if isinstance(expr, SemanticSubscriptAccess): + if desired_name is not None and into is not None: + value = self._extract_static_subscript_value(expr, env) + into.append( + self._indent(indent) + + f"{desired_name} = arith.constant {self._format_constant(value, expr.type)} : " + f"{self._render_type(expr.type)}" + ) + return _RenderedValue(name=desired_name, type=expr.type) + constant_name = self._lower_static_subscript(expr, env) + return _RenderedValue(name=constant_name, type=expr.type) + if isinstance(expr, SemanticBinaryExpr): + lhs = self._lower_expr(expr.lhs, env, indent=indent) + rhs = self._lower_expr(expr.rhs, env, indent=indent) + if into is None: + into = [] + result_name = desired_name or self._new_temp() + into.append( + self._indent(indent) + + f"{result_name} = {self._render_binary_op(expr.op, expr.type)} " + f"{lhs.name}, {rhs.name} : {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + if isinstance(expr, SemanticCallExpr): + return self._lower_call_expr(expr, env, indent=indent, desired_name=desired_name, into=into) + if isinstance(expr, SemanticAttributeAccess): + raise NotImplementedError("bare shape attribute values are not materialized directly") + if isinstance(expr, SemanticTensorSliceExpr): + raise NotImplementedError("TensorView slices are only lowered through DMA statements in TileLang DSL v1") + if isinstance(expr, SemanticSymbolExpr): + raise NotImplementedError("symbol expressions are only lowered through specialized TileLang DSL ops") + raise NotImplementedError(f"unsupported semantic expression {type(expr).__name__}") + + def _lower_call_expr( + self, + expr: SemanticCallExpr, + env: dict[str, _RenderedValue], + *, + indent: int, + desired_name: str | None, + into: list[str] | None, + ) -> _RenderedValue: + if expr.namespace != "pto": + raise NotImplementedError(f"unsupported call namespace {expr.namespace!r}") + if into is None: + into = [] + result_name = desired_name or self._new_temp() + + if expr.name == "make_mask": + dtype_expr, pattern_expr = expr.args + if not isinstance(dtype_expr, SemanticSymbolExpr): + raise NotImplementedError("make_mask dtype lowering expects a dtype symbol") + if not isinstance(pattern_expr, SemanticSymbolExpr) or not isinstance(pattern_expr.value, MaskPattern): + raise NotImplementedError("make_mask pattern lowering expects a MaskPattern symbol") + suffix = expr.type.granularity + into.append( + self._indent(indent) + + f'{result_name} = pto.pset_{suffix} "{pattern_expr.value.value}" : {self._render_type(expr.type)}' + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "vlds": + source = self._lower_expr(expr.args[0], env, indent=indent) + offset = self._lower_expr(expr.args[1], env, indent=indent) + into.append( + self._indent(indent) + + f"{result_name} = pto.vlds {source.name}[{offset.name}] : " + + f"{self._render_type(source.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name in {"vabs", "vrelu", "vexp", "vnot"}: + value = self._lower_expr(expr.args[0], env, indent=indent) + mask = self._lower_expr(expr.args[1], env, indent=indent) + into.append( + self._indent(indent) + + f"{result_name} = pto.{expr.name} {value.name}, {mask.name} : " + + f"{self._render_type(value.type)}, {self._render_type(mask.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name in {"vadd", "vsub", "vmul", "vdiv", "vmax", "vmin", "vand", "vor", "vxor"}: + lhs = self._lower_expr(expr.args[0], env, indent=indent) + rhs = self._lower_expr(expr.args[1], env, indent=indent) + mask = self._lower_expr(expr.args[2], env, indent=indent) + into.append( + self._indent(indent) + + f"{result_name} = pto.{expr.name} {lhs.name}, {rhs.name}, {mask.name} : " + + f"{self._render_type(lhs.type)}, {self._render_type(rhs.type)}, {self._render_type(mask.type)} " + + f"-> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name in {"vadds", "vsubs", "vmuls", "vdivs", "vmaxs", "vmins"}: + value = self._lower_expr(expr.args[0], env, indent=indent) + scalar = self._lower_expr(expr.args[1], env, indent=indent) + mask = self._lower_expr(expr.args[2], env, indent=indent) + into.append( + self._indent(indent) + + f"{result_name} = pto.{expr.name} {value.name}, {scalar.name}, {mask.name} : " + + f"{self._render_type(value.type)}, {self._render_type(scalar.type)}, {self._render_type(mask.type)} " + + f"-> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + raise NotImplementedError(f"unsupported pto call `{expr.name}` in lowering") + + def _lower_static_subscript( + self, + expr: SemanticSubscriptAccess, + env: dict[str, _RenderedValue], + ) -> str: + value = self._extract_static_subscript_value(expr, env) + return self._materialize_constant(value, expr.type) + + def _extract_static_subscript_value( + self, + expr: SemanticSubscriptAccess, + env: dict[str, _RenderedValue], + ) -> int: + if not isinstance(expr.base, SemanticAttributeAccess): + raise NotImplementedError("only shape indexing is supported in TileLang DSL v1 lowering") + if expr.base.attr != "shape": + raise NotImplementedError("only `.shape[...]` indexing is supported in TileLang DSL v1 lowering") + if not isinstance(expr.index, SemanticLiteralExpr) or not isinstance(expr.index.value, int): + raise NotImplementedError("shape indices must be integer literals in TileLang DSL v1 lowering") + if not isinstance(expr.base.base, SemanticBindingRef): + raise NotImplementedError("shape indexing expects a bound TensorView or Tile value") + + base_binding = expr.base.base.binding + base_value = env.get(base_binding.name, _RenderedValue(base_binding.ssa_name, base_binding.type)) + base_type = base_value.type + index = expr.index.value + + if isinstance(base_type, SemanticTileType): + if base_type.shape is None: + raise NotImplementedError("dynamic Tile shapes are not supported in TileLang DSL v1 lowering") + return base_type.shape[index] + + if isinstance(base_type, SemanticTensorViewType): + raise NotImplementedError( + "dynamic TensorView shape materialization is not implemented in TileLang DSL v1 lowering yet" + ) + + raise NotImplementedError("shape indexing expects a Tile or TensorView operand") + + def _materialize_constant(self, value: object, ty: SemanticType) -> str: + cache_key = (self._render_type(ty), value) + if cache_key in self._constant_cache: + return self._constant_cache[cache_key] + + name = self._constant_name(value, ty) + self._constant_cache[cache_key] = name + self._constant_lines.append( + self._indent(4) + + f"{name} = arith.constant {self._format_constant(value, ty)} : {self._render_type(ty)}" + ) + return name + + def _constant_name(self, value: object, ty: SemanticType) -> str: + if isinstance(ty, SemanticIndexType): + stem = f"c{value}" + elif isinstance(ty, SemanticScalarType): + if ty.dtype.name == "i1" and isinstance(value, bool): + stem = "true" if value else "false" + else: + stem = f"c{value}_{ty.dtype.name}" + else: + stem = "cst" + name = f"%{stem}" + existing = {line.split(" = ", 1)[0].strip() for line in self._constant_lines} + if name not in existing: + return name + suffix = 0 + while f"{name}_{suffix}" in existing: + suffix += 1 + return f"{name}_{suffix}" + + def _format_constant(self, value: object, ty: SemanticType) -> str: + if isinstance(ty, SemanticIndexType): + return str(value) + if isinstance(ty, SemanticScalarType): + if ty.dtype.name == "i1" and isinstance(value, bool): + return "true" if value else "false" + return str(value) + raise NotImplementedError(f"unsupported constant type {ty!r}") + + def _render_binary_op(self, op: str, ty: SemanticType) -> str: + if isinstance(ty, (SemanticIndexType, SemanticScalarType)): + if op == "add": + return "arith.addi" + if op == "sub": + return "arith.subi" + if op == "mul": + return "arith.muli" + if op == "floordiv": + return "arith.floordivsi" + raise NotImplementedError(f"unsupported binary op '{op}' for type {ty!r}") + + def _render_type(self, ty: SemanticType) -> str: + if isinstance(ty, SemanticIndexType): + return "index" + if isinstance(ty, SemanticScalarType): + return ty.dtype.name + if isinstance(ty, SemanticTensorViewType): + return f"!pto.ptr<{ty.element_dtype.name}, gm>" + if isinstance(ty, SemanticTileType): + memory_space = ty.memory_space or "ub" + return f"!pto.ptr<{ty.element_dtype.name}, {memory_space}>" + if isinstance(ty, SemanticMaskType): + return f"!pto.mask<{ty.granularity}>" + if isinstance(ty, SemanticVRegType): + return f"!pto.vreg<{ty.lanes}x{ty.element_dtype.name}>" + raise NotImplementedError(f"unsupported semantic type {ty!r}") + + def _dtype_byte_width(self, dtype: ScalarType) -> int: + widths = { + "i8": 1, + "i16": 2, + "i32": 4, + "i64": 8, + "f16": 2, + "bf16": 2, + "f32": 4, + } + width = widths.get(dtype.name) + if width is None: + raise NotImplementedError(f"unsupported DMA dtype '{dtype.name}' in TileLang DSL v1 lowering") + return width + + def _indent(self, indent: int) -> str: + return " " * indent + + def _new_temp(self) -> str: + name = f"%tmp_{self._temp_counter}" + self._temp_counter += 1 + return name + + +def lower_semantic_kernel(kernel: SemanticKernel) -> AuthoringModule: + """Lower the semantic model to the current authoring-form VPTO builder.""" + + return AuthoringModule(kernel=kernel) + + +__all__ = ["AuthoringModule", "lower_semantic_kernel"] diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py new file mode 100644 index 000000000..6ec2cb168 --- /dev/null +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -0,0 +1,1265 @@ +"""Semantic model for TileLang DSL descriptor lowering.""" + +from __future__ import annotations + +import ast +from dataclasses import dataclass +from typing import Any + +from .frontend_ast import ( + FrontendAssignStmt, + FrontendAttributeExpr, + FrontendBinaryExpr, + FrontendCallExpr, + FrontendConstantExpr, + FrontendExprNode, + FrontendExprStmt, + FrontendForStmt, + FrontendIfStmt, + FrontendKernelNode, + FrontendNameExpr, + FrontendNameTarget, + FrontendReturnStmt, + FrontendSliceExpr, + FrontendStrictVecscopeStmt, + FrontendStmtNode, + FrontendSubscriptExpr, + FrontendSymbolExpr, + FrontendTargetNode, + FrontendTupleExpr, + FrontendTupleTarget, +) +from .types import Event, MaskPattern, Pipe, ScalarType, bf16, f16, f32, i1, i8, i16, i32 + + +_DTYPE_SYMBOLS = { + "i1": i1, + "i8": i8, + "i16": i16, + "i32": i32, + "f16": f16, + "bf16": bf16, + "f32": f32, +} +_PATTERN_SYMBOLS = {pattern.name: pattern for pattern in MaskPattern} +_PIPE_SYMBOLS = {pipe.name: pipe for pipe in Pipe} +_EVENT_SYMBOLS = {event.name: event for event in Event} +_UNARY_VECTOR_OPS = {"vabs", "vrelu", "vexp", "vnot"} +_BINARY_VECTOR_OPS = {"vadd", "vsub", "vmul", "vdiv", "vmax", "vmin", "vand", "vor", "vxor"} +_VECTOR_SCALAR_OPS = {"vadds", "vsubs", "vmuls", "vdivs", "vmaxs", "vmins"} + + +class SemanticType: + """Base class for semantic value types.""" + + +@dataclass(frozen=True) +class SemanticTensorViewType(SemanticType): + element_dtype: ScalarType + rank: int = 2 + + +@dataclass(frozen=True) +class SemanticTensorSliceType(SemanticType): + element_dtype: ScalarType + rank: int + extents: tuple[int, ...] + + +@dataclass(frozen=True) +class SemanticTileType(SemanticType): + element_dtype: ScalarType + rank: int + shape: tuple[int, ...] | None + memory_space: str | None + + +@dataclass(frozen=True) +class SemanticScalarType(SemanticType): + dtype: ScalarType + + +@dataclass(frozen=True) +class SemanticIndexType(SemanticType): + pass + + +@dataclass(frozen=True) +class SemanticShapeType(SemanticType): + rank: int + + +@dataclass(frozen=True) +class SemanticSliceType(SemanticType): + pass + + +@dataclass(frozen=True) +class SemanticTupleType(SemanticType): + elements: tuple[SemanticType, ...] + + +@dataclass(frozen=True) +class SemanticMetaType(SemanticType): + kind: str + + +@dataclass(frozen=True) +class SemanticMaskType(SemanticType): + granularity: str + + +@dataclass(frozen=True) +class SemanticVRegType(SemanticType): + element_dtype: ScalarType + lanes: int + + +@dataclass(frozen=True) +class SemanticBinding: + name: str + ssa_name: str + type: SemanticType + origin: str + + +@dataclass(frozen=True) +class SemanticTileBinding: + name: str + shape: tuple[int, ...] + memory_space: str + config: Any + + +class SemanticExpr: + """Base class for typed semantic expressions.""" + + +@dataclass(frozen=True) +class SemanticBindingRef(SemanticExpr): + binding: SemanticBinding + type: SemanticType + + +@dataclass(frozen=True) +class SemanticLiteralExpr(SemanticExpr): + value: Any + type: SemanticType + + +@dataclass(frozen=True) +class SemanticSymbolExpr(SemanticExpr): + namespace: str + name: str + value: Any + type: SemanticMetaType + + +@dataclass(frozen=True) +class SemanticSliceExpr(SemanticExpr): + start: SemanticExpr | None + stop: SemanticExpr | None + step: SemanticExpr | None + type: SemanticSliceType + + +@dataclass(frozen=True) +class SemanticTupleExpr(SemanticExpr): + elements: tuple[SemanticExpr, ...] + type: SemanticTupleType + + +@dataclass(frozen=True) +class SemanticAttributeAccess(SemanticExpr): + base: SemanticExpr + attr: str + type: SemanticType + + +@dataclass(frozen=True) +class SemanticSubscriptAccess(SemanticExpr): + base: SemanticExpr + index: SemanticExpr + type: SemanticType + + +@dataclass(frozen=True) +class SemanticTensorSliceExpr(SemanticExpr): + base: SemanticExpr + slices: tuple[SemanticSliceExpr, ...] + type: SemanticTensorSliceType + + +@dataclass(frozen=True) +class SemanticBinaryExpr(SemanticExpr): + lhs: SemanticExpr + op: str + rhs: SemanticExpr + type: SemanticType + + +@dataclass(frozen=True) +class SemanticCallExpr(SemanticExpr): + namespace: str | None + name: str + args: tuple[SemanticExpr, ...] + type: SemanticType | None + + +class SemanticStmt: + """Base class for semantic statements.""" + + +@dataclass(frozen=True) +class SemanticAssignStmt(SemanticStmt): + targets: tuple[SemanticBinding, ...] + value: SemanticExpr + annotation: Any | None = None + + +@dataclass(frozen=True) +class SemanticExprStmt(SemanticStmt): + expr: SemanticExpr + + +@dataclass(frozen=True) +class SemanticDmaLoadStmt(SemanticStmt): + src: SemanticTensorSliceExpr + dst: SemanticExpr + + +@dataclass(frozen=True) +class SemanticDmaStoreStmt(SemanticStmt): + src: SemanticExpr + dst: SemanticTensorSliceExpr + + +@dataclass(frozen=True) +class SemanticVectorStoreStmt(SemanticStmt): + value: SemanticExpr + destination: SemanticExpr + offset: SemanticExpr + mask: SemanticExpr + + +@dataclass(frozen=True) +class SemanticSetFlagStmt(SemanticStmt): + src_pipe: str + dst_pipe: str + event: str + + +@dataclass(frozen=True) +class SemanticWaitFlagStmt(SemanticStmt): + src_pipe: str + dst_pipe: str + event: str + + +@dataclass(frozen=True) +class SemanticPipeBarrierStmt(SemanticStmt): + pipe: str + + +@dataclass(frozen=True) +class SemanticIfResult: + result_binding: SemanticBinding + then_binding: SemanticBinding + else_binding: SemanticBinding + + +@dataclass(frozen=True) +class SemanticIfStmt(SemanticStmt): + condition: SemanticExpr + then_body: tuple[SemanticStmt, ...] + else_body: tuple[SemanticStmt, ...] + results: tuple[SemanticIfResult, ...] + + +@dataclass(frozen=True) +class SemanticReturnStmt(SemanticStmt): + value: SemanticExpr | None + + +@dataclass(frozen=True) +class SemanticForStmt(SemanticStmt): + induction_variable: SemanticBinding + lower_bound: SemanticExpr + upper_bound: SemanticExpr + step: SemanticExpr + body: tuple[SemanticStmt, ...] + loop_carried: tuple[SemanticBinding, ...] + + +@dataclass(frozen=True) +class SemanticStrictVecscopeStmt(SemanticStmt): + captures: tuple[SemanticExpr, ...] + block_arguments: tuple[SemanticBinding, ...] + body: tuple[SemanticStmt, ...] + + +@dataclass(frozen=True) +class SemanticParameter: + binding: SemanticBinding + + @property + def name(self) -> str: + return self.binding.name + + @property + def kind(self) -> str: + return self.binding.origin + + @property + def type(self) -> SemanticType: + return self.binding.type + + @property + def ssa_name(self) -> str: + return self.binding.ssa_name + + +@dataclass(frozen=True) +class SemanticKernel: + target: str + op: str + symbol_name: str + verify_enabled: bool + dtype_signature: tuple[Any, ...] + parameters: tuple[SemanticParameter, ...] + tile_bindings: tuple[SemanticTileBinding, ...] + body: tuple[SemanticStmt, ...] + + +class _SemanticAnalyzer: + def __init__(self, node: FrontendKernelNode): + self.node = node + self._counter = 0 + self._tile_specializations = { + spec.name: spec for spec in node.tile_specializations + } + + def analyze(self) -> SemanticKernel: + env: dict[str, SemanticBinding] = {} + parameters = [] + for index, param in enumerate(self.node.parameters): + binding = SemanticBinding( + name=param.name, + ssa_name=f"%arg{index}", + type=self._parameter_type(param), + origin=param.kind, + ) + env[param.name] = binding + parameters.append(SemanticParameter(binding=binding)) + body, _ = self._analyze_block(self.node.body, env, allow_outer_lookup=True) + tile_bindings = tuple( + SemanticTileBinding( + name=spec.name, + shape=spec.shape, + memory_space=spec.memory_space, + config=spec.config, + ) + for spec in self.node.tile_specializations + ) + return SemanticKernel( + target=self.node.target, + op=self.node.op, + symbol_name=self.node.name, + verify_enabled=self.node.verify_enabled, + dtype_signature=self.node.dtype_signature, + parameters=tuple(parameters), + tile_bindings=tile_bindings, + body=body, + ) + + def _parameter_type(self, param: Any) -> SemanticType: + if param.kind == "tensorview": + return SemanticTensorViewType(element_dtype=param.dtype) + if param.kind == "tile": + spec = self._tile_specializations.get(param.name) + rank = 2 if spec is None else len(spec.shape) + shape = None if spec is None else spec.shape + memory_space = None if spec is None else spec.memory_space + return SemanticTileType( + element_dtype=param.dtype, + rank=rank, + shape=shape, + memory_space=memory_space, + ) + if param.kind == "scalar": + return SemanticScalarType(dtype=param.dtype) + raise ValueError(f"unsupported parameter kind {param.kind!r}") + + def _new_ssa_name(self, stem: str) -> str: + name = f"%{stem}_{self._counter}" + self._counter += 1 + return name + + def _make_binding(self, name: str, ty: SemanticType, origin: str) -> SemanticBinding: + stem = name if name.isidentifier() else "v" + return SemanticBinding( + name=name, + ssa_name=self._new_ssa_name(stem), + type=ty, + origin=origin, + ) + + def _analyze_block( + self, + statements: tuple[FrontendStmtNode, ...], + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + ) -> tuple[tuple[SemanticStmt, ...], dict[str, SemanticBinding]]: + current_env = dict(env) + semantic_statements = [] + for stmt in statements: + semantic_stmt, current_env = self._analyze_stmt( + stmt, + current_env, + allow_outer_lookup=allow_outer_lookup, + ) + semantic_statements.append(semantic_stmt) + return tuple(semantic_statements), current_env + + def _analyze_stmt( + self, + stmt: FrontendStmtNode, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + ) -> tuple[SemanticStmt, dict[str, SemanticBinding]]: + if isinstance(stmt, FrontendAssignStmt): + value = self._analyze_expr(stmt.value, env, allow_outer_lookup=allow_outer_lookup) + updated_env = dict(env) + targets = self._bind_assignment_target( + stmt.target, + value, + updated_env, + stmt.annotation, + ) + return ( + SemanticAssignStmt(targets=targets, value=value, annotation=stmt.annotation), + updated_env, + ) + if isinstance(stmt, FrontendExprStmt): + if self._is_dma_call(stmt.expr): + return self._analyze_dma_stmt(stmt.expr, env, allow_outer_lookup=allow_outer_lookup) + if self._is_sync_call(stmt.expr): + return self._analyze_sync_stmt(stmt.expr, env, allow_outer_lookup=allow_outer_lookup) + if self._is_vector_store_call(stmt.expr): + return self._analyze_vector_store_stmt(stmt.expr, env, allow_outer_lookup=allow_outer_lookup) + expr = self._analyze_expr(stmt.expr, env, allow_outer_lookup=allow_outer_lookup) + return SemanticExprStmt(expr=expr), dict(env) + if isinstance(stmt, FrontendReturnStmt): + value = None + if stmt.value is not None: + value = self._analyze_expr(stmt.value, env, allow_outer_lookup=allow_outer_lookup) + return SemanticReturnStmt(value=value), dict(env) + if isinstance(stmt, FrontendForStmt): + return self._analyze_for(stmt, env, allow_outer_lookup=allow_outer_lookup) + if isinstance(stmt, FrontendIfStmt): + return self._analyze_if(stmt, env, allow_outer_lookup=allow_outer_lookup) + if isinstance(stmt, FrontendStrictVecscopeStmt): + return self._analyze_strict_vecscope(stmt, env) + raise ValueError(f"unsupported frontend statement {type(stmt).__name__}") + + def _is_dma_call(self, expr: FrontendExprNode) -> bool: + return ( + isinstance(expr, FrontendCallExpr) + and expr.namespace == "pto" + and expr.name in {"dma_load", "dma_store"} + ) + + def _is_vector_store_call(self, expr: FrontendExprNode) -> bool: + return ( + isinstance(expr, FrontendCallExpr) + and expr.namespace == "pto" + and expr.name == "vsts" + ) + + def _is_sync_call(self, expr: FrontendExprNode) -> bool: + return ( + isinstance(expr, FrontendCallExpr) + and expr.namespace == "pto" + and expr.name in {"set_flag", "wait_flag", "pipe_barrier", "barrier"} + ) + + def _analyze_dma_stmt( + self, + expr: FrontendCallExpr, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + ) -> tuple[SemanticStmt, dict[str, SemanticBinding]]: + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + if expr.name == "dma_load": + if len(args) != 2: + raise TypeError("pto.dma_load expects exactly 2 positional arguments in TileLang DSL v1") + src = self._require_tensor_slice(args[0], "pto.dma_load source") + dst = self._require_tile_expr(args[1], "pto.dma_load destination") + self._validate_dma_shape_match(src.type, dst.type, "pto.dma_load") + return SemanticDmaLoadStmt(src=src, dst=dst), dict(env) + if expr.name == "dma_store": + if len(args) != 2: + raise TypeError("pto.dma_store expects exactly 2 positional arguments in TileLang DSL v1") + src = self._require_tile_expr(args[0], "pto.dma_store source") + dst = self._require_tensor_slice(args[1], "pto.dma_store destination") + self._validate_dma_shape_match(dst.type, src.type, "pto.dma_store") + return SemanticDmaStoreStmt(src=src, dst=dst), dict(env) + raise ValueError(f"unsupported DMA stmt pto.{expr.name}") + + def _analyze_vector_store_stmt( + self, + expr: FrontendCallExpr, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + ) -> tuple[SemanticStmt, dict[str, SemanticBinding]]: + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + if len(args) != 4: + raise TypeError("pto.vsts expects exactly 4 positional arguments in TileLang DSL v1") + value, destination, offset, mask = args + self._require_vreg_expr(value, "pto.vsts value") + self._require_tile_expr(destination, "pto.vsts destination") + self._require_index_typed_expr(offset) + self._require_mask_for_vreg(mask, value.type, "pto.vsts") + self._require_matching_vector_pointer(value.type, destination.type, "pto.vsts") + return ( + SemanticVectorStoreStmt( + value=value, + destination=destination, + offset=offset, + mask=mask, + ), + dict(env), + ) + + def _analyze_sync_stmt( + self, + expr: FrontendCallExpr, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + ) -> tuple[SemanticStmt, dict[str, SemanticBinding]]: + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + if expr.name in {"set_flag", "wait_flag"}: + if len(args) != 3: + raise TypeError(f"pto.{expr.name} expects exactly 3 positional arguments in TileLang DSL v1") + src_pipe = self._require_sync_pipe(args[0], f"pto.{expr.name} source pipe") + dst_pipe = self._require_sync_pipe(args[1], f"pto.{expr.name} destination pipe") + event = self._require_sync_event(args[2], f"pto.{expr.name} event") + if expr.name == "set_flag": + return SemanticSetFlagStmt(src_pipe=src_pipe, dst_pipe=dst_pipe, event=event), dict(env) + return SemanticWaitFlagStmt(src_pipe=src_pipe, dst_pipe=dst_pipe, event=event), dict(env) + if expr.name in {"pipe_barrier", "barrier"}: + if len(args) != 1: + raise TypeError(f"pto.{expr.name} expects exactly 1 positional argument in TileLang DSL v1") + pipe = self._require_sync_pipe(args[0], f"pto.{expr.name} pipe") + return SemanticPipeBarrierStmt(pipe=pipe), dict(env) + raise ValueError(f"unsupported sync stmt pto.{expr.name}") + + def _require_tensor_slice( + self, + expr: SemanticExpr, + context: str, + ) -> SemanticTensorSliceExpr: + if not isinstance(expr, SemanticTensorSliceExpr): + raise TypeError(f"{context} must be a TensorView slice in TileLang DSL v1") + return expr + + def _require_tile_expr(self, expr: SemanticExpr, context: str) -> SemanticExpr: + if not isinstance(expr.type, SemanticTileType): + raise TypeError(f"{context} must be a Tile value in TileLang DSL v1") + if expr.type.rank != 2: + raise TypeError(f"{context} currently only supports rank-2 Tile values in TileLang DSL v1") + if expr.type.shape is None: + raise TypeError(f"{context} requires a statically specialized Tile shape in TileLang DSL v1") + if expr.type.memory_space != "ub": + raise TypeError(f"{context} currently only supports MemorySpace.UB Tile values in TileLang DSL v1") + return expr + + def _validate_dma_shape_match( + self, + tensor_slice_type: SemanticTensorSliceType, + tile_type: SemanticTileType, + op_name: str, + ) -> None: + if tensor_slice_type.rank != 2: + raise TypeError(f"{op_name} currently only supports rank-2 TensorView slices in TileLang DSL v1") + if tile_type.rank != 2 or tile_type.shape is None: + raise TypeError(f"{op_name} requires a statically specialized rank-2 Tile in TileLang DSL v1") + if tensor_slice_type.element_dtype != tile_type.element_dtype: + raise TypeError(f"{op_name} requires matching TensorView/Tile element dtypes in TileLang DSL v1") + if tensor_slice_type.extents != tile_type.shape: + raise TypeError( + f"{op_name} requires TensorView slice extents {tensor_slice_type.extents!r} " + f"to match Tile shape {tile_type.shape!r}" + ) + + def _bind_assignment_target( + self, + target: FrontendTargetNode, + value: SemanticExpr, + env: dict[str, SemanticBinding], + annotation: Any | None, + ) -> tuple[SemanticBinding, ...]: + if isinstance(target, FrontendNameTarget): + annotated_type = self._annotation_type(annotation, value.type) + binding = self._make_binding( + target.name, + annotated_type if annotated_type is not None else value.type, + "ssa", + ) + env[target.name] = binding + return (binding,) + if isinstance(target, FrontendTupleTarget): + if not isinstance(value, SemanticCallExpr) or value.type is not None: + raise ValueError("tuple assignment expects a multi-result call") + raise ValueError("tuple assignment is not supported in TileLang DSL v1 yet") + raise ValueError(f"unsupported frontend assignment target {type(target).__name__}") + + def _annotation_type( + self, + annotation: Any | None, + inferred_type: SemanticType | None, + ) -> SemanticType | None: + if annotation is None: + return inferred_type + if isinstance(annotation, ast.Attribute) and isinstance(annotation.value, ast.Name): + if annotation.value.id == "pto" and isinstance(inferred_type, SemanticScalarType): + if inferred_type.dtype.name != annotation.attr: + raise TypeError( + f"annotated scalar type `pto.{annotation.attr}` does not match inferred {inferred_type.dtype!r}" + ) + return inferred_type + raise TypeError("unsupported annotated assignment type in TileLang DSL v1") + + def _analyze_for( + self, + stmt: FrontendForStmt, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + ) -> tuple[SemanticStmt, dict[str, SemanticBinding]]: + lower_bound = self._analyze_expr(stmt.lower_bound, env, allow_outer_lookup=allow_outer_lookup) + upper_bound = self._analyze_expr(stmt.upper_bound, env, allow_outer_lookup=allow_outer_lookup) + step = self._analyze_expr(stmt.step, env, allow_outer_lookup=allow_outer_lookup) + for expr in (lower_bound, upper_bound, step): + self._require_loop_bound_type(expr.type) + + body_env = dict(env) + induction_variable = self._make_binding(stmt.target, SemanticIndexType(), "loop_iv") + body_env[stmt.target] = induction_variable + body, final_body_env = self._analyze_block( + stmt.body, + body_env, + allow_outer_lookup=allow_outer_lookup, + ) + + updated_env = dict(env) + loop_carried = [] + for name, outer_binding in env.items(): + final_binding = final_body_env.get(name) + if final_binding is None or final_binding is outer_binding: + continue + if final_binding.type != outer_binding.type: + raise TypeError( + f"loop-carried binding '{name}' changes type from {outer_binding.type!r} to {final_binding.type!r}" + ) + merged = self._make_binding(name, outer_binding.type, "loop_result") + updated_env[name] = merged + loop_carried.append(merged) + + return ( + SemanticForStmt( + induction_variable=induction_variable, + lower_bound=lower_bound, + upper_bound=upper_bound, + step=step, + body=body, + loop_carried=tuple(loop_carried), + ), + updated_env, + ) + + def _analyze_if( + self, + stmt: FrontendIfStmt, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + ) -> tuple[SemanticStmt, dict[str, SemanticBinding]]: + condition = self._analyze_expr(stmt.condition, env, allow_outer_lookup=allow_outer_lookup) + self._require_condition_type(condition.type) + + then_body, then_env = self._analyze_block( + stmt.then_body, + dict(env), + allow_outer_lookup=allow_outer_lookup, + ) + else_body, else_env = self._analyze_block( + stmt.else_body, + dict(env), + allow_outer_lookup=allow_outer_lookup, + ) + + updated_env = dict(env) + merged_results: list[SemanticIfResult] = [] + for name, outer_binding in env.items(): + then_binding = then_env.get(name, outer_binding) + else_binding = else_env.get(name, outer_binding) + if then_binding is outer_binding and else_binding is outer_binding: + continue + if then_binding.type != else_binding.type: + raise TypeError( + f"if/else merge for '{name}' changes type between branches: " + f"{then_binding.type!r} vs {else_binding.type!r}" + ) + merged_binding = self._make_binding(name, then_binding.type, "if_result") + updated_env[name] = merged_binding + merged_results.append( + SemanticIfResult( + result_binding=merged_binding, + then_binding=then_binding, + else_binding=else_binding, + ) + ) + + return ( + SemanticIfStmt( + condition=condition, + then_body=then_body, + else_body=else_body, + results=tuple(merged_results), + ), + updated_env, + ) + + def _analyze_strict_vecscope( + self, + stmt: FrontendStrictVecscopeStmt, + env: dict[str, SemanticBinding], + ) -> tuple[SemanticStmt, dict[str, SemanticBinding]]: + if len(stmt.captures) != len(stmt.block_arguments): + raise ValueError("strict_vecscope capture arity must match block arguments") + + captures = tuple( + self._analyze_expr(expr, env, allow_outer_lookup=True) + for expr in stmt.captures + ) + scope_env: dict[str, SemanticBinding] = {} + block_arguments = [] + for name, capture in zip(stmt.block_arguments, captures): + if capture.type is None: + raise TypeError( + f"strict_vecscope block argument '{name}' type could not be inferred" + ) + block_binding = self._make_binding(name, capture.type, "strict_vecscope_arg") + scope_env[name] = block_binding + block_arguments.append(block_binding) + body, _ = self._analyze_block( + stmt.body, + scope_env, + allow_outer_lookup=False, + ) + return ( + SemanticStrictVecscopeStmt( + captures=captures, + block_arguments=tuple(block_arguments), + body=body, + ), + dict(env), + ) + + def _analyze_expr( + self, + expr: FrontendExprNode, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + ) -> SemanticExpr: + if isinstance(expr, FrontendNameExpr): + binding = env.get(expr.name) + if binding is None: + if allow_outer_lookup: + raise ValueError(f"unknown name '{expr.name}'") + raise ValueError( + f"implicit capture of '{expr.name}' is not allowed in pto.strict_vecscope" + ) + return SemanticBindingRef(binding=binding, type=binding.type) + if isinstance(expr, FrontendConstantExpr): + if isinstance(expr.value, bool): + raise TypeError("bool constants are not supported in TileLang DSL v1 yet") + if isinstance(expr.value, int): + return SemanticLiteralExpr(value=expr.value, type=SemanticIndexType()) + if isinstance(expr.value, str): + return SemanticLiteralExpr( + value=expr.value, + type=SemanticMetaType(kind="string"), + ) + if expr.value is None: + return SemanticLiteralExpr(value=None, type=SemanticIndexType()) + raise TypeError(f"unsupported constant {expr.value!r} in TileLang DSL v1") + if isinstance(expr, FrontendSymbolExpr): + return self._analyze_symbol_expr(expr) + if isinstance(expr, FrontendSliceExpr): + start = None if expr.start is None else self._analyze_expr(expr.start, env, allow_outer_lookup=allow_outer_lookup) + stop = None if expr.stop is None else self._analyze_expr(expr.stop, env, allow_outer_lookup=allow_outer_lookup) + step = None if expr.step is None else self._analyze_expr(expr.step, env, allow_outer_lookup=allow_outer_lookup) + for item in (start, stop, step): + if item is not None: + self._require_index_typed_expr(item) + return SemanticSliceExpr( + start=start, + stop=stop, + step=step, + type=SemanticSliceType(), + ) + if isinstance(expr, FrontendTupleExpr): + elements = tuple( + self._analyze_expr(element, env, allow_outer_lookup=allow_outer_lookup) + for element in expr.elements + ) + return SemanticTupleExpr( + elements=elements, + type=SemanticTupleType(elements=tuple(element.type for element in elements)), + ) + if isinstance(expr, FrontendAttributeExpr): + base = self._analyze_expr(expr.base, env, allow_outer_lookup=allow_outer_lookup) + attr_type = self._attribute_type(base, expr.attr) + return SemanticAttributeAccess(base=base, attr=expr.attr, type=attr_type) + if isinstance(expr, FrontendSubscriptExpr): + base = self._analyze_expr(expr.base, env, allow_outer_lookup=allow_outer_lookup) + index = self._analyze_expr(expr.index, env, allow_outer_lookup=allow_outer_lookup) + result_type = self._subscript_type(base, index) + if isinstance(result_type, SemanticTensorSliceType): + slices = self._normalize_tensor_slice(index, base.type.rank) + return SemanticTensorSliceExpr(base=base, slices=slices, type=result_type) + return SemanticSubscriptAccess(base=base, index=index, type=result_type) + if isinstance(expr, FrontendBinaryExpr): + lhs = self._analyze_expr(expr.lhs, env, allow_outer_lookup=allow_outer_lookup) + rhs = self._analyze_expr(expr.rhs, env, allow_outer_lookup=allow_outer_lookup) + result_type = self._binary_type(lhs, rhs, expr.op) + return SemanticBinaryExpr(lhs=lhs, op=expr.op, rhs=rhs, type=result_type) + if isinstance(expr, FrontendCallExpr): + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + return self._analyze_call_expr(expr.namespace, expr.name, args) + raise ValueError(f"unsupported frontend expression {type(expr).__name__}") + + def _analyze_symbol_expr(self, expr: FrontendSymbolExpr) -> SemanticExpr: + if expr.namespace == "pto": + dtype = _DTYPE_SYMBOLS.get(expr.name) + if dtype is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=dtype, + type=SemanticMetaType(kind="dtype"), + ) + if expr.namespace in {"PAT", "pto.PAT", "pto.MaskPattern"}: + pattern = _PATTERN_SYMBOLS.get(expr.name) + if pattern is None and expr.name.startswith("PAT_"): + canonical = expr.name[len("PAT_") :] + pattern = _PATTERN_SYMBOLS.get(canonical) + if pattern is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=pattern, + type=SemanticMetaType(kind="mask_pattern"), + ) + if expr.namespace in {"PIPE", "pto.PIPE"}: + pipe = _PIPE_SYMBOLS.get(expr.name) + if pipe is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=pipe, + type=SemanticMetaType(kind="pipe"), + ) + if expr.namespace in {"EVENT", "pto.EVENT"}: + event = _EVENT_SYMBOLS.get(expr.name) + if event is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=event, + type=SemanticMetaType(kind="event"), + ) + raise TypeError( + f"symbol `{expr.namespace}.{expr.name}` is not supported in TileLang DSL v1" + ) + + def _attribute_type(self, base: SemanticExpr, attr: str) -> SemanticType: + base_type = base.type + if isinstance(base_type, SemanticTensorViewType) and attr == "shape": + return SemanticShapeType(rank=base_type.rank) + if isinstance(base_type, SemanticTileType) and attr == "shape": + return SemanticShapeType(rank=base_type.rank) + raise TypeError(f"unsupported attribute access '{attr}' in TileLang DSL v1") + + def _subscript_type(self, base: SemanticExpr, index: SemanticExpr) -> SemanticType: + if isinstance(base.type, SemanticShapeType): + if not isinstance(index.type, SemanticIndexType): + raise TypeError("shape subscript index must be an index value in TileLang DSL v1") + return SemanticIndexType() + if isinstance(base.type, SemanticTensorViewType): + if not isinstance(index, SemanticTupleExpr): + raise TypeError("TensorView slicing expects a tuple of slices in TileLang DSL v1") + return self._tensor_slice_type(base.type, index) + raise TypeError("unsupported subscript base in TileLang DSL v1") + + def _tensor_slice_type( + self, + tensor_type: SemanticTensorViewType, + index: SemanticTupleExpr, + ) -> SemanticTensorSliceType: + if len(index.elements) != tensor_type.rank: + raise TypeError( + f"TensorView slice rank {len(index.elements)} does not match TensorView rank {tensor_type.rank}" + ) + extents = [] + for axis, element in enumerate(index.elements): + if not isinstance(element, SemanticSliceExpr): + raise TypeError( + f"TensorView slicing axis {axis} must use a Python slice in TileLang DSL v1" + ) + start = self._static_index_value(element.start, default=0) + stop = self._static_index_value(element.stop, default=None) + step = self._static_index_value(element.step, default=1) + if stop is None: + raise TypeError("TensorView slicing requires explicit stop bounds in TileLang DSL v1") + if start != 0: + raise TypeError("TensorView slicing currently only supports zero-based starts in TileLang DSL v1") + if step != 1: + raise TypeError("TensorView slicing currently only supports unit stride in TileLang DSL v1") + extent = stop - start + if extent <= 0: + raise TypeError("TensorView slicing requires positive static extents in TileLang DSL v1") + extents.append(extent) + return SemanticTensorSliceType( + element_dtype=tensor_type.element_dtype, + rank=tensor_type.rank, + extents=tuple(extents), + ) + + def _normalize_tensor_slice( + self, + index: SemanticExpr, + rank: int, + ) -> tuple[SemanticSliceExpr, ...]: + if not isinstance(index, SemanticTupleExpr): + raise TypeError("TensorView slicing expects a tuple index in TileLang DSL v1") + if len(index.elements) != rank: + raise TypeError(f"TensorView slicing expects {rank} slice elements in TileLang DSL v1") + slices = [] + for element in index.elements: + if not isinstance(element, SemanticSliceExpr): + raise TypeError("TensorView slicing only supports slice syntax in TileLang DSL v1") + slices.append(element) + return tuple(slices) + + def _binary_type( + self, + lhs: SemanticExpr, + rhs: SemanticExpr, + op: str, + ) -> SemanticType: + if op not in {"add", "sub", "mul", "floordiv"}: + raise TypeError(f"unsupported binary operator '{op}' in TileLang DSL v1") + if isinstance(lhs.type, SemanticIndexType) and isinstance(rhs.type, SemanticIndexType): + return SemanticIndexType() + raise TypeError("binary expressions currently only support index-typed operands") + + def _analyze_call_expr( + self, + namespace: str | None, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if namespace is None and name == "range": + return SemanticCallExpr(namespace=namespace, name=name, args=args, type=None) + if namespace != "pto": + raise TypeError( + f"call surface `{namespace + '.' if namespace else ''}{name}` is not supported in TileLang DSL v1 yet" + ) + if name == "make_mask": + return self._analyze_make_mask(args) + if name == "vlds": + return self._analyze_vlds(args) + if name in _UNARY_VECTOR_OPS: + return self._analyze_unary_vector_op(name, args) + if name in _BINARY_VECTOR_OPS: + return self._analyze_binary_vector_op(name, args) + if name in _VECTOR_SCALAR_OPS: + return self._analyze_vector_scalar_op(name, args) + raise TypeError(f"call surface `pto.{name}` is not supported in TileLang DSL v1 yet") + + def _analyze_make_mask(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) != 2: + raise TypeError("pto.make_mask expects exactly 2 positional arguments in TileLang DSL v1") + dtype_expr, value_expr = args + dtype = self._require_dtype_symbol(dtype_expr, "pto.make_mask element type") + if not ( + isinstance(value_expr, SemanticSymbolExpr) + and value_expr.type.kind == "mask_pattern" + ): + raise TypeError( + "pto.make_mask currently only supports PAT.* pattern lowering in TileLang DSL v1" + ) + return SemanticCallExpr( + namespace="pto", + name="make_mask", + args=args, + type=SemanticMaskType(granularity=self._mask_granularity_for_dtype(dtype)), + ) + + def _analyze_vlds(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) != 2: + raise TypeError("pto.vlds expects exactly 2 positional arguments in TileLang DSL v1") + source, offset = args + tile = self._require_tile_expr(source, "pto.vlds source") + self._require_index_typed_expr(offset) + return SemanticCallExpr( + namespace="pto", + name="vlds", + args=args, + type=self._vreg_type_for_dtype(tile.type.element_dtype), + ) + + def _analyze_unary_vector_op( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if len(args) != 2: + raise TypeError(f"pto.{name} expects exactly 2 positional arguments in TileLang DSL v1") + value, mask = args + vreg = self._require_vreg_expr(value, f"pto.{name} value") + self._require_mask_for_vreg(mask, vreg, f"pto.{name}") + self._validate_unary_dtype(name, vreg.element_dtype) + return SemanticCallExpr(namespace="pto", name=name, args=args, type=vreg) + + def _analyze_binary_vector_op( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if len(args) != 3: + raise TypeError(f"pto.{name} expects exactly 3 positional arguments in TileLang DSL v1") + lhs_expr, rhs_expr, mask = args + lhs = self._require_vreg_expr(lhs_expr, f"pto.{name} lhs") + rhs = self._require_vreg_expr(rhs_expr, f"pto.{name} rhs") + if lhs != rhs: + raise TypeError(f"pto.{name} requires lhs/rhs vector types to match") + self._require_mask_for_vreg(mask, lhs, f"pto.{name}") + self._validate_binary_dtype(name, lhs.element_dtype) + return SemanticCallExpr(namespace="pto", name=name, args=args, type=lhs) + + def _analyze_vector_scalar_op( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if len(args) != 3: + raise TypeError(f"pto.{name} expects exactly 3 positional arguments in TileLang DSL v1") + vector_expr, scalar_expr, mask = args + vreg = self._require_vreg_expr(vector_expr, f"pto.{name} vector") + scalar = self._require_scalar_expr(scalar_expr, f"pto.{name} scalar") + if scalar.dtype != vreg.element_dtype: + raise TypeError(f"pto.{name} scalar dtype must match vector element dtype") + self._require_mask_for_vreg(mask, vreg, f"pto.{name}") + self._validate_vector_scalar_dtype(name, vreg.element_dtype) + return SemanticCallExpr(namespace="pto", name=name, args=args, type=vreg) + + def _require_dtype_symbol(self, expr: SemanticExpr, context: str) -> ScalarType: + if not ( + isinstance(expr, SemanticSymbolExpr) + and expr.type.kind == "dtype" + and isinstance(expr.value, ScalarType) + ): + raise TypeError(f"{context} must be a TileLang scalar dtype symbol in TileLang DSL v1") + return expr.value + + def _require_vreg_expr(self, expr: SemanticExpr, context: str) -> SemanticVRegType: + if not isinstance(expr.type, SemanticVRegType): + raise TypeError(f"{context} must be a vector register value in TileLang DSL v1") + return expr.type + + def _require_scalar_expr(self, expr: SemanticExpr, context: str) -> SemanticScalarType: + if not isinstance(expr.type, SemanticScalarType): + raise TypeError(f"{context} must be a scalar value in TileLang DSL v1") + return expr.type + + def _require_mask_for_vreg( + self, + mask_expr: SemanticExpr, + vreg_type: SemanticVRegType, + context: str, + ) -> None: + if not isinstance(mask_expr.type, SemanticMaskType): + raise TypeError(f"{context} requires a mask operand in TileLang DSL v1") + expected = self._mask_granularity_for_dtype(vreg_type.element_dtype) + if mask_expr.type.granularity != expected: + raise TypeError( + f"{context} requires mask granularity {expected} for vector dtype {vreg_type.element_dtype!r}" + ) + + def _require_matching_vector_pointer( + self, + vreg_type: SemanticVRegType, + pointer_type: SemanticTileType, + context: str, + ) -> None: + if pointer_type.element_dtype != vreg_type.element_dtype: + raise TypeError(f"{context} requires destination Tile dtype to match vector dtype") + + def _mask_granularity_for_dtype(self, dtype: ScalarType) -> str: + if dtype.name in {"f32", "i32"}: + return "b32" + if dtype.name in {"f16", "bf16", "i16"}: + return "b16" + if dtype.name == "i8": + return "b8" + raise TypeError(f"dtype `{dtype.name}` is not supported by make_mask/vector lowering in TileLang DSL v1") + + def _vreg_type_for_dtype(self, dtype: ScalarType) -> SemanticVRegType: + byte_widths = { + "i8": 1, + "i16": 2, + "i32": 4, + "f16": 2, + "bf16": 2, + "f32": 4, + } + width = byte_widths.get(dtype.name) + if width is None: + raise TypeError(f"dtype `{dtype.name}` is not supported by vlds/vsts in TileLang DSL v1") + return SemanticVRegType(element_dtype=dtype, lanes=256 // width) + + def _validate_unary_dtype(self, name: str, dtype: ScalarType) -> None: + if name == "vexp" and dtype.name not in {"f16", "f32"}: + raise TypeError("pto.vexp only supports f16/f32 in TileLang DSL v1") + if name == "vrelu" and dtype.name not in {"f16", "f32"}: + raise TypeError("pto.vrelu only supports f16/f32 in TileLang DSL v1") + if name == "vnot" and dtype.name not in {"i8", "i16", "i32"}: + raise TypeError("pto.vnot only supports integer vector dtypes in TileLang DSL v1") + if name == "vabs" and dtype.name not in {"i8", "i16", "i32", "f16", "f32"}: + raise TypeError("pto.vabs does not support this dtype in TileLang DSL v1") + + def _validate_binary_dtype(self, name: str, dtype: ScalarType) -> None: + if name == "vdiv" and dtype.name not in {"f16", "f32"}: + raise TypeError("pto.vdiv only supports f16/f32 in TileLang DSL v1") + if name in {"vand", "vor", "vxor"} and dtype.name not in {"i8", "i16", "i32"}: + raise TypeError(f"pto.{name} only supports integer vector dtypes in TileLang DSL v1") + if name == "vmul" and dtype.name not in {"i16", "i32", "f16", "f32"}: + raise TypeError("pto.vmul only supports i16/i32/f16/f32 in TileLang DSL v1") + if name in {"vadd", "vsub", "vmax", "vmin"} and dtype.name not in {"i8", "i16", "i32", "f16", "bf16", "f32"}: + raise TypeError(f"pto.{name} does not support this dtype in TileLang DSL v1") + + def _validate_vector_scalar_dtype(self, name: str, dtype: ScalarType) -> None: + if name == "vdivs" and dtype.name not in {"f16", "f32"}: + raise TypeError("pto.vdivs only supports f16/f32 in TileLang DSL v1") + if name in {"vadds", "vsubs", "vmuls", "vmaxs", "vmins"} and dtype.name not in {"i8", "i16", "i32", "f16", "bf16", "f32"}: + raise TypeError(f"pto.{name} does not support this dtype in TileLang DSL v1") + + def _require_sync_pipe(self, expr: SemanticExpr, context: str) -> str: + if isinstance(expr, SemanticSymbolExpr) and expr.type.kind == "pipe": + return expr.value.value + if isinstance(expr, SemanticLiteralExpr) and isinstance(expr.type, SemanticMetaType) and expr.type.kind == "string": + return expr.value + raise TypeError(f"{context} must be a PIPE symbol or pipe string in TileLang DSL v1") + + def _require_sync_event(self, expr: SemanticExpr, context: str) -> str: + if isinstance(expr, SemanticSymbolExpr) and expr.type.kind == "event": + return expr.value.value + if isinstance(expr, SemanticLiteralExpr) and isinstance(expr.type, SemanticMetaType) and expr.type.kind == "string": + return expr.value + raise TypeError(f"{context} must be an EVENT symbol or event string in TileLang DSL v1") + + def _require_loop_bound_type(self, ty: SemanticType) -> None: + if isinstance(ty, (SemanticIndexType, SemanticScalarType)): + return + raise TypeError(f"loop bound must be scalar/index typed, got {ty!r}") + + def _require_condition_type(self, ty: SemanticType) -> None: + if isinstance(ty, SemanticIndexType): + return + if isinstance(ty, SemanticScalarType): + return + raise TypeError(f"if condition must be scalar/index typed, got {ty!r}") + + def _require_index_typed_expr(self, expr: SemanticExpr) -> None: + if not isinstance(expr.type, SemanticIndexType): + raise TypeError("slice bounds and vector offsets must be index-typed in TileLang DSL v1") + + def _static_index_value(self, expr: SemanticExpr | None, *, default: int | None) -> int | None: + if expr is None: + return default + if not isinstance(expr, SemanticLiteralExpr) or not isinstance(expr.value, int): + raise TypeError("TensorView slice bounds must be static integer literals in TileLang DSL v1") + return expr.value + + +def analyze_frontend_kernel(node: FrontendKernelNode) -> SemanticKernel: + """Normalize descriptor-owned AST into a lowering semantic model.""" + + return _SemanticAnalyzer(node).analyze() + + +__all__ = [ + "SemanticAssignStmt", + "SemanticAttributeAccess", + "SemanticBinaryExpr", + "SemanticBinding", + "SemanticBindingRef", + "SemanticCallExpr", + "SemanticDmaLoadStmt", + "SemanticDmaStoreStmt", + "SemanticExpr", + "SemanticExprStmt", + "SemanticForStmt", + "SemanticIfResult", + "SemanticIfStmt", + "SemanticIndexType", + "SemanticKernel", + "SemanticLiteralExpr", + "SemanticMaskType", + "SemanticParameter", + "SemanticPipeBarrierStmt", + "SemanticReturnStmt", + "SemanticScalarType", + "SemanticSetFlagStmt", + "SemanticShapeType", + "SemanticSliceExpr", + "SemanticSliceType", + "SemanticStmt", + "SemanticStrictVecscopeStmt", + "SemanticSubscriptAccess", + "SemanticSymbolExpr", + "SemanticTensorSliceExpr", + "SemanticTensorSliceType", + "SemanticTensorViewType", + "SemanticTileBinding", + "SemanticTileType", + "SemanticTupleExpr", + "SemanticTupleType", + "SemanticType", + "SemanticVRegType", + "SemanticVectorStoreStmt", + "SemanticWaitFlagStmt", + "analyze_frontend_kernel", +] diff --git a/tilelang-dsl/python/tilelang_dsl/types.py b/tilelang-dsl/python/tilelang_dsl/types.py index 161a3e0cd..82bca103a 100644 --- a/tilelang-dsl/python/tilelang_dsl/types.py +++ b/tilelang-dsl/python/tilelang_dsl/types.py @@ -44,6 +44,34 @@ class MemorySpace(str, Enum): UB = "ub" +class Pipe(str, Enum): + MTE1 = "PIPE_MTE1" + MTE2 = "PIPE_MTE2" + V = "PIPE_V" + MTE3 = "PIPE_MTE3" + ALL = "PIPE_ALL" + + +class Event(str, Enum): + ID0 = "EVENT_ID0" + ID1 = "EVENT_ID1" + ID2 = "EVENT_ID2" + ID3 = "EVENT_ID3" + ID4 = "EVENT_ID4" + ID5 = "EVENT_ID5" + ID6 = "EVENT_ID6" + ID7 = "EVENT_ID7" + + +class MaskPattern(str, Enum): + ALL = "PAT_ALL" + ALLF = "PAT_ALLF" + EVEN = "PAT_EVEN" + ODD = "PAT_ODD" + VL16 = "PAT_VL16" + VL32 = "PAT_VL32" + + @dataclass(frozen=True) class TileConfig: fields: tuple[tuple[str, Any], ...] = () @@ -61,12 +89,16 @@ class TileSpecialization: i8 = ScalarType("i8") +i1 = ScalarType("i1") i16 = ScalarType("i16") i32 = ScalarType("i32") i64 = ScalarType("i64") f16 = ScalarType("f16") bf16 = ScalarType("bf16") f32 = ScalarType("f32") +PIPE = Pipe +EVENT = Event +PAT = MaskPattern AnyFloat = WildcardType("AnyFloat") AnyInt = WildcardType("AnyInt") AnyType = WildcardType("AnyType") @@ -87,8 +119,15 @@ def TypeVar(name: str) -> TypeVariable: "TensorView", "Tile", "MemorySpace", + "Pipe", + "Event", + "PIPE", + "EVENT", + "MaskPattern", + "PAT", "TileConfig", "TileSpecialization", + "i1", "i8", "i16", "i32", diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index bef057aeb..c067788a7 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -4,6 +4,27 @@ from pathlib import Path import tilelang_dsl as pto +from tilelang_dsl.frontend_ast import build_frontend_kernel_node +from tilelang_dsl.lowering import AuthoringModule, lower_semantic_kernel +from tilelang_dsl.semantic import ( + SemanticAssignStmt, + SemanticCallExpr, + SemanticDmaLoadStmt, + SemanticDmaStoreStmt, + SemanticForStmt, + SemanticIfStmt, + SemanticIndexType, + SemanticMaskType, + SemanticPipeBarrierStmt, + SemanticScalarType, + SemanticSetFlagStmt, + SemanticStrictVecscopeStmt, + SemanticTensorViewType, + SemanticTileType, + SemanticVectorStoreStmt, + SemanticWaitFlagStmt, + analyze_frontend_kernel, +) class TileLangDSLPackageTests(unittest.TestCase): @@ -13,6 +34,9 @@ def test_package_exports_surface(self) -> None: self.assertTrue(hasattr(pto, "TensorView")) self.assertTrue(hasattr(pto, "Tile")) self.assertTrue(hasattr(pto, "TileSpecialization")) + self.assertTrue(hasattr(pto, "PAT")) + self.assertTrue(hasattr(pto, "PIPE")) + self.assertTrue(hasattr(pto, "EVENT")) class TileLangDSLDescriptorTests(unittest.TestCase): @@ -52,6 +76,8 @@ def kernel(inp: pto.TensorView, tile: pto.Tile): text = specialized.mlir_text() self.assertIn("// tilelang.target = a5", text) self.assertIn("// tilelang.specialize tile shape=(16, 32) memory_space=ub", text) + self.assertIn('module attributes {pto.target_arch = "a5"} {', text) + self.assertIn("func.func @kernel(%arg0: !pto.ptr, %arg1: !pto.ptr) {", text) module = specialized.mlir_module() self.assertEqual(type(module).__name__, "MaterializedMLIRModule") self.assertTrue(module.verify()) @@ -62,6 +88,260 @@ def kernel(inp: pto.TensorView, tile: pto.Tile): specialized.emit(out) self.assertEqual(out.read_text(encoding="utf-8"), text) + def test_descriptor_materialization_flows_through_pipeline(self) -> None: + @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f16, pto.i32)]) + def kernel(inp: pto.TensorView, tile: pto.Tile, scale: pto.i32): + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(8, 16), + memory_space=pto.MemorySpace.UB, + ) + ) + + frontend_kernel = build_frontend_kernel_node(specialized) + self.assertEqual(frontend_kernel.name, "kernel") + self.assertEqual( + [(param.name, param.kind) for param in frontend_kernel.parameters], + [("inp", "tensorview"), ("tile", "tile"), ("scale", "scalar")], + ) + self.assertEqual(frontend_kernel.tile_specializations[0].shape, (8, 16)) + + semantic_kernel = analyze_frontend_kernel(frontend_kernel) + self.assertEqual(semantic_kernel.symbol_name, "kernel") + self.assertEqual(semantic_kernel.tile_bindings[0].memory_space, "ub") + + authoring_module = lower_semantic_kernel(semantic_kernel) + self.assertIsInstance(authoring_module, AuthoringModule) + self.assertEqual(authoring_module.render(), specialized.mlir_text()) + self.assertIn("return", authoring_module.render()) + + def test_semantic_pipeline_binds_parameter_loop_and_strict_vecscope_types(self) -> None: + @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f16, pto.i32)]) + def kernel(inp: pto.TensorView, tile: pto.Tile, scale: pto.i32): + rows = tile.shape[0] + step = rows + with pto.strict_vecscope(inp, tile, scale, 0, rows, step) as ( + vin, + vtmp, + factor, + lb, + ub, + vec_step, + ): + for lane in range(lb, ub, vec_step): + current = factor + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(8, 16), + memory_space=pto.MemorySpace.UB, + ) + ) + + frontend_kernel = build_frontend_kernel_node(specialized) + self.assertEqual(len(frontend_kernel.body), 4) + + semantic_kernel = analyze_frontend_kernel(frontend_kernel) + self.assertIsInstance(semantic_kernel.parameters[0].type, SemanticTensorViewType) + self.assertIsInstance(semantic_kernel.parameters[1].type, SemanticTileType) + self.assertEqual(semantic_kernel.parameters[1].type.shape, (8, 16)) + self.assertIsInstance(semantic_kernel.parameters[2].type, SemanticScalarType) + + rows_assign = semantic_kernel.body[0] + self.assertIsInstance(rows_assign, SemanticAssignStmt) + self.assertIsInstance(rows_assign.targets[0].type, SemanticIndexType) + self.assertTrue(rows_assign.targets[0].ssa_name.startswith("%rows_")) + + vecscope_stmt = semantic_kernel.body[2] + self.assertIsInstance(vecscope_stmt, SemanticStrictVecscopeStmt) + self.assertEqual( + [binding.name for binding in vecscope_stmt.block_arguments], + ["vin", "vtmp", "factor", "lb", "ub", "vec_step"], + ) + self.assertIsInstance(vecscope_stmt.block_arguments[0].type, SemanticTensorViewType) + self.assertIsInstance(vecscope_stmt.block_arguments[1].type, SemanticTileType) + self.assertIsInstance(vecscope_stmt.block_arguments[2].type, SemanticScalarType) + self.assertIsInstance(vecscope_stmt.block_arguments[3].type, SemanticIndexType) + self.assertIsInstance(vecscope_stmt.block_arguments[4].type, SemanticIndexType) + self.assertIsInstance(vecscope_stmt.block_arguments[5].type, SemanticIndexType) + self.assertTrue(vecscope_stmt.block_arguments[0].ssa_name.startswith("%vin_")) + + loop_stmt = vecscope_stmt.body[0] + self.assertIsInstance(loop_stmt, SemanticForStmt) + self.assertEqual(loop_stmt.induction_variable.name, "lane") + self.assertIsInstance(loop_stmt.induction_variable.type, SemanticIndexType) + self.assertTrue(loop_stmt.induction_variable.ssa_name.startswith("%lane_")) + self.assertEqual(loop_stmt.loop_carried, ()) + + text = specialized.mlir_text() + self.assertIn("%rows_", text) + self.assertIn("= arith.constant 8 : index", text) + self.assertIn("pto.strict_vecscope(%arg0, %arg1, %arg2, %c0, %rows_", text) + self.assertIn("^bb0(", text) + self.assertIn("scf.for %lane_", text) + self.assertIn("to %ub_6 step %vec_step_7 {", text) + + def test_dma_load_and_store_lower_to_dma_programming_and_copy_ops(self) -> None: + @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32, pto.f32)]) + def kernel(inp: pto.TensorView, out: pto.TensorView, tile: pto.Tile): + pto.dma_load(inp[0:16, 0:16], tile) + pto.dma_store(tile, out[0:16, 0:16]) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(16, 16), + memory_space=pto.MemorySpace.UB, + ) + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertIsInstance(semantic_kernel.body[0], SemanticDmaLoadStmt) + self.assertIsInstance(semantic_kernel.body[1], SemanticDmaStoreStmt) + + text = specialized.mlir_text() + self.assertIn( + "func.func @kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr) {", + text, + ) + self.assertIn("pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64", text) + self.assertIn( + "pto.copy_gm_to_ubuf %arg0, %arg2, %c0_i64, %c16_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64", + text, + ) + self.assertIn("pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64", text) + self.assertIn( + "pto.copy_ubuf_to_gm %arg2, %arg1, %c0_i64, %c16_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64", + text, + ) + + def test_make_mask_vlds_vsts_and_vector_families_lower_inside_strict_vecscope(self) -> None: + @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32, pto.f32)]) + def kernel(inp: pto.TensorView, tile: pto.Tile, scale: pto.f32): + pto.dma_load(inp[0:16, 0:16], tile) + with pto.strict_vecscope(tile, tile, scale, 0, 256, 64) as ( + src, + dst, + factor, + lb, + ub, + step, + ): + for lane in range(lb, ub, step): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec = pto.vlds(src, lane) + biased = pto.vadds(vec, factor, mask) + summed = pto.vadd(biased, vec, mask) + activated = pto.vrelu(summed, mask) + pto.vsts(activated, dst, lane, mask) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(16, 16), + memory_space=pto.MemorySpace.UB, + ) + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + vecscope = semantic_kernel.body[1] + self.assertIsInstance(vecscope, SemanticStrictVecscopeStmt) + loop_stmt = vecscope.body[0] + self.assertIsInstance(loop_stmt, SemanticForStmt) + mask_assign = loop_stmt.body[0] + self.assertIsInstance(mask_assign, SemanticAssignStmt) + self.assertIsInstance(mask_assign.value, SemanticCallExpr) + self.assertEqual(mask_assign.value.name, "make_mask") + self.assertIsInstance(mask_assign.targets[0].type, SemanticMaskType) + self.assertIsInstance(loop_stmt.body[-1], SemanticVectorStoreStmt) + + text = specialized.mlir_text() + self.assertIn('%mask_7 = pto.pset_b32 "PAT_ALL" : !pto.mask', text) + self.assertIn("%vec_8 = pto.vlds %src_0[%lane_6] : !pto.ptr -> !pto.vreg<64xf32>", text) + self.assertIn( + "%biased_9 = pto.vadds %vec_8, %factor_2, %mask_7 : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32>", + text, + ) + self.assertIn( + "%summed_10 = pto.vadd %biased_9, %vec_8, %mask_7 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32>", + text, + ) + self.assertIn( + "%activated_11 = pto.vrelu %summed_10, %mask_7 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32>", + text, + ) + self.assertIn( + "pto.vsts %activated_11, %dst_1[%lane_6], %mask_7 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask", + text, + ) + + def test_if_else_and_sync_ops_lower_to_scf_if_and_authoring_sync_ops(self) -> None: + @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32, pto.i32)]) + def kernel(inp: pto.TensorView, tile: pto.Tile, flag: pto.i32): + pto.set_flag(pto.PIPE.MTE2, pto.PIPE.V, pto.EVENT.ID0) + pto.wait_flag(pto.PIPE.MTE2, pto.PIPE.V, pto.EVENT.ID0) + step = 64 + if flag: + step = 64 + pto.set_flag(pto.PIPE.V, pto.PIPE.MTE3, pto.EVENT.ID0) + else: + step = 128 + pto.wait_flag(pto.PIPE.V, pto.PIPE.MTE3, pto.EVENT.ID0) + with pto.strict_vecscope(tile, tile, 0, 256, step) as (src, dst, lb, ub, vec_step): + for lane in range(lb, ub, vec_step): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec = pto.vlds(src, lane) + pto.vsts(vec, dst, lane, mask) + pto.pipe_barrier(pto.PIPE.ALL) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(16, 16), + memory_space=pto.MemorySpace.UB, + ) + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertIsInstance(semantic_kernel.body[0], SemanticSetFlagStmt) + self.assertIsInstance(semantic_kernel.body[1], SemanticWaitFlagStmt) + self.assertIsInstance(semantic_kernel.body[3], SemanticIfStmt) + self.assertIsInstance(semantic_kernel.body[5], SemanticPipeBarrierStmt) + + text = specialized.mlir_text() + self.assertIn('pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"]', text) + self.assertIn('pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"]', text) + self.assertIn("= arith.cmpi ne, %arg2, %c0_i32 : i32", text) + self.assertIn("%step_3 = scf.if %tmp_0 -> (index) {", text) + self.assertIn('pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"]', text) + self.assertIn('pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"]', text) + self.assertRegex(text, r"scf\.yield %step_\d+ : index") + self.assertIn("%step_2 = arith.constant 128 : index", text) + self.assertIn("pto.strict_vecscope(%arg1, %arg1, %c0, %c256, %step_3)", text) + self.assertIn("scf.for %lane_", text) + self.assertIn("pto.barrier #pto.pipe", text) + + def test_strict_vecscope_rejects_implicit_capture_during_semantic_analysis(self) -> None: + @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f16, pto.i32)]) + def kernel(inp: pto.TensorView, tile: pto.Tile, scale: pto.i32): + with pto.strict_vecscope(inp, tile) as (vin, vtmp): + leaked = scale + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(8, 16), + memory_space=pto.MemorySpace.UB, + ) + ) + + with self.assertRaises(ValueError) as ctx: + analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertIn("implicit capture of 'scale' is not allowed", str(ctx.exception)) + class TileLangDSLDiagnosticsTests(unittest.TestCase): def test_matcher_feature_diagnostics_point_to_follow_up_change(self) -> None: @@ -114,7 +394,7 @@ def kernel(x: pto.TensorView): pto.vadd(x) return None - self.assertIn("unsupported op surface `pto.vadd`", str(ctx.exception)) + self.assertIn("vector op surface `pto.vadd` requires explicit pto.strict_vecscope", str(ctx.exception)) self.assertIn(f"{__file__}:", str(ctx.exception)) def test_missing_specialization_reports_source_location(self) -> None: From 5ac3dba7a828ef42442d6134ef5445492f988e90 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Fri, 3 Apr 2026 23:53:15 +0800 Subject: [PATCH 018/192] Update openspec --- .../.openspec.yaml | 2 + .../design.md | 167 ++++++++++++++++++ .../proposal.md | 97 ++++++++++ .../specs/tilelang-dsl-vpto-lowering/spec.md | 98 ++++++++++ .../tasks.md | 30 ++++ .../.openspec.yaml | 2 + .../design.md | 138 +++++++++++++++ .../proposal.md | 79 +++++++++ .../tilelang-dsl-advanced-surface/spec.md | 63 +++++++ .../specs/tilelang-dsl-kernel-matcher/spec.md | 57 ++++++ .../tasks.md | 24 +++ .../specs/tilelang-dsl-diagnostics/spec.md | 63 +++++++ openspec/specs/tilelang-dsl-surface/spec.md | 56 ++++++ 13 files changed, 876 insertions(+) create mode 100644 openspec/changes/add-tilelang-dsl-authoring-vpto-lowering/.openspec.yaml create mode 100644 openspec/changes/add-tilelang-dsl-authoring-vpto-lowering/design.md create mode 100644 openspec/changes/add-tilelang-dsl-authoring-vpto-lowering/proposal.md create mode 100644 openspec/changes/add-tilelang-dsl-authoring-vpto-lowering/specs/tilelang-dsl-vpto-lowering/spec.md create mode 100644 openspec/changes/add-tilelang-dsl-authoring-vpto-lowering/tasks.md create mode 100644 openspec/changes/extend-tilelang-dsl-matcher-and-advanced-surface/.openspec.yaml create mode 100644 openspec/changes/extend-tilelang-dsl-matcher-and-advanced-surface/design.md create mode 100644 openspec/changes/extend-tilelang-dsl-matcher-and-advanced-surface/proposal.md create mode 100644 openspec/changes/extend-tilelang-dsl-matcher-and-advanced-surface/specs/tilelang-dsl-advanced-surface/spec.md create mode 100644 openspec/changes/extend-tilelang-dsl-matcher-and-advanced-surface/specs/tilelang-dsl-kernel-matcher/spec.md create mode 100644 openspec/changes/extend-tilelang-dsl-matcher-and-advanced-surface/tasks.md create mode 100644 openspec/specs/tilelang-dsl-diagnostics/spec.md create mode 100644 openspec/specs/tilelang-dsl-surface/spec.md diff --git a/openspec/changes/add-tilelang-dsl-authoring-vpto-lowering/.openspec.yaml b/openspec/changes/add-tilelang-dsl-authoring-vpto-lowering/.openspec.yaml new file mode 100644 index 000000000..c430c5fa6 --- /dev/null +++ b/openspec/changes/add-tilelang-dsl-authoring-vpto-lowering/.openspec.yaml @@ -0,0 +1,2 @@ +schema: spec-driven +created: 2026-04-03 diff --git a/openspec/changes/add-tilelang-dsl-authoring-vpto-lowering/design.md b/openspec/changes/add-tilelang-dsl-authoring-vpto-lowering/design.md new file mode 100644 index 000000000..ab160ec9e --- /dev/null +++ b/openspec/changes/add-tilelang-dsl-authoring-vpto-lowering/design.md @@ -0,0 +1,167 @@ +## Context + +### 范围 + +本 design 只覆盖 TileLang DSL v1 的 lowering 主线: + +- 输入:`tilelang-dsl/` 中定义的 v1 surface +- 输出:authoring-form VPTO IR(`func.func + arith/scf + pto.*`) +- 验证:通过 repo 当前 `ptoas --pto-backend=vpto` 的 authoring-stage legality contract + +它不覆盖: + +- matcher / registry +- implicit vecscope inference +- A5 text / LLVM emission +- advanced vector family + +### 当前状态 + +当前仓库里与本 change 直接相关的事实有: + +1. 真实 authoring-form VPTO carrier 是 dedicated `pto.vecscope/pto.strict_vecscope` + +- `lib/PTO/Transforms/PTOValidateVPTOIR.cpp` +- `test/vpto_validate/vpto_validate_authoring_legacy_scope_negative.mlir` + +都已经说明 legacy `scf.for {llvm.loop.aivector_scope}` 已被拒绝。 + +2. `docs/tilelang-dsl-guide.md` 的 surface 比 v1 计划范围更大 + +- 文档包含 matcher、implicit vecscope inference、advanced family、低层 DMA programming 等。 +- 本 change 必须明确压缩到 elementwise 套餐,否则 lowering 无法在短期内闭合。 + +3. 现有实验 `python/pto/dialects/pto.py` parser 不可直接作为实现基线 + +- 它的 surface 接近 hand-written VPTO,并不等于 TileLang DSL guide。 +- 用户要求本特性工作集中在 `tilelang-dsl/`,并明确不以现有其他 Python binding 实现为前提。 + +### 实现约束 + +- lowering 输出必须符合当前真实的 authoring-form VPTO legality contract。 +- v1 只支持 `strict_vecscope` 显式 vector region,不做 implicit inference。 +- support matrix 必须固定为 elementwise 套餐,避免“实现到哪算哪”。 +- `verify()` 需要给出稳定行为;在无法访问 `ptoas` binary 时不能静默成功。 + +## Goals / Non-Goals + +**Goals:** + +- 定义并实现 TileLang DSL v1 的 fixed support matrix lowering。 +- 让 `dma_load/dma_store`、`make_mask`、`vlds/vsts`、elementwise unary/binary/vector-scalar family、`for`/`if`、基础 sync 都有明确 VPTO lowering 目标。 +- 保证输出 IR 能通过当前 `ptoas --pto-backend=vpto` authoring-stage legality。 +- 为 `verify()` 定义一个明确、可落地、与 repo 当前验证路径一致的契约。 + +**Non-Goals:** + +- 不在本 change 中扩展 matcher surface。 +- 不支持 implicit vecscope inference。 +- 不支持 compare/select/reduction/rearrangement/carry/UB-to-UB copy 等 advanced family。 +- 不把 generated IR 直接送进 A5 text / LLVM emission 作为本 change 的完成标准。 + +## Decisions + +### 1. v1 只接受显式 `strict_vecscope` 作为 Python surface 的 vector carrier + +决策: + +- 用户写 vector op 时,必须显式使用 `with pto.strict_vecscope(...) as (...):` +- frontend 直接 lower 为 dedicated `pto.strict_vecscope` +- v1 不做 implicit `pto.vecscope` inference + +原因: + +- 这与当前真实 authoring contract 一致。 +- 显式 `strict_vecscope` 能明确 capture 边界、block 参数和类型来源,降低 v1 lowering 复杂度。 + +备选方案: + +- 直接实现 implicit vecscope inference + - 放弃原因:需要引入 CFG 分析、scope boundary 规则、与 scalar/control-flow 边界的交互,超出 v1。 + +### 2. v1 lowering support matrix 固定为 elementwise 套餐 + +决策: + +- 支持: + - 2D `TensorView` + - 1D/2D `Tile` + - `dma_load` + - `dma_store` + - `make_mask(dtype, PAT.*)` / `make_mask(dtype, remaining)` + - `vlds` / `vsts` + - unary:`vabs`, `vrelu`, `vexp`, `vnot` + - binary:`vadd`, `vsub`, `vmul`, `vdiv`, `vmax`, `vmin`, `vand`, `vor`, `vxor` + - vector-scalar:`vadds`, `vsubs`, `vmuls`, `vdivs`, `vmaxs`, `vmins` + - `for range(lb, ub, step)` + - `if/else` + - `set_flag`, `wait_flag`, `pipe_barrier` +- 其余 family 在 frontend reject + +原因: + +- 这套矩阵足以覆盖 `docs/tilelang-dsl-guide.md` 的代表性 elementwise kernel。 +- 它避免把 advanced families 和 low-level authoring surface 混入 v1。 + +### 3. `dma_load/dma_store` 在 frontend 直接展开为 VPTO copy programming + copy op + +决策: + +- `dma_load` lower 到必要的 `set_loop*_stride_outtoub` / `set_loop_size_outtoub` + `copy_gm_to_ubuf` +- `dma_store` lower 到必要的 `set_loop*_stride_ubtoout` / `set_loop_size_ubtoout` + `copy_ubuf_to_gm` +- 参数由 TensorView slice、Tile shape/config、padding mode 推导 + +原因: + +- 当前 authoring-form VPTO 已经以这些 op 作为合法 surface。 +- 对 v1 来说,直接在 frontend materialize 这些 op 比再引入一层 TileLang-specific DMA IR 更简单。 + +### 4. shape profile 固定为“静态 physical Tile + 动态 view/bound” + +决策: + +- Tile physical shape 必须是静态编译期常量 +- TensorView 的 shape、slice 边界、loop bound 可以包含 runtime value +- `valid_shape` 仅支持: + - 静态值 + - 由 TensorView partition 直接推导 + +原因: + +- 这能覆盖 guide 中动态 TensorView / tail handling 的主要场景。 +- 同时避免在 v1 引入 fully-dynamic tile allocation 语义。 + +### 5. `verify()` 通过 `ptoas` subprocess 复用 repo 当前 legality contract + +决策: + +- `descriptor.verify()` 以临时文件或等价 stdin 方式调用 repo 中可用的 `ptoas` binary +- 命令路径按以下顺序解析: + - 显式传入或环境变量覆盖 + - `build/tools/ptoas/ptoas` +- 验证命令以 `--pto-backend=vpto --emit-vpto` 或等价 authoring-stage legality 路径运行 +- binary 缺失或不可执行时,返回结构化 `verifier unavailable` 结果,而不是静默成功 + +原因: + +- 当前 repo 没有直接暴露 custom VPTO legality pass 的 Python binding。 +- 复用 `ptoas` binary 是最直接且与现有回归一致的验证方式。 + +备选方案: + +- 在 Python 中直接调用 verifier pass + - 放弃原因:当前不存在稳定的 Python 入口,短期内实现成本高于收益。 + +## Risks / Trade-offs + +- [Risk] 直接展开 `dma_load/dma_store` 会把 DMA parameter inference 复杂度推到 frontend + Mitigation:v1 只支持固定 profile 的 TensorView slice / Tile layout,超出矩阵的场景一律 reject。 + +- [Risk] `verify()` 依赖 `ptoas` binary,环境不完整时可能影响开发体验 + Mitigation:定义清晰的 binary 查找顺序和结构化 unavailable 结果,避免模糊失败。 + +- [Risk] support matrix 过窄可能与 guide 的完整愿景存在落差 + Mitigation:在 change3 中单独扩展 matcher 和 advanced surface,并在 v1 diagnostics 中明确延期边界。 + +- [Risk] 显式 `strict_vecscope` 可能让首版示例看起来较啰嗦 + Mitigation:把 implicit vecscope inference 明确放入 follow-up change,不在 v1 做半成品推断。 diff --git a/openspec/changes/add-tilelang-dsl-authoring-vpto-lowering/proposal.md b/openspec/changes/add-tilelang-dsl-authoring-vpto-lowering/proposal.md new file mode 100644 index 000000000..7095bfce9 --- /dev/null +++ b/openspec/changes/add-tilelang-dsl-authoring-vpto-lowering/proposal.md @@ -0,0 +1,97 @@ +# Proposal: 实现 TileLang DSL v1 到 authoring-form VPTO IR 的 lowering + +## 概述 + +在 `add-tilelang-dsl-core-foundation` 固定 package、surface 和 diagnostics 之后,本 change 负责把 v1 核心子集真正 lower 到 authoring-form VPTO IR。 +目标不是直接产出 A5 text/LLVM 发射结果,而是在 `tilelang-dsl/` 下建立一条稳定的 `TileLang DSL -> func/arith/scf/pto` authoring pipeline,并要求生成结果能够通过当前 `ptoas --pto-backend=vpto` 的 authoring-stage legality contract。 + +## 背景与动机 + +当前仓库里虽然已经有: + +- `docs/tilelang-dsl-guide.md` 对高层 DSL 的表述 +- `docs/vpto-spec.md` / `docs/vpto-verify.md` 对 VPTO IR 的表述 +- 一个实验性的 `python/pto/dialects/pto.py` parser + +但还没有一条与 guide 对齐、且以 `tilelang-dsl/` 为承载目录的正式 lowering 路径。 +如果没有这条路径,TileLang DSL 只能停留在文档层,后续 sample、regression 和 capability 都无法真正收敛。 + +同时,v1 必须先压缩到固定 support matrix: + +- 只做 `a5` +- 只做 elementwise 套餐 +- 只做显式 `strict_vecscope` +- 只做到 authoring-form VPTO + +否则会把 matcher、implicit vecscope inference、advanced family 一起引入,导致 v1 无法闭合。 + +## 目标 + +- 在 `tilelang-dsl/` 下实现 TileLang DSL v1 到 authoring-form VPTO IR 的 lowering。 +- 固定 v1 lowering 仅支持 elementwise 套餐: + - 2D `TensorView` + - 1D/2D `Tile` + - `dma_load` / `dma_store` + - `make_mask` + - `vlds` / `vsts` + - 常用 unary / binary / vector-scalar family + - `for` / `if` + - `set_flag` / `wait_flag` / `pipe_barrier` +- 明确 vector surface 必须位于显式 `strict_vecscope` 内,v1 不做 implicit vecscope inference。 +- 提供 `verify()` 契约,使 generated IR 能按当前 repo 的 VPTO authoring-stage legality 路径进行验证。 + +## 非目标 + +- 不在本 change 中实现 kernel matcher、`constraints`、`priority`、`Any*`、`TypeVar`。 +- 不在本 change 中实现 implicit vecscope inference。 +- 不在本 change 中扩展到 compare/select/reduction/rearrangement 等 advanced family。 +- 不在本 change 中直接产出 A5 text/LLVM emission 结果。 +- 不在本 change 中把实现回填到现有 `python/pto/dialects/pto.py` 实验 parser。 + +## 变更内容 + +- 新增 `tilelang-dsl-vpto-lowering` capability,定义 v1 fixed support matrix 的 lowering 行为。 +- 固定 `strict_vecscope` 是 v1 唯一合法的 vector-surface Python carrier;vector op 出现在显式 scope 外必须由 frontend 拒绝。 +- 固定 `dma_load` / `dma_store` 到 `copy_gm_to_ubuf` / `copy_ubuf_to_gm` 以及必需 DMA programming op 的 lowering 规则。 +- 固定 `verify()` 通过 `ptoas --pto-backend=vpto` authoring-stage legality 契约校验 generated module;当环境缺少 `ptoas` binary 时返回结构化 “verifier unavailable” 结果。 + +## Capabilities + +### New Capabilities + +- `tilelang-dsl-vpto-lowering`: 定义 TileLang DSL v1 从高层 Tile/TensorView surface 到 authoring-form VPTO IR 的 lowering 目标、support matrix、dynamic-bound 轮廓与验证接口。 + +### Modified Capabilities + +- 无 + +## 预期结果 + +- `tilelang-dsl/` 下的 v1 kernel 能产出稳定的 authoring-form VPTO IR 文本/模块。 +- 生成的 IR 使用当前真实 contract:显式 `pto.strict_vecscope`、typed mask、authoring-form buffer-like address。 +- v1 support matrix 外的 surface 在 frontend 直接 reject,不让未实现的 family 混入 lowering。 +- `verify()` 能复用 repo 当前 `ptoas` legality 路径,对 generated IR 给出 pass/fail 结果。 + +## 成功标准 + +- 新增 `openspec/changes/add-tilelang-dsl-authoring-vpto-lowering/`,包含 proposal、design、tasks。 +- 新增 `specs/tilelang-dsl-vpto-lowering/spec.md`。 +- proposal/design/tasks 明确写清: + - v1 只支持显式 `strict_vecscope` + - v1 support matrix 的 family 列表 + - `dma_load/dma_store` 的 lowering 目标 + - `verify()` 必须走与 `ptoas --pto-backend=vpto` 一致的 authoring-stage legality contract + +## 影响 + +- 受影响目录: + - `tilelang-dsl/python/` + - `tilelang-dsl/tests/` + - `tilelang-dsl/examples/` + - `tilelang-dsl/docs/` +- 受影响 public API: + - `descriptor.mlir_text()` + - `descriptor.mlir_module()` + - `descriptor.verify()` +- 受影响验证路径: + - 生成物必须兼容 `ptoas --pto-backend=vpto` 的 authoring-stage legality diff --git a/openspec/changes/add-tilelang-dsl-authoring-vpto-lowering/specs/tilelang-dsl-vpto-lowering/spec.md b/openspec/changes/add-tilelang-dsl-authoring-vpto-lowering/specs/tilelang-dsl-vpto-lowering/spec.md new file mode 100644 index 000000000..d06207721 --- /dev/null +++ b/openspec/changes/add-tilelang-dsl-authoring-vpto-lowering/specs/tilelang-dsl-vpto-lowering/spec.md @@ -0,0 +1,98 @@ +# tilelang-dsl-vpto-lowering Specification + +## ADDED Requirements + +### Requirement: TileLang DSL v1 MUST lower vector surface through explicit `pto.strict_vecscope` + +TileLang DSL v1 中所有产生或消费 `!pto.vreg` / `!pto.mask<...>` 的 surface MUST 位于显式 `with pto.strict_vecscope(...) as (...):` 内。 +frontend MUST 将该 surface 直接 lower 为 dedicated `pto.strict_vecscope` authoring-form VPTO carrier。 +v1 MUST NOT 对用户省略的 vector region 做 implicit `pto.vecscope` inference。 + +#### Scenario: explicit `strict_vecscope` is preserved in authoring-form VPTO + +- **WHEN** 用户在 DSL 中显式书写 `with pto.strict_vecscope(...) as (...):` +- **THEN** lowering 结果 MUST 生成对应的 `pto.strict_vecscope` +- **AND** region argument、capture operand 和 block argument 类型 MUST 与 DSL surface 中的显式 capture 一一对应 + +#### Scenario: vector op outside explicit scope is rejected before IR generation + +- **WHEN** 用户在 `strict_vecscope` 外直接书写 `vlds`、`vsts`、vector ALU 或 predicate-producing surface +- **THEN** frontend MUST 在生成 VPTO IR 之前报错 +- **AND** MUST NOT 试图在 v1 中自动推断隐式 vecscope + +### Requirement: TileLang DSL v1 MUST support the fixed elementwise lowering profile + +TileLang DSL v1 lowering MUST 支持以下固定 support matrix: + +- 2D `TensorView` +- 1D/2D `Tile` +- `dma_load` +- `dma_store` +- `make_mask(dtype, PAT.*)` / `make_mask(dtype, remaining)` +- `vlds` +- `vsts` +- unary:`vabs`, `vrelu`, `vexp`, `vnot` +- binary:`vadd`, `vsub`, `vmul`, `vdiv`, `vmax`, `vmin`, `vand`, `vor`, `vxor` +- vector-scalar:`vadds`, `vsubs`, `vmuls`, `vdivs`, `vmaxs`, `vmins` +- `for range(lb, ub, step)` +- `if/else` +- `set_flag`, `wait_flag`, `pipe_barrier` + +support matrix 外的 surface MUST 在 frontend reject。 + +#### Scenario: representative elementwise kernel lowers to authoring-form VPTO + +- **WHEN** 用户编写由 `TensorView`、`Tile`、高层 DMA、typed mask、elementwise vector op、`for`、`if` 和基础 sync 组成的 kernel +- **THEN** frontend MUST 产出只包含 `func.func`、`arith`、`scf` 和合法 `pto.*` authoring surface 的 VPTO IR +- **AND** 该 IR MUST 不依赖 matcher、implicit vecscope inference 或 advanced family 才能成立 + +#### Scenario: unsupported advanced family is rejected in v1 + +- **WHEN** 用户在 v1 kernel 中使用 compare/select/reduction/rearrangement、UB-to-UB copy 或其他不在 support matrix 内的 family +- **THEN** frontend MUST 直接报错 +- **AND** MUST NOT 静默降级为其他 family 或生成半合法 VPTO IR + +### Requirement: TileLang DSL v1 MUST support static physical Tile shape with dynamic TensorView views and loop bounds + +TileLang DSL v1 中,Tile physical shape MUST 是静态编译期常量。 +TensorView shape、slice 边界、loop bound 和 tail 相关 remaining value MAY 包含 runtime value。 +`valid_shape` 仅可使用静态值或由 TensorView partition 直接推导。 + +#### Scenario: dynamic TensorView slice and tail mask lower successfully + +- **WHEN** 用户使用 dynamic TensorView slice、dynamic loop bound,并在 loop 中通过 `make_mask(dtype, remaining)` 处理尾块 +- **THEN** frontend MUST 生成合法的 authoring-form VPTO IR +- **AND** tail mask MUST lower 为与元素类型匹配的 typed predicate family +- **AND** Tile physical shape MUST 继续保持静态契约 + +### Requirement: `dma_load` and `dma_store` MUST lower to VPTO DMA programming plus copy ops + +TileLang DSL 的高层 `dma_load` / `dma_store` MUST 在 frontend lower 到当前合法 VPTO authoring surface: + +- GM -> UB:必要的 `set_loop*_stride_outtoub` / `set_loop_size_outtoub` + `copy_gm_to_ubuf` +- UB -> GM:必要的 `set_loop*_stride_ubtoout` / `set_loop_size_ubtoout` + `copy_ubuf_to_gm` + +参数 MUST 由 TensorView slice、Tile shape/config 和 padding mode 推导。 + +#### Scenario: high-level DMA becomes legal VPTO copy programming + +- **WHEN** 用户在 DSL 中编写 `dma_load(input_tensor[slice], ub_tile)` 或 `dma_store(ub_tile, output_tensor[slice])` +- **THEN** lowering MUST 显式生成对应的 DMA programming op 和 copy op +- **AND** 生成结果 MUST 符合当前 VPTO copy-family 的 authoring contract + +### Requirement: `verify()` MUST validate generated IR through the repo VPTO authoring-stage legality path + +TileLang DSL descriptor 的 `verify()` MUST 以 repo 当前 `ptoas` legality 路径验证生成结果。 +当可用的 `ptoas` binary 缺失、不可执行或环境不完整时,`verify()` MUST 返回结构化 `verifier unavailable` 结果,而不是静默通过。 + +#### Scenario: generated VPTO module is checked by `ptoas` + +- **WHEN** 用户对一个已 specialization 的 kernel 调用 `verify()` +- **THEN** implementation MUST 使用与 `ptoas --pto-backend=vpto` 一致的 authoring-stage legality contract 对生成 module 进行校验 +- **AND** 成功结果 MUST 代表 generated IR 已通过当前 repo 的 VPTO authoring legality + +#### Scenario: verifier-unavailable is reported explicitly + +- **WHEN** `verify()` 无法找到或执行 `ptoas` binary +- **THEN** implementation MUST 返回结构化 `verifier unavailable` 结果 +- **AND** MUST NOT 把“未验证”误报成“验证通过” diff --git a/openspec/changes/add-tilelang-dsl-authoring-vpto-lowering/tasks.md b/openspec/changes/add-tilelang-dsl-authoring-vpto-lowering/tasks.md new file mode 100644 index 000000000..be8775819 --- /dev/null +++ b/openspec/changes/add-tilelang-dsl-authoring-vpto-lowering/tasks.md @@ -0,0 +1,30 @@ +## 1. OpenSpec 契约落定 + +- [x] 1.1 新增 `openspec/changes/add-tilelang-dsl-authoring-vpto-lowering/specs/tilelang-dsl-vpto-lowering/spec.md`,固定 v1 lowering support matrix、dynamic-bound 轮廓和 `verify()` 契约。 +- [x] 1.2 在 `proposal.md` 和 `design.md` 中明确 v1 只支持显式 `strict_vecscope`,不做 implicit vecscope inference。 + +## 2. Frontend lowering 骨架 + +- [x] 2.1 在 `tilelang-dsl/python/` 中建立独立的 AST/语义/lowering pipeline,把 core-foundation 的 descriptor 接到 authoring-form VPTO builder。 +- [x] 2.2 实现 `TensorView`、`Tile`、标量、loop bound 和 `strict_vecscope` block argument 的类型绑定与 SSA 环境管理。 +- [x] 2.3 让 lowering 输出稳定的 `func.func + arith/scf + pto.*` authoring-form VPTO module。 + +## 3. Elementwise support matrix + +- [x] 3.1 实现 `dma_load` / `dma_store` 的 TensorView slice 到 DMA programming + copy op lowering。 +- [x] 3.2 实现 `make_mask`、`vlds`、`vsts` 以及 v1 unary/binary/vector-scalar family 的 lowering。 +- [x] 3.3 实现 `for range(lb, ub, step)`、`if/else`、`set_flag`、`wait_flag`、`pipe_barrier` 的 lowering。 +- [ ] 3.4 对 support matrix 外的 family 保持 fail-fast reject,不允许 silent fallback。 + +## 4. Dynamic-bound 与合法性验证 + +- [ ] 4.1 实现“静态 physical Tile + 动态 TensorView slice/loop bound”的 shape profile,拒绝 dynamic physical tile shape。 +- [ ] 4.2 实现 tail `make_mask(dtype, remaining)` 的 typed-mask lowering,确保输出满足当前 VPTO legality contract。 +- [ ] 4.3 实现 `descriptor.verify()`,通过 `ptoas` binary 运行与 `--pto-backend=vpto` 一致的 authoring-stage legality 验证,并对 binary 缺失返回结构化 unavailable 结果。 + +## 5. 测试、样例与文档 + +- [ ] 5.1 在 `tilelang-dsl/tests/` 增加 elementwise kernel 的 positive regression,覆盖 `dma_load/store`、`strict_vecscope`、typed-mask、dynamic loop bound。 +- [ ] 5.2 增加 negative regression,覆盖 vector op 出 scope、unsupported family、非法 shape profile、verifier unavailable。 +- [ ] 5.3 在 `tilelang-dsl/examples/` 和 `tilelang-dsl/docs/` 提供与 guide 对齐的 v1 示例,并明确记录 support matrix 与延期 feature。 +- [ ] 5.4 运行并记录最小验证命令,确认生成的 IR 能通过 `build/tools/ptoas/ptoas --pto-backend=vpto` 的 authoring-stage legality 路径。 diff --git a/openspec/changes/extend-tilelang-dsl-matcher-and-advanced-surface/.openspec.yaml b/openspec/changes/extend-tilelang-dsl-matcher-and-advanced-surface/.openspec.yaml new file mode 100644 index 000000000..c430c5fa6 --- /dev/null +++ b/openspec/changes/extend-tilelang-dsl-matcher-and-advanced-surface/.openspec.yaml @@ -0,0 +1,2 @@ +schema: spec-driven +created: 2026-04-03 diff --git a/openspec/changes/extend-tilelang-dsl-matcher-and-advanced-surface/design.md b/openspec/changes/extend-tilelang-dsl-matcher-and-advanced-surface/design.md new file mode 100644 index 000000000..4c5850412 --- /dev/null +++ b/openspec/changes/extend-tilelang-dsl-matcher-and-advanced-surface/design.md @@ -0,0 +1,138 @@ +## Context + +### 范围 + +本 design 覆盖两个 follow-up 能力: + +1. `tilelang-dsl-kernel-matcher` +2. `tilelang-dsl-advanced-surface` + +它们都建立在 core-foundation 与 authoring-vpto-lowering 两个 change 之上。 +目标不是推翻 v1,而是在 v1 已稳定的前提下,把被延期的 guide surface 升级成正式 capability。 + +### 当前状态 + +当前 v1 规划已经明确: + +- 只接受单一 monomorphic `dtypes` +- 只接受显式 `strict_vecscope` +- 只支持固定 elementwise 套餐 + +因此,以下 guide 能力仍处于“文档存在、实现未承诺”的状态: + +- 多 kernel / constraint / priority matcher +- `Any*` / `TypeVar` +- implicit vecscope inference +- raw pointer / low-level DMA authoring +- compare/select、predicate movement、carry、rearrangement、reduction 等 advanced family + +如果不把这些能力定义成单独的 follow-up change,v1 diagnostics 将长期缺少正式的迁移目标。 + +### 实现约束 + +- 继续保持 `tilelang-dsl/` 是本特性的源码与测试承载目录。 +- matcher 和 advanced surface 都必须最终收敛到当前 repo 的 authoring-form VPTO legality contract。 +- 显式 `strict_vecscope` 仍是强边界,advanced inference 不能破坏这一点。 +- 新 API 必须 deterministic,不能让“同样的 kernel 集合、同样的输入”出现不稳定选择结果。 + +## Goals / Non-Goals + +**Goals:** + +- 定义 kernel registry / selection 的明确接口和匹配顺序。 +- 让 `constraints`、`priority`、多 signature `dtypes`、`Any*`、`TypeVar` 进入正式契约。 +- 让 implicit vecscope inference、raw pointer surface、low-level DMA 和 advanced family 进入正式 capability。 + +**Non-Goals:** + +- 不改变 v1 核心 capability 中已经定义的 package/目录边界。 +- 不支持 `a5` 之外的 target。 +- 不把 TileLang DSL 改造成任意 Python 语法执行器。 + +## Decisions + +### 1. matcher 采用显式 registry + selection API,而不是隐式扫描所有 Python function + +决策: + +- `@pto.vkernel` 定义的 descriptor 自动注册到 module-level `KernelRegistry` +- 公开 `pto.select_kernel(target, op, operand_types, context_attrs, registry=None)` 作为 selection 入口 +- 选中的 descriptor 继续复用 v1 的 `specialize()` / `mlir_text()` / `verify()` 流程 + +原因: + +- registry + selection API 比“由外部框架自己 introspect Python globals”更稳定、可测试。 +- 这让 matcher capability 能独立存在,而不强绑某个上游 compiler integration。 + +### 2. selection 顺序固定为 target -> op -> dtype signature -> constraints -> priority -> tie error + +决策: + +- 先按 `target` +- 再按 `op` +- 再按 `dtypes` signature 做 concrete / wildcard / type-variable 匹配 +- 再评估 `constraints` +- 剩余候选按最高 `priority` 选择 +- 若最高 `priority` 仍有多个候选,则报 deterministic tie error,不做隐式 tiebreak + +原因: + +- 这与 guide 中的叙述一致,同时避免“靠定义顺序兜底”的隐式行为。 + +### 3. `Any*` 与 `TypeVar` 进入 matcher capability,而不是回写 v1 surface + +决策: + +- `AnyFloat`、`AnyInt`、`AnyType`、`AnyMask` 只在 matcher capability 中生效 +- `TypeVar("T")` 只用于单个 signature 内的位置一致性约束 +- v1 核心 surface 不回溯修改;advanced capability 启用后再开放这些写法 + +原因: + +- 这样可以保持 v1 core 仍然简单,同时让 follow-up change 独立定义 wildcard/type-variable 语义。 + +### 4. implicit vecscope inference 作为 advanced surface 的默认行为,但 `strict_vecscope` 继续是硬边界 + +决策: + +- 当用户在 advanced mode 下省略显式 scope,并书写连续的 supported vector chain 时,frontend 默认推断 `pto.vecscope` +- scalar op、控制流边界、外部 call、以及显式 `strict_vecscope` 都会切断 inference +- `strict_vecscope` 继续保留,且 inference MUST NOT 穿越其边界 + +原因: + +- 这与 guide 的默认 authoring 体验一致。 +- 同时保留 `strict_vecscope` 作为 deterministic 边界,避免 inference 影响关键 kernel 的资源边界。 + +### 5. advanced surface 扩展 raw pointer / low-level DMA / advanced family,但继续收敛到 authoring-form VPTO + +决策: + +- raw pointer / low-level DMA surface 增加: + - `castptr` + - `addptr` + - raw UBRef load/store + - low-level DMA programming + - `copy_ubuf_to_ubuf` +- advanced vector family 增加: + - compare/select + - predicate movement + - carry family + - rearrangement + - reduction + +这些 surface 仍必须 lower 到当前真实的 authoring-form VPTO,而不是发明新的公开中间 IR。 + +## Risks / Trade-offs + +- [Risk] matcher capability 引入 registry 和 selection API,会让 package surface 明显扩大 + Mitigation:把 registry/query API 单独收敛在 matcher capability,避免污染 v1 core descriptor API。 + +- [Risk] implicit vecscope inference 可能让 scope boundary 难以调试 + Mitigation:保留 `strict_vecscope` 作为显式硬边界,并要求 inference 在 control-flow / scalar boundary 上切断。 + +- [Risk] advanced family 范围过宽,容易再次失控 + Mitigation:按 capability 明确列出 family 分组,并用 regression 锁定首批支持面,其他 family 继续 reject。 + +- [Risk] raw pointer / low-level DMA authoring 可能让用户绕过高层安全网 + Mitigation:advanced surface 继续要求最终输出通过同一套 VPTO legality contract,不因“更底层”而放宽最终收口。 diff --git a/openspec/changes/extend-tilelang-dsl-matcher-and-advanced-surface/proposal.md b/openspec/changes/extend-tilelang-dsl-matcher-and-advanced-surface/proposal.md new file mode 100644 index 000000000..676b94291 --- /dev/null +++ b/openspec/changes/extend-tilelang-dsl-matcher-and-advanced-surface/proposal.md @@ -0,0 +1,79 @@ +# Proposal: 扩展 TileLang DSL 的 matcher 与 advanced surface + +## 概述 + +`add-tilelang-dsl-core-foundation` 和 `add-tilelang-dsl-authoring-vpto-lowering` 会先收敛出一个可闭合的 v1 核心子集,但 `docs/tilelang-dsl-guide.md` 还定义了更完整的方向:kernel matcher、多 signature `dtypes`、`Any*` / `TypeVar`、constraint-based selection、implicit vecscope inference、raw pointer surface 和 advanced vector family。 +本 change 作为明确的 follow-up capability,负责把这些能力从“v1 diagnostics 中被拒绝的延期项”升级为正式契约,并继续要求相关工作集中在 `tilelang-dsl/` 下实现。 + +## 背景与动机 + +v1 核心 change 有意做了三项收缩: + +1. 只保留单一 monomorphic `dtypes` +2. 只接受显式 `strict_vecscope` +3. 只支持 elementwise 套餐 + +这些收缩让首版可实现,但也与 `docs/tilelang-dsl-guide.md` 的完整愿景存在差距: + +- guide 已经定义 matcher / priority / constraints / wildcard typing +- guide 已经把 implicit vecscope inference 作为默认 authoring 体验 +- guide 还覆盖 raw pointer、低层 DMA、compare/select、predicate movement、rearrangement、reduction 等更大 surface + +如果不把这些延期项显式收口成 follow-up change,v1 中的 reject diagnostics 就会长期停留在“未来再说”,缺少明确的能力落点。 + +## 目标 + +- 为 TileLang DSL 建立正式的 kernel matcher capability:多 signature `dtypes`、`Any*`、`TypeVar`、`constraints`、`priority` 和 deterministic selection。 +- 为 TileLang DSL 建立 advanced surface capability:implicit vecscope inference、raw pointer / low-level DMA surface、advanced vector family。 +- 保持核心实现继续集中在 `tilelang-dsl/`,不把 matcher 或 advanced lowering 回填到现有其他 Python binding 入口。 + +## 非目标 + +- 不修改 v1 基础 change 中已经固定的 package/目录边界。 +- 不重新设计 `verify()` 的基本验证路径;advanced change 仍以当前 repo 的 VPTO legality 契约为输出收口。 +- 不在本 change 中扩展到 `a5` 之外的 target。 + +## 变更内容 + +- 新增 `tilelang-dsl-kernel-matcher` capability,定义 kernel registry、match order、wildcard/type-variable 语义、constraint evaluation 和 selection tie-breaking。 +- 新增 `tilelang-dsl-advanced-surface` capability,定义 implicit vecscope inference、raw pointer/UBRef authoring、low-level DMA surface 以及 advanced vector family 的扩展 lowering 契约。 +- 要求 core-foundation change 中对延期 feature 的 reject diagnostics 在本 change 落地后转为正式支持路径。 + +## Capabilities + +### New Capabilities + +- `tilelang-dsl-kernel-matcher`: 定义多 kernel 注册、target/op/type/constraint/priority 匹配、wildcard typing 和 deterministic selection 契约。 +- `tilelang-dsl-advanced-surface`: 定义 implicit vecscope inference、raw pointer / low-level DMA / UBRef surface 以及 advanced vector family 的扩展 lowering 契约。 + +### Modified Capabilities + +- 无 + +## 预期结果 + +- TileLang DSL 从 v1 的“固定单 kernel elementwise 子集”扩展到可注册、可选择、可约束的 kernel authoring 体系。 +- 当用户省略显式 scope 时,frontend 能按规则推断 `pto.vecscope`,同时继续保留 `strict_vecscope` 作为硬边界。 +- raw pointer / low-level DMA / advanced family 有清晰 capability,而不再只是文档愿景。 + +## 成功标准 + +- 新增 `openspec/changes/extend-tilelang-dsl-matcher-and-advanced-surface/`,包含 proposal、design、tasks。 +- 新增 `specs/tilelang-dsl-kernel-matcher/spec.md` 和 `specs/tilelang-dsl-advanced-surface/spec.md`。 +- proposal/design/tasks 明确写清: + - kernel registry / selection API + - 多 signature `dtypes`、`Any*`、`TypeVar`、constraint evaluation 和 priority 决策顺序 + - implicit vecscope inference 的默认行为和边界 + - raw pointer / low-level DMA / advanced family 的支持范围 + +## 影响 + +- 受影响目录: + - `tilelang-dsl/python/` + - `tilelang-dsl/tests/` + - `tilelang-dsl/examples/` + - `tilelang-dsl/docs/` +- 受影响 public API: + - `@pto.vkernel(... dtypes=[...], constraints=[...], priority=...)` + - `pto.select_kernel(...)` 或等价 registry 查询入口 + - implicit vecscope inference 相关的 compile behavior diff --git a/openspec/changes/extend-tilelang-dsl-matcher-and-advanced-surface/specs/tilelang-dsl-advanced-surface/spec.md b/openspec/changes/extend-tilelang-dsl-matcher-and-advanced-surface/specs/tilelang-dsl-advanced-surface/spec.md new file mode 100644 index 000000000..9f3db0d6d --- /dev/null +++ b/openspec/changes/extend-tilelang-dsl-matcher-and-advanced-surface/specs/tilelang-dsl-advanced-surface/spec.md @@ -0,0 +1,63 @@ +# tilelang-dsl-advanced-surface Specification + +## ADDED Requirements + +### Requirement: advanced mode MUST infer `pto.vecscope` for eligible vector chains while preserving `strict_vecscope` boundaries + +在 advanced mode 下,当用户省略显式 scope 且书写连续的 supported vector chain 时,frontend MUST 自动推断 dedicated `pto.vecscope`。 +scalar op、控制流边界、外部 call 和显式 `strict_vecscope` MUST 切断该推断。 +`strict_vecscope` 继续作为硬边界,inference MUST NOT 穿越其边界。 + +#### Scenario: contiguous vector chain becomes one inferred `pto.vecscope` + +- **WHEN** 用户在 advanced mode 下连续书写一条由 load -> vector ALU -> store 组成的纯 vector chain,且中间没有 scalar/control-flow boundary +- **THEN** frontend MUST 将该 chain lower 为一个 dedicated `pto.vecscope` +- **AND** 该 inferred vecscope MUST 满足当前 VPTO authoring legality contract + +#### Scenario: explicit `strict_vecscope` remains an inference barrier + +- **WHEN** 用户在 advanced mode 下显式书写 `strict_vecscope` +- **THEN** frontend MUST 保留该 `strict_vecscope` 原样语义 +- **AND** scope inference MUST NOT 跨越该显式边界去并合前后 vector chain + +### Requirement: advanced mode MUST support raw pointer, UBRef, low-level DMA, and `copy_ubuf_to_ubuf` authoring + +advanced mode MUST 将以下 surface 纳入正式契约: + +- `castptr` +- `addptr` +- raw UBRef load/store authoring +- low-level DMA programming +- `copy_ubuf_to_ubuf` + +这些 surface 仍 MUST lower 到当前合法的 authoring-form VPTO,不得发明另一套公开中间 IR。 + +#### Scenario: low-level pointer and DMA surface lowers to legal authoring-form VPTO + +- **WHEN** 用户使用 `castptr`、`addptr`、raw UBRef、低层 DMA programming 或 `copy_ubuf_to_ubuf` +- **THEN** frontend MUST 生成对应的合法 authoring-form VPTO surface +- **AND** 输出结果 MUST 继续满足当前 copy/buffer-like/ptr-only 地址契约 + +### Requirement: advanced mode MUST extend lowering to advanced vector families in grouped capability sets + +advanced mode MUST 将以下 family 分组纳入正式 lowering capability: + +- compare/select +- predicate movement +- carry family +- rearrangement +- reduction + +对未进入这些 capability set 的 family,frontend MUST 继续显式 reject。 + +#### Scenario: advanced family kernel lowers without leaving the authoring-form VPTO contract + +- **WHEN** 用户在 advanced mode 下使用 compare/select、predicate movement、carry、rearrangement 或 reduction family 编写 kernel +- **THEN** frontend MUST 为该 family 生成合法的 authoring-form VPTO IR +- **AND** typed-mask、vecscope 和地址形态契约 MUST 与当前 VPTO legality contract 保持一致 + +#### Scenario: family outside the declared advanced capability set is still rejected + +- **WHEN** 用户使用未纳入上述 capability set 的 family +- **THEN** frontend MUST 继续报 unsupported-feature 错误 +- **AND** MUST NOT 因启用了 advanced mode 就默认放开全部 VPTO family diff --git a/openspec/changes/extend-tilelang-dsl-matcher-and-advanced-surface/specs/tilelang-dsl-kernel-matcher/spec.md b/openspec/changes/extend-tilelang-dsl-matcher-and-advanced-surface/specs/tilelang-dsl-kernel-matcher/spec.md new file mode 100644 index 000000000..69579df93 --- /dev/null +++ b/openspec/changes/extend-tilelang-dsl-matcher-and-advanced-surface/specs/tilelang-dsl-kernel-matcher/spec.md @@ -0,0 +1,57 @@ +# tilelang-dsl-kernel-matcher Specification + +## ADDED Requirements + +### Requirement: TileLang DSL MUST provide a deterministic kernel registry and selection API + +当同一 `target/op` 下存在多个 `@pto.vkernel` descriptor 时,TileLang DSL MUST 将它们注册到可查询的 registry。 +系统 MUST 提供显式 selection API,用于在给定 `target`、`op`、operand type 信息和上下文属性时选择唯一 kernel。 +选择过程 MUST deterministic。 + +#### Scenario: selector returns the unique best kernel + +- **WHEN** registry 中存在多个针对同一 `target/op` 的 kernel descriptor,且其中一个在全部匹配步骤后成为唯一最佳候选 +- **THEN** `pto.select_kernel(...)` MUST 返回该 descriptor +- **AND** 返回结果 MUST 可继续走 `specialize()` / `mlir_text()` / `verify()` 流程 + +### Requirement: matcher MUST support concrete types, `Any*`, and `TypeVar` across multiple signatures + +matcher MUST 支持: + +- 多个 `dtypes` signature +- `AnyFloat` +- `AnyInt` +- `AnyType` +- `AnyMask` +- `TypeVar` + +`TypeVar` 在单个 signature 内 MUST 约束所有同名位置绑定到同一最终类型。 + +#### Scenario: wildcard and type-variable signatures match deterministically + +- **WHEN** 某个 kernel 使用多个 `dtypes` signature,并在其中混用 concrete type、`Any*` 与 `TypeVar` +- **THEN** matcher MUST 对每个 signature 独立求值 +- **AND** 只有满足所有 `TypeVar` 一致性约束的 signature 才能视为匹配成功 + +### Requirement: constraint evaluation MUST happen after type matching and before priority resolution + +对同一 `target/op` 的候选集合,matcher MUST 先完成 dtype matching,再评估 `constraints`。 +只有通过 constraint evaluation 的候选,才允许进入 `priority` 比较阶段。 + +#### Scenario: higher-priority kernel with failing constraint does not win + +- **WHEN** 一个更高 `priority` 的 kernel 在 target/op/type 层面匹配成功,但 `constraints` 评估失败 +- **THEN** 该 kernel MUST 从候选集合中移除 +- **AND** selector MUST 继续在剩余候选中选择合法 kernel + +### Requirement: priority ties MUST raise an explicit selection error + +若在 target/op/type/constraint 全部通过后,最高 `priority` 仍对应多个候选,matcher MUST 报显式选择错误。 +系统 MUST NOT 依赖定义顺序、导入顺序或其他隐式规则做 tiebreak。 + +#### Scenario: equal-priority winners cause deterministic tie error + +- **WHEN** 多个 kernel 在 target/op/type/constraint 匹配后拥有相同的最高 `priority` +- **THEN** selector MUST 报错 +- **AND** 错误消息 MUST 指出发生 tie 的 kernel 集合 +- **AND** MUST NOT 静默选择第一个已注册 kernel diff --git a/openspec/changes/extend-tilelang-dsl-matcher-and-advanced-surface/tasks.md b/openspec/changes/extend-tilelang-dsl-matcher-and-advanced-surface/tasks.md new file mode 100644 index 000000000..184bc7fe5 --- /dev/null +++ b/openspec/changes/extend-tilelang-dsl-matcher-and-advanced-surface/tasks.md @@ -0,0 +1,24 @@ +## 1. OpenSpec 契约落定 + +- [ ] 1.1 新增 `openspec/changes/extend-tilelang-dsl-matcher-and-advanced-surface/specs/tilelang-dsl-kernel-matcher/spec.md`,固定 registry、selection API、wildcard/type-variable、constraint 与 priority 规则。 +- [ ] 1.2 新增 `openspec/changes/extend-tilelang-dsl-matcher-and-advanced-surface/specs/tilelang-dsl-advanced-surface/spec.md`,固定 implicit vecscope inference、raw pointer / low-level DMA 和 advanced family 的扩展契约。 +- [ ] 1.3 在 `proposal.md` 和 `design.md` 中明确该 change 依赖 core-foundation 与 authoring-vpto-lowering 两个前置 change。 + +## 2. Matcher 能力 + +- [ ] 2.1 在 `tilelang-dsl/python/` 中实现 `KernelRegistry` 和 `pto.select_kernel(...)` 入口。 +- [ ] 2.2 实现多 signature `dtypes`、`Any*`、`TypeVar` 的 matcher 语义和 deterministic selection 顺序。 +- [ ] 2.3 实现 `constraints` evaluation 与 `priority` 决策;对最高优先级 tie 保持显式报错。 + +## 3. Advanced surface + +- [ ] 3.1 实现 implicit vecscope inference,并保证 `strict_vecscope` 仍然是硬边界。 +- [ ] 3.2 扩展 raw pointer / UBRef / low-level DMA / `copy_ubuf_to_ubuf` surface 到 authoring-form VPTO lowering。 +- [ ] 3.3 扩展 compare/select、predicate movement、carry、rearrangement、reduction family 的 lowering 支持。 + +## 4. 测试与文档 + +- [ ] 4.1 在 `tilelang-dsl/tests/` 增加 matcher regression,覆盖 wildcard/type-variable、constraint fallback、priority tie error。 +- [ ] 4.2 增加 vecscope inference regression,覆盖连续 vector chain 自动分组、scalar/control-flow 边界切断、`strict_vecscope` 边界保留。 +- [ ] 4.3 增加 raw pointer / low-level DMA / advanced family regression,确认输出仍满足当前 VPTO legality contract。 +- [ ] 4.4 在 `tilelang-dsl/docs/` 更新从 v1 core 到 matcher/advanced-surface 的迁移说明,并同步 `docs/tilelang-dsl-guide.md` 的已支持/延期状态。 diff --git a/openspec/specs/tilelang-dsl-diagnostics/spec.md b/openspec/specs/tilelang-dsl-diagnostics/spec.md new file mode 100644 index 000000000..2d7d1988b --- /dev/null +++ b/openspec/specs/tilelang-dsl-diagnostics/spec.md @@ -0,0 +1,63 @@ +# TileLang DSL Diagnostics Specification + +## Purpose +TBD - created by archiving change add-tilelang-dsl-core-foundation. Update Purpose after archive. + +## Requirements +### Requirement: v1 MUST fail fast on unsupported matcher and decorator features + +TileLang DSL v1 frontend 对以下 surface MUST fail-fast,而不是静默忽略或拖到 lowering 阶段: + +- 多个 `dtypes` signature +- `constraints` +- `priority` +- `AnyFloat` / `AnyInt` / `AnyType` / `AnyMask` +- `TypeVar` + +#### Scenario: unsupported matcher feature is rejected at decorator parse time + +- **WHEN** 用户在 v1 kernel decorator 中写入 `constraints`、`priority`、多 signature `dtypes`、`Any*` 或 `TypeVar` +- **THEN** frontend MUST 直接报错 +- **AND** 诊断 MUST 明确指出该 feature 不属于 v1 范围 +- **AND** 诊断 SHOULD 指向 follow-up change `extend-tilelang-dsl-matcher-and-advanced-surface`,而不是伪装成底层 type error + +### Requirement: v1 MUST reject unsupported Python syntax and unsupported DSL calls before IR generation + +TileLang DSL v1 frontend MUST 只接受受限 Python 子集。 +`while`、list/dict/set comprehension、arbitrary external function call、未注册 DSL op、以及其他超出 v1 surface 的 Python 结构 MUST 在 frontend 被拒绝。 + +#### Scenario: unsupported Python construct is rejected before lowering + +- **WHEN** kernel body 使用 `while`、comprehension、任意非 `pto.*` function call 或未纳入 v1 support matrix 的 DSL call +- **THEN** frontend MUST 在生成任何 VPTO IR 之前报错 +- **AND** 诊断 MUST 指明违规的 Python construct 或 DSL call 名称 + +### Requirement: Tile specialization and shape-profile errors MUST be diagnosed in the frontend + +TileLang DSL v1 frontend MUST 把以下错误归类为前端错误: + +- bare `Tile` 参数未完成 specialization +- Tile physical shape 不是静态编译期常量 +- Tile profile 与 v1 支持的 rank / memory-space 约束不匹配 + +#### Scenario: unspecialized or dynamically-shaped tile fails before materialization + +- **WHEN** kernel 含 bare `Tile` 参数但调用方未完成 `specialize()`,或 specialization 试图给出 dynamic physical tile shape +- **THEN** frontend MUST 在 `mlir_text()` / `mlir_module()` / `verify()` 之前直接报错 +- **AND** MUST NOT 继续尝试生成不完整的 authoring-form VPTO IR + +### Requirement: frontend diagnostics MUST include source location and semantic cause + +TileLang DSL v1 的 frontend diagnostics MUST 包含 DSL 源位置和语义原因。 +错误消息 MUST 能区分: + +- decorator surface 不支持 +- Python 语法子集不支持 +- 参数定型失败 +- Tile specialization/profile 非法 + +#### Scenario: user sees actionable diagnostic with source location + +- **WHEN** frontend 因 unsupported feature、unsupported syntax、type binding failure 或 specialization error 拒绝一个 kernel +- **THEN** 诊断 MUST 至少包含 DSL 源文件位置、行列号或等价的 source span +- **AND** MUST 明确指出失败原因属于哪一层 frontend 语义,而不是只给出底层 verifier 或 parser 的通用报错 diff --git a/openspec/specs/tilelang-dsl-surface/spec.md b/openspec/specs/tilelang-dsl-surface/spec.md new file mode 100644 index 000000000..a3824074b --- /dev/null +++ b/openspec/specs/tilelang-dsl-surface/spec.md @@ -0,0 +1,56 @@ +# TileLang DSL Surface Specification + +## Purpose +TBD - created by archiving change add-tilelang-dsl-core-foundation. Update Purpose after archive. + +## Requirements +### Requirement: TileLang DSL v1 MUST live under `tilelang-dsl/` and expose a dedicated `tilelang_dsl` package + +TileLang DSL v1 的实现、样例、测试和局部文档 MUST 集中在 `tilelang-dsl/`。 +对外 import 入口 MUST 是独立 package `tilelang_dsl`,不得继续把本特性的核心逻辑建立在现有 `python/pto/dialects/pto.py` 的实验 DSL 上。 +根目录其他路径若有改动,MUST 仅限最小 build/install/test wiring。 + +#### Scenario: TileLang DSL source stays isolated from existing Python binding code + +- **WHEN** 仓库为 TileLang DSL v1 新增源码、样例、测试和局部文档 +- **THEN** 这些工件 MUST 放在 `tilelang-dsl/` 下 +- **AND** repo root 或 `python/` 现有目录树的改动 MUST 只承担最小接线职责 +- **AND** `python/pto/dialects/pto.py` MUST NOT 继续作为 TileLang DSL v1 的 source of truth + +### Requirement: v1 `@pto.vkernel` surface MUST be limited to the monomorphic `a5` profile + +TileLang DSL v1 的 `@pto.vkernel` MUST 只接受 `target="a5"`。 +`op` MUST 作为必填 metadata 保留。 +`dtypes` MUST 只包含一个 monomorphic signature tuple。 +`name` 和 `verify` MAY 保留为可选字段。 +v1 不在 public surface 中支持多 signature `dtypes`、`constraints`、`priority`、`Any*` 或 `TypeVar`。 + +#### Scenario: monomorphic a5 kernel descriptor is accepted + +- **WHEN** 用户定义 `@pto.vkernel(target="a5", op="scale", dtypes=[(pto.f32, pto.f32, pto.f32)])` +- **THEN** frontend MUST 接受该 decorator surface +- **AND** descriptor MUST 保留 `target/op/dtypes/name/verify` metadata 用于后续编译和调试 + +### Requirement: bare `TensorView` and `Tile` annotations MUST bind element types through the single `dtypes` signature + +在 v1 中,`TensorView` 和 `Tile` 参数 MUST 允许使用 bare annotation。 +其元素类型 MUST 由 decorator 的单个 `dtypes` signature 按参数位置绑定。 +标量参数 MUST 继续使用显式标量注解,并与 `dtypes` 中对应位置的标量类型保持一致。 + +#### Scenario: `dtypes` binds operand element types positionally + +- **WHEN** kernel 参数按位置写成 `TensorView, TensorView, Tile, pto.f32` +- **THEN** 单个 `dtypes` signature MUST 按同样的位置顺序提供两个 GM operand 的元素类型、一个 Tile operand 的元素类型和一个标量类型 +- **AND** frontend MUST 使用该 signature 作为参数定型的唯一来源 + +### Requirement: bare `Tile` parameters MUST require explicit specialization before IR materialization + +对 bare `Tile` 参数,frontend MUST 在 descriptor 上提供显式 specialization 入口。 +Tile 的 physical shape、memory space 和配置 MUST 在 specialization 阶段补全。 +在所有 bare `Tile` 参数完成 specialization 之前,descriptor MUST NOT 允许执行 `mlir_text()`, `mlir_module()`, `verify()` 或 `emit(path)`。 + +#### Scenario: specialized tile kernel can materialize IR + +- **WHEN** kernel 含 bare `Tile` 参数,且调用方通过 `descriptor.specialize(**bindings)` 为所有 bare `Tile` 参数补齐静态 shape / space / config +- **THEN** 返回的 specialized descriptor MUST 允许调用 `mlir_text()`, `mlir_module()`, `verify()` 和 `emit(path)` +- **AND** specialization 之后的 Tile physical shape MUST 作为编译期静态契约固定下来 From 0fdf4d3154630414680b7feb176664dbb7348ca8 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Sat, 4 Apr 2026 15:52:34 +0800 Subject: [PATCH 019/192] tilelang dsl first working version --- tilelang-dsl/README.md | 105 ++++ tilelang-dsl/docs/README.md | 2 + tilelang-dsl/docs/v1-lowering.md | 125 ++++ tilelang-dsl/docs/v1-surface.md | 5 +- tilelang-dsl/examples/README.md | 19 +- .../examples/v1_elementwise_tail_demo.py | 75 +++ tilelang-dsl/examples/v1_emit_mlir_demo.py | 19 +- .../v1_tadd_implicit_vecscope_demo.py | 82 +++ .../v1_tbinop_2d_nopostupdate_demo.py | 124 ++++ tilelang-dsl/examples/v1_verify_smoke.py | 67 +++ tilelang-dsl/python/tilelang_dsl/__init__.py | 2 + .../python/tilelang_dsl/frontend_ast.py | 2 + tilelang-dsl/python/tilelang_dsl/kernel.py | 279 +++++++-- tilelang-dsl/python/tilelang_dsl/lowering.py | 320 +++++++--- tilelang-dsl/python/tilelang_dsl/semantic.py | 555 +++++++++++++++++- .../python/tilelang_dsl/support_matrix.py | 91 +++ tilelang-dsl/python/tilelang_dsl/types.py | 18 + tilelang-dsl/tests/test_tilelang_dsl_v1.py | 349 ++++++++++- 18 files changed, 2057 insertions(+), 182 deletions(-) create mode 100644 tilelang-dsl/docs/v1-lowering.md create mode 100644 tilelang-dsl/examples/v1_elementwise_tail_demo.py create mode 100644 tilelang-dsl/examples/v1_tadd_implicit_vecscope_demo.py create mode 100644 tilelang-dsl/examples/v1_tbinop_2d_nopostupdate_demo.py create mode 100644 tilelang-dsl/examples/v1_verify_smoke.py create mode 100644 tilelang-dsl/python/tilelang_dsl/support_matrix.py diff --git a/tilelang-dsl/README.md b/tilelang-dsl/README.md index 1bf15d1f8..37d13f015 100644 --- a/tilelang-dsl/README.md +++ b/tilelang-dsl/README.md @@ -16,4 +16,109 @@ Layout: - `examples/`: self-contained examples - `docs/`: local documentation for this frontend +## How To Generate MLIR From A `.py` + +Run the examples from the repository root. + +If you are developing against the in-tree Python sources, point `PYTHONPATH` +at `tilelang-dsl/python`: + +```bash +cd /home/zhangzhendong/ptoas-workspace/PTOAS +PYTHONPATH=$PWD/tilelang-dsl/python python3 tilelang-dsl/examples/v1_emit_mlir_demo.py +PYTHONPATH=$PWD/tilelang-dsl/python python3 tilelang-dsl/examples/v1_emit_mlir_demo.py /tmp/tilelang_demo.mlir +``` + +If you already built and installed the Python package into the repo build tree, +you can also point `PYTHONPATH` at `build/python`: + +```bash +cd /home/zhangzhendong/ptoas-workspace/PTOAS +PYTHONPATH=$PWD/build/python python3 tilelang-dsl/examples/v1_emit_mlir_demo.py +``` + +Behavior: +- without an output path, the script prints MLIR to stdout +- with an output path, the script writes MLIR to that file through `emit(path)` + +Useful examples: + +```bash +PYTHONPATH=$PWD/tilelang-dsl/python python3 tilelang-dsl/examples/v1_elementwise_tail_demo.py +PYTHONPATH=$PWD/tilelang-dsl/python python3 tilelang-dsl/examples/v1_elementwise_tail_demo.py /tmp/tilelang_v1_elementwise.mlir +PYTHONPATH=$PWD/tilelang-dsl/python python3 tilelang-dsl/examples/v1_tadd_implicit_vecscope_demo.py +PYTHONPATH=$PWD/tilelang-dsl/python python3 tilelang-dsl/examples/v1_verify_smoke.py /tmp/tilelang_v1_verify.mlir +``` + +## Advanced Mode + +The default v1 surface still requires explicit `pto.strict_vecscope`. + +If you want the follow-up advanced surface for: +- implicit `pto.vecscope` inference +- `pto.vlds(tile[row, col:])` +- `pto.vsts(vec, tile[row, col:], mask)` + +set `advanced=True` on `@pto.vkernel` and follow +[`tilelang-dsl/examples/v1_tadd_implicit_vecscope_demo.py`](/home/zhangzhendong/ptoas-workspace/PTOAS/tilelang-dsl/examples/v1_tadd_implicit_vecscope_demo.py). + +## Minimal Script Pattern + +Your own `.py` only needs to: +- import `tilelang_dsl` +- define a `@pto.vkernel` +- call `specialize(...)` +- call `mlir_text()` or `emit(path)` + +Minimal example: + +```python +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + op="eltwise_with_tile", + dtypes=[(pto.f32, pto.f16, pto.i32)], + name="my_kernel", +) +def kernel(inp: pto.TensorView, tile: pto.Tile, scale: pto.i32): + return None + + +specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(16, 32), + memory_space=pto.MemorySpace.UB, + ) +) + +print(specialized.mlir_text()) +specialized.emit(Path("/tmp/my_kernel.mlir")) +``` + +If `python3 your_script.py` reports `ModuleNotFoundError: tilelang_dsl`, it +means the package import path is missing. Re-run with one of: + +```bash +PYTHONPATH=$PWD/tilelang-dsl/python python3 your_script.py +PYTHONPATH=$PWD/build/python python3 your_script.py +``` + +## Optional Verifier Check + +To check that the generated MLIR passes the current repo VPTO authoring-stage +legality path: + +```bash +source scripts/ptoas_env.sh +PYTHONPATH=$PWD/tilelang-dsl/python python3 tilelang-dsl/examples/v1_verify_smoke.py /tmp/tilelang_v1_verify.mlir +build/tools/ptoas/ptoas --pto-arch a5 --pto-backend=vpto --emit-vpto \ + /tmp/tilelang_v1_verify.mlir -o /tmp/tilelang_v1_verify.checked.mlir +``` + +For the implemented authoring-form VPTO lowering contract, support matrix, +examples, and minimal validation commands, see +`tilelang-dsl/docs/v1-lowering.md`. + Root-level wiring belongs to follow-up tasks and must stay minimal. diff --git a/tilelang-dsl/docs/README.md b/tilelang-dsl/docs/README.md index 357cc14a2..a4c489fef 100644 --- a/tilelang-dsl/docs/README.md +++ b/tilelang-dsl/docs/README.md @@ -3,6 +3,8 @@ TileLang DSL local documentation lives here. Current docs: - `v1-surface.md`: the TileLang DSL v1 contract implemented by `add-tilelang-dsl-core-foundation` +- `v1-lowering.md`: the TileLang DSL v1 authoring-form VPTO lowering contract + implemented by `add-tilelang-dsl-authoring-vpto-lowering` Documentation boundary: - `tilelang-dsl/docs/` is the local documentation source of truth for the new diff --git a/tilelang-dsl/docs/v1-lowering.md b/tilelang-dsl/docs/v1-lowering.md new file mode 100644 index 000000000..ef4da2752 --- /dev/null +++ b/tilelang-dsl/docs/v1-lowering.md @@ -0,0 +1,125 @@ +# TileLang DSL v1 Authoring Lowering + +## Scope + +This document records the implemented TileLang DSL v1 lowering contract for +`add-tilelang-dsl-authoring-vpto-lowering`. + +It covers: +- the current v1 lowering support matrix +- dynamic-bound and shape-profile behavior +- examples that match the implemented surface +- minimal validation commands, including the repo `ptoas` legality path + +It does not define: +- matcher-driven dispatch +- implicit vecscope inference +- raw pointer authoring surface +- advanced vector-family lowering beyond the fixed v1 matrix + +## Source Of Truth + +The implemented lowering surface lives under: +- `tilelang-dsl/python/tilelang_dsl/` +- `tilelang-dsl/tests/` +- `tilelang-dsl/examples/` +- `tilelang-dsl/docs/` + +OpenSpec source of truth for this capability: +- `openspec/changes/add-tilelang-dsl-authoring-vpto-lowering/` + +## Implemented v1 Support Matrix + +The current v1 lowering contract supports: +- 2D `TensorView` +- 1D/2D `Tile` +- `dma_load` +- `dma_store` +- `make_mask(dtype, PAT.*)` +- `make_mask(dtype, remaining)` +- `vlds` +- `vsts` +- unary vector family: `vabs`, `vrelu`, `vexp`, `vnot` +- binary vector family: `vadd`, `vsub`, `vmul`, `vdiv`, `vmax`, `vmin`, `vand`, `vor`, `vxor` +- vector-scalar family: `vadds`, `vsubs`, `vmuls`, `vdivs`, `vmaxs`, `vmins` +- `for range(lb, ub, step)` +- `if/else` +- `set_flag`, `wait_flag`, `pipe_barrier` + +Current lowering shape: +- emits stable `func.func + arith/scf + pto.*` authoring-form VPTO modules +- requires explicit `pto.strict_vecscope` +- rejects support-matrix-external surface in the frontend + +## Dynamic-Bound Profile + +The implemented shape profile is: +- Tile physical shape must stay static +- TensorView shape access may lower through hidden shape arguments +- TensorView slice bounds may be dynamic +- loop bounds may be dynamic +- tail `remaining` values may be dynamic + +The current DMA lowering still uses the static physical Tile shape when the +TensorView slice extent is dynamic. This keeps v1 inside the current +authoring-form contract without introducing fully dynamic Tile allocation or +tail-DMA semantics. + +## Examples + +Examples aligned with the implemented surface: +- `tilelang-dsl/examples/v1_elementwise_tail_demo.py` + - emits a guide-style elementwise authoring kernel + - covers DMA, explicit `strict_vecscope`, dynamic loop bound, and typed tail mask +- `tilelang-dsl/examples/v1_verify_smoke.py` + - emits a minimal module that is expected to pass the current repo + `ptoas --pto-backend=vpto` legality path + +Typical usage from the repository root: + +```bash +PYTHONPATH=$PWD/tilelang-dsl/python \ + python3 tilelang-dsl/examples/v1_elementwise_tail_demo.py + +PYTHONPATH=$PWD/tilelang-dsl/python \ + python3 tilelang-dsl/examples/v1_elementwise_tail_demo.py /tmp/tilelang_v1_elementwise.mlir + +PYTHONPATH=$PWD/tilelang-dsl/python \ + python3 tilelang-dsl/examples/v1_verify_smoke.py +``` + +## Deferred Features + +The following remain outside v1 and belong to follow-up changes: +- implicit vecscope inference +- matcher registry and deterministic selection +- raw pointer / low-level DMA / `copy_ubuf_to_ubuf` authoring surface +- compare/select, predicate movement, carry, rearrangement, reduction families +- wildcard / type-variable dtypes +- multiple `dtypes` signatures + +Primary follow-up change: +- `extend-tilelang-dsl-matcher-and-advanced-surface` + +## Minimal Validation + +The minimal validation set for the implemented v1 lowering is: + +```bash +python3 -m py_compile tilelang-dsl/python/tilelang_dsl/*.py + +PYTHONPATH=$PWD/tilelang-dsl/python \ + python3 -m unittest $PWD/tilelang-dsl/tests/test_tilelang_dsl_v1.py + +PYTHONPATH=$PWD/tilelang-dsl/python \ + python3 tilelang-dsl/examples/v1_verify_smoke.py /tmp/tilelang_v1_verify.mlir + +build/tools/ptoas/ptoas --pto-arch a5 --pto-backend=vpto --emit-vpto \ + /tmp/tilelang_v1_verify.mlir -o /tmp/tilelang_v1_verify.checked.mlir +``` + +What these commands confirm: +- the standalone source-tree package imports and compiles +- the focused unittest suite passes for lowering, diagnostics, and verify behavior +- a generated TileLang DSL v1 module can be emitted to MLIR +- the emitted verify-smoke module passes the repo VPTO authoring-stage legality path diff --git a/tilelang-dsl/docs/v1-surface.md b/tilelang-dsl/docs/v1-surface.md index 225b566cc..c3a542684 100644 --- a/tilelang-dsl/docs/v1-surface.md +++ b/tilelang-dsl/docs/v1-surface.md @@ -20,6 +20,9 @@ It does not define: - advanced vector-family surface - implicit vecscope inference +For implemented lowering details, examples, and `verify()` behavior, see +`tilelang-dsl/docs/v1-lowering.md`. + ## Source Of Truth TileLang DSL v1 source of truth lives under: @@ -228,8 +231,6 @@ changes: - implicit vecscope inference - raw pointer authoring surface - advanced vector-family support -- final TileLang DSL to VPTO lowering implementation Matcher-related extensions are deferred to `extend-tilelang-dsl-matcher-and-advanced-surface`. -Lowering work is deferred to `add-tilelang-dsl-authoring-vpto-lowering`. diff --git a/tilelang-dsl/examples/README.md b/tilelang-dsl/examples/README.md index 1c9aac4d7..116c2439a 100644 --- a/tilelang-dsl/examples/README.md +++ b/tilelang-dsl/examples/README.md @@ -3,14 +3,25 @@ TileLang DSL examples live here. Examples in this subtree should import `tilelang_dsl` as their package entrypoint once the package wiring is added. -Current example: -- `v1_emit_mlir_demo.py`: define a v1 `@pto.vkernel`, specialize a bare - `Tile`, and materialize the result as MLIR text or an `.mlir` file +Current examples: +- `v1_emit_mlir_demo.py`: minimal descriptor/materialization demo +- `v1_elementwise_tail_demo.py`: guide-aligned elementwise authoring demo that + covers DMA, explicit `strict_vecscope`, dynamic loop bound, and typed tail + mask lowering +- `v1_tadd_implicit_vecscope_demo.py`: advanced-mode flattened `TADD` example + with implicit `pto.vecscope` inference and `vlds`/`vsts` tile indexing sugar +- `v1_tbinop_2d_nopostupdate_demo.py`: a representative TileLang DSL v1 + expansion of `pto::TBinOps_2D_NoPostUpdate` using `vadd` +- `v1_verify_smoke.py`: minimal verify smoke that is expected to pass the repo + `ptoas --pto-backend=vpto` legality path Typical usage from the repository root: ```bash -cmake --build build --target TileLangDSLPackage python3 tilelang-dsl/examples/v1_emit_mlir_demo.py python3 tilelang-dsl/examples/v1_emit_mlir_demo.py /tmp/tilelang_demo.mlir +PYTHONPATH=$PWD/tilelang-dsl/python python3 tilelang-dsl/examples/v1_elementwise_tail_demo.py +PYTHONPATH=$PWD/tilelang-dsl/python python3 tilelang-dsl/examples/v1_tadd_implicit_vecscope_demo.py +PYTHONPATH=$PWD/tilelang-dsl/python python3 tilelang-dsl/examples/v1_tbinop_2d_nopostupdate_demo.py +PYTHONPATH=$PWD/tilelang-dsl/python python3 tilelang-dsl/examples/v1_verify_smoke.py ``` diff --git a/tilelang-dsl/examples/v1_elementwise_tail_demo.py b/tilelang-dsl/examples/v1_elementwise_tail_demo.py new file mode 100644 index 000000000..75b3ec4dd --- /dev/null +++ b/tilelang-dsl/examples/v1_elementwise_tail_demo.py @@ -0,0 +1,75 @@ +"""Guide-aligned TileLang DSL v1 elementwise authoring demo.""" + +import sys +from pathlib import Path + + +def _import_tilelang_dsl(): + repo_root = Path(__file__).resolve().parents[2] + candidates = ( + repo_root / "tilelang-dsl" / "python", + repo_root / "build" / "python", + ) + for candidate in reversed(candidates): + if candidate.is_dir(): + sys.path.insert(0, str(candidate)) + import tilelang_dsl as pto + + return pto + + +pto = _import_tilelang_dsl() + + +@pto.vkernel( + op="eltwise", + dtypes=[(pto.f32, pto.f32, pto.f32, pto.i32)], + name="tilelang_v1_elementwise_tail_demo", +) +def kernel(inp: pto.TensorView, out: pto.TensorView, tile: pto.Tile, remaining: pto.i32): + rows = inp.shape[0] + pto.dma_load(inp[0:rows, 0:16], tile) + with pto.strict_vecscope(tile, tile, remaining, 0, rows, 64) as ( + src, + dst, + rem, + lb, + ub, + step, + ): + for lane in range(lb, ub, step): + mask, rem = pto.make_mask(pto.f32, rem) + vec = pto.vlds(src, lane) + pto.vsts(vec, dst, lane, mask) + pto.dma_store(tile, out[0:rows, 0:16]) + return None + + +def build_specialized_kernel(): + return kernel.specialize( + tile=pto.TileSpecialization( + shape=(16, 16), + memory_space=pto.MemorySpace.UB, + ) + ) + + +def main(argv) -> int: + specialized = build_specialized_kernel() + + if len(argv) > 2: + print(f"usage: {Path(argv[0]).name} [output.mlir]", file=sys.stderr) + return 2 + + if len(argv) == 2: + output_path = Path(argv[1]) + specialized.emit(output_path) + print(f"wrote MLIR to {output_path}") + return 0 + + print(specialized.mlir_text()) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main(sys.argv)) diff --git a/tilelang-dsl/examples/v1_emit_mlir_demo.py b/tilelang-dsl/examples/v1_emit_mlir_demo.py index 770bf12ea..68736477c 100644 --- a/tilelang-dsl/examples/v1_emit_mlir_demo.py +++ b/tilelang-dsl/examples/v1_emit_mlir_demo.py @@ -5,16 +5,17 @@ def _import_tilelang_dsl(): - try: - import tilelang_dsl as pto - - return pto - except ModuleNotFoundError: - repo_root = Path(__file__).resolve().parents[2] - sys.path.insert(0, str(repo_root / "build" / "python")) - import tilelang_dsl as pto + repo_root = Path(__file__).resolve().parents[2] + candidates = ( + repo_root / "tilelang-dsl" / "python", + repo_root / "build" / "python", + ) + for candidate in reversed(candidates): + if candidate.is_dir(): + sys.path.insert(0, str(candidate)) + import tilelang_dsl as pto - return pto + return pto pto = _import_tilelang_dsl() diff --git a/tilelang-dsl/examples/v1_tadd_implicit_vecscope_demo.py b/tilelang-dsl/examples/v1_tadd_implicit_vecscope_demo.py new file mode 100644 index 000000000..429de8a33 --- /dev/null +++ b/tilelang-dsl/examples/v1_tadd_implicit_vecscope_demo.py @@ -0,0 +1,82 @@ +"""Flattened TileLang DSL advanced-mode version of A5 `TADD_IMPL`. + +This example mirrors the user-facing `TADD_IMPL -> TAdd -> BinaryInstr -> +TBinOps_2D_NoPostUpdate` flow from `pto/npu/a5/TAdd.hpp`, but spells the final +2D row-major vector body directly in Python: + +- top-level interface uses `dst, src0, src1` Tile parameters like `TADD` +- `advanced=True` enables implicit `pto.vecscope` inference +- `pto.vlds(tile[row, col:])` / `pto.vsts(vec, tile[row, col:], mask)` use + tile indexing sugar instead of manual offset arithmetic +""" + +import sys +from pathlib import Path + + +def _import_tilelang_dsl(): + repo_root = Path(__file__).resolve().parents[2] + candidates = ( + repo_root / "tilelang-dsl" / "python", + repo_root / "build" / "python", + ) + for candidate in reversed(candidates): + if candidate.is_dir(): + sys.path.insert(0, str(candidate)) + import tilelang_dsl as pto + + return pto + + +pto = _import_tilelang_dsl() + + +@pto.vkernel( + op="tadd", + dtypes=[(pto.f32, pto.f32, pto.f32)], + advanced=True, + name="tilelang_advanced_tadd_demo", +) +def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + # Flattened equivalent of the TAddCheck/TADD_IMPL parameter plumbing. + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + summed = pto.vadd(lhs, rhs, mask) + pto.vsts(summed, dst[row, col:], mask) + return None + + +def build_specialized_kernel(): + return kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src0=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src1=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + +def main(argv) -> int: + specialized = build_specialized_kernel() + + if len(argv) > 2: + print(f"usage: {Path(argv[0]).name} [output.mlir]", file=sys.stderr) + return 2 + + if len(argv) == 2: + output_path = Path(argv[1]) + specialized.emit(output_path) + print(f"wrote MLIR to {output_path}") + return 0 + + print(specialized.mlir_text()) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main(sys.argv)) diff --git a/tilelang-dsl/examples/v1_tbinop_2d_nopostupdate_demo.py b/tilelang-dsl/examples/v1_tbinop_2d_nopostupdate_demo.py new file mode 100644 index 000000000..5809bbe3a --- /dev/null +++ b/tilelang-dsl/examples/v1_tbinop_2d_nopostupdate_demo.py @@ -0,0 +1,124 @@ +"""Representative TileLang DSL v1 form of `TBinOps_2D_NoPostUpdate`. + +This example mirrors the key structure from `pto::TBinOps_2D_NoPostUpdate`: +- two source UB tiles and one destination UB tile +- row-major 2D traversal +- explicit non-post-update absolute offsets: `row * row_stride + lane` +- binary vector op lowered as `pto.vadd` + +The TileLang DSL surface does not expose the C++ helper template directly, so +this example spells out the row/repeat loops and tail mask construction in the +authored Python kernel. +""" + +import sys +from pathlib import Path + + +def _import_tilelang_dsl(): + repo_root = Path(__file__).resolve().parents[2] + candidates = ( + repo_root / "tilelang-dsl" / "python", + repo_root / "build" / "python", + ) + for candidate in reversed(candidates): + if candidate.is_dir(): + sys.path.insert(0, str(candidate)) + import tilelang_dsl as pto + + return pto + + +pto = _import_tilelang_dsl() + + +@pto.vkernel( + op="eltwise", + dtypes=[(pto.f32, pto.f32, pto.f32, pto.f32, pto.f32, pto.f32)], + name="tilelang_v1_tbinop_2d_nopostupdate_demo", +) +def kernel( + lhs_gm: pto.TensorView, + rhs_gm: pto.TensorView, + out_gm: pto.TensorView, + lhs_tile: pto.Tile, + rhs_tile: pto.Tile, + dst_tile: pto.Tile, +): + rows = lhs_gm.shape[0] + cols = lhs_gm.shape[1] + row_stride = lhs_tile.shape[1] + + pto.dma_load(lhs_gm[0:rows, 0:cols], lhs_tile) + pto.dma_load(rhs_gm[0:rows, 0:cols], rhs_tile) + + with pto.strict_vecscope( + lhs_tile, + rhs_tile, + dst_tile, + rows, + cols, + row_stride, + 0, + rows, + 1, + ) as ( + lhs, + rhs, + dst, + valid_rows, + valid_cols, + stride, + row_lb, + row_ub, + row_step, + ): + for row in range(row_lb, row_ub, row_step): + for lane in range(0, valid_cols, 64): + offset = row * stride + lane + mask, next_remaining = pto.make_mask(pto.f32, valid_cols - lane) + lhs_vec = pto.vlds(lhs, offset) + rhs_vec = pto.vlds(rhs, offset) + summed = pto.vadd(lhs_vec, rhs_vec, mask) + pto.vsts(summed, dst, offset, mask) + + pto.dma_store(dst_tile, out_gm[0:rows, 0:cols]) + return None + + +def build_specialized_kernel(): + return kernel.specialize( + lhs_tile=pto.TileSpecialization( + shape=(8, 64), + memory_space=pto.MemorySpace.UB, + ), + rhs_tile=pto.TileSpecialization( + shape=(8, 64), + memory_space=pto.MemorySpace.UB, + ), + dst_tile=pto.TileSpecialization( + shape=(8, 64), + memory_space=pto.MemorySpace.UB, + ), + ) + + +def main(argv) -> int: + specialized = build_specialized_kernel() + + if len(argv) > 2: + print(f"usage: {Path(argv[0]).name} [output.mlir]", file=sys.stderr) + return 2 + + if len(argv) == 2: + output_path = Path(argv[1]) + specialized.emit(output_path) + print(f"wrote MLIR to {output_path}") + return 0 + + print(specialized.mlir_text()) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main(sys.argv)) diff --git a/tilelang-dsl/examples/v1_verify_smoke.py b/tilelang-dsl/examples/v1_verify_smoke.py new file mode 100644 index 000000000..adba64fa4 --- /dev/null +++ b/tilelang-dsl/examples/v1_verify_smoke.py @@ -0,0 +1,67 @@ +"""Minimal TileLang DSL v1 verify smoke for the repo PTOAS legality path.""" + +import sys +from pathlib import Path + + +def _import_tilelang_dsl(): + repo_root = Path(__file__).resolve().parents[2] + candidates = ( + repo_root / "tilelang-dsl" / "python", + repo_root / "build" / "python", + ) + for candidate in reversed(candidates): + if candidate.is_dir(): + sys.path.insert(0, str(candidate)) + import tilelang_dsl as pto + + return pto + + +pto = _import_tilelang_dsl() + + +@pto.vkernel( + op="eltwise", + dtypes=[(pto.f32, pto.f32)], + name="tilelang_v1_verify_smoke", +) +def kernel(inp: pto.TensorView, tile: pto.Tile): + return None + + +def build_specialized_kernel(): + return kernel.specialize( + tile=pto.TileSpecialization( + shape=(8, 16), + memory_space=pto.MemorySpace.UB, + ) + ) + + +def main(argv) -> int: + specialized = build_specialized_kernel() + + if len(argv) > 2: + print(f"usage: {Path(argv[0]).name} [output.mlir]", file=sys.stderr) + return 2 + + if len(argv) == 2: + output_path = Path(argv[1]) + specialized.emit(output_path) + print(f"wrote MLIR to {output_path}") + return 0 + + result = specialized.verify() + print(f"status={result.status}") + print(f"available={result.available}") + print(f"passed={result.passed}") + if result.command is not None: + print("command=" + " ".join(result.command)) + if result.message: + print(f"message={result.message}") + return 0 if result else 1 + + +if __name__ == "__main__": + raise SystemExit(main(sys.argv)) diff --git a/tilelang-dsl/python/tilelang_dsl/__init__.py b/tilelang-dsl/python/tilelang_dsl/__init__.py index ca9ea233b..f1bf887af 100644 --- a/tilelang-dsl/python/tilelang_dsl/__init__.py +++ b/tilelang-dsl/python/tilelang_dsl/__init__.py @@ -30,6 +30,7 @@ bf16, f16, f32, + get_lanes, i1, i8, i16, @@ -70,4 +71,5 @@ "AnyInt", "AnyType", "AnyMask", + "get_lanes", ] diff --git a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py index 5c875d797..ca7a9c361 100644 --- a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py +++ b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py @@ -145,6 +145,7 @@ class FrontendKernelNode: op: str name: str verify_enabled: bool + advanced_enabled: bool dtype_signature: tuple[Any, ...] parameters: tuple[FrontendParameterNode, ...] tile_specializations: tuple[FrontendTileSpecializationNode, ...] @@ -346,6 +347,7 @@ def build_frontend_kernel_node(descriptor: Any) -> FrontendKernelNode: op=descriptor.op, name=descriptor.name, verify_enabled=descriptor.verify_enabled, + advanced_enabled=descriptor.advanced_enabled, dtype_signature=descriptor.dtype_signature, parameters=parameters, tile_specializations=tile_specializations, diff --git a/tilelang-dsl/python/tilelang_dsl/kernel.py b/tilelang-dsl/python/tilelang_dsl/kernel.py index f27d00e77..cdd16e3a5 100644 --- a/tilelang-dsl/python/tilelang_dsl/kernel.py +++ b/tilelang-dsl/python/tilelang_dsl/kernel.py @@ -2,9 +2,12 @@ from __future__ import annotations +import os import inspect -import textwrap import ast +import subprocess +import tempfile +import textwrap from dataclasses import dataclass, field from pathlib import Path from typing import Any, Callable @@ -22,66 +25,33 @@ from .frontend_ast import build_frontend_kernel_node from .lowering import lower_semantic_kernel from .semantic import analyze_frontend_kernel +from .support_matrix import ( + DEFERRED_PTO_SURFACES, + SUPPORTED_TOPLEVEL_PTO_CALLS, + SUPPORTED_VECSCOPE_PTO_CALLS, + unsupported_feature_message, + deferred_surface_message, +) _UNSET = object() -_MATCHER_FOLLOW_UP_CHANGE = "extend-tilelang-dsl-matcher-and-advanced-surface" -_V1_ALLOWED_TOPLEVEL_PTO_CALLS = { - "strict_vecscope", - "dma_load", - "dma_store", - "set_flag", - "wait_flag", - "pipe_barrier", - "barrier", -} -_V1_ALLOWED_VECSCOPE_PTO_CALLS = { - "make_mask", - "vlds", - "vsts", - "vabs", - "vrelu", - "vexp", - "vnot", - "vadd", - "vsub", - "vmul", - "vdiv", - "vmax", - "vmin", - "vand", - "vor", - "vxor", - "vadds", - "vsubs", - "vmuls", - "vdivs", - "vmaxs", - "vmins", -} - - -def _unsupported_feature_message(feature: str) -> str: - return ( - f"{feature} is not supported in TileLang DSL v1; " - f"see follow-up change `{_MATCHER_FOLLOW_UP_CHANGE}`" - ) +_PTOAS_BIN_ENV = "PTOAS_BIN" def _reject_unsupported_decorator_feature(name: str, value: Any) -> None: if value is _UNSET: return - raise ValueError(_unsupported_feature_message(f"decorator feature `{name}`")) + raise ValueError(unsupported_feature_message(f"decorator feature `{name}`")) def _reject_unsupported_dtype_feature(dtype: Any) -> None: if isinstance(dtype, WildcardType): raise ValueError( - _unsupported_feature_message(f"dtype wildcard `{dtype.name}`") + unsupported_feature_message(f"dtype wildcard `{dtype.name}`") ) if isinstance(dtype, TypeVariable): raise ValueError( - _unsupported_feature_message(f"dtype type variable `{dtype.name}`") + unsupported_feature_message(f"dtype type variable `{dtype.name}`") ) @@ -119,8 +89,9 @@ def parameter_node(self, param_name: str) -> ast.AST | None: class _KernelBodyValidator(ast.NodeVisitor): - def __init__(self, source_info: _FunctionSourceInfo): + def __init__(self, source_info: _FunctionSourceInfo, *, advanced_enabled: bool): self.source_info = source_info + self.advanced_enabled = advanced_enabled self._vecscope_depth = 0 def validate(self) -> None: @@ -200,15 +171,22 @@ def visit_With(self, node: ast.With) -> None: def visit_Call(self, node: ast.Call) -> None: if isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name): - if node.func.value.id == "pto" and node.func.attr in _V1_ALLOWED_TOPLEVEL_PTO_CALLS: + if node.func.value.id == "pto" and node.func.attr in SUPPORTED_TOPLEVEL_PTO_CALLS: return - if node.func.value.id == "pto" and node.func.attr in _V1_ALLOWED_VECSCOPE_PTO_CALLS: + if node.func.value.id == "pto" and node.func.attr in SUPPORTED_VECSCOPE_PTO_CALLS: + if self.advanced_enabled: + return if self._vecscope_depth <= 0: raise self.source_info.error( node, f"vector op surface `pto.{node.func.attr}` requires explicit pto.strict_vecscope in TileLang DSL v1", ) return + if node.func.value.id == "pto" and node.func.attr in DEFERRED_PTO_SURFACES: + raise self.source_info.error( + node, + deferred_surface_message(node.func.attr), + ) if node.func.value.id == "pto": raise self.source_info.error( node, @@ -249,10 +227,17 @@ def _load_function_source_info(py_fn: Callable[..., Any]) -> _FunctionSourceInfo return None -def _validate_function_body(source_info: _FunctionSourceInfo | None) -> None: +def _validate_function_body( + source_info: _FunctionSourceInfo | None, + *, + advanced_enabled: bool, +) -> None: if source_info is None: return - _KernelBodyValidator(source_info).validate() + _KernelBodyValidator( + source_info, + advanced_enabled=advanced_enabled, + ).validate() def _raise_tile_param_error( @@ -286,7 +271,7 @@ def _freeze_dtypes(dtypes: Any) -> tuple[tuple[Any, ...], ...]: if len(frozen_signatures) != 1: raise ValueError( - _unsupported_feature_message("multiple dtypes signatures") + unsupported_feature_message("multiple dtypes signatures") ) return tuple(frozen_signatures) @@ -317,6 +302,7 @@ class VKernelDescriptor: dtypes: tuple[tuple[Any, ...], ...] name: str verify_enabled: bool + advanced_enabled: bool parameters: tuple[BoundKernelParameter, ...] _py_fn: Callable[..., Any] = field(repr=False) _source_info: _FunctionSourceInfo | None = field(repr=False, compare=False, default=None) @@ -338,6 +324,7 @@ def metadata(self) -> dict[str, Any]: "dtypes": self.dtypes, "name": self.name, "verify": self.verify_enabled, + "advanced": self.advanced_enabled, } @property @@ -375,6 +362,7 @@ def specialize(self, **bindings: Any) -> "VKernelDescriptor": dtypes=self.dtypes, name=self.name, verify_enabled=self.verify_enabled, + advanced_enabled=self.advanced_enabled, parameters=self.parameters, _source_info=self._source_info, specializations=tuple(sorted(updated.items())), @@ -408,12 +396,11 @@ def mlir_text(self) -> str: def mlir_module(self) -> "MaterializedMLIRModule": self._require_specialized_tiles("mlir_module") - return MaterializedMLIRModule(self.mlir_text()) + return MaterializedMLIRModule(text=self.mlir_text(), target=self.target) - def verify(self) -> bool: + def verify(self, *, ptoas_bin: str | Path | None = None) -> "VerificationResult": self._require_specialized_tiles("verify") - self.mlir_module() - return True + return self.mlir_module().verify(ptoas_bin=ptoas_bin) def emit(self, path: str | Path) -> None: self._require_specialized_tiles("emit") @@ -424,12 +411,175 @@ def emit(self, path: str | Path) -> None: @dataclass(frozen=True) class MaterializedMLIRModule: text: str + target: str = "a5" def __str__(self) -> str: return self.text - def verify(self) -> bool: - return True + def verify(self, *, ptoas_bin: str | Path | None = None) -> "VerificationResult": + return _run_ptoas_verifier(self.text, target=self.target, ptoas_bin=ptoas_bin) + + +@dataclass(frozen=True) +class VerificationResult: + status: str + available: bool + passed: bool + message: str + command: tuple[str, ...] | None = None + returncode: int | None = None + stdout: str = "" + stderr: str = "" + + @property + def ok(self) -> bool: + return self.available and self.passed + + def __bool__(self) -> bool: + return self.ok + + +def _repo_root() -> Path: + return Path(__file__).resolve().parents[3] + + +def _resolve_ptoas_bin(ptoas_bin: str | Path | None) -> Path: + if ptoas_bin is not None: + return Path(ptoas_bin) + env_path = os.environ.get(_PTOAS_BIN_ENV) + if env_path: + return Path(env_path) + return _repo_root() / "build/tools/ptoas/ptoas" + + +def _unavailable_result( + message: str, + *, + command: tuple[str, ...] | None = None, + stderr: str = "", +) -> VerificationResult: + return VerificationResult( + status="unavailable", + available=False, + passed=False, + message=message, + command=command, + stderr=stderr, + ) + + +def _failed_result( + message: str, + *, + command: tuple[str, ...], + returncode: int, + stdout: str, + stderr: str, +) -> VerificationResult: + return VerificationResult( + status="failed", + available=True, + passed=False, + message=message, + command=command, + returncode=returncode, + stdout=stdout, + stderr=stderr, + ) + + +def _passed_result( + *, + command: tuple[str, ...], + stdout: str, + stderr: str, +) -> VerificationResult: + return VerificationResult( + status="passed", + available=True, + passed=True, + message="generated IR passed the repo VPTO authoring-stage legality verifier", + command=command, + returncode=0, + stdout=stdout, + stderr=stderr, + ) + + +def _is_verifier_unavailable_process_failure(stderr: str) -> bool: + lowered = stderr.lower() + return ( + "error while loading shared libraries" in lowered + or "cannot open shared object file" in lowered + or "image not found" in lowered + or "dll load failed" in lowered + ) + + +def _run_ptoas_verifier( + mlir_text: str, + *, + target: str, + ptoas_bin: str | Path | None, +) -> VerificationResult: + binary = _resolve_ptoas_bin(ptoas_bin) + command = ( + str(binary), + "--pto-arch", + target, + "--pto-backend=vpto", + "--emit-vpto", + ) + if not binary.exists(): + return _unavailable_result( + f"verifier unavailable: missing ptoas binary at {binary}", + command=command, + ) + if not os.access(binary, os.X_OK): + return _unavailable_result( + f"verifier unavailable: ptoas binary is not executable: {binary}", + command=command, + ) + + try: + with tempfile.TemporaryDirectory(prefix="tilelang_dsl_verify_") as tmpdir: + tmpdir_path = Path(tmpdir) + input_path = tmpdir_path / "kernel.mlir" + output_path = tmpdir_path / "verified.mlir" + input_path.write_text(mlir_text, encoding="utf-8") + full_command = command + (str(input_path), "-o", str(output_path)) + completed = subprocess.run( + full_command, + cwd=_repo_root(), + text=True, + capture_output=True, + check=False, + ) + except OSError as exc: + return _unavailable_result( + f"verifier unavailable: failed to execute ptoas: {exc}", + command=command, + stderr=str(exc), + ) + + stderr = completed.stderr.strip() + stdout = completed.stdout.strip() + if completed.returncode == 0: + return _passed_result(command=full_command, stdout=stdout, stderr=stderr) + if _is_verifier_unavailable_process_failure(stderr): + return _unavailable_result( + "verifier unavailable: failed to launch repo ptoas legality path", + command=full_command, + stderr=stderr, + ) + message = stderr or stdout or "generated IR failed the repo VPTO authoring-stage legality verifier" + return _failed_result( + message, + command=full_command, + returncode=completed.returncode, + stdout=stdout, + stderr=stderr, + ) def _validate_target(target: str) -> str: @@ -460,6 +610,12 @@ def _validate_verify(verify: Any) -> bool: return verify +def _validate_advanced(advanced: Any) -> bool: + if not isinstance(advanced, bool): + raise TypeError("advanced must be a bool") + return advanced + + def _coerce_memory_space(value: Any, param_name: str) -> MemorySpace: if isinstance(value, MemorySpace): return value @@ -651,12 +807,14 @@ def _build_descriptor( dtypes: Any, name: Any, verify: Any, + advanced: Any, ) -> VKernelDescriptor: if not callable(py_fn): raise TypeError("@vkernel can only decorate callables") source_info = _load_function_source_info(py_fn) - _validate_function_body(source_info) + advanced_enabled = _validate_advanced(advanced) + _validate_function_body(source_info, advanced_enabled=advanced_enabled) frozen_dtypes = _freeze_dtypes(dtypes) return VKernelDescriptor( @@ -665,6 +823,7 @@ def _build_descriptor( dtypes=frozen_dtypes, name=_validate_name(py_fn, name), verify_enabled=_validate_verify(verify), + advanced_enabled=advanced_enabled, parameters=_bind_parameters(py_fn, frozen_dtypes), _py_fn=py_fn, _source_info=source_info, @@ -679,13 +838,14 @@ def vkernel( dtypes: Any = None, name: str | None = None, verify: bool = True, + advanced: bool = False, constraints: Any = _UNSET, priority: Any = _UNSET, ) -> VKernelDescriptor | Callable[[Callable[..., Any]], VKernelDescriptor]: """Create a TileLang DSL v1 kernel descriptor. v1 keeps only the minimal descriptor metadata surface: - `target`, `op`, `dtypes`, `name`, and `verify`. + `target`, `op`, `dtypes`, `name`, `verify`, and opt-in `advanced`. """ _reject_unsupported_decorator_feature("constraints", constraints) _reject_unsupported_decorator_feature("priority", priority) @@ -698,6 +858,7 @@ def wrap(fn: Callable[..., Any]) -> VKernelDescriptor: dtypes=dtypes, name=name, verify=verify, + advanced=advanced, ) if py_fn is None: diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index 04b30ba5e..d9037c3e5 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -22,11 +22,14 @@ SemanticKernel, SemanticLiteralExpr, SemanticMaskType, + SemanticMetaType, + SemanticBindingRef, SemanticPipeBarrierStmt, SemanticReturnStmt, SemanticScalarType, SemanticSetFlagStmt, SemanticStmt, + SemanticVecscopeStmt, SemanticStrictVecscopeStmt, SemanticSubscriptAccess, SemanticSymbolExpr, @@ -34,6 +37,8 @@ SemanticTensorViewType, SemanticTileType, SemanticType, + SemanticTupleExpr, + SemanticTupleType, SemanticVRegType, SemanticVectorStoreStmt, SemanticWaitFlagStmt, @@ -42,6 +47,7 @@ _I1_TYPE = SemanticScalarType(dtype=ScalarType("i1")) +_I32_TYPE = SemanticScalarType(dtype=ScalarType("i32")) _I64_TYPE = SemanticScalarType(dtype=ScalarType("i64")) @@ -92,6 +98,7 @@ def render(self) -> str: f"// tilelang.op = {self.kernel.op}", f"// tilelang.dtypes = {self.kernel.dtype_signature}", f"// tilelang.verify = {self.kernel.verify_enabled}", + f"// tilelang.advanced = {self.kernel.advanced_enabled}", ] for binding in self.kernel.tile_bindings: lines.append( @@ -157,6 +164,8 @@ def _render_stmt( return [self._indent(indent) + "return"] value = self._lower_expr(stmt.value, env, indent=indent) return [self._indent(indent) + f"return {value.name} : {self._render_type(value.type)}"] + if isinstance(stmt, SemanticVecscopeStmt): + return self._render_vecscope(stmt, env, indent=indent) if isinstance(stmt, SemanticStrictVecscopeStmt): return self._render_strict_vecscope(stmt, env, indent=indent) if isinstance(stmt, SemanticForStmt): @@ -173,8 +182,13 @@ def _render_assign( indent: int, ) -> list[str]: if len(stmt.targets) != 1: - raise NotImplementedError("multiple-result assignment is not supported in TileLang DSL v1 yet") + if isinstance(stmt.value, SemanticTupleExpr): + return self._render_tuple_expr_assign(stmt, env, indent=indent) + return self._render_multi_result_assign(stmt, env, indent=indent) target = stmt.targets[0] + if isinstance(target.type, SemanticMetaType): + env[target.name] = _RenderedValue(name=target.ssa_name, type=target.type) + return [] lines: list[str] = [] lowered = self._lower_expr( stmt.value, @@ -186,6 +200,66 @@ def _render_assign( env[target.name] = lowered return lines + def _render_tuple_expr_assign( + self, + stmt: SemanticAssignStmt, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + if not isinstance(stmt.value, SemanticTupleExpr): + raise NotImplementedError("tuple expression assignment expects a SemanticTupleExpr") + if len(stmt.targets) != len(stmt.value.elements): + raise NotImplementedError("tuple expression assignment arity mismatch") + + lines: list[str] = [] + for target, element in zip(stmt.targets, stmt.value.elements): + lowered = self._lower_expr( + element, + env, + indent=indent, + desired_name=target.ssa_name, + into=lines, + ) + env[target.name] = lowered + return lines + + def _render_multi_result_assign( + self, + stmt: SemanticAssignStmt, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + if not isinstance(stmt.value, SemanticCallExpr): + raise NotImplementedError("multi-result assignment expects a call expression in TileLang DSL v1") + if stmt.value.namespace != "pto" or stmt.value.name != "make_mask": + raise NotImplementedError( + f"multi-result assignment for `pto.{stmt.value.name}` is not supported in TileLang DSL v1" + ) + if len(stmt.targets) != 2: + raise NotImplementedError("tail make_mask lowering expects exactly two assignment targets") + if not isinstance(stmt.value.type, SemanticTupleType) or len(stmt.value.type.elements) != 2: + raise NotImplementedError("tail make_mask lowering expects a two-result tuple type") + + dtype_expr, remaining_expr = stmt.value.args + if not self._is_dtype_meta_expr(dtype_expr): + raise NotImplementedError("make_mask dtype lowering expects a dtype symbol") + + lines: list[str] = [] + remaining = self._lower_remaining_to_i32(remaining_expr, env, indent=indent, into=lines) + mask_target, remaining_target = stmt.targets + mask_type, remaining_type = stmt.value.type.elements + suffix = self._mask_suffix(mask_type) + lines.append( + self._indent(indent) + + f"{mask_target.ssa_name}, {remaining_target.ssa_name} = pto.plt_{suffix} {remaining.name} : " + + f"i32 -> {self._render_type(mask_type)}, {self._render_type(remaining_type)}" + ) + env[mask_target.name] = _RenderedValue(name=mask_target.ssa_name, type=mask_type) + env[remaining_target.name] = _RenderedValue(name=remaining_target.ssa_name, type=remaining_type) + return lines + def _render_dma_load( self, stmt: SemanticDmaLoadStmt, @@ -193,9 +267,10 @@ def _render_dma_load( *, indent: int, ) -> list[str]: - src = self._lower_expr(stmt.src.base, env, indent=indent) - dst = self._lower_expr(stmt.dst, env, indent=indent) - row_count, col_count = self._tensor_slice_extents(stmt.src) + lines: list[str] = [] + src = self._lower_expr(stmt.src.base, env, indent=indent, into=lines) + dst = self._lower_expr(stmt.dst, env, indent=indent, into=lines) + row_count, col_count = self._dma_transfer_extents(stmt.src, stmt.dst.type) element_bytes = self._dtype_byte_width(stmt.src.type.element_dtype) burst_bytes = col_count * element_bytes @@ -205,16 +280,19 @@ def _render_dma_load( len_burst = self._materialize_constant(burst_bytes, _I64_TYPE) false_bit = self._materialize_constant(False, _I1_TYPE) - return [ - self._indent(indent) - + f"pto.set_loop_size_outtoub {c1_i64}, {c1_i64} : i64, i64", - self._indent(indent) - + "pto.copy_gm_to_ubuf " - + f"{src.name}, {dst.name}, {c0_i64}, {n_burst}, {len_burst}, {c0_i64}, {c0_i64}, " - + f"{false_bit}, {c0_i64}, {len_burst}, {len_burst} : " - + f"{self._render_type(src.type)}, {self._render_type(dst.type)}, " - + "i64, i64, i64, i64, i64, i1, i64, i64, i64", - ] + lines.extend( + [ + self._indent(indent) + + f"pto.set_loop_size_outtoub {c1_i64}, {c1_i64} : i64, i64", + self._indent(indent) + + "pto.copy_gm_to_ubuf " + + f"{src.name}, {dst.name}, {c0_i64}, {n_burst}, {len_burst}, {c0_i64}, {c0_i64}, " + + f"{false_bit}, {c0_i64}, {len_burst}, {len_burst} : " + + f"{self._render_type(src.type)}, {self._render_type(dst.type)}, " + + "i64, i64, i64, i64, i64, i1, i64, i64, i64", + ] + ) + return lines def _render_dma_store( self, @@ -223,9 +301,10 @@ def _render_dma_store( *, indent: int, ) -> list[str]: - src = self._lower_expr(stmt.src, env, indent=indent) - dst = self._lower_expr(stmt.dst.base, env, indent=indent) - row_count, col_count = self._tensor_slice_extents(stmt.dst) + lines: list[str] = [] + src = self._lower_expr(stmt.src, env, indent=indent, into=lines) + dst = self._lower_expr(stmt.dst.base, env, indent=indent, into=lines) + row_count, col_count = self._dma_transfer_extents(stmt.dst, stmt.src.type) element_bytes = self._dtype_byte_width(stmt.dst.type.element_dtype) burst_bytes = col_count * element_bytes @@ -234,15 +313,18 @@ def _render_dma_store( n_burst = self._materialize_constant(row_count, _I64_TYPE) len_burst = self._materialize_constant(burst_bytes, _I64_TYPE) - return [ - self._indent(indent) - + f"pto.set_loop_size_ubtoout {c1_i64}, {c1_i64} : i64, i64", - self._indent(indent) - + "pto.copy_ubuf_to_gm " - + f"{src.name}, {dst.name}, {c0_i64}, {n_burst}, {len_burst}, {c0_i64}, " - + f"{len_burst}, {len_burst} : {self._render_type(src.type)}, {self._render_type(dst.type)}, " - + "i64, i64, i64, i64, i64, i64", - ] + lines.extend( + [ + self._indent(indent) + + f"pto.set_loop_size_ubtoout {c1_i64}, {c1_i64} : i64, i64", + self._indent(indent) + + "pto.copy_ubuf_to_gm " + + f"{src.name}, {dst.name}, {c0_i64}, {n_burst}, {len_burst}, {c0_i64}, " + + f"{len_burst}, {len_burst} : {self._render_type(src.type)}, {self._render_type(dst.type)}, " + + "i64, i64, i64, i64, i64, i64", + ] + ) + return lines def _render_vector_store( self, @@ -251,22 +333,36 @@ def _render_vector_store( *, indent: int, ) -> list[str]: - value = self._lower_expr(stmt.value, env, indent=indent) - destination = self._lower_expr(stmt.destination, env, indent=indent) - offset = self._lower_expr(stmt.offset, env, indent=indent) - mask = self._lower_expr(stmt.mask, env, indent=indent) - return [ + lines: list[str] = [] + value = self._lower_expr(stmt.value, env, indent=indent, into=lines) + destination = self._lower_expr(stmt.destination, env, indent=indent, into=lines) + offset = self._lower_expr(stmt.offset, env, indent=indent, into=lines) + mask = self._lower_expr(stmt.mask, env, indent=indent, into=lines) + lines.append( self._indent(indent) + "pto.vsts " + f"{value.name}, {destination.name}[{offset.name}], {mask.name} : " + f"{self._render_type(value.type)}, {self._render_type(destination.type)}, {self._render_type(mask.type)}" - ] + ) + return lines def _tensor_slice_extents(self, expr: SemanticTensorSliceExpr) -> tuple[int, int]: if expr.type.rank != 2 or len(expr.type.extents) != 2: raise NotImplementedError("TileLang DSL v1 DMA lowering currently only supports rank-2 TensorView slices") return expr.type.extents + def _dma_transfer_extents( + self, + slice_expr: SemanticTensorSliceExpr, + tile_type: SemanticTileType, + ) -> tuple[int, int]: + row_count, col_count = self._tensor_slice_extents(slice_expr) + if row_count is not None and col_count is not None: + return row_count, col_count + if tile_type.shape is None or len(tile_type.shape) != 2: + raise NotImplementedError("DMA lowering requires a statically specialized rank-2 Tile shape") + return tile_type.shape + def _render_strict_vecscope( self, stmt: SemanticStrictVecscopeStmt, @@ -274,7 +370,11 @@ def _render_strict_vecscope( *, indent: int, ) -> list[str]: - capture_values = [self._lower_expr(expr, env, indent=indent) for expr in stmt.captures] + lines: list[str] = [] + capture_values = [ + self._lower_expr(expr, env, indent=indent, into=lines) + for expr in stmt.captures + ] capture_names = ", ".join(value.name for value in capture_values) block_args = ", ".join( f"{binding.ssa_name}: {self._render_type(binding.type)}" @@ -289,12 +389,25 @@ def _render_strict_vecscope( for binding in stmt.block_arguments } - lines = [self._indent(indent) + f"pto.strict_vecscope({capture_names}) {{"] + lines.append(self._indent(indent) + f"pto.strict_vecscope({capture_names}) {{") lines.append(self._indent(indent) + f"^bb0({block_args}):") lines.extend(self._render_block(stmt.body, scope_env, indent=indent + 2)) lines.append(self._indent(indent) + f"}} : ({function_type}) -> ()") return lines + def _render_vecscope( + self, + stmt: SemanticVecscopeStmt, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + scope_env = dict(env) + lines = [self._indent(indent) + "pto.vecscope {"] + lines.extend(self._render_block(stmt.body, scope_env, indent=indent + 2)) + lines.append(self._indent(indent) + "}") + return lines + def _render_for( self, stmt: SemanticForStmt, @@ -302,9 +415,10 @@ def _render_for( *, indent: int, ) -> list[str]: - lower_bound = self._lower_expr(stmt.lower_bound, env, indent=indent) - upper_bound = self._lower_expr(stmt.upper_bound, env, indent=indent) - step = self._lower_expr(stmt.step, env, indent=indent) + lines: list[str] = [] + lower_bound = self._lower_expr(stmt.lower_bound, env, indent=indent, into=lines) + upper_bound = self._lower_expr(stmt.upper_bound, env, indent=indent, into=lines) + step = self._lower_expr(stmt.step, env, indent=indent, into=lines) body_env = dict(env) body_env[stmt.induction_variable.name] = _RenderedValue( @@ -313,11 +427,11 @@ def _render_for( ) if not stmt.loop_carried: - lines = [ + lines.append( self._indent(indent) + f"scf.for {stmt.induction_variable.ssa_name} = {lower_bound.name} " f"to {upper_bound.name} step {step.name} {{" - ] + ) lines.extend(self._render_block(stmt.body, body_env, indent=indent + 2)) lines.append(self._indent(indent) + "}") return lines @@ -336,13 +450,13 @@ def _render_for( type=carried_binding.type, ) - lines = [ + lines.append( self._indent(indent) + f"{carried_binding.ssa_name}:1 = scf.for {stmt.induction_variable.ssa_name} = " f"{lower_bound.name} to {upper_bound.name} step {step.name} " f"iter_args({iter_arg_name} = {initial_value.name}) -> " f"({self._render_type(carried_binding.type)}) {{" - ] + ) lines.extend(self._render_block(stmt.body, body_env, indent=indent + 2)) yielded_value = body_env[carried_binding.name] lines.append( @@ -422,7 +536,7 @@ def _lower_condition( indent: int, into: list[str], ) -> _RenderedValue: - value = self._lower_expr(expr, env, indent=indent) + value = self._lower_expr(expr, env, indent=indent, into=into) if isinstance(value.type, SemanticScalarType) and value.type.dtype.name == "i1": return value @@ -472,21 +586,18 @@ def _lower_expr( type=expr.type, ) if isinstance(expr, SemanticSubscriptAccess): - if desired_name is not None and into is not None: - value = self._extract_static_subscript_value(expr, env) - into.append( - self._indent(indent) - + f"{desired_name} = arith.constant {self._format_constant(value, expr.type)} : " - f"{self._render_type(expr.type)}" - ) - return _RenderedValue(name=desired_name, type=expr.type) - constant_name = self._lower_static_subscript(expr, env) - return _RenderedValue(name=constant_name, type=expr.type) + return self._lower_subscript_access( + expr, + env, + indent=indent, + desired_name=desired_name, + into=into, + ) if isinstance(expr, SemanticBinaryExpr): - lhs = self._lower_expr(expr.lhs, env, indent=indent) - rhs = self._lower_expr(expr.rhs, env, indent=indent) if into is None: into = [] + lhs = self._lower_expr(expr.lhs, env, indent=indent, into=into) + rhs = self._lower_expr(expr.rhs, env, indent=indent, into=into) result_name = desired_name or self._new_temp() into.append( self._indent(indent) @@ -515,13 +626,15 @@ def _lower_call_expr( ) -> _RenderedValue: if expr.namespace != "pto": raise NotImplementedError(f"unsupported call namespace {expr.namespace!r}") + if isinstance(expr.type, SemanticTupleType): + raise NotImplementedError("multi-result call values must be assigned directly in TileLang DSL v1") if into is None: into = [] result_name = desired_name or self._new_temp() if expr.name == "make_mask": dtype_expr, pattern_expr = expr.args - if not isinstance(dtype_expr, SemanticSymbolExpr): + if not self._is_dtype_meta_expr(dtype_expr): raise NotImplementedError("make_mask dtype lowering expects a dtype symbol") if not isinstance(pattern_expr, SemanticSymbolExpr) or not isinstance(pattern_expr.value, MaskPattern): raise NotImplementedError("make_mask pattern lowering expects a MaskPattern symbol") @@ -533,8 +646,8 @@ def _lower_call_expr( return _RenderedValue(name=result_name, type=expr.type) if expr.name == "vlds": - source = self._lower_expr(expr.args[0], env, indent=indent) - offset = self._lower_expr(expr.args[1], env, indent=indent) + source = self._lower_expr(expr.args[0], env, indent=indent, into=into) + offset = self._lower_expr(expr.args[1], env, indent=indent, into=into) into.append( self._indent(indent) + f"{result_name} = pto.vlds {source.name}[{offset.name}] : " @@ -543,8 +656,8 @@ def _lower_call_expr( return _RenderedValue(name=result_name, type=expr.type) if expr.name in {"vabs", "vrelu", "vexp", "vnot"}: - value = self._lower_expr(expr.args[0], env, indent=indent) - mask = self._lower_expr(expr.args[1], env, indent=indent) + value = self._lower_expr(expr.args[0], env, indent=indent, into=into) + mask = self._lower_expr(expr.args[1], env, indent=indent, into=into) into.append( self._indent(indent) + f"{result_name} = pto.{expr.name} {value.name}, {mask.name} : " @@ -553,9 +666,9 @@ def _lower_call_expr( return _RenderedValue(name=result_name, type=expr.type) if expr.name in {"vadd", "vsub", "vmul", "vdiv", "vmax", "vmin", "vand", "vor", "vxor"}: - lhs = self._lower_expr(expr.args[0], env, indent=indent) - rhs = self._lower_expr(expr.args[1], env, indent=indent) - mask = self._lower_expr(expr.args[2], env, indent=indent) + lhs = self._lower_expr(expr.args[0], env, indent=indent, into=into) + rhs = self._lower_expr(expr.args[1], env, indent=indent, into=into) + mask = self._lower_expr(expr.args[2], env, indent=indent, into=into) into.append( self._indent(indent) + f"{result_name} = pto.{expr.name} {lhs.name}, {rhs.name}, {mask.name} : " @@ -565,9 +678,9 @@ def _lower_call_expr( return _RenderedValue(name=result_name, type=expr.type) if expr.name in {"vadds", "vsubs", "vmuls", "vdivs", "vmaxs", "vmins"}: - value = self._lower_expr(expr.args[0], env, indent=indent) - scalar = self._lower_expr(expr.args[1], env, indent=indent) - mask = self._lower_expr(expr.args[2], env, indent=indent) + value = self._lower_expr(expr.args[0], env, indent=indent, into=into) + scalar = self._lower_expr(expr.args[1], env, indent=indent, into=into) + mask = self._lower_expr(expr.args[2], env, indent=indent, into=into) into.append( self._indent(indent) + f"{result_name} = pto.{expr.name} {value.name}, {scalar.name}, {mask.name} : " @@ -578,19 +691,74 @@ def _lower_call_expr( raise NotImplementedError(f"unsupported pto call `{expr.name}` in lowering") - def _lower_static_subscript( + def _lower_remaining_to_i32( + self, + expr: SemanticExpr, + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + value = self._lower_expr(expr, env, indent=indent, into=into) + if isinstance(value.type, SemanticScalarType) and value.type.dtype.name == "i32": + return value + if isinstance(value.type, SemanticIndexType): + cast_name = self._new_temp() + into.append( + self._indent(indent) + + f"{cast_name} = arith.index_cast {value.name} : index to i32" + ) + return _RenderedValue(name=cast_name, type=_I32_TYPE) + raise NotImplementedError("tail make_mask lowering expects an i32 or index remaining operand") + + def _mask_suffix(self, ty: SemanticType) -> str: + if not isinstance(ty, SemanticMaskType): + raise NotImplementedError("tail make_mask lowering expects a mask result type") + return ty.granularity + + def _is_dtype_meta_expr(self, expr: SemanticExpr) -> bool: + if isinstance(expr, SemanticSymbolExpr): + return isinstance(expr.value, ScalarType) and expr.type.kind == "dtype" + if isinstance(expr, SemanticBindingRef): + return ( + isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "dtype" + and isinstance(expr.binding.value, ScalarType) + ) + return False + + def _lower_subscript_access( self, expr: SemanticSubscriptAccess, env: dict[str, _RenderedValue], - ) -> str: - value = self._extract_static_subscript_value(expr, env) - return self._materialize_constant(value, expr.type) + *, + indent: int, + desired_name: str | None, + into: list[str] | None, + ) -> _RenderedValue: + value = self._extract_shape_subscript_value(expr, env) + if isinstance(value, _RenderedValue): + return value + if desired_name is not None and into is not None: + into.append( + self._indent(indent) + + f"{desired_name} = arith.constant {self._format_constant(value, expr.type)} : " + f"{self._render_type(expr.type)}" + ) + return _RenderedValue(name=desired_name, type=expr.type) + return _RenderedValue( + name=self._materialize_constant(value, expr.type), + type=expr.type, + ) + + def _tensor_shape_binding_name(self, tensor_name: str, axis: int) -> str: + return f"__shape_{tensor_name}_{axis}" - def _extract_static_subscript_value( + def _extract_shape_subscript_value( self, expr: SemanticSubscriptAccess, env: dict[str, _RenderedValue], - ) -> int: + ) -> int | _RenderedValue: if not isinstance(expr.base, SemanticAttributeAccess): raise NotImplementedError("only shape indexing is supported in TileLang DSL v1 lowering") if expr.base.attr != "shape": @@ -611,9 +779,13 @@ def _extract_static_subscript_value( return base_type.shape[index] if isinstance(base_type, SemanticTensorViewType): - raise NotImplementedError( - "dynamic TensorView shape materialization is not implemented in TileLang DSL v1 lowering yet" - ) + hidden_name = self._tensor_shape_binding_name(base_binding.name, index) + hidden_value = env.get(hidden_name) + if hidden_value is None: + raise NotImplementedError( + f"missing TensorView shape binding for '{base_binding.name}.shape[{index}]'" + ) + return hidden_value raise NotImplementedError("shape indexing expects a Tile or TensorView operand") diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 6ec2cb168..6075c8d81 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -29,6 +29,11 @@ FrontendTupleExpr, FrontendTupleTarget, ) +from .support_matrix import ( + DEFERRED_PTO_SURFACES, + deferred_surface_message, + unsupported_feature_message, +) from .types import Event, MaskPattern, Pipe, ScalarType, bf16, f16, f32, i1, i8, i16, i32 @@ -63,7 +68,7 @@ class SemanticTensorViewType(SemanticType): class SemanticTensorSliceType(SemanticType): element_dtype: ScalarType rank: int - extents: tuple[int, ...] + extents: tuple[int | None, ...] @dataclass(frozen=True) @@ -115,12 +120,16 @@ class SemanticVRegType(SemanticType): lanes: int +_I32_TYPE = SemanticScalarType(dtype=i32) + + @dataclass(frozen=True) class SemanticBinding: name: str ssa_name: str type: SemanticType origin: str + value: Any | None = None @dataclass(frozen=True) @@ -242,6 +251,11 @@ class SemanticVectorStoreStmt(SemanticStmt): mask: SemanticExpr +@dataclass(frozen=True) +class SemanticVecscopeStmt(SemanticStmt): + body: tuple[SemanticStmt, ...] + + @dataclass(frozen=True) class SemanticSetFlagStmt(SemanticStmt): src_pipe: str @@ -325,6 +339,7 @@ class SemanticKernel: op: str symbol_name: str verify_enabled: bool + advanced_enabled: bool dtype_signature: tuple[Any, ...] parameters: tuple[SemanticParameter, ...] tile_bindings: tuple[SemanticTileBinding, ...] @@ -335,9 +350,11 @@ class _SemanticAnalyzer: def __init__(self, node: FrontendKernelNode): self.node = node self._counter = 0 + self._disable_inference_depth = 0 self._tile_specializations = { spec.name: spec for spec in node.tile_specializations } + self._tensor_shape_parameters: list[SemanticParameter] = [] def analyze(self) -> SemanticKernel: env: dict[str, SemanticBinding] = {} @@ -351,7 +368,8 @@ def analyze(self) -> SemanticKernel: ) env[param.name] = binding parameters.append(SemanticParameter(binding=binding)) - body, _ = self._analyze_block(self.node.body, env, allow_outer_lookup=True) + body, _ = self._analyze_kernel_body(env) + parameters.extend(self._tensor_shape_parameters) tile_bindings = tuple( SemanticTileBinding( name=spec.name, @@ -366,12 +384,52 @@ def analyze(self) -> SemanticKernel: op=self.node.op, symbol_name=self.node.name, verify_enabled=self.node.verify_enabled, + advanced_enabled=self.node.advanced_enabled, dtype_signature=self.node.dtype_signature, parameters=tuple(parameters), tile_bindings=tile_bindings, body=body, ) + def _analyze_kernel_body( + self, + env: dict[str, SemanticBinding], + ) -> tuple[tuple[SemanticStmt, ...], dict[str, SemanticBinding]]: + if not self.node.advanced_enabled: + return self._analyze_block(self.node.body, env, allow_outer_lookup=True) + + body_without_return = self.node.body + trailing_return: FrontendReturnStmt | None = None + if body_without_return and isinstance(body_without_return[-1], FrontendReturnStmt): + trailing_return = body_without_return[-1] + body_without_return = body_without_return[:-1] + + if body_without_return and self._can_wrap_whole_kernel_vecscope(body_without_return): + self._disable_inference_depth += 1 + try: + scoped_body, scoped_env = self._analyze_block_without_inference( + body_without_return, + env, + allow_outer_lookup=True, + ) + finally: + self._disable_inference_depth -= 1 + semantic_body: list[SemanticStmt] = [] + if self._semantic_block_contains_vector_activity(scoped_body): + semantic_body.append(SemanticVecscopeStmt(body=scoped_body)) + else: + semantic_body.extend(scoped_body) + if trailing_return is not None: + return_stmt, scoped_env = self._analyze_stmt( + trailing_return, + scoped_env, + allow_outer_lookup=True, + ) + semantic_body.append(return_stmt) + return tuple(semantic_body), scoped_env + + return self._analyze_block(self.node.body, env, allow_outer_lookup=True) + def _parameter_type(self, param: Any) -> SemanticType: if param.kind == "tensorview": return SemanticTensorViewType(element_dtype=param.dtype) @@ -395,13 +453,42 @@ def _new_ssa_name(self, stem: str) -> str: self._counter += 1 return name - def _make_binding(self, name: str, ty: SemanticType, origin: str) -> SemanticBinding: + def _tensor_shape_binding_name(self, tensor_name: str, axis: int) -> str: + return f"__shape_{tensor_name}_{axis}" + + def _ensure_tensor_shape_parameter( + self, + tensor_binding: SemanticBinding, + axis: int, + ) -> SemanticBinding: + hidden_name = self._tensor_shape_binding_name(tensor_binding.name, axis) + for parameter in self._tensor_shape_parameters: + if parameter.name == hidden_name: + return parameter.binding + binding = SemanticBinding( + name=hidden_name, + ssa_name=f"%arg{len(self.node.parameters) + len(self._tensor_shape_parameters)}", + type=SemanticIndexType(), + origin="tensorview_shape", + ) + self._tensor_shape_parameters.append(SemanticParameter(binding=binding)) + return binding + + def _make_binding( + self, + name: str, + ty: SemanticType, + origin: str, + *, + value: Any | None = None, + ) -> SemanticBinding: stem = name if name.isidentifier() else "v" return SemanticBinding( name=name, ssa_name=self._new_ssa_name(stem), type=ty, origin=origin, + value=value, ) def _analyze_block( @@ -410,6 +497,101 @@ def _analyze_block( env: dict[str, SemanticBinding], *, allow_outer_lookup: bool, + ) -> tuple[tuple[SemanticStmt, ...], dict[str, SemanticBinding]]: + current_env = dict(env) + semantic_statements = [] + index = 0 + while index < len(statements): + if self._should_infer_vecscope(statements[index], allow_outer_lookup=allow_outer_lookup): + end = index + 1 + while end < len(statements) and self._should_infer_vecscope( + statements[end], + allow_outer_lookup=allow_outer_lookup, + ): + end += 1 + run = statements[index:end] + if self._run_contains_vector_op(run): + semantic_statements.append( + self._analyze_inferred_vecscope( + run, + current_env, + allow_outer_lookup=allow_outer_lookup, + ) + ) + else: + for stmt in run: + semantic_stmt, current_env = self._analyze_stmt( + stmt, + current_env, + allow_outer_lookup=allow_outer_lookup, + ) + semantic_statements.append(semantic_stmt) + index = end + continue + + semantic_stmt, current_env = self._analyze_stmt( + statements[index], + current_env, + allow_outer_lookup=allow_outer_lookup, + ) + semantic_statements.append(semantic_stmt) + index += 1 + return tuple(semantic_statements), current_env + + def _should_infer_vecscope( + self, + stmt: FrontendStmtNode, + *, + allow_outer_lookup: bool, + ) -> bool: + if self._disable_inference_depth > 0: + return False + if not self.node.advanced_enabled or not allow_outer_lookup: + return False + name = self._frontend_vector_call_name(stmt) + return name in {"make_mask", "vlds", "vsts"} | _UNARY_VECTOR_OPS | _BINARY_VECTOR_OPS | _VECTOR_SCALAR_OPS + + def _run_contains_vector_op(self, statements: tuple[FrontendStmtNode, ...]) -> bool: + for stmt in statements: + name = self._frontend_vector_call_name(stmt) + if name is None or name == "make_mask": + continue + return True + return False + + def _frontend_vector_call_name(self, stmt: FrontendStmtNode) -> str | None: + expr: FrontendExprNode | None = None + if isinstance(stmt, FrontendAssignStmt): + expr = stmt.value + elif isinstance(stmt, FrontendExprStmt): + expr = stmt.expr + if ( + isinstance(expr, FrontendCallExpr) + and expr.namespace == "pto" + ): + return expr.name + return None + + def _analyze_inferred_vecscope( + self, + statements: tuple[FrontendStmtNode, ...], + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + ) -> SemanticVecscopeStmt: + body, _ = self._analyze_block_without_inference( + statements, + env, + allow_outer_lookup=allow_outer_lookup, + ) + return SemanticVecscopeStmt(body=body) + + def _analyze_block_without_inference( + self, + statements: tuple[FrontendStmtNode, ...], + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, ) -> tuple[tuple[SemanticStmt, ...], dict[str, SemanticBinding]]: current_env = dict(env) semantic_statements = [] @@ -422,6 +604,70 @@ def _analyze_block( semantic_statements.append(semantic_stmt) return tuple(semantic_statements), current_env + def _can_wrap_whole_kernel_vecscope( + self, + statements: tuple[FrontendStmtNode, ...], + ) -> bool: + for stmt in statements: + if isinstance(stmt, FrontendStrictVecscopeStmt): + return False + if isinstance(stmt, FrontendExprStmt) and ( + self._is_dma_call(stmt.expr) or self._is_sync_call(stmt.expr) + ): + return False + nested_blocks: tuple[tuple[FrontendStmtNode, ...], ...] = () + if isinstance(stmt, FrontendForStmt): + nested_blocks = (stmt.body,) + elif isinstance(stmt, FrontendIfStmt): + nested_blocks = (stmt.then_body, stmt.else_body) + for block in nested_blocks: + if not self._can_wrap_whole_kernel_vecscope(block): + return False + return True + + def _semantic_block_contains_vector_activity( + self, + statements: tuple[SemanticStmt, ...], + ) -> bool: + for stmt in statements: + if isinstance(stmt, SemanticVecscopeStmt): + return True + if isinstance(stmt, SemanticStrictVecscopeStmt): + return True + if isinstance(stmt, SemanticVectorStoreStmt): + return True + if isinstance(stmt, SemanticAssignStmt) and self._expr_contains_vector_activity(stmt.value): + return True + if isinstance(stmt, SemanticExprStmt) and self._expr_contains_vector_activity(stmt.expr): + return True + if isinstance(stmt, SemanticForStmt) and self._semantic_block_contains_vector_activity(stmt.body): + return True + if isinstance(stmt, SemanticIfStmt) and ( + self._semantic_block_contains_vector_activity(stmt.then_body) + or self._semantic_block_contains_vector_activity(stmt.else_body) + ): + return True + return False + + def _expr_contains_vector_activity(self, expr: SemanticExpr) -> bool: + if isinstance(expr, SemanticCallExpr): + if expr.namespace == "pto" and expr.name in ( + {"make_mask", "vlds"} | _UNARY_VECTOR_OPS | _BINARY_VECTOR_OPS | _VECTOR_SCALAR_OPS + ): + return True + return any(self._expr_contains_vector_activity(arg) for arg in expr.args) + if isinstance(expr, SemanticBinaryExpr): + return self._expr_contains_vector_activity(expr.lhs) or self._expr_contains_vector_activity(expr.rhs) + if isinstance(expr, SemanticTupleExpr): + return any(self._expr_contains_vector_activity(element) for element in expr.elements) + if isinstance(expr, SemanticAttributeAccess): + return self._expr_contains_vector_activity(expr.base) + if isinstance(expr, SemanticSubscriptAccess): + return self._expr_contains_vector_activity(expr.base) or self._expr_contains_vector_activity(expr.index) + if isinstance(expr, SemanticTensorSliceExpr): + return self._expr_contains_vector_activity(expr.base) + return False + def _analyze_stmt( self, stmt: FrontendStmtNode, @@ -519,13 +765,23 @@ def _analyze_vector_store_stmt( *, allow_outer_lookup: bool, ) -> tuple[SemanticStmt, dict[str, SemanticBinding]]: - args = tuple( - self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) - for arg in expr.args - ) - if len(args) != 4: - raise TypeError("pto.vsts expects exactly 4 positional arguments in TileLang DSL v1") - value, destination, offset, mask = args + if len(expr.args) == 3: + value = self._analyze_expr(expr.args[0], env, allow_outer_lookup=allow_outer_lookup) + destination, offset = self._analyze_tile_vector_access( + expr.args[1], + env, + allow_outer_lookup=allow_outer_lookup, + context="pto.vsts destination", + ) + mask = self._analyze_expr(expr.args[2], env, allow_outer_lookup=allow_outer_lookup) + else: + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + if len(args) != 4: + raise TypeError("pto.vsts expects 3 or 4 positional arguments in TileLang DSL v1") + value, destination, offset, mask = args self._require_vreg_expr(value, "pto.vsts value") self._require_tile_expr(destination, "pto.vsts destination") self._require_index_typed_expr(offset) @@ -600,11 +856,12 @@ def _validate_dma_shape_match( raise TypeError(f"{op_name} requires a statically specialized rank-2 Tile in TileLang DSL v1") if tensor_slice_type.element_dtype != tile_type.element_dtype: raise TypeError(f"{op_name} requires matching TensorView/Tile element dtypes in TileLang DSL v1") - if tensor_slice_type.extents != tile_type.shape: - raise TypeError( - f"{op_name} requires TensorView slice extents {tensor_slice_type.extents!r} " - f"to match Tile shape {tile_type.shape!r}" - ) + for axis, (extent, tile_dim) in enumerate(zip(tensor_slice_type.extents, tile_type.shape)): + if extent is not None and extent != tile_dim: + raise TypeError( + f"{op_name} requires TensorView slice extent axis {axis}={extent!r} " + f"to match Tile shape axis {axis}={tile_dim!r}" + ) def _bind_assignment_target( self, @@ -614,20 +871,53 @@ def _bind_assignment_target( annotation: Any | None, ) -> tuple[SemanticBinding, ...]: if isinstance(target, FrontendNameTarget): + if isinstance(value.type, SemanticTupleType): + raise ValueError("multi-result call assignment requires tuple binding in TileLang DSL v1") annotated_type = self._annotation_type(annotation, value.type) binding = self._make_binding( target.name, annotated_type if annotated_type is not None else value.type, "ssa", + value=self._binding_value_for_expr(value), ) env[target.name] = binding return (binding,) if isinstance(target, FrontendTupleTarget): - if not isinstance(value, SemanticCallExpr) or value.type is not None: - raise ValueError("tuple assignment expects a multi-result call") - raise ValueError("tuple assignment is not supported in TileLang DSL v1 yet") + if not isinstance(value.type, SemanticTupleType): + raise ValueError("tuple assignment expects a tuple-typed value") + if annotation is not None: + raise TypeError("annotated tuple assignment is not supported in TileLang DSL v1") + if len(target.elements) != len(value.type.elements): + raise ValueError("tuple assignment arity must match the tuple value") + tuple_values: tuple[SemanticExpr, ...] + if isinstance(value, SemanticTupleExpr): + tuple_values = value.elements + elif isinstance(value, SemanticCallExpr): + tuple_values = value.args + else: + tuple_values = tuple(SemanticLiteralExpr(value=None, type=element_type) for element_type in value.type.elements) + bindings = [] + for element, element_type, element_value in zip(target.elements, value.type.elements, tuple_values): + binding = self._make_binding( + element.name, + element_type, + "ssa", + value=self._binding_value_for_expr(element_value), + ) + env[element.name] = binding + bindings.append(binding) + return tuple(bindings) raise ValueError(f"unsupported frontend assignment target {type(target).__name__}") + def _binding_value_for_expr(self, expr: SemanticExpr) -> Any | None: + if isinstance(expr, SemanticSymbolExpr): + return expr.value + if isinstance(expr, SemanticLiteralExpr): + return expr.value + if isinstance(expr, SemanticBindingRef): + return expr.binding.value + return None + def _annotation_type( self, annotation: Any | None, @@ -672,11 +962,12 @@ def _analyze_for( final_binding = final_body_env.get(name) if final_binding is None or final_binding is outer_binding: continue - if final_binding.type != outer_binding.type: + merged_type = self._merge_loop_carried_types(outer_binding.type, final_binding.type) + if merged_type is None: raise TypeError( f"loop-carried binding '{name}' changes type from {outer_binding.type!r} to {final_binding.type!r}" ) - merged = self._make_binding(name, outer_binding.type, "loop_result") + merged = self._make_binding(name, merged_type, "loop_result") updated_env[name] = merged loop_carried.append(merged) @@ -836,6 +1127,10 @@ def _analyze_expr( ) if isinstance(expr, FrontendAttributeExpr): base = self._analyze_expr(expr.base, env, allow_outer_lookup=allow_outer_lookup) + if expr.attr == "element_type": + return self._element_type_expr(base) + if expr.attr == "valid_shape": + return self._valid_shape_expr(base) attr_type = self._attribute_type(base, expr.attr) return SemanticAttributeAccess(base=base, attr=expr.attr, type=attr_type) if isinstance(expr, FrontendSubscriptExpr): @@ -852,6 +1147,14 @@ def _analyze_expr( result_type = self._binary_type(lhs, rhs, expr.op) return SemanticBinaryExpr(lhs=lhs, op=expr.op, rhs=rhs, type=result_type) if isinstance(expr, FrontendCallExpr): + if expr.namespace == "pto" and expr.name == "vlds" and len(expr.args) == 1: + base, offset = self._analyze_tile_vector_access( + expr.args[0], + env, + allow_outer_lookup=allow_outer_lookup, + context="pto.vlds source", + ) + return self._analyze_vlds((base, offset)) args = tuple( self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) for arg in expr.args @@ -911,10 +1214,58 @@ def _attribute_type(self, base: SemanticExpr, attr: str) -> SemanticType: return SemanticShapeType(rank=base_type.rank) raise TypeError(f"unsupported attribute access '{attr}' in TileLang DSL v1") + def _element_type_expr(self, base: SemanticExpr) -> SemanticExpr: + base_type = base.type + if isinstance(base_type, (SemanticTensorViewType, SemanticTileType)): + return SemanticSymbolExpr( + namespace="pto", + name=base_type.element_dtype.name, + value=base_type.element_dtype, + type=SemanticMetaType(kind="dtype"), + ) + raise TypeError("unsupported attribute access 'element_type' in TileLang DSL v1") + + def _valid_shape_expr(self, base: SemanticExpr) -> SemanticExpr: + base_type = base.type + if not isinstance(base_type, (SemanticTensorViewType, SemanticTileType)): + raise TypeError("unsupported attribute access 'valid_shape' in TileLang DSL v1") + shape_access = SemanticAttributeAccess( + base=base, + attr="shape", + type=SemanticShapeType(rank=base_type.rank), + ) + elements = [] + for axis in range(base_type.rank): + if isinstance(base, SemanticBindingRef) and isinstance(base.type, SemanticTensorViewType): + self._ensure_tensor_shape_parameter(base.binding, axis) + elements.append( + SemanticSubscriptAccess( + base=shape_access, + index=SemanticLiteralExpr(value=axis, type=SemanticIndexType()), + type=SemanticIndexType(), + ) + ) + return SemanticTupleExpr( + elements=tuple(elements), + type=SemanticTupleType(elements=tuple(SemanticIndexType() for _ in elements)), + ) + def _subscript_type(self, base: SemanticExpr, index: SemanticExpr) -> SemanticType: if isinstance(base.type, SemanticShapeType): if not isinstance(index.type, SemanticIndexType): raise TypeError("shape subscript index must be an index value in TileLang DSL v1") + if ( + isinstance(base, SemanticAttributeAccess) + and isinstance(base.base, SemanticBindingRef) + and isinstance(index, SemanticLiteralExpr) + and isinstance(index.value, int) + ): + if index.value < 0 or index.value >= base.type.rank: + raise TypeError( + f"shape subscript index {index.value} is out of bounds for rank {base.type.rank}" + ) + if isinstance(base.base.type, SemanticTensorViewType): + self._ensure_tensor_shape_parameter(base.base.binding, index.value) return SemanticIndexType() if isinstance(base.type, SemanticTensorViewType): if not isinstance(index, SemanticTupleExpr): @@ -922,6 +1273,80 @@ def _subscript_type(self, base: SemanticExpr, index: SemanticExpr) -> SemanticTy return self._tensor_slice_type(base.type, index) raise TypeError("unsupported subscript base in TileLang DSL v1") + def _analyze_tile_vector_access( + self, + expr: FrontendExprNode, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + context: str, + ) -> tuple[SemanticExpr, SemanticExpr]: + if not self.node.advanced_enabled: + raise TypeError(unsupported_feature_message(f"{context} tile indexing sugar")) + if not isinstance(expr, FrontendSubscriptExpr): + raise TypeError( + f"{context} expects Tile element-indexing syntax in advanced TileLang DSL mode" + ) + base = self._analyze_expr(expr.base, env, allow_outer_lookup=allow_outer_lookup) + tile = self._require_tile_expr(base, context) + offset = self._tile_vector_offset_expr( + expr.index, + tile.type, + env, + allow_outer_lookup=allow_outer_lookup, + context=context, + ) + return base, offset + + def _tile_vector_offset_expr( + self, + index_expr: FrontendExprNode, + tile_type: SemanticTileType, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + context: str, + ) -> SemanticExpr: + if tile_type.rank == 1: + if not isinstance(index_expr, FrontendSliceExpr): + raise TypeError(f"{context} expects Tile[start:] syntax for rank-1 Tile values") + if index_expr.stop is not None: + raise TypeError(f"{context} does not support explicit slice stop in TileLang DSL advanced mode") + if index_expr.step is not None: + raise TypeError(f"{context} does not support stepped Tile vector slices in TileLang DSL advanced mode") + if index_expr.start is None: + return SemanticLiteralExpr(value=0, type=SemanticIndexType()) + start = self._analyze_expr(index_expr.start, env, allow_outer_lookup=allow_outer_lookup) + self._require_index_typed_expr(start) + return start + + if tile_type.rank != 2 or tile_type.shape is None: + raise TypeError(f"{context} currently only supports statically specialized rank-1 or rank-2 Tiles") + if not isinstance(index_expr, FrontendTupleExpr) or len(index_expr.elements) != 2: + raise TypeError(f"{context} expects Tile[row, col:] syntax for rank-2 Tile values") + + row_expr, col_expr = index_expr.elements + if not isinstance(col_expr, FrontendSliceExpr): + raise TypeError(f"{context} expects Tile[row, col:] syntax for rank-2 Tile values") + if col_expr.stop is not None: + raise TypeError(f"{context} does not support explicit slice stop in TileLang DSL advanced mode") + if col_expr.step is not None: + raise TypeError(f"{context} does not support stepped Tile vector slices in TileLang DSL advanced mode") + + row = self._analyze_expr(row_expr, env, allow_outer_lookup=allow_outer_lookup) + self._require_index_typed_expr(row) + if col_expr.start is None: + col = SemanticLiteralExpr(value=0, type=SemanticIndexType()) + else: + col = self._analyze_expr(col_expr.start, env, allow_outer_lookup=allow_outer_lookup) + self._require_index_typed_expr(col) + + stride = SemanticLiteralExpr(value=tile_type.shape[1], type=SemanticIndexType()) + row_offset = SemanticBinaryExpr(lhs=row, op="mul", rhs=stride, type=SemanticIndexType()) + if isinstance(col, SemanticLiteralExpr) and col.value == 0: + return row_offset + return SemanticBinaryExpr(lhs=row_offset, op="add", rhs=col, type=SemanticIndexType()) + def _tensor_slice_type( self, tensor_type: SemanticTensorViewType, @@ -937,18 +1362,27 @@ def _tensor_slice_type( raise TypeError( f"TensorView slicing axis {axis} must use a Python slice in TileLang DSL v1" ) + self._require_optional_index_typed_expr(element.start) + self._require_optional_index_typed_expr(element.stop) + self._require_optional_index_typed_expr(element.step) + start = self._static_index_value(element.start, default=0) stop = self._static_index_value(element.stop, default=None) step = self._static_index_value(element.step, default=1) - if stop is None: + if element.stop is None: raise TypeError("TensorView slicing requires explicit stop bounds in TileLang DSL v1") if start != 0: raise TypeError("TensorView slicing currently only supports zero-based starts in TileLang DSL v1") + if element.step is not None and step is None: + raise TypeError("TensorView slicing currently only supports unit stride in TileLang DSL v1") if step != 1: raise TypeError("TensorView slicing currently only supports unit stride in TileLang DSL v1") - extent = stop - start - if extent <= 0: - raise TypeError("TensorView slicing requires positive static extents in TileLang DSL v1") + if stop is None: + extent = None + else: + extent = stop - start + if extent <= 0: + raise TypeError("TensorView slicing requires positive extents in TileLang DSL v1") extents.append(extent) return SemanticTensorSliceType( element_dtype=tensor_type.element_dtype, @@ -996,6 +1430,10 @@ def _analyze_call_expr( raise TypeError( f"call surface `{namespace + '.' if namespace else ''}{name}` is not supported in TileLang DSL v1 yet" ) + if name in DEFERRED_PTO_SURFACES: + raise TypeError(deferred_surface_message(name)) + if name == "get_lanes": + return self._analyze_get_lanes(args) if name == "make_mask": return self._analyze_make_mask(args) if name == "vlds": @@ -1013,20 +1451,32 @@ def _analyze_make_mask(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: raise TypeError("pto.make_mask expects exactly 2 positional arguments in TileLang DSL v1") dtype_expr, value_expr = args dtype = self._require_dtype_symbol(dtype_expr, "pto.make_mask element type") - if not ( - isinstance(value_expr, SemanticSymbolExpr) - and value_expr.type.kind == "mask_pattern" - ): - raise TypeError( - "pto.make_mask currently only supports PAT.* pattern lowering in TileLang DSL v1" + if isinstance(value_expr, SemanticSymbolExpr) and value_expr.type.kind == "mask_pattern": + return SemanticCallExpr( + namespace="pto", + name="make_mask", + args=args, + type=SemanticMaskType(granularity=self._mask_granularity_for_dtype(dtype)), ) + self._require_tail_remaining_expr(value_expr, "pto.make_mask tail remaining") return SemanticCallExpr( namespace="pto", name="make_mask", args=args, - type=SemanticMaskType(granularity=self._mask_granularity_for_dtype(dtype)), + type=SemanticTupleType( + elements=( + SemanticMaskType(granularity=self._mask_granularity_for_dtype(dtype)), + _I32_TYPE, + ) + ), ) + def _analyze_get_lanes(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) != 1: + raise TypeError("pto.get_lanes expects exactly 1 positional argument in TileLang DSL v1") + dtype = self._require_dtype_symbol(args[0], "pto.get_lanes dtype") + return SemanticLiteralExpr(value=self._vreg_type_for_dtype(dtype).lanes, type=SemanticIndexType()) + def _analyze_vlds(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: if len(args) != 2: raise TypeError("pto.vlds expects exactly 2 positional arguments in TileLang DSL v1") @@ -1091,6 +1541,13 @@ def _require_dtype_symbol(self, expr: SemanticExpr, context: str) -> ScalarType: and expr.type.kind == "dtype" and isinstance(expr.value, ScalarType) ): + if ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "dtype" + and isinstance(expr.binding.value, ScalarType) + ): + return expr.binding.value raise TypeError(f"{context} must be a TileLang scalar dtype symbol in TileLang DSL v1") return expr.value @@ -1104,6 +1561,13 @@ def _require_scalar_expr(self, expr: SemanticExpr, context: str) -> SemanticScal raise TypeError(f"{context} must be a scalar value in TileLang DSL v1") return expr.type + def _require_tail_remaining_expr(self, expr: SemanticExpr, context: str) -> None: + if isinstance(expr.type, SemanticIndexType): + return + if isinstance(expr.type, SemanticScalarType) and expr.type.dtype.name == "i32": + return + raise TypeError(f"{context} must be an i32 or index value in TileLang DSL v1") + def _require_mask_for_vreg( self, mask_expr: SemanticExpr, @@ -1202,6 +1666,27 @@ def _require_condition_type(self, ty: SemanticType) -> None: return raise TypeError(f"if condition must be scalar/index typed, got {ty!r}") + def _merge_loop_carried_types( + self, + outer_type: SemanticType, + final_type: SemanticType, + ) -> SemanticType | None: + if final_type == outer_type: + return outer_type + if ( + isinstance(outer_type, SemanticIndexType) + and isinstance(final_type, SemanticScalarType) + and final_type.dtype == i32 + ): + return final_type + if ( + isinstance(final_type, SemanticIndexType) + and isinstance(outer_type, SemanticScalarType) + and outer_type.dtype == i32 + ): + return outer_type + return None + def _require_index_typed_expr(self, expr: SemanticExpr) -> None: if not isinstance(expr.type, SemanticIndexType): raise TypeError("slice bounds and vector offsets must be index-typed in TileLang DSL v1") @@ -1210,9 +1695,14 @@ def _static_index_value(self, expr: SemanticExpr | None, *, default: int | None) if expr is None: return default if not isinstance(expr, SemanticLiteralExpr) or not isinstance(expr.value, int): - raise TypeError("TensorView slice bounds must be static integer literals in TileLang DSL v1") + return None return expr.value + def _require_optional_index_typed_expr(self, expr: SemanticExpr | None) -> None: + if expr is None: + return + self._require_index_typed_expr(expr) + def analyze_frontend_kernel(node: FrontendKernelNode) -> SemanticKernel: """Normalize descriptor-owned AST into a lowering semantic model.""" @@ -1247,6 +1737,7 @@ def analyze_frontend_kernel(node: FrontendKernelNode) -> SemanticKernel: "SemanticSliceExpr", "SemanticSliceType", "SemanticStmt", + "SemanticVecscopeStmt", "SemanticStrictVecscopeStmt", "SemanticSubscriptAccess", "SemanticSymbolExpr", diff --git a/tilelang-dsl/python/tilelang_dsl/support_matrix.py b/tilelang-dsl/python/tilelang_dsl/support_matrix.py new file mode 100644 index 000000000..de92cd701 --- /dev/null +++ b/tilelang-dsl/python/tilelang_dsl/support_matrix.py @@ -0,0 +1,91 @@ +"""Support-matrix definitions and diagnostics for TileLang DSL v1.""" + +from __future__ import annotations + +FOLLOW_UP_CHANGE = "extend-tilelang-dsl-matcher-and-advanced-surface" + +SUPPORTED_TOPLEVEL_PTO_CALLS = frozenset( + { + "strict_vecscope", + "dma_load", + "dma_store", + "set_flag", + "wait_flag", + "pipe_barrier", + "barrier", + } +) + +SUPPORTED_VECSCOPE_PTO_CALLS = frozenset( + { + "make_mask", + "vlds", + "vsts", + "vabs", + "vrelu", + "vexp", + "vnot", + "vadd", + "vsub", + "vmul", + "vdiv", + "vmax", + "vmin", + "vand", + "vor", + "vxor", + "vadds", + "vsubs", + "vmuls", + "vdivs", + "vmaxs", + "vmins", + } +) + +DEFERRED_PTO_SURFACES = frozenset( + { + "castptr", + "addptr", + "copy_ubuf_to_ubuf", + "vcmp", + "vcmps", + "vsel", + "vselr", + "vselrv2", + "pnot", + "psel", + "ppack", + "punpack", + "vaddc", + "vsubc", + "vaddcs", + "vsubcs", + "vintlv", + "vdintlv", + "vintlvv2", + "vdintlvv2", + "vreduce", + } +) + + +def unsupported_feature_message(feature: str) -> str: + return ( + f"{feature} is not supported in TileLang DSL v1; " + f"see follow-up change `{FOLLOW_UP_CHANGE}`" + ) + + +def deferred_surface_message(name: str) -> str: + return unsupported_feature_message(f"advanced family surface `pto.{name}`") + + +__all__ = [ + "DEFERRED_PTO_SURFACES", + "FOLLOW_UP_CHANGE", + "SUPPORTED_TOPLEVEL_PTO_CALLS", + "SUPPORTED_VECSCOPE_PTO_CALLS", + "deferred_surface_message", + "unsupported_feature_message", +] diff --git a/tilelang-dsl/python/tilelang_dsl/types.py b/tilelang-dsl/python/tilelang_dsl/types.py index 82bca103a..0e41114de 100644 --- a/tilelang-dsl/python/tilelang_dsl/types.py +++ b/tilelang-dsl/python/tilelang_dsl/types.py @@ -111,6 +111,23 @@ def TypeVar(name: str) -> TypeVariable: return TypeVariable(name) +def get_lanes(dtype: ScalarType) -> int: + if not isinstance(dtype, ScalarType): + raise TypeError("get_lanes expects a TileLang scalar dtype") + byte_widths = { + "i8": 1, + "i16": 2, + "i32": 4, + "f16": 2, + "bf16": 2, + "f32": 4, + } + width = byte_widths.get(dtype.name) + if width is None: + raise TypeError(f"dtype `{dtype.name}` is not supported by get_lanes") + return 256 // width + + __all__ = [ "ScalarType", "WildcardType", @@ -139,4 +156,5 @@ def TypeVar(name: str) -> TypeVariable: "AnyInt", "AnyType", "AnyMask", + "get_lanes", ] diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index c067788a7..07986b73a 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -1,9 +1,11 @@ import tempfile import unittest +from unittest import mock from importlib import util from pathlib import Path import tilelang_dsl as pto +import tilelang_dsl.kernel as kernel_impl from tilelang_dsl.frontend_ast import build_frontend_kernel_node from tilelang_dsl.lowering import AuthoringModule, lower_semantic_kernel from tilelang_dsl.semantic import ( @@ -21,6 +23,7 @@ SemanticStrictVecscopeStmt, SemanticTensorViewType, SemanticTileType, + SemanticVecscopeStmt, SemanticVectorStoreStmt, SemanticWaitFlagStmt, analyze_frontend_kernel, @@ -34,6 +37,7 @@ def test_package_exports_surface(self) -> None: self.assertTrue(hasattr(pto, "TensorView")) self.assertTrue(hasattr(pto, "Tile")) self.assertTrue(hasattr(pto, "TileSpecialization")) + self.assertTrue(hasattr(pto, "get_lanes")) self.assertTrue(hasattr(pto, "PAT")) self.assertTrue(hasattr(pto, "PIPE")) self.assertTrue(hasattr(pto, "EVENT")) @@ -49,7 +53,9 @@ def kernel(inp: pto.TensorView, tile: pto.Tile, scale: pto.i32): self.assertEqual(kernel.op, "eltwise") self.assertEqual(kernel.name, "kernel") self.assertFalse(kernel.verify_enabled) + self.assertFalse(kernel.advanced_enabled) self.assertEqual(kernel.metadata["verify"], False) + self.assertEqual(kernel.metadata["advanced"], False) self.assertEqual(kernel.dtype_signature, (pto.f32, pto.f16, pto.i32)) self.assertEqual( [(param.name, param.kind, param.dtype) for param in kernel.parameters], @@ -80,14 +86,44 @@ def kernel(inp: pto.TensorView, tile: pto.Tile): self.assertIn("func.func @kernel(%arg0: !pto.ptr, %arg1: !pto.ptr) {", text) module = specialized.mlir_module() self.assertEqual(type(module).__name__, "MaterializedMLIRModule") - self.assertTrue(module.verify()) - self.assertTrue(specialized.verify()) + mocked_result = kernel_impl.VerificationResult( + status="passed", + available=True, + passed=True, + message="ok", + command=("ptoas",), + returncode=0, + ) + with mock.patch("tilelang_dsl.kernel._run_ptoas_verifier", return_value=mocked_result): + self.assertTrue(module.verify()) + self.assertTrue(specialized.verify()) + self.assertEqual(module.verify().status, "passed") + self.assertEqual(specialized.verify().status, "passed") with tempfile.TemporaryDirectory() as tmpdir: out = Path(tmpdir) / "kernel.mlir" specialized.emit(out) self.assertEqual(out.read_text(encoding="utf-8"), text) + def test_verify_reports_structured_unavailable_when_ptoas_is_missing(self) -> None: + @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f16)]) + def kernel(inp: pto.TensorView, tile: pto.Tile): + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(16, 32), + memory_space=pto.MemorySpace.UB, + ) + ) + + result = specialized.verify(ptoas_bin="/definitely-missing/ptoas") + self.assertFalse(result) + self.assertEqual(result.status, "unavailable") + self.assertFalse(result.available) + self.assertFalse(result.passed) + self.assertIn("verifier unavailable", result.message) + def test_descriptor_materialization_flows_through_pipeline(self) -> None: @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f16, pto.i32)]) def kernel(inp: pto.TensorView, tile: pto.Tile, scale: pto.i32): @@ -218,6 +254,51 @@ def kernel(inp: pto.TensorView, out: pto.TensorView, tile: pto.Tile): text, ) + def test_dynamic_tensorview_shape_profile_supports_runtime_bound_and_slice(self) -> None: + @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32)]) + def kernel(inp: pto.TensorView, tile: pto.Tile): + rows = inp.shape[0] + pto.dma_load(inp[0:rows, 0:16], tile) + for lane in range(0, rows, 1): + current = lane + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(16, 16), + memory_space=pto.MemorySpace.UB, + ) + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertEqual( + [(param.name, param.kind) for param in semantic_kernel.parameters], + [("inp", "tensorview"), ("tile", "tile"), ("__shape_inp_0", "tensorview_shape")], + ) + + rows_assign = semantic_kernel.body[0] + self.assertIsInstance(rows_assign, SemanticAssignStmt) + self.assertIsInstance(rows_assign.targets[0].type, SemanticIndexType) + + dma_stmt = semantic_kernel.body[1] + self.assertIsInstance(dma_stmt, SemanticDmaLoadStmt) + self.assertEqual(dma_stmt.src.type.extents, (None, 16)) + + loop_stmt = semantic_kernel.body[2] + self.assertIsInstance(loop_stmt, SemanticForStmt) + + text = specialized.mlir_text() + self.assertIn( + "func.func @kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: index) {", + text, + ) + self.assertIn( + "pto.copy_gm_to_ubuf %arg0, %arg1, %c0_i64, %c16_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64", + text, + ) + self.assertIn("scf.for %lane_", text) + self.assertIn("to %arg2 step %c1 {", text) + def test_make_mask_vlds_vsts_and_vector_families_lower_inside_strict_vecscope(self) -> None: @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32, pto.f32)]) def kernel(inp: pto.TensorView, tile: pto.Tile, scale: pto.f32): @@ -278,6 +359,253 @@ def kernel(inp: pto.TensorView, tile: pto.Tile, scale: pto.f32): text, ) + def test_tail_make_mask_lowers_to_typed_plt_and_updates_remaining(self) -> None: + @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32, pto.i32)]) + def kernel(inp: pto.TensorView, tile: pto.Tile, remaining: pto.i32): + pto.dma_load(inp[0:16, 0:16], tile) + with pto.strict_vecscope(tile, tile, remaining, 0, 64, 64) as (src, dst, rem_in, lb, ub, step): + mask, next_remaining = pto.make_mask(pto.f32, rem_in) + vec = pto.vlds(src, lb) + pto.vsts(vec, dst, lb, mask) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(16, 16), + memory_space=pto.MemorySpace.UB, + ) + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + vecscope = semantic_kernel.body[1] + self.assertIsInstance(vecscope, SemanticStrictVecscopeStmt) + mask_assign = vecscope.body[0] + self.assertIsInstance(mask_assign, SemanticAssignStmt) + self.assertEqual(mask_assign.value.name, "make_mask") + self.assertEqual(len(mask_assign.targets), 2) + self.assertIsInstance(mask_assign.targets[0].type, SemanticMaskType) + self.assertIsInstance(mask_assign.targets[1].type, SemanticScalarType) + self.assertEqual(mask_assign.targets[1].type.dtype, pto.i32) + + text = specialized.mlir_text() + self.assertRegex( + text, + r"%mask_\d+, %next_remaining_\d+ = pto\.plt_b32 %rem_in_\d+ : i32 -> !pto\.mask, i32", + ) + self.assertIn( + "pto.vsts %vec_", + text, + ) + + def test_nested_index_arithmetic_lowers_before_vector_accesses(self) -> None: + @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32, pto.f32, pto.f32, pto.f32, pto.f32)]) + def kernel( + lhs_gm: pto.TensorView, + rhs_gm: pto.TensorView, + out_gm: pto.TensorView, + lhs_tile: pto.Tile, + rhs_tile: pto.Tile, + dst_tile: pto.Tile, + ): + rows = lhs_gm.shape[0] + cols = lhs_gm.shape[1] + row_stride = lhs_tile.shape[1] + + pto.dma_load(lhs_gm[0:rows, 0:cols], lhs_tile) + pto.dma_load(rhs_gm[0:rows, 0:cols], rhs_tile) + with pto.strict_vecscope( + lhs_tile, + rhs_tile, + dst_tile, + rows, + cols, + row_stride, + 0, + rows, + 1, + ) as (lhs, rhs, dst, valid_rows, valid_cols, stride, row_lb, row_ub, row_step): + for row in range(row_lb, row_ub, row_step): + for lane in range(0, valid_cols, 64): + offset = row * stride + lane + mask, next_remaining = pto.make_mask(pto.f32, valid_cols - lane) + summed = pto.vadd(pto.vlds(lhs, offset), pto.vlds(rhs, offset), mask) + pto.vsts(summed, dst, offset, mask) + pto.dma_store(dst_tile, out_gm[0:rows, 0:cols]) + return None + + specialized = kernel.specialize( + lhs_tile=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + rhs_tile=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + dst_tile=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertRegex(text, r"%tmp_\d+ = arith\.muli %row_\d+, %stride_\d+ : index") + self.assertRegex(text, r"%offset_\d+ = arith\.addi %tmp_\d+, %lane_\d+ : index") + self.assertRegex(text, r"%tmp_\d+ = arith\.subi %valid_cols_\d+, %lane_\d+ : index") + self.assertRegex(text, r"%tmp_\d+ = arith\.index_cast %tmp_\d+ : index to i32") + self.assertIn("pto.plt_b32", text) + self.assertIn("pto.vadd", text) + + def test_advanced_mode_infers_vecscope_and_lowers_tile_vector_sugar(self) -> None: + @pto.vkernel(op="tadd", dtypes=[(pto.f32, pto.f32, pto.f32)], advanced=True) + def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + dtype = dst.element_type + rows, cols = dst.valid_shape + all_mask = pto.make_mask(dtype, pto.PAT.ALL) + for row in range(0, rows, 1): + for col in range(0, cols, pto.get_lanes(dtype)): + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + summed = pto.vadd(lhs, rhs, all_mask) + pto.vsts(summed, dst[row, col:], all_mask) + return None + + self.assertTrue(kernel.advanced_enabled) + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src0=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src1=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertEqual(len(semantic_kernel.body), 2) + self.assertIsInstance(semantic_kernel.body[0], SemanticVecscopeStmt) + vecscope = semantic_kernel.body[0] + self.assertIsInstance(vecscope, SemanticVecscopeStmt) + outer_loop = next(stmt for stmt in vecscope.body if isinstance(stmt, SemanticForStmt)) + self.assertIsInstance(outer_loop, SemanticForStmt) + inner_loop = outer_loop.body[0] + self.assertIsInstance(inner_loop, SemanticForStmt) + self.assertTrue(inner_loop.body) + + text = specialized.mlir_text() + self.assertIn("// tilelang.advanced = True", text) + self.assertIn("pto.vecscope {", text) + self.assertNotIn("pto.strict_vecscope(", text) + self.assertRegex(text, r"pto\.vecscope \{\n(?:.|\n)*scf\.for %row_") + self.assertEqual(text.count("pto.vecscope {"), 1) + self.assertRegex(text, r"%tmp_\d+ = arith\.muli %row_\d+, %c64 : index") + self.assertRegex(text, r"%tmp_\d+ = arith\.addi %tmp_\d+, %col_\d+ : index") + self.assertIn("pto.vlds %arg1[", text) + self.assertIn("pto.vlds %arg2[", text) + self.assertIn("pto.vsts %summed_", text) + + def test_element_type_valid_shape_and_get_lanes_surface_lower_in_advanced_mode(self) -> None: + @pto.vkernel(op="tadd", dtypes=[(pto.f32, pto.f32, pto.f32)], advanced=True) + def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + remained = valid_cols + for row in range(0, valid_rows, 1): + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + summed = pto.vadd(pto.vlds(src0[row, col:]), pto.vlds(src1[row, col:]), mask) + pto.vsts(summed, dst[row, col:], mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src0=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src1=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("step %c64", text) + self.assertRegex(text, r"%mask_\d+, %remained_\d+ = pto\.plt_b32 %remained_iter_\d+ : i32 -> !pto\.mask, i32") + self.assertIn("pto.vadd", text) + self.assertIn("pto.vsts", text) + + def test_advanced_mode_keeps_strict_vecscope_as_hard_boundary(self) -> None: + @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32)], advanced=True) + def kernel(src: pto.Tile, dst: pto.Tile): + all_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + rows = src.shape[0] + for row in range(0, rows, 1): + vec = pto.vlds(src[row, 0:]) + pto.vsts(vec, dst[row, 0:], all_mask) + with pto.strict_vecscope(src, dst, all_mask, 0, 64, 64) as (vin, vout, mask, lb, ub, step): + for lane in range(lb, ub, step): + scoped = pto.vlds(vin, lane) + pto.vsts(scoped, vout, lane, mask) + return None + + specialized = kernel.specialize( + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertEqual(text.count("pto.vecscope {"), 1) + self.assertEqual(text.count("pto.strict_vecscope("), 1) + + def test_elementwise_kernel_positive_regression_covers_dma_vecscope_tail_mask_and_dynamic_loop_bound(self) -> None: + @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32, pto.f32, pto.i32)]) + def kernel(inp: pto.TensorView, out: pto.TensorView, tile: pto.Tile, remaining: pto.i32): + rows = inp.shape[0] + pto.dma_load(inp[0:rows, 0:16], tile) + with pto.strict_vecscope(tile, tile, remaining, 0, rows, 64) as ( + src, + dst, + rem, + lb, + ub, + step, + ): + for lane in range(lb, ub, step): + mask, rem = pto.make_mask(pto.f32, rem) + vec = pto.vlds(src, lane) + pto.vsts(vec, dst, lane, mask) + pto.dma_store(tile, out[0:rows, 0:16]) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(16, 16), + memory_space=pto.MemorySpace.UB, + ) + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertEqual(len(semantic_kernel.body), 5) + self.assertIsInstance(semantic_kernel.body[1], SemanticDmaLoadStmt) + self.assertIsInstance(semantic_kernel.body[2], SemanticStrictVecscopeStmt) + self.assertIsInstance(semantic_kernel.body[3], SemanticDmaStoreStmt) + + vecscope = semantic_kernel.body[2] + self.assertIsInstance(vecscope, SemanticStrictVecscopeStmt) + loop_stmt = vecscope.body[0] + self.assertIsInstance(loop_stmt, SemanticForStmt) + self.assertEqual(len(loop_stmt.loop_carried), 1) + self.assertEqual(loop_stmt.loop_carried[0].name, "rem") + + text = specialized.mlir_text() + self.assertIn( + "func.func @kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr, %arg3: i32, %arg4: index) {", + text, + ) + self.assertIn( + "pto.copy_gm_to_ubuf %arg0, %arg2, %c0_i64, %c16_i64, %c64_i64", + text, + ) + self.assertIn( + "pto.strict_vecscope(%arg2, %arg2, %arg3, %c0, %arg4, %c64)", + text, + ) + self.assertRegex( + text, + r"scf\.for %lane_\d+ = %lb_\d+ to %ub_\d+ step %step_\d+ iter_args\(%rem_iter_\d+ = %rem_\d+\) -> \(i32\) \{", + ) + self.assertRegex( + text, + r"%mask_\d+, %rem_\d+ = pto\.plt_b32 %rem_iter_\d+ : i32 -> !pto\.mask, i32", + ) + self.assertIn( + "pto.copy_ubuf_to_gm %arg2, %arg1, %c0_i64, %c16_i64, %c64_i64", + text, + ) + def test_if_else_and_sync_ops_lower_to_scf_if_and_authoring_sync_ops(self) -> None: @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32, pto.i32)]) def kernel(inp: pto.TensorView, tile: pto.Tile, flag: pto.i32): @@ -397,6 +725,23 @@ def kernel(x: pto.TensorView): self.assertIn("vector op surface `pto.vadd` requires explicit pto.strict_vecscope", str(ctx.exception)) self.assertIn(f"{__file__}:", str(ctx.exception)) + def test_unsupported_advanced_family_points_to_follow_up_change(self) -> None: + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.vkernel(op="x", dtypes=[(pto.f32, pto.f32)]) + def kernel(x: pto.TensorView, tile: pto.Tile): + with pto.strict_vecscope(tile, tile, 0, 256, 64) as (lhs, rhs, lb, ub, step): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + pto.vcmp(lhs, rhs, mask, "lt") + return None + + self.assertIn("advanced family surface `pto.vcmp`", str(ctx.exception)) + self.assertIn( + "extend-tilelang-dsl-matcher-and-advanced-surface", + str(ctx.exception), + ) + self.assertIn(f"{__file__}:", str(ctx.exception)) + def test_missing_specialization_reports_source_location(self) -> None: @pto.vkernel(op="x", dtypes=[(pto.f32, pto.f16)]) def kernel(x: pto.TensorView, tile: pto.Tile): From 0d8e71e8235f612a22c398b8c44ba46d6c571bdd Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Sat, 4 Apr 2026 17:03:14 +0800 Subject: [PATCH 020/192] Update openspec --- .../.openspec.yaml | 0 .../design.md | 0 .../proposal.md | 0 .../specs/tilelang-dsl-vpto-lowering/spec.md | 0 .../tasks.md | 16 ++++++++-------- 5 files changed, 8 insertions(+), 8 deletions(-) rename openspec/changes/{add-tilelang-dsl-authoring-vpto-lowering => archive/2026-04-04-add-tilelang-dsl-authoring-vpto-lowering}/.openspec.yaml (100%) rename openspec/changes/{add-tilelang-dsl-authoring-vpto-lowering => archive/2026-04-04-add-tilelang-dsl-authoring-vpto-lowering}/design.md (100%) rename openspec/changes/{add-tilelang-dsl-authoring-vpto-lowering => archive/2026-04-04-add-tilelang-dsl-authoring-vpto-lowering}/proposal.md (100%) rename openspec/changes/{add-tilelang-dsl-authoring-vpto-lowering => archive/2026-04-04-add-tilelang-dsl-authoring-vpto-lowering}/specs/tilelang-dsl-vpto-lowering/spec.md (100%) rename openspec/changes/{add-tilelang-dsl-authoring-vpto-lowering => archive/2026-04-04-add-tilelang-dsl-authoring-vpto-lowering}/tasks.md (78%) diff --git a/openspec/changes/add-tilelang-dsl-authoring-vpto-lowering/.openspec.yaml b/openspec/changes/archive/2026-04-04-add-tilelang-dsl-authoring-vpto-lowering/.openspec.yaml similarity index 100% rename from openspec/changes/add-tilelang-dsl-authoring-vpto-lowering/.openspec.yaml rename to openspec/changes/archive/2026-04-04-add-tilelang-dsl-authoring-vpto-lowering/.openspec.yaml diff --git a/openspec/changes/add-tilelang-dsl-authoring-vpto-lowering/design.md b/openspec/changes/archive/2026-04-04-add-tilelang-dsl-authoring-vpto-lowering/design.md similarity index 100% rename from openspec/changes/add-tilelang-dsl-authoring-vpto-lowering/design.md rename to openspec/changes/archive/2026-04-04-add-tilelang-dsl-authoring-vpto-lowering/design.md diff --git a/openspec/changes/add-tilelang-dsl-authoring-vpto-lowering/proposal.md b/openspec/changes/archive/2026-04-04-add-tilelang-dsl-authoring-vpto-lowering/proposal.md similarity index 100% rename from openspec/changes/add-tilelang-dsl-authoring-vpto-lowering/proposal.md rename to openspec/changes/archive/2026-04-04-add-tilelang-dsl-authoring-vpto-lowering/proposal.md diff --git a/openspec/changes/add-tilelang-dsl-authoring-vpto-lowering/specs/tilelang-dsl-vpto-lowering/spec.md b/openspec/changes/archive/2026-04-04-add-tilelang-dsl-authoring-vpto-lowering/specs/tilelang-dsl-vpto-lowering/spec.md similarity index 100% rename from openspec/changes/add-tilelang-dsl-authoring-vpto-lowering/specs/tilelang-dsl-vpto-lowering/spec.md rename to openspec/changes/archive/2026-04-04-add-tilelang-dsl-authoring-vpto-lowering/specs/tilelang-dsl-vpto-lowering/spec.md diff --git a/openspec/changes/add-tilelang-dsl-authoring-vpto-lowering/tasks.md b/openspec/changes/archive/2026-04-04-add-tilelang-dsl-authoring-vpto-lowering/tasks.md similarity index 78% rename from openspec/changes/add-tilelang-dsl-authoring-vpto-lowering/tasks.md rename to openspec/changes/archive/2026-04-04-add-tilelang-dsl-authoring-vpto-lowering/tasks.md index be8775819..2f99431f3 100644 --- a/openspec/changes/add-tilelang-dsl-authoring-vpto-lowering/tasks.md +++ b/openspec/changes/archive/2026-04-04-add-tilelang-dsl-authoring-vpto-lowering/tasks.md @@ -14,17 +14,17 @@ - [x] 3.1 实现 `dma_load` / `dma_store` 的 TensorView slice 到 DMA programming + copy op lowering。 - [x] 3.2 实现 `make_mask`、`vlds`、`vsts` 以及 v1 unary/binary/vector-scalar family 的 lowering。 - [x] 3.3 实现 `for range(lb, ub, step)`、`if/else`、`set_flag`、`wait_flag`、`pipe_barrier` 的 lowering。 -- [ ] 3.4 对 support matrix 外的 family 保持 fail-fast reject,不允许 silent fallback。 +- [x] 3.4 对 support matrix 外的 family 保持 fail-fast reject,不允许 silent fallback。 ## 4. Dynamic-bound 与合法性验证 -- [ ] 4.1 实现“静态 physical Tile + 动态 TensorView slice/loop bound”的 shape profile,拒绝 dynamic physical tile shape。 -- [ ] 4.2 实现 tail `make_mask(dtype, remaining)` 的 typed-mask lowering,确保输出满足当前 VPTO legality contract。 -- [ ] 4.3 实现 `descriptor.verify()`,通过 `ptoas` binary 运行与 `--pto-backend=vpto` 一致的 authoring-stage legality 验证,并对 binary 缺失返回结构化 unavailable 结果。 +- [x] 4.1 实现“静态 physical Tile + 动态 TensorView slice/loop bound”的 shape profile,拒绝 dynamic physical tile shape。 +- [x] 4.2 实现 tail `make_mask(dtype, remaining)` 的 typed-mask lowering,确保输出满足当前 VPTO legality contract。 +- [x] 4.3 实现 `descriptor.verify()`,通过 `ptoas` binary 运行与 `--pto-backend=vpto` 一致的 authoring-stage legality 验证,并对 binary 缺失返回结构化 unavailable 结果。 ## 5. 测试、样例与文档 -- [ ] 5.1 在 `tilelang-dsl/tests/` 增加 elementwise kernel 的 positive regression,覆盖 `dma_load/store`、`strict_vecscope`、typed-mask、dynamic loop bound。 -- [ ] 5.2 增加 negative regression,覆盖 vector op 出 scope、unsupported family、非法 shape profile、verifier unavailable。 -- [ ] 5.3 在 `tilelang-dsl/examples/` 和 `tilelang-dsl/docs/` 提供与 guide 对齐的 v1 示例,并明确记录 support matrix 与延期 feature。 -- [ ] 5.4 运行并记录最小验证命令,确认生成的 IR 能通过 `build/tools/ptoas/ptoas --pto-backend=vpto` 的 authoring-stage legality 路径。 +- [x] 5.1 在 `tilelang-dsl/tests/` 增加 elementwise kernel 的 positive regression,覆盖 `dma_load/store`、`strict_vecscope`、typed-mask、dynamic loop bound。 +- [x] 5.2 增加 negative regression,覆盖 vector op 出 scope、unsupported family、非法 shape profile、verifier unavailable。 +- [x] 5.3 在 `tilelang-dsl/examples/` 和 `tilelang-dsl/docs/` 提供与 guide 对齐的 v1 示例,并明确记录 support matrix 与延期 feature。 +- [x] 5.4 运行并记录最小验证命令,确认生成的 IR 能通过 `build/tools/ptoas/ptoas --pto-backend=vpto` 的 authoring-stage legality 路径。 From 5d96301b01217a348e40e949c29fb176a5db1a2f Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Tue, 7 Apr 2026 10:16:13 +0800 Subject: [PATCH 021/192] Support more dsl syntax --- tilelang-dsl/docs/README.md | 4 + .../matcher-and-advanced-surface-migration.md | 195 +++++ tilelang-dsl/docs/v1-lowering.md | 15 +- tilelang-dsl/docs/v1-surface.md | 13 +- tilelang-dsl/python/tilelang_dsl/__init__.py | 8 + tilelang-dsl/python/tilelang_dsl/kernel.py | 458 ++++++++++-- tilelang-dsl/python/tilelang_dsl/lowering.py | 315 +++++++- tilelang-dsl/python/tilelang_dsl/semantic.py | 682 ++++++++++++++++-- .../python/tilelang_dsl/support_matrix.py | 40 +- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 529 +++++++++++++- 10 files changed, 2071 insertions(+), 188 deletions(-) create mode 100644 tilelang-dsl/docs/matcher-and-advanced-surface-migration.md diff --git a/tilelang-dsl/docs/README.md b/tilelang-dsl/docs/README.md index a4c489fef..b93493489 100644 --- a/tilelang-dsl/docs/README.md +++ b/tilelang-dsl/docs/README.md @@ -5,6 +5,10 @@ Current docs: `add-tilelang-dsl-core-foundation` - `v1-lowering.md`: the TileLang DSL v1 authoring-form VPTO lowering contract implemented by `add-tilelang-dsl-authoring-vpto-lowering` +- `matcher-and-advanced-surface-migration.md`: migration notes from the + original v1 core/lowering boundary to the matcher and advanced-surface + capability implemented by + `extend-tilelang-dsl-matcher-and-advanced-surface` Documentation boundary: - `tilelang-dsl/docs/` is the local documentation source of truth for the new diff --git a/tilelang-dsl/docs/matcher-and-advanced-surface-migration.md b/tilelang-dsl/docs/matcher-and-advanced-surface-migration.md new file mode 100644 index 000000000..c28240ea0 --- /dev/null +++ b/tilelang-dsl/docs/matcher-and-advanced-surface-migration.md @@ -0,0 +1,195 @@ +# TileLang DSL Matcher And Advanced-Surface Migration + +## Scope + +This document explains how to move from the original v1 core contract +(`add-tilelang-dsl-core-foundation` + +`add-tilelang-dsl-authoring-vpto-lowering`) to the matcher and +advanced-surface capability implemented by +`extend-tilelang-dsl-matcher-and-advanced-surface`. + +It focuses on: +- matcher-driven kernel selection +- implicit vecscope inference +- raw pointer / low-level DMA authoring +- advanced vector-family coverage that is implemented today +- the remaining deferred boundary + +## What Changed + +The original v1 core profile assumed: +- one monomorphic `dtypes` signature +- no matcher registry or selection API +- explicit `pto.strict_vecscope` for vector code +- no raw-pointer or low-level DMA authoring surface +- no advanced vector-family lowering beyond the fixed elementwise set + +The current package now adds: +- `KernelRegistry` +- `pto.select_kernel(...)` +- multi-signature `dtypes` +- `AnyFloat`, `AnyInt`, `AnyType`, `AnyMask` +- `TypeVar(...)` +- `constraints=[...]` +- `priority=` +- implicit vecscope inference in `advanced=True` kernels +- `ptr(...)` / `PointerType` +- `castptr`, `addptr` +- low-level DMA config/copy surface +- compare/select, predicate movement, carry, and rearrangement families + +## Matcher Migration + +### Before + +The original v1 contract only supported one concrete signature: + +```python +@pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32)]) +def kernel(inp: pto.TensorView, out: pto.Tile): + return None +``` + +### After + +You can now register multiple polymorphic descriptors and let the matcher pick +the concrete specialization: + +```python +@pto.vkernel( + op="eltwise", + dtypes=[ + (pto.AnyFloat, pto.AnyFloat), + (pto.AnyInt, pto.AnyInt), + ], + constraints=[lambda attrs: attrs.get("enabled", True)], + priority=10, +) +def kernel(inp: pto.TensorView, out: pto.Tile): + return None + +selected = pto.select_kernel( + "a5", + "eltwise", + (pto.f32, pto.f32), + context_attrs={"enabled": True}, +) +``` + +Matcher rules in the implemented package: +- matching is deterministic +- selection order is `target -> op -> dtypes -> constraints -> priority` +- highest-priority ties raise an explicit error +- `TypeVar` only binds within one signature + +## Vecscope Migration + +### Before + +Vector code needed an explicit `pto.strict_vecscope` boundary: + +```python +with pto.strict_vecscope(tile, tile, 0, 256, 64) as (src, dst, lb, ub, step): + for lane in range(lb, ub, step): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec = pto.vlds(src, lane) + pto.vsts(vec, dst, lane, mask) +``` + +### After + +In `advanced=True` kernels, the frontend now infers `pto.vecscope` for +contiguous vector-active regions: + +```python +@pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32)], advanced=True) +def kernel(src: pto.Tile, dst: pto.Tile): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec = pto.vlds(src[0, 0:]) + pto.vsts(vec, dst[0, 0:], mask) +``` + +Inference boundaries in the implemented package: +- scalar statements cut inference +- `if` / `for` structure is respected +- sync and DMA statements cut inference +- explicit `pto.strict_vecscope` remains a hard boundary + +Use `pto.strict_vecscope` when you need a deterministic region ABI or do not +want inference to merge adjacent vector chains. + +## Pointer And DMA Migration + +### New Pointer Surface + +The package now exposes: +- `pto.ptr(dtype, memory_space)` +- pointer-typed parameters such as `pto.ptr(pto.f32, pto.MemorySpace.UB)` +- `pto.castptr(...)` +- `pto.addptr(...)` + +Example: + +```python +@pto.vkernel(op="copy", dtypes=[(pto.f32, pto.i64)], advanced=True) +def kernel(dst: pto.ptr(pto.f32, pto.MemorySpace.UB), addr: pto.i64): + src = pto.castptr(addr, pto.ptr(pto.f32, pto.MemorySpace.UB)) + next_src = pto.addptr(src, 64) + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec = pto.vlds(src, 0) + pto.vsts(vec, next_src, 0, mask) +``` + +### New Low-Level DMA Surface + +The package now lowers: +- `set_loop2_stride_outtoub` +- `set_loop1_stride_outtoub` +- `set_loop_size_outtoub` +- `set_loop2_stride_ubtoout` +- `set_loop1_stride_ubtoout` +- `set_loop_size_ubtoout` +- `copy_gm_to_ubuf` +- `copy_ubuf_to_gm` +- `copy_ubuf_to_ubuf` + +High-level `dma_load` / `dma_store` remain the preferred default. Use the +low-level surface only when you need manual DMA programming. + +## Advanced Vector Families + +The currently implemented advanced-family groups are: +- compare/select: + `vcmp`, `vcmps`, `vsel`, `vselr`, `vselrv2` +- predicate movement: + `pnot`, `psel`, `ppack`, `punpack` +- carry family: + `vaddc`, `vsubc`, `vaddcs`, `vsubcs` +- rearrangement: + `vintlv`, `vdintlv`, `vintlvv2`, `vdintlvv2` + +These lower directly to authoring-form VPTO and are covered by +`tilelang-dsl/tests/test_tilelang_dsl_v1.py`. + +## Still Deferred + +The following boundary remains intentionally deferred: +- reduction family authoring + +Reason: +- the current repo does not expose a public authoring-form VPTO reduction op + that TileLang DSL can target directly +- existing reduction logic lives in other lowering paths such as OpLib / EmitC + and cannot be treated as the public TileLang DSL authoring contract + +Current package behavior: +- reduction-family surface remains an explicit frontend reject +- no extra helper IR is introduced to fake reduction support + +## Recommended Reading Order + +For the current package contract, read in this order: +1. `tilelang-dsl/docs/v1-surface.md` +2. `tilelang-dsl/docs/v1-lowering.md` +3. `tilelang-dsl/docs/matcher-and-advanced-surface-migration.md` +4. `docs/tilelang-dsl-guide.md` diff --git a/tilelang-dsl/docs/v1-lowering.md b/tilelang-dsl/docs/v1-lowering.md index ef4da2752..9c0c6c8f4 100644 --- a/tilelang-dsl/docs/v1-lowering.md +++ b/tilelang-dsl/docs/v1-lowering.md @@ -17,6 +17,10 @@ It does not define: - raw pointer authoring surface - advanced vector-family lowering beyond the fixed v1 matrix +For migration from that original v1 lowering boundary to the current matcher +and advanced-surface implementation, see +`tilelang-dsl/docs/matcher-and-advanced-surface-migration.md`. + ## Source Of Truth The implemented lowering surface lives under: @@ -88,9 +92,10 @@ PYTHONPATH=$PWD/tilelang-dsl/python \ python3 tilelang-dsl/examples/v1_verify_smoke.py ``` -## Deferred Features +## Historical Deferred Features -The following remain outside v1 and belong to follow-up changes: +The following remained outside the original v1 lowering boundary and were +assigned to follow-up changes: - implicit vecscope inference - matcher registry and deterministic selection - raw pointer / low-level DMA / `copy_ubuf_to_ubuf` authoring surface @@ -101,6 +106,12 @@ The following remain outside v1 and belong to follow-up changes: Primary follow-up change: - `extend-tilelang-dsl-matcher-and-advanced-surface` +In the current package head, that follow-up has implemented matcher dispatch, +implicit vecscope inference, raw pointer / low-level DMA authoring, and +compare/select + predicate movement + carry + rearrangement families. +Reduction remains deferred because the repo still does not expose a public +authoring-form VPTO reduction op for TileLang DSL to target directly. + ## Minimal Validation The minimal validation set for the implemented v1 lowering is: diff --git a/tilelang-dsl/docs/v1-surface.md b/tilelang-dsl/docs/v1-surface.md index c3a542684..53e3b2dab 100644 --- a/tilelang-dsl/docs/v1-surface.md +++ b/tilelang-dsl/docs/v1-surface.md @@ -22,6 +22,9 @@ It does not define: For implemented lowering details, examples, and `verify()` behavior, see `tilelang-dsl/docs/v1-lowering.md`. +For migration from the original v1 core boundary to the current matcher and +advanced-surface package capabilities, see +`tilelang-dsl/docs/matcher-and-advanced-surface-migration.md`. ## Source Of Truth @@ -218,10 +221,10 @@ Expected output shape: This confirms diagnostics are emitted against the authored DSL source file rather than an internal lowering location. -## Deferred Features +## Historical Deferred Features -The following are intentionally out of scope for v1 and belong to follow-up -changes: +The following were intentionally out of scope for the original v1 core boundary +and were assigned to follow-up changes: - multiple `dtypes` signatures - `constraints` - `priority` @@ -234,3 +237,7 @@ changes: Matcher-related extensions are deferred to `extend-tilelang-dsl-matcher-and-advanced-surface`. +That follow-up capability is now implemented in the current package head; use +`tilelang-dsl/docs/matcher-and-advanced-surface-migration.md` for the updated +surface boundary instead of reading the list above as a statement about current +head behavior. diff --git a/tilelang-dsl/python/tilelang_dsl/__init__.py b/tilelang-dsl/python/tilelang_dsl/__init__.py index f1bf887af..73c2694c8 100644 --- a/tilelang-dsl/python/tilelang_dsl/__init__.py +++ b/tilelang-dsl/python/tilelang_dsl/__init__.py @@ -2,9 +2,11 @@ from .kernel import ( BoundKernelParameter, + KernelRegistry, MaterializedMLIRModule, TileLangFrontendError, VKernelDescriptor, + select_kernel, vkernel, ) from .types import ( @@ -18,6 +20,7 @@ MemorySpace, MaskPattern, PAT, + PointerType, Pipe, ScalarType, TensorView, @@ -36,13 +39,16 @@ i16, i32, i64, + ptr, ) __all__ = [ "BoundKernelParameter", + "KernelRegistry", "MaterializedMLIRModule", "TileLangFrontendError", "VKernelDescriptor", + "select_kernel", "vkernel", "ScalarType", "WildcardType", @@ -50,6 +56,8 @@ "TypeVar", "TensorView", "Tile", + "PointerType", + "ptr", "MemorySpace", "Pipe", "Event", diff --git a/tilelang-dsl/python/tilelang_dsl/kernel.py b/tilelang-dsl/python/tilelang_dsl/kernel.py index cdd16e3a5..69b305318 100644 --- a/tilelang-dsl/python/tilelang_dsl/kernel.py +++ b/tilelang-dsl/python/tilelang_dsl/kernel.py @@ -10,10 +10,11 @@ import textwrap from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Callable +from typing import Any, Callable, Mapping from .types import ( MemorySpace, + PointerType, ScalarType, TensorView, Tile, @@ -26,10 +27,13 @@ from .lowering import lower_semantic_kernel from .semantic import analyze_frontend_kernel from .support_matrix import ( + ADVANCED_EXPR_PTO_CALLS, + ADVANCED_TOPLEVEL_PTO_CALLS, + ADVANCED_VECSCOPE_PTO_CALLS, DEFERRED_PTO_SURFACES, SUPPORTED_TOPLEVEL_PTO_CALLS, SUPPORTED_VECSCOPE_PTO_CALLS, - unsupported_feature_message, + advanced_mode_message, deferred_surface_message, ) @@ -38,21 +42,10 @@ _PTOAS_BIN_ENV = "PTOAS_BIN" -def _reject_unsupported_decorator_feature(name: str, value: Any) -> None: - if value is _UNSET: - return - raise ValueError(unsupported_feature_message(f"decorator feature `{name}`")) - - -def _reject_unsupported_dtype_feature(dtype: Any) -> None: - if isinstance(dtype, WildcardType): - raise ValueError( - unsupported_feature_message(f"dtype wildcard `{dtype.name}`") - ) - if isinstance(dtype, TypeVariable): - raise ValueError( - unsupported_feature_message(f"dtype type variable `{dtype.name}`") - ) +def _validate_dtype_pattern(dtype: Any) -> ScalarType | WildcardType | TypeVariable: + if isinstance(dtype, (ScalarType, WildcardType, TypeVariable)): + return dtype + raise TypeError(f"unsupported dtype pattern {dtype!r}") class TileLangFrontendError(ValueError): @@ -182,6 +175,23 @@ def visit_Call(self, node: ast.Call) -> None: f"vector op surface `pto.{node.func.attr}` requires explicit pto.strict_vecscope in TileLang DSL v1", ) return + if node.func.value.id == "pto" and node.func.attr in ADVANCED_VECSCOPE_PTO_CALLS: + if self.advanced_enabled: + return + raise self.source_info.error( + node, + advanced_mode_message(node.func.attr), + ) + if node.func.value.id == "pto" and ( + node.func.attr in ADVANCED_EXPR_PTO_CALLS + or node.func.attr in ADVANCED_TOPLEVEL_PTO_CALLS + ): + if self.advanced_enabled: + return + raise self.source_info.error( + node, + advanced_mode_message(node.func.attr), + ) if node.func.value.id == "pto" and node.func.attr in DEFERRED_PTO_SURFACES: raise self.source_info.error( node, @@ -263,17 +273,12 @@ def _freeze_dtypes(dtypes: Any) -> tuple[tuple[Any, ...], ...]: raise TypeError("each dtypes entry must be a signature tuple") frozen_signature = tuple(signature) for dtype in frozen_signature: - _reject_unsupported_dtype_feature(dtype) + _validate_dtype_pattern(dtype) frozen_signatures.append(frozen_signature) if not frozen_signatures: raise ValueError("dtypes must contain at least one signature tuple") - if len(frozen_signatures) != 1: - raise ValueError( - unsupported_feature_message("multiple dtypes signatures") - ) - return tuple(frozen_signatures) @@ -288,11 +293,20 @@ class BoundKernelParameter: @property def element_dtype(self) -> ScalarType | None: - if self.kind in ("tensorview", "tile"): + if self.kind in ("tensorview", "tile", "ptr"): return self.dtype return None +@dataclass(frozen=True) +class KernelParameterSpec: + """One validated Python function parameter before dtype selection.""" + + name: str + kind: str + annotation: Any + + @dataclass(frozen=True) class VKernelDescriptor: """Descriptor returned by `@tilelang_dsl.vkernel`.""" @@ -303,10 +317,14 @@ class VKernelDescriptor: name: str verify_enabled: bool advanced_enabled: bool - parameters: tuple[BoundKernelParameter, ...] + _parameter_specs: tuple[KernelParameterSpec, ...] _py_fn: Callable[..., Any] = field(repr=False) _source_info: _FunctionSourceInfo | None = field(repr=False, compare=False, default=None) specializations: tuple[tuple[str, TileSpecialization], ...] = () + constraints: tuple[Callable[[Mapping[str, Any]], Any], ...] = field(default=(), repr=False) + priority: int = 0 + _selected_dtype_signature: tuple[ScalarType, ...] | None = None + _parameters: tuple[BoundKernelParameter, ...] | None = field(default=None, repr=False) @property def py_fn(self) -> Callable[..., Any]: @@ -314,7 +332,21 @@ def py_fn(self) -> Callable[..., Any]: @property def dtype_signature(self) -> tuple[ScalarType, ...]: - return self.dtypes[0] + if self._selected_dtype_signature is None: + raise ValueError( + "descriptor requires pto.select_kernel(...) to choose a concrete dtype signature " + "before materialization" + ) + return self._selected_dtype_signature + + @property + def parameters(self) -> tuple[BoundKernelParameter, ...]: + if self._parameters is None: + raise ValueError( + "descriptor requires pto.select_kernel(...) to bind concrete parameter dtypes " + "before materialization" + ) + return self._parameters @property def metadata(self) -> dict[str, Any]: @@ -325,6 +357,8 @@ def metadata(self) -> dict[str, Any]: "name": self.name, "verify": self.verify_enabled, "advanced": self.advanced_enabled, + "constraints": self.constraints, + "priority": self.priority, } @property @@ -335,9 +369,34 @@ def tile_parameters(self) -> tuple[BoundKernelParameter, ...]: def specializations_by_name(self) -> dict[str, TileSpecialization]: return dict(self.specializations) + def _tile_parameter_names(self) -> tuple[str, ...]: + return tuple(param.name for param in self._parameter_specs if param.kind == "tile") + + def _bind_selected_dtype_signature( + self, + dtype_signature: tuple[ScalarType, ...], + ) -> "VKernelDescriptor": + bound_parameters = _bind_parameters(self._parameter_specs, dtype_signature) + return VKernelDescriptor( + target=self.target, + op=self.op, + dtypes=self.dtypes, + name=self.name, + verify_enabled=self.verify_enabled, + advanced_enabled=self.advanced_enabled, + _parameter_specs=self._parameter_specs, + _py_fn=self._py_fn, + _source_info=self._source_info, + specializations=self.specializations, + constraints=self.constraints, + priority=self.priority, + _selected_dtype_signature=dtype_signature, + _parameters=bound_parameters, + ) + def specialize(self, **bindings: Any) -> "VKernelDescriptor": - tile_params = {param.name: param for param in self.tile_parameters} - if not tile_params: + tile_param_names = set(self._tile_parameter_names()) + if not tile_param_names: if bindings: unknown = ", ".join(sorted(bindings)) raise TypeError( @@ -345,7 +404,7 @@ def specialize(self, **bindings: Any) -> "VKernelDescriptor": ) return self - unknown = sorted(set(bindings) - set(tile_params)) + unknown = sorted(set(bindings) - tile_param_names) if unknown: unknown_names = ", ".join(unknown) raise TypeError( @@ -363,14 +422,18 @@ def specialize(self, **bindings: Any) -> "VKernelDescriptor": name=self.name, verify_enabled=self.verify_enabled, advanced_enabled=self.advanced_enabled, - parameters=self.parameters, + _parameter_specs=self._parameter_specs, _source_info=self._source_info, specializations=tuple(sorted(updated.items())), + constraints=self.constraints, + priority=self.priority, + _selected_dtype_signature=self._selected_dtype_signature, + _parameters=self._parameters, _py_fn=self._py_fn, ) def _require_specialized_tiles(self, api_name: str) -> None: - tile_names = [param.name for param in self.tile_parameters] + tile_names = list(self._tile_parameter_names()) if not tile_names: return @@ -386,6 +449,7 @@ def _require_specialized_tiles(self, api_name: str) -> None: ) def _build_authoring_module(self): + self.parameters frontend_kernel = build_frontend_kernel_node(self) semantic_kernel = analyze_frontend_kernel(frontend_kernel) return lower_semantic_kernel(semantic_kernel) @@ -408,6 +472,34 @@ def emit(self, path: str | Path) -> None: output_path.write_text(self.mlir_text(), encoding="utf-8") +class KernelRegistry: + """Explicit registry for TileLang kernel descriptors.""" + + def __init__(self, descriptors: tuple[VKernelDescriptor, ...] = ()): + self._descriptors: list[VKernelDescriptor] = [] + for descriptor in descriptors: + self.register(descriptor) + + def register(self, descriptor: VKernelDescriptor) -> VKernelDescriptor: + if not isinstance(descriptor, VKernelDescriptor): + raise TypeError("KernelRegistry.register() expects a VKernelDescriptor") + self._descriptors.append(descriptor) + return descriptor + + @property + def descriptors(self) -> tuple[VKernelDescriptor, ...]: + return tuple(self._descriptors) + + def __iter__(self): + return iter(self._descriptors) + + def __len__(self) -> int: + return len(self._descriptors) + + +_DEFAULT_KERNEL_REGISTRY = KernelRegistry() + + @dataclass(frozen=True) class MaterializedMLIRModule: text: str @@ -616,6 +708,28 @@ def _validate_advanced(advanced: Any) -> bool: return advanced +def _validate_constraints(constraints: Any) -> tuple[Callable[[Mapping[str, Any]], Any], ...]: + if constraints is _UNSET: + return () + if not isinstance(constraints, (list, tuple)): + raise TypeError("constraints must be a sequence of predicate callables") + + frozen_constraints = [] + for index, constraint in enumerate(constraints): + if not callable(constraint): + raise TypeError(f"constraints[{index}] must be callable") + frozen_constraints.append(constraint) + return tuple(frozen_constraints) + + +def _validate_priority(priority: Any) -> int: + if priority is _UNSET: + return 0 + if isinstance(priority, bool) or not isinstance(priority, int): + raise TypeError("priority must be an int") + return priority + + def _coerce_memory_space(value: Any, param_name: str) -> MemorySpace: if isinstance(value, MemorySpace): return value @@ -722,9 +836,65 @@ def _validate_scalar_dtype(dtype: Any, param_name: str) -> ScalarType: return dtype -def _bind_parameter( - param: inspect.Parameter, dtype: Any -) -> BoundKernelParameter: +def _freeze_operand_types(operand_types: Any) -> tuple[ScalarType, ...]: + if not isinstance(operand_types, (list, tuple)): + raise TypeError("operand_types must be a sequence of TileLang scalar dtypes") + return tuple(_validate_scalar_dtype(dtype, f"operand_types[{index}]") for index, dtype in enumerate(operand_types)) + + +def _matches_wildcard(pattern: WildcardType, actual: ScalarType) -> bool: + if pattern.name == "AnyType": + return True + if pattern.name == "AnyFloat": + return actual.name in {"f16", "bf16", "f32"} + if pattern.name == "AnyInt": + return actual.name.startswith("i") + if pattern.name == "AnyMask": + return actual.name == "i1" + raise TypeError(f"unsupported wildcard matcher {pattern.name!r}") + + +def _match_dtype_signature( + dtype_signature: tuple[Any, ...], + operand_types: tuple[ScalarType, ...], +) -> tuple[ScalarType, ...] | None: + if len(dtype_signature) != len(operand_types): + return None + + typevar_bindings: dict[str, ScalarType] = {} + for pattern, actual in zip(dtype_signature, operand_types): + if isinstance(pattern, ScalarType): + if pattern != actual: + return None + continue + if isinstance(pattern, WildcardType): + if not _matches_wildcard(pattern, actual): + return None + continue + if isinstance(pattern, TypeVariable): + bound = typevar_bindings.get(pattern.name) + if bound is None: + typevar_bindings[pattern.name] = actual + continue + if bound != actual: + return None + continue + raise TypeError(f"unsupported dtype pattern {pattern!r}") + return operand_types + + +def _match_descriptor_dtype_signature( + descriptor: VKernelDescriptor, + operand_types: tuple[ScalarType, ...], +) -> tuple[ScalarType, ...] | None: + for dtype_signature in descriptor.dtypes: + matched = _match_dtype_signature(dtype_signature, operand_types) + if matched is not None: + return matched + return None + + +def _validate_parameter_spec(param: inspect.Parameter) -> KernelParameterSpec: if param.kind not in ( inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD, @@ -742,33 +912,29 @@ def _bind_parameter( ) annotation = param.annotation - scalar_dtype = _validate_scalar_dtype(dtype, param.name) - if annotation is TensorView: - return BoundKernelParameter( + return KernelParameterSpec( name=param.name, kind="tensorview", annotation=annotation, - dtype=scalar_dtype, ) if annotation is Tile: - return BoundKernelParameter( + return KernelParameterSpec( name=param.name, kind="tile", annotation=annotation, - dtype=scalar_dtype, + ) + if isinstance(annotation, PointerType): + return KernelParameterSpec( + name=param.name, + kind="ptr", + annotation=annotation, ) if isinstance(annotation, ScalarType): - if annotation != scalar_dtype: - raise TypeError( - f"scalar parameter '{param.name}' annotation {annotation!r} " - f"does not match dtypes entry {scalar_dtype!r}" - ) - return BoundKernelParameter( + return KernelParameterSpec( name=param.name, kind="scalar", annotation=annotation, - dtype=scalar_dtype, ) raise TypeError( @@ -776,26 +942,77 @@ def _bind_parameter( ) -def _bind_parameters( - py_fn: Callable[..., Any], dtypes: tuple[tuple[Any, ...], ...] -) -> tuple[BoundKernelParameter, ...]: - if len(dtypes) != 1: - raise ValueError( - "TileLang DSL v1 requires dtypes to contain exactly one monomorphic signature tuple" +def _collect_parameter_specs(py_fn: Callable[..., Any]) -> tuple[KernelParameterSpec, ...]: + signature = inspect.signature(py_fn) + return tuple(_validate_parameter_spec(param) for param in signature.parameters.values()) + + +def _validate_dtype_arity( + parameter_specs: tuple[KernelParameterSpec, ...], + dtypes: tuple[tuple[Any, ...], ...], +) -> None: + for dtype_signature in dtypes: + if len(dtype_signature) != len(parameter_specs): + raise ValueError( + "each dtypes signature must match the decorated function parameter count" + ) + + +def _bind_parameter( + param_spec: KernelParameterSpec, + dtype: Any, +) -> BoundKernelParameter: + scalar_dtype = _validate_scalar_dtype(dtype, param_spec.name) + if param_spec.kind == "tensorview": + return BoundKernelParameter( + name=param_spec.name, + kind=param_spec.kind, + annotation=param_spec.annotation, + dtype=scalar_dtype, + ) + if param_spec.kind == "tile": + return BoundKernelParameter( + name=param_spec.name, + kind=param_spec.kind, + annotation=param_spec.annotation, + dtype=scalar_dtype, ) + if param_spec.kind == "ptr": + if param_spec.annotation.element_dtype != scalar_dtype: + raise TypeError( + f"pointer parameter '{param_spec.name}' annotation {param_spec.annotation!r} " + f"does not match selected dtype {scalar_dtype!r}" + ) + return BoundKernelParameter( + name=param_spec.name, + kind=param_spec.kind, + annotation=param_spec.annotation, + dtype=scalar_dtype, + ) + if param_spec.annotation != scalar_dtype: + raise TypeError( + f"scalar parameter '{param_spec.name}' annotation {param_spec.annotation!r} " + f"does not match selected dtype {scalar_dtype!r}" + ) + return BoundKernelParameter( + name=param_spec.name, + kind=param_spec.kind, + annotation=param_spec.annotation, + dtype=scalar_dtype, + ) - signature = inspect.signature(py_fn) - params = tuple(signature.parameters.values()) - dtype_signature = dtypes[0] - if len(dtype_signature) != len(params): +def _bind_parameters( + parameter_specs: tuple[KernelParameterSpec, ...], + dtype_signature: tuple[ScalarType, ...], +) -> tuple[BoundKernelParameter, ...]: + if len(dtype_signature) != len(parameter_specs): raise ValueError( - "single dtypes signature must match the decorated function parameter count" + "selected dtype signature must match the decorated function parameter count" ) - return tuple( - _bind_parameter(param, dtype) - for param, dtype in zip(params, dtype_signature) + _bind_parameter(param_spec, dtype) + for param_spec, dtype in zip(parameter_specs, dtype_signature) ) @@ -808,6 +1025,8 @@ def _build_descriptor( name: Any, verify: Any, advanced: Any, + constraints: Any, + priority: Any, ) -> VKernelDescriptor: if not callable(py_fn): raise TypeError("@vkernel can only decorate callables") @@ -816,6 +1035,14 @@ def _build_descriptor( advanced_enabled = _validate_advanced(advanced) _validate_function_body(source_info, advanced_enabled=advanced_enabled) frozen_dtypes = _freeze_dtypes(dtypes) + parameter_specs = _collect_parameter_specs(py_fn) + _validate_dtype_arity(parameter_specs, frozen_dtypes) + + selected_dtype_signature: tuple[ScalarType, ...] | None = None + bound_parameters: tuple[BoundKernelParameter, ...] | None = None + if len(frozen_dtypes) == 1 and all(isinstance(dtype, ScalarType) for dtype in frozen_dtypes[0]): + selected_dtype_signature = tuple(frozen_dtypes[0]) + bound_parameters = _bind_parameters(parameter_specs, selected_dtype_signature) return VKernelDescriptor( target=_validate_target(target), @@ -824,12 +1051,107 @@ def _build_descriptor( name=_validate_name(py_fn, name), verify_enabled=_validate_verify(verify), advanced_enabled=advanced_enabled, - parameters=_bind_parameters(py_fn, frozen_dtypes), + _parameter_specs=parameter_specs, _py_fn=py_fn, _source_info=source_info, + constraints=_validate_constraints(constraints), + priority=_validate_priority(priority), + _selected_dtype_signature=selected_dtype_signature, + _parameters=bound_parameters, ) +def _evaluate_constraints( + descriptor: VKernelDescriptor, + context_attrs: Mapping[str, Any], +) -> bool: + for index, constraint in enumerate(descriptor.constraints): + try: + result = constraint(context_attrs) + except Exception as exc: + raise TypeError( + f"constraint {index} for kernel {descriptor.name!r} raised {type(exc).__name__}: {exc}" + ) from exc + if not result: + return False + return True + + +def _format_descriptor_identity(descriptor: VKernelDescriptor) -> str: + dtype_signature = descriptor._selected_dtype_signature + if dtype_signature is None: + dtype_signature = tuple("?" for _ in descriptor.dtypes[0]) if descriptor.dtypes else () + return f"{descriptor.name}(priority={descriptor.priority}, dtypes={dtype_signature!r})" + + +def select_kernel( + target: str, + op: str, + operand_types: Any, + context_attrs: Mapping[str, Any] | None = None, + registry: KernelRegistry | None = None, +) -> VKernelDescriptor: + """Select one registered kernel descriptor for the given query.""" + + normalized_target = _validate_target(target) + normalized_op = _validate_op(op) + normalized_operand_types = _freeze_operand_types(operand_types) + + if context_attrs is None: + normalized_context_attrs: dict[str, Any] = {} + elif isinstance(context_attrs, Mapping): + normalized_context_attrs = dict(context_attrs) + else: + raise TypeError("context_attrs must be a mapping or None") + + active_registry = _DEFAULT_KERNEL_REGISTRY if registry is None else registry + if not isinstance(active_registry, KernelRegistry): + raise TypeError("registry must be a KernelRegistry or None") + + type_matched_candidates = [ + descriptor._bind_selected_dtype_signature(matched_signature) + if descriptor._selected_dtype_signature != matched_signature + else descriptor + for descriptor in active_registry + if descriptor.target == normalized_target + and descriptor.op == normalized_op + for matched_signature in (_match_descriptor_dtype_signature(descriptor, normalized_operand_types),) + if matched_signature is not None + ] + + if not type_matched_candidates: + raise LookupError( + "select_kernel() found no registered kernel for " + f"target={normalized_target!r}, op={normalized_op!r}, operand_types={normalized_operand_types!r}" + ) + + constrained_candidates = [ + descriptor + for descriptor in type_matched_candidates + if _evaluate_constraints(descriptor, normalized_context_attrs) + ] + if not constrained_candidates: + raise LookupError( + "select_kernel() found no registered kernel after constraint evaluation for " + f"target={normalized_target!r}, op={normalized_op!r}, operand_types={normalized_operand_types!r}" + ) + + highest_priority = max(descriptor.priority for descriptor in constrained_candidates) + winners = [ + descriptor + for descriptor in constrained_candidates + if descriptor.priority == highest_priority + ] + if len(winners) > 1: + winner_set = ", ".join(sorted(_format_descriptor_identity(descriptor) for descriptor in winners)) + raise LookupError( + "select_kernel() found multiple highest-priority kernels for " + f"target={normalized_target!r}, op={normalized_op!r}, operand_types={normalized_operand_types!r}: " + f"{winner_set}" + ) + return winners[0] + + def vkernel( py_fn: Callable[..., Any] | None = None, *, @@ -845,13 +1167,12 @@ def vkernel( """Create a TileLang DSL v1 kernel descriptor. v1 keeps only the minimal descriptor metadata surface: - `target`, `op`, `dtypes`, `name`, `verify`, and opt-in `advanced`. + `target`, `op`, `dtypes`, `constraints`, `priority`, `name`, `verify`, + and opt-in `advanced`. """ - _reject_unsupported_decorator_feature("constraints", constraints) - _reject_unsupported_decorator_feature("priority", priority) def wrap(fn: Callable[..., Any]) -> VKernelDescriptor: - return _build_descriptor( + descriptor = _build_descriptor( fn, target=target, op=op, @@ -859,7 +1180,10 @@ def wrap(fn: Callable[..., Any]) -> VKernelDescriptor: name=name, verify=verify, advanced=advanced, + constraints=constraints, + priority=priority, ) + return _DEFAULT_KERNEL_REGISTRY.register(descriptor) if py_fn is None: return wrap @@ -868,8 +1192,10 @@ def wrap(fn: Callable[..., Any]) -> VKernelDescriptor: __all__ = [ "BoundKernelParameter", + "KernelRegistry", "MaterializedMLIRModule", "TileLangFrontendError", "VKernelDescriptor", + "select_kernel", "vkernel", ] diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index d9037c3e5..e75e84094 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -11,6 +11,7 @@ SemanticBinaryExpr, SemanticBindingRef, SemanticCallExpr, + SemanticDmaConfigStmt, SemanticDmaLoadStmt, SemanticDmaStoreStmt, SemanticExpr, @@ -21,10 +22,11 @@ SemanticIfResult, SemanticKernel, SemanticLiteralExpr, + SemanticLowLevelCopyStmt, SemanticMaskType, SemanticMetaType, - SemanticBindingRef, SemanticPipeBarrierStmt, + SemanticPtrType, SemanticReturnStmt, SemanticScalarType, SemanticSetFlagStmt, @@ -159,6 +161,10 @@ def _render_stmt( ] if isinstance(stmt, SemanticPipeBarrierStmt): return [self._indent(indent) + f"pto.barrier #pto.pipe<{stmt.pipe}>"] + if isinstance(stmt, SemanticDmaConfigStmt): + return self._render_dma_config(stmt, env, indent=indent) + if isinstance(stmt, SemanticLowLevelCopyStmt): + return self._render_low_level_copy(stmt, env, indent=indent) if isinstance(stmt, SemanticReturnStmt): if stmt.value is None: return [self._indent(indent) + "return"] @@ -174,6 +180,53 @@ def _render_stmt( return self._render_if(stmt, env, indent=indent) raise ValueError(f"unsupported semantic statement {type(stmt).__name__}") + def _render_dma_config( + self, + stmt: SemanticDmaConfigStmt, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + lines: list[str] = [] + first = self._lower_to_i64(stmt.first, env, indent=indent, into=lines) + second = self._lower_to_i64(stmt.second, env, indent=indent, into=lines) + lines.append( + self._indent(indent) + + f"pto.{stmt.name} {first.name}, {second.name} : i64, i64" + ) + return lines + + def _render_low_level_copy( + self, + stmt: SemanticLowLevelCopyStmt, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + lines: list[str] = [] + source = self._lower_expr(stmt.source, env, indent=indent, into=lines) + destination = self._lower_expr(stmt.destination, env, indent=indent, into=lines) + + rendered_operands = [] + rendered_types = [] + for index, operand in enumerate(stmt.operands): + if stmt.name == "copy_gm_to_ubuf" and index == 5: + lowered = self._lower_to_i1(operand, env, indent=indent, into=lines) + else: + lowered = self._lower_to_i64(operand, env, indent=indent, into=lines) + rendered_operands.append(lowered.name) + rendered_types.append(self._render_type(lowered.type)) + + operand_text = ", ".join([source.name, destination.name, *rendered_operands]) + type_text = ", ".join( + [self._render_type(source.type), self._render_type(destination.type), *rendered_types] + ) + lines.append( + self._indent(indent) + + f"pto.{stmt.name} {operand_text} : {type_text}" + ) + return lines + def _render_assign( self, stmt: SemanticAssignStmt, @@ -233,32 +286,91 @@ def _render_multi_result_assign( ) -> list[str]: if not isinstance(stmt.value, SemanticCallExpr): raise NotImplementedError("multi-result assignment expects a call expression in TileLang DSL v1") - if stmt.value.namespace != "pto" or stmt.value.name != "make_mask": + if stmt.value.namespace != "pto": raise NotImplementedError( f"multi-result assignment for `pto.{stmt.value.name}` is not supported in TileLang DSL v1" ) if len(stmt.targets) != 2: - raise NotImplementedError("tail make_mask lowering expects exactly two assignment targets") + raise NotImplementedError("multi-result lowering expects exactly two assignment targets") if not isinstance(stmt.value.type, SemanticTupleType) or len(stmt.value.type.elements) != 2: - raise NotImplementedError("tail make_mask lowering expects a two-result tuple type") + raise NotImplementedError("multi-result lowering expects a two-result tuple type") - dtype_expr, remaining_expr = stmt.value.args - if not self._is_dtype_meta_expr(dtype_expr): - raise NotImplementedError("make_mask dtype lowering expects a dtype symbol") + if stmt.value.name == "make_mask": + dtype_expr, remaining_expr = stmt.value.args + if not self._is_dtype_meta_expr(dtype_expr): + raise NotImplementedError("make_mask dtype lowering expects a dtype symbol") - lines: list[str] = [] - remaining = self._lower_remaining_to_i32(remaining_expr, env, indent=indent, into=lines) - mask_target, remaining_target = stmt.targets - mask_type, remaining_type = stmt.value.type.elements - suffix = self._mask_suffix(mask_type) - lines.append( - self._indent(indent) - + f"{mask_target.ssa_name}, {remaining_target.ssa_name} = pto.plt_{suffix} {remaining.name} : " - + f"i32 -> {self._render_type(mask_type)}, {self._render_type(remaining_type)}" + lines: list[str] = [] + remaining = self._lower_remaining_to_i32(remaining_expr, env, indent=indent, into=lines) + mask_target, remaining_target = stmt.targets + mask_type, remaining_type = stmt.value.type.elements + suffix = self._mask_suffix(mask_type) + lines.append( + self._indent(indent) + + f"{mask_target.ssa_name}, {remaining_target.ssa_name} = pto.plt_{suffix} {remaining.name} : " + + f"i32 -> {self._render_type(mask_type)}, {self._render_type(remaining_type)}" + ) + env[mask_target.name] = _RenderedValue(name=mask_target.ssa_name, type=mask_type) + env[remaining_target.name] = _RenderedValue(name=remaining_target.ssa_name, type=remaining_type) + return lines + + if stmt.value.name in {"vaddc", "vsubc"}: + lines = [] + lhs = self._lower_expr(stmt.value.args[0], env, indent=indent, into=lines) + rhs = self._lower_expr(stmt.value.args[1], env, indent=indent, into=lines) + mask = self._lower_expr(stmt.value.args[2], env, indent=indent, into=lines) + result_target, carry_target = stmt.targets + result_type, carry_type = stmt.value.type.elements + lines.append( + self._indent(indent) + + f"{result_target.ssa_name}, {carry_target.ssa_name} = pto.{stmt.value.name} " + + f"{lhs.name}, {rhs.name}, {mask.name} : " + + f"{self._render_type(lhs.type)}, {self._render_type(rhs.type)}, {self._render_type(mask.type)} " + + f"-> {self._render_type(result_type)}, {self._render_type(carry_type)}" + ) + env[result_target.name] = _RenderedValue(name=result_target.ssa_name, type=result_type) + env[carry_target.name] = _RenderedValue(name=carry_target.ssa_name, type=carry_type) + return lines + + if stmt.value.name in {"vaddcs", "vsubcs"}: + lines = [] + lhs = self._lower_expr(stmt.value.args[0], env, indent=indent, into=lines) + rhs = self._lower_expr(stmt.value.args[1], env, indent=indent, into=lines) + carry_in = self._lower_expr(stmt.value.args[2], env, indent=indent, into=lines) + mask = self._lower_expr(stmt.value.args[3], env, indent=indent, into=lines) + result_target, carry_target = stmt.targets + result_type, carry_type = stmt.value.type.elements + lines.append( + self._indent(indent) + + f"{result_target.ssa_name}, {carry_target.ssa_name} = pto.{stmt.value.name} " + + f"{lhs.name}, {rhs.name}, {carry_in.name}, {mask.name} : " + + f"{self._render_type(lhs.type)}, {self._render_type(rhs.type)}, " + + f"{self._render_type(carry_in.type)}, {self._render_type(mask.type)} " + + f"-> {self._render_type(result_type)}, {self._render_type(carry_type)}" + ) + env[result_target.name] = _RenderedValue(name=result_target.ssa_name, type=result_type) + env[carry_target.name] = _RenderedValue(name=carry_target.ssa_name, type=carry_type) + return lines + + if stmt.value.name in {"vintlv", "vdintlv"}: + lines = [] + lhs = self._lower_expr(stmt.value.args[0], env, indent=indent, into=lines) + rhs = self._lower_expr(stmt.value.args[1], env, indent=indent, into=lines) + low_target, high_target = stmt.targets + low_type, high_type = stmt.value.type.elements + lines.append( + self._indent(indent) + + f"{low_target.ssa_name}, {high_target.ssa_name} = pto.{stmt.value.name} " + + f"{lhs.name}, {rhs.name} : {self._render_type(lhs.type)}, {self._render_type(rhs.type)} " + + f"-> {self._render_type(low_type)}, {self._render_type(high_type)}" + ) + env[low_target.name] = _RenderedValue(name=low_target.ssa_name, type=low_type) + env[high_target.name] = _RenderedValue(name=high_target.ssa_name, type=high_type) + return lines + + raise NotImplementedError( + f"multi-result assignment for `pto.{stmt.value.name}` is not supported in TileLang DSL v1" ) - env[mask_target.name] = _RenderedValue(name=mask_target.ssa_name, type=mask_type) - env[remaining_target.name] = _RenderedValue(name=remaining_target.ssa_name, type=remaining_type) - return lines def _render_dma_load( self, @@ -655,6 +767,118 @@ def _lower_call_expr( ) return _RenderedValue(name=result_name, type=expr.type) + if expr.name == "castptr": + value = self._lower_expr(expr.args[0], env, indent=indent, into=into) + if isinstance(expr.type, SemanticPtrType) and isinstance(value.type, SemanticIndexType): + value = self._coerce_rendered_to_i64(value, indent=indent, into=into) + into.append( + self._indent(indent) + + f"{result_name} = pto.castptr {value.name} : " + + f"{self._render_type(value.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "addptr": + pointer = self._lower_expr(expr.args[0], env, indent=indent, into=into) + offset = self._lower_expr(expr.args[1], env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"{result_name} = pto.addptr {pointer.name}, {offset.name} : " + + f"{self._render_type(pointer.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name in {"ppack", "punpack"}: + value = self._lower_expr(expr.args[0], env, indent=indent, into=into) + part = self._render_string_literal(expr.args[1]) + into.append( + self._indent(indent) + + f"{result_name} = pto.{expr.name} {value.name}, {part} : " + + f"{self._render_type(value.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "pnot": + value = self._lower_expr(expr.args[0], env, indent=indent, into=into) + mask = self._lower_expr(expr.args[1], env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"{result_name} = pto.pnot {value.name}, {mask.name} : " + + f"{self._render_type(value.type)}, {self._render_type(mask.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "psel": + src0 = self._lower_expr(expr.args[0], env, indent=indent, into=into) + src1 = self._lower_expr(expr.args[1], env, indent=indent, into=into) + mask = self._lower_expr(expr.args[2], env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"{result_name} = pto.psel {src0.name}, {src1.name}, {mask.name} : " + + f"{self._render_type(src0.type)}, {self._render_type(src1.type)}, {self._render_type(mask.type)} " + + f"-> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "vcmp": + lhs = self._lower_expr(expr.args[0], env, indent=indent, into=into) + rhs = self._lower_expr(expr.args[1], env, indent=indent, into=into) + seed = self._lower_expr(expr.args[2], env, indent=indent, into=into) + cmp_mode = self._render_string_literal(expr.args[3]) + into.append( + self._indent(indent) + + f"{result_name} = pto.vcmp {lhs.name}, {rhs.name}, {seed.name}, {cmp_mode} : " + + f"{self._render_type(lhs.type)}, {self._render_type(rhs.type)}, {self._render_type(seed.type)} " + + f"-> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "vcmps": + vector = self._lower_expr(expr.args[0], env, indent=indent, into=into) + scalar = self._lower_expr(expr.args[1], env, indent=indent, into=into) + seed = self._lower_expr(expr.args[2], env, indent=indent, into=into) + cmp_mode = self._render_string_literal(expr.args[3]) + into.append( + self._indent(indent) + + f"{result_name} = pto.vcmps {vector.name}, {scalar.name}, {seed.name}, {cmp_mode} : " + + f"{self._render_type(vector.type)}, {self._render_type(scalar.type)}, {self._render_type(seed.type)} " + + f"-> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "vsel": + src0 = self._lower_expr(expr.args[0], env, indent=indent, into=into) + src1 = self._lower_expr(expr.args[1], env, indent=indent, into=into) + mask = self._lower_expr(expr.args[2], env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"{result_name} = pto.vsel {src0.name}, {src1.name}, {mask.name} : " + + f"{self._render_type(src0.type)}, {self._render_type(src1.type)}, {self._render_type(mask.type)} " + + f"-> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name in {"vselr", "vselrv2"}: + src0 = self._lower_expr(expr.args[0], env, indent=indent, into=into) + src1 = self._lower_expr(expr.args[1], env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"{result_name} = pto.{expr.name} {src0.name}, {src1.name} : " + + f"{self._render_type(src0.type)}, {self._render_type(src1.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name in {"vintlvv2", "vdintlvv2"}: + lhs = self._lower_expr(expr.args[0], env, indent=indent, into=into) + rhs = self._lower_expr(expr.args[1], env, indent=indent, into=into) + part = self._render_string_literal(expr.args[2]) + into.append( + self._indent(indent) + + f"{result_name} = pto.{expr.name} {lhs.name}, {rhs.name}, {part} : " + + f"{self._render_type(lhs.type)}, {self._render_type(rhs.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + if expr.name in {"vabs", "vrelu", "vexp", "vnot"}: value = self._lower_expr(expr.args[0], env, indent=indent, into=into) mask = self._lower_expr(expr.args[1], env, indent=indent, into=into) @@ -691,6 +915,57 @@ def _lower_call_expr( raise NotImplementedError(f"unsupported pto call `{expr.name}` in lowering") + def _render_string_literal(self, expr: SemanticExpr) -> str: + if isinstance(expr, SemanticLiteralExpr) and isinstance(expr.value, str): + escaped = expr.value.replace("\\", "\\\\").replace('"', '\\"') + return f'"{escaped}"' + if isinstance(expr, SemanticBindingRef) and isinstance(expr.binding.value, str): + escaped = expr.binding.value.replace("\\", "\\\\").replace('"', '\\"') + return f'"{escaped}"' + raise NotImplementedError("expected a string literal for TileLang DSL advanced-family lowering") + + def _lower_to_i1( + self, + expr: SemanticExpr, + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + value = self._lower_expr(expr, env, indent=indent, into=into) + if isinstance(value.type, SemanticScalarType) and value.type.dtype.name == "i1": + return value + raise NotImplementedError("expected an i1 operand during TileLang DSL v1 lowering") + + def _lower_to_i64( + self, + expr: SemanticExpr, + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + value = self._lower_expr(expr, env, indent=indent, into=into) + return self._coerce_rendered_to_i64(value, indent=indent, into=into) + + def _coerce_rendered_to_i64( + self, + value: _RenderedValue, + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + if isinstance(value.type, SemanticScalarType) and value.type.dtype.name == "i64": + return value + if isinstance(value.type, SemanticIndexType): + cast_name = self._new_temp() + into.append( + self._indent(indent) + + f"{cast_name} = arith.index_castui {value.name} : index to i64" + ) + return _RenderedValue(name=cast_name, type=_I64_TYPE) + raise NotImplementedError("expected an i64 or index operand during TileLang DSL v1 lowering") + def _lower_remaining_to_i32( self, expr: SemanticExpr, @@ -847,6 +1122,8 @@ def _render_type(self, ty: SemanticType) -> str: return "index" if isinstance(ty, SemanticScalarType): return ty.dtype.name + if isinstance(ty, SemanticPtrType): + return f"!pto.ptr<{ty.element_dtype.name}, {ty.memory_space}>" if isinstance(ty, SemanticTensorViewType): return f"!pto.ptr<{ty.element_dtype.name}, gm>" if isinstance(ty, SemanticTileType): diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 6075c8d81..497b1557c 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -34,7 +34,22 @@ deferred_surface_message, unsupported_feature_message, ) -from .types import Event, MaskPattern, Pipe, ScalarType, bf16, f16, f32, i1, i8, i16, i32 +from .types import ( + Event, + MaskPattern, + MemorySpace, + Pipe, + PointerType, + ScalarType, + bf16, + f16, + f32, + i1, + i8, + i16, + i32, + i64, +) _DTYPE_SYMBOLS = { @@ -42,6 +57,7 @@ "i8": i8, "i16": i16, "i32": i32, + "i64": i64, "f16": f16, "bf16": bf16, "f32": f32, @@ -49,9 +65,33 @@ _PATTERN_SYMBOLS = {pattern.name: pattern for pattern in MaskPattern} _PIPE_SYMBOLS = {pipe.name: pipe for pipe in Pipe} _EVENT_SYMBOLS = {event.name: event for event in Event} +_MEMORY_SPACE_SYMBOLS = {memory_space.name: memory_space for memory_space in MemorySpace} _UNARY_VECTOR_OPS = {"vabs", "vrelu", "vexp", "vnot"} _BINARY_VECTOR_OPS = {"vadd", "vsub", "vmul", "vdiv", "vmax", "vmin", "vand", "vor", "vxor"} _VECTOR_SCALAR_OPS = {"vadds", "vsubs", "vmuls", "vdivs", "vmaxs", "vmins"} +_LOW_LEVEL_DMA_CONFIG_OPS = { + "set_loop2_stride_outtoub", + "set_loop1_stride_outtoub", + "set_loop_size_outtoub", + "set_loop2_stride_ubtoout", + "set_loop1_stride_ubtoout", + "set_loop_size_ubtoout", +} +_LOW_LEVEL_DMA_COPY_OPS = { + "copy_gm_to_ubuf", + "copy_ubuf_to_gm", + "copy_ubuf_to_ubuf", +} +_COMPARE_SELECT_OPS = {"vcmp", "vcmps", "vsel", "vselr", "vselrv2"} +_PREDICATE_MOVEMENT_OPS = {"pnot", "psel", "ppack", "punpack"} +_CARRY_OPS = {"vaddc", "vsubc", "vaddcs", "vsubcs"} +_REARRANGEMENT_OPS = {"vintlv", "vdintlv", "vintlvv2", "vdintlvv2"} +_ADVANCED_VECTOR_ACTIVITY_OPS = ( + _COMPARE_SELECT_OPS + | _PREDICATE_MOVEMENT_OPS + | _CARRY_OPS + | _REARRANGEMENT_OPS +) class SemanticType: @@ -84,6 +124,12 @@ class SemanticScalarType(SemanticType): dtype: ScalarType +@dataclass(frozen=True) +class SemanticPtrType(SemanticType): + element_dtype: ScalarType + memory_space: str + + @dataclass(frozen=True) class SemanticIndexType(SemanticType): pass @@ -275,6 +321,21 @@ class SemanticPipeBarrierStmt(SemanticStmt): pipe: str +@dataclass(frozen=True) +class SemanticDmaConfigStmt(SemanticStmt): + name: str + first: SemanticExpr + second: SemanticExpr + + +@dataclass(frozen=True) +class SemanticLowLevelCopyStmt(SemanticStmt): + name: str + source: SemanticExpr + destination: SemanticExpr + operands: tuple[SemanticExpr, ...] + + @dataclass(frozen=True) class SemanticIfResult: result_binding: SemanticBinding @@ -395,39 +456,6 @@ def _analyze_kernel_body( self, env: dict[str, SemanticBinding], ) -> tuple[tuple[SemanticStmt, ...], dict[str, SemanticBinding]]: - if not self.node.advanced_enabled: - return self._analyze_block(self.node.body, env, allow_outer_lookup=True) - - body_without_return = self.node.body - trailing_return: FrontendReturnStmt | None = None - if body_without_return and isinstance(body_without_return[-1], FrontendReturnStmt): - trailing_return = body_without_return[-1] - body_without_return = body_without_return[:-1] - - if body_without_return and self._can_wrap_whole_kernel_vecscope(body_without_return): - self._disable_inference_depth += 1 - try: - scoped_body, scoped_env = self._analyze_block_without_inference( - body_without_return, - env, - allow_outer_lookup=True, - ) - finally: - self._disable_inference_depth -= 1 - semantic_body: list[SemanticStmt] = [] - if self._semantic_block_contains_vector_activity(scoped_body): - semantic_body.append(SemanticVecscopeStmt(body=scoped_body)) - else: - semantic_body.extend(scoped_body) - if trailing_return is not None: - return_stmt, scoped_env = self._analyze_stmt( - trailing_return, - scoped_env, - allow_outer_lookup=True, - ) - semantic_body.append(return_stmt) - return tuple(semantic_body), scoped_env - return self._analyze_block(self.node.body, env, allow_outer_lookup=True) def _parameter_type(self, param: Any) -> SemanticType: @@ -444,6 +472,12 @@ def _parameter_type(self, param: Any) -> SemanticType: shape=shape, memory_space=memory_space, ) + if param.kind == "ptr": + memory_space = param.annotation.memory_space.value + return SemanticPtrType( + element_dtype=param.dtype, + memory_space=memory_space, + ) if param.kind == "scalar": return SemanticScalarType(dtype=param.dtype) raise ValueError(f"unsupported parameter kind {param.kind!r}") @@ -548,11 +582,65 @@ def _should_infer_vecscope( return False if not self.node.advanced_enabled or not allow_outer_lookup: return False + if isinstance(stmt, FrontendForStmt): + return self._block_can_live_in_inferred_vecscope(stmt.body) name = self._frontend_vector_call_name(stmt) - return name in {"make_mask", "vlds", "vsts"} | _UNARY_VECTOR_OPS | _BINARY_VECTOR_OPS | _VECTOR_SCALAR_OPS + return name in ( + {"make_mask", "vlds", "vsts"} + | _UNARY_VECTOR_OPS + | _BINARY_VECTOR_OPS + | _VECTOR_SCALAR_OPS + | _ADVANCED_VECTOR_ACTIVITY_OPS + ) + + def _block_can_live_in_inferred_vecscope( + self, + statements: tuple[FrontendStmtNode, ...], + ) -> bool: + saw_vector_activity = False + for stmt in statements: + if isinstance(stmt, FrontendStrictVecscopeStmt): + return False + if isinstance(stmt, FrontendIfStmt): + return False + if isinstance(stmt, FrontendExprStmt) and ( + self._is_dma_call(stmt.expr) or self._is_sync_call(stmt.expr) + ): + return False + if isinstance(stmt, FrontendForStmt): + if not self._block_can_live_in_inferred_vecscope(stmt.body): + return False + saw_vector_activity = True + continue + if self._frontend_stmt_contains_vector_activity(stmt): + saw_vector_activity = True + continue + return False + return saw_vector_activity + + def _frontend_stmt_contains_vector_activity(self, stmt: FrontendStmtNode) -> bool: + expr: FrontendExprNode | None = None + if isinstance(stmt, FrontendAssignStmt): + expr = stmt.value + elif isinstance(stmt, FrontendExprStmt): + expr = stmt.expr + if not isinstance(expr, FrontendCallExpr): + return False + return ( + expr.namespace == "pto" + and expr.name in ( + {"make_mask", "vlds", "vsts"} + | _UNARY_VECTOR_OPS + | _BINARY_VECTOR_OPS + | _VECTOR_SCALAR_OPS + | _ADVANCED_VECTOR_ACTIVITY_OPS + ) + ) def _run_contains_vector_op(self, statements: tuple[FrontendStmtNode, ...]) -> bool: for stmt in statements: + if isinstance(stmt, FrontendForStmt) and self._block_can_live_in_inferred_vecscope(stmt.body): + return True name = self._frontend_vector_call_name(stmt) if name is None or name == "make_mask": continue @@ -579,11 +667,15 @@ def _analyze_inferred_vecscope( *, allow_outer_lookup: bool, ) -> SemanticVecscopeStmt: - body, _ = self._analyze_block_without_inference( - statements, - env, - allow_outer_lookup=allow_outer_lookup, - ) + self._disable_inference_depth += 1 + try: + body, _ = self._analyze_block_without_inference( + statements, + env, + allow_outer_lookup=allow_outer_lookup, + ) + finally: + self._disable_inference_depth -= 1 return SemanticVecscopeStmt(body=body) def _analyze_block_without_inference( @@ -604,27 +696,6 @@ def _analyze_block_without_inference( semantic_statements.append(semantic_stmt) return tuple(semantic_statements), current_env - def _can_wrap_whole_kernel_vecscope( - self, - statements: tuple[FrontendStmtNode, ...], - ) -> bool: - for stmt in statements: - if isinstance(stmt, FrontendStrictVecscopeStmt): - return False - if isinstance(stmt, FrontendExprStmt) and ( - self._is_dma_call(stmt.expr) or self._is_sync_call(stmt.expr) - ): - return False - nested_blocks: tuple[tuple[FrontendStmtNode, ...], ...] = () - if isinstance(stmt, FrontendForStmt): - nested_blocks = (stmt.body,) - elif isinstance(stmt, FrontendIfStmt): - nested_blocks = (stmt.then_body, stmt.else_body) - for block in nested_blocks: - if not self._can_wrap_whole_kernel_vecscope(block): - return False - return True - def _semantic_block_contains_vector_activity( self, statements: tuple[SemanticStmt, ...], @@ -652,7 +723,11 @@ def _semantic_block_contains_vector_activity( def _expr_contains_vector_activity(self, expr: SemanticExpr) -> bool: if isinstance(expr, SemanticCallExpr): if expr.namespace == "pto" and expr.name in ( - {"make_mask", "vlds"} | _UNARY_VECTOR_OPS | _BINARY_VECTOR_OPS | _VECTOR_SCALAR_OPS + {"make_mask", "vlds"} + | _UNARY_VECTOR_OPS + | _BINARY_VECTOR_OPS + | _VECTOR_SCALAR_OPS + | _ADVANCED_VECTOR_ACTIVITY_OPS ): return True return any(self._expr_contains_vector_activity(arg) for arg in expr.args) @@ -693,6 +768,12 @@ def _analyze_stmt( return self._analyze_dma_stmt(stmt.expr, env, allow_outer_lookup=allow_outer_lookup) if self._is_sync_call(stmt.expr): return self._analyze_sync_stmt(stmt.expr, env, allow_outer_lookup=allow_outer_lookup) + if self._is_low_level_dma_call(stmt.expr): + return self._analyze_low_level_dma_stmt( + stmt.expr, + env, + allow_outer_lookup=allow_outer_lookup, + ) if self._is_vector_store_call(stmt.expr): return self._analyze_vector_store_stmt(stmt.expr, env, allow_outer_lookup=allow_outer_lookup) expr = self._analyze_expr(stmt.expr, env, allow_outer_lookup=allow_outer_lookup) @@ -731,6 +812,13 @@ def _is_sync_call(self, expr: FrontendExprNode) -> bool: and expr.name in {"set_flag", "wait_flag", "pipe_barrier", "barrier"} ) + def _is_low_level_dma_call(self, expr: FrontendExprNode) -> bool: + return ( + isinstance(expr, FrontendCallExpr) + and expr.namespace == "pto" + and expr.name in _LOW_LEVEL_DMA_CONFIG_OPS | _LOW_LEVEL_DMA_COPY_OPS + ) + def _analyze_dma_stmt( self, expr: FrontendCallExpr, @@ -783,7 +871,7 @@ def _analyze_vector_store_stmt( raise TypeError("pto.vsts expects 3 or 4 positional arguments in TileLang DSL v1") value, destination, offset, mask = args self._require_vreg_expr(value, "pto.vsts value") - self._require_tile_expr(destination, "pto.vsts destination") + self._require_vector_pointer_expr(destination, "pto.vsts destination") self._require_index_typed_expr(offset) self._require_mask_for_vreg(mask, value.type, "pto.vsts") self._require_matching_vector_pointer(value.type, destination.type, "pto.vsts") @@ -824,6 +912,106 @@ def _analyze_sync_stmt( return SemanticPipeBarrierStmt(pipe=pipe), dict(env) raise ValueError(f"unsupported sync stmt pto.{expr.name}") + def _analyze_low_level_dma_stmt( + self, + expr: FrontendCallExpr, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + ) -> tuple[SemanticStmt, dict[str, SemanticBinding]]: + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + if expr.name in _LOW_LEVEL_DMA_CONFIG_OPS: + if len(args) != 2: + raise TypeError(f"pto.{expr.name} expects exactly 2 positional arguments in TileLang DSL") + self._require_i64_like_expr(args[0], f"pto.{expr.name} first operand") + self._require_i64_like_expr(args[1], f"pto.{expr.name} second operand") + return ( + SemanticDmaConfigStmt( + name=expr.name, + first=args[0], + second=args[1], + ), + dict(env), + ) + if expr.name == "copy_gm_to_ubuf": + if len(args) != 11: + raise TypeError("pto.copy_gm_to_ubuf expects exactly 11 positional arguments in TileLang DSL") + source = self._require_pointer_expr(args[0], "pto.copy_gm_to_ubuf source", memory_space="gm") + destination = self._require_pointer_expr(args[1], "pto.copy_gm_to_ubuf destination", memory_space="ub") + for operand, label in zip( + args[2:7] + args[8:], + ( + "sid", + "n_burst", + "len_burst", + "left_padding_count", + "right_padding_count", + "l2_cache_ctl", + "gm_stride", + "ub_stride", + ), + ): + self._require_i64_like_expr(operand, f"pto.copy_gm_to_ubuf {label}") + self._require_i1_expr(args[7], "pto.copy_gm_to_ubuf data_select_bit") + return ( + SemanticLowLevelCopyStmt( + name=expr.name, + source=source, + destination=destination, + operands=args[2:], + ), + dict(env), + ) + if expr.name == "copy_ubuf_to_gm": + if len(args) != 8: + raise TypeError("pto.copy_ubuf_to_gm expects exactly 8 positional arguments in TileLang DSL") + source = self._require_pointer_expr(args[0], "pto.copy_ubuf_to_gm source", memory_space="ub") + destination = self._require_pointer_expr(args[1], "pto.copy_ubuf_to_gm destination", memory_space="gm") + for operand, label in zip( + args[2:], + ( + "sid", + "n_burst", + "len_burst", + "reserved", + "burst_dst_stride", + "burst_src_stride", + ), + ): + self._require_i64_like_expr(operand, f"pto.copy_ubuf_to_gm {label}") + return ( + SemanticLowLevelCopyStmt( + name=expr.name, + source=source, + destination=destination, + operands=args[2:], + ), + dict(env), + ) + if expr.name == "copy_ubuf_to_ubuf": + if len(args) != 7: + raise TypeError("pto.copy_ubuf_to_ubuf expects exactly 7 positional arguments in TileLang DSL") + source = self._require_pointer_expr(args[0], "pto.copy_ubuf_to_ubuf source", memory_space="ub") + destination = self._require_pointer_expr(args[1], "pto.copy_ubuf_to_ubuf destination", memory_space="ub") + for operand, label in zip( + args[2:], + ("sid", "n_burst", "len_burst", "src_stride", "dst_stride"), + ): + self._require_i64_like_expr(operand, f"pto.copy_ubuf_to_ubuf {label}") + return ( + SemanticLowLevelCopyStmt( + name=expr.name, + source=source, + destination=destination, + operands=args[2:], + ), + dict(env), + ) + raise ValueError(f"unsupported low-level DMA stmt pto.{expr.name}") + def _require_tensor_slice( self, expr: SemanticExpr, @@ -844,6 +1032,24 @@ def _require_tile_expr(self, expr: SemanticExpr, context: str) -> SemanticExpr: raise TypeError(f"{context} currently only supports MemorySpace.UB Tile values in TileLang DSL v1") return expr + def _require_pointer_expr( + self, + expr: SemanticExpr, + context: str, + *, + memory_space: str | None = None, + ) -> SemanticExpr: + if not isinstance(expr.type, SemanticPtrType): + raise TypeError(f"{context} must be a pointer value in TileLang DSL") + if memory_space is not None and expr.type.memory_space != memory_space: + raise TypeError(f"{context} requires MemorySpace.{memory_space.upper()} pointers in TileLang DSL") + return expr + + def _require_vector_pointer_expr(self, expr: SemanticExpr, context: str) -> SemanticExpr: + if isinstance(expr.type, SemanticTileType): + return self._require_tile_expr(expr, context) + return self._require_pointer_expr(expr, context, memory_space="ub") + def _validate_dma_shape_match( self, tensor_slice_type: SemanticTensorSliceType, @@ -1090,7 +1296,7 @@ def _analyze_expr( return SemanticBindingRef(binding=binding, type=binding.type) if isinstance(expr, FrontendConstantExpr): if isinstance(expr.value, bool): - raise TypeError("bool constants are not supported in TileLang DSL v1 yet") + return SemanticLiteralExpr(value=expr.value, type=SemanticScalarType(dtype=i1)) if isinstance(expr.value, int): return SemanticLiteralExpr(value=expr.value, type=SemanticIndexType()) if isinstance(expr.value, str): @@ -1202,6 +1408,15 @@ def _analyze_symbol_expr(self, expr: FrontendSymbolExpr) -> SemanticExpr: value=event, type=SemanticMetaType(kind="event"), ) + if expr.namespace in {"pto.MemorySpace"}: + memory_space = _MEMORY_SPACE_SYMBOLS.get(expr.name) + if memory_space is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=memory_space, + type=SemanticMetaType(kind="memory_space"), + ) raise TypeError( f"symbol `{expr.namespace}.{expr.name}` is not supported in TileLang DSL v1" ) @@ -1432,12 +1647,30 @@ def _analyze_call_expr( ) if name in DEFERRED_PTO_SURFACES: raise TypeError(deferred_surface_message(name)) + if name == "ptr": + return self._analyze_ptr_type(args) + if name == "castptr": + return self._analyze_castptr(args) + if name == "addptr": + return self._analyze_addptr(args) if name == "get_lanes": return self._analyze_get_lanes(args) if name == "make_mask": return self._analyze_make_mask(args) if name == "vlds": return self._analyze_vlds(args) + if name in {"ppack", "punpack"}: + return self._analyze_mask_part_op(name, args) + if name in {"pnot", "psel"}: + return self._analyze_mask_logic_op(name, args) + if name in {"vcmp", "vcmps"}: + return self._analyze_compare_op(name, args) + if name in {"vsel", "vselr", "vselrv2"}: + return self._analyze_select_op(name, args) + if name in {"vaddc", "vsubc", "vaddcs", "vsubcs"}: + return self._analyze_carry_op(name, args) + if name in {"vintlv", "vdintlv", "vintlvv2", "vdintlvv2"}: + return self._analyze_rearrangement_op(name, args) if name in _UNARY_VECTOR_OPS: return self._analyze_unary_vector_op(name, args) if name in _BINARY_VECTOR_OPS: @@ -1471,6 +1704,35 @@ def _analyze_make_mask(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: ), ) + def _analyze_ptr_type(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) != 2: + raise TypeError("pto.ptr expects exactly 2 positional arguments in TileLang DSL") + dtype = self._require_dtype_symbol(args[0], "pto.ptr element type") + memory_space = self._require_memory_space_symbol(args[1], "pto.ptr memory space") + return SemanticLiteralExpr( + value=PointerType(element_dtype=dtype, memory_space=memory_space), + type=SemanticMetaType(kind="ptr_type"), + ) + + def _analyze_castptr(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) != 2: + raise TypeError("pto.castptr expects exactly 2 positional arguments in TileLang DSL") + value, target = args + target_type = self._require_cast_target_type(target) + if isinstance(target_type, SemanticPtrType): + self._require_castptr_input(value, target_type) + else: + self._require_pointer_expr(value, "pto.castptr input") + return SemanticCallExpr(namespace="pto", name="castptr", args=args, type=target_type) + + def _analyze_addptr(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) != 2: + raise TypeError("pto.addptr expects exactly 2 positional arguments in TileLang DSL") + pointer, offset = args + ptr = self._require_pointer_expr(pointer, "pto.addptr pointer") + self._require_index_typed_expr(offset) + return SemanticCallExpr(namespace="pto", name="addptr", args=(ptr, offset), type=ptr.type) + def _analyze_get_lanes(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: if len(args) != 1: raise TypeError("pto.get_lanes expects exactly 1 positional argument in TileLang DSL v1") @@ -1481,13 +1743,16 @@ def _analyze_vlds(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: if len(args) != 2: raise TypeError("pto.vlds expects exactly 2 positional arguments in TileLang DSL v1") source, offset = args - tile = self._require_tile_expr(source, "pto.vlds source") + if isinstance(source_type, SemanticTileType): + source = self._require_tile_expr(source, "pto.vlds source") + else: + source = self._require_pointer_expr(source, "pto.vlds source", memory_space="ub") self._require_index_typed_expr(offset) return SemanticCallExpr( namespace="pto", name="vlds", args=args, - type=self._vreg_type_for_dtype(tile.type.element_dtype), + type=self._vreg_type_for_dtype(source.type.element_dtype), ) def _analyze_unary_vector_op( @@ -1535,6 +1800,165 @@ def _analyze_vector_scalar_op( self._validate_vector_scalar_dtype(name, vreg.element_dtype) return SemanticCallExpr(namespace="pto", name=name, args=args, type=vreg) + def _analyze_mask_part_op( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if len(args) != 2: + raise TypeError(f"pto.{name} expects exactly 2 positional arguments in TileLang DSL") + mask = self._require_mask_expr(args[0], f"pto.{name} mask") + self._require_string_expr(args[1], f"pto.{name} part") + return SemanticCallExpr(namespace="pto", name=name, args=args, type=mask) + + def _analyze_mask_logic_op( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if name == "pnot": + if len(args) != 2: + raise TypeError("pto.pnot expects exactly 2 positional arguments in TileLang DSL") + value = self._require_mask_expr(args[0], "pto.pnot input") + mask = self._require_mask_expr(args[1], "pto.pnot mask") + self._require_matching_mask_types(value, mask, "pto.pnot") + return SemanticCallExpr(namespace="pto", name=name, args=args, type=value) + if len(args) != 3: + raise TypeError("pto.psel expects exactly 3 positional arguments in TileLang DSL") + src0 = self._require_mask_expr(args[0], "pto.psel src0") + src1 = self._require_mask_expr(args[1], "pto.psel src1") + mask = self._require_mask_expr(args[2], "pto.psel mask") + self._require_matching_mask_types(src0, src1, "pto.psel") + self._require_matching_mask_types(src0, mask, "pto.psel") + return SemanticCallExpr(namespace="pto", name=name, args=args, type=src0) + + def _analyze_compare_op( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if name == "vcmp": + if len(args) != 4: + raise TypeError("pto.vcmp expects exactly 4 positional arguments in TileLang DSL") + lhs = self._require_vreg_expr(args[0], "pto.vcmp lhs") + rhs = self._require_vreg_expr(args[1], "pto.vcmp rhs") + if lhs != rhs: + raise TypeError("pto.vcmp requires lhs/rhs vector types to match") + seed = self._require_mask_expr(args[2], "pto.vcmp seed mask") + self._require_mask_for_vreg(args[2], lhs, "pto.vcmp") + self._require_string_expr(args[3], "pto.vcmp compare mode") + return SemanticCallExpr( + namespace="pto", + name=name, + args=args, + type=SemanticMaskType(granularity=seed.granularity), + ) + + if len(args) != 4: + raise TypeError("pto.vcmps expects exactly 4 positional arguments in TileLang DSL") + vector = self._require_vreg_expr(args[0], "pto.vcmps vector") + scalar = self._require_scalar_expr(args[1], "pto.vcmps scalar") + if scalar.dtype != vector.element_dtype: + raise TypeError("pto.vcmps scalar dtype must match vector element dtype") + seed = self._require_mask_expr(args[2], "pto.vcmps seed mask") + self._require_mask_for_vreg(args[2], vector, "pto.vcmps") + self._require_string_expr(args[3], "pto.vcmps compare mode") + return SemanticCallExpr( + namespace="pto", + name=name, + args=args, + type=SemanticMaskType(granularity=seed.granularity), + ) + + def _analyze_select_op( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if name == "vsel": + if len(args) != 3: + raise TypeError("pto.vsel expects exactly 3 positional arguments in TileLang DSL") + src0 = self._require_vreg_expr(args[0], "pto.vsel src0") + src1 = self._require_vreg_expr(args[1], "pto.vsel src1") + if src0 != src1: + raise TypeError("pto.vsel requires src0/src1 vector types to match") + self._require_mask_for_vreg(args[2], src0, "pto.vsel") + return SemanticCallExpr(namespace="pto", name=name, args=args, type=src0) + + if len(args) != 2: + raise TypeError(f"pto.{name} expects exactly 2 positional arguments in TileLang DSL") + src0 = self._require_vreg_expr(args[0], f"pto.{name} src0") + src1 = self._require_vreg_expr(args[1], f"pto.{name} src1") + if src0 != src1: + raise TypeError(f"pto.{name} requires src0/src1 vector types to match") + return SemanticCallExpr(namespace="pto", name=name, args=args, type=src0) + + def _analyze_carry_op( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if name in {"vaddc", "vsubc"}: + if len(args) != 3: + raise TypeError(f"pto.{name} expects exactly 3 positional arguments in TileLang DSL") + lhs = self._require_vreg_expr(args[0], f"pto.{name} lhs") + rhs = self._require_vreg_expr(args[1], f"pto.{name} rhs") + if lhs != rhs: + raise TypeError(f"pto.{name} requires lhs/rhs vector types to match") + self._require_mask_for_vreg(args[2], lhs, f"pto.{name}") + carry_type = args[2].type + return SemanticCallExpr( + namespace="pto", + name=name, + args=args, + type=SemanticTupleType(elements=(lhs, carry_type)), + ) + + if len(args) != 4: + raise TypeError(f"pto.{name} expects exactly 4 positional arguments in TileLang DSL") + lhs = self._require_vreg_expr(args[0], f"pto.{name} lhs") + rhs = self._require_vreg_expr(args[1], f"pto.{name} rhs") + if lhs != rhs: + raise TypeError(f"pto.{name} requires lhs/rhs vector types to match") + carry_in = self._require_mask_expr(args[2], f"pto.{name} carry_in") + self._require_mask_for_vreg(args[3], lhs, f"pto.{name}") + carry_mask = self._require_mask_expr(args[3], f"pto.{name} mask") + self._require_matching_mask_types(carry_in, carry_mask, f"pto.{name}") + return SemanticCallExpr( + namespace="pto", + name=name, + args=args, + type=SemanticTupleType(elements=(lhs, carry_in)), + ) + + def _analyze_rearrangement_op( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if name in {"vintlv", "vdintlv"}: + if len(args) != 2: + raise TypeError(f"pto.{name} expects exactly 2 positional arguments in TileLang DSL") + lhs = self._require_vreg_expr(args[0], f"pto.{name} lhs") + rhs = self._require_vreg_expr(args[1], f"pto.{name} rhs") + if lhs != rhs: + raise TypeError(f"pto.{name} requires lhs/rhs vector types to match") + return SemanticCallExpr( + namespace="pto", + name=name, + args=args, + type=SemanticTupleType(elements=(lhs, lhs)), + ) + + if len(args) != 3: + raise TypeError(f"pto.{name} expects exactly 3 positional arguments in TileLang DSL") + lhs = self._require_vreg_expr(args[0], f"pto.{name} lhs") + rhs = self._require_vreg_expr(args[1], f"pto.{name} rhs") + if lhs != rhs: + raise TypeError(f"pto.{name} requires lhs/rhs vector types to match") + self._require_string_expr(args[2], f"pto.{name} part") + return SemanticCallExpr(namespace="pto", name=name, args=args, type=lhs) + def _require_dtype_symbol(self, expr: SemanticExpr, context: str) -> ScalarType: if not ( isinstance(expr, SemanticSymbolExpr) @@ -1551,6 +1975,79 @@ def _require_dtype_symbol(self, expr: SemanticExpr, context: str) -> ScalarType: raise TypeError(f"{context} must be a TileLang scalar dtype symbol in TileLang DSL v1") return expr.value + def _require_memory_space_symbol(self, expr: SemanticExpr, context: str) -> MemorySpace: + if ( + isinstance(expr, SemanticSymbolExpr) + and expr.type.kind == "memory_space" + and isinstance(expr.value, MemorySpace) + ): + return expr.value + if ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "memory_space" + and isinstance(expr.binding.value, MemorySpace) + ): + return expr.binding.value + raise TypeError(f"{context} must be a TileLang MemorySpace symbol") + + def _require_ptr_type_expr(self, expr: SemanticExpr, context: str) -> PointerType: + if ( + isinstance(expr, SemanticLiteralExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "ptr_type" + and isinstance(expr.value, PointerType) + ): + return expr.value + if ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "ptr_type" + and isinstance(expr.binding.value, PointerType) + ): + return expr.binding.value + raise TypeError(f"{context} must be a pointer type constructed with pto.ptr(...)") + + def _require_cast_target_type(self, expr: SemanticExpr) -> SemanticType: + if self._is_i64_dtype_expr(expr): + return SemanticScalarType(dtype=i64) + ptr_type = self._require_ptr_type_expr(expr, "pto.castptr target type") + return SemanticPtrType( + element_dtype=ptr_type.element_dtype, + memory_space=ptr_type.memory_space.value, + ) + + def _require_castptr_input(self, expr: SemanticExpr, target_type: SemanticPtrType) -> None: + if isinstance(expr.type, SemanticIndexType): + return + if isinstance(expr.type, SemanticScalarType) and expr.type.dtype == i64: + return + if isinstance(expr.type, SemanticPtrType): + if expr.type.memory_space != target_type.memory_space: + raise TypeError("pto.castptr pointer-to-pointer casts must stay within one PTO memory space") + return + if isinstance(expr.type, SemanticTensorViewType): + if target_type.memory_space != "gm": + raise TypeError("pto.castptr TensorView casts require a GM pointer target") + return + if isinstance(expr.type, SemanticTileType): + tile_memory_space = expr.type.memory_space or "ub" + if tile_memory_space != target_type.memory_space: + raise TypeError("pto.castptr Tile casts must preserve the Tile memory space") + return + raise TypeError("pto.castptr input must be an index/i64, pointer, TensorView, or Tile value") + + def _is_i64_dtype_expr(self, expr: SemanticExpr) -> bool: + if isinstance(expr, SemanticSymbolExpr): + return expr.type.kind == "dtype" and expr.value == i64 + if isinstance(expr, SemanticBindingRef): + return ( + isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "dtype" + and expr.binding.value == i64 + ) + return False + def _require_vreg_expr(self, expr: SemanticExpr, context: str) -> SemanticVRegType: if not isinstance(expr.type, SemanticVRegType): raise TypeError(f"{context} must be a vector register value in TileLang DSL v1") @@ -1561,6 +2058,44 @@ def _require_scalar_expr(self, expr: SemanticExpr, context: str) -> SemanticScal raise TypeError(f"{context} must be a scalar value in TileLang DSL v1") return expr.type + def _require_mask_expr(self, expr: SemanticExpr, context: str) -> SemanticMaskType: + if not isinstance(expr.type, SemanticMaskType): + raise TypeError(f"{context} must be a mask value in TileLang DSL") + return expr.type + + def _require_matching_mask_types( + self, + lhs: SemanticMaskType, + rhs: SemanticMaskType, + context: str, + ) -> None: + if lhs != rhs: + raise TypeError(f"{context} requires all mask operands to use the same mask granularity") + + def _require_string_expr(self, expr: SemanticExpr, context: str) -> str: + if isinstance(expr, SemanticLiteralExpr) and isinstance(expr.type, SemanticMetaType) and expr.type.kind == "string": + return expr.value + if ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "string" + and isinstance(expr.binding.value, str) + ): + return expr.binding.value + raise TypeError(f"{context} must be a string literal in TileLang DSL") + + def _require_i1_expr(self, expr: SemanticExpr, context: str) -> None: + scalar = self._require_scalar_expr(expr, context) + if scalar.dtype != i1: + raise TypeError(f"{context} must be an i1 value in TileLang DSL") + + def _require_i64_like_expr(self, expr: SemanticExpr, context: str) -> None: + if isinstance(expr.type, SemanticIndexType): + return + scalar = self._require_scalar_expr(expr, context) + if scalar.dtype != i64: + raise TypeError(f"{context} must be an i64 or index value in TileLang DSL") + def _require_tail_remaining_expr(self, expr: SemanticExpr, context: str) -> None: if isinstance(expr.type, SemanticIndexType): return @@ -1585,11 +2120,20 @@ def _require_mask_for_vreg( def _require_matching_vector_pointer( self, vreg_type: SemanticVRegType, - pointer_type: SemanticTileType, + pointer_type: SemanticType, context: str, ) -> None: - if pointer_type.element_dtype != vreg_type.element_dtype: - raise TypeError(f"{context} requires destination Tile dtype to match vector dtype") + if isinstance(pointer_type, SemanticTileType): + if pointer_type.element_dtype != vreg_type.element_dtype: + raise TypeError(f"{context} requires destination Tile dtype to match vector dtype") + return + if isinstance(pointer_type, SemanticPtrType): + if pointer_type.memory_space != "ub": + raise TypeError(f"{context} requires a UB pointer destination in TileLang DSL") + if pointer_type.element_dtype != vreg_type.element_dtype: + raise TypeError(f"{context} requires destination pointer dtype to match vector dtype") + return + raise TypeError(f"{context} requires a Tile or pointer destination in TileLang DSL") def _mask_granularity_for_dtype(self, dtype: ScalarType) -> str: if dtype.name in {"f32", "i32"}: diff --git a/tilelang-dsl/python/tilelang_dsl/support_matrix.py b/tilelang-dsl/python/tilelang_dsl/support_matrix.py index de92cd701..ea8984f22 100644 --- a/tilelang-dsl/python/tilelang_dsl/support_matrix.py +++ b/tilelang-dsl/python/tilelang_dsl/support_matrix.py @@ -43,11 +43,8 @@ } ) -DEFERRED_PTO_SURFACES = frozenset( +ADVANCED_VECSCOPE_PTO_CALLS = frozenset( { - "castptr", - "addptr", - "copy_ubuf_to_ubuf", "vcmp", "vcmps", "vsel", @@ -65,6 +62,33 @@ "vdintlv", "vintlvv2", "vdintlvv2", + } +) + +ADVANCED_EXPR_PTO_CALLS = frozenset( + { + "ptr", + "castptr", + "addptr", + } +) + +ADVANCED_TOPLEVEL_PTO_CALLS = frozenset( + { + "copy_gm_to_ubuf", + "copy_ubuf_to_gm", + "copy_ubuf_to_ubuf", + "set_loop2_stride_outtoub", + "set_loop1_stride_outtoub", + "set_loop_size_outtoub", + "set_loop2_stride_ubtoout", + "set_loop1_stride_ubtoout", + "set_loop_size_ubtoout", + } +) + +DEFERRED_PTO_SURFACES = frozenset( + { "vreduce", } ) @@ -81,11 +105,19 @@ def deferred_surface_message(name: str) -> str: return unsupported_feature_message(f"advanced family surface `pto.{name}`") +def advanced_mode_message(name: str) -> str: + return f"surface `pto.{name}` requires advanced=True in TileLang DSL" + + __all__ = [ "DEFERRED_PTO_SURFACES", "FOLLOW_UP_CHANGE", + "ADVANCED_EXPR_PTO_CALLS", + "ADVANCED_TOPLEVEL_PTO_CALLS", + "ADVANCED_VECSCOPE_PTO_CALLS", "SUPPORTED_TOPLEVEL_PTO_CALLS", "SUPPORTED_VECSCOPE_PTO_CALLS", + "advanced_mode_message", "deferred_surface_message", "unsupported_feature_message", ] diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index 07986b73a..f6c6de2ff 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -11,13 +11,16 @@ from tilelang_dsl.semantic import ( SemanticAssignStmt, SemanticCallExpr, + SemanticDmaConfigStmt, SemanticDmaLoadStmt, SemanticDmaStoreStmt, SemanticForStmt, SemanticIfStmt, SemanticIndexType, + SemanticLowLevelCopyStmt, SemanticMaskType, SemanticPipeBarrierStmt, + SemanticPtrType, SemanticScalarType, SemanticSetFlagStmt, SemanticStrictVecscopeStmt, @@ -34,15 +37,230 @@ class TileLangDSLPackageTests(unittest.TestCase): def test_package_exports_surface(self) -> None: self.assertIsNotNone(pto.__file__) self.assertTrue(hasattr(pto, "vkernel")) + self.assertTrue(hasattr(pto, "KernelRegistry")) + self.assertTrue(hasattr(pto, "select_kernel")) self.assertTrue(hasattr(pto, "TensorView")) self.assertTrue(hasattr(pto, "Tile")) self.assertTrue(hasattr(pto, "TileSpecialization")) + self.assertTrue(hasattr(pto, "PointerType")) + self.assertTrue(hasattr(pto, "ptr")) self.assertTrue(hasattr(pto, "get_lanes")) self.assertTrue(hasattr(pto, "PAT")) self.assertTrue(hasattr(pto, "PIPE")) self.assertTrue(hasattr(pto, "EVENT")) +class TileLangDSLMatcherEntryTests(unittest.TestCase): + def test_select_kernel_returns_descriptor_from_default_registry(self) -> None: + @pto.vkernel(op="matcher_entry_default_registry_unique", dtypes=[(pto.f32, pto.i32)]) + def kernel(inp: pto.TensorView, scale: pto.i32): + return None + + selected = pto.select_kernel( + "a5", + "matcher_entry_default_registry_unique", + (pto.f32, pto.i32), + ) + + self.assertIs(selected, kernel) + + def test_select_kernel_uses_explicit_registry_without_falling_back(self) -> None: + @pto.vkernel(op="matcher_entry_registry_isolation_unique", dtypes=[(pto.f32,)]) + def default_kernel(inp: pto.TensorView): + return None + + empty_registry = pto.KernelRegistry() + with self.assertRaises(LookupError) as ctx: + pto.select_kernel( + "a5", + "matcher_entry_registry_isolation_unique", + (pto.f32,), + registry=empty_registry, + ) + self.assertIn("found no registered kernel", str(ctx.exception)) + + isolated_registry = pto.KernelRegistry() + isolated_registry.register(default_kernel) + selected = pto.select_kernel( + "a5", + "matcher_entry_registry_isolation_unique", + (pto.f32,), + registry=isolated_registry, + ) + + self.assertIs(selected, default_kernel) + self.assertEqual(len(isolated_registry.descriptors), 1) + + def test_select_kernel_binds_concrete_signature_from_multi_signature_descriptor(self) -> None: + @pto.vkernel( + op="matcher_multi_signature_unique", + dtypes=[ + (pto.f16, pto.f16), + (pto.f32, pto.f32), + ], + ) + def kernel(inp: pto.TensorView, tile: pto.Tile): + return None + + selected = pto.select_kernel( + "a5", + "matcher_multi_signature_unique", + (pto.f32, pto.f32), + ) + + self.assertEqual(selected.dtype_signature, (pto.f32, pto.f32)) + self.assertEqual( + [(param.name, param.kind, param.dtype) for param in selected.parameters], + [("inp", "tensorview", pto.f32), ("tile", "tile", pto.f32)], + ) + specialized = selected.specialize( + tile=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB) + ) + self.assertIn("memref<8x16xf32", specialized.mlir_text()) + + def test_select_kernel_matches_wildcards_deterministically(self) -> None: + @pto.vkernel( + op="matcher_wildcard_unique", + dtypes=[ + (pto.AnyInt, pto.AnyType), + (pto.AnyFloat, pto.AnyType), + ], + ) + def kernel(lhs: pto.TensorView, rhs: pto.Tile): + return None + + selected = pto.select_kernel( + "a5", + "matcher_wildcard_unique", + (pto.f32, pto.i32), + ) + + self.assertEqual(selected.dtype_signature, (pto.f32, pto.i32)) + self.assertEqual(selected.parameters[0].dtype, pto.f32) + self.assertEqual(selected.parameters[1].dtype, pto.i32) + + def test_select_kernel_enforces_typevar_consistency_per_signature(self) -> None: + @pto.vkernel( + op="matcher_typevar_unique", + dtypes=[(pto.TypeVar("T"), pto.TypeVar("T"))], + ) + def kernel(lhs: pto.TensorView, rhs: pto.Tile): + return None + + selected = pto.select_kernel( + "a5", + "matcher_typevar_unique", + (pto.f32, pto.f32), + ) + self.assertEqual(selected.dtype_signature, (pto.f32, pto.f32)) + + with self.assertRaises(LookupError) as ctx: + pto.select_kernel( + "a5", + "matcher_typevar_unique", + (pto.f32, pto.i32), + ) + self.assertIn("found no registered kernel", str(ctx.exception)) + + def test_polymorphic_descriptor_requires_select_kernel_before_materialization(self) -> None: + @pto.vkernel( + op="matcher_materialization_gate_unique", + dtypes=[(pto.AnyFloat, pto.AnyFloat)], + ) + def kernel(inp: pto.TensorView, out: pto.TensorView): + return None + + with self.assertRaises(ValueError) as ctx: + kernel.mlir_text() + self.assertIn("requires pto.select_kernel(...)", str(ctx.exception)) + + def test_select_kernel_evaluates_constraints_before_priority(self) -> None: + def requires_large_batch(context_attrs): + return context_attrs.get("batch", 0) >= 1024 + + @pto.vkernel( + op="matcher_constraint_priority_unique", + dtypes=[(pto.AnyFloat, pto.AnyFloat)], + constraints=[requires_large_batch], + priority=100, + ) + def high_priority_kernel(inp: pto.TensorView, out: pto.TensorView): + return None + + @pto.vkernel( + op="matcher_constraint_priority_unique", + dtypes=[(pto.AnyFloat, pto.AnyFloat)], + constraints=[], + priority=10, + ) + def fallback_kernel(inp: pto.TensorView, out: pto.TensorView): + return None + + selected = pto.select_kernel( + "a5", + "matcher_constraint_priority_unique", + (pto.f32, pto.f32), + context_attrs={"batch": 128}, + ) + self.assertIs(selected.py_fn, fallback_kernel.py_fn) + self.assertEqual(selected.priority, 10) + + selected = pto.select_kernel( + "a5", + "matcher_constraint_priority_unique", + (pto.f32, pto.f32), + context_attrs={"batch": 4096}, + ) + self.assertIs(selected.py_fn, high_priority_kernel.py_fn) + self.assertEqual(selected.priority, 100) + + def test_select_kernel_raises_tie_error_for_equal_highest_priority(self) -> None: + @pto.vkernel( + op="matcher_priority_tie_unique", + dtypes=[(pto.AnyFloat, pto.AnyFloat)], + priority=50, + ) + def lhs(inp: pto.TensorView, out: pto.TensorView): + return None + + @pto.vkernel( + op="matcher_priority_tie_unique", + dtypes=[(pto.AnyFloat, pto.AnyFloat)], + priority=50, + ) + def rhs(inp: pto.TensorView, out: pto.TensorView): + return None + + with self.assertRaises(LookupError) as ctx: + pto.select_kernel( + "a5", + "matcher_priority_tie_unique", + (pto.f32, pto.f32), + ) + self.assertIn("multiple highest-priority kernels", str(ctx.exception)) + self.assertIn("lhs(priority=50", str(ctx.exception)) + self.assertIn("rhs(priority=50", str(ctx.exception)) + + def test_select_kernel_reports_no_candidate_after_constraint_evaluation(self) -> None: + @pto.vkernel( + op="matcher_constraint_empty_unique", + dtypes=[(pto.AnyFloat, pto.AnyFloat)], + constraints=[lambda context_attrs: context_attrs.get("enabled", False)], + priority=1, + ) + def kernel(inp: pto.TensorView, out: pto.TensorView): + return None + + with self.assertRaises(LookupError) as ctx: + pto.select_kernel( + "a5", + "matcher_constraint_empty_unique", + (pto.f32, pto.f32), + context_attrs={"enabled": False}, + ) + self.assertIn("after constraint evaluation", str(ctx.exception)) + + class TileLangDSLDescriptorTests(unittest.TestCase): def test_descriptor_metadata_and_parameter_binding(self) -> None: @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f16, pto.i32)], verify=False) @@ -65,6 +283,16 @@ def kernel(inp: pto.TensorView, tile: pto.Tile, scale: pto.i32): self.assertEqual(kernel.parameters[1].element_dtype, pto.f16) self.assertIsNone(kernel.parameters[2].element_dtype) + def test_pointer_parameter_annotation_binds_as_ptr_kind(self) -> None: + @pto.vkernel(op="ptr_surface", dtypes=[(pto.f32, pto.i64)], advanced=True) + def kernel(src: pto.ptr(pto.f32, pto.MemorySpace.UB), addr: pto.i64): + return None + + self.assertEqual(kernel.parameters[0].kind, "ptr") + self.assertEqual(kernel.parameters[0].dtype, pto.f32) + self.assertEqual(kernel.parameters[0].annotation, pto.ptr(pto.f32, pto.MemorySpace.UB)) + self.assertEqual(kernel.parameters[0].element_dtype, pto.f32) + def test_specialization_enables_materialization_apis(self) -> None: @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f16)]) def kernel(inp: pto.TensorView, tile: pto.Tile): @@ -470,9 +698,9 @@ def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): ) semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) - self.assertEqual(len(semantic_kernel.body), 2) - self.assertIsInstance(semantic_kernel.body[0], SemanticVecscopeStmt) - vecscope = semantic_kernel.body[0] + vecscope_stmts = [stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticVecscopeStmt)] + self.assertEqual(len(vecscope_stmts), 1) + vecscope = vecscope_stmts[0] self.assertIsInstance(vecscope, SemanticVecscopeStmt) outer_loop = next(stmt for stmt in vecscope.body if isinstance(stmt, SemanticForStmt)) self.assertIsInstance(outer_loop, SemanticForStmt) @@ -486,6 +714,8 @@ def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): self.assertNotIn("pto.strict_vecscope(", text) self.assertRegex(text, r"pto\.vecscope \{\n(?:.|\n)*scf\.for %row_") self.assertEqual(text.count("pto.vecscope {"), 1) + self.assertLess(text.index("%rows_1 = arith.constant 8 : index"), text.index("pto.vecscope {")) + self.assertLess(text.index("%cols_2 = arith.constant 64 : index"), text.index("pto.vecscope {")) self.assertRegex(text, r"%tmp_\d+ = arith\.muli %row_\d+, %c64 : index") self.assertRegex(text, r"%tmp_\d+ = arith\.addi %tmp_\d+, %col_\d+ : index") self.assertIn("pto.vlds %arg1[", text) @@ -517,6 +747,51 @@ def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): self.assertIn("pto.vadd", text) self.assertIn("pto.vsts", text) + def test_advanced_mode_scalar_boundary_cuts_inferred_vecscope(self) -> None: + @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32)], advanced=True) + def kernel(src: pto.Tile, dst: pto.Tile): + dtype = src.element_type + first_mask = pto.make_mask(dtype, pto.PAT.ALL) + first = pto.vlds(src[0, 0:]) + pto.vsts(first, dst[0, 0:], first_mask) + boundary = 1 + second_mask = pto.make_mask(dtype, pto.PAT.ALL) + second = pto.vlds(src[1, 0:]) + pto.vsts(second, dst[1, 0:], second_mask) + return None + + specialized = kernel.specialize( + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertEqual(text.count("pto.vecscope {"), 2) + self.assertLess(text.index("%boundary_"), text.rindex("pto.vecscope {")) + + def test_advanced_mode_control_flow_boundary_cuts_inferred_vecscope(self) -> None: + @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32, pto.i32)], advanced=True) + def kernel(src: pto.Tile, dst: pto.Tile, flag: pto.i32): + dtype = src.element_type + all_mask = pto.make_mask(dtype, pto.PAT.ALL) + if flag: + first = pto.vlds(src[0, 0:]) + pto.vsts(first, dst[0, 0:], all_mask) + else: + second = pto.vlds(src[1, 0:]) + pto.vsts(second, dst[1, 0:], all_mask) + return None + + specialized = kernel.specialize( + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("scf.if", text) + self.assertEqual(text.count("pto.vecscope {"), 2) + self.assertLess(text.index("scf.if"), text.index("pto.vecscope {")) + def test_advanced_mode_keeps_strict_vecscope_as_hard_boundary(self) -> None: @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32)], advanced=True) def kernel(src: pto.Tile, dst: pto.Tile): @@ -540,6 +815,209 @@ def kernel(src: pto.Tile, dst: pto.Tile): self.assertEqual(text.count("pto.vecscope {"), 1) self.assertEqual(text.count("pto.strict_vecscope("), 1) + def test_advanced_mode_lowers_raw_pointer_and_low_level_dma_surface(self) -> None: + @pto.vkernel(op="ptr_dma", dtypes=[(pto.f32, pto.f32, pto.i64)], advanced=True) + def kernel( + src_gm: pto.ptr(pto.f32, pto.MemorySpace.GM), + dst_gm: pto.ptr(pto.f32, pto.MemorySpace.GM), + addr: pto.i64, + ): + ub_src = pto.castptr(addr, pto.ptr(pto.f32, pto.MemorySpace.UB)) + ub_dst = pto.addptr(ub_src, 64) + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec = pto.vlds(ub_src, 0) + pto.vsts(vec, ub_dst, 0, mask) + + src_bytes = pto.castptr(src_gm, pto.ptr(pto.i8, pto.MemorySpace.GM)) + dst_bytes = pto.castptr(dst_gm, pto.ptr(pto.i8, pto.MemorySpace.GM)) + src_offset = pto.addptr(src_bytes, 0) + dst_offset = pto.addptr(dst_bytes, 0) + typed_src = pto.castptr(src_offset, pto.ptr(pto.f32, pto.MemorySpace.GM)) + typed_dst = pto.castptr(dst_offset, pto.ptr(pto.f32, pto.MemorySpace.GM)) + + pto.set_loop2_stride_outtoub(4096, 4096) + pto.set_loop1_stride_outtoub(4096, 4096) + pto.set_loop_size_outtoub(1, 1) + pto.copy_gm_to_ubuf(typed_src, ub_src, 0, 32, 128, 0, 0, False, 0, 128, 128) + + pto.set_loop2_stride_ubtoout(4096, 4096) + pto.set_loop1_stride_ubtoout(4096, 4096) + pto.set_loop_size_ubtoout(1, 1) + pto.copy_ubuf_to_ubuf(ub_src, ub_dst, 0, 32, 128, 128, 128) + pto.copy_ubuf_to_gm(ub_dst, typed_dst, 0, 32, 128, 0, 128, 128) + return None + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(kernel)) + self.assertIsInstance(semantic_kernel.parameters[0].type, SemanticPtrType) + self.assertEqual(semantic_kernel.parameters[0].type.memory_space, "gm") + self.assertIsInstance(semantic_kernel.parameters[1].type, SemanticPtrType) + self.assertEqual(semantic_kernel.parameters[1].type.memory_space, "gm") + self.assertTrue(any(isinstance(stmt, SemanticVecscopeStmt) for stmt in semantic_kernel.body)) + self.assertTrue(any(isinstance(stmt, SemanticDmaConfigStmt) for stmt in semantic_kernel.body)) + self.assertTrue(any(isinstance(stmt, SemanticLowLevelCopyStmt) for stmt in semantic_kernel.body)) + + text = kernel.mlir_text() + self.assertIn( + "func.func @kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: i64) {", + text, + ) + self.assertRegex( + text, + r"%ub_src_\d+ = pto\.castptr %arg2 : i64 -> !pto\.ptr", + ) + self.assertRegex( + text, + r"%ub_dst_\d+ = pto\.addptr %ub_src_\d+, %c64 : !pto\.ptr -> !pto\.ptr", + ) + self.assertIn("pto.vecscope {", text) + self.assertRegex( + text, + r"%vec_\d+ = pto\.vlds %ub_src_\d+\[%c0\] : !pto\.ptr -> !pto\.vreg<64xf32>", + ) + self.assertRegex( + text, + r"pto\.vsts %vec_\d+, %ub_dst_\d+\[%c0\], %mask_\d+ : !pto\.vreg<64xf32>, !pto\.ptr, !pto\.mask", + ) + self.assertRegex( + text, + r"%src_bytes_\d+ = pto\.castptr %arg0 : !pto\.ptr -> !pto\.ptr", + ) + self.assertRegex( + text, + r"%dst_bytes_\d+ = pto\.castptr %arg1 : !pto\.ptr -> !pto\.ptr", + ) + self.assertRegex( + text, + r"%src_offset_\d+ = pto\.addptr %src_bytes_\d+, %c0 : !pto\.ptr -> !pto\.ptr", + ) + self.assertRegex( + text, + r"%dst_offset_\d+ = pto\.addptr %dst_bytes_\d+, %c0 : !pto\.ptr -> !pto\.ptr", + ) + self.assertRegex( + text, + r"pto\.set_loop2_stride_outtoub %tmp_\d+, %tmp_\d+ : i64, i64", + ) + self.assertRegex( + text, + r"pto\.set_loop1_stride_outtoub %tmp_\d+, %tmp_\d+ : i64, i64", + ) + self.assertRegex( + text, + r"pto\.set_loop_size_outtoub %tmp_\d+, %tmp_\d+ : i64, i64", + ) + self.assertRegex( + text, + r"pto\.copy_gm_to_ubuf %typed_src_\d+, %ub_src_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %false, %tmp_\d+, %tmp_\d+, %tmp_\d+", + ) + self.assertIn( + ": !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64", + text, + ) + self.assertRegex( + text, + r"pto\.set_loop2_stride_ubtoout %tmp_\d+, %tmp_\d+ : i64, i64", + ) + self.assertRegex( + text, + r"pto\.set_loop1_stride_ubtoout %tmp_\d+, %tmp_\d+ : i64, i64", + ) + self.assertRegex( + text, + r"pto\.set_loop_size_ubtoout %tmp_\d+, %tmp_\d+ : i64, i64", + ) + self.assertRegex( + text, + r"pto\.copy_ubuf_to_ubuf %ub_src_\d+, %ub_dst_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+", + ) + self.assertIn( + ": !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64", + text, + ) + self.assertRegex( + text, + r"pto\.copy_ubuf_to_gm %ub_dst_\d+, %typed_dst_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+", + ) + + def test_advanced_mode_lowers_compare_predicate_carry_and_rearrangement_families(self) -> None: + @pto.vkernel(op="advanced_family", dtypes=[(pto.i32, pto.i32, pto.i32, pto.i32)], advanced=True) + def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile, scalar: pto.i32): + all_mask = pto.make_mask(pto.i32, pto.PAT.ALL) + lhs = pto.vlds(src0[0, 0:]) + rhs = pto.vlds(src1[0, 0:]) + cmp_mask = pto.vcmp(lhs, rhs, all_mask, "lt") + cmp_scalar_mask = pto.vcmps(lhs, scalar, all_mask, "gt") + negated = pto.pnot(cmp_mask, all_mask) + picked = pto.psel(cmp_mask, negated, cmp_scalar_mask) + packed = pto.ppack(picked, "PART_EVEN") + unpacked = pto.punpack(packed, "PART_ODD") + sum_vec, carry_mask = pto.vaddc(lhs, rhs, all_mask) + diff_vec, borrow_mask = pto.vsubc(lhs, rhs, all_mask) + sum_with_carry, carry_mask2 = pto.vaddcs(sum_vec, diff_vec, carry_mask, all_mask) + diff_with_borrow, borrow_mask2 = pto.vsubcs(sum_with_carry, diff_vec, borrow_mask, all_mask) + low, high = pto.vintlv(sum_with_carry, diff_with_borrow) + dlow, dhigh = pto.vdintlv(low, high) + even = pto.vintlvv2(dlow, dhigh, "PART_EVEN") + odd = pto.vdintlvv2(dlow, dhigh, "PART_ODD") + selected = pto.vsel(even, odd, unpacked) + selected_r = pto.vselr(selected, sum_with_carry) + final = pto.vselrv2(selected_r, diff_with_borrow) + pto.vsts(final, dst[0, 0:], all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src0=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src1=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + vecscope_stmts = [stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticVecscopeStmt)] + self.assertEqual(len(vecscope_stmts), 1) + + text = specialized.mlir_text() + self.assertIn("pto.vecscope {", text) + self.assertIn('pto.vcmp ', text) + self.assertIn(', "lt" : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.mask', text) + self.assertIn('pto.vcmps ', text) + self.assertIn(', "gt" : !pto.vreg<64xi32>, i32, !pto.mask -> !pto.mask', text) + self.assertIn(" = pto.pnot ", text) + self.assertIn(" = pto.psel ", text) + self.assertIn(' = pto.ppack ', text) + self.assertIn('"PART_EVEN"', text) + self.assertIn(' = pto.punpack ', text) + self.assertIn('"PART_ODD"', text) + self.assertRegex( + text, + r"%sum_vec_\d+, %carry_mask_\d+ = pto\.vaddc %lhs_\d+, %rhs_\d+, %all_mask_\d+ : !pto\.vreg<64xi32>, !pto\.vreg<64xi32>, !pto\.mask -> !pto\.vreg<64xi32>, !pto\.mask", + ) + self.assertRegex( + text, + r"%diff_vec_\d+, %borrow_mask_\d+ = pto\.vsubc %lhs_\d+, %rhs_\d+, %all_mask_\d+ : !pto\.vreg<64xi32>, !pto\.vreg<64xi32>, !pto\.mask -> !pto\.vreg<64xi32>, !pto\.mask", + ) + self.assertRegex( + text, + r"%sum_with_carry_\d+, %carry_mask2_\d+ = pto\.vaddcs %sum_vec_\d+, %diff_vec_\d+, %carry_mask_\d+, %all_mask_\d+ : !pto\.vreg<64xi32>, !pto\.vreg<64xi32>, !pto\.mask, !pto\.mask -> !pto\.vreg<64xi32>, !pto\.mask", + ) + self.assertRegex( + text, + r"%diff_with_borrow_\d+, %borrow_mask2_\d+ = pto\.vsubcs %sum_with_carry_\d+, %diff_vec_\d+, %borrow_mask_\d+, %all_mask_\d+ : !pto\.vreg<64xi32>, !pto\.vreg<64xi32>, !pto\.mask, !pto\.mask -> !pto\.vreg<64xi32>, !pto\.mask", + ) + self.assertRegex( + text, + r"%low_\d+, %high_\d+ = pto\.vintlv %sum_with_carry_\d+, %diff_with_borrow_\d+ : !pto\.vreg<64xi32>, !pto\.vreg<64xi32> -> !pto\.vreg<64xi32>, !pto\.vreg<64xi32>", + ) + self.assertRegex( + text, + r"%dlow_\d+, %dhigh_\d+ = pto\.vdintlv %low_\d+, %high_\d+ : !pto\.vreg<64xi32>, !pto\.vreg<64xi32> -> !pto\.vreg<64xi32>, !pto\.vreg<64xi32>", + ) + self.assertIn(" = pto.vintlvv2 ", text) + self.assertIn(" = pto.vdintlvv2 ", text) + self.assertIn(" = pto.vsel ", text) + self.assertIn(" = pto.vselr ", text) + self.assertIn(" = pto.vselrv2 ", text) + self.assertIn("pto.vsts ", text) + def test_elementwise_kernel_positive_regression_covers_dma_vecscope_tail_mask_and_dynamic_loop_bound(self) -> None: @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32, pto.f32, pto.i32)]) def kernel(inp: pto.TensorView, out: pto.TensorView, tile: pto.Tile, remaining: pto.i32): @@ -672,22 +1150,27 @@ def kernel(inp: pto.TensorView, tile: pto.Tile, scale: pto.i32): class TileLangDSLDiagnosticsTests(unittest.TestCase): - def test_matcher_feature_diagnostics_point_to_follow_up_change(self) -> None: - cases = [ - lambda: pto.vkernel(op="x", dtypes=[(pto.f32,)], constraints=[])(lambda x: None), - lambda: pto.vkernel(op="x", dtypes=[(pto.f32,)], priority=1)(lambda x: None), - lambda: pto.vkernel(op="x", dtypes=[(pto.f32,), (pto.f16,)])(lambda x: None), - lambda: pto.vkernel(op="x", dtypes=[(pto.AnyFloat,)])(lambda x: None), - lambda: pto.vkernel(op="x", dtypes=[(pto.TypeVar("T"),)])(lambda x: None), - ] - - for thunk in cases: - with self.assertRaises(ValueError) as ctx: - thunk() - self.assertIn( - "extend-tilelang-dsl-matcher-and-advanced-surface", - str(ctx.exception), - ) + def test_matcher_feature_validation_rejects_invalid_constraints_and_priority(self) -> None: + def kernel(x: pto.TensorView): + return None + + with self.assertRaises(TypeError) as constraints_ctx: + pto.vkernel(op="x", dtypes=[(pto.f32,)], constraints=[123])(kernel) + self.assertIn("constraints[0] must be callable", str(constraints_ctx.exception)) + + with self.assertRaises(TypeError) as priority_ctx: + pto.vkernel(op="x", dtypes=[(pto.f32,)], priority=True)(kernel) + self.assertIn("priority must be an int", str(priority_ctx.exception)) + + def test_advanced_mode_keeps_vreduce_rejected_until_authoring_op_exists(self) -> None: + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.vkernel(op="x", dtypes=[(pto.i32,)], advanced=True) + def kernel(x: pto.Tile): + pto.vreduce(x) + return None + + self.assertIn("advanced family surface `pto.vreduce`", str(ctx.exception)) def test_unsupported_python_syntax_reports_source_location(self) -> None: with self.assertRaises(pto.TileLangFrontendError) as ctx: @@ -725,7 +1208,7 @@ def kernel(x: pto.TensorView): self.assertIn("vector op surface `pto.vadd` requires explicit pto.strict_vecscope", str(ctx.exception)) self.assertIn(f"{__file__}:", str(ctx.exception)) - def test_unsupported_advanced_family_points_to_follow_up_change(self) -> None: + def test_advanced_family_requires_advanced_mode(self) -> None: with self.assertRaises(pto.TileLangFrontendError) as ctx: @pto.vkernel(op="x", dtypes=[(pto.f32, pto.f32)]) @@ -735,11 +1218,7 @@ def kernel(x: pto.TensorView, tile: pto.Tile): pto.vcmp(lhs, rhs, mask, "lt") return None - self.assertIn("advanced family surface `pto.vcmp`", str(ctx.exception)) - self.assertIn( - "extend-tilelang-dsl-matcher-and-advanced-surface", - str(ctx.exception), - ) + self.assertIn("surface `pto.vcmp` requires advanced=True", str(ctx.exception)) self.assertIn(f"{__file__}:", str(ctx.exception)) def test_missing_specialization_reports_source_location(self) -> None: From 0945be0df0e85f2c76365b266652718b2c818133 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Tue, 7 Apr 2026 10:20:25 +0800 Subject: [PATCH 022/192] Amend --- tilelang-dsl/python/tilelang_dsl/types.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tilelang-dsl/python/tilelang_dsl/types.py b/tilelang-dsl/python/tilelang_dsl/types.py index 0e41114de..ca94ba895 100644 --- a/tilelang-dsl/python/tilelang_dsl/types.py +++ b/tilelang-dsl/python/tilelang_dsl/types.py @@ -23,6 +23,15 @@ class Tile: """Bare Tile annotation marker for TileLang DSL v1.""" +@dataclass(frozen=True) +class PointerType: + element_dtype: ScalarType + memory_space: "MemorySpace" + + def __repr__(self) -> str: + return f"ptr({self.element_dtype!r}, {self.memory_space!r})" + + @dataclass(frozen=True) class WildcardType: name: str @@ -111,6 +120,14 @@ def TypeVar(name: str) -> TypeVariable: return TypeVariable(name) +def ptr(dtype: ScalarType, memory_space: MemorySpace) -> PointerType: + if not isinstance(dtype, ScalarType): + raise TypeError("ptr() expects a TileLang scalar dtype") + if not isinstance(memory_space, MemorySpace): + raise TypeError("ptr() expects a TileLang MemorySpace") + return PointerType(element_dtype=dtype, memory_space=memory_space) + + def get_lanes(dtype: ScalarType) -> int: if not isinstance(dtype, ScalarType): raise TypeError("get_lanes expects a TileLang scalar dtype") @@ -135,6 +152,8 @@ def get_lanes(dtype: ScalarType) -> int: "TypeVar", "TensorView", "Tile", + "PointerType", + "ptr", "MemorySpace", "Pipe", "Event", From bd86d21272a6698d866441b96ccb5c116b03a62f Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Tue, 7 Apr 2026 10:28:23 +0800 Subject: [PATCH 023/192] Update docs --- docs/tilelang-dsl-guide.md | 127 +++++++++++++----- .../.openspec.yaml | 0 .../design.md | 0 .../proposal.md | 0 .../tilelang-dsl-advanced-surface/spec.md | 0 .../specs/tilelang-dsl-kernel-matcher/spec.md | 0 .../tasks.md | 0 .../tilelang-dsl-advanced-surface/spec.md | 88 ++++++++++++ .../specs/tilelang-dsl-kernel-matcher/spec.md | 85 ++++++++++++ 9 files changed, 264 insertions(+), 36 deletions(-) rename openspec/changes/{extend-tilelang-dsl-matcher-and-advanced-surface => archive/2026-04-07-extend-tilelang-dsl-matcher-and-advanced-surface}/.openspec.yaml (100%) rename openspec/changes/{extend-tilelang-dsl-matcher-and-advanced-surface => archive/2026-04-07-extend-tilelang-dsl-matcher-and-advanced-surface}/design.md (100%) rename openspec/changes/{extend-tilelang-dsl-matcher-and-advanced-surface => archive/2026-04-07-extend-tilelang-dsl-matcher-and-advanced-surface}/proposal.md (100%) rename openspec/changes/{extend-tilelang-dsl-matcher-and-advanced-surface => archive/2026-04-07-extend-tilelang-dsl-matcher-and-advanced-surface}/specs/tilelang-dsl-advanced-surface/spec.md (100%) rename openspec/changes/{extend-tilelang-dsl-matcher-and-advanced-surface => archive/2026-04-07-extend-tilelang-dsl-matcher-and-advanced-surface}/specs/tilelang-dsl-kernel-matcher/spec.md (100%) rename openspec/changes/{extend-tilelang-dsl-matcher-and-advanced-surface => archive/2026-04-07-extend-tilelang-dsl-matcher-and-advanced-surface}/tasks.md (100%) create mode 100644 openspec/specs/tilelang-dsl-advanced-surface/spec.md create mode 100644 openspec/specs/tilelang-dsl-kernel-matcher/spec.md diff --git a/docs/tilelang-dsl-guide.md b/docs/tilelang-dsl-guide.md index dfdd41263..b6dfabf3f 100644 --- a/docs/tilelang-dsl-guide.md +++ b/docs/tilelang-dsl-guide.md @@ -4,6 +4,33 @@ The TileLang Python DSL provides a high-level, Pythonic interface for authoring The DSL is designed to generate MLIR function libraries rather than direct binary executables. These MLIR libraries are intended to be consumed by other compilation frameworks that transform high-level tile semantics into low-level vector operations. This enables library developers to focus on hardware-aware kernel authoring while relying on upstream compilers for tile-level optimizations and code generation. +## Current Implementation Status + +The current `tilelang_dsl` package in this repository implements: +- matcher support: + `KernelRegistry`, `pto.select_kernel(...)`, multi-signature `dtypes`, + `AnyFloat` / `AnyInt` / `AnyType` / `AnyMask`, `TypeVar`, `constraints`, + `priority` +- advanced authoring support: + implicit vecscope inference in `advanced=True` kernels +- raw pointer / low-level DMA support: + `ptr(...)`, `castptr`, `addptr`, low-level DMA config ops, + `copy_gm_to_ubuf`, `copy_ubuf_to_gm`, `copy_ubuf_to_ubuf` +- advanced vector-family lowering: + compare/select, predicate movement, carry, rearrangement + +Still deferred in the current package head: +- reduction-family authoring + +Reason: +- the repo does not yet expose a public authoring-form VPTO reduction op that + the standalone TileLang DSL can target directly + +For the package-local source of truth, see: +- `tilelang-dsl/docs/v1-surface.md` +- `tilelang-dsl/docs/v1-lowering.md` +- `tilelang-dsl/docs/matcher-and-advanced-surface-migration.md` + ## Quick Start **Note on mask pattern enums**: For brevity, examples in this guide use `PAT` as an alias for `pto.MaskPattern` (e.g., `PAT.ALL` instead of `pto.MaskPattern.PAT_ALL`). You can create this alias with `from pto import MaskPattern as PAT` or `PAT = pto.MaskPattern`. @@ -224,6 +251,21 @@ When a PTO operation needs implementation, the system performs the following mat 5. **Priority Selection**: From the remaining kernels, select the one with the highest `priority` value. 6. **Fallback**: If no kernel matches, compilation fails with an error. +The package also exposes explicit selection utilities: + +```python +registry = pto.KernelRegistry() +registry.register(my_kernel) + +selected = pto.select_kernel( + "a5", + "matmul", + (pto.f16, pto.f16, pto.f32), + context_attrs={"k_aligned": True}, + registry=registry, +) +``` + #### Examples ##### Matmul with Multiple Implementations @@ -1869,63 +1911,66 @@ mask1, updated = pto.make_mask(pto.f32, remaining) # tail processing mask2 = pto.make_mask(pto.f32, PAT.ALL) # pattern mode ``` -#### `pto.ppack(mask: MaskType) -> pto.i32` +#### `pto.ppack(mask: MaskType, part: str) -> MaskType` -**Description**: Packs mask bits into a 32-bit integer. +**Description**: Rearranges a mask according to the requested `part` selector. **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| | `mask` | `MaskType` | Input mask (`mask_b8`, `mask_b16`, or `mask_b32`) | +| `part` | `str` | Part selector such as `"PART_EVEN"` or `"PART_ODD"` | **Returns**: | Return Value | Type | Description | |--------------|------|-------------| -| `packed` | `pto.i32` | Packed mask bits | +| `packed` | `MaskType` | Reordered mask | -#### `pto.punpack(packed: pto.i32) -> MaskType` +#### `pto.punpack(mask: MaskType, part: str) -> MaskType` -**Description**: Unpacks 32-bit integer to mask (granularity determined by context). +**Description**: Applies the inverse mask-part rearrangement selected by `part`. **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| -| `packed` | `pto.i32` | Packed mask bits | +| `mask` | `MaskType` | Input mask | +| `part` | `str` | Part selector such as `"PART_EVEN"` or `"PART_ODD"` | **Returns**: | Return Value | Type | Description | |--------------|------|-------------| -| `mask` | `MaskType` | Unpacked mask | +| `mask` | `MaskType` | Reordered mask | -#### `pto.pnot(mask: MaskType) -> MaskType` +#### `pto.pnot(mask: MaskType, gate: MaskType) -> MaskType` -**Description**: Logical negation of mask bits. +**Description**: Predicate negation under a same-granularity mask gate. **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| | `mask` | `MaskType` | Input mask | +| `gate` | `MaskType` | Gating mask with the same granularity | **Returns**: | Return Value | Type | Description | |--------------|------|-------------| | `negated` | `MaskType` | Negated mask | -#### `pto.psel(mask: MaskType, true_val: ScalarType, false_val: ScalarType) -> ScalarType` +#### `pto.psel(src0: MaskType, src1: MaskType, mask: MaskType) -> MaskType` -**Description**: Selects between two scalar values based on mask. +**Description**: Selects between two masks using a third mask as selector. **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| +| `src0` | `MaskType` | First input mask | +| `src1` | `MaskType` | Second input mask | | `mask` | `MaskType` | Selection mask | -| `true_val` | `ScalarType` | Value selected when mask bit is 1 | -| `false_val` | `ScalarType` | Value selected when mask bit is 0 | **Returns**: | Return Value | Type | Description | |--------------|------|-------------| -| `result` | `ScalarType` | Selected scalar value | +| `result` | `MaskType` | Selected mask | ### Unary Vector Operations @@ -2369,52 +2414,58 @@ scaled = pto.vmuls(vec_f32, pto.f32(2.0), mask32) Operations with carry propagation and selection. -#### `pto.vaddc(vec1: VRegType, vec2: VRegType, carry_in: ScalarType, mask: MaskType) -> (VRegType, ScalarType)` +Implemented current-package carry/select surface also includes: +- `pto.vcmp(vec0, vec1, seed_mask, cmp_mode) -> MaskType` +- `pto.vcmps(vec, scalar, seed_mask, cmp_mode) -> MaskType` +- `pto.vselr(vec0, vec1) -> VRegType` +- `pto.vselrv2(vec0, vec1) -> VRegType` +- `pto.vaddcs(vec0, vec1, carry_in, mask) -> (VRegType, MaskType)` +- `pto.vsubcs(vec0, vec1, carry_in, mask) -> (VRegType, MaskType)` -**Description**: Vector addition with carry input and output. +#### `pto.vaddc(vec1: VRegType, vec2: VRegType, mask: MaskType) -> (VRegType, MaskType)` + +**Description**: Vector addition with carry output. **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| | `vec1` | `VRegType` | First input vector | | `vec2` | `VRegType` | Second input vector | -| `carry_in` | `ScalarType` | Input carry bit | | `mask` | `MaskType` | Predicate mask | **Returns**: | Return Value | Type | Description | |--------------|------|-------------| | `result` | `VRegType` | Sum vector | -| `carry_out` | `ScalarType` | Output carry bit | +| `carry_out` | `MaskType` | Output carry mask | -#### `pto.vsubc(vec1: VRegType, vec2: VRegType, borrow_in: ScalarType, mask: MaskType) -> (VRegType, ScalarType)` +#### `pto.vsubc(vec1: VRegType, vec2: VRegType, mask: MaskType) -> (VRegType, MaskType)` -**Description**: Vector subtraction with borrow input and output. +**Description**: Vector subtraction with borrow output. **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| | `vec1` | `VRegType` | First input vector | | `vec2` | `VRegType` | Second input vector | -| `borrow_in` | `ScalarType` | Input borrow bit | | `mask` | `MaskType` | Predicate mask | **Returns**: | Return Value | Type | Description | |--------------|------|-------------| | `result` | `VRegType` | Difference vector | -| `borrow_out` | `ScalarType` | Output borrow bit | +| `borrow_out` | `MaskType` | Output borrow mask | -#### `pto.vsel(mask: MaskType, true_vec: VRegType, false_vec: VRegType) -> VRegType` +#### `pto.vsel(true_vec: VRegType, false_vec: VRegType, mask: MaskType) -> VRegType` **Description**: Vector select based on mask. **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| -| `mask` | `MaskType` | Selection mask | | `true_vec` | `VRegType` | Vector selected when mask bit is 1 | | `false_vec` | `VRegType` | Vector selected when mask bit is 0 | +| `mask` | `MaskType` | Selection mask | **Returns**: | Return Value | Type | Description | @@ -2423,7 +2474,7 @@ Operations with carry propagation and selection. **Example**: ```python -result = pto.vsel(mask32, scaled_vec, original_vec) +result = pto.vsel(scaled_vec, original_vec, mask32) ``` ### Data Rearrangement @@ -2458,31 +2509,35 @@ Operations for rearranging data within vectors. |--------------|------|-------------| | `result` | `pto.mask_b16` | Interleaved mask | -#### `pto.vintlv(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` +Implemented current-package rearrangement surface also includes: +- `pto.vintlvv2(vec0, vec1, part) -> VRegType` +- `pto.vdintlvv2(vec0, vec1, part) -> VRegType` + +#### `pto.vintlv(vec1: VRegType, vec2: VRegType) -> (VRegType, VRegType)` -**Description**: Interleave two vectors. +**Description**: Interleave two vectors and return the low/high results. **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| | `vec1` | `VRegType` | First input vector | | `vec2` | `VRegType` | Second input vector | -| `mask` | `MaskType` | Predicate mask | **Returns**: | Return Value | Type | Description | |--------------|------|-------------| -| `result` | `VRegType` | Interleaved vector | +| `low` | `VRegType` | Low interleaved result | +| `high` | `VRegType` | High interleaved result | -#### `pto.vdintlv(vec: VRegType, mask: MaskType) -> (VRegType, VRegType)` +#### `pto.vdintlv(vec0: VRegType, vec1: VRegType) -> (VRegType, VRegType)` -**Description**: Deinterleave vector into two vectors. +**Description**: Deinterleave a pair of vectors into low/high results. **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `mask` | `MaskType` | Predicate mask | +| `vec0` | `VRegType` | First input vector | +| `vec1` | `VRegType` | Second input vector | **Returns**: | Return Value | Type | Description | @@ -2832,7 +2887,7 @@ def conditional_scale(src: pto.ptr(pto.f32, MemorySpace.GM), scaled = pto.vmuls(vec, pto.f32(2.0), mask) # Keep original values below threshold - result = pto.vsel(mask, scaled, vec) + result = pto.vsel(scaled, vec, mask) pto.vsts(result, vout, i, all_mask) ``` @@ -2844,11 +2899,11 @@ def conditional_scale(src: pto.ptr(pto.f32, MemorySpace.GM), def prefix_sum(src: pto.ptr(pto.i32, MemorySpace.UB), dst: pto.ptr(pto.i32, MemorySpace.UB)): all_mask = pto.make_mask(pto.i32, PAT.ALL) - carry = pto.i32(0) + carry = all_mask for i in range(0, 256, 64): vec = pto.vlds(src, i) - result, carry = pto.vaddcs(vec, carry, all_mask) + result, carry = pto.vaddcs(vec, vec, carry, all_mask) pto.vsts(result, dst, i, all_mask) ``` diff --git a/openspec/changes/extend-tilelang-dsl-matcher-and-advanced-surface/.openspec.yaml b/openspec/changes/archive/2026-04-07-extend-tilelang-dsl-matcher-and-advanced-surface/.openspec.yaml similarity index 100% rename from openspec/changes/extend-tilelang-dsl-matcher-and-advanced-surface/.openspec.yaml rename to openspec/changes/archive/2026-04-07-extend-tilelang-dsl-matcher-and-advanced-surface/.openspec.yaml diff --git a/openspec/changes/extend-tilelang-dsl-matcher-and-advanced-surface/design.md b/openspec/changes/archive/2026-04-07-extend-tilelang-dsl-matcher-and-advanced-surface/design.md similarity index 100% rename from openspec/changes/extend-tilelang-dsl-matcher-and-advanced-surface/design.md rename to openspec/changes/archive/2026-04-07-extend-tilelang-dsl-matcher-and-advanced-surface/design.md diff --git a/openspec/changes/extend-tilelang-dsl-matcher-and-advanced-surface/proposal.md b/openspec/changes/archive/2026-04-07-extend-tilelang-dsl-matcher-and-advanced-surface/proposal.md similarity index 100% rename from openspec/changes/extend-tilelang-dsl-matcher-and-advanced-surface/proposal.md rename to openspec/changes/archive/2026-04-07-extend-tilelang-dsl-matcher-and-advanced-surface/proposal.md diff --git a/openspec/changes/extend-tilelang-dsl-matcher-and-advanced-surface/specs/tilelang-dsl-advanced-surface/spec.md b/openspec/changes/archive/2026-04-07-extend-tilelang-dsl-matcher-and-advanced-surface/specs/tilelang-dsl-advanced-surface/spec.md similarity index 100% rename from openspec/changes/extend-tilelang-dsl-matcher-and-advanced-surface/specs/tilelang-dsl-advanced-surface/spec.md rename to openspec/changes/archive/2026-04-07-extend-tilelang-dsl-matcher-and-advanced-surface/specs/tilelang-dsl-advanced-surface/spec.md diff --git a/openspec/changes/extend-tilelang-dsl-matcher-and-advanced-surface/specs/tilelang-dsl-kernel-matcher/spec.md b/openspec/changes/archive/2026-04-07-extend-tilelang-dsl-matcher-and-advanced-surface/specs/tilelang-dsl-kernel-matcher/spec.md similarity index 100% rename from openspec/changes/extend-tilelang-dsl-matcher-and-advanced-surface/specs/tilelang-dsl-kernel-matcher/spec.md rename to openspec/changes/archive/2026-04-07-extend-tilelang-dsl-matcher-and-advanced-surface/specs/tilelang-dsl-kernel-matcher/spec.md diff --git a/openspec/changes/extend-tilelang-dsl-matcher-and-advanced-surface/tasks.md b/openspec/changes/archive/2026-04-07-extend-tilelang-dsl-matcher-and-advanced-surface/tasks.md similarity index 100% rename from openspec/changes/extend-tilelang-dsl-matcher-and-advanced-surface/tasks.md rename to openspec/changes/archive/2026-04-07-extend-tilelang-dsl-matcher-and-advanced-surface/tasks.md diff --git a/openspec/specs/tilelang-dsl-advanced-surface/spec.md b/openspec/specs/tilelang-dsl-advanced-surface/spec.md new file mode 100644 index 000000000..4a7f37752 --- /dev/null +++ b/openspec/specs/tilelang-dsl-advanced-surface/spec.md @@ -0,0 +1,88 @@ +# tilelang-dsl-advanced-surface Specification + +## ADDED Requirements + +### Requirement: advanced mode MUST infer `pto.vecscope` for eligible vector chains while preserving `strict_vecscope` boundaries + +在 advanced mode 下,当用户省略显式 scope 且书写连续的 supported vector chain 时,frontend MUST 自动推断 dedicated `pto.vecscope`。 +scalar op、控制流边界、外部 call 和显式 `strict_vecscope` MUST 切断该推断。 +`strict_vecscope` 继续作为硬边界,inference MUST NOT 穿越其边界。 +inference 结果 MUST 继续满足当前 authoring-form VPTO legality contract,不得因为自动推断而放宽 typed-mask、capture operand、地址形态或 vecscope carrier 约束。 + +#### Scenario: contiguous vector chain becomes one inferred `pto.vecscope` + +- **WHEN** 用户在 advanced mode 下连续书写一条由 load -> vector ALU -> store 组成的纯 vector chain,且中间没有 scalar/control-flow boundary +- **THEN** frontend MUST 将该 chain lower 为一个 dedicated `pto.vecscope` +- **AND** 该 inferred vecscope MUST 满足当前 VPTO authoring legality contract + +#### Scenario: scalar or control-flow boundary cuts vecscope inference + +- **WHEN** 一条候选 vector chain 中间穿插 scalar op、`if` / `for` 边界或外部 call +- **THEN** frontend MUST 在该边界处切断 inference +- **AND** MUST NOT 把边界两侧的 vector 片段合并到同一个隐式 `pto.vecscope` + +#### Scenario: explicit `strict_vecscope` remains an inference barrier + +- **WHEN** 用户在 advanced mode 下显式书写 `strict_vecscope` +- **THEN** frontend MUST 保留该 `strict_vecscope` 原样语义 +- **AND** scope inference MUST NOT 跨越该显式边界去并合前后 vector chain + +### Requirement: advanced mode MUST support raw pointer, UBRef, low-level DMA, and `copy_ubuf_to_ubuf` authoring + +advanced mode MUST 将以下 surface 纳入正式契约: + +- `castptr` +- `addptr` +- raw UBRef load/store authoring +- low-level DMA programming +- `copy_ubuf_to_ubuf` + +这些 surface 仍 MUST lower 到当前合法的 authoring-form VPTO,不得发明另一套公开中间 IR。 +对于 raw pointer 与 UBRef 相关 surface,frontend MUST 继续遵守当前 ptr-only / buffer-like / copy-family 地址契约。 + +#### Scenario: low-level pointer and DMA surface lowers to legal authoring-form VPTO + +- **WHEN** 用户使用 `castptr`、`addptr`、raw UBRef、低层 DMA programming 或 `copy_ubuf_to_ubuf` +- **THEN** frontend MUST 生成对应的合法 authoring-form VPTO surface +- **AND** 输出结果 MUST 继续满足当前 copy/buffer-like/ptr-only 地址契约 + +#### Scenario: `copy_ubuf_to_ubuf` remains inside the existing DMA and address contract + +- **WHEN** 用户在 advanced mode 下书写 `copy_ubuf_to_ubuf` +- **THEN** lowering MUST 只生成当前 VPTO 已允许的 UB-to-UB copy programming 与 copy surface +- **AND** MUST NOT 通过额外的公开 helper IR 绕过现有 legality 检查 + +### Requirement: advanced mode MUST extend lowering to compare/select, predicate movement, carry, and rearrangement family capability sets + +advanced mode MUST 将以下 family 分组纳入正式 lowering capability: + +- compare/select +- predicate movement +- carry family +- rearrangement + +这些 family 的 lowering MUST 继续落在当前 authoring-form VPTO contract 内,并与已有 typed-mask、vecscope、pointer/buffer legality 规则兼容。 +对未进入这些 capability set 的 family,frontend MUST 继续显式 reject。 + +#### Scenario: advanced family kernel lowers without leaving the authoring-form VPTO contract + +- **WHEN** 用户在 advanced mode 下使用 compare/select、predicate movement、carry 或 rearrangement family 编写 kernel +- **THEN** frontend MUST 为该 family 生成合法的 authoring-form VPTO IR +- **AND** typed-mask、vecscope 和地址形态契约 MUST 与当前 VPTO legality contract 保持一致 + +#### Scenario: family outside the declared advanced capability set is still rejected + +- **WHEN** 用户使用未纳入上述 capability set 的 family +- **THEN** frontend MUST 继续报 unsupported-feature 错误 +- **AND** MUST NOT 因启用了 advanced mode 就默认放开全部 VPTO family + +### Requirement: advanced mode MUST keep reduction-family authoring rejected until a public authoring-form VPTO op exists + +当前 repo 尚未暴露可供 TileLang DSL 直接复用的 reduction authoring-form VPTO op。 +因此,在该 authoring 契约存在之前,frontend MUST 继续显式 reject reduction family surface,MUST NOT 通过额外公开 helper IR 或绕经 OpLib/EmitC 专用路径把 reduction 伪装成当前 capability 的一部分。 + +#### Scenario: reduction family remains deferred without a public authoring-form VPTO op + +- **WHEN** 用户在 advanced mode 下尝试书写 reduction family surface +- **THEN** frontend MUST 报 unsupported-feature 错误 +- **AND** MUST 说明该 family 仍处于 follow-up / deferred 状态 diff --git a/openspec/specs/tilelang-dsl-kernel-matcher/spec.md b/openspec/specs/tilelang-dsl-kernel-matcher/spec.md new file mode 100644 index 000000000..fde2f1717 --- /dev/null +++ b/openspec/specs/tilelang-dsl-kernel-matcher/spec.md @@ -0,0 +1,85 @@ +# tilelang-dsl-kernel-matcher Specification + +## ADDED Requirements + +### Requirement: TileLang DSL MUST provide an explicit kernel registry and selection API + +当同一 `target/op` 下存在多个 `@pto.vkernel` descriptor 时,TileLang DSL MUST 将它们注册到显式、可查询的 `KernelRegistry`。 +默认 registry MUST 是 module-level 对象;调用方 MAY 传入自定义 registry 以获得隔离的候选集合。 +系统 MUST 提供显式 selection API `pto.select_kernel(target, op, operand_types, context_attrs, registry=None)`,用于在给定 `target`、`op`、operand type 信息和上下文属性时选择唯一 kernel。 +实现 MUST NOT 依赖扫描 Python globals、locals 或导入顺序来隐式发现候选。 + +#### Scenario: selector returns the unique best kernel + +- **WHEN** registry 中存在多个针对同一 `target/op` 的 kernel descriptor,且其中一个在全部匹配步骤后成为唯一最佳候选 +- **THEN** `pto.select_kernel(...)` MUST 返回该 descriptor +- **AND** 返回结果 MUST 可继续走 `specialize()` / `mlir_text()` / `verify()` 流程 + +#### Scenario: custom registry restricts the candidate set explicitly + +- **WHEN** 调用方显式传入一个只含局部 kernel 的 `KernelRegistry` +- **THEN** selector MUST 只在该 registry 的候选集合内做匹配和决策 +- **AND** MUST NOT 回退去查询 module-level 默认 registry + +### Requirement: matcher MUST support concrete types, `Any*`, and `TypeVar` across multiple signatures + +matcher MUST 支持: + +- 多个 `dtypes` signature +- `AnyFloat` +- `AnyInt` +- `AnyType` +- `AnyMask` +- `TypeVar` + +`TypeVar` 在单个 signature 内 MUST 约束所有同名位置绑定到同一最终类型。 +多个 `dtypes` signature MUST 逐个独立求值;某个 signature 的 `TypeVar` 绑定状态 MUST NOT 泄漏到另一个 signature。 + +#### Scenario: wildcard and type-variable signatures match deterministically + +- **WHEN** 某个 kernel 使用多个 `dtypes` signature,并在其中混用 concrete type、`Any*` 与 `TypeVar` +- **THEN** matcher MUST 对每个 signature 独立求值 +- **AND** 只有满足所有 `TypeVar` 一致性约束的 signature 才能视为匹配成功 + +### Requirement: selection order MUST be target -> op -> dtype signature -> constraints -> priority -> tie error + +对一个 registry 中的候选集合,selector MUST 按以下固定顺序求值: + +1. `target` +2. `op` +3. `dtypes` signature 的 concrete / wildcard / type-variable 匹配 +4. `constraints` +5. `priority` +6. highest-priority tie error + +实现 MUST 保持该顺序 deterministic。 +系统 MUST NOT 依赖注册顺序、定义顺序、导入顺序或其他隐式规则来打破同一阶段的歧义。 + +#### Scenario: type match happens before constraints and priority + +- **WHEN** 一个候选在 `target/op` 上匹配,但没有任何 `dtypes` signature 能通过 concrete / wildcard / `TypeVar` 规则 +- **THEN** 该候选 MUST 在进入 `constraints` 评估前被移除 +- **AND** 其 `priority` MUST NOT 参与后续决策 + +### Requirement: constraint evaluation MUST happen after type matching and before priority resolution + +对同一 `target/op` 的候选集合,matcher MUST 先完成 dtype matching,再评估 `constraints`。 +只有通过 constraint evaluation 的候选,才允许进入 `priority` 比较阶段。 + +#### Scenario: higher-priority kernel with failing constraint does not win + +- **WHEN** 一个更高 `priority` 的 kernel 在 target/op/type 层面匹配成功,但 `constraints` 评估失败 +- **THEN** 该 kernel MUST 从候选集合中移除 +- **AND** selector MUST 继续在剩余候选中选择合法 kernel + +### Requirement: priority ties MUST raise an explicit selection error + +若在 target/op/type/constraint 全部通过后,最高 `priority` 仍对应多个候选,matcher MUST 报显式选择错误。 +系统 MUST NOT 依赖定义顺序、导入顺序或其他隐式规则做 tiebreak。 + +#### Scenario: equal-priority winners cause deterministic tie error + +- **WHEN** 多个 kernel 在 target/op/type/constraint 匹配后拥有相同的最高 `priority` +- **THEN** selector MUST 报错 +- **AND** 错误消息 MUST 指出发生 tie 的 kernel 集合 +- **AND** MUST NOT 静默选择第一个已注册 kernel From 11bc085105fb4285df5b289cdd263988f179cad5 Mon Sep 17 00:00:00 2001 From: qukelin Date: Tue, 7 Apr 2026 09:36:05 +0800 Subject: [PATCH 024/192] docs: add TileOp expand design and demo --- docs/designs/ptoas-tileop-expand-design.md | 972 +++------------------ 1 file changed, 140 insertions(+), 832 deletions(-) diff --git a/docs/designs/ptoas-tileop-expand-design.md b/docs/designs/ptoas-tileop-expand-design.md index 1ad4ad63d..f233d4a11 100644 --- a/docs/designs/ptoas-tileop-expand-design.md +++ b/docs/designs/ptoas-tileop-expand-design.md @@ -64,7 +64,8 @@ func.func @TADD( %b: !pto.tile_buf, %c: !pto.tile_buf) { + blayout=row_major, slayout=none_box, fractal=512, pad=0>) + attributes { pto.tile_function = "pto.tadd" } { %vecA = pto.tile_buf_addr %a : !pto.tile_buf -> memref<16x64xf32, strided<[64, 1]>, #pto.address_space> @@ -104,125 +105,79 @@ func.func @TADD( ### 2.1 总体思路 -为了降低开发门槛并解决参数组合的穷举问题,我们采用 **TileLang Python DSL** 来编写 -Tile Lib 的向量库实现。库开发者使用 Python 编写 vkernel 函数,PTOAS 编译器在编译时 -根据具体的 Tile op 以及操作数类型进行匹配、特化(specialization)和实例化(instantiation)。 - -TileLang DSL 的完整语法定义在 `tilelang-dsl/docs/tilelang-dsl-guide.md`,本章在该文档 -基础上,聚焦于本方案所依赖的语言子集及其语义约束。 +为了降低开发门槛并解决参数组合的穷举问题,我们采用 PTO DSL 来编写 Tile Lib 的向量库实现。这套语法定义在 TileLang 中,库开发者使用 Python 编写模板函数,由 PTOAS 编译器在编译时进行实例化。 整体方案: -1. **用 TileLang Python DSL 编写 vkernel**:以 `@pto.vkernel` 装饰器声明匹配元数据 - (`target` / `op` 或 `ops` / `dtypes` / `constraints` / `priority`),函数体使用 - `pto.Tile` 数据类型和基础向量指令(`make_mask` / `vlds` / `vsts` / `vadd` / …) - 按 Tile 指令语义编写向量实现。 -2. **编译器匹配并特化 vkernel**:PTOAS 遇到 Tile op 时,通过 DSL 提供的 - `pto.select_kernel(target, concrete_op, operand_types, …)` 匹配候选 vkernel,按 DSL Guide - §Kernel Selection Mechanism 的规则(target → op → dtypes → constraints → priority) - 选出一条,再以调用点的具体 `tile_buf` 类型作为 specialization key 进行特化,生成 - 以 `tile_buf` 为形参的向量实现函数。 -3. **inline 到调用点**:特化后的向量 IR 以 `func.call` 形式插入到原 Tile op 的位置, - 随后由 `PTOInlineLibCall` pass inline 到调用点,继续后续优化和 lowering 流程。 +1. **用 Python DSL 编写模板函数**:使用 `pto.Tile` 数据类型和向量操作接口,按 Tile 指令语义编写向量实现。 +2. **编译器实例化模板**:PTOAS 在编译过程中遇到 Tile op 时,调用对应的模板函数,填入具体的 `tile_buf` 类型参数,生成特化后的向量 IR。 +3. **inline 到调用点**:特化后的向量 IR 直接 inline 到原 Tile op 的位置,继续后续优化和 lowering 流程。 ### 2.2 TADD 模板示例 -以 `pto.tadd`(逐元素加法)为例,TileLang DSL 编写的 vkernel 如下(`PAT` 是 -`pto.MaskPattern` 的别名;算子名按 DSL Guide §Kernel Declaration 约定,不带 `pto.` 前缀): +以 `pto.tadd`(逐元素加法)为例,使用 Python DSL 编写的模板函数如下: ```python -from pto import MaskPattern as PAT - -@pto.vkernel( - target="a5", - op="tadd", # 匹配 pto.tadd - dtypes=[(pto.f32, pto.f32, pto.f32)], # 操作数类型签名 (src0, src1, dst) - advanced=True, # 启用隐式 vecscope 推断 -) -def template_tadd(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile) -> None: - dtype = dst.element_type # 编译期静态 - valid_rows, valid_cols = dst.valid_shape # 静态或动态 - - for row in range(0, valid_rows, 1): - remained = valid_cols - for col in range(0, valid_cols, pto.get_lanes(dtype)): - mask, remained = pto.make_mask(dtype, remained) - lhs = pto.vlds(src0[row, col:]) - rhs = pto.vlds(src1[row, col:]) - summed = pto.vadd(lhs, rhs, mask) - pto.vsts(summed, dst[row, col:], mask) - return None +@pto.tile_template(target="a5", op="pto.tadd") +def template_tadd(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = src0.element_type + elem_size = src0.element_size + rows, cols = src0.shape + v_rows, v_cols = src0.valid_shape + + for i in range(0, v_rows, 1): + remaining = v_cols + for j in range(0, v_cols, 256 / elem_size): + all_mask, remaining = pto.make_mask(dtype, remaining) + vec_a = pto.vlds(a[i, j]) + vec_b = pto.vlds(b[i, j]) + result = pto.vadd(vec_a, vec_b, all_mask) + pto.vsts(result, c[i, j], all_mask) ``` 代码解读: -- **`@pto.vkernel`** 装饰器声明本 kernel 匹配 `a5` 架构下的 `tadd` 算子、操作数签名 - `(f32, f32, f32)`。`advanced=True` 让编译器对函数体内的 `vlds`/`vadd`/`vsts` 序列 - 自动推断 `pto.vecscope`,无需显式 `with pto.vecscope():` 包裹 - (详见 DSL Guide §Implicit Scope Inference)。 -- **kernel 参数**为 3 个 `pto.Tile` 对象(2 个输入 `src0` / `src1`,1 个输出 `dst`), - 对应 VPTO IR 中的 `!pto.tile_buf` 类型,它们是实例化时被特化的 symbolic value。 - **参数顺序必须与 PTOAS 中对应指令的操作数顺序一致**(即 `ins` 在前、`outs` 在后), - 因为 `ExpandTileOp` 按位置索引直接传递操作数。 -- 通过 **`Tile` 属性接口**读取元素类型 `element_type` 和 `valid_shape`。参考 DSL Guide - §Tile Attributes:`shape` / `element_type` / `memory_space` / `config` 都是编译期静态 - 值,`valid_shape` 允许为静态或动态。 -- **2 层循环**分别遍历 tile 的行和列。外层步长 1,内层步长为 `pto.get_lanes(dtype)` - (单个向量寄存器可容纳的元素数,f32→64,f16→128)。 -- **`pto.make_mask(dtype, remained)`** 按 DSL Guide §Typed Masks 的 tail-processing 语义, - 返回 `(mask, new_remaining)`,并根据 `dtype` 自动选择正确的 mask 粒度 - (`f32` → `mask_b32`、`f16` → `mask_b16`、`i8` → `mask_b8`)。 -- **Tile 元素级索引语法糖** `src0[row, col:]` 实现向量宽度的 load/store - (DSL Guide §Address Generation Syntax Sugar):`col:` 后缀表示以 `col` 为起点、按 - 向量宽度连续读取;编译器按 `element_size` 和 layout 自动计算字节偏移,避免手写 - `i * cols * 4` 之类易错的算术。 -- **`pto.vadd(lhs, rhs, mask)`** 执行逐元素加法;**`pto.vsts(summed, dst[row, col:], mask)`** - 将结果带 mask 写回 `dst`。 +- **`@pto.tile_template`** 装饰器指示这是一个 `pto.tadd` 指令的模板,会在编译时进行实例化。 +- **输入参数**为 3 个 `pto.Tile` 数据类型参数,2 个输入(`src0`、`src1`),1 个输出(`dst`)。 +- 通过 **`Tile` 数据类型接口**获取元素类型(`element_type`)、元素大小(`element_size`)、静态 shape(`shape`)和 valid shape(`valid_shape`)信息。 +- 通过 **2 层循环**分别遍历 tile 的行和列。 +- 通过 **`pto.make_mask`** 指令,根据基础数据类型大小及有效数据数量设置 mask 寄存器。 +- 通过 **`pto.vlds`** 指令,以 `a[i, j]` 和 `b[i, j]` 为起始地址分别将数据读入向量寄存器。 +- 通过 **`pto.vadd`** 计算相加结果,写入寄存器 `result`。 +- 通过 **`pto.vsts`** 将 `result` 写入以 `c[i, j]` 为起始的地址区间。 ### 2.3 值模型与 Staging 语义 -TileLang DSL 按 DSL Guide §Value Model 的定义,采用 **symbolic value** 模型——函数体中的 -值并非 Python 运行时的 `int`/`float`,而是编译器构造的 SSA 值或编译期常量。在 vkernel -实例化过程中,`pto.Tile` 参数的属性按两种 stage 区分处理: +模板函数中使用的 `pto.Tile` 属性,在模板执行时分为两类不同阶段(stage)的值: #### 编译期静态值(Compile-time Static) -以下属性在 vkernel 实例化时已经确定,由 TileLang Codegen 在编译期折叠为字面量, -**不会**生成 MLIR SSA 值: +以下属性在模板实例化时已经确定,由 Python Codegen 在编译期折叠为字面量,**不会**生成 MLIR SSA 值: | 属性 | 来源 | 说明 | |------|------|------| -| `element_type` | `tile_buf` 的 `dtype` 字段 | 决定 vreg 类型和向量宽度;参与 specialization key | +| `element_type` | `tile_buf` 的 `dtype` 字段 | 决定 vreg 类型和向量宽度 | | `element_size` | 由 `dtype` 推导 | f32→4, f16→2, i8→1 | -| `shape` | `tile_buf` 的 `rows`, `cols` 字段 | **必须是编译期静态值**,参与 specialization key | -| `memory_space` | `tile_buf` 的 `loc` 字段 | `MemorySpace.GM` / `MemorySpace.UB`;参与 specialization key | -| `config` | `tile_buf` 的 blayout / slayout / fractal / pad | 决定 stride 模式和偏移计算方式 | +| `shape` | `tile_buf` 的 `rows`, `cols` 字段 | **必须是编译期静态值**,参与模板实例化的 specialization key | +| `config` | `tile_buf` 的 blayout/slayout/fractal/pad | 布局和配置信息 | -这些值在 Python 层直接参与运算(如 `pto.get_lanes(dtype)`、`rows * cols * element_size`), -结果在编译期确定。DSL Guide §Tile Types 明确规定 **Static Shape Requirement**: -`shape` 必须是 compile-time constant。 +这些值在 Python 层直接参与运算(如 `256 / elem_size`),结果在编译期确定。 #### 运行时 SSA 值(Runtime Dynamic) -以下属性可能在编译期未知,生成为实例化函数的参数或 SSA 值: +以下属性可能在编译期未知,生成为 MLIR 函数参数或 SSA 值: | 属性 | 来源 | 说明 | |------|------|------| -| `valid_shape` | `tile_buf` 的 `v_row`, `v_col` 字段 | **可以是静态也可以是动态**(DSL Guide §Tile Shape Concepts) | +| `valid_shape` | `tile_buf` 的 `v_row`, `v_col` 字段 | **可以是静态也可以是动态** | -当 `valid_shape` 为静态值时,TileLang Codegen 在编译期折叠(与 `shape` 相同处理方式); -当为动态值时,生成为实例化函数的 `index` 类型参数,循环边界等依赖它的地方生成 -`scf.for`。该参数在 PTOAS 侧由 `pto.bind_tile` 的 `valid_row` / `valid_col` -操作数承载(参见第三章)。 +当 `valid_shape` 为静态值时,Python Codegen 在编译期折叠(与 `shape` 相同处理方式);当为动态值时,生成为 MLIR 函数参数(`index` 类型),循环边界等依赖它的地方生成 `scf.for`。 #### 正式约束 -1. **`shape` 必须是编译期静态值**,并参与 specialization key。若 `shape` 为动态值, - vkernel 实例化应报错拒绝。 -2. **`valid_shape` 可以是静态也可以是动态**。当为静态值时,TileLang Codegen 应检查 - `valid_shape[i] ≤ shape[i]`(逐维度),对齐 DSL Guide §Tile Shape Concepts 的约束。 -3. **`element_type`、`element_size`、`memory_space`、`config` 必须是编译期静态值**, - 它们决定了函数体的结构(vreg 类型、向量宽度、stride 模式等)。 +1. **`shape` 必须是编译期静态值**,并参与模板实例化的 specialization key。如果 `shape` 为动态值,模板实例化应报错拒绝。 +2. **`valid_shape` 可以是静态也可以是动态**。当为静态值时,Python Codegen 侧应检查 `valid_shape <= shape`(逐维度)。 +3. **`element_type`、`element_size`、`config` 必须是编译期静态值**,它们决定了模板函数体的结构(vreg 类型、向量宽度、stride 模式等)。 #### 对控制流的影响 @@ -241,162 +196,101 @@ for i in range(0, rows, 1): # rows=16 静态 → Python 展开 16 次迭 ### 2.4 TileLang DSL 语法参考 -本节摘录本方案所依赖的 DSL 子集;完整定义见 `tilelang-dsl/docs/tilelang-dsl-guide.md`。 - -#### 2.4.1 基础标量类型 +#### 2.4.1 基础数据类型 | DSL 类型 | 说明 | 位宽 | |----------|------|------| -| `pto.i1` | 布尔 | 1 | -| `pto.i8` | 8 位整数 | 8 | +| `pto.i8` | 8 位整数 | 8 | | `pto.i16` | 16 位整数 | 16 | | `pto.i32` | 32 位整数 | 32 | | `pto.i64` | 64 位整数 | 64 | | `pto.f16` | 半精度浮点 | 16 | -| `pto.bf16`| Brain float 16 | 16 | +| `pto.bf16` | BFloat16 | 16 | | `pto.f32` | 单精度浮点 | 32 | -Python 字面量自动推导类型:`bool` → `pto.i1`,`int` → 上下文决定(通常 `pto.i32`/`pto.i64`), -`float` → `pto.f32`。需要显式类型时可用 `x = pto.i32(1024)` 或类型注解。 +Python 字面量自动推导类型:`int` → `pto.i32`,`float` → `pto.f32`。 -DSL 还提供类型通配符 `pto.AnyFloat` / `pto.AnyInt` / `pto.AnyType` / `pto.AnyMask` -和类型变量 `pto.TypeVar(...)`,用于在 `dtypes=` 中写多态签名。 +#### 2.4.2 Tile 数据类型 -#### 2.4.2 向量与 Mask 类型 +`pto.Tile` 表示一个带有布局和配置信息的数据块,对应 MLIR 中的 `!pto.tile_buf` 类型。 -向量寄存器固定 **256 字节** 宽度: - -```python -pto.vreg(64, pto.f32) # 64 lanes × 32 bit = 2048 bit -pto.vreg(128, pto.f16) # 128 lanes × 16 bit = 2048 bit -``` - -约束:`lanes × bitwidth(element_type) == 2048`。可用 `pto.get_lanes(dtype)` 获得 lane 数。 - -Mask 按位粒度分型(DSL Guide §Typed Masks),必须与 vreg 元素族匹配: - -| DSL 类型 | VPTO 类型 | 对应元素族 | -|----------|-----------|-----------| -| `pto.mask_b8` | `!pto.mask` | `i8` 向量 | -| `pto.mask_b16` | `!pto.mask` | `f16` / `bf16` / `i16` 向量 | -| `pto.mask_b32` | `!pto.mask` | `f32` / `i32` 向量 | - -粒度不匹配(例如 `f32` 向量配 `mask_b16`)会在类型检查阶段报错。 - -#### 2.4.3 Tile 数据类型 - -`pto.Tile` 表示一个带有布局和配置信息的数据块,对应 VPTO IR 中的 `!pto.tile_buf` 类型。 - -**Tile 属性接口**(DSL Guide §Tile Attributes): +**Tile 属性接口:** | 属性 | 类型 | 说明 | |------|------|------| -| `shape` | `tuple[int, ...]` | **编译期静态**的物理维度(rows, cols) | -| `valid_shape` | `tuple[int, ...]` | 有效数据维度(v_row, v_col),可为静态或动态,须 ≤ `shape` | -| `element_type` | `Type` | 元素类型,如 `pto.f32` | -| `element_size` | `int` | 元素字节大小 | -| `memory_space` | `MemorySpace` | `MemorySpace.GM` / `MemorySpace.UB` | -| `config` | `TileConfig` | 布局与 padding 配置 | -| `rank` / `num_elements` / `valid_elements` | `int` | 派生属性 | +| `shape` | `tuple[int, ...]` | Tile 的完整维度(rows, cols) | +| `valid_shape` | `tuple[int, ...]` | 有效数据维度(v_row, v_col),可能小于 shape | +| `element_type` | `Type` | 元素数据类型(如 `pto.f32`) | +| `element_size` | `int` | 元素字节大小(如 f32 → 4) | +| `memory_space` | `MemorySpace` | 内存空间(GM, UB) | +| `config` | `TileConfig` | 布局和 padding 配置 | -**Tile 配置枚举**: +**Tile 配置:** ```python -pto.BLayout.ROW_MAJOR / pto.BLayout.COL_MAJOR # 基础布局 -pto.SLayout.NONE_BOX / pto.SLayout.ROW_MAJOR / pto.SLayout.COL_MAJOR -pto.PadValue.NULL / pto.PadValue.ZERO / pto.PadValue.MAX / pto.PadValue.MIN +pto.BLayout.ROW_MAJOR # 行主序 +pto.BLayout.COL_MAJOR # 列主序 +pto.SLayout.NONE_BOX # 无二级布局 +pto.PadValue.NULL # 无 padding +pto.PadValue.ZERO # 零填充 ``` -**地址生成语法糖**(DSL Guide §Address Generation Syntax Sugar)——向量级读写使用 -元素索引语法,编译器自动按 layout 计算字节偏移: - -| 语法 | 含义 | -|------|------| -| `tile[row, col:]` | 行主序:从 `(row, col)` 起按向量宽度连续读 | -| `tile[row:, col]` | 列主序:从 `(row, col)` 起按向量宽度连续读 | -| `tile[start:]` | 1D tile:从 `start` 起按向量宽度连续读 | -| `tile[row, col]` | 单元素(仅 `pto.vsld` 等 broadcast load 使用) | - -#### 2.4.4 向量操作接口 +#### 2.4.3 向量操作接口 -本方案依赖 DSL Guide §Operations 中列在 **`stable`** tier 的 base vector ops: +向量寄存器固定 256 字节宽度,每次处理的元素数量由数据类型决定:f32 → 64 个元素,f16 → 128 个元素。 -**Mask 生成**(DSL Guide §`pto.make_mask`): +**Mask 操作:** -| 形式 | 说明 | +| 操作 | 说明 | |------|------| -| `pto.make_mask(dtype, remaining: pto.i32)` | Tail processing:返回 `(mask, new_remaining)` | -| `pto.make_mask(dtype, PAT.ALL)` | 固定 pattern:返回单值 `mask`。其它 pattern 包括 `PAT.EVEN`/`PAT.ODD` 等 | +| `pto.make_mask(dtype, remaining)` | 根据数据类型和剩余元素数量生成 mask,返回 `(mask, new_remaining)` | +| `pto.make_mask(dtype, PAT.ALL)` | 生成全 1 mask | -**向量 Load / Store**: +**向量 Load/Store:** | 操作 | 说明 | |------|------| -| `pto.vlds(tile[row, col:])` | 从 tile 的 `(row, col)` 按向量宽度加载到 vreg | -| `pto.vsts(vec, tile[row, col:], mask)` | 将 vreg 按 mask 写入 tile 的 `(row, col)` | +| `pto.vlds(tile[i, j])` | 从 Tile 的 `[i, j]` 位置加载一个向量寄存器的数据 | +| `pto.vsts(vec, tile[i, j], mask)` | 将向量寄存器数据写入 Tile 的 `[i, j]` 位置 | -上述两条也支持 DSL Guide 中的 byte-offset 形式 `pto.vlds(buf, offset)` / `pto.vsts(vec, buf, offset, mask)` -(Advanced Tier),但模板库优先使用元素索引语法。 - -**基础二元/一元算子**(用于常见 Tile op 的展开): +**二元向量运算:** | 操作 | 说明 | |------|------| -| `pto.vadd / vsub / vmul / vdiv(vec1, vec2, mask)` | 逐元素二元运算 | -| `pto.vmax / vmin(vec1, vec2, mask)` | 逐元素比较 | -| `pto.vabs / vexp / vln / vsqrt / vrelu(vec, mask)` | 逐元素一元运算 | -| `pto.vmuls / vadds(vec, scalar, mask)` | 向量-标量运算 | - -#### 2.4.5 控制流 +| `pto.vadd(vec1, vec2, mask)` | 逐元素加法 | +| `pto.vsub(vec1, vec2, mask)` | 逐元素减法 | +| `pto.vmul(vec1, vec2, mask)` | 逐元素乘法 | +| `pto.vdiv(vec1, vec2, mask)` | 逐元素除法 | +| `pto.vmax(vec1, vec2, mask)` | 逐元素取大 | +| `pto.vmin(vec1, vec2, mask)` | 逐元素取小 | -**循环**使用 Python 的 `range` 语法: +**一元向量运算:** -```python -for i in range(0, valid_rows, 1): - for j in range(0, valid_cols, pto.get_lanes(dtype)): - ... -``` +| 操作 | 说明 | +|------|------| +| `pto.vabs(vec, mask)` | 逐元素绝对值 | +| `pto.vexp(vec, mask)` | 逐元素指数 | +| `pto.vln(vec, mask)` | 逐元素对数 | +| `pto.vsqrt(vec, mask)` | 逐元素开方 | +| `pto.vrelu(vec, mask)` | 逐元素 ReLU | -当循环边界来自 `shape`(编译期常量)时,DSL 在 Python 层直接展开循环;当来自 -`valid_shape`(可能是动态值)时,生成 `scf.for` MLIR 循环。 +**向量-标量运算:** -**向量作用域**:本方案的 vkernel 统一使用 `advanced=True`,由编译器的 Scope Inference Pass -对连续、数据依赖的 `vlds`/`vadd`/`vsts` 序列自动推断 `pto.vecscope` 边界,库开发者无需 -显式书写 `with pto.vecscope(): ...`。需要精确控制时可使用 `strict_vecscope`(Advanced Tier)。 +| 操作 | 说明 | +|------|------| +| `pto.vmuls(vec, scalar, mask)` | 向量乘标量 | +| `pto.vadds(vec, scalar, mask)` | 向量加标量 | -#### 2.4.6 多算子模板(template slots) +#### 2.4.4 控制流 -对于计算骨架相同、仅核心算子不同的一组 Tile op(如 `tadd`/`tsub`/`tmul`/`tdiv`), -可用 DSL Guide §Template-based Kernel Authoring 的 `ops=[...]` + `templates=` + `pto.tpl(...)` -在一个 vkernel 中共享实现: +**循环**使用 Python 的 `range` 语法: ```python -@pto.vkernel( - target="a5", - ops=["tadd", "tsub", "tmul", "tdiv"], - dtypes=[(T, T, T)], - advanced=True, - templates={ - "core": {"tadd": "vadd", "tsub": "vsub", - "tmul": "vmul", "tdiv": "vdiv"}, - }, -) -def elementwise_arithmetic(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): - dtype = dst.element_type - rows, cols = dst.valid_shape - for row in range(0, rows, 1): - remained = cols - for col in range(0, cols, pto.get_lanes(dtype)): - mask, remained = pto.make_mask(dtype, remained) - lhs = pto.vlds(src0[row, col:]) - rhs = pto.vlds(src1[row, col:]) - out = pto.tpl("core", lhs, rhs, mask) # 按选中的具体算子替换 - pto.vsts(out, dst[row, col:], mask) +for i in range(0, v_rows, 1): + # 循环体 ``` -编译期 `pto.select_kernel(...)` 会把具体的 `tadd`/`tsub`/… 绑定到 `selected_op`, -`pto.tpl("core", ...)` 再按 `templates["core"]` 的映射展开为真正的 `vadd`/`vsub`/… 调用。 -这样本方案的 Tile Lib 可以用一份模板覆盖四条逐元素算子,显著收敛维护成本。 +当循环边界来自 `shape`(编译期常量)时,DSL 在 Python 层展开循环;当来自 `valid_shape`(可能是运行时动态值)时,生成 `scf.for` MLIR 循环。 ## 第三章 PTOAS 编译器:TileOp Expand @@ -417,7 +311,7 @@ PTOAS 编译器的输入可以是 Tile 指令、向量指令、或两者的混 ↓ Inline ← 将模板函数体 inline 到调用点 ↓ - Fold TileBuf Intrinsics ← 折叠 tile_buf / tensor_view intrinsic,解析到具体值 + Fold TileBuf Intrinsics ← 折叠 tile_buf_addr / tile_valid_rows / tile_valid_cols ↓ VF Fusion ← 合并相邻向量循环,消除中间 UB 读写 ↓ @@ -428,20 +322,20 @@ Tile 指令到向量指令的展开由三个 pass 协作完成: 1. **Expand TileOp**:核心 pass。调用 TileLang Python DSL 实例化模板库,生成以 `tile_buf` 为参数的向量实现函数,将原 Tile op 替换为对该函数的 `func.call`。 2. **Inline**:将模板函数体 inline 到调用点,使模板函数的 `tile_buf` 形参与调用点的实际 `tile_buf` 值绑定。 -3. **Fold TileBuf Intrinsics**:折叠 inline 后留下的 tile_buf 系列(`pto.tile_buf_addr`、`pto.tile_valid_rows`、`pto.tile_valid_cols`)和 tensor_view 系列(`pto.tensor_view_addr`、`pto.get_tensor_view_dim`、`pto.get_tensor_view_stride`)intrinsic,将 `tile_buf` / `partition_tensor_view` 的属性折叠为具体的 memref、常量和 SSA 值。 +3. **Fold TileBuf Intrinsics**:折叠 inline 后留下的 `pto.tile_buf_addr`、`pto.tile_valid_rows`、`pto.tile_valid_cols` 等 intrinsic,将 `tile_buf` 的静态属性(地址、shape、布局)折叠为具体的 memref 和常量。 ### 3.2 Expand TileOp Pass 的工作流程 以编译时遇到 `pto.tadd` 为例,Expand TileOp pass 的处理步骤如下: ``` -Step 1: 识别 Tile Op 并分类操作数 -──────────────────────────────── - 遍历函数体中所有 Tile op(pto.tadd, pto.tload, ...) - 每个操作数按 IR 类型分为三类: - Tile — TileBufType(如 pto.tadd 的输入/输出 tile_buf) - View — MemRefType(如 pto.tload 的 src,由 PTOViewToMemref 降级的 partition_tensor_view) - Scalar — 标量类型(如 pto.tadds 的 scalar 操作数) +Step 1: 识别 Tile Op +─────────────────── + 遍历函数体中所有 Tile op(pto.tadd, pto.tsub, ...) + 遇到 pto.tadd ins(%a, %b) outs(%c) + 从所有操作数的 tile_buf 类型提取属性: + dtype=f32, rows=16, cols=64, v_row=16, v_col=64, + blayout=row_major, slayout=none_box, fractal=512, pad=0 Step 2: 构造 Specialization Key + 查询缓存 ────────────────────────────────────────── @@ -451,81 +345,66 @@ Step 2: 构造 Specialization Key + 查询缓存 Step 3: 实例化模板(缓存未命中时执行) ───────────────────────────────────── - 调用 TileLang Python DSL,传入 op 名称和各操作数的类型信息 - Python DSL 查找匹配的 @vkernel 模板,填入具体参数进行特化 - 输出实例化后的 MLIR 函数,解析文本,克隆到目标 Module,写入缓存 + 调用 TileLang Python DSL,传入 op 名称和各操作数的 tile_buf 类型信息 + Python DSL 查找匹配的 @vkernel 模板,填入具体 tile_buf 参数进行特化 + 输出实例化后的 MLIR 函数(以 tile_buf 为参数,内含向量循环体) + 解析 MLIR 文本,克隆函数到目标 Module,写入缓存 Step 4: 生成调用并替换原 Tile Op ─────────────────────────────── - 在原 Tile op 位置插入 func.call @__pto_tilelang_...(%a, %b, %c) - - Tile 操作数:类型一致,直接传递 - - View 操作数:调用方类型为 memref,模板参数类型为 partition_tensor_view, - 插入 builtin.unrealized_conversion_cast 桥接(由后续 FoldTileBufIntrinsics 消除) - - Scalar 操作数:直接传递 + 在原 Tile op 位置插入 func.call @__pto_tilelang_tadd_f32_16_64(%a, %b, %c) + 操作数直接传递(类型均为 tile_buf,无需桥接转换) 删除原 Tile op ``` #### 3.2.1 Specialization Key 与缓存 -模板展开本质上是一个特化过程。当同一个 module 中存在多个相同类型的 Tile op(如多处 `pto.tadd` 且所有操作数类型完全相同),应复用已实例化的结果而非重复展开。 +模板展开本质上是一个特化过程。当同一个 module 中存在多个相同类型的 Tile op(如多处 `pto.tadd` 且所有 `tile_buf` 操作数类型完全相同),应复用已实例化的结果而非重复展开。 -**重要**:SpecKey 必须基于 **所有操作数** 的类型构建,而不仅仅是第一个操作数。因为同一个 op 的不同操作数可能有不同的类型(如不同的 dtype 或 shape),仅用第一个操作数无法区分这些情况。 +**重要**:SpecKey 必须基于 **所有操作数** 的 `tile_buf` 类型构建,而不仅仅是第一个操作数。因为同一个 op 的不同操作数可能有不同的类型(如不同的 dtype 或 shape),仅用第一个操作数无法区分这些情况。 -操作数按 IR 类型分为三类,每类参与 SpecKey 的字段不同: +Expand TileOp pass 维护一个实例化缓存,key 包含以下字段: -| 操作数类型 | IR 类型 | 参与 SpecKey 的字段 | 不参与 SpecKey 但传给 Python DSL 的字段 | -|-----------|---------|--------------------|-----------------------------------------| -| **Tile** | `TileBufType` | `dtype` + `shape` + `valid_shape` + `memorySpace` + `config`(blayout/slayout/fractal/pad) | — | -| **View** | `MemRefType`(降级后的 `PartitionTensorViewType`) | `dtype` | `shape`、`strides`、`memorySpace`(仅用于约束检查) | -| **Scalar** | 标量类型 | `dtype` | — | +| Key 字段 | 说明 | +|----------|------| +| `op_name` | Tile op 名称(如 `tadd`) | +| `operand_types` | **所有操作数**的 tile_buf 类型签名,每个操作数包含以下信息 | +| ├─ `dtype` | 元素数据类型(如 `f32`) | +| ├─ `shape` | Tile 的静态 shape(如 `(16, 64)`) | +| └─ `config` | blayout、slayout、fractal、pad 等配置 | -**View 操作数的特化策略**:View 对应的模板参数类型为 `!pto.partition_tensor_view`,维度全部动态,shape/strides 通过 intrinsic 在运行时查询。因此不同 view shape 的 Tile op 可以共享同一份模板实例——`shape`/`strides`/`memorySpace` 不参与 SpecKey 的判等和 hash。这些字段通过 `--operand-specs` JSON 传给 Python DSL 的 `expand_helper`,先按操作数位置构造成 `arg0_*`、`arg1_*` 一类的位置化上下文,再在 constraint evaluation 阶段按模板参数顺序映射到当前参数名(如 `src` / `dst`)后参与约束检查;它们不直接影响模板代码生成。 - -**Tile 操作数的特化策略**:当前实现中,`valid_shape` 参与 SpecKey,并与 `shape`、`memorySpace`、`config` 一起决定模板实例和缓存 key。也就是说,相同 `(op, operand_types)` 但不同 `valid_shape` 的 Tile op 当前会生成不同的实例化结果。约束检查和缓存命名都基于这一实现语义。 +`valid_shape` **不参与** key——因为它可能是动态的,作为运行时值在 inline 后通过 `pto.tile_valid_rows`/`pto.tile_valid_cols` 提取。相同 `(op, operand_types)` 但不同 `valid_shape` 的 Tile op 可以共享同一份实例化结果。 #### 3.2.2 模板实例化过程 Expand TileOp 通过调用 Python 子进程来实例化模板。具体流程: -1. **调用 Python helper**:`python3 -m tilelang_dsl.expand_helper --target --op pto. --operand-specs `,其中 JSON 描述每个操作数的类型信息。 +1. **调用 Python helper**:`python3 -m tilelang_dsl.expand_helper`,传入 op 名称、各操作数的 dtype/shape/memory_space 等参数。 2. **Python 端处理**: - 扫描模板目录下的 `.py` 文件,查找标注了 `@pto.vkernel` 装饰器的模板函数 - - 先按操作数个数和参数种类(`tile` / `view` / `scalar`)做 schema 预过滤 - - 基于 `operand_specs` 构造按位置组织的上下文属性(如 `arg0_shape`、`arg0_strides`、`arg1_config`) - - 调用 `pto.select_kernel(target, concrete_op, operand_types, context_attrs, registry)` 按 `target → op → dtypes → constraints → priority` 规则选择模板 - - 对 `pto.Tile` 参数使用给定的 shape / valid_shape / memory_space / config 进行特化 - - 对 `pto.PartitionTensorView` 参数,不做 `specialize()`,而是通过位置化上下文把 shape/strides/memorySpace 提供给前置条件检查(参数类型保持全动态) + - 按 `op` 名称和 `dtype` 签名匹配模板 + - 对所有 `pto.Tile` 参数使用给定的 shape 和 memory_space 进行特化 - 输出特化后的 MLIR 文本 3. **C++ 端处理**: - 解析 MLIR 文本为 `ModuleOp` - 提取 `func.func`,克隆到目标 Module 末尾 - - 重命名为 `__pto_tilelang___tile____view__...`(Tile 操作数拼 shape/valid_shape/config,View/Scalar 只拼 dtype),设为 `private` 可见性 - - 按 `target + op + operand schema` 存入 specCache + - 重命名为 `__pto_tilelang____`(如 `__pto_tilelang_tadd_f32_16_64`),设为 `private` 可见性 + - 存入 specCache **关键约束**:Python DSL 实例化输出的函数需要满足以下要求: -1. **参数类型**可以是 `!pto.tile_buf`、`!pto.partition_tensor_view` 或标量类型。DSL 在实例化时将 Tile 参数的元素类型、静态 shape、布局配置等信息编码进 `tile_buf` 类型;View 参数保持全动态维度(`!pto.partition_tensor_view`)。 -2. **函数必须带有 `pto.tilelang.instance` 属性**(UnitAttr)。Inline pass 通过此属性识别需要内联的模板实例函数。 - -函数体内部通过以下 intrinsic 提取信息: +1. **参数类型为 `!pto.tile_buf`**,而非 memref。DSL 在实例化时将具体的元素类型、静态 shape、布局配置等信息编码进 `tile_buf` 类型参数。 +2. **函数必须带有 `pto.tilelang.instance` 属性**(UnitAttr)。Inline pass 通过此属性识别需要内联的模板实例函数,而非依赖函数名前缀。 -**tile_buf 系列**(从 `!pto.tile_buf` 提取): +函数体内部通过以下 intrinsic 从 `tile_buf` 中提取信息: | Intrinsic | 功能 | 输出类型 | |-----------|------|----------| -| `pto.tile_buf_addr` | 提取数据区域的 memref 指针 | `memref, #pto.address_space<...>>` | -| `pto.tile_valid_rows` | 提取有效行数 | `index` | -| `pto.tile_valid_cols` | 提取有效列数 | `index` | - -**tensor_view 系列**(从 `!pto.partition_tensor_view` 提取): +| `pto.tile_buf_addr` | 从 tile_buf 提取数据区域的 memref 指针 | `memref, #pto.address_space<...>>` | +| `pto.tile_valid_rows` | 从 tile_buf 提取有效行数 | `index` | +| `pto.tile_valid_cols` | 从 tile_buf 提取有效列数 | `index` | -| Intrinsic | 功能 | 输出类型 | -|-----------|------|----------| -| `pto.tensor_view_addr` | 提取 memref/ptr 基地址 | `memref<...>` 或 `!pto.ptr<...>` | -| `pto.get_tensor_view_dim` | 按维度索引提取 shape 大小 | `index` | -| `pto.get_tensor_view_stride` | 按维度索引提取 stride | `index` | - -对于 Tile 操作数,Expand TileOp 直接将 `tile_buf` 透传。对于 View 操作数,调用方类型为 `memref`,模板参数类型为 `!pto.partition_tensor_view`,因此 Expand TileOp 在调用点插入 `builtin.unrealized_conversion_cast` 桥接。类型转换和 intrinsic 折叠统一在后续的 Fold pass 中处理。 +这样设计的好处是:Expand TileOp pass 的调用点不需要做任何类型桥接,直接将 `tile_buf` 操作数透传给实例化的函数。类型转换和属性提取的工作统一在后续的 Fold pass 中处理。 ### 3.3 实例化模板函数的 IR 结构 @@ -665,83 +544,11 @@ func.func @TADD(%a: !pto.tile_buf<...>, %b: !pto.tile_buf<...>, %c: !pto.tile_bu #### 3.4.4 经过 Fold TileBuf Intrinsics 后 -Fold pass 处理两族 intrinsic,通过严格的模式匹配将它们解析回调用点的具体 SSA 值。 - -##### tile_buf 系列折叠 - -每一个被折叠的 tile_buf intrinsic,其 `tile_buf` 操作数必须由如下固定链定义 -(由 `MemrefToTileBuf` pass 保证),否则 pass 直接报错并失败: - -```mlir -%0 = pto.pointer_cast(%addr) {config = ...} - : memref<16x64xf32, strided<[64, 1]>, #pto.address_space> -%1 = pto.bind_tile %0, %v_row, %v_col {config = ...} - : memref<16x64xf32, strided<[64, 1]>, ...> - -> memref<16x64xf32, strided<[64, 1], offset: ?>, ...> -%2 = builtin.unrealized_conversion_cast %1 - : memref<...> to !pto.tile_buf -``` - -也即:`tile_buf ← unrealized_conversion_cast ← pto.bind_tile ← pto.pointer_cast`。 - -**三条折叠规则**(均锚定到 `pto.bind_tile`): - -- `pto.tile_buf_addr %a` → 折叠为 `bind_tile` 的 **第一个操作数**(即 `pto.pointer_cast` 的结果)。 - 注意这里**绕过**了 `bind_tile` 自身产出的、带 `offset: ?` 的动态布局 memref, - 直接复用上游的 `strided<[64, 1]>` 静态布局 memref。这样下游的 `pto.vlds`/`pto.vsts` - 在被规范化、最终下沉到 VPTO 后端时,看到的始终是干净的 `strided<[..], offset: 0>` 布局, - 避免了 `pto.vlds does not support dynamic memref layout offsets` 这类下游错误。 - 若 `tile_buf_addr` 声明的结果类型与 `bind_tile` 源 memref 的实际布局不一致, - 会就地把结果类型替换为源 memref 的真实类型——下游向量算子对相同 element type / shape - 的 strided 布局是多态的。 -- `pto.tile_valid_rows %a` → 优先按 `TileBufType.validShape[0]` 静态折叠: - 若是静态值(如 `v_row=16`),折叠为 `arith.constant 16 : index`; - 若是动态值(`v_row=?`),折叠为 `bind_tile` 的 **第二个操作数**(`valid_row`,已经是 `index` 类型)。 -- `pto.tile_valid_cols %a` → 同理,使用 `validShape[1]` 或 `bind_tile` 的 **第三个操作数**。 +Fold pass 将 `pto.tile_buf_addr`、`pto.tile_valid_rows`、`pto.tile_valid_cols` 替换为具体值: -##### tensor_view 系列折叠 - -每一个被折叠的 tensor_view intrinsic,其 `partition_tensor_view` 操作数必须由如下固定链定义 -(由 `ExpandTileOp` 和 `PTOViewToMemref` pass 保证),否则 pass 直接报错并失败: - -```mlir -%rc = memref.reinterpret_cast %arg0 - to offset: [0], sizes: [%c1, %c1, %c1, %c16, %c64], - strides: [%c1024, %c1024, %c1024, %c64, %c1] - : memref → memref, gm> - -%sv = memref.subview %rc [0,0,0,0,0] [1,1,1,16,64] [1,1,1,1,1] - : → memref<1x1x1x16x64xf32, strided<[?,?,?,?,?], offset:?>, gm> - -%tv = builtin.unrealized_conversion_cast %sv - : memref<...> → !pto.partition_tensor_view<...> -``` - -也即:`partition_tensor_view ← unrealized_conversion_cast ← memref.subview ← memref.reinterpret_cast`。 - -pass 贯穿整条链,**一步到位**折叠到最终结果,不生成中间的 `memref.dim`、`memref.extract_strided_metadata` 或 `pto.castptr %subview`: - -- `pto.get_tensor_view_dim %tv, %cN` → - - subview 结果类型 shape[N] 是静态的:折叠为 `arith.constant`(如 dim 3 → `arith.constant 16`) - - shape[N] 是动态的:取 subview 的 `getMixedSizes()[N]`(可能追溯到 reinterpret_cast 的 size operand) - -- `pto.get_tensor_view_stride %tv, %cN` → - 直接取 reinterpret_cast 的 stride operand(通过 `getMixedStrides()[N]`)。 - 若 subview 的 stride[N] 不为 1,则生成 `arith.muli(rc_stride, sv_stride)`。 - reinterpret_cast 的 stride 可以是静态属性(生成 `arith.constant`)或动态 SSA 值(直接复用)。 - -- `pto.tensor_view_addr %tv` → - - subview 和 reinterpret_cast 的 offset 均为 0:折叠为 `pto.castptr %arg0`(直接用 base memref) - - 有非零 offset:折叠为 `pto.addptr(pto.castptr %arg0, linear_offset)`, - 其中 `linear_offset = rc_offset + sum(sv_offset[i] * rc_stride[i])` - -##### 通用规则 - -**跳过 TileLang 模板实例**:被 `PTOInlineLibCall` 内联完且作为 dead callee 删除之前, -带 `pto.tilelang.instance` 属性的私有模板函数仍可能保留在 module 中。这些函数体内的 -`pto.tile_buf_addr` 等 intrinsic 直接作用在 `tile_buf` 类型的 BlockArgument 上, -没有 `bind_tile` 可供折叠——pass 通过检测 `pto.tilelang.instance` 属性跳过这些函数, -留给下游 DCE 清理。 +- `pto.tile_buf_addr %a` → 折叠为调用点已知的 memref 值(从 tile_buf 提取底层地址) +- `pto.tile_valid_rows %a` → 如果 `v_row=16` 是静态的,折叠为 `arith.constant 16 : index`;如果是动态的(`v_row=?`),折叠为调用点传入的动态 index 值 +- `pto.tile_valid_cols %a` → 同理 折叠后得到最终的纯向量 IR,不再包含任何 tile_buf 引用: @@ -811,7 +618,7 @@ lib/TileOp/ ← 模板库根目录 advanced=True, # 启用隐式 vecscope 推断 name="template_", ) -def template_xxx(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): +def template_xxx(dst: pto.Tile, src0: pto.Tile, ...): # 向量化实现体 dtype = dst.element_type valid_rows, valid_cols = dst.valid_shape @@ -823,12 +630,7 @@ def template_xxx(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): return None ``` -**关键约束:模板参数顺序必须与 PTOAS 中对应指令的操作数顺序严格一致。** -`ExpandTileOp` 按位置索引将指令操作数直接传递给模板函数参数。对于 DPS 风格的 -算子,这意味着 `ins` 操作数在前、`outs` 在后。例如 `pto.tadd ins(%a, %b) outs(%c)` -的操作数顺序为 `(src0, src1, dst)`,模板参数必须为 `(src0, src1, dst)`。 - -`expand_helper.py` 自动扫描目录下所有 `.py` 文件,先按参数 schema 过滤候选,再通过 `select_kernel()` 按 `target`、`op`、`dtype`、`constraints` 和 `priority` 选择模板。模板约束读取的位置化上下文由 `argN_*` 键提供,并在 constraint evaluation 阶段按参数顺序映射到模板自己的参数名。 +`expand_helper.py` 自动扫描目录下所有 `.py` 文件,按 `op` 名称和 `dtype` 签名匹配模板。 ## 第四章 前置工作 @@ -851,503 +653,9 @@ def template_xxx(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): | MLIR 解析与 inline | 解析生成的 MLIR 文本,inline 到调用点,绑定参数 | | Cleanup | 实例化后运行 canonicalize 清理冗余 | -### 4.3 PTOAS 编译器:Fold TileBuf Intrinsics Pass - -**tile_buf 系列**: - -| 工作项 | 说明 | -|--------|------| -| 严格模式匹配 | 要求 `tile_buf` 由 `unrealized_conversion_cast ← pto.bind_tile` 链定义,否则 emit error 并 fail pass | -| `tile_buf_addr` 折叠 | 替换为 `bind_tile.getSource()`(即 `pto.pointer_cast` 的静态布局 memref),绕过 `bind_tile` 产出的动态 offset 布局 | -| 结果类型自适应 | 若 `tile_buf_addr` 声明类型与 source memref 实际布局不一致,就地更新结果类型 | -| `tile_valid_rows/cols` 折叠 | 优先按 `TileBufType.validShape` 静态折叠为 `arith.constant`;动态时取 `bind_tile` 的 `valid_row`/`valid_col` 操作数 | - -**tensor_view 系列**: - -| 工作项 | 说明 | -|--------|------| -| 严格模式匹配 | 要求 `partition_tensor_view` 由 `unrealized_conversion_cast ← memref.subview ← memref.reinterpret_cast` 链定义,否则 emit error 并 fail pass | -| `tensor_view_addr` 折叠 | 贯穿 subview → reinterpret_cast 链,折叠为 `pto.castptr %base_memref`;有非零 offset 时生成 `pto.addptr` | -| `get_tensor_view_dim` 折叠 | 静态 shape 维度折叠为 `arith.constant`;动态维度取 subview 的 `getMixedSizes()` operand | -| `get_tensor_view_stride` 折叠 | 直接取 reinterpret_cast 的 stride operand(`getMixedStrides()`),乘以 subview stride(通常为 1 可短路) | -| Dead op 清理 | 折叠完成后清理无 user 的 `unrealized_conversion_cast`、`memref.subview`、`memref.reinterpret_cast` | - -**通用**: - -| 工作项 | 说明 | -|--------|------| -| 跳过模板实例 | 检测 `pto.tilelang.instance` 属性,跳过 `PTOInlineLibCall` 删除前残留的私有模板函数 | - -### 4.4 测试与文档 +### 4.3 测试与文档 - Python DSL 模板编写和实例化的单元测试 - 以当前 `lib/TileOps/tadd_template.py` 为例,新增/维护 - `test/basic/expand_tile_op_tilelang.pto` - 作为 `pto.tadd` TileLang 模板实例化的基础回归。该用例覆盖: - 1. `ExpandTileOp` 是否能匹配 `pto.tadd` 并调用 Python DSL helper; - 2. 模板实例化后的 `func.call` 是否能被 inline; - 3. `FoldTileBufIntrinsics` 之后是否得到 `pto.vlds` / `pto.vadd` / `pto.vsts` 形式的 Vector IR。 - - 当前 `pto.tadd` 的向量库模板实现如下: - - ```python - import sys - from pathlib import Path - import tilelang_dsl as pto - - - @pto.vkernel( - target="a5", - op="pto.tadd" - ) - def template_tadd(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): - dtype = dst.element_type - valid_rows, valid_cols = dst.valid_shape - - for row in range(0, valid_rows, 1): - remained = valid_cols - for col in range(0, valid_cols, pto.get_lanes(dtype)): - mask, remained = pto.make_mask(dtype, remained) - lhs = pto.vlds(src0[row, col:]) - rhs = pto.vlds(src1[row, col:]) - summed = pto.vadd(lhs, rhs, mask) - pto.vsts(summed, dst[row, col:], mask) - return - ``` - - 对应的单元测试用例如下: - - ```mlir - // Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline - // expands pto.tadd via the default TileLang Python DSL template - // lib/TileOps/tadd_template.py. - // - // Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics - // - // RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s - - // After the full tile-op-expand path on the VPTO backend, the original - // pto.tadd should be lowered to vector-style VPTO IR. - // CHECK: func.func @TADD - // CHECK-NOT: pto.tadd ins - // CHECK: pto.vecscope - // CHECK: pto.castptr - // CHECK: pto.vlds - // CHECK: pto.vadd - // CHECK: pto.vsts - - module { - func.func @TADD() { - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %tile_buf = pto.alloc_tile - : !pto.tile_buf - - pto.tadd ins(%a, %b : !pto.tile_buf, - !pto.tile_buf) - outs(%tile_buf : !pto.tile_buf) - return - } - } - ``` - Expand TileOp pass 的端到端测试(`pto.tadd` → Vector IR) - 使用以下命令生成最终 LLVM IR,并继续交给 Bisheng 做设备侧编译校验: - - ```bash - ./build/tools/ptoas/ptoas test/basic/expand_tile_op_tilelang.pto \ - --pto-arch=a5 \ - --pto-backend=vpto \ - --enable-tile-op-expand \ - --vpto-emit-hivm-llvm \ - -o - \ - > add.ll - ``` - - 说明: - - `stdout` 中的最终产物是 textual LLVM IR,因此这里使用 `-o - > add.ll` 显式落盘。 - - 随后将生成的 `add.ll` 交给 Bisheng: - - ```bash - bisheng \ - --target=hiipu64-hisilicon-cce \ - -march=dav-c310-vec \ - --cce-aicore-arch=dav-c310-vec \ - --cce-aicore-only \ - -c -x ir add.ll \ - -o add.o - ``` - - 若上述命令成功生成 `add.o`,则说明当前 `pto.tadd` 的向量库模板已经完成: - - TileLang 模板实例化; - - `pto.tadd -> Vector IR -> LLVM IR` 的端到端 lowering; - - Bisheng 设备侧编译校验。 - 融合场景测试(多个 Tile op 连续使用后的 VF Fusion) - 更新 `PTO_IR_manual.md` 和 TileLang DSL Guide - -#### 4.4.1 ST 精度验证 - -IR 回归测试只能验证"模板展开后 IR 长什么样",无法回答"最终在 simulator / NPU 上跑出来的数值是否正确"。 -`test/tilelang_st` 框架提供了端到端精度验证能力,详细设计参见 [`tilelang-st-framework.md`](tilelang-st-framework.md)。 - -本节面向库开发者,说明在完成一个新 TileLang 库实现(如 `lib/TileOps/_template.py`)后,如何接入 ST 框架验证精度。 - -##### 完整执行链路概览 - -ST 框架的统一入口是: - -```bash -python3 test/tilelang_st/script/run_st.py -r sim -v a5 -t tadd -``` - -它不是只做“编译 `.pto`”,而是把编译、生成输入、运行二进制和精度比较串成一条完整流水线: - -```text -run_st.py - ├─ set_env_variables() - │ └─ 配置 simulator / NPU 运行环境 - ├─ build_project() - │ ├─ cmake -DRUN_MODE=... -DSOC_VERSION=... -DTEST_CASE=... -DPTOAS_BIN=... - │ ├─ ptoas: .pto -> _kernel.ll - │ │ flags: - │ │ --pto-arch=a5 - │ │ --pto-backend=vpto - │ │ --enable-insert-sync - │ │ --enable-tile-op-expand - │ │ --vpto-emit-hivm-llvm - │ ├─ bisheng -x ir: _kernel.ll -> _kernel_device.o - │ ├─ repack_tilelang_kernel.sh: - │ │ _kernel_device.o -> _kernel_repack.o - │ ├─ bisheng -xcce: launch.cpp + _kernel_repack.o -> lib_kernel.so - │ └─ bisheng -xc++: main.cpp -> - ├─ run_gen_data() - │ └─ 在 build/testcase// 下生成每个 case 的 input/golden - ├─ run_binary() - │ └─ 在 build/testcase// 下执行 ../../bin/ [case] - └─ run_compare() - └─ 在 build/testcase// 下逐 case 比较 golden/output -``` - -其中编译子链可以单独理解为: - -```text -.pto - ──ptoas──> _kernel.ll (LLVM IR) - ──bisheng -x ir──> _kernel_device.o (device-only 对象) - ──repack_tilelang_kernel.sh──> _kernel_repack.o - (host-linkable fatobj) - ──bisheng -xcce launch.cpp + repack.o──> lib_kernel.so - (共享库) - ──bisheng -xc++ main.cpp + .so──> (host 可执行文件) -``` - -其中 repack 步骤是 TileLang ST 与 pto-isa ST 的核心区别:`ptoas + bisheng -x ir` 产出的 -`*_kernel_device.o` 是 device-only 对象,不能直接作为 host 侧链接输入。repack 脚本从 -`launch.cpp` 中抽取 kernel 声明生成 stub,通过 `-fcce-include-aibinary` 嵌入 device binary, -产出 host 可链接的 fatobj。 - -运行阶段同样是 ST 框架的一部分,而不是“编译完以后开发者手工处理”的额外步骤: - -- `gen_data.py` 会基于 `cases.py` 中的 `CASES` 为每个 case 生成 `input*.bin` 和 `golden.bin` -- host 可执行文件会按 `main.cpp` 中的 case table 逐个读取 `.//input*.bin`,运行 kernel,并写回 `.//output.bin` -- `compare.py` 再基于同一份 `CASES` 定义逐 case 读取并裁剪需要比较的数据,最后调用公共 `result_cmp()` -- 若传入 `-c `,则运行和比较都只针对单个 case - -因此,TileLang ST 的验证对象不是“某一份中间 IR 是否长得对”,而是: - -1. TileLang 模板是否成功展开并编译到可执行产物; -2. 生成的数据、运行时读取的 case 目录、以及 compare 使用的 golden/output 是否保持一致; -3. 最终 simulator / NPU 上的数值结果是否正确。 - -编译子链由 `testcase/CMakeLists.txt` 中的 `pto_tilelang_vec_st()` 宏自动接管,整条执行链路则由 -`run_st.py` 统一调度。 - -##### 新增 testcase 所需文件(七个文件 + 一个注册修改) - -以新增 `pto.tsub` 为例,需在 `test/tilelang_st/npu/a5/src/st/testcase/tsub/` 下准备 -7 个文件,并修改 1 个注册文件: - -**1. `CMakeLists.txt`** — 通常只有一行: - -```cmake -pto_tilelang_vec_st(tsub) -``` - -宏自动查找同目录下的 `tsub.pto`、`launch.cpp`、`main.cpp`,串联上述五步编译。 - -**2. `cases.py`** — **case 定义的单一来源**,`gen_data.py` 和 `compare.py` 均从此导入: - -```python -import numpy as np - -CASES = [ - { - "name": "f32_16x64", - "dtype": np.float32, - "shape": (16, 64), - "valid_shape": (16, 64), - "eps": 1e-6, - }, -] -``` - -常规 case 必须包含 `name`/`dtype`/`shape`/`valid_shape`/`eps` 五个字段,`valid_shape` 为必填。 -如果输出 shape 与输入不同(如 `trowsum`),再额外补 `dst_shape`/`dst_valid_shape`,供 -`compare.py` 和 `gen_data.py` 使用。 - -**3. `tsub.pto`** — kernel 描述,一个文件中放多个 case 对应的函数。每个函数 -代表一种 dtype/shape 组合。以 tadd 为参考,kernel 结构为: - -```mlir -module { - // Case: f32 16x64 - func.func @TSUB_f32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, - %c_ptr: !pto.ptr) { - // 1. make_tensor_view: 从 !pto.ptr 构造 5D tensor_view (1×1×1×rows×cols) - // 2. partition_view: 提取 tile 区域 - // 3. alloc_tile: 分配 UB 上的 tile_buf - // 4. tload: 从 GM 加载到 UB - // 5. pto.tsub: 执行计算 - // 6. tstore: 从 UB 写回 GM - return - } - // Case: f32 32x32 - func.func @TSUB_f32_32x32(...) { ... } -} -``` - -函数命名约定:`__x`,例如 `TSUB_f32_16x64`、`TSUB_bf16_32x32`。 - -注意:`.pto` 中 `make_tensor_view` 的 shape 维度是 5D(`1×1×1×rows×cols`),strides 需要 -与 shape 一致(最内维 stride=1,逐维累乘)。函数参数顺序决定了后续所有文件的参数顺序。 - -**4. `launch.cpp`** — 为每个 kernel 声明 entry 和 launch wrapper: - -```cpp -#include - -#ifndef AICORE -#define AICORE [aicore] -#endif - -extern "C" __global__ AICORE void TSUB_f32_16x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); - -void LaunchTSUB_f32_16x64(float *a, float *b, float *c, void *stream) { - TSUB_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); -} -``` - -关键约束: -- `extern "C" __global__ AICORE void ...` 这一声明形态不可改变,repack 脚本用 sed 从中抽取 stub -- kernel 参数类型和顺序必须与 `.pto` 中函数签名一致 -- `<<<1, nullptr, stream>>>` 表示单核启动 - -**5. `main.cpp`** — host driver,核心是 case table 和 `RunCase()` 函数: - -```cpp -#include "acl/acl.h" -#include "test_common.h" // PtoTestCommon::ReadFile / WriteFile + ACL_CHECK - -using LaunchFn = void (*)(float *, float *, float *, void *); - -struct TestCase { - const char *name; // 对应 cases.py 中的 name 和运行时子目录 - LaunchFn launch; - size_t rows; // allocated tile rows - size_t cols; // allocated tile cols - size_t validRows; // effective computation rows (<= rows) - size_t validCols; // effective computation cols (<= cols) - size_t elemSize; -}; - -static const TestCase kCases[] = { - {"f32_16x64", LaunchTSUB_f32_16x64, 16, 64, 16, 64, sizeof(float)}, - {"f32_32x32", LaunchTSUB_f32_32x32, 32, 32, 32, 32, sizeof(float)}, -}; -``` - -注意:`ACL_CHECK` 宏由公共头 `test_common.h` 提供(需在 `acl/acl.h` 之后包含),无需在每个 testcase 中重复定义。 - -`RunCase()` 的职责: -1. 从 `.//input*.bin` 读取输入到 host 内存 -2. `aclrtMemcpy` 拷贝到 device -3. 调用 `tc.launch(...)` 启动 kernel -4. `aclrtSynchronizeStream` 等待完成 -5. 拷贝结果回 host -6. 写 `.//output.bin` - -`main()` 支持可选 `argv[1]` 作为 case filter,实现单 case 执行。 - -**6. `gen_data.py`** — 生成每个 case 的输入和 golden,从 `cases.py` 导入 `CASES`: - -```python -from cases import CASES -from st_common import validate_cases, setup_case_rng, save_case_data - -validate_cases(CASES) - -for case in CASES: - setup_case_rng(case) # per-case seed,新增 case 不影响已有数据 - dtype, shape = case["dtype"], case["shape"] - valid_shape = case["valid_shape"] - - input1 = np.random.randint(1, 10, size=shape).astype(dtype) - input2 = np.random.randint(1, 10, size=shape).astype(dtype) - golden = np.zeros(shape, dtype=dtype) - vr, vc = valid_shape - golden[:vr, :vc] = (input1[:vr, :vc] - input2[:vr, :vc]).astype(dtype, copy=False) # tsub: 减法 - - save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) -``` - -注意 golden 的计算逻辑必须与 op 语义一致(tadd 是加法,tsub 是减法),且只在 `valid_shape` 区域内计算。 - -**7. `compare.py`** — 每个 testcase 自己维护比较脚本。公共层只提供 -`st_common.result_cmp(golden, output, eps)`,具体比较哪些数据由 testcase 自己决定。 - -以 `tsub` 这种输入输出 shape 一致的 case 为例,核心逻辑通常是: - -```python -from cases import CASES -from st_common import result_cmp, style_fail, style_pass, validate_cases - -validate_cases(CASES) - -for case in CASES: - shape = case["shape"] - vr, vc = case["valid_shape"] - golden = np.fromfile(os.path.join(case["name"], "golden.bin"), dtype=case["dtype"]).reshape(shape) - output = np.fromfile(os.path.join(case["name"], "output.bin"), dtype=case["dtype"]).reshape(shape) - ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) -``` - -如果是 `trowsum` 这类输出 shape 不同的 op,则 `compare.py` 可以自己按 `dst_shape` reshape, -并只比较 `dst_valid_shape` 对应的有效区域。exit code 2 表示失败。 - -精度阈值参考: - -| dtype | 建议 eps | -|---|---| -| `float32` | `1e-6` | -| `float16` | `1e-3` | -| `bfloat16` | `1e-2` | -| `int8/int16/int32` | `0`(精确匹配) | - -**8. 注册** — 修改 `testcase/CMakeLists.txt`,将新 op 加入 `ALL_TESTCASES`: - -```cmake -set(ALL_TESTCASES - tadd - tsub # ← 新增 -) -``` - -##### 文件间一致性约束 - -新增 testcase 时最容易出错的是以下几处必须严格一致: - -| 约束 | 涉及文件 | 示例 | -|---|---|---| -| kernel 函数名 | `.pto` ↔ `launch.cpp` | `@TSUB_f32_16x64` ↔ `TSUB_f32_16x64` | -| Launch wrapper 名 | `launch.cpp` ↔ `main.cpp` | `LaunchTSUB_f32_16x64` | -| case 名 | `cases.py` ↔ `main.cpp` kCases[] ↔ 运行时目录 | `f32_16x64` | -| 参数顺序 | `.pto` → `launch.cpp` → `main.cpp` 的 launch 调用 | `(a, b) → c` | -| shape / valid_shape | `cases.py` ↔ `.pto` tile shape ↔ `main.cpp` rows/cols/validRows/validCols | `16×64` / `(16, 64)` | - -Python 侧的 case 名、dtype、shape、valid_shape、eps(以及必要时的 `dst_shape` / -`dst_valid_shape`)已通过 `cases.py` 收敛为单一来源。但 C++ 侧 `main.cpp` 的 `kCases[]` -和 `.pto` 仍需手动与 `cases.py` 保持一致。 - -任何一处不一致都可能导致:编译成功但运行时 segfault,或运行成功但比较结果错误且难以定位。 - -##### 运行方式 - -统一入口为 `test/tilelang_st/script/run_st.py`。前置条件: -- `ptoas` 已编译(默认路径 `build/tools/ptoas/ptoas`,也可通过 `-p` 指定或 `PTOAS_BIN` 环境变量) -- `ASCEND_HOME_PATH` 已设置 -- 建议先执行 `source scripts/ptoas_env.sh` - -```bash -# simulator 上跑 tsub 全部 case -python3 test/tilelang_st/script/run_st.py -r sim -v a5 -t tsub - -# NPU 上跑 tsub 全部 case -python3 test/tilelang_st/script/run_st.py -r npu -v a5 -t tsub - -# 只跑单个 case -python3 test/tilelang_st/script/run_st.py -r sim -v a5 -t tsub -c f32_16x64 - -# 复用已有 build,跳过重新编译(只重新生成数据、执行、比较) -python3 test/tilelang_st/script/run_st.py -r sim -v a5 -t tsub -c f32_16x64 -w -``` - -`run_st.py` 执行顺序:`set_env_variables()` → `build_project()` → `run_gen_data()` → -`run_binary()` → `run_compare()`。产物输出到 -`test/tilelang_st/npu/a5/src/st/build/testcase//` 下: - -```text -build/testcase/tsub/ -├── st_common.py # 从 testcase/ 公共目录拷贝 -├── cases.py # 从 testcase/tsub/ 拷贝 -├── gen_data.py # 从 testcase/tsub/ 拷贝 -├── compare.py # 从 testcase/tsub/ 拷贝 -├── f32_16x64/ -│ ├── input1.bin -│ ├── input2.bin -│ ├── golden.bin -│ └── output.bin -└── f32_32x32/ - └── ... -``` - -##### 建议的开发验证节奏 - -1. **最小 case 先行**:先写一个最小 case(如 `f32_16x64`),在 simulator 上跑通: - ```bash - python3 test/tilelang_st/script/run_st.py -r sim -v a5 -t tsub -c f32_16x64 - ``` - -2. **快速迭代**:修改 `.pto` 或 host 代码后,用 `-w` 跳过 cmake/make 重编译。 - 注意:如果改了 `.pto` 本身,仍需重新编译(不加 `-w`),`-w` 只适合改 `gen_data.py` / - `compare.py` / `main.cpp` 中非编译相关逻辑的情况。 - -3. **扩充 case**:单 case 稳定后,补充更多 shape / dtype 组合。建议覆盖: - - 不同 dtype(f32 / f16 / bf16) - - 不同 tile 形状(正方形、长条形) - - 边界情况(valid 行列不是整 tile 的场景) - -4. **全量验证**:跑全量 case 确认无回归。 - -5. **NPU 验证**:切到 `-r npu` 在真实硬件上验证。simulator 和 NPU 的行为可能存在差异。 - -##### 调试建议 - -| 阶段 | 排查方向 | -|---|---| -| `ptoas` 编译失败 | 检查 `.pto` 语法、TileLang 模板是否匹配、是否缺少 `--enable-tile-op-expand` | -| `bisheng -x ir` 失败 | 检查 `build/testcase//_kernel.ll` 中的 LLVM IR | -| repack 失败 | 检查 `launch.cpp` 中的 kernel 声明是否符合 `extern "C" __global__ AICORE void` 格式 | -| 链接失败 | 检查共享库符号名一致性、ACL 运行时依赖 | -| kernel 执行失败 | 确认 `build/testcase///input*.bin` 是否已生成 | -| compare fail | 先检查 `output.bin` vs `golden.bin` 差异,再检查 `.pto` 语义和参数顺序 | - -##### 已有 testcase 下新增 case - -如果只是在已有 testcase(如 `tadd`)下新增一个 case(如 `f32_8x128`),只需同步修改 4 个文件: - -| 文件 | 修改内容 | -|---|---| -| `cases.py` | 在 `CASES` 中加入 `{"name": "f32_8x128", "dtype": np.float32, "shape": (8, 128), "valid_shape": (8, 128), "eps": 1e-6}` | -| `tadd.pto` | 新增 `func.func @TADD_f32_8x128(...)` 函数体 | -| `launch.cpp` | 新增 `extern "C"` kernel 声明和 `LaunchTADD_f32_8x128` wrapper | -| `main.cpp` | 在 `kCases[]` 中加入 `{"f32_8x128", LaunchTADD_f32_8x128, 8, 128, 8, 128, sizeof(float)}` | - -`gen_data.py` 和 `compare.py` 无需修改,自动从 `cases.py` 读取。 From b424d215313f4dacb0b6f2db25107d82572c12d6 Mon Sep 17 00:00:00 2001 From: qukelin Date: Tue, 7 Apr 2026 10:42:28 +0800 Subject: [PATCH 025/192] =?UTF-8?q?feat(transforms):=20add=20Tile=E2=86=92?= =?UTF-8?q?Vector=20template=20lowering=20pipeline?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduce a three-pass pipeline that lowers PTO tile ops to vector-level implementations via TileLang DSL templates: - ExpandTileOp: invokes TileLang Python DSL to instantiate template functions and replaces tile ops with func.call. SpecKey covers all operands; tile_buf operands are passed through without bridging. - PTOInlineLibCall: extended to recognize tilelang instance functions via the attribute set by the DSL frontend. - FoldTileBufIntrinsics: resolves pto.tile_buf_addr / tile_valid_rows / tile_valid_cols, including dynamic valid-shape via pto.bind_tile chain tracing. - MemrefToTileBuf: recovers tile_buf types from memref + bind_tile metadata after PlanMemory/InsertSync. - PTOViewToMemref: insert pto.bind_tile anchors for tile_buf function args so MemrefToTileBuf can recover them. Adds new PTO ops (tile_buf_addr/tile_valid_rows/tile_valid_cols), ptoas pipeline wiring, design docs, and unit tests. --- include/PTO/IR/PTOOps.td | 58 +++ include/PTO/Transforms/Passes.h | 4 + include/PTO/Transforms/Passes.td | 79 ++++ lib/PTO/Transforms/CMakeLists.txt | 5 + lib/PTO/Transforms/ExpandTileOp.cpp | 442 ++++++++++++++++++ lib/PTO/Transforms/FoldTileBufIntrinsics.cpp | 207 ++++++++ lib/PTO/Transforms/MemrefToTileBuf.cpp | 245 ++++++++++ .../PTOInstantiateAndInlineOpLib.cpp | 246 ++++++++++ lib/PTO/Transforms/PTOViewToMemref.cpp | 37 ++ test/basic/expand_tile_op_tilelang.pto | 49 ++ test/basic/fold_tile_buf_intrinsics.pto | 90 ++++ test/tilelang_templates/tadd_template.py | 32 ++ tilelang-dsl/examples/tadd_demo.py | 72 +++ .../python/tilelang_dsl/expand_helper.py | 154 ++++++ tools/ptoas/ptoas.cpp | 49 ++ 15 files changed, 1769 insertions(+) create mode 100644 lib/PTO/Transforms/ExpandTileOp.cpp create mode 100644 lib/PTO/Transforms/FoldTileBufIntrinsics.cpp create mode 100644 lib/PTO/Transforms/MemrefToTileBuf.cpp create mode 100644 lib/PTO/Transforms/PTOInstantiateAndInlineOpLib.cpp create mode 100644 test/basic/expand_tile_op_tilelang.pto create mode 100644 test/basic/fold_tile_buf_intrinsics.pto create mode 100644 test/tilelang_templates/tadd_template.py create mode 100644 tilelang-dsl/examples/tadd_demo.py create mode 100644 tilelang-dsl/python/tilelang_dsl/expand_helper.py diff --git a/include/PTO/IR/PTOOps.td b/include/PTO/IR/PTOOps.td index 0fc4dbabd..5e53b3363 100644 --- a/include/PTO/IR/PTOOps.td +++ b/include/PTO/IR/PTOOps.td @@ -2165,6 +2165,64 @@ def TSyncOp : PTO_TOp<"tsync"> { }]; } +// --------------------------------------------------------------------------- +// TileBuf intrinsics — used in TileLang DSL-generated template functions. +// These ops extract memref address and valid shape from tile_buf parameters. +// After inline, FoldTileBufIntrinsics resolves them to concrete values. +// --------------------------------------------------------------------------- + +def TileBufAddrOp : PTO_Op<"tile_buf_addr", [Pure]> { + let summary = "Extract memref address from a tile_buf."; + let description = [{ + Returns a memref view of the data region of a `tile_buf` value. + The result memref has the same element type, shape (from tile_buf's static + shape), and address space as the source tile_buf, with row-major strides. + + This op is emitted by TileLang DSL templates and resolved by the + FoldTileBufIntrinsics pass after inlining. + }]; + + let arguments = (ins TileBufType:$src); + let results = (outs AnyMemRef:$dst); + + let assemblyFormat = [{ + $src attr-dict `:` qualified(type($src)) `->` type($dst) + }]; +} + +def TileValidRowsOp : PTO_Op<"tile_valid_rows", [Pure]> { + let summary = "Extract valid row count from a tile_buf."; + let description = [{ + Returns the valid row count (v_row) of a `tile_buf` as an `index` value. + When the tile_buf has a static v_row, FoldTileBufIntrinsics folds this + into `arith.constant`. When v_row is dynamic (`?`), the fold resolves + it to the runtime index value carried by the tile_buf. + }]; + + let arguments = (ins TileBufType:$src); + let results = (outs Index:$result); + + let assemblyFormat = [{ + $src attr-dict `:` qualified(type($src)) `->` type($result) + }]; +} + +def TileValidColsOp : PTO_Op<"tile_valid_cols", [Pure]> { + let summary = "Extract valid column count from a tile_buf."; + let description = [{ + Returns the valid column count (v_col) of a `tile_buf` as an `index` value. + When the tile_buf has a static v_col, FoldTileBufIntrinsics folds this + into `arith.constant`. When v_col is dynamic (`?`), the fold resolves + it to the runtime index value carried by the tile_buf. + }]; + + let arguments = (ins TileBufType:$src); + let results = (outs Index:$result); + + let assemblyFormat = [{ + $src attr-dict `:` qualified(type($src)) `->` type($result) + }]; +} //===----------------------------------------------------------------------===// // FFT Configuration Operation diff --git a/include/PTO/Transforms/Passes.h b/include/PTO/Transforms/Passes.h index 4193bc7d5..fadd4e143 100644 --- a/include/PTO/Transforms/Passes.h +++ b/include/PTO/Transforms/Passes.h @@ -70,6 +70,10 @@ std::unique_ptr createPTOValidateVPTOIRPass(); std::unique_ptr createPTOValidateVPTOEmissionIRPass(); std::unique_ptr createLowerPTOToVPTOPass(); std::unique_ptr createLowerPTOToVPTOPass(StringRef loweringStrategy); +std::unique_ptr createMemrefToTileBufPass(); +std::unique_ptr createExpandTileOpPass(); +std::unique_ptr createExpandTileOpPass(const ExpandTileOpOptions &options); +std::unique_ptr createFoldTileBufIntrinsicsPass(); //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index 5094c356a..0ed6e8299 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -138,6 +138,26 @@ def PTOWrapFunctionsInSections : Pass<"pto-wrap-functions-in-sections", "func::F ]; } +def MemrefToTileBuf : Pass<"pto-memref-to-tile-buf", "ModuleOp"> { + let summary = "Recover tile_buf types from memref + pto.bind_tile metadata"; + let description = [{ + After PTOViewToMemref + PlanMemory + InsertSync, the IR uses memref types + with pto.bind_tile ops carrying tile metadata (TileBufConfigAttr, valid + dims). This pass reconstructs tile_buf types from those anchors so that + the subsequent Tile->Vector lowering (Expand TileOp) can operate on + tile_buf semantics. + + The pass does NOT redo memory planning or synchronisation. + }]; + let constructor = "mlir::pto::createMemrefToTileBufPass()"; + let dependentDialects = [ + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::arith::ArithDialect", + "mlir::func::FuncDialect" + ]; +} + def PTOLowerFrontendPipeOps : Pass<"pto-lower-frontend-pipe-ops", "func::FuncOp"> { let summary = "Lower frontend TPUSH/TPOP pipe ops to internal pipe ops"; let description = [{ @@ -218,6 +238,65 @@ def PTOResolveReservedBuffers : Pass<"pto-resolve-reserved-buffers", "ModuleOp"> ]; } +def ExpandTileOp : Pass<"pto-expand-tile-op", "ModuleOp"> { + let summary = "Expand tile ops into calls to TileLang DSL template functions"; + let description = [{ + Expands tile-level operations (pto.tadd, pto.tsub, etc.) by invoking the + TileLang Python DSL to instantiate template libraries. The generated + template functions use tile_buf parameters and contain vector-level + implementations (pto.vecscope, pto.vlds, pto.vadd, pto.vsts, etc.). + + Each tile op is replaced by a func.call to the generated template function, + with tile_buf operands passed directly (no type bridging). + + After this pass, the Inline pass inlines template bodies, and + FoldTileBufIntrinsics resolves tile_buf_addr / tile_valid_rows / + tile_valid_cols. + }]; + let constructor = "mlir::pto::createExpandTileOpPass()"; + let dependentDialects = [ + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::arith::ArithDialect", + "mlir::func::FuncDialect", + "mlir::scf::SCFDialect", + "mlir::vector::VectorDialect" + ]; + let options = [ + Option<"tilelangPath", "tilelang-path", "std::string", + /*default=*/"\"\"", + "Path to directory of .py tilelang DSL template files">, + Option<"tilelangPkgPath", "tilelang-pkg-path", "std::string", + /*default=*/"\"\"", + "PYTHONPATH for tilelang_dsl package (added to env)">, + Option<"pythonExe", "python-exe", "std::string", + /*default=*/"\"python3\"", + "Python executable for tilelang DSL invocation"> + ]; +} + +def FoldTileBufIntrinsics : Pass<"pto-fold-tile-buf-intrinsics", "mlir::func::FuncOp"> { + let summary = "Fold tile_buf_addr / tile_valid_rows / tile_valid_cols"; + let description = [{ + After TileLang DSL template functions are inlined, the IR contains + pto.tile_buf_addr, pto.tile_valid_rows, and pto.tile_valid_cols ops + whose tile_buf operands are now bound to concrete values. + + This pass resolves them: + - pto.tile_buf_addr → replaced by pto.simd.tile_to_memref (extracts + memref address from tile_buf) + - pto.tile_valid_rows → folded to arith.constant if v_row is static, + or replaced with the dynamic index value from tile_buf + - pto.tile_valid_cols → same as above for v_col + }]; + let constructor = "mlir::pto::createFoldTileBufIntrinsicsPass()"; + let dependentDialects = [ + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::arith::ArithDialect" + ]; +} + def PTOVerifyTFree : Pass<"pto-verify-tfree", "func::FuncOp"> { let summary = "Verify explicit matching pto.tfree placement for pto.tpop"; let description = [{ diff --git a/lib/PTO/Transforms/CMakeLists.txt b/lib/PTO/Transforms/CMakeLists.txt index 732db768e..79d8ad2f8 100644 --- a/lib/PTO/Transforms/CMakeLists.txt +++ b/lib/PTO/Transforms/CMakeLists.txt @@ -24,6 +24,9 @@ add_mlir_dialect_library(PTOTransforms InsertSync/PTOInsertSync.cpp InsertSync/InsertSyncDebug.cpp PTOViewToMemref.cpp + MemrefToTileBuf.cpp + ExpandTileOp.cpp + FoldTileBufIntrinsics.cpp PTOToEmitC.cpp Utils.cpp OptMemPlanForPipeline.cpp @@ -75,6 +78,8 @@ add_mlir_dialect_library(PTOTransforms MLIRTransforms MLIRTensorDialect MLIRSCFDialect + MLIRVectorDialect + MLIRParser MLIRSCFToEmitC MLIRSCFToControlFlow MLIRConvertToLLVMPass diff --git a/lib/PTO/Transforms/ExpandTileOp.cpp b/lib/PTO/Transforms/ExpandTileOp.cpp new file mode 100644 index 000000000..945d2cc2e --- /dev/null +++ b/lib/PTO/Transforms/ExpandTileOp.cpp @@ -0,0 +1,442 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- ExpandTileOp.cpp ---------------------------------------------------===// +//===----------------------------------------------------------------------===// +// +// Expand tile-level ops (pto.tadd, pto.tsub, ...) by invoking the TileLang +// Python DSL to instantiate template libraries. +// +// The generated template functions use tile_buf parameters. After this pass, +// the Inline pass inlines the template body, and FoldTileBufIntrinsics +// resolves tile_buf_addr / tile_valid_rows / tile_valid_cols. +// +// Workflow per tile op: +// 1. Extract SpecKey from ALL operands' tile_buf types. +// 2. Invoke Python DSL helper to generate a specialized MLIR function +// (with tile_buf parameters). +// 3. Parse the generated MLIR and clone the function into the module. +// 4. Replace the original tile op with func.call, passing tile_buf +// operands directly (no type bridging needed). +// + +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Parser/Parser.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/Path.h" +#include "llvm/Support/Program.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include + +extern "C" { +extern char **environ; +} + +using namespace mlir; + +namespace mlir { +namespace pto { + namespace func = ::mlir::func; + + #define GEN_PASS_DEF_EXPANDTILEOP + #include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +namespace { + +// ============================================================================ +// OperandTypeInfo: captures the tile_buf type info for one operand. +// ============================================================================ +struct OperandTypeInfo { + std::string dtype; + SmallVector shape; + int32_t blayout = 0; + int32_t slayout = 0; + int32_t fractal = 0; + int32_t pad = 0; + + bool operator==(const OperandTypeInfo &rhs) const { + return dtype == rhs.dtype && shape == rhs.shape && + blayout == rhs.blayout && slayout == rhs.slayout && + fractal == rhs.fractal && pad == rhs.pad; + } +}; + +// ============================================================================ +// SpecKey: identifies a specialized template instance using ALL operands. +// ============================================================================ +struct SpecKey { + std::string opName; + SmallVector operands; + + bool operator==(const SpecKey &rhs) const { + return opName == rhs.opName && operands == rhs.operands; + } +}; + +struct SpecKeyInfo : public llvm::DenseMapInfo { + static inline SpecKey getEmptyKey() { return {"", {}}; } + static inline SpecKey getTombstoneKey() { return {"__tombstone__", {}}; } + static unsigned getHashValue(const SpecKey &key) { + unsigned h = llvm::hash_value(key.opName); + for (const auto &op : key.operands) { + h = llvm::hash_combine(h, op.dtype, op.blayout, op.slayout, + op.fractal, op.pad); + for (int64_t d : op.shape) + h = llvm::hash_combine(h, d); + } + return h; + } + static bool isEqual(const SpecKey &lhs, const SpecKey &rhs) { + return lhs == rhs; + } +}; + +// ============================================================================ +// Helpers +// ============================================================================ +static std::string getDtypeString(Type elemTy) { + if (elemTy.isF32()) return "f32"; + if (elemTy.isF16()) return "f16"; + if (elemTy.isBF16()) return "bf16"; + if (elemTy.isSignlessInteger(32)) return "i32"; + if (elemTy.isSignlessInteger(16)) return "i16"; + if (elemTy.isSignlessInteger(8)) return "i8"; + return ""; +} + +static StringRef getTileOpName(Operation *op) { + return op->getName().stripDialect(); +} + +static std::string getMemorySpaceString(pto::TileBufType tbTy) { + auto msAttr = dyn_cast_or_null(tbTy.getMemorySpace()); + if (!msAttr) return "ub"; + if (msAttr.getAddressSpace() == pto::AddressSpace::GM) return "gm"; + return "ub"; +} + +static std::optional +buildOperandTypeInfo(pto::TileBufType tbTy) { + OperandTypeInfo info; + info.dtype = getDtypeString(tbTy.getElementType()); + if (info.dtype.empty()) + return std::nullopt; + info.shape.assign(tbTy.getShape().begin(), tbTy.getShape().end()); + if (auto config = tbTy.getConfigAttr()) { + info.blayout = static_cast(config.getBLayout().getValue()); + info.slayout = static_cast(config.getSLayout().getValue()); + info.fractal = config.getSFractalSize() + ? static_cast(config.getSFractalSize().getInt()) + : 0; + info.pad = static_cast(config.getPad().getValue()); + } + return info; +} + +static std::optional buildSpecKey(Operation *op) { + SpecKey key; + key.opName = getTileOpName(op).str(); + + for (unsigned i = 0; i < op->getNumOperands(); ++i) { + auto tbTy = dyn_cast(op->getOperand(i).getType()); + if (!tbTy) + return std::nullopt; + auto info = buildOperandTypeInfo(tbTy); + if (!info) + return std::nullopt; + key.operands.push_back(*info); + } + if (key.operands.empty()) + return std::nullopt; + + return key; +} + +// ============================================================================ +// ExpandState: runtime state for a single pass invocation. +// ============================================================================ +struct ExpandState { + std::vector> parsedModules; + llvm::DenseMap specCache; + + std::string tilelangPath; + std::string tilelangPkgPath; + std::string pythonExe; + + func::FuncOp invokeTilelangDSL(const SpecKey &key, Operation *tileOp, + ModuleOp mod, MLIRContext *ctx); + + LogicalResult expandTileOpsInFunction(func::FuncOp func, ModuleOp mod, + MLIRContext *ctx); +}; + +// ============================================================================ +// The Pass +// ============================================================================ +struct ExpandTileOpPass + : public mlir::pto::impl::ExpandTileOpBase { + using ExpandTileOpBase::ExpandTileOpBase; + + void runOnOperation() override; +}; + +// ============================================================================ +// Invoke Python DSL helper to generate a specialized template function. +// ============================================================================ +func::FuncOp ExpandState::invokeTilelangDSL(const SpecKey &key, + Operation *tileOp, + ModuleOp mod, MLIRContext *ctx) { + // Check cache first. + auto cacheIt = specCache.find(key); + if (cacheIt != specCache.end()) + return cacheIt->second; + + // 1. Locate the Python executable. + auto pythonPath = llvm::sys::findProgramByName(pythonExe); + if (!pythonPath) { + llvm::errs() << "ExpandTileOp: cannot find '" << pythonExe << "'\n"; + return nullptr; + } + + // 2. Build shape string from the first operand (e.g. "16,64"). + // TODO: extend expand_helper to accept per-operand shapes if needed. + const auto &firstOp = key.operands[0]; + std::string shapeStr; + for (unsigned i = 0; i < firstOp.shape.size(); ++i) { + if (i > 0) shapeStr += ","; + shapeStr += std::to_string(firstOp.shape[i]); + } + + // Get memory space from the first tile_buf operand. + auto firstTbTy = dyn_cast(tileOp->getOperand(0).getType()); + std::string memSpace = firstTbTy ? getMemorySpaceString(firstTbTy) : "ub"; + + // 3. Create temp file for stdout redirect. + SmallString<128> tmpPath; + int tmpFD; + if (auto ec = llvm::sys::fs::createTemporaryFile("tilelang_expand", "mlir", + tmpFD, tmpPath)) { + llvm::errs() << "ExpandTileOp: cannot create temp file: " + << ec.message() << "\n"; + return nullptr; + } + ::close(tmpFD); + + // 4. Build command args. + std::string opName = "pto." + key.opName; + SmallVector args = { + *pythonPath, "-m", "tilelang_dsl.expand_helper", + "--template-dir", tilelangPath, + "--op", opName, + "--dtype", firstOp.dtype, + "--shape", shapeStr, + "--memory-space", memSpace, + }; + + // 5. Set up environment with PYTHONPATH. + std::optional redirects[] = {std::nullopt, StringRef(tmpPath), + std::nullopt}; + + SmallVector envp; + std::string pythonPathEnv; + std::vector envStorage; + bool hasPythonPath = !tilelangPkgPath.empty(); + if (hasPythonPath) { + const char *existingPath = ::getenv("PYTHONPATH"); + pythonPathEnv = "PYTHONPATH=" + tilelangPkgPath; + if (existingPath && existingPath[0] != '\0') { + pythonPathEnv += ":"; + pythonPathEnv += existingPath; + } + for (char **e = environ; *e; ++e) { + StringRef entry(*e); + if (entry.starts_with("PYTHONPATH=")) + continue; + envStorage.push_back(std::string(entry)); + } + envStorage.push_back(pythonPathEnv); + for (auto &s : envStorage) + envp.push_back(s); + } + + // 6. Execute. + std::string errMsg; + int rc = llvm::sys::ExecuteAndWait( + *pythonPath, args, + hasPythonPath ? std::optional>(envp) : std::nullopt, + redirects, /*secondsToWait=*/30, /*memoryLimit=*/0, &errMsg); + + if (rc != 0) { + llvm::errs() << "ExpandTileOp: tilelang DSL helper failed (rc=" << rc + << "): " << errMsg << "\n"; + llvm::sys::fs::remove(tmpPath); + return nullptr; + } + + // 7. Read the generated MLIR. + auto bufOrErr = llvm::MemoryBuffer::getFile(tmpPath); + llvm::sys::fs::remove(tmpPath); + if (!bufOrErr) { + llvm::errs() << "ExpandTileOp: cannot read DSL output\n"; + return nullptr; + } + StringRef mlirText = (*bufOrErr)->getBuffer(); + if (mlirText.empty()) { + llvm::errs() << "ExpandTileOp: empty DSL output\n"; + return nullptr; + } + + // 8. Parse the MLIR text. + auto parsedMod = parseSourceString(mlirText, ctx); + if (!parsedMod) { + llvm::errs() << "ExpandTileOp: failed to parse DSL output\n"; + return nullptr; + } + + // 9. Find func.func in the parsed module and clone into target module. + func::FuncOp srcFn; + for (auto fn : parsedMod->getOps()) { + srcFn = fn; + break; + } + if (!srcFn) { + llvm::errs() << "ExpandTileOp: no func.func in DSL output\n"; + return nullptr; + } + + OpBuilder builder(ctx); + builder.setInsertionPointToEnd(mod.getBody()); + IRMapping mapping; + auto cloned = cast(builder.clone(*srcFn, mapping)); + + // Build a unique name from all operand types. + std::string uniqueName = "__pto_tilelang_" + key.opName; + for (const auto &op : key.operands) { + uniqueName += "_" + op.dtype; + for (int64_t d : op.shape) + uniqueName += "_" + std::to_string(d); + } + cloned.setName(uniqueName); + cloned.setVisibility(SymbolTable::Visibility::Private); + // The pto.tilelang.instance attribute should already be set by the + // TileLang DSL frontend in the generated MLIR. Verify it exists. + if (!cloned->hasAttr("pto.tilelang.instance")) { + llvm::errs() << "ExpandTileOp: warning: DSL output function @" + << cloned.getSymName() + << " missing pto.tilelang.instance attribute\n"; + } + + // Keep the parsed module alive. + parsedModules.push_back(std::move(parsedMod)); + + specCache[key] = cloned; + return cloned; +} + +// ============================================================================ +// Expand tile ops in a single function. +// ============================================================================ +LogicalResult ExpandState::expandTileOpsInFunction(func::FuncOp func, + ModuleOp mod, + MLIRContext *ctx) { + OpBuilder builder(ctx); + + // Collect tile ops first (avoid modifying while iterating). + SmallVector tileOps; + func.walk([&](Operation *op) { + if (isa(op)) + tileOps.push_back(op); + }); + + for (auto *op : tileOps) { + auto specKeyOpt = buildSpecKey(op); + if (!specKeyOpt) { + op->emitWarning("ExpandTileOp: cannot build specialization key, skipping"); + continue; + } + + // Invoke tilelang DSL (with caching). + func::FuncOp dslFn = invokeTilelangDSL(*specKeyOpt, op, mod, ctx); + if (!dslFn) { + StringRef opName = getTileOpName(op); + op->emitWarning("ExpandTileOp: no tilelang template for " + opName + + ", skipping"); + continue; + } + + // Replace tile op with func.call, passing tile_buf operands directly. + builder.setInsertionPoint(op); + SmallVector operands(op->getOperands()); + builder.create(op->getLoc(), dslFn, operands); + op->erase(); + } + + return success(); +} + +// ============================================================================ +// Main entry point. +// ============================================================================ +void ExpandTileOpPass::runOnOperation() { + ModuleOp mod = getOperation(); + MLIRContext *ctx = &getContext(); + + if (tilelangPath.empty()) { + return; + } + + ExpandState state; + state.tilelangPath = std::string(tilelangPath); + state.tilelangPkgPath = std::string(tilelangPkgPath); + state.pythonExe = std::string(pythonExe); + + for (auto func : mod.getOps()) { + if (func.isExternal()) + continue; + if (failed(state.expandTileOpsInFunction(func, mod, ctx))) + return signalPassFailure(); + } +} + +} // namespace + +namespace mlir { +namespace pto { + +std::unique_ptr createExpandTileOpPass() { + return std::make_unique(); +} + +std::unique_ptr +createExpandTileOpPass(const ExpandTileOpOptions &options) { + return std::make_unique(options); +} + +} // namespace pto +} // namespace mlir diff --git a/lib/PTO/Transforms/FoldTileBufIntrinsics.cpp b/lib/PTO/Transforms/FoldTileBufIntrinsics.cpp new file mode 100644 index 000000000..dd5764d66 --- /dev/null +++ b/lib/PTO/Transforms/FoldTileBufIntrinsics.cpp @@ -0,0 +1,207 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- FoldTileBufIntrinsics.cpp ------------------------------------------===// +// +// After TileLang DSL template functions are inlined, the IR contains: +// - pto.tile_buf_addr → extract memref address from tile_buf +// - pto.tile_valid_rows → extract valid row count +// - pto.tile_valid_cols → extract valid column count +// +// This pass resolves them against the concrete tile_buf values at the +// call site. +// +//===----------------------------------------------------------------------===// + +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace mlir { +namespace pto { + #define GEN_PASS_DEF_FOLDTILEBUFINTRINSICS + #include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +namespace { + +/// Compute the row-major strided memref type for a tile_buf. +static MemRefType computeBridgeMemrefType(pto::TileBufType tbTy, + MLIRContext *ctx) { + ArrayRef shape = tbTy.getShape(); + ArrayRef validShape = tbTy.getValidShape(); + + SmallVector memrefDims; + for (unsigned d = 0; d < shape.size(); ++d) { + if (d < validShape.size() && validShape[d] != ShapedType::kDynamic) + memrefDims.push_back(validShape[d]); + else + memrefDims.push_back(ShapedType::kDynamic); + } + + SmallVector strides(shape.size(), 1); + for (int s = static_cast(shape.size()) - 2; s >= 0; --s) + strides[s] = strides[s + 1] * shape[s + 1]; + + auto stridedLayout = StridedLayoutAttr::get(ctx, /*offset=*/0, strides); + return MemRefType::get(memrefDims, tbTy.getElementType(), stridedLayout, + tbTy.getMemorySpace()); +} + +/// Try to find the dynamic valid_row index from the tile_buf's defining op +/// chain (e.g. pto.bind_tile carries optional valid_row/valid_col operands). +static Value findDynamicValidRow(Value tileBuf) { + Value cur = tileBuf; + while (cur) { + if (auto bindOp = cur.getDefiningOp()) { + if (bindOp.getValidRow()) + return bindOp.getValidRow(); + // bind_tile may chain — trace further through its source. + cur = bindOp.getSource(); + continue; + } + break; + } + return nullptr; +} + +/// Try to find the dynamic valid_col index from the tile_buf's defining op. +static Value findDynamicValidCol(Value tileBuf) { + Value cur = tileBuf; + while (cur) { + if (auto bindOp = cur.getDefiningOp()) { + if (bindOp.getValidCol()) + return bindOp.getValidCol(); + cur = bindOp.getSource(); + continue; + } + break; + } + return nullptr; +} + +struct FoldTileBufIntrinsicsPass + : public pto::impl::FoldTileBufIntrinsicsBase { + using FoldTileBufIntrinsicsBase::FoldTileBufIntrinsicsBase; + + void runOnOperation() override { + func::FuncOp func = getOperation(); + MLIRContext *ctx = &getContext(); + OpBuilder builder(ctx); + + SmallVector addrOps; + SmallVector rowsOps; + SmallVector colsOps; + + func.walk([&](Operation *op) { + if (auto addr = dyn_cast(op)) + addrOps.push_back(addr); + else if (auto rows = dyn_cast(op)) + rowsOps.push_back(rows); + else if (auto cols = dyn_cast(op)) + colsOps.push_back(cols); + }); + + // Fold pto.tile_buf_addr → pto.simd.tile_to_memref. + for (auto addrOp : addrOps) { + builder.setInsertionPoint(addrOp); + auto tbTy = dyn_cast(addrOp.getSrc().getType()); + if (!tbTy) { + addrOp.emitError("tile_buf_addr source is not tile_buf"); + return signalPassFailure(); + } + + MemRefType bridgeMemref = computeBridgeMemrefType(tbTy, ctx); + auto bridge = builder.create( + addrOp.getLoc(), bridgeMemref, addrOp.getSrc()); + + Value result = bridge.getDst(); + if (result.getType() != addrOp.getDst().getType()) { + result = builder.create( + addrOp.getLoc(), addrOp.getDst().getType(), result); + } + + addrOp.getDst().replaceAllUsesWith(result); + addrOp.erase(); + } + + // Fold pto.tile_valid_rows → arith.constant or dynamic index. + for (auto rowsOp : rowsOps) { + builder.setInsertionPoint(rowsOp); + auto tbTy = dyn_cast(rowsOp.getSrc().getType()); + if (!tbTy || tbTy.getValidShape().empty()) { + rowsOp.emitError("tile_valid_rows: invalid tile_buf type"); + return signalPassFailure(); + } + + int64_t vRow = tbTy.getValidShape()[0]; + Value replacement; + if (vRow != ShapedType::kDynamic) { + replacement = + builder.create(rowsOp.getLoc(), vRow); + } else { + replacement = findDynamicValidRow(rowsOp.getSrc()); + if (!replacement) { + rowsOp.emitError( + "tile_valid_rows: dynamic v_row but cannot find runtime value " + "(expected pto.bind_tile with valid_row operand)"); + return signalPassFailure(); + } + } + rowsOp.getResult().replaceAllUsesWith(replacement); + rowsOp.erase(); + } + + // Fold pto.tile_valid_cols → arith.constant or dynamic index. + for (auto colsOp : colsOps) { + builder.setInsertionPoint(colsOp); + auto tbTy = dyn_cast(colsOp.getSrc().getType()); + if (!tbTy || tbTy.getValidShape().size() < 2) { + colsOp.emitError("tile_valid_cols: invalid tile_buf type"); + return signalPassFailure(); + } + + int64_t vCol = tbTy.getValidShape()[1]; + Value replacement; + if (vCol != ShapedType::kDynamic) { + replacement = + builder.create(colsOp.getLoc(), vCol); + } else { + replacement = findDynamicValidCol(colsOp.getSrc()); + if (!replacement) { + colsOp.emitError( + "tile_valid_cols: dynamic v_col but cannot find runtime value " + "(expected pto.bind_tile with valid_col operand)"); + return signalPassFailure(); + } + } + colsOp.getResult().replaceAllUsesWith(replacement); + colsOp.erase(); + } + } +}; + +} // namespace + +namespace mlir { +namespace pto { + +std::unique_ptr createFoldTileBufIntrinsicsPass() { + return std::make_unique(); +} + +} // namespace pto +} // namespace mlir diff --git a/lib/PTO/Transforms/MemrefToTileBuf.cpp b/lib/PTO/Transforms/MemrefToTileBuf.cpp new file mode 100644 index 000000000..40e86006b --- /dev/null +++ b/lib/PTO/Transforms/MemrefToTileBuf.cpp @@ -0,0 +1,245 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- MemrefToTileBuf.cpp ------------------------------------------------===// +//===----------------------------------------------------------------------===// +// +// After PTOViewToMemref + PlanMemory + InsertSync, the IR uses memref types +// with pto.bind_tile ops carrying tile metadata. This pass recovers tile_buf +// types from those anchors so that the subsequent Tile→Vector lowering +// (Expand TileOp) can operate on tile_buf semantics. +// +// The pass does NOT redo memory planning or synchronisation; it only re-wraps +// planned memref values into tile_buf through unrealized_conversion_cast. + +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Pass/Pass.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; + +namespace mlir { +namespace pto { + namespace func = ::mlir::func; + + #define GEN_PASS_DEF_MEMREFTOTILEBUF + #include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +namespace { + +// ============================================================================ +// Helper: reconstruct TileBufType from a BindTileOp +// ============================================================================ +// BindTileOp carries: +// - source memref → shape, elementType, memorySpace +// - config attr → bLayout, sLayout, fractal, pad +// - valid_row/col → validShape (static or dynamic) +static pto::TileBufType reconstructTileBufType(pto::BindTileOp bindOp) { + auto memrefTy = cast(bindOp.getSource().getType()); + MLIRContext *ctx = bindOp.getContext(); + + ArrayRef shape = memrefTy.getShape(); + Type elemTy = memrefTy.getElementType(); + Attribute memSpace = memrefTy.getMemorySpace(); + pto::TileBufConfigAttr config = bindOp.getConfig(); + + // Recover valid shape: if BindTileOp provides valid_row/valid_col, check + // whether they are static constants. Otherwise mark as dynamic. + SmallVector validShape; + if (shape.size() == 2) { + auto resolveValidDim = [](Value v, int64_t staticDim) -> int64_t { + if (!v) + return staticDim; // no dynamic override → use static shape + if (auto cOp = v.getDefiningOp()) + return cOp.value(); + if (auto cInt = v.getDefiningOp()) + return cInt.value(); + return ShapedType::kDynamic; + }; + validShape.push_back(resolveValidDim(bindOp.getValidRow(), shape[0])); + validShape.push_back(resolveValidDim(bindOp.getValidCol(), shape[1])); + } else { + // Fallback: validShape = shape + validShape.assign(shape.begin(), shape.end()); + } + + return pto::TileBufType::get(ctx, shape, elemTy, memSpace, validShape, + config); +} + +// ============================================================================ +// Helper: check whether an op is a tile-level op (needs tile_buf operands) +// ============================================================================ +static bool isTileOp(Operation *op) { + return isa(op); +} + +// ============================================================================ +// The Pass +// ============================================================================ +struct MemrefToTileBufPass + : public mlir::pto::impl::MemrefToTileBufBase { + + void runOnOperation() override { + ModuleOp mod = getOperation(); + MLIRContext *ctx = &getContext(); + + for (auto func : mod.getOps()) { + if (func.isExternal()) + continue; + if (failed(processFunction(func, ctx))) + return signalPassFailure(); + } + } + +private: + LogicalResult processFunction(func::FuncOp func, MLIRContext *ctx); +}; + +LogicalResult MemrefToTileBufPass::processFunction(func::FuncOp func, + MLIRContext *ctx) { + OpBuilder builder(ctx); + + // Phase 1: For each BindTileOp, reconstruct tile_buf type and create a + // cast from the BindTileOp result (memref) to tile_buf. + // + // We build a mapping: BindTileOp result (memref Value) → tile_buf Value + // so that Phase 2 can replace tile op operands. + llvm::DenseMap memrefToTileBuf; + + SmallVector bindOps; + func.walk([&](pto::BindTileOp op) { bindOps.push_back(op); }); + + for (auto bindOp : bindOps) { + pto::TileBufType tileBufTy = reconstructTileBufType(bindOp); + + // Insert unrealized_conversion_cast right after the BindTileOp. + builder.setInsertionPointAfter(bindOp); + auto cast = builder.create( + bindOp.getLoc(), tileBufTy, bindOp.getResult()); + + memrefToTileBuf[bindOp.getResult()] = cast.getResult(0); + } + + // Phase 2: For each tile op, replace memref operands that have a + // corresponding tile_buf value. + func.walk([&](Operation *op) { + if (!isTileOp(op)) + return; + for (OpOperand &operand : op->getOpOperands()) { + auto it = memrefToTileBuf.find(operand.get()); + if (it != memrefToTileBuf.end()) { + operand.set(it->second); + } + } + }); + + // Phase 3: For function arguments that feed directly into BindTileOp, + // convert the argument type to tile_buf and propagate. + // + // Pattern: func @f(%arg: memref<...>) { %b = bind_tile %arg ... } + // → func @f(%arg: tile_buf<...>) { ... } + // + // We track which args were converted so we can update the function type. + Block &entry = func.front(); + auto fnTy = func.getFunctionType(); + SmallVector newInputTypes(fnTy.getInputs().begin(), + fnTy.getInputs().end()); + bool sigChanged = false; + + for (auto bindOp : bindOps) { + Value source = bindOp.getSource(); + auto blockArg = dyn_cast(source); + if (!blockArg || blockArg.getOwner() != &entry) + continue; + + // This argument was originally tile_buf before PTOViewToMemref. + unsigned idx = blockArg.getArgNumber(); + pto::TileBufType tileBufTy = reconstructTileBufType(bindOp); + + // Replace all tile op uses of the cast with the block arg directly. + auto castIt = memrefToTileBuf.find(bindOp.getResult()); + if (castIt == memrefToTileBuf.end()) + continue; + Value tileBufVal = castIt->second; + + // Save the original memref type before mutating. + Type origMemrefTy = blockArg.getType(); + + // Change the block argument type to tile_buf. + blockArg.setType(tileBufTy); + newInputTypes[idx] = tileBufTy; + sigChanged = true; + + // Insert a tile_buf → memref cast so that existing memref users of the + // block arg continue to work (PlanMemory / InsertSync results). + builder.setInsertionPointToStart(&entry); + auto backCast = builder.create( + func.getLoc(), origMemrefTy, blockArg); + + // Replace all non-tile-op uses of the original block arg with the + // back-cast memref. (Tile ops will use the tile_buf directly.) + // + // We must be careful: the BindTileOp itself uses the block arg as source. + // After this rewrite, BindTileOp.source should use the back-cast. + blockArg.replaceAllUsesWith(backCast.getResult(0)); + // But the back-cast's own operand must remain the block arg. + backCast.getInputsMutable().assign(ValueRange{blockArg}); + + // Now replace tile op uses: they should use the tile_buf block arg + // directly instead of going through the unrealized_conversion_cast + // chain. + tileBufVal.replaceAllUsesWith(blockArg); + // Erase the now-dead forward cast (memref → tile_buf). + if (auto castOp = + tileBufVal.getDefiningOp()) { + if (castOp->use_empty()) + castOp->erase(); + } + } + + // Update function signature if any arguments changed. + if (sigChanged) { + func.setFunctionType( + FunctionType::get(ctx, newInputTypes, fnTy.getResults())); + } + + // Phase 4: Clean up BindTileOps whose results are only used by the + // (now-erased) forward casts. If a BindTileOp still has memref users + // (e.g. memref.subview), keep it. + for (auto bindOp : bindOps) { + if (bindOp->use_empty()) + bindOp->erase(); + } + + return success(); +} + +} // namespace + +namespace mlir { +namespace pto { + +std::unique_ptr createMemrefToTileBufPass() { + return std::make_unique(); +} + +} // namespace pto +} // namespace mlir diff --git a/lib/PTO/Transforms/PTOInstantiateAndInlineOpLib.cpp b/lib/PTO/Transforms/PTOInstantiateAndInlineOpLib.cpp new file mode 100644 index 000000000..b08d5f23b --- /dev/null +++ b/lib/PTO/Transforms/PTOInstantiateAndInlineOpLib.cpp @@ -0,0 +1,246 @@ +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Pass/Pass.h" + +#include "llvm/ADT/StringRef.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_PTOINLINELIBCALL +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; + +namespace { + +static constexpr llvm::StringLiteral kOpLibAttrInstVariantId = + "pto.oplib.instance.variant_id"; +static constexpr llvm::StringLiteral kOpLibAttrInstOp = "pto.oplib.instance.op"; +static constexpr llvm::StringLiteral kOpLibAttrInstDType = + "pto.oplib.instance.dtype"; +static constexpr llvm::StringLiteral kErrInstanceBodyMissing = + "E_OPLIB_INSTANCE_BODY_MISSING"; + +static bool isInstanceFunc(func::FuncOp fn) { + return fn->hasAttr(kOpLibAttrInstVariantId); +} + +static bool isTilelangFunc(func::FuncOp fn) { + return fn->hasAttr("pto.tilelang.instance"); +} + +static bool isInlineableLibFunc(func::FuncOp fn) { + return isInstanceFunc(fn) || isTilelangFunc(fn); +} + +static Value maybeUnwrapCastToExpected(Value operand, Type expectedType) { + if (operand.getType() == expectedType) + return operand; + + auto cast = operand.getDefiningOp(); + if (!cast || cast->getNumOperands() != 1 || cast->getNumResults() != 1) + return operand; + + if (cast.getOperand(0).getType() == expectedType) + return cast.getOperand(0); + return operand; +} + +static Operation *cloneOpForInlineWithFix(OpBuilder &builder, Operation &op, + IRMapping &mapping) { + if (auto alloc = dyn_cast(&op)) { + auto mapOperand = [&](Value operand, Type expectedType) -> Value { + if (!operand) + return Value(); + Value mapped = mapping.lookupOrNull(operand); + if (!mapped) + mapped = operand; + return maybeUnwrapCastToExpected(mapped, expectedType); + }; + + Value mappedAddr = mapOperand( + alloc.getAddr(), alloc.getAddr() ? alloc.getAddr().getType() : Type()); + Value mappedValidRow = mapOperand( + alloc.getValidRow(), + alloc.getValidRow() ? alloc.getValidRow().getType() : Type()); + Value mappedValidCol = mapOperand( + alloc.getValidCol(), + alloc.getValidCol() ? alloc.getValidCol().getType() : Type()); + + auto cloned = builder.create( + alloc.getLoc(), alloc.getType(), mappedAddr, mappedValidRow, + mappedValidCol); + cloned->setAttrs(alloc->getAttrs()); + return cloned.getOperation(); + } + + return builder.clone(op, mapping); +} + +static void eraseDeadBridgeCasts(func::FuncOp func) { + bool changed = true; + while (changed) { + changed = false; + + SmallVector deadUnrealized; + func.walk([&](UnrealizedConversionCastOp cast) { + if (cast->use_empty()) + deadUnrealized.push_back(cast); + }); + + SmallVector deadMemrefCasts; + func.walk([&](memref::CastOp cast) { + if (cast->use_empty()) + deadMemrefCasts.push_back(cast); + }); + + if (deadUnrealized.empty() && deadMemrefCasts.empty()) + break; + + for (UnrealizedConversionCastOp cast : llvm::reverse(deadUnrealized)) + cast.erase(); + for (memref::CastOp cast : llvm::reverse(deadMemrefCasts)) + cast.erase(); + changed = true; + } +} + +static LogicalResult inlineCall(func::CallOp call, func::FuncOp callee) { + if (call.getNumResults() != 0) + return call.emitOpError("OP-Lib inline expects call without results"); + if (callee.isExternal()) + return call.emitOpError("callee must have a body before inlining"); + + Block &entry = callee.getBody().front(); + if (entry.getNumArguments() != call.getNumOperands()) + return call.emitOpError("callee argument count mismatch during inlining"); + + OpBuilder builder(call); + IRMapping mapping; + for (auto [arg, operand] : + llvm::zip(entry.getArguments(), call.getOperands())) + mapping.map(arg, operand); + + for (Operation &op : entry.without_terminator()) { + Operation *newOp = cloneOpForInlineWithFix(builder, op, mapping); + for (auto [oldRes, newRes] : + llvm::zip(op.getResults(), newOp->getResults())) + mapping.map(oldRes, newRes); + } + + call.erase(); + return success(); +} + +struct PTOInlineLibCallPass + : public pto::impl::PTOInlineLibCallBase { + using pto::impl::PTOInlineLibCallBase< + PTOInlineLibCallPass>::PTOInlineLibCallBase; + + void runOnOperation() override { + ModuleOp module = getOperation(); + + int inlinedCalls = 0; + int touchedFuncs = 0; + + for (func::FuncOp func : module.getOps()) { + if (func.isExternal()) + continue; + if (isInlineableLibFunc(func)) + continue; + if (func.empty()) + continue; + + SmallVector calls; + func.walk([&](func::CallOp call) { calls.push_back(call); }); + + bool changedThisFunc = false; + for (func::CallOp oldCall : calls) { + if (!oldCall || !oldCall->getBlock()) + continue; + + auto calleeAttr = oldCall.getCalleeAttr(); + if (!calleeAttr) + continue; + + func::FuncOp callee = + module.lookupSymbol(calleeAttr.getValue()); + if (!callee || !isInlineableLibFunc(callee)) + continue; + + if (callee.isExternal()) { + oldCall.emitError() << kErrInstanceBodyMissing + << ": OP-Lib instance body is missing for @" + << callee.getSymName(); + if (auto variant = + callee->getAttrOfType(kOpLibAttrInstVariantId)) { + oldCall.emitRemark() << "variant_id=" << variant.getValue(); + } + if (auto op = callee->getAttrOfType(kOpLibAttrInstOp)) { + oldCall.emitRemark() << "op=" << op.getValue(); + } + if (auto dtype = + callee->getAttrOfType(kOpLibAttrInstDType)) { + oldCall.emitRemark() << "dtype=" << dtype.getValue(); + } + signalPassFailure(); + return; + } + + func::CallOp call = oldCall; + SmallVector concreteOperands; + concreteOperands.reserve(call.getNumOperands()); + for (auto [operand, expectedTy] : llvm::zip( + call.getOperands(), callee.getFunctionType().getInputs())) { + concreteOperands.push_back( + maybeUnwrapCastToExpected(operand, expectedTy)); + } + + OpBuilder builder(call); + auto newCall = builder.create(call.getLoc(), callee, + concreteOperands); + call.erase(); + + if (failed(inlineCall(newCall, callee))) { + signalPassFailure(); + return; + } + + ++inlinedCalls; + changedThisFunc = true; + if (debug) { + llvm::errs() << "[op-fusion] inline-libcall: inlined @" + << callee.getSymName() << " into @" << func.getSymName() + << "\n"; + } + } + + if (changedThisFunc) { + eraseDeadBridgeCasts(func); + ++touchedFuncs; + } + } + + if (debug) { + llvm::errs() << "[op-fusion] inline-libcall touched " << touchedFuncs + << " function(s), inlined " << inlinedCalls << " call(s)\n"; + } + } +}; + +} // namespace + +std::unique_ptr +mlir::pto::createPTOInlineLibCallPass(const PTOInlineLibCallOptions &options) { + return std::make_unique(options); +} diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index 7fd339515..0b80849b1 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -1149,6 +1149,43 @@ struct PTOViewToMemrefPass // Update function type func.setFunctionType(FunctionType::get(ctx, newInputs, newResults)); + // ------------------------------------------------------------------ + // Stage 0.25: Insert pto.bind_tile for function args that were tile_buf + // ------------------------------------------------------------------ + // When MemrefToTileBuf runs later, it needs BindTileOp as the anchor to + // recover tile_buf types. For function args, no such anchor exists after + // the Stage-0 type rewrite, so we create one here. + { + IRRewriter rewriter(ctx); + // Insert after existing block args, before any existing ops. + rewriter.setInsertionPointToStart(&entry); + for (unsigned i = 0; i < entry.getNumArguments(); ++i) { + Type origTy = fnTy.getInputs()[i]; + auto tbTy = dyn_cast(origTy); + if (!tbTy) + continue; + + auto configAttr = tbTy.getConfigAttr(); + if (!configAttr) configAttr = pto::TileBufConfigAttr::getDefault(ctx); + + Value vRow, vCol; + auto vs = tbTy.getValidShape(); + if (vs.size() == 2) { + if (vs[0] != ShapedType::kDynamic) + vRow = rewriter.create(func.getLoc(), vs[0]); + if (vs[1] != ShapedType::kDynamic) + vCol = rewriter.create(func.getLoc(), vs[1]); + } + + auto bindOp = rewriter.create( + func.getLoc(), newInputs[i], entry.getArgument(i), + vRow ? vRow : Value(), vCol ? vCol : Value(), configAttr); + markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); + + entry.getArgument(i).replaceAllUsesExcept(bindOp.getResult(), bindOp); + } + } + // ------------------------------------------------------------------ // Stage 0.5: lower pto.alloc_tile -> memref.alloc + pto.bind_tile // ------------------------------------------------------------------ diff --git a/test/basic/expand_tile_op_tilelang.pto b/test/basic/expand_tile_op_tilelang.pto new file mode 100644 index 000000000..e0733908b --- /dev/null +++ b/test/basic/expand_tile_op_tilelang.pto @@ -0,0 +1,49 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tadd via TileLang Python DSL templates. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// REQUIRES: tilelang-dsl-tile-buf-params +// RUN: ptoas --pto-arch=a5 --enable-tile-to-vector --tilelang-path=%S/../tilelang_templates --tilelang-pkg-path=%S/../../tilelang-dsl/python %s 2>/dev/null | FileCheck %s +// +// NOTE: This test requires the TileLang DSL to generate functions with +// tile_buf parameters (not memref). The test is gated by REQUIRES until +// the DSL frontend is updated. + +// After the full pipeline, tile_buf_addr/tile_valid_rows/tile_valid_cols +// should be folded away, and the vector loop body should be inlined. +// CHECK: func.func @TADD +// CHECK-NOT: pto.tadd ins +// CHECK: pto.simd.tile_to_memref +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vadd +// CHECK: pto.vsts + +module { + func.func @TADD( + %a: !pto.tile_buf, + %b: !pto.tile_buf, + %c: !pto.tile_buf) + attributes { pto.tile_function = "pto.tadd" } { + + pto.tadd ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + return + } +} diff --git a/test/basic/fold_tile_buf_intrinsics.pto b/test/basic/fold_tile_buf_intrinsics.pto new file mode 100644 index 000000000..b6c565a83 --- /dev/null +++ b/test/basic/fold_tile_buf_intrinsics.pto @@ -0,0 +1,90 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Unit test for InlineLibCall + FoldTileBufIntrinsics passes. +// +// Input simulates ExpandTileOp output: a caller with func.call to a template +// function that has tile_buf params, tile_buf_addr / tile_valid_rows / +// tile_valid_cols intrinsics, and the pto.tilelang.instance attribute. +// +// RUN: ptoas --pto-arch=a5 --pass-pipeline="builtin.module(pto-inline-libcall,func.func(pto-fold-tile-buf-intrinsics))" %s | FileCheck %s + +// After inline + fold: +// - The call to @__pto_tilelang_tadd_f32_16_64 should be inlined +// - tile_buf_addr should be folded to pto.simd.tile_to_memref +// - tile_valid_rows/cols should be folded to arith.constant (static case) +// CHECK: func.func @TADD +// CHECK-NOT: call @__pto_tilelang_tadd +// CHECK: pto.simd.tile_to_memref +// CHECK: arith.constant 16 : index +// CHECK: arith.constant 64 : index +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vadd +// CHECK: pto.vsts + +module attributes {pto.target_arch = "a5"} { + + // Caller function: has a tile op replaced by func.call (simulating ExpandTileOp output). + func.func @TADD( + %a: !pto.tile_buf, + %b: !pto.tile_buf, + %c: !pto.tile_buf) + attributes { pto.tile_function = "pto.tadd" } { + call @__pto_tilelang_tadd_f32_16_64(%a, %b, %c) + : (!pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) -> () + return + } + + // TileLang DSL-generated template function (simulated). + // Parameters are tile_buf; body uses tile_buf_addr / tile_valid_rows / tile_valid_cols. + func.func private @__pto_tilelang_tadd_f32_16_64( + %dst: !pto.tile_buf, + %src0: !pto.tile_buf, + %src1: !pto.tile_buf) + attributes { pto.tilelang.instance } { + + %mDst = pto.tile_buf_addr %dst : !pto.tile_buf -> memref<16x64xf32, strided<[64, 1]>, #pto.address_space> + %mSrc0 = pto.tile_buf_addr %src0 : !pto.tile_buf -> memref<16x64xf32, strided<[64, 1]>, #pto.address_space> + %mSrc1 = pto.tile_buf_addr %src1 : !pto.tile_buf -> memref<16x64xf32, strided<[64, 1]>, #pto.address_space> + + %v_rows = pto.tile_valid_rows %dst : !pto.tile_buf -> index + %v_cols = pto.tile_valid_cols %dst : !pto.tile_buf -> index + %v_cols_i32 = arith.index_cast %v_cols : index to i32 + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + + pto.vecscope { + scf.for %i = %c0 to %v_rows step %c1 { + %tmp = arith.index_cast %v_cols : index to i32 + %inner = scf.for %j = %c0 to %v_cols step %c64 iter_args(%remain = %tmp) -> (i32) { + %mask, %next = pto.plt_b32 %remain : i32 -> !pto.mask, i32 + %va = pto.vlds %mSrc0[%i, %j] : memref<16x64xf32, strided<[64, 1]>, #pto.address_space> -> !pto.vreg<64xf32> + %vb = pto.vlds %mSrc1[%i, %j] : memref<16x64xf32, strided<[64, 1]>, #pto.address_space> -> !pto.vreg<64xf32> + %vc = pto.vadd %va, %vb, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %vc, %mDst[%i, %j], %mask : !pto.vreg<64xf32>, memref<16x64xf32, strided<[64, 1]>, #pto.address_space>, !pto.mask + scf.yield %next : i32 + } + } + } + return + } +} diff --git a/test/tilelang_templates/tadd_template.py b/test/tilelang_templates/tadd_template.py new file mode 100644 index 000000000..be561220e --- /dev/null +++ b/test/tilelang_templates/tadd_template.py @@ -0,0 +1,32 @@ +"""TileLang DSL template for pto.tadd — used by ExpandTileOp tests.""" + +import sys +from pathlib import Path + +_repo = Path(__file__).resolve().parents[2] +_pkg = _repo / "tilelang-dsl" / "python" +if str(_pkg) not in sys.path: + sys.path.insert(0, str(_pkg)) + +import tilelang_dsl as pto + + +@pto.vkernel( + op="pto.tadd", + dtypes=[(pto.f32, pto.f32, pto.f32)], + advanced=True, + name="template_tadd", +) +def template_tadd(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + summed = pto.vadd(lhs, rhs, mask) + pto.vsts(summed, dst[row, col:], mask) + return None diff --git a/tilelang-dsl/examples/tadd_demo.py b/tilelang-dsl/examples/tadd_demo.py new file mode 100644 index 000000000..d1843f450 --- /dev/null +++ b/tilelang-dsl/examples/tadd_demo.py @@ -0,0 +1,72 @@ +"""TileLang DSL v1 demo: pto.tadd (element-wise add) using Tile parameters. + +Note: v1 surface only supports 1D vectorized iteration within strict_vecscope. +The canonical 2D row×col loop with dynamic masking requires v2 features. +This demo demonstrates a 1D inner-loop pattern over the tile's column extent. +""" + +import sys +from pathlib import Path + + +def _import_tilelang_dsl(): + try: + import tilelang_dsl as pto + return pto + except ModuleNotFoundError: + repo_root = Path(__file__).resolve().parents[2] + sys.path.insert(0, str(repo_root / "python")) + import tilelang_dsl as pto + return pto + + +pto = _import_tilelang_dsl() + + +@pto.vkernel( + op="pto.tadd", + dtypes=[(pto.f32, pto.f32, pto.f32)], + name="template_tadd", +) +def template_tadd(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + # v1 strict_vecscope: all referenced values must be passed in explicitly, + # and only scalar offsets (not 2D subscripts) are supported for vlds/vsts. + with pto.strict_vecscope(src0, src1, dst, 0, 256, 64) as ( + a, b, c, lb, ub, step + ): + for j in range(lb, ub, step): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec_a = pto.vlds(a, j) + vec_b = pto.vlds(b, j) + result = pto.vadd(vec_a, vec_b, mask) + pto.vsts(result, c, j, mask) + + +def main(argv: list[str]) -> int: + specialized = template_tadd.specialize( + src0=pto.TileSpecialization( + shape=(16, 64), + memory_space=pto.MemorySpace.UB, + ), + src1=pto.TileSpecialization( + shape=(16, 64), + memory_space=pto.MemorySpace.UB, + ), + dst=pto.TileSpecialization( + shape=(16, 64), + memory_space=pto.MemorySpace.UB, + ), + ) + + if len(argv) == 2: + output_path = Path(argv[1]) + specialized.emit(output_path) + print(f"wrote MLIR to {output_path}") + return 0 + + print(specialized.mlir_text()) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main(sys.argv)) diff --git a/tilelang-dsl/python/tilelang_dsl/expand_helper.py b/tilelang-dsl/python/tilelang_dsl/expand_helper.py new file mode 100644 index 000000000..0a93b4818 --- /dev/null +++ b/tilelang-dsl/python/tilelang_dsl/expand_helper.py @@ -0,0 +1,154 @@ +"""CLI helper invoked by ExpandTileOp to instantiate a tilelang DSL template. + +Usage: + python3 -m tilelang_dsl.expand_helper \ + --template-dir /path/to/templates \ + --op pto.tadd \ + --dtype f32 \ + --shape 16,64 \ + --memory-space ub + +Scans --template-dir for .py files, finds a @vkernel whose `op` matches, +specializes every Tile parameter with the given shape/memory_space, and +prints the materialized MLIR module to stdout. +""" + +from __future__ import annotations + +import argparse +import importlib.util +import sys +from pathlib import Path + +from .kernel import VKernelDescriptor +from .types import MemorySpace, ScalarType, TileSpecialization + + +_DTYPE_MAP: dict[str, ScalarType] = {} + + +def _populate_dtype_map() -> None: + from . import types as _t + + for name in ("f16", "bf16", "f32", "i8", "i16", "i32", "i64"): + obj = getattr(_t, name, None) + if isinstance(obj, ScalarType): + _DTYPE_MAP[name] = obj + + +_populate_dtype_map() + +_MEMSPACE_MAP = { + "ub": MemorySpace.UB, + "gm": MemorySpace.GM, +} + + +def _find_descriptors(module) -> list[VKernelDescriptor]: + """Return all VKernelDescriptor instances found as module-level attributes.""" + result = [] + for attr_name in dir(module): + obj = getattr(module, attr_name, None) + if isinstance(obj, VKernelDescriptor): + result.append(obj) + return result + + +def _import_py_file(path: Path): + """Import a .py file as a module and return it.""" + spec = importlib.util.spec_from_file_location(f"_tl_template_{path.stem}", str(path)) + if spec is None or spec.loader is None: + return None + mod = importlib.util.module_from_spec(spec) + try: + spec.loader.exec_module(mod) + except Exception as exc: + print(f"expand_helper: warning: failed to import {path}: {exc}", file=sys.stderr) + return None + return mod + + +def _match_descriptor( + descriptors: list[VKernelDescriptor], + op_name: str, + dtype_name: str, +) -> VKernelDescriptor | None: + """Find the first descriptor matching (op, dtype).""" + target_dtype = _DTYPE_MAP.get(dtype_name) + if target_dtype is None: + return None + + for desc in descriptors: + if desc.op != op_name: + continue + # Check dtype signature: all entries must match the target dtype. + sig = desc.dtype_signature + if all(d == target_dtype for d in sig): + return desc + return None + + +def main(argv: list[str] | None = None) -> int: + parser = argparse.ArgumentParser(description="TileLang DSL expand helper") + parser.add_argument("--template-dir", required=True, help="Directory of .py templates") + parser.add_argument("--op", required=True, help="Tile op name, e.g. pto.tadd") + parser.add_argument("--dtype", required=True, help="Element dtype, e.g. f32") + parser.add_argument("--shape", required=True, help="Tile shape, e.g. 16,64") + parser.add_argument("--memory-space", default="ub", help="Memory space (ub or gm)") + args = parser.parse_args(argv) + + template_dir = Path(args.template_dir) + if not template_dir.is_dir(): + print(f"expand_helper: error: {template_dir} is not a directory", file=sys.stderr) + return 1 + + shape = tuple(int(d) for d in args.shape.split(",")) + mem_space = _MEMSPACE_MAP.get(args.memory_space) + if mem_space is None: + print(f"expand_helper: error: unknown memory-space '{args.memory_space}'", file=sys.stderr) + return 1 + + # Scan all .py files for descriptors. + all_descriptors: list[VKernelDescriptor] = [] + for py_path in sorted(template_dir.glob("*.py")): + mod = _import_py_file(py_path) + if mod is None: + continue + all_descriptors.extend(_find_descriptors(mod)) + + if not all_descriptors: + print(f"expand_helper: error: no @vkernel descriptors found in {template_dir}", file=sys.stderr) + return 1 + + # Match. + desc = _match_descriptor(all_descriptors, args.op, args.dtype) + if desc is None: + print( + f"expand_helper: error: no template matches op={args.op} dtype={args.dtype}", + file=sys.stderr, + ) + return 1 + + # Specialize all Tile parameters with the same shape/memory_space. + tile_specs = {} + for param in desc.tile_parameters: + tile_specs[param.name] = TileSpecialization( + shape=shape, + memory_space=mem_space, + ) + + specialized = desc.specialize(**tile_specs) + + # Emit MLIR to stdout. + try: + mlir_text = specialized.mlir_text() + except Exception as exc: + print(f"expand_helper: error: materialization failed: {exc}", file=sys.stderr) + return 1 + + sys.stdout.write(mlir_text) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index f90b95f3a..5daacf40b 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -187,6 +187,22 @@ static llvm::cl::opt enableInsertSync("enable-insert-sync", llvm::cl::desc("Enable automatic synchronization insertion pass"), llvm::cl::init(false)); +static llvm::cl::opt enableTileToVector( + "enable-tile-to-vector", + llvm::cl::desc( + "Enable Tile-to-Vector lowering path (memref->tile_buf recovery)"), + llvm::cl::init(false)); + +static llvm::cl::opt tilelangPath( + "tilelang-path", + llvm::cl::desc("Path to directory of .py tilelang DSL template files"), + llvm::cl::init("")); + +static llvm::cl::opt tilelangPkgPath( + "tilelang-pkg-path", + llvm::cl::desc("PYTHONPATH for tilelang_dsl package"), + llvm::cl::init("")); + static llvm::cl::opt disableInferLayout( "disable-infer-layout", llvm::cl::desc("Disable PTO layout inference pass (static-only)"), @@ -1158,6 +1174,11 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); + mlir::registerAllPasses(); + ::registerPTOInlineLibCall(); + ::registerFoldTileBufIntrinsics(); + ::registerExpandTileOp(); + mlir::registerPassManagerCLOptions(); llvm::cl::SetVersionPrinter(printPTOASVersion); @@ -1460,11 +1481,39 @@ int main(int argc, char **argv) { pm.addPass(emitc::createFormExpressionsPass()); pm.addPass(mlir::createCSEPass()); + if (enableTileToVector) { + // Tile→Vector path: + // 1. MemrefToTileBuf: recover tile_buf from memref + // 2. ExpandTileOp: instantiate TileLang DSL templates, replace tile ops + // with func.call to template functions (tile_buf params) + // 3. InlineLibCall: inline template function bodies + // 4. FoldTileBufIntrinsics: fold tile_buf_addr / tile_valid_rows / + // tile_valid_cols to concrete memref/constant values + pm.addPass(pto::createMemrefToTileBufPass()); + + pto::ExpandTileOpOptions expandOpts; + expandOpts.tilelangPath = tilelangPath; + expandOpts.tilelangPkgPath = tilelangPkgPath; + pm.addPass(pto::createExpandTileOpPass(expandOpts)); + + pm.addPass(pto::createPTOInlineLibCallPass()); + pm.addNestedPass( + pto::createFoldTileBufIntrinsicsPass()); + } + if (failed(pm.run(*module))) { llvm::errs() << "Error: Pass execution failed.\n"; return 1; } + // Tile→Vector path: print MLIR IR and exit (no C++ emission). + if (enableTileToVector) { + module->print(outputFile.os()); + outputFile.os() << "\n"; + outputFile.keep(); + return 0; + } + dropEmptyEmitCExpressions(module.get()); materializeControlFlowOperands(module.get()); if (failed(reorderEmitCFunctions(module.get()))) { From f64012916b7bf3d0ea803870694b33f08fcefc09 Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Mon, 13 Apr 2026 00:11:29 +0800 Subject: [PATCH 026/192] brough back the lost pass and pto op after rebase --- include/PTO/IR/PTOOps.td | 26 ++++++ include/PTO/Transforms/Passes.h | 2 + include/PTO/Transforms/Passes.td | 18 ++++ lib/PTO/IR/PTO.cpp | 145 ++++++++++++++++++++++++++++++ lib/PTO/Transforms/CMakeLists.txt | 1 + 5 files changed, 192 insertions(+) diff --git a/include/PTO/IR/PTOOps.td b/include/PTO/IR/PTOOps.td index 5e53b3363..e471b0f96 100644 --- a/include/PTO/IR/PTOOps.td +++ b/include/PTO/IR/PTOOps.td @@ -2165,6 +2165,32 @@ def TSyncOp : PTO_TOp<"tsync"> { }]; } +//===----------------------------------------------------------------------===// +// SIMD Bridge Ops +//===----------------------------------------------------------------------===// + +def SimdTileToMemrefOp : PTO_Op<"simd.tile_to_memref", [Pure]> { + let summary = "Bridge cast from tile_buf to memref in OP-Lib bodies."; + let description = [{ + This op is the canonical bridge marker for OP-Lib templates to expose a + memref view from tile-like values while keeping external ABI as + !pto.tile_buf. + In tile_buf world, src is !pto.tile_buf and dst is the corresponding + memref bridge type. + After memref-world lowering, src may already be memref and this op remains + as a marker for backend lowering (EmitC) to materialize tile data access. + }]; + + let arguments = (ins TileBufOrMemRef:$src); + let results = (outs AnyMemRef:$dst); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src attr-dict `:` qualified(type($src)) `to` type($dst) + }]; +} + // --------------------------------------------------------------------------- // TileBuf intrinsics — used in TileLang DSL-generated template functions. // These ops extract memref address and valid shape from tile_buf parameters. diff --git a/include/PTO/Transforms/Passes.h b/include/PTO/Transforms/Passes.h index fadd4e143..52ca26798 100644 --- a/include/PTO/Transforms/Passes.h +++ b/include/PTO/Transforms/Passes.h @@ -74,6 +74,8 @@ std::unique_ptr createMemrefToTileBufPass(); std::unique_ptr createExpandTileOpPass(); std::unique_ptr createExpandTileOpPass(const ExpandTileOpOptions &options); std::unique_ptr createFoldTileBufIntrinsicsPass(); +std::unique_ptr +createPTOInlineLibCallPass(const PTOInlineLibCallOptions &options = {}); //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index 0ed6e8299..4ac92e7d6 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -297,6 +297,24 @@ def FoldTileBufIntrinsics : Pass<"pto-fold-tile-buf-intrinsics", "mlir::func::Fu ]; } +def PTOInlineLibCall : Pass<"pto-inline-libcall", "ModuleOp"> { + let summary = "Materialize OP-Lib instance bodies and inline OP-Lib calls"; + let description = [{ + Resolves OP-Lib instance declarations generated by OP-Lib lowering, + materializes instance bodies, and inlines OP-Lib calls into caller/fused + helper functions. Function signatures stay in !pto.tile_buf form. + }]; + let constructor = "mlir::pto::createPTOInlineLibCallPass()"; + let dependentDialects = ["mlir::func::FuncDialect", "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::arith::ArithDialect", + "mlir::scf::SCFDialect"]; + let options = [Option< + "debug", "debug", "bool", + /*default=*/"false", + "Enable verbose debug logging for OP-Lib instantiation/inlining">]; +} + def PTOVerifyTFree : Pass<"pto-verify-tfree", "func::FuncOp"> { let summary = "Verify explicit matching pto.tfree placement for pto.tpop"; let description = [{ diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index 74751c313..1737234e3 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -9730,6 +9730,151 @@ static LogicalResult computeInnerShape(TileBufConfigAttr cfg, Type elemTy, return failure(); } +static LogicalResult +computeExpectedTileBufMemrefStrides(TileBufType tileTy, + SmallVectorImpl &expectedStrides) { + if (tileTy.getRank() != 2) + return failure(); + + ArrayRef shape = tileTy.getShape(); + if (shape.size() != 2) + return failure(); + if (shape[0] == ShapedType::kDynamic || shape[1] == ShapedType::kDynamic) + return failure(); + + auto cfg = tileTy.getConfigAttr(); + if (!cfg) + cfg = TileBufConfigAttr::getDefault(tileTy.getContext()); + + int64_t innerRows = 1, innerCols = 1; + bool boxed = false; + int32_t bl = 0, sl = 0; + if (failed(computeInnerShape(cfg, tileTy.getElementType(), innerRows, innerCols, + boxed, bl, sl))) + return failure(); + + expectedStrides.clear(); + if (!boxed) { + if (bl == 1) { + expectedStrides.push_back(1); + expectedStrides.push_back(shape[0]); + } else { + expectedStrides.push_back(shape[1]); + expectedStrides.push_back(1); + } + return success(); + } + + if (bl == 1) { + if (sl != 1) + return failure(); + expectedStrides.push_back(innerCols); + expectedStrides.push_back(shape[0]); + return success(); + } + + expectedStrides.push_back(shape[1]); + expectedStrides.push_back(innerRows); + return success(); +} + +mlir::LogicalResult mlir::pto::SimdTileToMemrefOp::verify() { + auto memTy = dyn_cast(getDst().getType()); + if (!memTy) + return emitOpError("expects result to be memref"); + + Type srcTy = getSrc().getType(); + if (auto tileTy = dyn_cast(srcTy)) { + if (memTy.getElementType() != tileTy.getElementType()) + return emitOpError( + "expects memref element type to match tile_buf element type"); + + if (memTy.getMemorySpace() != tileTy.getMemorySpace()) + return emitOpError( + "expects memref memory space to match tile_buf memory space"); + + if (memTy.getRank() != tileTy.getRank()) + return emitOpError("expects memref rank to match tile_buf rank"); + + ArrayRef tileShape = tileTy.getShape(); + ArrayRef validShape = tileTy.getValidShape(); + ArrayRef memShape = memTy.getShape(); + if (tileShape.size() != memShape.size()) + return emitOpError( + "expects memref shape rank to match tile_buf shape rank"); + + if (validShape.size() != memShape.size()) + return emitOpError( + "expects tile_buf valid shape rank to match memref shape rank"); + + for (unsigned i = 0; i < validShape.size(); ++i) { + int64_t expect = validShape[i]; + if (expect < 0) { + if (memShape[i] >= 0 && memShape[i] != tileShape[i]) { + return emitOpError() + << "expects memref dim " << i + << " to be dynamic or match physical tile dim " << tileShape[i] + << " because tile_buf valid dim is ?"; + } + continue; + } + + if (memShape[i] != expect) { + return emitOpError() << "expects memref dim " << i + << " to match tile_buf valid dim; got " + << memShape[i] << ", expected " << expect; + } + } + + SmallVector expectedStrides; + if (failed(computeExpectedTileBufMemrefStrides(tileTy, expectedStrides))) + return emitOpError("cannot infer expected strides from tile_buf layout"); + + SmallVector memStrides; + int64_t memOffset = ShapedType::kDynamic; + if (failed(getStridesAndOffset(memTy, memStrides, memOffset))) + return emitOpError("expects memref to use strided layout"); + if (memOffset != 0) + return emitOpError("expects memref offset to be 0"); + if (memStrides.size() != expectedStrides.size()) + return emitOpError("expects memref stride rank to match tile_buf rank"); + for (unsigned i = 0; i < expectedStrides.size(); ++i) { + if (memStrides[i] != expectedStrides[i]) { + return emitOpError() + << "expects memref strides to match tile_buf layout; got " + << memStrides[i] << " at dim " << i << ", expected " + << expectedStrides[i]; + } + } + return success(); + } + + auto srcMemTy = dyn_cast(srcTy); + if (!srcMemTy) + return emitOpError("expects src to be !pto.tile_buf or memref"); + + if (srcMemTy.getElementType() != memTy.getElementType()) + return emitOpError("expects src/result memref element types to match"); + + if (srcMemTy.getMemorySpace() != memTy.getMemorySpace()) + return emitOpError("expects src/result memref memory spaces to match"); + + if (srcMemTy.getRank() != memTy.getRank()) + return emitOpError("expects src/result memref ranks to match"); + + ArrayRef srcShape = srcMemTy.getShape(); + ArrayRef dstShape = memTy.getShape(); + for (unsigned i = 0; i < srcShape.size(); ++i) { + if (srcShape[i] >= 0 && dstShape[i] >= 0 && srcShape[i] != dstShape[i]) { + return emitOpError() + << "expects compatible src/result memref shapes; dim " << i + << " mismatches (" << srcShape[i] << " vs " << dstShape[i] << ")"; + } + } + + return success(); +} + mlir::LogicalResult mlir::pto::SubViewOp::verify() { if (shouldBypassDecodedMemrefVerifier(getOperation())) return success(); diff --git a/lib/PTO/Transforms/CMakeLists.txt b/lib/PTO/Transforms/CMakeLists.txt index 79d8ad2f8..8891f2b8e 100644 --- a/lib/PTO/Transforms/CMakeLists.txt +++ b/lib/PTO/Transforms/CMakeLists.txt @@ -27,6 +27,7 @@ add_mlir_dialect_library(PTOTransforms MemrefToTileBuf.cpp ExpandTileOp.cpp FoldTileBufIntrinsics.cpp + PTOInstantiateAndInlineOpLib.cpp PTOToEmitC.cpp Utils.cpp OptMemPlanForPipeline.cpp From 40b2c934781a06d596d761f2ae371dd6adc8f032 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Tue, 7 Apr 2026 11:29:23 +0800 Subject: [PATCH 027/192] feat: tile lang dsl --- docs/PTO_IR_manual.md | 38 + docs/build_with_installed_llvm.md | 153 + docs/designs/ptoas-tileop-expand-design.md | 467 ++- docs/tilelang-dsl-guide.md | 2978 ----------------- docs/vpto-spec.md | 54 + include/PTO/IR/PTOOps.td | 46 +- include/PTO/IR/VPTOOps.td | 53 +- include/PTO/Transforms/Passes.h | 2 + include/PTO/Transforms/Passes.td | 36 +- lib/PTO/IR/VPTO.cpp | 82 + lib/PTO/Transforms/CMakeLists.txt | 3 + lib/PTO/Transforms/ExpandTileOp.cpp | 117 +- lib/PTO/Transforms/FoldTileBufIntrinsics.cpp | 158 +- .../PTOInstantiateAndInlineOpLib.cpp | 65 +- lib/PTO/Transforms/PTOLowerToOpLibCalls.cpp | 209 ++ lib/PTO/Transforms/PTOLowerToOpLibCalls.h | 17 + lib/PTO/Transforms/PTOViewToMemref.cpp | 52 +- lib/PTO/Transforms/VPTOPtrCastCleanup.cpp | 73 + lib/PTO/Transforms/VPTOPtrNormalize.cpp | 418 +++ lib/TileOps/render_template_mlir.py | 342 ++ .../TileOps}/tadd_template.py | 18 +- lib/TileOps/tadds_template.py | 23 + lib/TileOps/tload_template.py | 91 + lib/TileOps/tstore_template.py | 88 + .../.openspec.yaml | 2 + .../design.md | 273 ++ .../proposal.md | 90 + .../specs/tilelang-dsl-kernel-matcher/spec.md | 85 + .../specs/tilelang-dsl-template-slots/spec.md | 68 + .../tasks.md | 24 + .../tilelang-dsl-advanced-surface/spec.md | 12 +- .../specs/tilelang-dsl-kernel-matcher/spec.md | 51 +- .../specs/tilelang-dsl-template-slots/spec.md | 70 + test/basic/expand_tile_op_tilelang.pto | 38 +- test/basic/fold_tile_buf_intrinsics.pto | 9 +- .../inline_libcall_filter_tilelang_scope.pto | 38 + test/basic/inline_libcall_result_rewrite.pto | 36 + .../tilelang_inline_proc_backend_inline.pto | 30 + .../vpto_mainline_inline_proc_cleanup.pto | 28 + test/dsl/expand_tile_op_tilelang_tadds.pto | 36 + tilelang-dsl/docs/README.md | 67 +- .../matcher-and-advanced-surface-migration.md | 163 +- tilelang-dsl/docs/unsupported-features.md | 249 ++ .../docs/user_guide/01-introduction.md | 48 + .../docs/user_guide/02-quick-start.md | 78 + .../docs/user_guide/03-kernel-declaration.md | 313 ++ .../docs/user_guide/04-template-kernels.md | 332 ++ .../docs/user_guide/05-type-system.md | 142 + tilelang-dsl/docs/user_guide/06-tensorview.md | 97 + tilelang-dsl/docs/user_guide/07-tile-types.md | 187 ++ .../docs/user_guide/08-control-flow.md | 142 + .../docs/user_guide/09-frontend-operations.md | 172 + .../docs/user_guide/10-sync-dma-operations.md | 460 +++ .../user_guide/11-vector-memory-operations.md | 986 ++++++ .../user_guide/12-predicate-operations.md | 498 +++ .../13-vector-arithmetic-operations.md | 1414 ++++++++ tilelang-dsl/docs/user_guide/14-examples.md | 154 + .../docs/user_guide/15-common-errors.md | 51 + .../docs/user_guide/16-compatibility-notes.md | 9 + tilelang-dsl/docs/user_guide/17-next-steps.md | 7 + tilelang-dsl/docs/v1-lowering.md | 20 +- tilelang-dsl/docs/v1-surface.md | 19 +- tilelang-dsl/examples/README.md | 13 +- tilelang-dsl/examples/tadd_demo.py | 7 +- .../examples/v1_elementwise_tail_demo.py | 1 + .../v1_tadd_implicit_vecscope_demo.py | 97 +- .../v1_tbinop_2d_nopostupdate_demo.py | 1 + .../examples/v1_template_slot_multiop_demo.py | 146 + tilelang-dsl/python/tilelang_dsl/__init__.py | 38 + .../python/tilelang_dsl/expand_helper.py | 144 +- .../python/tilelang_dsl/frontend_ast.py | 948 +++++- tilelang-dsl/python/tilelang_dsl/kernel.py | 1046 +++++- tilelang-dsl/python/tilelang_dsl/lowering.py | 2083 +++++++++++- tilelang-dsl/python/tilelang_dsl/semantic.py | 2083 +++++++++++- .../python/tilelang_dsl/support_matrix.py | 281 +- tilelang-dsl/python/tilelang_dsl/types.py | 86 +- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 2857 +++++++++++++++- tools/ptoas/CMakeLists.txt | 6 + tools/ptoas/ptoas.cpp | 88 +- 79 files changed, 18079 insertions(+), 3927 deletions(-) create mode 100644 docs/build_with_installed_llvm.md delete mode 100644 docs/tilelang-dsl-guide.md create mode 100644 lib/PTO/Transforms/PTOLowerToOpLibCalls.cpp create mode 100644 lib/PTO/Transforms/PTOLowerToOpLibCalls.h create mode 100644 lib/PTO/Transforms/VPTOPtrCastCleanup.cpp create mode 100644 lib/PTO/Transforms/VPTOPtrNormalize.cpp create mode 100644 lib/TileOps/render_template_mlir.py rename {test/tilelang_templates => lib/TileOps}/tadd_template.py (56%) create mode 100644 lib/TileOps/tadds_template.py create mode 100644 lib/TileOps/tload_template.py create mode 100644 lib/TileOps/tstore_template.py create mode 100644 openspec/changes/archive/2026-04-07-extend-tilelang-dsl-template-op-slots/.openspec.yaml create mode 100644 openspec/changes/archive/2026-04-07-extend-tilelang-dsl-template-op-slots/design.md create mode 100644 openspec/changes/archive/2026-04-07-extend-tilelang-dsl-template-op-slots/proposal.md create mode 100644 openspec/changes/archive/2026-04-07-extend-tilelang-dsl-template-op-slots/specs/tilelang-dsl-kernel-matcher/spec.md create mode 100644 openspec/changes/archive/2026-04-07-extend-tilelang-dsl-template-op-slots/specs/tilelang-dsl-template-slots/spec.md create mode 100644 openspec/changes/archive/2026-04-07-extend-tilelang-dsl-template-op-slots/tasks.md create mode 100644 openspec/specs/tilelang-dsl-template-slots/spec.md create mode 100644 test/basic/inline_libcall_filter_tilelang_scope.pto create mode 100644 test/basic/inline_libcall_result_rewrite.pto create mode 100644 test/basic/tilelang_inline_proc_backend_inline.pto create mode 100644 test/basic/vpto_mainline_inline_proc_cleanup.pto create mode 100644 test/dsl/expand_tile_op_tilelang_tadds.pto create mode 100644 tilelang-dsl/docs/unsupported-features.md create mode 100644 tilelang-dsl/docs/user_guide/01-introduction.md create mode 100644 tilelang-dsl/docs/user_guide/02-quick-start.md create mode 100644 tilelang-dsl/docs/user_guide/03-kernel-declaration.md create mode 100644 tilelang-dsl/docs/user_guide/04-template-kernels.md create mode 100644 tilelang-dsl/docs/user_guide/05-type-system.md create mode 100644 tilelang-dsl/docs/user_guide/06-tensorview.md create mode 100644 tilelang-dsl/docs/user_guide/07-tile-types.md create mode 100644 tilelang-dsl/docs/user_guide/08-control-flow.md create mode 100644 tilelang-dsl/docs/user_guide/09-frontend-operations.md create mode 100644 tilelang-dsl/docs/user_guide/10-sync-dma-operations.md create mode 100644 tilelang-dsl/docs/user_guide/11-vector-memory-operations.md create mode 100644 tilelang-dsl/docs/user_guide/12-predicate-operations.md create mode 100644 tilelang-dsl/docs/user_guide/13-vector-arithmetic-operations.md create mode 100644 tilelang-dsl/docs/user_guide/14-examples.md create mode 100644 tilelang-dsl/docs/user_guide/15-common-errors.md create mode 100644 tilelang-dsl/docs/user_guide/16-compatibility-notes.md create mode 100644 tilelang-dsl/docs/user_guide/17-next-steps.md create mode 100644 tilelang-dsl/examples/v1_template_slot_multiop_demo.py diff --git a/docs/PTO_IR_manual.md b/docs/PTO_IR_manual.md index 39d969fa1..f2aac7bcf 100644 --- a/docs/PTO_IR_manual.md +++ b/docs/PTO_IR_manual.md @@ -363,6 +363,7 @@ or between two `!pto.ptr` types. **Constraints & Verification:** - Integer-to-integer casts are rejected; use normal integer cast ops instead +- Descriptor values such as `!pto.tensor_view<...>` and `!pto.partition_tensor_view<...>` are not legal direct inputs; extract a memref address first - Pointer-to-pointer casts are only legal when source and destination stay in the same PTO memory space (`gm` or `ub`) - The operation is pure (no side effects) @@ -454,6 +455,43 @@ This op is primarily defined on `!pto.tensor_view`. --- +##### `pto.get_tensor_view_stride` - Get Tensor View Dimension Stride + +**Summary:** Returns the logical stride of a given dimension of a tensor view. + +**Semantics:** + +```mlir +stride = get_tensor_view_stride(tv_or_mr, dim_index) +``` + +This op is defined on `!pto.tensor_view`. During internal lowering, the same +query may temporarily appear on the memref form lowered from the tensor view. + +**Arguments:** + +| Name | Type | Description | +|------|------|-------------| +| `tensor_view` | `!pto.tensor_view<...>` or `memref<...>` | Logical tensor view or its lowered memref form | +| `dim_index` | `index` | Dimension index (0-based) | + +**Results:** `index` — the logical stride of the requested dimension, measured +in elements rather than bytes. + +**Notes:** + +- This op is the IR counterpart of the DSL-side `TensorView.strides` metadata access. +- After lowering to memref, static strides may be folded into constants, while dynamic strides are derived from memref metadata. + +**Basic Example:** + +```mlir +%s0 = pto.get_tensor_view_stride %tv, %c0 : !pto.tensor_view -> index +%s1 = pto.get_tensor_view_stride %tv, %c1 : !pto.tensor_view -> index +``` + +--- + ##### `pto.partition_view` - Partition Tensor View **Summary:** Creates a logical window on a tensor_view using offsets and sizes, producing a `partition_tensor_view`. diff --git a/docs/build_with_installed_llvm.md b/docs/build_with_installed_llvm.md new file mode 100644 index 000000000..b57b376fd --- /dev/null +++ b/docs/build_with_installed_llvm.md @@ -0,0 +1,153 @@ +# 基于已安装 LLVM 的 PTOAS 构建说明 + +本文档按 [README.md](../README.md) 第 3 章的逻辑整理,适用于: + +- LLVM/MLIR `19.1.7` 已经构建并安装完成。 +- LLVM 安装路径固定为 `/opt/llvm`。 +- `/opt/llvm` 是共享目录,不希望 `ptoas` 的安装步骤写入其中。 + +## 3.0 环境变量配置 + +先按 README 第 3.0 节的思路把变量定好。区别是这里不再使用 LLVM 源码目录和 LLVM build tree,而是直接使用 LLVM install tree。 + +```bash +# ================= 配置区域 (请按实际环境调整) ================= +export WORKSPACE_DIR=$HOME/llvm-workspace + +# LLVM 已安装完成,直接指向 install 根目录 +export LLVM_INSTALL_DIR=/opt/llvm + +# 为兼容仓库内部分脚本 / lit 变量命名,这里额外保留 LLVM_BUILD_DIR +export LLVM_BUILD_DIR=$LLVM_INSTALL_DIR + +# ptoas 源码与安装路径 +export PTO_SOURCE_DIR=$WORKSPACE_DIR/PTOAS +export PTO_INSTALL_DIR=$PTO_SOURCE_DIR/install-optllvm +# ============================================================ + +mkdir -p "$WORKSPACE_DIR" +``` + +说明: + +- 这里的 `LLVM_BUILD_DIR` 只是为了兼容仓库内已有变量名,实际指向的是 LLVM install 根目录 `/opt/llvm`。 +- `PTO_INSTALL_DIR` 建议单独放到 PTOAS 自己目录下,避免与共享 LLVM 安装混用。 + +## 3.1 环境准备 + +沿用 README 第 3.1 节即可,重点确认这些依赖已经满足: + +- Linux +- GCC >= 9 或 Clang +- CMake >= 3.20 +- Ninja +- Python 3.8+ +- `pybind11` +- `numpy` + +```bash +pip3 install pybind11 numpy +``` + +## 跳过 3.2 + +README 第 3.2 节是 LLVM/MLIR 的下载和编译步骤。当前场景下 LLVM 已经安装在 `/opt/llvm`,这一节可以直接跳过。 + +已验证: + +```bash +/opt/llvm/bin/llvm-config --version +``` + +输出为: + +```text +19.1.7 +``` + +## 3.3 第二步:构建 ptoas + +这里沿用 README 第 3.3 节的流程,但有两处需要改动: + +1. `LLVM_DIR` 和 `MLIR_DIR` 改为 `/opt/llvm/lib/cmake/...` +2. `MLIR_PYTHON_PACKAGE_DIR` 不再指向共享的 `/opt/llvm/python_packages/mlir_core`,而是指向 `PTO_INSTALL_DIR` + +如果继续沿用 README 里的 `MLIR_PYTHON_PACKAGE_DIR=$LLVM_BUILD_DIR/tools/mlir/python_packages/mlir_core`,在 `/opt/llvm` 场景下会把 `_pto.cpython-*.so` 安装到共享 LLVM 目录,不适合多人共用。 + +```bash +cd "$PTO_SOURCE_DIR" + +# 1. 获取 pybind11 的 CMake 路径 +export PYBIND11_CMAKE_DIR=$(python3 -m pybind11 --cmakedir) + +# 2. 配置 CMake +cmake -G Ninja \ + -S . \ + -B build \ + -DLLVM_DIR=$LLVM_INSTALL_DIR/lib/cmake/llvm \ + -DMLIR_DIR=$LLVM_INSTALL_DIR/lib/cmake/mlir \ + -DPython3_EXECUTABLE=$(which python3) \ + -DPython3_FIND_STRATEGY=LOCATION \ + -Dpybind11_DIR="${PYBIND11_CMAKE_DIR}" \ + -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DMLIR_PYTHON_PACKAGE_DIR="$PTO_INSTALL_DIR" \ + -DCMAKE_INSTALL_PREFIX="$PTO_INSTALL_DIR" + +# 3. 编译并安装 +ninja -C build +cmake --install build +``` + +## 构建后关键产物 + +按上面的配置,关键产物位置如下: + +- build 目录: + - `$PTO_SOURCE_DIR/build/tools/ptoas/ptoas` + - `$PTO_SOURCE_DIR/build/tools/ptobc/ptobc` + - `$PTO_SOURCE_DIR/build/python/mlir/_mlir_libs/_pto.cpython-*.so` + - `$PTO_SOURCE_DIR/build/python/mlir/dialects/pto.py` +- install 目录: + - `$PTO_INSTALL_DIR/bin/ptoas` + - `$PTO_INSTALL_DIR/mlir/_mlir_libs/_pto.cpython-*.so` + - `$PTO_INSTALL_DIR/mlir/dialects/pto.py` + - `$PTO_INSTALL_DIR/share/ptoas/oplib/level3` + +## 补充:运行环境 + +### 使用 build 目录中的 `ptoas` + +```bash +export PATH=$PTO_SOURCE_DIR/build/tools/ptoas:$PATH +export PYTHONPATH=$LLVM_INSTALL_DIR/python_packages/mlir_core:$PTO_SOURCE_DIR/build/python:$PYTHONPATH +export LD_LIBRARY_PATH=$LLVM_INSTALL_DIR/lib:$PTO_SOURCE_DIR/build/lib:$LD_LIBRARY_PATH +``` + +### 使用 install 目录中的 `ptoas` + +```bash +export PATH=$PTO_INSTALL_DIR/bin:$PATH +export PYTHONPATH=$LLVM_INSTALL_DIR/python_packages/mlir_core:$PTO_INSTALL_DIR:$PYTHONPATH +export LD_LIBRARY_PATH=$LLVM_INSTALL_DIR/lib:$PTO_INSTALL_DIR/lib:$LD_LIBRARY_PATH +``` + +注意: + +- install 版 `ptoas` 仍然需要从 `/opt/llvm/lib` 加载 LLVM/MLIR 共享库。 +- 如果直接运行 `$PTO_INSTALL_DIR/bin/ptoas` 而没有设置 `LD_LIBRARY_PATH=$LLVM_INSTALL_DIR/lib:...`,会报缺少 `libMLIR*.so`。 + +## 本地验证结果 + +当前仓库已验证通过以下组合: + +- `LLVM_DIR=/opt/llvm/lib/cmake/llvm` +- `MLIR_DIR=/opt/llvm/lib/cmake/mlir` +- `MLIR_PYTHON_PACKAGE_DIR=$PTO_INSTALL_DIR` +- `CMAKE_INSTALL_PREFIX=$PTO_INSTALL_DIR` + +最小验证结果: + +- build 版 `ptoas --version` 输出 `ptoas 0.22` +- build 版 `ptoas` 可成功处理 `test/basic/empty_func.pto` +- install 版 Python 绑定可在 `PYTHONPATH=/opt/llvm/python_packages/mlir_core:$PTO_INSTALL_DIR` 下正常导入 +- 若 install 版 `ptoas` 配合 `LD_LIBRARY_PATH=/opt/llvm/lib:$PTO_INSTALL_DIR/lib`,可正常执行 diff --git a/docs/designs/ptoas-tileop-expand-design.md b/docs/designs/ptoas-tileop-expand-design.md index f233d4a11..4a71b25a8 100644 --- a/docs/designs/ptoas-tileop-expand-design.md +++ b/docs/designs/ptoas-tileop-expand-design.md @@ -64,8 +64,7 @@ func.func @TADD( %b: !pto.tile_buf, %c: !pto.tile_buf) - attributes { pto.tile_function = "pto.tadd" } { + blayout=row_major, slayout=none_box, fractal=512, pad=0>) { %vecA = pto.tile_buf_addr %a : !pto.tile_buf -> memref<16x64xf32, strided<[64, 1]>, #pto.address_space> @@ -105,79 +104,125 @@ func.func @TADD( ### 2.1 总体思路 -为了降低开发门槛并解决参数组合的穷举问题,我们采用 PTO DSL 来编写 Tile Lib 的向量库实现。这套语法定义在 TileLang 中,库开发者使用 Python 编写模板函数,由 PTOAS 编译器在编译时进行实例化。 +为了降低开发门槛并解决参数组合的穷举问题,我们采用 **TileLang Python DSL** 来编写 +Tile Lib 的向量库实现。库开发者使用 Python 编写 vkernel 函数,PTOAS 编译器在编译时 +根据具体的 Tile op 以及操作数类型进行匹配、特化(specialization)和实例化(instantiation)。 + +TileLang DSL 的完整语法定义在 `tilelang-dsl/docs/tilelang-dsl-guide.md`,本章在该文档 +基础上,聚焦于本方案所依赖的语言子集及其语义约束。 整体方案: -1. **用 Python DSL 编写模板函数**:使用 `pto.Tile` 数据类型和向量操作接口,按 Tile 指令语义编写向量实现。 -2. **编译器实例化模板**:PTOAS 在编译过程中遇到 Tile op 时,调用对应的模板函数,填入具体的 `tile_buf` 类型参数,生成特化后的向量 IR。 -3. **inline 到调用点**:特化后的向量 IR 直接 inline 到原 Tile op 的位置,继续后续优化和 lowering 流程。 +1. **用 TileLang Python DSL 编写 vkernel**:以 `@pto.vkernel` 装饰器声明匹配元数据 + (`target` / `op` 或 `ops` / `dtypes` / `constraints` / `priority`),函数体使用 + `pto.Tile` 数据类型和基础向量指令(`make_mask` / `vlds` / `vsts` / `vadd` / …) + 按 Tile 指令语义编写向量实现。 +2. **编译器匹配并特化 vkernel**:PTOAS 遇到 Tile op 时,通过 DSL 提供的 + `pto.select_kernel(target, concrete_op, operand_types, …)` 匹配候选 vkernel,按 DSL Guide + §Kernel Selection Mechanism 的规则(target → op → dtypes → constraints → priority) + 选出一条,再以调用点的具体 `tile_buf` 类型作为 specialization key 进行特化,生成 + 以 `tile_buf` 为形参的向量实现函数。 +3. **inline 到调用点**:特化后的向量 IR 以 `func.call` 形式插入到原 Tile op 的位置, + 随后由 `PTOInlineLibCall` pass inline 到调用点,继续后续优化和 lowering 流程。 ### 2.2 TADD 模板示例 -以 `pto.tadd`(逐元素加法)为例,使用 Python DSL 编写的模板函数如下: +以 `pto.tadd`(逐元素加法)为例,TileLang DSL 编写的 vkernel 如下(`PAT` 是 +`pto.MaskPattern` 的别名;算子名按 DSL Guide §Kernel Declaration 约定,不带 `pto.` 前缀): ```python -@pto.tile_template(target="a5", op="pto.tadd") -def template_tadd(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): - dtype = src0.element_type - elem_size = src0.element_size - rows, cols = src0.shape - v_rows, v_cols = src0.valid_shape - - for i in range(0, v_rows, 1): - remaining = v_cols - for j in range(0, v_cols, 256 / elem_size): - all_mask, remaining = pto.make_mask(dtype, remaining) - vec_a = pto.vlds(a[i, j]) - vec_b = pto.vlds(b[i, j]) - result = pto.vadd(vec_a, vec_b, all_mask) - pto.vsts(result, c[i, j], all_mask) +from pto import MaskPattern as PAT + +@pto.vkernel( + target="a5", + op="tadd", # 匹配 pto.tadd + dtypes=[(pto.f32, pto.f32, pto.f32)], # 操作数类型签名 (src0, src1, dst) + advanced=True, # 启用隐式 vecscope 推断 +) +def template_tadd(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile) -> None: + dtype = dst.element_type # 编译期静态 + valid_rows, valid_cols = dst.valid_shape # 静态或动态 + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + summed = pto.vadd(lhs, rhs, mask) + pto.vsts(summed, dst[row, col:], mask) + return None ``` 代码解读: -- **`@pto.tile_template`** 装饰器指示这是一个 `pto.tadd` 指令的模板,会在编译时进行实例化。 -- **输入参数**为 3 个 `pto.Tile` 数据类型参数,2 个输入(`src0`、`src1`),1 个输出(`dst`)。 -- 通过 **`Tile` 数据类型接口**获取元素类型(`element_type`)、元素大小(`element_size`)、静态 shape(`shape`)和 valid shape(`valid_shape`)信息。 -- 通过 **2 层循环**分别遍历 tile 的行和列。 -- 通过 **`pto.make_mask`** 指令,根据基础数据类型大小及有效数据数量设置 mask 寄存器。 -- 通过 **`pto.vlds`** 指令,以 `a[i, j]` 和 `b[i, j]` 为起始地址分别将数据读入向量寄存器。 -- 通过 **`pto.vadd`** 计算相加结果,写入寄存器 `result`。 -- 通过 **`pto.vsts`** 将 `result` 写入以 `c[i, j]` 为起始的地址区间。 +- **`@pto.vkernel`** 装饰器声明本 kernel 匹配 `a5` 架构下的 `tadd` 算子、操作数签名 + `(f32, f32, f32)`。`advanced=True` 让编译器对函数体内的 `vlds`/`vadd`/`vsts` 序列 + 自动推断 `pto.vecscope`,无需显式 `with pto.vecscope():` 包裹 + (详见 DSL Guide §Implicit Scope Inference)。 +- **kernel 参数**为 3 个 `pto.Tile` 对象(2 个输入 `src0` / `src1`,1 个输出 `dst`), + 对应 VPTO IR 中的 `!pto.tile_buf` 类型,它们是实例化时被特化的 symbolic value。 + **参数顺序必须与 PTOAS 中对应指令的操作数顺序一致**(即 `ins` 在前、`outs` 在后), + 因为 `ExpandTileOp` 按位置索引直接传递操作数。 +- 通过 **`Tile` 属性接口**读取元素类型 `element_type` 和 `valid_shape`。参考 DSL Guide + §Tile Attributes:`shape` / `element_type` / `memory_space` / `config` 都是编译期静态 + 值,`valid_shape` 允许为静态或动态。 +- **2 层循环**分别遍历 tile 的行和列。外层步长 1,内层步长为 `pto.get_lanes(dtype)` + (单个向量寄存器可容纳的元素数,f32→64,f16→128)。 +- **`pto.make_mask(dtype, remained)`** 按 DSL Guide §Typed Masks 的 tail-processing 语义, + 返回 `(mask, new_remaining)`,并根据 `dtype` 自动选择正确的 mask 粒度 + (`f32` → `mask_b32`、`f16` → `mask_b16`、`i8` → `mask_b8`)。 +- **Tile 元素级索引语法糖** `src0[row, col:]` 实现向量宽度的 load/store + (DSL Guide §Address Generation Syntax Sugar):`col:` 后缀表示以 `col` 为起点、按 + 向量宽度连续读取;编译器按 `element_size` 和 layout 自动计算字节偏移,避免手写 + `i * cols * 4` 之类易错的算术。 +- **`pto.vadd(lhs, rhs, mask)`** 执行逐元素加法;**`pto.vsts(summed, dst[row, col:], mask)`** + 将结果带 mask 写回 `dst`。 ### 2.3 值模型与 Staging 语义 -模板函数中使用的 `pto.Tile` 属性,在模板执行时分为两类不同阶段(stage)的值: +TileLang DSL 按 DSL Guide §Value Model 的定义,采用 **symbolic value** 模型——函数体中的 +值并非 Python 运行时的 `int`/`float`,而是编译器构造的 SSA 值或编译期常量。在 vkernel +实例化过程中,`pto.Tile` 参数的属性按两种 stage 区分处理: #### 编译期静态值(Compile-time Static) -以下属性在模板实例化时已经确定,由 Python Codegen 在编译期折叠为字面量,**不会**生成 MLIR SSA 值: +以下属性在 vkernel 实例化时已经确定,由 TileLang Codegen 在编译期折叠为字面量, +**不会**生成 MLIR SSA 值: | 属性 | 来源 | 说明 | |------|------|------| -| `element_type` | `tile_buf` 的 `dtype` 字段 | 决定 vreg 类型和向量宽度 | +| `element_type` | `tile_buf` 的 `dtype` 字段 | 决定 vreg 类型和向量宽度;参与 specialization key | | `element_size` | 由 `dtype` 推导 | f32→4, f16→2, i8→1 | -| `shape` | `tile_buf` 的 `rows`, `cols` 字段 | **必须是编译期静态值**,参与模板实例化的 specialization key | -| `config` | `tile_buf` 的 blayout/slayout/fractal/pad | 布局和配置信息 | +| `shape` | `tile_buf` 的 `rows`, `cols` 字段 | **必须是编译期静态值**,参与 specialization key | +| `memory_space` | `tile_buf` 的 `loc` 字段 | `MemorySpace.GM` / `MemorySpace.UB`;参与 specialization key | +| `config` | `tile_buf` 的 blayout / slayout / fractal / pad | 决定 stride 模式和偏移计算方式 | -这些值在 Python 层直接参与运算(如 `256 / elem_size`),结果在编译期确定。 +这些值在 Python 层直接参与运算(如 `pto.get_lanes(dtype)`、`rows * cols * element_size`), +结果在编译期确定。DSL Guide §Tile Types 明确规定 **Static Shape Requirement**: +`shape` 必须是 compile-time constant。 #### 运行时 SSA 值(Runtime Dynamic) -以下属性可能在编译期未知,生成为 MLIR 函数参数或 SSA 值: +以下属性可能在编译期未知,生成为实例化函数的参数或 SSA 值: | 属性 | 来源 | 说明 | |------|------|------| -| `valid_shape` | `tile_buf` 的 `v_row`, `v_col` 字段 | **可以是静态也可以是动态** | +| `valid_shape` | `tile_buf` 的 `v_row`, `v_col` 字段 | **可以是静态也可以是动态**(DSL Guide §Tile Shape Concepts) | -当 `valid_shape` 为静态值时,Python Codegen 在编译期折叠(与 `shape` 相同处理方式);当为动态值时,生成为 MLIR 函数参数(`index` 类型),循环边界等依赖它的地方生成 `scf.for`。 +当 `valid_shape` 为静态值时,TileLang Codegen 在编译期折叠(与 `shape` 相同处理方式); +当为动态值时,生成为实例化函数的 `index` 类型参数,循环边界等依赖它的地方生成 +`scf.for`。该参数在 PTOAS 侧由 `pto.bind_tile` 的 `valid_row` / `valid_col` +操作数承载(参见第三章)。 #### 正式约束 -1. **`shape` 必须是编译期静态值**,并参与模板实例化的 specialization key。如果 `shape` 为动态值,模板实例化应报错拒绝。 -2. **`valid_shape` 可以是静态也可以是动态**。当为静态值时,Python Codegen 侧应检查 `valid_shape <= shape`(逐维度)。 -3. **`element_type`、`element_size`、`config` 必须是编译期静态值**,它们决定了模板函数体的结构(vreg 类型、向量宽度、stride 模式等)。 +1. **`shape` 必须是编译期静态值**,并参与 specialization key。若 `shape` 为动态值, + vkernel 实例化应报错拒绝。 +2. **`valid_shape` 可以是静态也可以是动态**。当为静态值时,TileLang Codegen 应检查 + `valid_shape[i] ≤ shape[i]`(逐维度),对齐 DSL Guide §Tile Shape Concepts 的约束。 +3. **`element_type`、`element_size`、`memory_space`、`config` 必须是编译期静态值**, + 它们决定了函数体的结构(vreg 类型、向量宽度、stride 模式等)。 #### 对控制流的影响 @@ -196,101 +241,162 @@ for i in range(0, rows, 1): # rows=16 静态 → Python 展开 16 次迭 ### 2.4 TileLang DSL 语法参考 -#### 2.4.1 基础数据类型 +本节摘录本方案所依赖的 DSL 子集;完整定义见 `tilelang-dsl/docs/tilelang-dsl-guide.md`。 + +#### 2.4.1 基础标量类型 | DSL 类型 | 说明 | 位宽 | |----------|------|------| -| `pto.i8` | 8 位整数 | 8 | +| `pto.i1` | 布尔 | 1 | +| `pto.i8` | 8 位整数 | 8 | | `pto.i16` | 16 位整数 | 16 | | `pto.i32` | 32 位整数 | 32 | | `pto.i64` | 64 位整数 | 64 | | `pto.f16` | 半精度浮点 | 16 | -| `pto.bf16` | BFloat16 | 16 | +| `pto.bf16`| Brain float 16 | 16 | | `pto.f32` | 单精度浮点 | 32 | -Python 字面量自动推导类型:`int` → `pto.i32`,`float` → `pto.f32`。 +Python 字面量自动推导类型:`bool` → `pto.i1`,`int` → 上下文决定(通常 `pto.i32`/`pto.i64`), +`float` → `pto.f32`。需要显式类型时可用 `x = pto.i32(1024)` 或类型注解。 + +DSL 还提供类型通配符 `pto.AnyFloat` / `pto.AnyInt` / `pto.AnyType` / `pto.AnyMask` +和类型变量 `pto.TypeVar(...)`,用于在 `dtypes=` 中写多态签名。 -#### 2.4.2 Tile 数据类型 +#### 2.4.2 向量与 Mask 类型 -`pto.Tile` 表示一个带有布局和配置信息的数据块,对应 MLIR 中的 `!pto.tile_buf` 类型。 +向量寄存器固定 **256 字节** 宽度: -**Tile 属性接口:** +```python +pto.vreg(64, pto.f32) # 64 lanes × 32 bit = 2048 bit +pto.vreg(128, pto.f16) # 128 lanes × 16 bit = 2048 bit +``` + +约束:`lanes × bitwidth(element_type) == 2048`。可用 `pto.get_lanes(dtype)` 获得 lane 数。 + +Mask 按位粒度分型(DSL Guide §Typed Masks),必须与 vreg 元素族匹配: + +| DSL 类型 | VPTO 类型 | 对应元素族 | +|----------|-----------|-----------| +| `pto.mask_b8` | `!pto.mask` | `i8` 向量 | +| `pto.mask_b16` | `!pto.mask` | `f16` / `bf16` / `i16` 向量 | +| `pto.mask_b32` | `!pto.mask` | `f32` / `i32` 向量 | + +粒度不匹配(例如 `f32` 向量配 `mask_b16`)会在类型检查阶段报错。 + +#### 2.4.3 Tile 数据类型 + +`pto.Tile` 表示一个带有布局和配置信息的数据块,对应 VPTO IR 中的 `!pto.tile_buf` 类型。 + +**Tile 属性接口**(DSL Guide §Tile Attributes): | 属性 | 类型 | 说明 | |------|------|------| -| `shape` | `tuple[int, ...]` | Tile 的完整维度(rows, cols) | -| `valid_shape` | `tuple[int, ...]` | 有效数据维度(v_row, v_col),可能小于 shape | -| `element_type` | `Type` | 元素数据类型(如 `pto.f32`) | -| `element_size` | `int` | 元素字节大小(如 f32 → 4) | -| `memory_space` | `MemorySpace` | 内存空间(GM, UB) | -| `config` | `TileConfig` | 布局和 padding 配置 | +| `shape` | `tuple[int, ...]` | **编译期静态**的物理维度(rows, cols) | +| `valid_shape` | `tuple[int, ...]` | 有效数据维度(v_row, v_col),可为静态或动态,须 ≤ `shape` | +| `element_type` | `Type` | 元素类型,如 `pto.f32` | +| `element_size` | `int` | 元素字节大小 | +| `memory_space` | `MemorySpace` | `MemorySpace.GM` / `MemorySpace.UB` | +| `config` | `TileConfig` | 布局与 padding 配置 | +| `rank` / `num_elements` / `valid_elements` | `int` | 派生属性 | -**Tile 配置:** +**Tile 配置枚举**: ```python -pto.BLayout.ROW_MAJOR # 行主序 -pto.BLayout.COL_MAJOR # 列主序 -pto.SLayout.NONE_BOX # 无二级布局 -pto.PadValue.NULL # 无 padding -pto.PadValue.ZERO # 零填充 +pto.BLayout.ROW_MAJOR / pto.BLayout.COL_MAJOR # 基础布局 +pto.SLayout.NONE_BOX / pto.SLayout.ROW_MAJOR / pto.SLayout.COL_MAJOR +pto.PadValue.NULL / pto.PadValue.ZERO / pto.PadValue.MAX / pto.PadValue.MIN ``` -#### 2.4.3 向量操作接口 - -向量寄存器固定 256 字节宽度,每次处理的元素数量由数据类型决定:f32 → 64 个元素,f16 → 128 个元素。 +**地址生成语法糖**(DSL Guide §Address Generation Syntax Sugar)——向量级读写使用 +元素索引语法,编译器自动按 layout 计算字节偏移: -**Mask 操作:** - -| 操作 | 说明 | +| 语法 | 含义 | |------|------| -| `pto.make_mask(dtype, remaining)` | 根据数据类型和剩余元素数量生成 mask,返回 `(mask, new_remaining)` | -| `pto.make_mask(dtype, PAT.ALL)` | 生成全 1 mask | +| `tile[row, col:]` | 行主序:从 `(row, col)` 起按向量宽度连续读 | +| `tile[row:, col]` | 列主序:从 `(row, col)` 起按向量宽度连续读 | +| `tile[start:]` | 1D tile:从 `start` 起按向量宽度连续读 | +| `tile[row, col]` | 单元素(仅 `pto.vsld` 等 broadcast load 使用) | -**向量 Load/Store:** +#### 2.4.4 向量操作接口 -| 操作 | 说明 | -|------|------| -| `pto.vlds(tile[i, j])` | 从 Tile 的 `[i, j]` 位置加载一个向量寄存器的数据 | -| `pto.vsts(vec, tile[i, j], mask)` | 将向量寄存器数据写入 Tile 的 `[i, j]` 位置 | +本方案依赖 DSL Guide §Operations 中列在 **`stable`** tier 的 base vector ops: -**二元向量运算:** +**Mask 生成**(DSL Guide §`pto.make_mask`): -| 操作 | 说明 | +| 形式 | 说明 | |------|------| -| `pto.vadd(vec1, vec2, mask)` | 逐元素加法 | -| `pto.vsub(vec1, vec2, mask)` | 逐元素减法 | -| `pto.vmul(vec1, vec2, mask)` | 逐元素乘法 | -| `pto.vdiv(vec1, vec2, mask)` | 逐元素除法 | -| `pto.vmax(vec1, vec2, mask)` | 逐元素取大 | -| `pto.vmin(vec1, vec2, mask)` | 逐元素取小 | +| `pto.make_mask(dtype, remaining: pto.i32)` | Tail processing:返回 `(mask, new_remaining)` | +| `pto.make_mask(dtype, PAT.ALL)` | 固定 pattern:返回单值 `mask`。其它 pattern 包括 `PAT.EVEN`/`PAT.ODD` 等 | -**一元向量运算:** +**向量 Load / Store**: | 操作 | 说明 | |------|------| -| `pto.vabs(vec, mask)` | 逐元素绝对值 | -| `pto.vexp(vec, mask)` | 逐元素指数 | -| `pto.vln(vec, mask)` | 逐元素对数 | -| `pto.vsqrt(vec, mask)` | 逐元素开方 | -| `pto.vrelu(vec, mask)` | 逐元素 ReLU | +| `pto.vlds(tile[row, col:])` | 从 tile 的 `(row, col)` 按向量宽度加载到 vreg | +| `pto.vsts(vec, tile[row, col:], mask)` | 将 vreg 按 mask 写入 tile 的 `(row, col)` | -**向量-标量运算:** +上述两条也支持 DSL Guide 中的 byte-offset 形式 `pto.vlds(buf, offset)` / `pto.vsts(vec, buf, offset, mask)` +(Advanced Tier),但模板库优先使用元素索引语法。 + +**基础二元/一元算子**(用于常见 Tile op 的展开): | 操作 | 说明 | |------|------| -| `pto.vmuls(vec, scalar, mask)` | 向量乘标量 | -| `pto.vadds(vec, scalar, mask)` | 向量加标量 | +| `pto.vadd / vsub / vmul / vdiv(vec1, vec2, mask)` | 逐元素二元运算 | +| `pto.vmax / vmin(vec1, vec2, mask)` | 逐元素比较 | +| `pto.vabs / vexp / vln / vsqrt / vrelu(vec, mask)` | 逐元素一元运算 | +| `pto.vmuls / vadds(vec, scalar, mask)` | 向量-标量运算 | -#### 2.4.4 控制流 +#### 2.4.5 控制流 **循环**使用 Python 的 `range` 语法: ```python -for i in range(0, v_rows, 1): - # 循环体 +for i in range(0, valid_rows, 1): + for j in range(0, valid_cols, pto.get_lanes(dtype)): + ... +``` + +当循环边界来自 `shape`(编译期常量)时,DSL 在 Python 层直接展开循环;当来自 +`valid_shape`(可能是动态值)时,生成 `scf.for` MLIR 循环。 + +**向量作用域**:本方案的 vkernel 统一使用 `advanced=True`,由编译器的 Scope Inference Pass +对连续、数据依赖的 `vlds`/`vadd`/`vsts` 序列自动推断 `pto.vecscope` 边界,库开发者无需 +显式书写 `with pto.vecscope(): ...`。需要精确控制时可使用 `strict_vecscope`(Advanced Tier)。 + +#### 2.4.6 多算子模板(template slots) + +对于计算骨架相同、仅核心算子不同的一组 Tile op(如 `tadd`/`tsub`/`tmul`/`tdiv`), +可用 DSL Guide §Template-based Kernel Authoring 的 `ops=[...]` + `templates=` + `pto.tpl(...)` +在一个 vkernel 中共享实现: + +```python +@pto.vkernel( + target="a5", + ops=["tadd", "tsub", "tmul", "tdiv"], + dtypes=[(T, T, T)], + advanced=True, + templates={ + "core": {"tadd": "vadd", "tsub": "vsub", + "tmul": "vmul", "tdiv": "vdiv"}, + }, +) +def elementwise_arithmetic(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + rows, cols = dst.valid_shape + for row in range(0, rows, 1): + remained = cols + for col in range(0, cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + out = pto.tpl("core", lhs, rhs, mask) # 按选中的具体算子替换 + pto.vsts(out, dst[row, col:], mask) ``` -当循环边界来自 `shape`(编译期常量)时,DSL 在 Python 层展开循环;当来自 `valid_shape`(可能是运行时动态值)时,生成 `scf.for` MLIR 循环。 +编译期 `pto.select_kernel(...)` 会把具体的 `tadd`/`tsub`/… 绑定到 `selected_op`, +`pto.tpl("core", ...)` 再按 `templates["core"]` 的映射展开为真正的 `vadd`/`vsub`/… 调用。 +这样本方案的 Tile Lib 可以用一份模板覆盖四条逐元素算子,显著收敛维护成本。 ## 第三章 PTOAS 编译器:TileOp Expand @@ -544,11 +650,44 @@ func.func @TADD(%a: !pto.tile_buf<...>, %b: !pto.tile_buf<...>, %c: !pto.tile_bu #### 3.4.4 经过 Fold TileBuf Intrinsics 后 -Fold pass 将 `pto.tile_buf_addr`、`pto.tile_valid_rows`、`pto.tile_valid_cols` 替换为具体值: +Fold pass 通过严格的模式匹配,将 `pto.tile_buf_addr`、`pto.tile_valid_rows`、`pto.tile_valid_cols` +解析回 `MemrefToTileBuf` pass 在调用点构造的具体 SSA 值。 -- `pto.tile_buf_addr %a` → 折叠为调用点已知的 memref 值(从 tile_buf 提取底层地址) -- `pto.tile_valid_rows %a` → 如果 `v_row=16` 是静态的,折叠为 `arith.constant 16 : index`;如果是动态的(`v_row=?`),折叠为调用点传入的动态 index 值 -- `pto.tile_valid_cols %a` → 同理 +**严格模式匹配**:每一个被折叠的 intrinsic,其 `tile_buf` 操作数必须由如下固定链定义 +(由 `MemrefToTileBuf` pass 保证),否则 pass 直接报错并失败: + +```mlir +%0 = pto.pointer_cast(%addr) {config = ...} + : memref<16x64xf32, strided<[64, 1]>, #pto.address_space> +%1 = pto.bind_tile %0, %v_row, %v_col {config = ...} + : memref<16x64xf32, strided<[64, 1]>, ...> + -> memref<16x64xf32, strided<[64, 1], offset: ?>, ...> +%2 = builtin.unrealized_conversion_cast %1 + : memref<...> to !pto.tile_buf +``` + +也即:`tile_buf ← unrealized_conversion_cast ← pto.bind_tile ← pto.pointer_cast`。 + +**三条折叠规则**(均锚定到 `pto.bind_tile`): + +- `pto.tile_buf_addr %a` → 折叠为 `bind_tile` 的 **第一个操作数**(即 `pto.pointer_cast` 的结果)。 + 注意这里**绕过**了 `bind_tile` 自身产出的、带 `offset: ?` 的动态布局 memref, + 直接复用上游的 `strided<[64, 1]>` 静态布局 memref。这样下游的 `pto.vlds`/`pto.vsts` + 在被规范化、最终下沉到 VPTO 后端时,看到的始终是干净的 `strided<[..], offset: 0>` 布局, + 避免了 `pto.vlds does not support dynamic memref layout offsets` 这类下游错误。 + 若 `tile_buf_addr` 声明的结果类型与 `bind_tile` 源 memref 的实际布局不一致, + 会就地把结果类型替换为源 memref 的真实类型——下游向量算子对相同 element type / shape + 的 strided 布局是多态的。 +- `pto.tile_valid_rows %a` → 优先按 `TileBufType.validShape[0]` 静态折叠: + 若是静态值(如 `v_row=16`),折叠为 `arith.constant 16 : index`; + 若是动态值(`v_row=?`),折叠为 `bind_tile` 的 **第二个操作数**(`valid_row`,已经是 `index` 类型)。 +- `pto.tile_valid_cols %a` → 同理,使用 `validShape[1]` 或 `bind_tile` 的 **第三个操作数**。 + +**跳过 TileLang 模板实例**:被 `PTOInlineLibCall` 内联完且作为 dead callee 删除之前, +带 `pto.tilelang.instance` 属性的私有模板函数仍可能保留在 module 中。这些函数体内的 +`pto.tile_buf_addr` 等 intrinsic 直接作用在 `tile_buf` 类型的 BlockArgument 上, +没有 `bind_tile` 可供折叠——pass 通过检测 `pto.tilelang.instance` 属性跳过这些函数, +留给下游 DCE 清理。 折叠后得到最终的纯向量 IR,不再包含任何 tile_buf 引用: @@ -618,7 +757,7 @@ lib/TileOp/ ← 模板库根目录 advanced=True, # 启用隐式 vecscope 推断 name="template_", ) -def template_xxx(dst: pto.Tile, src0: pto.Tile, ...): +def template_xxx(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): # 向量化实现体 dtype = dst.element_type valid_rows, valid_cols = dst.valid_shape @@ -630,6 +769,11 @@ def template_xxx(dst: pto.Tile, src0: pto.Tile, ...): return None ``` +**关键约束:模板参数顺序必须与 PTOAS 中对应指令的操作数顺序严格一致。** +`ExpandTileOp` 按位置索引将指令操作数直接传递给模板函数参数。对于 DPS 风格的 +算子,这意味着 `ins` 操作数在前、`outs` 在后。例如 `pto.tadd ins(%a, %b) outs(%c)` +的操作数顺序为 `(src0, src1, dst)`,模板参数必须为 `(src0, src1, dst)`。 + `expand_helper.py` 自动扫描目录下所有 `.py` 文件,按 `op` 名称和 `dtype` 签名匹配模板。 @@ -653,9 +797,130 @@ def template_xxx(dst: pto.Tile, src0: pto.Tile, ...): | MLIR 解析与 inline | 解析生成的 MLIR 文本,inline 到调用点,绑定参数 | | Cleanup | 实例化后运行 canonicalize 清理冗余 | -### 4.3 测试与文档 +### 4.3 PTOAS 编译器:Fold TileBuf Intrinsics Pass + +| 工作项 | 说明 | +|--------|------| +| 严格模式匹配 | 要求 `tile_buf` 由 `unrealized_conversion_cast ← pto.bind_tile` 链定义,否则 emit error 并 fail pass | +| `tile_buf_addr` 折叠 | 替换为 `bind_tile.getSource()`(即 `pto.pointer_cast` 的静态布局 memref),绕过 `bind_tile` 产出的动态 offset 布局,避免下游 VPTO 后端无法处理 `offset: ?` | +| 结果类型自适应 | 若 `tile_buf_addr` 声明类型与 source memref 实际布局不一致,就地更新结果类型(下游向量算子对 strided 布局多态) | +| `tile_valid_rows/cols` 折叠 | 优先按 `TileBufType.validShape` 静态折叠为 `arith.constant`;动态时取 `bind_tile` 的 `valid_row`/`valid_col` 操作数(均为 `index` 类型,无需 cast) | +| 跳过模板实例 | 检测 `pto.tilelang.instance` 属性,跳过 `PTOInlineLibCall` 删除前残留的私有模板函数 | + +### 4.4 测试与文档 - Python DSL 模板编写和实例化的单元测试 + 以当前 `lib/TileOps/tadd_template.py` 为例,新增/维护 + `test/basic/expand_tile_op_tilelang.pto` + 作为 `pto.tadd` TileLang 模板实例化的基础回归。该用例覆盖: + 1. `ExpandTileOp` 是否能匹配 `pto.tadd` 并调用 Python DSL helper; + 2. 模板实例化后的 `func.call` 是否能被 inline; + 3. `FoldTileBufIntrinsics` 之后是否得到 `pto.vlds` / `pto.vadd` / `pto.vsts` 形式的 Vector IR。 + + 当前 `pto.tadd` 的向量库模板实现如下: + + ```python + import sys + from pathlib import Path + import tilelang_dsl as pto + + + @pto.vkernel( + target="a5", + op="pto.tadd" + ) + def template_tadd(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + summed = pto.vadd(lhs, rhs, mask) + pto.vsts(summed, dst[row, col:], mask) + return + ``` + + 对应的单元测试用例如下: + + ```mlir + // Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline + // expands pto.tadd via the default TileLang Python DSL template + // lib/TileOps/tadd_template.py. + // + // Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics + // + // RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + + // After the full tile-op-expand path on the VPTO backend, the original + // pto.tadd should be lowered to vector-style VPTO IR. + // CHECK: func.func @TADD + // CHECK-NOT: pto.tadd ins + // CHECK: pto.vecscope + // CHECK: pto.castptr + // CHECK: pto.vlds + // CHECK: pto.vadd + // CHECK: pto.vsts + + module { + func.func @TADD() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.tadd ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%tile_buf : !pto.tile_buf) + return + } + } + ``` - Expand TileOp pass 的端到端测试(`pto.tadd` → Vector IR) + 使用以下命令同时观察中间 IR 和最终 LLVM IR: + + ```bash + ./build/tools/ptoas/ptoas test/basic/expand_tile_op_tilelang.pto \ + --pto-arch=a5 \ + --print-ir-after-all \ + --pto-backend=vpto \ + --enable-tile-op-expand \ + --vpto-emit-hivm-llvm \ + -o - \ + > add.ll \ + 2> /tmp/expand_tile_op_tilelang.mlir + ``` + + 说明: + - `stderr` 中的 `/tmp/expand_tile_op_tilelang.mlir` 保存 `--print-ir-after-all` 打印的各阶段 MLIR/VPTO IR,可用于检查模板是否已经从 `pto.tadd` 展开为向量 IR。 + - `stdout` 中的最终产物是 textual LLVM IR,因此这里使用 `-o - > add.ll` 显式落盘,而不是依赖 `-o ` 与 `--print-ir-after-all` 混用时的输出行为。 + + 随后将生成的 `add.ll` 交给 Bisheng: + + ```bash + bisheng \ + --target=hiipu64-hisilicon-cce \ + -march=dav-c310-vec \ + --cce-aicore-arch=dav-c310-vec \ + --cce-aicore-only \ + -c -x ir add.ll \ + -o add.o + ``` + + 若上述命令成功生成 `add.o`,则说明当前 `pto.tadd` 的向量库模板已经完成: + - TileLang 模板实例化; + - `pto.tadd -> Vector IR -> LLVM IR` 的端到端 lowering; + - Bisheng 设备侧编译校验。 - 融合场景测试(多个 Tile op 连续使用后的 VF Fusion) - 更新 `PTO_IR_manual.md` 和 TileLang DSL Guide diff --git a/docs/tilelang-dsl-guide.md b/docs/tilelang-dsl-guide.md deleted file mode 100644 index b6dfabf3f..000000000 --- a/docs/tilelang-dsl-guide.md +++ /dev/null @@ -1,2978 +0,0 @@ -# TileLang Python DSL Guide - -The TileLang Python DSL provides a high-level, Pythonic interface for authoring vector compute kernels targeting the Ascend NPU hardware. This guide is intended for library developers and performance engineers who need to write efficient, hardware-aware kernels using the PTO micro instruction set. - -The DSL is designed to generate MLIR function libraries rather than direct binary executables. These MLIR libraries are intended to be consumed by other compilation frameworks that transform high-level tile semantics into low-level vector operations. This enables library developers to focus on hardware-aware kernel authoring while relying on upstream compilers for tile-level optimizations and code generation. - -## Current Implementation Status - -The current `tilelang_dsl` package in this repository implements: -- matcher support: - `KernelRegistry`, `pto.select_kernel(...)`, multi-signature `dtypes`, - `AnyFloat` / `AnyInt` / `AnyType` / `AnyMask`, `TypeVar`, `constraints`, - `priority` -- advanced authoring support: - implicit vecscope inference in `advanced=True` kernels -- raw pointer / low-level DMA support: - `ptr(...)`, `castptr`, `addptr`, low-level DMA config ops, - `copy_gm_to_ubuf`, `copy_ubuf_to_gm`, `copy_ubuf_to_ubuf` -- advanced vector-family lowering: - compare/select, predicate movement, carry, rearrangement - -Still deferred in the current package head: -- reduction-family authoring - -Reason: -- the repo does not yet expose a public authoring-form VPTO reduction op that - the standalone TileLang DSL can target directly - -For the package-local source of truth, see: -- `tilelang-dsl/docs/v1-surface.md` -- `tilelang-dsl/docs/v1-lowering.md` -- `tilelang-dsl/docs/matcher-and-advanced-surface-migration.md` - -## Quick Start - -**Note on mask pattern enums**: For brevity, examples in this guide use `PAT` as an alias for `pto.MaskPattern` (e.g., `PAT.ALL` instead of `pto.MaskPattern.PAT_ALL`). You can create this alias with `from pto import MaskPattern as PAT` or `PAT = pto.MaskPattern`. - -Here's a minimal example of a tile scaling kernel using the new Tile type: - -```python -import pto - -@pto.vkernel(target="a5", op="scale", dtypes=[(pto.AnyFloat, pto.AnyFloat)], priority=10) -def tile_scale(input_tensor: pto.TensorView, # Input tensor view (shape: 256x128, f32, GM) - output_tensor: pto.TensorView, # Output tensor view (same shape and type) - scale_factor: pto.f32): # Scaling factor - # Access tensor properties - rows, cols = input_tensor.shape # (256, 128) - dtype = input_tensor.element_type # pto.f32 - - # Create a temporary tile in UB for computation - ub_tile = pto.tile((rows, cols), dtype, pto.MemorySpace.UB) - - # Load input tensor from GM to UB using high-level DMA operation - pto.dma_load(input_tensor, ub_tile) - - # Vector computation: scale all elements in the tile - all_mask = pto.make_mask(dtype, PAT.ALL) - - # Process tile in row-major order - for row in range(0, rows): - # Process each row in vector chunks - # Vector width is hardware-defined: 256 bytes / element size - # For f32: 256/4 = 64 lanes, for f16: 256/2 = 128 lanes - vector_lanes = pto.get_lanes(dtype) # Compute vector lanes based on element type (e.g., 64 for f32, 128 for f16) - for col_start in range(0, cols, vector_lanes): - # Load vector using element-indexing syntax (no manual byte calculation) - vec = pto.vlds(ub_tile[row, col_start:]) - - # Scale vector - scaled = pto.vmuls(vec, scale_factor, all_mask) - - # Store result back using element-indexing syntax - pto.vsts(scaled, ub_tile[row, col_start:], all_mask) - - # Store result from UB back to GM output tensor using high-level DMA operation - pto.dma_store(ub_tile, output_tensor) -``` - -This example demonstrates: -1. **TensorView parameters** in kernel declaration -2. **TensorView property access** (shape, element_type) -3. **Tile creation** for temporary buffers -4. **High-level DMA operations** (`dma_load`/`dma_store`) for data movement -5. **Implicit tile→UBRef conversion** in vector load/store operations -6. **Automatic DMA parameter inference** from tensor slices and tile properties - -For an even more concise example showing pure computation on UB tiles (assuming data is already in UB): - -```python -@pto.vkernel(target="a5", op="elementwise", dtypes=[(pto.AnyFloat, pto.AnyFloat, pto.AnyFloat)], priority=10) -def ub_tile_computation(a: pto.Tile, # UB tile - b: pto.Tile, # UB tile - c: pto.Tile): # UB tile (output) - dtype = a.element_type - - # All tiles are in UB memory space - all_mask = pto.make_mask(dtype, PAT.ALL) - rows, cols = a.shape - - # Element-wise: c = a + b * 2.0 - for i in range(0, rows * cols, 64): - # Load vectors from UB tiles using element-indexing syntax - vec_a = pto.vlds(a[i:]) # Implicit tile→UBRef with automatic offset calculation - vec_b = pto.vlds(b[i:]) - - # Compute: b * 2.0 - scaled_b = pto.vmuls(vec_b, 2.0, all_mask) - - # Compute: a + scaled_b - result = pto.vadd(vec_a, scaled_b, all_mask) - - # Store result to output tile using element-indexing syntax - pto.vsts(result, c[i:], all_mask) -``` - -## Core Concepts - -### Kernel Declaration - -Kernels are defined using the `@pto.vkernel` decorator with enhanced matching capabilities for PTO operations. The decorator specifies matching criteria for target architecture, operation type, data types, and additional constraints, along with a priority for disambiguation when multiple kernels match. - -#### Basic Syntax - -```python -@pto.vkernel( - target="a5", # Target architecture - op="matmul", # PTO operation name to match - dtypes=[(pto.f16, pto.f16, pto.f32)], # Type signatures - constraints=[ # Additional constraints - AnyOf(k_dim_aligned_64, continuous_memory), - Not(requires_ub_memory) - ], - priority=100 # Priority for selection -) -def matmul_fallback(a: pto.Tile, b: pto.Tile, c: pto.Tile) -> None: - # kernel implementation -``` - -#### Decorator Parameters - -| Parameter | Type | Required | Description | -|-----------|------|----------|-------------| -| `target` | `str` | Yes | Target hardware architecture (e.g., `"a5"` for Ascend 950). | -| `op` | `str` | Yes | Name of the PTO operation to match (e.g., `"matmul"`, `"conv2d"`, `"add"`). | -| `dtypes` | `List[Tuple[Type, ...]]` | Yes | List of type signatures. Each tuple specifies the expected data types for the operation's operands (inputs and outputs) in order. | -| `constraints` | `List[Constraint]` | No | Additional constraints that must be satisfied for the kernel to be selected. Can include logical combinations (`AnyOf`, `AllOf`, `Not`). Default: empty list. | -| `priority` | `int` | No | Selection priority when multiple kernels match. Higher values have higher priority. Default: `0`. | -| `name` | `str` | No | Kernel name (used for debugging and profiling). Defaults to the decorated function's name. | - -#### Type Matching Rules - -The `dtypes` parameter supports flexible type matching: - -1. **Concrete Types**: Exact type matches using DSL scalar types: - - `pto.f16`, `pto.f32`, `pto.bf16` - - `pto.i8`, `pto.i16`, `pto.i32`, `pto.i64` - - `pto.mask_b8`, `pto.mask_b16`, `pto.mask_b32` - -2. **Type Wildcards**: Generic type patterns: - - `pto.AnyFloat`: Matches any floating-point type (`f16`, `bf16`, `f32`) - - `pto.AnyInt`: Matches any integer type (`i8`, `i16`, `i32`, `i64`) - - `pto.AnyType`: Matches any scalar type - - `pto.AnyMask`: Matches any mask type (`mask_b8`, `mask_b16`, `mask_b32`) - -3. **Type Variables**: Named type variables that enforce consistency within a signature: - ```python - T = pto.TypeVar('T') # Define a type variable - - @pto.vkernel( - target="a5", - op="elementwise", - dtypes=[(T, T, T)], # All three operands must have the same type - constraints=[] - ) - def elementwise_same_type(x: pto.Tile, y: pto.Tile, out: pto.Tile) -> None: - # x, y, and out must have identical element types - pass - ``` - -4. **Mixed Signatures**: Multiple type signatures for the same operation: - ```python - @pto.vkernel( - target="a5", - op="add", - dtypes=[ - (pto.AnyFloat, pto.AnyFloat, pto.AnyFloat), # Float addition - (pto.AnyInt, pto.AnyInt, pto.AnyInt) # Integer addition - ] - ) - def generic_add(a: pto.Tile, b: pto.Tile, c: pto.Tile) -> None: - # Supports both float and integer types - pass - ``` - -#### Constraint System - -Constraints are compile-time predicates that refine kernel selection. The system supports logical combinations of constraints. - -##### Predefined Constraints - -| Constraint | Description | -|------------|-------------| -| `k_dim_aligned_64` | K dimension is aligned to 64 elements (for matmul kernels). | -| `continuous_memory` | Operands reside in contiguous memory regions. | -| `requires_ub_memory` | Operation requires Unified Buffer memory (vs. Global Memory). | -| `tensor_rank(rank)` | Operand tensor has specified rank (e.g., `tensor_rank(2)` for 2D tensors). | -| `broadcastable` | Operands are broadcastable according to NumPy-style broadcasting rules. | -| `static_shape` | All tensor dimensions are known at compile time (no dynamic shapes). | - -##### Logical Constraint Combinators - -| Combinator | Description | Example | -|------------|-------------|---------| -| `AnyOf(c1, c2, ...)` | At least one of the constraints must be satisfied. | `AnyOf(k_dim_aligned_64, continuous_memory)` | -| `AllOf(c1, c2, ...)` | All constraints must be satisfied. | `AllOf(tensor_rank(2), static_shape)` | -| `Not(c)` | The constraint must not be satisfied. | `Not(requires_ub_memory)` | - -##### Custom Constraints - -Users can define custom constraints using predicate functions: - -```python -# Define a custom constraint -def large_batch(batch_size: pto.i32) -> pto.Constraint: - """Batch size must be ≥ 1024.""" - return pto.Constraint(lambda op: op.batch_size >= batch_size) - -@pto.vkernel( - target="a5", - op="matmul", - dtypes=[(pto.AnyFloat, pto.AnyFloat, pto.AnyFloat)], - constraints=[large_batch(1024)] -) -def large_batch_matmul(a: pto.Tile, b: pto.Tile, c: pto.Tile) -> None: - # Optimized for large batch sizes - pass -``` - -#### Kernel Selection Mechanism - -When a PTO operation needs implementation, the system performs the following matching process: - -1. **Target Filtering**: Select kernels with matching `target` architecture. -2. **Operation Filtering**: Select kernels with matching `op` name. -3. **Type Matching**: For each kernel's `dtypes` list, check if any signature matches the operation's operand types: - - Concrete types must match exactly. - - Wildcard types match according to their category. - - Type variables must be consistent within the signature. -4. **Constraint Validation**: For each matching kernel, evaluate all `constraints`. If any constraint fails, the kernel is rejected. -5. **Priority Selection**: From the remaining kernels, select the one with the highest `priority` value. -6. **Fallback**: If no kernel matches, compilation fails with an error. - -The package also exposes explicit selection utilities: - -```python -registry = pto.KernelRegistry() -registry.register(my_kernel) - -selected = pto.select_kernel( - "a5", - "matmul", - (pto.f16, pto.f16, pto.f32), - context_attrs={"k_aligned": True}, - registry=registry, -) -``` - -#### Examples - -##### Matmul with Multiple Implementations - -```python -# High-performance kernel for aligned K dimension -@pto.vkernel( - target="a5", - op="matmul", - dtypes=[(pto.f16, pto.f16, pto.f32)], - constraints=[k_dim_aligned_64], - priority=200 -) -def matmul_aligned_k(a: pto.Tile, b: pto.Tile, c: pto.Tile) -> None: - # Optimized implementation for aligned K - pass - -# General-purpose fallback -@pto.vkernel( - target="a5", - op="matmul", - dtypes=[(pto.AnyFloat, pto.AnyFloat, pto.AnyFloat)], - constraints=[], - priority=100 -) -def matmul_general(a: pto.Tile, b: pto.Tile, c: pto.Tile) -> None: - # Generic implementation - pass -``` - -##### Elementwise Operation with Type Polymorphism - -```python -@pto.vkernel( - target="a5", - op="add", - dtypes=[ - (pto.AnyFloat, pto.AnyFloat, pto.AnyFloat), - (pto.AnyInt, pto.AnyInt, pto.AnyInt) - ], - constraints=[broadcastable] -) -def polymorphic_add(a: pto.Tile, b: pto.Tile, out: pto.Tile) -> None: - # Single implementation handles both float and integer types - dtype = a.element_type - all_mask = pto.make_mask(dtype, PAT.ALL) - # ... implementation using generic vector operations - pass -``` - -##### Constrained Convolution Kernel - -```python -@pto.vkernel( - target="a5", - op="conv2d", - dtypes=[(pto.f16, pto.f16, pto.f32)], - constraints=[ - AllOf( - tensor_rank(4), # NHWC format - static_shape, # No dynamic dimensions - Not(requires_ub_memory) # GM memory preferred - ) - ], - priority=150 -) -def conv2d_nhwc_f16_f32(input: pto.Tile, filter: pto.Tile, output: pto.Tile) -> None: - # Optimized for NHWC layout with static shapes - pass -``` - -### Value Model - -The DSL operates on symbolic values, not Python runtime values: -- **Constants**: Python literals that are typed to machine types -- **Operation results**: Values produced by DSL operations -- **Block arguments**: Values introduced by control flow structures - -### Memory Spaces - -The DSL supports different memory spaces: -- `MemorySpace.GM`: Global Memory -- `MemorySpace.UB`: Unified Buffer (local storage for vector computation) - -## Type System - -### Scalar Types - -| DSL Type | Description | Bit Width | -|----------|-------------|-----------| -| `pto.i1` | Boolean | 1 | -| `pto.i8` | 8-bit integer | 8 | -| `pto.i16` | 16-bit integer | 16 | -| `pto.i32` | 32-bit integer | 32 | -| `pto.i64` | 64-bit integer | 64 | -| `pto.f16` | Half precision float | 16 | -| `pto.bf16` | Brain float 16 | 16 | -| `pto.f32` | Single precision float | 32 | - -Python literals are automatically typed: -- `bool` → `pto.i1` -- `int` → Context-dependent (typically `pto.i32` or `pto.i64`) -- `float` → `pto.f32` - -For explicit typing, use type constructors: -```python -x = pto.i32(1024) # Explicit i32 constant -y: pto.i32 = 1024 # Type annotation -``` - -### Vector Types - -Vector registers have fixed 256-byte width: - -```python -v64_f32 = pto.vreg(64, pto.f32) # 64 lanes of f32 (64 * 32b = 2048b) -v128_f16 = pto.vreg(128, pto.f16) # 128 lanes of f16 (128 * 16b = 2048b) -``` - -Constraint: `lanes × bitwidth(element_type) = 2048` - -### Typed Masks - -Masks are typed by their bit granularity: - -| DSL Type | VPTO Type | Description | -|----------|-----------|-------------| -| `pto.mask_b8` | `!pto.mask` | 8-bit granularity mask | -| `pto.mask_b16` | `!pto.mask` | 16-bit granularity mask | -| `pto.mask_b32` | `!pto.mask` | 32-bit granularity mask | - -Mask operations must match the vector element family: -- `f32` vectors use `mask_b32` -- `f16` vectors use `mask_b16` -- `i8` vectors use `mask_b8` - -```python -# Correct: f32 vector with b32 mask -mask32 = pto.make_mask(pto.f32, PAT.ALL) -vec_f32 = pto.vlds(ptr, offset) -out = pto.vabs(vec_f32, mask32) - -# Error: mismatched mask granularity -mask16 = pto.make_mask(pto.f16, PAT.ALL) -out = pto.vabs(vec_f32, mask16) # Type error! -``` - -### Pointer Types - -Pointers combine element type and memory space: - -```python -from pto import MemorySpace - -ptr_gm = pto.ptr(pto.f32, MemorySpace.GM) # GM pointer to f32 -ptr_ub = pto.ptr(pto.f16, MemorySpace.UB) # UB pointer to f16 -``` - -The `MemorySpace` enum provides type-safe memory space specification: - -| Enum Value | Description | -|------------|-------------| -| `MemorySpace.GM` | Global Memory (off-chip HBM/DDR) | -| `MemorySpace.UB` | Unified Buffer (on-chip SRAM, 256KB) | - -This replaces string literals (`MemorySpace.GM`/`MemorySpace.UB`) with compile-time checked enums. - -### Pointer Type Aliases - -For clarity in API documentation, the following type aliases are used: - -| Alias | Equivalent Type | Description | -|-------|----------------|-------------| -| `GMPtr` | `ptr(..., MemorySpace.GM)` | Pointer to Global Memory | -| `UBPtr` | `ptr(..., MemorySpace.UB)` | Pointer to Unified Buffer | -| `UBRef` | `Union[MemRefType, UBPtr]` | UB buffer or pointer (accepted by load/store ops) | -| `Tile` | `pto.tile(...)` | Tile buffer with layout and configuration | - -### MemRef Types - -For buffer-like authoring, use memref types: - -```python -buf1d = pto.memref(256, pto.f32, MemorySpace.UB) # 1D: 256-element f32 buffer in UB -buf2d = pto.memref((256, 128), pto.f32, MemorySpace.UB) # 2D: 256x128 f32 buffer in UB -``` - -- **1D shapes**: Use a scalar integer (e.g., `256`) -- **Multi-dimensional shapes**: Use a tuple (e.g., `(256, 128)`) - -MemRefs are used for stateless load/store operations that accept `buf_like` operands in VPTO. - - -### TensorView Types - -TensorView types represent views into tensors residing in Global Memory (GM). They are used as kernel parameters for describing GM data and support slicing operations to create logical partitions for DMA load/store operations. - -### TensorView Type Definition - -TensorView types are parameterized by shape and element type: - -```python -# Kernel parameter using TensorView -@pto.vkernel(target="a5", op="custom", dtypes=[(pto.AnyFloat, pto.AnyFloat, pto.AnyFloat)], priority=10) -def tiled_kernel( - input_tensor: pto.TensorView, # GM tensor view - output_tensor: pto.TensorView, # GM tensor view - tile_buf: pto.Tile # UB tile -): - # Access tensor view properties - rows, cols = input_tensor.shape # (dynamic or static) - dtype = input_tensor.element_type # e.g., pto.f32 - strides = input_tensor.strides # stride in elements -``` - -**Important Notes:** -- TensorView is a **read-only descriptor** for GM data (though DMA store operations can write to it) -- Shape can be **static** (compile-time constants) or **dynamic** (determined at runtime) -- Strides are expressed in elements, not bytes -- Memory space is always GM (Global Memory) - -### TensorView Attributes - -| Attribute | Type | Description | -|-----------|------|-------------| -| `shape` | `tuple[int, ...]` | Tensor dimensions (2D only in current profile) | -| `element_type` | `Type` | Element data type (e.g., `pto.f32`, `pto.f16`) | -| `strides` | `tuple[int, ...]` | Stride in elements for each dimension | -| `offset` | `pto.i64` | Byte offset from base pointer (internal) | - -### Padding Mode Enum - -Padding mode controls how out-of-bounds accesses are handled during DMA load/store operations: - -| Enum Value | Description | -|------------|-------------| -| `PadMode.PadNull` | No padding (out-of-bounds access is invalid) | -| `PadMode.PadFirstElem` | Pad using the first element of the source | -| `PadMode.PadValue` | Pad using a specified value (requires `pad_value` parameter) | - -**Usage:** -```python -from pto import PadMode - -# Load with zero padding -pto.dma_load(src_partition, dst_tile, - pad_mode=PadMode.PadValue, - pad_value=pto.f32(0.0)) - -# Load with first-element padding -pto.dma_load(src_partition, dst_tile, pad_mode=PadMode.PadFirstElem) - -# Load without padding (default) -pto.dma_load(src_partition, dst_tile) # pad_mode=PadMode.PadNull -``` - -### Slicing Syntax - -TensorView supports Python slicing syntax to create logical partitions: - -```python -# Create a partition from a tensor view -partition = tensor_view[row_start:row_end, col_start:col_end] - -# Example: extract a 16x16 tile from a larger tensor -tile_view = large_tensor[0:16, 0:16] - -# Dynamic offsets and sizes -start_row = pto.i32(0) -start_col = pto.i32(0) -dynamic_partition = tensor_view[start_row:start_row+16, start_col:start_col+16] -``` - -**Constraints:** -- Slicing returns a new TensorView representing the logical partition -- The partition must be within the original tensor bounds -- Slices can be static (constant bounds) or dynamic (runtime values) - -### Alignment Type - -The `pto.align` type is used for alignment carrier operations and maps to `!pto.align`. - -### Tile Types - -Tile types represent data blocks in memory with layout and configuration information, corresponding to `!pto.tile_buf` in the VPTO IR. Tiles are commonly used as kernel parameters for tiled computations. - -#### Tile Type Definition - -```python -# Create a tile with shape, element type, and memory space -tile = pto.tile((256, 128), pto.f32, MemorySpace.UB) - -# With explicit configuration -config = pto.tile_config( - b_layout=pto.BLayout.ROW_MAJOR, - s_layout=pto.SLayout.NONE_BOX, - s_fractal_size=pto.i32(16), - pad_value=pto.PadValue.ZERO -) -tile = pto.tile((256, 128), pto.f32, MemorySpace.UB, config=config) - -# With valid shape (actual data dimensions within tile) -tile = pto.tile((256, 128), pto.f32, MemorySpace.UB, valid_shape=(240, 120)) -``` - -**Important Notes on Shape and Valid Shape:** -- **Static Shape Requirement**: The `shape` parameter must be a compile-time constant. Tile dimensions are fixed at compilation time and cannot change at runtime. -- **Valid Shape Constraints**: The `valid_shape` parameter can be either static (compile-time constant) or dynamic (determined at runtime). It must be less than or equal to the physical `shape` in each dimension. This allows for variable-sized data within a fixed tile allocation. -- **Default Behavior**: When `valid_shape` is not specified, it defaults to the full `shape`. - -#### Tile Attributes - -| Attribute | Type | Description | -|-----------|------|-------------| -| `shape` | `tuple[int, ...]` | **Static** full tile dimensions (compile-time constant) | -| `element_type` | `Type` | Element data type (e.g., `pto.f32`) | -| `memory_space` | `MemorySpace` | Memory space (GM, UB, etc.) | -| `valid_shape` | `tuple[int, ...]` | Actual data dimensions within tile (can be static/compile-time or dynamic/runtime). Must be ≤ shape in each dimension. | -| `config` | `TileConfig` | Layout and padding configuration | - -#### Tile Configuration - -The tile configuration includes layout and padding information: - -```python -# Layout enums -pto.BLayout.ROW_MAJOR # 0: row-major base layout -pto.BLayout.COL_MAJOR # 1: column-major base layout - -pto.SLayout.NONE_BOX # 0: no secondary layout -pto.SLayout.ROW_MAJOR # 1: row-major secondary layout -pto.SLayout.COL_MAJOR # 2: column-major secondary layout - -pto.PadValue.NULL # 0: no padding -pto.PadValue.ZERO # 1: zero padding -pto.PadValue.MAX # 2: maximum value padding -pto.PadValue.MIN # 3: minimum value padding -``` - -#### Tile Shape Concepts - -- **Static Physical Shape**: The `shape` parameter represents the **static physical dimensions** of the tile allocated in memory. This must be a **compile-time constant** because tile memory allocation is fixed during compilation. The shape determines the total memory footprint and cannot change at runtime. - -- **Valid Shape**: The `valid_shape` parameter represents the logical dimensions of actual data within the tile. It can be either **static** (compile-time constant) or **dynamic** (determined at runtime). It must be less than or equal to the physical `shape` in each dimension. When `valid_shape` is not specified, it defaults to the full `shape`. - -- **Key Distinction**: - - `shape`: **Static, compile-time** - Fixed tile allocation - - `valid_shape`: **Static or Dynamic** - Actual data region (must be ≤ shape) - -- **Constraints**: - - `valid_shape[i] ≤ shape[i]` for each dimension i - - `shape` must be compile-time constants - - `valid_shape` can be compile-time constants or runtime values - -- **Use Cases**: - - Fixed-size tile buffers with variable data (e.g., batch processing with different input sizes) - - Padding scenarios where physical allocation is larger than actual data - - Partial tile utilization in tiled algorithms - -- **Fractal Layout**: The `s_fractal_size` in tile configuration specifies the size of fractal blocks for secondary layout. This is used for optimized memory access patterns in matrix operations. - -- **Padding Behavior**: The `pad_value` determines how out-of-bounds accesses are handled when reading beyond `valid_shape` but within `shape`. Padding values are used for accesses in the padded region (between valid_shape and shape). - -> **⚠️ Important: Shape Constraints** -> -> The tile `shape` must be **compile-time constants**. `valid_shape` can be compile-time constants or determined at runtime, but must satisfy `valid_shape[i] ≤ shape[i]` for all dimensions i. - -### Tile Operations - -#### Basic Access Operations - -```python -# Get tile properties -shape = tile.shape # (256, 128) -elem_type = tile.element_type # pto.f32 -mem_space = tile.memory_space # MemorySpace.UB -valid_shape = tile.valid_shape # (240, 120) or same as shape - -# Get configuration properties -config = tile.config -b_layout = config.b_layout # pto.BLayout.ROW_MAJOR -s_layout = config.s_layout # pto.SLayout.NONE_BOX -s_fractal = config.s_fractal_size # pto.i32(16) -pad = config.pad_value # pto.PadValue.ZERO - -# Dynamic properties -rank = tile.rank # 2 -num_elements = tile.num_elements # 32768 (256 * 128) -valid_elements = tile.valid_elements # 28800 (240 * 120) -``` - -#### Layout and Stride Queries - -```python -# Get layout descriptors -layout_desc = tile.layout_descriptor # Returns layout description object - -# Get strides (in elements) -strides = tile.strides # (128, 1) for row-major 256x128 - -# Get byte strides -byte_strides = tile.byte_strides # (512, 4) for f32 row-major - -# Get base offset (in bytes) -offset = tile.offset # pto.i64(0) or specified offset -``` - -#### Conversion Operations - -Tiles support both explicit and implicit conversion to UBRef. When a tile is used in operations expecting a UBRef (e.g., `pto.vlds`, `pto.vsts`), it is automatically converted. - -```python -# Convert to UBRef (implicit in vector operations) -ub_ref = tile.to_ubref() # Explicit conversion -# or use tile as UBRef directly in vector ops -vec = pto.vlds(tile, offset) # Implicit conversion - -# Convert to typed pointer -ptr = tile.as_ptr() # Returns pto.ptr(pto.f32, MemorySpace.UB) - -# Convert to MemRef (for compatibility) -memref = tile.to_memref() # Returns pto.memref((256, 128), pto.f32, MemorySpace.UB) - -# Extract slice of tile -slice_tile = tile.slice((0, 0), (64, 128)) # 64x128 slice from top-left corner - -# Reshape tile (logical reshape, no data movement) -reshaped = tile.reshape((32768,)) # 1D reshape of 256x128 tile -``` - -#### Kernel Parameter Usage - -```python -@pto.vkernel(target="a5", op="scale", dtypes=[(pto.AnyFloat, pto.AnyFloat)], priority=10) -def tiled_kernel( - input_tile: pto.Tile, # Tile parameter - output_tile: pto.Tile, # Another tile parameter - scale: pto.f32 -): - # Convert tiles to UBRef for vector operations - ub_in = input_tile.to_ubref() - ub_out = output_tile.to_ubref() - - # Or use tiles directly (implicit conversion) - all_mask = pto.make_mask(pto.f32, PAT.ALL) - for i in range(0, 256, 64): - # tile implicitly converts to UBRef in vlds with element-indexing syntax - vec = pto.vlds(input_tile[i, 0:]) # Load from row i, columns 0 to vector_lanes-1 - scaled = pto.vmuls(vec, scale, all_mask) - pto.vsts(scaled, output_tile[i, 0:], all_mask) # Store to same position -``` - -#### Tile Creation from Existing Buffers - -```python -# Create tile from existing pointer with shape -ptr = pto.castptr(0, pto.ptr(pto.f32, MemorySpace.UB)) -tile = pto.tile_from_ptr(ptr, (256, 128), pto.f32) - -# Create tile from memref -memref = pto.memref((256, 128), pto.f32, MemorySpace.UB) -tile = pto.tile_from_memref(memref) - -# Create tile with explicit stride -tile = pto.tile_with_strides((256, 128), pto.f32, MemorySpace.UB, - strides=(256, 1)) # Column-major strides -``` - -## Control Flow - -### Vector Scopes - -The TileLang DSL supports implicit vector scope inference, allowing developers to write vector operations directly without explicit `pto.vecscope()` blocks. The compiler automatically groups consecutive, data-dependent vector operations into implicit vector scopes during lowering. - -#### Implicit Scope Inference - -**Note:** The explicit `pto.vecscope()` construct is deprecated. Vector operations are automatically grouped into implicit scopes by the compiler's Scope Inference Pass. - -When you write vector operations like `pto.vlds`, `pto.vadd`, `pto.vsts` directly in your code, the compiler's **Scope Inference Pass** analyzes the control flow graph and automatically creates vector scopes: - -```python -# No explicit vecscope needed - compiler infers scope boundaries -vec = pto.vlds(outer_ptr, offset) -result = pto.vadd(vec, vec, all_mask) -pto.vsts(result, dst_ptr, offset, all_mask) -``` - -The compiler automatically groups these three operations into a single implicit vector scope because they form a data-dependent chain. - -**Scope boundary rules:** -1. **Control flow boundaries**: Branches (`if`/`else`), loops (`for`/`while`), and function calls create implicit scope boundaries -2. **Scalar operations**: Non-vector operations (e.g., scalar arithmetic, pointer arithmetic) create boundaries -3. **Explicit strict_vecscope**: User-defined `strict_vecscope` blocks create hard boundaries - -#### Explicit Scope Boundaries with `strict_vecscope` - -For precise control over scope boundaries, use explicit `strict_vecscope` blocks. These create hard boundaries that prevent the compiler from merging operations across the block boundary: - -```python -with pto.strict_vecscope(src_ptr, dst_ptr, start, end) as (s, d, lb, ub): - # Operations inside this block are isolated from outside - # Compiler will not merge operations across this boundary - for i in range(lb, ub, 64): - vec = pto.vlds(s, i) - pto.vsts(vec, d, i, all_mask) -``` - -**Use cases for strict_vecscope:** -- Performance optimization: Isolate critical vector computation regions -- Debugging: Create explicit boundaries to isolate vector operations -- Resource management: Control vector register allocation boundaries -- Compatibility: Ensure deterministic scope placement for hardware constraints - -### Loops - -Counted loops use Python's `range` syntax: - -```python -for i in range(lb, ub, step): - # Loop body - mask, rem = pto.make_mask(pto.f32, remaining) - # ... -``` - -Loop-carried state is automatically handled through variable updates within the loop. - -### Conditionals - -`if` statements support value merging: - -```python -flag: pto.i1 = some_condition -step: pto.i32 = 0 - -if flag: - step = pto.i32(64) -else: - step = pto.i32(128) - -# 'step' here is the merged result from both branches -``` - -Variables defined in only one branch are local to that branch. - -## Operations - -The DSL provides operations grouped by functionality. All operations use the `pto.` prefix. Operations are organized by functional families following the VPTO instruction set architecture. - -### Pointer Construction - -Operations for creating and manipulating typed pointers. - -#### `pto.castptr(offset: pto.i64, ptr_type: Type) -> PtrType` - -**Description**: Creates a pointer with the specified offset and type. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `offset` | `pto.i64` | Byte offset from base address | -| `ptr_type` | `Type` | Target pointer type (e.g., `pto.ptr(pto.f32, MemorySpace.GM)`) | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `ptr` | `PtrType` | Typed pointer value | - -**Example**: -```python -ub_ptr = pto.castptr(0, pto.ptr(pto.f32, MemorySpace.UB)) -``` - -#### `pto.addptr(ptr: PtrType, offset: pto.i64) -> PtrType` - -**Description**: Adds an offset to an existing pointer. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `ptr` | `PtrType` | Source pointer | -| `offset` | `pto.i64` | Byte offset to add | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `new_ptr` | `PtrType` | Pointer with offset applied | - -**Example**: -```python -next_ptr = pto.addptr(ub_ptr, 4096) -``` - -### Synchronization & Buffer Control - -Operations for pipeline synchronization and buffer management. - -#### `pto.set_flag(pipe_from: PIPE, pipe_to: PIPE, event: EVENT) -> None` - -**Description**: Sets a synchronization flag between hardware pipelines. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `pipe_from` | `PIPE` | Source pipeline (e.g., `PIPE.MTE2`) | -| `pipe_to` | `PIPE` | Destination pipeline (e.g., `PIPE.V`) | -| `event` | `EVENT` | Event identifier (e.g., `EVENT.ID0`) | - -**Returns**: None (side-effect operation) - -**Example**: -```python -from pto import PIPE, EVENT - -pto.set_flag(PIPE.MTE2, PIPE.V, EVENT.ID0) -``` - -#### `pto.wait_flag(pipe_from: PIPE, pipe_to: PIPE, event: EVENT) -> None` - -**Description**: Waits for a synchronization flag between hardware pipelines. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `pipe_from` | `PIPE` | Source pipeline (e.g., `PIPE.MTE2`) | -| `pipe_to` | `PIPE` | Destination pipeline (e.g., `PIPE.V`) | -| `event` | `EVENT` | Event identifier (e.g., `EVENT.ID0`) | - -**Returns**: None (side-effect operation) - -**Example**: -```python -from pto import PIPE, EVENT - -pto.wait_flag(PIPE.MTE2, PIPE.V, EVENT.ID0) -``` - -#### `pto.pipe_barrier(pipes: PIPE) -> None` - -**Description**: Executes a barrier across specified pipelines. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `pipes` | `PIPE` | Pipeline specification (e.g., `PIPE.ALL`) | - -**Returns**: None (side-effect operation) - -**Example**: -```python -from pto import PIPE - -pto.pipe_barrier(PIPE.ALL) -``` - -#### `pto.get_buf(op_type: SyncOpType, buf_id: pto.i32, mode: pto.i32 = 0) -> None` - -**Description**: Acquires a buffer for producer-consumer synchronization. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `op_type` | `SyncOpType` | Operation type (e.g., `SyncOpType.TLOAD`) | -| `buf_id` | `pto.i32` | Buffer identifier | -| `mode` | `pto.i32` | Acquisition mode (default: 0) | - -**Returns**: None (side-effect operation) - -**Example**: -```python -from pto import SyncOpType - -# Acquire buffer for DMA load operation -pto.get_buf(SyncOpType.TLOAD, 0) -``` - -#### `pto.rls_buf(op_type: SyncOpType, buf_id: pto.i32, mode: pto.i32 = 0) -> None` - -**Description**: Releases a previously acquired buffer. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `op_type` | `SyncOpType` | Operation type (e.g., `SyncOpType.TLOAD`) | -| `buf_id` | `pto.i32` | Buffer identifier | -| `mode` | `pto.i32` | Release mode (default: 0) | - -**Returns**: None (side-effect操作) - -**Example**: -```python -from pto import SyncOpType - -# Release buffer for DMA load operation -pto.rls_buf(SyncOpType.TLOAD, 0) -``` - -### Low-level DMA Programming (Legacy) - -**Note**: These low-level DMA programming operations are automatically handled by `pto.dma_load` and `pto.dma_store` in most cases. They expose hardware DMA engine parameters directly and should only be used when the automatic inference provided by the high-level API is insufficient for specific optimization needs. - -This section contains both DMA configuration operations (setting loop strides and sizes) and DMA execution operations (copying data). Prefer the high-level `pto.dma_load` and `pto.dma_store` operations which automatically infer all parameters from TensorView slices and Tile properties. - -#### When to Use Low-level DMA Programming - -Consider using these low-level operations only in the following scenarios: - -1. **Performance micro-optimization**: When specific DMA parameter tuning is required for performance-critical code -2. **Non-standard access patterns**: When TensorView slicing syntax cannot express the desired memory access pattern -3. **Hardware-specific optimizations**: When targeting specific DMA engine characteristics not captured by the high-level API - -For 99% of use cases, `pto.dma_load` and `pto.dma_store` with TensorView slicing provide sufficient control and are much easier to use correctly. - -#### Manual Configuration Example - -```python -# Manual DMA configuration (discouraged for normal use) -pto.set_loop2_stride_outtoub(32, 128) # Outer loop strides -pto.set_loop1_stride_outtoub(1, 32) # Inner loop strides -pto.set_loop_size_outtoub(16, 16) # Transfer size -pto.copy_gm_to_ubuf(gm_ptr, ub_ptr, ...) - -# Equivalent using high-level API (recommended) -pto.dma_load(input_tensor[0:16, 0:16], ub_tile) -# All loop strides and sizes automatically inferred -``` - -#### `pto.set_loop2_stride_outtoub(stride0: pto.i64, stride1: pto.i64) -> None` - -**Description**: Configures DMA stride parameters for GM → UB transfers (loop2). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `stride0` | `pto.i64` | First dimension stride | -| `stride1` | `pto.i64` | Second dimension stride | - -**Returns**: None (side-effect operation) - -#### `pto.set_loop1_stride_outtoub(stride0: pto.i64, stride1: pto.i64) -> None` - -**Description**: Configures DMA stride parameters for GM → UB transfers (loop1). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `stride0` | `pto.i64` | First dimension stride | -| `stride1` | `pto.i64` | Second dimension stride | - -**Returns**: None (side-effect operation) - -#### `pto.set_loop_size_outtoub(size0: pto.i64, size1: pto.i64) -> None` - -**Description**: Configures DMA transfer size for GM → UB transfers. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `size0` | `pto.i64` | First dimension size | -| `size1` | `pto.i64` | Second dimension size | - -**Returns**: None (side-effect operation) - -**Example**: -```python -pto.set_loop_size_outtoub(1, 1) -``` - -#### `pto.set_loop2_stride_ubtoout(stride0: pto.i64, stride1: pto.i64) -> None` - -**Description**: Configures DMA stride parameters for UB → GM transfers (loop2). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `stride0` | `pto.i64` | First dimension stride | -| `stride1` | `pto.i64` | Second dimension stride | - -**Returns**: None (side-effect operation) - -#### `pto.set_loop1_stride_ubtoout(stride0: pto.i64, stride1: pto.i64) -> None` - -**Description**: Configures DMA stride parameters for UB → GM transfers (loop1). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `stride0` | `pto.i64` | First dimension stride | -| `stride1` | `pto.i64` | Second dimension stride | - -**Returns**: None (side-effect operation) - -#### `pto.set_loop_size_ubtoout(size0: pto.i64, size1: pto.i64) -> None` - -**Description**: Configures DMA transfer size for UB → GM transfers. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `size0` | `pto.i64` | First dimension size | -| `size1` | `pto.i64` | Second dimension size | - -**Returns**: None (side-effect operation) - -#### DMA Execution Operations - -**Note**: These operations execute DMA transfers but require manual configuration of DMA parameters (loop strides, loop sizes) using the `set_loop*_stride_*` and `set_loop_size_*` operations described above. The high-level `pto.dma_load` and `pto.dma_store` operations automatically handle both configuration and execution. - -The following operations provide direct control over DMA transfers but require manual stride and size configuration. Prefer the high-level Tile Data Movement operations for most use cases. - -#### `pto.copy_gm_to_ubuf(src: GMPtr, dst: UBPtr, src_offset: pto.i64, src_stride0: pto.i64, src_stride1: pto.i64, dst_offset: pto.i64, dst_stride0: pto.i64, transpose: pto.i1, pad_left: pto.i64, pad_right: pto.i64, pad_value: pto.i64) -> None` - -**Description**: Copies data from Global Memory (GM) to Unified Buffer (UB). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `src` | `GMPtr` | Source GM pointer | -| `dst` | `UBPtr` | Destination UB pointer | -| `src_offset` | `pto.i64` | Source offset | -| `src_stride0` | `pto.i64` | Source stride dimension 0 | -| `src_stride1` | `pto.i64` | Source stride dimension 1 | -| `dst_offset` | `pto.i64` | Destination offset | -| `dst_stride0` | `pto.i64` | Destination stride dimension 0 | -| `transpose` | `pto.i1` | Transpose flag | -| `pad_left` | `pto.i64` | Left padding size | -| `pad_right` | `pto.i64` | Right padding size | -| `pad_value` | `pto.i64` | Padding value | - -**Returns**: None (side-effect operation) - -**Example**: -```python -pto.copy_gm_to_ubuf(gm_ptr, ub_ptr, 0, 32, 128, 0, 0, False, 0, 128, 128) -``` - -#### `pto.copy_ubuf_to_ubuf(src: UBPtr, dst: UBPtr, src_offset: pto.i64, src_stride0: pto.i64, src_stride1: pto.i64, dst_offset: pto.i64, dst_stride0: pto.i64, dst_stride1: pto.i64) -> None` - -**Description**: Copies data within Unified Buffer (UB → UB). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `src` | `UBPtr` | Source UB pointer | -| `dst` | `UBPtr` | Destination UB pointer | -| `src_offset` | `pto.i64` | Source offset | -| `src_stride0` | `pto.i64` | Source stride dimension 0 | -| `src_stride1` | `pto.i64` | Source stride dimension 1 | -| `dst_offset` | `pto.i64` | Destination offset | -| `dst_stride0` | `pto.i64` | Destination stride dimension 0 | -| `dst_stride1` | `pto.i64` | Destination stride dimension 1 | - -**Returns**: None (side-effect operation) - -#### `pto.copy_ubuf_to_gm(src: UBPtr, dst: GMPtr, src_offset: pto.i64, src_stride0: pto.i64, src_stride1: pto.i64, dst_offset: pto.i64, dst_stride0: pto.i64, dst_stride1: pto.i64) -> None` - -**Description**: Copies data from Unified Buffer (UB) to Global Memory (GM). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `src` | `UBPtr` | Source UB pointer | -| `dst` | `GMPtr` | Destination GM pointer | -| `src_offset` | `pto.i64` | Source offset | -| `src_stride0` | `pto.i64` | Source stride dimension 0 | -| `src_stride1` | `pto.i64` | Source stride dimension 1 | -| `dst_offset` | `pto.i64` | Destination offset | -| `dst_stride0` | `pto.i64` | Destination stride dimension 0 | -| `dst_stride1` | `pto.i64` | Destination stride dimension 1 | - -**Returns**: None (side-effect operation) - -**Example**: -```python -pto.copy_ubuf_to_gm(ub_ptr, gm_ptr, 0, 32, 128, 0, 128, 128) -``` - -### Tile Data Movement Operations - -High-level operations for moving data between TensorView partitions (GM) and Tile buffers (UB), as well as between Tile buffers. These operations **automatically handle all low-level DMA configuration** and provide an intuitive interface based on tile semantics. - -#### Automatic DMA Parameter Inference - -The `pto.dma_load` and `pto.dma_store` operations automatically infer DMA transfer parameters (loop strides, loop sizes) from: - -1. **TensorView slices** - Python slicing syntax captures stride information: - ```python - # Contiguous slice: [0:16, 0:16] - pto.dma_load(input_tensor[0:16, 0:16], ub_tile) - - # Strided slice: [0:64:2, 0:32] → stride=2 in first dimension - pto.dma_load(input_tensor[0:64:2, 0:32], ub_tile) - ``` - -2. **Tile properties** - Layout and memory space determine destination patterns: - ```python - # Row-major vs column-major layouts affect stride computation - row_major_tile = pto.tile((16, 16), pto.f32, pto.MemorySpace.UB, b_layout=pto.BLayout.ROW_MAJOR) - col_major_tile = pto.tile((16, 16), pto.f32, pto.MemorySpace.UB, b_layout=pto.BLayout.COL_MAJOR) - ``` - -3. **Transpose and padding requirements** - Specified via operation parameters. - -#### Benefits of Automatic Inference - -- **Simplified API**: No need to manually call `set_loop*_stride_*` and `set_loop_size_*` operations -- **Reduced errors**: Automatic parameter validation and consistency checking -- **Hardware abstraction**: Focus on data movement semantics, not DMA engine details -- **Portable code**: Same TileLang code works across different DMA implementations - -For advanced use cases requiring manual DMA parameter control, see the [Low-level DMA Programming (Legacy)](#low-level-dma-programming-legacy) section. - -#### `pto.dma_load(src: TensorView, dst: Tile, pad_mode: PadMode = PadMode.PadNull, pad_value: ScalarType = None, left_padding: Index = 0, right_padding: Index = 0, init_out_buffer: bool = False) -> None` - -**Description**: Loads data from a TensorView partition (GM) into a Tile buffer (UB). This maps to `pto.copy_gm_to_ubuf` operation in VPTO IR. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `src` | `TensorView` | Source tensor view partition (must be in GM) | -| `dst` | `Tile` | Destination tile buffer (must be in UB memory space) | -| `pad_mode` | `PadMode` | Padding mode (PadNull, PadFirstElem, PadValue) | -| `pad_value` | `ScalarType` | Padding value (required if `pad_mode == PadValue`) | -| `left_padding` | `Index` | Left padding element count | -| `right_padding` | `Index` | Right padding element count | -| `init_out_buffer` | `bool` | Initialize output buffer before loading | - -**Returns**: None (side-effect operation) - -**Constraints**: -- Destination tile must have `memory_space = MemorySpace.UB` -- Element types of source and destination must have same bitwidth -- Source partition shape must match destination tile valid shape (after accounting for padding) - -**Example**: -```python -# Load a 16x16 partition into a UB tile -pto.dma_load(input_tensor[0:16, 0:16], ub_tile) - -# Load with zero padding -pto.dma_load(input_tensor[0:16, 0:16], ub_tile, - pad_mode=PadMode.PadValue, - pad_value=pto.f32(0.0), - left_padding=2, - right_padding=2) -``` - -#### `pto.dma_store(src: Tile, dst: TensorView, pad_mode: PadMode = PadMode.PadNull, pad_value: ScalarType = None, left_padding: Index = 0, right_padding: Index = 0) -> None` - -**Description**: Stores data from a Tile buffer (UB) to a TensorView partition (GM). This maps to `pto.copy_ubuf_to_gm` operation in VPTO IR. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `src` | `Tile` | Source tile buffer (must be in UB memory space) | -| `dst` | `TensorView` | Destination tensor view partition (must be in GM) | -| `pad_mode` | `PadMode` | Padding mode (PadNull, PadFirstElem, PadValue) | -| `pad_value` | `ScalarType` | Padding value (required if `pad_mode == PadValue`) | -| `left_padding` | `Index` | Left padding element count | -| `right_padding` | `Index` | Right padding element count | - -**Returns**: None (side-effect operation) - -**Constraints**: -- Source tile must have `memory_space = MemorySpace.UB` -- Element types of source and destination must have same bitwidth -- Source tile valid shape must match destination partition shape (after accounting for padding) - -**Example**: -```python -# Store a UB tile to a GM partition -pto.dma_store(ub_tile, output_tensor[0:16, 0:16]) - -# Store with padding -pto.dma_store(ub_tile, output_tensor[0:16, 0:16], - pad_mode=PadMode.PadValue, - pad_value=pto.f32(0.0), - left_padding=1, - right_padding=1) -``` - -#### `pto.dma_copy(src: Tile, dst: Tile, src_offset: tuple[Index, Index] = (0, 0), dst_offset: tuple[Index, Index] = (0, 0), copy_shape: tuple[Index, Index] = None) -> None` - -**Description**: Copies data between Tile buffers within Unified Buffer (UB → UB). This maps to `pto.copy_ubuf_to_ubuf` operation in VPTO IR. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `src` | `Tile` | Source tile buffer (must be in UB memory space) | -| `dst` | `Tile` | Destination tile buffer (must be in UB memory space) | -| `src_offset` | `tuple[Index, Index]` | Offset within source tile (row, col) in elements | -| `dst_offset` | `tuple[Index, Index]` | Offset within destination tile (row, col) in elements | -| `copy_shape` | `tuple[Index, Index]` | Shape of region to copy (rows, cols) in elements. If None, copies the maximum valid region starting from offsets. | - -**Returns**: None (side-effect operation) - -**Constraints**: -- Both tiles must have `memory_space = MemorySpace.UB` -- Element types of source and destination must match -- Source and destination regions must be within tile valid shapes - -**Example**: -```python -# Copy entire tile -pto.dma_copy(src_tile, dst_tile) - -# Copy subregion: copy 8x8 block from (2,2) in src to (0,0) in dst -pto.dma_copy(src_tile, dst_tile, - src_offset=(2, 2), - dst_offset=(0, 0), - copy_shape=(8, 8)) -``` - -**Note**: These high-level operations automatically handle DMA stride and size configuration based on tile shapes, layouts, and offsets. For low-level control, see the [Low-level DMA Programming (Legacy)](#low-level-dma-programming-legacy) section. - -#### VPTO IR Mapping - -The high-level DMA operations in TileLang DSL map to corresponding operations in VPTO IR: - -| TileLang DSL Operation | VPTO IR Operation | Description | -|------------------------|-------------------|-------------| -| `pto.dma_load` | `pto.copy_gm_to_ubuf` | Loads data from GM tensor view to UB tile buffer | -| `pto.dma_store` | `pto.copy_ubuf_to_gm` | Stores data from UB tile buffer to GM tensor view | -| `pto.dma_copy` | `pto.copy_ubuf_to_ubuf` | Copies data between UB tile buffers | - -These mappings allow the TileLang compiler to generate efficient VPTO IR code while providing a higher-level, more intuitive API for developers. The compiler automatically handles the conversion between Tile/TensorView abstractions and the low-level pointer/stride representation required by VPTO IR operations. - - -### Address Generation Syntax Sugar - -To simplify address calculation and reduce manual byte offset computation errors, TileLang DSL provides syntactic sugar for vector load/store operations using element-based indexing. This syntax automatically computes the byte offset based on tile shape, element type, and layout. - -#### Indexing Syntax - -The syntax supports two indexing modes for different operations: - -1. **Vector-range indexing** (for vector load/store operations): - - **Row-major layout (default)**: `tile[row_index, col_start:]` - - `row_index`: Row index (0-based) - - `col_start:`: Starting column index followed by colon, indicating a vector-width contiguous region starting from this column - - The colon (`:`) indicates an implicit vector-width range determined by hardware vector size (256 bytes) and element type - - - **Column-major layout**: `tile[row_start:, col_index]` - - `row_start:`: Starting row index followed by colon, indicating a vector-width contiguous region starting from this row - - `col_index`: Column index (0-based) - - Used for column-major tiles (`BLayout.COL_MAJOR`) where elements are stored column-wise - - - **1D tile indexing**: `tile[start:]` (or equivalently `tile[0, start:]` for row-major or `tile[start:, 0]` for column-major) - - `start:`: Starting element index followed by colon - -2. **Single-element indexing** (for scalar load operations like `pto.vsld`): - - **Row-major layout (default)**: `tile[row_index, col_index]` - - `row_index`: Row index (0-based) - - `col_index`: Column index (0-based) - - Loads a single element at the specified position and broadcasts it to all vector lanes - - - **Column-major layout**: `tile[row_index, col_index]` (same syntax) - - `row_index`: Row index (0-based) - - `col_index`: Column index (0-based) - - Same syntax as row-major; the layout determines how the offset is computed - - - **1D tile indexing**: `tile[pos]` - - `pos`: Element index (0-based) - - Loads a single element at the specified position and broadcasts it to all vector lanes - -#### Vector Width Calculation - -The number of elements loaded/stored in a single vector operation is determined by: - -``` -vector_lanes = 256 // element_size_bytes(element_type) -``` - -**Convenience API**: Use `pto.get_lanes(dtype)` to compute vector lanes for a given element type (e.g., `pto.get_lanes(pto.f32)` returns 64, `pto.get_lanes(pto.f16)` returns 128). - -Where `element_size_bytes` is: -- 1 byte for `i8` -- 2 bytes for `i16`, `f16`, `bf16` -- 4 bytes for `i32`, `f32` -- 8 bytes for `i64` - -#### Offset Computation - -The byte offset is automatically computed based on tile layout: - -- **Row-major layout** (`BLayout.ROW_MAJOR`): - ``` - offset = (row_index * stride_row + col_start) * element_size_bytes - ``` - where `stride_row` is the row stride in elements (typically `tile.shape[1]` for contiguous tiles). - -- **Column-major layout** (`BLayout.COL_MAJOR`): - - For syntax `tile[row_start:, col_index]`: - ``` - offset = (col_index * stride_col + row_start) * element_size_bytes - ``` - - For backward compatibility with traditional offset calculation: - ``` - offset = (col_start * stride_col + row_index) * element_size_bytes - ``` - where `stride_col` is the column stride in elements (typically `tile.shape[0]` for contiguous tiles), `row_start` is the starting row index, and `col_index` is the column index. - -**Note**: -- For single-element indexing (`tile[row, col]` or `tile[pos]`), the same offset formulas apply with `col_start` replaced by `col_index` (or `start` replaced by `pos` for 1D tiles). -- For column-major vector-range indexing (`tile[row_start:, col_index]`), the offset formula uses `row_start` as the starting position along the contiguous dimension. -- The compiler automatically handles the appropriate substitution based on the indexing syntax and tile layout. - -#### Constraints - -1. **Boundary checks**: The requested region must be within tile bounds: - - **For vector-range indexing** (`:` syntax): - - **Row-major layout** (`tile[row_index, col_start:]`): - - `row_index < tile.shape[0]` and `col_start + vector_lanes <= tile.shape[1]` - - **Column-major layout** (`tile[row_start:, col_index]`): - - `row_start + vector_lanes <= tile.shape[0]` and `col_index < tile.shape[1]` - - **1D tile indexing**: `tile[start:]` - - `start + vector_lanes <= tile.shape[0]` (or `tile.shape[1]` for 1D tiles) - - **For single-element indexing** (no `:` syntax): - - 2D: `row_index < tile.shape[0]` and `col_index < tile.shape[1]` (same for both layouts) - - 1D: `pos < tile.shape[0]` (or `tile.shape[1]` for 1D tiles) - -2. **Alignment**: The computed offset must satisfy hardware alignment requirements for the operation. - -3. **Full vectors only**: The `:` syntax always loads/stores a full vector width. For partial vectors, use the traditional byte offset approach with explicit mask handling. - -4. **Single-element operations**: The single-element indexing syntax (`tile[row, col]` or `tile[pos]`) is only supported for scalar load operations like `pto.vsld`. For other operations, use vector-range indexing with `:` syntax. - -#### Supported Operations - -The indexing syntax is supported for all vector load and store operations with the following syntax mapping: - -- **Vector-range indexing** (`tile[row, col:]` or `tile[start:]`): - - Load operations: `vlds`, `vldas`, `vldus`, `vplds`, `vldx2` - - Store operations: `vsts`, `vsta`, `psts`, `vsst`, `vstx2` - -- **Single-element indexing** (`tile[row, col]` or `tile[pos]`): - - Load operations: `vsld` (scalar load with broadcast) - -#### Examples - -The following examples use row-major layout syntax. For column-major tiles, use `tile[row_start:, col_index]` syntax instead of `tile[row_index, col_start:]`. - -```python -# 2D tile indexing (row-major layout) -vec = pto.vlds(tile[i, j:]) # Load vector from row i, columns j to j+vector_lanes-1 -pto.vsts(vec, tile[i, j:], mask) # Store vector with mask - -# 1D tile indexing -vec = pto.vlds(tile[k:]) # Load vector from elements k to k+vector_lanes-1 -pto.vsts(vec, tile[k:], mask) # Store vector with mask - -# Dual load with indexing -vec1, vec2 = pto.vldx2(tile_a[i, j:], tile_b[i, j:]) - -# Aligned load with indexing -vec = pto.vldas(tile[i, j:], align) - -# Scalar load (broadcast) -vec = pto.vsld(tile[i, j]) # Load scalar at tile[i,j] and broadcast to vector -``` - -#### Comparison with Manual Offset Calculation - -**Traditional approach (error-prone):** -```python -# Manual byte offset calculation for f32 tile -rows, cols = tile.shape -row_offset = i * cols * 4 # Hard-coded 4 bytes for f32 -col_offset = j * 4 -offset = row_offset + col_offset -vec = pto.vlds(tile, offset) -``` - -**New syntax (type-safe):** -```python -# Automatic offset calculation -vec = pto.vlds(tile[i, j:]) # Compiler computes correct offset for any element type -``` - -The syntax sugar eliminates manual byte calculations, reduces errors, and makes code generic across different element types (e.g., the same kernel works for both `f16` and `f32` without modification). - -### Vector Load Operations - -Operations for loading data from memory into vector registers. - -#### `pto.vlds(buf: UBRef, offset: Index) -> VRegType` -#### `pto.vlds(tile[row, col:]) -> VRegType` -#### `pto.vlds(tile[start:]) -> VRegType` - -**Description**: Stateless vector load from buffer. Supports both traditional byte-offset syntax and new element-indexing syntax. - -**Parameters (byte-offset syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `buf` | `UBRef` | Buffer or pointer (UB memory space) | -| `offset` | `Index` | Byte offset | - -**Parameters (element-indexing syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | -| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `vec` | `VRegType` | Loaded vector register | - -**Constraints**: -- Buffer must be in UB memory space -- For byte-offset syntax: offset must be properly aligned based on element type -- For element-indexing syntax: the requested vector region must be within tile bounds and satisfy alignment requirements - -**Examples**: -```python -# Traditional byte-offset syntax -vec = pto.vlds(ub_ptr, lane * 256) - -# New element-indexing syntax -vec = pto.vlds(tile[i, j:]) # Load from row i, columns j to j+vector_lanes-1 -vec = pto.vlds(tile[k:]) # Load from 1D tile, elements k to k+vector_lanes-1 - -# Generic kernel that works for both f16 and f32 -@pto.vkernel(target="a5", op="scale", dtypes=[(pto.AnyFloat, pto.AnyFloat)], priority=10) -def generic_scale(src: pto.Tile, dst: pto.Tile, scale: pto.f32): - rows, cols = src.shape - all_mask = pto.make_mask(src.element_type, PAT.ALL) - for i in range(0, rows): - for j in range(0, cols, vector_lanes): # vector_lanes computed from element type - # No manual byte calculation needed! - vec = pto.vlds(src[i, j:]) - scaled = pto.vmuls(vec, scale, all_mask) - pto.vsts(scaled, dst[i, j:], all_mask) -``` - -#### `pto.vldas(buf: UBRef, offset: Index, align: pto.align) -> VRegType` -#### `pto.vldas(tile[row, col:], align: pto.align) -> VRegType` -#### `pto.vldas(tile[start:], align: pto.align) -> VRegType` - -**Description**: Aligned vector load with explicit alignment carrier. Supports both byte-offset and element-indexing syntax. - -**Parameters (byte-offset syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `buf` | `UBRef` | Buffer or pointer (UB memory space) | -| `offset` | `Index` | Byte offset | -| `align` | `pto.align` | Alignment specification | - -**Parameters (element-indexing syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column | -| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index | -| `align` | `pto.align` | Alignment specification | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `vec` | `VRegType` | Loaded vector register | - -**Examples**: -```python -# Byte-offset syntax -vec = pto.vldas(ub_ptr, offset, align) - -# Element-indexing syntax -vec = pto.vldas(tile[i, j:], align) -vec = pto.vldas(tile[k:], align) -``` - -#### `pto.vldus(buf: UBRef, offset: Index) -> VRegType` -#### `pto.vldus(tile[row, col:]) -> VRegType` -#### `pto.vldus(tile[start:]) -> VRegType` - -**Description**: Unaligned vector load. Supports both byte-offset and element-indexing syntax. - -**Parameters (byte-offset syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `buf` | `UBRef` | Buffer or pointer (UB memory space) | -| `offset` | `Index` | Byte offset | - -**Parameters (element-indexing syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column | -| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `vec` | `VRegType` | Loaded vector register | - -**Examples**: -```python -# Byte-offset syntax -vec = pto.vldus(ub_ptr, offset) - -# Element-indexing syntax -vec = pto.vldus(tile[i, j:]) -vec = pto.vldus(tile[k:]) -``` - -#### `pto.vplds(buf: UBRef, offset: Index, pred: MaskType) -> VRegType` -#### `pto.vplds(tile[row, col:], pred: MaskType) -> VRegType` -#### `pto.vplds(tile[start:], pred: MaskType) -> VRegType` - -**Description**: Predicated vector load stateless. Supports both byte-offset and element-indexing syntax. - -**Parameters (byte-offset syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `buf` | `UBRef` | Buffer or pointer (UB memory space) | -| `offset` | `Index` | Byte offset | -| `pred` | `MaskType` | Predicate mask | - -**Parameters (element-indexing syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column | -| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index | -| `pred` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `vec` | `VRegType` | Loaded vector register | - -**Examples**: -```python -# Byte-offset syntax -vec = pto.vplds(ub_ptr, offset, mask) - -# Element-indexing syntax -vec = pto.vplds(tile[i, j:], mask) -vec = pto.vplds(tile[k:], mask) -``` - -#### `pto.vldx2(buf1: UBRef, buf2: UBRef, offset: Index) -> (VRegType, VRegType)` -#### `pto.vldx2(tile1[row, col:], tile2[row, col:]) -> (VRegType, VRegType)` -#### `pto.vldx2(tile1[start:], tile2[start:]) -> (VRegType, VRegType)` - -**Description**: Dual vector load from two buffers. Supports both byte-offset and element-indexing syntax. - -**Parameters (byte-offset syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `buf1` | `UBRef` | First buffer or pointer | -| `buf2` | `UBRef` | Second buffer or pointer | -| `offset` | `Index` | Byte offset (applied to both buffers) | - -**Parameters (element-indexing syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `tile1[row, col:]` | `Tile` with indexing | First 2D tile with row index and starting column | -| `tile2[row, col:]` | `Tile` with indexing | Second 2D tile with row index and starting column | -| _or_ | | | -| `tile1[start:]` | `Tile` with indexing | First 1D tile with starting element index | -| `tile2[start:]` | `Tile` with indexing | Second 1D tile with starting element index | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `vec1` | `VRegType` | Vector from first buffer | -| `vec2` | `VRegType` | Vector from second buffer | - -**Examples**: -```python -# Byte-offset syntax -vec1, vec2 = pto.vldx2(ub_ptr1, ub_ptr2, offset) - -# Element-indexing syntax -vec1, vec2 = pto.vldx2(tile_a[i, j:], tile_b[i, j:]) -vec1, vec2 = pto.vldx2(tile_a[k:], tile_b[k:]) -``` - -#### `pto.vsld(buf: UBRef, offset: Index) -> VRegType` -#### `pto.vsld(tile[row, col]) -> VRegType` -#### `pto.vsld(tile[pos]) -> VRegType` - -**Description**: Scalar load to vector (broadcast scalar to all lanes). Supports both byte-offset and element-indexing syntax. The element-indexing syntax loads a single element (not a vector) and broadcasts it to all lanes. - -**Parameters (byte-offset syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `buf` | `UBRef` | Buffer or pointer (UB memory space) | -| `offset` | `Index` | Byte offset | - -**Parameters (element-indexing syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `tile[row, col]` | `Tile` with indexing | 2D tile with row and column indices (single element) | -| `tile[pos]` | `Tile` with indexing | 1D tile with element index (single element) | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `vec` | `VRegType` | Vector with scalar broadcast to all lanes | - -**Examples**: -```python -# Byte-offset syntax -vec = pto.vsld(ub_ptr, offset) - -# Element-indexing syntax -vec = pto.vsld(tile[i, j]) # Load single element at (i,j) and broadcast -vec = pto.vsld(tile[k]) # Load single element at position k and broadcast -``` - -### Predicate Operations - -Operations for creating and manipulating typed masks. - -**Recommended API**: For most use cases, prefer the unified `pto.make_mask()` function which automatically selects the appropriate mask granularity based on element type and supports both tail processing (remaining element count) and pattern-based mask generation. This eliminates the need to manually choose between `plt_b8`/`plt_b16`/`plt_b32` (tail processing) and `pset_b8`/`pset_b16`/`pset_b32` (pattern generation) operations. - -**Pattern alias**: For brevity in examples, the documentation uses `PAT` as an alias for `pto.MaskPattern` (e.g., `PAT.ALL` instead of `pto.MaskPattern.PAT_ALL`). In practice, you can create this alias with `from pto import MaskPattern as PAT` or `PAT = pto.MaskPattern`. - -#### `pto.pset_b8(pattern: pto.MaskPattern) -> pto.mask_b8` - -**Description**: Creates an 8-bit granularity mask from a pattern. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `pattern` | `pto.MaskPattern` | Mask pattern enum (e.g., `pto.MaskPattern.PAT_ALL`, `pto.MaskPattern.PAT_EVEN`) | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `mask` | `pto.mask_b8` | 8-bit granularity mask | - -**Constraints**: -- Used with `i8` vector operations - -**Example**: -```python -mask8 = pto.make_mask(pto.i8, PAT.ALL) -``` - -#### `pto.pset_b16(pattern: pto.MaskPattern) -> pto.mask_b16` - -**Description**: Creates a 16-bit granularity mask from a pattern. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `pattern` | `pto.MaskPattern` | Mask pattern enum (e.g., `pto.MaskPattern.PAT_ALL`, `pto.MaskPattern.PAT_EVEN`) | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `mask` | `pto.mask_b16` | 16-bit granularity mask | - -**Constraints**: -- Used with `f16`/`bf16`/`i16` vector operations - -**Example**: -```python -mask16 = pto.make_mask(pto.f16, PAT.ALL) -``` - -#### `pto.pset_b32(pattern: pto.MaskPattern) -> pto.mask_b32` - -**Description**: Creates a 32-bit granularity mask from a pattern. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `pattern` | `pto.MaskPattern` | Mask pattern enum (e.g., `pto.MaskPattern.PAT_ALL`, `pto.MaskPattern.PAT_EVEN`) | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `mask` | `pto.mask_b32` | 32-bit granularity mask | - -**Constraints**: -- Used with `f32`/`i32` vector operations - -**Example**: -```python -mask32 = pto.make_mask(pto.f32, PAT.ALL) -``` - -#### `pto.pge_b8(vec: VRegType, scalar: ScalarType) -> pto.mask_b8` - -**Description**: Creates 8-bit mask where vector elements ≥ scalar. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector (element type must match mask granularity) | -| `scalar` | `ScalarType` | Scalar comparison value | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `mask` | `pto.mask_b8` | 8-bit granularity mask | - -**Constraints**: -- Vector element type must be `i8` or compatible - -#### `pto.pge_b16(vec: VRegType, scalar: ScalarType) -> pto.mask_b16` - -**Description**: Creates 16-bit mask where vector elements ≥ scalar. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector (element type must match mask granularity) | -| `scalar` | `ScalarType` | Scalar comparison value | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `mask` | `pto.mask_b16` | 16-bit granularity mask | - -**Constraints**: -- Vector element type must be `f16`/`bf16`/`i16` - -#### `pto.pge_b32(vec: VRegType, scalar: ScalarType) -> pto.mask_b32` - -**Description**: Creates 32-bit mask where vector elements ≥ scalar. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector (element type must match mask granularity) | -| `scalar` | `ScalarType` | Scalar comparison value | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `mask` | `pto.mask_b32` | 32-bit granularity mask | - -**Constraints**: -- Vector element type must be `f32`/`i32` - -**Example**: -```python -mask = pto.pge_b32(vec_f32, pto.f32(0.0)) -``` - -#### `pto.plt_b8(vec: VRegType, scalar: ScalarType) -> pto.mask_b8` - -**Description**: Creates 8-bit mask where vector elements < scalar. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector (element type must match mask granularity) | -| `scalar` | `ScalarType` | Scalar comparison value | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `mask` | `pto.mask_b8` | 8-bit granularity mask | - -#### `pto.plt_b16(vec: VRegType, scalar: ScalarType) -> pto.mask_b16` - -**Description**: Creates 16-bit mask where vector elements < scalar. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector (element type must match mask granularity) | -| `scalar` | `ScalarType` | Scalar comparison value | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `mask` | `pto.mask_b16` | 16-bit granularity mask | - -#### `pto.plt_b32(vec: VRegType, scalar: ScalarType) -> (pto.mask_b32, pto.i32)` - -**Description**: Creates 32-bit mask where vector elements < scalar, returns mask and remaining count. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector (element type must match mask granularity) | -| `scalar` | `ScalarType` | Scalar comparison value | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `mask` | `pto.mask_b32` | 32-bit granularity mask | -| `remaining` | `pto.i32` | Remaining element count | - -**Example**: -```python -mask, remaining = pto.plt_b32(vec_f32, pto.f32(10.0)) -``` - -#### `pto.make_mask(element_type: Type, value: pto.i32 | pto.MaskPattern) -> MaskType | (MaskType, pto.i32)` - -**Description**: Creates a mask with appropriate bitwidth (8, 16, or 32) based on element type, automatically inferring whether to perform tail processing or pattern-based mask generation based on the `value` parameter type. This convenience function eliminates the need to manually choose between `plt_b8`/`plt_b16`/`plt_b32` and `pset_b8`/`pset_b16`/`pset_b32` operations. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `element_type` | `Type` | Element type (e.g., `pto.f32`, `pto.f16`, `pto.i8`) | -| `value` | `pto.i32` \| `pto.MaskPattern` | Either:
- Remaining element count (as `pto.i32`) for tail processing
- Mask pattern enum value for fixed mask generation (e.g., `pto.MaskPattern.PAT_ALL`, `pto.MaskPattern.PAT_VL32`) | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `mask` | `MaskType` | Generated mask with appropriate granularity | -| `remaining` | `pto.i32` | Updated remaining element count (only returned when `value` is a `pto.i32` for tail processing) | - -**Constraints**: -- The `element_type` must be one of: `f32`, `i32`, `f16`, `bf16`, `i16`, `i8` -- The returned mask granularity matches the element type: 32-bit for `f32`/`i32`, 16-bit for `f16`/`bf16`/`i16`, 8-bit for `i8` -- The function infers the operation mode from the `value` parameter type at compile time: - - `pto.i32` value → tail processing mode (returns `(mask, updated_remaining)`) - - `pto.MaskPattern` enum value → pattern mode (returns `mask` only) - -**Implementation Note**: This function is a DSL macro that performs type-based dispatch at compile time: -- When `value` is a `pto.i32` expression: expands to corresponding `plt_b` instruction (`plt_b32`, `plt_b16`, or `plt_b8`) -- When `value` is a `pto.MaskPattern` enum value: expands to corresponding `pset_b` instruction (`pset_b32`, `pset_b16`, or `pset_b8`) - -**Example**: -```python -# Tail processing with f32 vectors: value is pto.i32 → expands to plt_b32 -mask_f32, remaining_f32 = pto.make_mask(pto.f32, remaining_elements) - -# Tail processing with f16 vectors: value is pto.i32 → expands to plt_b16 -mask_f16, remaining_f16 = pto.make_mask(pto.f16, remaining_elements) - -# Tail processing with i8 vectors: value is pto.i32 → expands to plt_b8 -mask_i8, remaining_i8 = pto.make_mask(pto.i8, remaining_elements) - -# Pattern-based mask with f32 vectors: value is MaskPattern enum → expands to pset_b32 -mask_all_f32 = pto.make_mask(pto.f32, PAT.ALL) - -# Pattern-based mask with f16 vectors: value is MaskPattern enum → expands to pset_b16 -mask_even_f16 = pto.make_mask(pto.f16, PAT.EVEN) - -# Pattern-based mask with i8 vectors: value is MaskPattern enum → expands to pset_b8 -mask_all_i8 = pto.make_mask(pto.i8, PAT.ALL) - -# Type annotations help clarify expected parameter types -remaining: pto.i32 = 1024 -mask1, updated = pto.make_mask(pto.f32, remaining) # tail processing -mask2 = pto.make_mask(pto.f32, PAT.ALL) # pattern mode -``` - -#### `pto.ppack(mask: MaskType, part: str) -> MaskType` - -**Description**: Rearranges a mask according to the requested `part` selector. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `mask` | `MaskType` | Input mask (`mask_b8`, `mask_b16`, or `mask_b32`) | -| `part` | `str` | Part selector such as `"PART_EVEN"` or `"PART_ODD"` | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `packed` | `MaskType` | Reordered mask | - -#### `pto.punpack(mask: MaskType, part: str) -> MaskType` - -**Description**: Applies the inverse mask-part rearrangement selected by `part`. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `mask` | `MaskType` | Input mask | -| `part` | `str` | Part selector such as `"PART_EVEN"` or `"PART_ODD"` | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `mask` | `MaskType` | Reordered mask | - -#### `pto.pnot(mask: MaskType, gate: MaskType) -> MaskType` - -**Description**: Predicate negation under a same-granularity mask gate. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `mask` | `MaskType` | Input mask | -| `gate` | `MaskType` | Gating mask with the same granularity | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `negated` | `MaskType` | Negated mask | - -#### `pto.psel(src0: MaskType, src1: MaskType, mask: MaskType) -> MaskType` - -**Description**: Selects between two masks using a third mask as selector. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `src0` | `MaskType` | First input mask | -| `src1` | `MaskType` | Second input mask | -| `mask` | `MaskType` | Selection mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `MaskType` | Selected mask | - -### Unary Vector Operations - -Element-wise unary operations on vector registers. - -#### `pto.vabs(vec: VRegType, mask: MaskType) -> VRegType` - -**Description**: Absolute value of vector elements. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `mask` | `MaskType` | Predicate mask (granularity must match vector element type) | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Absolute values | - -**Constraints**: -- Mask granularity must match vector element type (e.g., `f32` requires `mask_b32`) - -**Example**: -```python -abs_vec = pto.vabs(vec_f32, mask32) -``` - -#### `pto.vexp(vec: VRegType, mask: MaskType) -> VRegType` - -**Description**: Exponential of vector elements. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Exponential values | - -#### `pto.vln(vec: VRegType, mask: MaskType) -> VRegType` - -**Description**: Natural logarithm of vector elements. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Natural logarithm values | - -#### `pto.vsqrt(vec: VRegType, mask: MaskType) -> VRegType` - -**Description**: Square root of vector elements. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Square root values | - -#### `pto.vrelu(vec: VRegType, mask: MaskType) -> VRegType` - -**Description**: ReLU activation (max(0, x)) of vector elements. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | ReLU-activated values | - -#### `pto.vnot(vec: VRegType, mask: MaskType) -> VRegType` - -**Description**: Bitwise NOT of vector elements. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Bitwise NOT values | - -#### `pto.vcadd(vec: VRegType, mask: MaskType) -> VRegType` - -**Description**: Complex addition of vector elements (treating pairs as complex numbers). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector (interpreted as complex pairs) | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Complex addition result | - -#### `pto.vcmax(vec: VRegType, mask: MaskType) -> VRegType` - -**Description**: Complex maximum of vector elements. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector (interpreted as complex pairs) | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Complex maximum result | - -### Binary Vector Operations - -Element-wise binary operations on vector registers. - -#### `pto.vadd(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` - -**Description**: Element-wise addition of two vectors. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec1` | `VRegType` | First input vector | -| `vec2` | `VRegType` | Second input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Sum of vectors | - -**Example**: -```python -sum_vec = pto.vadd(vec_a, vec_b, mask32) -``` - -#### `pto.vsub(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` - -**Description**: Element-wise subtraction of two vectors. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec1` | `VRegType` | First input vector | -| `vec2` | `VRegType` | Second input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Difference of vectors | - -#### `pto.vmul(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` - -**Description**: Element-wise multiplication of two vectors. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec1` | `VRegType` | First input vector | -| `vec2` | `VRegType` | Second input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Product of vectors | - -#### `pto.vdiv(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` - -**Description**: Element-wise division of two vectors. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec1` | `VRegType` | First input vector | -| `vec2` | `VRegType` | Second input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Quotient of vectors | - -#### `pto.vmax(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` - -**Description**: Element-wise maximum of two vectors. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec1` | `VRegType` | First input vector | -| `vec2` | `VRegType` | Second input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Element-wise maximum | - -#### `pto.vmin(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` - -**Description**: Element-wise minimum of two vectors. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec1` | `VRegType` | First input vector | -| `vec2` | `VRegType` | Second input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Element-wise minimum | - -#### `pto.vand(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` - -**Description**: Element-wise bitwise AND of two vectors. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec1` | `VRegType` | First input vector | -| `vec2` | `VRegType` | Second input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Bitwise AND result | - -#### `pto.vor(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` - -**Description**: Element-wise bitwise OR of two vectors. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec1` | `VRegType` | First input vector | -| `vec2` | `VRegType` | Second input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Bitwise OR result | - -#### `pto.vxor(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` - -**Description**: Element-wise bitwise XOR of two vectors. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec1` | `VRegType` | First input vector | -| `vec2` | `VRegType` | Second input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Bitwise XOR result | - -#### `pto.vshl(vec: VRegType, shift: VRegType, mask: MaskType) -> VRegType` - -**Description**: Element-wise shift left (vector shift amounts). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `shift` | `VRegType` | Shift amounts (per element) | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Shifted values | - -#### `pto.vshr(vec: VRegType, shift: VRegType, mask: MaskType) -> VRegType` - -**Description**: Element-wise shift right (vector shift amounts). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `shift` | `VRegType` | Shift amounts (per element) | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Shifted values | - -### Vector-Scalar Operations - -Operations between vectors and scalars. - -#### `pto.vmuls(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` - -**Description**: Vector multiplied by scalar (broadcast). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `scalar` | `ScalarType` | Scalar multiplier | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Scaled vector | - -**Example**: -```python -scaled = pto.vmuls(vec_f32, pto.f32(2.0), mask32) -``` - -#### `pto.vadds(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` - -**Description**: Vector plus scalar (broadcast). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `scalar` | `ScalarType` | Scalar addend | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Result vector | - -#### `pto.vmaxs(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` - -**Description**: Element-wise maximum of vector and scalar. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `scalar` | `ScalarType` | Scalar value | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Maximum values | - -#### `pto.vmins(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` - -**Description**: Element-wise minimum of vector and scalar. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `scalar` | `ScalarType` | Scalar value | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Minimum values | - -#### `pto.vlrelu(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` - -**Description**: Leaky ReLU activation (max(αx, x)). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `scalar` | `ScalarType` | Alpha coefficient | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Leaky ReLU activated values | - -#### `pto.vshls(vec: VRegType, shift: ScalarType, mask: MaskType) -> VRegType` - -**Description**: Vector shift left by scalar (uniform shift). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `shift` | `ScalarType` | Shift amount (same for all elements) | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Shifted values | - -#### `pto.vshrs(vec: VRegType, shift: ScalarType, mask: MaskType) -> VRegType` - -**Description**: Vector shift right by scalar (uniform shift). - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `shift` | `ScalarType` | Shift amount (same for all elements) | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Shifted values | - -### Carry & Select Operations - -Operations with carry propagation and selection. - -Implemented current-package carry/select surface also includes: -- `pto.vcmp(vec0, vec1, seed_mask, cmp_mode) -> MaskType` -- `pto.vcmps(vec, scalar, seed_mask, cmp_mode) -> MaskType` -- `pto.vselr(vec0, vec1) -> VRegType` -- `pto.vselrv2(vec0, vec1) -> VRegType` -- `pto.vaddcs(vec0, vec1, carry_in, mask) -> (VRegType, MaskType)` -- `pto.vsubcs(vec0, vec1, carry_in, mask) -> (VRegType, MaskType)` - -#### `pto.vaddc(vec1: VRegType, vec2: VRegType, mask: MaskType) -> (VRegType, MaskType)` - -**Description**: Vector addition with carry output. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec1` | `VRegType` | First input vector | -| `vec2` | `VRegType` | Second input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Sum vector | -| `carry_out` | `MaskType` | Output carry mask | - -#### `pto.vsubc(vec1: VRegType, vec2: VRegType, mask: MaskType) -> (VRegType, MaskType)` - -**Description**: Vector subtraction with borrow output. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec1` | `VRegType` | First input vector | -| `vec2` | `VRegType` | Second input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Difference vector | -| `borrow_out` | `MaskType` | Output borrow mask | - -#### `pto.vsel(true_vec: VRegType, false_vec: VRegType, mask: MaskType) -> VRegType` - -**Description**: Vector select based on mask. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `true_vec` | `VRegType` | Vector selected when mask bit is 1 | -| `false_vec` | `VRegType` | Vector selected when mask bit is 0 | -| `mask` | `MaskType` | Selection mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Selected vector | - -**Example**: -```python -result = pto.vsel(scaled_vec, original_vec, mask32) -``` - -### Data Rearrangement - -Operations for rearranging data within vectors. - -#### `pto.pdintlv_b8(mask: pto.mask_b8) -> pto.mask_b8` - -**Description**: Deinterleave 8-bit mask. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `mask` | `pto.mask_b8` | Input 8-bit mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `pto.mask_b8` | Deinterleaved mask | - -#### `pto.pintlv_b16(mask: pto.mask_b16) -> pto.mask_b16` - -**Description**: Interleave 16-bit mask. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `mask` | `pto.mask_b16` | Input 16-bit mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `pto.mask_b16` | Interleaved mask | - -Implemented current-package rearrangement surface also includes: -- `pto.vintlvv2(vec0, vec1, part) -> VRegType` -- `pto.vdintlvv2(vec0, vec1, part) -> VRegType` - -#### `pto.vintlv(vec1: VRegType, vec2: VRegType) -> (VRegType, VRegType)` - -**Description**: Interleave two vectors and return the low/high results. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec1` | `VRegType` | First input vector | -| `vec2` | `VRegType` | Second input vector | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `low` | `VRegType` | Low interleaved result | -| `high` | `VRegType` | High interleaved result | - -#### `pto.vdintlv(vec0: VRegType, vec1: VRegType) -> (VRegType, VRegType)` - -**Description**: Deinterleave a pair of vectors into low/high results. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec0` | `VRegType` | First input vector | -| `vec1` | `VRegType` | Second input vector | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `vec1` | `VRegType` | First deinterleaved vector | -| `vec2` | `VRegType` | Second deinterleaved vector | - -### Conversion & Special Operations - -Type conversion and specialized operations. - -#### `pto.vtrc(vec: VRegType, mask: MaskType, rnd: str) -> VRegType` - -**Description**: Truncate/round floating-point vector elements to integer-valued -floating-point results under an explicit predicate mask. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `mask` | `MaskType` | Predicate mask; granularity must match element width | -| `rnd` | `str` | Round mode: `R`, `A`, `F`, `C`, or `Z` | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Rounded result with the same floating-point element type | - -#### `pto.vcvt(vec: VRegType, to_type: Type, mask: MaskType) -> VRegType` - -**Description**: Type conversion of vector elements. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `to_type` | `Type` | Target element type | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Converted vector | - -#### `pto.vbitsort(vec: VRegType, mask: MaskType) -> VRegType` - -**Description**: Bitonic sort of vector elements. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Sorted vector | - -#### `pto.vmrgsort4(vec1: VRegType, vec2: VRegType, vec3: VRegType, vec4: VRegType, mask: MaskType) -> VRegType` - -**Description**: 4-way merge sort of vectors. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec1` | `VRegType` | First input vector | -| `vec2` | `VRegType` | Second input vector | -| `vec3` | `VRegType` | Third input vector | -| `vec4` | `VRegType` | Fourth input vector | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Merged and sorted vector | - -### Stateless Store Operations - -Operations for storing data from vector registers to memory (stateless). - -#### `pto.vsts(vec: VRegType, buf: UBRef, offset: Index, mask: MaskType) -> None` -#### `pto.vsts(vec: VRegType, tile[row, col:], mask: MaskType) -> None` -#### `pto.vsts(vec: VRegType, tile[start:], mask: MaskType) -> None` - -**Description**: Stateless vector store to buffer. Supports both byte-offset and element-indexing syntax. - -**Parameters (byte-offset syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Vector to store | -| `buf` | `UBRef` | Destination buffer or pointer (UB memory space) | -| `offset` | `Index` | Byte offset | -| `mask` | `MaskType` | Predicate mask | - -**Parameters (element-indexing syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Vector to store | -| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column | -| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: None (side-effect operation) - -**Constraints**: -- Buffer must be in UB memory space -- For byte-offset syntax: offset must be properly aligned based on element type -- For element-indexing syntax: the destination vector region must be within tile bounds and satisfy alignment requirements - -**Examples**: -```python -# Byte-offset syntax -pto.vsts(vec_f32, ub_ptr, lane * 256, mask32) - -# Element-indexing syntax -pto.vsts(vec, tile[i, j:], mask) # Store to row i, columns j to j+vector_lanes-1 -pto.vsts(vec, tile[k:], mask) # Store to 1D tile, elements k to k+vector_lanes-1 - -# In a generic kernel -@pto.vkernel(target="a5", op="copy", dtypes=[(pto.AnyFloat, pto.AnyFloat)], priority=10) -def generic_store(src: pto.Tile, dst: pto.Tile): - rows, cols = src.shape - all_mask = pto.make_mask(src.element_type, PAT.ALL) - for i in range(0, rows): - for j in range(0, cols, vector_lanes): - vec = pto.vlds(src[i, j:]) - pto.vsts(vec, dst[i, j:], all_mask) # No manual offset calculation -``` - -#### `pto.psts(mask: MaskType, buf: UBRef, offset: Index) -> None` -#### `pto.psts(mask: MaskType, tile[row, col:]) -> None` -#### `pto.psts(mask: MaskType, tile[start:]) -> None` - -**Description**: Predicate store to buffer. Supports both traditional byte-offset syntax and new element-indexing syntax. - -**Parameters (byte-offset syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `mask` | `MaskType` | Mask to store | -| `buf` | `UBRef` | Destination buffer or pointer | -| `offset` | `Index` | Byte offset | - -**Parameters (element-indexing syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `mask` | `MaskType` | Mask to store | -| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | - -**Parameters (1D element-indexing syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `mask` | `MaskType` | Mask to store | -| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | - -**Returns**: None (side-effect operation) - -#### `pto.vsst(scalar: ScalarType, buf: UBRef, offset: Index, mask: MaskType) -> None` -#### `pto.vsst(scalar: ScalarType, tile[row, col:], mask: MaskType) -> None` -#### `pto.vsst(scalar: ScalarType, tile[start:], mask: MaskType) -> None` - -**Description**: Scalar to vector store (broadcast scalar to all lanes). Supports both traditional byte-offset syntax and new element-indexing syntax. - -**Parameters (byte-offset syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `scalar` | `ScalarType` | Scalar value | -| `buf` | `UBRef` | Destination buffer or pointer | -| `offset` | `Index` | Byte offset | -| `mask` | `MaskType` | Predicate mask | - -**Parameters (element-indexing syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `scalar` | `ScalarType` | Scalar value | -| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | -| `mask` | `MaskType` | Predicate mask | - -**Parameters (1D element-indexing syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `scalar` | `ScalarType` | Scalar value | -| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: None (side-effect operation) - -#### `pto.vstx2(vec1: VRegType, vec2: VRegType, buf1: UBRef, buf2: UBRef, offset: Index, mask: MaskType) -> None` -#### `pto.vstx2(vec1: VRegType, vec2: VRegType, tile1[row, col:], tile2[row, col:], mask: MaskType) -> None` -#### `pto.vstx2(vec1: VRegType, vec2: VRegType, tile1[start:], tile2[start:], mask: MaskType) -> None` - -**Description**: Dual vector store to two buffers. Supports both traditional byte-offset syntax and new element-indexing syntax. - -**Parameters (byte-offset syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec1` | `VRegType` | First vector to store | -| `vec2` | `VRegType` | Second vector to store | -| `buf1` | `UBRef` | First destination buffer | -| `buf2` | `UBRef` | Second destination buffer | -| `offset` | `Index` | Byte offset (applied to both buffers) | -| `mask` | `MaskType` | Predicate mask | - -**Parameters (element-indexing syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec1` | `VRegType` | First vector to store | -| `vec2` | `VRegType` | Second vector to store | -| `tile1[row, col:]` | `Tile` with indexing | First 2D tile with row index and starting column (vector-width range) | -| `tile2[row, col:]` | `Tile` with indexing | Second 2D tile with row index and starting column (vector-width range) | -| `mask` | `MaskType` | Predicate mask | - -**Parameters (1D element-indexing syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec1` | `VRegType` | First vector to store | -| `vec2` | `VRegType` | Second vector to store | -| `tile1[start:]` | `Tile` with indexing | First 1D tile with starting element index (vector-width range) | -| `tile2[start:]` | `Tile` with indexing | Second 1D tile with starting element index (vector-width range) | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: None (side-effect operation) - -#### `pto.vsta(vec: VRegType, buf: UBRef, offset: Index, align: pto.align, mask: MaskType) -> None` -#### `pto.vsta(vec: VRegType, tile[row, col:], align: pto.align, mask: MaskType) -> None` -#### `pto.vsta(vec: VRegType, tile[start:], align: pto.align, mask: MaskType) -> None` - -**Description**: Aligned vector store with explicit alignment carrier. Supports both traditional byte-offset syntax and new element-indexing syntax. - -**Parameters (byte-offset syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Vector to store | -| `buf` | `UBRef` | Destination buffer or pointer | -| `offset` | `Index` | Byte offset | -| `align` | `pto.align` | Alignment specification | -| `mask` | `MaskType` | Predicate mask | - -**Parameters (element-indexing syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Vector to store | -| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | -| `align` | `pto.align` | Alignment specification | -| `mask` | `MaskType` | Predicate mask | - -**Parameters (1D element-indexing syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Vector to store | -| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | -| `align` | `pto.align` | Alignment specification | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: None (side-effect operation) - -### Stateful Store Operations - -Operations for storing data with stateful semantics. - -#### `pto.pstu(mask: MaskType, buf: UBRef, offset: Index) -> None` - -**Description**: Predicate stateful store. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `mask` | `MaskType` | Mask to store | -| `buf` | `UBRef` | Destination buffer or pointer | -| `offset` | `Index` | Byte offset | - -**Returns**: None (side-effect operation) - -#### `pto.vstu(vec: VRegType, buf: UBRef, offset: Index, mask: MaskType) -> None` - -**Description**: Vector stateful store. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `vec` | `VRegType` | Vector to store | -| `buf` | `UBRef` | Destination buffer or pointer | -| `offset` | `Index` | Byte offset | -| `mask` | `MaskType` | Predicate mask | - -**Returns**: None (side-effect operation) - -#### `pto.vstus(align_in: AlignType, offset: i32, vec: VRegType, buf: UBRef) -> AlignType` - -**Description**: No-post unaligned vector store with scalar offset. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `align_in` | `AlignType` | Incoming unaligned-store state | -| `offset` | `i32` | Stream advance offset in elements | -| `vec` | `VRegType` | Vector to store | -| `buf` | `UBRef` | UB destination base pointer | - -**Returns**: Updated align-state token for a later flush op such as `pto.vstas`. - -#### `pto.vstur(align_in: AlignType, vec: VRegType, buf: UBRef, mode: str) -> AlignType` - -**Description**: Unaligned vector store using the SPR-AR-driven stateful form. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `align_in` | `AlignType` | Incoming unaligned-store state | -| `vec` | `VRegType` | Vector to store | -| `buf` | `UBRef` | UB destination base pointer | -| `mode` | `str` | `POST_UPDATE` or `NO_POST_UPDATE` | - -**Returns**: Updated align-state token for a later flush op such as `pto.vstar`. - -## Examples - -### Simple Vector Copy - -```python -@pto.vkernel(...) -def vector_copy(src: pto.memref(256, pto.f32, MemorySpace.UB), - dst: pto.memref(256, pto.f32, MemorySpace.UB)): - all_mask = pto.make_mask(pto.f32, PAT.ALL) - for offset in range(0, 256, 64): - vec = pto.vlds(src, offset) - pto.vsts(vec, dst, offset, all_mask) -``` - -### Conditional Computation - -```python -@pto.vkernel(...) -def conditional_scale(src: pto.ptr(pto.f32, MemorySpace.GM), - dst: pto.ptr(pto.f32, MemorySpace.GM), - threshold: pto.f32): - # ... setup ... - - with pto.strict_vecscope(ub_in, ub_out, threshold) as (vin, vout, thresh): - for i in range(0, 1024, 64): - vec = pto.vlds(vin, i) - - # Compare with threshold - mask = pto.pge_b32(vec, thresh) - - # Scale values above threshold - scaled = pto.vmuls(vec, pto.f32(2.0), mask) - - # Keep original values below threshold - result = pto.vsel(scaled, vec, mask) - - pto.vsts(result, vout, i, all_mask) -``` - -### Loop with Carry - -```python -@pto.vkernel(...) -def prefix_sum(src: pto.ptr(pto.i32, MemorySpace.UB), - dst: pto.ptr(pto.i32, MemorySpace.UB)): - all_mask = pto.make_mask(pto.i32, PAT.ALL) - carry = all_mask - - for i in range(0, 256, 64): - vec = pto.vlds(src, i) - result, carry = pto.vaddcs(vec, vec, carry, all_mask) - pto.vsts(result, dst, i, all_mask) -``` - -## Common Errors - -### Typed Mask Mismatch - -``` -Error: f32 vector operation cannot consume mask_b16 -``` - -**Solution:** Ensure mask granularity matches vector element size: -- `f32` vectors use `mask_b32` -- `f16` vectors use `mask_b16` -- `i8` vectors use `mask_b8` - -### Strict Scope Implicit Capture - -``` -Error: strict_vecscope body cannot capture outer value 'ub_in' implicitly -``` - -**Solution:** Pass all required values in the capture list: - -```python -# Wrong: -with pto.strict_vecscope() as (): - vec = pto.vlds(ub_in, offset) # ub_in from outer scope - -# Correct: -with pto.strict_vecscope(ub_in) as (ub): - vec = pto.vlds(ub, offset) -``` - -### Untyped Loop Carried State - -``` -Error: loop-carried value must have explicit machine type -``` - -**Solution:** Add type annotations to loop-carried variables: - -```python -# Wrong: -remaining = 1024 # Plain Python int -for i in range(0, N, step): - mask, remaining = pto.make_mask(pto.f32, remaining) - -# Correct: -remaining: pto.i32 = 1024 -# or -remaining = pto.i32(1024) -``` - -## Compatibility Notes - -The current experimental implementation in `python/pto/dialects/pto.py` differs from this specification in several ways: - -1. **Mask types**: The experimental version uses untyped `mask` instead of `mask_b8`/`mask_b16`/`mask_b32` -2. **Barrier operation**: Uses `pto.barrier()` instead of `pto.pipe_barrier()` -3. **MemRef support**: Does not yet support `pto.memref()` types -4. **Operation coverage**: Implements only a subset of operations - -When implementing new code, follow this specification. The experimental implementation will be updated to match over time. - -## Next Steps - -- Explore the ISA documentation in `docs/isa/` for detailed operation semantics -- Check `test/samples/` for example kernels -- Refer to `docs/vpto-spec.md` for the underlying VPTO instruction specification - -For compiler developers, see `docs/PTO_IR_manual.md` for MLIR-level details. diff --git a/docs/vpto-spec.md b/docs/vpto-spec.md index c20942b75..5d7758f21 100644 --- a/docs/vpto-spec.md +++ b/docs/vpto-spec.md @@ -452,8 +452,62 @@ Typical examples: - `!pto.ptr` - `!pto.ptr` +### Tensor View Metadata Query Ops + +VPTO source programs may keep GM tensor operands in logical `!pto.tensor_view` +form instead of exposing them as raw memrefs. Two metadata-query ops are used to +read shape and stride information from that logical view: + +#### `pto.get_tensor_view_dim` + +- **syntax:** `%dim = pto.get_tensor_view_dim %tv, %idx : !pto.tensor_view<...> -> index` +- **semantics:** Returns the runtime extent of dimension `%idx` from the logical tensor view. + +```c +dim = tv.shape[idx]; +``` + +Example: + +```mlir +%d2 = pto.get_tensor_view_dim %src, %c2 : !pto.tensor_view -> index +``` + +#### `pto.get_tensor_view_stride` + +- **syntax:** `%stride = pto.get_tensor_view_stride %tv, %idx : !pto.tensor_view<...> -> index` +- **semantics:** Returns the logical stride of dimension `%idx`, measured in elements rather than bytes. + +```c +stride = tv.strides[idx]; +``` + +Example: + +```mlir +%s2 = pto.get_tensor_view_stride %src, %c2 : !pto.tensor_view -> index +``` + +Notes: + +- These ops are metadata queries only and do not trigger any hardware pipeline activity. +- In authoring-form IR, they operate on `!pto.tensor_view`. +- During compiler-internal lowering, they may be rewritten to equivalent memref metadata queries such as `memref.dim` and extracted strided metadata. + ### Pointer Operations +#### `pto.tensor_view_addr` + +- **syntax:** `%result = pto.tensor_view_addr %src : !pto.tensor_view<...> -> memref<...>` +- **syntax:** `%result = pto.tensor_view_addr %src : !pto.tensor_view<...> -> !pto.ptr` +- **semantics:** Extract the underlying address view from a `tensor_view` or `partition_tensor_view`. + +```c +result = addr_of(src); +``` + +`pto.tensor_view_addr` is an address-extraction operation. It does not move data and does not by itself imply any hardware side effect. When the result type is a memref, it exposes the lowered view directly. When the result type is `!pto.ptr<..., gm>`, it exposes the same address in pointer form. After compiler-internal view lowering, the operand may already be a memref; in that case the op is folded away or rewritten to an equivalent memref-to-ptr cast. + #### `pto.castptr` - **syntax:** `%result = pto.castptr %addr : i64 -> !pto.ptr` diff --git a/include/PTO/IR/PTOOps.td b/include/PTO/IR/PTOOps.td index e471b0f96..1b4a9c43d 100644 --- a/include/PTO/IR/PTOOps.td +++ b/include/PTO/IR/PTOOps.td @@ -214,8 +214,9 @@ def PartitionViewOp : PTO_Op<"partition_view", [AttrSizedOperandSegments]> { } // Helper: tensor_view or memref (after lowering tensor_view to memref). -def TensorViewOrMemRef : - AnyTypeOf<[TensorViewType, AnyMemRef], "TensorView or MemRef">; +def TensorViewLikeOrMemRef : + AnyTypeOf<[TensorViewType, PartitionTensorViewType, AnyMemRef], + "TensorView, PartitionTensorView, or MemRef">; // Get the size of a dimension of a tensor_view or its lowered memref view. // Result type: Index (use arith.index_cast if i32 is needed). @@ -232,7 +233,27 @@ def GetTensorViewDimOp : PTO_Op<"get_tensor_view_dim", [Pure]> { : memref<...>, index -> index }]; let arguments = (ins - TensorViewOrMemRef:$tensor_view, + TensorViewLikeOrMemRef:$tensor_view, + Index:$dim_index + ); + let results = (outs Index:$result); + let assemblyFormat = [{ + $tensor_view `,` $dim_index `:` qualified(type($tensor_view)) `->` qualified(type($result)) + attr-dict + }]; +} + +// Get the logical stride of a tensor_view dimension in elements. +// Result type: Index (use arith.index_cast if i32 is needed). +def GetTensorViewStrideOp : PTO_Op<"get_tensor_view_stride", [Pure]> { + let summary = "Get the stride of a dimension of a tensor_view."; + let description = [{ + Returns the stride, measured in elements, of the given dimension of a + logical tensor view. This op accepts either !pto.tensor_view or the memref + it is lowered to. + }]; + let arguments = (ins + TensorViewLikeOrMemRef:$tensor_view, Index:$dim_index ); let results = (outs Index:$result); @@ -2197,25 +2218,6 @@ def SimdTileToMemrefOp : PTO_Op<"simd.tile_to_memref", [Pure]> { // After inline, FoldTileBufIntrinsics resolves them to concrete values. // --------------------------------------------------------------------------- -def TileBufAddrOp : PTO_Op<"tile_buf_addr", [Pure]> { - let summary = "Extract memref address from a tile_buf."; - let description = [{ - Returns a memref view of the data region of a `tile_buf` value. - The result memref has the same element type, shape (from tile_buf's static - shape), and address space as the source tile_buf, with row-major strides. - - This op is emitted by TileLang DSL templates and resolved by the - FoldTileBufIntrinsics pass after inlining. - }]; - - let arguments = (ins TileBufType:$src); - let results = (outs AnyMemRef:$dst); - - let assemblyFormat = [{ - $src attr-dict `:` qualified(type($src)) `->` type($dst) - }]; -} - def TileValidRowsOp : PTO_Op<"tile_valid_rows", [Pure]> { let summary = "Extract valid row count from a tile_buf."; let description = [{ diff --git a/include/PTO/IR/VPTOOps.td b/include/PTO/IR/VPTOOps.td index dc416e868..3daf2092d 100644 --- a/include/PTO/IR/VPTOOps.td +++ b/include/PTO/IR/VPTOOps.td @@ -34,6 +34,54 @@ def PTO_BufferType : Type< def PTO_BufferLikeType : AnyTypeOf<[AnyMemRef, PTO_BufferType], "memref or pointer-like buffer type">; +def TensorViewAddrOp : PTO_Op<"tensor_view_addr", [Pure]> { + let summary = "Extract address from a tensor view."; + let description = [{ + Returns the address view carried by a `tensor_view` or + `partition_tensor_view` value. The result may be either a memref view or a + typed PTO pointer, depending on the requested destination type. + + In authoring-form IR this op preserves the descriptor-style surface; + during view-to-memref lowering it collapses to the underlying memref value + or to a memref-derived pointer. + + This op may also accept a memref operand after earlier view lowering, in + which case it behaves as an identity marker and is removed by lowering. + }]; + + let arguments = (ins AnyTypeOf<[TensorViewType, PartitionTensorViewType, AnyMemRef], + "TensorViewLike or MemRef">:$src); + let results = (outs PtrOrMemRef:$dst); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src attr-dict `:` qualified(type($src)) `->` type($dst) + }]; +} + +def TileBufAddrOp : PTO_Op<"tile_buf_addr", [Pure]> { + let summary = "Extract address from a tile_buf."; + let description = [{ + Returns the address view of the data region of a `tile_buf` value. + The result may be either a memref view or a typed PTO pointer, depending + on the requested destination type. Memref results use the tile's static + shape and address space. + + This op is emitted by TileLang DSL templates and resolved by the + FoldTileBufIntrinsics pass after inlining. + }]; + + let arguments = (ins TileBufType:$src); + let results = (outs PtrOrMemRef:$dst); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src attr-dict `:` qualified(type($src)) `->` type($dst) + }]; +} + def VecScopeOp : PTO_Op<"vecscope", [SingleBlock, NoTerminator]> { let summary = "Structured region container for one VPTO vector scope"; let description = [{ @@ -775,14 +823,15 @@ def PTO_PintlvB32Op : PTO_PredicatePairReorderOp<"pintlv_b32", class PTO_VecScalarOp : PTO_Op { let arguments = (ins PTO_VectorType:$input, - AnyType:$scalar + AnyType:$scalar, + PTO_MaskTypeConstraint:$mask ); let results = (outs PTO_VectorType:$result); let hasVerifier = 1; let assemblyFormat = [{ - $input `,` $scalar attr-dict `:` type($input) `,` type($scalar) `->` type($result) + $input `,` $scalar `,` $mask attr-dict `:` type($input) `,` type($scalar) `,` type($mask) `->` type($result) }]; } diff --git a/include/PTO/Transforms/Passes.h b/include/PTO/Transforms/Passes.h index 52ca26798..120401299 100644 --- a/include/PTO/Transforms/Passes.h +++ b/include/PTO/Transforms/Passes.h @@ -66,6 +66,8 @@ std::unique_ptr createInferPTOLayoutPass(); std::unique_ptr createPTOA5NormalizeTMovPass(); std::unique_ptr createPTOVPTOExpandBridgeOpsPass(); std::unique_ptr createPTOVPTOPtrBoundaryPass(); +std::unique_ptr createVPTOPtrNormalizePass(); +std::unique_ptr createVPTOPtrCastCleanupPass(); std::unique_ptr createPTOValidateVPTOIRPass(); std::unique_ptr createPTOValidateVPTOEmissionIRPass(); std::unique_ptr createLowerPTOToVPTOPass(); diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index 4ac92e7d6..0ec3877af 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -283,8 +283,8 @@ def FoldTileBufIntrinsics : Pass<"pto-fold-tile-buf-intrinsics", "mlir::func::Fu whose tile_buf operands are now bound to concrete values. This pass resolves them: - - pto.tile_buf_addr → replaced by pto.simd.tile_to_memref (extracts - memref address from tile_buf) + - pto.tile_buf_addr → replaced by the underlying pto.bind_tile source + memref, or pto.castptr when the requested result type is !pto.ptr - pto.tile_valid_rows → folded to arith.constant if v_row is static, or replaced with the dynamic index value from tile_buf - pto.tile_valid_cols → same as above for v_col @@ -419,6 +419,38 @@ def PTOVPTOPtrBoundary "mlir::memref::MemRefDialect"]; } +def VPTOPtrNormalize + : Pass<"vpto-ptr-normalize", "ModuleOp"> { + let summary = + "Normalize VPTO ptr-like values and users into a uniform !pto.ptr form"; + let description = [{ + Uses MLIR's conversion framework to normalize VPTO ptr-related forms before + the existing VPTO ptr-boundary canonicalization runs. This pass rewrites + supported tile-buffer and memref view producers such as `pto.tile_buf_addr` + and `memref.subview`, and updates VPTO memory ops to consume the normalized + ptr-form consistently. + }]; + let constructor = "mlir::pto::createVPTOPtrNormalizePass()"; + let dependentDialects = ["mlir::func::FuncDialect", + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::arith::ArithDialect"]; +} + +def VPTOPtrCastCleanup + : Pass<"vpto-ptr-cast-cleanup", "ModuleOp"> { + let summary = "Collapse transient ptr-memref-ptr bridge casts after VPTO ptr normalization"; + let description = [{ + Eliminates bridge chains such as + `!pto.ptr -> builtin.unrealized_conversion_cast -> memref.cast -> + builtin.unrealized_conversion_cast -> !pto.ptr` + when the outer ptr types already match. + }]; + let constructor = "mlir::pto::createVPTOPtrCastCleanupPass()"; + let dependentDialects = ["mlir::pto::PTODialect", + "mlir::memref::MemRefDialect"]; +} + def PTOToVPTO : Pass<"pto-to-vpto", "ModuleOp"> { let summary = "Lower PTO tile ops to VPTO backend ops"; let description = [{ diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp index 02351975d..1d8853675 100644 --- a/lib/PTO/IR/VPTO.cpp +++ b/lib/PTO/IR/VPTO.cpp @@ -1563,6 +1563,88 @@ LogicalResult VdupOp::verify() { return success(); } +LogicalResult TensorViewAddrOp::verify() { + Type srcType = getSrc().getType(); + Type dstType = getDst().getType(); + + Type elementType; + int64_t expectedRank = -1; + auto gmSpace = pto::AddressSpaceAttr::get(getContext(), pto::AddressSpace::GM); + + if (auto tvType = dyn_cast(srcType)) { + elementType = tvType.getElementType(); + expectedRank = tvType.getRank(); + } else if (auto partType = dyn_cast(srcType)) { + elementType = partType.getElementType(); + expectedRank = partType.getRank(); + } else if (auto memrefType = dyn_cast(srcType)) { + elementType = memrefType.getElementType(); + expectedRank = memrefType.getRank(); + auto srcSpace = dyn_cast_or_null(memrefType.getMemorySpace()); + if (srcSpace && srcSpace != gmSpace) + return emitOpError("memref source must stay in gm memory space"); + } else { + return emitOpError( + "source must be a tensor_view, partition_tensor_view, or memref"); + } + + if (auto dstMemRefType = dyn_cast(dstType)) { + if (dstMemRefType.getElementType() != elementType) + return emitOpError( + "memref result element type must match source element type"); + if (dstMemRefType.getRank() != expectedRank) + return emitOpError("memref result rank must match source rank"); + auto dstSpace = + dyn_cast_or_null(dstMemRefType.getMemorySpace()); + if (dstSpace && dstSpace != gmSpace) + return emitOpError("memref result must stay in gm memory space"); + return success(); + } + + auto dstPtrType = dyn_cast(dstType); + if (!dstPtrType) + return emitOpError("result must be a memref or !pto.ptr<...>"); + if (dstPtrType.getElementType() != elementType) + return emitOpError( + "pointer result element type must match source element type"); + if (dstPtrType.getMemorySpace() != gmSpace) + return emitOpError("pointer result must stay in gm memory space"); + return success(); +} + +LogicalResult TileBufAddrOp::verify() { + auto srcType = dyn_cast(getSrc().getType()); + if (!srcType) + return emitOpError("source must be a !pto.tile_buf<...>"); + + Type dstType = getDst().getType(); + Type elementType = srcType.getElementType(); + auto srcSpace = dyn_cast_or_null(srcType.getMemorySpace()); + + if (auto dstMemRefType = dyn_cast(dstType)) { + if (dstMemRefType.getElementType() != elementType) + return emitOpError( + "memref result element type must match tile element type"); + if (dstMemRefType.getRank() != static_cast(srcType.getShape().size())) + return emitOpError("memref result rank must match tile rank"); + auto dstSpace = + dyn_cast_or_null(dstMemRefType.getMemorySpace()); + if (srcSpace && dstSpace && srcSpace != dstSpace) + return emitOpError("memref result must stay within the tile memory space"); + return success(); + } + + auto dstPtrType = dyn_cast(dstType); + if (!dstPtrType) + return emitOpError("result must be a memref or !pto.ptr<...>"); + if (dstPtrType.getElementType() != elementType) + return emitOpError( + "pointer result element type must match tile element type"); + if (srcSpace && dstPtrType.getMemorySpace() != srcSpace) + return emitOpError("pointer result must stay within the tile memory space"); + return success(); +} + LogicalResult PsetB8Op::verify() { if (failed(verifyMaskTypeWithGranularityLike(*this, getResult().getType(), "result type", "b8"))) diff --git a/lib/PTO/Transforms/CMakeLists.txt b/lib/PTO/Transforms/CMakeLists.txt index 8891f2b8e..39637791b 100644 --- a/lib/PTO/Transforms/CMakeLists.txt +++ b/lib/PTO/Transforms/CMakeLists.txt @@ -15,6 +15,8 @@ add_mlir_dialect_library(PTOTransforms HIVMIntrinsicNaming.cpp VPTOLLVMEmitter.cpp VPTOLLVMEmitterHelper.cpp + VPTOPtrNormalize.cpp + VPTOPtrCastCleanup.cpp PTOVPTOExpandBridgeOps.cpp PTOVPTOPtrBoundary.cpp PTOToVPTO.cpp @@ -27,6 +29,7 @@ add_mlir_dialect_library(PTOTransforms MemrefToTileBuf.cpp ExpandTileOp.cpp FoldTileBufIntrinsics.cpp + PTOLowerToOpLibCalls.cpp PTOInstantiateAndInlineOpLib.cpp PTOToEmitC.cpp Utils.cpp diff --git a/lib/PTO/Transforms/ExpandTileOp.cpp b/lib/PTO/Transforms/ExpandTileOp.cpp index 945d2cc2e..9257e4fa4 100644 --- a/lib/PTO/Transforms/ExpandTileOp.cpp +++ b/lib/PTO/Transforms/ExpandTileOp.cpp @@ -72,18 +72,26 @@ namespace { // ============================================================================ // OperandTypeInfo: captures the tile_buf type info for one operand. // ============================================================================ +enum class OperandKind { + Tile, + Scalar, +}; + struct OperandTypeInfo { + OperandKind kind = OperandKind::Tile; std::string dtype; SmallVector shape; int32_t blayout = 0; int32_t slayout = 0; int32_t fractal = 0; int32_t pad = 0; + std::string memorySpace; bool operator==(const OperandTypeInfo &rhs) const { - return dtype == rhs.dtype && shape == rhs.shape && + return kind == rhs.kind && dtype == rhs.dtype && shape == rhs.shape && blayout == rhs.blayout && slayout == rhs.slayout && - fractal == rhs.fractal && pad == rhs.pad; + fractal == rhs.fractal && pad == rhs.pad && + memorySpace == rhs.memorySpace; } }; @@ -105,8 +113,10 @@ struct SpecKeyInfo : public llvm::DenseMapInfo { static unsigned getHashValue(const SpecKey &key) { unsigned h = llvm::hash_value(key.opName); for (const auto &op : key.operands) { - h = llvm::hash_combine(h, op.dtype, op.blayout, op.slayout, + h = llvm::hash_combine(h, static_cast(op.kind), op.dtype, + op.blayout, op.slayout, op.fractal, op.pad); + h = llvm::hash_combine(h, op.memorySpace); for (int64_t d : op.shape) h = llvm::hash_combine(h, d); } @@ -121,9 +131,11 @@ struct SpecKeyInfo : public llvm::DenseMapInfo { // Helpers // ============================================================================ static std::string getDtypeString(Type elemTy) { + if (elemTy.isInteger(1)) return "i1"; if (elemTy.isF32()) return "f32"; if (elemTy.isF16()) return "f16"; if (elemTy.isBF16()) return "bf16"; + if (elemTy.isSignlessInteger(64)) return "i64"; if (elemTy.isSignlessInteger(32)) return "i32"; if (elemTy.isSignlessInteger(16)) return "i16"; if (elemTy.isSignlessInteger(8)) return "i8"; @@ -144,10 +156,12 @@ static std::string getMemorySpaceString(pto::TileBufType tbTy) { static std::optional buildOperandTypeInfo(pto::TileBufType tbTy) { OperandTypeInfo info; + info.kind = OperandKind::Tile; info.dtype = getDtypeString(tbTy.getElementType()); if (info.dtype.empty()) return std::nullopt; info.shape.assign(tbTy.getShape().begin(), tbTy.getShape().end()); + info.memorySpace = getMemorySpaceString(tbTy); if (auto config = tbTy.getConfigAttr()) { info.blayout = static_cast(config.getBLayout().getValue()); info.slayout = static_cast(config.getSLayout().getValue()); @@ -159,15 +173,24 @@ buildOperandTypeInfo(pto::TileBufType tbTy) { return info; } +static std::optional buildOperandTypeInfo(Type ty) { + if (auto tbTy = dyn_cast(ty)) + return buildOperandTypeInfo(tbTy); + + OperandTypeInfo info; + info.kind = OperandKind::Scalar; + info.dtype = getDtypeString(ty); + if (info.dtype.empty()) + return std::nullopt; + return info; +} + static std::optional buildSpecKey(Operation *op) { SpecKey key; key.opName = getTileOpName(op).str(); for (unsigned i = 0; i < op->getNumOperands(); ++i) { - auto tbTy = dyn_cast(op->getOperand(i).getType()); - if (!tbTy) - return std::nullopt; - auto info = buildOperandTypeInfo(tbTy); + auto info = buildOperandTypeInfo(op->getOperand(i).getType()); if (!info) return std::nullopt; key.operands.push_back(*info); @@ -206,6 +229,34 @@ struct ExpandTileOpPass void runOnOperation() override; }; +static std::string buildOperandSpecsJson(const SpecKey &key) { + std::string json = "["; + for (size_t i = 0; i < key.operands.size(); ++i) { + const auto &op = key.operands[i]; + if (i > 0) + json += ","; + if (op.kind == OperandKind::Tile) { + json += "{\"kind\":\"tile\",\"dtype\":\""; + json += op.dtype; + json += "\",\"shape\":["; + for (size_t dim = 0; dim < op.shape.size(); ++dim) { + if (dim > 0) + json += ","; + json += std::to_string(op.shape[dim]); + } + json += "],\"memory_space\":\""; + json += op.memorySpace; + json += "\"}"; + continue; + } + json += "{\"kind\":\"scalar\",\"dtype\":\""; + json += op.dtype; + json += "\"}"; + } + json += "]"; + return json; +} + // ============================================================================ // Invoke Python DSL helper to generate a specialized template function. // ============================================================================ @@ -224,18 +275,8 @@ func::FuncOp ExpandState::invokeTilelangDSL(const SpecKey &key, return nullptr; } - // 2. Build shape string from the first operand (e.g. "16,64"). - // TODO: extend expand_helper to accept per-operand shapes if needed. - const auto &firstOp = key.operands[0]; - std::string shapeStr; - for (unsigned i = 0; i < firstOp.shape.size(); ++i) { - if (i > 0) shapeStr += ","; - shapeStr += std::to_string(firstOp.shape[i]); - } - - // Get memory space from the first tile_buf operand. - auto firstTbTy = dyn_cast(tileOp->getOperand(0).getType()); - std::string memSpace = firstTbTy ? getMemorySpaceString(firstTbTy) : "ub"; + // 2. Build operand schema JSON for mixed tile/scalar specialization. + std::string operandSpecsJson = buildOperandSpecsJson(key); // 3. Create temp file for stdout redirect. SmallString<128> tmpPath; @@ -254,9 +295,7 @@ func::FuncOp ExpandState::invokeTilelangDSL(const SpecKey &key, *pythonPath, "-m", "tilelang_dsl.expand_helper", "--template-dir", tilelangPath, "--op", opName, - "--dtype", firstOp.dtype, - "--shape", shapeStr, - "--memory-space", memSpace, + "--operand-specs", operandSpecsJson, }; // 5. Set up environment with PYTHONPATH. @@ -293,8 +332,26 @@ func::FuncOp ExpandState::invokeTilelangDSL(const SpecKey &key, redirects, /*secondsToWait=*/30, /*memoryLimit=*/0, &errMsg); if (rc != 0) { + std::string cmd; + llvm::raw_string_ostream os(cmd); + bool first = true; + auto appendToken = [&](StringRef token) { + if (!first) + os << ' '; + first = false; + llvm::sys::printArg(os, token, /*Quote=*/true); + }; + if (hasPythonPath) { + appendToken("env"); + appendToken(pythonPathEnv); + } + for (StringRef arg : args) + appendToken(arg); + os.flush(); + llvm::errs() << "ExpandTileOp: tilelang DSL helper failed (rc=" << rc << "): " << errMsg << "\n"; + llvm::errs() << "ExpandTileOp: run: " << cmd << "\n"; llvm::sys::fs::remove(tmpPath); return nullptr; } @@ -338,6 +395,7 @@ func::FuncOp ExpandState::invokeTilelangDSL(const SpecKey &key, // Build a unique name from all operand types. std::string uniqueName = "__pto_tilelang_" + key.opName; for (const auto &op : key.operands) { + uniqueName += op.kind == OperandKind::Tile ? "_tile" : "_scalar"; uniqueName += "_" + op.dtype; for (int64_t d : op.shape) uniqueName += "_" + std::to_string(d); @@ -377,17 +435,18 @@ LogicalResult ExpandState::expandTileOpsInFunction(func::FuncOp func, for (auto *op : tileOps) { auto specKeyOpt = buildSpecKey(op); if (!specKeyOpt) { - op->emitWarning("ExpandTileOp: cannot build specialization key, skipping"); - continue; + op->emitError( + "ExpandTileOp: cannot build specialization key for this operand schema"); + return failure(); } // Invoke tilelang DSL (with caching). func::FuncOp dslFn = invokeTilelangDSL(*specKeyOpt, op, mod, ctx); if (!dslFn) { StringRef opName = getTileOpName(op); - op->emitWarning("ExpandTileOp: no tilelang template for " + opName + - ", skipping"); - continue; + op->emitError("ExpandTileOp: failed to instantiate tilelang template for " + + opName); + return failure(); } // Replace tile op with func.call, passing tile_buf operands directly. @@ -408,6 +467,10 @@ void ExpandTileOpPass::runOnOperation() { MLIRContext *ctx = &getContext(); if (tilelangPath.empty()) { + mod.emitError( + "ExpandTileOp requires a non-empty tilelang-path when " + "--enable-tile-op-expand is set"); + signalPassFailure(); return; } diff --git a/lib/PTO/Transforms/FoldTileBufIntrinsics.cpp b/lib/PTO/Transforms/FoldTileBufIntrinsics.cpp index dd5764d66..a25a05624 100644 --- a/lib/PTO/Transforms/FoldTileBufIntrinsics.cpp +++ b/lib/PTO/Transforms/FoldTileBufIntrinsics.cpp @@ -38,59 +38,30 @@ namespace pto { namespace { -/// Compute the row-major strided memref type for a tile_buf. -static MemRefType computeBridgeMemrefType(pto::TileBufType tbTy, - MLIRContext *ctx) { - ArrayRef shape = tbTy.getShape(); - ArrayRef validShape = tbTy.getValidShape(); - - SmallVector memrefDims; - for (unsigned d = 0; d < shape.size(); ++d) { - if (d < validShape.size() && validShape[d] != ShapedType::kDynamic) - memrefDims.push_back(validShape[d]); - else - memrefDims.push_back(ShapedType::kDynamic); +/// Locate the `pto.bind_tile` op that produced `tileBuf`, expecting the +/// strict pattern emitted by MemrefToTileBuf: +/// +/// %bound = pto.bind_tile %src, %vrow, %vcol : memref -> memref +/// %tile = builtin.unrealized_conversion_cast %bound : memref -> !pto.tile_buf +/// +/// Returns nullptr (with an error emitted on `loc`) if the pattern does not +/// hold — the caller is expected to signal pass failure. +static pto::BindTileOp findBindTileForTileBuf(Value tileBuf, Operation *user) { + auto cast = tileBuf.getDefiningOp(); + if (!cast || cast.getNumOperands() != 1) { + user->emitError( + "FoldTileBufIntrinsics: expected tile_buf to be defined by a " + "single-operand builtin.unrealized_conversion_cast"); + return nullptr; } - - SmallVector strides(shape.size(), 1); - for (int s = static_cast(shape.size()) - 2; s >= 0; --s) - strides[s] = strides[s + 1] * shape[s + 1]; - - auto stridedLayout = StridedLayoutAttr::get(ctx, /*offset=*/0, strides); - return MemRefType::get(memrefDims, tbTy.getElementType(), stridedLayout, - tbTy.getMemorySpace()); -} - -/// Try to find the dynamic valid_row index from the tile_buf's defining op -/// chain (e.g. pto.bind_tile carries optional valid_row/valid_col operands). -static Value findDynamicValidRow(Value tileBuf) { - Value cur = tileBuf; - while (cur) { - if (auto bindOp = cur.getDefiningOp()) { - if (bindOp.getValidRow()) - return bindOp.getValidRow(); - // bind_tile may chain — trace further through its source. - cur = bindOp.getSource(); - continue; - } - break; - } - return nullptr; -} - -/// Try to find the dynamic valid_col index from the tile_buf's defining op. -static Value findDynamicValidCol(Value tileBuf) { - Value cur = tileBuf; - while (cur) { - if (auto bindOp = cur.getDefiningOp()) { - if (bindOp.getValidCol()) - return bindOp.getValidCol(); - cur = bindOp.getSource(); - continue; - } - break; + auto bindOp = cast.getOperand(0).getDefiningOp(); + if (!bindOp) { + user->emitError( + "FoldTileBufIntrinsics: expected unrealized_conversion_cast operand " + "to be defined by pto.bind_tile"); + return nullptr; } - return nullptr; + return bindOp; } struct FoldTileBufIntrinsicsPass @@ -102,6 +73,13 @@ struct FoldTileBufIntrinsicsPass MLIRContext *ctx = &getContext(); OpBuilder builder(ctx); + // Leftover TileLang template instances (private, uncalled after + // PTOInlineLibCall) still contain pto.tile_buf_addr / tile_valid_* + // ops on tile_buf function arguments — they have no bind_tile to + // fold against and will be removed by later DCE. Skip them. + if (func->hasAttr("pto.tilelang.instance")) + return; + SmallVector addrOps; SmallVector rowsOps; SmallVector colsOps; @@ -115,30 +93,51 @@ struct FoldTileBufIntrinsicsPass colsOps.push_back(cols); }); - // Fold pto.tile_buf_addr → pto.simd.tile_to_memref. + // Fold pto.tile_buf_addr → bind_tile's source memref (the static-layout + // pto.pointer_cast result), or further to pto.castptr when the requested + // result type is already !pto.ptr<...>. This bypasses the dynamic-offset + // memref produced by bind_tile itself, so downstream vlds/vsts + // canonicalization sees a clean strided<[..],offset:0> layout. for (auto addrOp : addrOps) { - builder.setInsertionPoint(addrOp); - auto tbTy = dyn_cast(addrOp.getSrc().getType()); - if (!tbTy) { - addrOp.emitError("tile_buf_addr source is not tile_buf"); + pto::BindTileOp bindOp = findBindTileForTileBuf(addrOp.getSrc(), addrOp); + if (!bindOp) + return signalPassFailure(); + + Value srcMemref = bindOp.getSource(); + if (!isa(srcMemref.getType())) { + addrOp.emitError( + "FoldTileBufIntrinsics: pto.bind_tile source is not a memref"); return signalPassFailure(); } - MemRefType bridgeMemref = computeBridgeMemrefType(tbTy, ctx); - auto bridge = builder.create( - addrOp.getLoc(), bridgeMemref, addrOp.getSrc()); + if (auto resultMemrefType = dyn_cast(addrOp.getDst().getType())) { + // The declared tile_buf_addr result type may differ from the actual + // bind_tile source layout (e.g. plain shape vs. strided layout) — the + // downstream vector ops are polymorphic over strided layouts of the + // same element type and shape, so retype the result in place. + if (srcMemref.getType() != resultMemrefType) + addrOp.getDst().setType(cast(srcMemref.getType())); + addrOp.getDst().replaceAllUsesWith(srcMemref); + addrOp.erase(); + continue; + } - Value result = bridge.getDst(); - if (result.getType() != addrOp.getDst().getType()) { - result = builder.create( - addrOp.getLoc(), addrOp.getDst().getType(), result); + auto resultPtrType = dyn_cast(addrOp.getDst().getType()); + if (!resultPtrType) { + addrOp.emitError( + "FoldTileBufIntrinsics: tile_buf_addr result must be memref or !pto.ptr"); + return signalPassFailure(); } - addrOp.getDst().replaceAllUsesWith(result); + builder.setInsertionPoint(addrOp); + Value replacement = + builder.create(addrOp.getLoc(), resultPtrType, srcMemref); + addrOp.getDst().replaceAllUsesWith(replacement); addrOp.erase(); } - // Fold pto.tile_valid_rows → arith.constant or dynamic index. + // Fold pto.tile_valid_rows → arith.constant (static) or bind_tile's + // valid_row operand (dynamic). for (auto rowsOp : rowsOps) { builder.setInsertionPoint(rowsOp); auto tbTy = dyn_cast(rowsOp.getSrc().getType()); @@ -153,19 +152,28 @@ struct FoldTileBufIntrinsicsPass replacement = builder.create(rowsOp.getLoc(), vRow); } else { - replacement = findDynamicValidRow(rowsOp.getSrc()); + pto::BindTileOp bindOp = + findBindTileForTileBuf(rowsOp.getSrc(), rowsOp); + if (!bindOp) + return signalPassFailure(); + replacement = bindOp.getValidRow(); if (!replacement) { rowsOp.emitError( - "tile_valid_rows: dynamic v_row but cannot find runtime value " - "(expected pto.bind_tile with valid_row operand)"); + "tile_valid_rows: dynamic v_row but bind_tile has no " + "valid_row operand"); return signalPassFailure(); } + // bind_tile's valid_row is `index` (matches tile_valid_rows result), + // so no type adaptation is required. + assert(replacement.getType() == rowsOp.getResult().getType() && + "tile_valid_rows fold: type mismatch with bind_tile valid_row"); } rowsOp.getResult().replaceAllUsesWith(replacement); rowsOp.erase(); } - // Fold pto.tile_valid_cols → arith.constant or dynamic index. + // Fold pto.tile_valid_cols → arith.constant (static) or bind_tile's + // valid_col operand (dynamic). for (auto colsOp : colsOps) { builder.setInsertionPoint(colsOp); auto tbTy = dyn_cast(colsOp.getSrc().getType()); @@ -180,13 +188,19 @@ struct FoldTileBufIntrinsicsPass replacement = builder.create(colsOp.getLoc(), vCol); } else { - replacement = findDynamicValidCol(colsOp.getSrc()); + pto::BindTileOp bindOp = + findBindTileForTileBuf(colsOp.getSrc(), colsOp); + if (!bindOp) + return signalPassFailure(); + replacement = bindOp.getValidCol(); if (!replacement) { colsOp.emitError( - "tile_valid_cols: dynamic v_col but cannot find runtime value " - "(expected pto.bind_tile with valid_col operand)"); + "tile_valid_cols: dynamic v_col but bind_tile has no " + "valid_col operand"); return signalPassFailure(); } + assert(replacement.getType() == colsOp.getResult().getType() && + "tile_valid_cols fold: type mismatch with bind_tile valid_col"); } colsOp.getResult().replaceAllUsesWith(replacement); colsOp.erase(); diff --git a/lib/PTO/Transforms/PTOInstantiateAndInlineOpLib.cpp b/lib/PTO/Transforms/PTOInstantiateAndInlineOpLib.cpp index b08d5f23b..2973510e9 100644 --- a/lib/PTO/Transforms/PTOInstantiateAndInlineOpLib.cpp +++ b/lib/PTO/Transforms/PTOInstantiateAndInlineOpLib.cpp @@ -1,5 +1,6 @@ #include "PTO/IR/PTO.h" #include "PTO/Transforms/Passes.h" +#include "PTOLowerToOpLibCalls.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -35,12 +36,20 @@ static bool isInstanceFunc(func::FuncOp fn) { return fn->hasAttr(kOpLibAttrInstVariantId); } -static bool isTilelangFunc(func::FuncOp fn) { - return fn->hasAttr("pto.tilelang.instance"); +static bool isTilelangInlineProcFunc(func::FuncOp fn) { + return fn->hasAttr("pto.tilelang.inline_proc"); +} + +static bool isTilelangTemplateFunc(func::FuncOp fn) { + return fn->hasAttr("pto.tilelang.instance") && fn.isPrivate(); } static bool isInlineableLibFunc(func::FuncOp fn) { - return isInstanceFunc(fn) || isTilelangFunc(fn); + // Keep OP-Lib behavior unchanged while force-inlining TileLang helpers + // (inline_proc + private template helper). + if (isInstanceFunc(fn) || isTilelangInlineProcFunc(fn)) + return true; + return isTilelangTemplateFunc(fn); } static Value maybeUnwrapCastToExpected(Value operand, Type expectedType) { @@ -116,14 +125,17 @@ static void eraseDeadBridgeCasts(func::FuncOp func) { } static LogicalResult inlineCall(func::CallOp call, func::FuncOp callee) { - if (call.getNumResults() != 0) - return call.emitOpError("OP-Lib inline expects call without results"); if (callee.isExternal()) return call.emitOpError("callee must have a body before inlining"); Block &entry = callee.getBody().front(); if (entry.getNumArguments() != call.getNumOperands()) return call.emitOpError("callee argument count mismatch during inlining"); + auto returnOp = dyn_cast(entry.getTerminator()); + if (!returnOp) + return call.emitOpError("callee must terminate with func.return"); + if (returnOp.getNumOperands() != call.getNumResults()) + return call.emitOpError("callee return/result arity mismatch during inlining"); OpBuilder builder(call); IRMapping mapping; @@ -132,12 +144,27 @@ static LogicalResult inlineCall(func::CallOp call, func::FuncOp callee) { mapping.map(arg, operand); for (Operation &op : entry.without_terminator()) { + FailureOr handledOr = + pto::tryCloneOpLibInlineBridgeOp(builder, op, mapping); + if (failed(handledOr)) + return call.emitOpError("failed to remap OP-Lib inline bridge op"); + if (*handledOr) + continue; + Operation *newOp = cloneOpForInlineWithFix(builder, op, mapping); for (auto [oldRes, newRes] : llvm::zip(op.getResults(), newOp->getResults())) mapping.map(oldRes, newRes); } + for (auto [callResult, returnOperand] : + llvm::zip(call.getResults(), returnOp.getOperands())) { + Value mapped = mapping.lookupOrNull(returnOperand); + if (!mapped) + mapped = returnOperand; + callResult.replaceAllUsesWith(mapped); + } + call.erase(); return success(); } @@ -156,7 +183,7 @@ struct PTOInlineLibCallPass for (func::FuncOp func : module.getOps()) { if (func.isExternal()) continue; - if (isInlineableLibFunc(func)) + if (isInstanceFunc(func)) continue; if (func.empty()) continue; @@ -209,6 +236,14 @@ struct PTOInlineLibCallPass OpBuilder builder(call); auto newCall = builder.create(call.getLoc(), callee, concreteOperands); + if (call.getNumResults() != newCall.getNumResults()) { + call.emitOpError("call result arity mismatch during inline staging"); + signalPassFailure(); + return; + } + for (auto [oldResult, newResult] : + llvm::zip(call.getResults(), newCall.getResults())) + oldResult.replaceAllUsesWith(newResult); call.erase(); if (failed(inlineCall(newCall, callee))) { @@ -235,6 +270,24 @@ struct PTOInlineLibCallPass llvm::errs() << "[op-fusion] inline-libcall touched " << touchedFuncs << " function(s), inlined " << inlinedCalls << " call(s)\n"; } + + // Drop now-dead inline-able callees (private + uncalled) so downstream + // backends never see leftover template/instance bodies. This is needed + // for TileLang templates whose tile_buf-typed parameters cannot be + // legalized once their callers have been inlined. + SymbolTable symbolTable(module); + SmallVector deadFuncs; + for (func::FuncOp func : module.getOps()) { + if (!isInlineableLibFunc(func)) + continue; + if (func.isPublic()) + continue; + auto uses = symbolTable.getSymbolUses(func, module); + if (uses && uses->empty()) + deadFuncs.push_back(func); + } + for (func::FuncOp func : deadFuncs) + func.erase(); } }; diff --git a/lib/PTO/Transforms/PTOLowerToOpLibCalls.cpp b/lib/PTO/Transforms/PTOLowerToOpLibCalls.cpp new file mode 100644 index 000000000..93eb585f1 --- /dev/null +++ b/lib/PTO/Transforms/PTOLowerToOpLibCalls.cpp @@ -0,0 +1,209 @@ +#include "PTO/IR/PTO.h" + +#include "PTOLowerToOpLibCalls.h" + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" + +using namespace mlir; + +namespace { + +static int64_t getElemBytes(Type elemTy) { + if (auto intTy = dyn_cast(elemTy)) + return (intTy.getWidth() + 7) / 8; + if (auto floatTy = dyn_cast(elemTy)) + return (floatTy.getWidth() + 7) / 8; + return -1; +} + +static bool readBLayoutI32(Attribute attr, int32_t &out) { + if (auto intAttr = dyn_cast(attr)) { + out = static_cast(intAttr.getInt()); + return true; + } + return false; +} + +static bool readSLayoutI32(Attribute attr, int32_t &out) { + if (auto intAttr = dyn_cast(attr)) { + out = static_cast(intAttr.getInt()); + return true; + } + return false; +} + +static FailureOr inferSimdBridgeMemRefType(pto::TileBufType tileTy, + MLIRContext *ctx) { + if (tileTy.getRank() != 2) + return failure(); + + ArrayRef physicalShape = tileTy.getShape(); + if (physicalShape.size() != 2) + return failure(); + if (physicalShape[0] == ShapedType::kDynamic || + physicalShape[1] == ShapedType::kDynamic) + return failure(); + + SmallVector memShape(physicalShape.begin(), physicalShape.end()); + ArrayRef validShape = tileTy.getValidShape(); + if (validShape.size() == memShape.size()) { + for (unsigned i = 0; i < validShape.size(); ++i) + memShape[i] = validShape[i] < 0 ? physicalShape[i] : validShape[i]; + } + + auto cfg = tileTy.getConfigAttr(); + if (!cfg) + cfg = pto::TileBufConfigAttr::getDefault(ctx); + + int32_t bl = 0; + int32_t sl = 0; + int32_t fr = 512; + (void)readBLayoutI32(cfg.getBLayout(), bl); + (void)readSLayoutI32(cfg.getSLayout(), sl); + if (auto attr = dyn_cast(cfg.getSFractalSize())) + fr = static_cast(attr.getInt()); + + int64_t innerRows = 1; + int64_t innerCols = 1; + if (sl != 0) { + int64_t elemBytes = getElemBytes(tileTy.getElementType()); + if (elemBytes <= 0) + return failure(); + if (fr == 1024) { + innerRows = 16; + innerCols = 16; + } else if (fr == 32) { + innerRows = 16; + innerCols = 2; + } else if (fr == 512) { + if (sl == 1) { + innerRows = 16; + innerCols = 32 / elemBytes; + } else if (sl == 2) { + innerRows = 32 / elemBytes; + innerCols = 16; + } else { + return failure(); + } + } else { + return failure(); + } + } + + SmallVector strides; + if (sl == 0) { + if (bl == 1) { + strides.push_back(1); + strides.push_back(physicalShape[0]); + } else { + strides.push_back(physicalShape[1]); + strides.push_back(1); + } + } else if (bl == 1) { + if (sl != 1) + return failure(); + strides.push_back(innerCols); + strides.push_back(physicalShape[0]); + } else { + strides.push_back(physicalShape[1]); + strides.push_back(innerRows); + } + + auto layout = StridedLayoutAttr::get(ctx, /*offset=*/0, strides); + return MemRefType::get(memShape, tileTy.getElementType(), layout, + tileTy.getMemorySpace()); +} + +static bool areIntegerCarrierTypesCompatible(Type lhs, Type rhs) { + auto lhsInt = dyn_cast(lhs); + auto rhsInt = dyn_cast(rhs); + if (!lhsInt || !rhsInt) + return false; + return lhsInt.getWidth() == rhsInt.getWidth(); +} + +static bool canRemapSimdBridgeViaCarrierCast(MemRefType actualTy, + MemRefType templateTy) { + if (actualTy.getRank() != templateTy.getRank()) + return false; + if (actualTy.getMemorySpace() != templateTy.getMemorySpace()) + return false; + return areIntegerCarrierTypesCompatible(actualTy.getElementType(), + templateTy.getElementType()); +} + +static MemRefType remapMemRefToTemplateCarrier(MemRefType actualTy, + MemRefType templateTy) { + return MemRefType::get(actualTy.getShape(), templateTy.getElementType(), + actualTy.getLayout(), actualTy.getMemorySpace()); +} + +} // namespace + +FailureOr mlir::pto::tryCloneOpLibInlineBridgeOp(OpBuilder &builder, + Operation &op, + IRMapping &mapping) { + if (auto bridge = dyn_cast(&op)) { + Value mappedSrc = mapping.lookupOrNull(bridge.getSrc()); + if (!mappedSrc) + return failure(); + + auto templateMemTy = dyn_cast(bridge.getDst().getType()); + if (auto mappedTileTy = dyn_cast(mappedSrc.getType())) { + FailureOr inferredTyOr = + inferSimdBridgeMemRefType(mappedTileTy, builder.getContext()); + if (failed(inferredTyOr)) + return failure(); + + auto inferredTy = *inferredTyOr; + auto newBridge = builder.create( + bridge.getLoc(), inferredTy, mappedSrc); + if (templateMemTy && inferredTy != templateMemTy && + canRemapSimdBridgeViaCarrierCast(inferredTy, templateMemTy)) { + auto carrierTy = remapMemRefToTemplateCarrier(inferredTy, templateMemTy); + auto cast = builder.create( + bridge.getLoc(), TypeRange{carrierTy}, ValueRange{newBridge.getDst()}); + mapping.map(bridge.getDst(), cast.getResult(0)); + } else { + mapping.map(bridge.getDst(), newBridge.getDst()); + } + return true; + } + + auto mappedMemTy = dyn_cast(mappedSrc.getType()); + auto dstMemTy = templateMemTy; + if (!mappedMemTy || !dstMemTy) + return failure(); + if (mappedMemTy.getRank() != dstMemTy.getRank()) + return failure(); + + auto newBridge = builder.create( + bridge.getLoc(), mappedMemTy, mappedSrc); + if (mappedMemTy.getElementType() == dstMemTy.getElementType()) { + mapping.map(bridge.getDst(), newBridge.getDst()); + return true; + } + if (!canRemapSimdBridgeViaCarrierCast(mappedMemTy, dstMemTy)) + return failure(); + auto carrierTy = remapMemRefToTemplateCarrier(mappedMemTy, dstMemTy); + auto cast = builder.create( + bridge.getLoc(), TypeRange{carrierTy}, ValueRange{newBridge.getDst()}); + mapping.map(bridge.getDst(), cast.getResult(0)); + return true; + } + + if (auto cast = dyn_cast(&op)) { + if (cast->getNumOperands() != 1 || cast->getNumResults() != 1) + return failure(); + + Value mappedSrc = mapping.lookupOrNull(cast.getOperand(0)); + if (!mappedSrc) + return failure(); + + mapping.map(cast.getResult(0), mappedSrc); + return true; + } + + return false; +} diff --git a/lib/PTO/Transforms/PTOLowerToOpLibCalls.h b/lib/PTO/Transforms/PTOLowerToOpLibCalls.h new file mode 100644 index 000000000..50b7d93b8 --- /dev/null +++ b/lib/PTO/Transforms/PTOLowerToOpLibCalls.h @@ -0,0 +1,17 @@ +#ifndef PTO_TRANSFORMS_PTOLOWERTOOPLIBCALLS_H +#define PTO_TRANSFORMS_PTOLOWERTOOPLIBCALLS_H + +#include "mlir/IR/Builders.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Support/LLVM.h" + +namespace mlir { +namespace pto { + +FailureOr tryCloneOpLibInlineBridgeOp(OpBuilder &builder, Operation &op, + IRMapping &mapping); + +} // namespace pto +} // namespace mlir + +#endif // PTO_TRANSFORMS_PTOLOWERTOOPLIBCALLS_H diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index 0b80849b1..7a7869e1d 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -1488,7 +1488,50 @@ struct PTOViewToMemrefPass } // ------------------------------------------------------------------ - // Stage 1.5: Fold pto.addptr chains into load/store_scalar. + // Stage 1.5: Lower pto.get_tensor_view_stride -> strided memref metadata + // ------------------------------------------------------------------ + SmallVector tvStrides; + func.walk([&](mlir::pto::GetTensorViewStrideOp op) { tvStrides.push_back(op); }); + + for (auto op : tvStrides) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + + Value view = op.getTensorView(); + auto mrTy = dyn_cast(view.getType()); + if (!mrTy) + continue; // leave it to later passes if it hasn't been lowered yet + + int64_t dimIndex = 0; + if (!getConstIndexValue(op.getDimIndex(), dimIndex)) { + op.emitError("get_tensor_view_stride currently expects a constant dim index"); + signalPassFailure(); + return; + } + if (dimIndex < 0 || dimIndex >= mrTy.getRank()) { + op.emitError("get_tensor_view_stride dim index is out of bounds"); + signalPassFailure(); + return; + } + + SmallVector staticStrides; + int64_t offset = ShapedType::kDynamic; + if (succeeded(getStridesAndOffset(mrTy, staticStrides, offset)) && + dimIndex < (int64_t)staticStrides.size() && + staticStrides[dimIndex] != ShapedType::kDynamic) { + rewriter.replaceOpWithNewOp( + op, staticStrides[dimIndex]); + continue; + } + + auto metadata = + rewriter.create(loc, view); + rewriter.replaceOp(op, metadata.getStrides()[dimIndex]); + } + + // ------------------------------------------------------------------ + // Stage 1.6: Fold pto.addptr chains into load/store_scalar. // ------------------------------------------------------------------ SmallVector loadScalars; func.walk([&](mlir::pto::LoadScalarOp op) { loadScalars.push_back(op); }); @@ -2548,12 +2591,15 @@ struct PTOViewToMemrefPass signalPassFailure(); return; } - rewriter.replaceOpWithNewOp( - op, + auto attrs = op->getAttrs(); + auto newOp = rewriter.create( + op.getLoc(), TypeRange{}, src, scale, dst); + newOp->setAttrs(attrs); + rewriter.replaceOp(op, newOp->getResults()); } SmallVector expandsops; diff --git a/lib/PTO/Transforms/VPTOPtrCastCleanup.cpp b/lib/PTO/Transforms/VPTOPtrCastCleanup.cpp new file mode 100644 index 000000000..e0537b3ea --- /dev/null +++ b/lib/PTO/Transforms/VPTOPtrCastCleanup.cpp @@ -0,0 +1,73 @@ +#include "PTO/Transforms/Passes.h" + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_VPTOPTRCASTCLEANUP +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; + +namespace { + +struct CollapsePtrMemRefPtrBridgePattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(UnrealizedConversionCastOp op, + PatternRewriter &rewriter) const override { + if (op->getNumOperands() != 1 || op->getNumResults() != 1) + return failure(); + + auto resultPtrType = dyn_cast(op.getResult(0).getType()); + if (!resultPtrType) + return failure(); + + auto castOp = op.getOperand(0).getDefiningOp(); + if (!castOp || castOp->getNumOperands() != 1) + return failure(); + + auto innerCast = + castOp.getSource().getDefiningOp(); + if (!innerCast || innerCast->getNumOperands() != 1 || + innerCast->getNumResults() != 1) + return failure(); + + Value basePtr = innerCast.getOperand(0); + if (basePtr.getType() != resultPtrType) + return failure(); + + rewriter.replaceOp(op, basePtr); + if (castOp->use_empty()) + rewriter.eraseOp(castOp); + if (innerCast->use_empty()) + rewriter.eraseOp(innerCast); + return success(); + } +}; + +struct VPTOPtrCastCleanupPass + : public pto::impl::VPTOPtrCastCleanupBase { + using pto::impl::VPTOPtrCastCleanupBase< + VPTOPtrCastCleanupPass>::VPTOPtrCastCleanupBase; + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createVPTOPtrCastCleanupPass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/VPTOPtrNormalize.cpp b/lib/PTO/Transforms/VPTOPtrNormalize.cpp new file mode 100644 index 000000000..7230db089 --- /dev/null +++ b/lib/PTO/Transforms/VPTOPtrNormalize.cpp @@ -0,0 +1,418 @@ +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_VPTOPTRNORMALIZE +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; + +namespace { + +static pto::AddressSpaceAttr getPointerMemorySpace(Attribute memorySpace, + MLIRContext *ctx) { + if (auto addrSpace = dyn_cast_or_null(memorySpace)) + return addrSpace; + if (auto intAttr = dyn_cast_or_null(memorySpace)) + return pto::AddressSpaceAttr::get( + ctx, static_cast(intAttr.getInt())); + return {}; +} + +static Value buildIndexValue(OpBuilder &builder, Location loc, + OpFoldResult ofr) { + if (auto value = dyn_cast(ofr)) + return value; + auto attr = cast(cast(ofr)); + return builder.create(loc, attr.getInt()); +} + +static bool needsSubviewPtrConversion(memref::SubViewOp op) { + auto resultType = dyn_cast(op.getType()); + if (!resultType) + return false; + return static_cast( + getPointerMemorySpace(resultType.getMemorySpace(), op.getContext())); +} + +static Type convertSubviewResultType(Type type) { + auto memrefType = dyn_cast(type); + if (!memrefType) + return type; + + auto memorySpace = + getPointerMemorySpace(memrefType.getMemorySpace(), type.getContext()); + if (!memorySpace) + return type; + + return pto::PtrType::get(type.getContext(), memrefType.getElementType(), + memorySpace); +} + +static bool hasPtrNormalizeConvertibleType(Type type) { + if (isa(type)) + return true; + auto memrefType = dyn_cast(type); + return memrefType && + static_cast( + getPointerMemorySpace(memrefType.getMemorySpace(), type.getContext())); +} + +static bool hasPtrNormalizeConvertibleType(TypeRange types) { + return llvm::any_of(types, [](Type type) { + return hasPtrNormalizeConvertibleType(type); + }); +} + +static Value materializePtrCast(OpBuilder &builder, Type resultType, + ValueRange inputs, Location loc) { + if (inputs.size() != 1 || !isa(resultType)) + return {}; + + Value input = inputs.front(); + if (input.getType() == resultType) + return input; + + auto inputMemrefType = dyn_cast(input.getType()); + auto resultPtrType = dyn_cast(resultType); + if (!inputMemrefType || !resultPtrType) + return {}; + + auto memorySpace = getPointerMemorySpace(inputMemrefType.getMemorySpace(), + builder.getContext()); + if (!memorySpace || memorySpace != resultPtrType.getMemorySpace() || + inputMemrefType.getElementType() != resultPtrType.getElementType()) + return {}; + + return builder.create(loc, resultPtrType, input); +} + +static Value materializeUnrealizedCast(OpBuilder &builder, Type resultType, + ValueRange inputs, Location loc) { + if (inputs.size() != 1) + return {}; + return builder + .create(loc, TypeRange{resultType}, inputs) + .getResult(0); +} + +static LogicalResult computeSubviewElementOffset(memref::SubViewOp op, + PatternRewriter &rewriter, + Value &offset) { + auto sourceType = dyn_cast(op.getSource().getType()); + if (!sourceType) + return failure(); + + SmallVector strides; + int64_t baseOffset = 0; + if (failed(getStridesAndOffset(sourceType, strides, baseOffset))) + return failure(); + // The SSA source already names the base address after bind_tile/pointer_cast + // normalization. A dynamic memref layout offset here is metadata we can + // ignore for ptr normalization and model as zero. + if (baseOffset == ShapedType::kDynamic) + baseOffset = 0; + + Location loc = op.getLoc(); + Value total = rewriter.create(loc, baseOffset); + ArrayRef mixedOffsets = op.getMixedOffsets(); + if (mixedOffsets.size() != strides.size()) + return failure(); + + for (auto [ofr, stride] : llvm::zip(mixedOffsets, strides)) { + if (stride == 0) + continue; + if (stride == ShapedType::kDynamic) + return failure(); + + Value idx = buildIndexValue(rewriter, loc, ofr); + if (!idx.getType().isIndex()) + return failure(); + + if (stride != 1) { + Value strideValue = + rewriter.create(loc, stride); + idx = rewriter.create(loc, idx, strideValue); + } + total = rewriter.create(loc, total, idx); + } + + offset = total; + return success(); +} + +static Value materializeSubviewInputPtr(Value source, PatternRewriter &rewriter, + Location loc) { + if (!source) + return {}; + if (isa(source.getType())) + return source; + + auto memrefType = dyn_cast(source.getType()); + if (!memrefType) + return {}; + + auto memorySpace = + getPointerMemorySpace(memrefType.getMemorySpace(), rewriter.getContext()); + if (!memorySpace) + return {}; + + auto ptrType = pto::PtrType::get(rewriter.getContext(), + memrefType.getElementType(), memorySpace); + return rewriter.create(loc, ptrType, source); +} + +struct ConvertTileBufAddrToPtrPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::TileBufAddrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type convertedType = getTypeConverter()->convertType(op.getDst().getType()); + if (!isa(convertedType)) + return failure(); + + rewriter.replaceOpWithNewOp(op, convertedType, + adaptor.getSrc()); + return success(); + } +}; + +struct ConvertPointerCastToCastPtrPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::PointerCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type convertedType = getTypeConverter()->convertType(op.getResult().getType()); + auto ptrType = dyn_cast(convertedType); + if (!ptrType) + return failure(); + + if (adaptor.getAddrs().empty()) + return rewriter.notifyMatchFailure(op, "expected at least one address"); + + rewriter.replaceOpWithNewOp(op, ptrType, + adaptor.getAddrs().front()); + return success(); + } +}; + +struct ConvertBindTileToPtrPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::BindTileOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type convertedType = getTypeConverter()->convertType(op.getResult().getType()); + auto ptrType = dyn_cast(convertedType); + if (!ptrType) + return failure(); + + Value ptr = + materializeSubviewInputPtr(adaptor.getSource(), rewriter, op.getLoc()); + if (!ptr) + return rewriter.notifyMatchFailure(op, + "failed to materialize bind_tile input ptr"); + + if (ptr.getType() != ptrType) + ptr = rewriter.create(op.getLoc(), ptrType, ptr); + + rewriter.replaceOp(op, ptr); + return success(); + } +}; + +struct ConvertSubviewToAddPtrPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!needsSubviewPtrConversion(op)) + return failure(); + + auto ptrType = + dyn_cast(getTypeConverter()->convertType(op.getType())); + if (!ptrType) + return rewriter.notifyMatchFailure(op, "expected ptr result type"); + + Value basePtr = + materializeSubviewInputPtr(adaptor.getSource(), rewriter, op.getLoc()); + if (!basePtr) + return rewriter.notifyMatchFailure(op, + "failed to materialize subview input ptr"); + + Value offset; + if (failed(computeSubviewElementOffset(op, rewriter, offset))) + return rewriter.notifyMatchFailure(op, + "failed to compute subview element offset"); + + rewriter.replaceOpWithNewOp(op, ptrType, basePtr, offset); + return success(); + } +}; + +struct ConvertVldsSubviewOperandPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::VldsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isa(adaptor.getSource().getType())) + return failure(); + + OperationState state(op.getLoc(), op->getName().getStringRef()); + state.addOperands({adaptor.getSource(), adaptor.getOffset()}); + state.addTypes(op->getResultTypes()); + state.addAttributes(op->getAttrs()); + Operation *newOp = rewriter.create(state); + rewriter.replaceOp(op, newOp->getResults()); + return success(); + } +}; + +struct ConvertVstsSubviewOperandPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::VstsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isa(adaptor.getDestination().getType())) + return failure(); + + OperationState state(op.getLoc(), op->getName().getStringRef()); + state.addOperands( + {adaptor.getValue(), adaptor.getDestination(), adaptor.getOffset(), + adaptor.getMask()}); + state.addTypes(op->getResultTypes()); + state.addAttributes(op->getAttrs()); + Operation *newOp = rewriter.create(state); + rewriter.replaceOp(op, newOp->getResults()); + return success(); + } +}; + +struct ConvertPtrNormalizeUnrealizedCastOp final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op->getNumOperands() != 1 || op->getNumResults() != 1) + return failure(); + if (!hasPtrNormalizeConvertibleType(op->getOperandTypes()) && + !hasPtrNormalizeConvertibleType(op->getResultTypes())) + return failure(); + + Type convertedResultType = + getTypeConverter()->convertType(op.getResult(0).getType()); + if (!convertedResultType) + return failure(); + + Value input = adaptor.getOperands().front(); + if (input.getType() != convertedResultType) + return failure(); + + rewriter.replaceOp(op, input); + return success(); + } +}; + +struct VPTOPtrNormalizePass + : public pto::impl::VPTOPtrNormalizeBase { + using pto::impl::VPTOPtrNormalizeBase< + VPTOPtrNormalizePass>::VPTOPtrNormalizeBase; + + void runOnOperation() override { + ModuleOp module = getOperation(); + MLIRContext *context = module.getContext(); + + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + typeConverter.addConversion( + [](Type type) { return convertSubviewResultType(type); }); + typeConverter.addTargetMaterialization(materializeUnrealizedCast); + typeConverter.addSourceMaterialization(materializeUnrealizedCast); + typeConverter.addArgumentMaterialization(materializeUnrealizedCast); + + ConversionTarget target(*context); + target.addLegalDialect(); + target.addDynamicallyLegalDialect([](Operation *op) { + return !isa(op); + }); + target.addLegalOp(); + target.addDynamicallyLegalOp([&](func::FuncOp op) { + return typeConverter.isSignatureLegal(op.getFunctionType()) && + typeConverter.isLegal(&op.getBody()); + }); + target.addDynamicallyLegalOp( + [&](UnrealizedConversionCastOp op) { + return !hasPtrNormalizeConvertibleType(op->getOperandTypes()) && + !hasPtrNormalizeConvertibleType(op->getResultTypes()); + }); + target.addDynamicallyLegalOp([&](pto::TileBufAddrOp op) { + return op.getDst().getType() == + typeConverter.convertType(op.getDst().getType()); + }); + target.addDynamicallyLegalOp( + [&](pto::PointerCastOp op) { + return op.getResult().getType() == + typeConverter.convertType(op.getResult().getType()); + }); + target.addDynamicallyLegalOp([&](pto::BindTileOp op) { + return op.getResult().getType() == + typeConverter.convertType(op.getResult().getType()); + }); + target.addDynamicallyLegalOp( + [](pto::VldsOp op) { return isa(op.getSource().getType()); }); + target.addDynamicallyLegalOp([](pto::VstsOp op) { + return isa(op.getDestination().getType()); + }); + target.addDynamicallyLegalOp( + [](memref::SubViewOp op) { return !needsSubviewPtrConversion(op); }); + + RewritePatternSet patterns(context); + scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter, patterns, + target); + populateFunctionOpInterfaceTypeConversionPattern(patterns, + typeConverter); + patterns.add( + typeConverter, context); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createVPTOPtrNormalizePass() { + return std::make_unique(); +} diff --git a/lib/TileOps/render_template_mlir.py b/lib/TileOps/render_template_mlir.py new file mode 100644 index 000000000..42ce40552 --- /dev/null +++ b/lib/TileOps/render_template_mlir.py @@ -0,0 +1,342 @@ +#!/usr/bin/env python3 +"""Materialize a TileLang DSL library template to authoring-form MLIR. + +Examples: + python3 lib/TileOps/render_template_mlir.py lib/TileOps/tload_template.py + python3 lib/TileOps/render_template_mlir.py lib/TileOps/tadd_template.py --tile dst=8x64@ub --tile src0=8x64@ub --tile src1=8x64@ub + python3 lib/TileOps/render_template_mlir.py lib/TileOps/tload_template.py --dtypes f16,f16 -o /tmp/tload.mlir +""" + +from __future__ import annotations + +import argparse +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +REPO_ROOT = Path(__file__).resolve().parents[2] +TILELANG_PYTHON_DIR = REPO_ROOT / "tilelang-dsl" / "python" +if str(TILELANG_PYTHON_DIR) not in sys.path: + sys.path.insert(0, str(TILELANG_PYTHON_DIR)) + +import tilelang_dsl as pto + + +_DTYPE_BY_NAME = { + "i1": pto.i1, + "i8": pto.i8, + "i16": pto.i16, + "i32": pto.i32, + "i64": pto.i64, + "f16": pto.f16, + "bf16": pto.bf16, + "f32": pto.f32, +} +_MEMORY_SPACE_BY_NAME = { + "gm": pto.MemorySpace.GM, + "ub": pto.MemorySpace.UB, +} + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Load a TileLang DSL template file and emit its corresponding MLIR text.", + ) + parser.add_argument("template", help="Path to the template Python file") + parser.add_argument( + "--kernel", + help="Descriptor symbol name inside the module when the file defines multiple @pto.vkernel templates", + ) + parser.add_argument( + "--op", + help="Concrete op to bind when the descriptor matches multiple ops; defaults to the first match op", + ) + parser.add_argument( + "--dtypes", + help="Concrete operand dtypes as a comma-separated list, for example: f32,f32 or f16,f16,f16", + ) + parser.add_argument( + "--tile", + action="append", + default=[], + metavar="PARAM=SHAPE[@SPACE][:VALID]", + help=( + "Tile specialization override, for example: dst=16x32@ub or " + "dst=16x32@ub:8x32. May be repeated." + ), + ) + parser.add_argument( + "--default-tile-shape", + default="16x32", + help="Default shape for every bare Tile parameter when no --tile override is given", + ) + parser.add_argument( + "--default-tile-space", + default="ub", + choices=sorted(_MEMORY_SPACE_BY_NAME), + help="Default memory space for every bare Tile parameter", + ) + parser.add_argument( + "-o", + "--output", + help="Optional output path; defaults to stdout", + ) + return parser.parse_args() + + +def _load_module(template_path: Path) -> ModuleType: + module_name = f"_tileops_template_{template_path.stem}" + spec = importlib.util.spec_from_file_location(module_name, template_path) + if spec is None or spec.loader is None: + raise ValueError(f"failed to load Python module from {template_path}") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def _find_descriptors(module: ModuleType) -> dict[str, pto.VKernelDescriptor]: + descriptors: dict[str, pto.VKernelDescriptor] = {} + for name, value in vars(module).items(): + if isinstance(value, pto.VKernelDescriptor): + descriptors[name] = value + return descriptors + + +def _select_descriptor( + descriptors: dict[str, pto.VKernelDescriptor], + kernel_name: str | None, +) -> tuple[str, pto.VKernelDescriptor]: + if not descriptors: + raise ValueError("no @pto.vkernel descriptor found in the template module") + if kernel_name is not None: + descriptor = descriptors.get(kernel_name) + if descriptor is None: + available = ", ".join(sorted(descriptors)) + raise ValueError( + f"kernel {kernel_name!r} was not found in the template module; available descriptors: {available}" + ) + return kernel_name, descriptor + if len(descriptors) == 1: + return next(iter(descriptors.items())) + available = ", ".join(sorted(descriptors)) + raise ValueError( + "the template module defines multiple @pto.vkernel descriptors; " + f"please pass --kernel. Available descriptors: {available}" + ) + + +def _parse_dtype_list(text: str) -> tuple[pto.ScalarType, ...]: + parts = [part.strip() for part in text.split(",") if part.strip()] + if not parts: + raise ValueError("--dtypes must contain at least one dtype") + try: + return tuple(_DTYPE_BY_NAME[part] for part in parts) + except KeyError as exc: + available = ", ".join(sorted(_DTYPE_BY_NAME)) + raise ValueError( + f"unsupported dtype {exc.args[0]!r}; available dtypes: {available}" + ) from exc + + +def _default_concrete_dtype(pattern: object) -> pto.ScalarType: + if isinstance(pattern, pto.ScalarType): + return pattern + if isinstance(pattern, pto.WildcardType): + if pattern.name in {"AnyType", "AnyFloat"}: + return pto.f32 + if pattern.name == "AnyInt": + return pto.i32 + if pattern.name == "AnyMask": + return pto.i1 + raise ValueError(f"unsupported wildcard dtype pattern {pattern!r}") + if isinstance(pattern, pto.TypeVariable): + return pto.f32 + raise ValueError(f"unsupported dtype pattern {pattern!r}") + + +def _default_operand_types(descriptor: pto.VKernelDescriptor) -> tuple[pto.ScalarType, ...]: + if not descriptor.dtypes: + raise ValueError("descriptor does not declare any dtype signatures") + prototype = descriptor.dtypes[0] + typevar_bindings: dict[str, pto.ScalarType] = {} + concrete: list[pto.ScalarType] = [] + for pattern in prototype: + if isinstance(pattern, pto.TypeVariable): + bound = typevar_bindings.get(pattern.name) + if bound is None: + bound = pto.f32 + typevar_bindings[pattern.name] = bound + concrete.append(bound) + continue + concrete.append(_default_concrete_dtype(pattern)) + return tuple(concrete) + + +def _bind_descriptor( + descriptor: pto.VKernelDescriptor, + *, + op_name: str | None, + operand_types: tuple[pto.ScalarType, ...] | None, +) -> pto.VKernelDescriptor: + concrete_op = op_name + if concrete_op is None: + if descriptor.selected_op is not None: + concrete_op = descriptor.selected_op + elif len(descriptor.match_ops) == 1: + concrete_op = descriptor.match_ops[0] + else: + available = ", ".join(descriptor.match_ops) + raise ValueError( + f"descriptor matches multiple ops; pass --op. Available ops: {available}" + ) + + concrete_operand_types = operand_types + if concrete_operand_types is None: + if descriptor._selected_dtype_signature is not None: + concrete_operand_types = descriptor._selected_dtype_signature + else: + concrete_operand_types = _default_operand_types(descriptor) + + registry = pto.KernelRegistry((descriptor,)) + return pto.select_kernel( + target=descriptor.target, + op=concrete_op, + operand_types=concrete_operand_types, + registry=registry, + ) + + +def _parse_shape(text: str) -> tuple[int, ...]: + dims = [] + for part in text.split("x"): + part = part.strip() + if not part: + raise ValueError(f"invalid shape {text!r}") + value = int(part) + if value <= 0: + raise ValueError(f"shape dimensions must be positive integers, got {text!r}") + dims.append(value) + if not dims: + raise ValueError(f"invalid shape {text!r}") + return tuple(dims) + + +def _parse_tile_override(spec_text: str) -> tuple[str, pto.TileSpecialization]: + if "=" not in spec_text: + raise ValueError( + f"invalid --tile value {spec_text!r}; expected PARAM=SHAPE[@SPACE][:VALID]" + ) + param_name, payload = spec_text.split("=", 1) + param_name = param_name.strip() + payload = payload.strip() + if not param_name: + raise ValueError(f"invalid --tile value {spec_text!r}; missing parameter name") + + valid_shape = None + if ":" in payload: + payload, valid_text = payload.split(":", 1) + valid_shape = _parse_shape(valid_text.strip()) + + memory_space = pto.MemorySpace.UB + if "@" in payload: + shape_text, memory_space_text = payload.split("@", 1) + memory_space_key = memory_space_text.strip().lower() + try: + memory_space = _MEMORY_SPACE_BY_NAME[memory_space_key] + except KeyError as exc: + available = ", ".join(sorted(_MEMORY_SPACE_BY_NAME)) + raise ValueError( + f"unsupported memory space {memory_space_text!r}; available spaces: {available}" + ) from exc + else: + shape_text = payload + + shape = _parse_shape(shape_text.strip()) + if valid_shape is not None and len(valid_shape) != len(shape): + raise ValueError( + f"valid_shape rank {len(valid_shape)} does not match shape rank {len(shape)} for {param_name!r}" + ) + return ( + param_name, + pto.TileSpecialization( + shape=shape, + memory_space=memory_space, + valid_shape=valid_shape, + ), + ) + + +def _default_tile_specialization( + *, + shape: tuple[int, ...], + memory_space: pto.MemorySpace, +) -> pto.TileSpecialization: + return pto.TileSpecialization(shape=shape, memory_space=memory_space) + + +def _specialize_tiles( + descriptor: pto.VKernelDescriptor, + *, + tile_overrides: dict[str, pto.TileSpecialization], + default_shape: tuple[int, ...], + default_memory_space: pto.MemorySpace, +) -> pto.VKernelDescriptor: + if not descriptor.tile_parameters: + return descriptor + + specializations: dict[str, pto.TileSpecialization] = {} + for param in descriptor.tile_parameters: + specializations[param.name] = tile_overrides.get( + param.name, + _default_tile_specialization( + shape=default_shape, + memory_space=default_memory_space, + ), + ) + return descriptor.specialize(**specializations) + + +def _emit_output(text: str, output_path: str | None) -> None: + if output_path is None: + sys.stdout.write(text) + if not text.endswith("\n"): + sys.stdout.write("\n") + return + path = Path(output_path) + path.write_text(text, encoding="utf-8") + + +def main() -> int: + args = _parse_args() + template_path = Path(args.template).resolve() + if not template_path.is_file(): + print(f"error: template file not found: {template_path}", file=sys.stderr) + return 1 + + try: + module = _load_module(template_path) + _, descriptor = _select_descriptor(_find_descriptors(module), args.kernel) + operand_types = None if args.dtypes is None else _parse_dtype_list(args.dtypes) + bound = _bind_descriptor( + descriptor, + op_name=args.op, + operand_types=operand_types, + ) + tile_overrides = dict(_parse_tile_override(spec_text) for spec_text in args.tile) + specialized = _specialize_tiles( + bound, + tile_overrides=tile_overrides, + default_shape=_parse_shape(args.default_tile_shape), + default_memory_space=_MEMORY_SPACE_BY_NAME[args.default_tile_space], + ) + _emit_output(specialized.mlir_text(), args.output) + return 0 + except Exception as exc: + print(f"error: {exc}", file=sys.stderr) + return 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/test/tilelang_templates/tadd_template.py b/lib/TileOps/tadd_template.py similarity index 56% rename from test/tilelang_templates/tadd_template.py rename to lib/TileOps/tadd_template.py index be561220e..ecab2be5b 100644 --- a/test/tilelang_templates/tadd_template.py +++ b/lib/TileOps/tadd_template.py @@ -1,23 +1,15 @@ -"""TileLang DSL template for pto.tadd — used by ExpandTileOp tests.""" +"""TileLang DSL template for pto.tadd""" import sys from pathlib import Path - -_repo = Path(__file__).resolve().parents[2] -_pkg = _repo / "tilelang-dsl" / "python" -if str(_pkg) not in sys.path: - sys.path.insert(0, str(_pkg)) - import tilelang_dsl as pto @pto.vkernel( - op="pto.tadd", - dtypes=[(pto.f32, pto.f32, pto.f32)], - advanced=True, - name="template_tadd", + target="a5", + op="pto.tadd" ) -def template_tadd(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): +def template_tadd(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): dtype = dst.element_type valid_rows, valid_cols = dst.valid_shape @@ -29,4 +21,4 @@ def template_tadd(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): rhs = pto.vlds(src1[row, col:]) summed = pto.vadd(lhs, rhs, mask) pto.vsts(summed, dst[row, col:], mask) - return None + return diff --git a/lib/TileOps/tadds_template.py b/lib/TileOps/tadds_template.py new file mode 100644 index 000000000..fd36e1f3d --- /dev/null +++ b/lib/TileOps/tadds_template.py @@ -0,0 +1,23 @@ +"""TileLang DSL template for pto.tadds""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tadds", +) +def template_tadds(src: pto.Tile, scalar: pto.AnyType, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vec = pto.vlds(src[row, col:]) + result = pto.vadds(vec, scalar, mask) + pto.vsts(result, dst[row, col:], mask) + return diff --git a/lib/TileOps/tload_template.py b/lib/TileOps/tload_template.py new file mode 100644 index 000000000..a12ef6675 --- /dev/null +++ b/lib/TileOps/tload_template.py @@ -0,0 +1,91 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""`pto.tload` 的 TileLang DSL 模板""" + +import tilelang_dsl as pto + + +def _tload_preconditions(src, dst) -> bool: + logical_rows = src.shape[0] * src.shape[1] * src.shape[2] * src.shape[3] + logical_cols = src.shape[4] + return ( + src.rank == 5 + and src.strides[4] == 1 + and dst.valid_shape[0] <= logical_rows + and dst.valid_shape[1] <= logical_cols + and logical_rows <= dst.shape[0] + and logical_cols <= dst.shape[1] + and dst.valid_shape[0] <= dst.shape[0] + and dst.valid_shape[1] <= dst.shape[1] + ) + + +@pto.vkernel( + target="a5", + op="pto.tload", + advanced=True, + constraints=[_tload_preconditions], +) +def template_tload(src: pto.PartitionTensorView, dst: pto.Tile): + dtype = dst.element_type + elem_bytes = pto.bytewidth(dtype) + + g0, g1, g2, g3, g4 = src.shape + s0, s1, s2, s3, s4 = src.strides + + valid_rows, valid_cols = dst.valid_shape + ub_rows, ub_cols = dst.shape + + # These preconditions are expressed through the descriptor-level constraint + # callable above, using direct `src.shape[i]` / `dst.shape[i]` syntax. + + n_burst = g3 + len_burst = g4 * elem_bytes + gm_stride = s3 * elem_bytes + ub_stride = ub_cols * elem_bytes + + dst_stride2 = g3 * ub_cols + dst_stride1 = g2 * dst_stride2 + dst_stride0 = g1 * dst_stride1 + + loop1 = g2 + loop2 = g1 + loop1_src_stride = s2 * elem_bytes + loop1_dst_stride = dst_stride2 * elem_bytes + loop2_src_stride = s1 * elem_bytes + loop2_dst_stride = dst_stride1 * elem_bytes + + gm_ptr = src.as_ptr() + ub_ptr = dst.as_ptr() + + if loop1 != 1 or loop2 != 1: + pto.set_loop2_stride_outtoub( + src_stride=loop2_src_stride, dst_stride=loop2_dst_stride + ) + pto.set_loop1_stride_outtoub( + src_stride=loop1_src_stride, dst_stride=loop1_dst_stride + ) + pto.set_loop_size_outtoub(loop1=loop1, loop2=loop2) + + for i in range(0, g0, 1): + src_i = pto.addptr(gm_ptr, i * s0 * elem_bytes) + dst_i = pto.addptr(ub_ptr, i * dst_stride0 * elem_bytes) + pto.copy_gm_to_ubuf( + dst=dst_i, + src=src_i, + n_burst=n_burst, + len_burst=len_burst, + gm_stride=gm_stride, + ub_stride=ub_stride, + enable_ub_pad=False, + ) + + if loop1 != 1 or loop2 != 1: + pto.set_loop_size_outtoub(loop1=1, loop2=1) + return diff --git a/lib/TileOps/tstore_template.py b/lib/TileOps/tstore_template.py new file mode 100644 index 000000000..a8d6cc4f6 --- /dev/null +++ b/lib/TileOps/tstore_template.py @@ -0,0 +1,88 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""`pto.tstore` 的 TileLang DSL 模板""" + +import tilelang_dsl as pto + + +def _tstore_preconditions(src, dst) -> bool: + logical_rows = dst.shape[0] * dst.shape[1] * dst.shape[2] * dst.shape[3] + logical_cols = dst.shape[4] + return ( + dst.rank == 5 + and dst.strides[4] == 1 + and src.valid_shape[0] == logical_rows + and src.valid_shape[1] == logical_cols + and src.valid_shape[0] <= src.shape[0] + and src.valid_shape[1] <= src.shape[1] + ) + + +@pto.vkernel( + target="a5", + op="pto.tstore", + advanced=True, + constraints=[_tstore_preconditions], +) +def template_tstore(src: pto.Tile, dst: pto.PartitionTensorView): + dtype = src.element_type + elem_bytes = pto.bytewidth(dtype) + + g0, g1, g2, g3, g4 = dst.shape + s0, s1, s2, s3, s4 = dst.strides + + valid_rows, valid_cols = src.valid_shape + ub_rows, ub_cols = src.shape + + # These preconditions are expressed through the descriptor-level constraint + # callable above, using direct `src.*` / `dst.*` metadata syntax. + + n_burst = g3 + len_burst = valid_cols * elem_bytes + ub_stride = ub_cols * elem_bytes + gm_stride = s3 * elem_bytes + + src_stride2 = g3 * ub_cols + src_stride1 = g2 * src_stride2 + src_stride0 = g1 * src_stride1 + + loop1 = g2 + loop2 = g1 + loop1_src_stride = src_stride2 * elem_bytes + loop1_dst_stride = s2 * elem_bytes + loop2_src_stride = src_stride1 * elem_bytes + loop2_dst_stride = s1 * elem_bytes + + ub_ptr = src.as_ptr() + gm_ptr = dst.as_ptr() + + if loop1 != 1 or loop2 != 1: + pto.set_loop2_stride_ubtoout( + src_stride=loop2_src_stride, dst_stride=loop2_dst_stride + ) + pto.set_loop1_stride_ubtoout( + src_stride=loop1_src_stride, dst_stride=loop1_dst_stride + ) + pto.set_loop_size_ubtoout(loop1=loop1, loop2=loop2) + + for i in range(0, g0, 1): + src_i = pto.addptr(ub_ptr, i * src_stride0 * elem_bytes) + dst_i = pto.addptr(gm_ptr, i * s0 * elem_bytes) + pto.copy_ubuf_to_gm( + dst=dst_i, + src=src_i, + n_burst=n_burst, + len_burst=len_burst, + gm_stride=gm_stride, + ub_stride=ub_stride, + ) + + if loop1 != 1 or loop2 != 1: + pto.set_loop_size_ubtoout(loop1=1, loop2=1) + return diff --git a/openspec/changes/archive/2026-04-07-extend-tilelang-dsl-template-op-slots/.openspec.yaml b/openspec/changes/archive/2026-04-07-extend-tilelang-dsl-template-op-slots/.openspec.yaml new file mode 100644 index 000000000..2fe001e67 --- /dev/null +++ b/openspec/changes/archive/2026-04-07-extend-tilelang-dsl-template-op-slots/.openspec.yaml @@ -0,0 +1,2 @@ +schema: spec-driven +created: 2026-04-07 diff --git a/openspec/changes/archive/2026-04-07-extend-tilelang-dsl-template-op-slots/design.md b/openspec/changes/archive/2026-04-07-extend-tilelang-dsl-template-op-slots/design.md new file mode 100644 index 000000000..31b7d32ec --- /dev/null +++ b/openspec/changes/archive/2026-04-07-extend-tilelang-dsl-template-op-slots/design.md @@ -0,0 +1,273 @@ +## Context + +### 范围 + +本 design 覆盖两个相互配合的能力: + +1. 新增 `tilelang-dsl-template-slots` +2. 修改 `tilelang-dsl-kernel-matcher` + +目标是让多个同 family 的具体 PTO op 共享一份 TileLang DSL kernel body,同时保持 frontend 仍然是静态、受限、deterministic 的 DSL,而不是变成任意 Python 解释执行器。 + +### 当前状态 + +当前 `tilelang-dsl-kernel-matcher` 已提供: + +- `KernelRegistry` +- `pto.select_kernel(target, op, operand_types, context_attrs, registry=None)` +- 多 signature `dtypes` +- `constraints` +- `priority` + +但 descriptor 仍以单个 concrete `op` 为中心。 +与此同时,kernel body 中的 vector op surface 也仍要求直接书写真实 `pto.vadd`、`pto.vsub`、`pto.vmul`、`pto.vdiv` 等调用。 + +这使得 `tadd/tsub/tmul/tdiv` 这类共享同一 loop/mask/load-store 骨架、只在核心计算 op 上不同的 kernel 很难复用实现。 + +### 关键约束 + +- `tilelang-dsl/` 继续作为源码、测试、样例和局部文档的 source of truth。 +- `pto.select_kernel(...)` 的外部查询形态继续保持 concrete `op` 查询,不新增 family 级公共查询接口。 +- frontend 继续只接受受限 Python 子集和固定 DSL call surface;不能靠支持任意 dict/callable 来实现模板分发。 +- 新能力必须复用现有 semantic/type-check/lowering 路径,最终仍收敛到当前 authoring-form VPTO legality contract。 +- 多-op / template 相关行为必须 deterministic,不依赖注册顺序、定义顺序或运行时值。 + +### 设计拆分 + +本 change 明确拆成两层: + +1. matcher 层 + +- 负责让一个 descriptor 匹配多个 concrete PTO op,并在 `select_kernel(...)` 后绑定唯一 `selected_op` + +2. template slot 层 + +- 负责让 kernel body 使用统一的 `pto.tpl("slot", ...)` 占位调用 +- 在 frontend 编译阶段按 `selected_op` 把占位调用静态替换成真实 `pto.*` op + +这两层的拆分保证: + +- descriptor 选择仍由 matcher capability 管理 +- kernel body 模板化能力独立成一个新的 authoring capability +- semantic/type-check/lowering 不需要接入“动态模板调用”的新运行时概念 + +## Goals / Non-Goals + +**Goals:** + +- 允许 `@pto.vkernel` 用 `ops=[...]` 描述一个 descriptor 覆盖多个 concrete PTO op。 +- 提供统一模板入口 `pto.tpl("slot", ...)`,支持在 kernel body 的任意合法 DSL 位置表达“按 concrete op 替换”的核心计算。 +- 在 frontend 阶段完成模板替换,让后续 semantic/type-check/lowering 继续只面对真实 `pto.*` 调用。 +- 保证多-op descriptor 在未绑定 concrete `op` 时不会提前 materialize。 +- 保持现有单-op kernel、现有显式真实 `pto.*` 调用写法、现有 selector API 与现有 legality contract 全部兼容。 + +**Non-Goals:** + +- 不支持在 kernel body 中执行 Python dict lookup、callable value、lambda、闭包或其他 higher-order call。 +- 不把 `pto.tpl(...)` 设计成运行时 dispatch;它不是运行时 helper,而是编译期 placeholder。 +- 不把 family 变成 `select_kernel(...)` 的新公共查询轴。 +- 不在本 change 中一次性覆盖所有 family 的模板化 authoring;首版只提供通用 slot 机制和明确的合法性边界。 + +## Decisions + +### 1. 新增独立 capability `tilelang-dsl-template-slots`,而不是把模板语义塞进 `advanced` 或 `matcher` + +决策: + +- `pto.tpl("slot", ...)` 与 `templates={...}` 作为独立 capability 定义 +- `tilelang-dsl-kernel-matcher` 只负责 multi-op descriptor 与 concrete `selected_op` 绑定 +- `tilelang-dsl-advanced-surface` 不承载本次模板语义 + +原因: + +- 模板槽位本质上是 authoring sugar,不等于 advanced-family lowering。 +- 它既可以服务 `advanced=True` 的 kernel,也可以服务显式 `strict_vecscope` 的非-advanced kernel。 +- 把它独立成 capability 更容易保持语义聚焦,避免把 advanced surface 继续做胖。 + +备选方案: + +- 方案 A:把 `pto.tpl(...)` 放进 `tilelang-dsl-advanced-surface` + - 未采用,因为模板槽位不天然依赖 advanced mode。 +- 方案 B:直接修改 matcher spec,不新增 capability + - 未采用,因为 matcher 只负责“选哪个 descriptor”,不负责“kernel body 如何模板化表达”。 + +### 2. 模板映射放在 decorator 元数据中,而不是开放 kernel body 里的 Python dict/callable + +决策: + +- `@pto.vkernel` 新增 `templates={...}` 静态元数据 +- kernel body 只允许写 `pto.tpl("slot", ...)` +- 模板映射值只接受真实 `pto.*` op 名字符串,不接受 Python callable + +原因: + +- 当前 frontend 明确拒绝 arbitrary external call,也没有把 Python dict 作为正式 DSL 表达式收进 AST/semantic 契约。 +- 如果允许 `ops["core"](...)` 或 `table[name](...)` 这类写法,就必须给 DSL 增加 dict、callable value、索引后调用、甚至作用域捕获等一整套新语义,复杂度和风险都明显超出这次 change 目标。 +- 把映射放在 decorator 元数据里,可以把模板分发收敛为 compile-time static metadata,保持 deterministic。 + +备选方案: + +- 方案 A:在 kernel body 支持 Python dict literal + indexed callable + - 未采用,因为会显著扩张 DSL 的 Python 子集和 frontend 复杂度。 +- 方案 B:用 family-specific placeholder op,如 `pto.vbinary` + - 未采用,因为会引入一批专用 placeholder API,扩展到其他 family 时容易持续膨胀。 + +### 3. 公开 surface 采用单一通用入口 `pto.tpl("slot", ...)` + +决策: + +- 新公共 API 只有一个模板入口:`pto.tpl("slot", *args)` +- `slot` 必须是字符串字面量 +- `slot` 对应的真实 op 映射由 `templates={...}` 给出 + +原因: + +- 单一入口能避免 placeholder API 数量随 family 数量膨胀。 +- `slot` 命名由 kernel author 控制,能表达“core”“cmp”“select”“postprocess”等局部语义,而不是把 DSL 绑死到某一组预定义 family 名称。 +- 这仍然允许一份 kernel body 在多个位置复用同一 slot,或在同一 kernel 里声明多个 slot。 + +备选方案: + +- 方案 A:`pto.tpl.core(...)` + - 未采用,因为会额外引入 attribute-based placeholder namespace,收益不高。 +- 方案 B:`pto.tpl["core"](...)` + - 未采用,因为需要 DSL 支持 subscripted callable surface。 + +### 4. matcher 扩展为 `op` / `ops`,并在 selection 之后绑定 `selected_op` + +决策: + +- `@pto.vkernel` 接受: + - `op="tadd"`,保持现状 + - `ops=["tadd", "tsub", "tmul", "tdiv"]`,作为新增能力 +- `op` 与 `ops` 必须互斥,且至少提供其一 +- descriptor 内部统一保存 `match_ops` +- `pto.select_kernel(...)` 保持现有公共签名不变,仍使用 concrete `op` 查询 +- selector 命中 multi-op descriptor 时,必须把 query `op` 绑定成唯一 `selected_op` + +原因: + +- 对外仍按 concrete `op` 查询,能兼容现有上层集成和调用路径。 +- `selected_op` 把模板替换需要的 concrete 上下文显式传入后续 materialization 阶段,避免“后面再猜当前 op 是什么”。 +- `op/ops` 互斥可以避免 descriptor 元数据出现双重来源和优先级歧义。 + +备选方案: + +- 方案 A:把 `select_kernel(...)` 改成 family 查询 + - 未采用,因为会改动公共 API 语义,也与用户当前实际关心的 concrete PTO op 查询不一致。 +- 方案 B:`op` 和 `ops` 同时允许出现,由实现隐式 merge + - 未采用,因为容易制造含糊的匹配集合和隐藏优先级。 + +### 5. 模板替换发生在 frontend AST 构建阶段,先替换后做 semantic/type-check + +决策: + +- `build_frontend_kernel_node(...)` 在把 Python AST 投影成 frontend AST 时识别 `pto.tpl("slot", ...)` +- 若 descriptor 已绑定 `selected_op`,则直接把该调用重写成真实 `FrontendCallExpr(namespace="pto", name="", args=...)` +- semantic analyzer 和 lowering renderer 继续只消费真实 `pto.*` op + +原因: + +- 这样 semantic/type-check/lowering 不需要知道“模板调用”这个新概念,只需要沿用已有的真实 op 检查逻辑。 +- 错误可以更早暴露在 frontend,而不是等到后续阶段才发现某个模板槽位无法解析。 +- 替换点足够早,能让模板调用出现在循环、分支、inferred vecscope、strict_vecscope 内等任意合法位置,而不改变后续编译形态。 + +备选方案: + +- 方案 A:在 semantic 阶段再解析 `pto.tpl` + - 未采用,因为 semantic 需要额外承载 unresolved template call,增加中间状态复杂度。 +- 方案 B:在 lowering 阶段才解析 + - 未采用,因为会把本该 fail-fast 的模板错误延后到更晚阶段。 + +### 6. 多-op descriptor 的 materialization gate 与 polymorphic dtypes 一致 + +决策: + +- 若 descriptor 同时满足以下任一条件,则不得直接 `mlir_text()` / `verify()` / `emit()`: + - 还未绑定 concrete `dtype_signature` + - 还未绑定 concrete `selected_op` +- 只有经过 `select_kernel(...)` 绑定后,才能进入 `specialize()` 与 materialization 流程 + +原因: + +- template slot 替换依赖 `selected_op`;未绑定 concrete `op` 时无法确定最终真实 `pto.*` 调用。 +- 这与当前 polymorphic `dtypes` 的 gate 规则一致,用户心智模型也更统一。 + +备选方案: + +- 方案 A:对 multi-op descriptor 默认取 `ops[0]` 作为 materialization op + - 未采用,因为会引入不透明的隐式默认值,破坏 deterministic 语义。 + +### 7. 模板槽位必须做静态合法性约束,而不是“能替换就替换” + +决策: + +- 注册时必须校验: + - `templates` 是静态 mapping + - slot 名是非空字符串 + - 映射 key 是 descriptor 可匹配的 concrete op 子集 + - 映射 value 是受支持的真实 `pto.*` op 名字符串 +- frontend 替换时必须校验: + - `slot` 参数是字符串字面量 + - 当前 `selected_op` 在该 slot 下存在映射 + - 模板展开后的真实 op 属于当前 DSL 支持 surface + +原因: + +- 模板能力本质上是“静态展开”,因此错误必须在 frontend 明确、可重复地暴露。 +- 如果允许模糊或延迟解析,后续会把用户错误混淆成底层 type/lowering 错误。 + +备选方案: + +- 方案 A:只在模板实际被执行到时懒解析 + - 未采用,因为 kernel DSL 不是运行时解释器,不应存在执行路径依赖的语义。 + +## 测试策略 + +- matcher 正例: + - `ops=[...]` 的 descriptor 可被不同 concrete `op` 查询命中 + - single-op 与 multi-op descriptor 并存时仍按 `priority` / tie error 决策 +- materialization 正例/负例: + - multi-op descriptor 在未绑定 `selected_op` 前拒绝 materialization + - 绑定后可继续 `specialize()` / `mlir_text()` +- template slot 正例: + - 同一份 kernel body 在 `tadd/tsub/tmul/tdiv` 下分别展开成正确真实 op + - `pto.tpl("slot", ...)` 可位于 loop、if、strict_vecscope、inferred vecscope 中 +- template slot 负例: + - 未定义 slot + - slot 不覆盖当前 `selected_op` + - `slot` 不是字符串字面量 + - 映射到未知或不受支持的 `pto.*` op +- 文档/样例: + - 增加一个共享 `tadd/tsub/tmul/tdiv` 的 template-slot 示例 + - 在 guide 中明确“为什么不支持 kernel body Python dict/callable” + +## Risks / Trade-offs + +- [Risk] `op` / `ops` 双入口会增加 decorator 心智负担 + Mitigation:要求两者互斥,并保持 `op=` 旧写法完全兼容。 + +- [Risk] `pto.tpl("slot", ...)` 可能被误解为运行时 helper + Mitigation:spec、文档和诊断都明确它是 compile-time placeholder,不是 runtime dispatch。 + +- [Risk] 如果模板映射值过于自由,容易把不兼容调用形态混到一个 slot 里 + Mitigation:要求映射值必须是当前已支持的真实 `pto.*` op 名,并在 frontend 做静态合法性检查。 + +- [Risk] 多-op descriptor 若允许隐式默认 concrete `op`,会导致 materialization 结果不稳定 + Mitigation:未绑定 `selected_op` 时一律拒绝 materialization。 + +## Migration Plan + +- 该 change 为增量能力,不需要仓库级迁移或兼容层清理。 +- 现有单-op kernel 与显式真实 `pto.*` 调用保持原样可用。 +- 新功能按以下顺序接入: + 1. 扩展 descriptor/matcher 元数据为 `op` / `ops` + `selected_op` + 2. 引入 `templates={...}` 与 `pto.tpl("slot", ...)` + 3. 在 frontend AST 构建阶段接入模板替换 + 4. 补齐 unittest、example、guide 与 migration 文档 +- 若实现中发现模板替换无法在 frontend 阶段稳定落地,回退策略是保留 `ops=[...]` matcher 扩展,而暂不开放 `pto.tpl(...)` public surface。 + +## Open Questions + +- 当前没有必须阻塞实现的开放问题。 +- 若后续需要让模板槽位覆盖 compare/select 或 vector-scalar family,应沿用同一 `pto.tpl("slot", ...)` 机制扩展映射表,而不是新增另一套 placeholder API。 diff --git a/openspec/changes/archive/2026-04-07-extend-tilelang-dsl-template-op-slots/proposal.md b/openspec/changes/archive/2026-04-07-extend-tilelang-dsl-template-op-slots/proposal.md new file mode 100644 index 000000000..99462a468 --- /dev/null +++ b/openspec/changes/archive/2026-04-07-extend-tilelang-dsl-template-op-slots/proposal.md @@ -0,0 +1,90 @@ +# Proposal: 扩展 TileLang DSL 的模板槽位与多-op matcher + +## 概述 + +当前 TileLang DSL 的 matcher 仍以单个 `op` 为中心,kernel body 里也只能直接写具体的 `pto.vadd`、`pto.vsub`、`pto.vmul` 等真实 op。 +这让 `tadd/tsub/tmul/tdiv` 这类同一 family op 很难共享一份实现;一旦核心计算只在一两个 vector op 上不同,作者就只能复制多份几乎相同的 kernel body。 + +## 背景与动机 + +现有 `tilelang-dsl-kernel-matcher` capability 已经支持多 signature `dtypes`、`constraints` 和 `priority`,但 descriptor 仍然只匹配一个具体 `op`。 +与此同时,DSL authoring surface 仍要求用户在 kernel body 中显式写出真实 `pto.*` 调用;当前 frontend 既不支持在 kernel body 中执行任意 Python dict/callable,也没有“模板 op -> concrete op”的编译期替换能力。 + +这会带来两个直接问题: + +1. 同一 family op 的实现复用成本高 + +- 像 `tadd/tsub/tmul/tdiv` 这种共享同一 loop / mask / load-store 骨架的 kernel,作者必须复制多份函数体,只替换中间一条 `pto.v*` 调用。 + +2. 现有 DSL 没有正式的“模板化 authoring”契约 + +- 如果让用户直接在 kernel body 里写 Python 字典、索引和 callable 分发,就会把当前受限 DSL 推向“任意 Python 解释执行器”,破坏现有 frontend 的 deterministic 边界。 + +因此,需要一个新的正式 capability,把“同一份 kernel body 可被多个具体 op 复用”收敛为静态、可验证、可测试的 OpenSpec 契约。 + +## 目标 + +- 为 TileLang DSL 增加模板槽位 authoring capability,使用户可以在 kernel body 任意合法位置使用统一的模板调用,并在编译期被替换成当前 concrete `op` 对应的真实 `pto.*` op。 +- 扩展 `tilelang-dsl-kernel-matcher` capability,使一个 descriptor 可以通过 `ops=[...]` 覆盖多个具体 PTO op,同时在 `select_kernel(...)` 之后绑定唯一 concrete `op`。 +- 保持 TileLang DSL 继续是静态、受限、deterministic 的 frontend,不把 kernel body 扩展为任意 Python dict/callable 执行环境。 + +## 非目标 + +- 不支持在 kernel body 中直接执行 Python dict lookup、lambda、函数对象调用或其他 higher-order runtime dispatch。 +- 不新增运行时模板分发机制;模板替换只发生在编译期。 +- 不改变 `pto.select_kernel(target, op, operand_types, context_attrs, registry=None)` 的公共查询形态。 +- 不在本 change 中一次性引入一批 family-specific placeholder op,如 `pto.vbinary`、`pto.vbinarys`、`pto.vcmp_template`。 + +## 变更内容 + +- 新增 `tilelang-dsl-template-slots` capability,定义: + - `@pto.vkernel(..., templates={...})` + - 通用模板入口 `pto.tpl("slot", ...)` + - 基于当前 concrete `op` 的编译期静态替换 + - 模板槽位的 frontend diagnostics 与合法性边界 +- 修改 `tilelang-dsl-kernel-matcher` capability,允许 descriptor 通过 `ops=[...]` 匹配多个具体 op,并要求 selector 在返回 descriptor 前绑定 concrete `op`。 +- 要求模板替换继续复用现有 semantic/type-check/lowering 路径,最终仍输出当前合法的 authoring-form VPTO,而不是发明新的公开中间 IR。 + +## Capabilities + +### New Capabilities + +- `tilelang-dsl-template-slots`: 定义 `pto.tpl("slot", ...)`、`templates={...}` 静态映射、模板槽位的编译期替换规则,以及相关 frontend diagnostics 与 materialization 边界。 + +### Modified Capabilities + +- `tilelang-dsl-kernel-matcher`: 从单 `op` descriptor 扩展到 `op` / `ops` matcher 元数据,并要求 `select_kernel(...)` 为多-op descriptor 绑定唯一 concrete `op` 后再进入 materialization。 + +## 预期结果 + +- 用户可以为 `tadd/tsub/tmul/tdiv` 这类同 family op 注册一份共享 kernel body,而不再复制多份只差一条核心 vector op 的实现。 +- kernel body 里的模板调用在 frontend 阶段被静态展开成真实 `pto.*` 调用,后续 semantic/type-check/lowering 继续走现有路径。 +- TileLang DSL 仍然保持“受限 Python 子集 + 固定 DSL call surface”的边界,不因为引入模板能力而变成任意 Python 执行器。 + +## 成功标准 + +- OpenSpec 中新增 `tilelang-dsl-template-slots` capability,并明确: + - `templates={...}` 的静态结构 + - `pto.tpl("slot", ...)` 的编译期替换语义 + - 非法模板映射、未绑定 concrete `op`、未知 slot 等 frontend 失败路径 +- OpenSpec 中修改 `tilelang-dsl-kernel-matcher` capability,并明确: + - `ops=[...]` 与现有 `op=` 的关系 + - selector 对多-op descriptor 的 concrete `op` 绑定语义 + - 多-op descriptor 在未绑定 concrete `op` 时不得 materialize +- change 落地后,`tilelang-dsl/tests/` 能新增覆盖: + - 多-op matcher 正例 + - `pto.tpl("slot", ...)` 对 `tadd/tsub/tmul/tdiv` 的展开正例 + - 模板映射与 materialization 的负例诊断 + +## 影响 + +- 受影响目录: + - `tilelang-dsl/python/` + - `tilelang-dsl/tests/` + - `tilelang-dsl/examples/` + - `tilelang-dsl/docs/` + - `openspec/specs/` +- 受影响 public API: + - `@pto.vkernel(..., op=..., ops=..., templates=...)` + - `pto.select_kernel(...)` + - `pto.tpl("slot", ...)` diff --git a/openspec/changes/archive/2026-04-07-extend-tilelang-dsl-template-op-slots/specs/tilelang-dsl-kernel-matcher/spec.md b/openspec/changes/archive/2026-04-07-extend-tilelang-dsl-template-op-slots/specs/tilelang-dsl-kernel-matcher/spec.md new file mode 100644 index 000000000..1e29dc76c --- /dev/null +++ b/openspec/changes/archive/2026-04-07-extend-tilelang-dsl-template-op-slots/specs/tilelang-dsl-kernel-matcher/spec.md @@ -0,0 +1,85 @@ +## MODIFIED Requirements + +### Requirement: TileLang DSL MUST provide an explicit kernel registry and selection API + +当同一 `target/op` 下存在多个 `@pto.vkernel` descriptor 时,TileLang DSL MUST 将它们注册到显式、可查询的 `KernelRegistry`。 +默认 registry MUST 是 module-level 对象;调用方 MAY 传入自定义 registry 以获得隔离的候选集合。 +系统 MUST 提供显式 selection API `pto.select_kernel(target, op, operand_types, context_attrs, registry=None)`,用于在给定 `target`、concrete `op`、operand type 信息和上下文属性时选择唯一 kernel。 +descriptor MUST 支持两种互斥的 matcher 元数据: + +- `op=""` +- `ops=["", "", ...]` + +descriptor MUST 至少提供其中一种,且实现 MUST NOT 同时接受两者。 +当 selector 命中一个 `ops=[...]` descriptor 时,返回结果 MUST 绑定当前 query 对应的唯一 concrete `selected_op`,再进入后续 `specialize()` / `mlir_text()` / `verify()` 流程。 +实现 MUST NOT 依赖扫描 Python globals、locals 或导入顺序来隐式发现候选。 + +#### Scenario: selector returns the unique best kernel + +- **WHEN** registry 中存在多个针对同一 `target/op` 的 kernel descriptor,且其中一个在全部匹配步骤后成为唯一最佳候选 +- **THEN** `pto.select_kernel(...)` MUST 返回该 descriptor +- **AND** 返回结果 MUST 可继续走 `specialize()` / `mlir_text()` / `verify()` 流程 + +#### Scenario: custom registry restricts the candidate set explicitly + +- **WHEN** 调用方显式传入一个只含局部 kernel 的 `KernelRegistry` +- **THEN** selector MUST 只在该 registry 的候选集合内做匹配和决策 +- **AND** MUST NOT 回退去查询 module-level 默认 registry + +#### Scenario: selector binds the concrete op for a multi-op descriptor + +- **WHEN** 一个 descriptor 通过 `ops=["tadd", "tsub", "tmul", "tdiv"]` 注册,且调用方以 `pto.select_kernel(..., op="tmul", ...)` 查询命中该 descriptor +- **THEN** selector MUST 返回已经绑定 `selected_op="tmul"` 的 descriptor +- **AND** 后续 materialization MUST 基于该 concrete `selected_op` 而不是未绑定的原始 matcher 集合 + +### Requirement: selection order MUST be target -> op -> dtype signature -> constraints -> priority -> tie error + +对一个 registry 中的候选集合,selector MUST 按以下固定顺序求值: + +1. `target` +2. `op` +3. `dtypes` signature 的 concrete / wildcard / type-variable 匹配 +4. `constraints` +5. `priority` +6. highest-priority tie error + +其中第 2 步的 `op` 匹配 MUST 使用调用方给出的 concrete query `op`,并按以下规则求值: + +- 对 `op=""` descriptor,要求 exact match +- 对 `ops=[...]` descriptor,要求 query `op` 属于该 matcher 集合 + +实现 MUST 保持该顺序 deterministic。 +系统 MUST NOT 依赖注册顺序、定义顺序、导入顺序或其他隐式规则来打破同一阶段的歧义。 +系统 MUST NOT 因为候选是 single-op descriptor 或 multi-op descriptor 而引入额外隐式优先级。 + +#### Scenario: type match happens before constraints and priority + +- **WHEN** 一个候选在 `target/op` 上匹配,但没有任何 `dtypes` signature 能通过 concrete / wildcard / `TypeVar` 规则 +- **THEN** 该候选 MUST 在进入 `constraints` 评估前被移除 +- **AND** 其 `priority` MUST NOT 参与后续决策 + +#### Scenario: multi-op descriptor participates in selection without hidden specificity bonus + +- **WHEN** 同一个 concrete query `op` 同时命中 single-op descriptor 与 multi-op descriptor +- **THEN** selector MUST 继续按既有的 `dtypes -> constraints -> priority -> tie error` 顺序求值 +- **AND** MUST NOT 仅因为 single-op descriptor 更“具体”就隐式优先选择它 + +## ADDED Requirements + +### Requirement: multi-op descriptors MUST require concrete op binding before IR materialization + +当 descriptor 使用 `ops=[...]` 覆盖多个 concrete PTO op 时,系统 MUST 在 materialization 前先绑定唯一 `selected_op`。 +在未绑定 concrete `selected_op` 之前,descriptor MUST NOT 允许执行 `mlir_text()`、`mlir_module()`、`verify()` 或 `emit(path)`。 +一旦 selector 已绑定 concrete `selected_op`,该 descriptor MUST 与已绑定 concrete dtype signature 的其他 descriptor 一样继续参与 specialization 和 materialization。 + +#### Scenario: unresolved multi-op descriptor is rejected before materialization + +- **WHEN** 用户直接对一个通过 `ops=[...]` 注册、但尚未经过 `pto.select_kernel(...)` 绑定 concrete `selected_op` 的 descriptor 调用 `mlir_text()` +- **THEN** frontend MUST 直接报错 +- **AND** 诊断 MUST 明确指出该 descriptor 需要先绑定 concrete `op` + +#### Scenario: selected multi-op descriptor can materialize normally + +- **WHEN** 一个 `ops=[...]` descriptor 已经通过 `pto.select_kernel(...)` 绑定了 concrete `selected_op` +- **THEN** 调用方 MUST 可以继续执行 `specialize()`、`mlir_text()`、`verify()` 和 `emit(path)` +- **AND** materialization 结果 MUST 使用已绑定的 concrete `selected_op` diff --git a/openspec/changes/archive/2026-04-07-extend-tilelang-dsl-template-op-slots/specs/tilelang-dsl-template-slots/spec.md b/openspec/changes/archive/2026-04-07-extend-tilelang-dsl-template-op-slots/specs/tilelang-dsl-template-slots/spec.md new file mode 100644 index 000000000..f31da78ef --- /dev/null +++ b/openspec/changes/archive/2026-04-07-extend-tilelang-dsl-template-op-slots/specs/tilelang-dsl-template-slots/spec.md @@ -0,0 +1,68 @@ +## ADDED Requirements + +### Requirement: TileLang DSL MUST provide static template-slot metadata and a compile-time placeholder call + +TileLang DSL MUST 提供基于 descriptor 元数据的模板槽位机制,用于让多个 concrete PTO op 共享同一份 kernel body。 +`@pto.vkernel` MAY 声明 `templates={...}` 静态映射;kernel body MAY 使用统一模板入口 `pto.tpl("slot", ...)`。 +`templates` MUST 是静态 mapping,slot 名 MUST 是非空字符串,映射 value MUST 是真实 `pto.*` op 名字符串。 +系统 MUST NOT 要求用户在 kernel body 中执行 Python dict lookup、callable value 调用或其他 runtime dispatch 来实现该能力。 + +#### Scenario: kernel declares a template slot and uses the placeholder call + +- **WHEN** 一个 kernel descriptor 声明 `templates={"core": {"tadd": "vadd", "tsub": "vsub"}}`,并在 kernel body 中使用 `pto.tpl("core", lhs, rhs, mask)` +- **THEN** frontend MUST 接受该模板槽位写法 +- **AND** 该模板调用 MUST 被视为 compile-time placeholder,而不是 runtime helper + +### Requirement: template-slot substitution MUST resolve from the selected concrete op before semantic checking and lowering + +对使用模板槽位的 kernel,frontend MUST 在 semantic checking 和 lowering 之前,根据 descriptor 已绑定的 concrete `selected_op` 把 `pto.tpl("slot", ...)` 静态替换成真实 `pto.(...)` 调用。 +替换后的真实调用 MUST 继续沿用现有 semantic/type-check/lowering 路径,并满足当前 authoring-form VPTO legality contract。 +模板调用 MAY 出现在 loop、`if`、显式 `strict_vecscope` 或 inferred `pto.vecscope` 等任意合法 DSL 位置;其替换结果 MUST 与用户直接书写真实 `pto.*` 调用等价。 + +#### Scenario: one shared kernel body expands to different real ops for different selected op values + +- **WHEN** 同一个 descriptor 通过不同的 concrete query `op` 分别绑定到 `selected_op="tadd"` 和 `selected_op="tsub"` +- **THEN** `pto.tpl("core", lhs, rhs, mask)` MUST 分别静态展开成 `pto.vadd(lhs, rhs, mask)` 和 `pto.vsub(lhs, rhs, mask)` +- **AND** 后续 semantic/type-check/lowering MUST 只看到展开后的真实 `pto.*` 调用 + +#### Scenario: template placeholder remains valid inside legal control-flow and vecscope contexts + +- **WHEN** `pto.tpl("core", ...)` 出现在合法的 `for`、`if`、`strict_vecscope` 或 inferred `pto.vecscope` 上下文中 +- **THEN** frontend MUST 先完成模板替换 +- **AND** 替换后的编译结果 MUST 与同位置直接书写真实 `pto.*` 调用保持等价 + +### Requirement: template slots MUST fail fast on unresolved or invalid static mappings + +模板槽位是 compile-time static surface,因此 frontend MUST 对以下情况 fail-fast: + +- `pto.tpl(...)` 的 slot 不是字符串字面量 +- 使用了未声明的 slot +- 当前 `selected_op` 在该 slot 下没有映射 +- 映射 value 不是已支持的真实 `pto.*` op 名 +- descriptor 尚未绑定 concrete `selected_op` 就尝试解析模板槽位 + +这些错误 MUST 在生成任何 VPTO IR 之前报出,且诊断 MUST 明确指出失败的 slot 或 concrete `op` 绑定原因。 + +#### Scenario: unknown slot or missing op mapping is rejected before IR generation + +- **WHEN** kernel body 中使用 `pto.tpl("core", ...)`,但 descriptor 没有声明 `core` slot,或该 slot 未覆盖当前 `selected_op` +- **THEN** frontend MUST 在生成任何 VPTO IR 之前报错 +- **AND** 诊断 MUST 指出缺失的 slot 或缺失的 concrete `op` 映射 + +#### Scenario: non-literal slot name is rejected as unsupported template syntax + +- **WHEN** 用户写出 `pto.tpl(slot_name, lhs, rhs, mask)`,其中 `slot_name` 不是字符串字面量 +- **THEN** frontend MUST 直接报错 +- **AND** MUST NOT 把该写法当作运行时字符串分发处理 + +### Requirement: template slots MUST NOT introduce arbitrary Python callable semantics into the DSL + +TileLang DSL MUST 继续保持受限 Python 子集。 +实现 MUST NOT 因为模板槽位能力而接受 kernel body 中的 dict-lookup callable、lambda、闭包函数对象调用或其他 higher-order dispatch。 +模板化 authoring 的正式路径 MUST 是 descriptor 元数据中的 `templates={...}` 加上 kernel body 中的 `pto.tpl("slot", ...)`。 + +#### Scenario: callable-based runtime template dispatch remains rejected + +- **WHEN** 用户尝试在 kernel body 中通过 `table["core"](lhs, rhs, mask)`、`resolver(lhs, rhs, mask)` 或等价 callable-dispatch 写法实现模板分发 +- **THEN** frontend MUST 继续按 unsupported Python / unsupported call surface 拒绝该写法 +- **AND** MUST NOT 把它解释成合法的 template-slot surface diff --git a/openspec/changes/archive/2026-04-07-extend-tilelang-dsl-template-op-slots/tasks.md b/openspec/changes/archive/2026-04-07-extend-tilelang-dsl-template-op-slots/tasks.md new file mode 100644 index 000000000..12bc2f20f --- /dev/null +++ b/openspec/changes/archive/2026-04-07-extend-tilelang-dsl-template-op-slots/tasks.md @@ -0,0 +1,24 @@ +## 1. Descriptor 与 matcher 元数据扩展 + +- [x] 1.1 扩展 `tilelang-dsl/python/tilelang_dsl/kernel.py` 的 `@pto.vkernel` / descriptor 校验逻辑,支持互斥的 `op=` 与 `ops=[...]`,并为 multi-op descriptor 保存 `match_ops` 与 `selected_op` +- [x] 1.2 更新 `pto.select_kernel(...)` 的候选过滤与绑定逻辑,使其能匹配 `ops=[...]` descriptor,并在返回前绑定 concrete `selected_op` +- [x] 1.3 为 multi-op descriptor 增加 materialization gate,确保未绑定 concrete `selected_op` 时拒绝 `mlir_text()`、`mlir_module()`、`verify()` 和 `emit(path)` + +## 2. 模板槽位 frontend 能力 + +- [x] 2.1 在 descriptor 元数据中增加 `templates={...}` 静态映射的解析与校验,限制 slot 名、映射 key/value 和与 `match_ops` 的一致性 +- [x] 2.2 在 frontend AST 构建路径中增加 `pto.tpl("slot", ...)` 识别与 compile-time 替换,把模板调用重写成真实 `pto.*` call +- [x] 2.3 补齐模板槽位的 fail-fast diagnostics,覆盖非字面量 slot、未知 slot、缺失 op 映射、非法真实 op 名以及未绑定 `selected_op` 的错误路径 + +## 3. 回归测试与验证 + +- [x] 3.1 在 `tilelang-dsl/tests/test_tilelang_dsl_v1.py` 增加 multi-op matcher 回归,覆盖 `ops=[...]` 命中、single-op/multi-op 竞争、priority 与 tie error 行为 +- [x] 3.2 在 `tilelang-dsl/tests/test_tilelang_dsl_v1.py` 增加 template-slot 正例回归,验证同一份 kernel body 在 `tadd/tsub/tmul/tdiv` 下分别展开成正确真实 `pto.*` op +- [x] 3.3 在 `tilelang-dsl/tests/test_tilelang_dsl_v1.py` 增加 template-slot 负例回归,覆盖未绑定 `selected_op`、未知 slot、非字面量 slot、非法映射值和 callable-based runtime dispatch reject +- [x] 3.4 运行最小验证集,至少包括 `PYTHONPATH=$PWD/tilelang-dsl/python python3 -m unittest $PWD/tilelang-dsl/tests/test_tilelang_dsl_v1.py` + +## 4. 样例与文档 + +- [x] 4.1 新增或更新 `tilelang-dsl/examples/` 中的共享 kernel body 样例,展示 `ops=[...] + templates={...} + pto.tpl("slot", ...)` 的推荐写法 +- [x] 4.2 更新 `docs/tilelang-dsl-guide.md`,补充 template-slot surface、`op`/`ops` 语义、编译期替换模型和不支持 kernel body Python dict/callable 的原因 +- [x] 4.3 更新 `tilelang-dsl/docs/matcher-and-advanced-surface-migration.md` 或相邻文档,说明从显式真实 `pto.*` 调用迁移到模板槽位写法的适用场景与边界 diff --git a/openspec/specs/tilelang-dsl-advanced-surface/spec.md b/openspec/specs/tilelang-dsl-advanced-surface/spec.md index 4a7f37752..c81b02081 100644 --- a/openspec/specs/tilelang-dsl-advanced-surface/spec.md +++ b/openspec/specs/tilelang-dsl-advanced-surface/spec.md @@ -2,11 +2,11 @@ ## ADDED Requirements -### Requirement: advanced mode MUST infer `pto.vecscope` for eligible vector chains while preserving `strict_vecscope` boundaries +### Requirement: advanced mode MUST preserve the base inferred-vecscope contract while adding explicit `strict_vecscope` boundaries -在 advanced mode 下,当用户省略显式 scope 且书写连续的 supported vector chain 时,frontend MUST 自动推断 dedicated `pto.vecscope`。 +在 advanced mode 下,当用户省略显式 scope 且书写连续的 supported vector chain 时,frontend MUST 继续沿用 base surface 的 inferred `pto.vecscope` 契约。 scalar op、控制流边界、外部 call 和显式 `strict_vecscope` MUST 切断该推断。 -`strict_vecscope` 继续作为硬边界,inference MUST NOT 穿越其边界。 +同时,explicit `strict_vecscope` MUST 只在 advanced mode 下可用,并继续作为硬边界,inference MUST NOT 穿越其边界。 inference 结果 MUST 继续满足当前 authoring-form VPTO legality contract,不得因为自动推断而放宽 typed-mask、capture operand、地址形态或 vecscope carrier 约束。 #### Scenario: contiguous vector chain becomes one inferred `pto.vecscope` @@ -27,6 +27,12 @@ inference 结果 MUST 继续满足当前 authoring-form VPTO legality contract - **THEN** frontend MUST 保留该 `strict_vecscope` 原样语义 - **AND** scope inference MUST NOT 跨越该显式边界去并合前后 vector chain +#### Scenario: explicit `strict_vecscope` stays unavailable outside advanced mode + +- **WHEN** 用户在未启用 `advanced=True` 的 kernel 中书写 `strict_vecscope` +- **THEN** frontend MUST 报 requires-advanced 的 surface 诊断 +- **AND** MUST NOT 因为 base surface 支持 inferred vecscope 而放开 explicit `strict_vecscope` + ### Requirement: advanced mode MUST support raw pointer, UBRef, low-level DMA, and `copy_ubuf_to_ubuf` authoring advanced mode MUST 将以下 surface 纳入正式契约: diff --git a/openspec/specs/tilelang-dsl-kernel-matcher/spec.md b/openspec/specs/tilelang-dsl-kernel-matcher/spec.md index fde2f1717..e024ae38a 100644 --- a/openspec/specs/tilelang-dsl-kernel-matcher/spec.md +++ b/openspec/specs/tilelang-dsl-kernel-matcher/spec.md @@ -1,12 +1,19 @@ # tilelang-dsl-kernel-matcher Specification -## ADDED Requirements +## MODIFIED Requirements ### Requirement: TileLang DSL MUST provide an explicit kernel registry and selection API 当同一 `target/op` 下存在多个 `@pto.vkernel` descriptor 时,TileLang DSL MUST 将它们注册到显式、可查询的 `KernelRegistry`。 默认 registry MUST 是 module-level 对象;调用方 MAY 传入自定义 registry 以获得隔离的候选集合。 -系统 MUST 提供显式 selection API `pto.select_kernel(target, op, operand_types, context_attrs, registry=None)`,用于在给定 `target`、`op`、operand type 信息和上下文属性时选择唯一 kernel。 +系统 MUST 提供显式 selection API `pto.select_kernel(target, op, operand_types, context_attrs, registry=None)`,用于在给定 `target`、concrete `op`、operand type 信息和上下文属性时选择唯一 kernel。 +descriptor MUST 支持两种互斥的 matcher 元数据: + +- `op=""` +- `ops=["", "", ...]` + +descriptor MUST 至少提供其中一种,且实现 MUST NOT 同时接受两者。 +当 selector 命中一个 `ops=[...]` descriptor 时,返回结果 MUST 绑定当前 query 对应的唯一 concrete `selected_op`,再进入后续 `specialize()` / `mlir_text()` / `verify()` 流程。 实现 MUST NOT 依赖扫描 Python globals、locals 或导入顺序来隐式发现候选。 #### Scenario: selector returns the unique best kernel @@ -21,6 +28,12 @@ - **THEN** selector MUST 只在该 registry 的候选集合内做匹配和决策 - **AND** MUST NOT 回退去查询 module-level 默认 registry +#### Scenario: selector binds the concrete op for a multi-op descriptor + +- **WHEN** 一个 descriptor 通过 `ops=["tadd", "tsub", "tmul", "tdiv"]` 注册,且调用方以 `pto.select_kernel(..., op="tmul", ...)` 查询命中该 descriptor +- **THEN** selector MUST 返回已经绑定 `selected_op="tmul"` 的 descriptor +- **AND** 后续 materialization MUST 基于该 concrete `selected_op` 而不是未绑定的原始 matcher 集合 + ### Requirement: matcher MUST support concrete types, `Any*`, and `TypeVar` across multiple signatures matcher MUST 支持: @@ -52,8 +65,14 @@ matcher MUST 支持: 5. `priority` 6. highest-priority tie error +其中第 2 步的 `op` 匹配 MUST 使用调用方给出的 concrete query `op`,并按以下规则求值: + +- 对 `op=""` descriptor,要求 exact match +- 对 `ops=[...]` descriptor,要求 query `op` 属于该 matcher 集合 + 实现 MUST 保持该顺序 deterministic。 -系统 MUST NOT 依赖注册顺序、定义顺序、导入顺序或其他隐式规则来打破同一阶段的歧义。 +系统 MUST NOT 依赖注册顺序、定义顺序、导入顺序或其他隐式规则来打破同一阶段的歧义。 +系统 MUST NOT 因为候选是 single-op descriptor 或 multi-op descriptor 而引入额外隐式优先级。 #### Scenario: type match happens before constraints and priority @@ -61,6 +80,12 @@ matcher MUST 支持: - **THEN** 该候选 MUST 在进入 `constraints` 评估前被移除 - **AND** 其 `priority` MUST NOT 参与后续决策 +#### Scenario: multi-op descriptor participates in selection without hidden specificity bonus + +- **WHEN** 同一个 concrete query `op` 同时命中 single-op descriptor 与 multi-op descriptor +- **THEN** selector MUST 继续按既有的 `dtypes -> constraints -> priority -> tie error` 顺序求值 +- **AND** MUST NOT 仅因为 single-op descriptor 更“具体”就隐式优先选择它 + ### Requirement: constraint evaluation MUST happen after type matching and before priority resolution 对同一 `target/op` 的候选集合,matcher MUST 先完成 dtype matching,再评估 `constraints`。 @@ -83,3 +108,23 @@ matcher MUST 支持: - **THEN** selector MUST 报错 - **AND** 错误消息 MUST 指出发生 tie 的 kernel 集合 - **AND** MUST NOT 静默选择第一个已注册 kernel + +## ADDED Requirements + +### Requirement: multi-op descriptors MUST require concrete op binding before IR materialization + +当 descriptor 使用 `ops=[...]` 覆盖多个 concrete PTO op 时,系统 MUST 在 materialization 前先绑定唯一 `selected_op`。 +在未绑定 concrete `selected_op` 之前,descriptor MUST NOT 允许执行 `mlir_text()`、`mlir_module()`、`verify()` 或 `emit(path)`。 +一旦 selector 已绑定 concrete `selected_op`,该 descriptor MUST 与已绑定 concrete dtype signature 的其他 descriptor 一样继续参与 specialization 和 materialization。 + +#### Scenario: unresolved multi-op descriptor is rejected before materialization + +- **WHEN** 用户直接对一个通过 `ops=[...]` 注册、但尚未经过 `pto.select_kernel(...)` 绑定 concrete `selected_op` 的 descriptor 调用 `mlir_text()` +- **THEN** frontend MUST 直接报错 +- **AND** 诊断 MUST 明确指出该 descriptor 需要先绑定 concrete `op` + +#### Scenario: selected multi-op descriptor can materialize normally + +- **WHEN** 一个 `ops=[...]` descriptor 已经通过 `pto.select_kernel(...)` 绑定了 concrete `selected_op` +- **THEN** 调用方 MUST 可以继续执行 `specialize()`、`mlir_text()`、`verify()` 和 `emit(path)` +- **AND** materialization 结果 MUST 使用已绑定的 concrete `selected_op` diff --git a/openspec/specs/tilelang-dsl-template-slots/spec.md b/openspec/specs/tilelang-dsl-template-slots/spec.md new file mode 100644 index 000000000..48dff2e7c --- /dev/null +++ b/openspec/specs/tilelang-dsl-template-slots/spec.md @@ -0,0 +1,70 @@ +# tilelang-dsl-template-slots Specification + +## ADDED Requirements + +### Requirement: TileLang DSL MUST provide static template-slot metadata and a compile-time placeholder call + +TileLang DSL MUST 提供基于 descriptor 元数据的模板槽位机制,用于让多个 concrete PTO op 共享同一份 kernel body。 +`@pto.vkernel` MAY 声明 `templates={...}` 静态映射;kernel body MAY 使用统一模板入口 `pto.tpl("slot", ...)`。 +`templates` MUST 是静态 mapping,slot 名 MUST 是非空字符串,映射 value MUST 是真实 `pto.*` op 名字符串。 +系统 MUST NOT 要求用户在 kernel body 中执行 Python dict lookup、callable value 调用或其他 runtime dispatch 来实现该能力。 + +#### Scenario: kernel declares a template slot and uses the placeholder call + +- **WHEN** 一个 kernel descriptor 声明 `templates={"core": {"tadd": "vadd", "tsub": "vsub"}}`,并在 kernel body 中使用 `pto.tpl("core", lhs, rhs, mask)` +- **THEN** frontend MUST 接受该模板槽位写法 +- **AND** 该模板调用 MUST 被视为 compile-time placeholder,而不是 runtime helper + +### Requirement: template-slot substitution MUST resolve from the selected concrete op before semantic checking and lowering + +对使用模板槽位的 kernel,frontend MUST 在 semantic checking 和 lowering 之前,根据 descriptor 已绑定的 concrete `selected_op` 把 `pto.tpl("slot", ...)` 静态替换成真实 `pto.(...)` 调用。 +替换后的真实调用 MUST 继续沿用现有 semantic/type-check/lowering 路径,并满足当前 authoring-form VPTO legality contract。 +模板调用 MAY 出现在 loop、`if`、显式 `strict_vecscope` 或 inferred `pto.vecscope` 等任意合法 DSL 位置;其替换结果 MUST 与用户直接书写真实 `pto.*` 调用等价。 + +#### Scenario: one shared kernel body expands to different real ops for different selected op values + +- **WHEN** 同一个 descriptor 通过不同的 concrete query `op` 分别绑定到 `selected_op="tadd"` 和 `selected_op="tsub"` +- **THEN** `pto.tpl("core", lhs, rhs, mask)` MUST 分别静态展开成 `pto.vadd(lhs, rhs, mask)` 和 `pto.vsub(lhs, rhs, mask)` +- **AND** 后续 semantic/type-check/lowering MUST 只看到展开后的真实 `pto.*` 调用 + +#### Scenario: template placeholder remains valid inside legal control-flow and vecscope contexts + +- **WHEN** `pto.tpl("core", ...)` 出现在合法的 `for`、`if`、`strict_vecscope` 或 inferred `pto.vecscope` 上下文中 +- **THEN** frontend MUST 先完成模板替换 +- **AND** 替换后的编译结果 MUST 与同位置直接书写真实 `pto.*` 调用保持等价 + +### Requirement: template slots MUST fail fast on unresolved or invalid static mappings + +模板槽位是 compile-time static surface,因此 frontend MUST 对以下情况 fail-fast: + +- `pto.tpl(...)` 的 slot 不是字符串字面量 +- 使用了未声明的 slot +- 当前 `selected_op` 在该 slot 下没有映射 +- 映射 value 不是已支持的真实 `pto.*` op 名 +- descriptor 尚未绑定 concrete `selected_op` 就尝试解析模板槽位 + +这些错误 MUST 在生成任何 VPTO IR 之前报出,且诊断 MUST 明确指出失败的 slot 或 concrete `op` 绑定原因。 + +#### Scenario: unknown slot or missing op mapping is rejected before IR generation + +- **WHEN** kernel body 中使用 `pto.tpl("core", ...)`,但 descriptor 没有声明 `core` slot,或该 slot 未覆盖当前 `selected_op` +- **THEN** frontend MUST 在生成任何 VPTO IR 之前报错 +- **AND** 诊断 MUST 指出缺失的 slot 或缺失的 concrete `op` 映射 + +#### Scenario: non-literal slot name is rejected as unsupported template syntax + +- **WHEN** 用户写出 `pto.tpl(slot_name, lhs, rhs, mask)`,其中 `slot_name` 不是字符串字面量 +- **THEN** frontend MUST 直接报错 +- **AND** MUST NOT 把该写法当作运行时字符串分发处理 + +### Requirement: template slots MUST NOT introduce arbitrary Python callable semantics into the DSL + +TileLang DSL MUST 继续保持受限 Python 子集。 +实现 MUST NOT 因为模板槽位能力而接受 kernel body 中的 dict-lookup callable、lambda、闭包函数对象调用或其他 higher-order dispatch。 +模板化 authoring 的正式路径 MUST 是 descriptor 元数据中的 `templates={...}` 加上 kernel body 中的 `pto.tpl("slot", ...)`。 + +#### Scenario: callable-based runtime template dispatch remains rejected + +- **WHEN** 用户尝试在 kernel body 中通过 `table["core"](lhs, rhs, mask)`、`resolver(lhs, rhs, mask)` 或等价 callable-dispatch 写法实现模板分发 +- **THEN** frontend MUST 继续按 unsupported Python / unsupported call surface 拒绝该写法 +- **AND** MUST NOT 把它解释成合法的 template-slot surface diff --git a/test/basic/expand_tile_op_tilelang.pto b/test/basic/expand_tile_op_tilelang.pto index e0733908b..4ae45a4e9 100644 --- a/test/basic/expand_tile_op_tilelang.pto +++ b/test/basic/expand_tile_op_tilelang.pto @@ -7,43 +7,41 @@ // See LICENSE in the root of the software repository for the full text of the License. // Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline -// expands pto.tadd via TileLang Python DSL templates. +// expands pto.tadd via the default TileLang Python DSL template +// lib/TileOps/tadd_template.py. // // Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics // -// REQUIRES: tilelang-dsl-tile-buf-params -// RUN: ptoas --pto-arch=a5 --enable-tile-to-vector --tilelang-path=%S/../tilelang_templates --tilelang-pkg-path=%S/../../tilelang-dsl/python %s 2>/dev/null | FileCheck %s -// -// NOTE: This test requires the TileLang DSL to generate functions with -// tile_buf parameters (not memref). The test is gated by REQUIRES until -// the DSL frontend is updated. +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s -// After the full pipeline, tile_buf_addr/tile_valid_rows/tile_valid_cols -// should be folded away, and the vector loop body should be inlined. +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tadd should be lowered to vector-style VPTO IR. // CHECK: func.func @TADD // CHECK-NOT: pto.tadd ins -// CHECK: pto.simd.tile_to_memref // CHECK: pto.vecscope +// CHECK: pto.castptr // CHECK: pto.vlds // CHECK: pto.vadd // CHECK: pto.vsts module { - func.func @TADD( - %a: !pto.tile_buf, - %b: !pto.tile_buf, - %c: !pto.tile_buf) - attributes { pto.tile_function = "pto.tadd" } { + func.func @TADD() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf pto.tadd ins(%a, %b : !pto.tile_buf, !pto.tile_buf) - outs(%c : !pto.tile_buf) + outs(%tile_buf : !pto.tile_buf) return } } diff --git a/test/basic/fold_tile_buf_intrinsics.pto b/test/basic/fold_tile_buf_intrinsics.pto index b6c565a83..aaa8c06c1 100644 --- a/test/basic/fold_tile_buf_intrinsics.pto +++ b/test/basic/fold_tile_buf_intrinsics.pto @@ -75,12 +75,15 @@ module attributes {pto.target_arch = "a5"} { pto.vecscope { scf.for %i = %c0 to %v_rows step %c1 { %tmp = arith.index_cast %v_cols : index to i32 + %rowDst = memref.subview %mDst[%i, %c0] [1, 64] [1, 1] : memref<16x64xf32, strided<[64, 1]>, #pto.address_space> to memref<64xf32, strided<[1], offset: ?>, #pto.address_space> + %rowSrc0 = memref.subview %mSrc0[%i, %c0] [1, 64] [1, 1] : memref<16x64xf32, strided<[64, 1]>, #pto.address_space> to memref<64xf32, strided<[1], offset: ?>, #pto.address_space> + %rowSrc1 = memref.subview %mSrc1[%i, %c0] [1, 64] [1, 1] : memref<16x64xf32, strided<[64, 1]>, #pto.address_space> to memref<64xf32, strided<[1], offset: ?>, #pto.address_space> %inner = scf.for %j = %c0 to %v_cols step %c64 iter_args(%remain = %tmp) -> (i32) { %mask, %next = pto.plt_b32 %remain : i32 -> !pto.mask, i32 - %va = pto.vlds %mSrc0[%i, %j] : memref<16x64xf32, strided<[64, 1]>, #pto.address_space> -> !pto.vreg<64xf32> - %vb = pto.vlds %mSrc1[%i, %j] : memref<16x64xf32, strided<[64, 1]>, #pto.address_space> -> !pto.vreg<64xf32> + %va = pto.vlds %rowSrc0[%j] : memref<64xf32, strided<[1], offset: ?>, #pto.address_space> -> !pto.vreg<64xf32> + %vb = pto.vlds %rowSrc1[%j] : memref<64xf32, strided<[1], offset: ?>, #pto.address_space> -> !pto.vreg<64xf32> %vc = pto.vadd %va, %vb, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> - pto.vsts %vc, %mDst[%i, %j], %mask : !pto.vreg<64xf32>, memref<16x64xf32, strided<[64, 1]>, #pto.address_space>, !pto.mask + pto.vsts %vc, %rowDst[%j], %mask : !pto.vreg<64xf32>, memref<64xf32, strided<[1], offset: ?>, #pto.address_space>, !pto.mask scf.yield %next : i32 } } diff --git a/test/basic/inline_libcall_filter_tilelang_scope.pto b/test/basic/inline_libcall_filter_tilelang_scope.pto new file mode 100644 index 000000000..c3e1cb7ac --- /dev/null +++ b/test/basic/inline_libcall_filter_tilelang_scope.pto @@ -0,0 +1,38 @@ +// RUN: ptoas --pto-arch=a5 --pass-pipeline="builtin.module(pto-inline-libcall)" %s | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + func.func @kernel(%arg0: i32) { + %v0 = func.call @__tl_inline_add1_i32(%arg0) : (i32) -> i32 + %v1 = func.call @__tilelang_template_passthrough_i32(%v0) : (i32) -> i32 + %v2 = func.call @__regular_passthrough_i32(%v1) : (i32) -> i32 + return + } + + func.func private @__tl_inline_add1_i32(%x: i32) -> i32 attributes { pto.tilelang.inline_proc } { + %c1 = arith.constant 1 : i32 + %y = arith.addi %x, %c1 : i32 + return %y : i32 + } + + // This is intentionally NOT an inline_proc helper. The inline pass should + // still inline it because template helper inlining is enabled in VPTO + // mainline regardless of --enable-tile-op-expand. + func.func private @__tilelang_template_passthrough_i32(%x: i32) -> i32 attributes { pto.tilelang.instance } { + return %x : i32 + } + + // Regular private helper without TileLang/OP-Lib attrs must not be inlined + // by pto-inline-libcall. + func.func private @__regular_passthrough_i32(%x: i32) -> i32 { + return %x : i32 + } +} + +// CHECK-LABEL: func.func @kernel +// CHECK: arith.constant 1 : i32 +// CHECK-NOT: func.call @__tl_inline_add1_i32 +// CHECK-NOT: call @__tilelang_template_passthrough_i32 +// CHECK: call @__regular_passthrough_i32 +// CHECK-NOT: func.func private @__tilelang_template_passthrough_i32 +// CHECK: func.func private @__regular_passthrough_i32 +// CHECK-NOT: func.func private @__tl_inline_add1_i32 diff --git a/test/basic/inline_libcall_result_rewrite.pto b/test/basic/inline_libcall_result_rewrite.pto new file mode 100644 index 000000000..48af8109d --- /dev/null +++ b/test/basic/inline_libcall_result_rewrite.pto @@ -0,0 +1,36 @@ +// RUN: ptoas --pto-arch=a5 --pass-pipeline="builtin.module(pto-inline-libcall)" %s | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + func.func @kernel(%x: i32) { + %a, %b = func.call @__tl_inline_pair_i32(%x) : (i32) -> (i32, i32) + %sum = arith.addi %a, %b : i32 + func.call @__tl_inline_sink_i32(%sum) : (i32) -> () + return + } + + func.func private @__tl_inline_pair_i32(%arg0: i32) -> (i32, i32) attributes { pto.tilelang.inline_proc } { + %c1 = arith.constant 1 : i32 + %c2 = arith.constant 2 : i32 + %v1 = arith.addi %arg0, %c1 : i32 + %v2 = arith.addi %arg0, %c2 : i32 + return %v1, %v2 : i32, i32 + } + + func.func private @__tl_inline_sink_i32(%arg0: i32) attributes { pto.tilelang.inline_proc } { + %c0 = arith.constant 0 : i32 + %_ = arith.addi %arg0, %c0 : i32 + return + } +} + +// CHECK-LABEL: func.func @kernel( +// CHECK: %[[C1:.+]] = arith.constant 1 : i32 +// CHECK: %[[C2:.+]] = arith.constant 2 : i32 +// CHECK: %[[V1:.+]] = arith.addi %{{[^,]+}}, %[[C1]] : i32 +// CHECK: %[[V2:.+]] = arith.addi %{{[^,]+}}, %[[C2]] : i32 +// CHECK: %[[SUM:.+]] = arith.addi %[[V1]], %[[V2]] : i32 +// CHECK: arith.constant 0 : i32 +// CHECK: arith.addi %[[SUM]] +// CHECK-NOT: func.call @__tl_inline_ +// CHECK-NOT: func.func private @__tl_inline_pair_i32 +// CHECK-NOT: func.func private @__tl_inline_sink_i32 diff --git a/test/basic/tilelang_inline_proc_backend_inline.pto b/test/basic/tilelang_inline_proc_backend_inline.pto new file mode 100644 index 000000000..5695a2c46 --- /dev/null +++ b/test/basic/tilelang_inline_proc_backend_inline.pto @@ -0,0 +1,30 @@ +// RUN: ptoas --pto-arch=a5 --pass-pipeline="builtin.module(pto-inline-libcall)" %s | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + func.func @kernel(%arg0: i32) attributes { pto.tilelang.instance } { + %0 = func.call @__tl_inline_add1_i32(%arg0) : (i32) -> i32 + func.call @__tl_inline_sink_i32(%0) : (i32) -> () + return + } + + func.func private @__tl_inline_add1_i32(%x: i32) -> i32 attributes { pto.tilelang.inline_proc } { + %c1 = arith.constant 1 : i32 + %v = arith.addi %x, %c1 : i32 + return %v : i32 + } + + func.func private @__tl_inline_sink_i32(%x: i32) attributes { pto.tilelang.inline_proc } { + %c2 = arith.constant 2 : i32 + %t = arith.addi %x, %c2 : i32 + return + } +} + +// CHECK-LABEL: func.func @kernel +// CHECK: %[[C1:.+]] = arith.constant 1 : i32 +// CHECK: %[[ADD1:.+]] = arith.addi %arg0, %[[C1]] : i32 +// CHECK: %[[C2:.+]] = arith.constant 2 : i32 +// CHECK: arith.addi %[[ADD1]], %[[C2]] : i32 +// CHECK-NOT: func.call @__tl_inline_ +// CHECK-NOT: func.func private @__tl_inline_add1_i32 +// CHECK-NOT: func.func private @__tl_inline_sink_i32 diff --git a/test/basic/vpto_mainline_inline_proc_cleanup.pto b/test/basic/vpto_mainline_inline_proc_cleanup.pto new file mode 100644 index 000000000..fd775fa39 --- /dev/null +++ b/test/basic/vpto_mainline_inline_proc_cleanup.pto @@ -0,0 +1,28 @@ +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto %s -o - | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + func.func @kernel(%arg0: i32) attributes { pto.tilelang.instance } { + %0 = func.call @__tl_inline_add1_i32(%arg0) : (i32) -> i32 + func.call @__tl_inline_sink_i32(%0) : (i32) -> () + return + } + + func.func private @__tl_inline_add1_i32(%x: i32) -> i32 attributes { pto.tilelang.inline_proc } { + %c1 = arith.constant 1 : i32 + %v = arith.addi %x, %c1 : i32 + return %v : i32 + } + + func.func private @__tl_inline_sink_i32(%x: i32) attributes { pto.tilelang.inline_proc } { + %c2 = arith.constant 2 : i32 + %t = arith.addi %x, %c2 : i32 + return + } +} + +// CHECK-LABEL: func.func @kernel +// CHECK: return +// CHECK-NOT: func.call @__tl_inline_ +// CHECK-NOT: call @__tl_inline_ +// CHECK-NOT: pto.tilelang.inline_proc +// CHECK-NOT: func.func private @__tl_inline_ diff --git a/test/dsl/expand_tile_op_tilelang_tadds.pto b/test/dsl/expand_tile_op_tilelang_tadds.pto new file mode 100644 index 000000000..e7360c376 --- /dev/null +++ b/test/dsl/expand_tile_op_tilelang_tadds.pto @@ -0,0 +1,36 @@ +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test ExpandTileOp expansion for pto.tadds in the VPTO pipeline. +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand --emit-vpto %s -o - | FileCheck %s + +// CHECK: func.func @TADDS() +// CHECK: pto.vecscope +// CHECK: pto.addptr +// CHECK: pto.vlds +// CHECK: pto.vadds +// CHECK: pto.vsts +// CHECK-NOT: memref.cast +// CHECK-NOT: builtin.unrealized_conversion_cast + +module { + func.func @TADDS() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + %scalar = arith.constant 1.0 : f32 + + pto.tadds ins(%a, %scalar : !pto.tile_buf, + f32) + outs(%tile_buf : !pto.tile_buf) + return + } +} diff --git a/tilelang-dsl/docs/README.md b/tilelang-dsl/docs/README.md index b93493489..abfb405dd 100644 --- a/tilelang-dsl/docs/README.md +++ b/tilelang-dsl/docs/README.md @@ -1,18 +1,49 @@ -TileLang DSL local documentation lives here. - -Current docs: -- `v1-surface.md`: the TileLang DSL v1 contract implemented by - `add-tilelang-dsl-core-foundation` -- `v1-lowering.md`: the TileLang DSL v1 authoring-form VPTO lowering contract - implemented by `add-tilelang-dsl-authoring-vpto-lowering` -- `matcher-and-advanced-surface-migration.md`: migration notes from the - original v1 core/lowering boundary to the matcher and advanced-surface - capability implemented by - `extend-tilelang-dsl-matcher-and-advanced-surface` - -Documentation boundary: -- `tilelang-dsl/docs/` is the local documentation source of truth for the new - `tilelang_dsl` frontend -- repository-level docs may link here, but should not redefine this package's - implemented v1 boundary -- `python/pto/dialects/pto.py` is not the source of truth for TileLang DSL v1 +# TileLang DSL 文档 + +TileLang Python DSL 为面向 Ascend NPU 硬件的向量计算内核提供高级的 Pythonic 接口。本指南适用于需要编写高效、硬件感知内核的库开发人员和性能工程师。 + +## 文档结构 + +### 入门指南 +- [简介](user_guide/01-introduction.md) - 语言概述、层级、基本vs高级模式 +- [快速开始](user_guide/02-quick-start.md) - 快速入门示例 + +### 核心概念 +- [内核声明](user_guide/03-kernel-declaration.md) - 内核声明、装饰器参数、约束系统 +- [模板内核](user_guide/04-template-kernels.md) - 模板内核、多操作内核、编译时代换 + +### 类型系统 +- [类型系统](user_guide/05-type-system.md) - 标量类型、向量类型、指针类型 +- [TensorView](user_guide/06-tensorview.md) - TensorView类型、属性、切片语法 +- [Tile类型](user_guide/07-tile-types.md) - Tile类型、属性、配置、操作 + +### 控制流 +- [控制流](user_guide/08-control-flow.md) - 向量作用域、循环、条件语句 + +### 操作参考 +- [前端操作](user_guide/09-frontend-operations.md) - 前端操作、类型查询、指针构造 +- [同步和DMA操作](user_guide/10-sync-dma-operations.md) - 同步和DMA操作 +- [向量内存操作](user_guide/11-vector-memory-operations.md) - 向量加载和存储操作 +- [谓词操作](user_guide/12-predicate-operations.md) - 谓词操作 +- [向量算术操作](user_guide/13-vector-arithmetic-operations.md) - 向量算术操作 + +### 示例和错误处理 +- [示例](user_guide/15-examples.md) - 各种内核示例 +- [常见错误](user_guide/16-common-errors.md) - 常见错误和解决方案 + +### 附录 +- [兼容性说明](user_guide/17-compatibility-notes.md) - 与实验实现的差异 +- [后续步骤](user_guide/18-next-steps.md) - 相关资源链接 + +## 相关文档 +- [v1-surface.md](v1-surface.md) - TileLang DSL v1 合约 +- [v1-lowering.md](v1-lowering.md) - TileLang DSL v1 降低合约 +- [matcher-and-advanced-surface-migration.md](matcher-and-advanced-surface-migration.md) - 迁移说明 +- [unsupported-features.md](unsupported-features.md) - 不支持的功能 + +--- + +**原始文档边界说明**: +- `tilelang-dsl/docs/` 是新的 `tilelang_dsl` 前端本地文档的真实来源 +- 仓库级文档可以链接到这里,但不应重新定义此包实现的 v1 边界 +- `python/pto/dialects/pto.py` 不是 TileLang DSL v1 的真实来源 diff --git a/tilelang-dsl/docs/matcher-and-advanced-surface-migration.md b/tilelang-dsl/docs/matcher-and-advanced-surface-migration.md index c28240ea0..4117cf1b2 100644 --- a/tilelang-dsl/docs/matcher-and-advanced-surface-migration.md +++ b/tilelang-dsl/docs/matcher-and-advanced-surface-migration.md @@ -6,15 +6,38 @@ This document explains how to move from the original v1 core contract (`add-tilelang-dsl-core-foundation` + `add-tilelang-dsl-authoring-vpto-lowering`) to the matcher and advanced-surface capability implemented by -`extend-tilelang-dsl-matcher-and-advanced-surface`. +`extend-tilelang-dsl-matcher-and-advanced-surface`, and how to adopt the +template-slot authoring model added by +`extend-tilelang-dsl-template-op-slots`. It focuses on: - matcher-driven kernel selection +- migration from explicit real `pto.*` calls to template-slot authoring - implicit vecscope inference - raw pointer / low-level DMA authoring - advanced vector-family coverage that is implemented today - the remaining deferred boundary +## Current Tier Snapshot + +This migration note lives at the boundary between the basic starter path and +the broader expert surface. The public-surface groups discussed across the +guide, this migration note, and the support matrix currently map to tiers as +follows: + +| Surface Family | Tier | Migration Meaning | +|----------------|------|-------------------| +| `TensorView` | `basic` | Keep as the default GM-facing operand model. | +| `Tile` | `basic` | Keep as the default UB-facing compute tile model. | +| `dma_load` / `dma_store` | `basic` | Keep as the preferred high-level GM <-> UB path. | +| Base vector ops such as `make_mask`, `vlds`, `vsts`, `vadd`, `vmuls` | `basic` | Keep as the default compute skeleton before dropping to expert surfaces. | +| Raw pointer family such as `ptr(...)`, `castptr`, `addptr` | `advanced` | Use when moving from the starter path to expert pointer-form authoring. | +| Low-level DMA family such as `copy_*` and `set_loop*_stride_*` / `set_loop_size_*` | `advanced` | Use only when the high-level DMA surface is not sufficient. | +| Tile helper family such as `tile.slice(...)`, `tile.reshape(...)`, `tile.as_ptr()`, `tile_from_ptr(...)`, `tile_with_strides(...)`, `tile_config(...)` | `advanced` | Treat as partial or evolving surface rather than part of the basic starter path. | + +For the exact tier source of truth, see +`tilelang-dsl/python/tilelang_dsl/support_matrix.py`. + ## What Changed The original v1 core profile assumed: @@ -28,10 +51,14 @@ The current package now adds: - `KernelRegistry` - `pto.select_kernel(...)` - multi-signature `dtypes` +- multi-op descriptors via `op=` / `ops=[...]` - `AnyFloat`, `AnyInt`, `AnyType`, `AnyMask` - `TypeVar(...)` - `constraints=[...]` - `priority=` +- descriptor-bound `selected_op` for multi-op matches +- `templates={...}` +- `pto.tpl("slot", ...)` - implicit vecscope inference in `advanced=True` kernels - `ptr(...)` / `PointerType` - `castptr`, `addptr` @@ -62,7 +89,7 @@ the concrete specialization: (pto.AnyFloat, pto.AnyFloat), (pto.AnyInt, pto.AnyInt), ], - constraints=[lambda attrs: attrs.get("enabled", True)], + constraints=[lambda enabled=True: enabled], priority=10, ) def kernel(inp: pto.TensorView, out: pto.Tile): @@ -81,6 +108,27 @@ Matcher rules in the implemented package: - selection order is `target -> op -> dtypes -> constraints -> priority` - highest-priority ties raise an explicit error - `TypeVar` only binds within one signature +- `op=` and `ops=[...]` are mutually exclusive +- `ops=[...]` only widens the descriptor's matcher set; callers still query + `pto.select_kernel(...)` with one concrete op +- when a multi-op descriptor matches, the returned descriptor is already bound + to one concrete `selected_op` + +For explicit single-op kernels that already map 1:1 to one real PTO op, you +do not need to migrate anything. Keep `op="..."` and keep authoring explicit +real `pto.*` calls in the kernel body. + +For shared-family kernels, the matcher migration usually comes first: +- change one descriptor from `op="..."` to `ops=[...]` +- continue selecting with concrete query ops +- rely on `selected_op` only as internal compile-time context for later + template-slot expansion + +Materialization boundary for multi-op descriptors: +- a descriptor registered with `ops=[...]` cannot directly `mlir_text()`, + `mlir_module()`, `verify()`, or `emit(path)` before selection +- call `pto.select_kernel(...)` first so the returned descriptor carries one + concrete `selected_op` ## Vecscope Migration @@ -118,6 +166,117 @@ Inference boundaries in the implemented package: Use `pto.strict_vecscope` when you need a deterministic region ABI or do not want inference to merge adjacent vector chains. +## Template-Slot Migration + +Template slots are the migration path for kernels whose control-flow, +load/store pattern, masks, and surrounding vector scaffolding stay the same +while one or a few real `pto.*` ops differ by concrete matcher op. + +### When To Keep Explicit Real `pto.*` Calls + +Keep the original style when: +- the kernel only serves one concrete op +- different ops need structurally different loops, masks, DMA scheduling, or + control flow +- the body is clearer when the real op is written directly +- there is no duplication pressure worth introducing `ops=[...]` and + `templates={...}` + +Example: + +```python +@pto.vkernel(op="tadd", dtypes=[(pto.f32, pto.f32, pto.f32)], advanced=True) +def add_kernel(lhs: pto.TensorView, rhs: pto.TensorView, out: pto.Tile): + with pto.strict_vecscope(out, lhs, 0, 256, 64) as (_, _, lb, ub, step): + for lane in range(lb, ub, step): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + lhs_v = pto.vlds(lhs, lane) + rhs_v = pto.vlds(rhs, lane) + out_v = pto.vadd(lhs_v, rhs_v, mask) + pto.vsts(out_v, out, lane, mask) +``` + +### When To Migrate To Template Slots + +Migrate when: +- several concrete ops share the same loop skeleton +- only the core vector op or a small number of real `pto.*` calls differ +- you want one descriptor and one kernel body to cover a whole op family +- you still want deterministic compile-time expansion, not runtime dispatch + +Recommended pattern: + +```python +@pto.vkernel( + ops=["tadd", "tsub", "tmul", "tdiv"], + dtypes=[(pto.f32, pto.f32, pto.f32)], + advanced=True, + templates={ + "core": { + "tadd": "vadd", + "tsub": "vsub", + "tmul": "vmul", + "tdiv": "vdiv", + } + }, +) +def arithmetic_kernel(lhs: pto.TensorView, rhs: pto.TensorView, out: pto.Tile): + with pto.strict_vecscope(out, lhs, 0, 256, 64) as (_, _, lb, ub, step): + for lane in range(lb, ub, step): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + lhs_v = pto.vlds(lhs, lane) + rhs_v = pto.vlds(rhs, lane) + out_v = pto.tpl("core", lhs_v, rhs_v, mask) + pto.vsts(out_v, out, lane, mask) + +selected = pto.select_kernel( + "a5", + "tmul", + (pto.f32, pto.f32, pto.f32), +) +``` + +In this model: +- `ops=[...]` defines which concrete ops the descriptor may match +- `pto.select_kernel(...)` still receives one concrete op such as `"tmul"` +- the selected descriptor carries `selected_op="tmul"` +- frontend expansion rewrites `pto.tpl("core", ...)` to the real call for + that selected concrete op, such as `pto.vmul(...)` + +The example in +`tilelang-dsl/examples/v1_template_slot_multiop_demo.py` shows this shared +kernel-body migration pattern end to end. + +### Migration Checklist + +When converting an existing family of explicit kernels to template slots: +1. Confirm the kernels only differ in a few real `pto.*` calls. +2. Keep one shared body and move the op differences into + `templates={...}` slot mappings. +3. Replace the differing real calls with `pto.tpl("slot", ...)`. +4. Switch the descriptor from `op="..."` to `ops=[...]`. +5. Ensure all materialization goes through `pto.select_kernel(...)` so the + descriptor is bound to one concrete `selected_op`. + +### Boundaries And Non-Goals + +Template-slot migration is intentionally narrow: +- `pto.tpl("slot", ...)` is a compile-time placeholder, not a runtime helper +- the first argument must be a string literal slot name +- template mappings live in descriptor metadata, not in kernel-body Python + dictionaries +- callable-based dispatch such as `table["core"](...)` or `resolver(...)` + remains outside the DSL contract +- unresolved multi-op descriptors must not materialize before + `pto.select_kernel(...)` binds one concrete `selected_op` + +Template slots are not the right abstraction when: +- the kernels differ in control-flow structure, not just in a few ops +- one op variant needs extra DMA, sync, or pointer logic that the others do + not share +- you need arbitrary Python-level dispatch or dynamic selection inside the + kernel body + ## Pointer And DMA Migration ### New Pointer Surface diff --git a/tilelang-dsl/docs/unsupported-features.md b/tilelang-dsl/docs/unsupported-features.md new file mode 100644 index 000000000..5405e1453 --- /dev/null +++ b/tilelang-dsl/docs/unsupported-features.md @@ -0,0 +1,249 @@ +# TileLang DSL Unsupported And Partial Features + +## Scope + +This document records the gap between the broad language surface described in +`tilelang-dsl-guide.md` and what the current standalone `tilelang_dsl` package +actually implements under: + +- `tilelang-dsl/python/tilelang_dsl/` +- `tilelang-dsl/tests/` + +Use this file as a quick "what is still missing" index. For the implemented +contract, treat these as the source-of-truth companion documents: + +- `v1-surface.md` +- `v1-lowering.md` +- `matcher-and-advanced-surface-migration.md` + +## Status Labels + +- `Unsupported`: the public surface is documented but not exported or not + accepted by the frontend at all. +- `Partial`: the concept exists, but only a narrower subset works in the + current implementation. + +## Unsupported Features + +### Missing Public Type Constructors And Aliases + +The guide documents a richer type-construction surface that is not exported by +the current package: + +- `pto.tile(...)` +- `BLayout`, `SLayout`, `PadValue` +- `SyncOpType` + +Today, the public package exports annotation markers (`TensorView`, `Tile`), +scalar dtypes, `ptr(...)`, `PadMode`, `TileConfig`, matcher APIs, and a small +set of enums. The list above covers the remaining missing public constructors +and aliases from the guide. + +### Missing Tile/Tensor Utility Methods + +The following guide surfaces are not implemented as public APIs: + +- `tile.slice(...)` +- `tile.reshape(...)` +- `pto.tile_from_ptr(...)` +- `pto.tile_with_strides(...)` +- `pto.tile_config(...)` + +### Missing Sync/Buffer Control Ops + +These documented surfaces are not accepted by the current frontend: + +- `pto.get_buf(...)` +- `pto.rls_buf(...)` + + +### Missing Vector Load/Store Families + +Only `pto.vlds(...)` and `pto.vsts(...)` are implemented from the guide's +load/store families. The following documented ops are still unsupported: + +- `pto.vldas(...)` +- `pto.vldus(...)` +- `pto.vldx2(...)` +- `pto.vsld(...)` +- `pto.psts(...)` +- `pto.vsst(...)` +- `pto.vstx2(...)` +- `pto.vsta(...)` +- `pto.pstu(...)` +- `pto.vstu(...)` +- `pto.vstus(...)` +- `pto.vstur(...)` + +### Missing Direct Predicate Constructor/Compare APIs + +The implementation expects users to go through `pto.make_mask(...)` rather than +call the underlying mask ops directly. These guide-documented APIs are not part +of the supported authoring surface: + +- `pto.pset_b8(...)`, `pto.pset_b16(...)`, `pto.pset_b32(...)` +- `pto.pge_b8(...)`, `pto.pge_b16(...)`, `pto.pge_b32(...)` +- `pto.plt_b8(...)`, `pto.plt_b16(...)`, `pto.plt_b32(...)` + +### Missing Extended Vector Arithmetic Families + +The previously missing `13-vector-arithmetic-operations.md` gap list is now +implemented in the current package surface (including fused ops, broadcast/index +generation, reduction-flavored ops, and rearrangement/sort groups). + +### Missing Predicate Rearrangement Shorthands + +The guide documents mask-specific rearrangement helpers that are not currently +implemented: + +- `pto.pdintlv_b8(...)` +- `pto.pintlv_b16(...)` + +### Deferred Surface + +`pto.vreduce(...)` is still explicitly deferred and remains rejected even in +`advanced=True` kernels. + +## Partial Features + +### Scalar Constants And Literal Typing + +The guide describes automatic `float -> pto.f32` literal typing. + +Literal support currently includes: + +- `bool` +- `int` +- `str` +- `None` + +### TensorView Attribute Model + +`TensorView` currently supports only a narrow attribute subset: + +- `shape` +- `strides` +- `element_type` +- `valid_shape` + +The following documented attributes are not implemented: + +- `offset` + +In practice, `TensorView` is now modeled as a fixed 5D GM view in the current +profile, but the DMA-oriented slicing/lowering path remains narrower than the +full guide: + +- `shape` / `valid_shape` exposure follows the 5D descriptor +- `strides` lower through hidden stride parameters carried alongside TensorView shape +- fewer written slice axes are right-aligned onto the trailing physical axes +- DMA-oriented slicing/lowering still only accepts rank-2 TensorView slices + +### Tile Attribute Model + +`Tile` currently supports only a narrow attribute subset in semantic analysis: + +- `shape` +- `element_type` +- `valid_shape` + +The guide documents additional properties that are not currently supported: + +- `memory_space` +- `config` +- `rank` +- `num_elements` +- `valid_elements` +- `layout_descriptor` +- `strides` +- `byte_strides` +- `offset` + +### Tile Config Semantics + +`TileConfig` can be attached during specialization, but lowering does not yet +honor the rich layout/padding semantics described in the guide. The rendered +tile type is effectively fixed to a hard-coded baseline: + +- `blayout=row_major` +- `slayout=none_box` +- `fractal=512` +- `pad=0` + +So this is currently metadata storage rather than full behavioral support. + +### TensorView Slicing + +The guide presents general Python slicing with dynamic starts and strides. The +current stable DMA-oriented implementation is still a narrower 2D profile: + +- slice `stop` must be explicit on all dimensions +- slice `start` may be a compile-time constant or runtime index expression +- slice `step` must be a static positive integer +- dimension 0 may use `step > 1` +- dimension 1 must keep `step == 1` (current DMA restriction) + +Dynamic bounds are supported within those constraints. + + +### Tile Indexing Sugar + +Tile indexing sugar is partially implemented on the stable authoring path. + +Currently supported: + +- rank-1: `tile[start:]` +- rank-2: `tile[row, col:]` +- only for `pto.vlds(...)` and `pto.vsts(...)` + +Not currently supported from the guide's broader indexing model: + +- column-major syntax such as `tile[row_start:, col_index]` +- single-element syntax such as `tile[row, col]` and `tile[pos]` +- explicit slice `stop` +- stepped tile vector slices +- the guide's wider indexed op family (`vldas`, `vldus`, `vldx2`, + `vsld`, `psts`, `vsst`, `vstx2`, `vsta`) + +### Control-Flow Result Merging + +The frontend does analyze loop-carried values and merged `if` results, but +lowering still has a hard limit: + +- at most one loop-carried binding per loop +- at most one merged `if`/`else` binding per conditional + +So the language feature exists conceptually, but multi-value merge cases are +not fully lowered yet. + +### Tile Profile Breadth + +The guide discusses Tile memory spaces in more general terms, but bare Tile +specialization still only accepts: + +- rank-1 or rank-2 Tiles +- static physical shape +- `MemorySpace.UB` + +So GM Tiles and more general profiles are not supported yet. + +## Currently Implemented Core Surface + +For quick orientation, the current package head is strongest in these areas: + +- matcher-driven kernel selection +- `templates={...}` and `pto.tpl(...)` +- `ptr(...)`, `pto.castptr(...)`, `pto.addptr(...)` +- low-level DMA config/copy ops +- `pto.make_mask(...)` +- `pto.vlds(...)` and `pto.vsts(...)` +- base unary/binary/vector-scalar vector ops +- advanced compare/select/carry/rearrangement families + +If you need the exact supported boundary for implementation work, prefer the +source files and tests over the broader guide: + +- `tilelang-dsl/python/tilelang_dsl/support_matrix.py` +- `tilelang-dsl/python/tilelang_dsl/semantic.py` +- `tilelang-dsl/python/tilelang_dsl/lowering.py` +- `tilelang-dsl/tests/test_tilelang_dsl_v1.py` diff --git a/tilelang-dsl/docs/user_guide/01-introduction.md b/tilelang-dsl/docs/user_guide/01-introduction.md new file mode 100644 index 000000000..04f845a11 --- /dev/null +++ b/tilelang-dsl/docs/user_guide/01-introduction.md @@ -0,0 +1,48 @@ +# TileLang Python DSL Guide + +The TileLang Python DSL provides a high-level, Pythonic interface for authoring vector compute kernels targeting the Ascend NPU hardware. This guide is intended for library developers and performance engineers who need to write efficient, hardware-aware kernels using the PTO micro instruction set. + +The DSL is designed to generate MLIR function libraries rather than direct binary executables. These MLIR libraries are intended to be consumed by other compilation frameworks that transform high-level tile semantics into low-level vector operations. This enables library developers to focus on hardware-aware kernel authoring while relying on upstream compilers for tile-level optimizations and code generation. + +## Language Tier + +The DSL surface is organized into multiple maturity tiers, reflecting the stability and intended use of different language features. As the design evolves, the basic authoring path is being explicitly separated from more advanced surfaces. Refer to the following table when reading this guide: + +| Surface Family | Tier | Usage Guidance | +|----------------|------|----------------| +| `TensorView` | `basic` | Default GM-facing data model for starter kernels. | +| `Tile` | `basic` | Default UB-facing compute tile for starter kernels. | +| Base vector ops (`make_mask`, `vlds`, `vsts`, `vadd`, `vmuls`, etc.) | `basic` | Default compute skeleton for starter kernels. | +| `strict_vecscope` | `advanced` | Explicit vector-scope management for expert authoring. | +| Raw pointer family (`ptr(...)`, `castptr`, `addptr`) | `advanced` | For expert authoring and migration; not required for Quick Start. | +| DMA family (`copy_*`, `set_loop*_stride_*`, `set_loop_size_*`) | `advanced` | Direct DMA engine control for expert authoring. | +| Tile helper family (`tile.slice(...)`, `tile.reshape(...)`, `tile.as_ptr()`, `tile_from_ptr(...)`, `tile_with_strides(...)`, `tile_config(...)`) | `advanced` | Partial or evolving surface; not the default entry point. | + +For the authoritative tier classification, consult `tilelang-dsl/python/tilelang_dsl/support_matrix.py`. For known implementation gaps, refer to `tilelang-dsl/docs/unsupported-features.md`. + +### Basic vs Advanced Authoring Modes + +The TileLang DSL provides two distinct authoring modes: + +**Basic Mode (default)** +- Uses **Tile element/slice semantics** for buffer access +- Direct tile indexing syntax: `tile[start:]`, `tile[row, col:]` +- Vector operations use element-indexing syntax: `pto.vlds(tile[row, col:])`, `pto.vsts(vec, tile[start:], mask)` +- No pointer arithmetic or explicit offset calculations +- Suitable for most kernel authoring with high-level abstractions + +**Advanced Mode (`advanced=True` in `@pto.vkernel`)** +- Uses **raw pointer semantics** for explicit memory management +- Direct pointer operations correspond to `pto.ptr` types in MLIR +- Explicit pointer arithmetic: `ptr(...)`, `castptr`, `addptr` +- Manual DMA engine control with low-level copy operations +- Requires explicit buffer management and pointer arithmetic +- Intended for expert users and performance-critical optimizations + +**Key Differences** +- **Basic mode**: Uses tile element-indexing syntax (`tile[row, col:]`, `tile[start:]`) for vector operations +- **Advanced mode**: Uses pointer byte-offset syntax (`pto.vlds(buf: ptr, offset)`) for vector operations +- Tile slices in basic mode correspond to MLIR `memref` types +- Raw pointers in advanced mode correspond to MLIR `pto.ptr` types +- No automatic conversion between tile and pointer semantics - choose the appropriate syntax for your authoring mode + diff --git a/tilelang-dsl/docs/user_guide/02-quick-start.md b/tilelang-dsl/docs/user_guide/02-quick-start.md new file mode 100644 index 000000000..26b0ba58b --- /dev/null +++ b/tilelang-dsl/docs/user_guide/02-quick-start.md @@ -0,0 +1,78 @@ +## Quick Start + +**Note on mask pattern enums**: For brevity, examples in this guide use `PAT` as an alias for `pto.MaskPattern` (e.g., `PAT.ALL` instead of `pto.MaskPattern.PAT_ALL`). You can create this alias with `from pto import MaskPattern as PAT` or `PAT = pto.MaskPattern`. + +TileLang DSL provides the following core constructs for kernel authoring: + +- `TensorView` – Access global memory (GM) tensors +- `Tile` – Local computation buffers in unified buffer (UB) +- Base vector operations (`make_mask`, `vlds`, `vmuls`, `vadd`, `vsts`) – Perform vector computations + +A typical kernel follows the GM → UB → vector compute → GM pattern: + +```python +import tilelang_dsl as pto + +@pto.vkernel(target="a5", op="scale", dtypes=[(pto.f32, pto.f32, pto.f32, pto.f32)]) +def tile_scale( + input_tensor: pto.TensorView, + output_tensor: pto.TensorView, + work_tile: pto.Tile, + scale_factor: pto.f32, +): + dim0 = 4 + dim1 = 16 + + # Stage one GM tile into UB. + # GM -> UB data movement (implementation detail) + + # Run vector compute over the UB tile using tile indexing sugar. + for i in range(0, dim0): + mask = pto.make_mask(pto.f32, PAT.ALL) + vec = pto.vlds(work_tile[i, 0:]) + scaled = pto.vmuls(vec, scale_factor, mask) + pto.vsts(scaled, work_tile[i, 0:], mask) + + # Write the UB result back to GM. + # UB -> GM data movement (implementation detail) +``` + +The example illustrates the key components of a TileLang kernel: + +1. **`TensorView` parameters** – Access global memory tensors +2. **`Tile` parameters** – Local computation buffers in unified buffer (UB) +3. **Base vector operations** (`make_mask`, `vlds`, `vmuls`, `vadd`, `vsts`) – Perform vector computations + +Here is a second example with two inputs and one output: + +```python +@pto.vkernel( + target="a5", + op="elementwise_add", + dtypes=[(pto.f32, pto.f32, pto.f32, pto.f32, pto.f32, pto.f32)], +) +def elementwise_add( + lhs_gm: pto.TensorView, + rhs_gm: pto.TensorView, + out_gm: pto.TensorView, + lhs_tile: pto.Tile, + rhs_tile: pto.Tile, + dst_tile: pto.Tile, +): + dim0 = 4 + dim1 = 16 + + # GM -> UB data movement (implementation detail) + + for lane in range(0, 256, 64): + mask = pto.make_mask(pto.f32, PAT.ALL) + lhs_vec = pto.vlds(lhs_tile, lane) + rhs_vec = pto.vlds(rhs_tile, lane) + summed = pto.vadd(lhs_vec, rhs_vec, mask) + pto.vsts(summed, dst_tile, lane, mask) + + # UB -> GM data movement (implementation detail) +``` + +Both examples follow the same fundamental pattern: load data from global memory into local tiles, perform vector operations, and store results back. The compiler automatically infers vector-scope boundaries for the base vector operations. The `Tile` parameters are specialized to concrete shapes during compilation. Later sections cover advanced features such as matchers, template slots, raw pointer operations, and explicit scope management with `strict_vecscope`. + diff --git a/tilelang-dsl/docs/user_guide/03-kernel-declaration.md b/tilelang-dsl/docs/user_guide/03-kernel-declaration.md new file mode 100644 index 000000000..3044ef005 --- /dev/null +++ b/tilelang-dsl/docs/user_guide/03-kernel-declaration.md @@ -0,0 +1,313 @@ +## Core Concepts + +### Kernel Declaration + +Kernels are defined using the `@pto.vkernel` decorator with enhanced matching capabilities for PTO operations. The decorator specifies matching criteria for target architecture, operation type, data types, and additional constraints, along with a priority for disambiguation when multiple kernels match. + +#### Basic Syntax + +```python +@pto.vkernel( + target="a5", # Target architecture + op="pto.matmul ins(a, b) -> outs(c)", # PTO op + operand schema + dtypes=[(pto.f16, pto.f16, pto.f32)], # Type signatures + constraints=[ # Additional constraints + lambda a, b: a.shape[1] == b.shape[0], + lambda batch=1: batch >= 1, + ], + priority=100 # Priority for selection +) +def matmul_fallback(a: pto.Tile, b: pto.Tile, c: pto.Tile) -> None: + # kernel implementation +``` + +#### Decorator Parameters + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `target` | `str` | Yes | Target hardware architecture (e.g., `"a5"` for Ascend 950). | +| `op` | `str` | No* | PTO operation matcher. Preferred form is schema mode: `"pto.op_name ins(in0, in1, ...) -> outs(out0, out1, ...)"`. Legacy bare-op form (`"pto.op_name"`) is still accepted for compatibility. **Mutually exclusive with `ops`**. | +| `ops` | `List[str]` | No* | List of PTO operation names to match. **Mutually exclusive with `op`**. Use this when one descriptor should match multiple concrete ops (schema mode is currently only supported in `op`). | +| `dtypes` | `List[Tuple[Type, ...]]` | Yes | List of type signatures. Each tuple specifies the expected data types for the operation's operands (inputs and outputs) in order. | +| `templates` | `Dict[str, Dict[str, str]]` | No | Static template-slot mappings. Each slot maps concrete matcher ops to real `pto.*` op names. Required when the kernel body uses `pto.tpl(...)`. | +| `constraints` | `List[Callable[..., bool]]` | No | Additional selection-time predicates. Constraint arguments bind by name to kernel parameter proxy objects or `context_attrs` keys. Default: empty list. | +| `priority` | `int` | No | Selection priority when multiple kernels match. Higher values have higher priority. Default: `0`. | +| `name` | `str` | No | Kernel name (used for debugging and profiling). Defaults to the decorated function's name. | +| `advanced` | `bool` | No | Enable advanced-tier DSL surfaces (for example `strict_vecscope`, raw pointer family, and low-level DMA family). Implicit vecscope inference is available in both modes and runs only when no explicit `with pto.vecscope():` is present. Default: `False`. | + +#### Operation Schema in `op` (ins/outs) + +`op` supports a schema string that declares how kernel parameter names map to PTO op operands: + +```python +op="pto.tadds ins(src, scalar) -> outs(dst)" +``` + +Schema form: + +```text + ins(, , ...) -> outs(, , ...) +``` + +Rules: + +1. `ins(...)` and `outs(...)` are both required in schema mode. +2. Names in `ins` and `outs` must be valid, unique Python identifiers. +3. The decorated function parameter list must exactly match `ins + outs` by both count and name. +4. MLIR function argument ordering is defined by schema order (`ins` first, then `outs`). +5. Constraint binding keeps using parameter names; schema mode makes these names explicit and stable. +6. Schema mode applies to `op=...` (single matcher op). `ops=[...]` remains bare-op matching. + +Example: + +```python +@pto.vkernel( + target="a5", + op="pto.tadds ins(src, scalar) -> outs(dst)", + dtypes=[(pto.f32, pto.f32, pto.f32)], +) +def template_tadds(src: pto.Tile, scalar: pto.f32, dst: pto.Tile): + return None +``` + +If names or order do not match, descriptor construction fails early with a schema mismatch error. + + +#### Type Matching Rules + +The `dtypes` parameter supports flexible type matching: + +1. **Concrete Types**: Exact type matches using DSL scalar types: + - `pto.f16`, `pto.f32`, `pto.bf16` + - `pto.i8`, `pto.i16`, `pto.i32`, `pto.i64` + - `pto.mask_b8`, `pto.mask_b16`, `pto.mask_b32` + +2. **Type Wildcards**: Generic type patterns: + - `pto.AnyFloat`: Matches any floating-point type (`f16`, `bf16`, `f32`) + - `pto.AnyInt`: Matches any integer type (`i8`, `i16`, `i32`, `i64`) + - `pto.AnyType`: Matches any scalar type + - `pto.AnyMask`: Matches any mask type (`mask_b8`, `mask_b16`, `mask_b32`) + +3. **Type Variables**: Named type variables that enforce consistency within a signature: + ```python + T = pto.TypeVar('T') # Define a type variable + + @pto.vkernel( + target="a5", + op="elementwise", + dtypes=[(T, T, T)], # All three operands must have the same type + constraints=[] + ) + def elementwise_same_type(x: pto.Tile, y: pto.Tile, out: pto.Tile) -> None: + # x, y, and out must have identical element types + pass + ``` + +4. **Mixed Signatures**: Multiple type signatures for the same operation: + ```python + @pto.vkernel( + target="a5", + op="add", + dtypes=[ + (pto.AnyFloat, pto.AnyFloat, pto.AnyFloat), # Float addition + (pto.AnyInt, pto.AnyInt, pto.AnyInt) # Integer addition + ] + ) + def generic_add(a: pto.Tile, b: pto.Tile, c: pto.Tile) -> None: + # Supports both float and integer types + pass + ``` + +#### Constraint System + +Constraints are compile-time predicates that refine kernel selection. In the current implementation, each entry in `constraints=[...]` is a Python callable returning `True` or `False`. + +##### Predefined Constraints + +| Constraint | Description | +|------------|-------------| +| `k_dim_aligned_64` | K dimension is aligned to 64 elements (for matmul kernels). | +| `continuous_memory` | Operands reside in contiguous memory regions. | +| `requires_ub_memory` | Operation requires Unified Buffer memory (vs. Global Memory). | +| `tensor_rank(rank)` | Operand tensor has specified rank (e.g., `tensor_rank(2)` for 2D tensors). | +| `broadcastable` | Operands are broadcastable according to NumPy-style broadcasting rules. | +| `static_shape` | All tensor dimensions are known at compile time (no dynamic shapes). | + +##### Logical Constraint Combinators + +| Combinator | Description | Example | +|------------|-------------|---------| +| `AnyOf(c1, c2, ...)` | At least one of the constraints must be satisfied. | `AnyOf(k_dim_aligned_64, continuous_memory)` | +| `AllOf(c1, c2, ...)` | All constraints must be satisfied. | `AllOf(tensor_rank(2), static_shape)` | +| `Not(c)` | The constraint must not be satisfied. | `Not(requires_ub_memory)` | + +##### Custom Constraints + +Users can define custom constraints using predicate functions: + +```python +# Define a custom constraint that consumes one context attr by name. +def large_batch(min_batch: int): + return lambda batch=0: batch >= min_batch + +@pto.vkernel( + target="a5", + op="pto.matmul ins(a, b) -> outs(c)", + dtypes=[(pto.f16, pto.f16, pto.f32)], + constraints=[large_batch(1024)] +) +def large_batch_matmul(a: pto.Tile, b: pto.Tile, c: pto.Tile) -> None: + # Optimized for large batch sizes + pass +``` + +Constraint callables bind by parameter name. + +- Kernel parameter names such as `src`, `dst`, `a`, `b` receive lightweight proxy objects, so constraints can use direct expressions like `src.shape[0] <= dst.shape[0]`. +- Extra `context_attrs` passed to `pto.select_kernel(...)` bind by key name, for example `batch`, `enabled`, or `expected_rows`. + +##### Parameter Proxy Objects + +When a constraint argument name matches a kernel parameter name, the callable receives a lightweight proxy object rather than raw Python data. + +- For `TensorView` parameters, the proxy exposes `rank`, `shape`, `strides`, `dtype`, and `memory_space`. +- For `Tile` parameters, the proxy exposes `rank`, `shape`, `valid_shape`, `dtype`, `memory_space`, and `config`. +- `shape`, `strides`, and `valid_shape` support index access such as `src.shape[0]` or `dst.valid_shape[1]`. +- Missing or not-yet-known metadata evaluates as "unknown", so comparisons conservatively pass rather than failing early. + +Example: + +```python +def tload_preconditions(src, dst): + logical_rows = src.shape[0] * src.shape[1] * src.shape[2] * src.shape[3] + logical_cols = src.shape[4] + return ( + src.rank == 5 + and src.strides[4] == 1 + and dst.valid_shape[0] <= logical_rows + and dst.valid_shape[1] <= logical_cols + and logical_rows <= dst.shape[0] + and logical_cols <= dst.shape[1] + ) + +@pto.vkernel( + target="a5", + op="pto.tload", + dtypes=[(pto.f32, pto.f32)], + constraints=[tload_preconditions], +) +def template_tload(src: pto.TensorView, dst: pto.Tile): + return None +``` + +This is the recommended constraint style for current TileLang DSL head. + +#### Kernel Selection Mechanism + +When a PTO operation needs implementation, the system performs the following matching process: + +1. **Target Filtering**: Select kernels with matching `target` architecture. +2. **Operation Filtering**: Select kernels whose matcher metadata covers the concrete query op: + - `op="foo"` requires exact match + - `op="foo ins(...) -> outs(...)"` still matches by op name `foo`; `ins/outs` additionally defines parameter naming/order contract for descriptor validation and materialization + - `ops=[...]` requires the concrete query op to appear in that list +3. **Type Matching**: For each kernel's `dtypes` list, check if any signature matches the operation's operand types: + - Concrete types must match exactly. + - Wildcard types match according to their category. + - Type variables must be consistent within the signature. +4. **Constraint Validation**: For each matching kernel, evaluate all `constraints`. If any constraint fails, the kernel is rejected. +5. **Priority Selection**: From the remaining kernels, select the one with the highest `priority` value. +6. **Fallback**: If no kernel matches, compilation fails with an error. + +For multi-op descriptors selected through `ops=[...]`, `pto.select_kernel(...)` +also binds the concrete query op before materialization. This bound +`selected_op` is what template-slot expansion uses later. + +The package also exposes explicit selection utilities: + +```python +registry = pto.KernelRegistry() +registry.register(my_kernel) + +selected = pto.select_kernel( + "a5", + "matmul", + (pto.f16, pto.f16, pto.f32), + context_attrs={"k_aligned": True}, + registry=registry, +) +``` + +#### Examples + +##### Matmul with Multiple Implementations + +```python +# High-performance kernel for aligned K dimension +def k_aligned_64(k=0): + return k % 64 == 0 + +@pto.vkernel( + target="a5", + op="pto.matmul ins(a, b) -> outs(c)", + dtypes=[(pto.f16, pto.f16, pto.f32)], + constraints=[k_aligned_64], + priority=200 +) +def matmul_aligned_k(a: pto.Tile, b: pto.Tile, c: pto.Tile) -> None: + # Optimized implementation for aligned K + pass + +# General-purpose fallback +@pto.vkernel( + target="a5", + op="pto.matmul ins(a, b) -> outs(c)", + dtypes=[(pto.AnyFloat, pto.AnyFloat, pto.AnyFloat)], + constraints=[], + priority=100 +) +def matmul_general(a: pto.Tile, b: pto.Tile, c: pto.Tile) -> None: + # Generic implementation + pass +``` + +##### Elementwise Operation with Type Polymorphism + +```python +def same_shape(a, b, out): + return a.shape[0] == out.shape[0] and b.shape[0] == out.shape[0] + +@pto.vkernel( + target="a5", + op="pto.add ins(a, b) -> outs(out)", + dtypes=[ + (pto.AnyFloat, pto.AnyFloat, pto.AnyFloat), + (pto.AnyInt, pto.AnyInt, pto.AnyInt) + ], + constraints=[same_shape] +) +def polymorphic_add(a: pto.Tile, b: pto.Tile, out: pto.Tile) -> None: + # Single implementation handles both float and integer types + dtype = a.element_type + all_mask = pto.make_mask(dtype, PAT.ALL) + # ... implementation using generic vector operations + pass +``` + +##### Constrained Convolution Kernel + +```python +def prefer_static_nhwc(src, weight): + return src.rank == 4 and weight.rank == 4 + +@pto.vkernel( + target="a5", + op="pto.conv2d ins(input, filter) -> outs(output)", + dtypes=[(pto.f16, pto.f16, pto.f32)], + constraints=[prefer_static_nhwc], + priority=150 +) +def conv2d_nhwc_f16_f32(input: pto.Tile, filter: pto.Tile, output: pto.Tile) -> None: + # Optimized for NHWC layout with static shapes + pass +``` diff --git a/tilelang-dsl/docs/user_guide/04-template-kernels.md b/tilelang-dsl/docs/user_guide/04-template-kernels.md new file mode 100644 index 000000000..a9bffd3ab --- /dev/null +++ b/tilelang-dsl/docs/user_guide/04-template-kernels.md @@ -0,0 +1,332 @@ +### Template-based Kernel Authoring + +For operations that share similar computation patterns but differ in their core vector operations, the DSL supports template-based kernel authoring. This allows a single kernel implementation to serve multiple related operations through parameterized templates. + +#### Multi-operation Kernels with `ops` Parameter + +Instead of specifying a single `op` parameter, you can provide an `ops` list to match multiple operations: + +```python +@pto.vkernel( + target="a5", + ops=["tadd", "tsub", "tmul", "tdiv"], # List of operations + dtypes=[(T, T, T)], # Type signature using type variable + advanced=True, + templates={ + "core": { + "tadd": "vadd", + "tsub": "vsub", + "tmul": "vmul", + "tdiv": "vdiv", + } + } +) +def elementwise_arithmetic(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + dtype = dst.element_type + rows, cols = dst.valid_shape + elems_per_vreg = pto.elements_per_vreg(dtype) # Number of elements per vector register + for row in range(0, rows, 1): + remained = cols + for col in range(0, cols, elems_per_vreg): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + out = pto.tpl("core", lhs, rhs, mask) # Template dispatch + pto.vsts(out, dst[row, col:], mask) +``` + +`op` and `ops` are mutually exclusive, and exactly one of them must be +provided. `ops=[...]` only widens the matcher set; callers still use +`pto.select_kernel(target, concrete_op, operand_types, ...)` with a concrete +PTO op such as `"tadd"` or `"tmul"`. + +#### Template System + +The template system consists of three components: + +1. **`templates` parameter**: A dictionary mapping template names to operation-specific implementations +2. **`pto.tpl()` function**: A compile-time placeholder that resolves to the appropriate implementation for the currently selected concrete op +3. **`ops` parameter**: Replaces the singular `op` parameter for multi-operation kernels + +##### Template Definition + +Templates are defined in the `templates` parameter of `@pto.vkernel`. Each template is a dictionary mapping operation names to implementation strings: + +```python +templates={ + "template_name": { + "op1": "implementation_for_op1", + "op2": "implementation_for_op2", + # ... + }, + "another_template": { + "op1": "different_implementation_for_op1", + # ... + } +} +``` + +Template-slot metadata is static and validated when the descriptor is +registered: + +- slot names must be non-empty strings +- mapping keys must be concrete ops covered by the descriptor matcher set +- mapping values must be supported real `pto.*` op names + +The implementation strings are typically vector operation names such as +`"vadd"`, `"vsub"`, `"vmul"`, and `"vdiv"`, which are resolved during kernel +expansion. + +##### Template Usage with `pto.tpl()` + +The `pto.tpl()` operation enables template dispatch for multi-operation kernels, allowing code reuse across related operations through compile-time substitution. + +#### `pto.tpl(template_name: str, *args) -> Any` + +**Description**: Template dispatch operation for multi-operation kernels. Resolves to different implementations based on the current operation being expanded. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `template_name` | `str` | Name of the template to dispatch | +| `*args` | `Any` | Positional arguments passed unchanged to the resolved real implementation | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `Any` | Result of the template implementation | + +**Behavior**: +- Only valid inside kernels decorated with `@pto.vkernel` that have a `templates` parameter +- The first argument must be a string literal template-slot name +- During kernel expansion for a specific operation `op_name`, `pto.tpl("template_name", ...)` is replaced with the implementation specified in `templates["template_name"]["op_name"]` +- The replacement is a direct compile-time substitution; positional arguments are passed unchanged +- Template implementations are typically string names of vector operations (e.g., `"vadd"`, `"vsub"`) +- `pto.select_kernel(...)` must bind a concrete op before template expansion can happen +- Python dict lookup, callable values, lambdas, and other runtime dispatch patterns are not part of the supported kernel-body surface + +**Example**: +```python +@pto.vkernel( + ops=["tadd", "tsub"], + dtypes=[(T, T, T)], + templates={ + "core": { + "tadd": "vadd", + "tsub": "vsub", + } + } +) +def elementwise_kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + # ... load vectors + result = pto.tpl("core", lhs, rhs, mask) # Expands to vadd for tadd, vsub for tsub + # ... store result +``` + +**Constraints**: +- Template names must be defined in the `templates` parameter of the `@pto.vkernel` decorator +- When a kernel body uses `pto.tpl("slot", ...)`, that slot must define an implementation for the currently selected concrete op +- Template implementations must be valid operation names in the DSL + +#### Decorator Parameters Update + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `target` | `str` | Yes | Target hardware architecture (e.g., `"a5"` for Ascend 950). | +| `op` | `str` | No* | Name of the PTO operation to match. **Mutually exclusive with `ops`**. | +| `ops` | `List[str]` | No* | List of PTO operation names to match. **Mutually exclusive with `op`**. | +| `dtypes` | `List[Tuple[Type, ...]]` | Yes | List of type signatures. Each tuple specifies the expected data types for the operation's operands. | +| `templates` | `Dict[str, Dict[str, str]]` | No | Static slot mappings from concrete matcher ops to real `pto.*` op names. Required when the kernel body uses `pto.tpl(...)`. | +| `constraints` | `List[Constraint]` | No | Additional constraints that must be satisfied for kernel selection. | +| `priority` | `int` | No | Selection priority when multiple kernels match. Default: `0`. | +| `name` | `str` | No | Kernel name (used for debugging and profiling). Defaults to the decorated function's name. | +| `advanced` | `bool` | No | Enable advanced-tier DSL surfaces (for example `strict_vecscope`, raw pointer family, and low-level DMA family). Implicit vecscope inference is mode-independent and runs only when no explicit `with pto.vecscope():` is present. Default: `False`. | + +**Note**: +- Either `op` or `ops` must be provided, but not both. +- `templates` is only needed when the kernel body uses `pto.tpl(...)`. +- `pto.select_kernel(...)` still queries with a concrete op even for `ops=[...]` descriptors. + +#### Advanced Template Patterns + +##### Multiple Templates per Kernel + +A kernel can define multiple templates for different aspects of the computation: + +```python +@pto.vkernel( + target="a5", + ops=["tadd_relu", "tsub_relu", "tadd_abs", "tsub_abs"], + dtypes=[(T, T, T)], + templates={ + "arithmetic": { + "tadd_relu": "vadd", + "tsub_relu": "vsub", + "tadd_abs": "vadd", + "tsub_abs": "vsub", + }, + "postprocess": { + "tadd_relu": "vrelu", + "tsub_relu": "vrelu", # Same activation for both + "tadd_abs": "vabs", + "tsub_abs": "vabs", + } + } +) +def elementwise_with_postprocess(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + # ... load vectors + arith_result = pto.tpl("arithmetic", lhs, rhs, mask) + postprocessed = pto.tpl("postprocess", arith_result, mask) + # ... store result +``` + +##### Compile-time Substitution Model + +Template-slot expansion happens before semantic checking and lowering: + +- `pto.select_kernel(...)` first binds a concrete op such as `"tadd"` +- the frontend then resolves `pto.tpl("core", ...)` using `templates["core"]["tadd"]` +- the placeholder is rewritten to a real `pto.*` call before semantic analysis +- diagnostics for unknown slots, missing mappings, or unsupported resolved surfaces are raised before any VPTO IR is generated + +#### Type Variables in Template Kernels + +Template kernels often use type variables to enforce type consistency: + +```python +T = pto.TypeVar('T') + +@pto.vkernel( + target="a5", + ops=["tadd", "tsub"], + dtypes=[(T, T, T)], # All three operands share type T + templates={ + "core": { + "tadd": "vadd", + "tsub": "vsub", + } + } +) +def typed_elementwise(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + # Type variable T ensures all tiles have same element type + dtype = dst.element_type # This is type T + # ... implementation +``` + +#### Selection Mechanism for Template Kernels + +When a PTO operation matches a template kernel: +1. The system selects the descriptor based on `op` exact match or `ops` list inclusion. +2. `pto.select_kernel(...)` binds the concrete query op as the descriptor's `selected_op`. +3. During frontend expansion, `pto.tpl()` calls are resolved using that bound concrete op. +4. For operation `"op_name"`, template `"template_name"` resolves to `templates["template_name"]["op_name"]`. +5. The resolved string (e.g., `"vadd"`) is replaced with the corresponding real DSL operation before semantic analysis and lowering. + +#### Example: Unified Arithmetic Kernel + +```python +T = pto.TypeVar('T') + +@pto.vkernel( + ops=["tadd", "tsub", "tmul", "tdiv", "tmax", "tmin"], + dtypes=[(T, T, T)], + advanced=True, + templates={ + "arithmetic": { + "tadd": "vadd", + "tsub": "vsub", + "tmul": "vmul", + "tdiv": "vdiv", + "tmax": "vmax", + "tmin": "vmin", + } + } +) +def unified_arithmetic(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + """Single implementation for six arithmetic operations.""" + dtype = dst.element_type + rows, cols = dst.valid_shape + elems_per_vreg = pto.elements_per_vreg(dtype) # Number of elements per vector register + + for row in range(0, rows, 1): + remained = cols + for col in range(0, cols, elems_per_vreg): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + out = pto.tpl("arithmetic", lhs, rhs, mask) + pto.vsts(out, dst[row, col:], mask) +``` + +#### Compile-time Specialization with `pto.constexpr` + +The `pto.constexpr` construct enables compile-time branching for kernel specialization, allowing different code paths to be selected based on static compile-time information. Unlike runtime conditionals that generate control flow, `pto.constexpr` branches are resolved during kernel descriptor materialization, with only the selected branch retained for lowering. + +**Syntax and Usage**: +```python +if pto.constexpr(condition): + # Branch taken if condition evaluates to True at compile time + ... +else: + # Branch taken if condition evaluates to False at compile time + ... +``` + +**Semantics**: +- The `condition` must be evaluable at compile time during kernel descriptor materialization. +- Only the selected branch is analyzed, semantically checked, and lowered to VPTO IR. +- The non-selected branch is discarded entirely and does not contribute to runtime control flow or value merging. +- If the condition cannot be proven static, descriptor materialization fails with a frontend diagnostic. + +**Comparison with Runtime Conditionals**: + +| Aspect | Runtime `if` | `pto.constexpr` | +|--------|--------------|-----------------| +| **Evaluation time** | Runtime | Compile-time (descriptor materialization) | +| **Control flow** | Generates `scf.if` with merge logic | No runtime control flow; branch eliminated | +| **Value merging** | Both branches must produce compatible values for merge | No value merging; only one branch exists after elimination | +| **Use case** | Dynamic decision making based on runtime values | Code generation specialization based on static parameters | + +**Typical Static Inputs**: +- Literal integers, booleans, and strings +- Data type symbols (`src.element_type`, `dst.element_type`) and comparisons derived from them +- Statically specialized `Tile.shape` and `Tile.valid_shape` values +- Frontend query helpers such as `pto.bytewidth(dtype)` and `pto.elements_per_vreg(dtype)` (which computes elements per vector register) + +**Constraints and Notes**: +- `TensorView.shape` and `TensorView.strides` may be represented by hidden kernel parameters rather than descriptor-time constants. They should not be assumed constexpr unless separately bound through specialization or other compile-time context. +- `pto.constexpr` is a frontend-only authoring construct; it does not correspond to any runtime VPTO instruction. + +**Guidelines**: +- Use `constraints=[...]` and `pto.select_kernel(...)` when specialization requires selecting an entirely different kernel descriptor. +- Use `pto.constexpr` when the kernel remains the same but internal regions require specialization based on compile-time parameters. + +**Example**: +```python +@pto.vkernel(target="a5", op="pto.trowsum") +def template_trowsum(dst: pto.Tile, src: pto.Tile, tmp: pto.Tile): + acc_dtype = tmp.element_type + dst_dtype = dst.element_type + dst_mask_1, _ = pto.make_mask(dst_dtype, 1) + + if pto.constexpr(acc_dtype != dst_dtype): + # Type conversion required + v_acc_casted = pto.vcvt(v_acc, dst_mask_1, dst_dtype) + pto.vsts(v_acc_casted, dst[row, 0:], dst_mask_1) + else: + # No conversion needed + pto.vsts(v_acc, dst[row, 0:], dst_mask_1) +``` + +### Value Model + +The DSL operates on symbolic values, not Python runtime values: +- **Constants**: Python literals that are typed to machine types +- **Operation results**: Values produced by DSL operations +- **Block arguments**: Values introduced by control flow structures + +### Memory Spaces + +The DSL supports different memory spaces: +- `MemorySpace.GM`: Global Memory +- `MemorySpace.UB`: Unified Buffer (local storage for vector computation) diff --git a/tilelang-dsl/docs/user_guide/05-type-system.md b/tilelang-dsl/docs/user_guide/05-type-system.md new file mode 100644 index 000000000..d7df0e2b8 --- /dev/null +++ b/tilelang-dsl/docs/user_guide/05-type-system.md @@ -0,0 +1,142 @@ +## Type System + +### Scalar Types + +| DSL Type | Description | Bit Width | +|----------|-------------|-----------| +| `pto.i1` | Boolean | 1 | +| `pto.i8` | 8-bit integer | 8 | +| `pto.i16` | 16-bit integer | 16 | +| `pto.i32` | 32-bit integer | 32 | +| `pto.i64` | 64-bit integer | 64 | +| `pto.f16` | Half precision float | 16 | +| `pto.bf16` | Brain float 16 | 16 | +| `pto.f32` | Single precision float | 32 | + +Python literals are automatically typed: +- `bool` → `pto.i1` +- `int` → Context-dependent (typically `pto.i32` or `pto.i64`) +- `float` → `pto.f32` + +For explicit typing, use type constructors: +```python +x = pto.i32(1024) # Explicit i32 constant +y: pto.i32 = 1024 # Type annotation +``` + +### Floating-Point Literal Forms + +`pto.f16(...)`, `pto.bf16(...)`, and `pto.f32(...)` accept multiple literal forms. + +```python +# Signed numeric literals +a = pto.f16(-1.5) +b = pto.bf16(+2.5) +c = pto.f32(-3.5) + +# Special floating-point values +pos_inf = pto.f32("inf") +neg_inf = pto.f32("-inf") +qnan = pto.f32("nan") + +# Bit-pattern form (hex string, interpreted by target dtype) +f16_neg_inf = pto.f16("0xFC00") +bf16_neg_inf = pto.bf16("0xFF80") +f32_neg_inf = pto.f32("0xFF800000") +``` + +Notes: +- Prefer dtype constructors for reduction seeds and boundary values (for example rowmax initialization). +- For float bit-pattern constants, pass a **string** hex literal to the matching dtype constructor. +- Avoid passing raw integer bit-patterns directly into vector broadcast/dup APIs when a floating vector is expected. +- `float(...)` function calls are not part of the TileLang DSL public call surface; use constructor forms above. + +### Vector Register Type + +Vector registers have fixed 256-byte width: + +```python +v_f32 = pto.vreg(pto.f32) # !pto.vreg<64xf32> +v_f16 = pto.vreg(pto.f16) # !pto.vreg<128xf16> +v_i8 = pto.vreg(pto.i8) # !pto.vreg<256xi8> +``` + +`pto.vreg(dtype)` only takes the element type. The frontend infers the element count automatically from the fixed 256-byte register width: + +- `pto.f32` → `!pto.vreg<64xf32>` +- `pto.f16` → `!pto.vreg<128xf16>` +- `pto.bf16` → `!pto.vreg<128xbf16>` +- `pto.i32` → `!pto.vreg<64xi32>` +- `pto.i16` → `!pto.vreg<128xi16>` +- `pto.i8` → `!pto.vreg<256xi8>` + +Constraint: `element_count × bitwidth(element_type) = 2048` + +Use `pto.elements_per_vreg(dtype)` when you need the inferred element count explicitly: + +```python +v_dtype = pto.vreg(pto.f32) +lanes0 = v_dtype.elements_per_vreg # 64 +lanes1 = pto.elements_per_vreg(pto.f32) # 64 +``` + +Current TileLang DSL v1 vector lowering supports `i8`, `i16`, `i32`, `f16`, `bf16`, and `f32` element types. + +### Typed Masks + +Masks are typed by their bit granularity: + +| DSL Type | VPTO Type | Description | +|----------|-----------|-------------| +| `pto.mask_b8` | `!pto.mask` | 8-bit granularity mask | +| `pto.mask_b16` | `!pto.mask` | 16-bit granularity mask | +| `pto.mask_b32` | `!pto.mask` | 32-bit granularity mask | + +```python +mask_ty = pto.mask_b32 +mask: pto.mask_b32 = pto.make_mask(pto.f32, PAT.ALL) +``` + +Mask operations must match the vector element family: +- `f32` vectors use `mask_b32` +- `f16` vectors use `mask_b16` +- `i8` vectors use `mask_b8` + +```python +# Correct: f32 vector with b32 mask +mask32 = pto.make_mask(pto.f32, PAT.ALL) +vec_f32 = pto.vlds(ptr, offset) +out = pto.vabs(vec_f32, mask32) + +# Error: mismatched mask granularity +mask16 = pto.make_mask(pto.f16, PAT.ALL) +out = pto.vabs(vec_f32, mask16) # Type error! +``` + +### Pointer Types [Advanced Tier] + +Pointers combine element type and memory space: + +```python +from pto import MemorySpace + +ptr_gm = pto.ptr(pto.f32, MemorySpace.GM) # GM pointer to f32 +ptr_ub = pto.ptr(pto.f16, MemorySpace.UB) # UB pointer to f16 +``` + +The `MemorySpace` enum provides type-safe memory space specification: + +| Enum Value | Description | +|------------|-------------| +| `MemorySpace.GM` | Global Memory (off-chip HBM/DDR) | +| `MemorySpace.UB` | Unified Buffer (on-chip SRAM, 256KB) | + +This replaces string literals (`MemorySpace.GM`/`MemorySpace.UB`) with compile-time checked enums. + +### Pointer Type Aliases [Advanced Tier] + +For clarity in API documentation, the following type alias is used: + +| Alias | Equivalent Type | Description | +|-------|----------------|-------------| +| `Tile` | `pto.tile(...)` | Tile buffer with layout and configuration | diff --git a/tilelang-dsl/docs/user_guide/06-tensorview.md b/tilelang-dsl/docs/user_guide/06-tensorview.md new file mode 100644 index 000000000..ad9a169d4 --- /dev/null +++ b/tilelang-dsl/docs/user_guide/06-tensorview.md @@ -0,0 +1,97 @@ +### TensorView Types + +TensorView types represent multi‑dimensional (up to 5D) views into tensors residing in Global Memory (GM). They are used as kernel parameters for describing GM data and support slicing operations to create logical partitions for DMA load/store operations. + +### TensorView Type Definition + +TensorView types are parameterized by shape (a tuple of up to 5 dimensions) and element type: + +```python +# Kernel parameter using TensorView +@pto.vkernel(target="a5", op="custom", dtypes=[(pto.AnyFloat, pto.AnyFloat, pto.AnyFloat)], priority=10) +def tiled_kernel( + input_tensor: pto.TensorView, # GM tensor view + output_tensor: pto.TensorView, # GM tensor view + tile_buf: pto.Tile # UB tile +): + # Access tensor view properties + shape = input_tensor.shape # tuple of dimensions (dynamic or static, up to 5D) + dtype = input_tensor.element_type # e.g., pto.f32 + strides = input_tensor.strides # stride in elements +``` + +**Important Notes:** +- TensorView is a **read-only descriptor** for GM data (though DMA store operations can write to it) +- Shape can be **static** (compile-time constants) or **dynamic** (determined at runtime) +- Strides are expressed in elements, not bytes +- Memory space is always GM (Global Memory) +- Maximum rank is 5 (PTO ISA right‑aligns lower‑rank shapes to 5D) +- When higher dimensions are 1, a 5D TensorView can be abbreviated to lower‑rank forms. For example, shape `(1,1,64,32,16)` can be written as `(64,32,16)` (3D), and shape `(1,1,1,32,16)` can be written as `(32,16)` (2D). + +### TensorView Attributes + +| Attribute | Type | Description | +|-----------|------|-------------| +| `shape` | `tuple[int, ...]` | Tensor dimensions (supports up to 5 dimensions, right-aligned to 5D in PTO ISA) | +| `element_type` | `Type` | Element data type (e.g., `pto.f32`, `pto.f16`) | +| `strides` | `tuple[int, ...]` | Stride in elements for each dimension | +| `offset` | `pto.i64` | Byte offset from base pointer (internal) | + +### Padding Mode Enum + +Padding mode controls how out-of-bounds accesses are handled during DMA load/store operations: + +| Enum Value | Description | +|------------|-------------| +| `PadMode.PadNull` | No padding (out-of-bounds access is invalid) | +| `PadMode.PadFirstElem` | Pad using the first element of the source | +| `PadMode.PadValue` | Pad using a specified value (requires `pad_value` parameter) | + +### Slicing Syntax + +TensorView supports Python slicing syntax to create logical partitions: + +```python +# Create a partition from a tensor view +partition = tensor_view[dim0_start:dim0_end, dim1_start:dim1_end] + +# Example: extract a 16x16 tile from a larger tensor +tile_view = large_tensor[0:16, 0:16] + +# Dynamic offsets and sizes +dim0_start = tensor_view.shape[0] // 2 +dynamic_partition = tensor_view[dim0_start:tensor_view.shape[0], 4:20] + +# Static positive step on dimension 0 +stepped_partition = tensor_view[0:32:2, 0:16] + +# Right-aligned shorthand on a 5D descriptor: +# if the leading 2 axes are logical singleton dimensions, a 3D-style slice +# maps to the trailing 3 physical axes. +partition_3d = tensor_view[d2_start:d2_end, d3_start:d3_end, d4_start:d4_end] + +# Full 5D spelling remains available when needed +partition_5d = tensor_view[ + d0_start:d0_end, + d1_start:d1_end, + d2_start:d2_end, + d3_start:d3_end, + d4_start:d4_end, +] +``` + +**Constraints:** +- Slicing returns a new TensorView representing the logical partition +- The partition must be within the original tensor bounds +- When fewer than 5 slice axes are written, they are right-aligned to the + trailing physical axes of the 5D descriptor +- `stop` must be explicit on all dimensions +- `start` may be static or dynamic +- `step` must be a static positive integer +- Dimension 0 may use `step > 1` +- Dimension 1 must keep `step == 1` (current implementation restriction for DMA operations) + +### Alignment Type + +The `pto.align` type is used for alignment carrier operations and maps to `!pto.align`. + diff --git a/tilelang-dsl/docs/user_guide/07-tile-types.md b/tilelang-dsl/docs/user_guide/07-tile-types.md new file mode 100644 index 000000000..b901e20f9 --- /dev/null +++ b/tilelang-dsl/docs/user_guide/07-tile-types.md @@ -0,0 +1,187 @@ +### Tile Types + +Tile types represent data blocks in memory with layout and configuration information, corresponding to `!pto.tile_buf` in the VPTO IR. Tiles are commonly used as kernel parameters for tiled computations. + +#### Tile Type Definition + +```python +# Create a tile with shape, element type, and memory space +tile = pto.tile((256, 128), pto.f32, MemorySpace.UB) + +# With explicit configuration +config = pto.tile_config( + b_layout=pto.BLayout.ROW_MAJOR, + s_layout=pto.SLayout.NONE_BOX, + s_fractal_size=pto.i32(16), + pad_value=pto.PadValue.ZERO +) +tile = pto.tile((256, 128), pto.f32, MemorySpace.UB, config=config) + +# With valid shape (actual data dimensions within tile) +tile = pto.tile((256, 128), pto.f32, MemorySpace.UB, valid_shape=(240, 120)) +``` + +**Important Notes on Shape and Valid Shape:** +- **Static Shape Requirement**: The `shape` parameter must be a compile-time constant. Tile dimensions are fixed at compilation time and cannot change at runtime. +- **Valid Shape Constraints**: The `valid_shape` parameter can be either static (compile-time constant) or dynamic (determined at runtime). It must be less than or equal to the physical `shape` in each dimension. This allows for variable-sized data within a fixed tile allocation. +- **Default Behavior**: When `valid_shape` is not specified, it defaults to the full `shape`. + +#### Tile Attributes + +| Attribute | Type | Description | +|-----------|------|-------------| +| `shape` | `tuple[int, ...]` | **Static** full tile dimensions (compile-time constant) | +| `element_type` | `Type` | Element data type (e.g., `pto.f32`) | +| `memory_space` | `MemorySpace` | Memory space (GM, UB, etc.) | +| `valid_shape` | `tuple[int, ...]` | Actual data dimensions within tile (can be static/compile-time or dynamic/runtime). Must be ≤ shape in each dimension. | +| `config` | `TileConfig` | Layout and padding configuration | + +#### Tile Configuration + +The tile configuration includes layout and padding information: + +```python +# Layout enums +pto.BLayout.ROW_MAJOR # 0: row-major base layout +pto.BLayout.COL_MAJOR # 1: column-major base layout + +pto.SLayout.NONE_BOX # 0: no secondary layout +pto.SLayout.ROW_MAJOR # 1: row-major secondary layout +pto.SLayout.COL_MAJOR # 2: column-major secondary layout + +pto.PadValue.NULL # 0: no padding +pto.PadValue.ZERO # 1: zero padding +pto.PadValue.MAX # 2: maximum value padding +pto.PadValue.MIN # 3: minimum value padding +``` + +#### Tile Shape Concepts + +- **Static Physical Shape**: The `shape` parameter represents the **static physical dimensions** of the tile allocated in memory. This must be a **compile-time constant** because tile memory allocation is fixed during compilation. The shape determines the total memory footprint and cannot change at runtime. + +- **Valid Shape**: The `valid_shape` parameter represents the logical dimensions of actual data within the tile. It can be either **static** (compile-time constant) or **dynamic** (determined at runtime). It must be less than or equal to the physical `shape` in each dimension. When `valid_shape` is not specified, it defaults to the full `shape`. + +- **Key Distinction**: + - `shape`: **Static, compile-time** - Fixed tile allocation + - `valid_shape`: **Static or Dynamic** - Actual data region (must be ≤ shape) + +- **Constraints**: + - `valid_shape[i] ≤ shape[i]` for each dimension i + - `shape` must be compile-time constants + - `valid_shape` can be compile-time constants or runtime values + +- **Use Cases**: + - Fixed-size tile buffers with variable data (e.g., batch processing with different input sizes) + - Padding scenarios where physical allocation is larger than actual data + - Partial tile utilization in tiled algorithms + +- **Fractal Layout**: The `s_fractal_size` in tile configuration specifies the size of fractal blocks for secondary layout. This is used for optimized memory access patterns in matrix operations. + +- **Padding Behavior**: The `pad_value` determines how out-of-bounds accesses are handled when reading beyond `valid_shape` but within `shape`. Padding values are used for accesses in the padded region (between valid_shape and shape). + +> **⚠️ Important: Shape Constraints** +> +> The tile `shape` must be **compile-time constants**. `valid_shape` can be compile-time constants or determined at runtime, but must satisfy `valid_shape[i] ≤ shape[i]` for all dimensions i. + +### Tile Operations + +#### Basic Access Operations + +```python +# Get tile properties +shape = tile.shape # (256, 128) +elem_type = tile.element_type # pto.f32 +mem_space = tile.memory_space # MemorySpace.UB +valid_shape = tile.valid_shape # (240, 120) or same as shape + +# Get configuration properties +config = tile.config +b_layout = config.b_layout # pto.BLayout.ROW_MAJOR +s_layout = config.s_layout # pto.SLayout.NONE_BOX +s_fractal = config.s_fractal_size # pto.i32(16) +pad = config.pad_value # pto.PadValue.ZERO + +# Dynamic properties +rank = tile.rank # 2 +num_elements = tile.num_elements # 32768 (256 * 128) +valid_elements = tile.valid_elements # 28800 (240 * 120) +``` + +#### Layout and Stride Queries + +```python +# Get layout descriptors +layout_desc = tile.layout_descriptor # Returns layout description object + +# Get strides (in elements) +strides = tile.strides # (128, 1) for row-major 256x128 + +# Get byte strides +byte_strides = tile.byte_strides # (512, 4) for f32 row-major + +# Get base offset (in bytes) +offset = tile.offset # pto.i64(0) or specified offset +``` + +#### Conversion Operations + +**Basic Mode Syntax**: Use tile element-indexing directly in vector operations: +```python +# 2D tile indexing +vec = pto.vlds(tile[row, col:]) +pto.vsts(vec, tile[row, col:], mask) + +# 1D tile indexing +vec = pto.vlds(tile[start:]) +pto.vsts(vec, tile[start:], mask) +``` + +**Advanced Mode Syntax**: Convert tiles to typed pointers for byte-offset operations: +```python +# Convert tile to pointer +ptr = tile.as_ptr() # Returns pto.ptr(pto.f32, MemorySpace.UB) + +# Use pointer with byte offset +vec = pto.vlds(ptr, offset) +pto.vsts(vec, ptr, offset, mask) +``` + +**Tile Manipulation Operations**: +```python +# Extract slice of tile +slice_tile = tile.slice((0, 0), (64, 128)) # 64x128 slice from top-left corner + +# Reshape tile (logical reshape, no data movement) +reshaped = tile.reshape((32768,)) # 1D reshape of 256x128 tile +``` + +#### Kernel Parameter Usage + +```python +@pto.vkernel(target="a5", op="scale", dtypes=[(pto.AnyFloat, pto.AnyFloat)], priority=10) +def tiled_kernel( + input_tile: pto.Tile, # Tile parameter + output_tile: pto.Tile, # Another tile parameter + scale: pto.f32 +): + # Tiles can be used directly in vector operations (no explicit conversion needed) + all_mask = pto.make_mask(pto.f32, PAT.ALL) + for i in range(0, 256, 64): + # tile element-indexing syntax for basic mode vector operations + vec = pto.vlds(input_tile[i, 0:]) # Load from row i, columns 0 to vector_lanes-1 + scaled = pto.vmuls(vec, scale, all_mask) + pto.vsts(scaled, output_tile[i, 0:], all_mask) # Store to same position +``` + +#### Tile Creation from Existing Buffers + +```python +# Create tile from existing pointer with shape +ptr = pto.castptr(0, pto.ptr(pto.f32, MemorySpace.UB)) +tile = pto.tile_from_ptr(ptr, (256, 128), pto.f32) + +# Create tile with explicit stride +tile = pto.tile_with_strides((256, 128), pto.f32, MemorySpace.UB, + strides=(256, 1)) # Column-major strides +``` + diff --git a/tilelang-dsl/docs/user_guide/08-control-flow.md b/tilelang-dsl/docs/user_guide/08-control-flow.md new file mode 100644 index 000000000..1b1944600 --- /dev/null +++ b/tilelang-dsl/docs/user_guide/08-control-flow.md @@ -0,0 +1,142 @@ +## Control Flow + +### Vector Scopes + +The TileLang DSL supports implicit vector scope inference, allowing developers to write vector operations directly without explicit `pto.vecscope()` blocks. The compiler automatically groups consecutive, data-dependent vector operations into implicit vector scopes during lowering. + +#### Implicit Scope Inference + +**Note:** `pto.vecscope()` is supported. Automatic scope inference runs only when the kernel does **not** contain explicit `with pto.vecscope():` blocks. + +When you write vector operations like `pto.vlds`, `pto.vadd`, `pto.vsts` directly in your code, the compiler's **Scope Inference Pass** analyzes the control flow graph and automatically creates vector scopes: + +```python +# No explicit vecscope needed - compiler infers scope boundaries +vec = pto.vlds(outer_ptr, offset) +result = pto.vadd(vec, vec, all_mask) +pto.vsts(result, dst_ptr, offset, all_mask) +``` + +The compiler automatically groups these three operations into a single implicit vector scope because they form a data-dependent chain (when no explicit `pto.vecscope()` appears in the kernel). + +**Scope boundary rules:** +1. **Control flow boundaries**: Branches (`if`/`else`), loops (`for`/`while`), and function calls create implicit scope boundaries +2. **Scalar operations**: Non-vector operations (e.g., scalar arithmetic, pointer arithmetic) create boundaries +3. **Explicit scope blocks**: User-defined `vecscope` and `strict_vecscope` blocks create hard boundaries + +#### Explicit Scope Boundaries with `strict_vecscope` [Advanced Tier] + +##### `pto.strict_vecscope(*captures: AnyType) -> ContextManager[Tuple[AnyType, ...]]` + +**Description**: Creates an explicit vector scope boundary with explicit value captures. Values used inside the scope must be passed as arguments; implicit capture from outer scope is rejected. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `*captures` | `AnyType` | Variable number of values to be captured and passed into the scope | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `context_manager` | `ContextManager[Tuple[AnyType, ...]]` | Context manager that yields a tuple of captured values when entered | + +**Constraints**: +- The scope body cannot implicitly capture values from the surrounding scope; all used values must be passed as `captures`. +- Creates a hard boundary that prevents the compiler from merging vector operations across the scope boundary. +- Useful for performance optimization, debugging, resource management, and hardware compatibility. + +For precise control over scope boundaries, use explicit `strict_vecscope` blocks. These create hard boundaries that prevent the compiler from merging operations across the block boundary: + +```python +with pto.strict_vecscope(src_ptr, dst_ptr, start, end) as (s, d, lb, ub): + # Operations inside this block are isolated from outside + # Compiler will not merge operations across this boundary + for i in range(lb, ub, 64): + vec = pto.vlds(s, i) + pto.vsts(vec, d, i, all_mask) +``` + +**Use cases for strict_vecscope:** +- Performance optimization: Isolate critical vector computation regions +- Debugging: Create explicit boundaries to isolate vector operations +- Resource management: Control vector register allocation boundaries +- Compatibility: Ensure deterministic scope placement for hardware constraints + +#### Explicit Scope Blocks with `vecscope` + +`pto.vecscope` provides an explicit vector-scope boundary without strict capture ABI constraints: + +```python +with pto.vecscope(): + vec = pto.vlds(src, 0) + vec = pto.vadd(vec, vec, mask) + pto.vsts(vec, dst, 0, mask) +``` + +**Rules**: +- `pto.vecscope()` takes no positional/keyword arguments. +- `pto.vecscope()` does not support `as (...)` bindings. +- When any explicit `pto.vecscope()` is present in a kernel body, automatic vecscope inference is disabled for that kernel. + +### Inline Procedures (`@pto.inline_proc`) + +TileLang DSL supports reusable top-level procedures decorated with `@pto.inline_proc`. +`inline_proc` follows function-call semantics in frontend IR and is force-inlined +later by the VPTO backend mainline in `ptoas`. + +```python +@pto.inline_proc +def store_row(dst: pto.Tile, src: pto.Tile, row: pto.i32): + vec = pto.vlds(src[row, 0:]) + mask = pto.make_mask(dst.element_type, pto.PAT.ALL) + pto.vsts(vec, dst[row, 0:], mask) + return None + +@pto.vkernel(op="pto.row_copy", dtypes=[(pto.f32, pto.f32, pto.i32)]) +def row_copy(dst: pto.Tile, src: pto.Tile, row: pto.i32): + store_row(dst, src, row) + return None +``` + +Important semantics: + +- Frontend preserves helper `func.func` and `func.call` in `mlir_text()` output. +- VPTO backend mainline force-inlines helper calls before downstream lowering. +- Helper definitions support default parameter values. +- Helper calls support positional arguments and keyword arguments. +- Helper calls can appear in statement and expression positions. +- Helper definitions can use trailing `return ` to return values. +- Implicit capture is rejected; pass required values as explicit arguments. +- Recursive/mutually-recursive helper call graphs are rejected. +- `*args`, `**kwargs`, and keyword-only parameters are unsupported in current version. + +### Loops + +Counted loops use Python's `range` syntax: + +```python +for i in range(lb, ub, step): + # Loop body + mask, rem = pto.make_mask(pto.f32, remaining) + # ... +``` + +Loop-carried state is automatically handled through variable updates within the loop. + +### Conditionals + +`if` statements support value merging: + +```python +flag: pto.i1 = some_condition +step: pto.i32 = 0 + +if flag: + step = pto.i32(64) +else: + step = pto.i32(128) + +# 'step' here is the merged result from both branches +``` + +Variables defined in only one branch are local to that branch. diff --git a/tilelang-dsl/docs/user_guide/09-frontend-operations.md b/tilelang-dsl/docs/user_guide/09-frontend-operations.md new file mode 100644 index 000000000..5f8bd2893 --- /dev/null +++ b/tilelang-dsl/docs/user_guide/09-frontend-operations.md @@ -0,0 +1,172 @@ + +### Frontend-only Authoring Operations + +Operations in this family affect descriptor construction and code generation +shape. They are consumed by the frontend and do not correspond to runtime VPTO +instructions by themselves. + +#### `pto.constexpr(value: bool) -> bool` + +**Description**: Compile-time conditional construct for kernel specialization. Marks a boolean expression for evaluation during descriptor materialization, enabling branch elimination based on static compile-time information. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `value` | `bool` | Boolean expression that must be evaluable at compile time. | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `bool` | A frontend-only compile-time boolean used to guard `if` statements. | + +**Behavior**: +- Evaluated during kernel descriptor materialization before semantic analysis and lowering. +- When used in `if pto.constexpr(...):` statements, only the selected branch is retained; the other branch is discarded entirely. +- If the condition cannot be proven static, descriptor materialization fails with a frontend diagnostic. +- Does not generate runtime control flow or value merging logic. + +**Examples**: +```python +# Specialize based on element size +dtype = dst.element_type +elem_bytes = pto.bytewidth(dtype) + +if pto.constexpr(elem_bytes == 2): + # Specialized path for 16-bit types (f16/bf16) + ... +else: + # Fallback path for other types + ... +``` + +```python +# Specialize based on tile shape +rows, cols = dst.shape + +if pto.constexpr(rows == 1 and cols == 16): + # Fast path for specific tile configuration + ... +``` + +**Constraints**: +- `pto.constexpr` is a frontend-only authoring construct with no runtime representation. +- The condition must be statically evaluable from descriptor-time information (data types, tile shapes, literals, etc.). +- For kernel-level specialization, prefer `constraints=[...]` and `pto.select_kernel(...)`. +- See [Compile-time Specialization with `pto.constexpr`](04-template-kernels.md#compile-time-specialization-with-ptoconstexpr) for detailed usage guidelines. + +### Type Query Operations + +Operations for querying type properties. + +#### `pto.bytewidth(dtype: Type) -> pto.i32` + +**Description**: Returns the size in bytes of a single element of the given data type. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `dtype` | `Type` | Data type (e.g., `pto.f32`, `pto.f16`, `pto.i8`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `size` | `pto.i32` | Element size in bytes | + +**Example**: +```python +f32_size = pto.bytewidth(pto.f32) # Returns 4 +f16_size = pto.bytewidth(pto.f16) # Returns 2 +i8_size = pto.bytewidth(pto.i8) # Returns 1 +``` + +**Common Use Case**: Calculate byte offsets for memory access: +```python +element_type = pto.f32 +byte_offset = index * pto.bytewidth(element_type) +``` + +#### `pto.elements_per_vreg(dtype: Type) -> pto.i32` + +**Description**: Returns the number of elements per vector register for a given element type, based on the hardware vector register size (256 bytes). This function computes `256 // bytewidth(dtype)`, which represents the maximum number of elements of the given type that can fit in a single vector register. Useful for determining vector width and loop stride calculations. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `dtype` | `Type` | Data type (e.g., `pto.f32`, `pto.f16`, `pto.i8`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `elems` | `pto.i32` | Number of elements per vector register for the given element type | + +**Example**: +```python +f32_elems_per_vreg = pto.elements_per_vreg(pto.f32) # Returns 64 (256 / 4) +f16_elems_per_vreg = pto.elements_per_vreg(pto.f16) # Returns 128 (256 / 2) +i8_elems_per_vreg = pto.elements_per_vreg(pto.i8) # Returns 256 (256 / 1) +``` + +**Common Use Case**: Loop stride calculation for vector operations: +```python +dtype = pto.f32 +elems_per_vreg = pto.elements_per_vreg(dtype) # Returns 64 for f32 +for col in range(0, cols, elems_per_vreg): + # Load/store vectors of 'elems_per_vreg' elements + pass +``` + +**Relationship with `pto.bytewidth`**: +```python +# The relationship between bytewidth and elements per vector register: +elems = 256 // pto.bytewidth(dtype) +# This is equivalent to: +elems = pto.elements_per_vreg(dtype) +``` + +### Pointer Construction [Advanced Tier] + +Operations for creating and manipulating typed pointers. + +#### `pto.castptr(offset: pto.i64, ptr_type: Type) -> PtrType` + +**Description**: Creates a typed pointer from an integer address, a memref-backed address value, or another typed pointer in the same memory space. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `offset` | `pto.i64` / address-like value | Integer address, memref-backed address value, or existing pointer | +| `ptr_type` | `Type` | Target pointer type (e.g., `pto.ptr(pto.f32, MemorySpace.GM)`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `ptr` | `PtrType` | Typed pointer value | + +**Example**: +```python +ub_ptr = pto.castptr(0, pto.ptr(pto.f32, MemorySpace.UB)) +``` + +`TensorView.as_ptr()` and `Tile.as_ptr()` remain the preferred high-level APIs. They lower directly to address-extraction intrinsics (`pto.tensor_view_addr` / `pto.tile_buf_addr`) with pointer result types, while tile slice / buffer-view authoring paths continue to materialize memref results from the same intrinsics. + +#### `pto.addptr(ptr: PtrType, offset: pto.i64) -> PtrType` + +**Description**: Adds an element offset to an existing pointer. The offset is counted in elements, not bytes. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `ptr` | `PtrType` | Source pointer | +| `offset` | `pto.i64` | Element offset to add (counted in elements, not bytes) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `new_ptr` | `PtrType` | Pointer with element offset applied | + +**Example**: +```python +# Advance pointer by 1024 f32 elements (not bytes) +next_ptr = pto.addptr(ub_ptr, 1024) +``` + diff --git a/tilelang-dsl/docs/user_guide/10-sync-dma-operations.md b/tilelang-dsl/docs/user_guide/10-sync-dma-operations.md new file mode 100644 index 000000000..fc8982da0 --- /dev/null +++ b/tilelang-dsl/docs/user_guide/10-sync-dma-operations.md @@ -0,0 +1,460 @@ +### Synchronization & Buffer Control + +Operations for pipeline synchronization and buffer management. + +#### Enum Types for Synchronization + +The following enum types provide type-safe parameter specification for synchronization operations: + +- **`BarrierType`**: Memory barrier types for `pto.mem_bar` + - `VV_ALL`: All prior vector ops complete before subsequent + - `VST_VLD`: All prior vector stores visible before subsequent loads + - `VLD_VST`: All prior vector loads complete before subsequent stores + +- **`Pipe`**: Hardware pipeline identifiers + - `MTE2`: Memory Transfer Engine 2 pipeline + - `V`: Vector pipeline + - `MTE3`: Memory Transfer Engine 3 pipeline + - `ALL`: All pipelines (for barrier operations) + +- **`Event`**: Event identifiers for synchronization + - `ID0`, `ID1`, `ID2`, `ID3`, ..., `ID31`: Event IDs 0-31 (A5 supports 32 event IDs, 0-15 for subblock 0, 16-31 for subblock 1) + +#### `pto.set_flag(pipe_from: PIPE, pipe_to: PIPE, event: EVENT) -> None` + +**Description**: Sets a synchronization flag between hardware pipelines. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pipe_from` | `PIPE` | Source pipeline (e.g., `PIPE.MTE2`) | +| `pipe_to` | `PIPE` | Destination pipeline (e.g., `PIPE.V`) | +| `event` | `EVENT` | Event identifier (e.g., `EVENT.ID0`) | + +**Returns**: None (side-effect operation) + +**Example**: +```python +from pto import PIPE, EVENT + +pto.set_flag(PIPE.MTE2, PIPE.V, EVENT.ID0) +``` + +#### `pto.wait_flag(pipe_from: PIPE, pipe_to: PIPE, event: EVENT) -> None` + +**Description**: Waits for a synchronization flag between hardware pipelines. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pipe_from` | `PIPE` | Source pipeline (e.g., `PIPE.MTE2`) | +| `pipe_to` | `PIPE` | Destination pipeline (e.g., `PIPE.V`) | +| `event` | `EVENT` | Event identifier (e.g., `EVENT.ID0`) | + +**Returns**: None (side-effect operation) + +**Example**: +```python +from pto import PIPE, EVENT + +pto.wait_flag(PIPE.MTE2, PIPE.V, EVENT.ID0) +``` + +#### `pto.pipe_barrier(pipes: PIPE) -> None` + +**Description**: Executes a barrier across specified pipelines. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pipes` | `PIPE` | Pipeline specification (e.g., `PIPE.ALL`) | + +**Returns**: None (side-effect operation) + +**Example**: +```python +from pto import PIPE + +pto.pipe_barrier(PIPE.ALL) +``` + +#### `pto.get_buf(pipe: Pipe, buf_id: pto.i64, mode: pto.i64) -> None` + +**Description**: Acquire buffer slot for inter-pipeline double-buffering coordination. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pipe` | `Pipe` | Pipeline identifier (e.g., `Pipe.MTE2`, `Pipe.V`, `Pipe.MTE3`) | +| `buf_id` | `pto.i64` | Buffer identifier | +| `mode` | `pto.i64` | Acquisition mode | + +**Returns**: None (side-effect operation) + +**Example**: +```python +from pto import Pipe + +# Acquire buffer for MTE2 pipeline +pto.get_buf(Pipe.MTE2, 0, 0) +``` + +#### `pto.rls_buf(pipe: Pipe, buf_id: pto.i64, mode: pto.i64) -> None` + +**Description**: Release buffer slot to allow other pipeline to proceed. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pipe` | `Pipe` | Pipeline identifier (e.g., `Pipe.MTE2`, `Pipe.V`, `Pipe.MTE3`) | +| `buf_id` | `pto.i64` | Buffer identifier | +| `mode` | `pto.i64` | Release mode | + +**Returns**: None (side-effect operation) + +**Example**: +```python +from pto import Pipe + +# Release buffer for MTE2 pipeline +pto.rls_buf(Pipe.MTE2, 0, 0) +``` + +#### `pto.mem_bar(barrier_type: BarrierType) -> None` + +**Description**: Memory barrier for pipeline synchronization within vector scope. Required when UB addresses alias between vector load/store operations. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `barrier_type` | `BarrierType` | Barrier type: `BarrierType.VV_ALL` (all prior vector ops complete before subsequent), `BarrierType.VST_VLD` (all prior vector stores visible before subsequent loads), `BarrierType.VLD_VST` (all prior vector loads complete before subsequent stores) | + +**Returns**: None (side-effect operation) + +**Example**: +```python +from pto import BarrierType + +# Ensure stores are visible before loads to same UB region +pto.mem_bar(BarrierType.VST_VLD) +``` + +#### `pto.set_cross_core(core_id: pto.i64, event_id: Event) -> None` + +**Description**: Signal event to another core (cross-core synchronization). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `core_id` | `pto.i64` | Target/source core identifier (platform-specific mapping) | +| `event_id` | `Event` | Cross-core event identifier (e.g., `Event.ID0`) | + +**Returns**: None (side-effect operation) + +**Example**: +```python +from pto import Event + +# Signal event ID0 to core 0 +pto.set_cross_core(0, Event.ID0) +``` + +#### `pto.set_intra_block(block_id: pto.i64, event_id: Event) -> None` + +**Description**: Signal event within a block (A5). Specifies trigger pipe. 1:1 per subblock. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `block_id` | `pto.i64` | Block/pipeline identifier specifying trigger pipe | +| `event_id` | `Event` | Event identifier (e.g., `Event.ID0`) | + +**Returns**: None (side-effect operation) + +**Example**: +```python +from pto import Event + +# Signal event ID0 on block/pipeline 0 +pto.set_intra_block(0, Event.ID0) +``` + +#### `pto.set_intra_core(config: pto.i32) -> None` + +**Description**: Configures intra-core synchronization settings. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `config` | `pto.i32` | Configuration value for intra-core synchronization | + +**Returns**: None (side-effect operation) + +**Example**: +```python +pto.set_intra_core(3) +``` + +#### `pto.wait_flag_dev(core_id: pto.i64, event_id: Event) -> None` + +**Description**: Wait for event from another core. SU-level blocking — entire core stalls. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `core_id` | `pto.i64` | Core identifier | +| `event_id` | `Event` | Event identifier (e.g., `Event.ID0`) | + +**Returns**: None (side-effect operation) + +**Example**: +```python +from pto import Event + +# Wait for event ID0 from core 0 +pto.wait_flag_dev(0, Event.ID0) +``` + +#### `pto.wait_intra_core(block_id: pto.i64, event_id: Event) -> None` + +**Description**: Wait for event within block (A5). Specifies which pipeline should wait — only that pipe stalls, SU and other pipes continue. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `block_id` | `pto.i64` | Block/pipeline identifier specifying which pipeline should wait | +| `event_id` | `Event` | Event identifier (e.g., `Event.ID0`) | + +**Returns**: None (side-effect operation) + +**Example**: +```python +from pto import Event + +# Wait for event ID0 on block/pipeline 0 +pto.wait_intra_core(0, Event.ID0) +``` + +### DMA Programming [Advanced Tier] + +This section contains both DMA configuration operations (setting loop strides and sizes) and DMA execution operations (copying data). + +#### Manual Configuration Example + +```python +# DMA configuration example (requires careful parameter tuning) +pto.set_loop2_stride_outtoub(src_stride=32, dst_stride=128) # Outer loop strides +pto.set_loop1_stride_outtoub(src_stride=1, dst_stride=32) # Inner loop strides +pto.set_loop_size_outtoub(loop1=16, loop2=16) # Transfer size +pto.copy_gm_to_ubuf(src=gm_ptr, dst=ub_ptr, n_burst=16, len_burst=128, gm_stride=128, ub_stride=128) + +``` + +#### `pto.set_loop2_stride_outtoub(src_stride: pto.i64, dst_stride: pto.i64) -> None` [Advanced Tier] + +**Description**: Configures DMA stride parameters for GM → UB transfers (loop2). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src_stride` | `pto.i64` | Source-side stride | +| `dst_stride` | `pto.i64` | Destination-side stride | + +**Returns**: None (side-effect operation) + +#### `pto.set_loop1_stride_outtoub(src_stride: pto.i64, dst_stride: pto.i64) -> None` [Advanced Tier] + +**Description**: Configures DMA stride parameters for GM → UB transfers (loop1). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src_stride` | `pto.i64` | Source-side stride | +| `dst_stride` | `pto.i64` | Destination-side stride | + +**Returns**: None (side-effect operation) + +#### `pto.set_loop_size_outtoub(loop1: pto.i64, loop2: pto.i64) -> None` [Advanced Tier] + +**Description**: Configures DMA transfer size for GM → UB transfers. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `loop1` | `pto.i64` | Inner loop trip count | +| `loop2` | `pto.i64` | Outer loop trip count | + +**Returns**: None (side-effect operation) + +**Example**: +```python +pto.set_loop_size_outtoub(loop1=1, loop2=1) +``` + +#### `pto.set_loop2_stride_ubtoout(src_stride: pto.i64, dst_stride: pto.i64) -> None` [Advanced Tier] + +**Description**: Configures DMA stride parameters for UB → GM transfers (loop2). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src_stride` | `pto.i64` | Source-side stride | +| `dst_stride` | `pto.i64` | Destination-side stride | + +**Returns**: None (side-effect operation) + +#### `pto.set_loop1_stride_ubtoout(src_stride: pto.i64, dst_stride: pto.i64) -> None` [Advanced Tier] + +**Description**: Configures DMA stride parameters for UB → GM transfers (loop1). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src_stride` | `pto.i64` | Source-side stride | +| `dst_stride` | `pto.i64` | Destination-side stride | + +**Returns**: None (side-effect operation) + +#### `pto.set_loop_size_ubtoout(loop1: pto.i64, loop2: pto.i64) -> None` [Advanced Tier] + +**Description**: Configures DMA transfer size for UB → GM transfers. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `loop1` | `pto.i64` | Inner loop trip count | +| `loop2` | `pto.i64` | Outer loop trip count | + +**Returns**: None (side-effect operation) + +#### `pto.set_loop(loop_id: pto.i32, src_stride: pto.i64, dst_stride: pto.i64) -> None` [Advanced Tier] + +**Description**: Configures DMA stride parameters for a generic loop. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `loop_id` | `pto.i32` | Loop identifier (e.g., 1 for inner loop, 2 for outer loop) | +| `src_stride` | `pto.i64` | Source-side stride | +| `dst_stride` | `pto.i64` | Destination-side stride | + +**Returns**: None (side-effect operation) + +**Example**: +```python +pto.set_loop(1, src_stride=32, dst_stride=64) +``` + +#### `pto.set_loop_size(loop_id: pto.i32, size: pto.i64) -> None` [Advanced Tier] + +**Description**: Configures DMA transfer size for a generic loop. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `loop_id` | `pto.i32` | Loop identifier (e.g., 1 for inner loop, 2 for outer loop) | +| `size` | `pto.i64` | Loop trip count | + +**Returns**: None (side-effect operation) + +**Example**: +```python +pto.set_loop_size(1, 16) +``` + +#### DMA Execution Operations + +**Note**: These operations execute DMA transfers but require manual configuration of DMA parameters (loop strides, loop sizes) using the `set_loop*_stride_*` and `set_loop_size_*` operations described above. + +The following operations provide direct control over DMA transfers but require manual stride and size configuration. + +#### `pto.copy_gm_to_ubuf(src: GMPtr, dst: UBPtr, sid: pto.i64 = 0, n_burst: pto.i64, len_burst: pto.i64, left_padding_count: pto.i64 = 0, right_padding_count: pto.i64 = 0, enable_ub_pad: pto.i1 = False, l2_cache_ctl: pto.i64 = 0, gm_stride: pto.i64, ub_stride: pto.i64) -> None` [Advanced Tier] + +**Description**: Copies data from Global Memory (GM) to Unified Buffer (UB). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `GMPtr` | Source GM pointer | +| `dst` | `UBPtr` | Destination UB pointer | +| `sid` | `pto.i64` | DMA stream/control operand, defaults to `0` | +| `n_burst` | `pto.i64` | Number of bursts | +| `len_burst` | `pto.i64` | Bytes copied by each burst | +| `left_padding_count` | `pto.i64` | Left padding count, defaults to `0` | +| `right_padding_count` | `pto.i64` | Right padding count, defaults to `0` | +| `enable_ub_pad` | `pto.i1` | Convenience alias for `data_select_bit`, defaults to `False` | +| `l2_cache_ctl` | `pto.i64` | L2 cache control operand, defaults to `0` | +| `gm_stride` | `pto.i64` | GM-side stride in bytes | +| `ub_stride` | `pto.i64` | UB-side stride in bytes | + +**Returns**: None (side-effect operation) + +**Notes**: +- In TileLang DSL, the keyword form above is the recommended public surface. +- The lowering still maps to the underlying low-level PTO operand ABI in positional order. + +**Example**: +```python +pto.copy_gm_to_ubuf( + src=gm_ptr, + dst=ub_ptr, + n_burst=32, + len_burst=128, + gm_stride=128, + ub_stride=128, + enable_ub_pad=False, +) +``` + +#### `pto.copy_ubuf_to_ubuf(src: UBPtr, dst: UBPtr, src_offset: pto.i64, src_stride0: pto.i64, src_stride1: pto.i64, dst_offset: pto.i64, dst_stride0: pto.i64, dst_stride1: pto.i64) -> None` [Advanced Tier] + +**Description**: Copies data within Unified Buffer (UB → UB). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `UBPtr` | Source UB pointer | +| `dst` | `UBPtr` | Destination UB pointer | +| `src_offset` | `pto.i64` | Source offset | +| `src_stride0` | `pto.i64` | Source stride dimension 0 | +| `src_stride1` | `pto.i64` | Source stride dimension 1 | +| `dst_offset` | `pto.i64` | Destination offset | +| `dst_stride0` | `pto.i64` | Destination stride dimension 0 | +| `dst_stride1` | `pto.i64` | Destination stride dimension 1 | + +**Returns**: None (side-effect operation) + +#### `pto.copy_ubuf_to_gm(src: UBPtr, dst: GMPtr, sid: pto.i64 = 0, n_burst: pto.i64, len_burst: pto.i64, reserved: pto.i64 = 0, gm_stride: pto.i64, ub_stride: pto.i64) -> None` [Advanced Tier] + +**Description**: Copies data from Unified Buffer (UB) to Global Memory (GM). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `UBPtr` | Source UB pointer | +| `dst` | `GMPtr` | Destination GM pointer | +| `sid` | `pto.i64` | DMA stream/control operand, defaults to `0` | +| `n_burst` | `pto.i64` | Number of bursts | +| `len_burst` | `pto.i64` | Bytes copied by each burst | +| `reserved` | `pto.i64` | Reserved operand, defaults to `0` | +| `gm_stride` | `pto.i64` | GM-side stride in bytes | +| `ub_stride` | `pto.i64` | UB-side stride in bytes | + +**Returns**: None (side-effect operation) + +**Notes**: +- In TileLang DSL, the keyword form above is the recommended public surface. +- `gm_stride`/`ub_stride` are ergonomic aliases for the low-level `burst_dst_stride`/`burst_src_stride` operands. +- The lowering still maps to the underlying low-level PTO operand ABI in positional order. + +**Example**: +```python +pto.copy_ubuf_to_gm( + src=ub_ptr, + dst=gm_ptr, + n_burst=32, + len_burst=128, + gm_stride=128, + ub_stride=128, +) +``` diff --git a/tilelang-dsl/docs/user_guide/11-vector-memory-operations.md b/tilelang-dsl/docs/user_guide/11-vector-memory-operations.md new file mode 100644 index 000000000..f48cf2780 --- /dev/null +++ b/tilelang-dsl/docs/user_guide/11-vector-memory-operations.md @@ -0,0 +1,986 @@ +### Enum Types for Vector Memory Operations + +The following enum types provide type-safe parameter specification for vector memory operations: + +- **`DeinterleaveDist`**: Deinterleave distribution modes for `pto.vldx2` + - `B8`: 8-bit element deinterleave (for i8) + - `B16`: 16-bit element deinterleave (for i16, f16, bf16) + - `B32`: 32-bit element deinterleave (for i32, f32) + - `BD`: Broadcast deinterleave mode + +- **`InterleaveDist`**: Interleave distribution modes for `pto.vstx2` + - `B8`: 8-bit element interleave (for i8) + - `B16`: 16-bit element interleave (for i16, f16, bf16) + - `B32`: 32-bit element interleave (for i32, f32) + +- **`StrideMode`**: Stride modes for `pto.vsld` + - `S3_B16`: Stride 3, block size 16 + - `S4_B64`: Stride 4, block size 64 + - `S8_B32`: Stride 8, block size 32 + - `S2_B64`: Stride 2, block size 64 + +### Address Generation Syntax Sugar + +To simplify address calculation and reduce manual byte offset computation errors, TileLang DSL provides syntactic sugar for vector load/store operations using element-based indexing. This syntax automatically computes the byte offset based on tile shape, element type, and layout. + +#### Indexing Syntax + +The syntax supports two indexing modes for different operations: + +1. **Vector-range indexing** (for vector load/store operations): + - **Row-major layout (default)**: `tile[row_index, col_start:]` + - `row_index`: Row index (0-based) + - `col_start:`: Starting column index followed by colon, indicating a vector-width contiguous region starting from this column + - The colon (`:`) indicates an implicit vector-width range determined by hardware vector size (256 bytes) and element type + + - **Column-major layout**: `tile[row_start:, col_index]` + - `row_start:`: Starting row index followed by colon, indicating a vector-width contiguous region starting from this row + - `col_index`: Column index (0-based) + - Used for column-major tiles (`BLayout.COL_MAJOR`) where elements are stored column-wise + + - **1D tile indexing**: `tile[start:]` (or equivalently `tile[0, start:]` for row-major or `tile[start:, 0]` for column-major) + - `start:`: Starting element index followed by colon + +2. **Single-element indexing** (for scalar load operations like `pto.vsld`): + - **Row-major layout (default)**: `tile[row_index, col_index]` + - `row_index`: Row index (0-based) + - `col_index`: Column index (0-based) + - Loads a single element at the specified position and broadcasts it to all vector lanes + + - **Column-major layout**: `tile[row_index, col_index]` (same syntax) + - `row_index`: Row index (0-based) + - `col_index`: Column index (0-based) + - Same syntax as row-major; the layout determines how the offset is computed + + - **1D tile indexing**: `tile[pos]` + - `pos`: Element index (0-based) + - Loads a single element at the specified position and broadcasts it to all vector lanes + +#### Vector Width Calculation + +The number of elements loaded/stored in a single vector operation is determined by: + +``` +vector_lanes = 256 // element_size_bytes(element_type) +``` + +**Convenience API**: Use `pto.elements_per_vreg(dtype)` to compute the number of elements per vector register for a given element type (e.g., `pto.elements_per_vreg(pto.f32)` returns 64, `pto.elements_per_vreg(pto.f16)` returns 128). See [Type Query Operations](09-frontend-operations.md#type-query-operations) for full documentation. + +Where `element_size_bytes` is: +- 1 byte for `i8` +- 2 bytes for `i16`, `f16`, `bf16` +- 4 bytes for `i32`, `f32` +- 8 bytes for `i64` + +#### Offset Computation + +The byte offset is automatically computed based on tile layout: + +- **Row-major layout** (`BLayout.ROW_MAJOR`): + ``` + offset = (row_index * stride_row + col_start) * element_size_bytes + ``` + where `stride_row` is the row stride in elements (typically `tile.shape[1]` for contiguous tiles). + +- **Column-major layout** (`BLayout.COL_MAJOR`): + - For syntax `tile[row_start:, col_index]`: + ``` + offset = (col_index * stride_col + row_start) * element_size_bytes + ``` + - For backward compatibility with traditional offset calculation: + ``` + offset = (col_start * stride_col + row_index) * element_size_bytes + ``` + where `stride_col` is the column stride in elements (typically `tile.shape[0]` for contiguous tiles), `row_start` is the starting row index, and `col_index` is the column index. + +**Note**: +- For single-element indexing (`tile[row, col]` or `tile[pos]`), the same offset formulas apply with `col_start` replaced by `col_index` (or `start` replaced by `pos` for 1D tiles). +- For column-major vector-range indexing (`tile[row_start:, col_index]`), the offset formula uses `row_start` as the starting position along the contiguous dimension. +- The compiler automatically handles the appropriate substitution based on the indexing syntax and tile layout. + +#### Constraints + +1. **Boundary checks**: The requested region must be within tile bounds: + - **For vector-range indexing** (`:` syntax): + - **Row-major layout** (`tile[row_index, col_start:]`): + - `row_index < tile.shape[0]` and `col_start + vector_lanes <= tile.shape[1]` + - **Column-major layout** (`tile[row_start:, col_index]`): + - `row_start + vector_lanes <= tile.shape[0]` and `col_index < tile.shape[1]` + - **1D tile indexing**: `tile[start:]` + - `start + vector_lanes <= tile.shape[0]` (or `tile.shape[1]` for 1D tiles) + - **For single-element indexing** (no `:` syntax): + - 2D: `row_index < tile.shape[0]` and `col_index < tile.shape[1]` (same for both layouts) + - 1D: `pos < tile.shape[0]` (or `tile.shape[1]` for 1D tiles) + +2. **Alignment**: The computed offset must satisfy hardware alignment requirements for the operation. + +3. **Full vectors only**: The `:` syntax always loads/stores a full vector width. For partial vectors, use the traditional byte offset approach with explicit mask handling. + +4. **Single-element operations**: The single-element indexing syntax (`tile[row, col]` or `tile[pos]`) is only supported for scalar load operations like `pto.vsld`. For other operations, use vector-range indexing with `:` syntax. + +#### Supported Operations + +The indexing syntax is supported for all vector load and store operations with the following syntax mapping: + +- **Vector-range indexing** (`tile[row, col:]` or `tile[start:]`): + - Load operations: `vlds`, `vldas`, `vldus`, `vldx2` + - Store operations: `vsts`, `vsta`, `psts`, `vsst`, `vstx2` + +- **Single-element indexing** (`tile[row, col]` or `tile[pos]`): + - Load operations: `vsld` (scalar load with broadcast) + +#### Examples + +The following examples use row-major layout syntax. For column-major tiles, use `tile[row_start:, col_index]` syntax instead of `tile[row_index, col_start:]`. + +```python +# 2D tile indexing (row-major layout) +vec = pto.vlds(tile[i, j:]) # Load vector from row i, columns j to j+vector_lanes-1 +pto.vsts(vec, tile[i, j:], mask) # Store vector with mask + +# 1D tile indexing +vec = pto.vlds(tile[k:]) # Load vector from elements k to k+vector_lanes-1 +pto.vsts(vec, tile[k:], mask) # Store vector with mask + +# Dual load with deinterleave +low, high = pto.vldx2(tile[i, j:], DeinterleaveDist.B32) + +# Aligned load with indexing +vec = pto.vldas(tile[i, j:], align) + +# Scalar load (broadcast) +vec = pto.vsld(tile[i, j]) # Load scalar at tile[i,j] and broadcast to vector +``` + +#### Comparison with Manual Offset Calculation + +**Traditional approach (error-prone):** +```python +# Manual byte offset calculation for f32 tile +rows, cols = tile.shape +row_offset = i * cols * 4 # Hard-coded 4 bytes for f32 +col_offset = j * 4 +offset = row_offset + col_offset +vec = pto.vlds(tile, offset) +``` + +**New syntax (type-safe):** +```python +# Automatic offset calculation +vec = pto.vlds(tile[i, j:]) # Compiler computes correct offset for any element type +``` + +The syntax sugar eliminates manual byte calculations, reduces errors, and makes code generic across different element types (e.g., the same kernel works for both `f16` and `f32` without modification). + +### Vector Load Operations + +Operations for loading data from memory into vector registers. + +#### `pto.vlds(buf: ptr, offset: Index) -> VRegType` [Advanced Tier] +#### `pto.vlds(tile[row, col:]) -> VRegType` [Basic Tier] +#### `pto.vlds(tile[start:]) -> VRegType` [Basic Tier] + +**Description**: Stateless vector load from buffer. Supports both traditional byte-offset syntax and new element-indexing syntax. + +**Parameters (pointer syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `ptr` | Pointer to buffer in UB memory space (Advanced mode only - requires explicit pointer) | +| `offset` | `Index` | Byte offset | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Loaded vector register | + +**Constraints**: +- Buffer must be in UB memory space +- For byte-offset syntax: offset must be properly aligned based on element type +- For element-indexing syntax: the requested vector region must be within tile bounds and satisfy alignment requirements + +**Examples**: +```python +# Traditional byte-offset syntax +vec = pto.vlds(ub_ptr, lane * 256) + +# New element-indexing syntax +vec = pto.vlds(tile[i, j:]) # Load from row i, columns j to j+vector_lanes-1 +vec = pto.vlds(tile[k:]) # Load from 1D tile, elements k to k+vector_lanes-1 + +# Generic kernel that works for both f16 and f32 +@pto.vkernel(target="a5", op="scale", dtypes=[(pto.AnyFloat, pto.AnyFloat)], priority=10) +def generic_scale(src: pto.Tile, dst: pto.Tile, scale: pto.f32): + rows, cols = src.shape + all_mask = pto.make_mask(src.element_type, PAT.ALL) + for i in range(0, rows): + for j in range(0, cols, vector_lanes): # vector_lanes computed from element type + # No manual byte calculation needed! + vec = pto.vlds(src[i, j:]) + scaled = pto.vmuls(vec, scale, all_mask) + pto.vsts(scaled, dst[i, j:], all_mask) +``` + +#### `pto.vldas(buf: ptr) -> pto.align` [Advanced Tier] +#### `pto.vldas(tile[row, col:]) -> pto.align` [Basic Tier] +#### `pto.vldas(tile[start:]) -> pto.align` [Basic Tier] + +**Description**: Prime alignment buffer for subsequent unaligned load. Returns alignment state for use with `pto.vldus`. Supports both pointer syntax and element-indexing syntax. + +**Parameters (pointer syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `ptr` | Pointer to buffer in UB memory space (Advanced mode only - requires explicit pointer) | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `align` | `pto.align` | Alignment state for use with `pto.vldus` | + +**Examples**: +```python +# Pointer syntax +align = pto.vldas(ub_ptr) + +# Element-indexing syntax +align = pto.vldas(tile[i, j:]) +align = pto.vldas(tile[k:]) +``` + +#### `pto.vldus(buf: ptr, align: pto.align) -> (VRegType, pto.align, ptr)` [Advanced Tier] +#### `pto.vldus(tile[row, col:], align: pto.align) -> (VRegType, pto.align, ptr)` [Basic Tier] +#### `pto.vldus(tile[start:], align: pto.align) -> (VRegType, pto.align, ptr)` [Basic Tier] + +**Description**: Unaligned load using primed align state. Requires alignment state from `pto.vldas` or previous `pto.vldus`. Updates alignment state and base pointer for subsequent loads. Supports both pointer syntax and element-indexing syntax. + +**Parameters (pointer syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `ptr` | Pointer to buffer in UB memory space (Advanced mode only - requires explicit pointer) | +| `align` | `pto.align` | Alignment state from `pto.vldas` or previous `pto.vldus` | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column | +| `align` | `pto.align` | Alignment state from `pto.vldas` or previous `pto.vldus` | +| _or_ | | | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index | +| `align` | `pto.align` | Alignment state from `pto.vldas` or previous `pto.vldus` | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Assembled vector value | +| `align_out` | `pto.align` | Updated alignment state for next load | +| `base_out` | `ptr` | Post-update base pointer state | + +**Constraints**: +- A matching `pto.vldas` must appear before the first dependent `pto.vldus` stream in the same vector loop +- Both alignment state and base address advance across the stream +- If DSL authoring uses explicit byte/element offsets, the frontend first rewrites them into pointer/index expressions before lowering to this VPTO form. + +**Examples**: +```python +# Pointer syntax - requires alignment state priming +align = pto.vldas(ub_ptr) +vec, align_out, base_out = pto.vldus(ub_ptr, align) + +# Element-indexing syntax +align = pto.vldas(tile[i, j:]) +vec, align_out, base_out = pto.vldus(tile[i, j:], align) + +# Multiple unaligned loads in a stream +align = pto.vldas(tile[k:]) +for n in range(4): + vec, align, base = pto.vldus(tile[k:], align) # alignment state updates +``` + + +#### `pto.vldx2(buf: ptr, offset: Index, dist: DeinterleaveDist) -> (VRegType, VRegType)` [Advanced Tier] +#### `pto.vldx2(tile[row, col:], dist: DeinterleaveDist) -> (VRegType, VRegType)` [Basic Tier] +#### `pto.vldx2(tile[start:], dist: DeinterleaveDist) -> (VRegType, VRegType)` [Basic Tier] + +**Description**: Dual vector load with deinterleave (AoS → SoA conversion). Loads interleaved data from a single buffer and deinterleaves into two vectors. Supports both byte-offset and element-indexing syntax. + +**Parameters (pointer syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `ptr` | Pointer to source buffer in UB memory space (Advanced mode only - requires explicit pointer) | +| `offset` | `Index` | Byte offset | +| `dist` | `DeinterleaveDist` | Deinterleave distribution mode: `DeinterleaveDist.B8`, `DeinterleaveDist.B16`, `DeinterleaveDist.B32`, `DeinterleaveDist.BD`. Determines element size and interleave pattern. | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | +| `dist` | `DeinterleaveDist` | Deinterleave distribution mode: `DeinterleaveDist.B8`, `DeinterleaveDist.B16`, `DeinterleaveDist.B32`, `DeinterleaveDist.BD`. Determines element size and interleave pattern. | +| _or_ | | | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | +| `dist` | `DeinterleaveDist` | Deinterleave distribution mode: `DeinterleaveDist.B8`, `DeinterleaveDist.B16`, `DeinterleaveDist.B32`, `DeinterleaveDist.BD`. | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `low` | `VRegType` | First vector (even elements in interleaved stream) | +| `high` | `VRegType` | Second vector (odd elements in interleaved stream) | + +**Constraints**: +- Source buffer must be in UB memory space +- Offset must satisfy alignment requirements for the selected distribution mode +- The requested vector region must be within tile bounds (for element-indexing syntax) +- Distribution mode must match element type (e.g., `DeinterleaveDist.B32` for 32-bit elements) + +**Examples**: +```python +from pto import DeinterleaveDist + +# Byte-offset syntax +low, high = pto.vldx2(ub_ptr, offset, DeinterleaveDist.B32) + +# Element-indexing syntax +low, high = pto.vldx2(tile[i, j:], DeinterleaveDist.B32) +low, high = pto.vldx2(tile[k:], DeinterleaveDist.B16) + +# Example: Load interleaved XY pairs into separate X/Y vectors +x_vec, y_vec = pto.vldx2(xy_tile[i, j:], DeinterleaveDist.B32) +``` + +#### `pto.vsld(buf: ptr, offset: Index, stride: StrideMode) -> VRegType` [Advanced Tier] +#### `pto.vsld(tile[row, col], stride: StrideMode) -> VRegType` [Basic Tier] +#### `pto.vsld(tile[pos], stride: StrideMode) -> VRegType` [Basic Tier] + +**Description**: Strided load with fixed stride pattern. Loads elements from memory with regular stride pattern. The offset parameter encodes displacement with selected stride mode. This is a deprecated compatibility family. + +**Parameters (pointer syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `ptr` | Pointer to buffer in UB memory space (Advanced mode only - requires explicit pointer) | +| `offset` | `Index` | Byte displacement encoded with selected stride mode | +| `stride` | `StrideMode` | Stride mode token: `StrideMode.S3_B16`, `StrideMode.S4_B64`, `StrideMode.S8_B32`, `StrideMode.S2_B64`. Determines which sub-elements are read from each source block. | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile[row, col]` | `Tile` with indexing | 2D tile with row and column indices (single element) | +| `stride` | `StrideMode` | Stride mode token: `StrideMode.S3_B16`, `StrideMode.S4_B64`, `StrideMode.S8_B32`, `StrideMode.S2_B64`. | +| _or_ | | | +| `tile[pos]` | `Tile` with indexing | 1D tile with element index (single element) | +| `stride` | `StrideMode` | Stride mode token: `StrideMode.S3_B16`, `StrideMode.S4_B64`, `StrideMode.S8_B32`, `StrideMode.S2_B64`. | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Loaded vector with strided pattern | + +**Constraints**: +- The selected stride token determines which sub-elements are read from each source block +- This operation family is deprecated; prefer other load patterns when possible + +**Examples**: +```python +from pto import StrideMode + +# Byte-offset syntax +vec = pto.vsld(ub_ptr, offset, StrideMode.S4_B64) + +# Element-indexing syntax +vec = pto.vsld(tile[i, j], StrideMode.S3_B16) +vec = pto.vsld(tile[k], StrideMode.S8_B32) +``` + +#### `pto.vgather2(buf: ptr, offsets: Index, active_lanes: Index) -> VRegType` [Advanced Tier] + +**Description**: Indexed gather from UB. Gathers elements from a single buffer using per-lane offsets, with participation bounded by active lanes count. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `ptr` | Pointer to source buffer in UB memory space | +| `offsets` | `Index` | Per-lane element offsets (vector register) | +| `active_lanes` | `Index` | Number of lanes that participate (bounds gather operation) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Gathered vector | + +**Constraints**: +- Only the first `active_lanes` offsets participate in the gather +- Index element width and interpretation must match selected gather form +- Each effective address must satisfy the gather form's alignment rules + +**Example**: +```python +vec = pto.vgather2(buf, offsets, active_lanes) +``` + +#### `pto.vgather2_bc(buf: ptr, offsets: Index, mask: MaskType) -> VRegType` [Advanced Tier] + +**Description**: Gather with broadcast, conditioned by mask. Gathers elements from a single buffer using per-lane offsets, with mask gating lane participation. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `ptr` | Pointer to source buffer in UB memory space | +| `offsets` | `Index` | Per-lane element offsets (vector register) | +| `mask` | `MaskType` | Mask gating which lanes participate | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Gathered vector | + +**Constraints**: +- Masked-off lanes do not participate in address coalescing and do not trigger address overflow exceptions +- Destination lanes for masked-off lanes are zero-filled +- This is a backward-compatible operation family + +**Example**: +```python +vec = pto.vgather2_bc(buf, offsets, mask) +``` + +#### `pto.vgatherb(buf: ptr, offsets: Index) -> VRegType` [Advanced Tier] + +**Description**: Byte‑granularity gather load. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `ptr` | Pointer to buffer | +| `offsets` | `Index` | Byte offsets | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Gathered vector | + +**Example**: +```python +vec = pto.vgatherb(buf, offsets) +``` + +#### `pto.vsldb(buf: ptr, offset: Index, mask: MaskType) -> VRegType` [Advanced Tier] +#### `pto.vsldb(tile[row, col], offset: Index, mask: MaskType) -> VRegType` [Basic Tier] +#### `pto.vsldb(tile[pos], offset: Index, mask: MaskType) -> VRegType` [Basic Tier] + +**Description**: Block-strided load for 2D tile access. Loads elements with block stride pattern controlled by packed offset word and mask. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `ptr` | Pointer to buffer in UB memory space | +| `offset` | `Index` | Packed stride/control word (not plain byte displacement) | +| `mask` | `MaskType` | Mask controlling which blocks participate | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile[row, col]` | `Tile` with indexing | 2D tile with row and column indices (single element) | +| `offset` | `Index` | Packed stride/control word (not plain byte displacement) | +| `mask` | `MaskType` | Mask controlling which blocks participate | +| _or_ | | | +| `tile[pos]` | `Tile` with indexing | 1D tile with element index (single element) | +| `offset` | `Index` | Packed stride/control word (not plain byte displacement) | +| `mask` | `MaskType` | Mask controlling which blocks participate | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Loaded vector with block-strided pattern | + +**Constraints**: +- The offset encodes block stride and repeat pattern, not a plain byte displacement +- If a block is masked off, the corresponding destination block is zeroed +- Masked-off blocks must not raise address overflow exceptions + +**Example**: +```python +# Byte-offset syntax +vec = pto.vsldb(ub_ptr, control_word, mask) + +# Element-indexing syntax +vec = pto.vsldb(tile[i, j], control_word, mask) +vec = pto.vsldb(tile[k], control_word, mask) +``` + +### Vector Store Operations + +Operations for storing data from vector registers to memory. + +#### `pto.vsts(vec: VRegType, buf: ptr, offset: Index, mask: MaskType) -> None` [Advanced Tier] +#### `pto.vsts(vec: VRegType, tile[row, col:], mask: MaskType) -> None` [Basic Tier] +#### `pto.vsts(vec: VRegType, tile[start:], mask: MaskType) -> None` [Basic Tier] + +**Description**: Stateless vector store to buffer. Supports both byte-offset and element-indexing syntax. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Vector to store | +| `buf` | `ptr` | Pointer to destination buffer in UB memory space (Advanced mode only - requires explicit pointer) | +| `offset` | `Index` | Byte offset | +| `mask` | `MaskType` | Predicate mask | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Vector to store | +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: None (side-effect operation) + +**Constraints**: +- Buffer must be in UB memory space +- For byte-offset syntax: offset must be properly aligned based on element type +- For element-indexing syntax: the destination vector region must be within tile bounds and satisfy alignment requirements + +**Examples**: +```python +# Byte-offset syntax +pto.vsts(vec_f32, ub_ptr, lane * 256, mask32) + +# Element-indexing syntax +pto.vsts(vec, tile[i, j:], mask) # Store to row i, columns j to j+vector_lanes-1 +pto.vsts(vec, tile[k:], mask) # Store to 1D tile, elements k to k+vector_lanes-1 + +# In a generic kernel +@pto.vkernel(target="a5", op="copy", dtypes=[(pto.AnyFloat, pto.AnyFloat)], priority=10) +def generic_store(src: pto.Tile, dst: pto.Tile): + rows, cols = src.shape + all_mask = pto.make_mask(src.element_type, PAT.ALL) + for i in range(0, rows): + for j in range(0, cols, vector_lanes): + vec = pto.vlds(src[i, j:]) + pto.vsts(vec, dst[i, j:], all_mask) # No manual offset calculation +``` + +#### `pto.psts(mask: MaskType, buf: ptr, offset: Index) -> None` [Advanced Tier] +#### `pto.psts(mask: MaskType, tile[row, col:]) -> None` +#### `pto.psts(mask: MaskType, tile[start:]) -> None` + +**Description**: Predicate store to buffer. Supports both traditional byte-offset syntax and new element-indexing syntax. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Mask to store | +| `buf` | `ptr` | Pointer to destination buffer (Advanced mode only - requires explicit pointer) | +| `offset` | `Index` | Byte offset | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Mask to store | +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | + +**Parameters (1D element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Mask to store | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | + +**Returns**: None (side-effect operation) + +#### `pto.vsst(scalar: ScalarType, buf: ptr, offset: Index, mask: MaskType) -> None` [Advanced Tier] +#### `pto.vsst(scalar: ScalarType, tile[row, col:], mask: MaskType) -> None` +#### `pto.vsst(scalar: ScalarType, tile[start:], mask: MaskType) -> None` + +**Description**: Scalar to vector store (broadcast scalar to all lanes). Supports both traditional byte-offset syntax and new element-indexing syntax. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `scalar` | `ScalarType` | Scalar value | +| `buf` | `ptr` | Pointer to destination buffer (Advanced mode only - requires explicit pointer) | +| `offset` | `Index` | Byte offset | +| `mask` | `MaskType` | Predicate mask | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `scalar` | `ScalarType` | Scalar value | +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | +| `mask` | `MaskType` | Predicate mask | + +**Parameters (1D element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `scalar` | `ScalarType` | Scalar value | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: None (side-effect operation) + +#### `pto.vstx2(low: VRegType, high: VRegType, buf: ptr, offset: Index, dist: InterleaveDist, mask: MaskType) -> None` [Advanced Tier] +#### `pto.vstx2(low: VRegType, high: VRegType, tile[row, col:], dist: InterleaveDist, mask: MaskType) -> None` +#### `pto.vstx2(low: VRegType, high: VRegType, tile[start:], dist: InterleaveDist, mask: MaskType) -> None` + +**Description**: Dual interleaved store (SoA → AoS conversion). Stores two vectors interleaved into a single buffer. Supports both byte-offset and element-indexing syntax. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `low` | `VRegType` | First vector (even elements in interleaved stream) | +| `high` | `VRegType` | Second vector (odd elements in interleaved stream) | +| `buf` | `ptr` | Pointer to destination buffer in UB memory space (Advanced mode only - requires explicit pointer) | +| `offset` | `Index` | Byte offset | +| `dist` | `InterleaveDist` | Interleave distribution mode: `InterleaveDist.B8`, `InterleaveDist.B16`, `InterleaveDist.B32`. Determines element size and interleave pattern. | +| `mask` | `MaskType` | Predicate mask | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `low` | `VRegType` | First vector (even elements in interleaved stream) | +| `high` | `VRegType` | Second vector (odd elements in interleaved stream) | +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | +| `dist` | `InterleaveDist` | Interleave distribution mode: `InterleaveDist.B8`, `InterleaveDist.B16`, `InterleaveDist.B32`. | +| `mask` | `MaskType` | Predicate mask | + +**Parameters (1D element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `low` | `VRegType` | First vector (even elements in interleaved stream) | +| `high` | `VRegType` | Second vector (odd elements in interleaved stream) | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | +| `dist` | `InterleaveDist` | Interleave distribution mode: `InterleaveDist.B8`, `InterleaveDist.B16`, `InterleaveDist.B32`. | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: None (side-effect operation) + +**Constraints**: +- Destination buffer must be in UB memory space +- Offset must satisfy alignment requirements for the selected distribution mode +- The destination vector region must be within tile bounds (for element-indexing syntax) +- Distribution mode must match element type (e.g., `InterleaveDist.B32` for 32-bit elements) +- The two source vectors form an ordered pair; interleave semantics must be preserved + +**Examples**: +```python +from pto import InterleaveDist + +# Byte-offset syntax +pto.vstx2(x_vec, y_vec, ub_ptr, offset, InterleaveDist.B32, mask) + +# Element-indexing syntax +pto.vstx2(x_vec, y_vec, tile[i, j:], InterleaveDist.B32, mask) +pto.vstx2(x_vec, y_vec, tile[k:], InterleaveDist.B16, mask) + +# Example: Store separate X/Y vectors as interleaved XY pairs +pto.vstx2(x_vec, y_vec, xy_tile[i, j:], InterleaveDist.B32, all_mask) +``` + +#### `pto.vsta(align: pto.align, buf: ptr, offset: Index) -> None` [Advanced Tier] +#### `pto.vsta(align: pto.align, tile[row, col:]) -> None` [Basic Tier] +#### `pto.vsta(align: pto.align, tile[start:]) -> None` [Basic Tier] + +**Description**: Flush alignment state to memory. Writes buffered tail bytes from alignment state to UB memory. Consumes the alignment state after flush. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `align` | `pto.align` | Pending store-alignment state | +| `buf` | `ptr` | Pointer to destination buffer in UB memory space (Advanced mode only - requires explicit pointer) | +| `offset` | `Index` | Flush displacement | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `align` | `pto.align` | Pending store-alignment state | +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | +| _or_ | | | +| `align` | `pto.align` | Pending store-alignment state | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | + +**Returns**: None (side-effect operation) + +**Constraints**: +- The flush address must match the post-updated address expected by the preceding unaligned-store stream +- After the flush, the corresponding store alignment state is consumed +- A final flush operation is required to commit buffered bytes after unaligned-store sequences +- The `align` input should come from the latest `vstu`/`vstus`/`vstur` in the same stream + +**Example**: +```python +# Byte-offset syntax +pto.vsta(align, ub_ptr, offset) + +# Element-indexing syntax +pto.vsta(align, tile[i, j:]) +pto.vsta(align, tile[k:]) +``` + +#### `pto.vscatter(vec: VRegType, buf: ptr, offsets: Index, active_lanes: Index) -> None` [Advanced Tier] + +**Description**: Indexed scatter to UB. Stores vector elements to irregular locations using per-lane offsets, with participation bounded by active lanes count. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Source vector to scatter | +| `buf` | `ptr` | Pointer to destination buffer in UB memory space | +| `offsets` | `Index` | Per-lane element offsets (vector register) | +| `active_lanes` | `Index` | Number of lanes that participate (bounds scatter operation) | + +**Returns**: None (side-effect operation) + +**Constraints**: +- Only `b8`, `b16`, and `b32` element sizes are supported +- Index vector must use a supported integer element type and layout +- Each computed address must be element-aligned +- If indices alias, only one write is guaranteed (winning lane is implementation-defined) +- Only the first `active_lanes` offsets participate in the scatter + +**Example**: +```python +pto.vscatter(vec, buf, offsets, active_lanes) +``` + +#### `pto.vsstb(scalar: ScalarType, buf: ptr, offset: Index, mask: MaskType) -> None` [Advanced Tier] +#### `pto.vsstb(scalar: ScalarType, tile[row, col:], mask: MaskType) -> None` [Basic Tier] +#### `pto.vsstb(scalar: ScalarType, tile[start:], mask: MaskType) -> None` [Basic Tier] + +**Description**: Scalar to vector store with broadcast (enhanced version of `vsst`). Supports both byte‑offset and element‑indexing syntax. + +**Parameters (pointer syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `scalar` | `ScalarType` | Scalar value | +| `buf` | `ptr` | Pointer to destination buffer | +| `offset` | `Index` | Byte offset | +| `mask` | `MaskType` | Predicate mask | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `scalar` | `ScalarType` | Scalar value | +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | +| `mask` | `MaskType` | Predicate mask | + +**Parameters (1D element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `scalar` | `ScalarType` | Scalar value | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: None (side-effect operation) + +**Example**: +```python +# Byte-offset syntax +pto.vsstb(pto.f32(0.0), ub_ptr, offset, mask) + +# Element-indexing syntax +pto.vsstb(pto.f32(1.0), tile[i, j:], mask) +``` + +#### `pto.vstar(align: pto.align, buf: ptr) -> None` [Advanced Tier] +#### `pto.vstar(align: pto.align, tile[row, col:]) -> None` [Basic Tier] +#### `pto.vstar(align: pto.align, tile[start:]) -> None` [Basic Tier] + +**Description**: Flush alignment state using the register-update form. Writes buffered tail bytes from alignment state to UB memory. The implicit update state must correspond to the same store stream that produced the alignment state. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `align` | `pto.align` | Pending store-alignment state | +| `buf` | `ptr` | Pointer to destination buffer in UB memory space | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `align` | `pto.align` | Pending store-alignment state | +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | +| _or_ | | | +| `align` | `pto.align` | Pending store-alignment state | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | + +**Returns**: None (side-effect operation) + +**Constraints**: +- The implicit update state consumed by this flush must correspond to the same store stream that produced the alignment state +- A final flush operation is required to commit buffered bytes after unaligned-store sequences +- The `align` input should come from the latest `vstu`/`vstus`/`vstur` in the same stream + +**Example**: +```python +# Byte-offset syntax +pto.vstar(align, ub_ptr) + +# Element-indexing syntax +pto.vstar(align, tile[i, j:]) +pto.vstar(align, tile[k:]) +``` + +#### `pto.vstas(align: pto.align, buf: ptr, offset: Index) -> None` [Advanced Tier] +#### `pto.vstas(align: pto.align, tile[row, col:], offset: Index) -> None` [Basic Tier] +#### `pto.vstas(align: pto.align, tile[start:], offset: Index) -> None` [Basic Tier] + +**Description**: Scalar-register-offset form of alignment-state flush. Writes buffered tail bytes from alignment state to UB memory with explicit scalar offset. Uses same buffered-tail semantics as `pto.vsta`. + +**Parameters (pointer syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `align` | `pto.align` | Pending store-alignment state | +| `buf` | `ptr` | Pointer to destination buffer in UB memory space | +| `offset` | `Index` | Scalar-register style displacement | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `align` | `pto.align` | Pending store-alignment state | +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | +| `offset` | `Index` | Scalar-register style displacement | +| _or_ | | | +| `align` | `pto.align` | Pending store-alignment state | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | +| `offset` | `Index` | Scalar-register style displacement | + +**Returns**: None (side-effect operation) + +**Example**: +```python +# Byte-offset syntax +pto.vstas(align, ub_ptr, offset) + +# Element-indexing syntax +pto.vstas(align, tile[i, j:], offset) +pto.vstas(align, tile[k:], offset) +``` + +### Stateful Store Operations + +Operations for storing data with stateful semantics. + +#### `pto.pstu(align_in: pto.align, mask: MaskType, buf: ptr) -> (pto.align, ptr)` [Advanced Tier] + +**Description**: Predicate unaligned store with align state update. Stores predicate mask with alignment state threading. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `align_in` | `pto.align` | Incoming store-alignment state | +| `mask` | `MaskType` | Predicate mask to store | +| `buf` | `ptr` | Pointer to destination buffer in UB memory space | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `align_out` | `pto.align` | Updated alignment state | +| `base_out` | `ptr` | Post-update base pointer state | + +**Constraints**: +- Part of stateful unaligned-store sequence with alignment state threading + +#### `pto.vstu(align_in: pto.align, base_in: ptr, vec: VRegType, buf: ptr, mode: Index) -> (pto.align, ptr)` [Advanced Tier] + +**Description**: Unaligned store with explicit threaded alignment/base state. Models a stateful unaligned-store sequence in SSA form. Requires a final `pto.vsta`/`pto.vstas`/`pto.vstar` to flush trailing buffered bytes. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `align_in` | `pto.align` | Incoming store-alignment state | +| `base_in` | `ptr` | Current stream base pointer | +| `vec` | `VRegType` | Vector to store | +| `buf` | `ptr` | Destination buffer in UB memory space | +| `mode` | `Index` | Mode selecting post-update behavior | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `align_out` | `pto.align` | Updated buffered-tail state | +| `base_out` | `ptr` | Post-update base pointer state | + +**Constraints**: +- Models stateful unaligned-store sequence in SSA form +- Final flush operation required to commit buffered bytes + +**Example**: +```python +# Stateful unaligned store + final flush (vsta form) +align1, base1 = pto.vstu(align0, base0, vec0, ub_ptr, mode) +align2, base2 = pto.vstu(align1, base1, vec1, ub_ptr, mode) +pto.vsta(align2, ub_ptr, tail_offset) +``` + +#### `pto.vstus(align_in: pto.align, base_in: ptr, vec: VRegType, buf: ptr, offset: Index) -> (pto.align, ptr)` [Advanced Tier] + +**Description**: Scalar-offset unaligned store with threaded state. Same roles as `pto.vstu` but with explicit scalar displacement. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `align_in` | `pto.align` | Incoming store-alignment state | +| `base_in` | `ptr` | Current stream base pointer | +| `vec` | `VRegType` | Vector to store | +| `buf` | `ptr` | Destination buffer in UB memory space | +| `offset` | `Index` | Scalar displacement | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `align_out` | `pto.align` | Updated buffered-tail state | +| `base_out` | `ptr` | Post-update base pointer state | + +**Constraints**: +- Same final flush requirement and state-threading constraints as `pto.vstu` + +**Example**: +```python +# Scalar-offset threaded form + final flush (vstas form) +align1, base1 = pto.vstus(align0, base0, vec0, ub_ptr, offset0) +align2, base2 = pto.vstus(align1, base1, vec1, ub_ptr, offset1) +pto.vstas(align2, ub_ptr, flush_offset) +``` + +#### `pto.vstur(align_in: pto.align, vec: VRegType, buf: ptr) -> pto.align` [Advanced Tier] + +**Description**: Register-update unaligned store form. Updates only the residual alignment state without base pointer update. Requires matching flush operation to emit trailing bytes. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `align_in` | `pto.align` | Incoming store-alignment state | +| `vec` | `VRegType` | Vector to store | +| `buf` | `ptr` | Destination buffer in UB memory space | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `align_out` | `pto.align` | Updated buffered-tail state | + +**Constraints**: +- Updates only residual alignment state (no base pointer update) +- Matching flush operation still required to emit trailing bytes + +**Example**: +```python +# Residual-state form + final flush (vstar form) +align1 = pto.vstur(align0, vec0, ub_ptr) +align2 = pto.vstur(align1, vec1, ub_ptr) +pto.vstar(align2, ub_ptr) +``` + +#### Align-State Store Closed Loop + +For unaligned store families, the state must form a closed loop: + +1. Start from an incoming `align` state. +2. Thread state through one or more `vstu` / `vstus` / `vstur` operations. +3. Terminate the stream with exactly one flush op: `vsta` or `vstas` or `vstar`. +4. Do not reuse a flushed `align` state in another stream. diff --git a/tilelang-dsl/docs/user_guide/12-predicate-operations.md b/tilelang-dsl/docs/user_guide/12-predicate-operations.md new file mode 100644 index 000000000..f76d2f8cb --- /dev/null +++ b/tilelang-dsl/docs/user_guide/12-predicate-operations.md @@ -0,0 +1,498 @@ +### Predicate Operations + +Operations for creating and manipulating typed masks. + +**Recommended API**: For most use cases, prefer the unified `pto.make_mask()` function which automatically selects the appropriate mask granularity based on element type and supports both tail processing (remaining element count) and pattern-based mask generation. This eliminates the need to manually choose between `plt_b8`/`plt_b16`/`plt_b32` (tail processing) and `pset_b8`/`pset_b16`/`pset_b32` (pattern generation) operations. + +**Pattern alias**: For brevity in examples, the documentation uses `PAT` as an alias for `pto.MaskPattern` (e.g., `PAT.ALL` instead of `pto.MaskPattern.PAT_ALL`). In practice, you can create this alias with `from pto import MaskPattern as PAT` or `PAT = pto.MaskPattern`. + +**Part Mode Enum**: The `PartMode` enum provides type-safe part selection for `pto.ppack` and `pto.punpack` operations. It includes the following values: `EVEN` (selects even-indexed elements) and `ODD` (selects odd-indexed elements). + +**Predicate Dist Enum**: The `PredicateDist` enum provides type-safe distribution mode selection for predicate load/store families. Common values include `NORM`, `US`, and `DS`. + +#### `pto.pset_b8(pattern: pto.MaskPattern) -> pto.mask_b8` + +**Description**: Creates an 8-bit granularity mask from a pattern. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pattern` | `pto.MaskPattern` | Mask pattern enum (e.g., `pto.MaskPattern.PAT_ALL`, `pto.MaskPattern.PAT_EVEN`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b8` | 8-bit granularity mask | + +**Constraints**: +- Used with `i8` vector operations + +**Example**: +```python +mask8 = pto.make_mask(pto.i8, PAT.ALL) +``` + +#### `pto.pset_b16(pattern: pto.MaskPattern) -> pto.mask_b16` + +**Description**: Creates a 16-bit granularity mask from a pattern. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pattern` | `pto.MaskPattern` | Mask pattern enum (e.g., `pto.MaskPattern.PAT_ALL`, `pto.MaskPattern.PAT_EVEN`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b16` | 16-bit granularity mask | + +**Constraints**: +- Used with `f16`/`bf16`/`i16` vector operations + +**Example**: +```python +mask16 = pto.make_mask(pto.f16, PAT.ALL) +``` + +#### `pto.pset_b32(pattern: pto.MaskPattern) -> pto.mask_b32` + +**Description**: Creates a 32-bit granularity mask from a pattern. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pattern` | `pto.MaskPattern` | Mask pattern enum (e.g., `pto.MaskPattern.PAT_ALL`, `pto.MaskPattern.PAT_EVEN`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b32` | 32-bit granularity mask | + +**Constraints**: +- Used with `f32`/`i32` vector operations + +**Example**: +```python +mask32 = pto.make_mask(pto.f32, PAT.ALL) +``` + +#### `pto.pge_b8(pattern: pto.MaskPattern) -> pto.mask_b8` + +**Description**: Generate tail mask — first N lanes active based on pattern. Creates an 8-bit granularity mask where the first N lanes are active according to the specified pattern. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pattern` | `pto.MaskPattern` | Tail mask pattern enum (e.g., `pto.MaskPattern.PAT_VL8`, `pto.MaskPattern.PAT_VL16`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b8` | 8-bit granularity tail mask | + +**Constraints**: +- Used with `i8` vector operations +- Pattern must be a valid tail mask pattern (typically `PAT_VL*` variants) + +**Example**: +```python +# Tail mask for first 8 lanes +tail_mask = pto.pge_b8(PAT.VL8) +``` + +#### `pto.pge_b16(pattern: pto.MaskPattern) -> pto.mask_b16` + +**Description**: Generate tail mask — first N lanes active based on pattern. Creates a 16-bit granularity mask where the first N lanes are active according to the specified pattern. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pattern` | `pto.MaskPattern` | Tail mask pattern enum (e.g., `pto.MaskPattern.PAT_VL8`, `pto.MaskPattern.PAT_VL16`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b16` | 16-bit granularity tail mask | + +**Constraints**: +- Used with `f16`/`bf16`/`i16` vector operations +- Pattern must be a valid tail mask pattern (typically `PAT_VL*` variants) + +**Example**: +```python +# Tail mask for first 16 lanes +tail_mask = pto.pge_b16(PAT.VL16) +``` + +#### `pto.pge_b32(pattern: pto.MaskPattern) -> pto.mask_b32` + +**Description**: Generate tail mask — first N lanes active based on pattern. Creates a 32-bit granularity mask where the first N lanes are active according to the specified pattern. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pattern` | `pto.MaskPattern` | Tail mask pattern enum (e.g., `pto.MaskPattern.PAT_VL8`, `pto.MaskPattern.PAT_VL16`, `pto.MaskPattern.PAT_VL32`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b32` | 32-bit granularity tail mask | + +**Constraints**: +- Used with `f32`/`i32` vector operations +- Pattern must be a valid tail mask pattern (typically `PAT_VL*` variants) + +**Example**: +```python +# Tail mask for first 32 lanes +tail_mask = pto.pge_b32(PAT.VL32) +``` + +#### `pto.plt_b8(scalar: pto.i32) -> (pto.mask_b8, pto.i32)` + +**Description**: Generate predicate state together with updated scalar state (tail processing). Creates an 8-bit granularity mask and returns updated scalar value for state progression. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `scalar` | `pto.i32` | Input scalar value (typically remaining element count) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b8` | 8-bit granularity mask | +| `scalar_out` | `pto.i32` | Updated scalar state | + +**Constraints**: +- Used with `i8` vector operations for tail processing +- The scalar input is typically a remaining element count that decrements across successive calls + +**Example**: +```python +remaining: pto.i32 = 64 +mask, remaining = pto.plt_b8(remaining) # generates mask for next chunk, updates remaining count +``` + +#### `pto.plt_b16(scalar: pto.i32) -> (pto.mask_b16, pto.i32)` + +**Description**: Generate predicate state together with updated scalar state (tail processing). Creates a 16-bit granularity mask and returns updated scalar value for state progression. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `scalar` | `pto.i32` | Input scalar value (typically remaining element count) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b16` | 16-bit granularity mask | +| `scalar_out` | `pto.i32` | Updated scalar state | + +**Constraints**: +- Used with `f16`/`bf16`/`i16` vector operations for tail processing +- The scalar input is typically a remaining element count that decrements across successive calls + +**Example**: +```python +remaining: pto.i32 = 64 +mask, remaining = pto.plt_b16(remaining) # generates mask for next chunk, updates remaining count +``` + +#### `pto.plt_b32(scalar: pto.i32) -> (pto.mask_b32, pto.i32)` + +**Description**: Generate predicate state together with updated scalar state (tail processing). Creates a 32-bit granularity mask and returns updated scalar value for state progression. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `scalar` | `pto.i32` | Input scalar value (typically remaining element count) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b32` | 32-bit granularity mask | +| `scalar_out` | `pto.i32` | Updated scalar state | + +**Constraints**: +- Used with `f32`/`i32` vector operations for tail processing +- The scalar input is typically a remaining element count that decrements across successive calls + +**Example**: +```python +remaining: pto.i32 = 64 +mask, remaining = pto.plt_b32(remaining) # generates mask for next chunk, updates remaining count +``` + +#### `pto.make_mask(element_type: Type, value: pto.i32 | pto.MaskPattern) -> MaskType | (MaskType, pto.i32)` + +**Description**: Creates a mask with appropriate bitwidth (8, 16, or 32) based on element type, automatically inferring whether to perform tail processing or pattern-based mask generation based on the `value` parameter type. This convenience function eliminates the need to manually choose between `plt_b8`/`plt_b16`/`plt_b32` and `pset_b8`/`pset_b16`/`pset_b32` operations. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `element_type` | `Type` | Element type (e.g., `pto.f32`, `pto.f16`, `pto.i8`) | +| `value` | `pto.i32` \| `pto.MaskPattern` | Either:
- Remaining element count (as `pto.i32`) for tail processing
- Mask pattern enum value for fixed mask generation (e.g., `pto.MaskPattern.PAT_ALL`, `pto.MaskPattern.PAT_VL32`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `MaskType` | Generated mask with appropriate granularity | +| `remaining` | `pto.i32` | Updated remaining element count (only returned when `value` is a `pto.i32` for tail processing) | + +**Constraints**: +- The `element_type` must be one of: `f32`, `i32`, `f16`, `bf16`, `i16`, `i8` +- The returned mask granularity matches the element type: 32-bit for `f32`/`i32`, 16-bit for `f16`/`bf16`/`i16`, 8-bit for `i8` +- The function infers the operation mode from the `value` parameter type at compile time: + - `pto.i32` value → tail processing mode (returns `(mask, updated_remaining)`) + - `pto.MaskPattern` enum value → pattern mode (returns `mask` only) + +**Implementation Note**: This function is a DSL macro that performs type-based dispatch at compile time: +- When `value` is a `pto.i32` expression: expands to corresponding `plt_b` instruction (`plt_b32`, `plt_b16`, or `plt_b8`) +- When `value` is a `pto.MaskPattern` enum value: expands to corresponding `pset_b` instruction (`pset_b32`, `pset_b16`, or `pset_b8`) + +**Example**: +```python +# Tail processing with f32 vectors: value is pto.i32 → expands to plt_b32 +mask_f32, remaining_f32 = pto.make_mask(pto.f32, remaining_elements) + +# Tail processing with f16 vectors: value is pto.i32 → expands to plt_b16 +mask_f16, remaining_f16 = pto.make_mask(pto.f16, remaining_elements) + +# Tail processing with i8 vectors: value is pto.i32 → expands to plt_b8 +mask_i8, remaining_i8 = pto.make_mask(pto.i8, remaining_elements) + +# Pattern-based mask with f32 vectors: value is MaskPattern enum → expands to pset_b32 +mask_all_f32 = pto.make_mask(pto.f32, PAT.ALL) + +# Pattern-based mask with f16 vectors: value is MaskPattern enum → expands to pset_b16 +mask_even_f16 = pto.make_mask(pto.f16, PAT.EVEN) + +# Pattern-based mask with i8 vectors: value is MaskPattern enum → expands to pset_b8 +mask_all_i8 = pto.make_mask(pto.i8, PAT.ALL) + +# Type annotations help clarify expected parameter types +remaining: pto.i32 = 1024 +mask1, updated = pto.make_mask(pto.f32, remaining) # tail processing +mask2 = pto.make_mask(pto.f32, PAT.ALL) # pattern mode +``` + +#### `pto.ppack(mask: MaskType, part: PartMode) -> MaskType` + +**Description**: Rearranges a mask according to the requested `part` selector. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Input mask (`mask_b8`, `mask_b16`, or `mask_b32`) | +| `part` | `PartMode` | Part selector enum: `PartMode.EVEN` or `PartMode.ODD`. Determines which half of the mask to pack (even-indexed or odd-indexed elements). | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `packed` | `MaskType` | Reordered mask | + +#### `pto.punpack(mask: MaskType, part: PartMode) -> MaskType` + +**Description**: Applies the inverse mask-part rearrangement selected by `part`. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Input mask | +| `part` | `PartMode` | Part selector enum: `PartMode.EVEN` or `PartMode.ODD`. Determines which half of the mask to unpack (even-indexed or odd-indexed elements). | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `MaskType` | Reordered mask | + +#### `pto.pnot(mask: MaskType, gate: MaskType) -> MaskType` + +**Description**: Predicate negation under a same-granularity mask gate. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Input mask | +| `gate` | `MaskType` | Gating mask with the same granularity | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `negated` | `MaskType` | Negated mask | + +#### `pto.psel(src0: MaskType, src1: MaskType, mask: MaskType) -> MaskType` + +**Description**: Selects between two masks using a third mask as selector. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src0` | `MaskType` | First input mask | +| `src1` | `MaskType` | Second input mask | +| `mask` | `MaskType` | Selection mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `MaskType` | Selected mask | + +#### `pto.plds(buf: ptr, offset: Index, dist: PredicateDist = PredicateDist.NORM) -> MaskType` [Advanced Tier] + +**Description**: Predicate load with scalar-index style offset form. This is the default DSL surface for loading predicate masks from UB memory. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `ptr` | Source pointer in UB memory space | +| `offset` | `Index` | Scalar/index-style offset | +| `dist` | `PredicateDist` | Distribution mode (default: `PredicateDist.NORM`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `MaskType` | Loaded predicate mask | + +**Example**: +```python +mask = pto.plds(buf, offset, PredicateDist.NORM) +``` + +#### `pto.pld(buf: ptr, offset: Index, dist: PredicateDist) -> MaskType` [Advanced Tier] + +**Description**: Predicate load with areg/index register style offset encoding. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `ptr` | Source pointer in UB memory space | +| `offset` | `Index` | Areg/index-style offset | +| `dist` | `PredicateDist` | Distribution mode | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `MaskType` | Loaded predicate mask | + +**Example**: +```python +mask = pto.pld(buf, offset, PredicateDist.NORM) +``` + +#### `pto.pldi(buf: ptr, imm_offset: pto.i32, dist: PredicateDist) -> MaskType` [Advanced Tier] + +**Description**: Predicate load with immediate-offset encoding form. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `ptr` | Source pointer in UB memory space | +| `imm_offset` | `pto.i32` | Immediate-offset operand | +| `dist` | `PredicateDist` | Distribution mode | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `MaskType` | Loaded predicate mask | + +**Example**: +```python +mask = pto.pldi(buf, 0, PredicateDist.NORM) +``` + +#### `pto.pst(mask: MaskType, buf: ptr, offset: Index) -> None` [Advanced Tier] + +**Description**: Stores a predicate mask to buffer. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Predicate mask to store | +| `buf` | `ptr` | Pointer to destination buffer | +| `offset` | `Index` | Byte offset | + +**Returns**: None (side-effect operation) + +**Example**: +```python +pto.pst(mask, buf, offset) +``` + +#### `pto.psti(mask: MaskType, imm: pto.i32) -> None` + +**Description**: Stores a predicate mask to immediate destination. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Predicate mask to store | +| `imm` | `pto.i32` | Immediate destination identifier | + +**Returns**: None (side-effect operation) + +**Example**: +```python +pto.psti(mask, 1) +``` + +#### `pto.pand(src0: MaskType, src1: MaskType) -> MaskType` + +**Description**: Bitwise AND of two predicate masks. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src0` | `MaskType` | First input mask | +| `src1` | `MaskType` | Second input mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `MaskType` | Bitwise AND of input masks | + +**Example**: +```python +result = pto.pand(mask1, mask2) +``` + +#### `pto.por(src0: MaskType, src1: MaskType) -> MaskType` + +**Description**: Bitwise OR of two predicate masks. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src0` | `MaskType` | First input mask | +| `src1` | `MaskType` | Second input mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `MaskType` | Bitwise OR of input masks | + +**Example**: +```python +result = pto.por(mask1, mask2) +``` + +#### `pto.pxor(src0: MaskType, src1: MaskType) -> MaskType` + +**Description**: Bitwise XOR of two predicate masks. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src0` | `MaskType` | First input mask | +| `src1` | `MaskType` | Second input mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `MaskType` | Bitwise XOR of input masks | + +**Example**: +```python +result = pto.pxor(mask1, mask2) +``` + +**Note**: Prefer `pto.make_mask()` for automatic bitwidth selection and unified tail/pattern mask generation. diff --git a/tilelang-dsl/docs/user_guide/13-vector-arithmetic-operations.md b/tilelang-dsl/docs/user_guide/13-vector-arithmetic-operations.md new file mode 100644 index 000000000..d2f05d6f7 --- /dev/null +++ b/tilelang-dsl/docs/user_guide/13-vector-arithmetic-operations.md @@ -0,0 +1,1414 @@ +### Unary Vector Operations + +Element-wise unary operations on vector registers. + +#### `pto.vabs(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Absolute value of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask (granularity must match vector element type) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Absolute values | + +**Constraints**: +- Mask granularity must match vector element type (e.g., `f32` requires `mask_b32`) + +**Example**: +```python +abs_vec = pto.vabs(vec_f32, mask32) +``` + +#### `pto.vexp(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Exponential of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Exponential values | + +#### `pto.vln(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Natural logarithm of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Natural logarithm values | + +#### `pto.vsqrt(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Square root of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Square root values | + +#### `pto.vrec(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Reciprocal of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Reciprocal values | + +#### `pto.vrelu(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: ReLU activation (max(0, x)) of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | ReLU-activated values | + +#### `pto.vnot(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Bitwise NOT of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Bitwise NOT values | + +#### `pto.vcadd(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Complex addition of vector elements (treating pairs as complex numbers). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector (interpreted as complex pairs) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Complex addition result | + +#### `pto.vcmax(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Complex maximum of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector (interpreted as complex pairs) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Complex maximum result | + +#### `pto.vbcnt(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Bit count (population count) of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Bit count values | + +#### `pto.vneg(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Negation of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask (granularity must match vector element type) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Negated values | + +**Constraints**: +- Mask granularity must match vector element type + +**Example**: +```python +neg_vec = pto.vneg(vec_f32, mask32) +``` + +#### `pto.vcls(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Count leading sign bits of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Count of leading sign bits | + +**Constraints**: +- Operates on integer vector types only + +#### `pto.vcmin(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Complex minimum of vector elements (treating pairs as complex numbers). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector (interpreted as complex pairs) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Complex minimum result | + +#### `pto.vrsqrt(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Reciprocal square root of vector elements (1/√x). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Reciprocal square root values | + +**Constraints**: +- For floating-point vector types only + +#### `pto.vprelu(vec: VRegType, alpha: VRegType, mask: MaskType) -> VRegType` + +**Description**: Parametric ReLU activation of vector elements: `x if x >= 0 else alpha * x`. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `alpha` | `VRegType` | Slope parameter for negative values | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Parametric ReLU activated values | + +#### `pto.vmov(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Vector move (data movement). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Copied vector | + +#### `pto.vsunpack(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Signed unpack of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Unpacked signed values | + +**Constraints**: +- Operates on integer vector types only + +#### `pto.vzunpack(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Zero-extended unpack of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Unpacked zero-extended values | + +**Constraints**: +- Operates on integer vector types only + +#### `pto.vusqz(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Unsigned squeeze (compression) of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Compressed unsigned values | + +**Constraints**: +- Operates on integer vector types only + +#### `pto.vsqz(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Signed squeeze (compression) of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Compressed signed values | + +**Constraints**: +- Operates on integer vector types only + +#### `pto.vexpdiff(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Exponential difference of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Exponential difference values | + +**Constraints**: +- For floating-point vector types only + +### Binary Vector Operations + +Element-wise binary operations on vector registers. + +#### `pto.vadd(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise addition of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Sum of vectors | + +**Example**: +```python +sum_vec = pto.vadd(vec_a, vec_b, mask32) +``` + +#### `pto.vsub(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise subtraction of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Difference of vectors | + +#### `pto.vmul(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise multiplication of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Product of vectors | + +#### `pto.vdiv(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise division of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Quotient of vectors | + +#### `pto.vmax(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise maximum of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Element-wise maximum | + +#### `pto.vmin(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise minimum of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Element-wise minimum | + +#### `pto.vand(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise bitwise AND of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Bitwise AND result | + +#### `pto.vor(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise bitwise OR of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Bitwise OR result | + +#### `pto.vxor(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise bitwise XOR of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Bitwise XOR result | + +#### `pto.vshl(vec: VRegType, shift: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise shift left (vector shift amounts). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `shift` | `VRegType` | Shift amounts (per element) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Shifted values | + +#### `pto.vshr(vec: VRegType, shift: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise shift right (vector shift amounts). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `shift` | `VRegType` | Shift amounts (per element) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Shifted values | + +#### `pto.vaddrelu(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Addition with ReLU activation (max(0, vec1 + vec2)). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | ReLU-activated sum of vectors | + +#### `pto.vaddreluconv(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Convolution addition with ReLU activation (convolution-specific fused operation). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | ReLU-activated convolution sum | + +**Constraints**: +- Optimized for convolution-specific patterns + +#### `pto.vsubrelu(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Subtraction with ReLU activation (max(0, vec1 - vec2)). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | ReLU-activated difference of vectors | + +#### `pto.vaxpy(alpha: VRegType, x: VRegType, y: VRegType, mask: MaskType) -> VRegType` + +**Description**: BLAS AXPY operation (αx + y). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `alpha` | `VRegType` | Scaling factor | +| `x` | `VRegType` | Input vector x | +| `y` | `VRegType` | Input vector y | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Result of αx + y | + +#### `pto.vmulconv(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Convolution multiplication (convolution-specific multiplication). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Convolution product | + +**Constraints**: +- Optimized for convolution-specific patterns + +#### `pto.vmull(vec1: VRegType, vec2: VRegType, mask: MaskType) -> (VRegType, VRegType)` + +**Description**: Widening multiply with split low/high results (extended arithmetic). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `low` | `VRegType` | Low part of widened product (`r & 0xFFFFFFFF`) | +| `high` | `VRegType` | High part of widened product (`r >> 32`) | + +**Constraints**: +- Current A5 documented form is native `i32/u32` 32x32->64 widening multiply +- Result is split into two vector outputs instead of a single widened vector + +**Example**: +```python +low, high = pto.vmull(lhs_i32, rhs_i32, mask32) +``` + +#### `pto.vmula(vec1: VRegType, vec2: VRegType, vec3: VRegType, mask: MaskType) -> VRegType` + +**Description**: Fused multiply-add (vec1 * vec2 + vec3). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector (multiplier) | +| `vec2` | `VRegType` | Second input vector (multiplicand) | +| `vec3` | `VRegType` | Third input vector (addend) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Result of vec1 * vec2 + vec3 | + +### Vector-Scalar Operations + +Operations between vectors and scalars. + +#### `pto.vmuls(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Vector multiplied by scalar (broadcast). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Scalar multiplier | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Scaled vector | + +**Example**: +```python +scaled = pto.vmuls(vec_f32, pto.f32(2.0), mask32) +``` + +#### `pto.vadds(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Vector plus scalar (broadcast). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Scalar addend | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Result vector | + +#### `pto.vmaxs(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Element-wise maximum of vector and scalar. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Scalar value | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Maximum values | + +#### `pto.vmins(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Element-wise minimum of vector and scalar. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Scalar value | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Minimum values | + +#### `pto.vlrelu(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Leaky ReLU activation (max(αx, x)). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Alpha coefficient | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Leaky ReLU activated values | + +#### `pto.vshls(vec: VRegType, shift: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Vector shift left by scalar (uniform shift). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `shift` | `ScalarType` | Shift amount (same for all elements) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Shifted values | + +#### `pto.vshrs(vec: VRegType, shift: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Vector shift right by scalar (uniform shift). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `shift` | `ScalarType` | Shift amount (same for all elements) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Shifted values | + +#### `pto.vands(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Element-wise bitwise AND of vector and scalar. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Scalar operand | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Bitwise AND result | + +**Constraints**: +- Operates on integer vector types only + +#### `pto.vors(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Element-wise bitwise OR of vector and scalar. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Scalar operand | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Bitwise OR result | + +**Constraints**: +- Operates on integer vector types only + +#### `pto.vxors(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Element-wise bitwise XOR of vector and scalar. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Scalar operand | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Bitwise XOR result | + +**Constraints**: +- Operates on integer vector types only + +#### `pto.vsubs(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Vector minus scalar (broadcast). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Scalar subtrahend | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Difference vector | + +#### `pto.vbr(value: ScalarType) -> VRegType` + +**Description**: Broadcast scalar to all vector lanes. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `value` | `ScalarType` | Scalar source | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Vector whose active lanes all carry `value` | + +**Constraints**: +- Supported scalar types are `i8`, `i16`, `i32`, `f16`, `bf16`, `f32`. +- For integer types, only the low bits of the scalar source are consumed according to the bit width (8, 16, or 32 bits). + +**Example**: +```python +# Broadcast scalar constant to vector +zero_vec = pto.vbr(0.0) +one_vec = pto.vbr(1.0) + +# Reduction seed with explicit floating dtype +rowmax_seed_f32 = pto.vbr(pto.f32("-inf")) +rowmax_seed_f16 = pto.vbr(pto.f16("0xFC00")) +``` + +**Position Mode Enum**: The `PositionMode` enum provides type-safe position selection for `pto.vdup` operations. Currently only `LOWEST` (selects the lowest-index element) is supported, with more position options planned for future releases. + +#### `pto.vdup(input: ScalarType | VRegType, position: PositionMode = PositionMode.LOWEST) -> VRegType` + +**Description**: Duplicate scalar or vector element to all lanes. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `input` | `ScalarType` or `VRegType` | Input scalar or source vector | +| `position` | `PositionMode` | Optional enum selecting which source element to duplicate (default: `PositionMode.LOWEST`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Vector with duplicated value in all lanes | + +**Constraints**: +- When `input` is a scalar, it is broadcast to all lanes (similar to `pto.vbr` but with `position` attribute). +- When `input` is a vector, the element selected by `position` is duplicated to all lanes. +- Supported scalar types are `i8`, `i16`, `i32`, `f16`, `bf16`, `f32`. +- The `position` enum selects which source element or scalar position is duplicated. Currently only `PositionMode.LOWEST` is supported, which selects the lowest-index element. + +**Example**: +```python +# Broadcast scalar to vector (similar to pto.vbr) +broadcast = pto.vdup(3.14) # position defaults to "POS_LOWEST" + +# Use dtype constructor when the semantic value is floating-point special value +seed = pto.vdup(pto.f32("-inf")) +seed_f16 = pto.vdup(pto.f16("0xFC00")) + +# Duplicate lowest element of vector to all lanes +vec = pto.vreg_f32(64) # 64-element vector +dup_lowest = pto.vdup(vec) # position defaults to "POS_LOWEST" + +# Explicit position specification +dup_explicit = pto.vdup(vec, position=PositionMode.LOWEST) +``` + +**Type Safety Note**: +- For floating-point seeds, prefer `pto.f16(...)` / `pto.bf16(...)` / `pto.f32(...)` constructors. +- Do not pass integer bit-pattern literals directly (for example `0xFF800000`) when a floating vector type is intended. + +### Carry & Select Operations + +Operations with carry propagation and selection. + +**Comparison Mode Enum**: The `CmpMode` enum provides type-safe comparison mode specification for `pto.vcmp` and `pto.vcmps` operations. It includes the following values: `EQ` (equal), `NE` (not equal), `LT` (less than), `LE` (less than or equal), `GT` (greater than), `GE` (greater than or equal). + +Implemented current-package carry/select surface also includes: +- `pto.vselr(vec0, vec1) -> VRegType` +- `pto.vselrv2(vec0, vec1) -> VRegType` +- `pto.vaddcs(vec0, vec1, carry_in, mask) -> (VRegType, MaskType)` +- `pto.vsubcs(vec0, vec1, carry_in, mask) -> (VRegType, MaskType)` + +#### `pto.vcmp(vec0: VRegType, vec1: VRegType, seed_mask: MaskType, cmp_mode: CmpMode) -> MaskType` + +**Description**: Element-wise vector comparison with seed mask. Compares two vectors element-wise and generates a predicate mask based on the specified comparison mode. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec0` | `VRegType` | First input vector | +| `vec1` | `VRegType` | Second input vector | +| `seed_mask` | `MaskType` | Seed mask that determines which lanes participate in the comparison | +| `cmp_mode` | `CmpMode` | Comparison mode enum: `CmpMode.EQ` (equal), `CmpMode.NE` (not equal), `CmpMode.LT` (less than), `CmpMode.LE` (less than or equal), `CmpMode.GT` (greater than), `CmpMode.GE` (greater than or equal) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `MaskType` | Generated predicate mask based on element-wise comparison | + +**Constraints**: +- Only lanes enabled by `seed_mask` participate in the comparison +- The two input vectors must have the same element type and vector length +- The output mask granularity matches the input vector element type + +**Example**: +```python +# Compare two vectors for less-than relation +all_mask = pto.make_mask(pto.f32, PAT.ALL) +lt_mask = pto.vcmp(vec_a, vec_b, all_mask, CmpMode.LT) +``` + +#### `pto.vcmps(vec: VRegType, scalar: ScalarType, seed_mask: MaskType, cmp_mode: CmpMode) -> MaskType` + +**Description**: Vector-scalar comparison with seed mask. Compares each element of a vector against a scalar value and generates a predicate mask based on the specified comparison mode. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Scalar value to compare against (must match vector element type) | +| `seed_mask` | `MaskType` | Seed mask that determines which lanes participate in the comparison | +| `cmp_mode` | `CmpMode` | Comparison mode enum: `CmpMode.EQ`, `CmpMode.NE`, `CmpMode.LT`, `CmpMode.LE`, `CmpMode.GT`, `CmpMode.GE` | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `MaskType` | Generated predicate mask based on vector-scalar comparison | + +**Constraints**: +- Only lanes enabled by `seed_mask` participate in the comparison +- The scalar type must match the vector element type +- The output mask granularity matches the input vector element type + +**Example**: +```python +# Check which elements are greater than zero +all_mask = pto.make_mask(pto.f32, PAT.ALL) +positive_mask = pto.vcmps(values, pto.f32(0.0), all_mask, CmpMode.GT) +``` + +#### `pto.vaddc(vec1: VRegType, vec2: VRegType, mask: MaskType) -> (VRegType, MaskType)` + +**Description**: Vector addition with carry output. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Sum vector | +| `carry_out` | `MaskType` | Output carry mask | + +#### `pto.vsubc(vec1: VRegType, vec2: VRegType, mask: MaskType) -> (VRegType, MaskType)` + +**Description**: Vector subtraction with borrow output. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Difference vector | +| `borrow_out` | `MaskType` | Output borrow mask | + +#### `pto.vsel(true_vec: VRegType, false_vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Vector select based on mask. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `true_vec` | `VRegType` | Vector selected when mask bit is 1 | +| `false_vec` | `VRegType` | Vector selected when mask bit is 0 | +| `mask` | `MaskType` | Selection mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Selected vector | + +**Example**: +```python +result = pto.vsel(scaled_vec, original_vec, mask32) +``` + +### Reduction Operations + +Reduction operations across vector lanes or channels. + +#### `pto.vcgadd(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Cross-group addition reduction (reduction across VLanes). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Reduced sum across groups | + +#### `pto.vcgmax(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Cross-group maximum reduction (reduction across VLanes). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Reduced maximum across groups | + +#### `pto.vcgmin(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Cross-group minimum reduction (reduction across VLanes). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Reduced minimum across groups | + +#### `pto.vcpadd(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Cross-channel addition reduction (reduction across channels). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Reduced sum across channels | + +### Data Rearrangement + +Operations for rearranging data within vectors. + +#### `pto.pdintlv_b8(mask: pto.mask_b8) -> pto.mask_b8` + +**Description**: Deinterleave 8-bit mask. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `pto.mask_b8` | Input 8-bit mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `pto.mask_b8` | Deinterleaved mask | + +#### `pto.pintlv_b16(mask: pto.mask_b16) -> pto.mask_b16` + +**Description**: Interleave 16-bit mask. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `pto.mask_b16` | Input 16-bit mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `pto.mask_b16` | Interleaved mask | + +Implemented current-package rearrangement surface also includes: +- `pto.vintlvv2(vec0, vec1, part) -> VRegType` +- `pto.vdintlvv2(vec0, vec1, part) -> VRegType` + +#### `pto.vintlv(vec1: VRegType, vec2: VRegType) -> (VRegType, VRegType)` + +**Description**: Interleave two vectors and return the low/high results. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `low` | `VRegType` | Low interleaved result | +| `high` | `VRegType` | High interleaved result | + +#### `pto.vdintlv(vec0: VRegType, vec1: VRegType) -> (VRegType, VRegType)` + +**Description**: Deinterleave a pair of vectors into low/high results. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec0` | `VRegType` | First input vector | +| `vec1` | `VRegType` | Second input vector | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec1` | `VRegType` | First deinterleaved vector | +| `vec2` | `VRegType` | Second deinterleaved vector | + +#### `pto.vpack(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Vector packing (combine elements from two vectors). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Packed vector | + +#### `pto.vperm(vec: VRegType, indices: VRegType, mask: MaskType) -> VRegType` + +**Description**: Vector permutation (reorder elements according to index vector). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `indices` | `VRegType` | Permutation indices | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Permuted vector | + +#### `pto.vshift(vec: VRegType, shift_amount: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Generic vector shift (shift all elements by same amount). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `shift_amount` | `ScalarType` | Shift amount (same for all elements) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Shifted vector | + +#### `pto.vslide(vec: VRegType, window_size: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Vector sliding window (create overlapping windows). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `window_size` | `ScalarType` | Size of sliding window | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Sliding window result | + +#### `pto.vsort32(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: 32-element sorting of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector (32 elements) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Sorted vector | + +**Constraints**: +- Input vector must have exactly 32 elements + +#### `pto.vmrgsort(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Merge sort of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Merged and sorted vector | + +#### `pto.vtranspose(dest: ptr, src: ptr, config: pto.i64) -> None` [Advanced Tier] + +**Description**: UB-to-UB transpose operation. This op works on UB memory directly (not `vreg -> vreg`). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `dest` | `ptr` | Destination pointer in UB memory space | +| `src` | `ptr` | Source pointer in UB memory space | +| `config` | `pto.i64` | ISA control/config operand that encodes transpose layout behavior | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `None` | `None` | Side-effect operation that writes transposed data to `dest` | + +**Constraints**: +- `dest` and `src` must be UB pointers +- Correctness depends on the `config` encoding and UB layout contract + +**Example**: +```python +pto.vtranspose(dst_ub_ptr, src_ub_ptr, config_word) +``` + +### Conversion & Special Operations + +Type conversion and specialized operations. + +#### `pto.vtrc(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Truncate vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Truncated vector | + +#### `pto.vcvt(vec: VRegType, to_type: Type, mask: MaskType) -> VRegType` + +**Description**: Type conversion of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `to_type` | `Type` | Target element type | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Converted vector | + +#### `pto.vbitsort(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Bitonic sort of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Sorted vector | + +#### `pto.vmrgsort4(vec1: VRegType, vec2: VRegType, vec3: VRegType, vec4: VRegType, mask: MaskType) -> VRegType` + +**Description**: 4-way merge sort of vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `vec3` | `VRegType` | Third input vector | +| `vec4` | `VRegType` | Fourth input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Merged and sorted vector | + +**Order Mode Enum**: The `OrderMode` enum provides type-safe order selection for `pto.vci` operations. Currently only `ASC` (ascending order) is supported, with more order options planned for future releases. + +#### `pto.vci(index: ScalarType, order: OrderMode = OrderMode.ASC) -> VRegType` + +**Description**: Generate a lane-index vector from a scalar seed/index value (DSA/SFU operation). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `index` | `ScalarType` | Scalar seed or base index value | +| `order` | `OrderMode` | Order mode enum (default: `OrderMode.ASC` for ascending order) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Generated index vector | + +**Constraints**: +- This is an index-generation family, not a numeric conversion +- The `order` parameter and result element type together determine how indices are generated +- Currently only ascending order (`OrderMode.ASC`) is supported + +**Example**: +```python +# Generate ascending indices starting from 0 +indices = pto.vci(pto.i32(0), OrderMode.ASC) +``` diff --git a/tilelang-dsl/docs/user_guide/14-examples.md b/tilelang-dsl/docs/user_guide/14-examples.md new file mode 100644 index 000000000..89331015d --- /dev/null +++ b/tilelang-dsl/docs/user_guide/14-examples.md @@ -0,0 +1,154 @@ +## Examples + +### Template-based Kernel Examples + +#### Unified Arithmetic Operations + +A single kernel implementing multiple arithmetic operations using templates: + +```python +T = pto.TypeVar('T') + +@pto.vkernel( + target="a5", + ops=["tadd", "tsub", "tmul", "tdiv"], + dtypes=[(T, T, T)], + advanced=True, + templates={ + "core": { + "tadd": "vadd", + "tsub": "vsub", + "tmul": "vmul", + "tdiv": "vdiv", + } + } +) +def elementwise_arithmetic(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + """Single implementation for four arithmetic operations.""" + dtype = dst.element_type + rows, cols = dst.valid_shape + + for row in range(0, rows, 1): + remained = cols + for col in range(0, cols, pto.elements_per_vreg(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + out = pto.tpl("core", lhs, rhs, mask) + pto.vsts(out, dst[row, col:], mask) +``` + +#### Multiple Templates with Postprocess + +Kernel using separate templates for arithmetic and postprocess operations: + +```python +@pto.vkernel( + target="a5", + ops=["add_relu", "sub_relu", "add_abs", "sub_abs"], + dtypes=[(T, T, T)], + templates={ + "arithmetic": { + "add_relu": "vadd", + "sub_relu": "vsub", + "add_abs": "vadd", + "sub_abs": "vsub", + }, + "postprocess": { + "add_relu": "vrelu", + "sub_relu": "vrelu", + "add_abs": "vabs", + "sub_abs": "vabs", + } + } +) +def elementwise_with_postprocess(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + dtype = dst.element_type + rows, cols = dst.valid_shape + + for row in range(0, rows, 1): + remained = cols + for col in range(0, cols, pto.elements_per_vreg(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + + # Use arithmetic template + arith_result = pto.tpl("arithmetic", lhs, rhs, mask) + + # Apply postprocess template + activated = pto.tpl("postprocess", arith_result, mask) + + pto.vsts(activated, dst[row, col:], mask) +``` + +#### Compile-time Substitution + +Template substitution happens before semantic analysis and lowering: + +```python +selected = pto.select_kernel("a5", "tadd", (ptype, ptype, ptype)) +# frontend resolves: +# pto.tpl("core", lhs, rhs, mask) +# into: +# pto.vadd(lhs, rhs, mask) +``` + +#### Benefits of Template-based Authoring + +1. **Code Reuse**: Single implementation serves multiple operations +2. **Maintenance**: Bug fixes and optimizations apply to all related operations +3. **Consistency**: Ensures uniform behavior across operation families +4. **Reduced Boilerplate**: Eliminates duplicate control flow and data movement code +5. **Type Safety**: Type variables ensure consistent operand types + +### Simple Vector Copy + +```python +@pto.vkernel(...) +def vector_copy(src: pto.Tile, dst: pto.Tile): + all_mask: pto.mask_b32 = pto.make_mask(pto.f32, PAT.ALL) + for offset in range(0, 256, 64): + vec = pto.vlds(src, offset) + pto.vsts(vec, dst, offset, all_mask) +``` + +### Conditional Computation + +```python +@pto.vkernel(...) +def conditional_scale(src: pto.ptr(pto.f32, MemorySpace.GM), + dst: pto.ptr(pto.f32, MemorySpace.GM), + threshold: pto.f32): + # ... setup ... + + with pto.strict_vecscope(ub_in, ub_out, threshold) as (vin, vout, thresh): + for i in range(0, 1024, 64): + vec = pto.vlds(vin, i) + + # Compare with threshold + mask = pto.pge_b32(vec, thresh) + + # Scale values above threshold + scaled = pto.vmuls(vec, pto.f32(2.0), mask) + + # Keep original values below threshold + result = pto.vsel(scaled, vec, mask) + + pto.vsts(result, vout, i, all_mask) +``` + +### Loop with Carry + +```python +@pto.vkernel(...) +def prefix_sum(src: pto.ptr(pto.i32, MemorySpace.UB), + dst: pto.ptr(pto.i32, MemorySpace.UB)): + all_mask = pto.make_mask(pto.i32, PAT.ALL) + carry = all_mask + + for i in range(0, 256, 64): + vec = pto.vlds(src, i) + result, carry = pto.vaddcs(vec, vec, carry, all_mask) + pto.vsts(result, dst, i, all_mask) +``` diff --git a/tilelang-dsl/docs/user_guide/15-common-errors.md b/tilelang-dsl/docs/user_guide/15-common-errors.md new file mode 100644 index 000000000..46abe09b9 --- /dev/null +++ b/tilelang-dsl/docs/user_guide/15-common-errors.md @@ -0,0 +1,51 @@ +## Common Errors + +### Typed Mask Mismatch + +``` +Error: f32 vector operation cannot consume mask_b16 +``` + +**Solution:** Ensure mask granularity matches vector element size: +- `f32` vectors use `mask_b32` +- `f16` vectors use `mask_b16` +- `i8` vectors use `mask_b8` + +### Strict Scope Implicit Capture + +``` +Error: strict_vecscope body cannot capture outer value 'ub_in' implicitly +``` + +**Solution:** Pass all required values in the capture list: + +```python +# Wrong: +with pto.strict_vecscope() as (): + vec = pto.vlds(ub_in, offset) # ub_in from outer scope + +# Correct: +with pto.strict_vecscope(ub_in) as (ub): + vec = pto.vlds(ub, offset) +``` + +### Untyped Loop Carried State + +``` +Error: loop-carried value must have explicit machine type +``` + +**Solution:** Add type annotations to loop-carried variables: + +```python +# Wrong: +remaining = 1024 # Plain Python int +for i in range(0, N, step): + mask, remaining = pto.make_mask(pto.f32, remaining) + +# Correct: +remaining: pto.i32 = 1024 +# or +remaining = pto.i32(1024) +``` + diff --git a/tilelang-dsl/docs/user_guide/16-compatibility-notes.md b/tilelang-dsl/docs/user_guide/16-compatibility-notes.md new file mode 100644 index 000000000..defcf704c --- /dev/null +++ b/tilelang-dsl/docs/user_guide/16-compatibility-notes.md @@ -0,0 +1,9 @@ +## Compatibility Notes + +The current experimental implementation in `python/pto/dialects/pto.py` differs from this specification in several ways: + +1. **Mask types**: The experimental version uses untyped `mask` instead of `mask_b8`/`mask_b16`/`mask_b32` +2. **Barrier operation**: Uses `pto.barrier()` instead of `pto.pipe_barrier()` +3. **Operation coverage**: Implements only a subset of operations + +When implementing new code, follow this specification. The experimental implementation will be updated to match over time. diff --git a/tilelang-dsl/docs/user_guide/17-next-steps.md b/tilelang-dsl/docs/user_guide/17-next-steps.md new file mode 100644 index 000000000..2fe63b9a4 --- /dev/null +++ b/tilelang-dsl/docs/user_guide/17-next-steps.md @@ -0,0 +1,7 @@ +## Next Steps + +- Explore the ISA documentation in `docs/isa/` for detailed operation semantics +- Check `test/samples/` for example kernels +- Refer to `docs/vpto-spec.md` for the underlying VPTO instruction specification + +For compiler developers, see `docs/PTO_IR_manual.md` for MLIR-level details. diff --git a/tilelang-dsl/docs/v1-lowering.md b/tilelang-dsl/docs/v1-lowering.md index 9c0c6c8f4..c8ae1d82d 100644 --- a/tilelang-dsl/docs/v1-lowering.md +++ b/tilelang-dsl/docs/v1-lowering.md @@ -13,7 +13,6 @@ It covers: It does not define: - matcher-driven dispatch -- implicit vecscope inference - raw pointer authoring surface - advanced vector-family lowering beyond the fixed v1 matrix @@ -35,7 +34,7 @@ OpenSpec source of truth for this capability: ## Implemented v1 Support Matrix The current v1 lowering contract supports: -- 2D `TensorView` +- fixed-rank 5D `TensorView` descriptors - 1D/2D `Tile` - `dma_load` - `dma_store` @@ -52,15 +51,23 @@ The current v1 lowering contract supports: Current lowering shape: - emits stable `func.func + arith/scf + pto.*` authoring-form VPTO modules -- requires explicit `pto.strict_vecscope` +- defaults to memref-first function/tile authoring when the target VPTO family supports memref operands +- keeps `copy_*` family on typed `!pto.ptr` +- infers dedicated `pto.vecscope` for stable vector-active runs +- lowers `pto.strict_vecscope` buffer captures through ptr-form region ABI so the current emission-boundary ptr rewrite stays legal +- only accepts explicit `pto.strict_vecscope` in `advanced=True` kernels - rejects support-matrix-external surface in the frontend ## Dynamic-Bound Profile The implemented shape profile is: - Tile physical shape must stay static -- TensorView shape access may lower through hidden shape arguments +- TensorView parameters stay in authoring IR as `!pto.tensor_view<...>` +- TensorView shape access lowers through `pto.get_tensor_view_dim` +- TensorView stride access lowers through `pto.get_tensor_view_stride` - TensorView slice bounds may be dynamic +- TensorView slice spelling may omit leading axes; written axes are right-aligned + onto the trailing physical axes of the 5D descriptor - loop bounds may be dynamic - tail `remaining` values may be dynamic @@ -69,12 +76,15 @@ TensorView slice extent is dynamic. This keeps v1 inside the current authoring-form contract without introducing fully dynamic Tile allocation or tail-DMA semantics. +Although the descriptor rank is 5D, the current DMA-oriented slicing/lowering +path still only supports rank-2 TensorView slices. + ## Examples Examples aligned with the implemented surface: - `tilelang-dsl/examples/v1_elementwise_tail_demo.py` - emits a guide-style elementwise authoring kernel - - covers DMA, explicit `strict_vecscope`, dynamic loop bound, and typed tail mask + - covers DMA, advanced-only explicit `strict_vecscope`, dynamic loop bound, and typed tail mask - `tilelang-dsl/examples/v1_verify_smoke.py` - emits a minimal module that is expected to pass the current repo `ptoas --pto-backend=vpto` legality path diff --git a/tilelang-dsl/docs/v1-surface.md b/tilelang-dsl/docs/v1-surface.md index 53e3b2dab..dcf169cd1 100644 --- a/tilelang-dsl/docs/v1-surface.md +++ b/tilelang-dsl/docs/v1-surface.md @@ -55,8 +55,13 @@ The package currently exports: - `TileLangFrontendError` - `TensorView` - `Tile` +- `VRegType` +- `MaskType` - scalar dtypes such as `f16`, `bf16`, `f32`, `i8`, `i16`, `i32`, `i64` -- Tile specialization helpers: `MemorySpace`, `TileConfig`, `TileSpecialization` +- type helpers such as `vreg(...)`, `ptr(...)`, `mask_b8`, `mask_b16`, `mask_b32`, `MemorySpace`, `TileConfig`, `TileSpecialization` + +The package does not expose a DSL-level `pto.memref(...)` constructor. MemRef +only appears in generated/lowered IR, not in the public authoring type surface. ## v1 Decorator Surface @@ -93,16 +98,22 @@ The descriptor keeps these metadata fields: v1 accepts these parameter categories: - bare `TensorView` - bare `Tile` -- explicit scalar annotations such as `pto.i32`, `pto.f16`, `pto.f32` +- scalar annotations such as `pto.i32`, `pto.f16`, `pto.f32`, `pto.AnyType`, or `pto.TypeVar("T")` Binding rules: - the single `dtypes` signature binds parameter element types positionally - `TensorView` parameters get their element dtype from the same position in `dtypes` - `Tile` parameters get their element dtype from the same position in `dtypes` -- scalar parameters must use an explicit scalar annotation -- scalar annotations must exactly match the dtype at the same position in +- scalar parameters must use a TileLang scalar-style annotation +- scalar annotations may be concrete scalar dtypes, wildcard dtypes, or + `TypeVar(...)` +- concrete scalar annotations must exactly match the dtype at the same position + in `dtypes` +- wildcard scalar annotations must accept the dtype at the same position in `dtypes` +- `TypeVar(...)` scalar annotations bind to the selected dtype at the same + position in `dtypes` Example: diff --git a/tilelang-dsl/examples/README.md b/tilelang-dsl/examples/README.md index 116c2439a..970b382d5 100644 --- a/tilelang-dsl/examples/README.md +++ b/tilelang-dsl/examples/README.md @@ -8,8 +8,13 @@ Current examples: - `v1_elementwise_tail_demo.py`: guide-aligned elementwise authoring demo that covers DMA, explicit `strict_vecscope`, dynamic loop bound, and typed tail mask lowering +- `v1_template_slot_multiop_demo.py`: shared kernel-body demo for + `tadd`/`tsub`/`tmul`/`tdiv` using `ops=[...]`, `templates={...}`, and + `pto.tpl("core", ...)` - `v1_tadd_implicit_vecscope_demo.py`: advanced-mode flattened `TADD` example - with implicit `pto.vecscope` inference and `vlds`/`vsts` tile indexing sugar + with implicit `pto.vecscope` inference, dynamic Tile `valid_shape`, generic + dtype selection, partial-dynamic `valid_shape` modes, and `vlds`/`vsts` + tile indexing sugar - `v1_tbinop_2d_nopostupdate_demo.py`: a representative TileLang DSL v1 expansion of `pto::TBinOps_2D_NoPostUpdate` using `vadd` - `v1_verify_smoke.py`: minimal verify smoke that is expected to pass the repo @@ -21,7 +26,13 @@ Typical usage from the repository root: python3 tilelang-dsl/examples/v1_emit_mlir_demo.py python3 tilelang-dsl/examples/v1_emit_mlir_demo.py /tmp/tilelang_demo.mlir PYTHONPATH=$PWD/tilelang-dsl/python python3 tilelang-dsl/examples/v1_elementwise_tail_demo.py +PYTHONPATH=$PWD/tilelang-dsl/python python3 tilelang-dsl/examples/v1_template_slot_multiop_demo.py +PYTHONPATH=$PWD/tilelang-dsl/python python3 tilelang-dsl/examples/v1_template_slot_multiop_demo.py tsub +PYTHONPATH=$PWD/tilelang-dsl/python python3 tilelang-dsl/examples/v1_template_slot_multiop_demo.py tmul f16 PYTHONPATH=$PWD/tilelang-dsl/python python3 tilelang-dsl/examples/v1_tadd_implicit_vecscope_demo.py +PYTHONPATH=$PWD/tilelang-dsl/python python3 tilelang-dsl/examples/v1_tadd_implicit_vecscope_demo.py f16 +PYTHONPATH=$PWD/tilelang-dsl/python python3 tilelang-dsl/examples/v1_tadd_implicit_vecscope_demo.py f16 rows +PYTHONPATH=$PWD/tilelang-dsl/python python3 tilelang-dsl/examples/v1_tadd_implicit_vecscope_demo.py f16 cols PYTHONPATH=$PWD/tilelang-dsl/python python3 tilelang-dsl/examples/v1_tbinop_2d_nopostupdate_demo.py PYTHONPATH=$PWD/tilelang-dsl/python python3 tilelang-dsl/examples/v1_verify_smoke.py ``` diff --git a/tilelang-dsl/examples/tadd_demo.py b/tilelang-dsl/examples/tadd_demo.py index d1843f450..8bf8071f1 100644 --- a/tilelang-dsl/examples/tadd_demo.py +++ b/tilelang-dsl/examples/tadd_demo.py @@ -1,8 +1,8 @@ """TileLang DSL v1 demo: pto.tadd (element-wise add) using Tile parameters. -Note: v1 surface only supports 1D vectorized iteration within strict_vecscope. -The canonical 2D row×col loop with dynamic masking requires v2 features. -This demo demonstrates a 1D inner-loop pattern over the tile's column extent. +This example intentionally enables `advanced=True` because it demonstrates +explicit `strict_vecscope`. Stable kernels can rely on inferred `pto.vecscope` +and `tile[row, col:]` indexing sugar without opting into advanced mode. """ import sys @@ -27,6 +27,7 @@ def _import_tilelang_dsl(): op="pto.tadd", dtypes=[(pto.f32, pto.f32, pto.f32)], name="template_tadd", + advanced=True, ) def template_tadd(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): # v1 strict_vecscope: all referenced values must be passed in explicitly, diff --git a/tilelang-dsl/examples/v1_elementwise_tail_demo.py b/tilelang-dsl/examples/v1_elementwise_tail_demo.py index 75b3ec4dd..28dc7a5d4 100644 --- a/tilelang-dsl/examples/v1_elementwise_tail_demo.py +++ b/tilelang-dsl/examples/v1_elementwise_tail_demo.py @@ -25,6 +25,7 @@ def _import_tilelang_dsl(): op="eltwise", dtypes=[(pto.f32, pto.f32, pto.f32, pto.i32)], name="tilelang_v1_elementwise_tail_demo", + advanced=True, ) def kernel(inp: pto.TensorView, out: pto.TensorView, tile: pto.Tile, remaining: pto.i32): rows = inp.shape[0] diff --git a/tilelang-dsl/examples/v1_tadd_implicit_vecscope_demo.py b/tilelang-dsl/examples/v1_tadd_implicit_vecscope_demo.py index 429de8a33..76b16d0f7 100644 --- a/tilelang-dsl/examples/v1_tadd_implicit_vecscope_demo.py +++ b/tilelang-dsl/examples/v1_tadd_implicit_vecscope_demo.py @@ -5,7 +5,14 @@ 2D row-major vector body directly in Python: - top-level interface uses `dst, src0, src1` Tile parameters like `TADD` -- `advanced=True` enables implicit `pto.vecscope` inference +- Tile specializations keep a static physical tile shape while exposing a + dynamic `valid_shape` input at materialization time; the demo can model + fully dynamic or partially dynamic `(valid_rows, valid_cols)` profiles +- the kernel surface is dtype-polymorphic and can be selected for any supported + vector dtype with `pto.select_kernel(...)` +- implicit `pto.vecscope` inference and tile indexing sugar cover the base + vector authoring path; this demo also keeps `advanced=True` enabled because it + lives alongside the matcher/advanced-surface examples - `pto.vlds(tile[row, col:])` / `pto.vsts(vec, tile[row, col:], mask)` use tile indexing sugar instead of manual offset arithmetic """ @@ -29,11 +36,22 @@ def _import_tilelang_dsl(): pto = _import_tilelang_dsl() +T = pto.TypeVar("T") +SUPPORTED_DTYPES = { + "i8": pto.i8, + "i16": pto.i16, + "i32": pto.i32, + "f16": pto.f16, + "bf16": pto.bf16, + "f32": pto.f32, +} +VALID_SHAPE_MODES = ("both", "rows", "cols", "static") +TILE_SHAPE = (8, 64) @pto.vkernel( op="tadd", - dtypes=[(pto.f32, pto.f32, pto.f32)], + dtypes=[(T, T, T)], advanced=True, name="tilelang_advanced_tadd_demo", ) @@ -53,23 +71,76 @@ def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): return None -def build_specialized_kernel(): - return kernel.specialize( - dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), - src0=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), - src1=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), +def _resolve_valid_shape_profile(mode: str) -> tuple[object, object]: + rows, cols = TILE_SHAPE + if mode == "both": + return ("valid_rows", "valid_cols") + if mode == "rows": + return ("valid_rows", cols) + if mode == "cols": + return (rows, "valid_cols") + if mode == "static": + return TILE_SHAPE + raise ValueError(f"unsupported valid_shape mode '{mode}'") + + +def build_specialized_kernel(dtype=pto.f32, valid_shape_mode="both"): + selected = pto.select_kernel("a5", "tadd", (dtype, dtype, dtype)) + valid_shape = _resolve_valid_shape_profile(valid_shape_mode) + return selected.specialize( + src0=pto.TileSpecialization( + shape=TILE_SHAPE, + memory_space=pto.MemorySpace.UB, + valid_shape=valid_shape, + ), + src1=pto.TileSpecialization( + shape=TILE_SHAPE, + memory_space=pto.MemorySpace.UB, + valid_shape=valid_shape, + ), + dst=pto.TileSpecialization( + shape=TILE_SHAPE, + memory_space=pto.MemorySpace.UB, + valid_shape=valid_shape, + ), ) -def main(argv) -> int: - specialized = build_specialized_kernel() +def _parse_cli(argv): + if len(argv) > 4: + return None, None, None + + dtype = pto.f32 + valid_shape_mode = "both" + output_path = None + args = list(argv[1:]) + for arg in args: + if arg in SUPPORTED_DTYPES: + dtype = SUPPORTED_DTYPES[arg] + continue + if arg in VALID_SHAPE_MODES: + valid_shape_mode = arg + continue + if output_path is None: + output_path = Path(arg) + continue + return None, None, None + return dtype, valid_shape_mode, output_path + - if len(argv) > 2: - print(f"usage: {Path(argv[0]).name} [output.mlir]", file=sys.stderr) +def main(argv) -> int: + dtype, valid_shape_mode, output_path = _parse_cli(argv) + if dtype is None: + supported = ", ".join(SUPPORTED_DTYPES) + valid_shape_modes = ", ".join(VALID_SHAPE_MODES) + print( + f"usage: {Path(argv[0]).name} [{supported}] [{valid_shape_modes}] [output.mlir]", + file=sys.stderr, + ) return 2 + specialized = build_specialized_kernel(dtype=dtype, valid_shape_mode=valid_shape_mode) - if len(argv) == 2: - output_path = Path(argv[1]) + if output_path is not None: specialized.emit(output_path) print(f"wrote MLIR to {output_path}") return 0 diff --git a/tilelang-dsl/examples/v1_tbinop_2d_nopostupdate_demo.py b/tilelang-dsl/examples/v1_tbinop_2d_nopostupdate_demo.py index 5809bbe3a..96f879640 100644 --- a/tilelang-dsl/examples/v1_tbinop_2d_nopostupdate_demo.py +++ b/tilelang-dsl/examples/v1_tbinop_2d_nopostupdate_demo.py @@ -36,6 +36,7 @@ def _import_tilelang_dsl(): op="eltwise", dtypes=[(pto.f32, pto.f32, pto.f32, pto.f32, pto.f32, pto.f32)], name="tilelang_v1_tbinop_2d_nopostupdate_demo", + advanced=True, ) def kernel( lhs_gm: pto.TensorView, diff --git a/tilelang-dsl/examples/v1_template_slot_multiop_demo.py b/tilelang-dsl/examples/v1_template_slot_multiop_demo.py new file mode 100644 index 000000000..68d2a67f9 --- /dev/null +++ b/tilelang-dsl/examples/v1_template_slot_multiop_demo.py @@ -0,0 +1,146 @@ +"""Shared-kernel-body TileLang DSL v1 demo using template slots. + +This example shows the recommended authoring pattern for a small family of +binary elementwise ops that share the same traversal, mask, load, and store +structure: + +- one `@pto.vkernel` descriptor matches multiple concrete ops via `ops=[...]` +- `templates={"core": ...}` maps each concrete op to its real `pto.*` vector op +- the kernel body uses a single `pto.tpl("core", ...)` placeholder call +- `pto.select_kernel(...)` binds the concrete op before materialization +""" + +import sys +from pathlib import Path + + +def _import_tilelang_dsl(): + repo_root = Path(__file__).resolve().parents[2] + candidates = ( + repo_root / "tilelang-dsl" / "python", + repo_root / "build" / "python", + ) + for candidate in reversed(candidates): + if candidate.is_dir(): + sys.path.insert(0, str(candidate)) + import tilelang_dsl as pto + + return pto + + +pto = _import_tilelang_dsl() +T = pto.TypeVar("T") +SUPPORTED_DTYPES = { + "i8": pto.i8, + "i16": pto.i16, + "i32": pto.i32, + "f16": pto.f16, + "bf16": pto.bf16, + "f32": pto.f32, +} +SUPPORTED_OPS = ( + "tadd", + "tsub", + "tmul", + "tdiv", +) +TILE_SHAPE = (8, 64) + + +@pto.vkernel( + ops=list(SUPPORTED_OPS), + dtypes=[(T, T, T)], + advanced=True, + templates={ + "core": { + "tadd": "vadd", + "tsub": "vsub", + "tmul": "vmul", + "tdiv": "vdiv", + } + }, + name="tilelang_template_slot_multiop_demo", +) +def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + out = pto.tpl("core", lhs, rhs, mask) + pto.vsts(out, dst[row, col:], mask) + return None + + +def build_specialized_kernel(op_name="tadd", dtype=pto.f32): + if op_name not in SUPPORTED_OPS: + raise ValueError(f"unsupported op '{op_name}'") + selected = pto.select_kernel("a5", op_name, (dtype, dtype, dtype)) + return selected.specialize( + src0=pto.TileSpecialization( + shape=TILE_SHAPE, + memory_space=pto.MemorySpace.UB, + valid_shape=("valid_rows", "valid_cols"), + ), + src1=pto.TileSpecialization( + shape=TILE_SHAPE, + memory_space=pto.MemorySpace.UB, + valid_shape=("valid_rows", "valid_cols"), + ), + dst=pto.TileSpecialization( + shape=TILE_SHAPE, + memory_space=pto.MemorySpace.UB, + valid_shape=("valid_rows", "valid_cols"), + ), + ) + + +def _parse_cli(argv): + if len(argv) > 4: + return None, None, None + + op_name = "tadd" + dtype = pto.f32 + output_path = None + for arg in argv[1:]: + if arg in SUPPORTED_OPS: + op_name = arg + continue + if arg in SUPPORTED_DTYPES: + dtype = SUPPORTED_DTYPES[arg] + continue + if output_path is None: + output_path = Path(arg) + continue + return None, None, None + return op_name, dtype, output_path + + +def main(argv) -> int: + op_name, dtype, output_path = _parse_cli(argv) + if op_name is None: + supported_ops = ", ".join(SUPPORTED_OPS) + supported_dtypes = ", ".join(SUPPORTED_DTYPES) + print( + f"usage: {Path(argv[0]).name} [{supported_ops}] [{supported_dtypes}] [output.mlir]", + file=sys.stderr, + ) + return 2 + + specialized = build_specialized_kernel(op_name=op_name, dtype=dtype) + + if output_path is not None: + specialized.emit(output_path) + print(f"wrote MLIR to {output_path}") + return 0 + + print(specialized.mlir_text()) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main(sys.argv)) diff --git a/tilelang-dsl/python/tilelang_dsl/__init__.py b/tilelang-dsl/python/tilelang_dsl/__init__.py index 73c2694c8..acd27f302 100644 --- a/tilelang-dsl/python/tilelang_dsl/__init__.py +++ b/tilelang-dsl/python/tilelang_dsl/__init__.py @@ -1,11 +1,21 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + """TileLang DSL v1 package.""" from .kernel import ( BoundKernelParameter, + InlineProcDescriptor, KernelRegistry, MaterializedMLIRModule, TileLangFrontendError, VKernelDescriptor, + inline_proc, select_kernel, vkernel, ) @@ -17,20 +27,29 @@ EVENT, PIPE, Event, + MaskType, MemorySpace, MaskPattern, PAT, + PadMode, + PositionMode, + OrderMode, PointerType, Pipe, ScalarType, TensorView, + PartitionTensorView, Tile, TileConfig, TileSpecialization, TypeVar, TypeVariable, + VRegType, WildcardType, bf16, + constexpr, + bytewidth, + elements_per_vreg, f16, f32, get_lanes, @@ -39,15 +58,21 @@ i16, i32, i64, + mask_b8, + mask_b16, + mask_b32, ptr, + vreg, ) __all__ = [ "BoundKernelParameter", + "InlineProcDescriptor", "KernelRegistry", "MaterializedMLIRModule", "TileLangFrontendError", "VKernelDescriptor", + "inline_proc", "select_kernel", "vkernel", "ScalarType", @@ -55,9 +80,13 @@ "TypeVariable", "TypeVar", "TensorView", + "PartitionTensorView", "Tile", "PointerType", + "VRegType", + "MaskType", "ptr", + "vreg", "MemorySpace", "Pipe", "Event", @@ -65,6 +94,9 @@ "EVENT", "MaskPattern", "PAT", + "PadMode", + "PositionMode", + "OrderMode", "TileConfig", "TileSpecialization", "i1", @@ -79,5 +111,11 @@ "AnyInt", "AnyType", "AnyMask", + "mask_b8", + "mask_b16", + "mask_b32", + "constexpr", + "bytewidth", "get_lanes", + "elements_per_vreg", ] diff --git a/tilelang-dsl/python/tilelang_dsl/expand_helper.py b/tilelang-dsl/python/tilelang_dsl/expand_helper.py index 0a93b4818..3db9cf225 100644 --- a/tilelang-dsl/python/tilelang_dsl/expand_helper.py +++ b/tilelang-dsl/python/tilelang_dsl/expand_helper.py @@ -17,10 +17,11 @@ import argparse import importlib.util +import json import sys from pathlib import Path -from .kernel import VKernelDescriptor +from .kernel import VKernelDescriptor, _match_descriptor_dtype_signature from .types import MemorySpace, ScalarType, TileSpecialization @@ -71,30 +72,76 @@ def _import_py_file(path: Path): def _match_descriptor( descriptors: list[VKernelDescriptor], op_name: str, - dtype_name: str, + operand_types: tuple[ScalarType, ...], ) -> VKernelDescriptor | None: - """Find the first descriptor matching (op, dtype).""" - target_dtype = _DTYPE_MAP.get(dtype_name) - if target_dtype is None: - return None - + """Find and bind the first descriptor matching (op, dtype).""" for desc in descriptors: - if desc.op != op_name: + if op_name not in desc.match_ops: continue - # Check dtype signature: all entries must match the target dtype. - sig = desc.dtype_signature - if all(d == target_dtype for d in sig): - return desc + op_bound = desc._bind_selected_op(op_name) + matched_signature = _match_descriptor_dtype_signature(op_bound, operand_types) + if matched_signature is None: + continue + if op_bound._selected_dtype_signature == matched_signature: + return op_bound + return op_bound._bind_selected_dtype_signature(matched_signature) return None +def _parse_operand_specs(spec_text: str) -> list[dict]: + try: + raw_specs = json.loads(spec_text) + except json.JSONDecodeError as exc: + raise ValueError(f"invalid operand-specs JSON: {exc}") from exc + + if not isinstance(raw_specs, list) or not raw_specs: + raise ValueError("operand-specs must be a non-empty JSON array") + + specs: list[dict] = [] + for index, raw in enumerate(raw_specs): + if not isinstance(raw, dict): + raise ValueError(f"operand-specs[{index}] must be an object") + kind = raw.get("kind") + dtype_name = raw.get("dtype") + dtype = _DTYPE_MAP.get(dtype_name) + if dtype is None: + raise ValueError(f"operand-specs[{index}] has unsupported dtype {dtype_name!r}") + if kind == "scalar": + specs.append({"kind": "scalar", "dtype": dtype}) + continue + if kind == "tile": + shape = raw.get("shape") + if not isinstance(shape, list) or not shape: + raise ValueError(f"operand-specs[{index}] tile shape must be a non-empty list") + memory_space = _MEMSPACE_MAP.get(raw.get("memory_space")) + if memory_space is None: + raise ValueError( + f"operand-specs[{index}] has unknown memory-space {raw.get('memory_space')!r}" + ) + specs.append( + { + "kind": "tile", + "dtype": dtype, + "shape": tuple(int(dim) for dim in shape), + "memory_space": memory_space, + } + ) + continue + raise ValueError(f"operand-specs[{index}] has unknown kind {kind!r}") + return specs + + def main(argv: list[str] | None = None) -> int: parser = argparse.ArgumentParser(description="TileLang DSL expand helper") parser.add_argument("--template-dir", required=True, help="Directory of .py templates") parser.add_argument("--op", required=True, help="Tile op name, e.g. pto.tadd") - parser.add_argument("--dtype", required=True, help="Element dtype, e.g. f32") - parser.add_argument("--shape", required=True, help="Tile shape, e.g. 16,64") + parser.add_argument("--dtype", help="Element dtype, e.g. f32") + parser.add_argument("--shape", help="Tile shape, e.g. 16,64") parser.add_argument("--memory-space", default="ub", help="Memory space (ub or gm)") + parser.add_argument( + "--operand-specs", + help="JSON array describing each operand (tile/scalar schema)", + ) args = parser.parse_args(argv) template_dir = Path(args.template_dir) @@ -102,11 +149,32 @@ def main(argv: list[str] | None = None) -> int: print(f"expand_helper: error: {template_dir} is not a directory", file=sys.stderr) return 1 - shape = tuple(int(d) for d in args.shape.split(",")) - mem_space = _MEMSPACE_MAP.get(args.memory_space) - if mem_space is None: - print(f"expand_helper: error: unknown memory-space '{args.memory_space}'", file=sys.stderr) - return 1 + operand_specs: list[dict] | None = None + if args.operand_specs: + try: + operand_specs = _parse_operand_specs(args.operand_specs) + except ValueError as exc: + print(f"expand_helper: error: {exc}", file=sys.stderr) + return 1 + else: + if args.dtype is None or args.shape is None: + print( + "expand_helper: error: either --operand-specs or both --dtype/--shape are required", + file=sys.stderr, + ) + return 1 + shape = tuple(int(d) for d in args.shape.split(",")) + mem_space = _MEMSPACE_MAP.get(args.memory_space) + if mem_space is None: + print(f"expand_helper: error: unknown memory-space '{args.memory_space}'", file=sys.stderr) + return 1 + target_dtype = _DTYPE_MAP.get(args.dtype) + if target_dtype is None: + print(f"expand_helper: error: unknown dtype '{args.dtype}'", file=sys.stderr) + return 1 + operand_specs = [ + {"kind": "tile", "dtype": target_dtype, "shape": shape, "memory_space": mem_space} + ] # Scan all .py files for descriptors. all_descriptors: list[VKernelDescriptor] = [] @@ -121,21 +189,43 @@ def main(argv: list[str] | None = None) -> int: return 1 # Match. - desc = _match_descriptor(all_descriptors, args.op, args.dtype) + operand_types = tuple(spec["dtype"] for spec in operand_specs) + desc = _match_descriptor(all_descriptors, args.op, operand_types) if desc is None: print( - f"expand_helper: error: no template matches op={args.op} dtype={args.dtype}", + f"expand_helper: error: no template matches op={args.op} operand_types={operand_types!r}", file=sys.stderr, ) return 1 - # Specialize all Tile parameters with the same shape/memory_space. - tile_specs = {} - for param in desc.tile_parameters: - tile_specs[param.name] = TileSpecialization( - shape=shape, - memory_space=mem_space, + if len(desc.parameters) != len(operand_specs): + print( + "expand_helper: error: descriptor parameter count does not match operand-specs", + file=sys.stderr, ) + return 1 + + # Specialize Tile parameters positionally from operand-specs. + tile_specs = {} + for param, operand_spec in zip(desc.parameters, operand_specs): + if param.kind == "tile": + if operand_spec["kind"] != "tile": + print( + "expand_helper: error: descriptor tile parameter does not match operand-specs", + file=sys.stderr, + ) + return 1 + tile_specs[param.name] = TileSpecialization( + shape=operand_spec["shape"], + memory_space=operand_spec["memory_space"], + ) + continue + if param.kind == "scalar" and operand_spec["kind"] != "scalar": + print( + "expand_helper: error: descriptor scalar parameter does not match operand-specs", + file=sys.stderr, + ) + return 1 specialized = desc.specialize(**tile_specs) diff --git a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py index ca7a9c361..a48927d5a 100644 --- a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py +++ b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py @@ -3,9 +3,21 @@ from __future__ import annotations import ast +import inspect from dataclasses import dataclass from typing import Any +from .support_matrix import ( + ADVANCED_EXPR_PTO_CALLS, + ADVANCED_TOPLEVEL_PTO_CALLS, + ADVANCED_VECSCOPE_PTO_CALLS, + DEFERRED_PTO_SURFACES, + SUPPORTED_TOPLEVEL_PTO_CALLS, + SUPPORTED_VECSCOPE_PTO_CALLS, + advanced_mode_message, + deferred_surface_message, +) + @dataclass(frozen=True) class FrontendParameterNode: @@ -21,6 +33,7 @@ class FrontendTileSpecializationNode: shape: tuple[int, ...] memory_space: str config: Any + valid_shape: tuple[int | None, ...] | None class FrontendExprNode: @@ -79,6 +92,7 @@ class FrontendCallExpr(FrontendExprNode): namespace: str | None name: str args: tuple[FrontendExprNode, ...] + keywords: tuple[tuple[str, FrontendExprNode], ...] = () class FrontendTargetNode: @@ -130,6 +144,12 @@ class FrontendIfStmt(FrontendStmtNode): condition: FrontendExprNode then_body: tuple[FrontendStmtNode, ...] else_body: tuple[FrontendStmtNode, ...] + is_constexpr: bool = False + + +@dataclass(frozen=True) +class FrontendVecscopeStmt(FrontendStmtNode): + body: tuple[FrontendStmtNode, ...] @dataclass(frozen=True) @@ -139,6 +159,20 @@ class FrontendStrictVecscopeStmt(FrontendStmtNode): body: tuple[FrontendStmtNode, ...] +@dataclass(frozen=True) +class FrontendInlineProcParameterNode: + name: str + annotation: Any + default: FrontendExprNode | None + + +@dataclass(frozen=True) +class FrontendInlineProcNode: + name: str + parameters: tuple[FrontendInlineProcParameterNode, ...] + body: tuple[FrontendStmtNode, ...] + + @dataclass(frozen=True) class FrontendKernelNode: target: str @@ -150,6 +184,409 @@ class FrontendKernelNode: parameters: tuple[FrontendParameterNode, ...] tile_specializations: tuple[FrontendTileSpecializationNode, ...] body: tuple[FrontendStmtNode, ...] + inline_procs: tuple[FrontendInlineProcNode, ...] = () + + +@dataclass(frozen=True) +class _FrontendInlineProc: + name: str + source_info: Any + signature: inspect.Signature + + +@dataclass(frozen=True) +class _FrontendBuildContext: + source_info: Any + templates: dict[str, dict[str, str]] + selected_op: str | None + advanced_enabled: bool + inline_procs: dict[str, _FrontendInlineProc] + active_inline_proc_stack: tuple[str, ...] = () + vecscope_depth: int = 0 + + def error(self, node: ast.AST, message: str) -> Exception: + if self.source_info is not None: + return self.source_info.error(node, message) + return ValueError(message) + + def nested_vecscope(self) -> "_FrontendBuildContext": + return _FrontendBuildContext( + source_info=self.source_info, + templates=self.templates, + selected_op=self.selected_op, + advanced_enabled=self.advanced_enabled, + inline_procs=self.inline_procs, + active_inline_proc_stack=self.active_inline_proc_stack, + vecscope_depth=self.vecscope_depth + 1, + ) + + def enter_inline_proc(self, name: str, source_info: Any) -> "_FrontendBuildContext": + return _FrontendBuildContext( + source_info=source_info, + templates=self.templates, + selected_op=self.selected_op, + advanced_enabled=self.advanced_enabled, + inline_procs=self.inline_procs, + active_inline_proc_stack=(*self.active_inline_proc_stack, name), + vecscope_depth=self.vecscope_depth, + ) + + +def _inline_proc_param_specs(inline_proc: _FrontendInlineProc) -> tuple[tuple[str, ast.expr | None], ...]: + function_def = inline_proc.source_info.function_def + params = function_def.args.args + defaults = function_def.args.defaults + first_default = len(params) - len(defaults) + specs: list[tuple[str, ast.expr | None]] = [] + for index, param in enumerate(params): + default_node: ast.expr | None = None + if index >= first_default: + default_node = defaults[index - first_default] + specs.append((param.arg, default_node)) + return tuple(specs) + + +def _bind_inline_proc_call( + node: ast.Call, + inline_proc: _FrontendInlineProc, + context: _FrontendBuildContext, +) -> tuple[FrontendExprNode, ...]: + if any(keyword.arg is None for keyword in node.keywords): + raise context.error( + node, + "keyword unpacking via `**` is not supported in TileLang DSL v1", + ) + + param_specs = _inline_proc_param_specs(inline_proc) + param_names = tuple(param_name for param_name, _ in param_specs) + bound: dict[str, FrontendExprNode] = {} + + if len(node.args) > len(param_specs): + raise context.error( + node, + f"inline_proc `{inline_proc.name}` accepts at most {len(param_specs)} positional arguments in TileLang DSL v1", + ) + + for index, arg_node in enumerate(node.args): + param_name = param_names[index] + bound[param_name] = _build_expr(arg_node, context) + + seen_keywords: set[str] = set() + for keyword in node.keywords: + assert keyword.arg is not None + if keyword.arg in seen_keywords: + raise context.error( + keyword.value, + f"duplicate keyword `{keyword.arg}` for inline_proc `{inline_proc.name}` in TileLang DSL v1", + ) + if keyword.arg not in param_names: + raise context.error( + keyword.value, + f"inline_proc `{inline_proc.name}` does not define keyword `{keyword.arg}` in TileLang DSL v1", + ) + if keyword.arg in bound: + raise context.error( + keyword.value, + f"inline_proc `{inline_proc.name}` got multiple values for argument `{keyword.arg}` in TileLang DSL v1", + ) + seen_keywords.add(keyword.arg) + bound[keyword.arg] = _build_expr(keyword.value, context) + + ordered_args: list[FrontendExprNode] = [] + for param_name, default_node in param_specs: + value = bound.get(param_name) + if value is None: + if default_node is None: + raise context.error( + node, + f"inline_proc `{inline_proc.name}` is missing required argument `{param_name}` in TileLang DSL v1", + ) + value = _build_expr(default_node, context) + ordered_args.append(value) + return tuple(ordered_args) + + +def _collect_name_reads(expr: FrontendExprNode) -> set[str]: + if isinstance(expr, FrontendNameExpr): + return {expr.name} + if isinstance(expr, (FrontendConstantExpr, FrontendSymbolExpr)): + return set() + if isinstance(expr, FrontendSliceExpr): + names: set[str] = set() + if expr.start is not None: + names |= _collect_name_reads(expr.start) + if expr.stop is not None: + names |= _collect_name_reads(expr.stop) + if expr.step is not None: + names |= _collect_name_reads(expr.step) + return names + if isinstance(expr, FrontendTupleExpr): + names: set[str] = set() + for element in expr.elements: + names |= _collect_name_reads(element) + return names + if isinstance(expr, FrontendAttributeExpr): + return _collect_name_reads(expr.base) + if isinstance(expr, FrontendSubscriptExpr): + return _collect_name_reads(expr.base) | _collect_name_reads(expr.index) + if isinstance(expr, FrontendBinaryExpr): + return _collect_name_reads(expr.lhs) | _collect_name_reads(expr.rhs) + if isinstance(expr, FrontendCallExpr): + names: set[str] = set() + for arg in expr.args: + names |= _collect_name_reads(arg) + for _, keyword_value in expr.keywords: + names |= _collect_name_reads(keyword_value) + return names + return set() + + +def _extract_target_names(target: FrontendTargetNode) -> set[str]: + if isinstance(target, FrontendNameTarget): + return {target.name} + if isinstance(target, FrontendTupleTarget): + return {element.name for element in target.elements} + return set() + +def _validate_inline_capture( + stmt: FrontendStmtNode, + param_names: set[str], + assigned_names: set[str], + *, + context: _FrontendBuildContext, +) -> None: + allowed = param_names | assigned_names + if isinstance(stmt, FrontendAssignStmt): + missing = _collect_name_reads(stmt.value) - allowed + if missing: + name = sorted(missing)[0] + raise context.error( + context.source_info.function_def, + f"implicit capture of '{name}' is not allowed in inline_proc", + ) + assigned_names |= _extract_target_names(stmt.target) + return + if isinstance(stmt, FrontendExprStmt): + missing = _collect_name_reads(stmt.expr) - allowed + if missing: + name = sorted(missing)[0] + raise context.error( + context.source_info.function_def, + f"implicit capture of '{name}' is not allowed in inline_proc", + ) + return + if isinstance(stmt, FrontendReturnStmt): + if stmt.value is None: + return + missing = _collect_name_reads(stmt.value) - allowed + if missing: + name = sorted(missing)[0] + raise context.error( + context.source_info.function_def, + f"implicit capture of '{name}' is not allowed in inline_proc", + ) + return + if isinstance(stmt, FrontendForStmt): + header_reads = ( + _collect_name_reads(stmt.lower_bound) + | _collect_name_reads(stmt.upper_bound) + | _collect_name_reads(stmt.step) + ) + missing = header_reads - allowed + if missing: + name = sorted(missing)[0] + raise context.error( + context.source_info.function_def, + f"implicit capture of '{name}' is not allowed in inline_proc", + ) + + loop_assigned = set(assigned_names) + loop_assigned.add(stmt.target) + for child in stmt.body: + _validate_inline_capture(child, param_names, loop_assigned, context=context) + assigned_names.add(stmt.target) + return + if isinstance(stmt, FrontendIfStmt): + missing = _collect_name_reads(stmt.condition) - allowed + if missing: + name = sorted(missing)[0] + raise context.error( + context.source_info.function_def, + f"implicit capture of '{name}' is not allowed in inline_proc", + ) + then_assigned = set(assigned_names) + else_assigned = set(assigned_names) + for child in stmt.then_body: + _validate_inline_capture(child, param_names, then_assigned, context=context) + for child in stmt.else_body: + _validate_inline_capture(child, param_names, else_assigned, context=context) + assigned_names |= then_assigned | else_assigned + return + if isinstance(stmt, FrontendVecscopeStmt): + scope_assigned = set(assigned_names) + for child in stmt.body: + _validate_inline_capture(child, param_names, scope_assigned, context=context) + assigned_names |= scope_assigned + return + if isinstance(stmt, FrontendStrictVecscopeStmt): + captures_missing = set().union(*(_collect_name_reads(capture) for capture in stmt.captures)) - allowed + if captures_missing: + name = sorted(captures_missing)[0] + raise context.error( + context.source_info.function_def, + f"implicit capture of '{name}' is not allowed in inline_proc", + ) + scope_assigned = set(assigned_names) | set(stmt.block_arguments) + for child in stmt.body: + _validate_inline_capture(child, param_names, scope_assigned, context=context) + assigned_names |= scope_assigned + + +def _collect_inline_proc_calls_expr( + expr: FrontendExprNode, + inline_proc_names: set[str], + into: set[str], +) -> None: + if isinstance(expr, FrontendCallExpr): + if expr.namespace is None and expr.name in inline_proc_names: + into.add(expr.name) + for arg in expr.args: + _collect_inline_proc_calls_expr(arg, inline_proc_names, into) + for _, keyword_value in expr.keywords: + _collect_inline_proc_calls_expr(keyword_value, inline_proc_names, into) + return + if isinstance(expr, FrontendBinaryExpr): + _collect_inline_proc_calls_expr(expr.lhs, inline_proc_names, into) + _collect_inline_proc_calls_expr(expr.rhs, inline_proc_names, into) + return + if isinstance(expr, FrontendTupleExpr): + for element in expr.elements: + _collect_inline_proc_calls_expr(element, inline_proc_names, into) + return + if isinstance(expr, FrontendSliceExpr): + if expr.start is not None: + _collect_inline_proc_calls_expr(expr.start, inline_proc_names, into) + if expr.stop is not None: + _collect_inline_proc_calls_expr(expr.stop, inline_proc_names, into) + if expr.step is not None: + _collect_inline_proc_calls_expr(expr.step, inline_proc_names, into) + return + if isinstance(expr, FrontendAttributeExpr): + _collect_inline_proc_calls_expr(expr.base, inline_proc_names, into) + return + if isinstance(expr, FrontendSubscriptExpr): + _collect_inline_proc_calls_expr(expr.base, inline_proc_names, into) + _collect_inline_proc_calls_expr(expr.index, inline_proc_names, into) + + +def _collect_inline_proc_calls_stmt( + stmt: FrontendStmtNode, + inline_proc_names: set[str], + into: set[str], +) -> None: + if isinstance(stmt, FrontendAssignStmt): + _collect_inline_proc_calls_expr(stmt.value, inline_proc_names, into) + return + if isinstance(stmt, FrontendExprStmt): + _collect_inline_proc_calls_expr(stmt.expr, inline_proc_names, into) + return + if isinstance(stmt, FrontendReturnStmt): + if stmt.value is not None: + _collect_inline_proc_calls_expr(stmt.value, inline_proc_names, into) + return + if isinstance(stmt, FrontendForStmt): + _collect_inline_proc_calls_expr(stmt.lower_bound, inline_proc_names, into) + _collect_inline_proc_calls_expr(stmt.upper_bound, inline_proc_names, into) + _collect_inline_proc_calls_expr(stmt.step, inline_proc_names, into) + for child in stmt.body: + _collect_inline_proc_calls_stmt(child, inline_proc_names, into) + return + if isinstance(stmt, FrontendIfStmt): + _collect_inline_proc_calls_expr(stmt.condition, inline_proc_names, into) + for child in stmt.then_body: + _collect_inline_proc_calls_stmt(child, inline_proc_names, into) + for child in stmt.else_body: + _collect_inline_proc_calls_stmt(child, inline_proc_names, into) + return + if isinstance(stmt, FrontendVecscopeStmt): + for child in stmt.body: + _collect_inline_proc_calls_stmt(child, inline_proc_names, into) + return + if isinstance(stmt, FrontendStrictVecscopeStmt): + for capture in stmt.captures: + _collect_inline_proc_calls_expr(capture, inline_proc_names, into) + for child in stmt.body: + _collect_inline_proc_calls_stmt(child, inline_proc_names, into) + + +def _validate_inline_proc_call_graph( + kernel_body: tuple[FrontendStmtNode, ...], + inline_proc_nodes: tuple[FrontendInlineProcNode, ...], + inline_proc_source_infos: dict[str, Any], +) -> None: + inline_proc_names = {node.name for node in inline_proc_nodes} + if not inline_proc_names: + return + + edges: dict[str, set[str]] = {node.name: set() for node in inline_proc_nodes} + for inline_proc_node in inline_proc_nodes: + callees = edges[inline_proc_node.name] + for stmt in inline_proc_node.body: + _collect_inline_proc_calls_stmt(stmt, inline_proc_names, callees) + + root_callees: set[str] = set() + for stmt in kernel_body: + _collect_inline_proc_calls_stmt(stmt, inline_proc_names, root_callees) + + color: dict[str, int] = {} + + def dfs(name: str) -> None: + state = color.get(name, 0) + if state == 1: + source_info = inline_proc_source_infos.get(name) + if source_info is not None: + raise source_info.error( + source_info.function_def, + f"recursive inline_proc call `{name}` is not supported in TileLang DSL v1", + ) + raise ValueError(f"recursive inline_proc call `{name}` is not supported in TileLang DSL v1") + if state == 2: + return + color[name] = 1 + for callee in edges.get(name, ()): + dfs(callee) + color[name] = 2 + + for callee in sorted(root_callees): + dfs(callee) + + +def _collect_reachable_inline_procs( + kernel_body: tuple[FrontendStmtNode, ...], + inline_proc_nodes: tuple[FrontendInlineProcNode, ...], +) -> set[str]: + inline_proc_names = {node.name for node in inline_proc_nodes} + if not inline_proc_names: + return set() + + edges: dict[str, set[str]] = {node.name: set() for node in inline_proc_nodes} + for inline_proc_node in inline_proc_nodes: + for stmt in inline_proc_node.body: + _collect_inline_proc_calls_stmt(stmt, inline_proc_names, edges[inline_proc_node.name]) + + roots: set[str] = set() + for stmt in kernel_body: + _collect_inline_proc_calls_stmt(stmt, inline_proc_names, roots) + + reachable: set[str] = set() + stack = list(roots) + while stack: + name = stack.pop() + if name in reachable: + continue + reachable.add(name) + stack.extend(edges.get(name, ())) + return reachable _BINARY_OP_NAMES = { @@ -158,6 +595,57 @@ class FrontendKernelNode: ast.Mult: "mul", ast.FloorDiv: "floordiv", } +_COMPARE_OP_NAMES = { + ast.Eq: "eq", + ast.NotEq: "ne", + ast.Gt: "gt", + ast.Lt: "lt", + ast.GtE: "ge", + ast.LtE: "le", +} +_BOOL_OP_NAMES = { + ast.And: "and", + ast.Or: "or", +} + +_DMA_CALL_KEYWORDS: dict[str, frozenset[str]] = { + "set_loop2_stride_outtoub": frozenset({"src_stride", "dst_stride"}), + "set_loop1_stride_outtoub": frozenset({"src_stride", "dst_stride"}), + "set_loop_size_outtoub": frozenset({"loop1", "loop2"}), + "set_loop2_stride_ubtoout": frozenset({"src_stride", "dst_stride"}), + "set_loop1_stride_ubtoout": frozenset({"src_stride", "dst_stride"}), + "set_loop_size_ubtoout": frozenset({"loop1", "loop2"}), + "copy_gm_to_ubuf": frozenset( + { + "src", + "dst", + "sid", + "n_burst", + "len_burst", + "left_padding_count", + "right_padding_count", + "data_select_bit", + "enable_ub_pad", + "l2_cache_ctl", + "gm_stride", + "ub_stride", + } + ), + "copy_ubuf_to_gm": frozenset( + { + "src", + "dst", + "sid", + "n_burst", + "len_burst", + "reserved", + "burst_dst_stride", + "burst_src_stride", + "gm_stride", + "ub_stride", + } + ), +} def _attribute_path(node: ast.AST) -> tuple[str, ...] | None: @@ -171,147 +659,412 @@ def _attribute_path(node: ast.AST) -> tuple[str, ...] | None: return None -def _build_expr(node: ast.AST, source_info: Any) -> FrontendExprNode: +def _validate_resolved_template_op_surface( + op_name: str, + node: ast.AST, + context: _FrontendBuildContext, +) -> None: + if op_name in SUPPORTED_TOPLEVEL_PTO_CALLS: + return + if op_name in SUPPORTED_VECSCOPE_PTO_CALLS: + return + if op_name in ADVANCED_VECSCOPE_PTO_CALLS: + if context.advanced_enabled: + return + raise context.error( + node, + advanced_mode_message(op_name), + ) + if op_name in ADVANCED_EXPR_PTO_CALLS or op_name in ADVANCED_TOPLEVEL_PTO_CALLS: + if context.advanced_enabled: + return + raise context.error( + node, + advanced_mode_message(op_name), + ) + if op_name in DEFERRED_PTO_SURFACES: + raise context.error( + node, + deferred_surface_message(op_name), + ) + raise context.error( + node, + f"unsupported op surface `pto.{op_name}` in TileLang DSL v1", + ) + + +def _build_call_keywords( + node: ast.Call, + *, + namespace: str | None, + name: str, + context: _FrontendBuildContext, +) -> tuple[tuple[str, FrontendExprNode], ...]: + if not node.keywords: + return () + + for keyword in node.keywords: + if keyword.arg is None: + raise context.error( + keyword.value, + "keyword unpacking via `**` is not supported in TileLang DSL v1", + ) + + allowed_keywords = _DMA_CALL_KEYWORDS.get(name) if namespace == "pto" else None + if allowed_keywords is None: + call_name = f"{namespace + '.' if namespace else ''}{name}" + raise context.error( + node, + f"`{call_name}` does not support keyword arguments in TileLang DSL v1; " + "no public call surface currently accepts them", + ) + + seen: set[str] = set() + built_keywords: list[tuple[str, FrontendExprNode]] = [] + for keyword in node.keywords: + assert keyword.arg is not None + if keyword.arg in seen: + raise context.error( + keyword.value, + f"duplicate keyword `{keyword.arg}` for `pto.{name}` in TileLang DSL v1", + ) + if keyword.arg not in allowed_keywords: + raise context.error( + keyword.value, + f"unsupported keyword `{keyword.arg}` for `pto.{name}` in TileLang DSL v1", + ) + seen.add(keyword.arg) + built_keywords.append((keyword.arg, _build_expr(keyword.value, context))) + return tuple(built_keywords) + + +def _build_expr(node: ast.AST, context: _FrontendBuildContext) -> FrontendExprNode: if isinstance(node, ast.Name): return FrontendNameExpr(name=node.id) if isinstance(node, ast.Constant): return FrontendConstantExpr(value=node.value) + if isinstance(node, ast.UnaryOp): + if isinstance(node.op, ast.UAdd): + sign = 1 + elif isinstance(node.op, ast.USub): + sign = -1 + else: + raise context.error( + node, + f"unsupported unary operator `{type(node.op).__name__}` in TileLang DSL v1", + ) + if not isinstance(node.operand, ast.Constant) or isinstance(node.operand.value, bool): + raise context.error( + node, + "unary +/- currently only supports numeric literals in TileLang DSL v1", + ) + literal = node.operand.value + if not isinstance(literal, (int, float)): + raise context.error( + node, + "unary +/- currently only supports numeric literals in TileLang DSL v1", + ) + return FrontendConstantExpr(value=literal if sign > 0 else -literal) if isinstance(node, ast.Slice): - start = None if node.lower is None else _build_expr(node.lower, source_info) - stop = None if node.upper is None else _build_expr(node.upper, source_info) - step = None if node.step is None else _build_expr(node.step, source_info) + start = None if node.lower is None else _build_expr(node.lower, context) + stop = None if node.upper is None else _build_expr(node.upper, context) + step = None if node.step is None else _build_expr(node.step, context) return FrontendSliceExpr(start=start, stop=stop, step=step) if isinstance(node, ast.Tuple): return FrontendTupleExpr( - elements=tuple(_build_expr(elt, source_info) for elt in node.elts) + elements=tuple(_build_expr(elt, context) for elt in node.elts) ) if isinstance(node, ast.Attribute): path = _attribute_path(node) if path is not None and path[0] in {"pto", "PAT", "PIPE", "EVENT"} and len(path) >= 2: return FrontendSymbolExpr(namespace=".".join(path[:-1]), name=path[-1]) - return FrontendAttributeExpr(base=_build_expr(node.value, source_info), attr=node.attr) + return FrontendAttributeExpr(base=_build_expr(node.value, context), attr=node.attr) if isinstance(node, ast.Subscript): return FrontendSubscriptExpr( - base=_build_expr(node.value, source_info), - index=_build_expr(node.slice, source_info), + base=_build_expr(node.value, context), + index=_build_expr(node.slice, context), ) if isinstance(node, ast.BinOp): op_name = _BINARY_OP_NAMES.get(type(node.op)) if op_name is None: - raise source_info.error( + raise context.error( node, f"unsupported binary operator `{type(node.op).__name__}` in TileLang DSL v1", ) return FrontendBinaryExpr( - lhs=_build_expr(node.left, source_info), + lhs=_build_expr(node.left, context), op=op_name, - rhs=_build_expr(node.right, source_info), + rhs=_build_expr(node.right, context), ) + if isinstance(node, ast.Compare): + if len(node.ops) != 1 or len(node.comparators) != 1: + raise context.error( + node, + "chained comparisons are not supported in TileLang DSL v1", + ) + op_name = _COMPARE_OP_NAMES.get(type(node.ops[0])) + if op_name is None: + raise context.error( + node, + f"unsupported comparison operator `{type(node.ops[0]).__name__}` in TileLang DSL v1", + ) + return FrontendBinaryExpr( + lhs=_build_expr(node.left, context), + op=op_name, + rhs=_build_expr(node.comparators[0], context), + ) + if isinstance(node, ast.BoolOp): + op_name = _BOOL_OP_NAMES.get(type(node.op)) + if op_name is None: + raise context.error( + node, + f"unsupported boolean operator `{type(node.op).__name__}` in TileLang DSL v1", + ) + if len(node.values) < 2: + raise context.error( + node, + "boolean expressions must contain at least two operands in TileLang DSL v1", + ) + expr = _build_expr(node.values[0], context) + for value in node.values[1:]: + expr = FrontendBinaryExpr( + lhs=expr, + op=op_name, + rhs=_build_expr(value, context), + ) + return expr if isinstance(node, ast.Call): + if isinstance(node.func, ast.Name) and node.func.id in context.inline_procs: + inline_proc = context.inline_procs[node.func.id] + if node.func.id in context.active_inline_proc_stack: + raise context.error( + node, + f"recursive inline_proc call `{node.func.id}` is not supported in TileLang DSL v1", + ) + return FrontendCallExpr( + namespace=None, + name=node.func.id, + args=_bind_inline_proc_call(node, inline_proc, context), + keywords=(), + ) + if ( + isinstance(node.func, ast.Attribute) + and isinstance(node.func.value, ast.Name) + and node.func.value.id == "pto" + and node.func.attr == "tpl" + ): + if not node.args: + raise context.error( + node, + "pto.tpl() requires a non-empty string literal slot name as the first argument", + ) + slot_expr = node.args[0] + if not ( + isinstance(slot_expr, ast.Constant) + and isinstance(slot_expr.value, str) + and slot_expr.value + ): + raise context.error( + slot_expr, + "pto.tpl() requires a non-empty string literal slot name", + ) + slot_name = slot_expr.value + slot_bindings = context.templates.get(slot_name) + if slot_bindings is None: + raise context.error( + slot_expr, + f"unknown template slot {slot_name!r} in TileLang DSL v1", + ) + if context.selected_op is None: + raise context.error( + node, + "pto.tpl() requires pto.select_kernel(...) to bind a concrete op before expansion", + ) + resolved_op = slot_bindings.get(context.selected_op) + if resolved_op is None: + raise context.error( + slot_expr, + f"template slot {slot_name!r} does not define an implementation for " + f"selected op {context.selected_op!r}", + ) + _validate_resolved_template_op_surface(resolved_op, node, context) + return FrontendCallExpr( + namespace="pto", + name=resolved_op, + args=tuple(_build_expr(arg, context) for arg in node.args[1:]), + keywords=_build_call_keywords( + node, + namespace="pto", + name=resolved_op, + context=context, + ), + ) if isinstance(node.func, ast.Name): return FrontendCallExpr( namespace=None, name=node.func.id, - args=tuple(_build_expr(arg, source_info) for arg in node.args), + args=tuple(_build_expr(arg, context) for arg in node.args), + keywords=_build_call_keywords( + node, + namespace=None, + name=node.func.id, + context=context, + ), ) if isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name): return FrontendCallExpr( namespace=node.func.value.id, name=node.func.attr, - args=tuple(_build_expr(arg, source_info) for arg in node.args), + args=tuple(_build_expr(arg, context) for arg in node.args), + keywords=_build_call_keywords( + node, + namespace=node.func.value.id, + name=node.func.attr, + context=context, + ), ) - raise source_info.error( + raise context.error( node, f"unsupported expression `{type(node).__name__}` in TileLang DSL v1", ) -def _build_target(node: ast.AST, source_info: Any) -> FrontendTargetNode: +def _build_target(node: ast.AST, context: _FrontendBuildContext) -> FrontendTargetNode: if isinstance(node, ast.Name): return FrontendNameTarget(name=node.id) if isinstance(node, ast.Tuple): elements = [] for elt in node.elts: if not isinstance(elt, ast.Name): - raise source_info.error(elt, "tuple assignment only supports names in TileLang DSL v1") + raise context.error(elt, "tuple assignment only supports names in TileLang DSL v1") elements.append(FrontendNameTarget(name=elt.id)) return FrontendTupleTarget(elements=tuple(elements)) - raise source_info.error( + raise context.error( node, f"unsupported assignment target `{type(node).__name__}` in TileLang DSL v1", ) -def _build_stmt(node: ast.stmt, source_info: Any) -> FrontendStmtNode: +def _build_stmt_list(nodes: list[ast.stmt] | tuple[ast.stmt, ...], context: _FrontendBuildContext) -> tuple[FrontendStmtNode, ...]: + return tuple(_build_stmt(node, context) for node in nodes) + + +def _build_stmt(node: ast.stmt, context: _FrontendBuildContext) -> FrontendStmtNode: if isinstance(node, ast.Assign): if len(node.targets) != 1: - raise source_info.error(node, "multiple assignment targets are not supported in TileLang DSL v1") + raise context.error(node, "multiple assignment targets are not supported in TileLang DSL v1") return FrontendAssignStmt( - target=_build_target(node.targets[0], source_info), - value=_build_expr(node.value, source_info), + target=_build_target(node.targets[0], context), + value=_build_expr(node.value, context), ) if isinstance(node, ast.AnnAssign): if node.value is None: - raise source_info.error(node, "annotation-only assignments are not supported in TileLang DSL v1") + raise context.error(node, "annotation-only assignments are not supported in TileLang DSL v1") return FrontendAssignStmt( - target=_build_target(node.target, source_info), - value=_build_expr(node.value, source_info), + target=_build_target(node.target, context), + value=_build_expr(node.value, context), annotation=node.annotation, ) if isinstance(node, ast.Expr): - return FrontendExprStmt(expr=_build_expr(node.value, source_info)) + return FrontendExprStmt(expr=_build_expr(node.value, context)) if isinstance(node, ast.Return): value = None if node.value is not None: if not (isinstance(node.value, ast.Constant) and node.value.value is None): - value = _build_expr(node.value, source_info) + value = _build_expr(node.value, context) return FrontendReturnStmt(value=value) if isinstance(node, ast.For): if not isinstance(node.target, ast.Name): - raise source_info.error(node.target, "for target must be a single name") + raise context.error(node.target, "for target must be a single name") if not isinstance(node.iter, ast.Call) or not isinstance(node.iter.func, ast.Name) or node.iter.func.id != "range": - raise source_info.error(node.iter, "only Python range(lb, ub, step) loops are supported") + raise context.error(node.iter, "only Python range(lb, ub, step) loops are supported") if len(node.iter.args) != 3: - raise source_info.error(node.iter, "range() expects exactly 3 arguments in TileLang DSL v1") + raise context.error(node.iter, "range() expects exactly 3 arguments in TileLang DSL v1") return FrontendForStmt( target=node.target.id, - lower_bound=_build_expr(node.iter.args[0], source_info), - upper_bound=_build_expr(node.iter.args[1], source_info), - step=_build_expr(node.iter.args[2], source_info), - body=tuple(_build_stmt(stmt, source_info) for stmt in node.body), + lower_bound=_build_expr(node.iter.args[0], context), + upper_bound=_build_expr(node.iter.args[1], context), + step=_build_expr(node.iter.args[2], context), + body=_build_stmt_list(node.body, context), ) if isinstance(node, ast.If): + is_constexpr = False + condition_node: ast.AST = node.test + if ( + isinstance(node.test, ast.Call) + and isinstance(node.test.func, ast.Attribute) + and isinstance(node.test.func.value, ast.Name) + and node.test.func.value.id == "pto" + and node.test.func.attr == "constexpr" + ): + if node.test.keywords: + raise context.error( + node.test, + "pto.constexpr() does not support keyword arguments in TileLang DSL v1", + ) + if len(node.test.args) != 1: + raise context.error( + node.test, + "pto.constexpr() expects exactly 1 positional argument in TileLang DSL v1", + ) + is_constexpr = True + condition_node = node.test.args[0] return FrontendIfStmt( - condition=_build_expr(node.test, source_info), - then_body=tuple(_build_stmt(stmt, source_info) for stmt in node.body), - else_body=tuple(_build_stmt(stmt, source_info) for stmt in node.orelse), + condition=_build_expr(condition_node, context), + then_body=_build_stmt_list(node.body, context), + else_body=_build_stmt_list(node.orelse, context), + is_constexpr=is_constexpr, ) if isinstance(node, ast.With): if len(node.items) != 1: - raise source_info.error(node, "only a single with-item is supported in TileLang DSL v1") + raise context.error(node, "only a single with-item is supported in TileLang DSL v1") item = node.items[0] if not isinstance(item.context_expr, ast.Call): - raise source_info.error(item.context_expr, "with context must be a call in TileLang DSL v1") + raise context.error(item.context_expr, "with context must be a call in TileLang DSL v1") if not ( isinstance(item.context_expr.func, ast.Attribute) and isinstance(item.context_expr.func.value, ast.Name) and item.context_expr.func.value.id == "pto" - and item.context_expr.func.attr == "strict_vecscope" ): - raise source_info.error(item.context_expr, "only pto.strict_vecscope is supported in TileLang DSL v1") + raise context.error( + item.context_expr, + "only pto.vecscope/pto.strict_vecscope are supported in TileLang DSL v1", + ) + with_name = item.context_expr.func.attr + if with_name == "vecscope": + if item.context_expr.args or item.context_expr.keywords: + raise context.error( + item.context_expr, + "pto.vecscope() does not accept positional or keyword arguments in TileLang DSL v1", + ) + if item.optional_vars is not None: + raise context.error(item, "pto.vecscope() does not support `as` bindings in TileLang DSL v1") + return FrontendVecscopeStmt( + body=_build_stmt_list(node.body, context.nested_vecscope()), + ) + if with_name != "strict_vecscope": + raise context.error( + item.context_expr, + "only pto.vecscope/pto.strict_vecscope are supported in TileLang DSL v1", + ) + if not context.advanced_enabled: + raise context.error( + item.context_expr, + advanced_mode_message("strict_vecscope"), + ) if not isinstance(item.optional_vars, ast.Tuple): - raise source_info.error(item, "pto.strict_vecscope requires tuple binding in 'as'") + raise context.error(item, "pto.strict_vecscope requires tuple binding in 'as'") block_arguments = [] for elt in item.optional_vars.elts: if not isinstance(elt, ast.Name): - raise source_info.error(elt, "pto.strict_vecscope bindings must be names") + raise context.error(elt, "pto.strict_vecscope bindings must be names") block_arguments.append(elt.id) return FrontendStrictVecscopeStmt( - captures=tuple(_build_expr(arg, source_info) for arg in item.context_expr.args), + captures=tuple(_build_expr(arg, context) for arg in item.context_expr.args), block_arguments=tuple(block_arguments), - body=tuple(_build_stmt(stmt, source_info) for stmt in node.body), + body=_build_stmt_list(node.body, context.nested_vecscope()), ) - raise source_info.error( + raise context.error( node, f"unsupported statement `{type(node).__name__}` in TileLang DSL v1", ) @@ -335,13 +1088,112 @@ def build_frontend_kernel_node(descriptor: Any) -> FrontendKernelNode: shape=spec.shape, memory_space=spec.memory_space.value, config=spec.config, + valid_shape=spec.valid_shape, ) for name, spec in descriptor.specializations ) source_info = descriptor._source_info + sorted_inline_procs = tuple(sorted(descriptor.inline_procs.items(), key=lambda item: item[0])) + context = _FrontendBuildContext( + source_info=source_info, + templates=descriptor.templates, + selected_op=descriptor.selected_op, + advanced_enabled=descriptor.advanced_enabled, + inline_procs={ + name: _FrontendInlineProc( + name=name, + source_info=proc.source_info, + signature=proc.signature, + ) + for name, proc in sorted_inline_procs + }, + ) body = () if source_info is not None: - body = tuple(_build_stmt(stmt, source_info) for stmt in source_info.function_def.body) + body = _build_stmt_list(source_info.function_def.body, context) + + inline_proc_descriptors = {name: descriptor for name, descriptor in sorted_inline_procs} + inline_proc_names = set(inline_proc_descriptors) + root_inline_calls: set[str] = set() + for stmt in body: + _collect_inline_proc_calls_stmt(stmt, inline_proc_names, root_inline_calls) + + inline_proc_nodes_by_name: dict[str, FrontendInlineProcNode] = {} + inline_proc_source_infos: dict[str, Any] = {} + pending = list(sorted(root_inline_calls)) + while pending: + name = pending.pop() + if name in inline_proc_nodes_by_name: + continue + inline_proc_descriptor = inline_proc_descriptors.get(name) + if inline_proc_descriptor is None: + continue + inline_source = inline_proc_descriptor.source_info + if inline_source is None: + if source_info is not None: + raise context.error( + source_info.function_def, + f"inline_proc `{name}` requires source-visible Python functions", + ) + raise ValueError( + f"inline_proc `{name}` requires source-visible Python functions" + ) + inline_proc_source_infos[name] = inline_source + helper_context = context.enter_inline_proc(name, inline_source) + helper_body = _build_stmt_list(inline_source.function_def.body, helper_context) + parameter_specs = _inline_proc_param_specs( + _FrontendInlineProc( + name=name, + source_info=inline_source, + signature=inline_proc_descriptor.signature, + ) + ) + inline_proc_node = FrontendInlineProcNode( + name=name, + parameters=tuple( + FrontendInlineProcParameterNode( + name=param_name, + annotation=arg.annotation, + default=None + if default_node is None + else _build_expr(default_node, helper_context), + ) + for (param_name, default_node), arg in zip(parameter_specs, inline_source.function_def.args.args) + ), + body=helper_body, + ) + inline_proc_nodes_by_name[name] = inline_proc_node + nested_calls: set[str] = set() + for stmt in helper_body: + _collect_inline_proc_calls_stmt(stmt, inline_proc_names, nested_calls) + for nested in sorted(nested_calls): + if nested not in inline_proc_nodes_by_name: + pending.append(nested) + + reachable_inline_proc_nodes = tuple( + inline_proc_nodes_by_name[name] + for name, _ in sorted_inline_procs + if name in inline_proc_nodes_by_name + ) + for inline_proc_node in reachable_inline_proc_nodes: + source = inline_proc_source_infos[inline_proc_node.name] + helper_context = context.enter_inline_proc(inline_proc_node.name, source) + assigned_names: set[str] = set() + param_names = {parameter.name for parameter in inline_proc_node.parameters} + for stmt in inline_proc_node.body: + _validate_inline_capture( + stmt, + param_names, + assigned_names, + context=helper_context, + ) + + _validate_inline_proc_call_graph( + body, + reachable_inline_proc_nodes, + inline_proc_source_infos, + ) + return FrontendKernelNode( target=descriptor.target, op=descriptor.op, @@ -352,6 +1204,7 @@ def build_frontend_kernel_node(descriptor: Any) -> FrontendKernelNode: parameters=parameters, tile_specializations=tile_specializations, body=body, + inline_procs=reachable_inline_proc_nodes, ) @@ -365,12 +1218,15 @@ def build_frontend_kernel_node(descriptor: Any) -> FrontendKernelNode: "FrontendExprStmt", "FrontendForStmt", "FrontendIfStmt", + "FrontendInlineProcNode", + "FrontendInlineProcParameterNode", "FrontendKernelNode", "FrontendNameExpr", "FrontendNameTarget", "FrontendParameterNode", "FrontendReturnStmt", "FrontendSliceExpr", + "FrontendVecscopeStmt", "FrontendStrictVecscopeStmt", "FrontendStmtNode", "FrontendSubscriptExpr", diff --git a/tilelang-dsl/python/tilelang_dsl/kernel.py b/tilelang-dsl/python/tilelang_dsl/kernel.py index 69b305318..572271484 100644 --- a/tilelang-dsl/python/tilelang_dsl/kernel.py +++ b/tilelang-dsl/python/tilelang_dsl/kernel.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + """Kernel descriptor surface for TileLang DSL v1.""" from __future__ import annotations @@ -13,7 +21,11 @@ from typing import Any, Callable, Mapping from .types import ( + AnyMask, + AnyType, + MaskType, MemorySpace, + PartitionTensorView, PointerType, ScalarType, TensorView, @@ -23,7 +35,7 @@ TypeVariable, WildcardType, ) -from .frontend_ast import build_frontend_kernel_node +from .frontend_ast import _DMA_CALL_KEYWORDS, build_frontend_kernel_node from .lowering import lower_semantic_kernel from .semantic import analyze_frontend_kernel from .support_matrix import ( @@ -40,10 +52,152 @@ _UNSET = object() _PTOAS_BIN_ENV = "PTOAS_BIN" +_SUPPORTED_TEMPLATE_PTO_CALLS = frozenset( + SUPPORTED_TOPLEVEL_PTO_CALLS + | SUPPORTED_VECSCOPE_PTO_CALLS + | ADVANCED_VECSCOPE_PTO_CALLS + | ADVANCED_EXPR_PTO_CALLS + | ADVANCED_TOPLEVEL_PTO_CALLS +) + + +_INLINE_PROC_REGISTRY: dict[tuple[str, str], "InlineProcDescriptor"] = {} + + +@dataclass(frozen=True) +class InlineProcDescriptor: + """Descriptor returned by @tilelang_dsl.inline_proc.""" + + name: str + py_fn: Callable[..., Any] = field(repr=False) + signature: inspect.Signature = field(repr=False) + source_info: "_FunctionSourceInfo | None" = field(repr=False, default=None) + + +class _InlineProcValidator(ast.NodeVisitor): + def __init__(self, source_info: "_FunctionSourceInfo"): + self.source_info = source_info + + def validate(self) -> None: + fn = self.source_info.function_def + args = fn.args + if args.posonlyargs: + raise self.source_info.error(args.posonlyargs[0], "inline_proc does not support positional-only parameters in TileLang DSL v1") + if args.vararg is not None: + raise self.source_info.error(args.vararg, "inline_proc does not support *args in TileLang DSL v1") + if args.kwarg is not None: + raise self.source_info.error(args.kwarg, "inline_proc does not support **kwargs in TileLang DSL v1") + if args.kwonlyargs: + raise self.source_info.error(args.kwonlyargs[0], "inline_proc does not support keyword-only parameters in TileLang DSL v1") + tail_return: ast.Return | None = fn.body[-1] if fn.body and isinstance(fn.body[-1], ast.Return) else None + for node in ast.walk(fn): + if not isinstance(node, ast.Return): + continue + if node is tail_return: + continue + raise self.source_info.error( + node, + "inline_proc only supports an optional trailing `return` in TileLang DSL v1", + ) + + for stmt in fn.body: + self.visit(stmt) + def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + if node is self.source_info.function_def: + for stmt in node.body: + self.visit(stmt) + return + raise self.source_info.error(node, "nested function definitions are not supported inside inline_proc in TileLang DSL v1") + + +def _inline_proc_registry_key(fn: Callable[..., Any]) -> tuple[str, str]: + return (fn.__module__, fn.__name__) + + +def _find_inline_proc(name: str, *, module_name: str | None) -> InlineProcDescriptor | None: + if module_name is None: + return None + return _INLINE_PROC_REGISTRY.get((module_name, name)) + + +def _validate_inline_proc_call_surface( + source_info: _FunctionSourceInfo, + node: ast.Call, + inline_proc: InlineProcDescriptor, +) -> None: + if any(keyword.arg is None for keyword in node.keywords): + keyword = next(keyword for keyword in node.keywords if keyword.arg is None) + raise source_info.error( + keyword.value, + "keyword unpacking via `**` is not supported in TileLang DSL v1", + ) + seen_keywords: set[str] = set() + for keyword in node.keywords: + assert keyword.arg is not None + if keyword.arg in seen_keywords: + raise source_info.error( + keyword.value, + f"duplicate keyword `{keyword.arg}` for inline_proc `{inline_proc.name}` in TileLang DSL v1", + ) + seen_keywords.add(keyword.arg) + positional_placeholders = [object() for _ in node.args] + keyword_placeholders = {keyword.arg: object() for keyword in node.keywords if keyword.arg is not None} + try: + inline_proc.signature.bind(*positional_placeholders, **keyword_placeholders) + except TypeError as exc: + raise source_info.error( + node, + f"invalid inline_proc call `{inline_proc.name}` in TileLang DSL v1: {exc}", + ) from exc + + +def _collect_inline_procs(module_name: str) -> tuple[tuple[str, InlineProcDescriptor], ...]: + return tuple( + sorted( + ( + (symbol, descriptor) + for (registered_module, symbol), descriptor in _INLINE_PROC_REGISTRY.items() + if registered_module == module_name + ), + key=lambda item: item[0], + ) + ) -def _validate_dtype_pattern(dtype: Any) -> ScalarType | WildcardType | TypeVariable: - if isinstance(dtype, (ScalarType, WildcardType, TypeVariable)): + +def _register_inline_proc(descriptor: InlineProcDescriptor) -> InlineProcDescriptor: + _INLINE_PROC_REGISTRY[_inline_proc_registry_key(descriptor.py_fn)] = descriptor + return descriptor + + +def inline_proc( + py_fn: Callable[..., Any] | None = None, +) -> InlineProcDescriptor | Callable[[Callable[..., Any]], InlineProcDescriptor]: + """Register a top-level compile-time inline procedure for TileLang DSL kernels.""" + + def wrap(fn: Callable[..., Any]) -> InlineProcDescriptor: + if not callable(fn): + raise TypeError("@inline_proc can only decorate callables") + source_info = _load_function_source_info(fn) + if source_info is None: + raise TypeError("@inline_proc requires source-visible Python functions") + _InlineProcValidator(source_info).validate() + return _register_inline_proc( + InlineProcDescriptor( + name=fn.__name__, + py_fn=fn, + source_info=source_info, + signature=inspect.signature(fn), + ) + ) + + if py_fn is None: + return wrap + return wrap(py_fn) + + +def _validate_dtype_pattern(dtype: Any) -> ScalarType | MaskType | WildcardType | TypeVariable: + if isinstance(dtype, (ScalarType, MaskType, WildcardType, TypeVariable)): return dtype raise TypeError(f"unsupported dtype pattern {dtype!r}") @@ -82,9 +236,10 @@ def parameter_node(self, param_name: str) -> ast.AST | None: class _KernelBodyValidator(ast.NodeVisitor): - def __init__(self, source_info: _FunctionSourceInfo, *, advanced_enabled: bool): + def __init__(self, source_info: _FunctionSourceInfo, *, advanced_enabled: bool, module_name: str | None): self.source_info = source_info self.advanced_enabled = advanced_enabled + self.module_name = module_name self._vecscope_depth = 0 def validate(self) -> None: @@ -121,6 +276,11 @@ def visit_For(self, node: ast.For) -> None: raise self.source_info.error(node.iter, "only Python range(lb, ub, step) loops are supported") if node.iter.func.id != "range": raise self.source_info.error(node.iter, "only Python range(lb, ub, step) loops are supported") + if node.iter.keywords: + raise self.source_info.error( + node.iter, + "range() does not support keyword arguments in TileLang DSL v1", + ) if len(node.iter.args) != 3: raise self.source_info.error(node.iter, "range() expects exactly 3 arguments in TileLang DSL v1") for stmt in node.body: @@ -144,17 +304,39 @@ def visit_With(self, node: ast.With) -> None: isinstance(item.context_expr.func, ast.Attribute) and isinstance(item.context_expr.func.value, ast.Name) and item.context_expr.func.value.id == "pto" - and item.context_expr.func.attr == "strict_vecscope" ): raise self.source_info.error( item.context_expr, - "only pto.strict_vecscope is supported as a with-context in TileLang DSL v1", + "only pto.vecscope/pto.strict_vecscope are supported as with-contexts in TileLang DSL v1", + ) + with_name = item.context_expr.func.attr + if with_name == "vecscope": + if item.context_expr.args or item.context_expr.keywords: + raise self.source_info.error( + item.context_expr, + "pto.vecscope() does not accept positional or keyword arguments in TileLang DSL v1", + ) + if item.optional_vars is not None: + raise self.source_info.error( + item, + "pto.vecscope() does not support `as` bindings in TileLang DSL v1", + ) + elif with_name == "strict_vecscope": + if not self.advanced_enabled: + raise self.source_info.error( + item.context_expr, + advanced_mode_message("strict_vecscope"), + ) + if not isinstance(item.optional_vars, ast.Tuple): + raise self.source_info.error(item, "pto.strict_vecscope requires tuple binding in 'as'") + for elt in item.optional_vars.elts: + if not isinstance(elt, ast.Name): + raise self.source_info.error(elt, "pto.strict_vecscope bindings must be names") + else: + raise self.source_info.error( + item.context_expr, + "only pto.vecscope/pto.strict_vecscope are supported as with-contexts in TileLang DSL v1", ) - if not isinstance(item.optional_vars, ast.Tuple): - raise self.source_info.error(item, "pto.strict_vecscope requires tuple binding in 'as'") - for elt in item.optional_vars.elts: - if not isinstance(elt, ast.Name): - raise self.source_info.error(elt, "pto.strict_vecscope bindings must be names") self._vecscope_depth += 1 try: for stmt in node.body: @@ -162,21 +344,83 @@ def visit_With(self, node: ast.With) -> None: finally: self._vecscope_depth -= 1 + def _validate_call_keywords(self, node: ast.Call) -> None: + if not node.keywords: + return + for keyword in node.keywords: + if keyword.arg is None: + raise self.source_info.error( + keyword.value, + "keyword unpacking via `**` is not supported in TileLang DSL v1", + ) + + if isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name): + namespace = node.func.value.id + name = node.func.attr + elif isinstance(node.func, ast.Name): + namespace = None + name = node.func.id + else: + raise self.source_info.error( + node, + "unsupported call surface in TileLang DSL v1", + ) + + allowed_keywords = _DMA_CALL_KEYWORDS.get(name) if namespace == "pto" else None + if allowed_keywords is None: + call_name = f"{namespace + '.' if namespace else ''}{name}" + raise self.source_info.error( + node, + f"`{call_name}` does not support keyword arguments in TileLang DSL v1; " + "no public call surface currently accepts them", + ) + + seen: set[str] = set() + for keyword in node.keywords: + assert keyword.arg is not None + if keyword.arg in seen: + raise self.source_info.error( + keyword.value, + f"duplicate keyword `{keyword.arg}` for `pto.{name}` in TileLang DSL v1", + ) + if keyword.arg not in allowed_keywords: + raise self.source_info.error( + keyword.value, + f"unsupported keyword `{keyword.arg}` for `pto.{name}` in TileLang DSL v1", + ) + seen.add(keyword.arg) + def visit_Call(self, node: ast.Call) -> None: if isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name): - if node.func.value.id == "pto" and node.func.attr in SUPPORTED_TOPLEVEL_PTO_CALLS: - return - if node.func.value.id == "pto" and node.func.attr in SUPPORTED_VECSCOPE_PTO_CALLS: - if self.advanced_enabled: - return - if self._vecscope_depth <= 0: + if node.func.attr == "as_ptr": + if node.keywords: raise self.source_info.error( node, - f"vector op surface `pto.{node.func.attr}` requires explicit pto.strict_vecscope in TileLang DSL v1", + "`as_ptr` does not support keyword arguments in TileLang DSL v1", ) + if node.args: + raise self.source_info.error( + node, + "`as_ptr()` does not accept positional arguments in TileLang DSL v1", + ) + if self.advanced_enabled: + return + raise self.source_info.error( + node, + "surface `as_ptr` requires advanced=True in TileLang DSL v1", + ) + if node.func.value.id == "pto" and node.func.attr == "tpl": + self._validate_call_keywords(node) + return + if node.func.value.id == "pto" and node.func.attr in SUPPORTED_TOPLEVEL_PTO_CALLS: + self._validate_call_keywords(node) + return + if node.func.value.id == "pto" and node.func.attr in SUPPORTED_VECSCOPE_PTO_CALLS: + self._validate_call_keywords(node) return if node.func.value.id == "pto" and node.func.attr in ADVANCED_VECSCOPE_PTO_CALLS: if self.advanced_enabled: + self._validate_call_keywords(node) return raise self.source_info.error( node, @@ -187,6 +431,7 @@ def visit_Call(self, node: ast.Call) -> None: or node.func.attr in ADVANCED_TOPLEVEL_PTO_CALLS ): if self.advanced_enabled: + self._validate_call_keywords(node) return raise self.source_info.error( node, @@ -210,6 +455,11 @@ def visit_Call(self, node: ast.Call) -> None: if isinstance(node.func, ast.Name): if node.func.id == "range": + self._validate_call_keywords(node) + return + inline_proc = _find_inline_proc(node.func.id, module_name=self.module_name) + if inline_proc is not None: + _validate_inline_proc_call_surface(self.source_info, node, inline_proc) return raise self.source_info.error( node, @@ -241,12 +491,14 @@ def _validate_function_body( source_info: _FunctionSourceInfo | None, *, advanced_enabled: bool, + module_name: str | None, ) -> None: if source_info is None: return _KernelBodyValidator( source_info, advanced_enabled=advanced_enabled, + module_name=module_name, ).validate() @@ -289,11 +541,11 @@ class BoundKernelParameter: name: str kind: str annotation: Any - dtype: ScalarType + dtype: Any @property def element_dtype(self) -> ScalarType | None: - if self.kind in ("tensorview", "tile", "ptr"): + if self.kind in ("tensorview", "partition_tensor_view", "tile", "ptr"): return self.dtype return None @@ -307,12 +559,150 @@ class KernelParameterSpec: annotation: Any +@dataclass(frozen=True) +class _ConstraintValue: + value: Any | None + + def _coerce_other(self, other: Any) -> Any | None: + if isinstance(other, _ConstraintValue): + return other.value + return other + + def _arith(self, other: Any, fn: Callable[[Any, Any], Any]) -> "_ConstraintValue": + other_value = self._coerce_other(other) + if self.value is None or other_value is None: + return _ConstraintValue(None) + return _ConstraintValue(fn(self.value, other_value)) + + def _compare(self, other: Any, fn: Callable[[Any, Any], bool]) -> bool: + other_value = self._coerce_other(other) + if self.value is None or other_value is None: + return True + return fn(self.value, other_value) + + def __add__(self, other: Any) -> "_ConstraintValue": + return self._arith(other, lambda lhs, rhs: lhs + rhs) + + def __radd__(self, other: Any) -> "_ConstraintValue": + return _ConstraintValue(self._coerce_other(other)).__add__(self) + + def __sub__(self, other: Any) -> "_ConstraintValue": + return self._arith(other, lambda lhs, rhs: lhs - rhs) + + def __rsub__(self, other: Any) -> "_ConstraintValue": + return _ConstraintValue(self._coerce_other(other)).__sub__(self) + + def __mul__(self, other: Any) -> "_ConstraintValue": + return self._arith(other, lambda lhs, rhs: lhs * rhs) + + def __rmul__(self, other: Any) -> "_ConstraintValue": + return _ConstraintValue(self._coerce_other(other)).__mul__(self) + + def __floordiv__(self, other: Any) -> "_ConstraintValue": + return self._arith(other, lambda lhs, rhs: lhs // rhs) + + def __rfloordiv__(self, other: Any) -> "_ConstraintValue": + return _ConstraintValue(self._coerce_other(other)).__floordiv__(self) + + def __eq__(self, other: Any) -> bool: # type: ignore[override] + return self._compare(other, lambda lhs, rhs: lhs == rhs) + + def __ne__(self, other: Any) -> bool: # type: ignore[override] + return self._compare(other, lambda lhs, rhs: lhs != rhs) + + def __le__(self, other: Any) -> bool: + return self._compare(other, lambda lhs, rhs: lhs <= rhs) + + def __lt__(self, other: Any) -> bool: + return self._compare(other, lambda lhs, rhs: lhs < rhs) + + def __ge__(self, other: Any) -> bool: + return self._compare(other, lambda lhs, rhs: lhs >= rhs) + + def __gt__(self, other: Any) -> bool: + return self._compare(other, lambda lhs, rhs: lhs > rhs) + + def __bool__(self) -> bool: + if self.value is None: + return True + return bool(self.value) + + def __repr__(self) -> str: + return "?" if self.value is None else repr(self.value) + + +class _ConstraintSequenceView: + def __init__(self, values: tuple[Any | None, ...]): + self._values = tuple(_ConstraintValue(value) for value in values) + + def __getitem__(self, index: int) -> _ConstraintValue: + if -len(self._values) <= index < len(self._values): + return self._values[index] + return _ConstraintValue(None) + + def __len__(self) -> int: + return len(self._values) + + def __iter__(self): + return iter(self._values) + + def __repr__(self) -> str: + return repr(tuple(self._values)) + + +class _ConstraintParamView: + def __init__(self, name: str, attrs: Mapping[str, Any]): + self._name = name + self._attrs = dict(attrs) + + def _sequence_attr(self, attr_name: str) -> _ConstraintSequenceView: + values = self._attrs.get(attr_name) + if values is None: + rank = self._attrs.get("rank") + if isinstance(rank, int) and rank > 0: + values = (None,) * rank + else: + values = () + return _ConstraintSequenceView(tuple(values)) + + @property + def shape(self) -> _ConstraintSequenceView: + return self._sequence_attr("shape") + + @property + def valid_shape(self) -> _ConstraintSequenceView: + return self._sequence_attr("valid_shape") + + @property + def strides(self) -> _ConstraintSequenceView: + return self._sequence_attr("strides") + + @property + def rank(self) -> _ConstraintValue: + return _ConstraintValue(self._attrs.get("rank")) + + @property + def dtype(self) -> Any: + return self._attrs.get("dtype") + + @property + def memory_space(self) -> Any: + return self._attrs.get("memory_space") + + @property + def config(self) -> Any: + return self._attrs.get("config") + + def __repr__(self) -> str: + return f"{self._name}<{self._attrs!r}>" + + @dataclass(frozen=True) class VKernelDescriptor: """Descriptor returned by `@tilelang_dsl.vkernel`.""" target: str - op: str + match_ops: tuple[str, ...] dtypes: tuple[tuple[Any, ...], ...] name: str verify_enabled: bool @@ -323,15 +713,43 @@ class VKernelDescriptor: specializations: tuple[tuple[str, TileSpecialization], ...] = () constraints: tuple[Callable[[Mapping[str, Any]], Any], ...] = field(default=(), repr=False) priority: int = 0 - _selected_dtype_signature: tuple[ScalarType, ...] | None = None + _templates: tuple[tuple[str, tuple[tuple[str, str], ...]], ...] = field(default=(), repr=False) + _inline_procs: tuple[tuple[str, InlineProcDescriptor], ...] = field(default=(), repr=False) + _selected_op: str | None = None + _selected_dtype_signature: tuple[ScalarType | MaskType, ...] | None = None _parameters: tuple[BoundKernelParameter, ...] | None = field(default=None, repr=False) + _constraint_context_attrs: tuple[tuple[str, Any], ...] = field(default=(), repr=False) @property def py_fn(self) -> Callable[..., Any]: return self._py_fn @property - def dtype_signature(self) -> tuple[ScalarType, ...]: + def op(self) -> str: + if self._selected_op is None: + raise ValueError( + "descriptor requires pto.select_kernel(...) to bind a concrete op " + "before reading descriptor.op" + ) + return self._selected_op + + @property + def selected_op(self) -> str | None: + return self._selected_op + + @property + def templates(self) -> dict[str, dict[str, str]]: + return { + slot: dict(op_bindings) + for slot, op_bindings in self._templates + } + + @property + def inline_procs(self) -> dict[str, InlineProcDescriptor]: + return {name: descriptor for name, descriptor in self._inline_procs} + + @property + def dtype_signature(self) -> tuple[ScalarType | MaskType, ...]: if self._selected_dtype_signature is None: raise ValueError( "descriptor requires pto.select_kernel(...) to choose a concrete dtype signature " @@ -352,13 +770,17 @@ def parameters(self) -> tuple[BoundKernelParameter, ...]: def metadata(self) -> dict[str, Any]: return { "target": self.target, - "op": self.op, + "op": self._selected_op, + "match_ops": self.match_ops, + "selected_op": self._selected_op, "dtypes": self.dtypes, "name": self.name, "verify": self.verify_enabled, "advanced": self.advanced_enabled, "constraints": self.constraints, "priority": self.priority, + "templates": self.templates, + "inline_procs": tuple(sorted(self.inline_procs.keys())), } @property @@ -369,17 +791,51 @@ def tile_parameters(self) -> tuple[BoundKernelParameter, ...]: def specializations_by_name(self) -> dict[str, TileSpecialization]: return dict(self.specializations) + @property + def constraint_context_attrs(self) -> dict[str, Any]: + return dict(self._constraint_context_attrs) + def _tile_parameter_names(self) -> tuple[str, ...]: return tuple(param.name for param in self._parameter_specs if param.kind == "tile") + def _bind_constraint_context_attrs( + self, + context_attrs: Mapping[str, Any], + ) -> "VKernelDescriptor": + frozen_context_attrs = tuple( + sorted(dict(context_attrs).items(), key=lambda item: item[0]) + ) + if self._constraint_context_attrs == frozen_context_attrs: + return self + return VKernelDescriptor( + target=self.target, + match_ops=self.match_ops, + dtypes=self.dtypes, + name=self.name, + verify_enabled=self.verify_enabled, + advanced_enabled=self.advanced_enabled, + _parameter_specs=self._parameter_specs, + _py_fn=self._py_fn, + _source_info=self._source_info, + specializations=self.specializations, + constraints=self.constraints, + priority=self.priority, + _templates=self._templates, + _inline_procs=self._inline_procs, + _selected_op=self._selected_op, + _selected_dtype_signature=self._selected_dtype_signature, + _parameters=self._parameters, + _constraint_context_attrs=frozen_context_attrs, + ) + def _bind_selected_dtype_signature( self, - dtype_signature: tuple[ScalarType, ...], + dtype_signature: tuple[ScalarType | MaskType, ...], ) -> "VKernelDescriptor": bound_parameters = _bind_parameters(self._parameter_specs, dtype_signature) return VKernelDescriptor( target=self.target, - op=self.op, + match_ops=self.match_ops, dtypes=self.dtypes, name=self.name, verify_enabled=self.verify_enabled, @@ -390,8 +846,41 @@ def _bind_selected_dtype_signature( specializations=self.specializations, constraints=self.constraints, priority=self.priority, + _templates=self._templates, + _inline_procs=self._inline_procs, + _selected_op=self._selected_op, _selected_dtype_signature=dtype_signature, _parameters=bound_parameters, + _constraint_context_attrs=self._constraint_context_attrs, + ) + + def _bind_selected_op(self, op: str) -> "VKernelDescriptor": + normalized_op = _validate_op(op) + if normalized_op not in self.match_ops: + raise ValueError( + f"selected op {normalized_op!r} is not in descriptor matcher set {self.match_ops!r}" + ) + if self._selected_op == normalized_op: + return self + return VKernelDescriptor( + target=self.target, + match_ops=self.match_ops, + dtypes=self.dtypes, + name=self.name, + verify_enabled=self.verify_enabled, + advanced_enabled=self.advanced_enabled, + _parameter_specs=self._parameter_specs, + _py_fn=self._py_fn, + _source_info=self._source_info, + specializations=self.specializations, + constraints=self.constraints, + priority=self.priority, + _templates=self._templates, + _inline_procs=self._inline_procs, + _selected_op=normalized_op, + _selected_dtype_signature=self._selected_dtype_signature, + _parameters=self._parameters, + _constraint_context_attrs=self._constraint_context_attrs, ) def specialize(self, **bindings: Any) -> "VKernelDescriptor": @@ -417,7 +906,7 @@ def specialize(self, **bindings: Any) -> "VKernelDescriptor": return VKernelDescriptor( target=self.target, - op=self.op, + match_ops=self.match_ops, dtypes=self.dtypes, name=self.name, verify_enabled=self.verify_enabled, @@ -427,9 +916,13 @@ def specialize(self, **bindings: Any) -> "VKernelDescriptor": specializations=tuple(sorted(updated.items())), constraints=self.constraints, priority=self.priority, + _templates=self._templates, + _inline_procs=self._inline_procs, + _selected_op=self._selected_op, _selected_dtype_signature=self._selected_dtype_signature, _parameters=self._parameters, _py_fn=self._py_fn, + _constraint_context_attrs=self._constraint_context_attrs, ) def _require_specialized_tiles(self, api_name: str) -> None: @@ -448,6 +941,100 @@ def _require_specialized_tiles(self, api_name: str) -> None: f"{missing_names}", ) + def _require_materialization_binding(self, api_name: str) -> None: + self.parameters + if len(self.match_ops) > 1 and self._selected_op is None: + raise ValueError( + f"{api_name}() requires pto.select_kernel(...) to bind a concrete op " + "before materialization" + ) + + def _constraint_context_for_evaluation( + self, + extra_context_attrs: Mapping[str, Any] | None = None, + ) -> dict[str, Any]: + attrs = dict(self._constraint_context_attrs) + if extra_context_attrs is not None: + attrs.update(extra_context_attrs) + attrs.setdefault("target", self.target) + if self._selected_op is not None: + attrs.setdefault("op", self._selected_op) + attrs.setdefault("selected_op", self._selected_op) + + for spec in self._parameter_specs: + existing = attrs.get(spec.name) + param_attrs = {} if not isinstance(existing, dict) else dict(existing) + param_attrs.setdefault("kind", spec.kind) + attrs.setdefault(f"{spec.name}_kind", spec.kind) + if f"{spec.name}_shape" in attrs: + param_attrs.setdefault("shape", tuple(attrs[f"{spec.name}_shape"])) + if f"{spec.name}_valid_shape" in attrs: + param_attrs.setdefault("valid_shape", tuple(attrs[f"{spec.name}_valid_shape"])) + if f"{spec.name}_strides" in attrs: + param_attrs.setdefault("strides", tuple(attrs[f"{spec.name}_strides"])) + if f"{spec.name}_rank" in attrs: + param_attrs.setdefault("rank", attrs[f"{spec.name}_rank"]) + if f"{spec.name}_memory_space" in attrs: + param_attrs.setdefault("memory_space", attrs[f"{spec.name}_memory_space"]) + + if spec.kind in ("tensorview", "partition_tensor_view"): + # TensorView authoring form is normalized to 5D in the current DSL spec. + param_attrs.setdefault("rank", 5) + param_attrs.setdefault("memory_space", "gm") + attrs.setdefault(f"{spec.name}_rank", 5) + attrs.setdefault(f"{spec.name}_memory_space", "gm") + attrs[spec.name] = param_attrs + + if self._parameters is not None: + for param in self._parameters: + param_attrs = attrs.get(param.name) + if not isinstance(param_attrs, dict): + param_attrs = {"kind": param.kind} + param_attrs.setdefault("dtype", param.dtype) + attrs[param.name] = param_attrs + attrs.setdefault(f"{param.name}_dtype", param.dtype) + + for name, spec in self.specializations_by_name.items(): + effective_valid_shape = spec.shape if spec.valid_shape is None else spec.valid_shape + param_attrs = attrs.get(name) + if not isinstance(param_attrs, dict): + param_attrs = {"kind": "tile"} + config_mapping = None if spec.config is None else dict(spec.config.fields) + param_attrs.update( + { + "shape": spec.shape, + "rank": len(spec.shape), + "memory_space": spec.memory_space.value, + "valid_shape": effective_valid_shape, + "config": config_mapping, + } + ) + attrs[name] = param_attrs + attrs[f"{name}_shape"] = spec.shape + attrs[f"{name}_rank"] = len(spec.shape) + attrs[f"{name}_memory_space"] = spec.memory_space.value + attrs[f"{name}_valid_shape"] = effective_valid_shape + if len(spec.shape) == 1: + attrs[f"{name}_extent"] = spec.shape[0] + attrs[f"{name}_valid_extent"] = effective_valid_shape[0] + elif len(spec.shape) == 2: + attrs[f"{name}_rows"] = spec.shape[0] + attrs[f"{name}_cols"] = spec.shape[1] + attrs[f"{name}_valid_rows"] = effective_valid_shape[0] + attrs[f"{name}_valid_cols"] = effective_valid_shape[1] + return attrs + + def _validate_materialization_constraints(self, api_name: str) -> None: + if not self.constraints: + return + context_attrs = self._constraint_context_for_evaluation() + if _evaluate_constraints(self, context_attrs): + return + raise LookupError( + f"{api_name}() constraint evaluation rejected kernel {self.name!r} " + "for the current specialization/context attributes" + ) + def _build_authoring_module(self): self.parameters frontend_kernel = build_frontend_kernel_node(self) @@ -455,19 +1042,26 @@ def _build_authoring_module(self): return lower_semantic_kernel(semantic_kernel) def mlir_text(self) -> str: + self._require_materialization_binding("mlir_text") self._require_specialized_tiles("mlir_text") + self._validate_materialization_constraints("mlir_text") return self._build_authoring_module().render() def mlir_module(self) -> "MaterializedMLIRModule": + self._require_materialization_binding("mlir_module") self._require_specialized_tiles("mlir_module") return MaterializedMLIRModule(text=self.mlir_text(), target=self.target) def verify(self, *, ptoas_bin: str | Path | None = None) -> "VerificationResult": + self._require_materialization_binding("verify") self._require_specialized_tiles("verify") + self._validate_materialization_constraints("verify") return self.mlir_module().verify(ptoas_bin=ptoas_bin) def emit(self, path: str | Path) -> None: + self._require_materialization_binding("emit") self._require_specialized_tiles("emit") + self._validate_materialization_constraints("emit") output_path = Path(path) output_path.write_text(self.mlir_text(), encoding="utf-8") @@ -688,6 +1282,82 @@ def _validate_op(op: Any) -> str: return op +def _freeze_match_ops(*, op: Any, ops: Any) -> tuple[str, ...]: + if op is not None and ops is not None: + raise ValueError("vkernel() accepts either op= or ops=, but not both") + if op is None and ops is None: + raise ValueError("vkernel() requires exactly one of op= or ops=") + if op is not None: + return (_validate_op(op),) + if not isinstance(ops, (list, tuple)): + raise TypeError("ops must be a sequence of non-empty strings") + if not ops: + raise ValueError("ops must contain at least one op") + normalized_ops = tuple(_validate_op(candidate) for candidate in ops) + if len(set(normalized_ops)) != len(normalized_ops): + raise ValueError("ops must not contain duplicates") + return normalized_ops + + +def _validate_template_slot_name(slot: Any) -> str: + if not isinstance(slot, str) or not slot: + raise TypeError("template slot names must be non-empty strings") + return slot + + +def _validate_template_value(slot: str, op_name: str, value: Any) -> str: + if not isinstance(value, str) or not value: + raise TypeError( + f"templates[{slot!r}][{op_name!r}] must be a non-empty pto op name string" + ) + if value not in _SUPPORTED_TEMPLATE_PTO_CALLS: + raise ValueError( + f"templates[{slot!r}][{op_name!r}] maps to unsupported pto op {value!r}" + ) + return value + + +def _freeze_templates( + templates: Any, + *, + match_ops: tuple[str, ...], +) -> tuple[tuple[str, tuple[tuple[str, str], ...]], ...]: + if templates in (_UNSET, None): + return () + if not isinstance(templates, Mapping): + raise TypeError("templates must be a mapping of slot names to per-op mappings") + + frozen_templates = [] + for slot, op_bindings in templates.items(): + normalized_slot = _validate_template_slot_name(slot) + if not isinstance(op_bindings, Mapping): + raise TypeError( + f"templates[{normalized_slot!r}] must be a mapping of concrete ops to pto op names" + ) + if not op_bindings: + raise ValueError( + f"templates[{normalized_slot!r}] must contain at least one concrete-op mapping" + ) + + frozen_bindings = [] + for concrete_op, real_op in op_bindings.items(): + normalized_concrete_op = _validate_op(concrete_op) + if normalized_concrete_op not in match_ops: + raise ValueError( + f"templates[{normalized_slot!r}] references op {normalized_concrete_op!r} " + f"outside descriptor matcher set {match_ops!r}" + ) + frozen_bindings.append( + ( + normalized_concrete_op, + _validate_template_value(normalized_slot, normalized_concrete_op, real_op), + ) + ) + frozen_templates.append((normalized_slot, tuple(frozen_bindings))) + + return tuple(frozen_templates) + + def _validate_name(py_fn: Callable[..., Any], name: Any) -> str: if name is None: return py_fn.__name__ @@ -758,6 +1428,67 @@ def _coerce_tile_config(value: Any, param_name: str) -> TileConfig | None: ) +def _coerce_tile_valid_shape( + shape: tuple[int, ...], + value: Any, + param_name: str, + source_info: _FunctionSourceInfo | None, +) -> tuple[int | None, ...] | None: + if value is None: + return None + if not isinstance(value, (list, tuple)): + _raise_tile_param_error( + source_info, + param_name, + f"specialization for '{param_name}' must provide valid_shape as a tuple/list", + TypeError, + ) + valid_shape = tuple(value) + if len(valid_shape) != len(shape): + _raise_tile_param_error( + source_info, + param_name, + f"illegal Tile profile for '{param_name}': valid_shape rank must match shape rank", + ) + + normalized: list[int | None] = [] + for axis, (valid_dim, shape_dim) in enumerate(zip(valid_shape, shape)): + if isinstance(valid_dim, bool): + _raise_tile_param_error( + source_info, + param_name, + f"illegal Tile profile for '{param_name}': valid_shape axis {axis} must not be bool", + TypeError, + ) + if isinstance(valid_dim, int): + if valid_dim <= 0: + _raise_tile_param_error( + source_info, + param_name, + f"illegal Tile profile for '{param_name}': valid_shape axis {axis} must be positive", + ) + if valid_dim > shape_dim: + _raise_tile_param_error( + source_info, + param_name, + f"illegal Tile profile for '{param_name}': valid_shape axis {axis}={valid_dim} " + f"must be <= shape axis {axis}={shape_dim}", + ) + normalized.append(valid_dim) + continue + if valid_dim is None or isinstance(valid_dim, str): + normalized.append(None) + continue + _raise_tile_param_error( + source_info, + param_name, + f"illegal Tile profile for '{param_name}': valid_shape axis {axis} must be " + "a positive int, string symbol, or None", + TypeError, + ) + return tuple(normalized) + + def _coerce_tile_specialization( param_name: str, binding: Any, @@ -784,6 +1515,7 @@ def _coerce_tile_specialization( shape=tuple(binding["shape"]), memory_space=_coerce_memory_space(binding["memory_space"], param_name), config=_coerce_tile_config(binding.get("config"), param_name), + valid_shape=binding.get("valid_shape"), ) else: _raise_tile_param_error( @@ -825,45 +1557,64 @@ def _coerce_tile_specialization( param_name, f"illegal Tile profile for '{param_name}': v1 only supports MemorySpace.UB", ) - return spec + valid_shape = _coerce_tile_valid_shape(spec.shape, spec.valid_shape, param_name, source_info) + return TileSpecialization( + shape=spec.shape, + memory_space=spec.memory_space, + config=spec.config, + valid_shape=valid_shape, + ) -def _validate_scalar_dtype(dtype: Any, param_name: str) -> ScalarType: - if not isinstance(dtype, ScalarType): +def _validate_leaf_dtype(dtype: Any, param_name: str) -> ScalarType | MaskType: + if not isinstance(dtype, (ScalarType, MaskType)): raise TypeError( - f"dtypes entry for parameter '{param_name}' must be a TileLang scalar dtype" + f"dtypes entry for parameter '{param_name}' must be a TileLang scalar or mask dtype" ) return dtype -def _freeze_operand_types(operand_types: Any) -> tuple[ScalarType, ...]: +def _freeze_operand_types(operand_types: Any) -> tuple[ScalarType | MaskType, ...]: if not isinstance(operand_types, (list, tuple)): - raise TypeError("operand_types must be a sequence of TileLang scalar dtypes") - return tuple(_validate_scalar_dtype(dtype, f"operand_types[{index}]") for index, dtype in enumerate(operand_types)) + raise TypeError("operand_types must be a sequence of TileLang scalar or mask dtypes") + return tuple(_validate_leaf_dtype(dtype, f"operand_types[{index}]") for index, dtype in enumerate(operand_types)) -def _matches_wildcard(pattern: WildcardType, actual: ScalarType) -> bool: +def _matches_wildcard(pattern: WildcardType, actual: ScalarType | MaskType) -> bool: if pattern.name == "AnyType": - return True + return isinstance(actual, ScalarType) if pattern.name == "AnyFloat": - return actual.name in {"f16", "bf16", "f32"} + return isinstance(actual, ScalarType) and actual.name in {"f16", "bf16", "f32"} if pattern.name == "AnyInt": - return actual.name.startswith("i") + return isinstance(actual, ScalarType) and actual.name.startswith("i") if pattern.name == "AnyMask": - return actual.name == "i1" + return isinstance(actual, MaskType) raise TypeError(f"unsupported wildcard matcher {pattern.name!r}") +def _matches_scalar_annotation( + annotation: ScalarType | MaskType | WildcardType | TypeVariable, + actual: ScalarType | MaskType, +) -> bool: + if isinstance(annotation, (ScalarType, MaskType)): + return annotation == actual + if isinstance(annotation, WildcardType): + return _matches_wildcard(annotation, actual) + if isinstance(annotation, TypeVariable): + return True + raise TypeError(f"unsupported scalar annotation {annotation!r}") + + def _match_dtype_signature( dtype_signature: tuple[Any, ...], - operand_types: tuple[ScalarType, ...], -) -> tuple[ScalarType, ...] | None: + operand_types: tuple[ScalarType | MaskType, ...], +) -> tuple[ScalarType | MaskType, ...] | None: if len(dtype_signature) != len(operand_types): return None - typevar_bindings: dict[str, ScalarType] = {} + typevar_bindings: dict[str, ScalarType | MaskType] = {} for pattern, actual in zip(dtype_signature, operand_types): - if isinstance(pattern, ScalarType): + if isinstance(pattern, (ScalarType, MaskType)): if pattern != actual: return None continue @@ -885,8 +1636,8 @@ def _match_dtype_signature( def _match_descriptor_dtype_signature( descriptor: VKernelDescriptor, - operand_types: tuple[ScalarType, ...], -) -> tuple[ScalarType, ...] | None: + operand_types: tuple[ScalarType | MaskType, ...], +) -> tuple[ScalarType | MaskType, ...] | None: for dtype_signature in descriptor.dtypes: matched = _match_dtype_signature(dtype_signature, operand_types) if matched is not None: @@ -918,6 +1669,12 @@ def _validate_parameter_spec(param: inspect.Parameter) -> KernelParameterSpec: kind="tensorview", annotation=annotation, ) + if annotation is PartitionTensorView: + return KernelParameterSpec( + name=param.name, + kind="partition_tensor_view", + annotation=annotation, + ) if annotation is Tile: return KernelParameterSpec( name=param.name, @@ -930,7 +1687,19 @@ def _validate_parameter_spec(param: inspect.Parameter) -> KernelParameterSpec: kind="ptr", annotation=annotation, ) - if isinstance(annotation, ScalarType): + if isinstance(annotation, MaskType): + return KernelParameterSpec( + name=param.name, + kind="mask", + annotation=annotation, + ) + if isinstance(annotation, WildcardType) and annotation.name == "AnyMask": + return KernelParameterSpec( + name=param.name, + kind="mask", + annotation=annotation, + ) + if isinstance(annotation, (ScalarType, WildcardType, TypeVariable)): return KernelParameterSpec( name=param.name, kind="scalar", @@ -947,6 +1716,27 @@ def _collect_parameter_specs(py_fn: Callable[..., Any]) -> tuple[KernelParameter return tuple(_validate_parameter_spec(param) for param in signature.parameters.values()) +def _default_dtype_signature( + parameter_specs: tuple[KernelParameterSpec, ...], +) -> tuple[Any, ...]: + defaults: list[Any] = [] + for param_spec in parameter_specs: + if param_spec.kind in {"tensorview", "partition_tensor_view", "tile"}: + defaults.append(AnyType) + continue + if param_spec.kind == "ptr": + defaults.append(param_spec.annotation.element_dtype) + continue + if param_spec.kind == "mask": + defaults.append(param_spec.annotation if isinstance(param_spec.annotation, MaskType) else AnyMask) + continue + if isinstance(param_spec.annotation, (WildcardType, TypeVariable)): + defaults.append(AnyType) + continue + defaults.append(param_spec.annotation) + return tuple(defaults) + + def _validate_dtype_arity( parameter_specs: tuple[KernelParameterSpec, ...], dtypes: tuple[tuple[Any, ...], ...], @@ -962,49 +1752,71 @@ def _bind_parameter( param_spec: KernelParameterSpec, dtype: Any, ) -> BoundKernelParameter: - scalar_dtype = _validate_scalar_dtype(dtype, param_spec.name) - if param_spec.kind == "tensorview": + bound_dtype = _validate_leaf_dtype(dtype, param_spec.name) + if param_spec.kind in {"tensorview", "partition_tensor_view"}: return BoundKernelParameter( name=param_spec.name, kind=param_spec.kind, annotation=param_spec.annotation, - dtype=scalar_dtype, + dtype=bound_dtype, ) if param_spec.kind == "tile": return BoundKernelParameter( name=param_spec.name, kind=param_spec.kind, annotation=param_spec.annotation, - dtype=scalar_dtype, + dtype=bound_dtype, ) if param_spec.kind == "ptr": - if param_spec.annotation.element_dtype != scalar_dtype: + if param_spec.annotation.element_dtype != bound_dtype: raise TypeError( f"pointer parameter '{param_spec.name}' annotation {param_spec.annotation!r} " - f"does not match selected dtype {scalar_dtype!r}" + f"does not match selected dtype {bound_dtype!r}" + ) + return BoundKernelParameter( + name=param_spec.name, + kind=param_spec.kind, + annotation=param_spec.annotation, + dtype=bound_dtype, + ) + if param_spec.kind == "mask": + if not isinstance(bound_dtype, MaskType): + raise TypeError( + f"mask parameter '{param_spec.name}' annotation {param_spec.annotation!r} " + f"does not match selected dtype {bound_dtype!r}" + ) + if isinstance(param_spec.annotation, MaskType) and param_spec.annotation != bound_dtype: + raise TypeError( + f"mask parameter '{param_spec.name}' annotation {param_spec.annotation!r} " + f"does not match selected dtype {bound_dtype!r}" + ) + if isinstance(param_spec.annotation, WildcardType) and not _matches_wildcard(param_spec.annotation, bound_dtype): + raise TypeError( + f"mask parameter '{param_spec.name}' annotation {param_spec.annotation!r} " + f"does not match selected dtype {bound_dtype!r}" ) return BoundKernelParameter( name=param_spec.name, kind=param_spec.kind, annotation=param_spec.annotation, - dtype=scalar_dtype, + dtype=bound_dtype, ) - if param_spec.annotation != scalar_dtype: + if not _matches_scalar_annotation(param_spec.annotation, bound_dtype): raise TypeError( f"scalar parameter '{param_spec.name}' annotation {param_spec.annotation!r} " - f"does not match selected dtype {scalar_dtype!r}" + f"does not match selected dtype {bound_dtype!r}" ) return BoundKernelParameter( name=param_spec.name, kind=param_spec.kind, annotation=param_spec.annotation, - dtype=scalar_dtype, + dtype=bound_dtype, ) def _bind_parameters( parameter_specs: tuple[KernelParameterSpec, ...], - dtype_signature: tuple[ScalarType, ...], + dtype_signature: tuple[ScalarType | MaskType, ...], ) -> tuple[BoundKernelParameter, ...]: if len(dtype_signature) != len(parameter_specs): raise ValueError( @@ -1021,6 +1833,8 @@ def _build_descriptor( *, target: str, op: Any, + ops: Any, + templates: Any, dtypes: Any, name: Any, verify: Any, @@ -1033,20 +1847,32 @@ def _build_descriptor( source_info = _load_function_source_info(py_fn) advanced_enabled = _validate_advanced(advanced) - _validate_function_body(source_info, advanced_enabled=advanced_enabled) - frozen_dtypes = _freeze_dtypes(dtypes) + inline_procs = _collect_inline_procs(py_fn.__module__) + _validate_function_body( + source_info, + advanced_enabled=advanced_enabled, + module_name=py_fn.__module__, + ) + match_ops = _freeze_match_ops(op=op, ops=ops) + frozen_templates = _freeze_templates(templates, match_ops=match_ops) parameter_specs = _collect_parameter_specs(py_fn) + if dtypes is None: + dtypes = (_default_dtype_signature(parameter_specs),) + frozen_dtypes = _freeze_dtypes(dtypes) _validate_dtype_arity(parameter_specs, frozen_dtypes) - selected_dtype_signature: tuple[ScalarType, ...] | None = None + selected_op: str | None = None + selected_dtype_signature: tuple[ScalarType | MaskType, ...] | None = None bound_parameters: tuple[BoundKernelParameter, ...] | None = None - if len(frozen_dtypes) == 1 and all(isinstance(dtype, ScalarType) for dtype in frozen_dtypes[0]): + if len(match_ops) == 1: + selected_op = match_ops[0] + if len(frozen_dtypes) == 1 and all(isinstance(dtype, (ScalarType, MaskType)) for dtype in frozen_dtypes[0]): selected_dtype_signature = tuple(frozen_dtypes[0]) bound_parameters = _bind_parameters(parameter_specs, selected_dtype_signature) return VKernelDescriptor( target=_validate_target(target), - op=_validate_op(op), + match_ops=match_ops, dtypes=frozen_dtypes, name=_validate_name(py_fn, name), verify_enabled=_validate_verify(verify), @@ -1056,8 +1882,12 @@ def _build_descriptor( _source_info=source_info, constraints=_validate_constraints(constraints), priority=_validate_priority(priority), + _templates=frozen_templates, + _inline_procs=inline_procs, + _selected_op=selected_op, _selected_dtype_signature=selected_dtype_signature, _parameters=bound_parameters, + _constraint_context_attrs=(), ) @@ -1065,9 +1895,44 @@ def _evaluate_constraints( descriptor: VKernelDescriptor, context_attrs: Mapping[str, Any], ) -> bool: + named_context: dict[str, Any] = { + "target": context_attrs.get("target"), + "op": context_attrs.get("op"), + "selected_op": context_attrs.get("selected_op"), + } + for spec in descriptor._parameter_specs: + param_attrs = context_attrs.get(spec.name) + if not isinstance(param_attrs, Mapping): + param_attrs = {} + named_context[spec.name] = _ConstraintParamView(spec.name, param_attrs) + for index, constraint in enumerate(descriptor.constraints): try: - result = constraint(context_attrs) + signature = inspect.signature(constraint) + parameters = list(signature.parameters.values()) + kwargs: dict[str, Any] = {} + for parameter in parameters: + if parameter.kind == inspect.Parameter.VAR_POSITIONAL: + raise TypeError("constraint callables with *args are not supported") + if parameter.kind == inspect.Parameter.VAR_KEYWORD: + for key, value in named_context.items(): + kwargs.setdefault(key, value) + for key, value in context_attrs.items(): + kwargs.setdefault(key, value) + continue + if parameter.name in named_context: + kwargs[parameter.name] = named_context[parameter.name] + continue + if parameter.name in context_attrs: + kwargs[parameter.name] = context_attrs[parameter.name] + continue + if parameter.default is not inspect._empty: + continue + raise TypeError( + f"constraint {index} for kernel {descriptor.name!r} requires unsupported parameter " + f"{parameter.name!r}" + ) + result = constraint(**kwargs) except Exception as exc: raise TypeError( f"constraint {index} for kernel {descriptor.name!r} raised {type(exc).__name__}: {exc}" @@ -1084,6 +1949,27 @@ def _format_descriptor_identity(descriptor: VKernelDescriptor) -> str: return f"{descriptor.name}(priority={descriptor.priority}, dtypes={dtype_signature!r})" +def _match_descriptor_query( + descriptor: VKernelDescriptor, + *, + target: str, + op: str, + operand_types: tuple[ScalarType | MaskType, ...], +) -> VKernelDescriptor | None: + if descriptor.target != target: + return None + if op not in descriptor.match_ops: + return None + + op_bound_descriptor = descriptor._bind_selected_op(op) + matched_signature = _match_descriptor_dtype_signature(op_bound_descriptor, operand_types) + if matched_signature is None: + return None + if op_bound_descriptor._selected_dtype_signature == matched_signature: + return op_bound_descriptor + return op_bound_descriptor._bind_selected_dtype_signature(matched_signature) + + def select_kernel( target: str, op: str, @@ -1109,14 +1995,17 @@ def select_kernel( raise TypeError("registry must be a KernelRegistry or None") type_matched_candidates = [ - descriptor._bind_selected_dtype_signature(matched_signature) - if descriptor._selected_dtype_signature != matched_signature - else descriptor + matched_descriptor for descriptor in active_registry - if descriptor.target == normalized_target - and descriptor.op == normalized_op - for matched_signature in (_match_descriptor_dtype_signature(descriptor, normalized_operand_types),) - if matched_signature is not None + for matched_descriptor in ( + _match_descriptor_query( + descriptor, + target=normalized_target, + op=normalized_op, + operand_types=normalized_operand_types, + ), + ) + if matched_descriptor is not None ] if not type_matched_candidates: @@ -1128,7 +2017,10 @@ def select_kernel( constrained_candidates = [ descriptor for descriptor in type_matched_candidates - if _evaluate_constraints(descriptor, normalized_context_attrs) + if _evaluate_constraints( + descriptor, + descriptor._constraint_context_for_evaluation(normalized_context_attrs), + ) ] if not constrained_candidates: raise LookupError( @@ -1149,7 +2041,7 @@ def select_kernel( f"target={normalized_target!r}, op={normalized_op!r}, operand_types={normalized_operand_types!r}: " f"{winner_set}" ) - return winners[0] + return winners[0]._bind_constraint_context_attrs(normalized_context_attrs) def vkernel( @@ -1157,6 +2049,8 @@ def vkernel( *, target: str = "a5", op: str | None = None, + ops: tuple[str, ...] | list[str] | None = None, + templates: Any = _UNSET, dtypes: Any = None, name: str | None = None, verify: bool = True, @@ -1167,8 +2061,8 @@ def vkernel( """Create a TileLang DSL v1 kernel descriptor. v1 keeps only the minimal descriptor metadata surface: - `target`, `op`, `dtypes`, `constraints`, `priority`, `name`, `verify`, - and opt-in `advanced`. + `target`, `op`/`ops`, `templates`, `dtypes`, `constraints`, `priority`, `name`, + `verify`, and opt-in `advanced`. """ def wrap(fn: Callable[..., Any]) -> VKernelDescriptor: @@ -1176,6 +2070,8 @@ def wrap(fn: Callable[..., Any]) -> VKernelDescriptor: fn, target=target, op=op, + ops=ops, + templates=templates, dtypes=dtypes, name=name, verify=verify, @@ -1192,10 +2088,12 @@ def wrap(fn: Callable[..., Any]) -> VKernelDescriptor: __all__ = [ "BoundKernelParameter", + "InlineProcDescriptor", "KernelRegistry", "MaterializedMLIRModule", "TileLangFrontendError", "VKernelDescriptor", + "inline_proc", "select_kernel", "vkernel", ] diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index e75e84094..bd587df07 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + """Authoring-form VPTO lowering skeleton for TileLang DSL v1.""" from __future__ import annotations @@ -30,6 +38,7 @@ SemanticReturnStmt, SemanticScalarType, SemanticSetFlagStmt, + SemanticShapeType, SemanticStmt, SemanticVecscopeStmt, SemanticStrictVecscopeStmt, @@ -37,6 +46,7 @@ SemanticSymbolExpr, SemanticTensorSliceExpr, SemanticTensorViewType, + SemanticPartitionTensorViewType, SemanticTileType, SemanticType, SemanticTupleExpr, @@ -45,7 +55,7 @@ SemanticVectorStoreStmt, SemanticWaitFlagStmt, ) -from .types import MaskPattern, ScalarType +from .types import MaskPattern, ScalarType, get_lanes _I1_TYPE = SemanticScalarType(dtype=ScalarType("i1")) @@ -67,7 +77,61 @@ class AuthoringModule: kernel: SemanticKernel def render(self) -> str: - return _AuthoringRenderer(self.kernel).render() + kernel_text = _AuthoringRenderer(self.kernel).render() + if not self.kernel.inline_helpers: + return kernel_text + + base_lines = kernel_text.splitlines() + module_close_index = max( + (index for index, line in enumerate(base_lines) if line == "}"), + default=-1, + ) + if module_close_index < 0: + return kernel_text + + merged_lines = base_lines[:module_close_index] + for helper in self.kernel.inline_helpers: + helper_lines = _extract_single_function_lines( + _AuthoringRenderer(helper).render() + ) + if not helper_lines: + continue + helper_lines[0] = _rewrite_inline_helper_attrs(helper_lines[0]) + merged_lines.extend(helper_lines) + + merged_lines.append("}") + merged_lines.append("") + return "\n".join(merged_lines) + + +def _extract_single_function_lines(rendered_text: str) -> list[str]: + lines = rendered_text.splitlines() + try: + function_start = next( + index for index, line in enumerate(lines) if line.lstrip().startswith("func.func ") + ) + except StopIteration: + return [] + module_close_index = max( + (index for index, line in enumerate(lines) if line == "}"), + default=-1, + ) + if module_close_index <= function_start: + return [] + return lines[function_start:module_close_index] + + +def _rewrite_inline_helper_attrs(function_line: str) -> str: + kernel_attr = "attributes { pto.tilelang.instance }" + helper_attr = 'attributes { sym_visibility = "private", pto.tilelang.inline_proc }' + if kernel_attr in function_line: + return function_line.replace(kernel_attr, helper_attr) + if "attributes {" in function_line: + return function_line + if function_line.rstrip().endswith("{"): + stripped = function_line.rstrip() + return stripped[:-1] + f" {helper_attr} {{" + return function_line @dataclass(frozen=True) @@ -76,11 +140,45 @@ class _RenderedValue: type: SemanticType +@dataclass(frozen=True) +class _RenderedTextualType(SemanticType): + text: str + + +@dataclass(frozen=True) +class _DmaTransferConfig: + n_burst: _RenderedValue + len_burst: _RenderedValue + copy_src_stride: _RenderedValue + copy_dst_stride: _RenderedValue + loop_src_stride: _RenderedValue + loop_dst_stride: _RenderedValue + + +@dataclass(frozen=True) +class _DmaLoadPaddingProfile: + pad_mode_name: str + left_padding: int + right_padding: int + init_out_buffer: bool + pad_value: SemanticExpr | None + + +@dataclass(frozen=True) +class _DmaStoreTrimProfile: + left_padding: int + right_padding: int + + class _AuthoringRenderer: def __init__(self, kernel: SemanticKernel): self.kernel = kernel self._constant_lines: list[str] = [] self._constant_cache: dict[tuple[str, object], str] = {} + self._castptr_cache: dict[tuple[str, str], str] = {} + self._tile_memref_cache: dict[str, _RenderedValue] = {} + self._tile_valid_dim_cache: dict[tuple[str, int], _RenderedValue] = {} + self._used_tile_buffers = self._collect_used_tile_buffers(kernel.body) self._temp_counter = 0 self._loop_counter = 0 @@ -88,11 +186,23 @@ def render(self) -> str: parameter_list = ", ".join( f"{param.ssa_name}: {self._render_type(param.type)}" for param in self.kernel.parameters + if param.kind != "tile_valid_shape" ) env = { param.name: _RenderedValue(name=param.ssa_name, type=param.type) for param in self.kernel.parameters + if param.kind != "tile_valid_shape" } + entry_lines: list[str] = [] + for param in self.kernel.parameters: + if param.kind != "tile": + continue + if param.name in self._used_tile_buffers: + self._materialize_tile_memref( + env[param.name], + indent=4, + into=entry_lines, + ) body_lines = self._render_block(self.kernel.body, env, indent=4) lines = [ @@ -103,22 +213,136 @@ def render(self) -> str: f"// tilelang.advanced = {self.kernel.advanced_enabled}", ] for binding in self.kernel.tile_bindings: + valid_shape = "" + if binding.valid_shape is not None: + valid_shape = f" valid_shape={self._format_shape_tuple(binding.valid_shape)}" lines.append( "// tilelang.specialize " f"{binding.name} shape={binding.shape} memory_space={binding.memory_space} " - f"config={binding.config}" + f"config={binding.config}{valid_shape}" ) lines.append(f'module attributes {{pto.target_arch = "{self.kernel.target}"}} {{') lines.append( - f" func.func {_format_symbol_name(self.kernel.symbol_name)}({parameter_list}) {{" + " func.func " + f"{_format_symbol_name(self.kernel.symbol_name)}({parameter_list}) " + "attributes { pto.tilelang.instance } {" ) lines.extend(self._constant_lines) + lines.extend(entry_lines) lines.extend(body_lines) lines.append(" }") lines.append("}") lines.append("") return "\n".join(lines) + def _collect_used_tile_buffers( + self, + statements: tuple[SemanticStmt, ...], + ) -> set[str]: + used: set[str] = set() + for stmt in statements: + self._collect_used_tile_buffers_from_stmt(stmt, used) + return used + + def _collect_used_tile_buffers_from_stmt( + self, + stmt: SemanticStmt, + used: set[str], + ) -> None: + if isinstance(stmt, SemanticAssignStmt): + self._collect_used_tile_buffers_from_expr(stmt.value, used) + return + if isinstance(stmt, SemanticExprStmt): + self._collect_used_tile_buffers_from_expr(stmt.expr, used) + return + if isinstance(stmt, SemanticDmaLoadStmt): + self._record_tile_buffer_use(stmt.dst, used) + self._collect_used_tile_buffers_from_expr(stmt.src, used) + return + if isinstance(stmt, SemanticDmaStoreStmt): + self._record_tile_buffer_use(stmt.src, used) + self._collect_used_tile_buffers_from_expr(stmt.dst, used) + return + if isinstance(stmt, SemanticVectorStoreStmt): + self._collect_used_tile_buffers_from_expr(stmt.value, used) + self._record_tile_buffer_use(stmt.destination, used) + for index in stmt.indices: + self._collect_used_tile_buffers_from_expr(index, used) + self._collect_used_tile_buffers_from_expr(stmt.mask, used) + return + if isinstance(stmt, SemanticVecscopeStmt): + for nested in stmt.body: + self._collect_used_tile_buffers_from_stmt(nested, used) + return + if isinstance(stmt, SemanticStrictVecscopeStmt): + for capture in stmt.captures: + self._record_tile_buffer_use(capture, used) + self._collect_used_tile_buffers_from_expr(capture, used) + for nested in stmt.body: + self._collect_used_tile_buffers_from_stmt(nested, used) + return + if isinstance(stmt, SemanticForStmt): + self._collect_used_tile_buffers_from_expr(stmt.lower_bound, used) + self._collect_used_tile_buffers_from_expr(stmt.upper_bound, used) + self._collect_used_tile_buffers_from_expr(stmt.step, used) + for nested in stmt.body: + self._collect_used_tile_buffers_from_stmt(nested, used) + return + if isinstance(stmt, SemanticIfStmt): + self._collect_used_tile_buffers_from_expr(stmt.condition, used) + for nested in stmt.then_body: + self._collect_used_tile_buffers_from_stmt(nested, used) + for nested in stmt.else_body: + self._collect_used_tile_buffers_from_stmt(nested, used) + return + if isinstance(stmt, SemanticReturnStmt) and stmt.value is not None: + self._collect_used_tile_buffers_from_expr(stmt.value, used) + + def _collect_used_tile_buffers_from_expr( + self, + expr: SemanticExpr, + used: set[str], + ) -> None: + if isinstance(expr, SemanticCallExpr): + if expr.namespace == "pto" and expr.name == "vlds" and expr.args: + self._record_tile_buffer_use(expr.args[0], used) + for arg in expr.args: + self._collect_used_tile_buffers_from_expr(arg, used) + return + if isinstance(expr, SemanticBinaryExpr): + self._collect_used_tile_buffers_from_expr(expr.lhs, used) + self._collect_used_tile_buffers_from_expr(expr.rhs, used) + return + if isinstance(expr, SemanticTupleExpr): + for element in expr.elements: + self._collect_used_tile_buffers_from_expr(element, used) + return + if isinstance(expr, SemanticTensorSliceExpr): + self._collect_used_tile_buffers_from_expr(expr.base, used) + for slice_expr in expr.slices: + if slice_expr.start is not None: + self._collect_used_tile_buffers_from_expr(slice_expr.start, used) + if slice_expr.stop is not None: + self._collect_used_tile_buffers_from_expr(slice_expr.stop, used) + if slice_expr.step is not None: + self._collect_used_tile_buffers_from_expr(slice_expr.step, used) + return + if isinstance(expr, SemanticAttributeAccess): + if expr.attr not in {"shape", "valid_shape", "strides", "element_type"}: + self._collect_used_tile_buffers_from_expr(expr.base, used) + return + if isinstance(expr, SemanticSubscriptAccess): + self._collect_used_tile_buffers_from_expr(expr.base, used) + self._collect_used_tile_buffers_from_expr(expr.index, used) + + def _record_tile_buffer_use( + self, + expr: SemanticExpr, + used: set[str], + ) -> None: + if isinstance(expr, SemanticBindingRef) and isinstance(expr.type, SemanticTileType): + used.add(expr.binding.name) + def _render_block( self, statements: tuple[SemanticStmt, ...], @@ -141,8 +365,9 @@ def _render_stmt( if isinstance(stmt, SemanticAssignStmt): return self._render_assign(stmt, env, indent=indent) if isinstance(stmt, SemanticExprStmt): - self._lower_expr(stmt.expr, env, indent=indent) - return [] + lines: list[str] = [] + self._lower_expr(stmt.expr, env, indent=indent, into=lines) + return lines if isinstance(stmt, SemanticDmaLoadStmt): return self._render_dma_load(stmt, env, indent=indent) if isinstance(stmt, SemanticDmaStoreStmt): @@ -166,10 +391,12 @@ def _render_stmt( if isinstance(stmt, SemanticLowLevelCopyStmt): return self._render_low_level_copy(stmt, env, indent=indent) if isinstance(stmt, SemanticReturnStmt): + lines: list[str] = [] if stmt.value is None: return [self._indent(indent) + "return"] - value = self._lower_expr(stmt.value, env, indent=indent) - return [self._indent(indent) + f"return {value.name} : {self._render_type(value.type)}"] + value = self._lower_expr(stmt.value, env, indent=indent, into=lines) + lines.append(self._indent(indent) + f"return {value.name} : {self._render_type(value.type)}") + return lines if isinstance(stmt, SemanticVecscopeStmt): return self._render_vecscope(stmt, env, indent=indent) if isinstance(stmt, SemanticStrictVecscopeStmt): @@ -235,7 +462,10 @@ def _render_assign( indent: int, ) -> list[str]: if len(stmt.targets) != 1: - if isinstance(stmt.value, SemanticTupleExpr): + if isinstance(stmt.value, SemanticTupleExpr) or ( + isinstance(stmt.value, SemanticAttributeAccess) + and isinstance(stmt.value.type, SemanticShapeType) + ): return self._render_tuple_expr_assign(stmt, env, indent=indent) return self._render_multi_result_assign(stmt, env, indent=indent) target = stmt.targets[0] @@ -260,13 +490,26 @@ def _render_tuple_expr_assign( *, indent: int, ) -> list[str]: - if not isinstance(stmt.value, SemanticTupleExpr): - raise NotImplementedError("tuple expression assignment expects a SemanticTupleExpr") - if len(stmt.targets) != len(stmt.value.elements): + if isinstance(stmt.value, SemanticTupleExpr): + elements = stmt.value.elements + elif isinstance(stmt.value, SemanticAttributeAccess) and isinstance(stmt.value.type, SemanticShapeType): + elements = tuple( + SemanticSubscriptAccess( + base=stmt.value, + index=SemanticLiteralExpr(value=axis, type=SemanticIndexType()), + type=SemanticIndexType(), + ) + for axis in range(stmt.value.type.rank) + ) + else: + raise NotImplementedError( + "tuple expression assignment expects a SemanticTupleExpr or shape-like attribute value" + ) + if len(stmt.targets) != len(elements): raise NotImplementedError("tuple expression assignment arity mismatch") lines: list[str] = [] - for target, element in zip(stmt.targets, stmt.value.elements): + for target, element in zip(stmt.targets, elements): lowered = self._lower_expr( element, env, @@ -368,6 +611,24 @@ def _render_multi_result_assign( env[high_target.name] = _RenderedValue(name=high_target.ssa_name, type=high_type) return lines + if stmt.value.name == "vmull": + lines = [] + lhs = self._lower_expr(stmt.value.args[0], env, indent=indent, into=lines) + rhs = self._lower_expr(stmt.value.args[1], env, indent=indent, into=lines) + mask = self._lower_expr(stmt.value.args[2], env, indent=indent, into=lines) + low_target, high_target = stmt.targets + low_type, high_type = stmt.value.type.elements + lines.append( + self._indent(indent) + + f"{low_target.ssa_name}, {high_target.ssa_name} = pto.vmull " + + f"{lhs.name}, {rhs.name}, {mask.name} : " + + f"{self._render_type(lhs.type)}, {self._render_type(rhs.type)}, {self._render_type(mask.type)} " + + f"-> {self._render_type(low_type)}, {self._render_type(high_type)}" + ) + env[low_target.name] = _RenderedValue(name=low_target.ssa_name, type=low_type) + env[high_target.name] = _RenderedValue(name=high_target.ssa_name, type=high_type) + return lines + raise NotImplementedError( f"multi-result assignment for `pto.{stmt.value.name}` is not supported in TileLang DSL v1" ) @@ -380,30 +641,48 @@ def _render_dma_load( indent: int, ) -> list[str]: lines: list[str] = [] + profile = self._resolve_dma_load_padding_profile(stmt.options) src = self._lower_expr(stmt.src.base, env, indent=indent, into=lines) dst = self._lower_expr(stmt.dst, env, indent=indent, into=lines) - row_count, col_count = self._dma_transfer_extents(stmt.src, stmt.dst.type) - element_bytes = self._dtype_byte_width(stmt.src.type.element_dtype) - burst_bytes = col_count * element_bytes - - c0_i64 = self._materialize_constant(0, _I64_TYPE) - c1_i64 = self._materialize_constant(1, _I64_TYPE) - n_burst = self._materialize_constant(row_count, _I64_TYPE) - len_burst = self._materialize_constant(burst_bytes, _I64_TYPE) - false_bit = self._materialize_constant(False, _I1_TYPE) - - lines.extend( - [ - self._indent(indent) - + f"pto.set_loop_size_outtoub {c1_i64}, {c1_i64} : i64, i64", - self._indent(indent) - + "pto.copy_gm_to_ubuf " - + f"{src.name}, {dst.name}, {c0_i64}, {n_burst}, {len_burst}, {c0_i64}, {c0_i64}, " - + f"{false_bit}, {c0_i64}, {len_burst}, {len_burst} : " - + f"{self._render_type(src.type)}, {self._render_type(dst.type)}, " - + "i64, i64, i64, i64, i64, i1, i64, i64, i64", - ] + src_name, src_type = self._materialize_tensor_slice_ptr( + stmt.src, + src, + env, + indent=indent, + into=lines, ) + dst_name, dst_type = self._materialize_tile_window_ptr( + dst, + col_offset=profile.left_padding, + indent=indent, + into=lines, + ) + transfer = self._infer_dma_load_transfer(stmt.src, stmt.dst.type, src, env, indent=indent, into=lines) + + copy_lines = self._render_dma_load_copy_ops( + src_name, + src_type, + dst_name, + dst_type, + transfer, + indent=indent, + ) + prefill_lines = self._render_dma_load_prefill( + stmt.dst, + dst, + env, + profile, + indent=indent, + ) + if profile.pad_mode_name == "PadFirstElem": + lines.extend(copy_lines) + lines.extend(prefill_lines) + if profile.init_out_buffer: + lines.extend(copy_lines) + return lines + + lines.extend(prefill_lines) + lines.extend(copy_lines) return lines def _render_dma_store( @@ -414,25 +693,39 @@ def _render_dma_store( indent: int, ) -> list[str]: lines: list[str] = [] + profile = self._resolve_dma_store_trim_profile(stmt.options) src = self._lower_expr(stmt.src, env, indent=indent, into=lines) dst = self._lower_expr(stmt.dst.base, env, indent=indent, into=lines) - row_count, col_count = self._dma_transfer_extents(stmt.dst, stmt.src.type) - element_bytes = self._dtype_byte_width(stmt.dst.type.element_dtype) - burst_bytes = col_count * element_bytes + src_name, src_type = self._materialize_tile_window_ptr( + src, + col_offset=profile.left_padding, + indent=indent, + into=lines, + ) + dst_name, dst_type = self._materialize_tensor_slice_ptr( + stmt.dst, + dst, + env, + indent=indent, + into=lines, + ) + transfer = self._infer_dma_store_transfer(stmt.dst, stmt.src.type, dst, env, indent=indent, into=lines) c0_i64 = self._materialize_constant(0, _I64_TYPE) c1_i64 = self._materialize_constant(1, _I64_TYPE) - n_burst = self._materialize_constant(row_count, _I64_TYPE) - len_burst = self._materialize_constant(burst_bytes, _I64_TYPE) lines.extend( [ self._indent(indent) + f"pto.set_loop_size_ubtoout {c1_i64}, {c1_i64} : i64, i64", self._indent(indent) + + f"pto.set_loop1_stride_ubtoout {transfer.loop_src_stride.name}, {transfer.loop_dst_stride.name} : i64, i64", + self._indent(indent) + + f"pto.set_loop2_stride_ubtoout {transfer.loop_src_stride.name}, {transfer.loop_dst_stride.name} : i64, i64", + self._indent(indent) + "pto.copy_ubuf_to_gm " - + f"{src.name}, {dst.name}, {c0_i64}, {n_burst}, {len_burst}, {c0_i64}, " - + f"{len_burst}, {len_burst} : {self._render_type(src.type)}, {self._render_type(dst.type)}, " + + f"{src_name}, {dst_name}, {c0_i64}, {transfer.n_burst.name}, {transfer.len_burst.name}, {c0_i64}, " + + f"{transfer.copy_dst_stride.name}, {transfer.copy_src_stride.name} : {src_type}, {dst_type}, " + "i64, i64, i64, i64, i64, i64", ] ) @@ -448,21 +741,912 @@ def _render_vector_store( lines: list[str] = [] value = self._lower_expr(stmt.value, env, indent=indent, into=lines) destination = self._lower_expr(stmt.destination, env, indent=indent, into=lines) - offset = self._lower_expr(stmt.offset, env, indent=indent, into=lines) + if isinstance(destination.type, SemanticTileType): + destination = self._materialize_tile_memref(destination, indent=indent, into=lines) + if ( + isinstance(stmt.destination.type, SemanticTileType) + and stmt.destination.type.rank == 2 + and len(stmt.indices) == 2 + ): + destination = self._materialize_rank2_tile_subview( + destination, + stmt.destination.type, + stmt.indices, + env, + indent=indent, + into=lines, + ) + rendered_indices = self._materialize_constant(0, SemanticIndexType()) + else: + rendered_indices = self._render_index_list(stmt.indices, env, indent=indent, into=lines) mask = self._lower_expr(stmt.mask, env, indent=indent, into=lines) lines.append( self._indent(indent) + "pto.vsts " - + f"{value.name}, {destination.name}[{offset.name}], {mask.name} : " + + f"{value.name}, {destination.name}[{rendered_indices}], {mask.name} : " + f"{self._render_type(value.type)}, {self._render_type(destination.type)}, {self._render_type(mask.type)}" ) return lines + def _render_index_list( + self, + indices: tuple[SemanticExpr, ...], + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> str: + rendered = [ + self._lower_expr(index, env, indent=indent, into=into).name for index in indices + ] + return ", ".join(rendered) + + def _render_rank2_subview_result_type( + self, + *, + element_dtype: str, + memory_space: str, + ) -> _RenderedTextualType: + return _RenderedTextualType( + f"memref, " + f"{self._render_memref_memory_space(memory_space)}>" + ) + + def _materialize_rank2_tile_subview( + self, + base: _RenderedValue, + tile_type: SemanticTileType, + indices: tuple[SemanticExpr, ...], + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + row_index, col_index = indices + row_value = self._lower_expr(row_index, env, indent=indent, into=into) + col_value = self._lower_expr(col_index, env, indent=indent, into=into) + one = self._materialize_constant(1, SemanticIndexType()) + total_cols = self._materialize_constant(tile_type.shape[1], SemanticIndexType()) + remaining_cols = self._new_temp() + into.append( + self._indent(indent) + + f"{remaining_cols} = arith.subi {total_cols}, {col_value.name} : index" + ) + subview_type = self._render_rank2_subview_result_type( + element_dtype=tile_type.element_dtype.name, + memory_space=tile_type.memory_space or "ub", + ) + subview_name = self._new_temp() + into.append( + self._indent(indent) + + f"{subview_name} = memref.subview {base.name}[{row_value.name}, {col_value.name}] " + + f"[{one}, {remaining_cols}] [{one}, {one}] : " + + f"{self._render_type(base.type)} to {self._render_type(subview_type)}" + ) + return _RenderedValue(name=subview_name, type=subview_type) + def _tensor_slice_extents(self, expr: SemanticTensorSliceExpr) -> tuple[int, int]: if expr.type.rank != 2 or len(expr.type.extents) != 2: raise NotImplementedError("TileLang DSL v1 DMA lowering currently only supports rank-2 TensorView slices") return expr.type.extents + def _resolve_dma_load_padding_profile(self, options: object) -> _DmaLoadPaddingProfile: + pad_mode_name = self._static_pad_mode_name(getattr(options, "pad_mode", None)) or "PadNull" + left_padding = self._static_expr_value(getattr(options, "left_padding", None), default=0) + right_padding = self._static_expr_value(getattr(options, "right_padding", None), default=0) + init_out_buffer = self._static_expr_value(getattr(options, "init_out_buffer", None), default=False) + if not isinstance(left_padding, int) or left_padding < 0: + raise NotImplementedError( + "pto.dma_load lowering currently expects `left_padding` to be a static non-negative index" + ) + if not isinstance(right_padding, int) or right_padding < 0: + raise NotImplementedError( + "pto.dma_load lowering currently expects `right_padding` to be a static non-negative index" + ) + if not isinstance(init_out_buffer, bool): + raise NotImplementedError( + "pto.dma_load lowering currently expects `init_out_buffer` to be a compile-time bool" + ) + if pad_mode_name not in {"PadNull", "PadFirstElem", "PadValue"}: + raise NotImplementedError( + f"pto.dma_load lowering does not recognize pad_mode `{pad_mode_name}` in TileLang DSL v1" + ) + if pad_mode_name == "PadNull" and init_out_buffer: + raise NotImplementedError( + "pto.dma_load lowering does not support `init_out_buffer=True` with `pad_mode=PadMode.PadNull`; " + "the stable frontend-only path has no explicit fill value for that combination" + ) + return _DmaLoadPaddingProfile( + pad_mode_name=pad_mode_name, + left_padding=left_padding, + right_padding=right_padding, + init_out_buffer=init_out_buffer, + pad_value=getattr(options, "pad_value", None), + ) + + def _resolve_dma_store_trim_profile(self, options: object) -> _DmaStoreTrimProfile: + pad_mode_name = self._static_pad_mode_name(getattr(options, "pad_mode", None)) or "PadNull" + left_padding = self._static_expr_value(getattr(options, "left_padding", None), default=0) + right_padding = self._static_expr_value(getattr(options, "right_padding", None), default=0) + if pad_mode_name != "PadNull": + raise NotImplementedError( + "pto.dma_store lowering only supports `pad_mode=PadMode.PadNull`; " + "non-PadNull store padding would require GM-side fill in the stable frontend-only path" + ) + if self._static_expr_value(getattr(options, "pad_value", None)) is not None: + raise NotImplementedError( + "pto.dma_store lowering does not support `pad_value`; GM-side fill is unsupported" + ) + if not isinstance(left_padding, int) or left_padding < 0: + raise NotImplementedError( + "pto.dma_store lowering currently expects `left_padding` to be a static non-negative index" + ) + if not isinstance(right_padding, int) or right_padding < 0: + raise NotImplementedError( + "pto.dma_store lowering currently expects `right_padding` to be a static non-negative index" + ) + return _DmaStoreTrimProfile( + left_padding=left_padding, + right_padding=right_padding, + ) + + def _require_default_dma_lowering_profile(self, options: object, op_name: str) -> None: + if not self._is_default_dma_lowering_profile(options): + raise NotImplementedError( + f"{op_name} lowering for padding/trim/init options is not implemented yet in TileLang DSL v1; " + "this stable frontend-only DMA path only lowers the default no-padding profile today" + ) + + def _is_default_dma_lowering_profile(self, options: object) -> bool: + return ( + self._static_pad_mode_name(getattr(options, "pad_mode", None)) in {None, "PadNull"} + and self._static_expr_value(getattr(options, "pad_value", None)) is None + and self._static_expr_value(getattr(options, "left_padding", None), default=0) == 0 + and self._static_expr_value(getattr(options, "right_padding", None), default=0) == 0 + and self._static_expr_value(getattr(options, "init_out_buffer", None), default=False) is False + ) + + def _static_pad_mode_name(self, expr: SemanticExpr | None) -> str | None: + value = self._static_expr_value(expr) + return None if value is None else getattr(value, "name", None) + + def _static_expr_value(self, expr: SemanticExpr | None, *, default: object = None) -> object: + if expr is None: + return default + if isinstance(expr, SemanticLiteralExpr): + return expr.value + if isinstance(expr, SemanticSymbolExpr): + return expr.value + if isinstance(expr, SemanticBindingRef): + return expr.binding.value + return None + + def _infer_dma_load_transfer( + self, + slice_expr: SemanticTensorSliceExpr, + tile_type: SemanticTileType, + tensor_base: _RenderedValue, + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> _DmaTransferConfig: + element_bytes = self._dtype_byte_width(slice_expr.type.element_dtype) + row_count = self._materialize_dma_axis_extent(slice_expr, 0, env, indent=indent, into=into) + col_count = self._materialize_dma_axis_extent(slice_expr, 1, env, indent=indent, into=into) + gm_row_stride = self._materialize_tensor_row_stride_bytes( + slice_expr, + tensor_base, + element_bytes, + indent=indent, + into=into, + ) + row_step = self._materialize_dma_row_step(slice_expr, env, indent=indent, into=into) + copy_src_stride = self._emit_binary_value( + "mul", + gm_row_stride, + row_step, + _I64_TYPE, + indent=indent, + into=into, + ) + copy_dst_stride = self._materialize_tile_row_stride_bytes( + tile_type, + element_bytes, + indent=indent, + into=into, + ) + len_burst = self._materialize_dma_len_burst( + col_count, + element_bytes, + indent=indent, + into=into, + ) + loop_src_stride = self._emit_binary_value( + "mul", + row_count, + copy_src_stride, + _I64_TYPE, + indent=indent, + into=into, + ) + loop_dst_stride = self._emit_binary_value( + "mul", + row_count, + copy_dst_stride, + _I64_TYPE, + indent=indent, + into=into, + ) + return _DmaTransferConfig( + n_burst=row_count, + len_burst=len_burst, + copy_src_stride=copy_src_stride, + copy_dst_stride=copy_dst_stride, + loop_src_stride=loop_src_stride, + loop_dst_stride=loop_dst_stride, + ) + + def _infer_dma_store_transfer( + self, + slice_expr: SemanticTensorSliceExpr, + tile_type: SemanticTileType, + tensor_base: _RenderedValue, + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> _DmaTransferConfig: + element_bytes = self._dtype_byte_width(slice_expr.type.element_dtype) + row_count = self._materialize_dma_axis_extent(slice_expr, 0, env, indent=indent, into=into) + col_count = self._materialize_dma_axis_extent(slice_expr, 1, env, indent=indent, into=into) + copy_src_stride = self._materialize_tile_row_stride_bytes( + tile_type, + element_bytes, + indent=indent, + into=into, + ) + gm_row_stride = self._materialize_tensor_row_stride_bytes( + slice_expr, + tensor_base, + element_bytes, + indent=indent, + into=into, + ) + row_step = self._materialize_dma_row_step(slice_expr, env, indent=indent, into=into) + copy_dst_stride = self._emit_binary_value( + "mul", + gm_row_stride, + row_step, + _I64_TYPE, + indent=indent, + into=into, + ) + len_burst = self._materialize_dma_len_burst( + col_count, + element_bytes, + indent=indent, + into=into, + ) + loop_src_stride = self._emit_binary_value( + "mul", + row_count, + copy_src_stride, + _I64_TYPE, + indent=indent, + into=into, + ) + loop_dst_stride = self._emit_binary_value( + "mul", + row_count, + copy_dst_stride, + _I64_TYPE, + indent=indent, + into=into, + ) + return _DmaTransferConfig( + n_burst=row_count, + len_burst=len_burst, + copy_src_stride=copy_src_stride, + copy_dst_stride=copy_dst_stride, + loop_src_stride=loop_src_stride, + loop_dst_stride=loop_dst_stride, + ) + + def _render_dma_load_copy_ops( + self, + src_name: str, + src_type: str, + dst_name: str, + dst_type: str, + transfer: _DmaTransferConfig, + *, + indent: int, + ) -> list[str]: + c0_i64 = self._materialize_constant(0, _I64_TYPE) + c1_i64 = self._materialize_constant(1, _I64_TYPE) + false_bit = self._materialize_constant(False, _I1_TYPE) + return [ + self._indent(indent) + + f"pto.set_loop2_stride_outtoub {transfer.loop_src_stride.name}, {transfer.loop_dst_stride.name} : i64, i64", + self._indent(indent) + + f"pto.set_loop1_stride_outtoub {transfer.loop_src_stride.name}, {transfer.loop_dst_stride.name} : i64, i64", + self._indent(indent) + + f"pto.set_loop_size_outtoub {c1_i64}, {c1_i64} : i64, i64", + self._indent(indent) + + "pto.copy_gm_to_ubuf " + + f"{src_name}, {dst_name}, {c0_i64}, {transfer.n_burst.name}, {transfer.len_burst.name}, {c0_i64}, {c0_i64}, " + + f"{false_bit}, {c0_i64}, {transfer.copy_src_stride.name}, {transfer.copy_dst_stride.name} : " + + f"{src_type}, {dst_type}, " + + "i64, i64, i64, i64, i64, i1, i64, i64, i64", + ] + + def _render_dma_load_prefill( + self, + tile_expr: SemanticExpr, + tile_value: _RenderedValue, + env: dict[str, _RenderedValue], + profile: _DmaLoadPaddingProfile, + *, + indent: int, + ) -> list[str]: + fill_bands = profile.left_padding > 0 or profile.right_padding > 0 + if profile.pad_mode_name == "PadNull" and not profile.init_out_buffer: + return [] + if profile.pad_mode_name in {"PadValue", "PadFirstElem"} and not (profile.init_out_buffer or fill_bands): + return [] + + lines: list[str] = [] + tile_memref = self._materialize_tile_memref(tile_value, indent=indent, into=lines) + rows_upper = self._materialize_tile_window_extent( + tile_expr, + tile_value, + axis=0, + indent=indent, + into=lines, + ) + cols_upper = self._materialize_tile_window_extent( + tile_expr, + tile_value, + axis=1, + indent=indent, + into=lines, + ) + fill_vec = self._materialize_dma_load_prefill_vector( + tile_memref, + tile_value.type.element_dtype, + env, + profile, + indent=indent, + into=lines, + ) + + windows: list[tuple[_RenderedValue, _RenderedValue]] = [] + c0_index = _RenderedValue( + name=self._materialize_constant(0, SemanticIndexType()), + type=SemanticIndexType(), + ) + if profile.init_out_buffer: + windows.append((c0_index, cols_upper)) + else: + if profile.left_padding > 0: + windows.append( + ( + c0_index, + _RenderedValue( + name=self._materialize_constant(profile.left_padding, SemanticIndexType()), + type=SemanticIndexType(), + ), + ) + ) + if profile.right_padding > 0: + right_width = _RenderedValue( + name=self._materialize_constant(profile.right_padding, SemanticIndexType()), + type=SemanticIndexType(), + ) + right_start = self._emit_binary_value( + "sub", + cols_upper, + right_width, + SemanticIndexType(), + indent=indent, + into=lines, + ) + windows.append((right_start, cols_upper)) + + if not windows: + return [] + lines.extend( + self._render_tile_fill_windows( + tile_memref, + tile_value.type.element_dtype, + fill_vec, + rows_upper, + windows, + indent=indent, + ) + ) + return lines + + def _render_tile_fill_windows( + self, + tile_memref: _RenderedValue, + element_dtype: ScalarType, + fill_vec: _RenderedValue, + rows_upper: _RenderedValue, + windows: list[tuple[_RenderedValue, _RenderedValue]], + *, + indent: int, + ) -> list[str]: + lines: list[str] = [] + c0 = self._materialize_constant(0, SemanticIndexType()) + c1 = self._materialize_constant(1, SemanticIndexType()) + vector_step = self._materialize_constant(get_lanes(element_dtype), SemanticIndexType()) + mask_type = SemanticMaskType(granularity=self._mask_granularity_for_dtype(element_dtype)) + lines.append(self._indent(indent) + "pto.vecscope {") + for start, stop in windows: + row_iv = self._new_temp() + lines.append( + self._indent(indent + 2) + + f"scf.for {row_iv} = {c0} to {rows_upper.name} step {c1} {{" + ) + col_iv = self._new_temp() + lines.append( + self._indent(indent + 4) + + f"scf.for {col_iv} = {start.name} to {stop.name} step {vector_step} {{" + ) + remaining = self._emit_binary_value( + "sub", + stop, + _RenderedValue(name=col_iv, type=SemanticIndexType()), + SemanticIndexType(), + indent=indent + 6, + into=lines, + ) + remaining_i32 = self._coerce_rendered_value( + remaining, + _I32_TYPE, + indent=indent + 6, + into=lines, + ) + mask_name = self._new_temp() + next_name = self._new_temp() + lines.append( + self._indent(indent + 6) + + f"{mask_name}, {next_name} = pto.plt_{mask_type.granularity} {remaining_i32.name} : " + + f"i32 -> {self._render_type(mask_type)}, i32" + ) + lines.append( + self._indent(indent + 6) + + f"pto.vsts {fill_vec.name}, {tile_memref.name}[{row_iv}, {col_iv}], {mask_name} : " + + f"{self._render_type(fill_vec.type)}, {self._render_type(tile_memref.type)}, {self._render_type(mask_type)}" + ) + lines.append(self._indent(indent + 4) + "}") + lines.append(self._indent(indent + 2) + "}") + lines.append(self._indent(indent) + "}") + return lines + + def _materialize_dma_load_prefill_vector( + self, + tile_memref: _RenderedValue, + element_dtype: ScalarType, + env: dict[str, _RenderedValue], + profile: _DmaLoadPaddingProfile, + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + vec_type = SemanticVRegType(element_dtype=element_dtype, lanes=get_lanes(element_dtype)) + result_name = self._new_temp() + if profile.pad_mode_name == "PadValue": + scalar = self._materialize_dma_pad_value_scalar( + profile.pad_value, + element_dtype, + env, + indent=indent, + into=into, + ) + into.append( + self._indent(indent) + + f"{result_name} = pto.vbr {scalar.name} : {self._render_type(scalar.type)} -> {self._render_type(vec_type)}" + ) + return _RenderedValue(name=result_name, type=vec_type) + if profile.pad_mode_name == "PadFirstElem": + c0 = self._materialize_constant(0, SemanticIndexType()) + first_col = self._materialize_constant(profile.left_padding, SemanticIndexType()) + into.append( + self._indent(indent) + + f'{result_name} = pto.vlds {tile_memref.name}[{c0}, {first_col}] {{dist = "{self._broadcast_dist_for_dtype(element_dtype)}"}} : ' + + f"{self._render_type(tile_memref.type)} -> {self._render_type(vec_type)}" + ) + return _RenderedValue(name=result_name, type=vec_type) + raise NotImplementedError( + f"pto.dma_load lowering does not produce a prefill vector for pad_mode `{profile.pad_mode_name}`" + ) + + def _materialize_dma_pad_value_scalar( + self, + expr: SemanticExpr | None, + element_dtype: ScalarType, + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + scalar_type = SemanticScalarType(dtype=element_dtype) + static_value = self._static_expr_value(expr) + if isinstance(static_value, (int, float)): + return _RenderedValue( + name=self._materialize_constant(static_value, scalar_type), + type=scalar_type, + ) + if expr is None: + raise NotImplementedError("pto.dma_load PadValue lowering requires a concrete `pad_value` expression") + value = self._lower_expr(expr, env, indent=indent, into=into) + if isinstance(value.type, SemanticScalarType) and value.type.dtype == element_dtype: + return value + raise NotImplementedError( + "pto.dma_load PadValue lowering currently expects `pad_value` to be a compile-time numeric literal " + "or a scalar value whose dtype matches the destination Tile element dtype" + ) + + def _materialize_tile_window_extent( + self, + tile_expr: SemanticExpr, + tile_value: _RenderedValue, + *, + axis: int, + indent: int, + into: list[str], + ) -> _RenderedValue: + if ( + isinstance(tile_expr, SemanticBindingRef) + and isinstance(tile_expr.type, SemanticTileType) + and tile_expr.type.valid_shape is not None + and tile_expr.type.valid_shape[axis] is None + ): + return self._materialize_tile_valid_dim( + tile_expr.binding, + axis, + indent=indent, + into=into, + ) + if not isinstance(tile_value.type, SemanticTileType): + raise NotImplementedError("DMA load prefill expects a Tile destination") + valid_shape = tile_value.type.valid_shape or tile_value.type.shape + if valid_shape is None: + raise NotImplementedError("DMA load prefill expects a statically known Tile shape or valid_shape") + extent = valid_shape[axis] + if extent is None: + raise NotImplementedError("DMA load prefill does not support dynamic Tile valid_shape on non-binding values") + return _RenderedValue( + name=self._materialize_constant(extent, SemanticIndexType()), + type=SemanticIndexType(), + ) + + def _mask_granularity_for_dtype(self, dtype: ScalarType) -> str: + if dtype.name in {"f32", "i32"}: + return "b32" + if dtype.name in {"f16", "bf16", "i16"}: + return "b16" + if dtype.name == "i8": + return "b8" + raise NotImplementedError(f"dtype `{dtype.name}` is not supported by DMA load prefill lowering") + + def _broadcast_dist_for_dtype(self, dtype: ScalarType) -> str: + if dtype.name in {"f32", "i32"}: + return "BRC_B32" + if dtype.name in {"f16", "bf16", "i16"}: + return "BRC_B16" + if dtype.name == "i8": + return "BRC_B8" + raise NotImplementedError(f"dtype `{dtype.name}` is not supported by DMA load broadcast lowering") + + def _materialize_tile_window_ptr( + self, + tile_value: _RenderedValue, + *, + col_offset: int, + indent: int, + into: list[str], + ) -> tuple[str, str]: + base_ptr_name, base_ptr_type = self._materialize_copy_buffer_ptr( + tile_value, + indent=indent, + into=into, + ) + if col_offset == 0: + return base_ptr_name, base_ptr_type + byte_ptr_type = "!pto.ptr" + byte_ptr_name = self._new_temp() + into.append( + self._indent(indent) + + f"{byte_ptr_name} = pto.castptr {base_ptr_name} : {base_ptr_type} -> {byte_ptr_type}" + ) + offset_bytes = self._materialize_constant( + col_offset * self._dtype_byte_width(tile_value.type.element_dtype), + SemanticIndexType(), + ) + offset_ptr_name = self._new_temp() + into.append( + self._indent(indent) + + f"{offset_ptr_name} = pto.addptr {byte_ptr_name}, {offset_bytes} : {byte_ptr_type} -> {byte_ptr_type}" + ) + typed_ptr_name = self._new_temp() + into.append( + self._indent(indent) + + f"{typed_ptr_name} = pto.castptr {offset_ptr_name} : {byte_ptr_type} -> {base_ptr_type}" + ) + return typed_ptr_name, base_ptr_type + + def _materialize_tensor_slice_ptr( + self, + slice_expr: SemanticTensorSliceExpr, + tensor_base: _RenderedValue, + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> tuple[str, str]: + base_ptr_name, base_ptr_type = self._materialize_copy_buffer_ptr( + tensor_base, + indent=indent, + into=into, + ) + if self._is_zero_index_expr(slice_expr.slices[0].start) and self._is_zero_index_expr(slice_expr.slices[1].start): + return base_ptr_name, base_ptr_type + + byte_ptr_type = "!pto.ptr" + byte_ptr_name = self._new_temp() + into.append( + self._indent(indent) + + f"{byte_ptr_name} = pto.castptr {base_ptr_name} : {base_ptr_type} -> {byte_ptr_type}" + ) + offset = self._materialize_tensor_slice_offset_bytes( + slice_expr, + tensor_base, + env, + indent=indent, + into=into, + ) + offset_ptr_name = self._new_temp() + into.append( + self._indent(indent) + + f"{offset_ptr_name} = pto.addptr {byte_ptr_name}, {offset.name} : " + + f"{byte_ptr_type} -> {byte_ptr_type}" + ) + typed_ptr_name = self._new_temp() + into.append( + self._indent(indent) + + f"{typed_ptr_name} = pto.castptr {offset_ptr_name} : {byte_ptr_type} -> {base_ptr_type}" + ) + return typed_ptr_name, base_ptr_type + + def _materialize_tensor_slice_offset_bytes( + self, + slice_expr: SemanticTensorSliceExpr, + tensor_base: _RenderedValue, + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + offset_elems = _RenderedValue( + name=self._materialize_constant(0, SemanticIndexType()), + type=SemanticIndexType(), + ) + for axis_index, slice_axis in enumerate(slice_expr.slices): + axis_start = self._lower_expr(slice_axis.start, env, indent=indent, into=into) + axis_stride_elems = self._materialize_tensor_axis_stride_elems( + tensor_base, + axis=slice_expr.type.physical_axes[axis_index], + indent=indent, + into=into, + ) + axis_offset_elems = self._emit_binary_value( + "mul", + axis_start, + axis_stride_elems, + SemanticIndexType(), + indent=indent, + into=into, + ) + offset_elems = self._emit_binary_value( + "add", + offset_elems, + axis_offset_elems, + SemanticIndexType(), + indent=indent, + into=into, + ) + return self._emit_binary_value( + "mul", + offset_elems, + _RenderedValue( + name=self._materialize_constant( + self._dtype_byte_width(slice_expr.type.element_dtype), + SemanticIndexType(), + ), + type=SemanticIndexType(), + ), + SemanticIndexType(), + indent=indent, + into=into, + ) + + def _is_zero_index_expr(self, expr: SemanticExpr) -> bool: + if isinstance(expr, SemanticLiteralExpr): + return isinstance(expr.value, int) and expr.value == 0 + if isinstance(expr, SemanticBindingRef): + return isinstance(expr.binding.value, int) and expr.binding.value == 0 + return False + + def _materialize_tensor_dim( + self, + tensor_base: _RenderedValue, + *, + axis: int, + indent: int, + into: list[str], + ) -> _RenderedValue: + dim_index = self._new_temp() + axis_value = self._materialize_constant(axis, SemanticIndexType()) + if isinstance(tensor_base.type, (SemanticTensorViewType, SemanticPartitionTensorViewType)): + into.append( + self._indent(indent) + + f"{dim_index} = pto.get_tensor_view_dim {tensor_base.name}, {axis_value} : " + + f"{self._render_type(tensor_base.type)} -> index" + ) + else: + into.append( + self._indent(indent) + + f"{dim_index} = memref.dim {tensor_base.name}, {axis_value} : {self._render_type(tensor_base.type)}" + ) + return _RenderedValue(name=dim_index, type=SemanticIndexType()) + + def _materialize_dma_axis_extent( + self, + slice_expr: SemanticTensorSliceExpr, + axis: int, + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + axis_slice = slice_expr.slices[axis] + if axis_slice.extent is not None: + return _RenderedValue( + name=self._materialize_constant(axis_slice.extent, _I64_TYPE), + type=_I64_TYPE, + ) + + distance_expr = SemanticBinaryExpr( + lhs=axis_slice.stop, + op="sub", + rhs=axis_slice.start, + type=SemanticIndexType(), + ) + extent_expr: SemanticExpr = distance_expr + step_value = self._static_expr_value(axis_slice.step) + if not isinstance(step_value, int): + raise NotImplementedError("DMA lowering currently expects a static slice step") + if step_value != 1: + extent_expr = SemanticBinaryExpr( + lhs=SemanticBinaryExpr( + lhs=distance_expr, + op="add", + rhs=SemanticLiteralExpr(value=step_value - 1, type=SemanticIndexType()), + type=SemanticIndexType(), + ), + op="floordiv", + rhs=SemanticLiteralExpr(value=step_value, type=SemanticIndexType()), + type=SemanticIndexType(), + ) + return self._lower_to_i64(extent_expr, env, indent=indent, into=into) + + def _materialize_dma_row_step( + self, + slice_expr: SemanticTensorSliceExpr, + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + return self._lower_to_i64(slice_expr.slices[0].step, env, indent=indent, into=into) + + def _materialize_tensor_axis_stride_elems( + self, + tensor_base: _RenderedValue, + axis: int, + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + stride = _RenderedValue( + name=self._materialize_constant(1, SemanticIndexType()), + type=SemanticIndexType(), + ) + for dim_axis in range(axis + 1, tensor_base.type.rank): + dim_value = self._materialize_tensor_dim( + tensor_base, + axis=dim_axis, + indent=indent, + into=into, + ) + stride = self._emit_binary_value( + "mul", + stride, + dim_value, + SemanticIndexType(), + indent=indent, + into=into, + ) + return stride + + def _materialize_tensor_row_stride_bytes( + self, + slice_expr: SemanticTensorSliceExpr, + tensor_base: _RenderedValue, + element_bytes: int, + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + stride_elems = self._materialize_tensor_axis_stride_elems( + tensor_base, + axis=slice_expr.type.physical_axes[0], + indent=indent, + into=into, + ) + dim_bytes = self._emit_binary_value( + "mul", + stride_elems, + _RenderedValue( + name=self._materialize_constant(element_bytes, SemanticIndexType()), + type=SemanticIndexType(), + ), + SemanticIndexType(), + indent=indent, + into=into, + ) + return self._coerce_rendered_to_i64(dim_bytes, indent=indent, into=into) + + def _materialize_tile_row_stride_bytes( + self, + tile_type: SemanticTileType, + element_bytes: int, + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + if tile_type.shape is None or len(tile_type.shape) != 2: + raise NotImplementedError("DMA lowering requires a statically specialized rank-2 Tile shape") + row_bytes = tile_type.shape[1] * element_bytes + return _RenderedValue( + name=self._materialize_constant(row_bytes, _I64_TYPE), + type=_I64_TYPE, + ) + + def _materialize_dma_len_burst( + self, + col_count: _RenderedValue, + element_bytes: int, + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + return self._emit_binary_value( + "mul", + col_count, + _RenderedValue( + name=self._materialize_constant(element_bytes, _I64_TYPE), + type=_I64_TYPE, + ), + _I64_TYPE, + indent=indent, + into=into, + ) + def _dma_transfer_extents( self, slice_expr: SemanticTensorSliceExpr, @@ -475,6 +1659,24 @@ def _dma_transfer_extents( raise NotImplementedError("DMA lowering requires a statically specialized rank-2 Tile shape") return tile_type.shape + def _emit_binary_value( + self, + op: str, + lhs: _RenderedValue, + rhs: _RenderedValue, + result_type: SemanticType, + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + result_name = self._new_temp() + into.append( + self._indent(indent) + + f"{result_name} = {self._render_binary_op(op, result_type)} " + f"{lhs.name}, {rhs.name} : {self._render_type(result_type)}" + ) + return _RenderedValue(name=result_name, type=result_type) + def _render_strict_vecscope( self, stmt: SemanticStrictVecscopeStmt, @@ -483,22 +1685,30 @@ def _render_strict_vecscope( indent: int, ) -> list[str]: lines: list[str] = [] - capture_values = [ - self._lower_expr(expr, env, indent=indent, into=lines) - for expr in stmt.captures - ] + capture_values = [] + block_argument_values = [] + for expr, binding in zip(stmt.captures, stmt.block_arguments): + capture = self._lower_expr(expr, env, indent=indent, into=lines) + capture, block_arg = self._materialize_strict_vecscope_capture( + capture, + binding, + indent=indent, + into=lines, + ) + capture_values.append(capture) + block_argument_values.append(block_arg) capture_names = ", ".join(value.name for value in capture_values) block_args = ", ".join( - f"{binding.ssa_name}: {self._render_type(binding.type)}" - for binding in stmt.block_arguments + f"{binding.ssa_name}: {self._render_type(value.type)}" + for binding, value in zip(stmt.block_arguments, block_argument_values) ) function_type = ", ".join( - self._render_type(binding.type) for binding in stmt.block_arguments + self._render_type(value.type) for value in block_argument_values ) scope_env = { - binding.name: _RenderedValue(name=binding.ssa_name, type=binding.type) - for binding in stmt.block_arguments + binding.name: _RenderedValue(name=binding.ssa_name, type=value.type) + for binding, value in zip(stmt.block_arguments, block_argument_values) } lines.append(self._indent(indent) + f"pto.strict_vecscope({capture_names}) {{") @@ -548,38 +1758,102 @@ def _render_for( lines.append(self._indent(indent) + "}") return lines - if len(stmt.loop_carried) != 1: - raise NotImplementedError( - "TileLang DSL v1 lowering currently supports at most one loop-carried binding" + carried_bindings = tuple(stmt.loop_carried) + if len(carried_bindings) == 1: + carried_binding = carried_bindings[0] + initial_value = self._coerce_rendered_value( + env[carried_binding.name], + carried_binding.type, + indent=indent, + into=lines, + ) + iter_arg_name = f"%{carried_binding.name}_iter_{self._loop_counter}" + self._loop_counter += 1 + body_env[carried_binding.name] = _RenderedValue( + name=iter_arg_name, + type=carried_binding.type, + ) + + lines.append( + self._indent(indent) + + f"{carried_binding.ssa_name}:1 = scf.for {stmt.induction_variable.ssa_name} = " + f"{lower_bound.name} to {upper_bound.name} step {step.name} " + f"iter_args({iter_arg_name} = {initial_value.name}) -> " + f"({self._render_type(carried_binding.type)}) {{" ) + lines.extend(self._render_block(stmt.body, body_env, indent=indent + 2)) + yielded_value = self._coerce_rendered_value( + body_env[carried_binding.name], + carried_binding.type, + indent=indent + 2, + into=lines, + ) + lines.append( + self._indent(indent + 2) + + f"scf.yield {yielded_value.name} : {self._render_type(yielded_value.type)}" + ) + lines.append(self._indent(indent) + "}") + env[carried_binding.name] = _RenderedValue( + name=carried_binding.ssa_name, + type=carried_binding.type, + ) + return lines - carried_binding = stmt.loop_carried[0] - initial_value = env[carried_binding.name] - iter_arg_name = f"%{carried_binding.name}_iter_{self._loop_counter}" + loop_id = self._loop_counter self._loop_counter += 1 - body_env[carried_binding.name] = _RenderedValue( - name=iter_arg_name, - type=carried_binding.type, + + initial_values: list[_RenderedValue] = [] + iter_arg_names: list[str] = [] + for index, binding in enumerate(carried_bindings): + initial_values.append( + self._coerce_rendered_value( + env[binding.name], + binding.type, + indent=indent, + into=lines, + ) + ) + iter_arg_names.append(f"%{binding.name}_iter_{loop_id}_{index}") + body_env[binding.name] = _RenderedValue( + name=iter_arg_names[-1], + type=binding.type, + ) + + result_names = ", ".join(binding.ssa_name for binding in carried_bindings) + iter_args = ", ".join( + f"{iter_name} = {initial.name}" + for iter_name, initial in zip(iter_arg_names, initial_values) ) + result_types = ", ".join(self._render_type(binding.type) for binding in carried_bindings) lines.append( self._indent(indent) - + f"{carried_binding.ssa_name}:1 = scf.for {stmt.induction_variable.ssa_name} = " + + f"{result_names} = scf.for {stmt.induction_variable.ssa_name} = " f"{lower_bound.name} to {upper_bound.name} step {step.name} " - f"iter_args({iter_arg_name} = {initial_value.name}) -> " - f"({self._render_type(carried_binding.type)}) {{" + f"iter_args({iter_args}) -> ({result_types}) {{" ) lines.extend(self._render_block(stmt.body, body_env, indent=indent + 2)) - yielded_value = body_env[carried_binding.name] + yielded_values = [ + self._coerce_rendered_value( + body_env[binding.name], + binding.type, + indent=indent + 2, + into=lines, + ) + for binding in carried_bindings + ] + yielded_names = ", ".join(value.name for value in yielded_values) + yielded_types = ", ".join(self._render_type(value.type) for value in yielded_values) lines.append( self._indent(indent + 2) - + f"scf.yield {yielded_value.name} : {self._render_type(yielded_value.type)}" + + f"scf.yield {yielded_names} : {yielded_types}" ) lines.append(self._indent(indent) + "}") - env[carried_binding.name] = _RenderedValue( - name=carried_binding.ssa_name, - type=carried_binding.type, - ) + for binding in carried_bindings: + env[binding.name] = _RenderedValue( + name=binding.ssa_name, + type=binding.type, + ) return lines def _render_if( @@ -708,8 +1982,27 @@ def _lower_expr( if isinstance(expr, SemanticBinaryExpr): if into is None: into = [] + if expr.op in {"and", "or"}: + return self._lower_bool_expr( + expr.op, + expr.lhs, + expr.rhs, + env, + indent=indent, + desired_name=desired_name, + into=into, + ) lhs = self._lower_expr(expr.lhs, env, indent=indent, into=into) rhs = self._lower_expr(expr.rhs, env, indent=indent, into=into) + if expr.op in {"eq", "ne", "gt", "lt", "ge", "le"}: + return self._lower_compare_expr( + expr.op, + lhs, + rhs, + indent=indent, + desired_name=desired_name, + into=into, + ) result_name = desired_name or self._new_temp() into.append( self._indent(indent) @@ -736,6 +2029,32 @@ def _lower_call_expr( desired_name: str | None, into: list[str] | None, ) -> _RenderedValue: + if expr.namespace is None: + if into is None: + into = [] + rendered_args = [ + self._lower_expr(arg, env, indent=indent, into=into) + for arg in expr.args + ] + rendered_arg_names = ", ".join(arg.name for arg in rendered_args) + rendered_arg_types = ", ".join(self._render_type(arg.type) for arg in rendered_args) + if not rendered_arg_types: + rendered_arg_types = "" + if expr.type is None: + into.append( + self._indent(indent) + + f"func.call {_format_symbol_name(expr.name)}({rendered_arg_names}) : " + + f"({rendered_arg_types}) -> ()" + ) + return _RenderedValue(name="__void_call__", type=SemanticMetaType(kind="void")) + result_name = desired_name or self._new_temp() + into.append( + self._indent(indent) + + f"{result_name} = func.call {_format_symbol_name(expr.name)}({rendered_arg_names}) : " + + f"({rendered_arg_types}) -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + if expr.namespace != "pto": raise NotImplementedError(f"unsupported call namespace {expr.namespace!r}") if isinstance(expr.type, SemanticTupleType): @@ -759,10 +2078,74 @@ def _lower_call_expr( if expr.name == "vlds": source = self._lower_expr(expr.args[0], env, indent=indent, into=into) - offset = self._lower_expr(expr.args[1], env, indent=indent, into=into) + if isinstance(source.type, SemanticTileType): + source = self._materialize_tile_memref(source, indent=indent, into=into) + if ( + isinstance(expr.args[0].type, SemanticTileType) + and expr.args[0].type.rank == 2 + and len(expr.args[1:]) == 2 + ): + source = self._materialize_rank2_tile_subview( + source, + expr.args[0].type, + expr.args[1:], + env, + indent=indent, + into=into, + ) + rendered_indices = self._materialize_constant(0, SemanticIndexType()) + else: + rendered_indices = self._render_index_list(expr.args[1:], env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"{result_name} = pto.vlds {source.name}[{rendered_indices}] : " + + f"{self._render_type(source.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "vbr": + scalar = self._lower_expr(expr.args[0], env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"{result_name} = pto.vbr {scalar.name} : " + + f"{self._render_type(scalar.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "vdup": + value = self._lower_expr(expr.args[0], env, indent=indent, into=into) + position = self._render_string_literal(expr.args[1]) + into.append( + self._indent(indent) + + f"{result_name} = pto.vdup {value.name} {{position = {position}}} : " + + f"{self._render_type(value.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "vci": + index = self._lower_expr(expr.args[0], env, indent=indent, into=into) + order = self._render_string_literal(expr.args[1]) + into.append( + self._indent(indent) + + f"{result_name} = pto.vci {index.name} {{order = {order}}} : " + + f"{self._render_type(index.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "tensor_view_as_ptr": + source = self._lower_expr(expr.args[0], env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"{result_name} = pto.tensor_view_addr {source.name} : " + + f"{self._render_type(source.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "tile_as_ptr": + source = self._lower_expr(expr.args[0], env, indent=indent, into=into) into.append( self._indent(indent) - + f"{result_name} = pto.vlds {source.name}[{offset.name}] : " + + f"{result_name} = pto.tile_buf_addr {source.name} : " + f"{self._render_type(source.type)} -> {self._render_type(expr.type)}" ) return _RenderedValue(name=result_name, type=expr.type) @@ -778,6 +2161,10 @@ def _lower_call_expr( ) return _RenderedValue(name=result_name, type=expr.type) + if expr.name in {"i1", "i8", "i16", "i32", "i64", "f16", "bf16", "f32"}: + value = self._lower_expr(expr.args[0], env, indent=indent, into=into) + return self._coerce_rendered_value(value, expr.type, indent=indent, into=into) + if expr.name == "addptr": pointer = self._lower_expr(expr.args[0], env, indent=indent, into=into) offset = self._lower_expr(expr.args[1], env, indent=indent, into=into) @@ -879,7 +2266,60 @@ def _lower_call_expr( ) return _RenderedValue(name=result_name, type=expr.type) - if expr.name in {"vabs", "vrelu", "vexp", "vnot"}: + if expr.name == "vcvt": + value = self._lower_expr(expr.args[0], env, indent=indent, into=into) + target_dtype = self._render_dtype_symbol(expr.args[1], context="pto.vcvt to_type") + mask = self._lower_expr(expr.args[2], env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"{result_name} = pto.vcvt {value.name}, {target_dtype}, {mask.name} : " + + f"{self._render_type(value.type)}, {self._render_type(mask.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "vmrgsort4": + vec0 = self._lower_expr(expr.args[0], env, indent=indent, into=into) + vec1 = self._lower_expr(expr.args[1], env, indent=indent, into=into) + vec2 = self._lower_expr(expr.args[2], env, indent=indent, into=into) + vec3 = self._lower_expr(expr.args[3], env, indent=indent, into=into) + mask = self._lower_expr(expr.args[4], env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"{result_name} = pto.vmrgsort4 {vec0.name}, {vec1.name}, {vec2.name}, {vec3.name}, {mask.name} : " + + f"{self._render_type(vec0.type)}, {self._render_type(vec1.type)}, {self._render_type(vec2.type)}, " + + f"{self._render_type(vec3.type)}, {self._render_type(mask.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name in { + "vabs", + "vrelu", + "vexp", + "vln", + "vsqrt", + "vrec", + "vnot", + "vcadd", + "vcmax", + "vbcnt", + "vneg", + "vcls", + "vcmin", + "vrsqrt", + "vmov", + "vsunpack", + "vzunpack", + "vusqz", + "vsqz", + "vexpdiff", + "vtrc", + "vbitsort", + "vcgadd", + "vcgmax", + "vcgmin", + "vcpadd", + "vsort32", + }: value = self._lower_expr(expr.args[0], env, indent=indent, into=into) mask = self._lower_expr(expr.args[1], env, indent=indent, into=into) into.append( @@ -889,7 +2329,27 @@ def _lower_call_expr( ) return _RenderedValue(name=result_name, type=expr.type) - if expr.name in {"vadd", "vsub", "vmul", "vdiv", "vmax", "vmin", "vand", "vor", "vxor"}: + if expr.name in { + "vadd", + "vsub", + "vmul", + "vdiv", + "vmax", + "vmin", + "vand", + "vor", + "vxor", + "vaddrelu", + "vaddreluconv", + "vsubrelu", + "vmulconv", + "vshl", + "vshr", + "vprelu", + "vpack", + "vperm", + "vmrgsort", + }: lhs = self._lower_expr(expr.args[0], env, indent=indent, into=into) rhs = self._lower_expr(expr.args[1], env, indent=indent, into=into) mask = self._lower_expr(expr.args[2], env, indent=indent, into=into) @@ -901,7 +2361,19 @@ def _lower_call_expr( ) return _RenderedValue(name=result_name, type=expr.type) - if expr.name in {"vadds", "vsubs", "vmuls", "vdivs", "vmaxs", "vmins"}: + if expr.name in {"vshift", "vslide"}: + vector = self._lower_expr(expr.args[0], env, indent=indent, into=into) + immediate = self._lower_expr(expr.args[1], env, indent=indent, into=into) + mask = self._lower_expr(expr.args[2], env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"{result_name} = pto.{expr.name} {vector.name}, {immediate.name}, {mask.name} : " + + f"{self._render_type(vector.type)}, {self._render_type(immediate.type)}, {self._render_type(mask.type)} " + + f"-> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name in {"vadds", "vsubs", "vmuls", "vdivs", "vmaxs", "vmins", "vlrelu", "vshls", "vshrs", "vands", "vors", "vxors"}: value = self._lower_expr(expr.args[0], env, indent=indent, into=into) scalar = self._lower_expr(expr.args[1], env, indent=indent, into=into) mask = self._lower_expr(expr.args[2], env, indent=indent, into=into) @@ -913,8 +2385,103 @@ def _lower_call_expr( ) return _RenderedValue(name=result_name, type=expr.type) + if expr.name in {"vaxpy", "vmula"}: + vec0 = self._lower_expr(expr.args[0], env, indent=indent, into=into) + vec1 = self._lower_expr(expr.args[1], env, indent=indent, into=into) + vec2 = self._lower_expr(expr.args[2], env, indent=indent, into=into) + mask = self._lower_expr(expr.args[3], env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"{result_name} = pto.{expr.name} {vec0.name}, {vec1.name}, {vec2.name}, {mask.name} : " + + f"{self._render_type(vec0.type)}, {self._render_type(vec1.type)}, {self._render_type(vec2.type)}, {self._render_type(mask.type)} " + + f"-> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + raise NotImplementedError(f"unsupported pto call `{expr.name}` in lowering") + def _lower_compare_expr( + self, + op: str, + lhs: _RenderedValue, + rhs: _RenderedValue, + *, + indent: int, + desired_name: str | None, + into: list[str], + ) -> _RenderedValue: + result_name = desired_name or self._new_temp() + if isinstance(lhs.type, SemanticIndexType) and isinstance(rhs.type, SemanticIndexType): + index_predicates = { + "eq": "eq", + "ne": "ne", + "gt": "sgt", + "lt": "slt", + "ge": "sge", + "le": "sle", + } + predicate = index_predicates[op] + elif isinstance(lhs.type, SemanticScalarType) and lhs.type == rhs.type: + if lhs.type.dtype.name in {"f16", "bf16", "f32"}: + float_predicates = { + "eq": "oeq", + "ne": "une", + "gt": "ogt", + "lt": "olt", + "ge": "oge", + "le": "ole", + } + predicate = float_predicates[op] + cmp_name = "arith.cmpf" + else: + int_predicates = { + "eq": "eq", + "ne": "ne", + "gt": "sgt", + "lt": "slt", + "ge": "sge", + "le": "sle", + } + predicate = int_predicates[op] + cmp_name = "arith.cmpi" + into.append( + self._indent(indent) + + f"{result_name} = {cmp_name} {predicate}, {lhs.name}, {rhs.name} : " + f"{self._render_type(lhs.type)}" + ) + return _RenderedValue(name=result_name, type=_I1_TYPE) + else: + raise NotImplementedError( + f"comparison lowering requires matching scalar types or index operands, got {lhs.type!r} and {rhs.type!r}" + ) + + into.append( + self._indent(indent) + + f"{result_name} = arith.cmpi {predicate}, {lhs.name}, {rhs.name} : {self._render_type(lhs.type)}" + ) + return _RenderedValue(name=result_name, type=_I1_TYPE) + + def _lower_bool_expr( + self, + op: str, + lhs_expr: SemanticExpr, + rhs_expr: SemanticExpr, + env: dict[str, _RenderedValue], + *, + indent: int, + desired_name: str | None, + into: list[str], + ) -> _RenderedValue: + lhs = self._lower_condition(lhs_expr, env, indent=indent, into=into) + rhs = self._lower_condition(rhs_expr, env, indent=indent, into=into) + result_name = desired_name or self._new_temp() + arith_op = "arith.andi" if op == "and" else "arith.ori" + into.append( + self._indent(indent) + + f"{result_name} = {arith_op} {lhs.name}, {rhs.name} : i1" + ) + return _RenderedValue(name=result_name, type=_I1_TYPE) + def _render_string_literal(self, expr: SemanticExpr) -> str: if isinstance(expr, SemanticLiteralExpr) and isinstance(expr.value, str): escaped = expr.value.replace("\\", "\\\\").replace('"', '\\"') @@ -924,6 +2491,18 @@ def _render_string_literal(self, expr: SemanticExpr) -> str: return f'"{escaped}"' raise NotImplementedError("expected a string literal for TileLang DSL advanced-family lowering") + def _render_dtype_symbol(self, expr: SemanticExpr, *, context: str) -> str: + if isinstance(expr, SemanticSymbolExpr) and isinstance(expr.value, ScalarType): + return expr.value.name + if ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "dtype" + and isinstance(expr.binding.value, ScalarType) + ): + return expr.binding.value.name + raise NotImplementedError(f"{context} expects a dtype symbol in TileLang DSL v1 lowering") + def _lower_to_i1( self, expr: SemanticExpr, @@ -986,6 +2565,125 @@ def _lower_remaining_to_i32( return _RenderedValue(name=cast_name, type=_I32_TYPE) raise NotImplementedError("tail make_mask lowering expects an i32 or index remaining operand") + def _materialize_copy_buffer_ptr( + self, + value: _RenderedValue, + *, + indent: int, + into: list[str], + ) -> tuple[str, str]: + ptr_type = self._render_copy_buffer_type(value.type) + cache_key = (value.name, ptr_type) + existing = self._castptr_cache.get(cache_key) + if existing is not None: + return existing, ptr_type + + if isinstance(value.type, SemanticTileType): + value = self._materialize_tile_memref(value, indent=indent, into=into) + + if self._is_memref_like_type(value.type): + cast_name = self._new_temp() + into.append( + self._indent(indent) + + f"{cast_name} = pto.castptr {value.name} : {self._render_type(value.type)} -> {ptr_type}" + ) + self._castptr_cache[cache_key] = cast_name + return cast_name, ptr_type + + return value.name, ptr_type + + def _coerce_rendered_value( + self, + value: _RenderedValue, + target_type: SemanticType, + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + if type(value.type) is type(target_type) and value.type == target_type: + return value + if isinstance(value.type, SemanticIndexType) and isinstance(target_type, SemanticScalarType): + if target_type.dtype.name == "i32": + cast_name = self._new_temp() + into.append( + self._indent(indent) + + f"{cast_name} = arith.index_cast {value.name} : index to i32" + ) + return _RenderedValue(name=cast_name, type=target_type) + if target_type.dtype.name == "i64": + cast_name = self._new_temp() + into.append( + self._indent(indent) + + f"{cast_name} = arith.index_castui {value.name} : index to i64" + ) + return _RenderedValue(name=cast_name, type=target_type) + if target_type.dtype.name in {"f16", "bf16", "f32"}: + cast_name = self._new_temp() + into.append( + self._indent(indent) + + f"{cast_name} = arith.uitofp {value.name} : index to {target_type.dtype.name}" + ) + return _RenderedValue(name=cast_name, type=target_type) + if isinstance(value.type, SemanticScalarType) and isinstance(target_type, SemanticScalarType): + src = value.type.dtype.name + dst = target_type.dtype.name + if src == dst: + return value + cast_name = self._new_temp() + if src.startswith("i") and dst.startswith("i"): + src_bits = int(src[1:]) + dst_bits = int(dst[1:]) + op = "arith.extsi" if src_bits < dst_bits else "arith.trunci" + into.append( + self._indent(indent) + + f"{cast_name} = {op} {value.name} : {src} to {dst}" + ) + return _RenderedValue(name=cast_name, type=target_type) + if src.startswith("i") and dst in {"f16", "bf16", "f32"}: + into.append( + self._indent(indent) + + f"{cast_name} = arith.sitofp {value.name} : {src} to {dst}" + ) + return _RenderedValue(name=cast_name, type=target_type) + if src in {"f16", "bf16", "f32"} and dst.startswith("i"): + into.append( + self._indent(indent) + + f"{cast_name} = arith.fptosi {value.name} : {src} to {dst}" + ) + return _RenderedValue(name=cast_name, type=target_type) + if src in {"f16", "bf16", "f32"} and dst in {"f16", "bf16", "f32"}: + op = "arith.extf" if src in {"f16", "bf16"} and dst == "f32" else "arith.truncf" + into.append( + self._indent(indent) + + f"{cast_name} = {op} {value.name} : {src} to {dst}" + ) + return _RenderedValue(name=cast_name, type=target_type) + raise NotImplementedError( + f"unsupported value coercion from {value.type!r} to {target_type!r} in TileLang DSL v1 lowering" + ) + + def _materialize_strict_vecscope_capture( + self, + capture: _RenderedValue, + binding: SemanticBinding, + *, + indent: int, + into: list[str], + ) -> tuple[_RenderedValue, _RenderedValue]: + if not self._is_memref_like_type(capture.type): + return capture, _RenderedValue(name=binding.ssa_name, type=binding.type) + + ptr_name, ptr_type = self._materialize_copy_buffer_ptr( + capture, + indent=indent, + into=into, + ) + rendered_ptr_type = _RenderedTextualType(ptr_type) + return ( + _RenderedValue(name=ptr_name, type=rendered_ptr_type), + _RenderedValue(name=binding.ssa_name, type=rendered_ptr_type), + ) + def _mask_suffix(self, ty: SemanticType) -> str: if not isinstance(ty, SemanticMaskType): raise NotImplementedError("tail make_mask lowering expects a mask result type") @@ -1011,6 +2709,51 @@ def _lower_subscript_access( desired_name: str | None, into: list[str] | None, ) -> _RenderedValue: + if ( + into is not None + and isinstance(expr.base, SemanticAttributeAccess) + and expr.base.attr == "valid_shape" + and isinstance(expr.base.base, SemanticBindingRef) + and isinstance(expr.base.base.type, SemanticTileType) + and isinstance(expr.index, SemanticLiteralExpr) + and isinstance(expr.index.value, int) + ): + return self._materialize_tile_valid_dim( + expr.base.base.binding, + expr.index.value, + indent=indent, + into=into, + desired_name=desired_name, + ) + if ( + into is not None + and isinstance(expr.base, SemanticAttributeAccess) + and expr.base.attr in {"shape", "valid_shape", "strides"} + and isinstance(expr.base.base, SemanticBindingRef) + and isinstance( + expr.base.base.type, + (SemanticTensorViewType, SemanticPartitionTensorViewType), + ) + and isinstance(expr.index, SemanticLiteralExpr) + and isinstance(expr.index.value, int) + ): + tensor_value = env.get( + expr.base.base.binding.name, + _RenderedValue(expr.base.base.binding.ssa_name, expr.base.base.type), + ) + result_name = desired_name or self._new_temp() + axis_value = self._materialize_constant(expr.index.value, SemanticIndexType()) + op_name = ( + "pto.get_tensor_view_stride" + if expr.base.attr == "strides" + else "pto.get_tensor_view_dim" + ) + into.append( + self._indent(indent) + + f"{result_name} = {op_name} {tensor_value.name}, {axis_value} : " + + f"{self._render_type(tensor_value.type)} -> index" + ) + return _RenderedValue(name=result_name, type=SemanticIndexType()) value = self._extract_shape_subscript_value(expr, env) if isinstance(value, _RenderedValue): return value @@ -1029,19 +2772,78 @@ def _lower_subscript_access( def _tensor_shape_binding_name(self, tensor_name: str, axis: int) -> str: return f"__shape_{tensor_name}_{axis}" + def _tensor_stride_binding_name(self, tensor_name: str, axis: int) -> str: + return f"__stride_{tensor_name}_{axis}" + + def _materialize_tile_memref( + self, + value: _RenderedValue, + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + existing = self._tile_memref_cache.get(value.name) + if existing is not None: + return existing + if not isinstance(value.type, SemanticTileType): + return value + memref_type = _RenderedTextualType( + self._render_memref_type( + element_dtype=value.type.element_dtype.name, + shape=value.type.shape if value.type.shape is not None else ("?",) * value.type.rank, + memory_space=value.type.memory_space or "ub", + ) + ) + memref_name = self._new_temp() + into.append( + self._indent(indent) + + f"{memref_name} = pto.tile_buf_addr {value.name} : " + + f"{self._render_type(value.type)} -> {self._render_type(memref_type)}" + ) + rendered = _RenderedValue(name=memref_name, type=memref_type) + self._tile_memref_cache[value.name] = rendered + return rendered + + def _materialize_tile_valid_dim( + self, + binding: object, + axis: int, + *, + indent: int, + into: list[str], + desired_name: str | None = None, + ) -> _RenderedValue: + cache_key = (binding.name, axis) + existing = self._tile_valid_dim_cache.get(cache_key) + if existing is not None: + return existing + source = _RenderedValue(name=binding.ssa_name, type=binding.type) + op_name = "pto.tile_valid_rows" if axis == 0 else "pto.tile_valid_cols" + result_name = desired_name or self._new_temp() + into.append( + self._indent(indent) + + f"{result_name} = {op_name} {source.name} : " + + f"{self._render_type(source.type)} -> index" + ) + rendered = _RenderedValue(name=result_name, type=SemanticIndexType()) + self._tile_valid_dim_cache[cache_key] = rendered + return rendered + def _extract_shape_subscript_value( self, expr: SemanticSubscriptAccess, env: dict[str, _RenderedValue], ) -> int | _RenderedValue: if not isinstance(expr.base, SemanticAttributeAccess): - raise NotImplementedError("only shape indexing is supported in TileLang DSL v1 lowering") - if expr.base.attr != "shape": - raise NotImplementedError("only `.shape[...]` indexing is supported in TileLang DSL v1 lowering") + raise NotImplementedError("only shape/stride indexing is supported in TileLang DSL v1 lowering") + if expr.base.attr not in {"shape", "valid_shape", "strides"}: + raise NotImplementedError( + "only `.shape[...]`, `.valid_shape[...]`, and `.strides[...]` indexing are supported in TileLang DSL v1 lowering" + ) if not isinstance(expr.index, SemanticLiteralExpr) or not isinstance(expr.index.value, int): - raise NotImplementedError("shape indices must be integer literals in TileLang DSL v1 lowering") + raise NotImplementedError("shape/stride indices must be integer literals in TileLang DSL v1 lowering") if not isinstance(expr.base.base, SemanticBindingRef): - raise NotImplementedError("shape indexing expects a bound TensorView or Tile value") + raise NotImplementedError("shape/stride indexing expects a bound TensorView or Tile value") base_binding = expr.base.base.binding base_value = env.get(base_binding.name, _RenderedValue(base_binding.ssa_name, base_binding.type)) @@ -1049,20 +2851,33 @@ def _extract_shape_subscript_value( index = expr.index.value if isinstance(base_type, SemanticTileType): - if base_type.shape is None: + if expr.base.attr == "shape": + if base_type.shape is None: + raise NotImplementedError("dynamic Tile shapes are not supported in TileLang DSL v1 lowering") + return base_type.shape[index] + if base_type.valid_shape is None: raise NotImplementedError("dynamic Tile shapes are not supported in TileLang DSL v1 lowering") - return base_type.shape[index] - - if isinstance(base_type, SemanticTensorViewType): - hidden_name = self._tensor_shape_binding_name(base_binding.name, index) + valid_dim = base_type.valid_shape[index] + if valid_dim is not None: + return valid_dim + return _RenderedValue(name=base_binding.ssa_name, type=base_type) + + if isinstance(base_type, (SemanticTensorViewType, SemanticPartitionTensorViewType)): + if expr.base.attr == "strides": + hidden_name = self._tensor_stride_binding_name(base_binding.name, index) + else: + hidden_name = self._tensor_shape_binding_name(base_binding.name, index) hidden_value = env.get(hidden_name) if hidden_value is None: raise NotImplementedError( - f"missing TensorView shape binding for '{base_binding.name}.shape[{index}]'" + f"missing TensorView/PartitionTensorView {expr.base.attr} binding for '{base_binding.name}.{expr.base.attr}[{index}]'" ) return hidden_value - raise NotImplementedError("shape indexing expects a Tile or TensorView operand") + raise NotImplementedError("shape/stride indexing expects a Tile, TensorView, or PartitionTensorView operand") + + def _format_shape_tuple(self, shape: tuple[int | None, ...]) -> str: + return "(" + ", ".join("?" if dim is None else str(dim) for dim in shape) + ")" def _materialize_constant(self, value: object, ty: SemanticType) -> str: cache_key = (self._render_type(ty), value) @@ -1101,7 +2916,7 @@ def _format_constant(self, value: object, ty: SemanticType) -> str: return str(value) if isinstance(ty, SemanticScalarType): if ty.dtype.name == "i1" and isinstance(value, bool): - return "true" if value else "false" + return "1" if value else "0" return str(value) raise NotImplementedError(f"unsupported constant type {ty!r}") @@ -1118,6 +2933,8 @@ def _render_binary_op(self, op: str, ty: SemanticType) -> str: raise NotImplementedError(f"unsupported binary op '{op}' for type {ty!r}") def _render_type(self, ty: SemanticType) -> str: + if isinstance(ty, _RenderedTextualType): + return ty.text if isinstance(ty, SemanticIndexType): return "index" if isinstance(ty, SemanticScalarType): @@ -1125,16 +2942,100 @@ def _render_type(self, ty: SemanticType) -> str: if isinstance(ty, SemanticPtrType): return f"!pto.ptr<{ty.element_dtype.name}, {ty.memory_space}>" if isinstance(ty, SemanticTensorViewType): - return f"!pto.ptr<{ty.element_dtype.name}, gm>" + return self._render_tensor_view_type( + element_dtype=ty.element_dtype.name, + shape=("?",) * ty.rank, + ) + if isinstance(ty, SemanticPartitionTensorViewType): + return self._render_partition_tensor_view_type( + element_dtype=ty.element_dtype.name, + shape=("?",) * ty.rank, + ) if isinstance(ty, SemanticTileType): - memory_space = ty.memory_space or "ub" - return f"!pto.ptr<{ty.element_dtype.name}, {memory_space}>" + return self._render_tile_buf_type(ty) if isinstance(ty, SemanticMaskType): return f"!pto.mask<{ty.granularity}>" if isinstance(ty, SemanticVRegType): return f"!pto.vreg<{ty.lanes}x{ty.element_dtype.name}>" raise NotImplementedError(f"unsupported semantic type {ty!r}") + def _is_memref_like_type(self, ty: SemanticType) -> bool: + return isinstance(ty, (SemanticTensorViewType, SemanticPartitionTensorViewType, SemanticTileType)) or ( + isinstance(ty, _RenderedTextualType) and ty.text.startswith("memref<") + ) + + def _render_copy_buffer_type(self, ty: SemanticType) -> str: + if isinstance(ty, SemanticPtrType): + return self._render_type(ty) + if isinstance(ty, (SemanticTensorViewType, SemanticPartitionTensorViewType)): + return f"!pto.ptr<{ty.element_dtype.name}, gm>" + if isinstance(ty, SemanticTileType): + memory_space = ty.memory_space or "ub" + return f"!pto.ptr<{ty.element_dtype.name}, {memory_space}>" + return self._render_type(ty) + + def _render_memref_type( + self, + *, + element_dtype: str, + shape: tuple[int | str, ...], + memory_space: str, + ) -> str: + dims = "x".join(str(dim) for dim in shape) + return f"memref<{dims}x{element_dtype}, {self._render_memref_memory_space(memory_space)}>" + + def _render_tensor_view_type( + self, + *, + element_dtype: str, + shape: tuple[int | str, ...], + ) -> str: + dims = "x".join(str(dim) for dim in shape) + return f"!pto.tensor_view<{dims}x{element_dtype}>" + + def _render_partition_tensor_view_type( + self, + *, + element_dtype: str, + shape: tuple[int | str, ...], + ) -> str: + dims = "x".join(str(dim) for dim in shape) + return f"!pto.partition_tensor_view<{dims}x{element_dtype}>" + + def _render_memref_memory_space(self, memory_space: str) -> str: + if memory_space == "gm": + return "#pto.address_space" + if memory_space == "ub": + return "#pto.address_space" + raise NotImplementedError(f"unsupported memref memory space '{memory_space}' in TileLang DSL v1 lowering") + + def _render_tile_buf_type(self, ty: SemanticTileType) -> str: + if ty.shape is None: + raise NotImplementedError("tile_buf lowering requires statically specialized Tile shape") + if ty.rank not in (1, 2): + raise NotImplementedError("tile_buf lowering only supports rank-1 or rank-2 Tile values") + rows = ty.shape[0] + cols = 1 if ty.rank == 1 else ty.shape[1] + valid_shape = ty.valid_shape or ty.shape + v_row = valid_shape[0] + v_col = 1 if ty.rank == 1 else valid_shape[1] + return ( + f"!pto.tile_buf" + ) + + def _render_tile_buf_loc(self, memory_space: str) -> str: + if memory_space == "ub": + return "vec" + if memory_space == "gm": + return "gm" + raise NotImplementedError(f"unsupported tile_buf memory space '{memory_space}'") + + def _render_tile_buf_dim(self, dim: int | None) -> str: + return "?" if dim is None else str(dim) + def _dtype_byte_width(self, dtype: ScalarType) -> int: widths = { "i8": 1, diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 497b1557c..d3ca7679f 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -1,8 +1,17 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + """Semantic model for TileLang DSL descriptor lowering.""" from __future__ import annotations import ast +import struct from dataclasses import dataclass from typing import Any @@ -16,6 +25,7 @@ FrontendExprStmt, FrontendForStmt, FrontendIfStmt, + FrontendInlineProcNode, FrontendKernelNode, FrontendNameExpr, FrontendNameTarget, @@ -28,20 +38,28 @@ FrontendTargetNode, FrontendTupleExpr, FrontendTupleTarget, + FrontendVecscopeStmt, ) from .support_matrix import ( DEFERRED_PTO_SURFACES, + advanced_mode_message, deferred_surface_message, unsupported_feature_message, ) from .types import ( Event, + MaskType, MaskPattern, MemorySpace, + OrderMode, + PadMode, Pipe, + PositionMode, PointerType, ScalarType, + VRegType, bf16, + bytewidth, f16, f32, i1, @@ -62,13 +80,86 @@ "bf16": bf16, "f32": f32, } +_MASK_TYPE_SYMBOLS = { + "mask_b8": MaskType("b8"), + "mask_b16": MaskType("b16"), + "mask_b32": MaskType("b32"), +} _PATTERN_SYMBOLS = {pattern.name: pattern for pattern in MaskPattern} _PIPE_SYMBOLS = {pipe.name: pipe for pipe in Pipe} _EVENT_SYMBOLS = {event.name: event for event in Event} _MEMORY_SPACE_SYMBOLS = {memory_space.name: memory_space for memory_space in MemorySpace} -_UNARY_VECTOR_OPS = {"vabs", "vrelu", "vexp", "vnot"} -_BINARY_VECTOR_OPS = {"vadd", "vsub", "vmul", "vdiv", "vmax", "vmin", "vand", "vor", "vxor"} -_VECTOR_SCALAR_OPS = {"vadds", "vsubs", "vmuls", "vdivs", "vmaxs", "vmins"} +_PAD_MODE_SYMBOLS = {pad_mode.name: pad_mode for pad_mode in PadMode} +_POSITION_MODE_SYMBOLS = {position_mode.name: position_mode for position_mode in PositionMode} +_ORDER_MODE_SYMBOLS = {order_mode.name: order_mode for order_mode in OrderMode} +_UNARY_VECTOR_OPS = { + "vabs", + "vrelu", + "vexp", + "vln", + "vsqrt", + "vrec", + "vnot", + "vcadd", + "vcmax", + "vbcnt", + "vneg", + "vcls", + "vcmin", + "vrsqrt", + "vmov", + "vsunpack", + "vzunpack", + "vusqz", + "vsqz", + "vexpdiff", + "vtrc", + "vbitsort", + "vcgadd", + "vcgmax", + "vcgmin", + "vcpadd", + "vsort32", +} +_BINARY_VECTOR_OPS = { + "vadd", + "vsub", + "vmul", + "vdiv", + "vmax", + "vmin", + "vand", + "vor", + "vxor", + "vaddrelu", + "vaddreluconv", + "vsubrelu", + "vmulconv", + "vshl", + "vshr", + "vprelu", + "vpack", + "vperm", + "vmrgsort", +} +_VECTOR_SCALAR_OPS = { + "vadds", + "vsubs", + "vmuls", + "vdivs", + "vmaxs", + "vmins", + "vlrelu", + "vshls", + "vshrs", + "vands", + "vors", + "vxors", +} +_VECTOR_IMMEDIATE_OPS = {"vshift", "vslide"} +_TERNARY_VECTOR_OPS = {"vaxpy", "vmula"} +_MULTI_RESULT_VECTOR_OPS = {"vmull"} +_BROADCAST_VECTOR_OPS = {"vbr", "vdup", "vci"} _LOW_LEVEL_DMA_CONFIG_OPS = { "set_loop2_stride_outtoub", "set_loop1_stride_outtoub", @@ -91,7 +182,9 @@ | _PREDICATE_MOVEMENT_OPS | _CARRY_OPS | _REARRANGEMENT_OPS + | {"vcvt", "vmrgsort4"} ) +_TENSORVIEW_RANK = 5 class SemanticType: @@ -101,7 +194,13 @@ class SemanticType: @dataclass(frozen=True) class SemanticTensorViewType(SemanticType): element_dtype: ScalarType - rank: int = 2 + rank: int = _TENSORVIEW_RANK + + +@dataclass(frozen=True) +class SemanticPartitionTensorViewType(SemanticType): + element_dtype: ScalarType + rank: int = _TENSORVIEW_RANK @dataclass(frozen=True) @@ -109,6 +208,7 @@ class SemanticTensorSliceType(SemanticType): element_dtype: ScalarType rank: int extents: tuple[int | None, ...] + physical_axes: tuple[int, ...] @dataclass(frozen=True) @@ -116,6 +216,7 @@ class SemanticTileType(SemanticType): element_dtype: ScalarType rank: int shape: tuple[int, ...] | None + valid_shape: tuple[int | None, ...] | None memory_space: str | None @@ -182,6 +283,7 @@ class SemanticBinding: class SemanticTileBinding: name: str shape: tuple[int, ...] + valid_shape: tuple[int | None, ...] | None memory_space: str config: Any @@ -218,6 +320,14 @@ class SemanticSliceExpr(SemanticExpr): type: SemanticSliceType +@dataclass(frozen=True) +class SemanticTensorSliceAxis: + start: SemanticExpr + stop: SemanticExpr + step: SemanticExpr + extent: int | None + + @dataclass(frozen=True) class SemanticTupleExpr(SemanticExpr): elements: tuple[SemanticExpr, ...] @@ -241,7 +351,7 @@ class SemanticSubscriptAccess(SemanticExpr): @dataclass(frozen=True) class SemanticTensorSliceExpr(SemanticExpr): base: SemanticExpr - slices: tuple[SemanticSliceExpr, ...] + slices: tuple[SemanticTensorSliceAxis, ...] type: SemanticTensorSliceType @@ -277,23 +387,34 @@ class SemanticExprStmt(SemanticStmt): expr: SemanticExpr +@dataclass(frozen=True) +class SemanticDmaOptions: + pad_mode: SemanticExpr | None = None + pad_value: SemanticExpr | None = None + left_padding: SemanticExpr | None = None + right_padding: SemanticExpr | None = None + init_out_buffer: SemanticExpr | None = None + + @dataclass(frozen=True) class SemanticDmaLoadStmt(SemanticStmt): src: SemanticTensorSliceExpr dst: SemanticExpr + options: SemanticDmaOptions = SemanticDmaOptions() @dataclass(frozen=True) class SemanticDmaStoreStmt(SemanticStmt): src: SemanticExpr dst: SemanticTensorSliceExpr + options: SemanticDmaOptions = SemanticDmaOptions() @dataclass(frozen=True) class SemanticVectorStoreStmt(SemanticStmt): value: SemanticExpr destination: SemanticExpr - offset: SemanticExpr + indices: tuple[SemanticExpr, ...] mask: SemanticExpr @@ -405,6 +526,7 @@ class SemanticKernel: parameters: tuple[SemanticParameter, ...] tile_bindings: tuple[SemanticTileBinding, ...] body: tuple[SemanticStmt, ...] + inline_helpers: tuple["SemanticKernel", ...] = () class _SemanticAnalyzer: @@ -412,10 +534,18 @@ def __init__(self, node: FrontendKernelNode): self.node = node self._counter = 0 self._disable_inference_depth = 0 + self._has_explicit_vecscope = self._contains_explicit_vecscope(node.body) self._tile_specializations = { spec.name: spec for spec in node.tile_specializations } - self._tensor_shape_parameters: list[SemanticParameter] = [] + self._hidden_parameters: list[SemanticParameter] = [] + self._inline_proc_nodes: dict[str, FrontendInlineProcNode] = { + inline_proc.name: inline_proc for inline_proc in node.inline_procs + } + self._inline_proc_specializations: dict[tuple[str, tuple[SemanticType, ...]], SemanticKernel] = {} + self._inline_proc_return_types: dict[tuple[str, tuple[SemanticType, ...]], SemanticType | None] = {} + self._inline_proc_order: list[tuple[str, tuple[SemanticType, ...]]] = [] + self._inline_proc_active_stack: list[tuple[str, tuple[SemanticType, ...]]] = [] def analyze(self) -> SemanticKernel: env: dict[str, SemanticBinding] = {} @@ -430,11 +560,12 @@ def analyze(self) -> SemanticKernel: env[param.name] = binding parameters.append(SemanticParameter(binding=binding)) body, _ = self._analyze_kernel_body(env) - parameters.extend(self._tensor_shape_parameters) + parameters.extend(self._hidden_parameters) tile_bindings = tuple( SemanticTileBinding( name=spec.name, shape=spec.shape, + valid_shape=spec.valid_shape, memory_space=spec.memory_space, config=spec.config, ) @@ -450,6 +581,10 @@ def analyze(self) -> SemanticKernel: parameters=tuple(parameters), tile_bindings=tile_bindings, body=body, + inline_helpers=tuple( + self._inline_proc_specializations[key] + for key in self._inline_proc_order + ), ) def _analyze_kernel_body( @@ -460,16 +595,28 @@ def _analyze_kernel_body( def _parameter_type(self, param: Any) -> SemanticType: if param.kind == "tensorview": - return SemanticTensorViewType(element_dtype=param.dtype) + return SemanticTensorViewType( + element_dtype=param.dtype, + rank=_TENSORVIEW_RANK, + ) + if param.kind == "partition_tensor_view": + return SemanticPartitionTensorViewType( + element_dtype=param.dtype, + rank=_TENSORVIEW_RANK, + ) if param.kind == "tile": spec = self._tile_specializations.get(param.name) rank = 2 if spec is None else len(spec.shape) shape = None if spec is None else spec.shape + valid_shape = None if spec is None else ( + spec.shape if spec.valid_shape is None else spec.valid_shape + ) memory_space = None if spec is None else spec.memory_space return SemanticTileType( element_dtype=param.dtype, rank=rank, shape=shape, + valid_shape=valid_shape, memory_space=memory_space, ) if param.kind == "ptr": @@ -478,6 +625,8 @@ def _parameter_type(self, param: Any) -> SemanticType: element_dtype=param.dtype, memory_space=memory_space, ) + if param.kind == "mask": + return SemanticMaskType(granularity=param.dtype.granularity) if param.kind == "scalar": return SemanticScalarType(dtype=param.dtype) raise ValueError(f"unsupported parameter kind {param.kind!r}") @@ -490,24 +639,53 @@ def _new_ssa_name(self, stem: str) -> str: def _tensor_shape_binding_name(self, tensor_name: str, axis: int) -> str: return f"__shape_{tensor_name}_{axis}" - def _ensure_tensor_shape_parameter( + def _tensor_stride_binding_name(self, tensor_name: str, axis: int) -> str: + return f"__stride_{tensor_name}_{axis}" + + def _tile_valid_shape_binding_name(self, tile_name: str, axis: int) -> str: + return f"__valid_shape_{tile_name}_{axis}" + + def _ensure_hidden_parameter( self, - tensor_binding: SemanticBinding, - axis: int, + hidden_name: str, + origin: str, ) -> SemanticBinding: - hidden_name = self._tensor_shape_binding_name(tensor_binding.name, axis) - for parameter in self._tensor_shape_parameters: + for parameter in self._hidden_parameters: if parameter.name == hidden_name: return parameter.binding binding = SemanticBinding( name=hidden_name, - ssa_name=f"%arg{len(self.node.parameters) + len(self._tensor_shape_parameters)}", + ssa_name=f"%arg{len(self.node.parameters) + len(self._hidden_parameters)}", type=SemanticIndexType(), - origin="tensorview_shape", + origin=origin, ) - self._tensor_shape_parameters.append(SemanticParameter(binding=binding)) + self._hidden_parameters.append(SemanticParameter(binding=binding)) return binding + def _ensure_tensor_shape_parameter( + self, + tensor_binding: SemanticBinding, + axis: int, + ) -> SemanticBinding: + hidden_name = self._tensor_shape_binding_name(tensor_binding.name, axis) + return self._ensure_hidden_parameter(hidden_name, "tensorview_shape") + + def _ensure_tensor_stride_parameter( + self, + tensor_binding: SemanticBinding, + axis: int, + ) -> SemanticBinding: + hidden_name = self._tensor_stride_binding_name(tensor_binding.name, axis) + return self._ensure_hidden_parameter(hidden_name, "tensorview_stride") + + def _ensure_tile_valid_shape_parameter( + self, + tile_binding: SemanticBinding, + axis: int, + ) -> SemanticBinding: + hidden_name = self._tile_valid_shape_binding_name(tile_binding.name, axis) + return self._ensure_hidden_parameter(hidden_name, "tile_valid_shape") + def _make_binding( self, name: str, @@ -545,42 +723,88 @@ def _analyze_block( end += 1 run = statements[index:end] if self._run_contains_vector_op(run): + vecscope_stmt, current_env = self._analyze_inferred_vecscope( + run, + current_env, + allow_outer_lookup=allow_outer_lookup, + ) semantic_statements.append( - self._analyze_inferred_vecscope( - run, - current_env, - allow_outer_lookup=allow_outer_lookup, - ) + vecscope_stmt ) else: for stmt in run: - semantic_stmt, current_env = self._analyze_stmt( + emitted_stmts, current_env = self._analyze_stmt_or_inline( stmt, current_env, allow_outer_lookup=allow_outer_lookup, ) - semantic_statements.append(semantic_stmt) + semantic_statements.extend(emitted_stmts) index = end continue - semantic_stmt, current_env = self._analyze_stmt( + emitted_stmts, current_env = self._analyze_stmt_or_inline( statements[index], current_env, allow_outer_lookup=allow_outer_lookup, ) - semantic_statements.append(semantic_stmt) + semantic_statements.extend(emitted_stmts) index += 1 return tuple(semantic_statements), current_env + def _analyze_stmt_or_inline( + self, + stmt: FrontendStmtNode, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + ) -> tuple[tuple[SemanticStmt, ...], dict[str, SemanticBinding]]: + if ( + isinstance(stmt, FrontendExprStmt) + and isinstance(stmt.expr, FrontendConstantExpr) + and isinstance(stmt.expr.value, str) + ): + # Treat Python docstring-style string expression statements as no-op. + return tuple(), dict(env) + if isinstance(stmt, FrontendIfStmt) and stmt.is_constexpr: + return self._analyze_constexpr_if( + stmt, + env, + allow_outer_lookup=allow_outer_lookup, + ) + semantic_stmt, updated_env = self._analyze_stmt( + stmt, + env, + allow_outer_lookup=allow_outer_lookup, + ) + return (semantic_stmt,), updated_env + + def _wrap_kernel_body_in_inferred_vecscope( + self, + statements: tuple[SemanticStmt, ...], + ) -> tuple[SemanticStmt, ...]: + if not statements or not self._semantic_block_contains_vector_activity(statements): + return statements + + body_end = len(statements) + while body_end > 0 and isinstance(statements[body_end - 1], SemanticReturnStmt): + body_end -= 1 + if body_end == 0: + return statements + + wrapped_body = SemanticVecscopeStmt(body=statements[:body_end]) + return (wrapped_body, *statements[body_end:]) + def _should_infer_vecscope( self, stmt: FrontendStmtNode, *, allow_outer_lookup: bool, ) -> bool: + if self._has_explicit_vecscope: + return False if self._disable_inference_depth > 0: return False - if not self.node.advanced_enabled or not allow_outer_lookup: + if not allow_outer_lookup: return False if isinstance(stmt, FrontendForStmt): return self._block_can_live_in_inferred_vecscope(stmt.body) @@ -590,6 +814,10 @@ def _should_infer_vecscope( | _UNARY_VECTOR_OPS | _BINARY_VECTOR_OPS | _VECTOR_SCALAR_OPS + | _VECTOR_IMMEDIATE_OPS + | _TERNARY_VECTOR_OPS + | _MULTI_RESULT_VECTOR_OPS + | _BROADCAST_VECTOR_OPS | _ADVANCED_VECTOR_ACTIVITY_OPS ) @@ -599,25 +827,51 @@ def _block_can_live_in_inferred_vecscope( ) -> bool: saw_vector_activity = False for stmt in statements: - if isinstance(stmt, FrontendStrictVecscopeStmt): - return False - if isinstance(stmt, FrontendIfStmt): + if self._frontend_stmt_is_vecscope_boundary(stmt): return False - if isinstance(stmt, FrontendExprStmt) and ( - self._is_dma_call(stmt.expr) or self._is_sync_call(stmt.expr) - ): - return False - if isinstance(stmt, FrontendForStmt): - if not self._block_can_live_in_inferred_vecscope(stmt.body): - return False + if self._frontend_stmt_can_live_in_inferred_vecscope(stmt): saw_vector_activity = True continue - if self._frontend_stmt_contains_vector_activity(stmt): - saw_vector_activity = True + if self._frontend_stmt_is_scalar_vecscope_stmt(stmt): continue return False return saw_vector_activity + def _frontend_stmt_is_vecscope_boundary(self, stmt: FrontendStmtNode) -> bool: + if isinstance(stmt, FrontendStrictVecscopeStmt): + return True + if isinstance(stmt, FrontendVecscopeStmt): + return True + if isinstance(stmt, FrontendIfStmt): + return not stmt.is_constexpr + return ( + isinstance(stmt, FrontendExprStmt) + and (self._is_dma_call(stmt.expr) or self._is_sync_call(stmt.expr)) + ) + + def _constexpr_if_contains_vector_activity(self, stmt: FrontendIfStmt) -> bool: + if not stmt.is_constexpr: + return False + return self._run_contains_vector_op(stmt.then_body) or self._run_contains_vector_op(stmt.else_body) + + def _frontend_stmt_can_live_in_inferred_vecscope( + self, + stmt: FrontendStmtNode, + ) -> bool: + if isinstance(stmt, FrontendForStmt): + return self._block_can_live_in_inferred_vecscope(stmt.body) + if isinstance(stmt, FrontendIfStmt): + return self._constexpr_if_contains_vector_activity(stmt) + return self._frontend_stmt_contains_vector_activity(stmt) + + def _frontend_stmt_is_scalar_vecscope_stmt( + self, + stmt: FrontendStmtNode, + ) -> bool: + return isinstance(stmt, FrontendAssignStmt) or ( + isinstance(stmt, FrontendIfStmt) and stmt.is_constexpr + ) + def _frontend_stmt_contains_vector_activity(self, stmt: FrontendStmtNode) -> bool: expr: FrontendExprNode | None = None if isinstance(stmt, FrontendAssignStmt): @@ -633,6 +887,10 @@ def _frontend_stmt_contains_vector_activity(self, stmt: FrontendStmtNode) -> boo | _UNARY_VECTOR_OPS | _BINARY_VECTOR_OPS | _VECTOR_SCALAR_OPS + | _VECTOR_IMMEDIATE_OPS + | _TERNARY_VECTOR_OPS + | _MULTI_RESULT_VECTOR_OPS + | _BROADCAST_VECTOR_OPS | _ADVANCED_VECTOR_ACTIVITY_OPS ) ) @@ -641,6 +899,14 @@ def _run_contains_vector_op(self, statements: tuple[FrontendStmtNode, ...]) -> b for stmt in statements: if isinstance(stmt, FrontendForStmt) and self._block_can_live_in_inferred_vecscope(stmt.body): return True + if isinstance(stmt, FrontendVecscopeStmt): + if self._run_contains_vector_op(stmt.body): + return True + continue + if isinstance(stmt, FrontendIfStmt): + if self._constexpr_if_contains_vector_activity(stmt): + return True + continue name = self._frontend_vector_call_name(stmt) if name is None or name == "make_mask": continue @@ -666,17 +932,17 @@ def _analyze_inferred_vecscope( env: dict[str, SemanticBinding], *, allow_outer_lookup: bool, - ) -> SemanticVecscopeStmt: + ) -> tuple[SemanticVecscopeStmt, dict[str, SemanticBinding]]: self._disable_inference_depth += 1 try: - body, _ = self._analyze_block_without_inference( + body, updated_env = self._analyze_block_without_inference( statements, env, allow_outer_lookup=allow_outer_lookup, ) finally: self._disable_inference_depth -= 1 - return SemanticVecscopeStmt(body=body) + return SemanticVecscopeStmt(body=body), updated_env def _analyze_block_without_inference( self, @@ -688,12 +954,12 @@ def _analyze_block_without_inference( current_env = dict(env) semantic_statements = [] for stmt in statements: - semantic_stmt, current_env = self._analyze_stmt( + emitted_stmts, current_env = self._analyze_stmt_or_inline( stmt, current_env, allow_outer_lookup=allow_outer_lookup, ) - semantic_statements.append(semantic_stmt) + semantic_statements.extend(emitted_stmts) return tuple(semantic_statements), current_env def _semantic_block_contains_vector_activity( @@ -727,6 +993,10 @@ def _expr_contains_vector_activity(self, expr: SemanticExpr) -> bool: | _UNARY_VECTOR_OPS | _BINARY_VECTOR_OPS | _VECTOR_SCALAR_OPS + | _VECTOR_IMMEDIATE_OPS + | _TERNARY_VECTOR_OPS + | _MULTI_RESULT_VECTOR_OPS + | _BROADCAST_VECTOR_OPS | _ADVANCED_VECTOR_ACTIVITY_OPS ): return True @@ -787,10 +1057,171 @@ def _analyze_stmt( return self._analyze_for(stmt, env, allow_outer_lookup=allow_outer_lookup) if isinstance(stmt, FrontendIfStmt): return self._analyze_if(stmt, env, allow_outer_lookup=allow_outer_lookup) + if isinstance(stmt, FrontendVecscopeStmt): + return self._analyze_explicit_vecscope(stmt, env, allow_outer_lookup=allow_outer_lookup) if isinstance(stmt, FrontendStrictVecscopeStmt): return self._analyze_strict_vecscope(stmt, env) raise ValueError(f"unsupported frontend statement {type(stmt).__name__}") + def _inline_proc_specialization_key( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> tuple[str, tuple[SemanticType, ...]]: + return (name, tuple(arg.type for arg in args)) + + def _inline_proc_symbol_name( + self, + name: str, + index: int, + ) -> str: + sanitized = "".join(char if char.isalnum() else "_" for char in name) + return f"__tl_inline_{sanitized}_{index}" + + def _collect_inline_helper_tile_bindings( + self, + parameters: tuple[SemanticParameter, ...], + ) -> tuple[SemanticTileBinding, ...]: + tile_bindings: list[SemanticTileBinding] = [] + for parameter in parameters: + if not isinstance(parameter.type, SemanticTileType): + continue + if parameter.type.shape is None: + continue + tile_bindings.append( + SemanticTileBinding( + name=parameter.name, + shape=parameter.type.shape, + valid_shape=parameter.type.valid_shape, + memory_space=parameter.type.memory_space or "ub", + config=None, + ) + ) + return tuple(tile_bindings) + + def _materialize_inline_proc_specialization( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticKernel: + inline_proc_node = self._inline_proc_nodes.get(name) + if inline_proc_node is None: + raise TypeError(f"inline_proc `{name}` is not registered in the current TileLang module") + + key = self._inline_proc_specialization_key(name, args) + existing = self._inline_proc_specializations.get(key) + if existing is not None: + return existing + if key in self._inline_proc_active_stack: + raise TypeError( + f"recursive inline_proc call `{name}` is not supported in TileLang DSL v1" + ) + + if len(inline_proc_node.parameters) != len(args): + raise TypeError( + f"inline_proc `{name}` expects {len(inline_proc_node.parameters)} arguments in TileLang DSL v1" + ) + + helper_env: dict[str, SemanticBinding] = {} + helper_parameters: list[SemanticParameter] = [] + for index, (param, arg_expr) in enumerate(zip(inline_proc_node.parameters, args)): + binding = SemanticBinding( + name=param.name, + ssa_name=f"%arg{index}", + type=arg_expr.type, + origin="inline_param", + ) + helper_env[param.name] = binding + helper_parameters.append(SemanticParameter(binding=binding)) + + saved_hidden_parameters = self._hidden_parameters + self._hidden_parameters = [] + self._inline_proc_active_stack.append(key) + try: + body, _ = self._analyze_block( + inline_proc_node.body, + helper_env, + allow_outer_lookup=False, + ) + finally: + self._inline_proc_active_stack.pop() + helper_hidden_parameters = tuple(self._hidden_parameters) + self._hidden_parameters = saved_hidden_parameters + + if helper_hidden_parameters: + raise TypeError( + f"inline_proc `{name}` currently does not support dynamic shape metadata captures in TileLang DSL v1" + ) + + return_type: SemanticType | None = None + if body and isinstance(body[-1], SemanticReturnStmt): + return_type = None if body[-1].value is None else body[-1].value.type + + helper_index = len(self._inline_proc_order) + helper_kernel = SemanticKernel( + target=self.node.target, + op=self.node.op, + symbol_name=self._inline_proc_symbol_name(name, helper_index), + verify_enabled=False, + advanced_enabled=self.node.advanced_enabled, + dtype_signature=self.node.dtype_signature, + parameters=tuple(helper_parameters), + tile_bindings=self._collect_inline_helper_tile_bindings(tuple(helper_parameters)), + body=body, + inline_helpers=(), + ) + self._inline_proc_specializations[key] = helper_kernel + self._inline_proc_return_types[key] = return_type + self._inline_proc_order.append(key) + return helper_kernel + + def _analyze_inline_proc_call_expr( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + helper_kernel = self._materialize_inline_proc_specialization(name, args) + key = self._inline_proc_specialization_key(name, args) + return SemanticCallExpr( + namespace=None, + name=helper_kernel.symbol_name, + args=args, + type=self._inline_proc_return_types.get(key), + ) + + def _contains_explicit_vecscope(self, statements: tuple[FrontendStmtNode, ...]) -> bool: + for stmt in statements: + if isinstance(stmt, FrontendVecscopeStmt): + return True + if isinstance(stmt, FrontendForStmt): + if self._contains_explicit_vecscope(stmt.body): + return True + continue + if isinstance(stmt, FrontendIfStmt): + if self._contains_explicit_vecscope(stmt.then_body): + return True + if self._contains_explicit_vecscope(stmt.else_body): + return True + continue + if isinstance(stmt, FrontendStrictVecscopeStmt): + if self._contains_explicit_vecscope(stmt.body): + return True + return False + + def _analyze_explicit_vecscope( + self, + stmt: FrontendVecscopeStmt, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + ) -> tuple[SemanticStmt, dict[str, SemanticBinding]]: + body, updated_env = self._analyze_block( + stmt.body, + dict(env), + allow_outer_lookup=allow_outer_lookup, + ) + return SemanticVecscopeStmt(body=body), updated_env + def _is_dma_call(self, expr: FrontendExprNode) -> bool: return ( isinstance(expr, FrontendCallExpr) @@ -835,17 +1266,69 @@ def _analyze_dma_stmt( raise TypeError("pto.dma_load expects exactly 2 positional arguments in TileLang DSL v1") src = self._require_tensor_slice(args[0], "pto.dma_load source") dst = self._require_tile_expr(args[1], "pto.dma_load destination") - self._validate_dma_shape_match(src.type, dst.type, "pto.dma_load") - return SemanticDmaLoadStmt(src=src, dst=dst), dict(env) + options = self._analyze_dma_options( + expr.keywords, + env, + allow_outer_lookup=allow_outer_lookup, + context="pto.dma_load", + ) + self._validate_dma_load_profile(src, dst, options) + return SemanticDmaLoadStmt(src=src, dst=dst, options=options), dict(env) if expr.name == "dma_store": if len(args) != 2: raise TypeError("pto.dma_store expects exactly 2 positional arguments in TileLang DSL v1") src = self._require_tile_expr(args[0], "pto.dma_store source") dst = self._require_tensor_slice(args[1], "pto.dma_store destination") - self._validate_dma_shape_match(dst.type, src.type, "pto.dma_store") - return SemanticDmaStoreStmt(src=src, dst=dst), dict(env) + options = self._analyze_dma_options( + expr.keywords, + env, + allow_outer_lookup=allow_outer_lookup, + context="pto.dma_store", + ) + self._validate_dma_store_profile(src, dst, options) + return SemanticDmaStoreStmt(src=src, dst=dst, options=options), dict(env) raise ValueError(f"unsupported DMA stmt pto.{expr.name}") + def _analyze_dma_options( + self, + keywords: tuple[tuple[str, FrontendExprNode], ...], + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + context: str, + ) -> SemanticDmaOptions: + analyzed: dict[str, SemanticExpr] = {} + for name, keyword_expr in keywords: + analyzed[name] = self._analyze_expr( + keyword_expr, + env, + allow_outer_lookup=allow_outer_lookup, + ) + + pad_mode = analyzed.get("pad_mode") + if pad_mode is not None: + self._pad_mode_value(pad_mode, default=PadMode.PadNull) + + left_padding = analyzed.get("left_padding") + if left_padding is not None: + self._require_index_typed_expr(left_padding) + + right_padding = analyzed.get("right_padding") + if right_padding is not None: + self._require_index_typed_expr(right_padding) + + init_out_buffer = analyzed.get("init_out_buffer") + if init_out_buffer is not None: + self._require_i1_expr(init_out_buffer, f"{context} init_out_buffer") + + return SemanticDmaOptions( + pad_mode=pad_mode, + pad_value=analyzed.get("pad_value"), + left_padding=left_padding, + right_padding=right_padding, + init_out_buffer=init_out_buffer, + ) + def _analyze_vector_store_stmt( self, expr: FrontendCallExpr, @@ -855,7 +1338,7 @@ def _analyze_vector_store_stmt( ) -> tuple[SemanticStmt, dict[str, SemanticBinding]]: if len(expr.args) == 3: value = self._analyze_expr(expr.args[0], env, allow_outer_lookup=allow_outer_lookup) - destination, offset = self._analyze_tile_vector_access( + destination, indices = self._analyze_tile_vector_access( expr.args[1], env, allow_outer_lookup=allow_outer_lookup, @@ -870,16 +1353,18 @@ def _analyze_vector_store_stmt( if len(args) != 4: raise TypeError("pto.vsts expects 3 or 4 positional arguments in TileLang DSL v1") value, destination, offset, mask = args + indices = (offset,) self._require_vreg_expr(value, "pto.vsts value") self._require_vector_pointer_expr(destination, "pto.vsts destination") - self._require_index_typed_expr(offset) + for index in indices: + self._require_index_typed_expr(index) self._require_mask_for_vreg(mask, value.type, "pto.vsts") self._require_matching_vector_pointer(value.type, destination.type, "pto.vsts") return ( SemanticVectorStoreStmt( value=value, destination=destination, - offset=offset, + indices=indices, mask=mask, ), dict(env), @@ -919,9 +1404,10 @@ def _analyze_low_level_dma_stmt( *, allow_outer_lookup: bool, ) -> tuple[SemanticStmt, dict[str, SemanticBinding]]: - args = tuple( - self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) - for arg in expr.args + args = self._analyze_low_level_dma_operands( + expr, + env, + allow_outer_lookup=allow_outer_lookup, ) if expr.name in _LOW_LEVEL_DMA_CONFIG_OPS: if len(args) != 2: @@ -1012,6 +1498,99 @@ def _analyze_low_level_dma_stmt( ) raise ValueError(f"unsupported low-level DMA stmt pto.{expr.name}") + def _analyze_low_level_dma_operands( + self, + expr: FrontendCallExpr, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + ) -> tuple[SemanticExpr, ...]: + if expr.args and expr.keywords: + raise TypeError( + f"pto.{expr.name} does not support mixing positional and keyword operands in TileLang DSL v1" + ) + if not expr.keywords: + return tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + + analyzed_keywords: dict[str, SemanticExpr] = { + name: self._analyze_expr(value, env, allow_outer_lookup=allow_outer_lookup) + for name, value in expr.keywords + } + + def index_literal(value: int) -> SemanticLiteralExpr: + return SemanticLiteralExpr(value=value, type=SemanticIndexType()) + + def bool_literal(value: bool) -> SemanticLiteralExpr: + return SemanticLiteralExpr(value=value, type=SemanticScalarType(dtype=i1)) + + if expr.name in { + "set_loop2_stride_outtoub", + "set_loop1_stride_outtoub", + "set_loop2_stride_ubtoout", + "set_loop1_stride_ubtoout", + }: + return ( + analyzed_keywords["src_stride"], + analyzed_keywords["dst_stride"], + ) + if expr.name in {"set_loop_size_outtoub", "set_loop_size_ubtoout"}: + return ( + analyzed_keywords["loop1"], + analyzed_keywords["loop2"], + ) + if expr.name == "copy_gm_to_ubuf": + if "data_select_bit" in analyzed_keywords and "enable_ub_pad" in analyzed_keywords: + raise TypeError( + "pto.copy_gm_to_ubuf keyword form accepts either `data_select_bit` or `enable_ub_pad`, not both" + ) + return ( + analyzed_keywords["src"], + analyzed_keywords["dst"], + analyzed_keywords.get("sid", index_literal(0)), + analyzed_keywords["n_burst"], + analyzed_keywords["len_burst"], + analyzed_keywords.get("left_padding_count", index_literal(0)), + analyzed_keywords.get("right_padding_count", index_literal(0)), + analyzed_keywords.get( + "data_select_bit", + analyzed_keywords.get("enable_ub_pad", bool_literal(False)), + ), + analyzed_keywords.get("l2_cache_ctl", index_literal(0)), + analyzed_keywords["gm_stride"], + analyzed_keywords["ub_stride"], + ) + if expr.name == "copy_ubuf_to_gm": + if "burst_dst_stride" in analyzed_keywords and "gm_stride" in analyzed_keywords: + raise TypeError( + "pto.copy_ubuf_to_gm keyword form accepts either `burst_dst_stride` or `gm_stride`, not both" + ) + if "burst_src_stride" in analyzed_keywords and "ub_stride" in analyzed_keywords: + raise TypeError( + "pto.copy_ubuf_to_gm keyword form accepts either `burst_src_stride` or `ub_stride`, not both" + ) + return ( + analyzed_keywords["src"], + analyzed_keywords["dst"], + analyzed_keywords.get("sid", index_literal(0)), + analyzed_keywords["n_burst"], + analyzed_keywords["len_burst"], + analyzed_keywords.get("reserved", index_literal(0)), + analyzed_keywords.get( + "burst_dst_stride", + analyzed_keywords["gm_stride"], + ), + analyzed_keywords.get( + "burst_src_stride", + analyzed_keywords["ub_stride"], + ), + ) + raise TypeError( + f"pto.{expr.name} keyword form is not implemented in TileLang DSL v1" + ) + def _require_tensor_slice( self, expr: SemanticExpr, @@ -1050,7 +1629,7 @@ def _require_vector_pointer_expr(self, expr: SemanticExpr, context: str) -> Sema return self._require_tile_expr(expr, context) return self._require_pointer_expr(expr, context, memory_space="ub") - def _validate_dma_shape_match( + def _validate_dma_common_types( self, tensor_slice_type: SemanticTensorSliceType, tile_type: SemanticTileType, @@ -1062,48 +1641,268 @@ def _validate_dma_shape_match( raise TypeError(f"{op_name} requires a statically specialized rank-2 Tile in TileLang DSL v1") if tensor_slice_type.element_dtype != tile_type.element_dtype: raise TypeError(f"{op_name} requires matching TensorView/Tile element dtypes in TileLang DSL v1") - for axis, (extent, tile_dim) in enumerate(zip(tensor_slice_type.extents, tile_type.shape)): - if extent is not None and extent != tile_dim: - raise TypeError( - f"{op_name} requires TensorView slice extent axis {axis}={extent!r} " - f"to match Tile shape axis {axis}={tile_dim!r}" - ) - def _bind_assignment_target( + def _validate_dma_load_profile( self, - target: FrontendTargetNode, - value: SemanticExpr, - env: dict[str, SemanticBinding], - annotation: Any | None, - ) -> tuple[SemanticBinding, ...]: - if isinstance(target, FrontendNameTarget): - if isinstance(value.type, SemanticTupleType): - raise ValueError("multi-result call assignment requires tuple binding in TileLang DSL v1") - annotated_type = self._annotation_type(annotation, value.type) - binding = self._make_binding( - target.name, - annotated_type if annotated_type is not None else value.type, - "ssa", - value=self._binding_value_for_expr(value), - ) + src: SemanticTensorSliceExpr, + dst: SemanticExpr, + options: SemanticDmaOptions, + ) -> None: + assert isinstance(dst.type, SemanticTileType) + self._validate_dma_common_types(src.type, dst.type, "pto.dma_load") + self._validate_dma_slice_profile(src, "pto.dma_load") + + pad_mode = self._pad_mode_value(options.pad_mode, default=PadMode.PadNull) + left_padding = self._require_static_non_negative_index_value( + options.left_padding, + context="pto.dma_load left_padding", + default=0, + ) + right_padding = self._require_static_non_negative_index_value( + options.right_padding, + context="pto.dma_load right_padding", + default=0, + ) + self._require_static_bool_value( + options.init_out_buffer, + context="pto.dma_load init_out_buffer", + default=False, + ) + self._validate_dma_load_option_profile(options, pad_mode) + + valid_shape = self._resolved_tile_valid_shape(dst.type) + expected_extents = ( + valid_shape[0], + self._trimmed_tile_axis_extent( + valid_shape[1], + left_padding, + right_padding, + op_name="pto.dma_load", + axis=1, + window_label="destination Tile valid window", + ), + ) + self._validate_dma_extent_match( + actual_extents=src.type.extents, + expected_extents=expected_extents, + op_name="pto.dma_load", + actual_label="source slice", + expected_label="destination Tile valid window", + left_padding=left_padding, + right_padding=right_padding, + ) + + def _validate_dma_store_profile( + self, + src: SemanticExpr, + dst: SemanticTensorSliceExpr, + options: SemanticDmaOptions, + ) -> None: + assert isinstance(src.type, SemanticTileType) + self._validate_dma_common_types(dst.type, src.type, "pto.dma_store") + self._validate_dma_slice_profile(dst, "pto.dma_store") + + pad_mode = self._pad_mode_value(options.pad_mode, default=PadMode.PadNull) + left_padding = self._require_static_non_negative_index_value( + options.left_padding, + context="pto.dma_store left_padding", + default=0, + ) + right_padding = self._require_static_non_negative_index_value( + options.right_padding, + context="pto.dma_store right_padding", + default=0, + ) + self._validate_dma_store_option_profile(options, pad_mode) + + valid_shape = self._resolved_tile_valid_shape(src.type) + expected_extents = ( + valid_shape[0], + self._trimmed_tile_axis_extent( + valid_shape[1], + left_padding, + right_padding, + op_name="pto.dma_store", + axis=1, + window_label="source Tile interior window", + ), + ) + self._validate_dma_extent_match( + actual_extents=dst.type.extents, + expected_extents=expected_extents, + op_name="pto.dma_store", + actual_label="destination slice", + expected_label="source Tile interior window", + left_padding=left_padding, + right_padding=right_padding, + ) + + def _validate_dma_slice_profile( + self, + tensor_slice: SemanticTensorSliceExpr, + op_name: str, + ) -> None: + for axis, slice_axis in enumerate(tensor_slice.slices): + step = self._static_index_value(slice_axis.step, default=1) + if step is None: + raise TypeError( + f"{op_name} stable frontend-only DMA profile requires a static positive " + f"slice step on axis {axis}" + ) + if step <= 0: + raise TypeError( + f"{op_name} stable frontend-only DMA profile requires a positive " + f"slice step on axis {axis}, got {step!r}" + ) + if axis == 1 and step != 1: + raise TypeError( + f"{op_name} stable frontend-only DMA profile only supports step == 1 " + "on TensorView slice axis 1" + ) + + def _validate_dma_load_option_profile( + self, + options: SemanticDmaOptions, + pad_mode: PadMode, + ) -> None: + if pad_mode == PadMode.PadValue and options.pad_value is None: + raise TypeError( + "pto.dma_load stable frontend-only DMA profile requires `pad_value` when " + "`pad_mode=PadMode.PadValue`" + ) + if pad_mode != PadMode.PadValue and options.pad_value is not None: + raise TypeError( + "pto.dma_load stable frontend-only DMA profile only accepts `pad_value` " + "when `pad_mode=PadMode.PadValue`" + ) + + def _validate_dma_store_option_profile( + self, + options: SemanticDmaOptions, + pad_mode: PadMode, + ) -> None: + if options.pad_value is not None: + raise TypeError( + "pto.dma_store stable frontend-only DMA profile does not support `pad_value`; " + "GM-side fill is unsupported" + ) + if pad_mode != PadMode.PadNull: + raise TypeError( + "pto.dma_store stable frontend-only DMA profile only supports " + "`pad_mode=PadMode.PadNull`; non-PadNull store padding would require GM-side fill" + ) + + def _resolved_tile_valid_shape( + self, + tile_type: SemanticTileType, + ) -> tuple[int | None, ...]: + assert tile_type.shape is not None + return tile_type.shape if tile_type.valid_shape is None else tile_type.valid_shape + + def _trimmed_tile_axis_extent( + self, + base_extent: int | None, + left_padding: int, + right_padding: int, + *, + op_name: str, + axis: int, + window_label: str, + ) -> int | None: + if base_extent is None: + return None + trimmed_extent = base_extent - left_padding - right_padding + if trimmed_extent <= 0: + raise TypeError( + f"{op_name} stable frontend-only DMA profile requires {window_label} axis {axis}=" + f"{base_extent!r} to remain positive after left_padding={left_padding} " + f"and right_padding={right_padding}" + ) + return trimmed_extent + + def _validate_dma_extent_match( + self, + *, + actual_extents: tuple[int | None, ...], + expected_extents: tuple[int | None, ...], + op_name: str, + actual_label: str, + expected_label: str, + left_padding: int, + right_padding: int, + ) -> None: + for axis, (actual_extent, expected_extent) in enumerate(zip(actual_extents, expected_extents)): + if actual_extent is None or expected_extent is None: + continue + if actual_extent != expected_extent: + padding_suffix = "" + if axis == 1 and (left_padding != 0 or right_padding != 0): + padding_suffix = ( + f" after left_padding={left_padding} and right_padding={right_padding}" + ) + raise TypeError( + f"{op_name} stable frontend-only DMA profile requires {actual_label} extent " + f"axis {axis}={actual_extent!r} to match {expected_label} axis {axis}=" + f"{expected_extent!r}{padding_suffix}" + ) + + def _bind_assignment_target( + self, + target: FrontendTargetNode, + value: SemanticExpr, + env: dict[str, SemanticBinding], + annotation: Any | None, + ) -> tuple[SemanticBinding, ...]: + if isinstance(target, FrontendNameTarget): + if isinstance(value.type, SemanticTupleType): + raise ValueError("multi-result call assignment requires tuple binding in TileLang DSL v1") + annotated_type = self._annotation_type(annotation, value.type, env) + binding = self._make_binding( + target.name, + annotated_type if annotated_type is not None else value.type, + "ssa", + value=self._binding_value_for_expr(value), + ) env[target.name] = binding return (binding,) if isinstance(target, FrontendTupleTarget): - if not isinstance(value.type, SemanticTupleType): + if isinstance(value.type, SemanticTupleType): + element_types = value.type.elements + elif isinstance(value.type, SemanticShapeType): + element_types = tuple(SemanticIndexType() for _ in range(value.type.rank)) + else: raise ValueError("tuple assignment expects a tuple-typed value") if annotation is not None: raise TypeError("annotated tuple assignment is not supported in TileLang DSL v1") - if len(target.elements) != len(value.type.elements): + if len(target.elements) != len(element_types): raise ValueError("tuple assignment arity must match the tuple value") tuple_values: tuple[SemanticExpr, ...] if isinstance(value, SemanticTupleExpr): tuple_values = value.elements + elif isinstance(value, SemanticAttributeAccess) and isinstance(value.type, SemanticShapeType): + if isinstance(value.base, SemanticBindingRef): + if isinstance(value.base.type, SemanticTileType) and value.attr == "valid_shape": + valid_shape = value.base.type.valid_shape + if valid_shape is not None: + for axis, dim in enumerate(valid_shape): + if dim is None: + self._ensure_tile_valid_shape_parameter(value.base.binding, axis) + tuple_values = tuple( + SemanticSubscriptAccess( + base=value, + index=SemanticLiteralExpr(value=axis, type=SemanticIndexType()), + type=SemanticIndexType(), + ) + for axis in range(value.type.rank) + ) elif isinstance(value, SemanticCallExpr): tuple_values = value.args else: - tuple_values = tuple(SemanticLiteralExpr(value=None, type=element_type) for element_type in value.type.elements) + tuple_values = tuple( + SemanticLiteralExpr(value=None, type=element_type) for element_type in element_types + ) bindings = [] - for element, element_type, element_value in zip(target.elements, value.type.elements, tuple_values): + for element, element_type, element_value in zip(target.elements, element_types, tuple_values): binding = self._make_binding( element.name, element_type, @@ -1116,30 +1915,104 @@ def _bind_assignment_target( raise ValueError(f"unsupported frontend assignment target {type(target).__name__}") def _binding_value_for_expr(self, expr: SemanticExpr) -> Any | None: - if isinstance(expr, SemanticSymbolExpr): - return expr.value - if isinstance(expr, SemanticLiteralExpr): - return expr.value - if isinstance(expr, SemanticBindingRef): - return expr.binding.value - return None + return self._try_static_value(expr) def _annotation_type( self, annotation: Any | None, inferred_type: SemanticType | None, + env: dict[str, SemanticBinding], ) -> SemanticType | None: if annotation is None: return inferred_type - if isinstance(annotation, ast.Attribute) and isinstance(annotation.value, ast.Name): - if annotation.value.id == "pto" and isinstance(inferred_type, SemanticScalarType): - if inferred_type.dtype.name != annotation.attr: + annotation_expr = self._analyze_annotation_expr(annotation, env) + if isinstance(annotation_expr.type, SemanticMetaType): + if annotation_expr.type.kind == "dtype" and isinstance(inferred_type, SemanticScalarType): + dtype = self._require_dtype_symbol(annotation_expr, "annotated scalar type") + if inferred_type.dtype != dtype: + raise TypeError( + f"annotated scalar type `{dtype!r}` does not match inferred {inferred_type.dtype!r}" + ) + return inferred_type + if annotation_expr.type.kind == "ptr_type" and isinstance(inferred_type, SemanticPtrType): + ptr_type = self._require_ptr_type_expr(annotation_expr, "annotated pointer type") + if inferred_type.element_dtype != ptr_type.element_dtype: + raise TypeError( + f"annotated pointer type `{ptr_type!r}` does not match inferred pointer element type {inferred_type.element_dtype!r}" + ) + if inferred_type.memory_space != ptr_type.memory_space.value: + raise TypeError( + f"annotated pointer type `{ptr_type!r}` does not match inferred pointer memory space `{inferred_type.memory_space}`" + ) + return inferred_type + if annotation_expr.type.kind == "vreg_type" and isinstance(inferred_type, SemanticVRegType): + vreg_type = self._require_vreg_type_expr(annotation_expr, "annotated vector type") + if inferred_type.element_dtype != vreg_type.element_dtype or inferred_type.lanes != vreg_type.lanes: + raise TypeError( + f"annotated vector type `{vreg_type!r}` does not match inferred !pto.vreg<{inferred_type.lanes}x{inferred_type.element_dtype.name}>" + ) + return inferred_type + if annotation_expr.type.kind == "mask_type" and isinstance(inferred_type, SemanticMaskType): + mask_type = self._require_mask_type_expr(annotation_expr, "annotated mask type") + if inferred_type.granularity != mask_type.granularity: raise TypeError( - f"annotated scalar type `pto.{annotation.attr}` does not match inferred {inferred_type.dtype!r}" + f"annotated mask type `{mask_type!r}` does not match inferred !pto.mask<{inferred_type.granularity}>" ) return inferred_type raise TypeError("unsupported annotated assignment type in TileLang DSL v1") + def _analyze_annotation_expr( + self, + annotation: ast.AST, + env: dict[str, SemanticBinding], + ) -> SemanticExpr: + frontend_expr = self._build_frontend_annotation_expr(annotation) + return self._analyze_expr(frontend_expr, env, allow_outer_lookup=True) + + def _build_frontend_annotation_expr(self, node: ast.AST) -> FrontendExprNode: + if isinstance(node, ast.Name): + return FrontendNameExpr(name=node.id) + if isinstance(node, ast.Constant): + return FrontendConstantExpr(value=node.value) + if isinstance(node, ast.Attribute): + path = self._annotation_attribute_path(node) + if path is not None and path[0] in {"pto", "PAT", "PIPE", "EVENT"} and len(path) >= 2: + return FrontendSymbolExpr(namespace=".".join(path[:-1]), name=path[-1]) + return FrontendAttributeExpr( + base=self._build_frontend_annotation_expr(node.value), + attr=node.attr, + ) + if isinstance(node, ast.Call): + if any(keyword.arg is None for keyword in node.keywords): + raise TypeError("annotated assignment type does not support keyword unpacking in TileLang DSL v1") + if node.keywords: + raise TypeError("annotated assignment type does not support keyword arguments in TileLang DSL v1") + if isinstance(node.func, ast.Name): + return FrontendCallExpr( + namespace=None, + name=node.func.id, + args=tuple(self._build_frontend_annotation_expr(arg) for arg in node.args), + keywords=(), + ) + if isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name): + return FrontendCallExpr( + namespace=node.func.value.id, + name=node.func.attr, + args=tuple(self._build_frontend_annotation_expr(arg) for arg in node.args), + keywords=(), + ) + raise TypeError("unsupported annotated assignment type in TileLang DSL v1") + + def _annotation_attribute_path(self, node: ast.AST) -> tuple[str, ...] | None: + if isinstance(node, ast.Name): + return (node.id,) + if isinstance(node, ast.Attribute): + base_path = self._annotation_attribute_path(node.value) + if base_path is None: + return None + return base_path + (node.attr,) + return None + def _analyze_for( self, stmt: FrontendForStmt, @@ -1198,6 +2071,11 @@ def _analyze_if( ) -> tuple[SemanticStmt, dict[str, SemanticBinding]]: condition = self._analyze_expr(stmt.condition, env, allow_outer_lookup=allow_outer_lookup) self._require_condition_type(condition.type) + if self._contains_meta_condition_operand(condition): + raise TypeError( + "if condition comparing meta values requires wrapping the condition with pto.constexpr(...) " + "in TileLang DSL v1" + ) then_body, then_env = self._analyze_block( stmt.then_body, @@ -1242,11 +2120,43 @@ def _analyze_if( updated_env, ) + def _contains_meta_condition_operand(self, expr: SemanticExpr) -> bool: + if isinstance(expr, SemanticBinaryExpr): + if expr.op in {"eq", "ne"} and ( + isinstance(expr.lhs.type, SemanticMetaType) or isinstance(expr.rhs.type, SemanticMetaType) + ): + return True + if expr.op in {"and", "or"}: + return self._contains_meta_condition_operand(expr.lhs) or self._contains_meta_condition_operand(expr.rhs) + return False + + def _analyze_constexpr_if( + self, + stmt: FrontendIfStmt, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + ) -> tuple[tuple[SemanticStmt, ...], dict[str, SemanticBinding]]: + condition = self._analyze_expr(stmt.condition, env, allow_outer_lookup=allow_outer_lookup) + self._require_condition_type(condition.type) + static_value = self._require_constexpr_condition_bool( + condition, + context="if pto.constexpr(...) condition", + ) + selected_body = stmt.then_body if static_value else stmt.else_body + return self._analyze_block( + selected_body, + dict(env), + allow_outer_lookup=allow_outer_lookup, + ) + def _analyze_strict_vecscope( self, stmt: FrontendStrictVecscopeStmt, env: dict[str, SemanticBinding], ) -> tuple[SemanticStmt, dict[str, SemanticBinding]]: + if not self.node.advanced_enabled: + raise TypeError(advanced_mode_message("strict_vecscope")) if len(stmt.captures) != len(stmt.block_arguments): raise ValueError("strict_vecscope capture arity must match block arguments") @@ -1299,6 +2209,11 @@ def _analyze_expr( return SemanticLiteralExpr(value=expr.value, type=SemanticScalarType(dtype=i1)) if isinstance(expr.value, int): return SemanticLiteralExpr(value=expr.value, type=SemanticIndexType()) + if isinstance(expr.value, float): + return SemanticLiteralExpr( + value=expr.value, + type=SemanticScalarType(dtype=f32), + ) if isinstance(expr.value, str): return SemanticLiteralExpr( value=expr.value, @@ -1337,6 +2252,8 @@ def _analyze_expr( return self._element_type_expr(base) if expr.attr == "valid_shape": return self._valid_shape_expr(base) + if expr.attr == "strides": + return self._strides_expr(base) attr_type = self._attribute_type(base, expr.attr) return SemanticAttributeAccess(base=base, attr=expr.attr, type=attr_type) if isinstance(expr, FrontendSubscriptExpr): @@ -1353,14 +2270,42 @@ def _analyze_expr( result_type = self._binary_type(lhs, rhs, expr.op) return SemanticBinaryExpr(lhs=lhs, op=expr.op, rhs=rhs, type=result_type) if isinstance(expr, FrontendCallExpr): + if expr.namespace is None and expr.name in self._inline_proc_nodes: + if expr.keywords: + raise TypeError( + f"inline_proc call `{expr.name}` reached semantic analysis with unresolved keywords in TileLang DSL v1" + ) + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + return self._analyze_inline_proc_call_expr(expr.name, args) + if expr.namespace not in {None, "pto"} and expr.name == "as_ptr": + if expr.keywords: + raise TypeError("method call `as_ptr` does not support keyword arguments in TileLang DSL v1") + binding = env.get(expr.namespace) + if binding is None: + if allow_outer_lookup: + raise ValueError(f"unknown name '{expr.namespace}'") + raise ValueError( + f"implicit capture of '{expr.namespace}' is not allowed in pto.strict_vecscope" + ) + base = SemanticBindingRef(binding=binding, type=binding.type) + return self._analyze_as_ptr_method(base) if expr.namespace == "pto" and expr.name == "vlds" and len(expr.args) == 1: - base, offset = self._analyze_tile_vector_access( + base, indices = self._analyze_tile_vector_access( expr.args[0], env, allow_outer_lookup=allow_outer_lookup, context="pto.vlds source", ) - return self._analyze_vlds((base, offset)) + return self._analyze_vlds((base, *indices)) + if expr.keywords: + raise TypeError( + f"call surface `{expr.namespace + '.' if expr.namespace else ''}{expr.name}` " + "carries keyword arguments, but semantic keyword handling is not implemented " + "in TileLang DSL v1 yet" + ) args = tuple( self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) for arg in expr.args @@ -1378,6 +2323,14 @@ def _analyze_symbol_expr(self, expr: FrontendSymbolExpr) -> SemanticExpr: value=dtype, type=SemanticMetaType(kind="dtype"), ) + mask_type = _MASK_TYPE_SYMBOLS.get(expr.name) + if mask_type is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=mask_type, + type=SemanticMetaType(kind="mask_type"), + ) if expr.namespace in {"PAT", "pto.PAT", "pto.MaskPattern"}: pattern = _PATTERN_SYMBOLS.get(expr.name) if pattern is None and expr.name.startswith("PAT_"): @@ -1417,21 +2370,52 @@ def _analyze_symbol_expr(self, expr: FrontendSymbolExpr) -> SemanticExpr: value=memory_space, type=SemanticMetaType(kind="memory_space"), ) + if expr.namespace in {"pto.PadMode"}: + pad_mode = _PAD_MODE_SYMBOLS.get(expr.name) + if pad_mode is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=pad_mode, + type=SemanticMetaType(kind="pad_mode"), + ) + if expr.namespace in {"pto.PositionMode"}: + position_mode = _POSITION_MODE_SYMBOLS.get(expr.name) + if position_mode is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=position_mode, + type=SemanticMetaType(kind="position_mode"), + ) + if expr.namespace in {"pto.OrderMode"}: + order_mode = _ORDER_MODE_SYMBOLS.get(expr.name) + if order_mode is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=order_mode, + type=SemanticMetaType(kind="order_mode"), + ) raise TypeError( f"symbol `{expr.namespace}.{expr.name}` is not supported in TileLang DSL v1" ) def _attribute_type(self, base: SemanticExpr, attr: str) -> SemanticType: base_type = base.type - if isinstance(base_type, SemanticTensorViewType) and attr == "shape": + if isinstance(base_type, (SemanticTensorViewType, SemanticPartitionTensorViewType)) and attr == "shape": + return SemanticShapeType(rank=base_type.rank) + if isinstance(base_type, (SemanticTensorViewType, SemanticPartitionTensorViewType)) and attr == "strides": return SemanticShapeType(rank=base_type.rank) if isinstance(base_type, SemanticTileType) and attr == "shape": return SemanticShapeType(rank=base_type.rank) + if isinstance(base_type, (SemanticTensorViewType, SemanticPartitionTensorViewType, SemanticTileType)) and attr == "valid_shape": + return SemanticShapeType(rank=base_type.rank) raise TypeError(f"unsupported attribute access '{attr}' in TileLang DSL v1") def _element_type_expr(self, base: SemanticExpr) -> SemanticExpr: base_type = base.type - if isinstance(base_type, (SemanticTensorViewType, SemanticTileType)): + if isinstance(base_type, (SemanticTensorViewType, SemanticPartitionTensorViewType, SemanticTileType)): return SemanticSymbolExpr( namespace="pto", name=base_type.element_dtype.name, @@ -1440,19 +2424,48 @@ def _element_type_expr(self, base: SemanticExpr) -> SemanticExpr: ) raise TypeError("unsupported attribute access 'element_type' in TileLang DSL v1") + def _analyze_as_ptr_method(self, base: SemanticExpr) -> SemanticExpr: + base_type = base.type + if isinstance(base_type, (SemanticTensorViewType, SemanticPartitionTensorViewType)): + return SemanticCallExpr( + namespace="pto", + name="tensor_view_as_ptr", + args=(base,), + type=SemanticPtrType( + element_dtype=base_type.element_dtype, + memory_space="gm", + ), + ) + if isinstance(base_type, SemanticTileType): + return SemanticCallExpr( + namespace="pto", + name="tile_as_ptr", + args=(base,), + type=SemanticPtrType( + element_dtype=base_type.element_dtype, + memory_space=base_type.memory_space or "ub", + ), + ) + raise TypeError("`as_ptr()` expects a TensorView/PartitionTensorView or Tile value in TileLang DSL v1") + def _valid_shape_expr(self, base: SemanticExpr) -> SemanticExpr: base_type = base.type - if not isinstance(base_type, (SemanticTensorViewType, SemanticTileType)): + if not isinstance(base_type, (SemanticTensorViewType, SemanticPartitionTensorViewType, SemanticTileType)): raise TypeError("unsupported attribute access 'valid_shape' in TileLang DSL v1") shape_access = SemanticAttributeAccess( base=base, - attr="shape", + attr="valid_shape", type=SemanticShapeType(rank=base_type.rank), ) elements = [] for axis in range(base_type.rank): - if isinstance(base, SemanticBindingRef) and isinstance(base.type, SemanticTensorViewType): - self._ensure_tensor_shape_parameter(base.binding, axis) + if ( + isinstance(base, SemanticBindingRef) + and isinstance(base.type, SemanticTileType) + and base.type.valid_shape is not None + and base.type.valid_shape[axis] is None + ): + self._ensure_tile_valid_shape_parameter(base.binding, axis) elements.append( SemanticSubscriptAccess( base=shape_access, @@ -1465,6 +2478,29 @@ def _valid_shape_expr(self, base: SemanticExpr) -> SemanticExpr: type=SemanticTupleType(elements=tuple(SemanticIndexType() for _ in elements)), ) + def _strides_expr(self, base: SemanticExpr) -> SemanticExpr: + base_type = base.type + if not isinstance(base_type, (SemanticTensorViewType, SemanticPartitionTensorViewType)): + raise TypeError("unsupported attribute access 'strides' in TileLang DSL v1") + stride_access = SemanticAttributeAccess( + base=base, + attr="strides", + type=SemanticShapeType(rank=base_type.rank), + ) + elements = [] + for axis in range(base_type.rank): + elements.append( + SemanticSubscriptAccess( + base=stride_access, + index=SemanticLiteralExpr(value=axis, type=SemanticIndexType()), + type=SemanticIndexType(), + ) + ) + return SemanticTupleExpr( + elements=tuple(elements), + type=SemanticTupleType(elements=tuple(SemanticIndexType() for _ in elements)), + ) + def _subscript_type(self, base: SemanticExpr, index: SemanticExpr) -> SemanticType: if isinstance(base.type, SemanticShapeType): if not isinstance(index.type, SemanticIndexType): @@ -1479,10 +2515,8 @@ def _subscript_type(self, base: SemanticExpr, index: SemanticExpr) -> SemanticTy raise TypeError( f"shape subscript index {index.value} is out of bounds for rank {base.type.rank}" ) - if isinstance(base.base.type, SemanticTensorViewType): - self._ensure_tensor_shape_parameter(base.base.binding, index.value) return SemanticIndexType() - if isinstance(base.type, SemanticTensorViewType): + if isinstance(base.type, (SemanticTensorViewType, SemanticPartitionTensorViewType)): if not isinstance(index, SemanticTupleExpr): raise TypeError("TensorView slicing expects a tuple of slices in TileLang DSL v1") return self._tensor_slice_type(base.type, index) @@ -1495,25 +2529,23 @@ def _analyze_tile_vector_access( *, allow_outer_lookup: bool, context: str, - ) -> tuple[SemanticExpr, SemanticExpr]: - if not self.node.advanced_enabled: - raise TypeError(unsupported_feature_message(f"{context} tile indexing sugar")) + ) -> tuple[SemanticExpr, tuple[SemanticExpr, ...]]: if not isinstance(expr, FrontendSubscriptExpr): raise TypeError( - f"{context} expects Tile element-indexing syntax in advanced TileLang DSL mode" + f"{context} expects Tile element-indexing syntax in TileLang DSL v1" ) base = self._analyze_expr(expr.base, env, allow_outer_lookup=allow_outer_lookup) tile = self._require_tile_expr(base, context) - offset = self._tile_vector_offset_expr( + indices = self._tile_vector_indices( expr.index, tile.type, env, allow_outer_lookup=allow_outer_lookup, context=context, ) - return base, offset + return base, indices - def _tile_vector_offset_expr( + def _tile_vector_indices( self, index_expr: FrontendExprNode, tile_type: SemanticTileType, @@ -1521,7 +2553,7 @@ def _tile_vector_offset_expr( *, allow_outer_lookup: bool, context: str, - ) -> SemanticExpr: + ) -> tuple[SemanticExpr, ...]: if tile_type.rank == 1: if not isinstance(index_expr, FrontendSliceExpr): raise TypeError(f"{context} expects Tile[start:] syntax for rank-1 Tile values") @@ -1530,10 +2562,10 @@ def _tile_vector_offset_expr( if index_expr.step is not None: raise TypeError(f"{context} does not support stepped Tile vector slices in TileLang DSL advanced mode") if index_expr.start is None: - return SemanticLiteralExpr(value=0, type=SemanticIndexType()) + return (SemanticLiteralExpr(value=0, type=SemanticIndexType()),) start = self._analyze_expr(index_expr.start, env, allow_outer_lookup=allow_outer_lookup) self._require_index_typed_expr(start) - return start + return (start,) if tile_type.rank != 2 or tile_type.shape is None: raise TypeError(f"{context} currently only supports statically specialized rank-1 or rank-2 Tiles") @@ -1555,22 +2587,19 @@ def _tile_vector_offset_expr( else: col = self._analyze_expr(col_expr.start, env, allow_outer_lookup=allow_outer_lookup) self._require_index_typed_expr(col) - - stride = SemanticLiteralExpr(value=tile_type.shape[1], type=SemanticIndexType()) - row_offset = SemanticBinaryExpr(lhs=row, op="mul", rhs=stride, type=SemanticIndexType()) - if isinstance(col, SemanticLiteralExpr) and col.value == 0: - return row_offset - return SemanticBinaryExpr(lhs=row_offset, op="add", rhs=col, type=SemanticIndexType()) + return (row, col) def _tensor_slice_type( self, - tensor_type: SemanticTensorViewType, + tensor_type: SemanticTensorViewType | SemanticPartitionTensorViewType, index: SemanticTupleExpr, ) -> SemanticTensorSliceType: - if len(index.elements) != tensor_type.rank: + if not 1 <= len(index.elements) <= tensor_type.rank: raise TypeError( - f"TensorView slice rank {len(index.elements)} does not match TensorView rank {tensor_type.rank}" + f"TensorView slice rank {len(index.elements)} must be between 1 and " + f"{tensor_type.rank} in TileLang DSL v1" ) + axis_offset = tensor_type.rank - len(index.elements) extents = [] for axis, element in enumerate(index.elements): if not isinstance(element, SemanticSliceExpr): @@ -1581,44 +2610,44 @@ def _tensor_slice_type( self._require_optional_index_typed_expr(element.stop) self._require_optional_index_typed_expr(element.step) - start = self._static_index_value(element.start, default=0) - stop = self._static_index_value(element.stop, default=None) - step = self._static_index_value(element.step, default=1) if element.stop is None: raise TypeError("TensorView slicing requires explicit stop bounds in TileLang DSL v1") - if start != 0: - raise TypeError("TensorView slicing currently only supports zero-based starts in TileLang DSL v1") - if element.step is not None and step is None: - raise TypeError("TensorView slicing currently only supports unit stride in TileLang DSL v1") - if step != 1: - raise TypeError("TensorView slicing currently only supports unit stride in TileLang DSL v1") - if stop is None: - extent = None - else: - extent = stop - start - if extent <= 0: - raise TypeError("TensorView slicing requires positive extents in TileLang DSL v1") - extents.append(extent) + extents.append(self._normalized_tensor_slice_extent(element)) return SemanticTensorSliceType( element_dtype=tensor_type.element_dtype, - rank=tensor_type.rank, + rank=len(index.elements), extents=tuple(extents), + physical_axes=tuple(range(axis_offset, tensor_type.rank)), ) def _normalize_tensor_slice( self, index: SemanticExpr, rank: int, - ) -> tuple[SemanticSliceExpr, ...]: + ) -> tuple[SemanticTensorSliceAxis, ...]: if not isinstance(index, SemanticTupleExpr): raise TypeError("TensorView slicing expects a tuple index in TileLang DSL v1") - if len(index.elements) != rank: - raise TypeError(f"TensorView slicing expects {rank} slice elements in TileLang DSL v1") + if not 1 <= len(index.elements) <= rank: + raise TypeError( + f"TensorView slicing expects between 1 and {rank} slice elements in TileLang DSL v1" + ) slices = [] for element in index.elements: if not isinstance(element, SemanticSliceExpr): raise TypeError("TensorView slicing only supports slice syntax in TileLang DSL v1") - slices.append(element) + if element.stop is None: + raise TypeError("TensorView slicing requires explicit stop bounds in TileLang DSL v1") + start = self._normalize_optional_index_expr(element.start, default=0) + stop = element.stop + step = self._normalize_optional_index_expr(element.step, default=1) + slices.append( + SemanticTensorSliceAxis( + start=start, + stop=stop, + step=step, + extent=self._normalized_tensor_slice_extent(element), + ) + ) return tuple(slices) def _binary_type( @@ -1627,11 +2656,33 @@ def _binary_type( rhs: SemanticExpr, op: str, ) -> SemanticType: - if op not in {"add", "sub", "mul", "floordiv"}: - raise TypeError(f"unsupported binary operator '{op}' in TileLang DSL v1") - if isinstance(lhs.type, SemanticIndexType) and isinstance(rhs.type, SemanticIndexType): - return SemanticIndexType() - raise TypeError("binary expressions currently only support index-typed operands") + if op in {"add", "sub", "mul", "floordiv"}: + if isinstance(lhs.type, SemanticIndexType) and isinstance(rhs.type, SemanticIndexType): + return SemanticIndexType() + raise TypeError("binary expressions currently only support index-typed operands") + if op in {"eq", "ne"}: + if isinstance(lhs.type, SemanticIndexType) and isinstance(rhs.type, SemanticIndexType): + return SemanticScalarType(dtype=i1) + if isinstance(lhs.type, SemanticScalarType) and lhs.type == rhs.type: + return SemanticScalarType(dtype=i1) + if isinstance(lhs.type, SemanticMetaType) and lhs.type == rhs.type: + return SemanticScalarType(dtype=i1) + raise TypeError( + "comparison expressions currently require matching scalar/meta types or index-typed operands" + ) + if op in {"gt", "lt", "ge", "le"}: + if isinstance(lhs.type, SemanticIndexType) and isinstance(rhs.type, SemanticIndexType): + return SemanticScalarType(dtype=i1) + if isinstance(lhs.type, SemanticScalarType) and lhs.type == rhs.type: + return SemanticScalarType(dtype=i1) + raise TypeError( + "ordered comparison expressions currently require matching scalar types or index-typed operands" + ) + if op in {"and", "or"}: + self._require_condition_type(lhs.type) + self._require_condition_type(rhs.type) + return SemanticScalarType(dtype=i1) + raise TypeError(f"unsupported binary operator '{op}' in TileLang DSL v1") def _analyze_call_expr( self, @@ -1641,20 +2692,34 @@ def _analyze_call_expr( ) -> SemanticExpr: if namespace is None and name == "range": return SemanticCallExpr(namespace=namespace, name=name, args=args, type=None) + if namespace is None: + raise TypeError( + f"call surface `{name}` is not supported in TileLang DSL v1" + ) if namespace != "pto": raise TypeError( f"call surface `{namespace + '.' if namespace else ''}{name}` is not supported in TileLang DSL v1 yet" ) if name in DEFERRED_PTO_SURFACES: raise TypeError(deferred_surface_message(name)) + if name in _DTYPE_SYMBOLS: + return self._analyze_scalar_constructor(name, args) if name == "ptr": return self._analyze_ptr_type(args) + if name == "vreg": + return self._analyze_vreg_type(args) if name == "castptr": return self._analyze_castptr(args) if name == "addptr": return self._analyze_addptr(args) - if name == "get_lanes": - return self._analyze_get_lanes(args) + if name == "bytewidth": + return self._analyze_bytewidth(args) + if name in {"get_lanes", "elements_per_vreg"}: + return self._analyze_get_lanes(args, call_name=name) + if name == "constexpr": + raise TypeError( + "pto.constexpr(...) is only supported as an if-condition wrapper in TileLang DSL v1" + ) if name == "make_mask": return self._analyze_make_mask(args) if name == "vlds": @@ -1671,12 +2736,24 @@ def _analyze_call_expr( return self._analyze_carry_op(name, args) if name in {"vintlv", "vdintlv", "vintlvv2", "vdintlvv2"}: return self._analyze_rearrangement_op(name, args) + if name == "vcvt": + return self._analyze_vcvt(args) + if name == "vmrgsort4": + return self._analyze_vmrgsort4(args) + if name in _BROADCAST_VECTOR_OPS: + return self._analyze_broadcast_vector_op(name, args) + if name in _MULTI_RESULT_VECTOR_OPS: + return self._analyze_multi_result_vector_op(name, args) if name in _UNARY_VECTOR_OPS: return self._analyze_unary_vector_op(name, args) if name in _BINARY_VECTOR_OPS: return self._analyze_binary_vector_op(name, args) if name in _VECTOR_SCALAR_OPS: return self._analyze_vector_scalar_op(name, args) + if name in _VECTOR_IMMEDIATE_OPS: + return self._analyze_vector_immediate_op(name, args) + if name in _TERNARY_VECTOR_OPS: + return self._analyze_ternary_vector_op(name, args) raise TypeError(f"call surface `pto.{name}` is not supported in TileLang DSL v1 yet") def _analyze_make_mask(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: @@ -1704,6 +2781,123 @@ def _analyze_make_mask(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: ), ) + def _analyze_scalar_constructor( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if len(args) != 1: + raise TypeError(f"pto.{name} expects exactly 1 positional argument in TileLang DSL v1") + + target_dtype = _DTYPE_SYMBOLS[name] + if ( + target_dtype.name in {"f16", "bf16", "f32"} + and isinstance(args[0], SemanticLiteralExpr) + and isinstance(args[0].type, SemanticMetaType) + and args[0].type.kind == "string" + ): + parsed = self._parse_float_literal_string(args[0].value, target_dtype, f"pto.{name} value") + return SemanticLiteralExpr( + value=parsed, + type=SemanticScalarType(dtype=target_dtype), + ) + + value = self._require_scalar_or_index_expr(args[0], f"pto.{name} value") + + if isinstance(value.type, SemanticScalarType) and value.type.dtype == target_dtype: + return value + + if isinstance(value, SemanticLiteralExpr): + literal_value = value.value + if target_dtype == i1: + if isinstance(literal_value, bool): + return SemanticLiteralExpr(value=literal_value, type=SemanticScalarType(dtype=i1)) + if isinstance(literal_value, int): + return SemanticLiteralExpr(value=bool(literal_value), type=SemanticScalarType(dtype=i1)) + if isinstance(literal_value, float): + return SemanticLiteralExpr(value=bool(literal_value), type=SemanticScalarType(dtype=i1)) + elif target_dtype.name.startswith("i"): + if isinstance(literal_value, bool): + casted = int(literal_value) + elif isinstance(literal_value, (int, float)): + casted = int(literal_value) + else: + casted = None + if casted is not None: + bits = int(target_dtype.name[1:]) + min_value = -(1 << (bits - 1)) + max_value = (1 << (bits - 1)) - 1 + if casted < min_value or casted > max_value: + raise TypeError( + f"pto.{name} value {casted} is out of range for {target_dtype.name} in TileLang DSL v1" + ) + return SemanticLiteralExpr(value=casted, type=SemanticScalarType(dtype=target_dtype)) + else: + if isinstance(literal_value, (bool, int, float)): + return SemanticLiteralExpr( + value=float(literal_value), + type=SemanticScalarType(dtype=target_dtype), + ) + + return SemanticCallExpr( + namespace="pto", + name=name, + args=(value,), + type=SemanticScalarType(dtype=target_dtype), + ) + + def _parse_float_literal_string( + self, + literal: str, + target_dtype: ScalarType, + context: str, + ) -> float: + text = literal.strip().lower() + if text in {"inf", "+inf", "infinity", "+infinity"}: + return float("inf") + if text in {"-inf", "-infinity"}: + return float("-inf") + if text in {"nan", "+nan", "-nan"}: + return float("nan") + + if text.startswith("0x"): + try: + bit_pattern = int(text, 16) + except ValueError as exc: + raise TypeError( + f"{context} string literal {literal!r} is not a valid hex bit-pattern" + ) from exc + return self._float_from_bit_pattern(bit_pattern, target_dtype, context=context) + + try: + return float(text) + except ValueError as exc: + raise TypeError( + f"{context} string literal {literal!r} is not a valid float literal" + ) from exc + + def _float_from_bit_pattern( + self, + bit_pattern: int, + target_dtype: ScalarType, + *, + context: str, + ) -> float: + if target_dtype.name == "f16": + if bit_pattern < 0 or bit_pattern > 0xFFFF: + raise TypeError(f"{context} f16 bit-pattern must be in [0x0, 0xFFFF]") + return float(struct.unpack(">e", struct.pack(">H", bit_pattern))[0]) + if target_dtype.name == "bf16": + if bit_pattern < 0 or bit_pattern > 0xFFFF: + raise TypeError(f"{context} bf16 bit-pattern must be in [0x0, 0xFFFF]") + widened = bit_pattern << 16 + return float(struct.unpack(">f", struct.pack(">I", widened))[0]) + if target_dtype.name == "f32": + if bit_pattern < 0 or bit_pattern > 0xFFFFFFFF: + raise TypeError(f"{context} f32 bit-pattern must be in [0x0, 0xFFFFFFFF]") + return float(struct.unpack(">f", struct.pack(">I", bit_pattern))[0]) + raise TypeError(f"{context} bit-pattern literals are not supported for dtype {target_dtype.name}") + def _analyze_ptr_type(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: if len(args) != 2: raise TypeError("pto.ptr expects exactly 2 positional arguments in TileLang DSL") @@ -1714,6 +2908,16 @@ def _analyze_ptr_type(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: type=SemanticMetaType(kind="ptr_type"), ) + def _analyze_vreg_type(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) != 1: + raise TypeError("pto.vreg expects exactly 1 positional argument in TileLang DSL v1") + dtype = self._require_dtype_symbol(args[0], "pto.vreg element type") + vreg_type = self._vreg_type_for_dtype(dtype) + return SemanticLiteralExpr( + value=VRegType(element_dtype=dtype, lanes=vreg_type.lanes), + type=SemanticMetaType(kind="vreg_type"), + ) + def _analyze_castptr(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: if len(args) != 2: raise TypeError("pto.castptr expects exactly 2 positional arguments in TileLang DSL") @@ -1733,28 +2937,90 @@ def _analyze_addptr(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: self._require_index_typed_expr(offset) return SemanticCallExpr(namespace="pto", name="addptr", args=(ptr, offset), type=ptr.type) - def _analyze_get_lanes(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + def _analyze_get_lanes( + self, + args: tuple[SemanticExpr, ...], + *, + call_name: str = "get_lanes", + ) -> SemanticExpr: if len(args) != 1: - raise TypeError("pto.get_lanes expects exactly 1 positional argument in TileLang DSL v1") - dtype = self._require_dtype_symbol(args[0], "pto.get_lanes dtype") + raise TypeError( + f"pto.{call_name} expects exactly 1 positional argument in TileLang DSL v1" + ) + dtype = self._require_dtype_symbol(args[0], f"pto.{call_name} dtype") return SemanticLiteralExpr(value=self._vreg_type_for_dtype(dtype).lanes, type=SemanticIndexType()) + def _analyze_bytewidth(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) != 1: + raise TypeError("pto.bytewidth expects exactly 1 positional argument in TileLang DSL v1") + dtype = self._require_dtype_symbol(args[0], "pto.bytewidth dtype") + return SemanticLiteralExpr(value=bytewidth(dtype), type=SemanticIndexType()) + def _analyze_vlds(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: - if len(args) != 2: - raise TypeError("pto.vlds expects exactly 2 positional arguments in TileLang DSL v1") - source, offset = args + if len(args) < 2: + raise TypeError("pto.vlds expects at least 2 positional arguments in TileLang DSL v1") + source, *indices = args + source_type = source.type if isinstance(source_type, SemanticTileType): source = self._require_tile_expr(source, "pto.vlds source") else: source = self._require_pointer_expr(source, "pto.vlds source", memory_space="ub") - self._require_index_typed_expr(offset) + for index in indices: + self._require_index_typed_expr(index) return SemanticCallExpr( namespace="pto", name="vlds", - args=args, + args=(source, *indices), type=self._vreg_type_for_dtype(source.type.element_dtype), ) + def _analyze_broadcast_vector_op( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if name == "vbr": + if len(args) != 1: + raise TypeError("pto.vbr expects exactly 1 positional argument in TileLang DSL v1") + value = args[0] + vec_type = self._vreg_type_for_scalar_or_index(value, "pto.vbr value") + return SemanticCallExpr(namespace="pto", name=name, args=args, type=vec_type) + + if name == "vdup": + if len(args) not in {1, 2}: + raise TypeError("pto.vdup expects 1 or 2 positional arguments in TileLang DSL v1") + value = args[0] + if isinstance(value.type, SemanticVRegType): + vec_type = value.type + else: + vec_type = self._vreg_type_for_scalar_or_index(value, "pto.vdup input") + position_arg = args[1] if len(args) == 2 else None + position = self._normalize_position_mode(position_arg, "pto.vdup position") + return SemanticCallExpr( + namespace="pto", + name=name, + args=(value, position), + type=vec_type, + ) + + if name == "vci": + if len(args) not in {1, 2}: + raise TypeError("pto.vci expects 1 or 2 positional arguments in TileLang DSL v1") + index = self._require_scalar_or_index_expr(args[0], "pto.vci index") + index_dtype = i32 if isinstance(index.type, SemanticIndexType) else index.type.dtype + if index_dtype.name not in {"i8", "i16", "i32"}: + raise TypeError("pto.vci index only supports i8/i16/i32 in TileLang DSL v1") + order_arg = args[1] if len(args) == 2 else None + order = self._normalize_order_mode(order_arg, "pto.vci order") + return SemanticCallExpr( + namespace="pto", + name=name, + args=(index, order), + type=self._vreg_type_for_dtype(index_dtype), + ) + + raise TypeError(f"call surface `pto.{name}` is not supported in TileLang DSL v1 yet") + def _analyze_unary_vector_op( self, name: str, @@ -1778,7 +3044,12 @@ def _analyze_binary_vector_op( lhs_expr, rhs_expr, mask = args lhs = self._require_vreg_expr(lhs_expr, f"pto.{name} lhs") rhs = self._require_vreg_expr(rhs_expr, f"pto.{name} rhs") - if lhs != rhs: + if name == "vperm": + if rhs.element_dtype.name not in {"i8", "i16", "i32"}: + raise TypeError("pto.vperm indices vector only supports integer vector dtypes in TileLang DSL v1") + if lhs.lanes != rhs.lanes: + raise TypeError("pto.vperm requires data/indices vectors to use the same lane width") + elif lhs != rhs: raise TypeError(f"pto.{name} requires lhs/rhs vector types to match") self._require_mask_for_vreg(mask, lhs, f"pto.{name}") self._validate_binary_dtype(name, lhs.element_dtype) @@ -1800,6 +3071,59 @@ def _analyze_vector_scalar_op( self._validate_vector_scalar_dtype(name, vreg.element_dtype) return SemanticCallExpr(namespace="pto", name=name, args=args, type=vreg) + def _analyze_vector_immediate_op( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if len(args) != 3: + raise TypeError(f"pto.{name} expects exactly 3 positional arguments in TileLang DSL v1") + vector = self._require_vreg_expr(args[0], f"pto.{name} vector") + immediate = self._require_scalar_or_index_expr(args[1], f"pto.{name} immediate") + if isinstance(immediate.type, SemanticScalarType) and immediate.type.dtype.name not in {"i8", "i16", "i32"}: + raise TypeError(f"pto.{name} immediate only supports i8/i16/i32 in TileLang DSL v1") + self._require_mask_for_vreg(args[2], vector, f"pto.{name}") + self._validate_vector_immediate_dtype(name, vector.element_dtype) + return SemanticCallExpr(namespace="pto", name=name, args=args, type=vector) + + def _analyze_ternary_vector_op( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if len(args) != 4: + raise TypeError(f"pto.{name} expects exactly 4 positional arguments in TileLang DSL v1") + vec0 = self._require_vreg_expr(args[0], f"pto.{name} vec0") + vec1 = self._require_vreg_expr(args[1], f"pto.{name} vec1") + vec2 = self._require_vreg_expr(args[2], f"pto.{name} vec2") + if not (vec0 == vec1 == vec2): + raise TypeError(f"pto.{name} requires all vector operands to use the same vector type") + self._require_mask_for_vreg(args[3], vec0, f"pto.{name}") + self._validate_ternary_vector_dtype(name, vec0.element_dtype) + return SemanticCallExpr(namespace="pto", name=name, args=args, type=vec0) + + def _analyze_multi_result_vector_op( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if name != "vmull": + raise TypeError(f"call surface `pto.{name}` is not supported in TileLang DSL v1 yet") + if len(args) != 3: + raise TypeError("pto.vmull expects exactly 3 positional arguments in TileLang DSL") + lhs = self._require_vreg_expr(args[0], "pto.vmull lhs") + rhs = self._require_vreg_expr(args[1], "pto.vmull rhs") + if lhs != rhs: + raise TypeError("pto.vmull requires lhs/rhs vector types to match") + self._require_mask_for_vreg(args[2], lhs, "pto.vmull") + self._validate_multi_result_vector_dtype(name, lhs.element_dtype) + return SemanticCallExpr( + namespace="pto", + name=name, + args=args, + type=SemanticTupleType(elements=(lhs, lhs)), + ) + def _analyze_mask_part_op( self, name: str, @@ -1959,6 +3283,31 @@ def _analyze_rearrangement_op( self._require_string_expr(args[2], f"pto.{name} part") return SemanticCallExpr(namespace="pto", name=name, args=args, type=lhs) + def _analyze_vcvt(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) != 3: + raise TypeError("pto.vcvt expects exactly 3 positional arguments in TileLang DSL") + vector = self._require_vreg_expr(args[0], "pto.vcvt vector") + target_dtype = self._require_dtype_symbol(args[1], "pto.vcvt to_type") + self._require_mask_for_vreg(args[2], vector, "pto.vcvt") + return SemanticCallExpr( + namespace="pto", + name="vcvt", + args=args, + type=self._vreg_type_for_dtype(target_dtype), + ) + + def _analyze_vmrgsort4(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) != 5: + raise TypeError("pto.vmrgsort4 expects exactly 5 positional arguments in TileLang DSL") + vec0 = self._require_vreg_expr(args[0], "pto.vmrgsort4 vec0") + vec1 = self._require_vreg_expr(args[1], "pto.vmrgsort4 vec1") + vec2 = self._require_vreg_expr(args[2], "pto.vmrgsort4 vec2") + vec3 = self._require_vreg_expr(args[3], "pto.vmrgsort4 vec3") + if not (vec0 == vec1 == vec2 == vec3): + raise TypeError("pto.vmrgsort4 requires all vector operands to use the same vector type") + self._require_mask_for_vreg(args[4], vec0, "pto.vmrgsort4") + return SemanticCallExpr(namespace="pto", name="vmrgsort4", args=args, type=vec0) + def _require_dtype_symbol(self, expr: SemanticExpr, context: str) -> ScalarType: if not ( isinstance(expr, SemanticSymbolExpr) @@ -2008,6 +3357,40 @@ def _require_ptr_type_expr(self, expr: SemanticExpr, context: str) -> PointerTyp return expr.binding.value raise TypeError(f"{context} must be a pointer type constructed with pto.ptr(...)") + def _require_vreg_type_expr(self, expr: SemanticExpr, context: str) -> VRegType: + if ( + isinstance(expr, SemanticLiteralExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "vreg_type" + and isinstance(expr.value, VRegType) + ): + return expr.value + if ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "vreg_type" + and isinstance(expr.binding.value, VRegType) + ): + return expr.binding.value + raise TypeError(f"{context} must be a vector type constructed with pto.vreg(...)") + + def _require_mask_type_expr(self, expr: SemanticExpr, context: str) -> MaskType: + if ( + isinstance(expr, SemanticSymbolExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "mask_type" + and isinstance(expr.value, MaskType) + ): + return expr.value + if ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "mask_type" + and isinstance(expr.binding.value, MaskType) + ): + return expr.binding.value + raise TypeError(f"{context} must be a mask type such as pto.mask_b32") + def _require_cast_target_type(self, expr: SemanticExpr) -> SemanticType: if self._is_i64_dtype_expr(expr): return SemanticScalarType(dtype=i64) @@ -2026,16 +3409,7 @@ def _require_castptr_input(self, expr: SemanticExpr, target_type: SemanticPtrTyp if expr.type.memory_space != target_type.memory_space: raise TypeError("pto.castptr pointer-to-pointer casts must stay within one PTO memory space") return - if isinstance(expr.type, SemanticTensorViewType): - if target_type.memory_space != "gm": - raise TypeError("pto.castptr TensorView casts require a GM pointer target") - return - if isinstance(expr.type, SemanticTileType): - tile_memory_space = expr.type.memory_space or "ub" - if tile_memory_space != target_type.memory_space: - raise TypeError("pto.castptr Tile casts must preserve the Tile memory space") - return - raise TypeError("pto.castptr input must be an index/i64, pointer, TensorView, or Tile value") + raise TypeError("pto.castptr input must be an index/i64, pointer, or memref-backed address value") def _is_i64_dtype_expr(self, expr: SemanticExpr) -> bool: if isinstance(expr, SemanticSymbolExpr): @@ -2058,6 +3432,71 @@ def _require_scalar_expr(self, expr: SemanticExpr, context: str) -> SemanticScal raise TypeError(f"{context} must be a scalar value in TileLang DSL v1") return expr.type + def _require_scalar_or_index_expr(self, expr: SemanticExpr, context: str) -> SemanticExpr: + if isinstance(expr.type, (SemanticScalarType, SemanticIndexType)): + return expr + raise TypeError(f"{context} must be a scalar or index value in TileLang DSL v1") + + def _vreg_type_for_scalar_or_index(self, expr: SemanticExpr, context: str) -> SemanticVRegType: + value = self._require_scalar_or_index_expr(expr, context) + if isinstance(value.type, SemanticScalarType): + return self._vreg_type_for_dtype(value.type.dtype) + return self._vreg_type_for_dtype(i32) + + def _normalize_position_mode( + self, + expr: SemanticExpr | None, + context: str, + ) -> SemanticExpr: + if expr is None: + return SemanticLiteralExpr(value=PositionMode.LOWEST.value, type=SemanticMetaType(kind="string")) + if ( + isinstance(expr, SemanticSymbolExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "position_mode" + and isinstance(expr.value, PositionMode) + ): + return SemanticLiteralExpr(value=expr.value.value, type=SemanticMetaType(kind="string")) + if ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "position_mode" + and isinstance(expr.binding.value, PositionMode) + ): + return SemanticLiteralExpr(value=expr.binding.value.value, type=SemanticMetaType(kind="string")) + position = self._require_string_expr(expr, context) + if position != PositionMode.LOWEST.value: + raise TypeError( + "pto.vdup currently only supports position `PositionMode.LOWEST` in TileLang DSL v1" + ) + return SemanticLiteralExpr(value=position, type=SemanticMetaType(kind="string")) + + def _normalize_order_mode( + self, + expr: SemanticExpr | None, + context: str, + ) -> SemanticExpr: + if expr is None: + return SemanticLiteralExpr(value=OrderMode.ASC.value, type=SemanticMetaType(kind="string")) + if ( + isinstance(expr, SemanticSymbolExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "order_mode" + and isinstance(expr.value, OrderMode) + ): + return SemanticLiteralExpr(value=expr.value.value, type=SemanticMetaType(kind="string")) + if ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "order_mode" + and isinstance(expr.binding.value, OrderMode) + ): + return SemanticLiteralExpr(value=expr.binding.value.value, type=SemanticMetaType(kind="string")) + order = self._require_string_expr(expr, context) + if order != OrderMode.ASC.value: + raise TypeError("pto.vci currently only supports order `OrderMode.ASC` in TileLang DSL v1") + return SemanticLiteralExpr(value=order, type=SemanticMetaType(kind="string")) + def _require_mask_expr(self, expr: SemanticExpr, context: str) -> SemanticMaskType: if not isinstance(expr.type, SemanticMaskType): raise TypeError(f"{context} must be a mask value in TileLang DSL") @@ -2159,31 +3598,61 @@ def _vreg_type_for_dtype(self, dtype: ScalarType) -> SemanticVRegType: return SemanticVRegType(element_dtype=dtype, lanes=256 // width) def _validate_unary_dtype(self, name: str, dtype: ScalarType) -> None: - if name == "vexp" and dtype.name not in {"f16", "f32"}: - raise TypeError("pto.vexp only supports f16/f32 in TileLang DSL v1") + if name in {"vexp", "vln", "vsqrt", "vrec", "vrsqrt", "vexpdiff"} and dtype.name not in {"f16", "f32"}: + raise TypeError(f"pto.{name} only supports f16/f32 in TileLang DSL v1") if name == "vrelu" and dtype.name not in {"f16", "f32"}: raise TypeError("pto.vrelu only supports f16/f32 in TileLang DSL v1") - if name == "vnot" and dtype.name not in {"i8", "i16", "i32"}: - raise TypeError("pto.vnot only supports integer vector dtypes in TileLang DSL v1") - if name == "vabs" and dtype.name not in {"i8", "i16", "i32", "f16", "f32"}: - raise TypeError("pto.vabs does not support this dtype in TileLang DSL v1") + if name in {"vnot", "vbcnt", "vcls", "vsunpack", "vzunpack", "vusqz", "vsqz"} and dtype.name not in {"i8", "i16", "i32"}: + raise TypeError(f"pto.{name} only supports integer vector dtypes in TileLang DSL v1") + if name in {"vabs", "vneg", "vmov", "vtrc", "vbitsort", "vcadd", "vcmax", "vcmin"} and dtype.name not in {"i8", "i16", "i32", "f16", "bf16", "f32"}: + raise TypeError(f"pto.{name} does not support this dtype in TileLang DSL v1") def _validate_binary_dtype(self, name: str, dtype: ScalarType) -> None: if name == "vdiv" and dtype.name not in {"f16", "f32"}: raise TypeError("pto.vdiv only supports f16/f32 in TileLang DSL v1") + if name == "vprelu" and dtype.name not in {"f16", "f32"}: + raise TypeError("pto.vprelu only supports f16/f32 in TileLang DSL v1") + if name in {"vaddreluconv", "vmulconv"} and dtype.name not in {"f16", "bf16", "f32"}: + raise TypeError(f"pto.{name} only supports f16/bf16/f32 in TileLang DSL v1") if name in {"vand", "vor", "vxor"} and dtype.name not in {"i8", "i16", "i32"}: raise TypeError(f"pto.{name} only supports integer vector dtypes in TileLang DSL v1") + if name in {"vshl", "vshr"} and dtype.name not in {"i8", "i16", "i32"}: + raise TypeError(f"pto.{name} only supports integer vector dtypes in TileLang DSL v1") if name == "vmul" and dtype.name not in {"i16", "i32", "f16", "f32"}: raise TypeError("pto.vmul only supports i16/i32/f16/f32 in TileLang DSL v1") - if name in {"vadd", "vsub", "vmax", "vmin"} and dtype.name not in {"i8", "i16", "i32", "f16", "bf16", "f32"}: + if name == "vperm" and dtype.name not in {"i8", "i16", "i32", "f16", "bf16", "f32"}: + raise TypeError("pto.vperm does not support this data vector dtype in TileLang DSL v1") + if name in {"vadd", "vsub", "vmax", "vmin", "vaddrelu", "vsubrelu"} and dtype.name not in {"i8", "i16", "i32", "f16", "bf16", "f32"}: + raise TypeError(f"pto.{name} does not support this dtype in TileLang DSL v1") + if name in {"vpack", "vmrgsort"} and dtype.name not in {"i8", "i16", "i32", "f16", "bf16", "f32"}: raise TypeError(f"pto.{name} does not support this dtype in TileLang DSL v1") def _validate_vector_scalar_dtype(self, name: str, dtype: ScalarType) -> None: if name == "vdivs" and dtype.name not in {"f16", "f32"}: raise TypeError("pto.vdivs only supports f16/f32 in TileLang DSL v1") + if name == "vlrelu" and dtype.name not in {"f16", "f32"}: + raise TypeError("pto.vlrelu only supports f16/f32 in TileLang DSL v1") + if name in {"vshls", "vshrs"} and dtype.name not in {"i8", "i16", "i32"}: + raise TypeError(f"pto.{name} only supports integer vector dtypes in TileLang DSL v1") + if name in {"vands", "vors", "vxors"} and dtype.name not in {"i8", "i16", "i32"}: + raise TypeError(f"pto.{name} only supports integer vector dtypes in TileLang DSL v1") if name in {"vadds", "vsubs", "vmuls", "vmaxs", "vmins"} and dtype.name not in {"i8", "i16", "i32", "f16", "bf16", "f32"}: raise TypeError(f"pto.{name} does not support this dtype in TileLang DSL v1") + def _validate_vector_immediate_dtype(self, name: str, dtype: ScalarType) -> None: + if name in {"vshift", "vslide"} and dtype.name not in {"i8", "i16", "i32", "f16", "bf16", "f32"}: + raise TypeError(f"pto.{name} does not support this vector dtype in TileLang DSL v1") + + def _validate_ternary_vector_dtype(self, name: str, dtype: ScalarType) -> None: + if name == "vaxpy" and dtype.name not in {"i16", "i32", "f16", "f32"}: + raise TypeError("pto.vaxpy only supports i16/i32/f16/f32 in TileLang DSL v1") + if name == "vmula" and dtype.name not in {"i16", "i32", "f16", "f32"}: + raise TypeError("pto.vmula only supports i16/i32/f16/f32 in TileLang DSL v1") + + def _validate_multi_result_vector_dtype(self, name: str, dtype: ScalarType) -> None: + if name == "vmull" and dtype.name != "i32": + raise TypeError("pto.vmull only supports i32 vectors in TileLang DSL v1") + def _require_sync_pipe(self, expr: SemanticExpr, context: str) -> str: if isinstance(expr, SemanticSymbolExpr) and expr.type.kind == "pipe": return expr.value.value @@ -2198,6 +3667,25 @@ def _require_sync_event(self, expr: SemanticExpr, context: str) -> str: return expr.value raise TypeError(f"{context} must be an EVENT symbol or event string in TileLang DSL v1") + def _pad_mode_value( + self, + expr: SemanticExpr | None, + *, + default: PadMode, + ) -> PadMode: + if expr is None: + return default + if isinstance(expr, SemanticSymbolExpr) and expr.type.kind == "pad_mode": + return expr.value + if ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "pad_mode" + and isinstance(expr.binding.value, PadMode) + ): + return expr.binding.value + raise TypeError("DMA pad_mode must be a PadMode symbol in TileLang DSL v1") + def _require_loop_bound_type(self, ty: SemanticType) -> None: if isinstance(ty, (SemanticIndexType, SemanticScalarType)): return @@ -2235,18 +3723,252 @@ def _require_index_typed_expr(self, expr: SemanticExpr) -> None: if not isinstance(expr.type, SemanticIndexType): raise TypeError("slice bounds and vector offsets must be index-typed in TileLang DSL v1") + def _try_static_dtype(self, expr: SemanticExpr) -> ScalarType | None: + if ( + isinstance(expr, SemanticSymbolExpr) + and expr.type.kind == "dtype" + and isinstance(expr.value, ScalarType) + ): + return expr.value + if ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "dtype" + and isinstance(expr.binding.value, ScalarType) + ): + return expr.binding.value + return None + + def _try_static_subscript_value(self, expr: SemanticSubscriptAccess) -> Any | None: + index_value = self._try_static_value(expr.index) + if not isinstance(index_value, int): + return None + + base = expr.base + if isinstance(base, SemanticAttributeAccess) and isinstance(base.base, SemanticBindingRef): + binding_ref = base.base + binding_type = binding_ref.type + if isinstance(binding_type, SemanticTileType): + if base.attr == "shape" and binding_type.shape is not None: + if 0 <= index_value < len(binding_type.shape): + return binding_type.shape[index_value] + if base.attr == "valid_shape" and binding_type.valid_shape is not None: + if 0 <= index_value < len(binding_type.valid_shape): + return binding_type.valid_shape[index_value] + return None + if isinstance(binding_type, (SemanticTensorViewType, SemanticPartitionTensorViewType)): + return None + + base_value = self._try_static_value(base) + if isinstance(base_value, (tuple, list)): + if 0 <= index_value < len(base_value): + return base_value[index_value] + return None + return None + + def _try_static_value(self, expr: SemanticExpr | None) -> Any | None: + if expr is None: + return None + if isinstance(expr, SemanticSymbolExpr): + return expr.value + if isinstance(expr, SemanticLiteralExpr): + return expr.value + if isinstance(expr, SemanticBindingRef): + return expr.binding.value + if isinstance(expr, SemanticTupleExpr): + elements = [] + for element in expr.elements: + static_element = self._try_static_value(element) + if static_element is None: + return None + elements.append(static_element) + return tuple(elements) + if isinstance(expr, SemanticSubscriptAccess): + return self._try_static_subscript_value(expr) + if isinstance(expr, SemanticBinaryExpr): + if expr.op in {"and", "or"}: + lhs_bool = self._try_static_condition_bool(expr.lhs) + rhs_bool = self._try_static_condition_bool(expr.rhs) + if lhs_bool is None or rhs_bool is None: + return None + if expr.op == "and": + return lhs_bool and rhs_bool + return lhs_bool or rhs_bool + lhs = self._try_static_value(expr.lhs) + rhs = self._try_static_value(expr.rhs) + if lhs is None or rhs is None: + return None + if expr.op == "add": + if isinstance(lhs, int) and isinstance(rhs, int): + return lhs + rhs + return None + if expr.op == "sub": + if isinstance(lhs, int) and isinstance(rhs, int): + return lhs - rhs + return None + if expr.op == "mul": + if isinstance(lhs, int) and isinstance(rhs, int): + return lhs * rhs + return None + if expr.op == "floordiv": + if isinstance(lhs, int) and isinstance(rhs, int): + if rhs == 0: + return None + return lhs // rhs + return None + if expr.op == "eq": + return lhs == rhs + if expr.op == "ne": + return lhs != rhs + if expr.op == "gt": + try: + return lhs > rhs + except TypeError: + return None + if expr.op == "lt": + try: + return lhs < rhs + except TypeError: + return None + if expr.op == "ge": + try: + return lhs >= rhs + except TypeError: + return None + if expr.op == "le": + try: + return lhs <= rhs + except TypeError: + return None + return None + if isinstance(expr, SemanticCallExpr): + if expr.namespace != "pto": + return None + if expr.name == "bytewidth": + if len(expr.args) != 1: + return None + dtype = self._try_static_dtype(expr.args[0]) + if dtype is None: + return None + return bytewidth(dtype) + if expr.name in {"get_lanes", "elements_per_vreg"}: + if len(expr.args) != 1: + return None + dtype = self._try_static_dtype(expr.args[0]) + if dtype is None: + return None + return self._vreg_type_for_dtype(dtype).lanes + return None + + def _try_static_condition_bool(self, expr: SemanticExpr | None) -> bool | None: + value = self._try_static_value(expr) + if isinstance(value, bool): + return value + if isinstance(value, int): + return value != 0 + return None + + def _require_constexpr_condition_bool( + self, + expr: SemanticExpr, + *, + context: str, + ) -> bool: + value = self._try_static_condition_bool(expr) + if value is None: + raise TypeError( + f"{context} must be a compile-time bool in TileLang DSL v1" + ) + return value + def _static_index_value(self, expr: SemanticExpr | None, *, default: int | None) -> int | None: if expr is None: return default - if not isinstance(expr, SemanticLiteralExpr) or not isinstance(expr.value, int): - return None - return expr.value + value = self._try_static_value(expr) + if isinstance(value, int) and not isinstance(value, bool): + return value + return None def _require_optional_index_typed_expr(self, expr: SemanticExpr | None) -> None: if expr is None: return self._require_index_typed_expr(expr) + def _static_bool_value(self, expr: SemanticExpr | None, *, default: bool | None) -> bool | None: + if expr is None: + return default + if isinstance(expr, SemanticLiteralExpr): + if ( + isinstance(expr.type, SemanticScalarType) + and expr.type.dtype == i1 + and isinstance(expr.value, bool) + ): + return expr.value + return None + if ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticScalarType) + and expr.type.dtype == i1 + and isinstance(expr.binding.value, bool) + ): + return expr.binding.value + return None + + def _require_static_bool_value( + self, + expr: SemanticExpr | None, + *, + context: str, + default: bool, + ) -> bool: + value = self._static_bool_value(expr, default=default) + if value is None: + raise TypeError( + f"{context} must be a compile-time bool in the stable frontend-only DMA profile" + ) + return value + + def _require_static_non_negative_index_value( + self, + expr: SemanticExpr | None, + *, + context: str, + default: int, + ) -> int: + value = self._static_index_value(expr, default=default) + if value is None: + raise TypeError( + f"{context} must be a static non-negative index in the stable frontend-only DMA profile" + ) + if value < 0: + raise TypeError( + f"{context} must be a non-negative index in the stable frontend-only DMA profile" + ) + return value + + def _normalize_optional_index_expr( + self, + expr: SemanticExpr | None, + *, + default: int, + ) -> SemanticExpr: + if expr is not None: + return expr + return SemanticLiteralExpr(value=default, type=SemanticIndexType()) + + def _normalized_tensor_slice_extent(self, expr: SemanticSliceExpr) -> int | None: + start = self._static_index_value(expr.start, default=0) + stop = self._static_index_value(expr.stop, default=None) + step = self._static_index_value(expr.step, default=1) + if stop is None or start is None or step is None: + return None + if step <= 0: + raise TypeError("TensorView slicing requires a positive static step in TileLang DSL v1") + distance = stop - start + if distance <= 0: + raise TypeError("TensorView slicing requires positive extents in TileLang DSL v1") + return (distance + step - 1) // step + def analyze_frontend_kernel(node: FrontendKernelNode) -> SemanticKernel: """Normalize descriptor-owned AST into a lowering semantic model.""" @@ -2261,6 +3983,7 @@ def analyze_frontend_kernel(node: FrontendKernelNode) -> SemanticKernel: "SemanticBinding", "SemanticBindingRef", "SemanticCallExpr", + "SemanticDmaOptions", "SemanticDmaLoadStmt", "SemanticDmaStoreStmt", "SemanticExpr", @@ -2285,9 +4008,11 @@ def analyze_frontend_kernel(node: FrontendKernelNode) -> SemanticKernel: "SemanticStrictVecscopeStmt", "SemanticSubscriptAccess", "SemanticSymbolExpr", + "SemanticTensorSliceAxis", "SemanticTensorSliceExpr", "SemanticTensorSliceType", "SemanticTensorViewType", + "SemanticPartitionTensorViewType", "SemanticTileBinding", "SemanticTileType", "SemanticTupleExpr", diff --git a/tilelang-dsl/python/tilelang_dsl/support_matrix.py b/tilelang-dsl/python/tilelang_dsl/support_matrix.py index ea8984f22..bf59af107 100644 --- a/tilelang-dsl/python/tilelang_dsl/support_matrix.py +++ b/tilelang-dsl/python/tilelang_dsl/support_matrix.py @@ -4,11 +4,32 @@ FOLLOW_UP_CHANGE = "extend-tilelang-dsl-matcher-and-advanced-surface" +# Tier definitions for TileLang DSL surface classification +# These tiers represent the user-facing support level of language features: +# - BASIC: Core surface that is fully supported and recommended for general use +# - ADVANCED: Features requiring advanced=True, suitable for expert users + +BASIC_TIER = "basic" +ADVANCED_TIER = "advanced" + +# Tier metadata for PTO calls and language constructs +# This provides a unified source of truth for documentation and testing + SUPPORTED_TOPLEVEL_PTO_CALLS = frozenset( { - "strict_vecscope", - "dma_load", - "dma_store", + "constexpr", + "bytewidth", + "get_lanes", + "elements_per_vreg", + "vreg", + "i1", + "i8", + "i16", + "i32", + "i64", + "f16", + "bf16", + "f32", "set_flag", "wait_flag", "pipe_barrier", @@ -24,7 +45,27 @@ "vabs", "vrelu", "vexp", + "vln", + "vsqrt", + "vrec", "vnot", + "vcadd", + "vcmax", + "vbcnt", + "vneg", + "vcls", + "vcmin", + "vrsqrt", + "vmov", + "vsunpack", + "vzunpack", + "vusqz", + "vsqz", + "vexpdiff", + "vtrc", + "vbitsort", + "vbr", + "vdup", "vadd", "vsub", "vmul", @@ -34,12 +75,41 @@ "vand", "vor", "vxor", + "vaddrelu", + "vaddreluconv", + "vsubrelu", + "vaxpy", + "vmulconv", + "vmull", + "vmula", + "vshl", + "vshr", + "vprelu", "vadds", "vsubs", "vmuls", "vdivs", "vmaxs", "vmins", + "vlrelu", + "vshls", + "vshrs", + "vands", + "vors", + "vxors", + "vcgadd", + "vcgmax", + "vcgmin", + "vcpadd", + "vpack", + "vperm", + "vshift", + "vslide", + "vsort32", + "vmrgsort", + "vcvt", + "vmrgsort4", + "vci", } ) @@ -75,6 +145,7 @@ ADVANCED_TOPLEVEL_PTO_CALLS = frozenset( { + "strict_vecscope", "copy_gm_to_ubuf", "copy_ubuf_to_gm", "copy_ubuf_to_ubuf", @@ -93,6 +164,80 @@ } ) +# Public surface groupings used by the guide, migration notes, and tests. +# These groupings intentionally mirror the user-facing authoring tiers rather +# than the internal lowering organization. + +BASIC_TENSORVIEW_SURFACES = frozenset({"TensorView"}) +BASIC_TILE_SURFACES = frozenset({"Tile"}) +BASIC_HIGH_LEVEL_DMA_SURFACES = frozenset() +BASIC_BASE_VECTOR_SURFACES = frozenset( + f"pto.{name}" for name in sorted(SUPPORTED_VECSCOPE_PTO_CALLS) +) + +ADVANCED_RAW_POINTER_SURFACES = frozenset( + { + "ptr", + "pto.ptr", + "PointerType", + "pto.castptr", + "pto.addptr", + } +) +ADVANCED_LOW_LEVEL_DMA_SURFACES = frozenset( + { + "pto.copy_gm_to_ubuf", + "pto.copy_ubuf_to_gm", + "pto.copy_ubuf_to_ubuf", + "pto.set_loop2_stride_outtoub", + "pto.set_loop1_stride_outtoub", + "pto.set_loop_size_outtoub", + "pto.set_loop2_stride_ubtoout", + "pto.set_loop1_stride_ubtoout", + "pto.set_loop_size_ubtoout", + } +) +ADVANCED_EXPLICIT_VECSCOPE_SURFACES = frozenset({"pto.strict_vecscope"}) +ADVANCED_TILE_HELPER_SURFACES = frozenset( + { + "tile.slice", + "tile.reshape", + "tile.as_ptr", + "tensorview.as_ptr", + "pto.tile_from_ptr", + "pto.tile_with_strides", + "pto.tile_config", + } +) +BASIC_TILE_INDEXING_SURFACES = frozenset( + { + "tile[start:]", + "tile[row, col:]", + } +) + +AUTHORING_TIER_SURFACE_GROUPS = { + "TensorView": BASIC_TENSORVIEW_SURFACES, + "Tile": BASIC_TILE_SURFACES, + "base_vector_ops": BASIC_BASE_VECTOR_SURFACES, + "tile_indexing_sugar": BASIC_TILE_INDEXING_SURFACES, + "strict_vecscope": ADVANCED_EXPLICIT_VECSCOPE_SURFACES, + "raw_pointer_family": ADVANCED_RAW_POINTER_SURFACES, + "low_level_dma_family": ADVANCED_LOW_LEVEL_DMA_SURFACES, + "tile_helper_family": ADVANCED_TILE_HELPER_SURFACES, +} + +AUTHORING_TIER_GROUP_TIERS = { + "TensorView": BASIC_TIER, + "Tile": BASIC_TIER, + "base_vector_ops": BASIC_TIER, + "tile_indexing_sugar": BASIC_TIER, + "strict_vecscope": ADVANCED_TIER, + "raw_pointer_family": ADVANCED_TIER, + "low_level_dma_family": ADVANCED_TIER, + "tile_helper_family": ADVANCED_TIER, +} + def unsupported_feature_message(feature: str) -> str: return ( @@ -109,6 +254,118 @@ def advanced_mode_message(name: str) -> str: return f"surface `pto.{name}` requires advanced=True in TileLang DSL" +# Tier mapping for PTO calls +def get_pto_call_tier(call_name: str) -> str: + """Return the tier of a PTO call. + + Args: + call_name: Name of the PTO call (without 'pto.' prefix) + + Returns: + One of BASIC_TIER or ADVANCED_TIER + + Raises: + KeyError: If the PTO call is not part of the supported DSL surface + """ + if call_name in SUPPORTED_TOPLEVEL_PTO_CALLS: + return BASIC_TIER + if call_name in SUPPORTED_VECSCOPE_PTO_CALLS: + return BASIC_TIER + if call_name in ADVANCED_VECSCOPE_PTO_CALLS: + return ADVANCED_TIER + if call_name in ADVANCED_EXPR_PTO_CALLS: + return ADVANCED_TIER + if call_name in ADVANCED_TOPLEVEL_PTO_CALLS: + return ADVANCED_TIER + raise KeyError(unsupported_feature_message(f"pto.{call_name}")) + + +UNSUPPORTED_LANGUAGE_CONSTRUCTS = frozenset( + { + "dma_load", + "dma_store", + "pto.dma_load", + "pto.dma_store", + "pto.get_buf", + "pto.rls_buf", + "pto.dma_copy", + "pto.vreduce", + "pto.tile", + "BLayout", + "SLayout", + "PadValue", + "SyncOpType", + } +) + + +# Tier mapping for language constructs (non-PTO-call features) +# These are higher-level abstractions in the TileLang DSL +LANGUAGE_CONSTRUCT_TIERS = { + # Basic tier constructs + "TensorView": BASIC_TIER, + "Tile": BASIC_TIER, + "VRegType": BASIC_TIER, + "MaskType": BASIC_TIER, + "pto.vreg": BASIC_TIER, + "pto.mask_b8": BASIC_TIER, + "pto.mask_b16": BASIC_TIER, + "pto.mask_b32": BASIC_TIER, + "PadMode": BASIC_TIER, + "constexpr": BASIC_TIER, + "pto.constexpr": BASIC_TIER, + "tile[start:]": BASIC_TIER, + "tile[row, col:]": BASIC_TIER, + # Advanced tier constructs + "ptr": ADVANCED_TIER, # raw pointer constructor + "strict_vecscope": ADVANCED_TIER, # explicit vecscope management + "pto.strict_vecscope": ADVANCED_TIER, + "tile.slice": ADVANCED_TIER, + "tile.reshape": ADVANCED_TIER, + "tile.as_ptr": ADVANCED_TIER, + "tensorview.as_ptr": ADVANCED_TIER, + "pto.tile_from_ptr": ADVANCED_TIER, + "pto.tile_with_strides": ADVANCED_TIER, + "pto.tile_config": ADVANCED_TIER, +} + + +def get_feature_tier(feature_name: str) -> str: + """Return the tier of a TileLang DSL feature. + + Args: + feature_name: Name of the feature, which can be: + - A PTO call name (e.g., 'vadd', 'ptr') + - A language construct (e.g., 'TensorView', 'dma_load') + - A qualified construct (e.g., 'tile.slice', 'pto.tile_from_ptr') + + Returns: + One of BASIC_TIER or ADVANCED_TIER + + Raises: + KeyError: If the feature is documented but not part of the supported DSL surface + """ + # First check if it's a known language construct + if feature_name in LANGUAGE_CONSTRUCT_TIERS: + return LANGUAGE_CONSTRUCT_TIERS[feature_name] + if feature_name in UNSUPPORTED_LANGUAGE_CONSTRUCTS: + raise KeyError(unsupported_feature_message(feature_name)) + + # Check if it's a PTO call (might be qualified with 'pto.' prefix) + call_name = feature_name + if feature_name.startswith("pto."): + call_name = feature_name[4:] + + # Check PTO call tier + return get_pto_call_tier(call_name) + + +def get_surface_group_tier(group_name: str) -> str: + """Return the authoring tier for a documented public-surface group.""" + + return AUTHORING_TIER_GROUP_TIERS[group_name] + + __all__ = [ "DEFERRED_PTO_SURFACES", "FOLLOW_UP_CHANGE", @@ -117,7 +374,25 @@ def advanced_mode_message(name: str) -> str: "ADVANCED_VECSCOPE_PTO_CALLS", "SUPPORTED_TOPLEVEL_PTO_CALLS", "SUPPORTED_VECSCOPE_PTO_CALLS", + "BASIC_TIER", + "ADVANCED_TIER", + "BASIC_TENSORVIEW_SURFACES", + "BASIC_TILE_SURFACES", + "BASIC_HIGH_LEVEL_DMA_SURFACES", + "BASIC_BASE_VECTOR_SURFACES", + "BASIC_TILE_INDEXING_SURFACES", + "ADVANCED_EXPLICIT_VECSCOPE_SURFACES", + "ADVANCED_RAW_POINTER_SURFACES", + "ADVANCED_LOW_LEVEL_DMA_SURFACES", + "ADVANCED_TILE_HELPER_SURFACES", + "AUTHORING_TIER_SURFACE_GROUPS", + "AUTHORING_TIER_GROUP_TIERS", + "UNSUPPORTED_LANGUAGE_CONSTRUCTS", + "LANGUAGE_CONSTRUCT_TIERS", "advanced_mode_message", "deferred_surface_message", "unsupported_feature_message", + "get_pto_call_tier", + "get_feature_tier", + "get_surface_group_tier", ] diff --git a/tilelang-dsl/python/tilelang_dsl/types.py b/tilelang-dsl/python/tilelang_dsl/types.py index ca94ba895..405ece12f 100644 --- a/tilelang-dsl/python/tilelang_dsl/types.py +++ b/tilelang-dsl/python/tilelang_dsl/types.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + """Public type markers for the TileLang DSL v1 surface.""" from __future__ import annotations @@ -19,6 +27,10 @@ class TensorView: """Bare TensorView annotation marker for TileLang DSL v1.""" +class PartitionTensorView: + """Bare PartitionTensorView annotation marker for TileLang DSL v1.""" + + class Tile: """Bare Tile annotation marker for TileLang DSL v1.""" @@ -32,6 +44,23 @@ def __repr__(self) -> str: return f"ptr({self.element_dtype!r}, {self.memory_space!r})" +@dataclass(frozen=True) +class VRegType: + element_dtype: ScalarType + lanes: int + + def __repr__(self) -> str: + return f"vreg({self.element_dtype!r})" + + +@dataclass(frozen=True) +class MaskType: + granularity: str + + def __repr__(self) -> str: + return f"mask_{self.granularity}" + + @dataclass(frozen=True) class WildcardType: name: str @@ -81,6 +110,20 @@ class MaskPattern(str, Enum): VL32 = "PAT_VL32" +class PadMode(str, Enum): + PadNull = "PadNull" + PadFirstElem = "PadFirstElem" + PadValue = "PadValue" + + +class PositionMode(str, Enum): + LOWEST = "POS_LOWEST" + + +class OrderMode(str, Enum): + ASC = "ORDER_ASC" + + @dataclass(frozen=True) class TileConfig: fields: tuple[tuple[str, Any], ...] = () @@ -95,6 +138,7 @@ class TileSpecialization: shape: tuple[int, ...] memory_space: MemorySpace config: TileConfig | None = None + valid_shape: tuple[int | None, ...] | None = None i8 = ScalarType("i8") @@ -112,6 +156,9 @@ class TileSpecialization: AnyInt = WildcardType("AnyInt") AnyType = WildcardType("AnyType") AnyMask = WildcardType("AnyMask") +mask_b8 = MaskType("b8") +mask_b16 = MaskType("b16") +mask_b32 = MaskType("b32") def TypeVar(name: str) -> TypeVariable: @@ -128,9 +175,15 @@ def ptr(dtype: ScalarType, memory_space: MemorySpace) -> PointerType: return PointerType(element_dtype=dtype, memory_space=memory_space) -def get_lanes(dtype: ScalarType) -> int: +def vreg(dtype: ScalarType) -> VRegType: + if not isinstance(dtype, ScalarType): + raise TypeError("vreg() expects a TileLang scalar dtype") + return VRegType(element_dtype=dtype, lanes=get_lanes(dtype)) + + +def bytewidth(dtype: ScalarType) -> int: if not isinstance(dtype, ScalarType): - raise TypeError("get_lanes expects a TileLang scalar dtype") + raise TypeError("bytewidth expects a TileLang scalar dtype") byte_widths = { "i8": 1, "i16": 2, @@ -141,8 +194,20 @@ def get_lanes(dtype: ScalarType) -> int: } width = byte_widths.get(dtype.name) if width is None: - raise TypeError(f"dtype `{dtype.name}` is not supported by get_lanes") - return 256 // width + raise TypeError(f"dtype `{dtype.name}` is not supported by bytewidth") + return width + + +def get_lanes(dtype: ScalarType) -> int: + return 256 // bytewidth(dtype) + + +def elements_per_vreg(dtype: ScalarType) -> int: + return get_lanes(dtype) + + +def constexpr(value: bool) -> bool: + return value __all__ = [ @@ -151,9 +216,13 @@ def get_lanes(dtype: ScalarType) -> int: "TypeVariable", "TypeVar", "TensorView", + "PartitionTensorView", "Tile", "PointerType", + "VRegType", + "MaskType", "ptr", + "vreg", "MemorySpace", "Pipe", "Event", @@ -161,6 +230,9 @@ def get_lanes(dtype: ScalarType) -> int: "EVENT", "MaskPattern", "PAT", + "PadMode", + "PositionMode", + "OrderMode", "TileConfig", "TileSpecialization", "i1", @@ -175,5 +247,11 @@ def get_lanes(dtype: ScalarType) -> int: "AnyInt", "AnyType", "AnyMask", + "mask_b8", + "mask_b16", + "mask_b32", + "constexpr", + "bytewidth", "get_lanes", + "elements_per_vreg", ] diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index f6c6de2ff..639641af5 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -6,14 +6,34 @@ import tilelang_dsl as pto import tilelang_dsl.kernel as kernel_impl -from tilelang_dsl.frontend_ast import build_frontend_kernel_node +from tilelang_dsl.support_matrix import ( + ADVANCED_EXPLICIT_VECSCOPE_SURFACES, + ADVANCED_LOW_LEVEL_DMA_SURFACES, + ADVANCED_RAW_POINTER_SURFACES, + ADVANCED_TILE_HELPER_SURFACES, + ADVANCED_TIER, + AUTHORING_TIER_SURFACE_GROUPS, + BASIC_TIER, + BASIC_TILE_INDEXING_SURFACES, + get_feature_tier, + get_surface_group_tier, +) +from tilelang_dsl.frontend_ast import ( + FrontendAssignStmt, + FrontendCallExpr, + FrontendExprStmt, + FrontendForStmt, + FrontendStrictVecscopeStmt, + FrontendVecscopeStmt, + build_frontend_kernel_node, +) from tilelang_dsl.lowering import AuthoringModule, lower_semantic_kernel from tilelang_dsl.semantic import ( SemanticAssignStmt, + SemanticBinaryExpr, SemanticCallExpr, SemanticDmaConfigStmt, - SemanticDmaLoadStmt, - SemanticDmaStoreStmt, + SemanticExprStmt, SemanticForStmt, SemanticIfStmt, SemanticIndexType, @@ -24,10 +44,12 @@ SemanticScalarType, SemanticSetFlagStmt, SemanticStrictVecscopeStmt, + SemanticSymbolExpr, SemanticTensorViewType, SemanticTileType, SemanticVecscopeStmt, SemanticVectorStoreStmt, + SemanticVRegType, SemanticWaitFlagStmt, analyze_frontend_kernel, ) @@ -43,11 +65,107 @@ def test_package_exports_surface(self) -> None: self.assertTrue(hasattr(pto, "Tile")) self.assertTrue(hasattr(pto, "TileSpecialization")) self.assertTrue(hasattr(pto, "PointerType")) + self.assertTrue(hasattr(pto, "VRegType")) + self.assertTrue(hasattr(pto, "MaskType")) self.assertTrue(hasattr(pto, "ptr")) + self.assertTrue(hasattr(pto, "vreg")) + self.assertTrue(hasattr(pto, "mask_b8")) + self.assertTrue(hasattr(pto, "mask_b16")) + self.assertTrue(hasattr(pto, "mask_b32")) + self.assertTrue(hasattr(pto, "constexpr")) + self.assertTrue(hasattr(pto, "bytewidth")) self.assertTrue(hasattr(pto, "get_lanes")) + self.assertTrue(hasattr(pto, "elements_per_vreg")) self.assertTrue(hasattr(pto, "PAT")) + self.assertTrue(hasattr(pto, "PadMode")) + self.assertTrue(hasattr(pto, "PositionMode")) + self.assertTrue(hasattr(pto, "OrderMode")) self.assertTrue(hasattr(pto, "PIPE")) self.assertTrue(hasattr(pto, "EVENT")) + self.assertEqual(pto.PadMode.PadNull.value, "PadNull") + self.assertEqual(pto.PadMode.PadFirstElem.value, "PadFirstElem") + self.assertEqual(pto.PadMode.PadValue.value, "PadValue") + self.assertEqual(pto.PositionMode.LOWEST.value, "POS_LOWEST") + self.assertEqual(pto.OrderMode.ASC.value, "ORDER_ASC") + + +class TileLangDSLSupportMatrixTests(unittest.TestCase): + def test_stable_starter_surface_groups_map_to_stable_tier(self) -> None: + self.assertEqual(get_surface_group_tier("TensorView"), BASIC_TIER) + self.assertEqual(get_surface_group_tier("Tile"), BASIC_TIER) + self.assertEqual(get_surface_group_tier("base_vector_ops"), BASIC_TIER) + self.assertEqual(get_surface_group_tier("tile_indexing_sugar"), BASIC_TIER) + + self.assertIn("TensorView", AUTHORING_TIER_SURFACE_GROUPS["TensorView"]) + self.assertIn("Tile", AUTHORING_TIER_SURFACE_GROUPS["Tile"]) + self.assertNotIn("dma_load/store", AUTHORING_TIER_SURFACE_GROUPS) + self.assertIn("pto.vlds", AUTHORING_TIER_SURFACE_GROUPS["base_vector_ops"]) + self.assertIn("pto.vsts", AUTHORING_TIER_SURFACE_GROUPS["base_vector_ops"]) + self.assertIn("pto.vadd", AUTHORING_TIER_SURFACE_GROUPS["base_vector_ops"]) + self.assertIn("pto.vmuls", AUTHORING_TIER_SURFACE_GROUPS["base_vector_ops"]) + self.assertIn("tile[start:]", BASIC_TILE_INDEXING_SURFACES) + self.assertIn("tile[row, col:]", BASIC_TILE_INDEXING_SURFACES) + + self.assertEqual(get_feature_tier("TensorView"), BASIC_TIER) + self.assertEqual(get_feature_tier("Tile"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vlds"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vsts"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vadd"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vmuls"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vaddrelu"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vaxpy"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vmull"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vands"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vbr"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vdup"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vci"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vpack"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vsort32"), BASIC_TIER) + self.assertEqual(get_feature_tier("PadMode"), BASIC_TIER) + self.assertEqual(get_feature_tier("VRegType"), BASIC_TIER) + self.assertEqual(get_feature_tier("MaskType"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vreg"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.mask_b8"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.mask_b16"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.mask_b32"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.bytewidth"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.get_lanes"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.elements_per_vreg"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.constexpr"), BASIC_TIER) + self.assertEqual(get_feature_tier("constexpr"), BASIC_TIER) + self.assertEqual(get_feature_tier("tile[start:]"), BASIC_TIER) + self.assertEqual(get_feature_tier("tile[row, col:]"), BASIC_TIER) + + def test_non_stable_surface_groups_keep_advanced_boundaries(self) -> None: + self.assertEqual(get_surface_group_tier("strict_vecscope"), ADVANCED_TIER) + self.assertEqual(get_surface_group_tier("raw_pointer_family"), ADVANCED_TIER) + self.assertEqual(get_surface_group_tier("low_level_dma_family"), ADVANCED_TIER) + self.assertEqual(get_surface_group_tier("tile_helper_family"), ADVANCED_TIER) + + self.assertIn("pto.strict_vecscope", ADVANCED_EXPLICIT_VECSCOPE_SURFACES) + self.assertIn("pto.ptr", ADVANCED_RAW_POINTER_SURFACES) + self.assertIn("pto.castptr", ADVANCED_RAW_POINTER_SURFACES) + self.assertIn("pto.copy_ubuf_to_ubuf", ADVANCED_LOW_LEVEL_DMA_SURFACES) + self.assertIn("pto.tile_with_strides", ADVANCED_TILE_HELPER_SURFACES) + + self.assertEqual(get_feature_tier("strict_vecscope"), ADVANCED_TIER) + self.assertEqual(get_feature_tier("pto.strict_vecscope"), ADVANCED_TIER) + self.assertEqual(get_feature_tier("pto.ptr"), ADVANCED_TIER) + self.assertEqual(get_feature_tier("pto.castptr"), ADVANCED_TIER) + self.assertEqual(get_feature_tier("pto.copy_ubuf_to_ubuf"), ADVANCED_TIER) + self.assertEqual(get_feature_tier("pto.tile_with_strides"), ADVANCED_TIER) + + def test_unsupported_features_do_not_report_legacy_tiers(self) -> None: + with self.assertRaises(KeyError): + get_surface_group_tier("dma_load/store") + with self.assertRaises(KeyError): + get_feature_tier("pto.dma_load") + with self.assertRaises(KeyError): + get_feature_tier("pto.dma_store") + with self.assertRaises(KeyError): + get_feature_tier("pto.dma_copy") + with self.assertRaises(KeyError): + get_feature_tier("pto.vreduce") class TileLangDSLMatcherEntryTests(unittest.TestCase): @@ -116,7 +234,55 @@ def kernel(inp: pto.TensorView, tile: pto.Tile): specialized = selected.specialize( tile=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB) ) - self.assertIn("memref<8x16xf32", specialized.mlir_text()) + self.assertIn( + "!pto.tile_buf None: + @pto.vkernel(op="matcher_default_dtypes_unique") + def kernel(inp: pto.Tile, out: pto.Tile): + return None + + selected = pto.select_kernel( + "a5", + "matcher_default_dtypes_unique", + (pto.f16, pto.f16), + ) + + self.assertEqual(selected.dtype_signature, (pto.f16, pto.f16)) + self.assertEqual( + [(param.name, param.kind, param.dtype) for param in selected.parameters], + [("inp", "tile", pto.f16), ("out", "tile", pto.f16)], + ) + specialized = selected.specialize( + inp=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + out=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + self.assertIn( + "!pto.tile_buf None: + @pto.vkernel(op="matcher_default_dtypes_scalar_guard_unique") + def kernel(inp: pto.TensorView, scale: pto.i32): + return None + + selected = pto.select_kernel( + "a5", + "matcher_default_dtypes_scalar_guard_unique", + (pto.f32, pto.i32), + ) + self.assertEqual(selected.dtype_signature, (pto.f32, pto.i32)) + + with self.assertRaises(LookupError) as ctx: + pto.select_kernel( + "a5", + "matcher_default_dtypes_scalar_guard_unique", + (pto.f32, pto.f16), + ) + self.assertIn("found no registered kernel", str(ctx.exception)) def test_select_kernel_matches_wildcards_deterministically(self) -> None: @pto.vkernel( @@ -162,6 +328,48 @@ def kernel(lhs: pto.TensorView, rhs: pto.Tile): ) self.assertIn("found no registered kernel", str(ctx.exception)) + def test_scalar_typevar_annotation_tracks_selected_dtype(self) -> None: + elem = pto.TypeVar("Elem") + + @pto.vkernel( + op="scalar_typevar_binding_unique", + dtypes=[(elem, elem, elem)], + ) + def kernel(inp: pto.Tile, scale: elem, out: pto.Tile): + return None + + selected = pto.select_kernel( + "a5", + "scalar_typevar_binding_unique", + (pto.bf16, pto.bf16, pto.bf16), + ) + + self.assertEqual(selected.dtype_signature, (pto.bf16, pto.bf16, pto.bf16)) + self.assertEqual( + [(param.name, param.kind, param.dtype) for param in selected.parameters], + [("inp", "tile", pto.bf16), ("scale", "scalar", pto.bf16), ("out", "tile", pto.bf16)], + ) + + def test_scalar_wildcard_annotation_accepts_selected_dtype(self) -> None: + @pto.vkernel( + op="scalar_wildcard_binding_unique", + dtypes=[(pto.AnyType, pto.AnyType, pto.AnyType)], + ) + def kernel(inp: pto.Tile, scale: pto.AnyType, out: pto.Tile): + return None + + selected = pto.select_kernel( + "a5", + "scalar_wildcard_binding_unique", + (pto.i16, pto.i16, pto.i16), + ) + + self.assertEqual(selected.dtype_signature, (pto.i16, pto.i16, pto.i16)) + self.assertEqual( + [(param.name, param.kind, param.dtype) for param in selected.parameters], + [("inp", "tile", pto.i16), ("scale", "scalar", pto.i16), ("out", "tile", pto.i16)], + ) + def test_polymorphic_descriptor_requires_select_kernel_before_materialization(self) -> None: @pto.vkernel( op="matcher_materialization_gate_unique", @@ -175,8 +383,8 @@ def kernel(inp: pto.TensorView, out: pto.TensorView): self.assertIn("requires pto.select_kernel(...)", str(ctx.exception)) def test_select_kernel_evaluates_constraints_before_priority(self) -> None: - def requires_large_batch(context_attrs): - return context_attrs.get("batch", 0) >= 1024 + def requires_large_batch(batch=0): + return batch >= 1024 @pto.vkernel( op="matcher_constraint_priority_unique", @@ -245,7 +453,7 @@ def test_select_kernel_reports_no_candidate_after_constraint_evaluation(self) -> @pto.vkernel( op="matcher_constraint_empty_unique", dtypes=[(pto.AnyFloat, pto.AnyFloat)], - constraints=[lambda context_attrs: context_attrs.get("enabled", False)], + constraints=[lambda enabled=False: enabled], priority=1, ) def kernel(inp: pto.TensorView, out: pto.TensorView): @@ -260,6 +468,214 @@ def kernel(inp: pto.TensorView, out: pto.TensorView): ) self.assertIn("after constraint evaluation", str(ctx.exception)) + def test_materialization_constraints_can_see_specializations_and_selected_context_attrs(self) -> None: + @pto.vkernel( + op="matcher_materialization_constraint_unique", + dtypes=[(pto.f32, pto.f32)], + constraints=[ + lambda src: src.rank == 5, + lambda dst, expected_rows=None: dst.shape[0] == expected_rows, + lambda src, dst: dst.valid_shape[1] <= src.shape[4], + ], + ) + def kernel(src: pto.TensorView, dst: pto.Tile): + return None + + selected = pto.select_kernel( + "a5", + "matcher_materialization_constraint_unique", + (pto.f32, pto.f32), + context_attrs={"expected_rows": 8, "src_shape": (2, 2, 1, 1, 16), "src_strides": (32, 16, 16, 16, 1)}, + ).specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB, valid_shape=(4, 16)), + ) + text = selected.mlir_text() + self.assertIn("!pto.tensor_view", text) + self.assertIn("!pto.tile_buf None: + @pto.vkernel( + op="matcher_parameter_style_constraints_unique", + dtypes=[(pto.f32, pto.f32)], + constraints=[ + lambda src, dst: src.rank == 5, + lambda src: src.strides[4] == 1, + lambda src, dst: src.shape[0] <= dst.shape[0], + ], + ) + def kernel(src: pto.TensorView, dst: pto.Tile): + return None + + selected = pto.select_kernel( + "a5", + "matcher_parameter_style_constraints_unique", + (pto.f32, pto.f32), + context_attrs={"src_shape": (4, 1, 1, 1, 16), "src_strides": (16, 16, 16, 16, 1)}, + ).specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + self.assertIn("!pto.tile_buf None: + @pto.vkernel( + ops=["matcher_multi_op_bind_add_unique", "matcher_multi_op_bind_sub_unique"], + dtypes=[(pto.f32, pto.f32)], + ) + def kernel(inp: pto.TensorView, out: pto.TensorView): + return None + + selected = pto.select_kernel( + "a5", + "matcher_multi_op_bind_sub_unique", + (pto.f32, pto.f32), + ) + + self.assertIs(selected.py_fn, kernel.py_fn) + self.assertEqual(selected.match_ops, ("matcher_multi_op_bind_add_unique", "matcher_multi_op_bind_sub_unique")) + self.assertEqual(selected.selected_op, "matcher_multi_op_bind_sub_unique") + self.assertEqual(selected.op, "matcher_multi_op_bind_sub_unique") + self.assertEqual(selected.dtype_signature, (pto.f32, pto.f32)) + + def test_select_kernel_hits_same_multi_op_descriptor_for_multiple_query_ops(self) -> None: + @pto.vkernel( + ops=[ + "matcher_multi_hit_add_unique", + "matcher_multi_hit_mul_unique", + "matcher_multi_hit_div_unique", + ], + dtypes=[(pto.f32, pto.f32)], + ) + def kernel(inp: pto.TensorView, out: pto.TensorView): + return None + + add_selected = pto.select_kernel( + "a5", + "matcher_multi_hit_add_unique", + (pto.f32, pto.f32), + ) + mul_selected = pto.select_kernel( + "a5", + "matcher_multi_hit_mul_unique", + (pto.f32, pto.f32), + ) + + self.assertIs(add_selected.py_fn, kernel.py_fn) + self.assertIs(mul_selected.py_fn, kernel.py_fn) + self.assertEqual(add_selected.match_ops, kernel.match_ops) + self.assertEqual(mul_selected.match_ops, kernel.match_ops) + self.assertEqual(add_selected.selected_op, "matcher_multi_hit_add_unique") + self.assertEqual(mul_selected.selected_op, "matcher_multi_hit_mul_unique") + self.assertEqual(add_selected.op, "matcher_multi_hit_add_unique") + self.assertEqual(mul_selected.op, "matcher_multi_hit_mul_unique") + + def test_select_kernel_prefers_higher_priority_single_op_over_multi_op(self) -> None: + @pto.vkernel( + op="matcher_single_beats_multi_priority_unique", + dtypes=[(pto.f32, pto.f32)], + priority=12, + ) + def single(inp: pto.TensorView, out: pto.TensorView): + return None + + @pto.vkernel( + ops=[ + "matcher_single_beats_multi_priority_unique", + "matcher_single_beats_multi_priority_alt_unique", + ], + dtypes=[(pto.f32, pto.f32)], + priority=4, + ) + def multi(inp: pto.TensorView, out: pto.TensorView): + return None + + selected = pto.select_kernel( + "a5", + "matcher_single_beats_multi_priority_unique", + (pto.f32, pto.f32), + ) + + self.assertIs(selected.py_fn, single.py_fn) + self.assertEqual(selected.selected_op, "matcher_single_beats_multi_priority_unique") + self.assertEqual(selected.priority, 12) + + def test_select_kernel_prefers_priority_over_single_op_specificity(self) -> None: + @pto.vkernel( + op="matcher_single_vs_multi_priority_unique", + dtypes=[(pto.f32, pto.f32)], + priority=5, + ) + def single(inp: pto.TensorView, out: pto.TensorView): + return None + + @pto.vkernel( + ops=["matcher_single_vs_multi_priority_unique", "matcher_single_vs_multi_priority_alt_unique"], + dtypes=[(pto.f32, pto.f32)], + priority=9, + ) + def multi(inp: pto.TensorView, out: pto.TensorView): + return None + + selected = pto.select_kernel( + "a5", + "matcher_single_vs_multi_priority_unique", + (pto.f32, pto.f32), + ) + + self.assertIs(selected.py_fn, multi.py_fn) + self.assertEqual(selected.selected_op, "matcher_single_vs_multi_priority_unique") + self.assertEqual(selected.priority, 9) + + def test_select_kernel_raises_tie_error_when_single_and_multi_op_candidates_tie(self) -> None: + @pto.vkernel( + op="matcher_single_multi_tie_unique", + dtypes=[(pto.f32, pto.f32)], + priority=17, + ) + def single(inp: pto.TensorView, out: pto.TensorView): + return None + + @pto.vkernel( + ops=["matcher_single_multi_tie_unique", "matcher_single_multi_tie_alt_unique"], + dtypes=[(pto.f32, pto.f32)], + priority=17, + ) + def multi(inp: pto.TensorView, out: pto.TensorView): + return None + + with self.assertRaises(LookupError) as ctx: + pto.select_kernel( + "a5", + "matcher_single_multi_tie_unique", + (pto.f32, pto.f32), + ) + + self.assertIn("multiple highest-priority kernels", str(ctx.exception)) + self.assertIn("single(priority=17", str(ctx.exception)) + self.assertIn("multi(priority=17", str(ctx.exception)) + class TileLangDSLDescriptorTests(unittest.TestCase): def test_descriptor_metadata_and_parameter_binding(self) -> None: @@ -283,6 +699,146 @@ def kernel(inp: pto.TensorView, tile: pto.Tile, scale: pto.i32): self.assertEqual(kernel.parameters[1].element_dtype, pto.f16) self.assertIsNone(kernel.parameters[2].element_dtype) + def test_descriptor_accepts_multi_op_matcher_metadata(self) -> None: + @pto.vkernel(ops=["tadd", "tsub"], dtypes=[(pto.f32, pto.f32)]) + def kernel(inp: pto.TensorView, out: pto.TensorView): + return None + + self.assertEqual(kernel.match_ops, ("tadd", "tsub")) + self.assertIsNone(kernel.selected_op) + self.assertIsNone(kernel.metadata["op"]) + self.assertEqual(kernel.metadata["match_ops"], ("tadd", "tsub")) + self.assertIsNone(kernel.metadata["selected_op"]) + self.assertEqual(kernel.dtype_signature, (pto.f32, pto.f32)) + self.assertEqual( + [(param.name, param.kind, param.dtype) for param in kernel.parameters], + [("inp", "tensorview", pto.f32), ("out", "tensorview", pto.f32)], + ) + with self.assertRaises(ValueError) as ctx: + _ = kernel.op + self.assertIn("bind a concrete op", str(ctx.exception)) + + def test_descriptor_defaults_dtypes_for_beginner_tile_kernels(self) -> None: + @pto.vkernel(op="default_dtypes_unique") + def kernel(inp: pto.Tile, out: pto.Tile): + return None + + self.assertEqual(kernel.match_ops, ("default_dtypes_unique",)) + self.assertEqual(kernel.dtypes, ((pto.AnyType, pto.AnyType),)) + self.assertEqual(kernel.metadata["dtypes"], ((pto.AnyType, pto.AnyType),)) + with self.assertRaises(ValueError) as ctx: + _ = kernel.dtype_signature + self.assertIn("choose a concrete dtype signature", str(ctx.exception)) + + def test_descriptor_defaults_scalar_typevar_to_anytype(self) -> None: + elem = pto.TypeVar("Elem") + + @pto.vkernel(op="default_scalar_typevar_unique") + def kernel(inp: pto.Tile, scale: elem, out: pto.Tile): + return None + + self.assertEqual(kernel.match_ops, ("default_scalar_typevar_unique",)) + self.assertEqual(kernel.dtypes, ((pto.AnyType, pto.AnyType, pto.AnyType),)) + self.assertEqual(kernel.metadata["dtypes"], ((pto.AnyType, pto.AnyType, pto.AnyType),)) + with self.assertRaises(ValueError) as ctx: + _ = kernel.dtype_signature + self.assertIn("choose a concrete dtype signature", str(ctx.exception)) + + def test_descriptor_defaults_scalar_wildcard_to_anytype(self) -> None: + @pto.vkernel(op="default_scalar_wildcard_unique") + def kernel(inp: pto.Tile, scale: pto.AnyType, out: pto.Tile): + return None + + self.assertEqual(kernel.match_ops, ("default_scalar_wildcard_unique",)) + self.assertEqual(kernel.dtypes, ((pto.AnyType, pto.AnyType, pto.AnyType),)) + self.assertEqual(kernel.metadata["dtypes"], ((pto.AnyType, pto.AnyType, pto.AnyType),)) + with self.assertRaises(ValueError) as ctx: + _ = kernel.dtype_signature + self.assertIn("choose a concrete dtype signature", str(ctx.exception)) + + def test_descriptor_accepts_templates_metadata(self) -> None: + @pto.vkernel( + ops=["tadd", "tsub", "tmul"], + dtypes=[(pto.f32, pto.f32)], + templates={ + "core": { + "tadd": "vadd", + "tsub": "vsub", + }, + "post": { + "tmul": "vrelu", + }, + }, + ) + def kernel(inp: pto.TensorView, out: pto.TensorView): + return None + + self.assertEqual( + kernel.templates, + { + "core": { + "tadd": "vadd", + "tsub": "vsub", + }, + "post": { + "tmul": "vrelu", + }, + }, + ) + self.assertEqual(kernel.metadata["templates"], kernel.templates) + + def test_descriptor_rejects_op_and_ops_together(self) -> None: + with self.assertRaises(ValueError) as ctx: + @pto.vkernel(op="tadd", ops=["tsub"], dtypes=[(pto.f32,)]) + def kernel(inp: pto.TensorView): + return None + + self.assertIn("either op= or ops=", str(ctx.exception)) + + def test_descriptor_requires_one_of_op_or_ops(self) -> None: + with self.assertRaises(ValueError) as ctx: + @pto.vkernel(dtypes=[(pto.f32,)]) + def kernel(inp: pto.TensorView): + return None + + self.assertIn("exactly one of op= or ops=", str(ctx.exception)) + + def test_descriptor_rejects_template_slot_with_non_string_name(self) -> None: + with self.assertRaises(TypeError) as ctx: + @pto.vkernel( + ops=["tadd"], + dtypes=[(pto.f32,)], + templates={1: {"tadd": "vadd"}}, + ) + def kernel(inp: pto.TensorView): + return None + + self.assertIn("template slot names must be non-empty strings", str(ctx.exception)) + + def test_descriptor_rejects_template_op_outside_matcher_set(self) -> None: + with self.assertRaises(ValueError) as ctx: + @pto.vkernel( + ops=["tadd", "tsub"], + dtypes=[(pto.f32, pto.f32)], + templates={"core": {"tmul": "vmul"}}, + ) + def kernel(inp: pto.TensorView, out: pto.TensorView): + return None + + self.assertIn("outside descriptor matcher set", str(ctx.exception)) + + def test_descriptor_rejects_template_mapping_to_unknown_pto_op(self) -> None: + with self.assertRaises(ValueError) as ctx: + @pto.vkernel( + ops=["tadd"], + dtypes=[(pto.f32,)], + templates={"core": {"tadd": "vunknown"}}, + ) + def kernel(inp: pto.TensorView): + return None + + self.assertIn("maps to unsupported pto op", str(ctx.exception)) + def test_pointer_parameter_annotation_binds_as_ptr_kind(self) -> None: @pto.vkernel(op="ptr_surface", dtypes=[(pto.f32, pto.i64)], advanced=True) def kernel(src: pto.ptr(pto.f32, pto.MemorySpace.UB), addr: pto.i64): @@ -293,6 +849,32 @@ def kernel(src: pto.ptr(pto.f32, pto.MemorySpace.UB), addr: pto.i64): self.assertEqual(kernel.parameters[0].annotation, pto.ptr(pto.f32, pto.MemorySpace.UB)) self.assertEqual(kernel.parameters[0].element_dtype, pto.f32) + def test_vreg_type_constructor_exposes_inferred_lane_count(self) -> None: + vec_type = pto.vreg(pto.f32) + self.assertIsInstance(vec_type, pto.VRegType) + self.assertEqual(vec_type.element_dtype, pto.f32) + self.assertEqual(vec_type.lanes, 64) + self.assertEqual(repr(vec_type), "vreg(f32)") + + def test_mask_type_constants_expose_granularity(self) -> None: + self.assertIsInstance(pto.mask_b8, pto.MaskType) + self.assertIsInstance(pto.mask_b16, pto.MaskType) + self.assertIsInstance(pto.mask_b32, pto.MaskType) + self.assertEqual(pto.mask_b8.granularity, "b8") + self.assertEqual(pto.mask_b16.granularity, "b16") + self.assertEqual(pto.mask_b32.granularity, "b32") + self.assertEqual(repr(pto.mask_b32), "mask_b32") + + def test_mask_parameter_annotation_binds_as_mask_kind(self) -> None: + @pto.vkernel(op="mask_surface", dtypes=[(pto.mask_b32, pto.f32)], advanced=True) + def kernel(mask: pto.mask_b32, dst: pto.Tile): + return None + + self.assertEqual(kernel.parameters[0].kind, "mask") + self.assertEqual(kernel.parameters[0].dtype, pto.mask_b32) + self.assertEqual(kernel.parameters[0].annotation, pto.mask_b32) + self.assertIsNone(kernel.parameters[0].element_dtype) + def test_specialization_enables_materialization_apis(self) -> None: @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f16)]) def kernel(inp: pto.TensorView, tile: pto.Tile): @@ -311,7 +893,10 @@ def kernel(inp: pto.TensorView, tile: pto.Tile): self.assertIn("// tilelang.target = a5", text) self.assertIn("// tilelang.specialize tile shape=(16, 32) memory_space=ub", text) self.assertIn('module attributes {pto.target_arch = "a5"} {', text) - self.assertIn("func.func @kernel(%arg0: !pto.ptr, %arg1: !pto.ptr) {", text) + self.assertIn( + "func.func @kernel(%arg0: !pto.tensor_view, %arg1: !pto.tile_buf) attributes { pto.tilelang.instance } {", + text, + ) module = specialized.mlir_module() self.assertEqual(type(module).__name__, "MaterializedMLIRModule") mocked_result = kernel_impl.VerificationResult( @@ -333,6 +918,57 @@ def kernel(inp: pto.TensorView, tile: pto.Tile): specialized.emit(out) self.assertEqual(out.read_text(encoding="utf-8"), text) + def test_multi_op_descriptor_requires_select_kernel_before_materialization_apis(self) -> None: + @pto.vkernel( + ops=["multi_op_gate_add_unique", "multi_op_gate_sub_unique"], + dtypes=[(pto.f32, pto.f32)], + ) + def kernel(inp: pto.TensorView, out: pto.TensorView): + return None + + with self.assertRaises(ValueError) as text_ctx: + kernel.mlir_text() + self.assertIn("mlir_text() requires pto.select_kernel(...) to bind a concrete op", str(text_ctx.exception)) + + with self.assertRaises(ValueError) as module_ctx: + kernel.mlir_module() + self.assertIn( + "mlir_module() requires pto.select_kernel(...) to bind a concrete op", + str(module_ctx.exception), + ) + + with self.assertRaises(ValueError) as verify_ctx: + kernel.verify() + self.assertIn("verify() requires pto.select_kernel(...) to bind a concrete op", str(verify_ctx.exception)) + + with tempfile.TemporaryDirectory() as tmpdir: + out = Path(tmpdir) / "kernel.mlir" + with self.assertRaises(ValueError) as emit_ctx: + kernel.emit(out) + self.assertIn("emit() requires pto.select_kernel(...) to bind a concrete op", str(emit_ctx.exception)) + + def test_selected_multi_op_descriptor_can_materialize_normally(self) -> None: + @pto.vkernel( + ops=["multi_op_materialize_add_unique", "multi_op_materialize_sub_unique"], + dtypes=[(pto.f32, pto.f32)], + ) + def kernel(inp: pto.TensorView, out: pto.TensorView): + return None + + selected = pto.select_kernel( + "a5", + "multi_op_materialize_sub_unique", + (pto.f32, pto.f32), + ) + + text = selected.mlir_text() + self.assertIn("// tilelang.target = a5", text) + self.assertIn("// tilelang.op = multi_op_materialize_sub_unique", text) + self.assertIn( + 'func.func @kernel(%arg0: !pto.tensor_view, %arg1: !pto.tensor_view) attributes { pto.tilelang.instance } {', + text, + ) + def test_verify_reports_structured_unavailable_when_ptoas_is_missing(self) -> None: @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f16)]) def kernel(inp: pto.TensorView, tile: pto.Tile): @@ -381,20 +1017,368 @@ def kernel(inp: pto.TensorView, tile: pto.Tile, scale: pto.i32): self.assertEqual(authoring_module.render(), specialized.mlir_text()) self.assertIn("return", authoring_module.render()) - def test_semantic_pipeline_binds_parameter_loop_and_strict_vecscope_types(self) -> None: - @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f16, pto.i32)]) - def kernel(inp: pto.TensorView, tile: pto.Tile, scale: pto.i32): - rows = tile.shape[0] - step = rows - with pto.strict_vecscope(inp, tile, scale, 0, rows, step) as ( - vin, - vtmp, - factor, - lb, - ub, - vec_step, - ): - for lane in range(lb, ub, vec_step): + def test_descriptor_pipeline_ignores_kernel_docstring_expression(self) -> None: + @pto.vkernel(op="docstring_passthrough_unique", dtypes=[(pto.f32,)]) + def kernel(inp: pto.TensorView): + """This docstring should be ignored as a no-op expression statement.""" + return None + + frontend_kernel = build_frontend_kernel_node(kernel) + self.assertEqual(len(frontend_kernel.body), 2) + self.assertIsInstance(frontend_kernel.body[0], FrontendExprStmt) + + semantic_kernel = analyze_frontend_kernel(frontend_kernel) + self.assertEqual(len(semantic_kernel.body), 1) + + text = lower_semantic_kernel(semantic_kernel).render() + self.assertIn("// tilelang.op = docstring_passthrough_unique", text) + self.assertIn("func.func @kernel", text) + self.assertIn("return", text) + + def test_frontend_rejects_hidden_dma_load_surface(self) -> None: + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.vkernel(op="dma_load_hidden", dtypes=[(pto.f32, pto.f32)]) + def kernel(inp: pto.TensorView, tile: pto.Tile): + pto.dma_load(inp[0:16, 0:16], tile) + return None + + self.assertIn("unsupported op surface `pto.dma_load`", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_frontend_rejects_hidden_dma_store_surface(self) -> None: + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.vkernel(op="dma_store_hidden", dtypes=[(pto.f32, pto.f32)]) + def kernel(out: pto.TensorView, tile: pto.Tile): + pto.dma_store(tile, out[0:16, 0:16]) + return None + + self.assertIn("unsupported op surface `pto.dma_store`", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_frontend_rejects_hidden_dma_copy_surface(self) -> None: + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.vkernel(op="dma_copy_hidden", dtypes=[(pto.f32, pto.f32)]) + def kernel(src: pto.Tile, dst: pto.Tile): + pto.dma_copy(src, dst) + return None + + self.assertIn("unsupported op surface `pto.dma_copy`", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_frontend_rejects_keyword_arguments_on_public_surfaces(self) -> None: + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.vkernel(op="dma_kw_wrong_surface", dtypes=[(pto.f32, pto.f32)]) + def kernel(inp: pto.TensorView, tile: pto.Tile): + pto.vlds(tile, offset=0) + return None + + self.assertIn("no public call surface currently accepts them", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_frontend_rewrites_template_slot_to_selected_real_op(self) -> None: + @pto.vkernel( + ops=["template_slot_add_unique", "template_slot_sub_unique"], + dtypes=[(pto.f32, pto.f32, pto.f32)], + advanced=True, + templates={ + "core": { + "template_slot_add_unique": "vadd", + "template_slot_sub_unique": "vsub", + } + }, + ) + def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + with pto.strict_vecscope(dst, src0, src1, 0, 64, 64) as ( + out_tile, + lhs_tile, + rhs_tile, + lb, + ub, + step, + ): + for lane in range(lb, ub, step): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + lhs = pto.vlds(lhs_tile, lane) + rhs = pto.vlds(rhs_tile, lane) + out = pto.tpl("core", lhs, rhs, mask) + pto.vsts(out, out_tile, lane, mask) + return None + + add_selected = pto.select_kernel( + "a5", + "template_slot_add_unique", + (pto.f32, pto.f32, pto.f32), + ).specialize( + dst=pto.TileSpecialization(shape=(16, 16), memory_space=pto.MemorySpace.UB), + src0=pto.TileSpecialization(shape=(16, 16), memory_space=pto.MemorySpace.UB), + src1=pto.TileSpecialization(shape=(16, 16), memory_space=pto.MemorySpace.UB), + ) + sub_selected = pto.select_kernel( + "a5", + "template_slot_sub_unique", + (pto.f32, pto.f32, pto.f32), + ).specialize( + dst=pto.TileSpecialization(shape=(16, 16), memory_space=pto.MemorySpace.UB), + src0=pto.TileSpecialization(shape=(16, 16), memory_space=pto.MemorySpace.UB), + src1=pto.TileSpecialization(shape=(16, 16), memory_space=pto.MemorySpace.UB), + ) + + add_frontend = build_frontend_kernel_node(add_selected) + sub_frontend = build_frontend_kernel_node(sub_selected) + + add_vecscope = add_frontend.body[0] + sub_vecscope = sub_frontend.body[0] + self.assertIsInstance(add_vecscope, FrontendStrictVecscopeStmt) + self.assertIsInstance(sub_vecscope, FrontendStrictVecscopeStmt) + + add_loop = add_vecscope.body[0] + sub_loop = sub_vecscope.body[0] + self.assertIsInstance(add_loop, FrontendForStmt) + self.assertIsInstance(sub_loop, FrontendForStmt) + + add_out_assign = add_loop.body[3] + sub_out_assign = sub_loop.body[3] + self.assertIsInstance(add_out_assign, FrontendAssignStmt) + self.assertIsInstance(sub_out_assign, FrontendAssignStmt) + self.assertIsInstance(add_out_assign.value, FrontendCallExpr) + self.assertIsInstance(sub_out_assign.value, FrontendCallExpr) + self.assertEqual(add_out_assign.value.namespace, "pto") + self.assertEqual(sub_out_assign.value.namespace, "pto") + self.assertEqual(add_out_assign.value.name, "vadd") + self.assertEqual(sub_out_assign.value.name, "vsub") + + add_text = add_selected.mlir_text() + sub_text = sub_selected.mlir_text() + self.assertIn("pto.vadd", add_text) + self.assertNotIn("pto.vsub", add_text) + self.assertIn("pto.vsub", sub_text) + self.assertNotIn("pto.vadd", sub_text) + + def test_template_slot_shared_kernel_body_expands_for_four_ops(self) -> None: + @pto.vkernel( + ops=[ + "template_slot_tadd_unique", + "template_slot_tsub_unique", + "template_slot_tmul_unique", + "template_slot_tdiv_unique", + ], + dtypes=[(pto.f32, pto.f32, pto.f32)], + advanced=True, + templates={ + "core": { + "template_slot_tadd_unique": "vadd", + "template_slot_tsub_unique": "vsub", + "template_slot_tmul_unique": "vmul", + "template_slot_tdiv_unique": "vdiv", + } + }, + ) + def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + with pto.strict_vecscope(dst, src0, src1, 0, 64, 64) as ( + out_tile, + lhs_tile, + rhs_tile, + lb, + ub, + step, + ): + for lane in range(lb, ub, step): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + lhs = pto.vlds(lhs_tile, lane) + rhs = pto.vlds(rhs_tile, lane) + out = pto.tpl("core", lhs, rhs, mask) + pto.vsts(out, out_tile, lane, mask) + return None + + isolated_registry = pto.KernelRegistry((kernel,)) + expected_ops = { + "template_slot_tadd_unique": "vadd", + "template_slot_tsub_unique": "vsub", + "template_slot_tmul_unique": "vmul", + "template_slot_tdiv_unique": "vdiv", + } + + for query_op, real_op in expected_ops.items(): + selected = pto.select_kernel( + "a5", + query_op, + (pto.f32, pto.f32, pto.f32), + registry=isolated_registry, + ).specialize( + dst=pto.TileSpecialization(shape=(16, 16), memory_space=pto.MemorySpace.UB), + src0=pto.TileSpecialization(shape=(16, 16), memory_space=pto.MemorySpace.UB), + src1=pto.TileSpecialization(shape=(16, 16), memory_space=pto.MemorySpace.UB), + ) + + frontend_kernel = build_frontend_kernel_node(selected) + vecscope = frontend_kernel.body[0] + self.assertIsInstance(vecscope, FrontendStrictVecscopeStmt) + loop_stmt = vecscope.body[0] + self.assertIsInstance(loop_stmt, FrontendForStmt) + out_assign = loop_stmt.body[3] + self.assertIsInstance(out_assign, FrontendAssignStmt) + self.assertIsInstance(out_assign.value, FrontendCallExpr) + self.assertEqual(out_assign.value.name, real_op) + + text = selected.mlir_text() + self.assertIn(f"pto.{real_op}", text) + self.assertNotIn("pto.tpl(", text) + + def test_template_slot_rejects_non_literal_slot_name(self) -> None: + slot_name = "core" + + @pto.vkernel( + op="template_slot_non_literal_unique", + dtypes=[(pto.f32, pto.f32, pto.f32)], + advanced=True, + templates={"core": {"template_slot_non_literal_unique": "vadd"}}, + ) + def kernel(dst: pto.TensorView, src0: pto.TensorView, src1: pto.TensorView): + with pto.strict_vecscope(dst, src0, src1, 0, 64, 64) as (out_tile, lhs_tile, rhs_tile, lb, ub, step): + for lane in range(lb, ub, step): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + out = pto.tpl(slot_name, lhs_tile, rhs_tile, mask) + return None + + with self.assertRaises(pto.TileLangFrontendError) as ctx: + build_frontend_kernel_node(kernel) + + self.assertIn("pto.tpl() requires a non-empty string literal slot name", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_template_slot_rejects_unknown_slot_before_ir_generation(self) -> None: + @pto.vkernel( + op="template_slot_unknown_slot_unique", + dtypes=[(pto.f32, pto.f32, pto.f32)], + advanced=True, + templates={"core": {"template_slot_unknown_slot_unique": "vadd"}}, + ) + def kernel(dst: pto.TensorView, src0: pto.TensorView, src1: pto.TensorView): + with pto.strict_vecscope(dst, src0, src1, 0, 64, 64) as (out_tile, lhs_tile, rhs_tile, lb, ub, step): + for lane in range(lb, ub, step): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + out = pto.tpl("missing", lhs_tile, rhs_tile, mask) + return None + + with self.assertRaises(pto.TileLangFrontendError) as ctx: + build_frontend_kernel_node(kernel) + + self.assertIn("unknown template slot 'missing'", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_template_slot_rejects_missing_selected_op_mapping(self) -> None: + @pto.vkernel( + ops=["template_slot_missing_map_add_unique", "template_slot_missing_map_sub_unique"], + dtypes=[(pto.f32, pto.f32, pto.f32)], + advanced=True, + templates={"core": {"template_slot_missing_map_add_unique": "vadd"}}, + ) + def kernel(dst: pto.TensorView, src0: pto.TensorView, src1: pto.TensorView): + with pto.strict_vecscope(dst, src0, src1, 0, 64, 64) as (out_tile, lhs_tile, rhs_tile, lb, ub, step): + for lane in range(lb, ub, step): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + out = pto.tpl("core", lhs_tile, rhs_tile, mask) + return None + + selected = pto.select_kernel( + "a5", + "template_slot_missing_map_sub_unique", + (pto.f32, pto.f32, pto.f32), + ) + + with self.assertRaises(pto.TileLangFrontendError) as ctx: + build_frontend_kernel_node(selected) + + self.assertIn("template slot 'core' does not define an implementation for selected op", str(ctx.exception)) + self.assertIn("template_slot_missing_map_sub_unique", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_template_slot_requires_selected_op_before_expansion(self) -> None: + @pto.vkernel( + ops=["template_slot_unbound_add_unique", "template_slot_unbound_sub_unique"], + dtypes=[(pto.f32, pto.f32, pto.f32)], + advanced=True, + templates={ + "core": { + "template_slot_unbound_add_unique": "vadd", + "template_slot_unbound_sub_unique": "vsub", + } + }, + ) + def kernel(dst: pto.TensorView, src0: pto.TensorView, src1: pto.TensorView): + with pto.strict_vecscope(dst, src0, src1, 0, 64, 64) as (out_tile, lhs_tile, rhs_tile, lb, ub, step): + for lane in range(lb, ub, step): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + out = pto.tpl("core", lhs_tile, rhs_tile, mask) + return None + + with self.assertRaises(pto.TileLangFrontendError) as ctx: + build_frontend_kernel_node(kernel) + + self.assertIn("pto.tpl() requires pto.select_kernel(...) to bind a concrete op before expansion", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_template_slot_respects_resolved_op_surface_rules(self) -> None: + @pto.vkernel( + op="template_slot_advanced_surface_unique", + dtypes=[(pto.i32, pto.i32, pto.i32)], + templates={"cmp": {"template_slot_advanced_surface_unique": "vcmp"}}, + ) + def kernel(dst: pto.TensorView, src0: pto.TensorView, src1: pto.TensorView): + mask = pto.make_mask(pto.i32, pto.PAT.ALL) + out = pto.tpl("cmp", dst, src0, mask, "lt") + return None + + with self.assertRaises(pto.TileLangFrontendError) as ctx: + build_frontend_kernel_node(kernel) + + self.assertIn("surface `pto.vcmp` requires advanced=True", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_callable_based_runtime_template_dispatch_remains_rejected(self) -> None: + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.vkernel( + op="template_slot_callable_dispatch_unique", + dtypes=[(pto.f32, pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.TensorView, src0: pto.TensorView, src1: pto.TensorView): + table = {"core": pto.vadd} + with pto.strict_vecscope(dst, src0, src1, 0, 64, 64) as ( + out_tile, + lhs_tile, + rhs_tile, + lb, + ub, + step, + ): + for lane in range(lb, ub, step): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + out = table["core"](lhs_tile, rhs_tile, mask) + return None + + self.assertIn("unsupported call surface in TileLang DSL v1", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_semantic_pipeline_binds_parameter_loop_and_strict_vecscope_types(self) -> None: + @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f16, pto.i32)], advanced=True) + def kernel(inp: pto.TensorView, tile: pto.Tile, scale: pto.i32): + rows = tile.shape[0] + step = rows + with pto.strict_vecscope(inp, tile, scale, 0, rows, step) as ( + vin, + vtmp, + factor, + lb, + ub, + vec_step, + ): + for lane in range(lb, ub, vec_step): current = factor return None @@ -410,6 +1394,7 @@ def kernel(inp: pto.TensorView, tile: pto.Tile, scale: pto.i32): semantic_kernel = analyze_frontend_kernel(frontend_kernel) self.assertIsInstance(semantic_kernel.parameters[0].type, SemanticTensorViewType) + self.assertEqual(semantic_kernel.parameters[0].type.rank, 5) self.assertIsInstance(semantic_kernel.parameters[1].type, SemanticTileType) self.assertEqual(semantic_kernel.parameters[1].type.shape, (8, 16)) self.assertIsInstance(semantic_kernel.parameters[2].type, SemanticScalarType) @@ -443,50 +1428,102 @@ def kernel(inp: pto.TensorView, tile: pto.Tile, scale: pto.i32): text = specialized.mlir_text() self.assertIn("%rows_", text) self.assertIn("= arith.constant 8 : index", text) - self.assertIn("pto.strict_vecscope(%arg0, %arg1, %arg2, %c0, %rows_", text) + self.assertRegex( + text, + r"pto\.strict_vecscope\(%tmp_\d+, %tmp_\d+, %arg2, %c0, %rows_\d+, %rows_\d+\)", + ) self.assertIn("^bb0(", text) self.assertIn("scf.for %lane_", text) self.assertIn("to %ub_6 step %vec_step_7 {", text) - def test_dma_load_and_store_lower_to_dma_programming_and_copy_ops(self) -> None: - @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32, pto.f32)]) - def kernel(inp: pto.TensorView, out: pto.TensorView, tile: pto.Tile): - pto.dma_load(inp[0:16, 0:16], tile) - pto.dma_store(tile, out[0:16, 0:16]) + def test_tensorview_defaults_to_5d_shape_profile(self) -> None: + @pto.vkernel(op="tensorview_5d_shape_profile_unique", dtypes=[(pto.f32,)]) + def kernel(inp: pto.TensorView): + d0, d1, d2, d3, d4 = inp.valid_shape return None - specialized = kernel.specialize( - tile=pto.TileSpecialization( - shape=(16, 16), - memory_space=pto.MemorySpace.UB, - ) + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(kernel)) + self.assertIsInstance(semantic_kernel.parameters[0].type, SemanticTensorViewType) + self.assertEqual(semantic_kernel.parameters[0].type.rank, 5) + self.assertEqual( + [(param.name, param.kind) for param in semantic_kernel.parameters], + [("inp", "tensorview")], ) - semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) - self.assertIsInstance(semantic_kernel.body[0], SemanticDmaLoadStmt) - self.assertIsInstance(semantic_kernel.body[1], SemanticDmaStoreStmt) - - text = specialized.mlir_text() + text = kernel.mlir_text() self.assertIn( - "func.func @kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr) {", + "func.func @kernel(%arg0: !pto.tensor_view) " + "attributes { pto.tilelang.instance } {", text, ) - self.assertIn("pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64", text) - self.assertIn( - "pto.copy_gm_to_ubuf %arg0, %arg2, %c0_i64, %c16_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64", - text, + self.assertEqual(text.count("pto.get_tensor_view_dim"), 5) + + def test_tensorview_strides_profile_lowers_through_explicit_stride_queries(self) -> None: + @pto.vkernel(op="tensorview_5d_stride_profile_unique", dtypes=[(pto.f32,)]) + def kernel(inp: pto.TensorView): + s0, s1, s2, s3, s4 = inp.strides + for lane in range(0, s4, 1): + current = lane + return None + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(kernel)) + self.assertEqual( + [(param.name, param.kind) for param in semantic_kernel.parameters], + [("inp", "tensorview")], ) - self.assertIn("pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64", text) + + text = kernel.mlir_text() self.assertIn( - "pto.copy_ubuf_to_gm %arg2, %arg1, %c0_i64, %c16_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64", + "func.func @kernel(%arg0: !pto.tensor_view) " + "attributes { pto.tilelang.instance } {", text, ) + self.assertEqual(text.count("pto.get_tensor_view_stride"), 5) + self.assertRegex(text, r"scf\.for %lane_\d+ = %c0 to %s4_\d+ step %c1 \{") + + def test_tensorview_accepts_full_5d_slice_profile(self) -> None: + @pto.vkernel(op="tensorview_5d_slice_profile_unique", dtypes=[(pto.f32,)]) + def kernel(inp: pto.TensorView): + view = inp[0:1, 0:2, 0:3, 0:4, 0:5] + return None + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(kernel)) + slice_assign = semantic_kernel.body[0] + self.assertIsInstance(slice_assign, SemanticAssignStmt) + self.assertEqual(slice_assign.value.type.rank, 5) + self.assertEqual(slice_assign.value.type.extents, (1, 2, 3, 4, 5)) + self.assertEqual(slice_assign.value.type.physical_axes, (0, 1, 2, 3, 4)) + + def test_tensorview_3d_slice_profile_right_aligns_into_5d_descriptor(self) -> None: + @pto.vkernel(op="tensorview_3d_slice_profile_unique", dtypes=[(pto.f32,)]) + def kernel(inp: pto.TensorView): + view = inp[0:8, 0:16, 0:32] + return None + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(kernel)) + slice_assign = semantic_kernel.body[0] + self.assertIsInstance(slice_assign, SemanticAssignStmt) + self.assertEqual(slice_assign.value.type.rank, 3) + self.assertEqual(slice_assign.value.type.extents, (8, 16, 32)) + self.assertEqual(slice_assign.value.type.physical_axes, (2, 3, 4)) + + def test_tensorview_2d_slice_profile_right_aligns_into_5d_descriptor(self) -> None: + @pto.vkernel(op="tensorview_2d_slice_profile_unique", dtypes=[(pto.f32,)]) + def kernel(inp: pto.TensorView): + view = inp[0:16, 0:32] + return None + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(kernel)) + slice_assign = semantic_kernel.body[0] + self.assertIsInstance(slice_assign, SemanticAssignStmt) + self.assertEqual(slice_assign.value.type.rank, 2) + self.assertEqual(slice_assign.value.type.extents, (16, 32)) + self.assertEqual(slice_assign.value.type.physical_axes, (3, 4)) - def test_dynamic_tensorview_shape_profile_supports_runtime_bound_and_slice(self) -> None: + def test_dynamic_tensorview_shape_profile_supports_runtime_bound_without_high_level_dma(self) -> None: @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32)]) def kernel(inp: pto.TensorView, tile: pto.Tile): rows = inp.shape[0] - pto.dma_load(inp[0:rows, 0:16], tile) for lane in range(0, rows, 1): current = lane return None @@ -501,36 +1538,48 @@ def kernel(inp: pto.TensorView, tile: pto.Tile): semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) self.assertEqual( [(param.name, param.kind) for param in semantic_kernel.parameters], - [("inp", "tensorview"), ("tile", "tile"), ("__shape_inp_0", "tensorview_shape")], + [("inp", "tensorview"), ("tile", "tile")], ) rows_assign = semantic_kernel.body[0] self.assertIsInstance(rows_assign, SemanticAssignStmt) self.assertIsInstance(rows_assign.targets[0].type, SemanticIndexType) - dma_stmt = semantic_kernel.body[1] - self.assertIsInstance(dma_stmt, SemanticDmaLoadStmt) - self.assertEqual(dma_stmt.src.type.extents, (None, 16)) - - loop_stmt = semantic_kernel.body[2] + loop_stmt = semantic_kernel.body[1] self.assertIsInstance(loop_stmt, SemanticForStmt) text = specialized.mlir_text() self.assertIn( - "func.func @kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: index) {", - text, - ) - self.assertIn( - "pto.copy_gm_to_ubuf %arg0, %arg1, %c0_i64, %c16_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64", + "func.func @kernel(%arg0: !pto.tensor_view, %arg1: !pto.tile_buf) attributes { pto.tilelang.instance } {", text, ) self.assertIn("scf.for %lane_", text) - self.assertIn("to %arg2 step %c1 {", text) + self.assertIn("pto.get_tensor_view_dim", text) + + def test_semantic_recognizes_padmode_symbol(self) -> None: + @pto.vkernel(op="pad_mode_symbol", dtypes=[(pto.f32, pto.f32)]) + def kernel(inp: pto.TensorView, tile: pto.Tile): + mode = pto.PadMode.PadFirstElem + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(16, 16), + memory_space=pto.MemorySpace.UB, + ) + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + assign_stmt = semantic_kernel.body[0] + self.assertIsInstance(assign_stmt, SemanticAssignStmt) + self.assertIsInstance(assign_stmt.value, SemanticSymbolExpr) + self.assertEqual(assign_stmt.value.value, pto.PadMode.PadFirstElem) + self.assertEqual(assign_stmt.value.type.kind, "pad_mode") + def test_make_mask_vlds_vsts_and_vector_families_lower_inside_strict_vecscope(self) -> None: - @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32, pto.f32)]) - def kernel(inp: pto.TensorView, tile: pto.Tile, scale: pto.f32): - pto.dma_load(inp[0:16, 0:16], tile) + @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32)], advanced=True) + def kernel(tile: pto.Tile, scale: pto.f32): with pto.strict_vecscope(tile, tile, scale, 0, 256, 64) as ( src, dst, @@ -556,7 +1605,7 @@ def kernel(inp: pto.TensorView, tile: pto.Tile, scale: pto.f32): ) semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) - vecscope = semantic_kernel.body[1] + vecscope = semantic_kernel.body[0] self.assertIsInstance(vecscope, SemanticStrictVecscopeStmt) loop_stmt = vecscope.body[0] self.assertIsInstance(loop_stmt, SemanticForStmt) @@ -568,29 +1617,16 @@ def kernel(inp: pto.TensorView, tile: pto.Tile, scale: pto.f32): self.assertIsInstance(loop_stmt.body[-1], SemanticVectorStoreStmt) text = specialized.mlir_text() - self.assertIn('%mask_7 = pto.pset_b32 "PAT_ALL" : !pto.mask', text) - self.assertIn("%vec_8 = pto.vlds %src_0[%lane_6] : !pto.ptr -> !pto.vreg<64xf32>", text) - self.assertIn( - "%biased_9 = pto.vadds %vec_8, %factor_2, %mask_7 : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32>", - text, - ) - self.assertIn( - "%summed_10 = pto.vadd %biased_9, %vec_8, %mask_7 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32>", - text, - ) - self.assertIn( - "%activated_11 = pto.vrelu %summed_10, %mask_7 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32>", - text, - ) - self.assertIn( - "pto.vsts %activated_11, %dst_1[%lane_6], %mask_7 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask", - text, - ) + self.assertRegex(text, r'%mask_\d+ = pto\.pset_b32 "PAT_ALL" : !pto\.mask') + self.assertRegex(text, r"%vec_\d+ = pto\.vlds %src_\d+\[%lane_\d+\] : !pto\.ptr -> !pto\.vreg<64xf32>") + self.assertRegex(text, r"%biased_\d+ = pto\.vadds %vec_\d+, %factor_\d+, %mask_\d+ : !pto\.vreg<64xf32>, f32, !pto\.mask -> !pto\.vreg<64xf32>") + self.assertRegex(text, r"%summed_\d+ = pto\.vadd %biased_\d+, %vec_\d+, %mask_\d+ : !pto\.vreg<64xf32>, !pto\.vreg<64xf32>, !pto\.mask -> !pto\.vreg<64xf32>") + self.assertRegex(text, r"%activated_\d+ = pto\.vrelu %summed_\d+, %mask_\d+ : !pto\.vreg<64xf32>, !pto\.mask -> !pto\.vreg<64xf32>") + self.assertRegex(text, r"pto\.vsts %activated_\d+, %dst_\d+\[%lane_\d+\], %mask_\d+ : !pto\.vreg<64xf32>, !pto\.ptr, !pto\.mask") def test_tail_make_mask_lowers_to_typed_plt_and_updates_remaining(self) -> None: - @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32, pto.i32)]) - def kernel(inp: pto.TensorView, tile: pto.Tile, remaining: pto.i32): - pto.dma_load(inp[0:16, 0:16], tile) + @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.i32)], advanced=True) + def kernel(tile: pto.Tile, remaining: pto.i32): with pto.strict_vecscope(tile, tile, remaining, 0, 64, 64) as (src, dst, rem_in, lb, ub, step): mask, next_remaining = pto.make_mask(pto.f32, rem_in) vec = pto.vlds(src, lb) @@ -605,7 +1641,7 @@ def kernel(inp: pto.TensorView, tile: pto.Tile, remaining: pto.i32): ) semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) - vecscope = semantic_kernel.body[1] + vecscope = semantic_kernel.body[0] self.assertIsInstance(vecscope, SemanticStrictVecscopeStmt) mask_assign = vecscope.body[0] self.assertIsInstance(mask_assign, SemanticAssignStmt) @@ -626,21 +1662,20 @@ def kernel(inp: pto.TensorView, tile: pto.Tile, remaining: pto.i32): ) def test_nested_index_arithmetic_lowers_before_vector_accesses(self) -> None: - @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32, pto.f32, pto.f32, pto.f32, pto.f32)]) + @pto.vkernel( + op="eltwise", + dtypes=[(pto.f32, pto.f32, pto.f32)], + advanced=True, + ) def kernel( - lhs_gm: pto.TensorView, - rhs_gm: pto.TensorView, - out_gm: pto.TensorView, lhs_tile: pto.Tile, rhs_tile: pto.Tile, dst_tile: pto.Tile, ): - rows = lhs_gm.shape[0] - cols = lhs_gm.shape[1] + rows = lhs_tile.shape[0] + cols = lhs_tile.shape[1] row_stride = lhs_tile.shape[1] - pto.dma_load(lhs_gm[0:rows, 0:cols], lhs_tile) - pto.dma_load(rhs_gm[0:rows, 0:cols], rhs_tile) with pto.strict_vecscope( lhs_tile, rhs_tile, @@ -658,7 +1693,6 @@ def kernel( mask, next_remaining = pto.make_mask(pto.f32, valid_cols - lane) summed = pto.vadd(pto.vlds(lhs, offset), pto.vlds(rhs, offset), mask) pto.vsts(summed, dst, offset, mask) - pto.dma_store(dst_tile, out_gm[0:rows, 0:cols]) return None specialized = kernel.specialize( @@ -675,6 +1709,37 @@ def kernel( self.assertIn("pto.plt_b32", text) self.assertIn("pto.vadd", text) + def test_stable_mode_infers_vecscope_and_lowers_tile_vector_sugar(self) -> None: + @pto.vkernel(op="tadd_stable", dtypes=[(pto.f32, pto.f32, pto.f32)]) + def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + dtype = dst.element_type + rows, cols = dst.valid_shape + all_mask = pto.make_mask(dtype, pto.PAT.ALL) + for row in range(0, rows, 1): + for col in range(0, cols, pto.get_lanes(dtype)): + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + summed = pto.vadd(lhs, rhs, all_mask) + pto.vsts(summed, dst[row, col:], all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src0=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src1=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + vecscope_stmts = [stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticVecscopeStmt)] + self.assertEqual(len(vecscope_stmts), 1) + + text = specialized.mlir_text() + self.assertIn("pto.vecscope {", text) + self.assertNotIn("pto.strict_vecscope(", text) + self.assertRegex(text, r"memref\.subview %tmp_\d+\[%row_\d+, %col_\d+\] \[%c1, %tmp_\d+\] \[%c1, %c1\]") + self.assertRegex(text, r"pto\.vlds %tmp_\d+\[%c0\]") + self.assertRegex(text, r"pto\.vsts %summed_\d+, %tmp_\d+\[%c0\], %(?:all_mask|mask)_\d+") + def test_advanced_mode_infers_vecscope_and_lowers_tile_vector_sugar(self) -> None: @pto.vkernel(op="tadd", dtypes=[(pto.f32, pto.f32, pto.f32)], advanced=True) def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): @@ -714,13 +1779,26 @@ def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): self.assertNotIn("pto.strict_vecscope(", text) self.assertRegex(text, r"pto\.vecscope \{\n(?:.|\n)*scf\.for %row_") self.assertEqual(text.count("pto.vecscope {"), 1) - self.assertLess(text.index("%rows_1 = arith.constant 8 : index"), text.index("pto.vecscope {")) - self.assertLess(text.index("%cols_2 = arith.constant 64 : index"), text.index("pto.vecscope {")) - self.assertRegex(text, r"%tmp_\d+ = arith\.muli %row_\d+, %c64 : index") - self.assertRegex(text, r"%tmp_\d+ = arith\.addi %tmp_\d+, %col_\d+ : index") - self.assertIn("pto.vlds %arg1[", text) - self.assertIn("pto.vlds %arg2[", text) - self.assertIn("pto.vsts %summed_", text) + self.assertIn("!pto.tile_buf> to memref<\?x\?xf32, strided<\[\?, \?\], offset: \?>, #pto\.address_space>") + self.assertRegex(text, r"pto\.vlds %tmp_\d+\[%c0\] : memref<\?x\?xf32, strided<\[\?, \?\], offset: \?>, #pto\.address_space> -> !pto\.vreg<64xf32>") + self.assertRegex(text, r"pto\.vsts %summed_\d+, %tmp_\d+\[%c0\], %(?:all_mask|mask)_\d+ : !pto\.vreg<64xf32>, memref<\?x\?xf32, strided<\[\?, \?\], offset: \?>, #pto\.address_space>, !pto\.mask") + self.assertNotRegex(text, r"arith\.muli %row_\d+, %c64 : index") + self.assertNotRegex(text, r"arith\.addi %tmp_\d+, %col_\d+ : index") + self.assertLess(text.index("pto.tile_buf_addr %arg1"), text.index("pto.vecscope {")) + self.assertLess(text.index("pto.tile_buf_addr %arg2"), text.index("pto.vecscope {")) + self.assertLess(text.index("pto.tile_buf_addr %arg0"), text.index("pto.vecscope {")) + self.assertLess(text.index("pto.tile_valid_rows %arg0"), text.index("pto.vecscope {")) + self.assertLess(text.index("pto.tile_valid_cols %arg0"), text.index("pto.vecscope {")) + self.assertLess(text.index("pto.vecscope {"), text.index("scf.for %row_")) + self.assertLess(text.rindex("pto.vecscope {"), text.index("return")) def test_element_type_valid_shape_and_get_lanes_surface_lower_in_advanced_mode(self) -> None: @pto.vkernel(op="tadd", dtypes=[(pto.f32, pto.f32, pto.f32)], advanced=True) @@ -746,33 +1824,875 @@ def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): self.assertRegex(text, r"%mask_\d+, %remained_\d+ = pto\.plt_b32 %remained_iter_\d+ : i32 -> !pto\.mask, i32") self.assertIn("pto.vadd", text) self.assertIn("pto.vsts", text) - - def test_advanced_mode_scalar_boundary_cuts_inferred_vecscope(self) -> None: - @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32)], advanced=True) - def kernel(src: pto.Tile, dst: pto.Tile): - dtype = src.element_type - first_mask = pto.make_mask(dtype, pto.PAT.ALL) - first = pto.vlds(src[0, 0:]) - pto.vsts(first, dst[0, 0:], first_mask) - boundary = 1 - second_mask = pto.make_mask(dtype, pto.PAT.ALL) - second = pto.vlds(src[1, 0:]) - pto.vsts(second, dst[1, 0:], second_mask) + self.assertIn("pto.tile_valid_rows %arg0", text) + self.assertIn("pto.tile_valid_cols %arg0", text) + self.assertRegex(text, r"memref\.subview %tmp_\d+\[%row_\d+, %col_\d+\] \[%c1, %tmp_\d+\] \[%c1, %c1\]") + self.assertRegex(text, r"pto\.vlds %tmp_\d+\[%c0\]") + self.assertRegex(text, r"pto\.vsts %summed_\d+, %tmp_\d+\[%c0\], %mask_\d+") + + def test_bytewidth_surface_lowers_to_constant_index(self) -> None: + @pto.vkernel(op="bytewidth_query_unique", dtypes=[(pto.f32,)], advanced=True) + def kernel(dst: pto.Tile): + elem_bytes = pto.bytewidth(dst.element_type) + rows, cols = dst.valid_shape + for col in range(0, cols, elem_bytes): + current = col return None specialized = kernel.specialize( - src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), ) text = specialized.mlir_text() - self.assertEqual(text.count("pto.vecscope {"), 2) - self.assertLess(text.index("%boundary_"), text.rindex("pto.vecscope {")) + self.assertIn("= arith.constant 4 : index", text) + self.assertRegex(text, r"scf\.for %col_\d+ = %c0 to %cols_\d+ step %elem_bytes_\d+") + self.assertIn("pto.tile_valid_cols %arg0", text) + + def test_elements_per_vreg_alias_surface_lowers_to_constant_index(self) -> None: + @pto.vkernel(op="elements_per_vreg_query_unique", dtypes=[(pto.f32,)], advanced=True) + def kernel(dst: pto.Tile): + lanes = pto.elements_per_vreg(dst.element_type) + rows, cols = dst.valid_shape + for col in range(0, cols, lanes): + current = col + return None - def test_advanced_mode_control_flow_boundary_cuts_inferred_vecscope(self) -> None: - @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32, pto.i32)], advanced=True) - def kernel(src: pto.Tile, dst: pto.Tile, flag: pto.i32): - dtype = src.element_type + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("= arith.constant 64 : index", text) + self.assertRegex(text, r"scf\.for %col_\d+ = %c0 to %cols_\d+ step %lanes_\d+") + self.assertIn("pto.tile_valid_cols %arg0", text) + + def test_vreg_type_constructor_and_annotation_match_vector_value(self) -> None: + @pto.vkernel(op="vreg_type_annotation_unique", dtypes=[(pto.f32,)], advanced=True) + def kernel(dst: pto.Tile): + dtype = dst.element_type + vec_ty = pto.vreg(dtype) + mask = pto.make_mask(dtype, pto.PAT.ALL) + vec: pto.vreg(dtype) = pto.vlds(dst, 0) + pto.vsts(vec, dst, 0, mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + vecscope = next( + stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticVecscopeStmt) + ) + self.assertIsInstance(vecscope, SemanticVecscopeStmt) + vec_assign = next( + stmt + for stmt in vecscope.body + if isinstance(stmt, SemanticAssignStmt) + and stmt.targets[0].name == "vec" + ) + self.assertIsInstance(vec_assign.targets[0].type, SemanticVRegType) + self.assertEqual(vec_assign.targets[0].type.element_dtype, pto.f32) + self.assertEqual(vec_assign.targets[0].type.lanes, 64) + self.assertTrue( + any( + isinstance(stmt, SemanticAssignStmt) + and stmt.targets[0].name == "vec_ty" + for stmt in semantic_kernel.body + ) + ) + + text = specialized.mlir_text() + self.assertRegex(text, r"%vec_\d+ = pto\.vlds %tmp_\d+\[%c0\] : memref<8x64xf32, #pto\.address_space> -> !pto\.vreg<64xf32>") + self.assertRegex(text, r"pto\.vsts %vec_\d+, %tmp_\d+\[%c0\], %mask_\d+ : !pto\.vreg<64xf32>, memref<8x64xf32, #pto\.address_space>, !pto\.mask") + + def test_mask_type_annotation_matches_make_mask_result(self) -> None: + @pto.vkernel(op="mask_type_annotation_unique", dtypes=[(pto.f32,)], advanced=True) + def kernel(dst: pto.Tile): + mask_ty = pto.mask_b32 + mask: pto.mask_b32 = pto.make_mask(pto.f32, pto.PAT.ALL) + alias_mask: mask_ty = mask + vec: pto.vreg(pto.f32) = pto.vlds(dst, 0) + pto.vsts(vec, dst, 0, alias_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertRegex(text, r'%mask_\d+ = pto\.pset_b32 "PAT_ALL" : !pto\.mask') + self.assertRegex(text, r"pto\.vsts %vec_\d+, %tmp_\d+\[%c0\], %\w+ : !pto\.vreg<64xf32>, memref<8x64xf32, #pto\.address_space>, !pto\.mask") + + def test_extended_float_vector_ops_surface_lowers(self) -> None: + @pto.vkernel( + op="extended_float_vector_ops_unique", + dtypes=[(pto.f32, pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile, alpha: pto.f32): + all_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec0 = pto.vlds(src, 0) + vec1 = pto.vlds(src, 64) + vec2 = pto.vlds(src, 128) + vec3 = pto.vlds(src, 192) + + out = pto.vln(vec0, all_mask) + out = pto.vsqrt(out, all_mask) + out = pto.vrec(out, all_mask) + out = pto.vrsqrt(out, all_mask) + out = pto.vexpdiff(out, all_mask) + out = pto.vcadd(out, all_mask) + out = pto.vcmax(out, all_mask) + out = pto.vcmin(out, all_mask) + out = pto.vmov(out, all_mask) + out = pto.vtrc(out, all_mask) + out = pto.vbitsort(out, all_mask) + out = pto.vprelu(out, vec1, all_mask) + out = pto.vlrelu(out, alpha, all_mask) + out = pto.vcvt(out, pto.f32, all_mask) + out = pto.vmrgsort4(out, vec1, vec2, vec3, all_mask) + pto.vsts(out, dst, 0, all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 256), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 256), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vln", text) + self.assertIn("pto.vsqrt", text) + self.assertIn("pto.vrec", text) + self.assertIn("pto.vrsqrt", text) + self.assertIn("pto.vexpdiff", text) + self.assertIn("pto.vcadd", text) + self.assertIn("pto.vcmax", text) + self.assertIn("pto.vcmin", text) + self.assertIn("pto.vmov", text) + self.assertIn("pto.vtrc", text) + self.assertIn("pto.vbitsort", text) + self.assertIn("pto.vprelu", text) + self.assertIn("pto.vlrelu", text) + self.assertIn("pto.vcvt", text) + self.assertIn("pto.vmrgsort4", text) + + def test_extended_integer_vector_ops_surface_lowers(self) -> None: + @pto.vkernel( + op="extended_integer_vector_ops_unique", + dtypes=[(pto.i32, pto.i32, pto.i32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile, shift: pto.i32): + all_mask = pto.make_mask(pto.i32, pto.PAT.ALL) + vec0 = pto.vlds(src, 0) + vec1 = pto.vlds(src, 64) + + out = pto.vbcnt(vec0, all_mask) + out = pto.vneg(out, all_mask) + out = pto.vcls(out, all_mask) + out = pto.vsunpack(out, all_mask) + out = pto.vzunpack(out, all_mask) + out = pto.vusqz(out, all_mask) + out = pto.vsqz(out, all_mask) + out = pto.vshl(out, vec1, all_mask) + out = pto.vshr(out, vec1, all_mask) + out = pto.vshls(out, shift, all_mask) + out = pto.vshrs(out, shift, all_mask) + pto.vsts(out, dst, 0, all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vbcnt", text) + self.assertIn("pto.vneg", text) + self.assertIn("pto.vcls", text) + self.assertIn("pto.vsunpack", text) + self.assertIn("pto.vzunpack", text) + self.assertIn("pto.vusqz", text) + self.assertIn("pto.vsqz", text) + self.assertIn("pto.vshl", text) + self.assertIn("pto.vshr", text) + self.assertIn("pto.vshls", text) + self.assertIn("pto.vshrs", text) + + def test_fused_vector_ops_surface_lowers(self) -> None: + @pto.vkernel( + op="fused_vector_ops_unique", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + all_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec0 = pto.vlds(src, 0) + vec1 = pto.vlds(src, 64) + vec2 = pto.vlds(src, 128) + vec3 = pto.vlds(src, 192) + + out = pto.vaddrelu(vec0, vec1, all_mask) + out = pto.vaddreluconv(out, vec2, all_mask) + out = pto.vsubrelu(out, vec3, all_mask) + out = pto.vmulconv(out, vec1, all_mask) + out = pto.vaxpy(vec1, out, vec2, all_mask) + out = pto.vmula(vec1, vec2, out, all_mask) + pto.vsts(out, dst, 0, all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 256), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 256), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vaddrelu", text) + self.assertIn("pto.vaddreluconv", text) + self.assertIn("pto.vsubrelu", text) + self.assertIn("pto.vmulconv", text) + self.assertIn("pto.vaxpy", text) + self.assertIn("pto.vmula", text) + + def test_vmull_and_vector_scalar_bitwise_surface_lowers(self) -> None: + @pto.vkernel( + op="vmull_and_scalar_bitwise_unique", + dtypes=[(pto.i32, pto.i32, pto.i32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile, scalar: pto.i32): + all_mask = pto.make_mask(pto.i32, pto.PAT.ALL) + vec0 = pto.vlds(src, 0) + vec1 = pto.vlds(src, 64) + + low, high = pto.vmull(vec0, vec1, all_mask) + out = pto.vadd(low, high, all_mask) + out = pto.vands(out, scalar, all_mask) + out = pto.vors(out, scalar, all_mask) + out = pto.vxors(out, scalar, all_mask) + pto.vsts(out, dst, 0, all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vmull", text) + self.assertIn("pto.vands", text) + self.assertIn("pto.vors", text) + self.assertIn("pto.vxors", text) + + def test_broadcast_and_index_vector_ops_surface_lowers(self) -> None: + @pto.vkernel( + op="broadcast_and_index_vector_ops_unique", + dtypes=[(pto.i32, pto.i32, pto.i32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile, seed: pto.i32): + all_mask = pto.make_mask(pto.i32, pto.PAT.ALL) + vec0 = pto.vlds(src, 0) + + broadcast = pto.vbr(seed) + dup_from_vec = pto.vdup(vec0) + dup_from_scalar = pto.vdup(seed, pto.PositionMode.LOWEST) + idx0 = pto.vci(seed) + idx1 = pto.vci(seed, pto.OrderMode.ASC) + + out = pto.vadd(broadcast, dup_from_vec, all_mask) + out = pto.vadd(out, dup_from_scalar, all_mask) + out = pto.vadd(out, idx0, all_mask) + out = pto.vadd(out, idx1, all_mask) + pto.vsts(out, dst, 0, all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vbr", text) + self.assertIn("pto.vdup", text) + self.assertIn("pto.vci", text) + self.assertRegex( + text, + r'pto\.vdup\s+%[^\s]+\s+\{position = "POS_LOWEST"\}\s+:', + ) + self.assertNotRegex( + text, + r'pto\.vdup\s+%[^\s]+,\s*"POS_LOWEST"\s+:', + ) + self.assertRegex( + text, + r'pto\.vci\s+%[^\s]+\s+\{order = "ORDER_ASC"\}\s+:', + ) + self.assertNotRegex( + text, + r'pto\.vci\s+%[^\s]+,\s*"ORDER_ASC"\s+:', + ) + + def test_vbr_accepts_float_literal_constant(self) -> None: + @pto.vkernel( + op="broadcast_float_literal_constant_unique", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + all_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec0 = pto.vlds(src, 0) + bias = pto.vbr(0.0) + out = pto.vadd(vec0, bias, all_mask) + pto.vsts(out, dst, 0, all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("= arith.constant 0.0 : f32", text) + self.assertIn("pto.vbr", text) + + def test_scalar_constructor_call_surfaces_lower(self) -> None: + @pto.vkernel( + op="scalar_constructor_call_surfaces_unique", + dtypes=[(pto.i32, pto.i32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + mask = pto.make_mask(pto.i32, pto.PAT.ALL) + base = pto.i32(1) + idx = pto.i16(base) + idx = pto.i8(idx) + idx = pto.i64(idx) + flt = pto.f16(idx) + flt = pto.bf16(flt) + flt = pto.f32(flt) + gate = pto.i1(flt) + scalar = pto.i32(gate) + vec = pto.vlds(src, 0) + out = pto.vadds(vec, scalar, mask) + pto.vsts(out, dst, 0, mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("arith.trunci", text) + self.assertIn("arith.extsi", text) + self.assertIn("arith.sitofp", text) + self.assertIn("arith.fptosi", text) + self.assertIn("arith.extf", text) + self.assertIn("arith.truncf", text) + + def test_scalar_constructor_accepts_signed_float_literals(self) -> None: + @pto.vkernel(op="scalar_constructor_signed_float_literals_unique", dtypes=[(pto.f32,)]) + def kernel(inp: pto.TensorView): + a = pto.f16(-1.5) + b = pto.bf16(+2.5) + c = pto.f32(-3.5) + return None + + text = kernel.mlir_text() + self.assertIn("= arith.constant -1.5 : f16", text) + self.assertIn("= arith.constant 2.5 : bf16", text) + self.assertIn("= arith.constant -3.5 : f32", text) + + def test_scalar_constructor_accepts_special_float_string_literals(self) -> None: + @pto.vkernel(op="scalar_constructor_special_float_literals_unique", dtypes=[(pto.f32,)]) + def kernel(inp: pto.TensorView): + a = pto.f16("-inf") + b = pto.bf16("inf") + c = pto.f32("nan") + d = pto.f16("0xFC00") + e = pto.bf16("0xFF80") + f = pto.f32("0xFF800000") + return None + + text = kernel.mlir_text() + self.assertIn("= arith.constant -inf : f16", text) + self.assertIn("= arith.constant inf : bf16", text) + self.assertIn("= arith.constant nan : f32", text) + self.assertIn("= arith.constant -inf : bf16", text) + self.assertIn("= arith.constant -inf : f32", text) + + def test_scalar_constructor_rejects_bad_arity(self) -> None: + @pto.vkernel(op="scalar_constructor_bad_arity_no_arg_unique", dtypes=[(pto.f32,)]) + def kernel_no_arg(inp: pto.TensorView): + x = pto.i32() + return None + + with self.assertRaises(TypeError) as no_arg_ctx: + kernel_no_arg.mlir_text() + + self.assertIn("pto.i32 expects exactly 1 positional argument", str(no_arg_ctx.exception)) + + @pto.vkernel(op="scalar_constructor_bad_arity_two_arg_unique", dtypes=[(pto.f32,)]) + def kernel_two_arg(inp: pto.TensorView): + x = pto.f32(1.0, 2.0) + return None + + with self.assertRaises(TypeError) as two_arg_ctx: + kernel_two_arg.mlir_text() + + self.assertIn("pto.f32 expects exactly 1 positional argument", str(two_arg_ctx.exception)) + + def test_scalar_constructor_rejects_non_scalar_operand(self) -> None: + @pto.vkernel(op="scalar_constructor_bad_operand_unique", dtypes=[(pto.f32,)]) + def kernel(inp: pto.TensorView): + x = pto.i32(inp) + return None + + with self.assertRaises(TypeError) as ctx: + kernel.mlir_text() + + self.assertIn("pto.i32 value must be a scalar or index value", str(ctx.exception)) + + def test_scalar_constructor_rejects_out_of_range_integer_literal(self) -> None: + @pto.vkernel(op="scalar_constructor_oob_int_unique", dtypes=[(pto.f32,)]) + def kernel(inp: pto.TensorView): + x = pto.i8(1024) + return None + + with self.assertRaises(TypeError) as ctx: + kernel.mlir_text() + + self.assertIn("out of range for i8", str(ctx.exception)) + + def test_inferred_vecscope_propagates_bindings_to_constexpr_if(self) -> None: + @pto.vkernel( + op="inferred_vecscope_binding_propagation_unique", + dtypes=[(pto.f32, pto.f32)], + ) + def kernel(dst: pto.Tile, src: pto.Tile): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + acc = pto.vbr(0.0) + vec = pto.vlds(src, 0) + acc = pto.vadd(acc, vec, mask) + if pto.constexpr(True): + pto.vsts(acc, dst, 0, mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vadd", text) + self.assertIn("pto.vsts", text) + self.assertIn("= arith.constant 0.0 : f32", text) + + def test_loop_lowering_supports_multiple_loop_carried_bindings(self) -> None: + @pto.vkernel( + op="loop_multi_carried_bindings_unique", + dtypes=[(pto.f32, pto.f32)], + ) + def kernel(dst: pto.Tile, src: pto.Tile): + remained = 64 + acc = pto.vbr(0.0) + all_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + for col in range(0, 64, 64): + mask, remained = pto.make_mask(pto.f32, remained) + vec = pto.vlds(src, col) + acc = pto.vadd(acc, vec, mask) + pto.vsts(acc, dst, 0, all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertRegex(text, r"%remained_\d+, %acc_\d+ = scf\.for") + self.assertRegex(text, r"iter_args\(%remained_iter_\d+_0 = [^,]+, %acc_iter_\d+_1 = [^)]+\)") + self.assertRegex(text, r"scf\.yield %remained_\d+, %acc_\d+ : i32, !pto\.vreg<64xf32>") + + def test_reduction_and_rearrangement_vector_ops_surface_lowers(self) -> None: + @pto.vkernel( + op="reduction_and_rearrangement_vector_ops_unique", + dtypes=[(pto.i32, pto.i32, pto.i32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile, shift: pto.i32): + all_mask = pto.make_mask(pto.i32, pto.PAT.ALL) + vec0 = pto.vlds(src, 0) + vec1 = pto.vlds(src, 64) + indices = pto.vci(shift, pto.OrderMode.ASC) + + out = pto.vcgadd(vec0, all_mask) + out = pto.vcgmax(out, all_mask) + out = pto.vcgmin(out, all_mask) + out = pto.vcpadd(out, all_mask) + out = pto.vpack(out, vec1, all_mask) + out = pto.vperm(out, indices, all_mask) + out = pto.vshift(out, shift, all_mask) + out = pto.vslide(out, shift, all_mask) + out = pto.vsort32(out, all_mask) + out = pto.vmrgsort(out, vec1, all_mask) + pto.vsts(out, dst, 0, all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vcgadd", text) + self.assertIn("pto.vcgmax", text) + self.assertIn("pto.vcgmin", text) + self.assertIn("pto.vcpadd", text) + self.assertIn("pto.vpack", text) + self.assertIn("pto.vperm", text) + self.assertIn("pto.vshift", text) + self.assertIn("pto.vslide", text) + self.assertIn("pto.vsort32", text) + self.assertIn("pto.vmrgsort", text) + + def test_scalar_loop_prologue_does_not_force_vecscope_into_inner_loop(self) -> None: + @pto.vkernel(op="tadd_outer_scope_unique", dtypes=[(pto.f32, pto.f32, pto.f32)]) + def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + summed = pto.vadd(lhs, rhs, mask) + pto.vsts(summed, dst[row, col:], mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src0=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src1=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + vecscope_stmts = [stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticVecscopeStmt)] + self.assertEqual(len(vecscope_stmts), 1) + outer_loop = vecscope_stmts[0].body[0] + self.assertIsInstance(outer_loop, SemanticForStmt) + self.assertIsInstance(outer_loop.body[0], SemanticAssignStmt) + self.assertIsInstance(outer_loop.body[1], SemanticForStmt) + + text = specialized.mlir_text() + self.assertEqual(text.count("pto.vecscope {"), 1) + self.assertRegex(text, r"pto\.vecscope \{\n\s+scf\.for %row_\d+ = %c0 to %valid_rows_\d+ step %c1") + self.assertNotRegex(text, r"scf\.for %row_\d+ = [^\n]+\{\n\s+pto\.vecscope \{") + + def test_unused_tile_does_not_hoist_tile_buf_addr_or_valid_shape_intrinsics(self) -> None: + @pto.vkernel(op="tile_usage_scan_unique", dtypes=[(pto.f32, pto.f32, pto.f32)], advanced=True) + def kernel(dst: pto.Tile, src: pto.Tile, scratch: pto.Tile): + rows, cols = dst.valid_shape + mask = pto.make_mask(dst.element_type, pto.PAT.ALL) + for row in range(0, rows, 1): + for col in range(0, cols, pto.get_lanes(dst.element_type)): + value = pto.vlds(src[row, col:]) + pto.vsts(value, dst[row, col:], mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + scratch=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.tile_buf_addr %arg0", text) + self.assertIn("pto.tile_buf_addr %arg1", text) + self.assertNotIn("pto.tile_buf_addr %arg2", text) + self.assertIn("pto.tile_valid_rows %arg0", text) + self.assertIn("pto.tile_valid_cols %arg0", text) + self.assertNotIn("pto.tile_valid_rows %arg1", text) + self.assertNotIn("pto.tile_valid_cols %arg1", text) + self.assertNotIn("pto.tile_valid_rows %arg2", text) + self.assertNotIn("pto.tile_valid_cols %arg2", text) + + def test_tile_dynamic_valid_shape_profile_lowers_to_runtime_bounds_in_advanced_mode(self) -> None: + elem = pto.TypeVar("Elem") + + @pto.vkernel(op="tadd_dynamic_valid_shape_unique", dtypes=[(elem, elem, elem)], advanced=True) + def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + remained = valid_cols + for row in range(0, valid_rows, 1): + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + summed = pto.vadd(pto.vlds(src0[row, col:]), pto.vlds(src1[row, col:]), mask) + pto.vsts(summed, dst[row, col:], mask) + return None + + selected = pto.select_kernel( + "a5", + "tadd_dynamic_valid_shape_unique", + (pto.f16, pto.f16, pto.f16), + ) + specialized = selected.specialize( + dst=pto.TileSpecialization( + shape=(8, 128), + memory_space=pto.MemorySpace.UB, + valid_shape=("valid_rows", "valid_cols"), + ), + src0=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src1=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertEqual( + [(param.name, param.kind) for param in semantic_kernel.parameters], + [ + ("dst", "tile"), + ("src0", "tile"), + ("src1", "tile"), + ("__valid_shape_dst_0", "tile_valid_shape"), + ("__valid_shape_dst_1", "tile_valid_shape"), + ], + ) + self.assertEqual(semantic_kernel.tile_bindings[0].valid_shape, (None, None)) + + text = specialized.mlir_text() + self.assertIn( + "func.func @kernel(%arg0: !pto.tile_buf, %arg1: !pto.tile_buf, %arg2: !pto.tile_buf) attributes { pto.tilelang.instance } {", + text, + ) + self.assertIn("valid_shape=(?, ?)", text) + self.assertIn("pto.vecscope {", text) + self.assertIn("step %c128", text) + self.assertIn("pto.tile_valid_rows %arg0", text) + self.assertIn("pto.tile_valid_cols %arg0", text) + self.assertNotIn("pto.tile_valid_rows %arg1", text) + self.assertNotIn("pto.tile_valid_cols %arg1", text) + self.assertNotIn("pto.tile_valid_rows %arg2", text) + self.assertNotIn("pto.tile_valid_cols %arg2", text) + self.assertLess(text.index("pto.tile_valid_rows %arg0"), text.index("pto.vecscope {")) + self.assertLess(text.index("pto.tile_valid_cols %arg0"), text.index("pto.vecscope {")) + self.assertRegex(text, r"scf\.for %row_\d+ = %c0 to %valid_rows_\d+ step %c1") + self.assertRegex(text, r"scf\.for %col_\d+ = %c0 to %valid_cols_\d+ step %c128") + self.assertRegex(text, r"%tmp_\d+ = arith\.index_cast %valid_cols_\d+ : index to i32") + self.assertRegex( + text, + r"pto\.tile_buf_addr %arg1 : !pto\.tile_buf> to memref<\?x\?xf16, strided<\[\?, \?\], offset: \?>, #pto\.address_space>", + ) + self.assertRegex( + text, + r"pto\.vlds %tmp_\d+\[%c0\] : memref<\?x\?xf16, strided<\[\?, \?\], offset: \?>, #pto\.address_space> -> !pto\.vreg<128xf16>", + ) + self.assertRegex( + text, + r"pto\.vsts %summed_\d+, %tmp_\d+\[%c0\], %mask_\d+ : !pto\.vreg<128xf16>, memref<\?x\?xf16, strided<\[\?, \?\], offset: \?>, #pto\.address_space>, !pto\.mask", + ) + + def test_tile_partial_dynamic_valid_shape_profile_tracks_dynamic_axes_only(self) -> None: + elem = pto.TypeVar("Elem") + + @pto.vkernel(op="tadd_partial_dynamic_valid_shape_unique", dtypes=[(elem, elem, elem)], advanced=True) + def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + remained = valid_cols + for row in range(0, valid_rows, 1): + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + summed = pto.vadd(pto.vlds(src0[row, col:]), pto.vlds(src1[row, col:]), mask) + pto.vsts(summed, dst[row, col:], mask) + return None + + selected = pto.select_kernel( + "a5", + "tadd_partial_dynamic_valid_shape_unique", + (pto.f16, pto.f16, pto.f16), + ) + + rows_dynamic = selected.specialize( + dst=pto.TileSpecialization( + shape=(8, 128), + memory_space=pto.MemorySpace.UB, + valid_shape=("valid_rows", 128), + ), + src0=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src1=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + rows_dynamic_semantic = analyze_frontend_kernel(build_frontend_kernel_node(rows_dynamic)) + self.assertEqual( + [(param.name, param.kind) for param in rows_dynamic_semantic.parameters], + [ + ("dst", "tile"), + ("src0", "tile"), + ("src1", "tile"), + ("__valid_shape_dst_0", "tile_valid_shape"), + ], + ) + rows_dynamic_text = rows_dynamic.mlir_text() + self.assertIn("valid_shape=(?, 128)", rows_dynamic_text) + self.assertIn("pto.tile_valid_rows %arg0", rows_dynamic_text) + self.assertIn("pto.tile_valid_cols %arg0", rows_dynamic_text) + self.assertRegex(rows_dynamic_text, r"scf\.for %row_\d+ = %c0 to %valid_rows_\d+ step %c1") + self.assertRegex(rows_dynamic_text, r"scf\.for %col_\d+ = %c0 to %valid_cols_\d+ step %c128") + + cols_dynamic = selected.specialize( + dst=pto.TileSpecialization( + shape=(8, 128), + memory_space=pto.MemorySpace.UB, + valid_shape=(8, "valid_cols"), + ), + src0=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src1=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + cols_dynamic_semantic = analyze_frontend_kernel(build_frontend_kernel_node(cols_dynamic)) + self.assertEqual( + [(param.name, param.kind) for param in cols_dynamic_semantic.parameters], + [ + ("dst", "tile"), + ("src0", "tile"), + ("src1", "tile"), + ("__valid_shape_dst_1", "tile_valid_shape"), + ], + ) + cols_dynamic_text = cols_dynamic.mlir_text() + self.assertIn("valid_shape=(8, ?)", cols_dynamic_text) + self.assertIn("pto.tile_valid_rows %arg0", cols_dynamic_text) + self.assertIn("pto.tile_valid_cols %arg0", cols_dynamic_text) + self.assertRegex(cols_dynamic_text, r"scf\.for %row_\d+ = %c0 to %valid_rows_\d+ step %c1") + self.assertRegex(cols_dynamic_text, r"scf\.for %col_\d+ = %c0 to %valid_cols_\d+ step %c128") + + def test_advanced_mode_scalar_boundaries_split_inferred_vecscope_runs(self) -> None: + @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32)], advanced=True) + def kernel(src: pto.Tile, dst: pto.Tile): + dtype = src.element_type + first_mask = pto.make_mask(dtype, pto.PAT.ALL) + first = pto.vlds(src[0, 0:]) + pto.vsts(first, dst[0, 0:], first_mask) + boundary = 1 + second_mask = pto.make_mask(dtype, pto.PAT.ALL) + second = pto.vlds(src[1, 0:]) + pto.vsts(second, dst[1, 0:], second_mask) + return None + + specialized = kernel.specialize( + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + vecscope_stmts = [stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticVecscopeStmt)] + self.assertEqual(len(vecscope_stmts), 2) + + text = specialized.mlir_text() + self.assertEqual(text.count("pto.vecscope {"), 2) + self.assertLess(text.index("pto.vecscope {"), text.index("%boundary_")) + self.assertLess(text.index("%boundary_"), text.index("return")) + self.assertLess(text.index("%boundary_"), text.rindex("pto.vecscope {")) + + def test_explicit_vecscope_is_supported_in_stable_mode(self) -> None: + @pto.vkernel(op="explicit_vecscope_stable_unique", dtypes=[(pto.f32, pto.f32)]) + def kernel(src: pto.Tile, dst: pto.Tile): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + with pto.vecscope(): + vec = pto.vlds(src, 0) + pto.vsts(vec, dst, 0, mask) + return None + + specialized = kernel.specialize( + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + frontend_kernel = build_frontend_kernel_node(specialized) + self.assertIsInstance(frontend_kernel.body[1], FrontendVecscopeStmt) + + semantic_kernel = analyze_frontend_kernel(frontend_kernel) + vecscope_stmts = [stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticVecscopeStmt)] + self.assertEqual(len(vecscope_stmts), 1) + + text = specialized.mlir_text() + self.assertEqual(text.count("pto.vecscope {"), 1) + self.assertIn("pto.vlds", text) + self.assertIn("pto.vsts", text) + + def test_explicit_vecscope_disables_automatic_inference(self) -> None: + @pto.vkernel(op="explicit_vecscope_disables_infer_unique", dtypes=[(pto.f32, pto.f32)], advanced=True) + def kernel(src: pto.Tile, dst: pto.Tile): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + with pto.vecscope(): + first = pto.vlds(src, 0) + pto.vsts(first, dst, 0, mask) + second = pto.vlds(src, 64) + pto.vsts(second, dst, 64, mask) + return None + + specialized = kernel.specialize( + src=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + dst=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + vecscope_stmts = [stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticVecscopeStmt)] + self.assertEqual(len(vecscope_stmts), 1) + + text = specialized.mlir_text() + self.assertEqual(text.count("pto.vecscope {"), 1) + self.assertIn("pto.vlds", text) + self.assertIn("pto.vsts", text) + + def test_constexpr_if_tail_store_does_not_split_inferred_vecscope(self) -> None: + @pto.vkernel(op="trowsum_like_vecscope_unique", dtypes=[(pto.f32, pto.f32, pto.f32)], advanced=True) + def kernel(dst: pto.Tile, src: pto.Tile, tmp: pto.Tile): + src_dtype = src.element_type + valid_rows, valid_cols = src.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + acc = pto.vbr(0.0) + for col in range(0, valid_cols, pto.get_lanes(src_dtype)): + mask, remained = pto.make_mask(src_dtype, remained) + vec = pto.vlds(src[row, col:]) + reduced = pto.vcadd(vec, mask) + one_mask, _ = pto.make_mask(src_dtype, 1) + acc = pto.vadd(acc, reduced, one_mask) + out_mask, _ = pto.make_mask(src_dtype, 1) + if pto.constexpr(src_dtype != dst.element_type): + casted = pto.vcvt(acc, out_mask, dst.element_type) + pto.vsts(casted, dst[row, 0:], out_mask) + else: + pto.vsts(acc, dst[row, 0:], out_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + tmp=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + vecscope_stmts = [stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticVecscopeStmt)] + self.assertEqual(len(vecscope_stmts), 1) + + text = specialized.mlir_text() + self.assertEqual(text.count("pto.vecscope {"), 1) + self.assertRegex(text, r"pto\.vecscope \{\n(?:.|\n)*scf\.for %row_\d+") + self.assertIn("pto.vsts", text) + + def test_advanced_mode_control_flow_infers_vecscope_per_branch(self) -> None: + @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32, pto.i32)], advanced=True) + def kernel(src: pto.Tile, dst: pto.Tile, flag: pto.i32): + dtype = src.element_type all_mask = pto.make_mask(dtype, pto.PAT.ALL) if flag: first = pto.vlds(src[0, 0:]) @@ -787,10 +2707,24 @@ def kernel(src: pto.Tile, dst: pto.Tile, flag: pto.i32): dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), ) + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertEqual([type(stmt).__name__ for stmt in semantic_kernel.body[:-1]], [ + "SemanticAssignStmt", + "SemanticAssignStmt", + "SemanticIfStmt", + ]) + if_stmt = semantic_kernel.body[2] + self.assertIsInstance(if_stmt, SemanticIfStmt) + self.assertEqual(len(if_stmt.then_body), 1) + self.assertEqual(len(if_stmt.else_body), 1) + self.assertIsInstance(if_stmt.then_body[0], SemanticVecscopeStmt) + self.assertIsInstance(if_stmt.else_body[0], SemanticVecscopeStmt) + text = specialized.mlir_text() self.assertIn("scf.if", text) self.assertEqual(text.count("pto.vecscope {"), 2) self.assertLess(text.index("scf.if"), text.index("pto.vecscope {")) + self.assertLess(text.index("scf.if"), text.index("return")) def test_advanced_mode_keeps_strict_vecscope_as_hard_boundary(self) -> None: @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32)], advanced=True) @@ -852,13 +2786,14 @@ def kernel( self.assertEqual(semantic_kernel.parameters[0].type.memory_space, "gm") self.assertIsInstance(semantic_kernel.parameters[1].type, SemanticPtrType) self.assertEqual(semantic_kernel.parameters[1].type.memory_space, "gm") - self.assertTrue(any(isinstance(stmt, SemanticVecscopeStmt) for stmt in semantic_kernel.body)) self.assertTrue(any(isinstance(stmt, SemanticDmaConfigStmt) for stmt in semantic_kernel.body)) self.assertTrue(any(isinstance(stmt, SemanticLowLevelCopyStmt) for stmt in semantic_kernel.body)) + vecscope_stmts = [stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticVecscopeStmt)] + self.assertEqual(len(vecscope_stmts), 1) text = kernel.mlir_text() self.assertIn( - "func.func @kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: i64) {", + "func.func @kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: i64) attributes { pto.tilelang.instance } {", text, ) self.assertRegex( @@ -900,45 +2835,299 @@ def kernel( ) self.assertRegex( text, - r"pto\.set_loop1_stride_outtoub %tmp_\d+, %tmp_\d+ : i64, i64", + r"pto\.set_loop1_stride_outtoub %tmp_\d+, %tmp_\d+ : i64, i64", + ) + self.assertRegex( + text, + r"pto\.set_loop_size_outtoub %tmp_\d+, %tmp_\d+ : i64, i64", + ) + self.assertRegex( + text, + r"pto\.copy_gm_to_ubuf %typed_src_\d+, %ub_src_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %false, %tmp_\d+, %tmp_\d+, %tmp_\d+", + ) + self.assertIn( + ": !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64", + text, + ) + self.assertRegex( + text, + r"pto\.set_loop2_stride_ubtoout %tmp_\d+, %tmp_\d+ : i64, i64", + ) + self.assertRegex( + text, + r"pto\.set_loop1_stride_ubtoout %tmp_\d+, %tmp_\d+ : i64, i64", + ) + self.assertRegex( + text, + r"pto\.set_loop_size_ubtoout %tmp_\d+, %tmp_\d+ : i64, i64", + ) + self.assertRegex( + text, + r"pto\.copy_ubuf_to_ubuf %ub_src_\d+, %ub_dst_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+", + ) + self.assertIn( + ": !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64", + text, ) self.assertRegex( text, - r"pto\.set_loop_size_outtoub %tmp_\d+, %tmp_\d+ : i64, i64", + r"pto\.copy_ubuf_to_gm %ub_dst_\d+, %typed_dst_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+", + ) + + def test_as_ptr_method_and_keyword_low_level_dma_surface_lower_in_advanced_mode(self) -> None: + @pto.vkernel(op="tensorview_tile_as_ptr_dma_unique", dtypes=[(pto.f32, pto.f32)], advanced=True) + def kernel(inp: pto.TensorView, dst: pto.Tile): + gm_ptr = inp.as_ptr() + ub_ptr = dst.as_ptr() + + pto.set_loop2_stride_outtoub(src_stride=4096, dst_stride=2048) + pto.set_loop1_stride_outtoub(src_stride=1024, dst_stride=512) + pto.set_loop_size_outtoub(loop1=1, loop2=1) + pto.copy_gm_to_ubuf( + src=gm_ptr, + dst=ub_ptr, + n_burst=1, + len_burst=64, + gm_stride=128, + ub_stride=128, + enable_ub_pad=False, + ) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertTrue(any(isinstance(stmt, SemanticDmaConfigStmt) for stmt in semantic_kernel.body)) + self.assertTrue(any(isinstance(stmt, SemanticLowLevelCopyStmt) for stmt in semantic_kernel.body)) + + text = specialized.mlir_text() self.assertRegex( text, - r"pto\.copy_gm_to_ubuf %typed_src_\d+, %ub_src_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %false, %tmp_\d+, %tmp_\d+, %tmp_\d+", + r"%gm_ptr_\d+ = pto\.tensor_view_addr %arg0 : !pto\.tensor_view<\?x\?x\?x\?x\?xf32> -> !pto\.ptr", ) - self.assertIn( - ": !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64", + self.assertRegex( text, + r"%ub_ptr_\d+ = pto\.tile_buf_addr %arg1 : !pto\.tile_buf -> !pto\.ptr", ) + self.assertRegex(text, r"pto\.set_loop2_stride_outtoub %tmp_\d+, %tmp_\d+ : i64, i64") + self.assertRegex(text, r"pto\.set_loop1_stride_outtoub %tmp_\d+, %tmp_\d+ : i64, i64") + self.assertRegex(text, r"pto\.set_loop_size_outtoub %tmp_\d+, %tmp_\d+ : i64, i64") self.assertRegex( text, - r"pto\.set_loop2_stride_ubtoout %tmp_\d+, %tmp_\d+ : i64, i64", + r"pto\.copy_gm_to_ubuf %gm_ptr_\d+, %ub_ptr_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %false, %tmp_\d+, %tmp_\d+, %tmp_\d+", + ) + + def test_copy_ubuf_to_gm_keyword_surface_lowers_in_advanced_mode(self) -> None: + @pto.vkernel(op="tile_to_tensorview_dma_unique", dtypes=[(pto.f32, pto.f32)], advanced=True) + def kernel(src: pto.Tile, dst: pto.TensorView): + ub_ptr = src.as_ptr() + gm_ptr = dst.as_ptr() + + pto.set_loop2_stride_ubtoout(src_stride=4096, dst_stride=2048) + pto.set_loop1_stride_ubtoout(src_stride=1024, dst_stride=512) + pto.set_loop_size_ubtoout(loop1=1, loop2=1) + pto.copy_ubuf_to_gm( + src=ub_ptr, + dst=gm_ptr, + n_burst=1, + len_burst=64, + gm_stride=128, + ub_stride=128, + ) + return None + + specialized = kernel.specialize( + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertTrue(any(isinstance(stmt, SemanticDmaConfigStmt) for stmt in semantic_kernel.body)) + self.assertTrue(any(isinstance(stmt, SemanticLowLevelCopyStmt) for stmt in semantic_kernel.body)) + + text = specialized.mlir_text() self.assertRegex( text, - r"pto\.set_loop1_stride_ubtoout %tmp_\d+, %tmp_\d+ : i64, i64", + r"%ub_ptr_\d+ = pto\.tile_buf_addr %arg0 : !pto\.tile_buf -> !pto\.ptr", ) self.assertRegex( text, - r"pto\.set_loop_size_ubtoout %tmp_\d+, %tmp_\d+ : i64, i64", + r"%gm_ptr_\d+ = pto\.tensor_view_addr %arg1 : !pto\.tensor_view<\?x\?x\?x\?x\?xf32> -> !pto\.ptr", ) + self.assertRegex(text, r"pto\.set_loop2_stride_ubtoout %tmp_\d+, %tmp_\d+ : i64, i64") + self.assertRegex(text, r"pto\.set_loop1_stride_ubtoout %tmp_\d+, %tmp_\d+ : i64, i64") + self.assertRegex(text, r"pto\.set_loop_size_ubtoout %tmp_\d+, %tmp_\d+ : i64, i64") self.assertRegex( text, - r"pto\.copy_ubuf_to_ubuf %ub_src_\d+, %ub_dst_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+", + r"pto\.copy_ubuf_to_gm %ub_ptr_\d+, %gm_ptr_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+", + ) + + def test_castptr_rejects_tensorview_or_tile_inputs_in_advanced_mode(self) -> None: + @pto.vkernel(op="castptr_tensorview_reject_unique", dtypes=[(pto.f32,)], advanced=True) + def tensorview_kernel(inp: pto.TensorView): + tmp = pto.castptr(inp, pto.ptr(pto.f32, pto.MemorySpace.GM)) + return None + + with self.assertRaises(TypeError) as tensorview_ctx: + analyze_frontend_kernel(build_frontend_kernel_node(tensorview_kernel)) + self.assertIn("pto.castptr input must be an index/i64, pointer, or memref-backed address value", str(tensorview_ctx.exception)) + + @pto.vkernel(op="castptr_tile_reject_unique", dtypes=[(pto.f32,)], advanced=True) + def tile_kernel(inp: pto.Tile): + tmp = pto.castptr(inp, pto.ptr(pto.f32, pto.MemorySpace.UB)) + return None + + specialized = tile_kernel.specialize( + inp=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + with self.assertRaises(TypeError) as tile_ctx: + analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertIn("pto.castptr input must be an index/i64, pointer, or memref-backed address value", str(tile_ctx.exception)) + + def test_constexpr_if_folds_static_dtype_condition_without_scf_if(self) -> None: + @pto.vkernel(op="constexpr_if_dtype_fold", dtypes=[(pto.f16, pto.f32)]) + def kernel(dst: pto.Tile, src: pto.Tile): + step = 64 + if pto.constexpr(dst.element_type != src.element_type): + step = 128 + else: + step = 64 + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertFalse(any(isinstance(stmt, SemanticIfStmt) for stmt in semantic_kernel.body)) + + text = specialized.mlir_text() + self.assertNotIn("scf.if", text) + self.assertNotIn("arith.cmpi ne", text) + self.assertRegex(text, r"%step_\d+ = arith\.constant 128 : index") + + def test_constexpr_if_rejects_non_static_condition(self) -> None: + @pto.vkernel(op="constexpr_if_dynamic_reject", dtypes=[(pto.f32,)]) + def kernel(src: pto.TensorView): + step = 64 + if pto.constexpr(src.shape[0] != 1): + step = 128 + return None + + with self.assertRaises(TypeError) as ctx: + analyze_frontend_kernel(build_frontend_kernel_node(kernel)) self.assertIn( - ": !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64", - text, + "if pto.constexpr(...) condition must be a compile-time bool", + str(ctx.exception), ) - self.assertRegex( - text, - r"pto\.copy_ubuf_to_gm %ub_dst_\d+, %typed_dst_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+", + + def test_if_compare_or_condition_lowers_to_cmp_and_bool_ops(self) -> None: + @pto.vkernel(op="if_compare_or", dtypes=[(pto.f32,)]) + def kernel(src: pto.TensorView): + loop1 = src.shape[3] + loop2 = src.shape[4] + step = 64 + if loop1 != 1 or loop2 != 1: + step = 128 + else: + step = 64 + return None + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(kernel)) + self.assertEqual( + [(param.name, param.kind) for param in semantic_kernel.parameters], + [("src", "tensorview")], + ) + self.assertIsInstance(semantic_kernel.body[3], SemanticIfStmt) + condition = semantic_kernel.body[3].condition + self.assertIsInstance(condition, SemanticBinaryExpr) + self.assertEqual(condition.op, "or") + self.assertIsInstance(condition.lhs, SemanticBinaryExpr) + self.assertEqual(condition.lhs.op, "ne") + self.assertIsInstance(condition.rhs, SemanticBinaryExpr) + self.assertEqual(condition.rhs.op, "ne") + + text = kernel.mlir_text() + self.assertEqual(text.count("arith.cmpi ne"), 2) + self.assertRegex(text, r"%loop1_\d+ = pto\.get_tensor_view_dim %arg0, %c3 : !pto\.tensor_view<\?x\?x\?x\?x\?xf32> -> index") + self.assertRegex(text, r"%loop2_\d+ = pto\.get_tensor_view_dim %arg0, %c4 : !pto\.tensor_view<\?x\?x\?x\?x\?xf32> -> index") + self.assertRegex(text, r"arith\.cmpi ne, %loop1_\d+, %c1 : index") + self.assertRegex(text, r"arith\.cmpi ne, %loop2_\d+, %c1 : index") + self.assertRegex(text, r"arith\.ori %tmp_\d+, %tmp_\d+ : i1") + self.assertRegex(text, r"%step_\d+ = scf\.if %tmp_\d+ -> \(index\) \{") + + def test_if_ordered_index_comparisons_lower_to_signed_cmp_predicates(self) -> None: + @pto.vkernel(op="if_compare_ordered_index", dtypes=[(pto.f32,)]) + def kernel(src: pto.TensorView): + dim0 = src.shape[0] + dim1 = src.shape[1] + step = 64 + if dim0 > 1 and dim1 <= 8: + step = 128 + else: + step = 32 + return None + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(kernel)) + self.assertIsInstance(semantic_kernel.body[3], SemanticIfStmt) + condition = semantic_kernel.body[3].condition + self.assertIsInstance(condition, SemanticBinaryExpr) + self.assertEqual(condition.op, "and") + + text = kernel.mlir_text() + self.assertRegex(text, r"arith\.cmpi sgt, %dim0_\d+, %c1 : index") + self.assertRegex(text, r"arith\.cmpi sle, %dim1_\d+, %c8 : index") + self.assertRegex(text, r"arith\.andi %tmp_\d+, %tmp_\d+ : i1") + self.assertRegex(text, r"%step_\d+ = scf\.if %tmp_\d+ -> \(index\) \{") + + def test_if_ordered_float_comparison_lowers_to_cmpf_predicate(self) -> None: + @pto.vkernel(op="if_compare_ordered_float", dtypes=[(pto.f32, pto.f32, pto.f32)]) + def kernel(src: pto.TensorView, lhs: pto.f32, rhs: pto.f32): + step = 64 + if lhs > rhs: + step = 128 + else: + step = 64 + return None + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(kernel)) + self.assertIsInstance(semantic_kernel.body[1], SemanticIfStmt) + + text = kernel.mlir_text() + self.assertRegex(text, r"arith\.cmpf ogt, %arg1, %arg2 : f32") + self.assertRegex(text, r"%step_\d+ = scf\.if %tmp_\d+ -> \(index\) \{") + + def test_shape_and_stride_tuple_unpacking_lower_cleanly(self) -> None: + @pto.vkernel(op="shape_stride_unpack", dtypes=[(pto.f32, pto.f32)], advanced=True) + def kernel(src: pto.TensorView, dst: pto.Tile): + g0, g1, g2, g3, g4 = src.shape + s0, s1, s2, s3, s4 = src.strides + ub_rows, ub_cols = dst.shape + total = g0 + g1 + g2 + g3 + g4 + stride_total = s0 + s1 + s2 + s3 + s4 + area = ub_rows * ub_cols + if total != 0 or stride_total != area: + total = area + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertEqual( + [(param.name, param.kind) for param in semantic_kernel.parameters], + [("src", "tensorview"), ("dst", "tile")], ) + text = specialized.mlir_text() + self.assertEqual(text.count("pto.get_tensor_view_dim"), 5) + self.assertEqual(text.count("pto.get_tensor_view_stride"), 5) + self.assertRegex(text, r"%ub_rows_\d+ = arith\.constant 8 : index") + self.assertRegex(text, r"%ub_cols_\d+ = arith\.constant 64 : index") + def test_advanced_mode_lowers_compare_predicate_carry_and_rearrangement_families(self) -> None: @pto.vkernel(op="advanced_family", dtypes=[(pto.i32, pto.i32, pto.i32, pto.i32)], advanced=True) def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile, scalar: pto.i32): @@ -1018,11 +3207,10 @@ def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile, scalar: pto.i32): self.assertIn(" = pto.vselrv2 ", text) self.assertIn("pto.vsts ", text) - def test_elementwise_kernel_positive_regression_covers_dma_vecscope_tail_mask_and_dynamic_loop_bound(self) -> None: - @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32, pto.f32, pto.i32)]) - def kernel(inp: pto.TensorView, out: pto.TensorView, tile: pto.Tile, remaining: pto.i32): + def test_elementwise_kernel_positive_regression_covers_vecscope_tail_mask_and_dynamic_loop_bound(self) -> None: + @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32, pto.i32)], advanced=True) + def kernel(inp: pto.TensorView, tile: pto.Tile, remaining: pto.i32): rows = inp.shape[0] - pto.dma_load(inp[0:rows, 0:16], tile) with pto.strict_vecscope(tile, tile, remaining, 0, rows, 64) as ( src, dst, @@ -1035,7 +3223,6 @@ def kernel(inp: pto.TensorView, out: pto.TensorView, tile: pto.Tile, remaining: mask, rem = pto.make_mask(pto.f32, rem) vec = pto.vlds(src, lane) pto.vsts(vec, dst, lane, mask) - pto.dma_store(tile, out[0:rows, 0:16]) return None specialized = kernel.specialize( @@ -1046,12 +3233,10 @@ def kernel(inp: pto.TensorView, out: pto.TensorView, tile: pto.Tile, remaining: ) semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) - self.assertEqual(len(semantic_kernel.body), 5) - self.assertIsInstance(semantic_kernel.body[1], SemanticDmaLoadStmt) - self.assertIsInstance(semantic_kernel.body[2], SemanticStrictVecscopeStmt) - self.assertIsInstance(semantic_kernel.body[3], SemanticDmaStoreStmt) + self.assertEqual(len(semantic_kernel.body), 3) + self.assertIsInstance(semantic_kernel.body[1], SemanticStrictVecscopeStmt) - vecscope = semantic_kernel.body[2] + vecscope = semantic_kernel.body[1] self.assertIsInstance(vecscope, SemanticStrictVecscopeStmt) loop_stmt = vecscope.body[0] self.assertIsInstance(loop_stmt, SemanticForStmt) @@ -1060,16 +3245,16 @@ def kernel(inp: pto.TensorView, out: pto.TensorView, tile: pto.Tile, remaining: text = specialized.mlir_text() self.assertIn( - "func.func @kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr, %arg3: i32, %arg4: index) {", + "func.func @kernel(%arg0: !pto.tensor_view, %arg1: !pto.tile_buf, %arg2: i32) attributes { pto.tilelang.instance } {", text, ) - self.assertIn( - "pto.copy_gm_to_ubuf %arg0, %arg2, %c0_i64, %c16_i64, %c64_i64", + self.assertRegex( text, + r"%rows_\d+ = pto\.get_tensor_view_dim %arg0, %c0 : !pto\.tensor_view<\?x\?x\?x\?x\?xf32> -> index", ) - self.assertIn( - "pto.strict_vecscope(%arg2, %arg2, %arg3, %c0, %arg4, %c64)", + self.assertRegex( text, + r"pto\.strict_vecscope\(%tmp_\d+, %tmp_\d+, %arg2, %c0, %rows_\d+, %c64\)", ) self.assertRegex( text, @@ -1079,13 +3264,9 @@ def kernel(inp: pto.TensorView, out: pto.TensorView, tile: pto.Tile, remaining: text, r"%mask_\d+, %rem_\d+ = pto\.plt_b32 %rem_iter_\d+ : i32 -> !pto\.mask, i32", ) - self.assertIn( - "pto.copy_ubuf_to_gm %arg2, %arg1, %c0_i64, %c16_i64, %c64_i64", - text, - ) def test_if_else_and_sync_ops_lower_to_scf_if_and_authoring_sync_ops(self) -> None: - @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32, pto.i32)]) + @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32, pto.i32)], advanced=True) def kernel(inp: pto.TensorView, tile: pto.Tile, flag: pto.i32): pto.set_flag(pto.PIPE.MTE2, pto.PIPE.V, pto.EVENT.ID0) pto.wait_flag(pto.PIPE.MTE2, pto.PIPE.V, pto.EVENT.ID0) @@ -1121,17 +3302,20 @@ def kernel(inp: pto.TensorView, tile: pto.Tile, flag: pto.i32): self.assertIn('pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"]', text) self.assertIn('pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"]', text) self.assertIn("= arith.cmpi ne, %arg2, %c0_i32 : i32", text) - self.assertIn("%step_3 = scf.if %tmp_0 -> (index) {", text) + self.assertRegex(text, r"%step_\d+ = scf\.if %tmp_\d+ -> \(index\) \{") self.assertIn('pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"]', text) self.assertIn('pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"]', text) self.assertRegex(text, r"scf\.yield %step_\d+ : index") self.assertIn("%step_2 = arith.constant 128 : index", text) - self.assertIn("pto.strict_vecscope(%arg1, %arg1, %c0, %c256, %step_3)", text) + self.assertRegex( + text, + r"pto\.strict_vecscope\(%tmp_\d+, %tmp_\d+, %c0, %c256, %step_\d+\)", + ) self.assertIn("scf.for %lane_", text) self.assertIn("pto.barrier #pto.pipe", text) def test_strict_vecscope_rejects_implicit_capture_during_semantic_analysis(self) -> None: - @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f16, pto.i32)]) + @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f16, pto.i32)], advanced=True) def kernel(inp: pto.TensorView, tile: pto.Tile, scale: pto.i32): with pto.strict_vecscope(inp, tile) as (vin, vtmp): leaked = scale @@ -1149,6 +3333,307 @@ def kernel(inp: pto.TensorView, tile: pto.Tile, scale: pto.i32): self.assertIn("implicit capture of 'scale' is not allowed", str(ctx.exception)) +class TileLangDSLInlineProcTests(unittest.TestCase): + @pto.inline_proc + def _inline_copy_row(dst: pto.Tile, src: pto.Tile, lane: pto.i32): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec = pto.vlds(src, lane) + pto.vsts(vec, dst, lane, mask) + return None + + @pto.inline_proc + def _inline_recur(dst: pto.Tile): + _inline_recur(dst) + return None + + @pto.inline_proc + def _inline_capture(dst: pto.Tile): + pto.vlds(dst, lane) + return None + + def test_inline_proc_exports_from_package_surface(self) -> None: + self.assertTrue(hasattr(pto, "inline_proc")) + self.assertTrue(hasattr(pto, "InlineProcDescriptor")) + + def test_inline_proc_call_keeps_call_in_frontend_and_mlir_text(self) -> None: + @pto.vkernel(op="inline_proc_backend_call_unique", dtypes=[(pto.f32, pto.f32)]) + def kernel(dst: pto.Tile, src: pto.Tile): + _inline_copy_row(dst, src, 0) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + + frontend_kernel = build_frontend_kernel_node(specialized) + self.assertEqual(len(frontend_kernel.body), 2) + self.assertIsInstance(frontend_kernel.body[0], FrontendExprStmt) + self.assertIsInstance(frontend_kernel.body[0].expr, FrontendCallExpr) + self.assertEqual(frontend_kernel.body[0].expr.name, "_inline_copy_row") + self.assertGreaterEqual(len(frontend_kernel.inline_procs), 1) + self.assertIn("_inline_copy_row", {proc.name for proc in frontend_kernel.inline_procs}) + + text = specialized.mlir_text() + self.assertIn("func.call", text) + self.assertIn("pto.tilelang.inline_proc", text) + self.assertRegex(text, r"func\.call @__tl_inline_") + + def test_inline_proc_supports_default_parameters_and_keyword_call(self) -> None: + @pto.inline_proc + def inline_store(dst: pto.Tile, src: pto.Tile, lane: pto.i32 = 0): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec = pto.vlds(src, lane) + pto.vsts(vec, dst, lane, mask) + return None + + @pto.vkernel(op="inline_proc_keyword_default_unique", dtypes=[(pto.f32, pto.f32)]) + def kernel(dst: pto.Tile, src: pto.Tile): + inline_store(dst=dst, src=src) + return None + + text = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ).mlir_text() + self.assertIn("func.call", text) + self.assertIn("pto.tilelang.inline_proc", text) + + def test_inline_proc_supports_return_expression_in_expression_position(self) -> None: + @pto.inline_proc + def inline_load(src: pto.Tile, lane: pto.i32 = 0): + return pto.vlds(src, lane) + + @pto.vkernel(op="inline_proc_expr_return_unique", dtypes=[(pto.f32, pto.f32)]) + def kernel(dst: pto.Tile, src: pto.Tile): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec = inline_load(src) + pto.vsts(vec, dst, 0, mask) + return None + + text = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ).mlir_text() + self.assertIn("func.call", text) + self.assertRegex(text, r"= func\.call @__tl_inline_") + self.assertIn("pto.vsts", text) + + def test_inline_proc_rejects_non_trailing_return(self) -> None: + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.inline_proc + def bad_inline(flag: pto.i32): + if flag: + return flag + return flag + + self.assertIn("optional trailing `return`", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_inline_proc_rejects_recursive_calls(self) -> None: + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.vkernel(op="inline_proc_recursive_unique", dtypes=[(pto.f32,)]) + def kernel(dst: pto.Tile): + _inline_recur(dst) + return None + + kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB) + ).mlir_text() + + self.assertIn("recursive inline_proc call `_inline_recur` is not supported", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_inline_proc_rejects_implicit_capture(self) -> None: + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.vkernel(op="inline_proc_capture_unique", dtypes=[(pto.f32,)]) + def kernel(dst: pto.Tile): + lane = pto.i32(0) + _inline_capture(dst) + return None + + kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB) + ).mlir_text() + + self.assertIn("implicit capture of 'lane' is not allowed in inline_proc", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_inline_proc_rejects_kw_only_vararg_and_kwargs(self) -> None: + with self.assertRaises(pto.TileLangFrontendError) as kw_only_ctx: + + @pto.inline_proc + def bad_kw_only(dst: pto.Tile, *, lane: pto.i32): + return None + + self.assertIn("keyword-only parameters", str(kw_only_ctx.exception)) + + with self.assertRaises(pto.TileLangFrontendError) as vararg_ctx: + + @pto.inline_proc + def bad_vararg(dst: pto.Tile, *lanes: pto.i32): + return None + + self.assertIn("does not support *args", str(vararg_ctx.exception)) + + with self.assertRaises(pto.TileLangFrontendError) as kwargs_ctx: + + @pto.inline_proc + def bad_kwargs(dst: pto.Tile, **opts: pto.i32): + return None + + self.assertIn("does not support **kwargs", str(kwargs_ctx.exception)) + + def test_inline_proc_rejects_invalid_keyword_binding(self) -> None: + @pto.inline_proc + def inline_store(dst: pto.Tile, src: pto.Tile): + return None + + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.vkernel(op="inline_proc_invalid_keyword_unique", dtypes=[(pto.f32, pto.f32)]) + def kernel(dst: pto.Tile, src: pto.Tile): + inline_store(dst=dst, src=src, lane=0) + return None + + self.assertIn("unexpected keyword argument 'lane'", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_inline_proc_rejects_missing_required_argument(self) -> None: + @pto.inline_proc + def inline_store(dst: pto.Tile, src: pto.Tile): + return None + + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.vkernel(op="inline_proc_missing_required_unique", dtypes=[(pto.f32, pto.f32)]) + def kernel(dst: pto.Tile, src: pto.Tile): + inline_store(dst=dst) + return None + + self.assertIn("missing a required argument: 'src'", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_inline_proc_rejects_multiple_values_for_single_parameter(self) -> None: + @pto.inline_proc + def inline_store(dst: pto.Tile, src: pto.Tile): + return None + + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.vkernel(op="inline_proc_multiple_values_unique", dtypes=[(pto.f32, pto.f32)]) + def kernel(dst: pto.Tile, src: pto.Tile): + inline_store(dst, src, src=src) + return None + + self.assertIn("multiple values for argument 'src'", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_inline_proc_semantic_emits_controlled_namespace_none_call(self) -> None: + @pto.vkernel(op="inline_proc_semantic_call_unique", dtypes=[(pto.f32, pto.f32)]) + def kernel(dst: pto.Tile, src: pto.Tile): + _inline_copy_row(dst, src, 0) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + + call_stmts = [ + stmt + for stmt in semantic_kernel.body + if isinstance(stmt, SemanticExprStmt) and isinstance(stmt.expr, SemanticCallExpr) + ] + self.assertGreaterEqual(len(call_stmts), 1) + inline_call = call_stmts[0].expr + self.assertIsNone(inline_call.namespace) + self.assertRegex(inline_call.name, r"^__tl_inline_") + self.assertGreaterEqual(len(semantic_kernel.inline_helpers), 1) + self.assertRegex(semantic_kernel.inline_helpers[0].symbol_name, r"^__tl_inline_") + + def test_inline_proc_semantic_keeps_expression_call_return_type(self) -> None: + @pto.inline_proc + def inline_const_i32(): + return pto.i32(1) + + @pto.vkernel(op="inline_proc_semantic_expr_unique", dtypes=[(pto.f32,)]) + def kernel(dst: pto.Tile): + lane = inline_const_i32() + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + + assign_stmt = next( + stmt + for stmt in semantic_kernel.body + if isinstance(stmt, SemanticAssignStmt) and isinstance(stmt.value, SemanticCallExpr) + ) + self.assertIsNone(assign_stmt.value.namespace) + self.assertRegex(assign_stmt.value.name, r"^__tl_inline_") + self.assertIsInstance(assign_stmt.value.type, SemanticScalarType) + + def test_inline_proc_lowering_renders_private_helpers_and_call_bindings(self) -> None: + @pto.inline_proc + def inline_const_i32(): + return 1 + + @pto.inline_proc + def inline_store(dst: pto.Tile, src: pto.Tile): + lane = inline_const_i32() + _inline_copy_row(dst, src, lane) + return None + + @pto.vkernel(op="inline_proc_lowering_helpers_unique", dtypes=[(pto.f32, pto.f32)]) + def kernel(dst: pto.Tile, src: pto.Tile): + lane = inline_const_i32() + inline_store(dst, src) + _inline_copy_row(dst, src, lane) + return None + + text = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ).mlir_text() + self.assertIn('sym_visibility = "private", pto.tilelang.inline_proc', text) + self.assertGreaterEqual(text.count("func.func"), 3) + self.assertGreaterEqual(text.count("pto.tilelang.inline_proc"), 2) + self.assertRegex(text, r"= func\.call @__tl_inline_[A-Za-z0-9_]+\(.*\) : \([^\)]*\) -> index") + self.assertRegex(text, r"func\.call @__tl_inline_[A-Za-z0-9_]+\(.*\) : \([^\)]*\) -> \(\)") + + def test_inline_proc_rejects_mutual_recursion(self) -> None: + @pto.inline_proc + def inline_a(dst: pto.Tile): + inline_b(dst) + return None + + @pto.inline_proc + def inline_b(dst: pto.Tile): + inline_a(dst) + return None + + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.vkernel(op="inline_proc_mutual_recursion_unique", dtypes=[(pto.f32,)]) + def kernel(dst: pto.Tile): + inline_a(dst) + return None + + kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB) + ).mlir_text() + + self.assertIn("recursive inline_proc call", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + class TileLangDSLDiagnosticsTests(unittest.TestCase): def test_matcher_feature_validation_rejects_invalid_constraints_and_priority(self) -> None: def kernel(x: pto.TensorView): @@ -1183,6 +3668,34 @@ def kernel(x: pto.TensorView): self.assertIn("unsupported Python syntax `while`", str(ctx.exception)) self.assertIn(f"{__file__}:", str(ctx.exception)) + def test_vreg_annotated_assignment_rejects_mismatched_dtype(self) -> None: + with self.assertRaises(TypeError) as ctx: + + @pto.vkernel(op="vreg_annotation_mismatch_unique", dtypes=[(pto.f32,)], advanced=True) + def kernel(dst: pto.Tile): + vec: pto.vreg(pto.f16) = pto.vlds(dst, 0) + return None + + kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ).mlir_text() + + self.assertIn("annotated vector type `vreg(f16)` does not match inferred !pto.vreg<64xf32>", str(ctx.exception)) + + def test_mask_annotated_assignment_rejects_mismatched_granularity(self) -> None: + with self.assertRaises(TypeError) as ctx: + + @pto.vkernel(op="mask_annotation_mismatch_unique", dtypes=[(pto.f32,)], advanced=True) + def kernel(dst: pto.Tile): + mask: pto.mask_b16 = pto.make_mask(pto.f32, pto.PAT.ALL) + return None + + kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ).mlir_text() + + self.assertIn("annotated mask type `mask_b16` does not match inferred !pto.mask", str(ctx.exception)) + def test_arbitrary_external_call_reports_source_location(self) -> None: def helper(): return None @@ -1202,20 +3715,31 @@ def test_unsupported_pto_surface_reports_source_location(self) -> None: @pto.vkernel(op="x", dtypes=[(pto.f32,)]) def kernel(x: pto.TensorView): - pto.vadd(x) + pto.not_a_real_surface(x) return None - self.assertIn("vector op surface `pto.vadd` requires explicit pto.strict_vecscope", str(ctx.exception)) + self.assertIn("unsupported op surface `pto.not_a_real_surface`", str(ctx.exception)) self.assertIn(f"{__file__}:", str(ctx.exception)) - def test_advanced_family_requires_advanced_mode(self) -> None: + def test_strict_vecscope_requires_advanced_mode(self) -> None: with self.assertRaises(pto.TileLangFrontendError) as ctx: @pto.vkernel(op="x", dtypes=[(pto.f32, pto.f32)]) def kernel(x: pto.TensorView, tile: pto.Tile): with pto.strict_vecscope(tile, tile, 0, 256, 64) as (lhs, rhs, lb, ub, step): - mask = pto.make_mask(pto.f32, pto.PAT.ALL) - pto.vcmp(lhs, rhs, mask, "lt") + pass + return None + + self.assertIn("surface `pto.strict_vecscope` requires advanced=True", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_advanced_family_requires_advanced_mode(self) -> None: + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.vkernel(op="x", dtypes=[(pto.f32, pto.f32)]) + def kernel(x: pto.TensorView, tile: pto.Tile): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + pto.vcmp(tile, tile, mask, "lt") return None self.assertIn("surface `pto.vcmp` requires advanced=True", str(ctx.exception)) @@ -1252,6 +3776,11 @@ def kernel(x: pto.TensorView, tile: pto.Tile): self.assertIn("v1 only supports MemorySpace.UB", str(space_ctx.exception)) self.assertIn(f"{__file__}:", str(space_ctx.exception)) + with self.assertRaises(pto.TileLangFrontendError) as valid_shape_ctx: + kernel.specialize(tile={"shape": (4, 4), "memory_space": "ub", "valid_shape": (5, 4)}) + self.assertIn("valid_shape axis 0=5 must be <= shape axis 0=4", str(valid_shape_ctx.exception)) + self.assertIn(f"{__file__}:", str(valid_shape_ctx.exception)) + if __name__ == "__main__": unittest.main() diff --git a/tools/ptoas/CMakeLists.txt b/tools/ptoas/CMakeLists.txt index 96d841ca6..611c9c615 100644 --- a/tools/ptoas/CMakeLists.txt +++ b/tools/ptoas/CMakeLists.txt @@ -27,6 +27,12 @@ add_llvm_executable(pto-opt set_target_properties(pto-opt PROPERTIES OUTPUT_NAME "ptoas") target_compile_definitions(pto-opt PRIVATE PTOAS_RELEASE_VERSION="${PTOAS_CLI_VERSION}" + # Source-tree defaults for TileLang DSL expansion. These let ptoas run + # directly from the build tree without passing --tilelang-path / + # --tilelang-pkg-path. Installed layouts that move these directories + # still need to override the flags explicitly. + PTOAS_DEFAULT_TILELANG_PATH="${CMAKE_SOURCE_DIR}/lib/TileOps" + PTOAS_DEFAULT_TILELANG_PKG_PATH="${CMAKE_SOURCE_DIR}/tilelang-dsl/python" ) # [修改 2] 更新链接库名称 # 原因:In-tree 时你的库叫 MLIRPTODialect,但现在 Out-of-tree 它们是你自己定义的 diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index 5daacf40b..bacdd6f52 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -187,21 +187,30 @@ static llvm::cl::opt enableInsertSync("enable-insert-sync", llvm::cl::desc("Enable automatic synchronization insertion pass"), llvm::cl::init(false)); -static llvm::cl::opt enableTileToVector( - "enable-tile-to-vector", +static llvm::cl::opt enableTileOpExpand( + "enable-tile-op-expand", llvm::cl::desc( "Enable Tile-to-Vector lowering path (memref->tile_buf recovery)"), llvm::cl::init(false)); +#ifndef PTOAS_DEFAULT_TILELANG_PATH +#define PTOAS_DEFAULT_TILELANG_PATH "" +#endif +#ifndef PTOAS_DEFAULT_TILELANG_PKG_PATH +#define PTOAS_DEFAULT_TILELANG_PKG_PATH "" +#endif + static llvm::cl::opt tilelangPath( "tilelang-path", - llvm::cl::desc("Path to directory of .py tilelang DSL template files"), - llvm::cl::init("")); + llvm::cl::desc("Path to directory of .py tilelang DSL template files " + "(default: /lib/TileOps, baked in at build time)"), + llvm::cl::init(PTOAS_DEFAULT_TILELANG_PATH)); static llvm::cl::opt tilelangPkgPath( "tilelang-pkg-path", - llvm::cl::desc("PYTHONPATH for tilelang_dsl package"), - llvm::cl::init("")); + llvm::cl::desc("PYTHONPATH for tilelang_dsl package " + "(default: /tilelang-dsl/python, baked in at build time)"), + llvm::cl::init(PTOAS_DEFAULT_TILELANG_PKG_PATH)); static llvm::cl::opt disableInferLayout( "disable-infer-layout", @@ -1078,8 +1087,22 @@ static bool shouldDeclareVariablesAtTop(ModuleOp module) { } static LogicalResult prepareVPTOForEmission(ModuleOp module) { - if (failed(convertVPTOEmissionBoundaryToPtr(module, &llvm::errs()))) { - llvm::errs() << "Error: VPTO emission boundary canonicalization failed.\n"; + PassManager cleanupPM(module->getContext()); + cleanupPM.enableVerifier(); + cleanupPM.addPass(createCanonicalizerPass()); + cleanupPM.addPass(createCSEPass()); + if (failed(cleanupPM.run(module))) { + llvm::errs() << "Error: VPTO pre-emission cleanup failed.\n"; + return failure(); + } + + PassManager boundaryPM(module->getContext()); + boundaryPM.enableVerifier(); + boundaryPM.addPass(pto::createVPTOPtrNormalizePass()); + boundaryPM.addPass(pto::createVPTOPtrCastCleanupPass()); + boundaryPM.addPass(createReconcileUnrealizedCastsPass()); + if (failed(boundaryPM.run(module))) { + llvm::errs() << "Error: VPTO ptr normalization failed.\n"; return failure(); } @@ -1098,8 +1121,25 @@ static LogicalResult prepareVPTOForEmission(ModuleOp module) { static LogicalResult lowerPTOToVPTOBackend(ModuleOp module) { PassManager backendPM(module.getContext()); - backendPM.addPass(pto::createLowerPTOToVPTOPass()); - backendPM.addPass(mlir::createCSEPass()); + if (enableTileOpExpand) { + // TileOp Expand path: + // 1. MemrefToTileBuf: recover tile_buf from memref + // 2. ExpandTileOp: instantiate TileLang DSL templates, replace tile ops + // with func.call to template functions (tile_buf params) + // 3. InlineLibCall: inline template function bodies + // 4. FoldTileBufIntrinsics: fold tile_buf_addr / tile_valid_rows / + // tile_valid_cols to concrete memref/constant values + backendPM.addPass(pto::createMemrefToTileBufPass()); + + pto::ExpandTileOpOptions expandOpts; + expandOpts.tilelangPath = tilelangPath; + expandOpts.tilelangPkgPath = tilelangPkgPath; + backendPM.addPass(pto::createExpandTileOpPass(expandOpts)); + + backendPM.addPass(pto::createPTOInlineLibCallPass()); + backendPM.addNestedPass( + pto::createFoldTileBufIntrinsicsPass()); + } if (failed(backendPM.run(module))) { llvm::errs() << "Error: backend lowering pass execution failed.\n"; return failure(); @@ -1481,39 +1521,11 @@ int main(int argc, char **argv) { pm.addPass(emitc::createFormExpressionsPass()); pm.addPass(mlir::createCSEPass()); - if (enableTileToVector) { - // Tile→Vector path: - // 1. MemrefToTileBuf: recover tile_buf from memref - // 2. ExpandTileOp: instantiate TileLang DSL templates, replace tile ops - // with func.call to template functions (tile_buf params) - // 3. InlineLibCall: inline template function bodies - // 4. FoldTileBufIntrinsics: fold tile_buf_addr / tile_valid_rows / - // tile_valid_cols to concrete memref/constant values - pm.addPass(pto::createMemrefToTileBufPass()); - - pto::ExpandTileOpOptions expandOpts; - expandOpts.tilelangPath = tilelangPath; - expandOpts.tilelangPkgPath = tilelangPkgPath; - pm.addPass(pto::createExpandTileOpPass(expandOpts)); - - pm.addPass(pto::createPTOInlineLibCallPass()); - pm.addNestedPass( - pto::createFoldTileBufIntrinsicsPass()); - } - if (failed(pm.run(*module))) { llvm::errs() << "Error: Pass execution failed.\n"; return 1; } - // Tile→Vector path: print MLIR IR and exit (no C++ emission). - if (enableTileToVector) { - module->print(outputFile.os()); - outputFile.os() << "\n"; - outputFile.keep(); - return 0; - } - dropEmptyEmitCExpressions(module.get()); materializeControlFlowOperands(module.get()); if (failed(reorderEmitCFunctions(module.get()))) { From 203ea9150978d543d82460e8ac9c129fae2b7ec7 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Sat, 11 Apr 2026 17:49:41 +0800 Subject: [PATCH 028/192] Update DSL user guide --- tilelang-dsl/docs/unsupported-features.md | 2 +- .../docs/user_guide/01-introduction.md | 3 +- .../docs/user_guide/05-type-system.md | 214 +++++++++++++++++- ...{08-control-flow.md => 06-control-flow.md} | 0 tilelang-dsl/docs/user_guide/06-tensorview.md | 97 -------- ...perations.md => 07-frontend-operations.md} | 0 tilelang-dsl/docs/user_guide/07-tile-types.md | 187 --------------- ...perations.md => 08-sync-dma-operations.md} | 0 ...ions.md => 09-vector-memory-operations.md} | 2 +- ...erations.md => 10-predicate-operations.md} | 0 ....md => 11-vector-arithmetic-operations.md} | 0 .../{14-examples.md => 12-examples.md} | 0 ...5-common-errors.md => 13-common-errors.md} | 0 ...ity-notes.md => 14-compatibility-notes.md} | 0 .../{17-next-steps.md => 15-next-steps.md} | 0 15 files changed, 212 insertions(+), 293 deletions(-) rename tilelang-dsl/docs/user_guide/{08-control-flow.md => 06-control-flow.md} (100%) delete mode 100644 tilelang-dsl/docs/user_guide/06-tensorview.md rename tilelang-dsl/docs/user_guide/{09-frontend-operations.md => 07-frontend-operations.md} (100%) delete mode 100644 tilelang-dsl/docs/user_guide/07-tile-types.md rename tilelang-dsl/docs/user_guide/{10-sync-dma-operations.md => 08-sync-dma-operations.md} (100%) rename tilelang-dsl/docs/user_guide/{11-vector-memory-operations.md => 09-vector-memory-operations.md} (99%) rename tilelang-dsl/docs/user_guide/{12-predicate-operations.md => 10-predicate-operations.md} (100%) rename tilelang-dsl/docs/user_guide/{13-vector-arithmetic-operations.md => 11-vector-arithmetic-operations.md} (100%) rename tilelang-dsl/docs/user_guide/{14-examples.md => 12-examples.md} (100%) rename tilelang-dsl/docs/user_guide/{15-common-errors.md => 13-common-errors.md} (100%) rename tilelang-dsl/docs/user_guide/{16-compatibility-notes.md => 14-compatibility-notes.md} (100%) rename tilelang-dsl/docs/user_guide/{17-next-steps.md => 15-next-steps.md} (100%) diff --git a/tilelang-dsl/docs/unsupported-features.md b/tilelang-dsl/docs/unsupported-features.md index 5405e1453..7a815f4dc 100644 --- a/tilelang-dsl/docs/unsupported-features.md +++ b/tilelang-dsl/docs/unsupported-features.md @@ -87,7 +87,7 @@ of the supported authoring surface: ### Missing Extended Vector Arithmetic Families -The previously missing `13-vector-arithmetic-operations.md` gap list is now +The previously missing `11-vector-arithmetic-operations.md` gap list is now implemented in the current package surface (including fused ops, broadcast/index generation, reduction-flavored ops, and rearrangement/sort groups). diff --git a/tilelang-dsl/docs/user_guide/01-introduction.md b/tilelang-dsl/docs/user_guide/01-introduction.md index 04f845a11..1c51cb8ca 100644 --- a/tilelang-dsl/docs/user_guide/01-introduction.md +++ b/tilelang-dsl/docs/user_guide/01-introduction.md @@ -16,7 +16,7 @@ The DSL surface is organized into multiple maturity tiers, reflecting the stabil | `strict_vecscope` | `advanced` | Explicit vector-scope management for expert authoring. | | Raw pointer family (`ptr(...)`, `castptr`, `addptr`) | `advanced` | For expert authoring and migration; not required for Quick Start. | | DMA family (`copy_*`, `set_loop*_stride_*`, `set_loop_size_*`) | `advanced` | Direct DMA engine control for expert authoring. | -| Tile helper family (`tile.slice(...)`, `tile.reshape(...)`, `tile.as_ptr()`, `tile_from_ptr(...)`, `tile_with_strides(...)`, `tile_config(...)`) | `advanced` | Partial or evolving surface; not the default entry point. | +| Tile pointer helper (`tile.as_ptr()`) | `advanced` | Expert-only helper when advanced authoring needs explicit typed pointers. | For the authoritative tier classification, consult `tilelang-dsl/python/tilelang_dsl/support_matrix.py`. For known implementation gaps, refer to `tilelang-dsl/docs/unsupported-features.md`. @@ -45,4 +45,3 @@ The TileLang DSL provides two distinct authoring modes: - Tile slices in basic mode correspond to MLIR `memref` types - Raw pointers in advanced mode correspond to MLIR `pto.ptr` types - No automatic conversion between tile and pointer semantics - choose the appropriate syntax for your authoring mode - diff --git a/tilelang-dsl/docs/user_guide/05-type-system.md b/tilelang-dsl/docs/user_guide/05-type-system.md index d7df0e2b8..a467dc05a 100644 --- a/tilelang-dsl/docs/user_guide/05-type-system.md +++ b/tilelang-dsl/docs/user_guide/05-type-system.md @@ -133,10 +133,214 @@ The `MemorySpace` enum provides type-safe memory space specification: This replaces string literals (`MemorySpace.GM`/`MemorySpace.UB`) with compile-time checked enums. -### Pointer Type Aliases [Advanced Tier] +### Public Buffer Types -For clarity in API documentation, the following type alias is used: +TileLang uses two public buffer-facing type names in kernel signatures: -| Alias | Equivalent Type | Description | -|-------|----------------|-------------| -| `Tile` | `pto.tile(...)` | Tile buffer with layout and configuration | +| Public Type | Description | +|-------------|-------------| +| `pto.TensorView` | GM-facing tensor view descriptor used for DMA-oriented data access | +| `pto.Tile` | UB-facing tile buffer value used for tiled compute | + +### TensorView Types + +TensorView types represent multi-dimensional (up to 5D) views into tensors residing in Global Memory (GM). They are used as kernel parameters for describing GM data and support slicing operations to create logical partitions for DMA load/store operations. + +#### TensorView Type Definition + +TensorView types are parameterized by shape (a tuple of up to 5 dimensions) and element type: + +```python +# Kernel parameter using TensorView +@pto.vkernel(target="a5", op="custom", dtypes=[(pto.AnyFloat, pto.AnyFloat, pto.AnyFloat)], priority=10) +def tiled_kernel( + input_tensor: pto.TensorView, # GM tensor view + output_tensor: pto.TensorView, # GM tensor view + tile_buf: pto.Tile # UB tile +): + # Access tensor view properties + shape = input_tensor.shape # tuple of dimensions (dynamic or static, up to 5D) + dtype = input_tensor.element_type # e.g., pto.f32 + strides = input_tensor.strides # stride in elements +``` + +Important notes: +- TensorView is a read-only descriptor for GM data, though DMA store operations can write through it. +- Shape can be static (compile-time constants) or dynamic (determined at runtime). +- Strides are expressed in elements, not bytes. +- Memory space is always GM (Global Memory). +- Maximum rank is 5. PTO ISA right-aligns lower-rank shapes to 5D. +- When higher dimensions are 1, a 5D TensorView can be abbreviated to lower-rank forms. For example, shape `(1, 1, 64, 32, 16)` can be written as `(64, 32, 16)`, and shape `(1, 1, 1, 32, 16)` can be written as `(32, 16)`. + +#### TensorView Attributes + +| Attribute | Type | Description | +|-----------|------|-------------| +| `shape` | `tuple[int, ...]` | Tensor dimensions (supports up to 5 dimensions, right-aligned to 5D in PTO ISA) | +| `element_type` | `Type` | Element data type (for example `pto.f32`, `pto.f16`) | +| `strides` | `tuple[int, ...]` | Stride in elements for each dimension | +| `offset` | `pto.i64` | Byte offset from base pointer (internal) | + +#### Padding Mode Enum + +Padding mode controls how out-of-bounds accesses are handled during DMA load/store operations: + +| Enum Value | Description | +|------------|-------------| +| `PadMode.PadNull` | No padding. Out-of-bounds access is invalid | +| `PadMode.PadFirstElem` | Pad using the first element of the source | +| `PadMode.PadValue` | Pad using a specified value and requires `pad_value` | + +#### Slicing Syntax + +TensorView supports Python slicing syntax to create logical partitions: + +```python +# Create a partition from a tensor view +partition = tensor_view[dim0_start:dim0_end, dim1_start:dim1_end] + +# Example: extract a 16x16 tile from a larger tensor +tile_view = large_tensor[0:16, 0:16] + +# Dynamic offsets and sizes +dim0_start = tensor_view.shape[0] // 2 +dynamic_partition = tensor_view[dim0_start:tensor_view.shape[0], 4:20] + +# Static positive step on dimension 0 +stepped_partition = tensor_view[0:32:2, 0:16] + +# Right-aligned shorthand on a 5D descriptor +partition_3d = tensor_view[d2_start:d2_end, d3_start:d3_end, d4_start:d4_end] + +# Full 5D spelling remains available when needed +partition_5d = tensor_view[ + d0_start:d0_end, + d1_start:d1_end, + d2_start:d2_end, + d3_start:d3_end, + d4_start:d4_end, +] +``` + +Constraints: +- Slicing returns a new TensorView representing the logical partition. +- The partition must be within the original tensor bounds. +- When fewer than 5 slice axes are written, they are right-aligned to the trailing physical axes of the 5D descriptor. +- `stop` must be explicit on all dimensions. +- `start` may be static or dynamic. +- `step` must be a static positive integer. +- Dimension 0 may use `step > 1`. +- Dimension 1 must keep `step == 1` in the current DMA-oriented implementation. + +### Tile Types + +Tile types represent data blocks in memory with layout and configuration information, corresponding to `!pto.tile_buf` in the VPTO IR. Tiles are commonly used as kernel parameters for tiled computations. + +#### Tile Type Definition + +`pto.Tile` is the public tile type used in kernel signatures and vectorized UB compute. User code does not construct tiles with standalone helper APIs in the stable guide surface. + +Important notes on shape and valid shape: +- `shape` must be a compile-time constant. Tile dimensions are fixed at compilation time and cannot change at runtime. +- `valid_shape` can be either static or dynamic and must be less than or equal to `shape` in each dimension. +- When `valid_shape` is not specified, it defaults to the full `shape`. + +#### Tile Attributes + +| Attribute | Type | Description | +|-----------|------|-------------| +| `shape` | `tuple[int, ...]` | Full tile dimensions. These are compile-time constants | +| `element_type` | `Type` | Element data type (for example `pto.f32`) | +| `memory_space` | `MemorySpace` | Memory space such as GM or UB | +| `valid_shape` | `tuple[int, ...]` | Actual data dimensions within the tile. Must be less than or equal to `shape` in each dimension | +| `config` | `TileConfig` | Layout and padding configuration | + +#### Tile Shape Concepts + +- `shape` is the static physical allocation size of the tile buffer. +- `valid_shape` is the logical data region and may be static or dynamic. +- `valid_shape[i] <= shape[i]` must hold for each dimension. +- Fixed-size tiles with smaller valid regions are useful for padding and partial-tile cases. + +#### Basic Access Operations + +```python +# Get tile properties +shape = tile.shape # (256, 128) +elem_type = tile.element_type # pto.f32 +mem_space = tile.memory_space # MemorySpace.UB +valid_shape = tile.valid_shape # (240, 120) or same as shape + +# Get configuration properties +config = tile.config +b_layout = config.b_layout # pto.BLayout.ROW_MAJOR +s_layout = config.s_layout # pto.SLayout.NONE_BOX +s_fractal = config.s_fractal_size # pto.i32(16) +pad = config.pad_value # pto.PadValue.ZERO + +# Dynamic properties +rank = tile.rank # 2 +num_elements = tile.num_elements # 32768 (256 * 128) +valid_elements = tile.valid_elements # 28800 (240 * 120) +``` + +#### Layout and Stride Queries + +```python +# Get layout descriptors +layout_desc = tile.layout_descriptor # Returns layout description object + +# Get strides (in elements) +strides = tile.strides # (128, 1) for row-major 256x128 + +# Get byte strides +byte_strides = tile.byte_strides # (512, 4) for f32 row-major + +# Get base offset (in bytes) +offset = tile.offset # pto.i64(0) or specified offset +``` + +#### Conversion Operations + +Basic mode syntax uses tile element-indexing directly in vector operations: + +```python +# 2D tile indexing +vec = pto.vlds(tile[row, col:]) +pto.vsts(vec, tile[row, col:], mask) + +# 1D tile indexing +vec = pto.vlds(tile[start:]) +pto.vsts(vec, tile[start:], mask) +``` + +Advanced mode syntax converts tiles to typed pointers for byte-offset operations: + +```python +# Convert tile to pointer +ptr = tile.as_ptr() # Returns pto.ptr(pto.f32, MemorySpace.UB) + +# Use pointer with byte offset +vec = pto.vlds(ptr, offset) +pto.vsts(vec, ptr, offset, mask) +``` + +#### Kernel Parameter Usage + +```python +@pto.vkernel(target="a5", op="scale", dtypes=[(pto.AnyFloat, pto.AnyFloat)], priority=10) +def tiled_kernel( + input_tile: pto.Tile, + output_tile: pto.Tile, + scale: pto.f32 +): + all_mask = pto.make_mask(pto.f32, PAT.ALL) + for i in range(0, 256, 64): + vec = pto.vlds(input_tile[i, 0:]) + scaled = pto.vmuls(vec, scale, all_mask) + pto.vsts(scaled, output_tile[i, 0:], all_mask) +``` + +### Alignment Type + +The `pto.align` type is used for alignment carrier operations and maps to `!pto.align`. diff --git a/tilelang-dsl/docs/user_guide/08-control-flow.md b/tilelang-dsl/docs/user_guide/06-control-flow.md similarity index 100% rename from tilelang-dsl/docs/user_guide/08-control-flow.md rename to tilelang-dsl/docs/user_guide/06-control-flow.md diff --git a/tilelang-dsl/docs/user_guide/06-tensorview.md b/tilelang-dsl/docs/user_guide/06-tensorview.md deleted file mode 100644 index ad9a169d4..000000000 --- a/tilelang-dsl/docs/user_guide/06-tensorview.md +++ /dev/null @@ -1,97 +0,0 @@ -### TensorView Types - -TensorView types represent multi‑dimensional (up to 5D) views into tensors residing in Global Memory (GM). They are used as kernel parameters for describing GM data and support slicing operations to create logical partitions for DMA load/store operations. - -### TensorView Type Definition - -TensorView types are parameterized by shape (a tuple of up to 5 dimensions) and element type: - -```python -# Kernel parameter using TensorView -@pto.vkernel(target="a5", op="custom", dtypes=[(pto.AnyFloat, pto.AnyFloat, pto.AnyFloat)], priority=10) -def tiled_kernel( - input_tensor: pto.TensorView, # GM tensor view - output_tensor: pto.TensorView, # GM tensor view - tile_buf: pto.Tile # UB tile -): - # Access tensor view properties - shape = input_tensor.shape # tuple of dimensions (dynamic or static, up to 5D) - dtype = input_tensor.element_type # e.g., pto.f32 - strides = input_tensor.strides # stride in elements -``` - -**Important Notes:** -- TensorView is a **read-only descriptor** for GM data (though DMA store operations can write to it) -- Shape can be **static** (compile-time constants) or **dynamic** (determined at runtime) -- Strides are expressed in elements, not bytes -- Memory space is always GM (Global Memory) -- Maximum rank is 5 (PTO ISA right‑aligns lower‑rank shapes to 5D) -- When higher dimensions are 1, a 5D TensorView can be abbreviated to lower‑rank forms. For example, shape `(1,1,64,32,16)` can be written as `(64,32,16)` (3D), and shape `(1,1,1,32,16)` can be written as `(32,16)` (2D). - -### TensorView Attributes - -| Attribute | Type | Description | -|-----------|------|-------------| -| `shape` | `tuple[int, ...]` | Tensor dimensions (supports up to 5 dimensions, right-aligned to 5D in PTO ISA) | -| `element_type` | `Type` | Element data type (e.g., `pto.f32`, `pto.f16`) | -| `strides` | `tuple[int, ...]` | Stride in elements for each dimension | -| `offset` | `pto.i64` | Byte offset from base pointer (internal) | - -### Padding Mode Enum - -Padding mode controls how out-of-bounds accesses are handled during DMA load/store operations: - -| Enum Value | Description | -|------------|-------------| -| `PadMode.PadNull` | No padding (out-of-bounds access is invalid) | -| `PadMode.PadFirstElem` | Pad using the first element of the source | -| `PadMode.PadValue` | Pad using a specified value (requires `pad_value` parameter) | - -### Slicing Syntax - -TensorView supports Python slicing syntax to create logical partitions: - -```python -# Create a partition from a tensor view -partition = tensor_view[dim0_start:dim0_end, dim1_start:dim1_end] - -# Example: extract a 16x16 tile from a larger tensor -tile_view = large_tensor[0:16, 0:16] - -# Dynamic offsets and sizes -dim0_start = tensor_view.shape[0] // 2 -dynamic_partition = tensor_view[dim0_start:tensor_view.shape[0], 4:20] - -# Static positive step on dimension 0 -stepped_partition = tensor_view[0:32:2, 0:16] - -# Right-aligned shorthand on a 5D descriptor: -# if the leading 2 axes are logical singleton dimensions, a 3D-style slice -# maps to the trailing 3 physical axes. -partition_3d = tensor_view[d2_start:d2_end, d3_start:d3_end, d4_start:d4_end] - -# Full 5D spelling remains available when needed -partition_5d = tensor_view[ - d0_start:d0_end, - d1_start:d1_end, - d2_start:d2_end, - d3_start:d3_end, - d4_start:d4_end, -] -``` - -**Constraints:** -- Slicing returns a new TensorView representing the logical partition -- The partition must be within the original tensor bounds -- When fewer than 5 slice axes are written, they are right-aligned to the - trailing physical axes of the 5D descriptor -- `stop` must be explicit on all dimensions -- `start` may be static or dynamic -- `step` must be a static positive integer -- Dimension 0 may use `step > 1` -- Dimension 1 must keep `step == 1` (current implementation restriction for DMA operations) - -### Alignment Type - -The `pto.align` type is used for alignment carrier operations and maps to `!pto.align`. - diff --git a/tilelang-dsl/docs/user_guide/09-frontend-operations.md b/tilelang-dsl/docs/user_guide/07-frontend-operations.md similarity index 100% rename from tilelang-dsl/docs/user_guide/09-frontend-operations.md rename to tilelang-dsl/docs/user_guide/07-frontend-operations.md diff --git a/tilelang-dsl/docs/user_guide/07-tile-types.md b/tilelang-dsl/docs/user_guide/07-tile-types.md deleted file mode 100644 index b901e20f9..000000000 --- a/tilelang-dsl/docs/user_guide/07-tile-types.md +++ /dev/null @@ -1,187 +0,0 @@ -### Tile Types - -Tile types represent data blocks in memory with layout and configuration information, corresponding to `!pto.tile_buf` in the VPTO IR. Tiles are commonly used as kernel parameters for tiled computations. - -#### Tile Type Definition - -```python -# Create a tile with shape, element type, and memory space -tile = pto.tile((256, 128), pto.f32, MemorySpace.UB) - -# With explicit configuration -config = pto.tile_config( - b_layout=pto.BLayout.ROW_MAJOR, - s_layout=pto.SLayout.NONE_BOX, - s_fractal_size=pto.i32(16), - pad_value=pto.PadValue.ZERO -) -tile = pto.tile((256, 128), pto.f32, MemorySpace.UB, config=config) - -# With valid shape (actual data dimensions within tile) -tile = pto.tile((256, 128), pto.f32, MemorySpace.UB, valid_shape=(240, 120)) -``` - -**Important Notes on Shape and Valid Shape:** -- **Static Shape Requirement**: The `shape` parameter must be a compile-time constant. Tile dimensions are fixed at compilation time and cannot change at runtime. -- **Valid Shape Constraints**: The `valid_shape` parameter can be either static (compile-time constant) or dynamic (determined at runtime). It must be less than or equal to the physical `shape` in each dimension. This allows for variable-sized data within a fixed tile allocation. -- **Default Behavior**: When `valid_shape` is not specified, it defaults to the full `shape`. - -#### Tile Attributes - -| Attribute | Type | Description | -|-----------|------|-------------| -| `shape` | `tuple[int, ...]` | **Static** full tile dimensions (compile-time constant) | -| `element_type` | `Type` | Element data type (e.g., `pto.f32`) | -| `memory_space` | `MemorySpace` | Memory space (GM, UB, etc.) | -| `valid_shape` | `tuple[int, ...]` | Actual data dimensions within tile (can be static/compile-time or dynamic/runtime). Must be ≤ shape in each dimension. | -| `config` | `TileConfig` | Layout and padding configuration | - -#### Tile Configuration - -The tile configuration includes layout and padding information: - -```python -# Layout enums -pto.BLayout.ROW_MAJOR # 0: row-major base layout -pto.BLayout.COL_MAJOR # 1: column-major base layout - -pto.SLayout.NONE_BOX # 0: no secondary layout -pto.SLayout.ROW_MAJOR # 1: row-major secondary layout -pto.SLayout.COL_MAJOR # 2: column-major secondary layout - -pto.PadValue.NULL # 0: no padding -pto.PadValue.ZERO # 1: zero padding -pto.PadValue.MAX # 2: maximum value padding -pto.PadValue.MIN # 3: minimum value padding -``` - -#### Tile Shape Concepts - -- **Static Physical Shape**: The `shape` parameter represents the **static physical dimensions** of the tile allocated in memory. This must be a **compile-time constant** because tile memory allocation is fixed during compilation. The shape determines the total memory footprint and cannot change at runtime. - -- **Valid Shape**: The `valid_shape` parameter represents the logical dimensions of actual data within the tile. It can be either **static** (compile-time constant) or **dynamic** (determined at runtime). It must be less than or equal to the physical `shape` in each dimension. When `valid_shape` is not specified, it defaults to the full `shape`. - -- **Key Distinction**: - - `shape`: **Static, compile-time** - Fixed tile allocation - - `valid_shape`: **Static or Dynamic** - Actual data region (must be ≤ shape) - -- **Constraints**: - - `valid_shape[i] ≤ shape[i]` for each dimension i - - `shape` must be compile-time constants - - `valid_shape` can be compile-time constants or runtime values - -- **Use Cases**: - - Fixed-size tile buffers with variable data (e.g., batch processing with different input sizes) - - Padding scenarios where physical allocation is larger than actual data - - Partial tile utilization in tiled algorithms - -- **Fractal Layout**: The `s_fractal_size` in tile configuration specifies the size of fractal blocks for secondary layout. This is used for optimized memory access patterns in matrix operations. - -- **Padding Behavior**: The `pad_value` determines how out-of-bounds accesses are handled when reading beyond `valid_shape` but within `shape`. Padding values are used for accesses in the padded region (between valid_shape and shape). - -> **⚠️ Important: Shape Constraints** -> -> The tile `shape` must be **compile-time constants**. `valid_shape` can be compile-time constants or determined at runtime, but must satisfy `valid_shape[i] ≤ shape[i]` for all dimensions i. - -### Tile Operations - -#### Basic Access Operations - -```python -# Get tile properties -shape = tile.shape # (256, 128) -elem_type = tile.element_type # pto.f32 -mem_space = tile.memory_space # MemorySpace.UB -valid_shape = tile.valid_shape # (240, 120) or same as shape - -# Get configuration properties -config = tile.config -b_layout = config.b_layout # pto.BLayout.ROW_MAJOR -s_layout = config.s_layout # pto.SLayout.NONE_BOX -s_fractal = config.s_fractal_size # pto.i32(16) -pad = config.pad_value # pto.PadValue.ZERO - -# Dynamic properties -rank = tile.rank # 2 -num_elements = tile.num_elements # 32768 (256 * 128) -valid_elements = tile.valid_elements # 28800 (240 * 120) -``` - -#### Layout and Stride Queries - -```python -# Get layout descriptors -layout_desc = tile.layout_descriptor # Returns layout description object - -# Get strides (in elements) -strides = tile.strides # (128, 1) for row-major 256x128 - -# Get byte strides -byte_strides = tile.byte_strides # (512, 4) for f32 row-major - -# Get base offset (in bytes) -offset = tile.offset # pto.i64(0) or specified offset -``` - -#### Conversion Operations - -**Basic Mode Syntax**: Use tile element-indexing directly in vector operations: -```python -# 2D tile indexing -vec = pto.vlds(tile[row, col:]) -pto.vsts(vec, tile[row, col:], mask) - -# 1D tile indexing -vec = pto.vlds(tile[start:]) -pto.vsts(vec, tile[start:], mask) -``` - -**Advanced Mode Syntax**: Convert tiles to typed pointers for byte-offset operations: -```python -# Convert tile to pointer -ptr = tile.as_ptr() # Returns pto.ptr(pto.f32, MemorySpace.UB) - -# Use pointer with byte offset -vec = pto.vlds(ptr, offset) -pto.vsts(vec, ptr, offset, mask) -``` - -**Tile Manipulation Operations**: -```python -# Extract slice of tile -slice_tile = tile.slice((0, 0), (64, 128)) # 64x128 slice from top-left corner - -# Reshape tile (logical reshape, no data movement) -reshaped = tile.reshape((32768,)) # 1D reshape of 256x128 tile -``` - -#### Kernel Parameter Usage - -```python -@pto.vkernel(target="a5", op="scale", dtypes=[(pto.AnyFloat, pto.AnyFloat)], priority=10) -def tiled_kernel( - input_tile: pto.Tile, # Tile parameter - output_tile: pto.Tile, # Another tile parameter - scale: pto.f32 -): - # Tiles can be used directly in vector operations (no explicit conversion needed) - all_mask = pto.make_mask(pto.f32, PAT.ALL) - for i in range(0, 256, 64): - # tile element-indexing syntax for basic mode vector operations - vec = pto.vlds(input_tile[i, 0:]) # Load from row i, columns 0 to vector_lanes-1 - scaled = pto.vmuls(vec, scale, all_mask) - pto.vsts(scaled, output_tile[i, 0:], all_mask) # Store to same position -``` - -#### Tile Creation from Existing Buffers - -```python -# Create tile from existing pointer with shape -ptr = pto.castptr(0, pto.ptr(pto.f32, MemorySpace.UB)) -tile = pto.tile_from_ptr(ptr, (256, 128), pto.f32) - -# Create tile with explicit stride -tile = pto.tile_with_strides((256, 128), pto.f32, MemorySpace.UB, - strides=(256, 1)) # Column-major strides -``` - diff --git a/tilelang-dsl/docs/user_guide/10-sync-dma-operations.md b/tilelang-dsl/docs/user_guide/08-sync-dma-operations.md similarity index 100% rename from tilelang-dsl/docs/user_guide/10-sync-dma-operations.md rename to tilelang-dsl/docs/user_guide/08-sync-dma-operations.md diff --git a/tilelang-dsl/docs/user_guide/11-vector-memory-operations.md b/tilelang-dsl/docs/user_guide/09-vector-memory-operations.md similarity index 99% rename from tilelang-dsl/docs/user_guide/11-vector-memory-operations.md rename to tilelang-dsl/docs/user_guide/09-vector-memory-operations.md index f48cf2780..eec44ddc9 100644 --- a/tilelang-dsl/docs/user_guide/11-vector-memory-operations.md +++ b/tilelang-dsl/docs/user_guide/09-vector-memory-operations.md @@ -64,7 +64,7 @@ The number of elements loaded/stored in a single vector operation is determined vector_lanes = 256 // element_size_bytes(element_type) ``` -**Convenience API**: Use `pto.elements_per_vreg(dtype)` to compute the number of elements per vector register for a given element type (e.g., `pto.elements_per_vreg(pto.f32)` returns 64, `pto.elements_per_vreg(pto.f16)` returns 128). See [Type Query Operations](09-frontend-operations.md#type-query-operations) for full documentation. +**Convenience API**: Use `pto.elements_per_vreg(dtype)` to compute the number of elements per vector register for a given element type (e.g., `pto.elements_per_vreg(pto.f32)` returns 64, `pto.elements_per_vreg(pto.f16)` returns 128). See [Type Query Operations](07-frontend-operations.md#type-query-operations) for full documentation. Where `element_size_bytes` is: - 1 byte for `i8` diff --git a/tilelang-dsl/docs/user_guide/12-predicate-operations.md b/tilelang-dsl/docs/user_guide/10-predicate-operations.md similarity index 100% rename from tilelang-dsl/docs/user_guide/12-predicate-operations.md rename to tilelang-dsl/docs/user_guide/10-predicate-operations.md diff --git a/tilelang-dsl/docs/user_guide/13-vector-arithmetic-operations.md b/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md similarity index 100% rename from tilelang-dsl/docs/user_guide/13-vector-arithmetic-operations.md rename to tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md diff --git a/tilelang-dsl/docs/user_guide/14-examples.md b/tilelang-dsl/docs/user_guide/12-examples.md similarity index 100% rename from tilelang-dsl/docs/user_guide/14-examples.md rename to tilelang-dsl/docs/user_guide/12-examples.md diff --git a/tilelang-dsl/docs/user_guide/15-common-errors.md b/tilelang-dsl/docs/user_guide/13-common-errors.md similarity index 100% rename from tilelang-dsl/docs/user_guide/15-common-errors.md rename to tilelang-dsl/docs/user_guide/13-common-errors.md diff --git a/tilelang-dsl/docs/user_guide/16-compatibility-notes.md b/tilelang-dsl/docs/user_guide/14-compatibility-notes.md similarity index 100% rename from tilelang-dsl/docs/user_guide/16-compatibility-notes.md rename to tilelang-dsl/docs/user_guide/14-compatibility-notes.md diff --git a/tilelang-dsl/docs/user_guide/17-next-steps.md b/tilelang-dsl/docs/user_guide/15-next-steps.md similarity index 100% rename from tilelang-dsl/docs/user_guide/17-next-steps.md rename to tilelang-dsl/docs/user_guide/15-next-steps.md From 51112bb9953c390ee5da32d83c8ffc9a87e4693a Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Sat, 11 Apr 2026 18:23:54 +0800 Subject: [PATCH 029/192] Support more tile attributes --- tilelang-dsl/docs/README.md | 24 ++-- tilelang-dsl/docs/unsupported-features.md | 20 ++- .../docs/user_guide/05-type-system.md | 18 --- tilelang-dsl/python/tilelang_dsl/__init__.py | 6 + tilelang-dsl/python/tilelang_dsl/kernel.py | 32 ++++- tilelang-dsl/python/tilelang_dsl/lowering.py | 102 +++++++++++++- tilelang-dsl/python/tilelang_dsl/semantic.py | 108 +++++++++++++++ .../python/tilelang_dsl/support_matrix.py | 6 +- tilelang-dsl/python/tilelang_dsl/types.py | 127 ++++++++++++++++++ tilelang-dsl/tests/test_tilelang_dsl_v1.py | 22 +++ 10 files changed, 411 insertions(+), 54 deletions(-) diff --git a/tilelang-dsl/docs/README.md b/tilelang-dsl/docs/README.md index abfb405dd..afacded2e 100644 --- a/tilelang-dsl/docs/README.md +++ b/tilelang-dsl/docs/README.md @@ -13,27 +13,25 @@ TileLang Python DSL 为面向 Ascend NPU 硬件的向量计算内核提供高级 - [模板内核](user_guide/04-template-kernels.md) - 模板内核、多操作内核、编译时代换 ### 类型系统 -- [类型系统](user_guide/05-type-system.md) - 标量类型、向量类型、指针类型 -- [TensorView](user_guide/06-tensorview.md) - TensorView类型、属性、切片语法 -- [Tile类型](user_guide/07-tile-types.md) - Tile类型、属性、配置、操作 +- [类型系统](user_guide/05-type-system.md) - 标量、向量、指针、TensorView、Tile 类型 ### 控制流 -- [控制流](user_guide/08-control-flow.md) - 向量作用域、循环、条件语句 +- [控制流](user_guide/06-control-flow.md) - 向量作用域、循环、条件语句 ### 操作参考 -- [前端操作](user_guide/09-frontend-operations.md) - 前端操作、类型查询、指针构造 -- [同步和DMA操作](user_guide/10-sync-dma-operations.md) - 同步和DMA操作 -- [向量内存操作](user_guide/11-vector-memory-operations.md) - 向量加载和存储操作 -- [谓词操作](user_guide/12-predicate-operations.md) - 谓词操作 -- [向量算术操作](user_guide/13-vector-arithmetic-operations.md) - 向量算术操作 +- [前端操作](user_guide/07-frontend-operations.md) - 前端操作、类型查询、指针构造 +- [同步和DMA操作](user_guide/08-sync-dma-operations.md) - 同步和DMA操作 +- [向量内存操作](user_guide/09-vector-memory-operations.md) - 向量加载和存储操作 +- [谓词操作](user_guide/10-predicate-operations.md) - 谓词操作 +- [向量算术操作](user_guide/11-vector-arithmetic-operations.md) - 向量算术操作 ### 示例和错误处理 -- [示例](user_guide/15-examples.md) - 各种内核示例 -- [常见错误](user_guide/16-common-errors.md) - 常见错误和解决方案 +- [示例](user_guide/12-examples.md) - 各种内核示例 +- [常见错误](user_guide/13-common-errors.md) - 常见错误和解决方案 ### 附录 -- [兼容性说明](user_guide/17-compatibility-notes.md) - 与实验实现的差异 -- [后续步骤](user_guide/18-next-steps.md) - 相关资源链接 +- [兼容性说明](user_guide/14-compatibility-notes.md) - 与实验实现的差异 +- [后续步骤](user_guide/15-next-steps.md) - 相关资源链接 ## 相关文档 - [v1-surface.md](v1-surface.md) - TileLang DSL v1 合约 diff --git a/tilelang-dsl/docs/unsupported-features.md b/tilelang-dsl/docs/unsupported-features.md index 7a815f4dc..2db9ea66a 100644 --- a/tilelang-dsl/docs/unsupported-features.md +++ b/tilelang-dsl/docs/unsupported-features.md @@ -31,7 +31,6 @@ The guide documents a richer type-construction surface that is not exported by the current package: - `pto.tile(...)` -- `BLayout`, `SLayout`, `PadValue` - `SyncOpType` Today, the public package exports annotation markers (`TensorView`, `Tile`), @@ -141,23 +140,20 @@ full guide: ### Tile Attribute Model -`Tile` currently supports only a narrow attribute subset in semantic analysis: +`Tile` currently exposes the documented metadata/query surface used by the user guide: - `shape` - `element_type` -- `valid_shape` - -The guide documents additional properties that are not currently supported: - - `memory_space` +- `valid_shape` - `config` - `rank` -- `num_elements` -- `valid_elements` -- `layout_descriptor` -- `strides` -- `byte_strides` -- `offset` + +Current constraints still apply: + +- only statically specialized rank-1/rank-2 UB tiles are supported +- `TileConfig` is queryable metadata, but lowering still renders the fixed baseline + layout contract unless later backend work teaches richer layout semantics ### Tile Config Semantics diff --git a/tilelang-dsl/docs/user_guide/05-type-system.md b/tilelang-dsl/docs/user_guide/05-type-system.md index a467dc05a..2a68573c4 100644 --- a/tilelang-dsl/docs/user_guide/05-type-system.md +++ b/tilelang-dsl/docs/user_guide/05-type-system.md @@ -280,24 +280,6 @@ pad = config.pad_value # pto.PadValue.ZERO # Dynamic properties rank = tile.rank # 2 -num_elements = tile.num_elements # 32768 (256 * 128) -valid_elements = tile.valid_elements # 28800 (240 * 120) -``` - -#### Layout and Stride Queries - -```python -# Get layout descriptors -layout_desc = tile.layout_descriptor # Returns layout description object - -# Get strides (in elements) -strides = tile.strides # (128, 1) for row-major 256x128 - -# Get byte strides -byte_strides = tile.byte_strides # (512, 4) for f32 row-major - -# Get base offset (in bytes) -offset = tile.offset # pto.i64(0) or specified offset ``` #### Conversion Operations diff --git a/tilelang-dsl/python/tilelang_dsl/__init__.py b/tilelang-dsl/python/tilelang_dsl/__init__.py index acd27f302..bbc87ceb7 100644 --- a/tilelang-dsl/python/tilelang_dsl/__init__.py +++ b/tilelang-dsl/python/tilelang_dsl/__init__.py @@ -24,6 +24,7 @@ AnyInt, AnyMask, AnyType, + BLayout, EVENT, PIPE, Event, @@ -31,6 +32,7 @@ MemorySpace, MaskPattern, PAT, + PadValue, PadMode, PositionMode, OrderMode, @@ -62,6 +64,7 @@ mask_b16, mask_b32, ptr, + SLayout, vreg, ) @@ -88,6 +91,9 @@ "ptr", "vreg", "MemorySpace", + "BLayout", + "SLayout", + "PadValue", "Pipe", "Event", "PIPE", diff --git a/tilelang-dsl/python/tilelang_dsl/kernel.py b/tilelang-dsl/python/tilelang_dsl/kernel.py index 572271484..fcae842ba 100644 --- a/tilelang-dsl/python/tilelang_dsl/kernel.py +++ b/tilelang-dsl/python/tilelang_dsl/kernel.py @@ -679,7 +679,12 @@ def strides(self) -> _ConstraintSequenceView: @property def rank(self) -> _ConstraintValue: - return _ConstraintValue(self._attrs.get("rank")) + rank = self._attrs.get("rank") + if rank is None: + shape = self._attrs.get("shape") + if shape is not None: + rank = len(shape) + return _ConstraintValue(rank) @property def dtype(self) -> Any: @@ -687,11 +692,27 @@ def dtype(self) -> Any: @property def memory_space(self) -> Any: - return self._attrs.get("memory_space") + memory_space = self._attrs.get("memory_space") + if memory_space is None and self._attrs.get("kind") == "tile": + return MemorySpace.UB + if memory_space is None: + return None + if isinstance(memory_space, MemorySpace): + return memory_space + return MemorySpace(memory_space) @property - def config(self) -> Any: - return self._attrs.get("config") + def config(self) -> TileConfig | None: + config = self._attrs.get("config") + if config is None: + if self._attrs.get("kind") == "tile": + return TileConfig() + return None + if isinstance(config, TileConfig): + return config + if isinstance(config, Mapping): + return TileConfig.from_mapping(config) + raise TypeError(f"unsupported Tile config payload {config!r} in constraint view") def __repr__(self) -> str: return f"{self._name}<{self._attrs!r}>" @@ -999,14 +1020,13 @@ def _constraint_context_for_evaluation( param_attrs = attrs.get(name) if not isinstance(param_attrs, dict): param_attrs = {"kind": "tile"} - config_mapping = None if spec.config is None else dict(spec.config.fields) param_attrs.update( { "shape": spec.shape, "rank": len(spec.shape), "memory_space": spec.memory_space.value, "valid_shape": effective_valid_shape, - "config": config_mapping, + "config": spec.config, } ) attrs[name] = param_attrs diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index bd587df07..4eb0edf82 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -55,7 +55,7 @@ SemanticVectorStoreStmt, SemanticWaitFlagStmt, ) -from .types import MaskPattern, ScalarType, get_lanes +from .types import MaskPattern, MemorySpace, ScalarType, TileConfig, get_lanes _I1_TYPE = SemanticScalarType(dtype=ScalarType("i1")) @@ -328,7 +328,15 @@ def _collect_used_tile_buffers_from_expr( self._collect_used_tile_buffers_from_expr(slice_expr.step, used) return if isinstance(expr, SemanticAttributeAccess): - if expr.attr not in {"shape", "valid_shape", "strides", "element_type"}: + if expr.attr not in { + "shape", + "valid_shape", + "strides", + "element_type", + "rank", + "memory_space", + "config", + }: self._collect_used_tile_buffers_from_expr(expr.base, used) return if isinstance(expr, SemanticSubscriptAccess): @@ -919,6 +927,79 @@ def _static_expr_value(self, expr: SemanticExpr | None, *, default: object = Non return expr.value if isinstance(expr, SemanticBindingRef): return expr.binding.value + if isinstance(expr, SemanticAttributeAccess): + base_value = self._static_expr_value(expr.base) + if isinstance(base_value, TileConfig): + if expr.attr == "b_layout": + return base_value.b_layout + if expr.attr == "s_layout": + return base_value.s_layout + if expr.attr == "s_fractal_size": + return base_value.s_fractal_size + if expr.attr == "pad_value": + return base_value.pad_value + if base_value is not None and hasattr(base_value, expr.attr): + return getattr(base_value, expr.attr) + if isinstance(expr.base.type, SemanticTileType): + tile_type = expr.base.type + if expr.attr == "shape": + return tile_type.shape + if expr.attr == "valid_shape": + return None + if expr.attr == "rank": + return tile_type.rank + if expr.attr == "memory_space": + return None if tile_type.memory_space is None else MemorySpace(tile_type.memory_space) + if expr.attr == "config": + return TileConfig() if tile_type.config is None else tile_type.config + if isinstance(expr.base.type, (SemanticTensorViewType, SemanticPartitionTensorViewType)) and expr.attr == "rank": + return expr.base.type.rank + return None + if isinstance(expr, SemanticTupleExpr): + values = [] + for element in expr.elements: + value = self._static_expr_value(element) + if value is None: + return None + values.append(value) + return tuple(values) + if isinstance(expr, SemanticSubscriptAccess): + base_value = self._static_expr_value(expr.base) + index_value = self._static_expr_value(expr.index) + if isinstance(base_value, (tuple, list)) and isinstance(index_value, int): + if 0 <= index_value < len(base_value): + return base_value[index_value] + return None + if isinstance(expr, SemanticBinaryExpr): + lhs = self._static_expr_value(expr.lhs) + rhs = self._static_expr_value(expr.rhs) + if lhs is None or rhs is None: + return None + if expr.op == "add" and isinstance(lhs, int) and isinstance(rhs, int): + return lhs + rhs + if expr.op == "sub" and isinstance(lhs, int) and isinstance(rhs, int): + return lhs - rhs + if expr.op == "mul" and isinstance(lhs, int) and isinstance(rhs, int): + return lhs * rhs + if expr.op == "floordiv" and isinstance(lhs, int) and isinstance(rhs, int) and rhs != 0: + return lhs // rhs + if expr.op == "eq": + return lhs == rhs + if expr.op == "ne": + return lhs != rhs + if expr.op == "gt": + return lhs > rhs + if expr.op == "lt": + return lhs < rhs + if expr.op == "ge": + return lhs >= rhs + if expr.op == "le": + return lhs <= rhs + if expr.op == "and": + return bool(lhs) and bool(rhs) + if expr.op == "or": + return bool(lhs) or bool(rhs) + return None return None def _infer_dma_load_transfer( @@ -1971,6 +2052,19 @@ def _lower_expr( name=self._materialize_constant(expr.value, expr.type), type=expr.type, ) + static_value = self._static_expr_value(expr) + if static_value is not None and isinstance(expr.type, (SemanticIndexType, SemanticScalarType)): + if desired_name is not None and into is not None: + into.append( + self._indent(indent) + + f"{desired_name} = arith.constant {self._format_constant(static_value, expr.type)} : " + f"{self._render_type(expr.type)}" + ) + return _RenderedValue(name=desired_name, type=expr.type) + return _RenderedValue( + name=self._materialize_constant(static_value, expr.type), + type=expr.type, + ) if isinstance(expr, SemanticSubscriptAccess): return self._lower_subscript_access( expr, @@ -2834,6 +2928,10 @@ def _extract_shape_subscript_value( expr: SemanticSubscriptAccess, env: dict[str, _RenderedValue], ) -> int | _RenderedValue: + base_static = self._static_expr_value(expr.base) + index_static = self._static_expr_value(expr.index) + if isinstance(base_static, (tuple, list)) and isinstance(index_static, int): + return base_static[index_static] if not isinstance(expr.base, SemanticAttributeAccess): raise NotImplementedError("only shape/stride indexing is supported in TileLang DSL v1 lowering") if expr.base.attr not in {"shape", "valid_shape", "strides"}: diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index d3ca7679f..d429d3085 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -47,16 +47,20 @@ unsupported_feature_message, ) from .types import ( + BLayout, Event, MaskType, MaskPattern, MemorySpace, OrderMode, + PadValue, PadMode, Pipe, PositionMode, PointerType, ScalarType, + SLayout, + TileConfig, VRegType, bf16, bytewidth, @@ -89,6 +93,9 @@ _PIPE_SYMBOLS = {pipe.name: pipe for pipe in Pipe} _EVENT_SYMBOLS = {event.name: event for event in Event} _MEMORY_SPACE_SYMBOLS = {memory_space.name: memory_space for memory_space in MemorySpace} +_B_LAYOUT_SYMBOLS = {layout.name: layout for layout in BLayout} +_S_LAYOUT_SYMBOLS = {layout.name: layout for layout in SLayout} +_PAD_VALUE_SYMBOLS = {pad_value.name: pad_value for pad_value in PadValue} _PAD_MODE_SYMBOLS = {pad_mode.name: pad_mode for pad_mode in PadMode} _POSITION_MODE_SYMBOLS = {position_mode.name: position_mode for position_mode in PositionMode} _ORDER_MODE_SYMBOLS = {order_mode.name: order_mode for order_mode in OrderMode} @@ -218,6 +225,7 @@ class SemanticTileType(SemanticType): shape: tuple[int, ...] | None valid_shape: tuple[int | None, ...] | None memory_space: str | None + config: TileConfig | None = None @dataclass(frozen=True) @@ -618,6 +626,7 @@ def _parameter_type(self, param: Any) -> SemanticType: shape=shape, valid_shape=valid_shape, memory_space=memory_space, + config=None if spec is None else spec.config, ) if param.kind == "ptr": memory_space = param.annotation.memory_space.value @@ -2254,6 +2263,8 @@ def _analyze_expr( return self._valid_shape_expr(base) if expr.attr == "strides": return self._strides_expr(base) + if expr.attr == "rank": + return self._rank_expr(base) attr_type = self._attribute_type(base, expr.attr) return SemanticAttributeAccess(base=base, attr=expr.attr, type=attr_type) if isinstance(expr, FrontendSubscriptExpr): @@ -2370,6 +2381,33 @@ def _analyze_symbol_expr(self, expr: FrontendSymbolExpr) -> SemanticExpr: value=memory_space, type=SemanticMetaType(kind="memory_space"), ) + if expr.namespace in {"pto.BLayout"}: + b_layout = _B_LAYOUT_SYMBOLS.get(expr.name) + if b_layout is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=b_layout, + type=SemanticMetaType(kind="b_layout"), + ) + if expr.namespace in {"pto.SLayout"}: + s_layout = _S_LAYOUT_SYMBOLS.get(expr.name) + if s_layout is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=s_layout, + type=SemanticMetaType(kind="s_layout"), + ) + if expr.namespace in {"pto.PadValue"}: + pad_value = _PAD_VALUE_SYMBOLS.get(expr.name) + if pad_value is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=pad_value, + type=SemanticMetaType(kind="pad_value"), + ) if expr.namespace in {"pto.PadMode"}: pad_mode = _PAD_MODE_SYMBOLS.get(expr.name) if pad_mode is not None: @@ -2411,8 +2449,29 @@ def _attribute_type(self, base: SemanticExpr, attr: str) -> SemanticType: return SemanticShapeType(rank=base_type.rank) if isinstance(base_type, (SemanticTensorViewType, SemanticPartitionTensorViewType, SemanticTileType)) and attr == "valid_shape": return SemanticShapeType(rank=base_type.rank) + if isinstance(base_type, (SemanticTensorViewType, SemanticPartitionTensorViewType, SemanticTileType)) and attr == "rank": + return SemanticIndexType() + if isinstance(base_type, SemanticTileType) and attr == "memory_space": + return SemanticMetaType(kind="memory_space") + if isinstance(base_type, SemanticTileType) and attr == "config": + return SemanticMetaType(kind="tile_config") + if isinstance(base_type, SemanticMetaType) and base_type.kind == "tile_config": + if attr == "b_layout": + return SemanticMetaType(kind="b_layout") + if attr == "s_layout": + return SemanticMetaType(kind="s_layout") + if attr == "s_fractal_size": + return SemanticScalarType(dtype=i32) + if attr == "pad_value": + return SemanticMetaType(kind="pad_value") raise TypeError(f"unsupported attribute access '{attr}' in TileLang DSL v1") + def _rank_expr(self, base: SemanticExpr) -> SemanticExpr: + base_type = base.type + if isinstance(base_type, (SemanticTensorViewType, SemanticPartitionTensorViewType, SemanticTileType)): + return SemanticLiteralExpr(value=base_type.rank, type=SemanticIndexType()) + raise TypeError("unsupported attribute access 'rank' in TileLang DSL v1") + def _element_type_expr(self, base: SemanticExpr) -> SemanticExpr: base_type = base.type if isinstance(base_type, (SemanticTensorViewType, SemanticPartitionTensorViewType, SemanticTileType)): @@ -2501,6 +2560,15 @@ def _strides_expr(self, base: SemanticExpr) -> SemanticExpr: type=SemanticTupleType(elements=tuple(SemanticIndexType() for _ in elements)), ) + def _static_shape_tuple_expr(self, values: tuple[int, ...]) -> SemanticTupleExpr: + return SemanticTupleExpr( + elements=tuple( + SemanticLiteralExpr(value=value, type=SemanticIndexType()) + for value in values + ), + type=SemanticTupleType(elements=tuple(SemanticIndexType() for _ in values)), + ) + def _subscript_type(self, base: SemanticExpr, index: SemanticExpr) -> SemanticType: if isinstance(base.type, SemanticShapeType): if not isinstance(index.type, SemanticIndexType): @@ -2516,6 +2584,16 @@ def _subscript_type(self, base: SemanticExpr, index: SemanticExpr) -> SemanticTy f"shape subscript index {index.value} is out of bounds for rank {base.type.rank}" ) return SemanticIndexType() + if isinstance(base.type, SemanticTupleType): + if not isinstance(index.type, SemanticIndexType): + raise TypeError("tuple subscript index must be an index value in TileLang DSL v1") + if isinstance(index, SemanticLiteralExpr) and isinstance(index.value, int): + if index.value < 0 or index.value >= len(base.type.elements): + raise TypeError( + f"tuple subscript index {index.value} is out of bounds for arity {len(base.type.elements)}" + ) + return base.type.elements[index.value] + raise TypeError("tuple subscript index must be a compile-time integer literal in TileLang DSL v1") if isinstance(base.type, (SemanticTensorViewType, SemanticPartitionTensorViewType)): if not isinstance(index, SemanticTupleExpr): raise TypeError("TensorView slicing expects a tuple of slices in TileLang DSL v1") @@ -3775,6 +3853,36 @@ def _try_static_value(self, expr: SemanticExpr | None) -> Any | None: return expr.value if isinstance(expr, SemanticBindingRef): return expr.binding.value + if isinstance(expr, SemanticAttributeAccess): + base_value = self._try_static_value(expr.base) + if isinstance(base_value, TileConfig): + if expr.attr == "b_layout": + return base_value.b_layout + if expr.attr == "s_layout": + return base_value.s_layout + if expr.attr == "s_fractal_size": + return base_value.s_fractal_size + if expr.attr == "pad_value": + return base_value.pad_value + if base_value is not None and hasattr(base_value, expr.attr): + return getattr(base_value, expr.attr) + if isinstance(expr.base.type, SemanticTileType): + tile_type = expr.base.type + config = TileConfig() if tile_type.config is None else tile_type.config + if expr.attr == "shape": + return tile_type.shape + if expr.attr == "valid_shape": + return self._resolved_tile_valid_shape(tile_type) + if expr.attr == "rank": + return tile_type.rank + if expr.attr == "memory_space": + return None if tile_type.memory_space is None else MemorySpace(tile_type.memory_space) + if expr.attr == "config": + return config + if isinstance(expr.base.type, (SemanticTensorViewType, SemanticPartitionTensorViewType)): + if expr.attr == "rank": + return expr.base.type.rank + return None if isinstance(expr, SemanticTupleExpr): elements = [] for element in expr.elements: diff --git a/tilelang-dsl/python/tilelang_dsl/support_matrix.py b/tilelang-dsl/python/tilelang_dsl/support_matrix.py index bf59af107..e2b510f2b 100644 --- a/tilelang-dsl/python/tilelang_dsl/support_matrix.py +++ b/tilelang-dsl/python/tilelang_dsl/support_matrix.py @@ -291,9 +291,6 @@ def get_pto_call_tier(call_name: str) -> str: "pto.dma_copy", "pto.vreduce", "pto.tile", - "BLayout", - "SLayout", - "PadValue", "SyncOpType", } ) @@ -312,6 +309,9 @@ def get_pto_call_tier(call_name: str) -> str: "pto.mask_b16": BASIC_TIER, "pto.mask_b32": BASIC_TIER, "PadMode": BASIC_TIER, + "BLayout": BASIC_TIER, + "SLayout": BASIC_TIER, + "PadValue": BASIC_TIER, "constexpr": BASIC_TIER, "pto.constexpr": BASIC_TIER, "tile[start:]": BASIC_TIER, diff --git a/tilelang-dsl/python/tilelang_dsl/types.py b/tilelang-dsl/python/tilelang_dsl/types.py index 405ece12f..7098f71e6 100644 --- a/tilelang-dsl/python/tilelang_dsl/types.py +++ b/tilelang-dsl/python/tilelang_dsl/types.py @@ -82,6 +82,19 @@ class MemorySpace(str, Enum): UB = "ub" +class BLayout(str, Enum): + ROW_MAJOR = "row_major" + COL_MAJOR = "col_major" + + +class SLayout(str, Enum): + NONE_BOX = "none_box" + + +class PadValue(str, Enum): + ZERO = "zero" + + class Pipe(str, Enum): MTE1 = "PIPE_MTE1" MTE2 = "PIPE_MTE2" @@ -124,6 +137,29 @@ class OrderMode(str, Enum): ASC = "ORDER_ASC" +def _coerce_int_config_value(value: Any, field_name: str) -> int: + if isinstance(value, bool) or not isinstance(value, int): + raise TypeError(f"TileConfig field '{field_name}' must be an integer") + return value + + +def _coerce_enum_config_value( + value: Any, + enum_type: type[Enum], + field_name: str, + default: Enum, +) -> Enum: + if value is None: + return default + if isinstance(value, enum_type): + return value + if isinstance(value, str): + for candidate in enum_type: + if value in {candidate.name, candidate.value}: + return candidate + raise TypeError(f"TileConfig field '{field_name}' must be a {enum_type.__name__} or matching string") + + @dataclass(frozen=True) class TileConfig: fields: tuple[tuple[str, Any], ...] = () @@ -132,6 +168,55 @@ class TileConfig: def from_mapping(cls, mapping: Mapping[str, Any]) -> "TileConfig": return cls(tuple(sorted(mapping.items()))) + def _field(self, *names: str) -> Any | None: + values = dict(self.fields) + for name in names: + if name in values: + return values[name] + return None + + @property + def b_layout(self) -> BLayout: + return _coerce_enum_config_value( + self._field("b_layout", "layout"), + BLayout, + "b_layout", + BLayout.ROW_MAJOR, + ) + + @property + def s_layout(self) -> SLayout: + return _coerce_enum_config_value( + self._field("s_layout", "slayout"), + SLayout, + "s_layout", + SLayout.NONE_BOX, + ) + + @property + def s_fractal_size(self) -> int: + value = self._field("s_fractal_size", "fractal") + if value is None: + return 512 + return _coerce_int_config_value(value, "s_fractal_size") + + @property + def pad_value(self) -> PadValue: + return _coerce_enum_config_value( + self._field("pad_value", "pad"), + PadValue, + "pad_value", + PadValue.ZERO, + ) + + +@dataclass(frozen=True) +class TileLayoutDescriptor: + shape: tuple[int, ...] + strides: tuple[int, ...] + byte_strides: tuple[int, ...] + offset: int = 0 + @dataclass(frozen=True) class TileSpecialization: @@ -210,6 +295,45 @@ def constexpr(value: bool) -> bool: return value +def tile_strides( + shape: tuple[int, ...], + config: TileConfig | None = None, +) -> tuple[int, ...]: + if not shape: + return () + normalized = TileConfig() if config is None else config + if normalized.b_layout == BLayout.COL_MAJOR and len(shape) == 2: + return (1, shape[0]) + strides = [1] + for dim in reversed(shape[1:]): + strides.insert(0, strides[0] * dim) + return tuple(strides) + + +def tile_byte_strides( + shape: tuple[int, ...], + dtype: ScalarType, + config: TileConfig | None = None, +) -> tuple[int, ...]: + element_bytes = bytewidth(dtype) + return tuple(stride * element_bytes for stride in tile_strides(shape, config)) + + +def tile_layout_descriptor( + shape: tuple[int, ...], + dtype: ScalarType, + config: TileConfig | None = None, + *, + offset: int = 0, +) -> TileLayoutDescriptor: + return TileLayoutDescriptor( + shape=shape, + strides=tile_strides(shape, config), + byte_strides=tile_byte_strides(shape, dtype, config), + offset=offset, + ) + + __all__ = [ "ScalarType", "WildcardType", @@ -224,6 +348,9 @@ def constexpr(value: bool) -> bool: "ptr", "vreg", "MemorySpace", + "BLayout", + "SLayout", + "PadValue", "Pipe", "Event", "PIPE", diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index 639641af5..857bef08f 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -72,6 +72,9 @@ def test_package_exports_surface(self) -> None: self.assertTrue(hasattr(pto, "mask_b8")) self.assertTrue(hasattr(pto, "mask_b16")) self.assertTrue(hasattr(pto, "mask_b32")) + self.assertTrue(hasattr(pto, "BLayout")) + self.assertTrue(hasattr(pto, "SLayout")) + self.assertTrue(hasattr(pto, "PadValue")) self.assertTrue(hasattr(pto, "constexpr")) self.assertTrue(hasattr(pto, "bytewidth")) self.assertTrue(hasattr(pto, "get_lanes")) @@ -85,6 +88,9 @@ def test_package_exports_surface(self) -> None: self.assertEqual(pto.PadMode.PadNull.value, "PadNull") self.assertEqual(pto.PadMode.PadFirstElem.value, "PadFirstElem") self.assertEqual(pto.PadMode.PadValue.value, "PadValue") + self.assertEqual(pto.BLayout.ROW_MAJOR.value, "row_major") + self.assertEqual(pto.SLayout.NONE_BOX.value, "none_box") + self.assertEqual(pto.PadValue.ZERO.value, "zero") self.assertEqual(pto.PositionMode.LOWEST.value, "POS_LOWEST") self.assertEqual(pto.OrderMode.ASC.value, "ORDER_ASC") @@ -133,6 +139,9 @@ def test_stable_starter_surface_groups_map_to_stable_tier(self) -> None: self.assertEqual(get_feature_tier("pto.elements_per_vreg"), BASIC_TIER) self.assertEqual(get_feature_tier("pto.constexpr"), BASIC_TIER) self.assertEqual(get_feature_tier("constexpr"), BASIC_TIER) + self.assertEqual(get_feature_tier("BLayout"), BASIC_TIER) + self.assertEqual(get_feature_tier("SLayout"), BASIC_TIER) + self.assertEqual(get_feature_tier("PadValue"), BASIC_TIER) self.assertEqual(get_feature_tier("tile[start:]"), BASIC_TIER) self.assertEqual(get_feature_tier("tile[row, col:]"), BASIC_TIER) @@ -3721,6 +3730,19 @@ def kernel(x: pto.TensorView): self.assertIn("unsupported op surface `pto.not_a_real_surface`", str(ctx.exception)) self.assertIn(f"{__file__}:", str(ctx.exception)) + def test_removed_tile_derived_query_surface_is_rejected(self) -> None: + @pto.vkernel(op="removed_tile_query_surface_unique", dtypes=[(pto.f32,)], advanced=True) + def kernel(dst: pto.Tile): + layout = dst.layout_descriptor + return None + + with self.assertRaises(TypeError) as ctx: + kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ).mlir_text() + + self.assertIn("unsupported attribute access 'layout_descriptor'", str(ctx.exception)) + def test_strict_vecscope_requires_advanced_mode(self) -> None: with self.assertRaises(pto.TileLangFrontendError) as ctx: From 30dd61a077e07605942604dc83db676a57d40f12 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Sat, 11 Apr 2026 18:39:20 +0800 Subject: [PATCH 030/192] Support more event IDs --- .../python/tilelang_dsl/frontend_ast.py | 2 +- tilelang-dsl/python/tilelang_dsl/semantic.py | 6 ++--- tilelang-dsl/python/tilelang_dsl/types.py | 24 +++++++++++++++++ tilelang-dsl/tests/test_tilelang_dsl_v1.py | 26 +++++++++++++++++++ 4 files changed, 54 insertions(+), 4 deletions(-) diff --git a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py index a48927d5a..ef340a6de 100644 --- a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py +++ b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py @@ -776,7 +776,7 @@ def _build_expr(node: ast.AST, context: _FrontendBuildContext) -> FrontendExprNo ) if isinstance(node, ast.Attribute): path = _attribute_path(node) - if path is not None and path[0] in {"pto", "PAT", "PIPE", "EVENT"} and len(path) >= 2: + if path is not None and path[0] in {"pto", "PAT", "PIPE", "Pipe", "EVENT", "Event"} and len(path) >= 2: return FrontendSymbolExpr(namespace=".".join(path[:-1]), name=path[-1]) return FrontendAttributeExpr(base=_build_expr(node.value, context), attr=node.attr) if isinstance(node, ast.Subscript): diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index d429d3085..2a7ec5e32 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -1985,7 +1985,7 @@ def _build_frontend_annotation_expr(self, node: ast.AST) -> FrontendExprNode: return FrontendConstantExpr(value=node.value) if isinstance(node, ast.Attribute): path = self._annotation_attribute_path(node) - if path is not None and path[0] in {"pto", "PAT", "PIPE", "EVENT"} and len(path) >= 2: + if path is not None and path[0] in {"pto", "PAT", "PIPE", "Pipe", "EVENT", "Event"} and len(path) >= 2: return FrontendSymbolExpr(namespace=".".join(path[:-1]), name=path[-1]) return FrontendAttributeExpr( base=self._build_frontend_annotation_expr(node.value), @@ -2354,7 +2354,7 @@ def _analyze_symbol_expr(self, expr: FrontendSymbolExpr) -> SemanticExpr: value=pattern, type=SemanticMetaType(kind="mask_pattern"), ) - if expr.namespace in {"PIPE", "pto.PIPE"}: + if expr.namespace in {"PIPE", "Pipe", "pto.PIPE", "pto.Pipe"}: pipe = _PIPE_SYMBOLS.get(expr.name) if pipe is not None: return SemanticSymbolExpr( @@ -2363,7 +2363,7 @@ def _analyze_symbol_expr(self, expr: FrontendSymbolExpr) -> SemanticExpr: value=pipe, type=SemanticMetaType(kind="pipe"), ) - if expr.namespace in {"EVENT", "pto.EVENT"}: + if expr.namespace in {"EVENT", "Event", "pto.EVENT", "pto.Event"}: event = _EVENT_SYMBOLS.get(expr.name) if event is not None: return SemanticSymbolExpr( diff --git a/tilelang-dsl/python/tilelang_dsl/types.py b/tilelang-dsl/python/tilelang_dsl/types.py index 7098f71e6..d592c0fab 100644 --- a/tilelang-dsl/python/tilelang_dsl/types.py +++ b/tilelang-dsl/python/tilelang_dsl/types.py @@ -112,6 +112,30 @@ class Event(str, Enum): ID5 = "EVENT_ID5" ID6 = "EVENT_ID6" ID7 = "EVENT_ID7" + ID8 = "EVENT_ID8" + ID9 = "EVENT_ID9" + ID10 = "EVENT_ID10" + ID11 = "EVENT_ID11" + ID12 = "EVENT_ID12" + ID13 = "EVENT_ID13" + ID14 = "EVENT_ID14" + ID15 = "EVENT_ID15" + ID16 = "EVENT_ID16" + ID17 = "EVENT_ID17" + ID18 = "EVENT_ID18" + ID19 = "EVENT_ID19" + ID20 = "EVENT_ID20" + ID21 = "EVENT_ID21" + ID22 = "EVENT_ID22" + ID23 = "EVENT_ID23" + ID24 = "EVENT_ID24" + ID25 = "EVENT_ID25" + ID26 = "EVENT_ID26" + ID27 = "EVENT_ID27" + ID28 = "EVENT_ID28" + ID29 = "EVENT_ID29" + ID30 = "EVENT_ID30" + ID31 = "EVENT_ID31" class MaskPattern(str, Enum): diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index 857bef08f..95cbb92ca 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -3323,6 +3323,32 @@ def kernel(inp: pto.TensorView, tile: pto.Tile, flag: pto.i32): self.assertIn("scf.for %lane_", text) self.assertIn("pto.barrier #pto.pipe", text) + def test_sync_ops_accept_event_class_alias_and_full_event_range(self) -> None: + Event = pto.Event + Pipe = pto.Pipe + + @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32)], advanced=True) + def kernel(inp: pto.TensorView, tile: pto.Tile): + pto.set_flag(Pipe.MTE2, Pipe.V, Event.ID31) + pto.wait_flag(Pipe.MTE2, Pipe.V, Event.ID31) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(16, 16), + memory_space=pto.MemorySpace.UB, + ) + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertIsInstance(semantic_kernel.body[0], SemanticSetFlagStmt) + self.assertIsInstance(semantic_kernel.body[1], SemanticWaitFlagStmt) + self.assertEqual(pto.Event.ID31.value, "EVENT_ID31") + + text = specialized.mlir_text() + self.assertIn('pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID31"]', text) + self.assertIn('pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID31"]', text) + def test_strict_vecscope_rejects_implicit_capture_during_semantic_analysis(self) -> None: @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f16, pto.i32)], advanced=True) def kernel(inp: pto.TensorView, tile: pto.Tile, scale: pto.i32): From 2e6381930190beb813c3ccf3042381386d2aa8d2 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Sat, 11 Apr 2026 18:58:22 +0800 Subject: [PATCH 031/192] Align copy_ubuf_to_ubuf --- .../python/tilelang_dsl/frontend_ast.py | 12 ++++ tilelang-dsl/python/tilelang_dsl/semantic.py | 53 +++++++++++++++++- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 55 +++++++++++++++++++ 3 files changed, 118 insertions(+), 2 deletions(-) diff --git a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py index ef340a6de..7f1b80625 100644 --- a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py +++ b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py @@ -645,6 +645,18 @@ def _collect_reachable_inline_procs( "ub_stride", } ), + "copy_ubuf_to_ubuf": frozenset( + { + "src", + "dst", + "src_offset", + "src_stride0", + "src_stride1", + "dst_offset", + "dst_stride0", + "dst_stride1", + } + ), } diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 2a7ec5e32..7308d87f6 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -1519,10 +1519,13 @@ def _analyze_low_level_dma_operands( f"pto.{expr.name} does not support mixing positional and keyword operands in TileLang DSL v1" ) if not expr.keywords: - return tuple( + args = tuple( self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) for arg in expr.args ) + if expr.name == "copy_ubuf_to_ubuf" and len(args) == 8: + return self._normalize_copy_ubuf_to_ubuf_guide_operands(args) + return args analyzed_keywords: dict[str, SemanticExpr] = { name: self._analyze_expr(value, env, allow_outer_lookup=allow_outer_lookup) @@ -1596,10 +1599,56 @@ def bool_literal(value: bool) -> SemanticLiteralExpr: analyzed_keywords["ub_stride"], ), ) + if expr.name == "copy_ubuf_to_ubuf": + return self._normalize_copy_ubuf_to_ubuf_guide_operands( + ( + analyzed_keywords["src"], + analyzed_keywords["dst"], + analyzed_keywords["src_offset"], + analyzed_keywords["src_stride0"], + analyzed_keywords["src_stride1"], + analyzed_keywords["dst_offset"], + analyzed_keywords["dst_stride0"], + analyzed_keywords["dst_stride1"], + ) + ) raise TypeError( f"pto.{expr.name} keyword form is not implemented in TileLang DSL v1" ) + def _normalize_copy_ubuf_to_ubuf_guide_operands( + self, + args: tuple[SemanticExpr, ...], + ) -> tuple[SemanticExpr, ...]: + if len(args) != 8: + raise TypeError( + "pto.copy_ubuf_to_ubuf guide form expects exactly 8 operands in TileLang DSL" + ) + source = self._require_pointer_expr(args[0], "pto.copy_ubuf_to_ubuf source", memory_space="ub") + destination = self._require_pointer_expr( + args[1], + "pto.copy_ubuf_to_ubuf destination", + memory_space="ub", + ) + self._require_i64_like_expr(args[2], "pto.copy_ubuf_to_ubuf src_offset") + self._require_i64_like_expr(args[3], "pto.copy_ubuf_to_ubuf src_stride0") + self._require_i64_like_expr(args[4], "pto.copy_ubuf_to_ubuf src_stride1") + self._require_i64_like_expr(args[5], "pto.copy_ubuf_to_ubuf dst_offset") + self._require_i64_like_expr(args[6], "pto.copy_ubuf_to_ubuf dst_stride0") + self._require_i64_like_expr(args[7], "pto.copy_ubuf_to_ubuf dst_stride1") + zero_sid = SemanticLiteralExpr(value=0, type=SemanticIndexType()) + # The guide-level surface rewrites offsets into the base pointers, then + # lowers the remaining four integers to the existing VPTO UB->UB copy ABI. + return ( + SemanticCallExpr(namespace="pto", name="addptr", args=(source, args[2]), type=source.type), + SemanticCallExpr(namespace="pto", name="addptr", args=(destination, args[5]), type=destination.type), + zero_sid, + args[3], + args[4], + args[6], + args[7], + ) + def _require_tensor_slice( self, expr: SemanticExpr, @@ -3012,7 +3061,7 @@ def _analyze_addptr(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: raise TypeError("pto.addptr expects exactly 2 positional arguments in TileLang DSL") pointer, offset = args ptr = self._require_pointer_expr(pointer, "pto.addptr pointer") - self._require_index_typed_expr(offset) + self._require_i64_like_expr(offset, "pto.addptr offset") return SemanticCallExpr(namespace="pto", name="addptr", args=(ptr, offset), type=ptr.type) def _analyze_get_lanes( diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index 95cbb92ca..cc04ea5b7 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -2972,6 +2972,61 @@ def kernel(src: pto.Tile, dst: pto.TensorView): r"pto\.copy_ubuf_to_gm %ub_ptr_\d+, %gm_ptr_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+", ) + def test_copy_ubuf_to_ubuf_guide_surface_lowers_in_advanced_mode(self) -> None: + @pto.vkernel( + op="ub_to_ub_dma_guide_unique", + dtypes=[(pto.f32, pto.f32, pto.i64, pto.i64)], + advanced=True, + ) + def kernel(src: pto.Tile, dst: pto.Tile, src_offset: pto.i64, dst_offset: pto.i64): + src_ptr = src.as_ptr() + dst_ptr = dst.as_ptr() + + pto.copy_ubuf_to_ubuf(src_ptr, dst_ptr, src_offset, 32, 128, dst_offset, 160, 192) + pto.copy_ubuf_to_ubuf( + src=src_ptr, + dst=dst_ptr, + src_offset=src_offset, + src_stride0=16, + src_stride1=64, + dst_offset=dst_offset, + dst_stride0=96, + dst_stride1=128, + ) + return None + + specialized = kernel.specialize( + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + low_level_copies = [stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticLowLevelCopyStmt)] + self.assertEqual(len(low_level_copies), 2) + self.assertTrue(all(isinstance(stmt.source, SemanticCallExpr) for stmt in low_level_copies)) + self.assertTrue(all(isinstance(stmt.destination, SemanticCallExpr) for stmt in low_level_copies)) + self.assertTrue(all(stmt.source.name == "addptr" for stmt in low_level_copies)) + self.assertTrue(all(stmt.destination.name == "addptr" for stmt in low_level_copies)) + + text = specialized.mlir_text() + self.assertRegex( + text, + r"%src_ptr_\d+ = pto\.tile_buf_addr %arg0 : !pto\.tile_buf -> !pto\.ptr", + ) + self.assertRegex( + text, + r"%dst_ptr_\d+ = pto\.tile_buf_addr %arg1 : !pto\.tile_buf -> !pto\.ptr", + ) + self.assertRegex( + text, + r"%tmp_\d+ = pto\.addptr %src_ptr_\d+, %arg2 : !pto\.ptr -> !pto\.ptr", + ) + self.assertRegex( + text, + r"%tmp_\d+ = pto\.addptr %dst_ptr_\d+, %arg3 : !pto\.ptr -> !pto\.ptr", + ) + self.assertEqual(text.count("pto.copy_ubuf_to_ubuf "), 2) + def test_castptr_rejects_tensorview_or_tile_inputs_in_advanced_mode(self) -> None: @pto.vkernel(op="castptr_tensorview_reject_unique", dtypes=[(pto.f32,)], advanced=True) def tensorview_kernel(inp: pto.TensorView): From f0aeccb84b6cd65dde835b1ba052d0fa054ca9a2 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Sat, 11 Apr 2026 20:02:51 +0800 Subject: [PATCH 032/192] Support more load/store ops --- tilelang-dsl/python/tilelang_dsl/__init__.py | 10 + tilelang-dsl/python/tilelang_dsl/lowering.py | 396 +++++++++++- tilelang-dsl/python/tilelang_dsl/semantic.py | 608 ++++++++++++++++-- .../python/tilelang_dsl/support_matrix.py | 12 + tilelang-dsl/python/tilelang_dsl/types.py | 32 + tilelang-dsl/tests/test_tilelang_dsl_v1.py | 102 +++ 6 files changed, 1114 insertions(+), 46 deletions(-) diff --git a/tilelang-dsl/python/tilelang_dsl/__init__.py b/tilelang-dsl/python/tilelang_dsl/__init__.py index bbc87ceb7..44fba3583 100644 --- a/tilelang-dsl/python/tilelang_dsl/__init__.py +++ b/tilelang-dsl/python/tilelang_dsl/__init__.py @@ -24,10 +24,13 @@ AnyInt, AnyMask, AnyType, + AlignType, BLayout, + DeinterleaveDist, EVENT, PIPE, Event, + InterleaveDist, MaskType, MemorySpace, MaskPattern, @@ -39,6 +42,7 @@ PointerType, Pipe, ScalarType, + StrideMode, TensorView, PartitionTensorView, Tile, @@ -48,6 +52,7 @@ TypeVariable, VRegType, WildcardType, + align, bf16, constexpr, bytewidth, @@ -88,6 +93,7 @@ "PointerType", "VRegType", "MaskType", + "AlignType", "ptr", "vreg", "MemorySpace", @@ -103,6 +109,9 @@ "PadMode", "PositionMode", "OrderMode", + "DeinterleaveDist", + "InterleaveDist", + "StrideMode", "TileConfig", "TileSpecialization", "i1", @@ -117,6 +126,7 @@ "AnyInt", "AnyType", "AnyMask", + "align", "mask_b8", "mask_b16", "mask_b32", diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index 4eb0edf82..be08309b2 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -22,6 +22,7 @@ SemanticDmaConfigStmt, SemanticDmaLoadStmt, SemanticDmaStoreStmt, + SemanticAlignType, SemanticExpr, SemanticExprStmt, SemanticForStmt, @@ -55,7 +56,7 @@ SemanticVectorStoreStmt, SemanticWaitFlagStmt, ) -from .types import MaskPattern, MemorySpace, ScalarType, TileConfig, get_lanes +from .types import MaskPattern, MemorySpace, ScalarType, TileConfig, get_lanes, tile_strides _I1_TYPE = SemanticScalarType(dtype=ScalarType("i1")) @@ -304,8 +305,21 @@ def _collect_used_tile_buffers_from_expr( used: set[str], ) -> None: if isinstance(expr, SemanticCallExpr): - if expr.namespace == "pto" and expr.name == "vlds" and expr.args: - self._record_tile_buffer_use(expr.args[0], used) + if expr.namespace == "pto" and expr.args: + if expr.name in { + "vlds", + "vldas", + "vldus", + "vldx2", + "vsld", + "psts", + "vsst", + "vstx2", + "vsta", + }: + self._record_tile_buffer_use(expr.args[0], used) + if expr.name in {"psts", "vsst", "vstx2", "vsta"} and len(expr.args) >= 2: + self._record_tile_buffer_use(expr.args[1], used) for arg in expr.args: self._collect_used_tile_buffers_from_expr(arg, used) return @@ -541,10 +555,10 @@ def _render_multi_result_assign( raise NotImplementedError( f"multi-result assignment for `pto.{stmt.value.name}` is not supported in TileLang DSL v1" ) - if len(stmt.targets) != 2: - raise NotImplementedError("multi-result lowering expects exactly two assignment targets") - if not isinstance(stmt.value.type, SemanticTupleType) or len(stmt.value.type.elements) != 2: - raise NotImplementedError("multi-result lowering expects a two-result tuple type") + if not isinstance(stmt.value.type, SemanticTupleType): + raise NotImplementedError("multi-result lowering expects a tuple-typed call value") + if len(stmt.targets) != len(stmt.value.type.elements): + raise NotImplementedError("multi-result lowering expects tuple assignment arity to match the call result count") if stmt.value.name == "make_mask": dtype_expr, remaining_expr = stmt.value.args @@ -637,6 +651,111 @@ def _render_multi_result_assign( env[high_target.name] = _RenderedValue(name=high_target.ssa_name, type=high_type) return lines + if stmt.value.name == "vldx2": + lines = [] + source_name, source_type, offset_name, offset_type = self._lower_memory_buffer_with_offset( + stmt.value.args[:-1], + env, + indent=indent, + into=lines, + ) + dist = self._render_string_literal(stmt.value.args[-1]) + rendered_targets = ", ".join(target.ssa_name for target in stmt.targets) + rendered_result_types = ", ".join( + self._render_type(result_type) for result_type in stmt.value.type.elements + ) + lines.append( + self._indent(indent) + + f"{rendered_targets} = pto.vldx2 {source_name}[{offset_name}], {dist} : " + + f"{source_type}, {offset_type} -> {rendered_result_types}" + ) + for target, result_type in zip(stmt.targets, stmt.value.type.elements): + env[target.name] = _RenderedValue(name=target.ssa_name, type=result_type) + return lines + + if stmt.value.name == "vldus": + lines = [] + source_name, source_type = self._lower_memory_buffer_without_offset( + stmt.value.args[:-1], + env, + indent=indent, + into=lines, + ) + align = self._lower_expr(stmt.value.args[-1], env, indent=indent, into=lines) + rendered_targets = ", ".join(target.ssa_name for target in stmt.targets) + rendered_result_types = ", ".join( + self._render_type(result_type) for result_type in stmt.value.type.elements + ) + lines.append( + self._indent(indent) + + f"{rendered_targets} = pto.vldus {source_name}, {align.name} : " + + f"{source_type}, {self._render_type(align.type)} -> {rendered_result_types}" + ) + for target, result_type in zip(stmt.targets, stmt.value.type.elements): + env[target.name] = _RenderedValue(name=target.ssa_name, type=result_type) + return lines + + if stmt.value.name == "pstu": + lines = [] + align = self._lower_expr(stmt.value.args[0], env, indent=indent, into=lines) + value = self._lower_expr(stmt.value.args[1], env, indent=indent, into=lines) + base = self._lower_expr(stmt.value.args[2], env, indent=indent, into=lines) + rendered_targets = ", ".join(target.ssa_name for target in stmt.targets) + rendered_result_types = ", ".join( + self._render_type(result_type) for result_type in stmt.value.type.elements + ) + lines.append( + self._indent(indent) + + f"{rendered_targets} = pto.pstu {align.name}, {value.name}, {base.name} : " + + f"{self._render_type(align.type)}, {self._render_type(value.type)}, {self._render_type(base.type)} " + + f"-> {rendered_result_types}" + ) + for target, result_type in zip(stmt.targets, stmt.value.type.elements): + env[target.name] = _RenderedValue(name=target.ssa_name, type=result_type) + return lines + + if stmt.value.name == "vstu": + lines = [] + align = self._lower_expr(stmt.value.args[0], env, indent=indent, into=lines) + offset = self._lower_expr(stmt.value.args[1], env, indent=indent, into=lines) + value = self._lower_expr(stmt.value.args[2], env, indent=indent, into=lines) + base = self._lower_expr(stmt.value.args[3], env, indent=indent, into=lines) + mode = self._render_string_literal(stmt.value.args[4]) + rendered_targets = ", ".join(target.ssa_name for target in stmt.targets) + rendered_result_types = ", ".join( + self._render_type(result_type) for result_type in stmt.value.type.elements + ) + lines.append( + self._indent(indent) + + f"{rendered_targets} = pto.vstu {align.name}, {offset.name}, {value.name}, {base.name}, {mode} : " + + f"{self._render_type(align.type)}, {self._render_type(offset.type)}, {self._render_type(value.type)}, {self._render_type(base.type)} " + + f"-> {rendered_result_types}" + ) + for target, result_type in zip(stmt.targets, stmt.value.type.elements): + env[target.name] = _RenderedValue(name=target.ssa_name, type=result_type) + return lines + + if stmt.value.name == "vstus": + lines = [] + align = self._lower_expr(stmt.value.args[0], env, indent=indent, into=lines) + offset = self._lower_expr(stmt.value.args[1], env, indent=indent, into=lines) + value = self._lower_expr(stmt.value.args[2], env, indent=indent, into=lines) + base = self._lower_expr(stmt.value.args[3], env, indent=indent, into=lines) + mode = self._render_string_literal(stmt.value.args[4]) + rendered_targets = ", ".join(target.ssa_name for target in stmt.targets) + rendered_result_types = ", ".join( + self._render_type(result_type) for result_type in stmt.value.type.elements + ) + lines.append( + self._indent(indent) + + f"{rendered_targets} = pto.vstus {align.name}, {offset.name}, {value.name}, {base.name}, {mode} : " + + f"{self._render_type(align.type)}, {self._render_type(offset.type)}, {self._render_type(value.type)}, {self._render_type(base.type)} " + + f"-> {rendered_result_types}" + ) + for target, result_type in zip(stmt.targets, stmt.value.type.elements): + env[target.name] = _RenderedValue(name=target.ssa_name, type=result_type) + return lines + raise NotImplementedError( f"multi-result assignment for `pto.{stmt.value.name}` is not supported in TileLang DSL v1" ) @@ -833,6 +952,160 @@ def _materialize_rank2_tile_subview( ) return _RenderedValue(name=subview_name, type=subview_type) + def _lower_memory_buffer_without_offset( + self, + args: tuple[SemanticExpr, ...], + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> tuple[str, str]: + if not args: + raise NotImplementedError("memory buffer lowering expects at least one operand") + source = self._lower_expr(args[0], env, indent=indent, into=into) + if isinstance(source.type, SemanticTileType): + source = self._materialize_tile_access_ptr( + source, + args[1:], + env, + indent=indent, + into=into, + ) + return source.name, self._render_type(source.type) + if len(args) != 1: + raise NotImplementedError("pointer memory buffer lowering does not accept tile-style indices") + return source.name, self._render_type(source.type) + + def _lower_memory_buffer_with_offset( + self, + args: tuple[SemanticExpr, ...], + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> tuple[str, str, str, str]: + if not args: + raise NotImplementedError("memory buffer lowering expects at least one operand") + source = self._lower_expr(args[0], env, indent=indent, into=into) + if isinstance(source.type, SemanticTileType): + if not args[1:]: + raise NotImplementedError("tile memory buffer lowering requires element indices") + offset = self._materialize_tile_linear_offset( + source, + args[1:], + env, + indent=indent, + into=into, + ) + source = self._materialize_tile_memref(source, indent=indent, into=into) + return ( + source.name, + self._render_type(source.type), + offset.name, + self._render_type(offset.type), + ) + if len(args) != 2: + raise NotImplementedError("pointer memory buffer lowering expects exactly one explicit offset operand") + offset = self._lower_expr(args[1], env, indent=indent, into=into) + return ( + source.name, + self._render_type(source.type), + offset.name, + self._render_type(offset.type), + ) + + def _materialize_tile_access_ptr( + self, + tile_value: _RenderedValue, + indices: tuple[SemanticExpr, ...], + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + base_ptr_name, base_ptr_type = self._materialize_copy_buffer_ptr( + tile_value, + indent=indent, + into=into, + ) + if not indices: + return _RenderedValue(name=base_ptr_name, type=tile_value.type if isinstance(tile_value.type, SemanticPtrType) else SemanticPtrType(tile_value.type.element_dtype, tile_value.type.memory_space or "ub")) + offset = self._materialize_tile_linear_offset( + tile_value, + indices, + env, + indent=indent, + into=into, + ) + typed_ptr_type = SemanticPtrType( + element_dtype=tile_value.type.element_dtype, + memory_space=tile_value.type.memory_space or "ub", + ) + offset_ptr_name = self._new_temp() + into.append( + self._indent(indent) + + f"{offset_ptr_name} = pto.addptr {base_ptr_name}, {offset.name} : " + + f"{base_ptr_type} -> {self._render_type(typed_ptr_type)}" + ) + return _RenderedValue(name=offset_ptr_name, type=typed_ptr_type) + + def _materialize_tile_linear_offset( + self, + tile_value: _RenderedValue, + indices: tuple[SemanticExpr, ...], + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + tile_type = tile_value.type + if not isinstance(tile_type, SemanticTileType): + raise NotImplementedError("tile linear offset lowering expects a Tile value") + if tile_type.rank == 1: + if len(indices) != 1: + raise NotImplementedError("rank-1 Tile access expects one index") + return self._lower_expr(indices[0], env, indent=indent, into=into) + if tile_type.rank != 2 or tile_type.shape is None: + raise NotImplementedError("Tile linear offset lowering expects a statically specialized rank-2 Tile") + if len(indices) != 2: + raise NotImplementedError("rank-2 Tile access expects two indices") + + row = self._lower_expr(indices[0], env, indent=indent, into=into) + col = self._lower_expr(indices[1], env, indent=indent, into=into) + strides = tile_strides(tile_type.shape, tile_type.config or TileConfig()) + stride0 = _RenderedValue( + name=self._materialize_constant(strides[0], SemanticIndexType()), + type=SemanticIndexType(), + ) + stride1 = _RenderedValue( + name=self._materialize_constant(strides[1], SemanticIndexType()), + type=SemanticIndexType(), + ) + row_term = self._emit_binary_value( + "mul", + row, + stride0, + SemanticIndexType(), + indent=indent, + into=into, + ) + col_term = self._emit_binary_value( + "mul", + col, + stride1, + SemanticIndexType(), + indent=indent, + into=into, + ) + return self._emit_binary_value( + "add", + row_term, + col_term, + SemanticIndexType(), + indent=indent, + into=into, + ) + def _tensor_slice_extents(self, expr: SemanticTensorSliceExpr) -> tuple[int, int]: if expr.type.rank != 2 or len(expr.type.extents) != 2: raise NotImplementedError("TileLang DSL v1 DMA lowering currently only supports rank-2 TensorView slices") @@ -2157,6 +2430,73 @@ def _lower_call_expr( into = [] result_name = desired_name or self._new_temp() + if isinstance(expr.type, SemanticMetaType) and expr.type.kind == "void": + if expr.name == "psts": + value = self._lower_expr(expr.args[0], env, indent=indent, into=into) + destination_name, destination_type, offset_name, _ = self._lower_memory_buffer_with_offset( + expr.args[1:], + env, + indent=indent, + into=into, + ) + into.append( + self._indent(indent) + + f"pto.psts {value.name}, {destination_name}[{offset_name}] : " + + f"{self._render_type(value.type)}, {destination_type}" + ) + return _RenderedValue(name="__void_call__", type=expr.type) + + if expr.name == "vsst": + value = self._lower_expr(expr.args[0], env, indent=indent, into=into) + destination_name, destination_type, offset_name, offset_type = self._lower_memory_buffer_with_offset( + expr.args[1:-1], + env, + indent=indent, + into=into, + ) + stride = self._render_string_literal(expr.args[-1]) + into.append( + self._indent(indent) + + f"pto.vsst {value.name}, {destination_name}[{offset_name}], {stride} : " + + f"{self._render_type(value.type)}, {destination_type}" + ) + return _RenderedValue(name="__void_call__", type=expr.type) + + if expr.name == "vstx2": + low = self._lower_expr(expr.args[0], env, indent=indent, into=into) + high = self._lower_expr(expr.args[1], env, indent=indent, into=into) + destination_name, destination_type, offset_name, offset_type = self._lower_memory_buffer_with_offset( + expr.args[2:-2], + env, + indent=indent, + into=into, + ) + dist = self._render_string_literal(expr.args[-2]) + mask = self._lower_expr(expr.args[-1], env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"pto.vstx2 {low.name}, {high.name}, {destination_name}[{offset_name}], {dist}, {mask.name} : " + + f"{self._render_type(low.type)}, {self._render_type(high.type)}, {destination_type}, {offset_type}, {self._render_type(mask.type)}" + ) + return _RenderedValue(name="__void_call__", type=expr.type) + + if expr.name == "vsta": + value = self._lower_expr(expr.args[0], env, indent=indent, into=into) + destination_name, destination_type, offset_name, offset_type = self._lower_memory_buffer_with_offset( + expr.args[1:], + env, + indent=indent, + into=into, + ) + into.append( + self._indent(indent) + + f"pto.vsta {value.name}, {destination_name}[{offset_name}] : " + + f"{self._render_type(value.type)}, {destination_type}, {offset_type}" + ) + return _RenderedValue(name="__void_call__", type=expr.type) + + raise NotImplementedError(f"void pto call `pto.{expr.name}` is not supported in TileLang DSL v1") + if expr.name == "make_mask": dtype_expr, pattern_expr = expr.args if not self._is_dtype_meta_expr(dtype_expr): @@ -2197,6 +2537,46 @@ def _lower_call_expr( ) return _RenderedValue(name=result_name, type=expr.type) + if expr.name == "vldas": + source_name, source_type = self._lower_memory_buffer_without_offset( + expr.args, + env, + indent=indent, + into=into, + ) + into.append( + self._indent(indent) + + f"{result_name} = pto.vldas {source_name} : {source_type} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "vsld": + source_name, source_type, offset_name, offset_type = self._lower_memory_buffer_with_offset( + expr.args[:-1], + env, + indent=indent, + into=into, + ) + stride = self._render_string_literal(expr.args[-1]) + into.append( + self._indent(indent) + + f"{result_name} = pto.vsld {source_name}[{offset_name}], {stride} : " + + f"{source_type} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "vstur": + align = self._lower_expr(expr.args[0], env, indent=indent, into=into) + value = self._lower_expr(expr.args[1], env, indent=indent, into=into) + base = self._lower_expr(expr.args[2], env, indent=indent, into=into) + mode = self._render_string_literal(expr.args[3]) + into.append( + self._indent(indent) + + f"{result_name} = pto.vstur {align.name}, {value.name}, {base.name}, {mode} : " + + f"{self._render_type(align.type)}, {self._render_type(value.type)}, {self._render_type(base.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + if expr.name == "vbr": scalar = self._lower_expr(expr.args[0], env, indent=indent, into=into) into.append( @@ -3039,6 +3419,8 @@ def _render_type(self, ty: SemanticType) -> str: return ty.dtype.name if isinstance(ty, SemanticPtrType): return f"!pto.ptr<{ty.element_dtype.name}, {ty.memory_space}>" + if isinstance(ty, SemanticAlignType): + return "!pto.align" if isinstance(ty, SemanticTensorViewType): return self._render_tensor_view_type( element_dtype=ty.element_dtype.name, diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 7308d87f6..ee9ba45b4 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -48,7 +48,9 @@ ) from .types import ( BLayout, + DeinterleaveDist, Event, + InterleaveDist, MaskType, MaskPattern, MemorySpace, @@ -60,6 +62,7 @@ PointerType, ScalarType, SLayout, + StrideMode, TileConfig, VRegType, bf16, @@ -99,6 +102,9 @@ _PAD_MODE_SYMBOLS = {pad_mode.name: pad_mode for pad_mode in PadMode} _POSITION_MODE_SYMBOLS = {position_mode.name: position_mode for position_mode in PositionMode} _ORDER_MODE_SYMBOLS = {order_mode.name: order_mode for order_mode in OrderMode} +_DEINTERLEAVE_DIST_SYMBOLS = {dist.name: dist for dist in DeinterleaveDist} +_INTERLEAVE_DIST_SYMBOLS = {dist.name: dist for dist in InterleaveDist} +_STRIDE_MODE_SYMBOLS = {mode.name: mode for mode in StrideMode} _UNARY_VECTOR_OPS = { "vabs", "vrelu", @@ -191,6 +197,9 @@ | _REARRANGEMENT_OPS | {"vcvt", "vmrgsort4"} ) +_VECTOR_MEMORY_EXPR_OPS = {"vlds", "vldas", "vldus", "vldx2", "vsld"} +_VECTOR_MEMORY_STMT_OPS = {"vsts", "psts", "vsst", "vstx2", "vsta"} +_STATEFUL_MEMORY_EXPR_OPS = {"pstu", "vstu", "vstus", "vstur"} _TENSORVIEW_RANK = 5 @@ -275,7 +284,13 @@ class SemanticVRegType(SemanticType): lanes: int +@dataclass(frozen=True) +class SemanticAlignType(SemanticType): + pass + + _I32_TYPE = SemanticScalarType(dtype=i32) +_ALIGN_TYPE = SemanticAlignType() @dataclass(frozen=True) @@ -819,7 +834,7 @@ def _should_infer_vecscope( return self._block_can_live_in_inferred_vecscope(stmt.body) name = self._frontend_vector_call_name(stmt) return name in ( - {"make_mask", "vlds", "vsts"} + {"make_mask"} | _VECTOR_MEMORY_EXPR_OPS | _VECTOR_MEMORY_STMT_OPS | _STATEFUL_MEMORY_EXPR_OPS | _UNARY_VECTOR_OPS | _BINARY_VECTOR_OPS | _VECTOR_SCALAR_OPS @@ -892,7 +907,7 @@ def _frontend_stmt_contains_vector_activity(self, stmt: FrontendStmtNode) -> boo return ( expr.namespace == "pto" and expr.name in ( - {"make_mask", "vlds", "vsts"} + {"make_mask"} | _VECTOR_MEMORY_EXPR_OPS | _VECTOR_MEMORY_STMT_OPS | _STATEFUL_MEMORY_EXPR_OPS | _UNARY_VECTOR_OPS | _BINARY_VECTOR_OPS | _VECTOR_SCALAR_OPS @@ -998,7 +1013,7 @@ def _semantic_block_contains_vector_activity( def _expr_contains_vector_activity(self, expr: SemanticExpr) -> bool: if isinstance(expr, SemanticCallExpr): if expr.namespace == "pto" and expr.name in ( - {"make_mask", "vlds"} + {"make_mask"} | _VECTOR_MEMORY_EXPR_OPS | _STATEFUL_MEMORY_EXPR_OPS | _UNARY_VECTOR_OPS | _BINARY_VECTOR_OPS | _VECTOR_SCALAR_OPS @@ -1053,8 +1068,8 @@ def _analyze_stmt( env, allow_outer_lookup=allow_outer_lookup, ) - if self._is_vector_store_call(stmt.expr): - return self._analyze_vector_store_stmt(stmt.expr, env, allow_outer_lookup=allow_outer_lookup) + if self._is_vector_memory_stmt_call(stmt.expr): + return self._analyze_vector_memory_stmt(stmt.expr, env, allow_outer_lookup=allow_outer_lookup) expr = self._analyze_expr(stmt.expr, env, allow_outer_lookup=allow_outer_lookup) return SemanticExprStmt(expr=expr), dict(env) if isinstance(stmt, FrontendReturnStmt): @@ -1238,11 +1253,11 @@ def _is_dma_call(self, expr: FrontendExprNode) -> bool: and expr.name in {"dma_load", "dma_store"} ) - def _is_vector_store_call(self, expr: FrontendExprNode) -> bool: + def _is_vector_memory_stmt_call(self, expr: FrontendExprNode) -> bool: return ( isinstance(expr, FrontendCallExpr) and expr.namespace == "pto" - and expr.name == "vsts" + and expr.name in _VECTOR_MEMORY_STMT_OPS ) def _is_sync_call(self, expr: FrontendExprNode) -> bool: @@ -1338,46 +1353,203 @@ def _analyze_dma_options( init_out_buffer=init_out_buffer, ) - def _analyze_vector_store_stmt( + def _analyze_vector_memory_stmt( self, expr: FrontendCallExpr, env: dict[str, SemanticBinding], *, allow_outer_lookup: bool, ) -> tuple[SemanticStmt, dict[str, SemanticBinding]]: - if len(expr.args) == 3: - value = self._analyze_expr(expr.args[0], env, allow_outer_lookup=allow_outer_lookup) - destination, indices = self._analyze_tile_vector_access( - expr.args[1], - env, - allow_outer_lookup=allow_outer_lookup, - context="pto.vsts destination", - ) - mask = self._analyze_expr(expr.args[2], env, allow_outer_lookup=allow_outer_lookup) - else: - args = tuple( - self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) - for arg in expr.args + if expr.name == "vsts": + if len(expr.args) == 3: + value = self._analyze_expr(expr.args[0], env, allow_outer_lookup=allow_outer_lookup) + destination, indices = self._analyze_tile_vector_access( + expr.args[1], + env, + allow_outer_lookup=allow_outer_lookup, + context="pto.vsts destination", + ) + mask = self._analyze_expr(expr.args[2], env, allow_outer_lookup=allow_outer_lookup) + else: + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + if len(args) != 4: + raise TypeError("pto.vsts expects 3 or 4 positional arguments in TileLang DSL v1") + value, destination, offset, mask = args + indices = (offset,) + self._require_vreg_expr(value, "pto.vsts value") + self._require_vector_pointer_expr(destination, "pto.vsts destination") + for index in indices: + self._require_index_typed_expr(index) + self._require_mask_for_vreg(mask, value.type, "pto.vsts") + self._require_matching_vector_pointer(value.type, destination.type, "pto.vsts") + return ( + SemanticVectorStoreStmt( + value=value, + destination=destination, + indices=indices, + mask=mask, + ), + dict(env), ) - if len(args) != 4: - raise TypeError("pto.vsts expects 3 or 4 positional arguments in TileLang DSL v1") - value, destination, offset, mask = args - indices = (offset,) - self._require_vreg_expr(value, "pto.vsts value") - self._require_vector_pointer_expr(destination, "pto.vsts destination") - for index in indices: - self._require_index_typed_expr(index) - self._require_mask_for_vreg(mask, value.type, "pto.vsts") - self._require_matching_vector_pointer(value.type, destination.type, "pto.vsts") - return ( - SemanticVectorStoreStmt( - value=value, - destination=destination, - indices=indices, - mask=mask, - ), - dict(env), + + analyzed = self._analyze_vector_memory_stmt_call( + expr, + env, + allow_outer_lookup=allow_outer_lookup, ) + return SemanticExprStmt(expr=analyzed), dict(env) + + def _analyze_vector_memory_stmt_call( + self, + expr: FrontendCallExpr, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + ) -> SemanticExpr: + if expr.name == "psts": + if len(expr.args) == 2: + value = self._analyze_expr(expr.args[0], env, allow_outer_lookup=allow_outer_lookup) + destination, indices = self._analyze_tile_vector_access( + expr.args[1], + env, + allow_outer_lookup=allow_outer_lookup, + context="pto.psts destination", + ) + else: + value, destination, offset = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + indices = (offset,) + mask = self._require_mask_expr(value, "pto.psts value") + self._require_vector_pointer_expr(destination, "pto.psts destination") + if isinstance(destination.type, SemanticTileType): + if len(indices) not in {1, 2}: + raise TypeError("pto.psts Tile syntax expects rank-1 or rank-2 element indexing in TileLang DSL v1") + else: + if len(indices) != 1: + raise TypeError("pto.psts pointer syntax expects exactly one offset operand in TileLang DSL v1") + for index in indices: + self._require_index_typed_expr(index) + return SemanticCallExpr( + namespace="pto", + name="psts", + args=(value, destination, *indices), + type=SemanticMetaType(kind="void"), + ) + + if expr.name == "vsst": + if len(expr.args) == 3: + value = self._analyze_expr(expr.args[0], env, allow_outer_lookup=allow_outer_lookup) + destination, indices = self._analyze_tile_vector_access( + expr.args[1], + env, + allow_outer_lookup=allow_outer_lookup, + context="pto.vsst destination", + ) + stride = self._analyze_expr(expr.args[2], env, allow_outer_lookup=allow_outer_lookup) + else: + value, destination, offset, stride = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + indices = (offset,) + vreg = self._require_vreg_expr(value, "pto.vsst value") + self._require_vector_pointer_expr(destination, "pto.vsst destination") + if isinstance(destination.type, SemanticTileType): + if len(indices) not in {1, 2}: + raise TypeError("pto.vsst Tile syntax expects rank-1 or rank-2 element indexing in TileLang DSL v1") + else: + if len(indices) != 1: + raise TypeError("pto.vsst pointer syntax expects exactly one offset operand in TileLang DSL v1") + for index in indices: + self._require_index_typed_expr(index) + self._require_matching_vector_pointer(vreg, destination.type, "pto.vsst") + normalized_stride = self._normalize_stride_mode(stride, "pto.vsst stride") + return SemanticCallExpr( + namespace="pto", + name="vsst", + args=(value, destination, *indices, normalized_stride), + type=SemanticMetaType(kind="void"), + ) + + if expr.name == "vstx2": + if len(expr.args) == 5: + low = self._analyze_expr(expr.args[0], env, allow_outer_lookup=allow_outer_lookup) + high = self._analyze_expr(expr.args[1], env, allow_outer_lookup=allow_outer_lookup) + destination, indices = self._analyze_tile_vector_access( + expr.args[2], + env, + allow_outer_lookup=allow_outer_lookup, + context="pto.vstx2 destination", + ) + dist = self._analyze_expr(expr.args[3], env, allow_outer_lookup=allow_outer_lookup) + mask = self._analyze_expr(expr.args[4], env, allow_outer_lookup=allow_outer_lookup) + else: + low, high, destination, offset, dist, mask = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + indices = (offset,) + low_type = self._require_vreg_expr(low, "pto.vstx2 low") + high_type = self._require_vreg_expr(high, "pto.vstx2 high") + if low_type != high_type: + raise TypeError("pto.vstx2 requires low/high vector types to match") + self._require_vector_pointer_expr(destination, "pto.vstx2 destination") + if isinstance(destination.type, SemanticTileType): + if len(indices) not in {1, 2}: + raise TypeError("pto.vstx2 Tile syntax expects rank-1 or rank-2 element indexing in TileLang DSL v1") + else: + if len(indices) != 1: + raise TypeError("pto.vstx2 pointer syntax expects exactly one offset operand in TileLang DSL v1") + for index in indices: + self._require_index_typed_expr(index) + self._require_mask_for_vreg(mask, low_type, "pto.vstx2") + self._require_matching_vector_pointer(low_type, destination.type, "pto.vstx2") + normalized_dist = self._normalize_interleave_dist(dist, "pto.vstx2 dist") + return SemanticCallExpr( + namespace="pto", + name="vstx2", + args=(low, high, destination, *indices, normalized_dist, mask), + type=SemanticMetaType(kind="void"), + ) + + if expr.name == "vsta": + if len(expr.args) == 2: + value = self._analyze_expr(expr.args[0], env, allow_outer_lookup=allow_outer_lookup) + destination, indices = self._analyze_tile_vector_access( + expr.args[1], + env, + allow_outer_lookup=allow_outer_lookup, + context="pto.vsta destination", + ) + else: + value, destination, offset = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + indices = (offset,) + self._require_align_expr(value, "pto.vsta value") + self._require_vector_pointer_expr(destination, "pto.vsta destination") + if isinstance(destination.type, SemanticTileType): + if len(indices) not in {1, 2}: + raise TypeError("pto.vsta Tile syntax expects rank-1 or rank-2 element indexing in TileLang DSL v1") + else: + if len(indices) != 1: + raise TypeError("pto.vsta pointer syntax expects exactly one offset operand in TileLang DSL v1") + for index in indices: + self._require_index_typed_expr(index) + return SemanticCallExpr( + namespace="pto", + name="vsta", + args=(value, destination, *indices), + type=SemanticMetaType(kind="void"), + ) + + raise ValueError(f"unsupported vector-memory stmt pto.{expr.name}") def _analyze_sync_stmt( self, @@ -2360,6 +2532,61 @@ def _analyze_expr( context="pto.vlds source", ) return self._analyze_vlds((base, *indices)) + if ( + expr.namespace == "pto" + and expr.name == "vldas" + and len(expr.args) == 1 + and isinstance(expr.args[0], FrontendSubscriptExpr) + ): + base, indices = self._analyze_tile_vector_access( + expr.args[0], + env, + allow_outer_lookup=allow_outer_lookup, + context="pto.vldas source", + ) + return self._analyze_vldas((base, *indices)) + if ( + expr.namespace == "pto" + and expr.name == "vldus" + and len(expr.args) == 2 + and isinstance(expr.args[0], FrontendSubscriptExpr) + ): + base, indices = self._analyze_tile_vector_access( + expr.args[0], + env, + allow_outer_lookup=allow_outer_lookup, + context="pto.vldus source", + ) + align = self._analyze_expr(expr.args[1], env, allow_outer_lookup=allow_outer_lookup) + return self._analyze_vldus((base, *indices, align)) + if ( + expr.namespace == "pto" + and expr.name == "vldx2" + and len(expr.args) == 2 + and isinstance(expr.args[0], FrontendSubscriptExpr) + ): + base, indices = self._analyze_tile_vector_access( + expr.args[0], + env, + allow_outer_lookup=allow_outer_lookup, + context="pto.vldx2 source", + ) + dist = self._analyze_expr(expr.args[1], env, allow_outer_lookup=allow_outer_lookup) + return self._analyze_vldx2((base, *indices, dist)) + if ( + expr.namespace == "pto" + and expr.name == "vsld" + and len(expr.args) == 2 + and isinstance(expr.args[0], FrontendSubscriptExpr) + ): + base, indices = self._analyze_tile_scalar_access( + expr.args[0], + env, + allow_outer_lookup=allow_outer_lookup, + context="pto.vsld source", + ) + stride = self._analyze_expr(expr.args[1], env, allow_outer_lookup=allow_outer_lookup) + return self._analyze_vsld((base, *indices, stride)) if expr.keywords: raise TypeError( f"call surface `{expr.namespace + '.' if expr.namespace else ''}{expr.name}` " @@ -2484,6 +2711,33 @@ def _analyze_symbol_expr(self, expr: FrontendSymbolExpr) -> SemanticExpr: value=order_mode, type=SemanticMetaType(kind="order_mode"), ) + if expr.namespace in {"DeinterleaveDist", "pto.DeinterleaveDist"}: + dist = _DEINTERLEAVE_DIST_SYMBOLS.get(expr.name) + if dist is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=dist, + type=SemanticMetaType(kind="deinterleave_dist"), + ) + if expr.namespace in {"InterleaveDist", "pto.InterleaveDist"}: + dist = _INTERLEAVE_DIST_SYMBOLS.get(expr.name) + if dist is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=dist, + type=SemanticMetaType(kind="interleave_dist"), + ) + if expr.namespace in {"StrideMode", "pto.StrideMode"}: + stride = _STRIDE_MODE_SYMBOLS.get(expr.name) + if stride is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=stride, + type=SemanticMetaType(kind="stride_mode"), + ) raise TypeError( f"symbol `{expr.namespace}.{expr.name}` is not supported in TileLang DSL v1" ) @@ -2672,6 +2926,29 @@ def _analyze_tile_vector_access( ) return base, indices + def _analyze_tile_scalar_access( + self, + expr: FrontendExprNode, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + context: str, + ) -> tuple[SemanticExpr, tuple[SemanticExpr, ...]]: + if not isinstance(expr, FrontendSubscriptExpr): + raise TypeError( + f"{context} expects Tile element-indexing syntax in TileLang DSL v1" + ) + base = self._analyze_expr(expr.base, env, allow_outer_lookup=allow_outer_lookup) + tile = self._require_tile_expr(base, context) + indices = self._tile_scalar_indices( + expr.index, + tile.type, + env, + allow_outer_lookup=allow_outer_lookup, + context=context, + ) + return base, indices + def _tile_vector_indices( self, index_expr: FrontendExprNode, @@ -2716,6 +2993,33 @@ def _tile_vector_indices( self._require_index_typed_expr(col) return (row, col) + def _tile_scalar_indices( + self, + index_expr: FrontendExprNode, + tile_type: SemanticTileType, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + context: str, + ) -> tuple[SemanticExpr, ...]: + if tile_type.rank == 1: + if isinstance(index_expr, FrontendSliceExpr): + raise TypeError(f"{context} expects Tile[pos] syntax for rank-1 Tile values") + index = self._analyze_expr(index_expr, env, allow_outer_lookup=allow_outer_lookup) + self._require_index_typed_expr(index) + return (index,) + + if tile_type.rank != 2 or tile_type.shape is None: + raise TypeError(f"{context} currently only supports statically specialized rank-1 or rank-2 Tiles") + if not isinstance(index_expr, FrontendTupleExpr) or len(index_expr.elements) != 2: + raise TypeError(f"{context} expects Tile[row, col] syntax for rank-2 Tile values") + + row = self._analyze_expr(index_expr.elements[0], env, allow_outer_lookup=allow_outer_lookup) + col = self._analyze_expr(index_expr.elements[1], env, allow_outer_lookup=allow_outer_lookup) + self._require_index_typed_expr(row) + self._require_index_typed_expr(col) + return (row, col) + def _tensor_slice_type( self, tensor_type: SemanticTensorViewType | SemanticPartitionTensorViewType, @@ -2851,6 +3155,22 @@ def _analyze_call_expr( return self._analyze_make_mask(args) if name == "vlds": return self._analyze_vlds(args) + if name == "vldas": + return self._analyze_vldas(args) + if name == "vldus": + return self._analyze_vldus(args) + if name == "vldx2": + return self._analyze_vldx2(args) + if name == "vsld": + return self._analyze_vsld(args) + if name == "pstu": + return self._analyze_pstu(args) + if name == "vstu": + return self._analyze_vstu(args) + if name == "vstus": + return self._analyze_vstus(args) + if name == "vstur": + return self._analyze_vstur(args) if name in {"ppack", "punpack"}: return self._analyze_mask_part_op(name, args) if name in {"pnot", "psel"}: @@ -3101,6 +3421,173 @@ def _analyze_vlds(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: type=self._vreg_type_for_dtype(source.type.element_dtype), ) + def _analyze_vldas(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if not 1 <= len(args) <= 3: + raise TypeError("pto.vldas expects 1 positional argument or Tile element-indexing syntax in TileLang DSL v1") + source, *indices = args + if isinstance(source.type, SemanticTileType): + source = self._require_tile_expr(source, "pto.vldas source") + for index in indices: + self._require_index_typed_expr(index) + else: + if indices: + raise TypeError("pto.vldas pointer syntax does not accept an explicit offset in TileLang DSL v1") + source = self._require_pointer_expr(source, "pto.vldas source", memory_space="ub") + return SemanticCallExpr( + namespace="pto", + name="vldas", + args=(source, *indices), + type=_ALIGN_TYPE, + ) + + def _analyze_vldus(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) < 2: + raise TypeError("pto.vldus expects source and align operands in TileLang DSL v1") + source = args[0] + align = args[-1] + indices = args[1:-1] + if isinstance(source.type, SemanticTileType): + source = self._require_tile_expr(source, "pto.vldus source") + for index in indices: + self._require_index_typed_expr(index) + else: + if indices: + raise TypeError("pto.vldus pointer syntax does not accept an explicit offset in TileLang DSL v1") + source = self._require_pointer_expr(source, "pto.vldus source", memory_space="ub") + self._require_align_expr(align, "pto.vldus align") + source_ptr_type = SemanticPtrType( + element_dtype=source.type.element_dtype, + memory_space="ub", + ) + return SemanticCallExpr( + namespace="pto", + name="vldus", + args=(source, *indices, align), + type=SemanticTupleType( + elements=( + self._vreg_type_for_dtype(source.type.element_dtype), + _ALIGN_TYPE, + source_ptr_type, + ) + ), + ) + + def _analyze_vldx2(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) < 3: + raise TypeError("pto.vldx2 expects source, offset, and dist operands in TileLang DSL v1") + source = args[0] + dist = args[-1] + indices = args[1:-1] + if isinstance(source.type, SemanticTileType): + source = self._require_tile_expr(source, "pto.vldx2 source") + if len(indices) not in {1, 2}: + raise TypeError("pto.vldx2 Tile syntax expects rank-1 or rank-2 element indexing in TileLang DSL v1") + else: + source = self._require_pointer_expr(source, "pto.vldx2 source", memory_space="ub") + if len(indices) != 1: + raise TypeError("pto.vldx2 pointer syntax expects exactly one offset operand in TileLang DSL v1") + for index in indices: + self._require_index_typed_expr(index) + normalized_dist = self._normalize_deinterleave_dist(dist, "pto.vldx2 dist") + result_type = self._vreg_type_for_dtype(source.type.element_dtype) + return SemanticCallExpr( + namespace="pto", + name="vldx2", + args=(source, *indices, normalized_dist), + type=SemanticTupleType(elements=(result_type, result_type)), + ) + + def _analyze_vsld(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) < 3: + raise TypeError("pto.vsld expects source, offset, and stride operands in TileLang DSL v1") + source = args[0] + stride = args[-1] + indices = args[1:-1] + if isinstance(source.type, SemanticTileType): + source = self._require_tile_expr(source, "pto.vsld source") + if len(indices) not in {1, 2}: + raise TypeError("pto.vsld Tile syntax expects rank-1 or rank-2 element indexing in TileLang DSL v1") + else: + source = self._require_pointer_expr(source, "pto.vsld source", memory_space="ub") + if len(indices) != 1: + raise TypeError("pto.vsld pointer syntax expects exactly one offset operand in TileLang DSL v1") + for index in indices: + self._require_index_typed_expr(index) + normalized_stride = self._normalize_stride_mode(stride, "pto.vsld stride") + return SemanticCallExpr( + namespace="pto", + name="vsld", + args=(source, *indices, normalized_stride), + type=self._vreg_type_for_dtype(source.type.element_dtype), + ) + + def _analyze_pstu(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) != 3: + raise TypeError("pto.pstu expects exactly 3 positional arguments in TileLang DSL v1") + align, value, base = args + self._require_align_expr(align, "pto.pstu align") + mask = self._require_mask_expr(value, "pto.pstu value") + base_ptr = self._require_pointer_expr(base, "pto.pstu base", memory_space="ub") + return SemanticCallExpr( + namespace="pto", + name="pstu", + args=(align, value, base_ptr), + type=SemanticTupleType(elements=(_ALIGN_TYPE, base_ptr.type)), + ) + + def _analyze_vstu(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) != 5: + raise TypeError("pto.vstu expects exactly 5 positional arguments in TileLang DSL v1") + align, offset, value, base, mode = args + self._require_align_expr(align, "pto.vstu align") + self._require_index_typed_expr(offset) + vec = self._require_vreg_expr(value, "pto.vstu value") + base_ptr = self._require_pointer_expr(base, "pto.vstu base", memory_space="ub") + if base_ptr.type.element_dtype != vec.element_dtype: + raise TypeError("pto.vstu requires base pointer dtype to match vector element dtype") + normalized_mode = self._normalize_mode_string(mode, "pto.vstu mode") + return SemanticCallExpr( + namespace="pto", + name="vstu", + args=(align, offset, value, base_ptr, normalized_mode), + type=SemanticTupleType(elements=(_ALIGN_TYPE, SemanticIndexType())), + ) + + def _analyze_vstus(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) != 5: + raise TypeError("pto.vstus expects exactly 5 positional arguments in TileLang DSL v1") + align, offset, value, base, mode = args + self._require_align_expr(align, "pto.vstus align") + self._require_i32_expr(offset, "pto.vstus offset") + vec = self._require_vreg_expr(value, "pto.vstus value") + base_ptr = self._require_pointer_expr(base, "pto.vstus base", memory_space="ub") + if base_ptr.type.element_dtype != vec.element_dtype: + raise TypeError("pto.vstus requires base pointer dtype to match vector element dtype") + normalized_mode = self._normalize_mode_string(mode, "pto.vstus mode") + return SemanticCallExpr( + namespace="pto", + name="vstus", + args=(align, offset, value, base_ptr, normalized_mode), + type=SemanticTupleType(elements=(_ALIGN_TYPE, base_ptr.type)), + ) + + def _analyze_vstur(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) != 4: + raise TypeError("pto.vstur expects exactly 4 positional arguments in TileLang DSL v1") + align, value, base, mode = args + self._require_align_expr(align, "pto.vstur align") + vec = self._require_vreg_expr(value, "pto.vstur value") + base_ptr = self._require_pointer_expr(base, "pto.vstur base", memory_space="ub") + if base_ptr.type.element_dtype != vec.element_dtype: + raise TypeError("pto.vstur requires base pointer dtype to match vector element dtype") + normalized_mode = self._normalize_mode_string(mode, "pto.vstur mode") + return SemanticCallExpr( + namespace="pto", + name="vstur", + args=(align, value, base_ptr, normalized_mode), + type=_ALIGN_TYPE, + ) + def _analyze_broadcast_vector_op( self, name: str, @@ -3554,6 +4041,10 @@ def _require_vreg_expr(self, expr: SemanticExpr, context: str) -> SemanticVRegTy raise TypeError(f"{context} must be a vector register value in TileLang DSL v1") return expr.type + def _require_align_expr(self, expr: SemanticExpr, context: str) -> None: + if not isinstance(expr.type, SemanticAlignType): + raise TypeError(f"{context} must be an align state value in TileLang DSL v1") + def _require_scalar_expr(self, expr: SemanticExpr, context: str) -> SemanticScalarType: if not isinstance(expr.type, SemanticScalarType): raise TypeError(f"{context} must be a scalar value in TileLang DSL v1") @@ -3650,6 +4141,39 @@ def _require_string_expr(self, expr: SemanticExpr, context: str) -> str: return expr.binding.value raise TypeError(f"{context} must be a string literal in TileLang DSL") + def _normalize_deinterleave_dist(self, expr: SemanticExpr, context: str) -> SemanticLiteralExpr: + if ( + isinstance(expr, SemanticSymbolExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "deinterleave_dist" + and isinstance(expr.value, DeinterleaveDist) + ): + return SemanticLiteralExpr(value=expr.value.value, type=SemanticMetaType(kind="string")) + return SemanticLiteralExpr(value=self._require_string_expr(expr, context), type=SemanticMetaType(kind="string")) + + def _normalize_interleave_dist(self, expr: SemanticExpr, context: str) -> SemanticLiteralExpr: + if ( + isinstance(expr, SemanticSymbolExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "interleave_dist" + and isinstance(expr.value, InterleaveDist) + ): + return SemanticLiteralExpr(value=expr.value.value, type=SemanticMetaType(kind="string")) + return SemanticLiteralExpr(value=self._require_string_expr(expr, context), type=SemanticMetaType(kind="string")) + + def _normalize_stride_mode(self, expr: SemanticExpr, context: str) -> SemanticLiteralExpr: + if ( + isinstance(expr, SemanticSymbolExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "stride_mode" + and isinstance(expr.value, StrideMode) + ): + return SemanticLiteralExpr(value=expr.value.value, type=SemanticMetaType(kind="string")) + return SemanticLiteralExpr(value=self._require_string_expr(expr, context), type=SemanticMetaType(kind="string")) + + def _normalize_mode_string(self, expr: SemanticExpr, context: str) -> SemanticLiteralExpr: + return SemanticLiteralExpr(value=self._require_string_expr(expr, context), type=SemanticMetaType(kind="string")) + def _require_i1_expr(self, expr: SemanticExpr, context: str) -> None: scalar = self._require_scalar_expr(expr, context) if scalar.dtype != i1: @@ -3662,6 +4186,11 @@ def _require_i64_like_expr(self, expr: SemanticExpr, context: str) -> None: if scalar.dtype != i64: raise TypeError(f"{context} must be an i64 or index value in TileLang DSL") + def _require_i32_expr(self, expr: SemanticExpr, context: str) -> None: + scalar = self._require_scalar_expr(expr, context) + if scalar.dtype != i32: + raise TypeError(f"{context} must be an i32 value in TileLang DSL") + def _require_tail_remaining_expr(self, expr: SemanticExpr, context: str) -> None: if isinstance(expr.type, SemanticIndexType): return @@ -4140,6 +4669,7 @@ def analyze_frontend_kernel(node: FrontendKernelNode) -> SemanticKernel: "SemanticBinding", "SemanticBindingRef", "SemanticCallExpr", + "SemanticAlignType", "SemanticDmaOptions", "SemanticDmaLoadStmt", "SemanticDmaStoreStmt", diff --git a/tilelang-dsl/python/tilelang_dsl/support_matrix.py b/tilelang-dsl/python/tilelang_dsl/support_matrix.py index e2b510f2b..fff227865 100644 --- a/tilelang-dsl/python/tilelang_dsl/support_matrix.py +++ b/tilelang-dsl/python/tilelang_dsl/support_matrix.py @@ -41,7 +41,15 @@ { "make_mask", "vlds", + "vldas", + "vldus", + "vldx2", + "vsld", "vsts", + "psts", + "vsst", + "vstx2", + "vsta", "vabs", "vrelu", "vexp", @@ -132,6 +140,10 @@ "vdintlv", "vintlvv2", "vdintlvv2", + "pstu", + "vstu", + "vstus", + "vstur", } ) diff --git a/tilelang-dsl/python/tilelang_dsl/types.py b/tilelang-dsl/python/tilelang_dsl/types.py index d592c0fab..0d21acfe0 100644 --- a/tilelang-dsl/python/tilelang_dsl/types.py +++ b/tilelang-dsl/python/tilelang_dsl/types.py @@ -61,6 +61,12 @@ def __repr__(self) -> str: return f"mask_{self.granularity}" +@dataclass(frozen=True) +class AlignType: + def __repr__(self) -> str: + return "align" + + @dataclass(frozen=True) class WildcardType: name: str @@ -161,6 +167,26 @@ class OrderMode(str, Enum): ASC = "ORDER_ASC" +class DeinterleaveDist(str, Enum): + B8 = "DINTLV_B8" + B16 = "DINTLV_B16" + B32 = "DINTLV_B32" + BD = "BDINTLV" + + +class InterleaveDist(str, Enum): + B8 = "INTLV_B8" + B16 = "INTLV_B16" + B32 = "INTLV_B32" + + +class StrideMode(str, Enum): + S3_B16 = "STRIDE_S3_B16" + S4_B64 = "STRIDE_S4_B64" + S8_B32 = "STRIDE_S8_B32" + S2_B64 = "STRIDE_S2_B64" + + def _coerce_int_config_value(value: Any, field_name: str) -> int: if isinstance(value, bool) or not isinstance(value, int): raise TypeError(f"TileConfig field '{field_name}' must be an integer") @@ -265,6 +291,7 @@ class TileSpecialization: AnyInt = WildcardType("AnyInt") AnyType = WildcardType("AnyType") AnyMask = WildcardType("AnyMask") +align = AlignType() mask_b8 = MaskType("b8") mask_b16 = MaskType("b16") mask_b32 = MaskType("b32") @@ -369,6 +396,7 @@ def tile_layout_descriptor( "PointerType", "VRegType", "MaskType", + "AlignType", "ptr", "vreg", "MemorySpace", @@ -384,6 +412,9 @@ def tile_layout_descriptor( "PadMode", "PositionMode", "OrderMode", + "DeinterleaveDist", + "InterleaveDist", + "StrideMode", "TileConfig", "TileSpecialization", "i1", @@ -398,6 +429,7 @@ def tile_layout_descriptor( "AnyInt", "AnyType", "AnyMask", + "align", "mask_b8", "mask_b16", "mask_b32", diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index cc04ea5b7..fee3a46a3 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -29,6 +29,7 @@ ) from tilelang_dsl.lowering import AuthoringModule, lower_semantic_kernel from tilelang_dsl.semantic import ( + SemanticAlignType, SemanticAssignStmt, SemanticBinaryExpr, SemanticCallExpr, @@ -67,8 +68,10 @@ def test_package_exports_surface(self) -> None: self.assertTrue(hasattr(pto, "PointerType")) self.assertTrue(hasattr(pto, "VRegType")) self.assertTrue(hasattr(pto, "MaskType")) + self.assertTrue(hasattr(pto, "AlignType")) self.assertTrue(hasattr(pto, "ptr")) self.assertTrue(hasattr(pto, "vreg")) + self.assertTrue(hasattr(pto, "align")) self.assertTrue(hasattr(pto, "mask_b8")) self.assertTrue(hasattr(pto, "mask_b16")) self.assertTrue(hasattr(pto, "mask_b32")) @@ -83,8 +86,12 @@ def test_package_exports_surface(self) -> None: self.assertTrue(hasattr(pto, "PadMode")) self.assertTrue(hasattr(pto, "PositionMode")) self.assertTrue(hasattr(pto, "OrderMode")) + self.assertTrue(hasattr(pto, "DeinterleaveDist")) + self.assertTrue(hasattr(pto, "InterleaveDist")) + self.assertTrue(hasattr(pto, "StrideMode")) self.assertTrue(hasattr(pto, "PIPE")) self.assertTrue(hasattr(pto, "EVENT")) + self.assertEqual(repr(pto.align), "align") self.assertEqual(pto.PadMode.PadNull.value, "PadNull") self.assertEqual(pto.PadMode.PadFirstElem.value, "PadFirstElem") self.assertEqual(pto.PadMode.PadValue.value, "PadValue") @@ -93,6 +100,9 @@ def test_package_exports_surface(self) -> None: self.assertEqual(pto.PadValue.ZERO.value, "zero") self.assertEqual(pto.PositionMode.LOWEST.value, "POS_LOWEST") self.assertEqual(pto.OrderMode.ASC.value, "ORDER_ASC") + self.assertEqual(pto.DeinterleaveDist.B32.value, "DINTLV_B32") + self.assertEqual(pto.InterleaveDist.B16.value, "INTLV_B16") + self.assertEqual(pto.StrideMode.S4_B64.value, "STRIDE_S4_B64") class TileLangDSLSupportMatrixTests(unittest.TestCase): @@ -106,7 +116,15 @@ def test_stable_starter_surface_groups_map_to_stable_tier(self) -> None: self.assertIn("Tile", AUTHORING_TIER_SURFACE_GROUPS["Tile"]) self.assertNotIn("dma_load/store", AUTHORING_TIER_SURFACE_GROUPS) self.assertIn("pto.vlds", AUTHORING_TIER_SURFACE_GROUPS["base_vector_ops"]) + self.assertIn("pto.vldas", AUTHORING_TIER_SURFACE_GROUPS["base_vector_ops"]) + self.assertIn("pto.vldus", AUTHORING_TIER_SURFACE_GROUPS["base_vector_ops"]) + self.assertIn("pto.vldx2", AUTHORING_TIER_SURFACE_GROUPS["base_vector_ops"]) + self.assertIn("pto.vsld", AUTHORING_TIER_SURFACE_GROUPS["base_vector_ops"]) self.assertIn("pto.vsts", AUTHORING_TIER_SURFACE_GROUPS["base_vector_ops"]) + self.assertIn("pto.psts", AUTHORING_TIER_SURFACE_GROUPS["base_vector_ops"]) + self.assertIn("pto.vsst", AUTHORING_TIER_SURFACE_GROUPS["base_vector_ops"]) + self.assertIn("pto.vstx2", AUTHORING_TIER_SURFACE_GROUPS["base_vector_ops"]) + self.assertIn("pto.vsta", AUTHORING_TIER_SURFACE_GROUPS["base_vector_ops"]) self.assertIn("pto.vadd", AUTHORING_TIER_SURFACE_GROUPS["base_vector_ops"]) self.assertIn("pto.vmuls", AUTHORING_TIER_SURFACE_GROUPS["base_vector_ops"]) self.assertIn("tile[start:]", BASIC_TILE_INDEXING_SURFACES) @@ -115,7 +133,15 @@ def test_stable_starter_surface_groups_map_to_stable_tier(self) -> None: self.assertEqual(get_feature_tier("TensorView"), BASIC_TIER) self.assertEqual(get_feature_tier("Tile"), BASIC_TIER) self.assertEqual(get_feature_tier("pto.vlds"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vldas"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vldus"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vldx2"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vsld"), BASIC_TIER) self.assertEqual(get_feature_tier("pto.vsts"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.psts"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vsst"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vstx2"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vsta"), BASIC_TIER) self.assertEqual(get_feature_tier("pto.vadd"), BASIC_TIER) self.assertEqual(get_feature_tier("pto.vmuls"), BASIC_TIER) self.assertEqual(get_feature_tier("pto.vaddrelu"), BASIC_TIER) @@ -127,6 +153,10 @@ def test_stable_starter_surface_groups_map_to_stable_tier(self) -> None: self.assertEqual(get_feature_tier("pto.vci"), BASIC_TIER) self.assertEqual(get_feature_tier("pto.vpack"), BASIC_TIER) self.assertEqual(get_feature_tier("pto.vsort32"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.pstu"), ADVANCED_TIER) + self.assertEqual(get_feature_tier("pto.vstu"), ADVANCED_TIER) + self.assertEqual(get_feature_tier("pto.vstus"), ADVANCED_TIER) + self.assertEqual(get_feature_tier("pto.vstur"), ADVANCED_TIER) self.assertEqual(get_feature_tier("PadMode"), BASIC_TIER) self.assertEqual(get_feature_tier("VRegType"), BASIC_TIER) self.assertEqual(get_feature_tier("MaskType"), BASIC_TIER) @@ -1633,6 +1663,78 @@ def kernel(tile: pto.Tile, scale: pto.f32): self.assertRegex(text, r"%activated_\d+ = pto\.vrelu %summed_\d+, %mask_\d+ : !pto\.vreg<64xf32>, !pto\.mask -> !pto\.vreg<64xf32>") self.assertRegex(text, r"pto\.vsts %activated_\d+, %dst_\d+\[%lane_\d+\], %mask_\d+ : !pto\.vreg<64xf32>, !pto\.ptr, !pto\.mask") + def test_basic_vector_memory_family_surfaces_lower_from_tile_indexing(self) -> None: + @pto.vkernel(op="vector_memory_basic_unique", dtypes=[(pto.f32, pto.f32)]) + def kernel(src: pto.Tile, dst: pto.Tile): + align = pto.vldas(src[0, 0:]) + vec, next_align, base_out = pto.vldus(src[0, 0:], align) + low, high = pto.vldx2(src[0, 0:], pto.DeinterleaveDist.B32) + strided = pto.vsld(src[0, 0], pto.StrideMode.S4_B64) + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + pto.psts(mask, dst[0, 0:]) + pto.vsst(vec, dst[0, 0:], pto.StrideMode.S4_B64) + pto.vstx2(low, high, dst[0, 0:], pto.InterleaveDist.B32, mask) + pto.vsta(next_align, dst[0, 0:]) + return None + + specialized = kernel.specialize( + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + vecscope = semantic_kernel.body[0] + self.assertIsInstance(vecscope, SemanticVecscopeStmt) + self.assertIsInstance(vecscope.body[0], SemanticAssignStmt) + self.assertIsInstance(vecscope.body[0].targets[0].type, SemanticAlignType) + self.assertEqual(vecscope.body[1].targets[0].type.element_dtype, pto.f32) + self.assertEqual(vecscope.body[2].value.name, "vldx2") + + text = specialized.mlir_text() + self.assertIn("pto.vldas", text) + self.assertIn("pto.vldus", text) + self.assertIn('pto.vldx2', text) + self.assertIn('"DINTLV_B32"', text) + self.assertIn("pto.vsld", text) + self.assertIn('"STRIDE_S4_B64"', text) + self.assertIn("pto.psts", text) + self.assertIn("pto.vsst", text) + self.assertIn("pto.vstx2", text) + self.assertIn('"INTLV_B32"', text) + self.assertIn("pto.vsta", text) + + def test_advanced_stateful_vector_memory_surfaces_lower_with_pointer_state(self) -> None: + @pto.vkernel(op="vector_memory_stateful_unique", dtypes=[(pto.f32, pto.f32)], advanced=True) + def kernel(src: pto.Tile, dst: pto.Tile): + ub_src = src.as_ptr() + ub_dst = dst.as_ptr() + align0 = pto.vldas(ub_src) + vec = pto.vlds(ub_src, 0) + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + align1, base1 = pto.pstu(align0, mask, ub_dst) + align2, offset2 = pto.vstu(align1, 0, vec, ub_dst, "MODE_ZEROING") + align3, base3 = pto.vstus(align2, pto.i32(16), vec, ub_dst, "MODE_ZEROING") + align4 = pto.vstur(align3, vec, ub_dst, "MODE_ZEROING") + return None + + specialized = kernel.specialize( + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertIsInstance(semantic_kernel.body[0], SemanticAssignStmt) + self.assertEqual(semantic_kernel.body[0].value.name, "tile_as_ptr") + + text = specialized.mlir_text() + self.assertIn("pto.pstu", text) + self.assertIn("pto.vstu", text) + self.assertIn("pto.vstus", text) + self.assertIn("pto.vstur", text) + self.assertIn('"MODE_ZEROING"', text) + self.assertRegex(text, r"= pto\.vstu %align1_\d+, %c0, %vec_\d+, %ub_dst_\d+, \"MODE_ZEROING\"") + self.assertRegex(text, r"= pto\.vstus %align2_\d+, %(?:c16_i32|tmp_\d+), %vec_\d+, %ub_dst_\d+, \"MODE_ZEROING\"") + def test_tail_make_mask_lowers_to_typed_plt_and_updates_remaining(self) -> None: @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.i32)], advanced=True) def kernel(tile: pto.Tile, remaining: pto.i32): From 17f22688c421a6a9212d4fcd5d0d69711fbe9c7c Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Sat, 11 Apr 2026 20:20:23 +0800 Subject: [PATCH 033/192] Support colmajor indexing syntax --- tilelang-dsl/docs/unsupported-features.md | 27 +++------ tilelang-dsl/python/tilelang_dsl/lowering.py | 3 +- tilelang-dsl/python/tilelang_dsl/semantic.py | 37 +++++++++++- .../python/tilelang_dsl/support_matrix.py | 2 + tilelang-dsl/tests/test_tilelang_dsl_v1.py | 59 +++++++++++++++++++ 5 files changed, 105 insertions(+), 23 deletions(-) diff --git a/tilelang-dsl/docs/unsupported-features.md b/tilelang-dsl/docs/unsupported-features.md index 2db9ea66a..4cd1b03b9 100644 --- a/tilelang-dsl/docs/unsupported-features.md +++ b/tilelang-dsl/docs/unsupported-features.md @@ -58,21 +58,12 @@ These documented surfaces are not accepted by the current frontend: ### Missing Vector Load/Store Families -Only `pto.vlds(...)` and `pto.vsts(...)` are implemented from the guide's -load/store families. The following documented ops are still unsupported: - -- `pto.vldas(...)` -- `pto.vldus(...)` -- `pto.vldx2(...)` -- `pto.vsld(...)` -- `pto.psts(...)` -- `pto.vsst(...)` -- `pto.vstx2(...)` -- `pto.vsta(...)` -- `pto.pstu(...)` -- `pto.vstu(...)` -- `pto.vstus(...)` -- `pto.vstur(...)` +The previously missing vector-memory surfaces +`pto.vldas(...)`, `pto.vldus(...)`, `pto.vldx2(...)`, `pto.vsld(...)`, +`pto.psts(...)`, `pto.vsst(...)`, `pto.vstx2(...)`, `pto.vsta(...)`, +`pto.pstu(...)`, `pto.vstu(...)`, `pto.vstus(...)`, and `pto.vstur(...)` +are now implemented. Remaining guide gaps are concentrated in the wider +indexed/flush families that are still not wired through this DSL package. ### Missing Direct Predicate Constructor/Compare APIs @@ -190,16 +181,14 @@ Currently supported: - rank-1: `tile[start:]` - rank-2: `tile[row, col:]` -- only for `pto.vlds(...)` and `pto.vsts(...)` +- rank-2 column-major: `tile[row_start:, col_index]` +- available across the current basic vector-memory tile-indexing family Not currently supported from the guide's broader indexing model: -- column-major syntax such as `tile[row_start:, col_index]` - single-element syntax such as `tile[row, col]` and `tile[pos]` - explicit slice `stop` - stepped tile vector slices -- the guide's wider indexed op family (`vldas`, `vldus`, `vldx2`, - `vsld`, `psts`, `vsst`, `vstx2`, `vsta`) ### Control-Flow Result Merging diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index be08309b2..9a4fdcd98 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -3499,11 +3499,12 @@ def _render_tile_buf_type(self, ty: SemanticTileType) -> str: valid_shape = ty.valid_shape or ty.shape v_row = valid_shape[0] v_col = 1 if ty.rank == 1 else valid_shape[1] + config = ty.config or TileConfig() return ( f"!pto.tile_buf" + f"blayout={config.b_layout.value}, slayout=none_box, fractal=512, pad=0>" ) def _render_tile_buf_loc(self, memory_space: str) -> str: diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index ee9ba45b4..4a11d78fc 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -2974,11 +2974,33 @@ def _tile_vector_indices( if tile_type.rank != 2 or tile_type.shape is None: raise TypeError(f"{context} currently only supports statically specialized rank-1 or rank-2 Tiles") if not isinstance(index_expr, FrontendTupleExpr) or len(index_expr.elements) != 2: - raise TypeError(f"{context} expects Tile[row, col:] syntax for rank-2 Tile values") + raise TypeError( + f"{context} expects {self._tile_vector_rank2_syntax(tile_type)} syntax for rank-2 Tile values" + ) row_expr, col_expr = index_expr.elements - if not isinstance(col_expr, FrontendSliceExpr): - raise TypeError(f"{context} expects Tile[row, col:] syntax for rank-2 Tile values") + if self._tile_b_layout(tile_type) == BLayout.COL_MAJOR: + if not isinstance(row_expr, FrontendSliceExpr) or isinstance(col_expr, FrontendSliceExpr): + raise TypeError( + f"{context} expects {self._tile_vector_rank2_syntax(tile_type)} syntax for rank-2 Tile values" + ) + if row_expr.stop is not None: + raise TypeError(f"{context} does not support explicit slice stop in TileLang DSL advanced mode") + if row_expr.step is not None: + raise TypeError(f"{context} does not support stepped Tile vector slices in TileLang DSL advanced mode") + if row_expr.start is None: + row = SemanticLiteralExpr(value=0, type=SemanticIndexType()) + else: + row = self._analyze_expr(row_expr.start, env, allow_outer_lookup=allow_outer_lookup) + self._require_index_typed_expr(row) + col = self._analyze_expr(col_expr, env, allow_outer_lookup=allow_outer_lookup) + self._require_index_typed_expr(col) + return (row, col) + + if not isinstance(col_expr, FrontendSliceExpr) or isinstance(row_expr, FrontendSliceExpr): + raise TypeError( + f"{context} expects {self._tile_vector_rank2_syntax(tile_type)} syntax for rank-2 Tile values" + ) if col_expr.stop is not None: raise TypeError(f"{context} does not support explicit slice stop in TileLang DSL advanced mode") if col_expr.step is not None: @@ -2993,6 +3015,15 @@ def _tile_vector_indices( self._require_index_typed_expr(col) return (row, col) + def _tile_b_layout(self, tile_type: SemanticTileType) -> BLayout: + config = TileConfig() if tile_type.config is None else tile_type.config + return config.b_layout + + def _tile_vector_rank2_syntax(self, tile_type: SemanticTileType) -> str: + if self._tile_b_layout(tile_type) == BLayout.COL_MAJOR: + return "Tile[row_start:, col_index]" + return "Tile[row, col:]" + def _tile_scalar_indices( self, index_expr: FrontendExprNode, diff --git a/tilelang-dsl/python/tilelang_dsl/support_matrix.py b/tilelang-dsl/python/tilelang_dsl/support_matrix.py index fff227865..efb16b9b0 100644 --- a/tilelang-dsl/python/tilelang_dsl/support_matrix.py +++ b/tilelang-dsl/python/tilelang_dsl/support_matrix.py @@ -225,6 +225,7 @@ { "tile[start:]", "tile[row, col:]", + "tile[row_start:, col_index]", } ) @@ -328,6 +329,7 @@ def get_pto_call_tier(call_name: str) -> str: "pto.constexpr": BASIC_TIER, "tile[start:]": BASIC_TIER, "tile[row, col:]": BASIC_TIER, + "tile[row_start:, col_index]": BASIC_TIER, # Advanced tier constructs "ptr": ADVANCED_TIER, # raw pointer constructor "strict_vecscope": ADVANCED_TIER, # explicit vecscope management diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index fee3a46a3..161629d16 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -129,6 +129,7 @@ def test_stable_starter_surface_groups_map_to_stable_tier(self) -> None: self.assertIn("pto.vmuls", AUTHORING_TIER_SURFACE_GROUPS["base_vector_ops"]) self.assertIn("tile[start:]", BASIC_TILE_INDEXING_SURFACES) self.assertIn("tile[row, col:]", BASIC_TILE_INDEXING_SURFACES) + self.assertIn("tile[row_start:, col_index]", BASIC_TILE_INDEXING_SURFACES) self.assertEqual(get_feature_tier("TensorView"), BASIC_TIER) self.assertEqual(get_feature_tier("Tile"), BASIC_TIER) @@ -174,6 +175,7 @@ def test_stable_starter_surface_groups_map_to_stable_tier(self) -> None: self.assertEqual(get_feature_tier("PadValue"), BASIC_TIER) self.assertEqual(get_feature_tier("tile[start:]"), BASIC_TIER) self.assertEqual(get_feature_tier("tile[row, col:]"), BASIC_TIER) + self.assertEqual(get_feature_tier("tile[row_start:, col_index]"), BASIC_TIER) def test_non_stable_surface_groups_keep_advanced_boundaries(self) -> None: self.assertEqual(get_surface_group_tier("strict_vecscope"), ADVANCED_TIER) @@ -1703,6 +1705,63 @@ def kernel(src: pto.Tile, dst: pto.Tile): self.assertIn('"INTLV_B32"', text) self.assertIn("pto.vsta", text) + def test_col_major_tile_vector_indexing_lowers_with_column_major_layout(self) -> None: + @pto.vkernel(op="vector_memory_col_major_unique", dtypes=[(pto.f32, pto.f32)]) + def kernel(src: pto.Tile, dst: pto.Tile): + align = pto.vldas(src[2:, 3]) + streamed, next_align, base_out = pto.vldus(src[2:, 3], align) + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec = pto.vlds(src[2:, 3]) + pto.vsts(vec, dst[2:, 3], mask) + pto.vsta(next_align, dst[2:, 3]) + return None + + col_major = pto.TileConfig.from_mapping({"layout": "col_major"}) + specialized = kernel.specialize( + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB, config=col_major), + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB, config=col_major), + ) + + text = specialized.mlir_text() + self.assertIn("blayout=col_major", text) + self.assertRegex(text, r"pto\.vlds %tmp_\d+\[%c2, %c3\] : memref<8x64xf32, #pto\.address_space> -> !pto\.vreg<64xf32>") + self.assertRegex(text, r"%tmp_\d+ = arith\.muli %c3, %c8 : index") + self.assertIn("pto.addptr", text) + self.assertIn("pto.vldus", text) + self.assertIn("pto.vsta", text) + + def test_row_major_tile_rejects_column_major_vector_indexing_syntax(self) -> None: + @pto.vkernel(op="row_major_rejects_col_major_index_unique", dtypes=[(pto.f32,)]) + def kernel(src: pto.Tile): + pto.vlds(src[1:, 2]) + return None + + specialized = kernel.specialize( + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + with self.assertRaises(TypeError) as ctx: + analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertIn("Tile[row, col:]", str(ctx.exception)) + + def test_col_major_tile_rejects_row_major_vector_indexing_syntax(self) -> None: + @pto.vkernel(op="col_major_rejects_row_major_index_unique", dtypes=[(pto.f32,)]) + def kernel(src: pto.Tile): + pto.vlds(src[1, 2:]) + return None + + specialized = kernel.specialize( + src=pto.TileSpecialization( + shape=(8, 64), + memory_space=pto.MemorySpace.UB, + config=pto.TileConfig.from_mapping({"layout": "col_major"}), + ), + ) + + with self.assertRaises(TypeError) as ctx: + analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertIn("Tile[row_start:, col_index]", str(ctx.exception)) + def test_advanced_stateful_vector_memory_surfaces_lower_with_pointer_state(self) -> None: @pto.vkernel(op="vector_memory_stateful_unique", dtypes=[(pto.f32, pto.f32)], advanced=True) def kernel(src: pto.Tile, dst: pto.Tile): From 59368d7fb43af5ad14c4db087a602faf69fd545f Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Sat, 11 Apr 2026 20:25:40 +0800 Subject: [PATCH 034/192] Enhance Tile indexing syntax documentation to clarify restrictions on slice forms --- tilelang-dsl/docs/user_guide/01-introduction.md | 2 +- .../docs/user_guide/09-vector-memory-operations.md | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/tilelang-dsl/docs/user_guide/01-introduction.md b/tilelang-dsl/docs/user_guide/01-introduction.md index 1c51cb8ca..1ef148b77 100644 --- a/tilelang-dsl/docs/user_guide/01-introduction.md +++ b/tilelang-dsl/docs/user_guide/01-introduction.md @@ -26,7 +26,7 @@ The TileLang DSL provides two distinct authoring modes: **Basic Mode (default)** - Uses **Tile element/slice semantics** for buffer access -- Direct tile indexing syntax: `tile[start:]`, `tile[row, col:]` +- Direct tile indexing syntax: `tile[start:]`, `tile[row, col:]`, `tile[row:, col]` (Tile indexing sugar only supports open-ended vector slices; explicit `stop` and `step` forms are not accepted for `Tile` indexing) - Vector operations use element-indexing syntax: `pto.vlds(tile[row, col:])`, `pto.vsts(vec, tile[start:], mask)` - No pointer arithmetic or explicit offset calculations - Suitable for most kernel authoring with high-level abstractions diff --git a/tilelang-dsl/docs/user_guide/09-vector-memory-operations.md b/tilelang-dsl/docs/user_guide/09-vector-memory-operations.md index eec44ddc9..539ea1895 100644 --- a/tilelang-dsl/docs/user_guide/09-vector-memory-operations.md +++ b/tilelang-dsl/docs/user_guide/09-vector-memory-operations.md @@ -41,6 +41,11 @@ The syntax supports two indexing modes for different operations: - **1D tile indexing**: `tile[start:]` (or equivalently `tile[0, start:]` for row-major or `tile[start:, 0]` for column-major) - `start:`: Starting element index followed by colon + Tile indexing sugar only accepts an open-ended vector slice. Python slice + forms with an explicit `stop` or `step` are not supported for `Tile` + indexing. For example, `tile[row, col:col_end]`, `tile[row, col::2]`, + `tile[row_start:row_end, col]`, and `tile[start:stop:step]` are invalid. + 2. **Single-element indexing** (for scalar load operations like `pto.vsld`): - **Row-major layout (default)**: `tile[row_index, col_index]` - `row_index`: Row index (0-based) @@ -118,6 +123,12 @@ The byte offset is automatically computed based on tile layout: 4. **Single-element operations**: The single-element indexing syntax (`tile[row, col]` or `tile[pos]`) is only supported for scalar load operations like `pto.vsld`. For other operations, use vector-range indexing with `:` syntax. +5. **No explicit slice bounds/stride for `Tile` indexing**: `Tile` vector-range + indexing only supports the open-ended forms `tile[start:]`, + `tile[row, col:]`, and `tile[row_start:, col_index]` (for column-major + layout). `stop` and `step` syntax are not accepted in user-guide Tile + indexing. + #### Supported Operations The indexing syntax is supported for all vector load and store operations with the following syntax mapping: From cad7e7fbe4c9660d52bb972004cda1f9c4ddbb3f Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Sat, 11 Apr 2026 20:31:49 +0800 Subject: [PATCH 035/192] Update openspec --- .../.openspec.yaml | 2 + .../design.md | 173 ++++++++++++++++++ .../proposal.md | 93 ++++++++++ .../specs/tilelang-dsl-surface/spec.md | 19 ++ .../specs/tilelang-dsl-vpto-lowering/spec.md | 86 +++++++++ .../tasks.md | 31 ++++ .../.openspec.yaml | 2 + .../design.md | 161 ++++++++++++++++ .../proposal.md | 81 ++++++++ .../specs/tilelang-dsl-diagnostics/spec.md | 42 +++++ .../specs/tilelang-dsl-surface/spec.md | 20 ++ .../specs/tilelang-dsl-vpto-lowering/spec.md | 50 +++++ .../tasks.md | 47 +++++ openspec/specs/tilelang-dsl-surface/spec.md | 32 ++++ 14 files changed, 839 insertions(+) create mode 100644 openspec/changes/archive/2026-04-11-close-tilelang-dsl-stable-dma-gaps/.openspec.yaml create mode 100644 openspec/changes/archive/2026-04-11-close-tilelang-dsl-stable-dma-gaps/design.md create mode 100644 openspec/changes/archive/2026-04-11-close-tilelang-dsl-stable-dma-gaps/proposal.md create mode 100644 openspec/changes/archive/2026-04-11-close-tilelang-dsl-stable-dma-gaps/specs/tilelang-dsl-surface/spec.md create mode 100644 openspec/changes/archive/2026-04-11-close-tilelang-dsl-stable-dma-gaps/specs/tilelang-dsl-vpto-lowering/spec.md create mode 100644 openspec/changes/archive/2026-04-11-close-tilelang-dsl-stable-dma-gaps/tasks.md create mode 100644 openspec/changes/archive/2026-04-11-migrate-tilelang-dsl-inline-proc-to-backend-inline/.openspec.yaml create mode 100644 openspec/changes/archive/2026-04-11-migrate-tilelang-dsl-inline-proc-to-backend-inline/design.md create mode 100644 openspec/changes/archive/2026-04-11-migrate-tilelang-dsl-inline-proc-to-backend-inline/proposal.md create mode 100644 openspec/changes/archive/2026-04-11-migrate-tilelang-dsl-inline-proc-to-backend-inline/specs/tilelang-dsl-diagnostics/spec.md create mode 100644 openspec/changes/archive/2026-04-11-migrate-tilelang-dsl-inline-proc-to-backend-inline/specs/tilelang-dsl-surface/spec.md create mode 100644 openspec/changes/archive/2026-04-11-migrate-tilelang-dsl-inline-proc-to-backend-inline/specs/tilelang-dsl-vpto-lowering/spec.md create mode 100644 openspec/changes/archive/2026-04-11-migrate-tilelang-dsl-inline-proc-to-backend-inline/tasks.md diff --git a/openspec/changes/archive/2026-04-11-close-tilelang-dsl-stable-dma-gaps/.openspec.yaml b/openspec/changes/archive/2026-04-11-close-tilelang-dsl-stable-dma-gaps/.openspec.yaml new file mode 100644 index 000000000..2fe001e67 --- /dev/null +++ b/openspec/changes/archive/2026-04-11-close-tilelang-dsl-stable-dma-gaps/.openspec.yaml @@ -0,0 +1,2 @@ +schema: spec-driven +created: 2026-04-07 diff --git a/openspec/changes/archive/2026-04-11-close-tilelang-dsl-stable-dma-gaps/design.md b/openspec/changes/archive/2026-04-11-close-tilelang-dsl-stable-dma-gaps/design.md new file mode 100644 index 000000000..8ea8d12be --- /dev/null +++ b/openspec/changes/archive/2026-04-11-close-tilelang-dsl-stable-dma-gaps/design.md @@ -0,0 +1,173 @@ +## Context + +### 范围 + +本 design 只覆盖 TileLang DSL stable 模式下的 2D TensorView slicing 与 high-level DMA surface: + +- `pto.dma_load` +- `pto.dma_store` +- `PadMode` +- 与 DMA inference 直接相关的 TensorView slice `start/stop/step` + +它不覆盖: + +- matcher / registry / advanced surface +- rank > 2 的 TensorView/DMA profile +- backend 新 capability +- A5 text / LLVM emission 规则变更 + +### 当前状态 + +当前实现存在四个直接相关的事实: + +1. `tilelang-dsl/docs/tilelang-dsl-guide.md` 已把 stable DMA 描述为高层自动推导接口,并公开了 `PadMode`、padding 参数以及更宽的 slicing 语义。 +2. `tilelang-dsl/python/tilelang_dsl/frontend_ast.py` 的 `FrontendCallExpr` 目前只保存 positional `args`,无法承载 DMA keyword 参数。 +3. `tilelang-dsl/python/tilelang_dsl/semantic.py` 当前把 TensorView slice 限定为: + - rank-2 + - 显式 `stop` + - `start == 0` + - `step == 1` + 同时 `SemanticTensorSliceType` 只保留 extent,不能表达 `start/stop/step`。 +4. `tilelang-dsl/python/tilelang_dsl/lowering.py` 目前把 stable DMA 展开为固定常量参数的 `set_loop_size_* + copy_*` 组合,无法从 slice layout 推导 offset / stride / trim / padding。 + +与此同时,repo 内部的 PTO/VPTO lowering 已经存在可参考的 shape/stride contract,但 TileLang DSL stable path 当前仍选择直接产出 authoring-form VPTO text,而不是先 materialize `pto.tload` / `pto.tstore` 再复用 backend lowering。 + +### 实现约束 + +- 本 change 保持当前 `tilelang-dsl -> authoring-form VPTO` 主线,不引入新的公开中间 IR。 +- stable DMA profile 继续限定为 statically specialized rank-2 UB Tile。 +- 设计必须显式区分“前端可综合的高层行为”和“当前 authoring/backend 路径无法真实承载的行为”,不能靠 silent no-op 冒充支持。 +- `unsupported-features.md`、guide 和 OpenSpec 需要同步更新,避免继续出现 contract 漂移。 + +## Goals / Non-Goals + +**Goals:** + +- 让 stable DMA surface 真正接受 `PadMode` 和 keyword 参数,而不是继续停留在 2 参数最小形态。 +- 让 stable TensorView slice 在 2D profile 内支持 non-zero/dynamic start、dynamic stop 和静态正步长。 +- 让 lowering 基于 normalized slice layout、TensorView shape 和 Tile `valid_shape` 推导 offset / stride / loop size。 +- 为 padded `dma_load` 与 trimmed `dma_store` 定义稳定、可测试的 frontend-only 行为。 +- 对当前 frontend-only 路径不能真实表达的行为给出明确边界和 diagnostics。 + +**Non-Goals:** + +- 不支持 rank > 2 slice / DMA。 +- 不支持 dynamic `step`。 +- 不支持第 1 轴 stepped DMA。 +- 不把 stable DMA 改写为 backend-driven `pto.tload` / `pto.tstore` pipeline。 +- 不在本 change 中补齐 GM-side fill 或 backend-init capability。 + +## Decisions + +### 1. 保持 stable DMA 继续直接 lower 到 authoring-form VPTO,而不是切换到 `pto.tload` / `pto.tstore` + +决策: + +- `tilelang-dsl/python/tilelang_dsl/lowering.py` 继续直接生成 `pto.addptr`、`set_loop*_stride_*`、`set_loop_size_*` 和 `copy_*`。 +- 不在本 change 中改变 TileLang DSL 的 lowering 架构边界。 + +原因: + +- 当前 `descriptor.mlir_text()` / `verify()` 已经围绕 authoring-form VPTO 建立稳定路径。 +- 这次 change 的核心问题是 stable contract 缺口,不是重新设计整体 lowering pipeline。 + +备选方案: + +- 先 materialize `pto.tload` / `pto.tstore` 再复用 repo lowering + - 放弃原因:会扩大本 change 的影响面,并把“补 stable DMA 缺口”变成“改写 DSL lowering 主线”。 + +### 2. 扩展前端 AST 与 semantic 模型,保留 normalized slice 与 DMA options,而不是只保存 extent + +决策: + +- `FrontendCallExpr` 增加 keyword 参数承载。 +- stable DMA call 在 semantic 层保存: + - `pad_mode` + - `pad_value` + - `left_padding` + - `right_padding` + - `init_out_buffer` +- TensorView slice 保存标准化后的每轴 `start/stop/step`,不再只保存 `extent`。 + +原因: + +- 真实的 DMA inference 依赖 offset、step 和 trim 信息;仅保留 extent 无法推导 `pto.addptr` 与 loop stride。 +- keyword 参数必须在 frontend 就固定下来,不能等到 lowering 再反向猜测。 + +备选方案: + +- 继续沿用 “slice -> extents only” 的窄模型 + - 放弃原因:无法闭合 non-zero start、outer-axis step、padding/trim 等本 change 目标。 + +### 3. stable slicing profile 固定为“2D + 显式 stop + dynamic start/stop + 静态正步长”,且第 1 轴仍要求 `step == 1` + +决策: + +- 继续只支持 rank-2 TensorView slicing。 +- `stop` 必须显式给出。 +- `start` 可以是常量或 runtime index expr。 +- `step` 必须是静态正整数。 +- 第 0 轴允许 `step > 1`。 +- 第 1 轴必须保持 `step == 1`。 + +原因: + +- 这样既能覆盖 guide 中最核心的 stride-aware authoring,又不会把 stable DMA 推入当前 copy-family 无法表达的 gather/scatter 形态。 + +备选方案: + +- 完全放开 dynamic `step` + - 放弃原因:dynamic stride 会把 DMA legality、shape 校验和 lowering 复杂度整体抬高,不适合作为 stable 2D profile 的第一步。 +- 两个轴都允许 stepped DMA + - 放弃原因:第 1 轴 stepped copy 需要当前 stable authoring path 尚未具备的 gather/scatter 语义。 + +### 4. `dma_load` 采用 “prefill + interior copy” 的 frontend-only padding 方案;`dma_store` 采用 trim-and-validate 方案 + +决策: + +- `dma_load`: + - 先按 `pad_mode`、`pad_value`、`left_padding`、`right_padding` 决定目的 Tile 需要的 padding band。 + - padding band 由 frontend 生成稳定的 prefill lowering。 + - interior 数据仍通过 `copy_gm_to_ubuf` 进入 UB。 +- `dma_store`: + - `left_padding/right_padding` 定义 source tile 的 interior window。 + - lowering 只把 interior window 写回 destination slice。 + - GM-side fill 不属于本 change。 + +原因: + +- `dma_load` 的 padding 可以在 UB 侧通过 frontend 合成,不需要 backend 新 contract。 +- `dma_store` 若要把 padding value 主动写到 GM 边缘,会要求当前 frontend-only 路径之外的 GM-side fill 语义,因此本 change 选择 trim-and-validate。 + +备选方案: + +- 让 `dma_store` 也尝试合成完整 GM-side fill + - 放弃原因:会越过当前 stable authoring path 的真实边界,并引入额外 backend 依赖。 +- 继续让 non-default padding 参数全部 reject + - 放弃原因:无法收敛 guide 与 stable surface 的主要缺口。 + +### 5. 对 frontend-only 路径尚无稳定承载的行为保持显式诊断,而不是伪支持 + +决策: + +- `dma_store` 的 `pad_mode != PadNull`、GM-side fill 和等价 backend-init 语义继续显式 reject。 +- `init_out_buffer` 仅在能够映射到本次定义的 frontend prefill 行为时放行;超出该 profile 的组合继续报错。 +- diagnostics 要直接说明“当前 stable frontend-only DMA profile 未支持的原因”,而不是退化成模糊 type error。 + +原因: + +- 这能避免 public surface 看起来“已经支持”,但实际 lowering 只是在 silently ignore 参数。 + +## Risks / Trade-offs + +- [Risk] 在 frontend 合成 padded-load prefill 会引入额外 lowering 复杂度 + Mitigation:严格限定到 rank-2、静态 Tile shape、outer-axis stepped profile,并用 regression 锁定生成形态。 + +- [Risk] `dma_store` 仍然不能提供 guide 中最理想的 GM-side fill 语义 + Mitigation:在 spec 和 guide 中明确把 store padding 收敛为 trim-and-validate,而不是继续保留模糊承诺。 + +- [Risk] dynamic start/stop 与步长推导会增加 semantic 和 lowering 状态量 + Mitigation:统一在 semantic 层做 normalized slice 表示,避免后续多个 lowering helper 各自解释 slice。 + +- [Risk] `PadMode` public surface 放开后,用户更容易触碰尚未实现的组合 + Mitigation:在 diagnostics 和 `unsupported-features.md` 中精确列出仍被限制的组合,不使用“只支持 2 参数”这类过粗表述。 diff --git a/openspec/changes/archive/2026-04-11-close-tilelang-dsl-stable-dma-gaps/proposal.md b/openspec/changes/archive/2026-04-11-close-tilelang-dsl-stable-dma-gaps/proposal.md new file mode 100644 index 000000000..7a699d14f --- /dev/null +++ b/openspec/changes/archive/2026-04-11-close-tilelang-dsl-stable-dma-gaps/proposal.md @@ -0,0 +1,93 @@ +# Proposal: 补齐 TileLang DSL stable 模式的 DMA 与切片契约 + +## 概述 + +`tilelang-dsl/docs/tilelang-dsl-guide.md` 已经把 stable `dma_load` / `dma_store` 描述为高层、可自动推导 DMA 参数的默认数据搬运接口,并公开了 `PadMode`、padding 参数和更宽的 TensorView slicing 语义。 +但当前 `tilelang-dsl/python/` 实现仍停留在最小 2 参数 DMA 和极窄 slice profile,导致 guide、unsupported 文档和真实 lowering 行为长期脱节。 + +本 change 的目标是在不切换出当前 `tilelang-dsl -> authoring-form VPTO` 前端主线的前提下,把 stable DMA/slice 契约收敛到一个真实、可测试、可文档化的 2D profile:补齐非零/dynamic start、静态正步长、DMA 参数自动推导、stable `PadMode` 入口,以及前端可安全落地的 padding / trim 行为。 + +## 背景与动机 + +当前仓库中已经存在以下明显缺口: + +- guide 在 stable 章节承诺 `dma_load` / `dma_store` 会从 TensorView slice 自动推导 stride 与 loop size,但实现仍只支持 contiguous、zero-based、unit-step 的 rank-2 slice。 +- guide 公开了 `PadMode`、`pad_value`、`left_padding`、`right_padding`、`init_out_buffer` 等高层接口,但当前 stable path 只接受 2 参数 DMA。 +- `semantic.py` 目前把 TensorView slice 压缩成仅有 extent 的窄模型,无法忠实表达 `start/stop/step`,也无法支撑更真实的 DMA inference。 +- `unsupported-features.md` 已经承认这些缺口存在,但 stable surface 长期停留在“文档已承诺、实现未补齐”的状态,会持续误导使用者并阻碍后续 sample / regression 编写。 + +如果不把这部分契约收敛清楚,stable 模式就无法成为可信的默认 authoring path,后续 matcher、advanced surface 和 sample 覆盖也会继续建立在不稳定边界之上。 + +## 目标 + +- 补齐 stable `dma_load` / `dma_store` 的 keyword surface,使 `PadMode`、padding 参数和相关配置能进入 frontend 语义分析。 +- 把 stable TensorView slicing 扩展到可支持 non-zero/dynamic start、dynamic stop、静态正步长,并把这些信息保留到 lowering。 +- 让 stable DMA inference 从 slice `start/stop/step`、TensorView shape、Tile `shape/valid_shape` 推导 pointer offset、loop stride 和 loop size,而不再只支持 full-tile contiguous profile。 +- 定义 frontend-only stable padding 行为: + - `dma_load` 支持 padding 参数与可实现的 padded-load lowering。 + - `dma_store` 支持基于 `left_padding/right_padding` 的 interior trim;GM-side fill 继续保持显式限制。 +- 同步更新 tests、guide、unsupported 文档与 OpenSpec,使 stable contract 与实际实现重新一致。 + +## 非目标 + +- 不在本 change 中扩展到 rank > 2 的 TensorView slicing 或 DMA profile。 +- 不支持省略 `stop` 的 Python slice 语义。 +- 不支持 dynamic `step`,也不支持第 1 轴的 stepped gather/scatter DMA。 +- 不借本 change 重新设计 TileLang DSL 的整体 lowering 架构,不把 stable DMA 改写为新的公开中间 IR。 +- 不在本 change 中引入新的 backend capability;GM-side fill、任意 `pad_value` 直通 backend、以及 `init_out_buffer` 的完整 backend 语义不作为完成标准。 + +## What Changes + +- 扩展 stable `dma_load` / `dma_store` public surface,支持 `pad_mode`、`pad_value`、`left_padding`、`right_padding`、`init_out_buffer` keyword 参数,并公开 `PadMode`。 +- 扩展 stable TensorView slicing,支持 rank-2 profile 下的 non-zero/dynamic start、dynamic stop 和静态正步长,并把标准化后的 slice 元信息保留到 semantic layer。 +- 修改 stable DMA lowering,使其从 normalized slice layout 与 Tile `valid_shape` 推导 `pto.addptr`、`set_loop*_stride_*`、`set_loop_size_*` 和 copy-family 参数。 +- 为 padded load 与 trimmed store 定义 frontend-only authoring 行为,并对当前 frontend-only 路径无法真实表达的 GM-side fill / backend-init 语义给出显式诊断边界。 +- 更新 `tilelang-dsl/docs/tilelang-dsl-guide.md`、`tilelang-dsl/docs/unsupported-features.md`、相关 tests 与 examples,使文档承诺、OpenSpec 契约和实现支持面保持一致。 + +## Capabilities + +### New Capabilities + +- 无 + +### Modified Capabilities + +- `tilelang-dsl-surface`: 为 stable `dma_load` / `dma_store` 明确 keyword 参数和 `PadMode` public surface。 +- `tilelang-dsl-vpto-lowering`: 扩展 stable 2D slicing / DMA inference / padding-trim lowering 契约,并明确 frontend-only 边界。 + +## 预期结果 + +- stable `dma_load` / `dma_store` 不再只是文档中的高层愿景,而是具备清晰、可测试的 2D authoring contract。 +- stable slicing 和 DMA inference 能覆盖 guide 中最核心的 dynamic start / stride-aware authoring 场景,而不是继续被 zero-based contiguous profile 卡死。 +- guide、unsupported 文档和 OpenSpec 不再对 stable DMA 能力给出互相冲突的描述。 +- 对当前 frontend-only 路径暂时无法承载的行为,frontend 会给出明确、工程化的限制诊断,而不是 silently accept 或产生误导性输出。 + +## 成功标准 + +- 新增 `openspec/changes/close-tilelang-dsl-stable-dma-gaps/`,包含 `proposal.md`、`design.md`、`tasks.md`。 +- 新增 spec delta: + - `openspec/changes/close-tilelang-dsl-stable-dma-gaps/specs/tilelang-dsl-surface/spec.md` + - `openspec/changes/close-tilelang-dsl-stable-dma-gaps/specs/tilelang-dsl-vpto-lowering/spec.md` +- proposal/design/tasks/specs 明确写清: + - stable DMA keyword surface 和 `PadMode` 入口 + - stable 2D slice profile 的动态起点与静态步长边界 + - automatic DMA inference 的输入来源 + - padded `dma_load` 与 trimmed `dma_store` 的 frontend-only 行为 + - 当前 frontend-only 路径下继续延期的 GM-side fill / backend-init 语义 + +## Impact + +- 受影响目录: + - `tilelang-dsl/python/` + - `tilelang-dsl/tests/` + - `tilelang-dsl/examples/` + - `tilelang-dsl/docs/` + - `openspec/specs/` +- 受影响 public API: + - `pto.dma_load(...)` + - `pto.dma_store(...)` + - `PadMode` + - TensorView slicing 的 stable accepted profile +- 受影响验证路径: + - stable authoring kernel 的 semantic/lowering regression + - `descriptor.mlir_text()` 输出的 DMA programming / copy-family 形态 diff --git a/openspec/changes/archive/2026-04-11-close-tilelang-dsl-stable-dma-gaps/specs/tilelang-dsl-surface/spec.md b/openspec/changes/archive/2026-04-11-close-tilelang-dsl-stable-dma-gaps/specs/tilelang-dsl-surface/spec.md new file mode 100644 index 000000000..ab122f10e --- /dev/null +++ b/openspec/changes/archive/2026-04-11-close-tilelang-dsl-stable-dma-gaps/specs/tilelang-dsl-surface/spec.md @@ -0,0 +1,19 @@ +## ADDED Requirements + +### Requirement: stable `dma_load` / `dma_store` surface MUST expose keyword DMA options and `PadMode` + +TileLang DSL stable surface MUST 为 `pto.dma_load` / `pto.dma_store` 暴露正式的 keyword 参数入口,而不再把 high-level DMA 限定为 2 参数最小形态。 +stable public API MUST 提供 `PadMode`,并允许用户在 `dma_load` / `dma_store` 中通过 `pad_mode`、`pad_value`、`left_padding`、`right_padding`、`init_out_buffer` 表达高层 DMA 意图。 +frontend MUST 保留这些参数的语义信息进入后续 semantic/lowering,而不是在 AST 构建阶段静默丢弃 keyword 参数。 + +#### Scenario: stable DMA call accepts keyword arguments + +- **WHEN** 用户在 stable kernel 中编写 `pto.dma_load(src_slice, dst_tile, pad_mode=PadMode.PadValue, pad_value=..., left_padding=2, right_padding=2)` +- **THEN** frontend MUST 接受该 call surface 的 keyword 形态 +- **AND** keyword 参数 MUST 保留到后续 semantic 分析,而不是被静默忽略 + +#### Scenario: `PadMode` is part of the stable public surface + +- **WHEN** 用户在 TileLang DSL 中引用 `PadMode.PadNull`、`PadMode.PadFirstElem` 或 `PadMode.PadValue` +- **THEN** frontend MUST 识别这些 stable surface 符号 +- **AND** `PadMode` MUST NOT 继续被归类为 unsupported language construct diff --git a/openspec/changes/archive/2026-04-11-close-tilelang-dsl-stable-dma-gaps/specs/tilelang-dsl-vpto-lowering/spec.md b/openspec/changes/archive/2026-04-11-close-tilelang-dsl-stable-dma-gaps/specs/tilelang-dsl-vpto-lowering/spec.md new file mode 100644 index 000000000..9134a208d --- /dev/null +++ b/openspec/changes/archive/2026-04-11-close-tilelang-dsl-stable-dma-gaps/specs/tilelang-dsl-vpto-lowering/spec.md @@ -0,0 +1,86 @@ +## MODIFIED Requirements + +### Requirement: TileLang DSL v1 MUST support static physical Tile shape with dynamic TensorView views and loop bounds + +TileLang DSL v1 中,Tile physical shape MUST 是静态编译期常量。 +TensorView shape、slice 边界、loop bound 和 tail 相关 remaining value MAY 包含 runtime value。 +stable TensorView slicing 在本 profile 下 MUST 满足以下约束: + +- 仍只支持 rank-2 TensorView slice +- `stop` MUST 显式给出 +- `start` MAY 为常量或 runtime index expr +- `step` MUST 是静态正整数 +- 第 0 轴 MAY 使用 `step > 1` +- 第 1 轴 MUST 保持 `step == 1` + +`valid_shape` 仅可使用静态值或由 TensorView partition 直接推导。 +semantic / lowering MUST 保留标准化后的 `start/stop/step`,而不是只保留 slice extent。 + +#### Scenario: dynamic TensorView slice with non-zero start lowers successfully + +- **WHEN** 用户在 stable kernel 中使用 `tensor[row_start:row_stop, col_start:col_stop]`,且 `row_start`、`row_stop`、`col_start`、`col_stop` 含 runtime index value +- **THEN** frontend MUST 生成合法的 authoring-form VPTO IR +- **AND** lowering MUST 能从标准化 slice 推导对应的 pointer offset 与 transfer extent +- **AND** Tile physical shape MUST 继续保持静态契约 + +#### Scenario: outer-axis static stepped slice is accepted for stable DMA inference + +- **WHEN** 用户在 stable kernel 中使用第 0 轴带静态正步长的 TensorView slice,例如 `tensor[0:rows:2, 0:16]` +- **THEN** frontend MUST 接受该 stable slice profile +- **AND** lowering MUST 把该步长纳入 DMA stride / loop-size 推导,而不是退化成 contiguous copy + +#### Scenario: unsupported slice profile is rejected before lowering + +- **WHEN** 用户使用 rank > 2 slice、缺失 `stop`、dynamic `step` 或第 1 轴 `step != 1` 的 TensorView slice +- **THEN** frontend MUST 在生成 VPTO IR 之前报错 +- **AND** 诊断 MUST 明确指出超出了当前 stable 2D slice / DMA profile + +### Requirement: `dma_load` and `dma_store` MUST lower to VPTO DMA programming plus copy ops + +TileLang DSL 的高层 `dma_load` / `dma_store` MUST 在 frontend lower 到当前合法 VPTO authoring surface: + +- GM -> UB:必要的 `set_loop*_stride_outtoub` / `set_loop_size_outtoub` + `copy_gm_to_ubuf` +- UB -> GM:必要的 `set_loop*_stride_ubtoout` / `set_loop_size_ubtoout` + `copy_ubuf_to_gm` + +DMA 参数 MUST 由标准化后的 TensorView slice `start/stop/step`、TensorView shape、Tile `shape/valid_shape` 和 padding/trim 配置推导。 +对 non-zero start 的 stable DMA,lowering MUST 生成等价的 pointer offset(例如 `pto.addptr`),而不是假定 source / destination 总是 zero-based contiguous view。 +对带 outer-axis 静态步长的 stable DMA,lowering MUST 把步长纳入 stride/loop-size 推导。 + +`dma_load` 的 stable padding contract MUST 满足: + +- `left_padding` / `right_padding` 影响 destination Tile 内的有效 copy window +- padded load MAY 在 `copy_gm_to_ubuf` 之前生成额外的 frontend prefill lowering +- `PadNull`、`PadFirstElem`、`PadValue` 的可实现子集 MUST 有明确 lowering 语义 + +`dma_store` 的 stable contract MUST 满足: + +- `left_padding` / `right_padding` 定义 source Tile 的 interior trim window +- lowering MUST 只把 trim 后的 interior window 写回 destination TensorView slice +- 需要 GM-side fill 的 store padding 组合 MUST 显式 reject,直到 stable path 具备对应公开承载 + +当前 frontend-only stable path 无法真实表达的 backend-init / GM-side fill 语义 MUST 显式诊断,MUST NOT 通过 silently ignore 参数伪装成“已支持”。 + +#### Scenario: non-zero-start DMA lowers through inferred pointer offset and dynamic strides + +- **WHEN** 用户在 stable kernel 中编写 `pto.dma_load(inp[row_start:row_stop, 0:16], tile)` 或 `pto.dma_store(tile, out[row_start:row_stop, 0:16])` +- **THEN** lowering MUST 生成对应的 DMA programming op 和 copy op +- **AND** 生成结果 MUST 反映 slice start 带来的 pointer offset 与 stride 变化 +- **AND** 生成结果 MUST 符合当前 VPTO copy-family 的 authoring contract + +#### Scenario: padded stable load lowers through prefill plus interior copy + +- **WHEN** 用户在 stable kernel 中编写带 `pad_mode`、`left_padding`、`right_padding` 的 `pto.dma_load(src_slice, dst_tile, ...)` +- **THEN** lowering MUST 生成与该 padding 语义一致的 prefill/copy 组合 +- **AND** destination Tile 的 interior copy window MUST 与 padding 后的 shape contract 一致 + +#### Scenario: trimmed stable store writes the interior window only + +- **WHEN** 用户在 stable kernel 中编写 `pto.dma_store(src_tile, dst_slice, left_padding=1, right_padding=1)` +- **THEN** lowering MUST 只把 `src_tile` 的 interior trim window 写回 `dst_slice` +- **AND** `dst_slice` 的 extent MUST 与 trim 后的 source window 相匹配 + +#### Scenario: unsupported store fill or init combination is rejected explicitly + +- **WHEN** 用户在 stable kernel 中请求当前 frontend-only path 尚无真实承载的 GM-side fill、store `pad_mode != PadNull` 或其他等价 backend-init 组合 +- **THEN** frontend MUST 在生成 VPTO IR 之前报错 +- **AND** 诊断 MUST 明确指出该组合超出了当前 stable frontend-only DMA profile diff --git a/openspec/changes/archive/2026-04-11-close-tilelang-dsl-stable-dma-gaps/tasks.md b/openspec/changes/archive/2026-04-11-close-tilelang-dsl-stable-dma-gaps/tasks.md new file mode 100644 index 000000000..8584d70fa --- /dev/null +++ b/openspec/changes/archive/2026-04-11-close-tilelang-dsl-stable-dma-gaps/tasks.md @@ -0,0 +1,31 @@ +## 1. OpenSpec 契约落定 + +- [x] 1.1 新增 `openspec/changes/close-tilelang-dsl-stable-dma-gaps/specs/tilelang-dsl-surface/spec.md`,固定 stable DMA keyword surface、`PadMode` 和 public accepted profile。 +- [x] 1.2 新增 `openspec/changes/close-tilelang-dsl-stable-dma-gaps/specs/tilelang-dsl-vpto-lowering/spec.md`,固定 stable 2D slicing、DMA inference、padded `dma_load` 和 trimmed `dma_store` 的 lowering 契约。 +- [x] 1.3 在 `proposal.md` 和 `design.md` 中明确本 change 保持 frontend-only authoring path,不引入新的 backend capability。 + +## 2. Frontend surface 与语义模型 + +- [x] 2.1 在 `tilelang-dsl/python/tilelang_dsl/frontend_ast.py` 和相关 validator 中为 `pto.dma_load` / `pto.dma_store` 增加 keyword 参数承载与校验。 +- [x] 2.2 在 `tilelang-dsl/python/tilelang_dsl/types.py`、`__init__.py`、`support_matrix.py` 中补齐 stable `PadMode` public surface 与 tier/unsupported 边界。 +- [x] 2.3 在 `tilelang-dsl/python/tilelang_dsl/semantic.py` 中引入 normalized TensorView slice 表示,保留每轴 `start/stop/step` 与 DMA option 字段。 +- [x] 2.4 基于 Tile `shape/valid_shape`、slice extent 与 padding/trim 规则实现 stable DMA shape/profile 校验,并为 unsupported 组合提供明确 diagnostics。 + +## 3. Stable DMA lowering + +- [x] 3.1 在 `tilelang-dsl/python/tilelang_dsl/lowering.py` 中实现 non-zero/dynamic start 的 pointer offset lowering,以及基于 slice layout 的 loop stride / loop size 推导。 +- [x] 3.2 实现 outer-axis static-step DMA inference,确保 `set_loop*_stride_*` 与 copy-family 参数不再固定为 full-tile contiguous 常量。 +- [x] 3.3 实现 padded `dma_load` 的 frontend-only prefill + interior copy lowering,并处理 `PadNull` / `PadFirstElem` / `PadValue` 的稳定子集。 +- [x] 3.4 实现 `dma_store` 的 interior trim lowering,并对 GM-side fill / 非 `PadNull` store padding 组合保持 fail-fast reject。 + +## 4. 回归、文档与验证 + +- [x] 4.1 在 `tilelang-dsl/tests/test_tilelang_dsl_v1.py` 中增加正向 regression,覆盖 dynamic start/stop、outer-axis static step、padded `dma_load` 和 trimmed `dma_store`。 +- [x] 4.2 增加负向 regression,覆盖 dynamic step、第 1 轴 stepped slice、shape/padding 不匹配、unsupported store fill、unsupported init 组合等边界。 +- [x] 4.3 更新 `tilelang-dsl/docs/tilelang-dsl-guide.md`、`tilelang-dsl/docs/unsupported-features.md` 和相关 migration/support 文档,使 stable DMA contract 与实现一致。 +- [x] 4.4 运行并记录最小验证命令,确认新增 stable DMA regression 能通过 `tilelang-dsl/tests/` 路径下的最小相关测试集。 + +### 验证记录 + +- `PYTHONPATH=tilelang-dsl/python python3 -m unittest tilelang-dsl.tests.test_tilelang_dsl_v1` +- `openspec validate close-tilelang-dsl-stable-dma-gaps --type change --strict --json --no-interactive` diff --git a/openspec/changes/archive/2026-04-11-migrate-tilelang-dsl-inline-proc-to-backend-inline/.openspec.yaml b/openspec/changes/archive/2026-04-11-migrate-tilelang-dsl-inline-proc-to-backend-inline/.openspec.yaml new file mode 100644 index 000000000..e49efd11c --- /dev/null +++ b/openspec/changes/archive/2026-04-11-migrate-tilelang-dsl-inline-proc-to-backend-inline/.openspec.yaml @@ -0,0 +1,2 @@ +schema: spec-driven +created: 2026-04-10 diff --git a/openspec/changes/archive/2026-04-11-migrate-tilelang-dsl-inline-proc-to-backend-inline/design.md b/openspec/changes/archive/2026-04-11-migrate-tilelang-dsl-inline-proc-to-backend-inline/design.md new file mode 100644 index 000000000..24c1abf11 --- /dev/null +++ b/openspec/changes/archive/2026-04-11-migrate-tilelang-dsl-inline-proc-to-backend-inline/design.md @@ -0,0 +1,161 @@ +## Context + +### 范围 + +本 design 覆盖 `inline_proc` 的端到端迁移: + +- TileLang frontend:不再展开 inline_proc,改为保留 helper/call。 +- TileLang semantic/lowering:建立 helper 函数与 `func.call` 语义模型,支持返回值调用。 +- `ptoas` backend:在 VPTO backend 主线中强制 inline 并删除死 helper。 +- OpenSpec/文档/测试:同步新契约。 + +不覆盖: + +- 多模块 inline_proc 解析与跨模块可见性扩展。 +- `*args` / `**kwargs` / kw-only 参数模型。 +- matcher/registry 与 inline_proc 的新交互语义。 + +### 当前状态 + +1. `inline_proc` 当前在 frontend AST 阶段做语句级展开,导致宏替换行为与函数语义混杂。 +2. 当前实现通过捕获校验与名字改写保障可用,但复杂度和维护成本较高。 +3. `specialized.mlir_text()` 目前默认看不到 `func.call`,inline 边界与 backend 优化边界不清晰。 +4. `ptoas` 已有 `PTOInlineLibCall` pass 与 TileLang template 路径,但当前 pass 对“带返回值 call”与“TileLang kernel 作为 caller”支持不完整,不足以直接承接新版 inline_proc。 + +### 实现约束 + +- 不保留 feature switch,直接替换现有 frontend-expand 语义。 +- 解析范围固定为“同模块显式注册 inline_proc”。 +- 必须保持 fail-fast:隐式捕获、递归/互递归、不支持参数模型在 frontend 直接报错。 +- `mlir_text()` 允许 `func.call`;最终 backend 产物必须消除 inline_proc 调用。 + +## Goals / Non-Goals + +**Goals:** + +- 以函数语义重建 inline_proc:默认参数、关键字调用、返回表达式、表达式调用全部可用。 +- 让 frontend 专注语义建模,inline 优化责任收敛到 `ptoas` backend 主线。 +- 保证 `PTOInlineLibCall` 能内联带结果调用,并清理私有死 helper。 +- 用回归测试同时锁定 frontend 语义与 backend inline 收敛行为。 + +**Non-Goals:** + +- 不新增跨模块 inline_proc 导入调用规则。 +- 不支持 `*args` / `**kwargs` / kw-only。 +- 不把 inline_proc 语义下放到 emission 阶段;仍在 VPTO backend 主线 early stage 完成。 + +## Decisions + +### 1. 前端不展开 inline_proc,改为 helper+call 模型 + +决策: + +- `build_frontend_kernel_node` 同时产出 kernel body 与 inline_proc body(受控子集)。 +- kernel/body 中 inline_proc 调用保留为命名调用节点,不在前端替换为语句块。 + +原因: + +- 彻底消除前端宏展开复杂度,避免返回语义与作用域卫生问题反复回归。 + +备选方案: + +- 保留 frontend 展开并继续扩展语义。 + - 放弃原因:维护复杂度高,且与“函数语义优先”目标冲突。 + +### 2. 参数绑定采用 Python 子集:位置参数 + 关键字 + 默认值 + +决策: + +- 调用绑定支持 positional + keyword + defaults。 +- 继续 reject:`*args` / `**kwargs` / kw-only。 +- 绑定错误(重复赋值、缺参、未知关键字)在 frontend fail-fast。 + +原因: + +- 能满足可用性诉求,同时保持实现边界可控。 + +备选方案: + +- 一次性支持完整 Python 参数模型。 + - 放弃原因:复杂度和歧义显著上升,不适合作为本次迁移范围。 + +### 3. 解析范围固定为同模块显式注册 + +决策: + +- 仅允许解析当前 kernel 所在模块内已注册的 `@inline_proc`。 +- 不做全局唯一名 fallback,不做跨模块自动解析。 + +原因: + +- 避免解析歧义,确保 descriptor 构建与 materialization 行为一致。 + +### 4. backend-inline 接入 VPTO backend 主线早期 + +决策: + +- `PTOInlineLibCall` 在 `addVPTOBackendMainlinePasses` 中无条件接入(不依赖 `--enable-tile-op-expand`)。 +- pass 运行位置固定在 VPTO authoring validation 之前。 + +原因: + +- 保证后续依赖平坦 body 的 pass 不会看到 inline_proc helper 调用。 + +### 5. inline 目标限定为 TileLang inline helper,且支持带结果 call + +决策: + +- helper 函数增加专用属性(`pto.tilelang.inline_proc`)。 +- `PTOInlineLibCall` 仅对受控 inlineable callee 生效。 +- 扩展 `inlineCall` 支持 call result 映射与替换,不再仅限 `() -> ()`。 + +原因: + +- 既满足功能需求,又避免误内联普通调用。 + +### 6. `mlir_text()` 保留 `func.call` 可观察边界 + +决策: + +- `mlir_text()` 输出允许包含 kernel + private inline helper + `func.call`。 +- “最终无调用残留”只承诺在 `ptoas` backend 主线后成立。 + +原因: + +- 调试可观测性更好,前后端职责边界清晰。 + +## 测试策略 + +- Python 单测: + - 默认参数、关键字调用、返回表达式、表达式调用正向覆盖。 + - 隐式捕获、递归/互递归、不支持参数模型负向覆盖。 + - `mlir_text()` 断言包含 helper 函数和 `func.call`。 +- backend `lit` 回归: + - 构造带 `pto.tilelang.inline_proc` helper call 的输入,验证 VPTO backend 主线后调用被消除。 + - 覆盖带返回值调用的 inline 正确性。 + +## Risks / Trade-offs + +- [Risk] 增加 helper 函数后 module 结构更复杂,可能影响现有文本断言测试。 + Mitigation:更新相关断言为“关键行为断言”,避免过度绑定无关格式。 + +- [Risk] `PTOInlineLibCall` 行为变更可能影响现有 OP-Lib 路径。 + Mitigation:保持 callee 筛选策略受控;新增回归覆盖 TileLang inline 与 OP-Lib 两条路径。 + +- [Risk] 语义层新增 inline_proc specialization 可能引入类型绑定回归。 + Mitigation:按 call signature specialization 建 helper,并增加表达式返回值/类型回归。 + +## Migration Plan + +1. 先落 OpenSpec delta,冻结新契约。 +2. 迁移 Python frontend/semantic/lowering 到 helper+call 模型。 +3. 扩展并接线 `PTOInlineLibCall` 到 VPTO backend 主线。 +4. 更新测试与文档,跑最小验证命令。 + +回滚策略: + +- 如 backend inline 接线出现阻断,可先回滚 pass 接线与 helper attr 识别改动,保留 OpenSpec change 未归档状态并修复后再提交。 + +## Open Questions + +- 无(本 change 边界已锁定)。 diff --git a/openspec/changes/archive/2026-04-11-migrate-tilelang-dsl-inline-proc-to-backend-inline/proposal.md b/openspec/changes/archive/2026-04-11-migrate-tilelang-dsl-inline-proc-to-backend-inline/proposal.md new file mode 100644 index 000000000..37c44bc52 --- /dev/null +++ b/openspec/changes/archive/2026-04-11-migrate-tilelang-dsl-inline-proc-to-backend-inline/proposal.md @@ -0,0 +1,81 @@ +# Proposal: 将 TileLang DSL `inline_proc` 迁移为 backend-inline 主线 + +## 概述 + +当前 `inline_proc` 采用 frontend AST 展开,导致参数绑定、返回语义、作用域卫生和递归检测耦合在前端重写逻辑中,维护成本高且行为容易偏离“函数调用”直觉。 +本 change 将 `inline_proc` 迁移为 backend-inline:TileLang frontend 保留 helper `func.func` 与 `func.call`,在 `ptoas` VPTO backend 主线早期强制 inline 并清理死 helper。 + +## 背景与动机 + +现状存在三类直接问题: + +1. 前端展开路径把 `inline_proc` 变成“宏替换”语义,难以稳定支持默认参数、关键字参数、返回值调用与表达式调用。 +2. 前端展开已引入复杂的 capture 校验与名字改写逻辑,调试与行为验证成本持续上升。 +3. `mlir_text()` 与最终 backend 产物的调用边界不清晰,用户难以理解“语义正确性”和“性能内联”各由哪一层负责。 + +因此需要把 `inline_proc` 还原为函数语义建模,并把 inline 优化责任收敛到 `ptoas` backend pipeline。 + +## 目标 + +- 把 `inline_proc` 从 frontend 展开迁移为 backend-inline,不保留 feature switch。 +- 支持 `inline_proc` 的默认参数、关键字调用、返回表达式和表达式位置调用。 +- 保持 fail-fast 约束:禁止隐式捕获、禁止递归/互递归、禁止 `*args` / `**kwargs` / kw-only 参数。 +- 让 `specialized.mlir_text()` 可观察 helper `func.func` 与 `func.call`。 +- 在 `ptoas --pto-backend=vpto` 主线中保证 `inline_proc` helper 调用被强制消除。 + +## 非目标 + +- 不在本 change 中引入新的 matcher 或多模块解析策略;`inline_proc` 仍限定为同模块显式注册可解析。 +- 不在本 change 中放开 `*args` / `**kwargs` / kw-only 参数语义。 +- 不在本 change 中保留旧 frontend-expand 行为作为兼容入口。 +- 不在本 change 中新增独立 capability;沿用现有 TileLang DSL capability 并追加 delta。 + +## What Changes + +- **BREAKING**:`inline_proc` 不再在 frontend 物化阶段展开,`mlir_text()` 允许出现 helper `func.func` 与 `func.call`。 +- `@inline_proc` 参数模型升级:允许默认参数与关键字调用;允许返回表达式;允许表达式位置调用。 +- frontend 继续对隐式捕获、递归/互递归、不支持参数模型 (`*args/**kwargs/kw-only`) 做 source-located reject。 +- semantic/lowering 增加 `inline_proc` helper function + callsite 建模,保留函数调用语义。 +- `ptoas` VPTO backend 主线新增(或扩展)强制 inline 阶段,目标限定为 TileLang inline helper,inline 后清理私有死函数。 + +## Capabilities + +### New Capabilities + +- 无 + +### Modified Capabilities + +- `tilelang-dsl-surface`: `inline_proc` public surface 从 statement-only frontend-expand 迁移为函数语义 + backend-inline,补齐默认参数、关键字调用、返回值和表达式调用能力。 +- `tilelang-dsl-diagnostics`: 新增/调整 `inline_proc` 参数绑定、递归检测、捕获规则与不支持参数模型的 fail-fast 诊断契约。 +- `tilelang-dsl-vpto-lowering`: 定义 `inline_proc` helper/call 的 lowering 形态,以及 VPTO backend 主线强制 inline 与死 helper 清理要求。 + +## 预期结果 + +- `inline_proc` 行为与 Python 函数调用直觉对齐,前端不再维护复杂 AST 宏展开。 +- `mlir_text()` 保留调用边界用于调试;最终 backend 产物由 `ptoas` 主线统一消除 `inline_proc` 调用。 +- 回归测试可同时覆盖 frontend 语义(参数绑定/诊断)与 backend inline 收敛行为(调用消除)。 + +## 成功标准 + +- 新增 `openspec/changes/migrate-tilelang-dsl-inline-proc-to-backend-inline/`,包含 `proposal.md`、`design.md`、`tasks.md`。 +- 新增 spec delta: + - `specs/tilelang-dsl-surface/spec.md` + - `specs/tilelang-dsl-diagnostics/spec.md` + - `specs/tilelang-dsl-vpto-lowering/spec.md` +- 代码层完成 backend-inline 迁移并通过最小验证: + - `python3 -m unittest tilelang-dsl/tests/test_tilelang_dsl_v1.py -k inline_proc` + - 覆盖 inline 主线的 `lit` 回归 + - `openspec validate migrate-tilelang-dsl-inline-proc-to-backend-inline --type change --strict --json --no-interactive` + +## Impact + +- 受影响目录: + - `tilelang-dsl/python/tilelang_dsl/` + - `tilelang-dsl/tests/` + - `tilelang-dsl/docs/user_guide/` + - `tools/ptoas/` + - `lib/PTO/Transforms/` + - `openspec/changes/migrate-tilelang-dsl-inline-proc-to-backend-inline/` +- 受影响 public API:`@inline_proc` 调用与返回语义。 +- 受影响 backend 行为:VPTO backend 主线新增(或扩展)强制 inline 阶段。 diff --git a/openspec/changes/archive/2026-04-11-migrate-tilelang-dsl-inline-proc-to-backend-inline/specs/tilelang-dsl-diagnostics/spec.md b/openspec/changes/archive/2026-04-11-migrate-tilelang-dsl-inline-proc-to-backend-inline/specs/tilelang-dsl-diagnostics/spec.md new file mode 100644 index 000000000..ff5a5041d --- /dev/null +++ b/openspec/changes/archive/2026-04-11-migrate-tilelang-dsl-inline-proc-to-backend-inline/specs/tilelang-dsl-diagnostics/spec.md @@ -0,0 +1,42 @@ +## MODIFIED Requirements + +### Requirement: v1 MUST reject unsupported Python syntax and unsupported DSL calls before IR generation + +TileLang DSL v1 frontend MUST 只接受受限 Python 子集。 +`while`、list/dict/set comprehension、arbitrary external function call、未注册 DSL op、以及其他超出 v1 surface 的 Python 结构 MUST 在 frontend 被拒绝。 +对命名调用,frontend MUST 仅放行“同模块显式注册的 `@inline_proc`”目标;其余 `namespace=None` 命名调用 MUST 继续按 unsupported external function call 拒绝。 + +#### Scenario: unsupported Python construct is rejected before lowering + +- **WHEN** kernel body 使用 `while`、comprehension、任意非 `pto.*` function call(且不是同模块已注册 inline_proc)或未纳入 v1 support matrix 的 DSL call +- **THEN** frontend MUST 在生成任何 VPTO IR 之前报错 +- **AND** 诊断 MUST 指明违规的 Python construct 或 DSL call 名称 + +## ADDED Requirements + +### Requirement: inline_proc diagnostics MUST fail fast on capture/recursion/unsupported parameter forms + +针对 `@inline_proc`,frontend diagnostics MUST fail-fast 覆盖以下语义约束: + +- 隐式捕获 MUST 报错 +- 递归/互递归 MUST 报错 +- `*args` / `**kwargs` / kw-only 参数 MUST 报错 +- 调用绑定错误(重复赋值、缺参、未知关键字)MUST 报错 + +#### Scenario: implicit capture in inline_proc is rejected with source location + +- **WHEN** `@inline_proc` helper 体内引用了非参数、非局部定义的外部符号 +- **THEN** frontend MUST 在 materialization 前报错 +- **AND** 诊断 MUST 指向 helper 源位置并标明 capture 违反约束 + +#### Scenario: recursion and mutual recursion are rejected + +- **WHEN** 某个 `inline_proc` 直接递归调用自身,或两个 helper 形成互递归调用环 +- **THEN** frontend MUST 拒绝该定义/调用图 +- **AND** 诊断 MUST 明确指出 recursion 或 mutual recursion + +#### Scenario: unsupported inline_proc signature forms are rejected + +- **WHEN** 用户为 `@inline_proc` 使用 kw-only 参数、`*args` 或 `**kwargs` +- **THEN** frontend MUST 直接报错 +- **AND** 诊断 MUST 明确指出当前 v1 不支持该参数模型 diff --git a/openspec/changes/archive/2026-04-11-migrate-tilelang-dsl-inline-proc-to-backend-inline/specs/tilelang-dsl-surface/spec.md b/openspec/changes/archive/2026-04-11-migrate-tilelang-dsl-inline-proc-to-backend-inline/specs/tilelang-dsl-surface/spec.md new file mode 100644 index 000000000..d149366c4 --- /dev/null +++ b/openspec/changes/archive/2026-04-11-migrate-tilelang-dsl-inline-proc-to-backend-inline/specs/tilelang-dsl-surface/spec.md @@ -0,0 +1,20 @@ +## ADDED Requirements + +### Requirement: `@inline_proc` MUST expose function-call semantics and defer inlining to backend + +TileLang DSL v1 中,`@inline_proc` MUST 采用函数语义建模,而不是 frontend AST 宏展开。 +`inline_proc` 定义 MUST 支持默认参数;调用 MUST 支持位置参数与关键字参数混合绑定;helper body MUST 允许返回表达式。 +`inline_proc` 调用 MUST 允许出现在语句位置和表达式位置。 +`specialized.mlir_text()` MUST 允许暴露入口 kernel `func.func`、private helper `func.func` 与 `func.call`,而不是强制在 frontend 阶段消除调用。 + +#### Scenario: inline_proc accepts defaults and keyword call syntax + +- **WHEN** 用户定义 `@pto.inline_proc` helper,参数包含默认值,并在 kernel 中用关键字调用该 helper +- **THEN** frontend MUST 按 Python 子集完成参数绑定并接受该调用 +- **AND** 该调用 MUST NOT 因“仅支持位置参数”而被拒绝 + +#### Scenario: inline_proc return expression can be used in expression position + +- **WHEN** 用户在 `@inline_proc` helper 中返回表达式,并把 helper 调用放在另一个表达式上下文中 +- **THEN** frontend MUST 接受该 surface 并保留调用结果语义 +- **AND** `mlir_text()` 结果 MAY 包含对应的 `func.call` 结果值绑定 diff --git a/openspec/changes/archive/2026-04-11-migrate-tilelang-dsl-inline-proc-to-backend-inline/specs/tilelang-dsl-vpto-lowering/spec.md b/openspec/changes/archive/2026-04-11-migrate-tilelang-dsl-inline-proc-to-backend-inline/specs/tilelang-dsl-vpto-lowering/spec.md new file mode 100644 index 000000000..4ce4a05a3 --- /dev/null +++ b/openspec/changes/archive/2026-04-11-migrate-tilelang-dsl-inline-proc-to-backend-inline/specs/tilelang-dsl-vpto-lowering/spec.md @@ -0,0 +1,50 @@ +## MODIFIED Requirements + +### Requirement: TileLang DSL v1 MUST support the fixed elementwise lowering profile + +TileLang DSL v1 lowering MUST 支持以下固定 support matrix: + +- 2D `TensorView` +- 1D/2D `Tile` +- `dma_load` +- `dma_store` +- `make_mask(dtype, PAT.*)` / `make_mask(dtype, remaining)` +- `vlds` +- `vsts` +- unary:`vabs`, `vrelu`, `vexp`, `vnot` +- binary:`vadd`, `vsub`, `vmul`, `vdiv`, `vmax`, `vmin`, `vand`, `vor`, `vxor` +- vector-scalar:`vadds`, `vsubs`, `vmuls`, `vdivs`, `vmaxs`, `vmins` +- `for range(lb, ub, step)` +- `if/else` +- `set_flag`, `wait_flag`, `pipe_barrier` +- `@inline_proc` helper function 定义与 call(含返回值 call) + +support matrix 外的 surface MUST 在 frontend reject。 +对 `@inline_proc`,frontend/lowering MUST 生成“入口 kernel + private helper + `func.call`”的 authoring-form module 结构;`mlir_text()` 阶段 MAY 保留这些 call 边界。 + +#### Scenario: representative elementwise kernel lowers to authoring-form VPTO with inline_proc calls + +- **WHEN** 用户编写由 `TensorView`、`Tile`、高层 DMA、typed mask、elementwise vector op、`for`、`if`、基础 sync 和 `inline_proc` 调用组成的 kernel +- **THEN** frontend MUST 产出只包含 `func.func`、`arith`、`scf` 和合法 `pto.*` authoring surface 的 VPTO IR +- **AND** 该 IR MAY 包含标注为 inline_proc helper 的 private `func.func` 与 `func.call` + +## ADDED Requirements + +### Requirement: VPTO backend mainline MUST force-inline TileLang inline_proc helpers + +在 `ptoas --pto-backend=vpto` 主线中,backend MUST 在早期 pass 阶段强制 inline TileLang inline_proc helper 调用。 +强制 inline 的目标 MUST 通过 helper 属性(例如 `pto.tilelang.inline_proc`)筛选,避免误作用于普通函数调用。 +backend MUST 支持“带返回值 `func.call`”内联替换,不仅是 `() -> ()` 调用。 +inline 完成后,面向 inline_proc helper 的 `func.call` MUST 被消除,且无引用 private helper MUST 被清理。 + +#### Scenario: backend pipeline removes inline_proc calls before downstream lowering + +- **WHEN** 输入 module 包含 `pto.tilelang.instance` kernel 对 `pto.tilelang.inline_proc` private helper 的 `func.call` +- **THEN** VPTO backend 主线 MUST 在后续依赖平坦 body 的 pass 之前完成 inline +- **AND** inline 后 module MUST NOT 残留面向 inline_proc helper 的 `func.call` + +#### Scenario: return-valued helper call is inlined with SSA replacement + +- **WHEN** inline_proc helper 返回一个值,caller 使用 `func.call` 结果参与后续计算 +- **THEN** backend inline MUST 正确替换 call result 的 SSA use 链 +- **AND** 结果 IR MUST 保持类型一致并通过后续 legality 校验 diff --git a/openspec/changes/archive/2026-04-11-migrate-tilelang-dsl-inline-proc-to-backend-inline/tasks.md b/openspec/changes/archive/2026-04-11-migrate-tilelang-dsl-inline-proc-to-backend-inline/tasks.md new file mode 100644 index 000000000..394d161ee --- /dev/null +++ b/openspec/changes/archive/2026-04-11-migrate-tilelang-dsl-inline-proc-to-backend-inline/tasks.md @@ -0,0 +1,47 @@ +## 1. OpenSpec delta 落定 + +- [x] 1.1 完成 `specs/tilelang-dsl-surface/spec.md`,定义 `inline_proc` 函数语义与 `mlir_text()` 可见调用边界 +- [x] 1.2 完成 `specs/tilelang-dsl-diagnostics/spec.md`,定义 inline_proc fail-fast 诊断与命名调用放行边界 +- [x] 1.3 完成 `specs/tilelang-dsl-vpto-lowering/spec.md`,定义 helper/call lowering 形态与 backend 强制 inline 契约 + +## 2. TileLang frontend/semantic/lowering 迁移 + +- [x] 2.1 移除 frontend AST 展开路径,保留 inline_proc 为命名调用并放开表达式位置调用 +- [x] 2.2 在 frontend 参数绑定中支持 inline_proc 默认参数和关键字调用,同时保留 `*args/**kwargs/kw-only` 拒绝 +- [x] 2.3 在 semantic 层建立 inline_proc helper call 语义节点,允许 `namespace=None` 的受控命名调用进入分析 +- [x] 2.4 在 lowering 层渲染 kernel + private helper 多函数 module,生成 `func.call` 与返回值绑定 + +## 3. ptoas backend inline pass 与 pipeline 接线 + +- [x] 3.1 扩展 `PTOInlineLibCall` 以支持带返回值 `func.call` 的 inline 替换 +- [x] 3.2 将 inline 目标限定为 TileLang inline_proc helper 属性,避免误内联普通调用 +- [x] 3.4 inline 后清理无引用 private helper,并验证主线不残留 inline_proc helper call + +## 4. 回归测试与文档迁移 + +- [x] 4.1 更新 `tilelang-dsl/tests/test_tilelang_dsl_v1.py`:默认参数/关键字/返回表达式改为正向测试 +- [x] 4.2 新增 inline_proc 表达式位置调用和 `mlir_text()` helper+`func.call` 断言测试 +- [x] 4.3 保留并验证负向测试:隐式捕获、递归/互递归、`*args/**kwargs/kw-only` +- [x] 4.4 增加 ptoas 侧 `lit` 回归:验证 VPTO backend 主线消除 inline_proc helper 调用(含带返回值 case) +- [x] 4.5 更新 `tilelang-dsl/docs/user_guide/08-control-flow.md`,迁移为 backend-inline 语义描述 + +## 5. 验证命令与结果记录 + +- [x] 5.1 执行 `python3 -m unittest tilelang-dsl/tests/test_tilelang_dsl_v1.py -k inline_proc` +- [x] 5.2 执行覆盖强制 inline 生效路径的 `lit`/pipeline 回归并记录结果 +- [x] 5.3 执行 `openspec validate migrate-tilelang-dsl-inline-proc-to-backend-inline --type change --strict --json --no-interactive` +- [x] 5.4 在 change 记录中汇总本次实现的通过项与未覆盖项(若有) + +### 5.4 验证结果记录(2026-04-10) + +- 5.1 命令与结果: + - `PYTHONPATH=tilelang-dsl/python python3 -m unittest tilelang-dsl/tests/test_tilelang_dsl_v1.py -k inline_proc` + - `Ran 15 tests in 0.009s, OK` +- 5.2 命令与结果: + - `llvm-lit -sv test/basic/tilelang_inline_proc_backend_inline.pto test/basic/inline_libcall_result_rewrite.pto test/basic/inline_libcall_filter_tilelang_scope.pto test/basic/vpto_mainline_inline_proc_cleanup.pto test/basic/expand_tile_op_tilelang.pto` + - `Passed: 5/5` +- 5.3 命令与结果: + - `openspec validate migrate-tilelang-dsl-inline-proc-to-backend-inline --type change --strict --json --no-interactive` + - `valid: true, issues: []` +- 未覆盖项: + - 未执行全量 `lit`/`ctest` 套件;本次仅覆盖 inline_proc 迁移相关的定向回归路径。 diff --git a/openspec/specs/tilelang-dsl-surface/spec.md b/openspec/specs/tilelang-dsl-surface/spec.md index a3824074b..6a1dd7ccd 100644 --- a/openspec/specs/tilelang-dsl-surface/spec.md +++ b/openspec/specs/tilelang-dsl-surface/spec.md @@ -54,3 +54,35 @@ Tile 的 physical shape、memory space 和配置 MUST 在 specialization 阶段 - **WHEN** kernel 含 bare `Tile` 参数,且调用方通过 `descriptor.specialize(**bindings)` 为所有 bare `Tile` 参数补齐静态 shape / space / config - **THEN** 返回的 specialized descriptor MUST 允许调用 `mlir_text()`, `mlir_module()`, `verify()` 和 `emit(path)` - **AND** specialization 之后的 Tile physical shape MUST 作为编译期静态契约固定下来 + +### Requirement: TileLang DSL v1 MUST expose fixed-width vector type construction as `pto.vreg(dtype)` + +TileLang DSL v1 MUST 提供 public type constructor `pto.vreg(dtype)`。 +该 surface MUST 只接受元素类型,不接受显式 lanes 参数。 +frontend MUST 依据固定 256-byte vector register 宽度自动推导 lanes。 +当前 v1 若元素类型不在已支持的 vector lowering 子集内,frontend MUST fail fast。 + +#### Scenario: `pto.vreg(dtype)` returns the inferred fixed-width vector type + +- **WHEN** 用户在 DSL 中书写 `pto.vreg(pto.f32)` 或 `pto.vreg(dst.element_type)` +- **THEN** frontend MUST 将其识别为 vector type expression +- **AND** `pto.f32` MUST 对应 `!pto.vreg<64xf32>`,`pto.f16` MUST 对应 `!pto.vreg<128xf16>` +- **AND** MUST NOT 要求用户显式提供 lanes 参数 + +### Requirement: TileLang DSL v1 MUST expose typed mask markers and MUST NOT expose `pto.memref(...)` + +TileLang DSL v1 MUST 提供 public typed-mask marker:`pto.mask_b8`、`pto.mask_b16`、`pto.mask_b32`。 +frontend MUST 将这些 surface 识别为对应 `!pto.mask`、`!pto.mask`、`!pto.mask` 的 type expression。 +与此同时,DSL public surface MUST NOT 暴露 `pto.memref(...)` constructor;memref 只允许作为内部 IR / lowering 表达出现,不得作为 DSL authoring type surface。 + +#### Scenario: typed mask marker matches `make_mask` result type + +- **WHEN** 用户书写 `mask: pto.mask_b32 = pto.make_mask(pto.f32, pto.PAT.ALL)` +- **THEN** frontend MUST 接受该注解 +- **AND** 若注解 granularity 与 `make_mask` 推导结果不一致,frontend MUST fail fast + +#### Scenario: `pto.memref(...)` is not part of the DSL public type surface + +- **WHEN** 用户查看或使用 TileLang DSL v1 public type surface +- **THEN** 文档和 package surface MUST 只暴露 `TensorView`、`Tile`、typed pointer、`pto.vreg(...)`、typed mask 等 authoring type +- **AND** MUST NOT 将 `pto.memref(...)` 描述为 DSL 侧可用 constructor From 5d7ec5aa7ba510011381005b45fa19a26a0f04fc Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Sat, 11 Apr 2026 22:14:57 +0800 Subject: [PATCH 036/192] Support more predicate ops --- tilelang-dsl/docs/unsupported-features.md | 18 - .../user_guide/10-predicate-operations.md | 186 +++++++-- .../11-vector-arithmetic-operations.md | 28 +- tilelang-dsl/python/tilelang_dsl/__init__.py | 4 + tilelang-dsl/python/tilelang_dsl/lowering.py | 147 ++++++- tilelang-dsl/python/tilelang_dsl/semantic.py | 379 +++++++++++++++++- .../python/tilelang_dsl/support_matrix.py | 21 + tilelang-dsl/python/tilelang_dsl/types.py | 14 + tilelang-dsl/tests/test_tilelang_dsl_v1.py | 159 +++++++- 9 files changed, 840 insertions(+), 116 deletions(-) diff --git a/tilelang-dsl/docs/unsupported-features.md b/tilelang-dsl/docs/unsupported-features.md index 4cd1b03b9..94a085a9e 100644 --- a/tilelang-dsl/docs/unsupported-features.md +++ b/tilelang-dsl/docs/unsupported-features.md @@ -65,30 +65,12 @@ The previously missing vector-memory surfaces are now implemented. Remaining guide gaps are concentrated in the wider indexed/flush families that are still not wired through this DSL package. -### Missing Direct Predicate Constructor/Compare APIs - -The implementation expects users to go through `pto.make_mask(...)` rather than -call the underlying mask ops directly. These guide-documented APIs are not part -of the supported authoring surface: - -- `pto.pset_b8(...)`, `pto.pset_b16(...)`, `pto.pset_b32(...)` -- `pto.pge_b8(...)`, `pto.pge_b16(...)`, `pto.pge_b32(...)` -- `pto.plt_b8(...)`, `pto.plt_b16(...)`, `pto.plt_b32(...)` - ### Missing Extended Vector Arithmetic Families The previously missing `11-vector-arithmetic-operations.md` gap list is now implemented in the current package surface (including fused ops, broadcast/index generation, reduction-flavored ops, and rearrangement/sort groups). -### Missing Predicate Rearrangement Shorthands - -The guide documents mask-specific rearrangement helpers that are not currently -implemented: - -- `pto.pdintlv_b8(...)` -- `pto.pintlv_b16(...)` - ### Deferred Surface `pto.vreduce(...)` is still explicitly deferred and remains rejected even in diff --git a/tilelang-dsl/docs/user_guide/10-predicate-operations.md b/tilelang-dsl/docs/user_guide/10-predicate-operations.md index f76d2f8cb..227888c4d 100644 --- a/tilelang-dsl/docs/user_guide/10-predicate-operations.md +++ b/tilelang-dsl/docs/user_guide/10-predicate-operations.md @@ -4,11 +4,13 @@ Operations for creating and manipulating typed masks. **Recommended API**: For most use cases, prefer the unified `pto.make_mask()` function which automatically selects the appropriate mask granularity based on element type and supports both tail processing (remaining element count) and pattern-based mask generation. This eliminates the need to manually choose between `plt_b8`/`plt_b16`/`plt_b32` (tail processing) and `pset_b8`/`pset_b16`/`pset_b32` (pattern generation) operations. -**Pattern alias**: For brevity in examples, the documentation uses `PAT` as an alias for `pto.MaskPattern` (e.g., `PAT.ALL` instead of `pto.MaskPattern.PAT_ALL`). In practice, you can create this alias with `from pto import MaskPattern as PAT` or `PAT = pto.MaskPattern`. +**Pattern alias**: For brevity in examples, the documentation uses `PAT` as an alias for `pto.MaskPattern` (e.g., `PAT.ALL` instead of `pto.MaskPattern.ALL`). In practice, you can create this alias with `from pto import MaskPattern as PAT` or `PAT = pto.MaskPattern`. -**Part Mode Enum**: The `PartMode` enum provides type-safe part selection for `pto.ppack` and `pto.punpack` operations. It includes the following values: `EVEN` (selects even-indexed elements) and `ODD` (selects odd-indexed elements). +**Predicate Part Enum**: `pto.ppack` and `pto.punpack` require the `PredicatePart` enum. Use `PredicatePart.LOWER` or `PredicatePart.HIGHER`; these lower to the VPTO canonical `PART` tokens `"LOWER"` and `"HIGHER"`. -**Predicate Dist Enum**: The `PredicateDist` enum provides type-safe distribution mode selection for predicate load/store families. Common values include `NORM`, `US`, and `DS`. +**Predicate Dist Enum**: The `PredicateDist` enum provides type-safe distribution mode selection for predicate memory families. Load families (`plds`, `pld`, `pldi`) use `NORM`, `US`, and `DS`. Store families (`pst`, `psti`) use `NORM` and `PK`. + +**Pattern coverage**: The VPTO canonical predicate-generation families use `PAT_*` tokens such as `PAT_ALL`, `PAT_ALLF`, `PAT_H`, `PAT_Q`, `PAT_VL*`, `PAT_M3`, and `PAT_M4`. The Python DSL surface may expose only a subset through `pto.MaskPattern`; check the enum for currently available values. #### `pto.pset_b8(pattern: pto.MaskPattern) -> pto.mask_b8` @@ -17,7 +19,7 @@ Operations for creating and manipulating typed masks. **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| -| `pattern` | `pto.MaskPattern` | Mask pattern enum (e.g., `pto.MaskPattern.PAT_ALL`, `pto.MaskPattern.PAT_EVEN`) | +| `pattern` | `pto.MaskPattern` | Mask pattern enum (for example `pto.MaskPattern.ALL`, `pto.MaskPattern.ALLF`, or `pto.MaskPattern.VL32`) | **Returns**: | Return Value | Type | Description | @@ -29,7 +31,7 @@ Operations for creating and manipulating typed masks. **Example**: ```python -mask8 = pto.make_mask(pto.i8, PAT.ALL) +mask8 = pto.pset_b8(PAT.ALL) ``` #### `pto.pset_b16(pattern: pto.MaskPattern) -> pto.mask_b16` @@ -39,7 +41,7 @@ mask8 = pto.make_mask(pto.i8, PAT.ALL) **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| -| `pattern` | `pto.MaskPattern` | Mask pattern enum (e.g., `pto.MaskPattern.PAT_ALL`, `pto.MaskPattern.PAT_EVEN`) | +| `pattern` | `pto.MaskPattern` | Mask pattern enum (for example `pto.MaskPattern.ALL`, `pto.MaskPattern.ALLF`, or `pto.MaskPattern.VL32`) | **Returns**: | Return Value | Type | Description | @@ -51,7 +53,7 @@ mask8 = pto.make_mask(pto.i8, PAT.ALL) **Example**: ```python -mask16 = pto.make_mask(pto.f16, PAT.ALL) +mask16 = pto.pset_b16(PAT.ALL) ``` #### `pto.pset_b32(pattern: pto.MaskPattern) -> pto.mask_b32` @@ -61,7 +63,7 @@ mask16 = pto.make_mask(pto.f16, PAT.ALL) **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| -| `pattern` | `pto.MaskPattern` | Mask pattern enum (e.g., `pto.MaskPattern.PAT_ALL`, `pto.MaskPattern.PAT_EVEN`) | +| `pattern` | `pto.MaskPattern` | Mask pattern enum (for example `pto.MaskPattern.ALL`, `pto.MaskPattern.ALLF`, or `pto.MaskPattern.VL32`) | **Returns**: | Return Value | Type | Description | @@ -73,7 +75,7 @@ mask16 = pto.make_mask(pto.f16, PAT.ALL) **Example**: ```python -mask32 = pto.make_mask(pto.f32, PAT.ALL) +mask32 = pto.pset_b32(PAT.ALL) ``` #### `pto.pge_b8(pattern: pto.MaskPattern) -> pto.mask_b8` @@ -83,7 +85,7 @@ mask32 = pto.make_mask(pto.f32, PAT.ALL) **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| -| `pattern` | `pto.MaskPattern` | Tail mask pattern enum (e.g., `pto.MaskPattern.PAT_VL8`, `pto.MaskPattern.PAT_VL16`) | +| `pattern` | `pto.MaskPattern` | Tail mask pattern enum lowered to a VPTO `PAT_*` token (for example `pto.MaskPattern.VL16` or `pto.MaskPattern.VL32`) | **Returns**: | Return Value | Type | Description | @@ -96,8 +98,8 @@ mask32 = pto.make_mask(pto.f32, PAT.ALL) **Example**: ```python -# Tail mask for first 8 lanes -tail_mask = pto.pge_b8(PAT.VL8) +# Tail mask pattern lowered as `PAT_VL16` +tail_mask = pto.pge_b8(PAT.VL16) ``` #### `pto.pge_b16(pattern: pto.MaskPattern) -> pto.mask_b16` @@ -107,7 +109,7 @@ tail_mask = pto.pge_b8(PAT.VL8) **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| -| `pattern` | `pto.MaskPattern` | Tail mask pattern enum (e.g., `pto.MaskPattern.PAT_VL8`, `pto.MaskPattern.PAT_VL16`) | +| `pattern` | `pto.MaskPattern` | Tail mask pattern enum lowered to a VPTO `PAT_*` token (for example `pto.MaskPattern.VL16` or `pto.MaskPattern.VL32`) | **Returns**: | Return Value | Type | Description | @@ -131,7 +133,7 @@ tail_mask = pto.pge_b16(PAT.VL16) **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| -| `pattern` | `pto.MaskPattern` | Tail mask pattern enum (e.g., `pto.MaskPattern.PAT_VL8`, `pto.MaskPattern.PAT_VL16`, `pto.MaskPattern.PAT_VL32`) | +| `pattern` | `pto.MaskPattern` | Tail mask pattern enum lowered to a VPTO `PAT_*` token (for example `pto.MaskPattern.VL16` or `pto.MaskPattern.VL32`) | **Returns**: | Return Value | Type | Description | @@ -231,7 +233,7 @@ mask, remaining = pto.plt_b32(remaining) # generates mask for next chunk, updat | Parameter | Type | Description | |-----------|------|-------------| | `element_type` | `Type` | Element type (e.g., `pto.f32`, `pto.f16`, `pto.i8`) | -| `value` | `pto.i32` \| `pto.MaskPattern` | Either:
- Remaining element count (as `pto.i32`) for tail processing
- Mask pattern enum value for fixed mask generation (e.g., `pto.MaskPattern.PAT_ALL`, `pto.MaskPattern.PAT_VL32`) | +| `value` | `pto.i32` \| `pto.MaskPattern` | Either:
- Remaining element count (as `pto.i32`) for tail processing
- Mask pattern enum value for fixed mask generation (for example `pto.MaskPattern.ALL` or `pto.MaskPattern.VL32`) | **Returns**: | Return Value | Type | Description | @@ -276,35 +278,45 @@ mask1, updated = pto.make_mask(pto.f32, remaining) # tail processing mask2 = pto.make_mask(pto.f32, PAT.ALL) # pattern mode ``` -#### `pto.ppack(mask: MaskType, part: PartMode) -> MaskType` +#### `pto.ppack(mask: MaskType, part: PredicatePart) -> MaskType` -**Description**: Rearranges a mask according to the requested `part` selector. +**Description**: Narrowing pack of a predicate register. **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| | `mask` | `MaskType` | Input mask (`mask_b8`, `mask_b16`, or `mask_b32`) | -| `part` | `PartMode` | Part selector enum: `PartMode.EVEN` or `PartMode.ODD`. Determines which half of the mask to pack (even-indexed or odd-indexed elements). | +| `part` | `PredicatePart` | Part selector enum. Use `PredicatePart.LOWER` or `PredicatePart.HIGHER`. | **Returns**: | Return Value | Type | Description | |--------------|------|-------------| -| `packed` | `MaskType` | Reordered mask | +| `packed` | `MaskType` | Packed mask | + +**Example**: +```python +packed = pto.ppack(mask, pto.PredicatePart.LOWER) +``` -#### `pto.punpack(mask: MaskType, part: PartMode) -> MaskType` +#### `pto.punpack(mask: MaskType, part: PredicatePart) -> MaskType` -**Description**: Applies the inverse mask-part rearrangement selected by `part`. +**Description**: Widening unpack of a predicate register. **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| | `mask` | `MaskType` | Input mask | -| `part` | `PartMode` | Part selector enum: `PartMode.EVEN` or `PartMode.ODD`. Determines which half of the mask to unpack (even-indexed or odd-indexed elements). | +| `part` | `PredicatePart` | Part selector enum. Use `PredicatePart.LOWER` or `PredicatePart.HIGHER`. | **Returns**: | Return Value | Type | Description | |--------------|------|-------------| -| `mask` | `MaskType` | Reordered mask | +| `mask` | `MaskType` | Unpacked mask | + +**Example**: +```python +unpacked = pto.punpack(mask, pto.PredicatePart.HIGHER) +``` #### `pto.pnot(mask: MaskType, gate: MaskType) -> MaskType` @@ -400,99 +412,187 @@ mask = pto.pld(buf, offset, PredicateDist.NORM) mask = pto.pldi(buf, 0, PredicateDist.NORM) ``` -#### `pto.pst(mask: MaskType, buf: ptr, offset: Index) -> None` [Advanced Tier] +#### `pto.psts(mask: MaskType, buf: ptr, offset: Index) -> None` [Advanced Tier] -**Description**: Stores a predicate mask to buffer. +**Description**: Stores a predicate mask to UB memory using scalar-offset form. **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| | `mask` | `MaskType` | Predicate mask to store | | `buf` | `ptr` | Pointer to destination buffer | -| `offset` | `Index` | Byte offset | +| `offset` | `Index` | Scalar/index-style offset | + +**Returns**: None (side-effect operation) + +**Example**: +```python +pto.psts(mask, buf, offset) +``` + +#### `pto.pst(mask: MaskType, buf: ptr, offset: Index, dist: PredicateDist = PredicateDist.NORM) -> None` [Advanced Tier] + +**Description**: Stores a predicate mask to UB memory using areg/index offset encoding. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Predicate mask to store | +| `buf` | `ptr` | Pointer to destination UB buffer | +| `offset` | `Index` | Areg/index-style offset | +| `dist` | `PredicateDist` | Distribution mode for predicate store. Use `PredicateDist.NORM` or `PredicateDist.PK`. Default is `PredicateDist.NORM`. | **Returns**: None (side-effect operation) **Example**: ```python -pto.pst(mask, buf, offset) +pto.pst(mask, buf, offset, PredicateDist.NORM) ``` -#### `pto.psti(mask: MaskType, imm: pto.i32) -> None` +#### `pto.psti(mask: MaskType, buf: ptr, imm_offset: pto.i32, dist: PredicateDist = PredicateDist.NORM) -> None` [Advanced Tier] -**Description**: Stores a predicate mask to immediate destination. +**Description**: Stores a predicate mask to UB memory using immediate-offset encoding. **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| | `mask` | `MaskType` | Predicate mask to store | -| `imm` | `pto.i32` | Immediate destination identifier | +| `buf` | `ptr` | Pointer to destination UB buffer | +| `imm_offset` | `pto.i32` | Immediate-offset operand | +| `dist` | `PredicateDist` | Distribution mode for predicate store. Use `PredicateDist.NORM` or `PredicateDist.PK`. Default is `PredicateDist.NORM`. | **Returns**: None (side-effect operation) **Example**: ```python -pto.psti(mask, 1) +pto.psti(mask, buf, pto.i32(8), PredicateDist.PK) +``` + +#### `pto.pstu(align_in: pto.align, mask: MaskType, buf: ptr) -> (pto.align, ptr)` [Advanced Tier] + +**Description**: Unaligned predicate store with align-state update. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `align_in` | `pto.align` | Input alignment state | +| `mask` | `MaskType` | Predicate mask to store | +| `buf` | `ptr` | Pointer to destination UB buffer | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `align_out` | `pto.align` | Updated alignment state | +| `base_out` | `ptr` | Updated destination pointer | + +**Example**: +```python +align_out, base_out = pto.pstu(align_in, mask, buf) ``` -#### `pto.pand(src0: MaskType, src1: MaskType) -> MaskType` +#### `pto.pand(src0: MaskType, src1: MaskType, mask: MaskType) -> MaskType` -**Description**: Bitwise AND of two predicate masks. +**Description**: Bitwise AND of two predicate masks under a gating mask. **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| | `src0` | `MaskType` | First input mask | | `src1` | `MaskType` | Second input mask | +| `mask` | `MaskType` | Gating mask with the same granularity | **Returns**: | Return Value | Type | Description | |--------------|------|-------------| -| `result` | `MaskType` | Bitwise AND of input masks | +| `result` | `MaskType` | Bitwise AND result | **Example**: ```python -result = pto.pand(mask1, mask2) +result = pto.pand(mask1, mask2, gate) ``` -#### `pto.por(src0: MaskType, src1: MaskType) -> MaskType` +#### `pto.por(src0: MaskType, src1: MaskType, mask: MaskType) -> MaskType` -**Description**: Bitwise OR of two predicate masks. +**Description**: Bitwise OR of two predicate masks under a gating mask. **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| | `src0` | `MaskType` | First input mask | | `src1` | `MaskType` | Second input mask | +| `mask` | `MaskType` | Gating mask with the same granularity | **Returns**: | Return Value | Type | Description | |--------------|------|-------------| -| `result` | `MaskType` | Bitwise OR of input masks | +| `result` | `MaskType` | Bitwise OR result | **Example**: ```python -result = pto.por(mask1, mask2) +result = pto.por(mask1, mask2, gate) ``` -#### `pto.pxor(src0: MaskType, src1: MaskType) -> MaskType` +#### `pto.pxor(src0: MaskType, src1: MaskType, mask: MaskType) -> MaskType` -**Description**: Bitwise XOR of two predicate masks. +**Description**: Bitwise XOR of two predicate masks under a gating mask. **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| | `src0` | `MaskType` | First input mask | | `src1` | `MaskType` | Second input mask | +| `mask` | `MaskType` | Gating mask with the same granularity | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `MaskType` | Bitwise XOR result | + +**Example**: +```python +result = pto.pxor(mask1, mask2, gate) +``` + +#### `pto.pdintlv_b8(src0: pto.mask_b8, src1: pto.mask_b8) -> (pto.mask_b8, pto.mask_b8)` + +**Description**: Predicate deinterleave for 8-bit masks. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src0` | `pto.mask_b8` | First input mask | +| `src1` | `pto.mask_b8` | Second input mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `low` | `pto.mask_b8` | First result mask | +| `high` | `pto.mask_b8` | Second result mask | + +**Example**: +```python +low8, high8 = pto.pdintlv_b8(mask_a, mask_b) +``` + +#### `pto.pintlv_b16(src0: pto.mask_b16, src1: pto.mask_b16) -> (pto.mask_b16, pto.mask_b16)` + +**Description**: Predicate interleave for 16-bit masks. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src0` | `pto.mask_b16` | First input mask | +| `src1` | `pto.mask_b16` | Second input mask | **Returns**: | Return Value | Type | Description | |--------------|------|-------------| -| `result` | `MaskType` | Bitwise XOR of input masks | +| `low` | `pto.mask_b16` | First result mask | +| `high` | `pto.mask_b16` | Second result mask | **Example**: ```python -result = pto.pxor(mask1, mask2) +low16, high16 = pto.pintlv_b16(mask_a, mask_b) ``` **Note**: Prefer `pto.make_mask()` for automatic bitwidth selection and unified tail/pattern mask generation. diff --git a/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md b/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md index d2f05d6f7..872bfb80b 100644 --- a/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md +++ b/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md @@ -1130,33 +1130,7 @@ Reduction operations across vector lanes or channels. Operations for rearranging data within vectors. -#### `pto.pdintlv_b8(mask: pto.mask_b8) -> pto.mask_b8` - -**Description**: Deinterleave 8-bit mask. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `mask` | `pto.mask_b8` | Input 8-bit mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `pto.mask_b8` | Deinterleaved mask | - -#### `pto.pintlv_b16(mask: pto.mask_b16) -> pto.mask_b16` - -**Description**: Interleave 16-bit mask. - -**Parameters**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `mask` | `pto.mask_b16` | Input 16-bit mask | - -**Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `pto.mask_b16` | Interleaved mask | +Predicate rearrangement ops `pto.pdintlv_b8` and `pto.pintlv_b16` are documented in `10-predicate-operations.md` because they operate on predicate masks rather than vector registers. Implemented current-package rearrangement surface also includes: - `pto.vintlvv2(vec0, vec1, part) -> VRegType` diff --git a/tilelang-dsl/python/tilelang_dsl/__init__.py b/tilelang-dsl/python/tilelang_dsl/__init__.py index 44fba3583..d9525c342 100644 --- a/tilelang-dsl/python/tilelang_dsl/__init__.py +++ b/tilelang-dsl/python/tilelang_dsl/__init__.py @@ -35,6 +35,8 @@ MemorySpace, MaskPattern, PAT, + PredicateDist, + PredicatePart, PadValue, PadMode, PositionMode, @@ -111,6 +113,8 @@ "OrderMode", "DeinterleaveDist", "InterleaveDist", + "PredicateDist", + "PredicatePart", "StrideMode", "TileConfig", "TileSpecialization", diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index 9a4fdcd98..4180274a6 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -560,19 +560,22 @@ def _render_multi_result_assign( if len(stmt.targets) != len(stmt.value.type.elements): raise NotImplementedError("multi-result lowering expects tuple assignment arity to match the call result count") - if stmt.value.name == "make_mask": - dtype_expr, remaining_expr = stmt.value.args - if not self._is_dtype_meta_expr(dtype_expr): - raise NotImplementedError("make_mask dtype lowering expects a dtype symbol") - + if stmt.value.name == "make_mask" or stmt.value.name in {"plt_b8", "plt_b16", "plt_b32"}: lines: list[str] = [] + if stmt.value.name == "make_mask": + dtype_expr, remaining_expr = stmt.value.args + if not self._is_dtype_meta_expr(dtype_expr): + raise NotImplementedError("make_mask dtype lowering expects a dtype symbol") + opname = f"pto.plt_{self._mask_suffix(stmt.value.type.elements[0])}" + else: + remaining_expr = stmt.value.args[0] + opname = f"pto.{stmt.value.name}" remaining = self._lower_remaining_to_i32(remaining_expr, env, indent=indent, into=lines) mask_target, remaining_target = stmt.targets mask_type, remaining_type = stmt.value.type.elements - suffix = self._mask_suffix(mask_type) lines.append( self._indent(indent) - + f"{mask_target.ssa_name}, {remaining_target.ssa_name} = pto.plt_{suffix} {remaining.name} : " + + f"{mask_target.ssa_name}, {remaining_target.ssa_name} = {opname} {remaining.name} : " + f"i32 -> {self._render_type(mask_type)}, {self._render_type(remaining_type)}" ) env[mask_target.name] = _RenderedValue(name=mask_target.ssa_name, type=mask_type) @@ -617,7 +620,7 @@ def _render_multi_result_assign( env[carry_target.name] = _RenderedValue(name=carry_target.ssa_name, type=carry_type) return lines - if stmt.value.name in {"vintlv", "vdintlv"}: + if stmt.value.name in {"vintlv", "vdintlv", "pdintlv_b8", "pintlv_b16"}: lines = [] lhs = self._lower_expr(stmt.value.args[0], env, indent=indent, into=lines) rhs = self._lower_expr(stmt.value.args[1], env, indent=indent, into=lines) @@ -2425,7 +2428,9 @@ def _lower_call_expr( if expr.namespace != "pto": raise NotImplementedError(f"unsupported call namespace {expr.namespace!r}") if isinstance(expr.type, SemanticTupleType): - raise NotImplementedError("multi-result call values must be assigned directly in TileLang DSL v1") + raise NotImplementedError( + f"multi-result call `pto.{expr.name}` must be assigned directly in TileLang DSL v1" + ) if into is None: into = [] result_name = desired_name or self._new_temp() @@ -2446,6 +2451,34 @@ def _lower_call_expr( ) return _RenderedValue(name="__void_call__", type=expr.type) + if expr.name == "pst": + value = self._lower_expr(expr.args[0], env, indent=indent, into=into) + destination_name, destination_type, offset_name, offset_type = self._lower_memory_buffer_with_offset( + expr.args[1:-1], + env, + indent=indent, + into=into, + ) + dist = self._render_string_literal(expr.args[-1]) + into.append( + self._indent(indent) + + f"pto.pst {value.name}, {destination_name}[{offset_name}], {dist} : " + + f"{self._render_type(value.type)}, {destination_type}, {offset_type}" + ) + return _RenderedValue(name="__void_call__", type=expr.type) + + if expr.name == "psti": + value = self._lower_expr(expr.args[0], env, indent=indent, into=into) + destination = self._lower_expr(expr.args[1], env, indent=indent, into=into) + offset = self._lower_expr(expr.args[2], env, indent=indent, into=into) + dist = self._render_string_literal(expr.args[-1]) + into.append( + self._indent(indent) + + f"pto.psti {value.name}, {destination.name}, {offset.name}, {dist} : " + + f"{self._render_type(value.type)}, {self._render_type(destination.type)}, {self._render_type(offset.type)}" + ) + return _RenderedValue(name="__void_call__", type=expr.type) + if expr.name == "vsst": value = self._lower_expr(expr.args[0], env, indent=indent, into=into) destination_name, destination_type, offset_name, offset_type = self._lower_memory_buffer_with_offset( @@ -2501,12 +2534,23 @@ def _lower_call_expr( dtype_expr, pattern_expr = expr.args if not self._is_dtype_meta_expr(dtype_expr): raise NotImplementedError("make_mask dtype lowering expects a dtype symbol") - if not isinstance(pattern_expr, SemanticSymbolExpr) or not isinstance(pattern_expr.value, MaskPattern): + pattern_value = self._extract_mask_pattern_value(pattern_expr) + if pattern_value is None: raise NotImplementedError("make_mask pattern lowering expects a MaskPattern symbol") suffix = expr.type.granularity into.append( self._indent(indent) - + f'{result_name} = pto.pset_{suffix} "{pattern_expr.value.value}" : {self._render_type(expr.type)}' + + f'{result_name} = pto.pset_{suffix} "{pattern_value}" : {self._render_type(expr.type)}' + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name in {"pset_b8", "pset_b16", "pset_b32", "pge_b8", "pge_b16", "pge_b32"}: + pattern_value = self._extract_mask_pattern_value(expr.args[0]) + if pattern_value is None: + raise NotImplementedError(f"{expr.name} lowering expects a MaskPattern symbol") + into.append( + self._indent(indent) + + f'{result_name} = pto.{expr.name} "{pattern_value}" : {self._render_type(expr.type)}' ) return _RenderedValue(name=result_name, type=expr.type) @@ -2550,6 +2594,47 @@ def _lower_call_expr( ) return _RenderedValue(name=result_name, type=expr.type) + if expr.name == "plds": + source_name, source_type, offset_name, _ = self._lower_memory_buffer_with_offset( + expr.args[:-1], + env, + indent=indent, + into=into, + ) + dist = self._render_string_literal(expr.args[-1]) + into.append( + self._indent(indent) + + f'{result_name} = pto.plds {source_name}[{offset_name}] {{dist = {dist}}} : ' + + f"{source_type} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "pld": + source_name, source_type, offset_name, offset_type = self._lower_memory_buffer_with_offset( + expr.args[:-1], + env, + indent=indent, + into=into, + ) + dist = self._render_string_literal(expr.args[-1]) + into.append( + self._indent(indent) + + f"{result_name} = pto.pld {source_name}[{offset_name}], {dist} : " + + f"{source_type}, {offset_type} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "pldi": + source = self._lower_expr(expr.args[0], env, indent=indent, into=into) + offset = self._lower_expr(expr.args[1], env, indent=indent, into=into) + dist = self._render_string_literal(expr.args[2]) + into.append( + self._indent(indent) + + f"{result_name} = pto.pldi {source.name}, {offset.name}, {dist} : " + + f"{self._render_type(source.type)}, {self._render_type(offset.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + if expr.name == "vsld": source_name, source_type, offset_name, offset_type = self._lower_memory_buffer_with_offset( expr.args[:-1], @@ -2681,6 +2766,18 @@ def _lower_call_expr( ) return _RenderedValue(name=result_name, type=expr.type) + if expr.name in {"pand", "por", "pxor"}: + src0 = self._lower_expr(expr.args[0], env, indent=indent, into=into) + src1 = self._lower_expr(expr.args[1], env, indent=indent, into=into) + mask = self._lower_expr(expr.args[2], env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"{result_name} = pto.{expr.name} {src0.name}, {src1.name}, {mask.name} : " + + f"{self._render_type(src0.type)}, {self._render_type(src1.type)}, {self._render_type(mask.type)} " + + f"-> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + if expr.name == "vcmp": lhs = self._lower_expr(expr.args[0], env, indent=indent, into=into) rhs = self._lower_expr(expr.args[1], env, indent=indent, into=into) @@ -3163,6 +3260,34 @@ def _mask_suffix(self, ty: SemanticType) -> str: raise NotImplementedError("tail make_mask lowering expects a mask result type") return ty.granularity + def _extract_mask_pattern_value(self, expr: SemanticExpr) -> str | None: + if isinstance(expr, SemanticSymbolExpr) and isinstance(expr.value, MaskPattern): + return expr.value.value + if ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "mask_pattern" + and isinstance(expr.binding.value, MaskPattern) + ): + return expr.binding.value.value + return None + + def _emit_full_mask_for_type( + self, + ty: SemanticType, + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + if not isinstance(ty, SemanticMaskType): + raise NotImplementedError("full-mask synthesis expects a mask type") + result_name = self._new_temp() + into.append( + self._indent(indent) + + f'{result_name} = pto.pset_{ty.granularity} "PAT_ALL" : {self._render_type(ty)}' + ) + return _RenderedValue(name=result_name, type=ty) + def _is_dtype_meta_expr(self, expr: SemanticExpr) -> bool: if isinstance(expr, SemanticSymbolExpr): return isinstance(expr.value, ScalarType) and expr.type.kind == "dtype" diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 4a11d78fc..0f9ceba04 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -59,6 +59,8 @@ PadMode, Pipe, PositionMode, + PredicateDist, + PredicatePart, PointerType, ScalarType, SLayout, @@ -104,7 +106,22 @@ _ORDER_MODE_SYMBOLS = {order_mode.name: order_mode for order_mode in OrderMode} _DEINTERLEAVE_DIST_SYMBOLS = {dist.name: dist for dist in DeinterleaveDist} _INTERLEAVE_DIST_SYMBOLS = {dist.name: dist for dist in InterleaveDist} +_PREDICATE_DIST_SYMBOLS = {dist.name: dist for dist in PredicateDist} +_PREDICATE_PART_SYMBOLS = {part.name: part for part in PredicatePart} _STRIDE_MODE_SYMBOLS = {mode.name: mode for mode in StrideMode} +_DIRECT_PREDICATE_PATTERN_OPS = { + "pset_b8", + "pset_b16", + "pset_b32", + "pge_b8", + "pge_b16", + "pge_b32", +} +_DIRECT_PREDICATE_TAIL_OPS = {"plt_b8", "plt_b16", "plt_b32"} +_PREDICATE_MEMORY_EXPR_OPS = {"plds", "pld", "pldi"} +_PREDICATE_MEMORY_STMT_OPS = {"pst", "psti"} +_PREDICATE_BINARY_LOGIC_OPS = {"pand", "por", "pxor"} +_PREDICATE_REARRANGEMENT_OPS = {"pdintlv_b8", "pintlv_b16"} _UNARY_VECTOR_OPS = { "vabs", "vrelu", @@ -187,14 +204,18 @@ "copy_ubuf_to_ubuf", } _COMPARE_SELECT_OPS = {"vcmp", "vcmps", "vsel", "vselr", "vselrv2"} -_PREDICATE_MOVEMENT_OPS = {"pnot", "psel", "ppack", "punpack"} +_PREDICATE_MOVEMENT_OPS = {"pnot", "psel", "ppack", "punpack"} | _PREDICATE_BINARY_LOGIC_OPS _CARRY_OPS = {"vaddc", "vsubc", "vaddcs", "vsubcs"} -_REARRANGEMENT_OPS = {"vintlv", "vdintlv", "vintlvv2", "vdintlvv2"} +_REARRANGEMENT_OPS = {"vintlv", "vdintlv", "vintlvv2", "vdintlvv2"} | _PREDICATE_REARRANGEMENT_OPS _ADVANCED_VECTOR_ACTIVITY_OPS = ( _COMPARE_SELECT_OPS | _PREDICATE_MOVEMENT_OPS | _CARRY_OPS | _REARRANGEMENT_OPS + | _DIRECT_PREDICATE_PATTERN_OPS + | _DIRECT_PREDICATE_TAIL_OPS + | _PREDICATE_MEMORY_EXPR_OPS + | _PREDICATE_MEMORY_STMT_OPS | {"vcvt", "vmrgsort4"} ) _VECTOR_MEMORY_EXPR_OPS = {"vlds", "vldas", "vldus", "vldx2", "vsld"} @@ -1257,7 +1278,7 @@ def _is_vector_memory_stmt_call(self, expr: FrontendExprNode) -> bool: return ( isinstance(expr, FrontendCallExpr) and expr.namespace == "pto" - and expr.name in _VECTOR_MEMORY_STMT_OPS + and expr.name in (_VECTOR_MEMORY_STMT_OPS | _PREDICATE_MEMORY_STMT_OPS) ) def _is_sync_call(self, expr: FrontendExprNode) -> bool: @@ -1441,6 +1462,77 @@ def _analyze_vector_memory_stmt_call( type=SemanticMetaType(kind="void"), ) + if expr.name == "pst": + if len(expr.args) in {2, 3} and isinstance(expr.args[1], FrontendSubscriptExpr): + value = self._analyze_expr(expr.args[0], env, allow_outer_lookup=allow_outer_lookup) + destination, indices = self._analyze_tile_vector_access( + expr.args[1], + env, + allow_outer_lookup=allow_outer_lookup, + context="pto.pst destination", + ) + dist_expr = ( + self._analyze_expr(expr.args[2], env, allow_outer_lookup=allow_outer_lookup) + if len(expr.args) == 3 + else None + ) + else: + if len(expr.args) not in {3, 4}: + raise TypeError("pto.pst expects value, destination, offset, and optional dist in TileLang DSL v1") + analyzed = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + value, destination, offset = analyzed[:3] + indices = (offset,) + dist_expr = analyzed[3] if len(analyzed) == 4 else None + mask = self._require_mask_expr(value, "pto.pst value") + self._require_vector_pointer_expr(destination, "pto.pst destination") + if isinstance(destination.type, SemanticTileType): + if len(indices) not in {1, 2}: + raise TypeError("pto.pst Tile syntax expects rank-1 or rank-2 element indexing in TileLang DSL v1") + else: + if len(indices) != 1: + raise TypeError("pto.pst pointer syntax expects exactly one offset operand in TileLang DSL v1") + for index in indices: + self._require_index_typed_expr(index) + dist = self._normalize_predicate_dist( + dist_expr, + "pto.pst dist", + allowed={"NORM", "PK"}, + default="NORM", + ) + return SemanticCallExpr( + namespace="pto", + name="pst", + args=(value, destination, *indices, dist), + type=SemanticMetaType(kind="void"), + ) + + if expr.name == "psti": + if len(expr.args) not in {3, 4}: + raise TypeError("pto.psti expects value, destination, imm_offset, and optional dist in TileLang DSL v1") + analyzed = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + value = analyzed[0] + destination = self._require_pointer_expr(analyzed[1], "pto.psti destination", memory_space="ub") + self._require_mask_expr(value, "pto.psti value") + self._require_i32_expr(analyzed[2], "pto.psti offset") + dist = self._normalize_predicate_dist( + analyzed[3] if len(analyzed) == 4 else None, + "pto.psti dist", + allowed={"NORM", "PK"}, + default="NORM", + ) + return SemanticCallExpr( + namespace="pto", + name="psti", + args=(value, destination, analyzed[2], dist), + type=SemanticMetaType(kind="void"), + ) + if expr.name == "vsst": if len(expr.args) == 3: value = self._analyze_expr(expr.args[0], env, allow_outer_lookup=allow_outer_lookup) @@ -2126,7 +2218,9 @@ def _bind_assignment_target( for axis in range(value.type.rank) ) elif isinstance(value, SemanticCallExpr): - tuple_values = value.args + tuple_values = tuple( + SemanticLiteralExpr(value=None, type=element_type) for element_type in element_types + ) else: tuple_values = tuple( SemanticLiteralExpr(value=None, type=element_type) for element_type in element_types @@ -2587,6 +2681,23 @@ def _analyze_expr( ) stride = self._analyze_expr(expr.args[1], env, allow_outer_lookup=allow_outer_lookup) return self._analyze_vsld((base, *indices, stride)) + if ( + expr.namespace == "pto" + and expr.name == "plds" + and len(expr.args) in {1, 2} + and isinstance(expr.args[0], FrontendSubscriptExpr) + ): + base, indices = self._analyze_tile_vector_access( + expr.args[0], + env, + allow_outer_lookup=allow_outer_lookup, + context="pto.plds source", + ) + extra_args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args[1:] + ) + return self._analyze_predicate_memory_expr_op("plds", (base, *indices, *extra_args)) if expr.keywords: raise TypeError( f"call surface `{expr.namespace + '.' if expr.namespace else ''}{expr.name}` " @@ -2729,6 +2840,24 @@ def _analyze_symbol_expr(self, expr: FrontendSymbolExpr) -> SemanticExpr: value=dist, type=SemanticMetaType(kind="interleave_dist"), ) + if expr.namespace in {"PredicateDist", "pto.PredicateDist"}: + dist = _PREDICATE_DIST_SYMBOLS.get(expr.name) + if dist is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=dist, + type=SemanticMetaType(kind="predicate_dist"), + ) + if expr.namespace in {"PredicatePart", "pto.PredicatePart"}: + part = _PREDICATE_PART_SYMBOLS.get(expr.name) + if part is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=part, + type=SemanticMetaType(kind="predicate_part"), + ) if expr.namespace in {"StrideMode", "pto.StrideMode"}: stride = _STRIDE_MODE_SYMBOLS.get(expr.name) if stride is not None: @@ -3184,6 +3313,10 @@ def _analyze_call_expr( ) if name == "make_mask": return self._analyze_make_mask(args) + if name in _DIRECT_PREDICATE_PATTERN_OPS: + return self._analyze_direct_predicate_pattern_op(name, args) + if name in _DIRECT_PREDICATE_TAIL_OPS: + return self._analyze_direct_predicate_tail_op(name, args) if name == "vlds": return self._analyze_vlds(args) if name == "vldas": @@ -3194,6 +3327,8 @@ def _analyze_call_expr( return self._analyze_vldx2(args) if name == "vsld": return self._analyze_vsld(args) + if name in _PREDICATE_MEMORY_EXPR_OPS: + return self._analyze_predicate_memory_expr_op(name, args) if name == "pstu": return self._analyze_pstu(args) if name == "vstu": @@ -3204,7 +3339,7 @@ def _analyze_call_expr( return self._analyze_vstur(args) if name in {"ppack", "punpack"}: return self._analyze_mask_part_op(name, args) - if name in {"pnot", "psel"}: + if name in {"pnot", "psel"} | _PREDICATE_BINARY_LOGIC_OPS: return self._analyze_mask_logic_op(name, args) if name in {"vcmp", "vcmps"}: return self._analyze_compare_op(name, args) @@ -3212,6 +3347,8 @@ def _analyze_call_expr( return self._analyze_select_op(name, args) if name in {"vaddc", "vsubc", "vaddcs", "vsubcs"}: return self._analyze_carry_op(name, args) + if name in _PREDICATE_REARRANGEMENT_OPS: + return self._analyze_predicate_rearrangement_op(name, args) if name in {"vintlv", "vdintlv", "vintlvv2", "vdintlvv2"}: return self._analyze_rearrangement_op(name, args) if name == "vcvt": @@ -3259,6 +3396,142 @@ def _analyze_make_mask(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: ), ) + def _mask_type_from_named_family(self, name: str) -> SemanticMaskType: + if name.endswith("_b8"): + return SemanticMaskType(granularity="b8") + if name.endswith("_b16"): + return SemanticMaskType(granularity="b16") + if name.endswith("_b32"): + return SemanticMaskType(granularity="b32") + raise TypeError(f"unsupported predicate family `{name}` in TileLang DSL v1") + + def _analyze_direct_predicate_pattern_op( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if len(args) != 1: + raise TypeError(f"pto.{name} expects exactly 1 positional argument in TileLang DSL v1") + pattern = args[0] + if not ( + ( + isinstance(pattern, SemanticSymbolExpr) + and isinstance(pattern.type, SemanticMetaType) + and pattern.type.kind == "mask_pattern" + ) + or ( + isinstance(pattern, SemanticBindingRef) + and isinstance(pattern.type, SemanticMetaType) + and pattern.type.kind == "mask_pattern" + ) + ): + raise TypeError(f"pto.{name} pattern must be a pto.MaskPattern value in TileLang DSL v1") + return SemanticCallExpr( + namespace="pto", + name=name, + args=args, + type=self._mask_type_from_named_family(name), + ) + + def _analyze_direct_predicate_tail_op( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if len(args) != 1: + raise TypeError(f"pto.{name} expects exactly 1 positional argument in TileLang DSL v1") + self._require_tail_remaining_expr(args[0], f"pto.{name} scalar") + return SemanticCallExpr( + namespace="pto", + name=name, + args=args, + type=SemanticTupleType(elements=(self._mask_type_from_named_family(name), _I32_TYPE)), + ) + + def _predicate_mask_type_for_buffer( + self, + source: SemanticExpr, + context: str, + ) -> SemanticMaskType: + if isinstance(source.type, SemanticTileType): + tile = self._require_tile_expr(source, context) + return SemanticMaskType(granularity=self._mask_granularity_for_dtype(tile.type.element_dtype)) + pointer = self._require_pointer_expr(source, context, memory_space="ub") + return SemanticMaskType(granularity=self._mask_granularity_for_dtype(pointer.type.element_dtype)) + + def _analyze_predicate_memory_expr_op( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if name == "plds": + if len(args) < 2: + raise TypeError("pto.plds expects source, offset, and optional dist in TileLang DSL v1") + source = args[0] + if isinstance(source.type, SemanticTileType): + source = self._require_tile_expr(source, "pto.plds source") + else: + source = self._require_pointer_expr(source, "pto.plds source", memory_space="ub") + trailing = args[1:] + has_dist = ( + len(trailing) > 1 + and isinstance(trailing[-1].type, SemanticMetaType) + and trailing[-1].type.kind in {"predicate_dist", "string"} + ) + index_args = trailing[:-1] if has_dist else trailing + dist_arg = trailing[-1] if has_dist else None + for index in index_args: + self._require_index_typed_expr(index) + if isinstance(source.type, SemanticPtrType) and len(index_args) != 1: + raise TypeError("pto.plds pointer syntax expects exactly one offset operand in TileLang DSL v1") + dist = self._normalize_predicate_dist( + dist_arg, + "pto.plds dist", + allowed={"NORM", "US", "DS"}, + default="NORM", + ) + return SemanticCallExpr( + namespace="pto", + name=name, + args=(source, *index_args, dist), + type=self._predicate_mask_type_for_buffer(source, "pto.plds source"), + ) + + if name == "pld": + if len(args) != 3: + raise TypeError("pto.pld expects source, offset, and dist in TileLang DSL v1") + source = self._require_pointer_expr(args[0], "pto.pld source", memory_space="ub") + self._require_index_typed_expr(args[1]) + dist = self._normalize_predicate_dist( + args[2], + "pto.pld dist", + allowed={"NORM", "US", "DS"}, + default="NORM", + ) + return SemanticCallExpr( + namespace="pto", + name=name, + args=(source, args[1], dist), + type=self._predicate_mask_type_for_buffer(source, "pto.pld source"), + ) + + if len(args) != 3: + raise TypeError("pto.pldi expects source, imm_offset, and dist in TileLang DSL v1") + source = self._require_pointer_expr(args[0], "pto.pldi source", memory_space="ub") + self._require_i32_expr(args[1], "pto.pldi offset") + dist = self._normalize_predicate_dist( + args[2], + "pto.pldi dist", + allowed={"NORM", "US", "DS"}, + default="NORM", + ) + return SemanticCallExpr( + namespace="pto", + name=name, + args=(source, args[1], dist), + type=self._predicate_mask_type_for_buffer(source, "pto.pldi source"), + ) + def _analyze_scalar_constructor( self, name: str, @@ -3777,8 +4050,8 @@ def _analyze_mask_part_op( if len(args) != 2: raise TypeError(f"pto.{name} expects exactly 2 positional arguments in TileLang DSL") mask = self._require_mask_expr(args[0], f"pto.{name} mask") - self._require_string_expr(args[1], f"pto.{name} part") - return SemanticCallExpr(namespace="pto", name=name, args=args, type=mask) + part = self._normalize_predicate_part(args[1], f"pto.{name} part") + return SemanticCallExpr(namespace="pto", name=name, args=(args[0], part), type=mask) def _analyze_mask_logic_op( self, @@ -3792,13 +4065,23 @@ def _analyze_mask_logic_op( mask = self._require_mask_expr(args[1], "pto.pnot mask") self._require_matching_mask_types(value, mask, "pto.pnot") return SemanticCallExpr(namespace="pto", name=name, args=args, type=value) + if name == "psel": + if len(args) != 3: + raise TypeError("pto.psel expects exactly 3 positional arguments in TileLang DSL") + src0 = self._require_mask_expr(args[0], "pto.psel src0") + src1 = self._require_mask_expr(args[1], "pto.psel src1") + mask = self._require_mask_expr(args[2], "pto.psel mask") + self._require_matching_mask_types(src0, src1, "pto.psel") + self._require_matching_mask_types(src0, mask, "pto.psel") + return SemanticCallExpr(namespace="pto", name=name, args=args, type=src0) + if len(args) != 3: - raise TypeError("pto.psel expects exactly 3 positional arguments in TileLang DSL") - src0 = self._require_mask_expr(args[0], "pto.psel src0") - src1 = self._require_mask_expr(args[1], "pto.psel src1") - mask = self._require_mask_expr(args[2], "pto.psel mask") - self._require_matching_mask_types(src0, src1, "pto.psel") - self._require_matching_mask_types(src0, mask, "pto.psel") + raise TypeError(f"pto.{name} expects exactly 3 positional arguments in TileLang DSL v1") + src0 = self._require_mask_expr(args[0], f"pto.{name} src0") + src1 = self._require_mask_expr(args[1], f"pto.{name} src1") + self._require_matching_mask_types(src0, src1, f"pto.{name}") + mask = self._require_mask_expr(args[2], f"pto.{name} mask") + self._require_matching_mask_types(src0, mask, f"pto.{name}") return SemanticCallExpr(namespace="pto", name=name, args=args, type=src0) def _analyze_compare_op( @@ -3928,6 +4211,26 @@ def _analyze_rearrangement_op( self._require_string_expr(args[2], f"pto.{name} part") return SemanticCallExpr(namespace="pto", name=name, args=args, type=lhs) + def _analyze_predicate_rearrangement_op( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if len(args) != 2: + raise TypeError(f"pto.{name} expects exactly 2 positional arguments in TileLang DSL v1") + lhs = self._require_mask_expr(args[0], f"pto.{name} lhs") + rhs = self._require_mask_expr(args[1], f"pto.{name} rhs") + self._require_matching_mask_types(lhs, rhs, f"pto.{name}") + expected = "b8" if name == "pdintlv_b8" else "b16" + if lhs.granularity != expected: + raise TypeError(f"pto.{name} requires {expected} mask operands in TileLang DSL v1") + return SemanticCallExpr( + namespace="pto", + name=name, + args=args, + type=SemanticTupleType(elements=(lhs, rhs)), + ) + def _analyze_vcvt(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: if len(args) != 3: raise TypeError("pto.vcvt expects exactly 3 positional arguments in TileLang DSL") @@ -4192,6 +4495,56 @@ def _normalize_interleave_dist(self, expr: SemanticExpr, context: str) -> Semant return SemanticLiteralExpr(value=expr.value.value, type=SemanticMetaType(kind="string")) return SemanticLiteralExpr(value=self._require_string_expr(expr, context), type=SemanticMetaType(kind="string")) + def _normalize_predicate_dist( + self, + expr: SemanticExpr | None, + context: str, + *, + allowed: set[str], + default: str, + ) -> SemanticLiteralExpr: + if expr is None: + return SemanticLiteralExpr(value=default, type=SemanticMetaType(kind="string")) + if ( + isinstance(expr, SemanticSymbolExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "predicate_dist" + and isinstance(expr.value, PredicateDist) + ): + value = expr.value.value + elif ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "predicate_dist" + and isinstance(expr.binding.value, PredicateDist) + ): + value = expr.binding.value.value + else: + value = self._require_string_expr(expr, context) + if value not in allowed: + supported = ", ".join(sorted(allowed)) + raise TypeError(f"{context} must be one of {supported} in TileLang DSL v1") + return SemanticLiteralExpr(value=value, type=SemanticMetaType(kind="string")) + + def _normalize_predicate_part(self, expr: SemanticExpr, context: str) -> SemanticLiteralExpr: + if ( + isinstance(expr, SemanticSymbolExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "predicate_part" + and isinstance(expr.value, PredicatePart) + ): + value = expr.value.value + elif ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "predicate_part" + and isinstance(expr.binding.value, PredicatePart) + ): + value = expr.binding.value.value + else: + raise TypeError(f"{context} must be PredicatePart.LOWER or PredicatePart.HIGHER in TileLang DSL v1") + return SemanticLiteralExpr(value=value, type=SemanticMetaType(kind="string")) + def _normalize_stride_mode(self, expr: SemanticExpr, context: str) -> SemanticLiteralExpr: if ( isinstance(expr, SemanticSymbolExpr) diff --git a/tilelang-dsl/python/tilelang_dsl/support_matrix.py b/tilelang-dsl/python/tilelang_dsl/support_matrix.py index efb16b9b0..257b50c0f 100644 --- a/tilelang-dsl/python/tilelang_dsl/support_matrix.py +++ b/tilelang-dsl/python/tilelang_dsl/support_matrix.py @@ -123,15 +123,34 @@ ADVANCED_VECSCOPE_PTO_CALLS = frozenset( { + "pset_b8", + "pset_b16", + "pset_b32", + "pge_b8", + "pge_b16", + "pge_b32", + "plt_b8", + "plt_b16", + "plt_b32", + "plds", + "pld", + "pldi", + "pst", + "psti", "vcmp", "vcmps", "vsel", "vselr", "vselrv2", "pnot", + "pand", + "por", + "pxor", "psel", "ppack", "punpack", + "pdintlv_b8", + "pintlv_b16", "vaddc", "vsubc", "vaddcs", @@ -325,6 +344,8 @@ def get_pto_call_tier(call_name: str) -> str: "BLayout": BASIC_TIER, "SLayout": BASIC_TIER, "PadValue": BASIC_TIER, + "PredicateDist": ADVANCED_TIER, + "PredicatePart": ADVANCED_TIER, "constexpr": BASIC_TIER, "pto.constexpr": BASIC_TIER, "tile[start:]": BASIC_TIER, diff --git a/tilelang-dsl/python/tilelang_dsl/types.py b/tilelang-dsl/python/tilelang_dsl/types.py index 0d21acfe0..f60cf4875 100644 --- a/tilelang-dsl/python/tilelang_dsl/types.py +++ b/tilelang-dsl/python/tilelang_dsl/types.py @@ -180,6 +180,18 @@ class InterleaveDist(str, Enum): B32 = "INTLV_B32" +class PredicateDist(str, Enum): + NORM = "NORM" + US = "US" + DS = "DS" + PK = "PK" + + +class PredicatePart(str, Enum): + LOWER = "LOWER" + HIGHER = "HIGHER" + + class StrideMode(str, Enum): S3_B16 = "STRIDE_S3_B16" S4_B64 = "STRIDE_S4_B64" @@ -414,6 +426,8 @@ def tile_layout_descriptor( "OrderMode", "DeinterleaveDist", "InterleaveDist", + "PredicateDist", + "PredicatePart", "StrideMode", "TileConfig", "TileSpecialization", diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index 161629d16..5ab900e55 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -88,6 +88,8 @@ def test_package_exports_surface(self) -> None: self.assertTrue(hasattr(pto, "OrderMode")) self.assertTrue(hasattr(pto, "DeinterleaveDist")) self.assertTrue(hasattr(pto, "InterleaveDist")) + self.assertTrue(hasattr(pto, "PredicateDist")) + self.assertTrue(hasattr(pto, "PredicatePart")) self.assertTrue(hasattr(pto, "StrideMode")) self.assertTrue(hasattr(pto, "PIPE")) self.assertTrue(hasattr(pto, "EVENT")) @@ -102,6 +104,10 @@ def test_package_exports_surface(self) -> None: self.assertEqual(pto.OrderMode.ASC.value, "ORDER_ASC") self.assertEqual(pto.DeinterleaveDist.B32.value, "DINTLV_B32") self.assertEqual(pto.InterleaveDist.B16.value, "INTLV_B16") + self.assertEqual(pto.PredicateDist.NORM.value, "NORM") + self.assertEqual(pto.PredicateDist.PK.value, "PK") + self.assertEqual(pto.PredicatePart.LOWER.value, "LOWER") + self.assertEqual(pto.PredicatePart.HIGHER.value, "HIGHER") self.assertEqual(pto.StrideMode.S4_B64.value, "STRIDE_S4_B64") @@ -154,10 +160,16 @@ def test_stable_starter_surface_groups_map_to_stable_tier(self) -> None: self.assertEqual(get_feature_tier("pto.vci"), BASIC_TIER) self.assertEqual(get_feature_tier("pto.vpack"), BASIC_TIER) self.assertEqual(get_feature_tier("pto.vsort32"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.pset_b32"), ADVANCED_TIER) + self.assertEqual(get_feature_tier("pto.plds"), ADVANCED_TIER) + self.assertEqual(get_feature_tier("pto.pand"), ADVANCED_TIER) + self.assertEqual(get_feature_tier("pto.pintlv_b16"), ADVANCED_TIER) self.assertEqual(get_feature_tier("pto.pstu"), ADVANCED_TIER) self.assertEqual(get_feature_tier("pto.vstu"), ADVANCED_TIER) self.assertEqual(get_feature_tier("pto.vstus"), ADVANCED_TIER) self.assertEqual(get_feature_tier("pto.vstur"), ADVANCED_TIER) + self.assertEqual(get_feature_tier("PredicateDist"), ADVANCED_TIER) + self.assertEqual(get_feature_tier("PredicatePart"), ADVANCED_TIER) self.assertEqual(get_feature_tier("PadMode"), BASIC_TIER) self.assertEqual(get_feature_tier("VRegType"), BASIC_TIER) self.assertEqual(get_feature_tier("MaskType"), BASIC_TIER) @@ -1794,6 +1806,122 @@ def kernel(src: pto.Tile, dst: pto.Tile): self.assertRegex(text, r"= pto\.vstu %align1_\d+, %c0, %vec_\d+, %ub_dst_\d+, \"MODE_ZEROING\"") self.assertRegex(text, r"= pto\.vstus %align2_\d+, %(?:c16_i32|tmp_\d+), %vec_\d+, %ub_dst_\d+, \"MODE_ZEROING\"") + def test_advanced_direct_predicate_surfaces_lower_with_typed_families(self) -> None: + @pto.vkernel(op="predicate_surface_unique", dtypes=[(pto.f32, pto.i32)], advanced=True) + def kernel(tile: pto.Tile, remaining: pto.i32): + all8 = pto.pset_b8(pto.PAT.ALL) + all16 = pto.pset_b16(pto.PAT.ALL) + all32 = pto.pset_b32(pto.PAT.ALL) + tail8 = pto.pge_b8(pto.PAT.ALL) + tail16 = pto.pge_b16(pto.PAT.ALL) + tail32 = pto.pge_b32(pto.PAT.ALL) + mask8, rem8 = pto.plt_b8(remaining) + mask16, rem16 = pto.plt_b16(remaining) + mask32, rem32 = pto.plt_b32(remaining) + gate32 = all32 + and_mask = pto.pand(all32, tail32, gate32) + or_mask = pto.por(and_mask, all32, gate32) + xor_mask = pto.pxor(or_mask, tail32, gate32) + low8, high8 = pto.pdintlv_b8(all8, tail8) + low16, high16 = pto.pintlv_b16(all16, tail16) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn(' = pto.pset_b8 "PAT_ALL"', text) + self.assertIn(' = pto.pset_b16 "PAT_ALL"', text) + self.assertIn(' = pto.pset_b32 "PAT_ALL"', text) + self.assertIn(' = pto.pge_b8 "PAT_ALL"', text) + self.assertIn(' = pto.pge_b16 "PAT_ALL"', text) + self.assertIn(' = pto.pge_b32 "PAT_ALL"', text) + self.assertIn(" = pto.plt_b8 ", text) + self.assertIn(" = pto.plt_b16 ", text) + self.assertIn(" = pto.plt_b32 ", text) + self.assertIn(" = pto.pand ", text) + self.assertIn(" = pto.por ", text) + self.assertIn(" = pto.pxor ", text) + self.assertIn(" = pto.pdintlv_b8 ", text) + self.assertIn(" = pto.pintlv_b16 ", text) + self.assertEqual(text.count('pto.pset_b32 "PAT_ALL"'), 1) + + def test_advanced_predicate_memory_surfaces_lower_with_dist_enums(self) -> None: + @pto.vkernel(op="predicate_memory_surface_unique", dtypes=[(pto.f32,)], advanced=True) + def kernel(tile: pto.Tile): + ub = tile.as_ptr() + align = pto.vldas(ub) + all_mask = pto.pset_b32(pto.PAT.ALL) + loaded0 = pto.plds(ub, 0) + loaded1 = pto.pld(ub, 0, pto.PredicateDist.NORM) + loaded2 = pto.pldi(ub, pto.i32(4), pto.PredicateDist.US) + pto.psts(all_mask, ub, 0) + pto.pst(loaded0, ub, 0) + pto.psti(loaded1, ub, pto.i32(8), pto.PredicateDist.PK) + next_align, next_base = pto.pstu(align, all_mask, ub) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn('{dist = "NORM"}', text) + self.assertIn(" = pto.pld ", text) + self.assertIn(', "NORM" : !pto.ptr, index -> !pto.mask', text) + self.assertIn(" = pto.pldi ", text) + self.assertIn(', "US" : !pto.ptr, i32 -> !pto.mask', text) + self.assertIn("pto.psts ", text) + self.assertIn(': !pto.mask, !pto.ptr', text) + self.assertIn("pto.pst ", text) + self.assertIn(', "NORM" : !pto.mask, !pto.ptr, index', text) + self.assertIn("pto.psti ", text) + self.assertIn(', "PK" : !pto.mask, !pto.ptr, i32', text) + self.assertIn("pto.pstu", text) + + def test_pld_and_pldi_require_explicit_predicate_dist(self) -> None: + @pto.vkernel(op="predicate_memory_missing_dist", dtypes=[(pto.f32,)], advanced=True) + def kernel(tile: pto.Tile): + ub = tile.as_ptr() + loaded1 = pto.pld(ub, 0) + loaded2 = pto.pldi(ub, pto.i32(4)) + pto.pst(loaded1, ub, 0) + pto.pst(loaded2, ub, 0) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + with self.assertRaises(TypeError) as ctx: + specialized.mlir_text() + + message = str(ctx.exception) + self.assertTrue( + "pto.pld expects source, offset, and dist" in message + or "pto.pldi expects source, imm_offset, and dist" in message + ) + + def test_pand_por_pxor_require_explicit_gate_mask(self) -> None: + @pto.vkernel(op="predicate_logic_missing_gate", dtypes=[(pto.f32, pto.i32)], advanced=True) + def kernel(tile: pto.Tile, remaining: pto.i32): + all32 = pto.pset_b32(pto.PAT.ALL) + tail32 = pto.pge_b32(pto.PAT.ALL) + and_mask = pto.pand(all32, tail32) + or_mask = pto.por(and_mask, all32) + xor_mask = pto.pxor(or_mask, tail32) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + with self.assertRaises(TypeError) as ctx: + specialized.mlir_text() + + self.assertIn("expects exactly 3 positional arguments", str(ctx.exception)) + def test_tail_make_mask_lowers_to_typed_plt_and_updates_remaining(self) -> None: @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.i32)], advanced=True) def kernel(tile: pto.Tile, remaining: pto.i32): @@ -3363,8 +3491,8 @@ def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile, scalar: pto.i32): cmp_scalar_mask = pto.vcmps(lhs, scalar, all_mask, "gt") negated = pto.pnot(cmp_mask, all_mask) picked = pto.psel(cmp_mask, negated, cmp_scalar_mask) - packed = pto.ppack(picked, "PART_EVEN") - unpacked = pto.punpack(packed, "PART_ODD") + packed = pto.ppack(picked, pto.PredicatePart.LOWER) + unpacked = pto.punpack(packed, pto.PredicatePart.HIGHER) sum_vec, carry_mask = pto.vaddc(lhs, rhs, all_mask) diff_vec, borrow_mask = pto.vsubc(lhs, rhs, all_mask) sum_with_carry, carry_mask2 = pto.vaddcs(sum_vec, diff_vec, carry_mask, all_mask) @@ -3398,9 +3526,9 @@ def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile, scalar: pto.i32): self.assertIn(" = pto.pnot ", text) self.assertIn(" = pto.psel ", text) self.assertIn(' = pto.ppack ', text) - self.assertIn('"PART_EVEN"', text) + self.assertIn('"LOWER"', text) self.assertIn(' = pto.punpack ', text) - self.assertIn('"PART_ODD"', text) + self.assertIn('"HIGHER"', text) self.assertRegex( text, r"%sum_vec_\d+, %carry_mask_\d+ = pto\.vaddc %lhs_\d+, %rhs_\d+, %all_mask_\d+ : !pto\.vreg<64xi32>, !pto\.vreg<64xi32>, !pto\.mask -> !pto\.vreg<64xi32>, !pto\.mask", @@ -3432,6 +3560,29 @@ def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile, scalar: pto.i32): self.assertIn(" = pto.vselrv2 ", text) self.assertIn("pto.vsts ", text) + def test_ppack_and_punpack_require_predicate_part_enum(self) -> None: + @pto.vkernel(op="predicate_part_typecheck", dtypes=[(pto.i32, pto.i32, pto.i32)], advanced=True) + def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + all_mask = pto.make_mask(pto.i32, pto.PAT.ALL) + lhs = pto.vlds(src0[0, 0:]) + rhs = pto.vlds(src1[0, 0:]) + cmp_mask = pto.vcmp(lhs, rhs, all_mask, "lt") + packed = pto.ppack(cmp_mask, "LOWER") + unpacked = pto.punpack(packed, "HIGHER") + pto.vsts(pto.vsel(lhs, rhs, unpacked), dst[0, 0:], all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src0=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src1=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + with self.assertRaises(TypeError) as ctx: + specialized.mlir_text() + + self.assertIn("PredicatePart.LOWER or PredicatePart.HIGHER", str(ctx.exception)) + def test_elementwise_kernel_positive_regression_covers_vecscope_tail_mask_and_dynamic_loop_bound(self) -> None: @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32, pto.i32)], advanced=True) def kernel(inp: pto.TensorView, tile: pto.Tile, remaining: pto.i32): From 70026b15476e2863b15a22fe0b07ebd965f4e29b Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Sat, 11 Apr 2026 22:52:20 +0800 Subject: [PATCH 037/192] Support keyword arguments for vdup and vci operations in TileLang DSL --- .../11-vector-arithmetic-operations.md | 3 + .../python/tilelang_dsl/frontend_ast.py | 4 +- tilelang-dsl/python/tilelang_dsl/kernel.py | 2 +- tilelang-dsl/python/tilelang_dsl/semantic.py | 90 ++++++++++++++++++- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 6 +- 5 files changed, 96 insertions(+), 9 deletions(-) diff --git a/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md b/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md index 872bfb80b..817eeeb94 100644 --- a/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md +++ b/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md @@ -1385,4 +1385,7 @@ Type conversion and specialized operations. ```python # Generate ascending indices starting from 0 indices = pto.vci(pto.i32(0), OrderMode.ASC) + +# Keyword form for the optional order argument is also supported +indices_kw = pto.vci(pto.i32(0), order=OrderMode.ASC) ``` diff --git a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py index 7f1b80625..5ae619606 100644 --- a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py +++ b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py @@ -609,6 +609,8 @@ def _collect_reachable_inline_procs( } _DMA_CALL_KEYWORDS: dict[str, frozenset[str]] = { + "vdup": frozenset({"position"}), + "vci": frozenset({"order"}), "set_loop2_stride_outtoub": frozenset({"src_stride", "dst_stride"}), "set_loop1_stride_outtoub": frozenset({"src_stride", "dst_stride"}), "set_loop_size_outtoub": frozenset({"loop1", "loop2"}), @@ -728,7 +730,7 @@ def _build_call_keywords( raise context.error( node, f"`{call_name}` does not support keyword arguments in TileLang DSL v1; " - "no public call surface currently accepts them", + "keyword arguments are only supported on selected public call surfaces", ) seen: set[str] = set() diff --git a/tilelang-dsl/python/tilelang_dsl/kernel.py b/tilelang-dsl/python/tilelang_dsl/kernel.py index fcae842ba..784f328a3 100644 --- a/tilelang-dsl/python/tilelang_dsl/kernel.py +++ b/tilelang-dsl/python/tilelang_dsl/kernel.py @@ -372,7 +372,7 @@ def _validate_call_keywords(self, node: ast.Call) -> None: raise self.source_info.error( node, f"`{call_name}` does not support keyword arguments in TileLang DSL v1; " - "no public call surface currently accepts them", + "keyword arguments are only supported on selected public call surfaces", ) seen: set[str] = set() diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 0f9ceba04..d911b4337 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -2699,10 +2699,10 @@ def _analyze_expr( ) return self._analyze_predicate_memory_expr_op("plds", (base, *indices, *extra_args)) if expr.keywords: - raise TypeError( - f"call surface `{expr.namespace + '.' if expr.namespace else ''}{expr.name}` " - "carries keyword arguments, but semantic keyword handling is not implemented " - "in TileLang DSL v1 yet" + return self._analyze_keyword_call_expr( + expr, + env, + allow_outer_lookup=allow_outer_lookup, ) args = tuple( self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) @@ -3371,6 +3371,33 @@ def _analyze_call_expr( return self._analyze_ternary_vector_op(name, args) raise TypeError(f"call surface `pto.{name}` is not supported in TileLang DSL v1 yet") + def _analyze_keyword_call_expr( + self, + expr: FrontendCallExpr, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + ) -> SemanticExpr: + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + analyzed_keywords = { + name: self._analyze_expr(value, env, allow_outer_lookup=allow_outer_lookup) + for name, value in expr.keywords + } + + if expr.namespace == "pto" and expr.name == "vdup": + return self._analyze_vdup_keyword_call(args, analyzed_keywords) + if expr.namespace == "pto" and expr.name == "vci": + return self._analyze_vci_keyword_call(args, analyzed_keywords) + + raise TypeError( + f"call surface `{expr.namespace + '.' if expr.namespace else ''}{expr.name}` " + "carries keyword arguments, but semantic keyword handling is not implemented " + "in TileLang DSL v1 yet" + ) + def _analyze_make_mask(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: if len(args) != 2: raise TypeError("pto.make_mask expects exactly 2 positional arguments in TileLang DSL v1") @@ -3939,6 +3966,61 @@ def _analyze_broadcast_vector_op( raise TypeError(f"call surface `pto.{name}` is not supported in TileLang DSL v1 yet") + def _analyze_vdup_keyword_call( + self, + args: tuple[SemanticExpr, ...], + keywords: dict[str, SemanticExpr], + ) -> SemanticExpr: + if not args: + raise TypeError("pto.vdup expects at least 1 positional argument in TileLang DSL v1") + if len(args) > 2: + raise TypeError("pto.vdup expects 1 or 2 operands in TileLang DSL v1") + if len(args) == 2 and "position" in keywords: + raise TypeError("pto.vdup got multiple values for argument `position` in TileLang DSL v1") + + value = args[0] + if isinstance(value.type, SemanticVRegType): + vec_type = value.type + else: + vec_type = self._vreg_type_for_scalar_or_index(value, "pto.vdup input") + position = self._normalize_position_mode( + keywords.get("position", args[1] if len(args) == 2 else None), + "pto.vdup position", + ) + return SemanticCallExpr( + namespace="pto", + name="vdup", + args=(value, position), + type=vec_type, + ) + + def _analyze_vci_keyword_call( + self, + args: tuple[SemanticExpr, ...], + keywords: dict[str, SemanticExpr], + ) -> SemanticExpr: + if not args: + raise TypeError("pto.vci expects at least 1 positional argument in TileLang DSL v1") + if len(args) > 2: + raise TypeError("pto.vci expects 1 or 2 operands in TileLang DSL v1") + if len(args) == 2 and "order" in keywords: + raise TypeError("pto.vci got multiple values for argument `order` in TileLang DSL v1") + + index = self._require_scalar_or_index_expr(args[0], "pto.vci index") + index_dtype = i32 if isinstance(index.type, SemanticIndexType) else index.type.dtype + if index_dtype.name not in {"i8", "i16", "i32"}: + raise TypeError("pto.vci index only supports i8/i16/i32 in TileLang DSL v1") + order = self._normalize_order_mode( + keywords.get("order", args[1] if len(args) == 2 else None), + "pto.vci order", + ) + return SemanticCallExpr( + namespace="pto", + name="vci", + args=(index, order), + type=self._vreg_type_for_dtype(index_dtype), + ) + def _analyze_unary_vector_op( self, name: str, diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index 5ab900e55..8635f7e1d 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -1129,7 +1129,7 @@ def kernel(inp: pto.TensorView, tile: pto.Tile): pto.vlds(tile, offset=0) return None - self.assertIn("no public call surface currently accepts them", str(ctx.exception)) + self.assertIn("keyword arguments are only supported on selected public call surfaces", str(ctx.exception)) self.assertIn(f"{__file__}:", str(ctx.exception)) def test_frontend_rewrites_template_slot_to_selected_real_op(self) -> None: @@ -2395,9 +2395,9 @@ def kernel(dst: pto.Tile, src: pto.Tile, seed: pto.i32): broadcast = pto.vbr(seed) dup_from_vec = pto.vdup(vec0) - dup_from_scalar = pto.vdup(seed, pto.PositionMode.LOWEST) + dup_from_scalar = pto.vdup(seed, position=pto.PositionMode.LOWEST) idx0 = pto.vci(seed) - idx1 = pto.vci(seed, pto.OrderMode.ASC) + idx1 = pto.vci(seed, order=pto.OrderMode.ASC) out = pto.vadd(broadcast, dup_from_vec, all_mask) out = pto.vadd(out, dup_from_scalar, all_mask) From 64b295bee483f67360a276ca2d9bb7fb4fe2fd21 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Mon, 13 Apr 2026 11:12:30 +0800 Subject: [PATCH 038/192] Support more dma operations --- tilelang-dsl/docs/unsupported-features.md | 8 - tilelang-dsl/python/tilelang_dsl/__init__.py | 2 + .../python/tilelang_dsl/frontend_ast.py | 23 ++- tilelang-dsl/python/tilelang_dsl/lowering.py | 97 ++++++++++ tilelang-dsl/python/tilelang_dsl/semantic.py | 166 ++++++++++++++++-- .../python/tilelang_dsl/support_matrix.py | 11 +- tilelang-dsl/python/tilelang_dsl/types.py | 5 + tilelang-dsl/tests/test_tilelang_dsl_v1.py | 75 +++++++- 8 files changed, 360 insertions(+), 27 deletions(-) diff --git a/tilelang-dsl/docs/unsupported-features.md b/tilelang-dsl/docs/unsupported-features.md index 94a085a9e..3a288c85e 100644 --- a/tilelang-dsl/docs/unsupported-features.md +++ b/tilelang-dsl/docs/unsupported-features.md @@ -48,14 +48,6 @@ The following guide surfaces are not implemented as public APIs: - `pto.tile_with_strides(...)` - `pto.tile_config(...)` -### Missing Sync/Buffer Control Ops - -These documented surfaces are not accepted by the current frontend: - -- `pto.get_buf(...)` -- `pto.rls_buf(...)` - - ### Missing Vector Load/Store Families The previously missing vector-memory surfaces diff --git a/tilelang-dsl/python/tilelang_dsl/__init__.py b/tilelang-dsl/python/tilelang_dsl/__init__.py index d9525c342..5a957abdc 100644 --- a/tilelang-dsl/python/tilelang_dsl/__init__.py +++ b/tilelang-dsl/python/tilelang_dsl/__init__.py @@ -26,6 +26,7 @@ AnyType, AlignType, BLayout, + BarrierType, DeinterleaveDist, EVENT, PIPE, @@ -108,6 +109,7 @@ "EVENT", "MaskPattern", "PAT", + "BarrierType", "PadMode", "PositionMode", "OrderMode", diff --git a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py index 5ae619606..f0a5dc96c 100644 --- a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py +++ b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py @@ -790,7 +790,28 @@ def _build_expr(node: ast.AST, context: _FrontendBuildContext) -> FrontendExprNo ) if isinstance(node, ast.Attribute): path = _attribute_path(node) - if path is not None and path[0] in {"pto", "PAT", "PIPE", "Pipe", "EVENT", "Event"} and len(path) >= 2: + if path is not None and path[0] in { + "pto", + "PAT", + "PIPE", + "EVENT", + "MaskPattern", + "Pipe", + "Event", + "BarrierType", + "MemorySpace", + "PadMode", + "PositionMode", + "OrderMode", + "BLayout", + "SLayout", + "PadValue", + "DeinterleaveDist", + "InterleaveDist", + "PredicateDist", + "PredicatePart", + "StrideMode", + } and len(path) >= 2: return FrontendSymbolExpr(namespace=".".join(path[:-1]), name=path[-1]) return FrontendAttributeExpr(base=_build_expr(node.value, context), attr=node.attr) if isinstance(node, ast.Subscript): diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index 4180274a6..c962c1e33 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -26,19 +26,25 @@ SemanticExpr, SemanticExprStmt, SemanticForStmt, + SemanticGetBufStmt, SemanticIfStmt, SemanticIndexType, SemanticIfResult, SemanticKernel, SemanticLiteralExpr, + SemanticMemBarStmt, SemanticLowLevelCopyStmt, SemanticMaskType, SemanticMetaType, SemanticPipeBarrierStmt, SemanticPtrType, SemanticReturnStmt, + SemanticRlsBufStmt, SemanticScalarType, + SemanticSetCrossCoreStmt, SemanticSetFlagStmt, + SemanticSetIntraBlockStmt, + SemanticSetIntraCoreStmt, SemanticShapeType, SemanticStmt, SemanticVecscopeStmt, @@ -54,7 +60,9 @@ SemanticTupleType, SemanticVRegType, SemanticVectorStoreStmt, + SemanticWaitFlagDevStmt, SemanticWaitFlagStmt, + SemanticWaitIntraCoreStmt, ) from .types import MaskPattern, MemorySpace, ScalarType, TileConfig, get_lanes, tile_strides @@ -408,6 +416,22 @@ def _render_stmt( ] if isinstance(stmt, SemanticPipeBarrierStmt): return [self._indent(indent) + f"pto.barrier #pto.pipe<{stmt.pipe}>"] + if isinstance(stmt, SemanticGetBufStmt): + return self._render_buffer_sync_stmt("get_buf", stmt.pipe, stmt.buf_id, stmt.mode, env, indent=indent) + if isinstance(stmt, SemanticRlsBufStmt): + return self._render_buffer_sync_stmt("rls_buf", stmt.pipe, stmt.buf_id, stmt.mode, env, indent=indent) + if isinstance(stmt, SemanticMemBarStmt): + return [self._indent(indent) + f'pto.mem_bar "{stmt.barrier_type}"'] + if isinstance(stmt, SemanticSetCrossCoreStmt): + return self._render_i64_pair_stmt("set_cross_core", stmt.core_id, stmt.event_id, env, indent=indent) + if isinstance(stmt, SemanticSetIntraBlockStmt): + return self._render_i64_pair_stmt("set_intra_block", stmt.block_id, stmt.event_id, env, indent=indent) + if isinstance(stmt, SemanticSetIntraCoreStmt): + return self._render_i32_stmt("set_intra_core", stmt.config, env, indent=indent) + if isinstance(stmt, SemanticWaitFlagDevStmt): + return self._render_i64_pair_stmt("wait_flag_dev", stmt.core_id, stmt.event_id, env, indent=indent) + if isinstance(stmt, SemanticWaitIntraCoreStmt): + return self._render_i64_pair_stmt("wait_intra_core", stmt.block_id, stmt.event_id, env, indent=indent) if isinstance(stmt, SemanticDmaConfigStmt): return self._render_dma_config(stmt, env, indent=indent) if isinstance(stmt, SemanticLowLevelCopyStmt): @@ -445,6 +469,59 @@ def _render_dma_config( ) return lines + def _render_buffer_sync_stmt( + self, + name: str, + pipe: str, + buf_id: SemanticExpr, + mode: SemanticExpr, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + lines: list[str] = [] + rendered_buf_id = self._lower_to_i64(buf_id, env, indent=indent, into=lines) + rendered_mode = self._lower_to_i64(mode, env, indent=indent, into=lines) + lines.append( + self._indent(indent) + + f'pto.{name} "{pipe}", {rendered_buf_id.name}, {rendered_mode.name} : i64, i64' + ) + return lines + + def _render_i64_pair_stmt( + self, + name: str, + first: SemanticExpr, + second: SemanticExpr, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + lines: list[str] = [] + rendered_first = self._lower_to_i64(first, env, indent=indent, into=lines) + rendered_second = self._lower_to_i64(second, env, indent=indent, into=lines) + lines.append( + self._indent(indent) + + f"pto.{name} {rendered_first.name}, {rendered_second.name} : i64, i64" + ) + return lines + + def _render_i32_stmt( + self, + name: str, + value: SemanticExpr, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + lines: list[str] = [] + rendered_value = self._lower_to_i32(value, env, indent=indent, into=lines) + lines.append( + self._indent(indent) + + f"pto.{name} {rendered_value.name} : i32" + ) + return lines + def _render_low_level_copy( self, stmt: SemanticLowLevelCopyStmt, @@ -3098,6 +3175,26 @@ def _lower_to_i64( value = self._lower_expr(expr, env, indent=indent, into=into) return self._coerce_rendered_to_i64(value, indent=indent, into=into) + def _lower_to_i32( + self, + expr: SemanticExpr, + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + value = self._lower_expr(expr, env, indent=indent, into=into) + if isinstance(value.type, SemanticScalarType) and value.type.dtype.name == "i32": + return value + if isinstance(value.type, SemanticIndexType): + cast_name = self._new_temp() + into.append( + self._indent(indent) + + f"{cast_name} = arith.index_cast {value.name} : index to i32" + ) + return _RenderedValue(name=cast_name, type=_I32_TYPE) + raise NotImplementedError("expected an i32 or index operand during TileLang DSL v1 lowering") + def _coerce_rendered_to_i64( self, value: _RenderedValue, diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index d911b4337..1a56f8311 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -48,6 +48,7 @@ ) from .types import ( BLayout, + BarrierType, DeinterleaveDist, Event, InterleaveDist, @@ -97,6 +98,7 @@ _PATTERN_SYMBOLS = {pattern.name: pattern for pattern in MaskPattern} _PIPE_SYMBOLS = {pipe.name: pipe for pipe in Pipe} _EVENT_SYMBOLS = {event.name: event for event in Event} +_BARRIER_TYPE_SYMBOLS = {barrier_type.name: barrier_type for barrier_type in BarrierType} _MEMORY_SPACE_SYMBOLS = {memory_space.name: memory_space for memory_space in MemorySpace} _B_LAYOUT_SYMBOLS = {layout.name: layout for layout in BLayout} _S_LAYOUT_SYMBOLS = {layout.name: layout for layout in SLayout} @@ -486,6 +488,54 @@ class SemanticPipeBarrierStmt(SemanticStmt): pipe: str +@dataclass(frozen=True) +class SemanticGetBufStmt(SemanticStmt): + pipe: str + buf_id: SemanticExpr + mode: SemanticExpr + + +@dataclass(frozen=True) +class SemanticRlsBufStmt(SemanticStmt): + pipe: str + buf_id: SemanticExpr + mode: SemanticExpr + + +@dataclass(frozen=True) +class SemanticMemBarStmt(SemanticStmt): + barrier_type: str + + +@dataclass(frozen=True) +class SemanticSetCrossCoreStmt(SemanticStmt): + core_id: SemanticExpr + event_id: SemanticExpr + + +@dataclass(frozen=True) +class SemanticSetIntraBlockStmt(SemanticStmt): + block_id: SemanticExpr + event_id: SemanticExpr + + +@dataclass(frozen=True) +class SemanticSetIntraCoreStmt(SemanticStmt): + config: SemanticExpr + + +@dataclass(frozen=True) +class SemanticWaitFlagDevStmt(SemanticStmt): + core_id: SemanticExpr + event_id: SemanticExpr + + +@dataclass(frozen=True) +class SemanticWaitIntraCoreStmt(SemanticStmt): + block_id: SemanticExpr + event_id: SemanticExpr + + @dataclass(frozen=True) class SemanticDmaConfigStmt(SemanticStmt): name: str @@ -1285,7 +1335,20 @@ def _is_sync_call(self, expr: FrontendExprNode) -> bool: return ( isinstance(expr, FrontendCallExpr) and expr.namespace == "pto" - and expr.name in {"set_flag", "wait_flag", "pipe_barrier", "barrier"} + and expr.name in { + "set_flag", + "wait_flag", + "pipe_barrier", + "barrier", + "get_buf", + "rls_buf", + "mem_bar", + "set_cross_core", + "set_intra_block", + "set_intra_core", + "wait_flag_dev", + "wait_intra_core", + } ) def _is_low_level_dma_call(self, expr: FrontendExprNode) -> bool: @@ -1663,6 +1726,40 @@ def _analyze_sync_stmt( if expr.name == "set_flag": return SemanticSetFlagStmt(src_pipe=src_pipe, dst_pipe=dst_pipe, event=event), dict(env) return SemanticWaitFlagStmt(src_pipe=src_pipe, dst_pipe=dst_pipe, event=event), dict(env) + if expr.name in {"get_buf", "rls_buf"}: + if len(args) not in {2, 3}: + raise TypeError(f"pto.{expr.name} expects 2 or 3 positional arguments in TileLang DSL v1") + pipe = self._require_sync_pipe(args[0], f"pto.{expr.name} pipe") + self._require_i64_like_expr(args[1], f"pto.{expr.name} buf_id") + mode = args[2] if len(args) == 3 else SemanticLiteralExpr(value=0, type=SemanticScalarType(dtype=i64)) + self._require_i64_like_expr(mode, f"pto.{expr.name} mode") + if expr.name == "get_buf": + return SemanticGetBufStmt(pipe=pipe, buf_id=args[1], mode=mode), dict(env) + return SemanticRlsBufStmt(pipe=pipe, buf_id=args[1], mode=mode), dict(env) + if expr.name == "mem_bar": + if len(args) != 1: + raise TypeError("pto.mem_bar expects exactly 1 positional argument in TileLang DSL v1") + barrier_type = self._require_barrier_type(args[0], "pto.mem_bar barrier_type") + return SemanticMemBarStmt(barrier_type=barrier_type), dict(env) + if expr.name in {"set_cross_core", "set_intra_block", "wait_flag_dev", "wait_intra_core"}: + if len(args) != 2: + raise TypeError(f"pto.{expr.name} expects exactly 2 positional arguments in TileLang DSL v1") + identifier = self._require_scalar_or_index_expr(args[0], f"pto.{expr.name} first operand") + self._require_i64_like_expr(identifier, f"pto.{expr.name} first operand") + event_id = self._normalize_event_id_expr(args[1], f"pto.{expr.name} event_id") + if expr.name == "set_cross_core": + return SemanticSetCrossCoreStmt(core_id=identifier, event_id=event_id), dict(env) + if expr.name == "set_intra_block": + return SemanticSetIntraBlockStmt(block_id=identifier, event_id=event_id), dict(env) + if expr.name == "wait_flag_dev": + return SemanticWaitFlagDevStmt(core_id=identifier, event_id=event_id), dict(env) + return SemanticWaitIntraCoreStmt(block_id=identifier, event_id=event_id), dict(env) + if expr.name == "set_intra_core": + if len(args) != 1: + raise TypeError("pto.set_intra_core expects exactly 1 positional argument in TileLang DSL v1") + config = self._require_scalar_or_index_expr(args[0], "pto.set_intra_core config") + self._require_i32_like_expr(config, "pto.set_intra_core config") + return SemanticSetIntraCoreStmt(config=config), dict(env) if expr.name in {"pipe_barrier", "barrier"}: if len(args) != 1: raise TypeError(f"pto.{expr.name} expects exactly 1 positional argument in TileLang DSL v1") @@ -2741,7 +2838,7 @@ def _analyze_symbol_expr(self, expr: FrontendSymbolExpr) -> SemanticExpr: value=pattern, type=SemanticMetaType(kind="mask_pattern"), ) - if expr.namespace in {"PIPE", "Pipe", "pto.PIPE", "pto.Pipe"}: + if expr.namespace in {"PIPE", "pto.PIPE", "Pipe", "pto.Pipe"}: pipe = _PIPE_SYMBOLS.get(expr.name) if pipe is not None: return SemanticSymbolExpr( @@ -2750,7 +2847,7 @@ def _analyze_symbol_expr(self, expr: FrontendSymbolExpr) -> SemanticExpr: value=pipe, type=SemanticMetaType(kind="pipe"), ) - if expr.namespace in {"EVENT", "Event", "pto.EVENT", "pto.Event"}: + if expr.namespace in {"EVENT", "pto.EVENT", "Event", "pto.Event"}: event = _EVENT_SYMBOLS.get(expr.name) if event is not None: return SemanticSymbolExpr( @@ -2759,7 +2856,16 @@ def _analyze_symbol_expr(self, expr: FrontendSymbolExpr) -> SemanticExpr: value=event, type=SemanticMetaType(kind="event"), ) - if expr.namespace in {"pto.MemorySpace"}: + if expr.namespace in {"BarrierType", "pto.BarrierType"}: + barrier_type = _BARRIER_TYPE_SYMBOLS.get(expr.name) + if barrier_type is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=barrier_type, + type=SemanticMetaType(kind="barrier_type"), + ) + if expr.namespace in {"MemorySpace", "pto.MemorySpace"}: memory_space = _MEMORY_SPACE_SYMBOLS.get(expr.name) if memory_space is not None: return SemanticSymbolExpr( @@ -2768,7 +2874,7 @@ def _analyze_symbol_expr(self, expr: FrontendSymbolExpr) -> SemanticExpr: value=memory_space, type=SemanticMetaType(kind="memory_space"), ) - if expr.namespace in {"pto.BLayout"}: + if expr.namespace in {"BLayout", "pto.BLayout"}: b_layout = _B_LAYOUT_SYMBOLS.get(expr.name) if b_layout is not None: return SemanticSymbolExpr( @@ -2777,7 +2883,7 @@ def _analyze_symbol_expr(self, expr: FrontendSymbolExpr) -> SemanticExpr: value=b_layout, type=SemanticMetaType(kind="b_layout"), ) - if expr.namespace in {"pto.SLayout"}: + if expr.namespace in {"SLayout", "pto.SLayout"}: s_layout = _S_LAYOUT_SYMBOLS.get(expr.name) if s_layout is not None: return SemanticSymbolExpr( @@ -2786,7 +2892,7 @@ def _analyze_symbol_expr(self, expr: FrontendSymbolExpr) -> SemanticExpr: value=s_layout, type=SemanticMetaType(kind="s_layout"), ) - if expr.namespace in {"pto.PadValue"}: + if expr.namespace in {"PadValue", "pto.PadValue"}: pad_value = _PAD_VALUE_SYMBOLS.get(expr.name) if pad_value is not None: return SemanticSymbolExpr( @@ -2795,7 +2901,7 @@ def _analyze_symbol_expr(self, expr: FrontendSymbolExpr) -> SemanticExpr: value=pad_value, type=SemanticMetaType(kind="pad_value"), ) - if expr.namespace in {"pto.PadMode"}: + if expr.namespace in {"PadMode", "pto.PadMode"}: pad_mode = _PAD_MODE_SYMBOLS.get(expr.name) if pad_mode is not None: return SemanticSymbolExpr( @@ -2804,7 +2910,7 @@ def _analyze_symbol_expr(self, expr: FrontendSymbolExpr) -> SemanticExpr: value=pad_mode, type=SemanticMetaType(kind="pad_mode"), ) - if expr.namespace in {"pto.PositionMode"}: + if expr.namespace in {"PositionMode", "pto.PositionMode"}: position_mode = _POSITION_MODE_SYMBOLS.get(expr.name) if position_mode is not None: return SemanticSymbolExpr( @@ -2813,7 +2919,7 @@ def _analyze_symbol_expr(self, expr: FrontendSymbolExpr) -> SemanticExpr: value=position_mode, type=SemanticMetaType(kind="position_mode"), ) - if expr.namespace in {"pto.OrderMode"}: + if expr.namespace in {"OrderMode", "pto.OrderMode"}: order_mode = _ORDER_MODE_SYMBOLS.get(expr.name) if order_mode is not None: return SemanticSymbolExpr( @@ -4645,6 +4751,13 @@ def _require_i1_expr(self, expr: SemanticExpr, context: str) -> None: if scalar.dtype != i1: raise TypeError(f"{context} must be an i1 value in TileLang DSL") + def _require_i32_like_expr(self, expr: SemanticExpr, context: str) -> None: + if isinstance(expr.type, SemanticIndexType): + return + scalar = self._require_scalar_expr(expr, context) + if scalar.dtype != i32: + raise TypeError(f"{context} must be an i32 or index value in TileLang DSL") + def _require_i64_like_expr(self, expr: SemanticExpr, context: str) -> None: if isinstance(expr.type, SemanticIndexType): return @@ -4789,6 +4902,31 @@ def _require_sync_event(self, expr: SemanticExpr, context: str) -> str: return expr.value raise TypeError(f"{context} must be an EVENT symbol or event string in TileLang DSL v1") + def _require_barrier_type(self, expr: SemanticExpr, context: str) -> str: + if isinstance(expr, SemanticSymbolExpr) and expr.type.kind == "barrier_type": + return expr.value.value + if isinstance(expr, SemanticBindingRef) and isinstance(expr.type, SemanticMetaType): + if expr.type.kind == "barrier_type" and isinstance(expr.binding.value, BarrierType): + return expr.binding.value.value + if isinstance(expr, SemanticLiteralExpr) and isinstance(expr.type, SemanticMetaType) and expr.type.kind == "string": + return expr.value + raise TypeError(f"{context} must be a BarrierType symbol or string literal in TileLang DSL v1") + + def _normalize_event_id_expr(self, expr: SemanticExpr, context: str) -> SemanticExpr: + if isinstance(expr, SemanticSymbolExpr) and expr.type.kind == "event" and isinstance(expr.value, Event): + return SemanticLiteralExpr( + value=int(expr.value.name[2:]), + type=SemanticScalarType(dtype=i64), + ) + if isinstance(expr, SemanticBindingRef) and isinstance(expr.type, SemanticMetaType): + if expr.type.kind == "event" and isinstance(expr.binding.value, Event): + return SemanticLiteralExpr( + value=int(expr.binding.value.name[2:]), + type=SemanticScalarType(dtype=i64), + ) + self._require_i64_like_expr(expr, context) + return expr + def _pad_mode_value( self, expr: SemanticExpr | None, @@ -5142,17 +5280,23 @@ def analyze_frontend_kernel(node: FrontendKernelNode) -> SemanticKernel: "SemanticExpr", "SemanticExprStmt", "SemanticForStmt", + "SemanticGetBufStmt", "SemanticIfResult", "SemanticIfStmt", "SemanticIndexType", "SemanticKernel", "SemanticLiteralExpr", + "SemanticMemBarStmt", "SemanticMaskType", "SemanticParameter", "SemanticPipeBarrierStmt", + "SemanticRlsBufStmt", "SemanticReturnStmt", "SemanticScalarType", + "SemanticSetCrossCoreStmt", "SemanticSetFlagStmt", + "SemanticSetIntraBlockStmt", + "SemanticSetIntraCoreStmt", "SemanticShapeType", "SemanticSliceExpr", "SemanticSliceType", @@ -5173,6 +5317,8 @@ def analyze_frontend_kernel(node: FrontendKernelNode) -> SemanticKernel: "SemanticType", "SemanticVRegType", "SemanticVectorStoreStmt", + "SemanticWaitFlagDevStmt", "SemanticWaitFlagStmt", + "SemanticWaitIntraCoreStmt", "analyze_frontend_kernel", ] diff --git a/tilelang-dsl/python/tilelang_dsl/support_matrix.py b/tilelang-dsl/python/tilelang_dsl/support_matrix.py index 257b50c0f..fba2f10be 100644 --- a/tilelang-dsl/python/tilelang_dsl/support_matrix.py +++ b/tilelang-dsl/python/tilelang_dsl/support_matrix.py @@ -34,6 +34,14 @@ "wait_flag", "pipe_barrier", "barrier", + "get_buf", + "rls_buf", + "mem_bar", + "set_cross_core", + "set_intra_block", + "set_intra_core", + "wait_flag_dev", + "wait_intra_core", } ) @@ -318,8 +326,6 @@ def get_pto_call_tier(call_name: str) -> str: "dma_store", "pto.dma_load", "pto.dma_store", - "pto.get_buf", - "pto.rls_buf", "pto.dma_copy", "pto.vreduce", "pto.tile", @@ -340,6 +346,7 @@ def get_pto_call_tier(call_name: str) -> str: "pto.mask_b8": BASIC_TIER, "pto.mask_b16": BASIC_TIER, "pto.mask_b32": BASIC_TIER, + "BarrierType": BASIC_TIER, "PadMode": BASIC_TIER, "BLayout": BASIC_TIER, "SLayout": BASIC_TIER, diff --git a/tilelang-dsl/python/tilelang_dsl/types.py b/tilelang-dsl/python/tilelang_dsl/types.py index f60cf4875..ffdfbb7ae 100644 --- a/tilelang-dsl/python/tilelang_dsl/types.py +++ b/tilelang-dsl/python/tilelang_dsl/types.py @@ -142,6 +142,10 @@ class Event(str, Enum): ID29 = "EVENT_ID29" ID30 = "EVENT_ID30" ID31 = "EVENT_ID31" +class BarrierType(str, Enum): + VV_ALL = "VV_ALL" + VST_VLD = "VST_VLD" + VLD_VST = "VLD_VST" class MaskPattern(str, Enum): @@ -421,6 +425,7 @@ def tile_layout_descriptor( "EVENT", "MaskPattern", "PAT", + "BarrierType", "PadMode", "PositionMode", "OrderMode", diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index 8635f7e1d..2e90b822e 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -36,14 +36,20 @@ SemanticDmaConfigStmt, SemanticExprStmt, SemanticForStmt, + SemanticGetBufStmt, SemanticIfStmt, SemanticIndexType, + SemanticMemBarStmt, SemanticLowLevelCopyStmt, SemanticMaskType, SemanticPipeBarrierStmt, SemanticPtrType, + SemanticRlsBufStmt, SemanticScalarType, + SemanticSetCrossCoreStmt, SemanticSetFlagStmt, + SemanticSetIntraBlockStmt, + SemanticSetIntraCoreStmt, SemanticStrictVecscopeStmt, SemanticSymbolExpr, SemanticTensorViewType, @@ -51,7 +57,9 @@ SemanticVecscopeStmt, SemanticVectorStoreStmt, SemanticVRegType, + SemanticWaitFlagDevStmt, SemanticWaitFlagStmt, + SemanticWaitIntraCoreStmt, analyze_frontend_kernel, ) @@ -84,6 +92,7 @@ def test_package_exports_surface(self) -> None: self.assertTrue(hasattr(pto, "elements_per_vreg")) self.assertTrue(hasattr(pto, "PAT")) self.assertTrue(hasattr(pto, "PadMode")) + self.assertTrue(hasattr(pto, "BarrierType")) self.assertTrue(hasattr(pto, "PositionMode")) self.assertTrue(hasattr(pto, "OrderMode")) self.assertTrue(hasattr(pto, "DeinterleaveDist")) @@ -94,6 +103,7 @@ def test_package_exports_surface(self) -> None: self.assertTrue(hasattr(pto, "PIPE")) self.assertTrue(hasattr(pto, "EVENT")) self.assertEqual(repr(pto.align), "align") + self.assertEqual(pto.BarrierType.VST_VLD.value, "VST_VLD") self.assertEqual(pto.PadMode.PadNull.value, "PadNull") self.assertEqual(pto.PadMode.PadFirstElem.value, "PadFirstElem") self.assertEqual(pto.PadMode.PadValue.value, "PadValue") @@ -109,6 +119,7 @@ def test_package_exports_surface(self) -> None: self.assertEqual(pto.PredicatePart.LOWER.value, "LOWER") self.assertEqual(pto.PredicatePart.HIGHER.value, "HIGHER") self.assertEqual(pto.StrideMode.S4_B64.value, "STRIDE_S4_B64") + self.assertEqual(pto.Event.ID31.value, "EVENT_ID31") class TileLangDSLSupportMatrixTests(unittest.TestCase): @@ -151,6 +162,15 @@ def test_stable_starter_surface_groups_map_to_stable_tier(self) -> None: self.assertEqual(get_feature_tier("pto.vsta"), BASIC_TIER) self.assertEqual(get_feature_tier("pto.vadd"), BASIC_TIER) self.assertEqual(get_feature_tier("pto.vmuls"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.get_buf"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.rls_buf"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.mem_bar"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.set_cross_core"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.set_intra_block"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.set_intra_core"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.wait_flag_dev"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.wait_intra_core"), BASIC_TIER) + self.assertEqual(get_feature_tier("BarrierType"), BASIC_TIER) self.assertEqual(get_feature_tier("pto.vaddrelu"), BASIC_TIER) self.assertEqual(get_feature_tier("pto.vaxpy"), BASIC_TIER) self.assertEqual(get_feature_tier("pto.vmull"), BASIC_TIER) @@ -3698,23 +3718,66 @@ def test_sync_ops_accept_event_class_alias_and_full_event_range(self) -> None: def kernel(inp: pto.TensorView, tile: pto.Tile): pto.set_flag(Pipe.MTE2, Pipe.V, Event.ID31) pto.wait_flag(Pipe.MTE2, Pipe.V, Event.ID31) + + def test_extended_sync_buffer_ops_lower_to_authoring_surface(self) -> None: + Pipe = pto.Pipe + Event = pto.Event + BarrierType = pto.BarrierType + + @pto.vkernel( + op="extended_sync_surface", + dtypes=[(pto.f32, pto.i64, pto.i64, pto.i64, pto.i64, pto.i32)], + advanced=True, + ) + def kernel( + tile: pto.Tile, + buf_id: pto.i64, + mode: pto.i64, + core_id: pto.i64, + block_id: pto.i64, + config: pto.i32, + ): + pto.get_buf(Pipe.MTE2, buf_id, mode) + pto.rls_buf(Pipe.V, buf_id) + pto.mem_bar(BarrierType.VST_VLD) + pto.set_cross_core(core_id, Event.ID7) + pto.set_intra_block(block_id, Event.ID16) + pto.set_intra_core(config) + pto.wait_flag_dev(core_id, Event.ID8) + pto.wait_intra_core(block_id, Event.ID31) + with pto.strict_vecscope(tile, tile, 0, 128, 64) as (src, dst, lb, ub, step): + for lane in range(lb, ub, step): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec = pto.vlds(src, lane) + pto.vsts(vec, dst, lane, mask) return None specialized = kernel.specialize( tile=pto.TileSpecialization( - shape=(16, 16), + shape=(8, 16), memory_space=pto.MemorySpace.UB, ) ) semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) - self.assertIsInstance(semantic_kernel.body[0], SemanticSetFlagStmt) - self.assertIsInstance(semantic_kernel.body[1], SemanticWaitFlagStmt) - self.assertEqual(pto.Event.ID31.value, "EVENT_ID31") + self.assertIsInstance(semantic_kernel.body[0], SemanticGetBufStmt) + self.assertIsInstance(semantic_kernel.body[1], SemanticRlsBufStmt) + self.assertIsInstance(semantic_kernel.body[2], SemanticMemBarStmt) + self.assertIsInstance(semantic_kernel.body[3], SemanticSetCrossCoreStmt) + self.assertIsInstance(semantic_kernel.body[4], SemanticSetIntraBlockStmt) + self.assertIsInstance(semantic_kernel.body[5], SemanticSetIntraCoreStmt) + self.assertIsInstance(semantic_kernel.body[6], SemanticWaitFlagDevStmt) + self.assertIsInstance(semantic_kernel.body[7], SemanticWaitIntraCoreStmt) text = specialized.mlir_text() - self.assertIn('pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID31"]', text) - self.assertIn('pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID31"]', text) + self.assertIn('pto.get_buf "PIPE_MTE2", %arg1, %arg2 : i64, i64', text) + self.assertIn('pto.rls_buf "PIPE_V", %arg1, %c0_i64 : i64, i64', text) + self.assertIn('pto.mem_bar "VST_VLD"', text) + self.assertIn("pto.set_cross_core %arg3, %c7_i64 : i64, i64", text) + self.assertIn("pto.set_intra_block %arg4, %c16_i64 : i64, i64", text) + self.assertIn("pto.set_intra_core %arg5 : i32", text) + self.assertIn("pto.wait_flag_dev %arg3, %c8_i64 : i64, i64", text) + self.assertIn("pto.wait_intra_core %arg4, %c31_i64 : i64, i64", text) def test_strict_vecscope_rejects_implicit_capture_during_semantic_analysis(self) -> None: @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f16, pto.i32)], advanced=True) From c3111affa01bb1eab9ec9e2cf17ff04be330093e Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Mon, 13 Apr 2026 14:30:47 +0800 Subject: [PATCH 039/192] Align DSL interface with VPTO v0.3 --- tilelang-dsl/docs/unsupported-features.md | 45 +- .../docs/user_guide/03-kernel-declaration.md | 7 +- .../docs/user_guide/05-type-system.md | 36 +- .../docs/user_guide/07-frontend-operations.md | 185 +- .../user_guide/09-vector-memory-operations.md | 93 +- .../user_guide/10-predicate-operations.md | 4 +- .../11-vector-arithmetic-operations.md | 55 +- .../docs/vpto_spec/vpto-spec-current.md | 5349 +++++++++++++++++ tilelang-dsl/docs/vpto_spec/vpto-spec-v0.2.md | 5072 ++++++++++++++++ tilelang-dsl/python/tilelang_dsl/__init__.py | 19 + .../python/tilelang_dsl/expand_helper.py | 18 +- .../python/tilelang_dsl/frontend_ast.py | 2 + tilelang-dsl/python/tilelang_dsl/kernel.py | 3 +- tilelang-dsl/python/tilelang_dsl/lowering.py | 231 +- tilelang-dsl/python/tilelang_dsl/semantic.py | 422 +- .../python/tilelang_dsl/support_matrix.py | 16 + tilelang-dsl/python/tilelang_dsl/types.py | 112 +- .../skills/auto-update-vpto-spec/SKILL.md | 79 + tilelang-dsl/tests/test_tilelang_dsl_v1.py | 161 +- 19 files changed, 11711 insertions(+), 198 deletions(-) create mode 100644 tilelang-dsl/docs/vpto_spec/vpto-spec-current.md create mode 100644 tilelang-dsl/docs/vpto_spec/vpto-spec-v0.2.md create mode 100644 tilelang-dsl/skills/auto-update-vpto-spec/SKILL.md diff --git a/tilelang-dsl/docs/unsupported-features.md b/tilelang-dsl/docs/unsupported-features.md index 3a288c85e..3d09832d6 100644 --- a/tilelang-dsl/docs/unsupported-features.md +++ b/tilelang-dsl/docs/unsupported-features.md @@ -50,12 +50,37 @@ The following guide surfaces are not implemented as public APIs: ### Missing Vector Load/Store Families -The previously missing vector-memory surfaces -`pto.vldas(...)`, `pto.vldus(...)`, `pto.vldx2(...)`, `pto.vsld(...)`, -`pto.psts(...)`, `pto.vsst(...)`, `pto.vstx2(...)`, `pto.vsta(...)`, -`pto.pstu(...)`, `pto.vstu(...)`, `pto.vstus(...)`, and `pto.vstur(...)` -are now implemented. Remaining guide gaps are concentrated in the wider -indexed/flush families that are still not wired through this DSL package. +The current package supports the core v0.3 load/store subset: + +- `pto.vlds(...)` +- `pto.vsts(...)` +- `pto.vldsx2(...)` +- `pto.vstsx2(...)` +- `pto.load_scalar(...)` +- `pto.store_scalar(...)` + +The following documented load/store families are still unsupported: + +- `pto.vldas(...)` +- `pto.vldus(...)` +- `pto.vsld(...)` +- `pto.psts(...)` +- `pto.vsst(...)` +- `pto.vsta(...)` +- `pto.pstu(...)` +- `pto.vstu(...)` +- `pto.vstus(...)` +- `pto.vstur(...)` + +### Missing Direct Predicate Constructor/Compare APIs + +The implementation expects users to go through `pto.make_mask(...)` rather than +call the underlying mask ops directly. These guide-documented APIs are not part +of the supported authoring surface: + +- `pto.pset_b8(...)`, `pto.pset_b16(...)`, `pto.pset_b32(...)` +- `pto.pge_b8(...)`, `pto.pge_b16(...)`, `pto.pge_b32(...)` +- `pto.plt_b8(...)`, `pto.plt_b16(...)`, `pto.plt_b32(...)` ### Missing Extended Vector Arithmetic Families @@ -156,13 +181,15 @@ Currently supported: - rank-1: `tile[start:]` - rank-2: `tile[row, col:]` - rank-2 column-major: `tile[row_start:, col_index]` -- available across the current basic vector-memory tile-indexing family +- for `pto.vlds(...)`, `pto.vsts(...)`, `pto.vldsx2(...)`, and `pto.vstsx2(...)` Not currently supported from the guide's broader indexing model: - single-element syntax such as `tile[row, col]` and `tile[pos]` - explicit slice `stop` - stepped tile vector slices +- the guide's wider indexed op family (`vldas`, `vldus`, `vsld`, + `psts`, `vsst`, `vsta`) ### Control-Flow Result Merging @@ -194,8 +221,10 @@ For quick orientation, the current package head is strongest in these areas: - `templates={...}` and `pto.tpl(...)` - `ptr(...)`, `pto.castptr(...)`, `pto.addptr(...)` - low-level DMA config/copy ops +- runtime block queries (`pto.get_block_idx`, `pto.get_block_num`, ...) - `pto.make_mask(...)` -- `pto.vlds(...)` and `pto.vsts(...)` +- `pto.vlds(...)`, `pto.vsts(...)`, `pto.vldsx2(...)`, `pto.vstsx2(...)` +- `pto.load_scalar(...)` and `pto.store_scalar(...)` - base unary/binary/vector-scalar vector ops - advanced compare/select/carry/rearrangement families diff --git a/tilelang-dsl/docs/user_guide/03-kernel-declaration.md b/tilelang-dsl/docs/user_guide/03-kernel-declaration.md index 3044ef005..348bf170c 100644 --- a/tilelang-dsl/docs/user_guide/03-kernel-declaration.md +++ b/tilelang-dsl/docs/user_guide/03-kernel-declaration.md @@ -79,12 +79,15 @@ The `dtypes` parameter supports flexible type matching: 1. **Concrete Types**: Exact type matches using DSL scalar types: - `pto.f16`, `pto.f32`, `pto.bf16` - - `pto.i8`, `pto.i16`, `pto.i32`, `pto.i64` + - `pto.i8`, `pto.si8`, `pto.ui8` + - `pto.i16`, `pto.si16`, `pto.ui16` + - `pto.i32`, `pto.si32`, `pto.ui32` + - `pto.i64`, `pto.si64`, `pto.ui64` - `pto.mask_b8`, `pto.mask_b16`, `pto.mask_b32` 2. **Type Wildcards**: Generic type patterns: - `pto.AnyFloat`: Matches any floating-point type (`f16`, `bf16`, `f32`) - - `pto.AnyInt`: Matches any integer type (`i8`, `i16`, `i32`, `i64`) + - `pto.AnyInt`: Matches any integer type (`i*`, `si*`, `ui*`) - `pto.AnyType`: Matches any scalar type - `pto.AnyMask`: Matches any mask type (`mask_b8`, `mask_b16`, `mask_b32`) diff --git a/tilelang-dsl/docs/user_guide/05-type-system.md b/tilelang-dsl/docs/user_guide/05-type-system.md index 2a68573c4..8cb898250 100644 --- a/tilelang-dsl/docs/user_guide/05-type-system.md +++ b/tilelang-dsl/docs/user_guide/05-type-system.md @@ -5,10 +5,18 @@ | DSL Type | Description | Bit Width | |----------|-------------|-----------| | `pto.i1` | Boolean | 1 | -| `pto.i8` | 8-bit integer | 8 | -| `pto.i16` | 16-bit integer | 16 | -| `pto.i32` | 32-bit integer | 32 | -| `pto.i64` | 64-bit integer | 64 | +| `pto.i8` | 8-bit signless integer | 8 | +| `pto.si8` | 8-bit signed integer | 8 | +| `pto.ui8` | 8-bit unsigned integer | 8 | +| `pto.i16` | 16-bit signless integer | 16 | +| `pto.si16` | 16-bit signed integer | 16 | +| `pto.ui16` | 16-bit unsigned integer | 16 | +| `pto.i32` | 32-bit signless integer | 32 | +| `pto.si32` | 32-bit signed integer | 32 | +| `pto.ui32` | 32-bit unsigned integer | 32 | +| `pto.i64` | 64-bit signless integer | 64 | +| `pto.si64` | 64-bit signed integer | 64 | +| `pto.ui64` | 64-bit unsigned integer | 64 | | `pto.f16` | Half precision float | 16 | | `pto.bf16` | Brain float 16 | 16 | | `pto.f32` | Single precision float | 32 | @@ -22,8 +30,13 @@ For explicit typing, use type constructors: ```python x = pto.i32(1024) # Explicit i32 constant y: pto.i32 = 1024 # Type annotation +z = pto.ui16(7) # Explicit unsigned 16-bit constant ``` +Integer sign semantics are part of the DSL type surface. `pto.si16`, +`pto.ui16`, and `pto.i16` are distinct scalar dtypes and lower to `si16`, +`ui16`, and `i16` respectively in VPTO IR. + ### Floating-Point Literal Forms `pto.f16(...)`, `pto.bf16(...)`, and `pto.f32(...)` accept multiple literal forms. @@ -67,8 +80,14 @@ v_i8 = pto.vreg(pto.i8) # !pto.vreg<256xi8> - `pto.f16` → `!pto.vreg<128xf16>` - `pto.bf16` → `!pto.vreg<128xbf16>` - `pto.i32` → `!pto.vreg<64xi32>` +- `pto.si32` → `!pto.vreg<64xsi32>` +- `pto.ui32` → `!pto.vreg<64xui32>` - `pto.i16` → `!pto.vreg<128xi16>` +- `pto.si16` → `!pto.vreg<128xsi16>` +- `pto.ui16` → `!pto.vreg<128xui16>` - `pto.i8` → `!pto.vreg<256xi8>` +- `pto.si8` → `!pto.vreg<256xsi8>` +- `pto.ui8` → `!pto.vreg<256xui8>` Constraint: `element_count × bitwidth(element_type) = 2048` @@ -80,7 +99,8 @@ lanes0 = v_dtype.elements_per_vreg # 64 lanes1 = pto.elements_per_vreg(pto.f32) # 64 ``` -Current TileLang DSL v1 vector lowering supports `i8`, `i16`, `i32`, `f16`, `bf16`, and `f32` element types. +Current TileLang DSL v1 vector lowering supports the 8/16/32-bit integer +families (`i*`, `si*`, `ui*`) plus `f16`, `bf16`, and `f32` element types. ### Typed Masks @@ -98,9 +118,9 @@ mask: pto.mask_b32 = pto.make_mask(pto.f32, PAT.ALL) ``` Mask operations must match the vector element family: -- `f32` vectors use `mask_b32` -- `f16` vectors use `mask_b16` -- `i8` vectors use `mask_b8` +- `f32`, `i32`, `si32`, and `ui32` vectors use `mask_b32` +- `f16`, `bf16`, `i16`, `si16`, and `ui16` vectors use `mask_b16` +- `i8`, `si8`, and `ui8` vectors use `mask_b8` ```python # Correct: f32 vector with b32 mask diff --git a/tilelang-dsl/docs/user_guide/07-frontend-operations.md b/tilelang-dsl/docs/user_guide/07-frontend-operations.md index 5f8bd2893..9b564b98d 100644 --- a/tilelang-dsl/docs/user_guide/07-frontend-operations.md +++ b/tilelang-dsl/docs/user_guide/07-frontend-operations.md @@ -65,7 +65,7 @@ Operations for querying type properties. **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| -| `dtype` | `Type` | Data type (e.g., `pto.f32`, `pto.f16`, `pto.i8`) | +| `dtype` | `Type` | Data type (e.g., `pto.f32`, `pto.f16`, `pto.i8`, `pto.si16`, `pto.ui32`) | **Returns**: | Return Value | Type | Description | @@ -77,6 +77,7 @@ Operations for querying type properties. f32_size = pto.bytewidth(pto.f32) # Returns 4 f16_size = pto.bytewidth(pto.f16) # Returns 2 i8_size = pto.bytewidth(pto.i8) # Returns 1 +ui64_size = pto.bytewidth(pto.ui64) # Returns 8 ``` **Common Use Case**: Calculate byte offsets for memory access: @@ -92,7 +93,7 @@ byte_offset = index * pto.bytewidth(element_type) **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| -| `dtype` | `Type` | Data type (e.g., `pto.f32`, `pto.f16`, `pto.i8`) | +| `dtype` | `Type` | Data type (e.g., `pto.f32`, `pto.f16`, `pto.i8`, `pto.si16`, `pto.ui32`) | **Returns**: | Return Value | Type | Description | @@ -104,6 +105,7 @@ byte_offset = index * pto.bytewidth(element_type) f32_elems_per_vreg = pto.elements_per_vreg(pto.f32) # Returns 64 (256 / 4) f16_elems_per_vreg = pto.elements_per_vreg(pto.f16) # Returns 128 (256 / 2) i8_elems_per_vreg = pto.elements_per_vreg(pto.i8) # Returns 256 (256 / 1) +si16_elems_per_vreg = pto.elements_per_vreg(pto.si16) # Returns 128 (256 / 2) ``` **Common Use Case**: Loop stride calculation for vector operations: @@ -123,6 +125,184 @@ elems = 256 // pto.bytewidth(dtype) elems = pto.elements_per_vreg(dtype) ``` +### Runtime Block Query Operations + +These ops expose the current kernel instance's execution coordinates to scalar +code. They are pure scalar producers: + +- they do not move data +- they do not allocate buffers +- they do not by themselves create `vecscope` boundaries + +Their main purpose is workload partitioning. A common pattern is: + +1. query the current block or subblock id +2. compute a per-instance starting offset +3. use that offset to derive GM/UB pointers or TensorView slices +4. run the local tile or vector loop for that partition + +#### `pto.get_block_idx() -> pto.i64` + +**Description**: Returns the current block ID for the running kernel instance. + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `block` | `pto.i64` | Current block index in the range `[0, pto.get_block_num())` | + +**Behavior**: +- The returned value is launch-instance-local and may differ across concurrently running blocks. +- The value is stable for the lifetime of one kernel instance. +- The op is scalar-only and can be used before pointer arithmetic, TensorView partitioning, DMA setup, or loop construction. + +#### `pto.get_subblock_idx() -> pto.i64` + +**Description**: Returns the current subblock ID visible to the running kernel instance. + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `subblock` | `pto.i64` | Current subblock index in the range `[0, pto.get_subblock_num())` | + +**Behavior**: +- Used when one block is further subdivided by the launch/runtime model. +- Like `pto.get_block_idx()`, this is a pure scalar query with no side effects. + +#### `pto.get_block_num() -> pto.i64` + +**Description**: Returns the total number of blocks visible to the current kernel launch. + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `block_num` | `pto.i64` | Total block count for the current launch domain | + +**Behavior**: +- Typically paired with `pto.get_block_idx()` to compute per-block ranges. +- The result is a runtime value and should not be assumed to be a compile-time constant. + +#### `pto.get_subblock_num() -> pto.i64` + +**Description**: Returns the total number of subblocks visible to the current execution instance. + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `subblock_num` | `pto.i64` | Total subblock count in the current runtime execution domain | + +**Behavior**: +- Typically paired with `pto.get_subblock_idx()` for finer-grained partitioning inside one block. + +**Example**: +```python +block = pto.get_block_idx() +block_num = pto.get_block_num() +subblock = pto.get_subblock_idx() +subblock_num = pto.get_subblock_num() +``` + +**Typical Use Case**: Compute a per-block base pointer. +```python +block = pto.get_block_idx() +block_len = 2048 +base_elem = block * block_len +block_src = pto.addptr(src_gm, base_elem) +block_dst = pto.addptr(dst_gm, base_elem) +``` + +**Constraints**: +- These ops return runtime scalar values, not compile-time specialization constants. +- They are intended for scalar address/control computation, not as vector operands. +- When mixing them with pointer arithmetic, remember that `pto.addptr(...)` uses element offsets, not byte offsets. + +### Scalar Pointer Helpers [Advanced Tier] + +These ops perform scalar element access on typed PTO pointers. Unlike +`pto.vlds(...)` / `pto.vsts(...)`, they operate on exactly one element and do +not create or consume vector registers or masks. + +They are useful when a kernel needs a small amount of scalar state next to +vector code, for example: + +- reading a scalar coefficient or loop-carried value from UB +- writing a scalar flag or reduction result +- patching a small header/metadata area without vector load-store semantics + +#### `pto.load_scalar(ptr: PtrType, offset: Index) -> ScalarType` +#### `pto.load_scalar(dtype: Type, ptr: PtrType, offset: Index) -> ScalarType` + +**Description**: Loads one scalar element from a typed PTO pointer at the given element offset. + +**Parameters (`load_scalar(ptr, offset)`)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `ptr` | `PtrType` | Typed pointer created by `pto.ptr(...)`, `pto.castptr(...)`, `Tile.as_ptr()`, or `TensorView.as_ptr()` | +| `offset` | `Index` | Element displacement from `ptr` | + +**Parameters (`load_scalar(dtype, ptr, offset)`)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `dtype` | `Type` | Optional explicit result dtype; must match the pointer element type | +| `ptr` | `PtrType` | Typed pointer source | +| `offset` | `Index` | Element displacement from `ptr` | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `value` | `ScalarType` | One scalar element loaded from `ptr[offset]` | + +**Behavior**: +- Access is element-based, not byte-based. +- The loaded value has the same scalar dtype as the pointer element type. +- This is a scalar memory helper; it does not participate in vector distribution families such as `dist`. +- It may target any memory space represented by the pointer type; the memory-space legality follows the pointer producer. + +#### `pto.store_scalar(ptr: PtrType, offset: Index, value: ScalarType) -> None` +#### `pto.store_scalar(value: ScalarType, ptr: PtrType, offset: Index) -> None` + +**Description**: Stores one scalar element to a typed PTO pointer at the given element offset. + +**Parameters (`store_scalar(ptr, offset, value)`)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `ptr` | `PtrType` | Typed destination pointer | +| `offset` | `Index` | Element displacement from `ptr` | +| `value` | `ScalarType` | Scalar value to write | + +**Parameters (`store_scalar(value, ptr, offset)`)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `value` | `ScalarType` | Scalar value to write | +| `ptr` | `PtrType` | Typed destination pointer | +| `offset` | `Index` | Element displacement from `ptr` | + +**Returns**: None (side-effect operation) + +**Behavior**: +- Stores exactly one scalar element to `ptr[offset]`. +- Does not consume a predicate mask. +- Does not imply vector-store ordering semantics such as `dist` or unaligned store state. + +**Example**: +```python +value = pto.load_scalar(src_ptr, 0) +pto.store_scalar(dst_ptr, 0, value) +``` + +**Typical Use Case**: Read-modify-write scalar metadata next to vector code. +```python +flag = pto.load_scalar(status_ptr, 0) +# scalar compute on `flag` +pto.store_scalar(status_ptr, 0, flag) +``` + +**Constraints**: +- `ptr` must be a typed `pto.ptr(...)` value. +- `offset` is element-based and must be index-typed after frontend normalization. + Plain integer literals such as `0` are accepted and lowered as index constants. +- The scalar dtype must match the pointer element dtype. +- These ops are advanced pointer-surface operations; prefer Tile/TensorView authoring surfaces when scalar pointer manipulation is not required. + ### Pointer Construction [Advanced Tier] Operations for creating and manipulating typed pointers. @@ -169,4 +349,3 @@ ub_ptr = pto.castptr(0, pto.ptr(pto.f32, MemorySpace.UB)) # Advance pointer by 1024 f32 elements (not bytes) next_ptr = pto.addptr(ub_ptr, 1024) ``` - diff --git a/tilelang-dsl/docs/user_guide/09-vector-memory-operations.md b/tilelang-dsl/docs/user_guide/09-vector-memory-operations.md index 539ea1895..2d432941b 100644 --- a/tilelang-dsl/docs/user_guide/09-vector-memory-operations.md +++ b/tilelang-dsl/docs/user_guide/09-vector-memory-operations.md @@ -1,17 +1,28 @@ ### Enum Types for Vector Memory Operations -The following enum types provide type-safe parameter specification for vector memory operations: +The current DSL exposes type-safe Enum operands for the dual load/store +distribution families: -- **`DeinterleaveDist`**: Deinterleave distribution modes for `pto.vldx2` - - `B8`: 8-bit element deinterleave (for i8) - - `B16`: 16-bit element deinterleave (for i16, f16, bf16) - - `B32`: 32-bit element deinterleave (for i32, f32) - - `BD`: Broadcast deinterleave mode +- **`DeinterleaveDist`** for `pto.vldsx2` + - `DeinterleaveDist.DINTLV`: alternating-element deinterleave + - `DeinterleaveDist.BDINTLV`: block deinterleave + - compatibility aliases: `DeinterleaveDist.B8`, `DeinterleaveDist.B16`, + `DeinterleaveDist.B32`, `DeinterleaveDist.BD` -- **`InterleaveDist`**: Interleave distribution modes for `pto.vstx2` - - `B8`: 8-bit element interleave (for i8) - - `B16`: 16-bit element interleave (for i16, f16, bf16) - - `B32`: 32-bit element interleave (for i32, f32) +- **`InterleaveDist`** for `pto.vstsx2` + - `InterleaveDist.INTLV`: interleave two vectors into one destination stream + - compatibility aliases: `InterleaveDist.B8`, `InterleaveDist.B16`, + `InterleaveDist.B32` + +The canonical VPTO v0.3 spellings are the enum values: + +- `DeinterleaveDist.DINTLV.value == "DINTLV"` +- `DeinterleaveDist.BDINTLV.value == "BDINTLV"` +- `InterleaveDist.INTLV.value == "INTLV"` + +For migration convenience, the implementation still accepts legacy raw strings +such as `"DINTLV_B32"` and `"INTLV_B32"`, but new DSL code should prefer the +Enum operands. - **`StrideMode`**: Stride modes for `pto.vsld` - `S3_B16`: Stride 3, block size 16 @@ -72,10 +83,10 @@ vector_lanes = 256 // element_size_bytes(element_type) **Convenience API**: Use `pto.elements_per_vreg(dtype)` to compute the number of elements per vector register for a given element type (e.g., `pto.elements_per_vreg(pto.f32)` returns 64, `pto.elements_per_vreg(pto.f16)` returns 128). See [Type Query Operations](07-frontend-operations.md#type-query-operations) for full documentation. Where `element_size_bytes` is: -- 1 byte for `i8` -- 2 bytes for `i16`, `f16`, `bf16` -- 4 bytes for `i32`, `f32` -- 8 bytes for `i64` +- 1 byte for `i8`, `si8`, `ui8` +- 2 bytes for `i16`, `si16`, `ui16`, `f16`, `bf16` +- 4 bytes for `i32`, `si32`, `ui32`, `f32` +- 8 bytes for `i64`, `si64`, `ui64` #### Offset Computation @@ -134,8 +145,8 @@ The byte offset is automatically computed based on tile layout: The indexing syntax is supported for all vector load and store operations with the following syntax mapping: - **Vector-range indexing** (`tile[row, col:]` or `tile[start:]`): - - Load operations: `vlds`, `vldas`, `vldus`, `vldx2` - - Store operations: `vsts`, `vsta`, `psts`, `vsst`, `vstx2` + - Load operations: `vlds`, `vldas`, `vldus`, `vldsx2` + - Store operations: `vsts`, `vsta`, `psts`, `vsst`, `vstsx2` - **Single-element indexing** (`tile[row, col]` or `tile[pos]`): - Load operations: `vsld` (scalar load with broadcast) @@ -154,7 +165,7 @@ vec = pto.vlds(tile[k:]) # Load vector from elements k to k+vector_l pto.vsts(vec, tile[k:], mask) # Store vector with mask # Dual load with deinterleave -low, high = pto.vldx2(tile[i, j:], DeinterleaveDist.B32) +low, high = pto.vldsx2(tile[i, j:], "DINTLV") # Aligned load with indexing vec = pto.vldas(tile[i, j:], align) @@ -319,9 +330,9 @@ for n in range(4): ``` -#### `pto.vldx2(buf: ptr, offset: Index, dist: DeinterleaveDist) -> (VRegType, VRegType)` [Advanced Tier] -#### `pto.vldx2(tile[row, col:], dist: DeinterleaveDist) -> (VRegType, VRegType)` [Basic Tier] -#### `pto.vldx2(tile[start:], dist: DeinterleaveDist) -> (VRegType, VRegType)` [Basic Tier] +#### `pto.vldsx2(buf: ptr, offset: Index, dist: DeinterleaveDist) -> (VRegType, VRegType)` [Advanced Tier] +#### `pto.vldsx2(tile[row, col:], dist: DeinterleaveDist) -> (VRegType, VRegType)` [Basic Tier] +#### `pto.vldsx2(tile[start:], dist: DeinterleaveDist) -> (VRegType, VRegType)` [Basic Tier] **Description**: Dual vector load with deinterleave (AoS → SoA conversion). Loads interleaved data from a single buffer and deinterleaves into two vectors. Supports both byte-offset and element-indexing syntax. @@ -330,16 +341,16 @@ for n in range(4): |-----------|------|-------------| | `buf` | `ptr` | Pointer to source buffer in UB memory space (Advanced mode only - requires explicit pointer) | | `offset` | `Index` | Byte offset | -| `dist` | `DeinterleaveDist` | Deinterleave distribution mode: `DeinterleaveDist.B8`, `DeinterleaveDist.B16`, `DeinterleaveDist.B32`, `DeinterleaveDist.BD`. Determines element size and interleave pattern. | +| `dist` | `DeinterleaveDist` | Deinterleave distribution enum. Prefer `DeinterleaveDist.DINTLV` or `DeinterleaveDist.BDINTLV`. | **Parameters (element-indexing syntax)**: | Parameter | Type | Description | |-----------|------|-------------| | `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | -| `dist` | `DeinterleaveDist` | Deinterleave distribution mode: `DeinterleaveDist.B8`, `DeinterleaveDist.B16`, `DeinterleaveDist.B32`, `DeinterleaveDist.BD`. Determines element size and interleave pattern. | +| `dist` | `DeinterleaveDist` | Deinterleave distribution enum. Prefer `DeinterleaveDist.DINTLV` or `DeinterleaveDist.BDINTLV`. | | _or_ | | | | `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | -| `dist` | `DeinterleaveDist` | Deinterleave distribution mode: `DeinterleaveDist.B8`, `DeinterleaveDist.B16`, `DeinterleaveDist.B32`, `DeinterleaveDist.BD`. | +| `dist` | `DeinterleaveDist` | Deinterleave distribution enum. Prefer `DeinterleaveDist.DINTLV` or `DeinterleaveDist.BDINTLV`. | **Returns**: | Return Value | Type | Description | @@ -351,21 +362,19 @@ for n in range(4): - Source buffer must be in UB memory space - Offset must satisfy alignment requirements for the selected distribution mode - The requested vector region must be within tile bounds (for element-indexing syntax) -- Distribution mode must match element type (e.g., `DeinterleaveDist.B32` for 32-bit elements) +- Distribution mode must match element type (e.g., `"DINTLV"` for 32-bit elements) **Examples**: ```python -from pto import DeinterleaveDist - # Byte-offset syntax -low, high = pto.vldx2(ub_ptr, offset, DeinterleaveDist.B32) +low, high = pto.vldsx2(ub_ptr, offset, pto.DeinterleaveDist.DINTLV) # Element-indexing syntax -low, high = pto.vldx2(tile[i, j:], DeinterleaveDist.B32) -low, high = pto.vldx2(tile[k:], DeinterleaveDist.B16) +low, high = pto.vldsx2(tile[i, j:], pto.DeinterleaveDist.DINTLV) +low, high = pto.vldsx2(tile[k:], pto.DeinterleaveDist.DINTLV) # Example: Load interleaved XY pairs into separate X/Y vectors -x_vec, y_vec = pto.vldx2(xy_tile[i, j:], DeinterleaveDist.B32) +x_vec, y_vec = pto.vldsx2(xy_tile[i, j:], pto.DeinterleaveDist.DINTLV) ``` #### `pto.vsld(buf: ptr, offset: Index, stride: StrideMode) -> VRegType` [Advanced Tier] @@ -637,9 +646,9 @@ def generic_store(src: pto.Tile, dst: pto.Tile): **Returns**: None (side-effect operation) -#### `pto.vstx2(low: VRegType, high: VRegType, buf: ptr, offset: Index, dist: InterleaveDist, mask: MaskType) -> None` [Advanced Tier] -#### `pto.vstx2(low: VRegType, high: VRegType, tile[row, col:], dist: InterleaveDist, mask: MaskType) -> None` -#### `pto.vstx2(low: VRegType, high: VRegType, tile[start:], dist: InterleaveDist, mask: MaskType) -> None` +#### `pto.vstsx2(low: VRegType, high: VRegType, buf: ptr, offset: Index, dist: InterleaveDist, mask: MaskType) -> None` [Advanced Tier] +#### `pto.vstsx2(low: VRegType, high: VRegType, tile[row, col:], dist: InterleaveDist, mask: MaskType) -> None` +#### `pto.vstsx2(low: VRegType, high: VRegType, tile[start:], dist: InterleaveDist, mask: MaskType) -> None` **Description**: Dual interleaved store (SoA → AoS conversion). Stores two vectors interleaved into a single buffer. Supports both byte-offset and element-indexing syntax. @@ -650,7 +659,7 @@ def generic_store(src: pto.Tile, dst: pto.Tile): | `high` | `VRegType` | Second vector (odd elements in interleaved stream) | | `buf` | `ptr` | Pointer to destination buffer in UB memory space (Advanced mode only - requires explicit pointer) | | `offset` | `Index` | Byte offset | -| `dist` | `InterleaveDist` | Interleave distribution mode: `InterleaveDist.B8`, `InterleaveDist.B16`, `InterleaveDist.B32`. Determines element size and interleave pattern. | +| `dist` | `InterleaveDist` | Interleave distribution enum. Prefer `InterleaveDist.INTLV`. | | `mask` | `MaskType` | Predicate mask | **Parameters (element-indexing syntax)**: @@ -659,7 +668,7 @@ def generic_store(src: pto.Tile, dst: pto.Tile): | `low` | `VRegType` | First vector (even elements in interleaved stream) | | `high` | `VRegType` | Second vector (odd elements in interleaved stream) | | `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | -| `dist` | `InterleaveDist` | Interleave distribution mode: `InterleaveDist.B8`, `InterleaveDist.B16`, `InterleaveDist.B32`. | +| `dist` | `InterleaveDist` | Interleave distribution enum. Prefer `InterleaveDist.INTLV`. | | `mask` | `MaskType` | Predicate mask | **Parameters (1D element-indexing syntax)**: @@ -668,7 +677,7 @@ def generic_store(src: pto.Tile, dst: pto.Tile): | `low` | `VRegType` | First vector (even elements in interleaved stream) | | `high` | `VRegType` | Second vector (odd elements in interleaved stream) | | `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | -| `dist` | `InterleaveDist` | Interleave distribution mode: `InterleaveDist.B8`, `InterleaveDist.B16`, `InterleaveDist.B32`. | +| `dist` | `InterleaveDist` | Interleave distribution enum. Prefer `InterleaveDist.INTLV`. | | `mask` | `MaskType` | Predicate mask | **Returns**: None (side-effect operation) @@ -677,22 +686,20 @@ def generic_store(src: pto.Tile, dst: pto.Tile): - Destination buffer must be in UB memory space - Offset must satisfy alignment requirements for the selected distribution mode - The destination vector region must be within tile bounds (for element-indexing syntax) -- Distribution mode must match element type (e.g., `InterleaveDist.B32` for 32-bit elements) +- Distribution mode must match element type (e.g., `"INTLV"` for 32-bit elements) - The two source vectors form an ordered pair; interleave semantics must be preserved **Examples**: ```python -from pto import InterleaveDist - # Byte-offset syntax -pto.vstx2(x_vec, y_vec, ub_ptr, offset, InterleaveDist.B32, mask) +pto.vstsx2(x_vec, y_vec, ub_ptr, offset, pto.InterleaveDist.INTLV, mask) # Element-indexing syntax -pto.vstx2(x_vec, y_vec, tile[i, j:], InterleaveDist.B32, mask) -pto.vstx2(x_vec, y_vec, tile[k:], InterleaveDist.B16, mask) +pto.vstsx2(x_vec, y_vec, tile[i, j:], pto.InterleaveDist.INTLV, mask) +pto.vstsx2(x_vec, y_vec, tile[k:], pto.InterleaveDist.INTLV, mask) # Example: Store separate X/Y vectors as interleaved XY pairs -pto.vstx2(x_vec, y_vec, xy_tile[i, j:], InterleaveDist.B32, all_mask) +pto.vstsx2(x_vec, y_vec, xy_tile[i, j:], pto.InterleaveDist.INTLV, all_mask) ``` #### `pto.vsta(align: pto.align, buf: ptr, offset: Index) -> None` [Advanced Tier] diff --git a/tilelang-dsl/docs/user_guide/10-predicate-operations.md b/tilelang-dsl/docs/user_guide/10-predicate-operations.md index 227888c4d..c0b959943 100644 --- a/tilelang-dsl/docs/user_guide/10-predicate-operations.md +++ b/tilelang-dsl/docs/user_guide/10-predicate-operations.md @@ -242,8 +242,8 @@ mask, remaining = pto.plt_b32(remaining) # generates mask for next chunk, updat | `remaining` | `pto.i32` | Updated remaining element count (only returned when `value` is a `pto.i32` for tail processing) | **Constraints**: -- The `element_type` must be one of: `f32`, `i32`, `f16`, `bf16`, `i16`, `i8` -- The returned mask granularity matches the element type: 32-bit for `f32`/`i32`, 16-bit for `f16`/`bf16`/`i16`, 8-bit for `i8` +- The `element_type` must be one of: `f32`, `f16`, `bf16`, or an 8/16/32-bit integer family member (`i*`, `si*`, `ui*`) +- The returned mask granularity matches the element type: 32-bit for `f32`/`i32`/`si32`/`ui32`, 16-bit for `f16`/`bf16`/`i16`/`si16`/`ui16`, and 8-bit for `i8`/`si8`/`ui8` - The function infers the operation mode from the `value` parameter type at compile time: - `pto.i32` value → tail processing mode (returns `(mask, updated_remaining)`) - `pto.MaskPattern` enum value → pattern mode (returns `mask` only) diff --git a/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md b/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md index 817eeeb94..a8d0d0005 100644 --- a/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md +++ b/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md @@ -879,7 +879,7 @@ scaled = pto.vmuls(vec_f32, pto.f32(2.0), mask32) | `result` | `VRegType` | Vector whose active lanes all carry `value` | **Constraints**: -- Supported scalar types are `i8`, `i16`, `i32`, `f16`, `bf16`, `f32`. +- Supported scalar types are the 8/16/32-bit integer families (`i*`, `si*`, `ui*`) plus `f16`, `bf16`, and `f32`. - For integer types, only the low bits of the scalar source are consumed according to the bit width (8, 16, or 32 bits). **Example**: @@ -893,44 +893,59 @@ rowmax_seed_f32 = pto.vbr(pto.f32("-inf")) rowmax_seed_f16 = pto.vbr(pto.f16("0xFC00")) ``` -**Position Mode Enum**: The `PositionMode` enum provides type-safe position selection for `pto.vdup` operations. Currently only `LOWEST` (selects the lowest-index element) is supported, with more position options planned for future releases. +**Position Mode Enum**: The `PositionMode` enum provides type-safe source-lane +selection for `pto.vdup`. `LOWEST` selects the lowest-index element of the +source vector and `HIGHEST` selects the highest-index element. When the input is +a scalar, the duplicated scalar value is independent of `position`. -#### `pto.vdup(input: ScalarType | VRegType, position: PositionMode = PositionMode.LOWEST) -> VRegType` +#### `pto.vdup(input: ScalarType | VRegType, mask: MaskType, position: PositionMode = PositionMode.LOWEST) -> VRegType` -**Description**: Duplicate scalar or vector element to all lanes. +**Description**: Duplicate a scalar value or one selected vector element into +the active lanes of a destination vector. **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| | `input` | `ScalarType` or `VRegType` | Input scalar or source vector | -| `position` | `PositionMode` | Optional enum selecting which source element to duplicate (default: `PositionMode.LOWEST`) | +| `mask` | `MaskType` | Predicate mask controlling which lanes are written | +| `position` | `PositionMode` | Optional enum selecting the source vector element to duplicate (default: `PositionMode.LOWEST`) | **Returns**: | Return Value | Type | Description | |--------------|------|-------------| -| `result` | `VRegType` | Vector with duplicated value in all lanes | +| `result` | `VRegType` | Vector whose active lanes receive the duplicated value | **Constraints**: -- When `input` is a scalar, it is broadcast to all lanes (similar to `pto.vbr` but with `position` attribute). -- When `input` is a vector, the element selected by `position` is duplicated to all lanes. -- Supported scalar types are `i8`, `i16`, `i32`, `f16`, `bf16`, `f32`. -- The `position` enum selects which source element or scalar position is duplicated. Currently only `PositionMode.LOWEST` is supported, which selects the lowest-index element. +- `mask` granularity must match the destination vector element type. For + example, `f32`/`i32`/`si32`/`ui32` vectors require `mask_b32`. +- When `input` is a scalar, the scalar value is duplicated to every active lane. +- When `input` is a vector, `position` selects a single source element and that + value is duplicated to every active lane. +- Inactive lanes follow VPTO predicate semantics and are not guaranteed to carry + meaningful values for subsequent masked-off use. +- Supported scalar types are the 8/16/32-bit integer families (`i*`, `si*`, `ui*`) plus `f16`, `bf16`, and `f32`. +- `position` is only meaningful for vector input. TileLang DSL currently exposes + `PositionMode.LOWEST` and `PositionMode.HIGHEST`, matching VPTO v0.3. **Example**: ```python -# Broadcast scalar to vector (similar to pto.vbr) -broadcast = pto.vdup(3.14) # position defaults to "POS_LOWEST" +mask32 = pto.make_mask(pto.f32, pto.PAT.ALL) -# Use dtype constructor when the semantic value is floating-point special value -seed = pto.vdup(pto.f32("-inf")) -seed_f16 = pto.vdup(pto.f16("0xFC00")) +# Duplicate a scalar into all active lanes. +broadcast = pto.vdup(3.14, mask32) # position defaults to "LOWEST" -# Duplicate lowest element of vector to all lanes -vec = pto.vreg_f32(64) # 64-element vector -dup_lowest = pto.vdup(vec) # position defaults to "POS_LOWEST" +# Use dtype constructors for floating-point special values. +seed = pto.vdup(pto.f32("-inf"), mask32) +seed_f16 = pto.vdup(pto.f16("0xFC00"), pto.make_mask(pto.f16, pto.PAT.ALL)) -# Explicit position specification -dup_explicit = pto.vdup(vec, position=PositionMode.LOWEST) +# Assume `vec` is an existing `f32` vector register value. +vec = pto.vlds(src, 0) + +# Duplicate the lowest source lane to all active lanes. +dup_lowest = pto.vdup(vec, mask32) # position defaults to "LOWEST" + +# Duplicate the highest source lane to all active lanes. +dup_highest = pto.vdup(vec, mask32, pto.PositionMode.HIGHEST) ``` **Type Safety Note**: diff --git a/tilelang-dsl/docs/vpto_spec/vpto-spec-current.md b/tilelang-dsl/docs/vpto_spec/vpto-spec-current.md new file mode 100644 index 000000000..8de281795 --- /dev/null +++ b/tilelang-dsl/docs/vpto_spec/vpto-spec-current.md @@ -0,0 +1,5349 @@ +# PTO micro Instruction Spec — Draft (A5) + +- v0.3: Add runtime block query and vector-interval legality notes; Normalize load/store distribution families; Update get_buf/rls_buf details +- v0.2: Update micro Instruction latency and throughput +- v0.1: Doc Init + +[toc] + +--- + +## Part I: Architecture Overview + +### Overview + +This document defines the PTO micro Instruction, a compiler-internal and externally facing specification designed to represent vector compute kernels within the PTO architecture. Much like NVVM provides a robust IR for GPU architectures, the PTO micro Instruction serves as the direct bridge between high-level programming models and the underlying hardware ISA, providing a precise, low-level representation of vector workloads explicitly designed for the Ascend 950 architecture. + +#### Position in the Stack and Layer Modeled + +The PTO micro Instruction operates as a very low-level intermediate representation within the PTO compiler stack. It is uniquely designed to accurately and comprehensively express all architectural information of the Ascend 950 hardware. It specifically models the bare-metal vector execution layer, making hardware-specific capabilities and constraints, such as exact vector lane configurations, memory space hierarchies, and hardware-specific fusion semantics, fully transparent and controllable. + +#### PTO Instruction Modes and Compilation Flows + +Within the end-to-end PTO software stack, PTO instructions may appear in three closely related authoring or lowering modes: + +- **PTO Tile Instruction**: tile-oriented PTO code that serves as a nano-kernel encapsulation of Tile operations, primarily expressing computation and data movement in terms of tile buffers, tile shapes, and tile-local layout. +- **PTO micro Instruction**: vector-execution-oriented PTO code that makes DMA setup, vector registers, masks, synchronization, and `__VEC_SCOPE__` boundaries explicit. This document is centered on this mode. +- **PTO Tile+micro Instruction**: a hybrid PTO form that keeps tile-level orchestration while embedding explicit micro-instruction regions where direct vector-pipeline control is required. + +From these PTO instruction forms, the stack can proceed along two main compilation flows: + +- **CCE generation flow**: PTO ISA is lowered into a CCE-oriented representation, which is then compiled by the BiSheng toolchain into Ascend device binaries. +- **Bytecode generation flow**: PTO ISA is emitted as bytecode, which is then compiled by the BiSheng toolchain into Ascend device binaries. + +```text +High-level frameworks / DSLs / library kernels + | + v + +----------------------------------+ + | PTO ISA layer | + | | + | (1) PTO Tile Instruction | + | (2) PTO micro Instruction | + | (3) PTO Tile+micro Instruction | + +----------------+-----------------+ + | + +------------+------------+ + | | + v v + +-------------------------+ +-------------------------+ + | Path A: generate CCE | | Path B: generate | + | (CCE-oriented form) | | bytecode | + +------------+------------+ +------------+------------+ + | | + v v + +-----------------------------------------------+ + | BiSheng compiler | + +---------------------------+-------------------+ + | + v + +-----------------------------+ + | Ascend device binaries | + +-----------------------------+ +``` + +#### Why External Developers Read or Author PTO micro Instruction + +While the majority of users will interact with the PTO architecture via higher-level frameworks, external developers may need to read or author PTO micro Instruction directly for several key reasons: + +- Custom Toolchain Development: build custom compiler frontends or domain-specific languages (DSLs) that target the Ascend 950 architecture with maximum hardware utilization. +- Performance Engineering: inspect the output of high-level compiler passes, verify fine-grained optimization behaviors, and pinpoint performance bottlenecks at the architectural level. +- Micro-Optimization: hand-author highly optimized, critical mathematical kernels using a stable, precise IR when higher-level abstractions cannot achieve the theoretical peak performance of the hardware. + +#### Relationship to CCE + +The PTO micro Instruction is designed to express the full semantic capabilities of the Compute Cube Engine (CCE), but with significant structural and pipeline advantages for compiler development. + +- Bypassing the C/Clang Pipeline: while CCE heavily relies on C/C++ extensions parsed by Clang, the PTO micro Instruction operates entirely independently of the C language frontend. By bypassing Clang AST generation and frontend processing, utilizing the PTO micro Instruction significantly reduces overall compilation time and memory overhead. +- Enhanced IR Verification: because the PTO micro Instruction is a strongly typed, SSA-based (Static Single Assignment) compiler IR rather than a C-wrapper API, it provides a much more rigorous and detailed IR verification process. Structural inconsistencies, invalid memory access patterns, and operand type mismatches are caught immediately with precise, explicit diagnostic feedback, providing developers with much higher visibility into kernel correctness than traditional CCE error reporting. + +#### Intended Audience + +This document is written for compiler engineers, library writers, and advanced performance architects. We expect the reader to have a working understanding of modern compiler infrastructure, specifically MLIR, the principles of Static Single Assignment (SSA) form, and a deep understanding of the vector-processing capabilities of the Ascend 950 architecture. + +### Getting Started + +The PTO micro Instruction is architected as a performance-critical layer within the compiler stack, specifically designed to exploit the **Decoupled Access-Execute** (DAE) nature of the Ascend 950 hardware. + +#### Hardware Pipeline Modeling + +The IR is structured to mirror the three primary hardware pipelines of the Ascend 950 architecture. Correct PTO micro Instruction authoring requires managing the interaction between these asynchronous units: + +**MTE2** (Memory Transfer Engine - Inbound): Responsible for moving data from Global Memory (GM) to the Unified Buffer (UB). + +**Vector Core** (Computation): The primary engine for executing SIMD operations on data stored in UB. + +**MTE3** (Memory Transfer Engine - Outbound): Responsible for moving processed data from UB back to GM. + +#### Architecture Detail: Vector Lane (VLane) + +The vector register is organized as **8 VLanes** of 32 bytes each. A VLane is the atomic unit for group reduction operations. + +``` +vreg (256 bytes total): +┌─────────┬─────────┬─────────┬─────┬─────────┬─────────┐ +│ VLane 0 │ VLane 1 │ VLane 2 │ ... │ VLane 6 │ VLane 7 │ +│ 32B │ 32B │ 32B │ │ 32B │ 32B │ +└─────────┴─────────┴─────────┴─────┴─────────┴─────────┘ +``` + +Elements per VLane by data type: + +| Data Type | Elements/VLane | Total Elements/vreg | +|-----------|---------------|-------------------| +| i8/si8/ui8 | 32 | 256 | +| i16/si16/ui16/f16/bf16 | 16 | 128 | +| i32/si32/ui32/f32 | 8 | 64 | +| i64/si64/ui64 | 4 | 32 | + +#### Memory and Synchronization Model + +The PTO micro Instruction enforces a strict memory hierarchy. The Unified Buffer (UB) is the only valid operand source for vector compute instructions. Consequently, the architecture of a PTO micro Instruction program is defined by the explicit management of data movement: + +**Address Space Isolation**: The IR uses `!pto.ptr` to distinguish between GM (`!pto.ptr`) and UB (`!pto.ptr`). The verifier ensures that vector compute operations do not access GM directly; data must first be moved into UB. + +**UB Capacity**: The Unified Buffer provides 256KB of on-chip SRAM (also referred to as "vecTile"). + +**Data Flow**: + +``` +┌─────────────────────────────────────────────┐ +│ Global Memory (GM) │ +│ (Off-chip HBM/DDR) │ +└─────────────────────┬───────────────────────┘ + │ DMA (MTE2 inbound / MTE3 outbound) +┌─────────────────────▼───────────────────────┐ +│ Unified Buffer (UB) │ +│ (On-chip SRAM, 256KB) │ +└─────────────────────┬───────────────────────┘ + │ Vector Load/Store (PIPE_V) +┌─────────────────────▼───────────────────────┐ +│ Vector Register File (VRF) │ +│ vreg (256B each) + mask (256-bit each) │ +└─────────────────────────────────────────────┘ +``` + +1. **GM → UB**: DMA transfer via MTE2 (`pto.copy_gm_to_ubuf`) +2. **UB → vreg**: Vector Load instructions (`pto.vlds`, `pto.vldsx2`, etc.) +3. **vreg → vreg**: Compute instructions (`pto.vadd`, `pto.vmul`, etc.) +4. **vreg → UB**: Vector Store instructions (`pto.vsts`, `pto.vstsx2`, etc.) +5. **UB → GM**: DMA transfer via MTE3 (`pto.copy_ubuf_to_gm`) + +**Load/Store Access Patterns**: + +For UB↔vreg data movement, besides contiguous load/store, the architecture provides rich access pattern support including strided access, pack/unpack, interleave/deinterleave, broadcast, upsample/downsample, channel split/merge, gather/scatter, and squeeze/expand operations. For detailed instruction syntax and distribution modes, refer to the [Vector Load/Store](#isa-03-vector-load-store) group in the ISA specification. + +#### Synchronization Model + +The Ascend 950 architecture employs a cluster-based design with a 1:2 ratio of Cube cores to Vector cores. The PTO micro Instruction provides multiple levels of synchronization to manage concurrent execution across pipelines and cores: + +**Inter-Core Synchronization (within a cluster):** + +Synchronization between cores within the same cluster is achieved via the core sync mechanism using `pto.set_intra_core` and `pto.wait_intra_core` operations. This enables coordination between Cube and Vector cores sharing the same cluster resources. + +**Vector Core Pipeline Synchronization:** + +Within a single core, multiple pipelines operate asynchronously: + +- **MTE2 (PIPE_MTE2)**: DMA copy-in from GM to UB +- **MTE3 (PIPE_MTE3)**: DMA copy-out from UB to GM +- **Vector Compute (PIPE_V)**: Vector ALU operations +- **Scalar (PIPE_S)**: Scalar unit running the kernel program + +Pipeline synchronization can be achieved through two mechanisms: + +1. **Flag/Event mechanism**: `pto.set_flag` and `pto.wait_flag` operations resolve Read-After-Write (RAW) and Write-After-Read (WAR) hazards between pipelines. + +2. **Buffer-ID mechanism**: `pto.get_buf` and `pto.rls_buf` provide finer-grained synchronization through buffer acquisition and release semantics for producer-consumer coordination. + +**Intra-Pipeline Memory Barriers (within `__VEC_SCOPE__`):** + +Within the vector execution scope, the hardware does not track UB address aliasing between reg↔UB accesses. When UB addresses overlap or alias between vector load/store operations, explicit memory barriers are required: + +```c +pto.mem_bar "VV_ALL" // All prior vector ops complete before subsequent +pto.mem_bar "VST_VLD" // All prior vector stores visible before subsequent loads +pto.mem_bar "VLD_VST" // All prior vector loads complete before subsequent stores +``` + +Without proper barriers, loads may see stale data or stores may be reordered incorrectly. + +#### Execution Scopes (__VEC_SCOPE__) + +`__VEC_SCOPE__` is the IR-level representation of a Vector Function (VF) launch. In the PTO architecture, it defines the hardware interface between the Scalar Unit and the Vector Thread. + +In PTO micro Instruction source IR, vector execution scopes are modeled as dedicated region ops. The default form is `pto.vecscope`; when the scope body must reject implicit capture and require explicit region arguments, use `pto.strict_vecscope`. + +**Scalar-Vector Interface:** + +The execution model follows non-blocking fork semantics: + +- Scalar invocation: the scalar processor invokes a vector thread by calling a VF. Once the launch command is issued, the scalar unit does not stall and continues executing subsequent instructions in the pipeline. +- Vector execution: after invocation, the vector thread independently fetches and executes the instructions defined within the VF scope. +- Parallelism: this decoupled execution allows the scalar and vector units to run in parallel, so the scalar unit can prepare addresses or manage control flow while the vector unit performs heavy SIMD computation. + +**Launch Mechanism And Constraints:** + +- Parameter buffering: all arguments required by the VF must be staged in hardware-specific buffers. +- Launch overhead: launching a VF incurs a latency of a few cycles. Very small VFs should account for this overhead because launch cost can rival useful computation time. + +**MLIR Representation:** + +```mlir +pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} +``` + +**Strict MLIR Representation:** + +```mlir +pto.strict_vecscope(%ub, %ub_out, %lane) { +^bb0(%in: !pto.ptr, %out: !pto.ptr, %iv: index): + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %in[%iv] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %out[%iv], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} : (!pto.ptr, !pto.ptr, index) -> () +``` + +`pto.strict_vecscope` is the strict form of `pto.vecscope`. + +- `pto.vecscope` allows the body to use surrounding SSA values directly. +- `pto.strict_vecscope` requires every external value used by the body to be passed through the op operand list and received as a body block argument. +- `pto.strict_vecscope` rejects implicit capture from the surrounding scope. +- both ops still represent one explicit VPTO vector interval. +- regardless of whether the source form uses `pto.vecscope`, + `pto.strict_vecscope`, or a lowered carrier loop with + `llvm.loop.aivector_scope`, every op that produces or consumes `!pto.vreg`, + `!pto.mask<...>`, or `!pto.align` must be enclosed by exactly one vector + interval +- nested vector intervals are not part of the legal VPTO surface; ordinary + nested `scf.for` structure is fine, but one vector interval may not contain + another vector interval + +### Example: VecScope + +```mlir +pto.set_loop2_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.copy_gm_to_ubuf %7, %2, %3, %3, %c0_i64, %c32_i64, %4, %c0_i64, %c0_i64, + %false, %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i1, i64, i64, i64 + +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +pto.vecscope { + scf.for %lane = %c0 to %9 step %c64 { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %2[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %8[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } +} + +pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] +pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop2_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 +pto.copy_ubuf_to_gm %8, %14, %3, %3, %c0_i64, %c32_i64, %4, %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i64 +``` + +### Example: Strict VecScope + +```mlir +pto.strict_vecscope(%ub_in, %ub_out, %lane, %remaining) { +^bb0(%in: !pto.ptr, %out: !pto.ptr, %iv: index, %rem: i32): + %mask, %next_remaining = pto.plt_b32 %rem : i32 -> !pto.mask, i32 + %v = pto.vlds %in[%iv] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %out[%iv], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} : (!pto.ptr, !pto.ptr, index, i32) -> () +``` + +Use `pto.strict_vecscope` when the source form should make all vector-scope inputs explicit in the region signature instead of relying on surrounding SSA visibility. The scope op itself only defines the vector-interval boundary and region argument contract. + +### Scope + +This document is the interface specification centered on the `mlir::pto` dialect and the shared MLIR surface used alongside it in PTO micro Instruction programs. + +It only describes: + +- operation names +- operand and result lists +- operand and result types +- important attributes +- C-style semantics for each operation + +It does not describe lowering strategy. + +PTO micro Instruction source programs are not restricted to `pto` operations alone. In practice they also use shared MLIR dialect ops, most notably the full scalar operation surface of `arith` together with structured control-flow ops from `scf`, to express scalar constants, scalar arithmetic, type conversion, comparisons, and structured control flow around PTO vector or tile regions. These shared-dialect ops are part of the supported PTO micro Instruction source surface and should be regarded as part of PTO-ISA alongside `pto` dialect operations. + +### Shared MLIR Dialects + +- `arith`: the full scalar `arith` surface is supported in PTO micro Instruction programs, covering scalar integer, floating-point, boolean, and `index` operations. In current samples the most common uses are still constants, offset/bounds arithmetic, casts, compares, and selects. +- `scf`: structured control flow used to model counted loops, conditional regions, loop-carried state, and break-like control around PTO compute and data-movement ops. +- Shared dialect ops remain in standard MLIR form so that PTO analyses and backend passes can reason about control flow and scalar state without re-encoding them as PTO-specific instructions. + +### BlockDim Query Operations + +These ops expose the current kernel instance's execution coordinates to scalar code. They are the PTO-level equivalent of runtime queries such as `GetBlockIdx()` and `GetBlockNum()` in kernel programming models. + +Use them when the same kernel body is launched across multiple blocks or subblocks and each execution instance must figure out which slice of the global workload it owns. + +A common pattern is: + +- split the full input/output tensor into `block_num` disjoint block-sized regions +- let each block compute its own starting offset from `block_idx` +- within one block, further tile the local region and drive the tile loop with ordinary scalar `arith` / `scf` ops + +For example, if a tensor is split evenly across 8 blocks and each block handles `block_length = 2048` elements, then block `b` owns the global range `[b * block_length, (b + 1) * block_length)`. The per-block GM base pointer can be formed by adding `block_idx * block_length` elements to the original base pointer. + +At the PTO micro Instruction level, these runtime-query ops are pure scalar producers. They do not perform data movement, do not allocate memory, and do not by themselves create tiling or double buffering. Instead, they provide the scalar values used by surrounding address computation and structured control flow. + +#### Example: block-level data partitioning + +```mlir +%block = pto.get_block_idx +%block_num = pto.get_block_num +%block_len = arith.constant 2048 : index +%base = arith.index_cast %block : i64 to index +%offset = arith.muli %base, %block_len : index +%block_in = pto.addptr %gm_in, %offset : !pto.ptr -> !pto.ptr +%block_out = pto.addptr %gm_out, %offset : !pto.ptr -> !pto.ptr +``` + +In this pattern, all blocks execute the same kernel body, but each block sees a different `%block` value and therefore computes a different GM window. + +#### `pto.get_block_idx` + +- **syntax:** `%block = pto.get_block_idx` +- **result:** `i64` +- **semantics:** Return the current block ID in the range `[0, pto.get_block_num())`. + +```c +block = block_idx(); +``` + +#### `pto.get_subblock_idx` + +- **syntax:** `%subblock = pto.get_subblock_idx` +- **result:** `i64` +- **semantics:** Return the current subblock ID in the range `[0, pto.get_subblock_num())`. + +```c +subblock = subblock_idx(); +``` + +#### `pto.get_block_num` + +- **syntax:** `%block_num = pto.get_block_num` +- **result:** `i64` +- **semantics:** Return the total number of launched blocks visible to the current kernel instance. + +```c +block_num = block_num(); +``` + +#### `pto.get_subblock_num` + +- **syntax:** `%subblock_num = pto.get_subblock_num` +- **result:** `i64` +- **semantics:** Return the total number of visible subblocks for the current execution instance. + +```c +subblock_num = subblock_num(); +``` + +Typical usage: + +```mlir +%block = pto.get_block_idx +%subblock = pto.get_subblock_idx +%block_num = pto.get_block_num +%subblock_num = pto.get_subblock_num +``` + +### Core Types + +### Element Types +`vreg`: `!pto.vreg` Fixed-width PTO micro Instruction vector type with total width exactly 256 bytes (2048 bits). `N` is the lane count, `T` is the element type, and `N * bitwidth(T) = 2048`. + +| Type | Bits | Description | +|------|------|-------------| +| `i8` / `si8` / `ui8` | 8 | Signless/signed/unsigned 8-bit integer | +| `i16` / `si16` / `ui16` | 16 | Signless/signed/unsigned 16-bit integer | +| `i32` / `si32` / `ui32` | 32 | Signless/signed/unsigned 32-bit integer | +| `i64` / `si64` / `ui64` | 64 | Signless/signed/unsigned 64-bit integer | +| `f16` | 16 | IEEE 754 half precision | +| `bf16` | 16 | Brain floating point | +| `f32` | 32 | IEEE 754 single precision | + +### Mask Types + +`mask`: `!pto.mask` Typed predicate-register view. `G` is one of `b8`, `b16`, `b32` and records the byte-granularity interpretation used by VPTO ops and verifiers. + +Typed masks are also the primary legality contract for predicated VPTO code: + +- vector ops over `f32`, `i32`, `si32`, and `ui32` consume `!pto.mask` +- vector ops over `f16`, `bf16`, `i16`, `si16`, and `ui16` consume + `!pto.mask` +- vector ops over 8-bit element families consume `!pto.mask` +- compare families keep seed-mask and result-mask granularity aligned with the + compared vector family +- carry families keep carry-in, carry-out, and execution-mask granularity + aligned with the data-vector family +- mask-only ops that do not explicitly change granularity preserve the same `G` + +### Address Space Conventions + +PTO micro Instruction memory operands use `!pto.ptr`. This specification models the following memory-space attributes: + +| Space | Interpretation | +|-------|----------------| +| `gm` | Global Memory (GM), off-chip HBM/DDR storage | +| `ub` | Unified Buffer (UB), on-chip vector buffer | + +Typical pointer construction and pointer arithmetic follow the same `!pto.ptr<..., space>` form: + +```mlir +%0 = pto.castptr %c0 : i64 -> !pto.ptr +%1 = pto.addptr %0, %c1024 : !pto.ptr -> !pto.ptr +``` + +### `!pto.ptr` + +`!pto.ptr` is the typed pointer form used for explicit memory operands in PTO micro Instruction. + +- `T` is the element type associated with the pointed-to storage. +- `space` is the memory domain, typically `gm` or `ub` in this specification. +- A `pto.ptr` value carries an address plus its element-type / memory-space interpretation, but it does not carry tensor shape or stride metadata by itself. +- Tensor semantics are introduced separately through view-building operations such as `pto.make_tensor_view`. +- Pointer arithmetic is element-based rather than byte-based. + +Typical examples: + +- `!pto.ptr` +- `!pto.ptr` +- `!pto.ptr` + +### Pointer Operations + +#### `pto.castptr` + +- **syntax:** `%result = pto.castptr %addr : i64 -> !pto.ptr` +- **semantics:** Reinterpret a scalar address value as a typed PTO pointer in the target memory space. + +```c +result = (ptr)addr; +``` + +`pto.castptr` is a pointer-construction operation. It does not perform data movement and does not by itself imply any load/store side effect. + +#### `pto.addptr` + +- **syntax:** `%result = pto.addptr %ptr, %offset : !pto.ptr -> !pto.ptr` +- **semantics:** Compute a new pointer by advancing the base pointer by an element offset. + +```c +result = ptr + offset; // offset counted in elements, not bytes +``` + +`pto.addptr` preserves both the element type `T` and the memory-space tag `space`. + +#### `pto.load_scalar` + +- **syntax:** `%value = pto.load_scalar %ptr[%offset] : !pto.ptr -> T` +- **semantics:** Load one scalar element from a pointer-like operand. + +```c +value = ptr[offset]; +``` + +- **inputs:** + `%ptr` is a typed PTO pointer `!pto.ptr`, and `%offset` is an + `index` displacement counted in elements. +- **outputs:** + `%value` is the loaded scalar element. +- **constraints and limitations:** + The result type MUST match the element type of `%ptr`. This op is a scalar + memory helper; unlike `pto.vlds`, it does not produce a `vreg` result and + does not participate in vector load `dist` families. + +#### `pto.store_scalar` + +- **syntax:** `pto.store_scalar %value, %ptr[%offset] : !pto.ptr, T` +- **semantics:** Store one scalar element to a pointer-like operand. + +```c +ptr[offset] = value; +``` + +- **inputs:** + `%value` is the scalar value to store. `%ptr` is a typed PTO pointer + `!pto.ptr`, and `%offset` is an `index` displacement counted in + elements. +- **constraints and limitations:** + The stored value type MUST match the element type of `%ptr`. This op is a + scalar memory helper; unlike `pto.vsts`, it does not consume a mask and does + not target vector-store `dist` families. + +#### Pointer-Based Vector Access Example + +The following lowered-style fragment shows how typed PTO pointers flow through +pointer construction, pointer arithmetic, structured control flow, and PTO +memory ops. Scalar memory access is expressed on `!pto.ptr` in +general, but the common VPTO pattern here is UB-local scalar access alongside +UB vector loads/stores: + +```mlir +%0 = pto.castptr %c0 : i64 -> !pto.ptr +%1 = pto.addptr %0, %c1024 : !pto.ptr -> !pto.ptr +pto.vecscope { + %16 = scf.for %arg3 = %c0 to %11 step %c64 iter_args(%arg4 = %12) -> (i32) { + %mask, %scalar_out = pto.plt_b32 %arg4 : i32 -> !pto.mask, i32 + %s = pto.load_scalar %1[%c4] : !pto.ptr -> f32 + pto.store_scalar %s, %1[%c8] : !pto.ptr, f32 + %17 = pto.vlds %1[%arg3] : !pto.ptr -> !pto.vreg<64xf32> + %18 = pto.vabs %17, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %18, %10[%arg3], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %scalar_out : i32 + } +} +``` + +In this pattern, `pto.castptr` materializes a typed UB pointer, `pto.addptr` shifts the base by 1024 `f32` elements, and the subsequent `[%arg3]` indexing on `pto.vlds` / `pto.vsts` applies an additional element offset relative to that base. + +### Special Types + +#### `!pto.mask` + +`!pto.mask` models an A5 predicate register (256-bit) under a typed granularity view, not an integer vector. + +`G` is part of the type and MUST be one of: + +- `b32` +- `b16` +- `b8` + +All three forms describe the same physical 256-bit predicate-register class. The type parameter does not encode how many lanes are currently active. Instead, it records how VPTO interprets the register when matching mask-producing ops, mask-consuming ops, and verifier legality rules. + +In the ISA chapters below, this document uses `!pto.mask` as shorthand when a +family is generic over granularity. For op families whose names already encode +the granularity, such as `pset_b32`, `pge_b16`, `plt_b8`, +`pdintlv_b8`, and `pintlv_b16`, examples use the corresponding concrete typed +mask. + +**Mask Granularity:** + +The predicate register is 256 bits in length, where each bit controls 1 byte of data. `G` therefore describes how many bytes form one logical element slot: + +| Mask Type | Bytes / Element Slot | Typical Element Family | Derived Logical Lanes | +|-----------|----------------------|------------------------|-----------------------| +| `!pto.mask` | 4 | `f32` / `i32` | 64 | +| `!pto.mask` | 2 | `f16` / `bf16` / `i16` | 128 | +| `!pto.mask` | 1 | 8-bit element family | 256 | + +This is intentionally different from a lane-vector model such as `mask<64xi1>`: + +- `!pto.mask` still denotes a 256-bit predicate register; +- `64` is only the derived logical lane count for the `b32` view; +- value-level patterns such as `PAT_VL32` describe which lanes are active, not a different type. +- `pto.vaddc`, `pto.vsubc`, `pto.vaddcs`, and `pto.vsubcs` use `!pto.mask` + to carry their per-lane carry results, interpreted with this same + granularity. + +**Predication Behavior (Zero-Merge):** + +The native hardware predication mode is **ZEROING** — inactive lanes produce zero: + +```c +dst[i] = mask[i] ? op(src0[i], src1[i]) : 0 // ZEROING mode +``` + +```mlir +// Predicated add: inactive lanes produce zero +%mask = pto.pset_b32 "PAT_VL32" : !pto.mask // first 32 logical b32 lanes active +%result = pto.vcmp %a, %b, %mask, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +``` + +```mlir +// Compare and select: generate mask from comparison, use for conditional select +%mask = pto.vcmp %lhs, %rhs, %seed, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +%out = pto.vsel %x, %y, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +#### `!pto.align` + +`!pto.align` models the A5 vector-align carrier state. It is not payload data. + +```mlir +%align = pto.vldas %ub : !pto.ptr -> !pto.align +%vec, %align_out = pto.vldus %ub, %align : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align + +%store_align = pto.init_align : !pto.align +%next_align = pto.vstus %store_align, %offset, %vec, %ub + : !pto.align, i32, !pto.vreg<64xf32>, !pto.ptr -> !pto.align +``` + +--- + +## Part II: Notation Convention + +This section defines the MLIR syntax patterns and C-style semantic notation used throughout the ISA reference (Part III). + +### MLIR Op Syntax Patterns + +All PTO micro Instruction operations follow standard MLIR syntax. The common patterns are: + +**Unary (one vector in, one vector out):** + +```mlir +%result = pto. %input : !pto.vreg -> !pto.vreg +``` + +**Binary (two vectors in, one vector out):** + +```mlir +%result = pto. %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg +``` + +**Vec-Scalar (one vector + one scalar in, one vector out):** + +```mlir +%result = pto. %input, %scalar : !pto.vreg, T -> !pto.vreg +``` + +**Load (memory to register):** + +```mlir +%result = pto.vlds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.vreg +``` + +**Store (register to memory):** + +```mlir +pto.vsts %value, %destination[%offset] {dist = "DIST"} : !pto.vreg, !pto.ptr +``` + +**Dual Load (one load, two results — deinterleave):** + +```mlir +%low, %high = pto.vldsx2 %source[%offset], "DIST" : !pto.ptr, index -> !pto.vreg, !pto.vreg +``` + +**Dual Store (two inputs, one interleaved store):** + +```mlir +pto.vstsx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vreg, !pto.ptr, index, !pto.mask +``` + +**Compare (two vectors + seed mask in, mask out):** + +```mlir +%mask = pto.vcmp %src0, %src1, %seed, "CMP_MODE" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.mask +``` + +**Conversion (one vector in, different-typed vector out):** + +```mlir +%result = pto.vcvt %input {rnd = "R", sat = "SAT", part = "EVEN"} : !pto.vreg -> !pto.vreg +``` + +**Predicate construction:** + +```mlir +%mask = pto.pset_b32 "PAT_ALL" : !pto.mask +%tail = pto.pge_b32 "PAT_VL16" : !pto.mask +``` + +**Sync operations:** + +```mlir +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +**Pointer construction and arithmetic:** + +```mlir +%ptr = pto.castptr %addr : i64 -> !pto.ptr +%ptr2 = pto.addptr %ptr, %offset : !pto.ptr -> !pto.ptr +``` + +### Shared Dialect Syntax Patterns + +PTO micro Instruction programs may interleave PTO ops with standard MLIR `arith` and `scf` ops. +The examples below emphasize common index-heavy patterns, but `arith` support is not limited to index arithmetic. + +**Scalar / index constant:** + +```mlir +%c0 = arith.constant 0 : index +%zero = arith.constant 0.0 : f32 +``` + +**Scalar arithmetic (integer / float / boolean-style bitwise):** + +```mlir +%sum_i = arith.addi %lhs_i, %rhs_i : i32 +%sum_f = arith.addf %lhs_f, %rhs_f : f32 +%bits = arith.andi %flags0, %flags1 : i32 +``` + +**Scalar compare and select:** + +```mlir +%cond = arith.cmpi eq, %lhs, %rhs : index +%bound = arith.select %cond, %a, %b : index +``` + +**Counted loop with loop-carried values:** + +```mlir +%result = scf.for %iv = %lb to %ub step %step + iter_args(%acc = %init) -> (index) { + %next = arith.addi %acc, %iv : index + scf.yield %next : index +} +``` + +**Structured conditional region:** + +```mlir +%selected = scf.if %cond -> (index) { + scf.yield %then_value : index +} else { + scf.yield %else_value : index +} +``` + +**Structured while loop:** + +```mlir +%state:2 = scf.while (%iv = %c0, %alive = %true) : (index, i1) -> (index, i1) { + %keep_going = arith.cmpi slt, %iv, %limit : index + scf.condition(%keep_going) %iv, %alive : index, i1 +} do { +^bb0(%iv_in: index, %alive_in: i1): + %iv_next = arith.addi %iv_in, %c1 : index + scf.yield %iv_next, %alive_in : index, i1 +} +``` + +### C-Style Semantics Convention + +For each ISA operation in Part III, semantics are expressed as C code. The convention: + +```c +// Vector register contents as arrays: +T dst[N]; // destination +T src0[N]; // first source +T src1[N]; // second source (binary ops) +T scalar; // scalar operand (vec-scalar ops) +int mask[N]; // per-lane predicate (0 or 1) + +// N = lane count determined by type: +// N = 256 for i8/si8/ui8 +// N = 128 for i16/si16/ui16/f16/bf16 +// N = 64 for i32/si32/ui32/f32 +// N = 32 for i64/si64/ui64 +``` + +**Example — pto.vadd semantics:** + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +**Example — pto.vcgadd (group reduction per VLane) semantics:** + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +### Template Placeholder Conventions + +| Placeholder | Meaning | +|-------------|---------| +| `"SRC_PIPE"`, `"DST_PIPE"` | Pipeline identifiers: `"PIPE_MTE2"`, `"PIPE_V"`, `"PIPE_MTE3"` | +| `"EVENT_ID"` | Event identifier: `"EVENT_ID0"` etc. | +| `"DIST"` | Distribution mode string (see the relevant load/store ISA group in Part III) | +| `"CMP_MODE"` | Compare predicate: `eq \| ne \| lt \| le \| gt \| ge` | +| `"RND"` | Rounding mode: `R \| A \| F \| C \| Z \| O` | +| `"SAT"` | Saturation: `SAT \| NOSAT` | +| `"PART"` | Half selector: `EVEN \| ODD` | +| `"PAT_*"` | Predicate pattern literal | +| `T` | Element type (f32, f16, bf16, i32, i16, i8, etc.) | +| `N` | Lane count (`N * bitwidth(T) = 2048`) | + +--- + +### C-Style Semantics Convention + +For each ISA operation in Part III, semantics are expressed as C code. The convention: + +```c +// Vector register contents as arrays: +T dst[N]; // destination +T src0[N]; // first source +T src1[N]; // second source (binary ops) +T scalar; // scalar operand (vec-scalar ops) +int mask[N]; // per-lane predicate (0 or 1) + +// N = lane count determined by type: +// N = 256 for i8/si8/ui8 +// N = 128 for i16/si16/ui16/f16/bf16 +// N = 64 for i32/si32/ui32/f32 +// N = 32 for i64/si64/ui64 +``` + +**Example — pto.vadd semantics:** + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +**Example — pto.vcgadd (group reduction per VLane) semantics:** + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +### Template Placeholder Conventions + +| Placeholder | Meaning | +|-------------|---------| +| `"SRC_PIPE"`, `"DST_PIPE"` | Pipeline identifiers: `"PIPE_MTE2"`, `"PIPE_V"`, `"PIPE_MTE3"` | +| `"EVENT_ID"` | Event identifier: `"EVENT_ID0"` etc. | +| `"DIST"` | Distribution mode string (see the relevant load/store ISA group in Part III) | +| `"CMP_MODE"` | Compare predicate: `eq \| ne \| lt \| le \| gt \| ge` | +| `"RND"` | Rounding mode: `R \| A \| F \| C \| Z \| O` | +| `"SAT"` | Saturation: `SAT \| NOSAT` | +| `"PART"` | Half selector: `EVEN \| ODD` | +| `"PAT_*"` | Predicate pattern literal | +| `T` | Element type (f32, f16, bf16, i32, i16, i8, etc.) | +| `N` | Lane count (`N * bitwidth(T) = 2048`) | + +--- + +## Part III: ISA Instruction Reference — Summary + +This section provides a categorized overview of all PTO micro Instruction operations plus the shared MLIR `arith` and `scf` ops that may appear in PTO micro Instruction programs. Detailed documentation for each group is included later in this merged document. + +--- + +## Instruction Groups + +| # | Group | Description | Count | Details | +|---|-------|-------------|-------|---------| +| 1 | [Pipeline Sync](#isa-01-pipeline-sync) | Intra-core pipeline synchronization | 5 | `pto.set_flag`, `pto.wait_flag`, `pto.pipe_barrier`, `pto.get_buf`, `pto.rls_buf` | +| 2 | [DMA Copy Programming](#isa-02-dma-copy) | DMA configuration and transfer between GM↔UB | 9 | `pto.set_loop*_stride_*`, `pto.set_loop_size_*`, `pto.copy_gm_to_ubuf`, `pto.copy_ubuf_to_ubuf`, `pto.copy_ubuf_to_gm` | +| 3 | [Vector Load/Store](#isa-03-vector-load-store) | UB↔vreg data movement with various access patterns | ~20 | `pto.vlds`, `pto.vldsx2`, `pto.vgather2`, `pto.vsts`, `pto.vstsx2`, `pto.vscatter`, etc. | +| 4 | [Predicate Load/Store](#isa-04-predicate-load-store) | UB↔mask register movement | 5 | `pto.plds`, `pto.pldi`, `pto.psts`, `pto.psti`, `pto.pstu` | +| 5 | [Materialization & Predicate Ops](#isa-05-materialization-predicate) | Scalar broadcast, predicate generation and manipulation | ~17 | `pto.vbr`, `pto.vdup`, `pto.pset_b*`, `pto.pge_b*`, `pto.plt_b*`, `pto.ppack`, `pto.punpack`, `pto.pnot`, `pto.psel`, etc. | +| 6 | [Unary Vector Ops](#isa-06-unary-vector-ops) | Single-input element-wise operations | 6 | `pto.vabs`, `pto.vexp`, `pto.vln`, `pto.vsqrt`, `pto.vrelu`, `pto.vnot` | +| 7 | [Binary Vector Ops](#isa-07-binary-vector-ops) | Two-input element-wise operations | 13 | `pto.vadd`, `pto.vsub`, `pto.vmul`, `pto.vdiv`, `pto.vmax`, `pto.vmin`, `pto.vand`, `pto.vor`, `pto.vxor`, `pto.vshl`, `pto.vshr`, `pto.vaddc`, `pto.vsubc` | +| 8 | [Vec-Scalar Ops](#isa-08-vec-scalar-ops) | Vector-scalar operations | 9 | `pto.vadds`, `pto.vmuls`, `pto.vmaxs`, `pto.vmins`, `pto.vlrelu`, `pto.vshls`, `pto.vshrs`, `pto.vaddcs`, `pto.vsubcs` | +| 9 | [Conversion Ops](#isa-09-conversion-ops) | Type conversion with rounding/saturation control | 2 | `pto.vcvt`, `pto.vtrc` | +| 10 | [Reduction Ops](#isa-10-reduction-ops) | Vector reductions | 7 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin`, `pto.vcgadd`, `pto.vcgmax`, `pto.vcgmin`, `pto.vcpadd` | +| 11 | [Compare & Select](#isa-11-compare-select) | Comparison and conditional selection | 4 (+1 not A5) | `pto.vcmp`, `pto.vcmps`, `pto.vsel`, `pto.vselr` (`pto.vselrv2` removed: not A5) | +| 12 | [Data Rearrangement](#isa-12-data-rearrangement) | In-register data movement and permutation | 2 (+2 not A5) | `pto.vintlv`, `pto.vdintlv` (`pto.vintlvv2`, `pto.vdintlvv2` removed: not A5) | +| 13 | [DSA/SFU Ops](#isa-13-dsa-sfu-ops) | Specialized ops, index generation, and sorting helpers | 9 | `pto.vlrelu`, `pto.vprelu`, `pto.vexpdiff`, `pto.vaxpy`, `pto.vmull`, `pto.vmula`, `pto.vci`, `pto.vbitsort`, `pto.vmrgsort4` | +| 14 | [Arith (Shared MLIR Dialect)](#isa-14-shared-arith) | Full scalar `arith` surface used around PTO ops; the companion page lists categories and representative examples | all scalar ops | `arith.constant`, `arith.addi`, `arith.addf`, `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.index_cast`, `arith.extsi`, `arith.trunci`, `arith.andi`, `arith.shli`, etc. | +| 15 | [SCF (Shared MLIR Dialect)](#isa-15-shared-scf) | Structured loops, branches, and loop-carried state around PTO regions | 5 | `scf.for`, `scf.if`, `scf.while`, `scf.condition`, `scf.yield` | + +--- + +## Detailed ISA Group Reference + +This section inlines the 15 ISA group documents so the architectural overview, notation, summary table, and per-group semantics can be read in a single file. + + + +### 1. Pipeline Synchronization + +> **Category:** Synchronization primitives for coordinating pipeline execution +> **Pipelines:** MTE2 (GM→UB), PIPE_V (Vector), MTE3 (UB→GM) + +The PTO micro Instruction model operates on the Ascend 950's **Decoupled Access-Execute** architecture. The MTE and Vector pipelines run asynchronously, requiring explicit synchronization to prevent data hazards. + +--- + +#### Intra-Core Pipeline Sync + +These ops coordinate data flow between pipelines within a single vector core. + +##### `pto.set_flag` + +- **syntax:** `pto.set_flag["SRC_PIPE", "DST_PIPE", "EVENT_ID"]` +- **semantics:** Signal event from source pipe to destination pipe. + +```c +set_flag(src_pipe, dst_pipe, event_id); +``` + +**Example:** After MTE2 completes GM→UB transfer, signal Vector pipe: +```mlir +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +--- + +##### `pto.wait_flag` + +- **syntax:** `pto.wait_flag["SRC_PIPE", "DST_PIPE", "EVENT_ID"]` +- **semantics:** Block destination pipe until source pipe signals event. + +```c +wait_flag(src_pipe, dst_pipe, event_id); +``` + +**Example:** Vector pipe waits for MTE2 data to arrive: +```mlir +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +--- + +##### `pto.pipe_barrier` + +- **syntax:** `pto.pipe_barrier "PIPE_*"` +- **semantics:** Drain all pending ops in the specified pipe. All previously issued operations on that pipe complete before any subsequent operation begins. + +```c +pipe_barrier(pipe); +``` + +**Pipe identifiers:** `PIPE_MTE2`, `PIPE_V`, `PIPE_MTE3` + +**Example:** Two back-to-back `copy_ubuf_to_gm` calls writing to the same GM address. Without a barrier, MTE3 may reorder them and the final GM value is non-deterministic: + +```mlir +// Both stores target the same GM address — order matters! +pto.copy_ubuf_to_gm %ub_partial_0, %gm_result, ... +// Without pipe_barrier, MTE3 could execute the second copy before the first +// completes, producing a non-deterministic result at %gm_result. +pto.pipe_barrier "PIPE_MTE3" +// After barrier: first copy is guaranteed complete. Second copy overwrites deterministically. +pto.copy_ubuf_to_gm %ub_partial_1, %gm_result, ... +``` + +--- + +##### `pto.get_buf` + +- **syntax:** `pto.get_buf "PIPE_*", %buf_id, %mode : i64, i64` +- **semantics:** Acquire buffer slot for inter-pipeline double-buffering coordination. + +```c +get_buf(pipe, buf_id, mode); +``` + +--- + +##### `pto.rls_buf` + +- **syntax:** `pto.rls_buf "PIPE_*", %buf_id, %mode : i64, i64` +- **semantics:** Release buffer slot to allow other pipeline to proceed. + +```c +rls_buf(pipe, buf_id, mode); +``` + +--- + +##### Mode Parameter for `get_buf` / `rls_buf` + +The `mode` parameter controls how `get_buf` and `rls_buf` interact with pipeline execution and dependency tracking: + +| Mode | `get_buf` Behavior | `rls_buf` Behavior | Use Case | +|------|-------------------|-------------------|----------| +| **0** (default) | **Blocking acquire**: waits for all previous `rls_buf` with same `buf_id` from all pipelines (in program order) before the specified pipe can proceed | **Immediate release**: signals completion for only the instructions related to the specified pipe | **Automatic ping/pong dependency** — recommended for double/multi-buffering | +| **1** | **Non-blocking acquire**: does not wait; pipe execution proceeds immediately | **Deferred release**: waits for all instructions across all pipelines with same `buf_id` to retire before signaling | **Backward compatibility** with `set_flag`/`wait_flag` semantics | + +**Mode 0 (Default — Recommended):** +- `get_buf`: The specified pipeline blocks until all previous `rls_buf` operations for the same buffer ID (from any pipeline) have completed, respecting program order. +- `rls_buf`: Immediately signals that the specified pipeline has finished using the buffer — only waits for that pipe's related instructions. +- This mode provides **automatic RAW/WAR/WAW dependency resolution** based on buffer ID and program order, making it ideal for ping/pong and N-buffer patterns. + +**Mode 1 (Legacy Compatibility):** +- `get_buf`: Does not block — the pipeline proceeds immediately without waiting. +- `rls_buf`: Waits for **all** previous instructions across **all** pipelines with the same buffer ID to retire before signaling release. +- This mode emulates `set_flag`/`wait_flag` behavior and is provided for backward compatibility with existing code patterns. + +> **Note:** A5 supports both `set_flag`/`wait_flag` and `get_buf`/`rls_buf` mechanisms. Mode 1 is rarely needed since mode 0 provides a more programmer-friendly approach for buffer-based synchronization. + +--- + +##### `pto.mem_bar` + +- **syntax:** `pto.mem_bar "BARRIER_TYPE"` +- **semantics:** Intra-vector-pipe memory fence within `__VEC_SCOPE__`. Required when UB addresses alias between vector load/store operations. + +```c +mem_bar(barrier_type); +``` + +**Barrier types:** + +| Type | Semantics | +|------|-----------| +| `VV_ALL` | All prior vector ops complete before subsequent | +| `VST_VLD` | All prior vector stores visible before subsequent loads | +| `VLD_VST` | All prior vector loads complete before subsequent stores | + +**Example:** Ensure stores are visible before loads to same UB region: +```mlir +pto.vsts %v0, %ub[%c0] : !pto.vreg<64xf32>, !pto.ptr +pto.mem_bar "VST_VLD" +%v1 = pto.vlds %ub[%c0] : !pto.ptr -> !pto.vreg<64xf32> +``` + +--- + +#### Why `get_buf` / `rls_buf` is More Programmer-Friendly + +The buffer-based synchronization (`get_buf`/`rls_buf`) provides the **same functional capability** as `set_flag`/`wait_flag` for maintaining correct ordering of RAW/WAR/WAW dependencies across pipelines, but with significant usability advantages: + +##### 1. No Manual Priming or Draining + +With `set_flag`/`wait_flag`, ping/pong loops require: +- **Pre-loop priming**: 4× `set_flag` to initialize reverse-dependency signals (otherwise first iteration deadlocks) +- **Post-loop draining**: 4× `wait_flag` to consume leftover signals from final iterations + +With `get_buf`/`rls_buf`: +- **First iteration**: Buffer is initially free, so `get_buf` proceeds immediately — no priming needed +- **Final iteration**: Last `rls_buf` simply completes — no draining required + +##### 2. No Loop Peeling for Complex Dependencies + +For non-1:1 producer-consumer ratios (e.g., 1 MTE2 load : N Vector compute slices), `set_flag`/`wait_flag` requires **peeling the set_flag outside the loop**: + +```mlir +// set_flag/wait_flag: 1 MTE2 load, 8 Vector computes on slices +// MTE2 loads large tile once +pto.copy_gm_to_ubuf %gm_ptr, %ub_tile, ... +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVT_TILE_READY"] // ◀ MUST be outside loop + +// Vector consumes in 8 slices — but wait_flag can only fire ONCE +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVT_TILE_READY"] // ◀ MUST peel before loop +scf.for %slice = %c0 to %c8 step %c1 { + // compute on %ub_tile[%slice] + // Cannot put wait_flag here — would deadlock on iteration 1+ +} +``` + +With `get_buf`/`rls_buf`, acquire/release can be **inside the loop** — no peeling needed: + +```mlir +// get_buf/rls_buf: same 1:8 pattern, acquire/release inside loop works fine +// MTE2 loads large tile +pto.get_buf "PIPE_MTE2", %bufid_tile, %c0 : i64, i64 +pto.copy_gm_to_ubuf %gm_ptr, %ub_tile, ... +pto.rls_buf "PIPE_MTE2", %bufid_tile, %c0 : i64, i64 + +// Vector acquires/releases per slice — all 8 iterations work correctly +scf.for %slice = %c0 to %c8 step %c1 { + pto.get_buf "PIPE_V", %bufid_tile, %c0 : i64, i64 // iteration 0: blocks until MTE2 done + // iteration 1-7: proceeds immediately (already acquired) + // compute on %ub_tile[%slice] + pto.rls_buf "PIPE_V", %bufid_tile, %c0 : i64, i64 +} +// No peeling required — get_buf handles the MTE2→V dependency automatically +``` + +##### 3. Simpler Mental Model + +| Aspect | `set_flag`/`wait_flag` | `get_buf`/`rls_buf` | +|--------|------------------------|---------------------| +| **Dependency tracking** | Manual: track event IDs, signal directions, pair every set with wait | Automatic: buffer ID + program order | +| **Event ID management** | **8 IDs per pipe-pair direction** (HW limit); each buffer occupies 1 ID per direction | **1 buffer ID per shared resource** (HW limit: 32 global); same ID used across all pipelines | +| **Error-prone areas** | Forgetting prime/drain, mismatched IDs, wrong direction | Forgetting release (but compile-time checkable) | + +##### Quick Example Comparison + +**Problem:** MTE2 loads into `buf[i%2]`, Vector processes, MTE3 stores — standard ping/pong. + +**set_flag/wait_flag approach:** +```mlir +// BEFORE loop: prime 4 reverse-dep signals +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] + +scf.for %i = ... { + // 4 set_flag + 4 wait_flag inside loop + // Must track 4 IDs: 2 pipe-pair directions × 2 ping/pong buffers +} + +// AFTER loop: drain 4 signals +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] +``` + +**get_buf/rls_buf approach:** +```mlir +scf.for %i = ... { + pto.get_buf %bufid_in[%pp], "PIPE_MTE2" + // ... MTE2 work ... + pto.rls_buf %bufid_in[%pp], "PIPE_MTE2" + + pto.get_buf %bufid_in[%pp], "PIPE_V" + pto.get_buf %bufid_out[%pp], "PIPE_V" + // ... Vector work ... + pto.rls_buf %bufid_in[%pp], "PIPE_V" + pto.rls_buf %bufid_out[%pp], "PIPE_V" + + pto.get_buf %bufid_out[%pp], "PIPE_MTE3" + // ... MTE3 work ... + pto.rls_buf %bufid_out[%pp], "PIPE_MTE3" +} +// Done. No prime. No drain. Dependencies resolved by buffer ID + program order. +``` + +--- + +#### Intra-Core Sync Patterns & Examples + +##### Example 1: `set_flag` / `wait_flag` (Explicit Events) + +Each cross-pipeline data dependency requires an explicit signal/wait pair. The programmer must manually insert `set_flag` after the producer and `wait_flag` before the consumer. + +```mlir +// ─── Stage 1: MTE2 loads data from GM into UB ─── +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, ... + +// MTE2 signals: "UB data is ready for Vector pipe" +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +// ─── Stage 2: Vector pipe consumes UB data ─── +// Vector waits until MTE2's signal arrives +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +pto.vecscope { + %v = pto.vlds %ub_ptr[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} + +// Vector signals: "UB output is ready for MTE3" +pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + +// ─── Stage 3: MTE3 stores result from UB back to GM ─── +// MTE3 waits until Vector's signal arrives +pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + +pto.copy_ubuf_to_gm %ub_out, %gm_out, ... +``` + +**Key property:** Every cross-pipeline edge is an explicit `(set_flag, wait_flag)` pair. Simple for straight-line code, but gets verbose in loops (see Example 3). + +--- + +##### Example 2: `get_buf` / `rls_buf` (Resource-Based) + +Instead of naming events, each pipeline declares when it **acquires** (`get_buf`) and **releases** (`rls_buf`) a shared UB buffer. Cross-pipeline RAW/WAR dependencies are resolved implicitly by program order — if MTE2 releases `buf_A` and Vector later acquires `buf_A`, the hardware ensures the acquire cannot proceed until the release completes. + +```mlir +// ─── Stage 1: MTE2 loads data into UB ─── +// MTE2 acquires ub_ptr — blocks if Vector hasn't released it from a prior iteration +pto.get_buf "PIPE_MTE2", %bufid_ub_ptr, %c0 : i64, i64 // mode=0 (default) +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, ... +// MTE2 done writing ub_ptr — release it so Vector can consume +pto.rls_buf "PIPE_MTE2", %bufid_ub_ptr, %c0 : i64, i64 + +// ─── Stage 2: Vector computation ─── +// Vector acquires ub_ptr (input) — blocks until MTE2 releases it (RAW: MTE2 write → V read) +pto.get_buf "PIPE_V", %bufid_ub_ptr, %c0 : i64, i64 +// Vector acquires ub_out (output) — blocks until MTE3 releases it from a prior iteration (WAR: MTE3 read → V write) +pto.get_buf "PIPE_V", %bufid_ub_out, %c0 : i64, i64 + +pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub_ptr[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} + +// Vector done reading ub_ptr — release so MTE2 can reuse it in next iteration +pto.rls_buf "PIPE_V", %bufid_ub_ptr, %c0 : i64, i64 +// Vector done writing ub_out — release so MTE3 can consume +pto.rls_buf "PIPE_V", %bufid_ub_out, %c0 : i64, i64 + +// ─── Stage 3: MTE3 stores result to GM ─── +// MTE3 acquires ub_out — blocks until Vector releases it (RAW: V write → MTE3 read) +pto.get_buf "PIPE_MTE3", %bufid_ub_out, %c0 : i64, i64 +pto.copy_ubuf_to_gm %ub_out, %gm_out, ... +// MTE3 done reading ub_out — release so Vector can reuse it in next iteration +pto.rls_buf "PIPE_MTE3", %bufid_ub_out, %c0 : i64, i64 +``` + +**Key property:** No event IDs needed. Dependencies are implicit from program order of `get_buf`/`rls_buf` on the same buffer ID. This becomes much more convenient in multi-iteration loops (see Example 3). + +--- + +##### Example 3: Ping/Pong Double-Buffering Loop + +Double-buffering overlaps DMA and compute by using two UB buffers alternately. All three stages (MTE2, Vector, MTE3) appear in the **same iteration** — the hardware pipelines them across iterations because different iterations operate on different buffers (`buf[i%2]`). + +###### Event ID scheme (`set_flag` / `wait_flag`) + +With 2 ping/pong buffers and 2 pipeline pairs (MTE2↔V, V↔MTE3), `set_flag`/`wait_flag` needs **8 event IDs** = 2 pipe-pairs × 2 buffers × (forward + reverse): + +**MTE2 ↔ Vector (input buffers):** + +| Event ID | Direction | Purpose | +|----------|-----------|---------| +| `EVT_IN_FWD_0` | MTE2 → V | RAW: buf_in[0] data ready | +| `EVT_IN_FWD_1` | MTE2 → V | RAW: buf_in[1] data ready | +| `EVT_IN_REV_0` | V → MTE2 | WAR: Vector done reading buf_in[0] | +| `EVT_IN_REV_1` | V → MTE2 | WAR: Vector done reading buf_in[1] | + +**Vector ↔ MTE3 (output buffers):** + +| Event ID | Direction | Purpose | +|----------|-----------|---------| +| `EVT_OUT_FWD_0` | V → MTE3 | RAW: buf_out[0] result ready | +| `EVT_OUT_FWD_1` | V → MTE3 | RAW: buf_out[1] result ready | +| `EVT_OUT_REV_0` | MTE3 → V | WAR: MTE3 done reading buf_out[0] | +| `EVT_OUT_REV_1` | MTE3 → V | WAR: MTE3 done reading buf_out[1] | + +###### 3a. `set_flag` / `wait_flag` version + +```mlir +// ═══ Pre-loop: prime ALL reverse-dependency signals ═══ +// Both input and output buffers start unused. We must pre-send +// reverse-dep signals so the first iteration's wait_flags don't deadlock. +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] // ◀ PRIME: buf_in[0] "free" +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] // ◀ PRIME: buf_in[1] "free" +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] // ◀ PRIME: buf_out[0] "free" +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] // ◀ PRIME: buf_out[1] "free" + +scf.for %i = %c0 to %N step %c1 { + // ── All 3 stages in same iteration, indexed by i%2 ── + // %pp = i % 2 (ping/pong selector for buffer & event IDs) + + // ── MTE2: load tile[i] into buf_in[i%2] ── + // WAR: wait until Vector has released buf_in[i%2] from iteration i-2 + pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{pp}"] + pto.copy_gm_to_ubuf %gm_ptr[%i], %ub_in[%pp], ... + // RAW: signal Vector that buf_in[i%2] data is ready + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVT_IN_FWD_{pp}"] + + // ── Vector: compute buf_in[i%2] → buf_out[i%2] ── + // RAW: wait for MTE2 to finish loading buf_in[i%2] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVT_IN_FWD_{pp}"] + // WAR: wait for MTE3 to finish reading buf_out[i%2] from iteration i-2 + pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{pp}"] + scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_in[%pp][%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%pp][%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } {llvm.loop.aivector_scope} + // WAR: tell MTE2 "done reading buf_in[i%2]" + pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{pp}"] + // RAW: tell MTE3 "buf_out[i%2] result ready" + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVT_OUT_FWD_{pp}"] + + // ── MTE3: store result from buf_out[i%2] to GM ── + // RAW: wait for Vector to finish writing buf_out[i%2] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVT_OUT_FWD_{pp}"] + pto.copy_ubuf_to_gm %ub_out[%pp], %gm_out[%i], ... + // WAR: tell Vector "done reading buf_out[i%2]" + pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{pp}"] +} + +// ═══ Post-loop: drain — match every pre-loop prime with a wait ═══ +// Each priming set_flag must be paired. The last loop iteration's +// set_flags are consumed by wait_flags that will never fire inside the +// loop (there is no iteration i+2). Drain them here. +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-1)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-2)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-1)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-2)%2}"] // ◀ DRAIN +``` + +**What `set_flag`/`wait_flag` requires outside the loop:** +- **Before the loop (4 × `set_flag`):** Prime every reverse-dependency event ID — one per buffer per pipe-pair. Without this, the first iteration's `wait_flag` for reverse deps would deadlock (no signal was ever sent). +- **After the loop (4 × `wait_flag`):** Drain the matching reverse-dep signals from the last iterations. Every `set_flag` must be paired with a `wait_flag` — the last loop iterations produce signals that no subsequent iteration consumes, so they must be drained explicitly. + +###### 3b. `get_buf` / `rls_buf` version + +Same ping/pong double-buffering, but **no pre-loop priming or post-loop draining needed.** Buffer acquire/release semantics handle everything. + +```mlir +scf.for %i = %c0 to %N step %c1 { + // %pp = i % 2 (ping/pong selector) + + // ── MTE2: load tile[i] into buf[i%2] ── + // Acquires buf[i%2] — on first iteration, buffer is free so proceeds immediately. + // On later iterations, blocks until Vector releases buf[i%2] (WAR: automatic). + pto.get_buf "PIPE_MTE2", %bufid_buf[%pp], %c0 : i64, i64 // mode=0 + pto.copy_gm_to_ubuf %gm_ptr[%i], %ub_buf[%pp], ... + pto.rls_buf "PIPE_MTE2", %bufid_buf[%pp], %c0 : i64, i64 + + // ── Vector: compute on buf[i%2] ── + // Acquires buf[i%2] — blocks until MTE2 releases it (RAW: automatic) + pto.get_buf "PIPE_V", %bufid_buf[%pp], %c0 : i64, i64 + pto.get_buf "PIPE_V", %bufid_out[%pp], %c0 : i64, i64 + scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_buf[%pp][%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%pp][%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } {llvm.loop.aivector_scope} + // Release buf[i%2] — MTE2 can reuse in iteration i+2 (WAR resolved) + pto.rls_buf "PIPE_V", %bufid_buf[%pp], %c0 : i64, i64 + pto.rls_buf "PIPE_V", %bufid_out[%pp], %c0 : i64, i64 + + // ── MTE3: store result ── + // Acquires out[i%2] — blocks until Vector releases it (RAW: automatic) + pto.get_buf "PIPE_MTE3", %bufid_out[%pp], %c0 : i64, i64 + pto.copy_ubuf_to_gm %ub_out[%pp], %gm_out[%i], ... + pto.rls_buf "PIPE_MTE3", %bufid_out[%pp], %c0 : i64, i64 +} +// No post-loop drain needed — last rls_buf completes the pipeline. +``` + +**No priming, no draining, no event IDs.** The acquire/release protocol on buffer IDs indexed by `i%2` implicitly resolves all cross-pipeline dependencies: +- **RAW** (MTE2→V): Vector's `get_buf` blocks until MTE2's `rls_buf` on `buf[i%2]` +- **WAR** (V→MTE2): MTE2's `get_buf` in iteration `i+2` blocks until Vector's `rls_buf` in iteration `i` (same buffer) +- **First iteration:** Buffer is initially free, so `get_buf` proceeds without blocking — no priming needed + +--- + +#### Comparison Summary + +| Aspect | `set_flag` / `wait_flag` | `get_buf` / `rls_buf` | +|--------|--------------------------|------------------------| +| Dependency model | Explicit event signals | Implicit via buffer acquire/release | +| IDs per pipe-pair | 2 IDs per buffer: 1 for forward (e.g., MTE2→V) + 1 for reverse (V→MTE2) | **1 ID per buffer** (handles both directions automatically) | +| Total HW IDs | **8 per pipe-pair** (hardware limit) | **32 global** across all pipes | +| Reverse (WAR) deps | Extra `set_flag`/`wait_flag` pair per buffer | Handled automatically | +| Pre-loop setup | `set_flag` to prime each reverse dep | **None** | +| Post-loop teardown | `wait_flag` to drain all primed signals | **None** | +| Loop peeling for complex deps | Required for non-1:1 or nested loops | **Not required** | +| Straight-line code | Simple, clear | Slightly more verbose (bracket each stage) | +| Ping/pong loops | 8 event IDs + 4 prime + 4 drain | Same pattern, **no overhead** | +| Best used for | Simple pipelines, fine-grained control | Double/multi-buffering, complex loops | + +--- + +#### Inter-Core Sync + +> **Note:** Inter-core sync is only needed for **mixed Cube+Vector tasks** where Cube produces data that Vector consumes (or vice versa). **Vec-only tasks can ignore this section entirely.** + +These ops coordinate execution across the Cube block and Vector subblocks within a cluster. Each core cluster consists of **1 Cube block : 2 Vector subblocks**, each with its own **SU (Sequencer Unit)** running independent instruction streams. + +``` +Core Cluster (1:2 ratio) +┌─────────────────────────────────────────────┐ +│ ┌──────────────┐ ┌──────────────┐ │ +│ │ AIC (Cube) │ │ AIV0 (Vec) │ │ +│ │ ┌────────┐ │ │ ┌────────┐ │ │ +│ │ │ SU │──┼────┼──│ SU │ │ │ +│ │ └────────┘ │ │ └────────┘ │ │ +│ │ CUBE pipe │ │ MTE2/V/MTE3 │ │ +│ │ L0C buffer │ │ UB (256KB) │ │ +│ └──────────────┘ └──────────────┘ │ +│ ┌──────────────┐ │ +│ │ AIV1 (Vec) │ │ +│ │ ┌────────┐ │ │ +│ │ │ SU │ │ │ +│ │ └────────┘ │ │ +│ │ MTE2/V/MTE3 │ │ +│ │ UB (256KB) │ │ +│ └──────────────┘ │ +└─────────────────────────────────────────────┘ +``` + +##### Platform Comparison + +| Aspect | A2A3 (Ascend 910) | A5 (Ascend 950) | +|--------|-------------------|-----------------| +| **Signal op** | `set_cross_core` (mode2) | `set_intra_block` | +| **Wait op** | `wait_flag_dev` | `wait_intra_core` | +| **Wait behavior** | SU-level blocking (entire core stalls) | Per-pipeline (only named pipe stalls) | +| **Semaphore pool** | 16 IDs per cluster, 4-bit counter | 16 IDs, but 32-ID address space (see below) | +| **C→V** | **Broadcast**: one `set` reaches both AIV0+AIV1 | **1:1**: separate `set` per subblock required | +| **V→C** | **Reduce**: Cube waits for both subblocks in one `wait` | **1:1**: Cube needs separate `wait` per subblock | + +##### A2A3: `set_cross_core` / `wait_flag_dev` + +```c +// mode2 broadcast/reduce semantics for 1:2 cluster +set_cross_core(pipe, semaphore_id); // pipe: VEC/MTE2/CUBE/FIX +wait_flag_dev(semaphore_id); // SU-level blocking +``` + +``` +C→V Broadcast (one set reaches both): + AIC ──set_cross_core──┬──> AIV0 sema++ + └──> AIV1 sema++ + +V→C Reduce (one wait for both): + AIV0 ──set_cross_core──┐ + ├──> AIC wait_flag_dev (blocks until BOTH) + AIV1 ──set_cross_core──┘ +``` + +##### `pto.set_cross_core` + +- **syntax:** `pto.set_cross_core %core_id, %event_id : i64, i64` +- **semantics:** Signal event to another core. Uses **mode2** for 1:2 cluster on A2A3. + +##### `pto.wait_flag_dev` + +- **syntax:** `pto.wait_flag_dev %core_id, %event_id : i64, i64` +- **semantics:** Wait for event from another core. **SU-level blocking** — entire core stalls. + +##### A5: `set_intra_block` / `wait_intra_core` + +```c +set_intra_block(trigger_pipe, semaphore_id); +wait_intra_core(wait_pipe, semaphore_id); // only named pipe stalls +``` + +**A5 semaphore address space:** The hardware has **16 physical semaphore IDs** but exposes a **32-ID address space** to support 1:1 signaling to each subblock: + +| ID Range | Target | +|----------|--------| +| 0–15 | AIV0 (subblock 0) | +| 16–31 (+15 offset) | AIV1 (subblock 1) | + +This means C→V requires **separate `set_intra_block` calls** per subblock (no broadcast), and V→C requires **separate `wait_intra_core` calls** per subblock (no hardware reduce). + +``` +C→V on A5 (1:1, no broadcast — need two sets): + AIC ──set_intra_block(pipe, sema_id)────> AIV0 + AIC ──set_intra_block(pipe, sema_id+15)──> AIV1 + +V→C on A5 (1:1, no reduce — need two waits): + AIV0 ──set_intra_block──> AIC wait_intra_core(pipe, sema_id) + AIV1 ──set_intra_block──> AIC wait_intra_core(pipe, sema_id+15) // extra wait +``` + +##### `pto.set_intra_block` + +- **syntax:** `pto.set_intra_block %block_id, %event_id : i64, i64` +- **semantics:** Signal event within a block (A5). Specifies **trigger pipe**. 1:1 per subblock. + +##### `pto.wait_intra_core` + +- **syntax:** `pto.wait_intra_core %block_id, %event_id : i64, i64` +- **semantics:** Wait for event within block (A5). Specifies **which pipeline should wait** — only that pipe stalls, SU and other pipes continue. + +##### Wait Granularity Comparison + +``` +A2A3 wait_flag_dev (SU-level stall): + SU ──┬── PIPE_MTE2 ───╳ ALL STALLED + ├── PIPE_V ───╳ ALL STALLED + └── PIPE_MTE3 ───╳ ALL STALLED + +A5 wait_intra_core "PIPE_MTE2" (per-pipe stall): + SU ──┬── PIPE_MTE2 ───╳ STALLED (waiting for Cube) + ├── PIPE_V ─── ✓ RUNNING + └── PIPE_MTE3 ─── ✓ RUNNING +``` + + + +### 2. DMA Copy Programming + +> **Category:** DMA transfer configuration and execution +> **Pipelines:** MTE2 (GM→UB), MTE3 (UB→GM) + +DMA transfers move data between Global Memory (GM) and Unified Buffer (UB). The MTE engines operate asynchronously from the Vector core, requiring explicit sync (see [Pipeline Sync](#isa-01-pipeline-sync)). + +The MTE2/MTE3 DMA engine executes a **multi-level nested loop** transfer. Before issuing the copy instruction, stride and loop-size registers must be configured. + +--- + +#### Loop Stride Configuration (GM→UB) + +These ops configure the MTE2 DMA engine's hardware loops for GM→UB transfers. They must be set **before** calling `pto.copy_gm_to_ubuf`. + +##### `pto.set_loop_size_outtoub` + +- **syntax:** `pto.set_loop_size_outtoub %loop1_count, %loop2_count : i64, i64` +- **semantics:** Configure HW loop iteration counts for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%loop1_count` | 21 bits | Inner HW loop iteration count | +| `%loop2_count` | 21 bits | Outer HW loop iteration count | + +When not using multi-level looping, set both to 1. + +--- + +##### `pto.set_loop2_stride_outtoub` + +- **syntax:** `pto.set_loop2_stride_outtoub %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure outer loop (loop2) pointer advance for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 40 bits | GM source pointer advance per loop2 iteration (bytes) | +| `%dst_stride` | 21 bits | UB destination pointer advance per loop2 iteration (bytes) | + +After each loop2 iteration, the DMA engine advances the GM read pointer by `%src_stride` and UB write pointer by `%dst_stride`. + +--- + +##### `pto.set_loop1_stride_outtoub` + +- **syntax:** `pto.set_loop1_stride_outtoub %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure inner loop (loop1) pointer advance for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 40 bits | GM source pointer advance per loop1 iteration (bytes) | +| `%dst_stride` | 21 bits | UB destination pointer advance per loop1 iteration (bytes) | + +--- + +#### Loop Stride Configuration (UB→GM) + +These ops configure the MTE3 DMA engine's hardware loops for UB→GM transfers. They must be set **before** calling `pto.copy_ubuf_to_gm`. + +Note: UB stride fields are 21 bits (sufficient for 256KB UB address space), GM stride fields are 40 bits (full GM address range). + +##### `pto.set_loop_size_ubtoout` + +- **syntax:** `pto.set_loop_size_ubtoout %loop1_count, %loop2_count : i64, i64` +- **semantics:** Configure HW loop iteration counts for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%loop1_count` | 21 bits | Inner HW loop iteration count | +| `%loop2_count` | 21 bits | Outer HW loop iteration count | + +--- + +##### `pto.set_loop2_stride_ubtoout` + +- **syntax:** `pto.set_loop2_stride_ubtoout %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure outer loop (loop2) pointer advance for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 21 bits | UB source pointer advance per loop2 iteration (bytes) | +| `%dst_stride` | 40 bits | GM destination pointer advance per loop2 iteration (bytes) | + +--- + +##### `pto.set_loop1_stride_ubtoout` + +- **syntax:** `pto.set_loop1_stride_ubtoout %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure inner loop (loop1) pointer advance for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 21 bits | UB source pointer advance per loop1 iteration (bytes) | +| `%dst_stride` | 40 bits | GM destination pointer advance per loop1 iteration (bytes) | + +--- + +#### DMA Transfer Execution + +##### `pto.copy_gm_to_ubuf` + +- **syntax:** +```mlir +pto.copy_gm_to_ubuf %gm_src, %ub_dst, + %sid, %n_burst, %len_burst, %left_padding, %right_padding, + %data_select_bit, %l2_cache_ctl, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` +- **semantics:** DMA transfer from Global Memory (`!pto.ptr`) to Unified Buffer (`!pto.ptr`). + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%gm_src` | GM source pointer (`!pto.ptr`) | +| `%ub_dst` | UB destination pointer (`!pto.ptr`, 32B-aligned) | +| `%sid` | Stream ID (usually 0) | +| `%n_burst` | Number of burst rows (innermost loop count) | +| `%len_burst` | Contiguous bytes transferred per burst row | +| `%left_padding` | Left padding count (bytes) | +| `%right_padding` | Right padding count (bytes) | +| `%data_select_bit` | Padding / data-select control bit (`i1`) | +| `%l2_cache_ctl` | L2 cache allocate control (TBD — controls whether DMA allocates in L2 cache) | +| `%src_stride` | GM source stride: start-to-start distance between consecutive burst rows (bytes) | +| `%dst_stride` | UB destination stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | + +--- + +##### `pto.copy_ubuf_to_gm` + +- **syntax:** +```mlir +pto.copy_ubuf_to_gm %ub_src, %gm_dst, + %sid, %n_burst, %len_burst, %reserved, %dst_stride, %src_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` +- **semantics:** DMA transfer from Unified Buffer (`!pto.ptr`) to Global Memory (`!pto.ptr`). MTE3 reads only `len_burst` bytes from each UB row (de-padding). + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%ub_src` | UB source pointer (`!pto.ptr`, 32B-aligned) | +| `%gm_dst` | GM destination pointer (`!pto.ptr`) | +| `%sid` | Stream ID (usually 0) | +| `%n_burst` | Number of burst rows | +| `%len_burst` | Contiguous bytes transferred per burst row | +| `%reserved` | Reserved field (set to 0) | +| `%dst_stride` | GM destination stride: start-to-start distance between consecutive burst rows (bytes) | +| `%src_stride` | UB source stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | + +--- + +##### `pto.copy_ubuf_to_ubuf` + +- **syntax:** +```mlir +pto.copy_ubuf_to_ubuf %source, %dest, %sid, %n_burst, %len_burst, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64 x5 +``` +- **semantics:** Copy within Unified Buffer. + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%source` | UB source pointer | +| `%dest` | UB destination pointer | +| `%sid` | Stream ID | +| `%n_burst` | Number of bursts | +| `%len_burst` | Length per burst | +| `%src_stride` | Source stride | +| `%dst_stride` | Destination stride | + +--- + +#### Burst / Stride / Pad Model + +All A5 DMA addresses are **stride-based**: stride is the distance from the start of one row to the start of the next row (`stride >= lenBurst`). There is no separate "gap" parameter. + +##### Key Terms + +``` +burst = lenBurst contiguous bytes transferred per row +stride = distance (bytes) from start of row[r] to start of row[r+1] +pad = ub_stride - lenBurst, padded to the 32B alignment boundary +``` + +##### Alignment Constraints + +- **UB addresses** (both source and destination) must be **32-byte aligned**. +- **GM→UB padding**: When `data_select_bit = true`, each UB row is padded from `lenBurst` up to the **32B-aligned boundary** of `ub_stride` with `pad_val` (set via `set_mov_pad_val`). This ensures every UB row starts at a 32B-aligned offset. +- **UB→GM de-padding**: MTE3 reads `lenBurst` bytes from each 32B-aligned UB row (skipping any padding that was added during load), writing only valid data to GM. This effectively strips padding on store. + +##### 2D Diagram: GM→UB (pto.copy_gm_to_ubuf) + +``` +GM (source, `!pto.ptr`): + + |<--- src_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +UB (destination, `!pto.ptr`, 32B-aligned): + + |<---------- dst_stride (32B-aligned) ---------->| + |<- len_burst ->|<- pad (to 32B boundary) ->| | +Row 0: [##DATA########][000000 PAD 000000000000000] +Row 1: [##DATA########][000000 PAD 000000000000000] +Row 2: [##DATA########][000000 PAD 000000000000000] + ... +Row N-1: [##DATA########][000000 PAD 000000000000000] + +N = n_burst +stride = start of row[r] to start of row[r+1] +pad = filled with pad_val to 32B boundary (data_select_bit=true) +[DATA] = valid data transferred by DMA +[PAD] = pad_val fill (set via set_mov_pad_val) +``` + +##### 2D Diagram: UB→GM (pto.copy_ubuf_to_gm) + +``` +UB (source, `!pto.ptr`, 32B-aligned start addr): + + |<---------- src_stride (32B-aligned) --------->| + |<- len_burst ->|<-- pad (ignored on read) -->| | +Row 0: [##DATA########][000 pad 000000000000000000] +Row 1: [##DATA########][000 pad 000000000000000000] +Row 2: [##DATA########][000 pad 000000000000000000] + ... +Row N-1: [##DATA########][000 pad 000000000000000000] + +GM (destination, `!pto.ptr`): + + |<--- dst_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +N = n_burst +MTE3 reads only len_burst bytes from each UB row (de-padding). +Only len_burst bytes are written to each GM row. +``` + +--- + +#### Multi-Level Loop Semantics (C Code) + +The full DMA transfer is a nested loop. The HW loop registers (set before the copy) control the outer levels, and the copy instruction parameters control the innermost burst level. + +##### GM→UB Full Loop + +```c +// C equivalent of what the HW executes: +for (int j = 0; j < loop2_count; j++) { // HW outer loop + uint8_t *gm1 = gm_src + j * loop2_src_stride; + uint8_t *ub1 = ub_dst + j * loop2_dst_stride; + + for (int k = 0; k < loop1_count; k++) { // HW inner loop + uint8_t *gm2 = gm1 + k * loop1_src_stride; + uint8_t *ub2 = ub1 + k * loop1_dst_stride; + + for (int r = 0; r < n_burst; r++) { // burst engine + memcpy(ub2 + r * dst_stride, // UB dest row + gm2 + r * src_stride, // GM src row + len_burst); // contiguous bytes + if (data_select_bit) + memset(ub2 + r * dst_stride + len_burst, + pad_val, dst_stride - len_burst); + } + } +} +``` + +##### UB→GM Full Loop + +```c +// C equivalent: +for (int j = 0; j < loop2_count; j++) { + uint8_t *ub1 = ub_src + j * loop2_src_stride; + uint8_t *gm1 = gm_dst + j * loop2_dst_stride; + + for (int k = 0; k < loop1_count; k++) { + uint8_t *ub2 = ub1 + k * loop1_src_stride; + uint8_t *gm2 = gm1 + k * loop1_dst_stride; + + for (int r = 0; r < n_burst; r++) { + memcpy(gm2 + r * dst_stride, // GM dest row + ub2 + r * src_stride, // UB src row + len_burst); // contiguous bytes + } + } +} +``` + +--- + +#### Example 1: GM→UB — Load a 32×32 f32 Tile (Simple Case) + +Load a 32×32 f32 tile from GM into UB. This matches the `abs_kernel_2d` test case. + +``` +GM layout (32 × 32 f32, contiguous): + + |<- len_burst = 128B (32 × 4) ->| + |<- src_stride = 128B --------->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + +UB layout (32 × 32 f32, 32B-aligned, contiguous): + + |<- dst_stride = 128B (32B-aligned) ->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + + len_burst = 32 × 4 = 128 bytes + src_stride = 128 bytes (contiguous rows) + dst_stride = 128 bytes (already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — no multi-level loops needed +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + +pto.copy_gm_to_ubuf %arg0, %ub_in, + %c0_i64, // sid = 0 + %c32_i64, // n_burst = 32 (32 rows) + %c128_i64, // len_burst = 128 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c128_i64, // src_stride = 128 bytes + %c128_i64 // dst_stride = 128 bytes + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 2: GM→UB — Load a 2D Tile from a Larger Matrix + +Load a 64×128 tile (f16) from a 1024×512 matrix in GM into UB. + +``` +GM layout (1024 × 512 f16): + + col 0 col 128 col 512 + | | | + +--[###TILE###]+.....................+ row R + +--[###TILE###]+.....................+ row R+1 + ... + +--[###TILE###]+.....................+ row R+63 + + |<--------- src_stride = 1024B ----------->| + |<-len_burst=256B->| + + len_burst = 128 × 2 = 256 bytes (128 f16 elements) + src_stride = 512 × 2 = 1024 bytes (start-to-start, full GM row) + +UB layout (64 × 128 f16, 32B-aligned, contiguous): + + +--[###TILE###]--+ row 0 (256 bytes, 32B-aligned, no pad) + +--[###TILE###]--+ row 1 + ... + +--[###TILE###]--+ row 63 + + dst_stride = 256 bytes (= len_burst, already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — no multi-level loops needed +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 (64 rows) + %c256_i64, // len_burst = 256 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c1024_i64, // src_stride = 1024 bytes (full matrix row) + %c256_i64 // dst_stride = 256 bytes (tile row) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 3: GM→UB — Load with Padding + +Load 100 valid columns from GM into a 128-wide UB tile (f16). The remaining 28 columns are zero-padded. + +``` +GM (100 cols valid, contiguous): + + |<-len_burst=200B->| + |<- src_stride=200B (start-to-start) ->| + +--[####DATA####]-+ row 0 + +--[####DATA####]-+ row 1 + ... + +--[####DATA####]-+ row 63 + +UB (128 cols wide, 32B-aligned, padded): + + |<--------- dst_stride = 256B (32B-aligned) --------->| + |<-len_burst=200B->|<---- pad = 56B to 32B boundary ->| + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 0 + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 1 + ... + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 63 + + len_burst = 100 × 2 = 200 bytes + src_stride = 200 bytes (start-to-start, contiguous in GM) + dst_stride = 128 × 2 = 256 bytes (32B-aligned tile width in UB) + pad = 256 - 200 = 56 bytes (padded to 32B boundary with pad_val) +``` + +```mlir +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 + %c200_i64, // len_burst = 200 bytes + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %true, // data_select_bit = true (enable padding) + %c0_i64, // l2_cache_ctl = 0 + %c200_i64, // src_stride = 200 bytes + %c256_i64 // dst_stride = 256 bytes (32B-aligned) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 4: UB→GM — Store a 32×32 f32 Tile (Simple Case) + +Store a 32×32 f32 tile from UB back to GM. This matches the `abs_kernel_2d` test case. + +``` +UB (source, 32B-aligned, 32 × 32 f32): + + |<- src_stride = 128B (32B-aligned) ->| + |<- len_burst = 128B ->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 + + (no padding here — len_burst == src_stride) + +GM (dest, 32 × 32 f32): + + |<- dst_stride = 128B ->| + |<- len_burst = 128B -->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 +``` + +```mlir +// Configure MTE3 strides +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + +pto.copy_ubuf_to_gm %ub_out, %arg1, + %c0_i64, // sid = 0 + %c32_i64, // n_burst = 32 + %c128_i64, // len_burst = 128 bytes + %c0_i64, // reserved = 0 + %c128_i64, // dst_stride = 128 bytes + %c128_i64 // src_stride = 128 bytes + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` + +--- + +#### Example 5: UB→GM — Store a 2D Tile Back to a Larger Matrix + +Store a 64×128 tile (f16) from UB back to a 1024×512 GM matrix at an offset. + +``` +UB (source, 32B-aligned, 64 × 128 f16): + + |<- src_stride = 256B (32B-aligned) ->| + |<- len_burst = 256B ->| + +--[#####TILE#####]---+ row 0 + +--[#####TILE#####]---+ row 1 + ... + +--[#####TILE#####]---+ row 63 + + (no padding here — len_burst == src_stride) + +GM (dest, into 1024 × 512 matrix): + + |<----------- dst_stride = 1024B (start-to-start) --------->| + |<- len_burst = 256B ->| | + col 0 col 128 col 512 + +--[#####TILE#####]---+.............................+ row R + +--[#####TILE#####]---+.............................+ row R+1 + ... + +--[#####TILE#####]---+.............................+ row R+63 + + MTE3 reads len_burst bytes from each 32B-aligned UB row, + writes only len_burst bytes per GM row (stride controls row spacing). +``` + +```mlir +// Configure MTE3 strides +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 + +pto.copy_ubuf_to_gm %ub_ptr, %gm_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 + %c256_i64, // len_burst = 256 bytes + %c0_i64, // reserved = 0 + %c1024_i64, // dst_stride = 1024 bytes (GM row) + %c256_i64 // src_stride = 256 bytes (UB row) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` + +--- + +#### Example 6: GM→UB with Multi-Level Loop (Batch of Tiles) + +Load 4 batches of 8×128 tiles from a [4, 8, 128] f16 tensor using loop1. + +``` +GM [4, 8, 128] f16 (contiguous): UB (4 tiles laid out sequentially): + + batch 0: 8 rows × 256 bytes [batch 0: 8×128][batch 1: 8×128] + batch 1: 8 rows × 256 bytes [batch 2: 8×128][batch 3: 8×128] + batch 2: 8 rows × 256 bytes + batch 3: 8 rows × 256 bytes loop1 src_stride = 2048 bytes (8 × 256) + loop1 dst_stride = 2048 bytes (8 × 256) + Each batch = 8 × 256 = 2048 bytes loop1_count = 4 (iterate over batches) +``` + +```mlir +// loop1_count = 4 batches, loop2_count = 1 (not used) +pto.set_loop_size_outtoub %c4_i64, %c1_i64 : i64, i64 + +// loop1 stride: advance by one batch (2048 bytes) in both GM and UB +pto.set_loop1_stride_outtoub %c2048_i64, %c2048_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c8_i64, // n_burst = 8 rows per batch + %c256_i64, // len_burst = 256 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c256_i64, // src_stride = 256 (contiguous rows) + %c256_i64 // dst_stride = 256 (contiguous rows) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +Execution trace: + +``` +loop1 iter 0: gm_ptr + 0×2048 → ub_ptr + 0×2048, DMA 8 rows × 256B +loop1 iter 1: gm_ptr + 1×2048 → ub_ptr + 1×2048, DMA 8 rows × 256B +loop1 iter 2: gm_ptr + 2×2048 → ub_ptr + 2×2048, DMA 8 rows × 256B +loop1 iter 3: gm_ptr + 3×2048 → ub_ptr + 3×2048, DMA 8 rows × 256B +``` + + + +### 3. Vector Load/Store + +> **Category:** UB ↔ Vector Register data movement +> **Pipeline:** PIPE_V (Vector Core) + +Vector loads move data from Unified Buffer (UB) to vector registers (`vreg`). Vector stores move data from `vreg` back to UB. All vector compute operates only on `vreg` — UB is the staging area between DMA and compute. + +#### Common Operand Model + +- `%source` / `%dest` is the base address operand in SSA form. The base pointer + MUST address the Vector tile buffer / UB space. +- `%offset` is the displacement operand in SSA form. The exact encoding is + instruction-specific, but the effective address and any post-update behavior + MUST match the selected instruction form. +- `%mask` is the predicate operand for predicated memory families. For memory + families, + inactive lanes or inactive blocks MUST NOT issue memory requests unless the + instruction explicitly documents a different behavior. +- `%result` is the destination vector register value in SSA form. +- `!pto.align` is the SSA carrier for alignment-buffer state used by unaligned + load/store families. The PTO micro Instruction representation makes that state explicit rather than implicit. + +--- + +#### Latency and throughput (A5) + +**Cycle-accurate simulator (CA model)** issue→retire timings for vector-side instructions behind this chapter. Values are **simulator** results, **not** guaranteed for silicon. + +**SOC:** Tables below are from **Ascend910_9599** CA sim (the pto-isa ST default when **Ascend950PR_9599** is not selected). + +**Log `dist:` tokens:** PTO load/store modes lower to **`RV_VLD` / `RV_VLDI` / `RV_VST` / `RV_VSTI`** with a **`dist:`** field on the vector pipes (`RVECLD` / `RVECST`). Some simulator logs typo contiguous load as `dist:NORAML`; treat as **`NORMAL`**. + +##### Reference op latencies (A5 mnemonics) + +| A5 mnemonic | Mode / note | Typical issue→retire (cycles) | +|-------------|-------------|------------------------------| +| `RV_VLD` | `dist:NORMAL` / `NORAML` | **9** | +| `RV_VLDI` | `dist:DINTLV` (dual vreg) | **9** | +| `RV_VST` / `RV_VSTI` | `dist:NORM` | **9** | +| `RV_VGATHER2` | `Dtype: B32` | **27–28** | +| `RV_VGATHERB` | indexed byte gather | **~21** | +| `RV_VSCATTER` | `Dtype: B16` | **~17** | +| `RV_VADD` | F32 between UB-backed ops | **7** | + +##### `dist:` tokens (issue→retire) + +Most **`dist:`** tokens are **9** issue→retire cycles. **`INTLV`** on **`RV_VSTI`** is **12** cycles. + +| `dist:` (as in log) | RV op | issue→retire (cycles) | +|---------------------|-------|----------------------| +| `DINTLV` | `RV_VLDI` | **9** | +| `BRC` | `RV_VLD` | **9** | +| `BRC_BLK` | `RV_VLD` | **9** | +| `INTLV` | `RV_VSTI` | **12** | +| `UNPK` | `RV_VLD` | **9** | +| `NORM` | `RV_VSTI` | **9** | +| `PK` | `RV_VSTI` | **9** | +| `NORMAL` / `NORAML` | `RV_VLD` | **9** | + +**Note:** PTO intrinsic **`BRC_BLK`** matches the **`BRC_BLK`** `dist:` string on **`RV_VLD`** in simulator logs (block-replicate path; not a plain contiguous copy in the usual tiling use). + +**Issue (vector load/store):** `pto.vlds` (**`RV_VLD`**) is **dual-issue capable**: two independent `pto.vlds` can issue **in the same cycle**. **Alternatively**, the hardware can issue **one** `pto.vlds` **and** **one** `pto.vsts` together (**1+1**) in the same cycle. Each cycle is **either** dual **`vlds`** **or** **`vlds` + `vsts` (1+1)**—those two issue modes are mutually exclusive. Sustained throughput still depends on RAW hazards and loop structure. + +**Throughput (simulator, pattern-dependent):** + +- **`RV_VLD` / `pto.vlds`:** Dual-issue **or** half of a **1+1** with `vsts`, per the rule above. +- **`RV_VST` / `pto.vsts`:** In a **1+1** cycle, pairs with one `vlds`; otherwise typically **one** store per cycle in tight loops. +- **`RV_VGATHER2`:** Much lower than contiguous `RV_VLD` (on the order of **~0.1** ops/cycle in steady-state alongside 27–28-cycle latency). + +##### PTO `dist` summary (loads) + +| PTO `dist` (load) | Latency | +|-------------------|-------------------| +| `NORM` | **9** cycles | +| `UNPK` | **9** cycles | +| `DINTLV` | **9** cycles (`RV_VLDI`) | +| `BRC` | **9** cycles (`RV_VLD`) | +| `BRC_BLK` | **9** cycles as **`dist:BRC_BLK`** on `RV_VLD` | +| `BDINTLV` | **9** cycles | +| `US`, `DS`, `SPLT4CHN`, `SPLT2CHN` | **9** cycles | + +##### PTO `dist` summary (stores) + +| PTO `dist` (store) | Latency | +|--------------------|-------------------| +| `NORM` | **9** cycles (`RV_VSTI`) | +| `PK` | **9** cycles | +| `INTLV` (`pto.vstx2`) | **12** cycles | +| `MRG4CHN`, `MRG2CHN` | **9** cycles | + +##### Gather, scatter, and special addressing + +| PTO op | A5-level | Latency | +|--------|----------|-------------------| +| `pto.vgather2` | `RV_VGATHER2` | **27–28** cycles (pattern-dependent) | +| `pto.vgatherb` | `RV_VGATHERB` | **~21** cycles issue→retire | +| `pto.vgather2_bc` | (broadcast gather) | **27–28** cycles (same as **`pto.vgather2`**) | +| `pto.vscatter` | `RV_VSCATTER` | **~17** cycles for **`Dtype: B16`** | + +##### Strided loads/stores, unaligned ops, alignment state + +Ops such as **`pto.vldas`**, **`pto.vldus`**, **`pto.vsld`**, **`pto.vsldb`**, **`pto.vsst`**, **`pto.vsstb`**, **`pto.vsta`**, **`pto.vstas`**, **`pto.vstar`**, **`pto.vstu`**, **`pto.vstus`**, **`pto.vstur`**: **9** cycles (same vector load/store pipe family as contiguous `RV_VLD` / `RV_VST` unless listed otherwise above). + +##### Dual-issue vs DMA + +DMA **`TLOAD` / `TSTORE`** (global memory ↔ UB) use **MTE** pipes, not `RV_VLD`/`RV_VST`. **MTE2** `MOV_*` latency is not the same as vector `RV_VLD` latency (see `02-dma-copy.md` for GM↔UB movement). + +--- + +#### Contiguous Loads + +##### `pto.vlds` + +- **syntax:** `%result = pto.vlds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.vreg` +- **semantics:** Vector load with distribution mode. +- **inputs:** + `%source` is the UB base address, `%offset` is the load displacement, and + `DIST` selects the distribution mode. +- **outputs:** + `%result` is the loaded vector register value. +- **constraints and limitations:** + The effective address MUST satisfy the alignment rule of the selected + distribution mode. `NORM` reads one full vector footprint. Broadcast, + upsample, downsample, unpack, split-channel, and deinterleave modes change + how memory bytes are mapped into destination lanes, but they do not change the + fact that the source is UB memory. PTO surface exposes load `dist` as family + tokens, and each family only supports the element widths listed below. + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `NORM` | width-agnostic | `dst[i] = UB[base + i * sizeof(T)]` | **9** cycles | +| `BRC` | `b8`, `b16`, `b32` | `dst[i] = UB[base]` for all `i` | **9** cycles | +| `US` | `b8`, `b16` | `dst[2*i] = dst[2*i+1] = UB[base + i]` | **9** cycles | +| `DS` | `b8`, `b16` | `dst[i] = UB[base + 2*i]` | **9** cycles | +| `UNPK` | `b8`, `b16`, `b32` | Expand packed source data into wider lanes | **9** cycles | +| `BRC_BLK` | width-agnostic | Block-replicate load path; simulator logs may print `dist:BRC_BLK` | **9** cycles | +| `E2B` | `b16`, `b32` | Load element groups and expand them into byte-oriented lane layout | **9** cycles | +| `UNPK4` | `b8` | Unpack 4-way packed `b8` source groups into destination lanes | **9** cycles | +| `SPLT4CHN` | `b8` | Split 4-channel interleaved source into one channel plane | **9** cycles | +| `SPLT2CHN` | `b8`, `b16` | Split 2-channel interleaved source into one channel plane | **9** cycles | + +`pto.vlds` currently covers only single-result load families. Dual-result +deinterleave forms are modeled separately in PTO surface as +[`pto.vldsx2`](#ptovldsx2): `BDINTLV` is the block-deinterleave family, while +`DINTLV` is the element-width-sensitive deinterleave family. + +**Example — Contiguous load:** +```mlir +%v = pto.vlds %ub[%offset] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +``` + +**Example — Broadcast scalar to all lanes:** +```mlir +%v = pto.vlds %ub[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> +``` + +--- + +##### `pto.vldas` + +- **syntax:** `%result = pto.vldas %source : !pto.ptr -> !pto.align` +- **semantics:** Prime alignment buffer for subsequent unaligned load. +- **inputs:** + `%source` is the UB address whose surrounding aligned block seeds the load + alignment state. +- **outputs:** + `%result` is the initialized load-alignment state. +- **constraints and limitations:** + This op is the required leading operation for a `pto.vldus` stream using the + same alignment state. The source address itself need not be 32-byte aligned; + hardware truncates it to the aligned block boundary for the priming load. +- **Latency:** **9** cycles. + +--- + +##### `pto.vldus` + +- **syntax:** `%result, %align_out = pto.vldus %source, %align : !pto.ptr, !pto.align -> !pto.vreg, !pto.align` +- **semantics:** Unaligned load using primed align state. +- **inputs:** + `%source` is the current UB address and `%align` is the incoming load + alignment state primed by `pto.vldas` or a prior `pto.vldus`. +- **outputs:** + `%result` is the assembled vector value and `%align_out` is the updated + alignment state. +- **constraints and limitations:** + A matching `pto.vldas` MUST appear before the first dependent `pto.vldus` + stream in the same vector loop. The installed no-post A5 interface keeps a + struct-shaped internal return for lowering convenience, but its no-post + `base` field is not meaningful user-visible state. VPTO therefore hides that + value and only exposes the updated align carrier. Reusing the original + `%source` starts a new explicit access point; if the caller wants another + no-post access, it should compute the next source pointer explicitly and pair + it with the required align setup. +- **Latency:** **9** cycles. + +**Unaligned load pattern:** +```mlir +%align = pto.vldas %ub : !pto.ptr -> !pto.align +%vec, %align2 = pto.vldus %ub, %align : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align +``` + +--- + +##### `pto.init_align` + +- **syntax:** `%result = pto.init_align : !pto.align` +- **semantics:** Initialize store-side align carrier state. +- **outputs:** + `%result` is a fresh zero-initialized align carrier for store-side unaligned + streams such as `pto.vstus`, `pto.vstur`, `pto.vstar`, `pto.vstas`, and + `pto.pstu`. +- **constraints and limitations:** + This op is for store-family initialization only. Unaligned load streams still + start from `pto.vldas`. + +--- + +#### Dual Loads (Deinterleave) + +##### `pto.vldsx2` + +- **syntax:** `%low, %high = pto.vldsx2 %source[%offset], "DIST" : !pto.ptr, index -> !pto.vreg, !pto.vreg` +- **semantics:** Dual load with deinterleave (AoS → SoA conversion). +- **inputs:** + `%source` is the UB base pointer, `%offset` is the displacement, and `DIST` + selects a dual-load/deinterleave layout. +- **outputs:** + `%low` and `%high` are the two destination vectors. +- **constraints and limitations:** + This family is only legal for interleave/deinterleave style distributions. + The two outputs form an ordered pair, and that pairing MUST be preserved. + PTO surface accepts deinterleave families. `BDINTLV` is element-width + agnostic, while `DINTLV` supports only the element widths listed in the + table. +- **latency:** `BDINTLV` / `DINTLV` are both **9** cycles. + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `BDINTLV` | width-agnostic | Block deinterleave into two destination vectors | **9** cycles | +| `DINTLV` | `b8`, `b16`, `b32` | Deinterleave alternating elements into `%low` / `%high` | **9** cycles | + +```c +// DINTLV family on 32-bit elements: deinterleave 32-bit elements +for (int i = 0; i < 64; i++) { + low[i] = UB[base + 8*i]; // even elements + high[i] = UB[base + 8*i + 4]; // odd elements +} +``` + +**Example — Load interleaved XY pairs into separate X/Y vectors:** +```mlir +%x, %y = pto.vldsx2 %ub[%offset], "DINTLV" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +``` + +##### `pto.vsldb` + +- **syntax:** `%result = pto.vsldb %source, %block_stride, %repeat_stride, %mask : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg` +- **semantics:** Block-strided load for 2D tile access. +- **inputs:** + `%source` is the UB base pointer. `%block_stride` and `%repeat_stride` are + the two 16-bit fields of the hardware control word, and `%mask` controls + which blocks participate. +- **outputs:** + `%result` is the loaded vector. +- **constraints and limitations:** + PTO surface does not expose the packed control word directly. If a block is + masked off, the corresponding destination block is zeroed and MUST NOT raise + an address overflow exception for that block. +- **Latency:** **9** cycles. + +```c +// Block-strided load on 32-bit elements: one 32B block = 8 lanes. +for (int blk = 0; blk < 8; ++blk) { + if (pg_b32[blk]) + dst_block[blk] = UB_block[base + repeat_stride + blk * block_stride]; + else + dst_block[blk] = 0; +} +``` + +--- + +#### Gather (Indexed) Loads + +##### `pto.vgather2` + +- **syntax:** `%result = pto.vgather2 %source, %offsets, %active_lanes : !pto.ptr, !pto.vreg, index -> !pto.vreg` +- **semantics:** Indexed gather from UB. +- **inputs:** + `%source` is the UB base pointer, `%offsets` provides per-lane element + offsets, and `%active_lanes` bounds how many lanes participate. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + Only the first `%active_lanes` indices participate. The index element width + and interpretation MUST match the selected gather form, and each effective + address must satisfy that form's alignment rules. +- **Latency:** **27–28** cycles per `RV_VGATHER2`; throughput much lower than contiguous `RV_VLD` (see **Latency and throughput (A5)** at the start of this chapter). + +```c +for (int i = 0; i < active_lanes; i++) + dst[i] = UB[base + offsets[i] * sizeof(T)]; +``` + +--- + +##### `pto.vgatherb` + +- **syntax:** `%result = pto.vgatherb %source, %offsets, %mask : !pto.ptr, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Block gather load from UB. +- **inputs:** + `%source` is the UB base pointer, `%offsets` is a `ui32` offset vector, and + `%mask` is a `b32` predicate over the block-index lanes. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + This is a 32-byte block gather, not an element gather. `%source` MUST be + 32-byte aligned. Each participating `offsets[i]` is interpreted as a byte + offset and MUST itself be 32-byte aligned. Only the low `VL/8` bytes of the + offset vector are semantically valid; the effective block address is + `block_addr[i] = offsets_u32[i] + base`. If a `b32` predicate position is + false, the corresponding block does not participate in address coalescing, + does not raise overflow on that block address, and the destination block is + zero-filled. +- **Latency:** **~21** cycles issue→retire. + +```c +for (int blk = 0; blk < VL / 32; ++blk) { + if (pg_b32[blk]) + dst_block[blk] = UB_block[base + offsets_u32[blk]]; + else + dst_block[blk] = 0; +} +``` + +--- + +##### `pto.vgather2_bc` + +- **syntax:** `%result = pto.vgather2_bc %source, %offsets, %mask : !pto.ptr, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Gather with broadcast, conditioned by mask. +- **inputs:** + `%source` is the UB base pointer, `%offsets` contains gather indices, and + `%mask` gates which lanes participate. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + This is a backward-compatible family. Masked-off lanes do not participate in + address coalescing and do not trigger address overflow exceptions; their + destination lanes are zero-filled. On the current PTO surface, `%offsets` + uses 32-bit integer elements. +- **Latency:** **27–28** cycles (same as **`pto.vgather2`**). + +--- + +#### Contiguous Stores + +##### `pto.vsts` + +- **syntax:** `pto.vsts %value, %dest[%offset], %mask {dist = "DIST"} : !pto.vreg, !pto.ptr, !pto.mask` +- **semantics:** Vector store with distribution mode. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offset` is + the displacement, `%mask` selects the active lanes or sub-elements, and + `DIST` selects the store distribution. +- **outputs:** + This op has no SSA result; it writes to UB memory. +- **constraints and limitations:** + The effective destination address MUST satisfy the alignment rule of the + selected store mode. The single-input `pto.vsts` family covers contiguous + store, first-element-only store, packed store, and channel-merge store. + Dual-input interleave store remains in `pto.vstsx2`. PTO surface exposes + store `dist` as family tokens, and each family only supports the element + widths listed below. + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `NORM` | `b8`, `b16`, `b32` | `UB[base + i] = src[i]` | **9** cycles | +| `1PT` | `b8`, `b16`, `b32` | Only element 0 is written to the destination footprint | **9** cycles | +| `PK` | `b16`, `b32`, `b64` | Pack low half bits of each source element before store | **9** cycles | +| `PK4` | `b32` | Pack low 8 bits of each `b32` element before store | **9** cycles | +| `MRG4CHN` | `b8` | Merge 4 channel planes into an interleaved 4-channel layout | **9** cycles | +| `MRG2CHN` | `b8`, `b16` | Merge 2 channel planes into an interleaved 2-channel layout | **9** cycles | + +**Example — Contiguous store:** +```mlir +pto.vsts %v, %ub[%offset], %mask {dist = "NORM"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +--- + +#### Dual Stores (Interleave) + +##### `pto.vstsx2` + +- **syntax:** `pto.vstsx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vreg, !pto.ptr, index, !pto.mask` +- **semantics:** Dual interleaved store (SoA → AoS conversion). +- **inputs:** + `%low` and `%high` are the two source vectors, `%dest` is the UB base pointer, + `%offset` is the displacement, `DIST` selects the interleave layout, and + `%mask` gates the participating elements. +- **outputs:** + This op has no SSA result; it writes an interleaved stream to UB. +- **constraints and limitations:** + This family is only legal for interleave distributions. The two source + vectors form an ordered pair, and the interleave semantics of that pair MUST + be preserved. PTO surface accepts the `INTLV` family, which only supports the + element widths listed below. + be preserved. PTO surface accepts the `INTLV` family, which only supports the + element widths listed below. +- **latency:** `INTLV` is **12** cycles。 + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `INTLV` | `b8`, `b16`, `b32` | Interleave `%low` / `%high` into one destination stream | **12** cycles | +| `INTLV` | `b8`, `b16`, `b32` | + +```c +// INTLV family on 32-bit elements: +for (int i = 0; i < 64; i++) { + UB[base + 8*i] = low[i]; + UB[base + 8*i + 4] = high[i]; +} +``` + +##### `pto.vsstb` + +- **syntax:** `pto.vsstb %value, %dest, %block_stride, %repeat_stride, %mask : !pto.vreg, !pto.ptr, i16, i16, !pto.mask` +- **semantics:** Block-strided store for 2D tile access. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, + `%block_stride` and `%repeat_stride` are the two 16-bit fields of the + hardware control word, and `%mask` controls block participation. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + PTO surface does not expose the packed control word directly. Masked-off + blocks MUST NOT issue memory writes. +- **Latency:** **9** cycles. + +```c +// Block-strided store on 32-bit elements: one 32B block = 8 lanes. +for (int blk = 0; blk < 8; ++blk) { + if (pg_b32[blk]) + UB_block[base + repeat_stride + blk * block_stride] = src_block[blk]; +} +``` + +--- + +#### Scatter (Indexed) Stores + +##### `pto.vscatter` + +- **syntax:** `pto.vscatter %value, %dest, %offsets, %active_lanes : !pto.vreg, !pto.ptr, !pto.vreg, index` +- **semantics:** Indexed scatter to UB. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offsets` + provides per-lane or per-block indices, and `%active_lanes` bounds the active + requests. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + Only `b8`, `b16`, and `b32` element sizes are supported. The index vector + must use a supported integer element type and layout for this family. + Each computed address MUST be element-aligned. If two or more indices alias, + only one write is guaranteed and the winning lane is implementation-defined. +- **Latency:** **~17** cycles for **`Dtype: B16`**. + +```c +for (int i = 0; i < active_lanes; i++) + UB[base + offsets[i] * sizeof(T)] = src[i]; +``` + +--- + +#### Alignment State Stores + +##### `pto.vstas` +- **syntax:** `pto.vstas %value, %dest, %offset : !pto.align, !pto.ptr, i32` +- **semantics:** Scalar-register-offset form of alignment-state flush. +- **inputs:** + `%value` is the pending store-alignment state, `%dest` is the UB base + pointer, and `%offset` is the scalar-register style displacement. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + This family flushes pending store-alignment state using an explicit scalar + offset and keeps the scalar-offset form explicit. The incoming `%value` + should come from `pto.init_align` or from a prior state-producing unaligned + store op in the same stream. `%dest` and `%offset` together must identify the + same logical flush point produced by the immediately preceding stateful + unaligned-store step on that stream; using an unrelated base/offset pair is + invalid even if `%value` itself came from the same stream. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstar` +- **syntax:** `pto.vstar %value, %dest : !pto.align, !pto.ptr` +- **semantics:** Flush alignment state using the register-update form. +- **inputs:** + `%value` is the pending store-alignment state and `%dest` is the UB base + pointer. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + The implicit update state consumed by this flush MUST correspond to the same + store stream that produced `%value`. The first store-side state in a stream + should be created by `pto.init_align`. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstar` + +- **syntax:** `pto.vstar %value, %dest : !pto.align, !pto.ptr` +- **semantics:** Flush remaining alignment state. +- **inputs:** + `%value` is the pending alignment/buffer state that still needs to be emitted, + and `%dest` is the UB destination base pointer. +- **outputs:** + No SSA result. The effect is a memory-side flush that writes the remaining + buffered bytes to memory. +- **constraints and limitations:** + This op terminates an unaligned-store sequence. It MUST be paired with a + compatible prior state-producing store sequence so that the pending tail state + is well-defined. +- **Latency:** **9** cycles. + +--- + +#### Stateful Store Ops + +These ops make reference-updated state explicit as SSA results. + +##### `pto.vstus` + +- **syntax:** `%align_out = pto.vstus %align_in, %offset, %value, %base : !pto.align, i32, !pto.vreg, !pto.ptr -> !pto.align` +- **semantics:** No-post unaligned store with scalar offset. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%offset` is the scalar + displacement, `%value` is the vector being stored, and `%base` is the UB base + pointer. +- **outputs:** + `%align_out` is the updated buffered-tail state. +- **constraints and limitations:** + This is the scalar-offset stateful form of the unaligned store family. The + scalar offset width MUST match the selected form, and a later flush op is + still required. The first `%align_in` in the stream should come from + `pto.init_align`. This op does not mean "store a full vector starting at + `%base + %offset`". Instead, `%offset` describes how far the store stream + advances at this step, and `%align_out` carries any residual tail that could + not be committed yet. The no-post surface does not expose an updated base + pointer. A later flush op must therefore use an explicit destination/offset + pair that identifies the same logical flush point as this `pto.vstus`. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstur` + +- **syntax:** `%align_out = pto.vstur %align_in, %value, %base, "MODE" : !pto.align, !pto.vreg, !pto.ptr -> !pto.align` +- **semantics:** Unaligned store with residual flush and SPR-AR-driven state update. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%value` is the vector to + store, `%base` is the UB base pointer, and `MODE` selects whether the + hardware updates `SPR AR` after the store. +- **outputs:** + `%align_out` is the updated residual state after the current partial store. +- **constraints and limitations:** + The effective address is `base + AR`, where `AR` is the hardware SPR state + carried outside SSA. `POST_UPDATE` means hardware may advance `SPR AR` + according to the fixed `SPR SQZN` configuration; `NO_POST_UPDATE` preserves + the current `SPR AR` value. This form exposes only the evolving residual + align-state in SSA; it does not by itself guarantee that all buffered bytes + have reached memory. A compatible final flush is still required unless the + surrounding sequence is known to be complete. Independent sequences typically + begin from `AR = 0`; if the surrounding program does not already guarantee + that, the hardware sequence should clear `SPR AR` before the first dependent + `pto.vstur`. The first `%align_in` in the stream should come from + `pto.init_align`. `pto.vstur` also consumes the fixed `SPR SQZN` state, so a + preceding squeeze producer such as `pto.vsqz` / `pto.vusqz` MUST establish + the byte count before the store. `MODE` MUST be one of `POST_UPDATE` or + `NO_POST_UPDATE`. +- **Latency:** **9** cycles. + + + +### 4. Predicate Load/Store + +> **Category:** UB ↔ Predicate Register data movement +> **Pipeline:** PIPE_V (Vector Core) + +Predicate registers (`!pto.mask`) are 256-bit registers that enable per-lane conditional execution. These ops move predicate values between UB and predicate registers. + +In concrete examples, `G` should be chosen to match the consumer family. The +examples below use `b32` when the loaded/stored mask is used with `f32` +vector compares or selects. + +The predicate load/store ops documented on this page always use explicit +`base[offset]` addressing. The immediate forms (`pldi`, `psti`) and dynamic +forms (`plds`, `psts`) differ only in how `%offset` is supplied. + +--- + +#### Predicate Loads + +##### `pto.plds` + +- **syntax:** `%result = pto.plds %source[%offset], "DIST" : !pto.ptr, index -> !pto.mask` +- **semantics:** Load predicate register with runtime offset. This is the + dynamic-offset form of `pto.pldi`: the predicate payload interpretation is + the same, but `%offset` is supplied as an SSA `index` instead of a constant + `index` immediate. +- **DIST:** mandatory string token, one of `NORM`, `US`, `DS`. + - `NORM`: load a normal packed predicate payload of size `VL/8`. + - `US`: load a packed predicate payload of size `VL/16`, then duplicate each + loaded bit once. + - `DS`: load a packed predicate payload of size `2 * VL/8`, then keep one + bit out of every two bits. + +The loaded payload is a packed predicate image in UB. Consumer ops interpret +the resulting `!pto.mask` according to the mask granularity `G`. +`pto.plds` only +models the explicit `base[offset]` form. + +**Example:** +```mlir +%mask = pto.plds %ub[%c0], "NORM" : !pto.ptr, index -> !pto.mask +``` + +--- + +##### `pto.pldi` + +- **syntax:** `%result = pto.pldi %source[%offset], "DIST" : !pto.ptr, index -> !pto.mask` +- **offset:** must be a constant `index` immediate in PTO surface form. +- **semantics:** Load predicate register with immediate offset. +- **DIST:** mandatory string token, one of `NORM`, `US`, `DS`. + - `NORM`: load a normal packed predicate payload of size `VL/8`. + - `US`: load a packed predicate payload of size `VL/16`, then duplicate each + loaded bit once. + - `DS`: load a packed predicate payload of size `2 * VL/8`, then keep one + bit out of every two bits. + +Like `pto.plds`, this op reads a packed predicate payload from UB and +materializes it as `!pto.mask`. + +--- + +#### Predicate Stores + +##### `pto.psts` + +- **syntax:** `pto.psts %value, %dest[%offset], "DIST" : !pto.mask, !pto.ptr, index` +- **semantics:** Store predicate register with runtime offset. This is the + dynamic-offset form of `pto.psti`: the predicate payload interpretation is + the same, but `%offset` is supplied as an SSA `index` instead of a constant + `index` immediate. +- **DIST:** mandatory string token, one of `NORM`, `PK`. + - `NORM`: store the packed predicate payload into a normal destination space + of size `VL/8`. + - `PK`: store the packed predicate payload into a destination space of size + `VL/16`, keeping one bit out of every two bits. + +`pto.psts` stores the packed predicate payload represented by `!pto.mask`. +It only models the explicit `base[offset]` form. + +**Example:** +```mlir +pto.psts %mask, %ub[%c0], "NORM" : !pto.mask, !pto.ptr, index +``` + +--- + +##### `pto.psti` + +- **syntax:** `pto.psti %value, %dest[%offset], "DIST" : !pto.mask, !pto.ptr, index` +- **offset:** must be a constant `index` immediate in PTO surface form. +- **semantics:** Store predicate register with immediate offset. +- **DIST:** mandatory string token, one of `NORM`, `PK`. + - `NORM`: store the packed predicate payload into a normal destination space + of size `VL/8`. + - `PK`: store the packed predicate payload into a destination space of size + `VL/16`, keeping one bit out of every two bits. + +`pto.psti` and `pto.psts` store the packed predicate payload represented by +`!pto.mask`. The surface distinction is only immediate-offset versus +dynamic-offset. + +--- + +##### `pto.pstu` + +- **syntax:** `%align_out, %base_out = pto.pstu %align_in, %value, %base : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr` +- **syntax:** `%align_out, %base_out = pto.pstu %align_in, %value, %base : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr` +- **semantics:** Predicate unaligned store with align/base state update. The base type is fixed by mask granularity: `b16 <-> ui16`, `b32 <-> ui32`. +- **outputs:** + `%align_out` and `%base_out` are the updated unaligned-store state and are + intended to be used by a later `pto.pstu` call. +- **constraints and limitations:** + The first `%align_in` in a predicate unaligned-store stream should come from + `pto.init_align`. + +--- + +#### Typical Usage Pattern + +```mlir +// Generate comparison mask +%mask = pto.vcmp %v0, %v1, %seed, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + +// Store mask to UB for later use +pto.psts %mask, %ub_mask[%c0], "NORM" : !pto.mask, !pto.ptr, index + +// ... later in another kernel ... + +// Load mask from UB +%saved_mask = pto.plds %ub_mask[%c0], "NORM" : !pto.ptr, index -> !pto.mask + +// Use for predicated select +%result = pto.vsel %v_true, %v_false, %saved_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 5. Materialization & Predicate Ops + +> **Category:** Scalar broadcast, predicate generation and manipulation +> **Pipeline:** PIPE_V (Vector Core) + +These ops create vectors from scalar values and manipulate predicate registers. + +#### Common Operand Model + +- `%value` is the scalar source value in SSA form. +- `%input` is either a source scalar or a source vector depending on the op. +- `%result` is the destination vector register value. +- For 32-bit scalar inputs, the scalar source MUST satisfy the backend's legal + scalar-source constraints for this family. + +--- + +#### Scalar Materialization + +##### `pto.vbr` + +- **syntax:** `%result = pto.vbr %value : T -> !pto.vreg` +- **semantics:** Broadcast scalar to all vector lanes. +- **inputs:** + `%value` is the scalar source. +- **outputs:** + `%result` is a vector whose active lanes all carry `%value`. +- **constraints and limitations:** + Supported forms are `b8`, `b16`, and `b32`. For `b8`, only the low 8 bits of + the scalar source are consumed. + +```c +for (int i = 0; i < N; i++) + dst[i] = value; +``` + +**Example:** +```mlir +%one = pto.vbr %c1_f32 : f32 -> !pto.vreg<64xf32> +``` + +--- + +##### `pto.vdup` + +- **syntax:** `%result = pto.vdup %input, %mask {position = "LOWEST|HIGHEST"} : T|!pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Duplicate scalar or vector element to all lanes. +- **inputs:** + `%input` supplies the scalar or source-lane value selected by `position`, + and `%mask` controls the active lanes. +- **outputs:** + `%result` is the duplicated vector. +- **constraints and limitations:** + `position` selects which source vector element is duplicated and is only valid + for vector input. `position` defaults to `LOWEST`. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? input_scalar_or_element : 0; +``` + +--- + +#### Predicate Generation + +##### `pto.pset_b8` / `pto.pset_b16` / `pto.pset_b32` + +- **syntax:** `%result = pto.pset_b8 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pset_b16 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pset_b32 "PATTERN" : !pto.mask` +- **semantics:** Materialize a predicate register from a named pattern token. + +**Supported pattern tokens:** + +| Pattern | Description | +|---------|-------------| +| `PAT_ALL` | All lanes active | +| `PAT_ALLF` | All lanes inactive | +| `PAT_H` | High half active | +| `PAT_Q` | Upper quarter active | +| `PAT_VL1`...`PAT_VL128` | First N logical lanes active | +| `PAT_M3`, `PAT_M4` | Modular patterns | + +`PAT_ALL` is the PTO spelling of the VISA-style all-true predicate pattern. +The other tokens listed above are also concrete installed-toolchain pattern +objects, not PTO-only aliases. + +**Example — All 64 f32 lanes active:** +```mlir +%all_active = pto.pset_b32 "PAT_ALL" : !pto.mask +``` + +**Example — First 16 lanes active:** +```mlir +%first_16 = pto.pset_b32 "PAT_VL16" : !pto.mask +``` + +--- + +##### `pto.pge_b8` / `pto.pge_b16` / `pto.pge_b32` + +- **syntax:** `%result = pto.pge_b8 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pge_b16 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pge_b32 "PATTERN" : !pto.mask` +- **semantics:** Generate a predicate from a lane-count pattern token. In the + common tail-mask form, `PAT_VL` marks the first `N` logical lanes active. +- **supported pattern tokens:** `PAT_ALL`, `PAT_ALLF`, `PAT_H`, `PAT_Q`, + `PAT_VL1`, `PAT_VL2`, `PAT_VL3`, `PAT_VL4`, `PAT_VL8`, `PAT_VL16`, + `PAT_VL32`, `PAT_VL64`, `PAT_VL128`, `PAT_M3`, `PAT_M4` + +```c +for (int i = 0; i < TOTAL_LANES; i++) + mask[i] = (i < len); +``` + +**Example — Tail mask for remainder loop:** +```mlir +%tail_mask = pto.pge_b32 "PAT_VL8" : !pto.mask +``` + +--- + +##### `pto.plt_b8` / `pto.plt_b16` / `pto.plt_b32` + +- **syntax:** `%mask, %scalar_out = pto.plt_b8 %scalar : i32 -> !pto.mask, i32` +- **syntax:** `%mask, %scalar_out = pto.plt_b16 %scalar : i32 -> !pto.mask, i32` +- **syntax:** `%mask, %scalar_out = pto.plt_b32 %scalar : i32 -> !pto.mask, i32` +- **semantics:** Generate a tail-style predicate from an SSA lane-count value. + On A5/V300-style toolchains, this family is exposed as a post-update wrapper: + the predicate result becomes `%mask`, and the wrapper's carry-out scalar state + is surfaced as `%scalar_out`. +- **inputs:** + `%scalar` is the incoming lane-count / remaining-count state. +- **outputs:** + `%mask` is the generated predicate. + `%scalar_out` is the post-update scalar carry-out from the same `plt` call + and can be threaded into a subsequent `pto.plt_b*` call in the same chain. + +```c +for (int i = 0; i < VL_t; ++i) + mask[i] = (i < scalar_in); + +scalar_out = (scalar_in < VL_t) ? 0 : (scalar_in - VL_t); +``` + +Where `VL_t` is the logical lane count of the concrete op variant: + +- `pto.plt_b8`: `VL_t = 256` +- `pto.plt_b16`: `VL_t = 128` +- `pto.plt_b32`: `VL_t = 64` + +--- + +#### Predicate Pack/Unpack + +##### `pto.ppack` + +- **syntax:** `%result = pto.ppack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Narrowing pack of predicate register. +- **part tokens:** + - `LOWER`: pack into the lower half of `%result`; the upper half is zeroed. + - `HIGHER`: pack into the higher half of `%result`; the lower half is zeroed. + +Conceptually, `pto.ppack` keeps one bit out of each adjacent 2-bit group from +`%input`, packs those kept bits into the selected half of `%result`, and fills +the other half with zeros. + +```c +// Let VL be the logical lane count of the destination predicate. +// LOWER +for (int i = 0; i < VL / 2; ++i) + result[i] = input[2 * i]; +for (int i = VL / 2; i < VL; ++i) + result[i] = 0; + +// HIGHER +for (int i = 0; i < VL / 2; ++i) + result[VL / 2 + i] = input[2 * i]; +for (int i = 0; i < VL / 2; ++i) + result[i] = 0; +``` + +--- + +##### `pto.punpack` + +- **syntax:** `%result = pto.punpack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Widening unpack of predicate register. +- **part tokens:** + - `LOWER`: unpack from the lower half of `%input`. + - `HIGHER`: unpack from the higher half of `%input`. + +Conceptually, `pto.punpack` reads the selected half of `%input`, zero-extends +each 1-bit predicate element into a 2-bit group in `%result`, and leaves the +expanded image in the full destination predicate register. + +```c +// Let VL be the logical lane count of the destination predicate. +// LOWER +for (int i = 0; i < VL / 2; ++i) { + result[2 * i] = input[i]; + result[2 * i + 1] = 0; +} + +// HIGHER +for (int i = 0; i < VL / 2; ++i) { + result[2 * i] = input[VL / 2 + i]; + result[2 * i + 1] = 0; +} +``` + +--- + +#### Predicate Logical Ops + +##### `pto.pand` + +- **syntax:** `%result = pto.pand %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise AND gated by a governing predicate. + +Inactive lanes selected out by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (src0[i] & src1[i]) : 0; +``` + +--- + +##### `pto.por` + +- **syntax:** `%result = pto.por %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise OR gated by a governing predicate. + +Inactive lanes selected out by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (src0[i] | src1[i]) : 0; +``` + +--- + +##### `pto.pxor` + +- **syntax:** `%result = pto.pxor %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise XOR gated by a governing predicate. + +Inactive lanes selected by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (src0[i] ^ src1[i]) : 0; +``` + +--- + +##### `pto.pnot` + +- **syntax:** `%result = pto.pnot %input, %mask : !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise NOT gated by a governing predicate. + +Inactive lanes selected by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (~src[i]) : 0; +``` + +--- + +##### `pto.psel` + +- **syntax:** `%result = pto.psel %src0, %src1, %sel : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate select (mux). `%sel` is the governing predicate that + chooses lanes from `%src0` or `%src1`. + +```c +for (int i = 0; i < N; i++) + dst[i] = sel[i] ? src0[i] : src1[i]; +``` + +--- + +##### `pto.pdintlv_b8` / `pto.pdintlv_b16` / `pto.pdintlv_b32` + +- **syntax:** `%low, %high = pto.pdintlv_b8 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pdintlv_b16 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pdintlv_b32 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **semantics:** De-interleave two predicate sources and return the two + de-interleaved predicate images in the same predicate element family. + +--- + +##### `pto.pintlv_b8` / `pto.pintlv_b16` / `pto.pintlv_b32` + +- **syntax:** `%low, %high = pto.pintlv_b8 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pintlv_b16 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pintlv_b32 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **semantics:** Interleave two predicate sources and return the two + resulting predicate images in the same predicate element family. + +--- + +#### Typical Usage + +```mlir +// Generate all-active mask for f32 (64 lanes) +%all = pto.pset_b32 "PAT_ALL" : !pto.mask + +// Generate tail mask for remainder (last 12 elements) +%tail = pto.pge_b32 "PAT_VL12" : !pto.mask + +// Compare and generate mask +%cmp_mask = pto.vcmp %a, %b, %all, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + +// Combine masks: only process tail elements that passed comparison +%combined = pto.pand %cmp_mask, %tail, %all : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + +// Use for predicated operation +%result = pto.vsel %true_vals, %false_vals, %combined : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 6. Unary Vector Ops + +> **Category:** Single-input vector operations +> **Pipeline:** PIPE_V (Vector Core) + +Element-wise operations that take one vector input and produce one vector output. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate operand. For this family, inactive lanes follow the + predication behavior of the selected instruction form: zeroing forms + zero-fill inactive lanes, while merging forms preserve the destination value. +- `%result` is the destination vector register value. Unless stated otherwise, + `%result` has the same lane count and element type as `%input`. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** values use **aclFloat16** in traces where measured. **bf16:** no simple-tile ST coverage on this surface; treat as **—**. + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vabs` | `RV_VABS_FP` | **5** | **5** | — | +| `pto.vneg` | `RV_VMULS` | **8** | **8** | — | +| `pto.vexp` | `RV_VEXP` | **16** | **21** | — | +| `pto.vln` | `RV_VLN` | **18** | **23** | — | +| `pto.vsqrt` | `RV_VSQRT` | **17** | **22** | — | +| `pto.vrelu` | `RV_VRELU` | **5** | **5** | — | +| `pto.vnot` | `RV_VNOT` | — | int-only paths | — | +| `pto.vmov` | `RV_VLD` proxy | **9** | **9** | — | + +--- + +#### Arithmetic + +##### `pto.vabs` + +- **syntax:** `%result = pto.vabs %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] < 0) ? -src[i] : src[i]; +``` + +- **inputs:** `%input` supplies the source lanes and `%mask` selects which lanes + participate. +- **outputs:** `%result` receives the lane-wise absolute values. +- **constraints and limitations:** Source and result types MUST match. On A5, + integer overflow follows the ISA default truncation behavior for this family; + `pto.vabs` is not an explicit saturating op. + +--- + +##### `pto.vneg` + +- **syntax:** `%result = pto.vneg %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = -src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise arithmetic negation. +- **constraints and limitations:** Source and result types MUST match. + +--- + +#### Transcendental + +##### `pto.vexp` + +- **syntax:** `%result = pto.vexp %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = expf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds `exp(input[i])` per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + +--- + +##### `pto.vln` + +- **syntax:** `%result = pto.vln %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = logf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the natural logarithm per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + For real-number semantics, active inputs SHOULD be strictly positive; non- + positive inputs follow the target's exception/NaN rules. + +--- + +##### `pto.vsqrt` + +- **syntax:** `%result = pto.vsqrt %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = sqrtf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the square root per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + Negative active inputs follow the target's exception/NaN rules. + +--- + +#### Activation + +##### `pto.vrelu` + +- **syntax:** `%result = pto.vrelu %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] > 0) ? src[i] : 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds `max(input[i], 0)` per active lane. +- **constraints and limitations:** Only floating-point element types are legal + on the current A5 surface described here. + +--- + +#### Bitwise + +##### `pto.vnot` + +- **syntax:** `%result = pto.vnot %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = ~src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the lane-wise bitwise inversion. +- **constraints and limitations:** Integer element types only. + +--- + +#### Movement + +#### Typical Usage + +```mlir +// Softmax numerator: exp(x - max) +%sub = pto.vsub %x, %max_broadcast, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%exp = pto.vexp %sub, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// ReLU activation +%activated = pto.vrelu %linear_out, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 7. Binary Vector Ops + +> **Category:** Two-input vector operations +> **Pipeline:** PIPE_V (Vector Core) + +Element-wise operations that take two vector inputs and produce one vector output. + +#### Common Operand Model + +- `%lhs` and `%rhs` are the two source vector register values. +- `%mask` is the predicate operand `Pg` that gates which lanes participate. +- `%result` is the destination vector register value. Unless explicitly noted, + it has the same lane count and element type as the inputs. +- Unless explicitly documented otherwise, `%lhs`, `%rhs`, and `%result` MUST + have matching vector shapes and element types. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** uses **aclFloat16** in measured traces. **bf16:** — (no dedicated vec tile ST on this surface). + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vadd` | `RV_VADD` | **7** | **7** | — | +| `pto.vsub` | `RV_VSUB` | **7** | **7** | — | +| `pto.vmul` | `RV_VMUL` | **8** | **8** | — | +| `pto.vdiv` | `RV_VDIV` | **17** | **22** | — | + +--- + +#### Arithmetic + +##### `pto.vadd` + +- **syntax:** `%result = pto.vadd %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i64, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +- **inputs:** `%lhs` and `%rhs` are added lane-wise; `%mask` selects active + lanes. +- **outputs:** `%result` is the lane-wise sum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vsub` + +- **syntax:** `%result = pto.vsub %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i64, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] - src1[i]; +``` + +- **inputs:** `%lhs` is the minuend, `%rhs` is the subtrahend, and `%mask` + selects active lanes. +- **outputs:** `%result` is the lane-wise difference. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmul` + +- **syntax:** `%result = pto.vmul %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, bf16, f32 (**NOT** i8/ui8) + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] * src1[i]; +``` + +- **inputs:** `%lhs` and `%rhs` are multiplied lane-wise; `%mask` selects + active lanes. +- **outputs:** `%result` is the lane-wise product. +- **constraints and limitations:** The current A5 profile excludes `i8/ui8` + forms from this surface. + +--- + +##### `pto.vdiv` + +- **syntax:** `%result = pto.vdiv %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 only (no integer division) + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] / src1[i]; +``` + +- **inputs:** `%lhs` is the numerator, `%rhs` is the denominator, and `%mask` + selects active lanes. +- **outputs:** `%result` is the lane-wise quotient. +- **constraints and limitations:** Floating-point element types only. Active + denominators containing `+0` or `-0` follow the target's exceptional + behavior. + +--- + +##### `pto.vmax` + +- **syntax:** `%result = pto.vmax %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src0[i] > src1[i]) ? src0[i] : src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` holds the lane-wise maximum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmin` + +- **syntax:** `%result = pto.vmin %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src0[i] < src1[i]) ? src0[i] : src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` holds the lane-wise minimum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +#### Bitwise + +##### `pto.vand` + +- **syntax:** `%result = pto.vand %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] & src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise AND. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vor` + +- **syntax:** `%result = pto.vor %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] | src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise OR. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vxor` + +- **syntax:** `%result = pto.vxor %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] ^ src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise XOR. +- **constraints and limitations:** Integer element types only. + +--- + +#### Shift + +##### `pto.vshl` + +- **syntax:** `%result = pto.vshl %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] << src1[i]; +``` + +- **inputs:** `%lhs` supplies the shifted value, `%rhs` supplies the per-lane + shift amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. Shift counts + SHOULD stay within `[0, bitwidth(T) - 1]`; out-of-range behavior is target- + defined unless the verifier narrows it further. + +--- + +##### `pto.vshr` + +- **syntax:** `%result = pto.vshr %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] >> src1[i]; // arithmetic for signed, logical for unsigned +``` + +- **inputs:** `%lhs` supplies the shifted value, `%rhs` supplies the per-lane + shift amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. Signedness of the + element type determines arithmetic vs logical behavior. + +--- + +#### Carry Operations + +##### `pto.vaddc` + +- **syntax:** `%result, %carry = pto.vaddc %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Add with carry output. + +```c +for (int i = 0; i < N; i++) { + uint64_t r = (uint64_t)src0[i] + src1[i]; + dst[i] = (T)r; + carry[i] = (r >> bitwidth); +} +``` + +- **inputs:** `%lhs` and `%rhs` are added lane-wise and `%mask` selects active + lanes. +- **outputs:** `%result` is the truncated arithmetic result and `%carry` is the + carry/overflow predicate per lane. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This is a carry-chain integer add family. On + the current A5 surface, only 32-bit integer element types are supported. + `%mask` and `%carry` therefore use the same typed-mask granularity as the + data vector family, which on the current documented A5 surface means + `!pto.mask`. + +--- + +##### `pto.vsubc` + +- **syntax:** `%result, %carry = pto.vsubc %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Subtract with per-lane carry output. + +```c +for (int i = 0; i < N; i++) { + dst[i] = src0[i] - src1[i]; + carry[i] = (src0[i] >= src1[i]); +} +``` + +- **inputs:** `%lhs` and `%rhs` are subtracted lane-wise and `%mask` selects + active lanes. +- **outputs:** `%result` is the arithmetic difference and `%carry` is the + per-lane carry predicate. For this subtraction family, active lanes set + `%carry[i] = 1` when the subtraction completes without borrow, and + `%carry[i] = 0` when a borrow occurs. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This operation is currently restricted to + the 32-bit integer carry/borrow-chain family. `%mask` and `%carry` + therefore use the same typed-mask granularity as the data vector family, + which on the current documented A5 surface means `!pto.mask`. + +--- + +#### Typical Usage + +```mlir +// Vector addition +%sum = pto.vadd %a, %b, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Element-wise multiply +%prod = pto.vmul %x, %y, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Clamp to range [min, max] +%clamped_low = pto.vmax %input, %min_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%clamped = pto.vmin %clamped_low, %max_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Bit manipulation +%masked = pto.vand %data, %bitmask, %mask : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +``` + + + +### 8. Vec-Scalar Ops + +> **Category:** Vector-scalar operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that combine a vector with a scalar value, applying the scalar to every lane. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%scalar` is the scalar operand in SSA form. +- `%mask` is the predicate operand. +- `%result` is the destination vector register value. +- For 32-bit scalar forms, the scalar source MUST satisfy the backend's legal + scalar-source constraints for this family. +- For elementwise vec-scalar families whose scalar conceptually matches the + vector element type (`pto.vadds`, `pto.vmuls`, `pto.vmaxs`, + `pto.vmins`, `pto.vlrelu`): + - signed integer vectors accept signed integer scalars with the same width, + and also accept signless `i` + - unsigned integer vectors accept unsigned integer scalars with the same + width, and also accept signless `i` + - signless integer vectors accept signless `i` +- `pto.vshls` and `pto.vshrs` are not part of that rule; their scalar operand + is the shift amount and remains fixed to `i16`. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** uses **aclFloat16** in measured traces. **bf16:** —. + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vadds` | `RV_VADDS` | **7** | **7** | — | +| `pto.vmuls` | `RV_VMULS` | **8** | **8** | — | + +--- + +#### Arithmetic + +##### `pto.vadds` + +- **syntax:** `%result = pto.vadds %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` +- **A5 types:** `si8`, `si16`, `si32`, `ui8`, `ui16`, `ui32`, `f16`, `bf16`, `f32` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] + scalar; +``` + +- **inputs:** `%input` is the source vector, `%scalar` is broadcast logically to + each lane, and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise sum. +- **constraints and limitations:** Input vector element type, scalar type, and + result vector element type MUST match. For integer vector forms, `%scalar` + may also use matching-signedness integer or signless `i` with the same + bit width as the vector element type, so it can be fed directly from `arith` + constants. + +--- + +##### `pto.vmuls` + +- **syntax:** `%result = pto.vmuls %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] * scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise product. +- **constraints and limitations:** Supported element types are hardware-family + specific; the current PTO micro Instruction documentation covers the common + numeric cases. For integer vector forms, `%scalar` may use matching-signedness + integer or signless `i` with the same bit width as the vector element + type. + +--- + +##### `pto.vmaxs` + +- **syntax:** `%result = pto.vmaxs %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] > scalar) ? src[i] : scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise maximum. +- **constraints and limitations:** Input and result types MUST match. For + integer vector forms, `%scalar` may use matching-signedness integer or + signless `i` with the same bit width as the vector element type. + +--- + +##### `pto.vmins` + +- **syntax:** `%result = pto.vmins %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] < scalar) ? src[i] : scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise minimum. +- **constraints and limitations:** Input and result types MUST match. For + integer vector forms, `%scalar` may use matching-signedness integer or + signless `i` with the same bit width as the vector element type. + +--- + +#### Shift + +##### `pto.vshls` + +- **syntax:** `%result = pto.vshls %input, %scalar, %mask : !pto.vreg, i16, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] << scalar; +``` + +- **inputs:** `%input` is the value vector, `%scalar` is the uniform `i16` shift + amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. The shift amount + SHOULD stay within the source element width. + +--- + +##### `pto.vshrs` + +- **syntax:** `%result = pto.vshrs %input, %scalar, %mask : !pto.vreg, i16, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] >> scalar; +``` + +- **inputs:** `%input` is the value vector, `%scalar` is the uniform `i16` shift + amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vlrelu` + +- **syntax:** `%result = pto.vlrelu %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : scalar * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%scalar` is the leaky slope, + and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise leaky-ReLU result. +- **constraints and limitations:** Only `f16` and `f32` forms are currently + documented for `pto.vlrelu`. + +--- + +#### Carry Operations + +##### `pto.vaddcs` + +- **syntax:** `%result, %carry = pto.vaddcs %lhs, %rhs, %carry_in, %mask : !pto.vreg, !pto.vreg, !pto.mask, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Add with carry-in and carry-out. + +```c +for (int i = 0; i < N; i++) { + uint64_t r = (uint64_t)src0[i] + src1[i] + carry_in[i]; + dst[i] = (T)r; + carry_out[i] = (r >> bitwidth); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the value vectors, `%carry_in` is the + incoming carry predicate, and `%mask` selects active lanes. +- **outputs:** `%result` is the arithmetic result and `%carry` is the carry-out + predicate. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This is the scalar-extended carry-chain + family. On the current A5 surface, only 32-bit integer element types are + supported. `%carry_in`, `%mask`, and `%carry` therefore all use the same + typed-mask granularity as the data vector family, which on the current + documented A5 surface means `!pto.mask`. + +--- + +##### `pto.vsubcs` + +- **syntax:** `%result, %carry = pto.vsubcs %lhs, %rhs, %carry_in, %mask : !pto.vreg, !pto.vreg, !pto.mask, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Subtract with carry input and output. + +```c +for (int i = 0; i < N; i++) { + dst[i] = src0[i] - src1[i] - (1 - carry_in[i]); + carry_out[i] = (src0[i] >= src1[i] + (1 - carry_in[i])); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the value vectors, `%carry_in` is the + incoming carry predicate, and `%mask` selects active lanes. +- **outputs:** `%result` is the arithmetic result and `%carry` is the + carry predicate after the lane-wise subtraction. For this subtraction family, + active lanes set `%carry[i] = 1` when the subtraction completes without + borrow, and `%carry[i] = 0` when a borrow occurs. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This is the scalar-extended borrow-chain + family and is currently restricted to 32-bit integer element types. + `%carry_in`, `%mask`, and `%carry` therefore all use the same typed-mask + granularity as the data vector family, which on the current documented A5 + surface means `!pto.mask`. + +--- + +#### Typical Usage + +```mlir +// Add bias to all elements +%biased = pto.vadds %activation, %bias_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Scale by constant +%scaled = pto.vmuls %input, %scale, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Clamp to [0, 255] for uint8 quantization +%clamped_low = pto.vmaxs %input, %c0, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> +%clamped = pto.vmins %clamped_low, %c255, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Shift right by fixed amount +%shifted = pto.vshrs %data, %c4, %mask : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +``` + + + +### 9. Conversion Ops + +> **Category:** Type conversion operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that convert between data types (float/int, narrowing/widening). + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate mask that selects active conversion lanes. +- `%result` is the destination vector register value. +- `rnd`, `sat`, and `part` are optional attributes that refine + conversion behavior when the selected source/destination type pair needs + rounding, saturation, or lane placement control. +- The single `pto.vcvt` surface covers float-int, float-float, int-float, and + int-int conversion families. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). Only representative traces below; other `pto.vcvt` conversion pairs depend on the RV lowering in the trace. + +| PTO op | RV (CA) | Note | Latency | +|--------|---------|------|---------| +| `pto.vcvt` | `RV_VCVT_F2F` | f32→f16 | **7** | +| `pto.vci` | — | no vector `RV_*` in sampled `veccore0` trace | — | + +--- + +#### `pto.vci` + +- **syntax:** `%result = pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- **semantics:** Generate a lane-index vector from a scalar seed/index value. +- **inputs:** + `%index` is the scalar seed or base index. +- **outputs:** + `%result` is the generated index vector. +- **constraints and limitations:** + This is an index-generation family, not a numeric conversion. `ORDER` and the + result element type together determine how indices are generated. `%result` + uses an integer element type, and the scalar `%index` type matches that + result element type. + +--- + +#### `pto.vcvt` + +- **syntax:** `%result = pto.vcvt %input, %mask {rnd = "RND", sat = "SAT", part = "PART"} : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Type conversion between float/int types with rounding control. + +```c +for (int i = 0; i < min(N, M); i++) + if (mask[i]) + dst[i] = convert(src[i], T0, T1, rnd); +``` + +- **inputs:** + `%input` is the source vector, `%mask` selects active lanes, and attributes + select rounding, saturation, and output placement when the conversion changes + width or packs into sub-lane positions. +- **outputs:** + `%result` is the converted vector. +- **constraints and limitations:** + Only documented source/destination type pairs are legal. All three + attributes are optional at the surface level, but only the subset meaningful + to the selected conversion kind should be provided. The execution mask must + use the typed-mask granularity that matches the source vector family on the + current surface; there is no `!pto.mask` form in VPTO. + +--- + +##### Rounding Modes + +| Mode | Description | +|------|-------------| +| `R` | Round to nearest, ties to even (default) | +| `A` | Round away from zero | +| `F` | Round toward negative infinity (floor) | +| `C` | Round toward positive infinity (ceil) | +| `Z` | Round toward zero (truncate) | +| `O` | Round to odd | + +--- + +##### Saturation Modes + +| Mode | Description | +|------|-------------| +| `SAT` | Saturate on overflow | +| `NOSAT` | No saturation (wrap/undefined on overflow) | + +--- + +##### Part Modes + +Use `part` when a width-changing conversion writes only one half of each wider +destination lane group. This is typically used in even/odd placement forms such +as `32 -> 16` or `16 -> 32` style conversions. + +| Mode | Description | +|------|-------------| +| `EVEN` | Output to even-indexed lanes | +| `ODD` | Output to odd-indexed lanes | + +--- + +##### Attribute Guidance + +- `rnd` + - Use when the conversion needs an explicit rounding rule, especially for + float-to-int, float-to-float narrowing, or integer-to-float forms that do + not map exactly. +- `mask` + - Use to select which source lanes participate in the conversion. In + width-changing conversions, `mask` works together with `part` / `pp` to + determine which logical lane positions are produced. +- `sat` + - Use when the conversion may overflow the destination range and hardware + exposes a saturating form. +- `part` + - Use for width-changing conversions that select the even or odd half of the + destination packing layout. + +###### Float To Int + +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<32xsi64>` +- `%dst = pto.vcvt %src, %mask {rnd, sat} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {rnd, part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {rnd, sat} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<256xsi8>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xsi32>` + +###### Float To Float + +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xbf16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xf32>` + +###### Int To Float + +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xsi8>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {rnd} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<64xf32>` +- `%dst = pto.vcvt %src, %mask {rnd} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<64xf32>` + +###### Int To Int + +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<128xui16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<64xui32>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xsi8>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xsi8>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<64xui32>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<64xui32>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<128xui16>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<128xui16>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<32xsi64>` + +##### A5 Supported Type Matrix + +The table below is only a summary. For exact attribute combinations, use the +per-form entries above as the source of truth. + +| `src \ dst` | `ui8` | `si8` | `ui16` | `si16` | `ui32` | `si32` | `si64` | `f16` | `f32` | `bf16` | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| `ui8` | | | Y | | Y | | | Y | | | +| `si8` | | | | Y | | Y | | Y | | | +| `ui16` | Y | | | | Y | | | | | | +| `si16` | Y | | | | Y | Y | | Y | Y | | +| `ui32` | Y | | Y | Y | | | | | | | +| `si32` | Y | | Y | Y | | | Y | | Y | | +| `si64` | | | | | | | | | | | +| `f16` | Y | Y | | Y | | Y | | | Y | | +| `f32` | | | | Y | | Y | Y | Y | | Y | +| `bf16` | | | | | | Y | | | Y | | + +--- + +##### Width-Changing Conversion Pattern + +For conversions that change width (e.g., f32→f16), use even/odd parts and combine: + +```mlir +// Convert two f32 vectors to one f16 vector +%even = pto.vcvt %in0, %mask {rnd = "R", sat = "SAT", part = "EVEN"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%odd = pto.vcvt %in1, %mask {rnd = "R", sat = "SAT", part = "ODD"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%result = pto.vor %even, %odd, %mask : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +``` + +--- + +#### `pto.vtrc` + +- **syntax:** `%result = pto.vtrc %input, %mask, "RND" : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Truncate/round float to integer-valued float (stays in float type). + +```c +for (int i = 0; i < N; i++) + dst[i] = round_to_int_valued_float(src[i], rnd); +``` + +- **inputs:** + `%input` is the floating-point source vector, `%mask` selects active lanes, + and `RND` selects the truncation/rounding rule. +- **outputs:** + `%result` is still a floating-point vector, but each active lane now carries + an integer-valued floating-point result. +- **constraints and limitations:** + This op does not change the element type. `T` must be `f16`, `f32`, or + `bf16`. `RND` must be one of `R`, `A`, `F`, `C`, or `Z`. `BW` must match the + element width: `b16` for `f16`/`bf16`, `b32` for `f32`. + +**Example:** +```mlir +// Round to nearest integer, keep as float +%rounded = pto.vtrc %input, %mask, "R" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// input: [1.4, 2.6, -1.5, 3.0] +// output: [1.0, 3.0, -2.0, 3.0] +``` + +--- + +#### Typical Usage + +```mlir +// Quantization: f32 → i8 with saturation +%scaled = pto.vmuls %input, %scale, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> +%quantized = pto.vcvt %scaled, %mask {rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> +// Then narrow i32 → i8 via pack ops + +// Mixed precision: bf16 → f32 for accumulation +%f32_vec = pto.vcvt %bf16_input, %mask {part = "EVEN"} + : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xf32> + +// Floor for integer division +%floored = pto.vtrc %ratio, %mask, "F" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%int_div = pto.vcvt %floored, %mask {rnd = "Z"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> +``` + + + +### 10. Reduction Ops + +> **Category:** Vector reduction operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that reduce a vector to a scalar or per-group result. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate operand `Pg`; inactive lanes do not participate. +- `%result` is the destination vector register value. +- Reduction results are written into the low-significance portion of the + destination vector and the remaining destination bits are zero-filled. + +--- + +#### Full Vector Reductions + +##### `pto.vcadd` + +- **syntax:** `%result = pto.vcadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i64, f16, f32 +- **semantics:** Sum all elements. Result in lane 0, others zeroed. + +```c +T sum = 0; +for (int i = 0; i < N; i++) + sum += src[i]; +dst[0] = sum; +for (int i = 1; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains the reduction result in its low element(s). +- **constraints and limitations:** Some narrow integer forms may widen the + internal accumulation or result placement. If all predicate bits are zero, the + result is zero. + +--- + +##### `pto.vcmax` + +- **syntax:** `%result = pto.vcmax %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Find max element with argmax. The lowest destination element + stores the maximum value, the second-lowest destination element stores the + index of the first maximum, and all remaining elements are zero-filled. + +```c +T mx = -INF; int idx = 0; +for (int i = 0; i < N; i++) + if (src[i] > mx) { mx = src[i]; idx = i; } +dst[0] = mx; +dst[1] = idx; +for (int i = 2; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result[0]` holds the extremum value and `%result[1]` holds the + index. Other destination elements are zero-filled. +- **constraints and limitations:** If there are multiple maxima, the minimum + index is written. For floating-point types, inactive lanes are treated as + `-INF`; if all lanes are inactive, `%result[0]` becomes `-INF`. For integer + types, inactive lanes are treated as the literal minimum value; if all lanes + are inactive, `%result[0]` becomes that literal minimum value. The index is + written into the second destination element slot of the same destination + vector register. + +--- + +##### `pto.vcmin` + +- **syntax:** `%result = pto.vcmin %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Find min element with argmin. The lowest destination element + stores the minimum value, the second-lowest destination element stores the + index of the first minimum, and all remaining elements are zero-filled. + +```c +T mn = INF; int idx = 0; +for (int i = 0; i < N; i++) + if (src[i] < mn) { mn = src[i]; idx = i; } +dst[0] = mn; +dst[1] = idx; +for (int i = 2; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result[0]` holds the extremum value and `%result[1]` holds the + index. Other destination elements are zero-filled. +- **constraints and limitations:** If there are multiple minima, the minimum + index is written. For floating-point types, inactive lanes are treated as + `+INF`; if all lanes are inactive, `%result[0]` becomes `+INF`. For integer + types, inactive lanes are treated as the literal maximum value; if all lanes + are inactive, `%result[0]` becomes that literal maximum value. The index is + written into the second destination element slot of the same destination + vector register. + +--- + +#### Per-VLane (Group) Reductions + +The vector register is organized as **8 VLanes** of 32 bytes each. Group reductions operate within each VLane independently. + +``` +vreg layout (f32 example, 64 elements total): +VLane 0: [0..7] VLane 1: [8..15] VLane 2: [16..23] VLane 3: [24..31] +VLane 4: [32..39] VLane 5: [40..47] VLane 6: [48..55] VLane 7: [56..63] +``` + +##### `pto.vcgadd` + +- **syntax:** `%result = pto.vcgadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Sum within each VLane. 8 results at indices 0, 8, 16, 24, 32, 40, 48, 56 (for f32). + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +// For f32: results at dst[0], dst[8], dst[16], dst[24], dst[32], dst[40], dst[48], dst[56] +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one sum per 32-byte VLane group, written + contiguously into the low slot of each group. +- **constraints and limitations:** This is a per-32-byte VLane-group reduction. + Inactive lanes are treated as zero. + +--- + +##### `pto.vcgmax` + +- **syntax:** `%result = pto.vcgmax %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Max within each VLane. + +```c +int K = N / 8; +for (int g = 0; g < 8; g++) { + T mx = -INF; + for (int i = 0; i < K; i++) + if (src[g*K + i] > mx) mx = src[g*K + i]; + dst[g*K] = mx; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one maximum per 32-byte VLane group. +- **constraints and limitations:** Grouping is by hardware 32-byte VLane, not by + arbitrary software subvector. + +--- + +##### `pto.vcgmin` + +- **syntax:** `%result = pto.vcgmin %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Min within each VLane. + +```c +int K = N / 8; +for (int g = 0; g < 8; g++) { + T mn = INF; + for (int i = 0; i < K; i++) + if (src[g*K + i] < mn) mn = src[g*K + i]; + dst[g*K] = mn; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one minimum per 32-byte VLane group. +- **constraints and limitations:** Grouping is by hardware 32-byte VLane, not by + arbitrary software subvector. + +--- + +#### Prefix Operations + +##### `pto.vcpadd` + +- **syntax:** `%result = pto.vcpadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Inclusive prefix sum (scan). + +```c +dst[0] = src[0]; +for (int i = 1; i < N; i++) + dst[i] = dst[i-1] + src[i]; +``` + +**Example:** +```c +// input: [1, 2, 3, 4, 5, ...] +// output: [1, 3, 6, 10, 15, ...] +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` is the inclusive prefix-sum vector. +- **constraints and limitations:** Only floating-point element types are + documented on the current A5 surface here. + +--- + +#### Typical Usage + +```mlir +// Softmax: find max for numerical stability +%max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// max is in lane 0, broadcast it +%max_broadcast = pto.vlds %ub_tmp[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> + +// Row-wise sum using vcgadd (for 8-row tile) +%row_sums = pto.vcgadd %tile, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// Results at indices 0, 8, 16, 24, 32, 40, 48, 56 + +// Full vector sum for normalization +%total = pto.vcadd %values, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// total[0] contains the sum + +// Prefix sum for cumulative distribution +%cdf = pto.vcpadd %pdf, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 11. Compare & Select + +> **Category:** Comparison and conditional selection operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that compare vectors and conditionally select elements. + +#### Common Operand Model + +- `%src0` and `%src1` are source vector operands. +- `%scalar` is the scalar operand for scalar-comparison families. +- `%seed` is the incoming predicate that limits which lanes participate in the + compare. +- `%result` is either a predicate mask (`vcmp`, `vcmps`) or a vector register + (`vsel`, `vselr`, `vselrv2`). + +--- + +#### Comparison Operations + +##### `pto.vcmp` + +- **syntax:** `%result = pto.vcmp %src0, %src1, %seed, "CMP_MODE" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.mask` +- **semantics:** Element-wise comparison, output predicate mask. + +```c +for (int i = 0; i < N; i++) + if (seed[i]) + dst[i] = (src0[i] CMP src1[i]) ? 1 : 0; +``` + +**Compare modes:** + +| Mode | Operation | +|------|-----------| +| `eq` | Equal (==) | +| `ne` | Not equal (!=) | +| `lt` | Less than (<) | +| `le` | Less than or equal (<=) | +| `gt` | Greater than (>) | +| `ge` | Greater than or equal (>=) | + +**Example:** +```mlir +%all_active = pto.pset_b32 "PAT_ALL" : !pto.mask +%lt_mask = pto.vcmp %a, %b, %all_active, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +// lt_mask[i] = 1 if a[i] < b[i] +``` + +- **inputs:** `%src0`, `%src1`, and `%seed`; `CMP_MODE` selects the comparison + predicate. +- **outputs:** `%result` is the generated predicate mask. +- **constraints and limitations:** Only lanes enabled by `%seed` participate. + Integer and floating-point comparisons follow their own element-type-specific + comparison rules. `%seed` and `%result` keep the typed-mask granularity that + matches `%src0` / `%src1`. + +--- + +##### `pto.vcmps` + +- **syntax:** `%result = pto.vcmps %src, %scalar, %seed, "CMP_MODE" : !pto.vreg, T, !pto.mask -> !pto.mask` +- **semantics:** Compare vector against scalar. + +```c +for (int i = 0; i < N; i++) + if (seed[i]) + dst[i] = (src[i] CMP scalar) ? 1 : 0; +``` + +**Example:** +```mlir +%positive_mask = pto.vcmps %values, %c0_f32, %all_active, "gt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +// positive_mask[i] = 1 if values[i] > 0 +``` + +- **inputs:** `%src` is the vector source, `%scalar` is the scalar comparison + value, and `%seed` is the incoming predicate. +- **outputs:** `%result` is the generated predicate mask. +- **constraints and limitations:** For 32-bit scalar forms, the scalar source + MUST satisfy the backend's legal scalar-source constraints for this family. + `%seed` and `%result` keep the typed-mask granularity that matches `%src`. + +--- + +#### Selection Operations + +##### `pto.vsel` + +- **syntax:** `%result = pto.vsel %src0, %src1, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Per-lane select based on mask. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? src0[i] : src1[i]; +``` + +**Example — Conditional assignment:** +```mlir +// dst = mask ? true_vals : false_vals +%result = pto.vsel %true_vals, %false_vals, %condition + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +- **inputs:** `%src0` is the true-path vector, `%src1` is the false-path vector, + and `%mask` selects between them. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** Source vectors and result MUST have matching + vector shapes and element types. `%mask` keeps the typed-mask granularity + that matches the selected vector family. + +--- + +##### `pto.vselr` + +- **syntax:** `%result = pto.vselr %src, %idx : !pto.vreg, !pto.vreg> -> !pto.vreg` +- **semantics:** Lane-select by index vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = src[idx[i]]; +``` + +- **inputs:** `%src` is the source vector. `%idx` is the lane-index vector. +- **outputs:** `%result` is the reordered vector. +- **constraints and limitations:** `%idx` must use integer elements. `%idx` + must have the same lane count as `%src`, and its integer element width must + match the bit width of `%src` element type. + +--- + +##### `pto.vselrv2` + +- **syntax:** `%result = pto.vselrv2 %src0, %src1 : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Variant select form with the same current two-vector operand shape. +- **inputs:** `%src0` and `%src1` are the source vectors. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** This page records the surface shape only. + Lowering MUST preserve the exact A5 variant semantics selected for this form. + +--- + +#### Typical Usage + +```mlir +// Clamp negative values to zero (manual ReLU) +%all = pto.pset_b32 "PAT_ALL" : !pto.mask +%zero = pto.vbr %c0_f32 : f32 -> !pto.vreg<64xf32> +%neg_mask = pto.vcmps %input, %c0_f32, %all, "lt" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%clamped = pto.vsel %zero, %input, %neg_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Element-wise max via compare+select +%gt_mask = pto.vcmp %a, %b, %all, "gt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +%max_ab = pto.vsel %a, %b, %gt_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Threshold filter +%above_thresh = pto.vcmps %scores, %threshold, %all, "ge" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%filtered = pto.vsel %scores, %zero, %above_thresh : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +--- + +#### Compare + Select Pattern + +```mlir +// Softmax safe exp: exp(x - max) where x < max returns exp of negative +// but we want to clamp to avoid underflow + +%all = pto.pset_b32 "PAT_ALL" : !pto.mask + +// 1. Compare against threshold +%too_small = pto.vcmps %x_minus_max, %min_exp_arg, %all, "lt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask + +// 2. Clamp values below threshold +%clamped = pto.vsel %min_exp_arg_vec, %x_minus_max, %too_small + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// 3. Safe exp +%exp_result = pto.vexp %clamped, %all : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 12. Data Rearrangement + +> **Category:** In-register data movement and permutation +> **Pipeline:** PIPE_V (Vector Core) + +Operations that rearrange data within or between vector registers without memory access. + +#### Common Operand Model + +- `%lhs` / `%rhs` are source vector register values. +- `%src` is a single source vector register value. +- `%result` is the destination vector register value unless an op explicitly + returns multiple vectors. +- These families do not access UB directly; they only rearrange register + contents. + +--- + +#### Interleave / Deinterleave + +##### `pto.vintlv` + +- **syntax:** `%low, %high = pto.vintlv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg, !pto.vreg` +- **semantics:** Interleave elements from two sources. + +```c +// Interleave: merge even/odd elements from two sources +// low = {src0[0], src1[0], src0[1], src1[1], ...} +// high = {src0[N/2], src1[N/2], src0[N/2+1], src1[N/2+1], ...} +``` + +- **inputs:** `%lhs` and `%rhs` are the two source vectors. +- **outputs:** `%low` and `%high` are the two destination vectors. +- **constraints and limitations:** The two outputs form a paired interleave + result. The PTO micro Instruction representation exposes that pair as two SSA results, and the pair ordering MUST + be preserved. + +--- + +##### `pto.vdintlv` + +- **syntax:** `%low, %high = pto.vdintlv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg, !pto.vreg` +- **semantics:** Deinterleave elements into even/odd. + +```c +// Deinterleave: separate even/odd elements +// low = {src0[0], src0[2], src0[4], ...} // even +// high = {src0[1], src0[3], src0[5], ...} // odd +``` + +- **inputs:** `%lhs` and `%rhs` represent the interleaved source stream in the + current PTO micro Instruction representation. +- **outputs:** `%low` and `%high` are the separated destination vectors. +- **constraints and limitations:** The two outputs form the even/odd + deinterleave result pair, and their ordering MUST be preserved. + +--- + +#### Compress / Expand + +##### `pto.vsqz` + +- **syntax:** `%result = pto.vsqz %src, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Compress — pack active lanes to front. + +```c +int j = 0; +for (int i = 0; i < N; i++) + if (mask[i]) dst[j++] = src[i]; +while (j < N) dst[j++] = 0; +``` + +**Use case:** Sparse data compaction, filtering. + +- **inputs:** `%src` is the source vector and `%mask` selects which elements are + kept. +- **outputs:** `%result` is the compacted vector. +- **constraints and limitations:** This is a reduction-style compaction family. + Preserved element order MUST match source lane order. + +--- + +##### `pto.vusqz` + +- **syntax:** `%result = pto.vusqz %src, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Generate per-lane prefix counts from the governing predicate. + +```c +dst[0] = 0; +for (int i = 1; i < N; i++) + dst[i] = mask[i - 1] ? (dst[i - 1] + 1) : dst[i - 1]; +``` + +- **inputs:** `%mask` is the governing predicate. The current PTO surface keeps + `%src` in the operand list for interface compatibility, but the observable + result semantics are determined by `%mask`. +- **outputs:** `%result[i]` equals the number of active lanes in `%mask[0:i)`, + with `%result[0] = 0`. +- **constraints and limitations:** `T` is currently limited to `si8`, `si16`, + or `si32`. This operation is a predicate-derived counting/rearrangement + primitive rather than a value-placement primitive. The final predicate lane + does not contribute to a later output lane because there is no `dst[N]`. + +--- + +--- + +##### `pto.vselr` + +- **syntax:** `%result = pto.vselr %src, %idx : !pto.vreg, !pto.vreg> -> !pto.vreg` +- **semantics:** Register lane-select with an explicit index vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = src[idx[i]]; +``` + +- **inputs:** `%src` is the source vector. `%idx` is the lane-index vector. +- **outputs:** `%result` is the reordered vector. +- **constraints and limitations:** This page records the rearrangement use of + the family; the compare/select page documents the same name from the predicate + selection perspective. + +--- + +#### Pack / Unpack + +##### `pto.vpack` + +- **syntax:** `%result = pto.vpack %src, "PART" : !pto.vreg -> !pto.vreg<2NxT_narrow>` +- **semantics:** Narrow one wide vector and place the narrowed payload into the + selected half of the result. The other half is filled with zero. + +```c +// e.g., vreg<64xi32> → vreg<128xui16> +for (int i = 0; i < N; i++) + dst[i] = 0; + +if (part == LOWER) { + for (int i = 0; i < N; i++) + dst[i] = truncate(src[i]); +} else { // HIGHER + for (int i = 0; i < N; i++) + dst[N + i] = truncate(src[i]); +} +``` + +- **inputs:** `%src` is the wide source vector. `"LOWER"` and `"HIGHER"` + select whether the narrowed payload lands in the lower or upper half. +- **outputs:** `%result` is the packed narrow vector. +- **constraints and limitations:** Packing is a narrowing conversion with + truncation semantics. Current VPTO surface supports `i32/ui32 -> ui16` and + `i16/ui16 -> ui8`. + +--- + +##### `pto.vsunpack` + +- **syntax:** `%result = pto.vsunpack %src, %part : !pto.vreg, index -> !pto.vreg` +- **semantics:** Sign-extending unpack — narrow to wide (half). + +```c +// e.g., vreg<128xi16> → vreg<64xi32> (one half) +for (int i = 0; i < N/2; i++) + dst[i] = sign_extend(src[part_offset + i]); +``` + +- **inputs:** `%src` is the packed narrow vector and `%part` selects which half + is unpacked. +- **outputs:** `%result` is the widened vector. +- **constraints and limitations:** This is the sign-extending unpack family. + +--- + +##### `pto.vzunpack` + +- **syntax:** `%result = pto.vzunpack %src, %part : !pto.vreg, index -> !pto.vreg` +- **semantics:** Zero-extending unpack — narrow to wide (half). + +```c +for (int i = 0; i < N/2; i++) + dst[i] = zero_extend(src[part_offset + i]); +``` + +- **inputs:** `%src` is the packed narrow vector and `%part` selects which half + is unpacked. +- **outputs:** `%result` is the widened vector. +- **constraints and limitations:** This is the zero-extending unpack family. + +--- + +#### Typical Usage + +```mlir +// AoS → SoA conversion using deinterleave +%even, %odd = pto.vdintlv %interleaved0, %interleaved1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// Filter: keep only elements passing condition +%pass_mask = pto.vcmps %values, %threshold, %all, "gt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%compacted = pto.vsqz %values, %pass_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Type narrowing via pack +%packed_i16 = pto.vpack %wide_i32, "LOWER" + : !pto.vreg<64xi32> -> !pto.vreg<128xui16> +``` + +--- + +#### V2 Interleave Forms + +##### `pto.vintlvv2` + +- **syntax:** `%result = pto.vintlvv2 %lhs, %rhs, "PART" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **inputs:** `%lhs` and `%rhs` are source vectors and `PART` selects the + returned half of the V2 interleave result. +- **outputs:** `%result` is the selected interleave half. +- **constraints and limitations:** This op exposes only one half of the V2 + result in SSA form. + +##### `pto.vdintlvv2` + +- **syntax:** `%result = pto.vdintlvv2 %lhs, %rhs, "PART" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **inputs:** `%lhs` and `%rhs` are source vectors and `PART` selects the + returned half of the V2 deinterleave result. +- **outputs:** `%result` is the selected deinterleave half. +- **constraints and limitations:** This op exposes only one half of the V2 + result in SSA form. + + + +### 13. DSA/SFU Ops + +> **Category:** Domain-specific accelerator and special function unit operations +> **Pipeline:** PIPE_V (Vector Core) / SFU + +Fused operations, special functions, and UB-to-UB operations that leverage hardware acceleration. + +#### Common Operand Model + +- `%input`, `%lhs`, `%rhs`, `%acc`, and `%alpha` are source SSA values whose + roles are called out per instruction. +- `%mask` is the predicate operand `Pg` when present. +- `%result` is the destination SSA value. +- This page mixes three different backend shapes: pure `vreg -> vreg` ops, + conversion/fusion ops, and UB-to-UB helpers. Each instruction section calls + out which storage model it uses. + +--- + +#### Fused Activation Ops (vreg→vreg) + +##### `pto.vlrelu` + +- **syntax:** `%result = pto.vlrelu %input, %alpha, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Leaky ReLU with scalar alpha. + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : alpha * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%alpha` is the scalar slope, + and `%mask` selects active lanes. +- **outputs:** `%result` is the leaky-ReLU vector. +- **constraints and limitations:** Only `f16` and `f32` forms are currently + documented for `pto.vlrelu`. + +--- + +##### `pto.vprelu` + +- **syntax:** `%result = pto.vprelu %input, %alpha : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Parametric ReLU with per-element alpha vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : alpha[i] * src[i]; +``` + +- **inputs:** `%input` is the activation vector and `%alpha` is the per-element + slope vector. +- **outputs:** `%result` is the parametric-ReLU vector. +- **constraints and limitations:** Floating-point element types only on the + current A5 surface. + +--- + +##### `pto.vexpdiff` + +- **syntax:** `%result = pto.vexpdiff %input, %max, "EVEN|ODD" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** input `f16` or `f32`, output `f32` +- **semantics:** Fused exp(x - max) for numerically stable softmax. + +```c +for (int i = 0; i < N; i++) + dst[i] = expf(src[i] - max[i]); +``` + +**Use case:** Softmax numerator computation with numerical stability. + +- **inputs:** `%input` is the source vector and `%max` is the broadcasted + subtraction term. `%part` selects `EVEN` or `ODD` for the + underlying hardware contract. +- **outputs:** `%result` is the fused `exp(input - max)` vector with `f32` + elements. +- **constraints and limitations:** Source vectors must be `f16` or `f32`, the + result vector must be `f32`, and source/result storage width must match. + +--- + +#### Fused Compute+Convert Ops + +##### `pto.vaxpy` + +- **syntax:** `%result = pto.vaxpy %src0, %src1, %alpha : !pto.vreg, !pto.vreg, T -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** AXPY — scalar-vector multiply-add. + +```c +for (int i = 0; i < N; i++) + dst[i] = alpha * src0[i] + src1[i]; +``` + +- **inputs:** `%src0` is the scaled vector, `%src1` is the addend vector, and + `%alpha` is the scalar multiplier. +- **outputs:** `%result` is the fused AXPY result. +- **constraints and limitations:** Floating-point element types only on the + current documented surface. + +--- + +#### Extended Arithmetic + +##### `pto.vmull` + +- **syntax:** `%low, %high = pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` +- **A5 types:** i32/ui32 (native 32×32→64 widening multiply) +- **semantics:** Widening multiply with high/low results. + +```c +for (int i = 0; i < 64; i++) { + int64_t r = (int64_t)src0_i32[i] * (int64_t)src1_i32[i]; + dst_lo[i] = (int32_t)(r & 0xFFFFFFFF); + dst_hi[i] = (int32_t)(r >> 32); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the source vectors and `%mask` selects + active lanes. +- **outputs:** `%low` and `%high` expose the widened-product low/high parts. +- **constraints and limitations:** The current documented A5 form is the native + widening 32x32->64 integer multiply family. + +--- + +##### `pto.vmula` + +- **syntax:** `%result = pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Multiply-accumulate. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) + dst[i] = acc[i] + lhs[i] * rhs[i]; +``` + +- **inputs:** `%acc` is the accumulator input, `%lhs` and `%rhs` are the + multiplicands, and `%mask` selects active lanes. +- **outputs:** `%result` is the multiply-accumulate result. +- **constraints and limitations:** `pto.vmula` is a fused multiply-accumulate + operation and is not always interchangeable with separate `vmul` plus `vadd`. + +--- + +#### Index Generation + +##### `pto.vci` + +- **syntax:** `%result = pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- **semantics:** Generate lane index vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = base_index + i; +``` + +**Use case:** Generate indices for gather/scatter, argsort, etc. + +- **inputs:** `%index` is the scalar seed/base index. +- **outputs:** `%result` is the generated index vector. +- **constraints and limitations:** This page documents the arithmetic/indexing + use of the family; the conversion page also records the same opcode for + completeness. + +--- + +#### Sorting Operations + +##### `pto.vbitsort` + +- **syntax:** `pto.vbitsort %dest, %src, %indices, %repeat_times : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, index` +- **semantics:** Sort 32 region proposals by score and materialize sorted + proposal records into `%dest`. +- **inputs:** `%dest` is the UB destination buffer. `%src` is the UB score + buffer. `%indices` is the UB index buffer. `%repeat_times` is the repeat + count; each repeat processes the next adjacent group of 32 scores and 32 + indices. +- **outputs:** This op writes UB memory and returns no SSA value. Each output + record occupies 8 bytes: the upper 4 bytes hold the index and the lower + 4 bytes hold the score. For `f16` score forms, the score uses the lower + 2 bytes of that 4-byte score field and the upper 2 bytes are reserved. +- **constraints and limitations:** `%dest`, `%src`, and `%indices` MUST be + UB-backed pointers and SHOULD satisfy the backend alignment contract expected + by the A5 `VBS32` instruction. Scores are sorted in descending order, so the + highest score is written to the lowest destination address. Equal-score ties + preserve the earlier input proposal first. This is a UB helper, not a pure + `vreg -> vreg` op. + +--- + +##### `pto.vmrgsort4` + +- **syntax:** `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr, i64, i64` +- **semantics:** Merge-sort 4 pre-sorted input vectors. +- **inputs:** `%dest` is the UB destination, `%src0..%src3` are the four + pre-sorted UB inputs, `%count` is the number of valid elements, and `%config` + is the operation control word. +- **outputs:** This op writes UB memory and returns no SSA value. +- **constraints and limitations:** Inputs MUST already be sorted according to + the sort order encoded by `%config`. + +--- + +#### Current Implementation Surface Summary + +- `pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` +- `pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- `pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- `pto.vbitsort %dest, %src, %indices, %repeat_times : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, index` +- `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, i64, i64` + +--- + +#### Typical Usage + +```mlir +// Softmax with fused expdiff +%max_broadcast = pto.vlds %ub_max[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> +%exp_stable = pto.vexpdiff %logits, %max_broadcast : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// Leaky ReLU activation +%activated = pto.vlrelu %linear_out, %alpha_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Generate indices for argsort +%indices = pto.vci %c0 {order = "ASC"} : i32 -> !pto.vreg<64xi32> +``` + + + +### 14. Arith (Shared MLIR Dialect) + +> **Category:** Shared full scalar `arith` surface used around PTO ops +> **Dialect:** `arith` +> **Upstream Reference:** https://mlir.llvm.org/docs/Dialects/ArithOps/ + +The upstream MLIR `arith` dialect defines primitive arithmetic, comparison, select, and cast operations over signless integer, index, floating-point, and boolean-compatible scalar values. Within PTO micro Instruction code, the full scalar operation surface of `arith` is supported. These ops are used around PTO instructions to build constants, compute offsets and loop bounds, perform general scalar math, derive valid-shape metadata, and form predicates for `scf` control flow. + +These ops are part of the documented PTO micro Instruction surface, but they are not PTO ISA instructions. + +--- + +#### Role in PTO micro Instruction Code + +- materialize scalar constants used by PTO scalar operands and loop bounds +- compute scalar/index offsets for tensor views, partitioning, and dynamic shapes +- perform general scalar integer and floating-point math outside PTO vector/tile payload operations +- derive scalar predicates that guard `scf.if` or `scf.while` +- apply scalar casts, width changes, bitwise ops, and selects without introducing PTO-specific control ops + +Prefer PTO ops for vector or tile payload math. Use `arith` for scalar computation and bookkeeping that surrounds PTO regions. + +--- + +#### Supported Surface + +The documented PTO micro Instruction surface supports the full scalar operation surface of upstream `arith`. The upstream `arith` dialect reference remains authoritative for the exhaustive op-by-op syntax and semantics. The categories below summarize how that support is used in PTO micro Instruction code. + +| Category | Representative Ops | Typical Use in PTO micro Instruction Code | +|----------|--------------------|------------------| +| Constants | `arith.constant` | integer, floating-point, boolean, and `index` constants | +| Integer / Index Arithmetic | `arith.addi`, `arith.subi`, `arith.muli`, `arith.divsi`, `arith.divui`, `arith.ceildivsi`, `arith.ceildivui`, `arith.floordivsi`, `arith.remsi`, `arith.remui` | offsets, bounds, chunk sizes, scalar math | +| Floating-Point Arithmetic | `arith.addf`, `arith.subf`, `arith.mulf`, `arith.divf`, `arith.negf`, `arith.maximumf`, `arith.minimumf`, `arith.maxnumf`, `arith.minnumf` | scalar math around PTO regions | +| Bitwise / Shift Ops | `arith.andi`, `arith.ori`, `arith.xori`, `arith.shli`, `arith.shrsi`, `arith.shrui` | flags, masks, packed scalar fields | +| Comparisons / Select | `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.maxsi`, `arith.minui` | predicates, clamps, scalar muxes | +| Casts / Width Changes | `arith.index_cast`, `arith.index_castui`, `arith.extsi`, `arith.extui`, `arith.trunci`, `arith.sitofp`, `arith.uitofp`, `arith.fptosi`, `arith.fptoui`, `arith.extf`, `arith.truncf`, `arith.bitcast` | ABI glue, dynamic-shape plumbing, scalar type adaptation | + +--- + +#### Current PTOAS Coverage + +- the current repository examples are still dominated by constants, casts, integer/index arithmetic, compares, and selects because those are the most common surrounding-scalar patterns in existing kernels +- backend-specific tests such as the PTO shared-dialect fixture visibly exercise only a representative subset of `arith` ops in a single path +- the documented PTO micro Instruction source-level contract is nevertheless the full scalar `arith` surface, not just the index-heavy subset that appears most often in current samples + +This section therefore uses representative categories and examples instead of pretending that the supported `arith` surface is limited to the currently most common sample patterns. + +--- + +#### Typical Patterns + +##### Scalar Setup + +```mlir +%c0 = arith.constant 0 : index +%c1 = arith.constant 1 : index +%scale = arith.constant 2.0 : f32 +``` + +##### Dynamic Offset Computation + +```mlir +%vrow = arith.index_cast %valid_row : i32 to index +%chunk = arith.muli %row, %c32 : index +%tail = arith.subi %limit, %chunk : index +``` + +##### General Scalar Arithmetic + +```mlir +%sum_i = arith.addi %lhs_i, %rhs_i : i32 +%sum_f = arith.addf %lhs_f, %rhs_f : f32 +%prod_f = arith.mulf %sum_f, %scale : f32 +``` + +##### Scalar Predicate and Selection + +```mlir +%is_first = arith.cmpi eq, %i, %c0 : index +%active = arith.select %is_first, %first_count, %steady_count : index +``` + +##### Bitwise / Width Adaptation + +```mlir +%flags = arith.andi %flags0, %flags1 : i32 +%wide = arith.extui %flags : i32 to i64 +%shrunk = arith.trunci %wide : i64 to i16 +``` + +--- + +#### Authoring Guidance + +- treat upstream `arith` scalar semantics as the source of truth for supported scalar ops +- keep `arith` values scalar or `index` typed; do not use `arith` as a substitute for PTO vector/tile compute +- use `arith` for general scalar math, scalar comparisons, bitwise operations, and casts around PTO regions, not just for `index` arithmetic +- use `arith.cmpi` / `arith.cmpf` plus `scf.if` / `scf.while` for control flow, not ad hoc control intrinsics +- prefer `arith.index_cast` / `arith.index_castui` at ABI or shape boundaries where `index` is required, but do not read that as a restriction on the rest of scalar `arith` + + + +### 15. SCF (Shared MLIR Dialect) + +> **Category:** Shared structured control flow around PTO regions +> **Dialect:** `scf` +> **Upstream Reference:** https://mlir.llvm.org/docs/Dialects/SCFDialect/ + +The upstream MLIR `scf` dialect defines structured control flow operations with regions, including counted loops, conditional regions, and while-style loops. In PTO micro Instruction code, `scf` is the control shell around PTO ops: it sequences DMA, vector, and tile operations; carries scalar or tile state across iterations; and preserves analyzable control flow for PTO-specific analyses and lowerings. + +These ops are part of the documented PTO micro Instruction surface, but they are shared MLIR control-flow constructs rather than PTO ISA instructions. + +--- + +#### Supported Ops + +| Op | Role in PTO micro Instruction Code | Notes | +|----|------------------------|-------| +| `scf.for` | counted loops and loop-carried values | common structured counted loop form | +| `scf.if` | structured conditional execution | may yield values or act as side-effect-only branch | +| `scf.yield` | region terminator for `for` / `if` / `while` bodies | carries loop or branch results | +| `scf.while` | break-like or stateful loops | useful for source-level structured control | +| `scf.condition` | loop-continue / loop-exit decision for `scf.while` | placed in the "before" region | + +Ops such as `scf.execute_region`, `scf.forall`, or `scf.index_switch` are not part of the documented shared-dialect portion of the PTO micro Instruction surface here. + +--- + +#### Current PTOAS Coverage + +- `scf.for`, `scf.if`, and `scf.yield` are directly exercised in the shared-dialect PTO fixture and appear widely across PTO samples +- PTO synchronization and memory analyses explicitly reason about `scf.for`, `scf.if`, `scf.yield`, and `scf.while` +- `scf.while` and `scf.condition` appear in control-flow samples and are handled in PTO-to-EmitC control-flow lowering, but they are less broadly exercised than `for` / `if` on all backend paths + +--- + +#### Typical Patterns + +##### Counted Loop + +```mlir +scf.for %i = %c0 to %c4 step %c1 { + %offset = arith.muli %i, %c32 : index + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} +``` + +##### Counted Loop with Loop-Carried State + +```mlir +%final_alive = scf.for %i = %c0 to %c4 step %c1 + iter_args(%alive = %true) -> (i1) { + %break_now = arith.cmpi eq, %i, %c2 : index + %next_alive = scf.if %break_now -> (i1) { + scf.yield %false : i1 + } else { + scf.yield %alive : i1 + } + scf.yield %next_alive : i1 +} +``` + +##### Structured Conditional Region + +```mlir +%is_mode_a = arith.cmpi eq, %mode, %c0_i32 : i32 +scf.if %is_mode_a { + pto.tmuls ins(%data, %scale_a : !pto.tile_buf<...>, f32) outs(%data : !pto.tile_buf<...>) +} else { + pto.tadds ins(%data, %bias_b : !pto.tile_buf<...>, f32) outs(%data : !pto.tile_buf<...>) +} +``` + +##### While-Style Break Loop + +```mlir +%final:2 = scf.while (%i = %c0, %alive = %true) : (index, i1) -> (index, i1) { + %lt = arith.cmpi slt, %i, %c4 : index + %go = arith.andi %lt, %alive : i1 + scf.condition(%go) %i, %alive : index, i1 +} do { +^bb0(%i2: index, %alive2: i1): + %next_i = arith.addi %i2, %c1 : index + scf.yield %next_i, %alive2 : index, i1 +} +``` + +--- + +#### Authoring Guidance + +- use `scf.for` for regular counted loops and loop-carried scalar/tile state +- use `scf.if` for structured branching around PTO regions instead of inventing PTO-specific branch ops +- keep region results explicit with `scf.yield`; this is important for PTO analyses that track carried buffers and aliasing +- use `scf.while` only when a counted loop cannot express the control cleanly; `scf.for` remains the more common and better-exercised form in the current repository +- build branch predicates and loop conditions with `arith` ops, not PTO vector masks, unless the control decision truly comes from a scalarized value + +## Supported Data Types + +| Type | Bits | vreg Lanes | Description | +|------|------|-----------|-------------| +| `i8` / `si8` / `ui8` | 8 | 256 | Signless/signed/unsigned 8-bit integer | +| `i16` / `si16` / `ui16` | 16 | 128 | Signless/signed/unsigned 16-bit integer | +| `f16` | 16 | 128 | IEEE 754 half precision | +| `bf16` | 16 | 128 | Brain floating point | +| `i32` / `si32` / `ui32` | 32 | 64 | Signless/signed/unsigned 32-bit integer | +| `f32` | 32 | 64 | IEEE 754 single precision | +| `i64` / `si64` / `ui64` | 64 | 32 | Signless/signed/unsigned 64-bit integer | + +--- + +## Common Patterns + +### Softmax (Numerically Stable) + +```mlir +// 1. Find max +%max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %max_vec, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +%max_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> + +// 2. exp(x - max) using fused op +%exp = pto.vexpdiff %logits, %max_bc, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// 3. Sum +%sum = pto.vcadd %exp, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %sum, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +%sum_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> + +// 4. Divide +%softmax = pto.vdiv %exp, %sum_bc, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +### ReLU Variants + +```mlir +// Standard ReLU +%relu = pto.vrelu %input, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Leaky ReLU (scalar alpha) +%lrelu = pto.vlrelu %input, %alpha, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Parametric ReLU (per-element alpha) +%prelu = pto.vprelu %input, %alpha_vec : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +``` + +### Data Layout Conversion + +```mlir +// AoS → SoA (deinterleave) +%x, %y = pto.vldsx2 %ub_xy[%offset], "DINTLV" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// SoA → AoS (interleave) +pto.vstsx2 %x, %y, %ub_xy[%offset], "INTLV", %all_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, index, !pto.mask +``` + +--- + +--- + +## Quick Reference by Category + +### Memory Operations + +| Operation | Group | Description | +|-----------|-------|-------------| +| GM→UB DMA | 2 | `pto.copy_gm_to_ubuf` | +| UB→GM DMA | 2 | `pto.copy_ubuf_to_gm` | +| UB→UB Copy | 2 | `pto.copy_ubuf_to_ubuf` | +| Contiguous Load | 3 | `pto.vlds` with `NORM` dist | +| Broadcast Load | 3 | `pto.vlds` with `BRC` family dist | +| Gather | 3 | `pto.vgather2`, `pto.vgatherb` | +| Contiguous Store | 3 | `pto.vsts` with `NORM` dist | +| Scatter | 3 | `pto.vscatter` | + +### Compute Operations + +| Operation | Group | Description | +|-----------|-------|-------------| +| Element-wise Arithmetic | 6, 7 | `pto.vadd`, `pto.vmul`, `pto.vabs`, etc. | +| Scalar Operations | 8 | `pto.vadds`, `pto.vmuls`, etc. | +| Transcendental | 6 | `pto.vexp`, `pto.vln`, `pto.vsqrt`, etc. | +| Reduction | 10 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin` | +| Comparison | 11 | `pto.vcmp`, `pto.vcmps` | +| Selection | 11 | `pto.vsel`, `pto.vselr` | + +### Type & Data Manipulation + +| Operation | Group | Description | +|-----------|-------|-------------| +| Type Conversion | 9 | `pto.vcvt` | +| Interleave/Deinterleave | 12 | `pto.vintlv`, `pto.vdintlv` | +| Interleave/Deinterleave (not A5) | 12 | `pto.vintlvv2`, `pto.vdintlvv2` | + +### Synchronization + +| Operation | Group | Description | +|-----------|-------|-------------| +| Intra-core Sync | 1 | `pto.set_flag`, `pto.wait_flag` | +| Pipeline Buffer Sync | 1 | `pto.get_buf`, `pto.rls_buf` | + +### Scalar & Control Operations + +Group 14 covers the full scalar `arith` surface. The rows below list common PTO micro Instruction patterns rather than an exhaustive partition of `arith` ops. + +| Operation | Group | Description | +|-----------|-------|-------------| +| Scalar Constants | 14 | `arith.constant` | +| Scalar Integer / Index Arithmetic | 14 | `arith.addi`, `arith.subi`, `arith.muli`, `arith.divsi`, `arith.remui`, `arith.ceildivsi`, etc. | +| Scalar Floating-Point Arithmetic | 14 | `arith.addf`, `arith.subf`, `arith.mulf`, `arith.divf`, `arith.maximumf`, etc. | +| Scalar Compare & Select | 14 | `arith.cmpi`, `arith.cmpf`, `arith.select` | +| Scalar Casts / Width Changes | 14 | `arith.index_cast`, `arith.index_castui`, `arith.extsi`, `arith.extui`, `arith.trunci`, `arith.sitofp`, etc. | +| Scalar Bitwise / Shift Ops | 14 | `arith.andi`, `arith.ori`, `arith.xori`, `arith.shli`, `arith.shrsi`, `arith.shrui`, etc. | +| Counted Loops | 15 | `scf.for` | +| Conditional Regions | 15 | `scf.if`, `scf.yield` | +| Break-like Structured Loops | 15 | `scf.while`, `scf.condition`, `scf.yield` | diff --git a/tilelang-dsl/docs/vpto_spec/vpto-spec-v0.2.md b/tilelang-dsl/docs/vpto_spec/vpto-spec-v0.2.md new file mode 100644 index 000000000..3c1e31419 --- /dev/null +++ b/tilelang-dsl/docs/vpto_spec/vpto-spec-v0.2.md @@ -0,0 +1,5072 @@ +# PTO micro Instruction Spec — Draft (A5) + +- v0.2: Update micro Instruction latency and throughput +- v0.1: Doc Init + +[toc] + +--- + +## Part I: Architecture Overview + +### Overview + +This document defines the PTO micro Instruction, a compiler-internal and externally facing specification designed to represent vector compute kernels within the PTO architecture. Much like NVVM provides a robust IR for GPU architectures, the PTO micro Instruction serves as the direct bridge between high-level programming models and the underlying hardware ISA, providing a precise, low-level representation of vector workloads explicitly designed for the Ascend 950 architecture. + +#### Position in the Stack and Layer Modeled + +The PTO micro Instruction operates as a very low-level intermediate representation within the PTO compiler stack. It is uniquely designed to accurately and comprehensively express all architectural information of the Ascend 950 hardware. It specifically models the bare-metal vector execution layer, making hardware-specific capabilities and constraints, such as exact vector lane configurations, memory space hierarchies, and hardware-specific fusion semantics, fully transparent and controllable. + +#### PTO Instruction Modes and Compilation Flows + +Within the end-to-end PTO software stack, PTO instructions may appear in three closely related authoring or lowering modes: + +- **PTO Tile Instruction**: tile-oriented PTO code that serves as a nano-kernel encapsulation of Tile operations, primarily expressing computation and data movement in terms of tile buffers, tile shapes, and tile-local layout. +- **PTO micro Instruction**: vector-execution-oriented PTO code that makes DMA setup, vector registers, masks, synchronization, and `__VEC_SCOPE__` boundaries explicit. This document is centered on this mode. +- **PTO Tile+micro Instruction**: a hybrid PTO form that keeps tile-level orchestration while embedding explicit micro-instruction regions where direct vector-pipeline control is required. + +From these PTO instruction forms, the stack can proceed along two main compilation flows: + +- **CCE generation flow**: PTO ISA is lowered into a CCE-oriented representation, which is then compiled by the BiSheng toolchain into Ascend device binaries. +- **Bytecode generation flow**: PTO ISA is emitted as bytecode, which is then compiled by the BiSheng toolchain into Ascend device binaries. + +```text +High-level frameworks / DSLs / library kernels + | + v + +----------------------------------+ + | PTO ISA layer | + | | + | (1) PTO Tile Instruction | + | (2) PTO micro Instruction | + | (3) PTO Tile+micro Instruction | + +----------------+-----------------+ + | + +------------+------------+ + | | + v v + +-------------------------+ +-------------------------+ + | Path A: generate CCE | | Path B: generate | + | (CCE-oriented form) | | bytecode | + +------------+------------+ +------------+------------+ + | | + v v + +-----------------------------------------------+ + | BiSheng compiler | + +---------------------------+-------------------+ + | + v + +-----------------------------+ + | Ascend device binaries | + +-----------------------------+ +``` + +#### Why External Developers Read or Author PTO micro Instruction + +While the majority of users will interact with the PTO architecture via higher-level frameworks, external developers may need to read or author PTO micro Instruction directly for several key reasons: + +- Custom Toolchain Development: build custom compiler frontends or domain-specific languages (DSLs) that target the Ascend 950 architecture with maximum hardware utilization. +- Performance Engineering: inspect the output of high-level compiler passes, verify fine-grained optimization behaviors, and pinpoint performance bottlenecks at the architectural level. +- Micro-Optimization: hand-author highly optimized, critical mathematical kernels using a stable, precise IR when higher-level abstractions cannot achieve the theoretical peak performance of the hardware. + +#### Relationship to CCE + +The PTO micro Instruction is designed to express the full semantic capabilities of the Compute Cube Engine (CCE), but with significant structural and pipeline advantages for compiler development. + +- Bypassing the C/Clang Pipeline: while CCE heavily relies on C/C++ extensions parsed by Clang, the PTO micro Instruction operates entirely independently of the C language frontend. By bypassing Clang AST generation and frontend processing, utilizing the PTO micro Instruction significantly reduces overall compilation time and memory overhead. +- Enhanced IR Verification: because the PTO micro Instruction is a strongly typed, SSA-based (Static Single Assignment) compiler IR rather than a C-wrapper API, it provides a much more rigorous and detailed IR verification process. Structural inconsistencies, invalid memory access patterns, and operand type mismatches are caught immediately with precise, explicit diagnostic feedback, providing developers with much higher visibility into kernel correctness than traditional CCE error reporting. + +#### Intended Audience + +This document is written for compiler engineers, library writers, and advanced performance architects. We expect the reader to have a working understanding of modern compiler infrastructure, specifically MLIR, the principles of Static Single Assignment (SSA) form, and a deep understanding of the vector-processing capabilities of the Ascend 950 architecture. + +### Getting Started + +The PTO micro Instruction is architected as a performance-critical layer within the compiler stack, specifically designed to exploit the **Decoupled Access-Execute** (DAE) nature of the Ascend 950 hardware. + +#### Hardware Pipeline Modeling + +The IR is structured to mirror the three primary hardware pipelines of the Ascend 950 architecture. Correct PTO micro Instruction authoring requires managing the interaction between these asynchronous units: + +**MTE2** (Memory Transfer Engine - Inbound): Responsible for moving data from Global Memory (GM) to the Unified Buffer (UB). + +**Vector Core** (Computation): The primary engine for executing SIMD operations on data stored in UB. + +**MTE3** (Memory Transfer Engine - Outbound): Responsible for moving processed data from UB back to GM. + +#### Architecture Detail: Vector Lane (VLane) + +The vector register is organized as **8 VLanes** of 32 bytes each. A VLane is the atomic unit for group reduction operations. + +``` +vreg (256 bytes total): +┌─────────┬─────────┬─────────┬─────┬─────────┬─────────┐ +│ VLane 0 │ VLane 1 │ VLane 2 │ ... │ VLane 6 │ VLane 7 │ +│ 32B │ 32B │ 32B │ │ 32B │ 32B │ +└─────────┴─────────┴─────────┴─────┴─────────┴─────────┘ +``` + +Elements per VLane by data type: + +| Data Type | Elements/VLane | Total Elements/vreg | +|-----------|---------------|-------------------| +| i8/u8 | 32 | 256 | +| i16/u16/f16/bf16 | 16 | 128 | +| i32/u32/f32 | 8 | 64 | +| i64/u64 | 4 | 32 | + +#### Memory and Synchronization Model + +The PTO micro Instruction enforces a strict memory hierarchy. The Unified Buffer (UB) is the only valid operand source for vector compute instructions. Consequently, the architecture of a PTO micro Instruction program is defined by the explicit management of data movement: + +**Address Space Isolation**: The IR uses `!pto.ptr` to distinguish between GM (`!pto.ptr`) and UB (`!pto.ptr`). The verifier ensures that vector compute operations do not access GM directly; data must first be moved into UB. + +**UB Capacity**: The Unified Buffer provides 256KB of on-chip SRAM (also referred to as "vecTile"). + +**Data Flow**: + +``` +┌─────────────────────────────────────────────┐ +│ Global Memory (GM) │ +│ (Off-chip HBM/DDR) │ +└─────────────────────┬───────────────────────┘ + │ DMA (MTE2 inbound / MTE3 outbound) +┌─────────────────────▼───────────────────────┐ +│ Unified Buffer (UB) │ +│ (On-chip SRAM, 256KB) │ +└─────────────────────┬───────────────────────┘ + │ Vector Load/Store (PIPE_V) +┌─────────────────────▼───────────────────────┐ +│ Vector Register File (VRF) │ +│ vreg (256B each) + mask (256-bit each) │ +└─────────────────────────────────────────────┘ +``` + +1. **GM → UB**: DMA transfer via MTE2 (`pto.copy_gm_to_ubuf`) +2. **UB → vreg**: Vector Load instructions (`pto.vlds`, `pto.vldx2`, etc.) +3. **vreg → vreg**: Compute instructions (`pto.vadd`, `pto.vmul`, etc.) +4. **vreg → UB**: Vector Store instructions (`pto.vsts`, `pto.vstx2`, etc.) +5. **UB → GM**: DMA transfer via MTE3 (`pto.copy_ubuf_to_gm`) + +**Load/Store Access Patterns**: + +For UB↔vreg data movement, besides contiguous load/store, the architecture provides rich access pattern support including strided access, pack/unpack, interleave/deinterleave, broadcast, upsample/downsample, channel split/merge, gather/scatter, and squeeze/expand operations. For detailed instruction syntax and distribution modes, refer to the [Vector Load/Store](#isa-03-vector-load-store) group in the ISA specification. + +#### Synchronization Model + +The Ascend 950 architecture employs a cluster-based design with a 1:2 ratio of Cube cores to Vector cores. The PTO micro Instruction provides multiple levels of synchronization to manage concurrent execution across pipelines and cores: + +**Inter-Core Synchronization (within a cluster):** + +Synchronization between cores within the same cluster is achieved via the core sync mechanism using `pto.set_intra_core` and `pto.wait_intra_core` operations. This enables coordination between Cube and Vector cores sharing the same cluster resources. + +**Vector Core Pipeline Synchronization:** + +Within a single core, multiple pipelines operate asynchronously: + +- **MTE2 (PIPE_MTE2)**: DMA copy-in from GM to UB +- **MTE3 (PIPE_MTE3)**: DMA copy-out from UB to GM +- **Vector Compute (PIPE_V)**: Vector ALU operations +- **Scalar (PIPE_S)**: Scalar unit running the kernel program + +Pipeline synchronization can be achieved through two mechanisms: + +1. **Flag/Event mechanism**: `pto.set_flag` and `pto.wait_flag` operations resolve Read-After-Write (RAW) and Write-After-Read (WAR) hazards between pipelines. + +2. **Buffer-ID mechanism**: `pto.get_buf` and `pto.rls_buf` provide finer-grained synchronization through buffer acquisition and release semantics for producer-consumer coordination. + +**Intra-Pipeline Memory Barriers (within `__VEC_SCOPE__`):** + +Within the vector execution scope, the hardware does not track UB address aliasing between reg↔UB accesses. When UB addresses overlap or alias between vector load/store operations, explicit memory barriers are required: + +```c +pto.mem_bar "VV_ALL" // All prior vector ops complete before subsequent +pto.mem_bar "VST_VLD" // All prior vector stores visible before subsequent loads +pto.mem_bar "VLD_VST" // All prior vector loads complete before subsequent stores +``` + +Without proper barriers, loads may see stale data or stores may be reordered incorrectly. + +#### Execution Scopes (__VEC_SCOPE__) + +`__VEC_SCOPE__` is the IR-level representation of a Vector Function (VF) launch. In the PTO architecture, it defines the hardware interface between the Scalar Unit and the Vector Thread. + +In PTO micro Instruction source IR, vector execution scopes are modeled as dedicated region ops. The default form is `pto.vecscope`; when the scope body must reject implicit capture and require explicit region arguments, use `pto.strict_vecscope`. + +**Scalar-Vector Interface:** + +The execution model follows non-blocking fork semantics: + +- Scalar invocation: the scalar processor invokes a vector thread by calling a VF. Once the launch command is issued, the scalar unit does not stall and continues executing subsequent instructions in the pipeline. +- Vector execution: after invocation, the vector thread independently fetches and executes the instructions defined within the VF scope. +- Parallelism: this decoupled execution allows the scalar and vector units to run in parallel, so the scalar unit can prepare addresses or manage control flow while the vector unit performs heavy SIMD computation. + +**Launch Mechanism And Constraints:** + +- Parameter buffering: all arguments required by the VF must be staged in hardware-specific buffers. +- Launch overhead: launching a VF incurs a latency of a few cycles. Very small VFs should account for this overhead because launch cost can rival useful computation time. + +**MLIR Representation:** + +```mlir +pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} +``` + +**Strict MLIR Representation:** + +```mlir +pto.strict_vecscope(%ub, %ub_out, %lane) { +^bb0(%in: !pto.ptr, %out: !pto.ptr, %iv: index): + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %in[%iv] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %out[%iv], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} : (!pto.ptr, !pto.ptr, index) -> () +``` + +`pto.strict_vecscope` is the strict form of `pto.vecscope`. + +- `pto.vecscope` allows the body to use surrounding SSA values directly. +- `pto.strict_vecscope` requires every external value used by the body to be passed through the op operand list and received as a body block argument. +- `pto.strict_vecscope` rejects implicit capture from the surrounding scope. +- both ops still represent one explicit VPTO vector interval. + +### Example: VecScope + +```mlir +pto.set_loop2_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.copy_gm_to_ubuf %7, %2, %3, %3, %c0_i64, %c32_i64, %4, %c0_i64, %c0_i64, + %false, %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i1, i64, i64, i64 + +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +pto.vecscope { + scf.for %lane = %c0 to %9 step %c64 { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %2[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %8[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } +} + +pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] +pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop2_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 +pto.copy_ubuf_to_gm %8, %14, %3, %3, %c0_i64, %c32_i64, %4, %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i64 +``` + +### Example: Strict VecScope + +```mlir +pto.strict_vecscope(%ub_in, %ub_out, %lane, %remaining) { +^bb0(%in: !pto.ptr, %out: !pto.ptr, %iv: index, %rem: i32): + %mask, %next_remaining = pto.plt_b32 %rem : i32 -> !pto.mask, i32 + %v = pto.vlds %in[%iv] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %out[%iv], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} : (!pto.ptr, !pto.ptr, index, i32) -> () +``` + +Use `pto.strict_vecscope` when the source form should make all vector-scope inputs explicit in the region signature instead of relying on surrounding SSA visibility. The scope op itself only defines the vector-interval boundary and region argument contract. + +### Scope + +This document is the interface specification centered on the `mlir::pto` dialect and the shared MLIR surface used alongside it in PTO micro Instruction programs. + +It only describes: + +- operation names +- operand and result lists +- operand and result types +- important attributes +- C-style semantics for each operation + +It does not describe lowering strategy. + +PTO micro Instruction source programs are not restricted to `pto` operations alone. In practice they also use shared MLIR dialect ops, most notably the full scalar operation surface of `arith` together with structured control-flow ops from `scf`, to express scalar constants, scalar arithmetic, type conversion, comparisons, and structured control flow around PTO vector or tile regions. These shared-dialect ops are part of the supported PTO micro Instruction source surface and should be regarded as part of PTO-ISA alongside `pto` dialect operations. + +- `vreg`: `!pto.vreg` + Fixed-width VPTO vector type with total width exactly 256 bytes. +- `mask`: `!pto.mask` + Typed predicate-register view. `G` is one of `b8`, `b16`, `b32` and records the byte-granularity interpretation used by VPTO ops and verifiers. +- `align`: `!pto.align` +- `buf`: buffer-like LLVM pointer type accepted by the dialect +- `buf_like`: `memref<...>` or `!llvm.ptr` for stateless/predicate + `vld*/vst*` families +- `idx`: `index` +- `i32`: `i32` +- `i64`: `i64` + +### Shared MLIR Dialects + +- `arith`: the full scalar `arith` surface is supported in PTO micro Instruction programs, covering scalar integer, floating-point, boolean, and `index` operations. In current samples the most common uses are still constants, offset/bounds arithmetic, casts, compares, and selects. +- `scf`: structured control flow used to model counted loops, conditional regions, loop-carried state, and break-like control around PTO compute and data-movement ops. +- Shared dialect ops remain in standard MLIR form so that PTO analyses and backend passes can reason about control flow and scalar state without re-encoding them as PTO-specific instructions. + +### Core Types + +### Element Types +`vreg`: `!pto.vreg` Fixed-width PTO micro Instruction vector type with total width exactly 256 bytes (2048 bits). `N` is the lane count, `T` is the element type, and `N * bitwidth(T) = 2048`. + +| Type | Bits | Description | +|------|------|-------------| +| `i8` / `s8` / `u8` | 8 | Signless/signed/unsigned 8-bit integer | +| `i16` / `s16` / `u16` | 16 | Signless/signed/unsigned 16-bit integer | +| `i32` / `s32` / `u32` | 32 | Signless/signed/unsigned 32-bit integer | +| `i64` / `s64` / `u64` | 64 | Signless/signed/unsigned 64-bit integer | +| `f16` | 16 | IEEE 754 half precision | +| `bf16` | 16 | Brain floating point | +| `f32` | 32 | IEEE 754 single precision | +| `f8e4m3` | 8 | FP8 (4-bit exponent, 3-bit mantissa) | +| `f8e5m2` | 8 | FP8 (5-bit exponent, 2-bit mantissa) | + +### Address Space Conventions + +PTO micro Instruction memory operands use `!pto.ptr`. This specification models the following memory-space attributes: + +| Space | Interpretation | +|-------|----------------| +| `gm` | Global Memory (GM), off-chip HBM/DDR storage | +| `ub` | Unified Buffer (UB), on-chip vector buffer | + +Typical pointer construction and pointer arithmetic follow the same `!pto.ptr<..., space>` form: + +```mlir +%0 = pto.castptr %c0 : i64 -> !pto.ptr +%1 = pto.addptr %0, %c1024 : !pto.ptr -> !pto.ptr +``` + +### `!pto.ptr` + +`!pto.ptr` is the typed pointer form used for explicit memory operands in PTO micro Instruction. + +- `T` is the element type associated with the pointed-to storage. +- `space` is the memory domain, typically `gm` or `ub` in this specification. +- A `pto.ptr` value carries an address plus its element-type / memory-space interpretation, but it does not carry tensor shape or stride metadata by itself. +- Tensor semantics are introduced separately through view-building operations such as `pto.make_tensor_view`. +- Pointer arithmetic is element-based rather than byte-based. + +Typical examples: + +- `!pto.ptr` +- `!pto.ptr` +- `!pto.ptr` + +### Pointer Operations + +#### `pto.castptr` + +- **syntax:** `%result = pto.castptr %addr : i64 -> !pto.ptr` +- **semantics:** Reinterpret a scalar address value as a typed PTO pointer in the target memory space. + +```c +result = (ptr)addr; +``` + +`pto.castptr` is a pointer-construction operation. It does not perform data movement and does not by itself imply any load/store side effect. + +#### `pto.addptr` + +- **syntax:** `%result = pto.addptr %ptr, %offset : !pto.ptr -> !pto.ptr` +- **semantics:** Compute a new pointer by advancing the base pointer by an element offset. + +```c +result = ptr + offset; // offset counted in elements, not bytes +``` + +`pto.addptr` preserves both the element type `T` and the memory-space tag `space`. + +#### Pointer-Based Vector Access Example + +The following lowered-style fragment shows how typed PTO pointers flow through pointer construction, pointer arithmetic, structured control flow, and PTO memory ops: + +```mlir +%0 = pto.castptr %c0 : i64 -> !pto.ptr +%1 = pto.addptr %0, %c1024 : !pto.ptr -> !pto.ptr +pto.vecscope { + %16 = scf.for %arg3 = %c0 to %11 step %c64 iter_args(%arg4 = %12) -> (i32) { + %mask, %scalar_out = pto.plt_b32 %arg4 : i32 -> !pto.mask, i32 + %17 = pto.vlds %1[%arg3] : !pto.ptr -> !pto.vreg<64xf32> + %18 = pto.vabs %17, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %18, %10[%arg3], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %scalar_out : i32 + } +} +``` + +In this pattern, `pto.castptr` materializes a typed UB pointer, `pto.addptr` shifts the base by 1024 `f32` elements, and the subsequent `[%arg3]` indexing on `pto.vlds` / `pto.vsts` applies an additional element offset relative to that base. + +### Special Types + +#### `!pto.mask` + +`!pto.mask` models an A5 predicate register (256-bit) under a typed granularity view, not an integer vector. + +`G` is part of the type and MUST be one of: + +- `b32` +- `b16` +- `b8` + +All three forms describe the same physical 256-bit predicate-register class. The type parameter does not encode how many lanes are currently active. Instead, it records how VPTO interprets the register when matching mask-producing ops, mask-consuming ops, and verifier legality rules. + +In the ISA chapters below, this document uses `!pto.mask` as shorthand when a +family is generic over granularity. For op families whose names already encode +the granularity, such as `pset_b32`, `pge_b16`, `plt_b8`, +`pdintlv_b8`, and `pintlv_b16`, examples use the corresponding concrete typed +mask. + +**Mask Granularity:** + +The predicate register is 256 bits in length, where each bit controls 1 byte of data. `G` therefore describes how many bytes form one logical element slot: + +| Mask Type | Bytes / Element Slot | Typical Element Family | Derived Logical Lanes | +|-----------|----------------------|------------------------|-----------------------| +| `!pto.mask` | 4 | `f32` / `i32` | 64 | +| `!pto.mask` | 2 | `f16` / `bf16` / `i16` | 128 | +| `!pto.mask` | 1 | 8-bit element family | 256 | + +This is intentionally different from a lane-vector model such as `mask<64xi1>`: + +- `!pto.mask` still denotes a 256-bit predicate register; +- `64` is only the derived logical lane count for the `b32` view; +- value-level patterns such as `PAT_VL32` describe which lanes are active, not a different type. + +**Predication Behavior (Zero-Merge):** + +The native hardware predication mode is **ZEROING** — inactive lanes produce zero: + +```c +dst[i] = mask[i] ? op(src0[i], src1[i]) : 0 // ZEROING mode +``` + +```mlir +// Predicated add: inactive lanes produce zero +%mask = pto.pset_b32 "PAT_VL32" : !pto.mask // first 32 logical b32 lanes active +%result = pto.vcmp %a, %b, %mask, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +``` + +```mlir +// Compare and select: generate mask from comparison, use for conditional select +%mask = pto.vcmp %lhs, %rhs, %seed, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +%out = pto.vsel %x, %y, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +#### `!pto.align` + +`!pto.align` models the A5 vector-align carrier state. It is not payload data. + +```mlir +%align = pto.vldas %ub : !pto.ptr -> !pto.align +%vec, %align_out, %base_out = pto.vldus %ub, %align : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align, !pto.ptr +``` + +--- + +## Part II: Notation Convention + +This section defines the MLIR syntax patterns and C-style semantic notation used throughout the ISA reference (Part III). + +### MLIR Op Syntax Patterns + +All PTO micro Instruction operations follow standard MLIR syntax. The common patterns are: + +**Unary (one vector in, one vector out):** + +```mlir +%result = pto. %input : !pto.vreg -> !pto.vreg +``` + +**Binary (two vectors in, one vector out):** + +```mlir +%result = pto. %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg +``` + +**Vec-Scalar (one vector + one scalar in, one vector out):** + +```mlir +%result = pto. %input, %scalar : !pto.vreg, T -> !pto.vreg +``` + +**Load (memory to register):** + +```mlir +%result = pto.vlds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.vreg +``` + +**Store (register to memory):** + +```mlir +pto.vsts %value, %destination[%offset] {dist = "DIST"} : !pto.vreg, !pto.ptr +``` + +**Dual Load (one load, two results — deinterleave):** + +```mlir +%low, %high = pto.vldx2 %source[%offset], "DIST" : !pto.ptr, index -> !pto.vreg, !pto.vreg +``` + +**Dual Store (two inputs, one interleaved store):** + +```mlir +pto.vstx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vreg, !pto.ptr, index, !pto.mask +``` + +**Compare (two vectors + seed mask in, mask out):** + +```mlir +%mask = pto.vcmp %src0, %src1, %seed, "CMP_MODE" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.mask +``` + +**Conversion (one vector in, different-typed vector out):** + +```mlir +%result = pto.vcvt %input {round_mode = "ROUND_R", sat = "RS_ENABLE", part = "PART_EVEN"} : !pto.vreg -> !pto.vreg +``` + +**Predicate construction:** + +```mlir +%mask = pto.pset_b32 "PAT_ALL" : !pto.mask +%tail = pto.pge_b32 "PAT_VL16" : !pto.mask +``` + +**Sync operations:** + +```mlir +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +**Pointer construction and arithmetic:** + +```mlir +%ptr = pto.castptr %addr : i64 -> !pto.ptr +%ptr2 = pto.addptr %ptr, %offset : !pto.ptr -> !pto.ptr +``` + +### Shared Dialect Syntax Patterns + +PTO micro Instruction programs may interleave PTO ops with standard MLIR `arith` and `scf` ops. +The examples below emphasize common index-heavy patterns, but `arith` support is not limited to index arithmetic. + +**Scalar / index constant:** + +```mlir +%c0 = arith.constant 0 : index +%zero = arith.constant 0.0 : f32 +``` + +### C-Style Semantics Convention + +For each ISA operation in Part III, semantics are expressed as C code. The convention: + +```c +// Vector register contents as arrays: +T dst[N]; // destination +T src0[N]; // first source +T src1[N]; // second source (binary ops) +T scalar; // scalar operand (vec-scalar ops) +int mask[N]; // per-lane predicate (0 or 1) + +// N = lane count determined by type: +// N = 256 for i8/u8 +// N = 128 for i16/u16/f16/bf16 +// N = 64 for i32/u32/f32 +// N = 32 for i64/u64 +``` + +**Example — pto.vadd semantics:** + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +**Example — pto.vcgadd (group reduction per VLane) semantics:** + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +### Template Placeholder Conventions + +| Placeholder | Meaning | +|-------------|---------| +| `"SRC_PIPE"`, `"DST_PIPE"` | Pipeline identifiers: `"PIPE_MTE2"`, `"PIPE_V"`, `"PIPE_MTE3"` | +| `"EVENT_ID"` | Event identifier: `"EVENT_ID0"` etc. | +| `"DIST"` | Distribution mode string (see the relevant load/store ISA group in Part III) | +| `"CMP_MODE"` | Compare predicate: `eq \| ne \| lt \| le \| gt \| ge` | +| `"ROUND_MODE"` | Rounding mode: `ROUND_R \| ROUND_A \| ROUND_F \| ROUND_C \| ROUND_Z` | +| `"SAT_MODE"` | Saturation: `RS_ENABLE \| RS_DISABLE` | +| `"PART_MODE"` | Half selector: `PART_EVEN \| PART_ODD` | +| `"PAT_*"` | Predicate pattern literal | +| `T` | Element type (f32, f16, bf16, i32, i16, i8, etc.) | +| `N` | Lane count (`N * bitwidth(T) = 2048`) | + +--- + +## Part III: ISA Instruction Reference — Summary + +This section provides a categorized overview of all PTO micro Instruction operations plus the shared MLIR `arith` and `scf` ops that may appear in PTO micro Instruction programs. Detailed documentation for each group is included later in this merged document. + +--- + +## Instruction Groups + +| # | Group | Description | Count | Details | +|---|-------|-------------|-------|---------| +| 1 | [Pipeline Sync](#isa-01-pipeline-sync) | Intra-core pipeline synchronization | 5 | `pto.set_flag`, `pto.wait_flag`, `pto.pipe_barrier`, `pto.get_buf`, `pto.rls_buf` | +| 2 | [DMA Copy Programming](#isa-02-dma-copy) | DMA configuration and transfer between GM↔UB | 9 | `pto.set_loop*_stride_*`, `pto.set_loop_size_*`, `pto.copy_gm_to_ubuf`, `pto.copy_ubuf_to_ubuf`, `pto.copy_ubuf_to_gm` | +| 3 | [Vector Load/Store](#isa-03-vector-load-store) | UB↔vreg data movement with various access patterns | ~20 | `pto.vlds`, `pto.vldx2`, `pto.vgather2`, `pto.vsts`, `pto.vstx2`, `pto.vscatter`, etc. | +| 4 | [Predicate Load/Store](#isa-04-predicate-load-store) | UB↔mask register movement | 7 | `pto.plds`, `pto.pld`, `pto.pldi`, `pto.psts`, `pto.pst`, `pto.psti`, `pto.pstu` | +| 5 | [Materialization & Predicate Ops](#isa-05-materialization-predicate) | Scalar broadcast, predicate generation and manipulation | ~17 | `pto.vbr`, `pto.vdup`, `pto.pset_b*`, `pto.pge_b*`, `pto.plt_b*`, `pto.ppack`, `pto.punpack`, `pto.pnot`, `pto.psel`, etc. | +| 6 | [Unary Vector Ops](#isa-06-unary-vector-ops) | Single-input element-wise operations | 9 | `pto.vabs`, `pto.vexp`, `pto.vln`, `pto.vsqrt`, `pto.vrec`, `pto.vrelu`, `pto.vnot`, `pto.vbcnt`, `pto.vcls` | +| 7 | [Binary Vector Ops](#isa-07-binary-vector-ops) | Two-input element-wise operations | 13 | `pto.vadd`, `pto.vsub`, `pto.vmul`, `pto.vdiv`, `pto.vmax`, `pto.vmin`, `pto.vand`, `pto.vor`, `pto.vxor`, `pto.vshl`, `pto.vshr`, `pto.vaddc`, `pto.vsubc` | +| 8 | [Vec-Scalar Ops](#isa-08-vec-scalar-ops) | Vector-scalar operations | 8 | `pto.vadds`, `pto.vmuls`, `pto.vmaxs`, `pto.vmins`, `pto.vlrelu`, `pto.vshls`, `pto.vshrs`, `pto.vaddcs`, `pto.vsubcs` | +| 9 | [Conversion Ops](#isa-09-conversion-ops) | Type conversion with rounding/saturation control | 2 | `pto.vcvt`, `pto.vtrc` | +| 10 | [Reduction Ops](#isa-10-reduction-ops) | Vector reductions | 3 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin` | +| 11 | [Compare & Select](#isa-11-compare-select) | Comparison and conditional selection | 5 | `pto.vcmp`, `pto.vcmps`, `pto.vsel`, `pto.vselr`, `pto.vselrv2` | +| 12 | [Data Rearrangement](#isa-12-data-rearrangement) | In-register data movement and permutation | 4 | `pto.vintlv`, `pto.vdintlv`, `pto.vintlvv2`, `pto.vdintlvv2` | +| 13 | [DSA/SFU Ops](#isa-13-dsa-sfu-ops) | Specialized ops, index generation, and sorting helpers | 5 | `pto.vmull`, `pto.vmula`, `pto.vci`, `pto.vbitsort`, `pto.vmrgsort4` | +| 14 | [Arith (Shared MLIR Dialect)](#isa-14-shared-arith) | Full scalar `arith` surface used around PTO ops; the companion page lists categories and representative examples | all scalar ops | `arith.constant`, `arith.addi`, `arith.addf`, `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.index_cast`, `arith.extsi`, `arith.trunci`, `arith.andi`, `arith.shli`, etc. | +| 15 | [SCF (Shared MLIR Dialect)](#isa-15-shared-scf) | Structured loops, branches, and loop-carried state around PTO regions | 5 | `scf.for`, `scf.if`, `scf.while`, `scf.condition`, `scf.yield` | + +--- + +## Detailed ISA Group Reference + +This section inlines the 15 ISA group documents so the architectural overview, notation, summary table, and per-group semantics can be read in a single file. + + + +### 1. Pipeline Synchronization + +> **Category:** Synchronization primitives for coordinating pipeline execution +> **Pipelines:** MTE2 (GM→UB), PIPE_V (Vector), MTE3 (UB→GM) + +The PTO micro Instruction model operates on the Ascend 950's **Decoupled Access-Execute** architecture. The MTE and Vector pipelines run asynchronously, requiring explicit synchronization to prevent data hazards. + +--- + +#### Intra-Core Pipeline Sync + +These ops coordinate data flow between pipelines within a single vector core. + +##### `pto.set_flag` + +- **syntax:** `pto.set_flag["SRC_PIPE", "DST_PIPE", "EVENT_ID"]` +- **semantics:** Signal event from source pipe to destination pipe. + +```c +set_flag(src_pipe, dst_pipe, event_id); +``` + +**Example:** After MTE2 completes GM→UB transfer, signal Vector pipe: +```mlir +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +--- + +##### `pto.wait_flag` + +- **syntax:** `pto.wait_flag["SRC_PIPE", "DST_PIPE", "EVENT_ID"]` +- **semantics:** Block destination pipe until source pipe signals event. + +```c +wait_flag(src_pipe, dst_pipe, event_id); +``` + +**Example:** Vector pipe waits for MTE2 data to arrive: +```mlir +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +--- + +##### `pto.pipe_barrier` + +- **syntax:** `pto.pipe_barrier "PIPE_*"` +- **semantics:** Drain all pending ops in the specified pipe. All previously issued operations on that pipe complete before any subsequent operation begins. + +```c +pipe_barrier(pipe); +``` + +**Pipe identifiers:** `PIPE_MTE2`, `PIPE_V`, `PIPE_MTE3` + +**Example:** Two back-to-back `copy_ubuf_to_gm` calls writing to the same GM address. Without a barrier, MTE3 may reorder them and the final GM value is non-deterministic: + +```mlir +// Both stores target the same GM address — order matters! +pto.copy_ubuf_to_gm %ub_partial_0, %gm_result, ... +// Without pipe_barrier, MTE3 could execute the second copy before the first +// completes, producing a non-deterministic result at %gm_result. +pto.pipe_barrier "PIPE_MTE3" +// After barrier: first copy is guaranteed complete. Second copy overwrites deterministically. +pto.copy_ubuf_to_gm %ub_partial_1, %gm_result, ... +``` + +--- + +##### `pto.get_buf` + +- **syntax:** `pto.get_buf "PIPE_*", %buf_id, %mode : i64, i64` +- **semantics:** Acquire buffer slot for inter-pipeline double-buffering coordination. + +```c +get_buf(pipe, buf_id, mode); +``` + +--- + +##### `pto.rls_buf` + +- **syntax:** `pto.rls_buf "PIPE_*", %buf_id, %mode : i64, i64` +- **semantics:** Release buffer slot to allow other pipeline to proceed. + +```c +rls_buf(pipe, buf_id, mode); +``` + +--- + +##### `pto.mem_bar` + +- **syntax:** `pto.mem_bar "BARRIER_TYPE"` +- **semantics:** Intra-vector-pipe memory fence within `__VEC_SCOPE__`. Required when UB addresses alias between vector load/store operations. + +```c +mem_bar(barrier_type); +``` + +**Barrier types:** + +| Type | Semantics | +|------|-----------| +| `VV_ALL` | All prior vector ops complete before subsequent | +| `VST_VLD` | All prior vector stores visible before subsequent loads | +| `VLD_VST` | All prior vector loads complete before subsequent stores | + +**Example:** Ensure stores are visible before loads to same UB region: +```mlir +pto.vsts %v0, %ub[%c0] : !pto.vreg<64xf32>, !pto.ptr +pto.mem_bar "VST_VLD" +%v1 = pto.vlds %ub[%c0] : !pto.ptr -> !pto.vreg<64xf32> +``` + +--- + +#### Intra-Core Sync Patterns & Examples + +##### Example 1: `set_flag` / `wait_flag` (Explicit Events) + +Each cross-pipeline data dependency requires an explicit signal/wait pair. The programmer must manually insert `set_flag` after the producer and `wait_flag` before the consumer. + +```mlir +// ─── Stage 1: MTE2 loads data from GM into UB ─── +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, ... + +// MTE2 signals: "UB data is ready for Vector pipe" +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +// ─── Stage 2: Vector pipe consumes UB data ─── +// Vector waits until MTE2's signal arrives +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +pto.vecscope { + %v = pto.vlds %ub_ptr[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} + +// Vector signals: "UB output is ready for MTE3" +pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + +// ─── Stage 3: MTE3 stores result from UB back to GM ─── +// MTE3 waits until Vector's signal arrives +pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + +pto.copy_ubuf_to_gm %ub_out, %gm_out, ... +``` + +**Key property:** Every cross-pipeline edge is an explicit `(set_flag, wait_flag)` pair. Simple for straight-line code, but gets verbose in loops (see Example 3). + +--- + +##### Example 2: `get_buf` / `rls_buf` (Resource-Based) + +Instead of naming events, each pipeline declares when it **acquires** (`get_buf`) and **releases** (`rls_buf`) a shared UB buffer. Cross-pipeline RAW/WAR dependencies are resolved implicitly by program order — if MTE2 releases `buf_A` and Vector later acquires `buf_A`, the hardware ensures the acquire cannot proceed until the release completes. + +```mlir +// ─── Stage 1: MTE2 loads data into UB ─── +// MTE2 acquires ub_ptr — blocks if Vector hasn't released it from a prior iteration +pto.get_buf "PIPE_MTE2", %bufid_ub_ptr, %mode : i64, i64 +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, ... +// MTE2 done writing ub_ptr — release it so Vector can consume +pto.rls_buf "PIPE_MTE2", %bufid_ub_ptr, %mode : i64, i64 + +// ─── Stage 2: Vector computation ─── +// Vector acquires ub_ptr (input) — blocks until MTE2 releases it (RAW: MTE2 write → V read) +pto.get_buf "PIPE_V", %bufid_ub_ptr, %mode : i64, i64 +// Vector acquires ub_out (output) — blocks until MTE3 releases it from a prior iteration (WAR: MTE3 read → V write) +pto.get_buf "PIPE_V", %bufid_ub_out, %mode : i64, i64 + +pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub_ptr[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} + +// Vector done reading ub_ptr — release so MTE2 can reuse it in next iteration +pto.rls_buf "PIPE_V", %bufid_ub_ptr, %mode : i64, i64 +// Vector done writing ub_out — release so MTE3 can consume +pto.rls_buf "PIPE_V", %bufid_ub_out, %mode : i64, i64 + +// ─── Stage 3: MTE3 stores result to GM ─── +// MTE3 acquires ub_out — blocks until Vector releases it (RAW: V write → MTE3 read) +pto.get_buf "PIPE_MTE3", %bufid_ub_out, %mode : i64, i64 +pto.copy_ubuf_to_gm %ub_out, %gm_out, ... +// MTE3 done reading ub_out — release so Vector can reuse it in next iteration +pto.rls_buf "PIPE_MTE3", %bufid_ub_out, %mode : i64, i64 +``` + +**Key property:** No event IDs needed. Dependencies are implicit from program order of `get_buf`/`rls_buf` on the same buffer ID. This becomes much more convenient in multi-iteration loops (see Example 3). + +--- + +##### Example 3: Ping/Pong Double-Buffering Loop + +Double-buffering overlaps DMA and compute by using two UB buffers alternately. All three stages (MTE2, Vector, MTE3) appear in the **same iteration** — the hardware pipelines them across iterations because different iterations operate on different buffers (`buf[i%2]`). + +###### Event ID scheme (`set_flag` / `wait_flag`) + +With 2 ping/pong buffers and 2 pipeline pairs (MTE2↔V, V↔MTE3), `set_flag`/`wait_flag` needs **8 event IDs** = 2 pipe-pairs × 2 buffers × (forward + reverse): + +**MTE2 ↔ Vector (input buffers):** + +| Event ID | Direction | Purpose | +|----------|-----------|---------| +| `EVT_IN_FWD_0` | MTE2 → V | RAW: buf_in[0] data ready | +| `EVT_IN_FWD_1` | MTE2 → V | RAW: buf_in[1] data ready | +| `EVT_IN_REV_0` | V → MTE2 | WAR: Vector done reading buf_in[0] | +| `EVT_IN_REV_1` | V → MTE2 | WAR: Vector done reading buf_in[1] | + +**Vector ↔ MTE3 (output buffers):** + +| Event ID | Direction | Purpose | +|----------|-----------|---------| +| `EVT_OUT_FWD_0` | V → MTE3 | RAW: buf_out[0] result ready | +| `EVT_OUT_FWD_1` | V → MTE3 | RAW: buf_out[1] result ready | +| `EVT_OUT_REV_0` | MTE3 → V | WAR: MTE3 done reading buf_out[0] | +| `EVT_OUT_REV_1` | MTE3 → V | WAR: MTE3 done reading buf_out[1] | + +###### 3a. `set_flag` / `wait_flag` version + +```mlir +// ═══ Pre-loop: prime ALL reverse-dependency signals ═══ +// Both input and output buffers start unused. We must pre-send +// reverse-dep signals so the first iteration's wait_flags don't deadlock. +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] // ◀ PRIME: buf_in[0] "free" +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] // ◀ PRIME: buf_in[1] "free" +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] // ◀ PRIME: buf_out[0] "free" +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] // ◀ PRIME: buf_out[1] "free" + +scf.for %i = %c0 to %N step %c1 { + // ── All 3 stages in same iteration, indexed by i%2 ── + // %pp = i % 2 (ping/pong selector for buffer & event IDs) + + // ── MTE2: load tile[i] into buf_in[i%2] ── + // WAR: wait until Vector has released buf_in[i%2] from iteration i-2 + pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{pp}"] + pto.copy_gm_to_ubuf %gm_ptr[%i], %ub_in[%pp], ... + // RAW: signal Vector that buf_in[i%2] data is ready + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVT_IN_FWD_{pp}"] + + // ── Vector: compute buf_in[i%2] → buf_out[i%2] ── + // RAW: wait for MTE2 to finish loading buf_in[i%2] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVT_IN_FWD_{pp}"] + // WAR: wait for MTE3 to finish reading buf_out[i%2] from iteration i-2 + pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{pp}"] + scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_in[%pp][%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%pp][%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } {llvm.loop.aivector_scope} + // WAR: tell MTE2 "done reading buf_in[i%2]" + pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{pp}"] + // RAW: tell MTE3 "buf_out[i%2] result ready" + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVT_OUT_FWD_{pp}"] + + // ── MTE3: store result from buf_out[i%2] to GM ── + // RAW: wait for Vector to finish writing buf_out[i%2] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVT_OUT_FWD_{pp}"] + pto.copy_ubuf_to_gm %ub_out[%pp], %gm_out[%i], ... + // WAR: tell Vector "done reading buf_out[i%2]" + pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{pp}"] +} + +// ═══ Post-loop: drain — match every pre-loop prime with a wait ═══ +// Each priming set_flag must be paired. The last loop iteration's +// set_flags are consumed by wait_flags that will never fire inside the +// loop (there is no iteration i+2). Drain them here. +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-1)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-2)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-1)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-2)%2}"] // ◀ DRAIN +``` + +**What `set_flag`/`wait_flag` requires outside the loop:** +- **Before the loop (4 × `set_flag`):** Prime every reverse-dependency event ID — one per buffer per pipe-pair. Without this, the first iteration's `wait_flag` for reverse deps would deadlock (no signal was ever sent). +- **After the loop (4 × `wait_flag`):** Drain the matching reverse-dep signals from the last iterations. Every `set_flag` must be paired with a `wait_flag` — the last loop iterations produce signals that no subsequent iteration consumes, so they must be drained explicitly. + +###### 3b. `get_buf` / `rls_buf` version + +Same ping/pong double-buffering, but **no pre-loop priming or post-loop draining needed.** Buffer acquire/release semantics handle everything. + +```mlir +scf.for %i = %c0 to %N step %c1 { + // %pp = i % 2 (ping/pong selector) + + // ── MTE2: load tile[i] into buf[i%2] ── + // Acquires buf[i%2] — on first iteration, buffer is free so proceeds immediately. + // On later iterations, blocks until Vector releases buf[i%2] (WAR: automatic). + pto.get_buf %bufid_buf[%pp], "PIPE_MTE2" + pto.copy_gm_to_ubuf %gm_ptr[%i], %ub_buf[%pp], ... + pto.rls_buf %bufid_buf[%pp], "PIPE_MTE2" + + // ── Vector: compute on buf[i%2] ── + // Acquires buf[i%2] — blocks until MTE2 releases it (RAW: automatic) + pto.get_buf %bufid_buf[%pp], "PIPE_V" + pto.get_buf %bufid_out[%pp], "PIPE_V" + scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_buf[%pp][%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%pp][%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } {llvm.loop.aivector_scope} + // Release buf[i%2] — MTE2 can reuse in iteration i+2 (WAR resolved) + pto.rls_buf %bufid_buf[%pp], "PIPE_V" + pto.rls_buf %bufid_out[%pp], "PIPE_V" + + // ── MTE3: store result ── + // Acquires out[i%2] — blocks until Vector releases it (RAW: automatic) + pto.get_buf %bufid_out[%pp], "PIPE_MTE3" + pto.copy_ubuf_to_gm %ub_out[%pp], %gm_out[%i], ... + pto.rls_buf %bufid_out[%pp], "PIPE_MTE3" +} +// No post-loop drain needed — last rls_buf completes the pipeline. +``` + +**No priming, no draining, no event IDs.** The acquire/release protocol on buffer IDs indexed by `i%2` implicitly resolves all cross-pipeline dependencies: +- **RAW** (MTE2→V): Vector's `get_buf` blocks until MTE2's `rls_buf` on `buf[i%2]` +- **WAR** (V→MTE2): MTE2's `get_buf` in iteration `i+2` blocks until Vector's `rls_buf` in iteration `i` (same buffer) +- **First iteration:** Buffer is initially free, so `get_buf` proceeds without blocking — no priming needed + +--- + +#### Comparison Summary + +| Aspect | `set_flag` / `wait_flag` | `get_buf` / `rls_buf` | +|--------|--------------------------|------------------------| +| Dependency model | Explicit event signals | Implicit via buffer acquire/release | +| IDs per pipe-pair | **8** = 2 buffers × 2 dirs × 2 (fwd+rev) | 1 fwd + 1 rev per buffer (shared global pool) | +| Total HW IDs | 8 per pipe-pair, grows with buffers | **32 global** across all pipes | +| Reverse (WAR) deps | Extra `set_flag`/`wait_flag` pair per buffer | Handled automatically | +| Pre-loop setup | `set_flag` to prime each reverse dep | None | +| Post-loop teardown | `wait_flag` to drain all primed signals | None | +| Straight-line code | Simple, clear | Slightly more verbose (bracket each stage) | +| Ping/pong loops | 8 event IDs + 4 prime + 4 drain | Same pattern, no overhead | +| Best used for | Simple pipelines, fine-grained control | Double/multi-buffering, complex loops | + +--- + +#### Inter-Core Sync + +> **Note:** Inter-core sync is only needed for **mixed Cube+Vector tasks** where Cube produces data that Vector consumes (or vice versa). **Vec-only tasks can ignore this section entirely.** + +These ops coordinate execution across the Cube block and Vector subblocks within a cluster. Each core cluster consists of **1 Cube block : 2 Vector subblocks**, each with its own **SU (Sequencer Unit)** running independent instruction streams. + +``` +Core Cluster (1:2 ratio) +┌─────────────────────────────────────────────┐ +│ ┌──────────────┐ ┌──────────────┐ │ +│ │ AIC (Cube) │ │ AIV0 (Vec) │ │ +│ │ ┌────────┐ │ │ ┌────────┐ │ │ +│ │ │ SU │──┼────┼──│ SU │ │ │ +│ │ └────────┘ │ │ └────────┘ │ │ +│ │ CUBE pipe │ │ MTE2/V/MTE3 │ │ +│ │ L0C buffer │ │ UB (256KB) │ │ +│ └──────────────┘ └──────────────┘ │ +│ ┌──────────────┐ │ +│ │ AIV1 (Vec) │ │ +│ │ ┌────────┐ │ │ +│ │ │ SU │ │ │ +│ │ └────────┘ │ │ +│ │ MTE2/V/MTE3 │ │ +│ │ UB (256KB) │ │ +│ └──────────────┘ │ +└─────────────────────────────────────────────┘ +``` + +##### Platform Comparison + +| Aspect | A2A3 (Ascend 910) | A5 (Ascend 950) | +|--------|-------------------|-----------------| +| **Signal op** | `set_cross_core` (mode2) | `set_intra_block` | +| **Wait op** | `wait_flag_dev` | `wait_intra_core` | +| **Wait behavior** | SU-level blocking (entire core stalls) | Per-pipeline (only named pipe stalls) | +| **Semaphore pool** | 16 IDs per cluster, 4-bit counter | 16 IDs, but 32-ID address space (see below) | +| **C→V** | **Broadcast**: one `set` reaches both AIV0+AIV1 | **1:1**: separate `set` per subblock required | +| **V→C** | **Reduce**: Cube waits for both subblocks in one `wait` | **1:1**: Cube needs separate `wait` per subblock | + +##### A2A3: `set_cross_core` / `wait_flag_dev` + +```c +// mode2 broadcast/reduce semantics for 1:2 cluster +set_cross_core(pipe, semaphore_id); // pipe: VEC/MTE2/CUBE/FIX +wait_flag_dev(semaphore_id); // SU-level blocking +``` + +``` +C→V Broadcast (one set reaches both): + AIC ──set_cross_core──┬──> AIV0 sema++ + └──> AIV1 sema++ + +V→C Reduce (one wait for both): + AIV0 ──set_cross_core──┐ + ├──> AIC wait_flag_dev (blocks until BOTH) + AIV1 ──set_cross_core──┘ +``` + +##### `pto.set_cross_core` + +- **syntax:** `pto.set_cross_core %core_id, %event_id : i64, i64` +- **semantics:** Signal event to another core. Uses **mode2** for 1:2 cluster on A2A3. + +##### `pto.wait_flag_dev` + +- **syntax:** `pto.wait_flag_dev %core_id, %event_id : i64, i64` +- **semantics:** Wait for event from another core. **SU-level blocking** — entire core stalls. + +##### A5: `set_intra_block` / `wait_intra_core` + +```c +set_intra_block(trigger_pipe, semaphore_id); +wait_intra_core(wait_pipe, semaphore_id); // only named pipe stalls +``` + +**A5 semaphore address space:** The hardware has **16 physical semaphore IDs** but exposes a **32-ID address space** to support 1:1 signaling to each subblock: + +| ID Range | Target | +|----------|--------| +| 0–15 | AIV0 (subblock 0) | +| 16–31 (+15 offset) | AIV1 (subblock 1) | + +This means C→V requires **separate `set_intra_block` calls** per subblock (no broadcast), and V→C requires **separate `wait_intra_core` calls** per subblock (no hardware reduce). + +``` +C→V on A5 (1:1, no broadcast — need two sets): + AIC ──set_intra_block(pipe, sema_id)────> AIV0 + AIC ──set_intra_block(pipe, sema_id+15)──> AIV1 + +V→C on A5 (1:1, no reduce — need two waits): + AIV0 ──set_intra_block──> AIC wait_intra_core(pipe, sema_id) + AIV1 ──set_intra_block──> AIC wait_intra_core(pipe, sema_id+15) // extra wait +``` + +##### `pto.set_intra_block` + +- **syntax:** `pto.set_intra_block %block_id, %event_id : i64, i64` +- **semantics:** Signal event within a block (A5). Specifies **trigger pipe**. 1:1 per subblock. + +##### `pto.wait_intra_core` + +- **syntax:** `pto.wait_intra_core %block_id, %event_id : i64, i64` +- **semantics:** Wait for event within block (A5). Specifies **which pipeline should wait** — only that pipe stalls, SU and other pipes continue. + +##### Wait Granularity Comparison + +``` +A2A3 wait_flag_dev (SU-level stall): + SU ──┬── PIPE_MTE2 ───╳ ALL STALLED + ├── PIPE_V ───╳ ALL STALLED + └── PIPE_MTE3 ───╳ ALL STALLED + +A5 wait_intra_core "PIPE_MTE2" (per-pipe stall): + SU ──┬── PIPE_MTE2 ───╳ STALLED (waiting for Cube) + ├── PIPE_V ─── ✓ RUNNING + └── PIPE_MTE3 ─── ✓ RUNNING +``` + + + +### 2. DMA Copy Programming + +> **Category:** DMA transfer configuration and execution +> **Pipelines:** MTE2 (GM→UB), MTE3 (UB→GM) + +DMA transfers move data between Global Memory (GM) and Unified Buffer (UB). The MTE engines operate asynchronously from the Vector core, requiring explicit sync (see [Pipeline Sync](#isa-01-pipeline-sync)). + +The MTE2/MTE3 DMA engine executes a **multi-level nested loop** transfer. Before issuing the copy instruction, stride and loop-size registers must be configured. + +--- + +#### Loop Stride Configuration (GM→UB) + +These ops configure the MTE2 DMA engine's hardware loops for GM→UB transfers. They must be set **before** calling `pto.copy_gm_to_ubuf`. + +##### `pto.set_loop_size_outtoub` + +- **syntax:** `pto.set_loop_size_outtoub %loop1_count, %loop2_count : i64, i64` +- **semantics:** Configure HW loop iteration counts for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%loop1_count` | 21 bits | Inner HW loop iteration count | +| `%loop2_count` | 21 bits | Outer HW loop iteration count | + +When not using multi-level looping, set both to 1. + +--- + +##### `pto.set_loop2_stride_outtoub` + +- **syntax:** `pto.set_loop2_stride_outtoub %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure outer loop (loop2) pointer advance for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 40 bits | GM source pointer advance per loop2 iteration (bytes) | +| `%dst_stride` | 21 bits | UB destination pointer advance per loop2 iteration (bytes) | + +After each loop2 iteration, the DMA engine advances the GM read pointer by `%src_stride` and UB write pointer by `%dst_stride`. + +--- + +##### `pto.set_loop1_stride_outtoub` + +- **syntax:** `pto.set_loop1_stride_outtoub %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure inner loop (loop1) pointer advance for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 40 bits | GM source pointer advance per loop1 iteration (bytes) | +| `%dst_stride` | 21 bits | UB destination pointer advance per loop1 iteration (bytes) | + +--- + +#### Loop Stride Configuration (UB→GM) + +These ops configure the MTE3 DMA engine's hardware loops for UB→GM transfers. They must be set **before** calling `pto.copy_ubuf_to_gm`. + +Note: UB stride fields are 21 bits (sufficient for 256KB UB address space), GM stride fields are 40 bits (full GM address range). + +##### `pto.set_loop_size_ubtoout` + +- **syntax:** `pto.set_loop_size_ubtoout %loop1_count, %loop2_count : i64, i64` +- **semantics:** Configure HW loop iteration counts for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%loop1_count` | 21 bits | Inner HW loop iteration count | +| `%loop2_count` | 21 bits | Outer HW loop iteration count | + +--- + +##### `pto.set_loop2_stride_ubtoout` + +- **syntax:** `pto.set_loop2_stride_ubtoout %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure outer loop (loop2) pointer advance for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 21 bits | UB source pointer advance per loop2 iteration (bytes) | +| `%dst_stride` | 40 bits | GM destination pointer advance per loop2 iteration (bytes) | + +--- + +##### `pto.set_loop1_stride_ubtoout` + +- **syntax:** `pto.set_loop1_stride_ubtoout %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure inner loop (loop1) pointer advance for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 21 bits | UB source pointer advance per loop1 iteration (bytes) | +| `%dst_stride` | 40 bits | GM destination pointer advance per loop1 iteration (bytes) | + +--- + +#### DMA Transfer Execution + +##### `pto.copy_gm_to_ubuf` + +- **syntax:** +```mlir +pto.copy_gm_to_ubuf %gm_src, %ub_dst, + %sid, %n_burst, %len_burst, %left_padding, %right_padding, + %data_select_bit, %l2_cache_ctl, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` +- **semantics:** DMA transfer from Global Memory (`!pto.ptr`) to Unified Buffer (`!pto.ptr`). + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%gm_src` | GM source pointer (`!pto.ptr`) | +| `%ub_dst` | UB destination pointer (`!pto.ptr`, 32B-aligned) | +| `%sid` | Stream ID (usually 0) | +| `%n_burst` | Number of burst rows (innermost loop count) | +| `%len_burst` | Contiguous bytes transferred per burst row | +| `%left_padding` | Left padding count (bytes) | +| `%right_padding` | Right padding count (bytes) | +| `%data_select_bit` | Padding / data-select control bit (`i1`) | +| `%l2_cache_ctl` | L2 cache allocate control (TBD — controls whether DMA allocates in L2 cache) | +| `%src_stride` | GM source stride: start-to-start distance between consecutive burst rows (bytes) | +| `%dst_stride` | UB destination stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | + +--- + +##### `pto.copy_ubuf_to_gm` + +- **syntax:** +```mlir +pto.copy_ubuf_to_gm %ub_src, %gm_dst, + %sid, %n_burst, %len_burst, %reserved, %dst_stride, %src_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` +- **semantics:** DMA transfer from Unified Buffer (`!pto.ptr`) to Global Memory (`!pto.ptr`). MTE3 reads only `len_burst` bytes from each UB row (de-padding). + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%ub_src` | UB source pointer (`!pto.ptr`, 32B-aligned) | +| `%gm_dst` | GM destination pointer (`!pto.ptr`) | +| `%sid` | Stream ID (usually 0) | +| `%n_burst` | Number of burst rows | +| `%len_burst` | Contiguous bytes transferred per burst row | +| `%reserved` | Reserved field (set to 0) | +| `%dst_stride` | GM destination stride: start-to-start distance between consecutive burst rows (bytes) | +| `%src_stride` | UB source stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | + +--- + +##### `pto.copy_ubuf_to_ubuf` + +- **syntax:** +```mlir +pto.copy_ubuf_to_ubuf %source, %dest, %sid, %n_burst, %len_burst, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64 x5 +``` +- **semantics:** Copy within Unified Buffer. + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%source` | UB source pointer | +| `%dest` | UB destination pointer | +| `%sid` | Stream ID | +| `%n_burst` | Number of bursts | +| `%len_burst` | Length per burst | +| `%src_stride` | Source stride | +| `%dst_stride` | Destination stride | + +--- + +#### Burst / Stride / Pad Model + +All A5 DMA addresses are **stride-based**: stride is the distance from the start of one row to the start of the next row (`stride >= lenBurst`). There is no separate "gap" parameter. + +##### Key Terms + +``` +burst = lenBurst contiguous bytes transferred per row +stride = distance (bytes) from start of row[r] to start of row[r+1] +pad = ub_stride - lenBurst, padded to the 32B alignment boundary +``` + +##### Alignment Constraints + +- **UB addresses** (both source and destination) must be **32-byte aligned**. +- **GM→UB padding**: When `data_select_bit = true`, each UB row is padded from `lenBurst` up to the **32B-aligned boundary** of `ub_stride` with `pad_val` (set via `set_mov_pad_val`). This ensures every UB row starts at a 32B-aligned offset. +- **UB→GM de-padding**: MTE3 reads `lenBurst` bytes from each 32B-aligned UB row (skipping any padding that was added during load), writing only valid data to GM. This effectively strips padding on store. + +##### 2D Diagram: GM→UB (pto.copy_gm_to_ubuf) + +``` +GM (source, `!pto.ptr`): + + |<--- src_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +UB (destination, `!pto.ptr`, 32B-aligned): + + |<---------- dst_stride (32B-aligned) ---------->| + |<- len_burst ->|<- pad (to 32B boundary) ->| | +Row 0: [##DATA########][000000 PAD 000000000000000] +Row 1: [##DATA########][000000 PAD 000000000000000] +Row 2: [##DATA########][000000 PAD 000000000000000] + ... +Row N-1: [##DATA########][000000 PAD 000000000000000] + +N = n_burst +stride = start of row[r] to start of row[r+1] +pad = filled with pad_val to 32B boundary (data_select_bit=true) +[DATA] = valid data transferred by DMA +[PAD] = pad_val fill (set via set_mov_pad_val) +``` + +##### 2D Diagram: UB→GM (pto.copy_ubuf_to_gm) + +``` +UB (source, `!pto.ptr`, 32B-aligned start addr): + + |<---------- src_stride (32B-aligned) --------->| + |<- len_burst ->|<-- pad (ignored on read) -->| | +Row 0: [##DATA########][000 pad 000000000000000000] +Row 1: [##DATA########][000 pad 000000000000000000] +Row 2: [##DATA########][000 pad 000000000000000000] + ... +Row N-1: [##DATA########][000 pad 000000000000000000] + +GM (destination, `!pto.ptr`): + + |<--- dst_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +N = n_burst +MTE3 reads only len_burst bytes from each UB row (de-padding). +Only len_burst bytes are written to each GM row. +``` + +--- + +#### Multi-Level Loop Semantics (C Code) + +The full DMA transfer is a nested loop. The HW loop registers (set before the copy) control the outer levels, and the copy instruction parameters control the innermost burst level. + +##### GM→UB Full Loop + +```c +// C equivalent of what the HW executes: +for (int j = 0; j < loop2_count; j++) { // HW outer loop + uint8_t *gm1 = gm_src + j * loop2_src_stride; + uint8_t *ub1 = ub_dst + j * loop2_dst_stride; + + for (int k = 0; k < loop1_count; k++) { // HW inner loop + uint8_t *gm2 = gm1 + k * loop1_src_stride; + uint8_t *ub2 = ub1 + k * loop1_dst_stride; + + for (int r = 0; r < n_burst; r++) { // burst engine + memcpy(ub2 + r * dst_stride, // UB dest row + gm2 + r * src_stride, // GM src row + len_burst); // contiguous bytes + if (data_select_bit) + memset(ub2 + r * dst_stride + len_burst, + pad_val, dst_stride - len_burst); + } + } +} +``` + +##### UB→GM Full Loop + +```c +// C equivalent: +for (int j = 0; j < loop2_count; j++) { + uint8_t *ub1 = ub_src + j * loop2_src_stride; + uint8_t *gm1 = gm_dst + j * loop2_dst_stride; + + for (int k = 0; k < loop1_count; k++) { + uint8_t *ub2 = ub1 + k * loop1_src_stride; + uint8_t *gm2 = gm1 + k * loop1_dst_stride; + + for (int r = 0; r < n_burst; r++) { + memcpy(gm2 + r * dst_stride, // GM dest row + ub2 + r * src_stride, // UB src row + len_burst); // contiguous bytes + } + } +} +``` + +--- + +#### Example 1: GM→UB — Load a 32×32 f32 Tile (Simple Case) + +Load a 32×32 f32 tile from GM into UB. This matches the `abs_kernel_2d` test case. + +``` +GM layout (32 × 32 f32, contiguous): + + |<- len_burst = 128B (32 × 4) ->| + |<- src_stride = 128B --------->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + +UB layout (32 × 32 f32, 32B-aligned, contiguous): + + |<- dst_stride = 128B (32B-aligned) ->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + + len_burst = 32 × 4 = 128 bytes + src_stride = 128 bytes (contiguous rows) + dst_stride = 128 bytes (already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — no multi-level loops needed +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + +pto.copy_gm_to_ubuf %arg0, %ub_in, + %c0_i64, // sid = 0 + %c32_i64, // n_burst = 32 (32 rows) + %c128_i64, // len_burst = 128 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c128_i64, // src_stride = 128 bytes + %c128_i64 // dst_stride = 128 bytes + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 2: GM→UB — Load a 2D Tile from a Larger Matrix + +Load a 64×128 tile (f16) from a 1024×512 matrix in GM into UB. + +``` +GM layout (1024 × 512 f16): + + col 0 col 128 col 512 + | | | + +--[###TILE###]+.....................+ row R + +--[###TILE###]+.....................+ row R+1 + ... + +--[###TILE###]+.....................+ row R+63 + + |<--------- src_stride = 1024B ----------->| + |<-len_burst=256B->| + + len_burst = 128 × 2 = 256 bytes (128 f16 elements) + src_stride = 512 × 2 = 1024 bytes (start-to-start, full GM row) + +UB layout (64 × 128 f16, 32B-aligned, contiguous): + + +--[###TILE###]--+ row 0 (256 bytes, 32B-aligned, no pad) + +--[###TILE###]--+ row 1 + ... + +--[###TILE###]--+ row 63 + + dst_stride = 256 bytes (= len_burst, already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — no multi-level loops needed +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 (64 rows) + %c256_i64, // len_burst = 256 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c1024_i64, // src_stride = 1024 bytes (full matrix row) + %c256_i64 // dst_stride = 256 bytes (tile row) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 3: GM→UB — Load with Padding + +Load 100 valid columns from GM into a 128-wide UB tile (f16). The remaining 28 columns are zero-padded. + +``` +GM (100 cols valid, contiguous): + + |<-len_burst=200B->| + |<- src_stride=200B (start-to-start) ->| + +--[####DATA####]-+ row 0 + +--[####DATA####]-+ row 1 + ... + +--[####DATA####]-+ row 63 + +UB (128 cols wide, 32B-aligned, padded): + + |<--------- dst_stride = 256B (32B-aligned) --------->| + |<-len_burst=200B->|<---- pad = 56B to 32B boundary ->| + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 0 + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 1 + ... + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 63 + + len_burst = 100 × 2 = 200 bytes + src_stride = 200 bytes (start-to-start, contiguous in GM) + dst_stride = 128 × 2 = 256 bytes (32B-aligned tile width in UB) + pad = 256 - 200 = 56 bytes (padded to 32B boundary with pad_val) +``` + +```mlir +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 + %c200_i64, // len_burst = 200 bytes + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %true, // data_select_bit = true (enable padding) + %c0_i64, // l2_cache_ctl = 0 + %c200_i64, // src_stride = 200 bytes + %c256_i64 // dst_stride = 256 bytes (32B-aligned) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 4: UB→GM — Store a 32×32 f32 Tile (Simple Case) + +Store a 32×32 f32 tile from UB back to GM. This matches the `abs_kernel_2d` test case. + +``` +UB (source, 32B-aligned, 32 × 32 f32): + + |<- src_stride = 128B (32B-aligned) ->| + |<- len_burst = 128B ->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 + + (no padding here — len_burst == src_stride) + +GM (dest, 32 × 32 f32): + + |<- dst_stride = 128B ->| + |<- len_burst = 128B -->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 +``` + +```mlir +// Configure MTE3 strides +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + +pto.copy_ubuf_to_gm %ub_out, %arg1, + %c0_i64, // sid = 0 + %c32_i64, // n_burst = 32 + %c128_i64, // len_burst = 128 bytes + %c0_i64, // reserved = 0 + %c128_i64, // dst_stride = 128 bytes + %c128_i64 // src_stride = 128 bytes + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` + +--- + +#### Example 5: UB→GM — Store a 2D Tile Back to a Larger Matrix + +Store a 64×128 tile (f16) from UB back to a 1024×512 GM matrix at an offset. + +``` +UB (source, 32B-aligned, 64 × 128 f16): + + |<- src_stride = 256B (32B-aligned) ->| + |<- len_burst = 256B ->| + +--[#####TILE#####]---+ row 0 + +--[#####TILE#####]---+ row 1 + ... + +--[#####TILE#####]---+ row 63 + + (no padding here — len_burst == src_stride) + +GM (dest, into 1024 × 512 matrix): + + |<----------- dst_stride = 1024B (start-to-start) --------->| + |<- len_burst = 256B ->| | + col 0 col 128 col 512 + +--[#####TILE#####]---+.............................+ row R + +--[#####TILE#####]---+.............................+ row R+1 + ... + +--[#####TILE#####]---+.............................+ row R+63 + + MTE3 reads len_burst bytes from each 32B-aligned UB row, + writes only len_burst bytes per GM row (stride controls row spacing). +``` + +```mlir +// Configure MTE3 strides +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 + +pto.copy_ubuf_to_gm %ub_ptr, %gm_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 + %c256_i64, // len_burst = 256 bytes + %c0_i64, // reserved = 0 + %c1024_i64, // dst_stride = 1024 bytes (GM row) + %c256_i64 // src_stride = 256 bytes (UB row) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` + +--- + +#### Example 6: GM→UB with Multi-Level Loop (Batch of Tiles) + +Load 4 batches of 8×128 tiles from a [4, 8, 128] f16 tensor using loop1. + +``` +GM [4, 8, 128] f16 (contiguous): UB (4 tiles laid out sequentially): + + batch 0: 8 rows × 256 bytes [batch 0: 8×128][batch 1: 8×128] + batch 1: 8 rows × 256 bytes [batch 2: 8×128][batch 3: 8×128] + batch 2: 8 rows × 256 bytes + batch 3: 8 rows × 256 bytes loop1 src_stride = 2048 bytes (8 × 256) + loop1 dst_stride = 2048 bytes (8 × 256) + Each batch = 8 × 256 = 2048 bytes loop1_count = 4 (iterate over batches) +``` + +```mlir +// loop1_count = 4 batches, loop2_count = 1 (not used) +pto.set_loop_size_outtoub %c4_i64, %c1_i64 : i64, i64 + +// loop1 stride: advance by one batch (2048 bytes) in both GM and UB +pto.set_loop1_stride_outtoub %c2048_i64, %c2048_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c8_i64, // n_burst = 8 rows per batch + %c256_i64, // len_burst = 256 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c256_i64, // src_stride = 256 (contiguous rows) + %c256_i64 // dst_stride = 256 (contiguous rows) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +Execution trace: + +``` +loop1 iter 0: gm_ptr + 0×2048 → ub_ptr + 0×2048, DMA 8 rows × 256B +loop1 iter 1: gm_ptr + 1×2048 → ub_ptr + 1×2048, DMA 8 rows × 256B +loop1 iter 2: gm_ptr + 2×2048 → ub_ptr + 2×2048, DMA 8 rows × 256B +loop1 iter 3: gm_ptr + 3×2048 → ub_ptr + 3×2048, DMA 8 rows × 256B +``` + + + +### 3. Vector Load/Store + +> **Category:** UB ↔ Vector Register data movement +> **Pipeline:** PIPE_V (Vector Core) + +Vector loads move data from Unified Buffer (UB) to vector registers (`vreg`). Vector stores move data from `vreg` back to UB. All vector compute operates only on `vreg` — UB is the staging area between DMA and compute. + +#### Common Operand Model + +- `%source` / `%dest` is the base address operand in SSA form. The base pointer + MUST address the Vector tile buffer / UB space. +- `%offset` is the displacement operand in SSA form. The exact encoding is + instruction-specific, but the effective address and any post-update behavior + MUST match the selected instruction form. +- `%mask` is the predicate operand for predicated memory families. For memory + families, + inactive lanes or inactive blocks MUST NOT issue memory requests unless the + instruction explicitly documents a different behavior. +- `%result` is the destination vector register value in SSA form. +- `!pto.align` is the SSA carrier for alignment-buffer state used by unaligned + load/store families. The PTO micro Instruction representation makes that state explicit rather than implicit. + +--- + +#### Latency and throughput (A5) + +**Cycle-accurate simulator (CA model)** issue→retire timings for vector-side instructions behind this chapter. Values are **simulator** results, **not** guaranteed for silicon. + +**SOC:** Tables below are from **Ascend910_9599** CA sim (the pto-isa ST default when **Ascend950PR_9599** is not selected). + +**Log `dist:` tokens:** PTO load/store modes lower to **`RV_VLD` / `RV_VLDI` / `RV_VST` / `RV_VSTI`** with a **`dist:`** field on the vector pipes (`RVECLD` / `RVECST`). Some simulator logs typo contiguous load as `dist:NORAML`; treat as **`NORMAL`**. + +##### Reference op latencies (A5 mnemonics) + +| A5 mnemonic | Mode / note | Typical issue→retire (cycles) | +|-------------|-------------|------------------------------| +| `RV_VLD` | `dist:NORMAL` / `NORAML` | **9** | +| `RV_VLDI` | `dist:DINTLV_B32` (dual vreg) | **9** | +| `RV_VST` / `RV_VSTI` | `dist:NORM_B32` | **9** | +| `RV_VGATHER2` | `Dtype: B32` | **27–28** | +| `RV_VGATHERB` | indexed byte gather | **~21** | +| `RV_VSCATTER` | `Dtype: B16` | **~17** | +| `RV_VADD` | F32 between UB-backed ops | **7** | + +##### `dist:` tokens (issue→retire) + +Most **`dist:`** tokens are **9** issue→retire cycles. **`INTLV_*`** on **`RV_VSTI`** are **12** cycles. + +| `dist:` (as in log) | RV op | issue→retire (cycles) | +|---------------------|-------|----------------------| +| `DINTLV_B32` | `RV_VLDI` | **9** | +| `DINTLV_B16` | `RV_VLDI` | **9** | +| `DINTLV_B8` | `RV_VLDI` | **9** | +| `BRC_B32` | `RV_VLD` | **9** | +| `BRC_B8` | `RV_VLD` | **9** | +| `BRC_B16` | `RV_VLD` | **9** | +| `BRC_BLK` | `RV_VLD` | **9** | +| `INTLV_B32` | `RV_VSTI` | **12** | +| `INTLV_B16` | `RV_VSTI` | **12** | +| `INTLV_B8` | `RV_VSTI` | **12** | +| `UNPK_B8` | `RV_VLD` | **9** | +| `UNPK_B16` | `RV_VLD` | **9** | +| `UNPK_B32` | `RV_VLD` | **9** | +| `NORM_B32` | `RV_VSTI` | **9** | +| `NORM_B16` | `RV_VSTI` | **9** | +| `NORM_B8` | `RV_VSTI` | **9** | +| `PK_B32` | `RV_VSTI` | **9** | +| `PK_B16` | `RV_VSTI` | **9** | +| `NORMAL` / `NORAML` | `RV_VLD` | **9** | + +**Note:** PTO intrinsic **`BLK`** matches the **`BRC_BLK`** `dist:` string on **`RV_VLD`** in simulator logs (block-replicate path; not a plain contiguous copy in the usual tiling use). + +**Issue (vector load/store):** `pto.vlds` (**`RV_VLD`**) is **dual-issue capable**: two independent `pto.vlds` can issue **in the same cycle**. **Alternatively**, the hardware can issue **one** `pto.vlds` **and** **one** `pto.vsts` together (**1+1**) in the same cycle. Each cycle is **either** dual **`vlds`** **or** **`vlds` + `vsts` (1+1)**—those two issue modes are mutually exclusive. Sustained throughput still depends on RAW hazards and loop structure. + +**Throughput (simulator, pattern-dependent):** + +- **`RV_VLD` / `pto.vlds`:** Dual-issue **or** half of a **1+1** with `vsts`, per the rule above. +- **`RV_VST` / `pto.vsts`:** In a **1+1** cycle, pairs with one `vlds`; otherwise typically **one** store per cycle in tight loops. +- **`RV_VGATHER2`:** Much lower than contiguous `RV_VLD` (on the order of **~0.1** ops/cycle in steady-state alongside 27–28-cycle latency). + +##### PTO `dist` summary (loads) + +| PTO `dist` (load) | Latency | +|-------------------|-------------------| +| `NORM` | **9** cycles | +| `UNPK_B8`, `UNPK_B16`, `UNPK_B32` | **9** cycles | +| `DINTLV_B32` | **9** cycles (`RV_VLDI`) | +| `DINTLV_B16`, `DINTLV_B8` | **9** cycles (same `RV_VLDI` + `dist:DINTLV_*` path as `DINTLV_B32`) | +| `BRC_B32` | **9** cycles | +| `BRC_B8`, `BRC_B16` | **9** cycles (`RV_VLD`) | +| `BLK` | **9** cycles as **`dist:BRC_BLK`** on `RV_VLD` | +| `BDINTLV` | **9** cycles | +| `US_*`, `DS_*`, `SPLT*` | **9** cycles | + +##### PTO `dist` summary (stores) + +| PTO `dist` (store) | Latency | +|--------------------|-------------------| +| `NORM_B8`, `NORM_B16`, `NORM_B32` | **9** cycles (`RV_VSTI`) | +| `PK_B16`, `PK_B32` | **9** cycles | +| `INTLV_B32` (`pto.vstx2`) | **12** cycles | +| `INTLV_B16`, `INTLV_B8` | **12** cycles (same interleave store path as `INTLV_B32`) | +| `MRG4CHN_B8`, `MRG2CHN_*` | **9** cycles | + +##### Gather, scatter, and special addressing + +| PTO op | A5-level | Latency | +|--------|----------|-------------------| +| `pto.vgather2` | `RV_VGATHER2` | **27–28** cycles (pattern-dependent) | +| `pto.vgatherb` | `RV_VGATHERB` | **~21** cycles issue→retire | +| `pto.vgather2_bc` | (broadcast gather) | **27–28** cycles (same as **`pto.vgather2`**) | +| `pto.vscatter` | `RV_VSCATTER` | **~17** cycles for **`Dtype: B16`** | + +##### Strided loads/stores, unaligned ops, alignment state + +Ops such as **`pto.vldas`**, **`pto.vldus`**, **`pto.vsld`**, **`pto.vsldb`**, **`pto.vsst`**, **`pto.vsstb`**, **`pto.vsta`**, **`pto.vstas`**, **`pto.vstar`**, **`pto.vstu`**, **`pto.vstus`**, **`pto.vstur`**: **9** cycles (same vector load/store pipe family as contiguous `RV_VLD` / `RV_VST` unless listed otherwise above). + +##### Dual-issue vs DMA + +DMA **`TLOAD` / `TSTORE`** (global memory ↔ UB) use **MTE** pipes, not `RV_VLD`/`RV_VST`. **MTE2** `MOV_*` latency is not the same as vector `RV_VLD` latency (see `02-dma-copy.md` for GM↔UB movement). + +--- + +#### Contiguous Loads + +##### `pto.vlds` + +- **syntax:** `%result = pto.vlds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.vreg` +- **semantics:** Vector load with distribution mode. +- **inputs:** + `%source` is the UB base address, `%offset` is the load displacement, and + `DIST` selects the distribution mode. +- **outputs:** + `%result` is the loaded vector register value. +- **constraints and limitations:** + The effective address MUST satisfy the alignment rule of the selected + distribution mode. `NORM` reads one full vector footprint. Broadcast, + upsample, downsample, unpack, split-channel, and deinterleave modes change + how memory bytes are mapped into destination lanes, but they do not change the + fact that the source is UB memory. + +**Distribution modes:** + +| Mode | Description | C Semantics | Latency | +|------|-------------|-------------|---------------------| +| `NORM` | Contiguous 256B load | `dst[i] = UB[base + i * sizeof(T)]` | **9** cycles | +| `BRC_B32` | Broadcast single element | `dst[i] = UB[base]` for all i | **9** cycles | +| `BRC_B8`, `BRC_B16` | Broadcast first lane element | Same idea at B8/B16 width | **9** cycles | +| `US_B8/B16` | Upsample (duplicate each element) | `dst[2*i] = dst[2*i+1] = UB[base + i]` | **9** cycles | +| `DS_B8/B16` | Downsample (every 2nd element) | `dst[i] = UB[base + 2*i]` | **9** cycles | +| `UNPK_B8/B16/B32` | Unpack (zero-extend to wider type) | `dst_i32[i] = (uint32_t)UB_i16[base + 2*i]` | **9** cycles | +| `SPLT4CHN_B8` | Split 4-channel (RGBA → R plane) | Extract every 4th byte | **9** cycles | +| `SPLT2CHN_B8/B16` | Split 2-channel | Extract every 2nd element | **9** cycles | +| `DINTLV_B32` | Deinterleave 32-bit | Even elements only | **9** cycles | +| `DINTLV_B16`, `DINTLV_B8` | Deinterleave 16-bit / 8-bit | Pair lanes from interleaved UB | **9** cycles | +| `BDINTLV` | Block deinterleave | (see PTO headers for exact tiling) | **9** cycles | +| `BLK` | Block load | Blocked / tiled access pattern (see PTO headers) | **9** cycles (`dist:BRC_BLK` on `RV_VLD`) | + +**Example — Contiguous load:** +```mlir +%v = pto.vlds %ub[%offset] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +``` + +**Example — Broadcast scalar to all lanes:** +```mlir +%v = pto.vlds %ub[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> +``` + +--- + +##### `pto.vldas` + +- **syntax:** `%result = pto.vldas %source : !pto.ptr -> !pto.align` +- **semantics:** Prime alignment buffer for subsequent unaligned load. +- **inputs:** + `%source` is the UB address whose surrounding aligned block seeds the load + alignment state. +- **outputs:** + `%result` is the initialized load-alignment state. +- **constraints and limitations:** + This op is the required leading operation for a `pto.vldus` stream using the + same alignment state. The source address itself need not be 32-byte aligned; + hardware truncates it to the aligned block boundary for the priming load. +- **Latency:** **9** cycles. + +--- + +##### `pto.vldus` + +- **syntax:** `%result, %align_out, %base_out = pto.vldus %source, %align : !pto.ptr, !pto.align -> !pto.vreg, !pto.align, !pto.ptr` +- **semantics:** Unaligned load using primed align state. +- **inputs:** + `%source` is the current UB address and `%align` is the incoming load + alignment state primed by `pto.vldas` or a prior `pto.vldus`. +- **outputs:** + `%result` is the assembled vector value, `%align_out` is the updated alignment + state, and `%base_out` is the post-update base pointer state exposed in SSA + form. +- **constraints and limitations:** + A matching `pto.vldas` MUST appear before the first dependent `pto.vldus` + stream in the same vector loop. Both the alignment state and the base address + advance across the stream, and the PTO micro Instruction representation exposes those updates as SSA results. +- **Latency:** **9** cycles. + +**Unaligned load pattern:** +```mlir +%align = pto.vldas %ub : !pto.ptr -> !pto.align +%vec, %align2, %ub2 = pto.vldus %ub, %align : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align, !pto.ptr +``` + +--- + +#### Dual Loads (Deinterleave) + +##### `pto.vldx2` + +- **syntax:** `%low, %high = pto.vldx2 %source[%offset], "DIST" : !pto.ptr, index -> !pto.vreg, !pto.vreg` +- **semantics:** Dual load with deinterleave (AoS → SoA conversion). +- **inputs:** + `%source` is the UB base pointer, `%offset` is the displacement, and `DIST` + selects a dual-load/deinterleave layout. +- **outputs:** + `%low` and `%high` are the two destination vectors. +- **constraints and limitations:** + This family is only legal for interleave/deinterleave style distributions. + The two outputs form an ordered pair, and that pairing MUST be preserved. +- **Latency:** **`DINTLV_B32` → 9** cycles on `RV_VLDI`. **`DINTLV_B16` / `DINTLV_B8` → 9** cycles on `RV_VLDI`. **`BDINTLV` → 9** cycles on `RV_VLDI`. + +**Distribution modes:** `DINTLV_B8`, `DINTLV_B16`, `DINTLV_B32`, `BDINTLV` + +```c +// DINTLV_B32: deinterleave 32-bit elements +for (int i = 0; i < 64; i++) { + low[i] = UB[base + 8*i]; // even elements + high[i] = UB[base + 8*i + 4]; // odd elements +} +``` + +**Example — Load interleaved XY pairs into separate X/Y vectors:** +```mlir +%x, %y = pto.vldx2 %ub[%offset], "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +``` + +--- + +#### Strided Loads + +##### `pto.vsld` + +- **syntax:** `%result = pto.vsld %source[%offset], "STRIDE" : !pto.ptr -> !pto.vreg` +- **semantics:** Strided load with fixed stride pattern. +- **inputs:** + `%source` is the UB base pointer and `%offset` is the displacement encoded + with the selected fixed stride mode. +- **outputs:** + `%result` is the loaded vector. +- **constraints and limitations:** + This is a deprecated compatibility family. The selected stride token + determines which sub-elements are read from each source block. +- **Latency:** **9** cycles. + +**Stride modes:** `STRIDE_S3_B16`, `STRIDE_S4_B64`, `STRIDE_S8_B32`, `STRIDE_S2_B64` + +--- + +##### `pto.vsldb` + +- **syntax:** `%result = pto.vsldb %source, %offset, %mask : !pto.ptr, i32, !pto.mask -> !pto.vreg` +- **semantics:** Block-strided load for 2D tile access. +- **inputs:** + `%source` is the UB base pointer, `%offset` is the packed stride/control word, + and `%mask` controls which blocks participate. +- **outputs:** + `%result` is the loaded vector. +- **constraints and limitations:** + `%offset` is not a plain byte displacement; it encodes the block stride and + repeat pattern. If a block is masked off, the corresponding destination block + is zeroed and MUST NOT raise an address overflow exception for that block. +- **Latency:** **9** cycles. + +--- + +#### Gather (Indexed) Loads + +##### `pto.vgather2` + +- **syntax:** `%result = pto.vgather2 %source, %offsets, %active_lanes : !pto.ptr, !pto.vreg, index -> !pto.vreg` +- **semantics:** Indexed gather from UB. +- **inputs:** + `%source` is the UB base pointer, `%offsets` provides per-lane element + offsets, and `%active_lanes` bounds how many lanes participate. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + Only the first `%active_lanes` indices participate. The index element width + and interpretation MUST match the selected gather form, and each effective + address must satisfy that form's alignment rules. +- **Latency:** **27–28** cycles per `RV_VGATHER2`; throughput much lower than contiguous `RV_VLD` (see **Latency and throughput (A5)** at the start of this chapter). + +```c +for (int i = 0; i < active_lanes; i++) + dst[i] = UB[base + offsets[i] * sizeof(T)]; +``` + +--- + +##### `pto.vgatherb` + +- **syntax:** `%result = pto.vgatherb %source, %offsets, %active_lanes : !pto.ptr, !pto.vreg, index -> !pto.vreg` +- **semantics:** Byte-granularity indexed gather from UB. +- **inputs:** + `%source` is the UB base pointer, `%offsets` contains per-block byte offsets, + and `%active_lanes` bounds the number of active gathered blocks. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + This is a block gather, not a byte-per-lane gather. `%source` MUST be 32-byte + aligned, each participating offset MUST describe a 32-byte-aligned block, and + inactive blocks are zero-filled. +- **Latency:** **~21** cycles issue→retire. + +```c +for (int i = 0; i < active_lanes; i++) + dst[i] = UB[base + offsets[i]]; // byte-addressed +``` + +--- + +##### `pto.vgather2_bc` + +- **syntax:** `%result = pto.vgather2_bc %source, %offsets, %mask : !pto.ptr, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Gather with broadcast, conditioned by mask. +- **inputs:** + `%source` is the UB base pointer, `%offsets` contains gather indices, and + `%mask` gates which lanes participate. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + This is a backward-compatible family. Masked-off lanes do not participate in + address coalescing and do not trigger address overflow exceptions; their + destination lanes are zero-filled. +- **Latency:** **27–28** cycles (same as **`pto.vgather2`**). + +--- + +#### Contiguous Stores + +##### `pto.vsts` + +- **syntax:** `pto.vsts %value, %dest[%offset], %mask {dist = "DIST"} : !pto.vreg, !pto.ptr, !pto.mask` +- **semantics:** Vector store with distribution mode. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offset` is + the displacement, `%mask` selects the active lanes or sub-elements, and + `DIST` selects the store distribution. +- **outputs:** + This op has no SSA result; it writes to UB memory. +- **constraints and limitations:** + The effective destination address MUST satisfy the alignment rule of the + selected store mode. Narrowing/packing modes may only preserve a subset of the + source bits. Merge-channel modes reinterpret the source vector as channel + planes and interleave them on store. + +**Distribution modes:** + +| Mode | Description | C Semantics | Latency | +|------|-------------|-------------|---------------------| +| `NORM_B8/B16/B32` | Contiguous store | `UB[base + i] = src[i]` | **9** cycles | +| `PK_B16/B32` | Pack/narrowing store | `UB_i16[base + 2*i] = truncate_16(src_i32[i])` | **9** cycles | +| `MRG4CHN_B8` | Merge 4 channels (R,G,B,A → RGBA) | Interleave 4 planes | **9** cycles | +| `MRG2CHN_B8/B16` | Merge 2 channels | Interleave 2 planes | **9** cycles | + +**Example — Contiguous store:** +```mlir +pto.vsts %v, %ub[%offset], %mask {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +--- + +#### Dual Stores (Interleave) + +##### `pto.vstx2` + +- **syntax:** `pto.vstx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vreg, !pto.ptr, index, !pto.mask` +- **semantics:** Dual interleaved store (SoA → AoS conversion). +- **inputs:** + `%low` and `%high` are the two source vectors, `%dest` is the UB base pointer, + `%offset` is the displacement, `DIST` selects the interleave layout, and + `%mask` gates the participating elements. +- **outputs:** + This op has no SSA result; it writes an interleaved stream to UB. +- **constraints and limitations:** + This family is only legal for interleave distributions. The two source + vectors form an ordered pair, and the interleave semantics of that pair MUST + be preserved. +- **Latency:** **`INTLV_B32` / `INTLV_B16` / `INTLV_B8` → 12** cycles on `RV_VSTI`. + +**Distribution modes:** `INTLV_B8`, `INTLV_B16`, `INTLV_B32` + +```c +// INTLV_B32: +for (int i = 0; i < 64; i++) { + UB[base + 8*i] = low[i]; + UB[base + 8*i + 4] = high[i]; +} +``` + +--- + +#### Strided Stores + +##### `pto.vsst` + +- **syntax:** `pto.vsst %value, %dest[%offset], "STRIDE" : !pto.vreg, !pto.ptr` +- **semantics:** Strided store with fixed stride pattern. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, and `%offset` + / `STRIDE` select the fixed strided layout. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + This is a deprecated compatibility family. The stride token, not the vector + lane number alone, determines which destination elements are written. +- **Latency:** **9** cycles. + +--- + +##### `pto.vsstb` + +- **syntax:** `pto.vsstb %value, %dest, %offset, %mask : !pto.vreg, !pto.ptr, i32, !pto.mask` +- **semantics:** Block-strided store for 2D tile access. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offset` is + the packed stride/control word, and `%mask` controls block participation. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + `%offset` is a control word, not a plain byte displacement. This is a + deprecated compatibility family kept for surface coverage. +- **Latency:** **9** cycles. + +--- + +#### Scatter (Indexed) Stores + +##### `pto.vscatter` + +- **syntax:** `pto.vscatter %value, %dest, %offsets, %active_lanes : !pto.vreg, !pto.ptr, !pto.vreg, index` +- **semantics:** Indexed scatter to UB. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offsets` + provides per-lane or per-block indices, and `%active_lanes` bounds the active + requests. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + Only `b8`, `b16`, and `b32` element sizes are supported. The index vector + must use a supported integer element type and layout for this family. + Each computed address MUST be element-aligned. If two or more indices alias, + only one write is guaranteed and the winning lane is implementation-defined. +- **Latency:** **~17** cycles for **`Dtype: B16`**. + +```c +for (int i = 0; i < active_lanes; i++) + UB[base + offsets[i] * sizeof(T)] = src[i]; +``` + +--- + +#### Alignment State Stores + +##### `pto.vsta` + +- **syntax:** `pto.vsta %value, %dest[%offset] : !pto.align, !pto.ptr, index` +- **semantics:** Flush alignment state to memory. +- **inputs:** + `%value` is the pending store-alignment state, `%dest` is the UB base pointer, + and `%offset` is the flush displacement. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + The flush address MUST match the post-updated address expected by the + preceding unaligned-store stream. After the flush, the corresponding store + alignment state is consumed. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstas` +- **syntax:** `pto.vstas %value, %dest, %offset : !pto.align, !pto.ptr, i32` +- **semantics:** Scalar-register-offset form of alignment-state flush. +- **inputs:** + `%value` is the pending store-alignment state, `%dest` is the UB base + pointer, and `%offset` is the scalar-register style displacement. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + This family uses the same buffered-tail semantics as `pto.vsta` but keeps the + scalar-offset form explicit. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstar` +- **syntax:** `pto.vstar %value, %dest : !pto.align, !pto.ptr` +- **semantics:** Flush alignment state using the register-update form. +- **inputs:** + `%value` is the pending store-alignment state and `%dest` is the UB base + pointer. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + The implicit update state consumed by this flush MUST correspond to the same + store stream that produced `%value`. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstu` +- **syntax:** `%align_out, %base_out = pto.vstu %align_in, %base_in, %value, %dest, %mode : !pto.align, !pto.ptr, !pto.vreg, !pto.ptr, index -> !pto.align, !pto.ptr` +- **semantics:** Unaligned store with explicit threaded alignment/base state. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%base_in` is the current + stream base, `%value` is the vector to store, `%dest` is the UB base pointer, + and `%mode` selects the post-update behavior. +- **outputs:** + `%align_out` is the updated buffered-tail state and `%base_out` is the + post-update base pointer state. +- **constraints and limitations:** + This op models a stateful unaligned-store sequence in SSA form. A final + `pto.vsta` / `pto.vstas` / `pto.vstar` is still required to flush the trailing + buffered bytes. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstus` +- **syntax:** `%align_out, %base_out = pto.vstus %align_in, %base_in, %value, %dest, %offset : !pto.align, !pto.ptr, !pto.vreg, !pto.ptr, i32 -> !pto.align, !pto.ptr` +- **semantics:** Scalar-offset unaligned store with threaded state. +- **inputs:** + Same roles as `pto.vstu`, but `%offset` is provided explicitly as the scalar + displacement. +- **outputs:** + Updated alignment state and base state. +- **constraints and limitations:** + The same final flush requirement and state-threading constraints as + `pto.vstu` apply. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstur` +- **syntax:** `%align_out = pto.vstur %align_in, %value, %dest : !pto.align, !pto.vreg, !pto.ptr -> !pto.align` +- **semantics:** Register-update unaligned store form. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%value` is the vector to + store, and `%dest` is the UB base pointer. +- **outputs:** + `%align_out` is the updated buffered-tail state. +- **constraints and limitations:** + This op updates only the residual alignment state. A matching flush op is + still required to emit the trailing bytes. +- **Latency:** **9** cycles. + +--- + +#### Stateful Store Ops + +These ops make reference-updated state explicit as SSA results. + +##### `pto.vstu` + +- **syntax:** `%align_out, %offset_out = pto.vstu %align_in, %offset_in, %value, %base, "MODE" : !pto.align, index, !pto.vreg, !pto.ptr -> !pto.align, index` +- **semantics:** Unaligned store with align + offset state update. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%offset_in` is the current + logical byte/element displacement, `%value` is the vector being stored, and + `%base` is the UB base pointer. +- **outputs:** + `%align_out` is the updated alignment/tail state and `%offset_out` is the + next offset after applying the selected post-update rule. +- **constraints and limitations:** + The alignment state MUST be threaded in program order. A terminating flush + form such as `pto.vstar`/`pto.vstas` is still required to commit the buffered + tail bytes. +- **Latency:** **9** cycles. + +**Mode tokens:** `POST_UPDATE`, `NO_POST_UPDATE` + +--- + +##### `pto.vstus` + +- **syntax:** `%align_out, %base_out = pto.vstus %align_in, %offset, %value, %base, "MODE" : !pto.align, i32, !pto.vreg, !pto.ptr -> !pto.align, !pto.ptr` +- **semantics:** Unaligned store with scalar offset and state update. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%offset` is the scalar + displacement, `%value` is the vector being stored, and `%base` is the UB base + pointer. +- **outputs:** + `%align_out` is the updated buffered-tail state and `%base_out` is the next + base pointer when the lowering chooses a post-update form. +- **constraints and limitations:** + This is the scalar-offset stateful form of the unaligned store family. The + scalar offset width and update mode MUST match the selected form, and a later + flush op is still required. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstur` + +- **syntax:** `%align_out = pto.vstur %align_in, %value, %base, "MODE" : !pto.align, !pto.vreg, !pto.ptr -> !pto.align` +- **semantics:** Unaligned store with residual flush and state update. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%value` is the vector to + store, and `%base` is the UB base pointer. +- **outputs:** + `%align_out` is the updated residual state after the current partial store. +- **constraints and limitations:** + This form exposes only the evolving state; it does not by itself guarantee + that all buffered bytes have reached memory. A compatible final flush is still + required unless the surrounding sequence is known to be complete. +- **Latency:** **9** cycles. + + + +### 4. Predicate Load/Store + +> **Category:** UB ↔ Predicate Register data movement +> **Pipeline:** PIPE_V (Vector Core) + +Predicate registers (`!pto.mask`) are 256-bit registers that enable per-lane conditional execution. These ops move predicate values between UB and predicate registers. + +In concrete examples, `G` should be chosen to match the consumer family. The +examples below use `b32` when the loaded/stored mask is paired with `f32` +vector compares or selects. + +--- + +#### Predicate Loads + +##### `pto.plds` + +- **syntax:** `%result = pto.plds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.mask` +- **semantics:** Load predicate register with scalar offset. + +**Distribution modes:** `NORM`, `US`, `DS` + +**Example:** +```mlir +%mask = pto.plds %ub[%c0] {dist = "NORM"} : !pto.ptr -> !pto.mask +``` + +--- + +##### `pto.pld` + +- **syntax:** `%result = pto.pld %source[%offset], "DIST" : !pto.ptr, index -> !pto.mask` +- **semantics:** Load predicate register with areg offset. + +--- + +##### `pto.pldi` + +- **syntax:** `%result = pto.pldi %source, %offset, "DIST" : !pto.ptr, i32 -> !pto.mask` +- **semantics:** Load predicate register with immediate offset. + +--- + +#### Predicate Stores + +##### `pto.psts` + +- **syntax:** `pto.psts %value, %dest[%offset] : !pto.mask, !pto.ptr` +- **semantics:** Store predicate register with scalar offset. + +**Example:** +```mlir +pto.psts %mask, %ub[%c0] : !pto.mask, !pto.ptr +``` + +--- + +##### `pto.pst` + +- **syntax:** `pto.pst %value, %dest[%offset], "DIST" : !pto.mask, !pto.ptr, index` +- **semantics:** Store predicate register with areg offset. + +**Distribution modes:** `NORM`, `PK` + +--- + +##### `pto.psti` + +- **syntax:** `pto.psti %value, %dest, %offset, "DIST" : !pto.mask, !pto.ptr, i32` +- **semantics:** Store predicate register with immediate offset. + +--- + +##### `pto.pstu` + +- **syntax:** `%align_out, %base_out = pto.pstu %align_in, %value, %base : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr` +- **semantics:** Predicate unaligned store with align state update. + +--- + +#### Typical Usage Pattern + +```mlir +// Generate comparison mask +%mask = pto.vcmp %v0, %v1, %seed, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + +// Store mask to UB for later use +pto.psts %mask, %ub_mask[%c0] : !pto.mask, !pto.ptr + +// ... later in another kernel ... + +// Load mask from UB +%saved_mask = pto.plds %ub_mask[%c0] {dist = "NORM"} : !pto.ptr -> !pto.mask + +// Use for predicated select +%result = pto.vsel %v_true, %v_false, %saved_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 5. Materialization & Predicate Ops + +> **Category:** Scalar broadcast, predicate generation and manipulation +> **Pipeline:** PIPE_V (Vector Core) + +These ops create vectors from scalar values and manipulate predicate registers. + +#### Common Operand Model + +- `%value` is the scalar source value in SSA form. +- `%input` is either a source scalar or a source vector depending on the op. +- `%result` is the destination vector register value. +- For 32-bit scalar inputs, the scalar source MUST satisfy the backend's legal + scalar-source constraints for this family. + +--- + +#### Scalar Materialization + +##### `pto.vbr` + +- **syntax:** `%result = pto.vbr %value : T -> !pto.vreg` +- **semantics:** Broadcast scalar to all vector lanes. +- **inputs:** + `%value` is the scalar source. +- **outputs:** + `%result` is a vector whose active lanes all carry `%value`. +- **constraints and limitations:** + Supported forms are `b8`, `b16`, and `b32`. For `b8`, only the low 8 bits of + the scalar source are consumed. + +```c +for (int i = 0; i < N; i++) + dst[i] = value; +``` + +**Example:** +```mlir +%one = pto.vbr %c1_f32 : f32 -> !pto.vreg<64xf32> +``` + +--- + +##### `pto.vdup` + +- **syntax:** `%result = pto.vdup %input {position = "POSITION"} : T|!pto.vreg -> !pto.vreg` +- **semantics:** Duplicate scalar or vector element to all lanes. +- **inputs:** + `%input` supplies the scalar or source-lane value selected by `position`. +- **outputs:** + `%result` is the duplicated vector. +- **constraints and limitations:** + `position` selects which source element or scalar position is duplicated. The + current PTO micro Instruction representation models that selector as an attribute rather than a + separate operand. + +```c +for (int i = 0; i < N; i++) + dst[i] = input_scalar_or_element; +``` + +--- + +#### Predicate Generation + +##### `pto.pset_b8` / `pto.pset_b16` / `pto.pset_b32` + +- **syntax:** `%result = pto.pset_b8 "PAT_*" : !pto.mask` +- **syntax:** `%result = pto.pset_b16 "PAT_*" : !pto.mask` +- **syntax:** `%result = pto.pset_b32 "PAT_*" : !pto.mask` +- **semantics:** Generate predicate from pattern. + +**Patterns:** + +| Pattern | Description | +|---------|-------------| +| `PAT_ALL` | All lanes active | +| `PAT_ALLF` | All lanes inactive | +| `PAT_H` | High half active | +| `PAT_Q` | Upper quarter active | +| `PAT_VL1`...`PAT_VL128` | First N lanes active | +| `PAT_M3`, `PAT_M4` | Modular patterns | + +**Example — All 64 f32 lanes active:** +```mlir +%all_active = pto.pset_b32 "PAT_ALL" : !pto.mask +``` + +**Example — First 16 lanes active:** +```mlir +%first_16 = pto.pset_b32 "PAT_VL16" : !pto.mask +``` + +--- + +##### `pto.pge_b8` / `pto.pge_b16` / `pto.pge_b32` + +- **syntax:** `%result = pto.pge_b8 "PAT_*" : !pto.mask` +- **syntax:** `%result = pto.pge_b16 "PAT_*" : !pto.mask` +- **syntax:** `%result = pto.pge_b32 "PAT_*" : !pto.mask` +- **semantics:** Generate tail mask — first N lanes active. + +```c +for (int i = 0; i < TOTAL_LANES; i++) + mask[i] = (i < len); +``` + +**Example — Tail mask for remainder loop:** +```mlir +%tail_mask = pto.pge_b32 "PAT_VL8" : !pto.mask +``` + +--- + +##### `pto.plt_b8` / `pto.plt_b16` / `pto.plt_b32` + +- **syntax:** `%mask, %scalar_out = pto.plt_b8 %scalar : i32 -> !pto.mask, i32` +- **syntax:** `%mask, %scalar_out = pto.plt_b16 %scalar : i32 -> !pto.mask, i32` +- **syntax:** `%mask, %scalar_out = pto.plt_b32 %scalar : i32 -> !pto.mask, i32` +- **semantics:** Generate predicate state together with updated scalar state. + +--- + +#### Predicate Pack/Unpack + +##### `pto.ppack` + +- **syntax:** `%result = pto.ppack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Narrowing pack of predicate register. + +**Part tokens:** `LOWER`, `HIGHER` + +--- + +##### `pto.punpack` + +- **syntax:** `%result = pto.punpack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Widening unpack of predicate register. + +--- + +#### Predicate Logical Ops + +##### `pto.pand` + +- **syntax:** `%result = pto.pand %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise AND. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src0[i] & src1[i]; +``` + +--- + +##### `pto.por` + +- **syntax:** `%result = pto.por %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise OR. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src0[i] | src1[i]; +``` + +--- + +##### `pto.pxor` + +- **syntax:** `%result = pto.pxor %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise XOR. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src0[i] ^ src1[i]; +``` + +--- + +##### `pto.pnot` + +- **syntax:** `%result = pto.pnot %input, %mask : !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise NOT. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = ~src[i]; +``` + +--- + +##### `pto.psel` + +- **syntax:** `%result = pto.psel %src0, %src1, %sel : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate select (mux). + +```c +for (int i = 0; i < N; i++) + dst[i] = sel[i] ? src0[i] : src1[i]; +``` + +--- + +#### Predicate Movement + +##### `pto.ppack` + +- **syntax:** `%result = pto.ppack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Narrowing pack of predicate register. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src[i]; +``` + +--- + +##### `pto.punpack` + +- **syntax:** `%result = pto.punpack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Widening unpack of predicate register. + +--- + +##### `pto.pdintlv_b8` + +- **syntax:** `%low, %high = pto.pdintlv_b8 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **semantics:** Predicate deinterleave. + +--- + +##### `pto.pintlv_b16` + +- **syntax:** `%low, %high = pto.pintlv_b16 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **semantics:** Predicate interleave. + +--- + +#### Typical Usage + +```mlir +// Generate all-active mask for f32 (64 lanes) +%all = pto.pset_b32 "PAT_ALL" : !pto.mask + +// Generate tail mask for remainder (last 12 elements) +%tail = pto.pge_b32 "PAT_VL12" : !pto.mask + +// Compare and generate mask +%cmp_mask = pto.vcmp %a, %b, %all, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + +// Combine masks: only process tail elements that passed comparison +%combined = pto.pand %cmp_mask, %tail, %all : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + +// Use for predicated operation +%result = pto.vsel %true_vals, %false_vals, %combined : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 6. Unary Vector Ops + +> **Category:** Single-input vector operations +> **Pipeline:** PIPE_V (Vector Core) + +Element-wise operations that take one vector input and produce one vector output. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate operand. For this family, inactive lanes follow the + predication behavior of the selected instruction form: zeroing forms + zero-fill inactive lanes, while merging forms preserve the destination value. +- `%result` is the destination vector register value. Unless stated otherwise, + `%result` has the same lane count and element type as `%input`. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** values use **aclFloat16** in traces where measured. **bf16:** no simple-tile ST coverage on this surface; treat as **—**. + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vabs` | `RV_VABS_FP` | **5** | **5** | — | +| `pto.vneg` | `RV_VMULS` | **8** | **8** | — | +| `pto.vexp` | `RV_VEXP` | **16** | **21** | — | +| `pto.vln` | `RV_VLN` | **18** | **23** | — | +| `pto.vsqrt` | `RV_VSQRT` | **17** | **22** | — | +| `pto.vrsqrt` | `RV_VSQRT` / `RV_VDIV` | **17** / **17** | **22** / **22** | — | +| `pto.vrec` | `RV_VDIV` | **17** | **22** | — | +| `pto.vrelu` | `RV_VRELU` | **5** | **5** | — | +| `pto.vnot` | `RV_VNOT` | — | int-only paths | — | +| `pto.vmov` | `RV_VLD` proxy | **9** | **9** | — | + +--- + +#### Arithmetic + +##### `pto.vabs` + +- **syntax:** `%result = pto.vabs %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] < 0) ? -src[i] : src[i]; +``` + +- **inputs:** `%input` supplies the source lanes and `%mask` selects which lanes + participate. +- **outputs:** `%result` receives the lane-wise absolute values. +- **constraints and limitations:** Source and result types MUST match. Integer + overflow on the most-negative signed value follows the target-defined + behavior. + +--- + +##### `pto.vneg` + +- **syntax:** `%result = pto.vneg %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = -src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise arithmetic negation. +- **constraints and limitations:** Source and result types MUST match. + +--- + +#### Transcendental + +##### `pto.vexp` + +- **syntax:** `%result = pto.vexp %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = expf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds `exp(input[i])` per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + +--- + +##### `pto.vln` + +- **syntax:** `%result = pto.vln %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = logf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the natural logarithm per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + For real-number semantics, active inputs SHOULD be strictly positive; non- + positive inputs follow the target's exception/NaN rules. + +--- + +##### `pto.vsqrt` + +- **syntax:** `%result = pto.vsqrt %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = sqrtf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the square root per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + Negative active inputs follow the target's exception/NaN rules. + +--- + +##### `pto.vrsqrt` + +- **syntax:** `%result = pto.vrsqrt %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = 1.0f / sqrtf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds reciprocal-square-root values per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + Active inputs containing `+0` or `-0` follow the target's divide-style + exceptional behavior. + +--- + +##### `pto.vrec` + +- **syntax:** `%result = pto.vrec %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = 1.0f / src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the reciprocal per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + Active inputs containing `+0` or `-0` follow the target's divide-style + exceptional behavior. + +--- + +#### Activation + +##### `pto.vrelu` + +- **syntax:** `%result = pto.vrelu %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] > 0) ? src[i] : 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds `max(input[i], 0)` per active lane. +- **constraints and limitations:** Only floating-point element types are legal + on the current A5 surface described here. + +--- + +#### Bitwise + +##### `pto.vnot` + +- **syntax:** `%result = pto.vnot %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = ~src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the lane-wise bitwise inversion. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vbcnt` + +- **syntax:** `%result = pto.vbcnt %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = __builtin_popcount(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the population count for each active lane. +- **constraints and limitations:** Integer element types only. The count is + over the source element width, not over the full vector register. + +--- + +##### `pto.vcls` + +- **syntax:** `%result = pto.vcls %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = count_leading_sign_bits(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the leading-sign-bit count per active lane. +- **constraints and limitations:** Integer element types only. This operation is + sign-aware, so signed interpretation matters. + +--- + +#### Movement + +##### `pto.vmov` + +- **syntax:** `%result = pto.vmov %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Vector register copy. + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` is a copy of the source vector. +- **constraints and limitations:** Predicated `pto.vmov` behaves like a masked + copy, while the unpredicated form behaves like a full-register copy. + +--- + +#### Typical Usage + +```mlir +// Softmax numerator: exp(x - max) +%sub = pto.vsub %x, %max_broadcast, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%exp = pto.vexp %sub, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Reciprocal for division +%sum_rcp = pto.vrec %sum, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// ReLU activation +%activated = pto.vrelu %linear_out, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 7. Binary Vector Ops + +> **Category:** Two-input vector operations +> **Pipeline:** PIPE_V (Vector Core) + +Element-wise operations that take two vector inputs and produce one vector output. + +#### Common Operand Model + +- `%lhs` and `%rhs` are the two source vector register values. +- `%mask` is the predicate operand `Pg` that gates which lanes participate. +- `%result` is the destination vector register value. Unless explicitly noted, + it has the same lane count and element type as the inputs. +- Unless explicitly documented otherwise, `%lhs`, `%rhs`, and `%result` MUST + have matching vector shapes and element types. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** uses **aclFloat16** in measured traces. **bf16:** — (no dedicated vec tile ST on this surface). + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vadd` | `RV_VADD` | **7** | **7** | — | +| `pto.vsub` | `RV_VSUB` | **7** | **7** | — | +| `pto.vmul` | `RV_VMUL` | **8** | **8** | — | +| `pto.vdiv` | `RV_VDIV` | **17** | **22** | — | + +--- + +#### Arithmetic + +##### `pto.vadd` + +- **syntax:** `%result = pto.vadd %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i64, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +- **inputs:** `%lhs` and `%rhs` are added lane-wise; `%mask` selects active + lanes. +- **outputs:** `%result` is the lane-wise sum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vsub` + +- **syntax:** `%result = pto.vsub %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i64, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] - src1[i]; +``` + +- **inputs:** `%lhs` is the minuend, `%rhs` is the subtrahend, and `%mask` + selects active lanes. +- **outputs:** `%result` is the lane-wise difference. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmul` + +- **syntax:** `%result = pto.vmul %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, bf16, f32 (**NOT** i8/u8) + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] * src1[i]; +``` + +- **inputs:** `%lhs` and `%rhs` are multiplied lane-wise; `%mask` selects + active lanes. +- **outputs:** `%result` is the lane-wise product. +- **constraints and limitations:** The current A5 profile excludes `i8/u8` + forms from this surface. + +--- + +##### `pto.vdiv` + +- **syntax:** `%result = pto.vdiv %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 only (no integer division) + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] / src1[i]; +``` + +- **inputs:** `%lhs` is the numerator, `%rhs` is the denominator, and `%mask` + selects active lanes. +- **outputs:** `%result` is the lane-wise quotient. +- **constraints and limitations:** Floating-point element types only. Active + denominators containing `+0` or `-0` follow the target's exceptional + behavior. + +--- + +##### `pto.vmax` + +- **syntax:** `%result = pto.vmax %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src0[i] > src1[i]) ? src0[i] : src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` holds the lane-wise maximum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmin` + +- **syntax:** `%result = pto.vmin %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src0[i] < src1[i]) ? src0[i] : src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` holds the lane-wise minimum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +#### Bitwise + +##### `pto.vand` + +- **syntax:** `%result = pto.vand %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] & src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise AND. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vor` + +- **syntax:** `%result = pto.vor %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] | src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise OR. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vxor` + +- **syntax:** `%result = pto.vxor %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] ^ src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise XOR. +- **constraints and limitations:** Integer element types only. + +--- + +#### Shift + +##### `pto.vshl` + +- **syntax:** `%result = pto.vshl %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] << src1[i]; +``` + +- **inputs:** `%lhs` supplies the shifted value, `%rhs` supplies the per-lane + shift amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. Shift counts + SHOULD stay within `[0, bitwidth(T) - 1]`; out-of-range behavior is target- + defined unless the verifier narrows it further. + +--- + +##### `pto.vshr` + +- **syntax:** `%result = pto.vshr %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] >> src1[i]; // arithmetic for signed, logical for unsigned +``` + +- **inputs:** `%lhs` supplies the shifted value, `%rhs` supplies the per-lane + shift amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. Signedness of the + element type determines arithmetic vs logical behavior. + +--- + +#### Carry Operations + +##### `pto.vaddc` + +- **syntax:** `%result, %carry = pto.vaddc %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Add with carry output. + +```c +for (int i = 0; i < N; i++) { + uint64_t r = (uint64_t)src0[i] + src1[i]; + dst[i] = (T)r; + carry[i] = (r >> bitwidth); +} +``` + +- **inputs:** `%lhs` and `%rhs` are added lane-wise and `%mask` selects active + lanes. +- **outputs:** `%result` is the truncated arithmetic result and `%carry` is the + carry/overflow predicate per lane. +- **constraints and limitations:** This is a carry-chain integer add family. On + the current A5 surface, it SHOULD be treated as an unsigned integer + operation. + +--- + +##### `pto.vsubc` + +- **syntax:** `%result, %borrow = pto.vsubc %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Subtract with borrow output. + +```c +for (int i = 0; i < N; i++) { + dst[i] = src0[i] - src1[i]; + borrow[i] = (src0[i] < src1[i]); +} +``` + +- **inputs:** `%lhs` and `%rhs` are subtracted lane-wise and `%mask` selects + active lanes. +- **outputs:** `%result` is the arithmetic difference and `%borrow` marks lanes + that borrowed. +- **constraints and limitations:** This operation SHOULD be treated as an + unsigned 32-bit carry-chain family unless and until the verifier states + otherwise. + +--- + +#### Typical Usage + +```mlir +// Vector addition +%sum = pto.vadd %a, %b, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Element-wise multiply +%prod = pto.vmul %x, %y, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Clamp to range [min, max] +%clamped_low = pto.vmax %input, %min_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%clamped = pto.vmin %clamped_low, %max_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Bit manipulation +%masked = pto.vand %data, %bitmask, %mask : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +``` + + + +### 8. Vec-Scalar Ops + +> **Category:** Vector-scalar operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that combine a vector with a scalar value, applying the scalar to every lane. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%scalar` is the scalar operand in SSA form. +- `%mask` is the predicate operand. +- `%result` is the destination vector register value. +- For 32-bit scalar forms, the scalar source MUST satisfy the backend's legal + scalar-source constraints for this family. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** uses **aclFloat16** in measured traces. **bf16:** —. + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vadds` | `RV_VADDS` | **7** | **7** | — | +| `pto.vmuls` | `RV_VMULS` | **8** | **8** | — | + +--- + +#### Arithmetic + +##### `pto.vadds` + +- **syntax:** `%result = pto.vadds %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] + scalar; +``` + +- **inputs:** `%input` is the source vector, `%scalar` is broadcast logically to + each active lane, and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise sum. +- **constraints and limitations:** Inactive lanes follow the predication + behavior defined for this family. On the current surface, inactive lanes are + treated as zeroing lanes. + +--- + +##### `pto.vsubs` + +- **syntax:** `%result = pto.vsubs %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] - scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise difference. +- **constraints and limitations:** Integer or floating-point legality depends on + the selected type family in lowering. + +--- + +##### `pto.vmuls` + +- **syntax:** `%result = pto.vmuls %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] * scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise product. +- **constraints and limitations:** Supported element types are hardware-family + specific; the current PTO micro Instruction documentation covers the common numeric cases. + +--- + +##### `pto.vmaxs` + +- **syntax:** `%result = pto.vmaxs %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] > scalar) ? src[i] : scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise maximum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmins` + +- **syntax:** `%result = pto.vmins %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] < scalar) ? src[i] : scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise minimum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +#### Bitwise + +##### `pto.vands` + +- **syntax:** `%result = pto.vands %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] & scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise AND. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vors` + +- **syntax:** `%result = pto.vors %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] | scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise OR. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vxors` + +- **syntax:** `%result = pto.vxors %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] ^ scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise XOR. +- **constraints and limitations:** Integer element types only. + +--- + +#### Shift + +##### `pto.vshls` + +- **syntax:** `%result = pto.vshls %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] << scalar; +``` + +- **inputs:** `%input` is the value vector, `%scalar` is the uniform shift + amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. The shift amount + SHOULD stay within the source element width. + +--- + +##### `pto.vshrs` + +- **syntax:** `%result = pto.vshrs %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] >> scalar; +``` + +- **inputs:** `%input` is the value vector, `%scalar` is the uniform shift + amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vlrelu` + +- **syntax:** `%result = pto.vlrelu %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : scalar * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%scalar` is the leaky slope, + and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise leaky-ReLU result. +- **constraints and limitations:** Only `f16` and `f32` forms are currently + documented for `pto.vlrelu`. + +--- + +#### Carry Operations + +##### `pto.vaddcs` + +- **syntax:** `%result, %carry = pto.vaddcs %lhs, %rhs, %carry_in, %mask : !pto.vreg, !pto.vreg, !pto.mask, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Add with carry-in and carry-out. + +```c +for (int i = 0; i < N; i++) { + uint64_t r = (uint64_t)src0[i] + src1[i] + carry_in[i]; + dst[i] = (T)r; + carry_out[i] = (r >> bitwidth); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the value vectors, `%carry_in` is the + incoming carry predicate, and `%mask` selects active lanes. +- **outputs:** `%result` is the arithmetic result and `%carry` is the carry-out + predicate. +- **constraints and limitations:** This is the scalar-extended carry-chain + family. Treat it as an unsigned integer operation unless the verifier states a + wider legal domain. + +--- + +##### `pto.vsubcs` + +- **syntax:** `%result, %borrow = pto.vsubcs %lhs, %rhs, %borrow_in, %mask : !pto.vreg, !pto.vreg, !pto.mask, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Subtract with borrow-in and borrow-out. + +```c +for (int i = 0; i < N; i++) { + dst[i] = src0[i] - src1[i] - borrow_in[i]; + borrow_out[i] = (src0[i] < src1[i] + borrow_in[i]); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the value vectors, `%borrow_in` is the + incoming borrow predicate, and `%mask` selects active lanes. +- **outputs:** `%result` is the arithmetic result and `%borrow` is the + borrow-out predicate. +- **constraints and limitations:** This is the scalar-extended borrow-chain + family and SHOULD be treated as an unsigned integer operation. + +--- + +#### Typical Usage + +```mlir +// Add bias to all elements +%biased = pto.vadds %activation, %bias_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Scale by constant +%scaled = pto.vmuls %input, %scale, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Clamp to [0, 255] for uint8 quantization +%clamped_low = pto.vmaxs %input, %c0, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> +%clamped = pto.vmins %clamped_low, %c255, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Shift right by fixed amount +%shifted = pto.vshrs %data, %c4, %mask : !pto.vreg<64xi32>, i32, !pto.mask -> !pto.vreg<64xi32> +``` + + + +### 9. Conversion Ops + +> **Category:** Type conversion operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that convert between data types (float/int, narrowing/widening). + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%result` is the destination vector register value. +- `round_mode`, `sat`, and `part` control rounding, saturation, and lane-part + selection in attribute form. +- The single `pto.vcvt` surface covers float-int, float-float, int-float, and + int-int conversion families. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). Only representative traces below; other `pto.vcvt` conversion pairs depend on the RV lowering in the trace. + +| PTO op | RV (CA) | Note | Latency | +|--------|---------|------|---------| +| `pto.vcvt` | `RV_VCVT_F2F` | f32→f16 | **7** | +| `pto.vci` | — | no vector `RV_*` in sampled `veccore0` trace | — | + +--- + +#### `pto.vci` + +- **syntax:** `%result = pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- **semantics:** Generate a lane-index vector from a scalar seed/index value. +- **inputs:** + `%index` is the scalar seed or base index. +- **outputs:** + `%result` is the generated index vector. +- **constraints and limitations:** + This is an index-generation family, not a numeric conversion. `ORDER` and the + result element type together determine how indices are generated. + +--- + +#### `pto.vcvt` + +- **syntax:** `%result = pto.vcvt %input {round_mode = "ROUND_MODE", sat = "SAT_MODE", part = "PART_MODE"} : !pto.vreg -> !pto.vreg` +- **semantics:** Type conversion between float/int types with rounding control. + +```c +for (int i = 0; i < min(N, M); i++) + dst[i] = convert(src[i], T0, T1, round_mode); +``` + +- **inputs:** + `%input` is the source vector; attributes select rounding, saturation, and + even/odd placement when the conversion changes width. +- **outputs:** + `%result` is the converted vector. +- **constraints and limitations:** + Only documented source/destination type pairs are legal. `PART_EVEN` / + `PART_ODD` is only meaningful for width-changing forms that pack two source + streams into one destination register. + +--- + +##### Rounding Modes + +| Mode | Description | +|------|-------------| +| `ROUND_R` | Round to nearest, ties to even (default) | +| `ROUND_A` | Round away from zero | +| `ROUND_F` | Round toward negative infinity (floor) | +| `ROUND_C` | Round toward positive infinity (ceil) | +| `ROUND_Z` | Round toward zero (truncate) | +| `ROUND_O` | Round to odd | + +--- + +##### Saturation Modes + +| Mode | Description | +|------|-------------| +| `RS_ENABLE` | Saturate on overflow | +| `RS_DISABLE` | No saturation (wrap/undefined on overflow) | + +--- + +##### Part Modes (for width-changing conversions) + +| Mode | Description | +|------|-------------| +| `PART_EVEN` | Output to even-indexed lanes | +| `PART_ODD` | Output to odd-indexed lanes | + +--- + +##### A5 Supported Conversions + +**Float-Float (vcvtff):** +- f32 ↔ f16 +- f32 ↔ bf16 +- f16 ↔ bf16 + +**Float-Int (vcvtfi):** +- f16 → i16, f16 → i32 +- f32 → i16, f32 → i32 +- bf16 → i32 + +**Int-Float (vcvtif):** +- i16 → f16 +- i32 → f32 + +--- + +##### Width-Changing Conversion Pattern + +For conversions that change width (e.g., f32→f16), use even/odd parts and combine: + +```mlir +// Convert two f32 vectors to one f16 vector +%even = pto.vcvt %in0 {round_mode = "ROUND_R", sat = "RS_ENABLE", part = "PART_EVEN"} + : !pto.vreg<64xf32> -> !pto.vreg<128xf16> +%odd = pto.vcvt %in1 {round_mode = "ROUND_R", sat = "RS_ENABLE", part = "PART_ODD"} + : !pto.vreg<64xf32> -> !pto.vreg<128xf16> +%result = pto.vor %even, %odd, %mask : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +``` + +--- + +#### `pto.vtrc` + +- **syntax:** `%result = pto.vtrc %input, "ROUND_MODE" : !pto.vreg -> !pto.vreg` +- **semantics:** Truncate/round float to integer-valued float (stays in float type). + +```c +for (int i = 0; i < N; i++) + dst[i] = round_to_int_valued_float(src[i], round_mode); +``` + +- **inputs:** + `%input` is the floating-point source vector and `ROUND_MODE` selects the + truncation/rounding rule. +- **outputs:** + `%result` is still a floating-point vector, but each active lane now carries + an integer-valued floating-point result. +- **constraints and limitations:** + This op does not change the element type. `ROUND_O` is supported for avoiding + double-rounding errors during staged conversions. + +**Example:** +```mlir +// Round to nearest integer, keep as float +%rounded = pto.vtrc %input, "ROUND_R" : !pto.vreg<64xf32> -> !pto.vreg<64xf32> +// input: [1.4, 2.6, -1.5, 3.0] +// output: [1.0, 3.0, -2.0, 3.0] +``` + +--- + +#### Typical Usage + +```mlir +// Quantization: f32 → i8 with saturation +%scaled = pto.vmuls %input, %scale, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> +%quantized = pto.vcvt %scaled {round_mode = "ROUND_R", sat = "RS_ENABLE"} + : !pto.vreg<64xf32> -> !pto.vreg<64xi32> +// Then narrow i32 → i8 via pack ops + +// Mixed precision: bf16 → f32 for accumulation +%f32_vec = pto.vcvt %bf16_input {round_mode = "ROUND_R"} + : !pto.vreg<128xbf16> -> !pto.vreg<64xf32> + +// Floor for integer division +%floored = pto.vtrc %ratio, "ROUND_F" : !pto.vreg<64xf32> -> !pto.vreg<64xf32> +%int_div = pto.vcvt %floored {round_mode = "ROUND_Z"} + : !pto.vreg<64xf32> -> !pto.vreg<64xi32> +``` + + + +### 10. Reduction Ops + +> **Category:** Vector reduction operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that reduce a vector to a scalar or per-group result. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate operand `Pg`; inactive lanes do not participate. +- `%result` is the destination vector register value. +- Reduction results are written into the low-significance portion of the + destination vector and the remaining destination bits are zero-filled. + +--- + +#### Full Vector Reductions + +##### `pto.vcadd` + +- **syntax:** `%result = pto.vcadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i64, f16, f32 +- **semantics:** Sum all elements. Result in lane 0, others zeroed. + +```c +T sum = 0; +for (int i = 0; i < N; i++) + sum += src[i]; +dst[0] = sum; +for (int i = 1; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains the reduction result in its low element(s). +- **constraints and limitations:** Some narrow integer forms may widen the + internal accumulation or result placement. If all predicate bits are zero, the + result is zero. + +--- + +##### `pto.vcmax` + +- **syntax:** `%result = pto.vcmax %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Find max element with argmax. Result value + index in lane 0. + +```c +T mx = -INF; int idx = 0; +for (int i = 0; i < N; i++) + if (src[i] > mx) { mx = src[i]; idx = i; } +dst_val[0] = mx; +dst_idx[0] = idx; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` carries the reduction result in the low destination + positions. +- **constraints and limitations:** This family computes both the extremum and + location information, but the exact packing of that information into the + destination vector depends on the chosen form. If all predicate bits are zero, + the result follows the zero-filled convention. + +--- + +##### `pto.vcmin` + +- **syntax:** `%result = pto.vcmin %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Find min element with argmin. Result value + index in lane 0. + +```c +T mn = INF; int idx = 0; +for (int i = 0; i < N; i++) + if (src[i] < mn) { mn = src[i]; idx = i; } +dst_val[0] = mn; +dst_idx[0] = idx; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` carries the reduction result in the low destination + positions. +- **constraints and limitations:** As with `pto.vcmax`, the exact value/index + packing depends on the chosen form and MUST be preserved consistently. + +--- + +#### Per-VLane (Group) Reductions + +The vector register is organized as **8 VLanes** of 32 bytes each. Group reductions operate within each VLane independently. + +``` +vreg layout (f32 example, 64 elements total): +VLane 0: [0..7] VLane 1: [8..15] VLane 2: [16..23] VLane 3: [24..31] +VLane 4: [32..39] VLane 5: [40..47] VLane 6: [48..55] VLane 7: [56..63] +``` + +##### `pto.vcgadd` + +- **syntax:** `%result = pto.vcgadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Sum within each VLane. 8 results at indices 0, 8, 16, 24, 32, 40, 48, 56 (for f32). + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +// For f32: results at dst[0], dst[8], dst[16], dst[24], dst[32], dst[40], dst[48], dst[56] +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one sum per 32-byte VLane group, written + contiguously into the low slot of each group. +- **constraints and limitations:** This is a per-32-byte VLane-group reduction. + Inactive lanes are treated as zero. + +--- + +##### `pto.vcgmax` + +- **syntax:** `%result = pto.vcgmax %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Max within each VLane. + +```c +int K = N / 8; +for (int g = 0; g < 8; g++) { + T mx = -INF; + for (int i = 0; i < K; i++) + if (src[g*K + i] > mx) mx = src[g*K + i]; + dst[g*K] = mx; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one maximum per 32-byte VLane group. +- **constraints and limitations:** Grouping is by hardware 32-byte VLane, not by + arbitrary software subvector. + +--- + +##### `pto.vcgmin` + +- **syntax:** `%result = pto.vcgmin %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Min within each VLane. + +```c +int K = N / 8; +for (int g = 0; g < 8; g++) { + T mn = INF; + for (int i = 0; i < K; i++) + if (src[g*K + i] < mn) mn = src[g*K + i]; + dst[g*K] = mn; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one minimum per 32-byte VLane group. +- **constraints and limitations:** Grouping is by hardware 32-byte VLane, not by + arbitrary software subvector. + +--- + +#### Prefix Operations + +##### `pto.vcpadd` + +- **syntax:** `%result = pto.vcpadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Inclusive prefix sum (scan). + +```c +dst[0] = src[0]; +for (int i = 1; i < N; i++) + dst[i] = dst[i-1] + src[i]; +``` + +**Example:** +```c +// input: [1, 2, 3, 4, 5, ...] +// output: [1, 3, 6, 10, 15, ...] +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` is the inclusive prefix-sum vector. +- **constraints and limitations:** Only floating-point element types are + documented on the current A5 surface here. + +--- + +#### Typical Usage + +```mlir +// Softmax: find max for numerical stability +%max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// max is in lane 0, broadcast it +%max_broadcast = pto.vlds %ub_tmp[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + +// Row-wise sum using vcgadd (for 8-row tile) +%row_sums = pto.vcgadd %tile, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// Results at indices 0, 8, 16, 24, 32, 40, 48, 56 + +// Full vector sum for normalization +%total = pto.vcadd %values, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// total[0] contains the sum + +// Prefix sum for cumulative distribution +%cdf = pto.vcpadd %pdf, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 11. Compare & Select + +> **Category:** Comparison and conditional selection operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that compare vectors and conditionally select elements. + +#### Common Operand Model + +- `%src0` and `%src1` are source vector operands. +- `%scalar` is the scalar operand for scalar-comparison families. +- `%seed` is the incoming predicate that limits which lanes participate in the + compare. +- `%result` is either a predicate mask (`vcmp`, `vcmps`) or a vector register + (`vsel`, `vselr`, `vselrv2`). + +--- + +#### Comparison Operations + +##### `pto.vcmp` + +- **syntax:** `%result = pto.vcmp %src0, %src1, %seed, "CMP_MODE" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.mask` +- **semantics:** Element-wise comparison, output predicate mask. + +```c +for (int i = 0; i < N; i++) + if (seed[i]) + dst[i] = (src0[i] CMP src1[i]) ? 1 : 0; +``` + +**Compare modes:** + +| Mode | Operation | +|------|-----------| +| `eq` | Equal (==) | +| `ne` | Not equal (!=) | +| `lt` | Less than (<) | +| `le` | Less than or equal (<=) | +| `gt` | Greater than (>) | +| `ge` | Greater than or equal (>=) | + +**Example:** +```mlir +%all_active = pto.pset_b32 "PAT_ALL" : !pto.mask +%lt_mask = pto.vcmp %a, %b, %all_active, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +// lt_mask[i] = 1 if a[i] < b[i] +``` + +- **inputs:** `%src0`, `%src1`, and `%seed`; `CMP_MODE` selects the comparison + predicate. +- **outputs:** `%result` is the generated predicate mask. +- **constraints and limitations:** Only lanes enabled by `%seed` participate. + Integer and floating-point comparisons follow their own element-type-specific + comparison rules. + +--- + +##### `pto.vcmps` + +- **syntax:** `%result = pto.vcmps %src, %scalar, %seed, "CMP_MODE" : !pto.vreg, T, !pto.mask -> !pto.mask` +- **semantics:** Compare vector against scalar. + +```c +for (int i = 0; i < N; i++) + if (seed[i]) + dst[i] = (src[i] CMP scalar) ? 1 : 0; +``` + +**Example:** +```mlir +%positive_mask = pto.vcmps %values, %c0_f32, %all_active, "gt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +// positive_mask[i] = 1 if values[i] > 0 +``` + +- **inputs:** `%src` is the vector source, `%scalar` is the scalar comparison + value, and `%seed` is the incoming predicate. +- **outputs:** `%result` is the generated predicate mask. +- **constraints and limitations:** For 32-bit scalar forms, the scalar source + MUST satisfy the backend's legal scalar-source constraints for this family. + +--- + +#### Selection Operations + +##### `pto.vsel` + +- **syntax:** `%result = pto.vsel %src0, %src1, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Per-lane select based on mask. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? src0[i] : src1[i]; +``` + +**Example — Conditional assignment:** +```mlir +// dst = mask ? true_vals : false_vals +%result = pto.vsel %true_vals, %false_vals, %condition + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +- **inputs:** `%src0` is the true-path vector, `%src1` is the false-path vector, + and `%mask` selects between them. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** Source vectors and result MUST have matching + vector shapes and element types. + +--- + +##### `pto.vselr` + +- **syntax:** `%result = pto.vselr %src0, %src1 : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Select with reversed mask semantics. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? src1[i] : src0[i]; // reversed from vsel +``` + +- **inputs:** `%src0` and `%src1` are the source vectors. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** This family preserves reversed-select + semantics. If the concrete lowering uses an implicit predicate source, that + predicate source MUST be documented by the surrounding IR pattern. + +--- + +##### `pto.vselrv2` + +- **syntax:** `%result = pto.vselrv2 %src0, %src1 : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Variant select form with the same current two-vector operand shape. +- **inputs:** `%src0` and `%src1` are the source vectors. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** This page records the surface shape only. + Lowering MUST preserve the exact A5 variant semantics selected for this form. + +--- + +#### Typical Usage + +```mlir +// Clamp negative values to zero (manual ReLU) +%all = pto.pset_b32 "PAT_ALL" : !pto.mask +%zero = pto.vbr %c0_f32 : f32 -> !pto.vreg<64xf32> +%neg_mask = pto.vcmps %input, %c0_f32, %all, "lt" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%clamped = pto.vsel %zero, %input, %neg_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Element-wise max via compare+select +%gt_mask = pto.vcmp %a, %b, %all, "gt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +%max_ab = pto.vsel %a, %b, %gt_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Threshold filter +%above_thresh = pto.vcmps %scores, %threshold, %all, "ge" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%filtered = pto.vsel %scores, %zero, %above_thresh : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +--- + +#### Compare + Select Pattern + +```mlir +// Softmax safe exp: exp(x - max) where x < max returns exp of negative +// but we want to clamp to avoid underflow + +%all = pto.pset_b32 "PAT_ALL" : !pto.mask + +// 1. Compare against threshold +%too_small = pto.vcmps %x_minus_max, %min_exp_arg, %all, "lt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask + +// 2. Clamp values below threshold +%clamped = pto.vsel %min_exp_arg_vec, %x_minus_max, %too_small + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// 3. Safe exp +%exp_result = pto.vexp %clamped, %all : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 12. Data Rearrangement + +> **Category:** In-register data movement and permutation +> **Pipeline:** PIPE_V (Vector Core) + +Operations that rearrange data within or between vector registers without memory access. + +#### Common Operand Model + +- `%lhs` / `%rhs` are source vector register values. +- `%src` is a single source vector register value. +- `%result` is the destination vector register value unless an op explicitly + returns multiple vectors. +- These families do not access UB directly; they only rearrange register + contents. + +--- + +#### Interleave / Deinterleave + +##### `pto.vintlv` + +- **syntax:** `%low, %high = pto.vintlv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg, !pto.vreg` +- **semantics:** Interleave elements from two sources. + +```c +// Interleave: merge even/odd elements from two sources +// low = {src0[0], src1[0], src0[1], src1[1], ...} +// high = {src0[N/2], src1[N/2], src0[N/2+1], src1[N/2+1], ...} +``` + +- **inputs:** `%lhs` and `%rhs` are the two source vectors. +- **outputs:** `%low` and `%high` are the two destination vectors. +- **constraints and limitations:** The two outputs form a paired interleave + result. The PTO micro Instruction representation exposes that pair as two SSA results, and the pair ordering MUST + be preserved. + +--- + +##### `pto.vdintlv` + +- **syntax:** `%low, %high = pto.vdintlv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg, !pto.vreg` +- **semantics:** Deinterleave elements into even/odd. + +```c +// Deinterleave: separate even/odd elements +// low = {src0[0], src0[2], src0[4], ...} // even +// high = {src0[1], src0[3], src0[5], ...} // odd +``` + +- **inputs:** `%lhs` and `%rhs` represent the interleaved source stream in the + current PTO micro Instruction representation. +- **outputs:** `%low` and `%high` are the separated destination vectors. +- **constraints and limitations:** The two outputs form the even/odd + deinterleave result pair, and their ordering MUST be preserved. + +--- + +#### Slide / Shift + +##### `pto.vslide` + +- **syntax:** `%result = pto.vslide %src0, %src1, %amt : !pto.vreg, !pto.vreg, i16 -> !pto.vreg` +- **semantics:** Concatenate two vectors and extract N-element window at offset. + +```c +// Conceptually: tmp[0..2N-1] = {src1, src0} +// dst[i] = tmp[amt + i] +if (amt >= 0) + for (int i = 0; i < N; i++) + dst[i] = (i >= amt) ? src0[i - amt] : src1[N - amt + i]; +``` + +**Use case:** Sliding window operations, shift register patterns. + +- **inputs:** `%src0` and `%src1` provide the concatenated source window and + `%amt` selects the extraction offset. +- **outputs:** `%result` is the extracted destination window. +- **constraints and limitations:** `pto.vslide` operates on the logical + concatenation of `%src1` and `%src0`. The source order and extraction offset + MUST be preserved exactly. + +--- + +##### `pto.vshift` + +- **syntax:** `%result = pto.vshift %src, %amt : !pto.vreg, i16 -> !pto.vreg` +- **semantics:** Single-source slide (shift with zero fill). + +```c +for (int i = 0; i < N; i++) + dst[i] = (i >= amt) ? src[i - amt] : 0; +``` + +- **inputs:** `%src` is the source vector and `%amt` is the slide amount. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** This surface represents the single-source + slide/shift family. Zero-fill versus other fill behavior MUST match the + selected form. + +--- + +#### Compress / Expand + +##### `pto.vsqz` + +- **syntax:** `%result = pto.vsqz %src, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Compress — pack active lanes to front. + +```c +int j = 0; +for (int i = 0; i < N; i++) + if (mask[i]) dst[j++] = src[i]; +while (j < N) dst[j++] = 0; +``` + +**Use case:** Sparse data compaction, filtering. + +- **inputs:** `%src` is the source vector and `%mask` selects which elements are + kept. +- **outputs:** `%result` is the compacted vector. +- **constraints and limitations:** This is a reduction-style compaction family. + Preserved element order MUST match source lane order. + +--- + +##### `pto.vusqz` + +- **syntax:** `%result = pto.vusqz %mask : !pto.mask -> !pto.vreg` +- **semantics:** Expand — scatter front elements to active positions. + +```c +int j = 0; +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src_front[j++]; + else dst[i] = 0; +``` + +- **inputs:** `%mask` is the expansion/placement predicate. +- **outputs:** `%result` is the expanded vector image. +- **constraints and limitations:** The source-front stream is implicit in the + current surface. Lane placement for active and inactive positions MUST be + preserved exactly. + +--- + +#### Permutation + +##### `pto.vperm` + +- **syntax:** `%result = pto.vperm %src, %index : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** In-register permute (table lookup). **Not** memory gather. + +```c +for (int i = 0; i < N; i++) + dst[i] = src[index[i] % N]; +``` + +**Note:** This operates on register contents, unlike `pto.vgather2` which reads from UB memory. + +- **inputs:** `%src` is the source vector and `%index` supplies per-lane source + indices. +- **outputs:** `%result` is the permuted vector. +- **constraints and limitations:** This is an in-register permutation family. + `%index` values outside the legal range follow the wrap/clamp behavior of the + selected form. + +--- + +##### `pto.vselr` + +- **syntax:** `%result = pto.vselr %src0, %src1 : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Register select with reversed mask semantics. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? src1[i] : src0[i]; +``` + +- **inputs:** `%src0` and `%src1` are source vectors. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** This page records the rearrangement use of + the family; the compare/select page documents the same name from the predicate + selection perspective. + +--- + +#### Pack / Unpack + +##### `pto.vpack` + +- **syntax:** `%result = pto.vpack %src0, %src1, %part : !pto.vreg, !pto.vreg, index -> !pto.vreg<2NxT_narrow>` +- **semantics:** Narrowing pack — two wide vectors to one narrow vector. + +```c +// e.g., two vreg<64xi32> → one vreg<128xi16> +for (int i = 0; i < N; i++) { + dst[i] = truncate(src0[i]); + dst[N + i] = truncate(src1[i]); +} +``` + +- **inputs:** `%src0` and `%src1` are wide source vectors and `%part` selects + the packing submode. +- **outputs:** `%result` is the packed narrow vector. +- **constraints and limitations:** Packing is a narrowing conversion. Source + values that do not fit the destination width follow the truncation semantics + of the selected packing mode. + +--- + +##### `pto.vsunpack` + +- **syntax:** `%result = pto.vsunpack %src, %part : !pto.vreg, index -> !pto.vreg` +- **semantics:** Sign-extending unpack — narrow to wide (half). + +```c +// e.g., vreg<128xi16> → vreg<64xi32> (one half) +for (int i = 0; i < N/2; i++) + dst[i] = sign_extend(src[part_offset + i]); +``` + +- **inputs:** `%src` is the packed narrow vector and `%part` selects which half + is unpacked. +- **outputs:** `%result` is the widened vector. +- **constraints and limitations:** This is the sign-extending unpack family. + +--- + +##### `pto.vzunpack` + +- **syntax:** `%result = pto.vzunpack %src, %part : !pto.vreg, index -> !pto.vreg` +- **semantics:** Zero-extending unpack — narrow to wide (half). + +```c +for (int i = 0; i < N/2; i++) + dst[i] = zero_extend(src[part_offset + i]); +``` + +- **inputs:** `%src` is the packed narrow vector and `%part` selects which half + is unpacked. +- **outputs:** `%result` is the widened vector. +- **constraints and limitations:** This is the zero-extending unpack family. + +--- + +#### Typical Usage + +```mlir +// AoS → SoA conversion using deinterleave +%even, %odd = pto.vdintlv %interleaved0, %interleaved1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// Filter: keep only elements passing condition +%pass_mask = pto.vcmps %values, %threshold, %all, "gt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%compacted = pto.vsqz %values, %pass_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Sliding window sum +%prev_window = pto.vslide %curr, %prev, %c1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, i16 -> !pto.vreg<64xf32> +%window_sum = pto.vadd %curr, %prev_window, %all + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Type narrowing via pack +%packed_i16 = pto.vpack %wide0_i32, %wide1_i32, %c0 + : !pto.vreg<64xi32>, !pto.vreg<64xi32>, index -> !pto.vreg<128xi16> +``` + +--- + +#### V2 Interleave Forms + +##### `pto.vintlvv2` + +- **syntax:** `%result = pto.vintlvv2 %lhs, %rhs, "PART" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **inputs:** `%lhs` and `%rhs` are source vectors and `PART` selects the + returned half of the V2 interleave result. +- **outputs:** `%result` is the selected interleave half. +- **constraints and limitations:** This op exposes only one half of the V2 + result in SSA form. + +##### `pto.vdintlvv2` + +- **syntax:** `%result = pto.vdintlvv2 %lhs, %rhs, "PART" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **inputs:** `%lhs` and `%rhs` are source vectors and `PART` selects the + returned half of the V2 deinterleave result. +- **outputs:** `%result` is the selected deinterleave half. +- **constraints and limitations:** This op exposes only one half of the V2 + result in SSA form. + + + +### 13. DSA/SFU Ops + +> **Category:** Domain-specific accelerator and special function unit operations +> **Pipeline:** PIPE_V (Vector Core) / SFU + +Fused operations, special functions, and UB-to-UB operations that leverage hardware acceleration. + +#### Common Operand Model + +- `%input`, `%lhs`, `%rhs`, `%acc`, and `%alpha` are source SSA values whose + roles are called out per instruction. +- `%mask` is the predicate operand `Pg` when present. +- `%result` is the destination SSA value. +- This page mixes three different backend shapes: pure `vreg -> vreg` ops, + conversion/fusion ops, and UB-to-UB helpers. Each instruction section calls + out which storage model it uses. + +--- + +#### Fused Activation Ops (vreg→vreg) + +##### `pto.vlrelu` + +- **syntax:** `%result = pto.vlrelu %input, %alpha, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Leaky ReLU with scalar alpha. + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : alpha * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%alpha` is the scalar slope, + and `%mask` selects active lanes. +- **outputs:** `%result` is the leaky-ReLU vector. +- **constraints and limitations:** Only `f16` and `f32` forms are currently + documented for `pto.vlrelu`. + +--- + +##### `pto.vprelu` + +- **syntax:** `%result = pto.vprelu %input, %alpha : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Parametric ReLU with per-element alpha vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : alpha[i] * src[i]; +``` + +- **inputs:** `%input` is the activation vector and `%alpha` is the per-element + slope vector. +- **outputs:** `%result` is the parametric-ReLU vector. +- **constraints and limitations:** Floating-point element types only on the + current A5 surface. + +--- + +##### `pto.vexpdiff` + +- **syntax:** `%result = pto.vexpdiff %input, %max : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Fused exp(x - max) for numerically stable softmax. + +```c +for (int i = 0; i < N; i++) + dst[i] = expf(src[i] - max[i]); +``` + +**Use case:** Softmax numerator computation with numerical stability. + +- **inputs:** `%input` is the source vector and `%max` is the broadcasted + subtraction term. +- **outputs:** `%result` is the fused `exp(input - max)` vector. +- **constraints and limitations:** Floating-point element types only. + +--- + +#### Fused Compute+Convert Ops + +##### `pto.vaddrelu` + +- **syntax:** `%result = pto.vaddrelu %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Fused add + ReLU. + +```c +for (int i = 0; i < N; i++) + dst[i] = max(src0[i] + src1[i], 0); +``` + +- **inputs:** `%lhs` and `%rhs` are the two addends. +- **outputs:** `%result` is the fused add-then-ReLU result. +- **constraints and limitations:** Floating-point element types only on the + current documented surface. + +--- + +##### `pto.vsubrelu` + +- **syntax:** `%result = pto.vsubrelu %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Fused sub + ReLU. + +```c +for (int i = 0; i < N; i++) + dst[i] = max(src0[i] - src1[i], 0); +``` + +- **inputs:** `%lhs` is the minuend and `%rhs` is the subtrahend. +- **outputs:** `%result` is the fused sub-then-ReLU result. +- **constraints and limitations:** Floating-point element types only on the + current documented surface. + +--- + +##### `pto.vaxpy` + +- **syntax:** `%result = pto.vaxpy %src0, %src1, %alpha : !pto.vreg, !pto.vreg, T -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** AXPY — scalar-vector multiply-add. + +```c +for (int i = 0; i < N; i++) + dst[i] = alpha * src0[i] + src1[i]; +``` + +- **inputs:** `%src0` is the scaled vector, `%src1` is the addend vector, and + `%alpha` is the scalar multiplier. +- **outputs:** `%result` is the fused AXPY result. +- **constraints and limitations:** Floating-point element types only on the + current documented surface. + +--- + +##### `pto.vaddreluconv` + +- **syntax:** `%result = pto.vaddreluconv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Fused add + ReLU + type conversion (HW fusion). + +```c +// f32→f16 variant: +for (int i = 0; i < 64; i++) + dst_f16[i] = f32_to_f16(max(src0_f32[i] + src1_f32[i], 0)); + +// f16→i8 variant: +for (int i = 0; i < 128; i++) + dst_i8[i] = f16_to_i8(max(src0_f16[i] + src1_f16[i], 0)); +``` + +- **inputs:** `%lhs` and `%rhs` are the source vectors. +- **outputs:** `%result` is the fused add/ReLU/convert result. +- **constraints and limitations:** Only backend-supported source/destination + type pairs are legal. Rounding, saturation, and packing rules follow the + semantics of this fused operation, not an arbitrary sequence of standalone + ops. + +--- + +##### `pto.vmulconv` + +- **syntax:** `%result = pto.vmulconv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Fused mul + type conversion (HW fusion). + +```c +// f16→i8 variant: +for (int i = 0; i < 128; i++) + dst_i8[i] = f16_to_i8(src0_f16[i] * src1_f16[i]); +``` + +- **inputs:** `%lhs` and `%rhs` are the source vectors. +- **outputs:** `%result` is the fused mul/convert result. +- **constraints and limitations:** Only backend-supported source/destination + type pairs are legal. + +--- + +#### Extended Arithmetic + +##### `pto.vmull` + +- **syntax:** `%low, %high = pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` +- **A5 types:** i32/u32 (native 32×32→64 widening multiply) +- **semantics:** Widening multiply with high/low results. + +```c +for (int i = 0; i < 64; i++) { + int64_t r = (int64_t)src0_i32[i] * (int64_t)src1_i32[i]; + dst_lo[i] = (int32_t)(r & 0xFFFFFFFF); + dst_hi[i] = (int32_t)(r >> 32); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the source vectors and `%mask` selects + active lanes. +- **outputs:** `%low` and `%high` expose the widened-product low/high parts. +- **constraints and limitations:** The current documented A5 form is the native + widening 32x32->64 integer multiply family. + +--- + +##### `pto.vmula` + +- **syntax:** `%result = pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Multiply-accumulate. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) + dst[i] = acc[i] + lhs[i] * rhs[i]; +``` + +- **inputs:** `%acc` is the accumulator input, `%lhs` and `%rhs` are the + multiplicands, and `%mask` selects active lanes. +- **outputs:** `%result` is the multiply-accumulate result. +- **constraints and limitations:** `pto.vmula` is a fused multiply-accumulate + operation and is not always interchangeable with separate `vmul` plus `vadd`. + +--- + +#### Index Generation + +##### `pto.vci` + +- **syntax:** `%result = pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- **semantics:** Generate lane index vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = base_index + i; +``` + +**Use case:** Generate indices for gather/scatter, argsort, etc. + +- **inputs:** `%index` is the scalar seed/base index. +- **outputs:** `%result` is the generated index vector. +- **constraints and limitations:** This page documents the arithmetic/indexing + use of the family; the conversion page also records the same opcode for + completeness. + +--- + +#### UB-to-UB Operations + +##### `pto.vtranspose` + +- **syntax:** `pto.vtranspose %dest, %src, %config : !pto.ptr, !pto.ptr, i64` +- **semantics:** UB-to-UB transpose operation (not vreg-to-vreg). + +**Note:** This operates on UB memory directly, not on vector registers. + +- **inputs:** `%dest` and `%src` are UB pointers and `%config` is the ISA + control/config word. +- **outputs:** This op writes UB memory and returns no SSA value. +- **constraints and limitations:** This is not a `vreg -> vreg` op even though + it lives in the `pto.v*` namespace. Its correctness depends on the control + word and UB layout contract. + +--- + +#### Sorting Operations + +##### `pto.vsort32` + +- **syntax:** `pto.vsort32 %dest, %src, %config : !pto.ptr, !pto.ptr, i64` +- **semantics:** Sort 32 elements in UB. +- **inputs:** `%dest` and `%src` are UB pointers and `%config` is the ISA + control/config word. +- **outputs:** This op writes UB memory and returns no SSA value. +- **constraints and limitations:** This is a UB-to-UB accelerator helper, not a + pure vector-register op. + +--- + +##### `pto.vmrgsort` + +- **syntax:** `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr, !pto.ptr x4, i64, i64` +- **semantics:** Merge-sort 4 pre-sorted input vectors. +- **inputs:** `%dest` is the UB destination, `%src0..%src3` are the four + pre-sorted UB inputs, `%count` is the number of valid elements, and `%config` + is the operation control word. +- **outputs:** This op writes UB memory and returns no SSA value. +- **constraints and limitations:** Inputs MUST already be sorted according to + the sort order encoded by `%config`. This page uses the shorter mnemonic + `pto.vmrgsort`, while the current implementation summary still refers to + `pto.vmrgsort4`. + +--- + +#### Current Implementation Surface Summary + +- `pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` +- `pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- `pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- `pto.vbitsort %dest, %src, %indices, %repeat_times : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, index` +- `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, i64, i64` + +--- + +#### Typical Usage + +```mlir +// Softmax with fused expdiff +%max_broadcast = pto.vlds %ub_max[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> +%exp_stable = pto.vexpdiff %logits, %max_broadcast : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// Leaky ReLU activation +%activated = pto.vlrelu %linear_out, %alpha_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Fused residual add + ReLU +%residual = pto.vaddrelu %conv_out, %skip_connection : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// Generate indices for argsort +%indices = pto.vci %c0 {order = "ASC"} : i32 -> !pto.vreg<64xi32> +``` + + + +### 14. Arith (Shared MLIR Dialect) + +> **Category:** Shared full scalar `arith` surface used around PTO ops +> **Dialect:** `arith` +> **Upstream Reference:** https://mlir.llvm.org/docs/Dialects/ArithOps/ + +The upstream MLIR `arith` dialect defines primitive arithmetic, comparison, select, and cast operations over signless integer, index, floating-point, and boolean-compatible scalar values. Within PTO micro Instruction code, the full scalar operation surface of `arith` is supported. These ops are used around PTO instructions to build constants, compute offsets and loop bounds, perform general scalar math, derive valid-shape metadata, and form predicates for `scf` control flow. + +These ops are part of the documented PTO micro Instruction surface, but they are not PTO ISA instructions. + +--- + +#### Role in PTO micro Instruction Code + +- materialize scalar constants used by PTO scalar operands and loop bounds +- compute scalar/index offsets for tensor views, partitioning, and dynamic shapes +- perform general scalar integer and floating-point math outside PTO vector/tile payload operations +- derive scalar predicates that guard `scf.if` or `scf.while` +- apply scalar casts, width changes, bitwise ops, and selects without introducing PTO-specific control ops + +Prefer PTO ops for vector or tile payload math. Use `arith` for scalar computation and bookkeeping that surrounds PTO regions. + +--- + +#### Supported Surface + +The documented PTO micro Instruction surface supports the full scalar operation surface of upstream `arith`. The upstream `arith` dialect reference remains authoritative for the exhaustive op-by-op syntax and semantics. The categories below summarize how that support is used in PTO micro Instruction code. + +| Category | Representative Ops | Typical Use in PTO micro Instruction Code | +|----------|--------------------|------------------| +| Constants | `arith.constant` | integer, floating-point, boolean, and `index` constants | +| Integer / Index Arithmetic | `arith.addi`, `arith.subi`, `arith.muli`, `arith.divsi`, `arith.divui`, `arith.ceildivsi`, `arith.ceildivui`, `arith.floordivsi`, `arith.remsi`, `arith.remui` | offsets, bounds, chunk sizes, scalar math | +| Floating-Point Arithmetic | `arith.addf`, `arith.subf`, `arith.mulf`, `arith.divf`, `arith.negf`, `arith.maximumf`, `arith.minimumf`, `arith.maxnumf`, `arith.minnumf` | scalar math around PTO regions | +| Bitwise / Shift Ops | `arith.andi`, `arith.ori`, `arith.xori`, `arith.shli`, `arith.shrsi`, `arith.shrui` | flags, masks, packed scalar fields | +| Comparisons / Select | `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.maxsi`, `arith.minui` | predicates, clamps, scalar muxes | +| Casts / Width Changes | `arith.index_cast`, `arith.index_castui`, `arith.extsi`, `arith.extui`, `arith.trunci`, `arith.sitofp`, `arith.uitofp`, `arith.fptosi`, `arith.fptoui`, `arith.extf`, `arith.truncf`, `arith.bitcast` | ABI glue, dynamic-shape plumbing, scalar type adaptation | + +--- + +#### Current PTOAS Coverage + +- the current repository examples are still dominated by constants, casts, integer/index arithmetic, compares, and selects because those are the most common surrounding-scalar patterns in existing kernels +- backend-specific tests such as the PTO shared-dialect fixture visibly exercise only a representative subset of `arith` ops in a single path +- the documented PTO micro Instruction source-level contract is nevertheless the full scalar `arith` surface, not just the index-heavy subset that appears most often in current samples + +This section therefore uses representative categories and examples instead of pretending that the supported `arith` surface is limited to the currently most common sample patterns. + +--- + +#### Typical Patterns + +##### Scalar Setup + +```mlir +%c0 = arith.constant 0 : index +%c1 = arith.constant 1 : index +%scale = arith.constant 2.0 : f32 +``` + +##### Dynamic Offset Computation + +```mlir +%vrow = arith.index_cast %valid_row : i32 to index +%chunk = arith.muli %row, %c32 : index +%tail = arith.subi %limit, %chunk : index +``` + +##### General Scalar Arithmetic + +```mlir +%sum_i = arith.addi %lhs_i, %rhs_i : i32 +%sum_f = arith.addf %lhs_f, %rhs_f : f32 +%prod_f = arith.mulf %sum_f, %scale : f32 +``` + +##### Scalar Predicate and Selection + +```mlir +%is_first = arith.cmpi eq, %i, %c0 : index +%active = arith.select %is_first, %first_count, %steady_count : index +``` + +##### Bitwise / Width Adaptation + +```mlir +%flags = arith.andi %flags0, %flags1 : i32 +%wide = arith.extui %flags : i32 to i64 +%shrunk = arith.trunci %wide : i64 to i16 +``` + +--- + +#### Authoring Guidance + +- treat upstream `arith` scalar semantics as the source of truth for supported scalar ops +- keep `arith` values scalar or `index` typed; do not use `arith` as a substitute for PTO vector/tile compute +- use `arith` for general scalar math, scalar comparisons, bitwise operations, and casts around PTO regions, not just for `index` arithmetic +- use `arith.cmpi` / `arith.cmpf` plus `scf.if` / `scf.while` for control flow, not ad hoc control intrinsics +- prefer `arith.index_cast` / `arith.index_castui` at ABI or shape boundaries where `index` is required, but do not read that as a restriction on the rest of scalar `arith` + + + +### 15. SCF (Shared MLIR Dialect) + +> **Category:** Shared structured control flow around PTO regions +> **Dialect:** `scf` +> **Upstream Reference:** https://mlir.llvm.org/docs/Dialects/SCFDialect/ + +The upstream MLIR `scf` dialect defines structured control flow operations with regions, including counted loops, conditional regions, and while-style loops. In PTO micro Instruction code, `scf` is the control shell around PTO ops: it sequences DMA, vector, and tile operations; carries scalar or tile state across iterations; and preserves analyzable control flow for PTO-specific analyses and lowerings. + +These ops are part of the documented PTO micro Instruction surface, but they are shared MLIR control-flow constructs rather than PTO ISA instructions. + +--- + +#### Supported Ops + +| Op | Role in PTO micro Instruction Code | Notes | +|----|------------------------|-------| +| `scf.for` | counted loops and loop-carried values | common structured counted loop form | +| `scf.if` | structured conditional execution | may yield values or act as side-effect-only branch | +| `scf.yield` | region terminator for `for` / `if` / `while` bodies | carries loop or branch results | +| `scf.while` | break-like or stateful loops | useful for source-level structured control | +| `scf.condition` | loop-continue / loop-exit decision for `scf.while` | placed in the "before" region | + +Ops such as `scf.execute_region`, `scf.forall`, or `scf.index_switch` are not part of the documented shared-dialect portion of the PTO micro Instruction surface here. + +--- + +#### Current PTOAS Coverage + +- `scf.for`, `scf.if`, and `scf.yield` are directly exercised in the shared-dialect PTO fixture and appear widely across PTO samples +- PTO synchronization and memory analyses explicitly reason about `scf.for`, `scf.if`, `scf.yield`, and `scf.while` +- `scf.while` and `scf.condition` appear in control-flow samples and are handled in PTO-to-EmitC control-flow lowering, but they are less broadly exercised than `for` / `if` on all backend paths + +--- + +#### Typical Patterns + +##### Counted Loop + +```mlir +scf.for %i = %c0 to %c4 step %c1 { + %offset = arith.muli %i, %c32 : index + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} +``` + +##### Counted Loop with Loop-Carried State + +```mlir +%final_alive = scf.for %i = %c0 to %c4 step %c1 + iter_args(%alive = %true) -> (i1) { + %break_now = arith.cmpi eq, %i, %c2 : index + %next_alive = scf.if %break_now -> (i1) { + scf.yield %false : i1 + } else { + scf.yield %alive : i1 + } + scf.yield %next_alive : i1 +} +``` + +##### Structured Conditional Region + +```mlir +%is_mode_a = arith.cmpi eq, %mode, %c0_i32 : i32 +scf.if %is_mode_a { + pto.tmuls ins(%data, %scale_a : !pto.tile_buf<...>, f32) outs(%data : !pto.tile_buf<...>) +} else { + pto.tadds ins(%data, %bias_b : !pto.tile_buf<...>, f32) outs(%data : !pto.tile_buf<...>) +} +``` + +##### While-Style Break Loop + +```mlir +%final:2 = scf.while (%i = %c0, %alive = %true) : (index, i1) -> (index, i1) { + %lt = arith.cmpi slt, %i, %c4 : index + %go = arith.andi %lt, %alive : i1 + scf.condition(%go) %i, %alive : index, i1 +} do { +^bb0(%i2: index, %alive2: i1): + %next_i = arith.addi %i2, %c1 : index + scf.yield %next_i, %alive2 : index, i1 +} +``` + +--- + +#### Authoring Guidance + +- use `scf.for` for regular counted loops and loop-carried scalar/tile state +- use `scf.if` for structured branching around PTO regions instead of inventing PTO-specific branch ops +- keep region results explicit with `scf.yield`; this is important for PTO analyses that track carried buffers and aliasing +- use `scf.while` only when a counted loop cannot express the control cleanly; `scf.for` remains the more common and better-exercised form in the current repository +- build branch predicates and loop conditions with `arith` ops, not PTO vector masks, unless the control decision truly comes from a scalarized value + +## Supported Data Types + +| Type | Bits | vreg Lanes | Description | +|------|------|-----------|-------------| +| `i8` / `u8` | 8 | 256 | Signed/unsigned 8-bit integer | +| `i16` / `u16` | 16 | 128 | Signed/unsigned 16-bit integer | +| `f16` | 16 | 128 | IEEE 754 half precision | +| `bf16` | 16 | 128 | Brain floating point | +| `i32` / `u32` | 32 | 64 | Signed/unsigned 32-bit integer | +| `f32` | 32 | 64 | IEEE 754 single precision | +| `i64` / `u64` | 64 | 32 | Signed/unsigned 64-bit integer | + +--- + +## Common Patterns + +### Softmax (Numerically Stable) + +```mlir +// 1. Find max +%max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %max_vec, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +%max_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + +// 2. exp(x - max) using fused op +%exp = pto.vexpdiff %logits, %max_bc : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// 3. Sum +%sum = pto.vcadd %exp, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %sum, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +%sum_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + +// 4. Divide +%softmax = pto.vdiv %exp, %sum_bc, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +### ReLU Variants + +```mlir +// Standard ReLU +%relu = pto.vrelu %input, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Leaky ReLU (scalar alpha) +%lrelu = pto.vlrelu %input, %alpha, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Parametric ReLU (per-element alpha) +%prelu = pto.vprelu %input, %alpha_vec : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// Fused add + ReLU +%fused = pto.vaddrelu %a, %b : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> +``` + +### Data Layout Conversion + +```mlir +// AoS → SoA (deinterleave) +%x, %y = pto.vldx2 %ub_xy[%offset], "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// SoA → AoS (interleave) +pto.vstx2 %x, %y, %ub_xy[%offset], "INTLV_B32", %all_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, index, !pto.mask +``` + +--- + +## Quick Reference by Category + +### Memory Operations + +| Operation | Group | Description | +|-----------|-------|-------------| +| GM→UB DMA | 2 | `pto.copy_gm_to_ubuf` | +| UB→GM DMA | 2 | `pto.copy_ubuf_to_gm` | +| UB→UB Copy | 2 | `pto.copy_ubuf_to_ubuf` | +| Contiguous Load | 3 | `pto.vlds` with `NORM` dist | +| Broadcast Load | 3 | `pto.vlds` with `BRC_*` dist | +| Gather | 3 | `pto.vgather2`, `pto.vgatherb` | +| Contiguous Store | 3 | `pto.vsts` with `NORM_*` dist | +| Scatter | 3 | `pto.vscatter` | + +### Compute Operations + +| Operation | Group | Description | +|-----------|-------|-------------| +| Element-wise Arithmetic | 6, 7 | `pto.vadd`, `pto.vmul`, `pto.vabs`, etc. | +| Scalar Operations | 8 | `pto.vadds`, `pto.vmuls`, etc. | +| Transcendental | 6 | `pto.vexp`, `pto.vln`, `pto.vsqrt`, etc. | +| Reduction | 10 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin` | +| Comparison | 11 | `pto.vcmp`, `pto.vcmps` | +| Selection | 11 | `pto.vsel`, `pto.vselr` | + +### Type & Data Manipulation + +| Operation | Group | Description | +|-----------|-------|-------------| +| Type Conversion | 9 | `pto.vcvt` | +| Interleave/Deinterleave | 12 | `pto.vintlv`, `pto.vdintlv` | +| Interleave/Deinterleave | 12 | `pto.vintlv`, `pto.vdintlv`, `pto.vintlvv2`, `pto.vdintlvv2` | + +### Synchronization + +| Operation | Group | Description | +|-----------|-------|-------------| +| Intra-core Sync | 1 | `pto.set_flag`, `pto.wait_flag` | +| Pipeline Buffer Sync | 1 | `pto.get_buf`, `pto.rls_buf` | + +### Scalar & Control Operations + +Group 14 covers the full scalar `arith` surface. The rows below list common PTO micro Instruction patterns rather than an exhaustive partition of `arith` ops. + +| Operation | Group | Description | +|-----------|-------|-------------| +| Scalar Constants | 14 | `arith.constant` | +| Scalar Integer / Index Arithmetic | 14 | `arith.addi`, `arith.subi`, `arith.muli`, `arith.divsi`, `arith.remui`, `arith.ceildivsi`, etc. | +| Scalar Floating-Point Arithmetic | 14 | `arith.addf`, `arith.subf`, `arith.mulf`, `arith.divf`, `arith.maximumf`, etc. | +| Scalar Compare & Select | 14 | `arith.cmpi`, `arith.cmpf`, `arith.select` | +| Scalar Casts / Width Changes | 14 | `arith.index_cast`, `arith.index_castui`, `arith.extsi`, `arith.extui`, `arith.trunci`, `arith.sitofp`, etc. | +| Scalar Bitwise / Shift Ops | 14 | `arith.andi`, `arith.ori`, `arith.xori`, `arith.shli`, `arith.shrsi`, `arith.shrui`, etc. | +| Counted Loops | 15 | `scf.for` | +| Conditional Regions | 15 | `scf.if`, `scf.yield` | +| Break-like Structured Loops | 15 | `scf.while`, `scf.condition`, `scf.yield` | diff --git a/tilelang-dsl/python/tilelang_dsl/__init__.py b/tilelang-dsl/python/tilelang_dsl/__init__.py index 5a957abdc..3fb9b2c70 100644 --- a/tilelang-dsl/python/tilelang_dsl/__init__.py +++ b/tilelang-dsl/python/tilelang_dsl/__init__.py @@ -29,6 +29,7 @@ BarrierType, DeinterleaveDist, EVENT, + InterleaveDist, PIPE, Event, InterleaveDist, @@ -65,9 +66,17 @@ get_lanes, i1, i8, + si8, + ui8, i16, + si16, + ui16, i32, + si32, + ui32, i64, + si64, + ui64, mask_b8, mask_b16, mask_b32, @@ -110,6 +119,8 @@ "MaskPattern", "PAT", "BarrierType", + "DeinterleaveDist", + "InterleaveDist", "PadMode", "PositionMode", "OrderMode", @@ -122,9 +133,17 @@ "TileSpecialization", "i1", "i8", + "si8", + "ui8", "i16", + "si16", + "ui16", "i32", + "si32", + "ui32", "i64", + "si64", + "ui64", "f16", "bf16", "f32", diff --git a/tilelang-dsl/python/tilelang_dsl/expand_helper.py b/tilelang-dsl/python/tilelang_dsl/expand_helper.py index 3db9cf225..a01ef7a3c 100644 --- a/tilelang-dsl/python/tilelang_dsl/expand_helper.py +++ b/tilelang-dsl/python/tilelang_dsl/expand_helper.py @@ -31,7 +31,23 @@ def _populate_dtype_map() -> None: from . import types as _t - for name in ("f16", "bf16", "f32", "i8", "i16", "i32", "i64"): + for name in ( + "f16", + "bf16", + "f32", + "i8", + "si8", + "ui8", + "i16", + "si16", + "ui16", + "i32", + "si32", + "ui32", + "i64", + "si64", + "ui64", + ): obj = getattr(_t, name, None) if isinstance(obj, ScalarType): _DTYPE_MAP[name] = obj diff --git a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py index f0a5dc96c..ee3008e40 100644 --- a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py +++ b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py @@ -801,6 +801,8 @@ def _build_expr(node: ast.AST, context: _FrontendBuildContext) -> FrontendExprNo "BarrierType", "MemorySpace", "PadMode", + "DeinterleaveDist", + "InterleaveDist", "PositionMode", "OrderMode", "BLayout", diff --git a/tilelang-dsl/python/tilelang_dsl/kernel.py b/tilelang-dsl/python/tilelang_dsl/kernel.py index 784f328a3..3e4fbc93e 100644 --- a/tilelang-dsl/python/tilelang_dsl/kernel.py +++ b/tilelang-dsl/python/tilelang_dsl/kernel.py @@ -34,6 +34,7 @@ TileSpecialization, TypeVariable, WildcardType, + is_integer_dtype, ) from .frontend_ast import _DMA_CALL_KEYWORDS, build_frontend_kernel_node from .lowering import lower_semantic_kernel @@ -1606,7 +1607,7 @@ def _matches_wildcard(pattern: WildcardType, actual: ScalarType | MaskType) -> b if pattern.name == "AnyFloat": return isinstance(actual, ScalarType) and actual.name in {"f16", "bf16", "f32"} if pattern.name == "AnyInt": - return isinstance(actual, ScalarType) and actual.name.startswith("i") + return isinstance(actual, ScalarType) and is_integer_dtype(actual) if pattern.name == "AnyMask": return isinstance(actual, MaskType) raise TypeError(f"unsupported wildcard matcher {pattern.name!r}") diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index c962c1e33..634102753 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -40,6 +40,7 @@ SemanticPtrType, SemanticReturnStmt, SemanticRlsBufStmt, + SemanticScalarStoreStmt, SemanticScalarType, SemanticSetCrossCoreStmt, SemanticSetFlagStmt, @@ -59,12 +60,24 @@ SemanticTupleExpr, SemanticTupleType, SemanticVRegType, + SemanticVectorPairStoreStmt, SemanticVectorStoreStmt, SemanticWaitFlagDevStmt, SemanticWaitFlagStmt, SemanticWaitIntraCoreStmt, ) -from .types import MaskPattern, MemorySpace, ScalarType, TileConfig, get_lanes, tile_strides +from .types import ( + MaskPattern, + MemorySpace, + ScalarType, + TileConfig, + bytewidth, + get_lanes, + integer_bitwidth, + integer_signedness, + is_integer_dtype, + tile_strides, +) _I1_TYPE = SemanticScalarType(dtype=ScalarType("i1")) @@ -404,6 +417,10 @@ def _render_stmt( return self._render_dma_store(stmt, env, indent=indent) if isinstance(stmt, SemanticVectorStoreStmt): return self._render_vector_store(stmt, env, indent=indent) + if isinstance(stmt, SemanticVectorPairStoreStmt): + return self._render_vector_pair_store(stmt, env, indent=indent) + if isinstance(stmt, SemanticScalarStoreStmt): + return self._render_scalar_store(stmt, env, indent=indent) if isinstance(stmt, SemanticSetFlagStmt): return [ self._indent(indent) @@ -834,6 +851,41 @@ def _render_multi_result_assign( ) for target, result_type in zip(stmt.targets, stmt.value.type.elements): env[target.name] = _RenderedValue(name=target.ssa_name, type=result_type) + + if stmt.value.name == "vldsx2": + lines = [] + source = self._lower_expr(stmt.value.args[0], env, indent=indent, into=lines) + if isinstance(source.type, SemanticTileType): + source = self._materialize_tile_memref(source, indent=indent, into=lines) + index_args = stmt.value.args[1:-1] + if ( + isinstance(stmt.value.args[0].type, SemanticTileType) + and stmt.value.args[0].type.rank == 2 + and len(index_args) == 2 + ): + source = self._materialize_rank2_tile_subview( + source, + stmt.value.args[0].type, + index_args, + env, + indent=indent, + into=lines, + ) + rendered_indices = self._materialize_constant(0, SemanticIndexType()) + else: + rendered_indices = self._render_index_list(index_args, env, indent=indent, into=lines) + dist = self._render_string_literal(stmt.value.args[-1]) + low_target, high_target = stmt.targets + low_type, high_type = stmt.value.type.elements + lines.append( + self._indent(indent) + + f"{low_target.ssa_name}, {high_target.ssa_name} = pto.vldsx2 " + + f"{source.name}[{rendered_indices}], {dist} : " + + f"{self._render_type(source.type)}, index -> " + + f"{self._render_type(low_type)}, {self._render_type(high_type)}" + ) + env[low_target.name] = _RenderedValue(name=low_target.ssa_name, type=low_type) + env[high_target.name] = _RenderedValue(name=high_target.ssa_name, type=high_type) return lines raise NotImplementedError( @@ -975,6 +1027,64 @@ def _render_vector_store( ) return lines + def _render_vector_pair_store( + self, + stmt: SemanticVectorPairStoreStmt, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + lines: list[str] = [] + low = self._lower_expr(stmt.low, env, indent=indent, into=lines) + high = self._lower_expr(stmt.high, env, indent=indent, into=lines) + destination = self._lower_expr(stmt.destination, env, indent=indent, into=lines) + if isinstance(destination.type, SemanticTileType): + destination = self._materialize_tile_memref(destination, indent=indent, into=lines) + if ( + isinstance(stmt.destination.type, SemanticTileType) + and stmt.destination.type.rank == 2 + and len(stmt.indices) == 2 + ): + destination = self._materialize_rank2_tile_subview( + destination, + stmt.destination.type, + stmt.indices, + env, + indent=indent, + into=lines, + ) + rendered_indices = self._materialize_constant(0, SemanticIndexType()) + else: + rendered_indices = self._render_index_list(stmt.indices, env, indent=indent, into=lines) + dist = self._render_string_literal(stmt.dist) + mask = self._lower_expr(stmt.mask, env, indent=indent, into=lines) + lines.append( + self._indent(indent) + + "pto.vstsx2 " + + f"{low.name}, {high.name}, {destination.name}[{rendered_indices}], {dist}, {mask.name} : " + + f"{self._render_type(low.type)}, {self._render_type(high.type)}, " + + f"{self._render_type(destination.type)}, {self._render_type(mask.type)}" + ) + return lines + + def _render_scalar_store( + self, + stmt: SemanticScalarStoreStmt, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + lines: list[str] = [] + value = self._lower_expr(stmt.value, env, indent=indent, into=lines) + destination = self._lower_expr(stmt.destination, env, indent=indent, into=lines) + offset = self._lower_expr(stmt.offset, env, indent=indent, into=lines) + lines.append( + self._indent(indent) + + f"pto.store_scalar {value.name}, {destination.name}[{offset.name}] : " + + f"{self._render_type(destination.type)}, {self._render_type(value.type)}" + ) + return lines + def _render_index_list( self, indices: tuple[SemanticExpr, ...], @@ -1759,20 +1869,22 @@ def _materialize_tile_window_extent( ) def _mask_granularity_for_dtype(self, dtype: ScalarType) -> str: - if dtype.name in {"f32", "i32"}: + int_bits = integer_bitwidth(dtype) + if dtype.name == "f32" or int_bits == 32: return "b32" - if dtype.name in {"f16", "bf16", "i16"}: + if dtype.name in {"f16", "bf16"} or int_bits == 16: return "b16" - if dtype.name == "i8": + if int_bits == 8: return "b8" raise NotImplementedError(f"dtype `{dtype.name}` is not supported by DMA load prefill lowering") def _broadcast_dist_for_dtype(self, dtype: ScalarType) -> str: - if dtype.name in {"f32", "i32"}: + int_bits = integer_bitwidth(dtype) + if dtype.name == "f32" or int_bits == 32: return "BRC_B32" - if dtype.name in {"f16", "bf16", "i16"}: + if dtype.name in {"f16", "bf16"} or int_bits == 16: return "BRC_B16" - if dtype.name == "i8": + if int_bits == 8: return "BRC_B8" raise NotImplementedError(f"dtype `{dtype.name}` is not supported by DMA load broadcast lowering") @@ -2737,6 +2849,24 @@ def _lower_call_expr( + f"{result_name} = pto.vstur {align.name}, {value.name}, {base.name}, {mode} : " + f"{self._render_type(align.type)}, {self._render_type(value.type)}, {self._render_type(base.type)} -> {self._render_type(expr.type)}" ) + + if expr.name == "load_scalar": + source = self._lower_expr(expr.args[0], env, indent=indent, into=into) + offset = self._lower_expr(expr.args[1], env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"{result_name} = pto.load_scalar {source.name}[{offset.name}] : " + + f"{self._render_type(source.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name in { + "get_block_idx", + "get_subblock_idx", + "get_block_num", + "get_subblock_num", + }: + into.append(self._indent(indent) + f"{result_name} = pto.{expr.name}") return _RenderedValue(name=result_name, type=expr.type) if expr.name == "vbr": @@ -2750,11 +2880,12 @@ def _lower_call_expr( if expr.name == "vdup": value = self._lower_expr(expr.args[0], env, indent=indent, into=into) - position = self._render_string_literal(expr.args[1]) + mask = self._lower_expr(expr.args[1], env, indent=indent, into=into) + position = self._render_string_literal(expr.args[2]) into.append( self._indent(indent) - + f"{result_name} = pto.vdup {value.name} {{position = {position}}} : " - + f"{self._render_type(value.type)} -> {self._render_type(expr.type)}" + + f"{result_name} = pto.vdup {value.name}, {mask.name} {{position = {position}}} : " + + f"{self._render_type(value.type)}, {self._render_type(mask.type)} -> {self._render_type(expr.type)}" ) return _RenderedValue(name=result_name, type=expr.type) @@ -2797,7 +2928,24 @@ def _lower_call_expr( ) return _RenderedValue(name=result_name, type=expr.type) - if expr.name in {"i1", "i8", "i16", "i32", "i64", "f16", "bf16", "f32"}: + if expr.name in { + "i1", + "i8", + "si8", + "ui8", + "i16", + "si16", + "ui16", + "i32", + "si32", + "ui32", + "i64", + "si64", + "ui64", + "f16", + "bf16", + "f32", + }: value = self._lower_expr(expr.args[0], env, indent=indent, into=into) return self._coerce_rendered_value(value, expr.type, indent=indent, into=into) @@ -3268,21 +3416,34 @@ def _coerce_rendered_value( indent: int, into: list[str], ) -> _RenderedValue: + def _scalar_int_bits(dtype: ScalarType) -> int | None: + if dtype.name == "i1": + return 1 + return integer_bitwidth(dtype) + + def _scalar_int_sign(dtype: ScalarType) -> str: + sign = integer_signedness(dtype) + return "signless" if sign is None else sign + if type(value.type) is type(target_type) and value.type == target_type: return value if isinstance(value.type, SemanticIndexType) and isinstance(target_type, SemanticScalarType): - if target_type.dtype.name == "i32": + target_int_bits = _scalar_int_bits(target_type.dtype) + target_sign = _scalar_int_sign(target_type.dtype) + if target_int_bits == 32: + op = "arith.index_castui" if target_sign == "unsigned" else "arith.index_cast" cast_name = self._new_temp() into.append( self._indent(indent) - + f"{cast_name} = arith.index_cast {value.name} : index to i32" + + f"{cast_name} = {op} {value.name} : index to {target_type.dtype.name}" ) return _RenderedValue(name=cast_name, type=target_type) - if target_type.dtype.name == "i64": + if target_int_bits == 64: + op = "arith.index_castui" if target_sign in {"signless", "unsigned"} else "arith.index_cast" cast_name = self._new_temp() into.append( self._indent(indent) - + f"{cast_name} = arith.index_castui {value.name} : index to i64" + + f"{cast_name} = {op} {value.name} : index to {target_type.dtype.name}" ) return _RenderedValue(name=cast_name, type=target_type) if target_type.dtype.name in {"f16", "bf16", "f32"}: @@ -3298,25 +3459,32 @@ def _coerce_rendered_value( if src == dst: return value cast_name = self._new_temp() - if src.startswith("i") and dst.startswith("i"): - src_bits = int(src[1:]) - dst_bits = int(dst[1:]) - op = "arith.extsi" if src_bits < dst_bits else "arith.trunci" + src_bits = _scalar_int_bits(value.type.dtype) + dst_bits = _scalar_int_bits(target_type.dtype) + if src_bits is not None and dst_bits is not None: + if src_bits == dst_bits: + op = "arith.bitcast" + elif src_bits < dst_bits: + op = "arith.extui" if _scalar_int_sign(value.type.dtype) == "unsigned" else "arith.extsi" + else: + op = "arith.trunci" into.append( self._indent(indent) + f"{cast_name} = {op} {value.name} : {src} to {dst}" ) return _RenderedValue(name=cast_name, type=target_type) - if src.startswith("i") and dst in {"f16", "bf16", "f32"}: + if src_bits is not None and dst in {"f16", "bf16", "f32"}: + op = "arith.uitofp" if _scalar_int_sign(value.type.dtype) == "unsigned" else "arith.sitofp" into.append( self._indent(indent) - + f"{cast_name} = arith.sitofp {value.name} : {src} to {dst}" + + f"{cast_name} = {op} {value.name} : {src} to {dst}" ) return _RenderedValue(name=cast_name, type=target_type) - if src in {"f16", "bf16", "f32"} and dst.startswith("i"): + if src in {"f16", "bf16", "f32"} and dst_bits is not None: + op = "arith.fptoui" if _scalar_int_sign(target_type.dtype) == "unsigned" else "arith.fptosi" into.append( self._indent(indent) - + f"{cast_name} = arith.fptosi {value.name} : {src} to {dst}" + + f"{cast_name} = {op} {value.name} : {src} to {dst}" ) return _RenderedValue(name=cast_name, type=target_type) if src in {"f16", "bf16", "f32"} and dst in {"f16", "bf16", "f32"}: @@ -3740,19 +3908,10 @@ def _render_tile_buf_dim(self, dim: int | None) -> str: return "?" if dim is None else str(dim) def _dtype_byte_width(self, dtype: ScalarType) -> int: - widths = { - "i8": 1, - "i16": 2, - "i32": 4, - "i64": 8, - "f16": 2, - "bf16": 2, - "f32": 4, - } - width = widths.get(dtype.name) - if width is None: - raise NotImplementedError(f"unsupported DMA dtype '{dtype.name}' in TileLang DSL v1 lowering") - return width + try: + return bytewidth(dtype) + except TypeError as exc: + raise NotImplementedError(f"unsupported DMA dtype '{dtype.name}' in TileLang DSL v1 lowering") from exc def _indent(self, indent: int) -> str: return " " * indent diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 1a56f8311..34b2ddc86 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -77,15 +77,35 @@ i16, i32, i64, + integer_bitwidth, + integer_signedness, + is_float_dtype, + is_integer_dtype, + si8, + si16, + si32, + si64, + ui8, + ui16, + ui32, + ui64, ) _DTYPE_SYMBOLS = { "i1": i1, "i8": i8, + "si8": si8, + "ui8": ui8, "i16": i16, + "si16": si16, + "ui16": ui16, "i32": i32, + "si32": si32, + "ui32": ui32, "i64": i64, + "si64": si64, + "ui64": ui64, "f16": f16, "bf16": bf16, "f32": f32, @@ -104,6 +124,8 @@ _S_LAYOUT_SYMBOLS = {layout.name: layout for layout in SLayout} _PAD_VALUE_SYMBOLS = {pad_value.name: pad_value for pad_value in PadValue} _PAD_MODE_SYMBOLS = {pad_mode.name: pad_mode for pad_mode in PadMode} +_DEINTERLEAVE_DIST_SYMBOLS = dict(DeinterleaveDist.__members__) +_INTERLEAVE_DIST_SYMBOLS = dict(InterleaveDist.__members__) _POSITION_MODE_SYMBOLS = {position_mode.name: position_mode for position_mode in PositionMode} _ORDER_MODE_SYMBOLS = {order_mode.name: order_mode for order_mode in OrderMode} _DEINTERLEAVE_DIST_SYMBOLS = {dist.name: dist for dist in DeinterleaveDist} @@ -190,7 +212,7 @@ } _VECTOR_IMMEDIATE_OPS = {"vshift", "vslide"} _TERNARY_VECTOR_OPS = {"vaxpy", "vmula"} -_MULTI_RESULT_VECTOR_OPS = {"vmull"} +_MULTI_RESULT_VECTOR_OPS = {"vmull", "vldsx2"} _BROADCAST_VECTOR_OPS = {"vbr", "vdup", "vci"} _LOW_LEVEL_DMA_CONFIG_OPS = { "set_loop2_stride_outtoub", @@ -464,6 +486,23 @@ class SemanticVectorStoreStmt(SemanticStmt): mask: SemanticExpr +@dataclass(frozen=True) +class SemanticVectorPairStoreStmt(SemanticStmt): + low: SemanticExpr + high: SemanticExpr + destination: SemanticExpr + indices: tuple[SemanticExpr, ...] + dist: SemanticExpr + mask: SemanticExpr + + +@dataclass(frozen=True) +class SemanticScalarStoreStmt(SemanticStmt): + value: SemanticExpr + destination: SemanticExpr + offset: SemanticExpr + + @dataclass(frozen=True) class SemanticVecscopeStmt(SemanticStmt): body: tuple[SemanticStmt, ...] @@ -964,6 +1003,11 @@ def _frontend_stmt_is_scalar_vecscope_stmt( stmt: FrontendStmtNode, ) -> bool: return isinstance(stmt, FrontendAssignStmt) or ( + isinstance(stmt, FrontendExprStmt) + and isinstance(stmt.expr, FrontendCallExpr) + and stmt.expr.namespace == "pto" + and stmt.expr.name == "store_scalar" + ) or ( isinstance(stmt, FrontendIfStmt) and stmt.is_constexpr ) @@ -1141,6 +1185,8 @@ def _analyze_stmt( ) if self._is_vector_memory_stmt_call(stmt.expr): return self._analyze_vector_memory_stmt(stmt.expr, env, allow_outer_lookup=allow_outer_lookup) + if self._is_scalar_store_call(stmt.expr): + return self._analyze_scalar_store_stmt(stmt.expr, env, allow_outer_lookup=allow_outer_lookup) expr = self._analyze_expr(stmt.expr, env, allow_outer_lookup=allow_outer_lookup) return SemanticExprStmt(expr=expr), dict(env) if isinstance(stmt, FrontendReturnStmt): @@ -1328,7 +1374,14 @@ def _is_vector_memory_stmt_call(self, expr: FrontendExprNode) -> bool: return ( isinstance(expr, FrontendCallExpr) and expr.namespace == "pto" - and expr.name in (_VECTOR_MEMORY_STMT_OPS | _PREDICATE_MEMORY_STMT_OPS) + and expr.name in (_VECTOR_MEMORY_STMT_OPS | _PREDICATE_MEMORY_STMT_OPS | {"vstsx2"}) + ) + + def _is_scalar_store_call(self, expr: FrontendExprNode) -> bool: + return ( + isinstance(expr, FrontendCallExpr) + and expr.namespace == "pto" + and expr.name == "store_scalar" ) def _is_sync_call(self, expr: FrontendExprNode) -> bool: @@ -1479,6 +1532,49 @@ def _analyze_vector_memory_stmt( dict(env), ) + if expr.name == "vstsx2": + if len(expr.args) == 5: + low = self._analyze_expr(expr.args[0], env, allow_outer_lookup=allow_outer_lookup) + high = self._analyze_expr(expr.args[1], env, allow_outer_lookup=allow_outer_lookup) + destination, indices = self._analyze_tile_vector_access( + expr.args[2], + env, + allow_outer_lookup=allow_outer_lookup, + context="pto.vstsx2 destination", + ) + dist = self._analyze_expr(expr.args[3], env, allow_outer_lookup=allow_outer_lookup) + mask = self._analyze_expr(expr.args[4], env, allow_outer_lookup=allow_outer_lookup) + else: + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + if len(args) != 6: + raise TypeError("pto.vstsx2 expects 5 or 6 positional arguments in TileLang DSL v1") + low, high, destination, offset, dist, mask = args + indices = (offset,) + low_type = self._require_vreg_expr(low, "pto.vstsx2 low") + high_type = self._require_vreg_expr(high, "pto.vstsx2 high") + if low_type != high_type: + raise TypeError("pto.vstsx2 requires low/high vectors to use the same vector type") + self._require_vector_pointer_expr(destination, "pto.vstsx2 destination") + for index in indices: + self._require_index_typed_expr(index) + dist = self._normalize_vstsx2_dist(dist) + self._require_mask_for_vreg(mask, low_type, "pto.vstsx2") + self._require_matching_vector_pointer(low_type, destination.type, "pto.vstsx2") + return ( + SemanticVectorPairStoreStmt( + low=low, + high=high, + destination=destination, + indices=indices, + dist=dist, + mask=mask, + ), + dict(env), + ) + analyzed = self._analyze_vector_memory_stmt_call( expr, env, @@ -1703,9 +1799,38 @@ def _analyze_vector_memory_stmt_call( args=(value, destination, *indices), type=SemanticMetaType(kind="void"), ) - raise ValueError(f"unsupported vector-memory stmt pto.{expr.name}") + def _analyze_scalar_store_stmt( + self, + expr: FrontendCallExpr, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + ) -> tuple[SemanticStmt, dict[str, SemanticBinding]]: + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + if len(args) != 3: + raise TypeError("pto.store_scalar expects exactly 3 positional arguments in TileLang DSL v1") + if isinstance(args[0].type, SemanticPtrType): + destination = self._require_pointer_expr(args[0], "pto.store_scalar destination") + offset = args[1] + value = args[2] + else: + value = args[0] + destination = self._require_pointer_expr(args[1], "pto.store_scalar destination") + offset = args[2] + self._require_index_typed_expr(offset) + value_type = self._require_scalar_expr(value, "pto.store_scalar value") + if value_type.dtype != destination.type.element_dtype: + raise TypeError("pto.store_scalar value dtype must match destination pointer element dtype") + return ( + SemanticScalarStoreStmt(value=value, destination=destination, offset=offset), + dict(env), + ) + def _analyze_sync_stmt( self, expr: FrontendCallExpr, @@ -2764,6 +2889,15 @@ def _analyze_expr( ) dist = self._analyze_expr(expr.args[1], env, allow_outer_lookup=allow_outer_lookup) return self._analyze_vldx2((base, *indices, dist)) + if expr.namespace == "pto" and expr.name == "vldsx2" and len(expr.args) == 2: + base, indices = self._analyze_tile_vector_access( + expr.args[0], + env, + allow_outer_lookup=allow_outer_lookup, + context="pto.vldsx2 source", + ) + dist = self._analyze_expr(expr.args[1], env, allow_outer_lookup=allow_outer_lookup) + return self._analyze_vldsx2((base, *indices, dist)) if ( expr.namespace == "pto" and expr.name == "vsld" @@ -2795,6 +2929,15 @@ def _analyze_expr( for arg in expr.args[1:] ) return self._analyze_predicate_memory_expr_op("plds", (base, *indices, *extra_args)) + if expr.namespace == "pto" and expr.name == "vldsx2" and len(expr.args) == 2: + base, indices = self._analyze_tile_vector_access( + expr.args[0], + env, + allow_outer_lookup=allow_outer_lookup, + context="pto.vldsx2 source", + ) + dist = self._analyze_expr(expr.args[1], env, allow_outer_lookup=allow_outer_lookup) + return self._analyze_vldsx2((base, *indices, dist)) if expr.keywords: return self._analyze_keyword_call_expr( expr, @@ -2910,6 +3053,24 @@ def _analyze_symbol_expr(self, expr: FrontendSymbolExpr) -> SemanticExpr: value=pad_mode, type=SemanticMetaType(kind="pad_mode"), ) + if expr.namespace in {"DeinterleaveDist", "pto.DeinterleaveDist"}: + dist = _DEINTERLEAVE_DIST_SYMBOLS.get(expr.name) + if dist is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=dist, + type=SemanticMetaType(kind="deinterleave_dist"), + ) + if expr.namespace in {"InterleaveDist", "pto.InterleaveDist"}: + dist = _INTERLEAVE_DIST_SYMBOLS.get(expr.name) + if dist is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=dist, + type=SemanticMetaType(kind="interleave_dist"), + ) if expr.namespace in {"PositionMode", "pto.PositionMode"}: position_mode = _POSITION_MODE_SYMBOLS.get(expr.name) if position_mode is not None: @@ -3443,6 +3604,17 @@ def _analyze_call_expr( return self._analyze_vstus(args) if name == "vstur": return self._analyze_vstur(args) + if name in { + "get_block_idx", + "get_subblock_idx", + "get_block_num", + "get_subblock_num", + }: + return self._analyze_runtime_block_query(name, args) + if name == "vldsx2": + return self._analyze_vldsx2(args) + if name == "load_scalar": + return self._analyze_load_scalar(args) if name in {"ppack", "punpack"}: return self._analyze_mask_part_op(name, args) if name in {"pnot", "psel"} | _PREDICATE_BINARY_LOGIC_OPS: @@ -3700,7 +3872,7 @@ def _analyze_scalar_constructor( return SemanticLiteralExpr(value=bool(literal_value), type=SemanticScalarType(dtype=i1)) if isinstance(literal_value, float): return SemanticLiteralExpr(value=bool(literal_value), type=SemanticScalarType(dtype=i1)) - elif target_dtype.name.startswith("i"): + elif is_integer_dtype(target_dtype): if isinstance(literal_value, bool): casted = int(literal_value) elif isinstance(literal_value, (int, float)): @@ -3708,9 +3880,15 @@ def _analyze_scalar_constructor( else: casted = None if casted is not None: - bits = int(target_dtype.name[1:]) - min_value = -(1 << (bits - 1)) - max_value = (1 << (bits - 1)) - 1 + bits = integer_bitwidth(target_dtype) + signedness = integer_signedness(target_dtype) + assert bits is not None + if signedness == "unsigned": + min_value = 0 + max_value = (1 << bits) - 1 + else: + min_value = -(1 << (bits - 1)) + max_value = (1 << (bits - 1)) - 1 if casted < min_value or casted > max_value: raise TypeError( f"pto.{name} value {casted} is out of range for {target_dtype.name} in TileLang DSL v1" @@ -4025,6 +4203,66 @@ def _analyze_vstur(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: type=_ALIGN_TYPE, ) + def _analyze_vldsx2(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) not in {3, 4}: + raise TypeError("pto.vldsx2 expects 3 or 4 positional arguments in TileLang DSL v1") + source, *rest = args + if len(rest) == 2: + index_args = rest[:1] + dist = rest[1] + else: + index_args = rest[:2] + dist = rest[2] + source_type = source.type + if isinstance(source_type, SemanticTileType): + source = self._require_tile_expr(source, "pto.vldsx2 source") + else: + source = self._require_pointer_expr(source, "pto.vldsx2 source", memory_space="ub") + for index in index_args: + self._require_index_typed_expr(index) + dist = self._normalize_vldsx2_dist(dist) + vreg_type = self._vreg_type_for_dtype(source.type.element_dtype) + return SemanticCallExpr( + namespace="pto", + name="vldsx2", + args=(source, *index_args, dist), + type=SemanticTupleType(elements=(vreg_type, vreg_type)), + ) + + def _analyze_load_scalar(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) == 2: + destination_dtype = None + pointer, offset = args + elif len(args) == 3: + destination_dtype = self._require_dtype_symbol(args[0], "pto.load_scalar result type") + pointer, offset = args[1:] + else: + raise TypeError("pto.load_scalar expects 2 or 3 positional arguments in TileLang DSL v1") + pointer = self._require_pointer_expr(pointer, "pto.load_scalar source") + self._require_index_typed_expr(offset) + if destination_dtype is not None and destination_dtype != pointer.type.element_dtype: + raise TypeError("pto.load_scalar result type must match source pointer element dtype") + return SemanticCallExpr( + namespace="pto", + name="load_scalar", + args=(pointer, offset), + type=SemanticScalarType(dtype=pointer.type.element_dtype), + ) + + def _analyze_runtime_block_query( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if args: + raise TypeError(f"pto.{name} does not accept positional arguments in TileLang DSL v1") + return SemanticCallExpr( + namespace="pto", + name=name, + args=(), + type=SemanticScalarType(dtype=i64), + ) + def _analyze_broadcast_vector_op( self, name: str, @@ -4038,19 +4276,21 @@ def _analyze_broadcast_vector_op( return SemanticCallExpr(namespace="pto", name=name, args=args, type=vec_type) if name == "vdup": - if len(args) not in {1, 2}: - raise TypeError("pto.vdup expects 1 or 2 positional arguments in TileLang DSL v1") + if len(args) not in {2, 3}: + raise TypeError("pto.vdup expects 2 or 3 positional arguments in TileLang DSL v1") value = args[0] if isinstance(value.type, SemanticVRegType): vec_type = value.type else: vec_type = self._vreg_type_for_scalar_or_index(value, "pto.vdup input") - position_arg = args[1] if len(args) == 2 else None + mask = args[1] + self._require_mask_for_vreg(mask, vec_type, "pto.vdup") + position_arg = args[2] if len(args) == 3 else None position = self._normalize_position_mode(position_arg, "pto.vdup position") return SemanticCallExpr( namespace="pto", name=name, - args=(value, position), + args=(value, mask, position), type=vec_type, ) @@ -4059,8 +4299,8 @@ def _analyze_broadcast_vector_op( raise TypeError("pto.vci expects 1 or 2 positional arguments in TileLang DSL v1") index = self._require_scalar_or_index_expr(args[0], "pto.vci index") index_dtype = i32 if isinstance(index.type, SemanticIndexType) else index.type.dtype - if index_dtype.name not in {"i8", "i16", "i32"}: - raise TypeError("pto.vci index only supports i8/i16/i32 in TileLang DSL v1") + if not (is_integer_dtype(index_dtype) and integer_bitwidth(index_dtype) in {8, 16, 32}): + raise TypeError("pto.vci index only supports 8/16/32-bit integer dtypes in TileLang DSL v1") order_arg = args[1] if len(args) == 2 else None order = self._normalize_order_mode(order_arg, "pto.vci order") return SemanticCallExpr( @@ -4151,7 +4391,7 @@ def _analyze_binary_vector_op( lhs = self._require_vreg_expr(lhs_expr, f"pto.{name} lhs") rhs = self._require_vreg_expr(rhs_expr, f"pto.{name} rhs") if name == "vperm": - if rhs.element_dtype.name not in {"i8", "i16", "i32"}: + if not (is_integer_dtype(rhs.element_dtype) and integer_bitwidth(rhs.element_dtype) in {8, 16, 32}): raise TypeError("pto.vperm indices vector only supports integer vector dtypes in TileLang DSL v1") if lhs.lanes != rhs.lanes: raise TypeError("pto.vperm requires data/indices vectors to use the same lane width") @@ -4186,8 +4426,10 @@ def _analyze_vector_immediate_op( raise TypeError(f"pto.{name} expects exactly 3 positional arguments in TileLang DSL v1") vector = self._require_vreg_expr(args[0], f"pto.{name} vector") immediate = self._require_scalar_or_index_expr(args[1], f"pto.{name} immediate") - if isinstance(immediate.type, SemanticScalarType) and immediate.type.dtype.name not in {"i8", "i16", "i32"}: - raise TypeError(f"pto.{name} immediate only supports i8/i16/i32 in TileLang DSL v1") + if isinstance(immediate.type, SemanticScalarType) and not ( + is_integer_dtype(immediate.type.dtype) and integer_bitwidth(immediate.type.dtype) in {8, 16, 32} + ): + raise TypeError(f"pto.{name} immediate only supports 8/16/32-bit integer dtypes in TileLang DSL v1") self._require_mask_for_vreg(args[2], vector, f"pto.{name}") self._validate_vector_immediate_dtype(name, vector.element_dtype) return SemanticCallExpr(namespace="pto", name=name, args=args, type=vector) @@ -4605,9 +4847,11 @@ def _normalize_position_mode( ): return SemanticLiteralExpr(value=expr.binding.value.value, type=SemanticMetaType(kind="string")) position = self._require_string_expr(expr, context) - if position != PositionMode.LOWEST.value: + if position == "POS_LOWEST": + position = PositionMode.LOWEST.value + if position not in {PositionMode.LOWEST.value, PositionMode.HIGHEST.value}: raise TypeError( - "pto.vdup currently only supports position `PositionMode.LOWEST` in TileLang DSL v1" + "pto.vdup position must be `PositionMode.LOWEST` or `PositionMode.HIGHEST` in TileLang DSL v1" ) return SemanticLiteralExpr(value=position, type=SemanticMetaType(kind="string")) @@ -4809,26 +5053,76 @@ def _require_matching_vector_pointer( return raise TypeError(f"{context} requires a Tile or pointer destination in TileLang DSL") + def _normalize_vldsx2_dist(self, expr: SemanticExpr) -> SemanticExpr: + if ( + isinstance(expr, SemanticSymbolExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "deinterleave_dist" + and isinstance(expr.value, DeinterleaveDist) + ): + dist = expr.value.value + elif ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "deinterleave_dist" + and isinstance(expr.binding.value, DeinterleaveDist) + ): + dist = expr.binding.value.value + else: + dist = self._require_string_expr(expr, "pto.vldsx2 dist") + legacy_map = { + "DINTLV_B8": "DINTLV", + "DINTLV_B16": "DINTLV", + "DINTLV_B32": "DINTLV", + "BD": "BDINTLV", + } + normalized = legacy_map.get(dist, dist) + if normalized not in {"DINTLV", "BDINTLV"}: + raise TypeError( + "pto.vldsx2 dist must be one of \"DINTLV\" or \"BDINTLV\" in TileLang DSL v1" + ) + return SemanticLiteralExpr(value=normalized, type=SemanticMetaType(kind="string")) + + def _normalize_vstsx2_dist(self, expr: SemanticExpr) -> SemanticExpr: + if ( + isinstance(expr, SemanticSymbolExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "interleave_dist" + and isinstance(expr.value, InterleaveDist) + ): + dist = expr.value.value + elif ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "interleave_dist" + and isinstance(expr.binding.value, InterleaveDist) + ): + dist = expr.binding.value.value + else: + dist = self._require_string_expr(expr, "pto.vstsx2 dist") + legacy_map = { + "INTLV_B8": "INTLV", + "INTLV_B16": "INTLV", + "INTLV_B32": "INTLV", + } + normalized = legacy_map.get(dist, dist) + if normalized != "INTLV": + raise TypeError("pto.vstsx2 dist must be \"INTLV\" in TileLang DSL v1") + return SemanticLiteralExpr(value=normalized, type=SemanticMetaType(kind="string")) + def _mask_granularity_for_dtype(self, dtype: ScalarType) -> str: - if dtype.name in {"f32", "i32"}: + int_bits = integer_bitwidth(dtype) + if dtype.name == "f32" or int_bits == 32: return "b32" - if dtype.name in {"f16", "bf16", "i16"}: + if dtype.name in {"f16", "bf16"} or int_bits == 16: return "b16" - if dtype.name == "i8": + if int_bits == 8: return "b8" raise TypeError(f"dtype `{dtype.name}` is not supported by make_mask/vector lowering in TileLang DSL v1") def _vreg_type_for_dtype(self, dtype: ScalarType) -> SemanticVRegType: - byte_widths = { - "i8": 1, - "i16": 2, - "i32": 4, - "f16": 2, - "bf16": 2, - "f32": 4, - } - width = byte_widths.get(dtype.name) - if width is None: + width = bytewidth(dtype) + if width not in {1, 2, 4}: raise TypeError(f"dtype `{dtype.name}` is not supported by vlds/vsts in TileLang DSL v1") return SemanticVRegType(element_dtype=dtype, lanes=256 // width) @@ -4837,9 +5131,13 @@ def _validate_unary_dtype(self, name: str, dtype: ScalarType) -> None: raise TypeError(f"pto.{name} only supports f16/f32 in TileLang DSL v1") if name == "vrelu" and dtype.name not in {"f16", "f32"}: raise TypeError("pto.vrelu only supports f16/f32 in TileLang DSL v1") - if name in {"vnot", "vbcnt", "vcls", "vsunpack", "vzunpack", "vusqz", "vsqz"} and dtype.name not in {"i8", "i16", "i32"}: + if name in {"vnot", "vbcnt", "vcls", "vsunpack", "vzunpack", "vusqz", "vsqz"} and not ( + is_integer_dtype(dtype) and integer_bitwidth(dtype) in {8, 16, 32} + ): raise TypeError(f"pto.{name} only supports integer vector dtypes in TileLang DSL v1") - if name in {"vabs", "vneg", "vmov", "vtrc", "vbitsort", "vcadd", "vcmax", "vcmin"} and dtype.name not in {"i8", "i16", "i32", "f16", "bf16", "f32"}: + if name in {"vabs", "vneg", "vmov", "vtrc", "vbitsort", "vcadd", "vcmax", "vcmin"} and not ( + (is_integer_dtype(dtype) and integer_bitwidth(dtype) in {8, 16, 32}) or dtype.name in {"f16", "bf16", "f32"} + ): raise TypeError(f"pto.{name} does not support this dtype in TileLang DSL v1") def _validate_binary_dtype(self, name: str, dtype: ScalarType) -> None: @@ -4849,17 +5147,29 @@ def _validate_binary_dtype(self, name: str, dtype: ScalarType) -> None: raise TypeError("pto.vprelu only supports f16/f32 in TileLang DSL v1") if name in {"vaddreluconv", "vmulconv"} and dtype.name not in {"f16", "bf16", "f32"}: raise TypeError(f"pto.{name} only supports f16/bf16/f32 in TileLang DSL v1") - if name in {"vand", "vor", "vxor"} and dtype.name not in {"i8", "i16", "i32"}: + if name in {"vand", "vor", "vxor"} and not ( + is_integer_dtype(dtype) and integer_bitwidth(dtype) in {8, 16, 32} + ): raise TypeError(f"pto.{name} only supports integer vector dtypes in TileLang DSL v1") - if name in {"vshl", "vshr"} and dtype.name not in {"i8", "i16", "i32"}: + if name in {"vshl", "vshr"} and not ( + is_integer_dtype(dtype) and integer_bitwidth(dtype) in {8, 16, 32} + ): raise TypeError(f"pto.{name} only supports integer vector dtypes in TileLang DSL v1") - if name == "vmul" and dtype.name not in {"i16", "i32", "f16", "f32"}: - raise TypeError("pto.vmul only supports i16/i32/f16/f32 in TileLang DSL v1") - if name == "vperm" and dtype.name not in {"i8", "i16", "i32", "f16", "bf16", "f32"}: + if name == "vmul" and not ( + (is_integer_dtype(dtype) and integer_bitwidth(dtype) in {16, 32}) or dtype.name in {"f16", "f32"} + ): + raise TypeError("pto.vmul only supports 16/32-bit integer families and f16/f32 in TileLang DSL v1") + if name == "vperm" and not ( + (is_integer_dtype(dtype) and integer_bitwidth(dtype) in {8, 16, 32}) or dtype.name in {"f16", "bf16", "f32"} + ): raise TypeError("pto.vperm does not support this data vector dtype in TileLang DSL v1") - if name in {"vadd", "vsub", "vmax", "vmin", "vaddrelu", "vsubrelu"} and dtype.name not in {"i8", "i16", "i32", "f16", "bf16", "f32"}: + if name in {"vadd", "vsub", "vmax", "vmin", "vaddrelu", "vsubrelu"} and not ( + (is_integer_dtype(dtype) and integer_bitwidth(dtype) in {8, 16, 32}) or dtype.name in {"f16", "bf16", "f32"} + ): raise TypeError(f"pto.{name} does not support this dtype in TileLang DSL v1") - if name in {"vpack", "vmrgsort"} and dtype.name not in {"i8", "i16", "i32", "f16", "bf16", "f32"}: + if name in {"vpack", "vmrgsort"} and not ( + (is_integer_dtype(dtype) and integer_bitwidth(dtype) in {8, 16, 32}) or dtype.name in {"f16", "bf16", "f32"} + ): raise TypeError(f"pto.{name} does not support this dtype in TileLang DSL v1") def _validate_vector_scalar_dtype(self, name: str, dtype: ScalarType) -> None: @@ -4867,26 +5177,38 @@ def _validate_vector_scalar_dtype(self, name: str, dtype: ScalarType) -> None: raise TypeError("pto.vdivs only supports f16/f32 in TileLang DSL v1") if name == "vlrelu" and dtype.name not in {"f16", "f32"}: raise TypeError("pto.vlrelu only supports f16/f32 in TileLang DSL v1") - if name in {"vshls", "vshrs"} and dtype.name not in {"i8", "i16", "i32"}: + if name in {"vshls", "vshrs"} and not ( + is_integer_dtype(dtype) and integer_bitwidth(dtype) in {8, 16, 32} + ): raise TypeError(f"pto.{name} only supports integer vector dtypes in TileLang DSL v1") - if name in {"vands", "vors", "vxors"} and dtype.name not in {"i8", "i16", "i32"}: + if name in {"vands", "vors", "vxors"} and not ( + is_integer_dtype(dtype) and integer_bitwidth(dtype) in {8, 16, 32} + ): raise TypeError(f"pto.{name} only supports integer vector dtypes in TileLang DSL v1") - if name in {"vadds", "vsubs", "vmuls", "vmaxs", "vmins"} and dtype.name not in {"i8", "i16", "i32", "f16", "bf16", "f32"}: + if name in {"vadds", "vsubs", "vmuls", "vmaxs", "vmins"} and not ( + (is_integer_dtype(dtype) and integer_bitwidth(dtype) in {8, 16, 32}) or dtype.name in {"f16", "bf16", "f32"} + ): raise TypeError(f"pto.{name} does not support this dtype in TileLang DSL v1") def _validate_vector_immediate_dtype(self, name: str, dtype: ScalarType) -> None: - if name in {"vshift", "vslide"} and dtype.name not in {"i8", "i16", "i32", "f16", "bf16", "f32"}: + if name in {"vshift", "vslide"} and not ( + (is_integer_dtype(dtype) and integer_bitwidth(dtype) in {8, 16, 32}) or dtype.name in {"f16", "bf16", "f32"} + ): raise TypeError(f"pto.{name} does not support this vector dtype in TileLang DSL v1") def _validate_ternary_vector_dtype(self, name: str, dtype: ScalarType) -> None: - if name == "vaxpy" and dtype.name not in {"i16", "i32", "f16", "f32"}: - raise TypeError("pto.vaxpy only supports i16/i32/f16/f32 in TileLang DSL v1") - if name == "vmula" and dtype.name not in {"i16", "i32", "f16", "f32"}: - raise TypeError("pto.vmula only supports i16/i32/f16/f32 in TileLang DSL v1") + if name == "vaxpy" and not ( + (is_integer_dtype(dtype) and integer_bitwidth(dtype) in {16, 32}) or dtype.name in {"f16", "f32"} + ): + raise TypeError("pto.vaxpy only supports 16/32-bit integer families and f16/f32 in TileLang DSL v1") + if name == "vmula" and not ( + (is_integer_dtype(dtype) and integer_bitwidth(dtype) in {16, 32}) or dtype.name in {"f16", "f32"} + ): + raise TypeError("pto.vmula only supports 16/32-bit integer families and f16/f32 in TileLang DSL v1") def _validate_multi_result_vector_dtype(self, name: str, dtype: ScalarType) -> None: - if name == "vmull" and dtype.name != "i32": - raise TypeError("pto.vmull only supports i32 vectors in TileLang DSL v1") + if name == "vmull" and not (is_integer_dtype(dtype) and integer_bitwidth(dtype) == 32): + raise TypeError("pto.vmull only supports 32-bit integer vector families in TileLang DSL v1") def _require_sync_pipe(self, expr: SemanticExpr, context: str) -> str: if isinstance(expr, SemanticSymbolExpr) and expr.type.kind == "pipe": diff --git a/tilelang-dsl/python/tilelang_dsl/support_matrix.py b/tilelang-dsl/python/tilelang_dsl/support_matrix.py index fba2f10be..1d4a76c81 100644 --- a/tilelang-dsl/python/tilelang_dsl/support_matrix.py +++ b/tilelang-dsl/python/tilelang_dsl/support_matrix.py @@ -24,12 +24,24 @@ "vreg", "i1", "i8", + "si8", + "ui8", "i16", + "si16", + "ui16", "i32", + "si32", + "ui32", "i64", + "si64", + "ui64", "f16", "bf16", "f32", + "get_block_idx", + "get_subblock_idx", + "get_block_num", + "get_subblock_num", "set_flag", "wait_flag", "pipe_barrier", @@ -52,8 +64,10 @@ "vldas", "vldus", "vldx2", + "vldsx2", "vsld", "vsts", + "vstsx2", "psts", "vsst", "vstx2", @@ -179,12 +193,14 @@ "ptr", "castptr", "addptr", + "load_scalar", } ) ADVANCED_TOPLEVEL_PTO_CALLS = frozenset( { "strict_vecscope", + "store_scalar", "copy_gm_to_ubuf", "copy_ubuf_to_gm", "copy_ubuf_to_ubuf", diff --git a/tilelang-dsl/python/tilelang_dsl/types.py b/tilelang-dsl/python/tilelang_dsl/types.py index ffdfbb7ae..fb7932013 100644 --- a/tilelang-dsl/python/tilelang_dsl/types.py +++ b/tilelang-dsl/python/tilelang_dsl/types.py @@ -23,6 +23,48 @@ def __repr__(self) -> str: return self.name +_INTEGER_DTYPE_WIDTHS = { + "i8": 8, + "si8": 8, + "ui8": 8, + "i16": 16, + "si16": 16, + "ui16": 16, + "i32": 32, + "si32": 32, + "ui32": 32, + "i64": 64, + "si64": 64, + "ui64": 64, +} + +_INTEGER_DTYPE_SIGNS = { + "i8": "signless", + "si8": "signed", + "ui8": "unsigned", + "i16": "signless", + "si16": "signed", + "ui16": "unsigned", + "i32": "signless", + "si32": "signed", + "ui32": "unsigned", + "i64": "signless", + "si64": "signed", + "ui64": "unsigned", +} + +_FLOAT_DTYPE_WIDTHS = { + "f16": 16, + "bf16": 16, + "f32": 32, +} + +_DTYPE_BYTE_WIDTHS = { + name: bits // 8 for name, bits in _INTEGER_DTYPE_WIDTHS.items() +} +_DTYPE_BYTE_WIDTHS.update({name: bits // 8 for name, bits in _FLOAT_DTYPE_WIDTHS.items()}) + + class TensorView: """Bare TensorView annotation marker for TileLang DSL v1.""" @@ -163,15 +205,9 @@ class PadMode(str, Enum): PadValue = "PadValue" -class PositionMode(str, Enum): - LOWEST = "POS_LOWEST" - - -class OrderMode(str, Enum): - ASC = "ORDER_ASC" - - class DeinterleaveDist(str, Enum): + DINTLV = "DINTLV" + BDINTLV = "BDINTLV" B8 = "DINTLV_B8" B16 = "DINTLV_B16" B32 = "DINTLV_B32" @@ -179,11 +215,21 @@ class DeinterleaveDist(str, Enum): class InterleaveDist(str, Enum): + INTLV = "INTLV" B8 = "INTLV_B8" B16 = "INTLV_B16" B32 = "INTLV_B32" +class PositionMode(str, Enum): + LOWEST = "LOWEST" + HIGHEST = "HIGHEST" + + +class OrderMode(str, Enum): + ASC = "ORDER_ASC" + + class PredicateDist(str, Enum): NORM = "NORM" US = "US" @@ -292,11 +338,19 @@ class TileSpecialization: valid_shape: tuple[int | None, ...] | None = None -i8 = ScalarType("i8") i1 = ScalarType("i1") +i8 = ScalarType("i8") +si8 = ScalarType("si8") +ui8 = ScalarType("ui8") i16 = ScalarType("i16") +si16 = ScalarType("si16") +ui16 = ScalarType("ui16") i32 = ScalarType("i32") +si32 = ScalarType("si32") +ui32 = ScalarType("ui32") i64 = ScalarType("i64") +si64 = ScalarType("si64") +ui64 = ScalarType("ui64") f16 = ScalarType("f16") bf16 = ScalarType("bf16") f32 = ScalarType("f32") @@ -333,18 +387,30 @@ def vreg(dtype: ScalarType) -> VRegType: return VRegType(element_dtype=dtype, lanes=get_lanes(dtype)) +def integer_bitwidth(dtype: ScalarType) -> int | None: + if not isinstance(dtype, ScalarType): + return None + return _INTEGER_DTYPE_WIDTHS.get(dtype.name) + + +def integer_signedness(dtype: ScalarType) -> str | None: + if not isinstance(dtype, ScalarType): + return None + return _INTEGER_DTYPE_SIGNS.get(dtype.name) + + +def is_integer_dtype(dtype: ScalarType) -> bool: + return integer_bitwidth(dtype) is not None + + +def is_float_dtype(dtype: ScalarType) -> bool: + return isinstance(dtype, ScalarType) and dtype.name in _FLOAT_DTYPE_WIDTHS + + def bytewidth(dtype: ScalarType) -> int: if not isinstance(dtype, ScalarType): raise TypeError("bytewidth expects a TileLang scalar dtype") - byte_widths = { - "i8": 1, - "i16": 2, - "i32": 4, - "f16": 2, - "bf16": 2, - "f32": 4, - } - width = byte_widths.get(dtype.name) + width = _DTYPE_BYTE_WIDTHS.get(dtype.name) if width is None: raise TypeError(f"dtype `{dtype.name}` is not supported by bytewidth") return width @@ -427,6 +493,8 @@ def tile_layout_descriptor( "PAT", "BarrierType", "PadMode", + "DeinterleaveDist", + "InterleaveDist", "PositionMode", "OrderMode", "DeinterleaveDist", @@ -438,9 +506,17 @@ def tile_layout_descriptor( "TileSpecialization", "i1", "i8", + "si8", + "ui8", "i16", + "si16", + "ui16", "i32", + "si32", + "ui32", "i64", + "si64", + "ui64", "f16", "bf16", "f32", diff --git a/tilelang-dsl/skills/auto-update-vpto-spec/SKILL.md b/tilelang-dsl/skills/auto-update-vpto-spec/SKILL.md new file mode 100644 index 000000000..38bea7675 --- /dev/null +++ b/tilelang-dsl/skills/auto-update-vpto-spec/SKILL.md @@ -0,0 +1,79 @@ +--- +name: auto-update-vpto-spec +description: 自动对齐 TileLang DSL 与最新 VPTO 规范:比较 spec 差异并指导实现、lowering 与测试同步更新。 +license: MIT +--- + +根据最新 VPTO 规范自动更新 TileLang DSL 的规范与实现。 + +--- + +## 输入 + +- 最新规范文件(建议命名:`vpto-latest.md`) +- 当前 DSL 对齐规范文件(建议命名:`vpto-current.md`) + +如果用户没有提供路径,先询问文件路径后再继续。 + +--- + +## 执行步骤 + +1. 读取最新 VPTO 规范 `vpto-latest.md`(如果最新版本来自网络,则先下载保存)。 +2. 读取当前 DSL 使用的规范 `vpto-current.md`。 +3. 对比两者差异并生成差异报告: + - 若无差异:输出“无需更新”,结束。 + - 若有差异:按变更类型分类并执行对应动作(执行动作前向用户询问确认)。 + +4. 差异分类与处理规则: + +### A. 新增 op + +- 更新 DSL spec 中对应章节,添加新 op 的描述、参数、返回值、示例等。 +- 在 DSL 实现中新增该 op(包含前端定义与必要 lowering)。 +- 补齐对应测试用例(语法、语义、lowering/代码生成路径)。 + +### B. 修改 op 语义 + +- 同步修改 DSL spec 对应 op 的语义描述。 +- 评估并更新 DSL 实现语义行为。 +- 增加/更新回归测试覆盖新语义。 + +### C. 修改 op 参数格式 + +- 优先保持 DSL 前端接口不变(向后兼容用户调用方式)。 +- 在 lowering/转换逻辑层吸收格式变化。 +- 增加测试验证旧接口与新规范语义一致。 + +### D. 删除 op + +- 在 DSL spec 中删除对应 op。 +- 在 DSL 实现中将该 op 标记为不受支持,并在用户使用时显式报错。 +- 增加测试验证报错信息清晰可见。 + +5. 统一补充测试: + - 至少覆盖:新增/变更/删除的 golden path。 + - 包含失败路径(非法参数、已删除 op 调用)验证。 + +6. 将vpto-spec-current.md改名为vpto-spec-*.md(如vpto-spec-2024-06.md),并将vpto-latest.md改名为vpto-spec-current.md,保持版本迭代记录。 + +--- + +## 输出要求 + +- 输出变更摘要: + - 差异总览(新增/修改/删除清单) + - 更新的 DSL spec 章节 + - 更新的实现文件 + - 新增/修改的测试文件 +- 若存在无法自动判定的语义映射,先向用户提问后再继续。 + +--- + +## 护栏 + +- 不在未确认语义的情况下擅自改变 DSL 前端接口。 +- 优先在 lowering 层处理规范格式变更。 +- 删除 op 时必须提供显式错误信息与测试覆盖。 +- 若某一差异无法映射到当前 DSL 架构,先报告阻塞点并请求用户决策。 +- 始终先改文档再落实现 diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index 2e90b822e..8a9de4fb1 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -45,6 +45,7 @@ SemanticPipeBarrierStmt, SemanticPtrType, SemanticRlsBufStmt, + SemanticScalarStoreStmt, SemanticScalarType, SemanticSetCrossCoreStmt, SemanticSetFlagStmt, @@ -55,6 +56,7 @@ SemanticTensorViewType, SemanticTileType, SemanticVecscopeStmt, + SemanticVectorPairStoreStmt, SemanticVectorStoreStmt, SemanticVRegType, SemanticWaitFlagDevStmt, @@ -93,6 +95,8 @@ def test_package_exports_surface(self) -> None: self.assertTrue(hasattr(pto, "PAT")) self.assertTrue(hasattr(pto, "PadMode")) self.assertTrue(hasattr(pto, "BarrierType")) + self.assertTrue(hasattr(pto, "DeinterleaveDist")) + self.assertTrue(hasattr(pto, "InterleaveDist")) self.assertTrue(hasattr(pto, "PositionMode")) self.assertTrue(hasattr(pto, "OrderMode")) self.assertTrue(hasattr(pto, "DeinterleaveDist")) @@ -103,6 +107,14 @@ def test_package_exports_surface(self) -> None: self.assertTrue(hasattr(pto, "PIPE")) self.assertTrue(hasattr(pto, "EVENT")) self.assertEqual(repr(pto.align), "align") + self.assertTrue(hasattr(pto, "si8")) + self.assertTrue(hasattr(pto, "ui8")) + self.assertTrue(hasattr(pto, "si16")) + self.assertTrue(hasattr(pto, "ui16")) + self.assertTrue(hasattr(pto, "si32")) + self.assertTrue(hasattr(pto, "ui32")) + self.assertTrue(hasattr(pto, "si64")) + self.assertTrue(hasattr(pto, "ui64")) self.assertEqual(pto.BarrierType.VST_VLD.value, "VST_VLD") self.assertEqual(pto.PadMode.PadNull.value, "PadNull") self.assertEqual(pto.PadMode.PadFirstElem.value, "PadFirstElem") @@ -110,7 +122,11 @@ def test_package_exports_surface(self) -> None: self.assertEqual(pto.BLayout.ROW_MAJOR.value, "row_major") self.assertEqual(pto.SLayout.NONE_BOX.value, "none_box") self.assertEqual(pto.PadValue.ZERO.value, "zero") - self.assertEqual(pto.PositionMode.LOWEST.value, "POS_LOWEST") + self.assertEqual(pto.PositionMode.LOWEST.value, "LOWEST") + self.assertEqual(pto.DeinterleaveDist.DINTLV.value, "DINTLV") + self.assertEqual(pto.DeinterleaveDist.BDINTLV.value, "BDINTLV") + self.assertEqual(pto.InterleaveDist.INTLV.value, "INTLV") + self.assertEqual(pto.PositionMode.HIGHEST.value, "HIGHEST") self.assertEqual(pto.OrderMode.ASC.value, "ORDER_ASC") self.assertEqual(pto.DeinterleaveDist.B32.value, "DINTLV_B32") self.assertEqual(pto.InterleaveDist.B16.value, "INTLV_B16") @@ -120,6 +136,18 @@ def test_package_exports_surface(self) -> None: self.assertEqual(pto.PredicatePart.HIGHER.value, "HIGHER") self.assertEqual(pto.StrideMode.S4_B64.value, "STRIDE_S4_B64") self.assertEqual(pto.Event.ID31.value, "EVENT_ID31") + self.assertIs(pto.DeinterleaveDist.B32, pto.DeinterleaveDist.DINTLV) + self.assertIs(pto.InterleaveDist.B32, pto.InterleaveDist.INTLV) + self.assertEqual(pto.si8.name, "si8") + self.assertEqual(pto.ui16.name, "ui16") + self.assertEqual(pto.si32.name, "si32") + self.assertEqual(pto.ui64.name, "ui64") + self.assertIsNot(pto.si8, pto.i8) + self.assertIsNot(pto.ui32, pto.i32) + self.assertEqual(pto.bytewidth(pto.si16), 2) + self.assertEqual(pto.bytewidth(pto.ui64), 8) + self.assertEqual(pto.get_lanes(pto.ui32), 64) + self.assertEqual(pto.elements_per_vreg(pto.si8), 256) class TileLangDSLSupportMatrixTests(unittest.TestCase): @@ -164,6 +192,8 @@ def test_stable_starter_surface_groups_map_to_stable_tier(self) -> None: self.assertEqual(get_feature_tier("pto.vmuls"), BASIC_TIER) self.assertEqual(get_feature_tier("pto.get_buf"), BASIC_TIER) self.assertEqual(get_feature_tier("pto.rls_buf"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.get_block_idx"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.get_subblock_num"), BASIC_TIER) self.assertEqual(get_feature_tier("pto.mem_bar"), BASIC_TIER) self.assertEqual(get_feature_tier("pto.set_cross_core"), BASIC_TIER) self.assertEqual(get_feature_tier("pto.set_intra_block"), BASIC_TIER) @@ -190,6 +220,8 @@ def test_stable_starter_surface_groups_map_to_stable_tier(self) -> None: self.assertEqual(get_feature_tier("pto.vstur"), ADVANCED_TIER) self.assertEqual(get_feature_tier("PredicateDist"), ADVANCED_TIER) self.assertEqual(get_feature_tier("PredicatePart"), ADVANCED_TIER) + self.assertEqual(get_feature_tier("pto.vldsx2"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vstsx2"), BASIC_TIER) self.assertEqual(get_feature_tier("PadMode"), BASIC_TIER) self.assertEqual(get_feature_tier("VRegType"), BASIC_TIER) self.assertEqual(get_feature_tier("MaskType"), BASIC_TIER) @@ -225,6 +257,8 @@ def test_non_stable_surface_groups_keep_advanced_boundaries(self) -> None: self.assertEqual(get_feature_tier("pto.strict_vecscope"), ADVANCED_TIER) self.assertEqual(get_feature_tier("pto.ptr"), ADVANCED_TIER) self.assertEqual(get_feature_tier("pto.castptr"), ADVANCED_TIER) + self.assertEqual(get_feature_tier("pto.load_scalar"), ADVANCED_TIER) + self.assertEqual(get_feature_tier("pto.store_scalar"), ADVANCED_TIER) self.assertEqual(get_feature_tier("pto.copy_ubuf_to_ubuf"), ADVANCED_TIER) self.assertEqual(get_feature_tier("pto.tile_with_strides"), ADVANCED_TIER) @@ -378,6 +412,15 @@ def kernel(lhs: pto.TensorView, rhs: pto.Tile): self.assertEqual(selected.parameters[0].dtype, pto.f32) self.assertEqual(selected.parameters[1].dtype, pto.i32) + selected_int = pto.select_kernel( + "a5", + "matcher_wildcard_unique", + (pto.ui16, pto.si16), + ) + self.assertEqual(selected_int.dtype_signature, (pto.ui16, pto.si16)) + self.assertEqual(selected_int.parameters[0].dtype, pto.ui16) + self.assertEqual(selected_int.parameters[1].dtype, pto.si16) + def test_select_kernel_enforces_typevar_consistency_per_signature(self) -> None: @pto.vkernel( op="matcher_typevar_unique", @@ -2414,8 +2457,8 @@ def kernel(dst: pto.Tile, src: pto.Tile, seed: pto.i32): vec0 = pto.vlds(src, 0) broadcast = pto.vbr(seed) - dup_from_vec = pto.vdup(vec0) - dup_from_scalar = pto.vdup(seed, position=pto.PositionMode.LOWEST) + dup_from_vec = pto.vdup(vec0, all_mask, pto.PositionMode.HIGHEST) + dup_from_scalar = pto.vdup(seed, all_mask) idx0 = pto.vci(seed) idx1 = pto.vci(seed, order=pto.OrderMode.ASC) @@ -2437,12 +2480,13 @@ def kernel(dst: pto.Tile, src: pto.Tile, seed: pto.i32): self.assertIn("pto.vci", text) self.assertRegex( text, - r'pto\.vdup\s+%[^\s]+\s+\{position = "POS_LOWEST"\}\s+:', + r'pto\.vdup\s+%[^\s]+,\s+%[^\s]+\s+\{position = "HIGHEST"\}\s+:', ) - self.assertNotRegex( + self.assertRegex( text, - r'pto\.vdup\s+%[^\s]+,\s*"POS_LOWEST"\s+:', + r'pto\.vdup\s+%[^\s]+,\s+%[^\s]+\s+\{position = "LOWEST"\}\s+:', ) + self.assertNotIn('position = "POS_LOWEST"', text) self.assertRegex( text, r'pto\.vci\s+%[^\s]+\s+\{order = "ORDER_ASC"\}\s+:', @@ -2452,6 +2496,36 @@ def kernel(dst: pto.Tile, src: pto.Tile, seed: pto.i32): r'pto\.vci\s+%[^\s]+,\s*"ORDER_ASC"\s+:', ) + def test_signed_and_unsigned_integer_dtypes_lower_distinctly(self) -> None: + @pto.vkernel( + op="signed_unsigned_integer_types_unique", + dtypes=[(pto.si16, pto.si16, pto.ui16, pto.ui16)], + advanced=True, + ) + def kernel(dst_s: pto.Tile, src_s: pto.Tile, dst_u: pto.Tile, src_u: pto.Tile): + signed_mask = pto.make_mask(pto.si16, pto.PAT.ALL) + unsigned_mask = pto.make_mask(pto.ui16, pto.PAT.ALL) + signed_vec = pto.vlds(src_s, 0) + unsigned_vec = pto.vlds(src_u, 0) + signed_out = pto.vadds(signed_vec, pto.si16(-1), signed_mask) + unsigned_out = pto.vadds(unsigned_vec, pto.ui16(1), unsigned_mask) + pto.vsts(signed_out, dst_s, 0, signed_mask) + pto.vsts(unsigned_out, dst_u, 0, unsigned_mask) + return None + + specialized = kernel.specialize( + dst_s=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src_s=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + dst_u=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src_u=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("dtype=si16", text) + self.assertIn("dtype=ui16", text) + self.assertIn("!pto.vreg<128xsi16>", text) + self.assertIn("!pto.vreg<128xui16>", text) + def test_vbr_accepts_float_literal_constant(self) -> None: @pto.vkernel( op="broadcast_float_literal_constant_unique", @@ -3779,6 +3853,81 @@ def kernel( self.assertIn("pto.wait_flag_dev %arg3, %c8_i64 : i64, i64", text) self.assertIn("pto.wait_intra_core %arg4, %c31_i64 : i64, i64", text) + def test_runtime_block_queries_and_scalar_pointer_helpers_lower_to_v0_3_surface(self) -> None: + @pto.vkernel( + op="runtime_block_queries_and_scalar_helpers", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel( + src: pto.ptr(pto.f32, pto.MemorySpace.UB), + dst: pto.ptr(pto.f32, pto.MemorySpace.UB), + ): + block = pto.get_block_idx() + block_num = pto.get_block_num() + subblock = pto.get_subblock_idx() + subblock_num = pto.get_subblock_num() + value = pto.load_scalar(src, 0) + pto.store_scalar(dst, 0, value) + return None + + specialized = kernel.specialize() + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + + store_stmt = next(stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticScalarStoreStmt)) + self.assertIsInstance(store_stmt, SemanticScalarStoreStmt) + self.assertEqual(store_stmt.destination.type.element_dtype, pto.f32) + + text = specialized.mlir_text() + self.assertIn("= pto.get_block_idx", text) + self.assertIn("= pto.get_block_num", text) + self.assertIn("= pto.get_subblock_idx", text) + self.assertIn("= pto.get_subblock_num", text) + self.assertIn("= pto.load_scalar %arg0[%c0] : !pto.ptr -> f32", text) + self.assertIn("pto.store_scalar", text) + + def test_vldsx2_and_vstsx2_tile_sugar_lower_with_normalized_dist_tokens(self) -> None: + @pto.vkernel(op="vldsx2_vstsx2_tile_sugar", dtypes=[(pto.f32, pto.f32)]) + def kernel(src: pto.Tile, dst: pto.Tile): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + low, high = pto.vldsx2(src[0, 0:], pto.DeinterleaveDist.B32) + pto.vstsx2(low, high, dst[0, 0:], pto.InterleaveDist.B32, mask) + return None + + specialized = kernel.specialize( + src=pto.TileSpecialization(shape=(1, 128), memory_space=pto.MemorySpace.UB), + dst=pto.TileSpecialization(shape=(1, 128), memory_space=pto.MemorySpace.UB), + ) + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + vecscope = next(stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticVecscopeStmt)) + pair_store = next(stmt for stmt in vecscope.body if isinstance(stmt, SemanticVectorPairStoreStmt)) + self.assertIsInstance(pair_store, SemanticVectorPairStoreStmt) + + text = specialized.mlir_text() + self.assertIn("pto.vldsx2", text) + self.assertIn("pto.vstsx2", text) + self.assertIn('"DINTLV"', text) + self.assertIn('"INTLV"', text) + self.assertNotIn("DINTLV_B32", text) + self.assertNotIn("INTLV_B32", text) + + def test_vldsx2_and_vstsx2_still_accept_legacy_string_tokens_for_compatibility(self) -> None: + @pto.vkernel(op="vldsx2_vstsx2_legacy_tokens", dtypes=[(pto.f32, pto.f32)]) + def kernel(src: pto.Tile, dst: pto.Tile): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + low, high = pto.vldsx2(src[0, 0:], "DINTLV_B32") + pto.vstsx2(low, high, dst[0, 0:], "INTLV_B32", mask) + return None + + specialized = kernel.specialize( + src=pto.TileSpecialization(shape=(1, 128), memory_space=pto.MemorySpace.UB), + dst=pto.TileSpecialization(shape=(1, 128), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn('"DINTLV"', text) + self.assertIn('"INTLV"', text) + def test_strict_vecscope_rejects_implicit_capture_during_semantic_analysis(self) -> None: @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f16, pto.i32)], advanced=True) def kernel(inp: pto.TensorView, tile: pto.Tile, scale: pto.i32): From 05231274208b4b3808055fea86bc593b908d85fe Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Mon, 13 Apr 2026 15:37:03 +0800 Subject: [PATCH 040/192] Add more memory op support --- tilelang-dsl/docs/unsupported-features.md | 11 +- .../user_guide/09-vector-memory-operations.md | 18 +- tilelang-dsl/python/tilelang_dsl/__init__.py | 23 +- .../python/tilelang_dsl/frontend_ast.py | 25 +- tilelang-dsl/python/tilelang_dsl/lowering.py | 820 +++------ tilelang-dsl/python/tilelang_dsl/semantic.py | 1537 +++++------------ .../python/tilelang_dsl/support_matrix.py | 46 +- tilelang-dsl/python/tilelang_dsl/types.py | 171 +- .../skills/auto-update-vpto-spec/SKILL.md | 11 +- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 505 ++---- 10 files changed, 777 insertions(+), 2390 deletions(-) diff --git a/tilelang-dsl/docs/unsupported-features.md b/tilelang-dsl/docs/unsupported-features.md index 3d09832d6..971ccfb7b 100644 --- a/tilelang-dsl/docs/unsupported-features.md +++ b/tilelang-dsl/docs/unsupported-features.md @@ -61,16 +61,8 @@ The current package supports the core v0.3 load/store subset: The following documented load/store families are still unsupported: -- `pto.vldas(...)` -- `pto.vldus(...)` - `pto.vsld(...)` -- `pto.psts(...)` -- `pto.vsst(...)` -- `pto.vsta(...)` -- `pto.pstu(...)` - `pto.vstu(...)` -- `pto.vstus(...)` -- `pto.vstur(...)` ### Missing Direct Predicate Constructor/Compare APIs @@ -188,8 +180,7 @@ Not currently supported from the guide's broader indexing model: - single-element syntax such as `tile[row, col]` and `tile[pos]` - explicit slice `stop` - stepped tile vector slices -- the guide's wider indexed op family (`vldas`, `vldus`, `vsld`, - `psts`, `vsst`, `vsta`) +- the remaining wider indexed op family gap (`vsld`) ### Control-Flow Result Merging diff --git a/tilelang-dsl/docs/user_guide/09-vector-memory-operations.md b/tilelang-dsl/docs/user_guide/09-vector-memory-operations.md index 2d432941b..3355f9932 100644 --- a/tilelang-dsl/docs/user_guide/09-vector-memory-operations.md +++ b/tilelang-dsl/docs/user_guide/09-vector-memory-operations.md @@ -14,11 +14,21 @@ distribution families: - compatibility aliases: `InterleaveDist.B8`, `InterleaveDist.B16`, `InterleaveDist.B32` +- **`PostUpdateMode`** for `pto.vstur` + - `PostUpdateMode.NO_POST_UPDATE`: preserve the current hardware AR state + - `PostUpdateMode.POST_UPDATE`: advance the hardware AR state after the store + The canonical VPTO v0.3 spellings are the enum values: - `DeinterleaveDist.DINTLV.value == "DINTLV"` - `DeinterleaveDist.BDINTLV.value == "BDINTLV"` - `InterleaveDist.INTLV.value == "INTLV"` +- `PostUpdateMode.NO_POST_UPDATE.value == "NO_POST_UPDATE"` +- `PostUpdateMode.POST_UPDATE.value == "POST_UPDATE"` + +`pto.vstur` mode is intentionally Enum-only in the DSL. Unlike the legacy +distribution-token compatibility retained for some older load/store families, +raw strings such as `"POST_UPDATE"` are not accepted for `PostUpdateMode`. For migration convenience, the implementation still accepts legacy raw strings such as `"DINTLV_B32"` and `"INTLV_B32"`, but new DSL code should prefer the @@ -966,9 +976,9 @@ align2, base2 = pto.vstus(align1, base1, vec1, ub_ptr, offset1) pto.vstas(align2, ub_ptr, flush_offset) ``` -#### `pto.vstur(align_in: pto.align, vec: VRegType, buf: ptr) -> pto.align` [Advanced Tier] +#### `pto.vstur(align_in: pto.align, vec: VRegType, buf: ptr, mode: PostUpdateMode = pto.PostUpdateMode.NO_POST_UPDATE) -> pto.align` [Advanced Tier] -**Description**: Register-update unaligned store form. Updates only the residual alignment state without base pointer update. Requires matching flush operation to emit trailing bytes. +**Description**: Register-update unaligned store form. Updates only the residual alignment state without base pointer update. Requires matching flush operation to emit trailing bytes. The optional `mode` operand is a typed Enum and controls whether the hardware performs post-update on the implicit AR state. **Parameters**: | Parameter | Type | Description | @@ -976,6 +986,7 @@ pto.vstas(align2, ub_ptr, flush_offset) | `align_in` | `pto.align` | Incoming store-alignment state | | `vec` | `VRegType` | Vector to store | | `buf` | `ptr` | Destination buffer in UB memory space | +| `mode` | `PostUpdateMode` | Optional post-update mode. Defaults to `pto.PostUpdateMode.NO_POST_UPDATE`. | **Returns**: | Return Value | Type | Description | @@ -992,6 +1003,9 @@ pto.vstas(align2, ub_ptr, flush_offset) align1 = pto.vstur(align0, vec0, ub_ptr) align2 = pto.vstur(align1, vec1, ub_ptr) pto.vstar(align2, ub_ptr) + +# Explicit post-update mode with typed Enum +align3 = pto.vstur(align2, vec2, ub_ptr, pto.PostUpdateMode.POST_UPDATE) ``` #### Align-State Store Closed Loop diff --git a/tilelang-dsl/python/tilelang_dsl/__init__.py b/tilelang-dsl/python/tilelang_dsl/__init__.py index 3fb9b2c70..3981ca458 100644 --- a/tilelang-dsl/python/tilelang_dsl/__init__.py +++ b/tilelang-dsl/python/tilelang_dsl/__init__.py @@ -20,33 +20,28 @@ vkernel, ) from .types import ( + AlignType, AnyFloat, AnyInt, AnyMask, AnyType, - AlignType, - BLayout, BarrierType, DeinterleaveDist, EVENT, InterleaveDist, PIPE, Event, - InterleaveDist, MaskType, MemorySpace, MaskPattern, PAT, - PredicateDist, - PredicatePart, - PadValue, PadMode, PositionMode, OrderMode, PointerType, + PostUpdateMode, Pipe, ScalarType, - StrideMode, TensorView, PartitionTensorView, Tile, @@ -56,7 +51,6 @@ TypeVariable, VRegType, WildcardType, - align, bf16, constexpr, bytewidth, @@ -81,7 +75,7 @@ mask_b16, mask_b32, ptr, - SLayout, + align, vreg, ) @@ -108,10 +102,8 @@ "AlignType", "ptr", "vreg", + "align", "MemorySpace", - "BLayout", - "SLayout", - "PadValue", "Pipe", "Event", "PIPE", @@ -124,11 +116,7 @@ "PadMode", "PositionMode", "OrderMode", - "DeinterleaveDist", - "InterleaveDist", - "PredicateDist", - "PredicatePart", - "StrideMode", + "PostUpdateMode", "TileConfig", "TileSpecialization", "i1", @@ -151,7 +139,6 @@ "AnyInt", "AnyType", "AnyMask", - "align", "mask_b8", "mask_b16", "mask_b32", diff --git a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py index ee3008e40..30b4dd13d 100644 --- a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py +++ b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py @@ -609,8 +609,6 @@ def _collect_reachable_inline_procs( } _DMA_CALL_KEYWORDS: dict[str, frozenset[str]] = { - "vdup": frozenset({"position"}), - "vci": frozenset({"order"}), "set_loop2_stride_outtoub": frozenset({"src_stride", "dst_stride"}), "set_loop1_stride_outtoub": frozenset({"src_stride", "dst_stride"}), "set_loop_size_outtoub": frozenset({"loop1", "loop2"}), @@ -647,18 +645,6 @@ def _collect_reachable_inline_procs( "ub_stride", } ), - "copy_ubuf_to_ubuf": frozenset( - { - "src", - "dst", - "src_offset", - "src_stride0", - "src_stride1", - "dst_offset", - "dst_stride0", - "dst_stride1", - } - ), } @@ -730,7 +716,7 @@ def _build_call_keywords( raise context.error( node, f"`{call_name}` does not support keyword arguments in TileLang DSL v1; " - "keyword arguments are only supported on selected public call surfaces", + "no public call surface currently accepts them", ) seen: set[str] = set() @@ -805,14 +791,7 @@ def _build_expr(node: ast.AST, context: _FrontendBuildContext) -> FrontendExprNo "InterleaveDist", "PositionMode", "OrderMode", - "BLayout", - "SLayout", - "PadValue", - "DeinterleaveDist", - "InterleaveDist", - "PredicateDist", - "PredicatePart", - "StrideMode", + "PostUpdateMode", } and len(path) >= 2: return FrontendSymbolExpr(namespace=".".join(path[:-1]), name=path[-1]) return FrontendAttributeExpr(base=_build_expr(node.value, context), attr=node.attr) diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index 634102753..22da807db 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -14,6 +14,8 @@ from dataclasses import dataclass from .semantic import ( + SemanticAlignStoreStmt, + SemanticAlignType, SemanticAssignStmt, SemanticAttributeAccess, SemanticBinaryExpr, @@ -22,7 +24,6 @@ SemanticDmaConfigStmt, SemanticDmaLoadStmt, SemanticDmaStoreStmt, - SemanticAlignType, SemanticExpr, SemanticExprStmt, SemanticForStmt, @@ -37,6 +38,7 @@ SemanticMaskType, SemanticMetaType, SemanticPipeBarrierStmt, + SemanticPredicateStoreStmt, SemanticPtrType, SemanticReturnStmt, SemanticRlsBufStmt, @@ -68,15 +70,12 @@ ) from .types import ( MaskPattern, - MemorySpace, ScalarType, - TileConfig, bytewidth, get_lanes, integer_bitwidth, integer_signedness, is_integer_dtype, - tile_strides, ) @@ -292,6 +291,21 @@ def _collect_used_tile_buffers_from_stmt( self._collect_used_tile_buffers_from_expr(index, used) self._collect_used_tile_buffers_from_expr(stmt.mask, used) return + if isinstance(stmt, SemanticPredicateStoreStmt): + self._collect_used_tile_buffers_from_expr(stmt.value, used) + self._record_tile_buffer_use(stmt.destination, used) + for index in stmt.indices: + self._collect_used_tile_buffers_from_expr(index, used) + self._collect_used_tile_buffers_from_expr(stmt.dist, used) + return + if isinstance(stmt, SemanticAlignStoreStmt): + self._collect_used_tile_buffers_from_expr(stmt.value, used) + self._record_tile_buffer_use(stmt.destination, used) + for index in stmt.indices: + self._collect_used_tile_buffers_from_expr(index, used) + if stmt.offset is not None: + self._collect_used_tile_buffers_from_expr(stmt.offset, used) + return if isinstance(stmt, SemanticVecscopeStmt): for nested in stmt.body: self._collect_used_tile_buffers_from_stmt(nested, used) @@ -326,21 +340,8 @@ def _collect_used_tile_buffers_from_expr( used: set[str], ) -> None: if isinstance(expr, SemanticCallExpr): - if expr.namespace == "pto" and expr.args: - if expr.name in { - "vlds", - "vldas", - "vldus", - "vldx2", - "vsld", - "psts", - "vsst", - "vstx2", - "vsta", - }: - self._record_tile_buffer_use(expr.args[0], used) - if expr.name in {"psts", "vsst", "vstx2", "vsta"} and len(expr.args) >= 2: - self._record_tile_buffer_use(expr.args[1], used) + if expr.namespace == "pto" and expr.name in {"vlds", "vldas", "vldus"} and expr.args: + self._record_tile_buffer_use(expr.args[0], used) for arg in expr.args: self._collect_used_tile_buffers_from_expr(arg, used) return @@ -363,15 +364,7 @@ def _collect_used_tile_buffers_from_expr( self._collect_used_tile_buffers_from_expr(slice_expr.step, used) return if isinstance(expr, SemanticAttributeAccess): - if expr.attr not in { - "shape", - "valid_shape", - "strides", - "element_type", - "rank", - "memory_space", - "config", - }: + if expr.attr not in {"shape", "valid_shape", "strides", "element_type"}: self._collect_used_tile_buffers_from_expr(expr.base, used) return if isinstance(expr, SemanticSubscriptAccess): @@ -419,6 +412,10 @@ def _render_stmt( return self._render_vector_store(stmt, env, indent=indent) if isinstance(stmt, SemanticVectorPairStoreStmt): return self._render_vector_pair_store(stmt, env, indent=indent) + if isinstance(stmt, SemanticPredicateStoreStmt): + return self._render_predicate_store(stmt, env, indent=indent) + if isinstance(stmt, SemanticAlignStoreStmt): + return self._render_align_store(stmt, env, indent=indent) if isinstance(stmt, SemanticScalarStoreStmt): return self._render_scalar_store(stmt, env, indent=indent) if isinstance(stmt, SemanticSetFlagStmt): @@ -649,27 +646,24 @@ def _render_multi_result_assign( raise NotImplementedError( f"multi-result assignment for `pto.{stmt.value.name}` is not supported in TileLang DSL v1" ) - if not isinstance(stmt.value.type, SemanticTupleType): - raise NotImplementedError("multi-result lowering expects a tuple-typed call value") - if len(stmt.targets) != len(stmt.value.type.elements): - raise NotImplementedError("multi-result lowering expects tuple assignment arity to match the call result count") + if len(stmt.targets) != 2: + raise NotImplementedError("multi-result lowering expects exactly two assignment targets") + if not isinstance(stmt.value.type, SemanticTupleType) or len(stmt.value.type.elements) != 2: + raise NotImplementedError("multi-result lowering expects a two-result tuple type") + + if stmt.value.name == "make_mask": + dtype_expr, remaining_expr = stmt.value.args + if not self._is_dtype_meta_expr(dtype_expr): + raise NotImplementedError("make_mask dtype lowering expects a dtype symbol") - if stmt.value.name == "make_mask" or stmt.value.name in {"plt_b8", "plt_b16", "plt_b32"}: lines: list[str] = [] - if stmt.value.name == "make_mask": - dtype_expr, remaining_expr = stmt.value.args - if not self._is_dtype_meta_expr(dtype_expr): - raise NotImplementedError("make_mask dtype lowering expects a dtype symbol") - opname = f"pto.plt_{self._mask_suffix(stmt.value.type.elements[0])}" - else: - remaining_expr = stmt.value.args[0] - opname = f"pto.{stmt.value.name}" remaining = self._lower_remaining_to_i32(remaining_expr, env, indent=indent, into=lines) mask_target, remaining_target = stmt.targets mask_type, remaining_type = stmt.value.type.elements + suffix = self._mask_suffix(mask_type) lines.append( self._indent(indent) - + f"{mask_target.ssa_name}, {remaining_target.ssa_name} = {opname} {remaining.name} : " + + f"{mask_target.ssa_name}, {remaining_target.ssa_name} = pto.plt_{suffix} {remaining.name} : " + f"i32 -> {self._render_type(mask_type)}, {self._render_type(remaining_type)}" ) env[mask_target.name] = _RenderedValue(name=mask_target.ssa_name, type=mask_type) @@ -714,7 +708,7 @@ def _render_multi_result_assign( env[carry_target.name] = _RenderedValue(name=carry_target.ssa_name, type=carry_type) return lines - if stmt.value.name in {"vintlv", "vdintlv", "pdintlv_b8", "pintlv_b16"}: + if stmt.value.name in {"vintlv", "vdintlv"}: lines = [] lhs = self._lower_expr(stmt.value.args[0], env, indent=indent, into=lines) rhs = self._lower_expr(stmt.value.args[1], env, indent=indent, into=lines) @@ -748,110 +742,6 @@ def _render_multi_result_assign( env[high_target.name] = _RenderedValue(name=high_target.ssa_name, type=high_type) return lines - if stmt.value.name == "vldx2": - lines = [] - source_name, source_type, offset_name, offset_type = self._lower_memory_buffer_with_offset( - stmt.value.args[:-1], - env, - indent=indent, - into=lines, - ) - dist = self._render_string_literal(stmt.value.args[-1]) - rendered_targets = ", ".join(target.ssa_name for target in stmt.targets) - rendered_result_types = ", ".join( - self._render_type(result_type) for result_type in stmt.value.type.elements - ) - lines.append( - self._indent(indent) - + f"{rendered_targets} = pto.vldx2 {source_name}[{offset_name}], {dist} : " - + f"{source_type}, {offset_type} -> {rendered_result_types}" - ) - for target, result_type in zip(stmt.targets, stmt.value.type.elements): - env[target.name] = _RenderedValue(name=target.ssa_name, type=result_type) - return lines - - if stmt.value.name == "vldus": - lines = [] - source_name, source_type = self._lower_memory_buffer_without_offset( - stmt.value.args[:-1], - env, - indent=indent, - into=lines, - ) - align = self._lower_expr(stmt.value.args[-1], env, indent=indent, into=lines) - rendered_targets = ", ".join(target.ssa_name for target in stmt.targets) - rendered_result_types = ", ".join( - self._render_type(result_type) for result_type in stmt.value.type.elements - ) - lines.append( - self._indent(indent) - + f"{rendered_targets} = pto.vldus {source_name}, {align.name} : " - + f"{source_type}, {self._render_type(align.type)} -> {rendered_result_types}" - ) - for target, result_type in zip(stmt.targets, stmt.value.type.elements): - env[target.name] = _RenderedValue(name=target.ssa_name, type=result_type) - return lines - - if stmt.value.name == "pstu": - lines = [] - align = self._lower_expr(stmt.value.args[0], env, indent=indent, into=lines) - value = self._lower_expr(stmt.value.args[1], env, indent=indent, into=lines) - base = self._lower_expr(stmt.value.args[2], env, indent=indent, into=lines) - rendered_targets = ", ".join(target.ssa_name for target in stmt.targets) - rendered_result_types = ", ".join( - self._render_type(result_type) for result_type in stmt.value.type.elements - ) - lines.append( - self._indent(indent) - + f"{rendered_targets} = pto.pstu {align.name}, {value.name}, {base.name} : " - + f"{self._render_type(align.type)}, {self._render_type(value.type)}, {self._render_type(base.type)} " - + f"-> {rendered_result_types}" - ) - for target, result_type in zip(stmt.targets, stmt.value.type.elements): - env[target.name] = _RenderedValue(name=target.ssa_name, type=result_type) - return lines - - if stmt.value.name == "vstu": - lines = [] - align = self._lower_expr(stmt.value.args[0], env, indent=indent, into=lines) - offset = self._lower_expr(stmt.value.args[1], env, indent=indent, into=lines) - value = self._lower_expr(stmt.value.args[2], env, indent=indent, into=lines) - base = self._lower_expr(stmt.value.args[3], env, indent=indent, into=lines) - mode = self._render_string_literal(stmt.value.args[4]) - rendered_targets = ", ".join(target.ssa_name for target in stmt.targets) - rendered_result_types = ", ".join( - self._render_type(result_type) for result_type in stmt.value.type.elements - ) - lines.append( - self._indent(indent) - + f"{rendered_targets} = pto.vstu {align.name}, {offset.name}, {value.name}, {base.name}, {mode} : " - + f"{self._render_type(align.type)}, {self._render_type(offset.type)}, {self._render_type(value.type)}, {self._render_type(base.type)} " - + f"-> {rendered_result_types}" - ) - for target, result_type in zip(stmt.targets, stmt.value.type.elements): - env[target.name] = _RenderedValue(name=target.ssa_name, type=result_type) - return lines - - if stmt.value.name == "vstus": - lines = [] - align = self._lower_expr(stmt.value.args[0], env, indent=indent, into=lines) - offset = self._lower_expr(stmt.value.args[1], env, indent=indent, into=lines) - value = self._lower_expr(stmt.value.args[2], env, indent=indent, into=lines) - base = self._lower_expr(stmt.value.args[3], env, indent=indent, into=lines) - mode = self._render_string_literal(stmt.value.args[4]) - rendered_targets = ", ".join(target.ssa_name for target in stmt.targets) - rendered_result_types = ", ".join( - self._render_type(result_type) for result_type in stmt.value.type.elements - ) - lines.append( - self._indent(indent) - + f"{rendered_targets} = pto.vstus {align.name}, {offset.name}, {value.name}, {base.name}, {mode} : " - + f"{self._render_type(align.type)}, {self._render_type(offset.type)}, {self._render_type(value.type)}, {self._render_type(base.type)} " - + f"-> {rendered_result_types}" - ) - for target, result_type in zip(stmt.targets, stmt.value.type.elements): - env[target.name] = _RenderedValue(name=target.ssa_name, type=result_type) - if stmt.value.name == "vldsx2": lines = [] source = self._lower_expr(stmt.value.args[0], env, indent=indent, into=lines) @@ -888,6 +778,60 @@ def _render_multi_result_assign( env[high_target.name] = _RenderedValue(name=high_target.ssa_name, type=high_type) return lines + if stmt.value.name == "vldus": + lines = [] + source = self._lower_expr(stmt.value.args[0], env, indent=indent, into=lines) + index_args = stmt.value.args[1:-1] + if isinstance(source.type, SemanticTileType): + source = self._materialize_tile_memref(source, indent=indent, into=lines) + if ( + isinstance(stmt.value.args[0].type, SemanticTileType) + and stmt.value.args[0].type.rank == 2 + and len(index_args) == 2 + ): + source = self._materialize_rank2_tile_subview( + source, + stmt.value.args[0].type, + index_args, + env, + indent=indent, + into=lines, + ) + if self._is_memref_like_type(source.type): + ptr_name, ptr_type = self._materialize_copy_buffer_ptr(source, indent=indent, into=lines) + source = _RenderedValue(name=ptr_name, type=_RenderedTextualType(ptr_type)) + align = self._lower_expr(stmt.value.args[-1], env, indent=indent, into=lines) + result_target, align_target = stmt.targets + result_type, align_type = stmt.value.type.elements + lines.append( + self._indent(indent) + + f"{result_target.ssa_name}, {align_target.ssa_name} = pto.vldus " + + f"{source.name}, {align.name} : " + + f"{self._render_type(source.type)}, {self._render_type(align.type)} -> " + + f"{self._render_type(result_type)}, {self._render_type(align_type)}" + ) + env[result_target.name] = _RenderedValue(name=result_target.ssa_name, type=result_type) + env[align_target.name] = _RenderedValue(name=align_target.ssa_name, type=align_type) + return lines + + if stmt.value.name == "pstu": + lines = [] + align_in = self._lower_expr(stmt.value.args[0], env, indent=indent, into=lines) + value = self._lower_expr(stmt.value.args[1], env, indent=indent, into=lines) + base = self._lower_expr(stmt.value.args[2], env, indent=indent, into=lines) + align_target, base_target = stmt.targets + align_type, base_type = stmt.value.type.elements + lines.append( + self._indent(indent) + + f"{align_target.ssa_name}, {base_target.ssa_name} = pto.pstu " + + f"{align_in.name}, {value.name}, {base.name} : " + + f"{self._render_type(align_in.type)}, {self._render_type(value.type)}, {self._render_type(base.type)} " + + f"-> {self._render_type(align_type)}, {self._render_type(base_type)}" + ) + env[align_target.name] = _RenderedValue(name=align_target.ssa_name, type=align_type) + env[base_target.name] = _RenderedValue(name=base_target.ssa_name, type=base_type) + return lines + raise NotImplementedError( f"multi-result assignment for `pto.{stmt.value.name}` is not supported in TileLang DSL v1" ) @@ -1067,6 +1011,85 @@ def _render_vector_pair_store( ) return lines + def _render_predicate_store( + self, + stmt: SemanticPredicateStoreStmt, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + lines: list[str] = [] + value = self._lower_expr(stmt.value, env, indent=indent, into=lines) + destination = self._lower_expr(stmt.destination, env, indent=indent, into=lines) + if isinstance(destination.type, SemanticTileType): + destination = self._materialize_tile_memref(destination, indent=indent, into=lines) + if ( + isinstance(stmt.destination.type, SemanticTileType) + and stmt.destination.type.rank == 2 + and len(stmt.indices) == 2 + ): + destination = self._materialize_rank2_tile_subview( + destination, + stmt.destination.type, + stmt.indices, + env, + indent=indent, + into=lines, + ) + rendered_offset = self._materialize_constant(0, SemanticIndexType()) + else: + rendered_offset = self._lower_expr(stmt.indices[0], env, indent=indent, into=lines) + dist = self._render_string_literal(stmt.dist) + lines.append( + self._indent(indent) + + f"pto.psts {value.name}, {destination.name}[{rendered_offset.name}], {dist} : " + + f"{self._render_type(value.type)}, {self._render_type(destination.type)}, {self._render_type(rendered_offset.type)}" + ) + return lines + + def _render_align_store( + self, + stmt: SemanticAlignStoreStmt, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + lines: list[str] = [] + value = self._lower_expr(stmt.value, env, indent=indent, into=lines) + destination = self._lower_expr(stmt.destination, env, indent=indent, into=lines) + if isinstance(destination.type, SemanticTileType): + destination = self._materialize_tile_memref(destination, indent=indent, into=lines) + if ( + isinstance(stmt.destination.type, SemanticTileType) + and stmt.destination.type.rank == 2 + and len(stmt.indices) == 2 + ): + destination = self._materialize_rank2_tile_subview( + destination, + stmt.destination.type, + stmt.indices, + env, + indent=indent, + into=lines, + ) + if stmt.op_name == "vstar": + lines.append( + self._indent(indent) + + f"pto.vstar {value.name}, {destination.name} : " + + f"{self._render_type(value.type)}, {self._render_type(destination.type)}" + ) + return lines + if stmt.offset is None: + raise NotImplementedError("vstas lowering expects an explicit offset operand") + offset = self._lower_expr(stmt.offset, env, indent=indent, into=lines) + offset = self._coerce_rendered_value(offset, _I32_TYPE, indent=indent, into=lines) + lines.append( + self._indent(indent) + + f"pto.vstas {value.name}, {destination.name}, {offset.name} : " + + f"{self._render_type(value.type)}, {self._render_type(destination.type)}, {self._render_type(offset.type)}" + ) + return lines + def _render_scalar_store( self, stmt: SemanticScalarStoreStmt, @@ -1142,160 +1165,6 @@ def _materialize_rank2_tile_subview( ) return _RenderedValue(name=subview_name, type=subview_type) - def _lower_memory_buffer_without_offset( - self, - args: tuple[SemanticExpr, ...], - env: dict[str, _RenderedValue], - *, - indent: int, - into: list[str], - ) -> tuple[str, str]: - if not args: - raise NotImplementedError("memory buffer lowering expects at least one operand") - source = self._lower_expr(args[0], env, indent=indent, into=into) - if isinstance(source.type, SemanticTileType): - source = self._materialize_tile_access_ptr( - source, - args[1:], - env, - indent=indent, - into=into, - ) - return source.name, self._render_type(source.type) - if len(args) != 1: - raise NotImplementedError("pointer memory buffer lowering does not accept tile-style indices") - return source.name, self._render_type(source.type) - - def _lower_memory_buffer_with_offset( - self, - args: tuple[SemanticExpr, ...], - env: dict[str, _RenderedValue], - *, - indent: int, - into: list[str], - ) -> tuple[str, str, str, str]: - if not args: - raise NotImplementedError("memory buffer lowering expects at least one operand") - source = self._lower_expr(args[0], env, indent=indent, into=into) - if isinstance(source.type, SemanticTileType): - if not args[1:]: - raise NotImplementedError("tile memory buffer lowering requires element indices") - offset = self._materialize_tile_linear_offset( - source, - args[1:], - env, - indent=indent, - into=into, - ) - source = self._materialize_tile_memref(source, indent=indent, into=into) - return ( - source.name, - self._render_type(source.type), - offset.name, - self._render_type(offset.type), - ) - if len(args) != 2: - raise NotImplementedError("pointer memory buffer lowering expects exactly one explicit offset operand") - offset = self._lower_expr(args[1], env, indent=indent, into=into) - return ( - source.name, - self._render_type(source.type), - offset.name, - self._render_type(offset.type), - ) - - def _materialize_tile_access_ptr( - self, - tile_value: _RenderedValue, - indices: tuple[SemanticExpr, ...], - env: dict[str, _RenderedValue], - *, - indent: int, - into: list[str], - ) -> _RenderedValue: - base_ptr_name, base_ptr_type = self._materialize_copy_buffer_ptr( - tile_value, - indent=indent, - into=into, - ) - if not indices: - return _RenderedValue(name=base_ptr_name, type=tile_value.type if isinstance(tile_value.type, SemanticPtrType) else SemanticPtrType(tile_value.type.element_dtype, tile_value.type.memory_space or "ub")) - offset = self._materialize_tile_linear_offset( - tile_value, - indices, - env, - indent=indent, - into=into, - ) - typed_ptr_type = SemanticPtrType( - element_dtype=tile_value.type.element_dtype, - memory_space=tile_value.type.memory_space or "ub", - ) - offset_ptr_name = self._new_temp() - into.append( - self._indent(indent) - + f"{offset_ptr_name} = pto.addptr {base_ptr_name}, {offset.name} : " - + f"{base_ptr_type} -> {self._render_type(typed_ptr_type)}" - ) - return _RenderedValue(name=offset_ptr_name, type=typed_ptr_type) - - def _materialize_tile_linear_offset( - self, - tile_value: _RenderedValue, - indices: tuple[SemanticExpr, ...], - env: dict[str, _RenderedValue], - *, - indent: int, - into: list[str], - ) -> _RenderedValue: - tile_type = tile_value.type - if not isinstance(tile_type, SemanticTileType): - raise NotImplementedError("tile linear offset lowering expects a Tile value") - if tile_type.rank == 1: - if len(indices) != 1: - raise NotImplementedError("rank-1 Tile access expects one index") - return self._lower_expr(indices[0], env, indent=indent, into=into) - if tile_type.rank != 2 or tile_type.shape is None: - raise NotImplementedError("Tile linear offset lowering expects a statically specialized rank-2 Tile") - if len(indices) != 2: - raise NotImplementedError("rank-2 Tile access expects two indices") - - row = self._lower_expr(indices[0], env, indent=indent, into=into) - col = self._lower_expr(indices[1], env, indent=indent, into=into) - strides = tile_strides(tile_type.shape, tile_type.config or TileConfig()) - stride0 = _RenderedValue( - name=self._materialize_constant(strides[0], SemanticIndexType()), - type=SemanticIndexType(), - ) - stride1 = _RenderedValue( - name=self._materialize_constant(strides[1], SemanticIndexType()), - type=SemanticIndexType(), - ) - row_term = self._emit_binary_value( - "mul", - row, - stride0, - SemanticIndexType(), - indent=indent, - into=into, - ) - col_term = self._emit_binary_value( - "mul", - col, - stride1, - SemanticIndexType(), - indent=indent, - into=into, - ) - return self._emit_binary_value( - "add", - row_term, - col_term, - SemanticIndexType(), - indent=indent, - into=into, - ) - def _tensor_slice_extents(self, expr: SemanticTensorSliceExpr) -> tuple[int, int]: if expr.type.rank != 2 or len(expr.type.extents) != 2: raise NotImplementedError("TileLang DSL v1 DMA lowering currently only supports rank-2 TensorView slices") @@ -1390,79 +1259,6 @@ def _static_expr_value(self, expr: SemanticExpr | None, *, default: object = Non return expr.value if isinstance(expr, SemanticBindingRef): return expr.binding.value - if isinstance(expr, SemanticAttributeAccess): - base_value = self._static_expr_value(expr.base) - if isinstance(base_value, TileConfig): - if expr.attr == "b_layout": - return base_value.b_layout - if expr.attr == "s_layout": - return base_value.s_layout - if expr.attr == "s_fractal_size": - return base_value.s_fractal_size - if expr.attr == "pad_value": - return base_value.pad_value - if base_value is not None and hasattr(base_value, expr.attr): - return getattr(base_value, expr.attr) - if isinstance(expr.base.type, SemanticTileType): - tile_type = expr.base.type - if expr.attr == "shape": - return tile_type.shape - if expr.attr == "valid_shape": - return None - if expr.attr == "rank": - return tile_type.rank - if expr.attr == "memory_space": - return None if tile_type.memory_space is None else MemorySpace(tile_type.memory_space) - if expr.attr == "config": - return TileConfig() if tile_type.config is None else tile_type.config - if isinstance(expr.base.type, (SemanticTensorViewType, SemanticPartitionTensorViewType)) and expr.attr == "rank": - return expr.base.type.rank - return None - if isinstance(expr, SemanticTupleExpr): - values = [] - for element in expr.elements: - value = self._static_expr_value(element) - if value is None: - return None - values.append(value) - return tuple(values) - if isinstance(expr, SemanticSubscriptAccess): - base_value = self._static_expr_value(expr.base) - index_value = self._static_expr_value(expr.index) - if isinstance(base_value, (tuple, list)) and isinstance(index_value, int): - if 0 <= index_value < len(base_value): - return base_value[index_value] - return None - if isinstance(expr, SemanticBinaryExpr): - lhs = self._static_expr_value(expr.lhs) - rhs = self._static_expr_value(expr.rhs) - if lhs is None or rhs is None: - return None - if expr.op == "add" and isinstance(lhs, int) and isinstance(rhs, int): - return lhs + rhs - if expr.op == "sub" and isinstance(lhs, int) and isinstance(rhs, int): - return lhs - rhs - if expr.op == "mul" and isinstance(lhs, int) and isinstance(rhs, int): - return lhs * rhs - if expr.op == "floordiv" and isinstance(lhs, int) and isinstance(rhs, int) and rhs != 0: - return lhs // rhs - if expr.op == "eq": - return lhs == rhs - if expr.op == "ne": - return lhs != rhs - if expr.op == "gt": - return lhs > rhs - if expr.op == "lt": - return lhs < rhs - if expr.op == "ge": - return lhs >= rhs - if expr.op == "le": - return lhs <= rhs - if expr.op == "and": - return bool(lhs) and bool(rhs) - if expr.op == "or": - return bool(lhs) or bool(rhs) - return None return None def _infer_dma_load_transfer( @@ -2517,19 +2313,6 @@ def _lower_expr( name=self._materialize_constant(expr.value, expr.type), type=expr.type, ) - static_value = self._static_expr_value(expr) - if static_value is not None and isinstance(expr.type, (SemanticIndexType, SemanticScalarType)): - if desired_name is not None and into is not None: - into.append( - self._indent(indent) - + f"{desired_name} = arith.constant {self._format_constant(static_value, expr.type)} : " - f"{self._render_type(expr.type)}" - ) - return _RenderedValue(name=desired_name, type=expr.type) - return _RenderedValue( - name=self._materialize_constant(static_value, expr.type), - type=expr.type, - ) if isinstance(expr, SemanticSubscriptAccess): return self._lower_subscript_access( expr, @@ -2617,129 +2400,28 @@ def _lower_call_expr( if expr.namespace != "pto": raise NotImplementedError(f"unsupported call namespace {expr.namespace!r}") if isinstance(expr.type, SemanticTupleType): - raise NotImplementedError( - f"multi-result call `pto.{expr.name}` must be assigned directly in TileLang DSL v1" - ) + raise NotImplementedError("multi-result call values must be assigned directly in TileLang DSL v1") if into is None: into = [] result_name = desired_name or self._new_temp() - if isinstance(expr.type, SemanticMetaType) and expr.type.kind == "void": - if expr.name == "psts": - value = self._lower_expr(expr.args[0], env, indent=indent, into=into) - destination_name, destination_type, offset_name, _ = self._lower_memory_buffer_with_offset( - expr.args[1:], - env, - indent=indent, - into=into, - ) - into.append( - self._indent(indent) - + f"pto.psts {value.name}, {destination_name}[{offset_name}] : " - + f"{self._render_type(value.type)}, {destination_type}" - ) - return _RenderedValue(name="__void_call__", type=expr.type) - - if expr.name == "pst": - value = self._lower_expr(expr.args[0], env, indent=indent, into=into) - destination_name, destination_type, offset_name, offset_type = self._lower_memory_buffer_with_offset( - expr.args[1:-1], - env, - indent=indent, - into=into, - ) - dist = self._render_string_literal(expr.args[-1]) - into.append( - self._indent(indent) - + f"pto.pst {value.name}, {destination_name}[{offset_name}], {dist} : " - + f"{self._render_type(value.type)}, {destination_type}, {offset_type}" - ) - return _RenderedValue(name="__void_call__", type=expr.type) - - if expr.name == "psti": - value = self._lower_expr(expr.args[0], env, indent=indent, into=into) - destination = self._lower_expr(expr.args[1], env, indent=indent, into=into) - offset = self._lower_expr(expr.args[2], env, indent=indent, into=into) - dist = self._render_string_literal(expr.args[-1]) - into.append( - self._indent(indent) - + f"pto.psti {value.name}, {destination.name}, {offset.name}, {dist} : " - + f"{self._render_type(value.type)}, {self._render_type(destination.type)}, {self._render_type(offset.type)}" - ) - return _RenderedValue(name="__void_call__", type=expr.type) - - if expr.name == "vsst": - value = self._lower_expr(expr.args[0], env, indent=indent, into=into) - destination_name, destination_type, offset_name, offset_type = self._lower_memory_buffer_with_offset( - expr.args[1:-1], - env, - indent=indent, - into=into, - ) - stride = self._render_string_literal(expr.args[-1]) - into.append( - self._indent(indent) - + f"pto.vsst {value.name}, {destination_name}[{offset_name}], {stride} : " - + f"{self._render_type(value.type)}, {destination_type}" - ) - return _RenderedValue(name="__void_call__", type=expr.type) - - if expr.name == "vstx2": - low = self._lower_expr(expr.args[0], env, indent=indent, into=into) - high = self._lower_expr(expr.args[1], env, indent=indent, into=into) - destination_name, destination_type, offset_name, offset_type = self._lower_memory_buffer_with_offset( - expr.args[2:-2], - env, - indent=indent, - into=into, - ) - dist = self._render_string_literal(expr.args[-2]) - mask = self._lower_expr(expr.args[-1], env, indent=indent, into=into) - into.append( - self._indent(indent) - + f"pto.vstx2 {low.name}, {high.name}, {destination_name}[{offset_name}], {dist}, {mask.name} : " - + f"{self._render_type(low.type)}, {self._render_type(high.type)}, {destination_type}, {offset_type}, {self._render_type(mask.type)}" - ) - return _RenderedValue(name="__void_call__", type=expr.type) - - if expr.name == "vsta": - value = self._lower_expr(expr.args[0], env, indent=indent, into=into) - destination_name, destination_type, offset_name, offset_type = self._lower_memory_buffer_with_offset( - expr.args[1:], - env, - indent=indent, - into=into, - ) - into.append( - self._indent(indent) - + f"pto.vsta {value.name}, {destination_name}[{offset_name}] : " - + f"{self._render_type(value.type)}, {destination_type}, {offset_type}" - ) - return _RenderedValue(name="__void_call__", type=expr.type) - - raise NotImplementedError(f"void pto call `pto.{expr.name}` is not supported in TileLang DSL v1") - if expr.name == "make_mask": dtype_expr, pattern_expr = expr.args if not self._is_dtype_meta_expr(dtype_expr): raise NotImplementedError("make_mask dtype lowering expects a dtype symbol") - pattern_value = self._extract_mask_pattern_value(pattern_expr) - if pattern_value is None: + if not isinstance(pattern_expr, SemanticSymbolExpr) or not isinstance(pattern_expr.value, MaskPattern): raise NotImplementedError("make_mask pattern lowering expects a MaskPattern symbol") suffix = expr.type.granularity into.append( self._indent(indent) - + f'{result_name} = pto.pset_{suffix} "{pattern_value}" : {self._render_type(expr.type)}' + + f'{result_name} = pto.pset_{suffix} "{pattern_expr.value.value}" : {self._render_type(expr.type)}' ) return _RenderedValue(name=result_name, type=expr.type) - if expr.name in {"pset_b8", "pset_b16", "pset_b32", "pge_b8", "pge_b16", "pge_b32"}: - pattern_value = self._extract_mask_pattern_value(expr.args[0]) - if pattern_value is None: - raise NotImplementedError(f"{expr.name} lowering expects a MaskPattern symbol") + if expr.name == "init_align": into.append( self._indent(indent) - + f'{result_name} = pto.{expr.name} "{pattern_value}" : {self._render_type(expr.type)}' + + f"{result_name} = pto.init_align : {self._render_type(expr.type)}" ) return _RenderedValue(name=result_name, type=expr.type) @@ -2771,92 +2453,67 @@ def _lower_call_expr( return _RenderedValue(name=result_name, type=expr.type) if expr.name == "vldas": - source_name, source_type = self._lower_memory_buffer_without_offset( - expr.args, - env, - indent=indent, - into=into, - ) - into.append( - self._indent(indent) - + f"{result_name} = pto.vldas {source_name} : {source_type} -> {self._render_type(expr.type)}" - ) - return _RenderedValue(name=result_name, type=expr.type) - - if expr.name == "plds": - source_name, source_type, offset_name, _ = self._lower_memory_buffer_with_offset( - expr.args[:-1], - env, - indent=indent, - into=into, - ) - dist = self._render_string_literal(expr.args[-1]) - into.append( - self._indent(indent) - + f'{result_name} = pto.plds {source_name}[{offset_name}] {{dist = {dist}}} : ' - + f"{source_type} -> {self._render_type(expr.type)}" - ) - return _RenderedValue(name=result_name, type=expr.type) - - if expr.name == "pld": - source_name, source_type, offset_name, offset_type = self._lower_memory_buffer_with_offset( - expr.args[:-1], - env, - indent=indent, - into=into, - ) - dist = self._render_string_literal(expr.args[-1]) + source = self._lower_expr(expr.args[0], env, indent=indent, into=into) + index_args = expr.args[1:] + if isinstance(source.type, SemanticTileType): + source = self._materialize_tile_memref(source, indent=indent, into=into) + if ( + isinstance(expr.args[0].type, SemanticTileType) + and expr.args[0].type.rank == 2 + and len(index_args) == 2 + ): + source = self._materialize_rank2_tile_subview( + source, + expr.args[0].type, + index_args, + env, + indent=indent, + into=into, + ) + if self._is_memref_like_type(source.type): + ptr_name, ptr_type = self._materialize_copy_buffer_ptr(source, indent=indent, into=into) + source = _RenderedValue(name=ptr_name, type=_RenderedTextualType(ptr_type)) into.append( self._indent(indent) - + f"{result_name} = pto.pld {source_name}[{offset_name}], {dist} : " - + f"{source_type}, {offset_type} -> {self._render_type(expr.type)}" + + f"{result_name} = pto.vldas {source.name} : " + + f"{self._render_type(source.type)} -> {self._render_type(expr.type)}" ) return _RenderedValue(name=result_name, type=expr.type) - if expr.name == "pldi": + if expr.name == "load_scalar": source = self._lower_expr(expr.args[0], env, indent=indent, into=into) offset = self._lower_expr(expr.args[1], env, indent=indent, into=into) - dist = self._render_string_literal(expr.args[2]) into.append( self._indent(indent) - + f"{result_name} = pto.pldi {source.name}, {offset.name}, {dist} : " - + f"{self._render_type(source.type)}, {self._render_type(offset.type)} -> {self._render_type(expr.type)}" + + f"{result_name} = pto.load_scalar {source.name}[{offset.name}] : " + + f"{self._render_type(source.type)} -> {self._render_type(expr.type)}" ) return _RenderedValue(name=result_name, type=expr.type) - if expr.name == "vsld": - source_name, source_type, offset_name, offset_type = self._lower_memory_buffer_with_offset( - expr.args[:-1], - env, - indent=indent, - into=into, - ) - stride = self._render_string_literal(expr.args[-1]) + if expr.name == "vstus": + align_in = self._lower_expr(expr.args[0], env, indent=indent, into=into) + offset = self._lower_expr(expr.args[1], env, indent=indent, into=into) + offset = self._coerce_rendered_value(offset, _I32_TYPE, indent=indent, into=into) + value = self._lower_expr(expr.args[2], env, indent=indent, into=into) + base = self._lower_expr(expr.args[3], env, indent=indent, into=into) into.append( self._indent(indent) - + f"{result_name} = pto.vsld {source_name}[{offset_name}], {stride} : " - + f"{source_type} -> {self._render_type(expr.type)}" + + f"{result_name} = pto.vstus {align_in.name}, {offset.name}, {value.name}, {base.name} : " + + f"{self._render_type(align_in.type)}, {self._render_type(offset.type)}, {self._render_type(value.type)}, {self._render_type(base.type)} " + + f"-> {self._render_type(expr.type)}" ) return _RenderedValue(name=result_name, type=expr.type) if expr.name == "vstur": - align = self._lower_expr(expr.args[0], env, indent=indent, into=into) + align_in = self._lower_expr(expr.args[0], env, indent=indent, into=into) value = self._lower_expr(expr.args[1], env, indent=indent, into=into) base = self._lower_expr(expr.args[2], env, indent=indent, into=into) mode = self._render_string_literal(expr.args[3]) into.append( self._indent(indent) - + f"{result_name} = pto.vstur {align.name}, {value.name}, {base.name}, {mode} : " - + f"{self._render_type(align.type)}, {self._render_type(value.type)}, {self._render_type(base.type)} -> {self._render_type(expr.type)}" - ) - - if expr.name == "load_scalar": - source = self._lower_expr(expr.args[0], env, indent=indent, into=into) - offset = self._lower_expr(expr.args[1], env, indent=indent, into=into) - into.append( - self._indent(indent) - + f"{result_name} = pto.load_scalar {source.name}[{offset.name}] : " - + f"{self._render_type(source.type)} -> {self._render_type(expr.type)}" + + f"{result_name} = pto.vstur {align_in.name}, {value.name}, {base.name}, {mode} : " + + f"{self._render_type(align_in.type)}, {self._render_type(value.type)}, {self._render_type(base.type)} " + + f"-> {self._render_type(expr.type)}" ) return _RenderedValue(name=result_name, type=expr.type) @@ -2991,18 +2648,6 @@ def _lower_call_expr( ) return _RenderedValue(name=result_name, type=expr.type) - if expr.name in {"pand", "por", "pxor"}: - src0 = self._lower_expr(expr.args[0], env, indent=indent, into=into) - src1 = self._lower_expr(expr.args[1], env, indent=indent, into=into) - mask = self._lower_expr(expr.args[2], env, indent=indent, into=into) - into.append( - self._indent(indent) - + f"{result_name} = pto.{expr.name} {src0.name}, {src1.name}, {mask.name} : " - + f"{self._render_type(src0.type)}, {self._render_type(src1.type)}, {self._render_type(mask.type)} " - + f"-> {self._render_type(expr.type)}" - ) - return _RenderedValue(name=result_name, type=expr.type) - if expr.name == "vcmp": lhs = self._lower_expr(expr.args[0], env, indent=indent, into=into) rhs = self._lower_expr(expr.args[1], env, indent=indent, into=into) @@ -3525,34 +3170,6 @@ def _mask_suffix(self, ty: SemanticType) -> str: raise NotImplementedError("tail make_mask lowering expects a mask result type") return ty.granularity - def _extract_mask_pattern_value(self, expr: SemanticExpr) -> str | None: - if isinstance(expr, SemanticSymbolExpr) and isinstance(expr.value, MaskPattern): - return expr.value.value - if ( - isinstance(expr, SemanticBindingRef) - and isinstance(expr.type, SemanticMetaType) - and expr.type.kind == "mask_pattern" - and isinstance(expr.binding.value, MaskPattern) - ): - return expr.binding.value.value - return None - - def _emit_full_mask_for_type( - self, - ty: SemanticType, - *, - indent: int, - into: list[str], - ) -> _RenderedValue: - if not isinstance(ty, SemanticMaskType): - raise NotImplementedError("full-mask synthesis expects a mask type") - result_name = self._new_temp() - into.append( - self._indent(indent) - + f'{result_name} = pto.pset_{ty.granularity} "PAT_ALL" : {self._render_type(ty)}' - ) - return _RenderedValue(name=result_name, type=ty) - def _is_dtype_meta_expr(self, expr: SemanticExpr) -> bool: if isinstance(expr, SemanticSymbolExpr): return isinstance(expr.value, ScalarType) and expr.type.kind == "dtype" @@ -3698,10 +3315,6 @@ def _extract_shape_subscript_value( expr: SemanticSubscriptAccess, env: dict[str, _RenderedValue], ) -> int | _RenderedValue: - base_static = self._static_expr_value(expr.base) - index_static = self._static_expr_value(expr.index) - if isinstance(base_static, (tuple, list)) and isinstance(index_static, int): - return base_static[index_static] if not isinstance(expr.base, SemanticAttributeAccess): raise NotImplementedError("only shape/stride indexing is supported in TileLang DSL v1 lowering") if expr.base.attr not in {"shape", "valid_shape", "strides"}: @@ -3809,8 +3422,6 @@ def _render_type(self, ty: SemanticType) -> str: return ty.dtype.name if isinstance(ty, SemanticPtrType): return f"!pto.ptr<{ty.element_dtype.name}, {ty.memory_space}>" - if isinstance(ty, SemanticAlignType): - return "!pto.align" if isinstance(ty, SemanticTensorViewType): return self._render_tensor_view_type( element_dtype=ty.element_dtype.name, @@ -3823,6 +3434,8 @@ def _render_type(self, ty: SemanticType) -> str: ) if isinstance(ty, SemanticTileType): return self._render_tile_buf_type(ty) + if isinstance(ty, SemanticAlignType): + return "!pto.align" if isinstance(ty, SemanticMaskType): return f"!pto.mask<{ty.granularity}>" if isinstance(ty, SemanticVRegType): @@ -3889,12 +3502,11 @@ def _render_tile_buf_type(self, ty: SemanticTileType) -> str: valid_shape = ty.valid_shape or ty.shape v_row = valid_shape[0] v_col = 1 if ty.rank == 1 else valid_shape[1] - config = ty.config or TileConfig() return ( f"!pto.tile_buf" + "blayout=row_major, slayout=none_box, fractal=512, pad=0>" ) def _render_tile_buf_loc(self, memory_space: str) -> str: diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 34b2ddc86..5cdb28cc2 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -47,7 +47,7 @@ unsupported_feature_message, ) from .types import ( - BLayout, + AlignType, BarrierType, DeinterleaveDist, Event, @@ -56,17 +56,12 @@ MaskPattern, MemorySpace, OrderMode, - PadValue, PadMode, Pipe, + PostUpdateMode, PositionMode, - PredicateDist, - PredicatePart, PointerType, ScalarType, - SLayout, - StrideMode, - TileConfig, VRegType, bf16, bytewidth, @@ -85,6 +80,7 @@ si16, si32, si64, + align, ui8, ui16, ui32, @@ -120,32 +116,12 @@ _EVENT_SYMBOLS = {event.name: event for event in Event} _BARRIER_TYPE_SYMBOLS = {barrier_type.name: barrier_type for barrier_type in BarrierType} _MEMORY_SPACE_SYMBOLS = {memory_space.name: memory_space for memory_space in MemorySpace} -_B_LAYOUT_SYMBOLS = {layout.name: layout for layout in BLayout} -_S_LAYOUT_SYMBOLS = {layout.name: layout for layout in SLayout} -_PAD_VALUE_SYMBOLS = {pad_value.name: pad_value for pad_value in PadValue} _PAD_MODE_SYMBOLS = {pad_mode.name: pad_mode for pad_mode in PadMode} _DEINTERLEAVE_DIST_SYMBOLS = dict(DeinterleaveDist.__members__) _INTERLEAVE_DIST_SYMBOLS = dict(InterleaveDist.__members__) _POSITION_MODE_SYMBOLS = {position_mode.name: position_mode for position_mode in PositionMode} _ORDER_MODE_SYMBOLS = {order_mode.name: order_mode for order_mode in OrderMode} -_DEINTERLEAVE_DIST_SYMBOLS = {dist.name: dist for dist in DeinterleaveDist} -_INTERLEAVE_DIST_SYMBOLS = {dist.name: dist for dist in InterleaveDist} -_PREDICATE_DIST_SYMBOLS = {dist.name: dist for dist in PredicateDist} -_PREDICATE_PART_SYMBOLS = {part.name: part for part in PredicatePart} -_STRIDE_MODE_SYMBOLS = {mode.name: mode for mode in StrideMode} -_DIRECT_PREDICATE_PATTERN_OPS = { - "pset_b8", - "pset_b16", - "pset_b32", - "pge_b8", - "pge_b16", - "pge_b32", -} -_DIRECT_PREDICATE_TAIL_OPS = {"plt_b8", "plt_b16", "plt_b32"} -_PREDICATE_MEMORY_EXPR_OPS = {"plds", "pld", "pldi"} -_PREDICATE_MEMORY_STMT_OPS = {"pst", "psti"} -_PREDICATE_BINARY_LOGIC_OPS = {"pand", "por", "pxor"} -_PREDICATE_REARRANGEMENT_OPS = {"pdintlv_b8", "pintlv_b16"} +_POST_UPDATE_MODE_SYMBOLS = {mode.name: mode for mode in PostUpdateMode} _UNARY_VECTOR_OPS = { "vabs", "vrelu", @@ -212,7 +188,7 @@ } _VECTOR_IMMEDIATE_OPS = {"vshift", "vslide"} _TERNARY_VECTOR_OPS = {"vaxpy", "vmula"} -_MULTI_RESULT_VECTOR_OPS = {"vmull", "vldsx2"} +_MULTI_RESULT_VECTOR_OPS = {"vmull", "vldsx2", "vldus", "pstu"} _BROADCAST_VECTOR_OPS = {"vbr", "vdup", "vci"} _LOW_LEVEL_DMA_CONFIG_OPS = { "set_loop2_stride_outtoub", @@ -228,23 +204,16 @@ "copy_ubuf_to_ubuf", } _COMPARE_SELECT_OPS = {"vcmp", "vcmps", "vsel", "vselr", "vselrv2"} -_PREDICATE_MOVEMENT_OPS = {"pnot", "psel", "ppack", "punpack"} | _PREDICATE_BINARY_LOGIC_OPS +_PREDICATE_MOVEMENT_OPS = {"pnot", "psel", "ppack", "punpack"} _CARRY_OPS = {"vaddc", "vsubc", "vaddcs", "vsubcs"} -_REARRANGEMENT_OPS = {"vintlv", "vdintlv", "vintlvv2", "vdintlvv2"} | _PREDICATE_REARRANGEMENT_OPS +_REARRANGEMENT_OPS = {"vintlv", "vdintlv", "vintlvv2", "vdintlvv2"} _ADVANCED_VECTOR_ACTIVITY_OPS = ( _COMPARE_SELECT_OPS | _PREDICATE_MOVEMENT_OPS | _CARRY_OPS | _REARRANGEMENT_OPS - | _DIRECT_PREDICATE_PATTERN_OPS - | _DIRECT_PREDICATE_TAIL_OPS - | _PREDICATE_MEMORY_EXPR_OPS - | _PREDICATE_MEMORY_STMT_OPS | {"vcvt", "vmrgsort4"} ) -_VECTOR_MEMORY_EXPR_OPS = {"vlds", "vldas", "vldus", "vldx2", "vsld"} -_VECTOR_MEMORY_STMT_OPS = {"vsts", "psts", "vsst", "vstx2", "vsta"} -_STATEFUL_MEMORY_EXPR_OPS = {"pstu", "vstu", "vstus", "vstur"} _TENSORVIEW_RANK = 5 @@ -279,7 +248,6 @@ class SemanticTileType(SemanticType): shape: tuple[int, ...] | None valid_shape: tuple[int | None, ...] | None memory_space: str | None - config: TileConfig | None = None @dataclass(frozen=True) @@ -318,6 +286,11 @@ class SemanticMetaType(SemanticType): kind: str +@dataclass(frozen=True) +class SemanticAlignType(SemanticType): + pass + + @dataclass(frozen=True) class SemanticMaskType(SemanticType): granularity: str @@ -329,13 +302,7 @@ class SemanticVRegType(SemanticType): lanes: int -@dataclass(frozen=True) -class SemanticAlignType(SemanticType): - pass - - _I32_TYPE = SemanticScalarType(dtype=i32) -_ALIGN_TYPE = SemanticAlignType() @dataclass(frozen=True) @@ -496,6 +463,23 @@ class SemanticVectorPairStoreStmt(SemanticStmt): mask: SemanticExpr +@dataclass(frozen=True) +class SemanticPredicateStoreStmt(SemanticStmt): + value: SemanticExpr + destination: SemanticExpr + indices: tuple[SemanticExpr, ...] + dist: SemanticExpr + + +@dataclass(frozen=True) +class SemanticAlignStoreStmt(SemanticStmt): + op_name: str + value: SemanticExpr + destination: SemanticExpr + indices: tuple[SemanticExpr, ...] = () + offset: SemanticExpr | None = None + + @dataclass(frozen=True) class SemanticScalarStoreStmt(SemanticStmt): value: SemanticExpr @@ -751,7 +735,6 @@ def _parameter_type(self, param: Any) -> SemanticType: shape=shape, valid_shape=valid_shape, memory_space=memory_space, - config=None if spec is None else spec.config, ) if param.kind == "ptr": memory_space = param.annotation.memory_space.value @@ -944,7 +927,23 @@ def _should_infer_vecscope( return self._block_can_live_in_inferred_vecscope(stmt.body) name = self._frontend_vector_call_name(stmt) return name in ( - {"make_mask"} | _VECTOR_MEMORY_EXPR_OPS | _VECTOR_MEMORY_STMT_OPS | _STATEFUL_MEMORY_EXPR_OPS + { + "make_mask", + "init_align", + "vlds", + "vldas", + "vldus", + "psts", + "pstu", + "vsst", + "vsta", + "vstas", + "vstar", + "vsts", + "vstsx2", + "vstus", + "vstur", + } | _UNARY_VECTOR_OPS | _BINARY_VECTOR_OPS | _VECTOR_SCALAR_OPS @@ -1022,7 +1021,23 @@ def _frontend_stmt_contains_vector_activity(self, stmt: FrontendStmtNode) -> boo return ( expr.namespace == "pto" and expr.name in ( - {"make_mask"} | _VECTOR_MEMORY_EXPR_OPS | _VECTOR_MEMORY_STMT_OPS | _STATEFUL_MEMORY_EXPR_OPS + { + "make_mask", + "init_align", + "vlds", + "vldas", + "vldus", + "psts", + "pstu", + "vsst", + "vsta", + "vstas", + "vstar", + "vsts", + "vstsx2", + "vstus", + "vstur", + } | _UNARY_VECTOR_OPS | _BINARY_VECTOR_OPS | _VECTOR_SCALAR_OPS @@ -1128,7 +1143,7 @@ def _semantic_block_contains_vector_activity( def _expr_contains_vector_activity(self, expr: SemanticExpr) -> bool: if isinstance(expr, SemanticCallExpr): if expr.namespace == "pto" and expr.name in ( - {"make_mask"} | _VECTOR_MEMORY_EXPR_OPS | _STATEFUL_MEMORY_EXPR_OPS + {"make_mask", "vlds"} | _UNARY_VECTOR_OPS | _BINARY_VECTOR_OPS | _VECTOR_SCALAR_OPS @@ -1183,8 +1198,8 @@ def _analyze_stmt( env, allow_outer_lookup=allow_outer_lookup, ) - if self._is_vector_memory_stmt_call(stmt.expr): - return self._analyze_vector_memory_stmt(stmt.expr, env, allow_outer_lookup=allow_outer_lookup) + if self._is_vector_store_call(stmt.expr): + return self._analyze_vector_store_stmt(stmt.expr, env, allow_outer_lookup=allow_outer_lookup) if self._is_scalar_store_call(stmt.expr): return self._analyze_scalar_store_stmt(stmt.expr, env, allow_outer_lookup=allow_outer_lookup) expr = self._analyze_expr(stmt.expr, env, allow_outer_lookup=allow_outer_lookup) @@ -1370,11 +1385,11 @@ def _is_dma_call(self, expr: FrontendExprNode) -> bool: and expr.name in {"dma_load", "dma_store"} ) - def _is_vector_memory_stmt_call(self, expr: FrontendExprNode) -> bool: + def _is_vector_store_call(self, expr: FrontendExprNode) -> bool: return ( isinstance(expr, FrontendCallExpr) and expr.namespace == "pto" - and expr.name in (_VECTOR_MEMORY_STMT_OPS | _PREDICATE_MEMORY_STMT_OPS | {"vstsx2"}) + and expr.name in {"psts", "vsst", "vsta", "vstas", "vstar", "vsts", "vstsx2"} ) def _is_scalar_store_call(self, expr: FrontendExprNode) -> bool: @@ -1490,316 +1505,256 @@ def _analyze_dma_options( init_out_buffer=init_out_buffer, ) - def _analyze_vector_memory_stmt( + def _analyze_vector_store_stmt( self, expr: FrontendCallExpr, env: dict[str, SemanticBinding], *, allow_outer_lookup: bool, ) -> tuple[SemanticStmt, dict[str, SemanticBinding]]: - if expr.name == "vsts": - if len(expr.args) == 3: + if expr.name == "psts": + dist_expr: SemanticExpr | None = None + if len(expr.args) == 2: value = self._analyze_expr(expr.args[0], env, allow_outer_lookup=allow_outer_lookup) destination, indices = self._analyze_tile_vector_access( expr.args[1], env, allow_outer_lookup=allow_outer_lookup, - context="pto.vsts destination", + context="pto.psts destination", ) - mask = self._analyze_expr(expr.args[2], env, allow_outer_lookup=allow_outer_lookup) + elif len(expr.args) == 3 and isinstance(expr.args[1], FrontendSubscriptExpr): + value = self._analyze_expr(expr.args[0], env, allow_outer_lookup=allow_outer_lookup) + destination, indices = self._analyze_tile_vector_access( + expr.args[1], + env, + allow_outer_lookup=allow_outer_lookup, + context="pto.psts destination", + ) + dist_expr = self._analyze_expr(expr.args[2], env, allow_outer_lookup=allow_outer_lookup) else: args = tuple( self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) for arg in expr.args ) - if len(args) != 4: - raise TypeError("pto.vsts expects 3 or 4 positional arguments in TileLang DSL v1") - value, destination, offset, mask = args - indices = (offset,) - self._require_vreg_expr(value, "pto.vsts value") - self._require_vector_pointer_expr(destination, "pto.vsts destination") + if len(args) == 3: + value, destination, offset = args + indices = (offset,) + elif len(args) == 4: + value, destination, offset, dist_expr = args + indices = (offset,) + else: + raise TypeError("pto.psts expects Tile element-indexing syntax or 3/4 positional arguments") + self._require_mask_expr(value, "pto.psts value") + self._require_vector_pointer_expr(destination, "pto.psts destination") for index in indices: self._require_index_typed_expr(index) - self._require_mask_for_vreg(mask, value.type, "pto.vsts") - self._require_matching_vector_pointer(value.type, destination.type, "pto.vsts") + dist = self._normalize_predicate_store_dist(dist_expr, "pto.psts dist") return ( - SemanticVectorStoreStmt( + SemanticPredicateStoreStmt( value=value, destination=destination, indices=indices, - mask=mask, + dist=dist, ), dict(env), ) - if expr.name == "vstsx2": - if len(expr.args) == 5: - low = self._analyze_expr(expr.args[0], env, allow_outer_lookup=allow_outer_lookup) - high = self._analyze_expr(expr.args[1], env, allow_outer_lookup=allow_outer_lookup) - destination, indices = self._analyze_tile_vector_access( - expr.args[2], - env, - allow_outer_lookup=allow_outer_lookup, - context="pto.vstsx2 destination", - ) - dist = self._analyze_expr(expr.args[3], env, allow_outer_lookup=allow_outer_lookup) - mask = self._analyze_expr(expr.args[4], env, allow_outer_lookup=allow_outer_lookup) + if expr.name in {"vsta", "vstas", "vstar"}: + offset: SemanticExpr | None = None + op_name = "vstas" if expr.name == "vsta" else expr.name + if expr.name == "vsta": + if len(expr.args) == 2: + value = self._analyze_expr(expr.args[0], env, allow_outer_lookup=allow_outer_lookup) + destination, indices = self._analyze_tile_vector_access( + expr.args[1], + env, + allow_outer_lookup=allow_outer_lookup, + context="pto.vsta destination", + ) + offset = SemanticLiteralExpr(value=0, type=SemanticScalarType(dtype=i32)) + else: + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + if len(args) != 3: + raise TypeError("pto.vsta expects 2 or 3 positional arguments in TileLang DSL v1") + value, destination, offset = args + indices = () + elif expr.name == "vstas": + if len(expr.args) == 3 and isinstance(expr.args[1], FrontendSubscriptExpr): + value = self._analyze_expr(expr.args[0], env, allow_outer_lookup=allow_outer_lookup) + destination, indices = self._analyze_tile_vector_access( + expr.args[1], + env, + allow_outer_lookup=allow_outer_lookup, + context="pto.vstas destination", + ) + offset = self._analyze_expr(expr.args[2], env, allow_outer_lookup=allow_outer_lookup) + else: + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + if len(args) != 3: + raise TypeError("pto.vstas expects exactly 3 positional arguments in TileLang DSL v1") + value, destination, offset = args + indices = () else: - args = tuple( - self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) - for arg in expr.args - ) - if len(args) != 6: - raise TypeError("pto.vstsx2 expects 5 or 6 positional arguments in TileLang DSL v1") - low, high, destination, offset, dist, mask = args - indices = (offset,) - low_type = self._require_vreg_expr(low, "pto.vstsx2 low") - high_type = self._require_vreg_expr(high, "pto.vstsx2 high") - if low_type != high_type: - raise TypeError("pto.vstsx2 requires low/high vectors to use the same vector type") - self._require_vector_pointer_expr(destination, "pto.vstsx2 destination") + if len(expr.args) == 2 and isinstance(expr.args[1], FrontendSubscriptExpr): + value = self._analyze_expr(expr.args[0], env, allow_outer_lookup=allow_outer_lookup) + destination, indices = self._analyze_tile_vector_access( + expr.args[1], + env, + allow_outer_lookup=allow_outer_lookup, + context="pto.vstar destination", + ) + else: + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + if len(args) != 2: + raise TypeError("pto.vstar expects exactly 2 positional arguments in TileLang DSL v1") + value, destination = args + indices = () + self._require_align_expr(value, f"pto.{expr.name} value") + self._require_vector_pointer_expr(destination, f"pto.{expr.name} destination") for index in indices: self._require_index_typed_expr(index) - dist = self._normalize_vstsx2_dist(dist) - self._require_mask_for_vreg(mask, low_type, "pto.vstsx2") - self._require_matching_vector_pointer(low_type, destination.type, "pto.vstsx2") + if offset is not None: + self._require_i32_like_expr(offset, f"pto.{expr.name} offset") return ( - SemanticVectorPairStoreStmt( - low=low, - high=high, + SemanticAlignStoreStmt( + op_name=op_name, + value=value, destination=destination, indices=indices, - dist=dist, - mask=mask, + offset=offset, ), dict(env), ) - analyzed = self._analyze_vector_memory_stmt_call( - expr, - env, - allow_outer_lookup=allow_outer_lookup, - ) - return SemanticExprStmt(expr=analyzed), dict(env) - - def _analyze_vector_memory_stmt_call( - self, - expr: FrontendCallExpr, - env: dict[str, SemanticBinding], - *, - allow_outer_lookup: bool, - ) -> SemanticExpr: - if expr.name == "psts": - if len(expr.args) == 2: - value = self._analyze_expr(expr.args[0], env, allow_outer_lookup=allow_outer_lookup) - destination, indices = self._analyze_tile_vector_access( - expr.args[1], - env, - allow_outer_lookup=allow_outer_lookup, - context="pto.psts destination", - ) - else: - value, destination, offset = tuple( - self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) - for arg in expr.args - ) - indices = (offset,) - mask = self._require_mask_expr(value, "pto.psts value") - self._require_vector_pointer_expr(destination, "pto.psts destination") - if isinstance(destination.type, SemanticTileType): - if len(indices) not in {1, 2}: - raise TypeError("pto.psts Tile syntax expects rank-1 or rank-2 element indexing in TileLang DSL v1") - else: - if len(indices) != 1: - raise TypeError("pto.psts pointer syntax expects exactly one offset operand in TileLang DSL v1") - for index in indices: - self._require_index_typed_expr(index) - return SemanticCallExpr( - namespace="pto", - name="psts", - args=(value, destination, *indices), - type=SemanticMetaType(kind="void"), - ) - - if expr.name == "pst": - if len(expr.args) in {2, 3} and isinstance(expr.args[1], FrontendSubscriptExpr): - value = self._analyze_expr(expr.args[0], env, allow_outer_lookup=allow_outer_lookup) - destination, indices = self._analyze_tile_vector_access( - expr.args[1], - env, - allow_outer_lookup=allow_outer_lookup, - context="pto.pst destination", - ) - dist_expr = ( - self._analyze_expr(expr.args[2], env, allow_outer_lookup=allow_outer_lookup) - if len(expr.args) == 3 - else None - ) - else: - if len(expr.args) not in {3, 4}: - raise TypeError("pto.pst expects value, destination, offset, and optional dist in TileLang DSL v1") - analyzed = tuple( - self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) - for arg in expr.args - ) - value, destination, offset = analyzed[:3] - indices = (offset,) - dist_expr = analyzed[3] if len(analyzed) == 4 else None - mask = self._require_mask_expr(value, "pto.pst value") - self._require_vector_pointer_expr(destination, "pto.pst destination") - if isinstance(destination.type, SemanticTileType): - if len(indices) not in {1, 2}: - raise TypeError("pto.pst Tile syntax expects rank-1 or rank-2 element indexing in TileLang DSL v1") - else: - if len(indices) != 1: - raise TypeError("pto.pst pointer syntax expects exactly one offset operand in TileLang DSL v1") - for index in indices: - self._require_index_typed_expr(index) - dist = self._normalize_predicate_dist( - dist_expr, - "pto.pst dist", - allowed={"NORM", "PK"}, - default="NORM", - ) - return SemanticCallExpr( - namespace="pto", - name="pst", - args=(value, destination, *indices, dist), - type=SemanticMetaType(kind="void"), - ) - - if expr.name == "psti": - if len(expr.args) not in {3, 4}: - raise TypeError("pto.psti expects value, destination, imm_offset, and optional dist in TileLang DSL v1") - analyzed = tuple( - self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) - for arg in expr.args - ) - value = analyzed[0] - destination = self._require_pointer_expr(analyzed[1], "pto.psti destination", memory_space="ub") - self._require_mask_expr(value, "pto.psti value") - self._require_i32_expr(analyzed[2], "pto.psti offset") - dist = self._normalize_predicate_dist( - analyzed[3] if len(analyzed) == 4 else None, - "pto.psti dist", - allowed={"NORM", "PK"}, - default="NORM", - ) - return SemanticCallExpr( - namespace="pto", - name="psti", - args=(value, destination, analyzed[2], dist), - type=SemanticMetaType(kind="void"), - ) - if expr.name == "vsst": if len(expr.args) == 3: - value = self._analyze_expr(expr.args[0], env, allow_outer_lookup=allow_outer_lookup) + scalar = self._analyze_expr(expr.args[0], env, allow_outer_lookup=allow_outer_lookup) destination, indices = self._analyze_tile_vector_access( expr.args[1], env, allow_outer_lookup=allow_outer_lookup, context="pto.vsst destination", ) - stride = self._analyze_expr(expr.args[2], env, allow_outer_lookup=allow_outer_lookup) + mask = self._analyze_expr(expr.args[2], env, allow_outer_lookup=allow_outer_lookup) else: - value, destination, offset, stride = tuple( + args = tuple( self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) for arg in expr.args ) + if len(args) != 4: + raise TypeError("pto.vsst expects 3 or 4 positional arguments in TileLang DSL v1") + scalar, destination, offset, mask = args indices = (offset,) - vreg = self._require_vreg_expr(value, "pto.vsst value") + scalar_type = self._require_scalar_expr(scalar, "pto.vsst scalar") self._require_vector_pointer_expr(destination, "pto.vsst destination") - if isinstance(destination.type, SemanticTileType): - if len(indices) not in {1, 2}: - raise TypeError("pto.vsst Tile syntax expects rank-1 or rank-2 element indexing in TileLang DSL v1") - else: - if len(indices) != 1: - raise TypeError("pto.vsst pointer syntax expects exactly one offset operand in TileLang DSL v1") for index in indices: self._require_index_typed_expr(index) - self._require_matching_vector_pointer(vreg, destination.type, "pto.vsst") - normalized_stride = self._normalize_stride_mode(stride, "pto.vsst stride") - return SemanticCallExpr( + destination_dtype = destination.type.element_dtype + if scalar_type.dtype != destination_dtype: + raise TypeError("pto.vsst scalar dtype must match destination element dtype in TileLang DSL v1") + value = SemanticCallExpr( namespace="pto", - name="vsst", - args=(value, destination, *indices, normalized_stride), - type=SemanticMetaType(kind="void"), + name="vbr", + args=(scalar,), + type=self._vreg_type_for_dtype(destination_dtype), ) - - if expr.name == "vstx2": - if len(expr.args) == 5: - low = self._analyze_expr(expr.args[0], env, allow_outer_lookup=allow_outer_lookup) - high = self._analyze_expr(expr.args[1], env, allow_outer_lookup=allow_outer_lookup) - destination, indices = self._analyze_tile_vector_access( - expr.args[2], - env, - allow_outer_lookup=allow_outer_lookup, - context="pto.vstx2 destination", - ) - dist = self._analyze_expr(expr.args[3], env, allow_outer_lookup=allow_outer_lookup) - mask = self._analyze_expr(expr.args[4], env, allow_outer_lookup=allow_outer_lookup) - else: - low, high, destination, offset, dist, mask = tuple( - self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) - for arg in expr.args - ) - indices = (offset,) - low_type = self._require_vreg_expr(low, "pto.vstx2 low") - high_type = self._require_vreg_expr(high, "pto.vstx2 high") - if low_type != high_type: - raise TypeError("pto.vstx2 requires low/high vector types to match") - self._require_vector_pointer_expr(destination, "pto.vstx2 destination") - if isinstance(destination.type, SemanticTileType): - if len(indices) not in {1, 2}: - raise TypeError("pto.vstx2 Tile syntax expects rank-1 or rank-2 element indexing in TileLang DSL v1") - else: - if len(indices) != 1: - raise TypeError("pto.vstx2 pointer syntax expects exactly one offset operand in TileLang DSL v1") - for index in indices: - self._require_index_typed_expr(index) - self._require_mask_for_vreg(mask, low_type, "pto.vstx2") - self._require_matching_vector_pointer(low_type, destination.type, "pto.vstx2") - normalized_dist = self._normalize_interleave_dist(dist, "pto.vstx2 dist") - return SemanticCallExpr( - namespace="pto", - name="vstx2", - args=(low, high, destination, *indices, normalized_dist, mask), - type=SemanticMetaType(kind="void"), + self._require_mask_for_vreg(mask, value.type, "pto.vsst") + self._require_matching_vector_pointer(value.type, destination.type, "pto.vsst") + return ( + SemanticVectorStoreStmt( + value=value, + destination=destination, + indices=indices, + mask=mask, + ), + dict(env), ) - if expr.name == "vsta": - if len(expr.args) == 2: + if expr.name == "vsts": + if len(expr.args) == 3: value = self._analyze_expr(expr.args[0], env, allow_outer_lookup=allow_outer_lookup) destination, indices = self._analyze_tile_vector_access( expr.args[1], env, allow_outer_lookup=allow_outer_lookup, - context="pto.vsta destination", + context="pto.vsts destination", ) + mask = self._analyze_expr(expr.args[2], env, allow_outer_lookup=allow_outer_lookup) else: - value, destination, offset = tuple( + args = tuple( self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) for arg in expr.args ) + if len(args) != 4: + raise TypeError("pto.vsts expects 3 or 4 positional arguments in TileLang DSL v1") + value, destination, offset, mask = args indices = (offset,) - self._require_align_expr(value, "pto.vsta value") - self._require_vector_pointer_expr(destination, "pto.vsta destination") - if isinstance(destination.type, SemanticTileType): - if len(indices) not in {1, 2}: - raise TypeError("pto.vsta Tile syntax expects rank-1 or rank-2 element indexing in TileLang DSL v1") - else: - if len(indices) != 1: - raise TypeError("pto.vsta pointer syntax expects exactly one offset operand in TileLang DSL v1") + self._require_vreg_expr(value, "pto.vsts value") + self._require_vector_pointer_expr(destination, "pto.vsts destination") for index in indices: self._require_index_typed_expr(index) - return SemanticCallExpr( - namespace="pto", - name="vsta", - args=(value, destination, *indices), - type=SemanticMetaType(kind="void"), + self._require_mask_for_vreg(mask, value.type, "pto.vsts") + self._require_matching_vector_pointer(value.type, destination.type, "pto.vsts") + return ( + SemanticVectorStoreStmt( + value=value, + destination=destination, + indices=indices, + mask=mask, + ), + dict(env), + ) + + if len(expr.args) == 5: + low = self._analyze_expr(expr.args[0], env, allow_outer_lookup=allow_outer_lookup) + high = self._analyze_expr(expr.args[1], env, allow_outer_lookup=allow_outer_lookup) + destination, indices = self._analyze_tile_vector_access( + expr.args[2], + env, + allow_outer_lookup=allow_outer_lookup, + context="pto.vstsx2 destination", + ) + dist = self._analyze_expr(expr.args[3], env, allow_outer_lookup=allow_outer_lookup) + mask = self._analyze_expr(expr.args[4], env, allow_outer_lookup=allow_outer_lookup) + else: + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args ) - raise ValueError(f"unsupported vector-memory stmt pto.{expr.name}") + if len(args) != 6: + raise TypeError("pto.vstsx2 expects 5 or 6 positional arguments in TileLang DSL v1") + low, high, destination, offset, dist, mask = args + indices = (offset,) + low_type = self._require_vreg_expr(low, "pto.vstsx2 low") + high_type = self._require_vreg_expr(high, "pto.vstsx2 high") + if low_type != high_type: + raise TypeError("pto.vstsx2 requires low/high vectors to use the same vector type") + self._require_vector_pointer_expr(destination, "pto.vstsx2 destination") + for index in indices: + self._require_index_typed_expr(index) + dist = self._normalize_vstsx2_dist(dist) + self._require_mask_for_vreg(mask, low_type, "pto.vstsx2") + self._require_matching_vector_pointer(low_type, destination.type, "pto.vstsx2") + return ( + SemanticVectorPairStoreStmt( + low=low, + high=high, + destination=destination, + indices=indices, + dist=dist, + mask=mask, + ), + dict(env), + ) def _analyze_scalar_store_stmt( self, @@ -2005,13 +1960,10 @@ def _analyze_low_level_dma_operands( f"pto.{expr.name} does not support mixing positional and keyword operands in TileLang DSL v1" ) if not expr.keywords: - args = tuple( + return tuple( self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) for arg in expr.args ) - if expr.name == "copy_ubuf_to_ubuf" and len(args) == 8: - return self._normalize_copy_ubuf_to_ubuf_guide_operands(args) - return args analyzed_keywords: dict[str, SemanticExpr] = { name: self._analyze_expr(value, env, allow_outer_lookup=allow_outer_lookup) @@ -2085,56 +2037,10 @@ def bool_literal(value: bool) -> SemanticLiteralExpr: analyzed_keywords["ub_stride"], ), ) - if expr.name == "copy_ubuf_to_ubuf": - return self._normalize_copy_ubuf_to_ubuf_guide_operands( - ( - analyzed_keywords["src"], - analyzed_keywords["dst"], - analyzed_keywords["src_offset"], - analyzed_keywords["src_stride0"], - analyzed_keywords["src_stride1"], - analyzed_keywords["dst_offset"], - analyzed_keywords["dst_stride0"], - analyzed_keywords["dst_stride1"], - ) - ) raise TypeError( f"pto.{expr.name} keyword form is not implemented in TileLang DSL v1" ) - def _normalize_copy_ubuf_to_ubuf_guide_operands( - self, - args: tuple[SemanticExpr, ...], - ) -> tuple[SemanticExpr, ...]: - if len(args) != 8: - raise TypeError( - "pto.copy_ubuf_to_ubuf guide form expects exactly 8 operands in TileLang DSL" - ) - source = self._require_pointer_expr(args[0], "pto.copy_ubuf_to_ubuf source", memory_space="ub") - destination = self._require_pointer_expr( - args[1], - "pto.copy_ubuf_to_ubuf destination", - memory_space="ub", - ) - self._require_i64_like_expr(args[2], "pto.copy_ubuf_to_ubuf src_offset") - self._require_i64_like_expr(args[3], "pto.copy_ubuf_to_ubuf src_stride0") - self._require_i64_like_expr(args[4], "pto.copy_ubuf_to_ubuf src_stride1") - self._require_i64_like_expr(args[5], "pto.copy_ubuf_to_ubuf dst_offset") - self._require_i64_like_expr(args[6], "pto.copy_ubuf_to_ubuf dst_stride0") - self._require_i64_like_expr(args[7], "pto.copy_ubuf_to_ubuf dst_stride1") - zero_sid = SemanticLiteralExpr(value=0, type=SemanticIndexType()) - # The guide-level surface rewrites offsets into the base pointers, then - # lowers the remaining four integers to the existing VPTO UB->UB copy ABI. - return ( - SemanticCallExpr(namespace="pto", name="addptr", args=(source, args[2]), type=source.type), - SemanticCallExpr(namespace="pto", name="addptr", args=(destination, args[5]), type=destination.type), - zero_sid, - args[3], - args[4], - args[6], - args[7], - ) - def _require_tensor_slice( self, expr: SemanticExpr, @@ -2440,9 +2346,7 @@ def _bind_assignment_target( for axis in range(value.type.rank) ) elif isinstance(value, SemanticCallExpr): - tuple_values = tuple( - SemanticLiteralExpr(value=None, type=element_type) for element_type in element_types - ) + tuple_values = value.args else: tuple_values = tuple( SemanticLiteralExpr(value=None, type=element_type) for element_type in element_types @@ -2505,6 +2409,8 @@ def _annotation_type( f"annotated mask type `{mask_type!r}` does not match inferred !pto.mask<{inferred_type.granularity}>" ) return inferred_type + if annotation_expr.type.kind == "align_type" and isinstance(inferred_type, SemanticAlignType): + return inferred_type raise TypeError("unsupported annotated assignment type in TileLang DSL v1") def _analyze_annotation_expr( @@ -2522,7 +2428,7 @@ def _build_frontend_annotation_expr(self, node: ast.AST) -> FrontendExprNode: return FrontendConstantExpr(value=node.value) if isinstance(node, ast.Attribute): path = self._annotation_attribute_path(node) - if path is not None and path[0] in {"pto", "PAT", "PIPE", "Pipe", "EVENT", "Event"} and len(path) >= 2: + if path is not None and path[0] in {"pto", "PAT", "PIPE", "EVENT"} and len(path) >= 2: return FrontendSymbolExpr(namespace=".".join(path[:-1]), name=path[-1]) return FrontendAttributeExpr( base=self._build_frontend_annotation_expr(node.value), @@ -2800,8 +2706,6 @@ def _analyze_expr( return self._valid_shape_expr(base) if expr.attr == "strides": return self._strides_expr(base) - if expr.attr == "rank": - return self._rank_expr(base) attr_type = self._attribute_type(base, expr.attr) return SemanticAttributeAccess(base=base, attr=expr.attr, type=attr_type) if isinstance(expr, FrontendSubscriptExpr): @@ -2873,22 +2777,8 @@ def _analyze_expr( allow_outer_lookup=allow_outer_lookup, context="pto.vldus source", ) - align = self._analyze_expr(expr.args[1], env, allow_outer_lookup=allow_outer_lookup) - return self._analyze_vldus((base, *indices, align)) - if ( - expr.namespace == "pto" - and expr.name == "vldx2" - and len(expr.args) == 2 - and isinstance(expr.args[0], FrontendSubscriptExpr) - ): - base, indices = self._analyze_tile_vector_access( - expr.args[0], - env, - allow_outer_lookup=allow_outer_lookup, - context="pto.vldx2 source", - ) - dist = self._analyze_expr(expr.args[1], env, allow_outer_lookup=allow_outer_lookup) - return self._analyze_vldx2((base, *indices, dist)) + align_expr = self._analyze_expr(expr.args[1], env, allow_outer_lookup=allow_outer_lookup) + return self._analyze_vldus((base, *indices, align_expr)) if expr.namespace == "pto" and expr.name == "vldsx2" and len(expr.args) == 2: base, indices = self._analyze_tile_vector_access( expr.args[0], @@ -2898,51 +2788,11 @@ def _analyze_expr( ) dist = self._analyze_expr(expr.args[1], env, allow_outer_lookup=allow_outer_lookup) return self._analyze_vldsx2((base, *indices, dist)) - if ( - expr.namespace == "pto" - and expr.name == "vsld" - and len(expr.args) == 2 - and isinstance(expr.args[0], FrontendSubscriptExpr) - ): - base, indices = self._analyze_tile_scalar_access( - expr.args[0], - env, - allow_outer_lookup=allow_outer_lookup, - context="pto.vsld source", - ) - stride = self._analyze_expr(expr.args[1], env, allow_outer_lookup=allow_outer_lookup) - return self._analyze_vsld((base, *indices, stride)) - if ( - expr.namespace == "pto" - and expr.name == "plds" - and len(expr.args) in {1, 2} - and isinstance(expr.args[0], FrontendSubscriptExpr) - ): - base, indices = self._analyze_tile_vector_access( - expr.args[0], - env, - allow_outer_lookup=allow_outer_lookup, - context="pto.plds source", - ) - extra_args = tuple( - self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) - for arg in expr.args[1:] - ) - return self._analyze_predicate_memory_expr_op("plds", (base, *indices, *extra_args)) - if expr.namespace == "pto" and expr.name == "vldsx2" and len(expr.args) == 2: - base, indices = self._analyze_tile_vector_access( - expr.args[0], - env, - allow_outer_lookup=allow_outer_lookup, - context="pto.vldsx2 source", - ) - dist = self._analyze_expr(expr.args[1], env, allow_outer_lookup=allow_outer_lookup) - return self._analyze_vldsx2((base, *indices, dist)) - if expr.keywords: - return self._analyze_keyword_call_expr( - expr, - env, - allow_outer_lookup=allow_outer_lookup, + if expr.keywords: + raise TypeError( + f"call surface `{expr.namespace + '.' if expr.namespace else ''}{expr.name}` " + "carries keyword arguments, but semantic keyword handling is not implemented " + "in TileLang DSL v1 yet" ) args = tuple( self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) @@ -2969,6 +2819,13 @@ def _analyze_symbol_expr(self, expr: FrontendSymbolExpr) -> SemanticExpr: value=mask_type, type=SemanticMetaType(kind="mask_type"), ) + if expr.name == "align": + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=align, + type=SemanticMetaType(kind="align_type"), + ) if expr.namespace in {"PAT", "pto.PAT", "pto.MaskPattern"}: pattern = _PATTERN_SYMBOLS.get(expr.name) if pattern is None and expr.name.startswith("PAT_"): @@ -3017,33 +2874,6 @@ def _analyze_symbol_expr(self, expr: FrontendSymbolExpr) -> SemanticExpr: value=memory_space, type=SemanticMetaType(kind="memory_space"), ) - if expr.namespace in {"BLayout", "pto.BLayout"}: - b_layout = _B_LAYOUT_SYMBOLS.get(expr.name) - if b_layout is not None: - return SemanticSymbolExpr( - namespace=expr.namespace, - name=expr.name, - value=b_layout, - type=SemanticMetaType(kind="b_layout"), - ) - if expr.namespace in {"SLayout", "pto.SLayout"}: - s_layout = _S_LAYOUT_SYMBOLS.get(expr.name) - if s_layout is not None: - return SemanticSymbolExpr( - namespace=expr.namespace, - name=expr.name, - value=s_layout, - type=SemanticMetaType(kind="s_layout"), - ) - if expr.namespace in {"PadValue", "pto.PadValue"}: - pad_value = _PAD_VALUE_SYMBOLS.get(expr.name) - if pad_value is not None: - return SemanticSymbolExpr( - namespace=expr.namespace, - name=expr.name, - value=pad_value, - type=SemanticMetaType(kind="pad_value"), - ) if expr.namespace in {"PadMode", "pto.PadMode"}: pad_mode = _PAD_MODE_SYMBOLS.get(expr.name) if pad_mode is not None: @@ -3089,50 +2919,14 @@ def _analyze_symbol_expr(self, expr: FrontendSymbolExpr) -> SemanticExpr: value=order_mode, type=SemanticMetaType(kind="order_mode"), ) - if expr.namespace in {"DeinterleaveDist", "pto.DeinterleaveDist"}: - dist = _DEINTERLEAVE_DIST_SYMBOLS.get(expr.name) - if dist is not None: - return SemanticSymbolExpr( - namespace=expr.namespace, - name=expr.name, - value=dist, - type=SemanticMetaType(kind="deinterleave_dist"), - ) - if expr.namespace in {"InterleaveDist", "pto.InterleaveDist"}: - dist = _INTERLEAVE_DIST_SYMBOLS.get(expr.name) - if dist is not None: - return SemanticSymbolExpr( - namespace=expr.namespace, - name=expr.name, - value=dist, - type=SemanticMetaType(kind="interleave_dist"), - ) - if expr.namespace in {"PredicateDist", "pto.PredicateDist"}: - dist = _PREDICATE_DIST_SYMBOLS.get(expr.name) - if dist is not None: + if expr.namespace in {"PostUpdateMode", "pto.PostUpdateMode"}: + post_update_mode = _POST_UPDATE_MODE_SYMBOLS.get(expr.name) + if post_update_mode is not None: return SemanticSymbolExpr( namespace=expr.namespace, name=expr.name, - value=dist, - type=SemanticMetaType(kind="predicate_dist"), - ) - if expr.namespace in {"PredicatePart", "pto.PredicatePart"}: - part = _PREDICATE_PART_SYMBOLS.get(expr.name) - if part is not None: - return SemanticSymbolExpr( - namespace=expr.namespace, - name=expr.name, - value=part, - type=SemanticMetaType(kind="predicate_part"), - ) - if expr.namespace in {"StrideMode", "pto.StrideMode"}: - stride = _STRIDE_MODE_SYMBOLS.get(expr.name) - if stride is not None: - return SemanticSymbolExpr( - namespace=expr.namespace, - name=expr.name, - value=stride, - type=SemanticMetaType(kind="stride_mode"), + value=post_update_mode, + type=SemanticMetaType(kind="post_update_mode"), ) raise TypeError( f"symbol `{expr.namespace}.{expr.name}` is not supported in TileLang DSL v1" @@ -3148,29 +2942,8 @@ def _attribute_type(self, base: SemanticExpr, attr: str) -> SemanticType: return SemanticShapeType(rank=base_type.rank) if isinstance(base_type, (SemanticTensorViewType, SemanticPartitionTensorViewType, SemanticTileType)) and attr == "valid_shape": return SemanticShapeType(rank=base_type.rank) - if isinstance(base_type, (SemanticTensorViewType, SemanticPartitionTensorViewType, SemanticTileType)) and attr == "rank": - return SemanticIndexType() - if isinstance(base_type, SemanticTileType) and attr == "memory_space": - return SemanticMetaType(kind="memory_space") - if isinstance(base_type, SemanticTileType) and attr == "config": - return SemanticMetaType(kind="tile_config") - if isinstance(base_type, SemanticMetaType) and base_type.kind == "tile_config": - if attr == "b_layout": - return SemanticMetaType(kind="b_layout") - if attr == "s_layout": - return SemanticMetaType(kind="s_layout") - if attr == "s_fractal_size": - return SemanticScalarType(dtype=i32) - if attr == "pad_value": - return SemanticMetaType(kind="pad_value") raise TypeError(f"unsupported attribute access '{attr}' in TileLang DSL v1") - def _rank_expr(self, base: SemanticExpr) -> SemanticExpr: - base_type = base.type - if isinstance(base_type, (SemanticTensorViewType, SemanticPartitionTensorViewType, SemanticTileType)): - return SemanticLiteralExpr(value=base_type.rank, type=SemanticIndexType()) - raise TypeError("unsupported attribute access 'rank' in TileLang DSL v1") - def _element_type_expr(self, base: SemanticExpr) -> SemanticExpr: base_type = base.type if isinstance(base_type, (SemanticTensorViewType, SemanticPartitionTensorViewType, SemanticTileType)): @@ -3259,15 +3032,6 @@ def _strides_expr(self, base: SemanticExpr) -> SemanticExpr: type=SemanticTupleType(elements=tuple(SemanticIndexType() for _ in elements)), ) - def _static_shape_tuple_expr(self, values: tuple[int, ...]) -> SemanticTupleExpr: - return SemanticTupleExpr( - elements=tuple( - SemanticLiteralExpr(value=value, type=SemanticIndexType()) - for value in values - ), - type=SemanticTupleType(elements=tuple(SemanticIndexType() for _ in values)), - ) - def _subscript_type(self, base: SemanticExpr, index: SemanticExpr) -> SemanticType: if isinstance(base.type, SemanticShapeType): if not isinstance(index.type, SemanticIndexType): @@ -3283,16 +3047,6 @@ def _subscript_type(self, base: SemanticExpr, index: SemanticExpr) -> SemanticTy f"shape subscript index {index.value} is out of bounds for rank {base.type.rank}" ) return SemanticIndexType() - if isinstance(base.type, SemanticTupleType): - if not isinstance(index.type, SemanticIndexType): - raise TypeError("tuple subscript index must be an index value in TileLang DSL v1") - if isinstance(index, SemanticLiteralExpr) and isinstance(index.value, int): - if index.value < 0 or index.value >= len(base.type.elements): - raise TypeError( - f"tuple subscript index {index.value} is out of bounds for arity {len(base.type.elements)}" - ) - return base.type.elements[index.value] - raise TypeError("tuple subscript index must be a compile-time integer literal in TileLang DSL v1") if isinstance(base.type, (SemanticTensorViewType, SemanticPartitionTensorViewType)): if not isinstance(index, SemanticTupleExpr): raise TypeError("TensorView slicing expects a tuple of slices in TileLang DSL v1") @@ -3322,29 +3076,6 @@ def _analyze_tile_vector_access( ) return base, indices - def _analyze_tile_scalar_access( - self, - expr: FrontendExprNode, - env: dict[str, SemanticBinding], - *, - allow_outer_lookup: bool, - context: str, - ) -> tuple[SemanticExpr, tuple[SemanticExpr, ...]]: - if not isinstance(expr, FrontendSubscriptExpr): - raise TypeError( - f"{context} expects Tile element-indexing syntax in TileLang DSL v1" - ) - base = self._analyze_expr(expr.base, env, allow_outer_lookup=allow_outer_lookup) - tile = self._require_tile_expr(base, context) - indices = self._tile_scalar_indices( - expr.index, - tile.type, - env, - allow_outer_lookup=allow_outer_lookup, - context=context, - ) - return base, indices - def _tile_vector_indices( self, index_expr: FrontendExprNode, @@ -3370,33 +3101,11 @@ def _tile_vector_indices( if tile_type.rank != 2 or tile_type.shape is None: raise TypeError(f"{context} currently only supports statically specialized rank-1 or rank-2 Tiles") if not isinstance(index_expr, FrontendTupleExpr) or len(index_expr.elements) != 2: - raise TypeError( - f"{context} expects {self._tile_vector_rank2_syntax(tile_type)} syntax for rank-2 Tile values" - ) + raise TypeError(f"{context} expects Tile[row, col:] syntax for rank-2 Tile values") row_expr, col_expr = index_expr.elements - if self._tile_b_layout(tile_type) == BLayout.COL_MAJOR: - if not isinstance(row_expr, FrontendSliceExpr) or isinstance(col_expr, FrontendSliceExpr): - raise TypeError( - f"{context} expects {self._tile_vector_rank2_syntax(tile_type)} syntax for rank-2 Tile values" - ) - if row_expr.stop is not None: - raise TypeError(f"{context} does not support explicit slice stop in TileLang DSL advanced mode") - if row_expr.step is not None: - raise TypeError(f"{context} does not support stepped Tile vector slices in TileLang DSL advanced mode") - if row_expr.start is None: - row = SemanticLiteralExpr(value=0, type=SemanticIndexType()) - else: - row = self._analyze_expr(row_expr.start, env, allow_outer_lookup=allow_outer_lookup) - self._require_index_typed_expr(row) - col = self._analyze_expr(col_expr, env, allow_outer_lookup=allow_outer_lookup) - self._require_index_typed_expr(col) - return (row, col) - - if not isinstance(col_expr, FrontendSliceExpr) or isinstance(row_expr, FrontendSliceExpr): - raise TypeError( - f"{context} expects {self._tile_vector_rank2_syntax(tile_type)} syntax for rank-2 Tile values" - ) + if not isinstance(col_expr, FrontendSliceExpr): + raise TypeError(f"{context} expects Tile[row, col:] syntax for rank-2 Tile values") if col_expr.stop is not None: raise TypeError(f"{context} does not support explicit slice stop in TileLang DSL advanced mode") if col_expr.step is not None: @@ -3411,42 +3120,6 @@ def _tile_vector_indices( self._require_index_typed_expr(col) return (row, col) - def _tile_b_layout(self, tile_type: SemanticTileType) -> BLayout: - config = TileConfig() if tile_type.config is None else tile_type.config - return config.b_layout - - def _tile_vector_rank2_syntax(self, tile_type: SemanticTileType) -> str: - if self._tile_b_layout(tile_type) == BLayout.COL_MAJOR: - return "Tile[row_start:, col_index]" - return "Tile[row, col:]" - - def _tile_scalar_indices( - self, - index_expr: FrontendExprNode, - tile_type: SemanticTileType, - env: dict[str, SemanticBinding], - *, - allow_outer_lookup: bool, - context: str, - ) -> tuple[SemanticExpr, ...]: - if tile_type.rank == 1: - if isinstance(index_expr, FrontendSliceExpr): - raise TypeError(f"{context} expects Tile[pos] syntax for rank-1 Tile values") - index = self._analyze_expr(index_expr, env, allow_outer_lookup=allow_outer_lookup) - self._require_index_typed_expr(index) - return (index,) - - if tile_type.rank != 2 or tile_type.shape is None: - raise TypeError(f"{context} currently only supports statically specialized rank-1 or rank-2 Tiles") - if not isinstance(index_expr, FrontendTupleExpr) or len(index_expr.elements) != 2: - raise TypeError(f"{context} expects Tile[row, col] syntax for rank-2 Tile values") - - row = self._analyze_expr(index_expr.elements[0], env, allow_outer_lookup=allow_outer_lookup) - col = self._analyze_expr(index_expr.elements[1], env, allow_outer_lookup=allow_outer_lookup) - self._require_index_typed_expr(row) - self._require_index_typed_expr(col) - return (row, col) - def _tensor_slice_type( self, tensor_type: SemanticTensorViewType | SemanticPartitionTensorViewType, @@ -3580,44 +3253,34 @@ def _analyze_call_expr( ) if name == "make_mask": return self._analyze_make_mask(args) - if name in _DIRECT_PREDICATE_PATTERN_OPS: - return self._analyze_direct_predicate_pattern_op(name, args) - if name in _DIRECT_PREDICATE_TAIL_OPS: - return self._analyze_direct_predicate_tail_op(name, args) + if name in { + "get_block_idx", + "get_subblock_idx", + "get_block_num", + "get_subblock_num", + }: + return self._analyze_runtime_block_query(name, args) + if name == "init_align": + return self._analyze_init_align(args) if name == "vlds": return self._analyze_vlds(args) if name == "vldas": return self._analyze_vldas(args) if name == "vldus": return self._analyze_vldus(args) - if name == "vldx2": - return self._analyze_vldx2(args) - if name == "vsld": - return self._analyze_vsld(args) - if name in _PREDICATE_MEMORY_EXPR_OPS: - return self._analyze_predicate_memory_expr_op(name, args) + if name == "vldsx2": + return self._analyze_vldsx2(args) if name == "pstu": return self._analyze_pstu(args) - if name == "vstu": - return self._analyze_vstu(args) if name == "vstus": return self._analyze_vstus(args) if name == "vstur": return self._analyze_vstur(args) - if name in { - "get_block_idx", - "get_subblock_idx", - "get_block_num", - "get_subblock_num", - }: - return self._analyze_runtime_block_query(name, args) - if name == "vldsx2": - return self._analyze_vldsx2(args) if name == "load_scalar": return self._analyze_load_scalar(args) if name in {"ppack", "punpack"}: return self._analyze_mask_part_op(name, args) - if name in {"pnot", "psel"} | _PREDICATE_BINARY_LOGIC_OPS: + if name in {"pnot", "psel"}: return self._analyze_mask_logic_op(name, args) if name in {"vcmp", "vcmps"}: return self._analyze_compare_op(name, args) @@ -3625,8 +3288,6 @@ def _analyze_call_expr( return self._analyze_select_op(name, args) if name in {"vaddc", "vsubc", "vaddcs", "vsubcs"}: return self._analyze_carry_op(name, args) - if name in _PREDICATE_REARRANGEMENT_OPS: - return self._analyze_predicate_rearrangement_op(name, args) if name in {"vintlv", "vdintlv", "vintlvv2", "vdintlvv2"}: return self._analyze_rearrangement_op(name, args) if name == "vcvt": @@ -3649,33 +3310,6 @@ def _analyze_call_expr( return self._analyze_ternary_vector_op(name, args) raise TypeError(f"call surface `pto.{name}` is not supported in TileLang DSL v1 yet") - def _analyze_keyword_call_expr( - self, - expr: FrontendCallExpr, - env: dict[str, SemanticBinding], - *, - allow_outer_lookup: bool, - ) -> SemanticExpr: - args = tuple( - self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) - for arg in expr.args - ) - analyzed_keywords = { - name: self._analyze_expr(value, env, allow_outer_lookup=allow_outer_lookup) - for name, value in expr.keywords - } - - if expr.namespace == "pto" and expr.name == "vdup": - return self._analyze_vdup_keyword_call(args, analyzed_keywords) - if expr.namespace == "pto" and expr.name == "vci": - return self._analyze_vci_keyword_call(args, analyzed_keywords) - - raise TypeError( - f"call surface `{expr.namespace + '.' if expr.namespace else ''}{expr.name}` " - "carries keyword arguments, but semantic keyword handling is not implemented " - "in TileLang DSL v1 yet" - ) - def _analyze_make_mask(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: if len(args) != 2: raise TypeError("pto.make_mask expects exactly 2 positional arguments in TileLang DSL v1") @@ -3701,142 +3335,6 @@ def _analyze_make_mask(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: ), ) - def _mask_type_from_named_family(self, name: str) -> SemanticMaskType: - if name.endswith("_b8"): - return SemanticMaskType(granularity="b8") - if name.endswith("_b16"): - return SemanticMaskType(granularity="b16") - if name.endswith("_b32"): - return SemanticMaskType(granularity="b32") - raise TypeError(f"unsupported predicate family `{name}` in TileLang DSL v1") - - def _analyze_direct_predicate_pattern_op( - self, - name: str, - args: tuple[SemanticExpr, ...], - ) -> SemanticExpr: - if len(args) != 1: - raise TypeError(f"pto.{name} expects exactly 1 positional argument in TileLang DSL v1") - pattern = args[0] - if not ( - ( - isinstance(pattern, SemanticSymbolExpr) - and isinstance(pattern.type, SemanticMetaType) - and pattern.type.kind == "mask_pattern" - ) - or ( - isinstance(pattern, SemanticBindingRef) - and isinstance(pattern.type, SemanticMetaType) - and pattern.type.kind == "mask_pattern" - ) - ): - raise TypeError(f"pto.{name} pattern must be a pto.MaskPattern value in TileLang DSL v1") - return SemanticCallExpr( - namespace="pto", - name=name, - args=args, - type=self._mask_type_from_named_family(name), - ) - - def _analyze_direct_predicate_tail_op( - self, - name: str, - args: tuple[SemanticExpr, ...], - ) -> SemanticExpr: - if len(args) != 1: - raise TypeError(f"pto.{name} expects exactly 1 positional argument in TileLang DSL v1") - self._require_tail_remaining_expr(args[0], f"pto.{name} scalar") - return SemanticCallExpr( - namespace="pto", - name=name, - args=args, - type=SemanticTupleType(elements=(self._mask_type_from_named_family(name), _I32_TYPE)), - ) - - def _predicate_mask_type_for_buffer( - self, - source: SemanticExpr, - context: str, - ) -> SemanticMaskType: - if isinstance(source.type, SemanticTileType): - tile = self._require_tile_expr(source, context) - return SemanticMaskType(granularity=self._mask_granularity_for_dtype(tile.type.element_dtype)) - pointer = self._require_pointer_expr(source, context, memory_space="ub") - return SemanticMaskType(granularity=self._mask_granularity_for_dtype(pointer.type.element_dtype)) - - def _analyze_predicate_memory_expr_op( - self, - name: str, - args: tuple[SemanticExpr, ...], - ) -> SemanticExpr: - if name == "plds": - if len(args) < 2: - raise TypeError("pto.plds expects source, offset, and optional dist in TileLang DSL v1") - source = args[0] - if isinstance(source.type, SemanticTileType): - source = self._require_tile_expr(source, "pto.plds source") - else: - source = self._require_pointer_expr(source, "pto.plds source", memory_space="ub") - trailing = args[1:] - has_dist = ( - len(trailing) > 1 - and isinstance(trailing[-1].type, SemanticMetaType) - and trailing[-1].type.kind in {"predicate_dist", "string"} - ) - index_args = trailing[:-1] if has_dist else trailing - dist_arg = trailing[-1] if has_dist else None - for index in index_args: - self._require_index_typed_expr(index) - if isinstance(source.type, SemanticPtrType) and len(index_args) != 1: - raise TypeError("pto.plds pointer syntax expects exactly one offset operand in TileLang DSL v1") - dist = self._normalize_predicate_dist( - dist_arg, - "pto.plds dist", - allowed={"NORM", "US", "DS"}, - default="NORM", - ) - return SemanticCallExpr( - namespace="pto", - name=name, - args=(source, *index_args, dist), - type=self._predicate_mask_type_for_buffer(source, "pto.plds source"), - ) - - if name == "pld": - if len(args) != 3: - raise TypeError("pto.pld expects source, offset, and dist in TileLang DSL v1") - source = self._require_pointer_expr(args[0], "pto.pld source", memory_space="ub") - self._require_index_typed_expr(args[1]) - dist = self._normalize_predicate_dist( - args[2], - "pto.pld dist", - allowed={"NORM", "US", "DS"}, - default="NORM", - ) - return SemanticCallExpr( - namespace="pto", - name=name, - args=(source, args[1], dist), - type=self._predicate_mask_type_for_buffer(source, "pto.pld source"), - ) - - if len(args) != 3: - raise TypeError("pto.pldi expects source, imm_offset, and dist in TileLang DSL v1") - source = self._require_pointer_expr(args[0], "pto.pldi source", memory_space="ub") - self._require_i32_expr(args[1], "pto.pldi offset") - dist = self._normalize_predicate_dist( - args[2], - "pto.pldi dist", - allowed={"NORM", "US", "DS"}, - default="NORM", - ) - return SemanticCallExpr( - namespace="pto", - name=name, - args=(source, args[1], dist), - type=self._predicate_mask_type_for_buffer(source, "pto.pldi source"), - ) - def _analyze_scalar_constructor( self, name: str, @@ -3996,7 +3494,7 @@ def _analyze_addptr(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: raise TypeError("pto.addptr expects exactly 2 positional arguments in TileLang DSL") pointer, offset = args ptr = self._require_pointer_expr(pointer, "pto.addptr pointer") - self._require_i64_like_expr(offset, "pto.addptr offset") + self._require_index_typed_expr(offset) return SemanticCallExpr(namespace="pto", name="addptr", args=(ptr, offset), type=ptr.type) def _analyze_get_lanes( @@ -4018,6 +3516,11 @@ def _analyze_bytewidth(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: dtype = self._require_dtype_symbol(args[0], "pto.bytewidth dtype") return SemanticLiteralExpr(value=bytewidth(dtype), type=SemanticIndexType()) + def _analyze_init_align(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if args: + raise TypeError("pto.init_align does not accept positional arguments in TileLang DSL v1") + return SemanticCallExpr(namespace="pto", name="init_align", args=(), type=SemanticAlignType()) + def _analyze_vlds(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: if len(args) < 2: raise TypeError("pto.vlds expects at least 2 positional arguments in TileLang DSL v1") @@ -4037,196 +3540,126 @@ def _analyze_vlds(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: ) def _analyze_vldas(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: - if not 1 <= len(args) <= 3: - raise TypeError("pto.vldas expects 1 positional argument or Tile element-indexing syntax in TileLang DSL v1") + if len(args) not in {1, 2, 3}: + raise TypeError("pto.vldas expects 1 positional source or Tile[start:]/Tile[row, col:] in TileLang DSL v1") source, *indices = args - if isinstance(source.type, SemanticTileType): + source_type = source.type + if isinstance(source_type, SemanticTileType): source = self._require_tile_expr(source, "pto.vldas source") for index in indices: self._require_index_typed_expr(index) else: if indices: - raise TypeError("pto.vldas pointer syntax does not accept an explicit offset in TileLang DSL v1") + raise TypeError("pto.vldas pointer syntax does not accept explicit indices in TileLang DSL v1") source = self._require_pointer_expr(source, "pto.vldas source", memory_space="ub") return SemanticCallExpr( namespace="pto", name="vldas", args=(source, *indices), - type=_ALIGN_TYPE, + type=SemanticAlignType(), ) def _analyze_vldus(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: - if len(args) < 2: - raise TypeError("pto.vldus expects source and align operands in TileLang DSL v1") - source = args[0] - align = args[-1] - indices = args[1:-1] - if isinstance(source.type, SemanticTileType): + if len(args) not in {2, 3, 4}: + raise TypeError("pto.vldus expects (source, align) or Tile element-indexing syntax in TileLang DSL v1") + source, *rest = args + align_expr = rest[-1] + index_args = rest[:-1] + source_type = source.type + if isinstance(source_type, SemanticTileType): source = self._require_tile_expr(source, "pto.vldus source") - for index in indices: + for index in index_args: self._require_index_typed_expr(index) else: - if indices: - raise TypeError("pto.vldus pointer syntax does not accept an explicit offset in TileLang DSL v1") + if index_args: + raise TypeError("pto.vldus pointer syntax does not accept explicit indices in TileLang DSL v1") source = self._require_pointer_expr(source, "pto.vldus source", memory_space="ub") - self._require_align_expr(align, "pto.vldus align") - source_ptr_type = SemanticPtrType( - element_dtype=source.type.element_dtype, - memory_space="ub", - ) + self._require_align_expr(align_expr, "pto.vldus align") return SemanticCallExpr( namespace="pto", name="vldus", - args=(source, *indices, align), - type=SemanticTupleType( - elements=( - self._vreg_type_for_dtype(source.type.element_dtype), - _ALIGN_TYPE, - source_ptr_type, - ) - ), + args=(source, *index_args, align_expr), + type=SemanticTupleType(elements=(self._vreg_type_for_dtype(source.type.element_dtype), SemanticAlignType())), ) - def _analyze_vldx2(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: - if len(args) < 3: - raise TypeError("pto.vldx2 expects source, offset, and dist operands in TileLang DSL v1") - source = args[0] - dist = args[-1] - indices = args[1:-1] - if isinstance(source.type, SemanticTileType): - source = self._require_tile_expr(source, "pto.vldx2 source") - if len(indices) not in {1, 2}: - raise TypeError("pto.vldx2 Tile syntax expects rank-1 or rank-2 element indexing in TileLang DSL v1") + def _analyze_vldsx2(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) not in {3, 4}: + raise TypeError("pto.vldsx2 expects 3 or 4 positional arguments in TileLang DSL v1") + source, *rest = args + if len(rest) == 2: + index_args = rest[:1] + dist = rest[1] else: - source = self._require_pointer_expr(source, "pto.vldx2 source", memory_space="ub") - if len(indices) != 1: - raise TypeError("pto.vldx2 pointer syntax expects exactly one offset operand in TileLang DSL v1") - for index in indices: - self._require_index_typed_expr(index) - normalized_dist = self._normalize_deinterleave_dist(dist, "pto.vldx2 dist") - result_type = self._vreg_type_for_dtype(source.type.element_dtype) - return SemanticCallExpr( - namespace="pto", - name="vldx2", - args=(source, *indices, normalized_dist), - type=SemanticTupleType(elements=(result_type, result_type)), - ) - - def _analyze_vsld(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: - if len(args) < 3: - raise TypeError("pto.vsld expects source, offset, and stride operands in TileLang DSL v1") - source = args[0] - stride = args[-1] - indices = args[1:-1] - if isinstance(source.type, SemanticTileType): - source = self._require_tile_expr(source, "pto.vsld source") - if len(indices) not in {1, 2}: - raise TypeError("pto.vsld Tile syntax expects rank-1 or rank-2 element indexing in TileLang DSL v1") + index_args = rest[:2] + dist = rest[2] + source_type = source.type + if isinstance(source_type, SemanticTileType): + source = self._require_tile_expr(source, "pto.vldsx2 source") else: - source = self._require_pointer_expr(source, "pto.vsld source", memory_space="ub") - if len(indices) != 1: - raise TypeError("pto.vsld pointer syntax expects exactly one offset operand in TileLang DSL v1") - for index in indices: + source = self._require_pointer_expr(source, "pto.vldsx2 source", memory_space="ub") + for index in index_args: self._require_index_typed_expr(index) - normalized_stride = self._normalize_stride_mode(stride, "pto.vsld stride") + dist = self._normalize_vldsx2_dist(dist) + vreg_type = self._vreg_type_for_dtype(source.type.element_dtype) return SemanticCallExpr( namespace="pto", - name="vsld", - args=(source, *indices, normalized_stride), - type=self._vreg_type_for_dtype(source.type.element_dtype), + name="vldsx2", + args=(source, *index_args, dist), + type=SemanticTupleType(elements=(vreg_type, vreg_type)), ) def _analyze_pstu(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: if len(args) != 3: raise TypeError("pto.pstu expects exactly 3 positional arguments in TileLang DSL v1") - align, value, base = args - self._require_align_expr(align, "pto.pstu align") - mask = self._require_mask_expr(value, "pto.pstu value") - base_ptr = self._require_pointer_expr(base, "pto.pstu base", memory_space="ub") + align_expr, value, base = args + self._require_align_expr(align_expr, "pto.pstu align_in") + mask_type = self._require_mask_expr(value, "pto.pstu value") + base = self._require_pointer_expr(base, "pto.pstu base", memory_space="ub") + if mask_type.granularity == "b16": + expected = ui16 + elif mask_type.granularity == "b32": + expected = ui32 + else: + raise TypeError("pto.pstu only supports !pto.mask and !pto.mask in TileLang DSL v1") + if base.type.element_dtype != expected: + raise TypeError( + f"pto.pstu requires !pto.ptr<{expected.name}, ub> for mask granularity {mask_type.granularity}" + ) return SemanticCallExpr( namespace="pto", name="pstu", - args=(align, value, base_ptr), - type=SemanticTupleType(elements=(_ALIGN_TYPE, base_ptr.type)), - ) - - def _analyze_vstu(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: - if len(args) != 5: - raise TypeError("pto.vstu expects exactly 5 positional arguments in TileLang DSL v1") - align, offset, value, base, mode = args - self._require_align_expr(align, "pto.vstu align") - self._require_index_typed_expr(offset) - vec = self._require_vreg_expr(value, "pto.vstu value") - base_ptr = self._require_pointer_expr(base, "pto.vstu base", memory_space="ub") - if base_ptr.type.element_dtype != vec.element_dtype: - raise TypeError("pto.vstu requires base pointer dtype to match vector element dtype") - normalized_mode = self._normalize_mode_string(mode, "pto.vstu mode") - return SemanticCallExpr( - namespace="pto", - name="vstu", - args=(align, offset, value, base_ptr, normalized_mode), - type=SemanticTupleType(elements=(_ALIGN_TYPE, SemanticIndexType())), + args=(align_expr, value, base), + type=SemanticTupleType(elements=(SemanticAlignType(), base.type)), ) def _analyze_vstus(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: - if len(args) != 5: - raise TypeError("pto.vstus expects exactly 5 positional arguments in TileLang DSL v1") - align, offset, value, base, mode = args - self._require_align_expr(align, "pto.vstus align") - self._require_i32_expr(offset, "pto.vstus offset") - vec = self._require_vreg_expr(value, "pto.vstus value") - base_ptr = self._require_pointer_expr(base, "pto.vstus base", memory_space="ub") - if base_ptr.type.element_dtype != vec.element_dtype: - raise TypeError("pto.vstus requires base pointer dtype to match vector element dtype") - normalized_mode = self._normalize_mode_string(mode, "pto.vstus mode") + if len(args) != 4: + raise TypeError("pto.vstus expects exactly 4 positional arguments in TileLang DSL v1") + align_expr, offset, value, base = args + self._require_align_expr(align_expr, "pto.vstus align_in") + self._require_i32_like_expr(offset, "pto.vstus offset") + self._require_vreg_expr(value, "pto.vstus value") + base = self._require_pointer_expr(base, "pto.vstus base", memory_space="ub") return SemanticCallExpr( namespace="pto", name="vstus", - args=(align, offset, value, base_ptr, normalized_mode), - type=SemanticTupleType(elements=(_ALIGN_TYPE, base_ptr.type)), + args=(align_expr, offset, value, base), + type=SemanticAlignType(), ) def _analyze_vstur(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: - if len(args) != 4: - raise TypeError("pto.vstur expects exactly 4 positional arguments in TileLang DSL v1") - align, value, base, mode = args - self._require_align_expr(align, "pto.vstur align") - vec = self._require_vreg_expr(value, "pto.vstur value") - base_ptr = self._require_pointer_expr(base, "pto.vstur base", memory_space="ub") - if base_ptr.type.element_dtype != vec.element_dtype: - raise TypeError("pto.vstur requires base pointer dtype to match vector element dtype") - normalized_mode = self._normalize_mode_string(mode, "pto.vstur mode") - return SemanticCallExpr( - namespace="pto", - name="vstur", - args=(align, value, base_ptr, normalized_mode), - type=_ALIGN_TYPE, - ) - - def _analyze_vldsx2(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: if len(args) not in {3, 4}: - raise TypeError("pto.vldsx2 expects 3 or 4 positional arguments in TileLang DSL v1") - source, *rest = args - if len(rest) == 2: - index_args = rest[:1] - dist = rest[1] - else: - index_args = rest[:2] - dist = rest[2] - source_type = source.type - if isinstance(source_type, SemanticTileType): - source = self._require_tile_expr(source, "pto.vldsx2 source") - else: - source = self._require_pointer_expr(source, "pto.vldsx2 source", memory_space="ub") - for index in index_args: - self._require_index_typed_expr(index) - dist = self._normalize_vldsx2_dist(dist) - vreg_type = self._vreg_type_for_dtype(source.type.element_dtype) + raise TypeError("pto.vstur expects 3 or 4 positional arguments in TileLang DSL v1") + align_expr, value, base = args[:3] + mode = self._normalize_post_update_mode(args[3] if len(args) == 4 else None, "pto.vstur mode") + self._require_align_expr(align_expr, "pto.vstur align_in") + self._require_vreg_expr(value, "pto.vstur value") + base = self._require_pointer_expr(base, "pto.vstur base", memory_space="ub") return SemanticCallExpr( namespace="pto", - name="vldsx2", - args=(source, *index_args, dist), - type=SemanticTupleType(elements=(vreg_type, vreg_type)), + name="vstur", + args=(align_expr, value, base, mode), + type=SemanticAlignType(), ) def _analyze_load_scalar(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: @@ -4312,61 +3745,6 @@ def _analyze_broadcast_vector_op( raise TypeError(f"call surface `pto.{name}` is not supported in TileLang DSL v1 yet") - def _analyze_vdup_keyword_call( - self, - args: tuple[SemanticExpr, ...], - keywords: dict[str, SemanticExpr], - ) -> SemanticExpr: - if not args: - raise TypeError("pto.vdup expects at least 1 positional argument in TileLang DSL v1") - if len(args) > 2: - raise TypeError("pto.vdup expects 1 or 2 operands in TileLang DSL v1") - if len(args) == 2 and "position" in keywords: - raise TypeError("pto.vdup got multiple values for argument `position` in TileLang DSL v1") - - value = args[0] - if isinstance(value.type, SemanticVRegType): - vec_type = value.type - else: - vec_type = self._vreg_type_for_scalar_or_index(value, "pto.vdup input") - position = self._normalize_position_mode( - keywords.get("position", args[1] if len(args) == 2 else None), - "pto.vdup position", - ) - return SemanticCallExpr( - namespace="pto", - name="vdup", - args=(value, position), - type=vec_type, - ) - - def _analyze_vci_keyword_call( - self, - args: tuple[SemanticExpr, ...], - keywords: dict[str, SemanticExpr], - ) -> SemanticExpr: - if not args: - raise TypeError("pto.vci expects at least 1 positional argument in TileLang DSL v1") - if len(args) > 2: - raise TypeError("pto.vci expects 1 or 2 operands in TileLang DSL v1") - if len(args) == 2 and "order" in keywords: - raise TypeError("pto.vci got multiple values for argument `order` in TileLang DSL v1") - - index = self._require_scalar_or_index_expr(args[0], "pto.vci index") - index_dtype = i32 if isinstance(index.type, SemanticIndexType) else index.type.dtype - if index_dtype.name not in {"i8", "i16", "i32"}: - raise TypeError("pto.vci index only supports i8/i16/i32 in TileLang DSL v1") - order = self._normalize_order_mode( - keywords.get("order", args[1] if len(args) == 2 else None), - "pto.vci order", - ) - return SemanticCallExpr( - namespace="pto", - name="vci", - args=(index, order), - type=self._vreg_type_for_dtype(index_dtype), - ) - def _analyze_unary_vector_op( self, name: str, @@ -4480,8 +3858,8 @@ def _analyze_mask_part_op( if len(args) != 2: raise TypeError(f"pto.{name} expects exactly 2 positional arguments in TileLang DSL") mask = self._require_mask_expr(args[0], f"pto.{name} mask") - part = self._normalize_predicate_part(args[1], f"pto.{name} part") - return SemanticCallExpr(namespace="pto", name=name, args=(args[0], part), type=mask) + self._require_string_expr(args[1], f"pto.{name} part") + return SemanticCallExpr(namespace="pto", name=name, args=args, type=mask) def _analyze_mask_logic_op( self, @@ -4495,23 +3873,13 @@ def _analyze_mask_logic_op( mask = self._require_mask_expr(args[1], "pto.pnot mask") self._require_matching_mask_types(value, mask, "pto.pnot") return SemanticCallExpr(namespace="pto", name=name, args=args, type=value) - if name == "psel": - if len(args) != 3: - raise TypeError("pto.psel expects exactly 3 positional arguments in TileLang DSL") - src0 = self._require_mask_expr(args[0], "pto.psel src0") - src1 = self._require_mask_expr(args[1], "pto.psel src1") - mask = self._require_mask_expr(args[2], "pto.psel mask") - self._require_matching_mask_types(src0, src1, "pto.psel") - self._require_matching_mask_types(src0, mask, "pto.psel") - return SemanticCallExpr(namespace="pto", name=name, args=args, type=src0) - if len(args) != 3: - raise TypeError(f"pto.{name} expects exactly 3 positional arguments in TileLang DSL v1") - src0 = self._require_mask_expr(args[0], f"pto.{name} src0") - src1 = self._require_mask_expr(args[1], f"pto.{name} src1") - self._require_matching_mask_types(src0, src1, f"pto.{name}") - mask = self._require_mask_expr(args[2], f"pto.{name} mask") - self._require_matching_mask_types(src0, mask, f"pto.{name}") + raise TypeError("pto.psel expects exactly 3 positional arguments in TileLang DSL") + src0 = self._require_mask_expr(args[0], "pto.psel src0") + src1 = self._require_mask_expr(args[1], "pto.psel src1") + mask = self._require_mask_expr(args[2], "pto.psel mask") + self._require_matching_mask_types(src0, src1, "pto.psel") + self._require_matching_mask_types(src0, mask, "pto.psel") return SemanticCallExpr(namespace="pto", name=name, args=args, type=src0) def _analyze_compare_op( @@ -4641,26 +4009,6 @@ def _analyze_rearrangement_op( self._require_string_expr(args[2], f"pto.{name} part") return SemanticCallExpr(namespace="pto", name=name, args=args, type=lhs) - def _analyze_predicate_rearrangement_op( - self, - name: str, - args: tuple[SemanticExpr, ...], - ) -> SemanticExpr: - if len(args) != 2: - raise TypeError(f"pto.{name} expects exactly 2 positional arguments in TileLang DSL v1") - lhs = self._require_mask_expr(args[0], f"pto.{name} lhs") - rhs = self._require_mask_expr(args[1], f"pto.{name} rhs") - self._require_matching_mask_types(lhs, rhs, f"pto.{name}") - expected = "b8" if name == "pdintlv_b8" else "b16" - if lhs.granularity != expected: - raise TypeError(f"pto.{name} requires {expected} mask operands in TileLang DSL v1") - return SemanticCallExpr( - namespace="pto", - name=name, - args=args, - type=SemanticTupleType(elements=(lhs, rhs)), - ) - def _analyze_vcvt(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: if len(args) != 3: raise TypeError("pto.vcvt expects exactly 3 positional arguments in TileLang DSL") @@ -4805,10 +4153,6 @@ def _require_vreg_expr(self, expr: SemanticExpr, context: str) -> SemanticVRegTy raise TypeError(f"{context} must be a vector register value in TileLang DSL v1") return expr.type - def _require_align_expr(self, expr: SemanticExpr, context: str) -> None: - if not isinstance(expr.type, SemanticAlignType): - raise TypeError(f"{context} must be an align state value in TileLang DSL v1") - def _require_scalar_expr(self, expr: SemanticExpr, context: str) -> SemanticScalarType: if not isinstance(expr.type, SemanticScalarType): raise TypeError(f"{context} must be a scalar value in TileLang DSL v1") @@ -4886,6 +4230,10 @@ def _require_mask_expr(self, expr: SemanticExpr, context: str) -> SemanticMaskTy raise TypeError(f"{context} must be a mask value in TileLang DSL") return expr.type + def _require_align_expr(self, expr: SemanticExpr, context: str) -> None: + if not isinstance(expr.type, SemanticAlignType): + raise TypeError(f"{context} must be a pto.align value in TileLang DSL v1") + def _require_matching_mask_types( self, lhs: SemanticMaskType, @@ -4907,88 +4255,43 @@ def _require_string_expr(self, expr: SemanticExpr, context: str) -> str: return expr.binding.value raise TypeError(f"{context} must be a string literal in TileLang DSL") - def _normalize_deinterleave_dist(self, expr: SemanticExpr, context: str) -> SemanticLiteralExpr: - if ( - isinstance(expr, SemanticSymbolExpr) - and isinstance(expr.type, SemanticMetaType) - and expr.type.kind == "deinterleave_dist" - and isinstance(expr.value, DeinterleaveDist) - ): - return SemanticLiteralExpr(value=expr.value.value, type=SemanticMetaType(kind="string")) - return SemanticLiteralExpr(value=self._require_string_expr(expr, context), type=SemanticMetaType(kind="string")) - - def _normalize_interleave_dist(self, expr: SemanticExpr, context: str) -> SemanticLiteralExpr: - if ( - isinstance(expr, SemanticSymbolExpr) - and isinstance(expr.type, SemanticMetaType) - and expr.type.kind == "interleave_dist" - and isinstance(expr.value, InterleaveDist) - ): - return SemanticLiteralExpr(value=expr.value.value, type=SemanticMetaType(kind="string")) - return SemanticLiteralExpr(value=self._require_string_expr(expr, context), type=SemanticMetaType(kind="string")) - - def _normalize_predicate_dist( + def _normalize_post_update_mode( self, expr: SemanticExpr | None, context: str, - *, - allowed: set[str], - default: str, - ) -> SemanticLiteralExpr: + ) -> SemanticExpr: if expr is None: - return SemanticLiteralExpr(value=default, type=SemanticMetaType(kind="string")) + return SemanticLiteralExpr(value="NO_POST_UPDATE", type=SemanticMetaType(kind="string")) if ( isinstance(expr, SemanticSymbolExpr) and isinstance(expr.type, SemanticMetaType) - and expr.type.kind == "predicate_dist" - and isinstance(expr.value, PredicateDist) - ): - value = expr.value.value - elif ( - isinstance(expr, SemanticBindingRef) - and isinstance(expr.type, SemanticMetaType) - and expr.type.kind == "predicate_dist" - and isinstance(expr.binding.value, PredicateDist) + and expr.type.kind == "post_update_mode" + and isinstance(expr.value, PostUpdateMode) ): - value = expr.binding.value.value - else: - value = self._require_string_expr(expr, context) - if value not in allowed: - supported = ", ".join(sorted(allowed)) - raise TypeError(f"{context} must be one of {supported} in TileLang DSL v1") - return SemanticLiteralExpr(value=value, type=SemanticMetaType(kind="string")) - - def _normalize_predicate_part(self, expr: SemanticExpr, context: str) -> SemanticLiteralExpr: + return SemanticLiteralExpr(value=expr.value.value, type=SemanticMetaType(kind="string")) if ( - isinstance(expr, SemanticSymbolExpr) - and isinstance(expr.type, SemanticMetaType) - and expr.type.kind == "predicate_part" - and isinstance(expr.value, PredicatePart) - ): - value = expr.value.value - elif ( isinstance(expr, SemanticBindingRef) and isinstance(expr.type, SemanticMetaType) - and expr.type.kind == "predicate_part" - and isinstance(expr.binding.value, PredicatePart) - ): - value = expr.binding.value.value - else: - raise TypeError(f"{context} must be PredicatePart.LOWER or PredicatePart.HIGHER in TileLang DSL v1") - return SemanticLiteralExpr(value=value, type=SemanticMetaType(kind="string")) - - def _normalize_stride_mode(self, expr: SemanticExpr, context: str) -> SemanticLiteralExpr: - if ( - isinstance(expr, SemanticSymbolExpr) - and isinstance(expr.type, SemanticMetaType) - and expr.type.kind == "stride_mode" - and isinstance(expr.value, StrideMode) + and expr.type.kind == "post_update_mode" + and isinstance(expr.binding.value, PostUpdateMode) ): - return SemanticLiteralExpr(value=expr.value.value, type=SemanticMetaType(kind="string")) - return SemanticLiteralExpr(value=self._require_string_expr(expr, context), type=SemanticMetaType(kind="string")) + return SemanticLiteralExpr(value=expr.binding.value.value, type=SemanticMetaType(kind="string")) + raise TypeError( + "pto.vstur mode must be a PostUpdateMode enum such as " + "`pto.PostUpdateMode.NO_POST_UPDATE` or `pto.PostUpdateMode.POST_UPDATE` in TileLang DSL v1" + ) - def _normalize_mode_string(self, expr: SemanticExpr, context: str) -> SemanticLiteralExpr: - return SemanticLiteralExpr(value=self._require_string_expr(expr, context), type=SemanticMetaType(kind="string")) + def _normalize_predicate_store_dist( + self, + expr: SemanticExpr | None, + context: str, + ) -> SemanticExpr: + if expr is None: + return SemanticLiteralExpr(value="NORM", type=SemanticMetaType(kind="string")) + dist = self._require_string_expr(expr, context) + if dist not in {"NORM", "PK"}: + raise TypeError("predicate store dist must be \"NORM\" or \"PK\" in TileLang DSL v1") + return SemanticLiteralExpr(value=dist, type=SemanticMetaType(kind="string")) def _require_i1_expr(self, expr: SemanticExpr, context: str) -> None: scalar = self._require_scalar_expr(expr, context) @@ -5009,11 +4312,6 @@ def _require_i64_like_expr(self, expr: SemanticExpr, context: str) -> None: if scalar.dtype != i64: raise TypeError(f"{context} must be an i64 or index value in TileLang DSL") - def _require_i32_expr(self, expr: SemanticExpr, context: str) -> None: - scalar = self._require_scalar_expr(expr, context) - if scalar.dtype != i32: - raise TypeError(f"{context} must be an i32 value in TileLang DSL") - def _require_tail_remaining_expr(self, expr: SemanticExpr, context: str) -> None: if isinstance(expr.type, SemanticIndexType): return @@ -5357,36 +4655,6 @@ def _try_static_value(self, expr: SemanticExpr | None) -> Any | None: return expr.value if isinstance(expr, SemanticBindingRef): return expr.binding.value - if isinstance(expr, SemanticAttributeAccess): - base_value = self._try_static_value(expr.base) - if isinstance(base_value, TileConfig): - if expr.attr == "b_layout": - return base_value.b_layout - if expr.attr == "s_layout": - return base_value.s_layout - if expr.attr == "s_fractal_size": - return base_value.s_fractal_size - if expr.attr == "pad_value": - return base_value.pad_value - if base_value is not None and hasattr(base_value, expr.attr): - return getattr(base_value, expr.attr) - if isinstance(expr.base.type, SemanticTileType): - tile_type = expr.base.type - config = TileConfig() if tile_type.config is None else tile_type.config - if expr.attr == "shape": - return tile_type.shape - if expr.attr == "valid_shape": - return self._resolved_tile_valid_shape(tile_type) - if expr.attr == "rank": - return tile_type.rank - if expr.attr == "memory_space": - return None if tile_type.memory_space is None else MemorySpace(tile_type.memory_space) - if expr.attr == "config": - return config - if isinstance(expr.base.type, (SemanticTensorViewType, SemanticPartitionTensorViewType)): - if expr.attr == "rank": - return expr.base.type.rank - return None if isinstance(expr, SemanticTupleExpr): elements = [] for element in expr.elements: @@ -5595,7 +4863,6 @@ def analyze_frontend_kernel(node: FrontendKernelNode) -> SemanticKernel: "SemanticBinding", "SemanticBindingRef", "SemanticCallExpr", - "SemanticAlignType", "SemanticDmaOptions", "SemanticDmaLoadStmt", "SemanticDmaStoreStmt", @@ -5603,6 +4870,8 @@ def analyze_frontend_kernel(node: FrontendKernelNode) -> SemanticKernel: "SemanticExprStmt", "SemanticForStmt", "SemanticGetBufStmt", + "SemanticAlignStoreStmt", + "SemanticAlignType", "SemanticIfResult", "SemanticIfStmt", "SemanticIndexType", @@ -5612,6 +4881,7 @@ def analyze_frontend_kernel(node: FrontendKernelNode) -> SemanticKernel: "SemanticMaskType", "SemanticParameter", "SemanticPipeBarrierStmt", + "SemanticPredicateStoreStmt", "SemanticRlsBufStmt", "SemanticReturnStmt", "SemanticScalarType", @@ -5638,6 +4908,7 @@ def analyze_frontend_kernel(node: FrontendKernelNode) -> SemanticKernel: "SemanticTupleType", "SemanticType", "SemanticVRegType", + "SemanticVectorPairStoreStmt", "SemanticVectorStoreStmt", "SemanticWaitFlagDevStmt", "SemanticWaitFlagStmt", diff --git a/tilelang-dsl/python/tilelang_dsl/support_matrix.py b/tilelang-dsl/python/tilelang_dsl/support_matrix.py index 1d4a76c81..6805350e9 100644 --- a/tilelang-dsl/python/tilelang_dsl/support_matrix.py +++ b/tilelang-dsl/python/tilelang_dsl/support_matrix.py @@ -60,18 +60,21 @@ SUPPORTED_VECSCOPE_PTO_CALLS = frozenset( { "make_mask", + "init_align", "vlds", "vldas", "vldus", - "vldx2", "vldsx2", - "vsld", - "vsts", - "vstsx2", "psts", + "pstu", "vsst", - "vstx2", "vsta", + "vstas", + "vstar", + "vsts", + "vstsx2", + "vstus", + "vstur", "vabs", "vrelu", "vexp", @@ -145,34 +148,15 @@ ADVANCED_VECSCOPE_PTO_CALLS = frozenset( { - "pset_b8", - "pset_b16", - "pset_b32", - "pge_b8", - "pge_b16", - "pge_b32", - "plt_b8", - "plt_b16", - "plt_b32", - "plds", - "pld", - "pldi", - "pst", - "psti", "vcmp", "vcmps", "vsel", "vselr", "vselrv2", "pnot", - "pand", - "por", - "pxor", "psel", "ppack", "punpack", - "pdintlv_b8", - "pintlv_b16", "vaddc", "vsubc", "vaddcs", @@ -181,10 +165,6 @@ "vdintlv", "vintlvv2", "vdintlvv2", - "pstu", - "vstu", - "vstus", - "vstur", } ) @@ -268,7 +248,6 @@ { "tile[start:]", "tile[row, col:]", - "tile[row_start:, col_index]", } ) @@ -345,6 +324,9 @@ def get_pto_call_tier(call_name: str) -> str: "pto.dma_copy", "pto.vreduce", "pto.tile", + "BLayout", + "SLayout", + "PadValue", "SyncOpType", } ) @@ -364,16 +346,10 @@ def get_pto_call_tier(call_name: str) -> str: "pto.mask_b32": BASIC_TIER, "BarrierType": BASIC_TIER, "PadMode": BASIC_TIER, - "BLayout": BASIC_TIER, - "SLayout": BASIC_TIER, - "PadValue": BASIC_TIER, - "PredicateDist": ADVANCED_TIER, - "PredicatePart": ADVANCED_TIER, "constexpr": BASIC_TIER, "pto.constexpr": BASIC_TIER, "tile[start:]": BASIC_TIER, "tile[row, col:]": BASIC_TIER, - "tile[row_start:, col_index]": BASIC_TIER, # Advanced tier constructs "ptr": ADVANCED_TIER, # raw pointer constructor "strict_vecscope": ADVANCED_TIER, # explicit vecscope management diff --git a/tilelang-dsl/python/tilelang_dsl/types.py b/tilelang-dsl/python/tilelang_dsl/types.py index fb7932013..e35476876 100644 --- a/tilelang-dsl/python/tilelang_dsl/types.py +++ b/tilelang-dsl/python/tilelang_dsl/types.py @@ -130,19 +130,6 @@ class MemorySpace(str, Enum): UB = "ub" -class BLayout(str, Enum): - ROW_MAJOR = "row_major" - COL_MAJOR = "col_major" - - -class SLayout(str, Enum): - NONE_BOX = "none_box" - - -class PadValue(str, Enum): - ZERO = "zero" - - class Pipe(str, Enum): MTE1 = "PIPE_MTE1" MTE2 = "PIPE_MTE2" @@ -184,6 +171,8 @@ class Event(str, Enum): ID29 = "EVENT_ID29" ID30 = "EVENT_ID30" ID31 = "EVENT_ID31" + + class BarrierType(str, Enum): VV_ALL = "VV_ALL" VST_VLD = "VST_VLD" @@ -208,17 +197,17 @@ class PadMode(str, Enum): class DeinterleaveDist(str, Enum): DINTLV = "DINTLV" BDINTLV = "BDINTLV" - B8 = "DINTLV_B8" - B16 = "DINTLV_B16" - B32 = "DINTLV_B32" + B8 = "DINTLV" + B16 = "DINTLV" + B32 = "DINTLV" BD = "BDINTLV" class InterleaveDist(str, Enum): INTLV = "INTLV" - B8 = "INTLV_B8" - B16 = "INTLV_B16" - B32 = "INTLV_B32" + B8 = "INTLV" + B16 = "INTLV" + B32 = "INTLV" class PositionMode(str, Enum): @@ -230,46 +219,9 @@ class OrderMode(str, Enum): ASC = "ORDER_ASC" -class PredicateDist(str, Enum): - NORM = "NORM" - US = "US" - DS = "DS" - PK = "PK" - - -class PredicatePart(str, Enum): - LOWER = "LOWER" - HIGHER = "HIGHER" - - -class StrideMode(str, Enum): - S3_B16 = "STRIDE_S3_B16" - S4_B64 = "STRIDE_S4_B64" - S8_B32 = "STRIDE_S8_B32" - S2_B64 = "STRIDE_S2_B64" - - -def _coerce_int_config_value(value: Any, field_name: str) -> int: - if isinstance(value, bool) or not isinstance(value, int): - raise TypeError(f"TileConfig field '{field_name}' must be an integer") - return value - - -def _coerce_enum_config_value( - value: Any, - enum_type: type[Enum], - field_name: str, - default: Enum, -) -> Enum: - if value is None: - return default - if isinstance(value, enum_type): - return value - if isinstance(value, str): - for candidate in enum_type: - if value in {candidate.name, candidate.value}: - return candidate - raise TypeError(f"TileConfig field '{field_name}' must be a {enum_type.__name__} or matching string") +class PostUpdateMode(str, Enum): + POST_UPDATE = "POST_UPDATE" + NO_POST_UPDATE = "NO_POST_UPDATE" @dataclass(frozen=True) @@ -280,55 +232,6 @@ class TileConfig: def from_mapping(cls, mapping: Mapping[str, Any]) -> "TileConfig": return cls(tuple(sorted(mapping.items()))) - def _field(self, *names: str) -> Any | None: - values = dict(self.fields) - for name in names: - if name in values: - return values[name] - return None - - @property - def b_layout(self) -> BLayout: - return _coerce_enum_config_value( - self._field("b_layout", "layout"), - BLayout, - "b_layout", - BLayout.ROW_MAJOR, - ) - - @property - def s_layout(self) -> SLayout: - return _coerce_enum_config_value( - self._field("s_layout", "slayout"), - SLayout, - "s_layout", - SLayout.NONE_BOX, - ) - - @property - def s_fractal_size(self) -> int: - value = self._field("s_fractal_size", "fractal") - if value is None: - return 512 - return _coerce_int_config_value(value, "s_fractal_size") - - @property - def pad_value(self) -> PadValue: - return _coerce_enum_config_value( - self._field("pad_value", "pad"), - PadValue, - "pad_value", - PadValue.ZERO, - ) - - -@dataclass(frozen=True) -class TileLayoutDescriptor: - shape: tuple[int, ...] - strides: tuple[int, ...] - byte_strides: tuple[int, ...] - offset: int = 0 - @dataclass(frozen=True) class TileSpecialization: @@ -361,10 +264,10 @@ class TileSpecialization: AnyInt = WildcardType("AnyInt") AnyType = WildcardType("AnyType") AnyMask = WildcardType("AnyMask") -align = AlignType() mask_b8 = MaskType("b8") mask_b16 = MaskType("b16") mask_b32 = MaskType("b32") +align = AlignType() def TypeVar(name: str) -> TypeVariable: @@ -428,45 +331,6 @@ def constexpr(value: bool) -> bool: return value -def tile_strides( - shape: tuple[int, ...], - config: TileConfig | None = None, -) -> tuple[int, ...]: - if not shape: - return () - normalized = TileConfig() if config is None else config - if normalized.b_layout == BLayout.COL_MAJOR and len(shape) == 2: - return (1, shape[0]) - strides = [1] - for dim in reversed(shape[1:]): - strides.insert(0, strides[0] * dim) - return tuple(strides) - - -def tile_byte_strides( - shape: tuple[int, ...], - dtype: ScalarType, - config: TileConfig | None = None, -) -> tuple[int, ...]: - element_bytes = bytewidth(dtype) - return tuple(stride * element_bytes for stride in tile_strides(shape, config)) - - -def tile_layout_descriptor( - shape: tuple[int, ...], - dtype: ScalarType, - config: TileConfig | None = None, - *, - offset: int = 0, -) -> TileLayoutDescriptor: - return TileLayoutDescriptor( - shape=shape, - strides=tile_strides(shape, config), - byte_strides=tile_byte_strides(shape, dtype, config), - offset=offset, - ) - - __all__ = [ "ScalarType", "WildcardType", @@ -478,13 +342,9 @@ def tile_layout_descriptor( "PointerType", "VRegType", "MaskType", - "AlignType", "ptr", "vreg", "MemorySpace", - "BLayout", - "SLayout", - "PadValue", "Pipe", "Event", "PIPE", @@ -497,11 +357,7 @@ def tile_layout_descriptor( "InterleaveDist", "PositionMode", "OrderMode", - "DeinterleaveDist", - "InterleaveDist", - "PredicateDist", - "PredicatePart", - "StrideMode", + "PostUpdateMode", "TileConfig", "TileSpecialization", "i1", @@ -524,7 +380,6 @@ def tile_layout_descriptor( "AnyInt", "AnyType", "AnyMask", - "align", "mask_b8", "mask_b16", "mask_b32", diff --git a/tilelang-dsl/skills/auto-update-vpto-spec/SKILL.md b/tilelang-dsl/skills/auto-update-vpto-spec/SKILL.md index 38bea7675..d610f283b 100644 --- a/tilelang-dsl/skills/auto-update-vpto-spec/SKILL.md +++ b/tilelang-dsl/skills/auto-update-vpto-spec/SKILL.md @@ -21,11 +21,12 @@ license: MIT 1. 读取最新 VPTO 规范 `vpto-latest.md`(如果最新版本来自网络,则先下载保存)。 2. 读取当前 DSL 使用的规范 `vpto-current.md`。 -3. 对比两者差异并生成差异报告: +3. 对比两者差异: - 若无差异:输出“无需更新”,结束。 - - 若有差异:按变更类型分类并执行对应动作(执行动作前向用户询问确认)。 + - 若有差异:生成差异报告。 +4. 根据差异报告,逐项与用户确认每个差异变更的处理方式(新增/修改/删除),按照分类进行处理 -4. 差异分类与处理规则: +5. 差异分类与处理规则: ### A. 新增 op @@ -51,11 +52,11 @@ license: MIT - 在 DSL 实现中将该 op 标记为不受支持,并在用户使用时显式报错。 - 增加测试验证报错信息清晰可见。 -5. 统一补充测试: +6. 统一补充测试: - 至少覆盖:新增/变更/删除的 golden path。 - 包含失败路径(非法参数、已删除 op 调用)验证。 -6. 将vpto-spec-current.md改名为vpto-spec-*.md(如vpto-spec-2024-06.md),并将vpto-latest.md改名为vpto-spec-current.md,保持版本迭代记录。 +7. 将vpto-spec-current.md改名为vpto-spec-*.md(如vpto-spec-2024-06.md),并将vpto-latest.md改名为vpto-spec-current.md,保持版本迭代记录。 --- diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index 8a9de4fb1..11998acaa 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -29,6 +29,7 @@ ) from tilelang_dsl.lowering import AuthoringModule, lower_semantic_kernel from tilelang_dsl.semantic import ( + SemanticAlignStoreStmt, SemanticAlignType, SemanticAssignStmt, SemanticBinaryExpr, @@ -44,6 +45,7 @@ SemanticMaskType, SemanticPipeBarrierStmt, SemanticPtrType, + SemanticPredicateStoreStmt, SemanticRlsBufStmt, SemanticScalarStoreStmt, SemanticScalarType, @@ -85,9 +87,6 @@ def test_package_exports_surface(self) -> None: self.assertTrue(hasattr(pto, "mask_b8")) self.assertTrue(hasattr(pto, "mask_b16")) self.assertTrue(hasattr(pto, "mask_b32")) - self.assertTrue(hasattr(pto, "BLayout")) - self.assertTrue(hasattr(pto, "SLayout")) - self.assertTrue(hasattr(pto, "PadValue")) self.assertTrue(hasattr(pto, "constexpr")) self.assertTrue(hasattr(pto, "bytewidth")) self.assertTrue(hasattr(pto, "get_lanes")) @@ -99,14 +98,9 @@ def test_package_exports_surface(self) -> None: self.assertTrue(hasattr(pto, "InterleaveDist")) self.assertTrue(hasattr(pto, "PositionMode")) self.assertTrue(hasattr(pto, "OrderMode")) - self.assertTrue(hasattr(pto, "DeinterleaveDist")) - self.assertTrue(hasattr(pto, "InterleaveDist")) - self.assertTrue(hasattr(pto, "PredicateDist")) - self.assertTrue(hasattr(pto, "PredicatePart")) - self.assertTrue(hasattr(pto, "StrideMode")) + self.assertTrue(hasattr(pto, "PostUpdateMode")) self.assertTrue(hasattr(pto, "PIPE")) self.assertTrue(hasattr(pto, "EVENT")) - self.assertEqual(repr(pto.align), "align") self.assertTrue(hasattr(pto, "si8")) self.assertTrue(hasattr(pto, "ui8")) self.assertTrue(hasattr(pto, "si16")) @@ -119,22 +113,14 @@ def test_package_exports_surface(self) -> None: self.assertEqual(pto.PadMode.PadNull.value, "PadNull") self.assertEqual(pto.PadMode.PadFirstElem.value, "PadFirstElem") self.assertEqual(pto.PadMode.PadValue.value, "PadValue") - self.assertEqual(pto.BLayout.ROW_MAJOR.value, "row_major") - self.assertEqual(pto.SLayout.NONE_BOX.value, "none_box") - self.assertEqual(pto.PadValue.ZERO.value, "zero") - self.assertEqual(pto.PositionMode.LOWEST.value, "LOWEST") self.assertEqual(pto.DeinterleaveDist.DINTLV.value, "DINTLV") self.assertEqual(pto.DeinterleaveDist.BDINTLV.value, "BDINTLV") self.assertEqual(pto.InterleaveDist.INTLV.value, "INTLV") + self.assertEqual(pto.PositionMode.LOWEST.value, "LOWEST") self.assertEqual(pto.PositionMode.HIGHEST.value, "HIGHEST") self.assertEqual(pto.OrderMode.ASC.value, "ORDER_ASC") - self.assertEqual(pto.DeinterleaveDist.B32.value, "DINTLV_B32") - self.assertEqual(pto.InterleaveDist.B16.value, "INTLV_B16") - self.assertEqual(pto.PredicateDist.NORM.value, "NORM") - self.assertEqual(pto.PredicateDist.PK.value, "PK") - self.assertEqual(pto.PredicatePart.LOWER.value, "LOWER") - self.assertEqual(pto.PredicatePart.HIGHER.value, "HIGHER") - self.assertEqual(pto.StrideMode.S4_B64.value, "STRIDE_S4_B64") + self.assertEqual(pto.PostUpdateMode.POST_UPDATE.value, "POST_UPDATE") + self.assertEqual(pto.PostUpdateMode.NO_POST_UPDATE.value, "NO_POST_UPDATE") self.assertEqual(pto.Event.ID31.value, "EVENT_ID31") self.assertIs(pto.DeinterleaveDist.B32, pto.DeinterleaveDist.DINTLV) self.assertIs(pto.InterleaveDist.B32, pto.InterleaveDist.INTLV) @@ -148,6 +134,7 @@ def test_package_exports_surface(self) -> None: self.assertEqual(pto.bytewidth(pto.ui64), 8) self.assertEqual(pto.get_lanes(pto.ui32), 64) self.assertEqual(pto.elements_per_vreg(pto.si8), 256) + self.assertEqual(repr(pto.align), "align") class TileLangDSLSupportMatrixTests(unittest.TestCase): @@ -161,33 +148,16 @@ def test_stable_starter_surface_groups_map_to_stable_tier(self) -> None: self.assertIn("Tile", AUTHORING_TIER_SURFACE_GROUPS["Tile"]) self.assertNotIn("dma_load/store", AUTHORING_TIER_SURFACE_GROUPS) self.assertIn("pto.vlds", AUTHORING_TIER_SURFACE_GROUPS["base_vector_ops"]) - self.assertIn("pto.vldas", AUTHORING_TIER_SURFACE_GROUPS["base_vector_ops"]) - self.assertIn("pto.vldus", AUTHORING_TIER_SURFACE_GROUPS["base_vector_ops"]) - self.assertIn("pto.vldx2", AUTHORING_TIER_SURFACE_GROUPS["base_vector_ops"]) - self.assertIn("pto.vsld", AUTHORING_TIER_SURFACE_GROUPS["base_vector_ops"]) self.assertIn("pto.vsts", AUTHORING_TIER_SURFACE_GROUPS["base_vector_ops"]) - self.assertIn("pto.psts", AUTHORING_TIER_SURFACE_GROUPS["base_vector_ops"]) - self.assertIn("pto.vsst", AUTHORING_TIER_SURFACE_GROUPS["base_vector_ops"]) - self.assertIn("pto.vstx2", AUTHORING_TIER_SURFACE_GROUPS["base_vector_ops"]) - self.assertIn("pto.vsta", AUTHORING_TIER_SURFACE_GROUPS["base_vector_ops"]) self.assertIn("pto.vadd", AUTHORING_TIER_SURFACE_GROUPS["base_vector_ops"]) self.assertIn("pto.vmuls", AUTHORING_TIER_SURFACE_GROUPS["base_vector_ops"]) self.assertIn("tile[start:]", BASIC_TILE_INDEXING_SURFACES) self.assertIn("tile[row, col:]", BASIC_TILE_INDEXING_SURFACES) - self.assertIn("tile[row_start:, col_index]", BASIC_TILE_INDEXING_SURFACES) self.assertEqual(get_feature_tier("TensorView"), BASIC_TIER) self.assertEqual(get_feature_tier("Tile"), BASIC_TIER) self.assertEqual(get_feature_tier("pto.vlds"), BASIC_TIER) - self.assertEqual(get_feature_tier("pto.vldas"), BASIC_TIER) - self.assertEqual(get_feature_tier("pto.vldus"), BASIC_TIER) - self.assertEqual(get_feature_tier("pto.vldx2"), BASIC_TIER) - self.assertEqual(get_feature_tier("pto.vsld"), BASIC_TIER) self.assertEqual(get_feature_tier("pto.vsts"), BASIC_TIER) - self.assertEqual(get_feature_tier("pto.psts"), BASIC_TIER) - self.assertEqual(get_feature_tier("pto.vsst"), BASIC_TIER) - self.assertEqual(get_feature_tier("pto.vstx2"), BASIC_TIER) - self.assertEqual(get_feature_tier("pto.vsta"), BASIC_TIER) self.assertEqual(get_feature_tier("pto.vadd"), BASIC_TIER) self.assertEqual(get_feature_tier("pto.vmuls"), BASIC_TIER) self.assertEqual(get_feature_tier("pto.get_buf"), BASIC_TIER) @@ -210,16 +180,6 @@ def test_stable_starter_surface_groups_map_to_stable_tier(self) -> None: self.assertEqual(get_feature_tier("pto.vci"), BASIC_TIER) self.assertEqual(get_feature_tier("pto.vpack"), BASIC_TIER) self.assertEqual(get_feature_tier("pto.vsort32"), BASIC_TIER) - self.assertEqual(get_feature_tier("pto.pset_b32"), ADVANCED_TIER) - self.assertEqual(get_feature_tier("pto.plds"), ADVANCED_TIER) - self.assertEqual(get_feature_tier("pto.pand"), ADVANCED_TIER) - self.assertEqual(get_feature_tier("pto.pintlv_b16"), ADVANCED_TIER) - self.assertEqual(get_feature_tier("pto.pstu"), ADVANCED_TIER) - self.assertEqual(get_feature_tier("pto.vstu"), ADVANCED_TIER) - self.assertEqual(get_feature_tier("pto.vstus"), ADVANCED_TIER) - self.assertEqual(get_feature_tier("pto.vstur"), ADVANCED_TIER) - self.assertEqual(get_feature_tier("PredicateDist"), ADVANCED_TIER) - self.assertEqual(get_feature_tier("PredicatePart"), ADVANCED_TIER) self.assertEqual(get_feature_tier("pto.vldsx2"), BASIC_TIER) self.assertEqual(get_feature_tier("pto.vstsx2"), BASIC_TIER) self.assertEqual(get_feature_tier("PadMode"), BASIC_TIER) @@ -234,12 +194,8 @@ def test_stable_starter_surface_groups_map_to_stable_tier(self) -> None: self.assertEqual(get_feature_tier("pto.elements_per_vreg"), BASIC_TIER) self.assertEqual(get_feature_tier("pto.constexpr"), BASIC_TIER) self.assertEqual(get_feature_tier("constexpr"), BASIC_TIER) - self.assertEqual(get_feature_tier("BLayout"), BASIC_TIER) - self.assertEqual(get_feature_tier("SLayout"), BASIC_TIER) - self.assertEqual(get_feature_tier("PadValue"), BASIC_TIER) self.assertEqual(get_feature_tier("tile[start:]"), BASIC_TIER) self.assertEqual(get_feature_tier("tile[row, col:]"), BASIC_TIER) - self.assertEqual(get_feature_tier("tile[row_start:, col_index]"), BASIC_TIER) def test_non_stable_surface_groups_keep_advanced_boundaries(self) -> None: self.assertEqual(get_surface_group_tier("strict_vecscope"), ADVANCED_TIER) @@ -1192,7 +1148,7 @@ def kernel(inp: pto.TensorView, tile: pto.Tile): pto.vlds(tile, offset=0) return None - self.assertIn("keyword arguments are only supported on selected public call surfaces", str(ctx.exception)) + self.assertIn("no public call surface currently accepts them", str(ctx.exception)) self.assertIn(f"{__file__}:", str(ctx.exception)) def test_frontend_rewrites_template_slot_to_selected_real_op(self) -> None: @@ -1740,251 +1696,6 @@ def kernel(tile: pto.Tile, scale: pto.f32): self.assertRegex(text, r"%activated_\d+ = pto\.vrelu %summed_\d+, %mask_\d+ : !pto\.vreg<64xf32>, !pto\.mask -> !pto\.vreg<64xf32>") self.assertRegex(text, r"pto\.vsts %activated_\d+, %dst_\d+\[%lane_\d+\], %mask_\d+ : !pto\.vreg<64xf32>, !pto\.ptr, !pto\.mask") - def test_basic_vector_memory_family_surfaces_lower_from_tile_indexing(self) -> None: - @pto.vkernel(op="vector_memory_basic_unique", dtypes=[(pto.f32, pto.f32)]) - def kernel(src: pto.Tile, dst: pto.Tile): - align = pto.vldas(src[0, 0:]) - vec, next_align, base_out = pto.vldus(src[0, 0:], align) - low, high = pto.vldx2(src[0, 0:], pto.DeinterleaveDist.B32) - strided = pto.vsld(src[0, 0], pto.StrideMode.S4_B64) - mask = pto.make_mask(pto.f32, pto.PAT.ALL) - pto.psts(mask, dst[0, 0:]) - pto.vsst(vec, dst[0, 0:], pto.StrideMode.S4_B64) - pto.vstx2(low, high, dst[0, 0:], pto.InterleaveDist.B32, mask) - pto.vsta(next_align, dst[0, 0:]) - return None - - specialized = kernel.specialize( - src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), - dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), - ) - - semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) - vecscope = semantic_kernel.body[0] - self.assertIsInstance(vecscope, SemanticVecscopeStmt) - self.assertIsInstance(vecscope.body[0], SemanticAssignStmt) - self.assertIsInstance(vecscope.body[0].targets[0].type, SemanticAlignType) - self.assertEqual(vecscope.body[1].targets[0].type.element_dtype, pto.f32) - self.assertEqual(vecscope.body[2].value.name, "vldx2") - - text = specialized.mlir_text() - self.assertIn("pto.vldas", text) - self.assertIn("pto.vldus", text) - self.assertIn('pto.vldx2', text) - self.assertIn('"DINTLV_B32"', text) - self.assertIn("pto.vsld", text) - self.assertIn('"STRIDE_S4_B64"', text) - self.assertIn("pto.psts", text) - self.assertIn("pto.vsst", text) - self.assertIn("pto.vstx2", text) - self.assertIn('"INTLV_B32"', text) - self.assertIn("pto.vsta", text) - - def test_col_major_tile_vector_indexing_lowers_with_column_major_layout(self) -> None: - @pto.vkernel(op="vector_memory_col_major_unique", dtypes=[(pto.f32, pto.f32)]) - def kernel(src: pto.Tile, dst: pto.Tile): - align = pto.vldas(src[2:, 3]) - streamed, next_align, base_out = pto.vldus(src[2:, 3], align) - mask = pto.make_mask(pto.f32, pto.PAT.ALL) - vec = pto.vlds(src[2:, 3]) - pto.vsts(vec, dst[2:, 3], mask) - pto.vsta(next_align, dst[2:, 3]) - return None - - col_major = pto.TileConfig.from_mapping({"layout": "col_major"}) - specialized = kernel.specialize( - src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB, config=col_major), - dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB, config=col_major), - ) - - text = specialized.mlir_text() - self.assertIn("blayout=col_major", text) - self.assertRegex(text, r"pto\.vlds %tmp_\d+\[%c2, %c3\] : memref<8x64xf32, #pto\.address_space> -> !pto\.vreg<64xf32>") - self.assertRegex(text, r"%tmp_\d+ = arith\.muli %c3, %c8 : index") - self.assertIn("pto.addptr", text) - self.assertIn("pto.vldus", text) - self.assertIn("pto.vsta", text) - - def test_row_major_tile_rejects_column_major_vector_indexing_syntax(self) -> None: - @pto.vkernel(op="row_major_rejects_col_major_index_unique", dtypes=[(pto.f32,)]) - def kernel(src: pto.Tile): - pto.vlds(src[1:, 2]) - return None - - specialized = kernel.specialize( - src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), - ) - - with self.assertRaises(TypeError) as ctx: - analyze_frontend_kernel(build_frontend_kernel_node(specialized)) - self.assertIn("Tile[row, col:]", str(ctx.exception)) - - def test_col_major_tile_rejects_row_major_vector_indexing_syntax(self) -> None: - @pto.vkernel(op="col_major_rejects_row_major_index_unique", dtypes=[(pto.f32,)]) - def kernel(src: pto.Tile): - pto.vlds(src[1, 2:]) - return None - - specialized = kernel.specialize( - src=pto.TileSpecialization( - shape=(8, 64), - memory_space=pto.MemorySpace.UB, - config=pto.TileConfig.from_mapping({"layout": "col_major"}), - ), - ) - - with self.assertRaises(TypeError) as ctx: - analyze_frontend_kernel(build_frontend_kernel_node(specialized)) - self.assertIn("Tile[row_start:, col_index]", str(ctx.exception)) - - def test_advanced_stateful_vector_memory_surfaces_lower_with_pointer_state(self) -> None: - @pto.vkernel(op="vector_memory_stateful_unique", dtypes=[(pto.f32, pto.f32)], advanced=True) - def kernel(src: pto.Tile, dst: pto.Tile): - ub_src = src.as_ptr() - ub_dst = dst.as_ptr() - align0 = pto.vldas(ub_src) - vec = pto.vlds(ub_src, 0) - mask = pto.make_mask(pto.f32, pto.PAT.ALL) - align1, base1 = pto.pstu(align0, mask, ub_dst) - align2, offset2 = pto.vstu(align1, 0, vec, ub_dst, "MODE_ZEROING") - align3, base3 = pto.vstus(align2, pto.i32(16), vec, ub_dst, "MODE_ZEROING") - align4 = pto.vstur(align3, vec, ub_dst, "MODE_ZEROING") - return None - - specialized = kernel.specialize( - src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), - dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), - ) - - semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) - self.assertIsInstance(semantic_kernel.body[0], SemanticAssignStmt) - self.assertEqual(semantic_kernel.body[0].value.name, "tile_as_ptr") - - text = specialized.mlir_text() - self.assertIn("pto.pstu", text) - self.assertIn("pto.vstu", text) - self.assertIn("pto.vstus", text) - self.assertIn("pto.vstur", text) - self.assertIn('"MODE_ZEROING"', text) - self.assertRegex(text, r"= pto\.vstu %align1_\d+, %c0, %vec_\d+, %ub_dst_\d+, \"MODE_ZEROING\"") - self.assertRegex(text, r"= pto\.vstus %align2_\d+, %(?:c16_i32|tmp_\d+), %vec_\d+, %ub_dst_\d+, \"MODE_ZEROING\"") - - def test_advanced_direct_predicate_surfaces_lower_with_typed_families(self) -> None: - @pto.vkernel(op="predicate_surface_unique", dtypes=[(pto.f32, pto.i32)], advanced=True) - def kernel(tile: pto.Tile, remaining: pto.i32): - all8 = pto.pset_b8(pto.PAT.ALL) - all16 = pto.pset_b16(pto.PAT.ALL) - all32 = pto.pset_b32(pto.PAT.ALL) - tail8 = pto.pge_b8(pto.PAT.ALL) - tail16 = pto.pge_b16(pto.PAT.ALL) - tail32 = pto.pge_b32(pto.PAT.ALL) - mask8, rem8 = pto.plt_b8(remaining) - mask16, rem16 = pto.plt_b16(remaining) - mask32, rem32 = pto.plt_b32(remaining) - gate32 = all32 - and_mask = pto.pand(all32, tail32, gate32) - or_mask = pto.por(and_mask, all32, gate32) - xor_mask = pto.pxor(or_mask, tail32, gate32) - low8, high8 = pto.pdintlv_b8(all8, tail8) - low16, high16 = pto.pintlv_b16(all16, tail16) - return None - - specialized = kernel.specialize( - tile=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), - ) - - text = specialized.mlir_text() - self.assertIn(' = pto.pset_b8 "PAT_ALL"', text) - self.assertIn(' = pto.pset_b16 "PAT_ALL"', text) - self.assertIn(' = pto.pset_b32 "PAT_ALL"', text) - self.assertIn(' = pto.pge_b8 "PAT_ALL"', text) - self.assertIn(' = pto.pge_b16 "PAT_ALL"', text) - self.assertIn(' = pto.pge_b32 "PAT_ALL"', text) - self.assertIn(" = pto.plt_b8 ", text) - self.assertIn(" = pto.plt_b16 ", text) - self.assertIn(" = pto.plt_b32 ", text) - self.assertIn(" = pto.pand ", text) - self.assertIn(" = pto.por ", text) - self.assertIn(" = pto.pxor ", text) - self.assertIn(" = pto.pdintlv_b8 ", text) - self.assertIn(" = pto.pintlv_b16 ", text) - self.assertEqual(text.count('pto.pset_b32 "PAT_ALL"'), 1) - - def test_advanced_predicate_memory_surfaces_lower_with_dist_enums(self) -> None: - @pto.vkernel(op="predicate_memory_surface_unique", dtypes=[(pto.f32,)], advanced=True) - def kernel(tile: pto.Tile): - ub = tile.as_ptr() - align = pto.vldas(ub) - all_mask = pto.pset_b32(pto.PAT.ALL) - loaded0 = pto.plds(ub, 0) - loaded1 = pto.pld(ub, 0, pto.PredicateDist.NORM) - loaded2 = pto.pldi(ub, pto.i32(4), pto.PredicateDist.US) - pto.psts(all_mask, ub, 0) - pto.pst(loaded0, ub, 0) - pto.psti(loaded1, ub, pto.i32(8), pto.PredicateDist.PK) - next_align, next_base = pto.pstu(align, all_mask, ub) - return None - - specialized = kernel.specialize( - tile=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), - ) - - text = specialized.mlir_text() - self.assertIn('{dist = "NORM"}', text) - self.assertIn(" = pto.pld ", text) - self.assertIn(', "NORM" : !pto.ptr, index -> !pto.mask', text) - self.assertIn(" = pto.pldi ", text) - self.assertIn(', "US" : !pto.ptr, i32 -> !pto.mask', text) - self.assertIn("pto.psts ", text) - self.assertIn(': !pto.mask, !pto.ptr', text) - self.assertIn("pto.pst ", text) - self.assertIn(', "NORM" : !pto.mask, !pto.ptr, index', text) - self.assertIn("pto.psti ", text) - self.assertIn(', "PK" : !pto.mask, !pto.ptr, i32', text) - self.assertIn("pto.pstu", text) - - def test_pld_and_pldi_require_explicit_predicate_dist(self) -> None: - @pto.vkernel(op="predicate_memory_missing_dist", dtypes=[(pto.f32,)], advanced=True) - def kernel(tile: pto.Tile): - ub = tile.as_ptr() - loaded1 = pto.pld(ub, 0) - loaded2 = pto.pldi(ub, pto.i32(4)) - pto.pst(loaded1, ub, 0) - pto.pst(loaded2, ub, 0) - return None - - specialized = kernel.specialize( - tile=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), - ) - - with self.assertRaises(TypeError) as ctx: - specialized.mlir_text() - - message = str(ctx.exception) - self.assertTrue( - "pto.pld expects source, offset, and dist" in message - or "pto.pldi expects source, imm_offset, and dist" in message - ) - - def test_pand_por_pxor_require_explicit_gate_mask(self) -> None: - @pto.vkernel(op="predicate_logic_missing_gate", dtypes=[(pto.f32, pto.i32)], advanced=True) - def kernel(tile: pto.Tile, remaining: pto.i32): - all32 = pto.pset_b32(pto.PAT.ALL) - tail32 = pto.pge_b32(pto.PAT.ALL) - and_mask = pto.pand(all32, tail32) - or_mask = pto.por(and_mask, all32) - xor_mask = pto.pxor(or_mask, tail32) - return None - - specialized = kernel.specialize( - tile=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), - ) - - with self.assertRaises(TypeError) as ctx: - specialized.mlir_text() - - self.assertIn("expects exactly 3 positional arguments", str(ctx.exception)) - def test_tail_make_mask_lowers_to_typed_plt_and_updates_remaining(self) -> None: @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.i32)], advanced=True) def kernel(tile: pto.Tile, remaining: pto.i32): @@ -2460,7 +2171,7 @@ def kernel(dst: pto.Tile, src: pto.Tile, seed: pto.i32): dup_from_vec = pto.vdup(vec0, all_mask, pto.PositionMode.HIGHEST) dup_from_scalar = pto.vdup(seed, all_mask) idx0 = pto.vci(seed) - idx1 = pto.vci(seed, order=pto.OrderMode.ASC) + idx1 = pto.vci(seed, pto.OrderMode.ASC) out = pto.vadd(broadcast, dup_from_vec, all_mask) out = pto.vadd(out, dup_from_scalar, all_mask) @@ -3355,61 +3066,6 @@ def kernel(src: pto.Tile, dst: pto.TensorView): r"pto\.copy_ubuf_to_gm %ub_ptr_\d+, %gm_ptr_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+", ) - def test_copy_ubuf_to_ubuf_guide_surface_lowers_in_advanced_mode(self) -> None: - @pto.vkernel( - op="ub_to_ub_dma_guide_unique", - dtypes=[(pto.f32, pto.f32, pto.i64, pto.i64)], - advanced=True, - ) - def kernel(src: pto.Tile, dst: pto.Tile, src_offset: pto.i64, dst_offset: pto.i64): - src_ptr = src.as_ptr() - dst_ptr = dst.as_ptr() - - pto.copy_ubuf_to_ubuf(src_ptr, dst_ptr, src_offset, 32, 128, dst_offset, 160, 192) - pto.copy_ubuf_to_ubuf( - src=src_ptr, - dst=dst_ptr, - src_offset=src_offset, - src_stride0=16, - src_stride1=64, - dst_offset=dst_offset, - dst_stride0=96, - dst_stride1=128, - ) - return None - - specialized = kernel.specialize( - src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), - dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), - ) - - semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) - low_level_copies = [stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticLowLevelCopyStmt)] - self.assertEqual(len(low_level_copies), 2) - self.assertTrue(all(isinstance(stmt.source, SemanticCallExpr) for stmt in low_level_copies)) - self.assertTrue(all(isinstance(stmt.destination, SemanticCallExpr) for stmt in low_level_copies)) - self.assertTrue(all(stmt.source.name == "addptr" for stmt in low_level_copies)) - self.assertTrue(all(stmt.destination.name == "addptr" for stmt in low_level_copies)) - - text = specialized.mlir_text() - self.assertRegex( - text, - r"%src_ptr_\d+ = pto\.tile_buf_addr %arg0 : !pto\.tile_buf -> !pto\.ptr", - ) - self.assertRegex( - text, - r"%dst_ptr_\d+ = pto\.tile_buf_addr %arg1 : !pto\.tile_buf -> !pto\.ptr", - ) - self.assertRegex( - text, - r"%tmp_\d+ = pto\.addptr %src_ptr_\d+, %arg2 : !pto\.ptr -> !pto\.ptr", - ) - self.assertRegex( - text, - r"%tmp_\d+ = pto\.addptr %dst_ptr_\d+, %arg3 : !pto\.ptr -> !pto\.ptr", - ) - self.assertEqual(text.count("pto.copy_ubuf_to_ubuf "), 2) - def test_castptr_rejects_tensorview_or_tile_inputs_in_advanced_mode(self) -> None: @pto.vkernel(op="castptr_tensorview_reject_unique", dtypes=[(pto.f32,)], advanced=True) def tensorview_kernel(inp: pto.TensorView): @@ -3585,8 +3241,8 @@ def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile, scalar: pto.i32): cmp_scalar_mask = pto.vcmps(lhs, scalar, all_mask, "gt") negated = pto.pnot(cmp_mask, all_mask) picked = pto.psel(cmp_mask, negated, cmp_scalar_mask) - packed = pto.ppack(picked, pto.PredicatePart.LOWER) - unpacked = pto.punpack(packed, pto.PredicatePart.HIGHER) + packed = pto.ppack(picked, "PART_EVEN") + unpacked = pto.punpack(packed, "PART_ODD") sum_vec, carry_mask = pto.vaddc(lhs, rhs, all_mask) diff_vec, borrow_mask = pto.vsubc(lhs, rhs, all_mask) sum_with_carry, carry_mask2 = pto.vaddcs(sum_vec, diff_vec, carry_mask, all_mask) @@ -3620,9 +3276,9 @@ def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile, scalar: pto.i32): self.assertIn(" = pto.pnot ", text) self.assertIn(" = pto.psel ", text) self.assertIn(' = pto.ppack ', text) - self.assertIn('"LOWER"', text) + self.assertIn('"PART_EVEN"', text) self.assertIn(' = pto.punpack ', text) - self.assertIn('"HIGHER"', text) + self.assertIn('"PART_ODD"', text) self.assertRegex( text, r"%sum_vec_\d+, %carry_mask_\d+ = pto\.vaddc %lhs_\d+, %rhs_\d+, %all_mask_\d+ : !pto\.vreg<64xi32>, !pto\.vreg<64xi32>, !pto\.mask -> !pto\.vreg<64xi32>, !pto\.mask", @@ -3654,29 +3310,6 @@ def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile, scalar: pto.i32): self.assertIn(" = pto.vselrv2 ", text) self.assertIn("pto.vsts ", text) - def test_ppack_and_punpack_require_predicate_part_enum(self) -> None: - @pto.vkernel(op="predicate_part_typecheck", dtypes=[(pto.i32, pto.i32, pto.i32)], advanced=True) - def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): - all_mask = pto.make_mask(pto.i32, pto.PAT.ALL) - lhs = pto.vlds(src0[0, 0:]) - rhs = pto.vlds(src1[0, 0:]) - cmp_mask = pto.vcmp(lhs, rhs, all_mask, "lt") - packed = pto.ppack(cmp_mask, "LOWER") - unpacked = pto.punpack(packed, "HIGHER") - pto.vsts(pto.vsel(lhs, rhs, unpacked), dst[0, 0:], all_mask) - return None - - specialized = kernel.specialize( - dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), - src0=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), - src1=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), - ) - - with self.assertRaises(TypeError) as ctx: - specialized.mlir_text() - - self.assertIn("PredicatePart.LOWER or PredicatePart.HIGHER", str(ctx.exception)) - def test_elementwise_kernel_positive_regression_covers_vecscope_tail_mask_and_dynamic_loop_bound(self) -> None: @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32, pto.i32)], advanced=True) def kernel(inp: pto.TensorView, tile: pto.Tile, remaining: pto.i32): @@ -3784,15 +3417,6 @@ def kernel(inp: pto.TensorView, tile: pto.Tile, flag: pto.i32): self.assertIn("scf.for %lane_", text) self.assertIn("pto.barrier #pto.pipe", text) - def test_sync_ops_accept_event_class_alias_and_full_event_range(self) -> None: - Event = pto.Event - Pipe = pto.Pipe - - @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32)], advanced=True) - def kernel(inp: pto.TensorView, tile: pto.Tile): - pto.set_flag(Pipe.MTE2, Pipe.V, Event.ID31) - pto.wait_flag(Pipe.MTE2, Pipe.V, Event.ID31) - def test_extended_sync_buffer_ops_lower_to_authoring_surface(self) -> None: Pipe = pto.Pipe Event = pto.Event @@ -3928,6 +3552,82 @@ def kernel(src: pto.Tile, dst: pto.Tile): self.assertIn('"DINTLV"', text) self.assertIn('"INTLV"', text) + def test_align_load_and_stateful_store_ops_lower_to_current_vpto_surface(self) -> None: + @pto.vkernel( + op="align_load_and_stateful_store_ops", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel( + src: pto.ptr(pto.f32, pto.MemorySpace.UB), + dst: pto.ptr(pto.f32, pto.MemorySpace.UB), + ): + load_align = pto.vldas(src) + vec, load_align = pto.vldus(src, load_align) + store_align = pto.init_align() + store_align = pto.vstus(store_align, 0, vec, dst) + store_align = pto.vstur(store_align, vec, dst) + pto.vstas(store_align, dst, 0) + post_align = pto.vstur(pto.init_align(), vec, dst, pto.PostUpdateMode.POST_UPDATE) + pto.vstar(post_align, dst) + return None + + specialized = kernel.specialize() + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + vecscope = next(stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticVecscopeStmt)) + align_store_stmts = [stmt for stmt in vecscope.body if isinstance(stmt, SemanticAlignStoreStmt)] + + self.assertTrue(any(isinstance(stmt, SemanticAssignStmt) and isinstance(stmt.value.type, SemanticAlignType) for stmt in vecscope.body)) + self.assertEqual(len(align_store_stmts), 2) + self.assertEqual([stmt.op_name for stmt in align_store_stmts], ["vstas", "vstar"]) + + text = specialized.mlir_text() + self.assertIn("pto.vldas", text) + self.assertIn("pto.vldus", text) + self.assertIn("pto.init_align", text) + self.assertIn("pto.vstus", text) + self.assertIn("pto.vstur", text) + self.assertIn("pto.vstas", text) + self.assertIn("pto.vstar", text) + self.assertIn('"POST_UPDATE"', text) + self.assertIn('"NO_POST_UPDATE"', text) + self.assertIn("!pto.align", text) + + def test_predicate_store_and_compatibility_store_sugar_lower_to_supported_ops(self) -> None: + @pto.vkernel( + op="predicate_store_and_store_sugar", + dtypes=[(pto.f32, pto.ui32)], + advanced=True, + ) + def kernel( + dst: pto.ptr(pto.f32, pto.MemorySpace.UB), + mask_dst: pto.ptr(pto.ui32, pto.MemorySpace.UB), + ): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + pto.psts(mask, mask_dst, 0) + align = pto.init_align() + align, mask_base = pto.pstu(align, mask, mask_dst) + pto.vsta(align, mask_base, 0) + pto.vsst(1.0, dst, 0, mask) + return None + + specialized = kernel.specialize() + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + vecscope = next(stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticVecscopeStmt)) + + self.assertTrue(any(isinstance(stmt, SemanticPredicateStoreStmt) for stmt in vecscope.body)) + self.assertTrue(any(isinstance(stmt, SemanticAlignStoreStmt) and stmt.op_name == "vstas" for stmt in vecscope.body)) + + text = specialized.mlir_text() + self.assertIn("pto.psts", text) + self.assertIn('"NORM"', text) + self.assertIn("pto.pstu", text) + self.assertIn("pto.vbr", text) + self.assertIn("pto.vsts", text) + self.assertIn("pto.vstas", text) + self.assertNotIn("pto.vsst", text) + self.assertNotIn("pto.vsta ", text) + def test_strict_vecscope_rejects_implicit_capture_during_semantic_analysis(self) -> None: @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f16, pto.i32)], advanced=True) def kernel(inp: pto.TensorView, tile: pto.Tile, scale: pto.i32): @@ -4324,6 +4024,20 @@ def kernel(x: pto.TensorView): self.assertIn("arbitrary external call `helper`", str(ctx.exception)) self.assertIn(f"{__file__}:", str(ctx.exception)) + def test_vstur_rejects_raw_string_mode_and_requires_enum(self) -> None: + with self.assertRaises(TypeError) as ctx: + + @pto.vkernel(op="vstur_raw_string_mode_unique", dtypes=[(pto.f32,)], advanced=True) + def kernel(dst: pto.ptr(pto.f32, pto.MemorySpace.UB)): + align = pto.init_align() + vec = pto.vbr(1.0) + pto.vstur(align, vec, dst, "POST_UPDATE") + return None + + kernel.specialize().mlir_text() + + self.assertIn("pto.vstur mode must be a PostUpdateMode enum", str(ctx.exception)) + def test_unsupported_pto_surface_reports_source_location(self) -> None: with self.assertRaises(pto.TileLangFrontendError) as ctx: @@ -4335,19 +4049,6 @@ def kernel(x: pto.TensorView): self.assertIn("unsupported op surface `pto.not_a_real_surface`", str(ctx.exception)) self.assertIn(f"{__file__}:", str(ctx.exception)) - def test_removed_tile_derived_query_surface_is_rejected(self) -> None: - @pto.vkernel(op="removed_tile_query_surface_unique", dtypes=[(pto.f32,)], advanced=True) - def kernel(dst: pto.Tile): - layout = dst.layout_descriptor - return None - - with self.assertRaises(TypeError) as ctx: - kernel.specialize( - dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), - ).mlir_text() - - self.assertIn("unsupported attribute access 'layout_descriptor'", str(ctx.exception)) - def test_strict_vecscope_requires_advanced_mode(self) -> None: with self.assertRaises(pto.TileLangFrontendError) as ctx: From 5364861015f1a0f6d4b5ae50b89b5e9d835b2ceb Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Mon, 13 Apr 2026 16:07:19 +0800 Subject: [PATCH 041/192] Fix failed lit case --- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index 11998acaa..513224ec8 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -1148,7 +1148,10 @@ def kernel(inp: pto.TensorView, tile: pto.Tile): pto.vlds(tile, offset=0) return None - self.assertIn("no public call surface currently accepts them", str(ctx.exception)) + self.assertIn( + "`pto.vlds` does not support keyword arguments in TileLang DSL v1", + str(ctx.exception), + ) self.assertIn(f"{__file__}:", str(ctx.exception)) def test_frontend_rewrites_template_slot_to_selected_real_op(self) -> None: From e4642102ce93c3e54881850840f1abfdf2d2c232 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Mon, 13 Apr 2026 16:49:40 +0800 Subject: [PATCH 042/192] Dump source location for DSL frontend errors --- .../python/tilelang_dsl/frontend_ast.py | 247 +++++++++++++----- tilelang-dsl/python/tilelang_dsl/semantic.py | 135 ++++++++-- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 39 +++ 3 files changed, 322 insertions(+), 99 deletions(-) diff --git a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py index 30b4dd13d..87ed9f43d 100644 --- a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py +++ b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py @@ -36,6 +36,13 @@ class FrontendTileSpecializationNode: valid_shape: tuple[int | None, ...] | None +@dataclass(frozen=True) +class FrontendSourceLocation: + path: str + line: int + column: int + + class FrontendExprNode: """Base class for lowered frontend expressions.""" @@ -232,6 +239,26 @@ def enter_inline_proc(self, name: str, source_info: Any) -> "_FrontendBuildConte ) +def _attach_source_location( + frontend_node: FrontendExprNode | FrontendStmtNode, + ast_node: ast.AST, + context: _FrontendBuildContext, +) -> FrontendExprNode | FrontendStmtNode: + if context.source_info is None: + return frontend_node + line, column = context.source_info.location(ast_node) + object.__setattr__( + frontend_node, + "source_location", + FrontendSourceLocation( + path=context.source_info.path, + line=line, + column=column, + ), + ) + return frontend_node + + def _inline_proc_param_specs(inline_proc: _FrontendInlineProc) -> tuple[tuple[str, ast.expr | None], ...]: function_def = inline_proc.source_info.function_def params = function_def.args.args @@ -740,9 +767,9 @@ def _build_call_keywords( def _build_expr(node: ast.AST, context: _FrontendBuildContext) -> FrontendExprNode: if isinstance(node, ast.Name): - return FrontendNameExpr(name=node.id) + return _attach_source_location(FrontendNameExpr(name=node.id), node, context) if isinstance(node, ast.Constant): - return FrontendConstantExpr(value=node.value) + return _attach_source_location(FrontendConstantExpr(value=node.value), node, context) if isinstance(node, ast.UnaryOp): if isinstance(node.op, ast.UAdd): sign = 1 @@ -764,15 +791,27 @@ def _build_expr(node: ast.AST, context: _FrontendBuildContext) -> FrontendExprNo node, "unary +/- currently only supports numeric literals in TileLang DSL v1", ) - return FrontendConstantExpr(value=literal if sign > 0 else -literal) + return _attach_source_location( + FrontendConstantExpr(value=literal if sign > 0 else -literal), + node, + context, + ) if isinstance(node, ast.Slice): start = None if node.lower is None else _build_expr(node.lower, context) stop = None if node.upper is None else _build_expr(node.upper, context) step = None if node.step is None else _build_expr(node.step, context) - return FrontendSliceExpr(start=start, stop=stop, step=step) + return _attach_source_location( + FrontendSliceExpr(start=start, stop=stop, step=step), + node, + context, + ) if isinstance(node, ast.Tuple): - return FrontendTupleExpr( - elements=tuple(_build_expr(elt, context) for elt in node.elts) + return _attach_source_location( + FrontendTupleExpr( + elements=tuple(_build_expr(elt, context) for elt in node.elts) + ), + node, + context, ) if isinstance(node, ast.Attribute): path = _attribute_path(node) @@ -793,12 +832,24 @@ def _build_expr(node: ast.AST, context: _FrontendBuildContext) -> FrontendExprNo "OrderMode", "PostUpdateMode", } and len(path) >= 2: - return FrontendSymbolExpr(namespace=".".join(path[:-1]), name=path[-1]) - return FrontendAttributeExpr(base=_build_expr(node.value, context), attr=node.attr) + return _attach_source_location( + FrontendSymbolExpr(namespace=".".join(path[:-1]), name=path[-1]), + node, + context, + ) + return _attach_source_location( + FrontendAttributeExpr(base=_build_expr(node.value, context), attr=node.attr), + node, + context, + ) if isinstance(node, ast.Subscript): - return FrontendSubscriptExpr( - base=_build_expr(node.value, context), - index=_build_expr(node.slice, context), + return _attach_source_location( + FrontendSubscriptExpr( + base=_build_expr(node.value, context), + index=_build_expr(node.slice, context), + ), + node, + context, ) if isinstance(node, ast.BinOp): op_name = _BINARY_OP_NAMES.get(type(node.op)) @@ -807,10 +858,14 @@ def _build_expr(node: ast.AST, context: _FrontendBuildContext) -> FrontendExprNo node, f"unsupported binary operator `{type(node.op).__name__}` in TileLang DSL v1", ) - return FrontendBinaryExpr( - lhs=_build_expr(node.left, context), - op=op_name, - rhs=_build_expr(node.right, context), + return _attach_source_location( + FrontendBinaryExpr( + lhs=_build_expr(node.left, context), + op=op_name, + rhs=_build_expr(node.right, context), + ), + node, + context, ) if isinstance(node, ast.Compare): if len(node.ops) != 1 or len(node.comparators) != 1: @@ -824,10 +879,14 @@ def _build_expr(node: ast.AST, context: _FrontendBuildContext) -> FrontendExprNo node, f"unsupported comparison operator `{type(node.ops[0]).__name__}` in TileLang DSL v1", ) - return FrontendBinaryExpr( - lhs=_build_expr(node.left, context), - op=op_name, - rhs=_build_expr(node.comparators[0], context), + return _attach_source_location( + FrontendBinaryExpr( + lhs=_build_expr(node.left, context), + op=op_name, + rhs=_build_expr(node.comparators[0], context), + ), + node, + context, ) if isinstance(node, ast.BoolOp): op_name = _BOOL_OP_NAMES.get(type(node.op)) @@ -848,7 +907,7 @@ def _build_expr(node: ast.AST, context: _FrontendBuildContext) -> FrontendExprNo op=op_name, rhs=_build_expr(value, context), ) - return expr + return _attach_source_location(expr, node, context) if isinstance(node, ast.Call): if isinstance(node.func, ast.Name) and node.func.id in context.inline_procs: inline_proc = context.inline_procs[node.func.id] @@ -857,11 +916,15 @@ def _build_expr(node: ast.AST, context: _FrontendBuildContext) -> FrontendExprNo node, f"recursive inline_proc call `{node.func.id}` is not supported in TileLang DSL v1", ) - return FrontendCallExpr( - namespace=None, - name=node.func.id, - args=_bind_inline_proc_call(node, inline_proc, context), - keywords=(), + return _attach_source_location( + FrontendCallExpr( + namespace=None, + name=node.func.id, + args=_bind_inline_proc_call(node, inline_proc, context), + keywords=(), + ), + node, + context, ) if ( isinstance(node.func, ast.Attribute) @@ -904,40 +967,52 @@ def _build_expr(node: ast.AST, context: _FrontendBuildContext) -> FrontendExprNo f"selected op {context.selected_op!r}", ) _validate_resolved_template_op_surface(resolved_op, node, context) - return FrontendCallExpr( - namespace="pto", - name=resolved_op, - args=tuple(_build_expr(arg, context) for arg in node.args[1:]), - keywords=_build_call_keywords( - node, + return _attach_source_location( + FrontendCallExpr( namespace="pto", name=resolved_op, - context=context, + args=tuple(_build_expr(arg, context) for arg in node.args[1:]), + keywords=_build_call_keywords( + node, + namespace="pto", + name=resolved_op, + context=context, + ), ), + node, + context, ) if isinstance(node.func, ast.Name): - return FrontendCallExpr( - namespace=None, - name=node.func.id, - args=tuple(_build_expr(arg, context) for arg in node.args), - keywords=_build_call_keywords( - node, + return _attach_source_location( + FrontendCallExpr( namespace=None, name=node.func.id, - context=context, + args=tuple(_build_expr(arg, context) for arg in node.args), + keywords=_build_call_keywords( + node, + namespace=None, + name=node.func.id, + context=context, + ), ), + node, + context, ) if isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name): - return FrontendCallExpr( - namespace=node.func.value.id, - name=node.func.attr, - args=tuple(_build_expr(arg, context) for arg in node.args), - keywords=_build_call_keywords( - node, + return _attach_source_location( + FrontendCallExpr( namespace=node.func.value.id, name=node.func.attr, - context=context, + args=tuple(_build_expr(arg, context) for arg in node.args), + keywords=_build_call_keywords( + node, + namespace=node.func.value.id, + name=node.func.attr, + context=context, + ), ), + node, + context, ) raise context.error( node, @@ -969,26 +1044,38 @@ def _build_stmt(node: ast.stmt, context: _FrontendBuildContext) -> FrontendStmtN if isinstance(node, ast.Assign): if len(node.targets) != 1: raise context.error(node, "multiple assignment targets are not supported in TileLang DSL v1") - return FrontendAssignStmt( - target=_build_target(node.targets[0], context), - value=_build_expr(node.value, context), + return _attach_source_location( + FrontendAssignStmt( + target=_build_target(node.targets[0], context), + value=_build_expr(node.value, context), + ), + node, + context, ) if isinstance(node, ast.AnnAssign): if node.value is None: raise context.error(node, "annotation-only assignments are not supported in TileLang DSL v1") - return FrontendAssignStmt( - target=_build_target(node.target, context), - value=_build_expr(node.value, context), - annotation=node.annotation, + return _attach_source_location( + FrontendAssignStmt( + target=_build_target(node.target, context), + value=_build_expr(node.value, context), + annotation=node.annotation, + ), + node, + context, ) if isinstance(node, ast.Expr): - return FrontendExprStmt(expr=_build_expr(node.value, context)) + return _attach_source_location( + FrontendExprStmt(expr=_build_expr(node.value, context)), + node, + context, + ) if isinstance(node, ast.Return): value = None if node.value is not None: if not (isinstance(node.value, ast.Constant) and node.value.value is None): value = _build_expr(node.value, context) - return FrontendReturnStmt(value=value) + return _attach_source_location(FrontendReturnStmt(value=value), node, context) if isinstance(node, ast.For): if not isinstance(node.target, ast.Name): raise context.error(node.target, "for target must be a single name") @@ -996,12 +1083,16 @@ def _build_stmt(node: ast.stmt, context: _FrontendBuildContext) -> FrontendStmtN raise context.error(node.iter, "only Python range(lb, ub, step) loops are supported") if len(node.iter.args) != 3: raise context.error(node.iter, "range() expects exactly 3 arguments in TileLang DSL v1") - return FrontendForStmt( - target=node.target.id, - lower_bound=_build_expr(node.iter.args[0], context), - upper_bound=_build_expr(node.iter.args[1], context), - step=_build_expr(node.iter.args[2], context), - body=_build_stmt_list(node.body, context), + return _attach_source_location( + FrontendForStmt( + target=node.target.id, + lower_bound=_build_expr(node.iter.args[0], context), + upper_bound=_build_expr(node.iter.args[1], context), + step=_build_expr(node.iter.args[2], context), + body=_build_stmt_list(node.body, context), + ), + node, + context, ) if isinstance(node, ast.If): is_constexpr = False @@ -1025,11 +1116,15 @@ def _build_stmt(node: ast.stmt, context: _FrontendBuildContext) -> FrontendStmtN ) is_constexpr = True condition_node = node.test.args[0] - return FrontendIfStmt( - condition=_build_expr(condition_node, context), - then_body=_build_stmt_list(node.body, context), - else_body=_build_stmt_list(node.orelse, context), - is_constexpr=is_constexpr, + return _attach_source_location( + FrontendIfStmt( + condition=_build_expr(condition_node, context), + then_body=_build_stmt_list(node.body, context), + else_body=_build_stmt_list(node.orelse, context), + is_constexpr=is_constexpr, + ), + node, + context, ) if isinstance(node, ast.With): if len(node.items) != 1: @@ -1055,8 +1150,12 @@ def _build_stmt(node: ast.stmt, context: _FrontendBuildContext) -> FrontendStmtN ) if item.optional_vars is not None: raise context.error(item, "pto.vecscope() does not support `as` bindings in TileLang DSL v1") - return FrontendVecscopeStmt( - body=_build_stmt_list(node.body, context.nested_vecscope()), + return _attach_source_location( + FrontendVecscopeStmt( + body=_build_stmt_list(node.body, context.nested_vecscope()), + ), + node, + context, ) if with_name != "strict_vecscope": raise context.error( @@ -1075,10 +1174,14 @@ def _build_stmt(node: ast.stmt, context: _FrontendBuildContext) -> FrontendStmtN if not isinstance(elt, ast.Name): raise context.error(elt, "pto.strict_vecscope bindings must be names") block_arguments.append(elt.id) - return FrontendStrictVecscopeStmt( - captures=tuple(_build_expr(arg, context) for arg in item.context_expr.args), - block_arguments=tuple(block_arguments), - body=_build_stmt_list(node.body, context.nested_vecscope()), + return _attach_source_location( + FrontendStrictVecscopeStmt( + captures=tuple(_build_expr(arg, context) for arg in item.context_expr.args), + block_arguments=tuple(block_arguments), + body=_build_stmt_list(node.body, context.nested_vecscope()), + ), + node, + context, ) raise context.error( node, diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 5cdb28cc2..98dba1c31 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -31,6 +31,7 @@ FrontendNameTarget, FrontendReturnStmt, FrontendSliceExpr, + FrontendSourceLocation, FrontendStrictVecscopeStmt, FrontendStmtNode, FrontendSubscriptExpr, @@ -664,6 +665,44 @@ def __init__(self, node: FrontendKernelNode): self._inline_proc_order: list[tuple[str, tuple[SemanticType, ...]]] = [] self._inline_proc_active_stack: list[tuple[str, tuple[SemanticType, ...]]] = [] + def _expr_source_location( + self, + expr: FrontendExprNode | SemanticExpr, + ) -> FrontendSourceLocation | None: + return getattr(expr, "source_location", None) + + def _attach_expr_source_location( + self, + semantic_expr: SemanticExpr, + frontend_expr: FrontendExprNode, + ) -> SemanticExpr: + source_location = self._expr_source_location(frontend_expr) + if source_location is not None: + object.__setattr__(semantic_expr, "source_location", source_location) + return semantic_expr + + def _format_source_message( + self, + message: str, + expr: FrontendExprNode | SemanticExpr | None = None, + ) -> str: + if expr is None: + return message + source_location = self._expr_source_location(expr) + if source_location is None: + return message + return ( + f"{source_location.path}:{source_location.line}:{source_location.column}: " + f"{message}" + ) + + def _raise_expr_type_error( + self, + message: str, + expr: FrontendExprNode | SemanticExpr | None = None, + ) -> None: + raise TypeError(self._format_source_message(message, expr)) + def analyze(self) -> SemanticKernel: env: dict[str, SemanticBinding] = {} parameters = [] @@ -2655,27 +2694,48 @@ def _analyze_expr( raise ValueError( f"implicit capture of '{expr.name}' is not allowed in pto.strict_vecscope" ) - return SemanticBindingRef(binding=binding, type=binding.type) + return self._attach_expr_source_location( + SemanticBindingRef(binding=binding, type=binding.type), + expr, + ) if isinstance(expr, FrontendConstantExpr): if isinstance(expr.value, bool): - return SemanticLiteralExpr(value=expr.value, type=SemanticScalarType(dtype=i1)) + return self._attach_expr_source_location( + SemanticLiteralExpr(value=expr.value, type=SemanticScalarType(dtype=i1)), + expr, + ) if isinstance(expr.value, int): - return SemanticLiteralExpr(value=expr.value, type=SemanticIndexType()) + return self._attach_expr_source_location( + SemanticLiteralExpr(value=expr.value, type=SemanticIndexType()), + expr, + ) if isinstance(expr.value, float): - return SemanticLiteralExpr( - value=expr.value, - type=SemanticScalarType(dtype=f32), + return self._attach_expr_source_location( + SemanticLiteralExpr( + value=expr.value, + type=SemanticScalarType(dtype=f32), + ), + expr, ) if isinstance(expr.value, str): - return SemanticLiteralExpr( - value=expr.value, - type=SemanticMetaType(kind="string"), + return self._attach_expr_source_location( + SemanticLiteralExpr( + value=expr.value, + type=SemanticMetaType(kind="string"), + ), + expr, ) if expr.value is None: - return SemanticLiteralExpr(value=None, type=SemanticIndexType()) + return self._attach_expr_source_location( + SemanticLiteralExpr(value=None, type=SemanticIndexType()), + expr, + ) raise TypeError(f"unsupported constant {expr.value!r} in TileLang DSL v1") if isinstance(expr, FrontendSymbolExpr): - return self._analyze_symbol_expr(expr) + return self._attach_expr_source_location( + self._analyze_symbol_expr(expr), + expr, + ) if isinstance(expr, FrontendSliceExpr): start = None if expr.start is None else self._analyze_expr(expr.start, env, allow_outer_lookup=allow_outer_lookup) stop = None if expr.stop is None else self._analyze_expr(expr.stop, env, allow_outer_lookup=allow_outer_lookup) @@ -2683,44 +2743,62 @@ def _analyze_expr( for item in (start, stop, step): if item is not None: self._require_index_typed_expr(item) - return SemanticSliceExpr( - start=start, - stop=stop, - step=step, - type=SemanticSliceType(), + return self._attach_expr_source_location( + SemanticSliceExpr( + start=start, + stop=stop, + step=step, + type=SemanticSliceType(), + ), + expr, ) if isinstance(expr, FrontendTupleExpr): elements = tuple( self._analyze_expr(element, env, allow_outer_lookup=allow_outer_lookup) for element in expr.elements ) - return SemanticTupleExpr( - elements=elements, - type=SemanticTupleType(elements=tuple(element.type for element in elements)), + return self._attach_expr_source_location( + SemanticTupleExpr( + elements=elements, + type=SemanticTupleType(elements=tuple(element.type for element in elements)), + ), + expr, ) if isinstance(expr, FrontendAttributeExpr): base = self._analyze_expr(expr.base, env, allow_outer_lookup=allow_outer_lookup) if expr.attr == "element_type": - return self._element_type_expr(base) + return self._attach_expr_source_location(self._element_type_expr(base), expr) if expr.attr == "valid_shape": - return self._valid_shape_expr(base) + return self._attach_expr_source_location(self._valid_shape_expr(base), expr) if expr.attr == "strides": - return self._strides_expr(base) + return self._attach_expr_source_location(self._strides_expr(base), expr) attr_type = self._attribute_type(base, expr.attr) - return SemanticAttributeAccess(base=base, attr=expr.attr, type=attr_type) + return self._attach_expr_source_location( + SemanticAttributeAccess(base=base, attr=expr.attr, type=attr_type), + expr, + ) if isinstance(expr, FrontendSubscriptExpr): base = self._analyze_expr(expr.base, env, allow_outer_lookup=allow_outer_lookup) index = self._analyze_expr(expr.index, env, allow_outer_lookup=allow_outer_lookup) result_type = self._subscript_type(base, index) if isinstance(result_type, SemanticTensorSliceType): slices = self._normalize_tensor_slice(index, base.type.rank) - return SemanticTensorSliceExpr(base=base, slices=slices, type=result_type) - return SemanticSubscriptAccess(base=base, index=index, type=result_type) + return self._attach_expr_source_location( + SemanticTensorSliceExpr(base=base, slices=slices, type=result_type), + expr, + ) + return self._attach_expr_source_location( + SemanticSubscriptAccess(base=base, index=index, type=result_type), + expr, + ) if isinstance(expr, FrontendBinaryExpr): lhs = self._analyze_expr(expr.lhs, env, allow_outer_lookup=allow_outer_lookup) rhs = self._analyze_expr(expr.rhs, env, allow_outer_lookup=allow_outer_lookup) result_type = self._binary_type(lhs, rhs, expr.op) - return SemanticBinaryExpr(lhs=lhs, op=expr.op, rhs=rhs, type=result_type) + return self._attach_expr_source_location( + SemanticBinaryExpr(lhs=lhs, op=expr.op, rhs=rhs, type=result_type), + expr, + ) if isinstance(expr, FrontendCallExpr): if expr.namespace is None and expr.name in self._inline_proc_nodes: if expr.keywords: @@ -4601,7 +4679,10 @@ def _merge_loop_carried_types( def _require_index_typed_expr(self, expr: SemanticExpr) -> None: if not isinstance(expr.type, SemanticIndexType): - raise TypeError("slice bounds and vector offsets must be index-typed in TileLang DSL v1") + self._raise_expr_type_error( + "slice bounds and vector offsets must be index-typed in TileLang DSL v1", + expr, + ) def _try_static_dtype(self, expr: SemanticExpr) -> ScalarType | None: if ( diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index 513224ec8..aea7b0aef 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -4112,6 +4112,45 @@ def kernel(x: pto.TensorView, tile: pto.Tile): self.assertIn("valid_shape axis 0=5 must be <= shape axis 0=4", str(valid_shape_ctx.exception)) self.assertIn(f"{__file__}:", str(valid_shape_ctx.exception)) + def test_slice_index_type_error_reports_template_source_location(self) -> None: + source = """ +import tilelang_dsl as pto + +@pto.inline_proc +def store_row(dst: pto.Tile, src: pto.Tile, row: pto.i32): + vec = pto.vlds(src[row, 0:]) + mask = pto.make_mask(dst.element_type, pto.PAT.ALL) + pto.vsts(vec, dst[row, 0:], mask) + return None + +@pto.vkernel(op="diag_index_type_unique", dtypes=[(pto.f32, pto.f32, pto.i32)]) +def kernel(dst: pto.Tile, src: pto.Tile, row: pto.i32): + store_row(dst, src, row) + return None +""" + with tempfile.TemporaryDirectory() as tmpdir: + module_path = Path(tmpdir) / "diag_index_type_kernel.py" + module_path.write_text(source, encoding="utf-8") + spec = util.spec_from_file_location("diag_index_type_kernel", module_path) + self.assertIsNotNone(spec) + self.assertIsNotNone(spec.loader) + module = util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) + + specialized = module.kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + + with self.assertRaises(TypeError) as ctx: + specialized.mlir_text() + + message = str(ctx.exception) + self.assertIn(str(module_path), message) + self.assertIn(":6:", message) + self.assertIn("slice bounds and vector offsets must be index-typed", message) + if __name__ == "__main__": unittest.main() From 0d1961bd2373df0eb66b3ca0bda03c1c85947af6 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Mon, 13 Apr 2026 17:04:51 +0800 Subject: [PATCH 043/192] Fix inline_proc codegen --- lib/PTO/Transforms/ExpandTileOp.cpp | 53 +++++-- .../PTOInstantiateAndInlineOpLib.cpp | 134 +++++++++--------- tilelang-dsl/python/tilelang_dsl/lowering.py | 10 +- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 3 +- 4 files changed, 120 insertions(+), 80 deletions(-) diff --git a/lib/PTO/Transforms/ExpandTileOp.cpp b/lib/PTO/Transforms/ExpandTileOp.cpp index 9257e4fa4..9d5c823e8 100644 --- a/lib/PTO/Transforms/ExpandTileOp.cpp +++ b/lib/PTO/Transforms/ExpandTileOp.cpp @@ -36,11 +36,13 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/SymbolTable.h" #include "mlir/Pass/Pass.h" #include "mlir/Parser/Parser.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/MemoryBuffer.h" @@ -376,23 +378,23 @@ func::FuncOp ExpandState::invokeTilelangDSL(const SpecKey &key, return nullptr; } - // 9. Find func.func in the parsed module and clone into target module. - func::FuncOp srcFn; - for (auto fn : parsedMod->getOps()) { - srcFn = fn; - break; - } - if (!srcFn) { + // 9. Clone the generated function set into the target module. The TileLang + // output may include private inline helper funcs referenced by the entry. + SmallVector parsedFuncs; + for (auto fn : parsedMod->getOps()) + parsedFuncs.push_back(fn); + if (parsedFuncs.empty()) { llvm::errs() << "ExpandTileOp: no func.func in DSL output\n"; return nullptr; } + func::FuncOp srcFn = parsedFuncs.front(); OpBuilder builder(ctx); builder.setInsertionPointToEnd(mod.getBody()); - IRMapping mapping; - auto cloned = cast(builder.clone(*srcFn, mapping)); + SmallVector clonedFuncs; + llvm::StringMap renamedSymbols; - // Build a unique name from all operand types. + // Build a unique base name from all operand types. std::string uniqueName = "__pto_tilelang_" + key.opName; for (const auto &op : key.operands) { uniqueName += op.kind == OperandKind::Tile ? "_tile" : "_scalar"; @@ -400,8 +402,35 @@ func::FuncOp ExpandState::invokeTilelangDSL(const SpecKey &key, for (int64_t d : op.shape) uniqueName += "_" + std::to_string(d); } - cloned.setName(uniqueName); - cloned.setVisibility(SymbolTable::Visibility::Private); + + for (auto [index, fn] : llvm::enumerate(parsedFuncs)) { + IRMapping mapping; + auto cloned = cast(builder.clone(*fn, mapping)); + std::string newName; + if (index == 0) { + newName = uniqueName; + cloned.setVisibility(SymbolTable::Visibility::Private); + } else { + newName = uniqueName + "__" + std::string(fn.getSymName()); + } + renamedSymbols[fn.getSymName()] = newName; + cloned.setName(newName); + clonedFuncs.push_back(cloned); + } + + for (func::FuncOp fn : clonedFuncs) { + fn.walk([&](func::CallOp call) { + StringRef callee = call.getCallee(); + if (callee.empty()) + return; + auto renameIt = renamedSymbols.find(callee); + if (renameIt == renamedSymbols.end()) + return; + call.setCallee(renameIt->second); + }); + } + + auto cloned = clonedFuncs.front(); // The pto.tilelang.instance attribute should already be set by the // TileLang DSL frontend in the generated MLIR. Verify it exists. if (!cloned->hasAttr("pto.tilelang.instance")) { diff --git a/lib/PTO/Transforms/PTOInstantiateAndInlineOpLib.cpp b/lib/PTO/Transforms/PTOInstantiateAndInlineOpLib.cpp index 2973510e9..56bd116bf 100644 --- a/lib/PTO/Transforms/PTOInstantiateAndInlineOpLib.cpp +++ b/lib/PTO/Transforms/PTOInstantiateAndInlineOpLib.cpp @@ -188,75 +188,81 @@ struct PTOInlineLibCallPass if (func.empty()) continue; - SmallVector calls; - func.walk([&](func::CallOp call) { calls.push_back(call); }); - bool changedThisFunc = false; - for (func::CallOp oldCall : calls) { - if (!oldCall || !oldCall->getBlock()) - continue; - - auto calleeAttr = oldCall.getCalleeAttr(); - if (!calleeAttr) - continue; - - func::FuncOp callee = - module.lookupSymbol(calleeAttr.getValue()); - if (!callee || !isInlineableLibFunc(callee)) - continue; - - if (callee.isExternal()) { - oldCall.emitError() << kErrInstanceBodyMissing - << ": OP-Lib instance body is missing for @" - << callee.getSymName(); - if (auto variant = - callee->getAttrOfType(kOpLibAttrInstVariantId)) { - oldCall.emitRemark() << "variant_id=" << variant.getValue(); - } - if (auto op = callee->getAttrOfType(kOpLibAttrInstOp)) { - oldCall.emitRemark() << "op=" << op.getValue(); - } - if (auto dtype = - callee->getAttrOfType(kOpLibAttrInstDType)) { - oldCall.emitRemark() << "dtype=" << dtype.getValue(); + bool madeProgress = true; + while (madeProgress) { + madeProgress = false; + + SmallVector calls; + func.walk([&](func::CallOp call) { calls.push_back(call); }); + + for (func::CallOp oldCall : calls) { + if (!oldCall || !oldCall->getBlock()) + continue; + + auto calleeAttr = oldCall.getCalleeAttr(); + if (!calleeAttr) + continue; + + func::FuncOp callee = + module.lookupSymbol(calleeAttr.getValue()); + if (!callee || !isInlineableLibFunc(callee)) + continue; + + if (callee.isExternal()) { + oldCall.emitError() << kErrInstanceBodyMissing + << ": OP-Lib instance body is missing for @" + << callee.getSymName(); + if (auto variant = + callee->getAttrOfType(kOpLibAttrInstVariantId)) { + oldCall.emitRemark() << "variant_id=" << variant.getValue(); + } + if (auto op = callee->getAttrOfType(kOpLibAttrInstOp)) { + oldCall.emitRemark() << "op=" << op.getValue(); + } + if (auto dtype = + callee->getAttrOfType(kOpLibAttrInstDType)) { + oldCall.emitRemark() << "dtype=" << dtype.getValue(); + } + signalPassFailure(); + return; } - signalPassFailure(); - return; - } - func::CallOp call = oldCall; - SmallVector concreteOperands; - concreteOperands.reserve(call.getNumOperands()); - for (auto [operand, expectedTy] : llvm::zip( - call.getOperands(), callee.getFunctionType().getInputs())) { - concreteOperands.push_back( - maybeUnwrapCastToExpected(operand, expectedTy)); - } + func::CallOp call = oldCall; + SmallVector concreteOperands; + concreteOperands.reserve(call.getNumOperands()); + for (auto [operand, expectedTy] : llvm::zip( + call.getOperands(), callee.getFunctionType().getInputs())) { + concreteOperands.push_back( + maybeUnwrapCastToExpected(operand, expectedTy)); + } - OpBuilder builder(call); - auto newCall = builder.create(call.getLoc(), callee, - concreteOperands); - if (call.getNumResults() != newCall.getNumResults()) { - call.emitOpError("call result arity mismatch during inline staging"); - signalPassFailure(); - return; - } - for (auto [oldResult, newResult] : - llvm::zip(call.getResults(), newCall.getResults())) - oldResult.replaceAllUsesWith(newResult); - call.erase(); - - if (failed(inlineCall(newCall, callee))) { - signalPassFailure(); - return; - } + OpBuilder builder(call); + auto newCall = builder.create(call.getLoc(), callee, + concreteOperands); + if (call.getNumResults() != newCall.getNumResults()) { + call.emitOpError("call result arity mismatch during inline staging"); + signalPassFailure(); + return; + } + for (auto [oldResult, newResult] : + llvm::zip(call.getResults(), newCall.getResults())) + oldResult.replaceAllUsesWith(newResult); + call.erase(); + + if (failed(inlineCall(newCall, callee))) { + signalPassFailure(); + return; + } - ++inlinedCalls; - changedThisFunc = true; - if (debug) { - llvm::errs() << "[op-fusion] inline-libcall: inlined @" - << callee.getSymName() << " into @" << func.getSymName() - << "\n"; + ++inlinedCalls; + changedThisFunc = true; + madeProgress = true; + if (debug) { + llvm::errs() << "[op-fusion] inline-libcall: inlined @" + << callee.getSymName() << " into @" << func.getSymName() + << "\n"; + } } } diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index 22da807db..0e081472e 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -144,14 +144,18 @@ def _extract_single_function_lines(rendered_text: str) -> list[str]: def _rewrite_inline_helper_attrs(function_line: str) -> str: kernel_attr = "attributes { pto.tilelang.instance }" - helper_attr = 'attributes { sym_visibility = "private", pto.tilelang.inline_proc }' + helper_attr = "private " + helper_marker_attr = "attributes { pto.tilelang.inline_proc }" if kernel_attr in function_line: - return function_line.replace(kernel_attr, helper_attr) + rewritten = function_line.replace("func.func ", f"func.func {helper_attr}", 1) + return rewritten.replace(kernel_attr, helper_marker_attr) if "attributes {" in function_line: return function_line if function_line.rstrip().endswith("{"): stripped = function_line.rstrip() - return stripped[:-1] + f" {helper_attr} {{" + if stripped.startswith("func.func "): + stripped = stripped.replace("func.func ", f"func.func {helper_attr}", 1) + return stripped[:-1] + f" {helper_marker_attr} {{" return function_line diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index aea7b0aef..c32b1b06a 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -3919,7 +3919,8 @@ def kernel(dst: pto.Tile, src: pto.Tile): dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), ).mlir_text() - self.assertIn('sym_visibility = "private", pto.tilelang.inline_proc', text) + self.assertIn("func.func private @__tl_inline_", text) + self.assertIn("attributes { pto.tilelang.inline_proc }", text) self.assertGreaterEqual(text.count("func.func"), 3) self.assertGreaterEqual(text.count("pto.tilelang.inline_proc"), 2) self.assertRegex(text, r"= func\.call @__tl_inline_[A-Za-z0-9_]+\(.*\) : \([^\)]*\) -> index") From 5e14184a81a85cbbf229b8dc7afb3c5601cfddec Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Mon, 13 Apr 2026 19:21:48 +0800 Subject: [PATCH 044/192] Support optional RoundMode/SatMode/Part arguments in vcvt op --- .../docs/user_guide/04-template-kernels.md | 3 +- .../11-vector-arithmetic-operations.md | 62 ++++++- tilelang-dsl/python/tilelang_dsl/__init__.py | 6 + .../python/tilelang_dsl/frontend_ast.py | 4 + tilelang-dsl/python/tilelang_dsl/lowering.py | 17 +- tilelang-dsl/python/tilelang_dsl/semantic.py | 165 +++++++++++++++++- tilelang-dsl/python/tilelang_dsl/types.py | 19 ++ tilelang-dsl/tests/test_tilelang_dsl_v1.py | 71 +++++++- 8 files changed, 337 insertions(+), 10 deletions(-) diff --git a/tilelang-dsl/docs/user_guide/04-template-kernels.md b/tilelang-dsl/docs/user_guide/04-template-kernels.md index a9bffd3ab..9fcda0fd0 100644 --- a/tilelang-dsl/docs/user_guide/04-template-kernels.md +++ b/tilelang-dsl/docs/user_guide/04-template-kernels.md @@ -307,11 +307,12 @@ else: def template_trowsum(dst: pto.Tile, src: pto.Tile, tmp: pto.Tile): acc_dtype = tmp.element_type dst_dtype = dst.element_type + acc_mask_1, _ = pto.make_mask(acc_dtype, 1) dst_mask_1, _ = pto.make_mask(dst_dtype, 1) if pto.constexpr(acc_dtype != dst_dtype): # Type conversion required - v_acc_casted = pto.vcvt(v_acc, dst_mask_1, dst_dtype) + v_acc_casted = pto.vcvt(v_acc, dst_dtype, acc_mask_1) pto.vsts(v_acc_casted, dst[row, 0:], dst_mask_1) else: # No conversion needed diff --git a/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md b/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md index a8d0d0005..99b70f9b7 100644 --- a/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md +++ b/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md @@ -1325,21 +1325,73 @@ Type conversion and specialized operations. |--------------|------|-------------| | `result` | `VRegType` | Truncated vector | -#### `pto.vcvt(vec: VRegType, to_type: Type, mask: MaskType) -> VRegType` +#### `pto.vcvt(vec: VRegType, to_type: Type, mask: MaskType, rnd: pto.VcvtRoundMode | None = None, sat: pto.VcvtSatMode | None = None, part: pto.VcvtPartMode | None = None) -> VRegType` -**Description**: Type conversion of vector elements. +**Description**: Convert vector elements between supported float and integer +families. This is the TileLang DSL surface for the VPTO `pto.vcvt` conversion +family. + +**Attribute Enums**: +- `pto.VcvtRoundMode`: `R`, `A`, `F`, `C`, `Z`, `O` +- `pto.VcvtSatMode`: `SAT`, `NOSAT` +- `pto.VcvtPartMode`: `EVEN`, `ODD` **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| | `vec` | `VRegType` | Input vector | -| `to_type` | `Type` | Target element type | -| `mask` | `MaskType` | Predicate mask | +| `to_type` | `Type` | Target scalar dtype symbol for the result vector element type | +| `mask` | `MaskType` | Predicate mask selecting active source lanes. Its granularity must match the source vector family, not the destination family | +| `rnd` | `pto.VcvtRoundMode` \| `None` | Optional rounding-mode attribute lowered to VPTO `rnd` | +| `sat` | `pto.VcvtSatMode` \| `None` | Optional saturation attribute lowered to VPTO `sat` | +| `part` | `pto.VcvtPartMode` \| `None` | Optional even/odd packing selector lowered to VPTO `part` | **Returns**: | Return Value | Type | Description | |--------------|------|-------------| -| `result` | `VRegType` | Converted vector | +| `result` | `VRegType` | Converted vector with the vreg shape implied by `to_type` | + +**Constraints**: +- Current TileLang DSL v1 accepts exactly three positional arguments: + `pto.vcvt(vec, to_type, mask)`. Optional attributes are exposed as keyword + arguments: `rnd=...`, `sat=...`, `part=...`. +- The underlying VPTO op family is the fuller + `pto.vcvt %input, %mask {rnd, sat, part}` surface, and the DSL keywords map + directly to those VPTO attributes. +- `mask` always follows the source vector family: + `f32`/`i32`/`si32`/`ui32` use `mask_b32`; + `f16`/`bf16`/`i16`/`si16`/`ui16` use `mask_b16`; + `i8`/`si8`/`ui8` use `mask_b8`. +- The enum form is preferred. For compatibility, canonical strings such as + `"R"`, `"SAT"`, and `"EVEN"` are also accepted. +- Only backend-supported source/destination type pairs are legal. For the full + A5 `vcvt` type matrix, width-changing packing rules, and attribute-sensitive + forms, refer to + [`../vpto_spec/vpto-spec-current.md`](../vpto_spec/vpto-spec-current.md). +- VPTO does not define a `mask_b64` form. Conversions that produce `si64` + results still use the typed mask granularity of the source vector family. +- Width-changing conversions continue to follow VPTO packing semantics even on + the simplified DSL surface. For example, `f16 -> f32` uses an `f16`-family + `mask_b16`, because the mask is attached to the source vector family. + +**Example**: +```python +mask16 = pto.make_mask(pto.f16, PAT.ALL) +vec_f16 = pto.vlds(src, 0) +vec_f32 = pto.vcvt(vec_f16, pto.f32, mask16) + +mask32 = pto.make_mask(pto.f32, PAT.ALL) +vec_i32 = pto.vcvt(vec_f32, pto.si32, mask32) + +vec_f16_narrow = pto.vcvt( + vec_f32, + pto.f16, + mask32, + rnd=pto.VcvtRoundMode.R, + sat=pto.VcvtSatMode.SAT, + part=pto.VcvtPartMode.ODD, +) +``` #### `pto.vbitsort(vec: VRegType, mask: MaskType) -> VRegType` diff --git a/tilelang-dsl/python/tilelang_dsl/__init__.py b/tilelang-dsl/python/tilelang_dsl/__init__.py index 3981ca458..055f8d7db 100644 --- a/tilelang-dsl/python/tilelang_dsl/__init__.py +++ b/tilelang-dsl/python/tilelang_dsl/__init__.py @@ -38,6 +38,9 @@ PadMode, PositionMode, OrderMode, + VcvtPartMode, + VcvtRoundMode, + VcvtSatMode, PointerType, PostUpdateMode, Pipe, @@ -116,6 +119,9 @@ "PadMode", "PositionMode", "OrderMode", + "VcvtRoundMode", + "VcvtSatMode", + "VcvtPartMode", "PostUpdateMode", "TileConfig", "TileSpecialization", diff --git a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py index 87ed9f43d..9ce2f06eb 100644 --- a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py +++ b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py @@ -672,6 +672,7 @@ def _collect_reachable_inline_procs( "ub_stride", } ), + "vcvt": frozenset({"rnd", "sat", "part"}), } @@ -830,6 +831,9 @@ def _build_expr(node: ast.AST, context: _FrontendBuildContext) -> FrontendExprNo "InterleaveDist", "PositionMode", "OrderMode", + "VcvtRoundMode", + "VcvtSatMode", + "VcvtPartMode", "PostUpdateMode", } and len(path) >= 2: return _attach_source_location( diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index 0e081472e..a09bac840 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -2715,9 +2715,17 @@ def _lower_call_expr( value = self._lower_expr(expr.args[0], env, indent=indent, into=into) target_dtype = self._render_dtype_symbol(expr.args[1], context="pto.vcvt to_type") mask = self._lower_expr(expr.args[2], env, indent=indent, into=into) + attr_parts: list[str] = [] + if self._has_optional_string_literal(expr.args[3]): + attr_parts.append(f"rnd = {self._render_string_literal(expr.args[3])}") + if self._has_optional_string_literal(expr.args[4]): + attr_parts.append(f"sat = {self._render_string_literal(expr.args[4])}") + if self._has_optional_string_literal(expr.args[5]): + attr_parts.append(f"part = {self._render_string_literal(expr.args[5])}") + attr_suffix = f" {{{', '.join(attr_parts)}}}" if attr_parts else "" into.append( self._indent(indent) - + f"{result_name} = pto.vcvt {value.name}, {target_dtype}, {mask.name} : " + + f"{result_name} = pto.vcvt {value.name}, {target_dtype}, {mask.name}{attr_suffix} : " + f"{self._render_type(value.type)}, {self._render_type(mask.type)} -> {self._render_type(expr.type)}" ) return _RenderedValue(name=result_name, type=expr.type) @@ -2936,6 +2944,13 @@ def _render_string_literal(self, expr: SemanticExpr) -> str: return f'"{escaped}"' raise NotImplementedError("expected a string literal for TileLang DSL advanced-family lowering") + def _has_optional_string_literal(self, expr: SemanticExpr) -> bool: + if isinstance(expr, SemanticLiteralExpr): + return isinstance(expr.value, str) + if isinstance(expr, SemanticBindingRef): + return isinstance(expr.binding.value, str) + return False + def _render_dtype_symbol(self, expr: SemanticExpr, *, context: str) -> str: if isinstance(expr, SemanticSymbolExpr) and isinstance(expr.value, ScalarType): return expr.value.name diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 98dba1c31..121b83f76 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -63,6 +63,9 @@ PositionMode, PointerType, ScalarType, + VcvtPartMode, + VcvtRoundMode, + VcvtSatMode, VRegType, bf16, bytewidth, @@ -122,6 +125,9 @@ _INTERLEAVE_DIST_SYMBOLS = dict(InterleaveDist.__members__) _POSITION_MODE_SYMBOLS = {position_mode.name: position_mode for position_mode in PositionMode} _ORDER_MODE_SYMBOLS = {order_mode.name: order_mode for order_mode in OrderMode} +_VCVT_ROUND_MODE_SYMBOLS = {mode.name: mode for mode in VcvtRoundMode} +_VCVT_SAT_MODE_SYMBOLS = {mode.name: mode for mode in VcvtSatMode} +_VCVT_PART_MODE_SYMBOLS = {mode.name: mode for mode in VcvtPartMode} _POST_UPDATE_MODE_SYMBOLS = {mode.name: mode for mode in PostUpdateMode} _UNARY_VECTOR_OPS = { "vabs", @@ -2866,6 +2872,12 @@ def _analyze_expr( ) dist = self._analyze_expr(expr.args[1], env, allow_outer_lookup=allow_outer_lookup) return self._analyze_vldsx2((base, *indices, dist)) + if expr.namespace == "pto" and expr.name == "vcvt": + return self._analyze_vcvt_frontend_call( + expr, + env, + allow_outer_lookup=allow_outer_lookup, + ) if expr.keywords: raise TypeError( f"call surface `{expr.namespace + '.' if expr.namespace else ''}{expr.name}` " @@ -2997,6 +3009,33 @@ def _analyze_symbol_expr(self, expr: FrontendSymbolExpr) -> SemanticExpr: value=order_mode, type=SemanticMetaType(kind="order_mode"), ) + if expr.namespace in {"VcvtRoundMode", "pto.VcvtRoundMode"}: + round_mode = _VCVT_ROUND_MODE_SYMBOLS.get(expr.name) + if round_mode is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=round_mode, + type=SemanticMetaType(kind="vcvt_round_mode"), + ) + if expr.namespace in {"VcvtSatMode", "pto.VcvtSatMode"}: + sat_mode = _VCVT_SAT_MODE_SYMBOLS.get(expr.name) + if sat_mode is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=sat_mode, + type=SemanticMetaType(kind="vcvt_sat_mode"), + ) + if expr.namespace in {"VcvtPartMode", "pto.VcvtPartMode"}: + part_mode = _VCVT_PART_MODE_SYMBOLS.get(expr.name) + if part_mode is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=part_mode, + type=SemanticMetaType(kind="vcvt_part_mode"), + ) if expr.namespace in {"PostUpdateMode", "pto.PostUpdateMode"}: post_update_mode = _POST_UPDATE_MODE_SYMBOLS.get(expr.name) if post_update_mode is not None: @@ -4087,7 +4126,44 @@ def _analyze_rearrangement_op( self._require_string_expr(args[2], f"pto.{name} part") return SemanticCallExpr(namespace="pto", name=name, args=args, type=lhs) - def _analyze_vcvt(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + def _missing_optional_meta_expr(self) -> SemanticLiteralExpr: + return SemanticLiteralExpr(value=None, type=SemanticMetaType(kind="none")) + + def _analyze_vcvt_frontend_call( + self, + expr: FrontendCallExpr, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + ) -> SemanticExpr: + if len(expr.args) != 3: + raise TypeError( + "pto.vcvt expects exactly 3 positional operands `(vec, to_type, mask)` " + "before optional keyword attrs in TileLang DSL v1" + ) + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + analyzed_keywords = { + name: self._analyze_expr(value, env, allow_outer_lookup=allow_outer_lookup) + for name, value in expr.keywords + } + return self._analyze_vcvt( + args, + rnd=self._normalize_vcvt_round_mode(analyzed_keywords.get("rnd")), + sat=self._normalize_vcvt_sat_mode(analyzed_keywords.get("sat")), + part=self._normalize_vcvt_part_mode(analyzed_keywords.get("part")), + ) + + def _analyze_vcvt( + self, + args: tuple[SemanticExpr, ...], + *, + rnd: SemanticExpr | None = None, + sat: SemanticExpr | None = None, + part: SemanticExpr | None = None, + ) -> SemanticExpr: if len(args) != 3: raise TypeError("pto.vcvt expects exactly 3 positional arguments in TileLang DSL") vector = self._require_vreg_expr(args[0], "pto.vcvt vector") @@ -4096,7 +4172,14 @@ def _analyze_vcvt(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: return SemanticCallExpr( namespace="pto", name="vcvt", - args=args, + args=( + args[0], + args[1], + args[2], + rnd if rnd is not None else self._missing_optional_meta_expr(), + sat if sat is not None else self._missing_optional_meta_expr(), + part if part is not None else self._missing_optional_meta_expr(), + ), type=self._vreg_type_for_dtype(target_dtype), ) @@ -4303,6 +4386,84 @@ def _normalize_order_mode( raise TypeError("pto.vci currently only supports order `OrderMode.ASC` in TileLang DSL v1") return SemanticLiteralExpr(value=order, type=SemanticMetaType(kind="string")) + def _normalize_vcvt_round_mode(self, expr: SemanticExpr | None) -> SemanticExpr | None: + if expr is None: + return None + if ( + isinstance(expr, SemanticSymbolExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "vcvt_round_mode" + and isinstance(expr.value, VcvtRoundMode) + ): + return SemanticLiteralExpr(value=expr.value.value, type=SemanticMetaType(kind="string")) + if ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "vcvt_round_mode" + and isinstance(expr.binding.value, VcvtRoundMode) + ): + return SemanticLiteralExpr(value=expr.binding.value.value, type=SemanticMetaType(kind="string")) + round_mode = self._require_string_expr(expr, "pto.vcvt rnd") + if round_mode not in {mode.value for mode in VcvtRoundMode}: + raise TypeError( + "pto.vcvt rnd must be a VcvtRoundMode enum such as " + "`pto.VcvtRoundMode.R` or one of the canonical strings " + '`"R"`, `"A"`, `"F"`, `"C"`, `"Z"`, `"O"` in TileLang DSL v1' + ) + return SemanticLiteralExpr(value=round_mode, type=SemanticMetaType(kind="string")) + + def _normalize_vcvt_sat_mode(self, expr: SemanticExpr | None) -> SemanticExpr | None: + if expr is None: + return None + if ( + isinstance(expr, SemanticSymbolExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "vcvt_sat_mode" + and isinstance(expr.value, VcvtSatMode) + ): + return SemanticLiteralExpr(value=expr.value.value, type=SemanticMetaType(kind="string")) + if ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "vcvt_sat_mode" + and isinstance(expr.binding.value, VcvtSatMode) + ): + return SemanticLiteralExpr(value=expr.binding.value.value, type=SemanticMetaType(kind="string")) + sat_mode = self._require_string_expr(expr, "pto.vcvt sat") + if sat_mode not in {mode.value for mode in VcvtSatMode}: + raise TypeError( + "pto.vcvt sat must be a VcvtSatMode enum such as " + "`pto.VcvtSatMode.SAT` or `pto.VcvtSatMode.NOSAT`, or one of the " + 'canonical strings `"SAT"` / `"NOSAT"` in TileLang DSL v1' + ) + return SemanticLiteralExpr(value=sat_mode, type=SemanticMetaType(kind="string")) + + def _normalize_vcvt_part_mode(self, expr: SemanticExpr | None) -> SemanticExpr | None: + if expr is None: + return None + if ( + isinstance(expr, SemanticSymbolExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "vcvt_part_mode" + and isinstance(expr.value, VcvtPartMode) + ): + return SemanticLiteralExpr(value=expr.value.value, type=SemanticMetaType(kind="string")) + if ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "vcvt_part_mode" + and isinstance(expr.binding.value, VcvtPartMode) + ): + return SemanticLiteralExpr(value=expr.binding.value.value, type=SemanticMetaType(kind="string")) + part_mode = self._require_string_expr(expr, "pto.vcvt part") + if part_mode not in {mode.value for mode in VcvtPartMode}: + raise TypeError( + "pto.vcvt part must be a VcvtPartMode enum such as " + "`pto.VcvtPartMode.EVEN` or `pto.VcvtPartMode.ODD`, or one of the " + 'canonical strings `"EVEN"` / `"ODD"` in TileLang DSL v1' + ) + return SemanticLiteralExpr(value=part_mode, type=SemanticMetaType(kind="string")) + def _require_mask_expr(self, expr: SemanticExpr, context: str) -> SemanticMaskType: if not isinstance(expr.type, SemanticMaskType): raise TypeError(f"{context} must be a mask value in TileLang DSL") diff --git a/tilelang-dsl/python/tilelang_dsl/types.py b/tilelang-dsl/python/tilelang_dsl/types.py index e35476876..02752b277 100644 --- a/tilelang-dsl/python/tilelang_dsl/types.py +++ b/tilelang-dsl/python/tilelang_dsl/types.py @@ -219,6 +219,25 @@ class OrderMode(str, Enum): ASC = "ORDER_ASC" +class VcvtRoundMode(str, Enum): + R = "R" + A = "A" + F = "F" + C = "C" + Z = "Z" + O = "O" + + +class VcvtSatMode(str, Enum): + SAT = "SAT" + NOSAT = "NOSAT" + + +class VcvtPartMode(str, Enum): + EVEN = "EVEN" + ODD = "ODD" + + class PostUpdateMode(str, Enum): POST_UPDATE = "POST_UPDATE" NO_POST_UPDATE = "NO_POST_UPDATE" diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index c32b1b06a..b586b0ba3 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -98,6 +98,9 @@ def test_package_exports_surface(self) -> None: self.assertTrue(hasattr(pto, "InterleaveDist")) self.assertTrue(hasattr(pto, "PositionMode")) self.assertTrue(hasattr(pto, "OrderMode")) + self.assertTrue(hasattr(pto, "VcvtRoundMode")) + self.assertTrue(hasattr(pto, "VcvtSatMode")) + self.assertTrue(hasattr(pto, "VcvtPartMode")) self.assertTrue(hasattr(pto, "PostUpdateMode")) self.assertTrue(hasattr(pto, "PIPE")) self.assertTrue(hasattr(pto, "EVENT")) @@ -119,6 +122,9 @@ def test_package_exports_surface(self) -> None: self.assertEqual(pto.PositionMode.LOWEST.value, "LOWEST") self.assertEqual(pto.PositionMode.HIGHEST.value, "HIGHEST") self.assertEqual(pto.OrderMode.ASC.value, "ORDER_ASC") + self.assertEqual(pto.VcvtRoundMode.R.value, "R") + self.assertEqual(pto.VcvtSatMode.SAT.value, "SAT") + self.assertEqual(pto.VcvtPartMode.ODD.value, "ODD") self.assertEqual(pto.PostUpdateMode.POST_UPDATE.value, "POST_UPDATE") self.assertEqual(pto.PostUpdateMode.NO_POST_UPDATE.value, "NO_POST_UPDATE") self.assertEqual(pto.Event.ID31.value, "EVENT_ID31") @@ -2052,6 +2058,69 @@ def kernel(dst: pto.Tile, src: pto.Tile, alpha: pto.f32): self.assertIn("pto.vcvt", text) self.assertIn("pto.vmrgsort4", text) + def test_vcvt_supports_keyword_attrs_with_enums(self) -> None: + @pto.vkernel( + op="vcvt_keyword_attrs_unique", + dtypes=[(pto.f16, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + src_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + dst_mask = pto.make_mask(pto.f16, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vcvt( + vec, + pto.f16, + src_mask, + rnd=pto.VcvtRoundMode.R, + sat=pto.VcvtSatMode.SAT, + part=pto.VcvtPartMode.ODD, + ) + pto.vsts(out, dst, 0, dst_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn('pto.vcvt', text) + self.assertIn('rnd = "R"', text) + self.assertIn('sat = "SAT"', text) + self.assertIn('part = "ODD"', text) + + def test_vcvt_rejects_legacy_string_spellings(self) -> None: + with self.assertRaises(TypeError) as ctx: + + @pto.vkernel( + op="vcvt_keyword_attrs_legacy_unique", + dtypes=[(pto.f16, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + src_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + dst_mask = pto.make_mask(pto.f16, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vcvt( + vec, + pto.f16, + src_mask, + rnd="ROUND_R", + sat="RS_ENABLE", + part="PART_EVEN", + ) + pto.vsts(out, dst, 0, dst_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + specialized.mlir_text() + + self.assertIn("pto.vcvt rnd must be a VcvtRoundMode enum", str(ctx.exception)) + def test_extended_integer_vector_ops_surface_lowers(self) -> None: @pto.vkernel( op="extended_integer_vector_ops_unique", @@ -2774,7 +2843,7 @@ def kernel(dst: pto.Tile, src: pto.Tile, tmp: pto.Tile): acc = pto.vadd(acc, reduced, one_mask) out_mask, _ = pto.make_mask(src_dtype, 1) if pto.constexpr(src_dtype != dst.element_type): - casted = pto.vcvt(acc, out_mask, dst.element_type) + casted = pto.vcvt(acc, dst.element_type, out_mask) pto.vsts(casted, dst[row, 0:], out_mask) else: pto.vsts(acc, dst[row, 0:], out_mask) From 48192c2a9037cc694565b54f7c4b00cb6404004f Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Tue, 14 Apr 2026 09:26:10 +0800 Subject: [PATCH 045/192] Fix tile attribute extraction --- tilelang-dsl/docs/unsupported-features.md | 6 +- .../docs/user_guide/05-type-system.md | 4 +- tilelang-dsl/python/tilelang_dsl/__init__.py | 6 + tilelang-dsl/python/tilelang_dsl/semantic.py | 118 +++++++++++++++- .../python/tilelang_dsl/support_matrix.py | 6 +- tilelang-dsl/python/tilelang_dsl/types.py | 126 +++++++++++++++++- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 90 +++++++++++++ 7 files changed, 346 insertions(+), 10 deletions(-) diff --git a/tilelang-dsl/docs/unsupported-features.md b/tilelang-dsl/docs/unsupported-features.md index 971ccfb7b..d47cba7d8 100644 --- a/tilelang-dsl/docs/unsupported-features.md +++ b/tilelang-dsl/docs/unsupported-features.md @@ -34,9 +34,9 @@ the current package: - `SyncOpType` Today, the public package exports annotation markers (`TensorView`, `Tile`), -scalar dtypes, `ptr(...)`, `PadMode`, `TileConfig`, matcher APIs, and a small -set of enums. The list above covers the remaining missing public constructors -and aliases from the guide. +scalar dtypes, `ptr(...)`, `PadMode`, `BLayout`, `SLayout`, `PadValue`, +`TileConfig`, matcher APIs, and a small set of enums. The list above covers the +remaining missing public constructors and aliases from the guide. ### Missing Tile/Tensor Utility Methods diff --git a/tilelang-dsl/docs/user_guide/05-type-system.md b/tilelang-dsl/docs/user_guide/05-type-system.md index 8cb898250..19c28082f 100644 --- a/tilelang-dsl/docs/user_guide/05-type-system.md +++ b/tilelang-dsl/docs/user_guide/05-type-system.md @@ -295,8 +295,8 @@ valid_shape = tile.valid_shape # (240, 120) or same as shape config = tile.config b_layout = config.b_layout # pto.BLayout.ROW_MAJOR s_layout = config.s_layout # pto.SLayout.NONE_BOX -s_fractal = config.s_fractal_size # pto.i32(16) -pad = config.pad_value # pto.PadValue.ZERO +s_fractal = config.s_fractal_size # pto.i32(512) +pad = config.pad_value # pto.PadValue.NULL # Dynamic properties rank = tile.rank # 2 diff --git a/tilelang-dsl/python/tilelang_dsl/__init__.py b/tilelang-dsl/python/tilelang_dsl/__init__.py index 055f8d7db..d3f159845 100644 --- a/tilelang-dsl/python/tilelang_dsl/__init__.py +++ b/tilelang-dsl/python/tilelang_dsl/__init__.py @@ -26,6 +26,7 @@ AnyMask, AnyType, BarrierType, + BLayout, DeinterleaveDist, EVENT, InterleaveDist, @@ -38,6 +39,7 @@ PadMode, PositionMode, OrderMode, + PadValue, VcvtPartMode, VcvtRoundMode, VcvtSatMode, @@ -45,6 +47,7 @@ PostUpdateMode, Pipe, ScalarType, + SLayout, TensorView, PartitionTensorView, Tile, @@ -114,11 +117,14 @@ "MaskPattern", "PAT", "BarrierType", + "BLayout", "DeinterleaveDist", "InterleaveDist", "PadMode", + "PadValue", "PositionMode", "OrderMode", + "SLayout", "VcvtRoundMode", "VcvtSatMode", "VcvtPartMode", diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 121b83f76..0e0680dca 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -50,6 +50,7 @@ from .types import ( AlignType, BarrierType, + BLayout, DeinterleaveDist, Event, InterleaveDist, @@ -58,11 +59,14 @@ MemorySpace, OrderMode, PadMode, + PadValue, Pipe, PostUpdateMode, PositionMode, PointerType, ScalarType, + SLayout, + TileConfig, VcvtPartMode, VcvtRoundMode, VcvtSatMode, @@ -121,6 +125,9 @@ _BARRIER_TYPE_SYMBOLS = {barrier_type.name: barrier_type for barrier_type in BarrierType} _MEMORY_SPACE_SYMBOLS = {memory_space.name: memory_space for memory_space in MemorySpace} _PAD_MODE_SYMBOLS = {pad_mode.name: pad_mode for pad_mode in PadMode} +_B_LAYOUT_SYMBOLS = {layout.name: layout for layout in BLayout} +_S_LAYOUT_SYMBOLS = {layout.name: layout for layout in SLayout} +_PAD_VALUE_SYMBOLS = {pad_value.name: pad_value for pad_value in PadValue} _DEINTERLEAVE_DIST_SYMBOLS = dict(DeinterleaveDist.__members__) _INTERLEAVE_DIST_SYMBOLS = dict(InterleaveDist.__members__) _POSITION_MODE_SYMBOLS = {position_mode.name: position_mode for position_mode in PositionMode} @@ -255,6 +262,12 @@ class SemanticTileType(SemanticType): shape: tuple[int, ...] | None valid_shape: tuple[int | None, ...] | None memory_space: str | None + config: TileConfig | None + + +@dataclass(frozen=True) +class SemanticTileConfigType(SemanticType): + pass @dataclass(frozen=True) @@ -780,6 +793,7 @@ def _parameter_type(self, param: Any) -> SemanticType: shape=shape, valid_shape=valid_shape, memory_space=memory_space, + config=None if spec is None else (spec.config or TileConfig()), ) if param.kind == "ptr": memory_space = param.annotation.memory_space.value @@ -1295,7 +1309,7 @@ def _collect_inline_helper_tile_bindings( shape=parameter.type.shape, valid_shape=parameter.type.valid_shape, memory_space=parameter.type.memory_space or "ub", - config=None, + config=parameter.type.config or TileConfig(), ) ) return tuple(tile_bindings) @@ -2774,10 +2788,18 @@ def _analyze_expr( base = self._analyze_expr(expr.base, env, allow_outer_lookup=allow_outer_lookup) if expr.attr == "element_type": return self._attach_expr_source_location(self._element_type_expr(base), expr) + if expr.attr == "rank": + return self._attach_expr_source_location(self._rank_expr(base), expr) + if expr.attr == "memory_space": + return self._attach_expr_source_location(self._memory_space_expr(base), expr) + if expr.attr == "config": + return self._attach_expr_source_location(self._tile_config_expr(base), expr) if expr.attr == "valid_shape": return self._attach_expr_source_location(self._valid_shape_expr(base), expr) if expr.attr == "strides": return self._attach_expr_source_location(self._strides_expr(base), expr) + if isinstance(base.type, SemanticTileConfigType): + return self._attach_expr_source_location(self._tile_config_attr_expr(base, expr.attr), expr) attr_type = self._attribute_type(base, expr.attr) return self._attach_expr_source_location( SemanticAttributeAccess(base=base, attr=expr.attr, type=attr_type), @@ -2973,6 +2995,33 @@ def _analyze_symbol_expr(self, expr: FrontendSymbolExpr) -> SemanticExpr: value=pad_mode, type=SemanticMetaType(kind="pad_mode"), ) + if expr.namespace in {"BLayout", "pto.BLayout"}: + b_layout = _B_LAYOUT_SYMBOLS.get(expr.name) + if b_layout is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=b_layout, + type=SemanticMetaType(kind="b_layout"), + ) + if expr.namespace in {"SLayout", "pto.SLayout"}: + s_layout = _S_LAYOUT_SYMBOLS.get(expr.name) + if s_layout is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=s_layout, + type=SemanticMetaType(kind="s_layout"), + ) + if expr.namespace in {"PadValue", "pto.PadValue"}: + pad_value = _PAD_VALUE_SYMBOLS.get(expr.name) + if pad_value is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=pad_value, + type=SemanticMetaType(kind="pad_value"), + ) if expr.namespace in {"DeinterleaveDist", "pto.DeinterleaveDist"}: dist = _DEINTERLEAVE_DIST_SYMBOLS.get(expr.name) if dist is not None: @@ -3072,6 +3121,72 @@ def _element_type_expr(self, base: SemanticExpr) -> SemanticExpr: ) raise TypeError("unsupported attribute access 'element_type' in TileLang DSL v1") + def _rank_expr(self, base: SemanticExpr) -> SemanticExpr: + base_type = base.type + if isinstance(base_type, (SemanticTensorViewType, SemanticPartitionTensorViewType, SemanticTileType)): + return SemanticLiteralExpr(value=base_type.rank, type=SemanticIndexType()) + raise TypeError("unsupported attribute access 'rank' in TileLang DSL v1") + + def _memory_space_expr(self, base: SemanticExpr) -> SemanticExpr: + base_type = base.type + if isinstance(base_type, (SemanticTensorViewType, SemanticPartitionTensorViewType)): + return SemanticSymbolExpr( + namespace="pto", + name=MemorySpace.GM.name, + value=MemorySpace.GM, + type=SemanticMetaType(kind="memory_space"), + ) + if isinstance(base_type, SemanticTileType): + memory_space = MemorySpace.UB if base_type.memory_space is None else MemorySpace(base_type.memory_space) + return SemanticSymbolExpr( + namespace="pto", + name=memory_space.name, + value=memory_space, + type=SemanticMetaType(kind="memory_space"), + ) + raise TypeError("unsupported attribute access 'memory_space' in TileLang DSL v1") + + def _tile_config_expr(self, base: SemanticExpr) -> SemanticExpr: + base_type = base.type + if isinstance(base_type, SemanticTileType): + return SemanticLiteralExpr( + value=base_type.config or TileConfig(), + type=SemanticTileConfigType(), + ) + raise TypeError("unsupported attribute access 'config' in TileLang DSL v1") + + def _tile_config_attr_expr(self, base: SemanticExpr, attr: str) -> SemanticExpr: + config = self._try_static_value(base) + if not isinstance(config, TileConfig): + raise TypeError("Tile config metadata must be statically known in TileLang DSL v1") + if attr == "b_layout": + return SemanticSymbolExpr( + namespace="pto", + name=config.b_layout.name, + value=config.b_layout, + type=SemanticMetaType(kind="b_layout"), + ) + if attr == "s_layout": + return SemanticSymbolExpr( + namespace="pto", + name=config.s_layout.name, + value=config.s_layout, + type=SemanticMetaType(kind="s_layout"), + ) + if attr == "s_fractal_size": + return SemanticLiteralExpr( + value=config.s_fractal_size, + type=SemanticScalarType(dtype=i32), + ) + if attr == "pad_value": + return SemanticSymbolExpr( + namespace="pto", + name=config.pad_value.name, + value=config.pad_value, + type=SemanticMetaType(kind="pad_value"), + ) + raise TypeError(f"unsupported TileConfig attribute access '{attr}' in TileLang DSL v1") + def _analyze_as_ptr_method(self, base: SemanticExpr) -> SemanticExpr: base_type = base.type if isinstance(base_type, (SemanticTensorViewType, SemanticPartitionTensorViewType)): @@ -5145,6 +5260,7 @@ def analyze_frontend_kernel(node: FrontendKernelNode) -> SemanticKernel: "SemanticTensorViewType", "SemanticPartitionTensorViewType", "SemanticTileBinding", + "SemanticTileConfigType", "SemanticTileType", "SemanticTupleExpr", "SemanticTupleType", diff --git a/tilelang-dsl/python/tilelang_dsl/support_matrix.py b/tilelang-dsl/python/tilelang_dsl/support_matrix.py index 6805350e9..e574d9bad 100644 --- a/tilelang-dsl/python/tilelang_dsl/support_matrix.py +++ b/tilelang-dsl/python/tilelang_dsl/support_matrix.py @@ -324,9 +324,6 @@ def get_pto_call_tier(call_name: str) -> str: "pto.dma_copy", "pto.vreduce", "pto.tile", - "BLayout", - "SLayout", - "PadValue", "SyncOpType", } ) @@ -346,6 +343,9 @@ def get_pto_call_tier(call_name: str) -> str: "pto.mask_b32": BASIC_TIER, "BarrierType": BASIC_TIER, "PadMode": BASIC_TIER, + "BLayout": BASIC_TIER, + "SLayout": BASIC_TIER, + "PadValue": BASIC_TIER, "constexpr": BASIC_TIER, "pto.constexpr": BASIC_TIER, "tile[start:]": BASIC_TIER, diff --git a/tilelang-dsl/python/tilelang_dsl/types.py b/tilelang-dsl/python/tilelang_dsl/types.py index 02752b277..bd62b70ff 100644 --- a/tilelang-dsl/python/tilelang_dsl/types.py +++ b/tilelang-dsl/python/tilelang_dsl/types.py @@ -194,6 +194,24 @@ class PadMode(str, Enum): PadValue = "PadValue" +class BLayout(str, Enum): + ROW_MAJOR = "row_major" + COL_MAJOR = "col_major" + + +class SLayout(str, Enum): + NONE_BOX = "none_box" + ROW_MAJOR = "row_major" + COL_MAJOR = "col_major" + + +class PadValue(str, Enum): + NULL = "null" + ZERO = "zero" + MAX = "max" + MIN = "min" + + class DeinterleaveDist(str, Enum): DINTLV = "DINTLV" BDINTLV = "BDINTLV" @@ -249,7 +267,110 @@ class TileConfig: @classmethod def from_mapping(cls, mapping: Mapping[str, Any]) -> "TileConfig": - return cls(tuple(sorted(mapping.items()))) + if not isinstance(mapping, Mapping): + raise TypeError("TileConfig.from_mapping expects a mapping") + normalized: dict[str, Any] = {} + for key, value in mapping.items(): + canonical_key = cls._canonical_key(key) + if canonical_key in normalized: + raise ValueError(f"duplicate TileConfig field '{canonical_key}'") + normalized[canonical_key] = cls._normalize_field_value(canonical_key, value) + return cls(tuple(sorted(normalized.items()))) + + @staticmethod + def _canonical_key(key: Any) -> str: + if not isinstance(key, str): + raise TypeError("TileConfig field names must be strings") + aliases = { + "layout": "b_layout", + "blayout": "b_layout", + "b_layout": "b_layout", + "slayout": "s_layout", + "s_layout": "s_layout", + "fractal": "s_fractal_size", + "s_fractal_size": "s_fractal_size", + "pad": "pad_value", + "pad_value": "pad_value", + } + return aliases.get(key, key) + + @staticmethod + def _normalize_field_value(key: str, value: Any) -> Any: + if key == "b_layout": + return TileConfig._normalize_b_layout(value) + if key == "s_layout": + return TileConfig._normalize_s_layout(value) + if key == "s_fractal_size": + if isinstance(value, bool) or not isinstance(value, int): + raise TypeError("TileConfig.s_fractal_size must be an integer") + return value + if key == "pad_value": + return TileConfig._normalize_pad_value(value) + return value + + @staticmethod + def _normalize_b_layout(value: Any) -> BLayout: + if isinstance(value, BLayout): + return value + if isinstance(value, str): + normalized = value.strip().upper().replace("-", "_") + if normalized == "ROW_MAJOR": + return BLayout.ROW_MAJOR + if normalized == "COL_MAJOR": + return BLayout.COL_MAJOR + raise ValueError(f"unsupported TileConfig b_layout value {value!r}") + + @staticmethod + def _normalize_s_layout(value: Any) -> SLayout: + if isinstance(value, SLayout): + return value + if isinstance(value, str): + normalized = value.strip().upper().replace("-", "_") + if normalized == "NONE_BOX": + return SLayout.NONE_BOX + if normalized == "ROW_MAJOR": + return SLayout.ROW_MAJOR + if normalized == "COL_MAJOR": + return SLayout.COL_MAJOR + raise ValueError(f"unsupported TileConfig s_layout value {value!r}") + + @staticmethod + def _normalize_pad_value(value: Any) -> PadValue: + if isinstance(value, PadValue): + return value + if isinstance(value, str): + normalized = value.strip().upper().replace("-", "_") + if normalized == "NULL": + return PadValue.NULL + if normalized == "ZERO": + return PadValue.ZERO + if normalized == "MAX": + return PadValue.MAX + if normalized == "MIN": + return PadValue.MIN + raise ValueError(f"unsupported TileConfig pad_value value {value!r}") + + @property + def b_layout(self) -> BLayout: + value = dict(self.fields).get("b_layout", BLayout.ROW_MAJOR) + return self._normalize_b_layout(value) + + @property + def s_layout(self) -> SLayout: + value = dict(self.fields).get("s_layout", SLayout.NONE_BOX) + return self._normalize_s_layout(value) + + @property + def s_fractal_size(self) -> int: + value = dict(self.fields).get("s_fractal_size", 512) + if isinstance(value, bool) or not isinstance(value, int): + raise TypeError("TileConfig.s_fractal_size must be an integer") + return value + + @property + def pad_value(self) -> PadValue: + value = dict(self.fields).get("pad_value", PadValue.NULL) + return self._normalize_pad_value(value) @dataclass(frozen=True) @@ -372,6 +493,9 @@ def constexpr(value: bool) -> bool: "PAT", "BarrierType", "PadMode", + "BLayout", + "SLayout", + "PadValue", "DeinterleaveDist", "InterleaveDist", "PositionMode", diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index b586b0ba3..0d57af82c 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -40,6 +40,7 @@ SemanticGetBufStmt, SemanticIfStmt, SemanticIndexType, + SemanticLiteralExpr, SemanticMemBarStmt, SemanticLowLevelCopyStmt, SemanticMaskType, @@ -56,6 +57,7 @@ SemanticStrictVecscopeStmt, SemanticSymbolExpr, SemanticTensorViewType, + SemanticTileConfigType, SemanticTileType, SemanticVecscopeStmt, SemanticVectorPairStoreStmt, @@ -94,14 +96,17 @@ def test_package_exports_surface(self) -> None: self.assertTrue(hasattr(pto, "PAT")) self.assertTrue(hasattr(pto, "PadMode")) self.assertTrue(hasattr(pto, "BarrierType")) + self.assertTrue(hasattr(pto, "BLayout")) self.assertTrue(hasattr(pto, "DeinterleaveDist")) self.assertTrue(hasattr(pto, "InterleaveDist")) self.assertTrue(hasattr(pto, "PositionMode")) self.assertTrue(hasattr(pto, "OrderMode")) + self.assertTrue(hasattr(pto, "PadValue")) self.assertTrue(hasattr(pto, "VcvtRoundMode")) self.assertTrue(hasattr(pto, "VcvtSatMode")) self.assertTrue(hasattr(pto, "VcvtPartMode")) self.assertTrue(hasattr(pto, "PostUpdateMode")) + self.assertTrue(hasattr(pto, "SLayout")) self.assertTrue(hasattr(pto, "PIPE")) self.assertTrue(hasattr(pto, "EVENT")) self.assertTrue(hasattr(pto, "si8")) @@ -116,6 +121,9 @@ def test_package_exports_surface(self) -> None: self.assertEqual(pto.PadMode.PadNull.value, "PadNull") self.assertEqual(pto.PadMode.PadFirstElem.value, "PadFirstElem") self.assertEqual(pto.PadMode.PadValue.value, "PadValue") + self.assertEqual(pto.BLayout.ROW_MAJOR.value, "row_major") + self.assertEqual(pto.SLayout.NONE_BOX.value, "none_box") + self.assertEqual(pto.PadValue.NULL.value, "null") self.assertEqual(pto.DeinterleaveDist.DINTLV.value, "DINTLV") self.assertEqual(pto.DeinterleaveDist.BDINTLV.value, "BDINTLV") self.assertEqual(pto.InterleaveDist.INTLV.value, "INTLV") @@ -142,6 +150,26 @@ def test_package_exports_surface(self) -> None: self.assertEqual(pto.elements_per_vreg(pto.si8), 256) self.assertEqual(repr(pto.align), "align") + def test_tile_config_exposes_normalized_query_properties(self) -> None: + default_config = pto.TileConfig() + self.assertEqual(default_config.b_layout, pto.BLayout.ROW_MAJOR) + self.assertEqual(default_config.s_layout, pto.SLayout.NONE_BOX) + self.assertEqual(default_config.s_fractal_size, 512) + self.assertEqual(default_config.pad_value, pto.PadValue.NULL) + + config = pto.TileConfig.from_mapping( + { + "layout": "col_major", + "s_layout": "row_major", + "fractal": 16, + "pad": "max", + } + ) + self.assertEqual(config.b_layout, pto.BLayout.COL_MAJOR) + self.assertEqual(config.s_layout, pto.SLayout.ROW_MAJOR) + self.assertEqual(config.s_fractal_size, 16) + self.assertEqual(config.pad_value, pto.PadValue.MAX) + class TileLangDSLSupportMatrixTests(unittest.TestCase): def test_stable_starter_surface_groups_map_to_stable_tier(self) -> None: @@ -1657,6 +1685,68 @@ def kernel(inp: pto.TensorView, tile: pto.Tile): self.assertEqual(assign_stmt.value.value, pto.PadMode.PadFirstElem) self.assertEqual(assign_stmt.value.type.kind, "pad_mode") + def test_tile_config_attributes_bind_as_static_metadata(self) -> None: + @pto.vkernel(op="tile_config_attrs_unique", dtypes=[(pto.f16,)]) + def kernel(tile: pto.Tile): + config = tile.config + layout = config.b_layout + secondary = config.s_layout + fractal = config.s_fractal_size + pad = config.pad_value + rank = tile.rank + space = tile.memory_space + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(8, 16), + memory_space=pto.MemorySpace.UB, + config=pto.TileConfig.from_mapping( + { + "b_layout": pto.BLayout.COL_MAJOR, + "s_layout": pto.SLayout.ROW_MAJOR, + "s_fractal_size": 16, + "pad_value": pto.PadValue.ZERO, + } + ), + ) + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + config_assign, layout_assign, secondary_assign, fractal_assign, pad_assign, rank_assign, space_assign = ( + semantic_kernel.body[:7] + ) + + self.assertIsInstance(config_assign, SemanticAssignStmt) + self.assertIsInstance(config_assign.targets[0].type, SemanticTileConfigType) + self.assertIsInstance(config_assign.value, SemanticLiteralExpr) + self.assertEqual(config_assign.targets[0].value, config_assign.value.value) + self.assertIsInstance(config_assign.value.type, SemanticTileConfigType) + + self.assertIsInstance(layout_assign.value, SemanticSymbolExpr) + self.assertEqual(layout_assign.value.value, pto.BLayout.COL_MAJOR) + self.assertEqual(layout_assign.value.type.kind, "b_layout") + + self.assertIsInstance(secondary_assign.value, SemanticSymbolExpr) + self.assertEqual(secondary_assign.value.value, pto.SLayout.ROW_MAJOR) + self.assertEqual(secondary_assign.value.type.kind, "s_layout") + + self.assertIsInstance(fractal_assign.value, SemanticLiteralExpr) + self.assertEqual(fractal_assign.value.value, 16) + self.assertIsInstance(fractal_assign.targets[0].type, SemanticScalarType) + self.assertEqual(fractal_assign.targets[0].type.dtype, pto.i32) + + self.assertIsInstance(pad_assign.value, SemanticSymbolExpr) + self.assertEqual(pad_assign.value.value, pto.PadValue.ZERO) + self.assertEqual(pad_assign.value.type.kind, "pad_value") + + self.assertEqual(rank_assign.value.value, 2) + self.assertIsInstance(rank_assign.targets[0].type, SemanticIndexType) + + self.assertIsInstance(space_assign.value, SemanticSymbolExpr) + self.assertEqual(space_assign.value.value, pto.MemorySpace.UB) + self.assertEqual(space_assign.value.type.kind, "memory_space") + def test_make_mask_vlds_vsts_and_vector_families_lower_inside_strict_vecscope(self) -> None: @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32)], advanced=True) From 4157241d6b7ad65491b288ef21db6af0b6b443c7 Mon Sep 17 00:00:00 2001 From: qukelin Date: Fri, 10 Apr 2026 20:42:43 +0800 Subject: [PATCH 046/192] support expand tload/tstore to tile lib in ptoas --- docs/designs/ptoas-tileop-expand-design.md | 153 +++++--- docs/designs/tilelang-st-framework.md | 332 ++++++++++++++++++ include/PTO/Transforms/Passes.td | 20 +- lib/PTO/Transforms/ExpandTileOp.cpp | 196 ++++++++--- lib/PTO/Transforms/FoldTileBufIntrinsics.cpp | 320 ++++++++++++++++- test/tilelang_st/npu/a5/src/st/CMakeLists.txt | 79 +++++ .../npu/a5/src/st/testcase/CMakeLists.txt | 97 +++++ .../a5/src/st/testcase/tadd/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tadd/compare.py | 60 ++++ .../npu/a5/src/st/testcase/tadd/gen_data.py | 33 ++ .../npu/a5/src/st/testcase/tadd/launch.cpp | 33 ++ .../npu/a5/src/st/testcase/tadd/main.cpp | 144 ++++++++ .../npu/a5/src/st/testcase/tadd/tadd.pto | 140 ++++++++ test/tilelang_st/script/run_st.py | 251 +++++++++++++ .../python/tilelang_dsl/expand_helper.py | 47 +++ 15 files changed, 1808 insertions(+), 106 deletions(-) create mode 100644 docs/designs/tilelang-st-framework.md create mode 100644 test/tilelang_st/npu/a5/src/st/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/tadd.pto create mode 100755 test/tilelang_st/script/run_st.py diff --git a/docs/designs/ptoas-tileop-expand-design.md b/docs/designs/ptoas-tileop-expand-design.md index 4a71b25a8..8e9967ed6 100644 --- a/docs/designs/ptoas-tileop-expand-design.md +++ b/docs/designs/ptoas-tileop-expand-design.md @@ -417,7 +417,7 @@ PTOAS 编译器的输入可以是 Tile 指令、向量指令、或两者的混 ↓ Inline ← 将模板函数体 inline 到调用点 ↓ - Fold TileBuf Intrinsics ← 折叠 tile_buf_addr / tile_valid_rows / tile_valid_cols + Fold TileBuf Intrinsics ← 折叠 tile_buf / tensor_view intrinsic,解析到具体值 ↓ VF Fusion ← 合并相邻向量循环,消除中间 UB 读写 ↓ @@ -428,20 +428,20 @@ Tile 指令到向量指令的展开由三个 pass 协作完成: 1. **Expand TileOp**:核心 pass。调用 TileLang Python DSL 实例化模板库,生成以 `tile_buf` 为参数的向量实现函数,将原 Tile op 替换为对该函数的 `func.call`。 2. **Inline**:将模板函数体 inline 到调用点,使模板函数的 `tile_buf` 形参与调用点的实际 `tile_buf` 值绑定。 -3. **Fold TileBuf Intrinsics**:折叠 inline 后留下的 `pto.tile_buf_addr`、`pto.tile_valid_rows`、`pto.tile_valid_cols` 等 intrinsic,将 `tile_buf` 的静态属性(地址、shape、布局)折叠为具体的 memref 和常量。 +3. **Fold TileBuf Intrinsics**:折叠 inline 后留下的 tile_buf 系列(`pto.tile_buf_addr`、`pto.tile_valid_rows`、`pto.tile_valid_cols`)和 tensor_view 系列(`pto.tensor_view_addr`、`pto.get_tensor_view_dim`、`pto.get_tensor_view_stride`)intrinsic,将 `tile_buf` / `partition_tensor_view` 的属性折叠为具体的 memref、常量和 SSA 值。 ### 3.2 Expand TileOp Pass 的工作流程 以编译时遇到 `pto.tadd` 为例,Expand TileOp pass 的处理步骤如下: ``` -Step 1: 识别 Tile Op -─────────────────── - 遍历函数体中所有 Tile op(pto.tadd, pto.tsub, ...) - 遇到 pto.tadd ins(%a, %b) outs(%c) - 从所有操作数的 tile_buf 类型提取属性: - dtype=f32, rows=16, cols=64, v_row=16, v_col=64, - blayout=row_major, slayout=none_box, fractal=512, pad=0 +Step 1: 识别 Tile Op 并分类操作数 +──────────────────────────────── + 遍历函数体中所有 Tile op(pto.tadd, pto.tload, ...) + 每个操作数按 IR 类型分为三类: + Tile — TileBufType(如 pto.tadd 的输入/输出 tile_buf) + View — MemRefType(如 pto.tload 的 src,由 PTOViewToMemref 降级的 partition_tensor_view) + Scalar — 标量类型(如 pto.tadds 的 scalar 操作数) Step 2: 构造 Specialization Key + 查询缓存 ────────────────────────────────────────── @@ -451,66 +451,79 @@ Step 2: 构造 Specialization Key + 查询缓存 Step 3: 实例化模板(缓存未命中时执行) ───────────────────────────────────── - 调用 TileLang Python DSL,传入 op 名称和各操作数的 tile_buf 类型信息 - Python DSL 查找匹配的 @vkernel 模板,填入具体 tile_buf 参数进行特化 - 输出实例化后的 MLIR 函数(以 tile_buf 为参数,内含向量循环体) - 解析 MLIR 文本,克隆函数到目标 Module,写入缓存 + 调用 TileLang Python DSL,传入 op 名称和各操作数的类型信息 + Python DSL 查找匹配的 @vkernel 模板,填入具体参数进行特化 + 输出实例化后的 MLIR 函数,解析文本,克隆到目标 Module,写入缓存 Step 4: 生成调用并替换原 Tile Op ─────────────────────────────── - 在原 Tile op 位置插入 func.call @__pto_tilelang_tadd_f32_16_64(%a, %b, %c) - 操作数直接传递(类型均为 tile_buf,无需桥接转换) + 在原 Tile op 位置插入 func.call @__pto_tilelang_...(%a, %b, %c) + - Tile 操作数:类型一致,直接传递 + - View 操作数:调用方类型为 memref,模板参数类型为 partition_tensor_view, + 插入 builtin.unrealized_conversion_cast 桥接(由后续 FoldTileBufIntrinsics 消除) + - Scalar 操作数:直接传递 删除原 Tile op ``` #### 3.2.1 Specialization Key 与缓存 -模板展开本质上是一个特化过程。当同一个 module 中存在多个相同类型的 Tile op(如多处 `pto.tadd` 且所有 `tile_buf` 操作数类型完全相同),应复用已实例化的结果而非重复展开。 +模板展开本质上是一个特化过程。当同一个 module 中存在多个相同类型的 Tile op(如多处 `pto.tadd` 且所有操作数类型完全相同),应复用已实例化的结果而非重复展开。 -**重要**:SpecKey 必须基于 **所有操作数** 的 `tile_buf` 类型构建,而不仅仅是第一个操作数。因为同一个 op 的不同操作数可能有不同的类型(如不同的 dtype 或 shape),仅用第一个操作数无法区分这些情况。 +**重要**:SpecKey 必须基于 **所有操作数** 的类型构建,而不仅仅是第一个操作数。因为同一个 op 的不同操作数可能有不同的类型(如不同的 dtype 或 shape),仅用第一个操作数无法区分这些情况。 -Expand TileOp pass 维护一个实例化缓存,key 包含以下字段: +操作数按 IR 类型分为三类,每类参与 SpecKey 的字段不同: -| Key 字段 | 说明 | -|----------|------| -| `op_name` | Tile op 名称(如 `tadd`) | -| `operand_types` | **所有操作数**的 tile_buf 类型签名,每个操作数包含以下信息 | -| ├─ `dtype` | 元素数据类型(如 `f32`) | -| ├─ `shape` | Tile 的静态 shape(如 `(16, 64)`) | -| └─ `config` | blayout、slayout、fractal、pad 等配置 | +| 操作数类型 | IR 类型 | 参与 SpecKey 的字段 | 不参与 SpecKey 但传给 Python DSL 的字段 | +|-----------|---------|--------------------|-----------------------------------------| +| **Tile** | `TileBufType` | `dtype` + `shape` + `memorySpace` + `config`(blayout/slayout/fractal/pad) | — | +| **View** | `MemRefType`(降级后的 `PartitionTensorViewType`) | `dtype` | `shape`、`strides`、`memorySpace`(仅用于约束检查) | +| **Scalar** | 标量类型 | `dtype` | — | -`valid_shape` **不参与** key——因为它可能是动态的,作为运行时值在 inline 后通过 `pto.tile_valid_rows`/`pto.tile_valid_cols` 提取。相同 `(op, operand_types)` 但不同 `valid_shape` 的 Tile op 可以共享同一份实例化结果。 +**View 操作数的特化策略**:View 对应的模板参数类型为 `!pto.partition_tensor_view`,维度全部动态,shape/strides 通过 intrinsic 在运行时查询。因此不同 view shape 的 Tile op 可以共享同一份模板实例——`shape`/`strides`/`memorySpace` 不参与 SpecKey 的判等和 hash。这些字段通过 `--operand-specs` JSON 传给 Python DSL 的 `expand_helper`,注入到约束上下文中(如 `src.strides[4] == 1`),但不影响模板代码生成。 + +**Tile 操作数的排除字段**:`valid_shape` 不参与 SpecKey——因为它可能是动态的,作为运行时值在 inline 后通过 `pto.tile_valid_rows`/`pto.tile_valid_cols` 提取。相同 `(op, operand_types)` 但不同 `valid_shape` 的 Tile op 可以共享同一份实例化结果。 #### 3.2.2 模板实例化过程 Expand TileOp 通过调用 Python 子进程来实例化模板。具体流程: -1. **调用 Python helper**:`python3 -m tilelang_dsl.expand_helper`,传入 op 名称、各操作数的 dtype/shape/memory_space 等参数。 +1. **调用 Python helper**:`python3 -m tilelang_dsl.expand_helper --op pto. --operand-specs `,其中 JSON 描述每个操作数的类型信息。 2. **Python 端处理**: - 扫描模板目录下的 `.py` 文件,查找标注了 `@pto.vkernel` 装饰器的模板函数 - 按 `op` 名称和 `dtype` 签名匹配模板 - - 对所有 `pto.Tile` 参数使用给定的 shape 和 memory_space 进行特化 + - 对 `pto.Tile` 参数使用给定的 shape 和 memory_space 进行特化 + - 对 `pto.PartitionTensorView` 参数,将 shape/strides 注入约束上下文用于前置条件检查,但不影响模板特化(参数类型保持全动态) - 输出特化后的 MLIR 文本 3. **C++ 端处理**: - 解析 MLIR 文本为 `ModuleOp` - 提取 `func.func`,克隆到目标 Module 末尾 - - 重命名为 `__pto_tilelang____`(如 `__pto_tilelang_tadd_f32_16_64`),设为 `private` 可见性 + - 重命名为 `__pto_tilelang__tile____view__...`(Tile 操作数拼 shape,View/Scalar 只拼 dtype),设为 `private` 可见性 - 存入 specCache **关键约束**:Python DSL 实例化输出的函数需要满足以下要求: -1. **参数类型为 `!pto.tile_buf`**,而非 memref。DSL 在实例化时将具体的元素类型、静态 shape、布局配置等信息编码进 `tile_buf` 类型参数。 -2. **函数必须带有 `pto.tilelang.instance` 属性**(UnitAttr)。Inline pass 通过此属性识别需要内联的模板实例函数,而非依赖函数名前缀。 +1. **参数类型**可以是 `!pto.tile_buf`、`!pto.partition_tensor_view` 或标量类型。DSL 在实例化时将 Tile 参数的元素类型、静态 shape、布局配置等信息编码进 `tile_buf` 类型;View 参数保持全动态维度(`!pto.partition_tensor_view`)。 +2. **函数必须带有 `pto.tilelang.instance` 属性**(UnitAttr)。Inline pass 通过此属性识别需要内联的模板实例函数。 + +函数体内部通过以下 intrinsic 提取信息: + +**tile_buf 系列**(从 `!pto.tile_buf` 提取): + +| Intrinsic | 功能 | 输出类型 | +|-----------|------|----------| +| `pto.tile_buf_addr` | 提取数据区域的 memref 指针 | `memref, #pto.address_space<...>>` | +| `pto.tile_valid_rows` | 提取有效行数 | `index` | +| `pto.tile_valid_cols` | 提取有效列数 | `index` | -函数体内部通过以下 intrinsic 从 `tile_buf` 中提取信息: +**tensor_view 系列**(从 `!pto.partition_tensor_view` 提取): | Intrinsic | 功能 | 输出类型 | |-----------|------|----------| -| `pto.tile_buf_addr` | 从 tile_buf 提取数据区域的 memref 指针 | `memref, #pto.address_space<...>>` | -| `pto.tile_valid_rows` | 从 tile_buf 提取有效行数 | `index` | -| `pto.tile_valid_cols` | 从 tile_buf 提取有效列数 | `index` | +| `pto.tensor_view_addr` | 提取 memref/ptr 基地址 | `memref<...>` 或 `!pto.ptr<...>` | +| `pto.get_tensor_view_dim` | 按维度索引提取 shape 大小 | `index` | +| `pto.get_tensor_view_stride` | 按维度索引提取 stride | `index` | -这样设计的好处是:Expand TileOp pass 的调用点不需要做任何类型桥接,直接将 `tile_buf` 操作数透传给实例化的函数。类型转换和属性提取的工作统一在后续的 Fold pass 中处理。 +对于 Tile 操作数,Expand TileOp 直接将 `tile_buf` 透传。对于 View 操作数,调用方类型为 `memref`,模板参数类型为 `!pto.partition_tensor_view`,因此 Expand TileOp 在调用点插入 `builtin.unrealized_conversion_cast` 桥接。类型转换和 intrinsic 折叠统一在后续的 Fold pass 中处理。 ### 3.3 实例化模板函数的 IR 结构 @@ -650,10 +663,11 @@ func.func @TADD(%a: !pto.tile_buf<...>, %b: !pto.tile_buf<...>, %c: !pto.tile_bu #### 3.4.4 经过 Fold TileBuf Intrinsics 后 -Fold pass 通过严格的模式匹配,将 `pto.tile_buf_addr`、`pto.tile_valid_rows`、`pto.tile_valid_cols` -解析回 `MemrefToTileBuf` pass 在调用点构造的具体 SSA 值。 +Fold pass 处理两族 intrinsic,通过严格的模式匹配将它们解析回调用点的具体 SSA 值。 -**严格模式匹配**:每一个被折叠的 intrinsic,其 `tile_buf` 操作数必须由如下固定链定义 +##### tile_buf 系列折叠 + +每一个被折叠的 tile_buf intrinsic,其 `tile_buf` 操作数必须由如下固定链定义 (由 `MemrefToTileBuf` pass 保证),否则 pass 直接报错并失败: ```mlir @@ -683,6 +697,44 @@ Fold pass 通过严格的模式匹配,将 `pto.tile_buf_addr`、`pto.tile_vali 若是动态值(`v_row=?`),折叠为 `bind_tile` 的 **第二个操作数**(`valid_row`,已经是 `index` 类型)。 - `pto.tile_valid_cols %a` → 同理,使用 `validShape[1]` 或 `bind_tile` 的 **第三个操作数**。 +##### tensor_view 系列折叠 + +每一个被折叠的 tensor_view intrinsic,其 `partition_tensor_view` 操作数必须由如下固定链定义 +(由 `ExpandTileOp` 和 `PTOViewToMemref` pass 保证),否则 pass 直接报错并失败: + +```mlir +%rc = memref.reinterpret_cast %arg0 + to offset: [0], sizes: [%c1, %c1, %c1, %c16, %c64], + strides: [%c1024, %c1024, %c1024, %c64, %c1] + : memref → memref, gm> + +%sv = memref.subview %rc [0,0,0,0,0] [1,1,1,16,64] [1,1,1,1,1] + : → memref<1x1x1x16x64xf32, strided<[?,?,?,?,?], offset:?>, gm> + +%tv = builtin.unrealized_conversion_cast %sv + : memref<...> → !pto.partition_tensor_view<...> +``` + +也即:`partition_tensor_view ← unrealized_conversion_cast ← memref.subview ← memref.reinterpret_cast`。 + +pass 贯穿整条链,**一步到位**折叠到最终结果,不生成中间的 `memref.dim`、`memref.extract_strided_metadata` 或 `pto.castptr %subview`: + +- `pto.get_tensor_view_dim %tv, %cN` → + - subview 结果类型 shape[N] 是静态的:折叠为 `arith.constant`(如 dim 3 → `arith.constant 16`) + - shape[N] 是动态的:取 subview 的 `getMixedSizes()[N]`(可能追溯到 reinterpret_cast 的 size operand) + +- `pto.get_tensor_view_stride %tv, %cN` → + 直接取 reinterpret_cast 的 stride operand(通过 `getMixedStrides()[N]`)。 + 若 subview 的 stride[N] 不为 1,则生成 `arith.muli(rc_stride, sv_stride)`。 + reinterpret_cast 的 stride 可以是静态属性(生成 `arith.constant`)或动态 SSA 值(直接复用)。 + +- `pto.tensor_view_addr %tv` → + - subview 和 reinterpret_cast 的 offset 均为 0:折叠为 `pto.castptr %arg0`(直接用 base memref) + - 有非零 offset:折叠为 `pto.addptr(pto.castptr %arg0, linear_offset)`, + 其中 `linear_offset = rc_offset + sum(sv_offset[i] * rc_stride[i])` + +##### 通用规则 + **跳过 TileLang 模板实例**:被 `PTOInlineLibCall` 内联完且作为 dead callee 删除之前, 带 `pto.tilelang.instance` 属性的私有模板函数仍可能保留在 module 中。这些函数体内的 `pto.tile_buf_addr` 等 intrinsic 直接作用在 `tile_buf` 类型的 BlockArgument 上, @@ -799,12 +851,29 @@ def template_xxx(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): ### 4.3 PTOAS 编译器:Fold TileBuf Intrinsics Pass +**tile_buf 系列**: + | 工作项 | 说明 | |--------|------| | 严格模式匹配 | 要求 `tile_buf` 由 `unrealized_conversion_cast ← pto.bind_tile` 链定义,否则 emit error 并 fail pass | -| `tile_buf_addr` 折叠 | 替换为 `bind_tile.getSource()`(即 `pto.pointer_cast` 的静态布局 memref),绕过 `bind_tile` 产出的动态 offset 布局,避免下游 VPTO 后端无法处理 `offset: ?` | -| 结果类型自适应 | 若 `tile_buf_addr` 声明类型与 source memref 实际布局不一致,就地更新结果类型(下游向量算子对 strided 布局多态) | -| `tile_valid_rows/cols` 折叠 | 优先按 `TileBufType.validShape` 静态折叠为 `arith.constant`;动态时取 `bind_tile` 的 `valid_row`/`valid_col` 操作数(均为 `index` 类型,无需 cast) | +| `tile_buf_addr` 折叠 | 替换为 `bind_tile.getSource()`(即 `pto.pointer_cast` 的静态布局 memref),绕过 `bind_tile` 产出的动态 offset 布局 | +| 结果类型自适应 | 若 `tile_buf_addr` 声明类型与 source memref 实际布局不一致,就地更新结果类型 | +| `tile_valid_rows/cols` 折叠 | 优先按 `TileBufType.validShape` 静态折叠为 `arith.constant`;动态时取 `bind_tile` 的 `valid_row`/`valid_col` 操作数 | + +**tensor_view 系列**: + +| 工作项 | 说明 | +|--------|------| +| 严格模式匹配 | 要求 `partition_tensor_view` 由 `unrealized_conversion_cast ← memref.subview ← memref.reinterpret_cast` 链定义,否则 emit error 并 fail pass | +| `tensor_view_addr` 折叠 | 贯穿 subview → reinterpret_cast 链,折叠为 `pto.castptr %base_memref`;有非零 offset 时生成 `pto.addptr` | +| `get_tensor_view_dim` 折叠 | 静态 shape 维度折叠为 `arith.constant`;动态维度取 subview 的 `getMixedSizes()` operand | +| `get_tensor_view_stride` 折叠 | 直接取 reinterpret_cast 的 stride operand(`getMixedStrides()`),乘以 subview stride(通常为 1 可短路) | +| Dead op 清理 | 折叠完成后清理无 user 的 `unrealized_conversion_cast`、`memref.subview`、`memref.reinterpret_cast` | + +**通用**: + +| 工作项 | 说明 | +|--------|------| | 跳过模板实例 | 检测 `pto.tilelang.instance` 属性,跳过 `PTOInlineLibCall` 删除前残留的私有模板函数 | ### 4.4 测试与文档 diff --git a/docs/designs/tilelang-st-framework.md b/docs/designs/tilelang-st-framework.md new file mode 100644 index 000000000..8a72a4b08 --- /dev/null +++ b/docs/designs/tilelang-st-framework.md @@ -0,0 +1,332 @@ +# TileLang ST 精度验证框架 + +## 概述 + +TileLang ST(System Test)框架用于在 Ascend NPU 硬件或仿真器上端到端验证 TileLang DSL 模板库生成的 kernel 精度。框架参考 pto-isa 的 ST 目录结构和运行流程,但针对 TileLang 的编译路径(`.pto → LLVM IR → .o`,而非 pto-isa 的 `.cpp → -xcce → .o`)做了适配。 + +### 与 pto-isa ST 的关键差异 + +| | pto-isa ST | TileLang ST | +|--|-----------|-------------| +| kernel 源码 | 手写 C++(`kernel.cpp`) | PTO DSL(`tadd.pto`) | +| kernel 编译 | `bisheng -xcce kernel.cpp` | `ptoas .pto → .ll` + `bisheng -x ir .ll → .o` | +| 精度比较 | C++ `ResultCmp()`(GTest) | Python `np.allclose`(`compare.py`) | +| 多 case 支持 | 单文件多 GTest TEST_F | 单 `.pto` 多 kernel 函数 + case table | + +### 执行流程 + +``` +run_st.py + ├── set_env # 设置 ASCEND / simulator 环境变量 + ├── cmake + make # ptoas→.ll → bisheng→.o → link .so → build 可执行文件 + ├── gen_data.py # numpy 生成 input + golden(per-case 子目录) + ├── ./tadd [case] # 运行 kernel,写 output.bin(per-case 子目录) + └── compare.py # np.allclose 逐 case 比较 golden vs output +``` + +## 目录结构 + +``` +test/tilelang_st/ +├── script/ +│ └── run_st.py # 统一入口脚本 +└── npu/ + └── a5/ # SoC 架构 + └── src/st/ + ├── CMakeLists.txt # 顶层 CMake(编译器/环境配置) + └── testcase/ + ├── CMakeLists.txt # pto_tilelang_vec_st() 宏定义 + op 注册 + └── tadd/ # 每个 op 一个目录 + ├── CMakeLists.txt # 一行:pto_tilelang_vec_st(tadd) + ├── tadd.pto # kernel DSL(可包含多个函数) + ├── launch.cpp # kernel 声明 + launch wrapper + ├── main.cpp # host driver(case table 驱动) + ├── gen_data.py # 数据生成 + └── compare.py # 精度比较 +``` + +## 快速上手 + +### 运行已有测试 + +```bash +# 在 NPU 上跑 tadd 全部 case +python3 test/tilelang_st/script/run_st.py -r npu -v a5 -t tadd + +# 在仿真器上跑 +python3 test/tilelang_st/script/run_st.py -r sim -v a5 -t tadd + +# 只跑某个 case +python3 test/tilelang_st/script/run_st.py -r npu -v a5 -t tadd -c f32_16x64 + +# 跳过编译(已有 build 产物时) +python3 test/tilelang_st/script/run_st.py -r npu -v a5 -t tadd -w +``` + +### run_st.py 参数 + +| 参数 | 说明 | 示例 | +|------|------|------| +| `-r, --run-mode` | 运行模式 | `npu` 或 `sim` | +| `-v, --soc-version` | 架构版本 | `a5` | +| `-t, --testcase` | op 名称 | `tadd` | +| `-c, --case` | 指定单个 case(可选) | `f32_16x64` | +| `-p, --ptoas-bin` | ptoas 路径(可选,默认自动查找) | `/path/to/ptoas` | +| `-w, --without-build` | 跳过编译 | — | + +ptoas 路径查找顺序:`-p` 参数 → `PTOAS_BIN` 环境变量 → 从脚本位置向上遍历 `build/bin/ptoas`。 + +## 新增一个 op 测试 + +以新增 `tsub` 为例。 + +### 第 1 步:创建目录和文件 + +```bash +mkdir test/tilelang_st/npu/a5/src/st/testcase/tsub +``` + +需要创建 6 个文件,下面逐一说明。 + +### 第 2 步:编写 kernel(tsub.pto) + +单个 `.pto` 文件中包含所有 case 对应的 kernel 函数,函数名格式为 `@TSUB__x`: + +```mlir +module { + // Case 0: f32 16x64 + func.func @TSUB_f32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + // make_tensor_view × 3, partition_view × 3, alloc_tile × 3 + // tload × 2, tsub, tstore + ... + return + } + + // Case 1: f32 32x32 + func.func @TSUB_f32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + ... + return + } +} +``` + +### 第 3 步:编写 launch wrapper(launch.cpp) + +每个 kernel 函数需要一个 `__global__` 声明和一个 `Launch*` C++ wrapper: + +```cpp +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#include +#include +#include + +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +__global__ AICORE void TSUB_f32_16x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTSUB_f32_16x64(float *a, float *b, float *c, void *stream) { + TSUB_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +__global__ AICORE void TSUB_f32_32x32(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTSUB_f32_32x32(float *a, float *b, float *c, void *stream) { + TSUB_f32_32x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} +``` + +**注意:** `__global__`、`AICORE`、`__gm__`、`<<<>>>` 是 CCE 扩展语法,本地 clang 会报错,这是预期行为——launch.cpp 由 bisheng `-xcce` 编译。 + +### 第 4 步:编写 host driver(main.cpp) + +使用 case table 驱动,每个 case 从独立子目录读写数据: + +```cpp +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", \ + #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + return 1; \ + } \ + } while (0) + +// launch wrappers +void LaunchTSUB_f32_16x64(float *a, float *b, float *c, void *stream); +void LaunchTSUB_f32_32x32(float *a, float *b, float *c, void *stream); + +using LaunchFn = void (*)(float *, float *, float *, void *); + +struct TestCase { + const char *name; // 与 gen_data.py / compare.py 的 case name 一致 + LaunchFn launch; + size_t rows; + size_t cols; + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_16x64", LaunchTSUB_f32_16x64, 16, 64, sizeof(float)}, + {"f32_32x32", LaunchTSUB_f32_32x32, 32, 32, sizeof(float)}, +}; + +// main() 中循环 kCases,对每个 case: +// 1. 从 .//input1.bin, input2.bin 读数据 +// 2. H2D → launch kernel → D2H +// 3. 写 .//output.bin +// 支持 ./tsub [case_name] 过滤单个 case +``` + +完整实现参考 `testcase/tadd/main.cpp`。 + +### 第 5 步:数据生成与精度比较 + +**gen_data.py** — 为每个 case 生成独立子目录的 `input1.bin`、`input2.bin`、`golden.bin`: + +```python +import os +import numpy as np + +np.random.seed(42) + +CASES = [ + {"name": "f32_16x64", "dtype": np.float32, "shape": (16, 64)}, + {"name": "f32_32x32", "dtype": np.float32, "shape": (32, 32)}, +] + +for case in CASES: + case_dir = case["name"] + os.makedirs(case_dir, exist_ok=True) + + input1 = np.random.randint(1, 10, size=case["shape"]).astype(case["dtype"]) + input2 = np.random.randint(1, 10, size=case["shape"]).astype(case["dtype"]) + golden = (input1 - input2).astype(case["dtype"], copy=False) # tsub: 减法 + + input1.tofile(os.path.join(case_dir, "input1.bin")) + input2.tofile(os.path.join(case_dir, "input2.bin")) + golden.tofile(os.path.join(case_dir, "golden.bin")) +``` + +**compare.py** — 逐 case 比较,支持 `python compare.py [case_name]` 过滤: + +```python +import sys, os +import numpy as np + +CASES = [ + {"name": "f32_16x64", "dtype": np.float32, "eps": 1e-6}, + {"name": "f32_32x32", "dtype": np.float32, "eps": 1e-6}, +] + +def compare_bin(golden_path, output_path, dtype, eps): + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + if golden.shape != output.shape: + return False + return np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + +case_filter = sys.argv[1] if len(sys.argv) > 1 else None +all_passed = True +for case in CASES: + if case_filter and case["name"] != case_filter: + continue + ok = compare_bin( + os.path.join(case["name"], "golden.bin"), + os.path.join(case["name"], "output.bin"), + case["dtype"], case["eps"]) + if not ok: + all_passed = False +if not all_passed: + sys.exit(2) +``` + +### 第 6 步:CMake 注册 + +**testcase/tsub/CMakeLists.txt**(一行): + +```cmake +pto_tilelang_vec_st(tsub) +``` + +**testcase/CMakeLists.txt** 中注册: + +```cmake +set(ALL_TESTCASES + tadd + tsub # <-- 新增 +) +``` + +## 在已有 op 下新增 case + +不需要修改 CMake。只需要同步修改 4 个文件: + +| 步骤 | 文件 | 修改内容 | +|------|------|---------| +| 1 | `tadd.pto` | 新增 `func.func @TADD__(...)` | +| 2 | `launch.cpp` | 新增 `__global__` 声明 + `LaunchTADD__` wrapper | +| 3 | `main.cpp` | `kCases[]` 数组新增一行 | +| 4 | `gen_data.py` + `compare.py` | `CASES` 列表各新增一行 | + +### 命名约定 + +- kernel 函数名:`@__x`,例如 `@TADD_f32_16x64` +- launch wrapper:`Launch__x` +- case name / 子目录名:`_x`,例如 `f32_16x64` +- 三处 case 列表(`main.cpp` `kCases[]`、`gen_data.py` `CASES`、`compare.py` `CASES`)必须保持一致 + +## CMake 编译流水线 + +`pto_tilelang_vec_st(NAME)` 宏定义在 `testcase/CMakeLists.txt` 中,完成 4 步编译: + +``` +Step 1: ptoas NAME.pto → NAME_kernel.ll (LLVM IR) + ptoas --pto-arch=a5 --pto-backend=vpto + --enable-tile-op-expand --vpto-emit-hivm-llvm + NAME.pto -o NAME_kernel.ll + +Step 2: bisheng NAME_kernel.ll → NAME_kernel.o (object file) + bisheng --target=hiipu64-hisilicon-cce + -march=dav-c310-vec + --cce-aicore-arch=dav-c310-vec + --cce-aicore-only + -c -x ir NAME_kernel.ll -o NAME_kernel.o + +Step 3: launch.cpp (-xcce) + NAME_kernel.o → libNAME_kernel.so + bisheng -xcce launch.cpp + link NAME_kernel.o + +Step 4: main.cpp (-xc++) → NAME (可执行文件) + link libNAME_kernel.so + ACL runtime libraries +``` + +## 精度比较说明 + +使用 `np.allclose(golden, output, atol=eps, rtol=eps)` 进行比较。不同 dtype 建议的 eps 值: + +| dtype | 建议 eps | +|-------|---------| +| float32 | 1e-6 | +| float16 | 1e-3 | +| bfloat16 | 1e-2 | +| int8/int32 | 0(精确匹配) | + +比较失败时会输出 max diff、出错位置和 golden/output 值,便于定位问题。 diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index 0ec3877af..6c07552d0 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -276,18 +276,30 @@ def ExpandTileOp : Pass<"pto-expand-tile-op", "ModuleOp"> { } def FoldTileBufIntrinsics : Pass<"pto-fold-tile-buf-intrinsics", "mlir::func::FuncOp"> { - let summary = "Fold tile_buf_addr / tile_valid_rows / tile_valid_cols"; + let summary = "Fold structured-view intrinsics after template inlining"; let description = [{ After TileLang DSL template functions are inlined, the IR contains - pto.tile_buf_addr, pto.tile_valid_rows, and pto.tile_valid_cols ops - whose tile_buf operands are now bound to concrete values. + structured-view intrinsics whose operands are now bound to concrete values. This pass resolves them: - pto.tile_buf_addr → replaced by the underlying pto.bind_tile source memref, or pto.castptr when the requested result type is !pto.ptr - pto.tile_valid_rows → folded to arith.constant if v_row is static, - or replaced with the dynamic index value from tile_buf + or replaced with the dynamic index value from bind_tile - pto.tile_valid_cols → same as above for v_col + + tensor_view family: + - pto.tensor_view_addr → traces through unrealized_conversion_cast → + subview → reinterpret_cast, then folds to the base memref or to + pto.castptr/pto.addptr on the base memref + - pto.get_tensor_view_dim → folded to arith.constant for static subview + sizes, or to the subview size SSA operand for dynamic dims + - pto.get_tensor_view_stride → folded to the reinterpret_cast stride + operand, multiplied by the subview stride when needed + + Dead unrealized_conversion_cast, memref.subview, and + memref.reinterpret_cast ops exposed by folding are cleaned up after the + rewrite. }]; let constructor = "mlir::pto::createFoldTileBufIntrinsicsPass()"; let dependentDialects = [ diff --git a/lib/PTO/Transforms/ExpandTileOp.cpp b/lib/PTO/Transforms/ExpandTileOp.cpp index 9d5c823e8..4eb830a37 100644 --- a/lib/PTO/Transforms/ExpandTileOp.cpp +++ b/lib/PTO/Transforms/ExpandTileOp.cpp @@ -72,28 +72,48 @@ namespace pto { namespace { // ============================================================================ -// OperandTypeInfo: captures the tile_buf type info for one operand. +// OperandTypeInfo: describes one operand for template specialization. +// +// Three kinds of operands: +// Tile — from TileBufType. dtype + shape + memorySpace + config +// all participate in the specialization key (SpecKey). +// View — from MemRefType (lowered PartitionTensorViewType). Only dtype +// participates in SpecKey — the template is fully dynamic so +// shape/strides/memorySpace don't affect code generation. They +// are carried here solely for JSON serialization to the Python +// DSL for constraint checking. +// Scalar — from a scalar element type. Only dtype participates in SpecKey. // ============================================================================ -enum class OperandKind { - Tile, - Scalar, -}; +enum class OperandKind { Tile, View, Scalar }; struct OperandTypeInfo { OperandKind kind = OperandKind::Tile; - std::string dtype; - SmallVector shape; + std::string dtype; // all kinds: element type string (e.g. "f32") + + // --- Tile-only (TileBufType) --- + SmallVector tileShape; + std::string tileMemorySpace; // "ub" or "gm" int32_t blayout = 0; int32_t slayout = 0; int32_t fractal = 0; int32_t pad = 0; - std::string memorySpace; + // --- View-only (MemRefType) — for JSON / constraint checking only --- + SmallVector viewShape; + SmallVector viewStrides; + std::string viewMemorySpace; // "gm" or "ub" + + /// Equality for SpecKey caching — only compares fields relevant to each kind. bool operator==(const OperandTypeInfo &rhs) const { - return kind == rhs.kind && dtype == rhs.dtype && shape == rhs.shape && - blayout == rhs.blayout && slayout == rhs.slayout && - fractal == rhs.fractal && pad == rhs.pad && - memorySpace == rhs.memorySpace; + if (kind != rhs.kind || dtype != rhs.dtype) + return false; + if (kind == OperandKind::Tile) + return tileShape == rhs.tileShape && + tileMemorySpace == rhs.tileMemorySpace && + blayout == rhs.blayout && slayout == rhs.slayout && + fractal == rhs.fractal && pad == rhs.pad; + // View and Scalar: dtype alone is sufficient for template caching. + return true; } }; @@ -115,12 +135,14 @@ struct SpecKeyInfo : public llvm::DenseMapInfo { static unsigned getHashValue(const SpecKey &key) { unsigned h = llvm::hash_value(key.opName); for (const auto &op : key.operands) { - h = llvm::hash_combine(h, static_cast(op.kind), op.dtype, - op.blayout, op.slayout, - op.fractal, op.pad); - h = llvm::hash_combine(h, op.memorySpace); - for (int64_t d : op.shape) - h = llvm::hash_combine(h, d); + h = llvm::hash_combine(h, static_cast(op.kind), op.dtype); + if (op.kind == OperandKind::Tile) { + h = llvm::hash_combine(h, op.tileMemorySpace, op.blayout, + op.slayout, op.fractal, op.pad); + for (int64_t d : op.tileShape) + h = llvm::hash_combine(h, d); + } + // View/Scalar: only kind + dtype contribute to hash. } return h; } @@ -155,30 +177,51 @@ static std::string getMemorySpaceString(pto::TileBufType tbTy) { return "ub"; } -static std::optional -buildOperandTypeInfo(pto::TileBufType tbTy) { - OperandTypeInfo info; - info.kind = OperandKind::Tile; - info.dtype = getDtypeString(tbTy.getElementType()); - if (info.dtype.empty()) - return std::nullopt; - info.shape.assign(tbTy.getShape().begin(), tbTy.getShape().end()); - info.memorySpace = getMemorySpaceString(tbTy); - if (auto config = tbTy.getConfigAttr()) { - info.blayout = static_cast(config.getBLayout().getValue()); - info.slayout = static_cast(config.getSLayout().getValue()); - info.fractal = config.getSFractalSize() - ? static_cast(config.getSFractalSize().getInt()) - : 0; - info.pad = static_cast(config.getPad().getValue()); - } - return info; +static std::string getMemorySpaceString(MemRefType mrTy) { + auto msAttr = dyn_cast_or_null(mrTy.getMemorySpace()); + if (!msAttr) return "gm"; + if (msAttr.getAddressSpace() == pto::AddressSpace::GM) return "gm"; + return "ub"; } static std::optional buildOperandTypeInfo(Type ty) { - if (auto tbTy = dyn_cast(ty)) - return buildOperandTypeInfo(tbTy); + // Tile operand — from TileBufType. + if (auto tbTy = dyn_cast(ty)) { + OperandTypeInfo info; + info.kind = OperandKind::Tile; + info.dtype = getDtypeString(tbTy.getElementType()); + if (info.dtype.empty()) + return std::nullopt; + info.tileShape.assign(tbTy.getShape().begin(), tbTy.getShape().end()); + info.tileMemorySpace = getMemorySpaceString(tbTy); + if (auto config = tbTy.getConfigAttr()) { + info.blayout = static_cast(config.getBLayout().getValue()); + info.slayout = static_cast(config.getSLayout().getValue()); + info.fractal = config.getSFractalSize() + ? static_cast(config.getSFractalSize().getInt()) + : 0; + info.pad = static_cast(config.getPad().getValue()); + } + return info; + } + + // View operand — from MemRefType (lowered PartitionTensorViewType). + if (auto mrTy = dyn_cast(ty)) { + OperandTypeInfo info; + info.kind = OperandKind::View; + info.dtype = getDtypeString(mrTy.getElementType()); + if (info.dtype.empty()) + return std::nullopt; + info.viewShape.assign(mrTy.getShape().begin(), mrTy.getShape().end()); + info.viewMemorySpace = getMemorySpaceString(mrTy); + int64_t offset = ShapedType::kDynamic; + if (succeeded(getStridesAndOffset(mrTy, info.viewStrides, offset))) { + // strides populated — dynamic dims remain ShapedType::kDynamic. + } + return info; + } + // Scalar operand — from a scalar element type. OperandTypeInfo info; info.kind = OperandKind::Scalar; info.dtype = getDtypeString(ty); @@ -231,29 +274,52 @@ struct ExpandTileOpPass void runOnOperation() override; }; +/// Serialize a JSON array of integers. +static void appendJsonIntArray(std::string &json, ArrayRef arr) { + json += "["; + for (size_t i = 0; i < arr.size(); ++i) { + if (i > 0) + json += ","; + json += std::to_string(arr[i]); + } + json += "]"; +} + static std::string buildOperandSpecsJson(const SpecKey &key) { std::string json = "["; for (size_t i = 0; i < key.operands.size(); ++i) { const auto &op = key.operands[i]; if (i > 0) json += ","; + if (op.kind == OperandKind::Tile) { - json += "{\"kind\":\"tile\",\"dtype\":\""; - json += op.dtype; - json += "\",\"shape\":["; - for (size_t dim = 0; dim < op.shape.size(); ++dim) { - if (dim > 0) - json += ","; - json += std::to_string(op.shape[dim]); + json += "{\"kind\":\"tile\",\"dtype\":\"" + op.dtype + "\",\"shape\":"; + appendJsonIntArray(json, op.tileShape); + json += ",\"memory_space\":\"" + op.tileMemorySpace + "\"}"; + continue; + } + + if (op.kind == OperandKind::View) { + json += "{\"kind\":\"view\",\"dtype\":\"" + op.dtype + "\",\"shape\":"; + appendJsonIntArray(json, op.viewShape); + if (!op.viewStrides.empty()) { + json += ",\"strides\":["; + for (size_t dim = 0; dim < op.viewStrides.size(); ++dim) { + if (dim > 0) + json += ","; + if (ShapedType::isDynamic(op.viewStrides[dim])) + json += "null"; + else + json += std::to_string(op.viewStrides[dim]); + } + json += "]"; } - json += "],\"memory_space\":\""; - json += op.memorySpace; - json += "\"}"; + json += ",\"memory_space\":\"" + op.viewMemorySpace + "\"}"; continue; } - json += "{\"kind\":\"scalar\",\"dtype\":\""; - json += op.dtype; - json += "\"}"; + + // Scalar + json += "{\"kind\":\"scalar\",\"dtype\":\"" + op.dtype + "\"}"; } json += "]"; return json; @@ -394,13 +460,17 @@ func::FuncOp ExpandState::invokeTilelangDSL(const SpecKey &key, SmallVector clonedFuncs; llvm::StringMap renamedSymbols; - // Build a unique base name from all operand types. + // Build a unique name from the spec-key-relevant operand fields. std::string uniqueName = "__pto_tilelang_" + key.opName; for (const auto &op : key.operands) { - uniqueName += op.kind == OperandKind::Tile ? "_tile" : "_scalar"; + uniqueName += op.kind == OperandKind::Tile ? "_tile" + : op.kind == OperandKind::View ? "_view" + : "_scalar"; uniqueName += "_" + op.dtype; - for (int64_t d : op.shape) - uniqueName += "_" + std::to_string(d); + if (op.kind == OperandKind::Tile) { + for (int64_t d : op.tileShape) + uniqueName += "_" + std::to_string(d); + } } for (auto [index, fn] : llvm::enumerate(parsedFuncs)) { @@ -478,9 +548,21 @@ LogicalResult ExpandState::expandTileOpsInFunction(func::FuncOp func, return failure(); } - // Replace tile op with func.call, passing tile_buf operands directly. + // Replace tile op with func.call. For view operands whose caller type + // (memref) differs from the template parameter type (tensor_view / + // partition_tensor_view), insert an unrealized_conversion_cast bridge. + // FoldTileBufIntrinsics will later resolve these casts. builder.setInsertionPoint(op); - SmallVector operands(op->getOperands()); + SmallVector operands; + auto fnArgTypes = dslFn.getArgumentTypes(); + for (unsigned i = 0; i < op->getNumOperands(); ++i) { + Value operand = op->getOperand(i); + if (i < fnArgTypes.size() && operand.getType() != fnArgTypes[i]) { + operand = builder.create( + op->getLoc(), fnArgTypes[i], operand).getResult(0); + } + operands.push_back(operand); + } builder.create(op->getLoc(), dslFn, operands); op->erase(); } diff --git a/lib/PTO/Transforms/FoldTileBufIntrinsics.cpp b/lib/PTO/Transforms/FoldTileBufIntrinsics.cpp index a25a05624..35502d8a4 100644 --- a/lib/PTO/Transforms/FoldTileBufIntrinsics.cpp +++ b/lib/PTO/Transforms/FoldTileBufIntrinsics.cpp @@ -8,19 +8,33 @@ //===- FoldTileBufIntrinsics.cpp ------------------------------------------===// // -// After TileLang DSL template functions are inlined, the IR contains: +// After TileLang DSL template functions are inlined, the IR contains +// structured-view intrinsics that reference template parameters: +// +// tile_buf family: // - pto.tile_buf_addr → extract memref address from tile_buf // - pto.tile_valid_rows → extract valid row count // - pto.tile_valid_cols → extract valid column count // -// This pass resolves them against the concrete tile_buf values at the -// call site. +// tensor_view family: +// - pto.tensor_view_addr → extract memref/ptr from tensor_view +// - pto.get_tensor_view_dim → extract dimension size +// - pto.get_tensor_view_stride → extract dimension stride +// +// This pass resolves them against the concrete values at the call site. +// For tensor_view intrinsics, the pass traces through the full +// unrealized_conversion_cast → memref.subview → memref.reinterpret_cast +// chain to fold directly to constants or SSA operands from the +// reinterpret_cast, without generating intermediate memref.dim / +// memref.extract_strided_metadata ops. // //===----------------------------------------------------------------------===// #include "PTO/IR/PTO.h" #include "PTO/Transforms/Passes.h" +#include + #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/BuiltinTypes.h" @@ -64,6 +78,160 @@ static pto::BindTileOp findBindTileForTileBuf(Value tileBuf, Operation *user) { return bindOp; } +struct ViewChain { + UnrealizedConversionCastOp cast; + memref::SubViewOp subview; + memref::ReinterpretCastOp reinterpretCast; + Value baseMemref; +}; + +static std::optional traceViewChain(Value tensorView, + Operation *user) { + Value memrefVal; + UnrealizedConversionCastOp castOp; + + if (isa(tensorView.getType())) { + memrefVal = tensorView; + } else { + castOp = tensorView.getDefiningOp(); + if (!castOp || castOp.getNumOperands() != 1) { + user->emitError( + "FoldTileBufIntrinsics: expected tensor_view to be defined by a " + "single-operand builtin.unrealized_conversion_cast"); + return std::nullopt; + } + memrefVal = castOp.getOperand(0); + if (!isa(memrefVal.getType())) { + user->emitError( + "FoldTileBufIntrinsics: expected cast operand to be a memref, got ") + << memrefVal.getType(); + return std::nullopt; + } + } + + auto subviewOp = memrefVal.getDefiningOp(); + if (!subviewOp) { + user->emitError("FoldTileBufIntrinsics: expected memref to be defined by " + "memref.subview, got ") + << (memrefVal.getDefiningOp() + ? memrefVal.getDefiningOp()->getName().getStringRef() + : StringRef("block argument")); + return std::nullopt; + } + + auto rcOp = subviewOp.getSource().getDefiningOp(); + if (!rcOp) { + user->emitError( + "FoldTileBufIntrinsics: expected subview source to be defined by " + "memref.reinterpret_cast, got ") + << (subviewOp.getSource().getDefiningOp() + ? subviewOp.getSource().getDefiningOp()->getName().getStringRef() + : StringRef("block argument")); + return std::nullopt; + } + + return ViewChain{castOp, subviewOp, rcOp, rcOp.getSource()}; +} + +static bool getConstIndexValue(Value v, int64_t &out) { + if (auto cOp = v.getDefiningOp()) { + out = cOp.value(); + return true; + } + if (auto cInt = v.getDefiningOp()) { + out = cInt.value(); + return true; + } + if (auto cOp = v.getDefiningOp()) { + if (auto ia = dyn_cast(cOp.getValue())) { + out = ia.getInt(); + return true; + } + } + if (auto castOp = v.getDefiningOp()) + return getConstIndexValue(castOp.getIn(), out); + if (auto extOp = v.getDefiningOp()) + return getConstIndexValue(extOp.getIn(), out); + if (auto extOp = v.getDefiningOp()) + return getConstIndexValue(extOp.getIn(), out); + if (auto truncOp = v.getDefiningOp()) + return getConstIndexValue(truncOp.getIn(), out); + return false; +} + +static Value getValueOrCreateConstant(OpBuilder &builder, Location loc, + OpFoldResult ofr) { + if (auto val = dyn_cast(ofr)) + return val; + auto intAttr = dyn_cast(cast(ofr)); + assert(intAttr && "expected integer attribute in OpFoldResult"); + return builder.create(loc, intAttr.getInt()); +} + +static bool isAllStaticZero(ArrayRef ofrs) { + for (OpFoldResult ofr : ofrs) { + auto attr = dyn_cast(ofr); + if (!attr) + return false; + auto intAttr = dyn_cast(attr); + if (!intAttr || intAttr.getInt() != 0) + return false; + } + return true; +} + +static Value computeResultStride(OpBuilder &builder, Location loc, + OpFoldResult rcStride, + OpFoldResult svStride) { + if (auto attr = dyn_cast(svStride)) { + auto intAttr = dyn_cast(attr); + if (intAttr && intAttr.getInt() == 1) + return getValueOrCreateConstant(builder, loc, rcStride); + } + + Value lhs = getValueOrCreateConstant(builder, loc, rcStride); + Value rhs = getValueOrCreateConstant(builder, loc, svStride); + return builder.create(loc, lhs, rhs); +} + +static Value computeLinearOffset(OpBuilder &builder, Location loc, + ArrayRef rcOffsets, + ArrayRef svOffsets, + ArrayRef rcStrides) { + bool rcAllZero = isAllStaticZero(rcOffsets); + bool svAllZero = isAllStaticZero(svOffsets); + + if (rcAllZero && svAllZero) + return Value(); + + Value svPart; + if (!svAllZero) { + for (auto [svOffset, rcStride] : llvm::zip(svOffsets, rcStrides)) { + if (auto attr = dyn_cast(svOffset)) { + auto intAttr = dyn_cast(attr); + if (intAttr && intAttr.getInt() == 0) + continue; + } + + Value off = getValueOrCreateConstant(builder, loc, svOffset); + Value stride = getValueOrCreateConstant(builder, loc, rcStride); + Value term = builder.create(loc, off, stride); + svPart = svPart ? builder.create(loc, svPart, term) : term; + } + } + + Value rcPart; + if (!rcAllZero) { + if (rcOffsets.empty()) + return Value(); + rcPart = getValueOrCreateConstant(builder, loc, rcOffsets.front()); + } + + if (rcPart && svPart) + return builder.create(loc, rcPart, svPart); + return rcPart ? rcPart : svPart; +} + struct FoldTileBufIntrinsicsPass : public pto::impl::FoldTileBufIntrinsicsBase { using FoldTileBufIntrinsicsBase::FoldTileBufIntrinsicsBase; @@ -83,6 +251,9 @@ struct FoldTileBufIntrinsicsPass SmallVector addrOps; SmallVector rowsOps; SmallVector colsOps; + SmallVector tvAddrOps; + SmallVector tvDimOps; + SmallVector tvStrideOps; func.walk([&](Operation *op) { if (auto addr = dyn_cast(op)) @@ -91,6 +262,12 @@ struct FoldTileBufIntrinsicsPass rowsOps.push_back(rows); else if (auto cols = dyn_cast(op)) colsOps.push_back(cols); + else if (auto tvAddr = dyn_cast(op)) + tvAddrOps.push_back(tvAddr); + else if (auto tvDim = dyn_cast(op)) + tvDimOps.push_back(tvDim); + else if (auto tvStride = dyn_cast(op)) + tvStrideOps.push_back(tvStride); }); // Fold pto.tile_buf_addr → bind_tile's source memref (the static-layout @@ -205,6 +382,143 @@ struct FoldTileBufIntrinsicsPass colsOp.getResult().replaceAllUsesWith(replacement); colsOp.erase(); } + + for (auto addrOp : tvAddrOps) { + auto chain = traceViewChain(addrOp.getSrc(), addrOp); + if (!chain) + return signalPassFailure(); + + builder.setInsertionPoint(addrOp); + + auto resultPtrType = dyn_cast(addrOp.getDst().getType()); + if (!resultPtrType) { + if (auto resultMemrefType = + dyn_cast(addrOp.getDst().getType())) { + Value base = chain->baseMemref; + if (base.getType() != resultMemrefType) + addrOp.getDst().setType(cast(base.getType())); + addrOp.getDst().replaceAllUsesWith(base); + addrOp.erase(); + continue; + } + addrOp.emitError( + "FoldTileBufIntrinsics: tensor_view_addr result must be memref or " + "!pto.ptr"); + return signalPassFailure(); + } + + Value linearOffset = + computeLinearOffset(builder, addrOp.getLoc(), + chain->reinterpretCast.getMixedOffsets(), + chain->subview.getMixedOffsets(), + chain->reinterpretCast.getMixedStrides()); + + Value basePtr = builder.create( + addrOp.getLoc(), resultPtrType, chain->baseMemref); + Value replacement = + linearOffset + ? builder.create(addrOp.getLoc(), resultPtrType, + basePtr, linearOffset) + : basePtr; + + addrOp.getDst().replaceAllUsesWith(replacement); + addrOp.erase(); + } + + for (auto dimOp : tvDimOps) { + auto chain = traceViewChain(dimOp.getTensorView(), dimOp); + if (!chain) + return signalPassFailure(); + + int64_t dimIdx = 0; + if (!getConstIndexValue(dimOp.getDimIndex(), dimIdx)) { + dimOp.emitError( + "FoldTileBufIntrinsics: get_tensor_view_dim requires a constant " + "dim index"); + return signalPassFailure(); + } + + auto svTy = cast(chain->subview.getType()); + if (dimIdx < 0 || dimIdx >= svTy.getRank()) { + dimOp.emitError( + "FoldTileBufIntrinsics: get_tensor_view_dim dim index out of " + "bounds"); + return signalPassFailure(); + } + + builder.setInsertionPoint(dimOp); + Value replacement; + if (!svTy.isDynamicDim(dimIdx)) { + replacement = + builder.create(dimOp.getLoc(), + svTy.getDimSize(dimIdx)); + } else { + replacement = getValueOrCreateConstant( + builder, dimOp.getLoc(), chain->subview.getMixedSizes()[dimIdx]); + } + + dimOp.getResult().replaceAllUsesWith(replacement); + dimOp.erase(); + } + + for (auto strideOp : tvStrideOps) { + auto chain = traceViewChain(strideOp.getTensorView(), strideOp); + if (!chain) + return signalPassFailure(); + + int64_t dimIdx = 0; + if (!getConstIndexValue(strideOp.getDimIndex(), dimIdx)) { + strideOp.emitError( + "FoldTileBufIntrinsics: get_tensor_view_stride requires a " + "constant dim index"); + return signalPassFailure(); + } + + auto svTy = cast(chain->subview.getType()); + if (dimIdx < 0 || dimIdx >= svTy.getRank()) { + strideOp.emitError( + "FoldTileBufIntrinsics: get_tensor_view_stride dim index out of " + "bounds"); + return signalPassFailure(); + } + + builder.setInsertionPoint(strideOp); + Value replacement = computeResultStride( + builder, strideOp.getLoc(), + chain->reinterpretCast.getMixedStrides()[dimIdx], + chain->subview.getMixedStrides()[dimIdx]); + + strideOp.getResult().replaceAllUsesWith(replacement); + strideOp.erase(); + } + + // Clean up dead unrealized_conversion_cast ops that bridged + // memref -> partition_tensor_view / tile_buf and are now unused + // after folding. + SmallVector deadCasts; + func.walk([&](UnrealizedConversionCastOp castOp) { + if (castOp.use_empty() && castOp.getNumOperands() == 1 && + isa(castOp.getOperand(0).getType()) && + isa( + castOp.getResult(0).getType())) + deadCasts.push_back(castOp); + }); + for (auto castOp : llvm::reverse(deadCasts)) + castOp.erase(); + + while (true) { + SmallVector deadMemrefOps; + func.walk([&](Operation *op) { + if ((isa(op) || + isa(op)) && + op->use_empty()) + deadMemrefOps.push_back(op); + }); + if (deadMemrefOps.empty()) + break; + for (auto *op : llvm::reverse(deadMemrefOps)) + op->erase(); + } } }; diff --git a/test/tilelang_st/npu/a5/src/st/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/CMakeLists.txt new file mode 100644 index 000000000..20248be01 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/CMakeLists.txt @@ -0,0 +1,79 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +cmake_minimum_required(VERSION 3.16) +project(tilelang_st) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_POSITION_INDEPENDENT_CODE ON) + +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) + +# -------------------------------------------------------------------------- +# PTOAS binary — passed by run_st.py via -DPTOAS_BIN=... +# -------------------------------------------------------------------------- +if(NOT DEFINED PTOAS_BIN) + message(FATAL_ERROR "PTOAS_BIN is not set. Pass -DPTOAS_BIN=/path/to/ptoas to cmake.") +endif() + +# -------------------------------------------------------------------------- +# ASCEND environment +# -------------------------------------------------------------------------- +if(NOT DEFINED ENV{ASCEND_HOME_PATH}) + message(FATAL_ERROR "Cannot find ASCEND_HOME_PATH, please run set_env.sh.") +else() + set(ASCEND_HOME_PATH $ENV{ASCEND_HOME_PATH}) +endif() + +set(PTO_ISA_ROOT "${CMAKE_CURRENT_LIST_DIR}/../../../../../../../pto-isa" CACHE PATH "Path to pto-isa repo") +set(ASCEND_DRIVER_PATH /usr/local/Ascend/driver) + +set(CMAKE_COMPILER bisheng) +set(CMAKE_C_COMPILER ${CMAKE_COMPILER}) +set(CMAKE_CXX_COMPILER ${CMAKE_COMPILER}) + +add_compile_options( + -D_FORTIFY_SOURCE=2 + -O2 -std=c++17 + -Wno-macro-redefined -Wno-ignored-attributes -Wno-unknown-attributes + -fstack-protector-strong + -fPIC +) +add_link_options( + -s + -Wl,-z,relro + -Wl,-z,now +) + +set(CMAKE_CCE_COMPILE_OPTIONS + -xcce + -fPIC + -Xhost-start -Xhost-end + "SHELL:-mllvm -cce-aicore-stack-size=0x8000" + "SHELL:-mllvm -cce-aicore-function-stack-size=0x8000" + "SHELL:-mllvm -cce-aicore-record-overflow=true" + "SHELL:-mllvm -cce-aicore-addr-transform" + "SHELL:-mllvm -cce-aicore-dcci-insert-for-scalar=false" +) + +set(CMAKE_CPP_COMPILE_OPTIONS + -xc++ + "SHELL:-include stdint.h" + "SHELL:-include stddef.h" +) + +include_directories( + ${PTO_ISA_ROOT}/include + ${PTO_ISA_ROOT}/tests/common + ${ASCEND_HOME_PATH}/include + ${ASCEND_DRIVER_PATH}/kernel/inc +) + +add_subdirectory(testcase) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt new file mode 100644 index 000000000..1de192084 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt @@ -0,0 +1,97 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# -------------------------------------------------------------------------- +# pto_tilelang_vec_st(NAME) +# +# CMake macro for TileLang ST test cases. Unlike pto-isa's pto_vec_st() +# which compiles a hand-written kernel.cpp with -xcce, this macro: +# 1. Runs ptoas to compile .pto → kernel.ll (LLVM IR) +# 2. Runs bisheng -x ir kernel.ll → kernel.o (object file) +# 3. Compiles launch.cpp with -xcce and links kernel.o → shared library +# 4. Builds host executable from main.cpp (no GTest — comparison via compare.py) +# -------------------------------------------------------------------------- +function(pto_tilelang_vec_st NAME) + # Step 1: ptoas .pto → kernel.ll + set(PTO_SRC ${CMAKE_CURRENT_SOURCE_DIR}/${NAME}.pto) + set(KERNEL_LL ${CMAKE_CURRENT_BINARY_DIR}/${NAME}_kernel.ll) + add_custom_command( + OUTPUT ${KERNEL_LL} + COMMAND ${PTOAS_BIN} + --pto-arch=a5 --pto-backend=vpto + --enable-tile-op-expand --vpto-emit-hivm-llvm + ${PTO_SRC} -o ${KERNEL_LL} + DEPENDS ${PTO_SRC} + COMMENT "ptoas: ${NAME}.pto -> ${NAME}_kernel.ll" + ) + + # Step 2: bisheng kernel.ll → kernel.o + set(KERNEL_OBJ ${CMAKE_CURRENT_BINARY_DIR}/${NAME}_kernel.o) + add_custom_command( + OUTPUT ${KERNEL_OBJ} + COMMAND bisheng + --target=hiipu64-hisilicon-cce + -march=dav-c310-vec + --cce-aicore-arch=dav-c310-vec + --cce-aicore-only + -c -x ir ${KERNEL_LL} + -o ${KERNEL_OBJ} + DEPENDS ${KERNEL_LL} + COMMENT "bisheng: ${NAME}_kernel.ll -> ${NAME}_kernel.o" + ) + + # Step 3: launch.cpp (-xcce) + kernel.o → shared library + add_library(${NAME}_kernel SHARED launch.cpp ${KERNEL_OBJ}) + set_source_files_properties(${KERNEL_OBJ} PROPERTIES EXTERNAL_OBJECT TRUE GENERATED TRUE) + target_compile_options(${NAME}_kernel PRIVATE + ${CMAKE_CCE_COMPILE_OPTIONS} --cce-aicore-arch=dav-c310-vec -std=c++17) + target_include_directories(${NAME}_kernel PRIVATE + ${ASCEND_HOME_PATH}/pkg_inc/ + ${ASCEND_HOME_PATH}/pkg_inc/profiling/ + ${ASCEND_HOME_PATH}/pkg_inc/runtime/runtime + ) + target_link_options(${NAME}_kernel PRIVATE --cce-fatobj-link) + + # Step 4: main.cpp → host executable + add_executable(${NAME} main.cpp) + target_compile_options(${NAME} PRIVATE ${CMAKE_CPP_COMPILE_OPTIONS}) + target_include_directories(${NAME} PRIVATE + ${PTO_ISA_ROOT}/tests/common + ) + + target_link_directories(${NAME} PUBLIC + ${ASCEND_HOME_PATH}/lib64 + ${ASCEND_HOME_PATH}/tools/simulator/${SOC_VERSION}/lib + ) + + target_link_libraries(${NAME} PRIVATE + ${NAME}_kernel + $:runtime_camodel>> + $:runtime>> + stdc++ ascendcl m tiling_api platform c_sec dl nnopbase pthread + ) +endfunction() + +# -------------------------------------------------------------------------- +# Test case registry — add new ops here. +# -------------------------------------------------------------------------- +set(ALL_TESTCASES + tadd +) + +if((TEST_CASE IN_LIST ALL_TESTCASES) OR (TEST_CASE STREQUAL "all")) + message(STATUS "run: ${TEST_CASE}") +else() + message(FATAL_ERROR "not found TEST_CASE: ${TEST_CASE}, supported: ${ALL_TESTCASES}") +endif() + +foreach(TESTCASE ${ALL_TESTCASES}) + if((DEFINED TEST_CASE AND TEST_CASE STREQUAL TESTCASE) OR (NOT DEFINED TEST_CASE) OR (TEST_CASE STREQUAL "all")) + add_subdirectory(${TESTCASE}) + endif() +endforeach() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tadd/CMakeLists.txt new file mode 100644 index 000000000..84928bcdb --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadd/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tadd) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tadd/compare.py new file mode 100644 index 000000000..360b009fb --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadd/compare.py @@ -0,0 +1,60 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import sys +import os +import numpy as np + + +CASES = [ + {"name": "f32_16x64", "dtype": np.float32, "eps": 1e-6}, + {"name": "f32_32x32", "dtype": np.float32, "eps": 1e-6}, +] + + +def compare_bin(golden_path, output_path, dtype, eps): + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: golden {golden.shape} vs output {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + print(f"[ERROR] Mismatch: max diff={float(abs_diff[idx])} at idx={idx} " + f"(golden={g[idx]}, output={o[idx]})") + return False + return True + + +if __name__ == "__main__": + # Optional filter: python compare.py [case_name] + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + case_dir = case["name"] + golden_path = os.path.join(case_dir, "golden.bin") + output_path = os.path.join(case_dir, "output.bin") + ok = compare_bin(golden_path, output_path, case["dtype"], case["eps"]) + if ok: + print(f"[INFO] {case['name']}: compare passed") + else: + print(f"[ERROR] {case['name']}: compare failed") + all_passed = False + + if not all_passed: + sys.exit(2) + print("[INFO] all cases passed") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tadd/gen_data.py new file mode 100644 index 000000000..1d983c64f --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadd/gen_data.py @@ -0,0 +1,33 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import numpy as np + +np.random.seed(19) + +CASES = [ + {"name": "f32_16x64", "dtype": np.float32, "shape": (16, 64)}, + {"name": "f32_32x32", "dtype": np.float32, "shape": (32, 32)}, +] + +for case in CASES: + case_dir = case["name"] + os.makedirs(case_dir, exist_ok=True) + + input1 = np.random.randint(1, 10, size=case["shape"]).astype(case["dtype"]) + input2 = np.random.randint(1, 10, size=case["shape"]).astype(case["dtype"]) + golden = (input1 + input2).astype(case["dtype"], copy=False) + + input1.tofile(os.path.join(case_dir, "input1.bin")) + input2.tofile(os.path.join(case_dir, "input2.bin")) + golden.tofile(os.path.join(case_dir, "golden.bin")) + print(f"[INFO] gen_data: {case['name']} shape={case['shape']} dtype={case['dtype'].__name__}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tadd/launch.cpp new file mode 100644 index 000000000..ba35a01e6 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadd/launch.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#include +#include +#include + +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +// Case 0: f32 16x64 +__global__ AICORE void TADD_f32_16x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTADD_f32_16x64(float *a, float *b, float *c, void *stream) { + TADD_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 1: f32 32x32 +__global__ AICORE void TADD_f32_32x32(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTADD_f32_32x32(float *a, float *b, float *c, void *stream) { + TADD_f32_32x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tadd/main.cpp new file mode 100644 index 000000000..be03ef998 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadd/main.cpp @@ -0,0 +1,144 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tadd ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, \ + __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + return 1; \ + } \ + } while (0) + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTADD_f32_16x64(float *a, float *b, float *c, void *stream); +void LaunchTADD_f32_32x32(float *a, float *b, float *c, void *stream); + +using LaunchFn = void (*)(float *, float *, float *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; + size_t cols; + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_16x64", LaunchTADD_f32_16x64, 16, 64, sizeof(float)}, + {"f32_32x32", LaunchTADD_f32_32x32, 32, 32, sizeof(float)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (%zux%zu) ===\n", tc.name, tc.rows, tc.cols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + + float *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + float *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + ACL_CHECK(aclrtMallocHost((void **)(&src0Host), fileSize)); + ACL_CHECK(aclrtMallocHost((void **)(&src1Host), fileSize)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), fileSize)); + + ACL_CHECK(aclrtMalloc((void **)&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile((caseDir + "/input1.bin").c_str(), fileSize, src0Host, fileSize); + ReadFile((caseDir + "/input2.bin").c_str(), fileSize, src1Host, fileSize); + + ACL_CHECK(aclrtMemcpy(src0Device, fileSize, src0Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(src1Device, fileSize, src1Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE)); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize); + + aclrtFree(src0Device); + aclrtFree(src1Device); + aclrtFree(dstDevice); + aclrtFreeHost(src0Host); + aclrtFreeHost(src1Host); + aclrtFreeHost(dstHost); + + std::printf("[INFO] case %s done\n", tc.name); + return 0; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tadd [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) { + aclrtDestroyStream(stream); + } + if (deviceSet) { + aclrtResetDevice(deviceId); + } + if (aclInited) { + aclFinalize(); + } + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/tadd.pto b/test/tilelang_st/npu/a5/src/st/testcase/tadd/tadd.pto new file mode 100644 index 000000000..efe79cf4a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadd/tadd.pto @@ -0,0 +1,140 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tadd: tload(a) + tload(b) + tadd(a,b)->c + tstore(c). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-tile-op-expand --vpto-emit-hivm-llvm to produce LLVM IR. + +module { + // Case 0: f32 16x64 (1024 elements) + func.func @TADD_f32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tadd ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case 1: f32 32x32 (1024 elements) + func.func @TADD_f32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%b : !pto.tile_buf) + + pto.tadd ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + return + } +} diff --git a/test/tilelang_st/script/run_st.py b/test/tilelang_st/script/run_st.py new file mode 100755 index 000000000..f8c4d0f3b --- /dev/null +++ b/test/tilelang_st/script/run_st.py @@ -0,0 +1,251 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +""" +TileLang ST runner — validates TileLang DSL template library on NPU / simulator. + +Usage: + python3 test/tilelang_st/script/run_st.py -r npu -v a5 -t tadd + python3 test/tilelang_st/script/run_st.py -r sim -v a5 -t tadd +""" + +import os +import sys +import subprocess +import shutil +import argparse + + +def run_command(command, cwd=None, check=True): + try: + print(f"run command: {' '.join(command)}") + subprocess.run(command, cwd=cwd, check=check, stdout=None, stderr=None, text=True) + except subprocess.CalledProcessError as e: + print(f"run command failed with return code {e.returncode}") + raise + + +def find_ptoas_bin(): + """Locate the ptoas binary by walking up from this script to the repo root.""" + env_bin = os.environ.get("PTOAS_BIN") + if env_bin and os.path.isfile(env_bin): + return os.path.abspath(env_bin) + + search_dir = os.path.dirname(os.path.abspath(__file__)) + for _ in range(8): + candidate = os.path.join(search_dir, "build", "bin", "ptoas") + if os.path.isfile(candidate): + return os.path.abspath(candidate) + parent = os.path.dirname(search_dir) + if parent == search_dir: + break + search_dir = parent + return None + + +def set_env_variables(run_mode, soc_version): + if run_mode == "sim": + ld_lib_path = os.environ.get("LD_LIBRARY_PATH", "") + if ld_lib_path: + filtered_paths = [ + path for path in ld_lib_path.split(":") + if "/runtime/lib64" not in path + ] + os.environ["LD_LIBRARY_PATH"] = ":".join(filtered_paths) + + ascend_home = os.environ.get("ASCEND_HOME_PATH") + if not ascend_home: + raise EnvironmentError("ASCEND_HOME_PATH is not set") + + os.environ["LD_LIBRARY_PATH"] = ( + f"{ascend_home}/runtime/lib64/stub:{os.environ.get('LD_LIBRARY_PATH', '')}" + ) + setenv_path = os.path.join(ascend_home, "bin", "setenv.bash") + if os.path.exists(setenv_path): + print(f"run env shell: {setenv_path}") + result = subprocess.run( + f"source {setenv_path} && env", + shell=True, + executable=shutil.which("bash") or "bash", + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + for line in result.stdout.splitlines(): + if "=" in line: + key, value = line.split("=", 1) + os.environ[key] = value + else: + print(f"warning: not found {setenv_path}") + + simulator_lib_path = os.path.join( + ascend_home, "tools", "simulator", soc_version, "lib" + ) + os.environ["LD_LIBRARY_PATH"] = ( + f"{simulator_lib_path}:{os.environ.get('LD_LIBRARY_PATH', '')}" + ) + + +def build_project(run_mode, soc_version, testcase, ptoas_bin): + build_dir = "build" + if os.path.exists(build_dir): + print(f"clean build: {build_dir}") + shutil.rmtree(build_dir) + os.makedirs(build_dir, exist_ok=True) + + try: + cmake_cmd = [ + "cmake", + f"-DRUN_MODE={run_mode}", + f"-DSOC_VERSION={soc_version}", + f"-DTEST_CASE={testcase}", + f"-DPTOAS_BIN={ptoas_bin}", + "..", + ] + subprocess.run( + cmake_cmd, + cwd=build_dir, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + + cpu_count = os.cpu_count() or 4 + make_cmd = ["make", "VERBOSE=1", "-j", str(cpu_count)] + result = subprocess.run( + make_cmd, + cwd=build_dir, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + print("compile process:\n", result.stdout) + except subprocess.CalledProcessError as e: + print(f"build failed: {e.stdout}") + raise + + +def run_gen_data(golden_path): + original_dir = os.getcwd() + try: + run_command(["cp", golden_path, "build/gen_data.py"]) + os.chdir("build/") + run_command([sys.executable, "gen_data.py"]) + except Exception as e: + print(f"gen golden failed: {e}") + raise + finally: + os.chdir(original_dir) + + +def run_binary(testcase, case_filter=None): + original_dir = os.getcwd() + try: + os.chdir("build/bin/") + cmd = ["./" + testcase] + if case_filter: + cmd.append(case_filter) + run_command(cmd) + except Exception as e: + print(f"run binary failed: {e}") + raise + finally: + os.chdir(original_dir) + + +def run_compare(compare_path, case_filter=None): + original_dir = os.getcwd() + try: + run_command(["cp", compare_path, "build/compare.py"]) + os.chdir("build/") + cmd = [sys.executable, "compare.py"] + if case_filter: + cmd.append(case_filter) + run_command(cmd) + except Exception as e: + print(f"compare failed: {e}") + raise + finally: + os.chdir(original_dir) + + +def main(): + parser = argparse.ArgumentParser(description="TileLang ST runner") + parser.add_argument("-r", "--run-mode", required=True, + help="Run mode: sim or npu") + parser.add_argument("-v", "--soc-version", required=True, + help="SoC version: a5") + parser.add_argument("-t", "--testcase", required=True, + help="Test case name (e.g. tadd)") + parser.add_argument("-p", "--ptoas-bin", required=False, + help="Path to ptoas binary (auto-detected if omitted)") + parser.add_argument("-c", "--case", required=False, default=None, + help="Run a specific case within the testcase (e.g. f32_16x64)") + parser.add_argument("-w", "--without-build", action="store_true", + help="Skip build (requires prior build)") + + args = parser.parse_args() + + if args.soc_version == "a5": + default_soc_version = "Ascend950PR_9599" + else: + print(f"[ERROR] Unsupported soc-version: {args.soc_version}, only a5 is supported", + file=sys.stderr) + sys.exit(1) + + testcase = args.testcase + + ptoas_bin = args.ptoas_bin or find_ptoas_bin() + if not ptoas_bin: + print("[ERROR] Cannot find ptoas binary. " + "Set PTOAS_BIN env or use -p flag.", file=sys.stderr) + sys.exit(1) + ptoas_bin = os.path.abspath(ptoas_bin) + print(f"[INFO] ptoas: {ptoas_bin}") + + original_dir = os.getcwd() + try: + script_path = os.path.abspath(__file__) + tilelang_st_root = os.path.dirname(os.path.dirname(script_path)) + target_dir = os.path.join(tilelang_st_root, "npu", args.soc_version, "src", "st") + + if not os.path.isdir(target_dir): + print(f"[ERROR] Target dir not found: {target_dir}", file=sys.stderr) + sys.exit(1) + + print(f"target_dir: {target_dir}") + os.chdir(target_dir) + + set_env_variables(args.run_mode, default_soc_version) + + if not args.without_build: + build_project(args.run_mode, default_soc_version, testcase, ptoas_bin) + + # gen golden → run binary → compare + golden_path = f"testcase/{testcase}/gen_data.py" + run_gen_data(golden_path) + + run_binary(testcase, args.case) + + compare_path = f"testcase/{testcase}/compare.py" + run_compare(compare_path, args.case) + + except Exception as e: + print(f"run failed: {str(e)}", file=sys.stderr) + sys.exit(1) + finally: + os.chdir(original_dir) + + +if __name__ == "__main__": + main() diff --git a/tilelang-dsl/python/tilelang_dsl/expand_helper.py b/tilelang-dsl/python/tilelang_dsl/expand_helper.py index a01ef7a3c..79d8c82e9 100644 --- a/tilelang-dsl/python/tilelang_dsl/expand_helper.py +++ b/tilelang-dsl/python/tilelang_dsl/expand_helper.py @@ -143,6 +143,29 @@ def _parse_operand_specs(spec_text: str) -> list[dict]: } ) continue + if kind == "view": + shape = raw.get("shape") + if not isinstance(shape, list) or not shape: + raise ValueError(f"operand-specs[{index}] view shape must be a non-empty list") + memory_space = _MEMSPACE_MAP.get(raw.get("memory_space", "gm")) + if memory_space is None: + raise ValueError( + f"operand-specs[{index}] has unknown memory-space {raw.get('memory_space')!r}" + ) + view_spec: dict = { + "kind": "view", + "dtype": dtype, + "shape": tuple(int(dim) for dim in shape), + "memory_space": memory_space, + } + raw_strides = raw.get("strides") + if isinstance(raw_strides, list) and raw_strides: + # null entries represent dynamic strides — keep as None. + view_spec["strides"] = tuple( + None if s is None else int(s) for s in raw_strides + ) + specs.append(view_spec) + continue raise ValueError(f"operand-specs[{index}] has unknown kind {kind!r}") return specs @@ -222,7 +245,12 @@ def main(argv: list[str] | None = None) -> int: return 1 # Specialize Tile parameters positionally from operand-specs. + # View operands match tensorview/partition_tensor_view parameters without + # specialization — shape/strides are resolved dynamically via intrinsics. + # However, their shape/strides are injected into the constraint context so + # that precondition checks (e.g. src.strides[4] == 1) can evaluate. tile_specs = {} + view_context_attrs: dict[str, object] = {} for param, operand_spec in zip(desc.parameters, operand_specs): if param.kind == "tile": if operand_spec["kind"] != "tile": @@ -236,6 +264,19 @@ def main(argv: list[str] | None = None) -> int: memory_space=operand_spec["memory_space"], ) continue + if param.kind in ("tensorview", "partition_tensor_view"): + if operand_spec["kind"] != "view": + print( + f"expand_helper: error: descriptor {param.kind} parameter " + f"does not match operand-specs kind {operand_spec['kind']!r}", + file=sys.stderr, + ) + return 1 + # Inject shape/strides for constraint evaluation. + view_context_attrs[f"{param.name}_shape"] = operand_spec["shape"] + if "strides" in operand_spec: + view_context_attrs[f"{param.name}_strides"] = operand_spec["strides"] + continue if param.kind == "scalar" and operand_spec["kind"] != "scalar": print( "expand_helper: error: descriptor scalar parameter does not match operand-specs", @@ -243,6 +284,12 @@ def main(argv: list[str] | None = None) -> int: ) return 1 + # Bind view context attrs so constraint checking has access to shape/strides. + if view_context_attrs: + desc = desc._bind_constraint_context_attrs( + {**desc.constraint_context_attrs, **view_context_attrs} + ) + specialized = desc.specialize(**tile_specs) # Emit MLIR to stdout. From fc1da2f667057440855d46dda9dc016f3dd42be0 Mon Sep 17 00:00:00 2001 From: qukelin Date: Sat, 11 Apr 2026 10:37:04 +0800 Subject: [PATCH 047/192] Fix TileLang ST ptoas invocation --- .../npu/a5/src/st/testcase/CMakeLists.txt | 16 ++++-- .../src/st/testcase/run_ptoas_to_file.cmake | 51 +++++++++++++++++++ test/tilelang_st/script/run_st.py | 2 +- 3 files changed, 63 insertions(+), 6 deletions(-) create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/run_ptoas_to_file.cmake diff --git a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt index 1de192084..887c5e713 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt +++ b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt @@ -16,18 +16,24 @@ # 3. Compiles launch.cpp with -xcce and links kernel.o → shared library # 4. Builds host executable from main.cpp (no GTest — comparison via compare.py) # -------------------------------------------------------------------------- +set(PTO_TILELANG_ST_TESTCASE_DIR ${CMAKE_CURRENT_LIST_DIR}) + function(pto_tilelang_vec_st NAME) # Step 1: ptoas .pto → kernel.ll set(PTO_SRC ${CMAKE_CURRENT_SOURCE_DIR}/${NAME}.pto) set(KERNEL_LL ${CMAKE_CURRENT_BINARY_DIR}/${NAME}_kernel.ll) + set(PTOAS_CAPTURE_SCRIPT + ${PTO_TILELANG_ST_TESTCASE_DIR}/run_ptoas_to_file.cmake) add_custom_command( OUTPUT ${KERNEL_LL} - COMMAND ${PTOAS_BIN} - --pto-arch=a5 --pto-backend=vpto - --enable-tile-op-expand --vpto-emit-hivm-llvm - ${PTO_SRC} -o ${KERNEL_LL} - DEPENDS ${PTO_SRC} + COMMAND ${CMAKE_COMMAND} + -DPTOAS_BIN=${PTOAS_BIN} + -DPTO_SRC=${PTO_SRC} + -DKERNEL_LL=${KERNEL_LL} + -P ${PTOAS_CAPTURE_SCRIPT} + DEPENDS ${PTO_SRC} ${PTOAS_CAPTURE_SCRIPT} COMMENT "ptoas: ${NAME}.pto -> ${NAME}_kernel.ll" + VERBATIM ) # Step 2: bisheng kernel.ll → kernel.o diff --git a/test/tilelang_st/npu/a5/src/st/testcase/run_ptoas_to_file.cmake b/test/tilelang_st/npu/a5/src/st/testcase/run_ptoas_to_file.cmake new file mode 100644 index 000000000..037ef9f00 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/run_ptoas_to_file.cmake @@ -0,0 +1,51 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +if(NOT DEFINED PTOAS_BIN OR NOT DEFINED PTO_SRC OR NOT DEFINED KERNEL_LL) + message(FATAL_ERROR "PTOAS_BIN, PTO_SRC, and KERNEL_LL must be provided") +endif() + +get_filename_component(KERNEL_LL_DIR "${KERNEL_LL}" DIRECTORY) +file(MAKE_DIRECTORY "${KERNEL_LL_DIR}") + +execute_process( + COMMAND "${PTOAS_BIN}" + --pto-arch=a5 + --pto-backend=vpto + --enable-tile-op-expand + --vpto-emit-hivm-llvm + "${PTO_SRC}" + -o - + OUTPUT_FILE "${KERNEL_LL}" + ERROR_VARIABLE PTOAS_STDERR + RESULT_VARIABLE PTOAS_RESULT +) + +if(NOT PTOAS_RESULT EQUAL 0) + file(REMOVE "${KERNEL_LL}") + string(STRIP "${PTOAS_STDERR}" PTOAS_STDERR) + if(PTOAS_STDERR) + message(FATAL_ERROR "ptoas failed while generating ${KERNEL_LL}:\n${PTOAS_STDERR}") + endif() + message(FATAL_ERROR "ptoas failed while generating ${KERNEL_LL}") +endif() + +if(NOT EXISTS "${KERNEL_LL}") + message(FATAL_ERROR "ptoas completed without producing ${KERNEL_LL}") +endif() + +file(SIZE "${KERNEL_LL}" KERNEL_LL_SIZE) +if(KERNEL_LL_SIZE EQUAL 0) + file(REMOVE "${KERNEL_LL}") + string(STRIP "${PTOAS_STDERR}" PTOAS_STDERR) + if(PTOAS_STDERR) + message(FATAL_ERROR + "ptoas produced empty LLVM IR for ${PTO_SRC}:\n${PTOAS_STDERR}") + endif() + message(FATAL_ERROR "ptoas produced empty LLVM IR for ${PTO_SRC}") +endif() diff --git a/test/tilelang_st/script/run_st.py b/test/tilelang_st/script/run_st.py index f8c4d0f3b..3f16f7bb0 100755 --- a/test/tilelang_st/script/run_st.py +++ b/test/tilelang_st/script/run_st.py @@ -41,7 +41,7 @@ def find_ptoas_bin(): search_dir = os.path.dirname(os.path.abspath(__file__)) for _ in range(8): - candidate = os.path.join(search_dir, "build", "bin", "ptoas") + candidate = os.path.join(search_dir, "build", "tools", "ptoas", "ptoas") if os.path.isfile(candidate): return os.path.abspath(candidate) parent = os.path.dirname(search_dir) From fa4a4b5d2a556c607fb1d0652f98a44ba28ed407 Mon Sep 17 00:00:00 2001 From: qukelin Date: Sat, 11 Apr 2026 12:54:03 +0800 Subject: [PATCH 048/192] Fold TileLang dead branches after intrinsic folding --- tools/ptoas/ptoas.cpp | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index bacdd6f52..176a0668e 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -1129,16 +1129,22 @@ static LogicalResult lowerPTOToVPTOBackend(ModuleOp module) { // 3. InlineLibCall: inline template function bodies // 4. FoldTileBufIntrinsics: fold tile_buf_addr / tile_valid_rows / // tile_valid_cols to concrete memref/constant values - backendPM.addPass(pto::createMemrefToTileBufPass()); + pm.addPass(pto::createMemrefToTileBufPass()); pto::ExpandTileOpOptions expandOpts; expandOpts.tilelangPath = tilelangPath; expandOpts.tilelangPkgPath = tilelangPkgPath; - backendPM.addPass(pto::createExpandTileOpPass(expandOpts)); - - backendPM.addPass(pto::createPTOInlineLibCallPass()); - backendPM.addNestedPass( + pm.addPass(pto::createExpandTileOpPass(expandOpts)); + + pm.addPass(pto::createPTOInlineLibCallPass()); + pm.addNestedPass( pto::createFoldTileBufIntrinsicsPass()); + // FoldTileBufIntrinsics materializes many constant branch conditions. + // Clean them up immediately on the TileOp expansion path before the + // authoring-stage VPTO verifier and let the existing CSE passes remove the + // resulting dead values later in the pipeline. + pm.addPass(mlir::createSCCPPass()); + pm.addPass(mlir::createCanonicalizerPass()); } if (failed(backendPM.run(module))) { llvm::errs() << "Error: backend lowering pass execution failed.\n"; From 22fc03bd97f86a74ba723784ecbb6d7678f6465c Mon Sep 17 00:00:00 2001 From: qukelin Date: Sat, 11 Apr 2026 18:15:38 +0800 Subject: [PATCH 049/192] Fix tilelang ST tadd simulator flow --- .../npu/a5/src/st/testcase/CMakeLists.txt | 48 ++++- .../src/st/testcase/repack_tilelang_kernel.sh | 198 ++++++++++++++++++ .../src/st/testcase/run_ptoas_to_file.cmake | 1 + .../npu/a5/src/st/testcase/tadd/compare.py | 32 ++- .../npu/a5/src/st/testcase/tadd/launch.cpp | 14 +- .../npu/a5/src/st/testcase/tadd/main.cpp | 12 +- .../npu/a5/src/st/testcase/tadd/tadd.pto | 3 +- test/tilelang_st/script/run_st.py | 28 ++- 8 files changed, 296 insertions(+), 40 deletions(-) create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/repack_tilelang_kernel.sh diff --git a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt index 887c5e713..a1116f0b4 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt +++ b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt @@ -12,8 +12,9 @@ # CMake macro for TileLang ST test cases. Unlike pto-isa's pto_vec_st() # which compiles a hand-written kernel.cpp with -xcce, this macro: # 1. Runs ptoas to compile .pto → kernel.ll (LLVM IR) -# 2. Runs bisheng -x ir kernel.ll → kernel.o (object file) -# 3. Compiles launch.cpp with -xcce and links kernel.o → shared library +# 2. Runs bisheng -x ir kernel.ll → device .o +# 3. Repacks the device .o into a host-linkable fatobj object +# 4. Links the repacked object → shared library # 4. Builds host executable from main.cpp (no GTest — comparison via compare.py) # -------------------------------------------------------------------------- set(PTO_TILELANG_ST_TESTCASE_DIR ${CMAKE_CURRENT_LIST_DIR}) @@ -36,24 +37,49 @@ function(pto_tilelang_vec_st NAME) VERBATIM ) - # Step 2: bisheng kernel.ll → kernel.o - set(KERNEL_OBJ ${CMAKE_CURRENT_BINARY_DIR}/${NAME}_kernel.o) + # Step 2: bisheng kernel.ll → device .o + set(KERNEL_DEVICE_OBJ ${CMAKE_CURRENT_BINARY_DIR}/${NAME}_kernel_device.o) add_custom_command( - OUTPUT ${KERNEL_OBJ} + OUTPUT ${KERNEL_DEVICE_OBJ} COMMAND bisheng --target=hiipu64-hisilicon-cce -march=dav-c310-vec --cce-aicore-arch=dav-c310-vec --cce-aicore-only + -O2 -c -x ir ${KERNEL_LL} - -o ${KERNEL_OBJ} + -o ${KERNEL_DEVICE_OBJ} DEPENDS ${KERNEL_LL} - COMMENT "bisheng: ${NAME}_kernel.ll -> ${NAME}_kernel.o" + COMMENT "bisheng: ${NAME}_kernel.ll -> ${NAME}_kernel_device.o" + VERBATIM + ) + + # Step 3: repack device .o into a host-linkable fatobj object using a + # generated stub source derived from launch.cpp's kernel declarations. + set(REPACK_SCRIPT ${PTO_TILELANG_ST_TESTCASE_DIR}/repack_tilelang_kernel.sh) + set(KERNEL_LAUNCH_SRC ${CMAKE_CURRENT_SOURCE_DIR}/launch.cpp) + set(KERNEL_REPACK_OBJ ${CMAKE_CURRENT_BINARY_DIR}/${NAME}_kernel_repack.o) + string(MD5 KERNEL_MODULE_ID_HASH "${NAME}") + string(SUBSTRING "${KERNEL_MODULE_ID_HASH}" 0 16 KERNEL_MODULE_ID) + add_custom_command( + OUTPUT ${KERNEL_REPACK_OBJ} + COMMAND bash ${REPACK_SCRIPT} + ${ASCEND_HOME_PATH} + ${PTO_ISA_ROOT} + dav-c310-vec + ${KERNEL_LAUNCH_SRC} + ${KERNEL_DEVICE_OBJ} + ${KERNEL_REPACK_OBJ} + ${KERNEL_MODULE_ID} + DEPENDS ${KERNEL_LAUNCH_SRC} ${KERNEL_DEVICE_OBJ} ${REPACK_SCRIPT} + COMMENT "repack: ${NAME}_kernel_device.o -> ${NAME}_kernel_repack.o" + VERBATIM ) - # Step 3: launch.cpp (-xcce) + kernel.o → shared library - add_library(${NAME}_kernel SHARED launch.cpp ${KERNEL_OBJ}) - set_source_files_properties(${KERNEL_OBJ} PROPERTIES EXTERNAL_OBJECT TRUE GENERATED TRUE) + # Step 4: launch.cpp (-xcce) + repack object → shared library + add_library(${NAME}_kernel SHARED launch.cpp ${KERNEL_REPACK_OBJ}) + set_source_files_properties(${KERNEL_REPACK_OBJ} + PROPERTIES EXTERNAL_OBJECT TRUE GENERATED TRUE) target_compile_options(${NAME}_kernel PRIVATE ${CMAKE_CCE_COMPILE_OPTIONS} --cce-aicore-arch=dav-c310-vec -std=c++17) target_include_directories(${NAME}_kernel PRIVATE @@ -63,7 +89,7 @@ function(pto_tilelang_vec_st NAME) ) target_link_options(${NAME}_kernel PRIVATE --cce-fatobj-link) - # Step 4: main.cpp → host executable + # Step 5: main.cpp → host executable add_executable(${NAME} main.cpp) target_compile_options(${NAME} PRIVATE ${CMAKE_CPP_COMPILE_OPTIONS}) target_include_directories(${NAME} PRIVATE diff --git a/test/tilelang_st/npu/a5/src/st/testcase/repack_tilelang_kernel.sh b/test/tilelang_st/npu/a5/src/st/testcase/repack_tilelang_kernel.sh new file mode 100644 index 000000000..6a3cd4d7f --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/repack_tilelang_kernel.sh @@ -0,0 +1,198 @@ +#!/usr/bin/env bash +set -euo pipefail + +if [[ $# -ne 7 ]]; then + echo "usage: $0 " >&2 + exit 1 +fi + +ASCEND_HOME_PATH="$1" +PTO_ISA_ROOT="$2" +AICORE_ARCH="$3" +KERNEL_STUB_SRC="$4" +DEVICE_OBJ="$5" +OUTPUT_OBJ="$6" +MODULE_ID="$7" + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT_DIR="$(cd "${SCRIPT_DIR}/../../../../../../../" && pwd)" + +BISHENG_BIN="${BISHENG_BIN:-${ASCEND_HOME_PATH}/bin/bisheng}" +BISHENG_CC1_BIN="${BISHENG_CC1_BIN:-${ASCEND_HOME_PATH}/tools/bisheng_compiler/bin/bisheng}" +CCE_LD_BIN="${CCE_LD_BIN:-${ASCEND_HOME_PATH}/bin/cce-ld}" +LD_LLD_BIN="${LD_LLD_BIN:-${ASCEND_HOME_PATH}/bin/ld.lld}" +CLANG_RESOURCE_DIR="${CLANG_RESOURCE_DIR:-${ASCEND_HOME_PATH}/tools/bisheng_compiler/lib/clang/15.0.5}" +CCE_STUB_DIR="${CCE_STUB_DIR:-${CLANG_RESOURCE_DIR}/include/cce_stub}" + +HOST_ARCH="$(uname -m)" +HOST_TRIPLE="" +HOST_TARGET_CPU="" +HOST_TARGET_ABI="" +HOST_FEATURE_FLAGS=() + +case "${HOST_ARCH}" in + aarch64) + HOST_TRIPLE="aarch64-unknown-linux-gnu" + HOST_TARGET_CPU="generic" + HOST_TARGET_ABI="aapcs" + HOST_FEATURE_FLAGS=(-target-feature +neon -target-feature +v8a) + ;; + x86_64) + HOST_TRIPLE="x86_64-unknown-linux-gnu" + HOST_TARGET_CPU="x86-64" + ;; + *) + echo "unsupported host arch from uname -m: ${HOST_ARCH}" >&2 + exit 1 + ;; +esac + +for required in "${BISHENG_BIN}" "${BISHENG_CC1_BIN}" "${CCE_LD_BIN}" "${LD_LLD_BIN}"; do + if [[ ! -x "${required}" ]]; then + echo "missing required tool: ${required}" >&2 + exit 1 + fi +done + +readarray -t BISHENG_SYSTEM_INCLUDES < <( + "${BISHENG_BIN}" -xc++ -E -v - &1 | + awk ' + /#include <...> search starts here:/ {capture=1; next} + /End of search list\./ {capture=0} + capture && $0 ~ /^ / {sub(/^ +/, "", $0); print} + ' +) + +if [[ "${#BISHENG_SYSTEM_INCLUDES[@]}" -eq 0 ]]; then + echo "failed to discover bisheng system include directories" >&2 + exit 1 +fi + +CC1_INCLUDE_FLAGS=() +for inc in "${BISHENG_SYSTEM_INCLUDES[@]}"; do + if [[ "${inc}" == */include/c++/* || "${inc}" == */backward ]]; then + CC1_INCLUDE_FLAGS+=(-internal-isystem "${inc}") + elif [[ "${inc}" == "/usr/include" ]]; then + CC1_INCLUDE_FLAGS+=(-internal-externc-isystem "${inc}") + else + CC1_INCLUDE_FLAGS+=(-internal-isystem "${inc}") + fi +done + +OUTPUT_DIR="$(cd "$(dirname "${OUTPUT_OBJ}")" && pwd)" +BASE_NAME="$(basename "${OUTPUT_OBJ}" .o)" +HOST_STUB_OBJ="${OUTPUT_DIR}/${BASE_NAME}_host_from_llvm.o" +GENERATED_STUB_SRC="${OUTPUT_DIR}/${BASE_NAME}_stub.cpp" + +{ + cat <<'EOF' +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef AICORE +#define AICORE [aicore] +#endif +EOF + if ! sed -n '/__global__ AICORE void /s/;$/ {}/p' "${KERNEL_STUB_SRC}"; then + echo "failed to derive stub declarations from ${KERNEL_STUB_SRC}" >&2 + exit 1 + fi +} > "${GENERATED_STUB_SRC}" + +if ! grep -q "__global__ AICORE void" "${GENERATED_STUB_SRC}"; then + echo "no kernel declarations found in ${KERNEL_STUB_SRC}" >&2 + exit 1 +fi + +host_target_args=( + -triple "${HOST_TRIPLE}" + -target-cpu "${HOST_TARGET_CPU}" +) +if [[ -n "${HOST_TARGET_ABI}" ]]; then + host_target_args+=(-target-abi "${HOST_TARGET_ABI}") +fi +if [[ ${#HOST_FEATURE_FLAGS[@]} -gt 0 ]]; then + host_target_args+=("${HOST_FEATURE_FLAGS[@]}") +fi + +"${BISHENG_CC1_BIN}" -cc1 \ + "${host_target_args[@]}" \ + -fcce-aicpu-legacy-launch \ + -fcce-is-host \ + -cce-launch-with-flagv2-impl \ + -fcce-aicore-arch "${AICORE_ARCH}" \ + -fcce-fatobj-compile \ + -emit-obj \ + --mrelax-relocations \ + -disable-free \ + -clear-ast-before-backend \ + -disable-llvm-verifier \ + -discard-value-names \ + -main-file-name "$(basename "${KERNEL_STUB_SRC}")" \ + -mrelocation-model pic \ + -pic-level 2 \ + -fhalf-no-semantic-interposition \ + -fenable-matrix \ + -mllvm -enable-matrix \ + -mframe-pointer=non-leaf \ + -fmath-errno \ + -ffp-contract=on \ + -fno-rounding-math \ + -mconstructor-aliases \ + -funwind-tables=2 \ + -fallow-half-arguments-and-returns \ + -mllvm -treat-scalable-fixed-error-as-warning \ + -fcoverage-compilation-dir="${ROOT_DIR}" \ + -resource-dir "${CLANG_RESOURCE_DIR}" \ + -include __clang_cce_runtime_wrapper.h \ + -D _FORTIFY_SOURCE=2 \ + -D REGISTER_BASE \ + -I "${PTO_ISA_ROOT}/include" \ + -I "${ASCEND_HOME_PATH}/include" \ + -I "${ASCEND_HOME_PATH}/pkg_inc" \ + -I "${ASCEND_HOME_PATH}/pkg_inc/profiling" \ + -I "${ASCEND_HOME_PATH}/pkg_inc/runtime/runtime" \ + "${CC1_INCLUDE_FLAGS[@]}" \ + -O2 \ + -Wno-macro-redefined \ + -Wno-ignored-attributes \ + -std=c++17 \ + -fdeprecated-macro \ + -fdebug-compilation-dir="${ROOT_DIR}" \ + -ferror-limit 19 \ + -stack-protector 2 \ + -fno-signed-char \ + -fgnuc-version=4.2.1 \ + -fcxx-exceptions \ + -fexceptions \ + -vectorize-loops \ + -vectorize-slp \ + -mllvm -cce-aicore-stack-size=0x8000 \ + -mllvm -cce-aicore-function-stack-size=0x8000 \ + -mllvm -cce-aicore-record-overflow=true \ + -mllvm -cce-aicore-addr-transform \ + -mllvm -cce-aicore-dcci-insert-for-scalar=false \ + -fcce-include-aibinary "${DEVICE_OBJ}" \ + -fcce-device-module-id "${MODULE_ID}" \ + -target-feature +outline-atomics \ + -faddrsig \ + -D__GCC_HAVE_DWARF2_CFI_ASM=1 \ + -o "${HOST_STUB_OBJ}" \ + -x cce "${GENERATED_STUB_SRC}" + +"${CCE_LD_BIN}" \ + "${LD_LLD_BIN}" \ + -x \ + -cce-lite-bin-module-id "${MODULE_ID}" \ + -cce-aicore-arch="${AICORE_ARCH}" \ + -r \ + -o "${OUTPUT_OBJ}" \ + -cce-stub-dir "${CCE_STUB_DIR}" \ + -cce-install-dir "$(dirname "${BISHENG_CC1_BIN}")" \ + -cce-inputs-number 1 \ + "${HOST_STUB_OBJ}" diff --git a/test/tilelang_st/npu/a5/src/st/testcase/run_ptoas_to_file.cmake b/test/tilelang_st/npu/a5/src/st/testcase/run_ptoas_to_file.cmake index 037ef9f00..8922fbab8 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/run_ptoas_to_file.cmake +++ b/test/tilelang_st/npu/a5/src/st/testcase/run_ptoas_to_file.cmake @@ -17,6 +17,7 @@ execute_process( COMMAND "${PTOAS_BIN}" --pto-arch=a5 --pto-backend=vpto + --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm "${PTO_SRC}" diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tadd/compare.py index 360b009fb..56d41377c 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tadd/compare.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadd/compare.py @@ -13,6 +13,10 @@ import os import numpy as np +ANSI_RESET = "\033[0m" +ANSI_BOLD_GREEN = "\033[1;32m" +ANSI_BOLD_RED = "\033[1;31m" + CASES = [ {"name": "f32_16x64", "dtype": np.float32, "eps": 1e-6}, @@ -20,19 +24,35 @@ ] +def supports_color(): + return sys.stdout.isatty() and os.environ.get("TERM") not in (None, "", "dumb") + + +def style_pass(text): + if not supports_color(): + return text + return f"{ANSI_BOLD_GREEN}{text}{ANSI_RESET}" + + +def style_fail(text): + if not supports_color(): + return text + return f"{ANSI_BOLD_RED}{text}{ANSI_RESET}" + + def compare_bin(golden_path, output_path, dtype, eps): golden = np.fromfile(golden_path, dtype=dtype) output = np.fromfile(output_path, dtype=dtype) if golden.shape != output.shape: - print(f"[ERROR] Shape mismatch: golden {golden.shape} vs output {output.shape}") + print(style_fail(f"[ERROR] Shape mismatch: golden {golden.shape} vs output {output.shape}")) return False if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): g = golden.astype(np.float64, copy=False) o = output.astype(np.float64, copy=False) abs_diff = np.abs(g - o) idx = int(np.argmax(abs_diff)) - print(f"[ERROR] Mismatch: max diff={float(abs_diff[idx])} at idx={idx} " - f"(golden={g[idx]}, output={o[idx]})") + print(style_fail(f"[ERROR] Mismatch: max diff={float(abs_diff[idx])} at idx={idx} " + f"(golden={g[idx]}, output={o[idx]})")) return False return True @@ -50,11 +70,11 @@ def compare_bin(golden_path, output_path, dtype, eps): output_path = os.path.join(case_dir, "output.bin") ok = compare_bin(golden_path, output_path, case["dtype"], case["eps"]) if ok: - print(f"[INFO] {case['name']}: compare passed") + print(style_pass(f"[INFO] {case['name']}: compare passed")) else: - print(f"[ERROR] {case['name']}: compare failed") + print(style_fail(f"[ERROR] {case['name']}: compare failed")) all_passed = False if not all_passed: sys.exit(2) - print("[INFO] all cases passed") + print(style_pass("[INFO] all cases passed")) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tadd/launch.cpp index ba35a01e6..f1074c838 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tadd/launch.cpp +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadd/launch.cpp @@ -6,27 +6,21 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -#ifndef __VEC_SCOPE__ -#define __VEC_SCOPE__ -#endif - #include -#include -#include -#ifndef __CPU_SIM -#include "acl/acl.h" +#ifndef AICORE +#define AICORE [aicore] #endif // Case 0: f32 16x64 -__global__ AICORE void TADD_f32_16x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); +extern "C" __global__ AICORE void TADD_f32_16x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); void LaunchTADD_f32_16x64(float *a, float *b, float *c, void *stream) { TADD_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); } // Case 1: f32 32x32 -__global__ AICORE void TADD_f32_32x32(__gm__ float *a, __gm__ float *b, __gm__ float *c); +extern "C" __global__ AICORE void TADD_f32_32x32(__gm__ float *a, __gm__ float *b, __gm__ float *c); void LaunchTADD_f32_32x32(float *a, float *b, float *c, void *stream) { TADD_f32_32x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tadd/main.cpp index be03ef998..3c1703d8e 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tadd/main.cpp +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadd/main.cpp @@ -75,8 +75,16 @@ static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { ACL_CHECK(aclrtMalloc((void **)&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST)); ACL_CHECK(aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST)); - ReadFile((caseDir + "/input1.bin").c_str(), fileSize, src0Host, fileSize); - ReadFile((caseDir + "/input2.bin").c_str(), fileSize, src1Host, fileSize); + size_t src0FileSize = fileSize; + size_t src1FileSize = fileSize; + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + return 1; + } + if (!ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + return 1; + } ACL_CHECK(aclrtMemcpy(src0Device, fileSize, src0Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE)); ACL_CHECK(aclrtMemcpy(src1Device, fileSize, src1Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE)); diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/tadd.pto b/test/tilelang_st/npu/a5/src/st/testcase/tadd/tadd.pto index efe79cf4a..340e416c3 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tadd/tadd.pto +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadd/tadd.pto @@ -8,7 +8,8 @@ // TileLang ST kernels for pto.tadd: tload(a) + tload(b) + tadd(a,b)->c + tstore(c). // Multiple cases with different shapes in a single module. -// Compiled by ptoas --enable-tile-op-expand --vpto-emit-hivm-llvm to produce LLVM IR. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. module { // Case 0: f32 16x64 (1024 elements) diff --git a/test/tilelang_st/script/run_st.py b/test/tilelang_st/script/run_st.py index 3f16f7bb0..8ab4b5979 100755 --- a/test/tilelang_st/script/run_st.py +++ b/test/tilelang_st/script/run_st.py @@ -94,6 +94,10 @@ def set_env_variables(run_mode, soc_version): ) +def get_testcase_work_dir(testcase): + return os.path.join("build", "testcase", testcase) + + def build_project(run_mode, soc_version, testcase, ptoas_bin): build_dir = "build" if os.path.exists(build_dir): @@ -135,11 +139,13 @@ def build_project(run_mode, soc_version, testcase, ptoas_bin): raise -def run_gen_data(golden_path): +def run_gen_data(golden_path, testcase): original_dir = os.getcwd() try: - run_command(["cp", golden_path, "build/gen_data.py"]) - os.chdir("build/") + work_dir = get_testcase_work_dir(testcase) + os.makedirs(work_dir, exist_ok=True) + run_command(["cp", golden_path, os.path.join(work_dir, "gen_data.py")]) + os.chdir(work_dir) run_command([sys.executable, "gen_data.py"]) except Exception as e: print(f"gen golden failed: {e}") @@ -151,8 +157,8 @@ def run_gen_data(golden_path): def run_binary(testcase, case_filter=None): original_dir = os.getcwd() try: - os.chdir("build/bin/") - cmd = ["./" + testcase] + os.chdir(get_testcase_work_dir(testcase)) + cmd = [os.path.join("..", "..", "bin", testcase)] if case_filter: cmd.append(case_filter) run_command(cmd) @@ -163,11 +169,13 @@ def run_binary(testcase, case_filter=None): os.chdir(original_dir) -def run_compare(compare_path, case_filter=None): +def run_compare(compare_path, testcase, case_filter=None): original_dir = os.getcwd() try: - run_command(["cp", compare_path, "build/compare.py"]) - os.chdir("build/") + work_dir = get_testcase_work_dir(testcase) + os.makedirs(work_dir, exist_ok=True) + run_command(["cp", compare_path, os.path.join(work_dir, "compare.py")]) + os.chdir(work_dir) cmd = [sys.executable, "compare.py"] if case_filter: cmd.append(case_filter) @@ -233,12 +241,12 @@ def main(): # gen golden → run binary → compare golden_path = f"testcase/{testcase}/gen_data.py" - run_gen_data(golden_path) + run_gen_data(golden_path, testcase) run_binary(testcase, args.case) compare_path = f"testcase/{testcase}/compare.py" - run_compare(compare_path, args.case) + run_compare(compare_path, testcase, args.case) except Exception as e: print(f"run failed: {str(e)}", file=sys.stderr) From 6b40947ba09e5789e3d64e79f2a1c71101d6a5fa Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Mon, 13 Apr 2026 19:59:40 +0800 Subject: [PATCH 050/192] fixup ptr normalize --- lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 6 -- lib/PTO/Transforms/VPTOLLVMEmitterHelper.cpp | 104 ------------------- lib/PTO/Transforms/VPTOPtrNormalize.cpp | 75 +++++++------ 3 files changed, 44 insertions(+), 141 deletions(-) diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index 16c6a8bd7..1aa877ef6 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -27,8 +27,6 @@ namespace mlir::pto { void materializeVecScopeCarrierLoops(ModuleOp module); -LogicalResult normalizePtoMemRefSpaces(ModuleOp module, - llvm::raw_ostream &diagOS); LogicalResult applyQueriedTargetAttrs(ModuleOp module, const VPTOEmissionOptions &options, llvm::raw_ostream &diagOS); @@ -4855,10 +4853,6 @@ static LogicalResult runPipeline(ModuleOp module, llvm::raw_ostream &diagOS, materializeVecScopeCarrierLoops(clonedModule); - if (failed(normalizePtoMemRefSpaces(clonedModule, diagOS))) { - diagOS << "VPTO LLVM emission failed: normalizePtoMemRefSpaces failed\n"; - return failure(); - } if (failed(lowerVPTOOps(clonedModule, diagOS))) { diagOS << "VPTO LLVM emission failed: lowerVPTOOps failed\n"; return failure(); diff --git a/lib/PTO/Transforms/VPTOLLVMEmitterHelper.cpp b/lib/PTO/Transforms/VPTOLLVMEmitterHelper.cpp index 931f4966f..914ce0b53 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitterHelper.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitterHelper.cpp @@ -370,110 +370,6 @@ queryDefaultTargetAttrs(const VPTOEmissionOptions &options, } // namespace -LogicalResult normalizePtoMemRefSpaces(ModuleOp module, - llvm::raw_ostream &diagOS) { - MLIRContext *context = module.getContext(); - TypeConverter typeConverter; - typeConverter.addConversion([](Type type) { return type; }); - typeConverter.addConversion([&](MemRefType type) -> Type { - auto addrSpace = dyn_cast_or_null(type.getMemorySpace()); - if (!addrSpace) - return type; - return MemRefType::get( - type.getShape(), type.getElementType(), type.getLayout(), - IntegerAttr::get(IntegerType::get(context, 64), - static_cast(addrSpace.getAddressSpace()))); - }); - typeConverter.addTypeAttributeConversion( - [](MemRefType, pto::AddressSpaceAttr attr) -> Attribute { - return IntegerAttr::get(IntegerType::get(attr.getContext(), 64), - static_cast(attr.getAddressSpace())); - }); - auto materializeMemRefCast = [](OpBuilder &builder, Type resultType, - ValueRange inputs, Location loc) -> Value { - if (inputs.size() != 1) - return {}; - return builder - .create(loc, TypeRange{resultType}, inputs) - .getResult(0); - }; - typeConverter.addSourceMaterialization(materializeMemRefCast); - typeConverter.addTargetMaterialization(materializeMemRefCast); - typeConverter.addArgumentMaterialization(materializeMemRefCast); - - ConversionTarget target(*context); - target.addLegalOp(); - target.addDynamicallyLegalOp([&](func::FuncOp op) { - return typeConverter.isSignatureLegal(op.getFunctionType()) && - typeConverter.isLegal(&op.getBody()); - }); - target.addDynamicallyLegalOp( - [&](func::CallOp op) { return typeConverter.isLegal(op); }); - target.addDynamicallyLegalOp( - [&](func::ReturnOp op) { return typeConverter.isLegal(op); }); - target.addDynamicallyLegalOp( - [&](Operation *op) { - return isLegalForBranchOpInterfaceTypeConversionPattern(op, - typeConverter); - }); - target.markUnknownOpDynamicallyLegal([&](Operation *op) { - return typeConverter.isLegal(op->getOperandTypes()) && - typeConverter.isLegal(op->getResultTypes()); - }); - - RewritePatternSet patterns(context); - scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter, patterns, - target); - populateFunctionOpInterfaceTypeConversionPattern(patterns, - typeConverter); - populateCallOpTypeConversionPattern(patterns, typeConverter); - populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); - populateReturnOpTypeConversionPattern(patterns, typeConverter); - patterns.add(typeConverter, context); - patterns.add(typeConverter, context); - - if (failed(applyPartialConversion(module, target, std::move(patterns)))) { - diagOS << "VPTO LLVM emission failed: memref address-space normalization " - "failed\n"; - return failure(); - } - - SmallVector castsToFold; - module.walk([&](UnrealizedConversionCastOp castOp) { - if (castOp->getNumOperands() != 1 || castOp->getNumResults() != 1) - return; - if (!hasPtoMemRefMemorySpace(castOp->getOperandTypes()) && - !hasPtoMemRefMemorySpace(castOp->getResultTypes())) - return; - Type convertedResultType = - typeConverter.convertType(castOp.getResult(0).getType()); - if (convertedResultType && - convertedResultType == castOp.getOperand(0).getType()) - castsToFold.push_back(castOp); - }); - for (UnrealizedConversionCastOp castOp : castsToFold) { - castOp.getResult(0).replaceAllUsesWith(castOp.getOperand(0)); - castOp.erase(); - } - - WalkResult leftover = module.walk([&](Operation *op) { - if (hasPtoMemRefMemorySpace(op->getOperandTypes()) || - hasPtoMemRefMemorySpace(op->getResultTypes())) { - diagOS << "VPTO LLVM emission failed: residual PTO memref address space " - "on op " - << op->getName().getStringRef() << "\n"; - op->print(diagOS); - diagOS << "\n"; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - if (leftover.wasInterrupted()) - return failure(); - return success(); -} - void materializeVecScopeCarrierLoops(ModuleOp module) { MLIRContext *ctx = module.getContext(); (void)ctx->getOrLoadDialect(); diff --git a/lib/PTO/Transforms/VPTOPtrNormalize.cpp b/lib/PTO/Transforms/VPTOPtrNormalize.cpp index 7230db089..cba0bf027 100644 --- a/lib/PTO/Transforms/VPTOPtrNormalize.cpp +++ b/lib/PTO/Transforms/VPTOPtrNormalize.cpp @@ -68,39 +68,16 @@ static bool hasPtrNormalizeConvertibleType(Type type) { if (isa(type)) return true; auto memrefType = dyn_cast(type); - return memrefType && - static_cast( - getPointerMemorySpace(memrefType.getMemorySpace(), type.getContext())); + return memrefType && static_cast(getPointerMemorySpace( + memrefType.getMemorySpace(), type.getContext())); } static bool hasPtrNormalizeConvertibleType(TypeRange types) { - return llvm::any_of(types, [](Type type) { - return hasPtrNormalizeConvertibleType(type); - }); + return llvm::any_of( + types, [](Type type) { return hasPtrNormalizeConvertibleType(type); }); } -static Value materializePtrCast(OpBuilder &builder, Type resultType, - ValueRange inputs, Location loc) { - if (inputs.size() != 1 || !isa(resultType)) - return {}; - - Value input = inputs.front(); - if (input.getType() == resultType) - return input; - - auto inputMemrefType = dyn_cast(input.getType()); - auto resultPtrType = dyn_cast(resultType); - if (!inputMemrefType || !resultPtrType) - return {}; - - auto memorySpace = getPointerMemorySpace(inputMemrefType.getMemorySpace(), - builder.getContext()); - if (!memorySpace || memorySpace != resultPtrType.getMemorySpace() || - inputMemrefType.getElementType() != resultPtrType.getElementType()) - return {}; - - return builder.create(loc, resultPtrType, input); -} +static bool isMemRefType(Type type) { return isa(type); } static Value materializeUnrealizedCast(OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) { @@ -215,6 +192,38 @@ struct ConvertPointerCastToCastPtrPattern } }; +struct ConvertCastPtrPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::CastPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type convertedResultType = + getTypeConverter()->convertType(op.getResult().getType()); + if (!convertedResultType) + return failure(); + + Value input = adaptor.getInput(); + Type inputType = input.getType(); + if (isMemRefType(inputType) || isMemRefType(convertedResultType)) + return rewriter.notifyMatchFailure(op, + "memref castptr must be eliminated"); + + if (!isa(inputType) || + !isa(convertedResultType)) + return rewriter.notifyMatchFailure(op, + "expected ptr/int castptr operands"); + + if (inputType == convertedResultType) { + rewriter.replaceOp(op, input); + return success(); + } + + rewriter.replaceOpWithNewOp(op, convertedResultType, input); + return success(); + } +}; + struct ConvertBindTileToPtrPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -359,8 +368,8 @@ struct VPTOPtrNormalizePass target.addLegalDialect(); target.addDynamicallyLegalDialect([](Operation *op) { - return !isa(op); + return !isa(op); }); target.addLegalOp(); target.addDynamicallyLegalOp([&](func::FuncOp op) { @@ -381,6 +390,10 @@ struct VPTOPtrNormalizePass return op.getResult().getType() == typeConverter.convertType(op.getResult().getType()); }); + target.addDynamicallyLegalOp([&](pto::CastPtrOp op) { + return !isMemRefType(op.getInput().getType()) && + !isMemRefType(op.getResult().getType()); + }); target.addDynamicallyLegalOp([&](pto::BindTileOp op) { return op.getResult().getType() == typeConverter.convertType(op.getResult().getType()); @@ -399,7 +412,7 @@ struct VPTOPtrNormalizePass populateFunctionOpInterfaceTypeConversionPattern(patterns, typeConverter); patterns.add Date: Mon, 13 Apr 2026 19:04:46 +0800 Subject: [PATCH 051/192] Refactor tilelang ST framework and document full execution flow - make cases.py the single source of truth for TileLang ST cases - simplify tilelang_st host common usage and reuse shared file helpers - expand the tileop-expand design doc with the full TileLang ST execution and run flow --- docs/designs/ptoas-tileop-expand-design.md | 338 +++++++++ docs/designs/tilelang-st-framework.md | 665 ++++++++++++------ test/tilelang_st/npu/a5/src/st/CMakeLists.txt | 4 +- .../npu/a5/src/st/testcase/CMakeLists.txt | 2 +- .../npu/a5/src/st/testcase/compare.py | 16 + .../npu/a5/src/st/testcase/st_common.py | 147 ++++ .../npu/a5/src/st/testcase/tadd/cases.py | 41 ++ .../npu/a5/src/st/testcase/tadd/compare.py | 80 --- .../npu/a5/src/st/testcase/tadd/gen_data.py | 32 +- .../npu/a5/src/st/testcase/tadd/main.cpp | 117 ++- test/tilelang_st/script/run_st.py | 34 +- 11 files changed, 1073 insertions(+), 403 deletions(-) create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/st_common.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/cases.py delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/compare.py diff --git a/docs/designs/ptoas-tileop-expand-design.md b/docs/designs/ptoas-tileop-expand-design.md index 8e9967ed6..10c0e8e95 100644 --- a/docs/designs/ptoas-tileop-expand-design.md +++ b/docs/designs/ptoas-tileop-expand-design.md @@ -993,3 +993,341 @@ def template_xxx(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): - Bisheng 设备侧编译校验。 - 融合场景测试(多个 Tile op 连续使用后的 VF Fusion) - 更新 `PTO_IR_manual.md` 和 TileLang DSL Guide + +#### 4.4.1 ST 精度验证 + +IR 回归测试只能验证"模板展开后 IR 长什么样",无法回答"最终在 simulator / NPU 上跑出来的数值是否正确"。 +`test/tilelang_st` 框架提供了端到端精度验证能力,详细设计参见 [`tilelang-st-framework.md`](tilelang-st-framework.md)。 + +本节面向库开发者,说明在完成一个新 TileLang 库实现(如 `lib/TileOps/_template.py`)后,如何接入 ST 框架验证精度。 + +##### 完整执行链路概览 + +ST 框架的统一入口是: + +```bash +python3 test/tilelang_st/script/run_st.py -r sim -v a5 -t tadd +``` + +它不是只做“编译 `.pto`”,而是把编译、生成输入、运行二进制和精度比较串成一条完整流水线: + +```text +run_st.py + ├─ set_env_variables() + │ └─ 配置 simulator / NPU 运行环境 + ├─ build_project() + │ ├─ cmake -DRUN_MODE=... -DSOC_VERSION=... -DTEST_CASE=... -DPTOAS_BIN=... + │ ├─ ptoas: .pto -> _kernel.ll + │ │ flags: + │ │ --pto-arch=a5 + │ │ --pto-backend=vpto + │ │ --enable-insert-sync + │ │ --enable-tile-op-expand + │ │ --vpto-emit-hivm-llvm + │ ├─ bisheng -x ir: _kernel.ll -> _kernel_device.o + │ ├─ repack_tilelang_kernel.sh: + │ │ _kernel_device.o -> _kernel_repack.o + │ ├─ bisheng -xcce: launch.cpp + _kernel_repack.o -> lib_kernel.so + │ └─ bisheng -xc++: main.cpp -> + ├─ run_gen_data() + │ └─ 在 build/testcase// 下生成每个 case 的 input/golden + ├─ run_binary() + │ └─ 在 build/testcase// 下执行 ../../bin/ [case] + └─ run_compare() + └─ 在 build/testcase// 下逐 case 比较 golden/output +``` + +其中编译子链可以单独理解为: + +```text +.pto + ──ptoas──> _kernel.ll (LLVM IR) + ──bisheng -x ir──> _kernel_device.o (device-only 对象) + ──repack_tilelang_kernel.sh──> _kernel_repack.o + (host-linkable fatobj) + ──bisheng -xcce launch.cpp + repack.o──> lib_kernel.so + (共享库) + ──bisheng -xc++ main.cpp + .so──> (host 可执行文件) +``` + +其中 repack 步骤是 TileLang ST 与 pto-isa ST 的核心区别:`ptoas + bisheng -x ir` 产出的 +`*_kernel_device.o` 是 device-only 对象,不能直接作为 host 侧链接输入。repack 脚本从 +`launch.cpp` 中抽取 kernel 声明生成 stub,通过 `-fcce-include-aibinary` 嵌入 device binary, +产出 host 可链接的 fatobj。 + +运行阶段同样是 ST 框架的一部分,而不是“编译完以后开发者手工处理”的额外步骤: + +- `gen_data.py` 会基于 `cases.py` 中的 `CASES` 为每个 case 生成 `input*.bin` 和 `golden.bin` +- host 可执行文件会按 `main.cpp` 中的 case table 逐个读取 `.//input*.bin`,运行 kernel,并写回 `.//output.bin` +- `compare.py` 再基于同一份 `CASES` 定义逐 case 做 `numpy.allclose` 比较 +- 若传入 `-c `,则运行和比较都只针对单个 case + +因此,TileLang ST 的验证对象不是“某一份中间 IR 是否长得对”,而是: + +1. TileLang 模板是否成功展开并编译到可执行产物; +2. 生成的数据、运行时读取的 case 目录、以及 compare 使用的 golden/output 是否保持一致; +3. 最终 simulator / NPU 上的数值结果是否正确。 + +编译子链由 `testcase/CMakeLists.txt` 中的 `pto_tilelang_vec_st()` 宏自动接管,整条执行链路则由 +`run_st.py` 统一调度。 + +##### 新增 testcase 所需文件(七件套) + +以新增 `pto.tsub` 为例,需在 `test/tilelang_st/npu/a5/src/st/testcase/tsub/` 下准备 +6 个文件,并修改 1 个注册文件: + +**1. `CMakeLists.txt`** — 通常只有一行: + +```cmake +pto_tilelang_vec_st(tsub) +``` + +宏自动查找同目录下的 `tsub.pto`、`launch.cpp`、`main.cpp`,串联上述五步编译。 + +**2. `cases.py`** — **case 定义的单一来源**,`gen_data.py` 和 `compare.py` 均从此导入: + +```python +import numpy as np + +CASES = [ + { + "name": "f32_16x64", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-6, + }, +] +``` + +每个 case 必须包含 `name`/`dtype`/`shape`/`valid_shape`/`eps` 五个字段,`valid_shape` 为必填。 + +**3. `tsub.pto`** — kernel 描述,一个文件中放多个 case 对应的函数。每个函数 +代表一种 dtype/shape 组合。以 tadd 为参考,kernel 结构为: + +```mlir +module { + // Case: f32 16x64 + func.func @TSUB_f32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, + %c_ptr: !pto.ptr) { + // 1. make_tensor_view: 从 !pto.ptr 构造 5D tensor_view (1×1×1×rows×cols) + // 2. partition_view: 提取 tile 区域 + // 3. alloc_tile: 分配 UB 上的 tile_buf + // 4. tload: 从 GM 加载到 UB + // 5. pto.tsub: 执行计算 + // 6. tstore: 从 UB 写回 GM + return + } + // Case: f32 32x32 + func.func @TSUB_f32_32x32(...) { ... } +} +``` + +函数命名约定:`__x`,例如 `TSUB_f32_16x64`、`TSUB_bf16_32x32`。 + +注意:`.pto` 中 `make_tensor_view` 的 shape 维度是 5D(`1×1×1×rows×cols`),strides 需要 +与 shape 一致(最内维 stride=1,逐维累乘)。函数参数顺序决定了后续所有文件的参数顺序。 + +**4. `launch.cpp`** — 为每个 kernel 声明 entry 和 launch wrapper: + +```cpp +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +extern "C" __global__ AICORE void TSUB_f32_16x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTSUB_f32_16x64(float *a, float *b, float *c, void *stream) { + TSUB_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} +``` + +关键约束: +- `extern "C" __global__ AICORE void ...` 这一声明形态不可改变,repack 脚本用 sed 从中抽取 stub +- kernel 参数类型和顺序必须与 `.pto` 中函数签名一致 +- `<<<1, nullptr, stream>>>` 表示单核启动 + +**5. `main.cpp`** — host driver,核心是 case table 和 `RunCase()` 函数: + +```cpp +#include "acl/acl.h" +#include "test_common.h" // PtoTestCommon::ReadFile / WriteFile + ACL_CHECK + +using LaunchFn = void (*)(float *, float *, float *, void *); + +struct TestCase { + const char *name; // 对应 cases.py 中的 name 和运行时子目录 + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; +}; + +static const TestCase kCases[] = { + {"f32_16x64", LaunchTSUB_f32_16x64, 16, 64, 16, 64, sizeof(float)}, + {"f32_32x32", LaunchTSUB_f32_32x32, 32, 32, 32, 32, sizeof(float)}, +}; +``` + +注意:`ACL_CHECK` 宏由公共头 `test_common.h` 提供(需在 `acl/acl.h` 之后包含),无需在每个 testcase 中重复定义。 + +`RunCase()` 的职责: +1. 从 `.//input*.bin` 读取输入到 host 内存 +2. `aclrtMemcpy` 拷贝到 device +3. 调用 `tc.launch(...)` 启动 kernel +4. `aclrtSynchronizeStream` 等待完成 +5. 拷贝结果回 host +6. 写 `.//output.bin` + +`main()` 支持可选 `argv[1]` 作为 case filter,实现单 case 执行。 + +**6. `gen_data.py`** — 生成每个 case 的输入和 golden,从 `cases.py` 导入 `CASES`: + +```python +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) # per-case seed,新增 case 不影响已有数据 + dtype, shape = case["dtype"], case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + input2 = np.random.randint(1, 10, size=shape).astype(dtype) + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = (input1[:vr, :vc] - input2[:vr, :vc]).astype(dtype, copy=False) # tsub: 减法 + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) +``` + +注意 golden 的计算逻辑必须与 op 语义一致(tadd 是加法,tsub 是减法),且只在 `valid_shape` 区域内计算。 + +`compare.py` 为公共脚本(位于 `testcase/compare.py`),所有 testcase 共享,无需 per-testcase 编写。 +`run_st.py` 运行时将它与 per-testcase 的 `cases.py` 一起拷贝到 build 目录,自动读取 case 列表和阈值, +只比较 `valid_shape` 区域。exit code 2 表示失败。 + +精度阈值参考: + +| dtype | 建议 eps | +|---|---| +| `float32` | `1e-6` | +| `float16` | `1e-3` | +| `bfloat16` | `1e-2` | +| `int8/int16/int32` | `0`(精确匹配) | + +**7. 注册** — 修改 `testcase/CMakeLists.txt`,将新 op 加入 `ALL_TESTCASES`: + +```cmake +set(ALL_TESTCASES + tadd + tsub # ← 新增 +) +``` + +##### 文件间一致性约束 + +新增 testcase 时最容易出错的是以下几处必须严格一致: + +| 约束 | 涉及文件 | 示例 | +|---|---|---| +| kernel 函数名 | `.pto` ↔ `launch.cpp` | `@TSUB_f32_16x64` ↔ `TSUB_f32_16x64` | +| Launch wrapper 名 | `launch.cpp` ↔ `main.cpp` | `LaunchTSUB_f32_16x64` | +| case 名 | `cases.py` ↔ `main.cpp` kCases[] ↔ 运行时目录 | `f32_16x64` | +| 参数顺序 | `.pto` → `launch.cpp` → `main.cpp` 的 launch 调用 | `(a, b) → c` | +| shape / valid_shape | `cases.py` ↔ `.pto` tile shape ↔ `main.cpp` rows/cols/validRows/validCols | `16×64` / `(16, 64)` | + +Python 侧的 case 名、dtype、shape、valid_shape、eps 已通过 `cases.py` 收敛为单一来源。 +但 C++ 侧 `main.cpp` 的 `kCases[]` 和 `.pto` 仍需手动与 `cases.py` 保持一致。 + +任何一处不一致都可能导致:编译成功但运行时 segfault,或运行成功但比较结果错误且难以定位。 + +##### 运行方式 + +统一入口为 `test/tilelang_st/script/run_st.py`。前置条件: +- `ptoas` 已编译(默认路径 `build/tools/ptoas/ptoas`,也可通过 `-p` 指定或 `PTOAS_BIN` 环境变量) +- `ASCEND_HOME_PATH` 已设置 +- 建议先执行 `source scripts/ptoas_env.sh` + +```bash +# simulator 上跑 tsub 全部 case +python3 test/tilelang_st/script/run_st.py -r sim -v a5 -t tsub + +# NPU 上跑 tsub 全部 case +python3 test/tilelang_st/script/run_st.py -r npu -v a5 -t tsub + +# 只跑单个 case +python3 test/tilelang_st/script/run_st.py -r sim -v a5 -t tsub -c f32_16x64 + +# 复用已有 build,跳过重新编译(只重新生成数据、执行、比较) +python3 test/tilelang_st/script/run_st.py -r sim -v a5 -t tsub -c f32_16x64 -w +``` + +`run_st.py` 执行顺序:`set_env_variables()` → `build_project()` → `run_gen_data()` → +`run_binary()` → `run_compare()`。产物输出到 +`test/tilelang_st/npu/a5/src/st/build/testcase//` 下: + +```text +build/testcase/tsub/ +├── st_common.py # 从 testcase/ 公共目录拷贝 +├── compare.py # 从 testcase/ 公共目录拷贝 +├── cases.py # 从 testcase/tsub/ 拷贝 +├── gen_data.py # 从 testcase/tsub/ 拷贝 +├── f32_16x64/ +│ ├── input1.bin +│ ├── input2.bin +│ ├── golden.bin +│ └── output.bin +└── f32_32x32/ + └── ... +``` + +##### 建议的开发验证节奏 + +1. **最小 case 先行**:先写一个最小 case(如 `f32_16x64`),在 simulator 上跑通: + ```bash + python3 test/tilelang_st/script/run_st.py -r sim -v a5 -t tsub -c f32_16x64 + ``` + +2. **快速迭代**:修改 `.pto` 或 host 代码后,用 `-w` 跳过 cmake/make 重编译。 + 注意:如果改了 `.pto` 本身,仍需重新编译(不加 `-w`),`-w` 只适合改 `gen_data.py` / + `compare.py` / `main.cpp` 中非编译相关逻辑的情况。 + +3. **扩充 case**:单 case 稳定后,补充更多 shape / dtype 组合。建议覆盖: + - 不同 dtype(f32 / f16 / bf16) + - 不同 tile 形状(正方形、长条形) + - 边界情况(valid 行列不是整 tile 的场景) + +4. **全量验证**:跑全量 case 确认无回归。 + +5. **NPU 验证**:切到 `-r npu` 在真实硬件上验证。simulator 和 NPU 的行为可能存在差异。 + +##### 调试建议 + +| 阶段 | 排查方向 | +|---|---| +| `ptoas` 编译失败 | 检查 `.pto` 语法、TileLang 模板是否匹配、是否缺少 `--enable-tile-op-expand` | +| `bisheng -x ir` 失败 | 检查 `build/testcase//_kernel.ll` 中的 LLVM IR | +| repack 失败 | 检查 `launch.cpp` 中的 kernel 声明是否符合 `extern "C" __global__ AICORE void` 格式 | +| 链接失败 | 检查共享库符号名一致性、ACL 运行时依赖 | +| kernel 执行失败 | 确认 `build/testcase///input*.bin` 是否已生成 | +| compare fail | 先检查 `output.bin` vs `golden.bin` 差异,再检查 `.pto` 语义和参数顺序 | + +##### 已有 testcase 下新增 case + +如果只是在已有 testcase(如 `tadd`)下新增一个 case(如 `f32_8x128`),只需同步修改 4 个文件: + +| 文件 | 修改内容 | +|---|---| +| `cases.py` | 在 `CASES` 中加入 `{"name": "f32_8x128", "dtype": np.float32, "shape": (8, 128), "valid_shape": (8, 128), "eps": 1e-6}` | +| `tadd.pto` | 新增 `func.func @TADD_f32_8x128(...)` 函数体 | +| `launch.cpp` | 新增 `extern "C"` kernel 声明和 `LaunchTADD_f32_8x128` wrapper | +| `main.cpp` | 在 `kCases[]` 中加入 `{"f32_8x128", LaunchTADD_f32_8x128, 8, 128, 8, 128, sizeof(float)}` | + +`gen_data.py` 和 `compare.py` 无需修改,自动从 `cases.py` 读取。 diff --git a/docs/designs/tilelang-st-framework.md b/docs/designs/tilelang-st-framework.md index 8a72a4b08..df411d981 100644 --- a/docs/designs/tilelang-st-framework.md +++ b/docs/designs/tilelang-st-framework.md @@ -1,332 +1,537 @@ # TileLang ST 精度验证框架 -## 概述 +## 1. 文档目标 -TileLang ST(System Test)框架用于在 Ascend NPU 硬件或仿真器上端到端验证 TileLang DSL 模板库生成的 kernel 精度。框架参考 pto-isa 的 ST 目录结构和运行流程,但针对 TileLang 的编译路径(`.pto → LLVM IR → .o`,而非 pto-isa 的 `.cpp → -xcce → .o`)做了适配。 +本文从 TileLang 库开发者的视角介绍当前 `test/tilelang_st` 框架的使用方式。 -### 与 pto-isa ST 的关键差异 +这份框架的目标不是做单纯的 IR 回归,而是回答下面两个更贴近开发的问题: -| | pto-isa ST | TileLang ST | -|--|-----------|-------------| -| kernel 源码 | 手写 C++(`kernel.cpp`) | PTO DSL(`tadd.pto`) | -| kernel 编译 | `bisheng -xcce kernel.cpp` | `ptoas .pto → .ll` + `bisheng -x ir .ll → .o` | -| 精度比较 | C++ `ResultCmp()`(GTest) | Python `np.allclose`(`compare.py`) | -| 多 case 支持 | 单文件多 GTest TEST_F | 单 `.pto` 多 kernel 函数 + case table | +1. 我新写的 TileLang 模板库实现,展开到 PTO / VPTO / LLVM IR 之后,最终在 simulator 或 NPU 上跑出来的数值是否正确。 +2. 如果我要为一个新 op 增加 ST 用例,最少需要准备哪些文件,运行链路会经过哪些阶段。 -### 执行流程 +当前框架已经具备下面这些能力: +- 从 `.pto` 直接驱动 `ptoas`,不需要手写 `kernel.cpp` +- 支持在一个 testcase 下放多个 case +- 支持 `sim` / `npu` 两种运行模式 +- 支持单 case 过滤 +- 支持把输入、golden、output 隔离到 `build/testcase//` 下,避免不同 testcase 之间互相覆盖 + +## 2. 框架定位 + +TileLang ST 参考了 `pto-isa` 的 ST 目录组织方式,但编译链路不同。 + +| 维度 | pto-isa ST | TileLang ST | +|---|---|---| +| kernel 来源 | 手写 `kernel.cpp` | 手写 `.pto`,由 `ptoas` 展开 TileLang DSL 模板 | +| 编译入口 | `bisheng -xcce kernel.cpp` | `ptoas .pto -> .ll`,再 `bisheng -x ir .ll -> device.o` | +| device 对象接入 host | 编译器一步直接生成 fatobj | 先产出 device-only `.o`,再 repack 成 host-linkable fatobj | +| 精度比较 | GTest / C++ 比较逻辑 | `compare.py` + `numpy.allclose` | +| 多 case 组织 | 多个 GTest case | 一个 testcase 下多个 kernel 函数 + host case table | + +换句话说,TileLang ST 更适合验证“库模板展开后的端到端运行正确性”,而不是验证某一段单独的 CCE kernel.cpp。 + +## 3. 当前执行流程 + +统一入口是: + +```bash +python3 test/tilelang_st/script/run_st.py -r sim -v a5 -t tadd ``` + +完整链路如下: + +```text run_st.py - ├── set_env # 设置 ASCEND / simulator 环境变量 - ├── cmake + make # ptoas→.ll → bisheng→.o → link .so → build 可执行文件 - ├── gen_data.py # numpy 生成 input + golden(per-case 子目录) - ├── ./tadd [case] # 运行 kernel,写 output.bin(per-case 子目录) - └── compare.py # np.allclose 逐 case 比较 golden vs output + ├─ set_env_variables() + │ └─ 配置 simulator / NPU 运行环境 + ├─ build_project() + │ ├─ cmake -DRUN_MODE=... -DSOC_VERSION=... -DTEST_CASE=... -DPTOAS_BIN=... + │ ├─ ptoas: .pto -> _kernel.ll + │ │ flags: + │ │ --pto-arch=a5 + │ │ --pto-backend=vpto + │ │ --enable-insert-sync + │ │ --enable-tile-op-expand + │ │ --vpto-emit-hivm-llvm + │ ├─ bisheng -x ir: _kernel.ll -> _kernel_device.o + │ ├─ repack_tilelang_kernel.sh: + │ │ _kernel_device.o -> _kernel_repack.o + │ ├─ bisheng -xcce: launch.cpp + _kernel_repack.o -> lib_kernel.so + │ └─ bisheng -xc++: main.cpp -> + ├─ run_gen_data() + │ └─ 在 build/testcase// 下生成每个 case 的 input/golden + ├─ run_binary() + │ └─ 在 build/testcase// 下执行 ../../bin/ [case] + └─ run_compare() + └─ 在 build/testcase// 下逐 case 比较 golden/output ``` -## 目录结构 +### 3.1 关于 repack 步骤 -``` +这是当前 TileLang ST 相比 `pto-isa` 手写 kernel.cpp 路径最大的区别。 + +`ptoas + bisheng -x ir` 产出的 `*_kernel_device.o` 是 device-only 对象,不能直接作为 host 侧共享库链接输入。框架会调用 `test/tilelang_st/npu/a5/src/st/testcase/repack_tilelang_kernel.sh` 做两件事: + +1. 从 `launch.cpp` 中抽取 `extern "C" __global__ AICORE void ...` 声明,生成一个最小 stub +2. 通过 `-fcce-include-aibinary ` 把 device binary 嵌入这个 host stub,最终产出 host 可链接的 fatobj 对象 + +最终 `launch.cpp` 和 `*_kernel_repack.o` 一起链接成 `lib_kernel.so`。 + +如果没有这个 repack 步骤,host 可执行文件无法通过共享库把 LLVM IR 编出来的 device kernel 注册并发射出去。 + +### 3.2 关于 case 的执行和比较顺序 + +默认情况下: + +1. `gen_data.py` 会先为 testcase 下的所有 case 生成输入和 golden +2. `./bin/` 会依次跑完所有 case +3. `compare.py` 再依次比较所有 case 的 `golden.bin` 和 `output.bin` + +如果使用 `-c `,则运行和比较都会只针对这个 case。 + +## 4. 目录结构与职责 + +当前目录结构如下: + +```text test/tilelang_st/ ├── script/ -│ └── run_st.py # 统一入口脚本 +│ └── run_st.py └── npu/ - └── a5/ # SoC 架构 + └── a5/ └── src/st/ - ├── CMakeLists.txt # 顶层 CMake(编译器/环境配置) + ├── CMakeLists.txt └── testcase/ - ├── CMakeLists.txt # pto_tilelang_vec_st() 宏定义 + op 注册 - └── tadd/ # 每个 op 一个目录 - ├── CMakeLists.txt # 一行:pto_tilelang_vec_st(tadd) - ├── tadd.pto # kernel DSL(可包含多个函数) - ├── launch.cpp # kernel 声明 + launch wrapper - ├── main.cpp # host driver(case table 驱动) - ├── gen_data.py # 数据生成 - └── compare.py # 精度比较 + ├── CMakeLists.txt + ├── run_ptoas_to_file.cmake + ├── repack_tilelang_kernel.sh + ├── st_common.py + ├── compare.py + └── tadd/ + ├── CMakeLists.txt + ├── cases.py + ├── tadd.pto + ├── launch.cpp + ├── main.cpp + └── gen_data.py ``` -## 快速上手 +各文件职责如下: + +| 文件 | 职责 | +|---|---| +| `script/run_st.py` | 统一入口,负责编译、生成数据、执行二进制、比较结果 | +| `src/st/CMakeLists.txt` | 顶层 CMake,设置编译器、环境和依赖 | +| `testcase/CMakeLists.txt` | 定义 `pto_tilelang_vec_st()` 宏,并注册所有 testcase | +| `testcase/run_ptoas_to_file.cmake` | 封装 `ptoas` 调用,把 `.pto` 编译成 LLVM IR | +| `testcase/repack_tilelang_kernel.sh` | 把 device-only `.o` 包装成 host 可链接的 fatobj | +| `testcase/st_common.py` | 所有 testcase 共享的 Python 公共模块(case 校验、数据生成辅助、精度比较、终端着色) | +| `testcase/compare.py` | 公共比较脚本,所有 testcase 共享,从 per-testcase 的 `cases.py` 导入 `CASES` 后调用 `st_common.run_compare()` | +| `testcase//cases.py` | **case 定义的单一来源**,`gen_data.py` 和 `compare.py` 均从此导入 | +| `testcase//.pto` | testcase 的 kernel 描述,通常一个文件中放多个 case 对应的函数 | +| `testcase//launch.cpp` | kernel 声明和 launch wrapper | +| `testcase//main.cpp` | host driver,负责分配内存、launch kernel、回写 output(`ACL_CHECK` 宏由公共头 `test_common.h` 提供) | +| `testcase//gen_data.py` | 生成 input 与 golden,从 `cases.py` 读取 case 列表 | + +## 5. 日常使用方式 + +### 5.0 前置条件 -### 运行已有测试 +运行 TileLang ST 之前,建议先确认下面几件事: + +- 仓库里的 `ptoas` 已经编出来,默认路径是 `build/tools/ptoas/ptoas` +- `ASCEND_HOME_PATH` 已经设置正确 +- 如果需要手工跑 `ptoas`、`bisheng` 或 lit,优先先执行: ```bash -# 在 NPU 上跑 tadd 全部 case -python3 test/tilelang_st/script/run_st.py -r npu -v a5 -t tadd +source scripts/ptoas_env.sh +``` -# 在仿真器上跑 +`run_st.py` 会在运行时补充 simulator / NPU 相关环境,但它不会替你构建 `ptoas`。 + +### 5.1 运行已有 testcase + +```bash +# simulator 上跑 tadd 全部 case python3 test/tilelang_st/script/run_st.py -r sim -v a5 -t tadd -# 只跑某个 case -python3 test/tilelang_st/script/run_st.py -r npu -v a5 -t tadd -c f32_16x64 +# NPU 上跑 tadd 全部 case +python3 test/tilelang_st/script/run_st.py -r npu -v a5 -t tadd + +# 只跑一个 case +python3 test/tilelang_st/script/run_st.py -r sim -v a5 -t tadd -c f32_16x64 + +# 复用已有 build 目录,不重新编译 +python3 test/tilelang_st/script/run_st.py -r sim -v a5 -t tadd -w +``` + +### 5.2 常用参数 + +| 参数 | 含义 | +|---|---| +| `-r, --run-mode` | 运行模式,`sim` 或 `npu` | +| `-v, --soc-version` | SoC 版本,目前只支持 `a5` | +| `-t, --testcase` | testcase 名称,对应 `testcase//` | +| `-c, --case` | 只运行一个 case | +| `-p, --ptoas-bin` | 指定 `ptoas` 路径 | +| `-w, --without-build` | 跳过构建,直接复用已有 `build/` | + +### 5.3 产物在哪 + +testcase 的运行时数据不再写到 `build/` 根目录,而是写到: -# 跳过编译(已有 build 产物时) -python3 test/tilelang_st/script/run_st.py -r npu -v a5 -t tadd -w +```text +test/tilelang_st/npu/a5/src/st/build/testcase// ``` -### run_st.py 参数 +以 `tadd` 为例: + +```text +build/testcase/tadd/ +├── gen_data.py +├── compare.py +├── f32_16x64/ +│ ├── input1.bin +│ ├── input2.bin +│ ├── golden.bin +│ └── output.bin +└── f32_32x32/ + ├── input1.bin + ├── input2.bin + ├── golden.bin + └── output.bin +``` -| 参数 | 说明 | 示例 | -|------|------|------| -| `-r, --run-mode` | 运行模式 | `npu` 或 `sim` | -| `-v, --soc-version` | 架构版本 | `a5` | -| `-t, --testcase` | op 名称 | `tadd` | -| `-c, --case` | 指定单个 case(可选) | `f32_16x64` | -| `-p, --ptoas-bin` | ptoas 路径(可选,默认自动查找) | `/path/to/ptoas` | -| `-w, --without-build` | 跳过编译 | — | +这个布局的好处是: -ptoas 路径查找顺序:`-p` 参数 → `PTOAS_BIN` 环境变量 → 从脚本位置向上遍历 `build/bin/ptoas`。 +- 不同 testcase 之间不会因为 case 同名而互相覆盖 +- 方便开发者直接进入 `build/testcase//` 复查输入、输出和 golden +- 使用 `-w` 时,不容易把旧 testcase 的残留数据误认为当前结果 -## 新增一个 op 测试 +### 5.4 比较输出 -以新增 `tsub` 为例。 +`compare.py` 会对 pass/fail 做明显提示: -### 第 1 步:创建目录和文件 +- pass:粗体绿色 +- fail:粗体红色 -```bash -mkdir test/tilelang_st/npu/a5/src/st/testcase/tsub +比较逻辑目前使用 `numpy.allclose`。建议阈值: + +| dtype | 建议 eps | +|---|---| +| `float32` | `1e-6` | +| `float16` | `1e-3` | +| `bfloat16` | `1e-2` | +| `int8/int16/int32` | `0` | + +## 6. 作为库开发者,如何增加一个新 op testcase + +这一节回答“我开发了一个新的 TileLang 库实现,怎么用 ST 框架验证它”。 + +以新增 `pto.tsub` 为例,最少需要准备下面这些文件: + +| 文件 | 是否新增/修改 | 说明 | +|---|---|---| +| `testcase/tsub/CMakeLists.txt` | 新增 | 一般只有一行 `pto_tilelang_vec_st(tsub)` | +| `testcase/tsub/cases.py` | 新增 | **case 定义的单一来源**:每个 case 必须指定 `name`/`dtype`/`shape`/`valid_shape`/`eps` | +| `testcase/tsub/tsub.pto` | 新增 | 定义一个或多个 case 的 kernel 函数 | +| `testcase/tsub/launch.cpp` | 新增 | 为每个 kernel 函数声明 entry 并提供 launch wrapper | +| `testcase/tsub/main.cpp` | 新增 | host driver,负责 case table、内存拷贝、launch 和 output 落盘 | +| `testcase/tsub/gen_data.py` | 新增 | 生成每个 case 的输入和 golden,从 `cases.py` 导入 `CASES` | +| `testcase/CMakeLists.txt` | 修改 | 把 `tsub` 加入 `ALL_TESTCASES` | + +通常不需要修改: + +- `script/run_st.py` +- `src/st/CMakeLists.txt` +- `testcase/st_common.py` +- `testcase/compare.py`(公共脚本,所有 testcase 共享) +- `testcase/run_ptoas_to_file.cmake` +- `testcase/repack_tilelang_kernel.sh` + +除非你在改框架本身,而不是新增一个 testcase。 + +## 7. 以 `pto.tadd` 为例,需要改哪些文件 + +当前仓库里 `tadd` 已经是一个完整样例。把它当成模板即可。 + +### 7.1 `testcase/tadd/CMakeLists.txt` + +这个文件通常最简单: + +```cmake +pto_tilelang_vec_st(tadd) ``` -需要创建 6 个文件,下面逐一说明。 +含义是让公共宏接管 `tadd.pto -> tadd_kernel.ll -> tadd_kernel_device.o -> tadd_kernel_repack.o -> libtadd_kernel.so -> tadd` 这一整条流水线。 -### 第 2 步:编写 kernel(tsub.pto) +### 7.2 `testcase/tadd/tadd.pto` -单个 `.pto` 文件中包含所有 case 对应的 kernel 函数,函数名格式为 `@TSUB__x`: +这是最核心的文件。你需要在这里写出要验证的 kernel 形态。 -```mlir -module { - // Case 0: f32 16x64 - func.func @TSUB_f32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - // make_tensor_view × 3, partition_view × 3, alloc_tile × 3 - // tload × 2, tsub, tstore - ... - return - } +当前 `tadd.pto` 的特点是: - // Case 1: f32 32x32 - func.func @TSUB_f32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - ... - return - } -} +- 一个文件中包含多个 case +- 每个 case 对应一个 `func.func @TADD__x(...)` +- 函数体里显式写出 `make_tensor_view`、`partition_view`、`alloc_tile`、`tload`、`pto.tadd`、`tstore` + +如果你在开发 `pto.tadd` 库实现,最关键的是先把你要覆盖的 case 设计好。例如: + +- `f32` / `f16` / `bf16` +- 不同 tile 形状 +- 边界 valid 行列不是整 tile 的情况 + +这里的函数命名建议统一成: + +```text +TADD__x ``` -### 第 3 步:编写 launch wrapper(launch.cpp) +例如: -每个 kernel 函数需要一个 `__global__` 声明和一个 `Launch*` C++ wrapper: +```text +TADD_f32_16x64 +TADD_f32_32x32 +``` -```cpp -#ifndef __VEC_SCOPE__ -#define __VEC_SCOPE__ -#endif +### 7.3 `testcase/tadd/launch.cpp` +这个文件的职责只有两个: + +1. 声明 kernel entry +2. 为 host driver 提供 `Launch*` wrapper + +当前推荐写法和 `tadd` 一致: + +```cpp #include -#include -#include -#ifndef __CPU_SIM -#include "acl/acl.h" +#ifndef AICORE +#define AICORE [aicore] #endif -__global__ AICORE void TSUB_f32_16x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); +extern "C" __global__ AICORE void TADD_f32_16x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); -void LaunchTSUB_f32_16x64(float *a, float *b, float *c, void *stream) { - TSUB_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +void LaunchTADD_f32_16x64(float *a, float *b, float *c, void *stream) { + TADD_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); } +``` -__global__ AICORE void TSUB_f32_32x32(__gm__ float *a, __gm__ float *b, __gm__ float *c); +注意点: -void LaunchTSUB_f32_32x32(float *a, float *b, float *c, void *stream) { - TSUB_f32_32x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); -} -``` +- `launch.cpp` 不需要包含 PTO 头文件 +- `AICORE` 直接本地定义为 `[aicore]` +- 这里的 kernel 声明会被 repack 脚本抽取出来生成 stub,所以必须保留 `extern "C" __global__ AICORE void ...` 这一形态 +- kernel 参数顺序必须和 `.pto` 中函数签名保持一致 -**注意:** `__global__`、`AICORE`、`__gm__`、`<<<>>>` 是 CCE 扩展语法,本地 clang 会报错,这是预期行为——launch.cpp 由 bisheng `-xcce` 编译。 +### 7.4 `testcase/tadd/main.cpp` -### 第 4 步:编写 host driver(main.cpp) +这个文件负责 host 侧调度。 -使用 case table 驱动,每个 case 从独立子目录读写数据: +你需要做的事主要有三类: -```cpp -#include "test_common.h" -#include "acl/acl.h" -#include -#include -#include -#include -#include - -using namespace PtoTestCommon; - -#define ACL_CHECK(expr) \ - do { \ - const aclError _ret = (expr); \ - if (_ret != ACL_SUCCESS) { \ - std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", \ - #expr, (int)_ret, __FILE__, __LINE__); \ - const char *_recent = aclGetRecentErrMsg(); \ - if (_recent != nullptr && _recent[0] != '\0') \ - std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ - return 1; \ - } \ - } while (0) - -// launch wrappers -void LaunchTSUB_f32_16x64(float *a, float *b, float *c, void *stream); -void LaunchTSUB_f32_32x32(float *a, float *b, float *c, void *stream); - -using LaunchFn = void (*)(float *, float *, float *, void *); +1. 声明所有 `LaunchTADD_*` wrapper +2. 在 `kCases[]` 中列出每个 case 的名字、launch 函数、shape、valid shape、元素大小 +3. 在 `RunCase()` 中完成: + - 从 `.//input*.bin` 读取输入 + - `aclrtMemcpy` 把输入拷到 device + - 调用 `tc.launch(...)` + - `aclrtSynchronizeStream` + - 把输出拷回 host + - 写 `.//output.bin` + +当前 `tadd/main.cpp` 的 case table 形式如下: +```cpp struct TestCase { - const char *name; // 与 gen_data.py / compare.py 的 case name 一致 + const char *name; LaunchFn launch; - size_t rows; - size_t cols; - size_t elemSize; // bytes per element + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; }; static const TestCase kCases[] = { - {"f32_16x64", LaunchTSUB_f32_16x64, 16, 64, sizeof(float)}, - {"f32_32x32", LaunchTSUB_f32_32x32, 32, 32, sizeof(float)}, + {"f32_16x64", LaunchTADD_f32_16x64, 16, 64, 16, 64, sizeof(float)}, + {"f32_32x32", LaunchTADD_f32_32x32, 32, 32, 32, 32, sizeof(float)}, }; - -// main() 中循环 kCases,对每个 case: -// 1. 从 .//input1.bin, input2.bin 读数据 -// 2. H2D → launch kernel → D2H -// 3. 写 .//output.bin -// 支持 ./tsub [case_name] 过滤单个 case ``` -完整实现参考 `testcase/tadd/main.cpp`。 +注意:`ACL_CHECK` 宏已移至公共头文件 `test_common.h`(需在 `acl/acl.h` 之后包含),不需要在每个 testcase 的 `main.cpp` 中重复定义。 -### 第 5 步:数据生成与精度比较 +你在新增 case 时,必须同步更新这个表,字段需与 `cases.py` 中的 `shape` / `valid_shape` 保持一致。 -**gen_data.py** — 为每个 case 生成独立子目录的 `input1.bin`、`input2.bin`、`golden.bin`: +### 7.5 `testcase/tadd/cases.py` -```python -import os -import numpy as np +这是 case 定义的**单一来源**,`gen_data.py` 和 `compare.py` 均从此导入 `CASES`。 -np.random.seed(42) +每个 case 必须包含以下字段: +```python CASES = [ - {"name": "f32_16x64", "dtype": np.float32, "shape": (16, 64)}, - {"name": "f32_32x32", "dtype": np.float32, "shape": (32, 32)}, + { + "name": "f32_16x64", # case 标识,对应运行时子目录和 main.cpp kCases[] 中的 name + "dtype": np.float32, # numpy dtype + "shape": (16, 64), # 分配的 tile 维度 (rows, cols) + "valid_shape": (16, 64), # 有效计算区域 (valid_rows, valid_cols) + "eps": 1e-6, # numpy.allclose 容差 + }, ] +``` -for case in CASES: - case_dir = case["name"] - os.makedirs(case_dir, exist_ok=True) +`valid_shape` 为必填字段。当 valid shape 等于 tile shape 时也必须显式写出。 - input1 = np.random.randint(1, 10, size=case["shape"]).astype(case["dtype"]) - input2 = np.random.randint(1, 10, size=case["shape"]).astype(case["dtype"]) - golden = (input1 - input2).astype(case["dtype"], copy=False) # tsub: 减法 +### 7.6 `testcase/tadd/gen_data.py` - input1.tofile(os.path.join(case_dir, "input1.bin")) - input2.tofile(os.path.join(case_dir, "input2.bin")) - golden.tofile(os.path.join(case_dir, "golden.bin")) -``` +这个文件负责为每个 case 生成输入和 golden。从 `cases.py` 导入 `CASES`, +从 `st_common.py` 导入辅助函数(`setup_case_rng`、`save_case_data`)。 -**compare.py** — 逐 case 比较,支持 `python compare.py [case_name]` 过滤: +以 `pto.tadd` 为例,每个 case 的核心逻辑: ```python -import sys, os -import numpy as np +golden = np.zeros(shape, dtype=dtype) +vr, vc = case["valid_shape"] +golden[:vr, :vc] = (input1[:vr, :vc] + input2[:vr, :vc]).astype(dtype, copy=False) +``` -CASES = [ - {"name": "f32_16x64", "dtype": np.float32, "eps": 1e-6}, - {"name": "f32_32x32", "dtype": np.float32, "eps": 1e-6}, -] +golden 只在 `valid_shape` 区域内计算,区域外保持零值。 + +每个 case 使用独立的随机 seed(`setup_case_rng` 基于 `hash(case["name"])`), +新增或调整 case 顺序不会影响已有 case 的测试数据。 + +### 7.7 `testcase/compare.py`(公共,无需 per-testcase 修改) + +`compare.py` 位于 `testcase/` 公共目录,所有 testcase 共享同一份: -def compare_bin(golden_path, output_path, dtype, eps): - golden = np.fromfile(golden_path, dtype=dtype) - output = np.fromfile(output_path, dtype=dtype) - if golden.shape != output.shape: - return False - return np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) - -case_filter = sys.argv[1] if len(sys.argv) > 1 else None -all_passed = True -for case in CASES: - if case_filter and case["name"] != case_filter: - continue - ok = compare_bin( - os.path.join(case["name"], "golden.bin"), - os.path.join(case["name"], "output.bin"), - case["dtype"], case["eps"]) - if not ok: - all_passed = False -if not all_passed: - sys.exit(2) +```python +from cases import CASES +from st_common import run_compare + +if __name__ == "__main__": + run_compare(CASES) ``` -### 第 6 步:CMake 注册 +`run_st.py` 运行时会将它和 per-testcase 的 `cases.py` 一起拷贝到 build 目录, +`compare.py` 通过 `from cases import CASES` 获取当前 testcase 的 case 列表。 -**testcase/tsub/CMakeLists.txt**(一行): +`run_compare()` 会: +- 校验所有 case 必填字段 +- 只在 `valid_shape` 区域内比较 `golden.bin` 与 `output.bin` +- 支持 `argv[1]` 作为 case filter +- exit code 2 表示失败 -```cmake -pto_tilelang_vec_st(tsub) +## 8. 如果只是在已有 `tadd` 下新增一个 case + +如果 `tadd` testcase 已经存在,而你只是想加一个新 case,例如 `f32_8x128`,则通常只需要同步修改 4 个文件: + +| 文件 | 必须修改的内容 | +|---|---| +| `testcase/tadd/cases.py` | 在 `CASES` 中加入新条目(含 `name`/`dtype`/`shape`/`valid_shape`/`eps`) | +| `testcase/tadd/tadd.pto` | 新增一个 `func.func @TADD_f32_8x128(...)` | +| `testcase/tadd/launch.cpp` | 新增 `extern "C"` kernel 声明和 `LaunchTADD_f32_8x128` | +| `testcase/tadd/main.cpp` | 在 `kCases[]` 中加入 `{"f32_8x128", LaunchTADD_f32_8x128, 8, 128, 8, 128, sizeof(float)}` | + +不需要改: + +- `testcase/tadd/gen_data.py`(自动从 `cases.py` 读取) +- `testcase/tadd/compare.py`(自动从 `cases.py` 读取) +- `testcase/tadd/CMakeLists.txt` +- `testcase/CMakeLists.txt` +- `run_st.py` + +## 9. 文件之间必须保持一致的约束 + +这是新增 testcase 时最容易出错的地方。 + +### 9.1 命名一致 + +下面这几处名字必须严格一致: + +| 位置 | 示例 | +|---|---| +| `.pto` 中的 kernel 函数名 | `@TADD_f32_16x64` | +| `launch.cpp` 中的 kernel 声明 | `TADD_f32_16x64` | +| `launch.cpp` / `main.cpp` 中的 wrapper 名 | `LaunchTADD_f32_16x64` | +| `main.cpp` 的 case 名 | `f32_16x64` | +| `gen_data.py` / `compare.py` 的 case 名 | `f32_16x64` | +| 运行时目录名 | `build/testcase/tadd/f32_16x64/` | + +### 9.2 参数顺序一致 + +`.pto` 里 kernel 的参数顺序、`launch.cpp` 声明顺序、`main.cpp` 里 launch wrapper 的参数顺序必须一致。 +如果 `tadd` 的语义是 `(a, b) -> c`,那 host 侧和 compare 也都要按这个顺序组织。 + +### 9.3 shape、valid_shape 和 dtype 一致 + +`cases.py` 中的 `shape`/`valid_shape`/`dtype` 是 Python 侧的单一来源,`gen_data.py` 和 `compare.py` 自动从中读取。 +但 C++ 侧的 `main.cpp` `kCases[]`(`rows`/`cols`/`validRows`/`validCols`/`elemSize`)和 `.pto` 中的 tile shape 仍需手动与 `cases.py` 保持一致。 +否则运行能成功,结果也可能是错误的,且定位会很耗时。 + +## 10. 建议的开发验证节奏 + +作为库开发者,建议用下面的节奏迭代: + +1. 先写一个最小 case,例如 `f32_16x64` +2. 在 simulator 上跑单 case: + +```bash +python3 test/tilelang_st/script/run_st.py -r sim -v a5 -t tadd -c f32_16x64 ``` -**testcase/CMakeLists.txt** 中注册: +3. 改 `.pto` 或 host 代码后,如果确认只是小修改,可以用: -```cmake -set(ALL_TESTCASES - tadd - tsub # <-- 新增 -) +```bash +python3 test/tilelang_st/script/run_st.py -r sim -v a5 -t tadd -c f32_16x64 -w ``` -## 在已有 op 下新增 case +4. 单 case 稳定后,再补更多 shape / dtype case +5. 再跑全量 `tadd` +6. 最后如果需要,再切到 `-r npu` -不需要修改 CMake。只需要同步修改 4 个文件: +## 11. 调试建议 -| 步骤 | 文件 | 修改内容 | -|------|------|---------| -| 1 | `tadd.pto` | 新增 `func.func @TADD__(...)` | -| 2 | `launch.cpp` | 新增 `__global__` 声明 + `LaunchTADD__` wrapper | -| 3 | `main.cpp` | `kCases[]` 数组新增一行 | -| 4 | `gen_data.py` + `compare.py` | `CASES` 列表各新增一行 | +### 11.1 编译失败看哪里 -### 命名约定 +- `ptoas` 失败:优先看 `.pto` 本身、TileLang 模板实例化、是否缺少 `--enable-insert-sync` +- `bisheng -x ir` 失败:优先看生成的 `*_kernel.ll` +- repack 失败:优先看 `launch.cpp` 中的 kernel 声明是否符合脚本预期 +- `launch.cpp` / `main.cpp` 链接失败:优先看共享库、ACL 运行时依赖和符号名一致性 -- kernel 函数名:`@__x`,例如 `@TADD_f32_16x64` -- launch wrapper:`Launch__x` -- case name / 子目录名:`_x`,例如 `f32_16x64` -- 三处 case 列表(`main.cpp` `kCases[]`、`gen_data.py` `CASES`、`compare.py` `CASES`)必须保持一致 +### 11.2 运行失败看哪里 -## CMake 编译流水线 +- `main.cpp` 报读文件失败:先确认 `build/testcase///input*.bin` 是否存在 +- kernel 能跑但 compare fail:先看 `output.bin` 与 `golden.bin` 的差异,再看 `.pto` 语义和 host 参数顺序 +- 某个 case 单独跑通过、全量跑失败:优先怀疑 case 目录隔离、host 资源释放、或者多 case 共用状态 -`pto_tilelang_vec_st(NAME)` 宏定义在 `testcase/CMakeLists.txt` 中,完成 4 步编译: +### 11.3 典型排查文件 -``` -Step 1: ptoas NAME.pto → NAME_kernel.ll (LLVM IR) - ptoas --pto-arch=a5 --pto-backend=vpto - --enable-tile-op-expand --vpto-emit-hivm-llvm - NAME.pto -o NAME_kernel.ll - -Step 2: bisheng NAME_kernel.ll → NAME_kernel.o (object file) - bisheng --target=hiipu64-hisilicon-cce - -march=dav-c310-vec - --cce-aicore-arch=dav-c310-vec - --cce-aicore-only - -c -x ir NAME_kernel.ll -o NAME_kernel.o - -Step 3: launch.cpp (-xcce) + NAME_kernel.o → libNAME_kernel.so - bisheng -xcce launch.cpp + link NAME_kernel.o - -Step 4: main.cpp (-xc++) → NAME (可执行文件) - link libNAME_kernel.so + ACL runtime libraries +| 文件 | 作用 | +|---|---| +| `build/testcase//_kernel.ll` | 看 `ptoas` 最终生成的 LLVM IR | +| `build/testcase///golden.bin` | 确认 Python 侧 oracle 是否正确 | +| `build/testcase///output.bin` | 确认运行时实际输出 | +| `testcase//main.cpp` | 确认 host 侧参数顺序、shape 和文件路径 | +| `testcase//compare.py` | 确认比较阈值是否合理 | + +## 12. 一句话总结 + +对于库开发者来说,TileLang ST 框架就是一条固定好的端到端验证流水线: + +```text +写 .pto -> 接入 testcase 六件套 -> run_st.py 编译运行 -> 查看 build/testcase// 下的 input/golden/output -> 判断库实现是否正确 ``` -## 精度比较说明 +如果你想验证的是 `pto.tadd`,最重要的是把下面几处保持同步: -使用 `np.allclose(golden, output, atol=eps, rtol=eps)` 进行比较。不同 dtype 建议的 eps 值: +- `cases.py` 中的 case 定义(name/dtype/shape/valid_shape/eps)—— Python 侧的单一来源 +- `tadd.pto` 中的 kernel 函数名和 tile shape +- `launch.cpp` 中的 kernel 声明与 wrapper +- `main.cpp` 中的 `kCases[]`(rows/cols/validRows/validCols 需与 `cases.py` 一致) +- `gen_data.py` 中的 golden 计算逻辑(op 语义相关,如加法/减法) -| dtype | 建议 eps | -|-------|---------| -| float32 | 1e-6 | -| float16 | 1e-3 | -| bfloat16 | 1e-2 | -| int8/int32 | 0(精确匹配) | +`compare.py` 和 `gen_data.py` 的 case 列表、比较阈值均自动从 `cases.py` 读取,不需要单独维护。 -比较失败时会输出 max diff、出错位置和 golden/output 值,便于定位问题。 +这几处一致,框架就能帮助你把 TileLang 库实现的”端到端正确性”稳定地跑起来。 diff --git a/test/tilelang_st/npu/a5/src/st/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/CMakeLists.txt index 20248be01..60a1bf354 100644 --- a/test/tilelang_st/npu/a5/src/st/CMakeLists.txt +++ b/test/tilelang_st/npu/a5/src/st/CMakeLists.txt @@ -33,6 +33,8 @@ else() endif() set(PTO_ISA_ROOT "${CMAKE_CURRENT_LIST_DIR}/../../../../../../../pto-isa" CACHE PATH "Path to pto-isa repo") +set(PTO_NPU_VALIDATION_COMMON_DIR + "${CMAKE_CURRENT_LIST_DIR}/../../../../../npu_validation/common") set(ASCEND_DRIVER_PATH /usr/local/Ascend/driver) set(CMAKE_COMPILER bisheng) @@ -70,8 +72,6 @@ set(CMAKE_CPP_COMPILE_OPTIONS ) include_directories( - ${PTO_ISA_ROOT}/include - ${PTO_ISA_ROOT}/tests/common ${ASCEND_HOME_PATH}/include ${ASCEND_DRIVER_PATH}/kernel/inc ) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt index a1116f0b4..f3d2623e2 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt +++ b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt @@ -93,7 +93,7 @@ function(pto_tilelang_vec_st NAME) add_executable(${NAME} main.cpp) target_compile_options(${NAME} PRIVATE ${CMAKE_CPP_COMPILE_OPTIONS}) target_include_directories(${NAME} PRIVATE - ${PTO_ISA_ROOT}/tests/common + ${PTO_NPU_VALIDATION_COMMON_DIR} ) target_link_directories(${NAME} PUBLIC diff --git a/test/tilelang_st/npu/a5/src/st/testcase/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/compare.py new file mode 100644 index 000000000..ab6449f1c --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/compare.py @@ -0,0 +1,16 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +from cases import CASES +from st_common import run_compare + +if __name__ == "__main__": + run_compare(CASES) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/st_common.py b/test/tilelang_st/npu/a5/src/st/testcase/st_common.py new file mode 100644 index 000000000..8f61fdb8c --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/st_common.py @@ -0,0 +1,147 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Shared utilities for TileLang ST test cases. + +Provides: + - Case helpers: get_valid_shape() + - Data helpers: setup_case_rng(), save_case_data() + - Compare: compare_bin(), run_compare() (full compare entry point) + - Styling: supports_color(), style_pass(), style_fail() +""" + +import os +import sys +import numpy as np + + +# --------------------------------------------------------------------------- +# Case helpers +# --------------------------------------------------------------------------- + +REQUIRED_CASE_KEYS = {"name", "dtype", "shape", "valid_shape", "eps"} + + +def validate_cases(cases): + """Check that every case has all required keys.""" + for i, case in enumerate(cases): + missing = REQUIRED_CASE_KEYS - case.keys() + if missing: + raise ValueError(f"cases[{i}] ({case.get('name', '?')}) missing keys: {missing}") + + +# --------------------------------------------------------------------------- +# Data generation helpers +# --------------------------------------------------------------------------- + +def setup_case_rng(case): + """Set a per-case deterministic random seed. + + Using hash(name) ensures that adding/reordering cases does not change + the random data of existing cases. + """ + np.random.seed(hash(case["name"]) & 0xFFFFFFFF) + + +def save_case_data(case_name, data_dict): + """Create case directory and write {name}.bin for each entry in data_dict. + + Args: + case_name: subdirectory name (e.g. "f32_16x64"). + data_dict: mapping from file stem to numpy array, + e.g. {"input1": arr1, "input2": arr2, "golden": golden}. + """ + os.makedirs(case_name, exist_ok=True) + for name, arr in data_dict.items(): + arr.tofile(os.path.join(case_name, f"{name}.bin")) + + +# --------------------------------------------------------------------------- +# Terminal styling +# --------------------------------------------------------------------------- + +ANSI_RESET = "\033[0m" +ANSI_BOLD_GREEN = "\033[1;32m" +ANSI_BOLD_RED = "\033[1;31m" + + +def supports_color(): + return sys.stdout.isatty() and os.environ.get("TERM") not in (None, "", "dumb") + + +def style_pass(text): + if not supports_color(): + return text + return f"{ANSI_BOLD_GREEN}{text}{ANSI_RESET}" + + +def style_fail(text): + if not supports_color(): + return text + return f"{ANSI_BOLD_RED}{text}{ANSI_RESET}" + + +# --------------------------------------------------------------------------- +# Comparison +# --------------------------------------------------------------------------- + +def compare_bin(golden_path, output_path, dtype, eps, shape, valid_shape): + """Compare golden and output binary files within the valid region. + + Returns True on pass, False on mismatch. + """ + golden = np.fromfile(golden_path, dtype=dtype).reshape(shape) + output = np.fromfile(output_path, dtype=dtype).reshape(shape) + + vr, vc = valid_shape + g = golden[:vr, :vc].astype(np.float64, copy=False) + o = output[:vr, :vc].astype(np.float64, copy=False) + + if g.shape != o.shape: + print(style_fail(f"[ERROR] Shape mismatch: golden {g.shape} vs output {o.shape}")) + return False + if not np.allclose(g, o, atol=eps, rtol=eps, equal_nan=True): + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + print(style_fail(f"[ERROR] Mismatch: max diff={float(abs_diff.flat[idx])} " + f"at flat idx={idx} " + f"(golden={g.flat[idx]}, output={o.flat[idx]})")) + return False + return True + + +def run_compare(cases): + """Main entry point for per-testcase compare.py scripts. + + Reads an optional case filter from sys.argv[1], iterates over *cases*, + and exits with code 2 if any comparison fails. + """ + validate_cases(cases) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in cases: + if case_filter is not None and case["name"] != case_filter: + continue + case_dir = case["name"] + golden_path = os.path.join(case_dir, "golden.bin") + output_path = os.path.join(case_dir, "output.bin") + ok = compare_bin(golden_path, output_path, case["dtype"], case["eps"], + case["shape"], case["valid_shape"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tadd/cases.py new file mode 100644 index 000000000..e14e76c54 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadd/cases.py @@ -0,0 +1,41 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tadd ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_16x64", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-6, + }, + { + "name": "f32_32x32", + "dtype": np.float32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-6, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tadd/compare.py deleted file mode 100644 index 56d41377c..000000000 --- a/test/tilelang_st/npu/a5/src/st/testcase/tadd/compare.py +++ /dev/null @@ -1,80 +0,0 @@ -#!/usr/bin/python3 -# Copyright (c) 2026 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. - -# coding=utf-8 - -import sys -import os -import numpy as np - -ANSI_RESET = "\033[0m" -ANSI_BOLD_GREEN = "\033[1;32m" -ANSI_BOLD_RED = "\033[1;31m" - - -CASES = [ - {"name": "f32_16x64", "dtype": np.float32, "eps": 1e-6}, - {"name": "f32_32x32", "dtype": np.float32, "eps": 1e-6}, -] - - -def supports_color(): - return sys.stdout.isatty() and os.environ.get("TERM") not in (None, "", "dumb") - - -def style_pass(text): - if not supports_color(): - return text - return f"{ANSI_BOLD_GREEN}{text}{ANSI_RESET}" - - -def style_fail(text): - if not supports_color(): - return text - return f"{ANSI_BOLD_RED}{text}{ANSI_RESET}" - - -def compare_bin(golden_path, output_path, dtype, eps): - golden = np.fromfile(golden_path, dtype=dtype) - output = np.fromfile(output_path, dtype=dtype) - if golden.shape != output.shape: - print(style_fail(f"[ERROR] Shape mismatch: golden {golden.shape} vs output {output.shape}")) - return False - if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): - g = golden.astype(np.float64, copy=False) - o = output.astype(np.float64, copy=False) - abs_diff = np.abs(g - o) - idx = int(np.argmax(abs_diff)) - print(style_fail(f"[ERROR] Mismatch: max diff={float(abs_diff[idx])} at idx={idx} " - f"(golden={g[idx]}, output={o[idx]})")) - return False - return True - - -if __name__ == "__main__": - # Optional filter: python compare.py [case_name] - case_filter = sys.argv[1] if len(sys.argv) > 1 else None - - all_passed = True - for case in CASES: - if case_filter is not None and case["name"] != case_filter: - continue - case_dir = case["name"] - golden_path = os.path.join(case_dir, "golden.bin") - output_path = os.path.join(case_dir, "output.bin") - ok = compare_bin(golden_path, output_path, case["dtype"], case["eps"]) - if ok: - print(style_pass(f"[INFO] {case['name']}: compare passed")) - else: - print(style_fail(f"[ERROR] {case['name']}: compare failed")) - all_passed = False - - if not all_passed: - sys.exit(2) - print(style_pass("[INFO] all cases passed")) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tadd/gen_data.py index 1d983c64f..c7273286a 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tadd/gen_data.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadd/gen_data.py @@ -9,25 +9,25 @@ # coding=utf-8 -import os import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data -np.random.seed(19) - -CASES = [ - {"name": "f32_16x64", "dtype": np.float32, "shape": (16, 64)}, - {"name": "f32_32x32", "dtype": np.float32, "shape": (32, 32)}, -] +validate_cases(CASES) for case in CASES: - case_dir = case["name"] - os.makedirs(case_dir, exist_ok=True) + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + input2 = np.random.randint(1, 10, size=shape).astype(dtype) - input1 = np.random.randint(1, 10, size=case["shape"]).astype(case["dtype"]) - input2 = np.random.randint(1, 10, size=case["shape"]).astype(case["dtype"]) - golden = (input1 + input2).astype(case["dtype"], copy=False) + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = (input1[:vr, :vc] + input2[:vr, :vc]).astype(dtype, copy=False) - input1.tofile(os.path.join(case_dir, "input1.bin")) - input2.tofile(os.path.join(case_dir, "input2.bin")) - golden.tofile(os.path.join(case_dir, "golden.bin")) - print(f"[INFO] gen_data: {case['name']} shape={case['shape']} dtype={case['dtype'].__name__}") + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tadd/main.cpp index 3c1703d8e..1a010623f 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tadd/main.cpp +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadd/main.cpp @@ -10,8 +10,8 @@ // Each case launches a different kernel variant, reads/writes from per-case subdirectory. // Numerical comparison is done externally by compare.py. -#include "test_common.h" #include "acl/acl.h" +#include "test_common.h" #include #include #include @@ -21,20 +21,6 @@ using namespace PtoTestCommon; -#define ACL_CHECK(expr) \ - do { \ - const aclError _ret = (expr); \ - if (_ret != ACL_SUCCESS) { \ - std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, \ - __FILE__, __LINE__); \ - const char *_recent = aclGetRecentErrMsg(); \ - if (_recent != nullptr && _recent[0] != '\0') { \ - std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ - } \ - return 1; \ - } \ - } while (0) - // Kernel launch wrappers (defined in launch.cpp) void LaunchTADD_f32_16x64(float *a, float *b, float *c, void *stream); void LaunchTADD_f32_32x32(float *a, float *b, float *c, void *stream); @@ -44,67 +30,83 @@ using LaunchFn = void (*)(float *, float *, float *, void *); struct TestCase { const char *name; LaunchFn launch; - size_t rows; - size_t cols; - size_t elemSize; // bytes per element + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element }; static const TestCase kCases[] = { - {"f32_16x64", LaunchTADD_f32_16x64, 16, 64, sizeof(float)}, - {"f32_32x32", LaunchTADD_f32_32x32, 32, 32, sizeof(float)}, + {"f32_16x64", LaunchTADD_f32_16x64, 16, 64, 16, 64, sizeof(float)}, + {"f32_32x32", LaunchTADD_f32_32x32, 32, 32, 32, 32, sizeof(float)}, }; static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; const size_t elemCount = tc.rows * tc.cols; const size_t fileSize = elemCount * tc.elemSize; - std::printf("[INFO] === case: %s (%zux%zu) ===\n", tc.name, tc.rows, tc.cols); + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); // Per-case data directory std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = fileSize; + size_t src1FileSize = fileSize; float *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; float *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; - ACL_CHECK(aclrtMallocHost((void **)(&src0Host), fileSize)); - ACL_CHECK(aclrtMallocHost((void **)(&src1Host), fileSize)); - ACL_CHECK(aclrtMallocHost((void **)(&dstHost), fileSize)); + aclrtMallocHost((void **)(&src0Host), fileSize); + aclrtMallocHost((void **)(&src1Host), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); - ACL_CHECK(aclrtMalloc((void **)&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST)); - ACL_CHECK(aclrtMalloc((void **)&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST)); - ACL_CHECK(aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST)); + aclrtMalloc((void **)&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); - size_t src0FileSize = fileSize; - size_t src1FileSize = fileSize; if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, fileSize)) { std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); - return 1; + rc = 1; } - if (!ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, fileSize)) { + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, fileSize)) { std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); - return 1; + rc = 1; } - ACL_CHECK(aclrtMemcpy(src0Device, fileSize, src0Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE)); - ACL_CHECK(aclrtMemcpy(src1Device, fileSize, src1Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE)); + if (rc == 0) { + aclrtMemcpy(src0Device, fileSize, src0Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, fileSize, src1Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); - tc.launch(src0Device, src1Device, dstDevice, stream); + tc.launch(src0Device, src1Device, dstDevice, stream); - ACL_CHECK(aclrtSynchronizeStream(stream)); - ACL_CHECK(aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST)); - - WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize); + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } - aclrtFree(src0Device); - aclrtFree(src1Device); - aclrtFree(dstDevice); - aclrtFreeHost(src0Host); - aclrtFreeHost(src1Host); - aclrtFreeHost(dstHost); + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } - std::printf("[INFO] case %s done\n", tc.name); - return 0; + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; } int main(int argc, char *argv[]) { @@ -112,19 +114,15 @@ int main(int argc, char *argv[]) { const char *caseFilter = (argc > 1) ? argv[1] : nullptr; int rc = 0; - bool aclInited = false; - bool deviceSet = false; int deviceId = 0; aclrtStream stream = nullptr; - ACL_CHECK(aclInit(nullptr)); - aclInited = true; + aclInit(nullptr); if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { deviceId = std::atoi(envDevice); } - ACL_CHECK(aclrtSetDevice(deviceId)); - deviceSet = true; - ACL_CHECK(aclrtCreateStream(&stream)); + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); for (size_t i = 0; i < kNumCases; ++i) { if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { @@ -138,15 +136,10 @@ int main(int argc, char *argv[]) { } } - if (stream != nullptr) { + if (stream != nullptr) aclrtDestroyStream(stream); - } - if (deviceSet) { - aclrtResetDevice(deviceId); - } - if (aclInited) { - aclFinalize(); - } + aclrtResetDevice(deviceId); + aclFinalize(); return rc; } diff --git a/test/tilelang_st/script/run_st.py b/test/tilelang_st/script/run_st.py index 8ab4b5979..f71fc0a55 100755 --- a/test/tilelang_st/script/run_st.py +++ b/test/tilelang_st/script/run_st.py @@ -139,12 +139,28 @@ def build_project(run_mode, soc_version, testcase, ptoas_bin): raise -def run_gen_data(golden_path, testcase): +def _copy_testcase_scripts(testcase): + """Copy shared and per-testcase Python scripts into the build work directory.""" + work_dir = get_testcase_work_dir(testcase) + os.makedirs(work_dir, exist_ok=True) + # Shared scripts (testcase/ level). + for name in ("st_common.py", "compare.py"): + src = os.path.join("testcase", name) + if os.path.isfile(src): + run_command(["cp", src, os.path.join(work_dir, name)]) + # Per-testcase scripts. + testcase_src = f"testcase/{testcase}" + for name in ("cases.py", "gen_data.py"): + src = os.path.join(testcase_src, name) + if os.path.isfile(src): + run_command(["cp", src, os.path.join(work_dir, name)]) + + +def run_gen_data(testcase): original_dir = os.getcwd() try: work_dir = get_testcase_work_dir(testcase) - os.makedirs(work_dir, exist_ok=True) - run_command(["cp", golden_path, os.path.join(work_dir, "gen_data.py")]) + _copy_testcase_scripts(testcase) os.chdir(work_dir) run_command([sys.executable, "gen_data.py"]) except Exception as e: @@ -169,12 +185,10 @@ def run_binary(testcase, case_filter=None): os.chdir(original_dir) -def run_compare(compare_path, testcase, case_filter=None): +def run_compare(testcase, case_filter=None): original_dir = os.getcwd() try: work_dir = get_testcase_work_dir(testcase) - os.makedirs(work_dir, exist_ok=True) - run_command(["cp", compare_path, os.path.join(work_dir, "compare.py")]) os.chdir(work_dir) cmd = [sys.executable, "compare.py"] if case_filter: @@ -240,13 +254,9 @@ def main(): build_project(args.run_mode, default_soc_version, testcase, ptoas_bin) # gen golden → run binary → compare - golden_path = f"testcase/{testcase}/gen_data.py" - run_gen_data(golden_path, testcase) - + run_gen_data(testcase) run_binary(testcase, args.case) - - compare_path = f"testcase/{testcase}/compare.py" - run_compare(compare_path, testcase, args.case) + run_compare(testcase, args.case) except Exception as e: print(f"run failed: {str(e)}", file=sys.stderr) From 4dc3c6dcb1edbd1714e9cc0055ed051c289dec4c Mon Sep 17 00:00:00 2001 From: qukelin Date: Mon, 13 Apr 2026 21:16:26 +0800 Subject: [PATCH 052/192] Add local tilelang ST test common header --- test/tilelang_st/npu/a5/src/st/CMakeLists.txt | 4 +- .../npu/a5/src/st/common/test_common.h | 54 +++++++++++++++++++ .../npu/a5/src/st/testcase/CMakeLists.txt | 2 +- 3 files changed, 57 insertions(+), 3 deletions(-) create mode 100644 test/tilelang_st/npu/a5/src/st/common/test_common.h diff --git a/test/tilelang_st/npu/a5/src/st/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/CMakeLists.txt index 60a1bf354..4aec35d12 100644 --- a/test/tilelang_st/npu/a5/src/st/CMakeLists.txt +++ b/test/tilelang_st/npu/a5/src/st/CMakeLists.txt @@ -33,8 +33,8 @@ else() endif() set(PTO_ISA_ROOT "${CMAKE_CURRENT_LIST_DIR}/../../../../../../../pto-isa" CACHE PATH "Path to pto-isa repo") -set(PTO_NPU_VALIDATION_COMMON_DIR - "${CMAKE_CURRENT_LIST_DIR}/../../../../../npu_validation/common") +set(PTO_TILELANG_ST_COMMON_DIR + "${CMAKE_CURRENT_LIST_DIR}/common") set(ASCEND_DRIVER_PATH /usr/local/Ascend/driver) set(CMAKE_COMPILER bisheng) diff --git a/test/tilelang_st/npu/a5/src/st/common/test_common.h b/test/tilelang_st/npu/a5/src/st/common/test_common.h new file mode 100644 index 000000000..ffdf6e373 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/common/test_common.h @@ -0,0 +1,54 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace PtoTestCommon { + +inline bool ReadFile(const std::string &filePath, size_t &fileSize, void *buffer, + size_t bufferSize) { + struct stat sBuf; + if (stat(filePath.c_str(), &sBuf) == -1) { + return false; + } + if (!S_ISREG(sBuf.st_mode)) { + return false; + } + + std::ifstream file(filePath, std::ios::binary); + if (!file.is_open()) { + return false; + } + + std::filebuf *buf = file.rdbuf(); + size_t size = buf->pubseekoff(0, std::ios::end, std::ios::in); + if (size == 0 || size > bufferSize) { + return false; + } + buf->pubseekpos(0, std::ios::in); + buf->sgetn(static_cast(buffer), size); + fileSize = size; + return true; +} + +inline bool WriteFile(const std::string &filePath, const void *buffer, size_t size) { + if (buffer == nullptr) { + return false; + } + + int fd = open(filePath.c_str(), O_RDWR | O_CREAT | O_TRUNC, S_IRUSR | S_IWRITE); + if (fd < 0) { + return false; + } + + ssize_t writeSize = write(fd, buffer, size); + (void)close(fd); + return writeSize == static_cast(size); +} + +} // namespace PtoTestCommon diff --git a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt index f3d2623e2..74fcc7932 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt +++ b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt @@ -93,7 +93,7 @@ function(pto_tilelang_vec_st NAME) add_executable(${NAME} main.cpp) target_compile_options(${NAME} PRIVATE ${CMAKE_CPP_COMPILE_OPTIONS}) target_include_directories(${NAME} PRIVATE - ${PTO_NPU_VALIDATION_COMMON_DIR} + ${PTO_TILELANG_ST_COMMON_DIR} ) target_link_directories(${NAME} PUBLIC From a57b62a02a036f74a1f167295fe9b0291881d645 Mon Sep 17 00:00:00 2001 From: qukelin Date: Tue, 14 Apr 2026 10:13:53 +0800 Subject: [PATCH 053/192] resolve confilct --- tools/ptoas/ptoas.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index 176a0668e..40b12e719 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -1129,22 +1129,22 @@ static LogicalResult lowerPTOToVPTOBackend(ModuleOp module) { // 3. InlineLibCall: inline template function bodies // 4. FoldTileBufIntrinsics: fold tile_buf_addr / tile_valid_rows / // tile_valid_cols to concrete memref/constant values - pm.addPass(pto::createMemrefToTileBufPass()); + backendPM.addPass(pto::createMemrefToTileBufPass()); pto::ExpandTileOpOptions expandOpts; expandOpts.tilelangPath = tilelangPath; expandOpts.tilelangPkgPath = tilelangPkgPath; - pm.addPass(pto::createExpandTileOpPass(expandOpts)); + backendPM.addPass(pto::createExpandTileOpPass(expandOpts)); - pm.addPass(pto::createPTOInlineLibCallPass()); - pm.addNestedPass( + backendPM.addPass(pto::createPTOInlineLibCallPass()); + backendPM.addNestedPass( pto::createFoldTileBufIntrinsicsPass()); // FoldTileBufIntrinsics materializes many constant branch conditions. // Clean them up immediately on the TileOp expansion path before the // authoring-stage VPTO verifier and let the existing CSE passes remove the // resulting dead values later in the pipeline. - pm.addPass(mlir::createSCCPPass()); - pm.addPass(mlir::createCanonicalizerPass()); + backendPM.addPass(mlir::createSCCPPass()); + backendPM.addPass(mlir::createCanonicalizerPass()); } if (failed(backendPM.run(module))) { llvm::errs() << "Error: backend lowering pass execution failed.\n"; From 1f9afeca80185c4f0404db3f53b435785a73e2b4 Mon Sep 17 00:00:00 2001 From: qukelin Date: Tue, 14 Apr 2026 10:52:57 +0800 Subject: [PATCH 054/192] update ptoas option --- docs/designs/ptoas-tileop-expand-design.md | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/docs/designs/ptoas-tileop-expand-design.md b/docs/designs/ptoas-tileop-expand-design.md index 10c0e8e95..b840f4459 100644 --- a/docs/designs/ptoas-tileop-expand-design.md +++ b/docs/designs/ptoas-tileop-expand-design.md @@ -957,23 +957,20 @@ def template_xxx(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): } ``` - Expand TileOp pass 的端到端测试(`pto.tadd` → Vector IR) - 使用以下命令同时观察中间 IR 和最终 LLVM IR: + 使用以下命令生成最终 LLVM IR,并继续交给 Bisheng 做设备侧编译校验: ```bash ./build/tools/ptoas/ptoas test/basic/expand_tile_op_tilelang.pto \ --pto-arch=a5 \ - --print-ir-after-all \ --pto-backend=vpto \ --enable-tile-op-expand \ --vpto-emit-hivm-llvm \ -o - \ - > add.ll \ - 2> /tmp/expand_tile_op_tilelang.mlir + > add.ll ``` 说明: - - `stderr` 中的 `/tmp/expand_tile_op_tilelang.mlir` 保存 `--print-ir-after-all` 打印的各阶段 MLIR/VPTO IR,可用于检查模板是否已经从 `pto.tadd` 展开为向量 IR。 - - `stdout` 中的最终产物是 textual LLVM IR,因此这里使用 `-o - > add.ll` 显式落盘,而不是依赖 `-o ` 与 `--print-ir-after-all` 混用时的输出行为。 + - `stdout` 中的最终产物是 textual LLVM IR,因此这里使用 `-o - > add.ll` 显式落盘。 随后将生成的 `add.ll` 交给 Bisheng: From edf07f499428e8a91a2507212e0dc6a7e720c01d Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Tue, 14 Apr 2026 11:09:31 +0800 Subject: [PATCH 055/192] Support scalar mod operation in DSL --- tilelang-dsl/python/tilelang_dsl/frontend_ast.py | 1 + tilelang-dsl/python/tilelang_dsl/lowering.py | 4 ++++ tilelang-dsl/python/tilelang_dsl/semantic.py | 8 +++++++- 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py index 9ce2f06eb..c11ec116d 100644 --- a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py +++ b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py @@ -620,6 +620,7 @@ def _collect_reachable_inline_procs( ast.Add: "add", ast.Sub: "sub", ast.Mult: "mul", + ast.Mod: "mod", ast.FloorDiv: "floordiv", } _COMPARE_OP_NAMES = { diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index a09bac840..a4bd432e0 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -3428,6 +3428,10 @@ def _render_binary_op(self, op: str, ty: SemanticType) -> str: return "arith.subi" if op == "mul": return "arith.muli" + if op == "mod": + if isinstance(ty, SemanticIndexType): + return "arith.remui" + return "arith.remsi" if op == "floordiv": return "arith.floordivsi" raise NotImplementedError(f"unsupported binary op '{op}' for type {ty!r}") diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 0e0680dca..f34cfad7e 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -3419,7 +3419,7 @@ def _binary_type( rhs: SemanticExpr, op: str, ) -> SemanticType: - if op in {"add", "sub", "mul", "floordiv"}: + if op in {"add", "sub", "mul", "mod", "floordiv"}: if isinstance(lhs.type, SemanticIndexType) and isinstance(rhs.type, SemanticIndexType): return SemanticIndexType() raise TypeError("binary expressions currently only support index-typed operands") @@ -5047,6 +5047,12 @@ def _try_static_value(self, expr: SemanticExpr | None) -> Any | None: if isinstance(lhs, int) and isinstance(rhs, int): return lhs * rhs return None + if expr.op == "mod": + if isinstance(lhs, int) and isinstance(rhs, int): + if rhs == 0: + return None + return lhs % rhs + return None if expr.op == "floordiv": if isinstance(lhs, int) and isinstance(rhs, int): if rhs == 0: From 1c9b329554ff0bfbba2d8ce280e8345c6fa3d570 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Tue, 14 Apr 2026 12:04:08 +0800 Subject: [PATCH 056/192] Support custom tile padding value --- .../docs/user_guide/05-type-system.md | 63 +++++- tilelang-dsl/python/tilelang_dsl/lowering.py | 16 +- tilelang-dsl/python/tilelang_dsl/semantic.py | 65 +++++- tilelang-dsl/python/tilelang_dsl/types.py | 185 +++++++++++++++++- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 83 +++++++- 5 files changed, 392 insertions(+), 20 deletions(-) diff --git a/tilelang-dsl/docs/user_guide/05-type-system.md b/tilelang-dsl/docs/user_guide/05-type-system.md index 19c28082f..91d7c8845 100644 --- a/tilelang-dsl/docs/user_guide/05-type-system.md +++ b/tilelang-dsl/docs/user_guide/05-type-system.md @@ -275,6 +275,33 @@ Important notes on shape and valid shape: | `valid_shape` | `tuple[int, ...]` | Actual data dimensions within the tile. Must be less than or equal to `shape` in each dimension | | `config` | `TileConfig` | Layout and padding configuration | +#### Tile Pad Values + +`TileConfig.pad_value` is modeled after the C++ `PadValue : uint64_t` design. + +Standard pad values use small integer encodings: + +| DSL Value | Encoded Value | Meaning | +|-----------|---------------|---------| +| `pto.PadValue.NULL` | `0` | No concrete fill value | +| `pto.PadValue.ZERO` | `1` | Zero fill | +| `pto.PadValue.MAX` | `2` | Maximum finite / integer max for the tile element dtype | +| `pto.PadValue.MIN` | `3` | Minimum finite / integer min for the tile element dtype | + +Custom pad values use the `CustomBase = 0x100000000` convention and are authored with `pto.PadValue.custom_f32(...)`: + +```python +pad0 = pto.PadValue.ZERO +pad1 = pto.PadValue.custom_f32(-1.0) +pad2 = pto.PadValue.custom_f32("0xBF800000") # float32 bit pattern for -1.0f +``` + +Notes: +- `PadValue.value` on the host-side descriptor still exposes the encoded integer payload. +- `PadValue.text` exposes the standard textual spelling for built-ins such as `null` and `zero`. +- Custom pad values currently model an `f32` payload. In DSL v1, materializing a custom pad into a scalar is only supported for floating tile element dtypes. +- `PadValue.NULL` does not denote a usable scalar fill constant. Reading `tile.pad_value.value` or `tile.config.pad_value.value` when the enum is `NULL` is a frontend error. + #### Tile Shape Concepts - `shape` is the static physical allocation size of the tile buffer. @@ -296,12 +323,46 @@ config = tile.config b_layout = config.b_layout # pto.BLayout.ROW_MAJOR s_layout = config.s_layout # pto.SLayout.NONE_BOX s_fractal = config.s_fractal_size # pto.i32(512) -pad = config.pad_value # pto.PadValue.NULL +pad_desc = tile.config.pad_value # PadValue enum bound to the tile element dtype +pad_desc2 = tile.pad_value # direct sugar for the same PadValue enum # Dynamic properties rank = tile.rank # 2 ``` +`tile.config.pad_value` and `tile.pad_value` are enum-typed inside kernel code. Use `.value` to materialize the configured pad descriptor against the tile element dtype: + +- `tile.pad_value.value` with `PadValue.ZERO` becomes `0` / `0.0` +- `tile.pad_value.value` with `PadValue.MAX` becomes dtype-aware max +- `tile.pad_value.value` with `PadValue.MIN` becomes dtype-aware min +- `tile.pad_value.value` with `PadValue.custom_f32(...)` becomes the authored floating scalar +- `tile.pad_value.value` with `PadValue.NULL` raises a frontend error + +Example: reading pad value from a `Tile` + +```python +@pto.vkernel(op="fill_pad_demo", dtypes=[(pto.f16,)]) +def kernel(dst: pto.Tile): + mask, _ = pto.make_mask(pto.f16, 8) + + # Read the Tile-bound PadValue enum. + pad0 = dst.pad_value + + # Equivalent form through TileConfig metadata. + pad1 = dst.config.pad_value + + if pto.constexpr(pad0 != pto.PadValue.NULL): + scalar0 = pad0.value + scalar1 = pad1.value + vec0 = pto.vdup(scalar0, mask) + vec1 = pto.vdup(scalar1, mask) + pto.vsts(vec0, dst[0, 0:], mask) + pto.vsts(vec1, dst[1, 0:], mask) +``` + +If `dst` is specialized with `config=pto.TileConfig.from_mapping({"pad_value": pto.PadValue.ZERO})`, +both `pad0` and `pad1` are `PadValue.ZERO`, and `pad0.value` / `pad1.value` materialize to the scalar `0.0` for an `f16` tile. + #### Conversion Operations Basic mode syntax uses tile element-indexing directly in vector operations: diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index a4bd432e0..78d9e539e 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -37,6 +37,7 @@ SemanticLowLevelCopyStmt, SemanticMaskType, SemanticMetaType, + SemanticPadValueType, SemanticPipeBarrierStmt, SemanticPredicateStoreStmt, SemanticPtrType, @@ -70,7 +71,9 @@ ) from .types import ( MaskPattern, + PadValue, ScalarType, + TileConfig, bytewidth, get_lanes, integer_bitwidth, @@ -586,7 +589,7 @@ def _render_assign( return self._render_tuple_expr_assign(stmt, env, indent=indent) return self._render_multi_result_assign(stmt, env, indent=indent) target = stmt.targets[0] - if isinstance(target.type, SemanticMetaType): + if isinstance(target.type, (SemanticMetaType, SemanticPadValueType)): env[target.name] = _RenderedValue(name=target.ssa_name, type=target.type) return [] lines: list[str] = [] @@ -3525,11 +3528,13 @@ def _render_tile_buf_type(self, ty: SemanticTileType) -> str: valid_shape = ty.valid_shape or ty.shape v_row = valid_shape[0] v_col = 1 if ty.rank == 1 else valid_shape[1] + config = ty.config or TileConfig() return ( f"!pto.tile_buf" + f"blayout={config.b_layout.value}, slayout={config.s_layout.value}, " + f"fractal={config.s_fractal_size}, pad={self._render_tile_buf_pad_value(config.pad_value)}>" ) def _render_tile_buf_loc(self, memory_space: str) -> str: @@ -3542,6 +3547,13 @@ def _render_tile_buf_loc(self, memory_space: str) -> str: def _render_tile_buf_dim(self, dim: int | None) -> str: return "?" if dim is None else str(dim) + def _render_tile_buf_pad_value(self, pad_value: PadValue) -> str: + if pad_value.is_custom: + raise NotImplementedError( + "custom TileConfig.pad_value MLIR type rendering requires PTO tile_buf parser support for custom pad encodings" + ) + return str(pad_value.encoded) + def _dtype_byte_width(self, dtype: ScalarType) -> int: try: return bytewidth(dtype) diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index f34cfad7e..0df887657 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -127,7 +127,10 @@ _PAD_MODE_SYMBOLS = {pad_mode.name: pad_mode for pad_mode in PadMode} _B_LAYOUT_SYMBOLS = {layout.name: layout for layout in BLayout} _S_LAYOUT_SYMBOLS = {layout.name: layout for layout in SLayout} -_PAD_VALUE_SYMBOLS = {pad_value.name: pad_value for pad_value in PadValue} +_PAD_VALUE_SYMBOLS = { + pad_value.name: pad_value + for pad_value in (PadValue.NULL, PadValue.ZERO, PadValue.MAX, PadValue.MIN) +} _DEINTERLEAVE_DIST_SYMBOLS = dict(DeinterleaveDist.__members__) _INTERLEAVE_DIST_SYMBOLS = dict(InterleaveDist.__members__) _POSITION_MODE_SYMBOLS = {position_mode.name: position_mode for position_mode in PositionMode} @@ -267,7 +270,7 @@ class SemanticTileType(SemanticType): @dataclass(frozen=True) class SemanticTileConfigType(SemanticType): - pass + element_dtype: ScalarType | None = None @dataclass(frozen=True) @@ -306,6 +309,11 @@ class SemanticMetaType(SemanticType): kind: str +@dataclass(frozen=True) +class SemanticPadValueType(SemanticType): + element_dtype: ScalarType | None = None + + @dataclass(frozen=True) class SemanticAlignType(SemanticType): pass @@ -364,7 +372,7 @@ class SemanticSymbolExpr(SemanticExpr): namespace: str name: str value: Any - type: SemanticMetaType + type: SemanticType @dataclass(frozen=True) @@ -2792,6 +2800,10 @@ def _analyze_expr( return self._attach_expr_source_location(self._rank_expr(base), expr) if expr.attr == "memory_space": return self._attach_expr_source_location(self._memory_space_expr(base), expr) + if expr.attr == "pad_value" and isinstance(base.type, SemanticTileType): + return self._attach_expr_source_location(self._tile_pad_value_expr(base), expr) + if expr.attr == "value" and isinstance(base.type, SemanticPadValueType): + return self._attach_expr_source_location(self._pad_value_value_expr(base), expr) if expr.attr == "config": return self._attach_expr_source_location(self._tile_config_expr(base), expr) if expr.attr == "valid_shape": @@ -3020,7 +3032,7 @@ def _analyze_symbol_expr(self, expr: FrontendSymbolExpr) -> SemanticExpr: namespace=expr.namespace, name=expr.name, value=pad_value, - type=SemanticMetaType(kind="pad_value"), + type=SemanticPadValueType(), ) if expr.namespace in {"DeinterleaveDist", "pto.DeinterleaveDist"}: dist = _DEINTERLEAVE_DIST_SYMBOLS.get(expr.name) @@ -3151,10 +3163,44 @@ def _tile_config_expr(self, base: SemanticExpr) -> SemanticExpr: if isinstance(base_type, SemanticTileType): return SemanticLiteralExpr( value=base_type.config or TileConfig(), - type=SemanticTileConfigType(), + type=SemanticTileConfigType(element_dtype=base_type.element_dtype), ) raise TypeError("unsupported attribute access 'config' in TileLang DSL v1") + def _tile_pad_value_expr(self, base: SemanticExpr) -> SemanticExpr: + base_type = base.type + if not isinstance(base_type, SemanticTileType): + raise TypeError("unsupported attribute access 'pad_value' in TileLang DSL v1") + config = base_type.config or TileConfig() + return SemanticSymbolExpr( + namespace="pto", + name=config.pad_value.name, + value=config.pad_value, + type=SemanticPadValueType(element_dtype=base_type.element_dtype), + ) + + def _pad_value_value_expr(self, base: SemanticExpr) -> SemanticExpr: + if not isinstance(base.type, SemanticPadValueType): + raise TypeError("unsupported attribute access 'value' in TileLang DSL v1") + if base.type.element_dtype is None: + raise TypeError( + "PadValue.value requires a Tile-bound or TileConfig-bound pad descriptor with an owning " + "Tile element dtype in TileLang DSL v1" + ) + pad_value = self._try_static_value(base) + if not isinstance(pad_value, PadValue): + raise TypeError("PadValue.value expects a statically known PadValue enum in TileLang DSL v1") + pad_scalar = pad_value.materialize_scalar(base.type.element_dtype) + if pad_scalar is None: + raise TypeError( + "PadValue.NULL.value is invalid in TileLang DSL v1; " + "guard it with `pto.constexpr(tile.pad_value != pto.PadValue.NULL)` before reading `.value`" + ) + return SemanticLiteralExpr( + value=pad_scalar, + type=SemanticScalarType(dtype=base.type.element_dtype), + ) + def _tile_config_attr_expr(self, base: SemanticExpr, attr: str) -> SemanticExpr: config = self._try_static_value(base) if not isinstance(config, TileConfig): @@ -3179,11 +3225,15 @@ def _tile_config_attr_expr(self, base: SemanticExpr, attr: str) -> SemanticExpr: type=SemanticScalarType(dtype=i32), ) if attr == "pad_value": + if not isinstance(base.type, SemanticTileConfigType): + raise TypeError( + "TileConfig.pad_value expects a TileConfig value in TileLang DSL v1" + ) return SemanticSymbolExpr( namespace="pto", name=config.pad_value.name, value=config.pad_value, - type=SemanticMetaType(kind="pad_value"), + type=SemanticPadValueType(element_dtype=base.type.element_dtype), ) raise TypeError(f"unsupported TileConfig attribute access '{attr}' in TileLang DSL v1") @@ -3428,6 +3478,8 @@ def _binary_type( return SemanticScalarType(dtype=i1) if isinstance(lhs.type, SemanticScalarType) and lhs.type == rhs.type: return SemanticScalarType(dtype=i1) + if isinstance(lhs.type, SemanticPadValueType) and isinstance(rhs.type, SemanticPadValueType): + return SemanticScalarType(dtype=i1) if isinstance(lhs.type, SemanticMetaType) and lhs.type == rhs.type: return SemanticScalarType(dtype=i1) raise TypeError( @@ -5242,6 +5294,7 @@ def analyze_frontend_kernel(node: FrontendKernelNode) -> SemanticKernel: "SemanticLiteralExpr", "SemanticMemBarStmt", "SemanticMaskType", + "SemanticPadValueType", "SemanticParameter", "SemanticPipeBarrierStmt", "SemanticPredicateStoreStmt", diff --git a/tilelang-dsl/python/tilelang_dsl/types.py b/tilelang-dsl/python/tilelang_dsl/types.py index bd62b70ff..6780c998c 100644 --- a/tilelang-dsl/python/tilelang_dsl/types.py +++ b/tilelang-dsl/python/tilelang_dsl/types.py @@ -12,6 +12,7 @@ from dataclasses import dataclass from enum import Enum +import struct from typing import Any, Mapping @@ -205,11 +206,180 @@ class SLayout(str, Enum): COL_MAJOR = "col_major" -class PadValue(str, Enum): - NULL = "null" - ZERO = "zero" - MAX = "max" - MIN = "min" +def _float32_from_bits(bits: int) -> float: + return struct.unpack(">f", bits.to_bytes(4, byteorder="big", signed=False))[0] + + +_FLOAT_DTYPE_MAX = { + "f16": 65504.0, + "bf16": _float32_from_bits(0x7F7F0000), + "f32": _float32_from_bits(0x7F7FFFFF), +} +_FLOAT_DTYPE_MIN = { + "f16": -65504.0, + "bf16": _float32_from_bits(0xFF7F0000), + "f32": _float32_from_bits(0xFF7FFFFF), +} + + +@dataclass(frozen=True) +class PadValue: + """Tile pad descriptor matching the C++ PadValue design. + + Standard values occupy the low integer range: + - NULL = 0 + - ZERO = 1 + - MAX = 2 + - MIN = 3 + + Custom values use the C++ `CustomBase` convention and carry an f32 bit + pattern authored through `custom_f32(...)`. + """ + + encoded: int + _symbol_name: str | None = None + _float32_bits: int | None = None + + CustomBase = 0x100000000 + _STANDARD_TEXT = { + 0: "null", + 1: "zero", + 2: "max", + 3: "min", + } + + def __post_init__(self) -> None: + if isinstance(self.encoded, bool) or not isinstance(self.encoded, int): + raise TypeError("PadValue.encoded must be a uint64-compatible integer") + if self.encoded < 0 or self.encoded >= (1 << 64): + raise ValueError("PadValue.encoded must be in uint64 range") + if self._float32_bits is not None and not (0 <= self._float32_bits < (1 << 32)): + raise ValueError("PadValue custom float32 payload must be a 32-bit integer") + + @property + def name(self) -> str: + if self._symbol_name is not None: + return self._symbol_name + return "CUSTOM" + + @property + def value(self) -> int: + return self.encoded + + @property + def text(self) -> str: + standard = self._STANDARD_TEXT.get(self.encoded) + if standard is not None: + return standard + return f"0x{self.encoded:016X}" + + @property + def is_custom(self) -> bool: + return self._symbol_name is None and self.encoded >= self.CustomBase + + @property + def float32_bits(self) -> int: + if not self.is_custom: + raise ValueError("only custom PadValue instances carry a float32 payload") + if self._float32_bits is not None: + return self._float32_bits + return (self.encoded >> 32) & 0xFFFFFFFF + + def as_float32(self) -> float: + return _float32_from_bits(self.float32_bits) + + def materialize_scalar(self, dtype: ScalarType) -> int | float | None: + if not isinstance(dtype, ScalarType): + raise TypeError("PadValue.materialize_scalar expects a TileLang scalar dtype") + if self == PadValue.NULL: + return None + if self == PadValue.ZERO: + return 0.0 if is_float_dtype(dtype) else 0 + if self == PadValue.MAX: + if is_float_dtype(dtype): + return _FLOAT_DTYPE_MAX[dtype.name] + width = integer_bitwidth(dtype) + signedness = integer_signedness(dtype) + if width is None or signedness is None: + raise TypeError(f"PadValue.MAX does not support dtype `{dtype.name}`") + if signedness == "unsigned": + return (1 << width) - 1 + return (1 << (width - 1)) - 1 + if self == PadValue.MIN: + if is_float_dtype(dtype): + return _FLOAT_DTYPE_MIN[dtype.name] + width = integer_bitwidth(dtype) + signedness = integer_signedness(dtype) + if width is None or signedness is None: + raise TypeError(f"PadValue.MIN does not support dtype `{dtype.name}`") + if signedness == "unsigned": + return 0 + return -(1 << (width - 1)) + if self.is_custom: + if not is_float_dtype(dtype): + raise TypeError( + "custom Tile pad_value currently only materializes for floating Tile element dtypes" + ) + return self.as_float32() + raise TypeError(f"unsupported PadValue payload {self!r}") + + @classmethod + def from_uint64(cls, value: int) -> "PadValue": + if isinstance(value, bool) or not isinstance(value, int): + raise TypeError("PadValue.from_uint64 expects an integer") + if value == 0: + return cls.NULL + if value == 1: + return cls.ZERO + if value == 2: + return cls.MAX + if value == 3: + return cls.MIN + if value < 0 or value >= (1 << 64): + raise ValueError("PadValue.from_uint64 expects a uint64-compatible integer") + return cls(value) + + @classmethod + def custom_f32(cls, value: float | str | int) -> "PadValue": + bits = cls._normalize_custom_f32_bits(value) + encoded = cls.CustomBase | (bits << 32) + return cls(encoded=encoded, _float32_bits=bits) + + @staticmethod + def _normalize_custom_f32_bits(value: float | str | int) -> int: + if isinstance(value, bool): + raise TypeError("PadValue.custom_f32 does not accept bool") + if isinstance(value, int): + if value < 0 or value >= (1 << 32): + raise ValueError("PadValue.custom_f32 integer payload must fit in 32 bits") + return value + if isinstance(value, str): + text = value.strip() + if text.lower().startswith("0x"): + bits = int(text, 16) + if bits < 0 or bits >= (1 << 32): + raise ValueError("PadValue.custom_f32 hex payload must fit in 32 bits") + return bits + value = float(text) + packed = struct.pack(">f", float(value)) + return int.from_bytes(packed, byteorder="big", signed=False) + + def __repr__(self) -> str: + if self == PadValue.NULL: + return "PadValue.NULL" + if self == PadValue.ZERO: + return "PadValue.ZERO" + if self == PadValue.MAX: + return "PadValue.MAX" + if self == PadValue.MIN: + return "PadValue.MIN" + return f"PadValue.custom_f32(0x{self.float32_bits:08X})" + + +PadValue.NULL = PadValue(0, "NULL") +PadValue.ZERO = PadValue(1, "ZERO") +PadValue.MAX = PadValue(2, "MAX") +PadValue.MIN = PadValue(3, "MIN") class DeinterleaveDist(str, Enum): @@ -338,7 +508,12 @@ def _normalize_s_layout(value: Any) -> SLayout: def _normalize_pad_value(value: Any) -> PadValue: if isinstance(value, PadValue): return value + if isinstance(value, int) and not isinstance(value, bool): + return PadValue.from_uint64(value) if isinstance(value, str): + text = value.strip() + if text.lower().startswith("0x"): + return PadValue.from_uint64(int(text, 16)) normalized = value.strip().upper().replace("-", "_") if normalized == "NULL": return PadValue.NULL diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index 0d57af82c..4307adfd7 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -44,6 +44,7 @@ SemanticMemBarStmt, SemanticLowLevelCopyStmt, SemanticMaskType, + SemanticPadValueType, SemanticPipeBarrierStmt, SemanticPtrType, SemanticPredicateStoreStmt, @@ -123,7 +124,11 @@ def test_package_exports_surface(self) -> None: self.assertEqual(pto.PadMode.PadValue.value, "PadValue") self.assertEqual(pto.BLayout.ROW_MAJOR.value, "row_major") self.assertEqual(pto.SLayout.NONE_BOX.value, "none_box") - self.assertEqual(pto.PadValue.NULL.value, "null") + self.assertEqual(pto.PadValue.NULL.value, 0) + self.assertEqual(pto.PadValue.ZERO.value, 1) + self.assertEqual(pto.PadValue.MAX.value, 2) + self.assertEqual(pto.PadValue.MIN.value, 3) + self.assertEqual(pto.PadValue.NULL.text, "null") self.assertEqual(pto.DeinterleaveDist.DINTLV.value, "DINTLV") self.assertEqual(pto.DeinterleaveDist.BDINTLV.value, "BDINTLV") self.assertEqual(pto.InterleaveDist.INTLV.value, "INTLV") @@ -170,6 +175,19 @@ def test_tile_config_exposes_normalized_query_properties(self) -> None: self.assertEqual(config.s_fractal_size, 16) self.assertEqual(config.pad_value, pto.PadValue.MAX) + def test_pad_value_supports_standard_and_custom_payloads(self) -> None: + custom = pto.PadValue.custom_f32(-1.0) + self.assertTrue(custom.is_custom) + self.assertEqual(custom.float32_bits, 0xBF800000) + self.assertEqual(custom.value, pto.PadValue.CustomBase | (0xBF800000 << 32)) + self.assertAlmostEqual(custom.as_float32(), -1.0) + self.assertAlmostEqual(custom.materialize_scalar(pto.f32), -1.0) + self.assertEqual(pto.PadValue.MAX.materialize_scalar(pto.ui16), 0xFFFF) + self.assertEqual(pto.PadValue.MIN.materialize_scalar(pto.ui16), 0) + self.assertEqual(pto.PadValue.MAX.materialize_scalar(pto.i16), 0x7FFF) + self.assertEqual(pto.PadValue.MIN.materialize_scalar(pto.i16), -0x8000) + self.assertIsNone(pto.PadValue.NULL.materialize_scalar(pto.f16)) + class TileLangDSLSupportMatrixTests(unittest.TestCase): def test_stable_starter_surface_groups_map_to_stable_tier(self) -> None: @@ -990,7 +1008,12 @@ def kernel(inp: pto.TensorView, tile: pto.Tile): tile=pto.TileSpecialization( shape=(16, 32), memory_space=pto.MemorySpace.UB, - config=pto.TileConfig.from_mapping({"layout": "row_major"}), + config=pto.TileConfig.from_mapping( + { + "layout": "row_major", + "pad_value": pto.PadValue.ZERO, + } + ), ) ) @@ -1000,7 +1023,7 @@ def kernel(inp: pto.TensorView, tile: pto.Tile): self.assertIn("// tilelang.specialize tile shape=(16, 32) memory_space=ub", text) self.assertIn('module attributes {pto.target_arch = "a5"} {', text) self.assertIn( - "func.func @kernel(%arg0: !pto.tensor_view, %arg1: !pto.tile_buf) attributes { pto.tilelang.instance } {", + "func.func @kernel(%arg0: !pto.tensor_view, %arg1: !pto.tile_buf) attributes { pto.tilelang.instance } {", text, ) module = specialized.mlir_module() @@ -1693,6 +1716,9 @@ def kernel(tile: pto.Tile): secondary = config.s_layout fractal = config.s_fractal_size pad = config.pad_value + pad_direct = tile.pad_value + pad_scalar = pad.value + pad_direct_scalar = pad_direct.value rank = tile.rank space = tile.memory_space return None @@ -1713,8 +1739,19 @@ def kernel(tile: pto.Tile): ) semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) - config_assign, layout_assign, secondary_assign, fractal_assign, pad_assign, rank_assign, space_assign = ( - semantic_kernel.body[:7] + ( + config_assign, + layout_assign, + secondary_assign, + fractal_assign, + pad_assign, + pad_direct_assign, + pad_scalar_assign, + pad_direct_scalar_assign, + rank_assign, + space_assign, + ) = ( + semantic_kernel.body[:10] ) self.assertIsInstance(config_assign, SemanticAssignStmt) @@ -1738,7 +1775,23 @@ def kernel(tile: pto.Tile): self.assertIsInstance(pad_assign.value, SemanticSymbolExpr) self.assertEqual(pad_assign.value.value, pto.PadValue.ZERO) - self.assertEqual(pad_assign.value.type.kind, "pad_value") + self.assertIsInstance(pad_assign.targets[0].type, SemanticPadValueType) + self.assertEqual(pad_assign.targets[0].type.element_dtype, pto.f16) + + self.assertIsInstance(pad_direct_assign.value, SemanticSymbolExpr) + self.assertEqual(pad_direct_assign.value.value, pto.PadValue.ZERO) + self.assertIsInstance(pad_direct_assign.targets[0].type, SemanticPadValueType) + self.assertEqual(pad_direct_assign.targets[0].type.element_dtype, pto.f16) + + self.assertIsInstance(pad_scalar_assign.value, SemanticLiteralExpr) + self.assertEqual(pad_scalar_assign.value.value, 0.0) + self.assertIsInstance(pad_scalar_assign.targets[0].type, SemanticScalarType) + self.assertEqual(pad_scalar_assign.targets[0].type.dtype, pto.f16) + + self.assertIsInstance(pad_direct_scalar_assign.value, SemanticLiteralExpr) + self.assertEqual(pad_direct_scalar_assign.value.value, 0.0) + self.assertIsInstance(pad_direct_scalar_assign.targets[0].type, SemanticScalarType) + self.assertEqual(pad_direct_scalar_assign.targets[0].type.dtype, pto.f16) self.assertEqual(rank_assign.value.value, 2) self.assertIsInstance(rank_assign.targets[0].type, SemanticIndexType) @@ -1747,6 +1800,24 @@ def kernel(tile: pto.Tile): self.assertEqual(space_assign.value.value, pto.MemorySpace.UB) self.assertEqual(space_assign.value.type.kind, "memory_space") + def test_pad_value_value_requires_non_null_enum(self) -> None: + @pto.vkernel(op="tile_pad_value_null_value", dtypes=[(pto.f16,)]) + def kernel(tile: pto.Tile): + scalar = tile.pad_value.value + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(8, 16), + memory_space=pto.MemorySpace.UB, + ) + ) + + with self.assertRaises(TypeError) as ctx: + analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + + self.assertIn("PadValue.NULL.value is invalid", str(ctx.exception)) + def test_make_mask_vlds_vsts_and_vector_families_lower_inside_strict_vecscope(self) -> None: @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32)], advanced=True) From 9c06560f872560e13ae88db09a2f3e104954f5ea Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Tue, 14 Apr 2026 12:26:09 +0800 Subject: [PATCH 057/192] Fix missing valid_shape/padvalue info when expand tileop --- lib/PTO/Transforms/ExpandTileOp.cpp | 51 ++++++++++- .../python/tilelang_dsl/expand_helper.py | 20 ++++- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 87 +++++++++++++++++++ 3 files changed, 153 insertions(+), 5 deletions(-) diff --git a/lib/PTO/Transforms/ExpandTileOp.cpp b/lib/PTO/Transforms/ExpandTileOp.cpp index 4eb830a37..c4d4833ee 100644 --- a/lib/PTO/Transforms/ExpandTileOp.cpp +++ b/lib/PTO/Transforms/ExpandTileOp.cpp @@ -43,6 +43,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/MemoryBuffer.h" @@ -92,11 +93,12 @@ struct OperandTypeInfo { // --- Tile-only (TileBufType) --- SmallVector tileShape; + SmallVector tileValidShape; std::string tileMemorySpace; // "ub" or "gm" int32_t blayout = 0; int32_t slayout = 0; int32_t fractal = 0; - int32_t pad = 0; + uint64_t pad = 0; // --- View-only (MemRefType) — for JSON / constraint checking only --- SmallVector viewShape; @@ -109,6 +111,7 @@ struct OperandTypeInfo { return false; if (kind == OperandKind::Tile) return tileShape == rhs.tileShape && + tileValidShape == rhs.tileValidShape && tileMemorySpace == rhs.tileMemorySpace && blayout == rhs.blayout && slayout == rhs.slayout && fractal == rhs.fractal && pad == rhs.pad; @@ -138,9 +141,11 @@ struct SpecKeyInfo : public llvm::DenseMapInfo { h = llvm::hash_combine(h, static_cast(op.kind), op.dtype); if (op.kind == OperandKind::Tile) { h = llvm::hash_combine(h, op.tileMemorySpace, op.blayout, - op.slayout, op.fractal, op.pad); + op.slayout, op.fractal, op.pad); for (int64_t d : op.tileShape) h = llvm::hash_combine(h, d); + for (int64_t d : op.tileValidShape) + h = llvm::hash_combine(h, d); } // View/Scalar: only kind + dtype contribute to hash. } @@ -184,6 +189,20 @@ static std::string getMemorySpaceString(MemRefType mrTy) { return "ub"; } +static std::string getBLayoutString(int32_t blayout) { + if (blayout == static_cast(pto::BLayout::ColMajor)) + return "col_major"; + return "row_major"; +} + +static std::string getSLayoutString(int32_t slayout) { + if (slayout == static_cast(pto::SLayout::RowMajor)) + return "row_major"; + if (slayout == static_cast(pto::SLayout::ColMajor)) + return "col_major"; + return "none_box"; +} + static std::optional buildOperandTypeInfo(Type ty) { // Tile operand — from TileBufType. if (auto tbTy = dyn_cast(ty)) { @@ -193,6 +212,11 @@ static std::optional buildOperandTypeInfo(Type ty) { if (info.dtype.empty()) return std::nullopt; info.tileShape.assign(tbTy.getShape().begin(), tbTy.getShape().end()); + auto validShape = tbTy.getValidShape(); + if (validShape.empty()) + info.tileValidShape.assign(tbTy.getShape().begin(), tbTy.getShape().end()); + else + info.tileValidShape.assign(validShape.begin(), validShape.end()); info.tileMemorySpace = getMemorySpaceString(tbTy); if (auto config = tbTy.getConfigAttr()) { info.blayout = static_cast(config.getBLayout().getValue()); @@ -200,7 +224,7 @@ static std::optional buildOperandTypeInfo(Type ty) { info.fractal = config.getSFractalSize() ? static_cast(config.getSFractalSize().getInt()) : 0; - info.pad = static_cast(config.getPad().getValue()); + info.pad = static_cast(config.getPad().getValue()); } return info; } @@ -295,7 +319,20 @@ static std::string buildOperandSpecsJson(const SpecKey &key) { if (op.kind == OperandKind::Tile) { json += "{\"kind\":\"tile\",\"dtype\":\"" + op.dtype + "\",\"shape\":"; appendJsonIntArray(json, op.tileShape); - json += ",\"memory_space\":\"" + op.tileMemorySpace + "\"}"; + json += ",\"valid_shape\":"; + appendJsonIntArray(json, op.tileValidShape); + json += ",\"memory_space\":\""; + json += op.tileMemorySpace; + json += "\",\"config\":{"; + json += "\"b_layout\":\""; + json += getBLayoutString(op.blayout); + json += "\",\"s_layout\":\""; + json += getSLayoutString(op.slayout); + json += "\",\"s_fractal_size\":"; + json += std::to_string(op.fractal); + json += ",\"pad_value\":\"0x"; + json += llvm::utohexstr(op.pad, /*LowerCase=*/false); + json += "\"}}"; continue; } @@ -470,6 +507,12 @@ func::FuncOp ExpandState::invokeTilelangDSL(const SpecKey &key, if (op.kind == OperandKind::Tile) { for (int64_t d : op.tileShape) uniqueName += "_" + std::to_string(d); + for (int64_t d : op.tileValidShape) + uniqueName += "_v" + std::to_string(d); + uniqueName += "_bl" + std::to_string(op.blayout); + uniqueName += "_sl" + std::to_string(op.slayout); + uniqueName += "_fr" + std::to_string(op.fractal); + uniqueName += "_pd" + llvm::utohexstr(op.pad, /*LowerCase=*/false); } } diff --git a/tilelang-dsl/python/tilelang_dsl/expand_helper.py b/tilelang-dsl/python/tilelang_dsl/expand_helper.py index 79d8c82e9..44fed81b3 100644 --- a/tilelang-dsl/python/tilelang_dsl/expand_helper.py +++ b/tilelang-dsl/python/tilelang_dsl/expand_helper.py @@ -22,7 +22,7 @@ from pathlib import Path from .kernel import VKernelDescriptor, _match_descriptor_dtype_signature -from .types import MemorySpace, ScalarType, TileSpecialization +from .types import MemorySpace, ScalarType, TileConfig, TileSpecialization _DTYPE_MAP: dict[str, ScalarType] = {} @@ -129,16 +129,32 @@ def _parse_operand_specs(spec_text: str) -> list[dict]: shape = raw.get("shape") if not isinstance(shape, list) or not shape: raise ValueError(f"operand-specs[{index}] tile shape must be a non-empty list") + valid_shape = raw.get("valid_shape") + if valid_shape is not None and (not isinstance(valid_shape, list) or not valid_shape): + raise ValueError(f"operand-specs[{index}] tile valid_shape must be a non-empty list") memory_space = _MEMSPACE_MAP.get(raw.get("memory_space")) if memory_space is None: raise ValueError( f"operand-specs[{index}] has unknown memory-space {raw.get('memory_space')!r}" ) + config_raw = raw.get("config") + config = None + if config_raw is not None: + if not isinstance(config_raw, dict): + raise ValueError(f"operand-specs[{index}] tile config must be an object") + try: + config = TileConfig.from_mapping(config_raw) + except (TypeError, ValueError) as exc: + raise ValueError( + f"operand-specs[{index}] has invalid tile config: {exc}" + ) from exc specs.append( { "kind": "tile", "dtype": dtype, "shape": tuple(int(dim) for dim in shape), + "valid_shape": None if valid_shape is None else tuple(int(dim) for dim in valid_shape), + "config": config, "memory_space": memory_space, } ) @@ -262,6 +278,8 @@ def main(argv: list[str] | None = None) -> int: tile_specs[param.name] = TileSpecialization( shape=operand_spec["shape"], memory_space=operand_spec["memory_space"], + config=operand_spec.get("config"), + valid_shape=operand_spec.get("valid_shape"), ) continue if param.kind in ("tensorview", "partition_tensor_view"): diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index 4307adfd7..7d9ca4c6a 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -5,6 +5,7 @@ from pathlib import Path import tilelang_dsl as pto +import tilelang_dsl.expand_helper as expand_helper import tilelang_dsl.kernel as kernel_impl from tilelang_dsl.support_matrix import ( ADVANCED_EXPLICIT_VECSCOPE_SURFACES, @@ -189,6 +190,92 @@ def test_pad_value_supports_standard_and_custom_payloads(self) -> None: self.assertIsNone(pto.PadValue.NULL.materialize_scalar(pto.f16)) +class TileLangDSLExpandHelperTests(unittest.TestCase): + def test_operand_specs_preserve_tile_valid_shape_and_pad_value(self) -> None: + source = """ +import tilelang_dsl as pto + +@pto.vkernel(op="pto.expand_helper_tile_config_unique", dtypes=[(pto.f32, pto.f32)]) +def kernel(src: pto.Tile, dst: pto.Tile): + rows, cols = src.valid_shape + pad = dst.pad_value + if pto.constexpr(pad != pto.PadValue.NULL): + scalar = pad.value + return None +""" + with tempfile.TemporaryDirectory() as tmpdir: + module_path = Path(tmpdir) / "expand_helper_tile_config_unique.py" + module_path.write_text(source, encoding="utf-8") + + mod = expand_helper._import_py_file(module_path) + self.assertIsNotNone(mod) + descriptors = expand_helper._find_descriptors(mod) + self.assertTrue(descriptors) + + operand_specs = expand_helper._parse_operand_specs( + """ +[ + { + "kind": "tile", + "dtype": "f32", + "shape": [16, 64], + "valid_shape": [8, 48], + "memory_space": "ub", + "config": { + "b_layout": "row_major", + "s_layout": "none_box", + "s_fractal_size": 512, + "pad_value": "0x0" + } + }, + { + "kind": "tile", + "dtype": "f32", + "shape": [16, 64], + "valid_shape": [8, 48], + "memory_space": "ub", + "config": { + "b_layout": "row_major", + "s_layout": "none_box", + "s_fractal_size": 512, + "pad_value": "0x1" + } + } +] +""" + ) + desc = expand_helper._match_descriptor( + descriptors, + "pto.expand_helper_tile_config_unique", + tuple(spec["dtype"] for spec in operand_specs), + ) + self.assertIsNotNone(desc) + + tile_specs = {} + for param, operand_spec in zip(desc.parameters, operand_specs): + self.assertEqual(param.kind, "tile") + tile_specs[param.name] = pto.TileSpecialization( + shape=operand_spec["shape"], + memory_space=operand_spec["memory_space"], + config=operand_spec["config"], + valid_shape=operand_spec["valid_shape"], + ) + + mlir_text = desc.specialize(**tile_specs).mlir_text() + + self.assertIn("valid_shape=(8, 48)", mlir_text) + self.assertIn( + "!pto.tile_buf", + mlir_text, + ) + self.assertIn( + "!pto.tile_buf", + mlir_text, + ) + + class TileLangDSLSupportMatrixTests(unittest.TestCase): def test_stable_starter_surface_groups_map_to_stable_tier(self) -> None: self.assertEqual(get_surface_group_tier("TensorView"), BASIC_TIER) From ef243ec6ef41354a5ec3a0c7efac6477bf803a14 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Tue, 14 Apr 2026 14:10:40 +0800 Subject: [PATCH 058/192] Use pad_value.eval() instead of pad_value.value --- .../docs/user_guide/05-type-system.md | 22 ++++---- .../python/tilelang_dsl/frontend_ast.py | 19 +++++++ tilelang-dsl/python/tilelang_dsl/kernel.py | 13 +++++ tilelang-dsl/python/tilelang_dsl/semantic.py | 50 ++++++++++++++++--- tilelang-dsl/python/tilelang_dsl/types.py | 5 +- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 26 +++++----- 6 files changed, 103 insertions(+), 32 deletions(-) diff --git a/tilelang-dsl/docs/user_guide/05-type-system.md b/tilelang-dsl/docs/user_guide/05-type-system.md index 91d7c8845..47e1d104b 100644 --- a/tilelang-dsl/docs/user_guide/05-type-system.md +++ b/tilelang-dsl/docs/user_guide/05-type-system.md @@ -297,10 +297,10 @@ pad2 = pto.PadValue.custom_f32("0xBF800000") # float32 bit pattern for -1.0f ``` Notes: -- `PadValue.value` on the host-side descriptor still exposes the encoded integer payload. +- `PadValue.encoded` exposes the host-side uint64 payload. `PadValue.value` is intentionally unavailable to avoid confusion with kernel-side `.eval()`. - `PadValue.text` exposes the standard textual spelling for built-ins such as `null` and `zero`. - Custom pad values currently model an `f32` payload. In DSL v1, materializing a custom pad into a scalar is only supported for floating tile element dtypes. -- `PadValue.NULL` does not denote a usable scalar fill constant. Reading `tile.pad_value.value` or `tile.config.pad_value.value` when the enum is `NULL` is a frontend error. +- `PadValue.NULL` does not denote a usable scalar fill constant. Calling `tile.pad_value.eval()` or `tile.config.pad_value.eval()` when the enum is `NULL` is a frontend error. #### Tile Shape Concepts @@ -330,13 +330,13 @@ pad_desc2 = tile.pad_value # direct sugar for the same PadValue enum rank = tile.rank # 2 ``` -`tile.config.pad_value` and `tile.pad_value` are enum-typed inside kernel code. Use `.value` to materialize the configured pad descriptor against the tile element dtype: +`tile.config.pad_value` and `tile.pad_value` are enum-typed inside kernel code. Use `.eval()` to materialize the configured pad descriptor against the tile element dtype: -- `tile.pad_value.value` with `PadValue.ZERO` becomes `0` / `0.0` -- `tile.pad_value.value` with `PadValue.MAX` becomes dtype-aware max -- `tile.pad_value.value` with `PadValue.MIN` becomes dtype-aware min -- `tile.pad_value.value` with `PadValue.custom_f32(...)` becomes the authored floating scalar -- `tile.pad_value.value` with `PadValue.NULL` raises a frontend error +- `tile.pad_value.eval()` with `PadValue.ZERO` becomes `0` / `0.0` +- `tile.pad_value.eval()` with `PadValue.MAX` becomes dtype-aware max +- `tile.pad_value.eval()` with `PadValue.MIN` becomes dtype-aware min +- `tile.pad_value.eval()` with `PadValue.custom_f32(...)` becomes the authored floating scalar +- `tile.pad_value.eval()` with `PadValue.NULL` raises a frontend error Example: reading pad value from a `Tile` @@ -352,8 +352,8 @@ def kernel(dst: pto.Tile): pad1 = dst.config.pad_value if pto.constexpr(pad0 != pto.PadValue.NULL): - scalar0 = pad0.value - scalar1 = pad1.value + scalar0 = pad0.eval() + scalar1 = pad1.eval() vec0 = pto.vdup(scalar0, mask) vec1 = pto.vdup(scalar1, mask) pto.vsts(vec0, dst[0, 0:], mask) @@ -361,7 +361,7 @@ def kernel(dst: pto.Tile): ``` If `dst` is specialized with `config=pto.TileConfig.from_mapping({"pad_value": pto.PadValue.ZERO})`, -both `pad0` and `pad1` are `PadValue.ZERO`, and `pad0.value` / `pad1.value` materialize to the scalar `0.0` for an `f16` tile. +both `pad0` and `pad1` are `PadValue.ZERO`, and `pad0.eval()` / `pad1.eval()` materialize to the scalar `0.0` for an `f16` tile. #### Conversion Operations diff --git a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py index c11ec116d..9f2127ff5 100644 --- a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py +++ b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py @@ -1019,6 +1019,25 @@ def _build_expr(node: ast.AST, context: _FrontendBuildContext) -> FrontendExprNo node, context, ) + if isinstance(node.func, ast.Attribute): + return _attach_source_location( + FrontendCallExpr( + namespace=None, + name=node.func.attr, + args=( + _build_expr(node.func.value, context), + *(tuple(_build_expr(arg, context) for arg in node.args)), + ), + keywords=_build_call_keywords( + node, + namespace=None, + name=node.func.attr, + context=context, + ), + ), + node, + context, + ) raise context.error( node, f"unsupported expression `{type(node).__name__}` in TileLang DSL v1", diff --git a/tilelang-dsl/python/tilelang_dsl/kernel.py b/tilelang-dsl/python/tilelang_dsl/kernel.py index 3e4fbc93e..56c8d74f9 100644 --- a/tilelang-dsl/python/tilelang_dsl/kernel.py +++ b/tilelang-dsl/python/tilelang_dsl/kernel.py @@ -392,6 +392,19 @@ def _validate_call_keywords(self, node: ast.Call) -> None: seen.add(keyword.arg) def visit_Call(self, node: ast.Call) -> None: + if isinstance(node.func, ast.Attribute) and node.func.attr == "eval": + if node.keywords: + raise self.source_info.error( + node, + "`eval` does not support keyword arguments in TileLang DSL v1", + ) + if node.args: + raise self.source_info.error( + node, + "`eval()` does not accept positional arguments in TileLang DSL v1", + ) + return + if isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name): if node.func.attr == "as_ptr": if node.keywords: diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 0df887657..1433d4050 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -2802,8 +2802,6 @@ def _analyze_expr( return self._attach_expr_source_location(self._memory_space_expr(base), expr) if expr.attr == "pad_value" and isinstance(base.type, SemanticTileType): return self._attach_expr_source_location(self._tile_pad_value_expr(base), expr) - if expr.attr == "value" and isinstance(base.type, SemanticPadValueType): - return self._attach_expr_source_location(self._pad_value_value_expr(base), expr) if expr.attr == "config": return self._attach_expr_source_location(self._tile_config_expr(base), expr) if expr.attr == "valid_shape": @@ -2850,6 +2848,33 @@ def _analyze_expr( for arg in expr.args ) return self._analyze_inline_proc_call_expr(expr.name, args) + if expr.namespace is None and expr.name == "eval": + if expr.keywords: + raise TypeError("method call `eval` does not support keyword arguments in TileLang DSL v1") + if not expr.args: + raise TypeError("`eval()` expects a receiver in TileLang DSL v1") + base = self._analyze_expr(expr.args[0], env, allow_outer_lookup=allow_outer_lookup) + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args[1:] + ) + return self._analyze_eval_method(base, args) + if expr.namespace not in {None, "pto"} and expr.name == "eval": + if expr.keywords: + raise TypeError("method call `eval` does not support keyword arguments in TileLang DSL v1") + binding = env.get(expr.namespace) + if binding is None: + if allow_outer_lookup: + raise ValueError(f"unknown name '{expr.namespace}'") + raise ValueError( + f"implicit capture of '{expr.namespace}' is not allowed in pto.strict_vecscope" + ) + base = SemanticBindingRef(binding=binding, type=binding.type) + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + return self._analyze_eval_method(base, args) if expr.namespace not in {None, "pto"} and expr.name == "as_ptr": if expr.keywords: raise TypeError("method call `as_ptr` does not support keyword arguments in TileLang DSL v1") @@ -3179,28 +3204,37 @@ def _tile_pad_value_expr(self, base: SemanticExpr) -> SemanticExpr: type=SemanticPadValueType(element_dtype=base_type.element_dtype), ) - def _pad_value_value_expr(self, base: SemanticExpr) -> SemanticExpr: + def _pad_value_eval_expr(self, base: SemanticExpr) -> SemanticExpr: if not isinstance(base.type, SemanticPadValueType): - raise TypeError("unsupported attribute access 'value' in TileLang DSL v1") + raise TypeError("`eval()` expects a PadValue descriptor in TileLang DSL v1") if base.type.element_dtype is None: raise TypeError( - "PadValue.value requires a Tile-bound or TileConfig-bound pad descriptor with an owning " + "PadValue.eval() requires a Tile-bound or TileConfig-bound pad descriptor with an owning " "Tile element dtype in TileLang DSL v1" ) pad_value = self._try_static_value(base) if not isinstance(pad_value, PadValue): - raise TypeError("PadValue.value expects a statically known PadValue enum in TileLang DSL v1") + raise TypeError("PadValue.eval() expects a statically known PadValue enum in TileLang DSL v1") pad_scalar = pad_value.materialize_scalar(base.type.element_dtype) if pad_scalar is None: raise TypeError( - "PadValue.NULL.value is invalid in TileLang DSL v1; " - "guard it with `pto.constexpr(tile.pad_value != pto.PadValue.NULL)` before reading `.value`" + "PadValue.NULL.eval() is invalid in TileLang DSL v1; " + "guard it with `pto.constexpr(tile.pad_value != pto.PadValue.NULL)` before calling `.eval()`" ) return SemanticLiteralExpr( value=pad_scalar, type=SemanticScalarType(dtype=base.type.element_dtype), ) + def _analyze_eval_method( + self, + base: SemanticExpr, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if args: + raise TypeError("`eval()` does not accept positional arguments in TileLang DSL v1") + return self._pad_value_eval_expr(base) + def _tile_config_attr_expr(self, base: SemanticExpr, attr: str) -> SemanticExpr: config = self._try_static_value(base) if not isinstance(config, TileConfig): diff --git a/tilelang-dsl/python/tilelang_dsl/types.py b/tilelang-dsl/python/tilelang_dsl/types.py index 6780c998c..62f9dc964 100644 --- a/tilelang-dsl/python/tilelang_dsl/types.py +++ b/tilelang-dsl/python/tilelang_dsl/types.py @@ -264,7 +264,10 @@ def name(self) -> str: @property def value(self) -> int: - return self.encoded + raise AttributeError( + "PadValue.value is not available; use PadValue.encoded for host-side payload access " + "or pad.eval() for Tile-bound scalar materialization" + ) @property def text(self) -> str: diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index 7d9ca4c6a..0f0e68d0d 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -125,10 +125,10 @@ def test_package_exports_surface(self) -> None: self.assertEqual(pto.PadMode.PadValue.value, "PadValue") self.assertEqual(pto.BLayout.ROW_MAJOR.value, "row_major") self.assertEqual(pto.SLayout.NONE_BOX.value, "none_box") - self.assertEqual(pto.PadValue.NULL.value, 0) - self.assertEqual(pto.PadValue.ZERO.value, 1) - self.assertEqual(pto.PadValue.MAX.value, 2) - self.assertEqual(pto.PadValue.MIN.value, 3) + self.assertEqual(pto.PadValue.NULL.encoded, 0) + self.assertEqual(pto.PadValue.ZERO.encoded, 1) + self.assertEqual(pto.PadValue.MAX.encoded, 2) + self.assertEqual(pto.PadValue.MIN.encoded, 3) self.assertEqual(pto.PadValue.NULL.text, "null") self.assertEqual(pto.DeinterleaveDist.DINTLV.value, "DINTLV") self.assertEqual(pto.DeinterleaveDist.BDINTLV.value, "BDINTLV") @@ -180,7 +180,7 @@ def test_pad_value_supports_standard_and_custom_payloads(self) -> None: custom = pto.PadValue.custom_f32(-1.0) self.assertTrue(custom.is_custom) self.assertEqual(custom.float32_bits, 0xBF800000) - self.assertEqual(custom.value, pto.PadValue.CustomBase | (0xBF800000 << 32)) + self.assertEqual(custom.encoded, pto.PadValue.CustomBase | (0xBF800000 << 32)) self.assertAlmostEqual(custom.as_float32(), -1.0) self.assertAlmostEqual(custom.materialize_scalar(pto.f32), -1.0) self.assertEqual(pto.PadValue.MAX.materialize_scalar(pto.ui16), 0xFFFF) @@ -188,6 +188,8 @@ def test_pad_value_supports_standard_and_custom_payloads(self) -> None: self.assertEqual(pto.PadValue.MAX.materialize_scalar(pto.i16), 0x7FFF) self.assertEqual(pto.PadValue.MIN.materialize_scalar(pto.i16), -0x8000) self.assertIsNone(pto.PadValue.NULL.materialize_scalar(pto.f16)) + with self.assertRaises(AttributeError): + _ = pto.PadValue.ZERO.value class TileLangDSLExpandHelperTests(unittest.TestCase): @@ -200,7 +202,7 @@ def kernel(src: pto.Tile, dst: pto.Tile): rows, cols = src.valid_shape pad = dst.pad_value if pto.constexpr(pad != pto.PadValue.NULL): - scalar = pad.value + scalar = pad.eval() return None """ with tempfile.TemporaryDirectory() as tmpdir: @@ -1804,8 +1806,8 @@ def kernel(tile: pto.Tile): fractal = config.s_fractal_size pad = config.pad_value pad_direct = tile.pad_value - pad_scalar = pad.value - pad_direct_scalar = pad_direct.value + pad_scalar = pad.eval() + pad_direct_scalar = pad_direct.eval() rank = tile.rank space = tile.memory_space return None @@ -1887,10 +1889,10 @@ def kernel(tile: pto.Tile): self.assertEqual(space_assign.value.value, pto.MemorySpace.UB) self.assertEqual(space_assign.value.type.kind, "memory_space") - def test_pad_value_value_requires_non_null_enum(self) -> None: - @pto.vkernel(op="tile_pad_value_null_value", dtypes=[(pto.f16,)]) + def test_pad_value_eval_requires_non_null_enum(self) -> None: + @pto.vkernel(op="tile_pad_value_null_eval", dtypes=[(pto.f16,)]) def kernel(tile: pto.Tile): - scalar = tile.pad_value.value + scalar = tile.pad_value.eval() return None specialized = kernel.specialize( @@ -1903,7 +1905,7 @@ def kernel(tile: pto.Tile): with self.assertRaises(TypeError) as ctx: analyze_frontend_kernel(build_frontend_kernel_node(specialized)) - self.assertIn("PadValue.NULL.value is invalid", str(ctx.exception)) + self.assertIn("PadValue.NULL.eval() is invalid", str(ctx.exception)) def test_make_mask_vlds_vsts_and_vector_families_lower_inside_strict_vecscope(self) -> None: From 7986fc4f3d3f3f971340ef56d70d43c9f7100719 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Tue, 14 Apr 2026 16:07:23 +0800 Subject: [PATCH 059/192] Fix semantics of vdup op --- .../11-vector-arithmetic-operations.md | 18 +++++++------- tilelang-dsl/python/tilelang_dsl/lowering.py | 19 ++++++++++----- tilelang-dsl/python/tilelang_dsl/semantic.py | 23 ++++++++++++++---- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 24 ++++++++++++++++++- 4 files changed, 64 insertions(+), 20 deletions(-) diff --git a/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md b/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md index 99b70f9b7..b4d304458 100644 --- a/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md +++ b/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md @@ -893,12 +893,8 @@ rowmax_seed_f32 = pto.vbr(pto.f32("-inf")) rowmax_seed_f16 = pto.vbr(pto.f16("0xFC00")) ``` -**Position Mode Enum**: The `PositionMode` enum provides type-safe source-lane -selection for `pto.vdup`. `LOWEST` selects the lowest-index element of the -source vector and `HIGHEST` selects the highest-index element. When the input is -a scalar, the duplicated scalar value is independent of `position`. - -#### `pto.vdup(input: ScalarType | VRegType, mask: MaskType, position: PositionMode = PositionMode.LOWEST) -> VRegType` +#### `pto.vdup(input: ScalarType, mask: MaskType) -> VRegType` +#### `pto.vdup(input: VRegType, mask: MaskType, position: PositionMode = PositionMode.LOWEST) -> VRegType` **Description**: Duplicate a scalar value or one selected vector element into the active lanes of a destination vector. @@ -908,7 +904,12 @@ the active lanes of a destination vector. |-----------|------|-------------| | `input` | `ScalarType` or `VRegType` | Input scalar or source vector | | `mask` | `MaskType` | Predicate mask controlling which lanes are written | -| `position` | `PositionMode` | Optional enum selecting the source vector element to duplicate (default: `PositionMode.LOWEST`) | +| `position` | `PositionMode` | Optional enum for the vector-input overload, selecting the source vector element to duplicate (default: `PositionMode.LOWEST`) | + +**Position Mode Enum**: The `PositionMode` enum provides type-safe source-lane +selection for `pto.vdup`. `LOWEST` selects the lowest-index element of the +source vector and `HIGHEST` selects the highest-index element. The enum is only +used by the vector-input overload. **Returns**: | Return Value | Type | Description | @@ -921,6 +922,7 @@ the active lanes of a destination vector. - When `input` is a scalar, the scalar value is duplicated to every active lane. - When `input` is a vector, `position` selects a single source element and that value is duplicated to every active lane. +- The scalar overload does not accept `position`. - Inactive lanes follow VPTO predicate semantics and are not guaranteed to carry meaningful values for subsequent masked-off use. - Supported scalar types are the 8/16/32-bit integer families (`i*`, `si*`, `ui*`) plus `f16`, `bf16`, and `f32`. @@ -932,7 +934,7 @@ the active lanes of a destination vector. mask32 = pto.make_mask(pto.f32, pto.PAT.ALL) # Duplicate a scalar into all active lanes. -broadcast = pto.vdup(3.14, mask32) # position defaults to "LOWEST" +broadcast = pto.vdup(3.14, mask32) # Use dtype constructors for floating-point special values. seed = pto.vdup(pto.f32("-inf"), mask32) diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index 78d9e539e..24953d7ff 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -2545,12 +2545,19 @@ def _lower_call_expr( if expr.name == "vdup": value = self._lower_expr(expr.args[0], env, indent=indent, into=into) mask = self._lower_expr(expr.args[1], env, indent=indent, into=into) - position = self._render_string_literal(expr.args[2]) - into.append( - self._indent(indent) - + f"{result_name} = pto.vdup {value.name}, {mask.name} {{position = {position}}} : " - + f"{self._render_type(value.type)}, {self._render_type(mask.type)} -> {self._render_type(expr.type)}" - ) + if len(expr.args) == 3: + position = self._render_string_literal(expr.args[2]) + into.append( + self._indent(indent) + + f"{result_name} = pto.vdup {value.name}, {mask.name} {{position = {position}}} : " + + f"{self._render_type(value.type)}, {self._render_type(mask.type)} -> {self._render_type(expr.type)}" + ) + else: + into.append( + self._indent(indent) + + f"{result_name} = pto.vdup {value.name}, {mask.name} : " + + f"{self._render_type(value.type)}, {self._render_type(mask.type)} -> {self._render_type(expr.type)}" + ) return _RenderedValue(name=result_name, type=expr.type) if expr.name == "vci": diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 1433d4050..403e0d2b1 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -4032,16 +4032,29 @@ def _analyze_broadcast_vector_op( value = args[0] if isinstance(value.type, SemanticVRegType): vec_type = value.type - else: - vec_type = self._vreg_type_for_scalar_or_index(value, "pto.vdup input") + mask = args[1] + self._require_mask_for_vreg(mask, vec_type, "pto.vdup") + position_arg = args[2] if len(args) == 3 else None + position = self._normalize_position_mode(position_arg, "pto.vdup position") + return SemanticCallExpr( + namespace="pto", + name=name, + args=(value, mask, position), + type=vec_type, + ) + + if len(args) == 3: + raise TypeError( + "pto.vdup scalar input does not accept `position`; use `pto.vdup(input, mask)` " + "in TileLang DSL v1" + ) + vec_type = self._vreg_type_for_scalar_or_index(value, "pto.vdup input") mask = args[1] self._require_mask_for_vreg(mask, vec_type, "pto.vdup") - position_arg = args[2] if len(args) == 3 else None - position = self._normalize_position_mode(position_arg, "pto.vdup position") return SemanticCallExpr( namespace="pto", name=name, - args=(value, mask, position), + args=(value, mask), type=vec_type, ) diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index 0f0e68d0d..a2f394f28 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -2517,8 +2517,9 @@ def kernel(dst: pto.Tile, src: pto.Tile, seed: pto.i32): ) self.assertRegex( text, - r'pto\.vdup\s+%[^\s]+,\s+%[^\s]+\s+\{position = "LOWEST"\}\s+:', + r'pto\.vdup\s+%[^\s]+,\s+%[^\s]+\s+:', ) + self.assertNotIn('position = "LOWEST"', text) self.assertNotIn('position = "POS_LOWEST"', text) self.assertRegex( text, @@ -2529,6 +2530,27 @@ def kernel(dst: pto.Tile, src: pto.Tile, seed: pto.i32): r'pto\.vci\s+%[^\s]+,\s*"ORDER_ASC"\s+:', ) + def test_vdup_scalar_input_rejects_position_argument(self) -> None: + with self.assertRaises(TypeError) as ctx: + + @pto.vkernel( + op="vdup_scalar_reject_position_unique", + dtypes=[(pto.i32, pto.i32)], + advanced=True, + ) + def kernel(dst: pto.Tile, seed: pto.i32): + all_mask = pto.make_mask(pto.i32, pto.PAT.ALL) + out = pto.vdup(seed, all_mask, pto.PositionMode.HIGHEST) + pto.vsts(out, dst, 0, all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + specialized.mlir_text() + + self.assertIn("pto.vdup scalar input does not accept `position`", str(ctx.exception)) + def test_signed_and_unsigned_integer_dtypes_lower_distinctly(self) -> None: @pto.vkernel( op="signed_unsigned_integer_types_unique", From 9baecde9c47175bd90b21071d324d19b12f4ae33 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Tue, 14 Apr 2026 15:05:16 +0800 Subject: [PATCH 060/192] Fix vecscope inference for inline_proc --- tilelang-dsl/python/tilelang_dsl/semantic.py | 28 +++++++++++++------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 403e0d2b1..8b4daedeb 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -893,16 +893,20 @@ def _analyze_block( env: dict[str, SemanticBinding], *, allow_outer_lookup: bool, + allow_inferred_vecscope: bool = True, ) -> tuple[tuple[SemanticStmt, ...], dict[str, SemanticBinding]]: current_env = dict(env) semantic_statements = [] index = 0 while index < len(statements): - if self._should_infer_vecscope(statements[index], allow_outer_lookup=allow_outer_lookup): + if self._should_infer_vecscope( + statements[index], + allow_inferred_vecscope=allow_inferred_vecscope, + ): end = index + 1 while end < len(statements) and self._should_infer_vecscope( statements[end], - allow_outer_lookup=allow_outer_lookup, + allow_inferred_vecscope=allow_inferred_vecscope, ): end += 1 run = statements[index:end] @@ -982,13 +986,13 @@ def _should_infer_vecscope( self, stmt: FrontendStmtNode, *, - allow_outer_lookup: bool, + allow_inferred_vecscope: bool, ) -> bool: if self._has_explicit_vecscope: return False if self._disable_inference_depth > 0: return False - if not allow_outer_lookup: + if not allow_inferred_vecscope: return False if isinstance(stmt, FrontendForStmt): return self._block_can_live_in_inferred_vecscope(stmt.body) @@ -1365,6 +1369,7 @@ def _materialize_inline_proc_specialization( inline_proc_node.body, helper_env, allow_outer_lookup=False, + allow_inferred_vecscope=True, ) finally: self._inline_proc_active_stack.pop() @@ -2693,11 +2698,16 @@ def _analyze_strict_vecscope( block_binding = self._make_binding(name, capture.type, "strict_vecscope_arg") scope_env[name] = block_binding block_arguments.append(block_binding) - body, _ = self._analyze_block( - stmt.body, - scope_env, - allow_outer_lookup=False, - ) + self._disable_inference_depth += 1 + try: + body, _ = self._analyze_block( + stmt.body, + scope_env, + allow_outer_lookup=False, + allow_inferred_vecscope=False, + ) + finally: + self._disable_inference_depth -= 1 return ( SemanticStrictVecscopeStmt( captures=captures, From a0afcd3bbfeca92c6edbe2152b2e7a7db01bd426 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Tue, 14 Apr 2026 15:28:15 +0800 Subject: [PATCH 061/192] Fix tshrs/tshls scalar type legality check --- .../docs/user_guide/11-vector-arithmetic-operations.md | 8 ++++---- tilelang-dsl/python/tilelang_dsl/semantic.py | 5 ++++- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 4 ++-- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md b/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md index b4d304458..185dec6ca 100644 --- a/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md +++ b/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md @@ -759,7 +759,7 @@ scaled = pto.vmuls(vec_f32, pto.f32(2.0), mask32) |--------------|------|-------------| | `result` | `VRegType` | Leaky ReLU activated values | -#### `pto.vshls(vec: VRegType, shift: ScalarType, mask: MaskType) -> VRegType` +#### `pto.vshls(vec: VRegType, shift: i16, mask: MaskType) -> VRegType` **Description**: Vector shift left by scalar (uniform shift). @@ -767,7 +767,7 @@ scaled = pto.vmuls(vec_f32, pto.f32(2.0), mask32) | Parameter | Type | Description | |-----------|------|-------------| | `vec` | `VRegType` | Input vector | -| `shift` | `ScalarType` | Shift amount (same for all elements) | +| `shift` | `i16` | Shift amount (same for all elements) | | `mask` | `MaskType` | Predicate mask | **Returns**: @@ -775,7 +775,7 @@ scaled = pto.vmuls(vec_f32, pto.f32(2.0), mask32) |--------------|------|-------------| | `result` | `VRegType` | Shifted values | -#### `pto.vshrs(vec: VRegType, shift: ScalarType, mask: MaskType) -> VRegType` +#### `pto.vshrs(vec: VRegType, shift: i16, mask: MaskType) -> VRegType` **Description**: Vector shift right by scalar (uniform shift). @@ -783,7 +783,7 @@ scaled = pto.vmuls(vec_f32, pto.f32(2.0), mask32) | Parameter | Type | Description | |-----------|------|-------------| | `vec` | `VRegType` | Input vector | -| `shift` | `ScalarType` | Shift amount (same for all elements) | +| `shift` | `i16` | Shift amount (same for all elements) | | `mask` | `MaskType` | Predicate mask | **Returns**: diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 8b4daedeb..cafc87868 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -4130,7 +4130,10 @@ def _analyze_vector_scalar_op( vector_expr, scalar_expr, mask = args vreg = self._require_vreg_expr(vector_expr, f"pto.{name} vector") scalar = self._require_scalar_expr(scalar_expr, f"pto.{name} scalar") - if scalar.dtype != vreg.element_dtype: + if name in {"vshls", "vshrs"}: + if scalar.dtype != i16: + raise TypeError(f"pto.{name} scalar dtype must be i16") + elif scalar.dtype != vreg.element_dtype: raise TypeError(f"pto.{name} scalar dtype must match vector element dtype") self._require_mask_for_vreg(mask, vreg, f"pto.{name}") self._validate_vector_scalar_dtype(name, vreg.element_dtype) diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index a2f394f28..eec3049bb 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -2374,10 +2374,10 @@ def kernel(dst: pto.Tile, src: pto.Tile): def test_extended_integer_vector_ops_surface_lowers(self) -> None: @pto.vkernel( op="extended_integer_vector_ops_unique", - dtypes=[(pto.i32, pto.i32, pto.i32)], + dtypes=[(pto.i32, pto.i32, pto.i16)], advanced=True, ) - def kernel(dst: pto.Tile, src: pto.Tile, shift: pto.i32): + def kernel(dst: pto.Tile, src: pto.Tile, shift: pto.i16): all_mask = pto.make_mask(pto.i32, pto.PAT.ALL) vec0 = pto.vlds(src, 0) vec1 = pto.vlds(src, 64) From 6b82096b414d33310daa23e4e8a3f9922428af78 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Tue, 14 Apr 2026 19:07:00 +0800 Subject: [PATCH 062/192] fix(dsl): extend scalar arithmetic/bitwise lowering and tests (#71) --- .../python/tilelang_dsl/frontend_ast.py | 5 + tilelang-dsl/python/tilelang_dsl/lowering.py | 39 +++++- tilelang-dsl/python/tilelang_dsl/semantic.py | 70 +++++++++- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 123 ++++++++++++++++++ 4 files changed, 228 insertions(+), 9 deletions(-) diff --git a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py index 9f2127ff5..06a7956d3 100644 --- a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py +++ b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py @@ -622,6 +622,11 @@ def _collect_reachable_inline_procs( ast.Mult: "mul", ast.Mod: "mod", ast.FloorDiv: "floordiv", + ast.BitAnd: "bitand", + ast.BitOr: "bitor", + ast.BitXor: "bitxor", + ast.LShift: "lshift", + ast.RShift: "rshift", } _COMPARE_OP_NAMES = { ast.Eq: "eq", diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index 24953d7ff..c9a2611c3 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -78,6 +78,7 @@ get_lanes, integer_bitwidth, integer_signedness, + is_float_dtype, is_integer_dtype, ) @@ -3431,7 +3432,7 @@ def _format_constant(self, value: object, ty: SemanticType) -> str: raise NotImplementedError(f"unsupported constant type {ty!r}") def _render_binary_op(self, op: str, ty: SemanticType) -> str: - if isinstance(ty, (SemanticIndexType, SemanticScalarType)): + if isinstance(ty, SemanticIndexType): if op == "add": return "arith.addi" if op == "sub": @@ -3441,9 +3442,41 @@ def _render_binary_op(self, op: str, ty: SemanticType) -> str: if op == "mod": if isinstance(ty, SemanticIndexType): return "arith.remui" - return "arith.remsi" if op == "floordiv": - return "arith.floordivsi" + return "arith.divui" + if isinstance(ty, SemanticScalarType): + dtype = ty.dtype + if is_float_dtype(dtype): + if op == "add": + return "arith.addf" + if op == "sub": + return "arith.subf" + if op == "mul": + return "arith.mulf" + if is_integer_dtype(dtype): + if op == "add": + return "arith.addi" + if op == "sub": + return "arith.subi" + if op == "mul": + return "arith.muli" + if op == "mod": + sign = integer_signedness(dtype) + return "arith.remui" if sign == "unsigned" else "arith.remsi" + if op == "floordiv": + sign = integer_signedness(dtype) + return "arith.divui" if sign == "unsigned" else "arith.floordivsi" + if op == "bitand": + return "arith.andi" + if op == "bitor": + return "arith.ori" + if op == "bitxor": + return "arith.xori" + if op == "lshift": + return "arith.shli" + if op == "rshift": + sign = integer_signedness(dtype) + return "arith.shrui" if sign == "unsigned" else "arith.shrsi" raise NotImplementedError(f"unsupported binary op '{op}' for type {ty!r}") def _render_type(self, ty: SemanticType) -> str: diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index cafc87868..963598d16 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -3513,10 +3513,23 @@ def _binary_type( rhs: SemanticExpr, op: str, ) -> SemanticType: - if op in {"add", "sub", "mul", "mod", "floordiv"}: + if op in {"add", "sub", "mul", "mod", "floordiv", "bitand", "bitor", "bitxor", "lshift", "rshift"}: if isinstance(lhs.type, SemanticIndexType) and isinstance(rhs.type, SemanticIndexType): - return SemanticIndexType() - raise TypeError("binary expressions currently only support index-typed operands") + if op in {"add", "sub", "mul", "mod", "floordiv"}: + return SemanticIndexType() + if isinstance(lhs.type, SemanticScalarType) and lhs.type == rhs.type: + dtype = lhs.type.dtype + if op in {"add", "sub", "mul"} and (is_integer_dtype(dtype) or is_float_dtype(dtype)): + return SemanticScalarType(dtype=dtype) + if op in {"mod", "floordiv"} and is_integer_dtype(dtype): + return SemanticScalarType(dtype=dtype) + if op in {"bitand", "bitor", "bitxor", "lshift", "rshift"} and is_integer_dtype(dtype): + return SemanticScalarType(dtype=dtype) + raise TypeError( + "binary expressions currently require matching index operands, " + "or matching scalar operands (add/sub/mul for integer/float; " + "mod/floordiv/bitwise/shift for integer)" + ) if op in {"eq", "ne"}: if isinstance(lhs.type, SemanticIndexType) and isinstance(rhs.type, SemanticIndexType): return SemanticScalarType(dtype=i1) @@ -5148,15 +5161,30 @@ def _try_static_value(self, expr: SemanticExpr | None) -> Any | None: if lhs is None or rhs is None: return None if expr.op == "add": - if isinstance(lhs, int) and isinstance(rhs, int): + if ( + isinstance(lhs, (int, float)) + and isinstance(rhs, (int, float)) + and not isinstance(lhs, bool) + and not isinstance(rhs, bool) + ): return lhs + rhs return None if expr.op == "sub": - if isinstance(lhs, int) and isinstance(rhs, int): + if ( + isinstance(lhs, (int, float)) + and isinstance(rhs, (int, float)) + and not isinstance(lhs, bool) + and not isinstance(rhs, bool) + ): return lhs - rhs return None if expr.op == "mul": - if isinstance(lhs, int) and isinstance(rhs, int): + if ( + isinstance(lhs, (int, float)) + and isinstance(rhs, (int, float)) + and not isinstance(lhs, bool) + and not isinstance(rhs, bool) + ): return lhs * rhs return None if expr.op == "mod": @@ -5171,6 +5199,36 @@ def _try_static_value(self, expr: SemanticExpr | None) -> Any | None: return None return lhs // rhs return None + if expr.op == "bitand": + if isinstance(lhs, int) and isinstance(rhs, int): + if isinstance(lhs, bool) or isinstance(rhs, bool): + return None + return lhs & rhs + return None + if expr.op == "bitor": + if isinstance(lhs, int) and isinstance(rhs, int): + if isinstance(lhs, bool) or isinstance(rhs, bool): + return None + return lhs | rhs + return None + if expr.op == "bitxor": + if isinstance(lhs, int) and isinstance(rhs, int): + if isinstance(lhs, bool) or isinstance(rhs, bool): + return None + return lhs ^ rhs + return None + if expr.op == "lshift": + if isinstance(lhs, int) and isinstance(rhs, int): + if isinstance(lhs, bool) or isinstance(rhs, bool) or rhs < 0: + return None + return lhs << rhs + return None + if expr.op == "rshift": + if isinstance(lhs, int) and isinstance(rhs, int): + if isinstance(lhs, bool) or isinstance(rhs, bool) or rhs < 0: + return None + return lhs >> rhs + return None if expr.op == "eq": return lhs == rhs if expr.op == "ne": diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index eec3049bb..47f4b230c 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -2040,6 +2040,129 @@ def kernel( self.assertIn("pto.plt_b32", text) self.assertIn("pto.vadd", text) + def test_scalar_binary_arithmetic_supports_float_and_integer_paths(self) -> None: + @pto.vkernel( + op="scalar_binary_arithmetic_unique", + dtypes=[(pto.f32, pto.f32, pto.i32)], + advanced=True, + ) + def kernel(dst_tile: pto.Tile, src_tile: pto.Tile, gate: pto.i32): + rows = src_tile.shape[0] + cols = src_tile.shape[1] + with pto.strict_vecscope( + src_tile, + dst_tile, + gate, + rows, + cols, + 0, + rows, + 1, + ) as (src, dst, in_gate, valid_rows, valid_cols, row_lb, row_ub, row_step): + for row in range(row_lb, row_ub, row_step): + for lane in range(0, valid_cols, 64): + half = in_gate // pto.i32(2) + remain = in_gate % pto.i32(7) + factor = pto.f32(half) + pto.f32(remain) * pto.f32(0.5) + mask, _ = pto.make_mask(pto.f32, valid_cols - lane) + vec = pto.vlds(src, lane) + vec = pto.vmuls(vec, factor, mask) + pto.vsts(vec, dst, lane, mask) + return None + + specialized = kernel.specialize( + dst_tile=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src_tile=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertRegex(text, r"= arith\.floordivsi %in_gate_\d+, %c2_i32 : i32") + self.assertRegex(text, r"= arith\.remsi %in_gate_\d+, %c7_i32 : i32") + self.assertRegex(text, r"= arith\.mulf %tmp_\d+, %c0\.5_f32 : f32") + self.assertRegex(text, r"= arith\.addf %tmp_\d+, %tmp_\d+ : f32") + + def test_index_floordiv_lowers_to_divui_instead_of_floordivsi(self) -> None: + @pto.vkernel( + op="index_floordiv_lowering_unique", + dtypes=[(pto.f32, pto.f32, pto.f32)], + advanced=True, + ) + def kernel( + lhs_tile: pto.Tile, + rhs_tile: pto.Tile, + dst_tile: pto.Tile, + ): + rows = lhs_tile.shape[0] + cols = lhs_tile.shape[1] + with pto.strict_vecscope( + lhs_tile, + rhs_tile, + dst_tile, + rows, + cols, + 0, + rows, + 1, + ) as (lhs, rhs, dst, valid_rows, valid_cols, row_lb, row_ub, row_step): + for row in range(row_lb, row_ub, row_step): + for lane in range(0, valid_cols, 64): + row_bucket = row // valid_cols + offset = row_bucket * valid_cols + lane + mask, _ = pto.make_mask(pto.f32, valid_cols - lane) + summed = pto.vadd(pto.vlds(lhs, offset), pto.vlds(rhs, offset), mask) + pto.vsts(summed, dst, offset, mask) + return None + + specialized = kernel.specialize( + lhs_tile=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + rhs_tile=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + dst_tile=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertRegex(text, r"= arith\.divui %row_\d+, %valid_cols_\d+ : index") + self.assertNotRegex(text, r"arith\.floordivsi .*: index") + + def test_scalar_bitwise_and_shift_ops_lower_for_signed_and_unsigned(self) -> None: + @pto.vkernel( + op="scalar_bitwise_shift_unique", + dtypes=[(pto.i32, pto.ui32)], + advanced=True, + ) + def kernel(signed_val: pto.i32, unsigned_val: pto.ui32): + signed_mix = (signed_val & pto.i32(15)) | pto.i32(1) + signed_mix = signed_mix ^ pto.i32(2) + signed_mix = signed_mix >> pto.i32(1) + signed_mix = signed_mix << pto.i32(3) + + unsigned_mix = unsigned_val & pto.ui32(31) + unsigned_mix = unsigned_mix >> pto.ui32(2) + unsigned_mix = unsigned_mix << pto.ui32(1) + unsigned_mix = unsigned_mix ^ pto.ui32(7) + return None + + specialized = kernel.specialize() + text = specialized.mlir_text() + + self.assertIn("arith.andi", text) + self.assertIn("arith.ori", text) + self.assertIn("arith.xori", text) + self.assertRegex(text, r"= arith\.shrsi %\w+_\d+, %c1_i32 : i32") + self.assertRegex(text, r"= arith\.shli %\w+_\d+, %c3_i32 : i32") + self.assertRegex(text, r"= arith\.shrui %\w+_\d+, %c2_ui32 : ui32") + self.assertRegex(text, r"= arith\.shli %\w+_\d+, %c1_ui32 : ui32") + + def test_scalar_bitwise_rejects_float_operands(self) -> None: + @pto.vkernel(op="scalar_bitwise_float_reject_unique", dtypes=[(pto.f32,)]) + def kernel(value: pto.f32): + _ = value & pto.f32(1.0) + return None + + specialized = kernel.specialize() + with self.assertRaises(TypeError) as ctx: + specialized.mlir_text() + self.assertIn("mod/floordiv/bitwise/shift for integer", str(ctx.exception)) + def test_stable_mode_infers_vecscope_and_lowers_tile_vector_sugar(self) -> None: @pto.vkernel(op="tadd_stable", dtypes=[(pto.f32, pto.f32, pto.f32)]) def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): From 91a7926608641a95d6ecef99f76a1cb93c5916c1 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Tue, 14 Apr 2026 20:19:37 +0800 Subject: [PATCH 063/192] fix(dsl): align vbitsort and vmrgsort4 with vpto (#67) --- .../11-vector-arithmetic-operations.md | 46 ++++++++++++------- tilelang-dsl/python/tilelang_dsl/lowering.py | 37 +++++++++++---- tilelang-dsl/python/tilelang_dsl/semantic.py | 46 +++++++++++++------ .../python/tilelang_dsl/support_matrix.py | 4 +- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 46 +++++++++++++++++-- 5 files changed, 133 insertions(+), 46 deletions(-) diff --git a/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md b/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md index 185dec6ca..67ade98d2 100644 --- a/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md +++ b/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md @@ -1395,38 +1395,50 @@ vec_f16_narrow = pto.vcvt( ) ``` -#### `pto.vbitsort(vec: VRegType, mask: MaskType) -> VRegType` +#### `pto.vbitsort(dest: ptr, src: ptr, indices: ptr, repeat_times: index) -> None` [Advanced Tier] -**Description**: Bitonic sort of vector elements. +**Description**: Sort 32 region proposals by score and materialize sorted proposal +records into UB memory. This is a UB helper and not a `vreg -> vreg` operation. **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| -| `vec` | `VRegType` | Input vector | -| `mask` | `MaskType` | Predicate mask | +| `dest` | `ptr` | Destination pointer in UB memory space | +| `src` | `ptr` | Source score pointer in UB memory space | +| `indices` | `ptr` | Source index pointer in UB memory space | +| `repeat_times` | `index` | Repeat count; each repeat processes the next adjacent group of 32 scores and 32 indices | **Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Sorted vector | +None. The op writes UB memory directly. + +**Constraints**: +- `dest`, `src`, and `indices` must be UB-backed pointers +- Scores are sorted in descending order +- Equal-score ties preserve the earlier input proposal first +- Output records occupy 8 bytes each: upper 4 bytes for the index and lower 4 bytes for the score -#### `pto.vmrgsort4(vec1: VRegType, vec2: VRegType, vec3: VRegType, vec4: VRegType, mask: MaskType) -> VRegType` +#### `pto.vmrgsort4(dest: ptr, src0: ptr, src1: ptr, src2: ptr, src3: ptr, count: pto.i64, config: pto.i64) -> None` [Advanced Tier] -**Description**: 4-way merge sort of vectors. +**Description**: Merge-sort 4 pre-sorted UB inputs. This op writes UB memory +directly and does not return a vector SSA value. **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| -| `vec1` | `VRegType` | First input vector | -| `vec2` | `VRegType` | Second input vector | -| `vec3` | `VRegType` | Third input vector | -| `vec4` | `VRegType` | Fourth input vector | -| `mask` | `MaskType` | Predicate mask | +| `dest` | `ptr` | Destination pointer in UB memory space | +| `src0` | `ptr` | First pre-sorted input pointer in UB memory space | +| `src1` | `ptr` | Second pre-sorted input pointer in UB memory space | +| `src2` | `ptr` | Third pre-sorted input pointer in UB memory space | +| `src3` | `ptr` | Fourth pre-sorted input pointer in UB memory space | +| `count` | `pto.i64` | Number of valid input elements participating in the merge | +| `config` | `pto.i64` | Operation control word encoding sort behavior | **Returns**: -| Return Value | Type | Description | -|--------------|------|-------------| -| `result` | `VRegType` | Merged and sorted vector | +None. The op writes UB memory directly. + +**Constraints**: +- `dest` and `src0` through `src3` must be UB-backed pointers +- Inputs must already be sorted according to the order encoded by `config` **Order Mode Enum**: The `OrderMode` enum provides type-safe order selection for `pto.vci` operations. Currently only `ASC` (ascending order) is supported, with more order options planned for future releases. diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index c9a2611c3..749744685 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -2741,19 +2741,37 @@ def _lower_call_expr( ) return _RenderedValue(name=result_name, type=expr.type) + if expr.name == "vbitsort": + destination = self._lower_expr(expr.args[0], env, indent=indent, into=into) + source = self._lower_expr(expr.args[1], env, indent=indent, into=into) + indices = self._lower_expr(expr.args[2], env, indent=indent, into=into) + repeat_times = self._lower_expr(expr.args[3], env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"pto.vbitsort {destination.name}, {source.name}, {indices.name}, {repeat_times.name} : " + + f"{self._render_type(destination.type)}, {self._render_type(source.type)}, " + + f"{self._render_type(indices.type)}, {self._render_type(repeat_times.type)}" + ) + return _RenderedValue(name="__void_call__", type=SemanticMetaType(kind="void")) + if expr.name == "vmrgsort4": - vec0 = self._lower_expr(expr.args[0], env, indent=indent, into=into) - vec1 = self._lower_expr(expr.args[1], env, indent=indent, into=into) - vec2 = self._lower_expr(expr.args[2], env, indent=indent, into=into) - vec3 = self._lower_expr(expr.args[3], env, indent=indent, into=into) - mask = self._lower_expr(expr.args[4], env, indent=indent, into=into) + destination = self._lower_expr(expr.args[0], env, indent=indent, into=into) + source0 = self._lower_expr(expr.args[1], env, indent=indent, into=into) + source1 = self._lower_expr(expr.args[2], env, indent=indent, into=into) + source2 = self._lower_expr(expr.args[3], env, indent=indent, into=into) + source3 = self._lower_expr(expr.args[4], env, indent=indent, into=into) + count = self._lower_expr(expr.args[5], env, indent=indent, into=into) + config = self._lower_expr(expr.args[6], env, indent=indent, into=into) + count = self._coerce_rendered_value(count, _I64_TYPE, indent=indent, into=into) + config = self._coerce_rendered_value(config, _I64_TYPE, indent=indent, into=into) into.append( self._indent(indent) - + f"{result_name} = pto.vmrgsort4 {vec0.name}, {vec1.name}, {vec2.name}, {vec3.name}, {mask.name} : " - + f"{self._render_type(vec0.type)}, {self._render_type(vec1.type)}, {self._render_type(vec2.type)}, " - + f"{self._render_type(vec3.type)}, {self._render_type(mask.type)} -> {self._render_type(expr.type)}" + + f"pto.vmrgsort4 {destination.name}, {source0.name}, {source1.name}, {source2.name}, {source3.name}, " + + f"{count.name}, {config.name} : {self._render_type(destination.type)}, {self._render_type(source0.type)}, " + + f"{self._render_type(source1.type)}, {self._render_type(source2.type)}, {self._render_type(source3.type)}, " + + f"{self._render_type(count.type)}, {self._render_type(config.type)}" ) - return _RenderedValue(name=result_name, type=expr.type) + return _RenderedValue(name="__void_call__", type=SemanticMetaType(kind="void")) if expr.name in { "vabs", @@ -2777,7 +2795,6 @@ def _lower_call_expr( "vsqz", "vexpdiff", "vtrc", - "vbitsort", "vcgadd", "vcgmax", "vcgmin", diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 963598d16..5eb6e6cfd 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -161,7 +161,6 @@ "vsqz", "vexpdiff", "vtrc", - "vbitsort", "vcgadd", "vcgmax", "vcgmin", @@ -229,7 +228,7 @@ | _PREDICATE_MOVEMENT_OPS | _CARRY_OPS | _REARRANGEMENT_OPS - | {"vcvt", "vmrgsort4"} + | {"vcvt", "vbitsort", "vmrgsort4"} ) _TENSORVIEW_RANK = 5 @@ -3633,6 +3632,8 @@ def _analyze_call_expr( return self._analyze_rearrangement_op(name, args) if name == "vcvt": return self._analyze_vcvt(args) + if name == "vbitsort": + return self._analyze_vbitsort(args) if name == "vmrgsort4": return self._analyze_vmrgsort4(args) if name in _BROADCAST_VECTOR_OPS: @@ -4423,17 +4424,36 @@ def _analyze_vcvt( type=self._vreg_type_for_dtype(target_dtype), ) + def _analyze_vbitsort(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) != 4: + raise TypeError("pto.vbitsort expects exactly 4 positional arguments in TileLang DSL v1") + destination = self._require_pointer_expr(args[0], "pto.vbitsort destination", memory_space="ub") + source = self._require_pointer_expr(args[1], "pto.vbitsort source", memory_space="ub") + indices = self._require_pointer_expr(args[2], "pto.vbitsort indices", memory_space="ub") + self._require_index_typed_expr(args[3]) + return SemanticCallExpr( + namespace="pto", + name="vbitsort", + args=(destination, source, indices, args[3]), + type=None, + ) + def _analyze_vmrgsort4(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: - if len(args) != 5: - raise TypeError("pto.vmrgsort4 expects exactly 5 positional arguments in TileLang DSL") - vec0 = self._require_vreg_expr(args[0], "pto.vmrgsort4 vec0") - vec1 = self._require_vreg_expr(args[1], "pto.vmrgsort4 vec1") - vec2 = self._require_vreg_expr(args[2], "pto.vmrgsort4 vec2") - vec3 = self._require_vreg_expr(args[3], "pto.vmrgsort4 vec3") - if not (vec0 == vec1 == vec2 == vec3): - raise TypeError("pto.vmrgsort4 requires all vector operands to use the same vector type") - self._require_mask_for_vreg(args[4], vec0, "pto.vmrgsort4") - return SemanticCallExpr(namespace="pto", name="vmrgsort4", args=args, type=vec0) + if len(args) != 7: + raise TypeError("pto.vmrgsort4 expects exactly 7 positional arguments in TileLang DSL v1") + destination = self._require_pointer_expr(args[0], "pto.vmrgsort4 destination", memory_space="ub") + source0 = self._require_pointer_expr(args[1], "pto.vmrgsort4 src0", memory_space="ub") + source1 = self._require_pointer_expr(args[2], "pto.vmrgsort4 src1", memory_space="ub") + source2 = self._require_pointer_expr(args[3], "pto.vmrgsort4 src2", memory_space="ub") + source3 = self._require_pointer_expr(args[4], "pto.vmrgsort4 src3", memory_space="ub") + self._require_i64_like_expr(args[5], "pto.vmrgsort4 count") + self._require_i64_like_expr(args[6], "pto.vmrgsort4 config") + return SemanticCallExpr( + namespace="pto", + name="vmrgsort4", + args=(destination, source0, source1, source2, source3, args[5], args[6]), + type=None, + ) def _require_dtype_symbol(self, expr: SemanticExpr, context: str) -> ScalarType: if not ( @@ -4912,7 +4932,7 @@ def _validate_unary_dtype(self, name: str, dtype: ScalarType) -> None: is_integer_dtype(dtype) and integer_bitwidth(dtype) in {8, 16, 32} ): raise TypeError(f"pto.{name} only supports integer vector dtypes in TileLang DSL v1") - if name in {"vabs", "vneg", "vmov", "vtrc", "vbitsort", "vcadd", "vcmax", "vcmin"} and not ( + if name in {"vabs", "vneg", "vmov", "vtrc", "vcadd", "vcmax", "vcmin"} and not ( (is_integer_dtype(dtype) and integer_bitwidth(dtype) in {8, 16, 32}) or dtype.name in {"f16", "bf16", "f32"} ): raise TypeError(f"pto.{name} does not support this dtype in TileLang DSL v1") diff --git a/tilelang-dsl/python/tilelang_dsl/support_matrix.py b/tilelang-dsl/python/tilelang_dsl/support_matrix.py index e574d9bad..61f17edc1 100644 --- a/tilelang-dsl/python/tilelang_dsl/support_matrix.py +++ b/tilelang-dsl/python/tilelang_dsl/support_matrix.py @@ -96,7 +96,6 @@ "vsqz", "vexpdiff", "vtrc", - "vbitsort", "vbr", "vdup", "vadd", @@ -141,7 +140,6 @@ "vsort32", "vmrgsort", "vcvt", - "vmrgsort4", "vci", } ) @@ -165,6 +163,8 @@ "vdintlv", "vintlvv2", "vdintlvv2", + "vbitsort", + "vmrgsort4", } ) diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index 47f4b230c..d379e74a7 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -323,6 +323,8 @@ def test_stable_starter_surface_groups_map_to_stable_tier(self) -> None: self.assertEqual(get_feature_tier("pto.vsort32"), BASIC_TIER) self.assertEqual(get_feature_tier("pto.vldsx2"), BASIC_TIER) self.assertEqual(get_feature_tier("pto.vstsx2"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vbitsort"), ADVANCED_TIER) + self.assertEqual(get_feature_tier("pto.vmrgsort4"), ADVANCED_TIER) self.assertEqual(get_feature_tier("PadMode"), BASIC_TIER) self.assertEqual(get_feature_tier("VRegType"), BASIC_TIER) self.assertEqual(get_feature_tier("MaskType"), BASIC_TIER) @@ -2401,11 +2403,9 @@ def kernel(dst: pto.Tile, src: pto.Tile, alpha: pto.f32): out = pto.vcmin(out, all_mask) out = pto.vmov(out, all_mask) out = pto.vtrc(out, all_mask) - out = pto.vbitsort(out, all_mask) out = pto.vprelu(out, vec1, all_mask) out = pto.vlrelu(out, alpha, all_mask) out = pto.vcvt(out, pto.f32, all_mask) - out = pto.vmrgsort4(out, vec1, vec2, vec3, all_mask) pto.vsts(out, dst, 0, all_mask) return None @@ -2425,11 +2425,9 @@ def kernel(dst: pto.Tile, src: pto.Tile, alpha: pto.f32): self.assertIn("pto.vcmin", text) self.assertIn("pto.vmov", text) self.assertIn("pto.vtrc", text) - self.assertIn("pto.vbitsort", text) self.assertIn("pto.vprelu", text) self.assertIn("pto.vlrelu", text) self.assertIn("pto.vcvt", text) - self.assertIn("pto.vmrgsort4", text) def test_vcvt_supports_keyword_attrs_with_enums(self) -> None: @pto.vkernel( @@ -2463,6 +2461,46 @@ def kernel(dst: pto.Tile, src: pto.Tile): self.assertIn('sat = "SAT"', text) self.assertIn('part = "ODD"', text) + def test_advanced_sort_memory_ops_surface_lower(self) -> None: + @pto.vkernel( + op="advanced_sort_memory_ops_unique", + dtypes=[(pto.f32, pto.f32, pto.i32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile, idx: pto.Tile): + dst_ptr = dst.as_ptr() + src_ptr = src.as_ptr() + idx_ptr = idx.as_ptr() + + pto.vbitsort(dst_ptr, src_ptr, idx_ptr, 1) + pto.vmrgsort4( + dst_ptr, + src_ptr, + pto.addptr(src_ptr, 64), + pto.addptr(src_ptr, 128), + pto.addptr(src_ptr, 192), + pto.i64(64), + pto.i64(0), + ) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 256), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 256), memory_space=pto.MemorySpace.UB), + idx=pto.TileSpecialization(shape=(8, 256), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertRegex( + text, + r"pto\.vbitsort %dst_ptr_\d+, %src_ptr_\d+, %idx_ptr_\d+, %c1 : !pto\.ptr, !pto\.ptr, !pto\.ptr, index", + ) + self.assertRegex( + text, + r"pto\.vmrgsort4 %dst_ptr_\d+, %src_ptr_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %c\d+_i64, %c\d+_i64 : " + r"!pto\.ptr, !pto\.ptr, !pto\.ptr, !pto\.ptr, !pto\.ptr, i64, i64", + ) + def test_vcvt_rejects_legacy_string_spellings(self) -> None: with self.assertRaises(TypeError) as ctx: From caa47c3cab624f8d1835ee7adf1949d42734f231 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Tue, 14 Apr 2026 21:17:25 +0800 Subject: [PATCH 064/192] [PTOAS] wire MLIR IR printing through ptoas pipelines --- include/PTO/Transforms/Passes.h | 2 ++ lib/PTO/Transforms/InferPTOLayout.cpp | 6 ++++++ lib/PTO/Transforms/PTOViewToMemref.cpp | 1 + lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 5 +++++ test/basic/mlir_print_ir_debug.pto | 30 ++++++++++++++++++++++++++ tools/ptoas/ptoas.cpp | 25 +++++++++++++++++++++ 6 files changed, 69 insertions(+) create mode 100644 test/basic/mlir_print_ir_debug.pto diff --git a/include/PTO/Transforms/Passes.h b/include/PTO/Transforms/Passes.h index 120401299..2994168a0 100644 --- a/include/PTO/Transforms/Passes.h +++ b/include/PTO/Transforms/Passes.h @@ -78,6 +78,8 @@ std::unique_ptr createExpandTileOpPass(const ExpandTileOpOptions &options) std::unique_ptr createFoldTileBufIntrinsicsPass(); std::unique_ptr createPTOInlineLibCallPass(const PTOInlineLibCallOptions &options = {}); +void registerPTOViewToMemrefPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/lib/PTO/Transforms/InferPTOLayout.cpp b/lib/PTO/Transforms/InferPTOLayout.cpp index 7d5ea8a2c..2611f52ef 100644 --- a/lib/PTO/Transforms/InferPTOLayout.cpp +++ b/lib/PTO/Transforms/InferPTOLayout.cpp @@ -534,6 +534,12 @@ struct InferPTOLayoutPass : public mlir::pto::impl::InferPTOLayoutBase { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InferPTOLayoutPass) + StringRef getArgument() const final { return "pto-infer-layout"; } + + StringRef getDescription() const final { + return "Infer GlobalTensor layout (ND/DN/NZ) for make_tensor_view"; + } + void runOnOperation() override { func::FuncOp func = getOperation(); // ------------------------------------------------------------------ diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index 7a7869e1d..62cd8a3be 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -28,6 +28,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" namespace mlir { namespace pto { diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index 1aa877ef6..b54898fe6 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -4873,6 +4873,11 @@ static LogicalResult runPipeline(ModuleOp module, llvm::raw_ostream &diagOS, pm.addPass(createConvertFuncToLLVMPass()); pm.addPass(createConvertControlFlowToLLVMPass()); pm.addPass(createReconcileUnrealizedCastsPass()); + if (failed(mlir::applyPassManagerCLOptions(pm))) { + diagOS << "VPTO LLVM emission failed: unable to apply MLIR pass manager " + "command-line options\n"; + return failure(); + } if (failed(pm.run(clonedModule))) { diagOS << "VPTO LLVM emission failed: official lowering pipeline failed\n"; return failure(); diff --git a/test/basic/mlir_print_ir_debug.pto b/test/basic/mlir_print_ir_debug.pto new file mode 100644 index 000000000..4234b3563 --- /dev/null +++ b/test/basic/mlir_print_ir_debug.pto @@ -0,0 +1,30 @@ +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-after=pto-resolve-reserved-buffers %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=MAIN +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-after=vpto-ptr-normalize %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=VPTO + +module attributes {pto.target_arch = "a5"} { + func.func @kernel(%arg0: i32) attributes { pto.tilelang.instance } { + %0 = func.call @__tl_inline_add1_i32(%arg0) : (i32) -> i32 + func.call @__tl_inline_sink_i32(%0) : (i32) -> () + return + } + + func.func private @__tl_inline_add1_i32(%x: i32) -> i32 attributes { pto.tilelang.inline_proc } { + %c1 = arith.constant 1 : i32 + %v = arith.addi %x, %c1 : i32 + return %v : i32 + } + + func.func private @__tl_inline_sink_i32(%x: i32) attributes { pto.tilelang.inline_proc } { + %c2 = arith.constant 2 : i32 + %t = arith.addi %x, %c2 : i32 + return + } +} + +// MAIN: // -----// IR Dump After +// MAIN-SAME: (pto-resolve-reserved-buffers) +// MAIN: func.func @kernel + +// VPTO: // -----// IR Dump After +// VPTO-SAME: (vpto-ptr-normalize) +// VPTO: func.func @kernel diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index 40b12e719..49e5397b4 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -59,6 +59,16 @@ static void printPTOASVersion(llvm::raw_ostream &os) { os << "ptoas " << PTOAS_RELEASE_VERSION << "\n"; } +static LogicalResult applyConfiguredPassManagerCLOptions( + PassManager &pm, llvm::StringRef pipelineName, + llvm::raw_ostream &diagOS = llvm::errs()) { + if (succeeded(mlir::applyPassManagerCLOptions(pm))) + return success(); + diagOS << "Error: failed to apply MLIR pass manager command-line options for " + << pipelineName << ".\n"; + return failure(); +} + static LogicalResult reorderEmitCFunctions(ModuleOp module) { SmallVector declarations; SmallVector definitions; @@ -1091,6 +1101,8 @@ static LogicalResult prepareVPTOForEmission(ModuleOp module) { cleanupPM.enableVerifier(); cleanupPM.addPass(createCanonicalizerPass()); cleanupPM.addPass(createCSEPass()); + if (failed(applyConfiguredPassManagerCLOptions(cleanupPM, "VPTO cleanup"))) + return failure(); if (failed(cleanupPM.run(module))) { llvm::errs() << "Error: VPTO pre-emission cleanup failed.\n"; return failure(); @@ -1101,6 +1113,9 @@ static LogicalResult prepareVPTOForEmission(ModuleOp module) { boundaryPM.addPass(pto::createVPTOPtrNormalizePass()); boundaryPM.addPass(pto::createVPTOPtrCastCleanupPass()); boundaryPM.addPass(createReconcileUnrealizedCastsPass()); + if (failed(applyConfiguredPassManagerCLOptions(boundaryPM, + "VPTO ptr normalization"))) + return failure(); if (failed(boundaryPM.run(module))) { llvm::errs() << "Error: VPTO ptr normalization failed.\n"; return failure(); @@ -1111,6 +1126,9 @@ static LogicalResult prepareVPTOForEmission(ModuleOp module) { prepPM.addNestedPass(createPTOVPTOExpandBridgeOpsPass()); prepPM.addPass(createCSEPass()); prepPM.addPass(pto::createPTOValidateVPTOEmissionIRPass()); + if (failed(applyConfiguredPassManagerCLOptions(prepPM, + "VPTO emission preparation"))) + return failure(); if (failed(prepPM.run(module))) { llvm::errs() << "Error: VPTO emission preparation failed.\n"; return failure(); @@ -1146,6 +1164,9 @@ static LogicalResult lowerPTOToVPTOBackend(ModuleOp module) { backendPM.addPass(mlir::createSCCPPass()); backendPM.addPass(mlir::createCanonicalizerPass()); } + if (failed(applyConfiguredPassManagerCLOptions(backendPM, + "VPTO backend lowering"))) + return failure(); if (failed(backendPM.run(module))) { llvm::errs() << "Error: backend lowering pass execution failed.\n"; return failure(); @@ -1221,6 +1242,8 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); mlir::registerAllPasses(); + ::registerPTOPasses(); + mlir::pto::registerPTOViewToMemrefPass(); ::registerPTOInlineLibCall(); ::registerFoldTileBufIntrinsics(); ::registerExpandTileOp(); @@ -1497,6 +1520,8 @@ int main(int argc, char **argv) { } pm.addPass(createCSEPass()); + if (failed(applyConfiguredPassManagerCLOptions(pm, "main PTOAS pipeline"))) + return 1; module->getOperation()->setAttr("pto.target_arch", mlir::StringAttr::get(&context, arch)); From dde729398eb5483000b19b1e5840f5adbfa38226 Mon Sep 17 00:00:00 2001 From: qukelin Date: Tue, 14 Apr 2026 23:05:07 +0800 Subject: [PATCH 065/192] Use select_kernel for expand helper selection --- docs/designs/ptoas-tileop-expand-design.md | 22 ++- lib/PTO/Transforms/ExpandTileOp.cpp | 28 ++- .../python/tilelang_dsl/expand_helper.py | 178 +++++++++++++----- tilelang-dsl/python/tilelang_dsl/kernel.py | 44 +++-- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 145 +++++++++++++- 5 files changed, 343 insertions(+), 74 deletions(-) diff --git a/docs/designs/ptoas-tileop-expand-design.md b/docs/designs/ptoas-tileop-expand-design.md index b840f4459..607159d40 100644 --- a/docs/designs/ptoas-tileop-expand-design.md +++ b/docs/designs/ptoas-tileop-expand-design.md @@ -475,30 +475,32 @@ Step 4: 生成调用并替换原 Tile Op | 操作数类型 | IR 类型 | 参与 SpecKey 的字段 | 不参与 SpecKey 但传给 Python DSL 的字段 | |-----------|---------|--------------------|-----------------------------------------| -| **Tile** | `TileBufType` | `dtype` + `shape` + `memorySpace` + `config`(blayout/slayout/fractal/pad) | — | +| **Tile** | `TileBufType` | `dtype` + `shape` + `valid_shape` + `memorySpace` + `config`(blayout/slayout/fractal/pad) | — | | **View** | `MemRefType`(降级后的 `PartitionTensorViewType`) | `dtype` | `shape`、`strides`、`memorySpace`(仅用于约束检查) | | **Scalar** | 标量类型 | `dtype` | — | -**View 操作数的特化策略**:View 对应的模板参数类型为 `!pto.partition_tensor_view`,维度全部动态,shape/strides 通过 intrinsic 在运行时查询。因此不同 view shape 的 Tile op 可以共享同一份模板实例——`shape`/`strides`/`memorySpace` 不参与 SpecKey 的判等和 hash。这些字段通过 `--operand-specs` JSON 传给 Python DSL 的 `expand_helper`,注入到约束上下文中(如 `src.strides[4] == 1`),但不影响模板代码生成。 +**View 操作数的特化策略**:View 对应的模板参数类型为 `!pto.partition_tensor_view`,维度全部动态,shape/strides 通过 intrinsic 在运行时查询。因此不同 view shape 的 Tile op 可以共享同一份模板实例——`shape`/`strides`/`memorySpace` 不参与 SpecKey 的判等和 hash。这些字段通过 `--operand-specs` JSON 传给 Python DSL 的 `expand_helper`,先按操作数位置构造成 `arg0_*`、`arg1_*` 一类的位置化上下文,再在 constraint evaluation 阶段按模板参数顺序映射到当前参数名(如 `src` / `dst`)后参与约束检查;它们不直接影响模板代码生成。 -**Tile 操作数的排除字段**:`valid_shape` 不参与 SpecKey——因为它可能是动态的,作为运行时值在 inline 后通过 `pto.tile_valid_rows`/`pto.tile_valid_cols` 提取。相同 `(op, operand_types)` 但不同 `valid_shape` 的 Tile op 可以共享同一份实例化结果。 +**Tile 操作数的特化策略**:当前实现中,`valid_shape` 参与 SpecKey,并与 `shape`、`memorySpace`、`config` 一起决定模板实例和缓存 key。也就是说,相同 `(op, operand_types)` 但不同 `valid_shape` 的 Tile op 当前会生成不同的实例化结果。约束检查和缓存命名都基于这一实现语义。 #### 3.2.2 模板实例化过程 Expand TileOp 通过调用 Python 子进程来实例化模板。具体流程: -1. **调用 Python helper**:`python3 -m tilelang_dsl.expand_helper --op pto. --operand-specs `,其中 JSON 描述每个操作数的类型信息。 +1. **调用 Python helper**:`python3 -m tilelang_dsl.expand_helper --target --op pto. --operand-specs `,其中 JSON 描述每个操作数的类型信息。 2. **Python 端处理**: - 扫描模板目录下的 `.py` 文件,查找标注了 `@pto.vkernel` 装饰器的模板函数 - - 按 `op` 名称和 `dtype` 签名匹配模板 - - 对 `pto.Tile` 参数使用给定的 shape 和 memory_space 进行特化 - - 对 `pto.PartitionTensorView` 参数,将 shape/strides 注入约束上下文用于前置条件检查,但不影响模板特化(参数类型保持全动态) + - 先按操作数个数和参数种类(`tile` / `view` / `scalar`)做 schema 预过滤 + - 基于 `operand_specs` 构造按位置组织的上下文属性(如 `arg0_shape`、`arg0_strides`、`arg1_config`) + - 调用 `pto.select_kernel(target, concrete_op, operand_types, context_attrs, registry)` 按 `target → op → dtypes → constraints → priority` 规则选择模板 + - 对 `pto.Tile` 参数使用给定的 shape / valid_shape / memory_space / config 进行特化 + - 对 `pto.PartitionTensorView` 参数,不做 `specialize()`,而是通过位置化上下文把 shape/strides/memorySpace 提供给前置条件检查(参数类型保持全动态) - 输出特化后的 MLIR 文本 3. **C++ 端处理**: - 解析 MLIR 文本为 `ModuleOp` - 提取 `func.func`,克隆到目标 Module 末尾 - - 重命名为 `__pto_tilelang__tile____view__...`(Tile 操作数拼 shape,View/Scalar 只拼 dtype),设为 `private` 可见性 - - 存入 specCache + - 重命名为 `__pto_tilelang___tile____view__...`(Tile 操作数拼 shape/valid_shape/config,View/Scalar 只拼 dtype),设为 `private` 可见性 + - 按 `target + op + operand schema` 存入 specCache **关键约束**:Python DSL 实例化输出的函数需要满足以下要求: @@ -826,7 +828,7 @@ def template_xxx(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): 算子,这意味着 `ins` 操作数在前、`outs` 在后。例如 `pto.tadd ins(%a, %b) outs(%c)` 的操作数顺序为 `(src0, src1, dst)`,模板参数必须为 `(src0, src1, dst)`。 -`expand_helper.py` 自动扫描目录下所有 `.py` 文件,按 `op` 名称和 `dtype` 签名匹配模板。 +`expand_helper.py` 自动扫描目录下所有 `.py` 文件,先按参数 schema 过滤候选,再通过 `select_kernel()` 按 `target`、`op`、`dtype`、`constraints` 和 `priority` 选择模板。模板约束读取的位置化上下文由 `argN_*` 键提供,并在 constraint evaluation 阶段按参数顺序映射到模板自己的参数名。 ## 第四章 前置工作 diff --git a/lib/PTO/Transforms/ExpandTileOp.cpp b/lib/PTO/Transforms/ExpandTileOp.cpp index c4d4833ee..b05a652b6 100644 --- a/lib/PTO/Transforms/ExpandTileOp.cpp +++ b/lib/PTO/Transforms/ExpandTileOp.cpp @@ -33,6 +33,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/IRMapping.h" @@ -125,18 +126,20 @@ struct OperandTypeInfo { // ============================================================================ struct SpecKey { std::string opName; + std::string targetArch; SmallVector operands; bool operator==(const SpecKey &rhs) const { - return opName == rhs.opName && operands == rhs.operands; + return opName == rhs.opName && targetArch == rhs.targetArch && + operands == rhs.operands; } }; struct SpecKeyInfo : public llvm::DenseMapInfo { - static inline SpecKey getEmptyKey() { return {"", {}}; } - static inline SpecKey getTombstoneKey() { return {"__tombstone__", {}}; } + static inline SpecKey getEmptyKey() { return {"", "", {}}; } + static inline SpecKey getTombstoneKey() { return {"__tombstone__", "", {}}; } static unsigned getHashValue(const SpecKey &key) { - unsigned h = llvm::hash_value(key.opName); + unsigned h = llvm::hash_combine(key.opName, key.targetArch); for (const auto &op : key.operands) { h = llvm::hash_combine(h, static_cast(op.kind), op.dtype); if (op.kind == OperandKind::Tile) { @@ -175,6 +178,15 @@ static StringRef getTileOpName(Operation *op) { return op->getName().stripDialect(); } +static std::string getTargetArchString(ModuleOp mod) { + if (!mod) + return ""; + auto targetAttr = mod->getAttrOfType("pto.target_arch"); + if (!targetAttr) + return ""; + return targetAttr.getValue().str(); +} + static std::string getMemorySpaceString(pto::TileBufType tbTy) { auto msAttr = dyn_cast_or_null(tbTy.getMemorySpace()); if (!msAttr) return "ub"; @@ -257,6 +269,7 @@ static std::optional buildOperandTypeInfo(Type ty) { static std::optional buildSpecKey(Operation *op) { SpecKey key; key.opName = getTileOpName(op).str(); + key.targetArch = getTargetArchString(op->getParentOfType()); for (unsigned i = 0; i < op->getNumOperands(); ++i) { auto info = buildOperandTypeInfo(op->getOperand(i).getType()); @@ -382,6 +395,10 @@ func::FuncOp ExpandState::invokeTilelangDSL(const SpecKey &key, // 2. Build operand schema JSON for mixed tile/scalar specialization. std::string operandSpecsJson = buildOperandSpecsJson(key); + if (key.targetArch.empty()) { + llvm::errs() << "ExpandTileOp: missing pto.target_arch module attribute\n"; + return nullptr; + } // 3. Create temp file for stdout redirect. SmallString<128> tmpPath; @@ -399,6 +416,7 @@ func::FuncOp ExpandState::invokeTilelangDSL(const SpecKey &key, SmallVector args = { *pythonPath, "-m", "tilelang_dsl.expand_helper", "--template-dir", tilelangPath, + "--target", key.targetArch, "--op", opName, "--operand-specs", operandSpecsJson, }; @@ -498,7 +516,7 @@ func::FuncOp ExpandState::invokeTilelangDSL(const SpecKey &key, llvm::StringMap renamedSymbols; // Build a unique name from the spec-key-relevant operand fields. - std::string uniqueName = "__pto_tilelang_" + key.opName; + std::string uniqueName = "__pto_tilelang_" + key.targetArch + "_" + key.opName; for (const auto &op : key.operands) { uniqueName += op.kind == OperandKind::Tile ? "_tile" : op.kind == OperandKind::View ? "_view" diff --git a/tilelang-dsl/python/tilelang_dsl/expand_helper.py b/tilelang-dsl/python/tilelang_dsl/expand_helper.py index 44fed81b3..d2c776126 100644 --- a/tilelang-dsl/python/tilelang_dsl/expand_helper.py +++ b/tilelang-dsl/python/tilelang_dsl/expand_helper.py @@ -3,6 +3,7 @@ Usage: python3 -m tilelang_dsl.expand_helper \ --template-dir /path/to/templates \ + --target a5 \ --op pto.tadd \ --dtype f32 \ --shape 16,64 \ @@ -21,7 +22,12 @@ import sys from pathlib import Path -from .kernel import VKernelDescriptor, _match_descriptor_dtype_signature +from .kernel import ( + KernelRegistry, + VKernelDescriptor, + _match_descriptor_dtype_signature, + select_kernel, +) from .types import MemorySpace, ScalarType, TileConfig, TileSpecialization @@ -85,22 +91,33 @@ def _import_py_file(path: Path): return mod +def _bind_descriptor_for_query( + descriptor: VKernelDescriptor, + target: str, + op_name: str, + operand_types: tuple[ScalarType, ...], +) -> VKernelDescriptor | None: + if descriptor.target != target or op_name not in descriptor.match_ops: + return None + op_bound = descriptor._bind_selected_op(op_name) + matched_signature = _match_descriptor_dtype_signature(op_bound, operand_types) + if matched_signature is None: + return None + if op_bound._selected_dtype_signature == matched_signature: + return op_bound + return op_bound._bind_selected_dtype_signature(matched_signature) + + def _match_descriptor( descriptors: list[VKernelDescriptor], op_name: str, operand_types: tuple[ScalarType, ...], ) -> VKernelDescriptor | None: - """Find and bind the first descriptor matching (op, dtype).""" + """Legacy helper: find and bind the first descriptor matching (op, dtype).""" for desc in descriptors: - if op_name not in desc.match_ops: - continue - op_bound = desc._bind_selected_op(op_name) - matched_signature = _match_descriptor_dtype_signature(op_bound, operand_types) - if matched_signature is None: - continue - if op_bound._selected_dtype_signature == matched_signature: - return op_bound - return op_bound._bind_selected_dtype_signature(matched_signature) + bound = _bind_descriptor_for_query(desc, "a5", op_name, operand_types) + if bound is not None: + return bound return None @@ -186,9 +203,101 @@ def _parse_operand_specs(spec_text: str) -> list[dict]: return specs +def _operand_spec_matches_param_kind(param_kind: str, operand_kind: str) -> bool: + if operand_kind == "tile": + return param_kind == "tile" + if operand_kind == "view": + return param_kind in ("tensorview", "partition_tensor_view") + if operand_kind == "scalar": + return param_kind == "scalar" + return False + + +def _filter_descriptors_by_operand_schema( + descriptors: list[VKernelDescriptor], + *, + target: str, + op_name: str, + operand_specs: list[dict], +) -> list[VKernelDescriptor]: + operand_types = tuple(spec["dtype"] for spec in operand_specs) + filtered: list[VKernelDescriptor] = [] + for descriptor in descriptors: + bound = _bind_descriptor_for_query(descriptor, target, op_name, operand_types) + if bound is None: + continue + parameters = bound.parameters + if len(parameters) != len(operand_specs): + continue + if all( + _operand_spec_matches_param_kind(param.kind, operand_spec["kind"]) + for param, operand_spec in zip(parameters, operand_specs) + ): + filtered.append(bound) + return filtered + + +def _build_positional_context_attrs(operand_specs: list[dict]) -> dict[str, object]: + attrs: dict[str, object] = {} + for index, operand_spec in enumerate(operand_specs): + prefix = f"arg{index}" + attrs[f"{prefix}_kind"] = operand_spec["kind"] + attrs[f"{prefix}_dtype"] = operand_spec["dtype"] + if operand_spec["kind"] == "scalar": + continue + shape = tuple(operand_spec["shape"]) + attrs[f"{prefix}_shape"] = shape + attrs[f"{prefix}_rank"] = len(shape) + memory_space = operand_spec.get("memory_space") + if isinstance(memory_space, MemorySpace): + attrs[f"{prefix}_memory_space"] = memory_space.value + elif memory_space is not None: + attrs[f"{prefix}_memory_space"] = memory_space + if operand_spec["kind"] == "tile": + valid_shape = operand_spec.get("valid_shape") + effective_valid_shape = shape if valid_shape is None else tuple(valid_shape) + attrs[f"{prefix}_valid_shape"] = effective_valid_shape + if operand_spec.get("config") is not None: + attrs[f"{prefix}_config"] = operand_spec["config"] + continue + if "strides" in operand_spec: + attrs[f"{prefix}_strides"] = tuple(operand_spec["strides"]) + return attrs + + +def _select_descriptor( + descriptors: list[VKernelDescriptor], + *, + target: str, + op_name: str, + operand_specs: list[dict], +) -> VKernelDescriptor: + filtered_descriptors = _filter_descriptors_by_operand_schema( + descriptors, + target=target, + op_name=op_name, + operand_specs=operand_specs, + ) + operand_types = tuple(spec["dtype"] for spec in operand_specs) + if not filtered_descriptors: + raise LookupError( + "expand_helper found no registered kernel after operand schema filtering for " + f"target={target!r}, op={op_name!r}, operand_types={operand_types!r}" + ) + registry = KernelRegistry(tuple(filtered_descriptors)) + return select_kernel( + target, + op_name, + operand_types, + context_attrs=_build_positional_context_attrs(operand_specs), + registry=registry, + ) + + def main(argv: list[str] | None = None) -> int: parser = argparse.ArgumentParser(description="TileLang DSL expand helper") parser.add_argument("--template-dir", required=True, help="Directory of .py templates") + parser.add_argument("--target", default="a5", help="Target architecture, e.g. a5") parser.add_argument("--op", required=True, help="Tile op name, e.g. pto.tadd") parser.add_argument("--dtype", help="Element dtype, e.g. f32") parser.add_argument("--shape", help="Tile shape, e.g. 16,64") @@ -206,11 +315,11 @@ def main(argv: list[str] | None = None) -> int: operand_specs: list[dict] | None = None if args.operand_specs: - try: - operand_specs = _parse_operand_specs(args.operand_specs) - except ValueError as exc: - print(f"expand_helper: error: {exc}", file=sys.stderr) - return 1 + try: + operand_specs = _parse_operand_specs(args.operand_specs) + except ValueError as exc: + print(f"expand_helper: error: {exc}", file=sys.stderr) + return 1 else: if args.dtype is None or args.shape is None: print( @@ -243,30 +352,19 @@ def main(argv: list[str] | None = None) -> int: print(f"expand_helper: error: no @vkernel descriptors found in {template_dir}", file=sys.stderr) return 1 - # Match. - operand_types = tuple(spec["dtype"] for spec in operand_specs) - desc = _match_descriptor(all_descriptors, args.op, operand_types) - if desc is None: - print( - f"expand_helper: error: no template matches op={args.op} operand_types={operand_types!r}", - file=sys.stderr, - ) - return 1 - - if len(desc.parameters) != len(operand_specs): - print( - "expand_helper: error: descriptor parameter count does not match operand-specs", - file=sys.stderr, + try: + desc = _select_descriptor( + all_descriptors, + target=args.target, + op_name=args.op, + operand_specs=operand_specs, ) + except Exception as exc: + print(f"expand_helper: error: {exc}", file=sys.stderr) return 1 # Specialize Tile parameters positionally from operand-specs. - # View operands match tensorview/partition_tensor_view parameters without - # specialization — shape/strides are resolved dynamically via intrinsics. - # However, their shape/strides are injected into the constraint context so - # that precondition checks (e.g. src.strides[4] == 1) can evaluate. tile_specs = {} - view_context_attrs: dict[str, object] = {} for param, operand_spec in zip(desc.parameters, operand_specs): if param.kind == "tile": if operand_spec["kind"] != "tile": @@ -290,10 +388,6 @@ def main(argv: list[str] | None = None) -> int: file=sys.stderr, ) return 1 - # Inject shape/strides for constraint evaluation. - view_context_attrs[f"{param.name}_shape"] = operand_spec["shape"] - if "strides" in operand_spec: - view_context_attrs[f"{param.name}_strides"] = operand_spec["strides"] continue if param.kind == "scalar" and operand_spec["kind"] != "scalar": print( @@ -302,12 +396,6 @@ def main(argv: list[str] | None = None) -> int: ) return 1 - # Bind view context attrs so constraint checking has access to shape/strides. - if view_context_attrs: - desc = desc._bind_constraint_context_attrs( - {**desc.constraint_context_attrs, **view_context_attrs} - ) - specialized = desc.specialize(**tile_specs) # Emit MLIR to stdout. diff --git a/tilelang-dsl/python/tilelang_dsl/kernel.py b/tilelang-dsl/python/tilelang_dsl/kernel.py index 56c8d74f9..eb1b086d7 100644 --- a/tilelang-dsl/python/tilelang_dsl/kernel.py +++ b/tilelang-dsl/python/tilelang_dsl/kernel.py @@ -996,21 +996,43 @@ def _constraint_context_for_evaluation( attrs.setdefault("op", self._selected_op) attrs.setdefault("selected_op", self._selected_op) - for spec in self._parameter_specs: + for index, spec in enumerate(self._parameter_specs): existing = attrs.get(spec.name) param_attrs = {} if not isinstance(existing, dict) else dict(existing) + positional_prefix = f"arg{index}" param_attrs.setdefault("kind", spec.kind) attrs.setdefault(f"{spec.name}_kind", spec.kind) - if f"{spec.name}_shape" in attrs: - param_attrs.setdefault("shape", tuple(attrs[f"{spec.name}_shape"])) - if f"{spec.name}_valid_shape" in attrs: - param_attrs.setdefault("valid_shape", tuple(attrs[f"{spec.name}_valid_shape"])) - if f"{spec.name}_strides" in attrs: - param_attrs.setdefault("strides", tuple(attrs[f"{spec.name}_strides"])) - if f"{spec.name}_rank" in attrs: - param_attrs.setdefault("rank", attrs[f"{spec.name}_rank"]) - if f"{spec.name}_memory_space" in attrs: - param_attrs.setdefault("memory_space", attrs[f"{spec.name}_memory_space"]) + + def set_sequence_attr(attr_name: str) -> None: + named_key = f"{spec.name}_{attr_name}" + positional_key = f"{positional_prefix}_{attr_name}" + if named_key in attrs: + value = tuple(attrs[named_key]) + elif positional_key in attrs: + value = tuple(attrs[positional_key]) + attrs.setdefault(named_key, value) + else: + return + param_attrs.setdefault(attr_name, value) + + def set_scalar_attr(attr_name: str) -> None: + named_key = f"{spec.name}_{attr_name}" + positional_key = f"{positional_prefix}_{attr_name}" + if named_key in attrs: + value = attrs[named_key] + elif positional_key in attrs: + value = attrs[positional_key] + attrs.setdefault(named_key, value) + else: + return + param_attrs.setdefault(attr_name, value) + + set_sequence_attr("shape") + set_sequence_attr("valid_shape") + set_sequence_attr("strides") + set_scalar_attr("rank") + set_scalar_attr("memory_space") + set_scalar_attr("config") if spec.kind in ("tensorview", "partition_tensor_view"): # TensorView authoring form is normalized to 5D in the current DSL spec. diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index d379e74a7..cae4d60df 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -246,10 +246,11 @@ def kernel(src: pto.Tile, dst: pto.Tile): ] """ ) - desc = expand_helper._match_descriptor( + desc = expand_helper._select_descriptor( descriptors, - "pto.expand_helper_tile_config_unique", - tuple(spec["dtype"] for spec in operand_specs), + target="a5", + op_name="pto.expand_helper_tile_config_unique", + operand_specs=operand_specs, ) self.assertIsNotNone(desc) @@ -277,6 +278,81 @@ def kernel(src: pto.Tile, dst: pto.Tile): mlir_text, ) + def test_select_descriptor_uses_positional_context_for_named_constraints(self) -> None: + source = """ +import tilelang_dsl as pto + +@pto.vkernel( + target="a5", + op="pto.expand_helper_positional_constraints_unique", + dtypes=[(pto.f32, pto.f32)], + constraints=[ + lambda src: src.rank == 5, + lambda src: src.strides[4] == 1, + lambda dst: dst.config.b_layout == pto.BLayout.ROW_MAJOR, + ], +) +def template_nd(src: pto.TensorView, dst: pto.Tile): + return None + +@pto.vkernel( + target="a5", + op="pto.expand_helper_positional_constraints_unique", + dtypes=[(pto.f32, pto.f32)], + constraints=[ + lambda inp: inp.rank == 5, + lambda out: out.config.b_layout == pto.BLayout.COL_MAJOR, + ], + priority=9, +) +def template_dn(inp: pto.TensorView, out: pto.Tile): + return None +""" + with tempfile.TemporaryDirectory() as tmpdir: + module_path = Path(tmpdir) / "expand_helper_positional_constraints_unique.py" + module_path.write_text(source, encoding="utf-8") + + mod = expand_helper._import_py_file(module_path) + self.assertIsNotNone(mod) + descriptors = expand_helper._find_descriptors(mod) + self.assertTrue(descriptors) + + operand_specs = expand_helper._parse_operand_specs( + """ +[ + { + "kind": "view", + "dtype": "f32", + "shape": [1, 1, 1, 16, 64], + "strides": [1024, 1024, 1024, 64, 1], + "memory_space": "gm" + }, + { + "kind": "tile", + "dtype": "f32", + "shape": [16, 64], + "valid_shape": [16, 64], + "memory_space": "ub", + "config": { + "b_layout": "row_major", + "s_layout": "none_box", + "s_fractal_size": 512, + "pad_value": "0x0" + } + } +] +""" + ) + + selected = expand_helper._select_descriptor( + descriptors, + target="a5", + op_name="pto.expand_helper_positional_constraints_unique", + operand_specs=operand_specs, + ) + + self.assertEqual(selected.name, "template_nd") + class TileLangDSLSupportMatrixTests(unittest.TestCase): def test_stable_starter_surface_groups_map_to_stable_tier(self) -> None: @@ -754,6 +830,69 @@ def kernel(src: pto.TensorView, dst: pto.Tile): with self.assertRaises(LookupError): rejected.mlir_text() + def test_select_kernel_supports_positional_context_attrs(self) -> None: + @pto.vkernel( + op="matcher_positional_context_unique", + dtypes=[(pto.f32, pto.f32)], + constraints=[ + lambda src: src.rank == 5, + lambda src: src.strides[4] == 1, + lambda dst: dst.config.b_layout == pto.BLayout.ROW_MAJOR, + ], + ) + def template_nd(src: pto.TensorView, dst: pto.Tile): + return None + + @pto.vkernel( + op="matcher_positional_context_unique", + dtypes=[(pto.f32, pto.f32)], + constraints=[ + lambda inp: inp.rank == 5, + lambda out: out.config.b_layout == pto.BLayout.COL_MAJOR, + ], + priority=9, + ) + def template_dn(inp: pto.TensorView, out: pto.Tile): + return None + + operand_specs = expand_helper._parse_operand_specs( + """ +[ + { + "kind": "view", + "dtype": "f32", + "shape": [1, 1, 1, 16, 64], + "strides": [1024, 1024, 1024, 64, 1], + "memory_space": "gm" + }, + { + "kind": "tile", + "dtype": "f32", + "shape": [16, 64], + "valid_shape": [16, 64], + "memory_space": "ub", + "config": { + "b_layout": "row_major", + "s_layout": "none_box", + "s_fractal_size": 512, + "pad_value": "0x0" + } + } +] +""" + ) + + registry = pto.KernelRegistry((template_nd, template_dn)) + selected = pto.select_kernel( + "a5", + "matcher_positional_context_unique", + (pto.f32, pto.f32), + context_attrs=expand_helper._build_positional_context_attrs(operand_specs), + registry=registry, + ) + + self.assertEqual(selected.name, "template_nd") + def test_select_kernel_binds_selected_op_for_multi_op_descriptor(self) -> None: @pto.vkernel( ops=["matcher_multi_op_bind_add_unique", "matcher_multi_op_bind_sub_unique"], From b3091a931c8bad20b0b8308142cefb8261370e5e Mon Sep 17 00:00:00 2001 From: qukelin Date: Tue, 14 Apr 2026 23:19:41 +0800 Subject: [PATCH 066/192] Recover static view strides in ExpandTileOp --- lib/PTO/Transforms/ExpandTileOp.cpp | 121 +++++++++++++++++++++++++--- 1 file changed, 111 insertions(+), 10 deletions(-) diff --git a/lib/PTO/Transforms/ExpandTileOp.cpp b/lib/PTO/Transforms/ExpandTileOp.cpp index b05a652b6..829b49f58 100644 --- a/lib/PTO/Transforms/ExpandTileOp.cpp +++ b/lib/PTO/Transforms/ExpandTileOp.cpp @@ -43,6 +43,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" @@ -79,11 +80,11 @@ namespace { // Three kinds of operands: // Tile — from TileBufType. dtype + shape + memorySpace + config // all participate in the specialization key (SpecKey). -// View — from MemRefType (lowered PartitionTensorViewType). Only dtype +// View — from MemRefType (lowered PartitionTensorViewType). Only dtype // participates in SpecKey — the template is fully dynamic so -// shape/strides/memorySpace don't affect code generation. They -// are carried here solely for JSON serialization to the Python -// DSL for constraint checking. +// shape/strides/memorySpace don't affect code generation. They are +// carried here solely for JSON serialization to the Python DSL for +// constraint checking. // Scalar — from a scalar element type. Only dtype participates in SpecKey. // ============================================================================ enum class OperandKind { Tile, View, Scalar }; @@ -215,7 +216,103 @@ static std::string getSLayoutString(int32_t slayout) { return "none_box"; } -static std::optional buildOperandTypeInfo(Type ty) { +static bool getStaticIntFromValue(Value value, int64_t &out) { + if (auto cOp = value.getDefiningOp()) { + out = cOp.value(); + return true; + } + if (auto cInt = value.getDefiningOp()) { + out = cInt.value(); + return true; + } + return false; +} + +static int64_t getStaticIntOrDynamic(OpFoldResult ofr) { + if (auto attr = ofr.dyn_cast()) { + if (auto intAttr = dyn_cast(attr)) + return intAttr.getInt(); + return ShapedType::kDynamic; + } + auto value = llvm::cast(ofr); + int64_t result = ShapedType::kDynamic; + if (getStaticIntFromValue(value, result)) + return result; + return ShapedType::kDynamic; +} + +static void recordStaticSizes(ArrayRef inputs, + SmallVectorImpl &out) { + out.clear(); + out.reserve(inputs.size()); + for (OpFoldResult ofr : inputs) + out.push_back(getStaticIntOrDynamic(ofr)); +} + +static SmallVector combineSubviewStrides(ArrayRef baseStrides, + ArrayRef steps) { + SmallVector result; + result.reserve(baseStrides.size()); + for (auto [baseStride, step] : llvm::zip(baseStrides, steps)) { + int64_t stepValue = getStaticIntOrDynamic(step); + if (baseStride == ShapedType::kDynamic || + stepValue == ShapedType::kDynamic) { + result.push_back(ShapedType::kDynamic); + continue; + } + result.push_back(baseStride * stepValue); + } + return result; +} + +static void populateViewShapeAndStrides(Value value, + SmallVectorImpl &shape, + SmallVectorImpl &strides) { + if (!value) + return; + + if (auto subview = value.getDefiningOp()) { + populateViewShapeAndStrides(subview.getSource(), shape, strides); + SmallVector subviewShape; + recordStaticSizes(subview.getMixedSizes(), subviewShape); + if (!subviewShape.empty()) + shape = subviewShape; + if (!strides.empty()) + strides = combineSubviewStrides(strides, subview.getMixedStrides()); + return; + } + + if (auto reinterpret = value.getDefiningOp()) { + if (shape.empty()) { + SmallVector reinterpretShape; + recordStaticSizes(reinterpret.getMixedSizes(), reinterpretShape); + if (!reinterpretShape.empty()) + shape = reinterpretShape; + } + if (strides.empty()) + recordStaticSizes(reinterpret.getMixedStrides(), strides); + return; + } + + if (auto cast = value.getDefiningOp()) { + populateViewShapeAndStrides(cast.getSource(), shape, strides); + return; + } + + if (auto memrefTy = dyn_cast(value.getType())) { + if (shape.empty()) + shape.assign(memrefTy.getShape().begin(), memrefTy.getShape().end()); + if (strides.empty()) { + int64_t offset = ShapedType::kDynamic; + if (succeeded(getStridesAndOffset(memrefTy, strides, offset))) { + // strides populated — dynamic dims remain ShapedType::kDynamic. + } + } + } +} + +static std::optional buildOperandTypeInfo(Value value) { + Type ty = value.getType(); // Tile operand — from TileBufType. if (auto tbTy = dyn_cast(ty)) { OperandTypeInfo info; @@ -248,11 +345,15 @@ static std::optional buildOperandTypeInfo(Type ty) { info.dtype = getDtypeString(mrTy.getElementType()); if (info.dtype.empty()) return std::nullopt; - info.viewShape.assign(mrTy.getShape().begin(), mrTy.getShape().end()); info.viewMemorySpace = getMemorySpaceString(mrTy); - int64_t offset = ShapedType::kDynamic; - if (succeeded(getStridesAndOffset(mrTy, info.viewStrides, offset))) { - // strides populated — dynamic dims remain ShapedType::kDynamic. + populateViewShapeAndStrides(value, info.viewShape, info.viewStrides); + if (info.viewShape.empty()) + info.viewShape.assign(mrTy.getShape().begin(), mrTy.getShape().end()); + if (info.viewStrides.empty()) { + int64_t offset = ShapedType::kDynamic; + if (succeeded(getStridesAndOffset(mrTy, info.viewStrides, offset))) { + // strides populated — dynamic dims remain ShapedType::kDynamic. + } } return info; } @@ -272,7 +373,7 @@ static std::optional buildSpecKey(Operation *op) { key.targetArch = getTargetArchString(op->getParentOfType()); for (unsigned i = 0; i < op->getNumOperands(); ++i) { - auto info = buildOperandTypeInfo(op->getOperand(i).getType()); + auto info = buildOperandTypeInfo(op->getOperand(i)); if (!info) return std::nullopt; key.operands.push_back(*info); From 34d037818f28f7a5d7cba3185c8d028c299b867a Mon Sep 17 00:00:00 2001 From: qukelin Date: Wed, 15 Apr 2026 00:14:12 +0800 Subject: [PATCH 067/192] Refine tilelang ST compare flow --- docs/designs/ptoas-tileop-expand-design.md | 43 ++++++-- docs/designs/tilelang-st-framework.md | 102 +++++++++++++----- .../npu/a5/src/st/testcase/compare.py | 16 --- .../npu/a5/src/st/testcase/st_common.py | 76 +++++++------ .../npu/a5/src/st/testcase/tadd/compare.py | 49 +++++++++ test/tilelang_st/script/run_st.py | 4 +- 6 files changed, 195 insertions(+), 95 deletions(-) delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/compare.py diff --git a/docs/designs/ptoas-tileop-expand-design.md b/docs/designs/ptoas-tileop-expand-design.md index 607159d40..1ad4ad63d 100644 --- a/docs/designs/ptoas-tileop-expand-design.md +++ b/docs/designs/ptoas-tileop-expand-design.md @@ -1058,7 +1058,7 @@ run_st.py - `gen_data.py` 会基于 `cases.py` 中的 `CASES` 为每个 case 生成 `input*.bin` 和 `golden.bin` - host 可执行文件会按 `main.cpp` 中的 case table 逐个读取 `.//input*.bin`,运行 kernel,并写回 `.//output.bin` -- `compare.py` 再基于同一份 `CASES` 定义逐 case 做 `numpy.allclose` 比较 +- `compare.py` 再基于同一份 `CASES` 定义逐 case 读取并裁剪需要比较的数据,最后调用公共 `result_cmp()` - 若传入 `-c `,则运行和比较都只针对单个 case 因此,TileLang ST 的验证对象不是“某一份中间 IR 是否长得对”,而是: @@ -1070,10 +1070,10 @@ run_st.py 编译子链由 `testcase/CMakeLists.txt` 中的 `pto_tilelang_vec_st()` 宏自动接管,整条执行链路则由 `run_st.py` 统一调度。 -##### 新增 testcase 所需文件(七件套) +##### 新增 testcase 所需文件(七个文件 + 一个注册修改) 以新增 `pto.tsub` 为例,需在 `test/tilelang_st/npu/a5/src/st/testcase/tsub/` 下准备 -6 个文件,并修改 1 个注册文件: +7 个文件,并修改 1 个注册文件: **1. `CMakeLists.txt`** — 通常只有一行: @@ -1099,7 +1099,9 @@ CASES = [ ] ``` -每个 case 必须包含 `name`/`dtype`/`shape`/`valid_shape`/`eps` 五个字段,`valid_shape` 为必填。 +常规 case 必须包含 `name`/`dtype`/`shape`/`valid_shape`/`eps` 五个字段,`valid_shape` 为必填。 +如果输出 shape 与输入不同(如 `trowsum`),再额外补 `dst_shape`/`dst_valid_shape`,供 +`compare.py` 和 `gen_data.py` 使用。 **3. `tsub.pto`** — kernel 描述,一个文件中放多个 case 对应的函数。每个函数 代表一种 dtype/shape 组合。以 tadd 为参考,kernel 结构为: @@ -1208,9 +1210,27 @@ for case in CASES: 注意 golden 的计算逻辑必须与 op 语义一致(tadd 是加法,tsub 是减法),且只在 `valid_shape` 区域内计算。 -`compare.py` 为公共脚本(位于 `testcase/compare.py`),所有 testcase 共享,无需 per-testcase 编写。 -`run_st.py` 运行时将它与 per-testcase 的 `cases.py` 一起拷贝到 build 目录,自动读取 case 列表和阈值, -只比较 `valid_shape` 区域。exit code 2 表示失败。 +**7. `compare.py`** — 每个 testcase 自己维护比较脚本。公共层只提供 +`st_common.result_cmp(golden, output, eps)`,具体比较哪些数据由 testcase 自己决定。 + +以 `tsub` 这种输入输出 shape 一致的 case 为例,核心逻辑通常是: + +```python +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + +validate_cases(CASES) + +for case in CASES: + shape = case["shape"] + vr, vc = case["valid_shape"] + golden = np.fromfile(os.path.join(case["name"], "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case["name"], "output.bin"), dtype=case["dtype"]).reshape(shape) + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) +``` + +如果是 `trowsum` 这类输出 shape 不同的 op,则 `compare.py` 可以自己按 `dst_shape` reshape, +并只比较 `dst_valid_shape` 对应的有效区域。exit code 2 表示失败。 精度阈值参考: @@ -1221,7 +1241,7 @@ for case in CASES: | `bfloat16` | `1e-2` | | `int8/int16/int32` | `0`(精确匹配) | -**7. 注册** — 修改 `testcase/CMakeLists.txt`,将新 op 加入 `ALL_TESTCASES`: +**8. 注册** — 修改 `testcase/CMakeLists.txt`,将新 op 加入 `ALL_TESTCASES`: ```cmake set(ALL_TESTCASES @@ -1242,8 +1262,9 @@ set(ALL_TESTCASES | 参数顺序 | `.pto` → `launch.cpp` → `main.cpp` 的 launch 调用 | `(a, b) → c` | | shape / valid_shape | `cases.py` ↔ `.pto` tile shape ↔ `main.cpp` rows/cols/validRows/validCols | `16×64` / `(16, 64)` | -Python 侧的 case 名、dtype、shape、valid_shape、eps 已通过 `cases.py` 收敛为单一来源。 -但 C++ 侧 `main.cpp` 的 `kCases[]` 和 `.pto` 仍需手动与 `cases.py` 保持一致。 +Python 侧的 case 名、dtype、shape、valid_shape、eps(以及必要时的 `dst_shape` / +`dst_valid_shape`)已通过 `cases.py` 收敛为单一来源。但 C++ 侧 `main.cpp` 的 `kCases[]` +和 `.pto` 仍需手动与 `cases.py` 保持一致。 任何一处不一致都可能导致:编译成功但运行时 segfault,或运行成功但比较结果错误且难以定位。 @@ -1275,9 +1296,9 @@ python3 test/tilelang_st/script/run_st.py -r sim -v a5 -t tsub -c f32_16x64 -w ```text build/testcase/tsub/ ├── st_common.py # 从 testcase/ 公共目录拷贝 -├── compare.py # 从 testcase/ 公共目录拷贝 ├── cases.py # 从 testcase/tsub/ 拷贝 ├── gen_data.py # 从 testcase/tsub/ 拷贝 +├── compare.py # 从 testcase/tsub/ 拷贝 ├── f32_16x64/ │ ├── input1.bin │ ├── input2.bin diff --git a/docs/designs/tilelang-st-framework.md b/docs/designs/tilelang-st-framework.md index df411d981..a869c4056 100644 --- a/docs/designs/tilelang-st-framework.md +++ b/docs/designs/tilelang-st-framework.md @@ -15,6 +15,7 @@ - 支持在一个 testcase 下放多个 case - 支持 `sim` / `npu` 两种运行模式 - 支持单 case 过滤 +- 支持 `src` / `dst` 逻辑 shape 不一致的 testcase(例如 `trowsum` 这类 reduction) - 支持把输入、golden、output 隔离到 `build/testcase//` 下,避免不同 testcase 之间互相覆盖 ## 2. 框架定位 @@ -107,14 +108,14 @@ test/tilelang_st/ ├── run_ptoas_to_file.cmake ├── repack_tilelang_kernel.sh ├── st_common.py - ├── compare.py └── tadd/ ├── CMakeLists.txt ├── cases.py ├── tadd.pto ├── launch.cpp ├── main.cpp - └── gen_data.py + ├── gen_data.py + └── compare.py ``` 各文件职责如下: @@ -126,13 +127,13 @@ test/tilelang_st/ | `testcase/CMakeLists.txt` | 定义 `pto_tilelang_vec_st()` 宏,并注册所有 testcase | | `testcase/run_ptoas_to_file.cmake` | 封装 `ptoas` 调用,把 `.pto` 编译成 LLVM IR | | `testcase/repack_tilelang_kernel.sh` | 把 device-only `.o` 包装成 host 可链接的 fatobj | -| `testcase/st_common.py` | 所有 testcase 共享的 Python 公共模块(case 校验、数据生成辅助、精度比较、终端着色) | -| `testcase/compare.py` | 公共比较脚本,所有 testcase 共享,从 per-testcase 的 `cases.py` 导入 `CASES` 后调用 `st_common.run_compare()` | -| `testcase//cases.py` | **case 定义的单一来源**,`gen_data.py` 和 `compare.py` 均从此导入 | +| `testcase/st_common.py` | 所有 testcase 共享的 Python 公共模块(case 校验、数据生成辅助、`result_cmp`、终端着色) | +| `testcase//cases.py` | **case 定义的单一来源**,`gen_data.py` 和 `compare.py` 均从此导入;默认使用 `shape`/`valid_shape`,像 `trowsum` 这类输出 shape 不同的 op 再额外补 `dst_shape`/`dst_valid_shape` | | `testcase//.pto` | testcase 的 kernel 描述,通常一个文件中放多个 case 对应的函数 | | `testcase//launch.cpp` | kernel 声明和 launch wrapper | | `testcase//main.cpp` | host driver,负责分配内存、launch kernel、回写 output(`ACL_CHECK` 宏由公共头 `test_common.h` 提供) | | `testcase//gen_data.py` | 生成 input 与 golden,从 `cases.py` 读取 case 列表 | +| `testcase//compare.py` | 每个 testcase 自己的比较脚本,决定读取哪些 bin、reshape 成什么形状、裁哪一块数据,再调用公共 `result_cmp()` | ## 5. 日常使用方式 @@ -234,11 +235,12 @@ build/testcase/tadd/ | 文件 | 是否新增/修改 | 说明 | |---|---|---| | `testcase/tsub/CMakeLists.txt` | 新增 | 一般只有一行 `pto_tilelang_vec_st(tsub)` | -| `testcase/tsub/cases.py` | 新增 | **case 定义的单一来源**:每个 case 必须指定 `name`/`dtype`/`shape`/`valid_shape`/`eps` | +| `testcase/tsub/cases.py` | 新增 | **case 定义的单一来源**:每个 case 必须指定 `name`/`dtype`/`shape`/`valid_shape`/`eps`;如果输出 shape 不同,再额外补 `dst_shape`/`dst_valid_shape` | | `testcase/tsub/tsub.pto` | 新增 | 定义一个或多个 case 的 kernel 函数 | | `testcase/tsub/launch.cpp` | 新增 | 为每个 kernel 函数声明 entry 并提供 launch wrapper | | `testcase/tsub/main.cpp` | 新增 | host driver,负责 case table、内存拷贝、launch 和 output 落盘 | | `testcase/tsub/gen_data.py` | 新增 | 生成每个 case 的输入和 golden,从 `cases.py` 导入 `CASES` | +| `testcase/tsub/compare.py` | 新增 | testcase 自己决定比较哪些输出数据,再调用公共 `result_cmp()` | | `testcase/CMakeLists.txt` | 修改 | 把 `tsub` 加入 `ALL_TESTCASES` | 通常不需要修改: @@ -246,7 +248,6 @@ build/testcase/tadd/ - `script/run_st.py` - `src/st/CMakeLists.txt` - `testcase/st_common.py` -- `testcase/compare.py`(公共脚本,所有 testcase 共享) - `testcase/run_ptoas_to_file.cmake` - `testcase/repack_tilelang_kernel.sh` @@ -332,7 +333,7 @@ void LaunchTADD_f32_16x64(float *a, float *b, float *c, void *stream) { 你需要做的事主要有三类: 1. 声明所有 `LaunchTADD_*` wrapper -2. 在 `kCases[]` 中列出每个 case 的名字、launch 函数、shape、valid shape、元素大小 +2. 在 `kCases[]` 中列出每个 case 的名字、launch 函数、输入/输出 shape、valid shape、元素大小 3. 在 `RunCase()` 中完成: - 从 `.//input*.bin` 读取输入 - `aclrtMemcpy` 把输入拷到 device @@ -362,7 +363,10 @@ static const TestCase kCases[] = { 注意:`ACL_CHECK` 宏已移至公共头文件 `test_common.h`(需在 `acl/acl.h` 之后包含),不需要在每个 testcase 的 `main.cpp` 中重复定义。 -你在新增 case 时,必须同步更新这个表,字段需与 `cases.py` 中的 `shape` / `valid_shape` 保持一致。 +你在新增 case 时,必须同步更新这个表。 + +- 对 `tadd` 这类同 shape op,字段需与 `cases.py` 的 `shape` / `valid_shape` 保持一致。 +- 对 `trowsum` 这类输出 shape 不同的 op,host 侧需要把输入大小和输出大小分开计算。 ### 7.5 `testcase/tadd/cases.py` @@ -370,6 +374,14 @@ static const TestCase kCases[] = { 每个 case 必须包含以下字段: +```python +"name" +"dtype" +"shape" +"valid_shape" +"eps" +``` + ```python CASES = [ { @@ -384,6 +396,24 @@ CASES = [ `valid_shape` 为必填字段。当 valid shape 等于 tile shape 时也必须显式写出。 +如果输出 shape 不同,可以额外补下面两个字段: + +```python +CASES = [ + { + "name": "f32_16x64", + "dtype": np.float32, + "shape": (16, 64), # 输入 tensor shape + "valid_shape": (16, 64), # 输入有效区域 + "dst_shape": (16, 1), # 输出 tensor shape(GM 可见形状) + "dst_valid_shape": (16, 1), # 输出有效区域 + "eps": 1e-5, + }, +] +``` + +这也是 `trowsum` 推荐使用的写法。注意 `dst_shape` 描述的是写回 GM 后的实际结果形状,而不是片上 tile 的物理展开形状。 + ### 7.6 `testcase/tadd/gen_data.py` 这个文件负责为每个 case 生成输入和 golden。从 `cases.py` 导入 `CASES`, @@ -399,29 +429,45 @@ golden[:vr, :vc] = (input1[:vr, :vc] + input2[:vr, :vc]).astype(dtype, copy=Fals golden 只在 `valid_shape` 区域内计算,区域外保持零值。 +如果是 `trowsum` 这类输出 shape 不同的 op,则 `gen_data.py` 应该按 `dst_shape` 生成 `golden`,按 `valid_shape` 完成规约计算。例如: + +```python +shape = case["shape"] +valid_shape = case["valid_shape"] +dst_shape = case["dst_shape"] +dst_valid_shape = case["dst_valid_shape"] +input1 = np.random.randint(1, 10, size=shape).astype(dtype) +golden = np.zeros(dst_shape, dtype=dtype) +golden[:dst_valid_shape[0], 0] = np.sum( + input1[:valid_shape[0], :valid_shape[1]], axis=1 +).astype(dtype, copy=False)[:dst_valid_shape[0]] +``` + +比较阶段也会按 `dst_shape` / `dst_valid_shape` 读取和 reshape `golden.bin`、`output.bin`。 + 每个 case 使用独立的随机 seed(`setup_case_rng` 基于 `hash(case["name"])`), 新增或调整 case 顺序不会影响已有 case 的测试数据。 -### 7.7 `testcase/compare.py`(公共,无需 per-testcase 修改) +### 7.7 `testcase//compare.py` -`compare.py` 位于 `testcase/` 公共目录,所有 testcase 共享同一份: +比较脚本不再放在公共目录,而是每个 testcase 自己维护一份。 -```python -from cases import CASES -from st_common import run_compare +这样做的目的很直接: + +- 公共层只提供 `result_cmp(golden, output, eps)` 这种“比已经准备好的数据”的接口 +- 具体读取哪些 bin、reshape 成什么形状、裁哪一块 valid 区域,由 testcase 自己决定 + +以 `tadd` 为例,`compare.py` 的核心逻辑就是: -if __name__ == "__main__": - run_compare(CASES) +```python +golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) +output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) +ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) ``` -`run_st.py` 运行时会将它和 per-testcase 的 `cases.py` 一起拷贝到 build 目录, -`compare.py` 通过 `from cases import CASES` 获取当前 testcase 的 case 列表。 +如果是 `trowsum`,则可以自己改成按 `dst_shape` reshape,并只比较 `rows x 1` 的有效区域。 -`run_compare()` 会: -- 校验所有 case 必填字段 -- 只在 `valid_shape` 区域内比较 `golden.bin` 与 `output.bin` -- 支持 `argv[1]` 作为 case filter -- exit code 2 表示失败 +这种拆法更接近 `pto-isa` 的 `ResultCmp` 思路:公共层只负责“怎么比”,不负责“该比哪块数据”。 ## 8. 如果只是在已有 `tadd` 下新增一个 case @@ -464,10 +510,14 @@ if __name__ == "__main__": `.pto` 里 kernel 的参数顺序、`launch.cpp` 声明顺序、`main.cpp` 里 launch wrapper 的参数顺序必须一致。 如果 `tadd` 的语义是 `(a, b) -> c`,那 host 侧和 compare 也都要按这个顺序组织。 -### 9.3 shape、valid_shape 和 dtype 一致 +### 9.3 shape、valid_shape、dst_shape 和 dtype 一致 + +`cases.py` 中的 shape 信息和 `dtype` 是 Python 侧的单一来源,`gen_data.py` 和 `compare.py` 自动从中读取。 + +- 对大多数 op,`shape`/`valid_shape` 就够了。 +- 对 `trowsum` 这类输出 shape 不同的 op,再额外维护 `dst_shape`/`dst_valid_shape`。 -`cases.py` 中的 `shape`/`valid_shape`/`dtype` 是 Python 侧的单一来源,`gen_data.py` 和 `compare.py` 自动从中读取。 -但 C++ 侧的 `main.cpp` `kCases[]`(`rows`/`cols`/`validRows`/`validCols`/`elemSize`)和 `.pto` 中的 tile shape 仍需手动与 `cases.py` 保持一致。 +但 C++ 侧的 `main.cpp` `kCases[]` 和 `.pto` 中的 tensor/tile shape 仍需手动与 `cases.py` 保持一致。 否则运行能成功,结果也可能是错误的,且定位会很耗时。 ## 10. 建议的开发验证节奏 diff --git a/test/tilelang_st/npu/a5/src/st/testcase/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/compare.py deleted file mode 100644 index ab6449f1c..000000000 --- a/test/tilelang_st/npu/a5/src/st/testcase/compare.py +++ /dev/null @@ -1,16 +0,0 @@ -#!/usr/bin/python3 -# Copyright (c) 2026 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. - -# coding=utf-8 - -from cases import CASES -from st_common import run_compare - -if __name__ == "__main__": - run_compare(CASES) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/st_common.py b/test/tilelang_st/npu/a5/src/st/testcase/st_common.py index 8f61fdb8c..d0401b202 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/st_common.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/st_common.py @@ -12,9 +12,8 @@ """Shared utilities for TileLang ST test cases. Provides: - - Case helpers: get_valid_shape() - Data helpers: setup_case_rng(), save_case_data() - - Compare: compare_bin(), run_compare() (full compare entry point) + - Compare: result_cmp() - Styling: supports_color(), style_pass(), style_fail() """ @@ -30,12 +29,42 @@ REQUIRED_CASE_KEYS = {"name", "dtype", "shape", "valid_shape", "eps"} +def _to_shape_tuple(shape): + if not isinstance(shape, (tuple, list)): + raise ValueError(f"shape must be tuple/list, got {type(shape).__name__}: {shape!r}") + if not shape: + raise ValueError("shape must not be empty") + dims = tuple(int(dim) for dim in shape) + if any(dim <= 0 for dim in dims): + raise ValueError(f"shape dimensions must be > 0, got {dims}") + return dims + + +def _validate_shape_pair(shape, valid_shape, label): + shape = _to_shape_tuple(shape) + valid_shape = _to_shape_tuple(valid_shape) + if len(shape) != len(valid_shape): + raise ValueError(f"{label}: shape rank mismatch: {shape} vs {valid_shape}") + if any(valid_dim > dim for dim, valid_dim in zip(shape, valid_shape)): + raise ValueError(f"{label}: valid shape {valid_shape} exceeds shape {shape}") + return shape, valid_shape + + def validate_cases(cases): """Check that every case has all required keys.""" for i, case in enumerate(cases): missing = REQUIRED_CASE_KEYS - case.keys() if missing: raise ValueError(f"cases[{i}] ({case.get('name', '?')}) missing keys: {missing}") + _validate_shape_pair(case["shape"], case["valid_shape"], "shape") + has_dst_shape = "dst_shape" in case + has_dst_valid_shape = "dst_valid_shape" in case + if has_dst_shape != has_dst_valid_shape: + raise ValueError( + f"cases[{i}] ({case.get('name', '?')}) must define both dst_shape and dst_valid_shape" + ) + if has_dst_shape: + _validate_shape_pair(case["dst_shape"], case["dst_valid_shape"], "dst") # --------------------------------------------------------------------------- @@ -93,17 +122,13 @@ def style_fail(text): # Comparison # --------------------------------------------------------------------------- -def compare_bin(golden_path, output_path, dtype, eps, shape, valid_shape): - """Compare golden and output binary files within the valid region. +def result_cmp(golden, output, eps): + """Compare already prepared golden/output arrays. - Returns True on pass, False on mismatch. + The caller is responsible for loading, reshaping and slicing data. """ - golden = np.fromfile(golden_path, dtype=dtype).reshape(shape) - output = np.fromfile(output_path, dtype=dtype).reshape(shape) - - vr, vc = valid_shape - g = golden[:vr, :vc].astype(np.float64, copy=False) - o = output[:vr, :vc].astype(np.float64, copy=False) + g = np.asarray(golden).astype(np.float64, copy=False) + o = np.asarray(output).astype(np.float64, copy=False) if g.shape != o.shape: print(style_fail(f"[ERROR] Shape mismatch: golden {g.shape} vs output {o.shape}")) @@ -116,32 +141,3 @@ def compare_bin(golden_path, output_path, dtype, eps, shape, valid_shape): f"(golden={g.flat[idx]}, output={o.flat[idx]})")) return False return True - - -def run_compare(cases): - """Main entry point for per-testcase compare.py scripts. - - Reads an optional case filter from sys.argv[1], iterates over *cases*, - and exits with code 2 if any comparison fails. - """ - validate_cases(cases) - case_filter = sys.argv[1] if len(sys.argv) > 1 else None - - all_passed = True - for case in cases: - if case_filter is not None and case["name"] != case_filter: - continue - case_dir = case["name"] - golden_path = os.path.join(case_dir, "golden.bin") - output_path = os.path.join(case_dir, "output.bin") - ok = compare_bin(golden_path, output_path, case["dtype"], case["eps"], - case["shape"], case["valid_shape"]) - if ok: - print(style_pass(f"[INFO] {case['name']}: compare passed")) - else: - print(style_fail(f"[ERROR] {case['name']}: compare failed")) - all_passed = False - - if not all_passed: - sys.exit(2) - print(style_pass("[INFO] all cases passed")) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tadd/compare.py new file mode 100644 index 000000000..428604929 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadd/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/script/run_st.py b/test/tilelang_st/script/run_st.py index f71fc0a55..996135490 100755 --- a/test/tilelang_st/script/run_st.py +++ b/test/tilelang_st/script/run_st.py @@ -144,13 +144,13 @@ def _copy_testcase_scripts(testcase): work_dir = get_testcase_work_dir(testcase) os.makedirs(work_dir, exist_ok=True) # Shared scripts (testcase/ level). - for name in ("st_common.py", "compare.py"): + for name in ("st_common.py",): src = os.path.join("testcase", name) if os.path.isfile(src): run_command(["cp", src, os.path.join(work_dir, name)]) # Per-testcase scripts. testcase_src = f"testcase/{testcase}" - for name in ("cases.py", "gen_data.py"): + for name in ("cases.py", "gen_data.py", "compare.py"): src = os.path.join(testcase_src, name) if os.path.isfile(src): run_command(["cp", src, os.path.join(work_dir, name)]) From d1c6fdd647b33d34bdb76e44b0a2c44a905036f0 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Wed, 15 Apr 2026 09:33:38 +0800 Subject: [PATCH 068/192] fix(tilelang-dsl): enforce vcvt attrs by type pair and align vcvt lowering (#63) --- lib/TileOps/trowargmax_template.py | 86 +++++++++++++ tilelang-dsl/python/tilelang_dsl/lowering.py | 14 ++- tilelang-dsl/python/tilelang_dsl/semantic.py | 124 +++++++++++++++++++ tilelang-dsl/tests/test_tilelang_dsl_v1.py | 84 +++++++++++++ 4 files changed, 303 insertions(+), 5 deletions(-) create mode 100644 lib/TileOps/trowargmax_template.py diff --git a/lib/TileOps/trowargmax_template.py b/lib/TileOps/trowargmax_template.py new file mode 100644 index 000000000..660151b3b --- /dev/null +++ b/lib/TileOps/trowargmax_template.py @@ -0,0 +1,86 @@ +"""TileLang DSL template for pto.trowargmax""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + +@pto.vkernel( + target="a5", + op="pto.trowargmax", + advanced=True, +) +def template_trowargmax(src: pto.Tile, tmp: pto.Tile, dst: pto.Tile): + src_dtype = src.element_type + idx_dtype = dst.element_type + lanes = pto.get_lanes(src_dtype) + valid_rows, valid_cols = src.valid_shape + + # Initialize with negative infinity for ROWARGMAX + if pto.constexpr(src_dtype == pto.f32): + init_val = pto.f32(0xFF800000) # Negative infinity + init_zero = pto.f32(0) + elif pto.constexpr(src_dtype == pto.f16): + init_val = pto.f16(0xFC00) # Negative infinity + init_zero = pto.f16(0) + + # Since index is valid in lane 0, we can use mask_1 + mask_1, _ = pto.make_mask(src_dtype, 1) + mask_1_final, _ = pto.make_mask(idx_dtype, 1) + + for row in range(0, valid_rows, 1): + remained = valid_cols + + v_val_acc = pto.vbr(init_val) + v_idx_acc = pto.vbr(init_zero) + v_zero = pto.vbr(init_zero) + + # Process all column chunks + for col in range(0, valid_cols, lanes): + mask, remained = pto.make_mask(src_dtype, remained) + v_src = pto.vlds(src[row, col:]) + v_reduced = pto.vcmax(v_src, mask) + + v_val, v_idx = pto.vdintlv(v_reduced, v_zero) + + # Add absolute col offset to the chunk's local index + if pto.constexpr(src_dtype == pto.f32): + v_col = pto.f32(col) + elif pto.constexpr(src_dtype == pto.f16): + v_col = pto.f16(col) + + v_idx = pto.vadds(v_idx, v_col, mask_1) + + # Compare current chunk max with global max so far + # vcmp returns a mask + cmp_mask = pto.vcmp(v_val_acc, v_val, mask_1, "lt") + + # Update global max and global argmax depending on who is greater + v_val_acc = pto.vsel(v_val, v_val_acc, cmp_mask) + v_idx_acc = pto.vsel(v_idx, v_idx_acc, cmp_mask) + + # Store the extracted index into the dst tile + if pto.constexpr(src_dtype != idx_dtype): + # vcvt attrs are type-pair sensitive in VPTO verifier: + # - f32 -> i32 requires rnd + sat + # - f16 -> i32 requires part + if pto.constexpr(src_dtype == pto.f32 and idx_dtype == pto.i32): + v_idx_acc_casted = pto.vcvt( + v_idx_acc, + idx_dtype, + mask_1_final, + rnd=pto.VcvtRoundMode.R, + sat=pto.VcvtSatMode.SAT, + ) + elif pto.constexpr(src_dtype == pto.f16 and idx_dtype == pto.i32): + v_idx_acc_casted = pto.vcvt( + v_idx_acc, + idx_dtype, + mask_1_final, + part=pto.VcvtPartMode.ODD, + ) + else: + v_idx_acc_casted = pto.vcvt(v_idx_acc, idx_dtype, mask_1_final) + pto.vsts(v_idx_acc_casted, dst[row, 0:], mask_1_final) + else: + pto.vsts(v_idx_acc, dst[row, 0:], mask_1_final) + return diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index 749744685..670b6fc7a 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -2724,8 +2724,6 @@ def _lower_call_expr( if expr.name == "vcvt": value = self._lower_expr(expr.args[0], env, indent=indent, into=into) - target_dtype = self._render_dtype_symbol(expr.args[1], context="pto.vcvt to_type") - mask = self._lower_expr(expr.args[2], env, indent=indent, into=into) attr_parts: list[str] = [] if self._has_optional_string_literal(expr.args[3]): attr_parts.append(f"rnd = {self._render_string_literal(expr.args[3])}") @@ -2736,8 +2734,8 @@ def _lower_call_expr( attr_suffix = f" {{{', '.join(attr_parts)}}}" if attr_parts else "" into.append( self._indent(indent) - + f"{result_name} = pto.vcvt {value.name}, {target_dtype}, {mask.name}{attr_suffix} : " - + f"{self._render_type(value.type)}, {self._render_type(mask.type)} -> {self._render_type(expr.type)}" + + f"{result_name} = pto.vcvt {value.name}{attr_suffix} : " + + f"{self._render_type(value.type)} -> {self._render_type(expr.type)}" ) return _RenderedValue(name=result_name, type=expr.type) @@ -3139,10 +3137,16 @@ def _scalar_int_sign(dtype: ScalarType) -> str: ) return _RenderedValue(name=cast_name, type=target_type) if target_type.dtype.name in {"f16", "bf16", "f32"}: + index_to_int_name = self._new_temp() + index_to_int_op = "arith.index_castui" + into.append( + self._indent(indent) + + f"{index_to_int_name} = {index_to_int_op} {value.name} : index to i64" + ) cast_name = self._new_temp() into.append( self._indent(indent) - + f"{cast_name} = arith.uitofp {value.name} : index to {target_type.dtype.name}" + + f"{cast_name} = arith.uitofp {index_to_int_name} : i64 to {target_type.dtype.name}" ) return _RenderedValue(name=cast_name, type=target_type) if isinstance(value.type, SemanticScalarType) and isinstance(target_type, SemanticScalarType): diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 5eb6e6cfd..9129a0d5c 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -139,6 +139,67 @@ _VCVT_SAT_MODE_SYMBOLS = {mode.name: mode for mode in VcvtSatMode} _VCVT_PART_MODE_SYMBOLS = {mode.name: mode for mode in VcvtPartMode} _POST_UPDATE_MODE_SYMBOLS = {mode.name: mode for mode in PostUpdateMode} +_VCVT_ATTR_CONTRACTS: dict[tuple[str, str], tuple[bool, bool, bool]] = { + # (src_kind, dst_kind): (requires_rnd, requires_sat, requires_part) + ("f32", "f16"): (True, True, True), + ("f32", "bf16"): (True, True, True), + ("f32", "s16"): (True, True, True), + ("f32", "s64"): (True, True, True), + ("f32", "s32"): (True, True, False), + ("f16", "f32"): (False, False, True), + ("f16", "s32"): (False, False, True), + ("f16", "s16"): (True, True, False), + ("f16", "s8"): (True, True, True), + ("f16", "u8"): (True, True, True), + ("bf16", "f32"): (False, False, True), + ("bf16", "s32"): (True, True, True), + ("u8", "f16"): (False, False, True), + ("u8", "u16"): (False, False, True), + ("u8", "u32"): (False, False, True), + ("s8", "f16"): (False, False, True), + ("s8", "s16"): (False, False, True), + ("s8", "s32"): (False, False, True), + ("u16", "u8"): (False, True, True), + ("u16", "u32"): (False, False, True), + ("s16", "f16"): (True, False, False), + ("s16", "f32"): (False, False, True), + ("s16", "u32"): (False, False, True), + ("s16", "s32"): (False, False, True), + ("s16", "u8"): (False, True, True), + ("u32", "u8"): (False, True, True), + ("u32", "u16"): (False, True, True), + ("u32", "s16"): (False, True, True), + ("s32", "f32"): (True, False, False), + ("s32", "u8"): (False, True, True), + ("s32", "u16"): (False, True, True), + ("s32", "s16"): (False, True, True), + ("s32", "s64"): (False, False, True), + ("s64", "f32"): (True, False, True), + ("s64", "s32"): (False, True, True), +} + + +def _classify_vcvt_elem_kind(dtype: ScalarType) -> str | None: + if dtype == f16: + return "f16" + if dtype == bf16: + return "bf16" + if dtype == f32: + return "f32" + if not is_integer_dtype(dtype): + return None + width = integer_bitwidth(dtype) + sign = integer_signedness(dtype) + is_unsigned = sign == "unsigned" + if width == 8: + return "u8" if is_unsigned else "s8" + if width == 16: + return "u16" if is_unsigned else "s16" + if width == 32: + return "u32" if is_unsigned else "s32" + if width == 64: + return None if is_unsigned else "s64" + return None _UNARY_VECTOR_OPS = { "vabs", "vrelu", @@ -4390,11 +4451,22 @@ def _analyze_vcvt_frontend_call( name: self._analyze_expr(value, env, allow_outer_lookup=allow_outer_lookup) for name, value in expr.keywords } + allowed_keywords = {"rnd", "sat", "part"} + unexpected_keywords = sorted(set(analyzed_keywords) - allowed_keywords) + if unexpected_keywords: + keyword_text = ", ".join(unexpected_keywords) + raise TypeError( + "pto.vcvt only accepts keyword attrs `rnd`, `sat`, and `part`; " + f"got unsupported keyword(s): {keyword_text}" + ) return self._analyze_vcvt( args, rnd=self._normalize_vcvt_round_mode(analyzed_keywords.get("rnd")), sat=self._normalize_vcvt_sat_mode(analyzed_keywords.get("sat")), part=self._normalize_vcvt_part_mode(analyzed_keywords.get("part")), + rnd_explicit="rnd" in analyzed_keywords, + sat_explicit="sat" in analyzed_keywords, + part_explicit="part" in analyzed_keywords, ) def _analyze_vcvt( @@ -4404,12 +4476,27 @@ def _analyze_vcvt( rnd: SemanticExpr | None = None, sat: SemanticExpr | None = None, part: SemanticExpr | None = None, + rnd_explicit: bool = False, + sat_explicit: bool = False, + part_explicit: bool = False, ) -> SemanticExpr: if len(args) != 3: raise TypeError("pto.vcvt expects exactly 3 positional arguments in TileLang DSL") vector = self._require_vreg_expr(args[0], "pto.vcvt vector") target_dtype = self._require_dtype_symbol(args[1], "pto.vcvt to_type") self._require_mask_for_vreg(args[2], vector, "pto.vcvt") + contract = self._lookup_vcvt_attr_contract(vector.element_dtype, target_dtype) + if contract is not None: + self._require_explicit_vcvt_attrs( + src_dtype=vector.element_dtype, + dst_dtype=target_dtype, + rnd_required=contract[0], + sat_required=contract[1], + part_required=contract[2], + rnd_explicit=rnd_explicit, + sat_explicit=sat_explicit, + part_explicit=part_explicit, + ) return SemanticCallExpr( namespace="pto", name="vcvt", @@ -4724,6 +4811,43 @@ def _normalize_vcvt_part_mode(self, expr: SemanticExpr | None) -> SemanticExpr | ) return SemanticLiteralExpr(value=part_mode, type=SemanticMetaType(kind="string")) + def _lookup_vcvt_attr_contract( + self, src_dtype: ScalarType, dst_dtype: ScalarType + ) -> tuple[bool, bool, bool] | None: + src_kind = _classify_vcvt_elem_kind(src_dtype) + dst_kind = _classify_vcvt_elem_kind(dst_dtype) + if src_kind is None or dst_kind is None: + return None + return _VCVT_ATTR_CONTRACTS.get((src_kind, dst_kind)) + + def _require_explicit_vcvt_attrs( + self, + *, + src_dtype: ScalarType, + dst_dtype: ScalarType, + rnd_required: bool, + sat_required: bool, + part_required: bool, + rnd_explicit: bool, + sat_explicit: bool, + part_explicit: bool, + ) -> None: + pair = f"{src_dtype.name}->{dst_dtype.name}" + + def _check(attr_name: str, required: bool, explicit: bool) -> None: + if required and not explicit: + raise TypeError( + f"pto.vcvt {pair} requires explicit `{attr_name}=` in TileLang DSL v1" + ) + if not required and explicit: + raise TypeError( + f"pto.vcvt {pair} does not accept `{attr_name}=` for this type pair in TileLang DSL v1" + ) + + _check("rnd", rnd_required, rnd_explicit) + _check("sat", sat_required, sat_explicit) + _check("part", part_required, part_explicit) + def _require_mask_expr(self, expr: SemanticExpr, context: str) -> SemanticMaskType: if not isinstance(expr.type, SemanticMaskType): raise TypeError(f"{context} must be a mask value in TileLang DSL") diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index cae4d60df..ce5b4379e 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -2599,6 +2599,10 @@ def kernel(dst: pto.Tile, src: pto.Tile): self.assertIn('rnd = "R"', text) self.assertIn('sat = "SAT"', text) self.assertIn('part = "ODD"', text) + self.assertRegex( + text, + r"= pto\.vcvt %[^,\s]+(?: \{[^}]+\})? : !pto\.vreg<[^>]+> -> !pto\.vreg<[^>]+>", + ) def test_advanced_sort_memory_ops_surface_lower(self) -> None: @pto.vkernel( @@ -2671,6 +2675,86 @@ def kernel(dst: pto.Tile, src: pto.Tile): self.assertIn("pto.vcvt rnd must be a VcvtRoundMode enum", str(ctx.exception)) + def test_vcvt_requires_explicit_required_attrs_for_type_pair(self) -> None: + with self.assertRaises(TypeError) as ctx: + + @pto.vkernel( + op="vcvt_missing_required_attrs_unique", + dtypes=[(pto.i32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + src_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + dst_mask = pto.make_mask(pto.i32, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vcvt(vec, pto.i32, src_mask, rnd=pto.VcvtRoundMode.R) + pto.vsts(out, dst, 0, dst_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + specialized.mlir_text() + + self.assertIn("requires explicit `sat=`", str(ctx.exception)) + + def test_vcvt_rejects_disallowed_attrs_for_type_pair(self) -> None: + with self.assertRaises(TypeError) as ctx: + + @pto.vkernel( + op="vcvt_disallowed_attr_unique", + dtypes=[(pto.i32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + src_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + dst_mask = pto.make_mask(pto.i32, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vcvt( + vec, + pto.i32, + src_mask, + rnd=pto.VcvtRoundMode.R, + sat=pto.VcvtSatMode.SAT, + part=pto.VcvtPartMode.ODD, + ) + pto.vsts(out, dst, 0, dst_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + specialized.mlir_text() + + self.assertIn("does not accept `part=`", str(ctx.exception)) + + def test_index_to_float_scalar_cast_lowers_via_integer_bridge(self) -> None: + @pto.vkernel( + op="index_to_float_scalar_cast_unique", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + mask, _ = pto.make_mask(pto.f32, 1) + vec = pto.vlds(src, 0) + for col in range(0, 1, 1): + scalar = pto.f32(col) + out = pto.vadds(vec, scalar, mask) + pto.vsts(out, dst, 0, mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("arith.index_castui", text) + self.assertRegex(text, r"arith\.uitofp %\w+ : i64 to f32") + self.assertNotRegex(text, r"arith\.uitofp %\w+ : index to f32") + def test_extended_integer_vector_ops_surface_lowers(self) -> None: @pto.vkernel( op="extended_integer_vector_ops_unique", From 69b154eb490d5ad1a610375e08cf5afa84dee846 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Wed, 15 Apr 2026 09:43:15 +0800 Subject: [PATCH 069/192] chore: drop trowargmax template from issue #63 PR --- lib/TileOps/trowargmax_template.py | 86 ------------------------------ 1 file changed, 86 deletions(-) delete mode 100644 lib/TileOps/trowargmax_template.py diff --git a/lib/TileOps/trowargmax_template.py b/lib/TileOps/trowargmax_template.py deleted file mode 100644 index 660151b3b..000000000 --- a/lib/TileOps/trowargmax_template.py +++ /dev/null @@ -1,86 +0,0 @@ -"""TileLang DSL template for pto.trowargmax""" - -import sys -from pathlib import Path -import tilelang_dsl as pto - -@pto.vkernel( - target="a5", - op="pto.trowargmax", - advanced=True, -) -def template_trowargmax(src: pto.Tile, tmp: pto.Tile, dst: pto.Tile): - src_dtype = src.element_type - idx_dtype = dst.element_type - lanes = pto.get_lanes(src_dtype) - valid_rows, valid_cols = src.valid_shape - - # Initialize with negative infinity for ROWARGMAX - if pto.constexpr(src_dtype == pto.f32): - init_val = pto.f32(0xFF800000) # Negative infinity - init_zero = pto.f32(0) - elif pto.constexpr(src_dtype == pto.f16): - init_val = pto.f16(0xFC00) # Negative infinity - init_zero = pto.f16(0) - - # Since index is valid in lane 0, we can use mask_1 - mask_1, _ = pto.make_mask(src_dtype, 1) - mask_1_final, _ = pto.make_mask(idx_dtype, 1) - - for row in range(0, valid_rows, 1): - remained = valid_cols - - v_val_acc = pto.vbr(init_val) - v_idx_acc = pto.vbr(init_zero) - v_zero = pto.vbr(init_zero) - - # Process all column chunks - for col in range(0, valid_cols, lanes): - mask, remained = pto.make_mask(src_dtype, remained) - v_src = pto.vlds(src[row, col:]) - v_reduced = pto.vcmax(v_src, mask) - - v_val, v_idx = pto.vdintlv(v_reduced, v_zero) - - # Add absolute col offset to the chunk's local index - if pto.constexpr(src_dtype == pto.f32): - v_col = pto.f32(col) - elif pto.constexpr(src_dtype == pto.f16): - v_col = pto.f16(col) - - v_idx = pto.vadds(v_idx, v_col, mask_1) - - # Compare current chunk max with global max so far - # vcmp returns a mask - cmp_mask = pto.vcmp(v_val_acc, v_val, mask_1, "lt") - - # Update global max and global argmax depending on who is greater - v_val_acc = pto.vsel(v_val, v_val_acc, cmp_mask) - v_idx_acc = pto.vsel(v_idx, v_idx_acc, cmp_mask) - - # Store the extracted index into the dst tile - if pto.constexpr(src_dtype != idx_dtype): - # vcvt attrs are type-pair sensitive in VPTO verifier: - # - f32 -> i32 requires rnd + sat - # - f16 -> i32 requires part - if pto.constexpr(src_dtype == pto.f32 and idx_dtype == pto.i32): - v_idx_acc_casted = pto.vcvt( - v_idx_acc, - idx_dtype, - mask_1_final, - rnd=pto.VcvtRoundMode.R, - sat=pto.VcvtSatMode.SAT, - ) - elif pto.constexpr(src_dtype == pto.f16 and idx_dtype == pto.i32): - v_idx_acc_casted = pto.vcvt( - v_idx_acc, - idx_dtype, - mask_1_final, - part=pto.VcvtPartMode.ODD, - ) - else: - v_idx_acc_casted = pto.vcvt(v_idx_acc, idx_dtype, mask_1_final) - pto.vsts(v_idx_acc_casted, dst[row, 0:], mask_1_final) - else: - pto.vsts(v_idx_acc, dst[row, 0:], mask_1_final) - return From 1a4d80855ffc87f3c674854d75b9397018922a5b Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Wed, 15 Apr 2026 10:44:01 +0800 Subject: [PATCH 070/192] fix(dsl): support module-level literal constants in kernels (#62) --- .../python/tilelang_dsl/frontend_ast.py | 135 ++++++++++++++++++ tilelang-dsl/tests/test_tilelang_dsl_v1.py | 24 ++++ 2 files changed, 159 insertions(+) diff --git a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py index 06a7956d3..3cd23a407 100644 --- a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py +++ b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py @@ -208,6 +208,8 @@ class _FrontendBuildContext: selected_op: str | None advanced_enabled: bool inline_procs: dict[str, _FrontendInlineProc] + global_literal_constants: dict[str, Any] + local_bindings: frozenset[str] active_inline_proc_stack: tuple[str, ...] = () vecscope_depth: int = 0 @@ -223,6 +225,8 @@ def nested_vecscope(self) -> "_FrontendBuildContext": selected_op=self.selected_op, advanced_enabled=self.advanced_enabled, inline_procs=self.inline_procs, + global_literal_constants=self.global_literal_constants, + local_bindings=self.local_bindings, active_inline_proc_stack=self.active_inline_proc_stack, vecscope_depth=self.vecscope_depth + 1, ) @@ -234,11 +238,125 @@ def enter_inline_proc(self, name: str, source_info: Any) -> "_FrontendBuildConte selected_op=self.selected_op, advanced_enabled=self.advanced_enabled, inline_procs=self.inline_procs, + global_literal_constants=self.global_literal_constants, + local_bindings=_collect_source_local_bindings(source_info), active_inline_proc_stack=(*self.active_inline_proc_stack, name), vecscope_depth=self.vecscope_depth, ) +_UNSUPPORTED_GLOBAL_LITERAL = object() +_LOCAL_BINDINGS_CACHE: dict[tuple[str, int, str], frozenset[str]] = {} +_GLOBAL_NAME_READS_CACHE: dict[tuple[str, int, str], frozenset[str]] = {} + + +def _iter_target_names(node: ast.AST) -> tuple[str, ...]: + if isinstance(node, ast.Name): + return (node.id,) + if isinstance(node, (ast.Tuple, ast.List)): + names: list[str] = [] + for elt in node.elts: + names.extend(_iter_target_names(elt)) + return tuple(names) + return () + + +def _collect_source_global_name_reads( + source_info: Any, + local_bindings: frozenset[str], +) -> frozenset[str]: + if source_info is None: + return frozenset() + function_def = source_info.function_def + cache_key = ( + source_info.path, + source_info.start_line, + function_def.name, + ) + cached = _GLOBAL_NAME_READS_CACHE.get(cache_key) + if cached is not None: + return cached + + global_reads: set[str] = set() + for node in ast.walk(function_def): + if not isinstance(node, ast.Name) or not isinstance(node.ctx, ast.Load): + continue + if node.id in local_bindings: + continue + if node.id.startswith("__"): + continue + global_reads.add(node.id) + + frozen = frozenset(global_reads) + _GLOBAL_NAME_READS_CACHE[cache_key] = frozen + return frozen + + +def _collect_function_local_bindings(function_def: ast.FunctionDef) -> set[str]: + bindings: set[str] = set() + for arg in function_def.args.posonlyargs: + bindings.add(arg.arg) + for arg in function_def.args.args: + bindings.add(arg.arg) + for arg in function_def.args.kwonlyargs: + bindings.add(arg.arg) + if function_def.args.vararg is not None: + bindings.add(function_def.args.vararg.arg) + if function_def.args.kwarg is not None: + bindings.add(function_def.args.kwarg.arg) + + for node in ast.walk(function_def): + if isinstance(node, ast.Assign): + for target in node.targets: + bindings.update(_iter_target_names(target)) + continue + if isinstance(node, ast.AnnAssign): + bindings.update(_iter_target_names(node.target)) + continue + if isinstance(node, ast.For): + bindings.update(_iter_target_names(node.target)) + continue + if isinstance(node, ast.With): + for item in node.items: + if item.optional_vars is not None: + bindings.update(_iter_target_names(item.optional_vars)) + continue + return bindings + + +def _collect_source_local_bindings(source_info: Any) -> frozenset[str]: + if source_info is None: + return frozenset() + function_def = source_info.function_def + cache_key = ( + source_info.path, + source_info.start_line, + function_def.name, + ) + cached = _LOCAL_BINDINGS_CACHE.get(cache_key) + if cached is not None: + return cached + collected = frozenset(_collect_function_local_bindings(function_def)) + _LOCAL_BINDINGS_CACHE[cache_key] = collected + return collected + + +def _collect_module_literal_constants( + source_info: Any, + *, + module_globals: dict[str, Any] | None, + local_bindings: frozenset[str], +) -> dict[str, Any]: + if source_info is None or module_globals is None: + return {} + literal_constants: dict[str, Any] = {} + for name in _collect_source_global_name_reads(source_info, local_bindings): + value = module_globals.get(name, _UNSUPPORTED_GLOBAL_LITERAL) + if isinstance(value, (bool, int, float, str)): + literal_constants[name] = value + return literal_constants + + def _attach_source_location( frontend_node: FrontendExprNode | FrontendStmtNode, ast_node: ast.AST, @@ -774,6 +892,15 @@ def _build_call_keywords( def _build_expr(node: ast.AST, context: _FrontendBuildContext) -> FrontendExprNode: if isinstance(node, ast.Name): + if ( + node.id in context.global_literal_constants + and node.id not in context.local_bindings + ): + return _attach_source_location( + FrontendConstantExpr(value=context.global_literal_constants[node.id]), + node, + context, + ) return _attach_source_location(FrontendNameExpr(name=node.id), node, context) if isinstance(node, ast.Constant): return _attach_source_location(FrontendConstantExpr(value=node.value), node, context) @@ -1241,6 +1368,12 @@ def build_frontend_kernel_node(descriptor: Any) -> FrontendKernelNode: for name, spec in descriptor.specializations ) source_info = descriptor._source_info + local_bindings = _collect_source_local_bindings(source_info) + global_literal_constants = _collect_module_literal_constants( + source_info, + module_globals=getattr(descriptor._py_fn, "__globals__", None), + local_bindings=local_bindings, + ) sorted_inline_procs = tuple(sorted(descriptor.inline_procs.items(), key=lambda item: item[0])) context = _FrontendBuildContext( source_info=source_info, @@ -1255,6 +1388,8 @@ def build_frontend_kernel_node(descriptor: Any) -> FrontendKernelNode: ) for name, proc in sorted_inline_procs }, + global_literal_constants=global_literal_constants, + local_bindings=local_bindings, ) body = () if source_info is not None: diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index ce5b4379e..fd01fa6da 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -71,6 +71,8 @@ analyze_frontend_kernel, ) +GLOBAL_TILELANG_LITERAL_BLOCK_SIZE = 32 + class TileLangDSLPackageTests(unittest.TestCase): def test_package_exports_surface(self) -> None: @@ -2988,6 +2990,28 @@ def kernel(dst: pto.Tile, src: pto.Tile): self.assertIn("= arith.constant 0.0 : f32", text) self.assertIn("pto.vbr", text) + def test_kernel_accepts_module_level_literal_constant_reference(self) -> None: + @pto.vkernel( + op="module_level_literal_constant_reference_unique", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + all_mask, _ = pto.make_mask(pto.f32, GLOBAL_TILELANG_LITERAL_BLOCK_SIZE) + vec = pto.vlds(src, 0) + out = pto.vadds(vec, pto.f32(0.0), all_mask) + pto.vsts(out, dst, 0, all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertRegex(text, r"arith\.constant 32 : (index|i64)") + self.assertIn("pto.plt_b32", text) + def test_scalar_constructor_call_surfaces_lower(self) -> None: @pto.vkernel( op="scalar_constructor_call_surfaces_unique", From b9782edd3ffcf64e2dc99803673f16f75fc19c34 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Wed, 15 Apr 2026 10:59:58 +0800 Subject: [PATCH 071/192] feat(dsl): allow inline_proc capture of module literal globals --- .../python/tilelang_dsl/frontend_ast.py | 16 +++++++++-- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 28 +++++++++++++++++++ 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py index 3cd23a407..8aca7fa19 100644 --- a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py +++ b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py @@ -204,6 +204,7 @@ class _FrontendInlineProc: @dataclass(frozen=True) class _FrontendBuildContext: source_info: Any + module_globals: dict[str, Any] | None templates: dict[str, dict[str, str]] selected_op: str | None advanced_enabled: bool @@ -221,6 +222,7 @@ def error(self, node: ast.AST, message: str) -> Exception: def nested_vecscope(self) -> "_FrontendBuildContext": return _FrontendBuildContext( source_info=self.source_info, + module_globals=self.module_globals, templates=self.templates, selected_op=self.selected_op, advanced_enabled=self.advanced_enabled, @@ -232,14 +234,21 @@ def nested_vecscope(self) -> "_FrontendBuildContext": ) def enter_inline_proc(self, name: str, source_info: Any) -> "_FrontendBuildContext": + local_bindings = _collect_source_local_bindings(source_info) + global_literal_constants = _collect_module_literal_constants( + source_info, + module_globals=self.module_globals, + local_bindings=local_bindings, + ) return _FrontendBuildContext( source_info=source_info, + module_globals=self.module_globals, templates=self.templates, selected_op=self.selected_op, advanced_enabled=self.advanced_enabled, inline_procs=self.inline_procs, - global_literal_constants=self.global_literal_constants, - local_bindings=_collect_source_local_bindings(source_info), + global_literal_constants=global_literal_constants, + local_bindings=local_bindings, active_inline_proc_stack=(*self.active_inline_proc_stack, name), vecscope_depth=self.vecscope_depth, ) @@ -500,7 +509,7 @@ def _validate_inline_capture( *, context: _FrontendBuildContext, ) -> None: - allowed = param_names | assigned_names + allowed = param_names | assigned_names | set(context.global_literal_constants) if isinstance(stmt, FrontendAssignStmt): missing = _collect_name_reads(stmt.value) - allowed if missing: @@ -1377,6 +1386,7 @@ def build_frontend_kernel_node(descriptor: Any) -> FrontendKernelNode: sorted_inline_procs = tuple(sorted(descriptor.inline_procs.items(), key=lambda item: item[0])) context = _FrontendBuildContext( source_info=source_info, + module_globals=getattr(descriptor._py_fn, "__globals__", None), templates=descriptor.templates, selected_op=descriptor.selected_op, advanced_enabled=descriptor.advanced_enabled, diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index fd01fa6da..ecc3534f0 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -72,6 +72,7 @@ ) GLOBAL_TILELANG_LITERAL_BLOCK_SIZE = 32 +INLINE_PROC_GLOBAL_LANE = 0 class TileLangDSLPackageTests(unittest.TestCase): @@ -4417,6 +4418,13 @@ def _inline_capture(dst: pto.Tile): pto.vlds(dst, lane) return None + @pto.inline_proc + def _inline_capture_global_literal(dst: pto.Tile): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec = pto.vlds(dst, INLINE_PROC_GLOBAL_LANE) + pto.vsts(vec, dst, INLINE_PROC_GLOBAL_LANE, mask) + return None + def test_inline_proc_exports_from_package_surface(self) -> None: self.assertTrue(hasattr(pto, "inline_proc")) self.assertTrue(hasattr(pto, "InlineProcDescriptor")) @@ -4528,6 +4536,26 @@ def kernel(dst: pto.Tile): self.assertIn("implicit capture of 'lane' is not allowed in inline_proc", str(ctx.exception)) self.assertIn(f"{__file__}:", str(ctx.exception)) + def test_inline_proc_allows_module_level_literal_capture(self) -> None: + @pto.vkernel(op="inline_proc_global_literal_capture_unique", dtypes=[(pto.f32,)]) + def kernel(dst: pto.Tile): + _inline_capture_global_literal(dst) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB) + ) + + frontend_kernel = build_frontend_kernel_node(specialized) + self.assertIn( + "_inline_capture_global_literal", + {proc.name for proc in frontend_kernel.inline_procs}, + ) + + text = specialized.mlir_text() + self.assertIn("func.call", text) + self.assertIn("arith.constant 0 : index", text) + def test_inline_proc_rejects_kw_only_vararg_and_kwargs(self) -> None: with self.assertRaises(pto.TileLangFrontendError) as kw_only_ctx: From 9c9ba5c24d74d27ed7963635940aef814ca7a066 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Wed, 15 Apr 2026 11:21:50 +0800 Subject: [PATCH 072/192] Update dsl user guide --- tilelang-dsl/docs/user_guide/06-control-flow.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tilelang-dsl/docs/user_guide/06-control-flow.md b/tilelang-dsl/docs/user_guide/06-control-flow.md index 1b1944600..f4793acc9 100644 --- a/tilelang-dsl/docs/user_guide/06-control-flow.md +++ b/tilelang-dsl/docs/user_guide/06-control-flow.md @@ -106,7 +106,7 @@ Important semantics: - Helper calls support positional arguments and keyword arguments. - Helper calls can appear in statement and expression positions. - Helper definitions can use trailing `return ` to return values. -- Implicit capture is rejected; pass required values as explicit arguments. +- Implicit capture is rejected except module-level globals whose current bound value is `bool`/`int`/`float`/`str`; pass other required values as explicit arguments. - Recursive/mutually-recursive helper call graphs are rejected. - `*args`, `**kwargs`, and keyword-only parameters are unsupported in current version. From 639498be021f71158483159ddc9a5d426b202cdb Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Wed, 15 Apr 2026 12:53:53 +0800 Subject: [PATCH 073/192] tilelang-dsl: support pass as frontend no-op --- .../python/tilelang_dsl/frontend_ast.py | 12 +++++++ tilelang-dsl/python/tilelang_dsl/semantic.py | 6 +++- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 31 +++++++++++++++++++ 3 files changed, 48 insertions(+), 1 deletion(-) diff --git a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py index 8aca7fa19..0db887de1 100644 --- a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py +++ b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py @@ -120,6 +120,11 @@ class FrontendStmtNode: """Base class for lowered frontend statements.""" +@dataclass(frozen=True) +class FrontendNoOpStmt(FrontendStmtNode): + pass + + @dataclass(frozen=True) class FrontendAssignStmt(FrontendStmtNode): target: FrontendTargetNode @@ -510,6 +515,8 @@ def _validate_inline_capture( context: _FrontendBuildContext, ) -> None: allowed = param_names | assigned_names | set(context.global_literal_constants) + if isinstance(stmt, FrontendNoOpStmt): + return if isinstance(stmt, FrontendAssignStmt): missing = _collect_name_reads(stmt.value) - allowed if missing: @@ -638,6 +645,8 @@ def _collect_inline_proc_calls_stmt( inline_proc_names: set[str], into: set[str], ) -> None: + if isinstance(stmt, FrontendNoOpStmt): + return if isinstance(stmt, FrontendAssignStmt): _collect_inline_proc_calls_expr(stmt.value, inline_proc_names, into) return @@ -1206,6 +1215,8 @@ def _build_stmt_list(nodes: list[ast.stmt] | tuple[ast.stmt, ...], context: _Fro def _build_stmt(node: ast.stmt, context: _FrontendBuildContext) -> FrontendStmtNode: + if isinstance(node, ast.Pass): + return _attach_source_location(FrontendNoOpStmt(), node, context) if isinstance(node, ast.Assign): if len(node.targets) != 1: raise context.error(node, "multiple assignment targets are not supported in TileLang DSL v1") @@ -1516,6 +1527,7 @@ def build_frontend_kernel_node(descriptor: Any) -> FrontendKernelNode: "FrontendKernelNode", "FrontendNameExpr", "FrontendNameTarget", + "FrontendNoOpStmt", "FrontendParameterNode", "FrontendReturnStmt", "FrontendSliceExpr", diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 9129a0d5c..558357db4 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -29,6 +29,7 @@ FrontendKernelNode, FrontendNameExpr, FrontendNameTarget, + FrontendNoOpStmt, FrontendReturnStmt, FrontendSliceExpr, FrontendSourceLocation, @@ -1006,6 +1007,9 @@ def _analyze_stmt_or_inline( *, allow_outer_lookup: bool, ) -> tuple[tuple[SemanticStmt, ...], dict[str, SemanticBinding]]: + if isinstance(stmt, FrontendNoOpStmt): + # Python `pass` lowers to a frontend no-op and does not materialize semantic IR. + return tuple(), dict(env) if ( isinstance(stmt, FrontendExprStmt) and isinstance(stmt.expr, FrontendConstantExpr) @@ -1132,7 +1136,7 @@ def _frontend_stmt_is_scalar_vecscope_stmt( self, stmt: FrontendStmtNode, ) -> bool: - return isinstance(stmt, FrontendAssignStmt) or ( + return isinstance(stmt, FrontendNoOpStmt) or isinstance(stmt, FrontendAssignStmt) or ( isinstance(stmt, FrontendExprStmt) and isinstance(stmt.expr, FrontendCallExpr) and stmt.expr.namespace == "pto" diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index ecc3534f0..fe514c6c7 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -24,8 +24,10 @@ FrontendCallExpr, FrontendExprStmt, FrontendForStmt, + FrontendIfStmt, FrontendStrictVecscopeStmt, FrontendVecscopeStmt, + FrontendNoOpStmt, build_frontend_kernel_node, ) from tilelang_dsl.lowering import AuthoringModule, lower_semantic_kernel @@ -4763,6 +4765,35 @@ def kernel(x: pto.TensorView): self.assertIn("unsupported Python syntax `while`", str(ctx.exception)) self.assertIn(f"{__file__}:", str(ctx.exception)) + def test_pass_statement_builds_frontend_noop_and_compiles(self) -> None: + @pto.vkernel(op="pass_statement_frontend_noop_unique", dtypes=[(pto.f32,)]) + def kernel(dst: pto.Tile): + pass + if pto.constexpr(True): + pass + else: + pass + return None + + selected = pto.select_kernel( + "a5", + "pass_statement_frontend_noop_unique", + (pto.f32,), + ) + frontend_kernel = build_frontend_kernel_node(selected) + self.assertIsInstance(frontend_kernel.body[0], FrontendNoOpStmt) + self.assertIsInstance(frontend_kernel.body[1], FrontendIfStmt) + if_stmt = frontend_kernel.body[1] + self.assertTrue(if_stmt.is_constexpr) + self.assertIsInstance(if_stmt.then_body[0], FrontendNoOpStmt) + self.assertIsInstance(if_stmt.else_body[0], FrontendNoOpStmt) + + text = selected.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ).mlir_text() + self.assertIn("return", text) + self.assertNotIn("scf.if", text) + def test_vreg_annotated_assignment_rejects_mismatched_dtype(self) -> None: with self.assertRaises(TypeError) as ctx: From 708d79a1c9d0b4d7bc5d404e0717c8bacb13ad0d Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Wed, 15 Apr 2026 11:46:49 +0800 Subject: [PATCH 074/192] align(tilelang-dsl): support PartitionTensorView slice binding and docs --- .../docs/user_guide/05-type-system.md | 39 ++++++- tilelang-dsl/python/tilelang_dsl/lowering.py | 101 +++++++++++++++++- tilelang-dsl/python/tilelang_dsl/semantic.py | 23 +++- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 41 +++++++ 4 files changed, 199 insertions(+), 5 deletions(-) diff --git a/tilelang-dsl/docs/user_guide/05-type-system.md b/tilelang-dsl/docs/user_guide/05-type-system.md index 47e1d104b..820683c5b 100644 --- a/tilelang-dsl/docs/user_guide/05-type-system.md +++ b/tilelang-dsl/docs/user_guide/05-type-system.md @@ -155,11 +155,12 @@ This replaces string literals (`MemorySpace.GM`/`MemorySpace.UB`) with compile-t ### Public Buffer Types -TileLang uses two public buffer-facing type names in kernel signatures: +TileLang uses three public buffer-facing type names in kernel signatures: | Public Type | Description | |-------------|-------------| | `pto.TensorView` | GM-facing tensor view descriptor used for DMA-oriented data access | +| `pto.PartitionTensorView` | Logical GM partition (slice) descriptor, corresponding to `!pto.partition_tensor_view<...>` | | `pto.Tile` | UB-facing tile buffer value used for tiled compute | ### TensorView Types @@ -243,7 +244,7 @@ partition_5d = tensor_view[ ``` Constraints: -- Slicing returns a new TensorView representing the logical partition. +- Slicing returns a new `pto.PartitionTensorView` representing the logical partition. - The partition must be within the original tensor bounds. - When fewer than 5 slice axes are written, they are right-aligned to the trailing physical axes of the 5D descriptor. - `stop` must be explicit on all dimensions. @@ -252,6 +253,40 @@ Constraints: - Dimension 0 may use `step > 1`. - Dimension 1 must keep `step == 1` in the current DMA-oriented implementation. +### PartitionTensorView Types + +`pto.PartitionTensorView` models a logical partition of GM tensor data and maps to +`!pto.partition_tensor_view` in PTO IR. +Like `TensorView`, it is a descriptor type and does not own storage. + +#### PartitionTensorView Type Definition + +```python +@pto.vkernel(target="a5", op="custom_partition", dtypes=[(pto.f32, pto.f32)]) +def kernel(inp: pto.TensorView, out: pto.TensorView): + part: pto.PartitionTensorView = inp[0:16, 0:16] + p_rows, p_cols = part.shape + s_row, s_col = part.strides + return None +``` + +Important notes: +- A `PartitionTensorView` carries partition `shape` and `strides` metadata in element units. +- Element dtype is inherited from the source tensor view. +- Memory space remains GM. +- Rank handling follows the same right-aligned 5D contract as `TensorView`. +- `PartitionTensorView` can be used where DMA-oriented TensorView-like descriptors are accepted. +- Prefer direct indexing or tuple unpacking for `shape`/`strides` metadata values in current DSL v1 lowering. + +#### PartitionTensorView Attributes + +| Attribute | Type | Description | +|-----------|------|-------------| +| `shape` | `tuple[int, ...]` | Partition dimensions | +| `element_type` | `Type` | Element data type inherited from source tensor view | +| `strides` | `tuple[int, ...]` | Stride in elements for each dimension | +| `offset` | `pto.i64` | Byte offset from the base tensor pointer (internal) | + ### Tile Types Tile types represent data blocks in memory with layout and configuration information, corresponding to `!pto.tile_buf` in the VPTO IR. Tiles are commonly used as kernel parameters for tiled computations. diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index 670b6fc7a..99c1542ff 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -1178,6 +1178,99 @@ def _tensor_slice_extents(self, expr: SemanticTensorSliceExpr) -> tuple[int, int raise NotImplementedError("TileLang DSL v1 DMA lowering currently only supports rank-2 TensorView slices") return expr.type.extents + def _materialize_tensor_slice_axis_size( + self, + slice_axis: object, + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + if slice_axis.extent is not None: + return _RenderedValue( + name=self._materialize_constant(slice_axis.extent, SemanticIndexType()), + type=SemanticIndexType(), + ) + distance = self._emit_binary_value( + "sub", + self._lower_expr(slice_axis.stop, env, indent=indent, into=into), + self._lower_expr(slice_axis.start, env, indent=indent, into=into), + SemanticIndexType(), + indent=indent, + into=into, + ) + step_value = self._static_expr_value(slice_axis.step, default=1) + if not isinstance(step_value, int) or step_value <= 0: + raise NotImplementedError( + "partition_view lowering currently expects a static positive slice step in TileLang DSL v1" + ) + if step_value == 1: + return distance + numerator = self._emit_binary_value( + "add", + distance, + _RenderedValue( + name=self._materialize_constant(step_value - 1, SemanticIndexType()), + type=SemanticIndexType(), + ), + SemanticIndexType(), + indent=indent, + into=into, + ) + return self._emit_binary_value( + "floordiv", + numerator, + _RenderedValue( + name=self._materialize_constant(step_value, SemanticIndexType()), + type=SemanticIndexType(), + ), + SemanticIndexType(), + indent=indent, + into=into, + ) + + def _lower_tensor_slice_expr( + self, + expr: SemanticTensorSliceExpr, + env: dict[str, _RenderedValue], + *, + indent: int, + desired_name: str | None, + into: list[str] | None, + ) -> _RenderedValue: + if into is None: + into = [] + tensor_base = self._lower_expr(expr.base, env, indent=indent, into=into) + if not isinstance(tensor_base.type, (SemanticTensorViewType, SemanticPartitionTensorViewType)): + raise NotImplementedError("partition_view lowering expects a TensorView/PartitionTensorView source") + + offsets: list[_RenderedValue] = [] + sizes: list[_RenderedValue] = [] + for axis_slice in expr.slices: + offsets.append(self._lower_expr(axis_slice.start, env, indent=indent, into=into)) + sizes.append( + self._materialize_tensor_slice_axis_size( + axis_slice, + env, + indent=indent, + into=into, + ) + ) + + result_name = desired_name or self._new_temp() + result_type_text = self._render_partition_tensor_view_type( + element_dtype=expr.type.element_dtype.name, + shape=tuple("?" if dim is None else dim for dim in expr.type.extents), + ) + into.append( + self._indent(indent) + + f"{result_name} = pto.partition_view {tensor_base.name}, " + + f"offsets = [{', '.join(value.name for value in offsets)}], " + + f"sizes = [{', '.join(value.name for value in sizes)}] : " + + f"{self._render_type(tensor_base.type)} -> {result_type_text}" + ) + return _RenderedValue(name=result_name, type=_RenderedTextualType(result_type_text)) + def _resolve_dma_load_padding_profile(self, options: object) -> _DmaLoadPaddingProfile: pad_mode_name = self._static_pad_mode_name(getattr(options, "pad_mode", None)) or "PadNull" left_padding = self._static_expr_value(getattr(options, "left_padding", None), default=0) @@ -2365,7 +2458,13 @@ def _lower_expr( if isinstance(expr, SemanticAttributeAccess): raise NotImplementedError("bare shape attribute values are not materialized directly") if isinstance(expr, SemanticTensorSliceExpr): - raise NotImplementedError("TensorView slices are only lowered through DMA statements in TileLang DSL v1") + return self._lower_tensor_slice_expr( + expr, + env, + indent=indent, + desired_name=desired_name, + into=into, + ) if isinstance(expr, SemanticSymbolExpr): raise NotImplementedError("symbol expressions are only lowered through specialized TileLang DSL ops") raise NotImplementedError(f"unsupported semantic expression {type(expr).__name__}") diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 558357db4..8ce40fa6b 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -2442,10 +2442,17 @@ def _bind_assignment_target( if isinstance(target, FrontendNameTarget): if isinstance(value.type, SemanticTupleType): raise ValueError("multi-result call assignment requires tuple binding in TileLang DSL v1") - annotated_type = self._annotation_type(annotation, value.type, env) + inferred_type: SemanticType = value.type + if isinstance(value.type, SemanticTensorSliceType): + # Tensor slicing materializes a logical partition descriptor value in IR. + inferred_type = SemanticPartitionTensorViewType( + element_dtype=value.type.element_dtype, + rank=value.type.rank, + ) + annotated_type = self._annotation_type(annotation, inferred_type, env) binding = self._make_binding( target.name, - annotated_type if annotated_type is not None else value.type, + annotated_type if annotated_type is not None else inferred_type, "ssa", value=self._binding_value_for_expr(value), ) @@ -2547,6 +2554,11 @@ def _annotation_type( return inferred_type if annotation_expr.type.kind == "align_type" and isinstance(inferred_type, SemanticAlignType): return inferred_type + if ( + annotation_expr.type.kind == "partition_tensor_view_type" + and isinstance(inferred_type, SemanticPartitionTensorViewType) + ): + return inferred_type raise TypeError("unsupported annotated assignment type in TileLang DSL v1") def _analyze_annotation_expr( @@ -3049,6 +3061,13 @@ def _analyze_symbol_expr(self, expr: FrontendSymbolExpr) -> SemanticExpr: value=align, type=SemanticMetaType(kind="align_type"), ) + if expr.name == "PartitionTensorView": + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=expr.name, + type=SemanticMetaType(kind="partition_tensor_view_type"), + ) if expr.namespace in {"PAT", "pto.PAT", "pto.MaskPattern"}: pattern = _PATTERN_SYMBOLS.get(expr.name) if pattern is None and expr.name.startswith("PAT_"): diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index fe514c6c7..2cb8c6ab1 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -48,6 +48,7 @@ SemanticLowLevelCopyStmt, SemanticMaskType, SemanticPadValueType, + SemanticPartitionTensorViewType, SemanticPipeBarrierStmt, SemanticPtrType, SemanticPredicateStoreStmt, @@ -1887,6 +1888,46 @@ def kernel(inp: pto.TensorView): self.assertEqual(slice_assign.value.type.extents, (16, 32)) self.assertEqual(slice_assign.value.type.physical_axes, (3, 4)) + def test_tensorview_slice_binding_lowers_to_partition_tensor_view_descriptor(self) -> None: + @pto.vkernel(op="tensorview_slice_partition_binding_unique", dtypes=[(pto.f32,)]) + def kernel(inp: pto.TensorView): + part = inp[0:16, 0:32] + rows, cols = part.shape + s0, s1 = part.strides + if rows != 0 and cols != 0: + rows = s0 + s1 + return None + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(kernel)) + slice_assign = semantic_kernel.body[0] + self.assertIsInstance(slice_assign, SemanticAssignStmt) + self.assertIsInstance(slice_assign.targets[0].type, SemanticPartitionTensorViewType) + self.assertEqual(slice_assign.targets[0].type.rank, 2) + + text = kernel.mlir_text() + self.assertIn(" = pto.partition_view %arg0, offsets = [%c0, %c0], sizes = [%c16, %c32] : ", text) + self.assertIn("-> !pto.partition_tensor_view<16x32xf32>", text) + self.assertEqual(text.count("pto.get_tensor_view_dim"), 2) + self.assertEqual(text.count("pto.get_tensor_view_stride"), 2) + + def test_partition_tensor_view_annotation_accepts_tensorview_slice_binding(self) -> None: + @pto.vkernel(op="partition_tensor_view_annotation_unique", dtypes=[(pto.f32,)]) + def kernel(inp: pto.TensorView): + part: pto.PartitionTensorView = inp[0:8, 0:8] + r0, r1 = part.shape + return None + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(kernel)) + slice_assign = semantic_kernel.body[0] + self.assertIsInstance(slice_assign, SemanticAssignStmt) + self.assertIsInstance(slice_assign.targets[0].type, SemanticPartitionTensorViewType) + self.assertEqual(slice_assign.targets[0].type.rank, 2) + + text = kernel.mlir_text() + self.assertIn(" = pto.partition_view %arg0, offsets = [%c0, %c0], sizes = [%c8, %c8] : ", text) + self.assertIn("-> !pto.partition_tensor_view<8x8xf32>", text) + self.assertEqual(text.count("pto.get_tensor_view_dim"), 2) + def test_dynamic_tensorview_shape_profile_supports_runtime_bound_without_high_level_dma(self) -> None: @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32)]) def kernel(inp: pto.TensorView, tile: pto.Tile): From bc501d2dc39a53d18a710b8876942f75823ef8c4 Mon Sep 17 00:00:00 2001 From: chenjinlin Date: Tue, 14 Apr 2026 12:36:12 +0800 Subject: [PATCH 075/192] udpate tload and tstore. --- build.sh | 11 +++ lib/TileOps/tload_template.py | 164 ++++++++++++++++++++++++++++++--- lib/TileOps/tstore_template.py | 151 +++++++++++++++++++++++++++--- run.sh | 1 + 4 files changed, 299 insertions(+), 28 deletions(-) create mode 100755 build.sh create mode 100755 run.sh diff --git a/build.sh b/build.sh new file mode 100755 index 000000000..083c5c8b3 --- /dev/null +++ b/build.sh @@ -0,0 +1,11 @@ +cmake -G Ninja \ + -S . \ + -B build \ + -DLLVM_DIR=/opt/llvm/lib/cmake/llvm \ + -DMLIR_DIR=/opt/llvm/lib/cmake/mlir \ + -DPython3_EXECUTABLE=$(which python3) \ + -DPython3_FIND_STRATEGY=LOCATION \ + -Dpybind11_DIR="${PYBIND11_CMAKE_DIR}" \ + -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DMLIR_PYTHON_PACKAGE_DIR=${PTO_INSTALL_DIR} \ + -DCMAKE_INSTALL_PREFIX="$PTO_INSTALL_DIR" && ninja -C build && cmake --install build && cp build ../ -fr diff --git a/lib/TileOps/tload_template.py b/lib/TileOps/tload_template.py index a12ef6675..4c7466fcf 100644 --- a/lib/TileOps/tload_template.py +++ b/lib/TileOps/tload_template.py @@ -1,21 +1,14 @@ -# Copyright (c) 2026 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. - """`pto.tload` 的 TileLang DSL 模板""" import tilelang_dsl as pto - -def _tload_preconditions(src, dst) -> bool: +def _tload_preconditions_nd2nd(src, dst) -> bool: logical_rows = src.shape[0] * src.shape[1] * src.shape[2] * src.shape[3] - logical_cols = src.shape[4] - return ( - src.rank == 5 + logical_cols = src.shape[4] + return ( + dst.config.b_layout == pto.BLayout.ROW_MAJOR + and dst.config.s_layout == pto.SLayout.NONE_BOX + and src.rank == 5 and src.strides[4] == 1 and dst.valid_shape[0] <= logical_rows and dst.valid_shape[1] <= logical_cols @@ -26,14 +19,44 @@ def _tload_preconditions(src, dst) -> bool: ) +def _tload_preconditions_dn2dn(src, dst) -> bool: + logical_rows = src.shape[3] + logical_cols = src.shape[0] * src.shape[1] * src.shape[2] * src.shape[4] + return ( + dst.config.b_layout != pto.BLayout.ROW_MAJOR + and dst.config.s_layout == pto.SLayout.NONE_BOX + and src.rank == 5 + and src.strides[3] == 1 + and dst.valid_shape[0] <= logical_rows + and dst.valid_shape[1] <= logical_cols + and logical_rows <= dst.shape[0] + and logical_cols <= dst.shape[1] + and dst.valid_shape[0] <= dst.shape[0] + and dst.valid_shape[1] <= dst.shape[1] + ) + +def _tload_preconditions_nz2nz(src, dst) -> bool: + logical_rows = src.shape[2] + return ( + dst.config.b_layout != pto.BLayout.ROW_MAJOR + and dst.config.s_layout == pto.SLayout.ROW_MAJOR + and src.rank == 5 + and dst.valid_shape[0] <= logical_rows + and logical_rows <= dst.shape[0] + and dst.valid_shape[1] <= dst.shape[1] + and dst.valid_shape[0] <= dst.shape[0] + ) + + @pto.vkernel( target="a5", op="pto.tload", advanced=True, - constraints=[_tload_preconditions], + constraints=[_tload_preconditions_nd2nd], ) -def template_tload(src: pto.PartitionTensorView, dst: pto.Tile): +def template_tload_nd2nd(src: pto.TensorView, dst: pto.Tile): dtype = dst.element_type + b_layout = dst.config.b_layout elem_bytes = pto.bytewidth(dtype) g0, g1, g2, g3, g4 = src.shape @@ -89,3 +112,114 @@ def template_tload(src: pto.PartitionTensorView, dst: pto.Tile): if loop1 != 1 or loop2 != 1: pto.set_loop_size_outtoub(loop1=1, loop2=1) return + +@pto.vkernel( + target="a5", + op="pto.tload", + advanced=True, + constraints=[_tload_preconditions_dn2dn], +) +def template_tload_dn2dn(src: pto.TensorView, dst: pto.Tile): + dtype = dst.element_type + elem_bytes = pto.bytewidth(dtype) + + # rank-5 partition view 元信息。 + g0, g1, g2, g3, g4 = src.shape + s0, s1, s2, s3, s4 = src.strides + + tile_rows, tile_cols = dst.shape + valid_rows, valid_cols = dst.valid_shape + + n_burst = g4 + len_burst = valid_rows * elem_bytes + gm_stride = s4 * elem_bytes + ub_stride = tile_rows * elem_bytes + + # UB 目标 tile 是列高为 `tile_rows` 的紧凑 col-major 布局, + # 从最内层 `g4 × tile_rows` 块递推出三层阶梯 stride。 + dst_stride2 = g4 * tile_rows + dst_stride1 = g2 * dst_stride2 + dst_stride0 = g1 * dst_stride1 + + # loop1 ↔ g2(内层),loop2 ↔ g1(外层),软件 for ↔ g0。 + loop1 = g2 + loop2 = g1 + loop1_src_stride = s2 * elem_bytes + loop1_dst_stride = dst_stride2 * elem_bytes + loop2_src_stride = s1 * elem_bytes + loop2_dst_stride = dst_stride1 * elem_bytes + + gm_ptr = src.as_ptr() + ub_ptr = dst.as_ptr() + + if loop1 != 1 or loop2 != 1: + pto.set_loop2_stride_outtoub( + src_stride=loop2_src_stride, dst_stride=loop2_dst_stride + ) + pto.set_loop1_stride_outtoub( + src_stride=loop1_src_stride, dst_stride=loop1_dst_stride + ) + pto.set_loop_size_outtoub(loop1=loop1, loop2=loop2) + + for i in range(0, g0, 1): + src_i = pto.addptr(gm_ptr, i * s0 * elem_bytes) + dst_i = pto.addptr(ub_ptr, i * dst_stride0 * elem_bytes) + pto.copy_gm_to_ubuf( + dst=dst_i, + src=src_i, + n_burst=n_burst, + len_burst=len_burst, + gm_stride=gm_stride, + ub_stride=ub_stride, + enable_ub_pad=False, + ) + + if loop1 != 1 or loop2 != 1: + pto.set_loop_size_outtoub(loop1=1, loop2=1) + return + +@pto.vkernel( + target="a5", + op="pto.tload", + advanced=True, + constraints=[_tload_preconditions_nz2nz], +) +def template_tload_nz2nz(src: pto.TensorView, dst: pto.Tile): + dtype = dst.element_type + elem_bytes = pto.bytewidth(dtype) + + # rank-5 partition view 元信息。NZ 静态分块约束(g3/g4 与 dtype 的关系) + # 由更高层 schema/static-check 保证,这里只保留运行时搬运公式。 + g0, g1, g2, g3, g4 = src.shape + s0, s1, s2, s3, s4 = src.strides + + tile_rows, tile_cols = dst.shape + valid_rows, valid_cols = dst.valid_shape + + c0_size_bytes = 32 + n_burst = g1 + len_burst = valid_rows * c0_size_bytes + gm_stride = s1 * elem_bytes + ub_stride = tile_rows * c0_size_bytes + + # 每个 g0 block 在 UB 中包含 `g1` 个 NZ 小块;每块的列宽是 `g4` elems。 + tile_stride = g1 * tile_rows * g4 + + gm_ptr = src.as_ptr() + ub_ptr = dst.as_ptr() + + # NZ2NZ 对应实现始终走 normal mode,不复用 loop1/loop2 寄存器。 + pto.set_loop_size_outtoub(loop1=1, loop2=1) + for i in range(0, g0, 1): + src_i = pto.addptr(gm_ptr, i * s0 * elem_bytes) + dst_i = pto.addptr(ub_ptr, i * tile_stride * elem_bytes) + pto.copy_gm_to_ubuf( + dst=dst_i, + src=src_i, + n_burst=n_burst, + len_burst=len_burst, + gm_stride=gm_stride, + ub_stride=ub_stride, + enable_ub_pad=False, + ) + return diff --git a/lib/TileOps/tstore_template.py b/lib/TileOps/tstore_template.py index a8d6cc4f6..ceeb46368 100644 --- a/lib/TileOps/tstore_template.py +++ b/lib/TileOps/tstore_template.py @@ -1,21 +1,14 @@ -# Copyright (c) 2026 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. - """`pto.tstore` 的 TileLang DSL 模板""" import tilelang_dsl as pto - -def _tstore_preconditions(src, dst) -> bool: +def _tstore_preconditions_nd(src, dst) -> bool: logical_rows = dst.shape[0] * dst.shape[1] * dst.shape[2] * dst.shape[3] logical_cols = dst.shape[4] return ( - dst.rank == 5 + src.config.b_layout == pto.BLayout.ROW_MAJOR + and src.config.s_layout == pto.SLayout.NONE_BOX + and src.rank == 5 and dst.strides[4] == 1 and src.valid_shape[0] == logical_rows and src.valid_shape[1] == logical_cols @@ -23,14 +16,40 @@ def _tstore_preconditions(src, dst) -> bool: and src.valid_shape[1] <= src.shape[1] ) +def _tstore_preconditions_dn(src, dst) -> bool: + logical_rows = dst.shape[3] + logical_cols = dst.shape[0] * dst.shape[1] * dst.shape[2] * dst.shape[4] + return ( + src.config.b_layout != pto.BLayout.ROW_MAJOR + and src.config.s_layout == pto.SLayout.NONE_BOX + and src.rank == 5 + and dst.strides[3] == 1 + and src.valid_shape[0] == logical_rows + and src.valid_shape[1] == logical_cols + and src.valid_shape[0] <= src.shape[0] + and src.valid_shape[1] <= src.shape[1] + ) + +def _tstore_preconditions_nz(src, dst) -> bool: + logical_rows = dst.shape[2] * dst.shape[3] + logical_cols = dst.shape[0] * dst.shape[1] * dst.shape[4] + return ( + src.config.b_layout != pto.BLayout.ROW_MAJOR + and src.config.s_layout == pto.SLayout.ROW_MAJOR + and src.rank == 5 + and src.valid_shape[0] == logical_rows + and src.valid_shape[1] == logical_cols + and src.valid_shape[0] <= src.shape[0] + and src.valid_shape[1] <= src.shape[1] + ) @pto.vkernel( target="a5", op="pto.tstore", advanced=True, - constraints=[_tstore_preconditions], + constraints=[_tstore_preconditions_nd], ) -def template_tstore(src: pto.Tile, dst: pto.PartitionTensorView): +def template_tstore_nd(src: pto.Tile, dst: pto.TensorView): dtype = src.element_type elem_bytes = pto.bytewidth(dtype) @@ -86,3 +105,109 @@ def template_tstore(src: pto.Tile, dst: pto.PartitionTensorView): if loop1 != 1 or loop2 != 1: pto.set_loop_size_ubtoout(loop1=1, loop2=1) return + +@pto.vkernel( + target="a5", + op="pto.tstore", + advanced=True, + constraints=[_tstore_preconditions_dn], +) +def template_tstore_dn(src: pto.Tile, dst: pto.TensorView): + dtype = src.element_type + elem_bytes = pto.bytewidth(dtype) + + g0, g1, g2, g3, g4 = dst.shape + s0, s1, s2, s3, s4 = dst.strides + + valid_rows, valid_cols = src.valid_shape + ub_rows, ub_cols = src.shape + + n_burst = g4 + len_burst = valid_rows * elem_bytes + gm_stride = s4 * elem_bytes + ub_stride = ub_rows * elem_bytes + + # UB 源 tile 是列高 `ub_rows` 的紧凑 col-major 布局, + # 与 `TStoreVecDN` 一样由 `g4` / `g2` / `g1` 递推出三级 stride。 + src_stride2 = ub_rows * g4 + src_stride1 = g2 * src_stride2 + src_stride0 = g1 * src_stride1 + + loop1 = g2 + loop2 = g1 + loop1_src_stride = src_stride2 * elem_bytes + loop1_dst_stride = s2 * elem_bytes + loop2_src_stride = src_stride1 * elem_bytes + loop2_dst_stride = s1 * elem_bytes + + ub_ptr = src.as_ptr() + gm_ptr = dst.as_ptr() + + if loop1 != 1 or loop2 != 1: + pto.set_loop2_stride_ubtoout( + src_stride=loop2_src_stride, dst_stride=loop2_dst_stride + ) + pto.set_loop1_stride_ubtoout( + src_stride=loop1_src_stride, dst_stride=loop1_dst_stride + ) + pto.set_loop_size_ubtoout(loop1=loop1, loop2=loop2) + + for i in range(0, g0, 1): + src_i = pto.addptr(ub_ptr, i * src_stride0 * elem_bytes) + dst_i = pto.addptr(gm_ptr, i * s0 * elem_bytes) + pto.copy_ubuf_to_gm( + dst=dst_i, + src=src_i, + n_burst=n_burst, + len_burst=len_burst, + gm_stride=gm_stride, + ub_stride=ub_stride, + ) + + if loop1 != 1 or loop2 != 1: + pto.set_loop_size_ubtoout(loop1=1, loop2=1) + return + +@pto.vkernel( + target="a5", + op="pto.tstore", + advanced=True, + constraints=[_tstore_preconditions_nz], +) +def template_tstore_nz(src: pto.Tile, dst: pto.TensorView): + dtype = src.element_type + elem_bytes = pto.bytewidth(dtype) + + g0, g1, g2, g3, g4 = dst.shape + s0, s1, s2, s3, s4 = dst.strides + + valid_rows, valid_cols = src.valid_shape + ub_rows, ub_cols = src.shape + + # 对应 C++ `C0_SIZE_BYTE`。NZ 每个 burst 始终写一个完整 C0 block。 + c0_size_bytes = 32 + n_burst = g1 + len_burst = valid_rows * c0_size_bytes + gm_stride = s1 * elem_bytes + ub_stride = ub_rows * c0_size_bytes + + # 每个 g0 block 在 UB 中由 `g1` 个 NZ block 串接组成。 + tile_stride = g1 * ub_rows * g4 + + ub_ptr = src.as_ptr() + gm_ptr = dst.as_ptr() + + # NZ path 本身不使用 loop1/loop2,主动切回 normal mode 避免继承旧状态。 + pto.set_loop_size_ubtoout(loop1=1, loop2=1) + for i in range(0, g0, 1): + src_i = pto.addptr(ub_ptr, i * tile_stride * elem_bytes) + dst_i = pto.addptr(gm_ptr, i * s0 * elem_bytes) + pto.copy_ubuf_to_gm( + dst=dst_i, + src=src_i, + n_burst=n_burst, + len_burst=len_burst, + gm_stride=gm_stride, + ub_stride=ub_stride, + ) + return diff --git a/run.sh b/run.sh new file mode 100755 index 000000000..4dd3579e9 --- /dev/null +++ b/run.sh @@ -0,0 +1 @@ +python3 test/tilelang_st/script/run_st.py -r sim -v a5 -t tadd -c f32_16x64 From 5fcbdd6cb9f5174f3a1ea56b66c36703468ddc17 Mon Sep 17 00:00:00 2001 From: chenjinlin Date: Wed, 15 Apr 2026 12:25:01 +0800 Subject: [PATCH 076/192] update tload, tstore and related ST. --- build.sh | 11 -- lib/TileOps/tload_template.py | 82 ++++++----- lib/TileOps/tstore_template.py | 71 ++++++---- run.sh | 1 - .../npu/a5/src/st/testcase/CMakeLists.txt | 1 + .../a5/src/st/testcase/tload/CMakeLists.txt | 1 + .../npu/a5/src/st/testcase/tload/cases.py | 34 +++++ .../npu/a5/src/st/testcase/tload/compare.py | 47 +++++++ .../npu/a5/src/st/testcase/tload/gen_data.py | 26 ++++ .../npu/a5/src/st/testcase/tload/launch.cpp | 29 ++++ .../npu/a5/src/st/testcase/tload/main.cpp | 127 +++++++++++++++++ .../npu/a5/src/st/testcase/tload/tload.pto | 128 ++++++++++++++++++ 12 files changed, 481 insertions(+), 77 deletions(-) delete mode 100755 build.sh delete mode 100755 run.sh create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tload/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tload/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tload/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tload/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tload/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tload/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tload/tload.pto diff --git a/build.sh b/build.sh deleted file mode 100755 index 083c5c8b3..000000000 --- a/build.sh +++ /dev/null @@ -1,11 +0,0 @@ -cmake -G Ninja \ - -S . \ - -B build \ - -DLLVM_DIR=/opt/llvm/lib/cmake/llvm \ - -DMLIR_DIR=/opt/llvm/lib/cmake/mlir \ - -DPython3_EXECUTABLE=$(which python3) \ - -DPython3_FIND_STRATEGY=LOCATION \ - -Dpybind11_DIR="${PYBIND11_CMAKE_DIR}" \ - -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ - -DMLIR_PYTHON_PACKAGE_DIR=${PTO_INSTALL_DIR} \ - -DCMAKE_INSTALL_PREFIX="$PTO_INSTALL_DIR" && ninja -C build && cmake --install build && cp build ../ -fr diff --git a/lib/TileOps/tload_template.py b/lib/TileOps/tload_template.py index 4c7466fcf..b04f37840 100644 --- a/lib/TileOps/tload_template.py +++ b/lib/TileOps/tload_template.py @@ -2,49 +2,57 @@ import tilelang_dsl as pto +def _match_tile_layout(dst, *, row_major: bool, s_layout) -> bool: + b_layout_ok = ( + dst.config.b_layout == pto.BLayout.ROW_MAJOR + if row_major + else dst.config.b_layout != pto.BLayout.ROW_MAJOR + ) + return b_layout_ok and dst.config.s_layout == s_layout + + +def _check_load_bounds(src, dst, *, logical_rows, logical_cols=None, stride_axis=None) -> bool: + if src.rank != 5: + return False + if stride_axis is not None and src.strides[stride_axis] != 1: + return False + if dst.valid_shape[0] > logical_rows or logical_rows > dst.shape[0]: + return False + if dst.valid_shape[0] > dst.shape[0]: + return False + if logical_cols is not None: + if dst.valid_shape[1] > logical_cols or logical_cols > dst.shape[1]: + return False + if dst.valid_shape[1] > dst.shape[1]: + return False + return True + + def _tload_preconditions_nd2nd(src, dst) -> bool: logical_rows = src.shape[0] * src.shape[1] * src.shape[2] * src.shape[3] - logical_cols = src.shape[4] - return ( - dst.config.b_layout == pto.BLayout.ROW_MAJOR - and dst.config.s_layout == pto.SLayout.NONE_BOX - and src.rank == 5 - and src.strides[4] == 1 - and dst.valid_shape[0] <= logical_rows - and dst.valid_shape[1] <= logical_cols - and logical_rows <= dst.shape[0] - and logical_cols <= dst.shape[1] - and dst.valid_shape[0] <= dst.shape[0] - and dst.valid_shape[1] <= dst.shape[1] + logical_cols = src.shape[4] + return _match_tile_layout( + dst, row_major=True, s_layout=pto.SLayout.NONE_BOX + ) and _check_load_bounds( + src, dst, logical_rows=logical_rows, logical_cols=logical_cols, stride_axis=4 ) def _tload_preconditions_dn2dn(src, dst) -> bool: logical_rows = src.shape[3] logical_cols = src.shape[0] * src.shape[1] * src.shape[2] * src.shape[4] - return ( - dst.config.b_layout != pto.BLayout.ROW_MAJOR - and dst.config.s_layout == pto.SLayout.NONE_BOX - and src.rank == 5 - and src.strides[3] == 1 - and dst.valid_shape[0] <= logical_rows - and dst.valid_shape[1] <= logical_cols - and logical_rows <= dst.shape[0] - and logical_cols <= dst.shape[1] - and dst.valid_shape[0] <= dst.shape[0] - and dst.valid_shape[1] <= dst.shape[1] + return _match_tile_layout( + dst, row_major=False, s_layout=pto.SLayout.NONE_BOX + ) and _check_load_bounds( + src, dst, logical_rows=logical_rows, logical_cols=logical_cols, stride_axis=3 ) def _tload_preconditions_nz2nz(src, dst) -> bool: logical_rows = src.shape[2] - return ( - dst.config.b_layout != pto.BLayout.ROW_MAJOR - and dst.config.s_layout == pto.SLayout.ROW_MAJOR - and src.rank == 5 - and dst.valid_shape[0] <= logical_rows - and logical_rows <= dst.shape[0] - and dst.valid_shape[1] <= dst.shape[1] - and dst.valid_shape[0] <= dst.shape[0] + return _match_tile_layout( + dst, row_major=False, s_layout=pto.SLayout.ROW_MAJOR + ) and _check_load_bounds( + src, dst, logical_rows=logical_rows ) @@ -54,9 +62,8 @@ def _tload_preconditions_nz2nz(src, dst) -> bool: advanced=True, constraints=[_tload_preconditions_nd2nd], ) -def template_tload_nd2nd(src: pto.TensorView, dst: pto.Tile): +def template_tload_nd2nd(src: pto.PartitionTensorView, dst: pto.Tile): dtype = dst.element_type - b_layout = dst.config.b_layout elem_bytes = pto.bytewidth(dtype) g0, g1, g2, g3, g4 = src.shape @@ -119,7 +126,7 @@ def template_tload_nd2nd(src: pto.TensorView, dst: pto.Tile): advanced=True, constraints=[_tload_preconditions_dn2dn], ) -def template_tload_dn2dn(src: pto.TensorView, dst: pto.Tile): +def template_tload_dn2dn(src: pto.PartitionTensorView, dst: pto.Tile): dtype = dst.element_type elem_bytes = pto.bytewidth(dtype) @@ -184,10 +191,15 @@ def template_tload_dn2dn(src: pto.TensorView, dst: pto.Tile): advanced=True, constraints=[_tload_preconditions_nz2nz], ) -def template_tload_nz2nz(src: pto.TensorView, dst: pto.Tile): +def template_tload_nz2nz(src: pto.PartitionTensorView, dst: pto.Tile): dtype = dst.element_type elem_bytes = pto.bytewidth(dtype) + # set padding value for ub tile if needed + # enable_ub_pad = dst.config.pad_value is not pto.PadValue.NULL + # if enable_ub_pad: + # pto.set_mov_pad_val(pad_value=dst.config.pad_value.eval()) + # rank-5 partition view 元信息。NZ 静态分块约束(g3/g4 与 dtype 的关系) # 由更高层 schema/static-check 保证,这里只保留运行时搬运公式。 g0, g1, g2, g3, g4 = src.shape diff --git a/lib/TileOps/tstore_template.py b/lib/TileOps/tstore_template.py index ceeb46368..7857a1544 100644 --- a/lib/TileOps/tstore_template.py +++ b/lib/TileOps/tstore_template.py @@ -2,45 +2,56 @@ import tilelang_dsl as pto +def _match_store_tile_layout(src, *, row_major: bool, s_layout) -> bool: + b_layout_ok = ( + src.config.b_layout == pto.BLayout.ROW_MAJOR + if row_major + else src.config.b_layout != pto.BLayout.ROW_MAJOR + ) + return b_layout_ok and src.config.s_layout == s_layout + + +def _check_store_bounds(src, dst, *, logical_rows, logical_cols, stride_axis=None) -> bool: + if dst.rank != 5: + return False + if stride_axis is not None and dst.strides[stride_axis] != 1: + return False + if src.valid_shape[0] != logical_rows: + return False + if src.valid_shape[1] != logical_cols: + return False + if src.valid_shape[0] > src.shape[0]: + return False + if src.valid_shape[1] > src.shape[1]: + return False + return True + + def _tstore_preconditions_nd(src, dst) -> bool: logical_rows = dst.shape[0] * dst.shape[1] * dst.shape[2] * dst.shape[3] logical_cols = dst.shape[4] - return ( - src.config.b_layout == pto.BLayout.ROW_MAJOR - and src.config.s_layout == pto.SLayout.NONE_BOX - and src.rank == 5 - and dst.strides[4] == 1 - and src.valid_shape[0] == logical_rows - and src.valid_shape[1] == logical_cols - and src.valid_shape[0] <= src.shape[0] - and src.valid_shape[1] <= src.shape[1] + return _match_store_tile_layout( + src, row_major=True, s_layout=pto.SLayout.NONE_BOX + ) and _check_store_bounds( + src, dst, logical_rows=logical_rows, logical_cols=logical_cols, stride_axis=4 ) - + def _tstore_preconditions_dn(src, dst) -> bool: logical_rows = dst.shape[3] logical_cols = dst.shape[0] * dst.shape[1] * dst.shape[2] * dst.shape[4] - return ( - src.config.b_layout != pto.BLayout.ROW_MAJOR - and src.config.s_layout == pto.SLayout.NONE_BOX - and src.rank == 5 - and dst.strides[3] == 1 - and src.valid_shape[0] == logical_rows - and src.valid_shape[1] == logical_cols - and src.valid_shape[0] <= src.shape[0] - and src.valid_shape[1] <= src.shape[1] + return _match_store_tile_layout( + src, row_major=False, s_layout=pto.SLayout.NONE_BOX + ) and _check_store_bounds( + src, dst, logical_rows=logical_rows, logical_cols=logical_cols, stride_axis=3 ) def _tstore_preconditions_nz(src, dst) -> bool: logical_rows = dst.shape[2] * dst.shape[3] logical_cols = dst.shape[0] * dst.shape[1] * dst.shape[4] - return ( - src.config.b_layout != pto.BLayout.ROW_MAJOR - and src.config.s_layout == pto.SLayout.ROW_MAJOR - and src.rank == 5 - and src.valid_shape[0] == logical_rows - and src.valid_shape[1] == logical_cols - and src.valid_shape[0] <= src.shape[0] - and src.valid_shape[1] <= src.shape[1] + return _match_store_tile_layout( + src, row_major=False, s_layout=pto.SLayout.ROW_MAJOR + ) and _check_store_bounds( + src, dst, logical_rows=logical_rows, logical_cols=logical_cols ) @pto.vkernel( @@ -49,7 +60,7 @@ def _tstore_preconditions_nz(src, dst) -> bool: advanced=True, constraints=[_tstore_preconditions_nd], ) -def template_tstore_nd(src: pto.Tile, dst: pto.TensorView): +def template_tstore_nd(src: pto.Tile, dst: pto.PartitionTensorView): dtype = src.element_type elem_bytes = pto.bytewidth(dtype) @@ -112,7 +123,7 @@ def template_tstore_nd(src: pto.Tile, dst: pto.TensorView): advanced=True, constraints=[_tstore_preconditions_dn], ) -def template_tstore_dn(src: pto.Tile, dst: pto.TensorView): +def template_tstore_dn(src: pto.Tile, dst: pto.PartitionTensorView): dtype = src.element_type elem_bytes = pto.bytewidth(dtype) @@ -174,7 +185,7 @@ def template_tstore_dn(src: pto.Tile, dst: pto.TensorView): advanced=True, constraints=[_tstore_preconditions_nz], ) -def template_tstore_nz(src: pto.Tile, dst: pto.TensorView): +def template_tstore_nz(src: pto.Tile, dst: pto.PartitionTensorView): dtype = src.element_type elem_bytes = pto.bytewidth(dtype) diff --git a/run.sh b/run.sh deleted file mode 100755 index 4dd3579e9..000000000 --- a/run.sh +++ /dev/null @@ -1 +0,0 @@ -python3 test/tilelang_st/script/run_st.py -r sim -v a5 -t tadd -c f32_16x64 diff --git a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt index 74fcc7932..c0c522221 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt +++ b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt @@ -114,6 +114,7 @@ endfunction() # -------------------------------------------------------------------------- set(ALL_TESTCASES tadd + tload ) if((TEST_CASE IN_LIST ALL_TESTCASES) OR (TEST_CASE STREQUAL "all")) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tload/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tload/CMakeLists.txt new file mode 100644 index 000000000..43d8d8260 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tload/CMakeLists.txt @@ -0,0 +1 @@ +pto_tilelang_vec_st(tload) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tload/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tload/cases.py new file mode 100644 index 000000000..d1117e24e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tload/cases.py @@ -0,0 +1,34 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import numpy as np + +CASES = [ + { + "name": "nd_f32_16x64", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-6, + }, + { + "name": "dn_f32_16x64", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-6, + }, + { + "name": "nz_f32_128x128", + "dtype": np.float32, + "shape": (128, 128), + "valid_shape": (128, 128), + "eps": 1e-6, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tload/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tload/compare.py new file mode 100644 index 000000000..4ceccdde0 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tload/compare.py @@ -0,0 +1,47 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tload/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tload/gen_data.py new file mode 100644 index 000000000..fc1b88759 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tload/gen_data.py @@ -0,0 +1,26 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + + input_arr = np.random.randint(1, 17, size=shape).astype(dtype) + golden = input_arr.copy() + + save_case_data(case["name"], {"input": input_arr, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} dtype={dtype.__name__}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tload/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tload/launch.cpp new file mode 100644 index 000000000..bab1b88d8 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tload/launch.cpp @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +extern "C" __global__ AICORE void TLOAD_ND_f32_16x64(__gm__ float *src, __gm__ float *dst); +extern "C" __global__ AICORE void TLOAD_DN_f32_16x64(__gm__ float *src, __gm__ float *dst); +extern "C" __global__ AICORE void TLOAD_NZ_f32_128x128(__gm__ float *src, __gm__ float *dst); + +void LaunchTLOAD_ND_f32_16x64(float *src, float *dst, void *stream) { + TLOAD_ND_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +void LaunchTLOAD_DN_f32_16x64(float *src, float *dst, void *stream) { + TLOAD_DN_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +void LaunchTLOAD_NZ_f32_128x128(float *src, float *dst, void *stream) { + TLOAD_NZ_f32_128x128<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tload/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tload/main.cpp new file mode 100644 index 000000000..3984d9150 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tload/main.cpp @@ -0,0 +1,127 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tload/tstore ST. +// Each case performs a GM -> Tile -> GM round trip and compare.py checks that +// output.bin matches input.bin exactly for the requested layout. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +void LaunchTLOAD_ND_f32_16x64(float *src, float *dst, void *stream); +void LaunchTLOAD_DN_f32_16x64(float *src, float *dst, void *stream); +void LaunchTLOAD_NZ_f32_128x128(float *src, float *dst, void *stream); + +using LaunchFn = void (*)(float *, float *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; + size_t cols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + {"nd_f32_16x64", LaunchTLOAD_ND_f32_16x64, 16, 64, sizeof(float)}, + {"dn_f32_16x64", LaunchTLOAD_DN_f32_16x64, 16, 64, sizeof(float)}, + {"nz_f32_128x128", LaunchTLOAD_NZ_f32_128x128, 128, 128, sizeof(float)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (%zux%zu) ===\n", tc.name, tc.rows, tc.cols); + + std::string caseDir = std::string("./") + tc.name; + size_t inputFileSize = fileSize; + + float *srcHost = nullptr; + float *dstHost = nullptr; + float *srcDevice = nullptr; + float *dstDevice = nullptr; + + aclrtMallocHost((void **)(&srcHost), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); + aclrtMalloc((void **)&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), inputFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + tc.launch(srcDevice, dstDevice, stream); + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tload/tload.pto b/test/tilelang_st/npu/a5/src/st/testcase/tload/tload.pto new file mode 100644 index 000000000..a612f12ea --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tload/tload.pto @@ -0,0 +1,128 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tload + pto.tstore round-trip coverage. +// Each kernel only performs GM -> Tile -> GM, so the testcase validates the +// DMA layout path directly for ND, DN, and NZ vector tiles. + +module { + func.func @TLOAD_ND_f32_16x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%tile : !pto.tile_buf) + pto.tstore ins(%tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + func.func @TLOAD_DN_f32_16x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c1, %c16] + : !pto.tensor_view<1x1x1x16x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c1, %c16] + : !pto.tensor_view<1x1x1x16x64xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%tile : !pto.tile_buf) + pto.tstore ins(%tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + func.func @TLOAD_NZ_f32_128x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c16, %c1, %c128, %c1, %c8], + strides = [%c1024, %c1024, %c8, %c8, %c1] + : !pto.tensor_view<16x1x128x1x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c16, %c1, %c128, %c1, %c8], + strides = [%c1024, %c1024, %c8, %c8, %c1] + : !pto.tensor_view<16x1x128x1x8xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c16, %c1, %c128, %c1, %c8] + : !pto.tensor_view<16x1x128x1x8xf32> -> !pto.partition_tensor_view<16x1x128x1x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c16, %c1, %c128, %c1, %c8] + : !pto.tensor_view<16x1x128x1x8xf32> -> !pto.partition_tensor_view<16x1x128x1x8xf32> + + %tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<16x1x128x1x8xf32>) + outs(%tile : !pto.tile_buf) + pto.tstore ins(%tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<16x1x128x1x8xf32>) + return + } +} From e623e6352021a46d424eb0a325f2c34dec4c7d31 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Tue, 14 Apr 2026 20:27:53 +0800 Subject: [PATCH 077/192] Add a skill to resolve DSL issue --- .codex/skills/resolve-dsl-issue/SKILL.md | 263 +++++++++++++++++++++++ 1 file changed, 263 insertions(+) create mode 100644 .codex/skills/resolve-dsl-issue/SKILL.md diff --git a/.codex/skills/resolve-dsl-issue/SKILL.md b/.codex/skills/resolve-dsl-issue/SKILL.md new file mode 100644 index 000000000..52752c0db --- /dev/null +++ b/.codex/skills/resolve-dsl-issue/SKILL.md @@ -0,0 +1,263 @@ +--- +name: resolve-dsl-issue +description: 根据用户提供的 issue 链接,提取 DSL 与 PTO IR 复现最小用例,运行 PTOAS 复现并分析日志,在用户指导下完成修复、提交并自动创建关联 issue 的 PR。 +--- + +# Resolve DSL Issue + +当任务满足以下任一条件时使用本 skill: +- 用户明确提供了要处理的 issue 链接 +- 用户希望“按 issue 内容复现 DSL 问题并定位根因” + +不建议作为主入口的场景: +- 仅做编译/构建,不涉及 issue 复现 +- 仅做 NPU 运行验证,不涉及 DSL/PTO IR 复现 + +## 目标 + +从 issue 中抽取可执行复现输入(DSL + PTO IR),在仓库内构造最小复现并定位根因;在用户确认修复方向后完成代码修复、验证、提交,并自动创建关联原始 issue 的 PR。 + +## 前置条件 + +- 当前目录是 PTOAS 仓库根目录 +- `build/` 目录可写 +- `ptoas` 可执行(已在 PATH 或有明确绝对路径) +- 能访问 issue 内容(网页、API、或用户粘贴) +- 若需要自动创建 PR:`gh` CLI 已安装并登录(`gh auth status` 成功) + +## 标准流程 + +### 1. 解析 issue,提取两个代码片段 + +必须提取到两类片段: +- DSL 代码片段(`.py`) +- PTO IR 代码片段(`.pto`)(如果是纯DSL前端问题,可以不需要 PTO IR) + +推荐提取顺序: +1. issue 正文 +2. issue 评论 +3. issue 附件/粘贴内容 + +如果任一片段缺失,停止后续复现,直接在 issue 请求补充(模板见“评论模板”)。 + +### 2. 在仓库中落盘复现文件 + +文件位置固定为: +- DSL: `lib/TileOps/` +- PTO IR: `test/dsl/` + +命名建议使用 issue 编号,避免冲突,例如: +- `lib/TileOps/issue__repro.py` +- `test/dsl/issue__repro.pto` + +要求: +- 原样写入,避免“自动修复”代码导致偏离用户输入 +- 保留 issue 中的关键注释和输入形状信息 + +### 3. 执行编译并保存日志 + +标准命令: + +```bash +ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand --vpto-emit-hivm-llvm &> +``` + +推荐日志路径: +- `build/issue__repro.log` + +示例: + +```bash +ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand --vpto-emit-hivm-llvm \ + test/dsl/issue_1234_repro.pto \ + &> build/issue_1234_repro.log +``` + +### 4. 分析日志,判断是否复现 + +先定位关键错误信息(error/fatal/assert/traceback),再判断是否与 issue 描述一致。 + +推荐快速检索: + +```bash +rg -n "error|fatal|assert|traceback|failed" build/issue__repro.log +``` + +分支处理: +- 未复现:在 issue 中请求更完整的复现信息(环境、命令、输入、预期/实际) +- 已复现:进入根因定位 + +### 5. 根因定位与方案建议 + +定位输出应至少包含: +- 触发错误的阶段(前端解析/TileOp 展开/Lowering/LLVM 发射等) +- 直接触发点(具体报错行、pass、或输入约束不满足) +- 根因判断(1-2 条最可能原因,标注置信度) +- 修复建议(最小改动优先) + +如果无法在当前上下文完成修复实现,也需要给出: +- 建议修改文件范围 +- 建议新增/补充的测试用例 + +### 6. 与用户确认修复方向(必须) + +在进入代码修改前,先向用户同步: +- 复现文件路径 +- 复现命令 +- 关键报错摘要 +- 根因与建议 +- 待确认项(如环境差异) + +只有在用户明确同意修复方向后,才进入第 7 步。 + +### 7. 实施修复并本地验证 + +修复要求: +- 仅改动与该 issue 直接相关的最小文件集合 +- 优先补充或更新回归测试(如 `test/dsl` 相关用例) +- 保留复现输入,避免把“复现文件”误删 + +验证要求: +- 至少重新执行一次复现命令,确认错误消失或行为符合预期 +- 将关键验证日志保存到 `build/issue__fix_verify.log` +- 跑一次完整的dsl测试集,确认无其他回归 + +### 8. 提交代码(在用户确认后执行) + +分支命名建议: +- `fix/issue--dsl` + +提交信息建议(至少包含 issue 编号): +- `fix(dsl): <简要修复描述> (#)` + +示例命令: + +```bash +git checkout -b fix/issue_1234_dsl +git add +git commit -m "fix(dsl): handle in (#1234)" +git push -u origin fix/issue_1234_dsl +``` + +### 9. 自动创建 PR 并关联原始 issue + +目标仓库:https://github.com/mouliangyu/PTOAS/ +目标分支:feature-vpto-backend + +关联规则(GitHub): +- 在 PR 描述中包含 `Closes #` 或 `Fixes #` +- 若是跨仓库 issue,使用 `Closes /#` +- 合并后删除分支 + +推荐使用 `gh pr create`: + +```bash +gh pr create \ + --base main \ + --head fix/issue_1234_dsl \ + --title "fix(dsl): <简要修复标题>" \ + --body "$(cat <<'EOF' +## Summary +- <修改点1> +- <修改点2> + +## Repro +- issue: #1234 +- repro cmd: `ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand --vpto-emit-hivm-llvm test/dsl/issue_1234_repro.pto` + +## Validation +- <验证命令/结果> + +Closes #1234 +EOF +)" +``` + +创建 PR 后需要回填: +- PR 链接 +- 关联 issue 语句是否生效(是否显示 “linked issues”) + +### 10. 结果同步并等待 review + +向用户同步: +- 修复文件列表 +- 提交 hash +- PR 链接 +- 关联 issue 状态 +- 后续待办(例如 reviewer 关注点) + +## 评论模板 + +### 模板 A:缺少 DSL 或 PTO IR 片段 + +```text +为了准确复现该问题,还需要完整的最小复现输入。请补充以下两段代码: +1) DSL Python 片段(可直接运行到生成该 PTO 的部分) +2) 对应的 PTO IR 片段(完整函数/入口,不要省略关键上下文) + +建议同时提供:执行命令、实际报错、期望行为。 +``` + +### 模板 B:当前未复现 + +```text +我已按当前 issue 信息完成复现尝试,但暂未在本地复现相同报错。 +请补充以下信息以便继续定位: +1) 完整执行命令(含所有 flags) +2) 运行环境(分支/commit、CANN 版本、是否自定义环境变量) +3) 实际报错全文(建议粘贴日志片段) +4) 期望结果与当前结果差异 +``` + +### 模板 C:已复现并给出建议 + +```text +已使用 issue 中输入复现成功,关键报错位于:<阶段/文件/日志行号>。 +初步根因:<根因描述>。 +建议修复:<最小修复方案>。 + +如果你同意该方向,我会继续补充对应测试并提交修复实现供 review。 +``` + +### 模板 D:修复完成,准备提交与开 PR + +```text +修复已完成并通过本地验证。 +计划执行: +1) 提交分支:fix/issue--dsl +2) 创建 PR 并在描述中添加 `Closes #` 自动关联 issue + +请确认是否按该方案提交并创建 PR。 +``` + +### 模板 E:PR 已创建并关联 issue + +```text +PR 已创建: +已在 PR 描述中添加 `Closes #`,原始 issue 已自动关联。 + +本次提交: +- Commit: +- 关键修改: +- 验证结果: +``` + +## 执行注意事项 + +- 不要在未确认复现之前改动用户原始输入语义 +- 优先保留最小复现,不做无关重构 +- 若 issue 信息不完整,先补信息再继续,不要猜测输入 +- 日志分析时优先使用首次错误点,不要只看最后一行报错 +- 未经用户确认,不要直接执行 `git commit`、`git push`、`gh pr create` +- PR 关联语句建议统一放在 PR body 末尾,避免被模板覆盖 +- 若 `gh` 未登录或无权限,输出完整 PR 标题/body 草稿供用户手动创建 + +## 最终输出格式(给用户) + +建议按以下顺序输出: +1. 是否复现成功 +2. 复现文件路径与命令 +3. 日志关键错误(1-3 条) +4. 根因判断 +5. 修复建议与下一步计划 +6.(若完成修复)提交信息、PR 链接、issue 关联状态 From b962d93f30074c80d477fb7bc82bd460f5703b49 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Wed, 15 Apr 2026 20:57:35 +0800 Subject: [PATCH 078/192] fix(tilelang-dsl): emit stable float bit-pattern constants --- tilelang-dsl/python/tilelang_dsl/lowering.py | 54 ++++++++++++++++++++ tilelang-dsl/tests/test_tilelang_dsl_v1.py | 23 +++++++-- 2 files changed, 72 insertions(+), 5 deletions(-) diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index 99c1542ff..3c361dccf 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -10,7 +10,9 @@ from __future__ import annotations +import math import re +import struct from dataclasses import dataclass from .semantic import ( @@ -3546,11 +3548,63 @@ def _format_constant(self, value: object, ty: SemanticType) -> str: if isinstance(ty, SemanticIndexType): return str(value) if isinstance(ty, SemanticScalarType): + if ty.dtype.name in {"f16", "bf16", "f32"} and isinstance( + value, (bool, int, float) + ): + return self._format_float_constant(float(value), ty.dtype.name) if ty.dtype.name == "i1" and isinstance(value, bool): return "1" if value else "0" return str(value) raise NotImplementedError(f"unsupported constant type {ty!r}") + def _format_float_constant(self, value: float, dtype_name: str) -> str: + # Emit stable bit-pattern literals for values that are parse-sensitive + # (`inf`/`nan`) or sign-sensitive (`-0.0`). + if math.isnan(value): + return self._float_nan_bit_pattern(dtype_name) + if math.isinf(value): + sign_bit = value < 0.0 + return self._float_inf_bit_pattern(dtype_name, sign_bit=sign_bit) + if value == 0.0 and math.copysign(1.0, value) < 0.0: + return self._float_to_bit_pattern_literal(value, dtype_name) + return str(value) + + def _float_nan_bit_pattern(self, dtype_name: str) -> str: + if dtype_name == "f16": + return "0x7E00" + if dtype_name == "bf16": + return "0x7FC0" + if dtype_name == "f32": + return "0x7FC00000" + raise NotImplementedError( + f"unsupported float dtype {dtype_name!r} for NaN constant emission" + ) + + def _float_inf_bit_pattern(self, dtype_name: str, *, sign_bit: bool) -> str: + if dtype_name == "f16": + return "0xFC00" if sign_bit else "0x7C00" + if dtype_name == "bf16": + return "0xFF80" if sign_bit else "0x7F80" + if dtype_name == "f32": + return "0xFF800000" if sign_bit else "0x7F800000" + raise NotImplementedError( + f"unsupported float dtype {dtype_name!r} for inf constant emission" + ) + + def _float_to_bit_pattern_literal(self, value: float, dtype_name: str) -> str: + if dtype_name == "f16": + bits = struct.unpack(">H", struct.pack(">e", value))[0] + return f"0x{bits:04X}" + if dtype_name == "bf16": + bits = struct.unpack(">I", struct.pack(">f", value))[0] >> 16 + return f"0x{bits:04X}" + if dtype_name == "f32": + bits = struct.unpack(">I", struct.pack(">f", value))[0] + return f"0x{bits:08X}" + raise NotImplementedError( + f"unsupported float dtype {dtype_name!r} for bit-pattern emission" + ) + def _render_binary_op(self, op: str, ty: SemanticType) -> str: if isinstance(ty, SemanticIndexType): if op == "add": diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index 2cb8c6ab1..a3885edc4 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -3116,11 +3116,24 @@ def kernel(inp: pto.TensorView): return None text = kernel.mlir_text() - self.assertIn("= arith.constant -inf : f16", text) - self.assertIn("= arith.constant inf : bf16", text) - self.assertIn("= arith.constant nan : f32", text) - self.assertIn("= arith.constant -inf : bf16", text) - self.assertIn("= arith.constant -inf : f32", text) + self.assertIn("= arith.constant 0xFC00 : f16", text) + self.assertIn("= arith.constant 0x7F80 : bf16", text) + self.assertIn("= arith.constant 0x7FC00000 : f32", text) + self.assertIn("= arith.constant 0xFF80 : bf16", text) + self.assertIn("= arith.constant 0xFF800000 : f32", text) + + def test_scalar_constructor_emits_negative_zero_as_stable_bit_pattern(self) -> None: + @pto.vkernel(op="scalar_constructor_negative_zero_unique", dtypes=[(pto.f32,)]) + def kernel(inp: pto.TensorView): + a = pto.f16(-0.0) + b = pto.bf16(-0.0) + c = pto.f32(-0.0) + return None + + text = kernel.mlir_text() + self.assertIn("= arith.constant 0x8000 : f16", text) + self.assertIn("= arith.constant 0x8000 : bf16", text) + self.assertIn("= arith.constant 0x80000000 : f32", text) def test_scalar_constructor_rejects_bad_arity(self) -> None: @pto.vkernel(op="scalar_constructor_bad_arity_no_arg_unique", dtypes=[(pto.f32,)]) From f98ed53b54ef6c39ac24cd6113893af56eb33d1f Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Wed, 15 Apr 2026 23:14:47 +0800 Subject: [PATCH 079/192] fix(tilelang-dsl): enforce explicit pointer surface for psts --- .../user_guide/09-vector-memory-operations.md | 38 +++++++-------- .../user_guide/10-predicate-operations.md | 21 ++++++--- tilelang-dsl/python/tilelang_dsl/semantic.py | 47 ++++++++----------- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 19 ++++++++ 4 files changed, 72 insertions(+), 53 deletions(-) diff --git a/tilelang-dsl/docs/user_guide/09-vector-memory-operations.md b/tilelang-dsl/docs/user_guide/09-vector-memory-operations.md index 3355f9932..e1699f22d 100644 --- a/tilelang-dsl/docs/user_guide/09-vector-memory-operations.md +++ b/tilelang-dsl/docs/user_guide/09-vector-memory-operations.md @@ -599,32 +599,32 @@ def generic_store(src: pto.Tile, dst: pto.Tile): pto.vsts(vec, dst[i, j:], all_mask) # No manual offset calculation ``` -#### `pto.psts(mask: MaskType, buf: ptr, offset: Index) -> None` [Advanced Tier] -#### `pto.psts(mask: MaskType, tile[row, col:]) -> None` -#### `pto.psts(mask: MaskType, tile[start:]) -> None` +#### `pto.psts(mask: MaskType, buf: ptr, offset: Index, dist: PredicateDist = PredicateDist.NORM) -> None` [Advanced Tier] -**Description**: Predicate store to buffer. Supports both traditional byte-offset syntax and new element-indexing syntax. +**Description**: Predicate store (`pto.psts`) writes the packed payload represented by +`MaskType` to UB memory. This is the dynamic-offset form of the VPTO predicate-store +family (`psts` vs `psti`): the payload semantics are identical, and only the offset +delivery form differs. -**Parameters (byte-offset syntax)**: +**Parameters (advanced byte-offset syntax)**: | Parameter | Type | Description | |-----------|------|-------------| -| `mask` | `MaskType` | Mask to store | -| `buf` | `ptr` | Pointer to destination buffer (Advanced mode only - requires explicit pointer) | -| `offset` | `Index` | Byte offset | +| `mask` | `MaskType` | Predicate payload to store | +| `buf` | `ptr` | Pointer to destination UB buffer (Advanced mode only - requires explicit pointer) | +| `offset` | `Index` | Runtime offset (`index`) | +| `dist` | `PredicateDist` | Predicate distribution mode. Use `PredicateDist.NORM` or `PredicateDist.PK` (default: `PredicateDist.NORM`). | -**Parameters (element-indexing syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `mask` | `MaskType` | Mask to store | -| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | +**Returns**: None (side-effect operation) -**Parameters (1D element-indexing syntax)**: -| Parameter | Type | Description | -|-----------|------|-------------| -| `mask` | `MaskType` | Mask to store | -| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | +**DIST semantics (VPTO-aligned)**: +- `"NORM"`: store packed predicate payload into a normal destination space of size `VL/8`. +- `"PK"`: store packed predicate payload into a destination space of size `VL/16`, keeping one bit out of every two bits. -**Returns**: None (side-effect operation) +**Notes**: +- `pto.psts` is intentionally documented as explicit `buf + offset` surface in DSL v1. +- Packed predicate payload layout is bit-level (`VL/8` or `VL/16`), so tile element-indexing is not part of the stable Basic Tier contract. +- The pointer + offset form maps directly to explicit `base[offset]`. +- Authoritative predicate-memory-family semantics are documented in `10-predicate-operations.md`. #### `pto.vsst(scalar: ScalarType, buf: ptr, offset: Index, mask: MaskType) -> None` [Advanced Tier] #### `pto.vsst(scalar: ScalarType, tile[row, col:], mask: MaskType) -> None` diff --git a/tilelang-dsl/docs/user_guide/10-predicate-operations.md b/tilelang-dsl/docs/user_guide/10-predicate-operations.md index c0b959943..d0d64f0c0 100644 --- a/tilelang-dsl/docs/user_guide/10-predicate-operations.md +++ b/tilelang-dsl/docs/user_guide/10-predicate-operations.md @@ -412,22 +412,31 @@ mask = pto.pld(buf, offset, PredicateDist.NORM) mask = pto.pldi(buf, 0, PredicateDist.NORM) ``` -#### `pto.psts(mask: MaskType, buf: ptr, offset: Index) -> None` [Advanced Tier] +#### `pto.psts(mask: MaskType, buf: ptr, offset: Index, dist: PredicateDist = PredicateDist.NORM) -> None` [Advanced Tier] -**Description**: Stores a predicate mask to UB memory using scalar-offset form. +**Description**: Stores a predicate mask to UB memory using the VPTO dynamic-offset +`psts` form. This is the dynamic counterpart of `psti`: both encode the same +predicate payload semantics, while offset delivery differs (runtime `index` vs +constant immediate). -**Parameters**: +**Parameters (Advanced Tier: explicit pointer surface)**: | Parameter | Type | Description | |-----------|------|-------------| | `mask` | `MaskType` | Predicate mask to store | -| `buf` | `ptr` | Pointer to destination buffer | -| `offset` | `Index` | Scalar/index-style offset | +| `buf` | `ptr` | Pointer to destination UB buffer | +| `offset` | `Index` | Runtime offset (`index`) | +| `dist` | `PredicateDist` | Distribution mode. Use `PredicateDist.NORM` or `PredicateDist.PK` (default: `PredicateDist.NORM`). | + +**DIST semantics (VPTO-aligned)**: +- `NORM`: stores packed predicate payload into destination space of size `VL/8`. +- `PK`: stores packed predicate payload into destination space of size `VL/16`, + keeping one bit out of every two bits. **Returns**: None (side-effect operation) **Example**: ```python -pto.psts(mask, buf, offset) +pto.psts(mask, buf, offset, PredicateDist.NORM) ``` #### `pto.pst(mask: MaskType, buf: ptr, offset: Index, dist: PredicateDist = PredicateDist.NORM) -> None` [Advanced Tier] diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 8ce40fa6b..0b9c2f1b5 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -1649,37 +1649,28 @@ def _analyze_vector_store_stmt( allow_outer_lookup: bool, ) -> tuple[SemanticStmt, dict[str, SemanticBinding]]: if expr.name == "psts": - dist_expr: SemanticExpr | None = None - if len(expr.args) == 2: - value = self._analyze_expr(expr.args[0], env, allow_outer_lookup=allow_outer_lookup) - destination, indices = self._analyze_tile_vector_access( - expr.args[1], - env, - allow_outer_lookup=allow_outer_lookup, - context="pto.psts destination", - ) - elif len(expr.args) == 3 and isinstance(expr.args[1], FrontendSubscriptExpr): - value = self._analyze_expr(expr.args[0], env, allow_outer_lookup=allow_outer_lookup) - destination, indices = self._analyze_tile_vector_access( - expr.args[1], - env, - allow_outer_lookup=allow_outer_lookup, - context="pto.psts destination", + if len(expr.args) in {2, 3} and isinstance(expr.args[1], FrontendSubscriptExpr): + raise TypeError( + "pto.psts does not support Tile element-indexing syntax in TileLang DSL v1; " + "use explicit pointer form `pto.psts(mask, buf, offset[, dist])`" ) - dist_expr = self._analyze_expr(expr.args[2], env, allow_outer_lookup=allow_outer_lookup) + + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + dist_expr: SemanticExpr | None = None + if len(args) == 3: + value, destination, offset = args + indices = (offset,) + elif len(args) == 4: + value, destination, offset, dist_expr = args + indices = (offset,) else: - args = tuple( - self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) - for arg in expr.args + raise TypeError( + "pto.psts expects 3 or 4 positional arguments in TileLang DSL v1: " + "`pto.psts(mask, buf, offset[, dist])`" ) - if len(args) == 3: - value, destination, offset = args - indices = (offset,) - elif len(args) == 4: - value, destination, offset, dist_expr = args - indices = (offset,) - else: - raise TypeError("pto.psts expects Tile element-indexing syntax or 3/4 positional arguments") self._require_mask_expr(value, "pto.psts value") self._require_vector_pointer_expr(destination, "pto.psts destination") for index in indices: diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index a3885edc4..bc58f2181 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -4437,6 +4437,25 @@ def kernel( self.assertNotIn("pto.vsst", text) self.assertNotIn("pto.vsta ", text) + def test_psts_rejects_tile_indexing_surface(self) -> None: + @pto.vkernel( + op="predicate_store_tile_indexing_reject", + dtypes=[(pto.ui32,)], + advanced=True, + ) + def kernel(mask_dst: pto.Tile): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + pto.psts(mask, mask_dst[0, 0:]) + return None + + specialized = kernel.specialize( + mask_dst=pto.TileSpecialization(shape=(16, 64), memory_space=pto.MemorySpace.UB), + ) + with self.assertRaises(TypeError) as ctx: + analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertIn("does not support Tile element-indexing syntax", str(ctx.exception)) + self.assertIn("pto.psts(mask, buf, offset", str(ctx.exception)) + def test_strict_vecscope_rejects_implicit_capture_during_semantic_analysis(self) -> None: @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f16, pto.i32)], advanced=True) def kernel(inp: pto.TensorView, tile: pto.Tile, scale: pto.i32): From 0798ead3f61e48b9351d72e8f18ef3f3fddb8e4c Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Thu, 16 Apr 2026 09:50:38 +0800 Subject: [PATCH 080/192] fix(dsl): handle valid_shape subscript and guard unsupported tuple subscripts (#90) --- tilelang-dsl/python/tilelang_dsl/lowering.py | 14 +++ tilelang-dsl/python/tilelang_dsl/semantic.py | 35 +++++--- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 92 ++++++++++++++++++++ 3 files changed, 131 insertions(+), 10 deletions(-) diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index 3c361dccf..6e73e997b 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -3342,6 +3342,20 @@ def _lower_subscript_access( desired_name: str | None, into: list[str] | None, ) -> _RenderedValue: + if isinstance(expr.base, SemanticTupleExpr): + if not isinstance(expr.index, SemanticLiteralExpr) or not isinstance(expr.index.value, int): + raise NotImplementedError("tuple indices must be integer literals in TileLang DSL v1 lowering") + if expr.index.value < 0 or expr.index.value >= len(expr.base.elements): + raise NotImplementedError( + f"tuple subscript index {expr.index.value} is out of bounds for tuple length {len(expr.base.elements)}" + ) + return self._lower_expr( + expr.base.elements[expr.index.value], + env, + indent=indent, + desired_name=desired_name, + into=into, + ) if ( into is not None and isinstance(expr.base, SemanticAttributeAccess) diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 0b9c2f1b5..6190a73dc 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -3436,17 +3436,32 @@ def _subscript_type(self, base: SemanticExpr, index: SemanticExpr) -> SemanticTy if isinstance(base.type, SemanticShapeType): if not isinstance(index.type, SemanticIndexType): raise TypeError("shape subscript index must be an index value in TileLang DSL v1") - if ( - isinstance(base, SemanticAttributeAccess) - and isinstance(base.base, SemanticBindingRef) - and isinstance(index, SemanticLiteralExpr) - and isinstance(index.value, int) - ): - if index.value < 0 or index.value >= base.type.rank: - raise TypeError( - f"shape subscript index {index.value} is out of bounds for rank {base.type.rank}" - ) + if not isinstance(index, SemanticLiteralExpr) or not isinstance(index.value, int): + raise TypeError( + "shape/stride/valid_shape subscript index must be an integer literal in TileLang DSL v1" + ) + if index.value < 0 or index.value >= base.type.rank: + raise TypeError( + f"shape subscript index {index.value} is out of bounds for rank {base.type.rank}" + ) return SemanticIndexType() + if isinstance(base.type, SemanticTupleType): + if not isinstance(index.type, SemanticIndexType): + raise TypeError("tuple subscript index must be an index value in TileLang DSL v1") + if not isinstance(base, SemanticTupleExpr): + raise TypeError( + "tuple subscripting currently requires a shape-like tuple expression in TileLang DSL v1" + ) + if not base.type.elements: + raise TypeError("cannot subscript an empty tuple in TileLang DSL v1") + if not isinstance(index, SemanticLiteralExpr) or not isinstance(index.value, int): + raise TypeError("tuple subscript index must be an integer literal in TileLang DSL v1") + + if index.value < 0 or index.value >= len(base.type.elements): + raise TypeError( + f"tuple subscript index {index.value} is out of bounds for tuple length {len(base.type.elements)}" + ) + return base.type.elements[index.value] if isinstance(base.type, (SemanticTensorViewType, SemanticPartitionTensorViewType)): if not isinstance(index, SemanticTupleExpr): raise TypeError("TensorView slicing expects a tuple of slices in TileLang DSL v1") diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index bc58f2181..a5bcd186e 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -3411,6 +3411,45 @@ def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): r"pto\.vsts %summed_\d+, %tmp_\d+\[%c0\], %mask_\d+ : !pto\.vreg<128xf16>, memref<\?x\?xf16, strided<\[\?, \?\], offset: \?>, #pto\.address_space>, !pto\.mask", ) + def test_tile_valid_shape_subscript_profile_lowers_to_runtime_bounds_in_advanced_mode(self) -> None: + @pto.vkernel(op="tile_valid_shape_subscript_unique", dtypes=[(pto.f16,)], advanced=True) + def kernel(dst: pto.Tile): + valid_rows = dst.valid_shape[0] + valid_cols = dst.valid_shape[1] + area = valid_rows * valid_cols + if area == 0: + area = 1 + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization( + shape=(8, 128), + memory_space=pto.MemorySpace.UB, + valid_shape=("valid_rows", "valid_cols"), + ), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertEqual( + [(param.name, param.kind) for param in semantic_kernel.parameters], + [ + ("dst", "tile"), + ("__valid_shape_dst_0", "tile_valid_shape"), + ("__valid_shape_dst_1", "tile_valid_shape"), + ], + ) + valid_rows_assign = semantic_kernel.body[0] + valid_cols_assign = semantic_kernel.body[1] + self.assertIsInstance(valid_rows_assign, SemanticAssignStmt) + self.assertIsInstance(valid_cols_assign, SemanticAssignStmt) + self.assertIsInstance(valid_rows_assign.targets[0].type, SemanticIndexType) + self.assertIsInstance(valid_cols_assign.targets[0].type, SemanticIndexType) + + text = specialized.mlir_text() + self.assertIn("valid_shape=(?, ?)", text) + self.assertRegex(text, r"%valid_rows_\d+ = pto\.tile_valid_rows %arg0") + self.assertRegex(text, r"%valid_cols_\d+ = pto\.tile_valid_cols %arg0") + def test_tile_partial_dynamic_valid_shape_profile_tracks_dynamic_axes_only(self) -> None: elem = pto.TypeVar("Elem") @@ -4040,6 +4079,59 @@ def kernel(src: pto.TensorView, dst: pto.Tile): self.assertRegex(text, r"%ub_rows_\d+ = arith\.constant 8 : index") self.assertRegex(text, r"%ub_cols_\d+ = arith\.constant 64 : index") + def test_shape_subscript_rejects_non_literal_index_in_semantic(self) -> None: + @pto.vkernel(op="shape_dynamic_subscript_reject_unique", dtypes=[(pto.f32,)]) + def kernel(src: pto.TensorView): + axis = src.shape[0] + value = src.shape[axis] + return None + + with self.assertRaises(TypeError) as ctx: + analyze_frontend_kernel(build_frontend_kernel_node(kernel)) + self.assertIn( + "shape/stride/valid_shape subscript index must be an integer literal in TileLang DSL v1", + str(ctx.exception), + ) + + def test_valid_shape_subscript_rejects_non_literal_index_in_semantic(self) -> None: + @pto.vkernel(op="valid_shape_dynamic_subscript_reject_unique", dtypes=[(pto.f16,)], advanced=True) + def kernel(dst: pto.Tile): + axis = 0 + value = dst.valid_shape[axis] + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization( + shape=(8, 128), + memory_space=pto.MemorySpace.UB, + valid_shape=("valid_rows", "valid_cols"), + ) + ) + + with self.assertRaises(TypeError) as ctx: + analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertIn( + "tuple subscript index must be an integer literal in TileLang DSL v1", + str(ctx.exception), + ) + + def test_tuple_call_result_subscript_rejects_in_semantic(self) -> None: + @pto.vkernel(op="tuple_call_result_subscript_reject_unique", dtypes=[(pto.f16,)], advanced=True) + def kernel(dst: pto.Tile): + mask = pto.make_mask(dst.element_type, pto.i32(64))[0] + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + + with self.assertRaises(TypeError) as ctx: + analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertIn( + "tuple subscripting currently requires a shape-like tuple expression in TileLang DSL v1", + str(ctx.exception), + ) + def test_advanced_mode_lowers_compare_predicate_carry_and_rearrangement_families(self) -> None: @pto.vkernel(op="advanced_family", dtypes=[(pto.i32, pto.i32, pto.i32, pto.i32)], advanced=True) def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile, scalar: pto.i32): From 56ab05814a1f36a237f76a73bb928ffc2c67833e Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Wed, 15 Apr 2026 16:51:37 +0800 Subject: [PATCH 081/192] support pto.set_mov_pad_val --- docs/isa/02-dma-copy.md | 32 +++++++++++ docs/vpto-spec.md | 4 +- include/PTO/IR/VPTOOps.td | 11 ++++ lib/PTO/IR/VPTO.cpp | 19 +++++++ lib/PTO/Transforms/HIVMIntrinsicNaming.cpp | 5 +- lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 66 +++++++++++++++++++++- tools/ptoas/ptoas.cpp | 1 + 7 files changed, 135 insertions(+), 3 deletions(-) diff --git a/docs/isa/02-dma-copy.md b/docs/isa/02-dma-copy.md index 8d867af08..6e2fdc4f3 100644 --- a/docs/isa/02-dma-copy.md +++ b/docs/isa/02-dma-copy.md @@ -107,6 +107,36 @@ Note: UB stride fields are 21 bits (sufficient for 256KB UB address space), GM s --- +## Pad Value Configuration + +### `pto.set_mov_pad_val` + +- **syntax:** `pto.set_mov_pad_val %value : T` +- **supported `T`:** `i8`, `i16`, `i32`, `f16`, `bf16`, `f32` +- **semantics:** Configure the pad fill value used by GM→UB DMA when `data_select_bit = true`. + +This op programs the hardware pad register consumed by `pto.copy_gm_to_ubuf`. The operand is a typed scalar. Its raw bit pattern is encoded into the underlying hardware configuration payload: + +- integer inputs use their zero-extended bit pattern +- floating-point inputs use their bitcast-to-integer bit pattern, then zero-extend to `i64` + +This configuration affects only the GM→UB padding path. UB→GM DMA ignores the pad value. + +**Parameter Table:** + +| Parameter | Description | +|-----------|-------------| +| `%value` | Pad fill scalar. Must be one of `i8/i16/i32/f16/bf16/f32`. | + +**Example:** + +```mlir +%pad = arith.constant 0 : i16 +pto.set_mov_pad_val %pad : i16 +``` + +--- + ## DMA Transfer Execution ### `pto.copy_gm_to_ubuf` @@ -445,6 +475,8 @@ UB (128 cols wide, 32B-aligned, padded): ``` ```mlir +%pad = arith.constant 0 : i16 +pto.set_mov_pad_val %pad : i16 pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 diff --git a/docs/vpto-spec.md b/docs/vpto-spec.md index 5d7758f21..638238c8b 100644 --- a/docs/vpto-spec.md +++ b/docs/vpto-spec.md @@ -246,6 +246,8 @@ pto.strict_vecscope(%ub, %ub_out, %lane) { ### Example: VecScope ```mlir +%pad = arith.constant 0 : i32 +pto.set_mov_pad_val %pad : i32 pto.set_loop2_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 pto.set_loop1_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 @@ -879,7 +881,7 @@ This section provides a categorized overview of all PTO micro Instruction operat | # | Group | Description | Count | Details | |---|-------|-------------|-------|---------| | 1 | [Pipeline Sync](isa/01-pipeline-sync.md) | Intra-core pipeline synchronization | 5 | `pto.set_flag`, `pto.wait_flag`, `pto.pipe_barrier`, `pto.get_buf`, `pto.rls_buf` | -| 2 | [DMA Copy Programming](isa/02-dma-copy.md) | DMA configuration and transfer between GM↔UB | 9 | `pto.set_loop*_stride_*`, `pto.set_loop_size_*`, `pto.copy_gm_to_ubuf`, `pto.copy_ubuf_to_ubuf`, `pto.copy_ubuf_to_gm` | +| 2 | [DMA Copy Programming](isa/02-dma-copy.md) | DMA configuration and transfer between GM↔UB | 10 | `pto.set_loop*_stride_*`, `pto.set_loop_size_*`, `pto.set_mov_pad_val`, `pto.copy_gm_to_ubuf`, `pto.copy_ubuf_to_ubuf`, `pto.copy_ubuf_to_gm` | | 3 | [Vector Load/Store](isa/03-vector-load-store.md) | UB↔vreg data movement with various access patterns | ~20 | `pto.vlds`, `pto.vldsx2`, `pto.vgather2`, `pto.vsts`, `pto.vstsx2`, `pto.vscatter`, etc. | | 4 | [Predicate Load/Store](isa/04-predicate-load-store.md) | UB↔mask register movement | 5 | `pto.plds`, `pto.pldi`, `pto.psts`, `pto.psti`, `pto.pstu` | | 5 | [Materialization & Predicate Ops](isa/05-materialization-predicate.md) | Scalar broadcast, predicate generation and manipulation | ~17 | `pto.vbr`, `pto.vdup`, `pto.pset_b*`, `pto.pge_b*`, `pto.plt_b*`, `pto.ppack`, `pto.punpack`, `pto.pnot`, `pto.psel`, etc. | diff --git a/include/PTO/IR/VPTOOps.td b/include/PTO/IR/VPTOOps.td index 3daf2092d..015500acd 100644 --- a/include/PTO/IR/VPTOOps.td +++ b/include/PTO/IR/VPTOOps.td @@ -127,6 +127,17 @@ class PTO_BinaryI64ConfigOp : PTO_Op { }]; } +def PTO_SetMovPadValOp : PTO_Op<"set_mov_pad_val"> { + let arguments = (ins AnyTypeOf<[AnyInteger, AnyFloat], + "integer/float scalar">:$value); + let results = (outs); + let hasVerifier = 1; + + let assemblyFormat = [{ + $value attr-dict `:` type($value) + }]; +} + def PTO_SetLoop2StrideOutToUbOp : PTO_BinaryI64ConfigOp<"set_loop2_stride_outtoub">; def PTO_SetLoop1StrideOutToUbOp : PTO_BinaryI64ConfigOp<"set_loop1_stride_outtoub">; def PTO_SetLoopSizeOutToUbOp : PTO_BinaryI64ConfigOp<"set_loop_size_outtoub">; diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp index 1d8853675..359934536 100644 --- a/lib/PTO/IR/VPTO.cpp +++ b/lib/PTO/IR/VPTO.cpp @@ -120,6 +120,16 @@ static bool isSupportedVdupPosition(std::optional position) { return !position || *position == "LOWEST" || *position == "HIGHEST"; } +static bool isSupportedMovPadScalarType(Type type) { + if (auto intType = dyn_cast(type)) + return intType.isSignless() && + (intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32); + if (auto floatType = dyn_cast(type)) + return floatType.isF16() || floatType.isBF16() || floatType.isF32(); + return false; +} + static std::optional getVdupMaskGranularity(Type elementType) { if (auto intType = dyn_cast(elementType)) { switch (intType.getWidth()) { @@ -1196,6 +1206,15 @@ LogicalResult CopyGmToUbufOp::verify() { return verifyCopyGmToUbufOp(*this, true); } +LogicalResult SetMovPadValOp::verify() { + Type valueType = getValue().getType(); + if (isSupportedMovPadScalarType(valueType)) + return success(); + return emitOpError() + << "expects i8/i16/i32 or f16/bf16/f32 scalar operand, but got " + << valueType; +} + LogicalResult VbrOp::verify() { if (failed(verifyVRegTypeLike(*this, getResult().getType(), "result"))) return failure(); diff --git a/lib/PTO/Transforms/HIVMIntrinsicNaming.cpp b/lib/PTO/Transforms/HIVMIntrinsicNaming.cpp index d87cb6867..406ec5fe5 100644 --- a/lib/PTO/Transforms/HIVMIntrinsicNaming.cpp +++ b/lib/PTO/Transforms/HIVMIntrinsicNaming.cpp @@ -187,6 +187,8 @@ static FailureOr selectConfigLike(Operation *op) { ""); if (isa(op)) return makeResolved(op, "llvm.hivm.SET.LOOP.SIZE.UBTOOUT", usedFields, ""); + if (isa(op)) + return makeResolved(op, "llvm.hivm.SET.MOV.PAD.VAL", usedFields, ""); llvm::SmallVector missingFields = {"confirmed_hivm_name"}; return makeUnresolved(op, getOpMnemonic(op), "", usedFields, missingFields, @@ -539,7 +541,8 @@ FailureOr selectIntrinsic(Operation *op) { if (isa(op)) + pto::SetLoop1StrideUbToOutOp, pto::SetLoopSizeUbToOutOp, + pto::SetMovPadValOp>(op)) return selectConfigLike(op); if (succeeded(selectLoadIntrinsic(op))) diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index b54898fe6..96dee0e78 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -1576,6 +1576,9 @@ static FailureOr buildVcvtContract(pto::VcvtOp op) { template static StringRef buildSetLoopCallee(MLIRContext *context); +template +static StringRef buildUnaryConfigCallee(MLIRContext *context); + template <> StringRef buildSetLoopCallee(MLIRContext *context) { return StringAttr::get(context, "llvm.hivm.SET.LOOP2.STRIDE.OUTTOUB") @@ -1612,6 +1615,34 @@ StringRef buildSetLoopCallee(MLIRContext *context) { .getValue(); } +template <> +StringRef buildUnaryConfigCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.MOV.PAD.VAL").getValue(); +} + +static FailureOr encodeMovPadValue(Location loc, Value value, + ConversionPatternRewriter &rewriter) { + Type type = value.getType(); + Value payload = value; + unsigned bitWidth = 0; + + if (auto intType = dyn_cast(type)) { + bitWidth = intType.getWidth(); + } else if (auto floatType = dyn_cast(type)) { + bitWidth = floatType.getWidth(); + auto intType = rewriter.getIntegerType(bitWidth); + payload = rewriter.create(loc, intType, value); + } else { + return failure(); + } + + if (bitWidth != 8 && bitWidth != 16 && bitWidth != 32) + return failure(); + + return rewriter.create(loc, rewriter.getI64Type(), payload) + .getResult(); +} + template static StringRef buildSyncCallee(MLIRContext *context); @@ -4187,6 +4218,37 @@ class LowerSetLoopConfigOpPattern final : public OpConversionPattern { LoweringState &state; }; +template +class LowerUnaryConfigOpPattern final : public OpConversionPattern { +public: + explicit LowerUnaryConfigOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(ConfigOp op, typename ConfigOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr encoded = + encodeMovPadValue(op.getLoc(), adaptor.getValue(), rewriter); + if (failed(encoded)) + return rewriter.notifyMatchFailure( + op, "expected 8/16/32-bit integer or float mov-pad payload"); + + StringRef calleeName = buildUnaryConfigCallee(op.getContext()); + auto funcType = + rewriter.getFunctionType(TypeRange{rewriter.getI64Type()}, TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, + ValueRange{*encoded}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + template class LowerPipeEventSyncOpPattern final : public OpConversionPattern { public: @@ -4626,6 +4688,7 @@ static void populateVPTOOpLoweringPatterns(VPTOTypeConverter &typeConverter, LowerSetLoopConfigOpPattern, LowerSetLoopConfigOpPattern, LowerSetLoopConfigOpPattern, + LowerUnaryConfigOpPattern, LowerPipeEventSyncOpPattern, LowerPipeEventSyncOpPattern, LowerBarrierOpPattern, @@ -4670,7 +4733,8 @@ static void configureVPTOOpLoweringTarget(ConversionTarget &target, pto::GetBlockNumOp, pto::GetSubBlockNumOp>(); target.addIllegalOp(); + pto::SetLoop1StrideUbToOutOp, pto::SetLoopSizeUbToOutOp, + pto::SetMovPadValOp>(); target.addIllegalOp Date: Tue, 14 Apr 2026 21:28:33 +0800 Subject: [PATCH 082/192] Add VPTO membar support --- 3rdparty/PTO-Gym | 2 +- include/PTO/IR/PTOAttrs.td | 49 ++++++++++++++++ include/PTO/IR/VPTOOps.td | 14 +++++ lib/PTO/IR/PTO.cpp | 42 ++++++++++++++ lib/PTO/IR/VPTO.cpp | 4 -- lib/PTO/Transforms/HIVMIntrinsicNaming.cpp | 45 ++++++++++++++- lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 65 +++++++++++++++++++++- 7 files changed, 213 insertions(+), 8 deletions(-) diff --git a/3rdparty/PTO-Gym b/3rdparty/PTO-Gym index 8a186eae3..30ab9a22f 160000 --- a/3rdparty/PTO-Gym +++ b/3rdparty/PTO-Gym @@ -1 +1 @@ -Subproject commit 8a186eae3befc4f1417f4618addbd9e942339acd +Subproject commit 30ab9a22f3d9a8488bceefa4dd771d1018a1ac5a diff --git a/include/PTO/IR/PTOAttrs.td b/include/PTO/IR/PTOAttrs.td index cec8cb40e..61233c714 100644 --- a/include/PTO/IR/PTOAttrs.td +++ b/include/PTO/IR/PTOAttrs.td @@ -120,6 +120,55 @@ def PTO_PipeAttr : PTO_Attr<"Pipe", "pipe"> { }]; } +//===----------------------------------------------------------------------===// +// MemBar +//===----------------------------------------------------------------------===// + +def PTO_MEMBAR_VV_ALL : I32EnumAttrCase<"VV_ALL", 0>; +def PTO_MEMBAR_VST_VLD : I32EnumAttrCase<"VST_VLD", 1>; +def PTO_MEMBAR_VLD_VST : I32EnumAttrCase<"VLD_VST", 2>; +def PTO_MEMBAR_VST_VST : I32EnumAttrCase<"VST_VST", 3>; +def PTO_MEMBAR_VS_ALL : I32EnumAttrCase<"VS_ALL", 4>; +def PTO_MEMBAR_VST_LD : I32EnumAttrCase<"VST_LD", 5>; +def PTO_MEMBAR_VLD_ST : I32EnumAttrCase<"VLD_ST", 6>; +def PTO_MEMBAR_VST_ST : I32EnumAttrCase<"VST_ST", 7>; +def PTO_MEMBAR_SV_ALL : I32EnumAttrCase<"SV_ALL", 8>; +def PTO_MEMBAR_ST_VLD : I32EnumAttrCase<"ST_VLD", 9>; +def PTO_MEMBAR_LD_VST : I32EnumAttrCase<"LD_VST", 10>; +def PTO_MEMBAR_ST_VST : I32EnumAttrCase<"ST_VST", 11>; +def PTO_MEMBAR_SS_ALL : I32EnumAttrCase<"SS_ALL", 12>; +def PTO_MEMBAR_ST_LD : I32EnumAttrCase<"ST_LD", 13>; +def PTO_MEMBAR_LD_ST : I32EnumAttrCase<"LD_ST", 14>; +def PTO_MEMBAR_ST_ST : I32EnumAttrCase<"ST_ST", 15>; + +def PTO_MemBarEnum : PTO_I32Enum< + "MemBarKind", "PTO low-level memory barrier kind", [ + PTO_MEMBAR_VV_ALL, + PTO_MEMBAR_VST_VLD, + PTO_MEMBAR_VLD_VST, + PTO_MEMBAR_VST_VST, + PTO_MEMBAR_VS_ALL, + PTO_MEMBAR_VST_LD, + PTO_MEMBAR_VLD_ST, + PTO_MEMBAR_VST_ST, + PTO_MEMBAR_SV_ALL, + PTO_MEMBAR_ST_VLD, + PTO_MEMBAR_LD_VST, + PTO_MEMBAR_ST_VST, + PTO_MEMBAR_SS_ALL, + PTO_MEMBAR_ST_LD, + PTO_MEMBAR_LD_ST, + PTO_MEMBAR_ST_ST + ]>; + +def PTO_MemBarAttr : PTO_Attr<"MemBar", "membar"> { + let parameters = (ins EnumParameter:$kind); + let assemblyFormat = "`<` params `>`"; + let description = [{ + Low-level memory barrier kind for VPTO `pto.mem_bar`. + }]; +} + //===----------------------------------------------------------------------===// // Sync Op Type (High Level Abstraction) //===----------------------------------------------------------------------===// diff --git a/include/PTO/IR/VPTOOps.td b/include/PTO/IR/VPTOOps.td index 015500acd..a5ce95eb0 100644 --- a/include/PTO/IR/VPTOOps.td +++ b/include/PTO/IR/VPTOOps.td @@ -114,6 +114,20 @@ def StrictVecScopeOp : PTO_Op<"strict_vecscope", [SingleBlock, NoTerminator, }]; } +def PTO_MemBarOp : PTO_Op<"mem_bar"> { + let summary = "Low-level VPTO memory barrier"; + let description = [{ + Low-level memory ordering barrier that lowers to one of the + `llvm.hivm.mem.bar.*` intrinsics exposed by Bisheng. + }]; + + let arguments = (ins PTO_MemBarAttr:$kind); + let results = (outs); + + let hasCustomAssemblyFormat = 1; +} + + class PTO_BinaryI64ConfigOp : PTO_Op { let arguments = (ins I64:$first, diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index 1737234e3..a54733991 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -5689,6 +5689,34 @@ static ParseResult parseQuotedEventToken(OpAsmParser &parser, EventAttr &attr) { return success(); } +static ParseResult parseLegacyOrAttrMemBar(OpAsmParser &parser, + MemBarAttr &attr) { + auto loc = parser.getCurrentLocation(); + std::string token; + if (succeeded(parser.parseOptionalString(&token))) { + auto kind = symbolizeMemBarKind(token); + if (!kind) + return parser.emitError(loc) << "invalid membar token: " << token; + attr = MemBarAttr::get(parser.getContext(), *kind); + return success(); + } + + Attribute parsed; + if (failed(parser.parseAttribute(parsed))) + return failure(); + auto memBarAttr = dyn_cast(parsed); + if (!memBarAttr) + return parser.emitError(loc, "expected membar attribute"); + attr = memBarAttr; + return success(); +} + +static void printLegacyOrAttrMemBar(OpAsmPrinter &p, MemBarAttr kind, + ArrayRef attrs) { + p << ' ' << '"' << stringifyMemBarKind(kind.getKind()) << '"'; + p.printOptionalAttrDict(attrs, {"kind"}); +} + static ParseResult parseLegacyOrAttrPipe(OpAsmParser &parser, PipeAttr &attr) { auto loc = parser.getCurrentLocation(); std::string token; @@ -5818,6 +5846,20 @@ void WaitFlagOp::print(OpAsmPrinter &p) { (*this)->getAttrs()); } +ParseResult MemBarOp::parse(OpAsmParser &parser, OperationState &result) { + MemBarAttr kind; + if (parseLegacyOrAttrMemBar(parser, kind)) + return failure(); + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + result.addAttribute("kind", kind); + return success(); +} + +void MemBarOp::print(OpAsmPrinter &p) { + printLegacyOrAttrMemBar(p, getKind(), (*this)->getAttrs()); +} + static ParseResult parseLegacyOrAttrOpType(OpAsmParser &parser, Attribute &opTypeAttr) { auto loc = parser.getCurrentLocation(); diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp index 359934536..c1ff9e4a5 100644 --- a/lib/PTO/IR/VPTO.cpp +++ b/lib/PTO/IR/VPTO.cpp @@ -1810,8 +1810,6 @@ LogicalResult PldsOp::verify() { return emitOpError("requires index offset"); if (!isSupportedPredicateLoadDist(getDist())) return emitOpError("requires predicate load dist to be NORM, US, or DS"); - if (failed(verifyEnclosingLoopLike(*this, "pto.plds"))) - return failure(); return success(); } @@ -1832,8 +1830,6 @@ LogicalResult PldiOp::verify() { return emitOpError("requires offset to be a constant index immediate"); if (!isSupportedPredicateLoadDist(getDist())) return emitOpError("requires predicate load dist to be NORM, US, or DS"); - if (failed(verifyEnclosingLoopLike(*this, "pto.pldi"))) - return failure(); return success(); } diff --git a/lib/PTO/Transforms/HIVMIntrinsicNaming.cpp b/lib/PTO/Transforms/HIVMIntrinsicNaming.cpp index 406ec5fe5..aae68d33b 100644 --- a/lib/PTO/Transforms/HIVMIntrinsicNaming.cpp +++ b/lib/PTO/Transforms/HIVMIntrinsicNaming.cpp @@ -145,6 +145,44 @@ static IntrinsicSelection makeUnresolved(Operation *op, return selection; } +static StringRef getMemBarIntrinsicName(MemBarKind kind) { + switch (kind) { + case MemBarKind::VV_ALL: + return "llvm.hivm.mem.bar.vv.all"; + case MemBarKind::VST_VLD: + return "llvm.hivm.mem.bar.vst.vld"; + case MemBarKind::VLD_VST: + return "llvm.hivm.mem.bar.vld.vst"; + case MemBarKind::VST_VST: + return "llvm.hivm.mem.bar.vst.vst"; + case MemBarKind::VS_ALL: + return "llvm.hivm.mem.bar.vs.all"; + case MemBarKind::VST_LD: + return "llvm.hivm.mem.bar.vst.ld"; + case MemBarKind::VLD_ST: + return "llvm.hivm.mem.bar.vld.st"; + case MemBarKind::VST_ST: + return "llvm.hivm.mem.bar.vst.st"; + case MemBarKind::SV_ALL: + return "llvm.hivm.mem.bar.sv.all"; + case MemBarKind::ST_VLD: + return "llvm.hivm.mem.bar.st.vld"; + case MemBarKind::LD_VST: + return "llvm.hivm.mem.bar.ld.vst"; + case MemBarKind::ST_VST: + return "llvm.hivm.mem.bar.st.vst"; + case MemBarKind::SS_ALL: + return "llvm.hivm.mem.bar.ss.all"; + case MemBarKind::ST_LD: + return "llvm.hivm.mem.bar.st.ld"; + case MemBarKind::LD_ST: + return "llvm.hivm.mem.bar.ld.st"; + case MemBarKind::ST_ST: + return "llvm.hivm.mem.bar.st.st"; + } + llvm_unreachable("unexpected membar kind"); +} + static FailureOr selectSyncLike(Operation *op) { llvm::SmallVector usedFields; usedFields.push_back("op=" + getOpMnemonic(op)); @@ -162,6 +200,10 @@ static FailureOr selectSyncLike(Operation *op) { } else if (auto barrier = dyn_cast(op)) { usedFields.push_back("pipe=" + printAttrText(barrier.getPipe())); return makeResolved(op, "llvm.hivm.BARRIER", usedFields, ""); + } else if (auto membar = dyn_cast(op)) { + usedFields.push_back("kind=" + printAttrText(membar.getKind())); + return makeResolved(op, getMemBarIntrinsicName(membar.getKind().getKind()), + usedFields, ""); } llvm::SmallVector missingFields = {"confirmed_hivm_name"}; @@ -536,7 +578,8 @@ FailureOr selectStoreIntrinsic(Operation *op) { } FailureOr selectIntrinsic(Operation *op) { - if (isa(op)) + if (isa(op)) return selectSyncLike(op); if (isa(MLIRContext *context) { return StringAttr::get(context, "llvm.hivm.BARRIER").getValue(); } +static StringRef buildMemBarCallee(MemBarKind kind, MLIRContext *context) { + switch (kind) { + case MemBarKind::VV_ALL: + return StringAttr::get(context, "llvm.hivm.mem.bar.vv.all").getValue(); + case MemBarKind::VST_VLD: + return StringAttr::get(context, "llvm.hivm.mem.bar.vst.vld").getValue(); + case MemBarKind::VLD_VST: + return StringAttr::get(context, "llvm.hivm.mem.bar.vld.vst").getValue(); + case MemBarKind::VST_VST: + return StringAttr::get(context, "llvm.hivm.mem.bar.vst.vst").getValue(); + case MemBarKind::VS_ALL: + return StringAttr::get(context, "llvm.hivm.mem.bar.vs.all").getValue(); + case MemBarKind::VST_LD: + return StringAttr::get(context, "llvm.hivm.mem.bar.vst.ld").getValue(); + case MemBarKind::VLD_ST: + return StringAttr::get(context, "llvm.hivm.mem.bar.vld.st").getValue(); + case MemBarKind::VST_ST: + return StringAttr::get(context, "llvm.hivm.mem.bar.vst.st").getValue(); + case MemBarKind::SV_ALL: + return StringAttr::get(context, "llvm.hivm.mem.bar.sv.all").getValue(); + case MemBarKind::ST_VLD: + return StringAttr::get(context, "llvm.hivm.mem.bar.st.vld").getValue(); + case MemBarKind::LD_VST: + return StringAttr::get(context, "llvm.hivm.mem.bar.ld.vst").getValue(); + case MemBarKind::ST_VST: + return StringAttr::get(context, "llvm.hivm.mem.bar.st.vst").getValue(); + case MemBarKind::SS_ALL: + return StringAttr::get(context, "llvm.hivm.mem.bar.ss.all").getValue(); + case MemBarKind::ST_LD: + return StringAttr::get(context, "llvm.hivm.mem.bar.st.ld").getValue(); + case MemBarKind::LD_ST: + return StringAttr::get(context, "llvm.hivm.mem.bar.ld.st").getValue(); + case MemBarKind::ST_ST: + return StringAttr::get(context, "llvm.hivm.mem.bar.st.st").getValue(); + } + llvm_unreachable("unexpected membar kind"); +} + template <> StringRef buildSyncCallee(MLIRContext *context) { return StringAttr::get(context, "llvm.hivm.GET.BUFI.mode").getValue(); @@ -4316,6 +4354,29 @@ class LowerBarrierOpPattern final : public OpConversionPattern { LoweringState &state; }; +class LowerMemBarOpPattern final : public OpConversionPattern { +public: + explicit LowerMemBarOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::MemBarOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + StringRef calleeName = buildMemBarCallee(op.getKind().getKind(), op.getContext()); + auto funcType = rewriter.getFunctionType(TypeRange{}, TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, ValueRange{}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + template class LowerBufSyncOpPattern final : public OpConversionPattern { public: @@ -4691,7 +4752,7 @@ static void populateVPTOOpLoweringPatterns(VPTOTypeConverter &typeConverter, LowerUnaryConfigOpPattern, LowerPipeEventSyncOpPattern, LowerPipeEventSyncOpPattern, - LowerBarrierOpPattern, + LowerBarrierOpPattern, LowerMemBarOpPattern, LowerBufSyncOpPattern, LowerBufSyncOpPattern, LowerRuntimeQueryOpPattern, @@ -4728,7 +4789,7 @@ static void configureVPTOOpLoweringTarget(ConversionTarget &target, func::FuncDialect, scf::SCFDialect>(); target.addLegalOp(); target.addIllegalOp(); + pto::MemBarOp, pto::GetBufOp, pto::RlsBufOp>(); target.addIllegalOp(); target.addIllegalOp Date: Wed, 15 Apr 2026 11:33:17 +0800 Subject: [PATCH 083/192] warning the channel vsts --- lib/PTO/IR/VPTO.cpp | 47 +++++++++--- lib/PTO/Transforms/PTOValidateVPTOIR.cpp | 91 ++++++++++++++++++++++-- 2 files changed, 124 insertions(+), 14 deletions(-) diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp index c1ff9e4a5..5aee3cc6c 100644 --- a/lib/PTO/IR/VPTO.cpp +++ b/lib/PTO/IR/VPTO.cpp @@ -874,16 +874,32 @@ static LogicalResult verifyVstsDistWidth(Operation *op, StringRef dist, return *width == 32 ? success() : op->emitOpError("dist PK4 only supports 32-bit elements"); + if (dist == "MRG4CHN") { + if (*width != 8) + return op->emitOpError("dist MRG4CHN only supports 8-bit elements"); + return success(); + } + if (dist == "MRG2CHN") { + if (!matchesWidthFamily(dist, *width, {8, 16})) + return op->emitOpError("dist MRG2CHN only supports 8/16-bit elements"); + return success(); + } + + return op->emitOpError("requires a supported store distribution token"); +} + +static std::optional +getVstsMaskGranularityOverride(StringRef dist, Type elementType) { + auto width = getDistElementWidth(elementType); + if (!width) + return std::nullopt; + if (dist == "MRG4CHN") - return *width == 8 - ? success() - : op->emitOpError("dist MRG4CHN only supports 8-bit elements"); + return StringRef("b32"); if (dist == "MRG2CHN") - return matchesWidthFamily(dist, *width, {8, 16}) - ? success() - : op->emitOpError("dist MRG2CHN only supports 8/16-bit elements"); + return *width == 8 ? StringRef("b16") : StringRef("b32"); - return op->emitOpError("requires a supported store distribution token"); + return std::nullopt; } static LogicalResult verifyVstsx2DistWidth(Operation *op, StringRef dist, @@ -2660,8 +2676,6 @@ template static LogicalResult verifyVstsCommon(StoreOp op) { if (failed(verifyVRegTypeLike(op, op.getValue().getType(), "value type"))) return failure(); - if (failed(verifyMaskTypeLike(op, op.getMask().getType(), "mask type"))) - return failure(); if (!isBufferLike(op.getDestination().getType())) return op.emitOpError("requires a pointer-like destination"); @@ -2681,6 +2695,21 @@ static LogicalResult verifyVstsCommon(StoreOp op) { cast(op.getValue().getType()).getElementType()))) return failure(); + if (std::optional dist = op.getDist()) { + if (std::optional granularity = getVstsMaskGranularityOverride( + *dist, cast(op.getValue().getType()).getElementType())) { + if (failed(verifyMaskTypeWithGranularityLike(op, op.getMask().getType(), + "mask type", *granularity))) + return failure(); + } else if (failed(verifyMaskTypeLike(op, op.getMask().getType(), + "mask type"))) { + return failure(); + } + } else if (failed(verifyMaskTypeLike(op, op.getMask().getType(), + "mask type"))) { + return failure(); + } + return success(); } diff --git a/lib/PTO/Transforms/PTOValidateVPTOIR.cpp b/lib/PTO/Transforms/PTOValidateVPTOIR.cpp index a81b4e81f..92b57f769 100644 --- a/lib/PTO/Transforms/PTOValidateVPTOIR.cpp +++ b/lib/PTO/Transforms/PTOValidateVPTOIR.cpp @@ -308,6 +308,45 @@ class VPTOLegalityValidator { << formatExpectedMaskType(*expected); } + static std::optional + inferVstsMaskGranularityOverride(Operation *op) { + Value value; + if (auto vsts = dyn_cast(op)) + value = vsts.getValue(); + else if (auto vstsPost = dyn_cast(op)) + value = vstsPost.getValue(); + else + return std::nullopt; + + auto valueType = dyn_cast(value.getType()); + if (!valueType) + return std::nullopt; + + auto elementType = valueType.getElementType(); + auto elementIntType = dyn_cast(elementType); + if (!elementIntType) + return std::nullopt; + + auto distAttr = op->getAttrOfType("dist"); + if (!distAttr) + return std::nullopt; + + StringRef dist = distAttr.getValue(); + unsigned width = elementIntType.getWidth(); + if (dist == "MRG4CHN") { + if (width == 8) + return VPTOMaskGranularity::B32; + return std::nullopt; + } + if (dist == "MRG2CHN") { + if (width == 8) + return VPTOMaskGranularity::B16; + if (width == 16) + return VPTOMaskGranularity::B32; + } + return std::nullopt; + } + static LogicalResult validateSameMaskGranularity(Operation *op, Type lhsType, StringRef lhsRole, Type rhsType, @@ -338,11 +377,51 @@ class VPTOLegalityValidator { template static LogicalResult validateValueMaskVectorConsumer(OpTy op) { + if constexpr (std::is_same_v || + std::is_same_v) { + if (std::optional expected = + inferVstsMaskGranularityOverride(op.getOperation())) { + auto actual = + VPTOLegalityHelper::getMaskGranularity(op.getMask().getType()); + if (!actual || *actual == *expected) + return success(); + return op.emitOpError() + << "mask type " << op.getMask().getType() + << " does not match value vector type " + << op.getValue().getType() << "; expected " + << formatExpectedMaskType(*expected); + } + } return validateMaskMatchesVectorFamily(op, op.getMask().getType(), "mask type", op.getValue().getType(), "value vector type"); } + void emitHardwareSupportWarnings(Operation *op) const { + auto emitForStore = [&](auto storeOp) { + Operation *store = storeOp.getOperation(); + auto distAttr = store->getAttrOfType("dist"); + if (!distAttr) + return; + + StringRef dist = distAttr.getValue(); + if (dist == "MRG4CHN" || dist == "MRG2CHN") + writeDiagnostic((Twine("warning: ") + store->getName().getStringRef() + + " dist " + dist + + " is not supported on the current hardware\n") + .str()); + }; + + if (auto vsts = dyn_cast(op)) { + emitForStore(vsts); + return; + } + if (auto vstsPost = dyn_cast(op)) { + emitForStore(vstsPost); + return; + } + } + template static LogicalResult validateResultMaskVectorConsumer(OpTy op) { return validateMaskMatchesVectorFamily(op, op.getMask().getType(), @@ -607,11 +686,13 @@ class VPTOLegalityValidator { if (!VPTOLegalityHelper::requiresVecScope(op)) return WalkResult::advance(); - if (VPTOLegalityHelper::getEnclosingVectorScopeCarrier(op)) - return (failed(validateFamilySuffixMaskContracts(op)) || - failed(validateMaskGranularityContracts(op))) - ? WalkResult::interrupt() - : WalkResult::advance(); + if (VPTOLegalityHelper::getEnclosingVectorScopeCarrier(op)) { + if (failed(validateFamilySuffixMaskContracts(op)) || + failed(validateMaskGranularityContracts(op))) + return WalkResult::interrupt(); + emitHardwareSupportWarnings(op); + return WalkResult::advance(); + } op->emitOpError() << "requires enclosing scf.for with '" From 09ea9112038a74cc9913bd5fae48430538d4ad52 Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Wed, 15 Apr 2026 11:46:01 +0800 Subject: [PATCH 084/192] update docs for channel vsts --- docs/isa/03-vector-load-store.md | 6 +++--- docs/vpto-spec.md | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/isa/03-vector-load-store.md b/docs/isa/03-vector-load-store.md index bb840e44b..870b991bc 100644 --- a/docs/isa/03-vector-load-store.md +++ b/docs/isa/03-vector-load-store.md @@ -86,7 +86,7 @@ Most **`dist:`** tokens are **9** issue→retire cycles. **`INTLV`** on **`RV_VS | `NORM` | **9** cycles (`RV_VSTI`) | | `PK` | **9** cycles | | `INTLV` (`pto.vstx2`) | **12** cycles | -| `MRG4CHN`, `MRG2CHN` | **9** cycles | +| `MRG4CHN`, `MRG2CHN` | **9** cycles (surface retained; current A5 hardware still reports them unsupported at validation time) | ### Gather, scatter, and special addressing @@ -386,8 +386,8 @@ for (int blk = 0; blk < VL / 32; ++blk) { | `1PT` | `b8`, `b16`, `b32` | Only element 0 is written to the destination footprint | **9** cycles | | `PK` | `b16`, `b32`, `b64` | Pack low half bits of each source element before store | **9** cycles | | `PK4` | `b32` | Pack low 8 bits of each `b32` element before store | **9** cycles | -| `MRG4CHN` | `b8` | Merge 4 channel planes into an interleaved 4-channel layout | **9** cycles | -| `MRG2CHN` | `b8`, `b16` | Merge 2 channel planes into an interleaved 2-channel layout | **9** cycles | +| `MRG4CHN` | `b8` | Merge 4 channel planes into an interleaved 4-channel layout. VPTO currently requires `!pto.mask` for this family and emits a hardware-unsupported warning on A5. | **9** cycles | +| `MRG2CHN` | `b8`, `b16` | Merge 2 channel planes into an interleaved 2-channel layout. VPTO currently requires `!pto.mask` for `b8` input and `!pto.mask` for `b16` input, and emits a hardware-unsupported warning on A5. | **9** cycles | **Example — Contiguous store:** ```mlir diff --git a/docs/vpto-spec.md b/docs/vpto-spec.md index 638238c8b..0cb136c32 100644 --- a/docs/vpto-spec.md +++ b/docs/vpto-spec.md @@ -1047,6 +1047,6 @@ pto.vstsx2 %x, %y, %ub_xy[%offset], "INTLV", %all_mask : !pto.vreg<64xf32>, !pto ### Part 3C -2. **Store dist family completeness:** `vsts` currently covers `NORM`, `1PT`, `PK`, `PK4`, `MRG4CHN`, and `MRG2CHN`, while `vstsx2` covers `INTLV`. Confirm whether the surface constraints for these families are already sufficiently clear and complete. +2. **Store dist family completeness:** `vsts` currently covers `NORM`, `1PT`, `PK`, `PK4`, `MRG4CHN`, and `MRG2CHN`, while `vstsx2` covers `INTLV`. `MRG4CHN` / `MRG2CHN` are preserved in the VPTO surface, but the current hardware still reports them as unsupported via verifier warning and they are not expected to validate at runtime on A5 today. 3. **vcvt width-changing pattern:** The even/odd + `vor` pattern for forms such as `f32 -> f16` is the standard compiler lowering. Confirm this is the intended representation in the spec. 4. **Stateful store ops (Section 14):** These are complex with SSA state threading. Are they all needed for A5, or can some be simplified? From 905154f53e50bcef679b0f593626f42a6e5ded52 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Thu, 16 Apr 2026 12:49:42 +0800 Subject: [PATCH 085/192] fix(dsl): support explicit set_mov_pad_val for DMA padding --- .../docs/user_guide/01-introduction.md | 4 +- .../docs/user_guide/05-type-system.md | 1 + .../docs/user_guide/08-sync-dma-operations.md | 162 +- .../docs/vpto_spec/vpto-spec-current.md | 34 +- tilelang-dsl/docs/vpto_spec/vpto-spec-v0.3.md | 5349 +++++++++++++++++ .../python/tilelang_dsl/frontend_ast.py | 1 + tilelang-dsl/python/tilelang_dsl/lowering.py | 18 + tilelang-dsl/python/tilelang_dsl/semantic.py | 29 +- .../python/tilelang_dsl/support_matrix.py | 2 + tilelang-dsl/tests/test_tilelang_dsl_v1.py | 55 + 10 files changed, 5647 insertions(+), 8 deletions(-) create mode 100644 tilelang-dsl/docs/vpto_spec/vpto-spec-v0.3.md diff --git a/tilelang-dsl/docs/user_guide/01-introduction.md b/tilelang-dsl/docs/user_guide/01-introduction.md index 1ef148b77..26012f781 100644 --- a/tilelang-dsl/docs/user_guide/01-introduction.md +++ b/tilelang-dsl/docs/user_guide/01-introduction.md @@ -15,7 +15,7 @@ The DSL surface is organized into multiple maturity tiers, reflecting the stabil | Base vector ops (`make_mask`, `vlds`, `vsts`, `vadd`, `vmuls`, etc.) | `basic` | Default compute skeleton for starter kernels. | | `strict_vecscope` | `advanced` | Explicit vector-scope management for expert authoring. | | Raw pointer family (`ptr(...)`, `castptr`, `addptr`) | `advanced` | For expert authoring and migration; not required for Quick Start. | -| DMA family (`copy_*`, `set_loop*_stride_*`, `set_loop_size_*`) | `advanced` | Direct DMA engine control for expert authoring. | +| DMA family (`copy_*`, `set_loop*_stride_*`, `set_loop_size_*`, pad-fill control) | `advanced` | Direct DMA engine control for expert authoring, including GM→UB padding behavior. | | Tile pointer helper (`tile.as_ptr()`) | `advanced` | Expert-only helper when advanced authoring needs explicit typed pointers. | For the authoritative tier classification, consult `tilelang-dsl/python/tilelang_dsl/support_matrix.py`. For known implementation gaps, refer to `tilelang-dsl/docs/unsupported-features.md`. @@ -35,7 +35,7 @@ The TileLang DSL provides two distinct authoring modes: - Uses **raw pointer semantics** for explicit memory management - Direct pointer operations correspond to `pto.ptr` types in MLIR - Explicit pointer arithmetic: `ptr(...)`, `castptr`, `addptr` -- Manual DMA engine control with low-level copy operations +- Manual DMA engine control with low-level copy operations and explicit GM→UB padding behavior - Requires explicit buffer management and pointer arithmetic - Intended for expert users and performance-critical optimizations diff --git a/tilelang-dsl/docs/user_guide/05-type-system.md b/tilelang-dsl/docs/user_guide/05-type-system.md index 820683c5b..49239a82b 100644 --- a/tilelang-dsl/docs/user_guide/05-type-system.md +++ b/tilelang-dsl/docs/user_guide/05-type-system.md @@ -336,6 +336,7 @@ Notes: - `PadValue.text` exposes the standard textual spelling for built-ins such as `null` and `zero`. - Custom pad values currently model an `f32` payload. In DSL v1, materializing a custom pad into a scalar is only supported for floating tile element dtypes. - `PadValue.NULL` does not denote a usable scalar fill constant. Calling `tile.pad_value.eval()` or `tile.config.pad_value.eval()` when the enum is `NULL` is a frontend error. +- **DMA padding**: When performing GM→UB DMA transfers with padding enabled (via `enable_ub_pad=True` in `pto.copy_gm_to_ubuf`), the pad value must be configured explicitly using `pto.set_mov_pad_val`. Tile `PadValue` descriptors are not automatically translated to hardware register configurations in TileLang DSL v1. See [Pad Fill Semantics](08-sync-dma-operations.md#pad-fill-semantics) for usage details. #### Tile Shape Concepts diff --git a/tilelang-dsl/docs/user_guide/08-sync-dma-operations.md b/tilelang-dsl/docs/user_guide/08-sync-dma-operations.md index fc8982da0..023304715 100644 --- a/tilelang-dsl/docs/user_guide/08-sync-dma-operations.md +++ b/tilelang-dsl/docs/user_guide/08-sync-dma-operations.md @@ -237,7 +237,19 @@ pto.wait_intra_core(0, Event.ID0) ### DMA Programming [Advanced Tier] -This section contains both DMA configuration operations (setting loop strides and sizes) and DMA execution operations (copying data). +This section covers Direct Memory Access (DMA) operations for transferring data between Global Memory (GM) and Unified Buffer (UB). DMA operations are performance-critical and require careful configuration of stride parameters and transfer sizes. + +**Key Concepts:** +- **DMA Configuration**: Set stride parameters and loop sizes using `set_loop*_stride_*` and `set_loop_size_*` operations. +- **DMA Execution**: Perform transfers using `copy_gm_to_ubuf`, `copy_ubuf_to_gm`, and `copy_ubuf_to_ubuf` operations. +- **GM→UB Padding**: Optionally fill out-of-bounds regions with a specified value when copying from GM to UB. See [Pad Fill Semantics](#pad-fill-semantics) for details. + +**Usage Flow:** +1. Configure DMA parameters (strides, loop sizes) +2. Execute the DMA transfer operation +3. Optionally enable padding for GM→UB transfers + +**Note**: All DMA operations in this section are part of the **Advanced Tier** and require explicit buffer management and pointer arithmetic. For basic tile-based authoring, refer to the [Basic Authoring Mode](01-introduction.md#basic-vs-advanced-authoring-modes) documentation. #### Manual Configuration Example @@ -250,6 +262,129 @@ pto.copy_gm_to_ubuf(src=gm_ptr, dst=ub_ptr, n_burst=16, len_burst=128, gm_stride ``` +#### Pad Fill Semantics + +When copying data from Global Memory (GM) to Unified Buffer (UB), you can enable padding to fill out-of-bounds regions with a specified value. This is useful when the source data dimensions don't perfectly match the destination tile allocation, or when you need to handle boundary conditions in tiled computations. + +##### How Padding Works + +1. **Configure the hardware pad register**: Call `pto.set_mov_pad_val` to set the pad value in the hardware register. This must be done before any `pto.copy_gm_to_ubuf` operation with padding enabled. + +2. **Enable padding in the DMA operation**: Set `enable_ub_pad=True` in the `pto.copy_gm_to_ubuf` call to activate the padded transfer path. The pad value from the hardware register will be used for filling out-of-bounds regions. + +3. **Hardware mapping**: The `pto.set_mov_pad_val` operation corresponds directly to the low-level VPTO instruction that configures the hardware pad register. There is no automatic translation from tile `PadValue` descriptors—you must explicitly set the pad register before padded DMA transfers. + +##### Example Workflow + +Configure the hardware pad register using `pto.set_mov_pad_val`, then perform the DMA transfer with padding enabled: + +```python +# First, configure the hardware pad register with a scalar value +# For zero fill, use an appropriate scalar type based on your data +pto.set_mov_pad_val(pto.f32(0.0)) # Zero fill for float32 data + +# Then perform the DMA transfer with padding enabled +pto.copy_gm_to_ubuf( + src=gm_ptr, + dst=ub_ptr, + n_burst=32, + len_burst=200, + gm_stride=200, + ub_stride=256, + enable_ub_pad=True, # Enable padded transfer +) +``` + +##### Accessing Pad Values in Kernel Code + +Tile `PadValue` descriptors can be used within kernel code for computation purposes (e.g., initializing vectors with a specific fill value). However, note that **these descriptors are not automatically used for DMA padding**—you must still call `pto.set_mov_pad_val` explicitly to configure the hardware pad register for GM→UB transfers. + +To access a pad value from a tile descriptor in kernel code: + +```python +# Get the pad descriptor from the destination tile +pad_desc = dst.pad_value + +# Check if a valid pad value is configured +if pto.constexpr(pad_desc != pto.PadValue.NULL): + # Materialize the scalar value + pad_scalar = pad_desc.eval() + + # Use the scalar value (e.g., for vector duplication) + mask = pto.make_mask(pto.f32, PAT.ALL) + pad_vector = pto.vdup(pad_scalar, mask) +``` + +##### Important Notes + +- The `PadValue.NULL` descriptor indicates no pad value is configured. Attempting to call `.eval()` on `PadValue.NULL` will raise a frontend error. +- Custom pad values currently support only 32-bit float payloads (`PadValue.custom_f32(...)`). +- Padding only affects GM→UB transfers (`pto.copy_gm_to_ubuf`). UB→GM and UB→UB transfers do not support padding. +- The padded region is determined by the difference between the tile's `valid_shape` and its full `shape`. Ensure your tile is configured with appropriate dimensions. +- Tile `PadValue` descriptors are not automatically used for DMA padding. You must call `pto.set_mov_pad_val` explicitly to configure the hardware pad register for padded GM→UB transfers. + +##### `pto.set_mov_pad_val` Operation [Advanced Tier] + +The `pto.set_mov_pad_val` operation configures the hardware pad register used for GM→UB transfers when padding is enabled. This operation must be called explicitly before any `pto.copy_gm_to_ubuf` operation with `enable_ub_pad=True`, as the TileLang DSL v1 does not automatically translate tile `PadValue` descriptors to hardware register configurations. + +**Operation Signature**: +```python +pto.set_mov_pad_val(pad_value: ScalarType) -> None +``` + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pad_value` | `ScalarType` | Scalar value used for padding. Supported types: `pto.i8`, `pto.i16`, `pto.i32`, `pto.f16`, `pto.bf16`, `pto.f32`. The value's bit pattern is encoded into the hardware pad register. For standard pad values, use `PadValue.eval()` to obtain the appropriate scalar: `0` or `0.0` for `PadValue.ZERO`, dtype-aware maximum for `PadValue.MAX`, dtype-aware minimum for `PadValue.MIN`. | + +**Returns**: None (side-effect operation) + +**Example**: + +Using a scalar value directly: +```python +# Configure the hardware pad register for zero fill using an integer scalar +pto.set_mov_pad_val(pto.i32(0)) # Zero fill for integer types + +# Or using a float scalar for floating-point padding +pto.set_mov_pad_val(pto.f32(0.0)) # Zero fill for float types + +# Perform DMA transfer with padding enabled +pto.copy_gm_to_ubuf( + src=gm_ptr, + dst=ub_ptr, + n_burst=32, + len_burst=200, + gm_stride=200, + ub_stride=256, + enable_ub_pad=True, +) +``` + +Using a tile's pad value descriptor: +```python +# Get the pad value from a tile configuration +pad_desc = tile.pad_value # PadValue enum +if pto.constexpr(pad_desc != pto.PadValue.NULL): + pad_scalar = pad_desc.eval() # Materializes to a scalar value + pto.set_mov_pad_val(pad_scalar) + + # Perform padded DMA transfer + pto.copy_gm_to_ubuf( + src=gm_ptr, + dst=ub_ptr, + n_burst=32, + len_burst=200, + gm_stride=200, + ub_stride=256, + enable_ub_pad=True, + ) +``` + +**Important**: You are responsible for ensuring the pad register is properly configured before any `pto.copy_gm_to_ubuf` operation with `enable_ub_pad=True`. The pad register configuration persists until changed by another `pto.set_mov_pad_val` call. + +**Future Improvement**: Future versions of TileLang DSL may provide an implicit approach that automatically translates `PadValue` descriptors from tile configurations to hardware register configurations, similar to DMA syntax sugar features. + #### `pto.set_loop2_stride_outtoub(src_stride: pto.i64, dst_stride: pto.i64) -> None` [Advanced Tier] **Description**: Configures DMA stride parameters for GM → UB transfers (loop2). @@ -390,8 +525,10 @@ The following operations provide direct control over DMA transfers but require m **Returns**: None (side-effect operation) **Notes**: -- In TileLang DSL, the keyword form above is the recommended public surface. -- The lowering still maps to the underlying low-level PTO operand ABI in positional order. +- **Keyword arguments**: The keyword form shown above is the recommended public API surface. Use named arguments for clarity. +- **Padding control**: Set `enable_ub_pad=True` to enable padded GM→UB transfers. The pad value must be configured separately using `pto.set_mov_pad_val` before the DMA operation (see [Pad Fill Semantics](#pad-fill-semantics) for details). +- **Pad value source**: When padding is enabled, the fill scalar comes from the hardware pad register configured by `pto.set_mov_pad_val`. You must call this operation explicitly before the DMA transfer. +- **ABI compatibility**: The lowering preserves the underlying PTO operand order while providing a more ergonomic keyword interface. **Example**: ```python @@ -406,6 +543,23 @@ pto.copy_gm_to_ubuf( ) ``` +**Padding Example**: +```python +# First configure the hardware pad register with a scalar value +pto.set_mov_pad_val(pto.f32(0.0)) # Zero fill for float32 data + +# Then perform padded DMA transfer +pto.copy_gm_to_ubuf( + src=gm_ptr, + dst=ub_ptr, + n_burst=32, + len_burst=200, + gm_stride=200, + ub_stride=256, + enable_ub_pad=True, +) +``` + #### `pto.copy_ubuf_to_ubuf(src: UBPtr, dst: UBPtr, src_offset: pto.i64, src_stride0: pto.i64, src_stride1: pto.i64, dst_offset: pto.i64, dst_stride0: pto.i64, dst_stride1: pto.i64) -> None` [Advanced Tier] **Description**: Copies data within Unified Buffer (UB → UB). @@ -457,4 +611,4 @@ pto.copy_ubuf_to_gm( gm_stride=128, ub_stride=128, ) -``` +``` \ No newline at end of file diff --git a/tilelang-dsl/docs/vpto_spec/vpto-spec-current.md b/tilelang-dsl/docs/vpto_spec/vpto-spec-current.md index 8de281795..b92097043 100644 --- a/tilelang-dsl/docs/vpto_spec/vpto-spec-current.md +++ b/tilelang-dsl/docs/vpto_spec/vpto-spec-current.md @@ -883,7 +883,7 @@ This section provides a categorized overview of all PTO micro Instruction operat | # | Group | Description | Count | Details | |---|-------|-------------|-------|---------| | 1 | [Pipeline Sync](#isa-01-pipeline-sync) | Intra-core pipeline synchronization | 5 | `pto.set_flag`, `pto.wait_flag`, `pto.pipe_barrier`, `pto.get_buf`, `pto.rls_buf` | -| 2 | [DMA Copy Programming](#isa-02-dma-copy) | DMA configuration and transfer between GM↔UB | 9 | `pto.set_loop*_stride_*`, `pto.set_loop_size_*`, `pto.copy_gm_to_ubuf`, `pto.copy_ubuf_to_ubuf`, `pto.copy_ubuf_to_gm` | +| 2 | [DMA Copy Programming](#isa-02-dma-copy) | DMA configuration and transfer between GM↔UB | 10 | `pto.set_loop*_stride_*`, `pto.set_loop_size_*`, `pto.set_mov_pad_val`, `pto.copy_gm_to_ubuf`, `pto.copy_ubuf_to_ubuf`, `pto.copy_ubuf_to_gm` | | 3 | [Vector Load/Store](#isa-03-vector-load-store) | UB↔vreg data movement with various access patterns | ~20 | `pto.vlds`, `pto.vldsx2`, `pto.vgather2`, `pto.vsts`, `pto.vstsx2`, `pto.vscatter`, etc. | | 4 | [Predicate Load/Store](#isa-04-predicate-load-store) | UB↔mask register movement | 5 | `pto.plds`, `pto.pldi`, `pto.psts`, `pto.psti`, `pto.pstu` | | 5 | [Materialization & Predicate Ops](#isa-05-materialization-predicate) | Scalar broadcast, predicate generation and manipulation | ~17 | `pto.vbr`, `pto.vdup`, `pto.pset_b*`, `pto.pge_b*`, `pto.plt_b*`, `pto.ppack`, `pto.punpack`, `pto.pnot`, `pto.psel`, etc. | @@ -1609,6 +1609,36 @@ Note: UB stride fields are 21 bits (sufficient for 256KB UB address space), GM s --- +#### Pad Value Configuration + +##### `pto.set_mov_pad_val` + +- **syntax:** `pto.set_mov_pad_val %value : T` +- **supported `T`:** `i8`, `i16`, `i32`, `f16`, `bf16`, `f32` +- **semantics:** Configure the pad fill value used by GM→UB DMA when `data_select_bit = true`. + +This op programs the hardware pad register consumed by `pto.copy_gm_to_ubuf`. The operand is a typed scalar. Its raw bit pattern is encoded into the underlying hardware configuration payload: + +- integer inputs use their zero-extended bit pattern +- floating-point inputs use their bitcast-to-integer bit pattern, then zero-extend to `i64` + +This configuration affects only the GM→UB padding path. UB→GM DMA ignores the pad value. + +**Parameter Table:** + +| Parameter | Description | +|-----------|-------------| +| `%value` | Pad fill scalar. Must be one of `i8/i16/i32/f16/bf16/f32`. | + +**Example:** + +```mlir +%pad = arith.constant 0 : i16 +pto.set_mov_pad_val %pad : i16 +``` + +--- + #### DMA Transfer Execution ##### `pto.copy_gm_to_ubuf` @@ -1947,6 +1977,8 @@ UB (128 cols wide, 32B-aligned, padded): ``` ```mlir +%pad = arith.constant 0 : i16 +pto.set_mov_pad_val %pad : i16 pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 diff --git a/tilelang-dsl/docs/vpto_spec/vpto-spec-v0.3.md b/tilelang-dsl/docs/vpto_spec/vpto-spec-v0.3.md new file mode 100644 index 000000000..8de281795 --- /dev/null +++ b/tilelang-dsl/docs/vpto_spec/vpto-spec-v0.3.md @@ -0,0 +1,5349 @@ +# PTO micro Instruction Spec — Draft (A5) + +- v0.3: Add runtime block query and vector-interval legality notes; Normalize load/store distribution families; Update get_buf/rls_buf details +- v0.2: Update micro Instruction latency and throughput +- v0.1: Doc Init + +[toc] + +--- + +## Part I: Architecture Overview + +### Overview + +This document defines the PTO micro Instruction, a compiler-internal and externally facing specification designed to represent vector compute kernels within the PTO architecture. Much like NVVM provides a robust IR for GPU architectures, the PTO micro Instruction serves as the direct bridge between high-level programming models and the underlying hardware ISA, providing a precise, low-level representation of vector workloads explicitly designed for the Ascend 950 architecture. + +#### Position in the Stack and Layer Modeled + +The PTO micro Instruction operates as a very low-level intermediate representation within the PTO compiler stack. It is uniquely designed to accurately and comprehensively express all architectural information of the Ascend 950 hardware. It specifically models the bare-metal vector execution layer, making hardware-specific capabilities and constraints, such as exact vector lane configurations, memory space hierarchies, and hardware-specific fusion semantics, fully transparent and controllable. + +#### PTO Instruction Modes and Compilation Flows + +Within the end-to-end PTO software stack, PTO instructions may appear in three closely related authoring or lowering modes: + +- **PTO Tile Instruction**: tile-oriented PTO code that serves as a nano-kernel encapsulation of Tile operations, primarily expressing computation and data movement in terms of tile buffers, tile shapes, and tile-local layout. +- **PTO micro Instruction**: vector-execution-oriented PTO code that makes DMA setup, vector registers, masks, synchronization, and `__VEC_SCOPE__` boundaries explicit. This document is centered on this mode. +- **PTO Tile+micro Instruction**: a hybrid PTO form that keeps tile-level orchestration while embedding explicit micro-instruction regions where direct vector-pipeline control is required. + +From these PTO instruction forms, the stack can proceed along two main compilation flows: + +- **CCE generation flow**: PTO ISA is lowered into a CCE-oriented representation, which is then compiled by the BiSheng toolchain into Ascend device binaries. +- **Bytecode generation flow**: PTO ISA is emitted as bytecode, which is then compiled by the BiSheng toolchain into Ascend device binaries. + +```text +High-level frameworks / DSLs / library kernels + | + v + +----------------------------------+ + | PTO ISA layer | + | | + | (1) PTO Tile Instruction | + | (2) PTO micro Instruction | + | (3) PTO Tile+micro Instruction | + +----------------+-----------------+ + | + +------------+------------+ + | | + v v + +-------------------------+ +-------------------------+ + | Path A: generate CCE | | Path B: generate | + | (CCE-oriented form) | | bytecode | + +------------+------------+ +------------+------------+ + | | + v v + +-----------------------------------------------+ + | BiSheng compiler | + +---------------------------+-------------------+ + | + v + +-----------------------------+ + | Ascend device binaries | + +-----------------------------+ +``` + +#### Why External Developers Read or Author PTO micro Instruction + +While the majority of users will interact with the PTO architecture via higher-level frameworks, external developers may need to read or author PTO micro Instruction directly for several key reasons: + +- Custom Toolchain Development: build custom compiler frontends or domain-specific languages (DSLs) that target the Ascend 950 architecture with maximum hardware utilization. +- Performance Engineering: inspect the output of high-level compiler passes, verify fine-grained optimization behaviors, and pinpoint performance bottlenecks at the architectural level. +- Micro-Optimization: hand-author highly optimized, critical mathematical kernels using a stable, precise IR when higher-level abstractions cannot achieve the theoretical peak performance of the hardware. + +#### Relationship to CCE + +The PTO micro Instruction is designed to express the full semantic capabilities of the Compute Cube Engine (CCE), but with significant structural and pipeline advantages for compiler development. + +- Bypassing the C/Clang Pipeline: while CCE heavily relies on C/C++ extensions parsed by Clang, the PTO micro Instruction operates entirely independently of the C language frontend. By bypassing Clang AST generation and frontend processing, utilizing the PTO micro Instruction significantly reduces overall compilation time and memory overhead. +- Enhanced IR Verification: because the PTO micro Instruction is a strongly typed, SSA-based (Static Single Assignment) compiler IR rather than a C-wrapper API, it provides a much more rigorous and detailed IR verification process. Structural inconsistencies, invalid memory access patterns, and operand type mismatches are caught immediately with precise, explicit diagnostic feedback, providing developers with much higher visibility into kernel correctness than traditional CCE error reporting. + +#### Intended Audience + +This document is written for compiler engineers, library writers, and advanced performance architects. We expect the reader to have a working understanding of modern compiler infrastructure, specifically MLIR, the principles of Static Single Assignment (SSA) form, and a deep understanding of the vector-processing capabilities of the Ascend 950 architecture. + +### Getting Started + +The PTO micro Instruction is architected as a performance-critical layer within the compiler stack, specifically designed to exploit the **Decoupled Access-Execute** (DAE) nature of the Ascend 950 hardware. + +#### Hardware Pipeline Modeling + +The IR is structured to mirror the three primary hardware pipelines of the Ascend 950 architecture. Correct PTO micro Instruction authoring requires managing the interaction between these asynchronous units: + +**MTE2** (Memory Transfer Engine - Inbound): Responsible for moving data from Global Memory (GM) to the Unified Buffer (UB). + +**Vector Core** (Computation): The primary engine for executing SIMD operations on data stored in UB. + +**MTE3** (Memory Transfer Engine - Outbound): Responsible for moving processed data from UB back to GM. + +#### Architecture Detail: Vector Lane (VLane) + +The vector register is organized as **8 VLanes** of 32 bytes each. A VLane is the atomic unit for group reduction operations. + +``` +vreg (256 bytes total): +┌─────────┬─────────┬─────────┬─────┬─────────┬─────────┐ +│ VLane 0 │ VLane 1 │ VLane 2 │ ... │ VLane 6 │ VLane 7 │ +│ 32B │ 32B │ 32B │ │ 32B │ 32B │ +└─────────┴─────────┴─────────┴─────┴─────────┴─────────┘ +``` + +Elements per VLane by data type: + +| Data Type | Elements/VLane | Total Elements/vreg | +|-----------|---------------|-------------------| +| i8/si8/ui8 | 32 | 256 | +| i16/si16/ui16/f16/bf16 | 16 | 128 | +| i32/si32/ui32/f32 | 8 | 64 | +| i64/si64/ui64 | 4 | 32 | + +#### Memory and Synchronization Model + +The PTO micro Instruction enforces a strict memory hierarchy. The Unified Buffer (UB) is the only valid operand source for vector compute instructions. Consequently, the architecture of a PTO micro Instruction program is defined by the explicit management of data movement: + +**Address Space Isolation**: The IR uses `!pto.ptr` to distinguish between GM (`!pto.ptr`) and UB (`!pto.ptr`). The verifier ensures that vector compute operations do not access GM directly; data must first be moved into UB. + +**UB Capacity**: The Unified Buffer provides 256KB of on-chip SRAM (also referred to as "vecTile"). + +**Data Flow**: + +``` +┌─────────────────────────────────────────────┐ +│ Global Memory (GM) │ +│ (Off-chip HBM/DDR) │ +└─────────────────────┬───────────────────────┘ + │ DMA (MTE2 inbound / MTE3 outbound) +┌─────────────────────▼───────────────────────┐ +│ Unified Buffer (UB) │ +│ (On-chip SRAM, 256KB) │ +└─────────────────────┬───────────────────────┘ + │ Vector Load/Store (PIPE_V) +┌─────────────────────▼───────────────────────┐ +│ Vector Register File (VRF) │ +│ vreg (256B each) + mask (256-bit each) │ +└─────────────────────────────────────────────┘ +``` + +1. **GM → UB**: DMA transfer via MTE2 (`pto.copy_gm_to_ubuf`) +2. **UB → vreg**: Vector Load instructions (`pto.vlds`, `pto.vldsx2`, etc.) +3. **vreg → vreg**: Compute instructions (`pto.vadd`, `pto.vmul`, etc.) +4. **vreg → UB**: Vector Store instructions (`pto.vsts`, `pto.vstsx2`, etc.) +5. **UB → GM**: DMA transfer via MTE3 (`pto.copy_ubuf_to_gm`) + +**Load/Store Access Patterns**: + +For UB↔vreg data movement, besides contiguous load/store, the architecture provides rich access pattern support including strided access, pack/unpack, interleave/deinterleave, broadcast, upsample/downsample, channel split/merge, gather/scatter, and squeeze/expand operations. For detailed instruction syntax and distribution modes, refer to the [Vector Load/Store](#isa-03-vector-load-store) group in the ISA specification. + +#### Synchronization Model + +The Ascend 950 architecture employs a cluster-based design with a 1:2 ratio of Cube cores to Vector cores. The PTO micro Instruction provides multiple levels of synchronization to manage concurrent execution across pipelines and cores: + +**Inter-Core Synchronization (within a cluster):** + +Synchronization between cores within the same cluster is achieved via the core sync mechanism using `pto.set_intra_core` and `pto.wait_intra_core` operations. This enables coordination between Cube and Vector cores sharing the same cluster resources. + +**Vector Core Pipeline Synchronization:** + +Within a single core, multiple pipelines operate asynchronously: + +- **MTE2 (PIPE_MTE2)**: DMA copy-in from GM to UB +- **MTE3 (PIPE_MTE3)**: DMA copy-out from UB to GM +- **Vector Compute (PIPE_V)**: Vector ALU operations +- **Scalar (PIPE_S)**: Scalar unit running the kernel program + +Pipeline synchronization can be achieved through two mechanisms: + +1. **Flag/Event mechanism**: `pto.set_flag` and `pto.wait_flag` operations resolve Read-After-Write (RAW) and Write-After-Read (WAR) hazards between pipelines. + +2. **Buffer-ID mechanism**: `pto.get_buf` and `pto.rls_buf` provide finer-grained synchronization through buffer acquisition and release semantics for producer-consumer coordination. + +**Intra-Pipeline Memory Barriers (within `__VEC_SCOPE__`):** + +Within the vector execution scope, the hardware does not track UB address aliasing between reg↔UB accesses. When UB addresses overlap or alias between vector load/store operations, explicit memory barriers are required: + +```c +pto.mem_bar "VV_ALL" // All prior vector ops complete before subsequent +pto.mem_bar "VST_VLD" // All prior vector stores visible before subsequent loads +pto.mem_bar "VLD_VST" // All prior vector loads complete before subsequent stores +``` + +Without proper barriers, loads may see stale data or stores may be reordered incorrectly. + +#### Execution Scopes (__VEC_SCOPE__) + +`__VEC_SCOPE__` is the IR-level representation of a Vector Function (VF) launch. In the PTO architecture, it defines the hardware interface between the Scalar Unit and the Vector Thread. + +In PTO micro Instruction source IR, vector execution scopes are modeled as dedicated region ops. The default form is `pto.vecscope`; when the scope body must reject implicit capture and require explicit region arguments, use `pto.strict_vecscope`. + +**Scalar-Vector Interface:** + +The execution model follows non-blocking fork semantics: + +- Scalar invocation: the scalar processor invokes a vector thread by calling a VF. Once the launch command is issued, the scalar unit does not stall and continues executing subsequent instructions in the pipeline. +- Vector execution: after invocation, the vector thread independently fetches and executes the instructions defined within the VF scope. +- Parallelism: this decoupled execution allows the scalar and vector units to run in parallel, so the scalar unit can prepare addresses or manage control flow while the vector unit performs heavy SIMD computation. + +**Launch Mechanism And Constraints:** + +- Parameter buffering: all arguments required by the VF must be staged in hardware-specific buffers. +- Launch overhead: launching a VF incurs a latency of a few cycles. Very small VFs should account for this overhead because launch cost can rival useful computation time. + +**MLIR Representation:** + +```mlir +pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} +``` + +**Strict MLIR Representation:** + +```mlir +pto.strict_vecscope(%ub, %ub_out, %lane) { +^bb0(%in: !pto.ptr, %out: !pto.ptr, %iv: index): + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %in[%iv] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %out[%iv], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} : (!pto.ptr, !pto.ptr, index) -> () +``` + +`pto.strict_vecscope` is the strict form of `pto.vecscope`. + +- `pto.vecscope` allows the body to use surrounding SSA values directly. +- `pto.strict_vecscope` requires every external value used by the body to be passed through the op operand list and received as a body block argument. +- `pto.strict_vecscope` rejects implicit capture from the surrounding scope. +- both ops still represent one explicit VPTO vector interval. +- regardless of whether the source form uses `pto.vecscope`, + `pto.strict_vecscope`, or a lowered carrier loop with + `llvm.loop.aivector_scope`, every op that produces or consumes `!pto.vreg`, + `!pto.mask<...>`, or `!pto.align` must be enclosed by exactly one vector + interval +- nested vector intervals are not part of the legal VPTO surface; ordinary + nested `scf.for` structure is fine, but one vector interval may not contain + another vector interval + +### Example: VecScope + +```mlir +pto.set_loop2_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.copy_gm_to_ubuf %7, %2, %3, %3, %c0_i64, %c32_i64, %4, %c0_i64, %c0_i64, + %false, %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i1, i64, i64, i64 + +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +pto.vecscope { + scf.for %lane = %c0 to %9 step %c64 { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %2[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %8[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } +} + +pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] +pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop2_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 +pto.copy_ubuf_to_gm %8, %14, %3, %3, %c0_i64, %c32_i64, %4, %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i64 +``` + +### Example: Strict VecScope + +```mlir +pto.strict_vecscope(%ub_in, %ub_out, %lane, %remaining) { +^bb0(%in: !pto.ptr, %out: !pto.ptr, %iv: index, %rem: i32): + %mask, %next_remaining = pto.plt_b32 %rem : i32 -> !pto.mask, i32 + %v = pto.vlds %in[%iv] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %out[%iv], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} : (!pto.ptr, !pto.ptr, index, i32) -> () +``` + +Use `pto.strict_vecscope` when the source form should make all vector-scope inputs explicit in the region signature instead of relying on surrounding SSA visibility. The scope op itself only defines the vector-interval boundary and region argument contract. + +### Scope + +This document is the interface specification centered on the `mlir::pto` dialect and the shared MLIR surface used alongside it in PTO micro Instruction programs. + +It only describes: + +- operation names +- operand and result lists +- operand and result types +- important attributes +- C-style semantics for each operation + +It does not describe lowering strategy. + +PTO micro Instruction source programs are not restricted to `pto` operations alone. In practice they also use shared MLIR dialect ops, most notably the full scalar operation surface of `arith` together with structured control-flow ops from `scf`, to express scalar constants, scalar arithmetic, type conversion, comparisons, and structured control flow around PTO vector or tile regions. These shared-dialect ops are part of the supported PTO micro Instruction source surface and should be regarded as part of PTO-ISA alongside `pto` dialect operations. + +### Shared MLIR Dialects + +- `arith`: the full scalar `arith` surface is supported in PTO micro Instruction programs, covering scalar integer, floating-point, boolean, and `index` operations. In current samples the most common uses are still constants, offset/bounds arithmetic, casts, compares, and selects. +- `scf`: structured control flow used to model counted loops, conditional regions, loop-carried state, and break-like control around PTO compute and data-movement ops. +- Shared dialect ops remain in standard MLIR form so that PTO analyses and backend passes can reason about control flow and scalar state without re-encoding them as PTO-specific instructions. + +### BlockDim Query Operations + +These ops expose the current kernel instance's execution coordinates to scalar code. They are the PTO-level equivalent of runtime queries such as `GetBlockIdx()` and `GetBlockNum()` in kernel programming models. + +Use them when the same kernel body is launched across multiple blocks or subblocks and each execution instance must figure out which slice of the global workload it owns. + +A common pattern is: + +- split the full input/output tensor into `block_num` disjoint block-sized regions +- let each block compute its own starting offset from `block_idx` +- within one block, further tile the local region and drive the tile loop with ordinary scalar `arith` / `scf` ops + +For example, if a tensor is split evenly across 8 blocks and each block handles `block_length = 2048` elements, then block `b` owns the global range `[b * block_length, (b + 1) * block_length)`. The per-block GM base pointer can be formed by adding `block_idx * block_length` elements to the original base pointer. + +At the PTO micro Instruction level, these runtime-query ops are pure scalar producers. They do not perform data movement, do not allocate memory, and do not by themselves create tiling or double buffering. Instead, they provide the scalar values used by surrounding address computation and structured control flow. + +#### Example: block-level data partitioning + +```mlir +%block = pto.get_block_idx +%block_num = pto.get_block_num +%block_len = arith.constant 2048 : index +%base = arith.index_cast %block : i64 to index +%offset = arith.muli %base, %block_len : index +%block_in = pto.addptr %gm_in, %offset : !pto.ptr -> !pto.ptr +%block_out = pto.addptr %gm_out, %offset : !pto.ptr -> !pto.ptr +``` + +In this pattern, all blocks execute the same kernel body, but each block sees a different `%block` value and therefore computes a different GM window. + +#### `pto.get_block_idx` + +- **syntax:** `%block = pto.get_block_idx` +- **result:** `i64` +- **semantics:** Return the current block ID in the range `[0, pto.get_block_num())`. + +```c +block = block_idx(); +``` + +#### `pto.get_subblock_idx` + +- **syntax:** `%subblock = pto.get_subblock_idx` +- **result:** `i64` +- **semantics:** Return the current subblock ID in the range `[0, pto.get_subblock_num())`. + +```c +subblock = subblock_idx(); +``` + +#### `pto.get_block_num` + +- **syntax:** `%block_num = pto.get_block_num` +- **result:** `i64` +- **semantics:** Return the total number of launched blocks visible to the current kernel instance. + +```c +block_num = block_num(); +``` + +#### `pto.get_subblock_num` + +- **syntax:** `%subblock_num = pto.get_subblock_num` +- **result:** `i64` +- **semantics:** Return the total number of visible subblocks for the current execution instance. + +```c +subblock_num = subblock_num(); +``` + +Typical usage: + +```mlir +%block = pto.get_block_idx +%subblock = pto.get_subblock_idx +%block_num = pto.get_block_num +%subblock_num = pto.get_subblock_num +``` + +### Core Types + +### Element Types +`vreg`: `!pto.vreg` Fixed-width PTO micro Instruction vector type with total width exactly 256 bytes (2048 bits). `N` is the lane count, `T` is the element type, and `N * bitwidth(T) = 2048`. + +| Type | Bits | Description | +|------|------|-------------| +| `i8` / `si8` / `ui8` | 8 | Signless/signed/unsigned 8-bit integer | +| `i16` / `si16` / `ui16` | 16 | Signless/signed/unsigned 16-bit integer | +| `i32` / `si32` / `ui32` | 32 | Signless/signed/unsigned 32-bit integer | +| `i64` / `si64` / `ui64` | 64 | Signless/signed/unsigned 64-bit integer | +| `f16` | 16 | IEEE 754 half precision | +| `bf16` | 16 | Brain floating point | +| `f32` | 32 | IEEE 754 single precision | + +### Mask Types + +`mask`: `!pto.mask` Typed predicate-register view. `G` is one of `b8`, `b16`, `b32` and records the byte-granularity interpretation used by VPTO ops and verifiers. + +Typed masks are also the primary legality contract for predicated VPTO code: + +- vector ops over `f32`, `i32`, `si32`, and `ui32` consume `!pto.mask` +- vector ops over `f16`, `bf16`, `i16`, `si16`, and `ui16` consume + `!pto.mask` +- vector ops over 8-bit element families consume `!pto.mask` +- compare families keep seed-mask and result-mask granularity aligned with the + compared vector family +- carry families keep carry-in, carry-out, and execution-mask granularity + aligned with the data-vector family +- mask-only ops that do not explicitly change granularity preserve the same `G` + +### Address Space Conventions + +PTO micro Instruction memory operands use `!pto.ptr`. This specification models the following memory-space attributes: + +| Space | Interpretation | +|-------|----------------| +| `gm` | Global Memory (GM), off-chip HBM/DDR storage | +| `ub` | Unified Buffer (UB), on-chip vector buffer | + +Typical pointer construction and pointer arithmetic follow the same `!pto.ptr<..., space>` form: + +```mlir +%0 = pto.castptr %c0 : i64 -> !pto.ptr +%1 = pto.addptr %0, %c1024 : !pto.ptr -> !pto.ptr +``` + +### `!pto.ptr` + +`!pto.ptr` is the typed pointer form used for explicit memory operands in PTO micro Instruction. + +- `T` is the element type associated with the pointed-to storage. +- `space` is the memory domain, typically `gm` or `ub` in this specification. +- A `pto.ptr` value carries an address plus its element-type / memory-space interpretation, but it does not carry tensor shape or stride metadata by itself. +- Tensor semantics are introduced separately through view-building operations such as `pto.make_tensor_view`. +- Pointer arithmetic is element-based rather than byte-based. + +Typical examples: + +- `!pto.ptr` +- `!pto.ptr` +- `!pto.ptr` + +### Pointer Operations + +#### `pto.castptr` + +- **syntax:** `%result = pto.castptr %addr : i64 -> !pto.ptr` +- **semantics:** Reinterpret a scalar address value as a typed PTO pointer in the target memory space. + +```c +result = (ptr)addr; +``` + +`pto.castptr` is a pointer-construction operation. It does not perform data movement and does not by itself imply any load/store side effect. + +#### `pto.addptr` + +- **syntax:** `%result = pto.addptr %ptr, %offset : !pto.ptr -> !pto.ptr` +- **semantics:** Compute a new pointer by advancing the base pointer by an element offset. + +```c +result = ptr + offset; // offset counted in elements, not bytes +``` + +`pto.addptr` preserves both the element type `T` and the memory-space tag `space`. + +#### `pto.load_scalar` + +- **syntax:** `%value = pto.load_scalar %ptr[%offset] : !pto.ptr -> T` +- **semantics:** Load one scalar element from a pointer-like operand. + +```c +value = ptr[offset]; +``` + +- **inputs:** + `%ptr` is a typed PTO pointer `!pto.ptr`, and `%offset` is an + `index` displacement counted in elements. +- **outputs:** + `%value` is the loaded scalar element. +- **constraints and limitations:** + The result type MUST match the element type of `%ptr`. This op is a scalar + memory helper; unlike `pto.vlds`, it does not produce a `vreg` result and + does not participate in vector load `dist` families. + +#### `pto.store_scalar` + +- **syntax:** `pto.store_scalar %value, %ptr[%offset] : !pto.ptr, T` +- **semantics:** Store one scalar element to a pointer-like operand. + +```c +ptr[offset] = value; +``` + +- **inputs:** + `%value` is the scalar value to store. `%ptr` is a typed PTO pointer + `!pto.ptr`, and `%offset` is an `index` displacement counted in + elements. +- **constraints and limitations:** + The stored value type MUST match the element type of `%ptr`. This op is a + scalar memory helper; unlike `pto.vsts`, it does not consume a mask and does + not target vector-store `dist` families. + +#### Pointer-Based Vector Access Example + +The following lowered-style fragment shows how typed PTO pointers flow through +pointer construction, pointer arithmetic, structured control flow, and PTO +memory ops. Scalar memory access is expressed on `!pto.ptr` in +general, but the common VPTO pattern here is UB-local scalar access alongside +UB vector loads/stores: + +```mlir +%0 = pto.castptr %c0 : i64 -> !pto.ptr +%1 = pto.addptr %0, %c1024 : !pto.ptr -> !pto.ptr +pto.vecscope { + %16 = scf.for %arg3 = %c0 to %11 step %c64 iter_args(%arg4 = %12) -> (i32) { + %mask, %scalar_out = pto.plt_b32 %arg4 : i32 -> !pto.mask, i32 + %s = pto.load_scalar %1[%c4] : !pto.ptr -> f32 + pto.store_scalar %s, %1[%c8] : !pto.ptr, f32 + %17 = pto.vlds %1[%arg3] : !pto.ptr -> !pto.vreg<64xf32> + %18 = pto.vabs %17, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %18, %10[%arg3], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %scalar_out : i32 + } +} +``` + +In this pattern, `pto.castptr` materializes a typed UB pointer, `pto.addptr` shifts the base by 1024 `f32` elements, and the subsequent `[%arg3]` indexing on `pto.vlds` / `pto.vsts` applies an additional element offset relative to that base. + +### Special Types + +#### `!pto.mask` + +`!pto.mask` models an A5 predicate register (256-bit) under a typed granularity view, not an integer vector. + +`G` is part of the type and MUST be one of: + +- `b32` +- `b16` +- `b8` + +All three forms describe the same physical 256-bit predicate-register class. The type parameter does not encode how many lanes are currently active. Instead, it records how VPTO interprets the register when matching mask-producing ops, mask-consuming ops, and verifier legality rules. + +In the ISA chapters below, this document uses `!pto.mask` as shorthand when a +family is generic over granularity. For op families whose names already encode +the granularity, such as `pset_b32`, `pge_b16`, `plt_b8`, +`pdintlv_b8`, and `pintlv_b16`, examples use the corresponding concrete typed +mask. + +**Mask Granularity:** + +The predicate register is 256 bits in length, where each bit controls 1 byte of data. `G` therefore describes how many bytes form one logical element slot: + +| Mask Type | Bytes / Element Slot | Typical Element Family | Derived Logical Lanes | +|-----------|----------------------|------------------------|-----------------------| +| `!pto.mask` | 4 | `f32` / `i32` | 64 | +| `!pto.mask` | 2 | `f16` / `bf16` / `i16` | 128 | +| `!pto.mask` | 1 | 8-bit element family | 256 | + +This is intentionally different from a lane-vector model such as `mask<64xi1>`: + +- `!pto.mask` still denotes a 256-bit predicate register; +- `64` is only the derived logical lane count for the `b32` view; +- value-level patterns such as `PAT_VL32` describe which lanes are active, not a different type. +- `pto.vaddc`, `pto.vsubc`, `pto.vaddcs`, and `pto.vsubcs` use `!pto.mask` + to carry their per-lane carry results, interpreted with this same + granularity. + +**Predication Behavior (Zero-Merge):** + +The native hardware predication mode is **ZEROING** — inactive lanes produce zero: + +```c +dst[i] = mask[i] ? op(src0[i], src1[i]) : 0 // ZEROING mode +``` + +```mlir +// Predicated add: inactive lanes produce zero +%mask = pto.pset_b32 "PAT_VL32" : !pto.mask // first 32 logical b32 lanes active +%result = pto.vcmp %a, %b, %mask, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +``` + +```mlir +// Compare and select: generate mask from comparison, use for conditional select +%mask = pto.vcmp %lhs, %rhs, %seed, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +%out = pto.vsel %x, %y, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +#### `!pto.align` + +`!pto.align` models the A5 vector-align carrier state. It is not payload data. + +```mlir +%align = pto.vldas %ub : !pto.ptr -> !pto.align +%vec, %align_out = pto.vldus %ub, %align : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align + +%store_align = pto.init_align : !pto.align +%next_align = pto.vstus %store_align, %offset, %vec, %ub + : !pto.align, i32, !pto.vreg<64xf32>, !pto.ptr -> !pto.align +``` + +--- + +## Part II: Notation Convention + +This section defines the MLIR syntax patterns and C-style semantic notation used throughout the ISA reference (Part III). + +### MLIR Op Syntax Patterns + +All PTO micro Instruction operations follow standard MLIR syntax. The common patterns are: + +**Unary (one vector in, one vector out):** + +```mlir +%result = pto. %input : !pto.vreg -> !pto.vreg +``` + +**Binary (two vectors in, one vector out):** + +```mlir +%result = pto. %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg +``` + +**Vec-Scalar (one vector + one scalar in, one vector out):** + +```mlir +%result = pto. %input, %scalar : !pto.vreg, T -> !pto.vreg +``` + +**Load (memory to register):** + +```mlir +%result = pto.vlds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.vreg +``` + +**Store (register to memory):** + +```mlir +pto.vsts %value, %destination[%offset] {dist = "DIST"} : !pto.vreg, !pto.ptr +``` + +**Dual Load (one load, two results — deinterleave):** + +```mlir +%low, %high = pto.vldsx2 %source[%offset], "DIST" : !pto.ptr, index -> !pto.vreg, !pto.vreg +``` + +**Dual Store (two inputs, one interleaved store):** + +```mlir +pto.vstsx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vreg, !pto.ptr, index, !pto.mask +``` + +**Compare (two vectors + seed mask in, mask out):** + +```mlir +%mask = pto.vcmp %src0, %src1, %seed, "CMP_MODE" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.mask +``` + +**Conversion (one vector in, different-typed vector out):** + +```mlir +%result = pto.vcvt %input {rnd = "R", sat = "SAT", part = "EVEN"} : !pto.vreg -> !pto.vreg +``` + +**Predicate construction:** + +```mlir +%mask = pto.pset_b32 "PAT_ALL" : !pto.mask +%tail = pto.pge_b32 "PAT_VL16" : !pto.mask +``` + +**Sync operations:** + +```mlir +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +**Pointer construction and arithmetic:** + +```mlir +%ptr = pto.castptr %addr : i64 -> !pto.ptr +%ptr2 = pto.addptr %ptr, %offset : !pto.ptr -> !pto.ptr +``` + +### Shared Dialect Syntax Patterns + +PTO micro Instruction programs may interleave PTO ops with standard MLIR `arith` and `scf` ops. +The examples below emphasize common index-heavy patterns, but `arith` support is not limited to index arithmetic. + +**Scalar / index constant:** + +```mlir +%c0 = arith.constant 0 : index +%zero = arith.constant 0.0 : f32 +``` + +**Scalar arithmetic (integer / float / boolean-style bitwise):** + +```mlir +%sum_i = arith.addi %lhs_i, %rhs_i : i32 +%sum_f = arith.addf %lhs_f, %rhs_f : f32 +%bits = arith.andi %flags0, %flags1 : i32 +``` + +**Scalar compare and select:** + +```mlir +%cond = arith.cmpi eq, %lhs, %rhs : index +%bound = arith.select %cond, %a, %b : index +``` + +**Counted loop with loop-carried values:** + +```mlir +%result = scf.for %iv = %lb to %ub step %step + iter_args(%acc = %init) -> (index) { + %next = arith.addi %acc, %iv : index + scf.yield %next : index +} +``` + +**Structured conditional region:** + +```mlir +%selected = scf.if %cond -> (index) { + scf.yield %then_value : index +} else { + scf.yield %else_value : index +} +``` + +**Structured while loop:** + +```mlir +%state:2 = scf.while (%iv = %c0, %alive = %true) : (index, i1) -> (index, i1) { + %keep_going = arith.cmpi slt, %iv, %limit : index + scf.condition(%keep_going) %iv, %alive : index, i1 +} do { +^bb0(%iv_in: index, %alive_in: i1): + %iv_next = arith.addi %iv_in, %c1 : index + scf.yield %iv_next, %alive_in : index, i1 +} +``` + +### C-Style Semantics Convention + +For each ISA operation in Part III, semantics are expressed as C code. The convention: + +```c +// Vector register contents as arrays: +T dst[N]; // destination +T src0[N]; // first source +T src1[N]; // second source (binary ops) +T scalar; // scalar operand (vec-scalar ops) +int mask[N]; // per-lane predicate (0 or 1) + +// N = lane count determined by type: +// N = 256 for i8/si8/ui8 +// N = 128 for i16/si16/ui16/f16/bf16 +// N = 64 for i32/si32/ui32/f32 +// N = 32 for i64/si64/ui64 +``` + +**Example — pto.vadd semantics:** + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +**Example — pto.vcgadd (group reduction per VLane) semantics:** + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +### Template Placeholder Conventions + +| Placeholder | Meaning | +|-------------|---------| +| `"SRC_PIPE"`, `"DST_PIPE"` | Pipeline identifiers: `"PIPE_MTE2"`, `"PIPE_V"`, `"PIPE_MTE3"` | +| `"EVENT_ID"` | Event identifier: `"EVENT_ID0"` etc. | +| `"DIST"` | Distribution mode string (see the relevant load/store ISA group in Part III) | +| `"CMP_MODE"` | Compare predicate: `eq \| ne \| lt \| le \| gt \| ge` | +| `"RND"` | Rounding mode: `R \| A \| F \| C \| Z \| O` | +| `"SAT"` | Saturation: `SAT \| NOSAT` | +| `"PART"` | Half selector: `EVEN \| ODD` | +| `"PAT_*"` | Predicate pattern literal | +| `T` | Element type (f32, f16, bf16, i32, i16, i8, etc.) | +| `N` | Lane count (`N * bitwidth(T) = 2048`) | + +--- + +### C-Style Semantics Convention + +For each ISA operation in Part III, semantics are expressed as C code. The convention: + +```c +// Vector register contents as arrays: +T dst[N]; // destination +T src0[N]; // first source +T src1[N]; // second source (binary ops) +T scalar; // scalar operand (vec-scalar ops) +int mask[N]; // per-lane predicate (0 or 1) + +// N = lane count determined by type: +// N = 256 for i8/si8/ui8 +// N = 128 for i16/si16/ui16/f16/bf16 +// N = 64 for i32/si32/ui32/f32 +// N = 32 for i64/si64/ui64 +``` + +**Example — pto.vadd semantics:** + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +**Example — pto.vcgadd (group reduction per VLane) semantics:** + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +### Template Placeholder Conventions + +| Placeholder | Meaning | +|-------------|---------| +| `"SRC_PIPE"`, `"DST_PIPE"` | Pipeline identifiers: `"PIPE_MTE2"`, `"PIPE_V"`, `"PIPE_MTE3"` | +| `"EVENT_ID"` | Event identifier: `"EVENT_ID0"` etc. | +| `"DIST"` | Distribution mode string (see the relevant load/store ISA group in Part III) | +| `"CMP_MODE"` | Compare predicate: `eq \| ne \| lt \| le \| gt \| ge` | +| `"RND"` | Rounding mode: `R \| A \| F \| C \| Z \| O` | +| `"SAT"` | Saturation: `SAT \| NOSAT` | +| `"PART"` | Half selector: `EVEN \| ODD` | +| `"PAT_*"` | Predicate pattern literal | +| `T` | Element type (f32, f16, bf16, i32, i16, i8, etc.) | +| `N` | Lane count (`N * bitwidth(T) = 2048`) | + +--- + +## Part III: ISA Instruction Reference — Summary + +This section provides a categorized overview of all PTO micro Instruction operations plus the shared MLIR `arith` and `scf` ops that may appear in PTO micro Instruction programs. Detailed documentation for each group is included later in this merged document. + +--- + +## Instruction Groups + +| # | Group | Description | Count | Details | +|---|-------|-------------|-------|---------| +| 1 | [Pipeline Sync](#isa-01-pipeline-sync) | Intra-core pipeline synchronization | 5 | `pto.set_flag`, `pto.wait_flag`, `pto.pipe_barrier`, `pto.get_buf`, `pto.rls_buf` | +| 2 | [DMA Copy Programming](#isa-02-dma-copy) | DMA configuration and transfer between GM↔UB | 9 | `pto.set_loop*_stride_*`, `pto.set_loop_size_*`, `pto.copy_gm_to_ubuf`, `pto.copy_ubuf_to_ubuf`, `pto.copy_ubuf_to_gm` | +| 3 | [Vector Load/Store](#isa-03-vector-load-store) | UB↔vreg data movement with various access patterns | ~20 | `pto.vlds`, `pto.vldsx2`, `pto.vgather2`, `pto.vsts`, `pto.vstsx2`, `pto.vscatter`, etc. | +| 4 | [Predicate Load/Store](#isa-04-predicate-load-store) | UB↔mask register movement | 5 | `pto.plds`, `pto.pldi`, `pto.psts`, `pto.psti`, `pto.pstu` | +| 5 | [Materialization & Predicate Ops](#isa-05-materialization-predicate) | Scalar broadcast, predicate generation and manipulation | ~17 | `pto.vbr`, `pto.vdup`, `pto.pset_b*`, `pto.pge_b*`, `pto.plt_b*`, `pto.ppack`, `pto.punpack`, `pto.pnot`, `pto.psel`, etc. | +| 6 | [Unary Vector Ops](#isa-06-unary-vector-ops) | Single-input element-wise operations | 6 | `pto.vabs`, `pto.vexp`, `pto.vln`, `pto.vsqrt`, `pto.vrelu`, `pto.vnot` | +| 7 | [Binary Vector Ops](#isa-07-binary-vector-ops) | Two-input element-wise operations | 13 | `pto.vadd`, `pto.vsub`, `pto.vmul`, `pto.vdiv`, `pto.vmax`, `pto.vmin`, `pto.vand`, `pto.vor`, `pto.vxor`, `pto.vshl`, `pto.vshr`, `pto.vaddc`, `pto.vsubc` | +| 8 | [Vec-Scalar Ops](#isa-08-vec-scalar-ops) | Vector-scalar operations | 9 | `pto.vadds`, `pto.vmuls`, `pto.vmaxs`, `pto.vmins`, `pto.vlrelu`, `pto.vshls`, `pto.vshrs`, `pto.vaddcs`, `pto.vsubcs` | +| 9 | [Conversion Ops](#isa-09-conversion-ops) | Type conversion with rounding/saturation control | 2 | `pto.vcvt`, `pto.vtrc` | +| 10 | [Reduction Ops](#isa-10-reduction-ops) | Vector reductions | 7 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin`, `pto.vcgadd`, `pto.vcgmax`, `pto.vcgmin`, `pto.vcpadd` | +| 11 | [Compare & Select](#isa-11-compare-select) | Comparison and conditional selection | 4 (+1 not A5) | `pto.vcmp`, `pto.vcmps`, `pto.vsel`, `pto.vselr` (`pto.vselrv2` removed: not A5) | +| 12 | [Data Rearrangement](#isa-12-data-rearrangement) | In-register data movement and permutation | 2 (+2 not A5) | `pto.vintlv`, `pto.vdintlv` (`pto.vintlvv2`, `pto.vdintlvv2` removed: not A5) | +| 13 | [DSA/SFU Ops](#isa-13-dsa-sfu-ops) | Specialized ops, index generation, and sorting helpers | 9 | `pto.vlrelu`, `pto.vprelu`, `pto.vexpdiff`, `pto.vaxpy`, `pto.vmull`, `pto.vmula`, `pto.vci`, `pto.vbitsort`, `pto.vmrgsort4` | +| 14 | [Arith (Shared MLIR Dialect)](#isa-14-shared-arith) | Full scalar `arith` surface used around PTO ops; the companion page lists categories and representative examples | all scalar ops | `arith.constant`, `arith.addi`, `arith.addf`, `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.index_cast`, `arith.extsi`, `arith.trunci`, `arith.andi`, `arith.shli`, etc. | +| 15 | [SCF (Shared MLIR Dialect)](#isa-15-shared-scf) | Structured loops, branches, and loop-carried state around PTO regions | 5 | `scf.for`, `scf.if`, `scf.while`, `scf.condition`, `scf.yield` | + +--- + +## Detailed ISA Group Reference + +This section inlines the 15 ISA group documents so the architectural overview, notation, summary table, and per-group semantics can be read in a single file. + + + +### 1. Pipeline Synchronization + +> **Category:** Synchronization primitives for coordinating pipeline execution +> **Pipelines:** MTE2 (GM→UB), PIPE_V (Vector), MTE3 (UB→GM) + +The PTO micro Instruction model operates on the Ascend 950's **Decoupled Access-Execute** architecture. The MTE and Vector pipelines run asynchronously, requiring explicit synchronization to prevent data hazards. + +--- + +#### Intra-Core Pipeline Sync + +These ops coordinate data flow between pipelines within a single vector core. + +##### `pto.set_flag` + +- **syntax:** `pto.set_flag["SRC_PIPE", "DST_PIPE", "EVENT_ID"]` +- **semantics:** Signal event from source pipe to destination pipe. + +```c +set_flag(src_pipe, dst_pipe, event_id); +``` + +**Example:** After MTE2 completes GM→UB transfer, signal Vector pipe: +```mlir +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +--- + +##### `pto.wait_flag` + +- **syntax:** `pto.wait_flag["SRC_PIPE", "DST_PIPE", "EVENT_ID"]` +- **semantics:** Block destination pipe until source pipe signals event. + +```c +wait_flag(src_pipe, dst_pipe, event_id); +``` + +**Example:** Vector pipe waits for MTE2 data to arrive: +```mlir +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +--- + +##### `pto.pipe_barrier` + +- **syntax:** `pto.pipe_barrier "PIPE_*"` +- **semantics:** Drain all pending ops in the specified pipe. All previously issued operations on that pipe complete before any subsequent operation begins. + +```c +pipe_barrier(pipe); +``` + +**Pipe identifiers:** `PIPE_MTE2`, `PIPE_V`, `PIPE_MTE3` + +**Example:** Two back-to-back `copy_ubuf_to_gm` calls writing to the same GM address. Without a barrier, MTE3 may reorder them and the final GM value is non-deterministic: + +```mlir +// Both stores target the same GM address — order matters! +pto.copy_ubuf_to_gm %ub_partial_0, %gm_result, ... +// Without pipe_barrier, MTE3 could execute the second copy before the first +// completes, producing a non-deterministic result at %gm_result. +pto.pipe_barrier "PIPE_MTE3" +// After barrier: first copy is guaranteed complete. Second copy overwrites deterministically. +pto.copy_ubuf_to_gm %ub_partial_1, %gm_result, ... +``` + +--- + +##### `pto.get_buf` + +- **syntax:** `pto.get_buf "PIPE_*", %buf_id, %mode : i64, i64` +- **semantics:** Acquire buffer slot for inter-pipeline double-buffering coordination. + +```c +get_buf(pipe, buf_id, mode); +``` + +--- + +##### `pto.rls_buf` + +- **syntax:** `pto.rls_buf "PIPE_*", %buf_id, %mode : i64, i64` +- **semantics:** Release buffer slot to allow other pipeline to proceed. + +```c +rls_buf(pipe, buf_id, mode); +``` + +--- + +##### Mode Parameter for `get_buf` / `rls_buf` + +The `mode` parameter controls how `get_buf` and `rls_buf` interact with pipeline execution and dependency tracking: + +| Mode | `get_buf` Behavior | `rls_buf` Behavior | Use Case | +|------|-------------------|-------------------|----------| +| **0** (default) | **Blocking acquire**: waits for all previous `rls_buf` with same `buf_id` from all pipelines (in program order) before the specified pipe can proceed | **Immediate release**: signals completion for only the instructions related to the specified pipe | **Automatic ping/pong dependency** — recommended for double/multi-buffering | +| **1** | **Non-blocking acquire**: does not wait; pipe execution proceeds immediately | **Deferred release**: waits for all instructions across all pipelines with same `buf_id` to retire before signaling | **Backward compatibility** with `set_flag`/`wait_flag` semantics | + +**Mode 0 (Default — Recommended):** +- `get_buf`: The specified pipeline blocks until all previous `rls_buf` operations for the same buffer ID (from any pipeline) have completed, respecting program order. +- `rls_buf`: Immediately signals that the specified pipeline has finished using the buffer — only waits for that pipe's related instructions. +- This mode provides **automatic RAW/WAR/WAW dependency resolution** based on buffer ID and program order, making it ideal for ping/pong and N-buffer patterns. + +**Mode 1 (Legacy Compatibility):** +- `get_buf`: Does not block — the pipeline proceeds immediately without waiting. +- `rls_buf`: Waits for **all** previous instructions across **all** pipelines with the same buffer ID to retire before signaling release. +- This mode emulates `set_flag`/`wait_flag` behavior and is provided for backward compatibility with existing code patterns. + +> **Note:** A5 supports both `set_flag`/`wait_flag` and `get_buf`/`rls_buf` mechanisms. Mode 1 is rarely needed since mode 0 provides a more programmer-friendly approach for buffer-based synchronization. + +--- + +##### `pto.mem_bar` + +- **syntax:** `pto.mem_bar "BARRIER_TYPE"` +- **semantics:** Intra-vector-pipe memory fence within `__VEC_SCOPE__`. Required when UB addresses alias between vector load/store operations. + +```c +mem_bar(barrier_type); +``` + +**Barrier types:** + +| Type | Semantics | +|------|-----------| +| `VV_ALL` | All prior vector ops complete before subsequent | +| `VST_VLD` | All prior vector stores visible before subsequent loads | +| `VLD_VST` | All prior vector loads complete before subsequent stores | + +**Example:** Ensure stores are visible before loads to same UB region: +```mlir +pto.vsts %v0, %ub[%c0] : !pto.vreg<64xf32>, !pto.ptr +pto.mem_bar "VST_VLD" +%v1 = pto.vlds %ub[%c0] : !pto.ptr -> !pto.vreg<64xf32> +``` + +--- + +#### Why `get_buf` / `rls_buf` is More Programmer-Friendly + +The buffer-based synchronization (`get_buf`/`rls_buf`) provides the **same functional capability** as `set_flag`/`wait_flag` for maintaining correct ordering of RAW/WAR/WAW dependencies across pipelines, but with significant usability advantages: + +##### 1. No Manual Priming or Draining + +With `set_flag`/`wait_flag`, ping/pong loops require: +- **Pre-loop priming**: 4× `set_flag` to initialize reverse-dependency signals (otherwise first iteration deadlocks) +- **Post-loop draining**: 4× `wait_flag` to consume leftover signals from final iterations + +With `get_buf`/`rls_buf`: +- **First iteration**: Buffer is initially free, so `get_buf` proceeds immediately — no priming needed +- **Final iteration**: Last `rls_buf` simply completes — no draining required + +##### 2. No Loop Peeling for Complex Dependencies + +For non-1:1 producer-consumer ratios (e.g., 1 MTE2 load : N Vector compute slices), `set_flag`/`wait_flag` requires **peeling the set_flag outside the loop**: + +```mlir +// set_flag/wait_flag: 1 MTE2 load, 8 Vector computes on slices +// MTE2 loads large tile once +pto.copy_gm_to_ubuf %gm_ptr, %ub_tile, ... +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVT_TILE_READY"] // ◀ MUST be outside loop + +// Vector consumes in 8 slices — but wait_flag can only fire ONCE +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVT_TILE_READY"] // ◀ MUST peel before loop +scf.for %slice = %c0 to %c8 step %c1 { + // compute on %ub_tile[%slice] + // Cannot put wait_flag here — would deadlock on iteration 1+ +} +``` + +With `get_buf`/`rls_buf`, acquire/release can be **inside the loop** — no peeling needed: + +```mlir +// get_buf/rls_buf: same 1:8 pattern, acquire/release inside loop works fine +// MTE2 loads large tile +pto.get_buf "PIPE_MTE2", %bufid_tile, %c0 : i64, i64 +pto.copy_gm_to_ubuf %gm_ptr, %ub_tile, ... +pto.rls_buf "PIPE_MTE2", %bufid_tile, %c0 : i64, i64 + +// Vector acquires/releases per slice — all 8 iterations work correctly +scf.for %slice = %c0 to %c8 step %c1 { + pto.get_buf "PIPE_V", %bufid_tile, %c0 : i64, i64 // iteration 0: blocks until MTE2 done + // iteration 1-7: proceeds immediately (already acquired) + // compute on %ub_tile[%slice] + pto.rls_buf "PIPE_V", %bufid_tile, %c0 : i64, i64 +} +// No peeling required — get_buf handles the MTE2→V dependency automatically +``` + +##### 3. Simpler Mental Model + +| Aspect | `set_flag`/`wait_flag` | `get_buf`/`rls_buf` | +|--------|------------------------|---------------------| +| **Dependency tracking** | Manual: track event IDs, signal directions, pair every set with wait | Automatic: buffer ID + program order | +| **Event ID management** | **8 IDs per pipe-pair direction** (HW limit); each buffer occupies 1 ID per direction | **1 buffer ID per shared resource** (HW limit: 32 global); same ID used across all pipelines | +| **Error-prone areas** | Forgetting prime/drain, mismatched IDs, wrong direction | Forgetting release (but compile-time checkable) | + +##### Quick Example Comparison + +**Problem:** MTE2 loads into `buf[i%2]`, Vector processes, MTE3 stores — standard ping/pong. + +**set_flag/wait_flag approach:** +```mlir +// BEFORE loop: prime 4 reverse-dep signals +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] + +scf.for %i = ... { + // 4 set_flag + 4 wait_flag inside loop + // Must track 4 IDs: 2 pipe-pair directions × 2 ping/pong buffers +} + +// AFTER loop: drain 4 signals +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] +``` + +**get_buf/rls_buf approach:** +```mlir +scf.for %i = ... { + pto.get_buf %bufid_in[%pp], "PIPE_MTE2" + // ... MTE2 work ... + pto.rls_buf %bufid_in[%pp], "PIPE_MTE2" + + pto.get_buf %bufid_in[%pp], "PIPE_V" + pto.get_buf %bufid_out[%pp], "PIPE_V" + // ... Vector work ... + pto.rls_buf %bufid_in[%pp], "PIPE_V" + pto.rls_buf %bufid_out[%pp], "PIPE_V" + + pto.get_buf %bufid_out[%pp], "PIPE_MTE3" + // ... MTE3 work ... + pto.rls_buf %bufid_out[%pp], "PIPE_MTE3" +} +// Done. No prime. No drain. Dependencies resolved by buffer ID + program order. +``` + +--- + +#### Intra-Core Sync Patterns & Examples + +##### Example 1: `set_flag` / `wait_flag` (Explicit Events) + +Each cross-pipeline data dependency requires an explicit signal/wait pair. The programmer must manually insert `set_flag` after the producer and `wait_flag` before the consumer. + +```mlir +// ─── Stage 1: MTE2 loads data from GM into UB ─── +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, ... + +// MTE2 signals: "UB data is ready for Vector pipe" +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +// ─── Stage 2: Vector pipe consumes UB data ─── +// Vector waits until MTE2's signal arrives +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +pto.vecscope { + %v = pto.vlds %ub_ptr[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} + +// Vector signals: "UB output is ready for MTE3" +pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + +// ─── Stage 3: MTE3 stores result from UB back to GM ─── +// MTE3 waits until Vector's signal arrives +pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + +pto.copy_ubuf_to_gm %ub_out, %gm_out, ... +``` + +**Key property:** Every cross-pipeline edge is an explicit `(set_flag, wait_flag)` pair. Simple for straight-line code, but gets verbose in loops (see Example 3). + +--- + +##### Example 2: `get_buf` / `rls_buf` (Resource-Based) + +Instead of naming events, each pipeline declares when it **acquires** (`get_buf`) and **releases** (`rls_buf`) a shared UB buffer. Cross-pipeline RAW/WAR dependencies are resolved implicitly by program order — if MTE2 releases `buf_A` and Vector later acquires `buf_A`, the hardware ensures the acquire cannot proceed until the release completes. + +```mlir +// ─── Stage 1: MTE2 loads data into UB ─── +// MTE2 acquires ub_ptr — blocks if Vector hasn't released it from a prior iteration +pto.get_buf "PIPE_MTE2", %bufid_ub_ptr, %c0 : i64, i64 // mode=0 (default) +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, ... +// MTE2 done writing ub_ptr — release it so Vector can consume +pto.rls_buf "PIPE_MTE2", %bufid_ub_ptr, %c0 : i64, i64 + +// ─── Stage 2: Vector computation ─── +// Vector acquires ub_ptr (input) — blocks until MTE2 releases it (RAW: MTE2 write → V read) +pto.get_buf "PIPE_V", %bufid_ub_ptr, %c0 : i64, i64 +// Vector acquires ub_out (output) — blocks until MTE3 releases it from a prior iteration (WAR: MTE3 read → V write) +pto.get_buf "PIPE_V", %bufid_ub_out, %c0 : i64, i64 + +pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub_ptr[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} + +// Vector done reading ub_ptr — release so MTE2 can reuse it in next iteration +pto.rls_buf "PIPE_V", %bufid_ub_ptr, %c0 : i64, i64 +// Vector done writing ub_out — release so MTE3 can consume +pto.rls_buf "PIPE_V", %bufid_ub_out, %c0 : i64, i64 + +// ─── Stage 3: MTE3 stores result to GM ─── +// MTE3 acquires ub_out — blocks until Vector releases it (RAW: V write → MTE3 read) +pto.get_buf "PIPE_MTE3", %bufid_ub_out, %c0 : i64, i64 +pto.copy_ubuf_to_gm %ub_out, %gm_out, ... +// MTE3 done reading ub_out — release so Vector can reuse it in next iteration +pto.rls_buf "PIPE_MTE3", %bufid_ub_out, %c0 : i64, i64 +``` + +**Key property:** No event IDs needed. Dependencies are implicit from program order of `get_buf`/`rls_buf` on the same buffer ID. This becomes much more convenient in multi-iteration loops (see Example 3). + +--- + +##### Example 3: Ping/Pong Double-Buffering Loop + +Double-buffering overlaps DMA and compute by using two UB buffers alternately. All three stages (MTE2, Vector, MTE3) appear in the **same iteration** — the hardware pipelines them across iterations because different iterations operate on different buffers (`buf[i%2]`). + +###### Event ID scheme (`set_flag` / `wait_flag`) + +With 2 ping/pong buffers and 2 pipeline pairs (MTE2↔V, V↔MTE3), `set_flag`/`wait_flag` needs **8 event IDs** = 2 pipe-pairs × 2 buffers × (forward + reverse): + +**MTE2 ↔ Vector (input buffers):** + +| Event ID | Direction | Purpose | +|----------|-----------|---------| +| `EVT_IN_FWD_0` | MTE2 → V | RAW: buf_in[0] data ready | +| `EVT_IN_FWD_1` | MTE2 → V | RAW: buf_in[1] data ready | +| `EVT_IN_REV_0` | V → MTE2 | WAR: Vector done reading buf_in[0] | +| `EVT_IN_REV_1` | V → MTE2 | WAR: Vector done reading buf_in[1] | + +**Vector ↔ MTE3 (output buffers):** + +| Event ID | Direction | Purpose | +|----------|-----------|---------| +| `EVT_OUT_FWD_0` | V → MTE3 | RAW: buf_out[0] result ready | +| `EVT_OUT_FWD_1` | V → MTE3 | RAW: buf_out[1] result ready | +| `EVT_OUT_REV_0` | MTE3 → V | WAR: MTE3 done reading buf_out[0] | +| `EVT_OUT_REV_1` | MTE3 → V | WAR: MTE3 done reading buf_out[1] | + +###### 3a. `set_flag` / `wait_flag` version + +```mlir +// ═══ Pre-loop: prime ALL reverse-dependency signals ═══ +// Both input and output buffers start unused. We must pre-send +// reverse-dep signals so the first iteration's wait_flags don't deadlock. +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] // ◀ PRIME: buf_in[0] "free" +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] // ◀ PRIME: buf_in[1] "free" +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] // ◀ PRIME: buf_out[0] "free" +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] // ◀ PRIME: buf_out[1] "free" + +scf.for %i = %c0 to %N step %c1 { + // ── All 3 stages in same iteration, indexed by i%2 ── + // %pp = i % 2 (ping/pong selector for buffer & event IDs) + + // ── MTE2: load tile[i] into buf_in[i%2] ── + // WAR: wait until Vector has released buf_in[i%2] from iteration i-2 + pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{pp}"] + pto.copy_gm_to_ubuf %gm_ptr[%i], %ub_in[%pp], ... + // RAW: signal Vector that buf_in[i%2] data is ready + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVT_IN_FWD_{pp}"] + + // ── Vector: compute buf_in[i%2] → buf_out[i%2] ── + // RAW: wait for MTE2 to finish loading buf_in[i%2] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVT_IN_FWD_{pp}"] + // WAR: wait for MTE3 to finish reading buf_out[i%2] from iteration i-2 + pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{pp}"] + scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_in[%pp][%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%pp][%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } {llvm.loop.aivector_scope} + // WAR: tell MTE2 "done reading buf_in[i%2]" + pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{pp}"] + // RAW: tell MTE3 "buf_out[i%2] result ready" + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVT_OUT_FWD_{pp}"] + + // ── MTE3: store result from buf_out[i%2] to GM ── + // RAW: wait for Vector to finish writing buf_out[i%2] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVT_OUT_FWD_{pp}"] + pto.copy_ubuf_to_gm %ub_out[%pp], %gm_out[%i], ... + // WAR: tell Vector "done reading buf_out[i%2]" + pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{pp}"] +} + +// ═══ Post-loop: drain — match every pre-loop prime with a wait ═══ +// Each priming set_flag must be paired. The last loop iteration's +// set_flags are consumed by wait_flags that will never fire inside the +// loop (there is no iteration i+2). Drain them here. +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-1)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-2)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-1)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-2)%2}"] // ◀ DRAIN +``` + +**What `set_flag`/`wait_flag` requires outside the loop:** +- **Before the loop (4 × `set_flag`):** Prime every reverse-dependency event ID — one per buffer per pipe-pair. Without this, the first iteration's `wait_flag` for reverse deps would deadlock (no signal was ever sent). +- **After the loop (4 × `wait_flag`):** Drain the matching reverse-dep signals from the last iterations. Every `set_flag` must be paired with a `wait_flag` — the last loop iterations produce signals that no subsequent iteration consumes, so they must be drained explicitly. + +###### 3b. `get_buf` / `rls_buf` version + +Same ping/pong double-buffering, but **no pre-loop priming or post-loop draining needed.** Buffer acquire/release semantics handle everything. + +```mlir +scf.for %i = %c0 to %N step %c1 { + // %pp = i % 2 (ping/pong selector) + + // ── MTE2: load tile[i] into buf[i%2] ── + // Acquires buf[i%2] — on first iteration, buffer is free so proceeds immediately. + // On later iterations, blocks until Vector releases buf[i%2] (WAR: automatic). + pto.get_buf "PIPE_MTE2", %bufid_buf[%pp], %c0 : i64, i64 // mode=0 + pto.copy_gm_to_ubuf %gm_ptr[%i], %ub_buf[%pp], ... + pto.rls_buf "PIPE_MTE2", %bufid_buf[%pp], %c0 : i64, i64 + + // ── Vector: compute on buf[i%2] ── + // Acquires buf[i%2] — blocks until MTE2 releases it (RAW: automatic) + pto.get_buf "PIPE_V", %bufid_buf[%pp], %c0 : i64, i64 + pto.get_buf "PIPE_V", %bufid_out[%pp], %c0 : i64, i64 + scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_buf[%pp][%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%pp][%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } {llvm.loop.aivector_scope} + // Release buf[i%2] — MTE2 can reuse in iteration i+2 (WAR resolved) + pto.rls_buf "PIPE_V", %bufid_buf[%pp], %c0 : i64, i64 + pto.rls_buf "PIPE_V", %bufid_out[%pp], %c0 : i64, i64 + + // ── MTE3: store result ── + // Acquires out[i%2] — blocks until Vector releases it (RAW: automatic) + pto.get_buf "PIPE_MTE3", %bufid_out[%pp], %c0 : i64, i64 + pto.copy_ubuf_to_gm %ub_out[%pp], %gm_out[%i], ... + pto.rls_buf "PIPE_MTE3", %bufid_out[%pp], %c0 : i64, i64 +} +// No post-loop drain needed — last rls_buf completes the pipeline. +``` + +**No priming, no draining, no event IDs.** The acquire/release protocol on buffer IDs indexed by `i%2` implicitly resolves all cross-pipeline dependencies: +- **RAW** (MTE2→V): Vector's `get_buf` blocks until MTE2's `rls_buf` on `buf[i%2]` +- **WAR** (V→MTE2): MTE2's `get_buf` in iteration `i+2` blocks until Vector's `rls_buf` in iteration `i` (same buffer) +- **First iteration:** Buffer is initially free, so `get_buf` proceeds without blocking — no priming needed + +--- + +#### Comparison Summary + +| Aspect | `set_flag` / `wait_flag` | `get_buf` / `rls_buf` | +|--------|--------------------------|------------------------| +| Dependency model | Explicit event signals | Implicit via buffer acquire/release | +| IDs per pipe-pair | 2 IDs per buffer: 1 for forward (e.g., MTE2→V) + 1 for reverse (V→MTE2) | **1 ID per buffer** (handles both directions automatically) | +| Total HW IDs | **8 per pipe-pair** (hardware limit) | **32 global** across all pipes | +| Reverse (WAR) deps | Extra `set_flag`/`wait_flag` pair per buffer | Handled automatically | +| Pre-loop setup | `set_flag` to prime each reverse dep | **None** | +| Post-loop teardown | `wait_flag` to drain all primed signals | **None** | +| Loop peeling for complex deps | Required for non-1:1 or nested loops | **Not required** | +| Straight-line code | Simple, clear | Slightly more verbose (bracket each stage) | +| Ping/pong loops | 8 event IDs + 4 prime + 4 drain | Same pattern, **no overhead** | +| Best used for | Simple pipelines, fine-grained control | Double/multi-buffering, complex loops | + +--- + +#### Inter-Core Sync + +> **Note:** Inter-core sync is only needed for **mixed Cube+Vector tasks** where Cube produces data that Vector consumes (or vice versa). **Vec-only tasks can ignore this section entirely.** + +These ops coordinate execution across the Cube block and Vector subblocks within a cluster. Each core cluster consists of **1 Cube block : 2 Vector subblocks**, each with its own **SU (Sequencer Unit)** running independent instruction streams. + +``` +Core Cluster (1:2 ratio) +┌─────────────────────────────────────────────┐ +│ ┌──────────────┐ ┌──────────────┐ │ +│ │ AIC (Cube) │ │ AIV0 (Vec) │ │ +│ │ ┌────────┐ │ │ ┌────────┐ │ │ +│ │ │ SU │──┼────┼──│ SU │ │ │ +│ │ └────────┘ │ │ └────────┘ │ │ +│ │ CUBE pipe │ │ MTE2/V/MTE3 │ │ +│ │ L0C buffer │ │ UB (256KB) │ │ +│ └──────────────┘ └──────────────┘ │ +│ ┌──────────────┐ │ +│ │ AIV1 (Vec) │ │ +│ │ ┌────────┐ │ │ +│ │ │ SU │ │ │ +│ │ └────────┘ │ │ +│ │ MTE2/V/MTE3 │ │ +│ │ UB (256KB) │ │ +│ └──────────────┘ │ +└─────────────────────────────────────────────┘ +``` + +##### Platform Comparison + +| Aspect | A2A3 (Ascend 910) | A5 (Ascend 950) | +|--------|-------------------|-----------------| +| **Signal op** | `set_cross_core` (mode2) | `set_intra_block` | +| **Wait op** | `wait_flag_dev` | `wait_intra_core` | +| **Wait behavior** | SU-level blocking (entire core stalls) | Per-pipeline (only named pipe stalls) | +| **Semaphore pool** | 16 IDs per cluster, 4-bit counter | 16 IDs, but 32-ID address space (see below) | +| **C→V** | **Broadcast**: one `set` reaches both AIV0+AIV1 | **1:1**: separate `set` per subblock required | +| **V→C** | **Reduce**: Cube waits for both subblocks in one `wait` | **1:1**: Cube needs separate `wait` per subblock | + +##### A2A3: `set_cross_core` / `wait_flag_dev` + +```c +// mode2 broadcast/reduce semantics for 1:2 cluster +set_cross_core(pipe, semaphore_id); // pipe: VEC/MTE2/CUBE/FIX +wait_flag_dev(semaphore_id); // SU-level blocking +``` + +``` +C→V Broadcast (one set reaches both): + AIC ──set_cross_core──┬──> AIV0 sema++ + └──> AIV1 sema++ + +V→C Reduce (one wait for both): + AIV0 ──set_cross_core──┐ + ├──> AIC wait_flag_dev (blocks until BOTH) + AIV1 ──set_cross_core──┘ +``` + +##### `pto.set_cross_core` + +- **syntax:** `pto.set_cross_core %core_id, %event_id : i64, i64` +- **semantics:** Signal event to another core. Uses **mode2** for 1:2 cluster on A2A3. + +##### `pto.wait_flag_dev` + +- **syntax:** `pto.wait_flag_dev %core_id, %event_id : i64, i64` +- **semantics:** Wait for event from another core. **SU-level blocking** — entire core stalls. + +##### A5: `set_intra_block` / `wait_intra_core` + +```c +set_intra_block(trigger_pipe, semaphore_id); +wait_intra_core(wait_pipe, semaphore_id); // only named pipe stalls +``` + +**A5 semaphore address space:** The hardware has **16 physical semaphore IDs** but exposes a **32-ID address space** to support 1:1 signaling to each subblock: + +| ID Range | Target | +|----------|--------| +| 0–15 | AIV0 (subblock 0) | +| 16–31 (+15 offset) | AIV1 (subblock 1) | + +This means C→V requires **separate `set_intra_block` calls** per subblock (no broadcast), and V→C requires **separate `wait_intra_core` calls** per subblock (no hardware reduce). + +``` +C→V on A5 (1:1, no broadcast — need two sets): + AIC ──set_intra_block(pipe, sema_id)────> AIV0 + AIC ──set_intra_block(pipe, sema_id+15)──> AIV1 + +V→C on A5 (1:1, no reduce — need two waits): + AIV0 ──set_intra_block──> AIC wait_intra_core(pipe, sema_id) + AIV1 ──set_intra_block──> AIC wait_intra_core(pipe, sema_id+15) // extra wait +``` + +##### `pto.set_intra_block` + +- **syntax:** `pto.set_intra_block %block_id, %event_id : i64, i64` +- **semantics:** Signal event within a block (A5). Specifies **trigger pipe**. 1:1 per subblock. + +##### `pto.wait_intra_core` + +- **syntax:** `pto.wait_intra_core %block_id, %event_id : i64, i64` +- **semantics:** Wait for event within block (A5). Specifies **which pipeline should wait** — only that pipe stalls, SU and other pipes continue. + +##### Wait Granularity Comparison + +``` +A2A3 wait_flag_dev (SU-level stall): + SU ──┬── PIPE_MTE2 ───╳ ALL STALLED + ├── PIPE_V ───╳ ALL STALLED + └── PIPE_MTE3 ───╳ ALL STALLED + +A5 wait_intra_core "PIPE_MTE2" (per-pipe stall): + SU ──┬── PIPE_MTE2 ───╳ STALLED (waiting for Cube) + ├── PIPE_V ─── ✓ RUNNING + └── PIPE_MTE3 ─── ✓ RUNNING +``` + + + +### 2. DMA Copy Programming + +> **Category:** DMA transfer configuration and execution +> **Pipelines:** MTE2 (GM→UB), MTE3 (UB→GM) + +DMA transfers move data between Global Memory (GM) and Unified Buffer (UB). The MTE engines operate asynchronously from the Vector core, requiring explicit sync (see [Pipeline Sync](#isa-01-pipeline-sync)). + +The MTE2/MTE3 DMA engine executes a **multi-level nested loop** transfer. Before issuing the copy instruction, stride and loop-size registers must be configured. + +--- + +#### Loop Stride Configuration (GM→UB) + +These ops configure the MTE2 DMA engine's hardware loops for GM→UB transfers. They must be set **before** calling `pto.copy_gm_to_ubuf`. + +##### `pto.set_loop_size_outtoub` + +- **syntax:** `pto.set_loop_size_outtoub %loop1_count, %loop2_count : i64, i64` +- **semantics:** Configure HW loop iteration counts for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%loop1_count` | 21 bits | Inner HW loop iteration count | +| `%loop2_count` | 21 bits | Outer HW loop iteration count | + +When not using multi-level looping, set both to 1. + +--- + +##### `pto.set_loop2_stride_outtoub` + +- **syntax:** `pto.set_loop2_stride_outtoub %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure outer loop (loop2) pointer advance for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 40 bits | GM source pointer advance per loop2 iteration (bytes) | +| `%dst_stride` | 21 bits | UB destination pointer advance per loop2 iteration (bytes) | + +After each loop2 iteration, the DMA engine advances the GM read pointer by `%src_stride` and UB write pointer by `%dst_stride`. + +--- + +##### `pto.set_loop1_stride_outtoub` + +- **syntax:** `pto.set_loop1_stride_outtoub %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure inner loop (loop1) pointer advance for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 40 bits | GM source pointer advance per loop1 iteration (bytes) | +| `%dst_stride` | 21 bits | UB destination pointer advance per loop1 iteration (bytes) | + +--- + +#### Loop Stride Configuration (UB→GM) + +These ops configure the MTE3 DMA engine's hardware loops for UB→GM transfers. They must be set **before** calling `pto.copy_ubuf_to_gm`. + +Note: UB stride fields are 21 bits (sufficient for 256KB UB address space), GM stride fields are 40 bits (full GM address range). + +##### `pto.set_loop_size_ubtoout` + +- **syntax:** `pto.set_loop_size_ubtoout %loop1_count, %loop2_count : i64, i64` +- **semantics:** Configure HW loop iteration counts for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%loop1_count` | 21 bits | Inner HW loop iteration count | +| `%loop2_count` | 21 bits | Outer HW loop iteration count | + +--- + +##### `pto.set_loop2_stride_ubtoout` + +- **syntax:** `pto.set_loop2_stride_ubtoout %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure outer loop (loop2) pointer advance for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 21 bits | UB source pointer advance per loop2 iteration (bytes) | +| `%dst_stride` | 40 bits | GM destination pointer advance per loop2 iteration (bytes) | + +--- + +##### `pto.set_loop1_stride_ubtoout` + +- **syntax:** `pto.set_loop1_stride_ubtoout %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure inner loop (loop1) pointer advance for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 21 bits | UB source pointer advance per loop1 iteration (bytes) | +| `%dst_stride` | 40 bits | GM destination pointer advance per loop1 iteration (bytes) | + +--- + +#### DMA Transfer Execution + +##### `pto.copy_gm_to_ubuf` + +- **syntax:** +```mlir +pto.copy_gm_to_ubuf %gm_src, %ub_dst, + %sid, %n_burst, %len_burst, %left_padding, %right_padding, + %data_select_bit, %l2_cache_ctl, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` +- **semantics:** DMA transfer from Global Memory (`!pto.ptr`) to Unified Buffer (`!pto.ptr`). + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%gm_src` | GM source pointer (`!pto.ptr`) | +| `%ub_dst` | UB destination pointer (`!pto.ptr`, 32B-aligned) | +| `%sid` | Stream ID (usually 0) | +| `%n_burst` | Number of burst rows (innermost loop count) | +| `%len_burst` | Contiguous bytes transferred per burst row | +| `%left_padding` | Left padding count (bytes) | +| `%right_padding` | Right padding count (bytes) | +| `%data_select_bit` | Padding / data-select control bit (`i1`) | +| `%l2_cache_ctl` | L2 cache allocate control (TBD — controls whether DMA allocates in L2 cache) | +| `%src_stride` | GM source stride: start-to-start distance between consecutive burst rows (bytes) | +| `%dst_stride` | UB destination stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | + +--- + +##### `pto.copy_ubuf_to_gm` + +- **syntax:** +```mlir +pto.copy_ubuf_to_gm %ub_src, %gm_dst, + %sid, %n_burst, %len_burst, %reserved, %dst_stride, %src_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` +- **semantics:** DMA transfer from Unified Buffer (`!pto.ptr`) to Global Memory (`!pto.ptr`). MTE3 reads only `len_burst` bytes from each UB row (de-padding). + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%ub_src` | UB source pointer (`!pto.ptr`, 32B-aligned) | +| `%gm_dst` | GM destination pointer (`!pto.ptr`) | +| `%sid` | Stream ID (usually 0) | +| `%n_burst` | Number of burst rows | +| `%len_burst` | Contiguous bytes transferred per burst row | +| `%reserved` | Reserved field (set to 0) | +| `%dst_stride` | GM destination stride: start-to-start distance between consecutive burst rows (bytes) | +| `%src_stride` | UB source stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | + +--- + +##### `pto.copy_ubuf_to_ubuf` + +- **syntax:** +```mlir +pto.copy_ubuf_to_ubuf %source, %dest, %sid, %n_burst, %len_burst, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64 x5 +``` +- **semantics:** Copy within Unified Buffer. + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%source` | UB source pointer | +| `%dest` | UB destination pointer | +| `%sid` | Stream ID | +| `%n_burst` | Number of bursts | +| `%len_burst` | Length per burst | +| `%src_stride` | Source stride | +| `%dst_stride` | Destination stride | + +--- + +#### Burst / Stride / Pad Model + +All A5 DMA addresses are **stride-based**: stride is the distance from the start of one row to the start of the next row (`stride >= lenBurst`). There is no separate "gap" parameter. + +##### Key Terms + +``` +burst = lenBurst contiguous bytes transferred per row +stride = distance (bytes) from start of row[r] to start of row[r+1] +pad = ub_stride - lenBurst, padded to the 32B alignment boundary +``` + +##### Alignment Constraints + +- **UB addresses** (both source and destination) must be **32-byte aligned**. +- **GM→UB padding**: When `data_select_bit = true`, each UB row is padded from `lenBurst` up to the **32B-aligned boundary** of `ub_stride` with `pad_val` (set via `set_mov_pad_val`). This ensures every UB row starts at a 32B-aligned offset. +- **UB→GM de-padding**: MTE3 reads `lenBurst` bytes from each 32B-aligned UB row (skipping any padding that was added during load), writing only valid data to GM. This effectively strips padding on store. + +##### 2D Diagram: GM→UB (pto.copy_gm_to_ubuf) + +``` +GM (source, `!pto.ptr`): + + |<--- src_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +UB (destination, `!pto.ptr`, 32B-aligned): + + |<---------- dst_stride (32B-aligned) ---------->| + |<- len_burst ->|<- pad (to 32B boundary) ->| | +Row 0: [##DATA########][000000 PAD 000000000000000] +Row 1: [##DATA########][000000 PAD 000000000000000] +Row 2: [##DATA########][000000 PAD 000000000000000] + ... +Row N-1: [##DATA########][000000 PAD 000000000000000] + +N = n_burst +stride = start of row[r] to start of row[r+1] +pad = filled with pad_val to 32B boundary (data_select_bit=true) +[DATA] = valid data transferred by DMA +[PAD] = pad_val fill (set via set_mov_pad_val) +``` + +##### 2D Diagram: UB→GM (pto.copy_ubuf_to_gm) + +``` +UB (source, `!pto.ptr`, 32B-aligned start addr): + + |<---------- src_stride (32B-aligned) --------->| + |<- len_burst ->|<-- pad (ignored on read) -->| | +Row 0: [##DATA########][000 pad 000000000000000000] +Row 1: [##DATA########][000 pad 000000000000000000] +Row 2: [##DATA########][000 pad 000000000000000000] + ... +Row N-1: [##DATA########][000 pad 000000000000000000] + +GM (destination, `!pto.ptr`): + + |<--- dst_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +N = n_burst +MTE3 reads only len_burst bytes from each UB row (de-padding). +Only len_burst bytes are written to each GM row. +``` + +--- + +#### Multi-Level Loop Semantics (C Code) + +The full DMA transfer is a nested loop. The HW loop registers (set before the copy) control the outer levels, and the copy instruction parameters control the innermost burst level. + +##### GM→UB Full Loop + +```c +// C equivalent of what the HW executes: +for (int j = 0; j < loop2_count; j++) { // HW outer loop + uint8_t *gm1 = gm_src + j * loop2_src_stride; + uint8_t *ub1 = ub_dst + j * loop2_dst_stride; + + for (int k = 0; k < loop1_count; k++) { // HW inner loop + uint8_t *gm2 = gm1 + k * loop1_src_stride; + uint8_t *ub2 = ub1 + k * loop1_dst_stride; + + for (int r = 0; r < n_burst; r++) { // burst engine + memcpy(ub2 + r * dst_stride, // UB dest row + gm2 + r * src_stride, // GM src row + len_burst); // contiguous bytes + if (data_select_bit) + memset(ub2 + r * dst_stride + len_burst, + pad_val, dst_stride - len_burst); + } + } +} +``` + +##### UB→GM Full Loop + +```c +// C equivalent: +for (int j = 0; j < loop2_count; j++) { + uint8_t *ub1 = ub_src + j * loop2_src_stride; + uint8_t *gm1 = gm_dst + j * loop2_dst_stride; + + for (int k = 0; k < loop1_count; k++) { + uint8_t *ub2 = ub1 + k * loop1_src_stride; + uint8_t *gm2 = gm1 + k * loop1_dst_stride; + + for (int r = 0; r < n_burst; r++) { + memcpy(gm2 + r * dst_stride, // GM dest row + ub2 + r * src_stride, // UB src row + len_burst); // contiguous bytes + } + } +} +``` + +--- + +#### Example 1: GM→UB — Load a 32×32 f32 Tile (Simple Case) + +Load a 32×32 f32 tile from GM into UB. This matches the `abs_kernel_2d` test case. + +``` +GM layout (32 × 32 f32, contiguous): + + |<- len_burst = 128B (32 × 4) ->| + |<- src_stride = 128B --------->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + +UB layout (32 × 32 f32, 32B-aligned, contiguous): + + |<- dst_stride = 128B (32B-aligned) ->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + + len_burst = 32 × 4 = 128 bytes + src_stride = 128 bytes (contiguous rows) + dst_stride = 128 bytes (already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — no multi-level loops needed +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + +pto.copy_gm_to_ubuf %arg0, %ub_in, + %c0_i64, // sid = 0 + %c32_i64, // n_burst = 32 (32 rows) + %c128_i64, // len_burst = 128 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c128_i64, // src_stride = 128 bytes + %c128_i64 // dst_stride = 128 bytes + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 2: GM→UB — Load a 2D Tile from a Larger Matrix + +Load a 64×128 tile (f16) from a 1024×512 matrix in GM into UB. + +``` +GM layout (1024 × 512 f16): + + col 0 col 128 col 512 + | | | + +--[###TILE###]+.....................+ row R + +--[###TILE###]+.....................+ row R+1 + ... + +--[###TILE###]+.....................+ row R+63 + + |<--------- src_stride = 1024B ----------->| + |<-len_burst=256B->| + + len_burst = 128 × 2 = 256 bytes (128 f16 elements) + src_stride = 512 × 2 = 1024 bytes (start-to-start, full GM row) + +UB layout (64 × 128 f16, 32B-aligned, contiguous): + + +--[###TILE###]--+ row 0 (256 bytes, 32B-aligned, no pad) + +--[###TILE###]--+ row 1 + ... + +--[###TILE###]--+ row 63 + + dst_stride = 256 bytes (= len_burst, already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — no multi-level loops needed +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 (64 rows) + %c256_i64, // len_burst = 256 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c1024_i64, // src_stride = 1024 bytes (full matrix row) + %c256_i64 // dst_stride = 256 bytes (tile row) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 3: GM→UB — Load with Padding + +Load 100 valid columns from GM into a 128-wide UB tile (f16). The remaining 28 columns are zero-padded. + +``` +GM (100 cols valid, contiguous): + + |<-len_burst=200B->| + |<- src_stride=200B (start-to-start) ->| + +--[####DATA####]-+ row 0 + +--[####DATA####]-+ row 1 + ... + +--[####DATA####]-+ row 63 + +UB (128 cols wide, 32B-aligned, padded): + + |<--------- dst_stride = 256B (32B-aligned) --------->| + |<-len_burst=200B->|<---- pad = 56B to 32B boundary ->| + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 0 + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 1 + ... + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 63 + + len_burst = 100 × 2 = 200 bytes + src_stride = 200 bytes (start-to-start, contiguous in GM) + dst_stride = 128 × 2 = 256 bytes (32B-aligned tile width in UB) + pad = 256 - 200 = 56 bytes (padded to 32B boundary with pad_val) +``` + +```mlir +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 + %c200_i64, // len_burst = 200 bytes + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %true, // data_select_bit = true (enable padding) + %c0_i64, // l2_cache_ctl = 0 + %c200_i64, // src_stride = 200 bytes + %c256_i64 // dst_stride = 256 bytes (32B-aligned) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 4: UB→GM — Store a 32×32 f32 Tile (Simple Case) + +Store a 32×32 f32 tile from UB back to GM. This matches the `abs_kernel_2d` test case. + +``` +UB (source, 32B-aligned, 32 × 32 f32): + + |<- src_stride = 128B (32B-aligned) ->| + |<- len_burst = 128B ->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 + + (no padding here — len_burst == src_stride) + +GM (dest, 32 × 32 f32): + + |<- dst_stride = 128B ->| + |<- len_burst = 128B -->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 +``` + +```mlir +// Configure MTE3 strides +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + +pto.copy_ubuf_to_gm %ub_out, %arg1, + %c0_i64, // sid = 0 + %c32_i64, // n_burst = 32 + %c128_i64, // len_burst = 128 bytes + %c0_i64, // reserved = 0 + %c128_i64, // dst_stride = 128 bytes + %c128_i64 // src_stride = 128 bytes + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` + +--- + +#### Example 5: UB→GM — Store a 2D Tile Back to a Larger Matrix + +Store a 64×128 tile (f16) from UB back to a 1024×512 GM matrix at an offset. + +``` +UB (source, 32B-aligned, 64 × 128 f16): + + |<- src_stride = 256B (32B-aligned) ->| + |<- len_burst = 256B ->| + +--[#####TILE#####]---+ row 0 + +--[#####TILE#####]---+ row 1 + ... + +--[#####TILE#####]---+ row 63 + + (no padding here — len_burst == src_stride) + +GM (dest, into 1024 × 512 matrix): + + |<----------- dst_stride = 1024B (start-to-start) --------->| + |<- len_burst = 256B ->| | + col 0 col 128 col 512 + +--[#####TILE#####]---+.............................+ row R + +--[#####TILE#####]---+.............................+ row R+1 + ... + +--[#####TILE#####]---+.............................+ row R+63 + + MTE3 reads len_burst bytes from each 32B-aligned UB row, + writes only len_burst bytes per GM row (stride controls row spacing). +``` + +```mlir +// Configure MTE3 strides +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 + +pto.copy_ubuf_to_gm %ub_ptr, %gm_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 + %c256_i64, // len_burst = 256 bytes + %c0_i64, // reserved = 0 + %c1024_i64, // dst_stride = 1024 bytes (GM row) + %c256_i64 // src_stride = 256 bytes (UB row) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` + +--- + +#### Example 6: GM→UB with Multi-Level Loop (Batch of Tiles) + +Load 4 batches of 8×128 tiles from a [4, 8, 128] f16 tensor using loop1. + +``` +GM [4, 8, 128] f16 (contiguous): UB (4 tiles laid out sequentially): + + batch 0: 8 rows × 256 bytes [batch 0: 8×128][batch 1: 8×128] + batch 1: 8 rows × 256 bytes [batch 2: 8×128][batch 3: 8×128] + batch 2: 8 rows × 256 bytes + batch 3: 8 rows × 256 bytes loop1 src_stride = 2048 bytes (8 × 256) + loop1 dst_stride = 2048 bytes (8 × 256) + Each batch = 8 × 256 = 2048 bytes loop1_count = 4 (iterate over batches) +``` + +```mlir +// loop1_count = 4 batches, loop2_count = 1 (not used) +pto.set_loop_size_outtoub %c4_i64, %c1_i64 : i64, i64 + +// loop1 stride: advance by one batch (2048 bytes) in both GM and UB +pto.set_loop1_stride_outtoub %c2048_i64, %c2048_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c8_i64, // n_burst = 8 rows per batch + %c256_i64, // len_burst = 256 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c256_i64, // src_stride = 256 (contiguous rows) + %c256_i64 // dst_stride = 256 (contiguous rows) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +Execution trace: + +``` +loop1 iter 0: gm_ptr + 0×2048 → ub_ptr + 0×2048, DMA 8 rows × 256B +loop1 iter 1: gm_ptr + 1×2048 → ub_ptr + 1×2048, DMA 8 rows × 256B +loop1 iter 2: gm_ptr + 2×2048 → ub_ptr + 2×2048, DMA 8 rows × 256B +loop1 iter 3: gm_ptr + 3×2048 → ub_ptr + 3×2048, DMA 8 rows × 256B +``` + + + +### 3. Vector Load/Store + +> **Category:** UB ↔ Vector Register data movement +> **Pipeline:** PIPE_V (Vector Core) + +Vector loads move data from Unified Buffer (UB) to vector registers (`vreg`). Vector stores move data from `vreg` back to UB. All vector compute operates only on `vreg` — UB is the staging area between DMA and compute. + +#### Common Operand Model + +- `%source` / `%dest` is the base address operand in SSA form. The base pointer + MUST address the Vector tile buffer / UB space. +- `%offset` is the displacement operand in SSA form. The exact encoding is + instruction-specific, but the effective address and any post-update behavior + MUST match the selected instruction form. +- `%mask` is the predicate operand for predicated memory families. For memory + families, + inactive lanes or inactive blocks MUST NOT issue memory requests unless the + instruction explicitly documents a different behavior. +- `%result` is the destination vector register value in SSA form. +- `!pto.align` is the SSA carrier for alignment-buffer state used by unaligned + load/store families. The PTO micro Instruction representation makes that state explicit rather than implicit. + +--- + +#### Latency and throughput (A5) + +**Cycle-accurate simulator (CA model)** issue→retire timings for vector-side instructions behind this chapter. Values are **simulator** results, **not** guaranteed for silicon. + +**SOC:** Tables below are from **Ascend910_9599** CA sim (the pto-isa ST default when **Ascend950PR_9599** is not selected). + +**Log `dist:` tokens:** PTO load/store modes lower to **`RV_VLD` / `RV_VLDI` / `RV_VST` / `RV_VSTI`** with a **`dist:`** field on the vector pipes (`RVECLD` / `RVECST`). Some simulator logs typo contiguous load as `dist:NORAML`; treat as **`NORMAL`**. + +##### Reference op latencies (A5 mnemonics) + +| A5 mnemonic | Mode / note | Typical issue→retire (cycles) | +|-------------|-------------|------------------------------| +| `RV_VLD` | `dist:NORMAL` / `NORAML` | **9** | +| `RV_VLDI` | `dist:DINTLV` (dual vreg) | **9** | +| `RV_VST` / `RV_VSTI` | `dist:NORM` | **9** | +| `RV_VGATHER2` | `Dtype: B32` | **27–28** | +| `RV_VGATHERB` | indexed byte gather | **~21** | +| `RV_VSCATTER` | `Dtype: B16` | **~17** | +| `RV_VADD` | F32 between UB-backed ops | **7** | + +##### `dist:` tokens (issue→retire) + +Most **`dist:`** tokens are **9** issue→retire cycles. **`INTLV`** on **`RV_VSTI`** is **12** cycles. + +| `dist:` (as in log) | RV op | issue→retire (cycles) | +|---------------------|-------|----------------------| +| `DINTLV` | `RV_VLDI` | **9** | +| `BRC` | `RV_VLD` | **9** | +| `BRC_BLK` | `RV_VLD` | **9** | +| `INTLV` | `RV_VSTI` | **12** | +| `UNPK` | `RV_VLD` | **9** | +| `NORM` | `RV_VSTI` | **9** | +| `PK` | `RV_VSTI` | **9** | +| `NORMAL` / `NORAML` | `RV_VLD` | **9** | + +**Note:** PTO intrinsic **`BRC_BLK`** matches the **`BRC_BLK`** `dist:` string on **`RV_VLD`** in simulator logs (block-replicate path; not a plain contiguous copy in the usual tiling use). + +**Issue (vector load/store):** `pto.vlds` (**`RV_VLD`**) is **dual-issue capable**: two independent `pto.vlds` can issue **in the same cycle**. **Alternatively**, the hardware can issue **one** `pto.vlds` **and** **one** `pto.vsts` together (**1+1**) in the same cycle. Each cycle is **either** dual **`vlds`** **or** **`vlds` + `vsts` (1+1)**—those two issue modes are mutually exclusive. Sustained throughput still depends on RAW hazards and loop structure. + +**Throughput (simulator, pattern-dependent):** + +- **`RV_VLD` / `pto.vlds`:** Dual-issue **or** half of a **1+1** with `vsts`, per the rule above. +- **`RV_VST` / `pto.vsts`:** In a **1+1** cycle, pairs with one `vlds`; otherwise typically **one** store per cycle in tight loops. +- **`RV_VGATHER2`:** Much lower than contiguous `RV_VLD` (on the order of **~0.1** ops/cycle in steady-state alongside 27–28-cycle latency). + +##### PTO `dist` summary (loads) + +| PTO `dist` (load) | Latency | +|-------------------|-------------------| +| `NORM` | **9** cycles | +| `UNPK` | **9** cycles | +| `DINTLV` | **9** cycles (`RV_VLDI`) | +| `BRC` | **9** cycles (`RV_VLD`) | +| `BRC_BLK` | **9** cycles as **`dist:BRC_BLK`** on `RV_VLD` | +| `BDINTLV` | **9** cycles | +| `US`, `DS`, `SPLT4CHN`, `SPLT2CHN` | **9** cycles | + +##### PTO `dist` summary (stores) + +| PTO `dist` (store) | Latency | +|--------------------|-------------------| +| `NORM` | **9** cycles (`RV_VSTI`) | +| `PK` | **9** cycles | +| `INTLV` (`pto.vstx2`) | **12** cycles | +| `MRG4CHN`, `MRG2CHN` | **9** cycles | + +##### Gather, scatter, and special addressing + +| PTO op | A5-level | Latency | +|--------|----------|-------------------| +| `pto.vgather2` | `RV_VGATHER2` | **27–28** cycles (pattern-dependent) | +| `pto.vgatherb` | `RV_VGATHERB` | **~21** cycles issue→retire | +| `pto.vgather2_bc` | (broadcast gather) | **27–28** cycles (same as **`pto.vgather2`**) | +| `pto.vscatter` | `RV_VSCATTER` | **~17** cycles for **`Dtype: B16`** | + +##### Strided loads/stores, unaligned ops, alignment state + +Ops such as **`pto.vldas`**, **`pto.vldus`**, **`pto.vsld`**, **`pto.vsldb`**, **`pto.vsst`**, **`pto.vsstb`**, **`pto.vsta`**, **`pto.vstas`**, **`pto.vstar`**, **`pto.vstu`**, **`pto.vstus`**, **`pto.vstur`**: **9** cycles (same vector load/store pipe family as contiguous `RV_VLD` / `RV_VST` unless listed otherwise above). + +##### Dual-issue vs DMA + +DMA **`TLOAD` / `TSTORE`** (global memory ↔ UB) use **MTE** pipes, not `RV_VLD`/`RV_VST`. **MTE2** `MOV_*` latency is not the same as vector `RV_VLD` latency (see `02-dma-copy.md` for GM↔UB movement). + +--- + +#### Contiguous Loads + +##### `pto.vlds` + +- **syntax:** `%result = pto.vlds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.vreg` +- **semantics:** Vector load with distribution mode. +- **inputs:** + `%source` is the UB base address, `%offset` is the load displacement, and + `DIST` selects the distribution mode. +- **outputs:** + `%result` is the loaded vector register value. +- **constraints and limitations:** + The effective address MUST satisfy the alignment rule of the selected + distribution mode. `NORM` reads one full vector footprint. Broadcast, + upsample, downsample, unpack, split-channel, and deinterleave modes change + how memory bytes are mapped into destination lanes, but they do not change the + fact that the source is UB memory. PTO surface exposes load `dist` as family + tokens, and each family only supports the element widths listed below. + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `NORM` | width-agnostic | `dst[i] = UB[base + i * sizeof(T)]` | **9** cycles | +| `BRC` | `b8`, `b16`, `b32` | `dst[i] = UB[base]` for all `i` | **9** cycles | +| `US` | `b8`, `b16` | `dst[2*i] = dst[2*i+1] = UB[base + i]` | **9** cycles | +| `DS` | `b8`, `b16` | `dst[i] = UB[base + 2*i]` | **9** cycles | +| `UNPK` | `b8`, `b16`, `b32` | Expand packed source data into wider lanes | **9** cycles | +| `BRC_BLK` | width-agnostic | Block-replicate load path; simulator logs may print `dist:BRC_BLK` | **9** cycles | +| `E2B` | `b16`, `b32` | Load element groups and expand them into byte-oriented lane layout | **9** cycles | +| `UNPK4` | `b8` | Unpack 4-way packed `b8` source groups into destination lanes | **9** cycles | +| `SPLT4CHN` | `b8` | Split 4-channel interleaved source into one channel plane | **9** cycles | +| `SPLT2CHN` | `b8`, `b16` | Split 2-channel interleaved source into one channel plane | **9** cycles | + +`pto.vlds` currently covers only single-result load families. Dual-result +deinterleave forms are modeled separately in PTO surface as +[`pto.vldsx2`](#ptovldsx2): `BDINTLV` is the block-deinterleave family, while +`DINTLV` is the element-width-sensitive deinterleave family. + +**Example — Contiguous load:** +```mlir +%v = pto.vlds %ub[%offset] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +``` + +**Example — Broadcast scalar to all lanes:** +```mlir +%v = pto.vlds %ub[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> +``` + +--- + +##### `pto.vldas` + +- **syntax:** `%result = pto.vldas %source : !pto.ptr -> !pto.align` +- **semantics:** Prime alignment buffer for subsequent unaligned load. +- **inputs:** + `%source` is the UB address whose surrounding aligned block seeds the load + alignment state. +- **outputs:** + `%result` is the initialized load-alignment state. +- **constraints and limitations:** + This op is the required leading operation for a `pto.vldus` stream using the + same alignment state. The source address itself need not be 32-byte aligned; + hardware truncates it to the aligned block boundary for the priming load. +- **Latency:** **9** cycles. + +--- + +##### `pto.vldus` + +- **syntax:** `%result, %align_out = pto.vldus %source, %align : !pto.ptr, !pto.align -> !pto.vreg, !pto.align` +- **semantics:** Unaligned load using primed align state. +- **inputs:** + `%source` is the current UB address and `%align` is the incoming load + alignment state primed by `pto.vldas` or a prior `pto.vldus`. +- **outputs:** + `%result` is the assembled vector value and `%align_out` is the updated + alignment state. +- **constraints and limitations:** + A matching `pto.vldas` MUST appear before the first dependent `pto.vldus` + stream in the same vector loop. The installed no-post A5 interface keeps a + struct-shaped internal return for lowering convenience, but its no-post + `base` field is not meaningful user-visible state. VPTO therefore hides that + value and only exposes the updated align carrier. Reusing the original + `%source` starts a new explicit access point; if the caller wants another + no-post access, it should compute the next source pointer explicitly and pair + it with the required align setup. +- **Latency:** **9** cycles. + +**Unaligned load pattern:** +```mlir +%align = pto.vldas %ub : !pto.ptr -> !pto.align +%vec, %align2 = pto.vldus %ub, %align : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align +``` + +--- + +##### `pto.init_align` + +- **syntax:** `%result = pto.init_align : !pto.align` +- **semantics:** Initialize store-side align carrier state. +- **outputs:** + `%result` is a fresh zero-initialized align carrier for store-side unaligned + streams such as `pto.vstus`, `pto.vstur`, `pto.vstar`, `pto.vstas`, and + `pto.pstu`. +- **constraints and limitations:** + This op is for store-family initialization only. Unaligned load streams still + start from `pto.vldas`. + +--- + +#### Dual Loads (Deinterleave) + +##### `pto.vldsx2` + +- **syntax:** `%low, %high = pto.vldsx2 %source[%offset], "DIST" : !pto.ptr, index -> !pto.vreg, !pto.vreg` +- **semantics:** Dual load with deinterleave (AoS → SoA conversion). +- **inputs:** + `%source` is the UB base pointer, `%offset` is the displacement, and `DIST` + selects a dual-load/deinterleave layout. +- **outputs:** + `%low` and `%high` are the two destination vectors. +- **constraints and limitations:** + This family is only legal for interleave/deinterleave style distributions. + The two outputs form an ordered pair, and that pairing MUST be preserved. + PTO surface accepts deinterleave families. `BDINTLV` is element-width + agnostic, while `DINTLV` supports only the element widths listed in the + table. +- **latency:** `BDINTLV` / `DINTLV` are both **9** cycles. + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `BDINTLV` | width-agnostic | Block deinterleave into two destination vectors | **9** cycles | +| `DINTLV` | `b8`, `b16`, `b32` | Deinterleave alternating elements into `%low` / `%high` | **9** cycles | + +```c +// DINTLV family on 32-bit elements: deinterleave 32-bit elements +for (int i = 0; i < 64; i++) { + low[i] = UB[base + 8*i]; // even elements + high[i] = UB[base + 8*i + 4]; // odd elements +} +``` + +**Example — Load interleaved XY pairs into separate X/Y vectors:** +```mlir +%x, %y = pto.vldsx2 %ub[%offset], "DINTLV" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +``` + +##### `pto.vsldb` + +- **syntax:** `%result = pto.vsldb %source, %block_stride, %repeat_stride, %mask : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg` +- **semantics:** Block-strided load for 2D tile access. +- **inputs:** + `%source` is the UB base pointer. `%block_stride` and `%repeat_stride` are + the two 16-bit fields of the hardware control word, and `%mask` controls + which blocks participate. +- **outputs:** + `%result` is the loaded vector. +- **constraints and limitations:** + PTO surface does not expose the packed control word directly. If a block is + masked off, the corresponding destination block is zeroed and MUST NOT raise + an address overflow exception for that block. +- **Latency:** **9** cycles. + +```c +// Block-strided load on 32-bit elements: one 32B block = 8 lanes. +for (int blk = 0; blk < 8; ++blk) { + if (pg_b32[blk]) + dst_block[blk] = UB_block[base + repeat_stride + blk * block_stride]; + else + dst_block[blk] = 0; +} +``` + +--- + +#### Gather (Indexed) Loads + +##### `pto.vgather2` + +- **syntax:** `%result = pto.vgather2 %source, %offsets, %active_lanes : !pto.ptr, !pto.vreg, index -> !pto.vreg` +- **semantics:** Indexed gather from UB. +- **inputs:** + `%source` is the UB base pointer, `%offsets` provides per-lane element + offsets, and `%active_lanes` bounds how many lanes participate. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + Only the first `%active_lanes` indices participate. The index element width + and interpretation MUST match the selected gather form, and each effective + address must satisfy that form's alignment rules. +- **Latency:** **27–28** cycles per `RV_VGATHER2`; throughput much lower than contiguous `RV_VLD` (see **Latency and throughput (A5)** at the start of this chapter). + +```c +for (int i = 0; i < active_lanes; i++) + dst[i] = UB[base + offsets[i] * sizeof(T)]; +``` + +--- + +##### `pto.vgatherb` + +- **syntax:** `%result = pto.vgatherb %source, %offsets, %mask : !pto.ptr, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Block gather load from UB. +- **inputs:** + `%source` is the UB base pointer, `%offsets` is a `ui32` offset vector, and + `%mask` is a `b32` predicate over the block-index lanes. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + This is a 32-byte block gather, not an element gather. `%source` MUST be + 32-byte aligned. Each participating `offsets[i]` is interpreted as a byte + offset and MUST itself be 32-byte aligned. Only the low `VL/8` bytes of the + offset vector are semantically valid; the effective block address is + `block_addr[i] = offsets_u32[i] + base`. If a `b32` predicate position is + false, the corresponding block does not participate in address coalescing, + does not raise overflow on that block address, and the destination block is + zero-filled. +- **Latency:** **~21** cycles issue→retire. + +```c +for (int blk = 0; blk < VL / 32; ++blk) { + if (pg_b32[blk]) + dst_block[blk] = UB_block[base + offsets_u32[blk]]; + else + dst_block[blk] = 0; +} +``` + +--- + +##### `pto.vgather2_bc` + +- **syntax:** `%result = pto.vgather2_bc %source, %offsets, %mask : !pto.ptr, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Gather with broadcast, conditioned by mask. +- **inputs:** + `%source` is the UB base pointer, `%offsets` contains gather indices, and + `%mask` gates which lanes participate. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + This is a backward-compatible family. Masked-off lanes do not participate in + address coalescing and do not trigger address overflow exceptions; their + destination lanes are zero-filled. On the current PTO surface, `%offsets` + uses 32-bit integer elements. +- **Latency:** **27–28** cycles (same as **`pto.vgather2`**). + +--- + +#### Contiguous Stores + +##### `pto.vsts` + +- **syntax:** `pto.vsts %value, %dest[%offset], %mask {dist = "DIST"} : !pto.vreg, !pto.ptr, !pto.mask` +- **semantics:** Vector store with distribution mode. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offset` is + the displacement, `%mask` selects the active lanes or sub-elements, and + `DIST` selects the store distribution. +- **outputs:** + This op has no SSA result; it writes to UB memory. +- **constraints and limitations:** + The effective destination address MUST satisfy the alignment rule of the + selected store mode. The single-input `pto.vsts` family covers contiguous + store, first-element-only store, packed store, and channel-merge store. + Dual-input interleave store remains in `pto.vstsx2`. PTO surface exposes + store `dist` as family tokens, and each family only supports the element + widths listed below. + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `NORM` | `b8`, `b16`, `b32` | `UB[base + i] = src[i]` | **9** cycles | +| `1PT` | `b8`, `b16`, `b32` | Only element 0 is written to the destination footprint | **9** cycles | +| `PK` | `b16`, `b32`, `b64` | Pack low half bits of each source element before store | **9** cycles | +| `PK4` | `b32` | Pack low 8 bits of each `b32` element before store | **9** cycles | +| `MRG4CHN` | `b8` | Merge 4 channel planes into an interleaved 4-channel layout | **9** cycles | +| `MRG2CHN` | `b8`, `b16` | Merge 2 channel planes into an interleaved 2-channel layout | **9** cycles | + +**Example — Contiguous store:** +```mlir +pto.vsts %v, %ub[%offset], %mask {dist = "NORM"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +--- + +#### Dual Stores (Interleave) + +##### `pto.vstsx2` + +- **syntax:** `pto.vstsx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vreg, !pto.ptr, index, !pto.mask` +- **semantics:** Dual interleaved store (SoA → AoS conversion). +- **inputs:** + `%low` and `%high` are the two source vectors, `%dest` is the UB base pointer, + `%offset` is the displacement, `DIST` selects the interleave layout, and + `%mask` gates the participating elements. +- **outputs:** + This op has no SSA result; it writes an interleaved stream to UB. +- **constraints and limitations:** + This family is only legal for interleave distributions. The two source + vectors form an ordered pair, and the interleave semantics of that pair MUST + be preserved. PTO surface accepts the `INTLV` family, which only supports the + element widths listed below. + be preserved. PTO surface accepts the `INTLV` family, which only supports the + element widths listed below. +- **latency:** `INTLV` is **12** cycles。 + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `INTLV` | `b8`, `b16`, `b32` | Interleave `%low` / `%high` into one destination stream | **12** cycles | +| `INTLV` | `b8`, `b16`, `b32` | + +```c +// INTLV family on 32-bit elements: +for (int i = 0; i < 64; i++) { + UB[base + 8*i] = low[i]; + UB[base + 8*i + 4] = high[i]; +} +``` + +##### `pto.vsstb` + +- **syntax:** `pto.vsstb %value, %dest, %block_stride, %repeat_stride, %mask : !pto.vreg, !pto.ptr, i16, i16, !pto.mask` +- **semantics:** Block-strided store for 2D tile access. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, + `%block_stride` and `%repeat_stride` are the two 16-bit fields of the + hardware control word, and `%mask` controls block participation. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + PTO surface does not expose the packed control word directly. Masked-off + blocks MUST NOT issue memory writes. +- **Latency:** **9** cycles. + +```c +// Block-strided store on 32-bit elements: one 32B block = 8 lanes. +for (int blk = 0; blk < 8; ++blk) { + if (pg_b32[blk]) + UB_block[base + repeat_stride + blk * block_stride] = src_block[blk]; +} +``` + +--- + +#### Scatter (Indexed) Stores + +##### `pto.vscatter` + +- **syntax:** `pto.vscatter %value, %dest, %offsets, %active_lanes : !pto.vreg, !pto.ptr, !pto.vreg, index` +- **semantics:** Indexed scatter to UB. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offsets` + provides per-lane or per-block indices, and `%active_lanes` bounds the active + requests. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + Only `b8`, `b16`, and `b32` element sizes are supported. The index vector + must use a supported integer element type and layout for this family. + Each computed address MUST be element-aligned. If two or more indices alias, + only one write is guaranteed and the winning lane is implementation-defined. +- **Latency:** **~17** cycles for **`Dtype: B16`**. + +```c +for (int i = 0; i < active_lanes; i++) + UB[base + offsets[i] * sizeof(T)] = src[i]; +``` + +--- + +#### Alignment State Stores + +##### `pto.vstas` +- **syntax:** `pto.vstas %value, %dest, %offset : !pto.align, !pto.ptr, i32` +- **semantics:** Scalar-register-offset form of alignment-state flush. +- **inputs:** + `%value` is the pending store-alignment state, `%dest` is the UB base + pointer, and `%offset` is the scalar-register style displacement. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + This family flushes pending store-alignment state using an explicit scalar + offset and keeps the scalar-offset form explicit. The incoming `%value` + should come from `pto.init_align` or from a prior state-producing unaligned + store op in the same stream. `%dest` and `%offset` together must identify the + same logical flush point produced by the immediately preceding stateful + unaligned-store step on that stream; using an unrelated base/offset pair is + invalid even if `%value` itself came from the same stream. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstar` +- **syntax:** `pto.vstar %value, %dest : !pto.align, !pto.ptr` +- **semantics:** Flush alignment state using the register-update form. +- **inputs:** + `%value` is the pending store-alignment state and `%dest` is the UB base + pointer. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + The implicit update state consumed by this flush MUST correspond to the same + store stream that produced `%value`. The first store-side state in a stream + should be created by `pto.init_align`. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstar` + +- **syntax:** `pto.vstar %value, %dest : !pto.align, !pto.ptr` +- **semantics:** Flush remaining alignment state. +- **inputs:** + `%value` is the pending alignment/buffer state that still needs to be emitted, + and `%dest` is the UB destination base pointer. +- **outputs:** + No SSA result. The effect is a memory-side flush that writes the remaining + buffered bytes to memory. +- **constraints and limitations:** + This op terminates an unaligned-store sequence. It MUST be paired with a + compatible prior state-producing store sequence so that the pending tail state + is well-defined. +- **Latency:** **9** cycles. + +--- + +#### Stateful Store Ops + +These ops make reference-updated state explicit as SSA results. + +##### `pto.vstus` + +- **syntax:** `%align_out = pto.vstus %align_in, %offset, %value, %base : !pto.align, i32, !pto.vreg, !pto.ptr -> !pto.align` +- **semantics:** No-post unaligned store with scalar offset. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%offset` is the scalar + displacement, `%value` is the vector being stored, and `%base` is the UB base + pointer. +- **outputs:** + `%align_out` is the updated buffered-tail state. +- **constraints and limitations:** + This is the scalar-offset stateful form of the unaligned store family. The + scalar offset width MUST match the selected form, and a later flush op is + still required. The first `%align_in` in the stream should come from + `pto.init_align`. This op does not mean "store a full vector starting at + `%base + %offset`". Instead, `%offset` describes how far the store stream + advances at this step, and `%align_out` carries any residual tail that could + not be committed yet. The no-post surface does not expose an updated base + pointer. A later flush op must therefore use an explicit destination/offset + pair that identifies the same logical flush point as this `pto.vstus`. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstur` + +- **syntax:** `%align_out = pto.vstur %align_in, %value, %base, "MODE" : !pto.align, !pto.vreg, !pto.ptr -> !pto.align` +- **semantics:** Unaligned store with residual flush and SPR-AR-driven state update. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%value` is the vector to + store, `%base` is the UB base pointer, and `MODE` selects whether the + hardware updates `SPR AR` after the store. +- **outputs:** + `%align_out` is the updated residual state after the current partial store. +- **constraints and limitations:** + The effective address is `base + AR`, where `AR` is the hardware SPR state + carried outside SSA. `POST_UPDATE` means hardware may advance `SPR AR` + according to the fixed `SPR SQZN` configuration; `NO_POST_UPDATE` preserves + the current `SPR AR` value. This form exposes only the evolving residual + align-state in SSA; it does not by itself guarantee that all buffered bytes + have reached memory. A compatible final flush is still required unless the + surrounding sequence is known to be complete. Independent sequences typically + begin from `AR = 0`; if the surrounding program does not already guarantee + that, the hardware sequence should clear `SPR AR` before the first dependent + `pto.vstur`. The first `%align_in` in the stream should come from + `pto.init_align`. `pto.vstur` also consumes the fixed `SPR SQZN` state, so a + preceding squeeze producer such as `pto.vsqz` / `pto.vusqz` MUST establish + the byte count before the store. `MODE` MUST be one of `POST_UPDATE` or + `NO_POST_UPDATE`. +- **Latency:** **9** cycles. + + + +### 4. Predicate Load/Store + +> **Category:** UB ↔ Predicate Register data movement +> **Pipeline:** PIPE_V (Vector Core) + +Predicate registers (`!pto.mask`) are 256-bit registers that enable per-lane conditional execution. These ops move predicate values between UB and predicate registers. + +In concrete examples, `G` should be chosen to match the consumer family. The +examples below use `b32` when the loaded/stored mask is used with `f32` +vector compares or selects. + +The predicate load/store ops documented on this page always use explicit +`base[offset]` addressing. The immediate forms (`pldi`, `psti`) and dynamic +forms (`plds`, `psts`) differ only in how `%offset` is supplied. + +--- + +#### Predicate Loads + +##### `pto.plds` + +- **syntax:** `%result = pto.plds %source[%offset], "DIST" : !pto.ptr, index -> !pto.mask` +- **semantics:** Load predicate register with runtime offset. This is the + dynamic-offset form of `pto.pldi`: the predicate payload interpretation is + the same, but `%offset` is supplied as an SSA `index` instead of a constant + `index` immediate. +- **DIST:** mandatory string token, one of `NORM`, `US`, `DS`. + - `NORM`: load a normal packed predicate payload of size `VL/8`. + - `US`: load a packed predicate payload of size `VL/16`, then duplicate each + loaded bit once. + - `DS`: load a packed predicate payload of size `2 * VL/8`, then keep one + bit out of every two bits. + +The loaded payload is a packed predicate image in UB. Consumer ops interpret +the resulting `!pto.mask` according to the mask granularity `G`. +`pto.plds` only +models the explicit `base[offset]` form. + +**Example:** +```mlir +%mask = pto.plds %ub[%c0], "NORM" : !pto.ptr, index -> !pto.mask +``` + +--- + +##### `pto.pldi` + +- **syntax:** `%result = pto.pldi %source[%offset], "DIST" : !pto.ptr, index -> !pto.mask` +- **offset:** must be a constant `index` immediate in PTO surface form. +- **semantics:** Load predicate register with immediate offset. +- **DIST:** mandatory string token, one of `NORM`, `US`, `DS`. + - `NORM`: load a normal packed predicate payload of size `VL/8`. + - `US`: load a packed predicate payload of size `VL/16`, then duplicate each + loaded bit once. + - `DS`: load a packed predicate payload of size `2 * VL/8`, then keep one + bit out of every two bits. + +Like `pto.plds`, this op reads a packed predicate payload from UB and +materializes it as `!pto.mask`. + +--- + +#### Predicate Stores + +##### `pto.psts` + +- **syntax:** `pto.psts %value, %dest[%offset], "DIST" : !pto.mask, !pto.ptr, index` +- **semantics:** Store predicate register with runtime offset. This is the + dynamic-offset form of `pto.psti`: the predicate payload interpretation is + the same, but `%offset` is supplied as an SSA `index` instead of a constant + `index` immediate. +- **DIST:** mandatory string token, one of `NORM`, `PK`. + - `NORM`: store the packed predicate payload into a normal destination space + of size `VL/8`. + - `PK`: store the packed predicate payload into a destination space of size + `VL/16`, keeping one bit out of every two bits. + +`pto.psts` stores the packed predicate payload represented by `!pto.mask`. +It only models the explicit `base[offset]` form. + +**Example:** +```mlir +pto.psts %mask, %ub[%c0], "NORM" : !pto.mask, !pto.ptr, index +``` + +--- + +##### `pto.psti` + +- **syntax:** `pto.psti %value, %dest[%offset], "DIST" : !pto.mask, !pto.ptr, index` +- **offset:** must be a constant `index` immediate in PTO surface form. +- **semantics:** Store predicate register with immediate offset. +- **DIST:** mandatory string token, one of `NORM`, `PK`. + - `NORM`: store the packed predicate payload into a normal destination space + of size `VL/8`. + - `PK`: store the packed predicate payload into a destination space of size + `VL/16`, keeping one bit out of every two bits. + +`pto.psti` and `pto.psts` store the packed predicate payload represented by +`!pto.mask`. The surface distinction is only immediate-offset versus +dynamic-offset. + +--- + +##### `pto.pstu` + +- **syntax:** `%align_out, %base_out = pto.pstu %align_in, %value, %base : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr` +- **syntax:** `%align_out, %base_out = pto.pstu %align_in, %value, %base : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr` +- **semantics:** Predicate unaligned store with align/base state update. The base type is fixed by mask granularity: `b16 <-> ui16`, `b32 <-> ui32`. +- **outputs:** + `%align_out` and `%base_out` are the updated unaligned-store state and are + intended to be used by a later `pto.pstu` call. +- **constraints and limitations:** + The first `%align_in` in a predicate unaligned-store stream should come from + `pto.init_align`. + +--- + +#### Typical Usage Pattern + +```mlir +// Generate comparison mask +%mask = pto.vcmp %v0, %v1, %seed, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + +// Store mask to UB for later use +pto.psts %mask, %ub_mask[%c0], "NORM" : !pto.mask, !pto.ptr, index + +// ... later in another kernel ... + +// Load mask from UB +%saved_mask = pto.plds %ub_mask[%c0], "NORM" : !pto.ptr, index -> !pto.mask + +// Use for predicated select +%result = pto.vsel %v_true, %v_false, %saved_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 5. Materialization & Predicate Ops + +> **Category:** Scalar broadcast, predicate generation and manipulation +> **Pipeline:** PIPE_V (Vector Core) + +These ops create vectors from scalar values and manipulate predicate registers. + +#### Common Operand Model + +- `%value` is the scalar source value in SSA form. +- `%input` is either a source scalar or a source vector depending on the op. +- `%result` is the destination vector register value. +- For 32-bit scalar inputs, the scalar source MUST satisfy the backend's legal + scalar-source constraints for this family. + +--- + +#### Scalar Materialization + +##### `pto.vbr` + +- **syntax:** `%result = pto.vbr %value : T -> !pto.vreg` +- **semantics:** Broadcast scalar to all vector lanes. +- **inputs:** + `%value` is the scalar source. +- **outputs:** + `%result` is a vector whose active lanes all carry `%value`. +- **constraints and limitations:** + Supported forms are `b8`, `b16`, and `b32`. For `b8`, only the low 8 bits of + the scalar source are consumed. + +```c +for (int i = 0; i < N; i++) + dst[i] = value; +``` + +**Example:** +```mlir +%one = pto.vbr %c1_f32 : f32 -> !pto.vreg<64xf32> +``` + +--- + +##### `pto.vdup` + +- **syntax:** `%result = pto.vdup %input, %mask {position = "LOWEST|HIGHEST"} : T|!pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Duplicate scalar or vector element to all lanes. +- **inputs:** + `%input` supplies the scalar or source-lane value selected by `position`, + and `%mask` controls the active lanes. +- **outputs:** + `%result` is the duplicated vector. +- **constraints and limitations:** + `position` selects which source vector element is duplicated and is only valid + for vector input. `position` defaults to `LOWEST`. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? input_scalar_or_element : 0; +``` + +--- + +#### Predicate Generation + +##### `pto.pset_b8` / `pto.pset_b16` / `pto.pset_b32` + +- **syntax:** `%result = pto.pset_b8 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pset_b16 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pset_b32 "PATTERN" : !pto.mask` +- **semantics:** Materialize a predicate register from a named pattern token. + +**Supported pattern tokens:** + +| Pattern | Description | +|---------|-------------| +| `PAT_ALL` | All lanes active | +| `PAT_ALLF` | All lanes inactive | +| `PAT_H` | High half active | +| `PAT_Q` | Upper quarter active | +| `PAT_VL1`...`PAT_VL128` | First N logical lanes active | +| `PAT_M3`, `PAT_M4` | Modular patterns | + +`PAT_ALL` is the PTO spelling of the VISA-style all-true predicate pattern. +The other tokens listed above are also concrete installed-toolchain pattern +objects, not PTO-only aliases. + +**Example — All 64 f32 lanes active:** +```mlir +%all_active = pto.pset_b32 "PAT_ALL" : !pto.mask +``` + +**Example — First 16 lanes active:** +```mlir +%first_16 = pto.pset_b32 "PAT_VL16" : !pto.mask +``` + +--- + +##### `pto.pge_b8` / `pto.pge_b16` / `pto.pge_b32` + +- **syntax:** `%result = pto.pge_b8 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pge_b16 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pge_b32 "PATTERN" : !pto.mask` +- **semantics:** Generate a predicate from a lane-count pattern token. In the + common tail-mask form, `PAT_VL` marks the first `N` logical lanes active. +- **supported pattern tokens:** `PAT_ALL`, `PAT_ALLF`, `PAT_H`, `PAT_Q`, + `PAT_VL1`, `PAT_VL2`, `PAT_VL3`, `PAT_VL4`, `PAT_VL8`, `PAT_VL16`, + `PAT_VL32`, `PAT_VL64`, `PAT_VL128`, `PAT_M3`, `PAT_M4` + +```c +for (int i = 0; i < TOTAL_LANES; i++) + mask[i] = (i < len); +``` + +**Example — Tail mask for remainder loop:** +```mlir +%tail_mask = pto.pge_b32 "PAT_VL8" : !pto.mask +``` + +--- + +##### `pto.plt_b8` / `pto.plt_b16` / `pto.plt_b32` + +- **syntax:** `%mask, %scalar_out = pto.plt_b8 %scalar : i32 -> !pto.mask, i32` +- **syntax:** `%mask, %scalar_out = pto.plt_b16 %scalar : i32 -> !pto.mask, i32` +- **syntax:** `%mask, %scalar_out = pto.plt_b32 %scalar : i32 -> !pto.mask, i32` +- **semantics:** Generate a tail-style predicate from an SSA lane-count value. + On A5/V300-style toolchains, this family is exposed as a post-update wrapper: + the predicate result becomes `%mask`, and the wrapper's carry-out scalar state + is surfaced as `%scalar_out`. +- **inputs:** + `%scalar` is the incoming lane-count / remaining-count state. +- **outputs:** + `%mask` is the generated predicate. + `%scalar_out` is the post-update scalar carry-out from the same `plt` call + and can be threaded into a subsequent `pto.plt_b*` call in the same chain. + +```c +for (int i = 0; i < VL_t; ++i) + mask[i] = (i < scalar_in); + +scalar_out = (scalar_in < VL_t) ? 0 : (scalar_in - VL_t); +``` + +Where `VL_t` is the logical lane count of the concrete op variant: + +- `pto.plt_b8`: `VL_t = 256` +- `pto.plt_b16`: `VL_t = 128` +- `pto.plt_b32`: `VL_t = 64` + +--- + +#### Predicate Pack/Unpack + +##### `pto.ppack` + +- **syntax:** `%result = pto.ppack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Narrowing pack of predicate register. +- **part tokens:** + - `LOWER`: pack into the lower half of `%result`; the upper half is zeroed. + - `HIGHER`: pack into the higher half of `%result`; the lower half is zeroed. + +Conceptually, `pto.ppack` keeps one bit out of each adjacent 2-bit group from +`%input`, packs those kept bits into the selected half of `%result`, and fills +the other half with zeros. + +```c +// Let VL be the logical lane count of the destination predicate. +// LOWER +for (int i = 0; i < VL / 2; ++i) + result[i] = input[2 * i]; +for (int i = VL / 2; i < VL; ++i) + result[i] = 0; + +// HIGHER +for (int i = 0; i < VL / 2; ++i) + result[VL / 2 + i] = input[2 * i]; +for (int i = 0; i < VL / 2; ++i) + result[i] = 0; +``` + +--- + +##### `pto.punpack` + +- **syntax:** `%result = pto.punpack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Widening unpack of predicate register. +- **part tokens:** + - `LOWER`: unpack from the lower half of `%input`. + - `HIGHER`: unpack from the higher half of `%input`. + +Conceptually, `pto.punpack` reads the selected half of `%input`, zero-extends +each 1-bit predicate element into a 2-bit group in `%result`, and leaves the +expanded image in the full destination predicate register. + +```c +// Let VL be the logical lane count of the destination predicate. +// LOWER +for (int i = 0; i < VL / 2; ++i) { + result[2 * i] = input[i]; + result[2 * i + 1] = 0; +} + +// HIGHER +for (int i = 0; i < VL / 2; ++i) { + result[2 * i] = input[VL / 2 + i]; + result[2 * i + 1] = 0; +} +``` + +--- + +#### Predicate Logical Ops + +##### `pto.pand` + +- **syntax:** `%result = pto.pand %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise AND gated by a governing predicate. + +Inactive lanes selected out by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (src0[i] & src1[i]) : 0; +``` + +--- + +##### `pto.por` + +- **syntax:** `%result = pto.por %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise OR gated by a governing predicate. + +Inactive lanes selected out by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (src0[i] | src1[i]) : 0; +``` + +--- + +##### `pto.pxor` + +- **syntax:** `%result = pto.pxor %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise XOR gated by a governing predicate. + +Inactive lanes selected by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (src0[i] ^ src1[i]) : 0; +``` + +--- + +##### `pto.pnot` + +- **syntax:** `%result = pto.pnot %input, %mask : !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise NOT gated by a governing predicate. + +Inactive lanes selected by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (~src[i]) : 0; +``` + +--- + +##### `pto.psel` + +- **syntax:** `%result = pto.psel %src0, %src1, %sel : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate select (mux). `%sel` is the governing predicate that + chooses lanes from `%src0` or `%src1`. + +```c +for (int i = 0; i < N; i++) + dst[i] = sel[i] ? src0[i] : src1[i]; +``` + +--- + +##### `pto.pdintlv_b8` / `pto.pdintlv_b16` / `pto.pdintlv_b32` + +- **syntax:** `%low, %high = pto.pdintlv_b8 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pdintlv_b16 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pdintlv_b32 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **semantics:** De-interleave two predicate sources and return the two + de-interleaved predicate images in the same predicate element family. + +--- + +##### `pto.pintlv_b8` / `pto.pintlv_b16` / `pto.pintlv_b32` + +- **syntax:** `%low, %high = pto.pintlv_b8 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pintlv_b16 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pintlv_b32 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **semantics:** Interleave two predicate sources and return the two + resulting predicate images in the same predicate element family. + +--- + +#### Typical Usage + +```mlir +// Generate all-active mask for f32 (64 lanes) +%all = pto.pset_b32 "PAT_ALL" : !pto.mask + +// Generate tail mask for remainder (last 12 elements) +%tail = pto.pge_b32 "PAT_VL12" : !pto.mask + +// Compare and generate mask +%cmp_mask = pto.vcmp %a, %b, %all, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + +// Combine masks: only process tail elements that passed comparison +%combined = pto.pand %cmp_mask, %tail, %all : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + +// Use for predicated operation +%result = pto.vsel %true_vals, %false_vals, %combined : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 6. Unary Vector Ops + +> **Category:** Single-input vector operations +> **Pipeline:** PIPE_V (Vector Core) + +Element-wise operations that take one vector input and produce one vector output. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate operand. For this family, inactive lanes follow the + predication behavior of the selected instruction form: zeroing forms + zero-fill inactive lanes, while merging forms preserve the destination value. +- `%result` is the destination vector register value. Unless stated otherwise, + `%result` has the same lane count and element type as `%input`. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** values use **aclFloat16** in traces where measured. **bf16:** no simple-tile ST coverage on this surface; treat as **—**. + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vabs` | `RV_VABS_FP` | **5** | **5** | — | +| `pto.vneg` | `RV_VMULS` | **8** | **8** | — | +| `pto.vexp` | `RV_VEXP` | **16** | **21** | — | +| `pto.vln` | `RV_VLN` | **18** | **23** | — | +| `pto.vsqrt` | `RV_VSQRT` | **17** | **22** | — | +| `pto.vrelu` | `RV_VRELU` | **5** | **5** | — | +| `pto.vnot` | `RV_VNOT` | — | int-only paths | — | +| `pto.vmov` | `RV_VLD` proxy | **9** | **9** | — | + +--- + +#### Arithmetic + +##### `pto.vabs` + +- **syntax:** `%result = pto.vabs %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] < 0) ? -src[i] : src[i]; +``` + +- **inputs:** `%input` supplies the source lanes and `%mask` selects which lanes + participate. +- **outputs:** `%result` receives the lane-wise absolute values. +- **constraints and limitations:** Source and result types MUST match. On A5, + integer overflow follows the ISA default truncation behavior for this family; + `pto.vabs` is not an explicit saturating op. + +--- + +##### `pto.vneg` + +- **syntax:** `%result = pto.vneg %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = -src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise arithmetic negation. +- **constraints and limitations:** Source and result types MUST match. + +--- + +#### Transcendental + +##### `pto.vexp` + +- **syntax:** `%result = pto.vexp %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = expf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds `exp(input[i])` per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + +--- + +##### `pto.vln` + +- **syntax:** `%result = pto.vln %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = logf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the natural logarithm per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + For real-number semantics, active inputs SHOULD be strictly positive; non- + positive inputs follow the target's exception/NaN rules. + +--- + +##### `pto.vsqrt` + +- **syntax:** `%result = pto.vsqrt %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = sqrtf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the square root per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + Negative active inputs follow the target's exception/NaN rules. + +--- + +#### Activation + +##### `pto.vrelu` + +- **syntax:** `%result = pto.vrelu %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] > 0) ? src[i] : 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds `max(input[i], 0)` per active lane. +- **constraints and limitations:** Only floating-point element types are legal + on the current A5 surface described here. + +--- + +#### Bitwise + +##### `pto.vnot` + +- **syntax:** `%result = pto.vnot %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = ~src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the lane-wise bitwise inversion. +- **constraints and limitations:** Integer element types only. + +--- + +#### Movement + +#### Typical Usage + +```mlir +// Softmax numerator: exp(x - max) +%sub = pto.vsub %x, %max_broadcast, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%exp = pto.vexp %sub, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// ReLU activation +%activated = pto.vrelu %linear_out, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 7. Binary Vector Ops + +> **Category:** Two-input vector operations +> **Pipeline:** PIPE_V (Vector Core) + +Element-wise operations that take two vector inputs and produce one vector output. + +#### Common Operand Model + +- `%lhs` and `%rhs` are the two source vector register values. +- `%mask` is the predicate operand `Pg` that gates which lanes participate. +- `%result` is the destination vector register value. Unless explicitly noted, + it has the same lane count and element type as the inputs. +- Unless explicitly documented otherwise, `%lhs`, `%rhs`, and `%result` MUST + have matching vector shapes and element types. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** uses **aclFloat16** in measured traces. **bf16:** — (no dedicated vec tile ST on this surface). + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vadd` | `RV_VADD` | **7** | **7** | — | +| `pto.vsub` | `RV_VSUB` | **7** | **7** | — | +| `pto.vmul` | `RV_VMUL` | **8** | **8** | — | +| `pto.vdiv` | `RV_VDIV` | **17** | **22** | — | + +--- + +#### Arithmetic + +##### `pto.vadd` + +- **syntax:** `%result = pto.vadd %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i64, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +- **inputs:** `%lhs` and `%rhs` are added lane-wise; `%mask` selects active + lanes. +- **outputs:** `%result` is the lane-wise sum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vsub` + +- **syntax:** `%result = pto.vsub %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i64, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] - src1[i]; +``` + +- **inputs:** `%lhs` is the minuend, `%rhs` is the subtrahend, and `%mask` + selects active lanes. +- **outputs:** `%result` is the lane-wise difference. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmul` + +- **syntax:** `%result = pto.vmul %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, bf16, f32 (**NOT** i8/ui8) + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] * src1[i]; +``` + +- **inputs:** `%lhs` and `%rhs` are multiplied lane-wise; `%mask` selects + active lanes. +- **outputs:** `%result` is the lane-wise product. +- **constraints and limitations:** The current A5 profile excludes `i8/ui8` + forms from this surface. + +--- + +##### `pto.vdiv` + +- **syntax:** `%result = pto.vdiv %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 only (no integer division) + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] / src1[i]; +``` + +- **inputs:** `%lhs` is the numerator, `%rhs` is the denominator, and `%mask` + selects active lanes. +- **outputs:** `%result` is the lane-wise quotient. +- **constraints and limitations:** Floating-point element types only. Active + denominators containing `+0` or `-0` follow the target's exceptional + behavior. + +--- + +##### `pto.vmax` + +- **syntax:** `%result = pto.vmax %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src0[i] > src1[i]) ? src0[i] : src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` holds the lane-wise maximum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmin` + +- **syntax:** `%result = pto.vmin %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src0[i] < src1[i]) ? src0[i] : src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` holds the lane-wise minimum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +#### Bitwise + +##### `pto.vand` + +- **syntax:** `%result = pto.vand %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] & src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise AND. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vor` + +- **syntax:** `%result = pto.vor %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] | src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise OR. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vxor` + +- **syntax:** `%result = pto.vxor %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] ^ src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise XOR. +- **constraints and limitations:** Integer element types only. + +--- + +#### Shift + +##### `pto.vshl` + +- **syntax:** `%result = pto.vshl %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] << src1[i]; +``` + +- **inputs:** `%lhs` supplies the shifted value, `%rhs` supplies the per-lane + shift amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. Shift counts + SHOULD stay within `[0, bitwidth(T) - 1]`; out-of-range behavior is target- + defined unless the verifier narrows it further. + +--- + +##### `pto.vshr` + +- **syntax:** `%result = pto.vshr %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] >> src1[i]; // arithmetic for signed, logical for unsigned +``` + +- **inputs:** `%lhs` supplies the shifted value, `%rhs` supplies the per-lane + shift amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. Signedness of the + element type determines arithmetic vs logical behavior. + +--- + +#### Carry Operations + +##### `pto.vaddc` + +- **syntax:** `%result, %carry = pto.vaddc %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Add with carry output. + +```c +for (int i = 0; i < N; i++) { + uint64_t r = (uint64_t)src0[i] + src1[i]; + dst[i] = (T)r; + carry[i] = (r >> bitwidth); +} +``` + +- **inputs:** `%lhs` and `%rhs` are added lane-wise and `%mask` selects active + lanes. +- **outputs:** `%result` is the truncated arithmetic result and `%carry` is the + carry/overflow predicate per lane. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This is a carry-chain integer add family. On + the current A5 surface, only 32-bit integer element types are supported. + `%mask` and `%carry` therefore use the same typed-mask granularity as the + data vector family, which on the current documented A5 surface means + `!pto.mask`. + +--- + +##### `pto.vsubc` + +- **syntax:** `%result, %carry = pto.vsubc %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Subtract with per-lane carry output. + +```c +for (int i = 0; i < N; i++) { + dst[i] = src0[i] - src1[i]; + carry[i] = (src0[i] >= src1[i]); +} +``` + +- **inputs:** `%lhs` and `%rhs` are subtracted lane-wise and `%mask` selects + active lanes. +- **outputs:** `%result` is the arithmetic difference and `%carry` is the + per-lane carry predicate. For this subtraction family, active lanes set + `%carry[i] = 1` when the subtraction completes without borrow, and + `%carry[i] = 0` when a borrow occurs. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This operation is currently restricted to + the 32-bit integer carry/borrow-chain family. `%mask` and `%carry` + therefore use the same typed-mask granularity as the data vector family, + which on the current documented A5 surface means `!pto.mask`. + +--- + +#### Typical Usage + +```mlir +// Vector addition +%sum = pto.vadd %a, %b, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Element-wise multiply +%prod = pto.vmul %x, %y, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Clamp to range [min, max] +%clamped_low = pto.vmax %input, %min_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%clamped = pto.vmin %clamped_low, %max_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Bit manipulation +%masked = pto.vand %data, %bitmask, %mask : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +``` + + + +### 8. Vec-Scalar Ops + +> **Category:** Vector-scalar operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that combine a vector with a scalar value, applying the scalar to every lane. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%scalar` is the scalar operand in SSA form. +- `%mask` is the predicate operand. +- `%result` is the destination vector register value. +- For 32-bit scalar forms, the scalar source MUST satisfy the backend's legal + scalar-source constraints for this family. +- For elementwise vec-scalar families whose scalar conceptually matches the + vector element type (`pto.vadds`, `pto.vmuls`, `pto.vmaxs`, + `pto.vmins`, `pto.vlrelu`): + - signed integer vectors accept signed integer scalars with the same width, + and also accept signless `i` + - unsigned integer vectors accept unsigned integer scalars with the same + width, and also accept signless `i` + - signless integer vectors accept signless `i` +- `pto.vshls` and `pto.vshrs` are not part of that rule; their scalar operand + is the shift amount and remains fixed to `i16`. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** uses **aclFloat16** in measured traces. **bf16:** —. + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vadds` | `RV_VADDS` | **7** | **7** | — | +| `pto.vmuls` | `RV_VMULS` | **8** | **8** | — | + +--- + +#### Arithmetic + +##### `pto.vadds` + +- **syntax:** `%result = pto.vadds %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` +- **A5 types:** `si8`, `si16`, `si32`, `ui8`, `ui16`, `ui32`, `f16`, `bf16`, `f32` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] + scalar; +``` + +- **inputs:** `%input` is the source vector, `%scalar` is broadcast logically to + each lane, and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise sum. +- **constraints and limitations:** Input vector element type, scalar type, and + result vector element type MUST match. For integer vector forms, `%scalar` + may also use matching-signedness integer or signless `i` with the same + bit width as the vector element type, so it can be fed directly from `arith` + constants. + +--- + +##### `pto.vmuls` + +- **syntax:** `%result = pto.vmuls %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] * scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise product. +- **constraints and limitations:** Supported element types are hardware-family + specific; the current PTO micro Instruction documentation covers the common + numeric cases. For integer vector forms, `%scalar` may use matching-signedness + integer or signless `i` with the same bit width as the vector element + type. + +--- + +##### `pto.vmaxs` + +- **syntax:** `%result = pto.vmaxs %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] > scalar) ? src[i] : scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise maximum. +- **constraints and limitations:** Input and result types MUST match. For + integer vector forms, `%scalar` may use matching-signedness integer or + signless `i` with the same bit width as the vector element type. + +--- + +##### `pto.vmins` + +- **syntax:** `%result = pto.vmins %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] < scalar) ? src[i] : scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise minimum. +- **constraints and limitations:** Input and result types MUST match. For + integer vector forms, `%scalar` may use matching-signedness integer or + signless `i` with the same bit width as the vector element type. + +--- + +#### Shift + +##### `pto.vshls` + +- **syntax:** `%result = pto.vshls %input, %scalar, %mask : !pto.vreg, i16, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] << scalar; +``` + +- **inputs:** `%input` is the value vector, `%scalar` is the uniform `i16` shift + amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. The shift amount + SHOULD stay within the source element width. + +--- + +##### `pto.vshrs` + +- **syntax:** `%result = pto.vshrs %input, %scalar, %mask : !pto.vreg, i16, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] >> scalar; +``` + +- **inputs:** `%input` is the value vector, `%scalar` is the uniform `i16` shift + amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vlrelu` + +- **syntax:** `%result = pto.vlrelu %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : scalar * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%scalar` is the leaky slope, + and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise leaky-ReLU result. +- **constraints and limitations:** Only `f16` and `f32` forms are currently + documented for `pto.vlrelu`. + +--- + +#### Carry Operations + +##### `pto.vaddcs` + +- **syntax:** `%result, %carry = pto.vaddcs %lhs, %rhs, %carry_in, %mask : !pto.vreg, !pto.vreg, !pto.mask, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Add with carry-in and carry-out. + +```c +for (int i = 0; i < N; i++) { + uint64_t r = (uint64_t)src0[i] + src1[i] + carry_in[i]; + dst[i] = (T)r; + carry_out[i] = (r >> bitwidth); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the value vectors, `%carry_in` is the + incoming carry predicate, and `%mask` selects active lanes. +- **outputs:** `%result` is the arithmetic result and `%carry` is the carry-out + predicate. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This is the scalar-extended carry-chain + family. On the current A5 surface, only 32-bit integer element types are + supported. `%carry_in`, `%mask`, and `%carry` therefore all use the same + typed-mask granularity as the data vector family, which on the current + documented A5 surface means `!pto.mask`. + +--- + +##### `pto.vsubcs` + +- **syntax:** `%result, %carry = pto.vsubcs %lhs, %rhs, %carry_in, %mask : !pto.vreg, !pto.vreg, !pto.mask, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Subtract with carry input and output. + +```c +for (int i = 0; i < N; i++) { + dst[i] = src0[i] - src1[i] - (1 - carry_in[i]); + carry_out[i] = (src0[i] >= src1[i] + (1 - carry_in[i])); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the value vectors, `%carry_in` is the + incoming carry predicate, and `%mask` selects active lanes. +- **outputs:** `%result` is the arithmetic result and `%carry` is the + carry predicate after the lane-wise subtraction. For this subtraction family, + active lanes set `%carry[i] = 1` when the subtraction completes without + borrow, and `%carry[i] = 0` when a borrow occurs. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This is the scalar-extended borrow-chain + family and is currently restricted to 32-bit integer element types. + `%carry_in`, `%mask`, and `%carry` therefore all use the same typed-mask + granularity as the data vector family, which on the current documented A5 + surface means `!pto.mask`. + +--- + +#### Typical Usage + +```mlir +// Add bias to all elements +%biased = pto.vadds %activation, %bias_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Scale by constant +%scaled = pto.vmuls %input, %scale, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Clamp to [0, 255] for uint8 quantization +%clamped_low = pto.vmaxs %input, %c0, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> +%clamped = pto.vmins %clamped_low, %c255, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Shift right by fixed amount +%shifted = pto.vshrs %data, %c4, %mask : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +``` + + + +### 9. Conversion Ops + +> **Category:** Type conversion operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that convert between data types (float/int, narrowing/widening). + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate mask that selects active conversion lanes. +- `%result` is the destination vector register value. +- `rnd`, `sat`, and `part` are optional attributes that refine + conversion behavior when the selected source/destination type pair needs + rounding, saturation, or lane placement control. +- The single `pto.vcvt` surface covers float-int, float-float, int-float, and + int-int conversion families. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). Only representative traces below; other `pto.vcvt` conversion pairs depend on the RV lowering in the trace. + +| PTO op | RV (CA) | Note | Latency | +|--------|---------|------|---------| +| `pto.vcvt` | `RV_VCVT_F2F` | f32→f16 | **7** | +| `pto.vci` | — | no vector `RV_*` in sampled `veccore0` trace | — | + +--- + +#### `pto.vci` + +- **syntax:** `%result = pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- **semantics:** Generate a lane-index vector from a scalar seed/index value. +- **inputs:** + `%index` is the scalar seed or base index. +- **outputs:** + `%result` is the generated index vector. +- **constraints and limitations:** + This is an index-generation family, not a numeric conversion. `ORDER` and the + result element type together determine how indices are generated. `%result` + uses an integer element type, and the scalar `%index` type matches that + result element type. + +--- + +#### `pto.vcvt` + +- **syntax:** `%result = pto.vcvt %input, %mask {rnd = "RND", sat = "SAT", part = "PART"} : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Type conversion between float/int types with rounding control. + +```c +for (int i = 0; i < min(N, M); i++) + if (mask[i]) + dst[i] = convert(src[i], T0, T1, rnd); +``` + +- **inputs:** + `%input` is the source vector, `%mask` selects active lanes, and attributes + select rounding, saturation, and output placement when the conversion changes + width or packs into sub-lane positions. +- **outputs:** + `%result` is the converted vector. +- **constraints and limitations:** + Only documented source/destination type pairs are legal. All three + attributes are optional at the surface level, but only the subset meaningful + to the selected conversion kind should be provided. The execution mask must + use the typed-mask granularity that matches the source vector family on the + current surface; there is no `!pto.mask` form in VPTO. + +--- + +##### Rounding Modes + +| Mode | Description | +|------|-------------| +| `R` | Round to nearest, ties to even (default) | +| `A` | Round away from zero | +| `F` | Round toward negative infinity (floor) | +| `C` | Round toward positive infinity (ceil) | +| `Z` | Round toward zero (truncate) | +| `O` | Round to odd | + +--- + +##### Saturation Modes + +| Mode | Description | +|------|-------------| +| `SAT` | Saturate on overflow | +| `NOSAT` | No saturation (wrap/undefined on overflow) | + +--- + +##### Part Modes + +Use `part` when a width-changing conversion writes only one half of each wider +destination lane group. This is typically used in even/odd placement forms such +as `32 -> 16` or `16 -> 32` style conversions. + +| Mode | Description | +|------|-------------| +| `EVEN` | Output to even-indexed lanes | +| `ODD` | Output to odd-indexed lanes | + +--- + +##### Attribute Guidance + +- `rnd` + - Use when the conversion needs an explicit rounding rule, especially for + float-to-int, float-to-float narrowing, or integer-to-float forms that do + not map exactly. +- `mask` + - Use to select which source lanes participate in the conversion. In + width-changing conversions, `mask` works together with `part` / `pp` to + determine which logical lane positions are produced. +- `sat` + - Use when the conversion may overflow the destination range and hardware + exposes a saturating form. +- `part` + - Use for width-changing conversions that select the even or odd half of the + destination packing layout. + +###### Float To Int + +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<32xsi64>` +- `%dst = pto.vcvt %src, %mask {rnd, sat} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {rnd, part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {rnd, sat} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<256xsi8>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xsi32>` + +###### Float To Float + +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xbf16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xf32>` + +###### Int To Float + +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xsi8>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {rnd} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<64xf32>` +- `%dst = pto.vcvt %src, %mask {rnd} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<64xf32>` + +###### Int To Int + +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<128xui16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<64xui32>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xsi8>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xsi8>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<64xui32>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<64xui32>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<128xui16>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<128xui16>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<32xsi64>` + +##### A5 Supported Type Matrix + +The table below is only a summary. For exact attribute combinations, use the +per-form entries above as the source of truth. + +| `src \ dst` | `ui8` | `si8` | `ui16` | `si16` | `ui32` | `si32` | `si64` | `f16` | `f32` | `bf16` | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| `ui8` | | | Y | | Y | | | Y | | | +| `si8` | | | | Y | | Y | | Y | | | +| `ui16` | Y | | | | Y | | | | | | +| `si16` | Y | | | | Y | Y | | Y | Y | | +| `ui32` | Y | | Y | Y | | | | | | | +| `si32` | Y | | Y | Y | | | Y | | Y | | +| `si64` | | | | | | | | | | | +| `f16` | Y | Y | | Y | | Y | | | Y | | +| `f32` | | | | Y | | Y | Y | Y | | Y | +| `bf16` | | | | | | Y | | | Y | | + +--- + +##### Width-Changing Conversion Pattern + +For conversions that change width (e.g., f32→f16), use even/odd parts and combine: + +```mlir +// Convert two f32 vectors to one f16 vector +%even = pto.vcvt %in0, %mask {rnd = "R", sat = "SAT", part = "EVEN"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%odd = pto.vcvt %in1, %mask {rnd = "R", sat = "SAT", part = "ODD"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%result = pto.vor %even, %odd, %mask : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +``` + +--- + +#### `pto.vtrc` + +- **syntax:** `%result = pto.vtrc %input, %mask, "RND" : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Truncate/round float to integer-valued float (stays in float type). + +```c +for (int i = 0; i < N; i++) + dst[i] = round_to_int_valued_float(src[i], rnd); +``` + +- **inputs:** + `%input` is the floating-point source vector, `%mask` selects active lanes, + and `RND` selects the truncation/rounding rule. +- **outputs:** + `%result` is still a floating-point vector, but each active lane now carries + an integer-valued floating-point result. +- **constraints and limitations:** + This op does not change the element type. `T` must be `f16`, `f32`, or + `bf16`. `RND` must be one of `R`, `A`, `F`, `C`, or `Z`. `BW` must match the + element width: `b16` for `f16`/`bf16`, `b32` for `f32`. + +**Example:** +```mlir +// Round to nearest integer, keep as float +%rounded = pto.vtrc %input, %mask, "R" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// input: [1.4, 2.6, -1.5, 3.0] +// output: [1.0, 3.0, -2.0, 3.0] +``` + +--- + +#### Typical Usage + +```mlir +// Quantization: f32 → i8 with saturation +%scaled = pto.vmuls %input, %scale, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> +%quantized = pto.vcvt %scaled, %mask {rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> +// Then narrow i32 → i8 via pack ops + +// Mixed precision: bf16 → f32 for accumulation +%f32_vec = pto.vcvt %bf16_input, %mask {part = "EVEN"} + : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xf32> + +// Floor for integer division +%floored = pto.vtrc %ratio, %mask, "F" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%int_div = pto.vcvt %floored, %mask {rnd = "Z"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> +``` + + + +### 10. Reduction Ops + +> **Category:** Vector reduction operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that reduce a vector to a scalar or per-group result. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate operand `Pg`; inactive lanes do not participate. +- `%result` is the destination vector register value. +- Reduction results are written into the low-significance portion of the + destination vector and the remaining destination bits are zero-filled. + +--- + +#### Full Vector Reductions + +##### `pto.vcadd` + +- **syntax:** `%result = pto.vcadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i64, f16, f32 +- **semantics:** Sum all elements. Result in lane 0, others zeroed. + +```c +T sum = 0; +for (int i = 0; i < N; i++) + sum += src[i]; +dst[0] = sum; +for (int i = 1; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains the reduction result in its low element(s). +- **constraints and limitations:** Some narrow integer forms may widen the + internal accumulation or result placement. If all predicate bits are zero, the + result is zero. + +--- + +##### `pto.vcmax` + +- **syntax:** `%result = pto.vcmax %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Find max element with argmax. The lowest destination element + stores the maximum value, the second-lowest destination element stores the + index of the first maximum, and all remaining elements are zero-filled. + +```c +T mx = -INF; int idx = 0; +for (int i = 0; i < N; i++) + if (src[i] > mx) { mx = src[i]; idx = i; } +dst[0] = mx; +dst[1] = idx; +for (int i = 2; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result[0]` holds the extremum value and `%result[1]` holds the + index. Other destination elements are zero-filled. +- **constraints and limitations:** If there are multiple maxima, the minimum + index is written. For floating-point types, inactive lanes are treated as + `-INF`; if all lanes are inactive, `%result[0]` becomes `-INF`. For integer + types, inactive lanes are treated as the literal minimum value; if all lanes + are inactive, `%result[0]` becomes that literal minimum value. The index is + written into the second destination element slot of the same destination + vector register. + +--- + +##### `pto.vcmin` + +- **syntax:** `%result = pto.vcmin %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Find min element with argmin. The lowest destination element + stores the minimum value, the second-lowest destination element stores the + index of the first minimum, and all remaining elements are zero-filled. + +```c +T mn = INF; int idx = 0; +for (int i = 0; i < N; i++) + if (src[i] < mn) { mn = src[i]; idx = i; } +dst[0] = mn; +dst[1] = idx; +for (int i = 2; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result[0]` holds the extremum value and `%result[1]` holds the + index. Other destination elements are zero-filled. +- **constraints and limitations:** If there are multiple minima, the minimum + index is written. For floating-point types, inactive lanes are treated as + `+INF`; if all lanes are inactive, `%result[0]` becomes `+INF`. For integer + types, inactive lanes are treated as the literal maximum value; if all lanes + are inactive, `%result[0]` becomes that literal maximum value. The index is + written into the second destination element slot of the same destination + vector register. + +--- + +#### Per-VLane (Group) Reductions + +The vector register is organized as **8 VLanes** of 32 bytes each. Group reductions operate within each VLane independently. + +``` +vreg layout (f32 example, 64 elements total): +VLane 0: [0..7] VLane 1: [8..15] VLane 2: [16..23] VLane 3: [24..31] +VLane 4: [32..39] VLane 5: [40..47] VLane 6: [48..55] VLane 7: [56..63] +``` + +##### `pto.vcgadd` + +- **syntax:** `%result = pto.vcgadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Sum within each VLane. 8 results at indices 0, 8, 16, 24, 32, 40, 48, 56 (for f32). + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +// For f32: results at dst[0], dst[8], dst[16], dst[24], dst[32], dst[40], dst[48], dst[56] +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one sum per 32-byte VLane group, written + contiguously into the low slot of each group. +- **constraints and limitations:** This is a per-32-byte VLane-group reduction. + Inactive lanes are treated as zero. + +--- + +##### `pto.vcgmax` + +- **syntax:** `%result = pto.vcgmax %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Max within each VLane. + +```c +int K = N / 8; +for (int g = 0; g < 8; g++) { + T mx = -INF; + for (int i = 0; i < K; i++) + if (src[g*K + i] > mx) mx = src[g*K + i]; + dst[g*K] = mx; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one maximum per 32-byte VLane group. +- **constraints and limitations:** Grouping is by hardware 32-byte VLane, not by + arbitrary software subvector. + +--- + +##### `pto.vcgmin` + +- **syntax:** `%result = pto.vcgmin %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Min within each VLane. + +```c +int K = N / 8; +for (int g = 0; g < 8; g++) { + T mn = INF; + for (int i = 0; i < K; i++) + if (src[g*K + i] < mn) mn = src[g*K + i]; + dst[g*K] = mn; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one minimum per 32-byte VLane group. +- **constraints and limitations:** Grouping is by hardware 32-byte VLane, not by + arbitrary software subvector. + +--- + +#### Prefix Operations + +##### `pto.vcpadd` + +- **syntax:** `%result = pto.vcpadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Inclusive prefix sum (scan). + +```c +dst[0] = src[0]; +for (int i = 1; i < N; i++) + dst[i] = dst[i-1] + src[i]; +``` + +**Example:** +```c +// input: [1, 2, 3, 4, 5, ...] +// output: [1, 3, 6, 10, 15, ...] +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` is the inclusive prefix-sum vector. +- **constraints and limitations:** Only floating-point element types are + documented on the current A5 surface here. + +--- + +#### Typical Usage + +```mlir +// Softmax: find max for numerical stability +%max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// max is in lane 0, broadcast it +%max_broadcast = pto.vlds %ub_tmp[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> + +// Row-wise sum using vcgadd (for 8-row tile) +%row_sums = pto.vcgadd %tile, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// Results at indices 0, 8, 16, 24, 32, 40, 48, 56 + +// Full vector sum for normalization +%total = pto.vcadd %values, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// total[0] contains the sum + +// Prefix sum for cumulative distribution +%cdf = pto.vcpadd %pdf, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 11. Compare & Select + +> **Category:** Comparison and conditional selection operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that compare vectors and conditionally select elements. + +#### Common Operand Model + +- `%src0` and `%src1` are source vector operands. +- `%scalar` is the scalar operand for scalar-comparison families. +- `%seed` is the incoming predicate that limits which lanes participate in the + compare. +- `%result` is either a predicate mask (`vcmp`, `vcmps`) or a vector register + (`vsel`, `vselr`, `vselrv2`). + +--- + +#### Comparison Operations + +##### `pto.vcmp` + +- **syntax:** `%result = pto.vcmp %src0, %src1, %seed, "CMP_MODE" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.mask` +- **semantics:** Element-wise comparison, output predicate mask. + +```c +for (int i = 0; i < N; i++) + if (seed[i]) + dst[i] = (src0[i] CMP src1[i]) ? 1 : 0; +``` + +**Compare modes:** + +| Mode | Operation | +|------|-----------| +| `eq` | Equal (==) | +| `ne` | Not equal (!=) | +| `lt` | Less than (<) | +| `le` | Less than or equal (<=) | +| `gt` | Greater than (>) | +| `ge` | Greater than or equal (>=) | + +**Example:** +```mlir +%all_active = pto.pset_b32 "PAT_ALL" : !pto.mask +%lt_mask = pto.vcmp %a, %b, %all_active, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +// lt_mask[i] = 1 if a[i] < b[i] +``` + +- **inputs:** `%src0`, `%src1`, and `%seed`; `CMP_MODE` selects the comparison + predicate. +- **outputs:** `%result` is the generated predicate mask. +- **constraints and limitations:** Only lanes enabled by `%seed` participate. + Integer and floating-point comparisons follow their own element-type-specific + comparison rules. `%seed` and `%result` keep the typed-mask granularity that + matches `%src0` / `%src1`. + +--- + +##### `pto.vcmps` + +- **syntax:** `%result = pto.vcmps %src, %scalar, %seed, "CMP_MODE" : !pto.vreg, T, !pto.mask -> !pto.mask` +- **semantics:** Compare vector against scalar. + +```c +for (int i = 0; i < N; i++) + if (seed[i]) + dst[i] = (src[i] CMP scalar) ? 1 : 0; +``` + +**Example:** +```mlir +%positive_mask = pto.vcmps %values, %c0_f32, %all_active, "gt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +// positive_mask[i] = 1 if values[i] > 0 +``` + +- **inputs:** `%src` is the vector source, `%scalar` is the scalar comparison + value, and `%seed` is the incoming predicate. +- **outputs:** `%result` is the generated predicate mask. +- **constraints and limitations:** For 32-bit scalar forms, the scalar source + MUST satisfy the backend's legal scalar-source constraints for this family. + `%seed` and `%result` keep the typed-mask granularity that matches `%src`. + +--- + +#### Selection Operations + +##### `pto.vsel` + +- **syntax:** `%result = pto.vsel %src0, %src1, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Per-lane select based on mask. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? src0[i] : src1[i]; +``` + +**Example — Conditional assignment:** +```mlir +// dst = mask ? true_vals : false_vals +%result = pto.vsel %true_vals, %false_vals, %condition + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +- **inputs:** `%src0` is the true-path vector, `%src1` is the false-path vector, + and `%mask` selects between them. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** Source vectors and result MUST have matching + vector shapes and element types. `%mask` keeps the typed-mask granularity + that matches the selected vector family. + +--- + +##### `pto.vselr` + +- **syntax:** `%result = pto.vselr %src, %idx : !pto.vreg, !pto.vreg> -> !pto.vreg` +- **semantics:** Lane-select by index vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = src[idx[i]]; +``` + +- **inputs:** `%src` is the source vector. `%idx` is the lane-index vector. +- **outputs:** `%result` is the reordered vector. +- **constraints and limitations:** `%idx` must use integer elements. `%idx` + must have the same lane count as `%src`, and its integer element width must + match the bit width of `%src` element type. + +--- + +##### `pto.vselrv2` + +- **syntax:** `%result = pto.vselrv2 %src0, %src1 : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Variant select form with the same current two-vector operand shape. +- **inputs:** `%src0` and `%src1` are the source vectors. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** This page records the surface shape only. + Lowering MUST preserve the exact A5 variant semantics selected for this form. + +--- + +#### Typical Usage + +```mlir +// Clamp negative values to zero (manual ReLU) +%all = pto.pset_b32 "PAT_ALL" : !pto.mask +%zero = pto.vbr %c0_f32 : f32 -> !pto.vreg<64xf32> +%neg_mask = pto.vcmps %input, %c0_f32, %all, "lt" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%clamped = pto.vsel %zero, %input, %neg_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Element-wise max via compare+select +%gt_mask = pto.vcmp %a, %b, %all, "gt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +%max_ab = pto.vsel %a, %b, %gt_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Threshold filter +%above_thresh = pto.vcmps %scores, %threshold, %all, "ge" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%filtered = pto.vsel %scores, %zero, %above_thresh : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +--- + +#### Compare + Select Pattern + +```mlir +// Softmax safe exp: exp(x - max) where x < max returns exp of negative +// but we want to clamp to avoid underflow + +%all = pto.pset_b32 "PAT_ALL" : !pto.mask + +// 1. Compare against threshold +%too_small = pto.vcmps %x_minus_max, %min_exp_arg, %all, "lt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask + +// 2. Clamp values below threshold +%clamped = pto.vsel %min_exp_arg_vec, %x_minus_max, %too_small + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// 3. Safe exp +%exp_result = pto.vexp %clamped, %all : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 12. Data Rearrangement + +> **Category:** In-register data movement and permutation +> **Pipeline:** PIPE_V (Vector Core) + +Operations that rearrange data within or between vector registers without memory access. + +#### Common Operand Model + +- `%lhs` / `%rhs` are source vector register values. +- `%src` is a single source vector register value. +- `%result` is the destination vector register value unless an op explicitly + returns multiple vectors. +- These families do not access UB directly; they only rearrange register + contents. + +--- + +#### Interleave / Deinterleave + +##### `pto.vintlv` + +- **syntax:** `%low, %high = pto.vintlv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg, !pto.vreg` +- **semantics:** Interleave elements from two sources. + +```c +// Interleave: merge even/odd elements from two sources +// low = {src0[0], src1[0], src0[1], src1[1], ...} +// high = {src0[N/2], src1[N/2], src0[N/2+1], src1[N/2+1], ...} +``` + +- **inputs:** `%lhs` and `%rhs` are the two source vectors. +- **outputs:** `%low` and `%high` are the two destination vectors. +- **constraints and limitations:** The two outputs form a paired interleave + result. The PTO micro Instruction representation exposes that pair as two SSA results, and the pair ordering MUST + be preserved. + +--- + +##### `pto.vdintlv` + +- **syntax:** `%low, %high = pto.vdintlv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg, !pto.vreg` +- **semantics:** Deinterleave elements into even/odd. + +```c +// Deinterleave: separate even/odd elements +// low = {src0[0], src0[2], src0[4], ...} // even +// high = {src0[1], src0[3], src0[5], ...} // odd +``` + +- **inputs:** `%lhs` and `%rhs` represent the interleaved source stream in the + current PTO micro Instruction representation. +- **outputs:** `%low` and `%high` are the separated destination vectors. +- **constraints and limitations:** The two outputs form the even/odd + deinterleave result pair, and their ordering MUST be preserved. + +--- + +#### Compress / Expand + +##### `pto.vsqz` + +- **syntax:** `%result = pto.vsqz %src, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Compress — pack active lanes to front. + +```c +int j = 0; +for (int i = 0; i < N; i++) + if (mask[i]) dst[j++] = src[i]; +while (j < N) dst[j++] = 0; +``` + +**Use case:** Sparse data compaction, filtering. + +- **inputs:** `%src` is the source vector and `%mask` selects which elements are + kept. +- **outputs:** `%result` is the compacted vector. +- **constraints and limitations:** This is a reduction-style compaction family. + Preserved element order MUST match source lane order. + +--- + +##### `pto.vusqz` + +- **syntax:** `%result = pto.vusqz %src, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Generate per-lane prefix counts from the governing predicate. + +```c +dst[0] = 0; +for (int i = 1; i < N; i++) + dst[i] = mask[i - 1] ? (dst[i - 1] + 1) : dst[i - 1]; +``` + +- **inputs:** `%mask` is the governing predicate. The current PTO surface keeps + `%src` in the operand list for interface compatibility, but the observable + result semantics are determined by `%mask`. +- **outputs:** `%result[i]` equals the number of active lanes in `%mask[0:i)`, + with `%result[0] = 0`. +- **constraints and limitations:** `T` is currently limited to `si8`, `si16`, + or `si32`. This operation is a predicate-derived counting/rearrangement + primitive rather than a value-placement primitive. The final predicate lane + does not contribute to a later output lane because there is no `dst[N]`. + +--- + +--- + +##### `pto.vselr` + +- **syntax:** `%result = pto.vselr %src, %idx : !pto.vreg, !pto.vreg> -> !pto.vreg` +- **semantics:** Register lane-select with an explicit index vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = src[idx[i]]; +``` + +- **inputs:** `%src` is the source vector. `%idx` is the lane-index vector. +- **outputs:** `%result` is the reordered vector. +- **constraints and limitations:** This page records the rearrangement use of + the family; the compare/select page documents the same name from the predicate + selection perspective. + +--- + +#### Pack / Unpack + +##### `pto.vpack` + +- **syntax:** `%result = pto.vpack %src, "PART" : !pto.vreg -> !pto.vreg<2NxT_narrow>` +- **semantics:** Narrow one wide vector and place the narrowed payload into the + selected half of the result. The other half is filled with zero. + +```c +// e.g., vreg<64xi32> → vreg<128xui16> +for (int i = 0; i < N; i++) + dst[i] = 0; + +if (part == LOWER) { + for (int i = 0; i < N; i++) + dst[i] = truncate(src[i]); +} else { // HIGHER + for (int i = 0; i < N; i++) + dst[N + i] = truncate(src[i]); +} +``` + +- **inputs:** `%src` is the wide source vector. `"LOWER"` and `"HIGHER"` + select whether the narrowed payload lands in the lower or upper half. +- **outputs:** `%result` is the packed narrow vector. +- **constraints and limitations:** Packing is a narrowing conversion with + truncation semantics. Current VPTO surface supports `i32/ui32 -> ui16` and + `i16/ui16 -> ui8`. + +--- + +##### `pto.vsunpack` + +- **syntax:** `%result = pto.vsunpack %src, %part : !pto.vreg, index -> !pto.vreg` +- **semantics:** Sign-extending unpack — narrow to wide (half). + +```c +// e.g., vreg<128xi16> → vreg<64xi32> (one half) +for (int i = 0; i < N/2; i++) + dst[i] = sign_extend(src[part_offset + i]); +``` + +- **inputs:** `%src` is the packed narrow vector and `%part` selects which half + is unpacked. +- **outputs:** `%result` is the widened vector. +- **constraints and limitations:** This is the sign-extending unpack family. + +--- + +##### `pto.vzunpack` + +- **syntax:** `%result = pto.vzunpack %src, %part : !pto.vreg, index -> !pto.vreg` +- **semantics:** Zero-extending unpack — narrow to wide (half). + +```c +for (int i = 0; i < N/2; i++) + dst[i] = zero_extend(src[part_offset + i]); +``` + +- **inputs:** `%src` is the packed narrow vector and `%part` selects which half + is unpacked. +- **outputs:** `%result` is the widened vector. +- **constraints and limitations:** This is the zero-extending unpack family. + +--- + +#### Typical Usage + +```mlir +// AoS → SoA conversion using deinterleave +%even, %odd = pto.vdintlv %interleaved0, %interleaved1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// Filter: keep only elements passing condition +%pass_mask = pto.vcmps %values, %threshold, %all, "gt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%compacted = pto.vsqz %values, %pass_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Type narrowing via pack +%packed_i16 = pto.vpack %wide_i32, "LOWER" + : !pto.vreg<64xi32> -> !pto.vreg<128xui16> +``` + +--- + +#### V2 Interleave Forms + +##### `pto.vintlvv2` + +- **syntax:** `%result = pto.vintlvv2 %lhs, %rhs, "PART" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **inputs:** `%lhs` and `%rhs` are source vectors and `PART` selects the + returned half of the V2 interleave result. +- **outputs:** `%result` is the selected interleave half. +- **constraints and limitations:** This op exposes only one half of the V2 + result in SSA form. + +##### `pto.vdintlvv2` + +- **syntax:** `%result = pto.vdintlvv2 %lhs, %rhs, "PART" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **inputs:** `%lhs` and `%rhs` are source vectors and `PART` selects the + returned half of the V2 deinterleave result. +- **outputs:** `%result` is the selected deinterleave half. +- **constraints and limitations:** This op exposes only one half of the V2 + result in SSA form. + + + +### 13. DSA/SFU Ops + +> **Category:** Domain-specific accelerator and special function unit operations +> **Pipeline:** PIPE_V (Vector Core) / SFU + +Fused operations, special functions, and UB-to-UB operations that leverage hardware acceleration. + +#### Common Operand Model + +- `%input`, `%lhs`, `%rhs`, `%acc`, and `%alpha` are source SSA values whose + roles are called out per instruction. +- `%mask` is the predicate operand `Pg` when present. +- `%result` is the destination SSA value. +- This page mixes three different backend shapes: pure `vreg -> vreg` ops, + conversion/fusion ops, and UB-to-UB helpers. Each instruction section calls + out which storage model it uses. + +--- + +#### Fused Activation Ops (vreg→vreg) + +##### `pto.vlrelu` + +- **syntax:** `%result = pto.vlrelu %input, %alpha, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Leaky ReLU with scalar alpha. + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : alpha * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%alpha` is the scalar slope, + and `%mask` selects active lanes. +- **outputs:** `%result` is the leaky-ReLU vector. +- **constraints and limitations:** Only `f16` and `f32` forms are currently + documented for `pto.vlrelu`. + +--- + +##### `pto.vprelu` + +- **syntax:** `%result = pto.vprelu %input, %alpha : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Parametric ReLU with per-element alpha vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : alpha[i] * src[i]; +``` + +- **inputs:** `%input` is the activation vector and `%alpha` is the per-element + slope vector. +- **outputs:** `%result` is the parametric-ReLU vector. +- **constraints and limitations:** Floating-point element types only on the + current A5 surface. + +--- + +##### `pto.vexpdiff` + +- **syntax:** `%result = pto.vexpdiff %input, %max, "EVEN|ODD" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** input `f16` or `f32`, output `f32` +- **semantics:** Fused exp(x - max) for numerically stable softmax. + +```c +for (int i = 0; i < N; i++) + dst[i] = expf(src[i] - max[i]); +``` + +**Use case:** Softmax numerator computation with numerical stability. + +- **inputs:** `%input` is the source vector and `%max` is the broadcasted + subtraction term. `%part` selects `EVEN` or `ODD` for the + underlying hardware contract. +- **outputs:** `%result` is the fused `exp(input - max)` vector with `f32` + elements. +- **constraints and limitations:** Source vectors must be `f16` or `f32`, the + result vector must be `f32`, and source/result storage width must match. + +--- + +#### Fused Compute+Convert Ops + +##### `pto.vaxpy` + +- **syntax:** `%result = pto.vaxpy %src0, %src1, %alpha : !pto.vreg, !pto.vreg, T -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** AXPY — scalar-vector multiply-add. + +```c +for (int i = 0; i < N; i++) + dst[i] = alpha * src0[i] + src1[i]; +``` + +- **inputs:** `%src0` is the scaled vector, `%src1` is the addend vector, and + `%alpha` is the scalar multiplier. +- **outputs:** `%result` is the fused AXPY result. +- **constraints and limitations:** Floating-point element types only on the + current documented surface. + +--- + +#### Extended Arithmetic + +##### `pto.vmull` + +- **syntax:** `%low, %high = pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` +- **A5 types:** i32/ui32 (native 32×32→64 widening multiply) +- **semantics:** Widening multiply with high/low results. + +```c +for (int i = 0; i < 64; i++) { + int64_t r = (int64_t)src0_i32[i] * (int64_t)src1_i32[i]; + dst_lo[i] = (int32_t)(r & 0xFFFFFFFF); + dst_hi[i] = (int32_t)(r >> 32); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the source vectors and `%mask` selects + active lanes. +- **outputs:** `%low` and `%high` expose the widened-product low/high parts. +- **constraints and limitations:** The current documented A5 form is the native + widening 32x32->64 integer multiply family. + +--- + +##### `pto.vmula` + +- **syntax:** `%result = pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Multiply-accumulate. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) + dst[i] = acc[i] + lhs[i] * rhs[i]; +``` + +- **inputs:** `%acc` is the accumulator input, `%lhs` and `%rhs` are the + multiplicands, and `%mask` selects active lanes. +- **outputs:** `%result` is the multiply-accumulate result. +- **constraints and limitations:** `pto.vmula` is a fused multiply-accumulate + operation and is not always interchangeable with separate `vmul` plus `vadd`. + +--- + +#### Index Generation + +##### `pto.vci` + +- **syntax:** `%result = pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- **semantics:** Generate lane index vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = base_index + i; +``` + +**Use case:** Generate indices for gather/scatter, argsort, etc. + +- **inputs:** `%index` is the scalar seed/base index. +- **outputs:** `%result` is the generated index vector. +- **constraints and limitations:** This page documents the arithmetic/indexing + use of the family; the conversion page also records the same opcode for + completeness. + +--- + +#### Sorting Operations + +##### `pto.vbitsort` + +- **syntax:** `pto.vbitsort %dest, %src, %indices, %repeat_times : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, index` +- **semantics:** Sort 32 region proposals by score and materialize sorted + proposal records into `%dest`. +- **inputs:** `%dest` is the UB destination buffer. `%src` is the UB score + buffer. `%indices` is the UB index buffer. `%repeat_times` is the repeat + count; each repeat processes the next adjacent group of 32 scores and 32 + indices. +- **outputs:** This op writes UB memory and returns no SSA value. Each output + record occupies 8 bytes: the upper 4 bytes hold the index and the lower + 4 bytes hold the score. For `f16` score forms, the score uses the lower + 2 bytes of that 4-byte score field and the upper 2 bytes are reserved. +- **constraints and limitations:** `%dest`, `%src`, and `%indices` MUST be + UB-backed pointers and SHOULD satisfy the backend alignment contract expected + by the A5 `VBS32` instruction. Scores are sorted in descending order, so the + highest score is written to the lowest destination address. Equal-score ties + preserve the earlier input proposal first. This is a UB helper, not a pure + `vreg -> vreg` op. + +--- + +##### `pto.vmrgsort4` + +- **syntax:** `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr, i64, i64` +- **semantics:** Merge-sort 4 pre-sorted input vectors. +- **inputs:** `%dest` is the UB destination, `%src0..%src3` are the four + pre-sorted UB inputs, `%count` is the number of valid elements, and `%config` + is the operation control word. +- **outputs:** This op writes UB memory and returns no SSA value. +- **constraints and limitations:** Inputs MUST already be sorted according to + the sort order encoded by `%config`. + +--- + +#### Current Implementation Surface Summary + +- `pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` +- `pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- `pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- `pto.vbitsort %dest, %src, %indices, %repeat_times : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, index` +- `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, i64, i64` + +--- + +#### Typical Usage + +```mlir +// Softmax with fused expdiff +%max_broadcast = pto.vlds %ub_max[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> +%exp_stable = pto.vexpdiff %logits, %max_broadcast : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// Leaky ReLU activation +%activated = pto.vlrelu %linear_out, %alpha_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Generate indices for argsort +%indices = pto.vci %c0 {order = "ASC"} : i32 -> !pto.vreg<64xi32> +``` + + + +### 14. Arith (Shared MLIR Dialect) + +> **Category:** Shared full scalar `arith` surface used around PTO ops +> **Dialect:** `arith` +> **Upstream Reference:** https://mlir.llvm.org/docs/Dialects/ArithOps/ + +The upstream MLIR `arith` dialect defines primitive arithmetic, comparison, select, and cast operations over signless integer, index, floating-point, and boolean-compatible scalar values. Within PTO micro Instruction code, the full scalar operation surface of `arith` is supported. These ops are used around PTO instructions to build constants, compute offsets and loop bounds, perform general scalar math, derive valid-shape metadata, and form predicates for `scf` control flow. + +These ops are part of the documented PTO micro Instruction surface, but they are not PTO ISA instructions. + +--- + +#### Role in PTO micro Instruction Code + +- materialize scalar constants used by PTO scalar operands and loop bounds +- compute scalar/index offsets for tensor views, partitioning, and dynamic shapes +- perform general scalar integer and floating-point math outside PTO vector/tile payload operations +- derive scalar predicates that guard `scf.if` or `scf.while` +- apply scalar casts, width changes, bitwise ops, and selects without introducing PTO-specific control ops + +Prefer PTO ops for vector or tile payload math. Use `arith` for scalar computation and bookkeeping that surrounds PTO regions. + +--- + +#### Supported Surface + +The documented PTO micro Instruction surface supports the full scalar operation surface of upstream `arith`. The upstream `arith` dialect reference remains authoritative for the exhaustive op-by-op syntax and semantics. The categories below summarize how that support is used in PTO micro Instruction code. + +| Category | Representative Ops | Typical Use in PTO micro Instruction Code | +|----------|--------------------|------------------| +| Constants | `arith.constant` | integer, floating-point, boolean, and `index` constants | +| Integer / Index Arithmetic | `arith.addi`, `arith.subi`, `arith.muli`, `arith.divsi`, `arith.divui`, `arith.ceildivsi`, `arith.ceildivui`, `arith.floordivsi`, `arith.remsi`, `arith.remui` | offsets, bounds, chunk sizes, scalar math | +| Floating-Point Arithmetic | `arith.addf`, `arith.subf`, `arith.mulf`, `arith.divf`, `arith.negf`, `arith.maximumf`, `arith.minimumf`, `arith.maxnumf`, `arith.minnumf` | scalar math around PTO regions | +| Bitwise / Shift Ops | `arith.andi`, `arith.ori`, `arith.xori`, `arith.shli`, `arith.shrsi`, `arith.shrui` | flags, masks, packed scalar fields | +| Comparisons / Select | `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.maxsi`, `arith.minui` | predicates, clamps, scalar muxes | +| Casts / Width Changes | `arith.index_cast`, `arith.index_castui`, `arith.extsi`, `arith.extui`, `arith.trunci`, `arith.sitofp`, `arith.uitofp`, `arith.fptosi`, `arith.fptoui`, `arith.extf`, `arith.truncf`, `arith.bitcast` | ABI glue, dynamic-shape plumbing, scalar type adaptation | + +--- + +#### Current PTOAS Coverage + +- the current repository examples are still dominated by constants, casts, integer/index arithmetic, compares, and selects because those are the most common surrounding-scalar patterns in existing kernels +- backend-specific tests such as the PTO shared-dialect fixture visibly exercise only a representative subset of `arith` ops in a single path +- the documented PTO micro Instruction source-level contract is nevertheless the full scalar `arith` surface, not just the index-heavy subset that appears most often in current samples + +This section therefore uses representative categories and examples instead of pretending that the supported `arith` surface is limited to the currently most common sample patterns. + +--- + +#### Typical Patterns + +##### Scalar Setup + +```mlir +%c0 = arith.constant 0 : index +%c1 = arith.constant 1 : index +%scale = arith.constant 2.0 : f32 +``` + +##### Dynamic Offset Computation + +```mlir +%vrow = arith.index_cast %valid_row : i32 to index +%chunk = arith.muli %row, %c32 : index +%tail = arith.subi %limit, %chunk : index +``` + +##### General Scalar Arithmetic + +```mlir +%sum_i = arith.addi %lhs_i, %rhs_i : i32 +%sum_f = arith.addf %lhs_f, %rhs_f : f32 +%prod_f = arith.mulf %sum_f, %scale : f32 +``` + +##### Scalar Predicate and Selection + +```mlir +%is_first = arith.cmpi eq, %i, %c0 : index +%active = arith.select %is_first, %first_count, %steady_count : index +``` + +##### Bitwise / Width Adaptation + +```mlir +%flags = arith.andi %flags0, %flags1 : i32 +%wide = arith.extui %flags : i32 to i64 +%shrunk = arith.trunci %wide : i64 to i16 +``` + +--- + +#### Authoring Guidance + +- treat upstream `arith` scalar semantics as the source of truth for supported scalar ops +- keep `arith` values scalar or `index` typed; do not use `arith` as a substitute for PTO vector/tile compute +- use `arith` for general scalar math, scalar comparisons, bitwise operations, and casts around PTO regions, not just for `index` arithmetic +- use `arith.cmpi` / `arith.cmpf` plus `scf.if` / `scf.while` for control flow, not ad hoc control intrinsics +- prefer `arith.index_cast` / `arith.index_castui` at ABI or shape boundaries where `index` is required, but do not read that as a restriction on the rest of scalar `arith` + + + +### 15. SCF (Shared MLIR Dialect) + +> **Category:** Shared structured control flow around PTO regions +> **Dialect:** `scf` +> **Upstream Reference:** https://mlir.llvm.org/docs/Dialects/SCFDialect/ + +The upstream MLIR `scf` dialect defines structured control flow operations with regions, including counted loops, conditional regions, and while-style loops. In PTO micro Instruction code, `scf` is the control shell around PTO ops: it sequences DMA, vector, and tile operations; carries scalar or tile state across iterations; and preserves analyzable control flow for PTO-specific analyses and lowerings. + +These ops are part of the documented PTO micro Instruction surface, but they are shared MLIR control-flow constructs rather than PTO ISA instructions. + +--- + +#### Supported Ops + +| Op | Role in PTO micro Instruction Code | Notes | +|----|------------------------|-------| +| `scf.for` | counted loops and loop-carried values | common structured counted loop form | +| `scf.if` | structured conditional execution | may yield values or act as side-effect-only branch | +| `scf.yield` | region terminator for `for` / `if` / `while` bodies | carries loop or branch results | +| `scf.while` | break-like or stateful loops | useful for source-level structured control | +| `scf.condition` | loop-continue / loop-exit decision for `scf.while` | placed in the "before" region | + +Ops such as `scf.execute_region`, `scf.forall`, or `scf.index_switch` are not part of the documented shared-dialect portion of the PTO micro Instruction surface here. + +--- + +#### Current PTOAS Coverage + +- `scf.for`, `scf.if`, and `scf.yield` are directly exercised in the shared-dialect PTO fixture and appear widely across PTO samples +- PTO synchronization and memory analyses explicitly reason about `scf.for`, `scf.if`, `scf.yield`, and `scf.while` +- `scf.while` and `scf.condition` appear in control-flow samples and are handled in PTO-to-EmitC control-flow lowering, but they are less broadly exercised than `for` / `if` on all backend paths + +--- + +#### Typical Patterns + +##### Counted Loop + +```mlir +scf.for %i = %c0 to %c4 step %c1 { + %offset = arith.muli %i, %c32 : index + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} +``` + +##### Counted Loop with Loop-Carried State + +```mlir +%final_alive = scf.for %i = %c0 to %c4 step %c1 + iter_args(%alive = %true) -> (i1) { + %break_now = arith.cmpi eq, %i, %c2 : index + %next_alive = scf.if %break_now -> (i1) { + scf.yield %false : i1 + } else { + scf.yield %alive : i1 + } + scf.yield %next_alive : i1 +} +``` + +##### Structured Conditional Region + +```mlir +%is_mode_a = arith.cmpi eq, %mode, %c0_i32 : i32 +scf.if %is_mode_a { + pto.tmuls ins(%data, %scale_a : !pto.tile_buf<...>, f32) outs(%data : !pto.tile_buf<...>) +} else { + pto.tadds ins(%data, %bias_b : !pto.tile_buf<...>, f32) outs(%data : !pto.tile_buf<...>) +} +``` + +##### While-Style Break Loop + +```mlir +%final:2 = scf.while (%i = %c0, %alive = %true) : (index, i1) -> (index, i1) { + %lt = arith.cmpi slt, %i, %c4 : index + %go = arith.andi %lt, %alive : i1 + scf.condition(%go) %i, %alive : index, i1 +} do { +^bb0(%i2: index, %alive2: i1): + %next_i = arith.addi %i2, %c1 : index + scf.yield %next_i, %alive2 : index, i1 +} +``` + +--- + +#### Authoring Guidance + +- use `scf.for` for regular counted loops and loop-carried scalar/tile state +- use `scf.if` for structured branching around PTO regions instead of inventing PTO-specific branch ops +- keep region results explicit with `scf.yield`; this is important for PTO analyses that track carried buffers and aliasing +- use `scf.while` only when a counted loop cannot express the control cleanly; `scf.for` remains the more common and better-exercised form in the current repository +- build branch predicates and loop conditions with `arith` ops, not PTO vector masks, unless the control decision truly comes from a scalarized value + +## Supported Data Types + +| Type | Bits | vreg Lanes | Description | +|------|------|-----------|-------------| +| `i8` / `si8` / `ui8` | 8 | 256 | Signless/signed/unsigned 8-bit integer | +| `i16` / `si16` / `ui16` | 16 | 128 | Signless/signed/unsigned 16-bit integer | +| `f16` | 16 | 128 | IEEE 754 half precision | +| `bf16` | 16 | 128 | Brain floating point | +| `i32` / `si32` / `ui32` | 32 | 64 | Signless/signed/unsigned 32-bit integer | +| `f32` | 32 | 64 | IEEE 754 single precision | +| `i64` / `si64` / `ui64` | 64 | 32 | Signless/signed/unsigned 64-bit integer | + +--- + +## Common Patterns + +### Softmax (Numerically Stable) + +```mlir +// 1. Find max +%max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %max_vec, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +%max_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> + +// 2. exp(x - max) using fused op +%exp = pto.vexpdiff %logits, %max_bc, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// 3. Sum +%sum = pto.vcadd %exp, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %sum, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +%sum_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> + +// 4. Divide +%softmax = pto.vdiv %exp, %sum_bc, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +### ReLU Variants + +```mlir +// Standard ReLU +%relu = pto.vrelu %input, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Leaky ReLU (scalar alpha) +%lrelu = pto.vlrelu %input, %alpha, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Parametric ReLU (per-element alpha) +%prelu = pto.vprelu %input, %alpha_vec : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +``` + +### Data Layout Conversion + +```mlir +// AoS → SoA (deinterleave) +%x, %y = pto.vldsx2 %ub_xy[%offset], "DINTLV" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// SoA → AoS (interleave) +pto.vstsx2 %x, %y, %ub_xy[%offset], "INTLV", %all_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, index, !pto.mask +``` + +--- + +--- + +## Quick Reference by Category + +### Memory Operations + +| Operation | Group | Description | +|-----------|-------|-------------| +| GM→UB DMA | 2 | `pto.copy_gm_to_ubuf` | +| UB→GM DMA | 2 | `pto.copy_ubuf_to_gm` | +| UB→UB Copy | 2 | `pto.copy_ubuf_to_ubuf` | +| Contiguous Load | 3 | `pto.vlds` with `NORM` dist | +| Broadcast Load | 3 | `pto.vlds` with `BRC` family dist | +| Gather | 3 | `pto.vgather2`, `pto.vgatherb` | +| Contiguous Store | 3 | `pto.vsts` with `NORM` dist | +| Scatter | 3 | `pto.vscatter` | + +### Compute Operations + +| Operation | Group | Description | +|-----------|-------|-------------| +| Element-wise Arithmetic | 6, 7 | `pto.vadd`, `pto.vmul`, `pto.vabs`, etc. | +| Scalar Operations | 8 | `pto.vadds`, `pto.vmuls`, etc. | +| Transcendental | 6 | `pto.vexp`, `pto.vln`, `pto.vsqrt`, etc. | +| Reduction | 10 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin` | +| Comparison | 11 | `pto.vcmp`, `pto.vcmps` | +| Selection | 11 | `pto.vsel`, `pto.vselr` | + +### Type & Data Manipulation + +| Operation | Group | Description | +|-----------|-------|-------------| +| Type Conversion | 9 | `pto.vcvt` | +| Interleave/Deinterleave | 12 | `pto.vintlv`, `pto.vdintlv` | +| Interleave/Deinterleave (not A5) | 12 | `pto.vintlvv2`, `pto.vdintlvv2` | + +### Synchronization + +| Operation | Group | Description | +|-----------|-------|-------------| +| Intra-core Sync | 1 | `pto.set_flag`, `pto.wait_flag` | +| Pipeline Buffer Sync | 1 | `pto.get_buf`, `pto.rls_buf` | + +### Scalar & Control Operations + +Group 14 covers the full scalar `arith` surface. The rows below list common PTO micro Instruction patterns rather than an exhaustive partition of `arith` ops. + +| Operation | Group | Description | +|-----------|-------|-------------| +| Scalar Constants | 14 | `arith.constant` | +| Scalar Integer / Index Arithmetic | 14 | `arith.addi`, `arith.subi`, `arith.muli`, `arith.divsi`, `arith.remui`, `arith.ceildivsi`, etc. | +| Scalar Floating-Point Arithmetic | 14 | `arith.addf`, `arith.subf`, `arith.mulf`, `arith.divf`, `arith.maximumf`, etc. | +| Scalar Compare & Select | 14 | `arith.cmpi`, `arith.cmpf`, `arith.select` | +| Scalar Casts / Width Changes | 14 | `arith.index_cast`, `arith.index_castui`, `arith.extsi`, `arith.extui`, `arith.trunci`, `arith.sitofp`, etc. | +| Scalar Bitwise / Shift Ops | 14 | `arith.andi`, `arith.ori`, `arith.xori`, `arith.shli`, `arith.shrsi`, `arith.shrui`, etc. | +| Counted Loops | 15 | `scf.for` | +| Conditional Regions | 15 | `scf.if`, `scf.yield` | +| Break-like Structured Loops | 15 | `scf.while`, `scf.condition`, `scf.yield` | diff --git a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py index 0db887de1..bbed8b7d8 100644 --- a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py +++ b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py @@ -778,6 +778,7 @@ def _collect_reachable_inline_procs( } _DMA_CALL_KEYWORDS: dict[str, frozenset[str]] = { + "set_mov_pad_val": frozenset({"pad_value"}), "set_loop2_stride_outtoub": frozenset({"src_stride", "dst_stride"}), "set_loop1_stride_outtoub": frozenset({"src_stride", "dst_stride"}), "set_loop_size_outtoub": frozenset({"loop1", "loop2"}), diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index 6e73e997b..a5c99d692 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -24,6 +24,7 @@ SemanticBindingRef, SemanticCallExpr, SemanticDmaConfigStmt, + SemanticDmaUnaryConfigStmt, SemanticDmaLoadStmt, SemanticDmaStoreStmt, SemanticExpr, @@ -456,6 +457,8 @@ def _render_stmt( return self._render_i64_pair_stmt("wait_flag_dev", stmt.core_id, stmt.event_id, env, indent=indent) if isinstance(stmt, SemanticWaitIntraCoreStmt): return self._render_i64_pair_stmt("wait_intra_core", stmt.block_id, stmt.event_id, env, indent=indent) + if isinstance(stmt, SemanticDmaUnaryConfigStmt): + return self._render_dma_unary_config(stmt, env, indent=indent) if isinstance(stmt, SemanticDmaConfigStmt): return self._render_dma_config(stmt, env, indent=indent) if isinstance(stmt, SemanticLowLevelCopyStmt): @@ -493,6 +496,21 @@ def _render_dma_config( ) return lines + def _render_dma_unary_config( + self, + stmt: SemanticDmaUnaryConfigStmt, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + lines: list[str] = [] + value = self._lower_expr(stmt.value, env, indent=indent, into=lines) + lines.append( + self._indent(indent) + + f"pto.{stmt.name} {value.name} : {self._render_type(value.type)}" + ) + return lines + def _render_buffer_sync_stmt( self, name: str, diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 6190a73dc..dbcde25db 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -268,6 +268,7 @@ def _classify_vcvt_elem_kind(dtype: ScalarType) -> str | None: _TERNARY_VECTOR_OPS = {"vaxpy", "vmula"} _MULTI_RESULT_VECTOR_OPS = {"vmull", "vldsx2", "vldus", "pstu"} _BROADCAST_VECTOR_OPS = {"vbr", "vdup", "vci"} +_LOW_LEVEL_DMA_UNARY_CONFIG_OPS = {"set_mov_pad_val"} _LOW_LEVEL_DMA_CONFIG_OPS = { "set_loop2_stride_outtoub", "set_loop1_stride_outtoub", @@ -281,6 +282,9 @@ def _classify_vcvt_elem_kind(dtype: ScalarType) -> str | None: "copy_ubuf_to_gm", "copy_ubuf_to_ubuf", } +_MOV_PAD_SUPPORTED_SCALAR_DTYPES = frozenset( + dtype.name for dtype in (i8, i16, i32, f16, bf16, f32) +) _COMPARE_SELECT_OPS = {"vcmp", "vcmps", "vsel", "vselr", "vselrv2"} _PREDICATE_MOVEMENT_OPS = {"pnot", "psel", "ppack", "punpack"} _CARRY_OPS = {"vaddc", "vsubc", "vaddcs", "vsubcs"} @@ -655,6 +659,12 @@ class SemanticDmaConfigStmt(SemanticStmt): second: SemanticExpr +@dataclass(frozen=True) +class SemanticDmaUnaryConfigStmt(SemanticStmt): + name: str + value: SemanticExpr + + @dataclass(frozen=True) class SemanticLowLevelCopyStmt(SemanticStmt): name: str @@ -1559,7 +1569,7 @@ def _is_low_level_dma_call(self, expr: FrontendExprNode) -> bool: return ( isinstance(expr, FrontendCallExpr) and expr.namespace == "pto" - and expr.name in _LOW_LEVEL_DMA_CONFIG_OPS | _LOW_LEVEL_DMA_COPY_OPS + and expr.name in _LOW_LEVEL_DMA_UNARY_CONFIG_OPS | _LOW_LEVEL_DMA_CONFIG_OPS | _LOW_LEVEL_DMA_COPY_OPS ) def _analyze_dma_stmt( @@ -1986,6 +1996,21 @@ def _analyze_low_level_dma_stmt( env, allow_outer_lookup=allow_outer_lookup, ) + if expr.name in _LOW_LEVEL_DMA_UNARY_CONFIG_OPS: + if len(args) != 1: + raise TypeError(f"pto.{expr.name} expects exactly 1 positional argument in TileLang DSL") + scalar = self._require_scalar_expr(args[0], f"pto.{expr.name} pad_value") + if scalar.dtype.name not in _MOV_PAD_SUPPORTED_SCALAR_DTYPES: + raise TypeError( + "pto.set_mov_pad_val pad_value must be one of i8, i16, i32, f16, bf16, or f32 in TileLang DSL v1" + ) + return ( + SemanticDmaUnaryConfigStmt( + name=expr.name, + value=args[0], + ), + dict(env), + ) if expr.name in _LOW_LEVEL_DMA_CONFIG_OPS: if len(args) != 2: raise TypeError(f"pto.{expr.name} expects exactly 2 positional arguments in TileLang DSL") @@ -2103,6 +2128,8 @@ def index_literal(value: int) -> SemanticLiteralExpr: def bool_literal(value: bool) -> SemanticLiteralExpr: return SemanticLiteralExpr(value=value, type=SemanticScalarType(dtype=i1)) + if expr.name == "set_mov_pad_val": + return (analyzed_keywords["pad_value"],) if expr.name in { "set_loop2_stride_outtoub", "set_loop1_stride_outtoub", diff --git a/tilelang-dsl/python/tilelang_dsl/support_matrix.py b/tilelang-dsl/python/tilelang_dsl/support_matrix.py index 61f17edc1..763171014 100644 --- a/tilelang-dsl/python/tilelang_dsl/support_matrix.py +++ b/tilelang-dsl/python/tilelang_dsl/support_matrix.py @@ -181,6 +181,7 @@ { "strict_vecscope", "store_scalar", + "set_mov_pad_val", "copy_gm_to_ubuf", "copy_ubuf_to_gm", "copy_ubuf_to_ubuf", @@ -221,6 +222,7 @@ ) ADVANCED_LOW_LEVEL_DMA_SURFACES = frozenset( { + "pto.set_mov_pad_val", "pto.copy_gm_to_ubuf", "pto.copy_ubuf_to_gm", "pto.copy_ubuf_to_ubuf", diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index a5bcd186e..9e08a48ac 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -38,6 +38,7 @@ SemanticBinaryExpr, SemanticCallExpr, SemanticDmaConfigStmt, + SemanticDmaUnaryConfigStmt, SemanticExprStmt, SemanticForStmt, SemanticGetBufStmt, @@ -431,6 +432,7 @@ def test_non_stable_surface_groups_keep_advanced_boundaries(self) -> None: self.assertIn("pto.strict_vecscope", ADVANCED_EXPLICIT_VECSCOPE_SURFACES) self.assertIn("pto.ptr", ADVANCED_RAW_POINTER_SURFACES) self.assertIn("pto.castptr", ADVANCED_RAW_POINTER_SURFACES) + self.assertIn("pto.set_mov_pad_val", ADVANCED_LOW_LEVEL_DMA_SURFACES) self.assertIn("pto.copy_ubuf_to_ubuf", ADVANCED_LOW_LEVEL_DMA_SURFACES) self.assertIn("pto.tile_with_strides", ADVANCED_TILE_HELPER_SURFACES) @@ -440,6 +442,7 @@ def test_non_stable_surface_groups_keep_advanced_boundaries(self) -> None: self.assertEqual(get_feature_tier("pto.castptr"), ADVANCED_TIER) self.assertEqual(get_feature_tier("pto.load_scalar"), ADVANCED_TIER) self.assertEqual(get_feature_tier("pto.store_scalar"), ADVANCED_TIER) + self.assertEqual(get_feature_tier("pto.set_mov_pad_val"), ADVANCED_TIER) self.assertEqual(get_feature_tier("pto.copy_ubuf_to_ubuf"), ADVANCED_TIER) self.assertEqual(get_feature_tier("pto.tile_with_strides"), ADVANCED_TIER) @@ -3870,6 +3873,41 @@ def kernel(inp: pto.TensorView, dst: pto.Tile): r"pto\.copy_gm_to_ubuf %gm_ptr_\d+, %ub_ptr_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %false, %tmp_\d+, %tmp_\d+, %tmp_\d+", ) + def test_set_mov_pad_val_lowers_in_advanced_mode(self) -> None: + @pto.vkernel(op="set_mov_pad_val_dma_unique", dtypes=[(pto.f32, pto.f32)], advanced=True) + def kernel(inp: pto.TensorView, dst: pto.Tile): + gm_ptr = inp.as_ptr() + ub_ptr = dst.as_ptr() + + pto.set_mov_pad_val(pad_value=pto.f32(0.0)) + pto.set_loop2_stride_outtoub(src_stride=4096, dst_stride=2048) + pto.set_loop1_stride_outtoub(src_stride=1024, dst_stride=512) + pto.set_loop_size_outtoub(loop1=1, loop2=1) + pto.copy_gm_to_ubuf( + src=gm_ptr, + dst=ub_ptr, + n_burst=1, + len_burst=64, + gm_stride=128, + ub_stride=128, + enable_ub_pad=True, + ) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertTrue(any(isinstance(stmt, SemanticDmaUnaryConfigStmt) for stmt in semantic_kernel.body)) + + text = specialized.mlir_text() + self.assertRegex(text, r"pto\.set_mov_pad_val %[^ ]+ : f32") + self.assertRegex( + text, + r"pto\.copy_gm_to_ubuf %gm_ptr_\d+, %ub_ptr_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %true, %tmp_\d+, %tmp_\d+, %tmp_\d+", + ) + def test_copy_ubuf_to_gm_keyword_surface_lowers_in_advanced_mode(self) -> None: @pto.vkernel(op="tile_to_tensorview_dma_unique", dtypes=[(pto.f32, pto.f32)], advanced=True) def kernel(src: pto.Tile, dst: pto.TensorView): @@ -4919,6 +4957,23 @@ def kernel(x: pto.Tile): self.assertIn("advanced family surface `pto.vreduce`", str(ctx.exception)) + def test_set_mov_pad_val_rejects_unsupported_scalar_dtype(self) -> None: + with self.assertRaises(TypeError) as ctx: + + @pto.vkernel(op="set_mov_pad_val_bad_dtype_unique", dtypes=[(pto.f32,)], advanced=True) + def kernel(dst: pto.Tile): + pto.set_mov_pad_val(pto.i64(0)) + return None + + kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB) + ).mlir_text() + + self.assertIn( + "pto.set_mov_pad_val pad_value must be one of i8, i16, i32, f16, bf16, or f32", + str(ctx.exception), + ) + def test_unsupported_python_syntax_reports_source_location(self) -> None: with self.assertRaises(pto.TileLangFrontendError) as ctx: From e9932bbc8ef01a5dcf53ae6799f8f9e9b545c6ec Mon Sep 17 00:00:00 2001 From: qukelin Date: Wed, 15 Apr 2026 20:12:35 +0800 Subject: [PATCH 086/192] fix: support unsigned dtypes in ExpandTileOp --- lib/PTO/Transforms/ExpandTileOp.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/lib/PTO/Transforms/ExpandTileOp.cpp b/lib/PTO/Transforms/ExpandTileOp.cpp index 829b49f58..a398896bd 100644 --- a/lib/PTO/Transforms/ExpandTileOp.cpp +++ b/lib/PTO/Transforms/ExpandTileOp.cpp @@ -168,6 +168,14 @@ static std::string getDtypeString(Type elemTy) { if (elemTy.isF32()) return "f32"; if (elemTy.isF16()) return "f16"; if (elemTy.isBF16()) return "bf16"; + if (elemTy.isUnsignedInteger(64)) return "ui64"; + if (elemTy.isUnsignedInteger(32)) return "ui32"; + if (elemTy.isUnsignedInteger(16)) return "ui16"; + if (elemTy.isUnsignedInteger(8)) return "ui8"; + if (elemTy.isSignedInteger(64)) return "si64"; + if (elemTy.isSignedInteger(32)) return "si32"; + if (elemTy.isSignedInteger(16)) return "si16"; + if (elemTy.isSignedInteger(8)) return "si8"; if (elemTy.isSignlessInteger(64)) return "i64"; if (elemTy.isSignlessInteger(32)) return "i32"; if (elemTy.isSignlessInteger(16)) return "i16"; From 3762b68eaf6509651d5ea01ed3f1261e3e4575c6 Mon Sep 17 00:00:00 2001 From: qukelin Date: Thu, 16 Apr 2026 14:06:44 +0800 Subject: [PATCH 087/192] fix(tilelang-dsl): keep vbitsort and vmrgsort4 out of inferred vecscope --- lib/PTO/IR/VPTO.cpp | 2 ++ tilelang-dsl/python/tilelang_dsl/semantic.py | 16 ++++++++-- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 31 ++++++++++++++++++++ 3 files changed, 47 insertions(+), 2 deletions(-) diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp index 5aee3cc6c..420873d21 100644 --- a/lib/PTO/IR/VPTO.cpp +++ b/lib/PTO/IR/VPTO.cpp @@ -1413,6 +1413,8 @@ LogicalResult Vmrgsort4Op::verify() { classifyMemoryRole(getSource2().getType()) != MemoryRole::UB || classifyMemoryRole(getSource3().getType()) != MemoryRole::UB) return emitOpError("requires UB-backed destination and sources"); + if (failed(verifyNotNestedInVecScope(*this, "pto.vmrgsort4"))) + return failure(); return success(); } diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index dbcde25db..52fa40551 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -289,12 +289,13 @@ def _classify_vcvt_elem_kind(dtype: ScalarType) -> str | None: _PREDICATE_MOVEMENT_OPS = {"pnot", "psel", "ppack", "punpack"} _CARRY_OPS = {"vaddc", "vsubc", "vaddcs", "vsubcs"} _REARRANGEMENT_OPS = {"vintlv", "vdintlv", "vintlvv2", "vdintlvv2"} +_UB_HELPER_OPS = {"vbitsort", "vmrgsort4"} _ADVANCED_VECTOR_ACTIVITY_OPS = ( _COMPARE_SELECT_OPS | _PREDICATE_MOVEMENT_OPS | _CARRY_OPS | _REARRANGEMENT_OPS - | {"vcvt", "vbitsort", "vmrgsort4"} + | {"vcvt"} ) _TENSORVIEW_RANK = 5 @@ -1124,7 +1125,11 @@ def _frontend_stmt_is_vecscope_boundary(self, stmt: FrontendStmtNode) -> bool: return not stmt.is_constexpr return ( isinstance(stmt, FrontendExprStmt) - and (self._is_dma_call(stmt.expr) or self._is_sync_call(stmt.expr)) + and ( + self._is_dma_call(stmt.expr) + or self._is_sync_call(stmt.expr) + or self._is_ub_helper_call(stmt.expr) + ) ) def _constexpr_if_contains_vector_activity(self, stmt: FrontendIfStmt) -> bool: @@ -1565,6 +1570,13 @@ def _is_sync_call(self, expr: FrontendExprNode) -> bool: } ) + def _is_ub_helper_call(self, expr: FrontendExprNode) -> bool: + return ( + isinstance(expr, FrontendCallExpr) + and expr.namespace == "pto" + and expr.name in _UB_HELPER_OPS + ) + def _is_low_level_dma_call(self, expr: FrontendExprNode) -> bool: return ( isinstance(expr, FrontendCallExpr) diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index 9e08a48ac..37f0e1512 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -2693,6 +2693,37 @@ def kernel(dst: pto.Tile, src: pto.Tile, idx: pto.Tile): r"!pto\.ptr, !pto\.ptr, !pto\.ptr, !pto\.ptr, !pto\.ptr, i64, i64", ) + def test_vbitsort_helper_stays_outside_inferred_vecscope(self) -> None: + @pto.vkernel( + op="vbitsort_vecscope_boundary_unique", + dtypes=[(pto.f32, pto.f32, pto.i32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile, idx: pto.Tile): + dst_ptr = dst.as_ptr() + src_ptr = src.as_ptr() + idx_ptr = idx.as_ptr() + + pto.vbitsort(dst_ptr, src_ptr, idx_ptr, 1) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 256), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 256), memory_space=pto.MemorySpace.UB), + idx=pto.TileSpecialization(shape=(8, 256), memory_space=pto.MemorySpace.UB), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + vecscope_stmts = [stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticVecscopeStmt)] + self.assertEqual(vecscope_stmts, []) + + text = specialized.mlir_text() + self.assertRegex( + text, + r"pto\.vbitsort %dst_ptr_\d+, %src_ptr_\d+, %idx_ptr_\d+, %c1 : !pto\.ptr, !pto\.ptr, !pto\.ptr, index", + ) + self.assertNotIn("pto.vecscope {", text) + def test_vcvt_rejects_legacy_string_spellings(self) -> None: with self.assertRaises(TypeError) as ctx: From 10dad242c41bd69f4491e945cb52bc415605dc07 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Thu, 16 Apr 2026 17:03:48 +0800 Subject: [PATCH 088/192] feat(tilelang-dsl): add select_kernel diagnostics report Introduce opt-in select_kernel metadata reporting with per-candidate stage diagnostics, MLIR materialization visibility, updated docs and regression coverage.\n\nAlso sync and archive the related OpenSpec change and disable CMake linker-generated dependency files in tilelang_st so bisheng-based sim builds remain compatible.\n\nCloses #89 --- .../.openspec.yaml | 2 + .../design.md | 137 ++++ .../proposal.md | 35 + .../specs/tilelang-dsl-diagnostics/spec.md | 32 + .../specs/tilelang-dsl-kernel-matcher/spec.md | 100 +++ .../tasks.md | 31 + .../specs/tilelang-dsl-diagnostics/spec.md | 31 + .../specs/tilelang-dsl-kernel-matcher/spec.md | 84 ++- test/tilelang_st/npu/a5/src/st/CMakeLists.txt | 7 + .../matcher-and-advanced-surface-migration.md | 27 + .../docs/user_guide/03-kernel-declaration.md | 44 ++ tilelang-dsl/python/tilelang_dsl/__init__.py | 4 + .../python/tilelang_dsl/expand_helper.py | 1 + tilelang-dsl/python/tilelang_dsl/kernel.py | 641 ++++++++++++++++-- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 240 +++++++ 15 files changed, 1357 insertions(+), 59 deletions(-) create mode 100644 openspec/changes/archive/2026-04-16-improve-tilelang-dsl-kernel-selection-diagnostics/.openspec.yaml create mode 100644 openspec/changes/archive/2026-04-16-improve-tilelang-dsl-kernel-selection-diagnostics/design.md create mode 100644 openspec/changes/archive/2026-04-16-improve-tilelang-dsl-kernel-selection-diagnostics/proposal.md create mode 100644 openspec/changes/archive/2026-04-16-improve-tilelang-dsl-kernel-selection-diagnostics/specs/tilelang-dsl-diagnostics/spec.md create mode 100644 openspec/changes/archive/2026-04-16-improve-tilelang-dsl-kernel-selection-diagnostics/specs/tilelang-dsl-kernel-matcher/spec.md create mode 100644 openspec/changes/archive/2026-04-16-improve-tilelang-dsl-kernel-selection-diagnostics/tasks.md diff --git a/openspec/changes/archive/2026-04-16-improve-tilelang-dsl-kernel-selection-diagnostics/.openspec.yaml b/openspec/changes/archive/2026-04-16-improve-tilelang-dsl-kernel-selection-diagnostics/.openspec.yaml new file mode 100644 index 000000000..3a54a172b --- /dev/null +++ b/openspec/changes/archive/2026-04-16-improve-tilelang-dsl-kernel-selection-diagnostics/.openspec.yaml @@ -0,0 +1,2 @@ +schema: spec-driven +created: 2026-04-16 diff --git a/openspec/changes/archive/2026-04-16-improve-tilelang-dsl-kernel-selection-diagnostics/design.md b/openspec/changes/archive/2026-04-16-improve-tilelang-dsl-kernel-selection-diagnostics/design.md new file mode 100644 index 000000000..1d9347715 --- /dev/null +++ b/openspec/changes/archive/2026-04-16-improve-tilelang-dsl-kernel-selection-diagnostics/design.md @@ -0,0 +1,137 @@ +## Context + +本 change 只覆盖 TileLang DSL matcher 的“可观测性”增强,不改变既有 kernel 选择主线的 deterministic 语义。 +当前实现中,`select_kernel(...)` 先做 `target/op/dtype` 过滤,再用布尔式 constraint evaluation 继续筛选;一旦所有候选都在 constraint 阶段被移除,调用方只能看到通用的 `"after constraint evaluation"` 报错。 +对于依赖 `constraints`、`priority`、多 signature `dtypes` 和模板复用的 kernel 作者来说,这个输出不足以定位失败根因。 + +实现约束: + +- 默认 `select_kernel(...)` 路径必须保持兼容,不能破坏现有返回值和回归测试。 +- selection 顺序仍然固定为 `target -> op -> dtype signature -> constraints -> priority -> tie error`。 +- 诊断输出必须围绕“通过 `target/op` 的候选”展开,避免把整个 registry 全量转储成噪声。 +- report 模式不得吞掉已有失败信息;相反,它需要把 dtype/constraint/materialization 失败显式结构化。 + +## Goals / Non-Goals + +**Goals:** + +- 让 `select_kernel(...)` 在 opt-in 模式下返回逐候选 metadata/report,而不是只给最终 winner 或通用失败。 +- 让 report 明确区分 `dtype_mismatch`、`constraint_failed`、`constraint_error`、`priority_shadowed`、`selected` 和 `mlir_error` 等阶段结果。 +- 在不改变默认 API 兼容性的前提下,为成功穿过 constraint 阶段的候选补齐 MLIR 可见性。 +- 为后续 `expand_helper`、tests 和文档提供稳定的诊断契约。 + +**Non-Goals:** + +- 不改变 matcher 的核心选择顺序,也不引入新的隐式 tiebreak 规则。 +- 不把 report 范围扩展到所有 registry descriptor;`target/op` 都不匹配的 kernel 继续不进入报告。 +- 不在本 change 中引入新的 kernel authoring surface、constraint 语法或额外 matcher capability。 +- 不要求默认异常路径立即改写为详细长报文;重点是新增结构化 opt-in 诊断通道。 + +## Decisions + +### 1. 保持默认 `select_kernel(...)` 兼容,新增 opt-in report 模式 + +决策: + +- `select_kernel(...)` 继续保留默认“返回单个 `VKernelDescriptor` / 抛异常”的行为。 +- 新增 opt-in 参数,使调用方可以请求结构化 `KernelSelectionReport`。 +- report 模式下,no-candidate 和 priority-tie 不再通过第一时间抛出 `LookupError` 丢失上下文,而是收敛为 `final_status` / `final_error`。 + +原因: + +- 这能最大限度保护现有 tests、examples 和调用栈。 +- 与直接改写 `select_kernel` 返回值相比,开关模式更容易渐进接入。 + +备选方案: + +- 直接把 `select_kernel(...)` 改成始终返回 metadata 列表。 + 放弃原因:breaking change 风险太高。 +- 另起一个 `explain_select_kernel(...)` API。 + 放弃原因:会复制大部分选择逻辑,容易让默认路径和诊断路径再度漂移。 + +### 2. 报告范围固定为“通过 `target/op` 过滤的候选”,并显式暴露 `dtype` 不匹配 + +决策: + +- 只对通过 `target` 和 concrete `op` 过滤的 descriptor 生成候选 metadata。 +- 即使某个候选在 `dtype` 阶段失败,也必须进入 report,并标记为 `dtype_mismatch`。 + +原因: + +- 对 kernel 作者来说,`target/op` 通过但 `dtype` 没命中,是最需要被看见的失败之一。 +- 如果把整个 registry 的 target/op miss 也塞进来,报告很快就会变成噪声。 + +### 3. 将选择流程拆成结构化阶段结果,而不是继续复用布尔筛选 + +决策: + +- 把当前“匹配即保留、不匹配即丢弃”的 helper 改成可返回阶段性结果的内部结构。 +- constraint 评估不再只返回 `bool`;它需要产出: + - 是否通过 + - 失败 constraint 的索引 + - 失败 callable 名称或 `qualname` + - `False` 失败与异常失败的区分 + - 异常类型与消息摘要 +- 顶层 `select_kernel(...)` 在 report 模式下汇总每个候选的阶段状态和最终决策结果。 + +原因: + +- 只有把选择阶段显式建模,才能稳定输出“挂在哪一步”的 metadata。 +- 这也能避免未来再把 default path 和 diagnostics path 写成两套不一致逻辑。 + +### 4. 对通过 constraint 阶段的候选尝试 materialization,并把 MLIR 成功/失败都保留下来 + +决策: + +- report 模式启用 MLIR 采集时,对所有通过 `constraints` 的候选尝试 `mlir_text()`。 +- 成功时记录 `mlir_text`。 +- 若因为 specialization/context 不完整或其他 materialization 问题失败,记录 `mlir_error`,但该候选仍保留在 report 中。 + +原因: + +- 用户要的不是“猜测哪个 kernel 理论上可用”,而是看到“匹配成功的 kernel 最终会产出什么 MLIR,或者为什么连 MLIR 都拿不到”。 +- 把 materialization 失败单独结构化,能避免它被误解成 matcher 失败。 + +### 5. report 结果必须保留最终决策摘要 + +决策: + +- 顶层 report 统一包含: + - `selected` + - `candidates` + - `final_status` + - `final_error` +- `final_status` 至少覆盖: + - `selected` + - `no_candidate` + - `priority_tie` + +原因: + +- 这让调用方既能读逐候选细节,也能快速知道这次 query 的总结论。 +- `expand_helper` 和未来 CLI/debug tooling 也更容易消费统一结构。 + +## Risks / Trade-offs + +- [Risk] report 模式需要更多内部数据结构,增加 matcher 代码复杂度 + Mitigation:把阶段结果抽成小而稳定的内部 helper,避免把 `select_kernel(...)` 主逻辑写成大段分支。 + +- [Risk] 为多个候选尝试 `mlir_text()` 可能增加开销 + Mitigation:将 MLIR 采集保留为 opt-in 行为,并只对通过 constraint 阶段的候选执行。 + +- [Risk] dual-mode API 可能让调用方混淆返回类型 + Mitigation:显式参数命名、补齐 public docs/examples,并在类型导出中给出清晰的数据模型名。 + +- [Risk] constraint callable 可能是匿名 `lambda`,名字信息不稳定 + Mitigation:至少保证 `constraint` 索引稳定可见;名字信息在可解析时补充,但不作为唯一定位手段。 + +## Migration Plan + +1. 先落 OpenSpec delta,冻结 report 模式和 selector diagnostics 契约。 +2. 在 `kernel.py` 中引入 report 数据模型与阶段结果 helper,保留默认路径兼容。 +3. 补齐 tests 和 docs,再让上游调用方按需接入新 report 模式。 +4. 如实现过程中发现 payload 过重,可保留 report 结构不变,只把 MLIR 采集降为可选。 + +## Open Questions + +- 当前无必须阻断本 change 的开放问题;若后续需要把 source location 也纳入 selector diagnostics,可作为本 change 的增量修订,而不是前置阻断项。 diff --git a/openspec/changes/archive/2026-04-16-improve-tilelang-dsl-kernel-selection-diagnostics/proposal.md b/openspec/changes/archive/2026-04-16-improve-tilelang-dsl-kernel-selection-diagnostics/proposal.md new file mode 100644 index 000000000..ef990001c --- /dev/null +++ b/openspec/changes/archive/2026-04-16-improve-tilelang-dsl-kernel-selection-diagnostics/proposal.md @@ -0,0 +1,35 @@ +## Why + +当前 `tilelang_dsl.select_kernel(...)` 在匹配多个 kernel descriptor 时,只会返回最终选中的 descriptor,或者在候选全部失败时抛出通用 `LookupError`。 +当失败发生在 `dtype` 过滤、`constraints` 评估或 materialization 预检查阶段时,kernel 作者很难看出“是哪一个候选、挂在第几个 constraint、还是 MLIR 构造本身失败”,这已经成为 matcher/模板 kernel 迭代的直接阻碍。 + +## What Changes + +- 为 `pto.select_kernel(...)` 增加 opt-in 的 selection metadata/report 模式,同时保持默认返回 `VKernelDescriptor` 的兼容行为不变。 +- 在 report 模式下,覆盖所有通过 `target/op` 过滤的候选,并显式记录 `dtype` 不匹配、constraint 失败、constraint 异常、priority 落败和最终选中状态。 +- 对通过 `constraints` 的候选,在启用 MLIR 采集时尝试生成 `mlir_text()`,成功时返回 MLIR 文本,失败时返回结构化 `mlir_error`,而不是丢失候选信息。 +- 为 selector diagnostics 补齐“失败在第几个 constraint、可调用名是什么、失败原因是什么”的正式契约,并要求 no-candidate / priority-tie 结果保留完整候选上下文。 + +## Capabilities + +### New Capabilities + +None. + +### Modified Capabilities + +- `tilelang-dsl-kernel-matcher`: 扩展 `select_kernel(...)` 的公共返回契约,允许调用方以 opt-in 方式获得逐候选 selection report,同时不改变现有 deterministic 选择顺序。 +- `tilelang-dsl-diagnostics`: 为 matcher/selector 补充结构化诊断契约,明确 `dtype` 不匹配、constraint false、constraint exception 和 materialization error 的可见性。 + +## Impact + +- 受影响源码: + - `tilelang-dsl/python/tilelang_dsl/kernel.py` + - `tilelang-dsl/python/tilelang_dsl/__init__.py` + - `tilelang-dsl/python/tilelang_dsl/expand_helper.py`(如需消费新 report 结果) +- 受影响测试与文档: + - `tilelang-dsl/tests/test_tilelang_dsl_v1.py` + - `tilelang-dsl/docs/` + - `openspec/specs/` +- 受影响 public API: + - `pto.select_kernel(...)` diff --git a/openspec/changes/archive/2026-04-16-improve-tilelang-dsl-kernel-selection-diagnostics/specs/tilelang-dsl-diagnostics/spec.md b/openspec/changes/archive/2026-04-16-improve-tilelang-dsl-kernel-selection-diagnostics/specs/tilelang-dsl-diagnostics/spec.md new file mode 100644 index 000000000..612a71ea0 --- /dev/null +++ b/openspec/changes/archive/2026-04-16-improve-tilelang-dsl-kernel-selection-diagnostics/specs/tilelang-dsl-diagnostics/spec.md @@ -0,0 +1,32 @@ +## ADDED Requirements + +### Requirement: selector diagnostics MUST identify the failing stage and failing constraint for each reported candidate + +当调用方启用 `select_kernel(...)` 的 report/metadata 模式时,TileLang DSL diagnostics MUST 为每个候选明确指出其失败或胜出的阶段。 +对于 constraint 阶段失败的候选,诊断 MUST 至少包含: + +- 失败的 constraint 索引 +- 若可解析则包含 callable 名称或 `qualname` +- 区分“predicate 返回 `False`”与“constraint 执行抛异常” +- 可读的失败原因文本 + +对于 `dtype` 不匹配、priority 落败和 materialization 失败,diagnostics 也 MUST 使用不同的 kind/status 表达,而不是统一折叠成同一种通用错误。 +selector diagnostics MUST 让调用方能仅凭 report 判定候选究竟挂在 `dtype`、`constraints`、`priority` 还是 materialization。 + +#### Scenario: false-returning constraint is reported with index and callable identity + +- **WHEN** 某个候选在 constraint evaluation 中命中第 `N` 个 constraint,且该 callable 返回 `False` +- **THEN** selection diagnostics MUST 报告该候选失败在第 `N` 个 constraint +- **AND** diagnostics MUST 在可解析时包含该 constraint 的 callable 名称或 `qualname` + +#### Scenario: raising constraint is reported as a constraint error + +- **WHEN** 某个 constraint 在执行时抛出异常 +- **THEN** selection diagnostics MUST 将该候选标记为 `constraint_error` +- **AND** diagnostics MUST 包含异常类型与消息摘要,而不是只给通用失败 + +#### Scenario: materialization failure remains distinguishable from matcher failure + +- **WHEN** 某个候选通过 `dtype` 与 `constraints`,但在 `mlir_text()` materialization 时失败 +- **THEN** selection diagnostics MUST 将该候选标记为 materialization 相关失败 +- **AND** MUST NOT 把该失败重新表述为 `dtype` mismatch 或 constraint failure diff --git a/openspec/changes/archive/2026-04-16-improve-tilelang-dsl-kernel-selection-diagnostics/specs/tilelang-dsl-kernel-matcher/spec.md b/openspec/changes/archive/2026-04-16-improve-tilelang-dsl-kernel-selection-diagnostics/specs/tilelang-dsl-kernel-matcher/spec.md new file mode 100644 index 000000000..fb9566237 --- /dev/null +++ b/openspec/changes/archive/2026-04-16-improve-tilelang-dsl-kernel-selection-diagnostics/specs/tilelang-dsl-kernel-matcher/spec.md @@ -0,0 +1,100 @@ +## MODIFIED Requirements + +### Requirement: TileLang DSL MUST provide an explicit kernel registry and selection API + +当同一 `target/op` 下存在多个 `@pto.vkernel` descriptor 时,TileLang DSL MUST 将它们注册到显式、可查询的 `KernelRegistry`。 +默认 registry MUST 是 module-level 对象;调用方 MAY 传入自定义 registry 以获得隔离的候选集合。 +系统 MUST 提供显式 selection API `pto.select_kernel(target, op, operand_types, context_attrs=None, registry=None, return_metadata=False, include_mlir=True)`,用于在给定 `target`、concrete `op`、operand type 信息和上下文属性时选择 kernel。 +descriptor MUST 支持两种互斥的 matcher 元数据: + +- `op=""` +- `ops=["", "", ...]` + +descriptor MUST 至少提供其中一种,且实现 MUST NOT 同时接受两者。 +当 selector 命中一个 `ops=[...]` descriptor 时,无论是默认路径还是 report 路径,结果都 MUST 绑定当前 query 对应的唯一 concrete `selected_op`,再进入后续 `specialize()` / `mlir_text()` / `verify()` 流程。 +当 `return_metadata=False` 时,`pto.select_kernel(...)` MUST 继续返回唯一选中的 descriptor,并保持现有异常行为兼容。 +当 `return_metadata=True` 时,`pto.select_kernel(...)` MUST 返回结构化 selection report;该 report MUST 暴露最终 winner、逐候选状态和最终决策摘要,而不是只暴露一个 descriptor 或通用失败字符串。 +实现 MUST NOT 依赖扫描 Python globals、locals 或导入顺序来隐式发现候选。 + +#### Scenario: selector returns the unique best kernel in default mode + +- **WHEN** registry 中存在多个针对同一 `target/op` 的 kernel descriptor,且其中一个在全部匹配步骤后成为唯一最佳候选 +- **THEN** `pto.select_kernel(..., return_metadata=False)` MUST 返回该 descriptor +- **AND** 返回结果 MUST 可继续走 `specialize()` / `mlir_text()` / `verify()` 流程 + +#### Scenario: custom registry restricts the candidate set explicitly + +- **WHEN** 调用方显式传入一个只含局部 kernel 的 `KernelRegistry` +- **THEN** selector MUST 只在该 registry 的候选集合内做匹配和决策 +- **AND** MUST NOT 回退去查询 module-level 默认 registry + +#### Scenario: selector binds the concrete op for a multi-op descriptor + +- **WHEN** 一个 descriptor 通过 `ops=["tadd", "tsub", "tmul", "tdiv"]` 注册,且调用方以 `pto.select_kernel(..., op="tmul", ...)` 查询命中该 descriptor +- **THEN** selector MUST 返回已经绑定 `selected_op="tmul"` 的 descriptor,或在 report 模式下返回绑定了该 `selected_op` 的候选记录 +- **AND** 后续 materialization MUST 基于该 concrete `selected_op` 而不是未绑定的原始 matcher 集合 + +#### Scenario: selector can return a structured selection report + +- **WHEN** 调用方以 `pto.select_kernel(..., return_metadata=True)` 查询 kernel +- **THEN** selector MUST 返回包含 `selected`、`candidates`、`final_status` 和 `final_error` 的结构化 report +- **AND** 该 report MUST 保持与默认路径一致的 winner 决策语义 + +## ADDED Requirements + +### Requirement: selection report MUST preserve per-candidate stage results for all target/op-matched descriptors + +当调用方启用 `return_metadata=True` 时,selector MUST 为所有通过 `target` 与 concrete `op` 过滤的 descriptor 生成逐候选 metadata。 +每个候选记录 MUST 至少包含: + +- kernel identity(如 `name`、`priority`、`match_ops`) +- 当前 query 下绑定后的 concrete `selected_op` +- 匹配到的 concrete dtype signature,或明确的 dtype mismatch 信息 +- 一个稳定的阶段状态,至少覆盖: + - `dtype_mismatch` + - `constraint_failed` + - `constraint_error` + - `priority_shadowed` + - `selected` + - `mlir_error` + +候选在 `dtype` 阶段失败时,selector MUST 保留该候选并显式标记为 `dtype_mismatch`,而不是直接从 report 中丢弃。 +若没有任何候选通过后续选择,顶层 report MUST 通过 `final_status` / `final_error` 明确表达 `no_candidate` 或 `priority_tie`,同时保留全部候选记录。 +report 模式 MUST NOT 改变 matcher 的既有选择顺序或 winner 决策结果。 + +#### Scenario: dtype-mismatched descriptor still appears in the report + +- **WHEN** 某个 descriptor 通过 `target/op` 过滤,但没有任何 `dtypes` signature 能匹配当前 operand types +- **THEN** selection report MUST 仍包含该 descriptor 的候选记录 +- **AND** 该候选 MUST 标记为 `dtype_mismatch` + +#### Scenario: report preserves no-candidate outcome without losing candidate context + +- **WHEN** 所有通过 `target/op` 的候选最终都在 `dtype`、`constraints` 或 materialization 阶段失败 +- **THEN** selection report MUST 将 `final_status` 标记为 `no_candidate` +- **AND** `candidates` 中 MUST 保留每个候选的失败阶段和失败原因 + +#### Scenario: report preserves priority tie outcome without dropping tied winners + +- **WHEN** 多个候选在 `target/op/dtype/constraints` 全部通过后拥有相同最高 `priority` +- **THEN** selection report MUST 将 `final_status` 标记为 `priority_tie` +- **AND** report MUST 保留所有 tie 候选,而不是隐式只保留其中一个 + +### Requirement: selection report MUST optionally include MLIR materialization results for successful candidates + +当 `return_metadata=True` 且 `include_mlir=True` 时,selector MUST 对所有通过 `constraints` 阶段、并进入 priority 决策的候选尝试 materialization。 +若 `mlir_text()` 成功,候选记录 MUST 暴露 `mlir_text`。 +若 materialization 因 specialization/context 不完整或其他 frontend 失败而未成功,候选记录 MUST 暴露 `mlir_error`,并继续保留其余阶段信息。 +当 `include_mlir=False` 时,selector MUST 允许调用方跳过该 materialization 尝试,但其余候选 metadata 仍 MUST 可用。 + +#### Scenario: successful candidate carries MLIR text in the report + +- **WHEN** 某个候选通过选择阶段并且 `mlir_text()` materialization 成功 +- **THEN** 该候选的 selection metadata MUST 包含 `mlir_text` +- **AND** 其阶段状态 MUST 继续反映它是 `selected` 或 `priority_shadowed` + +#### Scenario: materialization failure is captured without losing the candidate + +- **WHEN** 某个候选通过 `dtype` 与 `constraints`,但在 `mlir_text()` 时因为 specialization/context 不完整而失败 +- **THEN** selection report MUST 保留该候选 +- **AND** 该候选 MUST 包含 `mlir_error`,而不是被重新归类为 `dtype` 或 constraint 失败 diff --git a/openspec/changes/archive/2026-04-16-improve-tilelang-dsl-kernel-selection-diagnostics/tasks.md b/openspec/changes/archive/2026-04-16-improve-tilelang-dsl-kernel-selection-diagnostics/tasks.md new file mode 100644 index 000000000..d36a96fd3 --- /dev/null +++ b/openspec/changes/archive/2026-04-16-improve-tilelang-dsl-kernel-selection-diagnostics/tasks.md @@ -0,0 +1,31 @@ +## 1. OpenSpec 契约落定 + +- [x] 1.1 完成 `specs/tilelang-dsl-kernel-matcher/spec.md` delta,固定 `select_kernel(...)` 的 opt-in report 模式、候选覆盖范围与 `final_status` 契约。 +- [x] 1.2 完成 `specs/tilelang-dsl-diagnostics/spec.md` delta,固定 `dtype` / constraint / materialization 失败的结构化诊断字段。 +- [x] 1.3 在 `proposal.md` 和 `design.md` 中明确本 change 不改变 matcher 选择顺序,只增强可观测性与调试输出。 + +## 2. Selection report 数据模型 + +- [x] 2.1 在 `tilelang-dsl/python/tilelang_dsl/kernel.py` 中引入公共或半公共的 selection report / candidate metadata 数据模型,并保持默认返回 `VKernelDescriptor` 的兼容路径。 +- [x] 2.2 把 `target/op`、`dtype`、`constraints`、`priority` 的内部求值拆成可记录阶段结果的 helper,而不是只做布尔筛选。 +- [x] 2.3 让 constraint 评估返回结构化结果,至少记录失败 constraint 索引、可调用名和异常摘要。 + +## 3. API 接线与 materialization 可见性 + +- [x] 3.1 为 `pto.select_kernel(...)` 接线 opt-in report 参数,统一填充 `selected`、`candidates`、`final_status` 与 `final_error`。 +- [x] 3.2 为通过 constraint 阶段的候选接线 MLIR 采集,成功时返回 `mlir_text`,失败时返回 `mlir_error`。 +- [x] 3.3 评估并更新 `tilelang-dsl/python/tilelang_dsl/__init__.py` 与 `expand_helper.py` 的导出/消费边界,确保新 report 类型可被稳定访问。 + +## 4. 回归、文档与验证 + +- [x] 4.1 在 `tilelang-dsl/tests/test_tilelang_dsl_v1.py` 中增加回归,覆盖默认兼容路径、`dtype_mismatch`、constraint false、constraint exception、priority tie 和 no-candidate report。 +- [x] 4.2 增加 MLIR 可见性回归,覆盖成功候选附带 `mlir_text` 与 materialization 失败候选附带 `mlir_error`。 +- [x] 4.3 更新 `tilelang-dsl/docs/` 中与 matcher 相关的用户文档,说明何时使用 report 模式以及如何读取失败原因。 +- [x] 4.4 运行并记录最小验证命令,至少覆盖相关 `unittest` 子集与 `openspec validate improve-tilelang-dsl-kernel-selection-diagnostics --type change --strict --json --no-interactive`。 + +### 验证记录 + +- `python3 -m py_compile tilelang-dsl/python/tilelang_dsl/kernel.py tilelang-dsl/python/tilelang_dsl/__init__.py tilelang-dsl/python/tilelang_dsl/expand_helper.py tilelang-dsl/tests/test_tilelang_dsl_v1.py` +- `PYTHONPATH=tilelang-dsl/python python3 -m unittest discover -s tilelang-dsl/tests -p 'test_tilelang_dsl_v1.py' -k 'select_kernel'` +- `PYTHONPATH=tilelang-dsl/python python3 -m unittest discover -s tilelang-dsl/tests -p 'test_tilelang_dsl_v1.py' -k 'select_kernel_report_mode'` +- `openspec validate improve-tilelang-dsl-kernel-selection-diagnostics --type change --strict --json --no-interactive` diff --git a/openspec/specs/tilelang-dsl-diagnostics/spec.md b/openspec/specs/tilelang-dsl-diagnostics/spec.md index 2d7d1988b..c741ab267 100644 --- a/openspec/specs/tilelang-dsl-diagnostics/spec.md +++ b/openspec/specs/tilelang-dsl-diagnostics/spec.md @@ -61,3 +61,34 @@ TileLang DSL v1 的 frontend diagnostics MUST 包含 DSL 源位置和语义原 - **WHEN** frontend 因 unsupported feature、unsupported syntax、type binding failure 或 specialization error 拒绝一个 kernel - **THEN** 诊断 MUST 至少包含 DSL 源文件位置、行列号或等价的 source span - **AND** MUST 明确指出失败原因属于哪一层 frontend 语义,而不是只给出底层 verifier 或 parser 的通用报错 + +### Requirement: selector diagnostics MUST identify the failing stage and failing constraint for each reported candidate + +当调用方启用 `select_kernel(...)` 的 report/metadata 模式时,TileLang DSL diagnostics MUST 为每个候选明确指出其失败或胜出的阶段。 +对于 constraint 阶段失败的候选,诊断 MUST 至少包含: + +- 失败的 constraint 索引 +- 若可解析则包含 callable 名称或 `qualname` +- 区分“predicate 返回 `False`”与“constraint 执行抛异常” +- 可读的失败原因文本 + +对于 `dtype` 不匹配、priority 落败和 materialization 失败,diagnostics 也 MUST 使用不同的 kind/status 表达,而不是统一折叠成同一种通用错误。 +selector diagnostics MUST 让调用方能仅凭 report 判定候选究竟挂在 `dtype`、`constraints`、`priority` 还是 materialization。 + +#### Scenario: false-returning constraint is reported with index and callable identity + +- **WHEN** 某个候选在 constraint evaluation 中命中第 `N` 个 constraint,且该 callable 返回 `False` +- **THEN** selection diagnostics MUST 报告该候选失败在第 `N` 个 constraint +- **AND** diagnostics MUST 在可解析时包含该 constraint 的 callable 名称或 `qualname` + +#### Scenario: raising constraint is reported as a constraint error + +- **WHEN** 某个 constraint 在执行时抛出异常 +- **THEN** selection diagnostics MUST 将该候选标记为 `constraint_error` +- **AND** diagnostics MUST 包含异常类型与消息摘要,而不是只给通用失败 + +#### Scenario: materialization failure remains distinguishable from matcher failure + +- **WHEN** 某个候选通过 `dtype` 与 `constraints`,但在 `mlir_text()` materialization 时失败 +- **THEN** selection diagnostics MUST 将该候选标记为 materialization 相关失败 +- **AND** MUST NOT 把该失败重新表述为 `dtype` mismatch 或 constraint failure diff --git a/openspec/specs/tilelang-dsl-kernel-matcher/spec.md b/openspec/specs/tilelang-dsl-kernel-matcher/spec.md index e024ae38a..b7dc7e68c 100644 --- a/openspec/specs/tilelang-dsl-kernel-matcher/spec.md +++ b/openspec/specs/tilelang-dsl-kernel-matcher/spec.md @@ -1,25 +1,32 @@ # tilelang-dsl-kernel-matcher Specification -## MODIFIED Requirements +## Purpose +Define how TileLang DSL registers kernel descriptors, matches candidates, and +selects a final kernel while preserving deterministic behavior and actionable +selection diagnostics. + +## Requirements ### Requirement: TileLang DSL MUST provide an explicit kernel registry and selection API 当同一 `target/op` 下存在多个 `@pto.vkernel` descriptor 时,TileLang DSL MUST 将它们注册到显式、可查询的 `KernelRegistry`。 默认 registry MUST 是 module-level 对象;调用方 MAY 传入自定义 registry 以获得隔离的候选集合。 -系统 MUST 提供显式 selection API `pto.select_kernel(target, op, operand_types, context_attrs, registry=None)`,用于在给定 `target`、concrete `op`、operand type 信息和上下文属性时选择唯一 kernel。 +系统 MUST 提供显式 selection API `pto.select_kernel(target, op, operand_types, context_attrs=None, registry=None, return_metadata=False, include_mlir=True)`,用于在给定 `target`、concrete `op`、operand type 信息和上下文属性时选择 kernel。 descriptor MUST 支持两种互斥的 matcher 元数据: - `op=""` - `ops=["", "", ...]` descriptor MUST 至少提供其中一种,且实现 MUST NOT 同时接受两者。 -当 selector 命中一个 `ops=[...]` descriptor 时,返回结果 MUST 绑定当前 query 对应的唯一 concrete `selected_op`,再进入后续 `specialize()` / `mlir_text()` / `verify()` 流程。 +当 selector 命中一个 `ops=[...]` descriptor 时,无论是默认路径还是 report 路径,结果都 MUST 绑定当前 query 对应的唯一 concrete `selected_op`,再进入后续 `specialize()` / `mlir_text()` / `verify()` 流程。 +当 `return_metadata=False` 时,`pto.select_kernel(...)` MUST 继续返回唯一选中的 descriptor,并保持现有异常行为兼容。 +当 `return_metadata=True` 时,`pto.select_kernel(...)` MUST 返回结构化 selection report;该 report MUST 暴露最终 winner、逐候选状态和最终决策摘要,而不是只暴露一个 descriptor 或通用失败字符串。 实现 MUST NOT 依赖扫描 Python globals、locals 或导入顺序来隐式发现候选。 -#### Scenario: selector returns the unique best kernel +#### Scenario: selector returns the unique best kernel in default mode - **WHEN** registry 中存在多个针对同一 `target/op` 的 kernel descriptor,且其中一个在全部匹配步骤后成为唯一最佳候选 -- **THEN** `pto.select_kernel(...)` MUST 返回该 descriptor +- **THEN** `pto.select_kernel(..., return_metadata=False)` MUST 返回该 descriptor - **AND** 返回结果 MUST 可继续走 `specialize()` / `mlir_text()` / `verify()` 流程 #### Scenario: custom registry restricts the candidate set explicitly @@ -31,9 +38,72 @@ descriptor MUST 至少提供其中一种,且实现 MUST NOT 同时接受两者 #### Scenario: selector binds the concrete op for a multi-op descriptor - **WHEN** 一个 descriptor 通过 `ops=["tadd", "tsub", "tmul", "tdiv"]` 注册,且调用方以 `pto.select_kernel(..., op="tmul", ...)` 查询命中该 descriptor -- **THEN** selector MUST 返回已经绑定 `selected_op="tmul"` 的 descriptor +- **THEN** selector MUST 返回已经绑定 `selected_op="tmul"` 的 descriptor,或在 report 模式下返回绑定了该 `selected_op` 的候选记录 - **AND** 后续 materialization MUST 基于该 concrete `selected_op` 而不是未绑定的原始 matcher 集合 +#### Scenario: selector can return a structured selection report + +- **WHEN** 调用方以 `pto.select_kernel(..., return_metadata=True)` 查询 kernel +- **THEN** selector MUST 返回包含 `selected`、`candidates`、`final_status` 和 `final_error` 的结构化 report +- **AND** 该 report MUST 保持与默认路径一致的 winner 决策语义 + +### Requirement: selection report MUST preserve per-candidate stage results for all target/op-matched descriptors + +当调用方启用 `return_metadata=True` 时,selector MUST 为所有通过 `target` 与 concrete `op` 过滤的 descriptor 生成逐候选 metadata。 +每个候选记录 MUST 至少包含: + +- kernel identity(如 `name`、`priority`、`match_ops`) +- 当前 query 下绑定后的 concrete `selected_op` +- 匹配到的 concrete dtype signature,或明确的 dtype mismatch 信息 +- 一个稳定的阶段状态,至少覆盖: + - `dtype_mismatch` + - `constraint_failed` + - `constraint_error` + - `priority_shadowed` + - `selected` + - `mlir_error` + +候选在 `dtype` 阶段失败时,selector MUST 保留该候选并显式标记为 `dtype_mismatch`,而不是直接从 report 中丢弃。 +若没有任何候选通过后续选择,顶层 report MUST 通过 `final_status` / `final_error` 明确表达 `no_candidate` 或 `priority_tie`,同时保留全部候选记录。 +report 模式 MUST NOT 改变 matcher 的既有选择顺序或 winner 决策结果。 + +#### Scenario: dtype-mismatched descriptor still appears in the report + +- **WHEN** 某个 descriptor 通过 `target/op` 过滤,但没有任何 `dtypes` signature 能匹配当前 operand types +- **THEN** selection report MUST 仍包含该 descriptor 的候选记录 +- **AND** 该候选 MUST 标记为 `dtype_mismatch` + +#### Scenario: report preserves no-candidate outcome without losing candidate context + +- **WHEN** 所有通过 `target/op` 的候选最终都在 `dtype`、`constraints` 或 materialization 阶段失败 +- **THEN** selection report MUST 将 `final_status` 标记为 `no_candidate` +- **AND** `candidates` 中 MUST 保留每个候选的失败阶段和失败原因 + +#### Scenario: report preserves priority tie outcome without dropping tied winners + +- **WHEN** 多个候选在 `target/op/dtype/constraints` 全部通过后拥有相同最高 `priority` +- **THEN** selection report MUST 将 `final_status` 标记为 `priority_tie` +- **AND** report MUST 保留所有 tie 候选,而不是隐式只保留其中一个 + +### Requirement: selection report MUST optionally include MLIR materialization results for successful candidates + +当 `return_metadata=True` 且 `include_mlir=True` 时,selector MUST 对所有通过 `constraints` 阶段、并进入 priority 决策的候选尝试 materialization。 +若 `mlir_text()` 成功,候选记录 MUST 暴露 `mlir_text`。 +若 materialization 因 specialization/context 不完整或其他 frontend 失败而未成功,候选记录 MUST 暴露 `mlir_error`,并继续保留其余阶段信息。 +当 `include_mlir=False` 时,selector MUST 允许调用方跳过该 materialization 尝试,但其余候选 metadata 仍 MUST 可用。 + +#### Scenario: successful candidate carries MLIR text in the report + +- **WHEN** 某个候选通过选择阶段并且 `mlir_text()` materialization 成功 +- **THEN** 该候选的 selection metadata MUST 包含 `mlir_text` +- **AND** 其阶段状态 MUST 继续反映它是 `selected` 或 `priority_shadowed` + +#### Scenario: materialization failure is captured without losing the candidate + +- **WHEN** 某个候选通过 `dtype` 与 `constraints`,但在 `mlir_text()` 时因为 specialization/context 不完整而失败 +- **THEN** selection report MUST 保留该候选 +- **AND** 该候选 MUST 包含 `mlir_error`,而不是被重新归类为 `dtype` 或 constraint 失败 + ### Requirement: matcher MUST support concrete types, `Any*`, and `TypeVar` across multiple signatures matcher MUST 支持: @@ -109,8 +179,6 @@ matcher MUST 支持: - **AND** 错误消息 MUST 指出发生 tie 的 kernel 集合 - **AND** MUST NOT 静默选择第一个已注册 kernel -## ADDED Requirements - ### Requirement: multi-op descriptors MUST require concrete op binding before IR materialization 当 descriptor 使用 `ops=[...]` 覆盖多个 concrete PTO op 时,系统 MUST 在 materialization 前先绑定唯一 `selected_op`。 diff --git a/test/tilelang_st/npu/a5/src/st/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/CMakeLists.txt index 4aec35d12..a3e167a84 100644 --- a/test/tilelang_st/npu/a5/src/st/CMakeLists.txt +++ b/test/tilelang_st/npu/a5/src/st/CMakeLists.txt @@ -13,6 +13,13 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON) +# CMake 3.27+ may ask the linker to emit dependency files via +# `--dependency-file`. bisheng/cce-ld does not support that flag, so disable +# linker-generated link dependencies for this standalone ST build. +if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.27) + set(CMAKE_LINK_DEPENDS_USE_LINKER FALSE) +endif() + set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) diff --git a/tilelang-dsl/docs/matcher-and-advanced-surface-migration.md b/tilelang-dsl/docs/matcher-and-advanced-surface-migration.md index 4117cf1b2..56a249c97 100644 --- a/tilelang-dsl/docs/matcher-and-advanced-surface-migration.md +++ b/tilelang-dsl/docs/matcher-and-advanced-surface-migration.md @@ -114,6 +114,33 @@ Matcher rules in the implemented package: - when a multi-op descriptor matches, the returned descriptor is already bound to one concrete `selected_op` +Matcher diagnostics are also available through the opt-in report path: + +```python +report = pto.select_kernel( + "a5", + "eltwise", + (pto.f32, pto.f32), + context_attrs={"enabled": False}, + return_metadata=True, + include_mlir=False, +) +``` + +In report mode: + +- `report.final_status` summarizes the overall outcome +- `report.candidates` keeps one record per `target/op`-matched descriptor +- constraint failures expose `failed_constraint_index`, + `failed_constraint_name`, and `failed_constraint_location` +- `include_mlir=True` additionally collects `mlir_text` or `mlir_error` for + candidates that pass constraint evaluation + +For clearer diagnostics, prefer writing multiple small constraint entries over a +single compound Python predicate. Report mode can identify which constraint +callable failed, but it does not decompose `cond0 and cond1` inside one +callable. + For explicit single-op kernels that already map 1:1 to one real PTO op, you do not need to migrate anything. Keep `op="..."` and keep authoring explicit real `pto.*` calls in the kernel body. diff --git a/tilelang-dsl/docs/user_guide/03-kernel-declaration.md b/tilelang-dsl/docs/user_guide/03-kernel-declaration.md index 348bf170c..349e2e0d4 100644 --- a/tilelang-dsl/docs/user_guide/03-kernel-declaration.md +++ b/tilelang-dsl/docs/user_guide/03-kernel-declaration.md @@ -241,6 +241,50 @@ selected = pto.select_kernel( ) ``` +`pto.select_kernel(...)` also supports an opt-in diagnostics path for matcher debugging: + +```python +report = pto.select_kernel( + "a5", + "matmul", + (pto.f16, pto.f16, pto.f32), + context_attrs={"k_aligned": False}, + return_metadata=True, + include_mlir=False, +) +``` + +When `return_metadata=True`, the result is a `KernelSelectionReport` instead of one +selected descriptor. + +- `report.selected` carries the winner when one candidate is selected. +- `report.final_status` is one of `selected`, `no_candidate`, or `priority_tie`. +- `report.final_error` summarizes the final selection outcome. +- `report.candidates` contains one `KernelSelectionCandidateMetadata` per + `target/op`-matched descriptor, including `dtype_mismatch`, + `constraint_failed`, `constraint_error`, `priority_shadowed`, `selected`, and + `priority_tie` states. + +Constraint diagnostics in report mode include: + +- `failed_constraint_index` +- `failed_constraint_name` +- `failed_constraint_location` as `file:line` + +For best diagnostics, prefer splitting compound predicates into multiple +constraint entries instead of writing one large `cond0 and cond1 and cond2` +callable. Report mode can precisely identify which constraint entry failed, but +it does not introspect which sub-expression inside one Python boolean +expression returned `False`. + +When `include_mlir=True`, report mode also attempts `mlir_text()` for candidates +that pass constraint evaluation. + +- On success, the candidate carries `mlir_text`. +- On materialization failure such as missing `specialize()` bindings, the + candidate carries `mlir_error`. +- Use `include_mlir=False` to skip this extra materialization attempt. + #### Examples ##### Matmul with Multiple Implementations diff --git a/tilelang-dsl/python/tilelang_dsl/__init__.py b/tilelang-dsl/python/tilelang_dsl/__init__.py index d3f159845..2aedb8083 100644 --- a/tilelang-dsl/python/tilelang_dsl/__init__.py +++ b/tilelang-dsl/python/tilelang_dsl/__init__.py @@ -12,6 +12,8 @@ BoundKernelParameter, InlineProcDescriptor, KernelRegistry, + KernelSelectionCandidateMetadata, + KernelSelectionReport, MaterializedMLIRModule, TileLangFrontendError, VKernelDescriptor, @@ -89,6 +91,8 @@ "BoundKernelParameter", "InlineProcDescriptor", "KernelRegistry", + "KernelSelectionCandidateMetadata", + "KernelSelectionReport", "MaterializedMLIRModule", "TileLangFrontendError", "VKernelDescriptor", diff --git a/tilelang-dsl/python/tilelang_dsl/expand_helper.py b/tilelang-dsl/python/tilelang_dsl/expand_helper.py index d2c776126..a31af37db 100644 --- a/tilelang-dsl/python/tilelang_dsl/expand_helper.py +++ b/tilelang-dsl/python/tilelang_dsl/expand_helper.py @@ -291,6 +291,7 @@ def _select_descriptor( operand_types, context_attrs=_build_positional_context_attrs(operand_specs), registry=registry, + return_metadata=False, ) diff --git a/tilelang-dsl/python/tilelang_dsl/kernel.py b/tilelang-dsl/python/tilelang_dsl/kernel.py index eb1b086d7..e239bded4 100644 --- a/tilelang-dsl/python/tilelang_dsl/kernel.py +++ b/tilelang-dsl/python/tilelang_dsl/kernel.py @@ -1084,7 +1084,9 @@ def _validate_materialization_constraints(self, api_name: str) -> None: if not self.constraints: return context_attrs = self._constraint_context_for_evaluation() - if _evaluate_constraints(self, context_attrs): + evaluation = _evaluate_constraints(self, context_attrs) + _raise_constraint_evaluation_error(evaluation) + if evaluation.passed: return raise LookupError( f"{api_name}() constraint evaluation rejected kernel {self.name!r} " @@ -1122,6 +1124,124 @@ def emit(self, path: str | Path) -> None: output_path.write_text(self.mlir_text(), encoding="utf-8") +@dataclass(frozen=True) +class KernelSelectionCandidateMetadata: + """Structured selection diagnostics for one target/op-matched kernel candidate.""" + + descriptor: VKernelDescriptor + status: str + selected_op: str | None = None + matched_dtype_signature: tuple[ScalarType | MaskType, ...] | None = None + reason: str | None = None + failed_constraint_index: int | None = None + failed_constraint_name: str | None = None + failed_constraint_location: str | None = None + error_type: str | None = None + error_message: str | None = None + mlir_text: str | None = None + mlir_error: str | None = None + + @property + def name(self) -> str: + return self.descriptor.name + + @property + def priority(self) -> int: + return self.descriptor.priority + + @property + def match_ops(self) -> tuple[str, ...]: + return self.descriptor.match_ops + + @property + def dtype_signatures(self) -> tuple[tuple[Any, ...], ...]: + return self.descriptor.dtypes + + +@dataclass(frozen=True) +class KernelSelectionReport: + """Structured selector result returned by the opt-in metadata path.""" + + target: str + op: str + operand_types: tuple[ScalarType | MaskType, ...] + selected: VKernelDescriptor | None + candidates: tuple[KernelSelectionCandidateMetadata, ...] = () + final_status: str = "no_candidate" + final_error: str | None = None + _context_attrs: tuple[tuple[str, Any], ...] = field(default=(), repr=False) + + @property + def context_attrs(self) -> dict[str, Any]: + return dict(self._context_attrs) + + @property + def ok(self) -> bool: + return self.final_status == "selected" and self.selected is not None + + +@dataclass(frozen=True) +class _TargetOpSelectionCandidate: + descriptor: VKernelDescriptor + + +@dataclass(frozen=True) +class _DtypeSelectionCandidate: + descriptor: VKernelDescriptor + matched_descriptor: VKernelDescriptor | None = None + matched_dtype_signature: tuple[ScalarType | MaskType, ...] | None = None + + @property + def matched(self) -> bool: + return self.matched_descriptor is not None + + +@dataclass(frozen=True) +class _ConstraintSelectionCandidate: + descriptor: VKernelDescriptor + passed: bool + evaluation: "_ConstraintEvaluationResult" + bound_descriptor: VKernelDescriptor | None = None + + +@dataclass(frozen=True) +class _PrioritySelectionResult: + candidates: tuple[VKernelDescriptor, ...] + highest_priority: int | None + winners: tuple[VKernelDescriptor, ...] + + @property + def has_tie(self) -> bool: + return len(self.winners) > 1 + + @property + def winner(self) -> VKernelDescriptor | None: + if len(self.winners) != 1: + return None + return self.winners[0] + + +@dataclass(frozen=True) +class _MaterializationSelectionCandidate: + descriptor: VKernelDescriptor + mlir_text: str | None = None + mlir_error: str | None = None + + +@dataclass(frozen=True) +class _ConstraintEvaluationResult: + passed: bool + failed_constraint_index: int | None = None + failed_constraint_name: str | None = None + failed_constraint_location: str | None = None + error_type: str | None = None + error_message: str | None = None + + @property + def raised_error(self) -> bool: + return self.error_type is not None + + class KernelRegistry: """Explicit registry for TileLang kernel descriptors.""" @@ -1950,7 +2070,7 @@ def _build_descriptor( def _evaluate_constraints( descriptor: VKernelDescriptor, context_attrs: Mapping[str, Any], -) -> bool: +) -> _ConstraintEvaluationResult: named_context: dict[str, Any] = { "target": context_attrs.get("target"), "op": context_attrs.get("op"), @@ -1963,6 +2083,8 @@ def _evaluate_constraints( named_context[spec.name] = _ConstraintParamView(spec.name, param_attrs) for index, constraint in enumerate(descriptor.constraints): + constraint_name = _constraint_callable_name(constraint) + constraint_location = _constraint_callable_location(constraint) try: signature = inspect.signature(constraint) parameters = list(signature.parameters.values()) @@ -1990,12 +2112,61 @@ def _evaluate_constraints( ) result = constraint(**kwargs) except Exception as exc: - raise TypeError( - f"constraint {index} for kernel {descriptor.name!r} raised {type(exc).__name__}: {exc}" - ) from exc + return _ConstraintEvaluationResult( + passed=False, + failed_constraint_index=index, + failed_constraint_name=constraint_name, + failed_constraint_location=constraint_location, + error_type=type(exc).__name__, + error_message=( + f"constraint {index} for kernel {descriptor.name!r} " + f"raised {type(exc).__name__}: {exc}" + f"{_format_constraint_location_suffix(constraint_location)}" + ), + ) if not result: - return False - return True + return _ConstraintEvaluationResult( + passed=False, + failed_constraint_index=index, + failed_constraint_name=constraint_name, + failed_constraint_location=constraint_location, + error_message=( + f"constraint {index} for kernel {descriptor.name!r} returned False" + f"{_format_constraint_location_suffix(constraint_location)}" + ), + ) + return _ConstraintEvaluationResult(passed=True) + + +def _constraint_callable_name(constraint: Callable[..., Any]) -> str | None: + qualname = getattr(constraint, "__qualname__", None) + if isinstance(qualname, str) and qualname: + return qualname + name = getattr(constraint, "__name__", None) + if isinstance(name, str) and name: + return name + return None + + +def _constraint_callable_location(constraint: Callable[..., Any]) -> str | None: + code = getattr(constraint, "__code__", None) + filename = getattr(code, "co_filename", None) + firstlineno = getattr(code, "co_firstlineno", None) + if isinstance(filename, str) and filename and isinstance(firstlineno, int) and firstlineno > 0: + return f"{filename}:{firstlineno}" + return None + + +def _format_constraint_location_suffix(location: str | None) -> str: + if location is None: + return "" + return f" at {location}" + + +def _raise_constraint_evaluation_error(result: _ConstraintEvaluationResult) -> None: + if not result.raised_error or result.error_message is None: + return + raise TypeError(result.error_message) def _format_descriptor_identity(descriptor: VKernelDescriptor) -> str: @@ -2005,25 +2176,323 @@ def _format_descriptor_identity(descriptor: VKernelDescriptor) -> str: return f"{descriptor.name}(priority={descriptor.priority}, dtypes={dtype_signature!r})" -def _match_descriptor_query( +def _bind_descriptor_for_target_op( descriptor: VKernelDescriptor, *, target: str, op: str, - operand_types: tuple[ScalarType | MaskType, ...], ) -> VKernelDescriptor | None: if descriptor.target != target: return None if op not in descriptor.match_ops: return None + return descriptor._bind_selected_op(op) - op_bound_descriptor = descriptor._bind_selected_op(op) - matched_signature = _match_descriptor_dtype_signature(op_bound_descriptor, operand_types) + +def _collect_target_op_candidates( + registry: KernelRegistry, + *, + target: str, + op: str, +) -> tuple[_TargetOpSelectionCandidate, ...]: + candidates: list[_TargetOpSelectionCandidate] = [] + for descriptor in registry: + op_bound_descriptor = _bind_descriptor_for_target_op( + descriptor, + target=target, + op=op, + ) + if op_bound_descriptor is None: + continue + candidates.append(_TargetOpSelectionCandidate(descriptor=op_bound_descriptor)) + return tuple(candidates) + + +def _evaluate_dtype_candidate( + candidate: _TargetOpSelectionCandidate, + *, + operand_types: tuple[ScalarType | MaskType, ...], +) -> _DtypeSelectionCandidate: + matched_signature = _match_descriptor_dtype_signature(candidate.descriptor, operand_types) if matched_signature is None: + return _DtypeSelectionCandidate(descriptor=candidate.descriptor) + if candidate.descriptor._selected_dtype_signature == matched_signature: + return _DtypeSelectionCandidate( + descriptor=candidate.descriptor, + matched_descriptor=candidate.descriptor, + matched_dtype_signature=matched_signature, + ) + return _DtypeSelectionCandidate( + descriptor=candidate.descriptor, + matched_descriptor=candidate.descriptor._bind_selected_dtype_signature(matched_signature), + matched_dtype_signature=matched_signature, + ) + + +def _evaluate_dtype_candidates( + candidates: tuple[_TargetOpSelectionCandidate, ...], + *, + operand_types: tuple[ScalarType | MaskType, ...], +) -> tuple[_DtypeSelectionCandidate, ...]: + return tuple( + _evaluate_dtype_candidate( + candidate, + operand_types=operand_types, + ) + for candidate in candidates + ) + + +def _match_descriptor_query( + descriptor: VKernelDescriptor, + *, + target: str, + op: str, + operand_types: tuple[ScalarType | MaskType, ...], +) -> VKernelDescriptor | None: + op_bound_descriptor = _bind_descriptor_for_target_op( + descriptor, + target=target, + op=op, + ) + if op_bound_descriptor is None: return None - if op_bound_descriptor._selected_dtype_signature == matched_signature: - return op_bound_descriptor - return op_bound_descriptor._bind_selected_dtype_signature(matched_signature) + dtype_result = _evaluate_dtype_candidate( + _TargetOpSelectionCandidate(descriptor=op_bound_descriptor), + operand_types=operand_types, + ) + return dtype_result.matched_descriptor + + +def _evaluate_constraint_candidate( + descriptor: VKernelDescriptor, + *, + context_attrs: Mapping[str, Any], +) -> _ConstraintSelectionCandidate: + evaluation = _evaluate_constraints( + descriptor, + descriptor._constraint_context_for_evaluation(context_attrs), + ) + if not evaluation.passed: + return _ConstraintSelectionCandidate( + descriptor=descriptor, + passed=False, + evaluation=evaluation, + ) + return _ConstraintSelectionCandidate( + descriptor=descriptor, + passed=True, + evaluation=evaluation, + bound_descriptor=descriptor._bind_constraint_context_attrs(context_attrs), + ) + + +def _evaluate_constraint_candidates( + descriptors: tuple[VKernelDescriptor, ...], + *, + context_attrs: Mapping[str, Any], +) -> tuple[_ConstraintSelectionCandidate, ...]: + return tuple( + _evaluate_constraint_candidate( + descriptor, + context_attrs=context_attrs, + ) + for descriptor in descriptors + ) + + +def _resolve_priority_candidates( + descriptors: tuple[VKernelDescriptor, ...], +) -> _PrioritySelectionResult: + if not descriptors: + return _PrioritySelectionResult( + candidates=(), + highest_priority=None, + winners=(), + ) + highest_priority = max(descriptor.priority for descriptor in descriptors) + winners = tuple( + descriptor + for descriptor in descriptors + if descriptor.priority == highest_priority + ) + return _PrioritySelectionResult( + candidates=descriptors, + highest_priority=highest_priority, + winners=winners, + ) + + +def _materialize_selection_candidate( + descriptor: VKernelDescriptor, +) -> _MaterializationSelectionCandidate: + try: + return _MaterializationSelectionCandidate( + descriptor=descriptor, + mlir_text=descriptor.mlir_text(), + ) + except Exception as exc: + return _MaterializationSelectionCandidate( + descriptor=descriptor, + mlir_error=str(exc), + ) + + +def _collect_materialization_candidates( + descriptors: tuple[VKernelDescriptor, ...], +) -> tuple[_MaterializationSelectionCandidate, ...]: + return tuple( + _materialize_selection_candidate(descriptor) + for descriptor in descriptors + ) + + +def _select_kernel_no_candidate_error( + *, + target: str, + op: str, + operand_types: tuple[ScalarType | MaskType, ...], +) -> str: + return ( + "select_kernel() found no registered kernel for " + f"target={target!r}, op={op!r}, operand_types={operand_types!r}" + ) + + +def _select_kernel_constraint_error( + *, + target: str, + op: str, + operand_types: tuple[ScalarType | MaskType, ...], +) -> str: + return ( + "select_kernel() found no registered kernel after constraint evaluation for " + f"target={target!r}, op={op!r}, operand_types={operand_types!r}" + ) + + +def _select_kernel_priority_tie_error( + *, + target: str, + op: str, + operand_types: tuple[ScalarType | MaskType, ...], + winners: tuple[VKernelDescriptor, ...], +) -> str: + winner_set = ", ".join(sorted(_format_descriptor_identity(descriptor) for descriptor in winners)) + return ( + "select_kernel() found multiple highest-priority kernels for " + f"target={target!r}, op={op!r}, operand_types={operand_types!r}: " + f"{winner_set}" + ) + + +def _build_selection_report( + *, + target: str, + op: str, + operand_types: tuple[ScalarType | MaskType, ...], + context_attrs: Mapping[str, Any], + dtype_results: tuple[_DtypeSelectionCandidate, ...], + constraint_results: tuple[_ConstraintSelectionCandidate, ...], + materialization_results: tuple[_MaterializationSelectionCandidate, ...], + priority_result: _PrioritySelectionResult, + final_status: str, + final_error: str | None, +) -> KernelSelectionReport: + constraint_by_descriptor_id = { + id(result.descriptor): result + for result in constraint_results + } + materialization_by_descriptor_id = { + id(result.descriptor): result + for result in materialization_results + } + winner_ids = {id(descriptor) for descriptor in priority_result.winners} + highest_priority = priority_result.highest_priority + candidates: list[KernelSelectionCandidateMetadata] = [] + + for dtype_result in dtype_results: + if dtype_result.matched_descriptor is None: + candidates.append( + KernelSelectionCandidateMetadata( + descriptor=dtype_result.descriptor, + status="dtype_mismatch", + selected_op=dtype_result.descriptor.selected_op, + reason=( + "no dtype signature matched " + f"operand_types={operand_types!r}" + ), + ) + ) + continue + + constraint_result = constraint_by_descriptor_id.get(id(dtype_result.matched_descriptor)) + if constraint_result is None: + continue + evaluation = constraint_result.evaluation + candidate_descriptor = constraint_result.bound_descriptor or dtype_result.matched_descriptor + materialization_result = materialization_by_descriptor_id.get(id(candidate_descriptor)) + base_kwargs = { + "descriptor": candidate_descriptor, + "selected_op": candidate_descriptor.selected_op, + "matched_dtype_signature": dtype_result.matched_dtype_signature, + "failed_constraint_index": evaluation.failed_constraint_index, + "failed_constraint_name": evaluation.failed_constraint_name, + "failed_constraint_location": evaluation.failed_constraint_location, + "error_type": evaluation.error_type, + "error_message": evaluation.error_message, + "mlir_text": None if materialization_result is None else materialization_result.mlir_text, + "mlir_error": None if materialization_result is None else materialization_result.mlir_error, + } + + if evaluation.raised_error: + candidates.append( + KernelSelectionCandidateMetadata( + status="constraint_error", + reason=evaluation.error_message, + **base_kwargs, + ) + ) + continue + if not evaluation.passed: + candidates.append( + KernelSelectionCandidateMetadata( + status="constraint_failed", + reason=evaluation.error_message, + **base_kwargs, + ) + ) + continue + if id(candidate_descriptor) in winner_ids: + status = "selected" if final_status == "selected" else "priority_tie" + reason = None if status == "selected" else final_error + else: + status = "priority_shadowed" + if highest_priority is None: + reason = "not selected" + else: + reason = f"shadowed by higher-priority candidate priority={highest_priority}" + candidates.append( + KernelSelectionCandidateMetadata( + status=status, + reason=reason, + **base_kwargs, + ) + ) + + frozen_context_attrs = tuple( + sorted(dict(context_attrs).items(), key=lambda item: item[0]) + ) + return KernelSelectionReport( + target=target, + op=op, + operand_types=operand_types, + selected=priority_result.winner if final_status == "selected" else None, + candidates=tuple(candidates), + final_status=final_status, + final_error=final_error, + _context_attrs=frozen_context_attrs, + ) def select_kernel( @@ -2032,7 +2501,10 @@ def select_kernel( operand_types: Any, context_attrs: Mapping[str, Any] | None = None, registry: KernelRegistry | None = None, -) -> VKernelDescriptor: + *, + return_metadata: bool = False, + include_mlir: bool = True, +) -> VKernelDescriptor | KernelSelectionReport: """Select one registered kernel descriptor for the given query.""" normalized_target = _validate_target(target) @@ -2049,55 +2521,120 @@ def select_kernel( active_registry = _DEFAULT_KERNEL_REGISTRY if registry is None else registry if not isinstance(active_registry, KernelRegistry): raise TypeError("registry must be a KernelRegistry or None") + if not isinstance(return_metadata, bool): + raise TypeError("return_metadata must be a bool") + if not isinstance(include_mlir, bool): + raise TypeError("include_mlir must be a bool") + + target_op_candidates = _collect_target_op_candidates( + active_registry, + target=normalized_target, + op=normalized_op, + ) + dtype_results = _evaluate_dtype_candidates( + target_op_candidates, + operand_types=normalized_operand_types, + ) + type_matched_candidates = tuple( + result.matched_descriptor + for result in dtype_results + if result.matched_descriptor is not None + ) - type_matched_candidates = [ - matched_descriptor - for descriptor in active_registry - for matched_descriptor in ( - _match_descriptor_query( - descriptor, + if not type_matched_candidates: + no_candidate_error = _select_kernel_no_candidate_error( + target=normalized_target, + op=normalized_op, + operand_types=normalized_operand_types, + ) + if return_metadata: + return _build_selection_report( target=normalized_target, op=normalized_op, operand_types=normalized_operand_types, - ), - ) - if matched_descriptor is not None - ] + context_attrs=normalized_context_attrs, + dtype_results=dtype_results, + constraint_results=(), + materialization_results=(), + priority_result=_PrioritySelectionResult(candidates=(), highest_priority=None, winners=()), + final_status="no_candidate", + final_error=no_candidate_error, + ) + raise LookupError(no_candidate_error) - if not type_matched_candidates: - raise LookupError( - "select_kernel() found no registered kernel for " - f"target={normalized_target!r}, op={normalized_op!r}, operand_types={normalized_operand_types!r}" + constraint_results = _evaluate_constraint_candidates( + type_matched_candidates, + context_attrs=normalized_context_attrs, + ) + constrained_candidates = tuple( + result.bound_descriptor + for result in constraint_results + if result.bound_descriptor is not None + ) + if return_metadata: + priority_result = _resolve_priority_candidates(constrained_candidates) + materialization_results = ( + _collect_materialization_candidates(constrained_candidates) + if include_mlir + else () ) - - constrained_candidates = [ - descriptor - for descriptor in type_matched_candidates - if _evaluate_constraints( - descriptor, - descriptor._constraint_context_for_evaluation(normalized_context_attrs), + final_status = "selected" + final_error: str | None = None + if not constrained_candidates: + final_status = "no_candidate" + error_messages = [ + result.evaluation.error_message + for result in constraint_results + if result.evaluation.error_message is not None + ] + final_error = error_messages[0] if error_messages else _select_kernel_constraint_error( + target=normalized_target, + op=normalized_op, + operand_types=normalized_operand_types, + ) + elif priority_result.has_tie: + final_status = "priority_tie" + final_error = _select_kernel_priority_tie_error( + target=normalized_target, + op=normalized_op, + operand_types=normalized_operand_types, + winners=priority_result.winners, + ) + return _build_selection_report( + target=normalized_target, + op=normalized_op, + operand_types=normalized_operand_types, + context_attrs=normalized_context_attrs, + dtype_results=dtype_results, + constraint_results=constraint_results, + materialization_results=materialization_results, + priority_result=priority_result, + final_status=final_status, + final_error=final_error, ) - ] + for result in constraint_results: + _raise_constraint_evaluation_error(result.evaluation) if not constrained_candidates: raise LookupError( - "select_kernel() found no registered kernel after constraint evaluation for " - f"target={normalized_target!r}, op={normalized_op!r}, operand_types={normalized_operand_types!r}" + _select_kernel_constraint_error( + target=normalized_target, + op=normalized_op, + operand_types=normalized_operand_types, + ) ) - highest_priority = max(descriptor.priority for descriptor in constrained_candidates) - winners = [ - descriptor - for descriptor in constrained_candidates - if descriptor.priority == highest_priority - ] - if len(winners) > 1: - winner_set = ", ".join(sorted(_format_descriptor_identity(descriptor) for descriptor in winners)) + priority_result = _resolve_priority_candidates(constrained_candidates) + if priority_result.has_tie: raise LookupError( - "select_kernel() found multiple highest-priority kernels for " - f"target={normalized_target!r}, op={normalized_op!r}, operand_types={normalized_operand_types!r}: " - f"{winner_set}" + _select_kernel_priority_tie_error( + target=normalized_target, + op=normalized_op, + operand_types=normalized_operand_types, + winners=priority_result.winners, + ) ) - return winners[0]._bind_constraint_context_attrs(normalized_context_attrs) + assert priority_result.winner is not None + return priority_result.winner def vkernel( @@ -2146,6 +2683,8 @@ def wrap(fn: Callable[..., Any]) -> VKernelDescriptor: "BoundKernelParameter", "InlineProcDescriptor", "KernelRegistry", + "KernelSelectionCandidateMetadata", + "KernelSelectionReport", "MaterializedMLIRModule", "TileLangFrontendError", "VKernelDescriptor", diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index 37f0e1512..c70a45e68 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -768,6 +768,246 @@ def kernel(inp: pto.TensorView, out: pto.TensorView): ) self.assertIn("after constraint evaluation", str(ctx.exception)) + def test_select_kernel_report_mode_keeps_default_descriptor_path_compatible(self) -> None: + @pto.vkernel(op="matcher_report_default_compat_unique", dtypes=[(pto.f32, pto.f32)]) + def kernel(inp: pto.TensorView, out: pto.TensorView): + return None + + selected = pto.select_kernel( + "a5", + "matcher_report_default_compat_unique", + (pto.f32, pto.f32), + return_metadata=False, + include_mlir=False, + ) + + self.assertIsInstance(selected, pto.VKernelDescriptor) + self.assertIs(selected.py_fn, kernel.py_fn) + self.assertEqual(selected.dtype_signature, (pto.f32, pto.f32)) + + def test_select_kernel_report_mode_records_dtype_mismatch_candidates(self) -> None: + @pto.vkernel( + op="matcher_report_dtype_mismatch_unique", + dtypes=[(pto.f32, pto.f32)], + priority=5, + ) + def mismatch(inp: pto.TensorView, out: pto.TensorView): + return None + + @pto.vkernel( + op="matcher_report_dtype_mismatch_unique", + dtypes=[(pto.AnyFloat, pto.AnyFloat)], + priority=10, + ) + def fallback(inp: pto.TensorView, out: pto.TensorView): + return None + + report = pto.select_kernel( + "a5", + "matcher_report_dtype_mismatch_unique", + (pto.bf16, pto.bf16), + return_metadata=True, + include_mlir=False, + ) + + self.assertIsInstance(report, pto.KernelSelectionReport) + self.assertEqual(report.final_status, "selected") + self.assertIsNotNone(report.selected) + assert report.selected is not None + self.assertEqual(report.selected.py_fn, fallback.py_fn) + self.assertEqual( + [(candidate.name, candidate.status) for candidate in report.candidates], + [("mismatch", "dtype_mismatch"), ("fallback", "selected")], + ) + + def test_select_kernel_report_mode_records_constraint_failure_candidates(self) -> None: + constrained_check = lambda enabled=False: enabled + + @pto.vkernel( + op="matcher_report_constraint_failure_unique", + dtypes=[(pto.AnyFloat, pto.AnyFloat)], + constraints=[constrained_check], + priority=20, + ) + def constrained(inp: pto.TensorView, out: pto.TensorView): + return None + + @pto.vkernel( + op="matcher_report_constraint_failure_unique", + dtypes=[(pto.AnyFloat, pto.AnyFloat)], + priority=5, + ) + def fallback(inp: pto.TensorView, out: pto.TensorView): + return None + + report = pto.select_kernel( + "a5", + "matcher_report_constraint_failure_unique", + (pto.f32, pto.f32), + context_attrs={"enabled": False}, + return_metadata=True, + include_mlir=False, + ) + + self.assertEqual(report.final_status, "selected") + self.assertIsNotNone(report.selected) + assert report.selected is not None + self.assertEqual(report.selected.py_fn, fallback.py_fn) + expected_location = ( + f"{constrained_check.__code__.co_filename}:{constrained_check.__code__.co_firstlineno}" + ) + self.assertEqual( + [ + ( + candidate.name, + candidate.status, + candidate.failed_constraint_index, + candidate.failed_constraint_location, + ) + for candidate in report.candidates + ], + [ + ("constrained", "constraint_failed", 0, expected_location), + ("fallback", "selected", None, None), + ], + ) + self.assertIn(expected_location, report.candidates[0].reason) + + def test_select_kernel_report_mode_records_constraint_exceptions(self) -> None: + bad_constraint = lambda missing: missing + + @pto.vkernel( + op="matcher_report_constraint_exception_unique", + dtypes=[(pto.f32, pto.f32)], + constraints=[bad_constraint], + ) + def bad(inp: pto.TensorView, out: pto.TensorView): + return None + + report = pto.select_kernel( + "a5", + "matcher_report_constraint_exception_unique", + (pto.f32, pto.f32), + return_metadata=True, + include_mlir=False, + ) + + self.assertEqual(report.final_status, "no_candidate") + self.assertIsNone(report.selected) + self.assertIn("requires unsupported parameter", report.final_error) + self.assertEqual(len(report.candidates), 1) + expected_location = ( + f"{bad_constraint.__code__.co_filename}:{bad_constraint.__code__.co_firstlineno}" + ) + candidate = report.candidates[0] + self.assertEqual(candidate.name, "bad") + self.assertEqual(candidate.status, "constraint_error") + self.assertEqual(candidate.failed_constraint_index, 0) + self.assertEqual(candidate.failed_constraint_location, expected_location) + self.assertEqual(candidate.error_type, "TypeError") + self.assertIn("requires unsupported parameter", candidate.error_message) + self.assertIn(expected_location, candidate.error_message) + + def test_select_kernel_report_mode_reports_priority_ties(self) -> None: + @pto.vkernel( + op="matcher_report_priority_tie_unique", + dtypes=[(pto.f32, pto.f32)], + priority=33, + ) + def lhs(inp: pto.TensorView, out: pto.TensorView): + return None + + @pto.vkernel( + op="matcher_report_priority_tie_unique", + dtypes=[(pto.f32, pto.f32)], + priority=33, + ) + def rhs(inp: pto.TensorView, out: pto.TensorView): + return None + + report = pto.select_kernel( + "a5", + "matcher_report_priority_tie_unique", + (pto.f32, pto.f32), + return_metadata=True, + include_mlir=False, + ) + + self.assertEqual(report.final_status, "priority_tie") + self.assertIsNone(report.selected) + self.assertIn("multiple highest-priority kernels", report.final_error) + self.assertEqual( + [(candidate.name, candidate.status) for candidate in report.candidates], + [("lhs", "priority_tie"), ("rhs", "priority_tie")], + ) + + def test_select_kernel_report_mode_reports_no_candidate_without_candidates(self) -> None: + empty_registry = pto.KernelRegistry() + + report = pto.select_kernel( + "a5", + "matcher_report_empty_registry_unique", + (pto.f32,), + registry=empty_registry, + return_metadata=True, + include_mlir=False, + ) + + self.assertEqual(report.final_status, "no_candidate") + self.assertIsNone(report.selected) + self.assertEqual(report.candidates, ()) + self.assertIn("found no registered kernel", report.final_error) + + def test_select_kernel_report_mode_includes_mlir_text_for_materializable_candidate(self) -> None: + @pto.vkernel( + op="matcher_report_mlir_text_unique", + dtypes=[(pto.f32, pto.f32)], + ) + def kernel(inp: pto.TensorView, out: pto.TensorView): + return None + + report = pto.select_kernel( + "a5", + "matcher_report_mlir_text_unique", + (pto.f32, pto.f32), + return_metadata=True, + include_mlir=True, + ) + + self.assertEqual(report.final_status, "selected") + self.assertEqual(len(report.candidates), 1) + candidate = report.candidates[0] + self.assertEqual(candidate.status, "selected") + self.assertIsNotNone(candidate.mlir_text) + self.assertIsNone(candidate.mlir_error) + self.assertIn("module attributes", candidate.mlir_text) + self.assertIn("@kernel", candidate.mlir_text) + self.assertIn("!pto.tensor_view", candidate.mlir_text) + + def test_select_kernel_report_mode_includes_mlir_error_for_unspecialized_tile_candidate(self) -> None: + @pto.vkernel( + op="matcher_report_mlir_error_unique", + dtypes=[(pto.f32, pto.f32)], + ) + def kernel(inp: pto.TensorView, out: pto.Tile): + return None + + report = pto.select_kernel( + "a5", + "matcher_report_mlir_error_unique", + (pto.f32, pto.f32), + return_metadata=True, + include_mlir=True, + ) + + self.assertEqual(report.final_status, "selected") + self.assertEqual(len(report.candidates), 1) + candidate = report.candidates[0] + self.assertEqual(candidate.status, "selected") + self.assertIsNone(candidate.mlir_text) + self.assertIsNotNone(candidate.mlir_error) + self.assertIn("requires specialize() bindings for bare Tile parameters", candidate.mlir_error) + def test_materialization_constraints_can_see_specializations_and_selected_context_attrs(self) -> None: @pto.vkernel( op="matcher_materialization_constraint_unique", From e927a6d731c194359983bb76893619e72b7cdcc0 Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Thu, 16 Apr 2026 17:27:04 +0800 Subject: [PATCH 089/192] specify dist of vsts/vlds --- docs/isa/03-vector-load-store.md | 75 +++---- docs/isa/10-reduction-ops.md | 2 +- docs/isa/13-dsa-sfu-ops.md | 2 +- docs/vpto-spec.md | 12 +- lib/PTO/IR/VPTO.cpp | 187 ++++++++++++------ lib/PTO/Transforms/PTOToVPTOLowering.cpp | 4 +- lib/PTO/Transforms/PTOValidateVPTOIR.cpp | 9 +- lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 128 ++++++------ .../kernels/online-softmax-update/kernel.pto | 10 +- 9 files changed, 252 insertions(+), 177 deletions(-) diff --git a/docs/isa/03-vector-load-store.md b/docs/isa/03-vector-load-store.md index 870b991bc..7b9abc634 100644 --- a/docs/isa/03-vector-load-store.md +++ b/docs/isa/03-vector-load-store.md @@ -36,7 +36,7 @@ Vector loads move data from Unified Buffer (UB) to vector registers (`vreg`). Ve |-------------|-------------|------------------------------| | `RV_VLD` | `dist:NORMAL` / `NORAML` | **9** | | `RV_VLDI` | `dist:DINTLV` (dual vreg) | **9** | -| `RV_VST` / `RV_VSTI` | `dist:NORM` | **9** | +| `RV_VST` / `RV_VSTI` | `dist:NORM_B8` / `NORM_B16` / `NORM_B32` | **9** | | `RV_VGATHER2` | `Dtype: B32` | **27–28** | | `RV_VGATHERB` | indexed byte gather | **~21** | | `RV_VSCATTER` | `Dtype: B16` | **~17** | @@ -44,17 +44,17 @@ Vector loads move data from Unified Buffer (UB) to vector registers (`vreg`). Ve ### `dist:` tokens (issue→retire) -Most **`dist:`** tokens are **9** issue→retire cycles. **`INTLV`** on **`RV_VSTI`** is **12** cycles. +Most **`dist:`** tokens are **9** issue→retire cycles. **`INTLV_B8` / `INTLV_B16` / `INTLV_B32`** on **`RV_VSTI`** are **12** cycles. | `dist:` (as in log) | RV op | issue→retire (cycles) | |---------------------|-------|----------------------| -| `DINTLV` | `RV_VLDI` | **9** | -| `BRC` | `RV_VLD` | **9** | +| `DINTLV_B8` / `DINTLV_B16` / `DINTLV_B32` | `RV_VLDI` | **9** | +| `BRC_B8` / `BRC_B16` / `BRC_B32` | `RV_VLD` | **9** | | `BRC_BLK` | `RV_VLD` | **9** | -| `INTLV` | `RV_VSTI` | **12** | -| `UNPK` | `RV_VLD` | **9** | -| `NORM` | `RV_VSTI` | **9** | -| `PK` | `RV_VSTI` | **9** | +| `INTLV_B8` / `INTLV_B16` / `INTLV_B32` | `RV_VSTI` | **12** | +| `UNPK_B8` / `UNPK_B16` / `UNPK_B32` | `RV_VLD` | **9** | +| `NORM_B8` / `NORM_B16` / `NORM_B32` | `RV_VSTI` | **9** | +| `PK_B16` / `PK_B32` / `PK_B64` / `PK4_B32` | `RV_VSTI` | **9** | | `NORMAL` / `NORAML` | `RV_VLD` | **9** | **Note:** PTO intrinsic **`BRC_BLK`** matches the **`BRC_BLK`** `dist:` string on **`RV_VLD`** in simulator logs (block-replicate path; not a plain contiguous copy in the usual tiling use). @@ -72,21 +72,21 @@ Most **`dist:`** tokens are **9** issue→retire cycles. **`INTLV`** on **`RV_VS | PTO `dist` (load) | Latency | |-------------------|-------------------| | `NORM` | **9** cycles | -| `UNPK` | **9** cycles | -| `DINTLV` | **9** cycles (`RV_VLDI`) | -| `BRC` | **9** cycles (`RV_VLD`) | +| `UNPK_B8` / `UNPK_B16` / `UNPK_B32` | **9** cycles | +| `DINTLV_B8` / `DINTLV_B16` / `DINTLV_B32` | **9** cycles (`RV_VLDI`) | +| `BRC_B8` / `BRC_B16` / `BRC_B32` | **9** cycles (`RV_VLD`) | | `BRC_BLK` | **9** cycles as **`dist:BRC_BLK`** on `RV_VLD` | | `BDINTLV` | **9** cycles | -| `US`, `DS`, `SPLT4CHN`, `SPLT2CHN` | **9** cycles | +| `US_B8` / `US_B16`, `DS_B8` / `DS_B16`, `SPLT4CHN`, `SPLT2CHN_B8` / `SPLT2CHN_B16` | **9** cycles | ### PTO `dist` summary (stores) | PTO `dist` (store) | Latency | |--------------------|-------------------| -| `NORM` | **9** cycles (`RV_VSTI`) | -| `PK` | **9** cycles | -| `INTLV` (`pto.vstx2`) | **12** cycles | -| `MRG4CHN`, `MRG2CHN` | **9** cycles (surface retained; current A5 hardware still reports them unsupported at validation time) | +| `NORM_B8` / `NORM_B16` / `NORM_B32` | **9** cycles (`RV_VSTI`) | +| `PK_B16` / `PK_B32` / `PK_B64` / `PK4_B32` | **9** cycles | +| `INTLV_B8` / `INTLV_B16` / `INTLV_B32` (`pto.vstx2`) | **12** cycles | +| `MRG4CHN_B8`, `MRG2CHN_B8`, `MRG2CHN_B16` | **9** cycles (surface retained; current A5 hardware still reports them unsupported at validation time) | ### Gather, scatter, and special addressing @@ -131,20 +131,21 @@ DMA **`TLOAD` / `TSTORE`** (global memory ↔ UB) use **MTE** pipes, not `RV_VLD | Family | Allowed element widths | C semantics | Latency | |------|-------------|-------------|-------------| | `NORM` | width-agnostic | `dst[i] = UB[base + i * sizeof(T)]` | **9** cycles | -| `BRC` | `b8`, `b16`, `b32` | `dst[i] = UB[base]` for all `i` | **9** cycles | -| `US` | `b8`, `b16` | `dst[2*i] = dst[2*i+1] = UB[base + i]` | **9** cycles | -| `DS` | `b8`, `b16` | `dst[i] = UB[base + 2*i]` | **9** cycles | -| `UNPK` | `b8`, `b16`, `b32` | Expand packed source data into wider lanes | **9** cycles | +| `BRC_B8` / `BRC_B16` / `BRC_B32` | `b8`, `b16`, `b32` | `dst[i] = UB[base]` for all `i` | **9** cycles | +| `US_B8` / `US_B16` | `b8`, `b16` | `dst[2*i] = dst[2*i+1] = UB[base + i]` | **9** cycles | +| `DS_B8` / `DS_B16` | `b8`, `b16` | `dst[i] = UB[base + 2*i]` | **9** cycles | +| `UNPK_B8` / `UNPK_B16` / `UNPK_B32` | `b8`, `b16`, `b32` | Expand packed source data into wider lanes | **9** cycles | | `BRC_BLK` | width-agnostic | Block-replicate load path; simulator logs may print `dist:BRC_BLK` | **9** cycles | -| `E2B` | `b16`, `b32` | Load element groups and expand them into byte-oriented lane layout | **9** cycles | +| `E2B_B16` / `E2B_B32` | `b16`, `b32` | Load element groups and expand them into byte-oriented lane layout | **9** cycles | | `UNPK4` | `b8` | Unpack 4-way packed `b8` source groups into destination lanes | **9** cycles | | `SPLT4CHN` | `b8` | Split 4-channel interleaved source into one channel plane | **9** cycles | -| `SPLT2CHN` | `b8`, `b16` | Split 2-channel interleaved source into one channel plane | **9** cycles | +| `SPLT2CHN_B8` / `SPLT2CHN_B16` | `b8`, `b16` | Split 2-channel interleaved source into one channel plane | **9** cycles | `pto.vlds` currently covers only single-result load families. Dual-result deinterleave forms are modeled separately in PTO surface as [`pto.vldsx2`](#ptovldsx2): `BDINTLV` is the block-deinterleave family, while -`DINTLV` is the element-width-sensitive deinterleave family. +`DINTLV_B8` / `DINTLV_B16` / `DINTLV_B32` are the element-width-sensitive +deinterleave forms. **Example — Contiguous load:** ```mlir @@ -153,7 +154,7 @@ deinterleave forms are modeled separately in PTO surface as **Example — Broadcast scalar to all lanes:** ```mlir -%v = pto.vlds %ub[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> +%v = pto.vlds %ub[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> ``` --- @@ -233,19 +234,21 @@ deinterleave forms are modeled separately in PTO surface as This family is only legal for interleave/deinterleave style distributions. The two outputs form an ordered pair, and that pairing MUST be preserved. PTO surface accepts deinterleave families. `BDINTLV` is element-width - agnostic, while `DINTLV` supports only the element widths listed in the + agnostic, while `DINTLV_B8` / `DINTLV_B16` / `DINTLV_B32` support only the + element widths listed in the table. -- **latency:** `BDINTLV` / `DINTLV` are both **9** cycles. +- **latency:** `BDINTLV` / `DINTLV_B8` / `DINTLV_B16` / `DINTLV_B32` are all + **9** cycles. **Distribution families:** | Family | Allowed element widths | C semantics | Latency | |------|-------------|-------------|-------------| | `BDINTLV` | width-agnostic | Block deinterleave into two destination vectors | **9** cycles | -| `DINTLV` | `b8`, `b16`, `b32` | Deinterleave alternating elements into `%low` / `%high` | **9** cycles | +| `DINTLV_B8` / `DINTLV_B16` / `DINTLV_B32` | `b8`, `b16`, `b32` | Deinterleave alternating elements into `%low` / `%high` | **9** cycles | ```c -// DINTLV family on 32-bit elements: deinterleave 32-bit elements +// DINTLV_B32 family on 32-bit elements: deinterleave 32-bit elements for (int i = 0; i < 64; i++) { low[i] = UB[base + 8*i]; // even elements high[i] = UB[base + 8*i + 4]; // odd elements @@ -254,7 +257,7 @@ for (int i = 0; i < 64; i++) { **Example — Load interleaved XY pairs into separate X/Y vectors:** ```mlir -%x, %y = pto.vldsx2 %ub[%offset], "DINTLV" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x, %y = pto.vldsx2 %ub[%offset], "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> ``` ### `pto.vsldb` @@ -382,16 +385,16 @@ for (int blk = 0; blk < VL / 32; ++blk) { | Family | Allowed element widths | C semantics | Latency | |------|-------------|-------------|-------------| -| `NORM` | `b8`, `b16`, `b32` | `UB[base + i] = src[i]` | **9** cycles | -| `1PT` | `b8`, `b16`, `b32` | Only element 0 is written to the destination footprint | **9** cycles | -| `PK` | `b16`, `b32`, `b64` | Pack low half bits of each source element before store | **9** cycles | -| `PK4` | `b32` | Pack low 8 bits of each `b32` element before store | **9** cycles | -| `MRG4CHN` | `b8` | Merge 4 channel planes into an interleaved 4-channel layout. VPTO currently requires `!pto.mask` for this family and emits a hardware-unsupported warning on A5. | **9** cycles | -| `MRG2CHN` | `b8`, `b16` | Merge 2 channel planes into an interleaved 2-channel layout. VPTO currently requires `!pto.mask` for `b8` input and `!pto.mask` for `b16` input, and emits a hardware-unsupported warning on A5. | **9** cycles | +| `NORM_B8` / `NORM_B16` / `NORM_B32` | `b8`, `b16`, `b32` | `UB[base + i] = src[i]` | **9** cycles | +| `1PT_B8` / `1PT_B16` / `1PT_B32` | `b8`, `b16`, `b32` | Only element 0 is written to the destination footprint | **9** cycles | +| `PK_B16` / `PK_B32` / `PK_B64` | `b16`, `b32`, `b64` | Pack low half bits of each source element before store. | **9** cycles | +| `PK4_B32` | `b32` | Pack low 8 bits of each `b32` element before store | **9** cycles | +| `MRG4CHN_B8` | `b8` | Merge 4 channel planes into an interleaved 4-channel layout. VPTO currently requires `!pto.mask` for this family and emits a hardware-unsupported warning on A5. | **9** cycles | +| `MRG2CHN_B8` / `MRG2CHN_B16` | `b8`, `b16` | Merge 2 channel planes into an interleaved 2-channel layout. VPTO currently requires `!pto.mask` for `MRG2CHN_B8` and `!pto.mask` for `MRG2CHN_B16`, and emits a hardware-unsupported warning on A5. | **9** cycles | **Example — Contiguous store:** ```mlir -pto.vsts %v, %ub[%offset], %mask {dist = "NORM"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %v, %ub[%offset], %mask {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask ``` --- diff --git a/docs/isa/10-reduction-ops.md b/docs/isa/10-reduction-ops.md index b2fb20894..dec8ccbab 100644 --- a/docs/isa/10-reduction-ops.md +++ b/docs/isa/10-reduction-ops.md @@ -229,7 +229,7 @@ for (int i = 1; i < N; i++) // Softmax: find max for numerical stability %max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> // max is in lane 0, broadcast it -%max_broadcast = pto.vlds %ub_tmp[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> +%max_broadcast = pto.vlds %ub_tmp[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> // Row-wise sum using vcgadd (for 8-row tile) %row_sums = pto.vcgadd %tile, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> diff --git a/docs/isa/13-dsa-sfu-ops.md b/docs/isa/13-dsa-sfu-ops.md index 731fa71b2..32fc75b82 100644 --- a/docs/isa/13-dsa-sfu-ops.md +++ b/docs/isa/13-dsa-sfu-ops.md @@ -218,7 +218,7 @@ for (int i = 0; i < N; i++) ```mlir // Softmax with fused expdiff -%max_broadcast = pto.vlds %ub_max[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> +%max_broadcast = pto.vlds %ub_max[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> %exp_stable = pto.vexpdiff %logits, %max_broadcast : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> // Leaky ReLU activation diff --git a/docs/vpto-spec.md b/docs/vpto-spec.md index 0cb136c32..11e8a6ee8 100644 --- a/docs/vpto-spec.md +++ b/docs/vpto-spec.md @@ -910,7 +910,7 @@ This section provides a categorized overview of all PTO micro Instruction operat | Contiguous Load | 3 | `pto.vlds` with `NORM` dist | | Broadcast Load | 3 | `pto.vlds` with `BRC` family dist | | Gather | 3 | `pto.vgather2`, `pto.vgatherb` | -| Contiguous Store | 3 | `pto.vsts` with `NORM` dist | +| Contiguous Store | 3 | `pto.vsts` with `NORM_B8` / `NORM_B16` / `NORM_B32` dist | | Scatter | 3 | `pto.vscatter` | ### Compute Operations @@ -979,7 +979,7 @@ Group 14 covers the full scalar `arith` surface. The rows below list common PTO // 1. Find max %max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> pto.vsts %max_vec, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask -%max_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> +%max_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> // 2. exp(x - max) using fused op %exp = pto.vexpdiff %logits, %max_bc, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> @@ -987,7 +987,7 @@ pto.vsts %max_vec, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, ! // 3. Sum %sum = pto.vcadd %exp, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> pto.vsts %sum, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask -%sum_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> +%sum_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> // 4. Divide %softmax = pto.vdiv %exp, %sum_bc, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> @@ -1011,10 +1011,10 @@ pto.vsts %sum, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto. ```mlir // AoS → SoA (deinterleave) -%x, %y = pto.vldsx2 %ub_xy[%offset], "DINTLV" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x, %y = pto.vldsx2 %ub_xy[%offset], "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> // SoA → AoS (interleave) -pto.vstsx2 %x, %y, %ub_xy[%offset], "INTLV", %all_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, index, !pto.mask +pto.vstsx2 %x, %y, %ub_xy[%offset], "INTLV_B32", %all_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, index, !pto.mask ``` --- @@ -1047,6 +1047,6 @@ pto.vstsx2 %x, %y, %ub_xy[%offset], "INTLV", %all_mask : !pto.vreg<64xf32>, !pto ### Part 3C -2. **Store dist family completeness:** `vsts` currently covers `NORM`, `1PT`, `PK`, `PK4`, `MRG4CHN`, and `MRG2CHN`, while `vstsx2` covers `INTLV`. `MRG4CHN` / `MRG2CHN` are preserved in the VPTO surface, but the current hardware still reports them as unsupported via verifier warning and they are not expected to validate at runtime on A5 today. +2. **Store dist family completeness:** `vsts` currently covers `NORM_B8`, `NORM_B16`, `NORM_B32`, `1PT_B8`, `1PT_B16`, `1PT_B32`, `PK_B16`, `PK_B32`, `PK_B64`, `PK4_B32`, `MRG4CHN_B8`, `MRG2CHN_B8`, and `MRG2CHN_B16`, while `vstsx2` covers `INTLV_B8` / `INTLV_B16` / `INTLV_B32`. `MRG4CHN_B8` / `MRG2CHN_B8` / `MRG2CHN_B16` are preserved in the VPTO surface, but the current hardware still reports them as unsupported via verifier warning and they are not expected to validate at runtime on A5 today. 3. **vcvt width-changing pattern:** The even/odd + `vor` pattern for forms such as `f32 -> f16` is the standard compiler lowering. Confirm this is the intended representation in the spec. 4. **Stateful store ops (Section 14):** These are complex with SSA state threading. Are they all needed for A5, or can some be simplified? diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp index 420873d21..9a9f6e447 100644 --- a/lib/PTO/IR/VPTO.cpp +++ b/lib/PTO/IR/VPTO.cpp @@ -774,22 +774,29 @@ static bool matchesWidthFamily(StringRef dist, unsigned width, } static bool isSupportedVldx2DistToken(StringRef dist) { - return dist == "BDINTLV" || dist == "DINTLV"; + return dist == "BDINTLV" || dist == "DINTLV_B8" || dist == "DINTLV_B16" || + dist == "DINTLV_B32"; } static bool isSupportedVldsDistToken(StringRef dist) { - return dist == "NORM" || dist == "BRC" || dist == "US" || dist == "DS" || - dist == "UNPK" || dist == "BRC_BLK" || dist == "E2B" || - dist == "UNPK4" || dist == "SPLT4CHN" || dist == "SPLT2CHN"; + return dist == "NORM" || dist == "BRC_B8" || dist == "BRC_B16" || + dist == "BRC_B32" || dist == "US_B8" || dist == "US_B16" || + dist == "DS_B8" || dist == "DS_B16" || dist == "UNPK_B8" || + dist == "UNPK_B16" || dist == "UNPK_B32" || dist == "BRC_BLK" || + dist == "E2B_B16" || dist == "E2B_B32" || dist == "UNPK4" || + dist == "SPLT4CHN" || dist == "SPLT2CHN_B8" || dist == "SPLT2CHN_B16"; } static bool isSupportedVstsDistToken(StringRef dist) { - return dist == "NORM" || dist == "1PT" || dist == "PK" || - dist == "PK4" || dist == "MRG4CHN" || dist == "MRG2CHN"; + return dist == "NORM_B8" || dist == "NORM_B16" || dist == "NORM_B32" || + dist == "1PT_B8" || dist == "1PT_B16" || dist == "1PT_B32" || + dist == "PK_B16" || dist == "PK_B32" || dist == "PK_B64" || + dist == "PK4_B32" || dist == "MRG4CHN_B8" || dist == "MRG2CHN_B8" || + dist == "MRG2CHN_B16"; } static bool isSupportedVstsx2DistToken(StringRef dist) { - return dist == "INTLV"; + return dist == "INTLV_B8" || dist == "INTLV_B16" || dist == "INTLV_B32"; } static LogicalResult verifyVldsDistWidth(Operation *op, StringRef dist, @@ -800,26 +807,42 @@ static LogicalResult verifyVldsDistWidth(Operation *op, StringRef dist, if (dist == "NORM" || dist == "BRC_BLK") return success(); - if (dist == "BRC") - return matchesWidthFamily(dist, *width, {8, 16, 32}) - ? success() - : op->emitOpError("dist BRC only supports 8/16/32-bit elements"); - if (dist == "US") - return matchesWidthFamily(dist, *width, {8, 16}) - ? success() - : op->emitOpError("dist US only supports 8/16-bit elements"); - if (dist == "DS") - return matchesWidthFamily(dist, *width, {8, 16}) - ? success() - : op->emitOpError("dist DS only supports 8/16-bit elements"); - if (dist == "UNPK") - return matchesWidthFamily(dist, *width, {8, 16, 32}) - ? success() - : op->emitOpError("dist UNPK only supports 8/16/32-bit elements"); - if (dist == "E2B") - return matchesWidthFamily(dist, *width, {16, 32}) - ? success() - : op->emitOpError("dist E2B only supports 16/32-bit elements"); + if (dist == "BRC_B8") + return *width == 8 ? success() + : op->emitOpError("dist BRC_B8 only supports 8-bit elements"); + if (dist == "BRC_B16") + return *width == 16 ? success() + : op->emitOpError("dist BRC_B16 only supports 16-bit elements"); + if (dist == "BRC_B32") + return *width == 32 ? success() + : op->emitOpError("dist BRC_B32 only supports 32-bit elements"); + if (dist == "US_B8") + return *width == 8 ? success() + : op->emitOpError("dist US_B8 only supports 8-bit elements"); + if (dist == "US_B16") + return *width == 16 ? success() + : op->emitOpError("dist US_B16 only supports 16-bit elements"); + if (dist == "DS_B8") + return *width == 8 ? success() + : op->emitOpError("dist DS_B8 only supports 8-bit elements"); + if (dist == "DS_B16") + return *width == 16 ? success() + : op->emitOpError("dist DS_B16 only supports 16-bit elements"); + if (dist == "UNPK_B8") + return *width == 8 ? success() + : op->emitOpError("dist UNPK_B8 only supports 8-bit elements"); + if (dist == "UNPK_B16") + return *width == 16 ? success() + : op->emitOpError("dist UNPK_B16 only supports 16-bit elements"); + if (dist == "UNPK_B32") + return *width == 32 ? success() + : op->emitOpError("dist UNPK_B32 only supports 32-bit elements"); + if (dist == "E2B_B16") + return *width == 16 ? success() + : op->emitOpError("dist E2B_B16 only supports 16-bit elements"); + if (dist == "E2B_B32") + return *width == 32 ? success() + : op->emitOpError("dist E2B_B32 only supports 32-bit elements"); if (dist == "UNPK4") return *width == 8 ? success() @@ -828,10 +851,14 @@ static LogicalResult verifyVldsDistWidth(Operation *op, StringRef dist, return *width == 8 ? success() : op->emitOpError("dist SPLT4CHN only supports 8-bit elements"); - if (dist == "SPLT2CHN") - return matchesWidthFamily(dist, *width, {8, 16}) + if (dist == "SPLT2CHN_B8") + return *width == 8 + ? success() + : op->emitOpError("dist SPLT2CHN_B8 only supports 8-bit elements"); + if (dist == "SPLT2CHN_B16") + return *width == 16 ? success() - : op->emitOpError("dist SPLT2CHN only supports 8/16-bit elements"); + : op->emitOpError("dist SPLT2CHN_B16 only supports 16-bit elements"); return op->emitOpError("requires a supported load distribution token"); } @@ -844,10 +871,15 @@ static LogicalResult verifyVldsx2DistWidth(Operation *op, StringRef dist, "requires x2 load element type with a concrete bit width"); if (dist == "BDINTLV") return success(); - if (dist == "DINTLV") - return matchesWidthFamily(dist, *width, {8, 16, 32}) - ? success() - : op->emitOpError("dist DINTLV only supports 8/16/32-bit elements"); + if (dist == "DINTLV_B8") + return *width == 8 ? success() + : op->emitOpError("dist DINTLV_B8 only supports 8-bit elements"); + if (dist == "DINTLV_B16") + return *width == 16 ? success() + : op->emitOpError("dist DINTLV_B16 only supports 16-bit elements"); + if (dist == "DINTLV_B32") + return *width == 32 ? success() + : op->emitOpError("dist DINTLV_B32 only supports 32-bit elements"); return op->emitOpError("requires a supported x2 load distribution token"); } @@ -858,30 +890,49 @@ static LogicalResult verifyVstsDistWidth(Operation *op, StringRef dist, return op->emitOpError( "requires store element type with a concrete bit width"); - if (dist == "NORM") - return matchesWidthFamily(dist, *width, {8, 16, 32}) - ? success() - : op->emitOpError("dist NORM only supports 8/16/32-bit elements"); - if (dist == "1PT") - return matchesWidthFamily(dist, *width, {8, 16, 32}) - ? success() - : op->emitOpError("dist 1PT only supports 8/16/32-bit elements"); - if (dist == "PK") - return matchesWidthFamily(dist, *width, {16, 32, 64}) - ? success() - : op->emitOpError("dist PK only supports 16/32/64-bit elements"); - if (dist == "PK4") - return *width == 32 - ? success() - : op->emitOpError("dist PK4 only supports 32-bit elements"); - if (dist == "MRG4CHN") { + if (dist == "NORM_B8") + return *width == 8 ? success() + : op->emitOpError("dist NORM_B8 only supports 8-bit elements"); + if (dist == "NORM_B16") + return *width == 16 ? success() + : op->emitOpError("dist NORM_B16 only supports 16-bit elements"); + if (dist == "NORM_B32") + return *width == 32 ? success() + : op->emitOpError("dist NORM_B32 only supports 32-bit elements"); + if (dist == "1PT_B8") + return *width == 8 ? success() + : op->emitOpError("dist 1PT_B8 only supports 8-bit elements"); + if (dist == "1PT_B16") + return *width == 16 ? success() + : op->emitOpError("dist 1PT_B16 only supports 16-bit elements"); + if (dist == "1PT_B32") + return *width == 32 ? success() + : op->emitOpError("dist 1PT_B32 only supports 32-bit elements"); + if (dist == "PK_B16") + return *width == 16 ? success() + : op->emitOpError("dist PK_B16 only supports 16-bit elements"); + if (dist == "PK_B32") + return *width == 32 ? success() + : op->emitOpError("dist PK_B32 only supports 32-bit elements"); + if (dist == "PK_B64") + return *width == 64 ? success() + : op->emitOpError("dist PK_B64 only supports 64-bit elements"); + if (dist == "PK4_B32") + return *width == 32 ? success() + : op->emitOpError("dist PK4_B32 only supports 32-bit elements"); + if (dist == "MRG4CHN_B8") { if (*width != 8) - return op->emitOpError("dist MRG4CHN only supports 8-bit elements"); + return op->emitOpError("dist MRG4CHN_B8 only supports 8-bit elements"); return success(); } - if (dist == "MRG2CHN") { - if (!matchesWidthFamily(dist, *width, {8, 16})) - return op->emitOpError("dist MRG2CHN only supports 8/16-bit elements"); + if (dist == "MRG2CHN_B8") { + if (*width != 8) + return op->emitOpError("dist MRG2CHN_B8 only supports 8-bit elements"); + return success(); + } + if (dist == "MRG2CHN_B16") { + if (*width != 16) + return op->emitOpError("dist MRG2CHN_B16 only supports 16-bit elements"); return success(); } @@ -894,10 +945,12 @@ getVstsMaskGranularityOverride(StringRef dist, Type elementType) { if (!width) return std::nullopt; - if (dist == "MRG4CHN") + if (dist == "MRG4CHN_B8") + return StringRef("b32"); + if (dist == "MRG2CHN_B8") + return StringRef("b16"); + if (dist == "MRG2CHN_B16") return StringRef("b32"); - if (dist == "MRG2CHN") - return *width == 8 ? StringRef("b16") : StringRef("b32"); return std::nullopt; } @@ -908,10 +961,15 @@ static LogicalResult verifyVstsx2DistWidth(Operation *op, StringRef dist, if (!width) return op->emitOpError( "requires x2 store element type with a concrete bit width"); - if (dist == "INTLV") - return matchesWidthFamily(dist, *width, {8, 16, 32}) - ? success() - : op->emitOpError("dist INTLV only supports 8/16/32-bit elements"); + if (dist == "INTLV_B8") + return *width == 8 ? success() + : op->emitOpError("dist INTLV_B8 only supports 8-bit elements"); + if (dist == "INTLV_B16") + return *width == 16 ? success() + : op->emitOpError("dist INTLV_B16 only supports 16-bit elements"); + if (dist == "INTLV_B32") + return *width == 32 ? success() + : op->emitOpError("dist INTLV_B32 only supports 32-bit elements"); return op->emitOpError("requires a supported x2 store distribution token"); } @@ -1462,8 +1520,9 @@ static LogicalResult verifyVldsCommon(LoadOp op) { StringRef dist = *op.getDist(); if (!isSupportedVldsDistToken(dist)) return op.emitOpError( - "supports only NORM, BRC, US, DS, UNPK, BRC_BLK, E2B, UNPK4, " - "and SPLT2CHN/SPLT4CHN load distributions"); + "supports only NORM, BRC_B8/B16/B32, US_B8/B16, DS_B8/B16, " + "UNPK_B8/B16/B32, BRC_BLK, E2B_B16/B32, UNPK4, SPLT4CHN, and " + "SPLT2CHN_B8/B16 load distributions"); if (failed(verifyVldsDistWidth( op.getOperation(), dist, cast(op.getResult().getType()).getElementType()))) diff --git a/lib/PTO/Transforms/PTOToVPTOLowering.cpp b/lib/PTO/Transforms/PTOToVPTOLowering.cpp index 89d4f33e7..ff8df34f9 100644 --- a/lib/PTO/Transforms/PTOToVPTOLowering.cpp +++ b/lib/PTO/Transforms/PTOToVPTOLowering.cpp @@ -2355,9 +2355,9 @@ LogicalResult buildRowReduceVecScope(StringRef family, auto getRowReduceStoreDist = [&]() -> StringAttr { if (contract.elementType.isF16() || contract.elementType.isBF16()) - return rewriter.getStringAttr("1PT"); + return rewriter.getStringAttr("1PT_B16"); if (contract.elementType.isF32()) - return rewriter.getStringAttr("1PT"); + return rewriter.getStringAttr("1PT_B32"); return {}; }; StringAttr storeDist = getRowReduceStoreDist(); diff --git a/lib/PTO/Transforms/PTOValidateVPTOIR.cpp b/lib/PTO/Transforms/PTOValidateVPTOIR.cpp index 92b57f769..ae384fcdc 100644 --- a/lib/PTO/Transforms/PTOValidateVPTOIR.cpp +++ b/lib/PTO/Transforms/PTOValidateVPTOIR.cpp @@ -333,14 +333,17 @@ class VPTOLegalityValidator { StringRef dist = distAttr.getValue(); unsigned width = elementIntType.getWidth(); - if (dist == "MRG4CHN") { + if (dist == "MRG4CHN_B8") { if (width == 8) return VPTOMaskGranularity::B32; return std::nullopt; } - if (dist == "MRG2CHN") { + if (dist == "MRG2CHN_B8") { if (width == 8) return VPTOMaskGranularity::B16; + return std::nullopt; + } + if (dist == "MRG2CHN_B16") { if (width == 16) return VPTOMaskGranularity::B32; } @@ -405,7 +408,7 @@ class VPTOLegalityValidator { return; StringRef dist = distAttr.getValue(); - if (dist == "MRG4CHN" || dist == "MRG2CHN") + if (dist == "MRG4CHN_B8" || dist == "MRG2CHN_B8" || dist == "MRG2CHN_B16") writeDiagnostic((Twine("warning: ") + store->getName().getStringRef() + " dist " + dist + " is not supported on the current hardware\n") diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index 4c895896e..2650f732b 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -586,38 +586,40 @@ static std::optional parseLoadDistImmediate(StringRef dist, return 0; if (!width) return std::nullopt; - if (dist == "BRC") - return *width == 8 ? std::optional(1) - : *width == 16 ? std::optional(2) - : *width == 32 ? std::optional(3) - : std::nullopt; - if (dist == "US") - return *width == 8 ? std::optional(6) - : *width == 16 ? std::optional(7) - : std::nullopt; - if (dist == "DS") - return *width == 8 ? std::optional(8) - : *width == 16 ? std::optional(9) - : std::nullopt; - if (dist == "UNPK") - return *width == 8 ? std::optional(13) - : *width == 16 ? std::optional(14) - : *width == 32 ? std::optional(18) - : std::nullopt; + if (dist == "BRC_B8") + return std::optional(1); + if (dist == "BRC_B16") + return std::optional(2); + if (dist == "BRC_B32") + return std::optional(3); + if (dist == "US_B8") + return std::optional(6); + if (dist == "US_B16") + return std::optional(7); + if (dist == "DS_B8") + return std::optional(8); + if (dist == "DS_B16") + return std::optional(9); + if (dist == "UNPK_B8") + return std::optional(13); + if (dist == "UNPK_B16") + return std::optional(14); + if (dist == "UNPK_B32") + return std::optional(18); if (dist == "BRC_BLK") return 15; - if (dist == "E2B") - return *width == 16 ? std::optional(16) - : *width == 32 ? std::optional(17) - : std::nullopt; + if (dist == "E2B_B16") + return std::optional(16); + if (dist == "E2B_B32") + return std::optional(17); if (dist == "UNPK4") return *width == 8 ? std::optional(20) : std::nullopt; if (dist == "SPLT4CHN") return *width == 8 ? std::optional(21) : std::nullopt; - if (dist == "SPLT2CHN") - return *width == 8 ? std::optional(22) - : *width == 16 ? std::optional(23) - : std::nullopt; + if (dist == "SPLT2CHN_B8") + return std::optional(22); + if (dist == "SPLT2CHN_B16") + return std::optional(23); return std::nullopt; } @@ -628,18 +630,19 @@ static std::optional parseLoadX2DistImmediate(StringRef dist, return 10; if (!width) return std::nullopt; - if (dist == "DINTLV") - return *width == 8 ? std::optional(11) - : *width == 16 ? std::optional(12) - : *width == 32 ? std::optional(19) - : std::nullopt; + if (dist == "DINTLV_B8") + return std::optional(11); + if (dist == "DINTLV_B16") + return std::optional(12); + if (dist == "DINTLV_B32") + return std::optional(19); return std::nullopt; } static std::optional parseStoreDistImmediate(StringRef dist, Type elementType) { auto width = getDistElementWidth(elementType); - if (dist.empty() || dist == "NORM") { + if (dist.empty()) { if (!width) return std::nullopt; if (*width == 8) @@ -650,26 +653,32 @@ static std::optional parseStoreDistImmediate(StringRef dist, return 2; return std::nullopt; } - if (!width) - return std::nullopt; - if (dist == "1PT") - return *width == 8 ? std::optional(3) - : *width == 16 ? std::optional(4) - : *width == 32 ? std::optional(5) - : std::nullopt; - if (dist == "PK") - return *width == 16 ? std::optional(6) - : *width == 32 ? std::optional(7) - : *width == 64 ? std::optional(10) - : std::nullopt; - if (dist == "PK4") - return *width == 32 ? std::optional(12) : std::nullopt; - if (dist == "MRG4CHN") - return *width == 8 ? std::optional(13) : std::nullopt; - if (dist == "MRG2CHN") - return *width == 8 ? std::optional(14) - : *width == 16 ? std::optional(15) - : std::nullopt; + if (dist == "NORM_B8") + return std::optional(0); + if (dist == "NORM_B16") + return std::optional(1); + if (dist == "NORM_B32") + return std::optional(2); + if (dist == "1PT_B8") + return std::optional(3); + if (dist == "1PT_B16") + return std::optional(4); + if (dist == "1PT_B32") + return std::optional(5); + if (dist == "PK_B16") + return std::optional(6); + if (dist == "PK_B32") + return std::optional(7); + if (dist == "PK_B64") + return std::optional(10); + if (dist == "PK4_B32") + return std::optional(12); + if (dist == "MRG4CHN_B8") + return std::optional(13); + if (dist == "MRG2CHN_B8") + return std::optional(14); + if (dist == "MRG2CHN_B16") + return std::optional(15); return std::nullopt; } @@ -678,11 +687,12 @@ static std::optional parseStoreX2DistImmediate(StringRef dist, auto width = getDistElementWidth(elementType); if (!width) return std::nullopt; - if (dist == "INTLV") - return *width == 8 ? std::optional(8) - : *width == 16 ? std::optional(9) - : *width == 32 ? std::optional(11) - : std::nullopt; + if (dist == "INTLV_B8") + return std::optional(8); + if (dist == "INTLV_B16") + return std::optional(9); + if (dist == "INTLV_B32") + return std::optional(11); return std::nullopt; } @@ -3236,7 +3246,7 @@ class LowerVstsOpPattern final : public OpConversionPattern { convertElementOffsetToBytes(op, adaptor.getOffset(), elementType); auto basePtr = dyn_cast(adaptor.getDestination().getType()); auto dist = - parseStoreDistImmediate(op.getDist().value_or("NORM"), elementType); + parseStoreDistImmediate(op.getDist().value_or(""), elementType); if (failed(offsetBytes) || !basePtr || !dist) return rewriter.notifyMatchFailure(op, "failed to materialize vsts operands"); @@ -3320,7 +3330,7 @@ class LowerVstsPostOpPattern final auto basePtr = dyn_cast(adaptor.getDestination().getType()); auto dist = - parseStoreDistImmediate(op.getDist().value_or("NORM"), elementType); + parseStoreDistImmediate(op.getDist().value_or(""), elementType); if (failed(offsetBytes) || !basePtr || !dist) { return rewriter.notifyMatchFailure(op, "failed to materialize vsts_post operands"); diff --git a/test/vpto/cases/kernels/online-softmax-update/kernel.pto b/test/vpto/cases/kernels/online-softmax-update/kernel.pto index 9d49bc6cb..a2dab4d2e 100644 --- a/test/vpto/cases/kernels/online-softmax-update/kernel.pto +++ b/test/vpto/cases/kernels/online-softmax-update/kernel.pto @@ -88,8 +88,8 @@ module attributes {pto.target_arch = "a5"} { %one_mask, %one_remaining = pto.plt_b32 %c1_i32 : i32 -> !pto.mask, i32 scf.for %row = %c0 to %row_count step %c1 { %row_qk = arith.muli %row, %c128 : index - %oldmax_bc = pto.vlds %ub_oldmax[%row] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> - %oldsum_bc = pto.vlds %ub_oldsum[%row] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> + %oldmax_bc = pto.vlds %ub_oldmax[%row] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + %oldsum_bc = pto.vlds %ub_oldsum[%row] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> %final_max, %final_sum = scf.for %chunk = %c0 to %c128 step %c64 iter_args(%running_max = %oldmax_bc, %running_sum = %oldsum_bc) @@ -120,9 +120,9 @@ module attributes {pto.target_arch = "a5"} { %raw_expmax = pto.vexpdiff %oldmax_bc, %final_max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> %scaled_oldsum = pto.vmul %raw_expmax, %oldsum_bc, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %expmax = pto.vdiv %scaled_oldsum, %final_sum, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> - pto.vsts %final_max, %ub_newmax[%row], %one_mask {dist = "1PT"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask - pto.vsts %final_sum, %ub_newsum[%row], %one_mask {dist = "1PT"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask - pto.vsts %expmax, %ub_expmax[%row], %one_mask {dist = "1PT"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.vsts %final_max, %ub_newmax[%row], %one_mask {dist = "1PT_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.vsts %final_sum, %ub_newsum[%row], %one_mask {dist = "1PT_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.vsts %expmax, %ub_expmax[%row], %one_mask {dist = "1PT_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask %zero = pto.vsub %final_max, %final_max, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> scf.for %chunk = %c0 to %c128 step %c64 { From c6524926e0c2848ae7ed5438bb8b631c739c6a7d Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Thu, 16 Apr 2026 20:16:42 +0800 Subject: [PATCH 090/192] fix(dsl): materialize vtrc round mode in tilelang --- .../11-vector-arithmetic-operations.md | 18 ++++- .../python/tilelang_dsl/frontend_ast.py | 1 + tilelang-dsl/python/tilelang_dsl/lowering.py | 12 ++- tilelang-dsl/python/tilelang_dsl/semantic.py | 79 +++++++++++++++++++ tilelang-dsl/tests/test_tilelang_dsl_v1.py | 69 ++++++++++++++++ 5 files changed, 175 insertions(+), 4 deletions(-) diff --git a/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md b/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md index 67ade98d2..f935ca8ca 100644 --- a/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md +++ b/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md @@ -1312,20 +1312,32 @@ pto.vtranspose(dst_ub_ptr, src_ub_ptr, config_word) Type conversion and specialized operations. -#### `pto.vtrc(vec: VRegType, mask: MaskType) -> VRegType` +#### `pto.vtrc(vec: VRegType, mask: MaskType, rnd: pto.VcvtRoundMode | None = None) -> VRegType` -**Description**: Truncate vector elements. +**Description**: Truncate/round float to integer-valued float (stays in float type). This is the TileLang DSL surface for the VPTO `pto.vtrc` operation. + +**Attribute Enums**: +- `pto.VcvtRoundMode`: `R`, `A`, `F`, `C`, `Z`, `O` (note: `vtrc` does not support `O`) **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| | `vec` | `VRegType` | Input vector | | `mask` | `MaskType` | Predicate mask | +| `rnd` | `pto.VcvtRoundMode` \| `None` | Optional rounding-mode attribute lowered to VPTO `round_mode`. Defaults to `R` if not specified. | **Returns**: | Return Value | Type | Description | |--------------|------|-------------| -| `result` | `VRegType` | Truncated vector | +| `result` | `VRegType` | Truncated vector with integer-valued float elements | + +**Constraints**: +- Current TileLang DSL v1 accepts exactly two positional arguments: `pto.vtrc(vec, mask)`. Optional `rnd` attribute is exposed as keyword argument: `rnd=...`. +- The underlying VPTO op syntax is `pto.vtrc %input, %mask, "RND"`. +- Supported rounding modes are `R` (round to nearest), `A` (round away from zero), `F` (floor), `C` (ceil), `Z` (truncate toward zero). +- The enum form is preferred. For compatibility, canonical strings such as `"R"`, `"A"`, `"F"`, `"C"`, `"Z"` are also accepted. +- This op does not change the element type; input and output have the same vector type. +- Only floating-point element types are supported: `f16`, `bf16`, `f32`. #### `pto.vcvt(vec: VRegType, to_type: Type, mask: MaskType, rnd: pto.VcvtRoundMode | None = None, sat: pto.VcvtSatMode | None = None, part: pto.VcvtPartMode | None = None) -> VRegType` diff --git a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py index bbed8b7d8..a78d7b55f 100644 --- a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py +++ b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py @@ -816,6 +816,7 @@ def _collect_reachable_inline_procs( } ), "vcvt": frozenset({"rnd", "sat", "part"}), + "vtrc": frozenset({"rnd"}), } diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index a5c99d692..5e260d79c 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -2890,6 +2890,17 @@ def _lower_call_expr( ) return _RenderedValue(name="__void_call__", type=SemanticMetaType(kind="void")) + if expr.name == "vtrc": + value = self._lower_expr(expr.args[0], env, indent=indent, into=into) + mask = self._lower_expr(expr.args[1], env, indent=indent, into=into) + rnd = self._render_string_literal(expr.args[2]) + into.append( + self._indent(indent) + + f"{result_name} = pto.vtrc {value.name}, {mask.name}, {rnd} : " + + f"{self._render_type(value.type)}, {self._render_type(mask.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + if expr.name in { "vabs", "vrelu", @@ -2911,7 +2922,6 @@ def _lower_call_expr( "vusqz", "vsqz", "vexpdiff", - "vtrc", "vcgadd", "vcgmax", "vcgmin", diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 52fa40551..be4c24ac9 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -3053,6 +3053,12 @@ def _analyze_expr( env, allow_outer_lookup=allow_outer_lookup, ) + if expr.namespace == "pto" and expr.name == "vtrc": + return self._analyze_vtrc_frontend_call( + expr, + env, + allow_outer_lookup=allow_outer_lookup, + ) if expr.keywords: raise TypeError( f"call surface `{expr.namespace + '.' if expr.namespace else ''}{expr.name}` " @@ -3761,6 +3767,8 @@ def _analyze_call_expr( return self._analyze_rearrangement_op(name, args) if name == "vcvt": return self._analyze_vcvt(args) + if name == "vtrc": + return self._analyze_vtrc(args) if name == "vbitsort": return self._analyze_vbitsort(args) if name == "vmrgsort4": @@ -4537,6 +4545,39 @@ def _analyze_vcvt_frontend_call( part_explicit="part" in analyzed_keywords, ) + def _analyze_vtrc_frontend_call( + self, + expr: FrontendCallExpr, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + ) -> SemanticExpr: + if len(expr.args) != 2: + raise TypeError( + "pto.vtrc expects exactly 2 positional operands `(vec, mask)` " + "before optional keyword attrs in TileLang DSL v1" + ) + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + analyzed_keywords = { + name: self._analyze_expr(value, env, allow_outer_lookup=allow_outer_lookup) + for name, value in expr.keywords + } + allowed_keywords = {"rnd"} + unexpected_keywords = sorted(set(analyzed_keywords) - allowed_keywords) + if unexpected_keywords: + keyword_text = ", ".join(unexpected_keywords) + raise TypeError( + "pto.vtrc only accepts keyword attr `rnd`; " + f"got unsupported keyword(s): {keyword_text}" + ) + return self._analyze_vtrc( + args, + rnd=self._normalize_vtrc_round_mode(analyzed_keywords.get("rnd")), + ) + def _analyze_vcvt( self, args: tuple[SemanticExpr, ...], @@ -4579,6 +4620,31 @@ def _analyze_vcvt( type=self._vreg_type_for_dtype(target_dtype), ) + def _analyze_vtrc( + self, + args: tuple[SemanticExpr, ...], + *, + rnd: SemanticExpr | None = None, + ) -> SemanticExpr: + if len(args) != 2: + raise TypeError("pto.vtrc expects exactly 2 positional arguments in TileLang DSL v1") + vector = self._require_vreg_expr(args[0], "pto.vtrc vector") + self._require_mask_for_vreg(args[1], vector, "pto.vtrc") + if vector.element_dtype not in {f16, bf16, f32}: + raise TypeError("pto.vtrc only supports f16/bf16/f32 vector element types in TileLang DSL v1") + return SemanticCallExpr( + namespace="pto", + name="vtrc", + args=( + args[0], + args[1], + rnd + if rnd is not None + else SemanticLiteralExpr(value="R", type=SemanticMetaType(kind="string")), + ), + type=vector, + ) + def _analyze_vbitsort(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: if len(args) != 4: raise TypeError("pto.vbitsort expects exactly 4 positional arguments in TileLang DSL v1") @@ -4827,6 +4893,19 @@ def _normalize_vcvt_round_mode(self, expr: SemanticExpr | None) -> SemanticExpr ) return SemanticLiteralExpr(value=round_mode, type=SemanticMetaType(kind="string")) + def _normalize_vtrc_round_mode(self, expr: SemanticExpr | None) -> SemanticExpr | None: + normalized = self._normalize_vcvt_round_mode(expr) + if normalized is None: + return None + round_mode = self._require_string_expr(normalized, "pto.vtrc rnd") + if round_mode == VcvtRoundMode.O.value: + raise TypeError( + "pto.vtrc rnd must be one of " + '`"R"`, `"A"`, `"F"`, `"C"`, `"Z"` or a matching ' + "VcvtRoundMode enum in TileLang DSL v1" + ) + return SemanticLiteralExpr(value=round_mode, type=SemanticMetaType(kind="string")) + def _normalize_vcvt_sat_mode(self, expr: SemanticExpr | None) -> SemanticExpr | None: if expr is None: return None diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index c70a45e68..78ba57be5 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -2893,6 +2893,75 @@ def kernel(dst: pto.Tile, src: pto.Tile): r"= pto\.vcvt %[^,\s]+(?: \{[^}]+\})? : !pto\.vreg<[^>]+> -> !pto\.vreg<[^>]+>", ) + def test_vtrc_defaults_to_round_nearest(self) -> None: + @pto.vkernel( + op="vtrc_default_rnd_unique", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + all_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vtrc(vec, all_mask) + pto.vsts(out, dst, 0, all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vtrc", text) + self.assertIn(', "R" :', text) + self.assertRegex( + text, + r"= pto\.vtrc %[^,\s]+, %[^,\s]+, \"R\" : !pto\.vreg<[^>]+>, !pto\.mask<[^>]+> -> !pto\.vreg<[^>]+>", + ) + + def test_vtrc_supports_keyword_rnd_with_enums(self) -> None: + @pto.vkernel( + op="vtrc_keyword_rnd_unique", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + all_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vtrc(vec, all_mask, rnd=pto.VcvtRoundMode.F) + pto.vsts(out, dst, 0, all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vtrc", text) + self.assertIn(', "F" :', text) + + def test_vtrc_rejects_round_mode_o(self) -> None: + with self.assertRaises(TypeError) as ctx: + @pto.vkernel( + op="vtrc_round_mode_o_unique", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + all_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vtrc(vec, all_mask, rnd=pto.VcvtRoundMode.O) + pto.vsts(out, dst, 0, all_mask) + return None + + kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ).mlir_text() + + self.assertIn("pto.vtrc rnd must be one of", str(ctx.exception)) + def test_advanced_sort_memory_ops_surface_lower(self) -> None: @pto.vkernel( op="advanced_sort_memory_ops_unique", From 4a9c7f4fe29d011d000f57239fbf33f37a9adf1a Mon Sep 17 00:00:00 2001 From: qukelin Date: Thu, 16 Apr 2026 17:21:30 +0800 Subject: [PATCH 091/192] Add tcvt TileLib support and ST coverage --- ...4-15-tcvt-tilelib-sample-and-work-items.md | 365 ++++++++++++++++++ lib/PTO/Transforms/ExpandTileOp.cpp | 59 ++- lib/TileOps/tcvt_template.py | 79 ++++ test/basic/expand_tile_op_tilelang_tcvt.pto | 76 ++++ .../npu/a5/src/st/testcase/CMakeLists.txt | 1 + .../a5/src/st/testcase/tcvt/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tcvt/cases.py | 55 +++ .../npu/a5/src/st/testcase/tcvt/compare.py | 50 +++ .../npu/a5/src/st/testcase/tcvt/gen_data.py | 112 ++++++ .../npu/a5/src/st/testcase/tcvt/launch.cpp | 29 ++ .../npu/a5/src/st/testcase/tcvt/main.cpp | 130 +++++++ .../npu/a5/src/st/testcase/tcvt/tcvt.pto | 157 ++++++++ tilelang-dsl/python/tilelang_dsl/__init__.py | 2 + .../python/tilelang_dsl/expand_helper.py | 30 +- .../python/tilelang_dsl/frontend_ast.py | 4 + tilelang-dsl/python/tilelang_dsl/semantic.py | 54 ++- .../python/tilelang_dsl/support_matrix.py | 1 + tilelang-dsl/python/tilelang_dsl/types.py | 7 + 18 files changed, 1217 insertions(+), 3 deletions(-) create mode 100644 docs/designs/2026-04-15-tcvt-tilelib-sample-and-work-items.md create mode 100644 lib/TileOps/tcvt_template.py create mode 100644 test/basic/expand_tile_op_tilelang_tcvt.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcvt/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcvt/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcvt/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcvt/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcvt/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcvt/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcvt/tcvt.pto diff --git a/docs/designs/2026-04-15-tcvt-tilelib-sample-and-work-items.md b/docs/designs/2026-04-15-tcvt-tilelib-sample-and-work-items.md new file mode 100644 index 000000000..bc4de825b --- /dev/null +++ b/docs/designs/2026-04-15-tcvt-tilelib-sample-and-work-items.md @@ -0,0 +1,365 @@ +# `pto.tcvt` TileLib 模板库设计与工作项 + +## 1. 目标 + +参考 `pto-isa` 在 A5 上现有的 `TCVT_IMPL` 实现,用 TileLang DSL 在 TileLib 中补齐 `pto.tcvt` 模板库。 + +当前 PTOAS 的 `pto.tcvt` 只显式携带 `rmode`,没有单独暴露 `sat_mode`。因此模板库不能只做一层简单透传,而是要在内部按 `(src_dtype, dst_dtype)` 复现 A5 的默认 `sat_mode` 选择,再为不同类型对走到正确的 VPTO 路径。 + + +## 2. 当前语义 + +### 2.1 PTOAS 侧 + +当前 PTOAS 的 `pto.tcvt` 只有 `rmode` attribute,没有 `sat_mode` attribute。 + +这意味着从 PTOAS 传到 TileLib 的静态信息只有 round mode。默认饱和策略需要模板库自己补齐。 + +### 2.2 A5 `pto-isa` 侧 + +A5 侧已有多组 `TCVT_IMPL` 重载,包括: + +- `TCVT_IMPL(dst, src, mode)` +- `TCVT_IMPL(dst, src, mode, satMode)` +- `TCVT_IMPL(dst, src, tmp, mode)` +- `TCVT_IMPL(dst, src, tmp, mode, satMode)` + +其中 `TCVT_IMPL(dst, src, tmp, mode)` 在 A5 上只是转调无 `tmp` 的版本,`tmp` 本身不参与实现。这里保留 `tmp`,主要是为了和 A2/A3 的接口形态保持兼容。 + +如果只聚焦当前 `pto.tcvt` 真正需要对齐的那条入口,也就是: + +```cpp +TCVT_IMPL(dst, src, mode) +``` + +那么 A5 `pto-isa` 里的主要过程可以概括成下面这条链路: + +1. 先按 `(src_dtype, dst_dtype)` 选默认 `satMode` + 也就是这条入口本身先做一层类型分派,把当前 type pair 映射成默认 + `satMode=ON` 或 `OFF`。 + +2. 再转调到显式 `satMode` 的主实现入口 + +```cpp +TCVT_IMPL(dst, src, mode, satMode) +``` + +3. 在显式 `satMode` 入口里,先根据 `(src_dtype, dst_dtype, satMode)` 计算当前需要设置哪些 CTRL 位 + 这里会调用 `determineSaturationCtrlBits(...)`,然后再调用 + `applySaturationCtrlBits(...)` 把这些 CTRL 位写进去。 + +4. CTRL 位设置完成后,再按 `round_mode` 做一层 switch 分派 + 例如分到 `RoundRType` / `RoundAType` / `RoundFType` / `RoundCType` / + `RoundZType` / `RoundOType`,最后统一调用: + +```cpp +implTCVT(...) +``` + +5. `implTCVT(...)` 内部再按 type pair 落到具体 helper + 例如: + - `cast32to32` + - `cast32to16` + - `cast16to32` + - `cast16to16` + - `cast16to8` + - 以及 `NonSatTorch` 那几条专门 helper + +6. 最后恢复之前改过的 CTRL 位 + 也就是在主实现入口的尾部调用 `restoreSaturationCtrlBits(...)`。 + +把这段代码实现压成一条线来看,就是: + +```text +TCVT_IMPL(dst, src, mode) + -> 按类型对选默认 satMode + -> TCVT_IMPL(dst, src, mode, satMode) + -> determineSaturationCtrlBits(...) + -> applySaturationCtrlBits(...) + -> switch(round_mode) + -> implTCVT(...) + -> cast helper / NonSatTorch helper + -> restoreSaturationCtrlBits(...) +``` + +对 TileLib 来说,真正需要复现的就是这条框架,而不是只把 `rmode` 直接透传给某一个 +`vcvt` 就结束。 + +因此,对当前 A5 来说,`pto.tcvt` 需要对齐的真实语义是: + +1. 外部只显式给 `rmode` +2. 库内部按类型对选择默认 `sat_mode` +3. 再按类型对和 `sat_mode` 进入具体实现路径 + +## 3. A5 实现要点 + +### 3.1 默认 `sat_mode` + +A5 的 round-only `TCVT_IMPL(dst, src, mode)` 对下面这些类型对默认使用 `sat_mode=OFF`: + +| 源类型 | 目标类型 | 默认 `sat_mode` | 说明 | +|---|---|---|---| +| `f16` | `u8` | `OFF` | A5 现有默认行为 | +| `f16` | `i8` | `OFF` | A5 现有默认行为 | +| `f32` | `i16` | `OFF` | A5 现有默认行为 | +| `f16` | `i16` | `OFF` | A5 现有默认行为 | +| `i64` | `i32` | `OFF` | A5 现有默认行为 | +| `i32` | `i16` | `OFF` | A5 现有默认行为 | + +除上表外,其余类型对默认使用 `sat_mode=ON`。 + +这部分规则应直接在 TileLib 模板内部复现,不应依赖 PTOAS 额外传参。 + +### 3.2 A5 `TCVT` 整体支持表 + +按三个实现维度分类: + +- 是否受 `round_mode` 影响 +- 是否受 `sat_mode` 影响 +- 是否需要 `NonSatTorch` 对齐 + +这里根据 `pto-isa/include/pto/npu/a5/TCvt.hpp` 整理,不等于当前 +PTOAS + TileLib 已经全部打通。 + +下面各表最后一列 `TileLib是否支持` 以当前 +`PTOAS/lib/TileOps/tcvt_template.py` 实际实现为准。当前已打通的先标 `已支持`, +其余暂时留空。 + +#### 3.2.1 不受 `round_mode` / `sat_mode` 影响,也不需要 `NonSatTorch` + +这组最适合优先实现,基本都是 expand / unpack 路径。 + +| 源类型 | 目标类型 | A5 helper 覆盖 | 备注 | TileLib是否支持 | +|---|---|---|---|---| +| `f16` | `f32` | 1D+2D,`vcvt + part` | type expand | | +| `bf16` | `f32` | 1D+2D,`vcvt + part` | type expand | | +| `i16` | `f32` / `i32` / `u32` | 1D+2D,expand helper | widening path | | +| `i32` | `i64` | 1D+2D,expand helper | | | +| `u8` | `f16` / `u16` | 1D only,expand helper | 当前只看到 1D helper | | +| `i8` | `f16` / `i16` / `i32` | 1D only,expand helper | 当前只看到 1D helper | | +| `fp8_e4m3` / `fp8_e5m2` / `h8` | `f32` | 1D+2D,expand helper | source 8-bit float | | +| `fp4_e1m2x2` / `fp4_e2m1x2` | `bf16` | 1D+2D,专用 unpack helper | 4-bit packed source | | + +#### 3.2.2 受 `round_mode` 影响,不受 `sat_mode` 影响,也不需要 `NonSatTorch` + +这组属于 round-only 路径。 + +| 源类型 | 目标类型 | A5 helper 覆盖 | 备注 | TileLib是否支持 | +|---|---|---|---|---| +| `f32` | `f32` | 1D+2D,`vtrc` | 保持 `f32`,做 integer-valued float rounding | | +| `f16` | `i32` | 1D+2D,`vcvt + part` | | | +| `i16` | `f16` | 1D+2D,`vcvt` | | | +| `i32` | `f32` | 1D+2D,`vcvt` | | `已支持` | +| `i64` | `f32` | 1D+2D,`vcvt + part` | | | +| `bf16` | `fp4_e1m2x2` / `fp4_e2m1x2` | 1D+2D,专用 packed helper | 不是普通 `vcvt` 套餐,但不吃 `sat_mode` | | + +#### 3.2.3 不受 `round_mode` 影响,受 `sat_mode` 影响,不需要 `NonSatTorch` + +这组主要是整数窄化。 + +| 源类型 | 目标类型 | A5 helper 覆盖 | 默认 `effective_sat_mode` | 备注 | TileLib是否支持 | +|---|---|---|---|---|---| +| `i16` | `u8` | 1D+2D,`vcvt + part` | `ON` | | | +| `i32` | `i16` | 1D+2D,`vcvt + part` | `OFF` | | | +| `i32` | `u16` / `u8` | 1D+2D,`vcvt + part` | `ON` | | | +| `u32` | `i16` / `u16` / `u8` | 1D+2D,`vcvt + part` | `ON` | | | +| `i64` | `i32` | 1D+2D,`vcvt + part` | `OFF` | | | + +#### 3.2.4 同时受 `round_mode` 和 `sat_mode` 影响,但不需要 `NonSatTorch` + +这组是常规 `tcvt` 主干路径。当前先打通的 `f32 -> i32` 就属于这一类。 + +| 源类型 | 目标类型 | A5 helper 覆盖 | 默认 `effective_sat_mode` | 备注 | TileLib是否支持 | +|---|---|---|---|---|---| +| `f32` | `f16` / `bf16` | 1D+2D,`vcvt + part` | `ON` | 窄化 float | | +| `f32` | `i32` | 1D+2D,`vcvt` | `ON` | 当前已先打通这一类普通路径 | `已支持` | +| `f32` | `i64` | 1D+2D,`vcvt + part` | `ON` | | | +| `f32` | `fp8_e4m3` / `fp8_e5m2` | 1D+2D,`vcvt + part` | `ON` | | | +| `f16` | `u8` | 1D+2D,`vcvt + part` | `OFF` | | | +| `bf16` | `i32` | 1D+2D,`vcvt + part` | `ON` | | | +| `bf16` | `f16` | 1D+2D,`vcvt` | `ON` | helper 内部是 `SAT_ROUND` 顺序 | | + +#### 3.2.5 同时受 `round_mode` 和 `sat_mode` 影响,且需要 `NonSatTorch` + +这组后面要单独收口。不能把它们直接等价成普通 `vcvt(..., sat=NOSAT)`。 + +| 源类型 | 目标类型 | A5 helper 覆盖 | 默认 `effective_sat_mode` | `NonSatTorch` | 备注 | TileLib是否支持 | +|---|---|---|---|---|---|---| +| `f32` | `i16` | 1D+2D,`vcvt + part` | `OFF` | 是 | `OFF` 时走 `NonSatTorch` | | +| `f16` | `i16` | 1D+2D,`vcvt` | `OFF` | 是 | | | +| `f16` | `i8` | 1D+2D,`vcvt + part` | `OFF` | 是 | | | + +#### 3.2.6 专用 helper,`round_mode` 受限 + +这组不建议和普通路径一起排第一批。A5 helper 虽然形式上带模板参数,但当前实现实际固定在特定 round 行为上。 + +| 源类型 | 目标类型 | A5 helper 覆盖 | 默认 `effective_sat_mode` | 备注 | TileLib是否支持 | +|---|---|---|---|---|---| +| `f32` | `h8` | 1D+2D,专用 helper | `ON` | helper 实际固定 `ROUND_A` | | +| `f16` | `h8` | 1D+2D,专用 helper | `ON` | helper 实际固定 `ROUND_A` | | + +这里再记三点: + +- `f16 -> fp8_e4m3/e5m2` 当前 A5 `pto-isa` 明确未实现;`f16` 这边只提供了 `h8` 专用 helper。 +- `h8`、`fp4` 这类路径不是普通 `vcvt` 套餐,后面做 TileLib 时不建议和常规 `f32/f16/bf16/int` 主干混在第一批一起做。 +- 这里说“受 / 不受 `round_mode` 影响”指的是该 pair 的 A5 helper 是否真的消费 round 语义,不是说 PTOAS 这层拿不到 `rmode`。 + +### 3.3 `round_mode` 映射表 + +当前 `pto.tcvt` 这条链路里,round mode 至少会经过四层名字: + +1. PTOAS op attr:`#pto` +2. `ExpandTileOp` 传给 TileLang 的上下文字符串:`round_mode` +3. TileLang DSL 前端:`pto.VcvtRoundMode.*` +4. VPTO / A5 lowering:`rnd = "R"` 这一类 token,或 `RoundMode::CAST_*` + +建议文档和实现都按下面这张表统一,不要在不同层写不同别名。 + +| PTOAS `rmode` | `ExpandTileOp` 传值 | DSL 前端 | VPTO token | A5 / EmitC | 语义 | +|---|---|---|---|---|---| +| `NONE` | `RINT` | `pto.VcvtRoundMode.R` | `R` / `ROUND_R` | `RoundMode::CAST_RINT` | round to nearest, ties to even | +| `RINT` | `RINT` | `pto.VcvtRoundMode.R` | `R` / `ROUND_R` | `RoundMode::CAST_RINT` | round to nearest, ties to even | +| `CAST_RINT` | `RINT` | `pto.VcvtRoundMode.R` | `R` / `ROUND_R` | `RoundMode::CAST_RINT` | round to nearest, ties to even | +| `ROUND` | `ROUND` | `pto.VcvtRoundMode.A` | `A` / `ROUND_A` | `RoundMode::CAST_ROUND` | round away from zero | +| `FLOOR` | `FLOOR` | `pto.VcvtRoundMode.F` | `F` / `ROUND_F` | `RoundMode::CAST_FLOOR` | round toward negative infinity | +| `CEIL` | `CEIL` | `pto.VcvtRoundMode.C` | `C` / `ROUND_C` | `RoundMode::CAST_CEIL` | round toward positive infinity | +| `TRUNC` | `TRUNC` | `pto.VcvtRoundMode.Z` | `Z` / `ROUND_Z` | `RoundMode::CAST_TRUNC` | round toward zero | +| `ODD` | `ODD` | `pto.VcvtRoundMode.O` | `O` / `ROUND_O` | `RoundMode::CAST_ODD` | round to odd | + +这里再补三条实现上要注意的点: + +- `ExpandTileOp` 当前应把 `NONE` / `RINT` / `CAST_RINT` 统一归一成 `RINT`,这样模板内部只需要处理一套默认 round-to-nearest 语义。 +- `PTO_IR_manual` 里对 `ROUND` 的描述偏旧,当前实现和 VPTO 规格应按 “away from zero” 理解。 +- `f32 -> f32` 这条 `vtrc` 路径不能直接照抄上表全部 token。当前 VPTO `vtrc` 规格只明确列了 `R/A/F/C/Z`,`ODD` 需要单独看目标语义,不应默认跟 `vcvt` 完全等价。 + +### 3.4 不同类型对的处理路径 + +从模板实现角度看,更重要的不是 A5 内部怎么切 CTRL 位,而是不同类型对最终该走哪条路径。建议按下面这张表组织 TileLib 逻辑: + +| 类型对 | 默认路径 | 备注 | +|---|---|---| +| `f32 -> f32` | `vtrc` | 这是 round-to-int-valued-float,不应走 `vcvt` | +| `f32 -> i16` 且 `sat_mode=OFF` | `NonSatTorch` helper | 需要对齐 A5 现有边界值行为 | +| `f16 -> i16` 且 `sat_mode=OFF` | `NonSatTorch` helper | 需要对齐 A5 现有边界值行为 | +| `f16 -> i8` 且 `sat_mode=OFF` | `NonSatTorch` helper | 需要对齐 A5 现有边界值行为 | +| 其余合法类型对 | `vcvt` | 具体带哪些 attr 取决于 VPTO contract | + +`NonSatTorch` 这三条路径不能简单等价成普通 `vcvt(..., sat=NOSAT)`。A5 这里保留了专门实现,是为了在 `inf`、`nan`、`overflow` 这些边界值上对齐当前行为。 + +### 3.5 `vcvt` 的 attr 约束 + +TileLib 侧即使已经推导出了 `sat_mode`,也不能无条件给 `vcvt` 传 `rnd/sat/part`。这些 attr 是否应该出现,仍然要服从 VPTO `vcvt` 的 verifier 约束。 + +下面列几个模板里一定会碰到的典型路径: + +| 类型对 | `rnd` | `sat` | `part` | 建议路径 | +|---|---|---|---|---| +| `f32 -> i32` | 需要 | 需要 | 不需要 | `vcvt` | +| `i32 -> f32` | 需要 | 不需要 | 不需要 | `vcvt` | +| `f32 -> f16/bf16` | 需要 | 需要 | 需要 | `vcvt` | +| `f16/bf16 -> f32` | 不需要 | 不需要 | 需要 | `vcvt` | +| `f32 -> f32` | 不适用 | 不适用 | 不适用 | `vtrc` | + +因此,模板里最好把“默认 `sat_mode` 推导”和“`vcvt` attr 组织”拆成两层,不要混在一起写。 + +## 4. TileLib 设计建议 + +### 4.1 模板主流程 + +TileLib 中的 `pto.tcvt` 模板建议保持下面这个结构: + +```python +@pto.vkernel(target="a5", op="pto.tcvt") +def template_tcvt(src: pto.Tile, dst: pto.Tile): + src_dtype = src.element_type + dst_dtype = dst.element_type + + round_mode = pto.get_op_attr("round_mode", "RINT") + sat_mode = _a5_default_tcvt_sat_mode(src_dtype, dst_dtype) + + if _needs_nonsat_torch(src_dtype, dst_dtype, sat_mode): + return _emit_nonsat_torch_tcvt(src, dst, round_mode) + + return _emit_regular_tcvt(src, dst, round_mode, sat_mode) +``` + +这里建议把逻辑拆成三个内部 helper: + +- `_a5_default_tcvt_sat_mode(src_dtype, dst_dtype)` +- `_needs_nonsat_torch(src_dtype, dst_dtype, sat_mode)` +- `_emit_regular_tcvt(...)` + +这样写更容易和 A5 `pto-isa` 的现有规则对齐,也方便后面做单测。 + +### 4.2 普通路径的分派原则 + +`_emit_regular_tcvt(...)` 里建议只做两件事: + +1. 判断当前类型对应该走 `vtrc` 还是 `vcvt` +2. 如果走 `vcvt`,按 VPTO contract 决定是否附带 `rnd`、`sat`、`part` + +不要直接按 A5 C++ helper 名称去分派 TileLang DSL。TileLib 需要对齐的是最终语义,而不是逐个复刻底层 helper 名。 + +### 4.3 `NonSatTorch` 的定位 + +`NonSatTorch` 在这里应视为模板内部实现细节,不是新的对外接口。 + +可以先完成普通路径,再补 `NonSatTorch`。如果目标是和当前 A5 行为严格对齐,这三条特殊路径需要在第一版就一起补上。 + +## 5. 工作项 + +### 5.1 TileLib 模板库 + +需要补一份 `pto.tcvt` TileLib 模板,实现以下逻辑: + +| 工作项 | 说明 | +|---|---| +| 读取 `round_mode` | 通过 `pto.get_op_attr("round_mode", "RINT")` 获取 | +| 推导默认 `sat_mode` | 严格按 A5 类型对规则实现 | +| 支持 `vtrc` 路径 | 至少覆盖 `f32 -> f32` | +| 支持普通 `vcvt` 路径 | 并满足 VPTO verifier 对 attr 的要求 | +| 支持 `NonSatTorch` 路径 | 至少覆盖 `f32 -> i16`、`f16 -> i16`、`f16 -> i8` 且默认 `OFF` 的场景 | + +### 5.2 DSL / ExpandHelper / `ExpandTileOp` + +除了模板本身,还需要把下面几处配套能力接上: + +| 模块 | 工作项 | +|---|---| +| TileLang DSL | 支持 `pto.get_op_attr("round_mode", ...)` | +| TileLang DSL | 为 `pto.vtrc` 补 round-mode surface,避免 `f32 -> f32` 卡住 | +| ExpandHelper | 传递 `round_mode` 到模板上下文 | +| `ExpandTileOp` | `SpecKey` 纳入 `round_mode`,避免不同 `rmode` 错误复用实例 | + +当前没有必要把 `sat_mode` 加进 `SpecKey`,因为在现有语义下,它完全由 `(src_dtype, dst_dtype)` 决定,而这部分已经包含在操作数 specialization 里。 + +### 5.3 测试 + +建议测试按三类准备: + +| 测试类型 | 关注点 | +|---|---| +| 模板选择与缓存 | 相同类型对、不同 `rmode` 不应复用同一实例 | +| 模板展开 | `round_mode` 能正确进入 `vtrc` / `vcvt` | +| 数值行为 | 默认 `OFF` 类型对、`NonSatTorch` 特殊路径、`f32 -> f32` 路径 | + +最少应覆盖下面这些代表性 case: + +- `f32 -> f32` +- `f32 -> i16` +- `f16 -> i16` +- `f16 -> i8` +- `f32 -> i32` +- `i32 -> f32` + +## 6. 结论 + +这项工作的关键不是“把 `rmode` 传给一个 `vcvt`”这么简单,而是把当前 A5 `pto-isa` 在 round-only `TCVT_IMPL` 里隐含的默认 `sat_mode` 规则和类型分派规则一起带到 TileLib。 + +对当前 PTOAS `pto.tcvt` 而言,模板库应复现下面这条主线: + +1. 从 PTOAS 读取 `round_mode` +2. 在模板内部按 `(src_dtype, dst_dtype)` 推导默认 `sat_mode` +3. 按类型对分派到 `vtrc`、普通 `vcvt` 或 `NonSatTorch` helper + +这样实现出来的 TileLib 模板库,才能和 A5 `pto-isa` 现有行为保持一致。 diff --git a/lib/PTO/Transforms/ExpandTileOp.cpp b/lib/PTO/Transforms/ExpandTileOp.cpp index a398896bd..6c009e556 100644 --- a/lib/PTO/Transforms/ExpandTileOp.cpp +++ b/lib/PTO/Transforms/ExpandTileOp.cpp @@ -129,10 +129,11 @@ struct SpecKey { std::string opName; std::string targetArch; SmallVector operands; + SmallVector, 4> contextAttrs; bool operator==(const SpecKey &rhs) const { return opName == rhs.opName && targetArch == rhs.targetArch && - operands == rhs.operands; + operands == rhs.operands && contextAttrs == rhs.contextAttrs; } }; @@ -153,6 +154,8 @@ struct SpecKeyInfo : public llvm::DenseMapInfo { } // View/Scalar: only kind + dtype contribute to hash. } + for (const auto &[attrName, attrValue] : key.contextAttrs) + h = llvm::hash_combine(h, attrName, attrValue); return h; } static bool isEqual(const SpecKey &lhs, const SpecKey &rhs) { @@ -224,6 +227,36 @@ static std::string getSLayoutString(int32_t slayout) { return "none_box"; } +static std::optional getTCvtRoundModeString(pto::TCvtOp op) { + switch (op.getRmode()) { + case pto::RoundMode::NONE: + case pto::RoundMode::RINT: + case pto::RoundMode::CAST_RINT: + return "RINT"; + case pto::RoundMode::ROUND: + return "ROUND"; + case pto::RoundMode::FLOOR: + return "FLOOR"; + case pto::RoundMode::CEIL: + return "CEIL"; + case pto::RoundMode::TRUNC: + return "TRUNC"; + case pto::RoundMode::ODD: + return "ODD"; + } + return std::nullopt; +} + +static void appendOpContextAttrs( + Operation *op, + SmallVectorImpl> &attrs) { + if (auto tcvt = dyn_cast(op)) { + std::optional roundMode = getTCvtRoundModeString(tcvt); + if (roundMode) + attrs.emplace_back("round_mode", *roundMode); + } +} + static bool getStaticIntFromValue(Value value, int64_t &out) { if (auto cOp = value.getDefiningOp()) { out = cOp.value(); @@ -389,6 +422,7 @@ static std::optional buildSpecKey(Operation *op) { if (key.operands.empty()) return std::nullopt; + appendOpContextAttrs(op, key.contextAttrs); return key; } @@ -484,6 +518,22 @@ static std::string buildOperandSpecsJson(const SpecKey &key) { return json; } +static std::string buildContextAttrsJson(const SpecKey &key) { + std::string json = "{"; + for (size_t i = 0; i < key.contextAttrs.size(); ++i) { + const auto &[attrName, attrValue] = key.contextAttrs[i]; + if (i > 0) + json += ","; + json += "\""; + json += attrName; + json += "\":\""; + json += attrValue; + json += "\""; + } + json += "}"; + return json; +} + // ============================================================================ // Invoke Python DSL helper to generate a specialized template function. // ============================================================================ @@ -504,6 +554,7 @@ func::FuncOp ExpandState::invokeTilelangDSL(const SpecKey &key, // 2. Build operand schema JSON for mixed tile/scalar specialization. std::string operandSpecsJson = buildOperandSpecsJson(key); + std::string contextAttrsJson = buildContextAttrsJson(key); if (key.targetArch.empty()) { llvm::errs() << "ExpandTileOp: missing pto.target_arch module attribute\n"; return nullptr; @@ -529,6 +580,10 @@ func::FuncOp ExpandState::invokeTilelangDSL(const SpecKey &key, "--op", opName, "--operand-specs", operandSpecsJson, }; + if (!key.contextAttrs.empty()) { + args.push_back("--context-attrs"); + args.push_back(contextAttrsJson); + } // 5. Set up environment with PYTHONPATH. std::optional redirects[] = {std::nullopt, StringRef(tmpPath), @@ -642,6 +697,8 @@ func::FuncOp ExpandState::invokeTilelangDSL(const SpecKey &key, uniqueName += "_pd" + llvm::utohexstr(op.pad, /*LowerCase=*/false); } } + for (const auto &[attrName, attrValue] : key.contextAttrs) + uniqueName += "_ctx_" + attrName + "_" + attrValue; for (auto [index, fn] : llvm::enumerate(parsedFuncs)) { IRMapping mapping; diff --git a/lib/TileOps/tcvt_template.py b/lib/TileOps/tcvt_template.py new file mode 100644 index 000000000..0f6dfb640 --- /dev/null +++ b/lib/TileOps/tcvt_template.py @@ -0,0 +1,79 @@ +"""TileLang DSL template for pto.tcvt.""" + +import tilelang_dsl as pto + +@pto.vkernel( + target="a5", + op="pto.tcvt", + dtypes=[ + (pto.f32, pto.i32), + ], +) +def template_tcvt_f32_to_i32(src: pto.Tile, dst: pto.Tile): + dst_dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + round_mode = pto.get_op_attr("round_mode", "RINT") + rnd = pto.VcvtRoundMode.R + if pto.constexpr(round_mode == "ROUND"): + rnd = pto.VcvtRoundMode.A + elif pto.constexpr(round_mode == "FLOOR"): + rnd = pto.VcvtRoundMode.F + elif pto.constexpr(round_mode == "CEIL"): + rnd = pto.VcvtRoundMode.C + elif pto.constexpr(round_mode == "TRUNC"): + rnd = pto.VcvtRoundMode.Z + elif pto.constexpr(round_mode == "ODD"): + rnd = pto.VcvtRoundMode.O + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dst_dtype)): + mask, remained = pto.make_mask(dst_dtype, remained) + vec = pto.vlds(src[row, col:]) + converted = pto.vcvt( + vec, + dst_dtype, + mask, + rnd=rnd, + sat=pto.VcvtSatMode.SAT, + ) + pto.vsts(converted, dst[row, col:], mask) + return + + +@pto.vkernel( + target="a5", + op="pto.tcvt", + dtypes=[ + (pto.i32, pto.f32), + ], +) +def template_tcvt_i32_to_f32(src: pto.Tile, dst: pto.Tile): + dst_dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + round_mode = pto.get_op_attr("round_mode", "RINT") + rnd = pto.VcvtRoundMode.R + if pto.constexpr(round_mode == "ROUND"): + rnd = pto.VcvtRoundMode.A + elif pto.constexpr(round_mode == "FLOOR"): + rnd = pto.VcvtRoundMode.F + elif pto.constexpr(round_mode == "CEIL"): + rnd = pto.VcvtRoundMode.C + elif pto.constexpr(round_mode == "TRUNC"): + rnd = pto.VcvtRoundMode.Z + elif pto.constexpr(round_mode == "ODD"): + rnd = pto.VcvtRoundMode.O + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dst_dtype)): + mask, remained = pto.make_mask(dst_dtype, remained) + vec = pto.vlds(src[row, col:]) + converted = pto.vcvt( + vec, + dst_dtype, + mask, + rnd=rnd, + ) + pto.vsts(converted, dst[row, col:], mask) + return diff --git a/test/basic/expand_tile_op_tilelang_tcvt.pto b/test/basic/expand_tile_op_tilelang_tcvt.pto new file mode 100644 index 000000000..ee1ae4060 --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_tcvt.pto @@ -0,0 +1,76 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test the first regular tcvt paths through ExpandTileOp. +// Current scope is intentionally narrow: +// - f32 -> i32: regular vcvt with rnd + sat +// - i32 -> f32: regular vcvt with rnd only +// - round_mode must reach the template so different rmode values materialize +// different vcvt attrs +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// CHECK-LABEL: func.func @TCVT_F32_TO_I32 +// CHECK-NOT: pto.tcvt ins +// CHECK: pto.vecscope +// CHECK: pto.vcvt {{.*}} {rnd = "R", sat = "SAT"} : !pto.vreg<64xf32> -> !pto.vreg<64xi32> +// CHECK: pto.vcvt {{.*}} {rnd = "A", sat = "SAT"} : !pto.vreg<64xf32> -> !pto.vreg<64xi32> + +// CHECK-LABEL: func.func @TCVT_I32_TO_F32 +// CHECK-NOT: pto.tcvt ins +// CHECK: pto.vecscope +// CHECK: pto.vcvt {{.*}} {rnd = "R"} : !pto.vreg<64xi32> -> !pto.vreg<64xf32> +// CHECK: pto.vcvt {{.*}} {rnd = "A"} : !pto.vreg<64xi32> -> !pto.vreg<64xf32> + +module { + func.func @TCVT_F32_TO_I32() { + %src = pto.alloc_tile + : !pto.tile_buf + %dst_r = pto.alloc_tile + : !pto.tile_buf + %dst_a = pto.alloc_tile + : !pto.tile_buf + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst_r : !pto.tile_buf) + + pto.tcvt ins(%src {rmode = #pto} : !pto.tile_buf) + outs(%dst_a : !pto.tile_buf) + return + } + + func.func @TCVT_I32_TO_F32() { + %src = pto.alloc_tile + : !pto.tile_buf + %dst_r = pto.alloc_tile + : !pto.tile_buf + %dst_a = pto.alloc_tile + : !pto.tile_buf + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst_r : !pto.tile_buf) + + pto.tcvt ins(%src {rmode = #pto} : !pto.tile_buf) + outs(%dst_a : !pto.tile_buf) + return + } +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt index c0c522221..db614307b 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt +++ b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt @@ -114,6 +114,7 @@ endfunction() # -------------------------------------------------------------------------- set(ALL_TESTCASES tadd + tcvt tload ) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcvt/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/CMakeLists.txt new file mode 100644 index 000000000..b117e9a27 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tcvt) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcvt/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/cases.py new file mode 100644 index 000000000..0674eadbb --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/cases.py @@ -0,0 +1,55 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tcvt ST test cases. + +Current TileLib tcvt support covered by this testcase: + - f32 -> i32 + - i32 -> f32 + +`dtype` is kept for shared validation compatibility. +Actual data generation and comparison use `src_dtype` / `dst_dtype`. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_to_i32_rint_16x64", + "dtype": np.int32, + "src_dtype": np.float32, + "dst_dtype": np.int32, + "shape": (16, 64), + "valid_shape": (16, 64), + "round_mode": "RINT", + "eps": 0.0, + }, + { + "name": "f32_to_i32_round_16x64", + "dtype": np.int32, + "src_dtype": np.float32, + "dst_dtype": np.int32, + "shape": (16, 64), + "valid_shape": (16, 64), + "round_mode": "ROUND", + "eps": 0.0, + }, + { + "name": "i32_to_f32_rint_16x64", + "dtype": np.float32, + "src_dtype": np.int32, + "dst_dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "round_mode": "RINT", + "eps": 1e-6, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcvt/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/compare.py new file mode 100644 index 000000000..e49344aaa --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + dst_dtype = case["dst_dtype"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=dst_dtype).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=dst_dtype).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcvt/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/gen_data.py new file mode 100644 index 000000000..e13ae1237 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/gen_data.py @@ -0,0 +1,112 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np + +from cases import CASES +from st_common import save_case_data, setup_case_rng, validate_cases + + +def make_f32_input(shape): + total = int(np.prod(shape)) + base = (np.arange(total, dtype=np.float32) % 17) - 8.0 + frac_table = np.array([0.2, 0.5, 0.8, -0.2, -0.5, -0.8], dtype=np.float32) + frac = frac_table[np.arange(total) % frac_table.size] + return (base + frac).reshape(shape) + + +def make_i32_input(shape): + total = int(np.prod(shape)) + return (((np.arange(total, dtype=np.int32) * 37) % 257) - 128).reshape(shape) + + +def round_half_away_from_zero(values): + return np.copysign(np.floor(np.abs(values) + 0.5), values) + + +def default_saturation_off(src_dtype, dst_dtype): + """Mirror the current A5 default saturation policy for supported pairs.""" + return ( + (src_dtype is np.float16 and dst_dtype is np.uint8) + or (src_dtype is np.float16 and dst_dtype is np.int8) + or (src_dtype is np.float32 and dst_dtype is np.int16) + or (src_dtype is np.float16 and dst_dtype is np.int16) + or (src_dtype is np.int64 and dst_dtype is np.int32) + or (src_dtype is np.int32 and dst_dtype is np.int16) + ) + + +def apply_round_mode(values, round_mode): + rounding_funcs = { + "RINT": np.rint, + "ROUND": round_half_away_from_zero, + "FLOOR": np.floor, + "CEIL": np.ceil, + "TRUNC": np.trunc, + } + return rounding_funcs.get(round_mode, np.rint)(values) + + +def convert_with_default_saturation(values, src_dtype, dst_dtype): + if np.issubdtype(dst_dtype, np.integer): + if default_saturation_off(src_dtype, dst_dtype): + # For currently supported ST cases this branch is not taken, but keep + # the structure aligned with pto-isa's A5 tcvt golden generator. + if dst_dtype is np.int32: + widened = values.astype(np.int64, copy=False) + wrapped = np.where(widened < 0, (widened + (1 << 32)) & 0xFFFFFFFF, widened & 0xFFFFFFFF) + signed = np.where(wrapped < (1 << 31), wrapped, wrapped - (1 << 32)) + return signed.astype(np.int32, copy=False) + return values.astype(dst_dtype, copy=False) + info = np.iinfo(dst_dtype) + widened = values.astype(np.float64, copy=False) + return np.clip(widened, info.min, info.max).astype(dst_dtype) + + if np.issubdtype(dst_dtype, np.floating): + info = np.finfo(dst_dtype) + return np.clip(values.astype(np.float64, copy=False), info.min, info.max).astype(dst_dtype) + + return values.astype(dst_dtype, copy=False) + + +def generate_golden(case): + src_dtype = case["src_dtype"] + dst_dtype = case["dst_dtype"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + if src_dtype is np.float32: + input_arr = make_f32_input(shape).astype(src_dtype) + rounded = apply_round_mode(input_arr[:vr, :vc], case["round_mode"]) + converted = convert_with_default_saturation(rounded, src_dtype, dst_dtype) + elif src_dtype is np.int32: + input_arr = make_i32_input(shape).astype(src_dtype) + converted = convert_with_default_saturation(input_arr[:vr, :vc], src_dtype, dst_dtype) + else: + raise TypeError(f"unsupported tcvt ST source dtype: {src_dtype}") + + golden = np.zeros(shape, dtype=dst_dtype) + golden[:vr, :vc] = converted + return input_arr, golden + + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + input_arr, golden = generate_golden(case) + + save_case_data(case["name"], {"input": input_arr, "golden": golden}) + print( + f"[INFO] gen_data: {case['name']} shape={case['shape']} " + f"src_dtype={case['src_dtype'].__name__} dst_dtype={case['dst_dtype'].__name__} " + f"round_mode={case['round_mode']}" + ) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcvt/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/launch.cpp new file mode 100644 index 000000000..738be1bfb --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/launch.cpp @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +extern "C" __global__ AICORE void TCVT_f32_to_i32_rint_16x64(__gm__ float *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_f32_to_i32_round_16x64(__gm__ float *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_i32_to_f32_rint_16x64(__gm__ int32_t *src, __gm__ float *dst); + +void LaunchTCVT_f32_to_i32_rint_16x64(void *src, void *dst, void *stream) { + TCVT_f32_to_i32_rint_16x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_f32_to_i32_round_16x64(void *src, void *dst, void *stream) { + TCVT_f32_to_i32_round_16x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_i32_to_f32_rint_16x64(void *src, void *dst, void *stream) { + TCVT_i32_to_f32_rint_16x64<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ float *)dst); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcvt/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/main.cpp new file mode 100644 index 000000000..5d080c29d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +void LaunchTCVT_f32_to_i32_rint_16x64(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_i32_round_16x64(void *src, void *dst, void *stream); +void LaunchTCVT_i32_to_f32_rint_16x64(void *src, void *dst, void *stream); + +using LaunchFn = void (*)(void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t srcRows; + size_t srcCols; + size_t dstRows; + size_t dstCols; + size_t srcElemSize; + size_t dstElemSize; +}; + +static const TestCase kCases[] = { + {"f32_to_i32_rint_16x64", LaunchTCVT_f32_to_i32_rint_16x64, 16, 64, 16, 64, sizeof(float), sizeof(int32_t)}, + {"f32_to_i32_round_16x64", LaunchTCVT_f32_to_i32_round_16x64, 16, 64, 16, 64, sizeof(float), sizeof(int32_t)}, + {"i32_to_f32_rint_16x64", LaunchTCVT_i32_to_f32_rint_16x64, 16, 64, 16, 64, sizeof(int32_t), sizeof(float)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + (void)deviceId; + int rc = 0; + const size_t srcElemCount = tc.srcRows * tc.srcCols; + const size_t dstElemCount = tc.dstRows * tc.dstCols; + size_t srcFileSize = srcElemCount * tc.srcElemSize; + size_t dstFileSize = dstElemCount * tc.dstElemSize; + + std::printf("[INFO] === case: %s (src=%zux%zu, dst=%zux%zu) ===\n", + tc.name, tc.srcRows, tc.srcCols, tc.dstRows, tc.dstCols); + + std::string caseDir = std::string("./") + tc.name; + + void *srcHost = nullptr; + void *dstHost = nullptr; + void *srcDevice = nullptr; + void *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, srcFileSize); + aclrtMallocHost(&dstHost, dstFileSize); + + aclrtMalloc(&srcDevice, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), srcFileSize, srcHost, srcFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, srcFileSize, srcHost, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + tc.launch(srcDevice, dstDevice, stream); + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcvt/tcvt.pto b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/tcvt.pto new file mode 100644 index 000000000..f11a13912 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/tcvt.pto @@ -0,0 +1,157 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tcvt. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. + +module { + // Case 0: f32 -> i32, default RINT + func.func @TCVT_f32_to_i32_rint_16x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x64xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xi32>) + return + } + + // Case 1: f32 -> i32, explicit ROUND + func.func @TCVT_f32_to_i32_round_16x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x64xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src {rmode = #pto} : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xi32>) + return + } + + // Case 2: i32 -> f32, default RINT + func.func @TCVT_i32_to_f32_rint_16x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x64xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x64xi32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } +} diff --git a/tilelang-dsl/python/tilelang_dsl/__init__.py b/tilelang-dsl/python/tilelang_dsl/__init__.py index 2aedb8083..7a5d2c6d2 100644 --- a/tilelang-dsl/python/tilelang_dsl/__init__.py +++ b/tilelang-dsl/python/tilelang_dsl/__init__.py @@ -65,6 +65,7 @@ elements_per_vreg, f16, f32, + get_op_attr, get_lanes, i1, i8, @@ -159,6 +160,7 @@ "mask_b16", "mask_b32", "constexpr", + "get_op_attr", "bytewidth", "get_lanes", "elements_per_vreg", diff --git a/tilelang-dsl/python/tilelang_dsl/expand_helper.py b/tilelang-dsl/python/tilelang_dsl/expand_helper.py index a31af37db..e8dd5f8a0 100644 --- a/tilelang-dsl/python/tilelang_dsl/expand_helper.py +++ b/tilelang-dsl/python/tilelang_dsl/expand_helper.py @@ -271,6 +271,7 @@ def _select_descriptor( target: str, op_name: str, operand_specs: list[dict], + extra_context_attrs: dict[str, object] | None = None, ) -> VKernelDescriptor: filtered_descriptors = _filter_descriptors_by_operand_schema( descriptors, @@ -285,16 +286,30 @@ def _select_descriptor( f"target={target!r}, op={op_name!r}, operand_types={operand_types!r}" ) registry = KernelRegistry(tuple(filtered_descriptors)) + context_attrs = _build_positional_context_attrs(operand_specs) + if extra_context_attrs: + context_attrs.update(extra_context_attrs) return select_kernel( target, op_name, operand_types, - context_attrs=_build_positional_context_attrs(operand_specs), + context_attrs=context_attrs, registry=registry, return_metadata=False, ) +def _parse_context_attrs(spec_text: str) -> dict[str, object]: + try: + raw = json.loads(spec_text) + except json.JSONDecodeError as exc: + raise ValueError(f"invalid context-attrs JSON: {exc}") from exc + + if not isinstance(raw, dict): + raise ValueError("context-attrs must be a JSON object") + return dict(raw) + + def main(argv: list[str] | None = None) -> int: parser = argparse.ArgumentParser(description="TileLang DSL expand helper") parser.add_argument("--template-dir", required=True, help="Directory of .py templates") @@ -307,6 +322,10 @@ def main(argv: list[str] | None = None) -> int: "--operand-specs", help="JSON array describing each operand (tile/scalar schema)", ) + parser.add_argument( + "--context-attrs", + help="JSON object describing static op/context attrs visible to the template", + ) args = parser.parse_args(argv) template_dir = Path(args.template_dir) @@ -315,6 +334,7 @@ def main(argv: list[str] | None = None) -> int: return 1 operand_specs: list[dict] | None = None + extra_context_attrs: dict[str, object] = {} if args.operand_specs: try: operand_specs = _parse_operand_specs(args.operand_specs) @@ -341,6 +361,13 @@ def main(argv: list[str] | None = None) -> int: {"kind": "tile", "dtype": target_dtype, "shape": shape, "memory_space": mem_space} ] + if args.context_attrs: + try: + extra_context_attrs = _parse_context_attrs(args.context_attrs) + except ValueError as exc: + print(f"expand_helper: error: {exc}", file=sys.stderr) + return 1 + # Scan all .py files for descriptors. all_descriptors: list[VKernelDescriptor] = [] for py_path in sorted(template_dir.glob("*.py")): @@ -359,6 +386,7 @@ def main(argv: list[str] | None = None) -> int: target=args.target, op_name=args.op, operand_specs=operand_specs, + extra_context_attrs=extra_context_attrs, ) except Exception as exc: print(f"expand_helper: error: {exc}", file=sys.stderr) diff --git a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py index a78d7b55f..2fcc25639 100644 --- a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py +++ b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py @@ -196,6 +196,7 @@ class FrontendKernelNode: parameters: tuple[FrontendParameterNode, ...] tile_specializations: tuple[FrontendTileSpecializationNode, ...] body: tuple[FrontendStmtNode, ...] + context_attrs: tuple[tuple[str, Any], ...] = () inline_procs: tuple[FrontendInlineProcNode, ...] = () @@ -1510,6 +1511,9 @@ def build_frontend_kernel_node(descriptor: Any) -> FrontendKernelNode: parameters=parameters, tile_specializations=tile_specializations, body=body, + context_attrs=tuple( + sorted(descriptor.constraint_context_attrs.items(), key=lambda item: item[0]) + ), inline_procs=reachable_inline_proc_nodes, ) diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index be4c24ac9..8c67b0a64 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -749,6 +749,7 @@ class SemanticKernel: class _SemanticAnalyzer: def __init__(self, node: FrontendKernelNode): self.node = node + self._context_attrs = dict(node.context_attrs) self._counter = 0 self._disable_inference_depth = 0 self._has_explicit_vecscope = self._contains_explicit_vecscope(node.body) @@ -3722,6 +3723,8 @@ def _analyze_call_expr( return self._analyze_bytewidth(args) if name in {"get_lanes", "elements_per_vreg"}: return self._analyze_get_lanes(args, call_name=name) + if name == "get_op_attr": + return self._analyze_get_op_attr(args) if name == "constexpr": raise TypeError( "pto.constexpr(...) is only supported as an if-condition wrapper in TileLang DSL v1" @@ -3814,6 +3817,50 @@ def _analyze_make_mask(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: ), ) + def _literal_expr_from_context_value(self, value: object, context: str) -> SemanticExpr: + if isinstance(value, bool): + return SemanticLiteralExpr(value=value, type=SemanticScalarType(dtype=i1)) + if isinstance(value, int) and not isinstance(value, bool): + return SemanticLiteralExpr(value=value, type=SemanticIndexType()) + if isinstance(value, float): + return SemanticLiteralExpr(value=value, type=SemanticScalarType(dtype=f32)) + if isinstance(value, str): + return SemanticLiteralExpr(value=value, type=SemanticMetaType(kind="string")) + if isinstance(value, ScalarType): + return SemanticSymbolExpr( + namespace="pto", + name=value.name, + value=value, + type=SemanticMetaType(kind="dtype"), + ) + if isinstance(value, MemorySpace): + return SemanticSymbolExpr( + namespace="pto", + name=value.name, + value=value, + type=SemanticMetaType(kind="memory_space"), + ) + raise TypeError( + f"{context} resolved to unsupported static value {value!r} in TileLang DSL v1" + ) + + def _analyze_get_op_attr(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) not in {1, 2}: + raise TypeError( + "pto.get_op_attr expects 1 or 2 positional arguments `(name, default?)` in TileLang DSL v1" + ) + attr_name = self._require_string_expr(args[0], "pto.get_op_attr name") + if attr_name in self._context_attrs: + return self._literal_expr_from_context_value( + self._context_attrs[attr_name], + f"pto.get_op_attr({attr_name!r})", + ) + if len(args) == 2: + return args[1] + raise TypeError( + f"pto.get_op_attr could not resolve attribute {attr_name!r} and no default was provided" + ) + def _analyze_scalar_constructor( self, name: str, @@ -5215,10 +5262,15 @@ def _validate_binary_dtype(self, name: str, dtype: ScalarType) -> None: raise TypeError("pto.vprelu only supports f16/f32 in TileLang DSL v1") if name in {"vaddreluconv", "vmulconv"} and dtype.name not in {"f16", "bf16", "f32"}: raise TypeError(f"pto.{name} only supports f16/bf16/f32 in TileLang DSL v1") - if name in {"vand", "vor", "vxor"} and not ( + if name in {"vand", "vxor"} and not ( is_integer_dtype(dtype) and integer_bitwidth(dtype) in {8, 16, 32} ): raise TypeError(f"pto.{name} only supports integer vector dtypes in TileLang DSL v1") + if name == "vor" and not ( + (is_integer_dtype(dtype) and integer_bitwidth(dtype) in {8, 16, 32}) + or dtype.name in {"f16", "bf16", "f32"} + ): + raise TypeError("pto.vor only supports integer vector dtypes and f16/bf16/f32 in TileLang DSL v1") if name in {"vshl", "vshr"} and not ( is_integer_dtype(dtype) and integer_bitwidth(dtype) in {8, 16, 32} ): diff --git a/tilelang-dsl/python/tilelang_dsl/support_matrix.py b/tilelang-dsl/python/tilelang_dsl/support_matrix.py index 763171014..7b7597866 100644 --- a/tilelang-dsl/python/tilelang_dsl/support_matrix.py +++ b/tilelang-dsl/python/tilelang_dsl/support_matrix.py @@ -21,6 +21,7 @@ "bytewidth", "get_lanes", "elements_per_vreg", + "get_op_attr", "vreg", "i1", "i8", diff --git a/tilelang-dsl/python/tilelang_dsl/types.py b/tilelang-dsl/python/tilelang_dsl/types.py index 62f9dc964..1485013e1 100644 --- a/tilelang-dsl/python/tilelang_dsl/types.py +++ b/tilelang-dsl/python/tilelang_dsl/types.py @@ -649,6 +649,12 @@ def constexpr(value: bool) -> bool: return value +def get_op_attr(name: str, default: Any = None) -> Any: + if not isinstance(name, str) or not name: + raise TypeError("get_op_attr expects a non-empty string attribute name") + return default + + __all__ = [ "ScalarType", "WildcardType", @@ -705,6 +711,7 @@ def constexpr(value: bool) -> bool: "mask_b16", "mask_b32", "constexpr", + "get_op_attr", "bytewidth", "get_lanes", "elements_per_vreg", From 39e00c852b61a1f4b085c8b32039d260f78e4b07 Mon Sep 17 00:00:00 2001 From: qukelin Date: Thu, 16 Apr 2026 20:57:00 +0800 Subject: [PATCH 092/192] Add f16-to-f32 tcvt support --- ...4-15-tcvt-tilelib-sample-and-work-items.md | 2 +- lib/TileOps/tcvt_template.py | 56 +++++++++++ test/basic/expand_tile_op_tilelang_tcvt.pto | 21 ++++ .../npu/a5/src/st/testcase/tcvt/cases.py | 11 +++ .../npu/a5/src/st/testcase/tcvt/gen_data.py | 3 + .../npu/a5/src/st/testcase/tcvt/launch.cpp | 5 + .../npu/a5/src/st/testcase/tcvt/main.cpp | 2 + .../npu/a5/src/st/testcase/tcvt/tcvt.pto | 48 +++++++++ .../python/tilelang_dsl/frontend_ast.py | 1 + tilelang-dsl/python/tilelang_dsl/lowering.py | 13 ++- tilelang-dsl/python/tilelang_dsl/semantic.py | 98 +++++++++++++++++-- 11 files changed, 248 insertions(+), 12 deletions(-) diff --git a/docs/designs/2026-04-15-tcvt-tilelib-sample-and-work-items.md b/docs/designs/2026-04-15-tcvt-tilelib-sample-and-work-items.md index bc4de825b..7a70859ce 100644 --- a/docs/designs/2026-04-15-tcvt-tilelib-sample-and-work-items.md +++ b/docs/designs/2026-04-15-tcvt-tilelib-sample-and-work-items.md @@ -131,7 +131,7 @@ PTOAS + TileLib 已经全部打通。 | 源类型 | 目标类型 | A5 helper 覆盖 | 备注 | TileLib是否支持 | |---|---|---|---|---| -| `f16` | `f32` | 1D+2D,`vcvt + part` | type expand | | +| `f16` | `f32` | 1D+2D,`vcvt + part` | type expand | `已支持` | | `bf16` | `f32` | 1D+2D,`vcvt + part` | type expand | | | `i16` | `f32` / `i32` / `u32` | 1D+2D,expand helper | widening path | | | `i32` | `i64` | 1D+2D,expand helper | | | diff --git a/lib/TileOps/tcvt_template.py b/lib/TileOps/tcvt_template.py index 0f6dfb640..4241419cd 100644 --- a/lib/TileOps/tcvt_template.py +++ b/lib/TileOps/tcvt_template.py @@ -2,12 +2,40 @@ import tilelang_dsl as pto + +def _supports_basic_rowwise_tcvt( + src_shape=(), + src_valid_shape=(), + dst_shape=(), + dst_valid_shape=(), + src_config=None, + dst_config=None, +): + if tuple(src_shape) != tuple(dst_shape): + return False + if tuple(src_valid_shape) != tuple(dst_valid_shape): + return False + if len(src_shape) != 2 or len(dst_shape) != 2: + return False + if src_config is None or dst_config is None: + return False + if src_config.b_layout != pto.BLayout.ROW_MAJOR: + return False + if dst_config.b_layout != pto.BLayout.ROW_MAJOR: + return False + if src_config.s_layout != pto.SLayout.NONE_BOX: + return False + if dst_config.s_layout != pto.SLayout.NONE_BOX: + return False + return True + @pto.vkernel( target="a5", op="pto.tcvt", dtypes=[ (pto.f32, pto.i32), ], + constraints=[_supports_basic_rowwise_tcvt], ) def template_tcvt_f32_to_i32(src: pto.Tile, dst: pto.Tile): dst_dtype = dst.element_type @@ -41,12 +69,40 @@ def template_tcvt_f32_to_i32(src: pto.Tile, dst: pto.Tile): return +@pto.vkernel( + target="a5", + op="pto.tcvt", + dtypes=[ + (pto.f16, pto.f32), + ], + constraints=[_supports_basic_rowwise_tcvt], +) +def template_tcvt_f16_to_f32(src: pto.Tile, dst: pto.Tile): + valid_rows, valid_cols = dst.valid_shape + full_mask = pto.make_mask(pto.f16, pto.PAT.ALL) + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(pto.f32)): + store_mask, remained = pto.make_mask(pto.f32, remained) + vec = pto.vlds(src[row, col:], dist="UNPK_B16") + converted = pto.vcvt( + vec, + pto.f32, + full_mask, + part=pto.VcvtPartMode.EVEN, + ) + pto.vsts(converted, dst[row, col:], store_mask) + return + + @pto.vkernel( target="a5", op="pto.tcvt", dtypes=[ (pto.i32, pto.f32), ], + constraints=[_supports_basic_rowwise_tcvt], ) def template_tcvt_i32_to_f32(src: pto.Tile, dst: pto.Tile): dst_dtype = dst.element_type diff --git a/test/basic/expand_tile_op_tilelang_tcvt.pto b/test/basic/expand_tile_op_tilelang_tcvt.pto index ee1ae4060..a897611fa 100644 --- a/test/basic/expand_tile_op_tilelang_tcvt.pto +++ b/test/basic/expand_tile_op_tilelang_tcvt.pto @@ -27,6 +27,12 @@ // CHECK: pto.vcvt {{.*}} {rnd = "R"} : !pto.vreg<64xi32> -> !pto.vreg<64xf32> // CHECK: pto.vcvt {{.*}} {rnd = "A"} : !pto.vreg<64xi32> -> !pto.vreg<64xf32> +// CHECK-LABEL: func.func @TCVT_F16_TO_F32 +// CHECK-NOT: pto.tcvt ins +// CHECK: pto.vecscope +// CHECK: pto.vlds {{.*}} {dist = "UNPK"} : {{.*}} -> !pto.vreg<128xf16> +// CHECK: pto.vcvt {{.*}} {part = "EVEN"} : !pto.vreg<128xf16> -> !pto.vreg<64xf32> + module { func.func @TCVT_F32_TO_I32() { %src = pto.alloc_tile @@ -73,4 +79,19 @@ module { blayout=row_major, slayout=none_box, fractal=512, pad=0>) return } + + func.func @TCVT_F16_TO_F32() { + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } } diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcvt/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/cases.py index 0674eadbb..7752e17e3 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tcvt/cases.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/cases.py @@ -13,6 +13,7 @@ Current TileLib tcvt support covered by this testcase: - f32 -> i32 + - f16 -> f32 - i32 -> f32 `dtype` is kept for shared validation compatibility. @@ -52,4 +53,14 @@ "round_mode": "RINT", "eps": 1e-6, }, + { + "name": "f16_to_f32_rint_16x64", + "dtype": np.float32, + "src_dtype": np.float16, + "dst_dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "round_mode": "RINT", + "eps": 1e-6, + }, ] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcvt/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/gen_data.py index e13ae1237..01885d9d7 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tcvt/gen_data.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/gen_data.py @@ -90,6 +90,9 @@ def generate_golden(case): elif src_dtype is np.int32: input_arr = make_i32_input(shape).astype(src_dtype) converted = convert_with_default_saturation(input_arr[:vr, :vc], src_dtype, dst_dtype) + elif src_dtype is np.float16: + input_arr = make_f32_input(shape).astype(src_dtype) + converted = convert_with_default_saturation(input_arr[:vr, :vc], src_dtype, dst_dtype) else: raise TypeError(f"unsupported tcvt ST source dtype: {src_dtype}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcvt/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/launch.cpp index 738be1bfb..4515f1f4c 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tcvt/launch.cpp +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/launch.cpp @@ -15,6 +15,7 @@ extern "C" __global__ AICORE void TCVT_f32_to_i32_rint_16x64(__gm__ float *src, __gm__ int32_t *dst); extern "C" __global__ AICORE void TCVT_f32_to_i32_round_16x64(__gm__ float *src, __gm__ int32_t *dst); extern "C" __global__ AICORE void TCVT_i32_to_f32_rint_16x64(__gm__ int32_t *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_f16_to_f32_rint_16x64(__gm__ half *src, __gm__ float *dst); void LaunchTCVT_f32_to_i32_rint_16x64(void *src, void *dst, void *stream) { TCVT_f32_to_i32_rint_16x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ int32_t *)dst); @@ -27,3 +28,7 @@ void LaunchTCVT_f32_to_i32_round_16x64(void *src, void *dst, void *stream) { void LaunchTCVT_i32_to_f32_rint_16x64(void *src, void *dst, void *stream) { TCVT_i32_to_f32_rint_16x64<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ float *)dst); } + +void LaunchTCVT_f16_to_f32_rint_16x64(void *src, void *dst, void *stream) { + TCVT_f16_to_f32_rint_16x64<<<1, nullptr, stream>>>((__gm__ half *)src, (__gm__ float *)dst); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcvt/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/main.cpp index 5d080c29d..002bbb53d 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tcvt/main.cpp +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/main.cpp @@ -19,6 +19,7 @@ using namespace PtoTestCommon; void LaunchTCVT_f32_to_i32_rint_16x64(void *src, void *dst, void *stream); void LaunchTCVT_f32_to_i32_round_16x64(void *src, void *dst, void *stream); void LaunchTCVT_i32_to_f32_rint_16x64(void *src, void *dst, void *stream); +void LaunchTCVT_f16_to_f32_rint_16x64(void *src, void *dst, void *stream); using LaunchFn = void (*)(void *, void *, void *); @@ -37,6 +38,7 @@ static const TestCase kCases[] = { {"f32_to_i32_rint_16x64", LaunchTCVT_f32_to_i32_rint_16x64, 16, 64, 16, 64, sizeof(float), sizeof(int32_t)}, {"f32_to_i32_round_16x64", LaunchTCVT_f32_to_i32_round_16x64, 16, 64, 16, 64, sizeof(float), sizeof(int32_t)}, {"i32_to_f32_rint_16x64", LaunchTCVT_i32_to_f32_rint_16x64, 16, 64, 16, 64, sizeof(int32_t), sizeof(float)}, + {"f16_to_f32_rint_16x64", LaunchTCVT_f16_to_f32_rint_16x64, 16, 64, 16, 64, sizeof(uint16_t), sizeof(float)}, }; static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcvt/tcvt.pto b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/tcvt.pto index f11a13912..cd1457eb4 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tcvt/tcvt.pto +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/tcvt.pto @@ -154,4 +154,52 @@ module { outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) return } + + // Case 3: f16 -> f32, default RINT + func.func @TCVT_f16_to_f32_rint_16x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } } diff --git a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py index 2fcc25639..e19825645 100644 --- a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py +++ b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py @@ -818,6 +818,7 @@ def _collect_reachable_inline_procs( ), "vcvt": frozenset({"rnd", "sat", "part"}), "vtrc": frozenset({"rnd"}), + "vlds": frozenset({"dist"}), } diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index 5e260d79c..fc492929f 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -2556,25 +2556,30 @@ def _lower_call_expr( source = self._lower_expr(expr.args[0], env, indent=indent, into=into) if isinstance(source.type, SemanticTileType): source = self._materialize_tile_memref(source, indent=indent, into=into) + index_args = expr.args[1:] + dist_suffix = "" + if index_args and self._has_optional_string_literal(index_args[-1]): + dist_suffix = f" {{dist = {self._render_string_literal(index_args[-1])}}}" + index_args = index_args[:-1] if ( isinstance(expr.args[0].type, SemanticTileType) and expr.args[0].type.rank == 2 - and len(expr.args[1:]) == 2 + and len(index_args) == 2 ): source = self._materialize_rank2_tile_subview( source, expr.args[0].type, - expr.args[1:], + index_args, env, indent=indent, into=into, ) rendered_indices = self._materialize_constant(0, SemanticIndexType()) else: - rendered_indices = self._render_index_list(expr.args[1:], env, indent=indent, into=into) + rendered_indices = self._render_index_list(index_args, env, indent=indent, into=into) into.append( self._indent(indent) - + f"{result_name} = pto.vlds {source.name}[{rendered_indices}] : " + + f"{result_name} = pto.vlds {source.name}[{rendered_indices}]{dist_suffix} : " + f"{self._render_type(source.type)} -> {self._render_type(expr.type)}" ) return _RenderedValue(name=result_name, type=expr.type) diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 8c67b0a64..8d2ea475c 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -3004,14 +3004,12 @@ def _analyze_expr( ) base = SemanticBindingRef(binding=binding, type=binding.type) return self._analyze_as_ptr_method(base) - if expr.namespace == "pto" and expr.name == "vlds" and len(expr.args) == 1: - base, indices = self._analyze_tile_vector_access( - expr.args[0], + if expr.namespace == "pto" and expr.name == "vlds": + return self._analyze_vlds_frontend_call( + expr, env, allow_outer_lookup=allow_outer_lookup, - context="pto.vlds source", ) - return self._analyze_vlds((base, *indices)) if ( expr.namespace == "pto" and expr.name == "vldas" @@ -4047,7 +4045,12 @@ def _analyze_init_align(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: raise TypeError("pto.init_align does not accept positional arguments in TileLang DSL v1") return SemanticCallExpr(namespace="pto", name="init_align", args=(), type=SemanticAlignType()) - def _analyze_vlds(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + def _analyze_vlds( + self, + args: tuple[SemanticExpr, ...], + *, + dist: SemanticExpr | None = None, + ) -> SemanticExpr: if len(args) < 2: raise TypeError("pto.vlds expects at least 2 positional arguments in TileLang DSL v1") source, *indices = args @@ -4058,13 +4061,52 @@ def _analyze_vlds(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: source = self._require_pointer_expr(source, "pto.vlds source", memory_space="ub") for index in indices: self._require_index_typed_expr(index) + lowered_args: tuple[SemanticExpr, ...] + if dist is not None: + lowered_args = (source, *indices, dist) + else: + lowered_args = (source, *indices) return SemanticCallExpr( namespace="pto", name="vlds", - args=(source, *indices), + args=lowered_args, type=self._vreg_type_for_dtype(source.type.element_dtype), ) + def _analyze_vlds_frontend_call( + self, + expr: FrontendCallExpr, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + ) -> SemanticExpr: + analyzed_keywords = { + name: self._analyze_expr(value, env, allow_outer_lookup=allow_outer_lookup) + for name, value in expr.keywords + } + unexpected_keywords = sorted(set(analyzed_keywords) - {"dist"}) + if unexpected_keywords: + keyword_text = ", ".join(unexpected_keywords) + raise TypeError( + "pto.vlds only accepts keyword attr `dist`; " + f"got unsupported keyword(s): {keyword_text}" + ) + dist = self._normalize_vlds_dist(analyzed_keywords.get("dist"), "pto.vlds dist") + if len(expr.args) == 1 and isinstance(expr.args[0], FrontendSubscriptExpr): + base, indices = self._analyze_tile_vector_access( + expr.args[0], + env, + allow_outer_lookup=allow_outer_lookup, + context="pto.vlds source", + ) + return self._analyze_vlds((base, *indices), dist=dist) + + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + return self._analyze_vlds(args, dist=dist) + def _analyze_vldas(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: if len(args) not in {1, 2, 3}: raise TypeError("pto.vldas expects 1 positional source or Tile[start:]/Tile[row, col:] in TileLang DSL v1") @@ -5110,6 +5152,48 @@ def _normalize_predicate_store_dist( raise TypeError("predicate store dist must be \"NORM\" or \"PK\" in TileLang DSL v1") return SemanticLiteralExpr(value=dist, type=SemanticMetaType(kind="string")) + def _normalize_vlds_dist( + self, + expr: SemanticExpr | None, + context: str, + ) -> SemanticExpr | None: + if expr is None: + return None + dist = self._require_string_expr(expr, context) + legacy_map = { + "BRC_B8": "BRC", + "BRC_B16": "BRC", + "BRC_B32": "BRC", + "US_B8": "US", + "US_B16": "US", + "DS_B8": "DS", + "DS_B16": "DS", + "UNPK_B8": "UNPK", + "UNPK_B16": "UNPK", + "UNPK_B32": "UNPK", + "E2B_B16": "E2B", + "E2B_B32": "E2B", + } + normalized = legacy_map.get(dist, dist) + if normalized not in { + "NORM", + "BRC", + "US", + "DS", + "UNPK", + "BRC_BLK", + "E2B", + "UNPK4", + "SPLT4CHN", + "SPLT2CHN", + }: + raise TypeError( + "pto.vlds dist must be one of " + "\"NORM\", \"BRC\", \"US\", \"DS\", \"UNPK\", \"BRC_BLK\", " + "\"E2B\", \"UNPK4\", \"SPLT4CHN\", or \"SPLT2CHN\" in TileLang DSL v1" + ) + return SemanticLiteralExpr(value=normalized, type=SemanticMetaType(kind="string")) + def _require_i1_expr(self, expr: SemanticExpr, context: str) -> None: scalar = self._require_scalar_expr(expr, context) if scalar.dtype != i1: From 731865a23f839aedbfb26d274e43afd23c7e2828 Mon Sep 17 00:00:00 2001 From: qukelin Date: Thu, 16 Apr 2026 23:00:56 +0800 Subject: [PATCH 093/192] Add f32-to-f16 tcvt support --- ...4-15-tcvt-tilelib-sample-and-work-items.md | 3 +- lib/PTO/IR/VPTO.cpp | 16 ++- lib/PTO/Transforms/PTOValidateVPTOIR.cpp | 30 ++++- lib/TileOps/tcvt_template.py | 41 ++++++ test/basic/expand_tile_op_tilelang_tcvt.pto | 25 +++- .../npu/a5/src/st/testcase/tcvt/cases.py | 11 ++ .../npu/a5/src/st/testcase/tcvt/gen_data.py | 5 +- .../npu/a5/src/st/testcase/tcvt/launch.cpp | 5 + .../npu/a5/src/st/testcase/tcvt/main.cpp | 2 + .../npu/a5/src/st/testcase/tcvt/tcvt.pto | 48 +++++++ .../python/tilelang_dsl/frontend_ast.py | 1 + tilelang-dsl/python/tilelang_dsl/lowering.py | 6 +- tilelang-dsl/python/tilelang_dsl/semantic.py | 120 ++++++++++++++---- 13 files changed, 273 insertions(+), 40 deletions(-) diff --git a/docs/designs/2026-04-15-tcvt-tilelib-sample-and-work-items.md b/docs/designs/2026-04-15-tcvt-tilelib-sample-and-work-items.md index 7a70859ce..0fecba0c8 100644 --- a/docs/designs/2026-04-15-tcvt-tilelib-sample-and-work-items.md +++ b/docs/designs/2026-04-15-tcvt-tilelib-sample-and-work-items.md @@ -171,7 +171,7 @@ PTOAS + TileLib 已经全部打通。 | 源类型 | 目标类型 | A5 helper 覆盖 | 默认 `effective_sat_mode` | 备注 | TileLib是否支持 | |---|---|---|---|---|---| -| `f32` | `f16` / `bf16` | 1D+2D,`vcvt + part` | `ON` | 窄化 float | | +| `f32` | `f16` / `bf16` | 1D+2D,`vcvt + part` | `ON` | 窄化 float | `f32 -> f16` 已支持 | | `f32` | `i32` | 1D+2D,`vcvt` | `ON` | 当前已先打通这一类普通路径 | `已支持` | | `f32` | `i64` | 1D+2D,`vcvt + part` | `ON` | | | | `f32` | `fp8_e4m3` / `fp8_e5m2` | 1D+2D,`vcvt + part` | `ON` | | | @@ -350,6 +350,7 @@ def template_tcvt(src: pto.Tile, dst: pto.Tile): - `f16 -> i16` - `f16 -> i8` - `f32 -> i32` +- `f32 -> f16` - `i32 -> f32` ## 6. 结论 diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp index 9a9f6e447..9a5759d87 100644 --- a/lib/PTO/IR/VPTO.cpp +++ b/lib/PTO/IR/VPTO.cpp @@ -909,14 +909,14 @@ static LogicalResult verifyVstsDistWidth(Operation *op, StringRef dist, return *width == 32 ? success() : op->emitOpError("dist 1PT_B32 only supports 32-bit elements"); if (dist == "PK_B16") - return *width == 16 ? success() - : op->emitOpError("dist PK_B16 only supports 16-bit elements"); + return *width == 8 ? success() + : op->emitOpError("dist PK_B16 only supports 8-bit packed elements"); if (dist == "PK_B32") - return *width == 32 ? success() - : op->emitOpError("dist PK_B32 only supports 32-bit elements"); + return *width == 16 ? success() + : op->emitOpError("dist PK_B32 only supports 16-bit packed elements"); if (dist == "PK_B64") - return *width == 64 ? success() - : op->emitOpError("dist PK_B64 only supports 64-bit elements"); + return *width == 32 ? success() + : op->emitOpError("dist PK_B64 only supports 32-bit packed elements"); if (dist == "PK4_B32") return *width == 32 ? success() : op->emitOpError("dist PK4_B32 only supports 32-bit elements"); @@ -951,6 +951,10 @@ getVstsMaskGranularityOverride(StringRef dist, Type elementType) { return StringRef("b16"); if (dist == "MRG2CHN_B16") return StringRef("b32"); + if (dist == "PK_B16") + return StringRef("b16"); + if (dist == "PK_B32") + return StringRef("b32"); return std::nullopt; } diff --git a/lib/PTO/Transforms/PTOValidateVPTOIR.cpp b/lib/PTO/Transforms/PTOValidateVPTOIR.cpp index ae384fcdc..e5ba23c39 100644 --- a/lib/PTO/Transforms/PTOValidateVPTOIR.cpp +++ b/lib/PTO/Transforms/PTOValidateVPTOIR.cpp @@ -322,17 +322,35 @@ class VPTOLegalityValidator { if (!valueType) return std::nullopt; - auto elementType = valueType.getElementType(); - auto elementIntType = dyn_cast(elementType); - if (!elementIntType) - return std::nullopt; - auto distAttr = op->getAttrOfType("dist"); if (!distAttr) return std::nullopt; StringRef dist = distAttr.getValue(); - unsigned width = elementIntType.getWidth(); + auto elementType = valueType.getElementType(); + unsigned width = 0; + if (auto elementIntType = dyn_cast(elementType)) { + width = elementIntType.getWidth(); + } else if (elementType.isF16() || elementType.isBF16()) { + width = 16; + } else if (elementType.isF32()) { + width = 32; + } else if (elementType.isF64()) { + width = 64; + } else { + return std::nullopt; + } + + if (dist == "PK_B16") { + if (width == 8) + return VPTOMaskGranularity::B16; + return std::nullopt; + } + if (dist == "PK_B32") { + if (width == 16) + return VPTOMaskGranularity::B32; + return std::nullopt; + } if (dist == "MRG4CHN_B8") { if (width == 8) return VPTOMaskGranularity::B32; diff --git a/lib/TileOps/tcvt_template.py b/lib/TileOps/tcvt_template.py index 4241419cd..8a37e00cb 100644 --- a/lib/TileOps/tcvt_template.py +++ b/lib/TileOps/tcvt_template.py @@ -29,6 +29,47 @@ def _supports_basic_rowwise_tcvt( return False return True +@pto.vkernel( + target="a5", + op="pto.tcvt", + dtypes=[ + (pto.f32, pto.f16), + ], + constraints=[_supports_basic_rowwise_tcvt], +) +def template_tcvt_f32_to_f16(src: pto.Tile, dst: pto.Tile): + valid_rows, valid_cols = dst.valid_shape + round_mode = pto.get_op_attr("round_mode", "RINT") + rnd = pto.VcvtRoundMode.R + if pto.constexpr(round_mode == "ROUND"): + rnd = pto.VcvtRoundMode.A + elif pto.constexpr(round_mode == "FLOOR"): + rnd = pto.VcvtRoundMode.F + elif pto.constexpr(round_mode == "CEIL"): + rnd = pto.VcvtRoundMode.C + elif pto.constexpr(round_mode == "TRUNC"): + rnd = pto.VcvtRoundMode.Z + elif pto.constexpr(round_mode == "ODD"): + rnd = pto.VcvtRoundMode.O + + full_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(pto.f32)): + store_mask, remained = pto.make_mask(pto.f32, remained) + vec = pto.vlds(src[row, col:]) + converted = pto.vcvt( + vec, + pto.f16, + full_mask, + rnd=rnd, + sat=pto.VcvtSatMode.SAT, + part=pto.VcvtPartMode.EVEN, + ) + pto.vsts(converted, dst[row, col:], store_mask, dist="PK_B32") + return + + @pto.vkernel( target="a5", op="pto.tcvt", diff --git a/test/basic/expand_tile_op_tilelang_tcvt.pto b/test/basic/expand_tile_op_tilelang_tcvt.pto index a897611fa..d3b2b890f 100644 --- a/test/basic/expand_tile_op_tilelang_tcvt.pto +++ b/test/basic/expand_tile_op_tilelang_tcvt.pto @@ -9,7 +9,9 @@ // Test the first regular tcvt paths through ExpandTileOp. // Current scope is intentionally narrow: // - f32 -> i32: regular vcvt with rnd + sat +// - f32 -> f16: cast32to16_2D_NoPostUpdate-style vcvt(part=EVEN) + PK_B32 store // - i32 -> f32: regular vcvt with rnd only +// - f16 -> f32: cast16to32-style UNPK_B16 load + vcvt(part=EVEN) // - round_mode must reach the template so different rmode values materialize // different vcvt attrs // @@ -30,9 +32,15 @@ // CHECK-LABEL: func.func @TCVT_F16_TO_F32 // CHECK-NOT: pto.tcvt ins // CHECK: pto.vecscope -// CHECK: pto.vlds {{.*}} {dist = "UNPK"} : {{.*}} -> !pto.vreg<128xf16> +// CHECK: pto.vlds {{.*}} {dist = "UNPK_B16"} : {{.*}} -> !pto.vreg<128xf16> // CHECK: pto.vcvt {{.*}} {part = "EVEN"} : !pto.vreg<128xf16> -> !pto.vreg<64xf32> +// CHECK-LABEL: func.func @TCVT_F32_TO_F16 +// CHECK-NOT: pto.tcvt ins +// CHECK: pto.vecscope +// CHECK: pto.vcvt {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32> -> !pto.vreg<128xf16> +// CHECK: pto.vsts {{.*}} {dist = "PK_B32"} : !pto.vreg<128xf16>, {{.*}}, !pto.mask + module { func.func @TCVT_F32_TO_I32() { %src = pto.alloc_tile @@ -94,4 +102,19 @@ module { blayout=row_major, slayout=none_box, fractal=512, pad=0>) return } + + func.func @TCVT_F32_TO_F16() { + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } } diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcvt/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/cases.py index 7752e17e3..947a13c7f 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tcvt/cases.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/cases.py @@ -13,6 +13,7 @@ Current TileLib tcvt support covered by this testcase: - f32 -> i32 + - f32 -> f16 - f16 -> f32 - i32 -> f32 @@ -53,6 +54,16 @@ "round_mode": "RINT", "eps": 1e-6, }, + { + "name": "f32_to_f16_rint_16x64", + "dtype": np.float16, + "src_dtype": np.float32, + "dst_dtype": np.float16, + "shape": (16, 64), + "valid_shape": (16, 64), + "round_mode": "RINT", + "eps": 1e-3, + }, { "name": "f16_to_f32_rint_16x64", "dtype": np.float32, diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcvt/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/gen_data.py index 01885d9d7..464381275 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tcvt/gen_data.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/gen_data.py @@ -83,10 +83,13 @@ def generate_golden(case): shape = case["shape"] vr, vc = case["valid_shape"] - if src_dtype is np.float32: + if src_dtype is np.float32 and dst_dtype is np.int32: input_arr = make_f32_input(shape).astype(src_dtype) rounded = apply_round_mode(input_arr[:vr, :vc], case["round_mode"]) converted = convert_with_default_saturation(rounded, src_dtype, dst_dtype) + elif src_dtype is np.float32 and dst_dtype is np.float16: + input_arr = make_f32_input(shape).astype(src_dtype) + converted = convert_with_default_saturation(input_arr[:vr, :vc], src_dtype, dst_dtype) elif src_dtype is np.int32: input_arr = make_i32_input(shape).astype(src_dtype) converted = convert_with_default_saturation(input_arr[:vr, :vc], src_dtype, dst_dtype) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcvt/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/launch.cpp index 4515f1f4c..b7ab5efd2 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tcvt/launch.cpp +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/launch.cpp @@ -15,6 +15,7 @@ extern "C" __global__ AICORE void TCVT_f32_to_i32_rint_16x64(__gm__ float *src, __gm__ int32_t *dst); extern "C" __global__ AICORE void TCVT_f32_to_i32_round_16x64(__gm__ float *src, __gm__ int32_t *dst); extern "C" __global__ AICORE void TCVT_i32_to_f32_rint_16x64(__gm__ int32_t *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_f32_to_f16_rint_16x64(__gm__ float *src, __gm__ half *dst); extern "C" __global__ AICORE void TCVT_f16_to_f32_rint_16x64(__gm__ half *src, __gm__ float *dst); void LaunchTCVT_f32_to_i32_rint_16x64(void *src, void *dst, void *stream) { @@ -29,6 +30,10 @@ void LaunchTCVT_i32_to_f32_rint_16x64(void *src, void *dst, void *stream) { TCVT_i32_to_f32_rint_16x64<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ float *)dst); } +void LaunchTCVT_f32_to_f16_rint_16x64(void *src, void *dst, void *stream) { + TCVT_f32_to_f16_rint_16x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ half *)dst); +} + void LaunchTCVT_f16_to_f32_rint_16x64(void *src, void *dst, void *stream) { TCVT_f16_to_f32_rint_16x64<<<1, nullptr, stream>>>((__gm__ half *)src, (__gm__ float *)dst); } diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcvt/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/main.cpp index 002bbb53d..9207a4120 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tcvt/main.cpp +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/main.cpp @@ -19,6 +19,7 @@ using namespace PtoTestCommon; void LaunchTCVT_f32_to_i32_rint_16x64(void *src, void *dst, void *stream); void LaunchTCVT_f32_to_i32_round_16x64(void *src, void *dst, void *stream); void LaunchTCVT_i32_to_f32_rint_16x64(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_f16_rint_16x64(void *src, void *dst, void *stream); void LaunchTCVT_f16_to_f32_rint_16x64(void *src, void *dst, void *stream); using LaunchFn = void (*)(void *, void *, void *); @@ -38,6 +39,7 @@ static const TestCase kCases[] = { {"f32_to_i32_rint_16x64", LaunchTCVT_f32_to_i32_rint_16x64, 16, 64, 16, 64, sizeof(float), sizeof(int32_t)}, {"f32_to_i32_round_16x64", LaunchTCVT_f32_to_i32_round_16x64, 16, 64, 16, 64, sizeof(float), sizeof(int32_t)}, {"i32_to_f32_rint_16x64", LaunchTCVT_i32_to_f32_rint_16x64, 16, 64, 16, 64, sizeof(int32_t), sizeof(float)}, + {"f32_to_f16_rint_16x64", LaunchTCVT_f32_to_f16_rint_16x64, 16, 64, 16, 64, sizeof(float), sizeof(uint16_t)}, {"f16_to_f32_rint_16x64", LaunchTCVT_f16_to_f32_rint_16x64, 16, 64, 16, 64, sizeof(uint16_t), sizeof(float)}, }; static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcvt/tcvt.pto b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/tcvt.pto index cd1457eb4..c5dc4bb0b 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tcvt/tcvt.pto +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/tcvt.pto @@ -202,4 +202,52 @@ module { outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) return } + + // Case 4: f32 -> f16, default RINT + func.func @TCVT_f32_to_f16_rint_16x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + return + } } diff --git a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py index e19825645..a0d659ed2 100644 --- a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py +++ b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py @@ -819,6 +819,7 @@ def _collect_reachable_inline_procs( "vcvt": frozenset({"rnd", "sat", "part"}), "vtrc": frozenset({"rnd"}), "vlds": frozenset({"dist"}), + "vsts": frozenset({"dist"}), } diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index fc492929f..35cecb607 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -991,10 +991,14 @@ def _render_vector_store( else: rendered_indices = self._render_index_list(stmt.indices, env, indent=indent, into=lines) mask = self._lower_expr(stmt.mask, env, indent=indent, into=lines) + attrs = "" + if stmt.dist is not None: + dist = self._render_string_literal(stmt.dist) + attrs = f" {{dist = {dist}}}" lines.append( self._indent(indent) + "pto.vsts " - + f"{value.name}, {destination.name}[{rendered_indices}], {mask.name} : " + + f"{value.name}, {destination.name}[{rendered_indices}], {mask.name}{attrs} : " + f"{self._render_type(value.type)}, {self._render_type(destination.type)}, {self._render_type(mask.type)}" ) return lines diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 8d2ea475c..c3ca9fcf5 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -544,6 +544,7 @@ class SemanticVectorStoreStmt(SemanticStmt): value: SemanticExpr destination: SemanticExpr indices: tuple[SemanticExpr, ...] + dist: SemanticExpr | None mask: SemanticExpr @@ -1830,6 +1831,18 @@ def _analyze_vector_store_stmt( ) if expr.name == "vsts": + analyzed_keywords = { + name: self._analyze_expr(value, env, allow_outer_lookup=allow_outer_lookup) + for name, value in expr.keywords + } + unexpected_keywords = sorted(set(analyzed_keywords) - {"dist"}) + if unexpected_keywords: + keyword_text = ", ".join(unexpected_keywords) + raise TypeError( + "pto.vsts only accepts keyword attr `dist`; " + f"got unsupported keyword(s): {keyword_text}" + ) + dist = self._normalize_vsts_dist(analyzed_keywords.get("dist"), "pto.vsts dist") if len(expr.args) == 3: value = self._analyze_expr(expr.args[0], env, allow_outer_lookup=allow_outer_lookup) destination, indices = self._analyze_tile_vector_access( @@ -1852,13 +1865,14 @@ def _analyze_vector_store_stmt( self._require_vector_pointer_expr(destination, "pto.vsts destination") for index in indices: self._require_index_typed_expr(index) - self._require_mask_for_vreg(mask, value.type, "pto.vsts") + self._require_mask_for_vsts(mask, value.type, dist, "pto.vsts") self._require_matching_vector_pointer(value.type, destination.type, "pto.vsts") return ( SemanticVectorStoreStmt( value=value, destination=destination, indices=indices, + dist=dist, mask=mask, ), dict(env), @@ -5160,37 +5174,67 @@ def _normalize_vlds_dist( if expr is None: return None dist = self._require_string_expr(expr, context) - legacy_map = { - "BRC_B8": "BRC", - "BRC_B16": "BRC", - "BRC_B32": "BRC", - "US_B8": "US", - "US_B16": "US", - "DS_B8": "DS", - "DS_B16": "DS", - "UNPK_B8": "UNPK", - "UNPK_B16": "UNPK", - "UNPK_B32": "UNPK", - "E2B_B16": "E2B", - "E2B_B32": "E2B", - } - normalized = legacy_map.get(dist, dist) + normalized = dist if normalized not in { "NORM", - "BRC", - "US", - "DS", - "UNPK", + "BRC_B8", + "BRC_B16", + "BRC_B32", + "US_B8", + "US_B16", + "DS_B8", + "DS_B16", + "UNPK_B8", + "UNPK_B16", + "UNPK_B32", "BRC_BLK", - "E2B", + "E2B_B16", + "E2B_B32", "UNPK4", "SPLT4CHN", - "SPLT2CHN", + "SPLT2CHN_B8", + "SPLT2CHN_B16", }: raise TypeError( "pto.vlds dist must be one of " - "\"NORM\", \"BRC\", \"US\", \"DS\", \"UNPK\", \"BRC_BLK\", " - "\"E2B\", \"UNPK4\", \"SPLT4CHN\", or \"SPLT2CHN\" in TileLang DSL v1" + "\"NORM\", \"BRC_B8\", \"BRC_B16\", \"BRC_B32\", " + "\"US_B8\", \"US_B16\", \"DS_B8\", \"DS_B16\", " + "\"UNPK_B8\", \"UNPK_B16\", \"UNPK_B32\", \"BRC_BLK\", " + "\"E2B_B16\", \"E2B_B32\", \"UNPK4\", \"SPLT4CHN\", " + "\"SPLT2CHN_B8\", or \"SPLT2CHN_B16\" in TileLang DSL v1" + ) + return SemanticLiteralExpr(value=normalized, type=SemanticMetaType(kind="string")) + + def _normalize_vsts_dist( + self, + expr: SemanticExpr | None, + context: str, + ) -> SemanticExpr | None: + if expr is None: + return None + dist = self._require_string_expr(expr, context) + normalized = dist + if normalized not in { + "NORM_B8", + "NORM_B16", + "NORM_B32", + "1PT_B8", + "1PT_B16", + "1PT_B32", + "PK_B16", + "PK_B32", + "PK_B64", + "PK4_B32", + "MRG4CHN_B8", + "MRG2CHN_B8", + "MRG2CHN_B16", + }: + raise TypeError( + "pto.vsts dist must be one of " + "\"NORM_B8\", \"NORM_B16\", \"NORM_B32\", " + "\"1PT_B8\", \"1PT_B16\", \"1PT_B32\", " + "\"PK_B16\", \"PK_B32\", \"PK_B64\", \"PK4_B32\", " + "\"MRG4CHN_B8\", \"MRG2CHN_B8\", or \"MRG2CHN_B16\" in TileLang DSL v1" ) return SemanticLiteralExpr(value=normalized, type=SemanticMetaType(kind="string")) @@ -5234,6 +5278,34 @@ def _require_mask_for_vreg( f"{context} requires mask granularity {expected} for vector dtype {vreg_type.element_dtype!r}" ) + def _require_mask_for_vsts( + self, + mask_expr: SemanticExpr, + vreg_type: SemanticVRegType, + dist_expr: SemanticExpr | None, + context: str, + ) -> None: + if not isinstance(mask_expr.type, SemanticMaskType): + raise TypeError(f"{context} requires a mask operand in TileLang DSL v1") + expected = self._mask_granularity_for_dtype(vreg_type.element_dtype) + if dist_expr is not None: + dist = self._require_string_expr(dist_expr, f"{context} dist") + if dist == "PK_B16": + expected = "b16" + elif dist == "PK_B32": + expected = "b32" + elif dist == "PK_B64": + expected = "b32" + elif dist == "MRG4CHN_B8": + expected = "b32" + elif dist in {"MRG2CHN_B8", "MRG2CHN_B16"}: + expected = "b16" if dist == "MRG2CHN_B8" else "b32" + if mask_expr.type.granularity != expected: + raise TypeError( + f"{context} requires mask granularity {expected} for store dist " + f"{self._require_string_expr(dist_expr, f'{context} dist') if dist_expr is not None else 'default'}" + ) + def _require_matching_vector_pointer( self, vreg_type: SemanticVRegType, From 36dd795d28292eaf7d1c96eb3d1e0a12c0b28cb3 Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Fri, 17 Apr 2026 12:37:01 +0800 Subject: [PATCH 094/192] let the dist of store op more flexiable --- docs/isa/03-vector-load-store.md | 24 ++-- lib/PTO/IR/VPTO.cpp | 177 ------------------------- test/vpto/dist_width_agnostic_llvm.pto | 25 ++++ 3 files changed, 37 insertions(+), 189 deletions(-) create mode 100644 test/vpto/dist_width_agnostic_llvm.pto diff --git a/docs/isa/03-vector-load-store.md b/docs/isa/03-vector-load-store.md index 7b9abc634..4c199c286 100644 --- a/docs/isa/03-vector-load-store.md +++ b/docs/isa/03-vector-load-store.md @@ -85,7 +85,7 @@ Most **`dist:`** tokens are **9** issue→retire cycles. **`INTLV_B8` / `INTLV_B |--------------------|-------------------| | `NORM_B8` / `NORM_B16` / `NORM_B32` | **9** cycles (`RV_VSTI`) | | `PK_B16` / `PK_B32` / `PK_B64` / `PK4_B32` | **9** cycles | -| `INTLV_B8` / `INTLV_B16` / `INTLV_B32` (`pto.vstx2`) | **12** cycles | +| `INTLV_B8` / `INTLV_B16` / `INTLV_B32` (`pto.vstsx2`) | **12** cycles | | `MRG4CHN_B8`, `MRG2CHN_B8`, `MRG2CHN_B16` | **9** cycles (surface retained; current A5 hardware still reports them unsupported at validation time) | ### Gather, scatter, and special addressing @@ -369,7 +369,7 @@ for (int blk = 0; blk < VL / 32; ++blk) { - **semantics:** Vector store with distribution mode. - **inputs:** `%value` is the source vector, `%dest` is the UB base pointer, `%offset` is - the displacement, `%mask` selects the active lanes or sub-elements, and + the displacement, `%mask` is the predicate operand, and `DIST` selects the store distribution. - **outputs:** This op has no SSA result; it writes to UB memory. @@ -386,11 +386,13 @@ for (int blk = 0; blk < VL / 32; ++blk) { | Family | Allowed element widths | C semantics | Latency | |------|-------------|-------------|-------------| | `NORM_B8` / `NORM_B16` / `NORM_B32` | `b8`, `b16`, `b32` | `UB[base + i] = src[i]` | **9** cycles | -| `1PT_B8` / `1PT_B16` / `1PT_B32` | `b8`, `b16`, `b32` | Only element 0 is written to the destination footprint | **9** cycles | -| `PK_B16` / `PK_B32` / `PK_B64` | `b16`, `b32`, `b64` | Pack low half bits of each source element before store. | **9** cycles | -| `PK4_B32` | `b32` | Pack low 8 bits of each `b32` element before store | **9** cycles | -| `MRG4CHN_B8` | `b8` | Merge 4 channel planes into an interleaved 4-channel layout. VPTO currently requires `!pto.mask` for this family and emits a hardware-unsupported warning on A5. | **9** cycles | -| `MRG2CHN_B8` / `MRG2CHN_B16` | `b8`, `b16` | Merge 2 channel planes into an interleaved 2-channel layout. VPTO currently requires `!pto.mask` for `MRG2CHN_B8` and `!pto.mask` for `MRG2CHN_B16`, and emits a hardware-unsupported warning on A5. | **9** cycles | +| `1PT_B8` / `1PT_B16` / `1PT_B32` | `b8`, `b16`, `b32` | Only element 0 is written to the destination footprint; the predicate register is ignored. | **9** cycles | +| `PK_B16` | `b16` | Pack the source vector, extract the lower half bits of all elements, and only store the active elements. The predicate is interpreted for 16-bit data. | **9** cycles | +| `PK_B32` | `b32` | Pack the source vector, extract the lower half bits of all elements, and only store the active elements. The predicate is interpreted for 32-bit data. | **9** cycles | +| `PK_B64` | `b64` | Pack the source vector, extract the lower half bits of all elements, and only store the active elements. The predicate is interpreted for 64-bit data. | **9** cycles | +| `PK4_B32` | `b32` | Pack the source vector, extract the lower 8 bits of all elements, and only store the active elements. The predicate is interpreted for 32-bit data. | **9** cycles | +| `MRG4CHN_B8` | `b8` | Merge 4 interleaved 8-bit channels within each 32B block; the predicate is interpreted for 32-bit data and applies after channel merge. | **9** cycles | +| `MRG2CHN_B8` / `MRG2CHN_B16` | `b8`, `b16` | Merge 2 interleaved channels within each 32B block; for `MRG2CHN_B8` the predicate is interpreted for 16-bit data, and for `MRG2CHN_B16` it is interpreted for 32-bit data; in both cases it applies after channel merge. | **9** cycles | **Example — Contiguous store:** ```mlir @@ -408,16 +410,15 @@ pto.vsts %v, %ub[%offset], %mask {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.p - **inputs:** `%low` and `%high` are the two source vectors, `%dest` is the UB base pointer, `%offset` is the displacement, `DIST` selects the interleave layout, and - `%mask` gates the participating elements. + `%mask` is the predicate operand. - **outputs:** This op has no SSA result; it writes an interleaved stream to UB. - **constraints and limitations:** This family is only legal for interleave distributions. The two source vectors form an ordered pair, and the interleave semantics of that pair MUST be preserved. PTO surface accepts the `INTLV` family, which only supports the - element widths listed below. - be preserved. PTO surface accepts the `INTLV` family, which only supports the - element widths listed below. + element widths listed below. For all `INTLV_*` distributions, the predicate + register is ignored. - **latency:** `INTLV` is **12** cycles。 **Distribution families:** @@ -425,7 +426,6 @@ pto.vsts %v, %ub[%offset], %mask {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.p | Family | Allowed element widths | C semantics | Latency | |------|-------------|-------------|-------------| | `INTLV` | `b8`, `b16`, `b32` | Interleave `%low` / `%high` into one destination stream | **12** cycles | -| `INTLV` | `b8`, `b16`, `b32` | ```c // INTLV family on 32-bit elements: diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp index 9a5759d87..7a3e1a4d5 100644 --- a/lib/PTO/IR/VPTO.cpp +++ b/lib/PTO/IR/VPTO.cpp @@ -799,146 +799,6 @@ static bool isSupportedVstsx2DistToken(StringRef dist) { return dist == "INTLV_B8" || dist == "INTLV_B16" || dist == "INTLV_B32"; } -static LogicalResult verifyVldsDistWidth(Operation *op, StringRef dist, - Type elementType) { - auto width = getDistElementWidth(elementType); - if (!width) - return op->emitOpError("requires load element type with a concrete bit width"); - - if (dist == "NORM" || dist == "BRC_BLK") - return success(); - if (dist == "BRC_B8") - return *width == 8 ? success() - : op->emitOpError("dist BRC_B8 only supports 8-bit elements"); - if (dist == "BRC_B16") - return *width == 16 ? success() - : op->emitOpError("dist BRC_B16 only supports 16-bit elements"); - if (dist == "BRC_B32") - return *width == 32 ? success() - : op->emitOpError("dist BRC_B32 only supports 32-bit elements"); - if (dist == "US_B8") - return *width == 8 ? success() - : op->emitOpError("dist US_B8 only supports 8-bit elements"); - if (dist == "US_B16") - return *width == 16 ? success() - : op->emitOpError("dist US_B16 only supports 16-bit elements"); - if (dist == "DS_B8") - return *width == 8 ? success() - : op->emitOpError("dist DS_B8 only supports 8-bit elements"); - if (dist == "DS_B16") - return *width == 16 ? success() - : op->emitOpError("dist DS_B16 only supports 16-bit elements"); - if (dist == "UNPK_B8") - return *width == 8 ? success() - : op->emitOpError("dist UNPK_B8 only supports 8-bit elements"); - if (dist == "UNPK_B16") - return *width == 16 ? success() - : op->emitOpError("dist UNPK_B16 only supports 16-bit elements"); - if (dist == "UNPK_B32") - return *width == 32 ? success() - : op->emitOpError("dist UNPK_B32 only supports 32-bit elements"); - if (dist == "E2B_B16") - return *width == 16 ? success() - : op->emitOpError("dist E2B_B16 only supports 16-bit elements"); - if (dist == "E2B_B32") - return *width == 32 ? success() - : op->emitOpError("dist E2B_B32 only supports 32-bit elements"); - if (dist == "UNPK4") - return *width == 8 - ? success() - : op->emitOpError("dist UNPK4 only supports 8-bit elements"); - if (dist == "SPLT4CHN") - return *width == 8 - ? success() - : op->emitOpError("dist SPLT4CHN only supports 8-bit elements"); - if (dist == "SPLT2CHN_B8") - return *width == 8 - ? success() - : op->emitOpError("dist SPLT2CHN_B8 only supports 8-bit elements"); - if (dist == "SPLT2CHN_B16") - return *width == 16 - ? success() - : op->emitOpError("dist SPLT2CHN_B16 only supports 16-bit elements"); - - return op->emitOpError("requires a supported load distribution token"); -} - -static LogicalResult verifyVldsx2DistWidth(Operation *op, StringRef dist, - Type elementType) { - auto width = getDistElementWidth(elementType); - if (!width) - return op->emitOpError( - "requires x2 load element type with a concrete bit width"); - if (dist == "BDINTLV") - return success(); - if (dist == "DINTLV_B8") - return *width == 8 ? success() - : op->emitOpError("dist DINTLV_B8 only supports 8-bit elements"); - if (dist == "DINTLV_B16") - return *width == 16 ? success() - : op->emitOpError("dist DINTLV_B16 only supports 16-bit elements"); - if (dist == "DINTLV_B32") - return *width == 32 ? success() - : op->emitOpError("dist DINTLV_B32 only supports 32-bit elements"); - return op->emitOpError("requires a supported x2 load distribution token"); -} - -static LogicalResult verifyVstsDistWidth(Operation *op, StringRef dist, - Type elementType) { - auto width = getDistElementWidth(elementType); - if (!width) - return op->emitOpError( - "requires store element type with a concrete bit width"); - - if (dist == "NORM_B8") - return *width == 8 ? success() - : op->emitOpError("dist NORM_B8 only supports 8-bit elements"); - if (dist == "NORM_B16") - return *width == 16 ? success() - : op->emitOpError("dist NORM_B16 only supports 16-bit elements"); - if (dist == "NORM_B32") - return *width == 32 ? success() - : op->emitOpError("dist NORM_B32 only supports 32-bit elements"); - if (dist == "1PT_B8") - return *width == 8 ? success() - : op->emitOpError("dist 1PT_B8 only supports 8-bit elements"); - if (dist == "1PT_B16") - return *width == 16 ? success() - : op->emitOpError("dist 1PT_B16 only supports 16-bit elements"); - if (dist == "1PT_B32") - return *width == 32 ? success() - : op->emitOpError("dist 1PT_B32 only supports 32-bit elements"); - if (dist == "PK_B16") - return *width == 8 ? success() - : op->emitOpError("dist PK_B16 only supports 8-bit packed elements"); - if (dist == "PK_B32") - return *width == 16 ? success() - : op->emitOpError("dist PK_B32 only supports 16-bit packed elements"); - if (dist == "PK_B64") - return *width == 32 ? success() - : op->emitOpError("dist PK_B64 only supports 32-bit packed elements"); - if (dist == "PK4_B32") - return *width == 32 ? success() - : op->emitOpError("dist PK4_B32 only supports 32-bit elements"); - if (dist == "MRG4CHN_B8") { - if (*width != 8) - return op->emitOpError("dist MRG4CHN_B8 only supports 8-bit elements"); - return success(); - } - if (dist == "MRG2CHN_B8") { - if (*width != 8) - return op->emitOpError("dist MRG2CHN_B8 only supports 8-bit elements"); - return success(); - } - if (dist == "MRG2CHN_B16") { - if (*width != 16) - return op->emitOpError("dist MRG2CHN_B16 only supports 16-bit elements"); - return success(); - } - - return op->emitOpError("requires a supported store distribution token"); -} - static std::optional getVstsMaskGranularityOverride(StringRef dist, Type elementType) { auto width = getDistElementWidth(elementType); @@ -959,24 +819,6 @@ getVstsMaskGranularityOverride(StringRef dist, Type elementType) { return std::nullopt; } -static LogicalResult verifyVstsx2DistWidth(Operation *op, StringRef dist, - Type elementType) { - auto width = getDistElementWidth(elementType); - if (!width) - return op->emitOpError( - "requires x2 store element type with a concrete bit width"); - if (dist == "INTLV_B8") - return *width == 8 ? success() - : op->emitOpError("dist INTLV_B8 only supports 8-bit elements"); - if (dist == "INTLV_B16") - return *width == 16 ? success() - : op->emitOpError("dist INTLV_B16 only supports 16-bit elements"); - if (dist == "INTLV_B32") - return *width == 32 ? success() - : op->emitOpError("dist INTLV_B32 only supports 32-bit elements"); - return op->emitOpError("requires a supported x2 store distribution token"); -} - static bool isSupportedPostMode(StringRef mode) { return mode == "NO_POST_UPDATE" || mode == "POST_UPDATE"; } @@ -1527,10 +1369,6 @@ static LogicalResult verifyVldsCommon(LoadOp op) { "supports only NORM, BRC_B8/B16/B32, US_B8/B16, DS_B8/B16, " "UNPK_B8/B16/B32, BRC_BLK, E2B_B16/B32, UNPK4, SPLT4CHN, and " "SPLT2CHN_B8/B16 load distributions"); - if (failed(verifyVldsDistWidth( - op.getOperation(), dist, - cast(op.getResult().getType()).getElementType()))) - return failure(); } return success(); @@ -2723,10 +2561,6 @@ LogicalResult Vldsx2Op::verify() { return emitOpError("requires low/high results to share one vector type"); if (!isSupportedVldx2DistToken(getDist())) return emitOpError("requires a supported x2 load distribution token"); - if (failed(verifyVldsx2DistWidth( - getOperation(), getDist(), - cast(getLow().getType()).getElementType()))) - return failure(); return success(); } @@ -2753,13 +2587,6 @@ static LogicalResult verifyVstsCommon(StoreOp op) { dist && !isSupportedVstsDistToken(*dist)) { return op.emitOpError("requires a supported store distribution token"); } - if (std::optional dist = op.getDist(); - dist && - failed(verifyVstsDistWidth( - op.getOperation(), *dist, - cast(op.getValue().getType()).getElementType()))) - return failure(); - if (std::optional dist = op.getDist()) { if (std::optional granularity = getVstsMaskGranularityOverride( *dist, cast(op.getValue().getType()).getElementType())) { @@ -2825,10 +2652,6 @@ LogicalResult Vstsx2Op::verify() { return emitOpError("requires index offset"); if (!isSupportedVstsx2DistToken(getDist())) return emitOpError("requires a supported x2 store distribution token"); - if (failed(verifyVstsx2DistWidth( - getOperation(), getDist(), - cast(getLow().getType()).getElementType()))) - return failure(); return success(); } diff --git a/test/vpto/dist_width_agnostic_llvm.pto b/test/vpto/dist_width_agnostic_llvm.pto new file mode 100644 index 000000000..0a03219a3 --- /dev/null +++ b/test/vpto/dist_width_agnostic_llvm.pto @@ -0,0 +1,25 @@ +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --vpto-emit-hivm-llvm %s -o - | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + func.func @kernel() attributes {pto.kernel_kind = #pto.kernel_kind} { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + + %ub = pto.castptr %c0_i64 : i64 -> !pto.ptr + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %brc = pto.vlds %ub[%c0] {dist = "BRC_B8"} : !pto.ptr -> !pto.vreg<64xf32> + %low, %high = pto.vldsx2 %ub[%c0], "DINTLV_B8" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + pto.vsts %brc, %ub[%c0], %mask {dist = "NORM_B8"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.vstsx2 %low, %high, %ub[%c0], "INTLV_B8", %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, index, !pto.mask + } + return + } +} + +// CHECK-LABEL: define void @kernel() +// CHECK: call <256 x i1> @llvm.hivm.pset.b32(i32 0) +// CHECK: call <64 x float> @llvm.hivm.vldsx1.v64f32(ptr addrspace(6) null, i32 0, i32 1, i32 0) +// CHECK: call { <64 x float>, <64 x float> } @llvm.hivm.vldsx2.v64f32(ptr addrspace(6) null, i32 0, i32 11, i32 0) +// CHECK: call void @llvm.hivm.vstsx1.v64f32(<64 x float> {{.*}}, ptr addrspace(6) null, i32 0, i32 0, i32 0, <256 x i1> {{.*}}) +// CHECK: call void @llvm.hivm.vstsx2.v64f32(<64 x float> {{.*}}, <64 x float> {{.*}}, ptr addrspace(6) null, i32 0, i32 8, i32 0, <256 x i1> {{.*}}) From 25dd553ce636cbd6c2d69fd242fbd7fd9443c27b Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Fri, 17 Apr 2026 11:36:55 +0800 Subject: [PATCH 095/192] fix(dsl): fix vecscope inference --- tilelang-dsl/python/tilelang_dsl/semantic.py | 72 +++++++++++++++++--- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 2 +- 2 files changed, 64 insertions(+), 10 deletions(-) diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index c3ca9fcf5..4b601dd17 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -848,7 +848,30 @@ def _analyze_kernel_body( self, env: dict[str, SemanticBinding], ) -> tuple[tuple[SemanticStmt, ...], dict[str, SemanticBinding]]: - return self._analyze_block(self.node.body, env, allow_outer_lookup=True) + return self._analyze_body_with_inferred_kernel_vecscope( + self.node.body, + env, + allow_outer_lookup=True, + has_explicit_vecscope=self._has_explicit_vecscope, + ) + + def _analyze_body_with_inferred_kernel_vecscope( + self, + statements: tuple[FrontendStmtNode, ...], + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + has_explicit_vecscope: bool, + ) -> tuple[tuple[SemanticStmt, ...], dict[str, SemanticBinding]]: + # Inferred vecscope is built incrementally from left to right, splitting + # at hard boundaries (DMA/sync/UB-helper/control-flow) instead of + # wrapping the whole kernel body in one fallback region. + return self._analyze_block( + statements, + env, + allow_outer_lookup=allow_outer_lookup, + allow_inferred_vecscope=not has_explicit_vecscope, + ) def _parameter_type(self, param: Any) -> SemanticType: if param.kind == "tensorview": @@ -973,12 +996,12 @@ def _analyze_block( semantic_statements = [] index = 0 while index < len(statements): - if self._should_infer_vecscope( + if self._stmt_can_participate_in_inferred_vecscope( statements[index], allow_inferred_vecscope=allow_inferred_vecscope, ): end = index + 1 - while end < len(statements) and self._should_infer_vecscope( + while end < len(statements) and self._stmt_can_participate_in_inferred_vecscope( statements[end], allow_inferred_vecscope=allow_inferred_vecscope, ): @@ -1013,6 +1036,22 @@ def _analyze_block( index += 1 return tuple(semantic_statements), current_env + def _stmt_can_participate_in_inferred_vecscope( + self, + stmt: FrontendStmtNode, + *, + allow_inferred_vecscope: bool, + ) -> bool: + if self._has_explicit_vecscope: + return False + if self._disable_inference_depth > 0: + return False + if not allow_inferred_vecscope: + return False + if self._frontend_stmt_is_vecscope_boundary(stmt): + return False + return self._frontend_stmt_can_live_in_inferred_vecscope(stmt) + def _analyze_stmt_or_inline( self, stmt: FrontendStmtNode, @@ -1129,6 +1168,7 @@ def _frontend_stmt_is_vecscope_boundary(self, stmt: FrontendStmtNode) -> bool: isinstance(stmt, FrontendExprStmt) and ( self._is_dma_call(stmt.expr) + or self._is_low_level_dma_call(stmt.expr) or self._is_sync_call(stmt.expr) or self._is_ub_helper_call(stmt.expr) ) @@ -1213,10 +1253,11 @@ def _run_contains_vector_op(self, statements: tuple[FrontendStmtNode, ...]) -> b if self._constexpr_if_contains_vector_activity(stmt): return True continue - name = self._frontend_vector_call_name(stmt) - if name is None or name == "make_mask": - continue - return True + if self._frontend_stmt_contains_vector_activity(stmt): + name = self._frontend_vector_call_name(stmt) + if name == "make_mask": + continue + return True return False def _frontend_vector_call_name(self, stmt: FrontendStmtNode) -> str | None: @@ -1277,8 +1318,18 @@ def _semantic_block_contains_vector_activity( return True if isinstance(stmt, SemanticStrictVecscopeStmt): return True + if isinstance(stmt, SemanticDmaLoadStmt): + return True + if isinstance(stmt, SemanticDmaStoreStmt): + return True if isinstance(stmt, SemanticVectorStoreStmt): return True + if isinstance(stmt, SemanticDmaConfigStmt): + return True + if isinstance(stmt, SemanticDmaUnaryConfigStmt): + return True + if isinstance(stmt, SemanticLowLevelCopyStmt): + return True if isinstance(stmt, SemanticAssignStmt) and self._expr_contains_vector_activity(stmt.value): return True if isinstance(stmt, SemanticExprStmt) and self._expr_contains_vector_activity(stmt.expr): @@ -1446,11 +1497,13 @@ def _materialize_inline_proc_specialization( self._hidden_parameters = [] self._inline_proc_active_stack.append(key) try: - body, _ = self._analyze_block( + body, _ = self._analyze_body_with_inferred_kernel_vecscope( inline_proc_node.body, helper_env, allow_outer_lookup=False, - allow_inferred_vecscope=True, + has_explicit_vecscope=self._contains_explicit_vecscope( + inline_proc_node.body + ), ) finally: self._inline_proc_active_stack.pop() @@ -1825,6 +1878,7 @@ def _analyze_vector_store_stmt( value=value, destination=destination, indices=indices, + dist=None, mask=mask, ), dict(env), diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index 78ba57be5..acf236152 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -1685,7 +1685,7 @@ def kernel(inp: pto.TensorView, tile: pto.Tile): return None self.assertIn( - "`pto.vlds` does not support keyword arguments in TileLang DSL v1", + "unsupported keyword `offset` for `pto.vlds` in TileLang DSL v1", str(ctx.exception), ) self.assertIn(f"{__file__}:", str(ctx.exception)) From 849e57b35edc3e08cf404b7c435c31717628831f Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Mon, 20 Apr 2026 03:05:32 +0800 Subject: [PATCH 096/192] support ctrl reg configure ops --- include/PTO/IR/VPTOOps.td | 35 ++++++++++ lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 97 +++++++++++++++++++++++++- test/basic/ctrl_ops.pto | 18 +++++ 3 files changed, 148 insertions(+), 2 deletions(-) create mode 100644 test/basic/ctrl_ops.pto diff --git a/include/PTO/IR/VPTOOps.td b/include/PTO/IR/VPTOOps.td index a5ce95eb0..15967393b 100644 --- a/include/PTO/IR/VPTOOps.td +++ b/include/PTO/IR/VPTOOps.td @@ -141,6 +141,37 @@ class PTO_BinaryI64ConfigOp : PTO_Op { }]; } +class PTO_BinaryI64PureOp : PTO_Op { + let arguments = (ins + I64:$first, + I64:$second + ); + + let results = (outs I64:$result); + + let assemblyFormat = [{ + $first `,` $second attr-dict `:` type($first) `,` type($second) `->` type($result) + }]; +} + +class PTO_UnaryI64ConfigOp : PTO_Op { + let arguments = (ins I64:$value); + let results = (outs); + + let assemblyFormat = [{ + $value attr-dict `:` type($value) + }]; +} + +class PTO_NullaryI64PureOp : PTO_Op { + let arguments = (ins); + let results = (outs I64:$result); + + let assemblyFormat = [{ + attr-dict `:` type($result) + }]; +} + def PTO_SetMovPadValOp : PTO_Op<"set_mov_pad_val"> { let arguments = (ins AnyTypeOf<[AnyInteger, AnyFloat], "integer/float scalar">:$value); @@ -158,6 +189,10 @@ def PTO_SetLoopSizeOutToUbOp : PTO_BinaryI64ConfigOp<"set_loop_size_outtoub">; def PTO_SetLoop2StrideUbToOutOp : PTO_BinaryI64ConfigOp<"set_loop2_stride_ubtoout">; def PTO_SetLoop1StrideUbToOutOp : PTO_BinaryI64ConfigOp<"set_loop1_stride_ubtoout">; def PTO_SetLoopSizeUbToOutOp : PTO_BinaryI64ConfigOp<"set_loop_size_ubtoout">; +def PTO_GetCtrlOp : PTO_NullaryI64PureOp<"get_ctrl">; +def PTO_SetCtrlOp : PTO_UnaryI64ConfigOp<"set_ctrl">; +def PTO_Sbitset0Op : PTO_BinaryI64PureOp<"sbitset0">; +def PTO_Sbitset1Op : PTO_BinaryI64PureOp<"sbitset1">; def PTO_CopyGmToUbufOp : PTO_Op<"copy_gm_to_ubuf", [ DeclareOpInterfaceMethods diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index 2650f732b..31fdbd11e 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -1080,10 +1080,26 @@ static StringRef buildInitAlignCallee(MLIRContext *context) { return StringAttr::get(context, "llvm.hivm.init.vector.align.data").getValue(); } +template +static StringRef buildRuntimeQueryCallee(MLIRContext *context); + +template <> +StringRef buildRuntimeQueryCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.GET.CTRL").getValue(); +} + static StringRef buildSprclrCallee(MLIRContext *context) { return StringAttr::get(context, "llvm.hivm.sprclr").getValue(); } +template +static StringRef buildUnaryConfigCallee(MLIRContext *context); + +template <> +StringRef buildUnaryConfigCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.CTRL").getValue(); +} + static StringRef buildVstarCallee(MLIRContext *context) { return StringAttr::get(context, "llvm.hivm.vstar").getValue(); } @@ -1092,6 +1108,19 @@ static StringRef buildVstasCallee(MLIRContext *context) { return StringAttr::get(context, "llvm.hivm.vstas").getValue(); } +template +static StringRef buildBinaryI64PureCallee(MLIRContext *context); + +template <> +StringRef buildBinaryI64PureCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SBITSET0").getValue(); +} + +template <> +StringRef buildBinaryI64PureCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SBITSET1").getValue(); +} + static FailureOr buildVldsPostCallee(MLIRContext *context, Type resultType) { std::string vec = getElementTypeFragment(getElementTypeFromVectorLike(resultType)); @@ -4297,6 +4326,32 @@ class LowerUnaryConfigOpPattern final : public OpConversionPattern { LoweringState &state; }; +template +class LowerUnaryI64ConfigOpPattern final : public OpConversionPattern { +public: + explicit LowerUnaryI64ConfigOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(ConfigOp op, typename ConfigOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + StringRef calleeName = buildUnaryConfigCallee(op.getContext()); + auto funcType = + rewriter.getFunctionType(TypeRange{adaptor.getValue().getType()}, + TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, + ValueRange{adaptor.getValue()}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + template class LowerPipeEventSyncOpPattern final : public OpConversionPattern { public: @@ -4467,6 +4522,38 @@ class LowerRuntimeQueryOpPattern final : public OpConversionPattern { LoweringState &state; }; +template +class LowerBinaryI64PureOpPattern final : public OpConversionPattern { +public: + explicit LowerBinaryI64PureOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(BinaryOp op, typename BinaryOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, "failed to convert result type"); + + StringRef calleeName = buildBinaryI64PureCallee(op.getContext()); + auto funcType = + rewriter.getFunctionType(TypeRange{adaptor.getFirst().getType(), + adaptor.getSecond().getType()}, + TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), calleeName, TypeRange{resultType}, + ValueRange{adaptor.getFirst(), adaptor.getSecond()}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + class ConvertVPTOUnrealizedCastOp final : public OpConversionPattern { public: @@ -4753,12 +4840,16 @@ static void populateVPTOOpLoweringPatterns(VPTOTypeConverter &typeConverter, LowerPgeOpPattern, LowerPgeOpPattern, LowerPgeOpPattern, + LowerRuntimeQueryOpPattern, + LowerBinaryI64PureOpPattern, + LowerBinaryI64PureOpPattern, LowerSetLoopConfigOpPattern, LowerSetLoopConfigOpPattern, LowerSetLoopConfigOpPattern, LowerSetLoopConfigOpPattern, LowerSetLoopConfigOpPattern, LowerSetLoopConfigOpPattern, + LowerUnaryI64ConfigOpPattern, LowerUnaryConfigOpPattern, LowerPipeEventSyncOpPattern, LowerPipeEventSyncOpPattern, @@ -4801,11 +4892,13 @@ static void configureVPTOOpLoweringTarget(ConversionTarget &target, target.addIllegalOp(); target.addIllegalOp(); + pto::GetBlockNumOp, pto::GetSubBlockNumOp, + pto::GetCtrlOp>(); target.addIllegalOp(); + pto::SetCtrlOp, pto::SetMovPadValOp>(); + target.addIllegalOp(); target.addIllegalOp i64 + %ctrl1 = pto.sbitset1 %ctrl0, %bit45 : i64, i64 -> i64 + pto.set_ctrl %ctrl1 : i64 + return + } +} + +// CHECK: call i64 @llvm.hivm.GET.CTRL() +// CHECK: call i64 @llvm.hivm.SBITSET0(i64 +// CHECK: call i64 @llvm.hivm.SBITSET1(i64 +// CHECK: call void @llvm.hivm.SET.CTRL(i64 From 51bcc3f8a59a3d38ceb1c57827a8e4281c830800 Mon Sep 17 00:00:00 2001 From: chenjinlin Date: Mon, 20 Apr 2026 10:51:12 +0800 Subject: [PATCH 097/192] udpate tload and tstore, and add testcase. --- lib/TileOps/tload_template.py | 112 +++++++++++------ lib/TileOps/tstore_template.py | 12 +- .../npu/a5/src/st/testcase/tload/cases.py | 89 +++++++++++++ .../npu/a5/src/st/testcase/tload/compare.py | 9 +- .../npu/a5/src/st/testcase/tload/gen_data.py | 10 +- .../npu/a5/src/st/testcase/tload/launch.cpp | 15 +++ .../npu/a5/src/st/testcase/tload/main.cpp | 18 +++ .../npu/a5/src/st/testcase/tload/tload.pto | 118 ++++++++++++++++++ tilelang-dsl/python/tilelang_dsl/lowering.py | 8 ++ 9 files changed, 339 insertions(+), 52 deletions(-) diff --git a/lib/TileOps/tload_template.py b/lib/TileOps/tload_template.py index b04f37840..3366ac2cd 100644 --- a/lib/TileOps/tload_template.py +++ b/lib/TileOps/tload_template.py @@ -65,6 +65,8 @@ def _tload_preconditions_nz2nz(src, dst) -> bool: def template_tload_nd2nd(src: pto.PartitionTensorView, dst: pto.Tile): dtype = dst.element_type elem_bytes = pto.bytewidth(dtype) + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + pto.set_mov_pad_val(dst.pad_value.eval()) g0, g1, g2, g3, g4 = src.shape s0, s1, s2, s3, s4 = src.strides @@ -72,9 +74,6 @@ def template_tload_nd2nd(src: pto.PartitionTensorView, dst: pto.Tile): valid_rows, valid_cols = dst.valid_shape ub_rows, ub_cols = dst.shape - # These preconditions are expressed through the descriptor-level constraint - # callable above, using direct `src.shape[i]` / `dst.shape[i]` syntax. - n_burst = g3 len_burst = g4 * elem_bytes gm_stride = s3 * elem_bytes @@ -104,17 +103,28 @@ def template_tload_nd2nd(src: pto.PartitionTensorView, dst: pto.Tile): pto.set_loop_size_outtoub(loop1=loop1, loop2=loop2) for i in range(0, g0, 1): - src_i = pto.addptr(gm_ptr, i * s0 * elem_bytes) - dst_i = pto.addptr(ub_ptr, i * dst_stride0 * elem_bytes) - pto.copy_gm_to_ubuf( - dst=dst_i, - src=src_i, - n_burst=n_burst, - len_burst=len_burst, - gm_stride=gm_stride, - ub_stride=ub_stride, - enable_ub_pad=False, - ) + src_i = pto.addptr(gm_ptr, i * s0) + dst_i = pto.addptr(ub_ptr, i * dst_stride0) + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + pto.copy_gm_to_ubuf( + dst=dst_i, + src=src_i, + n_burst=n_burst, + len_burst=len_burst, + gm_stride=gm_stride, + ub_stride=ub_stride, + enable_ub_pad=True, + ) + else: + pto.copy_gm_to_ubuf( + dst=dst_i, + src=src_i, + n_burst=n_burst, + len_burst=len_burst, + gm_stride=gm_stride, + ub_stride=ub_stride, + enable_ub_pad=False, + ) if loop1 != 1 or loop2 != 1: pto.set_loop_size_outtoub(loop1=1, loop2=1) @@ -129,6 +139,8 @@ def template_tload_nd2nd(src: pto.PartitionTensorView, dst: pto.Tile): def template_tload_dn2dn(src: pto.PartitionTensorView, dst: pto.Tile): dtype = dst.element_type elem_bytes = pto.bytewidth(dtype) + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + pto.set_mov_pad_val(dst.pad_value.eval()) # rank-5 partition view 元信息。 g0, g1, g2, g3, g4 = src.shape @@ -169,17 +181,28 @@ def template_tload_dn2dn(src: pto.PartitionTensorView, dst: pto.Tile): pto.set_loop_size_outtoub(loop1=loop1, loop2=loop2) for i in range(0, g0, 1): - src_i = pto.addptr(gm_ptr, i * s0 * elem_bytes) - dst_i = pto.addptr(ub_ptr, i * dst_stride0 * elem_bytes) - pto.copy_gm_to_ubuf( - dst=dst_i, - src=src_i, - n_burst=n_burst, - len_burst=len_burst, - gm_stride=gm_stride, - ub_stride=ub_stride, - enable_ub_pad=False, - ) + src_i = pto.addptr(gm_ptr, i * s0) + dst_i = pto.addptr(ub_ptr, i * dst_stride0) + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + pto.copy_gm_to_ubuf( + dst=dst_i, + src=src_i, + n_burst=n_burst, + len_burst=len_burst, + gm_stride=gm_stride, + ub_stride=ub_stride, + enable_ub_pad=True, + ) + else: + pto.copy_gm_to_ubuf( + dst=dst_i, + src=src_i, + n_burst=n_burst, + len_burst=len_burst, + gm_stride=gm_stride, + ub_stride=ub_stride, + enable_ub_pad=False, + ) if loop1 != 1 or loop2 != 1: pto.set_loop_size_outtoub(loop1=1, loop2=1) @@ -195,10 +218,8 @@ def template_tload_nz2nz(src: pto.PartitionTensorView, dst: pto.Tile): dtype = dst.element_type elem_bytes = pto.bytewidth(dtype) - # set padding value for ub tile if needed - # enable_ub_pad = dst.config.pad_value is not pto.PadValue.NULL - # if enable_ub_pad: - # pto.set_mov_pad_val(pad_value=dst.config.pad_value.eval()) + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + pto.set_mov_pad_val(dst.pad_value.eval()) # rank-5 partition view 元信息。NZ 静态分块约束(g3/g4 与 dtype 的关系) # 由更高层 schema/static-check 保证,这里只保留运行时搬运公式。 @@ -223,15 +244,26 @@ def template_tload_nz2nz(src: pto.PartitionTensorView, dst: pto.Tile): # NZ2NZ 对应实现始终走 normal mode,不复用 loop1/loop2 寄存器。 pto.set_loop_size_outtoub(loop1=1, loop2=1) for i in range(0, g0, 1): - src_i = pto.addptr(gm_ptr, i * s0 * elem_bytes) - dst_i = pto.addptr(ub_ptr, i * tile_stride * elem_bytes) - pto.copy_gm_to_ubuf( - dst=dst_i, - src=src_i, - n_burst=n_burst, - len_burst=len_burst, - gm_stride=gm_stride, - ub_stride=ub_stride, - enable_ub_pad=False, - ) + src_i = pto.addptr(gm_ptr, i * s0) + dst_i = pto.addptr(ub_ptr, i * tile_stride) + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + pto.copy_gm_to_ubuf( + dst=dst_i, + src=src_i, + n_burst=n_burst, + len_burst=len_burst, + gm_stride=gm_stride, + ub_stride=ub_stride, + enable_ub_pad=True, + ) + else: + pto.copy_gm_to_ubuf( + dst=dst_i, + src=src_i, + n_burst=n_burst, + len_burst=len_burst, + gm_stride=gm_stride, + ub_stride=ub_stride, + enable_ub_pad=False, + ) return diff --git a/lib/TileOps/tstore_template.py b/lib/TileOps/tstore_template.py index 7857a1544..2850a1651 100644 --- a/lib/TileOps/tstore_template.py +++ b/lib/TileOps/tstore_template.py @@ -102,8 +102,8 @@ def template_tstore_nd(src: pto.Tile, dst: pto.PartitionTensorView): pto.set_loop_size_ubtoout(loop1=loop1, loop2=loop2) for i in range(0, g0, 1): - src_i = pto.addptr(ub_ptr, i * src_stride0 * elem_bytes) - dst_i = pto.addptr(gm_ptr, i * s0 * elem_bytes) + src_i = pto.addptr(ub_ptr, i * src_stride0) + dst_i = pto.addptr(gm_ptr, i * s0) pto.copy_ubuf_to_gm( dst=dst_i, src=src_i, @@ -164,8 +164,8 @@ def template_tstore_dn(src: pto.Tile, dst: pto.PartitionTensorView): pto.set_loop_size_ubtoout(loop1=loop1, loop2=loop2) for i in range(0, g0, 1): - src_i = pto.addptr(ub_ptr, i * src_stride0 * elem_bytes) - dst_i = pto.addptr(gm_ptr, i * s0 * elem_bytes) + src_i = pto.addptr(ub_ptr, i * src_stride0) + dst_i = pto.addptr(gm_ptr, i * s0) pto.copy_ubuf_to_gm( dst=dst_i, src=src_i, @@ -211,8 +211,8 @@ def template_tstore_nz(src: pto.Tile, dst: pto.PartitionTensorView): # NZ path 本身不使用 loop1/loop2,主动切回 normal mode 避免继承旧状态。 pto.set_loop_size_ubtoout(loop1=1, loop2=1) for i in range(0, g0, 1): - src_i = pto.addptr(ub_ptr, i * tile_stride * elem_bytes) - dst_i = pto.addptr(gm_ptr, i * s0 * elem_bytes) + src_i = pto.addptr(ub_ptr, i * tile_stride) + dst_i = pto.addptr(gm_ptr, i * s0) pto.copy_ubuf_to_gm( dst=dst_i, src=src_i, diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tload/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tload/cases.py index d1117e24e..a8a77ba80 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tload/cases.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tload/cases.py @@ -31,4 +31,93 @@ "valid_shape": (128, 128), "eps": 1e-6, }, + { + "name": "nd_pad_zero_f32_16x64", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 63), + "eps": 1e-6, + "golden_fill": 0.0, + }, + { + "name": "dn_pad_max_f32_16x64", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (15, 64), + "eps": 1e-6, + "golden_fill": np.finfo(np.float32).max, + }, + { + "name": "nz_pad_min_f32_128x128", + "dtype": np.float32, + "shape": (128, 128), + "valid_shape": (64, 128), + "eps": 1e-6, + "golden_fill": np.finfo(np.float32).min, + }, ] + + +def build_expected_output(case, input_arr): + shape = case["shape"] + vr, vc = case["valid_shape"] + dtype = case["dtype"] + + if "golden_fill" in case: + golden = np.full(shape, case["golden_fill"], dtype=dtype) + else: + golden = np.empty(shape, dtype=dtype) + + if case["name"].startswith("dn_pad_"): + flat_in = np.asarray(input_arr, dtype=dtype).reshape(-1) + flat_golden = golden.reshape(-1) + physical_rows = shape[0] + for col in range(vc): + start = physical_rows * col + flat_golden[start : start + vr] = flat_in[start : start + vr] + return golden + + if case["name"].startswith("nz_pad_"): + flat_in = np.asarray(input_arr, dtype=dtype).reshape(-1) + flat_golden = golden.reshape(-1) + block_rows = 8 + block_size = block_rows * shape[1] + num_blocks = shape[0] // block_rows + valid_rows_per_block = vr // num_blocks + for block in range(num_blocks): + base = block * block_size + valid_elems = valid_rows_per_block * shape[1] + flat_golden[base : base + valid_elems] = flat_in[base : base + valid_elems] + return golden + + if "golden_fill" in case: + golden[:vr, :vc] = input_arr[:vr, :vc] + return golden + + return np.asarray(input_arr, dtype=dtype).copy() + + +def select_compared_region(case, arr): + vr, vc = case["valid_shape"] + + if case["name"].startswith("dn_pad_"): + flat = np.asarray(arr).reshape(-1) + physical_rows = case["shape"][0] + pieces = [flat[physical_rows * col : physical_rows * col + vr] for col in range(vc)] + return np.concatenate(pieces) if pieces else flat[:0] + + if case["name"].startswith("nz_pad_"): + flat = np.asarray(arr).reshape(-1) + shape = case["shape"] + block_rows = 8 + block_size = block_rows * shape[1] + num_blocks = shape[0] // block_rows + valid_rows_per_block = vr // num_blocks + pieces = [] + for block in range(num_blocks): + base = block * block_size + valid_elems = valid_rows_per_block * shape[1] + pieces.append(flat[base : base + valid_elems]) + return np.concatenate(pieces) if pieces else flat[:0] + + return np.asarray(arr)[:vr, :vc] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tload/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tload/compare.py index 4ceccdde0..6adc9c9fe 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tload/compare.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tload/compare.py @@ -11,7 +11,7 @@ import sys import numpy as np -from cases import CASES +from cases import CASES, select_compared_region from st_common import result_cmp, style_fail, style_pass, validate_cases @@ -26,12 +26,15 @@ def main(): case_dir = case["name"] shape = case["shape"] - vr, vc = case["valid_shape"] golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) - ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + ok = result_cmp( + select_compared_region(case, golden), + select_compared_region(case, output), + case["eps"], + ) if ok: print(style_pass(f"[INFO] {case['name']}: compare passed")) else: diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tload/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tload/gen_data.py index fc1b88759..449291f26 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tload/gen_data.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tload/gen_data.py @@ -8,7 +8,7 @@ # See LICENSE in the root of the software repository for the full text of the License. import numpy as np -from cases import CASES +from cases import CASES, build_expected_output from st_common import validate_cases, setup_case_rng, save_case_data validate_cases(CASES) @@ -18,9 +18,13 @@ dtype = case["dtype"] shape = case["shape"] + vr, vc = case["valid_shape"] input_arr = np.random.randint(1, 17, size=shape).astype(dtype) - golden = input_arr.copy() + golden = build_expected_output(case, input_arr) save_case_data(case["name"], {"input": input_arr, "golden": golden}) - print(f"[INFO] gen_data: {case['name']} shape={shape} dtype={dtype.__name__}") + print( + f"[INFO] gen_data: {case['name']} shape={shape} " + f"valid_shape={(vr, vc)} dtype={dtype.__name__}" + ) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tload/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tload/launch.cpp index bab1b88d8..70453d6b9 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tload/launch.cpp +++ b/test/tilelang_st/npu/a5/src/st/testcase/tload/launch.cpp @@ -15,6 +15,9 @@ extern "C" __global__ AICORE void TLOAD_ND_f32_16x64(__gm__ float *src, __gm__ float *dst); extern "C" __global__ AICORE void TLOAD_DN_f32_16x64(__gm__ float *src, __gm__ float *dst); extern "C" __global__ AICORE void TLOAD_NZ_f32_128x128(__gm__ float *src, __gm__ float *dst); +extern "C" __global__ AICORE void TLOAD_ND_PAD_ZERO_f32_16x64(__gm__ float *src, __gm__ float *dst); +extern "C" __global__ AICORE void TLOAD_DN_PAD_MAX_f32_16x64(__gm__ float *src, __gm__ float *dst); +extern "C" __global__ AICORE void TLOAD_NZ_PAD_MIN_f32_128x128(__gm__ float *src, __gm__ float *dst); void LaunchTLOAD_ND_f32_16x64(float *src, float *dst, void *stream) { TLOAD_ND_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); @@ -27,3 +30,15 @@ void LaunchTLOAD_DN_f32_16x64(float *src, float *dst, void *stream) { void LaunchTLOAD_NZ_f32_128x128(float *src, float *dst, void *stream) { TLOAD_NZ_f32_128x128<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); } + +void LaunchTLOAD_ND_PAD_ZERO_f32_16x64(float *src, float *dst, void *stream) { + TLOAD_ND_PAD_ZERO_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +void LaunchTLOAD_DN_PAD_MAX_f32_16x64(float *src, float *dst, void *stream) { + TLOAD_DN_PAD_MAX_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +void LaunchTLOAD_NZ_PAD_MIN_f32_128x128(float *src, float *dst, void *stream) { + TLOAD_NZ_PAD_MIN_f32_128x128<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tload/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tload/main.cpp index 3984d9150..7c0b66d14 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tload/main.cpp +++ b/test/tilelang_st/npu/a5/src/st/testcase/tload/main.cpp @@ -23,6 +23,9 @@ using namespace PtoTestCommon; void LaunchTLOAD_ND_f32_16x64(float *src, float *dst, void *stream); void LaunchTLOAD_DN_f32_16x64(float *src, float *dst, void *stream); void LaunchTLOAD_NZ_f32_128x128(float *src, float *dst, void *stream); +void LaunchTLOAD_ND_PAD_ZERO_f32_16x64(float *src, float *dst, void *stream); +void LaunchTLOAD_DN_PAD_MAX_f32_16x64(float *src, float *dst, void *stream); +void LaunchTLOAD_NZ_PAD_MIN_f32_128x128(float *src, float *dst, void *stream); using LaunchFn = void (*)(float *, float *, void *); @@ -38,6 +41,9 @@ static const TestCase kCases[] = { {"nd_f32_16x64", LaunchTLOAD_ND_f32_16x64, 16, 64, sizeof(float)}, {"dn_f32_16x64", LaunchTLOAD_DN_f32_16x64, 16, 64, sizeof(float)}, {"nz_f32_128x128", LaunchTLOAD_NZ_f32_128x128, 128, 128, sizeof(float)}, + {"nd_pad_zero_f32_16x64", LaunchTLOAD_ND_PAD_ZERO_f32_16x64, 16, 64, sizeof(float)}, + {"dn_pad_max_f32_16x64", LaunchTLOAD_DN_PAD_MAX_f32_16x64, 16, 64, sizeof(float)}, + {"nz_pad_min_f32_128x128", LaunchTLOAD_NZ_PAD_MIN_f32_128x128, 128, 128, sizeof(float)}, }; static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); @@ -96,6 +102,7 @@ int main(int argc, char *argv[]) { const char *caseFilter = (argc > 1) ? argv[1] : nullptr; int rc = 0; + bool matchedCase = (caseFilter == nullptr); int deviceId = 0; aclrtStream stream = nullptr; @@ -110,6 +117,7 @@ int main(int argc, char *argv[]) { if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { continue; } + matchedCase = true; int ret = RunCase(kCases[i], stream); if (ret != 0) { std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); @@ -118,6 +126,16 @@ int main(int argc, char *argv[]) { } } + if (!matchedCase) { + std::fprintf(stderr, "[ERROR] unknown case filter: %s\n", caseFilter); + std::fprintf(stderr, "[ERROR] supported cases:"); + for (size_t i = 0; i < kNumCases; ++i) { + std::fprintf(stderr, " %s", kCases[i].name); + } + std::fprintf(stderr, "\n"); + rc = 1; + } + if (stream != nullptr) aclrtDestroyStream(stream); aclrtResetDevice(deviceId); diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tload/tload.pto b/test/tilelang_st/npu/a5/src/st/testcase/tload/tload.pto index a612f12ea..9b116a34e 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tload/tload.pto +++ b/test/tilelang_st/npu/a5/src/st/testcase/tload/tload.pto @@ -125,4 +125,122 @@ module { outs(%dst_part : !pto.partition_tensor_view<16x1x128x1x8xf32>) return } + + func.func @TLOAD_ND_PAD_ZERO_f32_16x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c63 = arith.constant 63 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c63], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x63xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c63], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x63xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c63] + : !pto.tensor_view<1x1x1x16x63xf32> -> !pto.partition_tensor_view<1x1x1x16x63xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c63] + : !pto.tensor_view<1x1x1x16x63xf32> -> !pto.partition_tensor_view<1x1x1x16x63xf32> + + %tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x63xf32>) + outs(%tile : !pto.tile_buf) + pto.tstore ins(%tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x63xf32>) + return + } + + func.func @TLOAD_DN_PAD_MAX_f32_16x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c15, %c64], + strides = [%c1024, %c1024, %c1024, %c1, %c16] + : !pto.tensor_view<1x1x1x15x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c15, %c64], + strides = [%c1024, %c1024, %c1024, %c1, %c16] + : !pto.tensor_view<1x1x1x15x64xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c64] + : !pto.tensor_view<1x1x1x15x64xf32> -> !pto.partition_tensor_view<1x1x1x15x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c64] + : !pto.tensor_view<1x1x1x15x64xf32> -> !pto.partition_tensor_view<1x1x1x15x64xf32> + + %tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x64xf32>) + outs(%tile : !pto.tile_buf) + pto.tstore ins(%tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x15x64xf32>) + return + } + + func.func @TLOAD_NZ_PAD_MIN_f32_128x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c16, %c1, %c64, %c1, %c8], + strides = [%c1024, %c1024, %c8, %c8, %c1] + : !pto.tensor_view<16x1x64x1x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c16, %c1, %c64, %c1, %c8], + strides = [%c1024, %c1024, %c8, %c8, %c1] + : !pto.tensor_view<16x1x64x1x8xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c16, %c1, %c64, %c1, %c8] + : !pto.tensor_view<16x1x64x1x8xf32> -> !pto.partition_tensor_view<16x1x64x1x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c16, %c1, %c64, %c1, %c8] + : !pto.tensor_view<16x1x64x1x8xf32> -> !pto.partition_tensor_view<16x1x64x1x8xf32> + + %tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<16x1x64x1x8xf32>) + outs(%tile : !pto.tile_buf) + pto.tstore ins(%tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<16x1x64x1x8xf32>) + return + } } diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index 35cecb607..6ed1b550b 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -3586,6 +3586,14 @@ def _constant_name(self, value: object, ty: SemanticType) -> str: stem = f"c{value}_{ty.dtype.name}" else: stem = "cst" + # Keep generated SSA names MLIR-safe for constants whose textual value + # contains punctuation such as decimal points or scientific-notation + # exponents (for example f32 max -> `3.4028235e+38`). + stem = re.sub(r"[^0-9A-Za-z_]", "_", stem) + stem = re.sub(r"_+", "_", stem).strip("_") or "cst" + if stem[0].isdigit(): + stem = f"c_{stem}" + name = f"%{stem}" existing = {line.split(" = ", 1)[0].strip() for line in self._constant_lines} if name not in existing: From d7b4524d7c7cece60463755210cf49690f704541 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Fri, 17 Apr 2026 18:12:06 +0800 Subject: [PATCH 098/192] feat(dsl): complete predicate ops and enforce PredicateDist enum Implement missing predicate-op surfaces end-to-end in TileLang DSL v1, including semantic/lowering/test coverage and docs updates. Use PredicateDist enum for predicate load/store dist operands instead of raw strings. Closes #124 --- .../user_guide/09-vector-memory-operations.md | 4 +- .../user_guide/10-predicate-operations.md | 14 +- tilelang-dsl/python/tilelang_dsl/__init__.py | 2 + .../python/tilelang_dsl/frontend_ast.py | 1 + tilelang-dsl/python/tilelang_dsl/lowering.py | 93 ++++++- tilelang-dsl/python/tilelang_dsl/semantic.py | 261 +++++++++++++++++- .../python/tilelang_dsl/support_matrix.py | 19 ++ tilelang-dsl/python/tilelang_dsl/types.py | 8 + tilelang-dsl/tests/test_tilelang_dsl_v1.py | 143 ++++++++++ 9 files changed, 510 insertions(+), 35 deletions(-) diff --git a/tilelang-dsl/docs/user_guide/09-vector-memory-operations.md b/tilelang-dsl/docs/user_guide/09-vector-memory-operations.md index e1699f22d..57def372d 100644 --- a/tilelang-dsl/docs/user_guide/09-vector-memory-operations.md +++ b/tilelang-dsl/docs/user_guide/09-vector-memory-operations.md @@ -617,8 +617,8 @@ delivery form differs. **Returns**: None (side-effect operation) **DIST semantics (VPTO-aligned)**: -- `"NORM"`: store packed predicate payload into a normal destination space of size `VL/8`. -- `"PK"`: store packed predicate payload into a destination space of size `VL/16`, keeping one bit out of every two bits. +- `PredicateDist.NORM`: store packed predicate payload into a normal destination space of size `VL/8`. +- `PredicateDist.PK`: store packed predicate payload into a destination space of size `VL/16`, keeping one bit out of every two bits. **Notes**: - `pto.psts` is intentionally documented as explicit `buf + offset` surface in DSL v1. diff --git a/tilelang-dsl/docs/user_guide/10-predicate-operations.md b/tilelang-dsl/docs/user_guide/10-predicate-operations.md index d0d64f0c0..21f8d4879 100644 --- a/tilelang-dsl/docs/user_guide/10-predicate-operations.md +++ b/tilelang-dsl/docs/user_guide/10-predicate-operations.md @@ -8,7 +8,7 @@ Operations for creating and manipulating typed masks. **Predicate Part Enum**: `pto.ppack` and `pto.punpack` require the `PredicatePart` enum. Use `PredicatePart.LOWER` or `PredicatePart.HIGHER`; these lower to the VPTO canonical `PART` tokens `"LOWER"` and `"HIGHER"`. -**Predicate Dist Enum**: The `PredicateDist` enum provides type-safe distribution mode selection for predicate memory families. Load families (`plds`, `pld`, `pldi`) use `NORM`, `US`, and `DS`. Store families (`pst`, `psti`) use `NORM` and `PK`. +**Predicate Dist Enum**: The `PredicateDist` enum provides type-safe distribution mode selection for predicate memory families. Load families (`plds`, `pld`, `pldi`) use `NORM`, `US`, and `DS`. Store families (`psts`, `pst`, `psti`) use `NORM` and `PK`. **Pattern coverage**: The VPTO canonical predicate-generation families use `PAT_*` tokens such as `PAT_ALL`, `PAT_ALLF`, `PAT_H`, `PAT_Q`, `PAT_VL*`, `PAT_M3`, and `PAT_M4`. The Python DSL surface may expose only a subset through `pto.MaskPattern`; check the enum for currently available values. @@ -367,7 +367,7 @@ unpacked = pto.punpack(mask, pto.PredicatePart.HIGHER) **Example**: ```python -mask = pto.plds(buf, offset, PredicateDist.NORM) +mask = pto.plds(buf, offset, pto.PredicateDist.NORM) ``` #### `pto.pld(buf: ptr, offset: Index, dist: PredicateDist) -> MaskType` [Advanced Tier] @@ -388,7 +388,7 @@ mask = pto.plds(buf, offset, PredicateDist.NORM) **Example**: ```python -mask = pto.pld(buf, offset, PredicateDist.NORM) +mask = pto.pld(buf, offset, pto.PredicateDist.NORM) ``` #### `pto.pldi(buf: ptr, imm_offset: pto.i32, dist: PredicateDist) -> MaskType` [Advanced Tier] @@ -409,7 +409,7 @@ mask = pto.pld(buf, offset, PredicateDist.NORM) **Example**: ```python -mask = pto.pldi(buf, 0, PredicateDist.NORM) +mask = pto.pldi(buf, 0, pto.PredicateDist.NORM) ``` #### `pto.psts(mask: MaskType, buf: ptr, offset: Index, dist: PredicateDist = PredicateDist.NORM) -> None` [Advanced Tier] @@ -436,7 +436,7 @@ constant immediate). **Example**: ```python -pto.psts(mask, buf, offset, PredicateDist.NORM) +pto.psts(mask, buf, offset, pto.PredicateDist.NORM) ``` #### `pto.pst(mask: MaskType, buf: ptr, offset: Index, dist: PredicateDist = PredicateDist.NORM) -> None` [Advanced Tier] @@ -455,7 +455,7 @@ pto.psts(mask, buf, offset, PredicateDist.NORM) **Example**: ```python -pto.pst(mask, buf, offset, PredicateDist.NORM) +pto.pst(mask, buf, offset, pto.PredicateDist.NORM) ``` #### `pto.psti(mask: MaskType, buf: ptr, imm_offset: pto.i32, dist: PredicateDist = PredicateDist.NORM) -> None` [Advanced Tier] @@ -474,7 +474,7 @@ pto.pst(mask, buf, offset, PredicateDist.NORM) **Example**: ```python -pto.psti(mask, buf, pto.i32(8), PredicateDist.PK) +pto.psti(mask, buf, pto.i32(8), pto.PredicateDist.PK) ``` #### `pto.pstu(align_in: pto.align, mask: MaskType, buf: ptr) -> (pto.align, ptr)` [Advanced Tier] diff --git a/tilelang-dsl/python/tilelang_dsl/__init__.py b/tilelang-dsl/python/tilelang_dsl/__init__.py index 7a5d2c6d2..eca98c8ba 100644 --- a/tilelang-dsl/python/tilelang_dsl/__init__.py +++ b/tilelang-dsl/python/tilelang_dsl/__init__.py @@ -48,6 +48,7 @@ PointerType, PostUpdateMode, Pipe, + PredicateDist, ScalarType, SLayout, TensorView, @@ -120,6 +121,7 @@ "PIPE", "EVENT", "MaskPattern", + "PredicateDist", "PAT", "BarrierType", "BLayout", diff --git a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py index a0d659ed2..130dcc165 100644 --- a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py +++ b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py @@ -978,6 +978,7 @@ def _build_expr(node: ast.AST, context: _FrontendBuildContext) -> FrontendExprNo "PIPE", "EVENT", "MaskPattern", + "PredicateDist", "Pipe", "Event", "BarrierType", diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index 6ed1b550b..5f2adf04d 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -679,19 +679,24 @@ def _render_multi_result_assign( if not isinstance(stmt.value.type, SemanticTupleType) or len(stmt.value.type.elements) != 2: raise NotImplementedError("multi-result lowering expects a two-result tuple type") - if stmt.value.name == "make_mask": - dtype_expr, remaining_expr = stmt.value.args - if not self._is_dtype_meta_expr(dtype_expr): - raise NotImplementedError("make_mask dtype lowering expects a dtype symbol") - + if stmt.value.name in {"make_mask", "plt_b8", "plt_b16", "plt_b32"}: lines: list[str] = [] - remaining = self._lower_remaining_to_i32(remaining_expr, env, indent=indent, into=lines) + if stmt.value.name == "make_mask": + dtype_expr, remaining_expr = stmt.value.args + if not self._is_dtype_meta_expr(dtype_expr): + raise NotImplementedError("make_mask dtype lowering expects a dtype symbol") + remaining = self._lower_remaining_to_i32(remaining_expr, env, indent=indent, into=lines) + op_name = None + else: + remaining = self._lower_remaining_to_i32(stmt.value.args[0], env, indent=indent, into=lines) + op_name = stmt.value.name mask_target, remaining_target = stmt.targets mask_type, remaining_type = stmt.value.type.elements suffix = self._mask_suffix(mask_type) + lowered_op = op_name or f"plt_{suffix}" lines.append( self._indent(indent) - + f"{mask_target.ssa_name}, {remaining_target.ssa_name} = pto.plt_{suffix} {remaining.name} : " + + f"{mask_target.ssa_name}, {remaining_target.ssa_name} = pto.{lowered_op} {remaining.name} : " + f"i32 -> {self._render_type(mask_type)}, {self._render_type(remaining_type)}" ) env[mask_target.name] = _RenderedValue(name=mask_target.ssa_name, type=mask_type) @@ -752,6 +757,22 @@ def _render_multi_result_assign( env[high_target.name] = _RenderedValue(name=high_target.ssa_name, type=high_type) return lines + if stmt.value.name in {"pdintlv_b8", "pintlv_b16"}: + lines = [] + lhs = self._lower_expr(stmt.value.args[0], env, indent=indent, into=lines) + rhs = self._lower_expr(stmt.value.args[1], env, indent=indent, into=lines) + low_target, high_target = stmt.targets + low_type, high_type = stmt.value.type.elements + lines.append( + self._indent(indent) + + f"{low_target.ssa_name}, {high_target.ssa_name} = pto.{stmt.value.name} " + + f"{lhs.name}, {rhs.name} : {self._render_type(lhs.type)}, {self._render_type(rhs.type)} " + + f"-> {self._render_type(low_type)}, {self._render_type(high_type)}" + ) + env[low_target.name] = _RenderedValue(name=low_target.ssa_name, type=low_type) + env[high_target.name] = _RenderedValue(name=high_target.ssa_name, type=high_type) + return lines + if stmt.value.name == "vmull": lines = [] lhs = self._lower_expr(stmt.value.args[0], env, indent=indent, into=lines) @@ -1070,11 +1091,14 @@ def _render_predicate_store( ) rendered_offset = self._materialize_constant(0, SemanticIndexType()) else: - rendered_offset = self._lower_expr(stmt.indices[0], env, indent=indent, into=lines) + if stmt.op_name == "psti": + rendered_offset = self._lower_to_index(stmt.indices[0], env, indent=indent, into=lines) + else: + rendered_offset = self._lower_expr(stmt.indices[0], env, indent=indent, into=lines) dist = self._render_string_literal(stmt.dist) lines.append( self._indent(indent) - + f"pto.psts {value.name}, {destination.name}[{rendered_offset.name}], {dist} : " + + f"pto.{stmt.op_name} {value.name}, {destination.name}[{rendered_offset.name}], {dist} : " + f"{self._render_type(value.type)}, {self._render_type(destination.type)}, {self._render_type(rendered_offset.type)}" ) return lines @@ -2549,6 +2573,17 @@ def _lower_call_expr( ) return _RenderedValue(name=result_name, type=expr.type) + if expr.name in {"pset_b8", "pset_b16", "pset_b32", "pge_b8", "pge_b16", "pge_b32"}: + if not isinstance(expr.args[0], SemanticSymbolExpr) or not isinstance(expr.args[0].value, MaskPattern): + raise NotImplementedError(f"{expr.name} lowering expects a MaskPattern symbol") + pattern_token = expr.args[0].value.value.replace("\\", "\\\\").replace('"', '\\"') + pattern = f'"{pattern_token}"' + into.append( + self._indent(indent) + + f"{result_name} = pto.{expr.name} {pattern} : {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + if expr.name == "init_align": into.append( self._indent(indent) @@ -2588,6 +2623,20 @@ def _lower_call_expr( ) return _RenderedValue(name=result_name, type=expr.type) + if expr.name in {"plds", "pldi"}: + source = self._lower_expr(expr.args[0], env, indent=indent, into=into) + if expr.name == "pldi": + offset = self._lower_to_index(expr.args[1], env, indent=indent, into=into) + else: + offset = self._lower_expr(expr.args[1], env, indent=indent, into=into) + dist = self._render_string_literal(expr.args[2]) + into.append( + self._indent(indent) + + f"{result_name} = pto.{expr.name} {source.name}[{offset.name}], {dist} : " + + f"{self._render_type(source.type)}, {self._render_type(offset.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + if expr.name == "vldas": source = self._lower_expr(expr.args[0], env, indent=indent, into=into) index_args = expr.args[1:] @@ -2779,13 +2828,13 @@ def _lower_call_expr( ) return _RenderedValue(name=result_name, type=expr.type) - if expr.name == "psel": + if expr.name in {"psel", "pand", "por", "pxor"}: src0 = self._lower_expr(expr.args[0], env, indent=indent, into=into) src1 = self._lower_expr(expr.args[1], env, indent=indent, into=into) mask = self._lower_expr(expr.args[2], env, indent=indent, into=into) into.append( self._indent(indent) - + f"{result_name} = pto.psel {src0.name}, {src1.name}, {mask.name} : " + + f"{result_name} = pto.{expr.name} {src0.name}, {src1.name}, {mask.name} : " + f"{self._render_type(src0.type)}, {self._render_type(src1.type)}, {self._render_type(mask.type)} " + f"-> {self._render_type(expr.type)}" ) @@ -3171,6 +3220,28 @@ def _lower_to_i32( return _RenderedValue(name=cast_name, type=_I32_TYPE) raise NotImplementedError("expected an i32 or index operand during TileLang DSL v1 lowering") + def _lower_to_index( + self, + expr: SemanticExpr, + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + value = self._lower_expr(expr, env, indent=indent, into=into) + if isinstance(value.type, SemanticIndexType): + return value + if isinstance(value.type, SemanticScalarType) and is_integer_dtype(value.type.dtype): + bits = integer_bitwidth(value.type.dtype) + if bits in {32, 64}: + cast_name = self._new_temp() + into.append( + self._indent(indent) + + f"{cast_name} = arith.index_cast {value.name} : {value.type.dtype.name} to index" + ) + return _RenderedValue(name=cast_name, type=SemanticIndexType()) + raise NotImplementedError("expected an i32/i64/index operand during TileLang DSL v1 lowering") + def _coerce_rendered_to_i64( self, value: _RenderedValue, diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 4b601dd17..51e11194e 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -61,6 +61,7 @@ OrderMode, PadMode, PadValue, + PredicateDist, Pipe, PostUpdateMode, PositionMode, @@ -132,6 +133,7 @@ pad_value.name: pad_value for pad_value in (PadValue.NULL, PadValue.ZERO, PadValue.MAX, PadValue.MIN) } +_PREDICATE_DIST_SYMBOLS = {dist.name: dist for dist in PredicateDist} _DEINTERLEAVE_DIST_SYMBOLS = dict(DeinterleaveDist.__members__) _INTERLEAVE_DIST_SYMBOLS = dict(InterleaveDist.__members__) _POSITION_MODE_SYMBOLS = {position_mode.name: position_mode for position_mode in PositionMode} @@ -286,7 +288,33 @@ def _classify_vcvt_elem_kind(dtype: ScalarType) -> str | None: dtype.name for dtype in (i8, i16, i32, f16, bf16, f32) ) _COMPARE_SELECT_OPS = {"vcmp", "vcmps", "vsel", "vselr", "vselrv2"} -_PREDICATE_MOVEMENT_OPS = {"pnot", "psel", "ppack", "punpack"} +_PREDICATE_MOVEMENT_OPS = { + "pset_b8", + "pset_b16", + "pset_b32", + "pge_b8", + "pge_b16", + "pge_b32", + "plt_b8", + "plt_b16", + "plt_b32", + "plds", + "pld", + "pldi", + "psts", + "pst", + "psti", + "pstu", + "pnot", + "psel", + "pand", + "por", + "pxor", + "ppack", + "punpack", + "pdintlv_b8", + "pintlv_b16", +} _CARRY_OPS = {"vaddc", "vsubc", "vaddcs", "vsubcs"} _REARRANGEMENT_OPS = {"vintlv", "vdintlv", "vintlvv2", "vdintlvv2"} _UB_HELPER_OPS = {"vbitsort", "vmrgsort4"} @@ -560,6 +588,7 @@ class SemanticVectorPairStoreStmt(SemanticStmt): @dataclass(frozen=True) class SemanticPredicateStoreStmt(SemanticStmt): + op_name: str value: SemanticExpr destination: SemanticExpr indices: tuple[SemanticExpr, ...] @@ -1120,6 +1149,7 @@ def _should_infer_vecscope( "vlds", "vldas", "vldus", + "plds", "psts", "pstu", "vsst", @@ -1219,6 +1249,7 @@ def _frontend_stmt_contains_vector_activity(self, stmt: FrontendStmtNode) -> boo "vlds", "vldas", "vldus", + "plds", "psts", "pstu", "vsst", @@ -1324,6 +1355,10 @@ def _semantic_block_contains_vector_activity( return True if isinstance(stmt, SemanticVectorStoreStmt): return True + if isinstance(stmt, SemanticPredicateStoreStmt): + return True + if isinstance(stmt, SemanticAlignStoreStmt): + return True if isinstance(stmt, SemanticDmaConfigStmt): return True if isinstance(stmt, SemanticDmaUnaryConfigStmt): @@ -1595,7 +1630,7 @@ def _is_vector_store_call(self, expr: FrontendExprNode) -> bool: return ( isinstance(expr, FrontendCallExpr) and expr.namespace == "pto" - and expr.name in {"psts", "vsst", "vsta", "vstas", "vstar", "vsts", "vstsx2"} + and expr.name in {"psts", "pst", "psti", "vsst", "vsta", "vstas", "vstar", "vsts", "vstsx2"} ) def _is_scalar_store_call(self, expr: FrontendExprNode) -> bool: @@ -1725,11 +1760,12 @@ def _analyze_vector_store_stmt( *, allow_outer_lookup: bool, ) -> tuple[SemanticStmt, dict[str, SemanticBinding]]: - if expr.name == "psts": + if expr.name in {"psts", "pst", "psti"}: + canonical_name = "psts" if expr.name == "pst" else expr.name if len(expr.args) in {2, 3} and isinstance(expr.args[1], FrontendSubscriptExpr): raise TypeError( - "pto.psts does not support Tile element-indexing syntax in TileLang DSL v1; " - "use explicit pointer form `pto.psts(mask, buf, offset[, dist])`" + f"pto.{expr.name} does not support Tile element-indexing syntax in TileLang DSL v1; " + f"use explicit pointer form `pto.{expr.name}(mask, buf, offset[, dist])`" ) args = tuple( @@ -1745,16 +1781,20 @@ def _analyze_vector_store_stmt( indices = (offset,) else: raise TypeError( - "pto.psts expects 3 or 4 positional arguments in TileLang DSL v1: " - "`pto.psts(mask, buf, offset[, dist])`" + f"pto.{expr.name} expects 3 or 4 positional arguments in TileLang DSL v1: " + f"`pto.{expr.name}(mask, buf, offset[, dist])`" ) - self._require_mask_expr(value, "pto.psts value") - self._require_vector_pointer_expr(destination, "pto.psts destination") + self._require_mask_expr(value, f"pto.{expr.name} value") + self._require_vector_pointer_expr(destination, f"pto.{expr.name} destination") for index in indices: - self._require_index_typed_expr(index) - dist = self._normalize_predicate_store_dist(dist_expr, "pto.psts dist") + if expr.name == "psti": + self._require_i32_like_expr(index, "pto.psti offset") + else: + self._require_index_typed_expr(index) + dist = self._normalize_predicate_store_dist(dist_expr, f"pto.{expr.name} dist") return ( SemanticPredicateStoreStmt( + op_name=canonical_name, value=value, destination=destination, indices=indices, @@ -2588,7 +2628,12 @@ def _bind_assignment_target( for axis in range(value.type.rank) ) elif isinstance(value, SemanticCallExpr): - tuple_values = value.args + if len(value.args) == len(element_types): + tuple_values = value.args + else: + tuple_values = tuple( + SemanticLiteralExpr(value=None, type=element_type) for element_type in element_types + ) else: tuple_values = tuple( SemanticLiteralExpr(value=None, type=element_type) for element_type in element_types @@ -3255,6 +3300,15 @@ def _analyze_symbol_expr(self, expr: FrontendSymbolExpr) -> SemanticExpr: value=pad_value, type=SemanticPadValueType(), ) + if expr.namespace in {"PredicateDist", "pto.PredicateDist"}: + predicate_dist = _PREDICATE_DIST_SYMBOLS.get(expr.name) + if predicate_dist is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=predicate_dist, + type=SemanticMetaType(kind="predicate_dist"), + ) if expr.namespace in {"DeinterleaveDist", "pto.DeinterleaveDist"}: dist = _DEINTERLEAVE_DIST_SYMBOLS.get(expr.name) if dist is not None: @@ -3814,6 +3868,12 @@ def _analyze_call_expr( return self._analyze_vldus(args) if name == "vldsx2": return self._analyze_vldsx2(args) + if name in {"pset_b8", "pset_b16", "pset_b32", "pge_b8", "pge_b16", "pge_b32"}: + return self._analyze_predicate_pattern_op(name, args) + if name in {"plt_b8", "plt_b16", "plt_b32"}: + return self._analyze_predicate_tail_op(name, args) + if name in {"plds", "pld", "pldi"}: + return self._analyze_predicate_load_op(name, args) if name == "pstu": return self._analyze_pstu(args) if name == "vstus": @@ -3824,8 +3884,10 @@ def _analyze_call_expr( return self._analyze_load_scalar(args) if name in {"ppack", "punpack"}: return self._analyze_mask_part_op(name, args) - if name in {"pnot", "psel"}: + if name in {"pnot", "psel", "pand", "por", "pxor"}: return self._analyze_mask_logic_op(name, args) + if name in {"pdintlv_b8", "pintlv_b16"}: + return self._analyze_predicate_reorder_op(name, args) if name in {"vcmp", "vcmps"}: return self._analyze_compare_op(name, args) if name in {"vsel", "vselr", "vselrv2"}: @@ -3883,6 +3945,46 @@ def _analyze_make_mask(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: ), ) + def _analyze_predicate_pattern_op( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if len(args) != 1: + raise TypeError(f"pto.{name} expects exactly 1 positional argument in TileLang DSL v1") + pattern = args[0] + if not ( + isinstance(pattern, SemanticSymbolExpr) + and isinstance(pattern.type, SemanticMetaType) + and pattern.type.kind == "mask_pattern" + and isinstance(pattern.value, MaskPattern) + ): + raise TypeError(f"pto.{name} pattern must be a MaskPattern symbol such as `pto.PAT.ALL`") + granularity = name.rsplit("_", 1)[-1] + return SemanticCallExpr( + namespace="pto", + name=name, + args=args, + type=SemanticMaskType(granularity=granularity), + ) + + def _analyze_predicate_tail_op( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if len(args) != 1: + raise TypeError(f"pto.{name} expects exactly 1 positional argument in TileLang DSL v1") + self._require_tail_remaining_expr(args[0], f"pto.{name} scalar") + granularity = name.rsplit("_", 1)[-1] + mask_type = SemanticMaskType(granularity=granularity) + return SemanticCallExpr( + namespace="pto", + name=name, + args=args, + type=SemanticTupleType(elements=(mask_type, _I32_TYPE)), + ) + def _literal_expr_from_context_value(self, value: object, context: str) -> SemanticExpr: if isinstance(value, bool): return SemanticLiteralExpr(value=value, type=SemanticScalarType(dtype=i1)) @@ -4244,6 +4346,48 @@ def _analyze_vldsx2(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: type=SemanticTupleType(elements=(vreg_type, vreg_type)), ) + def _analyze_predicate_load_op( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + expects_i32_immediate = name == "pldi" + canonical_name = "plds" if name == "pld" else name + if len(args) not in {2, 3}: + raise TypeError( + f"pto.{name} expects 2 or 3 positional arguments in TileLang DSL v1: " + f"`pto.{name}(buf, offset[, dist])`" + ) + + source, offset = args[:2] + source = self._require_pointer_expr(source, f"pto.{name} source", memory_space="ub") + if expects_i32_immediate: + self._require_i32_like_expr(offset, "pto.pldi offset") + else: + self._require_index_typed_expr(offset) + dist = self._normalize_predicate_load_dist( + args[2] if len(args) == 3 else None, + f"pto.{name} dist", + ) + + if source.type.element_dtype == ui8: + granularity = "b8" + elif source.type.element_dtype == ui16: + granularity = "b16" + elif source.type.element_dtype == ui32: + granularity = "b32" + else: + raise TypeError( + f"pto.{name} source must be !pto.ptr in TileLang DSL v1" + ) + + return SemanticCallExpr( + namespace="pto", + name=canonical_name, + args=(source, offset, dist), + type=SemanticMaskType(granularity=granularity), + ) + def _analyze_pstu(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: if len(args) != 3: raise TypeError("pto.pstu expects exactly 3 positional arguments in TileLang DSL v1") @@ -4525,6 +4669,15 @@ def _analyze_mask_logic_op( mask = self._require_mask_expr(args[1], "pto.pnot mask") self._require_matching_mask_types(value, mask, "pto.pnot") return SemanticCallExpr(namespace="pto", name=name, args=args, type=value) + if name in {"pand", "por", "pxor"}: + if len(args) != 3: + raise TypeError(f"pto.{name} expects exactly 3 positional arguments in TileLang DSL") + src0 = self._require_mask_expr(args[0], f"pto.{name} src0") + src1 = self._require_mask_expr(args[1], f"pto.{name} src1") + mask = self._require_mask_expr(args[2], f"pto.{name} mask") + self._require_matching_mask_types(src0, src1, f"pto.{name}") + self._require_matching_mask_types(src0, mask, f"pto.{name}") + return SemanticCallExpr(namespace="pto", name=name, args=args, type=src0) if len(args) != 3: raise TypeError("pto.psel expects exactly 3 positional arguments in TileLang DSL") src0 = self._require_mask_expr(args[0], "pto.psel src0") @@ -4534,6 +4687,30 @@ def _analyze_mask_logic_op( self._require_matching_mask_types(src0, mask, "pto.psel") return SemanticCallExpr(namespace="pto", name=name, args=args, type=src0) + def _analyze_predicate_reorder_op( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if len(args) != 2: + raise TypeError(f"pto.{name} expects exactly 2 positional arguments in TileLang DSL v1") + lhs = self._require_mask_expr(args[0], f"pto.{name} src0") + rhs = self._require_mask_expr(args[1], f"pto.{name} src1") + expected_granularity = "b8" if name == "pdintlv_b8" else "b16" + if lhs.granularity != expected_granularity or rhs.granularity != expected_granularity: + raise TypeError(f"pto.{name} expects !pto.mask<{expected_granularity}> operands") + return SemanticCallExpr( + namespace="pto", + name=name, + args=args, + type=SemanticTupleType( + elements=( + SemanticMaskType(granularity=expected_granularity), + SemanticMaskType(granularity=expected_granularity), + ) + ), + ) + def _analyze_compare_op( self, name: str, @@ -5215,9 +5392,63 @@ def _normalize_predicate_store_dist( ) -> SemanticExpr: if expr is None: return SemanticLiteralExpr(value="NORM", type=SemanticMetaType(kind="string")) - dist = self._require_string_expr(expr, context) + if ( + isinstance(expr, SemanticSymbolExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "predicate_dist" + and isinstance(expr.value, PredicateDist) + ): + dist = expr.value.value + elif ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "predicate_dist" + and isinstance(expr.binding.value, PredicateDist) + ): + dist = expr.binding.value.value + else: + raise TypeError( + "predicate store dist must be a PredicateDist enum such as " + "`pto.PredicateDist.NORM` or `pto.PredicateDist.PK` in TileLang DSL v1" + ) if dist not in {"NORM", "PK"}: - raise TypeError("predicate store dist must be \"NORM\" or \"PK\" in TileLang DSL v1") + raise TypeError( + "predicate store dist must be one of " + "`pto.PredicateDist.NORM` or `pto.PredicateDist.PK` in TileLang DSL v1" + ) + return SemanticLiteralExpr(value=dist, type=SemanticMetaType(kind="string")) + + def _normalize_predicate_load_dist( + self, + expr: SemanticExpr | None, + context: str, + ) -> SemanticExpr: + if expr is None: + return SemanticLiteralExpr(value="NORM", type=SemanticMetaType(kind="string")) + if ( + isinstance(expr, SemanticSymbolExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "predicate_dist" + and isinstance(expr.value, PredicateDist) + ): + dist = expr.value.value + elif ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "predicate_dist" + and isinstance(expr.binding.value, PredicateDist) + ): + dist = expr.binding.value.value + else: + raise TypeError( + "predicate load dist must be a PredicateDist enum such as " + "`pto.PredicateDist.NORM`, `pto.PredicateDist.US`, or `pto.PredicateDist.DS` in TileLang DSL v1" + ) + if dist not in {"NORM", "US", "DS"}: + raise TypeError( + "predicate load dist must be one of " + "`pto.PredicateDist.NORM`, `pto.PredicateDist.US`, or `pto.PredicateDist.DS` in TileLang DSL v1" + ) return SemanticLiteralExpr(value=dist, type=SemanticMetaType(kind="string")) def _normalize_vlds_dist( diff --git a/tilelang-dsl/python/tilelang_dsl/support_matrix.py b/tilelang-dsl/python/tilelang_dsl/support_matrix.py index 7b7597866..7619a4965 100644 --- a/tilelang-dsl/python/tilelang_dsl/support_matrix.py +++ b/tilelang-dsl/python/tilelang_dsl/support_matrix.py @@ -66,6 +66,7 @@ "vldas", "vldus", "vldsx2", + "plds", "psts", "pstu", "vsst", @@ -152,10 +153,28 @@ "vsel", "vselr", "vselrv2", + "pset_b8", + "pset_b16", + "pset_b32", + "pge_b8", + "pge_b16", + "pge_b32", + "plt_b8", + "plt_b16", + "plt_b32", "pnot", "psel", + "pand", + "por", + "pxor", "ppack", "punpack", + "pld", + "pldi", + "pst", + "psti", + "pdintlv_b8", + "pintlv_b16", "vaddc", "vsubc", "vaddcs", diff --git a/tilelang-dsl/python/tilelang_dsl/types.py b/tilelang-dsl/python/tilelang_dsl/types.py index 1485013e1..87072cee2 100644 --- a/tilelang-dsl/python/tilelang_dsl/types.py +++ b/tilelang-dsl/python/tilelang_dsl/types.py @@ -189,6 +189,13 @@ class MaskPattern(str, Enum): VL32 = "PAT_VL32" +class PredicateDist(str, Enum): + NORM = "NORM" + US = "US" + DS = "DS" + PK = "PK" + + class PadMode(str, Enum): PadNull = "PadNull" PadFirstElem = "PadFirstElem" @@ -674,6 +681,7 @@ def get_op_attr(name: str, default: Any = None) -> Any: "PIPE", "EVENT", "MaskPattern", + "PredicateDist", "PAT", "BarrierType", "PadMode", diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index acf236152..1b880ad03 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -103,6 +103,7 @@ def test_package_exports_surface(self) -> None: self.assertTrue(hasattr(pto, "get_lanes")) self.assertTrue(hasattr(pto, "elements_per_vreg")) self.assertTrue(hasattr(pto, "PAT")) + self.assertTrue(hasattr(pto, "PredicateDist")) self.assertTrue(hasattr(pto, "PadMode")) self.assertTrue(hasattr(pto, "BarrierType")) self.assertTrue(hasattr(pto, "BLayout")) @@ -143,6 +144,10 @@ def test_package_exports_surface(self) -> None: self.assertEqual(pto.PositionMode.LOWEST.value, "LOWEST") self.assertEqual(pto.PositionMode.HIGHEST.value, "HIGHEST") self.assertEqual(pto.OrderMode.ASC.value, "ORDER_ASC") + self.assertEqual(pto.PredicateDist.NORM.value, "NORM") + self.assertEqual(pto.PredicateDist.US.value, "US") + self.assertEqual(pto.PredicateDist.DS.value, "DS") + self.assertEqual(pto.PredicateDist.PK.value, "PK") self.assertEqual(pto.VcvtRoundMode.R.value, "R") self.assertEqual(pto.VcvtSatMode.SAT.value, "SAT") self.assertEqual(pto.VcvtPartMode.ODD.value, "ODD") @@ -4926,6 +4931,144 @@ def kernel(mask_dst: pto.Tile): self.assertIn("does not support Tile element-indexing syntax", str(ctx.exception)) self.assertIn("pto.psts(mask, buf, offset", str(ctx.exception)) + def test_plds_load_lower_to_supported_op(self) -> None: + @pto.vkernel( + op="predicate_load_from_ub_buffer", + dtypes=[(pto.ui32, pto.ui32)], + advanced=True, + ) + def kernel( + mask_src: pto.ptr(pto.ui32, pto.MemorySpace.UB), + mask_dst: pto.ptr(pto.ui32, pto.MemorySpace.UB), + ): + mask = pto.plds(mask_src, 0) + pto.psts(mask, mask_dst, 0) + return None + + specialized = kernel.specialize() + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + vecscope = next(stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticVecscopeStmt)) + load_assign = next( + stmt + for stmt in vecscope.body + if isinstance(stmt, SemanticAssignStmt) + and isinstance(stmt.value, SemanticCallExpr) + and stmt.value.name == "plds" + ) + self.assertIsInstance(load_assign.value.type, SemanticMaskType) + self.assertEqual(load_assign.value.type.granularity, "b32") + + text = specialized.mlir_text() + self.assertIn("pto.plds", text) + self.assertIn('"NORM"', text) + self.assertIn("pto.psts", text) + + def test_plds_rejects_unsupported_dist_token(self) -> None: + with self.assertRaises(TypeError) as ctx: + + @pto.vkernel( + op="predicate_load_invalid_dist", + dtypes=[(pto.ui32,)], + advanced=True, + ) + def kernel(mask_src: pto.ptr(pto.ui32, pto.MemorySpace.UB)): + _mask = pto.plds(mask_src, 0, pto.PredicateDist.PK) + return None + + kernel.specialize().mlir_text() + + self.assertIn("predicate load dist must be one of", str(ctx.exception)) + self.assertIn("pto.PredicateDist.DS", str(ctx.exception)) + + def test_predicate_generation_and_logic_families_lower_to_supported_ops(self) -> None: + @pto.vkernel( + op="predicate_generation_and_logic_families", + dtypes=[(pto.ui32,)], + advanced=True, + ) + def kernel(mask_dst: pto.ptr(pto.ui32, pto.MemorySpace.UB)): + mask8 = pto.pset_b8(pto.PAT.ALL) + mask16 = pto.pge_b16(pto.PAT.VL16) + mask32, _next = pto.plt_b32(64) + and_mask = pto.pand(mask32, mask32, mask32) + or_mask = pto.por(and_mask, mask32, mask32) + xor_mask = pto.pxor(or_mask, mask32, mask32) + pto.psts(xor_mask, mask_dst, 0) + _ = mask8 + _ = mask16 + return None + + text = kernel.specialize().mlir_text() + self.assertIn("pto.pset_b8", text) + self.assertIn("pto.pge_b16", text) + self.assertIn("pto.plt_b32", text) + self.assertIn("pto.pand", text) + self.assertIn("pto.por", text) + self.assertIn("pto.pxor", text) + + def test_predicate_load_store_alias_and_immediate_forms_lower_to_supported_ops(self) -> None: + @pto.vkernel( + op="predicate_load_store_alias_and_immediate_forms", + dtypes=[(pto.ui32, pto.ui32)], + advanced=True, + ) + def kernel( + mask_src: pto.ptr(pto.ui32, pto.MemorySpace.UB), + mask_dst: pto.ptr(pto.ui32, pto.MemorySpace.UB), + ): + mask0 = pto.pld(mask_src, 0, pto.PredicateDist.NORM) + mask1 = pto.pldi(mask_src, pto.i32(8), pto.PredicateDist.US) + pto.pst(mask0, mask_dst, 0) + pto.psti(mask1, mask_dst, pto.i32(8), pto.PredicateDist.PK) + return None + + text = kernel.specialize().mlir_text() + self.assertIn("pto.plds", text) + self.assertIn("pto.pldi", text) + self.assertIn("pto.psts", text) + self.assertIn("pto.psti", text) + self.assertIn("arith.index_cast", text) + + def test_predicate_reorder_families_lower_to_supported_ops(self) -> None: + @pto.vkernel( + op="predicate_reorder_families", + dtypes=[(pto.ui32,)], + advanced=True, + ) + def kernel(mask_dst: pto.ptr(pto.ui32, pto.MemorySpace.UB)): + mask8 = pto.pset_b8(pto.PAT.ALL) + mask16 = pto.pset_b16(pto.PAT.ALL) + low8, high8 = pto.pdintlv_b8(mask8, mask8) + low16, high16 = pto.pintlv_b16(mask16, mask16) + _ = low8 + _ = high8 + _ = low16 + _ = high16 + all32 = pto.make_mask(pto.ui32, pto.PAT.ALL) + pto.psts(all32, mask_dst, 0) + return None + + text = kernel.specialize().mlir_text() + self.assertIn("pto.pdintlv_b8", text) + self.assertIn("pto.pintlv_b16", text) + + def test_pdintlv_b8_rejects_wrong_mask_granularity(self) -> None: + with self.assertRaises(TypeError) as ctx: + + @pto.vkernel( + op="predicate_reorder_wrong_mask_granularity", + dtypes=[(pto.ui32,)], + advanced=True, + ) + def kernel(mask_src: pto.ptr(pto.ui32, pto.MemorySpace.UB)): + mask32 = pto.plds(mask_src, 0) + _low, _high = pto.pdintlv_b8(mask32, mask32) + return None + + kernel.specialize().mlir_text() + + self.assertIn("expects !pto.mask operands", str(ctx.exception)) + def test_strict_vecscope_rejects_implicit_capture_during_semantic_analysis(self) -> None: @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f16, pto.i32)], advanced=True) def kernel(inp: pto.TensorView, tile: pto.Tile, scale: pto.i32): From a76956e60ca03eda1bee390c5a8158fef7af873b Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Fri, 17 Apr 2026 10:44:04 +0800 Subject: [PATCH 099/192] Add vbitcast implementation --- docs/isa/09-conversion-ops.md | 54 +++++++ docs/vpto-spec.md | 4 +- include/PTO/IR/VPTOOps.td | 13 ++ lib/PTO/IR/VPTO.cpp | 28 ++++ lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 23 +++ test/basic/vbitcast_vpto_llvm.pto | 44 ++++++ .../docs/user_guide/05-type-system.md | 50 +++++++ .../python/tilelang_dsl/frontend_ast.py | 1 + tilelang-dsl/python/tilelang_dsl/kernel.py | 13 ++ tilelang-dsl/python/tilelang_dsl/lowering.py | 9 ++ tilelang-dsl/python/tilelang_dsl/semantic.py | 62 ++++++++ .../python/tilelang_dsl/support_matrix.py | 1 + tilelang-dsl/tests/test_tilelang_dsl_v1.py | 132 ++++++++++++++++++ 13 files changed, 432 insertions(+), 2 deletions(-) create mode 100644 test/basic/vbitcast_vpto_llvm.pto diff --git a/docs/isa/09-conversion-ops.md b/docs/isa/09-conversion-ops.md index efb3a9ed4..bcad3f1af 100644 --- a/docs/isa/09-conversion-ops.md +++ b/docs/isa/09-conversion-ops.md @@ -250,3 +250,57 @@ for (int i = 0; i < N; i++) %int_div = pto.vcvt %floored, %mask {rnd = "Z"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> ``` + +--- + +## `pto.vbitcast` + +- **syntax:** `%result = pto.vbitcast %input : !pto.vreg -> !pto.vreg` +- **semantics:** Bitwise reinterpretation of a vreg vector without changing the underlying bit pattern. This operation performs a pure type cast that preserves the exact bits of each element, changing only their interpretation (e.g., from floating-point to integer). + +- **inputs:** + `%input` is the source vector register value. +- **outputs:** + `%result` is the reinterpreted vector register value. +- **constraints and limitations:** + 1. Both source and result must be `!pto.vreg<...>` types. + 2. Source and result vectors must have the same total bit width (currently 2048 bits). + 3. Only integer and floating-point element types are supported. + +**Element bit-width equality examples:** +- `f32<64>` → `i32<64>` (both 32-bit elements, total 2048 bits) +- `f16<128>` → `i16<128>` (both 16-bit elements, total 2048 bits) +- `bf16<128>` → `ui16<128>` (both 16-bit elements, total 2048 bits) +- `si32<64>` → `ui32<64>` (both 32-bit elements, total 2048 bits) +- `f32<64>` → `i16<128>` (32-bit/16-bit elements, total 2048 bits) + +**Verification:** The operation verifies that: +1. Both input and result are `!pto.vreg<...>` types. +2. Total bit width equals 2048 (the fixed vreg size). + +**Comparison with `pto.vcvt`:** +- `pto.vcvt` performs value conversion with rounding, saturation, and lane placement control. +- `pto.vbitcast` performs bitwise reinterpretation without changing the underlying bit pattern. + +**Example: Reinterpreting float as integer for bit manipulation** +```mlir +// Prepare a vector of float values +%fvec = pto.vlds %ub[%lane] : !pto.ptr -> !pto.vreg<64xf32> + +// Reinterpret as integer for bitwise operations +%ivec = pto.vbitcast %fvec : !pto.vreg<64xf32> -> !pto.vreg<64xi32> + +// Extract sign bit (bit 31) +%sign_bits = pto.vand %ivec, %sign_mask, %mask : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> + +// Reinterpret back to float +%fvec_without_sign = pto.vbitcast %sign_bits : !pto.vreg<64xi32> -> !pto.vreg<64xf32> +``` + +**Example: Type punning between signed and unsigned integer** +```mlir +// Convert signed to unsigned without changing bits +%signed = pto.vlds %ub[%lane] : !pto.ptr -> !pto.vreg<64xsi32> +%unsigned = pto.vbitcast %signed : !pto.vreg<64xsi32> -> !pto.vreg<64xui32> +// Bits are identical; interpretation changes from signed to unsigned +``` diff --git a/docs/vpto-spec.md b/docs/vpto-spec.md index 11e8a6ee8..1c8e27203 100644 --- a/docs/vpto-spec.md +++ b/docs/vpto-spec.md @@ -888,7 +888,7 @@ This section provides a categorized overview of all PTO micro Instruction operat | 6 | [Unary Vector Ops](isa/06-unary-vector-ops.md) | Single-input element-wise operations | 6 | `pto.vabs`, `pto.vexp`, `pto.vln`, `pto.vsqrt`, `pto.vrelu`, `pto.vnot` | | 7 | [Binary Vector Ops](isa/07-binary-vector-ops.md) | Two-input element-wise operations | 13 | `pto.vadd`, `pto.vsub`, `pto.vmul`, `pto.vdiv`, `pto.vmax`, `pto.vmin`, `pto.vand`, `pto.vor`, `pto.vxor`, `pto.vshl`, `pto.vshr`, `pto.vaddc`, `pto.vsubc` | | 8 | [Vec-Scalar Ops](isa/08-vec-scalar-ops.md) | Vector-scalar operations | 9 | `pto.vadds`, `pto.vmuls`, `pto.vmaxs`, `pto.vmins`, `pto.vlrelu`, `pto.vshls`, `pto.vshrs`, `pto.vaddcs`, `pto.vsubcs` | -| 9 | [Conversion Ops](isa/09-conversion-ops.md) | Type conversion with rounding/saturation control | 2 | `pto.vcvt`, `pto.vtrc` | +| 9 | [Conversion Ops](isa/09-conversion-ops.md) | Type conversion with rounding/saturation control | 3 | `pto.vcvt`, `pto.vtrc`, `pto.vbitcast` | | 10 | [Reduction Ops](isa/10-reduction-ops.md) | Vector reductions | 7 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin`, `pto.vcgadd`, `pto.vcgmax`, `pto.vcgmin`, `pto.vcpadd` | | 11 | [Compare & Select](isa/11-compare-select.md) | Comparison and conditional selection | 4 (+1 not A5) | `pto.vcmp`, `pto.vcmps`, `pto.vsel`, `pto.vselr` (`pto.vselrv2` removed: not A5) | | 12 | [Data Rearrangement](isa/12-data-rearrangement.md) | In-register data movement and permutation | 2 (+2 not A5) | `pto.vintlv`, `pto.vdintlv` (`pto.vintlvv2`, `pto.vdintlvv2` removed: not A5) | @@ -928,7 +928,7 @@ This section provides a categorized overview of all PTO micro Instruction operat | Operation | Group | Description | |-----------|-------|-------------| -| Type Conversion | 9 | `pto.vcvt` | +| Type Conversion | 9 | `pto.vcvt`, `pto.vbitcast` | | Interleave/Deinterleave | 12 | `pto.vintlv`, `pto.vdintlv` | | Interleave/Deinterleave (not A5) | 12 | `pto.vintlvv2`, `pto.vdintlvv2` | diff --git a/include/PTO/IR/VPTOOps.td b/include/PTO/IR/VPTOOps.td index 15967393b..909786c6c 100644 --- a/include/PTO/IR/VPTOOps.td +++ b/include/PTO/IR/VPTOOps.td @@ -935,6 +935,19 @@ def PTO_VcvtOp : PTO_Op<"vcvt", [Pure]> { let hasCustomAssemblyFormat = 1; } +def PTO_VbitcastOp : PTO_Op<"vbitcast", [Pure]> { + let arguments = (ins + PTO_VectorType:$input + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $input attr-dict `:` type($input) `->` type($result) + }]; +} + def PTO_VciOp : PTO_Op<"vci", [Pure]> { let arguments = (ins AnyInteger:$index, diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp index 7a3e1a4d5..90e8b384b 100644 --- a/lib/PTO/IR/VPTO.cpp +++ b/lib/PTO/IR/VPTO.cpp @@ -2331,6 +2331,34 @@ LogicalResult VcvtOp::verify() { return success(); } +LogicalResult VbitcastOp::verify() { + auto inputType = dyn_cast(getInput().getType()); + auto resultType = dyn_cast(getResult().getType()); + if (!inputType || !resultType) + return emitOpError("input and result must be !pto.vreg<...>"); + + auto getStorageBits = [](VRegType type) -> std::optional { + Type elementType = type.getElementType(); + if (auto intType = dyn_cast(elementType)) + return type.getElementCount() * static_cast(intType.getWidth()); + if (auto floatType = dyn_cast(elementType)) + return type.getElementCount() * + static_cast(floatType.getWidth()); + return std::nullopt; + }; + + auto inputBits = getStorageBits(inputType); + auto resultBits = getStorageBits(resultType); + if (!inputBits || !resultBits) + return emitOpError("requires integer or floating-point vreg element type"); + if (*inputBits != *resultBits) { + return emitOpError("requires source and result vectors to carry the same " + "total number of bits"); + } + + return success(); +} + LogicalResult PdintlvB8Op::verify() { if (failed(verifyMaskTypeWithGranularityLike(*this, getLhs().getType(), "lhs type", "b8")) || diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index 31fdbd11e..2505a4e48 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -4117,6 +4117,27 @@ class LowerVcvtOpPattern final : public OpConversionPattern { LoweringState &state; }; +class LowerVbitcastOpPattern final + : public OpConversionPattern { +public: + explicit LowerVbitcastOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context) {} + + LogicalResult + matchAndRewrite(pto::VbitcastOp op, pto::VbitcastOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, + "failed to convert vbitcast result type"); + rewriter.replaceOpWithNewOp(op, resultType, + adaptor.getInput()); + return success(); + } +}; + class LowerVtrcOpPattern final : public OpConversionPattern { public: explicit LowerVtrcOpPattern(TypeConverter &typeConverter, MLIRContext *context, @@ -4872,6 +4893,7 @@ static void populateVPTOOpLoweringPatterns(VPTOTypeConverter &typeConverter, LowerVpreluOpPattern, LowerVaxpyOpPattern, LowerVciOpPattern, LowerVexpdiffOpPattern, LowerVbitsortOpPattern, LowerVtrcOpPattern, LowerVcvtOpPattern, + LowerVbitcastOpPattern, LowerPredicateLoadOpPattern, LowerPredicateLoadOpPattern, LowerPredicateStoreOpPattern, @@ -4930,6 +4952,7 @@ static void configureVPTOOpLoweringTarget(ConversionTarget &target, pto::VintlvOp, pto::VdintlvOp, pto::VpreluOp, pto::VaxpyOp, pto::VciOp, pto::VexpdiffOp, pto::VbitsortOp, pto::VtrcOp, pto::VcvtOp, + pto::VbitcastOp, pto::VcmpOp, pto::VcmpsOp, pto::CopyGmToUbufOp, pto::CopyUbufToGmOp>(); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); diff --git a/test/basic/vbitcast_vpto_llvm.pto b/test/basic/vbitcast_vpto_llvm.pto new file mode 100644 index 000000000..8302e0f2a --- /dev/null +++ b/test/basic/vbitcast_vpto_llvm.pto @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --vpto-emit-hivm-llvm %s -o - | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + func.func @vbitcast_f32_to_i32_store(%value: f32, %dst: !pto.ptr) attributes {pto.kernel_kind = #pto.kernel_kind} { + %c0 = arith.constant 0 : index + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %vec = pto.vdup %value, %mask : f32, !pto.mask -> !pto.vreg<64xf32> + %cast = pto.vbitcast %vec : !pto.vreg<64xf32> -> !pto.vreg<64xi32> + pto.vsts %cast, %dst[%c0], %mask : !pto.vreg<64xi32>, !pto.ptr, !pto.mask + } + return + } + + func.func @vbitcast_f32_to_i16x128_store(%value: f32, %dst: !pto.ptr) attributes {pto.kernel_kind = #pto.kernel_kind} { + %c0 = arith.constant 0 : index + pto.vecscope { + %mask_f32 = pto.pset_b32 "PAT_ALL" : !pto.mask + %mask_i16 = pto.pset_b16 "PAT_ALL" : !pto.mask + %vec = pto.vdup %value, %mask_f32 : f32, !pto.mask -> !pto.vreg<64xf32> + %cast = pto.vbitcast %vec : !pto.vreg<64xf32> -> !pto.vreg<128xi16> + pto.vsts %cast, %dst[%c0], %mask_i16 : !pto.vreg<128xi16>, !pto.ptr, !pto.mask + } + return + } +} + +// CHECK-LABEL: define void @vbitcast_f32_to_i32_store( +// CHECK: %[[VDUP0:[^ ]+]] = call <64 x float> @llvm.hivm.vdups.v64f32.z( +// CHECK: %[[CAST0:[^ ]+]] = bitcast <64 x float> %[[VDUP0]] to <64 x i32> +// CHECK: call void @llvm.hivm.vstsx1.v64s32( + +// CHECK-LABEL: define void @vbitcast_f32_to_i16x128_store( +// CHECK: %[[VDUP1:[^ ]+]] = call <64 x float> @llvm.hivm.vdups.v64f32.z( +// CHECK: %[[CAST1:[^ ]+]] = bitcast <64 x float> %[[VDUP1]] to <128 x i16> +// CHECK: call void @llvm.hivm.vstsx1.v128s16( diff --git a/tilelang-dsl/docs/user_guide/05-type-system.md b/tilelang-dsl/docs/user_guide/05-type-system.md index 49239a82b..28c35cf9c 100644 --- a/tilelang-dsl/docs/user_guide/05-type-system.md +++ b/tilelang-dsl/docs/user_guide/05-type-system.md @@ -102,6 +102,56 @@ lanes1 = pto.elements_per_vreg(pto.f32) # 64 Current TileLang DSL v1 vector lowering supports the 8/16/32-bit integer families (`i*`, `si*`, `ui*`) plus `f16`, `bf16`, and `f32` element types. +### Vector Type Reinterpretation (vbitcast) + +Vector registers support bitwise type reinterpretation via `pto.vbitcast`: + +```python +result = pto.vbitcast(vector, to_type) +``` + +Interface summary: +- `vector`: a vector register value of type `!pto.vreg` +- `to_type`: target element dtype such as `pto.i32`, `pto.ui32`, `pto.f16`, `pto.bf16`, `pto.f32` +- return: a new vector register `!pto.vreg` whose element count is inferred from the fixed 256-byte vreg width + +Constraints: +- `vector` must be a vreg value; scalar values, pointers, `Tile`, and `TensorView` are rejected +- `to_type` must be a DSL-supported vreg element dtype +- `vbitcast` preserves the total register storage size, so only reinterpretations with the same total bit count are allowed +- the operation has no mask, rounding, saturation, or lane-placement parameters + +Lane count is recomputed from `to_type`: +- `!pto.vreg<64xf32> + pto.i32 -> !pto.vreg<64xi32>` +- `!pto.vreg<64xf32> + pto.f16 -> !pto.vreg<128xf16>` +- `!pto.vreg<128xbf16> + pto.ui16 -> !pto.vreg<128xui16>` + +```python +# Float to integer bitwise reinterpretation +fvec = pto.vlds(ub_ptr, lane) # !pto.vreg<64xf32> +ivec = pto.vbitcast(fvec, pto.i32) # !pto.vreg<64xi32> + +# Signed to unsigned integer reinterpretation +signed_vec = pto.vlds(ptr, lane) # !pto.vreg<64xsi32> +unsigned_vec = pto.vbitcast(signed_vec, pto.ui32) # !pto.vreg<64xui32> + +# Element size change (32-bit to 16-bit) +f32_vec = pto.vlds(ptr, lane) # !pto.vreg<64xf32> +f16_vec = pto.vbitcast(f32_vec, pto.f16) # !pto.vreg<128xf16> +``` + +Pythonic syntax sugar via `astype()` method: + +```python +ivec = fvec.astype(pto.i32) # Float to integer +unsigned_vec = signed_vec.astype(pto.ui32) # Signed to unsigned +f16_vec = f32_vec.astype(pto.f16) # 32-bit to 16-bit +``` + +`astype()` on a vector register is syntax sugar for `pto.vbitcast(...)`. In other words, it is a bit reinterpretation API, not a numeric conversion API. + +**Note**: `vbitcast` preserves the exact bit pattern (type punning), unlike `vcvt` which performs value conversion with rounding/saturation. Use `vcvt` when you want numeric conversion semantics; use `vbitcast` when you want the bits to stay unchanged. + ### Typed Masks Masks are typed by their bit granularity: diff --git a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py index 130dcc165..19fad19bd 100644 --- a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py +++ b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py @@ -820,6 +820,7 @@ def _collect_reachable_inline_procs( "vtrc": frozenset({"rnd"}), "vlds": frozenset({"dist"}), "vsts": frozenset({"dist"}), + "vbitcast": frozenset(), } diff --git a/tilelang-dsl/python/tilelang_dsl/kernel.py b/tilelang-dsl/python/tilelang_dsl/kernel.py index e239bded4..7eba04f19 100644 --- a/tilelang-dsl/python/tilelang_dsl/kernel.py +++ b/tilelang-dsl/python/tilelang_dsl/kernel.py @@ -423,6 +423,19 @@ def visit_Call(self, node: ast.Call) -> None: node, "surface `as_ptr` requires advanced=True in TileLang DSL v1", ) + if node.func.attr == "astype": + if node.keywords: + raise self.source_info.error( + node, + "`astype` does not support keyword arguments in TileLang DSL v1", + ) + if len(node.args) != 1: + raise self.source_info.error( + node, + "`astype()` expects exactly 1 positional argument (target dtype) in TileLang DSL v1", + ) + # Type checking will be done during semantic analysis + return if node.func.value.id == "pto" and node.func.attr == "tpl": self._validate_call_keywords(node) return diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index 5f2adf04d..a57a35572 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -2916,6 +2916,15 @@ def _lower_call_expr( ) return _RenderedValue(name=result_name, type=expr.type) + if expr.name == "vbitcast": + value = self._lower_expr(expr.args[0], env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"{result_name} = pto.vbitcast {value.name} : " + + f"{self._render_type(value.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + if expr.name == "vbitsort": destination = self._lower_expr(expr.args[0], env, indent=indent, into=into) source = self._lower_expr(expr.args[1], env, indent=indent, into=into) diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 51e11194e..6133867c7 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -3089,6 +3089,17 @@ def _analyze_expr( for arg in expr.args[1:] ) return self._analyze_eval_method(base, args) + if expr.namespace is None and expr.name == "astype": + if expr.keywords: + raise TypeError("method call `astype` does not support keyword arguments in TileLang DSL v1") + if not expr.args: + raise TypeError("`astype()` expects a receiver in TileLang DSL v1") + base = self._analyze_expr(expr.args[0], env, allow_outer_lookup=allow_outer_lookup) + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args[1:] + ) + return self._analyze_astype_method(base, args) if expr.namespace not in {None, "pto"} and expr.name == "eval": if expr.keywords: raise TypeError("method call `eval` does not support keyword arguments in TileLang DSL v1") @@ -3117,6 +3128,22 @@ def _analyze_expr( ) base = SemanticBindingRef(binding=binding, type=binding.type) return self._analyze_as_ptr_method(base) + if expr.namespace not in {None, "pto"} and expr.name == "astype": + if expr.keywords: + raise TypeError("method call `astype` does not support keyword arguments in TileLang DSL v1") + binding = env.get(expr.namespace) + if binding is None: + if allow_outer_lookup: + raise ValueError(f"unknown name '{expr.namespace}'") + raise ValueError( + f"implicit capture of '{expr.namespace}' is not allowed in pto.strict_vecscope" + ) + base = SemanticBindingRef(binding=binding, type=binding.type) + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + return self._analyze_astype_method(base, args) if expr.namespace == "pto" and expr.name == "vlds": return self._analyze_vlds_frontend_call( expr, @@ -3545,6 +3572,23 @@ def _analyze_as_ptr_method(self, base: SemanticExpr) -> SemanticExpr: ) raise TypeError("`as_ptr()` expects a TensorView/PartitionTensorView or Tile value in TileLang DSL v1") + def _analyze_astype_method(self, base: SemanticExpr, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + """Analyze vreg.astype(dtype) method call.""" + if len(args) != 1: + raise TypeError("`astype()` expects exactly 1 positional argument (target dtype) in TileLang DSL v1") + # Verify target dtype is a valid dtype symbol + target_dtype = self._require_dtype_symbol(args[0], "astype target dtype") + # Verify base is a vector register + if not isinstance(base.type, SemanticVRegType): + raise TypeError("`astype()` expects a vector register value in TileLang DSL v1") + # Convert to pto.vbitcast call, pass original dtype expression as second argument + return SemanticCallExpr( + namespace="pto", + name="vbitcast", + args=(base, args[0]), + type=self._vreg_type_for_dtype(target_dtype), + ) + def _valid_shape_expr(self, base: SemanticExpr) -> SemanticExpr: base_type = base.type if not isinstance(base_type, (SemanticTensorViewType, SemanticPartitionTensorViewType, SemanticTileType)): @@ -3898,6 +3942,8 @@ def _analyze_call_expr( return self._analyze_rearrangement_op(name, args) if name == "vcvt": return self._analyze_vcvt(args) + if name == "vbitcast": + return self._analyze_vbitcast(args) if name == "vtrc": return self._analyze_vtrc(args) if name == "vbitsort": @@ -4979,6 +5025,22 @@ def _analyze_vtrc( type=vector, ) + def _analyze_vbitcast(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) != 2: + raise TypeError("pto.vbitcast expects exactly 2 positional arguments in TileLang DSL") + vector = self._require_vreg_expr(args[0], "pto.vbitcast vector") + target_dtype = self._require_dtype_symbol(args[1], "pto.vbitcast to_type") + # No mask for vbitcast (pure type conversion) + return SemanticCallExpr( + namespace="pto", + name="vbitcast", + args=( + args[0], + args[1], + ), + type=self._vreg_type_for_dtype(target_dtype), + ) + def _analyze_vbitsort(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: if len(args) != 4: raise TypeError("pto.vbitsort expects exactly 4 positional arguments in TileLang DSL v1") diff --git a/tilelang-dsl/python/tilelang_dsl/support_matrix.py b/tilelang-dsl/python/tilelang_dsl/support_matrix.py index 7619a4965..a82a2d3f1 100644 --- a/tilelang-dsl/python/tilelang_dsl/support_matrix.py +++ b/tilelang-dsl/python/tilelang_dsl/support_matrix.py @@ -142,6 +142,7 @@ "vsort32", "vmrgsort", "vcvt", + "vbitcast", "vci", } ) diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index 1b880ad03..d607e849c 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -3124,6 +3124,138 @@ def kernel(dst: pto.Tile, src: pto.Tile): self.assertIn("does not accept `part=`", str(ctx.exception)) + def test_vbitcast_supports_direct_interface(self) -> None: + @pto.vkernel( + op="vbitcast_direct_unique", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + all_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + # Load float vector + fvec = pto.vlds(src, 0) # !pto.vreg<64xf32> + # Convert to integer via vbitcast + ivec = pto.vbitcast(fvec, pto.i32) # !pto.vreg<64xi32> + # Convert back to float + fvec2 = pto.vbitcast(ivec, pto.f32) # !pto.vreg<64xf32> + # Store result + pto.vsts(fvec2, dst, 0, all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vbitcast", text) + self.assertRegex(text, r"= pto\.vbitcast %[^:]+ : !pto\.vreg<64xf32> -> !pto\.vreg<64xi32>") + self.assertRegex(text, r"= pto\.vbitcast %[^:]+ : !pto\.vreg<64xi32> -> !pto\.vreg<64xf32>") + + def test_vbitcast_supports_astype_syntax_sugar(self) -> None: + @pto.vkernel( + op="vbitcast_astype_unique", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + all_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + # Load float vector + fvec = pto.vlds(src, 0) # !pto.vreg<64xf32> + # Convert to integer via astype syntax sugar + ivec = fvec.astype(pto.i32) # !pto.vreg<64xi32> + # Convert back to float + fvec2 = ivec.astype(pto.f32) # !pto.vreg<64xf32> + # Store result + pto.vsts(fvec2, dst, 0, all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vbitcast", text) + # astype calls should be lowered to vbitcast + count = text.count("pto.vbitcast") + self.assertGreaterEqual(count, 2) + + def test_vbitcast_rejects_non_vreg_input(self) -> None: + with self.assertRaises(TypeError) as ctx: + @pto.vkernel( + op="vbitcast_non_vreg_input_unique", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + all_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + # Try to vbitcast a non-vector value + scalar = pto.f32(1.0) + ivec = pto.vbitcast(scalar, pto.i32) + pto.vsts(ivec, dst, 0, all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + specialized.mlir_text() + + self.assertIn("vector register value", str(ctx.exception)) + + def test_astype_requires_vector_register(self) -> None: + with self.assertRaises(TypeError) as ctx: + @pto.vkernel( + op="astype_non_vreg_input_unique", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + # Try to call astype on a non-vector value + scalar = pto.f32(1.0) + ivec = scalar.astype(pto.i32) + all_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + pto.vsts(ivec, dst, 0, all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + specialized.mlir_text() + + self.assertIn("vector register value", str(ctx.exception)) + + def test_vbitcast_supports_element_size_change(self) -> None: + @pto.vkernel( + op="vbitcast_element_size_change_unique", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + f32_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + f16_mask = pto.make_mask(pto.f16, pto.PAT.ALL) + # Load f32 vector (64 elements) + f32_vec = pto.vlds(src, 0) # !pto.vreg<64xf32> + # Convert to f16 (128 elements) + f16_vec = pto.vbitcast(f32_vec, pto.f16) # !pto.vreg<128xf16> + # Convert back to f32 + f32_vec2 = pto.vbitcast(f16_vec, pto.f32) # !pto.vreg<64xf32> + # Store result + pto.vsts(f32_vec2, dst, 0, f32_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vbitcast", text) + self.assertRegex(text, r"= pto\.vbitcast %[^:]+ : !pto\.vreg<64xf32> -> !pto\.vreg<128xf16>") + self.assertRegex(text, r"= pto\.vbitcast %[^:]+ : !pto\.vreg<128xf16> -> !pto\.vreg<64xf32>") + def test_index_to_float_scalar_cast_lowers_via_integer_bridge(self) -> None: @pto.vkernel( op="index_to_float_scalar_cast_unique", From 0e034043f8811a5223395d1ae6bca7857e4e259c Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Mon, 20 Apr 2026 14:52:34 +0800 Subject: [PATCH 100/192] Fix TileLang enum support for issue 131 --- tilelang-dsl/python/tilelang_dsl/__init__.py | 4 + .../python/tilelang_dsl/frontend_ast.py | 2 + tilelang-dsl/python/tilelang_dsl/semantic.py | 95 +++++++++++++++++-- tilelang-dsl/python/tilelang_dsl/types.py | 16 ++++ tilelang-dsl/tests/test_tilelang_dsl_v1.py | 22 +++-- 5 files changed, 127 insertions(+), 12 deletions(-) diff --git a/tilelang-dsl/python/tilelang_dsl/__init__.py b/tilelang-dsl/python/tilelang_dsl/__init__.py index eca98c8ba..8f036b476 100644 --- a/tilelang-dsl/python/tilelang_dsl/__init__.py +++ b/tilelang-dsl/python/tilelang_dsl/__init__.py @@ -37,8 +37,10 @@ MaskType, MemorySpace, MaskPattern, + CmpMode, PAT, PadMode, + PredicatePart, PositionMode, OrderMode, PadValue, @@ -122,6 +124,8 @@ "EVENT", "MaskPattern", "PredicateDist", + "PredicatePart", + "CmpMode", "PAT", "BarrierType", "BLayout", diff --git a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py index 19fad19bd..dcd58a9e4 100644 --- a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py +++ b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py @@ -980,6 +980,8 @@ def _build_expr(node: ast.AST, context: _FrontendBuildContext) -> FrontendExprNo "EVENT", "MaskPattern", "PredicateDist", + "PredicatePart", + "CmpMode", "Pipe", "Event", "BarrierType", diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 6133867c7..462c33db6 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -52,6 +52,7 @@ AlignType, BarrierType, BLayout, + CmpMode, DeinterleaveDist, Event, InterleaveDist, @@ -62,6 +63,7 @@ PadMode, PadValue, PredicateDist, + PredicatePart, Pipe, PostUpdateMode, PositionMode, @@ -134,6 +136,8 @@ for pad_value in (PadValue.NULL, PadValue.ZERO, PadValue.MAX, PadValue.MIN) } _PREDICATE_DIST_SYMBOLS = {dist.name: dist for dist in PredicateDist} +_PREDICATE_PART_SYMBOLS = {part.name: part for part in PredicatePart} +_CMP_MODE_SYMBOLS = {mode.name: mode for mode in CmpMode} _DEINTERLEAVE_DIST_SYMBOLS = dict(DeinterleaveDist.__members__) _INTERLEAVE_DIST_SYMBOLS = dict(InterleaveDist.__members__) _POSITION_MODE_SYMBOLS = {position_mode.name: position_mode for position_mode in PositionMode} @@ -3336,6 +3340,24 @@ def _analyze_symbol_expr(self, expr: FrontendSymbolExpr) -> SemanticExpr: value=predicate_dist, type=SemanticMetaType(kind="predicate_dist"), ) + if expr.namespace in {"PredicatePart", "pto.PredicatePart"}: + predicate_part = _PREDICATE_PART_SYMBOLS.get(expr.name) + if predicate_part is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=predicate_part, + type=SemanticMetaType(kind="predicate_part"), + ) + if expr.namespace in {"CmpMode", "pto.CmpMode"}: + cmp_mode = _CMP_MODE_SYMBOLS.get(expr.name) + if cmp_mode is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=cmp_mode, + type=SemanticMetaType(kind="cmp_mode"), + ) if expr.namespace in {"DeinterleaveDist", "pto.DeinterleaveDist"}: dist = _DEINTERLEAVE_DIST_SYMBOLS.get(expr.name) if dist is not None: @@ -4054,6 +4076,20 @@ def _literal_expr_from_context_value(self, value: object, context: str) -> Seman value=value, type=SemanticMetaType(kind="memory_space"), ) + if isinstance(value, CmpMode): + return SemanticSymbolExpr( + namespace="pto", + name=value.name, + value=value, + type=SemanticMetaType(kind="cmp_mode"), + ) + if isinstance(value, PredicatePart): + return SemanticSymbolExpr( + namespace="pto", + name=value.name, + value=value, + type=SemanticMetaType(kind="predicate_part"), + ) raise TypeError( f"{context} resolved to unsupported static value {value!r} in TileLang DSL v1" ) @@ -4700,8 +4736,8 @@ def _analyze_mask_part_op( if len(args) != 2: raise TypeError(f"pto.{name} expects exactly 2 positional arguments in TileLang DSL") mask = self._require_mask_expr(args[0], f"pto.{name} mask") - self._require_string_expr(args[1], f"pto.{name} part") - return SemanticCallExpr(namespace="pto", name=name, args=args, type=mask) + part = self._normalize_predicate_part(args[1], f"pto.{name} part") + return SemanticCallExpr(namespace="pto", name=name, args=(args[0], part), type=mask) def _analyze_mask_logic_op( self, @@ -4771,11 +4807,11 @@ def _analyze_compare_op( raise TypeError("pto.vcmp requires lhs/rhs vector types to match") seed = self._require_mask_expr(args[2], "pto.vcmp seed mask") self._require_mask_for_vreg(args[2], lhs, "pto.vcmp") - self._require_string_expr(args[3], "pto.vcmp compare mode") + cmp_mode = self._normalize_cmp_mode(args[3], "pto.vcmp compare mode") return SemanticCallExpr( namespace="pto", name=name, - args=args, + args=(args[0], args[1], args[2], cmp_mode), type=SemanticMaskType(granularity=seed.granularity), ) @@ -4787,11 +4823,11 @@ def _analyze_compare_op( raise TypeError("pto.vcmps scalar dtype must match vector element dtype") seed = self._require_mask_expr(args[2], "pto.vcmps seed mask") self._require_mask_for_vreg(args[2], vector, "pto.vcmps") - self._require_string_expr(args[3], "pto.vcmps compare mode") + cmp_mode = self._normalize_cmp_mode(args[3], "pto.vcmps compare mode") return SemanticCallExpr( namespace="pto", name=name, - args=args, + args=(args[0], args[1], args[2], cmp_mode), type=SemanticMaskType(granularity=seed.granularity), ) @@ -5421,6 +5457,53 @@ def _require_string_expr(self, expr: SemanticExpr, context: str) -> str: return expr.binding.value raise TypeError(f"{context} must be a string literal in TileLang DSL") + def _normalize_cmp_mode(self, expr: SemanticExpr, context: str) -> SemanticExpr: + if ( + isinstance(expr, SemanticSymbolExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "cmp_mode" + and isinstance(expr.value, CmpMode) + ): + return SemanticLiteralExpr(value=expr.value.value, type=SemanticMetaType(kind="string")) + if ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "cmp_mode" + and isinstance(expr.binding.value, CmpMode) + ): + return SemanticLiteralExpr(value=expr.binding.value.value, type=SemanticMetaType(kind="string")) + cmp_mode = self._require_string_expr(expr, context) + if cmp_mode not in {mode.value for mode in CmpMode}: + raise TypeError( + f"{context} must be a CmpMode enum such as `pto.CmpMode.LT`, " + 'or one of the canonical strings `"eq"`, `"ne"`, `"lt"`, `"le"`, `"gt"`, `"ge"` ' + "in TileLang DSL v1" + ) + return SemanticLiteralExpr(value=cmp_mode, type=SemanticMetaType(kind="string")) + + def _normalize_predicate_part(self, expr: SemanticExpr, context: str) -> SemanticExpr: + if ( + isinstance(expr, SemanticSymbolExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "predicate_part" + and isinstance(expr.value, PredicatePart) + ): + return SemanticLiteralExpr(value=expr.value.value, type=SemanticMetaType(kind="string")) + if ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "predicate_part" + and isinstance(expr.binding.value, PredicatePart) + ): + return SemanticLiteralExpr(value=expr.binding.value.value, type=SemanticMetaType(kind="string")) + part = self._require_string_expr(expr, context) + if part not in {token.value for token in PredicatePart}: + raise TypeError( + f"{context} must be a PredicatePart enum such as `pto.PredicatePart.LOWER`, " + 'or one of the canonical strings `"LOWER"`, `"HIGHER"` in TileLang DSL v1' + ) + return SemanticLiteralExpr(value=part, type=SemanticMetaType(kind="string")) + def _normalize_post_update_mode( self, expr: SemanticExpr | None, diff --git a/tilelang-dsl/python/tilelang_dsl/types.py b/tilelang-dsl/python/tilelang_dsl/types.py index 87072cee2..a572325c6 100644 --- a/tilelang-dsl/python/tilelang_dsl/types.py +++ b/tilelang-dsl/python/tilelang_dsl/types.py @@ -196,6 +196,20 @@ class PredicateDist(str, Enum): PK = "PK" +class PredicatePart(str, Enum): + LOWER = "LOWER" + HIGHER = "HIGHER" + + +class CmpMode(str, Enum): + EQ = "eq" + NE = "ne" + LT = "lt" + LE = "le" + GT = "gt" + GE = "ge" + + class PadMode(str, Enum): PadNull = "PadNull" PadFirstElem = "PadFirstElem" @@ -682,6 +696,8 @@ def get_op_attr(name: str, default: Any = None) -> Any: "EVENT", "MaskPattern", "PredicateDist", + "PredicatePart", + "CmpMode", "PAT", "BarrierType", "PadMode", diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index d607e849c..a7f652c0f 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -148,6 +148,16 @@ def test_package_exports_surface(self) -> None: self.assertEqual(pto.PredicateDist.US.value, "US") self.assertEqual(pto.PredicateDist.DS.value, "DS") self.assertEqual(pto.PredicateDist.PK.value, "PK") + self.assertTrue(hasattr(pto, "PredicatePart")) + self.assertEqual(pto.PredicatePart.LOWER.value, "LOWER") + self.assertEqual(pto.PredicatePart.HIGHER.value, "HIGHER") + self.assertTrue(hasattr(pto, "CmpMode")) + self.assertEqual(pto.CmpMode.EQ.value, "eq") + self.assertEqual(pto.CmpMode.NE.value, "ne") + self.assertEqual(pto.CmpMode.LT.value, "lt") + self.assertEqual(pto.CmpMode.LE.value, "le") + self.assertEqual(pto.CmpMode.GT.value, "gt") + self.assertEqual(pto.CmpMode.GE.value, "ge") self.assertEqual(pto.VcvtRoundMode.R.value, "R") self.assertEqual(pto.VcvtSatMode.SAT.value, "SAT") self.assertEqual(pto.VcvtPartMode.ODD.value, "ODD") @@ -4653,12 +4663,12 @@ def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile, scalar: pto.i32): all_mask = pto.make_mask(pto.i32, pto.PAT.ALL) lhs = pto.vlds(src0[0, 0:]) rhs = pto.vlds(src1[0, 0:]) - cmp_mask = pto.vcmp(lhs, rhs, all_mask, "lt") - cmp_scalar_mask = pto.vcmps(lhs, scalar, all_mask, "gt") + cmp_mask = pto.vcmp(lhs, rhs, all_mask, pto.CmpMode.LT) + cmp_scalar_mask = pto.vcmps(lhs, scalar, all_mask, pto.CmpMode.GT) negated = pto.pnot(cmp_mask, all_mask) picked = pto.psel(cmp_mask, negated, cmp_scalar_mask) - packed = pto.ppack(picked, "PART_EVEN") - unpacked = pto.punpack(packed, "PART_ODD") + packed = pto.ppack(picked, pto.PredicatePart.LOWER) + unpacked = pto.punpack(packed, pto.PredicatePart.HIGHER) sum_vec, carry_mask = pto.vaddc(lhs, rhs, all_mask) diff_vec, borrow_mask = pto.vsubc(lhs, rhs, all_mask) sum_with_carry, carry_mask2 = pto.vaddcs(sum_vec, diff_vec, carry_mask, all_mask) @@ -4692,9 +4702,9 @@ def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile, scalar: pto.i32): self.assertIn(" = pto.pnot ", text) self.assertIn(" = pto.psel ", text) self.assertIn(' = pto.ppack ', text) - self.assertIn('"PART_EVEN"', text) + self.assertIn('"LOWER"', text) self.assertIn(' = pto.punpack ', text) - self.assertIn('"PART_ODD"', text) + self.assertIn('"HIGHER"', text) self.assertRegex( text, r"%sum_vec_\d+, %carry_mask_\d+ = pto\.vaddc %lhs_\d+, %rhs_\d+, %all_mask_\d+ : !pto\.vreg<64xi32>, !pto\.vreg<64xi32>, !pto\.mask -> !pto\.vreg<64xi32>, !pto\.mask", From 4d07cb352625150c3c44e149fa37ab6be8ab41f5 Mon Sep 17 00:00:00 2001 From: OmarZohir Date: Fri, 17 Apr 2026 16:59:59 +0800 Subject: [PATCH 101/192] Adding testing for TCI, assuring correctness Made-with: Cursor --- lib/PTO/Transforms/PTOToVPTOLowering.cpp | 2 +- lib/TileOps/tci_template.py | 64 +++ lib/TileOps/tstore_template.py | 1 + test/basic/expand_tile_op_tilelang_tci.pto | 37 ++ .../npu/a5/src/st/testcase/CMakeLists.txt | 1 + .../npu/a5/src/st/testcase/tci/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tci/cases.py | 121 ++++++ .../npu/a5/src/st/testcase/tci/compare.py | 47 +++ .../npu/a5/src/st/testcase/tci/gen_data.py | 36 ++ .../npu/a5/src/st/testcase/tci/launch.cpp | 52 +++ .../npu/a5/src/st/testcase/tci/main.cpp | 141 +++++++ .../npu/a5/src/st/testcase/tci/tci.pto | 363 ++++++++++++++++++ tilelang-dsl/python/tilelang_dsl/types.py | 13 +- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 6 +- 14 files changed, 881 insertions(+), 12 deletions(-) create mode 100644 lib/TileOps/tci_template.py create mode 100644 test/basic/expand_tile_op_tilelang_tci.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tci/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tci/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tci/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tci/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tci/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tci/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tci/tci.pto diff --git a/lib/PTO/Transforms/PTOToVPTOLowering.cpp b/lib/PTO/Transforms/PTOToVPTOLowering.cpp index ff8df34f9..99c121733 100644 --- a/lib/PTO/Transforms/PTOToVPTOLowering.cpp +++ b/lib/PTO/Transforms/PTOToVPTOLowering.cpp @@ -5252,7 +5252,7 @@ LogicalResult lowerTTRANS(TTransOp op, PatternRewriter &rewriter) { rewriter.create(op.getLoc(), indexElemType, chunkBase); auto indices = rewriter.create(op.getLoc(), indexVecType, chunkBaseI32, - rewriter.getStringAttr("INC_ORDER")); + rewriter.getStringAttr("ASC")); Value fullMask = buildAllPredicateMask(rewriter, op.getLoc(), indexElemType); auto scaled = rewriter.create(op.getLoc(), indexVecType, indices.getResult(), srcStrideI32, fullMask); diff --git a/lib/TileOps/tci_template.py b/lib/TileOps/tci_template.py new file mode 100644 index 000000000..3b722765c --- /dev/null +++ b/lib/TileOps/tci_template.py @@ -0,0 +1,64 @@ +"""TileLang DSL template for pto.tci. + +`pto.tci` writes the integer sequence ``dst[i, j] = start + linear_index(i, j)`` +into a vec tile. The A5 lowering path currently requires ``valid_rows == 1``, +so this template only covers the 1xC row-major case. On the hardware path the +work is carried out with ``pto.vci``, which seeds a per-lane index vector +``[seed, seed + 1, ..., seed + lanes - 1]``. That matches `TAdd`'s structure +exactly: tile the valid columns into VReg-sized chunks, load the mask for the +trailing partial chunk, and use ``pto.vsts`` to write each chunk back to the +destination tile. +""" + +import tilelang_dsl as pto + + +def _supports_tci_layout(dst): + return ( + dst.shape[0] == 1 + and dst.valid_shape[0] == 1 + and dst.shape[1] > 1 + ) + + +@pto.vkernel( + target="a5", + op="pto.tci", + dtypes=[(pto.i16, pto.i16)], + constraints=[_supports_tci_layout], + advanced=True, +) +def template_tci_i16(start: pto.i16, dst: pto.Tile): + dtype = dst.element_type + _, valid_cols = dst.valid_shape + + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + # DSL v1 lowering cannot directly cast `index` to `i16`, so stage + # through `i32` first (`index -> i32 -> i16`). + col_i16 = pto.i16(pto.i32(col)) + seed = start + col_i16 + indices = pto.vci(seed) + pto.vsts(indices, dst[0, col:], mask) + return + + +@pto.vkernel( + target="a5", + op="pto.tci", + dtypes=[(pto.i32, pto.i32)], + constraints=[_supports_tci_layout], + advanced=True, +) +def template_tci_i32(start: pto.i32, dst: pto.Tile): + dtype = dst.element_type + _, valid_cols = dst.valid_shape + + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + seed = start + pto.i32(col) + indices = pto.vci(seed) + pto.vsts(indices, dst[0, col:], mask) + return diff --git a/lib/TileOps/tstore_template.py b/lib/TileOps/tstore_template.py index 2850a1651..aaec02f33 100644 --- a/lib/TileOps/tstore_template.py +++ b/lib/TileOps/tstore_template.py @@ -54,6 +54,7 @@ def _tstore_preconditions_nz(src, dst) -> bool: src, dst, logical_rows=logical_rows, logical_cols=logical_cols ) + @pto.vkernel( target="a5", op="pto.tstore", diff --git a/test/basic/expand_tile_op_tilelang_tci.pto b/test/basic/expand_tile_op_tilelang_tci.pto new file mode 100644 index 000000000..a142eef70 --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_tci.pto @@ -0,0 +1,37 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tci via the default TileLang Python DSL template +// lib/TileOps/tci_template.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tci should be lowered to vector-style VPTO IR. +// CHECK: func.func @TCI +// CHECK-NOT: pto.tci ins +// CHECK: pto.vecscope +// CHECK: pto.vci +// CHECK: pto.vsts + +module { + func.func @TCI() { + %c7_i32 = arith.constant 7 : i32 + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.tci ins(%c7_i32 : i32) + outs(%tile_buf : !pto.tile_buf) + return + } +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt index db614307b..ad03b47c5 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt +++ b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt @@ -114,6 +114,7 @@ endfunction() # -------------------------------------------------------------------------- set(ALL_TESTCASES tadd + tci tcvt tload ) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tci/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tci/CMakeLists.txt new file mode 100644 index 000000000..6e998ed86 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tci/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tci) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tci/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tci/cases.py new file mode 100644 index 000000000..42e58e81c --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tci/cases.py @@ -0,0 +1,121 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import numpy as np + +# Cases cover the 1D contiguous-integer sequence semantics of pto.tci. The +# shape choices exercise the corner cases of the template's column-chunked +# vector store: +# +# * tiny shapes (partial mask on the first chunk) +# * exact single-VReg multiples (mask covers the full lane count, no tail) +# * one-element tails just past a full VReg (mask of 1 lane on the tail) +# * multi-VReg shapes (two full chunks, no partial tail) +# +# For i32 the vector lane count is 64; for i16 it is 128. +CASES = [ + # ---- i32 (lanes = 64) ---- + { + "name": "i32_1x8", + "dtype": np.int32, + "shape": (1, 8), + "valid_shape": (1, 8), + "start": -5, + "eps": 0.0, + }, + { + "name": "i32_1x32", + "dtype": np.int32, + "shape": (1, 32), + "valid_shape": (1, 32), + "start": 3, + "eps": 0.0, + }, + { + "name": "i32_1x64", + "dtype": np.int32, + "shape": (1, 64), + "valid_shape": (1, 64), + "start": 100, + "eps": 0.0, + }, + { + "name": "i32_1x72", + "dtype": np.int32, + "shape": (1, 72), + "valid_shape": (1, 72), + "start": 0, + "eps": 0.0, + }, + { + "name": "i32_1x80", + "dtype": np.int32, + "shape": (1, 80), + "valid_shape": (1, 80), + "start": 17, + "eps": 0.0, + }, + { + "name": "i32_1x128", + "dtype": np.int32, + "shape": (1, 128), + "valid_shape": (1, 128), + "start": -1000, + "eps": 0.0, + }, + # ---- i16 (lanes = 128) ---- + { + "name": "i16_1x16", + "dtype": np.int16, + "shape": (1, 16), + "valid_shape": (1, 16), + "start": 1000, + "eps": 0.0, + }, + { + "name": "i16_1x64", + "dtype": np.int16, + "shape": (1, 64), + "valid_shape": (1, 64), + "start": 11, + "eps": 0.0, + }, + { + "name": "i16_1x128", + "dtype": np.int16, + "shape": (1, 128), + "valid_shape": (1, 128), + "start": -100, + "eps": 0.0, + }, + { + "name": "i16_1x144", + "dtype": np.int16, + "shape": (1, 144), + "valid_shape": (1, 144), + "start": 0, + "eps": 0.0, + }, + { + "name": "i16_1x160", + "dtype": np.int16, + "shape": (1, 160), + "valid_shape": (1, 160), + "start": -23, + "eps": 0.0, + }, + { + "name": "i16_1x256", + "dtype": np.int16, + "shape": (1, 256), + "valid_shape": (1, 256), + "start": 30000, + "eps": 0.0, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tci/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tci/compare.py new file mode 100644 index 000000000..4ceccdde0 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tci/compare.py @@ -0,0 +1,47 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tci/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tci/gen_data.py new file mode 100644 index 000000000..b3ad632c5 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tci/gen_data.py @@ -0,0 +1,36 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + start = np.array([case["start"]], dtype=dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + base = int(start[0]) + golden[:vr, :vc] = np.asarray( + [base + index for index in range(vr * vc)], + dtype=dtype, + ).reshape(valid_shape) + + save_case_data(case["name"], {"start": start, "golden": golden}) + print( + f"[INFO] gen_data: {case['name']} shape={shape} " + f"valid_shape={valid_shape} start={base} dtype={dtype.__name__}" + ) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tci/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tci/launch.cpp new file mode 100644 index 000000000..29bfc48c7 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tci/launch.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +extern "C" __global__ AICORE void TCI_i32_1x8(int32_t start, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCI_i32_1x32(int32_t start, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCI_i32_1x64(int32_t start, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCI_i32_1x72(int32_t start, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCI_i32_1x80(int32_t start, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCI_i32_1x128(int32_t start, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCI_i16_1x16(int16_t start, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCI_i16_1x64(int16_t start, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCI_i16_1x128(int16_t start, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCI_i16_1x144(int16_t start, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCI_i16_1x160(int16_t start, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCI_i16_1x256(int16_t start, __gm__ int16_t *dst); + +#define DEFINE_LAUNCH_I32(name) \ + void Launch##name(const void *start, void *dst, void *stream) { \ + const int32_t scalar = *reinterpret_cast(start); \ + name<<<1, nullptr, stream>>>(scalar, (__gm__ int32_t *)dst); \ + } + +#define DEFINE_LAUNCH_I16(name) \ + void Launch##name(const void *start, void *dst, void *stream) { \ + const int16_t scalar = *reinterpret_cast(start); \ + name<<<1, nullptr, stream>>>(scalar, (__gm__ int16_t *)dst); \ + } + +DEFINE_LAUNCH_I32(TCI_i32_1x8) +DEFINE_LAUNCH_I32(TCI_i32_1x32) +DEFINE_LAUNCH_I32(TCI_i32_1x64) +DEFINE_LAUNCH_I32(TCI_i32_1x72) +DEFINE_LAUNCH_I32(TCI_i32_1x80) +DEFINE_LAUNCH_I32(TCI_i32_1x128) + +DEFINE_LAUNCH_I16(TCI_i16_1x16) +DEFINE_LAUNCH_I16(TCI_i16_1x64) +DEFINE_LAUNCH_I16(TCI_i16_1x128) +DEFINE_LAUNCH_I16(TCI_i16_1x144) +DEFINE_LAUNCH_I16(TCI_i16_1x160) +DEFINE_LAUNCH_I16(TCI_i16_1x256) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tci/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tci/main.cpp new file mode 100644 index 000000000..a13a898d9 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tci/main.cpp @@ -0,0 +1,141 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +void LaunchTCI_i32_1x8(const void *start, void *dst, void *stream); +void LaunchTCI_i32_1x32(const void *start, void *dst, void *stream); +void LaunchTCI_i32_1x64(const void *start, void *dst, void *stream); +void LaunchTCI_i32_1x72(const void *start, void *dst, void *stream); +void LaunchTCI_i32_1x80(const void *start, void *dst, void *stream); +void LaunchTCI_i32_1x128(const void *start, void *dst, void *stream); +void LaunchTCI_i16_1x16(const void *start, void *dst, void *stream); +void LaunchTCI_i16_1x64(const void *start, void *dst, void *stream); +void LaunchTCI_i16_1x128(const void *start, void *dst, void *stream); +void LaunchTCI_i16_1x144(const void *start, void *dst, void *stream); +void LaunchTCI_i16_1x160(const void *start, void *dst, void *stream); +void LaunchTCI_i16_1x256(const void *start, void *dst, void *stream); + +using LaunchFn = void (*)(const void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; + size_t cols; + size_t validRows; + size_t validCols; + size_t elemSize; + size_t scalarSize; +}; + +#define CASE_I32(cols) \ + {"i32_1x" #cols, LaunchTCI_i32_1x##cols, 1, (cols), 1, (cols), sizeof(int32_t), sizeof(int32_t)} +#define CASE_I16(cols) \ + {"i16_1x" #cols, LaunchTCI_i16_1x##cols, 1, (cols), 1, (cols), sizeof(int16_t), sizeof(int16_t)} + +static const TestCase kCases[] = { + CASE_I32(8), + CASE_I32(32), + CASE_I32(64), + CASE_I32(72), + CASE_I32(80), + CASE_I32(128), + CASE_I16(16), + CASE_I16(64), + CASE_I16(128), + CASE_I16(144), + CASE_I16(160), + CASE_I16(256), +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + std::string caseDir = std::string("./") + tc.name; + size_t startFileSize = tc.scalarSize; + std::vector startHost(tc.scalarSize); + void *dstHost = nullptr; + void *dstDevice = nullptr; + + aclrtMallocHost(&dstHost, fileSize); + aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/start.bin").c_str(), startFileSize, startHost.data(), tc.scalarSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/start.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + tc.launch(startHost.data(), dstDevice, stream); + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tci/tci.pto b/test/tilelang_st/npu/a5/src/st/testcase/tci/tci.pto new file mode 100644 index 000000000..e5c02263e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tci/tci.pto @@ -0,0 +1,363 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tci: generate an integer sequence in UB and +// store it back to GM through the standard 5D tensor-view / partition-view +// path. Compiled by ptoas --enable-insert-sync --enable-tile-op-expand +// --vpto-emit-hivm-llvm to produce LLVM IR. +// +// Each kernel materializes a 1xC tile, runs pto.tci against the scalar +// seed, then streams the tile out to GM via pto.tstore. Cases exercise +// the vectorised template's corner cases (partial masks, exact VReg +// multiples, single-lane tails, etc.). + +module { + // ===================================================================== + // i32 cases (lane count = 64) + // ===================================================================== + + // Partial mask on the only chunk (8 active of 64 lanes). + func.func @TCI_i32_1x8(%start: i32, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c8], + strides = [%c8, %c8, %c8, %c8, %c1] + : !pto.tensor_view<1x1x1x1x8xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c8] + : !pto.tensor_view<1x1x1x1x8xi32> -> !pto.partition_tensor_view<1x1x1x1x8xi32> + + %tile = pto.alloc_tile + : !pto.tile_buf + + pto.tci ins(%start : i32) + outs(%tile : !pto.tile_buf) + pto.tstore ins(%tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x8xi32>) + return + } + + // Partial mask on the only chunk (32 active of 64 lanes). + func.func @TCI_i32_1x32(%start: i32, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c32], + strides = [%c32, %c32, %c32, %c32, %c1] + : !pto.tensor_view<1x1x1x1x32xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c32] + : !pto.tensor_view<1x1x1x1x32xi32> -> !pto.partition_tensor_view<1x1x1x1x32xi32> + + %tile = pto.alloc_tile + : !pto.tile_buf + + pto.tci ins(%start : i32) + outs(%tile : !pto.tile_buf) + pto.tstore ins(%tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x32xi32>) + return + } + + // Exact single VReg: one full-mask chunk, no tail. + func.func @TCI_i32_1x64(%start: i32, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xi32> -> !pto.partition_tensor_view<1x1x1x1x64xi32> + + %tile = pto.alloc_tile + : !pto.tile_buf + + pto.tci ins(%start : i32) + outs(%tile : !pto.tile_buf) + pto.tstore ins(%tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x64xi32>) + return + } + + // One full VReg plus an 8-lane tail (tail mask = 8). + func.func @TCI_i32_1x72(%start: i32, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c72 = arith.constant 72 : index + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c72], + strides = [%c72, %c72, %c72, %c72, %c1] + : !pto.tensor_view<1x1x1x1x72xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c72] + : !pto.tensor_view<1x1x1x1x72xi32> -> !pto.partition_tensor_view<1x1x1x1x72xi32> + + %tile = pto.alloc_tile + : !pto.tile_buf + + pto.tci ins(%start : i32) + outs(%tile : !pto.tile_buf) + pto.tstore ins(%tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x72xi32>) + return + } + + // Two chunks: one full VReg + partial tail of 16 lanes. + func.func @TCI_i32_1x80(%start: i32, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c80 = arith.constant 80 : index + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c80], + strides = [%c80, %c80, %c80, %c80, %c1] + : !pto.tensor_view<1x1x1x1x80xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c80] + : !pto.tensor_view<1x1x1x1x80xi32> -> !pto.partition_tensor_view<1x1x1x1x80xi32> + + %tile = pto.alloc_tile + : !pto.tile_buf + + pto.tci ins(%start : i32) + outs(%tile : !pto.tile_buf) + pto.tstore ins(%tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x80xi32>) + return + } + + // Exact two VRegs: two full-mask chunks, no tail. + func.func @TCI_i32_1x128(%start: i32, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xi32> -> !pto.partition_tensor_view<1x1x1x1x128xi32> + + %tile = pto.alloc_tile + : !pto.tile_buf + + pto.tci ins(%start : i32) + outs(%tile : !pto.tile_buf) + pto.tstore ins(%tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xi32>) + return + } + + // ===================================================================== + // i16 cases (lane count = 128) + // ===================================================================== + + // Partial mask on the only chunk (16 active of 128 lanes). + func.func @TCI_i16_1x16(%start: i16, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c16], + strides = [%c16, %c16, %c16, %c16, %c1] + : !pto.tensor_view<1x1x1x1x16xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c16] + : !pto.tensor_view<1x1x1x1x16xi16> -> !pto.partition_tensor_view<1x1x1x1x16xi16> + + %tile = pto.alloc_tile + : !pto.tile_buf + + pto.tci ins(%start : i16) + outs(%tile : !pto.tile_buf) + pto.tstore ins(%tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x16xi16>) + return + } + + // Partial mask on the only chunk (64 active of 128 lanes). + func.func @TCI_i16_1x64(%start: i16, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xi16> -> !pto.partition_tensor_view<1x1x1x1x64xi16> + + %tile = pto.alloc_tile + : !pto.tile_buf + + pto.tci ins(%start : i16) + outs(%tile : !pto.tile_buf) + pto.tstore ins(%tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x64xi16>) + return + } + + // Exact single VReg: one full-mask chunk, no tail. + func.func @TCI_i16_1x128(%start: i16, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xi16> -> !pto.partition_tensor_view<1x1x1x1x128xi16> + + %tile = pto.alloc_tile + : !pto.tile_buf + + pto.tci ins(%start : i16) + outs(%tile : !pto.tile_buf) + pto.tstore ins(%tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xi16>) + return + } + + // One full VReg plus a 16-lane tail (tail mask = 16). + func.func @TCI_i16_1x144(%start: i16, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c144 = arith.constant 144 : index + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c144], + strides = [%c144, %c144, %c144, %c144, %c1] + : !pto.tensor_view<1x1x1x1x144xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c144] + : !pto.tensor_view<1x1x1x1x144xi16> -> !pto.partition_tensor_view<1x1x1x1x144xi16> + + %tile = pto.alloc_tile + : !pto.tile_buf + + pto.tci ins(%start : i16) + outs(%tile : !pto.tile_buf) + pto.tstore ins(%tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x144xi16>) + return + } + + // Two chunks: one full VReg + partial tail of 32 lanes. + func.func @TCI_i16_1x160(%start: i16, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c160 = arith.constant 160 : index + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c160], + strides = [%c160, %c160, %c160, %c160, %c1] + : !pto.tensor_view<1x1x1x1x160xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c160] + : !pto.tensor_view<1x1x1x1x160xi16> -> !pto.partition_tensor_view<1x1x1x1x160xi16> + + %tile = pto.alloc_tile + : !pto.tile_buf + + pto.tci ins(%start : i16) + outs(%tile : !pto.tile_buf) + pto.tstore ins(%tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x160xi16>) + return + } + + // Exact two VRegs: two full-mask chunks, no tail. + func.func @TCI_i16_1x256(%start: i16, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c256] + : !pto.tensor_view<1x1x1x1x256xi16> -> !pto.partition_tensor_view<1x1x1x1x256xi16> + + %tile = pto.alloc_tile + : !pto.tile_buf + + pto.tci ins(%start : i16) + outs(%tile : !pto.tile_buf) + pto.tstore ins(%tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x256xi16>) + return + } +} diff --git a/tilelang-dsl/python/tilelang_dsl/types.py b/tilelang-dsl/python/tilelang_dsl/types.py index a572325c6..3f8762980 100644 --- a/tilelang-dsl/python/tilelang_dsl/types.py +++ b/tilelang-dsl/python/tilelang_dsl/types.py @@ -428,7 +428,11 @@ class PositionMode(str, Enum): class OrderMode(str, Enum): - ASC = "ORDER_ASC" + # The serialized value is emitted into the MLIR `order` attribute of + # `pto.vci` and must match the canonical spelling documented by the VPTO + # spec (see `docs/vpto_spec/vpto-spec-current.md`). Hand-written VPTO IR + # and `VPTOLLVMEmitter::parseOrderImmediate` also use this short form. + ASC = "ASC" class VcvtRoundMode(str, Enum): @@ -670,12 +674,6 @@ def constexpr(value: bool) -> bool: return value -def get_op_attr(name: str, default: Any = None) -> Any: - if not isinstance(name, str) or not name: - raise TypeError("get_op_attr expects a non-empty string attribute name") - return default - - __all__ = [ "ScalarType", "WildcardType", @@ -735,7 +733,6 @@ def get_op_attr(name: str, default: Any = None) -> Any: "mask_b16", "mask_b32", "constexpr", - "get_op_attr", "bytewidth", "get_lanes", "elements_per_vreg", diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index a7f652c0f..a1ef17d07 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -1700,7 +1700,7 @@ def kernel(inp: pto.TensorView, tile: pto.Tile): return None self.assertIn( - "unsupported keyword `offset` for `pto.vlds` in TileLang DSL v1", + "`pto.vlds` does not support keyword arguments in TileLang DSL v1", str(ctx.exception), ) self.assertIn(f"{__file__}:", str(ctx.exception)) @@ -3443,11 +3443,11 @@ def kernel(dst: pto.Tile, src: pto.Tile, seed: pto.i32): self.assertNotIn('position = "POS_LOWEST"', text) self.assertRegex( text, - r'pto\.vci\s+%[^\s]+\s+\{order = "ORDER_ASC"\}\s+:', + r'pto\.vci\s+%[^\s]+\s+\{order = "ASC"\}\s+:', ) self.assertNotRegex( text, - r'pto\.vci\s+%[^\s]+,\s*"ORDER_ASC"\s+:', + r'pto\.vci\s+%[^\s]+,\s*"ASC"\s+:', ) def test_vdup_scalar_input_rejects_position_argument(self) -> None: From fbe08ce7afad143efec7fb48a95f10fe53a129c9 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Mon, 20 Apr 2026 18:18:24 +0800 Subject: [PATCH 102/192] test(tilelang_st): add self-hosted runner batch entrypoint --- test/tilelang_st/script/run_all_st.py | 187 ++++++++++++++++++++++++++ test/tilelang_st/script/run_ci.sh | 22 +++ 2 files changed, 209 insertions(+) create mode 100755 test/tilelang_st/script/run_all_st.py create mode 100755 test/tilelang_st/script/run_ci.sh diff --git a/test/tilelang_st/script/run_all_st.py b/test/tilelang_st/script/run_all_st.py new file mode 100755 index 000000000..a939aa6bf --- /dev/null +++ b/test/tilelang_st/script/run_all_st.py @@ -0,0 +1,187 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""Batch runner for TileLang ST, suitable for CI/self-hosted runner usage.""" + +import argparse +import os +import sys +import traceback + +import run_st + + +SOC_VERSION_MAP = { + "a5": "Ascend950PR_9599", +} + + +def discover_testcases(testcase_root): + testcases = [] + for entry in sorted(os.listdir(testcase_root)): + testcase_dir = os.path.join(testcase_root, entry) + if not os.path.isdir(testcase_dir): + continue + pto_file = os.path.join(testcase_dir, f"{entry}.pto") + if os.path.isfile(pto_file): + testcases.append(entry) + return testcases + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Run all TileLang ST testcases for CI or local batch validation." + ) + parser.add_argument( + "-r", "--run-mode", default="sim", + help="Run mode: sim or npu (default: sim)", + ) + parser.add_argument( + "-v", "--soc-version", default="a5", + help="SoC version: a5 (default: a5)", + ) + parser.add_argument( + "-p", "--ptoas-bin", default=None, + help="Path to ptoas binary (auto-detected if omitted)", + ) + parser.add_argument( + "-t", "--testcase", action="append", default=[], + help="Run only selected testcase(s). Can be passed multiple times.", + ) + parser.add_argument( + "-w", "--without-build", action="store_true", + help="Skip build and reuse the existing build directory.", + ) + parser.add_argument( + "--fail-fast", action="store_true", + help="Stop immediately after the first failed testcase.", + ) + parser.add_argument( + "--list", action="store_true", + help="List discovered testcases and exit.", + ) + return parser.parse_args() + + +def resolve_selected_testcases(all_testcases, requested): + if not requested: + return all_testcases + + requested_set = [] + seen = set() + for testcase in requested: + if testcase not in seen: + requested_set.append(testcase) + seen.add(testcase) + + missing = [testcase for testcase in requested_set if testcase not in all_testcases] + if missing: + raise ValueError( + f"Unsupported testcase(s): {', '.join(missing)}; " + f"supported: {', '.join(all_testcases)}" + ) + return requested_set + + +def main(): + args = parse_args() + + if args.soc_version not in SOC_VERSION_MAP: + print( + f"[ERROR] Unsupported soc-version: {args.soc_version}, " + f"supported: {', '.join(sorted(SOC_VERSION_MAP))}", + file=sys.stderr, + ) + sys.exit(1) + + script_path = os.path.abspath(__file__) + tilelang_st_root = os.path.dirname(os.path.dirname(script_path)) + testcase_root = os.path.join( + tilelang_st_root, "npu", args.soc_version, "src", "st", "testcase" + ) + target_dir = os.path.dirname(testcase_root) + + if not os.path.isdir(testcase_root): + print(f"[ERROR] Testcase root not found: {testcase_root}", file=sys.stderr) + sys.exit(1) + + all_testcases = discover_testcases(testcase_root) + if not all_testcases: + print(f"[ERROR] No testcases found in: {testcase_root}", file=sys.stderr) + sys.exit(1) + + if args.list: + for testcase in all_testcases: + print(testcase) + return + + try: + selected_testcases = resolve_selected_testcases(all_testcases, args.testcase) + except ValueError as exc: + print(f"[ERROR] {exc}", file=sys.stderr) + sys.exit(1) + + ptoas_bin = args.ptoas_bin or run_st.find_ptoas_bin() + if not ptoas_bin: + print( + "[ERROR] Cannot find ptoas binary. Set PTOAS_BIN env or use -p flag.", + file=sys.stderr, + ) + sys.exit(1) + ptoas_bin = os.path.abspath(ptoas_bin) + + default_soc_version = SOC_VERSION_MAP[args.soc_version] + print(f"[INFO] run_mode={args.run_mode}") + print(f"[INFO] soc_version={args.soc_version} ({default_soc_version})") + print(f"[INFO] ptoas={ptoas_bin}") + print(f"[INFO] target_dir={target_dir}") + print(f"[INFO] selected_testcases={', '.join(selected_testcases)}") + + original_dir = os.getcwd() + failures = [] + try: + os.chdir(target_dir) + run_st.set_env_variables(args.run_mode, default_soc_version) + + if not args.without_build: + build_target = "all" if selected_testcases == all_testcases else ";".join(selected_testcases) + print(f"[INFO] build requested for {build_target}") + run_st.build_project(args.run_mode, default_soc_version, "all", ptoas_bin) + + total = len(selected_testcases) + for index, testcase in enumerate(selected_testcases, start=1): + print(f"[INFO] [{index}/{total}] running testcase: {testcase}") + try: + run_st.run_gen_data(testcase) + run_st.run_binary(testcase) + run_st.run_compare(testcase) + except Exception as exc: # pragma: no cover - CI-side aggregation path + failures.append((testcase, str(exc))) + print(f"[ERROR] testcase failed: {testcase}") + traceback.print_exc() + if args.fail_fast: + break + + except Exception as exc: + print(f"[ERROR] batch run failed: {exc}", file=sys.stderr) + sys.exit(1) + finally: + os.chdir(original_dir) + + passed = len(selected_testcases) - len(failures) + print("[INFO] TileLang ST summary") + print(f"[INFO] passed={passed} failed={len(failures)} total={len(selected_testcases)}") + if failures: + for testcase, reason in failures: + print(f"[INFO] failed testcase: {testcase} ({reason})") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/script/run_ci.sh b/test/tilelang_st/script/run_ci.sh new file mode 100755 index 000000000..385a7bda9 --- /dev/null +++ b/test/tilelang_st/script/run_ci.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd -- "${SCRIPT_DIR}/../../.." && pwd)" + +if [[ -f "${REPO_ROOT}/scripts/ptoas_env.sh" ]]; then + # shellcheck source=/dev/null + source "${REPO_ROOT}/scripts/ptoas_env.sh" +fi + +export PYTHONUNBUFFERED=1 + +exec python3 "${SCRIPT_DIR}/run_all_st.py" "$@" From 236c8c73569ef7ca268712a517f36df6704e34b2 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Mon, 20 Apr 2026 20:01:13 +0800 Subject: [PATCH 103/192] fix(tilelang-dsl): emit signless arith constants for unsigned scalars (#143) --- .../docs/user_guide/05-type-system.md | 20 ++++++++++++++ tilelang-dsl/python/tilelang_dsl/lowering.py | 17 +++++++++--- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 27 ++++++++++++++++++- 3 files changed, 60 insertions(+), 4 deletions(-) diff --git a/tilelang-dsl/docs/user_guide/05-type-system.md b/tilelang-dsl/docs/user_guide/05-type-system.md index 28c35cf9c..71c916498 100644 --- a/tilelang-dsl/docs/user_guide/05-type-system.md +++ b/tilelang-dsl/docs/user_guide/05-type-system.md @@ -424,6 +424,26 @@ rank = tile.rank # 2 - `tile.pad_value.eval()` with `PadValue.custom_f32(...)` becomes the authored floating scalar - `tile.pad_value.eval()` with `PadValue.NULL` raises a frontend error +For dtype-dependent fill seeds, prefer `tile.pad_value.eval()` over handwritten +`if dtype == ...` ladders. + +```python +@pto.vkernel(op="fill_pad_value", dtypes=[(pto.AnyType,)]) +def fill_pad_value(dst: pto.Tile): + pad_scalar = dst.pad_value.eval() + pad_vec = pto.vbr(pad_scalar) + # ... +``` + +Typical materialized values: + +- `PadValue.ZERO` -> `0` / `0.0` +- `PadValue.MAX` -> dtype-aware max, for example `4294967295` for `pto.ui32` +- `PadValue.MIN` -> dtype-aware min, for example `-2147483648` for `pto.i32` and `0` for `pto.ui32` + +This is usually simpler than spelling every dtype case manually with +`pto.constexpr(dst.element_type == ...)`. + Example: reading pad value from a `Tile` ```python diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index a57a35572..9c011a170 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -2455,7 +2455,7 @@ def _lower_expr( into.append( self._indent(indent) + f"{desired_name} = arith.constant {self._format_constant(expr.value, expr.type)} : " - f"{self._render_type(expr.type)}" + f"{self._render_arith_constant_type(expr.type)}" ) return _RenderedValue(name=desired_name, type=expr.type) return _RenderedValue( @@ -3525,7 +3525,7 @@ def _lower_subscript_access( into.append( self._indent(indent) + f"{desired_name} = arith.constant {self._format_constant(value, expr.type)} : " - f"{self._render_type(expr.type)}" + f"{self._render_arith_constant_type(expr.type)}" ) return _RenderedValue(name=desired_name, type=expr.type) return _RenderedValue( @@ -3652,7 +3652,8 @@ def _materialize_constant(self, value: object, ty: SemanticType) -> str: self._constant_cache[cache_key] = name self._constant_lines.append( self._indent(4) - + f"{name} = arith.constant {self._format_constant(value, ty)} : {self._render_type(ty)}" + + f"{name} = arith.constant {self._format_constant(value, ty)} : " + f"{self._render_arith_constant_type(ty)}" ) return name @@ -3696,6 +3697,16 @@ def _format_constant(self, value: object, ty: SemanticType) -> str: return str(value) raise NotImplementedError(f"unsupported constant type {ty!r}") + def _render_arith_constant_type(self, ty: SemanticType) -> str: + if isinstance(ty, SemanticScalarType) and is_integer_dtype(ty.dtype): + width = integer_bitwidth(ty.dtype) + if width is None: + raise NotImplementedError( + f"unsupported integer dtype {ty.dtype.name!r} for arith.constant emission" + ) + return f"i{width}" + return self._render_type(ty) + def _format_float_constant(self, value: float, dtype_name: str) -> str: # Emit stable bit-pattern literals for values that are parse-sensitive # (`inf`/`nan`) or sign-sensitive (`-0.0`). diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index a1ef17d07..e993ef90f 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -2352,6 +2352,30 @@ def kernel(tile: pto.Tile): self.assertIn("PadValue.NULL.eval() is invalid", str(ctx.exception)) + def test_unsigned_integer_constants_lower_with_signless_arith_types(self) -> None: + @pto.vkernel(op="tile_pad_value_ui32_max_eval_unique", dtypes=[(pto.ui32,)]) + def kernel(tile: pto.Tile): + scalar = tile.pad_value.eval() + explicit = pto.ui32(4294967295) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(8, 16), + memory_space=pto.MemorySpace.UB, + config=pto.TileConfig.from_mapping( + { + "pad_value": pto.PadValue.MAX, + } + ), + ) + ) + + text = specialized.mlir_text() + self.assertIn("dtype=ui32", text) + self.assertIn("arith.constant 4294967295 : i32", text) + self.assertNotIn("arith.constant 4294967295 : ui32", text) + def test_make_mask_vlds_vsts_and_vector_families_lower_inside_strict_vecscope(self) -> None: @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32)], advanced=True) @@ -2523,7 +2547,8 @@ def kernel(dst_tile: pto.Tile, src_tile: pto.Tile, gate: pto.i32): text = specialized.mlir_text() self.assertRegex(text, r"= arith\.floordivsi %in_gate_\d+, %c2_i32 : i32") self.assertRegex(text, r"= arith\.remsi %in_gate_\d+, %c7_i32 : i32") - self.assertRegex(text, r"= arith\.mulf %tmp_\d+, %c0\.5_f32 : f32") + self.assertRegex(text, r"%c0_5_f32 = arith\.constant 0\.5 : f32") + self.assertRegex(text, r"= arith\.mulf %tmp_\d+, %c0_5_f32 : f32") self.assertRegex(text, r"= arith\.addf %tmp_\d+, %tmp_\d+ : f32") def test_index_floordiv_lowers_to_divui_instead_of_floordivsi(self) -> None: From 658e71e3b7e4200d8140449528a465d0f3ada814 Mon Sep 17 00:00:00 2001 From: Zhang Zhendong Date: Mon, 20 Apr 2026 20:44:00 +0800 Subject: [PATCH 104/192] Revert "Adding testing for TCI, assuring correctness" This reverts commit d32497cbfb0ad477b16693d44042db3573c23d03. --- lib/PTO/Transforms/PTOToVPTOLowering.cpp | 2 +- lib/TileOps/tci_template.py | 64 --- lib/TileOps/tstore_template.py | 1 - test/basic/expand_tile_op_tilelang_tci.pto | 37 -- .../npu/a5/src/st/testcase/CMakeLists.txt | 1 - .../npu/a5/src/st/testcase/tci/CMakeLists.txt | 9 - .../npu/a5/src/st/testcase/tci/cases.py | 121 ------ .../npu/a5/src/st/testcase/tci/compare.py | 47 --- .../npu/a5/src/st/testcase/tci/gen_data.py | 36 -- .../npu/a5/src/st/testcase/tci/launch.cpp | 52 --- .../npu/a5/src/st/testcase/tci/main.cpp | 141 ------- .../npu/a5/src/st/testcase/tci/tci.pto | 363 ------------------ tilelang-dsl/python/tilelang_dsl/types.py | 13 +- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 6 +- 14 files changed, 12 insertions(+), 881 deletions(-) delete mode 100644 lib/TileOps/tci_template.py delete mode 100644 test/basic/expand_tile_op_tilelang_tci.pto delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tci/CMakeLists.txt delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tci/cases.py delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tci/compare.py delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tci/gen_data.py delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tci/launch.cpp delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tci/main.cpp delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tci/tci.pto diff --git a/lib/PTO/Transforms/PTOToVPTOLowering.cpp b/lib/PTO/Transforms/PTOToVPTOLowering.cpp index 99c121733..ff8df34f9 100644 --- a/lib/PTO/Transforms/PTOToVPTOLowering.cpp +++ b/lib/PTO/Transforms/PTOToVPTOLowering.cpp @@ -5252,7 +5252,7 @@ LogicalResult lowerTTRANS(TTransOp op, PatternRewriter &rewriter) { rewriter.create(op.getLoc(), indexElemType, chunkBase); auto indices = rewriter.create(op.getLoc(), indexVecType, chunkBaseI32, - rewriter.getStringAttr("ASC")); + rewriter.getStringAttr("INC_ORDER")); Value fullMask = buildAllPredicateMask(rewriter, op.getLoc(), indexElemType); auto scaled = rewriter.create(op.getLoc(), indexVecType, indices.getResult(), srcStrideI32, fullMask); diff --git a/lib/TileOps/tci_template.py b/lib/TileOps/tci_template.py deleted file mode 100644 index 3b722765c..000000000 --- a/lib/TileOps/tci_template.py +++ /dev/null @@ -1,64 +0,0 @@ -"""TileLang DSL template for pto.tci. - -`pto.tci` writes the integer sequence ``dst[i, j] = start + linear_index(i, j)`` -into a vec tile. The A5 lowering path currently requires ``valid_rows == 1``, -so this template only covers the 1xC row-major case. On the hardware path the -work is carried out with ``pto.vci``, which seeds a per-lane index vector -``[seed, seed + 1, ..., seed + lanes - 1]``. That matches `TAdd`'s structure -exactly: tile the valid columns into VReg-sized chunks, load the mask for the -trailing partial chunk, and use ``pto.vsts`` to write each chunk back to the -destination tile. -""" - -import tilelang_dsl as pto - - -def _supports_tci_layout(dst): - return ( - dst.shape[0] == 1 - and dst.valid_shape[0] == 1 - and dst.shape[1] > 1 - ) - - -@pto.vkernel( - target="a5", - op="pto.tci", - dtypes=[(pto.i16, pto.i16)], - constraints=[_supports_tci_layout], - advanced=True, -) -def template_tci_i16(start: pto.i16, dst: pto.Tile): - dtype = dst.element_type - _, valid_cols = dst.valid_shape - - remained = valid_cols - for col in range(0, valid_cols, pto.get_lanes(dtype)): - mask, remained = pto.make_mask(dtype, remained) - # DSL v1 lowering cannot directly cast `index` to `i16`, so stage - # through `i32` first (`index -> i32 -> i16`). - col_i16 = pto.i16(pto.i32(col)) - seed = start + col_i16 - indices = pto.vci(seed) - pto.vsts(indices, dst[0, col:], mask) - return - - -@pto.vkernel( - target="a5", - op="pto.tci", - dtypes=[(pto.i32, pto.i32)], - constraints=[_supports_tci_layout], - advanced=True, -) -def template_tci_i32(start: pto.i32, dst: pto.Tile): - dtype = dst.element_type - _, valid_cols = dst.valid_shape - - remained = valid_cols - for col in range(0, valid_cols, pto.get_lanes(dtype)): - mask, remained = pto.make_mask(dtype, remained) - seed = start + pto.i32(col) - indices = pto.vci(seed) - pto.vsts(indices, dst[0, col:], mask) - return diff --git a/lib/TileOps/tstore_template.py b/lib/TileOps/tstore_template.py index aaec02f33..2850a1651 100644 --- a/lib/TileOps/tstore_template.py +++ b/lib/TileOps/tstore_template.py @@ -54,7 +54,6 @@ def _tstore_preconditions_nz(src, dst) -> bool: src, dst, logical_rows=logical_rows, logical_cols=logical_cols ) - @pto.vkernel( target="a5", op="pto.tstore", diff --git a/test/basic/expand_tile_op_tilelang_tci.pto b/test/basic/expand_tile_op_tilelang_tci.pto deleted file mode 100644 index a142eef70..000000000 --- a/test/basic/expand_tile_op_tilelang_tci.pto +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline -// expands pto.tci via the default TileLang Python DSL template -// lib/TileOps/tci_template.py. -// -// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics -// -// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s - -// After the full tile-op-expand path on the VPTO backend, the original -// pto.tci should be lowered to vector-style VPTO IR. -// CHECK: func.func @TCI -// CHECK-NOT: pto.tci ins -// CHECK: pto.vecscope -// CHECK: pto.vci -// CHECK: pto.vsts - -module { - func.func @TCI() { - %c7_i32 = arith.constant 7 : i32 - %tile_buf = pto.alloc_tile - : !pto.tile_buf - - pto.tci ins(%c7_i32 : i32) - outs(%tile_buf : !pto.tile_buf) - return - } -} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt index ad03b47c5..db614307b 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt +++ b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt @@ -114,7 +114,6 @@ endfunction() # -------------------------------------------------------------------------- set(ALL_TESTCASES tadd - tci tcvt tload ) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tci/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tci/CMakeLists.txt deleted file mode 100644 index 6e998ed86..000000000 --- a/test/tilelang_st/npu/a5/src/st/testcase/tci/CMakeLists.txt +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright (c) 2026 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. - -pto_tilelang_vec_st(tci) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tci/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tci/cases.py deleted file mode 100644 index 42e58e81c..000000000 --- a/test/tilelang_st/npu/a5/src/st/testcase/tci/cases.py +++ /dev/null @@ -1,121 +0,0 @@ -#!/usr/bin/python3 -# Copyright (c) 2026 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. - -import numpy as np - -# Cases cover the 1D contiguous-integer sequence semantics of pto.tci. The -# shape choices exercise the corner cases of the template's column-chunked -# vector store: -# -# * tiny shapes (partial mask on the first chunk) -# * exact single-VReg multiples (mask covers the full lane count, no tail) -# * one-element tails just past a full VReg (mask of 1 lane on the tail) -# * multi-VReg shapes (two full chunks, no partial tail) -# -# For i32 the vector lane count is 64; for i16 it is 128. -CASES = [ - # ---- i32 (lanes = 64) ---- - { - "name": "i32_1x8", - "dtype": np.int32, - "shape": (1, 8), - "valid_shape": (1, 8), - "start": -5, - "eps": 0.0, - }, - { - "name": "i32_1x32", - "dtype": np.int32, - "shape": (1, 32), - "valid_shape": (1, 32), - "start": 3, - "eps": 0.0, - }, - { - "name": "i32_1x64", - "dtype": np.int32, - "shape": (1, 64), - "valid_shape": (1, 64), - "start": 100, - "eps": 0.0, - }, - { - "name": "i32_1x72", - "dtype": np.int32, - "shape": (1, 72), - "valid_shape": (1, 72), - "start": 0, - "eps": 0.0, - }, - { - "name": "i32_1x80", - "dtype": np.int32, - "shape": (1, 80), - "valid_shape": (1, 80), - "start": 17, - "eps": 0.0, - }, - { - "name": "i32_1x128", - "dtype": np.int32, - "shape": (1, 128), - "valid_shape": (1, 128), - "start": -1000, - "eps": 0.0, - }, - # ---- i16 (lanes = 128) ---- - { - "name": "i16_1x16", - "dtype": np.int16, - "shape": (1, 16), - "valid_shape": (1, 16), - "start": 1000, - "eps": 0.0, - }, - { - "name": "i16_1x64", - "dtype": np.int16, - "shape": (1, 64), - "valid_shape": (1, 64), - "start": 11, - "eps": 0.0, - }, - { - "name": "i16_1x128", - "dtype": np.int16, - "shape": (1, 128), - "valid_shape": (1, 128), - "start": -100, - "eps": 0.0, - }, - { - "name": "i16_1x144", - "dtype": np.int16, - "shape": (1, 144), - "valid_shape": (1, 144), - "start": 0, - "eps": 0.0, - }, - { - "name": "i16_1x160", - "dtype": np.int16, - "shape": (1, 160), - "valid_shape": (1, 160), - "start": -23, - "eps": 0.0, - }, - { - "name": "i16_1x256", - "dtype": np.int16, - "shape": (1, 256), - "valid_shape": (1, 256), - "start": 30000, - "eps": 0.0, - }, -] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tci/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tci/compare.py deleted file mode 100644 index 4ceccdde0..000000000 --- a/test/tilelang_st/npu/a5/src/st/testcase/tci/compare.py +++ /dev/null @@ -1,47 +0,0 @@ -#!/usr/bin/python3 -# Copyright (c) 2026 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. - -import os -import sys -import numpy as np - -from cases import CASES -from st_common import result_cmp, style_fail, style_pass, validate_cases - - -def main(): - validate_cases(CASES) - case_filter = sys.argv[1] if len(sys.argv) > 1 else None - - all_passed = True - for case in CASES: - if case_filter is not None and case["name"] != case_filter: - continue - - case_dir = case["name"] - shape = case["shape"] - vr, vc = case["valid_shape"] - - golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) - output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) - - ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) - if ok: - print(style_pass(f"[INFO] {case['name']}: compare passed")) - else: - print(style_fail(f"[ERROR] {case['name']}: compare failed")) - all_passed = False - - if not all_passed: - sys.exit(2) - print(style_pass("[INFO] all cases passed")) - - -if __name__ == "__main__": - main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tci/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tci/gen_data.py deleted file mode 100644 index b3ad632c5..000000000 --- a/test/tilelang_st/npu/a5/src/st/testcase/tci/gen_data.py +++ /dev/null @@ -1,36 +0,0 @@ -#!/usr/bin/python3 -# Copyright (c) 2026 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. - -import numpy as np -from cases import CASES -from st_common import validate_cases, setup_case_rng, save_case_data - -validate_cases(CASES) - -for case in CASES: - setup_case_rng(case) - - dtype = case["dtype"] - shape = case["shape"] - valid_shape = case["valid_shape"] - start = np.array([case["start"]], dtype=dtype) - - golden = np.zeros(shape, dtype=dtype) - vr, vc = valid_shape - base = int(start[0]) - golden[:vr, :vc] = np.asarray( - [base + index for index in range(vr * vc)], - dtype=dtype, - ).reshape(valid_shape) - - save_case_data(case["name"], {"start": start, "golden": golden}) - print( - f"[INFO] gen_data: {case['name']} shape={shape} " - f"valid_shape={valid_shape} start={base} dtype={dtype.__name__}" - ) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tci/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tci/launch.cpp deleted file mode 100644 index 29bfc48c7..000000000 --- a/test/tilelang_st/npu/a5/src/st/testcase/tci/launch.cpp +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -#include - -#ifndef AICORE -#define AICORE [aicore] -#endif - -extern "C" __global__ AICORE void TCI_i32_1x8(int32_t start, __gm__ int32_t *dst); -extern "C" __global__ AICORE void TCI_i32_1x32(int32_t start, __gm__ int32_t *dst); -extern "C" __global__ AICORE void TCI_i32_1x64(int32_t start, __gm__ int32_t *dst); -extern "C" __global__ AICORE void TCI_i32_1x72(int32_t start, __gm__ int32_t *dst); -extern "C" __global__ AICORE void TCI_i32_1x80(int32_t start, __gm__ int32_t *dst); -extern "C" __global__ AICORE void TCI_i32_1x128(int32_t start, __gm__ int32_t *dst); -extern "C" __global__ AICORE void TCI_i16_1x16(int16_t start, __gm__ int16_t *dst); -extern "C" __global__ AICORE void TCI_i16_1x64(int16_t start, __gm__ int16_t *dst); -extern "C" __global__ AICORE void TCI_i16_1x128(int16_t start, __gm__ int16_t *dst); -extern "C" __global__ AICORE void TCI_i16_1x144(int16_t start, __gm__ int16_t *dst); -extern "C" __global__ AICORE void TCI_i16_1x160(int16_t start, __gm__ int16_t *dst); -extern "C" __global__ AICORE void TCI_i16_1x256(int16_t start, __gm__ int16_t *dst); - -#define DEFINE_LAUNCH_I32(name) \ - void Launch##name(const void *start, void *dst, void *stream) { \ - const int32_t scalar = *reinterpret_cast(start); \ - name<<<1, nullptr, stream>>>(scalar, (__gm__ int32_t *)dst); \ - } - -#define DEFINE_LAUNCH_I16(name) \ - void Launch##name(const void *start, void *dst, void *stream) { \ - const int16_t scalar = *reinterpret_cast(start); \ - name<<<1, nullptr, stream>>>(scalar, (__gm__ int16_t *)dst); \ - } - -DEFINE_LAUNCH_I32(TCI_i32_1x8) -DEFINE_LAUNCH_I32(TCI_i32_1x32) -DEFINE_LAUNCH_I32(TCI_i32_1x64) -DEFINE_LAUNCH_I32(TCI_i32_1x72) -DEFINE_LAUNCH_I32(TCI_i32_1x80) -DEFINE_LAUNCH_I32(TCI_i32_1x128) - -DEFINE_LAUNCH_I16(TCI_i16_1x16) -DEFINE_LAUNCH_I16(TCI_i16_1x64) -DEFINE_LAUNCH_I16(TCI_i16_1x128) -DEFINE_LAUNCH_I16(TCI_i16_1x144) -DEFINE_LAUNCH_I16(TCI_i16_1x160) -DEFINE_LAUNCH_I16(TCI_i16_1x256) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tci/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tci/main.cpp deleted file mode 100644 index a13a898d9..000000000 --- a/test/tilelang_st/npu/a5/src/st/testcase/tci/main.cpp +++ /dev/null @@ -1,141 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -#include "acl/acl.h" -#include "test_common.h" -#include -#include -#include -#include -#include -#include - -using namespace PtoTestCommon; - -void LaunchTCI_i32_1x8(const void *start, void *dst, void *stream); -void LaunchTCI_i32_1x32(const void *start, void *dst, void *stream); -void LaunchTCI_i32_1x64(const void *start, void *dst, void *stream); -void LaunchTCI_i32_1x72(const void *start, void *dst, void *stream); -void LaunchTCI_i32_1x80(const void *start, void *dst, void *stream); -void LaunchTCI_i32_1x128(const void *start, void *dst, void *stream); -void LaunchTCI_i16_1x16(const void *start, void *dst, void *stream); -void LaunchTCI_i16_1x64(const void *start, void *dst, void *stream); -void LaunchTCI_i16_1x128(const void *start, void *dst, void *stream); -void LaunchTCI_i16_1x144(const void *start, void *dst, void *stream); -void LaunchTCI_i16_1x160(const void *start, void *dst, void *stream); -void LaunchTCI_i16_1x256(const void *start, void *dst, void *stream); - -using LaunchFn = void (*)(const void *, void *, void *); - -struct TestCase { - const char *name; - LaunchFn launch; - size_t rows; - size_t cols; - size_t validRows; - size_t validCols; - size_t elemSize; - size_t scalarSize; -}; - -#define CASE_I32(cols) \ - {"i32_1x" #cols, LaunchTCI_i32_1x##cols, 1, (cols), 1, (cols), sizeof(int32_t), sizeof(int32_t)} -#define CASE_I16(cols) \ - {"i16_1x" #cols, LaunchTCI_i16_1x##cols, 1, (cols), 1, (cols), sizeof(int16_t), sizeof(int16_t)} - -static const TestCase kCases[] = { - CASE_I32(8), - CASE_I32(32), - CASE_I32(64), - CASE_I32(72), - CASE_I32(80), - CASE_I32(128), - CASE_I16(16), - CASE_I16(64), - CASE_I16(128), - CASE_I16(144), - CASE_I16(160), - CASE_I16(256), -}; -static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); - -static int RunCase(const TestCase &tc, aclrtStream stream) { - int rc = 0; - const size_t elemCount = tc.rows * tc.cols; - const size_t fileSize = elemCount * tc.elemSize; - - std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", - tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); - - std::string caseDir = std::string("./") + tc.name; - size_t startFileSize = tc.scalarSize; - std::vector startHost(tc.scalarSize); - void *dstHost = nullptr; - void *dstDevice = nullptr; - - aclrtMallocHost(&dstHost, fileSize); - aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); - - if (!ReadFile((caseDir + "/start.bin").c_str(), startFileSize, startHost.data(), tc.scalarSize)) { - std::fprintf(stderr, "[ERROR] failed to read %s/start.bin\n", caseDir.c_str()); - rc = 1; - } - - if (rc == 0) { - tc.launch(startHost.data(), dstDevice, stream); - aclrtSynchronizeStream(stream); - aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); - } - - if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { - std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); - rc = 1; - } - - if (dstDevice != nullptr) - aclrtFree(dstDevice); - if (dstHost != nullptr) - aclrtFreeHost(dstHost); - - if (rc == 0) - std::printf("[INFO] case %s done\n", tc.name); - return rc; -} - -int main(int argc, char *argv[]) { - const char *caseFilter = (argc > 1) ? argv[1] : nullptr; - - int rc = 0; - int deviceId = 0; - aclrtStream stream = nullptr; - - aclInit(nullptr); - if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { - deviceId = std::atoi(envDevice); - } - aclrtSetDevice(deviceId); - aclrtCreateStream(&stream); - - for (size_t i = 0; i < kNumCases; ++i) { - if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { - continue; - } - int ret = RunCase(kCases[i], stream); - if (ret != 0) { - std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); - rc = 1; - break; - } - } - - if (stream != nullptr) - aclrtDestroyStream(stream); - aclrtResetDevice(deviceId); - aclFinalize(); - return rc; -} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tci/tci.pto b/test/tilelang_st/npu/a5/src/st/testcase/tci/tci.pto deleted file mode 100644 index e5c02263e..000000000 --- a/test/tilelang_st/npu/a5/src/st/testcase/tci/tci.pto +++ /dev/null @@ -1,363 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -// TileLang ST kernels for pto.tci: generate an integer sequence in UB and -// store it back to GM through the standard 5D tensor-view / partition-view -// path. Compiled by ptoas --enable-insert-sync --enable-tile-op-expand -// --vpto-emit-hivm-llvm to produce LLVM IR. -// -// Each kernel materializes a 1xC tile, runs pto.tci against the scalar -// seed, then streams the tile out to GM via pto.tstore. Cases exercise -// the vectorised template's corner cases (partial masks, exact VReg -// multiples, single-lane tails, etc.). - -module { - // ===================================================================== - // i32 cases (lane count = 64) - // ===================================================================== - - // Partial mask on the only chunk (8 active of 64 lanes). - func.func @TCI_i32_1x8(%start: i32, %dst_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c8 = arith.constant 8 : index - - %dst_view = pto.make_tensor_view %dst_ptr, - shape = [%c1, %c1, %c1, %c1, %c8], - strides = [%c8, %c8, %c8, %c8, %c1] - : !pto.tensor_view<1x1x1x1x8xi32> - %dst_part = pto.partition_view %dst_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c1, %c8] - : !pto.tensor_view<1x1x1x1x8xi32> -> !pto.partition_tensor_view<1x1x1x1x8xi32> - - %tile = pto.alloc_tile - : !pto.tile_buf - - pto.tci ins(%start : i32) - outs(%tile : !pto.tile_buf) - pto.tstore ins(%tile : !pto.tile_buf) - outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x8xi32>) - return - } - - // Partial mask on the only chunk (32 active of 64 lanes). - func.func @TCI_i32_1x32(%start: i32, %dst_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c32 = arith.constant 32 : index - - %dst_view = pto.make_tensor_view %dst_ptr, - shape = [%c1, %c1, %c1, %c1, %c32], - strides = [%c32, %c32, %c32, %c32, %c1] - : !pto.tensor_view<1x1x1x1x32xi32> - %dst_part = pto.partition_view %dst_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c1, %c32] - : !pto.tensor_view<1x1x1x1x32xi32> -> !pto.partition_tensor_view<1x1x1x1x32xi32> - - %tile = pto.alloc_tile - : !pto.tile_buf - - pto.tci ins(%start : i32) - outs(%tile : !pto.tile_buf) - pto.tstore ins(%tile : !pto.tile_buf) - outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x32xi32>) - return - } - - // Exact single VReg: one full-mask chunk, no tail. - func.func @TCI_i32_1x64(%start: i32, %dst_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c64 = arith.constant 64 : index - - %dst_view = pto.make_tensor_view %dst_ptr, - shape = [%c1, %c1, %c1, %c1, %c64], - strides = [%c64, %c64, %c64, %c64, %c1] - : !pto.tensor_view<1x1x1x1x64xi32> - %dst_part = pto.partition_view %dst_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c1, %c64] - : !pto.tensor_view<1x1x1x1x64xi32> -> !pto.partition_tensor_view<1x1x1x1x64xi32> - - %tile = pto.alloc_tile - : !pto.tile_buf - - pto.tci ins(%start : i32) - outs(%tile : !pto.tile_buf) - pto.tstore ins(%tile : !pto.tile_buf) - outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x64xi32>) - return - } - - // One full VReg plus an 8-lane tail (tail mask = 8). - func.func @TCI_i32_1x72(%start: i32, %dst_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c72 = arith.constant 72 : index - - %dst_view = pto.make_tensor_view %dst_ptr, - shape = [%c1, %c1, %c1, %c1, %c72], - strides = [%c72, %c72, %c72, %c72, %c1] - : !pto.tensor_view<1x1x1x1x72xi32> - %dst_part = pto.partition_view %dst_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c1, %c72] - : !pto.tensor_view<1x1x1x1x72xi32> -> !pto.partition_tensor_view<1x1x1x1x72xi32> - - %tile = pto.alloc_tile - : !pto.tile_buf - - pto.tci ins(%start : i32) - outs(%tile : !pto.tile_buf) - pto.tstore ins(%tile : !pto.tile_buf) - outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x72xi32>) - return - } - - // Two chunks: one full VReg + partial tail of 16 lanes. - func.func @TCI_i32_1x80(%start: i32, %dst_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c80 = arith.constant 80 : index - - %dst_view = pto.make_tensor_view %dst_ptr, - shape = [%c1, %c1, %c1, %c1, %c80], - strides = [%c80, %c80, %c80, %c80, %c1] - : !pto.tensor_view<1x1x1x1x80xi32> - %dst_part = pto.partition_view %dst_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c1, %c80] - : !pto.tensor_view<1x1x1x1x80xi32> -> !pto.partition_tensor_view<1x1x1x1x80xi32> - - %tile = pto.alloc_tile - : !pto.tile_buf - - pto.tci ins(%start : i32) - outs(%tile : !pto.tile_buf) - pto.tstore ins(%tile : !pto.tile_buf) - outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x80xi32>) - return - } - - // Exact two VRegs: two full-mask chunks, no tail. - func.func @TCI_i32_1x128(%start: i32, %dst_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c128 = arith.constant 128 : index - - %dst_view = pto.make_tensor_view %dst_ptr, - shape = [%c1, %c1, %c1, %c1, %c128], - strides = [%c128, %c128, %c128, %c128, %c1] - : !pto.tensor_view<1x1x1x1x128xi32> - %dst_part = pto.partition_view %dst_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c1, %c128] - : !pto.tensor_view<1x1x1x1x128xi32> -> !pto.partition_tensor_view<1x1x1x1x128xi32> - - %tile = pto.alloc_tile - : !pto.tile_buf - - pto.tci ins(%start : i32) - outs(%tile : !pto.tile_buf) - pto.tstore ins(%tile : !pto.tile_buf) - outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xi32>) - return - } - - // ===================================================================== - // i16 cases (lane count = 128) - // ===================================================================== - - // Partial mask on the only chunk (16 active of 128 lanes). - func.func @TCI_i16_1x16(%start: i16, %dst_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c16 = arith.constant 16 : index - - %dst_view = pto.make_tensor_view %dst_ptr, - shape = [%c1, %c1, %c1, %c1, %c16], - strides = [%c16, %c16, %c16, %c16, %c1] - : !pto.tensor_view<1x1x1x1x16xi16> - %dst_part = pto.partition_view %dst_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c1, %c16] - : !pto.tensor_view<1x1x1x1x16xi16> -> !pto.partition_tensor_view<1x1x1x1x16xi16> - - %tile = pto.alloc_tile - : !pto.tile_buf - - pto.tci ins(%start : i16) - outs(%tile : !pto.tile_buf) - pto.tstore ins(%tile : !pto.tile_buf) - outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x16xi16>) - return - } - - // Partial mask on the only chunk (64 active of 128 lanes). - func.func @TCI_i16_1x64(%start: i16, %dst_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c64 = arith.constant 64 : index - - %dst_view = pto.make_tensor_view %dst_ptr, - shape = [%c1, %c1, %c1, %c1, %c64], - strides = [%c64, %c64, %c64, %c64, %c1] - : !pto.tensor_view<1x1x1x1x64xi16> - %dst_part = pto.partition_view %dst_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c1, %c64] - : !pto.tensor_view<1x1x1x1x64xi16> -> !pto.partition_tensor_view<1x1x1x1x64xi16> - - %tile = pto.alloc_tile - : !pto.tile_buf - - pto.tci ins(%start : i16) - outs(%tile : !pto.tile_buf) - pto.tstore ins(%tile : !pto.tile_buf) - outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x64xi16>) - return - } - - // Exact single VReg: one full-mask chunk, no tail. - func.func @TCI_i16_1x128(%start: i16, %dst_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c128 = arith.constant 128 : index - - %dst_view = pto.make_tensor_view %dst_ptr, - shape = [%c1, %c1, %c1, %c1, %c128], - strides = [%c128, %c128, %c128, %c128, %c1] - : !pto.tensor_view<1x1x1x1x128xi16> - %dst_part = pto.partition_view %dst_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c1, %c128] - : !pto.tensor_view<1x1x1x1x128xi16> -> !pto.partition_tensor_view<1x1x1x1x128xi16> - - %tile = pto.alloc_tile - : !pto.tile_buf - - pto.tci ins(%start : i16) - outs(%tile : !pto.tile_buf) - pto.tstore ins(%tile : !pto.tile_buf) - outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xi16>) - return - } - - // One full VReg plus a 16-lane tail (tail mask = 16). - func.func @TCI_i16_1x144(%start: i16, %dst_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c144 = arith.constant 144 : index - - %dst_view = pto.make_tensor_view %dst_ptr, - shape = [%c1, %c1, %c1, %c1, %c144], - strides = [%c144, %c144, %c144, %c144, %c1] - : !pto.tensor_view<1x1x1x1x144xi16> - %dst_part = pto.partition_view %dst_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c1, %c144] - : !pto.tensor_view<1x1x1x1x144xi16> -> !pto.partition_tensor_view<1x1x1x1x144xi16> - - %tile = pto.alloc_tile - : !pto.tile_buf - - pto.tci ins(%start : i16) - outs(%tile : !pto.tile_buf) - pto.tstore ins(%tile : !pto.tile_buf) - outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x144xi16>) - return - } - - // Two chunks: one full VReg + partial tail of 32 lanes. - func.func @TCI_i16_1x160(%start: i16, %dst_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c160 = arith.constant 160 : index - - %dst_view = pto.make_tensor_view %dst_ptr, - shape = [%c1, %c1, %c1, %c1, %c160], - strides = [%c160, %c160, %c160, %c160, %c1] - : !pto.tensor_view<1x1x1x1x160xi16> - %dst_part = pto.partition_view %dst_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c1, %c160] - : !pto.tensor_view<1x1x1x1x160xi16> -> !pto.partition_tensor_view<1x1x1x1x160xi16> - - %tile = pto.alloc_tile - : !pto.tile_buf - - pto.tci ins(%start : i16) - outs(%tile : !pto.tile_buf) - pto.tstore ins(%tile : !pto.tile_buf) - outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x160xi16>) - return - } - - // Exact two VRegs: two full-mask chunks, no tail. - func.func @TCI_i16_1x256(%start: i16, %dst_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c256 = arith.constant 256 : index - - %dst_view = pto.make_tensor_view %dst_ptr, - shape = [%c1, %c1, %c1, %c1, %c256], - strides = [%c256, %c256, %c256, %c256, %c1] - : !pto.tensor_view<1x1x1x1x256xi16> - %dst_part = pto.partition_view %dst_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c1, %c256] - : !pto.tensor_view<1x1x1x1x256xi16> -> !pto.partition_tensor_view<1x1x1x1x256xi16> - - %tile = pto.alloc_tile - : !pto.tile_buf - - pto.tci ins(%start : i16) - outs(%tile : !pto.tile_buf) - pto.tstore ins(%tile : !pto.tile_buf) - outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x256xi16>) - return - } -} diff --git a/tilelang-dsl/python/tilelang_dsl/types.py b/tilelang-dsl/python/tilelang_dsl/types.py index 3f8762980..a572325c6 100644 --- a/tilelang-dsl/python/tilelang_dsl/types.py +++ b/tilelang-dsl/python/tilelang_dsl/types.py @@ -428,11 +428,7 @@ class PositionMode(str, Enum): class OrderMode(str, Enum): - # The serialized value is emitted into the MLIR `order` attribute of - # `pto.vci` and must match the canonical spelling documented by the VPTO - # spec (see `docs/vpto_spec/vpto-spec-current.md`). Hand-written VPTO IR - # and `VPTOLLVMEmitter::parseOrderImmediate` also use this short form. - ASC = "ASC" + ASC = "ORDER_ASC" class VcvtRoundMode(str, Enum): @@ -674,6 +670,12 @@ def constexpr(value: bool) -> bool: return value +def get_op_attr(name: str, default: Any = None) -> Any: + if not isinstance(name, str) or not name: + raise TypeError("get_op_attr expects a non-empty string attribute name") + return default + + __all__ = [ "ScalarType", "WildcardType", @@ -733,6 +735,7 @@ def constexpr(value: bool) -> bool: "mask_b16", "mask_b32", "constexpr", + "get_op_attr", "bytewidth", "get_lanes", "elements_per_vreg", diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index e993ef90f..59573d941 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -1700,7 +1700,7 @@ def kernel(inp: pto.TensorView, tile: pto.Tile): return None self.assertIn( - "`pto.vlds` does not support keyword arguments in TileLang DSL v1", + "unsupported keyword `offset` for `pto.vlds` in TileLang DSL v1", str(ctx.exception), ) self.assertIn(f"{__file__}:", str(ctx.exception)) @@ -3468,11 +3468,11 @@ def kernel(dst: pto.Tile, src: pto.Tile, seed: pto.i32): self.assertNotIn('position = "POS_LOWEST"', text) self.assertRegex( text, - r'pto\.vci\s+%[^\s]+\s+\{order = "ASC"\}\s+:', + r'pto\.vci\s+%[^\s]+\s+\{order = "ORDER_ASC"\}\s+:', ) self.assertNotRegex( text, - r'pto\.vci\s+%[^\s]+,\s*"ASC"\s+:', + r'pto\.vci\s+%[^\s]+,\s*"ORDER_ASC"\s+:', ) def test_vdup_scalar_input_rejects_position_argument(self) -> None: From 36e4d5517b3725aba64cafb03e804b322c12d04a Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Tue, 21 Apr 2026 09:46:11 +0800 Subject: [PATCH 105/192] fix(tilelang-dsl): support i64 vregs in DSL v1 (#151) --- tilelang-dsl/python/tilelang_dsl/lowering.py | 2 +- tilelang-dsl/python/tilelang_dsl/semantic.py | 4 +-- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 38 ++++++++++++++++++++ 3 files changed, 41 insertions(+), 3 deletions(-) diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index 9c011a170..4eaf241af 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -1815,7 +1815,7 @@ def _materialize_tile_window_extent( def _mask_granularity_for_dtype(self, dtype: ScalarType) -> str: int_bits = integer_bitwidth(dtype) - if dtype.name == "f32" or int_bits == 32: + if dtype.name == "f32" or int_bits in {32, 64}: return "b32" if dtype.name in {"f16", "bf16"} or int_bits == 16: return "b16" diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 462c33db6..b56422d8b 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -5813,7 +5813,7 @@ def _normalize_vstsx2_dist(self, expr: SemanticExpr) -> SemanticExpr: def _mask_granularity_for_dtype(self, dtype: ScalarType) -> str: int_bits = integer_bitwidth(dtype) - if dtype.name == "f32" or int_bits == 32: + if dtype.name == "f32" or int_bits in {32, 64}: return "b32" if dtype.name in {"f16", "bf16"} or int_bits == 16: return "b16" @@ -5823,7 +5823,7 @@ def _mask_granularity_for_dtype(self, dtype: ScalarType) -> str: def _vreg_type_for_dtype(self, dtype: ScalarType) -> SemanticVRegType: width = bytewidth(dtype) - if width not in {1, 2, 4}: + if width not in {1, 2, 4, 8}: raise TypeError(f"dtype `{dtype.name}` is not supported by vlds/vsts in TileLang DSL v1") return SemanticVRegType(element_dtype=dtype, lanes=256 // width) diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index 59573d941..71b2ea37e 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -175,6 +175,7 @@ def test_package_exports_surface(self) -> None: self.assertEqual(pto.bytewidth(pto.si16), 2) self.assertEqual(pto.bytewidth(pto.ui64), 8) self.assertEqual(pto.get_lanes(pto.ui32), 64) + self.assertEqual(pto.get_lanes(pto.i64), 32) self.assertEqual(pto.elements_per_vreg(pto.si8), 256) self.assertEqual(repr(pto.align), "align") @@ -2933,6 +2934,43 @@ def kernel(dst: pto.Tile, src: pto.Tile): r"= pto\.vcvt %[^,\s]+(?: \{[^}]+\})? : !pto\.vreg<[^>]+> -> !pto\.vreg<[^>]+>", ) + def test_vcvt_i32_to_i64_reuses_b32_mask_and_emits_i64_vreg(self) -> None: + @pto.vkernel( + op="vcvt_i32_to_i64_unique", + dtypes=[(pto.i64, pto.i32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + src_mask = pto.make_mask(pto.i32, pto.PAT.ALL) + dst_mask = pto.make_mask(pto.i64, pto.PAT.ALL) + vec = pto.vlds(src, 0, dist="UNPK_B32") + out = pto.vcvt( + vec, + pto.i64, + src_mask, + part=pto.VcvtPartMode.EVEN, + ) + pto.vsts(out, dst, 0, dst_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 32), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + vecscope = next(stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticVecscopeStmt)) + store_stmt = next(stmt for stmt in vecscope.body if isinstance(stmt, SemanticVectorStoreStmt)) + self.assertIsInstance(store_stmt.mask.type, SemanticMaskType) + self.assertEqual(store_stmt.mask.type.granularity, "b32") + + text = specialized.mlir_text() + self.assertIn("!pto.mask", text) + self.assertIn('dist = "UNPK_B32"', text) + self.assertRegex(text, r"!pto\.vreg<32xi64>") + self.assertIn('part = "EVEN"', text) + self.assertIn("pto.vsts", text) + def test_vtrc_defaults_to_round_nearest(self) -> None: @pto.vkernel( op="vtrc_default_rnd_unique", From 5090f588f7c42f7ba06375b38f161193e077bb38 Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Tue, 21 Apr 2026 01:46:14 +0800 Subject: [PATCH 106/192] bugfix: i8/i16 vcadd will widen the return type --- docs/isa/10-reduction-ops.md | 12 +++-- docs/vpto-spec.md | 6 +++ lib/PTO/IR/VPTO.cpp | 50 +++++++++++++++--- lib/PTO/Transforms/PTOToVPTOLowering.cpp | 42 ++++++++++++--- lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 67 +++++++++++++++++++++++- 5 files changed, 159 insertions(+), 18 deletions(-) diff --git a/docs/isa/10-reduction-ops.md b/docs/isa/10-reduction-ops.md index dec8ccbab..b264d6386 100644 --- a/docs/isa/10-reduction-ops.md +++ b/docs/isa/10-reduction-ops.md @@ -19,8 +19,8 @@ Operations that reduce a vector to a scalar or per-group result. ### `pto.vcadd` -- **syntax:** `%result = pto.vcadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` -- **A5 types:** i16-i64, f16, f32 +- **syntax:** `%result = pto.vcadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 - **semantics:** Sum all elements. Result in lane 0, others zeroed. ```c @@ -35,9 +35,11 @@ for (int i = 1; i < N; i++) - **inputs:** `%input` is the source vector and `%mask` selects participating lanes. - **outputs:** `%result` contains the reduction result in its low element(s). -- **constraints and limitations:** Some narrow integer forms may widen the - internal accumulation or result placement. If all predicate bits are zero, the - result is zero. +- **constraints and limitations:** On A5, `i8/u8` inputs produce widened + `i16/u16` results with half as many lanes (`M = N / 2`), and `i16/u16` inputs + produce widened `i32/u32` results with half as many lanes. For + `i32/u32/f16/f32` inputs, `U = T` and `M = N`. If all predicate bits are + zero, the result is zero. --- diff --git a/docs/vpto-spec.md b/docs/vpto-spec.md index 1c8e27203..70bc42c55 100644 --- a/docs/vpto-spec.md +++ b/docs/vpto-spec.md @@ -852,6 +852,12 @@ for (int g = 0; g < 8; g++) { } ``` +For A5 reduction result types: + +- `pto.vcadd` widens `i8 -> i16`, `u8 -> u16`, `i16 -> i32`, and `u16 -> u32`, + with the lane count halved in each widening case. +- `pto.vcadd` keeps the same result type for `f16`, `f32`, `i32`, and `u32`. + ### Template Placeholder Conventions | Placeholder | Meaning | diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp index 90e8b384b..128de940c 100644 --- a/lib/PTO/IR/VPTO.cpp +++ b/lib/PTO/IR/VPTO.cpp @@ -1148,13 +1148,51 @@ LogicalResult VbrOp::verify() { return success(); } -LogicalResult VcaddOp::verify() { - if (failed(verifyVRegTypeLike(*this, getInput().getType(), "input")) || - failed(verifyVRegTypeLike(*this, getResult().getType(), "result"))) +template +static LogicalResult verifyWideningReductionVecOp(ReductionOp op, + StringRef opName) { + if (failed(verifyVRegTypeLike(op, op.getInput().getType(), "input")) || + failed(verifyVRegTypeLike(op, op.getResult().getType(), "result"))) return failure(); - if (getInput().getType() != getResult().getType()) - return emitOpError("input and result must have the same vector type"); - return success(); + + auto inputType = dyn_cast(op.getInput().getType()); + auto resultType = dyn_cast(op.getResult().getType()); + if (!inputType || !resultType) + return failure(); + + Type inputElemType = inputType.getElementType(); + Type expectedResultElemType = inputElemType; + int64_t expectedResultLanes = inputType.getElementCount(); + if (auto inputInt = dyn_cast(inputElemType)) { + if (inputInt.getWidth() < 8 || inputInt.getWidth() > 32) + return op.emitOpError( + "requires 8-bit, 16-bit, or 32-bit integer vector element type"); + if (inputInt.getWidth() == 8) { + expectedResultElemType = + IntegerType::get(op.getContext(), 16, inputInt.getSignedness()); + expectedResultLanes = inputType.getElementCount() / 2; + } + if (inputInt.getWidth() == 16) { + expectedResultElemType = + IntegerType::get(op.getContext(), 32, inputInt.getSignedness()); + expectedResultLanes = inputType.getElementCount() / 2; + } + } else if (!inputElemType.isF16() && !inputElemType.isF32()) { + return op.emitOpError("requires i16/i32/f16/f32 vector element type"); + } + + if (resultType.getElementCount() == expectedResultLanes && + resultType.getElementType() == expectedResultElemType) + return success(); + + return op.emitOpError() << opName << " expects result type !pto.vreg<" + << expectedResultLanes << "x" + << expectedResultElemType + << " for input element type " << inputElemType; +} + +LogicalResult VcaddOp::verify() { + return verifyWideningReductionVecOp(*this, "vcadd"); } LogicalResult VcmaxOp::verify() { diff --git a/lib/PTO/Transforms/PTOToVPTOLowering.cpp b/lib/PTO/Transforms/PTOToVPTOLowering.cpp index ff8df34f9..afdab6f52 100644 --- a/lib/PTO/Transforms/PTOToVPTOLowering.cpp +++ b/lib/PTO/Transforms/PTOToVPTOLowering.cpp @@ -36,6 +36,28 @@ namespace { constexpr StringLiteral kLoweredLoopScopeAttrName = "llvm.loop.aivector_scope"; +static Type getVcaddResultElementType(MLIRContext *context, Type inputElementType) { + if (auto intType = dyn_cast(inputElementType)) { + if (intType.getWidth() == 8) + return IntegerType::get(context, 16, intType.getSignedness()); + if (intType.getWidth() == 16) + return IntegerType::get(context, 32, intType.getSignedness()); + } + return inputElementType; +} + +static pto::VRegType getVcaddResultVRegType(MLIRContext *context, + pto::VRegType inputType) { + int64_t resultLanes = inputType.getElementCount(); + if (auto intType = dyn_cast(inputType.getElementType())) { + if (intType.getWidth() == 8 || intType.getWidth() == 16) + resultLanes /= 2; + } + return pto::VRegType::get( + context, resultLanes, + getVcaddResultElementType(context, inputType.getElementType())); +} + struct ResolvedTensorView { Value root; Attribute layoutAttr; @@ -2414,7 +2436,9 @@ LogicalResult buildRowReduceVecScope(StringRef family, Value reduced; if (family == "rowsum") - reduced = rewriter.create(loc, vecType, srcVec, srcPredicate); + reduced = rewriter.create( + loc, getVcaddResultVRegType(rewriter.getContext(), vecType), srcVec, + srcPredicate); else if (family == "rowmax") reduced = rewriter.create(loc, vecType, srcVec, srcPredicate); else if (family == "rowmin") @@ -2423,9 +2447,11 @@ LogicalResult buildRowReduceVecScope(StringRef family, return emitError(loc) << "unsupported VPTO row-reduce family: " << family; Value fullMask = buildAllPredicateMask(rewriter, loc, contract.elementType); - if (family == "rowsum") + if (family == "rowsum") { + if (reduced.getType() != vecType) + reduced = rewriter.create(loc, vecType, reduced); acc = rewriter.create(loc, vecType, acc, reduced, fullMask); - else if (family == "rowmax") + } else if (family == "rowmax") acc = rewriter.create(loc, vecType, acc, reduced, fullMask); else acc = rewriter.create(loc, vecType, acc, reduced, fullMask); @@ -2463,7 +2489,9 @@ LogicalResult buildRowReduceVecScope(StringRef family, Value reduced; if (family == "rowsum") - reduced = rewriter.create(loc, vecType, srcVec, srcPredicate); + reduced = rewriter.create( + loc, getVcaddResultVRegType(rewriter.getContext(), vecType), srcVec, + srcPredicate); else if (family == "rowmax") reduced = rewriter.create(loc, vecType, srcVec, srcPredicate); else if (family == "rowmin") @@ -2472,9 +2500,11 @@ LogicalResult buildRowReduceVecScope(StringRef family, return emitError(loc) << "unsupported VPTO row-reduce family: " << family; Value fullMask = buildAllPredicateMask(rewriter, loc, contract.elementType); - if (family == "rowsum") + if (family == "rowsum") { + if (reduced.getType() != vecType) + reduced = rewriter.create(loc, vecType, reduced); acc = rewriter.create(loc, vecType, acc, reduced, fullMask); - else if (family == "rowmax") + } else if (family == "rowmax") acc = rewriter.create(loc, vecType, acc, reduced, fullMask); else acc = rewriter.create(loc, vecType, acc, reduced, fullMask); diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index 2505a4e48..948bb33f3 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -145,6 +145,22 @@ static FailureOr buildLaneTypedCallee(MLIRContext *context, .getValue(); } +static FailureOr buildLaneTypedCalleeFromInput(MLIRContext *context, + Type inputType, + StringRef stem, + StringRef suffix) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(inputType)); + auto lanes = getElementCountFromVectorLike(inputType); + if (vec.empty() || !lanes) + return failure(); + + return StringAttr::get(context, "llvm.hivm." + stem.str() + ".v" + + std::to_string(*lanes) + vec + + suffix.str()) + .getValue(); +} + static std::string getElementTypeFragment(Type type) { if (type.isF16()) return "f16"; @@ -2262,6 +2278,55 @@ class LowerReductionUnaryOpPattern final LoweringState &state; }; +template +class LowerWideningReductionUnaryOpPattern final + : public OpConversionPattern { +public: + explicit LowerWideningReductionUnaryOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(ReductionOp op, typename ReductionOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = buildLaneTypedCalleeFromInput( + op.getContext(), op.getInput().getType(), + getReductionUnaryStem(), ".x"); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, + "unsupported widening reduction VPTO signature"); + + Type inputType = + this->getTypeConverter()->convertType(op.getInput().getType()); + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + Type maskType = this->getTypeConverter()->convertType(op.getMask().getType()); + if (!inputType || !resultType || !maskType) + return rewriter.notifyMatchFailure(op, + "failed to convert widening reduction types"); + + Value input = adaptor.getInput(); + Value mask = adaptor.getMask(); + if (!input || !mask || input.getType() != inputType || + mask.getType() != maskType) { + return rewriter.notifyMatchFailure( + op, "unexpected converted widening reduction operand types"); + } + + auto call = rewriter.create(op.getLoc(), *calleeName, + TypeRange{resultType}, + ValueRange{input, mask}); + state.plannedDecls.push_back( + PlannedDecl{calleeName->str(), call.getCalleeType()}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + class LowerVselOpPattern final : public OpConversionPattern { public: explicit LowerVselOpPattern(TypeConverter &typeConverter, MLIRContext *context, @@ -4823,7 +4888,7 @@ static void populateVPTOOpLoweringPatterns(VPTOTypeConverter &typeConverter, LowerVecScalarMaskedOpPattern, LowerVecScalarMaskedOpPattern, LowerVecScalarMaskedOpPattern, - LowerReductionUnaryOpPattern, + LowerWideningReductionUnaryOpPattern, LowerReductionUnaryOpPattern, LowerReductionUnaryOpPattern, LowerReductionUnaryOpPattern, From 67a7dc4f86825431630e91b86159123e4fa2fa8e Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Tue, 21 Apr 2026 11:03:55 +0800 Subject: [PATCH 107/192] bugfix: i8/i16 vcadd res type --- tilelang-dsl/python/tilelang_dsl/semantic.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index b56422d8b..b20e4b720 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -4631,7 +4631,10 @@ def _analyze_unary_vector_op( vreg = self._require_vreg_expr(value, f"pto.{name} value") self._require_mask_for_vreg(mask, vreg, f"pto.{name}") self._validate_unary_dtype(name, vreg.element_dtype) - return SemanticCallExpr(namespace="pto", name=name, args=args, type=vreg) + result_type = vreg + if name == "vcadd": + result_type = self._vcadd_result_vreg_type(vreg) + return SemanticCallExpr(namespace="pto", name=name, args=args, type=result_type) def _analyze_binary_vector_op( self, @@ -5243,6 +5246,20 @@ def _vreg_type_for_scalar_or_index(self, expr: SemanticExpr, context: str) -> Se return self._vreg_type_for_dtype(value.type.dtype) return self._vreg_type_for_dtype(i32) + def _vcadd_result_vreg_type(self, vreg_type: SemanticVRegType) -> SemanticVRegType: + dtype = vreg_type.element_dtype + if not is_integer_dtype(dtype): + return vreg_type + signedness = integer_signedness(dtype) + bitwidth = integer_bitwidth(dtype) + if bitwidth == 8: + widened_dtype = ui16 if signedness == "unsigned" else i16 + return self._vreg_type_for_dtype(widened_dtype) + if bitwidth == 16: + widened_dtype = ui32 if signedness == "unsigned" else i32 + return self._vreg_type_for_dtype(widened_dtype) + return vreg_type + def _normalize_position_mode( self, expr: SemanticExpr | None, From 696fd6cc5374a5819f481180e56aa479886fa273 Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Tue, 21 Apr 2026 11:16:33 +0800 Subject: [PATCH 108/192] bugfix: update i8/i16 vcadd dsl docs --- .../user_guide/11-vector-arithmetic-operations.md | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md b/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md index f935ca8ca..8ec00bef4 100644 --- a/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md +++ b/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md @@ -117,18 +117,24 @@ abs_vec = pto.vabs(vec_f32, mask32) #### `pto.vcadd(vec: VRegType, mask: MaskType) -> VRegType` -**Description**: Complex addition of vector elements (treating pairs as complex numbers). +**Description**: Reduction add of vector elements. **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| -| `vec` | `VRegType` | Input vector (interpreted as complex pairs) | +| `vec` | `VRegType` | Input vector | | `mask` | `MaskType` | Predicate mask | **Returns**: | Return Value | Type | Description | |--------------|------|-------------| -| `result` | `VRegType` | Complex addition result | +| `result` | `VRegType` | Reduction result vector | + +**Type Rules**: +- For floating-point inputs and `i32/ui32`, the result vector type matches the input vector type. +- For `i8/ui8` inputs, `pto.vcadd` returns a widened `i16/ui16` vector. +- For `i16/ui16` inputs, `pto.vcadd` returns a widened `i32/ui32` vector. +- The result mask granularity follows the result vector element type. #### `pto.vcmax(vec: VRegType, mask: MaskType) -> VRegType` From 2d4647c41596438b7210c74e5c6d4b9391813153 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Tue, 21 Apr 2026 10:40:58 +0800 Subject: [PATCH 109/192] fix(tilelang-dsl): align f16->i32 vcvt contract (#152) --- lib/PTO/IR/VPTO.cpp | 4 +- .../11-vector-arithmetic-operations.md | 16 +++ tilelang-dsl/python/tilelang_dsl/semantic.py | 2 +- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 119 ++++++++++++++++++ 4 files changed, 139 insertions(+), 2 deletions(-) diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp index 128de940c..bfcc4892d 100644 --- a/lib/PTO/IR/VPTO.cpp +++ b/lib/PTO/IR/VPTO.cpp @@ -640,9 +640,11 @@ static std::optional lookupVcvtContract(VcvtElemKind src, case VcvtElemKind::F16: switch (dst) { case VcvtElemKind::F32: - case VcvtElemKind::S32: return VcvtContract{/*requiresRnd=*/false, /*requiresSat=*/false, /*requiresPart=*/true}; + case VcvtElemKind::S32: + return VcvtContract{/*requiresRnd=*/true, /*requiresSat=*/false, + /*requiresPart=*/true}; case VcvtElemKind::S16: return VcvtContract{/*requiresRnd=*/true, /*requiresSat=*/true, /*requiresPart=*/false}; diff --git a/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md b/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md index 8ec00bef4..47d0401a4 100644 --- a/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md +++ b/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md @@ -1388,6 +1388,14 @@ family. A5 `vcvt` type matrix, width-changing packing rules, and attribute-sensitive forms, refer to [`../vpto_spec/vpto-spec-current.md`](../vpto_spec/vpto-spec-current.md). +- Attribute requirements are type-pair specific. The DSL enforces the same + per-form contract as VPTO, so some pairs require attributes while others + reject them. +- Examples: + `f32 -> si32` requires `rnd` and `sat`; + `f16 -> si32` requires `rnd` and `part`, and rejects `sat`; + `f16 -> f32` requires `part`; + `si32 -> f32` requires `rnd`. - VPTO does not define a `mask_b64` form. Conversions that produce `si64` results still use the typed mask granularity of the source vector family. - Width-changing conversions continue to follow VPTO packing semantics even on @@ -1403,6 +1411,14 @@ vec_f32 = pto.vcvt(vec_f16, pto.f32, mask16) mask32 = pto.make_mask(pto.f32, PAT.ALL) vec_i32 = pto.vcvt(vec_f32, pto.si32, mask32) +vec_i32_wide = pto.vcvt( + vec_f16, + pto.si32, + mask16, + rnd=pto.VcvtRoundMode.R, + part=pto.VcvtPartMode.EVEN, +) + vec_f16_narrow = pto.vcvt( vec_f32, pto.f16, diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index b20e4b720..b6871fd18 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -154,7 +154,7 @@ ("f32", "s64"): (True, True, True), ("f32", "s32"): (True, True, False), ("f16", "f32"): (False, False, True), - ("f16", "s32"): (False, False, True), + ("f16", "s32"): (True, False, True), ("f16", "s16"): (True, True, False), ("f16", "s8"): (True, True, True), ("f16", "u8"): (True, True, True), diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index 71b2ea37e..c1d057566 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -3197,6 +3197,125 @@ def kernel(dst: pto.Tile, src: pto.Tile): self.assertIn("does not accept `part=`", str(ctx.exception)) + def test_vcvt_f16_to_i32_requires_rnd_and_part(self) -> None: + with self.assertRaises(TypeError) as ctx: + + @pto.vkernel( + op="vcvt_f16_to_i32_missing_rnd_unique", + dtypes=[(pto.i32, pto.f16)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + src_mask = pto.make_mask(pto.f16, pto.PAT.ALL) + dst_mask = pto.make_mask(pto.i32, pto.PAT.ALL) + vec = pto.vlds(src, 0, dist="UNPK_B16") + out = pto.vcvt( + vec, + pto.i32, + src_mask, + part=pto.VcvtPartMode.EVEN, + ) + pto.vsts(out, dst, 0, dst_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + specialized.mlir_text() + + self.assertIn("requires explicit `rnd=`", str(ctx.exception)) + + with self.assertRaises(TypeError) as ctx: + + @pto.vkernel( + op="vcvt_f16_to_i32_missing_part_unique", + dtypes=[(pto.i32, pto.f16)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + src_mask = pto.make_mask(pto.f16, pto.PAT.ALL) + dst_mask = pto.make_mask(pto.i32, pto.PAT.ALL) + vec = pto.vlds(src, 0, dist="UNPK_B16") + out = pto.vcvt( + vec, + pto.i32, + src_mask, + rnd=pto.VcvtRoundMode.R, + ) + pto.vsts(out, dst, 0, dst_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + specialized.mlir_text() + + self.assertIn("requires explicit `part=`", str(ctx.exception)) + + def test_vcvt_f16_to_i32_rejects_sat(self) -> None: + with self.assertRaises(TypeError) as ctx: + + @pto.vkernel( + op="vcvt_f16_to_i32_sat_unique", + dtypes=[(pto.i32, pto.f16)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + src_mask = pto.make_mask(pto.f16, pto.PAT.ALL) + dst_mask = pto.make_mask(pto.i32, pto.PAT.ALL) + vec = pto.vlds(src, 0, dist="UNPK_B16") + out = pto.vcvt( + vec, + pto.i32, + src_mask, + rnd=pto.VcvtRoundMode.R, + sat=pto.VcvtSatMode.SAT, + part=pto.VcvtPartMode.EVEN, + ) + pto.vsts(out, dst, 0, dst_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + specialized.mlir_text() + + self.assertIn("does not accept `sat=`", str(ctx.exception)) + + def test_vcvt_f16_to_i32_accepts_rnd_and_part(self) -> None: + @pto.vkernel( + op="vcvt_f16_to_i32_attrs_unique", + dtypes=[(pto.i32, pto.f16)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + src_mask = pto.make_mask(pto.f16, pto.PAT.ALL) + dst_mask = pto.make_mask(pto.i32, pto.PAT.ALL) + vec = pto.vlds(src, 0, dist="UNPK_B16") + out = pto.vcvt( + vec, + pto.i32, + src_mask, + rnd=pto.VcvtRoundMode.R, + part=pto.VcvtPartMode.EVEN, + ) + pto.vsts(out, dst, 0, dst_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vcvt", text) + self.assertIn('rnd = "R"', text) + self.assertIn('part = "EVEN"', text) + self.assertNotIn('sat = "SAT"', text) + def test_vbitcast_supports_direct_interface(self) -> None: @pto.vkernel( op="vbitcast_direct_unique", From d1890168a299fe7a68b304904e73044164d429fc Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Tue, 21 Apr 2026 11:23:22 +0800 Subject: [PATCH 110/192] Support bf16->f16 convert --- docs/isa/09-conversion-ops.md | 3 +- lib/PTO/IR/VPTO.cpp | 3 + lib/PTO/Transforms/PTOToVPTOLowering.cpp | 16 ++- lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 24 +++- .../11-vector-arithmetic-operations.md | 9 ++ .../docs/vpto_spec/vpto-spec-current.md | 3 +- tilelang-dsl/python/tilelang_dsl/semantic.py | 1 + tilelang-dsl/tests/test_tilelang_dsl_v1.py | 119 ++++++++++++++++++ 8 files changed, 172 insertions(+), 6 deletions(-) diff --git a/docs/isa/09-conversion-ops.md b/docs/isa/09-conversion-ops.md index bcad3f1af..f3953089b 100644 --- a/docs/isa/09-conversion-ops.md +++ b/docs/isa/09-conversion-ops.md @@ -136,6 +136,7 @@ as `32 -> 16` or `16 -> 32` style conversions. - `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16>` - `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xbf16>` +- `%dst = pto.vcvt %src, %mask {rnd, sat} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<128xf16>` - `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32>` - `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xf32>` @@ -182,7 +183,7 @@ per-form entries above as the source of truth. | `si64` | | | | | | | | | | | | `f16` | Y | Y | | Y | | Y | | | Y | | | `f32` | | | | Y | | Y | Y | Y | | Y | -| `bf16` | | | | | | Y | | | Y | | +| `bf16` | | | | | | Y | | Y | Y | | --- diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp index bfcc4892d..5c2bfd264 100644 --- a/lib/PTO/IR/VPTO.cpp +++ b/lib/PTO/IR/VPTO.cpp @@ -657,6 +657,9 @@ static std::optional lookupVcvtContract(VcvtElemKind src, } case VcvtElemKind::BF16: switch (dst) { + case VcvtElemKind::F16: + return VcvtContract{/*requiresRnd=*/true, /*requiresSat=*/true, + /*requiresPart=*/false}; case VcvtElemKind::F32: return VcvtContract{/*requiresRnd=*/false, /*requiresSat=*/false, /*requiresPart=*/true}; diff --git a/lib/PTO/Transforms/PTOToVPTOLowering.cpp b/lib/PTO/Transforms/PTOToVPTOLowering.cpp index afdab6f52..d8bf008a8 100644 --- a/lib/PTO/Transforms/PTOToVPTOLowering.cpp +++ b/lib/PTO/Transforms/PTOToVPTOLowering.cpp @@ -1769,6 +1769,7 @@ enum class VPTOCvtLoweringKind { Vtrc, F32ToBF16, F16ToF32, + BF16ToF16, BF16ToF32, }; @@ -1780,6 +1781,8 @@ static FailureOr classifyA5CvtLowering(Type srcElemType, return VPTOCvtLoweringKind::F32ToBF16; if (srcElemType.isF16() && dstElemType.isF32()) return VPTOCvtLoweringKind::F16ToF32; + if (srcElemType.isBF16() && dstElemType.isF16()) + return VPTOCvtLoweringKind::BF16ToF16; if (srcElemType.isBF16() && dstElemType.isF32()) return VPTOCvtLoweringKind::BF16ToF32; return failure(); @@ -4786,7 +4789,7 @@ LogicalResult lowerTCVT(TCvtOp op, PatternRewriter &rewriter) { classifyA5CvtLowering(contract.elementType, dstElementType); if (failed(loweringKind)) return op.emitOpError( - "current tcvt lowering supports only f32->f32, f32->bf16, f16->f32, and bf16->f32"); + "current tcvt lowering supports only f32->f32, f32->bf16, f16->f32, bf16->f16, and bf16->f32"); FailureOr roundMode = stringifyA5RoundMode(op, rewriter); if (failed(roundMode)) @@ -4891,6 +4894,17 @@ LogicalResult lowerTCVT(TCvtOp op, PatternRewriter &rewriter) { buildAllPredicateMask(rewriter, op.getLoc(), dstElementType)); break; } + case VPTOCvtLoweringKind::BF16ToF16: { + auto loaded = + rewriter.create(op.getLoc(), srcVecType, srcBuffer, offset, StringAttr()); + Value converted = rewriter.create( + op.getLoc(), dstVecType, loaded.getResult(), *roundMode, + rewriter.getStringAttr("RS_ENABLE"), StringAttr()); + rewriter.create( + op.getLoc(), converted, dstBuffer, offset, StringAttr(), + buildAllPredicateMask(rewriter, op.getLoc(), dstElementType)); + break; + } case VPTOCvtLoweringKind::BF16ToF32: { auto loaded = rewriter.create( op.getLoc(), srcVecType, srcBuffer, offset, rewriter.getStringAttr("UNPK_B16")); diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index 948bb33f3..6cc1c90c3 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -117,6 +117,7 @@ struct VcvtContract { bool requiresSat; bool requiresPart; unsigned maskBitWidth; + bool satBeforeRnd = false; }; static Value getI64Constant(OpBuilder &builder, Location loc, uint64_t value) { @@ -487,6 +488,9 @@ static std::optional lookupVcvtContract(VcvtElemKind src, } case VcvtElemKind::BF16: switch (dst) { + case VcvtElemKind::F16: + return VcvtContract{"llvm.hivm.vcvtff.bf162f16.x", true, true, false, 16, + true}; case VcvtElemKind::F32: return VcvtContract{"llvm.hivm.vcvtff.bf162f32.x", false, false, true, 16}; case VcvtElemKind::S32: @@ -4140,7 +4144,7 @@ class LowerVcvtOpPattern final : public OpConversionPattern { callArgs.push_back(*mask); argTypes.push_back((*mask).getType()); - if ((*contract).requiresRnd) { + auto appendRndArg = [&]() -> LogicalResult { auto roundMode = op.getRndAttr() ? parseRoundModeImmediate(*op.getRnd()) : std::nullopt; if (!roundMode) @@ -4148,9 +4152,10 @@ class LowerVcvtOpPattern final : public OpConversionPattern { Value roundValue = getI32Constant(rewriter, op.getLoc(), *roundMode); callArgs.push_back(roundValue); argTypes.push_back(roundValue.getType()); - } + return success(); + }; - if ((*contract).requiresSat) { + auto appendSatArg = [&]() -> LogicalResult { auto saturation = op.getSatAttr() ? parseSaturationImmediate(*op.getSat()) : std::nullopt; if (!saturation) @@ -4158,6 +4163,19 @@ class LowerVcvtOpPattern final : public OpConversionPattern { Value satValue = getI32Constant(rewriter, op.getLoc(), *saturation); callArgs.push_back(satValue); argTypes.push_back(satValue.getType()); + return success(); + }; + + if ((*contract).satBeforeRnd) { + if ((*contract).requiresSat && failed(appendSatArg())) + return failure(); + if ((*contract).requiresRnd && failed(appendRndArg())) + return failure(); + } else { + if ((*contract).requiresRnd && failed(appendRndArg())) + return failure(); + if ((*contract).requiresSat && failed(appendSatArg())) + return failure(); } if ((*contract).requiresPart) { diff --git a/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md b/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md index 47d0401a4..dd60b3692 100644 --- a/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md +++ b/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md @@ -1394,6 +1394,7 @@ family. - Examples: `f32 -> si32` requires `rnd` and `sat`; `f16 -> si32` requires `rnd` and `part`, and rejects `sat`; + `bf16 -> f16` requires `rnd` and `sat`; `f16 -> f32` requires `part`; `si32 -> f32` requires `rnd`. - VPTO does not define a `mask_b64` form. Conversions that produce `si64` @@ -1419,6 +1420,14 @@ vec_i32_wide = pto.vcvt( part=pto.VcvtPartMode.EVEN, ) +vec_f16_from_bf16 = pto.vcvt( + vec_bf16, + pto.f16, + mask16, + rnd=pto.VcvtRoundMode.R, + sat=pto.VcvtSatMode.SAT, +) + vec_f16_narrow = pto.vcvt( vec_f32, pto.f16, diff --git a/tilelang-dsl/docs/vpto_spec/vpto-spec-current.md b/tilelang-dsl/docs/vpto_spec/vpto-spec-current.md index b92097043..8db80caef 100644 --- a/tilelang-dsl/docs/vpto_spec/vpto-spec-current.md +++ b/tilelang-dsl/docs/vpto_spec/vpto-spec-current.md @@ -4046,6 +4046,7 @@ as `32 -> 16` or `16 -> 32` style conversions. - `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16>` - `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xbf16>` +- `%dst = pto.vcvt %src, %mask {rnd, sat} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<128xf16>` - `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32>` - `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xf32>` @@ -4092,7 +4093,7 @@ per-form entries above as the source of truth. | `si64` | | | | | | | | | | | | `f16` | Y | Y | | Y | | Y | | | Y | | | `f32` | | | | Y | | Y | Y | Y | | Y | -| `bf16` | | | | | | Y | | | Y | | +| `bf16` | | | | | | Y | | Y | Y | | --- diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index b6871fd18..3d1a941a7 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -158,6 +158,7 @@ ("f16", "s16"): (True, True, False), ("f16", "s8"): (True, True, True), ("f16", "u8"): (True, True, True), + ("bf16", "f16"): (True, True, False), ("bf16", "f32"): (False, False, True), ("bf16", "s32"): (True, True, True), ("u8", "f16"): (False, False, True), diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index c1d057566..c2ae22241 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -3316,6 +3316,125 @@ def kernel(dst: pto.Tile, src: pto.Tile): self.assertIn('part = "EVEN"', text) self.assertNotIn('sat = "SAT"', text) + def test_vcvt_bf16_to_f16_requires_rnd_and_sat(self) -> None: + with self.assertRaises(TypeError) as ctx: + + @pto.vkernel( + op="vcvt_bf16_to_f16_missing_rnd_unique", + dtypes=[(pto.f16, pto.bf16)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + src_mask = pto.make_mask(pto.bf16, pto.PAT.ALL) + dst_mask = pto.make_mask(pto.f16, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vcvt( + vec, + pto.f16, + src_mask, + sat=pto.VcvtSatMode.SAT, + ) + pto.vsts(out, dst, 0, dst_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + specialized.mlir_text() + + self.assertIn("requires explicit `rnd=`", str(ctx.exception)) + + with self.assertRaises(TypeError) as ctx: + + @pto.vkernel( + op="vcvt_bf16_to_f16_missing_sat_unique", + dtypes=[(pto.f16, pto.bf16)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + src_mask = pto.make_mask(pto.bf16, pto.PAT.ALL) + dst_mask = pto.make_mask(pto.f16, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vcvt( + vec, + pto.f16, + src_mask, + rnd=pto.VcvtRoundMode.R, + ) + pto.vsts(out, dst, 0, dst_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + specialized.mlir_text() + + self.assertIn("requires explicit `sat=`", str(ctx.exception)) + + def test_vcvt_bf16_to_f16_rejects_part(self) -> None: + with self.assertRaises(TypeError) as ctx: + + @pto.vkernel( + op="vcvt_bf16_to_f16_part_unique", + dtypes=[(pto.f16, pto.bf16)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + src_mask = pto.make_mask(pto.bf16, pto.PAT.ALL) + dst_mask = pto.make_mask(pto.f16, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vcvt( + vec, + pto.f16, + src_mask, + rnd=pto.VcvtRoundMode.R, + sat=pto.VcvtSatMode.SAT, + part=pto.VcvtPartMode.EVEN, + ) + pto.vsts(out, dst, 0, dst_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + specialized.mlir_text() + + self.assertIn("does not accept `part=`", str(ctx.exception)) + + def test_vcvt_bf16_to_f16_accepts_rnd_and_sat(self) -> None: + @pto.vkernel( + op="vcvt_bf16_to_f16_attrs_unique", + dtypes=[(pto.f16, pto.bf16)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + src_mask = pto.make_mask(pto.bf16, pto.PAT.ALL) + dst_mask = pto.make_mask(pto.f16, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vcvt( + vec, + pto.f16, + src_mask, + rnd=pto.VcvtRoundMode.R, + sat=pto.VcvtSatMode.SAT, + ) + pto.vsts(out, dst, 0, dst_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vcvt", text) + self.assertIn('rnd = "R"', text) + self.assertIn('sat = "SAT"', text) + self.assertNotIn('part = "EVEN"', text) + def test_vbitcast_supports_direct_interface(self) -> None: @pto.vkernel( op="vbitcast_direct_unique", From 8caf58328b0ec6a9c757f3c45759d864cc3da68c Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Tue, 21 Apr 2026 11:52:40 +0800 Subject: [PATCH 111/192] Support materialize PadValue with eval(dtype) interface --- .../docs/user_guide/05-type-system.md | 15 +++++++- .../docs/user_guide/08-sync-dma-operations.md | 10 ++++-- tilelang-dsl/python/tilelang_dsl/kernel.py | 4 +-- tilelang-dsl/python/tilelang_dsl/semantic.py | 32 ++++++++++++----- tilelang-dsl/python/tilelang_dsl/types.py | 6 ++-- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 34 +++++++++++++++---- 6 files changed, 78 insertions(+), 23 deletions(-) diff --git a/tilelang-dsl/docs/user_guide/05-type-system.md b/tilelang-dsl/docs/user_guide/05-type-system.md index 71c916498..eadd7d270 100644 --- a/tilelang-dsl/docs/user_guide/05-type-system.md +++ b/tilelang-dsl/docs/user_guide/05-type-system.md @@ -382,12 +382,19 @@ pad2 = pto.PadValue.custom_f32("0xBF800000") # float32 bit pattern for -1.0f ``` Notes: -- `PadValue.encoded` exposes the host-side uint64 payload. `PadValue.value` is intentionally unavailable to avoid confusion with kernel-side `.eval()`. +- `PadValue.encoded` exposes the host-side uint64 payload. `PadValue.value` is intentionally unavailable to avoid confusion with `.eval(...)` scalar materialization. - `PadValue.text` exposes the standard textual spelling for built-ins such as `null` and `zero`. - Custom pad values currently model an `f32` payload. In DSL v1, materializing a custom pad into a scalar is only supported for floating tile element dtypes. - `PadValue.NULL` does not denote a usable scalar fill constant. Calling `tile.pad_value.eval()` or `tile.config.pad_value.eval()` when the enum is `NULL` is a frontend error. - **DMA padding**: When performing GM→UB DMA transfers with padding enabled (via `enable_ub_pad=True` in `pto.copy_gm_to_ubuf`), the pad value must be configured explicitly using `pto.set_mov_pad_val`. Tile `PadValue` descriptors are not automatically translated to hardware register configurations in TileLang DSL v1. See [Pad Fill Semantics](08-sync-dma-operations.md#pad-fill-semantics) for usage details. +Host-side code can materialize a scalar with an explicit dtype: + +```python +pad_max_f32 = pto.PadValue.MAX.eval(pto.f32) +pad_min_i16 = pto.PadValue.MIN.eval(pto.i16) +``` + #### Tile Shape Concepts - `shape` is the static physical allocation size of the tile buffer. @@ -427,6 +434,12 @@ rank = tile.rank # 2 For dtype-dependent fill seeds, prefer `tile.pad_value.eval()` over handwritten `if dtype == ...` ladders. +For standalone `PadValue` symbols that are not bound to a tile, pass the target dtype explicitly: + +```python +pad_scalar = pto.PadValue.MAX.eval(pto.f32) +``` + ```python @pto.vkernel(op="fill_pad_value", dtypes=[(pto.AnyType,)]) def fill_pad_value(dst: pto.Tile): diff --git a/tilelang-dsl/docs/user_guide/08-sync-dma-operations.md b/tilelang-dsl/docs/user_guide/08-sync-dma-operations.md index 023304715..e13516d0f 100644 --- a/tilelang-dsl/docs/user_guide/08-sync-dma-operations.md +++ b/tilelang-dsl/docs/user_guide/08-sync-dma-operations.md @@ -335,7 +335,7 @@ pto.set_mov_pad_val(pad_value: ScalarType) -> None **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| -| `pad_value` | `ScalarType` | Scalar value used for padding. Supported types: `pto.i8`, `pto.i16`, `pto.i32`, `pto.f16`, `pto.bf16`, `pto.f32`. The value's bit pattern is encoded into the hardware pad register. For standard pad values, use `PadValue.eval()` to obtain the appropriate scalar: `0` or `0.0` for `PadValue.ZERO`, dtype-aware maximum for `PadValue.MAX`, dtype-aware minimum for `PadValue.MIN`. | +| `pad_value` | `ScalarType` | Scalar value used for padding. Supported types: `pto.i8`, `pto.i16`, `pto.i32`, `pto.f16`, `pto.bf16`, `pto.f32`. The value's bit pattern is encoded into the hardware pad register. For standard pad values, use `PadValue.eval(...)` to obtain the appropriate scalar: `0` or `0.0` for `PadValue.ZERO`, dtype-aware maximum for `PadValue.MAX`, dtype-aware minimum for `PadValue.MIN`. | **Returns**: None (side-effect operation) @@ -381,6 +381,12 @@ if pto.constexpr(pad_desc != pto.PadValue.NULL): ) ``` +Using a standalone `PadValue` with an explicit dtype: +```python +pad_scalar = pto.PadValue.MAX.eval(pto.f32) +pto.set_mov_pad_val(pto.f32(pad_scalar)) +``` + **Important**: You are responsible for ensuring the pad register is properly configured before any `pto.copy_gm_to_ubuf` operation with `enable_ub_pad=True`. The pad register configuration persists until changed by another `pto.set_mov_pad_val` call. **Future Improvement**: Future versions of TileLang DSL may provide an implicit approach that automatically translates `PadValue` descriptors from tile configurations to hardware register configurations, similar to DMA syntax sugar features. @@ -611,4 +617,4 @@ pto.copy_ubuf_to_gm( gm_stride=128, ub_stride=128, ) -``` \ No newline at end of file +``` diff --git a/tilelang-dsl/python/tilelang_dsl/kernel.py b/tilelang-dsl/python/tilelang_dsl/kernel.py index 7eba04f19..743188511 100644 --- a/tilelang-dsl/python/tilelang_dsl/kernel.py +++ b/tilelang-dsl/python/tilelang_dsl/kernel.py @@ -398,10 +398,10 @@ def visit_Call(self, node: ast.Call) -> None: node, "`eval` does not support keyword arguments in TileLang DSL v1", ) - if node.args: + if len(node.args) > 1: raise self.source_info.error( node, - "`eval()` does not accept positional arguments in TileLang DSL v1", + "`eval()` accepts at most one positional dtype argument in TileLang DSL v1", ) return diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 3d1a941a7..65b58129c 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -3504,18 +3504,32 @@ def _tile_pad_value_expr(self, base: SemanticExpr) -> SemanticExpr: type=SemanticPadValueType(element_dtype=base_type.element_dtype), ) - def _pad_value_eval_expr(self, base: SemanticExpr) -> SemanticExpr: + def _pad_value_eval_expr( + self, + base: SemanticExpr, + dtype_expr: SemanticExpr | None = None, + ) -> SemanticExpr: if not isinstance(base.type, SemanticPadValueType): raise TypeError("`eval()` expects a PadValue descriptor in TileLang DSL v1") - if base.type.element_dtype is None: + element_dtype = base.type.element_dtype + if dtype_expr is not None: + if not ( + isinstance(dtype_expr, SemanticSymbolExpr) + and isinstance(dtype_expr.type, SemanticMetaType) + and dtype_expr.type.kind == "dtype" + and isinstance(dtype_expr.value, ScalarType) + ): + raise TypeError("PadValue.eval(dtype) expects a TileLang scalar dtype symbol in TileLang DSL v1") + element_dtype = dtype_expr.value + if element_dtype is None: raise TypeError( - "PadValue.eval() requires a Tile-bound or TileConfig-bound pad descriptor with an owning " - "Tile element dtype in TileLang DSL v1" + "PadValue.eval() requires either a Tile-bound pad descriptor or an explicit dtype argument " + "in TileLang DSL v1" ) pad_value = self._try_static_value(base) if not isinstance(pad_value, PadValue): raise TypeError("PadValue.eval() expects a statically known PadValue enum in TileLang DSL v1") - pad_scalar = pad_value.materialize_scalar(base.type.element_dtype) + pad_scalar = pad_value.eval(element_dtype) if pad_scalar is None: raise TypeError( "PadValue.NULL.eval() is invalid in TileLang DSL v1; " @@ -3523,7 +3537,7 @@ def _pad_value_eval_expr(self, base: SemanticExpr) -> SemanticExpr: ) return SemanticLiteralExpr( value=pad_scalar, - type=SemanticScalarType(dtype=base.type.element_dtype), + type=SemanticScalarType(dtype=element_dtype), ) def _analyze_eval_method( @@ -3531,9 +3545,9 @@ def _analyze_eval_method( base: SemanticExpr, args: tuple[SemanticExpr, ...], ) -> SemanticExpr: - if args: - raise TypeError("`eval()` does not accept positional arguments in TileLang DSL v1") - return self._pad_value_eval_expr(base) + if len(args) > 1: + raise TypeError("`eval()` accepts at most one positional dtype argument in TileLang DSL v1") + return self._pad_value_eval_expr(base, args[0] if args else None) def _tile_config_attr_expr(self, base: SemanticExpr, attr: str) -> SemanticExpr: config = self._try_static_value(base) diff --git a/tilelang-dsl/python/tilelang_dsl/types.py b/tilelang-dsl/python/tilelang_dsl/types.py index a572325c6..1dcf0e374 100644 --- a/tilelang-dsl/python/tilelang_dsl/types.py +++ b/tilelang-dsl/python/tilelang_dsl/types.py @@ -287,7 +287,7 @@ def name(self) -> str: def value(self) -> int: raise AttributeError( "PadValue.value is not available; use PadValue.encoded for host-side payload access " - "or pad.eval() for Tile-bound scalar materialization" + "or pad.eval(...) for scalar materialization" ) @property @@ -312,9 +312,9 @@ def float32_bits(self) -> int: def as_float32(self) -> float: return _float32_from_bits(self.float32_bits) - def materialize_scalar(self, dtype: ScalarType) -> int | float | None: + def eval(self, dtype: ScalarType) -> int | float | None: if not isinstance(dtype, ScalarType): - raise TypeError("PadValue.materialize_scalar expects a TileLang scalar dtype") + raise TypeError("PadValue.eval expects a TileLang scalar dtype") if self == PadValue.NULL: return None if self == PadValue.ZERO: diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index c2ae22241..cc8329529 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -205,12 +205,12 @@ def test_pad_value_supports_standard_and_custom_payloads(self) -> None: self.assertEqual(custom.float32_bits, 0xBF800000) self.assertEqual(custom.encoded, pto.PadValue.CustomBase | (0xBF800000 << 32)) self.assertAlmostEqual(custom.as_float32(), -1.0) - self.assertAlmostEqual(custom.materialize_scalar(pto.f32), -1.0) - self.assertEqual(pto.PadValue.MAX.materialize_scalar(pto.ui16), 0xFFFF) - self.assertEqual(pto.PadValue.MIN.materialize_scalar(pto.ui16), 0) - self.assertEqual(pto.PadValue.MAX.materialize_scalar(pto.i16), 0x7FFF) - self.assertEqual(pto.PadValue.MIN.materialize_scalar(pto.i16), -0x8000) - self.assertIsNone(pto.PadValue.NULL.materialize_scalar(pto.f16)) + self.assertAlmostEqual(custom.eval(pto.f32), -1.0) + self.assertEqual(pto.PadValue.MAX.eval(pto.ui16), 0xFFFF) + self.assertEqual(pto.PadValue.MIN.eval(pto.ui16), 0) + self.assertEqual(pto.PadValue.MAX.eval(pto.i16), 0x7FFF) + self.assertEqual(pto.PadValue.MIN.eval(pto.i16), -0x8000) + self.assertIsNone(pto.PadValue.NULL.eval(pto.f16)) with self.assertRaises(AttributeError): _ = pto.PadValue.ZERO.value @@ -2353,6 +2353,28 @@ def kernel(tile: pto.Tile): self.assertIn("PadValue.NULL.eval() is invalid", str(ctx.exception)) + def test_standalone_pad_value_eval_accepts_explicit_dtype(self) -> None: + @pto.vkernel(op="standalone_pad_value_eval_dtype", dtypes=[(pto.f32,)]) + def kernel(tile: pto.Tile): + scalar = pto.PadValue.MAX.eval(pto.f32) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(8, 16), + memory_space=pto.MemorySpace.UB, + ) + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + scalar_assign = semantic_kernel.body[0] + + self.assertIsInstance(scalar_assign, SemanticAssignStmt) + self.assertIsInstance(scalar_assign.value, SemanticLiteralExpr) + self.assertAlmostEqual(scalar_assign.value.value, pto.PadValue.MAX.eval(pto.f32)) + self.assertIsInstance(scalar_assign.targets[0].type, SemanticScalarType) + self.assertEqual(scalar_assign.targets[0].type.dtype, pto.f32) + def test_unsigned_integer_constants_lower_with_signless_arith_types(self) -> None: @pto.vkernel(op="tile_pad_value_ui32_max_eval_unique", dtypes=[(pto.ui32,)]) def kernel(tile: pto.Tile): From a500dc96f88a70e40a550fc4642f0e132ace71dc Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Tue, 21 Apr 2026 14:59:39 +0800 Subject: [PATCH 112/192] fix(dsl): allow PadValue.eval with static dtype bindings --- tilelang-dsl/python/tilelang_dsl/semantic.py | 10 +++----- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 27 ++++++++++++++++++++ 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 65b58129c..a88c88cd1 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -3513,14 +3513,10 @@ def _pad_value_eval_expr( raise TypeError("`eval()` expects a PadValue descriptor in TileLang DSL v1") element_dtype = base.type.element_dtype if dtype_expr is not None: - if not ( - isinstance(dtype_expr, SemanticSymbolExpr) - and isinstance(dtype_expr.type, SemanticMetaType) - and dtype_expr.type.kind == "dtype" - and isinstance(dtype_expr.value, ScalarType) - ): + explicit_dtype = self._try_static_value(dtype_expr) + if not isinstance(explicit_dtype, ScalarType): raise TypeError("PadValue.eval(dtype) expects a TileLang scalar dtype symbol in TileLang DSL v1") - element_dtype = dtype_expr.value + element_dtype = explicit_dtype if element_dtype is None: raise TypeError( "PadValue.eval() requires either a Tile-bound pad descriptor or an explicit dtype argument " diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index cc8329529..87ac88f3e 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -2375,6 +2375,33 @@ def kernel(tile: pto.Tile): self.assertIsInstance(scalar_assign.targets[0].type, SemanticScalarType) self.assertEqual(scalar_assign.targets[0].type.dtype, pto.f32) + def test_standalone_pad_value_eval_accepts_static_dtype_binding(self) -> None: + @pto.vkernel(op="standalone_pad_value_eval_dtype_binding", dtypes=[(pto.f32,)]) + def kernel(tile: pto.Tile): + dtype = tile.element_type + scalar = pto.PadValue.MAX.eval(dtype) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(8, 16), + memory_space=pto.MemorySpace.UB, + ) + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + dtype_assign, scalar_assign = semantic_kernel.body[:2] + + self.assertIsInstance(dtype_assign, SemanticAssignStmt) + self.assertIsInstance(dtype_assign.value, SemanticSymbolExpr) + self.assertEqual(dtype_assign.value.value, pto.f32) + + self.assertIsInstance(scalar_assign, SemanticAssignStmt) + self.assertIsInstance(scalar_assign.value, SemanticLiteralExpr) + self.assertAlmostEqual(scalar_assign.value.value, pto.PadValue.MAX.eval(pto.f32)) + self.assertIsInstance(scalar_assign.targets[0].type, SemanticScalarType) + self.assertEqual(scalar_assign.targets[0].type.dtype, pto.f32) + def test_unsigned_integer_constants_lower_with_signless_arith_types(self) -> None: @pto.vkernel(op="tile_pad_value_ui32_max_eval_unique", dtypes=[(pto.ui32,)]) def kernel(tile: pto.Tile): From 28fa87e6d35e7f51ba261f84633db76e95d7f274 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Tue, 21 Apr 2026 20:36:45 +0800 Subject: [PATCH 113/192] fix(tilelang-dsl): auto-cast mov pad scalars for issue 170 (#170) --- .../docs/user_guide/08-sync-dma-operations.md | 6 ++-- tilelang-dsl/python/tilelang_dsl/lowering.py | 24 ++++++++++++++++ tilelang-dsl/python/tilelang_dsl/semantic.py | 15 ++++++---- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 28 ++++++++++++++++++- 4 files changed, 65 insertions(+), 8 deletions(-) diff --git a/tilelang-dsl/docs/user_guide/08-sync-dma-operations.md b/tilelang-dsl/docs/user_guide/08-sync-dma-operations.md index e13516d0f..2ee6ec05a 100644 --- a/tilelang-dsl/docs/user_guide/08-sync-dma-operations.md +++ b/tilelang-dsl/docs/user_guide/08-sync-dma-operations.md @@ -335,7 +335,7 @@ pto.set_mov_pad_val(pad_value: ScalarType) -> None **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| -| `pad_value` | `ScalarType` | Scalar value used for padding. Supported types: `pto.i8`, `pto.i16`, `pto.i32`, `pto.f16`, `pto.bf16`, `pto.f32`. The value's bit pattern is encoded into the hardware pad register. For standard pad values, use `PadValue.eval(...)` to obtain the appropriate scalar: `0` or `0.0` for `PadValue.ZERO`, dtype-aware maximum for `PadValue.MAX`, dtype-aware minimum for `PadValue.MIN`. | +| `pad_value` | `ScalarType` | Scalar value used for padding. Supported types: any 8/16/32-bit integer scalar (`pto.i8`, `pto.si8`, `pto.ui8`, `pto.i16`, `pto.si16`, `pto.ui16`, `pto.i32`, `pto.si32`, `pto.ui32`) plus `pto.f16`, `pto.bf16`, and `pto.f32`. The value's bit pattern is encoded into the hardware pad register. Integer inputs are automatically normalized to the corresponding signless hardware operand width during lowering, so no manual cast is required before calling `pto.set_mov_pad_val`. For standard pad values, use `PadValue.eval(...)` to obtain the appropriate scalar: `0` or `0.0` for `PadValue.ZERO`, dtype-aware maximum for `PadValue.MAX`, dtype-aware minimum for `PadValue.MIN`. | **Returns**: None (side-effect operation) @@ -384,9 +384,11 @@ if pto.constexpr(pad_desc != pto.PadValue.NULL): Using a standalone `PadValue` with an explicit dtype: ```python pad_scalar = pto.PadValue.MAX.eval(pto.f32) -pto.set_mov_pad_val(pto.f32(pad_scalar)) +pto.set_mov_pad_val(pad_scalar) ``` +For integer tile dtypes such as `pto.ui16` or `pto.si32`, `pad_desc.eval()` can be passed directly to `pto.set_mov_pad_val`. TileLang DSL v1 will automatically insert the required same-width bitcast to the signless hardware operand type during lowering. + **Important**: You are responsible for ensuring the pad register is properly configured before any `pto.copy_gm_to_ubuf` operation with `enable_ub_pad=True`. The pad register configuration persists until changed by another `pto.set_mov_pad_val` call. **Future Improvement**: Future versions of TileLang DSL may provide an implicit approach that automatically translates `PadValue` descriptors from tile configurations to hardware register configurations, similar to DMA syntax sugar features. diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index 4eaf241af..838ce2df8 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -91,6 +91,17 @@ _I64_TYPE = SemanticScalarType(dtype=ScalarType("i64")) +def _signless_mov_pad_scalar_type(dtype: ScalarType) -> SemanticScalarType | None: + bitwidth = integer_bitwidth(dtype) + if bitwidth == 8: + return SemanticScalarType(dtype=ScalarType("i8")) + if bitwidth == 16: + return SemanticScalarType(dtype=ScalarType("i16")) + if bitwidth == 32: + return SemanticScalarType(dtype=ScalarType("i32")) + return None + + def _format_symbol_name(symbol_name: str) -> str: if re.fullmatch(r"[A-Za-z_][A-Za-z0-9_$.]*", symbol_name): return f"@{symbol_name}" @@ -505,6 +516,19 @@ def _render_dma_unary_config( ) -> list[str]: lines: list[str] = [] value = self._lower_expr(stmt.value, env, indent=indent, into=lines) + if ( + stmt.name == "set_mov_pad_val" + and isinstance(value.type, SemanticScalarType) + and is_integer_dtype(value.type.dtype) + ): + signless_type = _signless_mov_pad_scalar_type(value.type.dtype) + if signless_type is not None: + value = self._coerce_rendered_value( + value, + signless_type, + indent=indent, + into=lines, + ) lines.append( self._indent(indent) + f"pto.{stmt.name} {value.name} : {self._render_type(value.type)}" diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index a88c88cd1..bcd7a96a1 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -289,9 +289,14 @@ def _classify_vcvt_elem_kind(dtype: ScalarType) -> str | None: "copy_ubuf_to_gm", "copy_ubuf_to_ubuf", } -_MOV_PAD_SUPPORTED_SCALAR_DTYPES = frozenset( - dtype.name for dtype in (i8, i16, i32, f16, bf16, f32) -) + + +def _is_supported_mov_pad_scalar_dtype(dtype: ScalarType) -> bool: + if is_integer_dtype(dtype): + return integer_bitwidth(dtype) in {8, 16, 32} + return dtype.name in {"f16", "bf16", "f32"} + + _COMPARE_SELECT_OPS = {"vcmp", "vcmps", "vsel", "vselr", "vselrv2"} _PREDICATE_MOVEMENT_OPS = { "pset_b8", @@ -2126,9 +2131,9 @@ def _analyze_low_level_dma_stmt( if len(args) != 1: raise TypeError(f"pto.{expr.name} expects exactly 1 positional argument in TileLang DSL") scalar = self._require_scalar_expr(args[0], f"pto.{expr.name} pad_value") - if scalar.dtype.name not in _MOV_PAD_SUPPORTED_SCALAR_DTYPES: + if not _is_supported_mov_pad_scalar_dtype(scalar.dtype): raise TypeError( - "pto.set_mov_pad_val pad_value must be one of i8, i16, i32, f16, bf16, or f32 in TileLang DSL v1" + "pto.set_mov_pad_val pad_value must be an 8/16/32-bit integer or f16/bf16/f32 in TileLang DSL v1" ) return ( SemanticDmaUnaryConfigStmt( diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index 87ac88f3e..52e6abbae 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -4745,6 +4745,32 @@ def kernel(inp: pto.TensorView, dst: pto.Tile): r"pto\.copy_gm_to_ubuf %gm_ptr_\d+, %ub_ptr_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %true, %tmp_\d+, %tmp_\d+, %tmp_\d+", ) + def test_set_mov_pad_val_automatically_bitcasts_unsigned_tile_pad_value_to_signless_scalar(self) -> None: + @pto.vkernel(op="set_mov_pad_val_tile_pad_bitcast_unique", dtypes=[(pto.ui16,)], advanced=True) + def kernel(dst: pto.Tile): + pto.set_mov_pad_val(dst.pad_value.eval()) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization( + shape=(260, 32), + memory_space=pto.MemorySpace.UB, + config=pto.TileConfig.from_mapping( + { + "b_layout": "row_major", + "s_layout": "none_box", + "s_fractal_size": 512, + "pad_value": "0x2", + } + ), + valid_shape=(260, 7), + ) + ) + + text = specialized.mlir_text() + self.assertIn("arith.bitcast", text) + self.assertRegex(text, r"pto\.set_mov_pad_val %[^ ]+ : i16") + def test_copy_ubuf_to_gm_keyword_surface_lowers_in_advanced_mode(self) -> None: @pto.vkernel(op="tile_to_tensorview_dma_unique", dtypes=[(pto.f32, pto.f32)], advanced=True) def kernel(src: pto.Tile, dst: pto.TensorView): @@ -5945,7 +5971,7 @@ def kernel(dst: pto.Tile): ).mlir_text() self.assertIn( - "pto.set_mov_pad_val pad_value must be one of i8, i16, i32, f16, bf16, or f32", + "pto.set_mov_pad_val pad_value must be an 8/16/32-bit integer or f16/bf16/f32", str(ctx.exception), ) From 34af40aabb44e94bc4a9e9f438c51cdf8abf78ad Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Tue, 21 Apr 2026 21:07:51 +0800 Subject: [PATCH 114/192] fixup ci --- .../scripts/generate_release_vpto_spec.py | 8 ++++++++ include/PTO/IR/VPTOOps.td | 8 ++++++++ include/PTO/IR/VPTOTypeDefs.td | 8 ++++++++ include/PTO/Transforms/HIVMIntrinsicNaming.h | 8 ++++++++ include/PTO/Transforms/VPTOLLVMEmitter.h | 8 ++++++++ .../PTO/Transforms/VPTOLLVMEmitterHelper.h | 8 ++++++++ include/PTO/Transforms/VPTOLowering.h | 8 ++++++++ lib/Bindings/Python/CMakeLists.txt | 5 ++++- lib/PTO/IR/PTO.cpp | 20 ++++--------------- lib/PTO/IR/VPTO.cpp | 8 ++++++++ lib/PTO/Transforms/HIVMIntrinsicNaming.cpp | 8 ++++++++ lib/PTO/Transforms/PTOToVPTO.cpp | 8 ++++++++ lib/PTO/Transforms/PTOToVPTOLowering.cpp | 8 ++++++++ lib/PTO/Transforms/PTOVPTOExpandBridgeOps.cpp | 10 +++++++++- lib/PTO/Transforms/PTOVPTOPtrBoundary.cpp | 8 ++++++++ lib/PTO/Transforms/PTOValidateVPTOIR.cpp | 8 ++++++++ lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 8 ++++++++ lib/PTO/Transforms/VPTOLLVMEmitterHelper.cpp | 8 ++++++++ python/pto/dialects/pto.py | 17 ++++++++++++++++ scripts/batch_compile_output_cpp.sh | 8 ++++++++ scripts/compile_pto_to_vpto_llvm.sh | 8 ++++++++ scripts/ptoas_env.sh | 8 ++++++++ test/dsl/abs.py | 8 ++++++++ test/dsl/strict_vecscope.py | 8 ++++++++ test/dsl/template_abs.py | 8 ++++++++ test/lit.cfg.py | 8 ++++++++ .../kernels/online-softmax-update/compare.py | 8 ++++++++ .../kernels/online-softmax-update/golden.py | 8 ++++++++ .../kernels/online-softmax-update/launch.cpp | 8 ++++++++ .../kernels/online-softmax-update/main.cpp | 8 ++++++++ .../kernels/online-softmax-update/stub.cpp | 8 ++++++++ 31 files changed, 250 insertions(+), 18 deletions(-) diff --git a/.codex/skills/generate-vpto-release-doc/scripts/generate_release_vpto_spec.py b/.codex/skills/generate-vpto-release-doc/scripts/generate_release_vpto_spec.py index f03e7d592..be7de527d 100644 --- a/.codex/skills/generate-vpto-release-doc/scripts/generate_release_vpto_spec.py +++ b/.codex/skills/generate-vpto-release-doc/scripts/generate_release_vpto_spec.py @@ -1,4 +1,12 @@ #!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + """Generate merged VPTO release spec from docs/vpto-spec.md and docs/isa/*.md.""" from __future__ import annotations diff --git a/include/PTO/IR/VPTOOps.td b/include/PTO/IR/VPTOOps.td index 909786c6c..324c45c69 100644 --- a/include/PTO/IR/VPTOOps.td +++ b/include/PTO/IR/VPTOOps.td @@ -1,3 +1,11 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + //===- VPTOOps.td - PTO low-level operations ----------------*- tablegen -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. diff --git a/include/PTO/IR/VPTOTypeDefs.td b/include/PTO/IR/VPTOTypeDefs.td index 04e8ac583..ed62e7655 100644 --- a/include/PTO/IR/VPTOTypeDefs.td +++ b/include/PTO/IR/VPTOTypeDefs.td @@ -1,3 +1,11 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + //===- VPTOTypeDefs.td ---------------------------------------*- tablegen -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. diff --git a/include/PTO/Transforms/HIVMIntrinsicNaming.h b/include/PTO/Transforms/HIVMIntrinsicNaming.h index 7ba956168..f2ea4c899 100644 --- a/include/PTO/Transforms/HIVMIntrinsicNaming.h +++ b/include/PTO/Transforms/HIVMIntrinsicNaming.h @@ -1,3 +1,11 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + #ifndef MLIR_DIALECT_PTO_TRANSFORMS_HIVMINTRINSICNAMING_H #define MLIR_DIALECT_PTO_TRANSFORMS_HIVMINTRINSICNAMING_H diff --git a/include/PTO/Transforms/VPTOLLVMEmitter.h b/include/PTO/Transforms/VPTOLLVMEmitter.h index dc56f64b2..625d5b2fa 100644 --- a/include/PTO/Transforms/VPTOLLVMEmitter.h +++ b/include/PTO/Transforms/VPTOLLVMEmitter.h @@ -1,3 +1,11 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + #ifndef MLIR_DIALECT_PTO_TRANSFORMS_VPTOLLVMEMITTER_H #define MLIR_DIALECT_PTO_TRANSFORMS_VPTOLLVMEMITTER_H diff --git a/include/PTO/Transforms/VPTOLLVMEmitterHelper.h b/include/PTO/Transforms/VPTOLLVMEmitterHelper.h index 555bbe274..0db138273 100644 --- a/include/PTO/Transforms/VPTOLLVMEmitterHelper.h +++ b/include/PTO/Transforms/VPTOLLVMEmitterHelper.h @@ -1,3 +1,11 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + #ifndef MLIR_DIALECT_PTO_TRANSFORMS_VPTOLLVMEMITTERHELPER_H #define MLIR_DIALECT_PTO_TRANSFORMS_VPTOLLVMEMITTERHELPER_H diff --git a/include/PTO/Transforms/VPTOLowering.h b/include/PTO/Transforms/VPTOLowering.h index 17730ab4e..2c80d2354 100644 --- a/include/PTO/Transforms/VPTOLowering.h +++ b/include/PTO/Transforms/VPTOLowering.h @@ -1,3 +1,11 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + //===- VPTOLowering.h - PTO to VPTO lowering contracts ----------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. diff --git a/lib/Bindings/Python/CMakeLists.txt b/lib/Bindings/Python/CMakeLists.txt index 6fe560a06..b7628dd86 100644 --- a/lib/Bindings/Python/CMakeLists.txt +++ b/lib/Bindings/Python/CMakeLists.txt @@ -62,6 +62,10 @@ if(APPLE) target_link_options(_pto PRIVATE "LINKER:-undefined,dynamic_lookup") endif() +if(NOT MLIR_PYTHON_PACKAGE_DIR) + message(FATAL_ERROR "MLIR_PYTHON_PACKAGE_DIR must be set when PTO_ENABLE_PYTHON_BINDING=ON") +endif() + install(TARGETS _pto LIBRARY DESTINATION "${MLIR_PYTHON_PACKAGE_DIR}/mlir/_mlir_libs" ) @@ -103,4 +107,3 @@ add_custom_command(TARGET _pto POST_BUILD "${CMAKE_BINARY_DIR}/python/mlir/_mlir_libs" VERBATIM ) - diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index a54733991..1e258f009 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -5941,23 +5941,11 @@ static void printBufSyncOp(OpAsmPrinter &p, Attribute opTypeAttr, p << " \"" << stringifyPIPE(pipeAttr.getPipe()) << "\", " << bufIdAttr.getInt() << ", " << modeAttr.getInt(); } else if (auto pipeEventType = dyn_cast(opTypeAttr)) { - auto pipe = mapSyncOpTypeToPipe(pipeEventType.getOpType()); - if (isConcreteSyncPipe(pipe)) { - p << " \"" << stringifyPIPE(pipe) << "\", " << bufIdAttr.getInt() - << ", " << modeAttr.getInt(); - } else { - p << "[" << opTypeAttr << ", " << bufIdAttr.getInt() << ", " - << modeAttr.getInt() << "]"; - } + p << "[" << opTypeAttr << ", " << bufIdAttr.getInt() << ", " + << modeAttr.getInt() << "]"; } else if (auto syncOpType = dyn_cast(opTypeAttr)) { - auto pipe = mapSyncOpTypeToPipe(syncOpType.getOpType()); - if (isConcreteSyncPipe(pipe)) { - p << " \"" << stringifyPIPE(pipe) << "\", " << bufIdAttr.getInt() - << ", " << modeAttr.getInt(); - } else { - p << "[" << opTypeAttr << ", " << bufIdAttr.getInt() << ", " - << modeAttr.getInt() << "]"; - } + p << "[" << opTypeAttr << ", " << bufIdAttr.getInt() << ", " + << modeAttr.getInt() << "]"; } else { p << "[" << opTypeAttr << ", " << bufIdAttr.getInt() << ", " << modeAttr.getInt() << "]"; diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp index 5c2bfd264..c8f2ad7c3 100644 --- a/lib/PTO/IR/VPTO.cpp +++ b/lib/PTO/IR/VPTO.cpp @@ -1,3 +1,11 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + //===- VPTO.cpp - VPTO dialect -------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. diff --git a/lib/PTO/Transforms/HIVMIntrinsicNaming.cpp b/lib/PTO/Transforms/HIVMIntrinsicNaming.cpp index aae68d33b..d17b8ad38 100644 --- a/lib/PTO/Transforms/HIVMIntrinsicNaming.cpp +++ b/lib/PTO/Transforms/HIVMIntrinsicNaming.cpp @@ -1,3 +1,11 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + //===- HIVMIntrinsicNaming.cpp - HIVM intrinsic selection -----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. diff --git a/lib/PTO/Transforms/PTOToVPTO.cpp b/lib/PTO/Transforms/PTOToVPTO.cpp index 1661e1fcf..c73774bf4 100644 --- a/lib/PTO/Transforms/PTOToVPTO.cpp +++ b/lib/PTO/Transforms/PTOToVPTO.cpp @@ -1,3 +1,11 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + //===- PTOToVPTO.cpp - PTO to VPTO pass wiring ---------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. diff --git a/lib/PTO/Transforms/PTOToVPTOLowering.cpp b/lib/PTO/Transforms/PTOToVPTOLowering.cpp index d8bf008a8..23c673505 100644 --- a/lib/PTO/Transforms/PTOToVPTOLowering.cpp +++ b/lib/PTO/Transforms/PTOToVPTOLowering.cpp @@ -1,3 +1,11 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + //===- PTOToVPTOLowering.cpp - PTO to VPTO lowering helpers --------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. diff --git a/lib/PTO/Transforms/PTOVPTOExpandBridgeOps.cpp b/lib/PTO/Transforms/PTOVPTOExpandBridgeOps.cpp index 7d2cd0d4d..9b4711982 100644 --- a/lib/PTO/Transforms/PTOVPTOExpandBridgeOps.cpp +++ b/lib/PTO/Transforms/PTOVPTOExpandBridgeOps.cpp @@ -1,3 +1,11 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + #include "PTO/IR/PTO.h" #include "PTO/Transforms/Passes.h" @@ -83,7 +91,7 @@ struct ExpandUvldPattern : public OpRewritePattern { Value align = rewriter.create(op.getLoc(), alignType, loadPtr); auto load = rewriter.create( - op.getLoc(), TypeRange{vecType, alignType, loadPtr.getType()}, + op.getLoc(), TypeRange{vecType, alignType}, ValueRange{loadPtr, align}); rewriter.replaceOp(op, load.getResult()); return success(); diff --git a/lib/PTO/Transforms/PTOVPTOPtrBoundary.cpp b/lib/PTO/Transforms/PTOVPTOPtrBoundary.cpp index 6aa62259f..f1e52424f 100644 --- a/lib/PTO/Transforms/PTOVPTOPtrBoundary.cpp +++ b/lib/PTO/Transforms/PTOVPTOPtrBoundary.cpp @@ -1,3 +1,11 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + #include "PTO/IR/PTO.h" #include "PTO/Transforms/VPTOLowering.h" #include "PTO/Transforms/Passes.h" diff --git a/lib/PTO/Transforms/PTOValidateVPTOIR.cpp b/lib/PTO/Transforms/PTOValidateVPTOIR.cpp index e5ba23c39..7f1a54a18 100644 --- a/lib/PTO/Transforms/PTOValidateVPTOIR.cpp +++ b/lib/PTO/Transforms/PTOValidateVPTOIR.cpp @@ -1,3 +1,11 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + //===- PTOValidateVPTOIR.cpp - Shared VPTO legality helpers --------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index 6cc1c90c3..0bd5ae62d 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -1,3 +1,11 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + #include "PTO/Transforms/VPTOLLVMEmitter.h" #include "PTO/IR/PTO.h" diff --git a/lib/PTO/Transforms/VPTOLLVMEmitterHelper.cpp b/lib/PTO/Transforms/VPTOLLVMEmitterHelper.cpp index 914ce0b53..7d0f3125c 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitterHelper.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitterHelper.cpp @@ -1,3 +1,11 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + //===- VPTOLLVMEmitterHelper.cpp - VPTO LLVM emission helpers ------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. diff --git a/python/pto/dialects/pto.py b/python/pto/dialects/pto.py index b90174f9f..c6f5e6319 100644 --- a/python/pto/dialects/pto.py +++ b/python/pto/dialects/pto.py @@ -93,6 +93,23 @@ def get_op_result_or_value(value): QuantType = _pto_mod.QuantType QuantTypeAttr = _pto_mod.QuantTypeAttr + +_ptr_type_get_impl = PtrType.get + + +def _ptr_type_get_compat(cls, element_type, memory_space=None, context=None): + if isinstance(memory_space, _ods_ir.Context): + if context is not None: + raise TypeError("PtrType.get got multiple context arguments") + context = memory_space + memory_space = None + return _ptr_type_get_impl( + element_type, memory_space=memory_space, context=context + ) + + +PtrType.get = classmethod(_ptr_type_get_compat) + __all__ = [ # Dialect utilities "register_dialect", diff --git a/scripts/batch_compile_output_cpp.sh b/scripts/batch_compile_output_cpp.sh index 13426a9f6..8de0fb114 100755 --- a/scripts/batch_compile_output_cpp.sh +++ b/scripts/batch_compile_output_cpp.sh @@ -1,4 +1,12 @@ #!/usr/bin/env bash +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + set -u diff --git a/scripts/compile_pto_to_vpto_llvm.sh b/scripts/compile_pto_to_vpto_llvm.sh index 2d15c86b6..4b6c066c9 100755 --- a/scripts/compile_pto_to_vpto_llvm.sh +++ b/scripts/compile_pto_to_vpto_llvm.sh @@ -1,4 +1,12 @@ #!/usr/bin/env bash +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + set -euo pipefail SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" diff --git a/scripts/ptoas_env.sh b/scripts/ptoas_env.sh index 95dcd9a8d..15ed7466a 100644 --- a/scripts/ptoas_env.sh +++ b/scripts/ptoas_env.sh @@ -1,4 +1,12 @@ #!/usr/bin/env bash +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + # PTOAS runtime environment bootstrap. # Usage: # source scripts/ptoas_env.sh diff --git a/test/dsl/abs.py b/test/dsl/abs.py index 7c67e5959..1917a5725 100644 --- a/test/dsl/abs.py +++ b/test/dsl/abs.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + import mlir.dialects.pto as pto diff --git a/test/dsl/strict_vecscope.py b/test/dsl/strict_vecscope.py index e882df3d8..7badac2ae 100644 --- a/test/dsl/strict_vecscope.py +++ b/test/dsl/strict_vecscope.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + import mlir.dialects.pto as pto diff --git a/test/dsl/template_abs.py b/test/dsl/template_abs.py index 87b330e32..dc674adee 100644 --- a/test/dsl/template_abs.py +++ b/test/dsl/template_abs.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + import mlir.dialects.pto as pto diff --git a/test/lit.cfg.py b/test/lit.cfg.py index 95e17569a..ab7a8e01c 100644 --- a/test/lit.cfg.py +++ b/test/lit.cfg.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + import os import lit.formats diff --git a/test/vpto/cases/kernels/online-softmax-update/compare.py b/test/vpto/cases/kernels/online-softmax-update/compare.py index e6af92b4a..31a8aa717 100644 --- a/test/vpto/cases/kernels/online-softmax-update/compare.py +++ b/test/vpto/cases/kernels/online-softmax-update/compare.py @@ -1,4 +1,12 @@ #!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + # case: kernels/online-softmax-update # family: kernels # target_ops: pto.copy_gm_to_ubuf, pto.copy_ubuf_to_gm, pto.vlds, pto.vcmax, pto.vdup, pto.vmax, pto.vexpdiff, pto.vcadd, pto.vadd, pto.vmul, pto.vdiv, pto.vsts diff --git a/test/vpto/cases/kernels/online-softmax-update/golden.py b/test/vpto/cases/kernels/online-softmax-update/golden.py index ea41425eb..497f7eed6 100644 --- a/test/vpto/cases/kernels/online-softmax-update/golden.py +++ b/test/vpto/cases/kernels/online-softmax-update/golden.py @@ -1,4 +1,12 @@ #!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + # case: kernels/online-softmax-update # family: kernels # target_ops: pto.get_block_idx, pto.copy_gm_to_ubuf, pto.copy_ubuf_to_gm, pto.vlds, pto.vcmax, pto.vdup, pto.vmax, pto.vexpdiff, pto.vcadd, pto.vadd, pto.vmul, pto.vdiv, pto.vsts diff --git a/test/vpto/cases/kernels/online-softmax-update/launch.cpp b/test/vpto/cases/kernels/online-softmax-update/launch.cpp index 5cf6c4e2f..e50841764 100644 --- a/test/vpto/cases/kernels/online-softmax-update/launch.cpp +++ b/test/vpto/cases/kernels/online-softmax-update/launch.cpp @@ -1,3 +1,11 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + // ----------------------------------------------------------------------------- // case: kernels/online-softmax-update // family: kernels diff --git a/test/vpto/cases/kernels/online-softmax-update/main.cpp b/test/vpto/cases/kernels/online-softmax-update/main.cpp index 6282f13a8..af6cbb63b 100644 --- a/test/vpto/cases/kernels/online-softmax-update/main.cpp +++ b/test/vpto/cases/kernels/online-softmax-update/main.cpp @@ -1,3 +1,11 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + // ----------------------------------------------------------------------------- // case: kernels/online-softmax-update // family: kernels diff --git a/test/vpto/cases/kernels/online-softmax-update/stub.cpp b/test/vpto/cases/kernels/online-softmax-update/stub.cpp index 003519801..389a74d5f 100644 --- a/test/vpto/cases/kernels/online-softmax-update/stub.cpp +++ b/test/vpto/cases/kernels/online-softmax-update/stub.cpp @@ -1,3 +1,11 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + // ----------------------------------------------------------------------------- // case: kernels/online-softmax-update // family: kernels From 3112fd89339ef4162f9c84df59a2557ce7d24f7c Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Tue, 21 Apr 2026 22:51:50 +0800 Subject: [PATCH 115/192] fix(dsl): support integer string scalar literals (#174) --- tilelang-dsl/python/tilelang_dsl/semantic.py | 63 +++++++++++++++----- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 22 +++++++ 2 files changed, 71 insertions(+), 14 deletions(-) diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index bcd7a96a1..596c0a2f7 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -4147,6 +4147,17 @@ def _analyze_scalar_constructor( value=parsed, type=SemanticScalarType(dtype=target_dtype), ) + if ( + is_integer_dtype(target_dtype) + and isinstance(args[0], SemanticLiteralExpr) + and isinstance(args[0].type, SemanticMetaType) + and args[0].type.kind == "string" + ): + parsed = self._parse_integer_literal_string(args[0].value, target_dtype, f"pto.{name} value") + return SemanticLiteralExpr( + value=parsed, + type=SemanticScalarType(dtype=target_dtype), + ) value = self._require_scalar_or_index_expr(args[0], f"pto.{name} value") @@ -4170,20 +4181,8 @@ def _analyze_scalar_constructor( else: casted = None if casted is not None: - bits = integer_bitwidth(target_dtype) - signedness = integer_signedness(target_dtype) - assert bits is not None - if signedness == "unsigned": - min_value = 0 - max_value = (1 << bits) - 1 - else: - min_value = -(1 << (bits - 1)) - max_value = (1 << (bits - 1)) - 1 - if casted < min_value or casted > max_value: - raise TypeError( - f"pto.{name} value {casted} is out of range for {target_dtype.name} in TileLang DSL v1" - ) - return SemanticLiteralExpr(value=casted, type=SemanticScalarType(dtype=target_dtype)) + checked = self._check_integer_literal_range(casted, target_dtype, f"pto.{name} value") + return SemanticLiteralExpr(value=checked, type=SemanticScalarType(dtype=target_dtype)) else: if isinstance(literal_value, (bool, int, float)): return SemanticLiteralExpr( @@ -4228,6 +4227,42 @@ def _parse_float_literal_string( f"{context} string literal {literal!r} is not a valid float literal" ) from exc + def _parse_integer_literal_string( + self, + literal: str, + target_dtype: ScalarType, + context: str, + ) -> int: + text = literal.strip().lower() + try: + parsed = int(text, 0) + except ValueError as exc: + raise TypeError( + f"{context} string literal {literal!r} is not a valid integer literal" + ) from exc + return self._check_integer_literal_range(parsed, target_dtype, context) + + def _check_integer_literal_range( + self, + value: int, + target_dtype: ScalarType, + context: str, + ) -> int: + bits = integer_bitwidth(target_dtype) + signedness = integer_signedness(target_dtype) + assert bits is not None + if signedness == "unsigned": + min_value = 0 + max_value = (1 << bits) - 1 + else: + min_value = -(1 << (bits - 1)) + max_value = (1 << (bits - 1)) - 1 + if value < min_value or value > max_value: + raise TypeError( + f"{context} {value} is out of range for {target_dtype.name} in TileLang DSL v1" + ) + return value + def _float_from_bit_pattern( self, bit_pattern: int, diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index 52e6abbae..3b51a8b88 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -4007,6 +4007,17 @@ def kernel(inp: pto.TensorView): self.assertIn("pto.i32 value must be a scalar or index value", str(ctx.exception)) + def test_scalar_constructor_accepts_integer_string_literals(self) -> None: + @pto.vkernel(op="scalar_constructor_integer_string_literals_unique", dtypes=[(pto.f32,)]) + def kernel(inp: pto.TensorView): + x = pto.i16("0x7FFF") + y = pto.i32("0x7FFFFFFF") + return None + + text = kernel.mlir_text() + self.assertIn("= arith.constant 32767 : i16", text) + self.assertIn("= arith.constant 2147483647 : i32", text) + def test_scalar_constructor_rejects_out_of_range_integer_literal(self) -> None: @pto.vkernel(op="scalar_constructor_oob_int_unique", dtypes=[(pto.f32,)]) def kernel(inp: pto.TensorView): @@ -4018,6 +4029,17 @@ def kernel(inp: pto.TensorView): self.assertIn("out of range for i8", str(ctx.exception)) + def test_scalar_constructor_rejects_out_of_range_integer_string_literal(self) -> None: + @pto.vkernel(op="scalar_constructor_oob_integer_string_unique", dtypes=[(pto.f32,)]) + def kernel(inp: pto.TensorView): + x = pto.i16("0x8000") + return None + + with self.assertRaises(TypeError) as ctx: + kernel.mlir_text() + + self.assertIn("out of range for i16", str(ctx.exception)) + def test_inferred_vecscope_propagates_bindings_to_constexpr_if(self) -> None: @pto.vkernel( op="inferred_vecscope_binding_propagation_unique", From d482a9af4a3cee52be063518aa3243994f1ff082 Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Tue, 21 Apr 2026 23:40:50 +0800 Subject: [PATCH 116/192] feat: vpto ci --- .github/workflows/ci.yml | 175 ++++++ .../binary-vector/vadd-bf16/compare.py | 42 ++ .../binary-vector/vadd-bf16/golden.py | 61 +++ .../binary-vector/vadd-bf16/kernel.pto | 50 ++ .../binary-vector/vadd-bf16/launch.cpp | 53 ++ .../micro-op/binary-vector/vadd-bf16/main.cpp | 102 ++++ .../micro-op/binary-vector/vadd-bf16/stub.cpp | 30 ++ .../binary-vector/vadd-f16/compare.py | 42 ++ .../micro-op/binary-vector/vadd-f16/golden.py | 51 ++ .../binary-vector/vadd-f16/kernel.pto | 50 ++ .../binary-vector/vadd-f16/launch.cpp | 52 ++ .../micro-op/binary-vector/vadd-f16/main.cpp | 101 ++++ .../micro-op/binary-vector/vadd-f16/stub.cpp | 30 ++ .../vadd-f32-exceptional/compare.py | 37 ++ .../vadd-f32-exceptional/golden.py | 54 ++ .../vadd-f32-exceptional/kernel.pto | 46 ++ .../vadd-f32-exceptional/launch.cpp | 46 ++ .../vadd-f32-exceptional/main.cpp | 95 ++++ .../vadd-f32-exceptional/stub.cpp | 24 + .../vadd-i16-signed-overflow/compare.py | 42 ++ .../vadd-i16-signed-overflow/golden.py | 62 +++ .../vadd-i16-signed-overflow/kernel.pto | 53 ++ .../vadd-i16-signed-overflow/launch.cpp | 47 ++ .../vadd-i16-signed-overflow/main.cpp | 92 ++++ .../vadd-i16-signed-overflow/stub.cpp | 25 + .../binary-vector/vadd-i16-signed/compare.py | 42 ++ .../binary-vector/vadd-i16-signed/golden.py | 48 ++ .../binary-vector/vadd-i16-signed/kernel.pto | 50 ++ .../binary-vector/vadd-i16-signed/launch.cpp | 53 ++ .../binary-vector/vadd-i16-signed/main.cpp | 102 ++++ .../binary-vector/vadd-i16-signed/stub.cpp | 30 ++ .../vadd-i16-unsigned-overflow/compare.py | 42 ++ .../vadd-i16-unsigned-overflow/golden.py | 62 +++ .../vadd-i16-unsigned-overflow/kernel.pto | 53 ++ .../vadd-i16-unsigned-overflow/launch.cpp | 47 ++ .../vadd-i16-unsigned-overflow/main.cpp | 92 ++++ .../vadd-i16-unsigned-overflow/stub.cpp | 25 + .../vadd-i16-unsigned/compare.py | 42 ++ .../binary-vector/vadd-i16-unsigned/golden.py | 48 ++ .../vadd-i16-unsigned/kernel.pto | 50 ++ .../vadd-i16-unsigned/launch.cpp | 53 ++ .../binary-vector/vadd-i16-unsigned/main.cpp | 102 ++++ .../binary-vector/vadd-i16-unsigned/stub.cpp | 30 ++ .../binary-vector/vadd-tail/compare.py | 39 ++ .../binary-vector/vadd-tail/golden.py | 51 ++ .../binary-vector/vadd-tail/kernel.pto | 46 ++ .../binary-vector/vadd-tail/launch.cpp | 46 ++ .../micro-op/binary-vector/vadd-tail/main.cpp | 94 ++++ .../micro-op/binary-vector/vadd-tail/stub.cpp | 24 + .../micro-op/binary-vector/vadd/compare.py | 37 ++ .../micro-op/binary-vector/vadd/golden.py | 46 ++ .../micro-op/binary-vector/vadd/kernel.pto | 55 ++ .../micro-op/binary-vector/vadd/launch.cpp | 46 ++ .../micro-op/binary-vector/vadd/main.cpp | 94 ++++ .../micro-op/binary-vector/vadd/stub.cpp | 24 + .../vaddc-carry-boundary/compare.py | 64 +++ .../vaddc-carry-boundary/golden.py | 69 +++ .../vaddc-carry-boundary/kernel.pto | 53 ++ .../vaddc-carry-boundary/launch.cpp | 56 ++ .../vaddc-carry-boundary/main.cpp | 119 +++++ .../vaddc-carry-boundary/stub.cpp | 33 ++ .../micro-op/binary-vector/vaddc/compare.py | 64 +++ .../micro-op/binary-vector/vaddc/golden.py | 64 +++ .../micro-op/binary-vector/vaddc/kernel.pto | 53 ++ .../micro-op/binary-vector/vaddc/launch.cpp | 57 ++ .../micro-op/binary-vector/vaddc/main.cpp | 117 +++++ .../micro-op/binary-vector/vaddc/stub.cpp | 34 ++ .../binary-vector/vand-mask-edge/compare.py | 44 ++ .../binary-vector/vand-mask-edge/golden.py | 50 ++ .../binary-vector/vand-mask-edge/kernel.pto | 52 ++ .../binary-vector/vand-mask-edge/launch.cpp | 55 ++ .../binary-vector/vand-mask-edge/main.cpp | 103 ++++ .../binary-vector/vand-mask-edge/stub.cpp | 32 ++ .../micro-op/binary-vector/vand/compare.py | 42 ++ .../micro-op/binary-vector/vand/golden.py | 49 ++ .../micro-op/binary-vector/vand/kernel.pto | 50 ++ .../micro-op/binary-vector/vand/launch.cpp | 53 ++ .../micro-op/binary-vector/vand/main.cpp | 101 ++++ .../micro-op/binary-vector/vand/stub.cpp | 31 ++ .../binary-vector/vdiv-f16/compare.py | 42 ++ .../micro-op/binary-vector/vdiv-f16/golden.py | 56 ++ .../binary-vector/vdiv-f16/kernel.pto | 52 ++ .../binary-vector/vdiv-f16/launch.cpp | 52 ++ .../micro-op/binary-vector/vdiv-f16/main.cpp | 101 ++++ .../micro-op/binary-vector/vdiv-f16/stub.cpp | 30 ++ .../vdiv-f32-exceptional/compare.py | 37 ++ .../vdiv-f32-exceptional/golden.py | 54 ++ .../vdiv-f32-exceptional/kernel.pto | 46 ++ .../vdiv-f32-exceptional/launch.cpp | 46 ++ .../vdiv-f32-exceptional/main.cpp | 95 ++++ .../vdiv-f32-exceptional/stub.cpp | 24 + .../binary-vector/vdiv-tail/compare.py | 39 ++ .../binary-vector/vdiv-tail/golden.py | 51 ++ .../binary-vector/vdiv-tail/kernel.pto | 46 ++ .../binary-vector/vdiv-tail/launch.cpp | 46 ++ .../micro-op/binary-vector/vdiv-tail/main.cpp | 94 ++++ .../micro-op/binary-vector/vdiv-tail/stub.cpp | 24 + .../micro-op/binary-vector/vdiv/compare.py | 37 ++ .../micro-op/binary-vector/vdiv/golden.py | 49 ++ .../micro-op/binary-vector/vdiv/kernel.pto | 48 ++ .../micro-op/binary-vector/vdiv/launch.cpp | 46 ++ .../micro-op/binary-vector/vdiv/main.cpp | 94 ++++ .../micro-op/binary-vector/vdiv/stub.cpp | 24 + .../binary-vector/vmax-tail/compare.py | 39 ++ .../binary-vector/vmax-tail/golden.py | 51 ++ .../binary-vector/vmax-tail/kernel.pto | 46 ++ .../binary-vector/vmax-tail/launch.cpp | 46 ++ .../micro-op/binary-vector/vmax-tail/main.cpp | 94 ++++ .../micro-op/binary-vector/vmax-tail/stub.cpp | 24 + .../micro-op/binary-vector/vmax/compare.py | 37 ++ .../micro-op/binary-vector/vmax/golden.py | 46 ++ .../micro-op/binary-vector/vmax/kernel.pto | 48 ++ .../micro-op/binary-vector/vmax/launch.cpp | 46 ++ .../micro-op/binary-vector/vmax/main.cpp | 94 ++++ .../micro-op/binary-vector/vmax/stub.cpp | 24 + .../binary-vector/vmin-bf16/compare.py | 42 ++ .../binary-vector/vmin-bf16/golden.py | 61 +++ .../binary-vector/vmin-bf16/kernel.pto | 50 ++ .../binary-vector/vmin-bf16/launch.cpp | 52 ++ .../micro-op/binary-vector/vmin-bf16/main.cpp | 102 ++++ .../micro-op/binary-vector/vmin-bf16/stub.cpp | 31 ++ .../binary-vector/vmin-f16/compare.py | 42 ++ .../micro-op/binary-vector/vmin-f16/golden.py | 48 ++ .../binary-vector/vmin-f16/kernel.pto | 50 ++ .../binary-vector/vmin-f16/launch.cpp | 52 ++ .../micro-op/binary-vector/vmin-f16/main.cpp | 102 ++++ .../micro-op/binary-vector/vmin-f16/stub.cpp | 31 ++ .../vmin-f32-exceptional/compare.py | 37 ++ .../vmin-f32-exceptional/golden.py | 54 ++ .../vmin-f32-exceptional/kernel.pto | 48 ++ .../vmin-f32-exceptional/launch.cpp | 46 ++ .../vmin-f32-exceptional/main.cpp | 94 ++++ .../vmin-f32-exceptional/stub.cpp | 24 + .../vmin-f32-exceptional/vmin/compare.py | 37 ++ .../vmin-f32-exceptional/vmin/golden.py | 46 ++ .../vmin-f32-exceptional/vmin/kernel.pto | 48 ++ .../vmin-f32-exceptional/vmin/launch.cpp | 46 ++ .../vmin-f32-exceptional/vmin/main.cpp | 94 ++++ .../vmin-f32-exceptional/vmin/stub.cpp | 24 + .../binary-vector/vmin-i16-signed/compare.py | 42 ++ .../binary-vector/vmin-i16-signed/golden.py | 49 ++ .../binary-vector/vmin-i16-signed/kernel.pto | 50 ++ .../binary-vector/vmin-i16-signed/launch.cpp | 53 ++ .../binary-vector/vmin-i16-signed/main.cpp | 101 ++++ .../binary-vector/vmin-i16-signed/stub.cpp | 30 ++ .../vmin-i16-unsigned/compare.py | 42 ++ .../binary-vector/vmin-i16-unsigned/golden.py | 49 ++ .../vmin-i16-unsigned/kernel.pto | 50 ++ .../vmin-i16-unsigned/launch.cpp | 53 ++ .../binary-vector/vmin-i16-unsigned/main.cpp | 101 ++++ .../binary-vector/vmin-i16-unsigned/stub.cpp | 30 ++ .../binary-vector/vmin-tail/compare.py | 39 ++ .../binary-vector/vmin-tail/golden.py | 51 ++ .../binary-vector/vmin-tail/kernel.pto | 46 ++ .../binary-vector/vmin-tail/launch.cpp | 46 ++ .../micro-op/binary-vector/vmin-tail/main.cpp | 94 ++++ .../micro-op/binary-vector/vmin-tail/stub.cpp | 24 + .../micro-op/binary-vector/vmin/compare.py | 37 ++ .../micro-op/binary-vector/vmin/golden.py | 46 ++ .../micro-op/binary-vector/vmin/kernel.pto | 48 ++ .../micro-op/binary-vector/vmin/launch.cpp | 46 ++ .../micro-op/binary-vector/vmin/main.cpp | 94 ++++ .../micro-op/binary-vector/vmin/stub.cpp | 24 + .../binary-vector/vmul-tail/compare.py | 39 ++ .../binary-vector/vmul-tail/golden.py | 51 ++ .../binary-vector/vmul-tail/kernel.pto | 46 ++ .../binary-vector/vmul-tail/launch.cpp | 46 ++ .../micro-op/binary-vector/vmul-tail/main.cpp | 94 ++++ .../micro-op/binary-vector/vmul-tail/stub.cpp | 24 + .../micro-op/binary-vector/vmul/compare.py | 37 ++ .../micro-op/binary-vector/vmul/golden.py | 46 ++ .../micro-op/binary-vector/vmul/kernel.pto | 48 ++ .../micro-op/binary-vector/vmul/launch.cpp | 46 ++ .../micro-op/binary-vector/vmul/main.cpp | 94 ++++ .../micro-op/binary-vector/vmul/stub.cpp | 24 + .../micro-op/binary-vector/vor-f16/compare.py | 44 ++ .../micro-op/binary-vector/vor-f16/golden.py | 60 +++ .../micro-op/binary-vector/vor-f16/kernel.pto | 52 ++ .../micro-op/binary-vector/vor-f16/launch.cpp | 54 ++ .../micro-op/binary-vector/vor-f16/main.cpp | 103 ++++ .../micro-op/binary-vector/vor-f16/stub.cpp | 33 ++ .../binary-vector/vor-mask-edge/compare.py | 44 ++ .../binary-vector/vor-mask-edge/golden.py | 50 ++ .../binary-vector/vor-mask-edge/kernel.pto | 52 ++ .../binary-vector/vor-mask-edge/launch.cpp | 55 ++ .../binary-vector/vor-mask-edge/main.cpp | 103 ++++ .../binary-vector/vor-mask-edge/stub.cpp | 32 ++ .../micro-op/binary-vector/vor/compare.py | 44 ++ .../micro-op/binary-vector/vor/golden.py | 50 ++ .../micro-op/binary-vector/vor/kernel.pto | 52 ++ .../micro-op/binary-vector/vor/launch.cpp | 55 ++ .../cases/micro-op/binary-vector/vor/main.cpp | 103 ++++ .../cases/micro-op/binary-vector/vor/stub.cpp | 33 ++ .../vshl-i32-unsigned/compare.py | 44 ++ .../binary-vector/vshl-i32-unsigned/golden.py | 50 ++ .../vshl-i32-unsigned/kernel.pto | 54 ++ .../vshl-i32-unsigned/launch.cpp | 55 ++ .../binary-vector/vshl-i32-unsigned/main.cpp | 103 ++++ .../binary-vector/vshl-i32-unsigned/stub.cpp | 32 ++ .../vshl-shift-boundary/compare.py | 44 ++ .../vshl-shift-boundary/golden.py | 51 ++ .../vshl-shift-boundary/kernel.pto | 52 ++ .../vshl-shift-boundary/launch.cpp | 55 ++ .../vshl-shift-boundary/main.cpp | 103 ++++ .../vshl-shift-boundary/stub.cpp | 32 ++ .../micro-op/binary-vector/vshl/compare.py | 44 ++ .../micro-op/binary-vector/vshl/golden.py | 50 ++ .../micro-op/binary-vector/vshl/kernel.pto | 52 ++ .../micro-op/binary-vector/vshl/launch.cpp | 55 ++ .../micro-op/binary-vector/vshl/main.cpp | 103 ++++ .../micro-op/binary-vector/vshl/stub.cpp | 32 ++ .../binary-vector/vshr-i16-signed/compare.py | 44 ++ .../binary-vector/vshr-i16-signed/golden.py | 50 ++ .../binary-vector/vshr-i16-signed/kernel.pto | 52 ++ .../binary-vector/vshr-i16-signed/launch.cpp | 55 ++ .../binary-vector/vshr-i16-signed/main.cpp | 103 ++++ .../binary-vector/vshr-i16-signed/stub.cpp | 33 ++ .../vshr-shift-boundary/compare.py | 44 ++ .../vshr-shift-boundary/golden.py | 51 ++ .../vshr-shift-boundary/kernel.pto | 52 ++ .../vshr-shift-boundary/launch.cpp | 55 ++ .../vshr-shift-boundary/main.cpp | 103 ++++ .../vshr-shift-boundary/stub.cpp | 32 ++ .../micro-op/binary-vector/vshr/compare.py | 44 ++ .../micro-op/binary-vector/vshr/golden.py | 50 ++ .../micro-op/binary-vector/vshr/kernel.pto | 52 ++ .../micro-op/binary-vector/vshr/launch.cpp | 55 ++ .../micro-op/binary-vector/vshr/main.cpp | 103 ++++ .../micro-op/binary-vector/vshr/stub.cpp | 33 ++ .../binary-vector/vsub-tail/compare.py | 39 ++ .../binary-vector/vsub-tail/golden.py | 51 ++ .../binary-vector/vsub-tail/kernel.pto | 46 ++ .../binary-vector/vsub-tail/launch.cpp | 46 ++ .../micro-op/binary-vector/vsub-tail/main.cpp | 94 ++++ .../micro-op/binary-vector/vsub-tail/stub.cpp | 24 + .../micro-op/binary-vector/vsub/compare.py | 37 ++ .../micro-op/binary-vector/vsub/golden.py | 46 ++ .../micro-op/binary-vector/vsub/kernel.pto | 48 ++ .../micro-op/binary-vector/vsub/launch.cpp | 46 ++ .../micro-op/binary-vector/vsub/main.cpp | 94 ++++ .../micro-op/binary-vector/vsub/stub.cpp | 24 + .../vsubc-borrow-boundary/compare.py | 65 +++ .../vsubc-borrow-boundary/golden.py | 66 +++ .../vsubc-borrow-boundary/kernel.pto | 53 ++ .../vsubc-borrow-boundary/launch.cpp | 56 ++ .../vsubc-borrow-boundary/main.cpp | 119 +++++ .../vsubc-borrow-boundary/stub.cpp | 33 ++ .../micro-op/binary-vector/vsubc/compare.py | 64 +++ .../micro-op/binary-vector/vsubc/golden.py | 63 +++ .../micro-op/binary-vector/vsubc/kernel.pto | 53 ++ .../micro-op/binary-vector/vsubc/launch.cpp | 57 ++ .../micro-op/binary-vector/vsubc/main.cpp | 117 +++++ .../micro-op/binary-vector/vsubc/stub.cpp | 34 ++ .../binary-vector/vxor-mask-edge/compare.py | 44 ++ .../binary-vector/vxor-mask-edge/golden.py | 50 ++ .../binary-vector/vxor-mask-edge/kernel.pto | 52 ++ .../binary-vector/vxor-mask-edge/launch.cpp | 55 ++ .../binary-vector/vxor-mask-edge/main.cpp | 103 ++++ .../binary-vector/vxor-mask-edge/stub.cpp | 32 ++ .../micro-op/binary-vector/vxor/compare.py | 44 ++ .../micro-op/binary-vector/vxor/golden.py | 50 ++ .../micro-op/binary-vector/vxor/kernel.pto | 52 ++ .../micro-op/binary-vector/vxor/launch.cpp | 55 ++ .../micro-op/binary-vector/vxor/main.cpp | 103 ++++ .../micro-op/binary-vector/vxor/stub.cpp | 33 ++ .../compare-select/vcmp-eq/compare.py | 35 ++ .../micro-op/compare-select/vcmp-eq/golden.py | 58 ++ .../compare-select/vcmp-eq/kernel.pto | 49 ++ .../compare-select/vcmp-eq/launch.cpp | 51 ++ .../micro-op/compare-select/vcmp-eq/main.cpp | 103 ++++ .../micro-op/compare-select/vcmp-eq/stub.cpp | 24 + .../vcmp-f32-exceptional/compare.py | 35 ++ .../vcmp-f32-exceptional/golden.py | 62 +++ .../vcmp-f32-exceptional/kernel.pto | 49 ++ .../vcmp-f32-exceptional/launch.cpp | 50 ++ .../vcmp-f32-exceptional/main.cpp | 103 ++++ .../vcmp-f32-exceptional/stub.cpp | 24 + .../compare-select/vcmp-i16-signed/compare.py | 35 ++ .../compare-select/vcmp-i16-signed/golden.py | 56 ++ .../compare-select/vcmp-i16-signed/kernel.pto | 50 ++ .../compare-select/vcmp-i16-signed/launch.cpp | 59 +++ .../compare-select/vcmp-i16-signed/main.cpp | 112 ++++ .../compare-select/vcmp-i16-signed/stub.cpp | 32 ++ .../vcmp-i16-unsigned/compare.py | 35 ++ .../vcmp-i16-unsigned/golden.py | 56 ++ .../vcmp-i16-unsigned/kernel.pto | 50 ++ .../vcmp-i16-unsigned/launch.cpp | 59 +++ .../compare-select/vcmp-i16-unsigned/main.cpp | 112 ++++ .../compare-select/vcmp-i16-unsigned/stub.cpp | 32 ++ .../compare-select/vcmp-lt/compare.py | 35 ++ .../micro-op/compare-select/vcmp-lt/golden.py | 56 ++ .../compare-select/vcmp-lt/kernel.pto | 49 ++ .../compare-select/vcmp-lt/launch.cpp | 51 ++ .../micro-op/compare-select/vcmp-lt/main.cpp | 103 ++++ .../micro-op/compare-select/vcmp-lt/stub.cpp | 24 + .../compare-select/vcmp-tail/compare.py | 35 ++ .../compare-select/vcmp-tail/golden.py | 79 +++ .../compare-select/vcmp-tail/kernel.pto | 49 ++ .../compare-select/vcmp-tail/launch.cpp | 51 ++ .../compare-select/vcmp-tail/main.cpp | 103 ++++ .../compare-select/vcmp-tail/stub.cpp | 24 + .../vcmps-f32-exceptional/compare.py | 45 ++ .../vcmps-f32-exceptional/golden.py | 57 ++ .../vcmps-f32-exceptional/kernel.pto | 46 ++ .../vcmps-f32-exceptional/launch.cpp | 49 ++ .../vcmps-f32-exceptional/main.cpp | 92 ++++ .../vcmps-f32-exceptional/stub.cpp | 22 + .../compare-select/vcmps-f32/compare.py | 45 ++ .../compare-select/vcmps-f32/golden.py | 55 ++ .../compare-select/vcmps-f32/kernel.pto | 46 ++ .../compare-select/vcmps-f32/launch.cpp | 48 ++ .../compare-select/vcmps-f32/main.cpp | 91 ++++ .../compare-select/vcmps-f32/stub.cpp | 22 + .../vcmps-i16-signed/compare.py | 45 ++ .../compare-select/vcmps-i16-signed/golden.py | 53 ++ .../vcmps-i16-signed/kernel.pto | 45 ++ .../vcmps-i16-signed/launch.cpp | 48 ++ .../compare-select/vcmps-i16-signed/main.cpp | 92 ++++ .../compare-select/vcmps-i16-signed/stub.cpp | 22 + .../vcmps-i16-unsigned/compare.py | 45 ++ .../vcmps-i16-unsigned/golden.py | 53 ++ .../vcmps-i16-unsigned/kernel.pto | 46 ++ .../vcmps-i16-unsigned/launch.cpp | 48 ++ .../vcmps-i16-unsigned/main.cpp | 92 ++++ .../vcmps-i16-unsigned/stub.cpp | 22 + .../compare-select/vcmps-tail/compare.py | 45 ++ .../compare-select/vcmps-tail/golden.py | 71 +++ .../compare-select/vcmps-tail/kernel.pto | 45 ++ .../compare-select/vcmps-tail/launch.cpp | 48 ++ .../compare-select/vcmps-tail/main.cpp | 91 ++++ .../compare-select/vcmps-tail/stub.cpp | 22 + .../compare-select/vsel-i16/compare.py | 45 ++ .../compare-select/vsel-i16/golden.py | 44 ++ .../compare-select/vsel-i16/kernel.pto | 48 ++ .../compare-select/vsel-i16/launch.cpp | 50 ++ .../micro-op/compare-select/vsel-i16/main.cpp | 103 ++++ .../micro-op/compare-select/vsel-i16/stub.cpp | 24 + .../vsel-predicate-edge/compare.py | 45 ++ .../vsel-predicate-edge/golden.py | 50 ++ .../vsel-predicate-edge/kernel.pto | 50 ++ .../vsel-predicate-edge/launch.cpp | 51 ++ .../vsel-predicate-edge/main.cpp | 103 ++++ .../vsel-predicate-edge/stub.cpp | 24 + .../compare-select/vsel-tail/compare.py | 50 ++ .../compare-select/vsel-tail/golden.py | 48 ++ .../compare-select/vsel-tail/kernel.pto | 49 ++ .../compare-select/vsel-tail/launch.cpp | 50 ++ .../compare-select/vsel-tail/main.cpp | 102 ++++ .../compare-select/vsel-tail/stub.cpp | 24 + .../micro-op/compare-select/vsel/compare.py | 50 ++ .../micro-op/compare-select/vsel/golden.py | 44 ++ .../micro-op/compare-select/vsel/kernel.pto | 48 ++ .../micro-op/compare-select/vsel/launch.cpp | 50 ++ .../micro-op/compare-select/vsel/main.cpp | 102 ++++ .../micro-op/compare-select/vsel/stub.cpp | 24 + .../compare-select/vselr-f16/compare.py | 49 ++ .../compare-select/vselr-f16/golden.py | 52 ++ .../compare-select/vselr-f16/kernel.pto | 51 ++ .../compare-select/vselr-f16/launch.cpp | 54 ++ .../compare-select/vselr-f16/main.cpp | 106 ++++ .../compare-select/vselr-f16/stub.cpp | 31 ++ .../compare-select/vselr-u8/compare.py | 49 ++ .../compare-select/vselr-u8/golden.py | 53 ++ .../compare-select/vselr-u8/kernel.pto | 51 ++ .../compare-select/vselr-u8/launch.cpp | 54 ++ .../micro-op/compare-select/vselr-u8/main.cpp | 106 ++++ .../micro-op/compare-select/vselr-u8/stub.cpp | 31 ++ .../micro-op/compare-select/vselr/compare.py | 47 ++ .../micro-op/compare-select/vselr/golden.py | 67 +++ .../micro-op/compare-select/vselr/kernel.pto | 71 +++ .../micro-op/compare-select/vselr/launch.cpp | 57 ++ .../micro-op/compare-select/vselr/main.cpp | 109 ++++ .../micro-op/compare-select/vselr/stub.cpp | 30 ++ .../conversion/vcvt-f16-special/compare.py | 40 ++ .../conversion/vcvt-f16-special/golden.py | 72 +++ .../conversion/vcvt-f16-special/kernel.pto | 40 ++ .../conversion/vcvt-f16-special/launch.cpp | 71 +++ .../conversion/vcvt-f16-special/main.cpp | 130 +++++ .../conversion/vcvt-f16-special/stub.cpp | 33 ++ .../vcvt-f16-to-f32-part-even/compare.py | 40 ++ .../vcvt-f16-to-f32-part-even/golden.py | 54 ++ .../vcvt-f16-to-f32-part-even/kernel.pto | 44 ++ .../vcvt-f16-to-f32-part-even/launch.cpp | 71 +++ .../vcvt-f16-to-f32-part-even/main.cpp | 130 +++++ .../vcvt-f16-to-f32-part-even/stub.cpp | 33 ++ .../vcvt-f16-to-f32-part-odd/compare.py | 40 ++ .../vcvt-f16-to-f32-part-odd/golden.py | 54 ++ .../vcvt-f16-to-f32-part-odd/kernel.pto | 44 ++ .../vcvt-f16-to-f32-part-odd/launch.cpp | 71 +++ .../vcvt-f16-to-f32-part-odd/main.cpp | 130 +++++ .../vcvt-f16-to-f32-part-odd/stub.cpp | 33 ++ .../conversion/vcvt-f16-to-f32/compare.py | 40 ++ .../conversion/vcvt-f16-to-f32/golden.py | 48 ++ .../conversion/vcvt-f16-to-f32/kernel.pto | 40 ++ .../conversion/vcvt-f16-to-f32/launch.cpp | 71 +++ .../conversion/vcvt-f16-to-f32/main.cpp | 130 +++++ .../conversion/vcvt-f16-to-f32/stub.cpp | 33 ++ .../conversion/vcvt-f32-special/compare.py | 40 ++ .../conversion/vcvt-f32-special/golden.py | 83 +++ .../conversion/vcvt-f32-special/kernel.pto | 48 ++ .../conversion/vcvt-f32-special/launch.cpp | 71 +++ .../conversion/vcvt-f32-special/main.cpp | 130 +++++ .../conversion/vcvt-f32-special/stub.cpp | 33 ++ .../vcvt-f32-to-f16-pk-b32/compare.py | 40 ++ .../vcvt-f32-to-f16-pk-b32/golden.py | 48 ++ .../vcvt-f32-to-f16-pk-b32/kernel.pto | 46 ++ .../vcvt-f32-to-f16-pk-b32/launch.cpp | 44 ++ .../vcvt-f32-to-f16-pk-b32/main.cpp | 122 +++++ .../vcvt-f32-to-f16-pk-b32/stub.cpp | 21 + .../conversion/vcvt-f32-to-f16/compare.py | 40 ++ .../conversion/vcvt-f32-to-f16/golden.py | 58 ++ .../conversion/vcvt-f32-to-f16/kernel.pto | 48 ++ .../conversion/vcvt-f32-to-f16/launch.cpp | 71 +++ .../conversion/vcvt-f32-to-f16/main.cpp | 130 +++++ .../conversion/vcvt-f32-to-f16/stub.cpp | 33 ++ .../vcvt-i32-to-i16-overflow/compare.py | 42 ++ .../vcvt-i32-to-i16-overflow/golden.py | 60 +++ .../vcvt-i32-to-i16-overflow/kernel.pto | 53 ++ .../vcvt-i32-to-i16-overflow/launch.cpp | 49 ++ .../vcvt-i32-to-i16-overflow/main.cpp | 103 ++++ .../vcvt-i32-to-i16-overflow/stub.cpp | 23 + .../conversion/vcvt-tail-special/compare.py | 43 ++ .../conversion/vcvt-tail-special/golden.py | 88 ++++ .../conversion/vcvt-tail-special/kernel.pto | 48 ++ .../conversion/vcvt-tail-special/launch.cpp | 71 +++ .../conversion/vcvt-tail-special/main.cpp | 130 +++++ .../conversion/vcvt-tail-special/stub.cpp | 33 ++ .../micro-op/conversion/vcvt-tail/compare.py | 43 ++ .../micro-op/conversion/vcvt-tail/golden.py | 68 +++ .../micro-op/conversion/vcvt-tail/kernel.pto | 48 ++ .../micro-op/conversion/vcvt-tail/launch.cpp | 71 +++ .../micro-op/conversion/vcvt-tail/main.cpp | 130 +++++ .../micro-op/conversion/vcvt-tail/stub.cpp | 33 ++ .../conversion/vtrc-f16-rounding/compare.py | 39 ++ .../conversion/vtrc-f16-rounding/golden.py | 50 ++ .../conversion/vtrc-f16-rounding/kernel.pto | 41 ++ .../conversion/vtrc-f16-rounding/launch.cpp | 45 ++ .../conversion/vtrc-f16-rounding/main.cpp | 83 +++ .../conversion/vtrc-f16-rounding/stub.cpp | 22 + .../conversion/vtrc-f32-rounding/compare.py | 206 ++++++++ .../conversion/vtrc-f32-rounding/golden.py | 54 ++ .../conversion/vtrc-f32-rounding/kernel.pto | 75 +++ .../conversion/vtrc-f32-rounding/launch.cpp | 68 +++ .../conversion/vtrc-f32-rounding/main.cpp | 147 ++++++ .../conversion/vtrc-f32-rounding/stub.cpp | 29 + .../conversion/vtrc-f32-special/compare.py | 39 ++ .../conversion/vtrc-f32-special/golden.py | 50 ++ .../conversion/vtrc-f32-special/kernel.pto | 40 ++ .../conversion/vtrc-f32-special/launch.cpp | 45 ++ .../conversion/vtrc-f32-special/main.cpp | 83 +++ .../conversion/vtrc-f32-special/stub.cpp | 22 + .../vtrc-rounding-boundary/compare.py | 206 ++++++++ .../vtrc-rounding-boundary/golden.py | 58 ++ .../vtrc-rounding-boundary/kernel.pto | 75 +++ .../vtrc-rounding-boundary/launch.cpp | 68 +++ .../vtrc-rounding-boundary/main.cpp | 147 ++++++ .../vtrc-rounding-boundary/stub.cpp | 29 + .../micro-op/dsa-sfu/vaxpy-f32/compare.py | 42 ++ .../micro-op/dsa-sfu/vaxpy-f32/golden.py | 52 ++ .../micro-op/dsa-sfu/vaxpy-f32/kernel.pto | 55 ++ .../micro-op/dsa-sfu/vaxpy-f32/launch.cpp | 54 ++ .../cases/micro-op/dsa-sfu/vaxpy-f32/main.cpp | 102 ++++ .../cases/micro-op/dsa-sfu/vaxpy-f32/stub.cpp | 32 ++ .../micro-op/dsa-sfu/vbitsort/compare.py | 42 ++ .../cases/micro-op/dsa-sfu/vbitsort/golden.py | 64 +++ .../micro-op/dsa-sfu/vbitsort/kernel.pto | 40 ++ .../micro-op/dsa-sfu/vbitsort/launch.cpp | 51 ++ .../cases/micro-op/dsa-sfu/vbitsort/main.cpp | 119 +++++ .../cases/micro-op/dsa-sfu/vbitsort/stub.cpp | 25 + .../cases/micro-op/dsa-sfu/vci/compare.py | 42 ++ .../vpto/cases/micro-op/dsa-sfu/vci/golden.py | 49 ++ .../cases/micro-op/dsa-sfu/vci/kernel.pto | 52 ++ .../cases/micro-op/dsa-sfu/vci/launch.cpp | 52 ++ test/vpto/cases/micro-op/dsa-sfu/vci/main.cpp | 91 ++++ test/vpto/cases/micro-op/dsa-sfu/vci/stub.cpp | 30 ++ .../dsa-sfu/vexpdiff-boundary/compare.py | 42 ++ .../dsa-sfu/vexpdiff-boundary/golden.py | 65 +++ .../dsa-sfu/vexpdiff-boundary/kernel.pto | 54 ++ .../dsa-sfu/vexpdiff-boundary/launch.cpp | 54 ++ .../dsa-sfu/vexpdiff-boundary/main.cpp | 102 ++++ .../dsa-sfu/vexpdiff-boundary/stub.cpp | 32 ++ .../dsa-sfu/vexpdiff-f16-part/compare.py | 44 ++ .../dsa-sfu/vexpdiff-f16-part/golden.py | 61 +++ .../dsa-sfu/vexpdiff-f16-part/kernel.pto | 60 +++ .../dsa-sfu/vexpdiff-f16-part/launch.cpp | 53 ++ .../dsa-sfu/vexpdiff-f16-part/main.cpp | 101 ++++ .../dsa-sfu/vexpdiff-f16-part/stub.cpp | 30 ++ .../micro-op/dsa-sfu/vexpdiff-f32/compare.py | 42 ++ .../micro-op/dsa-sfu/vexpdiff-f32/golden.py | 47 ++ .../micro-op/dsa-sfu/vexpdiff-f32/kernel.pto | 50 ++ .../micro-op/dsa-sfu/vexpdiff-f32/launch.cpp | 52 ++ .../micro-op/dsa-sfu/vexpdiff-f32/main.cpp | 91 ++++ .../micro-op/dsa-sfu/vexpdiff-f32/stub.cpp | 30 ++ .../micro-op/dsa-sfu/vlrelu-f16/compare.py | 42 ++ .../micro-op/dsa-sfu/vlrelu-f16/golden.py | 50 ++ .../micro-op/dsa-sfu/vlrelu-f16/kernel.pto | 50 ++ .../micro-op/dsa-sfu/vlrelu-f16/launch.cpp | 52 ++ .../micro-op/dsa-sfu/vlrelu-f16/main.cpp | 91 ++++ .../micro-op/dsa-sfu/vlrelu-f16/stub.cpp | 30 ++ .../dsa-sfu/vlrelu-f32-exceptional/compare.py | 37 ++ .../dsa-sfu/vlrelu-f32-exceptional/golden.py | 49 ++ .../dsa-sfu/vlrelu-f32-exceptional/kernel.pto | 44 ++ .../dsa-sfu/vlrelu-f32-exceptional/launch.cpp | 44 ++ .../dsa-sfu/vlrelu-f32-exceptional/main.cpp | 83 +++ .../dsa-sfu/vlrelu-f32-exceptional/stub.cpp | 22 + .../micro-op/dsa-sfu/vlrelu-f32/compare.py | 37 ++ .../micro-op/dsa-sfu/vlrelu-f32/golden.py | 45 ++ .../micro-op/dsa-sfu/vlrelu-f32/kernel.pto | 44 ++ .../micro-op/dsa-sfu/vlrelu-f32/launch.cpp | 44 ++ .../micro-op/dsa-sfu/vlrelu-f32/main.cpp | 83 +++ .../micro-op/dsa-sfu/vlrelu-f32/stub.cpp | 22 + .../micro-op/dsa-sfu/vlrelu-tail/compare.py | 39 ++ .../micro-op/dsa-sfu/vlrelu-tail/golden.py | 51 ++ .../micro-op/dsa-sfu/vlrelu-tail/kernel.pto | 43 ++ .../micro-op/dsa-sfu/vlrelu-tail/launch.cpp | 44 ++ .../micro-op/dsa-sfu/vlrelu-tail/main.cpp | 83 +++ .../micro-op/dsa-sfu/vlrelu-tail/stub.cpp | 22 + .../vmula-accumulator-boundary/compare.py | 54 ++ .../vmula-accumulator-boundary/golden.py | 48 ++ .../vmula-accumulator-boundary/kernel.pto | 50 ++ .../vmula-accumulator-boundary/launch.cpp | 52 ++ .../vmula-accumulator-boundary/main.cpp | 91 ++++ .../vmula-accumulator-boundary/stub.cpp | 30 ++ .../cases/micro-op/dsa-sfu/vmula/compare.py | 42 ++ .../cases/micro-op/dsa-sfu/vmula/golden.py | 47 ++ .../cases/micro-op/dsa-sfu/vmula/kernel.pto | 50 ++ .../cases/micro-op/dsa-sfu/vmula/launch.cpp | 52 ++ .../cases/micro-op/dsa-sfu/vmula/main.cpp | 91 ++++ .../cases/micro-op/dsa-sfu/vmula/stub.cpp | 30 ++ .../cases/micro-op/dsa-sfu/vmull/compare.py | 42 ++ .../cases/micro-op/dsa-sfu/vmull/golden.py | 52 ++ .../cases/micro-op/dsa-sfu/vmull/kernel.pto | 52 ++ .../cases/micro-op/dsa-sfu/vmull/launch.cpp | 52 ++ .../cases/micro-op/dsa-sfu/vmull/main.cpp | 91 ++++ .../cases/micro-op/dsa-sfu/vmull/stub.cpp | 30 ++ .../micro-op/dsa-sfu/vprelu-f32/compare.py | 42 ++ .../micro-op/dsa-sfu/vprelu-f32/golden.py | 49 ++ .../micro-op/dsa-sfu/vprelu-f32/kernel.pto | 54 ++ .../micro-op/dsa-sfu/vprelu-f32/launch.cpp | 54 ++ .../micro-op/dsa-sfu/vprelu-f32/main.cpp | 102 ++++ .../micro-op/dsa-sfu/vprelu-f32/stub.cpp | 32 ++ .../micro-op/dsa-sfu/vprelu-tail/compare.py | 54 ++ .../micro-op/dsa-sfu/vprelu-tail/golden.py | 50 ++ .../micro-op/dsa-sfu/vprelu-tail/kernel.pto | 56 ++ .../micro-op/dsa-sfu/vprelu-tail/launch.cpp | 54 ++ .../micro-op/dsa-sfu/vprelu-tail/main.cpp | 102 ++++ .../micro-op/dsa-sfu/vprelu-tail/stub.cpp | 32 ++ .../vgather2-duplicate-index/compare.py | 209 ++++++++ .../vgather2-duplicate-index/golden.py | 56 ++ .../vgather2-duplicate-index/kernel.pto | 53 ++ .../vgather2-duplicate-index/launch.cpp | 72 +++ .../vgather2-duplicate-index/main.cpp | 141 +++++ .../vgather2-duplicate-index/stub.cpp | 34 ++ .../gather-scatter/vgather2/compare.py | 209 ++++++++ .../gather-scatter/vgather2/golden.py | 55 ++ .../gather-scatter/vgather2/kernel.pto | 72 +++ .../gather-scatter/vgather2/launch.cpp | 73 +++ .../micro-op/gather-scatter/vgather2/main.cpp | 140 +++++ .../micro-op/gather-scatter/vgather2/stub.cpp | 35 ++ .../vgather2_bc-sparse-mask/compare.py | 209 ++++++++ .../vgather2_bc-sparse-mask/golden.py | 57 ++ .../vgather2_bc-sparse-mask/kernel.pto | 55 ++ .../vgather2_bc-sparse-mask/launch.cpp | 72 +++ .../vgather2_bc-sparse-mask/main.cpp | 141 +++++ .../vgather2_bc-sparse-mask/stub.cpp | 34 ++ .../gather-scatter/vgather2_bc/compare.py | 209 ++++++++ .../gather-scatter/vgather2_bc/golden.py | 55 ++ .../gather-scatter/vgather2_bc/kernel.pto | 55 ++ .../gather-scatter/vgather2_bc/launch.cpp | 73 +++ .../gather-scatter/vgather2_bc/main.cpp | 140 +++++ .../gather-scatter/vgather2_bc/stub.cpp | 35 ++ .../vgatherb-block-boundary/compare.py | 209 ++++++++ .../vgatherb-block-boundary/golden.py | 66 +++ .../vgatherb-block-boundary/kernel.pto | 55 ++ .../vgatherb-block-boundary/launch.cpp | 72 +++ .../vgatherb-block-boundary/main.cpp | 141 +++++ .../vgatherb-block-boundary/stub.cpp | 34 ++ .../gather-scatter/vgatherb/compare.py | 209 ++++++++ .../gather-scatter/vgatherb/golden.py | 66 +++ .../gather-scatter/vgatherb/kernel.pto | 55 ++ .../gather-scatter/vgatherb/launch.cpp | 73 +++ .../micro-op/gather-scatter/vgatherb/main.cpp | 140 +++++ .../micro-op/gather-scatter/vgatherb/stub.cpp | 35 ++ .../vscatter-out-of-order-index/compare.py | 209 ++++++++ .../vscatter-out-of-order-index/golden.py | 56 ++ .../vscatter-out-of-order-index/kernel.pto | 55 ++ .../vscatter-out-of-order-index/launch.cpp | 72 +++ .../vscatter-out-of-order-index/main.cpp | 141 +++++ .../vscatter-out-of-order-index/stub.cpp | 34 ++ .../gather-scatter/vscatter/compare.py | 209 ++++++++ .../gather-scatter/vscatter/golden.py | 56 ++ .../gather-scatter/vscatter/kernel.pto | 74 +++ .../gather-scatter/vscatter/launch.cpp | 73 +++ .../micro-op/gather-scatter/vscatter/main.cpp | 140 +++++ .../micro-op/gather-scatter/vscatter/stub.cpp | 35 ++ .../materialization-predicate/pand/compare.py | 64 +++ .../materialization-predicate/pand/golden.py | 62 +++ .../materialization-predicate/pand/kernel.pto | 55 ++ .../materialization-predicate/pand/launch.cpp | 48 ++ .../materialization-predicate/pand/main.cpp | 104 ++++ .../materialization-predicate/pand/stub.cpp | 25 + .../pdintlv_b16-nontrivial/compare.py | 64 +++ .../pdintlv_b16-nontrivial/golden.py | 49 ++ .../pdintlv_b16-nontrivial/kernel.pto | 37 ++ .../pdintlv_b16-nontrivial/launch.cpp | 48 ++ .../pdintlv_b16-nontrivial/main.cpp | 104 ++++ .../pdintlv_b16-nontrivial/stub.cpp | 25 + .../pdintlv_b16/compare.py | 64 +++ .../pdintlv_b16/golden.py | 49 ++ .../pdintlv_b16/kernel.pto | 37 ++ .../pdintlv_b16/launch.cpp | 48 ++ .../pdintlv_b16/main.cpp | 104 ++++ .../pdintlv_b16/stub.cpp | 25 + .../pdintlv_b32-nontrivial/compare.py | 64 +++ .../pdintlv_b32-nontrivial/golden.py | 49 ++ .../pdintlv_b32-nontrivial/kernel.pto | 37 ++ .../pdintlv_b32-nontrivial/launch.cpp | 48 ++ .../pdintlv_b32-nontrivial/main.cpp | 104 ++++ .../pdintlv_b32-nontrivial/stub.cpp | 25 + .../pdintlv_b32/compare.py | 64 +++ .../pdintlv_b32/golden.py | 49 ++ .../pdintlv_b32/kernel.pto | 37 ++ .../pdintlv_b32/launch.cpp | 48 ++ .../pdintlv_b32/main.cpp | 104 ++++ .../pdintlv_b32/stub.cpp | 25 + .../pdintlv_b8-nontrivial/compare.py | 64 +++ .../pdintlv_b8-nontrivial/golden.py | 49 ++ .../pdintlv_b8-nontrivial/kernel.pto | 37 ++ .../pdintlv_b8-nontrivial/launch.cpp | 48 ++ .../pdintlv_b8-nontrivial/main.cpp | 104 ++++ .../pdintlv_b8-nontrivial/stub.cpp | 25 + .../pdintlv_b8/compare.py | 64 +++ .../pdintlv_b8/golden.py | 49 ++ .../pdintlv_b8/kernel.pto | 37 ++ .../pdintlv_b8/launch.cpp | 48 ++ .../pdintlv_b8/main.cpp | 104 ++++ .../pdintlv_b8/stub.cpp | 25 + .../pge-tail-mask-boundary/compare.py | 53 ++ .../pge-tail-mask-boundary/golden.py | 55 ++ .../pge-tail-mask-boundary/kernel.pto | 38 ++ .../pge-tail-mask-boundary/launch.cpp | 48 ++ .../pge-tail-mask-boundary/main.cpp | 104 ++++ .../pge-tail-mask-boundary/stub.cpp | 25 + .../pge-tail-mask/compare.py | 54 ++ .../pge-tail-mask/golden.py | 59 +++ .../pge-tail-mask/kernel.pto | 38 ++ .../pge-tail-mask/launch.cpp | 48 ++ .../pge-tail-mask/main.cpp | 104 ++++ .../pge-tail-mask/stub.cpp | 25 + .../pintlv_b16-nontrivial/compare.py | 64 +++ .../pintlv_b16-nontrivial/golden.py | 49 ++ .../pintlv_b16-nontrivial/kernel.pto | 37 ++ .../pintlv_b16-nontrivial/launch.cpp | 48 ++ .../pintlv_b16-nontrivial/main.cpp | 104 ++++ .../pintlv_b16-nontrivial/stub.cpp | 25 + .../pintlv_b16/compare.py | 64 +++ .../pintlv_b16/golden.py | 49 ++ .../pintlv_b16/kernel.pto | 37 ++ .../pintlv_b16/launch.cpp | 48 ++ .../pintlv_b16/main.cpp | 104 ++++ .../pintlv_b16/stub.cpp | 25 + .../pintlv_b32-nontrivial/compare.py | 64 +++ .../pintlv_b32-nontrivial/golden.py | 49 ++ .../pintlv_b32-nontrivial/kernel.pto | 37 ++ .../pintlv_b32-nontrivial/launch.cpp | 48 ++ .../pintlv_b32-nontrivial/main.cpp | 104 ++++ .../pintlv_b32-nontrivial/stub.cpp | 25 + .../pintlv_b32/compare.py | 64 +++ .../pintlv_b32/golden.py | 49 ++ .../pintlv_b32/kernel.pto | 37 ++ .../pintlv_b32/launch.cpp | 48 ++ .../pintlv_b32/main.cpp | 104 ++++ .../pintlv_b32/stub.cpp | 25 + .../pintlv_b8-nontrivial/compare.py | 64 +++ .../pintlv_b8-nontrivial/golden.py | 49 ++ .../pintlv_b8-nontrivial/kernel.pto | 37 ++ .../pintlv_b8-nontrivial/launch.cpp | 48 ++ .../pintlv_b8-nontrivial/main.cpp | 104 ++++ .../pintlv_b8-nontrivial/stub.cpp | 25 + .../pintlv_b8/compare.py | 64 +++ .../pintlv_b8/golden.py | 49 ++ .../pintlv_b8/kernel.pto | 37 ++ .../pintlv_b8/launch.cpp | 48 ++ .../pintlv_b8/main.cpp | 104 ++++ .../pintlv_b8/stub.cpp | 25 + .../plt-tail-mask-boundary/compare.py | 53 ++ .../plt-tail-mask-boundary/golden.py | 55 ++ .../plt-tail-mask-boundary/kernel.pto | 39 ++ .../plt-tail-mask-boundary/launch.cpp | 48 ++ .../plt-tail-mask-boundary/main.cpp | 104 ++++ .../plt-tail-mask-boundary/stub.cpp | 25 + .../plt-tail-mask/compare.py | 54 ++ .../plt-tail-mask/golden.py | 59 +++ .../plt-tail-mask/kernel.pto | 41 ++ .../plt-tail-mask/launch.cpp | 48 ++ .../plt-tail-mask/main.cpp | 104 ++++ .../plt-tail-mask/stub.cpp | 25 + .../materialization-predicate/pnot/compare.py | 64 +++ .../materialization-predicate/pnot/golden.py | 49 ++ .../materialization-predicate/pnot/kernel.pto | 36 ++ .../materialization-predicate/pnot/launch.cpp | 48 ++ .../materialization-predicate/pnot/main.cpp | 104 ++++ .../materialization-predicate/pnot/stub.cpp | 25 + .../materialization-predicate/por/compare.py | 64 +++ .../materialization-predicate/por/golden.py | 62 +++ .../materialization-predicate/por/kernel.pto | 55 ++ .../materialization-predicate/por/launch.cpp | 48 ++ .../materialization-predicate/por/main.cpp | 104 ++++ .../materialization-predicate/por/stub.cpp | 25 + .../ppack-punpack-nontrivial/compare.py | 64 +++ .../ppack-punpack-nontrivial/golden.py | 49 ++ .../ppack-punpack-nontrivial/kernel.pto | 36 ++ .../ppack-punpack-nontrivial/launch.cpp | 48 ++ .../ppack-punpack-nontrivial/main.cpp | 104 ++++ .../ppack-punpack-nontrivial/stub.cpp | 25 + .../ppack-punpack/compare.py | 64 +++ .../ppack-punpack/golden.py | 49 ++ .../ppack-punpack/kernel.pto | 36 ++ .../ppack-punpack/launch.cpp | 48 ++ .../ppack-punpack/main.cpp | 104 ++++ .../ppack-punpack/stub.cpp | 25 + .../psel-tail-predicate/compare.py | 64 +++ .../psel-tail-predicate/golden.py | 49 ++ .../psel-tail-predicate/kernel.pto | 40 ++ .../psel-tail-predicate/launch.cpp | 48 ++ .../psel-tail-predicate/main.cpp | 104 ++++ .../psel-tail-predicate/stub.cpp | 25 + .../materialization-predicate/psel/compare.py | 64 +++ .../materialization-predicate/psel/golden.py | 49 ++ .../materialization-predicate/psel/kernel.pto | 36 ++ .../materialization-predicate/psel/launch.cpp | 48 ++ .../materialization-predicate/psel/main.cpp | 104 ++++ .../materialization-predicate/psel/stub.cpp | 25 + .../pset-pattern-fragment/compare.py | 64 +++ .../pset-pattern-fragment/golden.py | 49 ++ .../pset-pattern-fragment/kernel.pto | 38 ++ .../pset-pattern-fragment/launch.cpp | 48 ++ .../pset-pattern-fragment/main.cpp | 104 ++++ .../pset-pattern-fragment/stub.cpp | 25 + .../pset-pattern/compare.py | 60 +++ .../pset-pattern/golden.py | 63 +++ .../pset-pattern/kernel.pto | 38 ++ .../pset-pattern/launch.cpp | 69 +++ .../pset-pattern/main.cpp | 119 +++++ .../pset-pattern/stub.cpp | 31 ++ .../materialization-predicate/pxor/compare.py | 64 +++ .../materialization-predicate/pxor/golden.py | 62 +++ .../materialization-predicate/pxor/kernel.pto | 55 ++ .../materialization-predicate/pxor/launch.cpp | 48 ++ .../materialization-predicate/pxor/main.cpp | 104 ++++ .../materialization-predicate/pxor/stub.cpp | 25 + .../vbr-f32/compare.py | 204 +++++++ .../vbr-f32/golden.py | 45 ++ .../vbr-f32/kernel.pto | 48 ++ .../vbr-f32/launch.cpp | 61 +++ .../vbr-f32/main.cpp | 111 ++++ .../vbr-f32/stub.cpp | 23 + .../vbr-i32/compare.py | 204 +++++++ .../vbr-i32/golden.py | 45 ++ .../vbr-i32/kernel.pto | 48 ++ .../vbr-i32/launch.cpp | 61 +++ .../vbr-i32/main.cpp | 111 ++++ .../vbr-i32/stub.cpp | 23 + .../vdup-lane/compare.py | 61 +++ .../vdup-lane/golden.py | 62 +++ .../vdup-lane/kernel.pto | 50 ++ .../vdup-lane/launch.cpp | 46 ++ .../vdup-lane/main.cpp | 119 +++++ .../vdup-lane/stub.cpp | 30 ++ .../vdup-scalar-f16/compare.py | 60 +++ .../vdup-scalar-f16/golden.py | 45 ++ .../vdup-scalar-f16/kernel.pto | 35 ++ .../vdup-scalar-f16/launch.cpp | 43 ++ .../vdup-scalar-f16/main.cpp | 111 ++++ .../vdup-scalar-f16/stub.cpp | 20 + .../vdup-scalar-i8/compare.py | 54 ++ .../vdup-scalar-i8/golden.py | 45 ++ .../vdup-scalar-i8/kernel.pto | 35 ++ .../vdup-scalar-i8/launch.cpp | 43 ++ .../vdup-scalar-i8/main.cpp | 111 ++++ .../vdup-scalar-i8/stub.cpp | 20 + .../vdup-scalar/compare.py | 204 +++++++ .../vdup-scalar/golden.py | 45 ++ .../vdup-scalar/kernel.pto | 48 ++ .../vdup-scalar/launch.cpp | 61 +++ .../vdup-scalar/main.cpp | 111 ++++ .../vdup-scalar/stub.cpp | 23 + .../_predicate_load_store_case.py | 76 +++ .../predicate-load-store/pldi-norm/compare.py | 38 ++ .../predicate-load-store/pldi-norm/golden.py | 74 +++ .../predicate-load-store/pldi-norm/kernel.pto | 58 ++ .../predicate-load-store/pldi-norm/launch.cpp | 45 ++ .../predicate-load-store/pldi-norm/main.cpp | 85 +++ .../predicate-load-store/pldi-norm/stub.cpp | 22 + .../predicate-load-store/plds-norm/compare.py | 38 ++ .../predicate-load-store/plds-norm/golden.py | 74 +++ .../predicate-load-store/plds-norm/kernel.pto | 57 ++ .../predicate-load-store/plds-norm/launch.cpp | 45 ++ .../predicate-load-store/plds-norm/main.cpp | 85 +++ .../predicate-load-store/plds-norm/stub.cpp | 22 + .../psti-norm-pldi-ds/compare.py | 37 ++ .../psti-norm-pldi-ds/golden.py | 56 ++ .../psti-norm-pldi-ds/kernel.pto | 50 ++ .../psti-norm-pldi-ds/launch.cpp | 48 ++ .../psti-norm-pldi-ds/main.cpp | 95 ++++ .../psti-norm-pldi-ds/stub.cpp | 24 + .../psti-pk-pldi-us/compare.py | 37 ++ .../psti-pk-pldi-us/golden.py | 54 ++ .../psti-pk-pldi-us/kernel.pto | 50 ++ .../psti-pk-pldi-us/launch.cpp | 48 ++ .../psti-pk-pldi-us/main.cpp | 95 ++++ .../psti-pk-pldi-us/stub.cpp | 24 + .../predicate-load-store/psti-pk/compare.py | 44 ++ .../predicate-load-store/psti-pk/golden.py | 56 ++ .../predicate-load-store/psti-pk/kernel.pto | 33 ++ .../predicate-load-store/psti-pk/launch.cpp | 43 ++ .../predicate-load-store/psti-pk/main.cpp | 78 +++ .../predicate-load-store/psti-pk/stub.cpp | 20 + .../psts-norm-plds-ds/compare.py | 37 ++ .../psts-norm-plds-ds/golden.py | 42 ++ .../psts-norm-plds-ds/kernel.pto | 55 ++ .../psts-norm-plds-ds/launch.cpp | 48 ++ .../psts-norm-plds-ds/main.cpp | 95 ++++ .../psts-norm-plds-ds/stub.cpp | 24 + .../compare.py | 37 ++ .../psts-pk-plds-us-prefix-boundary/golden.py | 44 ++ .../kernel.pto | 52 ++ .../launch.cpp | 47 ++ .../psts-pk-plds-us-prefix-boundary/main.cpp | 95 ++++ .../psts-pk-plds-us-prefix-boundary/stub.cpp | 24 + .../psts-pk-plds-us/compare.py | 37 ++ .../psts-pk-plds-us/golden.py | 42 ++ .../psts-pk-plds-us/kernel.pto | 51 ++ .../psts-pk-plds-us/launch.cpp | 48 ++ .../psts-pk-plds-us/main.cpp | 95 ++++ .../psts-pk-plds-us/stub.cpp | 24 + .../pstu-init-align-outside-loop/compare.py | 50 ++ .../pstu-init-align-outside-loop/golden.py | 83 +++ .../pstu-init-align-outside-loop/kernel.pto | 45 ++ .../pstu-init-align-outside-loop/launch.cpp | 55 ++ .../pstu-init-align-outside-loop/main.cpp | 112 ++++ .../pstu-init-align-outside-loop/stub.cpp | 32 ++ .../pstu-state-advance-boundary/compare.py | 50 ++ .../pstu-state-advance-boundary/golden.py | 80 +++ .../pstu-state-advance-boundary/kernel.pto | 40 ++ .../pstu-state-advance-boundary/launch.cpp | 58 ++ .../pstu-state-advance-boundary/main.cpp | 111 ++++ .../pstu-state-advance-boundary/stub.cpp | 32 ++ .../predicate-load-store/pstu/compare.py | 53 ++ .../predicate-load-store/pstu/golden.py | 82 +++ .../predicate-load-store/pstu/kernel.pto | 41 ++ .../predicate-load-store/pstu/launch.cpp | 57 ++ .../predicate-load-store/pstu/main.cpp | 110 ++++ .../predicate-load-store/pstu/stub.cpp | 32 ++ .../vintlv-vdintlv-lane-boundary/compare.py | 209 ++++++++ .../vintlv-vdintlv-lane-boundary/golden.py | 54 ++ .../vintlv-vdintlv-lane-boundary/kernel.pto | 52 ++ .../vintlv-vdintlv-lane-boundary/launch.cpp | 71 +++ .../vintlv-vdintlv-lane-boundary/main.cpp | 130 +++++ .../vintlv-vdintlv-lane-boundary/stub.cpp | 33 ++ .../rearrangement/vintlv-vdintlv/compare.py | 209 ++++++++ .../rearrangement/vintlv-vdintlv/golden.py | 51 ++ .../rearrangement/vintlv-vdintlv/kernel.pto | 52 ++ .../rearrangement/vintlv-vdintlv/launch.cpp | 71 +++ .../rearrangement/vintlv-vdintlv/main.cpp | 130 +++++ .../rearrangement/vintlv-vdintlv/stub.cpp | 33 ++ .../rearrangement/vpack-higher/compare.py | 42 ++ .../rearrangement/vpack-higher/golden.py | 63 +++ .../rearrangement/vpack-higher/kernel.pto | 49 ++ .../rearrangement/vpack-higher/launch.cpp | 51 ++ .../rearrangement/vpack-higher/main.cpp | 128 +++++ .../rearrangement/vpack-higher/stub.cpp | 29 + .../rearrangement/vpack-lower/compare.py | 42 ++ .../rearrangement/vpack-lower/golden.py | 63 +++ .../rearrangement/vpack-lower/kernel.pto | 49 ++ .../rearrangement/vpack-lower/launch.cpp | 51 ++ .../rearrangement/vpack-lower/main.cpp | 128 +++++ .../rearrangement/vpack-lower/stub.cpp | 29 + .../vsqz-nontrivial-mask/compare.py | 208 ++++++++ .../vsqz-nontrivial-mask/golden.py | 59 +++ .../vsqz-nontrivial-mask/kernel.pto | 61 +++ .../vsqz-nontrivial-mask/launch.cpp | 70 +++ .../vsqz-nontrivial-mask/main.cpp | 130 +++++ .../vsqz-nontrivial-mask/stub.cpp | 33 ++ .../micro-op/rearrangement/vsqz/compare.py | 209 ++++++++ .../micro-op/rearrangement/vsqz/golden.py | 51 ++ .../micro-op/rearrangement/vsqz/kernel.pto | 68 +++ .../micro-op/rearrangement/vsqz/launch.cpp | 70 +++ .../micro-op/rearrangement/vsqz/main.cpp | 130 +++++ .../micro-op/rearrangement/vsqz/stub.cpp | 33 ++ .../rearrangement/vsunpack/compare.py | 209 ++++++++ .../micro-op/rearrangement/vsunpack/golden.py | 54 ++ .../rearrangement/vsunpack/kernel.pto | 70 +++ .../rearrangement/vsunpack/launch.cpp | 70 +++ .../micro-op/rearrangement/vsunpack/main.cpp | 132 +++++ .../micro-op/rearrangement/vsunpack/stub.cpp | 33 ++ .../vusqz-nontrivial-mask/compare.py | 36 ++ .../vusqz-nontrivial-mask/golden.py | 69 +++ .../vusqz-nontrivial-mask/kernel.pto | 55 ++ .../vusqz-nontrivial-mask/launch.cpp | 55 ++ .../vusqz-nontrivial-mask/main.cpp | 103 ++++ .../vusqz-nontrivial-mask/stub.cpp | 30 ++ .../micro-op/rearrangement/vusqz/compare.py | 36 ++ .../micro-op/rearrangement/vusqz/golden.py | 66 +++ .../micro-op/rearrangement/vusqz/kernel.pto | 55 ++ .../micro-op/rearrangement/vusqz/launch.cpp | 52 ++ .../micro-op/rearrangement/vusqz/main.cpp | 100 ++++ .../micro-op/rearrangement/vusqz/stub.cpp | 30 ++ .../rearrangement/vzunpack/compare.py | 209 ++++++++ .../micro-op/rearrangement/vzunpack/golden.py | 54 ++ .../rearrangement/vzunpack/kernel.pto | 70 +++ .../rearrangement/vzunpack/launch.cpp | 70 +++ .../micro-op/rearrangement/vzunpack/main.cpp | 132 +++++ .../micro-op/rearrangement/vzunpack/stub.cpp | 33 ++ .../micro-op/reduction/vcadd-tail/compare.py | 39 ++ .../micro-op/reduction/vcadd-tail/golden.py | 49 ++ .../micro-op/reduction/vcadd-tail/kernel.pto | 40 ++ .../micro-op/reduction/vcadd-tail/launch.cpp | 44 ++ .../micro-op/reduction/vcadd-tail/main.cpp | 87 +++ .../micro-op/reduction/vcadd-tail/stub.cpp | 22 + .../cases/micro-op/reduction/vcadd/compare.py | 204 +++++++ .../cases/micro-op/reduction/vcadd/golden.py | 48 ++ .../cases/micro-op/reduction/vcadd/kernel.pto | 41 ++ .../cases/micro-op/reduction/vcadd/launch.cpp | 62 +++ .../cases/micro-op/reduction/vcadd/main.cpp | 122 +++++ .../cases/micro-op/reduction/vcadd/stub.cpp | 25 + .../micro-op/reduction/vcgadd-tail/compare.py | 209 ++++++++ .../micro-op/reduction/vcgadd-tail/golden.py | 54 ++ .../micro-op/reduction/vcgadd-tail/kernel.pto | 49 ++ .../micro-op/reduction/vcgadd-tail/launch.cpp | 70 +++ .../micro-op/reduction/vcgadd-tail/main.cpp | 130 +++++ .../micro-op/reduction/vcgadd-tail/stub.cpp | 33 ++ .../micro-op/reduction/vcgadd/compare.py | 209 ++++++++ .../cases/micro-op/reduction/vcgadd/golden.py | 53 ++ .../micro-op/reduction/vcgadd/kernel.pto | 49 ++ .../micro-op/reduction/vcgadd/launch.cpp | 70 +++ .../cases/micro-op/reduction/vcgadd/main.cpp | 130 +++++ .../cases/micro-op/reduction/vcgadd/stub.cpp | 33 ++ .../micro-op/reduction/vcgmax-tie/compare.py | 209 ++++++++ .../micro-op/reduction/vcgmax-tie/golden.py | 58 ++ .../micro-op/reduction/vcgmax-tie/kernel.pto | 49 ++ .../micro-op/reduction/vcgmax-tie/launch.cpp | 70 +++ .../micro-op/reduction/vcgmax-tie/main.cpp | 130 +++++ .../micro-op/reduction/vcgmax-tie/stub.cpp | 33 ++ .../micro-op/reduction/vcgmax/compare.py | 209 ++++++++ .../cases/micro-op/reduction/vcgmax/golden.py | 53 ++ .../micro-op/reduction/vcgmax/kernel.pto | 49 ++ .../micro-op/reduction/vcgmax/launch.cpp | 70 +++ .../cases/micro-op/reduction/vcgmax/main.cpp | 130 +++++ .../cases/micro-op/reduction/vcgmax/stub.cpp | 33 ++ .../micro-op/reduction/vcgmin-tie/compare.py | 209 ++++++++ .../micro-op/reduction/vcgmin-tie/golden.py | 58 ++ .../micro-op/reduction/vcgmin-tie/kernel.pto | 49 ++ .../micro-op/reduction/vcgmin-tie/launch.cpp | 70 +++ .../micro-op/reduction/vcgmin-tie/main.cpp | 130 +++++ .../micro-op/reduction/vcgmin-tie/stub.cpp | 33 ++ .../micro-op/reduction/vcgmin/compare.py | 209 ++++++++ .../cases/micro-op/reduction/vcgmin/golden.py | 53 ++ .../micro-op/reduction/vcgmin/kernel.pto | 49 ++ .../micro-op/reduction/vcgmin/launch.cpp | 70 +++ .../cases/micro-op/reduction/vcgmin/main.cpp | 130 +++++ .../cases/micro-op/reduction/vcgmin/stub.cpp | 33 ++ .../cases/micro-op/reduction/vcmax/compare.py | 204 +++++++ .../cases/micro-op/reduction/vcmax/golden.py | 52 ++ .../cases/micro-op/reduction/vcmax/kernel.pto | 41 ++ .../cases/micro-op/reduction/vcmax/launch.cpp | 62 +++ .../cases/micro-op/reduction/vcmax/main.cpp | 122 +++++ .../cases/micro-op/reduction/vcmax/stub.cpp | 25 + .../cases/micro-op/reduction/vcmin/compare.py | 204 +++++++ .../cases/micro-op/reduction/vcmin/golden.py | 52 ++ .../cases/micro-op/reduction/vcmin/kernel.pto | 41 ++ .../cases/micro-op/reduction/vcmin/launch.cpp | 62 +++ .../cases/micro-op/reduction/vcmin/main.cpp | 122 +++++ .../cases/micro-op/reduction/vcmin/stub.cpp | 25 + .../micro-op/reduction/vcpadd-tail/compare.py | 209 ++++++++ .../micro-op/reduction/vcpadd-tail/golden.py | 56 ++ .../micro-op/reduction/vcpadd-tail/kernel.pto | 49 ++ .../micro-op/reduction/vcpadd-tail/launch.cpp | 70 +++ .../micro-op/reduction/vcpadd-tail/main.cpp | 130 +++++ .../micro-op/reduction/vcpadd-tail/stub.cpp | 33 ++ .../micro-op/reduction/vcpadd/compare.py | 209 ++++++++ .../cases/micro-op/reduction/vcpadd/golden.py | 55 ++ .../micro-op/reduction/vcpadd/kernel.pto | 49 ++ .../micro-op/reduction/vcpadd/launch.cpp | 70 +++ .../cases/micro-op/reduction/vcpadd/main.cpp | 130 +++++ .../cases/micro-op/reduction/vcpadd/stub.cpp | 33 ++ .../load-store-scalar-ub/compare.py | 54 ++ .../load-store-scalar-ub/golden.py | 45 ++ .../load-store-scalar-ub/kernel.pto | 48 ++ .../load-store-scalar-ub/launch.cpp | 45 ++ .../load-store-scalar-ub/main.cpp | 119 +++++ .../load-store-scalar-ub/stub.cpp | 23 + .../get-block-subblock-id/compare.py | 54 ++ .../get-block-subblock-id/golden.py | 51 ++ .../get-block-subblock-id/kernel.pto | 35 ++ .../get-block-subblock-id/launch.cpp | 43 ++ .../get-block-subblock-id/main.cpp | 111 ++++ .../get-block-subblock-id/stub.cpp | 20 + .../micro-op/unary-vector/vabs-f16/compare.py | 209 ++++++++ .../micro-op/unary-vector/vabs-f16/golden.py | 51 ++ .../micro-op/unary-vector/vabs-f16/kernel.pto | 68 +++ .../micro-op/unary-vector/vabs-f16/launch.cpp | 70 +++ .../micro-op/unary-vector/vabs-f16/main.cpp | 130 +++++ .../micro-op/unary-vector/vabs-f16/stub.cpp | 33 ++ .../vabs-f32-exceptional/compare.py | 204 +++++++ .../vabs-f32-exceptional/golden.py | 48 ++ .../vabs-f32-exceptional/kernel.pto | 40 ++ .../vabs-f32-exceptional/launch.cpp | 45 ++ .../vabs-f32-exceptional/main.cpp | 87 +++ .../vabs-f32-exceptional/stub.cpp | 22 + .../vabs-i16-signed-overflow-edge/compare.py | 42 ++ .../vabs-i16-signed-overflow-edge/golden.py | 51 ++ .../vabs-i16-signed-overflow-edge/kernel.pto | 47 ++ .../vabs-i16-signed-overflow-edge/launch.cpp | 49 ++ .../vabs-i16-signed-overflow-edge/main.cpp | 103 ++++ .../vabs-i16-signed-overflow-edge/stub.cpp | 23 + .../unary-vector/vabs-i16-signed/compare.py | 209 ++++++++ .../unary-vector/vabs-i16-signed/golden.py | 51 ++ .../unary-vector/vabs-i16-signed/kernel.pto | 68 +++ .../unary-vector/vabs-i16-signed/launch.cpp | 70 +++ .../unary-vector/vabs-i16-signed/main.cpp | 130 +++++ .../unary-vector/vabs-i16-signed/stub.cpp | 33 ++ .../unary-vector/vabs-i16-unsigned/compare.py | 209 ++++++++ .../unary-vector/vabs-i16-unsigned/golden.py | 51 ++ .../unary-vector/vabs-i16-unsigned/kernel.pto | 68 +++ .../unary-vector/vabs-i16-unsigned/launch.cpp | 70 +++ .../unary-vector/vabs-i16-unsigned/main.cpp | 130 +++++ .../unary-vector/vabs-i16-unsigned/stub.cpp | 33 ++ .../vabs-loop-carried-vreg/compare.py | 198 +++++++ .../vabs-loop-carried-vreg/golden.py | 74 +++ .../vabs-loop-carried-vreg/kernel.pto | 48 ++ .../vabs-loop-carried-vreg/launch.cpp | 48 ++ .../vabs-loop-carried-vreg/main.cpp | 122 +++++ .../vabs-loop-carried-vreg/stub.cpp | 22 + .../unary-vector/vabs-tail/compare.py | 39 ++ .../micro-op/unary-vector/vabs-tail/golden.py | 49 ++ .../unary-vector/vabs-tail/kernel.pto | 40 ++ .../unary-vector/vabs-tail/launch.cpp | 44 ++ .../micro-op/unary-vector/vabs-tail/main.cpp | 87 +++ .../micro-op/unary-vector/vabs-tail/stub.cpp | 22 + .../micro-op/unary-vector/vabs/compare.py | 204 +++++++ .../micro-op/unary-vector/vabs/golden.py | 46 ++ .../micro-op/unary-vector/vabs/kernel.pto | 60 +++ .../micro-op/unary-vector/vabs/launch.cpp | 62 +++ .../cases/micro-op/unary-vector/vabs/main.cpp | 122 +++++ .../cases/micro-op/unary-vector/vabs/stub.cpp | 25 + .../micro-op/unary-vector/vexp-f16/compare.py | 209 ++++++++ .../micro-op/unary-vector/vexp-f16/golden.py | 51 ++ .../micro-op/unary-vector/vexp-f16/kernel.pto | 48 ++ .../micro-op/unary-vector/vexp-f16/launch.cpp | 71 +++ .../micro-op/unary-vector/vexp-f16/main.cpp | 130 +++++ .../micro-op/unary-vector/vexp-f16/stub.cpp | 33 ++ .../vexp-f32-exceptional/compare.py | 204 +++++++ .../vexp-f32-exceptional/golden.py | 48 ++ .../vexp-f32-exceptional/kernel.pto | 40 ++ .../vexp-f32-exceptional/launch.cpp | 44 ++ .../vexp-f32-exceptional/main.cpp | 83 +++ .../vexp-f32-exceptional/stub.cpp | 22 + .../vexp-f32-over-underflow/compare.py | 204 +++++++ .../vexp-f32-over-underflow/golden.py | 48 ++ .../vexp-f32-over-underflow/kernel.pto | 40 ++ .../vexp-f32-over-underflow/launch.cpp | 44 ++ .../vexp-f32-over-underflow/main.cpp | 83 +++ .../vexp-f32-over-underflow/stub.cpp | 22 + .../unary-vector/vexp-tail/compare.py | 39 ++ .../micro-op/unary-vector/vexp-tail/golden.py | 49 ++ .../unary-vector/vexp-tail/kernel.pto | 40 ++ .../unary-vector/vexp-tail/launch.cpp | 43 ++ .../micro-op/unary-vector/vexp-tail/main.cpp | 83 +++ .../micro-op/unary-vector/vexp-tail/stub.cpp | 22 + .../micro-op/unary-vector/vexp/compare.py | 204 +++++++ .../micro-op/unary-vector/vexp/golden.py | 46 ++ .../micro-op/unary-vector/vexp/kernel.pto | 60 +++ .../micro-op/unary-vector/vexp/launch.cpp | 62 +++ .../cases/micro-op/unary-vector/vexp/main.cpp | 122 +++++ .../cases/micro-op/unary-vector/vexp/stub.cpp | 25 + .../vln-domain-boundary/compare.py | 209 ++++++++ .../vln-domain-boundary/golden.py | 65 +++ .../vln-domain-boundary/kernel.pto | 68 +++ .../vln-domain-boundary/launch.cpp | 70 +++ .../unary-vector/vln-domain-boundary/main.cpp | 130 +++++ .../unary-vector/vln-domain-boundary/stub.cpp | 33 ++ .../micro-op/unary-vector/vln/compare.py | 204 +++++++ .../cases/micro-op/unary-vector/vln/golden.py | 46 ++ .../micro-op/unary-vector/vln/kernel.pto | 41 ++ .../micro-op/unary-vector/vln/launch.cpp | 62 +++ .../cases/micro-op/unary-vector/vln/main.cpp | 122 +++++ .../cases/micro-op/unary-vector/vln/stub.cpp | 25 + .../vneg-f32-exceptional/compare.py | 209 ++++++++ .../vneg-f32-exceptional/golden.py | 65 +++ .../vneg-f32-exceptional/kernel.pto | 68 +++ .../vneg-f32-exceptional/launch.cpp | 70 +++ .../vneg-f32-exceptional/main.cpp | 130 +++++ .../vneg-f32-exceptional/stub.cpp | 33 ++ .../micro-op/unary-vector/vneg/compare.py | 209 ++++++++ .../micro-op/unary-vector/vneg/golden.py | 51 ++ .../micro-op/unary-vector/vneg/kernel.pto | 68 +++ .../micro-op/unary-vector/vneg/launch.cpp | 70 +++ .../cases/micro-op/unary-vector/vneg/main.cpp | 130 +++++ .../cases/micro-op/unary-vector/vneg/stub.cpp | 33 ++ .../micro-op/unary-vector/vnot/compare.py | 209 ++++++++ .../micro-op/unary-vector/vnot/golden.py | 56 ++ .../micro-op/unary-vector/vnot/kernel.pto | 48 ++ .../micro-op/unary-vector/vnot/launch.cpp | 70 +++ .../cases/micro-op/unary-vector/vnot/main.cpp | 130 +++++ .../cases/micro-op/unary-vector/vnot/stub.cpp | 33 ++ .../micro-op/unary-vector/vrelu/compare.py | 204 +++++++ .../micro-op/unary-vector/vrelu/golden.py | 46 ++ .../micro-op/unary-vector/vrelu/kernel.pto | 41 ++ .../micro-op/unary-vector/vrelu/launch.cpp | 62 +++ .../micro-op/unary-vector/vrelu/main.cpp | 122 +++++ .../micro-op/unary-vector/vrelu/stub.cpp | 25 + .../vsqrt-domain-boundary/compare.py | 209 ++++++++ .../vsqrt-domain-boundary/golden.py | 65 +++ .../vsqrt-domain-boundary/kernel.pto | 68 +++ .../vsqrt-domain-boundary/launch.cpp | 70 +++ .../vsqrt-domain-boundary/main.cpp | 130 +++++ .../vsqrt-domain-boundary/stub.cpp | 33 ++ .../micro-op/unary-vector/vsqrt/compare.py | 204 +++++++ .../micro-op/unary-vector/vsqrt/golden.py | 47 ++ .../micro-op/unary-vector/vsqrt/kernel.pto | 41 ++ .../micro-op/unary-vector/vsqrt/launch.cpp | 62 +++ .../micro-op/unary-vector/vsqrt/main.cpp | 122 +++++ .../micro-op/unary-vector/vsqrt/stub.cpp | 25 + .../vaddcs-carry-boundary/compare.py | 64 +++ .../vaddcs-carry-boundary/golden.py | 72 +++ .../vaddcs-carry-boundary/kernel.pto | 54 ++ .../vaddcs-carry-boundary/launch.cpp | 53 ++ .../vec-scalar/vaddcs-carry-boundary/main.cpp | 115 ++++ .../vec-scalar/vaddcs-carry-boundary/stub.cpp | 31 ++ .../micro-op/vec-scalar/vaddcs/compare.py | 64 +++ .../micro-op/vec-scalar/vaddcs/golden.py | 64 +++ .../micro-op/vec-scalar/vaddcs/kernel.pto | 52 ++ .../micro-op/vec-scalar/vaddcs/launch.cpp | 54 ++ .../cases/micro-op/vec-scalar/vaddcs/main.cpp | 115 ++++ .../cases/micro-op/vec-scalar/vaddcs/stub.cpp | 31 ++ .../micro-op/vec-scalar/vadds-bf16/compare.py | 37 ++ .../micro-op/vec-scalar/vadds-bf16/golden.py | 57 ++ .../micro-op/vec-scalar/vadds-bf16/kernel.pto | 41 ++ .../micro-op/vec-scalar/vadds-bf16/launch.cpp | 44 ++ .../micro-op/vec-scalar/vadds-bf16/main.cpp | 84 +++ .../micro-op/vec-scalar/vadds-bf16/stub.cpp | 23 + .../micro-op/vec-scalar/vadds-f16/compare.py | 37 ++ .../micro-op/vec-scalar/vadds-f16/golden.py | 44 ++ .../micro-op/vec-scalar/vadds-f16/kernel.pto | 41 ++ .../micro-op/vec-scalar/vadds-f16/launch.cpp | 44 ++ .../micro-op/vec-scalar/vadds-f16/main.cpp | 84 +++ .../micro-op/vec-scalar/vadds-f16/stub.cpp | 23 + .../vadds-f32-exceptional/compare.py | 37 ++ .../vadds-f32-exceptional/golden.py | 49 ++ .../vadds-f32-exceptional/kernel.pto | 43 ++ .../vadds-f32-exceptional/launch.cpp | 44 ++ .../vec-scalar/vadds-f32-exceptional/main.cpp | 83 +++ .../vec-scalar/vadds-f32-exceptional/stub.cpp | 22 + .../vadds-i16-signed-overflow/compare.py | 42 ++ .../vadds-i16-signed-overflow/golden.py | 64 +++ .../vadds-i16-signed-overflow/kernel.pto | 46 ++ .../vadds-i16-signed-overflow/launch.cpp | 51 ++ .../vadds-i16-signed-overflow/main.cpp | 81 +++ .../vadds-i16-signed-overflow/stub.cpp | 29 + .../vec-scalar/vadds-i16-signed/compare.py | 42 ++ .../vec-scalar/vadds-i16-signed/golden.py | 47 ++ .../vec-scalar/vadds-i16-signed/kernel.pto | 46 ++ .../vec-scalar/vadds-i16-signed/launch.cpp | 44 ++ .../vec-scalar/vadds-i16-signed/main.cpp | 80 +++ .../vec-scalar/vadds-i16-signed/stub.cpp | 23 + .../vadds-i16-unsigned-overflow/compare.py | 41 ++ .../vadds-i16-unsigned-overflow/golden.py | 64 +++ .../vadds-i16-unsigned-overflow/kernel.pto | 46 ++ .../vadds-i16-unsigned-overflow/launch.cpp | 51 ++ .../vadds-i16-unsigned-overflow/main.cpp | 81 +++ .../vadds-i16-unsigned-overflow/stub.cpp | 29 + .../vec-scalar/vadds-i16-unsigned/compare.py | 41 ++ .../vec-scalar/vadds-i16-unsigned/golden.py | 47 ++ .../vec-scalar/vadds-i16-unsigned/kernel.pto | 46 ++ .../vec-scalar/vadds-i16-unsigned/launch.cpp | 50 ++ .../vec-scalar/vadds-i16-unsigned/main.cpp | 90 ++++ .../vec-scalar/vadds-i16-unsigned/stub.cpp | 29 + .../micro-op/vec-scalar/vadds-tail/compare.py | 39 ++ .../micro-op/vec-scalar/vadds-tail/golden.py | 50 ++ .../micro-op/vec-scalar/vadds-tail/kernel.pto | 43 ++ .../micro-op/vec-scalar/vadds-tail/launch.cpp | 44 ++ .../micro-op/vec-scalar/vadds-tail/main.cpp | 83 +++ .../micro-op/vec-scalar/vadds-tail/stub.cpp | 22 + .../micro-op/vec-scalar/vadds/compare.py | 37 ++ .../cases/micro-op/vec-scalar/vadds/golden.py | 45 ++ .../micro-op/vec-scalar/vadds/kernel.pto | 44 ++ .../micro-op/vec-scalar/vadds/launch.cpp | 44 ++ .../cases/micro-op/vec-scalar/vadds/main.cpp | 83 +++ .../cases/micro-op/vec-scalar/vadds/stub.cpp | 22 + .../micro-op/vec-scalar/vmaxs-tail/compare.py | 39 ++ .../micro-op/vec-scalar/vmaxs-tail/golden.py | 50 ++ .../micro-op/vec-scalar/vmaxs-tail/kernel.pto | 43 ++ .../micro-op/vec-scalar/vmaxs-tail/launch.cpp | 44 ++ .../micro-op/vec-scalar/vmaxs-tail/main.cpp | 83 +++ .../micro-op/vec-scalar/vmaxs-tail/stub.cpp | 22 + .../micro-op/vec-scalar/vmaxs/compare.py | 37 ++ .../cases/micro-op/vec-scalar/vmaxs/golden.py | 45 ++ .../micro-op/vec-scalar/vmaxs/kernel.pto | 44 ++ .../micro-op/vec-scalar/vmaxs/launch.cpp | 44 ++ .../cases/micro-op/vec-scalar/vmaxs/main.cpp | 83 +++ .../cases/micro-op/vec-scalar/vmaxs/stub.cpp | 22 + .../micro-op/vec-scalar/vmins-tail/compare.py | 39 ++ .../micro-op/vec-scalar/vmins-tail/golden.py | 50 ++ .../micro-op/vec-scalar/vmins-tail/kernel.pto | 43 ++ .../micro-op/vec-scalar/vmins-tail/launch.cpp | 44 ++ .../micro-op/vec-scalar/vmins-tail/main.cpp | 83 +++ .../micro-op/vec-scalar/vmins-tail/stub.cpp | 22 + .../micro-op/vec-scalar/vmins/compare.py | 37 ++ .../cases/micro-op/vec-scalar/vmins/golden.py | 45 ++ .../micro-op/vec-scalar/vmins/kernel.pto | 44 ++ .../micro-op/vec-scalar/vmins/launch.cpp | 44 ++ .../cases/micro-op/vec-scalar/vmins/main.cpp | 83 +++ .../cases/micro-op/vec-scalar/vmins/stub.cpp | 22 + .../micro-op/vec-scalar/vmuls-tail/compare.py | 39 ++ .../micro-op/vec-scalar/vmuls-tail/golden.py | 50 ++ .../micro-op/vec-scalar/vmuls-tail/kernel.pto | 43 ++ .../micro-op/vec-scalar/vmuls-tail/launch.cpp | 44 ++ .../micro-op/vec-scalar/vmuls-tail/main.cpp | 83 +++ .../micro-op/vec-scalar/vmuls-tail/stub.cpp | 22 + .../micro-op/vec-scalar/vmuls/compare.py | 37 ++ .../cases/micro-op/vec-scalar/vmuls/golden.py | 45 ++ .../micro-op/vec-scalar/vmuls/kernel.pto | 44 ++ .../micro-op/vec-scalar/vmuls/launch.cpp | 44 ++ .../cases/micro-op/vec-scalar/vmuls/main.cpp | 83 +++ .../cases/micro-op/vec-scalar/vmuls/stub.cpp | 22 + .../vshls-shift-boundary/compare.py | 41 ++ .../vec-scalar/vshls-shift-boundary/golden.py | 51 ++ .../vshls-shift-boundary/kernel.pto | 46 ++ .../vshls-shift-boundary/launch.cpp | 52 ++ .../vec-scalar/vshls-shift-boundary/main.cpp | 91 ++++ .../vec-scalar/vshls-shift-boundary/stub.cpp | 30 ++ .../micro-op/vec-scalar/vshls/compare.py | 41 ++ .../cases/micro-op/vec-scalar/vshls/golden.py | 47 ++ .../micro-op/vec-scalar/vshls/kernel.pto | 46 ++ .../micro-op/vec-scalar/vshls/launch.cpp | 51 ++ .../cases/micro-op/vec-scalar/vshls/main.cpp | 91 ++++ .../cases/micro-op/vec-scalar/vshls/stub.cpp | 30 ++ .../vshrs-shift-boundary/compare.py | 41 ++ .../vec-scalar/vshrs-shift-boundary/golden.py | 51 ++ .../vshrs-shift-boundary/kernel.pto | 46 ++ .../vshrs-shift-boundary/launch.cpp | 52 ++ .../vec-scalar/vshrs-shift-boundary/main.cpp | 91 ++++ .../vec-scalar/vshrs-shift-boundary/stub.cpp | 30 ++ .../micro-op/vec-scalar/vshrs/compare.py | 41 ++ .../cases/micro-op/vec-scalar/vshrs/golden.py | 47 ++ .../micro-op/vec-scalar/vshrs/kernel.pto | 46 ++ .../micro-op/vec-scalar/vshrs/launch.cpp | 51 ++ .../cases/micro-op/vec-scalar/vshrs/main.cpp | 91 ++++ .../cases/micro-op/vec-scalar/vshrs/stub.cpp | 30 ++ .../vsubcs-borrow-boundary/compare.py | 64 +++ .../vsubcs-borrow-boundary/golden.py | 73 +++ .../vsubcs-borrow-boundary/kernel.pto | 54 ++ .../vsubcs-borrow-boundary/launch.cpp | 53 ++ .../vsubcs-borrow-boundary/main.cpp | 116 ++++ .../vsubcs-borrow-boundary/stub.cpp | 31 ++ .../micro-op/vec-scalar/vsubcs/compare.py | 64 +++ .../micro-op/vec-scalar/vsubcs/golden.py | 65 +++ .../micro-op/vec-scalar/vsubcs/kernel.pto | 52 ++ .../micro-op/vec-scalar/vsubcs/launch.cpp | 54 ++ .../cases/micro-op/vec-scalar/vsubcs/main.cpp | 115 ++++ .../cases/micro-op/vec-scalar/vsubcs/stub.cpp | 31 ++ .../vldas-vldus-state-chain/compare.py | 210 ++++++++ .../vldas-vldus-state-chain/golden.py | 56 ++ .../vldas-vldus-state-chain/kernel.pto | 59 +++ .../vldas-vldus-state-chain/launch.cpp | 70 +++ .../vldas-vldus-state-chain/main.cpp | 130 +++++ .../vldas-vldus-state-chain/stub.cpp | 33 ++ .../vector-load-store/vldas-vldus/compare.py | 210 ++++++++ .../vector-load-store/vldas-vldus/golden.py | 55 ++ .../vector-load-store/vldas-vldus/kernel.pto | 69 +++ .../vector-load-store/vldas-vldus/launch.cpp | 70 +++ .../vector-load-store/vldas-vldus/main.cpp | 130 +++++ .../vector-load-store/vldas-vldus/stub.cpp | 33 ++ .../vlds-brc-b16-f32/compare.py | 40 ++ .../vlds-brc-b16-f32/golden.py | 57 ++ .../vlds-brc-b16-f32/kernel.pto | 45 ++ .../vlds-brc-b16-f32/launch.cpp | 45 ++ .../vlds-brc-b16-f32/main.cpp | 122 +++++ .../vlds-brc-b16-f32/stub.cpp | 21 + .../vector-load-store/vlds-brc-b16/compare.py | 210 ++++++++ .../vector-load-store/vlds-brc-b16/golden.py | 54 ++ .../vector-load-store/vlds-brc-b16/kernel.pto | 52 ++ .../vector-load-store/vlds-brc-b16/launch.cpp | 70 +++ .../vector-load-store/vlds-brc-b16/main.cpp | 130 +++++ .../vector-load-store/vlds-brc-b16/stub.cpp | 33 ++ .../vector-load-store/vlds-brc-b32/compare.py | 209 ++++++++ .../vector-load-store/vlds-brc-b32/golden.py | 53 ++ .../vector-load-store/vlds-brc-b32/kernel.pto | 67 +++ .../vector-load-store/vlds-brc-b32/launch.cpp | 70 +++ .../vector-load-store/vlds-brc-b32/main.cpp | 130 +++++ .../vector-load-store/vlds-brc-b32/stub.cpp | 33 ++ .../vlds-brc-b8-f32/compare.py | 40 ++ .../vlds-brc-b8-f32/golden.py | 56 ++ .../vlds-brc-b8-f32/kernel.pto | 45 ++ .../vlds-brc-b8-f32/launch.cpp | 45 ++ .../vlds-brc-b8-f32/main.cpp | 122 +++++ .../vlds-brc-b8-f32/stub.cpp | 21 + .../vector-load-store/vlds-brc-blk/compare.py | 210 ++++++++ .../vector-load-store/vlds-brc-blk/golden.py | 56 ++ .../vector-load-store/vlds-brc-blk/kernel.pto | 48 ++ .../vector-load-store/vlds-brc-blk/launch.cpp | 70 +++ .../vector-load-store/vlds-brc-blk/main.cpp | 130 +++++ .../vector-load-store/vlds-brc-blk/stub.cpp | 33 ++ .../vector-load-store/vlds-ds-b16/compare.py | 210 ++++++++ .../vector-load-store/vlds-ds-b16/golden.py | 55 ++ .../vector-load-store/vlds-ds-b16/kernel.pto | 48 ++ .../vector-load-store/vlds-ds-b16/launch.cpp | 70 +++ .../vector-load-store/vlds-ds-b16/main.cpp | 130 +++++ .../vector-load-store/vlds-ds-b16/stub.cpp | 33 ++ .../vector-load-store/vlds-tail/compare.py | 210 ++++++++ .../vector-load-store/vlds-tail/golden.py | 53 ++ .../vector-load-store/vlds-tail/kernel.pto | 68 +++ .../vector-load-store/vlds-tail/launch.cpp | 70 +++ .../vector-load-store/vlds-tail/main.cpp | 130 +++++ .../vector-load-store/vlds-tail/stub.cpp | 33 ++ .../vlds-unpk-b16/compare.py | 56 ++ .../vector-load-store/vlds-unpk-b16/golden.py | 55 ++ .../vlds-unpk-b16/kernel.pto | 54 ++ .../vlds-unpk-b16/launch.cpp | 50 ++ .../vector-load-store/vlds-unpk-b16/main.cpp | 102 ++++ .../vector-load-store/vlds-unpk-b16/stub.cpp | 28 + .../vector-load-store/vlds-us-b16/compare.py | 210 ++++++++ .../vector-load-store/vlds-us-b16/golden.py | 55 ++ .../vector-load-store/vlds-us-b16/kernel.pto | 48 ++ .../vector-load-store/vlds-us-b16/launch.cpp | 70 +++ .../vector-load-store/vlds-us-b16/main.cpp | 130 +++++ .../vector-load-store/vlds-us-b16/stub.cpp | 33 ++ .../vector-load-store/vlds/compare.py | 209 ++++++++ .../micro-op/vector-load-store/vlds/golden.py | 51 ++ .../vector-load-store/vlds/kernel.pto | 67 +++ .../vector-load-store/vlds/launch.cpp | 70 +++ .../micro-op/vector-load-store/vlds/main.cpp | 130 +++++ .../micro-op/vector-load-store/vlds/stub.cpp | 33 ++ .../vldsx2-layout-check/compare.py | 209 ++++++++ .../vldsx2-layout-check/golden.py | 62 +++ .../vldsx2-layout-check/kernel.pto | 62 +++ .../vldsx2-layout-check/launch.cpp | 71 +++ .../vldsx2-layout-check/main.cpp | 130 +++++ .../vldsx2-layout-check/stub.cpp | 33 ++ .../vldsx2-vstsx2-b8-f32/compare.py | 201 +++++++ .../vldsx2-vstsx2-b8-f32/golden.py | 50 ++ .../vldsx2-vstsx2-b8-f32/kernel.pto | 59 +++ .../vldsx2-vstsx2-b8-f32/launch.cpp | 50 ++ .../vldsx2-vstsx2-b8-f32/main.cpp | 128 +++++ .../vldsx2-vstsx2-b8-f32/stub.cpp | 28 + .../vldsx2-vstsx2/compare.py | 209 ++++++++ .../vector-load-store/vldsx2-vstsx2/golden.py | 52 ++ .../vldsx2-vstsx2/kernel.pto | 59 +++ .../vldsx2-vstsx2/launch.cpp | 71 +++ .../vector-load-store/vldsx2-vstsx2/main.cpp | 130 +++++ .../vector-load-store/vldsx2-vstsx2/stub.cpp | 33 ++ .../vector-load-store/vsldb/compare.py | 210 ++++++++ .../vector-load-store/vsldb/golden.py | 62 +++ .../vector-load-store/vsldb/kernel.pto | 45 ++ .../vector-load-store/vsldb/launch.cpp | 70 +++ .../micro-op/vector-load-store/vsldb/main.cpp | 130 +++++ .../micro-op/vector-load-store/vsldb/stub.cpp | 33 ++ .../vector-load-store/vsstb/compare.py | 209 ++++++++ .../vector-load-store/vsstb/golden.py | 55 ++ .../vector-load-store/vsstb/kernel.pto | 48 ++ .../vector-load-store/vsstb/launch.cpp | 70 +++ .../micro-op/vector-load-store/vsstb/main.cpp | 130 +++++ .../micro-op/vector-load-store/vsstb/stub.cpp | 33 ++ .../vector-load-store/vstar/compare.py | 236 +++++++++ .../vector-load-store/vstar/golden.py | 51 ++ .../vector-load-store/vstar/kernel.pto | 61 +++ .../vector-load-store/vstar/launch.cpp | 70 +++ .../micro-op/vector-load-store/vstar/main.cpp | 130 +++++ .../micro-op/vector-load-store/vstar/stub.cpp | 33 ++ .../vstas-vstus-offset-update/compare.py | 208 ++++++++ .../vstas-vstus-offset-update/golden.py | 52 ++ .../vstas-vstus-offset-update/kernel.pto | 57 ++ .../vstas-vstus-offset-update/launch.cpp | 70 +++ .../vstas-vstus-offset-update/main.cpp | 130 +++++ .../vstas-vstus-offset-update/stub.cpp | 33 ++ .../vector-load-store/vsts-1pt-b16/compare.py | 258 +++++++++ .../vector-load-store/vsts-1pt-b16/golden.py | 53 ++ .../vector-load-store/vsts-1pt-b16/kernel.pto | 48 ++ .../vector-load-store/vsts-1pt-b16/launch.cpp | 71 +++ .../vector-load-store/vsts-1pt-b16/main.cpp | 130 +++++ .../vector-load-store/vsts-1pt-b16/stub.cpp | 33 ++ .../vector-load-store/vsts-pk-b16/compare.py | 91 ++++ .../vector-load-store/vsts-pk-b16/golden.py | 65 +++ .../vector-load-store/vsts-pk-b16/kernel.pto | 48 ++ .../vector-load-store/vsts-pk-b16/launch.cpp | 71 +++ .../vector-load-store/vsts-pk-b16/main.cpp | 130 +++++ .../vector-load-store/vsts-pk-b16/stub.cpp | 33 ++ .../vsts-pk-b64-f32/compare.py | 40 ++ .../vsts-pk-b64-f32/golden.py | 54 ++ .../vsts-pk-b64-f32/kernel.pto | 48 ++ .../vsts-pk-b64-f32/launch.cpp | 44 ++ .../vsts-pk-b64-f32/main.cpp | 122 +++++ .../vsts-pk-b64-f32/stub.cpp | 21 + .../vector-load-store/vsts-tail/compare.py | 257 +++++++++ .../vector-load-store/vsts-tail/golden.py | 51 ++ .../vector-load-store/vsts-tail/kernel.pto | 50 ++ .../vector-load-store/vsts-tail/launch.cpp | 70 +++ .../vector-load-store/vsts-tail/main.cpp | 130 +++++ .../vector-load-store/vsts-tail/stub.cpp | 33 ++ .../vector-load-store/vsts/compare.py | 209 ++++++++ .../micro-op/vector-load-store/vsts/golden.py | 51 ++ .../vector-load-store/vsts/kernel.pto | 67 +++ .../vector-load-store/vsts/launch.cpp | 70 +++ .../micro-op/vector-load-store/vsts/main.cpp | 130 +++++ .../micro-op/vector-load-store/vsts/stub.cpp | 33 ++ .../vstsx2-layout-check/compare.py | 210 ++++++++ .../vstsx2-layout-check/golden.py | 56 ++ .../vstsx2-layout-check/kernel.pto | 54 ++ .../vstsx2-layout-check/launch.cpp | 71 +++ .../vstsx2-layout-check/main.cpp | 130 +++++ .../vstsx2-layout-check/stub.cpp | 33 ++ .../vstur-init-align-outside-loop/compare.py | 112 ++++ .../vstur-init-align-outside-loop/golden.py | 52 ++ .../vstur-init-align-outside-loop/kernel.pto | 51 ++ .../vstur-init-align-outside-loop/launch.cpp | 54 ++ .../vstur-init-align-outside-loop/main.cpp | 129 +++++ .../vstur-init-align-outside-loop/stub.cpp | 30 ++ .../vector-load-store/vstur/compare.py | 257 +++++++++ .../vector-load-store/vstur/golden.py | 52 ++ .../vector-load-store/vstur/kernel.pto | 59 +++ .../vector-load-store/vstur/launch.cpp | 70 +++ .../micro-op/vector-load-store/vstur/main.cpp | 130 +++++ .../micro-op/vector-load-store/vstur/stub.cpp | 33 ++ test/vpto/npu_validation/common/test_common.h | 61 +++ test/vpto/scripts/run_host_vpto_validation.sh | 496 ++++++++++++++++++ .../run_host_vpto_validation_parallel.sh | 189 +++++++ 1427 files changed, 93603 insertions(+) create mode 100755 test/vpto/cases/micro-op/binary-vector/vadd-bf16/compare.py create mode 100755 test/vpto/cases/micro-op/binary-vector/vadd-bf16/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-bf16/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-bf16/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-bf16/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-bf16/stub.cpp create mode 100755 test/vpto/cases/micro-op/binary-vector/vadd-f16/compare.py create mode 100755 test/vpto/cases/micro-op/binary-vector/vadd-f16/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-f16/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-f16/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-f16/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-f16/stub.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/compare.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/stub.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/compare.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/stub.cpp create mode 100755 test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/compare.py create mode 100755 test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/stub.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/compare.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/stub.cpp create mode 100755 test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/compare.py create mode 100755 test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/stub.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-tail/compare.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-tail/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-tail/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-tail/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-tail/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd-tail/stub.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd/compare.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vadd/stub.cpp create mode 100755 test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/compare.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/stub.cpp create mode 100755 test/vpto/cases/micro-op/binary-vector/vaddc/compare.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vaddc/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vaddc/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vaddc/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vaddc/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vaddc/stub.cpp create mode 100755 test/vpto/cases/micro-op/binary-vector/vand-mask-edge/compare.py create mode 100755 test/vpto/cases/micro-op/binary-vector/vand-mask-edge/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vand-mask-edge/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vand-mask-edge/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vand-mask-edge/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vand-mask-edge/stub.cpp create mode 100755 test/vpto/cases/micro-op/binary-vector/vand/compare.py create mode 100755 test/vpto/cases/micro-op/binary-vector/vand/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vand/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vand/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vand/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vand/stub.cpp create mode 100755 test/vpto/cases/micro-op/binary-vector/vdiv-f16/compare.py create mode 100755 test/vpto/cases/micro-op/binary-vector/vdiv-f16/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vdiv-f16/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vdiv-f16/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vdiv-f16/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vdiv-f16/stub.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/compare.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/stub.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vdiv-tail/compare.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vdiv-tail/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vdiv-tail/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vdiv-tail/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vdiv-tail/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vdiv-tail/stub.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vdiv/compare.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vdiv/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vdiv/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vdiv/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vdiv/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vdiv/stub.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vmax-tail/compare.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vmax-tail/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vmax-tail/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vmax-tail/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vmax-tail/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vmax-tail/stub.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vmax/compare.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vmax/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vmax/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vmax/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vmax/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vmax/stub.cpp create mode 100755 test/vpto/cases/micro-op/binary-vector/vmin-bf16/compare.py create mode 100755 test/vpto/cases/micro-op/binary-vector/vmin-bf16/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin-bf16/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin-bf16/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin-bf16/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin-bf16/stub.cpp create mode 100755 test/vpto/cases/micro-op/binary-vector/vmin-f16/compare.py create mode 100755 test/vpto/cases/micro-op/binary-vector/vmin-f16/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin-f16/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin-f16/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin-f16/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin-f16/stub.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/compare.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/stub.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/vmin/compare.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/vmin/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/vmin/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/vmin/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/vmin/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/vmin/stub.cpp create mode 100755 test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/compare.py create mode 100755 test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/stub.cpp create mode 100755 test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/compare.py create mode 100755 test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/stub.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin-tail/compare.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin-tail/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin-tail/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin-tail/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin-tail/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin-tail/stub.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin/compare.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vmin/stub.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vmul-tail/compare.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vmul-tail/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vmul-tail/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vmul-tail/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vmul-tail/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vmul-tail/stub.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vmul/compare.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vmul/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vmul/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vmul/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vmul/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vmul/stub.cpp create mode 100755 test/vpto/cases/micro-op/binary-vector/vor-f16/compare.py create mode 100755 test/vpto/cases/micro-op/binary-vector/vor-f16/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vor-f16/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vor-f16/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vor-f16/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vor-f16/stub.cpp create mode 100755 test/vpto/cases/micro-op/binary-vector/vor-mask-edge/compare.py create mode 100755 test/vpto/cases/micro-op/binary-vector/vor-mask-edge/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vor-mask-edge/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vor-mask-edge/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vor-mask-edge/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vor-mask-edge/stub.cpp create mode 100755 test/vpto/cases/micro-op/binary-vector/vor/compare.py create mode 100755 test/vpto/cases/micro-op/binary-vector/vor/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vor/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vor/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vor/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vor/stub.cpp create mode 100755 test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/compare.py create mode 100755 test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/stub.cpp create mode 100755 test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/compare.py create mode 100755 test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/stub.cpp create mode 100755 test/vpto/cases/micro-op/binary-vector/vshl/compare.py create mode 100755 test/vpto/cases/micro-op/binary-vector/vshl/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vshl/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vshl/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vshl/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vshl/stub.cpp create mode 100755 test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/compare.py create mode 100755 test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/stub.cpp create mode 100755 test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/compare.py create mode 100755 test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/stub.cpp create mode 100755 test/vpto/cases/micro-op/binary-vector/vshr/compare.py create mode 100755 test/vpto/cases/micro-op/binary-vector/vshr/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vshr/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vshr/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vshr/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vshr/stub.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vsub-tail/compare.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vsub-tail/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vsub-tail/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vsub-tail/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vsub-tail/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vsub-tail/stub.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vsub/compare.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vsub/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vsub/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vsub/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vsub/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vsub/stub.cpp create mode 100755 test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/compare.py create mode 100755 test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/stub.cpp create mode 100755 test/vpto/cases/micro-op/binary-vector/vsubc/compare.py create mode 100755 test/vpto/cases/micro-op/binary-vector/vsubc/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vsubc/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vsubc/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vsubc/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vsubc/stub.cpp create mode 100755 test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/compare.py create mode 100755 test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/stub.cpp create mode 100755 test/vpto/cases/micro-op/binary-vector/vxor/compare.py create mode 100755 test/vpto/cases/micro-op/binary-vector/vxor/golden.py create mode 100644 test/vpto/cases/micro-op/binary-vector/vxor/kernel.pto create mode 100644 test/vpto/cases/micro-op/binary-vector/vxor/launch.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vxor/main.cpp create mode 100644 test/vpto/cases/micro-op/binary-vector/vxor/stub.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vcmp-eq/compare.py create mode 100644 test/vpto/cases/micro-op/compare-select/vcmp-eq/golden.py create mode 100644 test/vpto/cases/micro-op/compare-select/vcmp-eq/kernel.pto create mode 100644 test/vpto/cases/micro-op/compare-select/vcmp-eq/launch.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vcmp-eq/main.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vcmp-eq/stub.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/compare.py create mode 100644 test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/golden.py create mode 100644 test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/kernel.pto create mode 100644 test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/launch.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/main.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/stub.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/compare.py create mode 100644 test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/golden.py create mode 100644 test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/kernel.pto create mode 100644 test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/launch.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/main.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/stub.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/compare.py create mode 100644 test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/golden.py create mode 100644 test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/kernel.pto create mode 100644 test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/launch.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/main.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/stub.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vcmp-lt/compare.py create mode 100644 test/vpto/cases/micro-op/compare-select/vcmp-lt/golden.py create mode 100644 test/vpto/cases/micro-op/compare-select/vcmp-lt/kernel.pto create mode 100644 test/vpto/cases/micro-op/compare-select/vcmp-lt/launch.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vcmp-lt/main.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vcmp-lt/stub.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vcmp-tail/compare.py create mode 100644 test/vpto/cases/micro-op/compare-select/vcmp-tail/golden.py create mode 100644 test/vpto/cases/micro-op/compare-select/vcmp-tail/kernel.pto create mode 100644 test/vpto/cases/micro-op/compare-select/vcmp-tail/launch.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vcmp-tail/main.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vcmp-tail/stub.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/compare.py create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/golden.py create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/kernel.pto create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/launch.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/main.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/stub.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-f32/compare.py create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-f32/golden.py create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-f32/kernel.pto create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-f32/launch.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-f32/main.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-f32/stub.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/compare.py create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/golden.py create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/kernel.pto create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/launch.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/main.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/stub.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/compare.py create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/golden.py create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/kernel.pto create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/launch.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/main.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/stub.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-tail/compare.py create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-tail/golden.py create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-tail/kernel.pto create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-tail/launch.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-tail/main.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-tail/stub.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vsel-i16/compare.py create mode 100644 test/vpto/cases/micro-op/compare-select/vsel-i16/golden.py create mode 100644 test/vpto/cases/micro-op/compare-select/vsel-i16/kernel.pto create mode 100644 test/vpto/cases/micro-op/compare-select/vsel-i16/launch.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vsel-i16/main.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vsel-i16/stub.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/compare.py create mode 100644 test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/golden.py create mode 100644 test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/kernel.pto create mode 100644 test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/launch.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/main.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/stub.cpp create mode 100755 test/vpto/cases/micro-op/compare-select/vsel-tail/compare.py create mode 100644 test/vpto/cases/micro-op/compare-select/vsel-tail/golden.py create mode 100644 test/vpto/cases/micro-op/compare-select/vsel-tail/kernel.pto create mode 100644 test/vpto/cases/micro-op/compare-select/vsel-tail/launch.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vsel-tail/main.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vsel-tail/stub.cpp create mode 100755 test/vpto/cases/micro-op/compare-select/vsel/compare.py create mode 100644 test/vpto/cases/micro-op/compare-select/vsel/golden.py create mode 100644 test/vpto/cases/micro-op/compare-select/vsel/kernel.pto create mode 100644 test/vpto/cases/micro-op/compare-select/vsel/launch.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vsel/main.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vsel/stub.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vselr-f16/compare.py create mode 100644 test/vpto/cases/micro-op/compare-select/vselr-f16/golden.py create mode 100644 test/vpto/cases/micro-op/compare-select/vselr-f16/kernel.pto create mode 100644 test/vpto/cases/micro-op/compare-select/vselr-f16/launch.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vselr-f16/main.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vselr-f16/stub.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vselr-u8/compare.py create mode 100644 test/vpto/cases/micro-op/compare-select/vselr-u8/golden.py create mode 100644 test/vpto/cases/micro-op/compare-select/vselr-u8/kernel.pto create mode 100644 test/vpto/cases/micro-op/compare-select/vselr-u8/launch.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vselr-u8/main.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vselr-u8/stub.cpp create mode 100755 test/vpto/cases/micro-op/compare-select/vselr/compare.py create mode 100755 test/vpto/cases/micro-op/compare-select/vselr/golden.py create mode 100644 test/vpto/cases/micro-op/compare-select/vselr/kernel.pto create mode 100644 test/vpto/cases/micro-op/compare-select/vselr/launch.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vselr/main.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vselr/stub.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-f16-special/compare.py create mode 100755 test/vpto/cases/micro-op/conversion/vcvt-f16-special/golden.py create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-f16-special/kernel.pto create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-f16-special/launch.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-f16-special/main.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-f16-special/stub.cpp create mode 100755 test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/compare.py create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/golden.py create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/kernel.pto create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/launch.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/main.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/stub.cpp create mode 100755 test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/compare.py create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/golden.py create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/kernel.pto create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/launch.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/main.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/stub.cpp create mode 100755 test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/compare.py create mode 100755 test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/golden.py create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/kernel.pto create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/launch.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/main.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/stub.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-f32-special/compare.py create mode 100755 test/vpto/cases/micro-op/conversion/vcvt-f32-special/golden.py create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-f32-special/kernel.pto create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-f32-special/launch.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-f32-special/main.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-f32-special/stub.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/compare.py create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/golden.py create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/kernel.pto create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/launch.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/main.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/stub.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/compare.py create mode 100755 test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/golden.py create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/kernel.pto create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/launch.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/main.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/stub.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/compare.py create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/golden.py create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/kernel.pto create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/launch.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/main.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/stub.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-tail-special/compare.py create mode 100755 test/vpto/cases/micro-op/conversion/vcvt-tail-special/golden.py create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-tail-special/kernel.pto create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-tail-special/launch.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-tail-special/main.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-tail-special/stub.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-tail/compare.py create mode 100755 test/vpto/cases/micro-op/conversion/vcvt-tail/golden.py create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-tail/kernel.pto create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-tail/launch.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-tail/main.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-tail/stub.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/compare.py create mode 100644 test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/golden.py create mode 100644 test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/kernel.pto create mode 100644 test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/launch.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/main.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/stub.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/compare.py create mode 100644 test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/golden.py create mode 100644 test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/kernel.pto create mode 100644 test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/launch.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/main.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/stub.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vtrc-f32-special/compare.py create mode 100644 test/vpto/cases/micro-op/conversion/vtrc-f32-special/golden.py create mode 100644 test/vpto/cases/micro-op/conversion/vtrc-f32-special/kernel.pto create mode 100644 test/vpto/cases/micro-op/conversion/vtrc-f32-special/launch.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vtrc-f32-special/main.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vtrc-f32-special/stub.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/compare.py create mode 100644 test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/golden.py create mode 100644 test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/kernel.pto create mode 100644 test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/launch.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/main.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/stub.cpp create mode 100755 test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/compare.py create mode 100755 test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/golden.py create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/kernel.pto create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/launch.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/main.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/stub.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vbitsort/compare.py create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vbitsort/golden.py create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vbitsort/kernel.pto create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vbitsort/launch.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vbitsort/main.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vbitsort/stub.cpp create mode 100755 test/vpto/cases/micro-op/dsa-sfu/vci/compare.py create mode 100755 test/vpto/cases/micro-op/dsa-sfu/vci/golden.py create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vci/kernel.pto create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vci/launch.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vci/main.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vci/stub.cpp create mode 100755 test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/compare.py create mode 100755 test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/golden.py create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/kernel.pto create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/launch.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/main.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/stub.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/compare.py create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/golden.py create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/kernel.pto create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/launch.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/main.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/stub.cpp create mode 100755 test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/compare.py create mode 100755 test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/golden.py create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/kernel.pto create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/launch.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/main.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/stub.cpp create mode 100755 test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/compare.py create mode 100755 test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/golden.py create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/kernel.pto create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/launch.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/main.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/stub.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/compare.py create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/golden.py create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/kernel.pto create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/launch.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/main.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/stub.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/compare.py create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/golden.py create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/kernel.pto create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/launch.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/main.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/stub.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/compare.py create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/golden.py create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/kernel.pto create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/launch.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/main.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/stub.cpp create mode 100755 test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/compare.py create mode 100755 test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/golden.py create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/kernel.pto create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/launch.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/main.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/stub.cpp create mode 100755 test/vpto/cases/micro-op/dsa-sfu/vmula/compare.py create mode 100755 test/vpto/cases/micro-op/dsa-sfu/vmula/golden.py create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vmula/kernel.pto create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vmula/launch.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vmula/main.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vmula/stub.cpp create mode 100755 test/vpto/cases/micro-op/dsa-sfu/vmull/compare.py create mode 100755 test/vpto/cases/micro-op/dsa-sfu/vmull/golden.py create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vmull/kernel.pto create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vmull/launch.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vmull/main.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vmull/stub.cpp create mode 100755 test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/compare.py create mode 100755 test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/golden.py create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/kernel.pto create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/launch.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/main.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/stub.cpp create mode 100755 test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/compare.py create mode 100755 test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/golden.py create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/kernel.pto create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/launch.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/main.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/stub.cpp create mode 100755 test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/compare.py create mode 100755 test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/golden.py create mode 100644 test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/kernel.pto create mode 100644 test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/launch.cpp create mode 100644 test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/main.cpp create mode 100644 test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/stub.cpp create mode 100755 test/vpto/cases/micro-op/gather-scatter/vgather2/compare.py create mode 100755 test/vpto/cases/micro-op/gather-scatter/vgather2/golden.py create mode 100644 test/vpto/cases/micro-op/gather-scatter/vgather2/kernel.pto create mode 100644 test/vpto/cases/micro-op/gather-scatter/vgather2/launch.cpp create mode 100644 test/vpto/cases/micro-op/gather-scatter/vgather2/main.cpp create mode 100644 test/vpto/cases/micro-op/gather-scatter/vgather2/stub.cpp create mode 100755 test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/compare.py create mode 100755 test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/golden.py create mode 100644 test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/kernel.pto create mode 100644 test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/launch.cpp create mode 100644 test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/main.cpp create mode 100644 test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/stub.cpp create mode 100755 test/vpto/cases/micro-op/gather-scatter/vgather2_bc/compare.py create mode 100755 test/vpto/cases/micro-op/gather-scatter/vgather2_bc/golden.py create mode 100644 test/vpto/cases/micro-op/gather-scatter/vgather2_bc/kernel.pto create mode 100644 test/vpto/cases/micro-op/gather-scatter/vgather2_bc/launch.cpp create mode 100644 test/vpto/cases/micro-op/gather-scatter/vgather2_bc/main.cpp create mode 100644 test/vpto/cases/micro-op/gather-scatter/vgather2_bc/stub.cpp create mode 100755 test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/compare.py create mode 100755 test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/golden.py create mode 100644 test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/kernel.pto create mode 100644 test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/launch.cpp create mode 100644 test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/main.cpp create mode 100644 test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/stub.cpp create mode 100755 test/vpto/cases/micro-op/gather-scatter/vgatherb/compare.py create mode 100755 test/vpto/cases/micro-op/gather-scatter/vgatherb/golden.py create mode 100644 test/vpto/cases/micro-op/gather-scatter/vgatherb/kernel.pto create mode 100644 test/vpto/cases/micro-op/gather-scatter/vgatherb/launch.cpp create mode 100644 test/vpto/cases/micro-op/gather-scatter/vgatherb/main.cpp create mode 100644 test/vpto/cases/micro-op/gather-scatter/vgatherb/stub.cpp create mode 100755 test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/compare.py create mode 100755 test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/golden.py create mode 100644 test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/kernel.pto create mode 100644 test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/launch.cpp create mode 100644 test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/main.cpp create mode 100644 test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/stub.cpp create mode 100755 test/vpto/cases/micro-op/gather-scatter/vscatter/compare.py create mode 100755 test/vpto/cases/micro-op/gather-scatter/vscatter/golden.py create mode 100644 test/vpto/cases/micro-op/gather-scatter/vscatter/kernel.pto create mode 100644 test/vpto/cases/micro-op/gather-scatter/vscatter/launch.cpp create mode 100644 test/vpto/cases/micro-op/gather-scatter/vscatter/main.cpp create mode 100644 test/vpto/cases/micro-op/gather-scatter/vscatter/stub.cpp create mode 100755 test/vpto/cases/micro-op/materialization-predicate/pand/compare.py create mode 100755 test/vpto/cases/micro-op/materialization-predicate/pand/golden.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pand/kernel.pto create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pand/launch.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pand/main.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pand/stub.cpp create mode 100755 test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/compare.py create mode 100755 test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/golden.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/kernel.pto create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/launch.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/main.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/stub.cpp create mode 100755 test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/compare.py create mode 100755 test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/golden.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/kernel.pto create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/launch.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/main.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/stub.cpp create mode 100755 test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/compare.py create mode 100755 test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/golden.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/kernel.pto create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/launch.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/main.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/stub.cpp create mode 100755 test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/compare.py create mode 100755 test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/golden.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/kernel.pto create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/launch.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/main.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/stub.cpp create mode 100755 test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/compare.py create mode 100755 test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/golden.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/kernel.pto create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/launch.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/main.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/stub.cpp create mode 100755 test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/compare.py create mode 100755 test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/golden.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/kernel.pto create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/launch.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/main.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/stub.cpp create mode 100755 test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/compare.py create mode 100755 test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/golden.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/kernel.pto create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/launch.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/main.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/stub.cpp create mode 100755 test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/compare.py create mode 100755 test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/golden.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/kernel.pto create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/launch.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/main.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/stub.cpp create mode 100755 test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/compare.py create mode 100755 test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/golden.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/kernel.pto create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/launch.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/main.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/stub.cpp create mode 100755 test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/compare.py create mode 100755 test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/golden.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/kernel.pto create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/launch.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/main.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/stub.cpp create mode 100755 test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/compare.py create mode 100755 test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/golden.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/kernel.pto create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/launch.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/main.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/stub.cpp create mode 100755 test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/compare.py create mode 100755 test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/golden.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/kernel.pto create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/launch.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/main.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/stub.cpp create mode 100755 test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/compare.py create mode 100755 test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/golden.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/kernel.pto create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/launch.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/main.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/stub.cpp create mode 100755 test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/compare.py create mode 100755 test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/golden.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/kernel.pto create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/launch.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/main.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/stub.cpp create mode 100755 test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/compare.py create mode 100755 test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/golden.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/kernel.pto create mode 100644 test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/launch.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/main.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/stub.cpp create mode 100755 test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/compare.py create mode 100755 test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/golden.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/kernel.pto create mode 100644 test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/launch.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/main.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/stub.cpp create mode 100755 test/vpto/cases/micro-op/materialization-predicate/pnot/compare.py create mode 100755 test/vpto/cases/micro-op/materialization-predicate/pnot/golden.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pnot/kernel.pto create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pnot/launch.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pnot/main.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pnot/stub.cpp create mode 100755 test/vpto/cases/micro-op/materialization-predicate/por/compare.py create mode 100755 test/vpto/cases/micro-op/materialization-predicate/por/golden.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/por/kernel.pto create mode 100644 test/vpto/cases/micro-op/materialization-predicate/por/launch.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/por/main.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/por/stub.cpp create mode 100755 test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/compare.py create mode 100755 test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/golden.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/kernel.pto create mode 100644 test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/launch.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/main.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/stub.cpp create mode 100755 test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/compare.py create mode 100755 test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/golden.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/kernel.pto create mode 100644 test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/launch.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/main.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/stub.cpp create mode 100755 test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/compare.py create mode 100755 test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/golden.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/kernel.pto create mode 100644 test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/launch.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/main.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/stub.cpp create mode 100755 test/vpto/cases/micro-op/materialization-predicate/psel/compare.py create mode 100755 test/vpto/cases/micro-op/materialization-predicate/psel/golden.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/psel/kernel.pto create mode 100644 test/vpto/cases/micro-op/materialization-predicate/psel/launch.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/psel/main.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/psel/stub.cpp create mode 100755 test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/compare.py create mode 100755 test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/golden.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/kernel.pto create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/launch.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/main.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/stub.cpp create mode 100755 test/vpto/cases/micro-op/materialization-predicate/pset-pattern/compare.py create mode 100755 test/vpto/cases/micro-op/materialization-predicate/pset-pattern/golden.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pset-pattern/kernel.pto create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pset-pattern/launch.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pset-pattern/main.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pset-pattern/stub.cpp create mode 100755 test/vpto/cases/micro-op/materialization-predicate/pxor/compare.py create mode 100755 test/vpto/cases/micro-op/materialization-predicate/pxor/golden.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pxor/kernel.pto create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pxor/launch.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pxor/main.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/pxor/stub.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vbr-f32/compare.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vbr-f32/golden.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vbr-f32/kernel.pto create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vbr-f32/launch.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vbr-f32/main.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vbr-f32/stub.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vbr-i32/compare.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vbr-i32/golden.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vbr-i32/kernel.pto create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vbr-i32/launch.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vbr-i32/main.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vbr-i32/stub.cpp create mode 100755 test/vpto/cases/micro-op/materialization-predicate/vdup-lane/compare.py create mode 100755 test/vpto/cases/micro-op/materialization-predicate/vdup-lane/golden.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vdup-lane/kernel.pto create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vdup-lane/launch.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vdup-lane/main.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vdup-lane/stub.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/compare.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/golden.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/kernel.pto create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/launch.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/main.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/stub.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/compare.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/golden.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/kernel.pto create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/launch.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/main.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/stub.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/compare.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/golden.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/kernel.pto create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/launch.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/main.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/stub.cpp create mode 100644 test/vpto/cases/micro-op/predicate-load-store/_predicate_load_store_case.py create mode 100644 test/vpto/cases/micro-op/predicate-load-store/pldi-norm/compare.py create mode 100644 test/vpto/cases/micro-op/predicate-load-store/pldi-norm/golden.py create mode 100644 test/vpto/cases/micro-op/predicate-load-store/pldi-norm/kernel.pto create mode 100644 test/vpto/cases/micro-op/predicate-load-store/pldi-norm/launch.cpp create mode 100644 test/vpto/cases/micro-op/predicate-load-store/pldi-norm/main.cpp create mode 100644 test/vpto/cases/micro-op/predicate-load-store/pldi-norm/stub.cpp create mode 100644 test/vpto/cases/micro-op/predicate-load-store/plds-norm/compare.py create mode 100644 test/vpto/cases/micro-op/predicate-load-store/plds-norm/golden.py create mode 100644 test/vpto/cases/micro-op/predicate-load-store/plds-norm/kernel.pto create mode 100644 test/vpto/cases/micro-op/predicate-load-store/plds-norm/launch.cpp create mode 100644 test/vpto/cases/micro-op/predicate-load-store/plds-norm/main.cpp create mode 100644 test/vpto/cases/micro-op/predicate-load-store/plds-norm/stub.cpp create mode 100644 test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/compare.py create mode 100644 test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/golden.py create mode 100644 test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/kernel.pto create mode 100644 test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/launch.cpp create mode 100644 test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/main.cpp create mode 100644 test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/stub.cpp create mode 100644 test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/compare.py create mode 100644 test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/golden.py create mode 100644 test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/kernel.pto create mode 100644 test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/launch.cpp create mode 100644 test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/main.cpp create mode 100644 test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/stub.cpp create mode 100644 test/vpto/cases/micro-op/predicate-load-store/psti-pk/compare.py create mode 100644 test/vpto/cases/micro-op/predicate-load-store/psti-pk/golden.py create mode 100644 test/vpto/cases/micro-op/predicate-load-store/psti-pk/kernel.pto create mode 100644 test/vpto/cases/micro-op/predicate-load-store/psti-pk/launch.cpp create mode 100644 test/vpto/cases/micro-op/predicate-load-store/psti-pk/main.cpp create mode 100644 test/vpto/cases/micro-op/predicate-load-store/psti-pk/stub.cpp create mode 100644 test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/compare.py create mode 100644 test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/golden.py create mode 100644 test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/kernel.pto create mode 100644 test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/launch.cpp create mode 100644 test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/main.cpp create mode 100644 test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/stub.cpp create mode 100644 test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/compare.py create mode 100644 test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/golden.py create mode 100644 test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/kernel.pto create mode 100644 test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/launch.cpp create mode 100644 test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/main.cpp create mode 100644 test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/stub.cpp create mode 100644 test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/compare.py create mode 100644 test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/golden.py create mode 100644 test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/kernel.pto create mode 100644 test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/launch.cpp create mode 100644 test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/main.cpp create mode 100644 test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/stub.cpp create mode 100644 test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/compare.py create mode 100644 test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/golden.py create mode 100644 test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/kernel.pto create mode 100644 test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/launch.cpp create mode 100644 test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/main.cpp create mode 100644 test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/stub.cpp create mode 100755 test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/compare.py create mode 100755 test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/golden.py create mode 100644 test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/kernel.pto create mode 100644 test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/launch.cpp create mode 100644 test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/main.cpp create mode 100644 test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/stub.cpp create mode 100755 test/vpto/cases/micro-op/predicate-load-store/pstu/compare.py create mode 100755 test/vpto/cases/micro-op/predicate-load-store/pstu/golden.py create mode 100644 test/vpto/cases/micro-op/predicate-load-store/pstu/kernel.pto create mode 100644 test/vpto/cases/micro-op/predicate-load-store/pstu/launch.cpp create mode 100644 test/vpto/cases/micro-op/predicate-load-store/pstu/main.cpp create mode 100644 test/vpto/cases/micro-op/predicate-load-store/pstu/stub.cpp create mode 100755 test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/compare.py create mode 100755 test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/golden.py create mode 100644 test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/kernel.pto create mode 100644 test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/launch.cpp create mode 100644 test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/main.cpp create mode 100644 test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/stub.cpp create mode 100755 test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/compare.py create mode 100755 test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/golden.py create mode 100644 test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/kernel.pto create mode 100644 test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/launch.cpp create mode 100644 test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/main.cpp create mode 100644 test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/stub.cpp create mode 100644 test/vpto/cases/micro-op/rearrangement/vpack-higher/compare.py create mode 100644 test/vpto/cases/micro-op/rearrangement/vpack-higher/golden.py create mode 100644 test/vpto/cases/micro-op/rearrangement/vpack-higher/kernel.pto create mode 100644 test/vpto/cases/micro-op/rearrangement/vpack-higher/launch.cpp create mode 100644 test/vpto/cases/micro-op/rearrangement/vpack-higher/main.cpp create mode 100644 test/vpto/cases/micro-op/rearrangement/vpack-higher/stub.cpp create mode 100644 test/vpto/cases/micro-op/rearrangement/vpack-lower/compare.py create mode 100644 test/vpto/cases/micro-op/rearrangement/vpack-lower/golden.py create mode 100644 test/vpto/cases/micro-op/rearrangement/vpack-lower/kernel.pto create mode 100644 test/vpto/cases/micro-op/rearrangement/vpack-lower/launch.cpp create mode 100644 test/vpto/cases/micro-op/rearrangement/vpack-lower/main.cpp create mode 100644 test/vpto/cases/micro-op/rearrangement/vpack-lower/stub.cpp create mode 100755 test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/compare.py create mode 100755 test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/golden.py create mode 100644 test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/kernel.pto create mode 100644 test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/launch.cpp create mode 100644 test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/main.cpp create mode 100644 test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/stub.cpp create mode 100755 test/vpto/cases/micro-op/rearrangement/vsqz/compare.py create mode 100755 test/vpto/cases/micro-op/rearrangement/vsqz/golden.py create mode 100644 test/vpto/cases/micro-op/rearrangement/vsqz/kernel.pto create mode 100644 test/vpto/cases/micro-op/rearrangement/vsqz/launch.cpp create mode 100644 test/vpto/cases/micro-op/rearrangement/vsqz/main.cpp create mode 100644 test/vpto/cases/micro-op/rearrangement/vsqz/stub.cpp create mode 100755 test/vpto/cases/micro-op/rearrangement/vsunpack/compare.py create mode 100755 test/vpto/cases/micro-op/rearrangement/vsunpack/golden.py create mode 100644 test/vpto/cases/micro-op/rearrangement/vsunpack/kernel.pto create mode 100644 test/vpto/cases/micro-op/rearrangement/vsunpack/launch.cpp create mode 100644 test/vpto/cases/micro-op/rearrangement/vsunpack/main.cpp create mode 100644 test/vpto/cases/micro-op/rearrangement/vsunpack/stub.cpp create mode 100644 test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/compare.py create mode 100644 test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/golden.py create mode 100644 test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/kernel.pto create mode 100644 test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/launch.cpp create mode 100644 test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/main.cpp create mode 100644 test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/stub.cpp create mode 100644 test/vpto/cases/micro-op/rearrangement/vusqz/compare.py create mode 100644 test/vpto/cases/micro-op/rearrangement/vusqz/golden.py create mode 100644 test/vpto/cases/micro-op/rearrangement/vusqz/kernel.pto create mode 100644 test/vpto/cases/micro-op/rearrangement/vusqz/launch.cpp create mode 100644 test/vpto/cases/micro-op/rearrangement/vusqz/main.cpp create mode 100644 test/vpto/cases/micro-op/rearrangement/vusqz/stub.cpp create mode 100755 test/vpto/cases/micro-op/rearrangement/vzunpack/compare.py create mode 100755 test/vpto/cases/micro-op/rearrangement/vzunpack/golden.py create mode 100644 test/vpto/cases/micro-op/rearrangement/vzunpack/kernel.pto create mode 100644 test/vpto/cases/micro-op/rearrangement/vzunpack/launch.cpp create mode 100644 test/vpto/cases/micro-op/rearrangement/vzunpack/main.cpp create mode 100644 test/vpto/cases/micro-op/rearrangement/vzunpack/stub.cpp create mode 100644 test/vpto/cases/micro-op/reduction/vcadd-tail/compare.py create mode 100644 test/vpto/cases/micro-op/reduction/vcadd-tail/golden.py create mode 100644 test/vpto/cases/micro-op/reduction/vcadd-tail/kernel.pto create mode 100644 test/vpto/cases/micro-op/reduction/vcadd-tail/launch.cpp create mode 100644 test/vpto/cases/micro-op/reduction/vcadd-tail/main.cpp create mode 100644 test/vpto/cases/micro-op/reduction/vcadd-tail/stub.cpp create mode 100644 test/vpto/cases/micro-op/reduction/vcadd/compare.py create mode 100644 test/vpto/cases/micro-op/reduction/vcadd/golden.py create mode 100644 test/vpto/cases/micro-op/reduction/vcadd/kernel.pto create mode 100644 test/vpto/cases/micro-op/reduction/vcadd/launch.cpp create mode 100644 test/vpto/cases/micro-op/reduction/vcadd/main.cpp create mode 100644 test/vpto/cases/micro-op/reduction/vcadd/stub.cpp create mode 100755 test/vpto/cases/micro-op/reduction/vcgadd-tail/compare.py create mode 100755 test/vpto/cases/micro-op/reduction/vcgadd-tail/golden.py create mode 100644 test/vpto/cases/micro-op/reduction/vcgadd-tail/kernel.pto create mode 100644 test/vpto/cases/micro-op/reduction/vcgadd-tail/launch.cpp create mode 100644 test/vpto/cases/micro-op/reduction/vcgadd-tail/main.cpp create mode 100644 test/vpto/cases/micro-op/reduction/vcgadd-tail/stub.cpp create mode 100755 test/vpto/cases/micro-op/reduction/vcgadd/compare.py create mode 100755 test/vpto/cases/micro-op/reduction/vcgadd/golden.py create mode 100644 test/vpto/cases/micro-op/reduction/vcgadd/kernel.pto create mode 100644 test/vpto/cases/micro-op/reduction/vcgadd/launch.cpp create mode 100644 test/vpto/cases/micro-op/reduction/vcgadd/main.cpp create mode 100644 test/vpto/cases/micro-op/reduction/vcgadd/stub.cpp create mode 100755 test/vpto/cases/micro-op/reduction/vcgmax-tie/compare.py create mode 100755 test/vpto/cases/micro-op/reduction/vcgmax-tie/golden.py create mode 100644 test/vpto/cases/micro-op/reduction/vcgmax-tie/kernel.pto create mode 100644 test/vpto/cases/micro-op/reduction/vcgmax-tie/launch.cpp create mode 100644 test/vpto/cases/micro-op/reduction/vcgmax-tie/main.cpp create mode 100644 test/vpto/cases/micro-op/reduction/vcgmax-tie/stub.cpp create mode 100755 test/vpto/cases/micro-op/reduction/vcgmax/compare.py create mode 100755 test/vpto/cases/micro-op/reduction/vcgmax/golden.py create mode 100644 test/vpto/cases/micro-op/reduction/vcgmax/kernel.pto create mode 100644 test/vpto/cases/micro-op/reduction/vcgmax/launch.cpp create mode 100644 test/vpto/cases/micro-op/reduction/vcgmax/main.cpp create mode 100644 test/vpto/cases/micro-op/reduction/vcgmax/stub.cpp create mode 100755 test/vpto/cases/micro-op/reduction/vcgmin-tie/compare.py create mode 100755 test/vpto/cases/micro-op/reduction/vcgmin-tie/golden.py create mode 100644 test/vpto/cases/micro-op/reduction/vcgmin-tie/kernel.pto create mode 100644 test/vpto/cases/micro-op/reduction/vcgmin-tie/launch.cpp create mode 100644 test/vpto/cases/micro-op/reduction/vcgmin-tie/main.cpp create mode 100644 test/vpto/cases/micro-op/reduction/vcgmin-tie/stub.cpp create mode 100755 test/vpto/cases/micro-op/reduction/vcgmin/compare.py create mode 100755 test/vpto/cases/micro-op/reduction/vcgmin/golden.py create mode 100644 test/vpto/cases/micro-op/reduction/vcgmin/kernel.pto create mode 100644 test/vpto/cases/micro-op/reduction/vcgmin/launch.cpp create mode 100644 test/vpto/cases/micro-op/reduction/vcgmin/main.cpp create mode 100644 test/vpto/cases/micro-op/reduction/vcgmin/stub.cpp create mode 100644 test/vpto/cases/micro-op/reduction/vcmax/compare.py create mode 100644 test/vpto/cases/micro-op/reduction/vcmax/golden.py create mode 100644 test/vpto/cases/micro-op/reduction/vcmax/kernel.pto create mode 100644 test/vpto/cases/micro-op/reduction/vcmax/launch.cpp create mode 100644 test/vpto/cases/micro-op/reduction/vcmax/main.cpp create mode 100644 test/vpto/cases/micro-op/reduction/vcmax/stub.cpp create mode 100644 test/vpto/cases/micro-op/reduction/vcmin/compare.py create mode 100644 test/vpto/cases/micro-op/reduction/vcmin/golden.py create mode 100644 test/vpto/cases/micro-op/reduction/vcmin/kernel.pto create mode 100644 test/vpto/cases/micro-op/reduction/vcmin/launch.cpp create mode 100644 test/vpto/cases/micro-op/reduction/vcmin/main.cpp create mode 100644 test/vpto/cases/micro-op/reduction/vcmin/stub.cpp create mode 100755 test/vpto/cases/micro-op/reduction/vcpadd-tail/compare.py create mode 100755 test/vpto/cases/micro-op/reduction/vcpadd-tail/golden.py create mode 100644 test/vpto/cases/micro-op/reduction/vcpadd-tail/kernel.pto create mode 100644 test/vpto/cases/micro-op/reduction/vcpadd-tail/launch.cpp create mode 100644 test/vpto/cases/micro-op/reduction/vcpadd-tail/main.cpp create mode 100644 test/vpto/cases/micro-op/reduction/vcpadd-tail/stub.cpp create mode 100755 test/vpto/cases/micro-op/reduction/vcpadd/compare.py create mode 100755 test/vpto/cases/micro-op/reduction/vcpadd/golden.py create mode 100644 test/vpto/cases/micro-op/reduction/vcpadd/kernel.pto create mode 100644 test/vpto/cases/micro-op/reduction/vcpadd/launch.cpp create mode 100644 test/vpto/cases/micro-op/reduction/vcpadd/main.cpp create mode 100644 test/vpto/cases/micro-op/reduction/vcpadd/stub.cpp create mode 100644 test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/compare.py create mode 100644 test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/golden.py create mode 100644 test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/kernel.pto create mode 100644 test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/launch.cpp create mode 100644 test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/main.cpp create mode 100644 test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/stub.cpp create mode 100644 test/vpto/cases/micro-op/system-runtime-query/get-block-subblock-id/compare.py create mode 100644 test/vpto/cases/micro-op/system-runtime-query/get-block-subblock-id/golden.py create mode 100644 test/vpto/cases/micro-op/system-runtime-query/get-block-subblock-id/kernel.pto create mode 100644 test/vpto/cases/micro-op/system-runtime-query/get-block-subblock-id/launch.cpp create mode 100644 test/vpto/cases/micro-op/system-runtime-query/get-block-subblock-id/main.cpp create mode 100644 test/vpto/cases/micro-op/system-runtime-query/get-block-subblock-id/stub.cpp create mode 100755 test/vpto/cases/micro-op/unary-vector/vabs-f16/compare.py create mode 100755 test/vpto/cases/micro-op/unary-vector/vabs-f16/golden.py create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs-f16/kernel.pto create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs-f16/launch.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs-f16/main.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs-f16/stub.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/compare.py create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/golden.py create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/kernel.pto create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/launch.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/main.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/stub.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/compare.py create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/golden.py create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/kernel.pto create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/launch.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/main.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/stub.cpp create mode 100755 test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/compare.py create mode 100755 test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/golden.py create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/kernel.pto create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/launch.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/main.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/stub.cpp create mode 100755 test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/compare.py create mode 100755 test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/golden.py create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/kernel.pto create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/launch.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/main.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/stub.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/compare.py create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/golden.py create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/kernel.pto create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/launch.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/main.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/stub.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs-tail/compare.py create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs-tail/golden.py create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs-tail/kernel.pto create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs-tail/launch.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs-tail/main.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs-tail/stub.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs/compare.py create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs/golden.py create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs/kernel.pto create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs/launch.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs/main.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vabs/stub.cpp create mode 100755 test/vpto/cases/micro-op/unary-vector/vexp-f16/compare.py create mode 100755 test/vpto/cases/micro-op/unary-vector/vexp-f16/golden.py create mode 100644 test/vpto/cases/micro-op/unary-vector/vexp-f16/kernel.pto create mode 100644 test/vpto/cases/micro-op/unary-vector/vexp-f16/launch.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vexp-f16/main.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vexp-f16/stub.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/compare.py create mode 100644 test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/golden.py create mode 100644 test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/kernel.pto create mode 100644 test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/launch.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/main.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/stub.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/compare.py create mode 100644 test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/golden.py create mode 100644 test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/kernel.pto create mode 100644 test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/launch.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/main.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/stub.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vexp-tail/compare.py create mode 100644 test/vpto/cases/micro-op/unary-vector/vexp-tail/golden.py create mode 100644 test/vpto/cases/micro-op/unary-vector/vexp-tail/kernel.pto create mode 100644 test/vpto/cases/micro-op/unary-vector/vexp-tail/launch.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vexp-tail/main.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vexp-tail/stub.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vexp/compare.py create mode 100644 test/vpto/cases/micro-op/unary-vector/vexp/golden.py create mode 100644 test/vpto/cases/micro-op/unary-vector/vexp/kernel.pto create mode 100644 test/vpto/cases/micro-op/unary-vector/vexp/launch.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vexp/main.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vexp/stub.cpp create mode 100755 test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/compare.py create mode 100755 test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/golden.py create mode 100644 test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/kernel.pto create mode 100644 test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/launch.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/main.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/stub.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vln/compare.py create mode 100644 test/vpto/cases/micro-op/unary-vector/vln/golden.py create mode 100644 test/vpto/cases/micro-op/unary-vector/vln/kernel.pto create mode 100644 test/vpto/cases/micro-op/unary-vector/vln/launch.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vln/main.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vln/stub.cpp create mode 100755 test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/compare.py create mode 100755 test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/golden.py create mode 100644 test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/kernel.pto create mode 100644 test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/launch.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/main.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/stub.cpp create mode 100755 test/vpto/cases/micro-op/unary-vector/vneg/compare.py create mode 100755 test/vpto/cases/micro-op/unary-vector/vneg/golden.py create mode 100644 test/vpto/cases/micro-op/unary-vector/vneg/kernel.pto create mode 100644 test/vpto/cases/micro-op/unary-vector/vneg/launch.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vneg/main.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vneg/stub.cpp create mode 100755 test/vpto/cases/micro-op/unary-vector/vnot/compare.py create mode 100755 test/vpto/cases/micro-op/unary-vector/vnot/golden.py create mode 100644 test/vpto/cases/micro-op/unary-vector/vnot/kernel.pto create mode 100644 test/vpto/cases/micro-op/unary-vector/vnot/launch.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vnot/main.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vnot/stub.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vrelu/compare.py create mode 100644 test/vpto/cases/micro-op/unary-vector/vrelu/golden.py create mode 100644 test/vpto/cases/micro-op/unary-vector/vrelu/kernel.pto create mode 100644 test/vpto/cases/micro-op/unary-vector/vrelu/launch.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vrelu/main.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vrelu/stub.cpp create mode 100755 test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/compare.py create mode 100755 test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/golden.py create mode 100644 test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/kernel.pto create mode 100644 test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/launch.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/main.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/stub.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vsqrt/compare.py create mode 100644 test/vpto/cases/micro-op/unary-vector/vsqrt/golden.py create mode 100644 test/vpto/cases/micro-op/unary-vector/vsqrt/kernel.pto create mode 100644 test/vpto/cases/micro-op/unary-vector/vsqrt/launch.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vsqrt/main.cpp create mode 100644 test/vpto/cases/micro-op/unary-vector/vsqrt/stub.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/compare.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/golden.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/kernel.pto create mode 100644 test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/launch.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/main.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/stub.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vaddcs/compare.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vaddcs/golden.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vaddcs/kernel.pto create mode 100644 test/vpto/cases/micro-op/vec-scalar/vaddcs/launch.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vaddcs/main.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vaddcs/stub.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-bf16/compare.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-bf16/golden.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-bf16/kernel.pto create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-bf16/launch.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-bf16/main.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-bf16/stub.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-f16/compare.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-f16/golden.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-f16/kernel.pto create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-f16/launch.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-f16/main.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-f16/stub.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/compare.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/golden.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/kernel.pto create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/launch.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/main.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/stub.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/compare.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/golden.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/kernel.pto create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/launch.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/main.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/stub.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/compare.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/golden.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/kernel.pto create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/launch.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/main.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/stub.cpp create mode 100755 test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/compare.py create mode 100755 test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/golden.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/kernel.pto create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/launch.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/main.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/stub.cpp create mode 100755 test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/compare.py create mode 100755 test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/golden.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/kernel.pto create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/launch.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/main.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/stub.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-tail/compare.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-tail/golden.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-tail/kernel.pto create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-tail/launch.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-tail/main.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds-tail/stub.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds/compare.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds/golden.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds/kernel.pto create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds/launch.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds/main.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vadds/stub.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/compare.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/golden.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/kernel.pto create mode 100644 test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/launch.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/main.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/stub.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vmaxs/compare.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vmaxs/golden.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vmaxs/kernel.pto create mode 100644 test/vpto/cases/micro-op/vec-scalar/vmaxs/launch.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vmaxs/main.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vmaxs/stub.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vmins-tail/compare.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vmins-tail/golden.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vmins-tail/kernel.pto create mode 100644 test/vpto/cases/micro-op/vec-scalar/vmins-tail/launch.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vmins-tail/main.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vmins-tail/stub.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vmins/compare.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vmins/golden.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vmins/kernel.pto create mode 100644 test/vpto/cases/micro-op/vec-scalar/vmins/launch.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vmins/main.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vmins/stub.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vmuls-tail/compare.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vmuls-tail/golden.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vmuls-tail/kernel.pto create mode 100644 test/vpto/cases/micro-op/vec-scalar/vmuls-tail/launch.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vmuls-tail/main.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vmuls-tail/stub.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vmuls/compare.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vmuls/golden.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vmuls/kernel.pto create mode 100644 test/vpto/cases/micro-op/vec-scalar/vmuls/launch.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vmuls/main.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vmuls/stub.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/compare.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/golden.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/kernel.pto create mode 100644 test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/launch.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/main.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/stub.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vshls/compare.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vshls/golden.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vshls/kernel.pto create mode 100644 test/vpto/cases/micro-op/vec-scalar/vshls/launch.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vshls/main.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vshls/stub.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/compare.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/golden.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/kernel.pto create mode 100644 test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/launch.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/main.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/stub.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vshrs/compare.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vshrs/golden.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vshrs/kernel.pto create mode 100644 test/vpto/cases/micro-op/vec-scalar/vshrs/launch.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vshrs/main.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vshrs/stub.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/compare.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/golden.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/kernel.pto create mode 100644 test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/launch.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/main.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/stub.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vsubcs/compare.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vsubcs/golden.py create mode 100644 test/vpto/cases/micro-op/vec-scalar/vsubcs/kernel.pto create mode 100644 test/vpto/cases/micro-op/vec-scalar/vsubcs/launch.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vsubcs/main.cpp create mode 100644 test/vpto/cases/micro-op/vec-scalar/vsubcs/stub.cpp create mode 100755 test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/compare.py create mode 100755 test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/golden.py create mode 100644 test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/kernel.pto create mode 100644 test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/launch.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/main.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/stub.cpp create mode 100755 test/vpto/cases/micro-op/vector-load-store/vldas-vldus/compare.py create mode 100755 test/vpto/cases/micro-op/vector-load-store/vldas-vldus/golden.py create mode 100644 test/vpto/cases/micro-op/vector-load-store/vldas-vldus/kernel.pto create mode 100644 test/vpto/cases/micro-op/vector-load-store/vldas-vldus/launch.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vldas-vldus/main.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vldas-vldus/stub.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/compare.py create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/golden.py create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/kernel.pto create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/launch.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/main.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/stub.cpp create mode 100755 test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/compare.py create mode 100755 test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/golden.py create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/kernel.pto create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/launch.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/main.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/stub.cpp create mode 100755 test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/compare.py create mode 100755 test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/golden.py create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/kernel.pto create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/launch.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/main.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/stub.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/compare.py create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/golden.py create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/kernel.pto create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/launch.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/main.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/stub.cpp create mode 100755 test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/compare.py create mode 100755 test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/golden.py create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/kernel.pto create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/launch.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/main.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/stub.cpp create mode 100755 test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/compare.py create mode 100755 test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/golden.py create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/kernel.pto create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/launch.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/main.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/stub.cpp create mode 100755 test/vpto/cases/micro-op/vector-load-store/vlds-tail/compare.py create mode 100755 test/vpto/cases/micro-op/vector-load-store/vlds-tail/golden.py create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-tail/kernel.pto create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-tail/launch.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-tail/main.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-tail/stub.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/compare.py create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/golden.py create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/kernel.pto create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/launch.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/main.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/stub.cpp create mode 100755 test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/compare.py create mode 100755 test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/golden.py create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/kernel.pto create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/launch.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/main.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/stub.cpp create mode 100755 test/vpto/cases/micro-op/vector-load-store/vlds/compare.py create mode 100755 test/vpto/cases/micro-op/vector-load-store/vlds/golden.py create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds/kernel.pto create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds/launch.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds/main.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds/stub.cpp create mode 100755 test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/compare.py create mode 100755 test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/golden.py create mode 100644 test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/kernel.pto create mode 100644 test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/launch.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/main.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/stub.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/compare.py create mode 100644 test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/golden.py create mode 100644 test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/kernel.pto create mode 100644 test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/launch.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/main.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/stub.cpp create mode 100755 test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/compare.py create mode 100755 test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/golden.py create mode 100644 test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/kernel.pto create mode 100644 test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/launch.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/main.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/stub.cpp create mode 100755 test/vpto/cases/micro-op/vector-load-store/vsldb/compare.py create mode 100755 test/vpto/cases/micro-op/vector-load-store/vsldb/golden.py create mode 100644 test/vpto/cases/micro-op/vector-load-store/vsldb/kernel.pto create mode 100644 test/vpto/cases/micro-op/vector-load-store/vsldb/launch.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vsldb/main.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vsldb/stub.cpp create mode 100755 test/vpto/cases/micro-op/vector-load-store/vsstb/compare.py create mode 100755 test/vpto/cases/micro-op/vector-load-store/vsstb/golden.py create mode 100644 test/vpto/cases/micro-op/vector-load-store/vsstb/kernel.pto create mode 100644 test/vpto/cases/micro-op/vector-load-store/vsstb/launch.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vsstb/main.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vsstb/stub.cpp create mode 100755 test/vpto/cases/micro-op/vector-load-store/vstar/compare.py create mode 100755 test/vpto/cases/micro-op/vector-load-store/vstar/golden.py create mode 100644 test/vpto/cases/micro-op/vector-load-store/vstar/kernel.pto create mode 100644 test/vpto/cases/micro-op/vector-load-store/vstar/launch.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vstar/main.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vstar/stub.cpp create mode 100755 test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/compare.py create mode 100755 test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/golden.py create mode 100644 test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/kernel.pto create mode 100644 test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/launch.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/main.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/stub.cpp create mode 100755 test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/compare.py create mode 100755 test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/golden.py create mode 100644 test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/kernel.pto create mode 100644 test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/launch.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/main.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/stub.cpp create mode 100755 test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/compare.py create mode 100755 test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/golden.py create mode 100644 test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/kernel.pto create mode 100644 test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/launch.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/main.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/stub.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/compare.py create mode 100644 test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/golden.py create mode 100644 test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/kernel.pto create mode 100644 test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/launch.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/main.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/stub.cpp create mode 100755 test/vpto/cases/micro-op/vector-load-store/vsts-tail/compare.py create mode 100755 test/vpto/cases/micro-op/vector-load-store/vsts-tail/golden.py create mode 100644 test/vpto/cases/micro-op/vector-load-store/vsts-tail/kernel.pto create mode 100644 test/vpto/cases/micro-op/vector-load-store/vsts-tail/launch.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vsts-tail/main.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vsts-tail/stub.cpp create mode 100755 test/vpto/cases/micro-op/vector-load-store/vsts/compare.py create mode 100755 test/vpto/cases/micro-op/vector-load-store/vsts/golden.py create mode 100644 test/vpto/cases/micro-op/vector-load-store/vsts/kernel.pto create mode 100644 test/vpto/cases/micro-op/vector-load-store/vsts/launch.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vsts/main.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vsts/stub.cpp create mode 100755 test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/compare.py create mode 100755 test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/golden.py create mode 100644 test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/kernel.pto create mode 100644 test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/launch.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/main.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/stub.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/compare.py create mode 100644 test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/golden.py create mode 100644 test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/kernel.pto create mode 100644 test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/launch.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/main.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/stub.cpp create mode 100755 test/vpto/cases/micro-op/vector-load-store/vstur/compare.py create mode 100755 test/vpto/cases/micro-op/vector-load-store/vstur/golden.py create mode 100644 test/vpto/cases/micro-op/vector-load-store/vstur/kernel.pto create mode 100644 test/vpto/cases/micro-op/vector-load-store/vstur/launch.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vstur/main.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vstur/stub.cpp create mode 100644 test/vpto/npu_validation/common/test_common.h create mode 100755 test/vpto/scripts/run_host_vpto_validation.sh create mode 100755 test/vpto/scripts/run_host_vpto_validation_parallel.sh diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2b020193b..58e45958d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -264,6 +264,181 @@ jobs: path: ${{ env.PAYLOAD_TGZ }} if-no-files-found: error + vpto-sim-validation: + runs-on: [self-hosted, Linux, X64, label-1] + timeout-minutes: 120 + if: >- + ${{ + github.event_name == 'workflow_dispatch' || + github.event_name == 'schedule' || + (github.event_name == 'pull_request' && + github.event.pull_request.head.repo.full_name == github.repository) + }} + env: + LLVM_COMMIT: cd708029e0b2869e80abe31ddb175f7c35361f90 + LLVM_DIR: ${{ github.workspace }}/llvm-project/llvm/build-shared + MLIR_PYTHONPATH: ${{ github.workspace }}/llvm-project/llvm/build-shared/tools/mlir/python_packages/mlir_core + VPTO_SIM_WORKSPACE: ${{ github.workspace }}/.work/vpto-sim-ci + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + repository: ${{ github.event.pull_request.head.repo.full_name || github.repository }} + ref: ${{ github.event.pull_request.head.sha || github.sha }} + fetch-depth: 1 + persist-credentials: false + + - name: Ensure runner dependencies + shell: bash + run: | + set -euo pipefail + missing_tools=() + for tool in python3 git cmake ninja; do + if ! command -v "${tool}" >/dev/null 2>&1; then + missing_tools+=("${tool}") + fi + done + + if [[ "${#missing_tools[@]}" -gt 0 ]]; then + if command -v sudo >/dev/null 2>&1 && command -v apt-get >/dev/null 2>&1; then + sudo apt-get update + sudo apt-get install -y python3 python3-pip git cmake ninja-build + else + echo "ERROR: missing required tools on self-hosted runner: ${missing_tools[*]}" >&2 + echo "ERROR: automatic installation requires sudo + apt-get" >&2 + exit 1 + fi + fi + + python3 -m pip --version >/dev/null 2>&1 || { + if command -v sudo >/dev/null 2>&1 && command -v apt-get >/dev/null 2>&1; then + sudo apt-get update + sudo apt-get install -y python3-pip + else + echo "ERROR: python3-pip is required on self-hosted runner" >&2 + exit 1 + fi + } + + need_pip_install=0 + python3 -c "import numpy" >/dev/null 2>&1 || need_pip_install=1 + python3 -m pybind11 --cmakedir >/dev/null 2>&1 || need_pip_install=1 + + if [[ "${need_pip_install}" -eq 1 ]]; then + python3 -m pip install --upgrade pip + python3 -m pip install 'pybind11<3' numpy + fi + + - name: Detect reusable LLVM build + id: detect-llvm-build + shell: bash + run: | + set -euo pipefail + if [[ -f "${LLVM_DIR}/lib/cmake/llvm/LLVMConfig.cmake" && \ + -f "${LLVM_DIR}/lib/cmake/mlir/MLIRConfig.cmake" ]]; then + echo "ready=true" >> "${GITHUB_OUTPUT}" + else + echo "ready=false" >> "${GITHUB_OUTPUT}" + fi + + - name: Prepare LLVM source (no rebuild) + if: steps.detect-llvm-build.outputs.ready != 'true' + shell: bash + run: | + set -euo pipefail + mkdir -p llvm-project + cd llvm-project + + if [ ! -d .git ]; then + git init + git remote add origin https://github.com/llvm/llvm-project.git + fi + + git fetch --depth 1 origin tag llvmorg-19.1.7 + git checkout "${LLVM_COMMIT}" + + - name: Build LLVM/MLIR (only if cache miss) + if: steps.detect-llvm-build.outputs.ready != 'true' + shell: bash + run: | + set -euo pipefail + cd llvm-project + cmake -G Ninja -S llvm -B llvm/build-shared \ + -DLLVM_ENABLE_PROJECTS="mlir;clang" \ + -DBUILD_SHARED_LIBS=ON \ + -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DPython3_EXECUTABLE=python3 \ + -DCMAKE_BUILD_TYPE=Release \ + -DLLVM_TARGETS_TO_BUILD="host" + + ninja -C llvm/build-shared + + - name: Build PTOAS + shell: bash + run: | + set -euo pipefail + export PYBIND11_CMAKE_DIR="$(python3 -m pybind11 --cmakedir)" + cmake -G Ninja -S . -B build \ + -DLLVM_DIR="${LLVM_DIR}/lib/cmake/llvm" \ + -DMLIR_DIR="${LLVM_DIR}/lib/cmake/mlir" \ + -DPython3_EXECUTABLE=python3 \ + -DPython3_FIND_STRATEGY=LOCATION \ + -Dpybind11_DIR="${PYBIND11_CMAKE_DIR}" \ + -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DMLIR_PYTHON_PACKAGE_DIR="${LLVM_DIR}/tools/mlir/python_packages/mlir_core" \ + -DCMAKE_BUILD_TYPE=Release + ninja -C build ptoas + + - name: Resolve simulator environment + shell: bash + run: | + set -euo pipefail + + detect_ascend_home() { + for d in \ + "${ASCEND_HOME_PATH:-}" \ + /usr/local/Ascend/cann \ + /usr/local/Ascend/cann-* \ + /usr/local/Ascend/ascend-toolkit/latest + do + [[ -n "${d}" && -d "${d}" ]] || continue + printf '%s\n' "${d}" + return 0 + done + return 1 + } + + ASCEND_HOME_PATH_DETECTED="$(detect_ascend_home || true)" + if [[ -z "${ASCEND_HOME_PATH_DETECTED}" ]]; then + echo "ERROR: failed to detect ASCEND_HOME_PATH on self-hosted runner" >&2 + exit 1 + fi + + echo "ASCEND_HOME_PATH=${ASCEND_HOME_PATH_DETECTED}" >> "${GITHUB_ENV}" + echo "PTOAS_BIN=${GITHUB_WORKSPACE}/build/tools/ptoas/ptoas" >> "${GITHUB_ENV}" + + - name: Run VPTO SIM validation + shell: bash + run: | + set -euo pipefail + mkdir -p "${VPTO_SIM_WORKSPACE}" + WORK_SPACE="${VPTO_SIM_WORKSPACE}" \ + ASCEND_HOME_PATH="${ASCEND_HOME_PATH}" \ + PTOAS_BIN="${PTOAS_BIN}" \ + DEVICE=SIM \ + JOBS="${JOBS:-32}" \ + bash test/vpto/scripts/run_host_vpto_validation_parallel.sh + + - name: Upload VPTO SIM logs + if: always() + uses: actions/upload-artifact@v4 + with: + name: vpto-sim-validation-${{ github.run_id }} + path: | + ${{ env.VPTO_SIM_WORKSPACE }}/parallel-runner.log + ${{ env.VPTO_SIM_WORKSPACE }}/parallel-summary.tsv + if-no-files-found: warn + remote-npu-validation: needs: build-and-test runs-on: ubuntu-22.04 diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-bf16/compare.py b/test/vpto/cases/micro-op/binary-vector/vadd-bf16/compare.py new file mode 100755 index 000000000..68ceff820 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-bf16/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vadd-bf16 +# family: binary-vector +# target_ops: pto.vadd +# scenarios: core-bf16, full-mask +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.uint16, 0, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-bf16/golden.py b/test/vpto/cases/micro-op/binary-vector/vadd-bf16/golden.py new file mode 100755 index 000000000..c1d417ba0 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-bf16/golden.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vadd-bf16 +# family: binary-vector +# target_ops: pto.vadd +# scenarios: core-bf16, full-mask +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def f32_to_bf16_bits(values: np.ndarray) -> np.ndarray: + wide = values.astype(np.float32, copy=False).view(np.uint32) + rounding = np.uint32(0x7FFF) + ((wide >> 16) & np.uint32(1)) + return ((wide + rounding) >> 16).astype(np.uint16) + + +def bf16_bits_to_f32(bits: np.ndarray) -> np.ndarray: + return (bits.astype(np.uint32) << 16).view(np.float32) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1_f32 = rng.uniform(-4.0, 4.0, size=ELEMS).astype(np.float32) + v2_f32 = rng.uniform(-4.0, 4.0, size=ELEMS).astype(np.float32) + v1 = f32_to_bf16_bits(v1_f32) + v2 = f32_to_bf16_bits(v2_f32) + v3 = np.zeros(ELEMS, dtype=np.uint16) + golden_v3 = f32_to_bf16_bits(bf16_bits_to_f32(v1) + bf16_bits_to_f32(v2)) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-bf16/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vadd-bf16/kernel.pto new file mode 100644 index 000000000..f86357b6b --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-bf16/kernel.pto @@ -0,0 +1,50 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vadd-bf16 +// family: binary-vector +// target_ops: pto.vadd +// scenarios: core-bf16, full-mask +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vadd_bf16_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xbf16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xbf16> + %sum = pto.vadd %lhs, %rhs, %mask : !pto.vreg<128xbf16>, !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<128xbf16> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<128xbf16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-bf16/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-bf16/launch.cpp new file mode 100644 index 000000000..13e50fe0b --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-bf16/launch.cpp @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vadd-bf16 +// family: binary-vector +// target_ops: pto.vadd +// scenarios: core-bf16, full-mask +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vadd_bf16_kernel(__gm__ bfloat16_t *v1, + __gm__ bfloat16_t *v2, + __gm__ bfloat16_t *v3); + +void LaunchVadd_bf16_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream) { + vadd_bf16_kernel<<<1, nullptr, stream>>>((__gm__ bfloat16_t *)v1, + (__gm__ bfloat16_t *)v2, + (__gm__ bfloat16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-bf16/main.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-bf16/main.cpp new file mode 100644 index 000000000..e130fecc0 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-bf16/main.cpp @@ -0,0 +1,102 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vadd-bf16 +// family: binary-vector +// target_ops: pto.vadd +// scenarios: core-bf16, full-mask +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadd_bf16_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadd_bf16_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-bf16/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-bf16/stub.cpp new file mode 100644 index 000000000..d23dccb40 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-bf16/stub.cpp @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vadd-bf16 +// family: binary-vector +// target_ops: pto.vadd +// scenarios: core-bf16, full-mask +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vadd_bf16_kernel(__gm__ bfloat16_t *v1, + __gm__ bfloat16_t *v2, + __gm__ bfloat16_t *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-f16/compare.py b/test/vpto/cases/micro-op/binary-vector/vadd-f16/compare.py new file mode 100755 index 000000000..1254044fb --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-f16/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vadd-f16 +# family: binary-vector +# target_ops: pto.vadd +# scenarios: core-f16, full-mask +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.float16, 5e-3, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-f16/golden.py b/test/vpto/cases/micro-op/binary-vector/vadd-f16/golden.py new file mode 100755 index 000000000..442cc35e7 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-f16/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vadd-f16 +# family: binary-vector +# target_ops: pto.vadd +# scenarios: core-f16, full-mask +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LOGICAL_ELEMS = 1000 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float16) + v2 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float16) + v3 = np.zeros((ROWS, COLS), dtype=np.float16) + golden_v3 = (v1.astype(np.float32) + v2.astype(np.float32)).astype(np.float16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-f16/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vadd-f16/kernel.pto new file mode 100644 index 000000000..7b382fad6 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-f16/kernel.pto @@ -0,0 +1,50 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vadd-f16 +// family: binary-vector +// target_ops: pto.vadd +// scenarios: core-f16, full-mask +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vadd_f16_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xf16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xf16> + %sum = pto.vadd %lhs, %rhs, %mask : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<128xf16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-f16/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-f16/launch.cpp new file mode 100644 index 000000000..8beb1a003 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-f16/launch.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vadd-f16 +// family: binary-vector +// target_ops: pto.vadd +// scenarios: core-f16, full-mask +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vadd_f16_kernel(__gm__ half *v1, + __gm__ half *v2, + __gm__ half *v3); + +void LaunchVadd_f16_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream) { + vadd_f16_kernel<<<1, nullptr, stream>>>((__gm__ half *)v1, (__gm__ half *)v2, + (__gm__ half *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-f16/main.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-f16/main.cpp new file mode 100644 index 000000000..621cf398a --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-f16/main.cpp @@ -0,0 +1,101 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vadd-f16 +// family: binary-vector +// target_ops: pto.vadd +// scenarios: core-f16, full-mask +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadd_f16_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadd_f16_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-f16/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-f16/stub.cpp new file mode 100644 index 000000000..f5d834b0e --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-f16/stub.cpp @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vadd-f16 +// family: binary-vector +// target_ops: pto.vadd +// scenarios: core-f16, full-mask +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vadd_f16_kernel(__gm__ half *v1, + __gm__ half *v2, + __gm__ half *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/compare.py b/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/compare.py new file mode 100644 index 000000000..a5f14dabc --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/golden.py b/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/golden.py new file mode 100644 index 000000000..802880fdc --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/golden.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + del seed + specials_a = np.array( + [-np.inf, -7.5, -0.0, 0.0, 1.0, np.inf, np.nan, 3.5], + dtype=np.float32, + ) + specials_b = np.array( + [np.inf, 2.5, 0.0, -0.0, -1.0, -np.inf, 1.0, np.nan], + dtype=np.float32, + ) + v1 = np.resize(specials_a, ROWS * COLS).reshape(ROWS, COLS).astype(np.float32) + v2 = np.resize(specials_b, ROWS * COLS).reshape(ROWS, COLS).astype(np.float32) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v3 = (v1 + v2).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/kernel.pto new file mode 100644 index 000000000..b514a3a81 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/kernel.pto @@ -0,0 +1,46 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vadd_f32_exceptional_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %sum = pto.vadd %lhs, %rhs, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/launch.cpp new file mode 100644 index 000000000..fbb0031f2 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/launch.cpp @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vadd_f32_exceptional_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchVadd_f32_exceptional_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + vadd_f32_exceptional_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/main.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/main.cpp new file mode 100644 index 000000000..781e0d000 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/main.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadd_f32_exceptional_kernel_2d(float *v1, float *v2, float *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadd_f32_exceptional_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/stub.cpp new file mode 100644 index 000000000..114ac5c35 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/stub.cpp @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vadd_f32_exceptional_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/compare.py b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/compare.py new file mode 100644 index 000000000..fe6bc69c3 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vadd-i16-signed-overflow +# family: binary-vector +# target_ops: pto.vadd +# scenarios: core-i16-signed, full-mask, integer-overflow + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str, dtype) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin", np.int16) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/golden.py b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/golden.py new file mode 100644 index 000000000..960e6d163 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/golden.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vadd-i16-signed-overflow +# family: binary-vector +# target_ops: pto.vadd +# scenarios: core-i16-signed, full-mask, integer-overflow + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def wrap_add_i16(lhs: np.ndarray, rhs: np.ndarray) -> np.ndarray: + bits = lhs.view(np.uint16).astype(np.uint32) + rhs.view(np.uint16).astype(np.uint32) + return (bits & 0xFFFF).astype(np.uint16).view(np.int16) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + lhs_pattern = np.array( + [32767, 32760, -32768, -32760, 1000, -1000, 12345, -12345], + dtype=np.int16, + ) + rhs_pattern = np.array( + [1, 100, -1, -100, 30000, -30000, 23456, -23456], + dtype=np.int16, + ) + repeats = ELEMS // lhs_pattern.size + v1 = np.tile(lhs_pattern, repeats) + v2 = np.tile(rhs_pattern, repeats) + v3 = np.zeros(ELEMS, dtype=np.int16) + golden_v3 = wrap_add_i16(v1, v2) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/kernel.pto new file mode 100644 index 000000000..59770ba4b --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/kernel.pto @@ -0,0 +1,53 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vadd-i16-signed-overflow +// family: binary-vector +// target_ops: pto.vadd +// scenarios: core-i16-signed, full-mask, integer-overflow +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vadd_i16_signed_overflow_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xi16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xi16> + %sum = pto.vadd %lhs, %rhs, %mask : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<128xi16>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/launch.cpp new file mode 100644 index 000000000..7e3c8bb76 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/launch.cpp @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vadd_i16_signed_overflow_kernel(__gm__ int16_t *v1, + __gm__ int16_t *v2, + __gm__ int16_t *v3); + +void LaunchVadd_i16_signed_overflow_kernel(int16_t *v1, int16_t *v2, int16_t *v3, + void *stream) { + vadd_i16_signed_overflow_kernel<<<1, nullptr, stream>>>((__gm__ int16_t *)v1, + (__gm__ int16_t *)v2, + (__gm__ int16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/main.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/main.cpp new file mode 100644 index 000000000..26f3895f8 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/main.cpp @@ -0,0 +1,92 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadd_i16_signed_overflow_kernel(int16_t *v1, int16_t *v2, int16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(int16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(int16_t); + int16_t *v1Host = nullptr; + int16_t *v1Device = nullptr; + int16_t *v2Host = nullptr; + int16_t *v2Device = nullptr; + int16_t *v3Host = nullptr; + int16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadd_i16_signed_overflow_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/stub.cpp new file mode 100644 index 000000000..d5f243f4f --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vadd_i16_signed_overflow_kernel(__gm__ int16_t *v1, + __gm__ int16_t *v2, + __gm__ int16_t *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/compare.py b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/compare.py new file mode 100755 index 000000000..2f49f90a6 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vadd-i16-signed +# family: binary-vector +# target_ops: pto.vadd +# scenarios: core-i16-signed, full-mask +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.int16, 0, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/golden.py b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/golden.py new file mode 100755 index 000000000..38079c47f --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vadd-i16-signed +# family: binary-vector +# target_ops: pto.vadd +# scenarios: core-i16-signed, full-mask +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(-1000, 1001, size=(ROWS, COLS), dtype=np.int16) + v2 = rng.integers(-1000, 1001, size=(ROWS, COLS), dtype=np.int16) + v3 = np.zeros((ROWS, COLS), dtype=np.int16) + golden_v3 = (v1.astype(np.int32) + v2.astype(np.int32)).astype(np.int16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/kernel.pto new file mode 100644 index 000000000..13302379f --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/kernel.pto @@ -0,0 +1,50 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vadd-i16-signed +// family: binary-vector +// target_ops: pto.vadd +// scenarios: core-i16-signed, full-mask +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vadd_i16_signed_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xi16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xi16> + %sum = pto.vadd %lhs, %rhs, %mask : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<128xi16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/launch.cpp new file mode 100644 index 000000000..5f2e4f059 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/launch.cpp @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vadd-i16-signed +// family: binary-vector +// target_ops: pto.vadd +// scenarios: core-i16-signed, full-mask +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vadd_i16_signed_kernel(__gm__ int16_t *v1, + __gm__ int16_t *v2, + __gm__ int16_t *v3); + +void LaunchVadd_i16_signed_kernel(int16_t *v1, int16_t *v2, int16_t *v3, + void *stream) { + vadd_i16_signed_kernel<<<1, nullptr, stream>>>((__gm__ int16_t *)v1, + (__gm__ int16_t *)v2, + (__gm__ int16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/main.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/main.cpp new file mode 100644 index 000000000..4d22c989d --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/main.cpp @@ -0,0 +1,102 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vadd-i16-signed +// family: binary-vector +// target_ops: pto.vadd +// scenarios: core-i16-signed, full-mask +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadd_i16_signed_kernel(int16_t *v1, int16_t *v2, int16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(int16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(int16_t); + int16_t *v1Host = nullptr; + int16_t *v1Device = nullptr; + int16_t *v2Host = nullptr; + int16_t *v2Device = nullptr; + int16_t *v3Host = nullptr; + int16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadd_i16_signed_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/stub.cpp new file mode 100644 index 000000000..1c27e339b --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/stub.cpp @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vadd-i16-signed +// family: binary-vector +// target_ops: pto.vadd +// scenarios: core-i16-signed, full-mask +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vadd_i16_signed_kernel(__gm__ int16_t *v1, + __gm__ int16_t *v2, + __gm__ int16_t *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/compare.py b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/compare.py new file mode 100644 index 000000000..4e992a275 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vadd-i16-unsigned-overflow +# family: binary-vector +# target_ops: pto.vadd +# scenarios: core-i16-unsigned, full-mask, integer-overflow + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str, dtype) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin", np.uint16) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/golden.py b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/golden.py new file mode 100644 index 000000000..4673e2501 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/golden.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vadd-i16-unsigned-overflow +# family: binary-vector +# target_ops: pto.vadd +# scenarios: core-i16-unsigned, full-mask, integer-overflow + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def wrap_add_u16(lhs: np.ndarray, rhs: np.ndarray) -> np.ndarray: + wide = lhs.astype(np.uint32) + rhs.astype(np.uint32) + return (wide & 0xFFFF).astype(np.uint16) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + lhs_pattern = np.array( + [65535, 65530, 65500, 60000, 100, 0, 32768, 12345], + dtype=np.uint16, + ) + rhs_pattern = np.array( + [1, 10, 1000, 10000, 65535, 5, 40000, 60000], + dtype=np.uint16, + ) + repeats = ELEMS // lhs_pattern.size + v1 = np.tile(lhs_pattern, repeats) + v2 = np.tile(rhs_pattern, repeats) + v3 = np.zeros(ELEMS, dtype=np.uint16) + golden_v3 = wrap_add_u16(v1, v2) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/kernel.pto new file mode 100644 index 000000000..043d33afd --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/kernel.pto @@ -0,0 +1,53 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vadd-i16-unsigned-overflow +// family: binary-vector +// target_ops: pto.vadd +// scenarios: core-i16-unsigned, full-mask, integer-overflow +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vadd_i16_unsigned_overflow_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %sum = pto.vadd %lhs, %rhs, %mask : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/launch.cpp new file mode 100644 index 000000000..bfd0fbe37 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/launch.cpp @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vadd_i16_unsigned_overflow_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2, + __gm__ uint16_t *v3); + +void LaunchVadd_i16_unsigned_overflow_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream) { + vadd_i16_unsigned_overflow_kernel<<<1, nullptr, stream>>>((__gm__ uint16_t *)v1, + (__gm__ uint16_t *)v2, + (__gm__ uint16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/main.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/main.cpp new file mode 100644 index 000000000..fb6fa53b2 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/main.cpp @@ -0,0 +1,92 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadd_i16_unsigned_overflow_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadd_i16_unsigned_overflow_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/stub.cpp new file mode 100644 index 000000000..f148bd883 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vadd_i16_unsigned_overflow_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2, + __gm__ uint16_t *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/compare.py b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/compare.py new file mode 100755 index 000000000..29c833e93 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vadd-i16-unsigned +# family: binary-vector +# target_ops: pto.vadd +# scenarios: core-i16-unsigned, full-mask +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.uint16, 0, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/golden.py b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/golden.py new file mode 100755 index 000000000..fa3e8e0c1 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vadd-i16-unsigned +# family: binary-vector +# target_ops: pto.vadd +# scenarios: core-i16-unsigned, full-mask +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, 2001, size=(ROWS, COLS), dtype=np.uint16) + v2 = rng.integers(0, 2001, size=(ROWS, COLS), dtype=np.uint16) + v3 = np.zeros((ROWS, COLS), dtype=np.uint16) + golden_v3 = (v1.astype(np.uint32) + v2.astype(np.uint32)).astype(np.uint16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/kernel.pto new file mode 100644 index 000000000..a73b810d9 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/kernel.pto @@ -0,0 +1,50 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vadd-i16-unsigned +// family: binary-vector +// target_ops: pto.vadd +// scenarios: core-i16-unsigned, full-mask +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vadd_i16_unsigned_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %sum = pto.vadd %lhs, %rhs, %mask : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/launch.cpp new file mode 100644 index 000000000..c4198a017 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/launch.cpp @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vadd-i16-unsigned +// family: binary-vector +// target_ops: pto.vadd +// scenarios: core-i16-unsigned, full-mask +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vadd_i16_unsigned_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2, + __gm__ uint16_t *v3); + +void LaunchVadd_i16_unsigned_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream) { + vadd_i16_unsigned_kernel<<<1, nullptr, stream>>>((__gm__ uint16_t *)v1, + (__gm__ uint16_t *)v2, + (__gm__ uint16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/main.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/main.cpp new file mode 100644 index 000000000..dd05d5051 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/main.cpp @@ -0,0 +1,102 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vadd-i16-unsigned +// family: binary-vector +// target_ops: pto.vadd +// scenarios: core-i16-unsigned, full-mask +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadd_i16_unsigned_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadd_i16_unsigned_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/stub.cpp new file mode 100644 index 000000000..f03220ca7 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/stub.cpp @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vadd-i16-unsigned +// family: binary-vector +// target_ops: pto.vadd +// scenarios: core-i16-unsigned, full-mask +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vadd_i16_unsigned_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2, + __gm__ uint16_t *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-tail/compare.py b/test/vpto/cases/micro-op/binary-vector/vadd-tail/compare.py new file mode 100644 index 000000000..c95419953 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-tail/compare.py @@ -0,0 +1,39 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.float32, 1e-4, 1000) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-tail/golden.py b/test/vpto/cases/micro-op/binary-vector/vadd-tail/golden.py new file mode 100644 index 000000000..e967b1153 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-tail/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LOGICAL_ELEMS = 1000 +OUT_SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.random((ROWS, COLS), dtype=np.float32) + v2 = rng.random((ROWS, COLS), dtype=np.float32) + v3 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v3 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v3.reshape(-1)[:LOGICAL_ELEMS] = ( + v1.reshape(-1)[:LOGICAL_ELEMS] + v2.reshape(-1)[:LOGICAL_ELEMS] + ).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-tail/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vadd-tail/kernel.pto new file mode 100644 index 000000000..ec2fa5fd6 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-tail/kernel.pto @@ -0,0 +1,46 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vadd_tail_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1000_i32 = arith.constant 1000 : i32 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %sum = pto.vadd %lhs, %rhs, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-tail/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-tail/launch.cpp new file mode 100644 index 000000000..3d1578331 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-tail/launch.cpp @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vadd_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchVadd_tail_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + vadd_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-tail/main.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-tail/main.cpp new file mode 100644 index 000000000..40a9881d6 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-tail/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadd_tail_kernel_2d(float *v1, float *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadd_tail_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-tail/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-tail/stub.cpp new file mode 100644 index 000000000..fce6f90f8 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-tail/stub.cpp @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vadd_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd/compare.py b/test/vpto/cases/micro-op/binary-vector/vadd/compare.py new file mode 100644 index 000000000..a5f14dabc --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vadd/golden.py b/test/vpto/cases/micro-op/binary-vector/vadd/golden.py new file mode 100644 index 000000000..fbf37245e --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd/golden.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + golden_v3 = (v1 + v2).astype(np.float32, copy=False) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vadd/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vadd/kernel.pto new file mode 100644 index 000000000..b74c45432 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd/kernel.pto @@ -0,0 +1,55 @@ +module attributes {pto.target_arch = "a5"} { + func.func @add_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.get_buf "PIPE_MTE2", 0, 0 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.rls_buf "PIPE_MTE2", 0, 0 + pto.get_buf "PIPE_MTE2", 1, 0 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.rls_buf "PIPE_MTE2", 1, 0 + pto.get_buf "PIPE_V", 0, 0 + pto.get_buf "PIPE_V", 1, 0 + pto.get_buf "PIPE_V", 2, 0 + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %sum = pto.vadd %lhs, %rhs, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + pto.rls_buf "PIPE_V", 0, 0 + pto.rls_buf "PIPE_V", 1, 0 + pto.rls_buf "PIPE_V", 2, 0 + + pto.get_buf "PIPE_MTE3", 2, 0 + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.rls_buf "PIPE_MTE3", 2, 0 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vadd/launch.cpp new file mode 100644 index 000000000..7e1cfc9e0 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd/launch.cpp @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void add_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchAdd_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + add_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd/main.cpp b/test/vpto/cases/micro-op/binary-vector/vadd/main.cpp new file mode 100644 index 000000000..78517d1a3 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchAdd_kernel_2d(float *v1, float *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchAdd_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vadd/stub.cpp new file mode 100644 index 000000000..63f9bd3c6 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd/stub.cpp @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void add_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/compare.py b/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/compare.py new file mode 100755 index 000000000..df15d65e4 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vaddc-carry-boundary +# family: binary-vector +# target_ops: pto.vaddc +# scenarios: core-u32-unsigned, full-mask, carry-chain + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +LOGICAL_ELEMS = 64 +SRC_ELEM_BYTES = 4 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + repeat_elems = REPEAT_BYTES // src_elem_bytes + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + +def compare_result(): + golden = np.fromfile("golden_v3.bin", dtype=np.uint32, count=64) + output = np.fromfile("v3.bin", dtype=np.uint32, count=64) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def compare_carry(): + prefix_bytes = _packed_pred_storage_bytes(LOGICAL_ELEMS, SRC_ELEM_BYTES) + golden = np.fromfile("golden_v4.bin", dtype=np.uint8) + output = np.fromfile("v4.bin", dtype=np.uint8) + if golden.size < prefix_bytes or output.size < prefix_bytes: + return False + return np.array_equal(golden[:prefix_bytes], output[:prefix_bytes]) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_result() and compare_carry() + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/golden.py b/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/golden.py new file mode 100644 index 000000000..253c44d2c --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/golden.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vaddc-carry-boundary +# family: binary-vector +# target_ops: pto.vaddc +# scenarios: core-u32-unsigned, full-mask, carry-chain + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 64 +SEED = 19 + + +def pack_mask_nibbles(bits): + out = np.zeros(256, dtype=np.uint8) + for idx, bit in enumerate(bits): + if not bit: + continue + byte = idx // 2 + if idx % 2 == 0: + out[byte] |= np.uint8(0x1) + else: + out[byte] |= np.uint8(0x10) + return out + + +def generate(output_dir: Path, seed: int) -> None: + del seed + v1 = np.zeros(LANES, dtype=np.uint32) + v2 = np.zeros(LANES, dtype=np.uint32) + pattern_lhs = np.array([0xFFFFFFFF, 0xFFFFFFFE, 0x80000000, 0x7FFFFFFF], dtype=np.uint32) + pattern_rhs = np.array([0x00000001, 0x00000002, 0x80000000, 0x00000001], dtype=np.uint32) + reps = LANES // pattern_lhs.size + v1[:] = np.tile(pattern_lhs, reps) + v2[:] = np.tile(pattern_rhs, reps) + total = v1.astype(np.uint64) + v2.astype(np.uint64) + result = (total & np.uint64(0xFFFFFFFF)).astype(np.uint32) + carry = (total >> np.uint64(32)) != 0 + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + np.zeros(LANES, dtype=np.uint32).tofile(output_dir / "v3.bin") + np.zeros(256, dtype=np.uint8).tofile(output_dir / "v4.bin") + result.tofile(output_dir / "golden_v3.bin") + pack_mask_nibbles(carry).tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/kernel.pto new file mode 100644 index 000000000..4bab24ec8 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/kernel.pto @@ -0,0 +1,53 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vaddc-carry-boundary +// family: binary-vector +// target_ops: pto.vaddc +// scenarios: core-u32-unsigned, full-mask, carry-chain +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vaddc_carry_boundary_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr, %arg3: !pto.ptr) { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_carry = pto.castptr %c12288_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %lhs = pto.vlds %ub_lhs[%c0] : !pto.ptr -> !pto.vreg<64xui32> + %rhs = pto.vlds %ub_rhs[%c0] : !pto.ptr -> !pto.vreg<64xui32> + %sum, %carry = pto.vaddc %lhs, %rhs, %mask : !pto.vreg<64xui32>, !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<64xui32>, !pto.mask + pto.vsts %sum, %ub_out[%c0], %mask : !pto.vreg<64xui32>, !pto.ptr, !pto.mask + pto.psti %carry, %ub_carry[%c0], "NORM" : !pto.mask, !pto.ptr, index + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_carry, %arg3, %c0_i64, %c1_i64, %c128_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/launch.cpp new file mode 100644 index 000000000..9c3f6d2c9 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/launch.cpp @@ -0,0 +1,56 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vaddc-carry-boundary +// family: binary-vector +// target_ops: pto.vaddc +// scenarios: core-u32-unsigned, full-mask, carry-chain +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vaddc_carry_boundary_kernel_2d(__gm__ uint32_t *v1, __gm__ uint32_t *v2, + __gm__ uint32_t *v3, __gm__ uint8_t *v4); + +void LaunchVaddc_carry_boundary_kernel_2d(uint32_t *v1, uint32_t *v2, + uint32_t *v3, uint8_t *v4, + void *stream) { + vaddc_carry_boundary_kernel_2d<<<1, nullptr, stream>>>( + (__gm__ uint32_t *)v1, (__gm__ uint32_t *)v2, (__gm__ uint32_t *)v3, + (__gm__ uint8_t *)v4); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/main.cpp b/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/main.cpp new file mode 100644 index 000000000..486a7314a --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/main.cpp @@ -0,0 +1,119 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vaddc-carry-boundary +// family: binary-vector +// target_ops: pto.vaddc +// scenarios: core-u32-unsigned, full-mask, carry-chain +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVaddc_carry_boundary_kernel_2d(uint32_t *v1, uint32_t *v2, + uint32_t *v3, uint8_t *v4, + void *stream); + +int main() { + size_t elemCount_v1 = 64; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + size_t elemCount_v2 = 64; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint32_t); + size_t elemCount_v3 = 64; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint32_t); + size_t elemCount_v4 = 256; + size_t fileSize_v4 = elemCount_v4 * sizeof(uint8_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + uint32_t *v2Host = nullptr; + uint32_t *v2Device = nullptr; + uint32_t *v3Host = nullptr; + uint32_t *v3Device = nullptr; + uint8_t *v4Host = nullptr; + uint8_t *v4Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMallocHost((void **)(&v4Host), fileSize_v4)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v4Device, fileSize_v4, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ReadFile("./v4.bin", fileSize_v4, v4Host, fileSize_v4); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v4Device, fileSize_v4, v4Host, fileSize_v4, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVaddc_carry_boundary_kernel_2d(v1Device, v2Device, v3Device, v4Device, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v4Host, fileSize_v4, v4Device, fileSize_v4, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + WriteFile("./v4.bin", v4Host, fileSize_v4); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFree(v4Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + aclrtFreeHost(v4Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/stub.cpp new file mode 100644 index 000000000..b6c51f4b0 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vaddc-carry-boundary +// family: binary-vector +// target_ops: pto.vaddc +// scenarios: core-u32-unsigned, full-mask, carry-chain +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void +vaddc_carry_boundary_kernel_2d(__gm__ uint32_t *v1, __gm__ uint32_t *v2, + __gm__ uint32_t *v3, __gm__ uint8_t *v4) { + (void)v1; + (void)v2; + (void)v3; + (void)v4; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vaddc/compare.py b/test/vpto/cases/micro-op/binary-vector/vaddc/compare.py new file mode 100755 index 000000000..af24f7730 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vaddc/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vaddc +# family: binary-vector +# target_ops: pto.vaddc +# scenarios: core-u32-unsigned, full-mask, carry-chain + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +LOGICAL_ELEMS = 64 +SRC_ELEM_BYTES = 4 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + repeat_elems = REPEAT_BYTES // src_elem_bytes + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + +def compare_result(): + golden = np.fromfile("golden_v3.bin", dtype=np.uint32, count=64) + output = np.fromfile("v3.bin", dtype=np.uint32, count=64) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def compare_carry(): + prefix_bytes = _packed_pred_storage_bytes(LOGICAL_ELEMS, SRC_ELEM_BYTES) + golden = np.fromfile("golden_v4.bin", dtype=np.uint8) + output = np.fromfile("v4.bin", dtype=np.uint8) + if golden.size < prefix_bytes or output.size < prefix_bytes: + return False + return np.array_equal(golden[:prefix_bytes], output[:prefix_bytes]) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_result() and compare_carry() + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vaddc/golden.py b/test/vpto/cases/micro-op/binary-vector/vaddc/golden.py new file mode 100644 index 000000000..3fff82cbc --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vaddc/golden.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vaddc +# family: binary-vector +# target_ops: pto.vaddc +# scenarios: core-u32-unsigned, full-mask, carry-chain + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 64 +SEED = 19 + + +def pack_mask_nibbles(bits): + out = np.zeros(256, dtype=np.uint8) + for idx, bit in enumerate(bits): + if not bit: + continue + byte = idx // 2 + if idx % 2 == 0: + out[byte] |= np.uint8(0x1) + else: + out[byte] |= np.uint8(0x10) + return out + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, 0xFFFFFFFF, size=LANES, dtype=np.uint32) + v2 = rng.integers(0, 0xFFFFFFFF, size=LANES, dtype=np.uint32) + total = v1.astype(np.uint64) + v2.astype(np.uint64) + result = (total & np.uint64(0xFFFFFFFF)).astype(np.uint32) + carry = (total >> np.uint64(32)) != 0 + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + np.zeros(LANES, dtype=np.uint32).tofile(output_dir / "v3.bin") + np.zeros(256, dtype=np.uint8).tofile(output_dir / "v4.bin") + result.tofile(output_dir / "golden_v3.bin") + pack_mask_nibbles(carry).tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vaddc/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vaddc/kernel.pto new file mode 100644 index 000000000..dd7246774 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vaddc/kernel.pto @@ -0,0 +1,53 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vaddc +// family: binary-vector +// target_ops: pto.vaddc +// scenarios: core-u32-unsigned, full-mask, carry-chain +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vaddc_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr, %arg3: !pto.ptr) { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + %c128_i64 = arith.constant 128 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_carry = pto.castptr %c12288_i64 : i64 -> !pto.ptr + %false = arith.constant false + + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %lhs = pto.vlds %ub_lhs[%c0] : !pto.ptr -> !pto.vreg<64xui32> + %rhs = pto.vlds %ub_rhs[%c0] : !pto.ptr -> !pto.vreg<64xui32> + %sum, %carry = pto.vaddc %lhs, %rhs, %mask : !pto.vreg<64xui32>, !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<64xui32>, !pto.mask + pto.vsts %sum, %ub_out[%c0], %mask : !pto.vreg<64xui32>, !pto.ptr, !pto.mask + pto.psti %carry, %ub_carry[%c0], "NORM" : !pto.mask, !pto.ptr, index + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_carry, %arg3, %c0_i64, %c1_i64, %c128_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vaddc/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vaddc/launch.cpp new file mode 100644 index 000000000..21640d20f --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vaddc/launch.cpp @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vaddc +// family: binary-vector +// target_ops: pto.vaddc +// scenarios: core-u32-unsigned, full-mask, carry-chain +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vaddc_kernel_2d(__gm__ uint32_t *v1, + __gm__ uint32_t *v2, + __gm__ uint32_t *v3, + __gm__ uint8_t *v4); + +void LaunchVaddc_kernel_2d(uint32_t *v1, uint32_t *v2, uint32_t *v3, + uint8_t *v4, void *stream) { + vaddc_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1, + (__gm__ uint32_t *)v2, + (__gm__ uint32_t *)v3, + (__gm__ uint8_t *)v4); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vaddc/main.cpp b/test/vpto/cases/micro-op/binary-vector/vaddc/main.cpp new file mode 100644 index 000000000..8cf91e047 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vaddc/main.cpp @@ -0,0 +1,117 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vaddc +// family: binary-vector +// target_ops: pto.vaddc +// scenarios: core-u32-unsigned, full-mask, carry-chain +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVaddc_kernel_2d(uint32_t *v1, uint32_t *v2, uint32_t *v3, + uint8_t *v4, void *stream); + +int main() { + size_t elemCount_v1 = 64; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + size_t elemCount_v2 = 64; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint32_t); + size_t elemCount_v3 = 64; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint32_t); + size_t elemCount_v4 = 256; + size_t fileSize_v4 = elemCount_v4 * sizeof(uint8_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + uint32_t *v2Host = nullptr; + uint32_t *v2Device = nullptr; + uint32_t *v3Host = nullptr; + uint32_t *v3Device = nullptr; + uint8_t *v4Host = nullptr; + uint8_t *v4Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMallocHost((void **)(&v4Host), fileSize_v4)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v4Device, fileSize_v4, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ReadFile("./v4.bin", fileSize_v4, v4Host, fileSize_v4); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v4Device, fileSize_v4, v4Host, fileSize_v4, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVaddc_kernel_2d(v1Device, v2Device, v3Device, v4Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v4Host, fileSize_v4, v4Device, fileSize_v4, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + WriteFile("./v4.bin", v4Host, fileSize_v4); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFree(v4Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + aclrtFreeHost(v4Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vaddc/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vaddc/stub.cpp new file mode 100644 index 000000000..4194d42a1 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vaddc/stub.cpp @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vaddc +// family: binary-vector +// target_ops: pto.vaddc +// scenarios: core-u32-unsigned, full-mask, carry-chain +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vaddc_kernel_2d(__gm__ uint32_t *v1, + __gm__ uint32_t *v2, + __gm__ uint32_t *v3, + __gm__ uint8_t *v4) { + (void)v1; + (void)v2; + (void)v3; + (void)v4; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/compare.py b/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/compare.py new file mode 100755 index 000000000..f42233bb4 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/compare.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vand-mask-edge +# family: binary-vector +# target_ops: pto.vand +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.uint16, 0, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/golden.py b/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/golden.py new file mode 100755 index 000000000..27a700901 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vand-mask-edge +# family: binary-vector +# target_ops: pto.vand +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + idx = np.arange(ELEMS, dtype=np.uint16) + v1 = np.where((idx & 1) == 0, np.uint16(0xAAAA), np.uint16(0x0F0F)).astype(np.uint16, copy=False) + v2 = np.where((idx & 2) == 0, np.uint16(0x5555), np.uint16(0x3333)).astype(np.uint16, copy=False) + v3 = np.zeros(ELEMS, dtype=np.uint16) + golden_v3 = np.bitwise_and(v1, v2).astype(np.uint16, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/kernel.pto new file mode 100644 index 000000000..a474d3161 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/kernel.pto @@ -0,0 +1,52 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vand-mask-edge +// family: binary-vector +// target_ops: pto.vand +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vand_mask_edge_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %out = pto.vand %lhs, %rhs, %mask : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/launch.cpp new file mode 100644 index 000000000..3924e63d3 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vand-mask-edge +// family: binary-vector +// target_ops: pto.vand +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vand_mask_edge_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2, + __gm__ uint16_t *v3); + +void LaunchVand_mask_edge_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream) { + vand_mask_edge_kernel<<<1, nullptr, stream>>>((__gm__ uint16_t *)v1, + (__gm__ uint16_t *)v2, + (__gm__ uint16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/main.cpp b/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/main.cpp new file mode 100644 index 000000000..eae6df992 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vand-mask-edge +// family: binary-vector +// target_ops: pto.vand +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVand_mask_edge_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVand_mask_edge_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/stub.cpp new file mode 100644 index 000000000..e23511e54 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/stub.cpp @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vand-mask-edge +// family: binary-vector +// target_ops: pto.vand +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vand_mask_edge_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2, + __gm__ uint16_t *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vand/compare.py b/test/vpto/cases/micro-op/binary-vector/vand/compare.py new file mode 100755 index 000000000..28c2a232c --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vand/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vand +# family: binary-vector +# target_ops: pto.vand +# scenarios: core-i16-unsigned, full-mask +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.uint16, 0, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vand/golden.py b/test/vpto/cases/micro-op/binary-vector/vand/golden.py new file mode 100755 index 000000000..a67709b57 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vand/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vand +# family: binary-vector +# target_ops: pto.vand +# scenarios: core-i16-unsigned, full-mask +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, 0x10000, size=ELEMS, dtype=np.uint16) + v2 = rng.integers(0, 0x10000, size=ELEMS, dtype=np.uint16) + v3 = np.zeros(ELEMS, dtype=np.uint16) + golden_v3 = np.bitwise_and(v1, v2).astype(np.uint16, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vand/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vand/kernel.pto new file mode 100644 index 000000000..453445d96 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vand/kernel.pto @@ -0,0 +1,50 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vand +// family: binary-vector +// target_ops: pto.vand +// scenarios: core-i16-unsigned, full-mask +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vand_i16_unsigned_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %out = pto.vand %lhs, %rhs, %mask : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vand/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vand/launch.cpp new file mode 100644 index 000000000..3008d46bc --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vand/launch.cpp @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vand +// family: binary-vector +// target_ops: pto.vand +// scenarios: core-i16-unsigned, full-mask +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vand_i16_unsigned_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2, + __gm__ uint16_t *v3); + +void LaunchVand_i16_unsigned_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream) { + vand_i16_unsigned_kernel<<<1, nullptr, stream>>>((__gm__ uint16_t *)v1, + (__gm__ uint16_t *)v2, + (__gm__ uint16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vand/main.cpp b/test/vpto/cases/micro-op/binary-vector/vand/main.cpp new file mode 100644 index 000000000..958b4422e --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vand/main.cpp @@ -0,0 +1,101 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vand +// family: binary-vector +// target_ops: pto.vand +// scenarios: core-i16-unsigned, full-mask +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVand_i16_unsigned_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVand_i16_unsigned_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vand/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vand/stub.cpp new file mode 100644 index 000000000..613febecb --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vand/stub.cpp @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vand +// family: binary-vector +// target_ops: pto.vand +// scenarios: core-i16-unsigned, full-mask +// ----------------------------------------------------------------------------- +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vand_i16_unsigned_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2, + __gm__ uint16_t *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv-f16/compare.py b/test/vpto/cases/micro-op/binary-vector/vdiv-f16/compare.py new file mode 100755 index 000000000..1de4f17b7 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv-f16/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vdiv-f16 +# family: binary-vector +# target_ops: pto.vdiv +# scenarios: core-f16, full-mask +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.float16, 5e-3, 1000) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv-f16/golden.py b/test/vpto/cases/micro-op/binary-vector/vdiv-f16/golden.py new file mode 100755 index 000000000..627221d7a --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv-f16/golden.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vdiv-f16 +# family: binary-vector +# target_ops: pto.vdiv +# scenarios: core-f16, full-mask +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LOGICAL_ELEMS = 1000 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float16) + v2_mag = rng.uniform(0.5, 4.0, size=(ROWS, COLS)).astype(np.float32) + v2_sign = np.where(rng.integers(0, 2, size=(ROWS, COLS), dtype=np.int32) == 0, + np.float32(-1.0), np.float32(1.0)) + v2 = (v2_mag * v2_sign).astype(np.float16) + v3 = np.zeros((ROWS, COLS), dtype=np.float16) + golden_v3 = np.zeros((ROWS, COLS), dtype=np.float16) + golden_v3.reshape(-1)[:LOGICAL_ELEMS] = ( + v1.reshape(-1)[:LOGICAL_ELEMS].astype(np.float32) + / v2.reshape(-1)[:LOGICAL_ELEMS].astype(np.float32) + ).astype(np.float16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv-f16/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vdiv-f16/kernel.pto new file mode 100644 index 000000000..8f416095e --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv-f16/kernel.pto @@ -0,0 +1,52 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vdiv-f16 +// family: binary-vector +// target_ops: pto.vdiv +// scenarios: core-f16, full-mask +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vdiv_f16_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xf16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xf16> + %quot = pto.vdiv %lhs, %rhs, %mask : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> + pto.vsts %quot, %ub_out[%offset], %mask : !pto.vreg<128xf16>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv-f16/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vdiv-f16/launch.cpp new file mode 100644 index 000000000..1abdc6f6d --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv-f16/launch.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vdiv-f16 +// family: binary-vector +// target_ops: pto.vdiv +// scenarios: core-f16, full-mask +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vdiv_f16_kernel(__gm__ half *v1, + __gm__ half *v2, + __gm__ half *v3); + +void LaunchVdiv_f16_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream) { + vdiv_f16_kernel<<<1, nullptr, stream>>>((__gm__ half *)v1, (__gm__ half *)v2, + (__gm__ half *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv-f16/main.cpp b/test/vpto/cases/micro-op/binary-vector/vdiv-f16/main.cpp new file mode 100644 index 000000000..d0b9cff9a --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv-f16/main.cpp @@ -0,0 +1,101 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vdiv-f16 +// family: binary-vector +// target_ops: pto.vdiv +// scenarios: core-f16, full-mask +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVdiv_f16_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVdiv_f16_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv-f16/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vdiv-f16/stub.cpp new file mode 100644 index 000000000..8fd0cb97f --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv-f16/stub.cpp @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vdiv-f16 +// family: binary-vector +// target_ops: pto.vdiv +// scenarios: core-f16, full-mask +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vdiv_f16_kernel(__gm__ half *v1, + __gm__ half *v2, + __gm__ half *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/compare.py b/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/compare.py new file mode 100644 index 000000000..a5f14dabc --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/golden.py b/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/golden.py new file mode 100644 index 000000000..9caa514c7 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/golden.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + del seed + numer = np.array( + [-np.inf, -7.5, -0.0, 0.0, 1.0, np.inf, np.nan, 3.5], + dtype=np.float32, + ) + denom = np.array( + [2.0, -2.0, 0.0, -0.0, np.inf, 1.0, 1.0, np.nan], + dtype=np.float32, + ) + v1 = np.resize(numer, ROWS * COLS).reshape(ROWS, COLS).astype(np.float32) + v2 = np.resize(denom, ROWS * COLS).reshape(ROWS, COLS).astype(np.float32) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v3 = np.divide(v1, v2).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/kernel.pto new file mode 100644 index 000000000..bc4b2536f --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/kernel.pto @@ -0,0 +1,46 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vdiv_f32_exceptional_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %quot = pto.vdiv %lhs, %rhs, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %quot, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/launch.cpp new file mode 100644 index 000000000..a68d4e95b --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/launch.cpp @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vdiv_f32_exceptional_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchVdiv_f32_exceptional_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + vdiv_f32_exceptional_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/main.cpp b/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/main.cpp new file mode 100644 index 000000000..d048e2faf --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/main.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVdiv_f32_exceptional_kernel_2d(float *v1, float *v2, float *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVdiv_f32_exceptional_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/stub.cpp new file mode 100644 index 000000000..73737559f --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/stub.cpp @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vdiv_f32_exceptional_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv-tail/compare.py b/test/vpto/cases/micro-op/binary-vector/vdiv-tail/compare.py new file mode 100644 index 000000000..c95419953 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv-tail/compare.py @@ -0,0 +1,39 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.float32, 1e-4, 1000) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv-tail/golden.py b/test/vpto/cases/micro-op/binary-vector/vdiv-tail/golden.py new file mode 100644 index 000000000..c010ada1f --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv-tail/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LOGICAL_ELEMS = 1000 +OUT_SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.random((ROWS, COLS), dtype=np.float32) + np.float32(0.5) + v2 = rng.random((ROWS, COLS), dtype=np.float32) + np.float32(0.5) + v3 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v3 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v3.reshape(-1)[:LOGICAL_ELEMS] = ( + v1.reshape(-1)[:LOGICAL_ELEMS] / v2.reshape(-1)[:LOGICAL_ELEMS] + ).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv-tail/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vdiv-tail/kernel.pto new file mode 100644 index 000000000..96b8c2fe6 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv-tail/kernel.pto @@ -0,0 +1,46 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vdiv_tail_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1000_i32 = arith.constant 1000 : i32 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %quot = pto.vdiv %lhs, %rhs, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %quot, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv-tail/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vdiv-tail/launch.cpp new file mode 100644 index 000000000..85826b5f8 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv-tail/launch.cpp @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vdiv_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchVdiv_tail_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + vdiv_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv-tail/main.cpp b/test/vpto/cases/micro-op/binary-vector/vdiv-tail/main.cpp new file mode 100644 index 000000000..3f3dcf515 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv-tail/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVdiv_tail_kernel_2d(float *v1, float *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVdiv_tail_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv-tail/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vdiv-tail/stub.cpp new file mode 100644 index 000000000..94d70b42e --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv-tail/stub.cpp @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vdiv_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv/compare.py b/test/vpto/cases/micro-op/binary-vector/vdiv/compare.py new file mode 100644 index 000000000..a5f14dabc --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv/golden.py b/test/vpto/cases/micro-op/binary-vector/vdiv/golden.py new file mode 100644 index 000000000..ea43aa613 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2_mag = rng.uniform(0.5, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2_sign = np.where(rng.integers(0, 2, size=(ROWS, COLS), dtype=np.int32) == 0, + np.float32(-1.0), np.float32(1.0)) + v2 = (v2_mag * v2_sign).astype(np.float32, copy=False) + golden_v3 = (v1 / v2).astype(np.float32, copy=False) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vdiv/kernel.pto new file mode 100644 index 000000000..865320944 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv/kernel.pto @@ -0,0 +1,48 @@ +module attributes {pto.target_arch = "a5"} { + func.func @div_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %quot = pto.vdiv %lhs, %rhs, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %quot, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vdiv/launch.cpp new file mode 100644 index 000000000..fd82d2921 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv/launch.cpp @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void div_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchDiv_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + div_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv/main.cpp b/test/vpto/cases/micro-op/binary-vector/vdiv/main.cpp new file mode 100644 index 000000000..3972f99b3 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchDiv_kernel_2d(float *v1, float *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchDiv_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vdiv/stub.cpp new file mode 100644 index 000000000..f563b92ac --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv/stub.cpp @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void div_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmax-tail/compare.py b/test/vpto/cases/micro-op/binary-vector/vmax-tail/compare.py new file mode 100644 index 000000000..c95419953 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmax-tail/compare.py @@ -0,0 +1,39 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.float32, 1e-4, 1000) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmax-tail/golden.py b/test/vpto/cases/micro-op/binary-vector/vmax-tail/golden.py new file mode 100644 index 000000000..82d3beb41 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmax-tail/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LOGICAL_ELEMS = 1000 +OUT_SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.random((ROWS, COLS), dtype=np.float32) + v2 = rng.random((ROWS, COLS), dtype=np.float32) + v3 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v3 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v3.reshape(-1)[:LOGICAL_ELEMS] = np.maximum( + v1.reshape(-1)[:LOGICAL_ELEMS], v2.reshape(-1)[:LOGICAL_ELEMS] + ).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmax-tail/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmax-tail/kernel.pto new file mode 100644 index 000000000..d8d17ca9b --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmax-tail/kernel.pto @@ -0,0 +1,46 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vmax_tail_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1000_i32 = arith.constant 1000 : i32 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %maxv = pto.vmax %lhs, %rhs, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %maxv, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmax-tail/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vmax-tail/launch.cpp new file mode 100644 index 000000000..bab607bfa --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmax-tail/launch.cpp @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vmax_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchVadd_tail_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + vmax_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmax-tail/main.cpp b/test/vpto/cases/micro-op/binary-vector/vmax-tail/main.cpp new file mode 100644 index 000000000..40a9881d6 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmax-tail/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadd_tail_kernel_2d(float *v1, float *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadd_tail_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmax-tail/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vmax-tail/stub.cpp new file mode 100644 index 000000000..fc829f209 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmax-tail/stub.cpp @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vmax_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmax/compare.py b/test/vpto/cases/micro-op/binary-vector/vmax/compare.py new file mode 100644 index 000000000..a5f14dabc --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmax/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmax/golden.py b/test/vpto/cases/micro-op/binary-vector/vmax/golden.py new file mode 100644 index 000000000..aca780439 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmax/golden.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + golden_v3 = np.maximum(v1, v2).astype(np.float32, copy=False) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmax/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmax/kernel.pto new file mode 100644 index 000000000..1e9948ca9 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmax/kernel.pto @@ -0,0 +1,48 @@ +module attributes {pto.target_arch = "a5"} { + func.func @max_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %maxv = pto.vmax %lhs, %rhs, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %maxv, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmax/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vmax/launch.cpp new file mode 100644 index 000000000..44e917951 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmax/launch.cpp @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void max_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchMax_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + max_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmax/main.cpp b/test/vpto/cases/micro-op/binary-vector/vmax/main.cpp new file mode 100644 index 000000000..9713fd509 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmax/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchMax_kernel_2d(float *v1, float *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchMax_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmax/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vmax/stub.cpp new file mode 100644 index 000000000..2225c641b --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmax/stub.cpp @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void max_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-bf16/compare.py b/test/vpto/cases/micro-op/binary-vector/vmin-bf16/compare.py new file mode 100755 index 000000000..8e84eda9e --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-bf16/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vmin-bf16 +# family: binary-vector +# target_ops: pto.vmin +# scenarios: core-bf16, full-mask +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.uint16, 0, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-bf16/golden.py b/test/vpto/cases/micro-op/binary-vector/vmin-bf16/golden.py new file mode 100755 index 000000000..a399eeb9d --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-bf16/golden.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vmin-bf16 +# family: binary-vector +# target_ops: pto.vmin +# scenarios: core-bf16, full-mask +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def f32_to_bf16_bits(values: np.ndarray) -> np.ndarray: + wide = values.astype(np.float32, copy=False).view(np.uint32) + rounding = np.uint32(0x7FFF) + ((wide >> 16) & np.uint32(1)) + return ((wide + rounding) >> 16).astype(np.uint16) + + +def bf16_bits_to_f32(bits: np.ndarray) -> np.ndarray: + return (bits.astype(np.uint32) << 16).view(np.float32) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1_f32 = rng.uniform(-4.0, 4.0, size=ELEMS).astype(np.float32) + v2_f32 = rng.uniform(-4.0, 4.0, size=ELEMS).astype(np.float32) + v1 = f32_to_bf16_bits(v1_f32) + v2 = f32_to_bf16_bits(v2_f32) + v3 = np.zeros(ELEMS, dtype=np.uint16) + golden_v3 = f32_to_bf16_bits(np.minimum(bf16_bits_to_f32(v1), bf16_bits_to_f32(v2))) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-bf16/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmin-bf16/kernel.pto new file mode 100644 index 000000000..4af801098 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-bf16/kernel.pto @@ -0,0 +1,50 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vmin-bf16 +// family: binary-vector +// target_ops: pto.vmin +// scenarios: core-bf16, full-mask +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vmin_bf16_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xbf16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xbf16> + %out = pto.vmin %lhs, %rhs, %mask : !pto.vreg<128xbf16>, !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<128xbf16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xbf16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-bf16/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vmin-bf16/launch.cpp new file mode 100644 index 000000000..1f374fb36 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-bf16/launch.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vmin-bf16 +// family: binary-vector +// target_ops: pto.vmin +// scenarios: core-bf16, full-mask +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vmin_bf16_kernel(__gm__ bfloat16_t *v1, + __gm__ bfloat16_t *v2, + __gm__ bfloat16_t *v3); + +void LaunchVmin_bf16_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, void *stream) { + vmin_bf16_kernel<<<1, nullptr, stream>>>((__gm__ bfloat16_t *)v1, + (__gm__ bfloat16_t *)v2, + (__gm__ bfloat16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-bf16/main.cpp b/test/vpto/cases/micro-op/binary-vector/vmin-bf16/main.cpp new file mode 100644 index 000000000..01fb803f3 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-bf16/main.cpp @@ -0,0 +1,102 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vmin-bf16 +// family: binary-vector +// target_ops: pto.vmin +// scenarios: core-bf16, full-mask +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmin_bf16_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmin_bf16_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-bf16/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vmin-bf16/stub.cpp new file mode 100644 index 000000000..32d4583e7 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-bf16/stub.cpp @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vmin-bf16 +// family: binary-vector +// target_ops: pto.vmin +// scenarios: core-bf16, full-mask +// ----------------------------------------------------------------------------- +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vmin_bf16_kernel(__gm__ bfloat16_t *v1, + __gm__ bfloat16_t *v2, + __gm__ bfloat16_t *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-f16/compare.py b/test/vpto/cases/micro-op/binary-vector/vmin-f16/compare.py new file mode 100755 index 000000000..d4fe300db --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-f16/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vmin-f16 +# family: binary-vector +# target_ops: pto.vmin +# scenarios: core-f16, full-mask +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.float16, 5e-3, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-f16/golden.py b/test/vpto/cases/micro-op/binary-vector/vmin-f16/golden.py new file mode 100755 index 000000000..4ee39a4b3 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-f16/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vmin-f16 +# family: binary-vector +# target_ops: pto.vmin +# scenarios: core-f16, full-mask +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float16) + v2 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float16) + v3 = np.zeros((ROWS, COLS), dtype=np.float16) + golden_v3 = np.minimum(v1.astype(np.float32), v2.astype(np.float32)).astype(np.float16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-f16/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmin-f16/kernel.pto new file mode 100644 index 000000000..2654e8d7e --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-f16/kernel.pto @@ -0,0 +1,50 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vmin-f16 +// family: binary-vector +// target_ops: pto.vmin +// scenarios: core-f16, full-mask +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vmin_f16_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xf16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xf16> + %out = pto.vmin %lhs, %rhs, %mask : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xf16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-f16/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vmin-f16/launch.cpp new file mode 100644 index 000000000..5151be743 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-f16/launch.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vmin-f16 +// family: binary-vector +// target_ops: pto.vmin +// scenarios: core-f16, full-mask +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vmin_f16_kernel(__gm__ half *v1, + __gm__ half *v2, + __gm__ half *v3); + +void LaunchVmin_f16_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, void *stream) { + vmin_f16_kernel<<<1, nullptr, stream>>>((__gm__ half *)v1, + (__gm__ half *)v2, + (__gm__ half *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-f16/main.cpp b/test/vpto/cases/micro-op/binary-vector/vmin-f16/main.cpp new file mode 100644 index 000000000..374d8a322 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-f16/main.cpp @@ -0,0 +1,102 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vmin-f16 +// family: binary-vector +// target_ops: pto.vmin +// scenarios: core-f16, full-mask +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmin_f16_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmin_f16_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-f16/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vmin-f16/stub.cpp new file mode 100644 index 000000000..bcbc00949 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-f16/stub.cpp @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vmin-f16 +// family: binary-vector +// target_ops: pto.vmin +// scenarios: core-f16, full-mask +// ----------------------------------------------------------------------------- +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vmin_f16_kernel(__gm__ half *v1, + __gm__ half *v2, + __gm__ half *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/compare.py b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/compare.py new file mode 100644 index 000000000..a5f14dabc --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/golden.py b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/golden.py new file mode 100644 index 000000000..4d8d2f34a --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/golden.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + del seed + lhs = np.array( + [-np.inf, -7.5, -0.0, 0.0, 1.0, np.inf, np.nan, 3.5], + dtype=np.float32, + ) + rhs = np.array( + [np.inf, -2.5, 0.0, -0.0, -1.0, 1.0, 1.0, np.nan], + dtype=np.float32, + ) + v1 = np.resize(lhs, ROWS * COLS).reshape(ROWS, COLS).astype(np.float32) + v2 = np.resize(rhs, ROWS * COLS).reshape(ROWS, COLS).astype(np.float32) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v3 = np.minimum(v1, v2).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/kernel.pto new file mode 100644 index 000000000..df1987955 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/kernel.pto @@ -0,0 +1,48 @@ +module attributes {pto.target_arch = "a5"} { + func.func @min_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vmin %lhs, %rhs, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/launch.cpp new file mode 100644 index 000000000..f2c64c6a6 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/launch.cpp @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void min_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchMin_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + min_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/main.cpp b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/main.cpp new file mode 100644 index 000000000..b952b76a0 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchMin_kernel_2d(float *v1, float *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchMin_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/stub.cpp new file mode 100644 index 000000000..8a84a8941 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/stub.cpp @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void min_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/vmin/compare.py b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/vmin/compare.py new file mode 100644 index 000000000..a5f14dabc --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/vmin/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/vmin/golden.py b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/vmin/golden.py new file mode 100644 index 000000000..6d18ab792 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/vmin/golden.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.random((ROWS, COLS), dtype=np.float32) + v2 = rng.random((ROWS, COLS), dtype=np.float32) + golden_v3 = np.minimum(v1, v2) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.astype(np.float32, copy=False).reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/vmin/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/vmin/kernel.pto new file mode 100644 index 000000000..d896ab4e8 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/vmin/kernel.pto @@ -0,0 +1,48 @@ +module attributes {pto.target_arch = "a5"} { + func.func @min_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %minv = pto.vmin %lhs, %rhs, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %minv, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/vmin/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/vmin/launch.cpp new file mode 100644 index 000000000..f2c64c6a6 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/vmin/launch.cpp @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void min_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchMin_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + min_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/vmin/main.cpp b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/vmin/main.cpp new file mode 100644 index 000000000..b952b76a0 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/vmin/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchMin_kernel_2d(float *v1, float *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchMin_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/vmin/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/vmin/stub.cpp new file mode 100644 index 000000000..8a84a8941 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/vmin/stub.cpp @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void min_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/compare.py b/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/compare.py new file mode 100755 index 000000000..2afc3f8ec --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vmin-i16-signed +# family: binary-vector +# target_ops: pto.vmin +# scenarios: core-i16-signed, full-mask +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.int16, 0, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/golden.py b/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/golden.py new file mode 100755 index 000000000..48ce71042 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vmin-i16-signed +# family: binary-vector +# target_ops: pto.vmin +# scenarios: core-i16-signed, full-mask +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(-1000, 1001, size=ELEMS, dtype=np.int16) + v2 = rng.integers(-1000, 1001, size=ELEMS, dtype=np.int16) + v3 = np.zeros(ELEMS, dtype=np.int16) + golden_v3 = np.minimum(v1, v2) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/kernel.pto new file mode 100644 index 000000000..19c3779ec --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/kernel.pto @@ -0,0 +1,50 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vmin-i16-signed +// family: binary-vector +// target_ops: pto.vmin +// scenarios: core-i16-signed, full-mask +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vmin_i16_signed_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xi16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xi16> + %out = pto.vmin %lhs, %rhs, %mask : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xi16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/launch.cpp new file mode 100644 index 000000000..923e415d4 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/launch.cpp @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vmin-i16-signed +// family: binary-vector +// target_ops: pto.vmin +// scenarios: core-i16-signed, full-mask +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vmin_i16_signed_kernel(__gm__ int16_t *v1, + __gm__ int16_t *v2, + __gm__ int16_t *v3); + +void LaunchVmin_i16_signed_kernel(int16_t *v1, int16_t *v2, int16_t *v3, + void *stream) { + vmin_i16_signed_kernel<<<1, nullptr, stream>>>((__gm__ int16_t *)v1, + (__gm__ int16_t *)v2, + (__gm__ int16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/main.cpp b/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/main.cpp new file mode 100644 index 000000000..029455a99 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/main.cpp @@ -0,0 +1,101 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vmin-i16-signed +// family: binary-vector +// target_ops: pto.vmin +// scenarios: core-i16-signed, full-mask +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmin_i16_signed_kernel(int16_t *v1, int16_t *v2, int16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(int16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(int16_t); + int16_t *v1Host = nullptr; + int16_t *v1Device = nullptr; + int16_t *v2Host = nullptr; + int16_t *v2Device = nullptr; + int16_t *v3Host = nullptr; + int16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmin_i16_signed_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/stub.cpp new file mode 100644 index 000000000..786ae97ef --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/stub.cpp @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vmin-i16-signed +// family: binary-vector +// target_ops: pto.vmin +// scenarios: core-i16-signed, full-mask +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vmin_i16_signed_kernel(__gm__ int16_t *v1, + __gm__ int16_t *v2, + __gm__ int16_t *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/compare.py b/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/compare.py new file mode 100755 index 000000000..f87d0f17d --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vmin-i16-unsigned +# family: binary-vector +# target_ops: pto.vmin +# scenarios: core-i16-unsigned, full-mask +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.uint16, 0, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/golden.py b/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/golden.py new file mode 100755 index 000000000..7ac5b68a6 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vmin-i16-unsigned +# family: binary-vector +# target_ops: pto.vmin +# scenarios: core-i16-unsigned, full-mask +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, 2001, size=ELEMS, dtype=np.uint16) + v2 = rng.integers(0, 2001, size=ELEMS, dtype=np.uint16) + v3 = np.zeros(ELEMS, dtype=np.uint16) + golden_v3 = np.minimum(v1, v2) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/kernel.pto new file mode 100644 index 000000000..33f972373 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/kernel.pto @@ -0,0 +1,50 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vmin-i16-unsigned +// family: binary-vector +// target_ops: pto.vmin +// scenarios: core-i16-unsigned, full-mask +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vmin_i16_unsigned_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %out = pto.vmin %lhs, %rhs, %mask : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/launch.cpp new file mode 100644 index 000000000..6cc3d692c --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/launch.cpp @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vmin-i16-unsigned +// family: binary-vector +// target_ops: pto.vmin +// scenarios: core-i16-unsigned, full-mask +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vmin_i16_unsigned_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2, + __gm__ uint16_t *v3); + +void LaunchVmin_i16_unsigned_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream) { + vmin_i16_unsigned_kernel<<<1, nullptr, stream>>>((__gm__ uint16_t *)v1, + (__gm__ uint16_t *)v2, + (__gm__ uint16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/main.cpp b/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/main.cpp new file mode 100644 index 000000000..885dea67a --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/main.cpp @@ -0,0 +1,101 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vmin-i16-unsigned +// family: binary-vector +// target_ops: pto.vmin +// scenarios: core-i16-unsigned, full-mask +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmin_i16_unsigned_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmin_i16_unsigned_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/stub.cpp new file mode 100644 index 000000000..4429f6df6 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/stub.cpp @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vmin-i16-unsigned +// family: binary-vector +// target_ops: pto.vmin +// scenarios: core-i16-unsigned, full-mask +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vmin_i16_unsigned_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2, + __gm__ uint16_t *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-tail/compare.py b/test/vpto/cases/micro-op/binary-vector/vmin-tail/compare.py new file mode 100644 index 000000000..c95419953 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-tail/compare.py @@ -0,0 +1,39 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.float32, 1e-4, 1000) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-tail/golden.py b/test/vpto/cases/micro-op/binary-vector/vmin-tail/golden.py new file mode 100644 index 000000000..29bbdcd28 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-tail/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LOGICAL_ELEMS = 1000 +OUT_SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v3 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v3 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v3.reshape(-1)[:LOGICAL_ELEMS] = np.minimum( + v1.reshape(-1)[:LOGICAL_ELEMS], v2.reshape(-1)[:LOGICAL_ELEMS] + ).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-tail/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmin-tail/kernel.pto new file mode 100644 index 000000000..8aade9361 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-tail/kernel.pto @@ -0,0 +1,46 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vmin_tail_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1000_i32 = arith.constant 1000 : i32 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vmin %lhs, %rhs, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-tail/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vmin-tail/launch.cpp new file mode 100644 index 000000000..f2a890c47 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-tail/launch.cpp @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vmin_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchVmin_tail_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + vmin_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-tail/main.cpp b/test/vpto/cases/micro-op/binary-vector/vmin-tail/main.cpp new file mode 100644 index 000000000..5a418b3da --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-tail/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmin_tail_kernel_2d(float *v1, float *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmin_tail_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-tail/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vmin-tail/stub.cpp new file mode 100644 index 000000000..9e1f3ba22 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-tail/stub.cpp @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vmin_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin/compare.py b/test/vpto/cases/micro-op/binary-vector/vmin/compare.py new file mode 100644 index 000000000..a5f14dabc --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmin/golden.py b/test/vpto/cases/micro-op/binary-vector/vmin/golden.py new file mode 100644 index 000000000..6d18ab792 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin/golden.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.random((ROWS, COLS), dtype=np.float32) + v2 = rng.random((ROWS, COLS), dtype=np.float32) + golden_v3 = np.minimum(v1, v2) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.astype(np.float32, copy=False).reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmin/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmin/kernel.pto new file mode 100644 index 000000000..d896ab4e8 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin/kernel.pto @@ -0,0 +1,48 @@ +module attributes {pto.target_arch = "a5"} { + func.func @min_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %minv = pto.vmin %lhs, %rhs, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %minv, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vmin/launch.cpp new file mode 100644 index 000000000..f2c64c6a6 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin/launch.cpp @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void min_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchMin_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + min_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin/main.cpp b/test/vpto/cases/micro-op/binary-vector/vmin/main.cpp new file mode 100644 index 000000000..b952b76a0 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchMin_kernel_2d(float *v1, float *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchMin_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vmin/stub.cpp new file mode 100644 index 000000000..8a84a8941 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin/stub.cpp @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void min_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmul-tail/compare.py b/test/vpto/cases/micro-op/binary-vector/vmul-tail/compare.py new file mode 100644 index 000000000..c95419953 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmul-tail/compare.py @@ -0,0 +1,39 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.float32, 1e-4, 1000) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmul-tail/golden.py b/test/vpto/cases/micro-op/binary-vector/vmul-tail/golden.py new file mode 100644 index 000000000..553faae15 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmul-tail/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LOGICAL_ELEMS = 1000 +OUT_SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.random((ROWS, COLS), dtype=np.float32) + v2 = rng.random((ROWS, COLS), dtype=np.float32) + v3 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v3 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v3.reshape(-1)[:LOGICAL_ELEMS] = ( + v1.reshape(-1)[:LOGICAL_ELEMS] * v2.reshape(-1)[:LOGICAL_ELEMS] + ).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmul-tail/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmul-tail/kernel.pto new file mode 100644 index 000000000..a7eccda6f --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmul-tail/kernel.pto @@ -0,0 +1,46 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vmul_tail_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1000_i32 = arith.constant 1000 : i32 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %prod = pto.vmul %lhs, %rhs, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %prod, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmul-tail/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vmul-tail/launch.cpp new file mode 100644 index 000000000..6e1ab54ae --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmul-tail/launch.cpp @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vmul_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchVadd_tail_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + vmul_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmul-tail/main.cpp b/test/vpto/cases/micro-op/binary-vector/vmul-tail/main.cpp new file mode 100644 index 000000000..40a9881d6 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmul-tail/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadd_tail_kernel_2d(float *v1, float *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadd_tail_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmul-tail/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vmul-tail/stub.cpp new file mode 100644 index 000000000..fa4755459 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmul-tail/stub.cpp @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vmul_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmul/compare.py b/test/vpto/cases/micro-op/binary-vector/vmul/compare.py new file mode 100644 index 000000000..a5f14dabc --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmul/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmul/golden.py b/test/vpto/cases/micro-op/binary-vector/vmul/golden.py new file mode 100644 index 000000000..23e4731e7 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmul/golden.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-4.0, 4.0, size=(ROWS, COLS)).astype(np.float32) + v2 = rng.uniform(-4.0, 4.0, size=(ROWS, COLS)).astype(np.float32) + golden_v3 = (v1 * v2).astype(np.float32, copy=False) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmul/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmul/kernel.pto new file mode 100644 index 000000000..5763169d7 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmul/kernel.pto @@ -0,0 +1,48 @@ +module attributes {pto.target_arch = "a5"} { + func.func @mul_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %prod = pto.vmul %lhs, %rhs, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %prod, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmul/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vmul/launch.cpp new file mode 100644 index 000000000..21ee4384c --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmul/launch.cpp @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void mul_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchMul_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + mul_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmul/main.cpp b/test/vpto/cases/micro-op/binary-vector/vmul/main.cpp new file mode 100644 index 000000000..711269796 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmul/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchMul_kernel_2d(float *v1, float *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchMul_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmul/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vmul/stub.cpp new file mode 100644 index 000000000..e8b456fea --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmul/stub.cpp @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void mul_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vor-f16/compare.py b/test/vpto/cases/micro-op/binary-vector/vor-f16/compare.py new file mode 100755 index 000000000..78bd43ef6 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vor-f16/compare.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vor-f16 +# family: binary-vector +# target_ops: pto.vor +# scenarios: core-f16, full-mask +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.uint16, 0, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vor-f16/golden.py b/test/vpto/cases/micro-op/binary-vector/vor-f16/golden.py new file mode 100755 index 000000000..471da5094 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vor-f16/golden.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vor-f16 +# family: binary-vector +# target_ops: pto.vor +# scenarios: core-f16, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + bits1 = rng.integers(0, 0x10000, size=ELEMS, dtype=np.uint16) + bits2 = rng.integers(0, 0x10000, size=ELEMS, dtype=np.uint16) + bits1[:8] = np.array( + [0x0000, 0x8000, 0x3c00, 0xbc00, 0x7c00, 0xfc00, 0x7e00, 0x3555], + dtype=np.uint16, + ) + bits2[:8] = np.array( + [0x0001, 0x0001, 0x4000, 0x2000, 0x0001, 0x0001, 0x0100, 0x0aaa], + dtype=np.uint16, + ) + v1 = bits1.view(np.float16) + v2 = bits2.view(np.float16) + v3 = np.zeros(ELEMS, dtype=np.float16) + golden_v3 = np.bitwise_or(bits1, bits2).view(np.float16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vor-f16/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vor-f16/kernel.pto new file mode 100644 index 000000000..4a92da608 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vor-f16/kernel.pto @@ -0,0 +1,52 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vor-f16 +// family: binary-vector +// target_ops: pto.vor +// scenarios: core-f16, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vor_f16_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xf16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xf16> + %out = pto.vor %lhs, %rhs, %mask : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xf16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vor-f16/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vor-f16/launch.cpp new file mode 100644 index 000000000..45b41406e --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vor-f16/launch.cpp @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vor-f16 +// family: binary-vector +// target_ops: pto.vor +// scenarios: core-f16, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vor_f16_kernel(__gm__ half *v1, + __gm__ half *v2, + __gm__ half *v3); + +void LaunchVor_f16_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream) { + vor_f16_kernel<<<1, nullptr, stream>>>((__gm__ half *)v1, (__gm__ half *)v2, + (__gm__ half *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vor-f16/main.cpp b/test/vpto/cases/micro-op/binary-vector/vor-f16/main.cpp new file mode 100644 index 000000000..826735922 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vor-f16/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vor-f16 +// family: binary-vector +// target_ops: pto.vor +// scenarios: core-f16, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVor_f16_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVor_f16_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vor-f16/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vor-f16/stub.cpp new file mode 100644 index 000000000..6153f9277 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vor-f16/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vor-f16 +// family: binary-vector +// target_ops: pto.vor +// scenarios: core-f16, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vor_f16_kernel(__gm__ half *v1, + __gm__ half *v2, + __gm__ half *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/compare.py b/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/compare.py new file mode 100755 index 000000000..58ac20c66 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/compare.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vor-mask-edge +# family: binary-vector +# target_ops: pto.vor +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.uint16, 0, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/golden.py b/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/golden.py new file mode 100755 index 000000000..3c28fa036 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vor-mask-edge +# family: binary-vector +# target_ops: pto.vor +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + idx = np.arange(ELEMS, dtype=np.uint16) + v1 = np.where((idx & 1) == 0, np.uint16(0xAAAA), np.uint16(0x0F0F)).astype(np.uint16, copy=False) + v2 = np.where((idx & 2) == 0, np.uint16(0x5555), np.uint16(0x3333)).astype(np.uint16, copy=False) + v3 = np.zeros(ELEMS, dtype=np.uint16) + golden_v3 = np.bitwise_or(v1, v2).astype(np.uint16, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/kernel.pto new file mode 100644 index 000000000..49552af6d --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/kernel.pto @@ -0,0 +1,52 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vor-mask-edge +// family: binary-vector +// target_ops: pto.vor +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vor_mask_edge_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %out = pto.vor %lhs, %rhs, %mask : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/launch.cpp new file mode 100644 index 000000000..c9e9411c6 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vor-mask-edge +// family: binary-vector +// target_ops: pto.vor +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vor_mask_edge_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2, + __gm__ uint16_t *v3); + +void LaunchVor_mask_edge_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream) { + vor_mask_edge_kernel<<<1, nullptr, stream>>>((__gm__ uint16_t *)v1, + (__gm__ uint16_t *)v2, + (__gm__ uint16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/main.cpp b/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/main.cpp new file mode 100644 index 000000000..634ad7664 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vor-mask-edge +// family: binary-vector +// target_ops: pto.vor +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVor_mask_edge_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVor_mask_edge_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/stub.cpp new file mode 100644 index 000000000..ac5a8596d --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/stub.cpp @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vor-mask-edge +// family: binary-vector +// target_ops: pto.vor +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vor_mask_edge_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2, + __gm__ uint16_t *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vor/compare.py b/test/vpto/cases/micro-op/binary-vector/vor/compare.py new file mode 100755 index 000000000..8b38a30b8 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vor/compare.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vor +# family: binary-vector +# target_ops: pto.vor +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.uint16, 0, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vor/golden.py b/test/vpto/cases/micro-op/binary-vector/vor/golden.py new file mode 100755 index 000000000..c0d7ce117 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vor/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vor +# family: binary-vector +# target_ops: pto.vor +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, 0x10000, size=ELEMS, dtype=np.uint16) + v2 = rng.integers(0, 0x10000, size=ELEMS, dtype=np.uint16) + v3 = np.zeros(ELEMS, dtype=np.uint16) + golden_v3 = np.bitwise_or(v1, v2).astype(np.uint16, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vor/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vor/kernel.pto new file mode 100644 index 000000000..97fd111c0 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vor/kernel.pto @@ -0,0 +1,52 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vor +// family: binary-vector +// target_ops: pto.vor +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vor_i16_unsigned_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %out = pto.vor %lhs, %rhs, %mask : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vor/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vor/launch.cpp new file mode 100644 index 000000000..416e5354f --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vor/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vor +// family: binary-vector +// target_ops: pto.vor +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vor_i16_unsigned_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2, + __gm__ uint16_t *v3); + +void LaunchVor_i16_unsigned_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream) { + vor_i16_unsigned_kernel<<<1, nullptr, stream>>>((__gm__ uint16_t *)v1, + (__gm__ uint16_t *)v2, + (__gm__ uint16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vor/main.cpp b/test/vpto/cases/micro-op/binary-vector/vor/main.cpp new file mode 100644 index 000000000..0ebb0d781 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vor/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vor +// family: binary-vector +// target_ops: pto.vor +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVor_i16_unsigned_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVor_i16_unsigned_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vor/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vor/stub.cpp new file mode 100644 index 000000000..287e43ef5 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vor/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vor +// family: binary-vector +// target_ops: pto.vor +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vor_i16_unsigned_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2, + __gm__ uint16_t *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/compare.py b/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/compare.py new file mode 100755 index 000000000..fcf304f6f --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/compare.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vshl-i32-unsigned +# family: binary-vector +# target_ops: pto.vshl +# scenarios: core-i32-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.uint32, 0, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/golden.py b/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/golden.py new file mode 100755 index 000000000..cefd36ee1 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vshl-i32-unsigned +# family: binary-vector +# target_ops: pto.vshl +# scenarios: core-i32-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, 1 << 32, size=ELEMS, dtype=np.uint32) + v2 = rng.integers(0, 32, size=ELEMS, dtype=np.uint32) + v3 = np.zeros(ELEMS, dtype=np.uint32) + golden_v3 = np.left_shift(v1, v2 & np.uint32(31)).astype(np.uint32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/kernel.pto new file mode 100644 index 000000000..85112255a --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/kernel.pto @@ -0,0 +1,54 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshl-i32-unsigned +// family: binary-vector +// target_ops: pto.vshl +// scenarios: core-i32-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vshl_i32_unsigned_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c64 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<64xui32> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<64xui32> + %out = pto.vshl %lhs, %rhs, %mask : !pto.vreg<64xui32>, !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<64xui32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xui32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/launch.cpp new file mode 100644 index 000000000..ba6ef9dca --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshl-i32-unsigned +// family: binary-vector +// target_ops: pto.vshl +// scenarios: core-i32-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vshl_i32_unsigned_kernel(__gm__ uint32_t *v1, + __gm__ uint32_t *v2, + __gm__ uint32_t *v3); + +void LaunchVshl_i32_unsigned_kernel(uint32_t *v1, uint32_t *v2, uint32_t *v3, + void *stream) { + vshl_i32_unsigned_kernel<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1, + (__gm__ uint32_t *)v2, + (__gm__ uint32_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/main.cpp b/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/main.cpp new file mode 100644 index 000000000..df4bc2e9a --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshl-i32-unsigned +// family: binary-vector +// target_ops: pto.vshl +// scenarios: core-i32-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVshl_i32_unsigned_kernel(uint32_t *v1, uint32_t *v2, uint32_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint32_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + uint32_t *v2Host = nullptr; + uint32_t *v2Device = nullptr; + uint32_t *v3Host = nullptr; + uint32_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVshl_i32_unsigned_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/stub.cpp new file mode 100644 index 000000000..01ecbcf7c --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/stub.cpp @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshl-i32-unsigned +// family: binary-vector +// target_ops: pto.vshl +// scenarios: core-i32-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vshl_i32_unsigned_kernel(__gm__ uint32_t *v1, + __gm__ uint32_t *v2, + __gm__ uint32_t *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/compare.py b/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/compare.py new file mode 100755 index 000000000..2ef28c2cf --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/compare.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vshl-shift-boundary +# family: binary-vector +# target_ops: pto.vshl +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.uint16, 0, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/golden.py b/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/golden.py new file mode 100755 index 000000000..15261e271 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vshl-shift-boundary +# family: binary-vector +# target_ops: pto.vshl +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, 1 << 16, size=ELEMS, dtype=np.uint16) + shift_cycle = np.array([0, 1, 14, 15, 15, 14, 1, 0], dtype=np.uint16) + v2 = np.resize(shift_cycle, ELEMS).astype(np.uint16, copy=False) + v3 = np.zeros(ELEMS, dtype=np.uint16) + golden_v3 = np.left_shift(v1, v2 & np.uint16(15)).astype(np.uint16, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/kernel.pto new file mode 100644 index 000000000..42d07f0ec --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/kernel.pto @@ -0,0 +1,52 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshl-shift-boundary +// family: binary-vector +// target_ops: pto.vshl +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vshl_shift_boundary_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %out = pto.vshl %lhs, %rhs, %mask : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/launch.cpp new file mode 100644 index 000000000..bbf9c75f5 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshl-shift-boundary +// family: binary-vector +// target_ops: pto.vshl +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vshl_shift_boundary_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2, + __gm__ uint16_t *v3); + +void LaunchVshl_shift_boundary_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream) { + vshl_shift_boundary_kernel<<<1, nullptr, stream>>>((__gm__ uint16_t *)v1, + (__gm__ uint16_t *)v2, + (__gm__ uint16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/main.cpp b/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/main.cpp new file mode 100644 index 000000000..13d114f38 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshl-shift-boundary +// family: binary-vector +// target_ops: pto.vshl +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVshl_shift_boundary_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVshl_shift_boundary_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/stub.cpp new file mode 100644 index 000000000..7051b1442 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/stub.cpp @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshl-shift-boundary +// family: binary-vector +// target_ops: pto.vshl +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vshl_shift_boundary_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2, + __gm__ uint16_t *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshl/compare.py b/test/vpto/cases/micro-op/binary-vector/vshl/compare.py new file mode 100755 index 000000000..7006bca77 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshl/compare.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vshl +# family: binary-vector +# target_ops: pto.vshl +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.uint16, 0, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vshl/golden.py b/test/vpto/cases/micro-op/binary-vector/vshl/golden.py new file mode 100755 index 000000000..ed6ca93bc --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshl/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vshl +# family: binary-vector +# target_ops: pto.vshl +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, 0x10000, size=ELEMS, dtype=np.uint16) + v2 = rng.integers(0, 16, size=ELEMS, dtype=np.uint16) + v3 = np.zeros(ELEMS, dtype=np.uint16) + golden_v3 = np.left_shift(v1, v2 & np.uint16(15)).astype(np.uint16, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vshl/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vshl/kernel.pto new file mode 100644 index 000000000..c21d9eadf --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshl/kernel.pto @@ -0,0 +1,52 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshl +// family: binary-vector +// target_ops: pto.vshl +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vshl_i16_unsigned_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %out = pto.vshl %lhs, %rhs, %mask : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshl/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vshl/launch.cpp new file mode 100644 index 000000000..3abb0a4bf --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshl/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshl +// family: binary-vector +// target_ops: pto.vshl +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vshl_i16_unsigned_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2, + __gm__ uint16_t *v3); + +void LaunchVshl_i16_unsigned_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream) { + vshl_i16_unsigned_kernel<<<1, nullptr, stream>>>((__gm__ uint16_t *)v1, + (__gm__ uint16_t *)v2, + (__gm__ uint16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshl/main.cpp b/test/vpto/cases/micro-op/binary-vector/vshl/main.cpp new file mode 100644 index 000000000..ce45e4cf1 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshl/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshl +// family: binary-vector +// target_ops: pto.vshl +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVshl_i16_unsigned_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVshl_i16_unsigned_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshl/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vshl/stub.cpp new file mode 100644 index 000000000..7ddcaf5c8 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshl/stub.cpp @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshl +// family: binary-vector +// target_ops: pto.vshl +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vshl_i16_unsigned_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2, + __gm__ uint16_t *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/compare.py b/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/compare.py new file mode 100755 index 000000000..a8ee34d1b --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/compare.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vshr-i16-signed +# family: binary-vector +# target_ops: pto.vshr +# scenarios: core-i16-signed, full-mask +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.int16, 0, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/golden.py b/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/golden.py new file mode 100755 index 000000000..4262e419f --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vshr-i16-signed +# family: binary-vector +# target_ops: pto.vshr +# scenarios: core-i16-signed, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(-0x8000, 0x8000, size=ELEMS, dtype=np.int16) + v2 = rng.integers(0, 16, size=ELEMS, dtype=np.int16) + v3 = np.zeros(ELEMS, dtype=np.int16) + golden_v3 = np.right_shift(v1, v2 & np.int16(15)).astype(np.int16, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/kernel.pto new file mode 100644 index 000000000..459959491 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/kernel.pto @@ -0,0 +1,52 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshr-i16-signed +// family: binary-vector +// target_ops: pto.vshr +// scenarios: core-i16-signed, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vshr_i16_signed_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xsi16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xsi16> + %out = pto.vshr %lhs, %rhs, %mask : !pto.vreg<128xsi16>, !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<128xsi16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xsi16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/launch.cpp new file mode 100644 index 000000000..6b62ae139 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshr-i16-signed +// family: binary-vector +// target_ops: pto.vshr +// scenarios: core-i16-signed, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vshr_i16_signed_kernel(__gm__ int16_t *v1, + __gm__ int16_t *v2, + __gm__ int16_t *v3); + +void LaunchVshr_i16_signed_kernel(int16_t *v1, int16_t *v2, int16_t *v3, + void *stream) { + vshr_i16_signed_kernel<<<1, nullptr, stream>>>((__gm__ int16_t *)v1, + (__gm__ int16_t *)v2, + (__gm__ int16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/main.cpp b/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/main.cpp new file mode 100644 index 000000000..d1e3520e4 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshr-i16-signed +// family: binary-vector +// target_ops: pto.vshr +// scenarios: core-i16-signed, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVshr_i16_signed_kernel(int16_t *v1, int16_t *v2, int16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(int16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(int16_t); + int16_t *v1Host = nullptr; + int16_t *v1Device = nullptr; + int16_t *v2Host = nullptr; + int16_t *v2Device = nullptr; + int16_t *v3Host = nullptr; + int16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVshr_i16_signed_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/stub.cpp new file mode 100644 index 000000000..2e4da3914 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshr-i16-signed +// family: binary-vector +// target_ops: pto.vshr +// scenarios: core-i16-signed, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vshr_i16_signed_kernel(__gm__ int16_t *v1, + __gm__ int16_t *v2, + __gm__ int16_t *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/compare.py b/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/compare.py new file mode 100755 index 000000000..f5e74791e --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/compare.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vshr-shift-boundary +# family: binary-vector +# target_ops: pto.vshr +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.uint16, 0, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/golden.py b/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/golden.py new file mode 100755 index 000000000..6c70bedba --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vshr-shift-boundary +# family: binary-vector +# target_ops: pto.vshr +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, 1 << 16, size=ELEMS, dtype=np.uint16) + shift_cycle = np.array([0, 1, 14, 15, 15, 14, 1, 0], dtype=np.uint16) + v2 = np.resize(shift_cycle, ELEMS).astype(np.uint16, copy=False) + v3 = np.zeros(ELEMS, dtype=np.uint16) + golden_v3 = np.right_shift(v1, v2 & np.uint16(15)).astype(np.uint16, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/kernel.pto new file mode 100644 index 000000000..b20001bda --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/kernel.pto @@ -0,0 +1,52 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshr-shift-boundary +// family: binary-vector +// target_ops: pto.vshr +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vshr_shift_boundary_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %out = pto.vshr %lhs, %rhs, %mask : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/launch.cpp new file mode 100644 index 000000000..a24d8e08c --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshr-shift-boundary +// family: binary-vector +// target_ops: pto.vshr +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vshr_shift_boundary_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2, + __gm__ uint16_t *v3); + +void LaunchVshr_shift_boundary_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream) { + vshr_shift_boundary_kernel<<<1, nullptr, stream>>>((__gm__ uint16_t *)v1, + (__gm__ uint16_t *)v2, + (__gm__ uint16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/main.cpp b/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/main.cpp new file mode 100644 index 000000000..e8d1e1459 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshr-shift-boundary +// family: binary-vector +// target_ops: pto.vshr +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVshr_shift_boundary_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVshr_shift_boundary_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/stub.cpp new file mode 100644 index 000000000..08e67f1d9 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/stub.cpp @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshr-shift-boundary +// family: binary-vector +// target_ops: pto.vshr +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vshr_shift_boundary_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2, + __gm__ uint16_t *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshr/compare.py b/test/vpto/cases/micro-op/binary-vector/vshr/compare.py new file mode 100755 index 000000000..a2429ec9b --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshr/compare.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vshr +# family: binary-vector +# target_ops: pto.vshr +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.uint16, 0, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vshr/golden.py b/test/vpto/cases/micro-op/binary-vector/vshr/golden.py new file mode 100755 index 000000000..bd0cda8b9 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshr/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vshr +# family: binary-vector +# target_ops: pto.vshr +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, 0x10000, size=ELEMS, dtype=np.uint16) + v2 = rng.integers(0, 16, size=ELEMS, dtype=np.uint16) + v3 = np.zeros(ELEMS, dtype=np.uint16) + golden_v3 = np.right_shift(v1, v2 & np.uint16(15)).astype(np.uint16, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vshr/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vshr/kernel.pto new file mode 100644 index 000000000..2ce05e8be --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshr/kernel.pto @@ -0,0 +1,52 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshr +// family: binary-vector +// target_ops: pto.vshr +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vshr_i16_unsigned_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %out = pto.vshr %lhs, %rhs, %mask : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshr/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vshr/launch.cpp new file mode 100644 index 000000000..08208c24c --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshr/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshr +// family: binary-vector +// target_ops: pto.vshr +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vshr_i16_unsigned_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2, + __gm__ uint16_t *v3); + +void LaunchVshr_i16_unsigned_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream) { + vshr_i16_unsigned_kernel<<<1, nullptr, stream>>>((__gm__ uint16_t *)v1, + (__gm__ uint16_t *)v2, + (__gm__ uint16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshr/main.cpp b/test/vpto/cases/micro-op/binary-vector/vshr/main.cpp new file mode 100644 index 000000000..fcdf7cbc0 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshr/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshr +// family: binary-vector +// target_ops: pto.vshr +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVshr_i16_unsigned_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVshr_i16_unsigned_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshr/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vshr/stub.cpp new file mode 100644 index 000000000..26832ee79 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshr/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshr +// family: binary-vector +// target_ops: pto.vshr +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vshr_i16_unsigned_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2, + __gm__ uint16_t *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vsub-tail/compare.py b/test/vpto/cases/micro-op/binary-vector/vsub-tail/compare.py new file mode 100644 index 000000000..c95419953 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsub-tail/compare.py @@ -0,0 +1,39 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.float32, 1e-4, 1000) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vsub-tail/golden.py b/test/vpto/cases/micro-op/binary-vector/vsub-tail/golden.py new file mode 100644 index 000000000..954e00c9b --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsub-tail/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LOGICAL_ELEMS = 1000 +OUT_SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.random((ROWS, COLS), dtype=np.float32) + v2 = rng.random((ROWS, COLS), dtype=np.float32) + v3 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v3 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v3.reshape(-1)[:LOGICAL_ELEMS] = ( + v1.reshape(-1)[:LOGICAL_ELEMS] - v2.reshape(-1)[:LOGICAL_ELEMS] + ).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vsub-tail/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vsub-tail/kernel.pto new file mode 100644 index 000000000..f124da2d7 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsub-tail/kernel.pto @@ -0,0 +1,46 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vsub_tail_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1000_i32 = arith.constant 1000 : i32 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %diff = pto.vsub %lhs, %rhs, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %diff, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vsub-tail/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vsub-tail/launch.cpp new file mode 100644 index 000000000..01f113a97 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsub-tail/launch.cpp @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vsub_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchVadd_tail_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + vsub_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vsub-tail/main.cpp b/test/vpto/cases/micro-op/binary-vector/vsub-tail/main.cpp new file mode 100644 index 000000000..40a9881d6 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsub-tail/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadd_tail_kernel_2d(float *v1, float *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadd_tail_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vsub-tail/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vsub-tail/stub.cpp new file mode 100644 index 000000000..f8bd5e2b1 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsub-tail/stub.cpp @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vsub_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vsub/compare.py b/test/vpto/cases/micro-op/binary-vector/vsub/compare.py new file mode 100644 index 000000000..a5f14dabc --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsub/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vsub/golden.py b/test/vpto/cases/micro-op/binary-vector/vsub/golden.py new file mode 100644 index 000000000..2f3f82fe6 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsub/golden.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + golden_v3 = (v1 - v2).astype(np.float32, copy=False) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vsub/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vsub/kernel.pto new file mode 100644 index 000000000..1721e0273 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsub/kernel.pto @@ -0,0 +1,48 @@ +module attributes {pto.target_arch = "a5"} { + func.func @sub_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %diff = pto.vsub %lhs, %rhs, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %diff, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vsub/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vsub/launch.cpp new file mode 100644 index 000000000..daeaeb5de --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsub/launch.cpp @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void sub_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchSub_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + sub_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vsub/main.cpp b/test/vpto/cases/micro-op/binary-vector/vsub/main.cpp new file mode 100644 index 000000000..0c7c8359a --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsub/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchSub_kernel_2d(float *v1, float *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchSub_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vsub/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vsub/stub.cpp new file mode 100644 index 000000000..6b357ffb5 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsub/stub.cpp @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void sub_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/compare.py b/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/compare.py new file mode 100755 index 000000000..67df8a750 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/compare.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vsubc-borrow-boundary +# family: binary-vector +# target_ops: pto.vsubc +# scenarios: core-u32-unsigned, full-mask, carry-chain +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +LOGICAL_ELEMS = 64 +SRC_ELEM_BYTES = 4 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + repeat_elems = REPEAT_BYTES // src_elem_bytes + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + +def compare_result(): + golden = np.fromfile("golden_v3.bin", dtype=np.uint32, count=64) + output = np.fromfile("v3.bin", dtype=np.uint32, count=64) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def compare_borrow(): + prefix_bytes = _packed_pred_storage_bytes(LOGICAL_ELEMS, SRC_ELEM_BYTES) + golden = np.fromfile("golden_v4.bin", dtype=np.uint8) + output = np.fromfile("v4.bin", dtype=np.uint8) + if golden.size < prefix_bytes or output.size < prefix_bytes: + return False + return np.array_equal(golden[:prefix_bytes], output[:prefix_bytes]) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_result() and compare_borrow() + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/golden.py b/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/golden.py new file mode 100755 index 000000000..cf6db0014 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/golden.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vsubc-borrow-boundary +# family: binary-vector +# target_ops: pto.vsubc +# scenarios: core-u32-unsigned, full-mask, carry-chain +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 64 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + del seed + v1 = np.zeros(LANES, dtype=np.uint32) + v2 = np.zeros(LANES, dtype=np.uint32) + pattern_lhs = np.array([0x00000000, 0x00000001, 0x7FFFFFFF, 0x80000000], dtype=np.uint32) + pattern_rhs = np.array([0x00000001, 0x00000002, 0x80000000, 0xFFFFFFFF], dtype=np.uint32) + reps = LANES // pattern_lhs.size + v1[:] = np.tile(pattern_lhs, reps) + v2[:] = np.tile(pattern_rhs, reps) + no_borrow = v1 >= v2 + result = (v1 - v2).astype(np.uint32, copy=False) + packed = np.zeros(256, dtype=np.uint8) + for idx, bit in enumerate(no_borrow): + if not bit: + continue + byte = idx // 2 + if idx % 2 == 0: + packed[byte] |= np.uint8(0x1) + else: + packed[byte] |= np.uint8(0x10) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + np.zeros(LANES, dtype=np.uint32).tofile(output_dir / "v3.bin") + np.zeros(256, dtype=np.uint8).tofile(output_dir / "v4.bin") + result.tofile(output_dir / "golden_v3.bin") + packed.tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/kernel.pto new file mode 100644 index 000000000..aaeabd1d4 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/kernel.pto @@ -0,0 +1,53 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vsubc-borrow-boundary +// family: binary-vector +// target_ops: pto.vsubc +// scenarios: core-u32-unsigned, full-mask, carry-chain +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vsubc_borrow_boundary_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr, %arg3: !pto.ptr) { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_borrow = pto.castptr %c12288_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %lhs = pto.vlds %ub_lhs[%c0] : !pto.ptr -> !pto.vreg<64xui32> + %rhs = pto.vlds %ub_rhs[%c0] : !pto.ptr -> !pto.vreg<64xui32> + %diff, %borrow = pto.vsubc %lhs, %rhs, %mask : !pto.vreg<64xui32>, !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<64xui32>, !pto.mask + pto.vsts %diff, %ub_out[%c0], %mask : !pto.vreg<64xui32>, !pto.ptr, !pto.mask + pto.psti %borrow, %ub_borrow[%c0], "NORM" : !pto.mask, !pto.ptr, index + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_borrow, %arg3, %c0_i64, %c1_i64, %c128_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/launch.cpp new file mode 100644 index 000000000..972a52f2d --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/launch.cpp @@ -0,0 +1,56 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vsubc-borrow-boundary +// family: binary-vector +// target_ops: pto.vsubc +// scenarios: core-u32-unsigned, full-mask, carry-chain +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vsubc_borrow_boundary_kernel_2d(__gm__ uint32_t *v1, __gm__ uint32_t *v2, + __gm__ uint32_t *v3, __gm__ uint8_t *v4); + +void LaunchVsubc_borrow_boundary_kernel_2d(uint32_t *v1, uint32_t *v2, + uint32_t *v3, uint8_t *v4, + void *stream) { + vsubc_borrow_boundary_kernel_2d<<<1, nullptr, stream>>>( + (__gm__ uint32_t *)v1, (__gm__ uint32_t *)v2, (__gm__ uint32_t *)v3, + (__gm__ uint8_t *)v4); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/main.cpp b/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/main.cpp new file mode 100644 index 000000000..43b2eb292 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/main.cpp @@ -0,0 +1,119 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vsubc-borrow-boundary +// family: binary-vector +// target_ops: pto.vsubc +// scenarios: core-u32-unsigned, full-mask, carry-chain +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVsubc_borrow_boundary_kernel_2d(uint32_t *v1, uint32_t *v2, + uint32_t *v3, uint8_t *v4, + void *stream); + +int main() { + size_t elemCount_v1 = 64; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + size_t elemCount_v2 = 64; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint32_t); + size_t elemCount_v3 = 64; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint32_t); + size_t elemCount_v4 = 256; + size_t fileSize_v4 = elemCount_v4 * sizeof(uint8_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + uint32_t *v2Host = nullptr; + uint32_t *v2Device = nullptr; + uint32_t *v3Host = nullptr; + uint32_t *v3Device = nullptr; + uint8_t *v4Host = nullptr; + uint8_t *v4Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMallocHost((void **)(&v4Host), fileSize_v4)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v4Device, fileSize_v4, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ReadFile("./v4.bin", fileSize_v4, v4Host, fileSize_v4); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v4Device, fileSize_v4, v4Host, fileSize_v4, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVsubc_borrow_boundary_kernel_2d(v1Device, v2Device, v3Device, v4Device, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v4Host, fileSize_v4, v4Device, fileSize_v4, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + WriteFile("./v4.bin", v4Host, fileSize_v4); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFree(v4Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + aclrtFreeHost(v4Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/stub.cpp new file mode 100644 index 000000000..17b14c10b --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vsubc-borrow-boundary +// family: binary-vector +// target_ops: pto.vsubc +// scenarios: core-u32-unsigned, full-mask, carry-chain +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void +vsubc_borrow_boundary_kernel_2d(__gm__ uint32_t *v1, __gm__ uint32_t *v2, + __gm__ uint32_t *v3, __gm__ uint8_t *v4) { + (void)v1; + (void)v2; + (void)v3; + (void)v4; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vsubc/compare.py b/test/vpto/cases/micro-op/binary-vector/vsubc/compare.py new file mode 100755 index 000000000..f68c8267e --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsubc/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vsubc +# family: binary-vector +# target_ops: pto.vsubc +# scenarios: core-u32-unsigned, full-mask, carry-chain + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +LOGICAL_ELEMS = 64 +SRC_ELEM_BYTES = 4 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + repeat_elems = REPEAT_BYTES // src_elem_bytes + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + +def compare_result(): + golden = np.fromfile("golden_v3.bin", dtype=np.uint32, count=64) + output = np.fromfile("v3.bin", dtype=np.uint32, count=64) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def compare_borrow(): + prefix_bytes = _packed_pred_storage_bytes(LOGICAL_ELEMS, SRC_ELEM_BYTES) + golden = np.fromfile("golden_v4.bin", dtype=np.uint8) + output = np.fromfile("v4.bin", dtype=np.uint8) + if golden.size < prefix_bytes or output.size < prefix_bytes: + return False + return np.array_equal(golden[:prefix_bytes], output[:prefix_bytes]) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_result() and compare_borrow() + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vsubc/golden.py b/test/vpto/cases/micro-op/binary-vector/vsubc/golden.py new file mode 100755 index 000000000..1b647ac6c --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsubc/golden.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vsubc +# family: binary-vector +# target_ops: pto.vsubc +# scenarios: core-u32-unsigned, full-mask, carry-chain + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 64 +SEED = 19 + + +def pack_mask_nibbles(bits): + out = np.zeros(256, dtype=np.uint8) + for idx, bit in enumerate(bits): + if not bit: + continue + byte = idx // 2 + if idx % 2 == 0: + out[byte] |= np.uint8(0x1) + else: + out[byte] |= np.uint8(0x10) + return out + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, 0xFFFFFFFF, size=LANES, dtype=np.uint32) + v2 = rng.integers(0, 0xFFFFFFFF, size=LANES, dtype=np.uint32) + diff = (v1 - v2).astype(np.uint32, copy=False) + no_borrow = v1 >= v2 + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + np.zeros(LANES, dtype=np.uint32).tofile(output_dir / "v3.bin") + np.zeros(256, dtype=np.uint8).tofile(output_dir / "v4.bin") + diff.tofile(output_dir / "golden_v3.bin") + pack_mask_nibbles(no_borrow).tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vsubc/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vsubc/kernel.pto new file mode 100644 index 000000000..7a7dc3720 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsubc/kernel.pto @@ -0,0 +1,53 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vsubc +// family: binary-vector +// target_ops: pto.vsubc +// scenarios: core-i16-unsigned, full-mask, carry-chain +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vsubc_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr, %arg3: !pto.ptr) { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_borrow = pto.castptr %c12288_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %lhs = pto.vlds %ub_lhs[%c0] : !pto.ptr -> !pto.vreg<64xui32> + %rhs = pto.vlds %ub_rhs[%c0] : !pto.ptr -> !pto.vreg<64xui32> + %diff, %borrow = pto.vsubc %lhs, %rhs, %mask : !pto.vreg<64xui32>, !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<64xui32>, !pto.mask + pto.vsts %diff, %ub_out[%c0], %mask : !pto.vreg<64xui32>, !pto.ptr, !pto.mask + pto.psti %borrow, %ub_borrow[%c0], "NORM" : !pto.mask, !pto.ptr, index + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_borrow, %arg3, %c0_i64, %c1_i64, %c128_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vsubc/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vsubc/launch.cpp new file mode 100644 index 000000000..4f47cec25 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsubc/launch.cpp @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vsubc +// family: binary-vector +// target_ops: pto.vsubc +// scenarios: core-u32-unsigned, full-mask, carry-chain +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vsubc_kernel_2d(__gm__ uint32_t *v1, + __gm__ uint32_t *v2, + __gm__ uint32_t *v3, + __gm__ uint8_t *v4); + +void LaunchVsubc_kernel_2d(uint32_t *v1, uint32_t *v2, uint32_t *v3, + uint8_t *v4, void *stream) { + vsubc_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1, + (__gm__ uint32_t *)v2, + (__gm__ uint32_t *)v3, + (__gm__ uint8_t *)v4); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vsubc/main.cpp b/test/vpto/cases/micro-op/binary-vector/vsubc/main.cpp new file mode 100644 index 000000000..a553603b2 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsubc/main.cpp @@ -0,0 +1,117 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vsubc +// family: binary-vector +// target_ops: pto.vsubc +// scenarios: core-u32-unsigned, full-mask, carry-chain +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVsubc_kernel_2d(uint32_t *v1, uint32_t *v2, uint32_t *v3, + uint8_t *v4, void *stream); + +int main() { + size_t elemCount_v1 = 64; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + size_t elemCount_v2 = 64; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint32_t); + size_t elemCount_v3 = 64; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint32_t); + size_t elemCount_v4 = 256; + size_t fileSize_v4 = elemCount_v4 * sizeof(uint8_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + uint32_t *v2Host = nullptr; + uint32_t *v2Device = nullptr; + uint32_t *v3Host = nullptr; + uint32_t *v3Device = nullptr; + uint8_t *v4Host = nullptr; + uint8_t *v4Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMallocHost((void **)(&v4Host), fileSize_v4)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v4Device, fileSize_v4, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ReadFile("./v4.bin", fileSize_v4, v4Host, fileSize_v4); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v4Device, fileSize_v4, v4Host, fileSize_v4, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVsubc_kernel_2d(v1Device, v2Device, v3Device, v4Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v4Host, fileSize_v4, v4Device, fileSize_v4, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + WriteFile("./v4.bin", v4Host, fileSize_v4); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFree(v4Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + aclrtFreeHost(v4Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vsubc/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vsubc/stub.cpp new file mode 100644 index 000000000..ab7af88e0 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsubc/stub.cpp @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vsubc +// family: binary-vector +// target_ops: pto.vsubc +// scenarios: core-u32-unsigned, full-mask, carry-chain +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vsubc_kernel_2d(__gm__ uint32_t *v1, + __gm__ uint32_t *v2, + __gm__ uint32_t *v3, + __gm__ uint8_t *v4) { + (void)v1; + (void)v2; + (void)v3; + (void)v4; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/compare.py b/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/compare.py new file mode 100755 index 000000000..e8a187bb2 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/compare.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vxor-mask-edge +# family: binary-vector +# target_ops: pto.vxor +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.uint16, 0, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/golden.py b/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/golden.py new file mode 100755 index 000000000..0da3cc44d --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vxor-mask-edge +# family: binary-vector +# target_ops: pto.vxor +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + idx = np.arange(ELEMS, dtype=np.uint16) + v1 = np.where((idx & 1) == 0, np.uint16(0xAAAA), np.uint16(0x0F0F)).astype(np.uint16, copy=False) + v2 = np.where((idx & 2) == 0, np.uint16(0x5555), np.uint16(0x3333)).astype(np.uint16, copy=False) + v3 = np.zeros(ELEMS, dtype=np.uint16) + golden_v3 = np.bitwise_xor(v1, v2).astype(np.uint16, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/kernel.pto new file mode 100644 index 000000000..1e9ff40bb --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/kernel.pto @@ -0,0 +1,52 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vxor-mask-edge +// family: binary-vector +// target_ops: pto.vxor +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vxor_mask_edge_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %out = pto.vxor %lhs, %rhs, %mask : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/launch.cpp new file mode 100644 index 000000000..309646298 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vxor-mask-edge +// family: binary-vector +// target_ops: pto.vxor +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vxor_mask_edge_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2, + __gm__ uint16_t *v3); + +void LaunchVxor_mask_edge_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream) { + vxor_mask_edge_kernel<<<1, nullptr, stream>>>((__gm__ uint16_t *)v1, + (__gm__ uint16_t *)v2, + (__gm__ uint16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/main.cpp b/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/main.cpp new file mode 100644 index 000000000..95d478f5f --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vxor-mask-edge +// family: binary-vector +// target_ops: pto.vxor +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVxor_mask_edge_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVxor_mask_edge_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/stub.cpp new file mode 100644 index 000000000..6a158648e --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/stub.cpp @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vxor-mask-edge +// family: binary-vector +// target_ops: pto.vxor +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vxor_mask_edge_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2, + __gm__ uint16_t *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vxor/compare.py b/test/vpto/cases/micro-op/binary-vector/vxor/compare.py new file mode 100755 index 000000000..cd8e1ce3e --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vxor/compare.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vxor +# family: binary-vector +# target_ops: pto.vxor +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.uint16, 0, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vxor/golden.py b/test/vpto/cases/micro-op/binary-vector/vxor/golden.py new file mode 100755 index 000000000..f5e328e08 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vxor/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vxor +# family: binary-vector +# target_ops: pto.vxor +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, 0x10000, size=ELEMS, dtype=np.uint16) + v2 = rng.integers(0, 0x10000, size=ELEMS, dtype=np.uint16) + v3 = np.zeros(ELEMS, dtype=np.uint16) + golden_v3 = np.bitwise_xor(v1, v2).astype(np.uint16, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vxor/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vxor/kernel.pto new file mode 100644 index 000000000..cf456ee72 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vxor/kernel.pto @@ -0,0 +1,52 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vxor +// family: binary-vector +// target_ops: pto.vxor +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vxor_i16_unsigned_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %out = pto.vxor %lhs, %rhs, %mask : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vxor/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vxor/launch.cpp new file mode 100644 index 000000000..59d1c049b --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vxor/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vxor +// family: binary-vector +// target_ops: pto.vxor +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vxor_i16_unsigned_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2, + __gm__ uint16_t *v3); + +void LaunchVxor_i16_unsigned_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream) { + vxor_i16_unsigned_kernel<<<1, nullptr, stream>>>((__gm__ uint16_t *)v1, + (__gm__ uint16_t *)v2, + (__gm__ uint16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vxor/main.cpp b/test/vpto/cases/micro-op/binary-vector/vxor/main.cpp new file mode 100644 index 000000000..99f4a9d98 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vxor/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vxor +// family: binary-vector +// target_ops: pto.vxor +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVxor_i16_unsigned_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVxor_i16_unsigned_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vxor/stub.cpp b/test/vpto/cases/micro-op/binary-vector/vxor/stub.cpp new file mode 100644 index 000000000..772a434c4 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vxor/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vxor +// family: binary-vector +// target_ops: pto.vxor +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vxor_i16_unsigned_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2, + __gm__ uint16_t *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-eq/compare.py b/test/vpto/cases/micro-op/compare-select/vcmp-eq/compare.py new file mode 100644 index 000000000..a872552e3 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-eq/compare.py @@ -0,0 +1,35 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + golden = np.fromfile("golden_v3.bin", dtype=np.uint8) + output = np.fromfile("v3.bin", dtype=np.uint8) + ok = golden.size >= 32 and output.size >= 32 and np.array_equal(golden[:32], output[:32]) + if not ok: + if golden.size and output.size: + diff = np.nonzero(golden[:32] != output[:32])[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch: idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-eq/golden.py b/test/vpto/cases/micro-op/compare-select/vcmp-eq/golden.py new file mode 100644 index 000000000..4d075f357 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-eq/golden.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 64 +SEED = 19 +OUTPUT_BYTES = 32 + + +def encode_b32_mask(mask: np.ndarray) -> np.ndarray: + out = np.zeros((OUTPUT_BYTES,), dtype=np.uint8) + for i, bit in enumerate(mask.astype(np.uint8, copy=False)): + if bit: + byte_index = i // 2 + nibble_shift = 4 * (i % 2) + out[byte_index] |= np.uint8(1 << nibble_shift) + return out + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-3.0, 3.0, size=(LANES,)).astype(np.float32) + v2 = v1.copy() + mismatch = (np.arange(LANES, dtype=np.int32) % 3) == 1 + v2[mismatch] = (v2[mismatch] + np.float32(1.25)).astype(np.float32) + + mask = np.equal(v1, v2) + golden = encode_b32_mask(mask) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + np.zeros((OUTPUT_BYTES,), dtype=np.uint8).tofile(output_dir / "v3.bin") + golden.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for VPTO vcmp-eq.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-eq/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmp-eq/kernel.pto new file mode 100644 index 000000000..e8fcf175f --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-eq/kernel.pto @@ -0,0 +1,49 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vcmp_eq_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i32 = arith.constant 64 : i32 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %false = arith.constant false + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c256_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr + + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c64_i32) -> (i32) { + %active, %next = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %pred = pto.vcmp %lhs, %rhs, %active, "eq" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + pto.psts %pred, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-eq/launch.cpp b/test/vpto/cases/micro-op/compare-select/vcmp-eq/launch.cpp new file mode 100644 index 000000000..aedecd7b5 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-eq/launch.cpp @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcmp_eq_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ unsigned char *v3); + +void LaunchVcmp_eq_kernel_2d(float *v1, float *v2, unsigned char *v3, + void *stream) { + vcmp_eq_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ unsigned char *)v3); +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-eq/main.cpp b/test/vpto/cases/micro-op/compare-select/vcmp-eq/main.cpp new file mode 100644 index 000000000..15cae7173 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-eq/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcmp_eq_kernel_2d(float *v1, float *v2, unsigned char *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 64; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 64; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 32; + size_t fileSize_v3 = elemCount_v3 * sizeof(unsigned char); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + unsigned char *v3Host = nullptr; + unsigned char *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVcmp_eq_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-eq/stub.cpp b/test/vpto/cases/micro-op/compare-select/vcmp-eq/stub.cpp new file mode 100644 index 000000000..ff0d95de0 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-eq/stub.cpp @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vcmp_eq_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ unsigned char *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/compare.py b/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/compare.py new file mode 100644 index 000000000..a872552e3 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/compare.py @@ -0,0 +1,35 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + golden = np.fromfile("golden_v3.bin", dtype=np.uint8) + output = np.fromfile("v3.bin", dtype=np.uint8) + ok = golden.size >= 32 and output.size >= 32 and np.array_equal(golden[:32], output[:32]) + if not ok: + if golden.size and output.size: + diff = np.nonzero(golden[:32] != output[:32])[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch: idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/golden.py b/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/golden.py new file mode 100644 index 000000000..3d3dbb4d8 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/golden.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 64 +SEED = 19 +OUTPUT_BYTES = 32 + + +def encode_b32_mask(mask: np.ndarray) -> np.ndarray: + out = np.zeros((OUTPUT_BYTES,), dtype=np.uint8) + for i, bit in enumerate(mask.astype(np.uint8, copy=False)): + if bit: + byte_index = i // 2 + nibble_shift = 4 * (i % 2) + out[byte_index] |= np.uint8(1 << nibble_shift) + return out + + +def generate(output_dir: Path, seed: int) -> None: + del seed + lhs = np.array( + [-np.inf, -3.0, -0.0, 0.0, 0.5, np.inf, np.nan, 1.0], + dtype=np.float32, + ) + rhs = np.array( + [np.inf, -2.0, 0.0, -0.0, 0.5, np.nan, 1.0, -np.inf], + dtype=np.float32, + ) + v1 = np.resize(lhs, LANES).astype(np.float32) + v2 = np.resize(rhs, LANES).astype(np.float32) + mask = np.less(v1, v2) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + np.zeros((OUTPUT_BYTES,), dtype=np.uint8).tofile(output_dir / "v3.bin") + encode_b32_mask(mask).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for VPTO vcmp-f32-exceptional.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/kernel.pto new file mode 100644 index 000000000..66751d6ec --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/kernel.pto @@ -0,0 +1,49 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vcmp_f32_exceptional_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i32 = arith.constant 64 : i32 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %false = arith.constant false + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c256_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr + + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c64_i32) -> (i32) { + %active, %next = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %pred = pto.vcmp %lhs, %rhs, %active, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + pto.psts %pred, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/launch.cpp b/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/launch.cpp new file mode 100644 index 000000000..97d79d2fd --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/launch.cpp @@ -0,0 +1,50 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcmp_f32_exceptional_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ unsigned char *v3); + +void LaunchVcmp_f32_exceptional_kernel_2d(float *v1, float *v2, + unsigned char *v3, void *stream) { + vcmp_f32_exceptional_kernel_2d<<<1, nullptr, stream>>>( + (__gm__ float *)v1, (__gm__ float *)v2, (__gm__ unsigned char *)v3); +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/main.cpp b/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/main.cpp new file mode 100644 index 000000000..b8b92375f --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcmp_f32_exceptional_kernel_2d(float *v1, float *v2, + unsigned char *v3, void *stream); + +int main() { + size_t elemCount_v1 = 64; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 64; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 32; + size_t fileSize_v3 = elemCount_v3 * sizeof(unsigned char); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + unsigned char *v3Host = nullptr; + unsigned char *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVcmp_f32_exceptional_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/stub.cpp b/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/stub.cpp new file mode 100644 index 000000000..d5b0ccf61 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/stub.cpp @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vcmp_f32_exceptional_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ unsigned char *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/compare.py b/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/compare.py new file mode 100644 index 000000000..a872552e3 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/compare.py @@ -0,0 +1,35 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + golden = np.fromfile("golden_v3.bin", dtype=np.uint8) + output = np.fromfile("v3.bin", dtype=np.uint8) + ok = golden.size >= 32 and output.size >= 32 and np.array_equal(golden[:32], output[:32]) + if not ok: + if golden.size and output.size: + diff = np.nonzero(golden[:32] != output[:32])[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch: idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/golden.py b/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/golden.py new file mode 100644 index 000000000..08c48389e --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/golden.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 128 +SEED = 19 +OUTPUT_BYTES = 32 + + +def encode_b16_mask(mask: np.ndarray) -> np.ndarray: + out = np.zeros((OUTPUT_BYTES,), dtype=np.uint8) + for i, bit in enumerate(mask.astype(np.uint8, copy=False)): + if bit: + byte_index = i // 4 + pair_shift = 2 * (i % 4) + out[byte_index] |= np.uint8(1 << pair_shift) + return out + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(-1200, 1200, size=(LANES,), dtype=np.int16) + v2 = v1.copy() + mismatch = (np.arange(LANES, dtype=np.int32) % 3) == 1 + v2[mismatch] = (v2[mismatch] + np.int16(7)).astype(np.int16) + mask = np.equal(v1, v2) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + np.zeros((OUTPUT_BYTES,), dtype=np.uint8).tofile(output_dir / "v3.bin") + encode_b16_mask(mask).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for VPTO vcmp-i16-signed.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/kernel.pto new file mode 100644 index 000000000..df5c04fe9 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/kernel.pto @@ -0,0 +1,50 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vcmp_eq_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i32 = arith.constant 128 : i32 + %c32_i64_data = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %false = arith.constant false + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c256_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr + + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c128_i32) -> (i32) { + %active, %next = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%c0] : !pto.ptr -> !pto.vreg<128xi16> + %rhs = pto.vlds %ub_rhs[%c0] : !pto.ptr -> !pto.vreg<128xi16> + %pred = pto.vcmp %lhs, %rhs, %active, "eq" : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.mask + pto.psts %pred, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c32_i64_data, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/launch.cpp b/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/launch.cpp new file mode 100644 index 000000000..0af8b16de --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/launch.cpp @@ -0,0 +1,59 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/compare-select/vcmp-i16-signed +// family: compare-select +// target_ops: pto.vcmp +// scenarios: core-i16-signed, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcmp_eq_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ unsigned char *v3); + +void LaunchVcmp_eq_kernel_2d(float *v1, float *v2, unsigned char *v3, + void *stream) { + vcmp_eq_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ unsigned char *)v3); +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/main.cpp b/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/main.cpp new file mode 100644 index 000000000..0857d2e9e --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/main.cpp @@ -0,0 +1,112 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/compare-select/vcmp-i16-signed +// family: compare-select +// target_ops: pto.vcmp +// scenarios: core-i16-signed, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcmp_eq_kernel_2d(float *v1, float *v2, unsigned char *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 128; + size_t fileSize_v1 = elemCount_v1 * sizeof(short); + size_t elemCount_v2 = 128; + size_t fileSize_v2 = elemCount_v2 * sizeof(short); + size_t elemCount_v3 = 32; + size_t fileSize_v3 = elemCount_v3 * sizeof(unsigned char); + short *v1Host = nullptr; + short *v1Device = nullptr; + short *v2Host = nullptr; + short *v2Device = nullptr; + unsigned char *v3Host = nullptr; + unsigned char *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVcmp_eq_kernel_2d(reinterpret_cast(v1Device), + reinterpret_cast(v2Device), v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/stub.cpp b/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/stub.cpp new file mode 100644 index 000000000..fb202c4c4 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/stub.cpp @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/compare-select/vcmp-i16-signed +// family: compare-select +// target_ops: pto.vcmp +// scenarios: core-i16-signed, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vcmp_eq_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ unsigned char *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/compare.py b/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/compare.py new file mode 100644 index 000000000..a872552e3 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/compare.py @@ -0,0 +1,35 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + golden = np.fromfile("golden_v3.bin", dtype=np.uint8) + output = np.fromfile("v3.bin", dtype=np.uint8) + ok = golden.size >= 32 and output.size >= 32 and np.array_equal(golden[:32], output[:32]) + if not ok: + if golden.size and output.size: + diff = np.nonzero(golden[:32] != output[:32])[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch: idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/golden.py b/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/golden.py new file mode 100644 index 000000000..1742febf1 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/golden.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 128 +SEED = 19 +OUTPUT_BYTES = 32 + + +def encode_b16_mask(mask: np.ndarray) -> np.ndarray: + out = np.zeros((OUTPUT_BYTES,), dtype=np.uint8) + for i, bit in enumerate(mask.astype(np.uint8, copy=False)): + if bit: + byte_index = i // 4 + pair_shift = 2 * (i % 4) + out[byte_index] |= np.uint8(1 << pair_shift) + return out + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, 60000, size=(LANES,), dtype=np.uint16) + v2 = v1.copy() + mismatch = (np.arange(LANES, dtype=np.int32) % 3) == 1 + v2[mismatch] = (v2[mismatch] + np.uint16(7)).astype(np.uint16) + mask = np.equal(v1, v2) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + np.zeros((OUTPUT_BYTES,), dtype=np.uint8).tofile(output_dir / "v3.bin") + encode_b16_mask(mask).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for VPTO vcmp-i16-unsigned.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/kernel.pto new file mode 100644 index 000000000..8dfa3cf79 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/kernel.pto @@ -0,0 +1,50 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vcmp_eq_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i32 = arith.constant 128 : i32 + %c32_i64_data = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %false = arith.constant false + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c256_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr + + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c128_i32) -> (i32) { + %active, %next = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%c0] : !pto.ptr -> !pto.vreg<128xui16> + %rhs = pto.vlds %ub_rhs[%c0] : !pto.ptr -> !pto.vreg<128xui16> + %pred = pto.vcmp %lhs, %rhs, %active, "eq" : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.mask + pto.psts %pred, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c32_i64_data, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/launch.cpp b/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/launch.cpp new file mode 100644 index 000000000..e1199bd24 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/launch.cpp @@ -0,0 +1,59 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/compare-select/vcmp-i16-unsigned +// family: compare-select +// target_ops: pto.vcmp +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcmp_eq_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ unsigned char *v3); + +void LaunchVcmp_eq_kernel_2d(float *v1, float *v2, unsigned char *v3, + void *stream) { + vcmp_eq_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ unsigned char *)v3); +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/main.cpp b/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/main.cpp new file mode 100644 index 000000000..8e5b3e1d4 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/main.cpp @@ -0,0 +1,112 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/compare-select/vcmp-i16-unsigned +// family: compare-select +// target_ops: pto.vcmp +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcmp_eq_kernel_2d(float *v1, float *v2, unsigned char *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 128; + size_t fileSize_v1 = elemCount_v1 * sizeof(unsigned short); + size_t elemCount_v2 = 128; + size_t fileSize_v2 = elemCount_v2 * sizeof(unsigned short); + size_t elemCount_v3 = 32; + size_t fileSize_v3 = elemCount_v3 * sizeof(unsigned char); + unsigned short *v1Host = nullptr; + unsigned short *v1Device = nullptr; + unsigned short *v2Host = nullptr; + unsigned short *v2Device = nullptr; + unsigned char *v3Host = nullptr; + unsigned char *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVcmp_eq_kernel_2d(reinterpret_cast(v1Device), + reinterpret_cast(v2Device), v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/stub.cpp b/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/stub.cpp new file mode 100644 index 000000000..8c1a42dab --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/stub.cpp @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/compare-select/vcmp-i16-unsigned +// family: compare-select +// target_ops: pto.vcmp +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vcmp_eq_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ unsigned char *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-lt/compare.py b/test/vpto/cases/micro-op/compare-select/vcmp-lt/compare.py new file mode 100644 index 000000000..a872552e3 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-lt/compare.py @@ -0,0 +1,35 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + golden = np.fromfile("golden_v3.bin", dtype=np.uint8) + output = np.fromfile("v3.bin", dtype=np.uint8) + ok = golden.size >= 32 and output.size >= 32 and np.array_equal(golden[:32], output[:32]) + if not ok: + if golden.size and output.size: + diff = np.nonzero(golden[:32] != output[:32])[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch: idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-lt/golden.py b/test/vpto/cases/micro-op/compare-select/vcmp-lt/golden.py new file mode 100644 index 000000000..6feb1da41 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-lt/golden.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 64 +SEED = 19 +OUTPUT_BYTES = 32 + + +def encode_b32_mask(mask: np.ndarray) -> np.ndarray: + out = np.zeros((OUTPUT_BYTES,), dtype=np.uint8) + for i, bit in enumerate(mask.astype(np.uint8, copy=False)): + if bit: + byte_index = i // 2 + nibble_shift = 4 * (i % 2) + out[byte_index] |= np.uint8(1 << nibble_shift) + return out + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-3.0, 3.0, size=(LANES,)).astype(np.float32) + delta = rng.uniform(0.25, 1.25, size=(LANES,)).astype(np.float32) + choose_less = (np.arange(LANES, dtype=np.int32) % 2) == 0 + v2 = np.where(choose_less, v1 + delta, v1 - delta).astype(np.float32) + mask = np.less(v1, v2) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + np.zeros((OUTPUT_BYTES,), dtype=np.uint8).tofile(output_dir / "v3.bin") + encode_b32_mask(mask).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for VPTO vcmp-lt.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-lt/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmp-lt/kernel.pto new file mode 100644 index 000000000..fbe299c4b --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-lt/kernel.pto @@ -0,0 +1,49 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vcmp_lt_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i32 = arith.constant 64 : i32 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %false = arith.constant false + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c256_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr + + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c64_i32) -> (i32) { + %active, %next = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %pred = pto.vcmp %lhs, %rhs, %active, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + pto.psts %pred, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-lt/launch.cpp b/test/vpto/cases/micro-op/compare-select/vcmp-lt/launch.cpp new file mode 100644 index 000000000..2762499e2 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-lt/launch.cpp @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcmp_lt_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ unsigned char *v3); + +void LaunchVcmp_lt_kernel_2d(float *v1, float *v2, unsigned char *v3, + void *stream) { + vcmp_lt_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ unsigned char *)v3); +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-lt/main.cpp b/test/vpto/cases/micro-op/compare-select/vcmp-lt/main.cpp new file mode 100644 index 000000000..fa06a715f --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-lt/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcmp_lt_kernel_2d(float *v1, float *v2, unsigned char *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 64; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 64; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 32; + size_t fileSize_v3 = elemCount_v3 * sizeof(unsigned char); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + unsigned char *v3Host = nullptr; + unsigned char *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVcmp_lt_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-lt/stub.cpp b/test/vpto/cases/micro-op/compare-select/vcmp-lt/stub.cpp new file mode 100644 index 000000000..85314f5ef --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-lt/stub.cpp @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vcmp_lt_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ unsigned char *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-tail/compare.py b/test/vpto/cases/micro-op/compare-select/vcmp-tail/compare.py new file mode 100644 index 000000000..a872552e3 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-tail/compare.py @@ -0,0 +1,35 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + golden = np.fromfile("golden_v3.bin", dtype=np.uint8) + output = np.fromfile("v3.bin", dtype=np.uint8) + ok = golden.size >= 32 and output.size >= 32 and np.array_equal(golden[:32], output[:32]) + if not ok: + if golden.size and output.size: + diff = np.nonzero(golden[:32] != output[:32])[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch: idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-tail/golden.py b/test/vpto/cases/micro-op/compare-select/vcmp-tail/golden.py new file mode 100644 index 000000000..f59aded57 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-tail/golden.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 64 +LOGICAL_ELEMS = 53 +SEED = 19 +OUTPUT_BYTES = 32 + + +def encode_b32_mask(mask: np.ndarray) -> np.ndarray: + out = np.zeros((OUTPUT_BYTES,), dtype=np.uint8) + for i, bit in enumerate(mask.astype(np.uint8, copy=False)): + if bit: + byte_index = i // 2 + nibble_shift = 4 * (i % 2) + out[byte_index] |= np.uint8(1 << nibble_shift) + return out + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-6.0, 6.0, size=(LANES,)).astype(np.float32) + delta = rng.uniform(0.1, 2.0, size=(LANES,)).astype(np.float32) + mode = np.arange(LANES, dtype=np.int32) % 5 + + v2 = np.empty((LANES,), dtype=np.float32) + v2[mode == 0] = v1[mode == 0] + delta[mode == 0] + v2[mode == 1] = v1[mode == 1] - delta[mode == 1] + v2[mode == 2] = v1[mode == 2] + v2[mode == 3] = np.nextafter(v1[mode == 3], np.float32(np.inf)) + v2[mode == 4] = np.nextafter(v1[mode == 4], np.float32(-np.inf)) + + v1[:10] = np.array([-3.0, -1.0, -0.0, 0.0, 0.25, 1.0, 2.0, 4.0, -4.0, 6.0], dtype=np.float32) + v2[:10] = np.array([ + -2.0, + -2.0, + 0.0, + np.nextafter(np.float32(0.0), np.float32(np.inf)), + 0.25, + np.nextafter(np.float32(1.0), np.float32(-np.inf)), + 3.0, + 3.0, + np.nextafter(np.float32(-4.0), np.float32(np.inf)), + 6.0, + ], dtype=np.float32) + + mask = np.less(v1, v2) + mask[LOGICAL_ELEMS:] = False + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + np.zeros((OUTPUT_BYTES,), dtype=np.uint8).tofile(output_dir / "v3.bin") + encode_b32_mask(mask).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for VPTO vcmp-tail.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-tail/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmp-tail/kernel.pto new file mode 100644 index 000000000..610422e29 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-tail/kernel.pto @@ -0,0 +1,49 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vcmp_tail_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c53_i32 = arith.constant 53 : i32 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %false = arith.constant false + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c256_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr + + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c53_i32) -> (i32) { + %active, %next = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %pred = pto.vcmp %lhs, %rhs, %active, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + pto.psts %pred, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-tail/launch.cpp b/test/vpto/cases/micro-op/compare-select/vcmp-tail/launch.cpp new file mode 100644 index 000000000..c57830b58 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-tail/launch.cpp @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcmp_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ unsigned char *v3); + +void LaunchVcmp_tail_kernel_2d(float *v1, float *v2, unsigned char *v3, + void *stream) { + vcmp_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ unsigned char *)v3); +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-tail/main.cpp b/test/vpto/cases/micro-op/compare-select/vcmp-tail/main.cpp new file mode 100644 index 000000000..ee8661a62 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-tail/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcmp_tail_kernel_2d(float *v1, float *v2, unsigned char *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 64; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 64; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 32; + size_t fileSize_v3 = elemCount_v3 * sizeof(unsigned char); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + unsigned char *v3Host = nullptr; + unsigned char *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVcmp_tail_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-tail/stub.cpp b/test/vpto/cases/micro-op/compare-select/vcmp-tail/stub.cpp new file mode 100644 index 000000000..b68dd2f4a --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-tail/stub.cpp @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vcmp_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ unsigned char *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/compare.py b/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/compare.py new file mode 100644 index 000000000..bc2a4827f --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/compare.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_mask(golden_path, output_path): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch (packed mask): idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_mask("golden_v2.bin", "v2.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/golden.py b/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/golden.py new file mode 100644 index 000000000..d2ca06dc2 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/golden.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 64 +SEED = 19 +THRESHOLD = np.float32(0.5) +OUTPUT_BYTES = 32 + + +def encode_b32_mask(mask: np.ndarray) -> np.ndarray: + out = np.zeros((OUTPUT_BYTES,), dtype=np.uint8) + for i, bit in enumerate(mask.astype(np.uint8, copy=False)): + if bit: + byte_index = i // 2 + nibble_shift = 4 * (i % 2) + out[byte_index] |= np.uint8(1 << nibble_shift) + return out + + +def generate(output_dir: Path, seed: int) -> None: + del seed + specials = np.array( + [-np.inf, -1.0, -0.0, 0.0, 0.5, 0.75, np.inf, np.nan], + dtype=np.float32, + ) + v1 = np.resize(specials, LANES).astype(np.float32) + mask = np.greater(v1, THRESHOLD) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + np.zeros((OUTPUT_BYTES,), dtype=np.uint8).tofile(output_dir / "v2.bin") + encode_b32_mask(mask).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for VPTO vcmps-f32-exceptional.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/kernel.pto new file mode 100644 index 000000000..fbd52f20f --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/kernel.pto @@ -0,0 +1,46 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vcmps_f32_exceptional_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i32 = arith.constant 64 : i32 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %threshold = arith.constant 5.000000e-01 : f32 + %false = arith.constant false + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c256_i64 : i64 -> !pto.ptr + + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_src, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c64_i32) -> (i32) { + %active, %next = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %src = pto.vlds %ub_src[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %pred = pto.vcmps %src, %threshold, %active, "gt" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask + pto.psts %pred, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/launch.cpp b/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/launch.cpp new file mode 100644 index 000000000..e96d87fec --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/launch.cpp @@ -0,0 +1,49 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcmps_f32_exceptional_kernel_2d(__gm__ float *v1, + __gm__ unsigned char *v2); + +void LaunchVcmps_f32_exceptional_kernel_2d(float *v1, unsigned char *v2, + void *stream) { + vcmps_f32_exceptional_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ unsigned char *)v2); +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/main.cpp b/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/main.cpp new file mode 100644 index 000000000..d8d7a33b6 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/main.cpp @@ -0,0 +1,92 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcmps_f32_exceptional_kernel_2d(float *v1, unsigned char *v2, + void *stream); + +int main() { + size_t elemCount_v1 = 64; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 32; + size_t fileSize_v2 = elemCount_v2 * sizeof(unsigned char); + float *v1Host = nullptr; + float *v1Device = nullptr; + unsigned char *v2Host = nullptr; + unsigned char *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVcmps_f32_exceptional_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/stub.cpp b/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/stub.cpp new file mode 100644 index 000000000..dd061a1fb --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/stub.cpp @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vcmps_f32_exceptional_kernel_2d(__gm__ float *v1, + __gm__ unsigned char *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-f32/compare.py b/test/vpto/cases/micro-op/compare-select/vcmps-f32/compare.py new file mode 100644 index 000000000..bc2a4827f --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-f32/compare.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_mask(golden_path, output_path): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch (packed mask): idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_mask("golden_v2.bin", "v2.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-f32/golden.py b/test/vpto/cases/micro-op/compare-select/vcmps-f32/golden.py new file mode 100644 index 000000000..7224594ef --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-f32/golden.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 64 +SEED = 19 +THRESHOLD = np.float32(0.5) +OUTPUT_BYTES = 32 + + +def encode_b32_mask(mask: np.ndarray) -> np.ndarray: + out = np.zeros((OUTPUT_BYTES,), dtype=np.uint8) + for i, bit in enumerate(mask.astype(np.uint8, copy=False)): + if bit: + byte_index = i // 2 + nibble_shift = 4 * (i % 2) + out[byte_index] |= np.uint8(1 << nibble_shift) + return out + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-2.0, 2.0, size=(LANES,)).astype(np.float32) + v1[:8] = np.array([0.5, 0.5001, 0.4999, -0.5, 1.0, -1.0, 0.0, 2.0], + dtype=np.float32) + mask = np.greater(v1, THRESHOLD) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + np.zeros((OUTPUT_BYTES,), dtype=np.uint8).tofile(output_dir / "v2.bin") + encode_b32_mask(mask).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for VPTO vcmps-f32.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-f32/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmps-f32/kernel.pto new file mode 100644 index 000000000..3be90dca5 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-f32/kernel.pto @@ -0,0 +1,46 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vcmps_f32_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i32 = arith.constant 64 : i32 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %threshold = arith.constant 5.000000e-01 : f32 + %false = arith.constant false + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c256_i64 : i64 -> !pto.ptr + + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_src, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c64_i32) -> (i32) { + %active, %next = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %src = pto.vlds %ub_src[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %pred = pto.vcmps %src, %threshold, %active, "gt" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask + pto.psts %pred, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-f32/launch.cpp b/test/vpto/cases/micro-op/compare-select/vcmps-f32/launch.cpp new file mode 100644 index 000000000..49dfceec7 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-f32/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcmps_f32_kernel_2d(__gm__ float *v1, + __gm__ unsigned char *v2); + +void LaunchVcmps_f32_kernel_2d(float *v1, unsigned char *v2, void *stream) { + vcmps_f32_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ unsigned char *)v2); +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-f32/main.cpp b/test/vpto/cases/micro-op/compare-select/vcmps-f32/main.cpp new file mode 100644 index 000000000..e9c28d290 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-f32/main.cpp @@ -0,0 +1,91 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcmps_f32_kernel_2d(float *v1, unsigned char *v2, void *stream); + +int main() { + size_t elemCount_v1 = 64; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 32; + size_t fileSize_v2 = elemCount_v2 * sizeof(unsigned char); + float *v1Host = nullptr; + float *v1Device = nullptr; + unsigned char *v2Host = nullptr; + unsigned char *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVcmps_f32_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-f32/stub.cpp b/test/vpto/cases/micro-op/compare-select/vcmps-f32/stub.cpp new file mode 100644 index 000000000..7d7eb253f --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-f32/stub.cpp @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vcmps_f32_kernel_2d(__gm__ float *v1, + __gm__ unsigned char *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/compare.py b/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/compare.py new file mode 100644 index 000000000..bc2a4827f --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/compare.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_mask(golden_path, output_path): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch (packed mask): idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_mask("golden_v2.bin", "v2.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/golden.py b/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/golden.py new file mode 100644 index 000000000..1a4865835 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/golden.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 128 +SEED = 19 +THRESHOLD = np.int16(5) +OUTPUT_BYTES = 32 + + +def encode_b16_mask(mask: np.ndarray) -> np.ndarray: + out = np.zeros((OUTPUT_BYTES,), dtype=np.uint8) + for i, bit in enumerate(mask.astype(np.uint8, copy=False)): + if bit: + byte_index = i // 4 + bit_shift = 2 * (i % 4) + out[byte_index] |= np.uint8(1 << bit_shift) + return out + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(-32768, 32767, size=(LANES,), dtype=np.int16) + mask = np.greater(v1, THRESHOLD) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + np.zeros((OUTPUT_BYTES,), dtype=np.uint8).tofile(output_dir / "v2.bin") + encode_b16_mask(mask).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for VPTO vcmps-i16-signed.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/kernel.pto new file mode 100644 index 000000000..c1f08269e --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/kernel.pto @@ -0,0 +1,45 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vcmps_i16_signed_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %c128_i32 = arith.constant 128 : i32 + %threshold = arith.constant 5 : i16 + %false = arith.constant false + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c256_i64 : i64 -> !pto.ptr + + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_src, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c128_i32) -> (i32) { + %active, %next = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %src = pto.vlds %ub_src[%c0] : !pto.ptr -> !pto.vreg<128xi16> + %pred = pto.vcmps %src, %threshold, %active, "gt" : !pto.vreg<128xi16>, i16, !pto.mask -> !pto.mask + pto.psts %pred, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/launch.cpp b/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/launch.cpp new file mode 100644 index 000000000..6fd05c86e --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcmps_i16_signed_kernel_2d(__gm__ int16_t *v1, + __gm__ unsigned char *v2); + +void LaunchVcmps_i16_signed_kernel_2d(int16_t *v1, unsigned char *v2, void *stream) { + vcmps_i16_signed_kernel_2d<<<1, nullptr, stream>>>((__gm__ int16_t *)v1, + (__gm__ unsigned char *)v2); +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/main.cpp b/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/main.cpp new file mode 100644 index 000000000..2d318f173 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/main.cpp @@ -0,0 +1,92 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcmps_i16_signed_kernel_2d(int16_t *v1, unsigned char *v2, void *stream); + +int main() { + size_t elemCount_v1 = 128; + size_t fileSize_v1 = elemCount_v1 * sizeof(int16_t); + size_t elemCount_v2 = 32; + size_t fileSize_v2 = elemCount_v2 * sizeof(unsigned char); + int16_t *v1Host = nullptr; + int16_t *v1Device = nullptr; + unsigned char *v2Host = nullptr; + unsigned char *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVcmps_i16_signed_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/stub.cpp b/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/stub.cpp new file mode 100644 index 000000000..a265a046b --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/stub.cpp @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vcmps_i16_signed_kernel_2d(__gm__ int16_t *v1, + __gm__ unsigned char *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/compare.py b/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/compare.py new file mode 100644 index 000000000..bc2a4827f --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/compare.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_mask(golden_path, output_path): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch (packed mask): idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_mask("golden_v2.bin", "v2.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/golden.py b/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/golden.py new file mode 100644 index 000000000..b7318cc9f --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/golden.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 128 +SEED = 19 +THRESHOLD = np.uint16(513) +OUTPUT_BYTES = 32 + + +def encode_b16_mask(mask: np.ndarray) -> np.ndarray: + out = np.zeros((OUTPUT_BYTES,), dtype=np.uint8) + for i, bit in enumerate(mask.astype(np.uint8, copy=False)): + if bit: + byte_index = i // 4 + bit_shift = 2 * (i % 4) + out[byte_index] |= np.uint8(1 << bit_shift) + return out + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, 65535, size=(LANES,), dtype=np.uint16) + mask = np.greater(v1, THRESHOLD) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + np.zeros((OUTPUT_BYTES,), dtype=np.uint8).tofile(output_dir / "v2.bin") + encode_b16_mask(mask).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for VPTO vcmps-i16-unsigned.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/kernel.pto new file mode 100644 index 000000000..dc2864f86 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/kernel.pto @@ -0,0 +1,46 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vcmps_i16_unsigned_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %c128_i32 = arith.constant 128 : i32 + %threshold_i16 = arith.constant 513 : i16 + %threshold = builtin.unrealized_conversion_cast %threshold_i16 : i16 to ui16 + %false = arith.constant false + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c256_i64 : i64 -> !pto.ptr + + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_src, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c128_i32) -> (i32) { + %active, %next = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %src = pto.vlds %ub_src[%c0] : !pto.ptr -> !pto.vreg<128xui16> + %pred = pto.vcmps %src, %threshold, %active, "gt" : !pto.vreg<128xui16>, ui16, !pto.mask -> !pto.mask + pto.psts %pred, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/launch.cpp b/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/launch.cpp new file mode 100644 index 000000000..2178b15fa --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcmps_i16_unsigned_kernel_2d(__gm__ uint16_t *v1, + __gm__ unsigned char *v2); + +void LaunchVcmps_i16_unsigned_kernel_2d(uint16_t *v1, unsigned char *v2, void *stream) { + vcmps_i16_unsigned_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint16_t *)v1, + (__gm__ unsigned char *)v2); +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/main.cpp b/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/main.cpp new file mode 100644 index 000000000..717dc9332 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/main.cpp @@ -0,0 +1,92 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcmps_i16_unsigned_kernel_2d(uint16_t *v1, unsigned char *v2, void *stream); + +int main() { + size_t elemCount_v1 = 128; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 32; + size_t fileSize_v2 = elemCount_v2 * sizeof(unsigned char); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + unsigned char *v2Host = nullptr; + unsigned char *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVcmps_i16_unsigned_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/stub.cpp b/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/stub.cpp new file mode 100644 index 000000000..45e1a4015 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/stub.cpp @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vcmps_i16_unsigned_kernel_2d(__gm__ uint16_t *v1, + __gm__ unsigned char *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-tail/compare.py b/test/vpto/cases/micro-op/compare-select/vcmps-tail/compare.py new file mode 100644 index 000000000..bc2a4827f --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-tail/compare.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_mask(golden_path, output_path): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch (packed mask): idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_mask("golden_v2.bin", "v2.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-tail/golden.py b/test/vpto/cases/micro-op/compare-select/vcmps-tail/golden.py new file mode 100644 index 000000000..e36631b9a --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-tail/golden.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 64 +LOGICAL_ELEMS = 40 +SEED = 19 +THRESHOLD = np.float32(0.5) +OUTPUT_BYTES = 32 + + +def encode_b32_mask(mask: np.ndarray) -> np.ndarray: + out = np.zeros((OUTPUT_BYTES,), dtype=np.uint8) + for i, bit in enumerate(mask.astype(np.uint8, copy=False)): + if bit: + byte_index = i // 2 + nibble_shift = 4 * (i % 2) + out[byte_index] |= np.uint8(1 << nibble_shift) + return out + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-2.0, 2.0, size=(LANES,)).astype(np.float32) + + v1[:12] = np.array([ + THRESHOLD, + np.nextafter(THRESHOLD, np.float32(np.inf)), + np.nextafter(THRESHOLD, np.float32(-np.inf)), + 0.0, + -0.0, + -1.0, + 1.0, + 2.0, + -2.0, + THRESHOLD + np.float32(0.25), + THRESHOLD - np.float32(0.25), + THRESHOLD, + ], dtype=np.float32) + + mask = np.greater(v1, THRESHOLD) + mask[LOGICAL_ELEMS:] = False + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + np.zeros((OUTPUT_BYTES,), dtype=np.uint8).tofile(output_dir / "v2.bin") + encode_b32_mask(mask).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for VPTO vcmps-tail.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-tail/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmps-tail/kernel.pto new file mode 100644 index 000000000..342605924 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-tail/kernel.pto @@ -0,0 +1,45 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vcmps_tail_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c40_i32 = arith.constant 40 : i32 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %threshold = arith.constant 5.000000e-01 : f32 + %false = arith.constant false + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c256_i64 : i64 -> !pto.ptr + + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_src, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c40_i32) -> (i32) { + %active, %next = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %src = pto.vlds %ub_src[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %pred = pto.vcmps %src, %threshold, %active, "gt" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask + pto.psts %pred, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-tail/launch.cpp b/test/vpto/cases/micro-op/compare-select/vcmps-tail/launch.cpp new file mode 100644 index 000000000..a210fc2fa --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-tail/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcmps_tail_kernel_2d(__gm__ float *v1, + __gm__ unsigned char *v2); + +void LaunchVcmps_tail_kernel_2d(float *v1, unsigned char *v2, void *stream) { + vcmps_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ unsigned char *)v2); +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-tail/main.cpp b/test/vpto/cases/micro-op/compare-select/vcmps-tail/main.cpp new file mode 100644 index 000000000..941741c4b --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-tail/main.cpp @@ -0,0 +1,91 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcmps_tail_kernel_2d(float *v1, unsigned char *v2, void *stream); + +int main() { + size_t elemCount_v1 = 64; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 32; + size_t fileSize_v2 = elemCount_v2 * sizeof(unsigned char); + float *v1Host = nullptr; + float *v1Device = nullptr; + unsigned char *v2Host = nullptr; + unsigned char *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVcmps_tail_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-tail/stub.cpp b/test/vpto/cases/micro-op/compare-select/vcmps-tail/stub.cpp new file mode 100644 index 000000000..d4c3a9ba1 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-tail/stub.cpp @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vcmps_tail_kernel_2d(__gm__ float *v1, + __gm__ unsigned char *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/compare-select/vsel-i16/compare.py b/test/vpto/cases/micro-op/compare-select/vsel-i16/compare.py new file mode 100644 index 000000000..cb78833f5 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-i16/compare.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch: idx={idx} golden={golden[idx]} out={output[idx]}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin", np.int16) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vsel-i16/golden.py b/test/vpto/cases/micro-op/compare-select/vsel-i16/golden.py new file mode 100644 index 000000000..9a5b86185 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-i16/golden.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 128 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(-200, 200, size=(LANES,), dtype=np.int16) + v2 = rng.integers(-200, 200, size=(LANES,), dtype=np.int16) + golden_v3 = np.where(v1 > v2, v1, v2).astype(np.int16, copy=False) + v3 = np.zeros((LANES,), dtype=np.int16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for VPTO vsel-i16.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vsel-i16/kernel.pto b/test/vpto/cases/micro-op/compare-select/vsel-i16/kernel.pto new file mode 100644 index 000000000..58d5eac19 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-i16/kernel.pto @@ -0,0 +1,48 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vsel_i16_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i32 = arith.constant 128 : i32 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %false = arith.constant false + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c256_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr + + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c128_i32) -> (i32) { + %active, %next = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%c0] : !pto.ptr -> !pto.vreg<128xi16> + %rhs = pto.vlds %ub_rhs[%c0] : !pto.ptr -> !pto.vreg<128xi16> + %pred = pto.vcmp %lhs, %rhs, %active, "gt" : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.mask + %out = pto.vsel %lhs, %rhs, %pred : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + pto.vsts %out, %ub_out[%c0], %active : !pto.vreg<128xi16>, !pto.ptr, !pto.mask + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.set_loop1_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vsel-i16/launch.cpp b/test/vpto/cases/micro-op/compare-select/vsel-i16/launch.cpp new file mode 100644 index 000000000..8ca125427 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-i16/launch.cpp @@ -0,0 +1,50 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vsel_i16_kernel_2d(__gm__ int16_t *v1, + __gm__ int16_t *v2, + __gm__ int16_t *v3); + +void LaunchVsel_i16_kernel_2d(int16_t *v1, int16_t *v2, int16_t *v3, void *stream) { + vsel_i16_kernel_2d<<<1, nullptr, stream>>>((__gm__ int16_t *)v1, + (__gm__ int16_t *)v2, + (__gm__ int16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/compare-select/vsel-i16/main.cpp b/test/vpto/cases/micro-op/compare-select/vsel-i16/main.cpp new file mode 100644 index 000000000..656e4dfec --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-i16/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVsel_i16_kernel_2d(int16_t *v1, int16_t *v2, int16_t *v3, void *stream); + +int main() { + size_t elemCount_v1 = 128; + size_t fileSize_v1 = elemCount_v1 * sizeof(int16_t); + size_t elemCount_v2 = 128; + size_t fileSize_v2 = elemCount_v2 * sizeof(int16_t); + size_t elemCount_v3 = 128; + size_t fileSize_v3 = elemCount_v3 * sizeof(int16_t); + int16_t *v1Host = nullptr; + int16_t *v1Device = nullptr; + int16_t *v2Host = nullptr; + int16_t *v2Device = nullptr; + int16_t *v3Host = nullptr; + int16_t *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVsel_i16_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vsel-i16/stub.cpp b/test/vpto/cases/micro-op/compare-select/vsel-i16/stub.cpp new file mode 100644 index 000000000..8db319318 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-i16/stub.cpp @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vsel_i16_kernel_2d(__gm__ int16_t *v1, + __gm__ int16_t *v2, + __gm__ int16_t *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/compare.py b/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/compare.py new file mode 100644 index 000000000..a861864de --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/compare.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + diff = np.abs(golden.astype(np.float64) - output.astype(np.float64)) + idx = int(np.argmax(diff)) + print(f"[ERROR] Mismatch: idx={idx} golden={golden[idx]} out={output[idx]}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 1e-6) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/golden.py b/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/golden.py new file mode 100644 index 000000000..e757d1089 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 64 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + lhs = rng.uniform(-8.0, 8.0, size=(LANES,)).astype(np.float32) + rhs = lhs.copy() + + lane_ids = np.arange(LANES, dtype=np.int32) + edge_mask = ((lane_ids < 4) | (lane_ids >= 60) | ((lane_ids % 17) == 0)) + rhs[edge_mask] = (rhs[edge_mask] + np.float32(3.5)).astype(np.float32) + rhs[~edge_mask] = (rhs[~edge_mask] - np.float32(2.0)).astype(np.float32) + + out = np.zeros((LANES,), dtype=np.float32) + golden = np.where(lhs > rhs, lhs, rhs).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + lhs.tofile(output_dir / "v1.bin") + rhs.tofile(output_dir / "v2.bin") + out.tofile(output_dir / "v3.bin") + golden.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for VPTO vsel-predicate-edge.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/kernel.pto b/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/kernel.pto new file mode 100644 index 000000000..79ed52b1a --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/kernel.pto @@ -0,0 +1,50 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vsel_predicate_edge_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i32 = arith.constant 64 : i32 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %false = arith.constant false + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c256_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr + + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c64_i32) -> (i32) { + %active, %next = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %pred = pto.vcmp %lhs, %rhs, %active, "gt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + %out = pto.vsel %lhs, %rhs, %pred : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%c0], %active : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.set_loop1_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/launch.cpp b/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/launch.cpp new file mode 100644 index 000000000..c23a21ddd --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/launch.cpp @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vsel_predicate_edge_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchVsel_predicate_edge_kernel_2d(float *v1, float *v2, float *v3, + void *stream) { + vsel_predicate_edge_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/main.cpp b/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/main.cpp new file mode 100644 index 000000000..6fea6f0e8 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVsel_predicate_edge_kernel_2d(float *v1, float *v2, float *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 64; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 64; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 64; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVsel_predicate_edge_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/stub.cpp b/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/stub.cpp new file mode 100644 index 000000000..2d416e94a --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/stub.cpp @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vsel_predicate_edge_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/compare-select/vsel-tail/compare.py b/test/vpto/cases/micro-op/compare-select/vsel-tail/compare.py new file mode 100755 index 000000000..b5d54649a --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-tail/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/compare-select/vsel-tail +# family: compare-select +# target_ops: pto.vsel +# scenarios: core-f32, tail-mask +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + diff = np.abs(golden.astype(np.float64) - output.astype(np.float64)) + idx = int(np.argmax(diff)) + print(f"[ERROR] Mismatch: idx={idx} golden={golden[idx]} out={output[idx]}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 1e-6) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vsel-tail/golden.py b/test/vpto/cases/micro-op/compare-select/vsel-tail/golden.py new file mode 100644 index 000000000..a2d6807fa --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-tail/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 64 +SEED = 19 +LOGICAL_ELEMS = 40 +OUT_SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-3.0, 3.0, size=(LANES,)).astype(np.float32) + v2 = rng.uniform(-3.0, 3.0, size=(LANES,)).astype(np.float32) + golden_v3 = np.full((LANES,), OUT_SENTINEL, dtype=np.float32) + flat = np.where(v1 > v2, v1, v2).astype(np.float32, copy=False) + golden_v3[:LOGICAL_ELEMS] = flat[:LOGICAL_ELEMS] + v3 = np.full((LANES,), OUT_SENTINEL, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for VPTO vsel-tail.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vsel-tail/kernel.pto b/test/vpto/cases/micro-op/compare-select/vsel-tail/kernel.pto new file mode 100644 index 000000000..0d4602a58 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-tail/kernel.pto @@ -0,0 +1,49 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vsel_tail_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c40_i32 = arith.constant 40 : i32 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %false = arith.constant false + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c256_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr + + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.copy_gm_to_ubuf %arg2, %ub_out, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c40_i32) -> (i32) { + %active, %next = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %pred = pto.vcmp %lhs, %rhs, %active, "gt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + %out = pto.vsel %lhs, %rhs, %pred : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%c0], %active : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.set_loop1_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vsel-tail/launch.cpp b/test/vpto/cases/micro-op/compare-select/vsel-tail/launch.cpp new file mode 100644 index 000000000..b4e0598e0 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-tail/launch.cpp @@ -0,0 +1,50 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vsel_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchVsel_tail_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + vsel_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/compare-select/vsel-tail/main.cpp b/test/vpto/cases/micro-op/compare-select/vsel-tail/main.cpp new file mode 100644 index 000000000..323131056 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-tail/main.cpp @@ -0,0 +1,102 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVsel_tail_kernel_2d(float *v1, float *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 64; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 64; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 64; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVsel_tail_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vsel-tail/stub.cpp b/test/vpto/cases/micro-op/compare-select/vsel-tail/stub.cpp new file mode 100644 index 000000000..df964325d --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-tail/stub.cpp @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vsel_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/compare-select/vsel/compare.py b/test/vpto/cases/micro-op/compare-select/vsel/compare.py new file mode 100755 index 000000000..6cd01c922 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/compare-select/vsel +# family: compare-select +# target_ops: pto.vsel +# scenarios: core-f32, full-mask +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + diff = np.abs(golden.astype(np.float64) - output.astype(np.float64)) + idx = int(np.argmax(diff)) + print(f"[ERROR] Mismatch: idx={idx} golden={golden[idx]} out={output[idx]}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 1e-6) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vsel/golden.py b/test/vpto/cases/micro-op/compare-select/vsel/golden.py new file mode 100644 index 000000000..fef5f7d27 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel/golden.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 64 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-3.0, 3.0, size=(LANES,)).astype(np.float32) + v2 = rng.uniform(-3.0, 3.0, size=(LANES,)).astype(np.float32) + golden_v3 = np.where(v1 > v2, v1, v2).astype(np.float32, copy=False) + v3 = np.zeros((LANES,), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for VPTO vsel.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vsel/kernel.pto b/test/vpto/cases/micro-op/compare-select/vsel/kernel.pto new file mode 100644 index 000000000..9a330d306 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel/kernel.pto @@ -0,0 +1,48 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vsel_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i32 = arith.constant 64 : i32 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %false = arith.constant false + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c256_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr + + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c64_i32) -> (i32) { + %active, %next = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %pred = pto.vcmp %lhs, %rhs, %active, "gt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + %out = pto.vsel %lhs, %rhs, %pred : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%c0], %active : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.set_loop1_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vsel/launch.cpp b/test/vpto/cases/micro-op/compare-select/vsel/launch.cpp new file mode 100644 index 000000000..269405dee --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel/launch.cpp @@ -0,0 +1,50 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vsel_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchVsel_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + vsel_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/compare-select/vsel/main.cpp b/test/vpto/cases/micro-op/compare-select/vsel/main.cpp new file mode 100644 index 000000000..cf71eb295 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel/main.cpp @@ -0,0 +1,102 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVsel_kernel_2d(float *v1, float *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 64; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 64; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 64; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVsel_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vsel/stub.cpp b/test/vpto/cases/micro-op/compare-select/vsel/stub.cpp new file mode 100644 index 000000000..a7138163e --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel/stub.cpp @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vsel_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/compare-select/vselr-f16/compare.py b/test/vpto/cases/micro-op/compare-select/vselr-f16/compare.py new file mode 100644 index 000000000..b961a3713 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vselr-f16/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/compare-select/vselr-f16 +# family: compare-select +# target_ops: pto.vselr +# scenarios: core-f16, full-mask, explicit-lane-index + +import os +import sys + +import numpy as np + + +def compare_tensor(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.float16) + output = np.fromfile(output_path, dtype=np.float16) + if golden.shape != output.shape: + return False + if not np.allclose(golden, output, rtol=0.0, atol=0.0, equal_nan=True): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch: idx={idx} golden={golden[idx]} out={output[idx]}") + return False + return True + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_tensor("golden_v3.bin", "v3.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vselr-f16/golden.py b/test/vpto/cases/micro-op/compare-select/vselr-f16/golden.py new file mode 100644 index 000000000..ae0513ecf --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vselr-f16/golden.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/compare-select/vselr-f16 +# family: compare-select +# target_ops: pto.vselr +# scenarios: core-f16, full-mask, explicit-lane-index + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 8 +COLS = 128 +SEED = 23 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + src = rng.uniform(-6.0, 6.0, size=(ROWS, COLS)).astype(np.float16, copy=False) + lane_ids = np.arange(COLS, dtype=np.uint16) + idx = np.empty((ROWS, COLS), dtype=np.uint16) + for row in range(ROWS): + idx[row] = (lane_ids[::-1] + row * 11 + (lane_ids % 7) * 3) % COLS + golden = np.take_along_axis(src, idx.astype(np.int64, copy=False), axis=1).astype(np.float16, copy=False) + out = np.zeros((ROWS, COLS), dtype=np.float16) + + output_dir.mkdir(parents=True, exist_ok=True) + src.view(np.uint16).reshape(-1).tofile(output_dir / "v1.bin") + idx.reshape(-1).tofile(output_dir / "v2.bin") + out.view(np.uint16).reshape(-1).tofile(output_dir / "v3.bin") + golden.view(np.uint16).reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for vselr-f16.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vselr-f16/kernel.pto b/test/vpto/cases/micro-op/compare-select/vselr-f16/kernel.pto new file mode 100644 index 000000000..cea3fd793 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vselr-f16/kernel.pto @@ -0,0 +1,51 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/compare-select/vselr-f16 +// family: compare-select +// target_ops: pto.vselr +// scenarios: core-f16, full-mask, explicit-lane-index +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vselr_f16_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c8_i64 = arith.constant 8 : i64 + %c256_i64 = arith.constant 256 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + %false = arith.constant false + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_idx = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_src, %c0_i64, %c8_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_idx, %c0_i64, %c8_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %src = pto.vlds %ub_src[%offset] : !pto.ptr -> !pto.vreg<128xf16> + %idx = pto.vlds %ub_idx[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %out = pto.vselr %src, %idx : !pto.vreg<128xf16>, !pto.vreg<128xui16> -> !pto.vreg<128xf16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xf16>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c8_i64, %c256_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vselr-f16/launch.cpp b/test/vpto/cases/micro-op/compare-select/vselr-f16/launch.cpp new file mode 100644 index 000000000..f00e5672d --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vselr-f16/launch.cpp @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/compare-select/vselr-f16 +// family: compare-select +// target_ops: pto.vselr +// scenarios: core-f16, full-mask, explicit-lane-index +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vselr_f16_kernel_2d(__gm__ half *v1, + __gm__ uint16_t *v2, + __gm__ half *v3); + +void LaunchVselr_f16_kernel_2d(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream) { + vselr_f16_kernel_2d<<<1, nullptr, stream>>>((__gm__ half *)v1, + (__gm__ uint16_t *)v2, + (__gm__ half *)v3); +} diff --git a/test/vpto/cases/micro-op/compare-select/vselr-f16/main.cpp b/test/vpto/cases/micro-op/compare-select/vselr-f16/main.cpp new file mode 100644 index 000000000..2002d9ee5 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vselr-f16/main.cpp @@ -0,0 +1,106 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/compare-select/vselr-f16 +// family: compare-select +// target_ops: pto.vselr +// scenarios: core-f16, full-mask, explicit-lane-index +// ----------------------------------------------------------------------------- +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVselr_f16_kernel_2d(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVselr_f16_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vselr-f16/stub.cpp b/test/vpto/cases/micro-op/compare-select/vselr-f16/stub.cpp new file mode 100644 index 000000000..8bc99e40d --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vselr-f16/stub.cpp @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/compare-select/vselr-f16 +// family: compare-select +// target_ops: pto.vselr +// scenarios: core-f16, full-mask, explicit-lane-index +// ----------------------------------------------------------------------------- +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vselr_f16_kernel_2d(__gm__ half *v1, + __gm__ uint16_t *v2, + __gm__ half *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/compare-select/vselr-u8/compare.py b/test/vpto/cases/micro-op/compare-select/vselr-u8/compare.py new file mode 100644 index 000000000..d48e1a42e --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vselr-u8/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/compare-select/vselr-u8 +# family: compare-select +# target_ops: pto.vselr +# scenarios: core-u8, full-mask, explicit-lane-index + +import os +import sys + +import numpy as np + + +def compare_tensor(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + if golden.shape != output.shape: + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch: idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + return False + return True + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_tensor("golden_v3.bin", "v3.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vselr-u8/golden.py b/test/vpto/cases/micro-op/compare-select/vselr-u8/golden.py new file mode 100644 index 000000000..2cb03404a --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vselr-u8/golden.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/compare-select/vselr-u8 +# family: compare-select +# target_ops: pto.vselr +# scenarios: core-u8, full-mask, explicit-lane-index + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 4 +COLS = 256 +SEED = 29 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + src = rng.integers(0, 256, size=(ROWS, COLS), dtype=np.uint8) + lane_ids = np.arange(COLS, dtype=np.uint16) + idx = np.empty((ROWS, COLS), dtype=np.uint8) + for row in range(ROWS): + row_idx = (lane_ids[::-1] + row * 19 + (lane_ids % 13) * 5) % COLS + idx[row] = row_idx.astype(np.uint8, copy=False) + golden = np.take_along_axis(src, idx.astype(np.int64, copy=False), axis=1).astype(np.uint8, copy=False) + out = np.zeros((ROWS, COLS), dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + idx.reshape(-1).tofile(output_dir / "v2.bin") + out.reshape(-1).tofile(output_dir / "v3.bin") + golden.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for vselr-u8.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vselr-u8/kernel.pto b/test/vpto/cases/micro-op/compare-select/vselr-u8/kernel.pto new file mode 100644 index 000000000..08d052afc --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vselr-u8/kernel.pto @@ -0,0 +1,51 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/compare-select/vselr-u8 +// family: compare-select +// target_ops: pto.vselr +// scenarios: core-u8, full-mask, explicit-lane-index +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vselr_u8_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c4_i64 = arith.constant 4 : i64 + %c256_i64 = arith.constant 256 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c1024_i32 = arith.constant 1024 : i32 + %false = arith.constant false + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_idx = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_src, %c0_i64, %c4_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_idx, %c0_i64, %c4_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c256 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b8 %remaining : i32 -> !pto.mask, i32 + %src = pto.vlds %ub_src[%offset] : !pto.ptr -> !pto.vreg<256xui8> + %idx = pto.vlds %ub_idx[%offset] : !pto.ptr -> !pto.vreg<256xui8> + %out = pto.vselr %src, %idx : !pto.vreg<256xui8>, !pto.vreg<256xui8> -> !pto.vreg<256xui8> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<256xui8>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c4_i64, %c256_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vselr-u8/launch.cpp b/test/vpto/cases/micro-op/compare-select/vselr-u8/launch.cpp new file mode 100644 index 000000000..a8d38e8ea --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vselr-u8/launch.cpp @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/compare-select/vselr-u8 +// family: compare-select +// target_ops: pto.vselr +// scenarios: core-u8, full-mask, explicit-lane-index +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vselr_u8_kernel_2d(__gm__ uint8_t *v1, + __gm__ uint8_t *v2, + __gm__ uint8_t *v3); + +void LaunchVselr_u8_kernel_2d(uint8_t *v1, uint8_t *v2, uint8_t *v3, + void *stream) { + vselr_u8_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint8_t *)v1, + (__gm__ uint8_t *)v2, + (__gm__ uint8_t *)v3); +} diff --git a/test/vpto/cases/micro-op/compare-select/vselr-u8/main.cpp b/test/vpto/cases/micro-op/compare-select/vselr-u8/main.cpp new file mode 100644 index 000000000..78f0a8d16 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vselr-u8/main.cpp @@ -0,0 +1,106 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/compare-select/vselr-u8 +// family: compare-select +// target_ops: pto.vselr +// scenarios: core-u8, full-mask, explicit-lane-index +// ----------------------------------------------------------------------------- +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVselr_u8_kernel_2d(uint8_t *v1, uint8_t *v2, uint8_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint8_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint8_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint8_t); + uint8_t *v1Host = nullptr; + uint8_t *v1Device = nullptr; + uint8_t *v2Host = nullptr; + uint8_t *v2Device = nullptr; + uint8_t *v3Host = nullptr; + uint8_t *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVselr_u8_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vselr-u8/stub.cpp b/test/vpto/cases/micro-op/compare-select/vselr-u8/stub.cpp new file mode 100644 index 000000000..270618fa7 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vselr-u8/stub.cpp @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/compare-select/vselr-u8 +// family: compare-select +// target_ops: pto.vselr +// scenarios: core-u8, full-mask, explicit-lane-index +// ----------------------------------------------------------------------------- +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vselr_u8_kernel_2d(__gm__ uint8_t *v1, + __gm__ uint8_t *v2, + __gm__ uint8_t *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/compare-select/vselr/compare.py b/test/vpto/cases/micro-op/compare-select/vselr/compare.py new file mode 100755 index 000000000..b0da91196 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vselr/compare.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/compare-select/vselr +# family: compare-select +# target_ops: pto.vselr +# scenarios: core-f32, full-mask, explicit-lane-index + +import os +import sys +import numpy as np + +def compare_tensor(golden_path, output_path): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.float32) + output = np.fromfile(output_path, dtype=np.float32) + if golden.shape != output.shape: + return False + if not np.allclose(golden, output, rtol=0.0, atol=0.0): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch: idx={idx} golden={float(golden[idx])} out={float(output[idx])}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_tensor("golden_v3.bin", "v3.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vselr/golden.py b/test/vpto/cases/micro-op/compare-select/vselr/golden.py new file mode 100755 index 000000000..6362369bd --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vselr/golden.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/compare-select/vselr +# family: compare-select +# target_ops: pto.vselr +# scenarios: core-f32, full-mask, explicit-lane-index +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-3.0, 3.0, size=(ROWS, COLS)).astype(np.float32, copy=False) + src = v1.reshape(16, 64) + lane_ids = np.arange(64, dtype=np.int32) + idx = np.empty((16, 64), dtype=np.int32) + for row in range(16): + idx[row] = (lane_ids[::-1] + row * 3 + (lane_ids // 8) * 3) % 64 + golden_v3 = np.take_along_axis(src, idx, axis=1).astype(np.float32, copy=False).reshape(ROWS, COLS) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + idx.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vselr validation." + ) + parser.add_argument( + "--output-dir", + type=Path, + default=Path("."), + help="Directory where v1.bin/v2.bin/v3.bin/golden_v3.bin are written.", + ) + parser.add_argument( + "--seed", + type=int, + default=SEED, + help="Numpy random seed.", + ) + args = parser.parse_args() + + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vselr/kernel.pto b/test/vpto/cases/micro-op/compare-select/vselr/kernel.pto new file mode 100644 index 000000000..f1eff1633 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vselr/kernel.pto @@ -0,0 +1,71 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/compare-select/vselr +// family: compare-select +// target_ops: pto.vselr +// scenarios: core-f32, full-mask, explicit-lane-index +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vselr_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr) { + %c8192_i64 = arith.constant 8192 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c0_i64 = arith.constant 0 : i64 + %c32 = arith.constant 32 : index + %0 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %1 = arith.index_castui %c32 : index to i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c4_i64 = arith.constant 4 : i64 + %2 = arith.muli %1, %c4_i64 : i64 + %c128_i64 = arith.constant 128 : i64 + %3 = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + %4 = arith.index_castui %c0_i64 : i64 to index + %5 = pto.addptr %3, %4 : -> + pto.set_loop2_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 + pto.set_loop1_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + %6 = pto.castptr %5 : !pto.ptr -> !pto.ptr + %false = arith.constant false + pto.copy_gm_to_ubuf %6, %0, %c0_i64, %c32_i64, %2, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + %7 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %8 = pto.castptr %arg1 : !pto.ptr -> !pto.ptr + %9 = pto.addptr %8, %4 : -> + pto.set_loop2_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 + pto.set_loop1_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + %10 = pto.castptr %9 : !pto.ptr -> !pto.ptr + pto.copy_gm_to_ubuf %10, %7, %c0_i64, %c32_i64, %2, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + %11 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c1024_i32 = arith.constant 1024 : i32 + pto.vecscope { + %16 = scf.for %arg4 = %c0 to %c16 step %c1 iter_args(%arg5 = %c1024_i32) -> (i32) { + %17 = arith.muli %arg4, %c64 : index + %mask, %scalar_out = pto.plt_b32 %arg5 : i32 -> !pto.mask, i32 + %25 = pto.vlds %0[%17] : !pto.ptr -> !pto.vreg<64xf32> + %26 = pto.vlds %7[%17] : !pto.ptr -> !pto.vreg<64xi32> + %27 = pto.vselr %25, %26 : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> + pto.vsts %27, %11[%17], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %scalar_out : i32 + } + } + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + %c1024_i64 = arith.constant 1024 : i64 + %12 = arith.muli %1, %c4_i64 : i64 + %13 = pto.castptr %arg2 : !pto.ptr -> !pto.ptr + %14 = pto.addptr %13, %4 : -> + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.set_loop1_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 + %15 = pto.castptr %14 : !pto.ptr -> !pto.ptr + pto.copy_ubuf_to_gm %11, %15, %c0_i64, %c32_i64, %12, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vselr/launch.cpp b/test/vpto/cases/micro-op/compare-select/vselr/launch.cpp new file mode 100644 index 000000000..68e4c6169 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vselr/launch.cpp @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/compare-select/vselr +// family: compare-select +// target_ops: pto.vselr +// scenarios: core-f32, full-mask, explicit-lane-index +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vselr_kernel_2d(__gm__ float *v1, + __gm__ int *v2, + __gm__ float *v3); + +void LaunchVselr_kernel_2d(float *v1, int *v2, float *v3, + void *stream) { + vselr_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ int *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/compare-select/vselr/main.cpp b/test/vpto/cases/micro-op/compare-select/vselr/main.cpp new file mode 100644 index 000000000..62fd4bebf --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vselr/main.cpp @@ -0,0 +1,109 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/compare-select/vselr +// family: compare-select +// target_ops: pto.vselr +// scenarios: core-f32, full-mask, explicit-lane-index +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVselr_kernel_2d(float *v1, int *v2, float *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + int *v2Host = nullptr; + int *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVselr_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vselr/stub.cpp b/test/vpto/cases/micro-op/compare-select/vselr/stub.cpp new file mode 100644 index 000000000..f850d1f28 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vselr/stub.cpp @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/compare-select/vselr +// family: compare-select +// target_ops: pto.vselr +// scenarios: core-f32, full-mask, explicit-lane-index +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vselr_kernel_2d(__gm__ float *v1, + __gm__ int *v2, + __gm__ float *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-special/compare.py b/test/vpto/cases/micro-op/conversion/vcvt-f16-special/compare.py new file mode 100644 index 000000000..751000b6f --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-special/compare.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-special/golden.py b/test/vpto/cases/micro-op/conversion/vcvt-f16-special/golden.py new file mode 100755 index 000000000..66921951a --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-special/golden.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/conversion/vcvt-f16-special +# family: conversion +# target_ops: pto.vcvt +# scenarios: f16-to-f32, exceptional-values +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + del seed + special = np.array( + [ + np.float16(0.0), + np.float16(-0.0), + np.float16(1.0), + np.float16(-1.0), + np.float16(np.inf), + np.float16(-np.inf), + np.float16(np.nan), + np.float16(65504.0), + np.float16(-65504.0), + np.float16(6.1035e-05), + np.float16(-6.1035e-05), + np.float16(5.9605e-08), + np.float16(-5.9605e-08), + np.float16(123.75), + np.float16(-123.75), + np.float16(0.33325), + ], + dtype=np.float16, + ) + v1 = np.resize(special, ROWS * COLS).reshape(ROWS, COLS) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = v1.astype(np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vcvt-f16-special validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-special/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-f16-special/kernel.pto new file mode 100644 index 000000000..d475fbb7a --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-special/kernel.pto @@ -0,0 +1,40 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vcvt_f16_special_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %full_mask = pto.pset_b32 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c64 { + %loaded = pto.vlds %ub_in[%offset] {dist = "UNPK_B16"} : !pto.ptr -> !pto.vreg<128xf16> + %out = pto.vcvt %loaded {part = "EVEN"} : !pto.vreg<128xf16> -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %full_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-special/launch.cpp b/test/vpto/cases/micro-op/conversion/vcvt-f16-special/launch.cpp new file mode 100644 index 000000000..214c9b5c8 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-special/launch.cpp @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-f16-special +// family: conversion +// target_ops: pto.vcvt +// scenarios: f16-to-f32, exceptional-values +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcvt_f16_special_kernel_2d(__gm__ half *v1, + __gm__ float *v2); + +void LaunchVcvt_f16_special_kernel_2d(uint16_t *v1, float *v2, void *stream) { + vcvt_f16_special_kernel_2d<<<1, nullptr, stream>>>((__gm__ half *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-special/main.cpp b/test/vpto/cases/micro-op/conversion/vcvt-f16-special/main.cpp new file mode 100644 index 000000000..8f83d1edd --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-special/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-f16-special +// family: conversion +// target_ops: pto.vcvt +// scenarios: f16-to-f32, exceptional-values +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcvt_f16_special_kernel_2d(uint16_t *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcvt_f16_special_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-special/stub.cpp b/test/vpto/cases/micro-op/conversion/vcvt-f16-special/stub.cpp new file mode 100644 index 000000000..fcc0046e9 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-special/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-f16-special +// family: conversion +// target_ops: pto.vcvt +// scenarios: f16-to-f32, exceptional-values +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vcvt_f16_special_kernel_2d(__gm__ half *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/compare.py b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/compare.py new file mode 100755 index 000000000..751000b6f --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/compare.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/golden.py b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/golden.py new file mode 100644 index 000000000..e071074e7 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/golden.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/conversion/vcvt-f16-to-f32-part-even +# family: conversion +# target_ops: pto.vcvt +# scenarios: f16-to-f32, full-mask, part-even + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=ELEMS).astype(np.float16) + # Kernel writes 8 chunks (offset 0..448, step 64), each chunk converts the + # lower 16-bit half (PART_EVEN) from packed f16 pairs in a 128-lane load. + out_elems = 512 + v2 = np.zeros(out_elems, dtype=np.float32) + golden_v2 = np.empty(out_elems, dtype=np.float32) + for block in range(0, out_elems, 64): + src = v1[block : block + 128 : 2].astype(np.float32, copy=False) + golden_v2[block : block + 64] = src + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vcvt-f16-to-f32 part-even validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/kernel.pto new file mode 100644 index 000000000..60dfc3f26 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/kernel.pto @@ -0,0 +1,44 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vcvt_f16_to_f32_part_even_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c512 = arith.constant 512 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %full_mask = pto.pset_b32 "PAT_ALL" : !pto.mask + // Use packed f16 load (no UNPK): PART_EVEN selects the lower 16-bit + // element from each f16 pair inside a b32 lane. + scf.for %offset = %c0 to %c512 step %c64 { + %loaded = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xf16> + %out = pto.vcvt %loaded {part = "EVEN"} : !pto.vreg<128xf16> -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %full_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c16_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/launch.cpp b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/launch.cpp new file mode 100644 index 000000000..8e321d1a5 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/launch.cpp @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-f16-to-f32-part-even +// family: conversion +// target_ops: pto.vcvt +// scenarios: f16-to-f32, full-mask, part-even +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcvt_f16_to_f32_part_even_kernel_2d(__gm__ half *v1, + __gm__ float *v2); + +void LaunchVcvt_f16_to_f32_part_even_kernel_2d(uint16_t *v1, float *v2, void *stream) { + vcvt_f16_to_f32_part_even_kernel_2d<<<1, nullptr, stream>>>((__gm__ half *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/main.cpp b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/main.cpp new file mode 100644 index 000000000..1925124c1 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-f16-to-f32-part-even +// family: conversion +// target_ops: pto.vcvt +// scenarios: f16-to-f32, full-mask, part-even +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcvt_f16_to_f32_part_even_kernel_2d(uint16_t *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 512; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcvt_f16_to_f32_part_even_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/stub.cpp b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/stub.cpp new file mode 100644 index 000000000..8698fb8bd --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-f16-to-f32-part-even +// family: conversion +// target_ops: pto.vcvt +// scenarios: f16-to-f32, full-mask, part-even +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vcvt_f16_to_f32_part_even_kernel_2d(__gm__ half *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/compare.py b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/compare.py new file mode 100755 index 000000000..751000b6f --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/compare.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/golden.py b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/golden.py new file mode 100644 index 000000000..2b7822bc8 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/golden.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/conversion/vcvt-f16-to-f32-part-odd +# family: conversion +# target_ops: pto.vcvt +# scenarios: f16-to-f32, full-mask, part-odd + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=ELEMS).astype(np.float16) + # Kernel writes 8 chunks (offset 0..448, step 64), each chunk converts the + # upper 16-bit half (PART_ODD) from packed f16 pairs in a 128-lane load. + out_elems = 512 + v2 = np.zeros(out_elems, dtype=np.float32) + golden_v2 = np.empty(out_elems, dtype=np.float32) + for block in range(0, out_elems, 64): + src = v1[block + 1 : block + 128 : 2].astype(np.float32, copy=False) + golden_v2[block : block + 64] = src + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vcvt-f16-to-f32 part-odd validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/kernel.pto new file mode 100644 index 000000000..e20d43ad5 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/kernel.pto @@ -0,0 +1,44 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vcvt_f16_to_f32_part_odd_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c512 = arith.constant 512 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %full_mask = pto.pset_b32 "PAT_ALL" : !pto.mask + // Use packed f16 load (no UNPK): PART_ODD then selects the upper 16-bit + // element from each f16 pair inside a b32 lane. + scf.for %offset = %c0 to %c512 step %c64 { + %loaded = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xf16> + %out = pto.vcvt %loaded {part = "ODD"} : !pto.vreg<128xf16> -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %full_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c16_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/launch.cpp b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/launch.cpp new file mode 100644 index 000000000..db23cbbf4 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/launch.cpp @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-f16-to-f32-part-odd +// family: conversion +// target_ops: pto.vcvt +// scenarios: f16-to-f32, full-mask, part-odd +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcvt_f16_to_f32_part_odd_kernel_2d(__gm__ half *v1, + __gm__ float *v2); + +void LaunchVcvt_f16_to_f32_part_odd_kernel_2d(uint16_t *v1, float *v2, void *stream) { + vcvt_f16_to_f32_part_odd_kernel_2d<<<1, nullptr, stream>>>((__gm__ half *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/main.cpp b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/main.cpp new file mode 100644 index 000000000..567aafa0a --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-f16-to-f32-part-odd +// family: conversion +// target_ops: pto.vcvt +// scenarios: f16-to-f32, full-mask, part-odd +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcvt_f16_to_f32_part_odd_kernel_2d(uint16_t *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 512; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcvt_f16_to_f32_part_odd_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/stub.cpp b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/stub.cpp new file mode 100644 index 000000000..ac1c05443 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-f16-to-f32-part-odd +// family: conversion +// target_ops: pto.vcvt +// scenarios: f16-to-f32, full-mask, part-odd +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vcvt_f16_to_f32_part_odd_kernel_2d(__gm__ half *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/compare.py b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/compare.py new file mode 100755 index 000000000..751000b6f --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/compare.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/golden.py b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/golden.py new file mode 100755 index 000000000..903c2385d --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/conversion/vcvt-f16-to-f32 +# family: conversion +# target_ops: pto.vcvt +# scenarios: f16-to-f32, full-mask + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=ELEMS).astype(np.float16) + v2 = np.zeros(ELEMS, dtype=np.float32) + golden_v2 = v1.astype(np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vcvt-f16-to-f32 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/kernel.pto new file mode 100644 index 000000000..a49630d82 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/kernel.pto @@ -0,0 +1,40 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vcvt_f16_to_f32_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %full_mask = pto.pset_b32 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c64 { + %loaded = pto.vlds %ub_in[%offset] {dist = "UNPK_B16"} : !pto.ptr -> !pto.vreg<128xf16> + %out = pto.vcvt %loaded {part = "EVEN"} : !pto.vreg<128xf16> -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %full_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/launch.cpp b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/launch.cpp new file mode 100644 index 000000000..4998ce110 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/launch.cpp @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-f16-to-f32 +// family: conversion +// target_ops: pto.vcvt +// scenarios: f16-to-f32, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcvt_f16_to_f32_kernel_2d(__gm__ half *v1, + __gm__ float *v2); + +void LaunchVcvt_f16_to_f32_kernel_2d(uint16_t *v1, float *v2, void *stream) { + vcvt_f16_to_f32_kernel_2d<<<1, nullptr, stream>>>((__gm__ half *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/main.cpp b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/main.cpp new file mode 100644 index 000000000..17f92c862 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-f16-to-f32 +// family: conversion +// target_ops: pto.vcvt +// scenarios: f16-to-f32, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcvt_f16_to_f32_kernel_2d(uint16_t *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcvt_f16_to_f32_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/stub.cpp b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/stub.cpp new file mode 100644 index 000000000..984b781f5 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-f16-to-f32 +// family: conversion +// target_ops: pto.vcvt +// scenarios: f16-to-f32, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vcvt_f16_to_f32_kernel_2d(__gm__ half *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-special/compare.py b/test/vpto/cases/micro-op/conversion/vcvt-f32-special/compare.py new file mode 100644 index 000000000..d2d022505 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-special/compare.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float16, 1e-3) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-special/golden.py b/test/vpto/cases/micro-op/conversion/vcvt-f32-special/golden.py new file mode 100755 index 000000000..33742285b --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-special/golden.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/conversion/vcvt-f32-special +# family: conversion +# target_ops: pto.vcvt +# scenarios: f32-to-f16, exceptional-values +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + del seed + special = np.array( + [ + 0.0, + -0.0, + 1.0, + -1.0, + np.inf, + -np.inf, + np.nan, + 65504.0, + -65504.0, + 1.0e-8, + -1.0e-8, + 1.0e-4, + -1.0e-4, + 123.75, + -123.75, + 0.33333334, + ], + dtype=np.float32, + ) + flat = np.resize(special, ROWS * COLS).astype(np.float32) + v1 = flat.reshape(ROWS, COLS) + v2 = np.zeros((ROWS, COLS), dtype=np.float16) + golden_flat = np.zeros(ROWS * COLS, dtype=np.float16) + + for offset in range(0, ROWS * COLS, 128): + lower = flat[offset : offset + 64].astype(np.float16) + upper = flat[offset + 64 : offset + 128].astype(np.float16) + merged = np.empty(128, dtype=np.float16) + merged[0::2] = lower + merged[1::2] = upper + golden_flat[offset : offset + 128] = merged + + golden_v2 = golden_flat.reshape(ROWS, COLS) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vcvt-f32-special validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-special/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-f32-special/kernel.pto new file mode 100644 index 000000000..dfd8c8ac2 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-special/kernel.pto @@ -0,0 +1,48 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vcvt_f32_special_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %upper_offset = arith.addi %offset, %c64 : index + %lower = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %upper = pto.vlds %ub_in[%upper_offset] : !pto.ptr -> !pto.vreg<64xf32> + %even = pto.vcvt %lower {rnd = "R", sat = "SAT", part = "EVEN"} : !pto.vreg<64xf32> -> !pto.vreg<128xf16> + %odd = pto.vcvt %upper {rnd = "R", sat = "SAT", part = "ODD"} : !pto.vreg<64xf32> -> !pto.vreg<128xf16> + %merged = pto.vor %even, %odd, %mask : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> + pto.vsts %merged, %ub_out[%offset], %mask : !pto.vreg<128xf16>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-special/launch.cpp b/test/vpto/cases/micro-op/conversion/vcvt-f32-special/launch.cpp new file mode 100644 index 000000000..64c50ea0d --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-special/launch.cpp @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-f32-special +// family: conversion +// target_ops: pto.vcvt +// scenarios: f32-to-f16, exceptional-values +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcvt_f32_special_kernel_2d(__gm__ float *v1, + __gm__ half *v2); + +void LaunchVcvt_f32_special_kernel_2d(float *v1, uint16_t *v2, void *stream) { + vcvt_f32_special_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ half *)v2); +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-special/main.cpp b/test/vpto/cases/micro-op/conversion/vcvt-f32-special/main.cpp new file mode 100644 index 000000000..73f29e4ca --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-special/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-f32-special +// family: conversion +// target_ops: pto.vcvt +// scenarios: f32-to-f16, exceptional-values +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcvt_f32_special_kernel_2d(float *v1, uint16_t *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + float *v1Host = nullptr; + float *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcvt_f32_special_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-special/stub.cpp b/test/vpto/cases/micro-op/conversion/vcvt-f32-special/stub.cpp new file mode 100644 index 000000000..97bce8b24 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-special/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-f32-special +// family: conversion +// target_ops: pto.vcvt +// scenarios: f32-to-f16, exceptional-values +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vcvt_f32_special_kernel_2d(__gm__ float *v1, + __gm__ half *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/compare.py b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/compare.py new file mode 100644 index 000000000..d2d022505 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/compare.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float16, 1e-3) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/golden.py b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/golden.py new file mode 100644 index 000000000..ee8fd3890 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/conversion/vcvt-f32-to-f16-pk-b32 +# family: conversion +# target_ops: pto.vcvt, pto.vsts +# scenarios: f32-to-f16, pk-b32-store, full-mask + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=ELEMS).astype(np.float32) + v2 = np.zeros(ELEMS, dtype=np.float16) + golden_v2 = v1.astype(np.float16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vcvt-f32-to-f16-pk-b32 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/kernel.pto new file mode 100644 index 000000000..c6a08514e --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/kernel.pto @@ -0,0 +1,46 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-f32-to-f16-pk-b32 +// family: conversion +// target_ops: pto.vcvt, pto.vsts +// scenarios: f32-to-f16, pk-b32-store, full-mask +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vcvt_f32_to_f16_pk_b32_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c64 { + %loaded = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %converted = pto.vcvt %loaded {rnd = "R", sat = "SAT", part = "EVEN"} : !pto.vreg<64xf32> -> !pto.vreg<128xf16> + pto.vsts %converted, %ub_out[%offset], %mask {dist = "PK_B32"} : !pto.vreg<128xf16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/launch.cpp b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/launch.cpp new file mode 100644 index 000000000..77055836f --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcvt_f32_to_f16_pk_b32_kernel(__gm__ float *v1, + __gm__ half *v2); + +void LaunchVcvt_f32_to_f16_pk_b32_kernel(float *v1, aclFloat16 *v2, void *stream) { + vcvt_f32_to_f16_pk_b32_kernel<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ half *)v2); +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/main.cpp b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/main.cpp new file mode 100644 index 000000000..8b7886671 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/main.cpp @@ -0,0 +1,122 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcvt_f32_to_f16_pk_b32_kernel(float *v1, aclFloat16 *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(aclFloat16); + float *v1Host = nullptr; + float *v1Device = nullptr; + aclFloat16 *v2Host = nullptr; + aclFloat16 *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcvt_f32_to_f16_pk_b32_kernel(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/stub.cpp b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/stub.cpp new file mode 100644 index 000000000..b4d615e95 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/stub.cpp @@ -0,0 +1,21 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vcvt_f32_to_f16_pk_b32_kernel(__gm__ float *v1, + __gm__ half *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/compare.py b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/compare.py new file mode 100644 index 000000000..d2d022505 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/compare.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float16, 1e-3) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/golden.py b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/golden.py new file mode 100755 index 000000000..55a924e97 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/golden.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/conversion/vcvt-f32-to-f16 +# family: conversion +# target_ops: pto.vcvt +# scenarios: f32-to-f16, full-mask + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=ELEMS).astype(np.float32) + v2 = np.zeros(ELEMS, dtype=np.float16) + golden_v2 = np.zeros(ELEMS, dtype=np.float16) + + # Width-changing f32->f16 lowering uses two 64-lane f32 vectors, converts + # them into EVEN/ODD halves, then merges them into one 128-lane f16 vector. + for offset in range(0, ELEMS, 128): + lower = v1[offset : offset + 64].astype(np.float16) + upper = v1[offset + 64 : offset + 128].astype(np.float16) + merged = np.empty(128, dtype=np.float16) + merged[0::2] = lower + merged[1::2] = upper + golden_v2[offset : offset + 128] = merged + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vcvt-f32-to-f16 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/kernel.pto new file mode 100644 index 000000000..91b9045ef --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/kernel.pto @@ -0,0 +1,48 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vcvt_f32_to_f16_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %upper_offset = arith.addi %offset, %c64 : index + %lower = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %upper = pto.vlds %ub_in[%upper_offset] : !pto.ptr -> !pto.vreg<64xf32> + %even = pto.vcvt %lower {rnd = "R", sat = "SAT", part = "EVEN"} : !pto.vreg<64xf32> -> !pto.vreg<128xf16> + %odd = pto.vcvt %upper {rnd = "R", sat = "SAT", part = "ODD"} : !pto.vreg<64xf32> -> !pto.vreg<128xf16> + %merged = pto.vor %even, %odd, %mask : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> + pto.vsts %merged, %ub_out[%offset], %mask : !pto.vreg<128xf16>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/launch.cpp b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/launch.cpp new file mode 100644 index 000000000..8dcc00348 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/launch.cpp @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-f32-to-f16 +// family: conversion +// target_ops: pto.vcvt +// scenarios: f32-to-f16, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcvt_f32_to_f16_kernel_2d(__gm__ float *v1, + __gm__ half *v2); + +void LaunchVcvt_f32_to_f16_kernel_2d(float *v1, uint16_t *v2, void *stream) { + vcvt_f32_to_f16_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ half *)v2); +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/main.cpp b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/main.cpp new file mode 100644 index 000000000..cf7a6c2de --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-f32-to-f16 +// family: conversion +// target_ops: pto.vcvt +// scenarios: f32-to-f16, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcvt_f32_to_f16_kernel_2d(float *v1, uint16_t *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + float *v1Host = nullptr; + float *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcvt_f32_to_f16_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/stub.cpp b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/stub.cpp new file mode 100644 index 000000000..cf0a001f7 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-f32-to-f16 +// family: conversion +// target_ops: pto.vcvt +// scenarios: f32-to-f16, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vcvt_f32_to_f16_kernel_2d(__gm__ float *v1, + __gm__ half *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/compare.py b/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/compare.py new file mode 100644 index 000000000..fe3cc3abc --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/conversion/vcvt-i32-to-i16-overflow +# family: conversion +# target_ops: pto.vcvt +# scenarios: i32-to-i16, integer-overflow + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.int16) + output = np.fromfile(output_path, dtype=np.int16) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/golden.py b/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/golden.py new file mode 100644 index 000000000..6fbfc4834 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/golden.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/conversion/vcvt-i32-to-i16-overflow +# family: conversion +# target_ops: pto.vcvt +# scenarios: i32-to-i16, integer-overflow + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 +I16_MIN = np.iinfo(np.int16).min +I16_MAX = np.iinfo(np.int16).max + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + data = rng.integers(-200000, 200000, size=ELEMS, dtype=np.int32) + edge = np.array([ + -40000, -32769, -32768, -32767, -1, 0, 1, 32766, + 32767, 32768, 40000, 70000, -70000, 65535, -65535, 123456, + ], dtype=np.int32) + data[:edge.size] = edge + clipped = np.clip(data, I16_MIN, I16_MAX).astype(np.int16) + golden = np.zeros(ELEMS, dtype=np.int16) + for offset in range(0, ELEMS, 128): + lower = clipped[offset : offset + 64] + upper = clipped[offset + 64 : offset + 128] + merged = np.empty(128, dtype=np.int16) + merged[0::2] = lower + merged[1::2] = upper + golden[offset : offset + 128] = merged + + output_dir.mkdir(parents=True, exist_ok=True) + data.tofile(output_dir / "v1.bin") + np.zeros(ELEMS, dtype=np.int16).tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/kernel.pto new file mode 100644 index 000000000..292a4df27 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/kernel.pto @@ -0,0 +1,53 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-i32-to-i16-overflow +// family: conversion +// target_ops: pto.vcvt +// scenarios: i32-to-i16, integer-overflow +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vcvt_i32_to_i16_overflow_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %upper_offset = arith.addi %offset, %c64 : index + %lower = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xi32> + %upper = pto.vlds %ub_in[%upper_offset] : !pto.ptr -> !pto.vreg<64xi32> + %even = pto.vcvt %lower {sat = "SAT", part = "EVEN"} : !pto.vreg<64xi32> -> !pto.vreg<128xi16> + %odd = pto.vcvt %upper {sat = "SAT", part = "ODD"} : !pto.vreg<64xi32> -> !pto.vreg<128xi16> + %merged = pto.vor %even, %odd, %mask : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + pto.vsts %merged, %ub_out[%offset], %mask : !pto.vreg<128xi16>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/launch.cpp b/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/launch.cpp new file mode 100644 index 000000000..0ac6bc67d --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/launch.cpp @@ -0,0 +1,49 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcvt_i32_to_i16_overflow_kernel( + __gm__ int32_t *v1, __gm__ int16_t *v2); + +void LaunchVcvt_i32_to_i16_overflow_kernel(int32_t *v1, int16_t *v2, + void *stream) { + vcvt_i32_to_i16_overflow_kernel<<<1, nullptr, stream>>>( + (__gm__ int32_t *)v1, (__gm__ int16_t *)v2); +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/main.cpp b/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/main.cpp new file mode 100644 index 000000000..ab8500906 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-i32-to-i16-overflow +// family: conversion +// target_ops: pto.vcvt +// scenarios: i32-to-i16, integer-overflow +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcvt_i32_to_i16_overflow_kernel(int32_t *v1, int16_t *v2, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(int32_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int16_t); + int32_t *v1Host = nullptr; + int32_t *v1Device = nullptr; + int16_t *v2Host = nullptr; + int16_t *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcvt_i32_to_i16_overflow_kernel(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/stub.cpp b/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/stub.cpp new file mode 100644 index 000000000..3903859dc --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/stub.cpp @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vcvt_i32_to_i16_overflow_kernel( + __gm__ int32_t *v1, __gm__ int16_t *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-tail-special/compare.py b/test/vpto/cases/micro-op/conversion/vcvt-tail-special/compare.py new file mode 100644 index 000000000..166196a8e --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-tail-special/compare.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys + +import numpy as np + + +LOGICAL_ELEMS = 1000 + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float16, 1e-3, LOGICAL_ELEMS) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-tail-special/golden.py b/test/vpto/cases/micro-op/conversion/vcvt-tail-special/golden.py new file mode 100755 index 000000000..229bb9818 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-tail-special/golden.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/conversion/vcvt-tail-special +# family: conversion +# target_ops: pto.vcvt +# scenarios: f32-to-f16, tail-mask, exceptional-values +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LOGICAL_ELEMS = 1000 + + +def generate(output_dir: Path, seed: int) -> None: + del seed + special = np.array( + [ + 0.0, + -0.0, + 1.0, + -1.0, + np.inf, + -np.inf, + np.nan, + 65504.0, + -65504.0, + 1.0e-8, + -1.0e-8, + 1.0e-4, + -1.0e-4, + 123.75, + -123.75, + 0.33333334, + ], + dtype=np.float32, + ) + flat = np.resize(special, ROWS * COLS).astype(np.float32) + flat[LOGICAL_ELEMS:] = 0.0 + v1 = flat.reshape(ROWS, COLS) + v2 = np.zeros((ROWS, COLS), dtype=np.float16) + golden_flat = np.zeros(ROWS * COLS, dtype=np.float16) + + remaining = LOGICAL_ELEMS + for offset in range(0, ROWS * COLS, 128): + lower = flat[offset : offset + 64].astype(np.float16) + upper = flat[offset + 64 : offset + 128].astype(np.float16) + merged = np.empty(128, dtype=np.float16) + merged[0::2] = lower + merged[1::2] = upper + active = min(remaining, 128) + golden_flat[offset : offset + active] = merged[:active] + remaining = max(remaining - 128, 0) + + golden_v2 = golden_flat.reshape(ROWS, COLS) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vcvt-tail-special validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-tail-special/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-tail-special/kernel.pto new file mode 100644 index 000000000..9a907e9a1 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-tail-special/kernel.pto @@ -0,0 +1,48 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vcvt_tail_special_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1000_i32 = arith.constant 1000 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %upper_offset = arith.addi %offset, %c64 : index + %lower = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %upper = pto.vlds %ub_in[%upper_offset] : !pto.ptr -> !pto.vreg<64xf32> + %even = pto.vcvt %lower {rnd = "R", sat = "SAT", part = "EVEN"} : !pto.vreg<64xf32> -> !pto.vreg<128xf16> + %odd = pto.vcvt %upper {rnd = "R", sat = "SAT", part = "ODD"} : !pto.vreg<64xf32> -> !pto.vreg<128xf16> + %merged = pto.vor %even, %odd, %mask : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> + pto.vsts %merged, %ub_out[%offset], %mask : !pto.vreg<128xf16>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-tail-special/launch.cpp b/test/vpto/cases/micro-op/conversion/vcvt-tail-special/launch.cpp new file mode 100644 index 000000000..128254d29 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-tail-special/launch.cpp @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-tail-special +// family: conversion +// target_ops: pto.vcvt +// scenarios: f32-to-f16, tail-mask, exceptional-values +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcvt_tail_special_kernel_2d(__gm__ float *v1, + __gm__ half *v2); + +void LaunchVcvt_tail_special_kernel_2d(float *v1, uint16_t *v2, void *stream) { + vcvt_tail_special_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ half *)v2); +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-tail-special/main.cpp b/test/vpto/cases/micro-op/conversion/vcvt-tail-special/main.cpp new file mode 100644 index 000000000..155e88b98 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-tail-special/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-tail-special +// family: conversion +// target_ops: pto.vcvt +// scenarios: f32-to-f16, tail-mask, exceptional-values +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcvt_tail_special_kernel_2d(float *v1, uint16_t *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + float *v1Host = nullptr; + float *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcvt_tail_special_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-tail-special/stub.cpp b/test/vpto/cases/micro-op/conversion/vcvt-tail-special/stub.cpp new file mode 100644 index 000000000..32976ef73 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-tail-special/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-tail-special +// family: conversion +// target_ops: pto.vcvt +// scenarios: f32-to-f16, tail-mask, exceptional-values +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vcvt_tail_special_kernel_2d(__gm__ float *v1, + __gm__ half *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-tail/compare.py b/test/vpto/cases/micro-op/conversion/vcvt-tail/compare.py new file mode 100644 index 000000000..166196a8e --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-tail/compare.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys + +import numpy as np + + +LOGICAL_ELEMS = 1000 + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float16, 1e-3, LOGICAL_ELEMS) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-tail/golden.py b/test/vpto/cases/micro-op/conversion/vcvt-tail/golden.py new file mode 100755 index 000000000..b121f1e29 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-tail/golden.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/conversion/vcvt-tail +# family: conversion +# target_ops: pto.vcvt +# scenarios: f32-to-f16, tail-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LOGICAL_ELEMS = 1000 +OUT_SENTINEL = np.float16(-123.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + flat = rng.uniform(-8.0, 8.0, size=ROWS * COLS).astype(np.float32) + flat[LOGICAL_ELEMS:] = 0.0 + v1 = flat.reshape(ROWS, COLS) + v2 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float16) + golden_flat = np.full(ROWS * COLS, OUT_SENTINEL, dtype=np.float16) + + remaining = LOGICAL_ELEMS + for offset in range(0, ROWS * COLS, 128): + lower = flat[offset : offset + 64].astype(np.float16) + upper = flat[offset + 64 : offset + 128].astype(np.float16) + merged = np.empty(128, dtype=np.float16) + merged[0::2] = lower + merged[1::2] = upper + active = min(remaining, 128) + golden_flat[offset : offset + active] = merged[:active] + remaining = max(remaining - 128, 0) + + golden_v2 = golden_flat.reshape(ROWS, COLS) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vcvt-tail validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-tail/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-tail/kernel.pto new file mode 100644 index 000000000..0cfaca5e4 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-tail/kernel.pto @@ -0,0 +1,48 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vcvt_tail_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1000_i32 = arith.constant 1000 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %upper_offset = arith.addi %offset, %c64 : index + %lower = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %upper = pto.vlds %ub_in[%upper_offset] : !pto.ptr -> !pto.vreg<64xf32> + %even = pto.vcvt %lower {rnd = "R", sat = "SAT", part = "EVEN"} : !pto.vreg<64xf32> -> !pto.vreg<128xf16> + %odd = pto.vcvt %upper {rnd = "R", sat = "SAT", part = "ODD"} : !pto.vreg<64xf32> -> !pto.vreg<128xf16> + %merged = pto.vor %even, %odd, %mask : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> + pto.vsts %merged, %ub_out[%offset], %mask : !pto.vreg<128xf16>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-tail/launch.cpp b/test/vpto/cases/micro-op/conversion/vcvt-tail/launch.cpp new file mode 100644 index 000000000..5773d5044 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-tail/launch.cpp @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-tail +// family: conversion +// target_ops: pto.vcvt +// scenarios: f32-to-f16, tail-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcvt_tail_kernel_2d(__gm__ float *v1, + __gm__ half *v2); + +void LaunchVcvt_tail_kernel_2d(float *v1, uint16_t *v2, void *stream) { + vcvt_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ half *)v2); +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-tail/main.cpp b/test/vpto/cases/micro-op/conversion/vcvt-tail/main.cpp new file mode 100644 index 000000000..9a0abf5cb --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-tail/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-tail +// family: conversion +// target_ops: pto.vcvt +// scenarios: f32-to-f16, tail-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcvt_tail_kernel_2d(float *v1, uint16_t *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + float *v1Host = nullptr; + float *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcvt_tail_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-tail/stub.cpp b/test/vpto/cases/micro-op/conversion/vcvt-tail/stub.cpp new file mode 100644 index 000000000..e0ffb004e --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-tail/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-tail +// family: conversion +// target_ops: pto.vcvt +// scenarios: f32-to-f16, tail-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vcvt_tail_kernel_2d(__gm__ float *v1, + __gm__ half *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/compare.py b/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/compare.py new file mode 100644 index 000000000..2aec0a573 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/compare.py @@ -0,0 +1,39 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float16, 1e-3) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/golden.py b/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/golden.py new file mode 100644 index 000000000..62578c84a --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 23 + + +def generate(output_dir: Path, seed: int) -> None: + del seed + values = np.array( + [-7.5, -3.25, -0.5, -0.0, 0.0, 0.5, 1.5, 6.75], + dtype=np.float16, + ) + v1 = np.resize(values, ROWS * COLS).reshape(ROWS, COLS).astype(np.float16) + v2 = np.zeros((ROWS, COLS), dtype=np.float16) + golden_v2 = np.trunc(v1.astype(np.float32)).astype(np.float16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vtrc-f16-rounding validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/kernel.pto b/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/kernel.pto new file mode 100644 index 000000000..b26bc2308 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/kernel.pto @@ -0,0 +1,41 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vtrc_f16_rounding_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xf16> + %out = pto.vtrc %vec, %mask, "Z" : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xf16>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/launch.cpp b/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/launch.cpp new file mode 100644 index 000000000..ad3a8682a --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vtrc_f16_rounding_kernel_2d(__gm__ half *v1, + __gm__ half *v2); + +void LaunchVtrc_f16_rounding_kernel_2d(void *v1, void *v2, void *stream) { + vtrc_f16_rounding_kernel_2d<<<1, nullptr, stream>>>((__gm__ half *)v1, + (__gm__ half *)v2); +} diff --git a/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/main.cpp b/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/main.cpp new file mode 100644 index 000000000..604ec1198 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/main.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVtrc_f16_rounding_kernel_2d(void *v1, void *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + void *v1Host = nullptr; + void *v1Device = nullptr; + void *v2Host = nullptr; + void *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost(&v1Host, fileSize_v1)); + ACL_CHECK(aclrtMallocHost(&v2Host, fileSize_v2)); + ACL_CHECK(aclrtMalloc(&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc(&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVtrc_f16_rounding_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/stub.cpp b/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/stub.cpp new file mode 100644 index 000000000..e6468977b --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/stub.cpp @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vtrc_f16_rounding_kernel_2d(__gm__ half *v1, + __gm__ half *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/compare.py b/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/compare.py new file mode 100644 index 000000000..848571069 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/compare.py @@ -0,0 +1,206 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 0.0001) and ok + ok = compare_bin("golden_v4.bin", "v4.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/golden.py b/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/golden.py new file mode 100644 index 000000000..64260e9d6 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/golden.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + v4 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.rint(v1).astype(np.float32, copy=False) + golden_v3 = np.trunc(v1).astype(np.float32, copy=False) + golden_v4 = np.floor(v1).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + v4.reshape(-1).tofile(output_dir / "v4.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + golden_v4.reshape(-1).tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vtrc-f32-rounding validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/kernel.pto b/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/kernel.pto new file mode 100644 index 000000000..1cf47698c --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/kernel.pto @@ -0,0 +1,75 @@ +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5"} { + func.func @vtrc_f32_rounding_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr, + %arg3: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_r = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + %ub_z = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_f = pto.castptr %c12288_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out_r = pto.vtrc %vec, %mask, "R" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %out_z = pto.vtrc %vec, %mask, "Z" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %out_f = pto.vtrc %vec, %mask, "F" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out_r, %ub_r[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.vsts %out_z, %ub_z[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.vsts %out_f, %ub_f[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_r, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_z, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_f, %arg3, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/launch.cpp b/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/launch.cpp new file mode 100644 index 000000000..6e4f1a142 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/launch.cpp @@ -0,0 +1,68 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vtrc_f32_rounding_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3, + __gm__ float *v4); + +void LaunchVtrc_f32_rounding_kernel_2d(float *v1, float *v2, float *v3, + float *v4, void *stream) { + vtrc_f32_rounding_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3, + (__gm__ float *)v4); +} diff --git a/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/main.cpp b/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/main.cpp new file mode 100644 index 000000000..b86de567c --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/main.cpp @@ -0,0 +1,147 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVtrc_f32_rounding_kernel_2d(float *v1, float *v2, float *v3, + float *v4, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + size_t elemCount_v4 = 1024; + size_t fileSize_v4 = elemCount_v4 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + float *v4Host = nullptr; + float *v4Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMallocHost((void **)(&v4Host), fileSize_v4)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v4Device, fileSize_v4, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ReadFile("./v4.bin", fileSize_v4, v4Host, fileSize_v4); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v4Device, fileSize_v4, v4Host, fileSize_v4, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVtrc_f32_rounding_kernel_2d(v1Device, v2Device, v3Device, v4Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v4Host, fileSize_v4, v4Device, fileSize_v4, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + WriteFile("./v3.bin", v3Host, fileSize_v3); + WriteFile("./v4.bin", v4Host, fileSize_v4); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFree(v4Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + aclrtFreeHost(v4Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/stub.cpp b/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/stub.cpp new file mode 100644 index 000000000..549737cf3 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/stub.cpp @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vtrc_f32_rounding_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3, + __gm__ float *v4) { + (void)v1; + (void)v2; + (void)v3; + (void)v4; +} diff --git a/test/vpto/cases/micro-op/conversion/vtrc-f32-special/compare.py b/test/vpto/cases/micro-op/conversion/vtrc-f32-special/compare.py new file mode 100644 index 000000000..38d1deb75 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-f32-special/compare.py @@ -0,0 +1,39 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vtrc-f32-special/golden.py b/test/vpto/cases/micro-op/conversion/vtrc-f32-special/golden.py new file mode 100644 index 000000000..f6251171d --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-f32-special/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + del seed + specials = np.array( + [-np.inf, -7.5, -0.0, 0.0, 1.0, np.inf, np.nan, 3.5], + dtype=np.float32, + ) + v1 = np.resize(specials, ROWS * COLS).reshape(ROWS, COLS).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.trunc(v1).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vtrc-f32-special validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vtrc-f32-special/kernel.pto b/test/vpto/cases/micro-op/conversion/vtrc-f32-special/kernel.pto new file mode 100644 index 000000000..948f50b6c --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-f32-special/kernel.pto @@ -0,0 +1,40 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vtrc_f32_special_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vtrc %vec, %mask, "Z" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/conversion/vtrc-f32-special/launch.cpp b/test/vpto/cases/micro-op/conversion/vtrc-f32-special/launch.cpp new file mode 100644 index 000000000..4d1ad9527 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-f32-special/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vtrc_f32_special_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVtrc_f32_special_kernel_2d(float *v1, float *v2, void *stream) { + vtrc_f32_special_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/conversion/vtrc-f32-special/main.cpp b/test/vpto/cases/micro-op/conversion/vtrc-f32-special/main.cpp new file mode 100644 index 000000000..40f3aa5ae --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-f32-special/main.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVtrc_f32_special_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVtrc_f32_special_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/conversion/vtrc-f32-special/stub.cpp b/test/vpto/cases/micro-op/conversion/vtrc-f32-special/stub.cpp new file mode 100644 index 000000000..1d00c60da --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-f32-special/stub.cpp @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vtrc_f32_special_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/compare.py b/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/compare.py new file mode 100644 index 000000000..848571069 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/compare.py @@ -0,0 +1,206 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 0.0001) and ok + ok = compare_bin("golden_v4.bin", "v4.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/golden.py b/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/golden.py new file mode 100644 index 000000000..a39eaa122 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/golden.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + del seed + boundary = np.array( + [-3.5, -3.0, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5], + dtype=np.float32, + ) + v1 = np.resize(boundary, ROWS * COLS).reshape(ROWS, COLS).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + v4 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.rint(v1).astype(np.float32, copy=False) + golden_v3 = np.trunc(v1).astype(np.float32, copy=False) + golden_v4 = np.floor(v1).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + v4.reshape(-1).tofile(output_dir / "v4.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + golden_v4.reshape(-1).tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vtrc-f32-rounding validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/kernel.pto b/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/kernel.pto new file mode 100644 index 000000000..1cf47698c --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/kernel.pto @@ -0,0 +1,75 @@ +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5"} { + func.func @vtrc_f32_rounding_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr, + %arg3: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_r = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + %ub_z = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_f = pto.castptr %c12288_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out_r = pto.vtrc %vec, %mask, "R" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %out_z = pto.vtrc %vec, %mask, "Z" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %out_f = pto.vtrc %vec, %mask, "F" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out_r, %ub_r[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.vsts %out_z, %ub_z[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.vsts %out_f, %ub_f[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_r, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_z, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_f, %arg3, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/launch.cpp b/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/launch.cpp new file mode 100644 index 000000000..6e4f1a142 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/launch.cpp @@ -0,0 +1,68 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vtrc_f32_rounding_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3, + __gm__ float *v4); + +void LaunchVtrc_f32_rounding_kernel_2d(float *v1, float *v2, float *v3, + float *v4, void *stream) { + vtrc_f32_rounding_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3, + (__gm__ float *)v4); +} diff --git a/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/main.cpp b/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/main.cpp new file mode 100644 index 000000000..b86de567c --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/main.cpp @@ -0,0 +1,147 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVtrc_f32_rounding_kernel_2d(float *v1, float *v2, float *v3, + float *v4, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + size_t elemCount_v4 = 1024; + size_t fileSize_v4 = elemCount_v4 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + float *v4Host = nullptr; + float *v4Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMallocHost((void **)(&v4Host), fileSize_v4)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v4Device, fileSize_v4, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ReadFile("./v4.bin", fileSize_v4, v4Host, fileSize_v4); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v4Device, fileSize_v4, v4Host, fileSize_v4, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVtrc_f32_rounding_kernel_2d(v1Device, v2Device, v3Device, v4Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v4Host, fileSize_v4, v4Device, fileSize_v4, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + WriteFile("./v3.bin", v3Host, fileSize_v3); + WriteFile("./v4.bin", v4Host, fileSize_v4); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFree(v4Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + aclrtFreeHost(v4Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/stub.cpp b/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/stub.cpp new file mode 100644 index 000000000..549737cf3 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/stub.cpp @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vtrc_f32_rounding_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3, + __gm__ float *v4) { + (void)v1; + (void)v2; + (void)v3; + (void)v4; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/compare.py b/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/compare.py new file mode 100755 index 000000000..ae42e6822 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vaxpy-f32 +# family: dsa-sfu +# target_ops: pto.vaxpy +# scenarios: core-f32, scalar-operand, fused-op +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/golden.py b/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/golden.py new file mode 100755 index 000000000..e99a6f22f --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/golden.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vaxpy-f32 +# family: dsa-sfu +# target_ops: pto.vaxpy +# scenarios: core-f32, scalar-operand, fused-op +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +ALPHA = np.float32(0.125) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v3 = (ALPHA * v1 + v2).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/kernel.pto new file mode 100644 index 000000000..517dbdd5b --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/kernel.pto @@ -0,0 +1,55 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vaxpy-f32 +// family: dsa-sfu +// target_ops: pto.vaxpy +// scenarios: core-f32, scalar-operand, fused-op +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vaxpy_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %alpha = arith.constant 1.250000e-01 : f32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_addend = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_addend, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c64 { + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %addend = pto.vlds %ub_addend[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %sum = pto.vaxpy %vec, %addend, %alpha : !pto.vreg<64xf32>, !pto.vreg<64xf32>, f32 -> !pto.vreg<64xf32> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/launch.cpp new file mode 100644 index 000000000..00c358ce2 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/launch.cpp @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vaxpy-f32 +// family: dsa-sfu +// target_ops: pto.vaxpy +// scenarios: core-f32, scalar-operand, fused-op +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vaxpy_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchVaxpy_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + vaxpy_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/main.cpp new file mode 100644 index 000000000..62c80a5fb --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/main.cpp @@ -0,0 +1,102 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vaxpy-f32 +// family: dsa-sfu +// target_ops: pto.vaxpy +// scenarios: core-f32, scalar-operand, fused-op +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVaxpy_kernel_2d(float *v1, float *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVaxpy_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/stub.cpp b/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/stub.cpp new file mode 100644 index 000000000..039db6cb6 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/stub.cpp @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vaxpy-f32 +// family: dsa-sfu +// target_ops: pto.vaxpy +// scenarios: core-f32, scalar-operand, fused-op +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vaxpy_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vbitsort/compare.py b/test/vpto/cases/micro-op/dsa-sfu/vbitsort/compare.py new file mode 100644 index 000000000..efadc7cd0 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vbitsort/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vbitsort +# family: dsa-sfu +# target_ops: pto.vbitsort +# scenarios: index-generation, layout-transform + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vbitsort/golden.py b/test/vpto/cases/micro-op/dsa-sfu/vbitsort/golden.py new file mode 100644 index 000000000..7a15acfe0 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vbitsort/golden.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vbitsort +# family: dsa-sfu +# target_ops: pto.vbitsort +# scenarios: index-generation, layout-transform + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +PROPOSALS = 32 + + +def generate(output_dir: Path, seed: int) -> None: + _ = seed + scores = np.array([ + 3.5, -2.0, 7.0, 7.0, 1.5, 4.25, 0.0, 9.5, + -8.0, 9.5, 2.0, 2.0, 6.0, 6.0, -1.0, 5.75, + 5.75, 4.25, 8.0, 8.0, 3.0, -4.5, 1.25, 1.25, + 10.0, 10.0, -3.0, 0.5, 12.0, 12.0, -7.0, 6.5, + ], dtype=np.float32) + indices = np.array([ + 100, 203, 77, 88, 12, 45, 501, 9, + 333, 7, 900, 901, 31, 32, 400, 62, + 63, 46, 73, 74, 15, 16, 120, 121, + 5, 6, 700, 701, 1, 2, 808, 90, + ], dtype=np.uint32) + + order = np.argsort(-scores, kind="stable") + sorted_scores = scores[order] + sorted_indices = indices[order] + + packed = np.empty(PROPOSALS * 2, dtype=np.uint32) + packed[0::2] = sorted_scores.view(np.uint32) + packed[1::2] = sorted_indices + + output_dir.mkdir(parents=True, exist_ok=True) + scores.tofile(output_dir / "v1.bin") + indices.tofile(output_dir / "v2.bin") + np.zeros(PROPOSALS * 2, dtype=np.uint32).tofile(output_dir / "v3.bin") + packed.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vbitsort/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vbitsort/kernel.pto new file mode 100644 index 000000000..d4a421d50 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vbitsort/kernel.pto @@ -0,0 +1,40 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vbitsort +// family: dsa-sfu +// target_ops: pto.vbitsort +// scenarios: index-generation, layout-transform +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vbitsort_kernel_f32(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c1 = arith.constant 1 : index + %false = arith.constant false + + %ub_scores = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_indices = pto.castptr %c128_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c256_i64 : i64 -> !pto.ptr + + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_scores, %c0_i64, %c1_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_indices, %c0_i64, %c1_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vbitsort %ub_out, %ub_scores, %ub_indices, %c1 : !pto.ptr, !pto.ptr, !pto.ptr, index + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vbitsort/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vbitsort/launch.cpp new file mode 100644 index 000000000..767eefd22 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vbitsort/launch.cpp @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vbitsort_kernel_f32(__gm__ float *scores, + __gm__ uint32_t *indices, + __gm__ uint32_t *output); + +void LaunchVbitsort_kernel_f32(float *scores, uint32_t *indices, uint32_t *output, + void *stream) { + vbitsort_kernel_f32<<<1, nullptr, stream>>>((__gm__ float *)scores, + (__gm__ uint32_t *)indices, + (__gm__ uint32_t *)output); +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vbitsort/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vbitsort/main.cpp new file mode 100644 index 000000000..5d9f5a1b2 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vbitsort/main.cpp @@ -0,0 +1,119 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vbitsort +// family: dsa-sfu +// target_ops: pto.vbitsort +// scenarios: index-generation, layout-transform +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVbitsort_kernel_f32(float *scores, uint32_t *indices, uint32_t *output, + void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 32; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint32_t); + size_t elemCount_v3 = 64; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint32_t); + float *v1Host = nullptr; + float *v1Device = nullptr; + uint32_t *v2Host = nullptr; + uint32_t *v2Device = nullptr; + uint32_t *v3Host = nullptr; + uint32_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVbitsort_kernel_f32(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vbitsort/stub.cpp b/test/vpto/cases/micro-op/dsa-sfu/vbitsort/stub.cpp new file mode 100644 index 000000000..31eba2068 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vbitsort/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vbitsort_kernel_f32(__gm__ float *scores, + __gm__ uint32_t *indices, + __gm__ uint32_t *output) { + (void)scores; + (void)indices; + (void)output; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci/compare.py b/test/vpto/cases/micro-op/dsa-sfu/vci/compare.py new file mode 100755 index 000000000..4a2c212c6 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vci/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vci +# family: dsa-sfu / conversion +# target_ops: pto.vci +# scenarios: index-generation +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.int32, 0) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci/golden.py b/test/vpto/cases/micro-op/dsa-sfu/vci/golden.py new file mode 100755 index 000000000..f044e1819 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vci/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vci +# family: dsa-sfu / conversion +# target_ops: pto.vci +# scenarios: index-generation +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + _ = seed + v1 = np.zeros((ROWS, COLS), dtype=np.int32) + v2 = np.zeros((ROWS, COLS), dtype=np.int32) + golden_v2 = np.arange(ROWS * COLS, dtype=np.int32).reshape(ROWS, COLS) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vci/kernel.pto new file mode 100644 index 000000000..becc05323 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vci/kernel.pto @@ -0,0 +1,52 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vci +// family: dsa-sfu / conversion +// target_ops: pto.vci +// scenarios: index-generation +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vci_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_zero = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_zero, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_out, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %base = arith.index_castui %offset : index to i32 + %indices = pto.vci %base {order = "ASC"} : i32 -> !pto.vreg<64xi32> + pto.vsts %indices, %ub_out[%offset], %mask : !pto.vreg<64xi32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vci/launch.cpp new file mode 100644 index 000000000..33957c516 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vci/launch.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vci +// family: dsa-sfu / conversion +// target_ops: pto.vci +// scenarios: index-generation +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vci_kernel_2d(__gm__ int *v1, + __gm__ int *v2); + +void LaunchVci_kernel_2d(int *v1, int *v2, void *stream) { + vci_kernel_2d<<<1, nullptr, stream>>>((__gm__ int *)v1, + (__gm__ int *)v2); +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vci/main.cpp new file mode 100644 index 000000000..2d828b0ba --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vci/main.cpp @@ -0,0 +1,91 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vci +// family: dsa-sfu / conversion +// target_ops: pto.vci +// scenarios: index-generation +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVci_kernel_2d(int *v1, int *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(int); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int); + int *v1Host = nullptr; + int *v1Device = nullptr; + int *v2Host = nullptr; + int *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVci_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci/stub.cpp b/test/vpto/cases/micro-op/dsa-sfu/vci/stub.cpp new file mode 100644 index 000000000..8f04031ff --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vci/stub.cpp @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vci +// family: dsa-sfu / conversion +// target_ops: pto.vci +// scenarios: index-generation +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vci_kernel_2d(__gm__ int *v1, + __gm__ int *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/compare.py b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/compare.py new file mode 100755 index 000000000..4b67584ba --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vexpdiff-boundary +# family: dsa-sfu +# target_ops: pto.vexpdiff +# scenarios: core-f32, fused-expdiff, exceptional-values, floating-overflow-underflow +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/golden.py b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/golden.py new file mode 100755 index 000000000..b82ce1002 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/golden.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vexpdiff-boundary +# family: dsa-sfu +# target_ops: pto.vexpdiff +# scenarios: core-f32, fused-expdiff, exceptional-values, floating-overflow-underflow +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +def generate(output_dir: Path, seed: int) -> None: + del seed + src_pattern = np.array( + [ + 0.0, 88.0, -120.0, np.nan, np.inf, -np.inf, 1.0, -1.0, + 90.0, -90.0, 50.0, -50.0, 3.0, -3.0, 10.0, -10.0, + ], + dtype=np.float32, + ) + max_pattern = np.array( + [ + 0.0, 0.0, 0.0, 1.0, np.inf, -np.inf, -1.0, 1.0, + 0.0, 0.0, 100.0, -100.0, 3.0, -3.0, 20.0, -20.0, + ], + dtype=np.float32, + ) + flat_src = np.resize(src_pattern, ROWS * COLS).astype(np.float32, copy=False) + flat_max = np.resize(max_pattern, ROWS * COLS).astype(np.float32, copy=False) + v1 = flat_src.reshape(ROWS, COLS) + v2 = flat_max.reshape(ROWS, COLS) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v3 = np.exp(flat_src - flat_max).astype(np.float32, copy=False).reshape(ROWS, COLS) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/kernel.pto new file mode 100644 index 000000000..2d4df9cb0 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/kernel.pto @@ -0,0 +1,54 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vexpdiff-boundary +// family: dsa-sfu +// target_ops: pto.vexpdiff +// scenarios: core-f32, fused-expdiff, exceptional-values, floating-overflow-underflow +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vexpdiff_boundary_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_max = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_max, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c64 { + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %max = pto.vlds %ub_max[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %sum = pto.vexpdiff %vec, %max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/launch.cpp new file mode 100644 index 000000000..1ff4ec8a7 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/launch.cpp @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vexpdiff-boundary +// family: dsa-sfu +// target_ops: pto.vexpdiff +// scenarios: core-f32, fused-expdiff, exceptional-values, floating-overflow-underflow +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vexpdiff_boundary_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchVexpdiff_boundary_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + vexpdiff_boundary_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/main.cpp new file mode 100644 index 000000000..59f0d80f4 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/main.cpp @@ -0,0 +1,102 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vexpdiff-boundary +// family: dsa-sfu +// target_ops: pto.vexpdiff +// scenarios: core-f32, fused-expdiff, exceptional-values, floating-overflow-underflow +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVexpdiff_boundary_kernel_2d(float *v1, float *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVexpdiff_boundary_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/stub.cpp b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/stub.cpp new file mode 100644 index 000000000..278265849 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/stub.cpp @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vexpdiff-boundary +// family: dsa-sfu +// target_ops: pto.vexpdiff +// scenarios: core-f32, fused-expdiff, exceptional-values, floating-overflow-underflow +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vexpdiff_boundary_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/compare.py b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/compare.py new file mode 100644 index 000000000..c2ea3f6bd --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/compare.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vexpdiff-f16-part +# family: dsa-sfu +# target_ops: pto.vexpdiff +# scenarios: core-f16, fused-expdiff, part-even-odd + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/golden.py b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/golden.py new file mode 100644 index 000000000..915730bc8 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/golden.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vexpdiff-f16-part +# family: dsa-sfu +# target_ops: pto.vexpdiff +# scenarios: core-f16, fused-expdiff, part-even-odd + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 31 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-4.0, 4.0, size=(ROWS, COLS)).astype(np.float16) + v2 = rng.uniform(-2.0, 2.0, size=(ROWS, COLS)).astype(np.float16) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + + flat1 = v1.reshape(-1) + flat2 = v2.reshape(-1) + golden = np.empty((ROWS * COLS,), dtype=np.float32) + for base in range(0, ROWS * COLS, 128): + chunk1 = flat1[base : base + 128].astype(np.float32) + chunk2 = flat2[base : base + 128].astype(np.float32) + golden[base : base + 64] = np.exp(chunk1[0::2] - chunk2[0::2]).astype( + np.float32 + ) + golden[base + 64 : base + 128] = np.exp( + chunk1[1::2] - chunk2[1::2] + ).astype(np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + flat1.tofile(output_dir / "v1.bin") + flat2.tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/kernel.pto new file mode 100644 index 000000000..338dbf8be --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/kernel.pto @@ -0,0 +1,60 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vexpdiff-f16-part +// family: dsa-sfu +// target_ops: pto.vexpdiff +// scenarios: core-f16, fused-expdiff, part-even-odd +// NOTE: validates that ODD/EVEN selects odd/even lanes from f16 inputs. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vexpdiff_f16_part_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_max = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_max, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { + %input = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xf16> + %max = pto.vlds %ub_max[%offset] : !pto.ptr -> !pto.vreg<128xf16> + %even_mask, %remaining_after_even = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %odd_mask, %next_remaining = pto.plt_b32 %remaining_after_even : i32 -> !pto.mask, i32 + %even = pto.vexpdiff %input, %max, "EVEN" : !pto.vreg<128xf16>, !pto.vreg<128xf16> -> !pto.vreg<64xf32> + %odd = pto.vexpdiff %input, %max, "ODD" : !pto.vreg<128xf16>, !pto.vreg<128xf16> -> !pto.vreg<64xf32> + %odd_offset = arith.addi %offset, %c64 : index + pto.vsts %even, %ub_out[%offset], %even_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.vsts %odd, %ub_out[%odd_offset], %odd_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/launch.cpp new file mode 100644 index 000000000..8bee57183 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/launch.cpp @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vexpdiff-f16-part +// family: dsa-sfu +// target_ops: pto.vexpdiff +// scenarios: core-f16, fused-expdiff, part-even-odd +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vexpdiff_f16_part_kernel_2d(__gm__ half *v1, + __gm__ half *v2, + __gm__ float *v3); + +void LaunchVexpdiff_f16_part_kernel_2d(uint16_t *v1, uint16_t *v2, float *v3, + void *stream) { + vexpdiff_f16_part_kernel_2d<<<1, nullptr, stream>>>((__gm__ half *)v1, + (__gm__ half *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/main.cpp new file mode 100644 index 000000000..7137c02a0 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/main.cpp @@ -0,0 +1,101 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vexpdiff-f16-part +// family: dsa-sfu +// target_ops: pto.vexpdiff +// scenarios: core-f16, fused-expdiff, part-even-odd +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVexpdiff_f16_part_kernel_2d(uint16_t *v1, uint16_t *v2, float *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVexpdiff_f16_part_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/stub.cpp b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/stub.cpp new file mode 100644 index 000000000..65011d6bb --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/stub.cpp @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vexpdiff-f16-part +// family: dsa-sfu +// target_ops: pto.vexpdiff +// scenarios: core-f16, fused-expdiff, part-even-odd +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vexpdiff_f16_part_kernel_2d(__gm__ half *v1, + __gm__ half *v2, + __gm__ float *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/compare.py b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/compare.py new file mode 100755 index 000000000..fdd9df368 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vexpdiff-f32 +# family: dsa-sfu +# target_ops: pto.vexpdiff +# scenarios: core-f32, fused-expdiff +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/golden.py b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/golden.py new file mode 100755 index 000000000..9c6cf1776 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/golden.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vexpdiff-f32 +# family: dsa-sfu +# target_ops: pto.vexpdiff +# scenarios: core-f32, fused-expdiff +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.ones((ROWS, COLS), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/kernel.pto new file mode 100644 index 000000000..026f64ace --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/kernel.pto @@ -0,0 +1,50 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vexpdiff-f32 +// family: dsa-sfu +// target_ops: pto.vexpdiff +// scenarios: core-f32, fused-expdiff +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vexpdiff_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %sum = pto.vexpdiff %vec, %vec, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/launch.cpp new file mode 100644 index 000000000..cc7fb20e8 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/launch.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vexpdiff-f32 +// family: dsa-sfu +// target_ops: pto.vexpdiff +// scenarios: core-f32, fused-expdiff +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vexpdiff_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVexpdiff_kernel_2d(float *v1, float *v2, void *stream) { + vexpdiff_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/main.cpp new file mode 100644 index 000000000..ccf695380 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/main.cpp @@ -0,0 +1,91 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vexpdiff-f32 +// family: dsa-sfu +// target_ops: pto.vexpdiff +// scenarios: core-f32, fused-expdiff +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVexpdiff_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVexpdiff_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/stub.cpp b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/stub.cpp new file mode 100644 index 000000000..30bf48863 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/stub.cpp @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vexpdiff-f32 +// family: dsa-sfu +// target_ops: pto.vexpdiff +// scenarios: core-f32, fused-expdiff +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vexpdiff_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/compare.py b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/compare.py new file mode 100755 index 000000000..4717fd3e8 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vlrelu-f16 +# family: dsa-sfu +# target_ops: pto.vlrelu +# scenarios: core-f16, full-mask, scalar-operand +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/golden.py b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/golden.py new file mode 100755 index 000000000..bc7c328b9 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vlrelu-f16 +# family: dsa-sfu +# target_ops: pto.vlrelu +# scenarios: core-f16, full-mask, scalar-operand +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +ALPHA = np.float16(0.125) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float16) + v2 = np.zeros((ROWS, COLS), dtype=np.float16) + golden_v2 = np.where(v1 >= 0.0, v1, v1 * ALPHA).astype(np.float16, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/kernel.pto new file mode 100644 index 000000000..e071ed15d --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/kernel.pto @@ -0,0 +1,50 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vlrelu-f16 +// family: dsa-sfu +// target_ops: pto.vlrelu +// scenarios: core-f16, full-mask, scalar-operand +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vec_add_scalar_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %cst = arith.constant 1.250000e-01 : f16 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xf16> + %sum = pto.vlrelu %vec, %cst, %mask : !pto.vreg<128xf16>, f16, !pto.mask -> !pto.vreg<128xf16> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<128xf16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/launch.cpp new file mode 100644 index 000000000..da89bb6f0 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/launch.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vlrelu-f16 +// family: dsa-sfu +// target_ops: pto.vlrelu +// scenarios: core-f16, full-mask, scalar-operand +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vec_add_scalar_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVec_add_scalar_kernel_2d(float *v1, float *v2, void *stream) { + vec_add_scalar_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/main.cpp new file mode 100644 index 000000000..73e868d99 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/main.cpp @@ -0,0 +1,91 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vlrelu-f16 +// family: dsa-sfu +// target_ops: pto.vlrelu +// scenarios: core-f16, full-mask, scalar-operand +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVec_add_scalar_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVec_add_scalar_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/stub.cpp b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/stub.cpp new file mode 100644 index 000000000..7b0e1f2a5 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/stub.cpp @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vlrelu-f16 +// family: dsa-sfu +// target_ops: pto.vlrelu +// scenarios: core-f16, full-mask, scalar-operand +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vec_add_scalar_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/compare.py b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/compare.py new file mode 100644 index 000000000..15b793fac --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/golden.py b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/golden.py new file mode 100644 index 000000000..938b69b9d --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +ALPHA = np.float32(0.125) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + specials = np.array( + [-np.inf, -8.0, -1.0, -0.0, 0.0, 1.0, np.inf, np.nan], + dtype=np.float32, + ) + v1 = np.resize(specials, ROWS * COLS).reshape(ROWS, COLS).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.where(v1 >= 0.0, v1, v1 * ALPHA).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/kernel.pto new file mode 100644 index 000000000..3d0cfb074 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/kernel.pto @@ -0,0 +1,44 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vec_add_scalar_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + %cst = arith.constant 1.250000e-01 : f32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %sum = pto.vlrelu %vec, %cst, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/launch.cpp new file mode 100644 index 000000000..44c07c249 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vec_add_scalar_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVec_add_scalar_kernel_2d(float *v1, float *v2, void *stream) { + vec_add_scalar_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/main.cpp new file mode 100644 index 000000000..fcb42331f --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/main.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVec_add_scalar_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVec_add_scalar_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/stub.cpp b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/stub.cpp new file mode 100644 index 000000000..dea70f9b6 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/stub.cpp @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vec_add_scalar_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/compare.py b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/compare.py new file mode 100644 index 000000000..15b793fac --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/golden.py b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/golden.py new file mode 100644 index 000000000..dd0899be8 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/golden.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +ALPHA = np.float32(0.125) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.where(v1 >= 0.0, v1, v1 * ALPHA).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/kernel.pto new file mode 100644 index 000000000..3d0cfb074 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/kernel.pto @@ -0,0 +1,44 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vec_add_scalar_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + %cst = arith.constant 1.250000e-01 : f32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %sum = pto.vlrelu %vec, %cst, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/launch.cpp new file mode 100644 index 000000000..44c07c249 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vec_add_scalar_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVec_add_scalar_kernel_2d(float *v1, float *v2, void *stream) { + vec_add_scalar_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/main.cpp new file mode 100644 index 000000000..fcb42331f --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/main.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVec_add_scalar_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVec_add_scalar_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/stub.cpp b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/stub.cpp new file mode 100644 index 000000000..dea70f9b6 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/stub.cpp @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vec_add_scalar_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/compare.py b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/compare.py new file mode 100644 index 000000000..c13d79273 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/compare.py @@ -0,0 +1,39 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float32, 1e-4, 1000) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/golden.py b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/golden.py new file mode 100644 index 000000000..2544a92ff --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +ALPHA = np.float32(0.125) +LOGICAL_ELEMS = 1000 +OUT_SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v2 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + flat = v1.reshape(-1) + golden_v2.reshape(-1)[:LOGICAL_ELEMS] = np.where( + flat[:LOGICAL_ELEMS] >= 0.0, flat[:LOGICAL_ELEMS], flat[:LOGICAL_ELEMS] * ALPHA + ).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/kernel.pto new file mode 100644 index 000000000..53a0e0456 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/kernel.pto @@ -0,0 +1,43 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vadds_tail_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1000_i32 = arith.constant 1000 : i32 + %cst = arith.constant 1.250000e-01 : f32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %sum = pto.vlrelu %vec, %cst, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/launch.cpp new file mode 100644 index 000000000..b4cd46470 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vadds_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVadds_tail_kernel_2d(float *v1, float *v2, void *stream) { + vadds_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/main.cpp new file mode 100644 index 000000000..ab77e6b1a --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/main.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadds_tail_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadds_tail_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/stub.cpp b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/stub.cpp new file mode 100644 index 000000000..67e6846a1 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/stub.cpp @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vadds_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/compare.py b/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/compare.py new file mode 100755 index 000000000..e7e8af91d --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/compare.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vmula-accumulator-boundary +# family: dsa-sfu +# target_ops: pto.vmula +# scenarios: core-f32, fused-op, accumulator +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + +ACTIVE_ELEMS = 65 + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.size == count and output.size == count and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float32, 1e-4, ACTIVE_ELEMS) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/golden.py b/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/golden.py new file mode 100755 index 000000000..6c0d8c252 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vmula-accumulator-boundary +# family: dsa-sfu +# target_ops: pto.vmula +# scenarios: core-f32, fused-op, accumulator, boundary +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = (v1 + np.abs(v1) * np.abs(v1)).astype(np.float32, copy=False) + golden_v2.reshape(-1)[65:] = 0.0 + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/kernel.pto new file mode 100644 index 000000000..5901724b7 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/kernel.pto @@ -0,0 +1,50 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vmula-accumulator-boundary +// family: dsa-sfu +// target_ops: pto.vmula +// scenarios: core-f32, fused-op, accumulator, boundary +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vec_add_scalar_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c65_i32 = arith.constant 65 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c65_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %acc = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %lhs = pto.vabs %acc, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %rhs = pto.vabs %lhs, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %sum = pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/launch.cpp new file mode 100644 index 000000000..8dcf35197 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/launch.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vmula-accumulator-boundary +// family: dsa-sfu +// target_ops: pto.vmula +// scenarios: core-f32, fused-op, accumulator +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vec_add_scalar_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVec_add_scalar_kernel_2d(float *v1, float *v2, void *stream) { + vec_add_scalar_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/main.cpp new file mode 100644 index 000000000..c9b1f36c3 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/main.cpp @@ -0,0 +1,91 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vmula-accumulator-boundary +// family: dsa-sfu +// target_ops: pto.vmula +// scenarios: core-f32, fused-op, accumulator +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVec_add_scalar_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVec_add_scalar_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/stub.cpp b/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/stub.cpp new file mode 100644 index 000000000..d5df9bb56 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/stub.cpp @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vmula-accumulator-boundary +// family: dsa-sfu +// target_ops: pto.vmula +// scenarios: core-f32, fused-op, accumulator +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vec_add_scalar_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmula/compare.py b/test/vpto/cases/micro-op/dsa-sfu/vmula/compare.py new file mode 100755 index 000000000..cfc4e190a --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmula/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vmula +# family: dsa-sfu +# target_ops: pto.vmula +# scenarios: core-f32, fused-op, accumulator +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmula/golden.py b/test/vpto/cases/micro-op/dsa-sfu/vmula/golden.py new file mode 100755 index 000000000..2110c7144 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmula/golden.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vmula +# family: dsa-sfu +# target_ops: pto.vmula +# scenarios: core-f32, fused-op, accumulator +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = (v1 + np.abs(v1) * np.abs(v1)).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmula/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vmula/kernel.pto new file mode 100644 index 000000000..3f4216a6b --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmula/kernel.pto @@ -0,0 +1,50 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vmula +// family: dsa-sfu +// target_ops: pto.vmula +// scenarios: core-f32, fused-op, accumulator +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vec_add_scalar_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %acc = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %lhs = pto.vabs %acc, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %rhs = pto.vabs %lhs, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %sum = pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmula/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vmula/launch.cpp new file mode 100644 index 000000000..fa6ae4bb5 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmula/launch.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vmula +// family: dsa-sfu +// target_ops: pto.vmula +// scenarios: core-f32, fused-op, accumulator +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vec_add_scalar_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVec_add_scalar_kernel_2d(float *v1, float *v2, void *stream) { + vec_add_scalar_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmula/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vmula/main.cpp new file mode 100644 index 000000000..54508a12d --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmula/main.cpp @@ -0,0 +1,91 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vmula +// family: dsa-sfu +// target_ops: pto.vmula +// scenarios: core-f32, fused-op, accumulator +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVec_add_scalar_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVec_add_scalar_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmula/stub.cpp b/test/vpto/cases/micro-op/dsa-sfu/vmula/stub.cpp new file mode 100644 index 000000000..8d80184f9 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmula/stub.cpp @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vmula +// family: dsa-sfu +// target_ops: pto.vmula +// scenarios: core-f32, fused-op, accumulator +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vec_add_scalar_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmull/compare.py b/test/vpto/cases/micro-op/dsa-sfu/vmull/compare.py new file mode 100755 index 000000000..e2e0b0eef --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmull/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vmull +# family: dsa-sfu +# target_ops: pto.vmull +# scenarios: widening-op, hi-lo-split +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.int32, 0) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmull/golden.py b/test/vpto/cases/micro-op/dsa-sfu/vmull/golden.py new file mode 100755 index 000000000..96b892823 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmull/golden.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vmull +# family: dsa-sfu +# target_ops: pto.vmull +# scenarios: widening-op, hi-lo-split +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + lhs = rng.integers(-10000, 10000, size=(ROWS // 2, COLS), dtype=np.int32) + rhs = rng.integers(-10000, 10000, size=(ROWS // 2, COLS), dtype=np.int32) + v1 = np.concatenate([lhs, rhs], axis=0).astype(np.int32, copy=False) + prod = lhs.astype(np.int64) * rhs.astype(np.int64) + low = (prod & np.int64(0xFFFFFFFF)).astype(np.uint32).view(np.int32) + high = ((prod >> np.int64(32)) & np.int64(0xFFFFFFFF)).astype(np.uint32).view(np.int32) + golden_v2 = np.concatenate([low, high], axis=0).astype(np.int32, copy=False) + v2 = np.zeros((ROWS, COLS), dtype=np.int32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmull/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vmull/kernel.pto new file mode 100644 index 000000000..30af09612 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmull/kernel.pto @@ -0,0 +1,52 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vmull +// family: dsa-sfu +// target_ops: pto.vmull +// scenarios: widening-op, hi-lo-split +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vmull_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c512 = arith.constant 512 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c512_i32 = arith.constant 512 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %gm_in = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + %gm_out = pto.castptr %arg1 : !pto.ptr -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %gm_in, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c512 step %c64 iter_args(%remaining = %c512_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %rhs_offset = arith.addi %offset, %c512 : index + %lhs = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xi32> + %rhs = pto.vlds %ub_in[%rhs_offset] : !pto.ptr -> !pto.vreg<64xi32> + %low, %high = pto.vmull %lhs, %rhs, %mask : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32>, !pto.vreg<64xi32> + pto.vsts %low, %ub_out[%offset], %mask : !pto.vreg<64xi32>, !pto.ptr, !pto.mask + pto.vsts %high, %ub_out[%rhs_offset], %mask : !pto.vreg<64xi32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmull/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vmull/launch.cpp new file mode 100644 index 000000000..1b7ce84c6 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmull/launch.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vmull +// family: dsa-sfu +// target_ops: pto.vmull +// scenarios: widening-op, hi-lo-split +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vmull_kernel_2d(__gm__ int *v1, + __gm__ int *v2); + +void LaunchVmull_kernel_2d(int *v1, int *v2, void *stream) { + vmull_kernel_2d<<<1, nullptr, stream>>>((__gm__ int *)v1, + (__gm__ int *)v2); +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmull/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vmull/main.cpp new file mode 100644 index 000000000..d853eda16 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmull/main.cpp @@ -0,0 +1,91 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vmull +// family: dsa-sfu +// target_ops: pto.vmull +// scenarios: widening-op, hi-lo-split +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmull_kernel_2d(int *v1, int *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(int); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int); + int *v1Host = nullptr; + int *v1Device = nullptr; + int *v2Host = nullptr; + int *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmull_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmull/stub.cpp b/test/vpto/cases/micro-op/dsa-sfu/vmull/stub.cpp new file mode 100644 index 000000000..2791c774a --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmull/stub.cpp @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vmull +// family: dsa-sfu +// target_ops: pto.vmull +// scenarios: widening-op, hi-lo-split +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vmull_kernel_2d(__gm__ int *v1, + __gm__ int *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/compare.py b/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/compare.py new file mode 100755 index 000000000..35caa0aa6 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vprelu-f32 +# family: dsa-sfu +# target_ops: pto.vprelu +# scenarios: core-f32, vector-alpha +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/golden.py b/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/golden.py new file mode 100755 index 000000000..e1fd7c683 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vprelu-f32 +# family: dsa-sfu +# target_ops: pto.vprelu +# scenarios: core-f32, vector-alpha +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = rng.uniform(0.05, 0.5, size=(ROWS, COLS)).astype(np.float32) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v3 = np.where(v1 >= 0.0, v1, v1 * v2).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/kernel.pto new file mode 100644 index 000000000..2dc582748 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/kernel.pto @@ -0,0 +1,54 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vprelu-f32 +// family: dsa-sfu +// target_ops: pto.vprelu +// scenarios: core-f32, vector-alpha +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vprelu_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_alpha = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_alpha, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c64 { + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %alpha = pto.vlds %ub_alpha[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %sum = pto.vprelu %vec, %alpha : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/launch.cpp new file mode 100644 index 000000000..d6002ce63 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/launch.cpp @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vprelu-f32 +// family: dsa-sfu +// target_ops: pto.vprelu +// scenarios: core-f32, vector-alpha +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vprelu_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchVprelu_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + vprelu_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/main.cpp new file mode 100644 index 000000000..6a2738912 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/main.cpp @@ -0,0 +1,102 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vprelu-f32 +// family: dsa-sfu +// target_ops: pto.vprelu +// scenarios: core-f32, vector-alpha +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVprelu_kernel_2d(float *v1, float *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVprelu_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/stub.cpp b/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/stub.cpp new file mode 100644 index 000000000..bcff63877 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/stub.cpp @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vprelu-f32 +// family: dsa-sfu +// target_ops: pto.vprelu +// scenarios: core-f32, vector-alpha +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vprelu_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/compare.py b/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/compare.py new file mode 100755 index 000000000..bbc6ab65a --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/compare.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vprelu-tail +# family: dsa-sfu +# target_ops: pto.vprelu +# scenarios: core-f32, vector-alpha, tail-mask +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + +ACTIVE_ELEMS = 1000 + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.size == count and output.size == count and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.float32, 1e-4, ACTIVE_ELEMS) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/golden.py b/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/golden.py new file mode 100755 index 000000000..a9e101569 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vprelu-tail +# family: dsa-sfu +# target_ops: pto.vprelu +# scenarios: core-f32, vector-alpha, tail-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = rng.uniform(0.05, 0.5, size=(ROWS, COLS)).astype(np.float32) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v3 = np.where(v1 >= 0.0, v1, v1 * v2).astype(np.float32, copy=False) + golden_v3.reshape(-1)[1000:] = 0.0 + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/kernel.pto new file mode 100644 index 000000000..5c3951823 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/kernel.pto @@ -0,0 +1,56 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vprelu-tail +// family: dsa-sfu +// target_ops: pto.vprelu +// scenarios: core-f32, vector-alpha, tail-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vprelu_tail_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1000_i32 = arith.constant 1000 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_alpha = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_alpha, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %alpha = pto.vlds %ub_alpha[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %sum = pto.vprelu %vec, %alpha : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/launch.cpp new file mode 100644 index 000000000..b4a675c6a --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/launch.cpp @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vprelu-tail +// family: dsa-sfu +// target_ops: pto.vprelu +// scenarios: core-f32, vector-alpha, tail-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vprelu_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchVprelu_tail_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + vprelu_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/main.cpp new file mode 100644 index 000000000..27b55f701 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/main.cpp @@ -0,0 +1,102 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vprelu-tail +// family: dsa-sfu +// target_ops: pto.vprelu +// scenarios: core-f32, vector-alpha, tail-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVprelu_tail_kernel_2d(float *v1, float *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVprelu_tail_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/stub.cpp b/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/stub.cpp new file mode 100644 index 000000000..f54e21ef9 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/stub.cpp @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vprelu-tail +// family: dsa-sfu +// target_ops: pto.vprelu +// scenarios: core-f32, vector-alpha, tail-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vprelu_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/compare.py b/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/compare.py new file mode 100755 index 000000000..c932750d2 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/gather-scatter/vgather2-duplicate-index +# family: gather-scatter +# target_ops: pto.vgather2 +# scenarios: core-f32, non-contiguous, explicit-index-pattern, load-effect-validation, no-alias +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/golden.py b/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/golden.py new file mode 100755 index 000000000..f27ecfd0b --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/golden.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/gather-scatter/vgather2-duplicate-index +# family: gather-scatter +# target_ops: pto.vgather2 +# scenarios: core-f32, non-contiguous, explicit-index-pattern, load-effect-validation, no-alias +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + flat = rng.uniform(-8.0, 8.0, size=(ROWS * COLS,)).astype(np.float32) + pair_ids = ((np.arange((ROWS * COLS) // 2, dtype=np.int32) * 29) + 5) % (ROWS * COLS) + offsets = np.repeat(pair_ids, 2) + gathered = flat[offsets].reshape(ROWS, COLS) + v1 = flat.reshape(ROWS, COLS) + v2 = offsets.reshape(ROWS, COLS) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + gathered.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vgather2 duplicate-index validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/kernel.pto b/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/kernel.pto new file mode 100644 index 000000000..46bd057fc --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/kernel.pto @@ -0,0 +1,53 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgather2-duplicate-index +// family: gather-scatter +// target_ops: pto.vgather2 +// scenarios: core-f32, non-contiguous, explicit-index-pattern, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vgather2_duplicate_index_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_offsets = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_offsets, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %full_mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %offsets = pto.vlds %ub_offsets[%offset] : !pto.ptr -> !pto.vreg<64xi32> + %out = pto.vgather2 %ub_in, %offsets, %c64 : !pto.ptr, !pto.vreg<64xi32>, index -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %full_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/launch.cpp b/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/launch.cpp new file mode 100644 index 000000000..1a2d0359e --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/launch.cpp @@ -0,0 +1,72 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgather2-duplicate-index +// family: gather-scatter +// target_ops: pto.vgather2 +// scenarios: core-f32, non-contiguous, explicit-index-pattern, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vgather2_duplicate_index_kernel_2d( + __gm__ float *v1, __gm__ int *v2, __gm__ float *v3); + +void LaunchVgather2_duplicate_index_kernel_2d(float *v1, int *v2, float *v3, + void *stream) { + vgather2_duplicate_index_kernel_2d<<<1, nullptr, stream>>>( + (__gm__ float *)v1, (__gm__ int *)v2, (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/main.cpp b/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/main.cpp new file mode 100644 index 000000000..df5907af3 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/main.cpp @@ -0,0 +1,141 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgather2-duplicate-index +// family: gather-scatter +// target_ops: pto.vgather2 +// scenarios: core-f32, non-contiguous, explicit-index-pattern, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVgather2_duplicate_index_kernel_2d(float *v1, int *v2, float *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + int *v2Host = nullptr; + int *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVgather2_duplicate_index_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/stub.cpp b/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/stub.cpp new file mode 100644 index 000000000..847e2c028 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/stub.cpp @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgather2-duplicate-index +// family: gather-scatter +// target_ops: pto.vgather2 +// scenarios: core-f32, non-contiguous, explicit-index-pattern, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vgather2_duplicate_index_kernel_2d( + __gm__ float *v1, __gm__ int *v2, __gm__ float *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2/compare.py b/test/vpto/cases/micro-op/gather-scatter/vgather2/compare.py new file mode 100755 index 000000000..41f5c3f65 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/gather-scatter/vgather2 +# family: gather-scatter +# target_ops: pto.vgather2 +# scenarios: core-f32, full-mask, non-contiguous, explicit-index-pattern, load-effect-validation, no-alias +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2/golden.py b/test/vpto/cases/micro-op/gather-scatter/vgather2/golden.py new file mode 100755 index 000000000..714c54c63 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2/golden.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/gather-scatter/vgather2 +# family: gather-scatter +# target_ops: pto.vgather2 +# scenarios: core-f32, full-mask, non-contiguous, explicit-index-pattern, load-effect-validation, no-alias +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + flat = rng.uniform(-8.0, 8.0, size=(ROWS * COLS,)).astype(np.float32) + offsets = ((np.arange(ROWS * COLS, dtype=np.int32) * 17) + 3) % (ROWS * COLS) + gathered = flat[offsets].reshape(ROWS, COLS) + v1 = flat.reshape(ROWS, COLS) + v2 = offsets.reshape(ROWS, COLS) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + gathered.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vgather2 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2/kernel.pto b/test/vpto/cases/micro-op/gather-scatter/vgather2/kernel.pto new file mode 100644 index 000000000..34b01dced --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2/kernel.pto @@ -0,0 +1,72 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgather2 +// family: gather-scatter +// target_ops: pto.vgather2 +// scenarios: core-f32, full-mask, non-contiguous, explicit-index-pattern, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5"} { + func.func @vgather2_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_offsets = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_offsets, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %offsets = pto.vlds %ub_offsets[%offset] : !pto.ptr -> !pto.vreg<64xi32> + %out = pto.vgather2 %ub_in, %offsets, %c64 : !pto.ptr, !pto.vreg<64xi32>, index -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2/launch.cpp b/test/vpto/cases/micro-op/gather-scatter/vgather2/launch.cpp new file mode 100644 index 000000000..e99c6741a --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2/launch.cpp @@ -0,0 +1,73 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgather2 +// family: gather-scatter +// target_ops: pto.vgather2 +// scenarios: core-f32, full-mask, non-contiguous, explicit-index-pattern, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vgather2_kernel_2d(__gm__ float *v1, + __gm__ int *v2, + __gm__ float *v3); + +void LaunchVgather2_kernel_2d(float *v1, int *v2, float *v3, void *stream) { + vgather2_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ int *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2/main.cpp b/test/vpto/cases/micro-op/gather-scatter/vgather2/main.cpp new file mode 100644 index 000000000..e2a9b4804 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2/main.cpp @@ -0,0 +1,140 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgather2 +// family: gather-scatter +// target_ops: pto.vgather2 +// scenarios: core-f32, full-mask, non-contiguous, explicit-index-pattern, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVgather2_kernel_2d(float *v1, int *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + int *v2Host = nullptr; + int *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVgather2_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2/stub.cpp b/test/vpto/cases/micro-op/gather-scatter/vgather2/stub.cpp new file mode 100644 index 000000000..87d007735 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2/stub.cpp @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgather2 +// family: gather-scatter +// target_ops: pto.vgather2 +// scenarios: core-f32, full-mask, non-contiguous, explicit-index-pattern, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vgather2_kernel_2d(__gm__ float *v1, + __gm__ int *v2, + __gm__ float *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/compare.py b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/compare.py new file mode 100755 index 000000000..83f25f4ab --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/gather-scatter/vgather2_bc-sparse-mask +# family: gather-scatter +# target_ops: pto.vgather2_bc +# scenarios: core-f32, masked-gather, load-effect-validation, no-alias +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/golden.py b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/golden.py new file mode 100755 index 000000000..e0cea7841 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/golden.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/gather-scatter/vgather2_bc-sparse-mask +# family: gather-scatter +# target_ops: pto.vgather2_bc +# scenarios: core-f32, masked-gather, load-effect-validation, no-alias +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + flat = rng.uniform(-8.0, 8.0, size=(ROWS * COLS,)).astype(np.float32) + offsets = ((np.arange(ROWS * COLS, dtype=np.int32) * 17) + 3) % (ROWS * COLS) + gathered = np.zeros((ROWS * COLS,), dtype=np.float32) + active = offsets < 64 + gathered[active] = flat[offsets[active]] + v1 = flat.reshape(ROWS, COLS) + v2 = offsets.reshape(ROWS, COLS) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + gathered.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vgather2_bc sparse-mask validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/kernel.pto b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/kernel.pto new file mode 100644 index 000000000..5baaa4a8d --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/kernel.pto @@ -0,0 +1,55 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgather2_bc-sparse-mask +// family: gather-scatter +// target_ops: pto.vgather2_bc +// scenarios: core-f32, masked-gather, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vgather2_bc_sparse_mask_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c64_i32 = arith.constant 64 : i32 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_offsets = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_offsets, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %full_mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %offsets = pto.vlds %ub_offsets[%offset] : !pto.ptr -> !pto.vreg<64xi32> + %gather_mask = pto.vcmps %offsets, %c64_i32, %full_mask, "lt" : !pto.vreg<64xi32>, i32, !pto.mask -> !pto.mask + %out = pto.vgather2_bc %ub_in, %offsets, %gather_mask : !pto.ptr, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %full_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/launch.cpp b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/launch.cpp new file mode 100644 index 000000000..333288162 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/launch.cpp @@ -0,0 +1,72 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgather2_bc-sparse-mask +// family: gather-scatter +// target_ops: pto.vgather2_bc +// scenarios: core-f32, masked-gather, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vgather2_bc_sparse_mask_kernel_2d( + __gm__ float *v1, __gm__ int *v2, __gm__ float *v3); + +void LaunchVgather2_bc_sparse_mask_kernel_2d(float *v1, int *v2, float *v3, + void *stream) { + vgather2_bc_sparse_mask_kernel_2d<<<1, nullptr, stream>>>( + (__gm__ float *)v1, (__gm__ int *)v2, (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/main.cpp b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/main.cpp new file mode 100644 index 000000000..66ab70307 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/main.cpp @@ -0,0 +1,141 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgather2_bc-sparse-mask +// family: gather-scatter +// target_ops: pto.vgather2_bc +// scenarios: core-f32, masked-gather, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVgather2_bc_sparse_mask_kernel_2d(float *v1, int *v2, float *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + int *v2Host = nullptr; + int *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVgather2_bc_sparse_mask_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/stub.cpp b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/stub.cpp new file mode 100644 index 000000000..76de79afc --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/stub.cpp @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgather2_bc-sparse-mask +// family: gather-scatter +// target_ops: pto.vgather2_bc +// scenarios: core-f32, masked-gather, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vgather2_bc_sparse_mask_kernel_2d( + __gm__ float *v1, __gm__ int *v2, __gm__ float *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/compare.py b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/compare.py new file mode 100755 index 000000000..4ebeae5d2 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/gather-scatter/vgather2_bc +# family: gather-scatter +# target_ops: pto.vgather2_bc +# scenarios: core-f32, full-mask, non-contiguous, masked-gather, load-effect-validation, no-alias +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/golden.py b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/golden.py new file mode 100755 index 000000000..da03fb5a7 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/golden.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/gather-scatter/vgather2_bc +# family: gather-scatter +# target_ops: pto.vgather2_bc +# scenarios: core-f32, full-mask, non-contiguous, masked-gather, load-effect-validation, no-alias +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + flat = rng.uniform(-8.0, 8.0, size=(ROWS * COLS,)).astype(np.float32) + offsets = ((np.arange(ROWS * COLS, dtype=np.int32) * 17) + 3) % (ROWS * COLS) + gathered = np.zeros((ROWS * COLS,), dtype=np.float32) + active = offsets < 256 + gathered[active] = flat[offsets[active]] + v1 = flat.reshape(ROWS, COLS) + v2 = offsets.reshape(ROWS, COLS) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + gathered.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vgather2_bc validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/kernel.pto b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/kernel.pto new file mode 100644 index 000000000..80f1c5a28 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/kernel.pto @@ -0,0 +1,55 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgather2_bc +// family: gather-scatter +// target_ops: pto.vgather2_bc +// scenarios: core-f32, full-mask, non-contiguous, masked-gather, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vgather2_bc_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i32 = arith.constant 256 : i32 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_offsets = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_offsets, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %full_mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %offsets = pto.vlds %ub_offsets[%offset] : !pto.ptr -> !pto.vreg<64xi32> + %gather_mask = pto.vcmps %offsets, %c256_i32, %full_mask, "lt" : !pto.vreg<64xi32>, i32, !pto.mask -> !pto.mask + %out = pto.vgather2_bc %ub_in, %offsets, %gather_mask : !pto.ptr, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %full_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/launch.cpp b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/launch.cpp new file mode 100644 index 000000000..2c60a591c --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/launch.cpp @@ -0,0 +1,73 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgather2_bc +// family: gather-scatter +// target_ops: pto.vgather2_bc +// scenarios: core-f32, full-mask, non-contiguous, masked-gather, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vgather2_bc_kernel_2d(__gm__ float *v1, + __gm__ int *v2, + __gm__ float *v3); + +void LaunchVgather2_bc_kernel_2d(float *v1, int *v2, float *v3, void *stream) { + vgather2_bc_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ int *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/main.cpp b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/main.cpp new file mode 100644 index 000000000..73ba0e412 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/main.cpp @@ -0,0 +1,140 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgather2_bc +// family: gather-scatter +// target_ops: pto.vgather2_bc +// scenarios: core-f32, full-mask, non-contiguous, masked-gather, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVgather2_bc_kernel_2d(float *v1, int *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + int *v2Host = nullptr; + int *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVgather2_bc_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/stub.cpp b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/stub.cpp new file mode 100644 index 000000000..82ab56c08 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/stub.cpp @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgather2_bc +// family: gather-scatter +// target_ops: pto.vgather2_bc +// scenarios: core-f32, full-mask, non-contiguous, masked-gather, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vgather2_bc_kernel_2d(__gm__ float *v1, + __gm__ int *v2, + __gm__ float *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/compare.py b/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/compare.py new file mode 100755 index 000000000..6d777d58f --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/gather-scatter/vgatherb-block-boundary +# family: gather-scatter +# target_ops: pto.vgatherb +# scenarios: core-f32, block-gather, aligned-base, load-effect-validation, no-alias +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/golden.py b/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/golden.py new file mode 100755 index 000000000..2bfb0c0d3 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/golden.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/gather-scatter/vgatherb-block-boundary +# family: gather-scatter +# target_ops: pto.vgatherb +# scenarios: core-f32, block-gather, aligned-base, load-effect-validation, no-alias +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +BLOCK_FLOATS = 8 +BLOCKS_PER_ITER = 8 +ITER_ELEMS = 64 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + flat = rng.uniform(-8.0, 8.0, size=(ROWS * COLS,)).astype(np.float32) + blocks = flat.reshape(-1, BLOCK_FLOATS) + offsets = np.zeros((ROWS * COLS,), dtype=np.int32) + gathered = np.zeros((ROWS * COLS,), dtype=np.float32) + boundary_patterns = np.array([0, 1, 15, 16, 31, 32, 63, 127], dtype=np.int32) + + for chunk in range((ROWS * COLS) // ITER_ELEMS): + block_ids = (boundary_patterns + chunk * 3) % blocks.shape[0] + offsets[chunk * ITER_ELEMS:chunk * ITER_ELEMS + BLOCKS_PER_ITER] = block_ids * 32 + gathered[chunk * ITER_ELEMS:(chunk + 1) * ITER_ELEMS] = blocks[block_ids].reshape(-1) + + v1 = flat.reshape(ROWS, COLS) + v2 = offsets.reshape(ROWS, COLS) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + gathered.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vgatherb block-boundary validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/kernel.pto b/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/kernel.pto new file mode 100644 index 000000000..ff1ed7abb --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/kernel.pto @@ -0,0 +1,55 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgatherb-block-boundary +// family: gather-scatter +// target_ops: pto.vgatherb +// scenarios: core-f32, block-gather, aligned-base, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vgatherb_block_boundary_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c8_i32 = arith.constant 8 : i32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_offsets = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_offsets, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %full_mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %gather_mask, %_tail = pto.plt_b32 %c8_i32 : i32 -> !pto.mask, i32 + %offsets = pto.vlds %ub_offsets[%offset] : !pto.ptr -> !pto.vreg<64xi32> + %out = pto.vgatherb %ub_in, %offsets, %gather_mask : !pto.ptr, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %full_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/launch.cpp b/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/launch.cpp new file mode 100644 index 000000000..fb6f40c39 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/launch.cpp @@ -0,0 +1,72 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgatherb-block-boundary +// family: gather-scatter +// target_ops: pto.vgatherb +// scenarios: core-f32, block-gather, aligned-base, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vgatherb_block_boundary_kernel_2d( + __gm__ float *v1, __gm__ int *v2, __gm__ float *v3); + +void LaunchVgatherb_block_boundary_kernel_2d(float *v1, int *v2, float *v3, + void *stream) { + vgatherb_block_boundary_kernel_2d<<<1, nullptr, stream>>>( + (__gm__ float *)v1, (__gm__ int *)v2, (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/main.cpp b/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/main.cpp new file mode 100644 index 000000000..77c2cb46f --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/main.cpp @@ -0,0 +1,141 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgatherb-block-boundary +// family: gather-scatter +// target_ops: pto.vgatherb +// scenarios: core-f32, block-gather, aligned-base, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVgatherb_block_boundary_kernel_2d(float *v1, int *v2, float *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + int *v2Host = nullptr; + int *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVgatherb_block_boundary_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/stub.cpp b/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/stub.cpp new file mode 100644 index 000000000..fd1a8ef1d --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/stub.cpp @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgatherb-block-boundary +// family: gather-scatter +// target_ops: pto.vgatherb +// scenarios: core-f32, block-gather, aligned-base, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vgatherb_block_boundary_kernel_2d( + __gm__ float *v1, __gm__ int *v2, __gm__ float *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgatherb/compare.py b/test/vpto/cases/micro-op/gather-scatter/vgatherb/compare.py new file mode 100755 index 000000000..e9d439e1a --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgatherb/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/gather-scatter/vgatherb +# family: gather-scatter +# target_ops: pto.vgatherb +# scenarios: core-f32, full-mask, block-gather, aligned-base, load-effect-validation, no-alias +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/gather-scatter/vgatherb/golden.py b/test/vpto/cases/micro-op/gather-scatter/vgatherb/golden.py new file mode 100755 index 000000000..e102cecfe --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgatherb/golden.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/gather-scatter/vgatherb +# family: gather-scatter +# target_ops: pto.vgatherb +# scenarios: core-f32, full-mask, block-gather, aligned-base, load-effect-validation, no-alias +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +BLOCK_FLOATS = 8 +BLOCKS_PER_ITER = 8 +ITER_ELEMS = 64 +SEED = 19 +OUT_SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + flat = rng.uniform(-8.0, 8.0, size=(ROWS * COLS,)).astype(np.float32) + blocks = flat.reshape(-1, BLOCK_FLOATS) + offsets = np.zeros((ROWS * COLS,), dtype=np.int32) + gathered = np.full((ROWS * COLS,), OUT_SENTINEL, dtype=np.float32) + + for chunk in range((ROWS * COLS) // ITER_ELEMS): + block_ids = ((np.arange(BLOCKS_PER_ITER, dtype=np.int32) + chunk * 11) * 7 + 3) % blocks.shape[0] + offsets[chunk * ITER_ELEMS:chunk * ITER_ELEMS + BLOCKS_PER_ITER] = block_ids * 32 + gathered[chunk * ITER_ELEMS:(chunk + 1) * ITER_ELEMS] = blocks[block_ids].reshape(-1) + + v1 = flat.reshape(ROWS, COLS) + v2 = offsets.reshape(ROWS, COLS) + v3 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + gathered.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vgatherb validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/gather-scatter/vgatherb/kernel.pto b/test/vpto/cases/micro-op/gather-scatter/vgatherb/kernel.pto new file mode 100644 index 000000000..abad98343 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgatherb/kernel.pto @@ -0,0 +1,55 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgatherb +// family: gather-scatter +// target_ops: pto.vgatherb +// scenarios: core-f32, full-mask, block-gather, aligned-base, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vgatherb_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c8_i32 = arith.constant 8 : i32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_offsets = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_offsets, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %full_mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %gather_mask, %_tail = pto.plt_b32 %c8_i32 : i32 -> !pto.mask, i32 + %offsets = pto.vlds %ub_offsets[%offset] : !pto.ptr -> !pto.vreg<64xi32> + %out = pto.vgatherb %ub_in, %offsets, %gather_mask : !pto.ptr, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %full_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgatherb/launch.cpp b/test/vpto/cases/micro-op/gather-scatter/vgatherb/launch.cpp new file mode 100644 index 000000000..589f236be --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgatherb/launch.cpp @@ -0,0 +1,73 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgatherb +// family: gather-scatter +// target_ops: pto.vgatherb +// scenarios: core-f32, full-mask, block-gather, aligned-base, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vgatherb_kernel_2d(__gm__ float *v1, + __gm__ int *v2, + __gm__ float *v3); + +void LaunchVgatherb_kernel_2d(float *v1, int *v2, float *v3, void *stream) { + vgatherb_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ int *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgatherb/main.cpp b/test/vpto/cases/micro-op/gather-scatter/vgatherb/main.cpp new file mode 100644 index 000000000..e16952c96 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgatherb/main.cpp @@ -0,0 +1,140 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgatherb +// family: gather-scatter +// target_ops: pto.vgatherb +// scenarios: core-f32, full-mask, block-gather, aligned-base, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVgatherb_kernel_2d(float *v1, int *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + int *v2Host = nullptr; + int *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVgatherb_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgatherb/stub.cpp b/test/vpto/cases/micro-op/gather-scatter/vgatherb/stub.cpp new file mode 100644 index 000000000..419830b82 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgatherb/stub.cpp @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgatherb +// family: gather-scatter +// target_ops: pto.vgatherb +// scenarios: core-f32, full-mask, block-gather, aligned-base, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vgatherb_kernel_2d(__gm__ float *v1, + __gm__ int *v2, + __gm__ float *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/compare.py b/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/compare.py new file mode 100755 index 000000000..016bfa5b7 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/gather-scatter/vscatter-out-of-order-index +# family: gather-scatter +# target_ops: pto.vscatter +# scenarios: core-f32, explicit-index-pattern, scatter-store, store-effect-validation, no-alias +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/golden.py b/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/golden.py new file mode 100755 index 000000000..3bf886b5f --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/golden.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/gather-scatter/vscatter-out-of-order-index +# family: gather-scatter +# target_ops: pto.vscatter +# scenarios: core-f32, explicit-index-pattern, scatter-store, store-effect-validation, no-alias +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + flat = rng.uniform(-8.0, 8.0, size=(ROWS * COLS,)).astype(np.float32) + offsets = ((np.arange(ROWS * COLS, dtype=np.int32) * 43) + 11) % (ROWS * COLS) + scattered = np.zeros((ROWS * COLS,), dtype=np.float32) + scattered[offsets] = flat + v1 = flat.reshape(ROWS, COLS) + v2 = offsets.reshape(ROWS, COLS) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + scattered.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vscatter validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/kernel.pto b/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/kernel.pto new file mode 100644 index 000000000..f07933e42 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/kernel.pto @@ -0,0 +1,55 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vscatter-out-of-order-index +// family: gather-scatter +// target_ops: pto.vscatter +// scenarios: core-f32, explicit-index-pattern, scatter-store, store-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vscatter_out_of_order_index_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_offsets = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_offsets, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg2, %ub_out, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %full_mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %offsets = pto.vlds %ub_offsets[%offset] : !pto.ptr -> !pto.vreg<64xi32> + pto.vscatter %vec, %ub_out, %offsets, %c64 : !pto.vreg<64xf32>, !pto.ptr, !pto.vreg<64xi32>, index + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/launch.cpp b/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/launch.cpp new file mode 100644 index 000000000..87b02ee5d --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/launch.cpp @@ -0,0 +1,72 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vscatter-out-of-order-index +// family: gather-scatter +// target_ops: pto.vscatter +// scenarios: core-f32, explicit-index-pattern, scatter-store, store-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vscatter_out_of_order_index_kernel_2d( + __gm__ float *v1, __gm__ int *v2, __gm__ float *v3); + +void LaunchVscatter_out_of_order_index_kernel_2d(float *v1, int *v2, + float *v3, void *stream) { + vscatter_out_of_order_index_kernel_2d<<<1, nullptr, stream>>>( + (__gm__ float *)v1, (__gm__ int *)v2, (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/main.cpp b/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/main.cpp new file mode 100644 index 000000000..f762aa293 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/main.cpp @@ -0,0 +1,141 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vscatter-out-of-order-index +// family: gather-scatter +// target_ops: pto.vscatter +// scenarios: core-f32, explicit-index-pattern, scatter-store, store-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVscatter_out_of_order_index_kernel_2d(float *v1, int *v2, float *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + int *v2Host = nullptr; + int *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVscatter_out_of_order_index_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/stub.cpp b/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/stub.cpp new file mode 100644 index 000000000..4acefd3c5 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/stub.cpp @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vscatter-out-of-order-index +// family: gather-scatter +// target_ops: pto.vscatter +// scenarios: core-f32, explicit-index-pattern, scatter-store, store-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vscatter_out_of_order_index_kernel_2d( + __gm__ float *v1, __gm__ int *v2, __gm__ float *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vscatter/compare.py b/test/vpto/cases/micro-op/gather-scatter/vscatter/compare.py new file mode 100755 index 000000000..ada19a30e --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vscatter/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/gather-scatter/vscatter +# family: gather-scatter +# target_ops: pto.vscatter +# scenarios: core-f32, full-mask, non-contiguous, explicit-index-pattern, scatter-store, store-effect-validation, no-alias +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/gather-scatter/vscatter/golden.py b/test/vpto/cases/micro-op/gather-scatter/vscatter/golden.py new file mode 100755 index 000000000..252356095 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vscatter/golden.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/gather-scatter/vscatter +# family: gather-scatter +# target_ops: pto.vscatter +# scenarios: core-f32, full-mask, non-contiguous, explicit-index-pattern, scatter-store, store-effect-validation, no-alias +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + flat = rng.uniform(-8.0, 8.0, size=(ROWS * COLS,)).astype(np.float32) + offsets = ((np.arange(ROWS * COLS, dtype=np.int32) * 29) + 7) % (ROWS * COLS) + scattered = np.zeros((ROWS * COLS,), dtype=np.float32) + scattered[offsets] = flat + v1 = flat.reshape(ROWS, COLS) + v2 = offsets.reshape(ROWS, COLS) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + scattered.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vscatter validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/gather-scatter/vscatter/kernel.pto b/test/vpto/cases/micro-op/gather-scatter/vscatter/kernel.pto new file mode 100644 index 000000000..2efab3a00 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vscatter/kernel.pto @@ -0,0 +1,74 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vscatter +// family: gather-scatter +// target_ops: pto.vscatter +// scenarios: core-f32, full-mask, non-contiguous, explicit-index-pattern, scatter-store, store-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5"} { + func.func @vscatter_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_offsets = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_offsets, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg2, %ub_out, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %offsets = pto.vlds %ub_offsets[%offset] : !pto.ptr -> !pto.vreg<64xi32> + pto.vscatter %vec, %ub_out, %offsets, %c64 : !pto.vreg<64xf32>, !pto.ptr, !pto.vreg<64xi32>, index + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vscatter/launch.cpp b/test/vpto/cases/micro-op/gather-scatter/vscatter/launch.cpp new file mode 100644 index 000000000..79296467b --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vscatter/launch.cpp @@ -0,0 +1,73 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vscatter +// family: gather-scatter +// target_ops: pto.vscatter +// scenarios: core-f32, full-mask, non-contiguous, explicit-index-pattern, scatter-store, store-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vscatter_kernel_2d(__gm__ float *v1, + __gm__ int *v2, + __gm__ float *v3); + +void LaunchVscatter_kernel_2d(float *v1, int *v2, float *v3, void *stream) { + vscatter_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ int *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vscatter/main.cpp b/test/vpto/cases/micro-op/gather-scatter/vscatter/main.cpp new file mode 100644 index 000000000..613ee0282 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vscatter/main.cpp @@ -0,0 +1,140 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vscatter +// family: gather-scatter +// target_ops: pto.vscatter +// scenarios: core-f32, full-mask, non-contiguous, explicit-index-pattern, scatter-store, store-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVscatter_kernel_2d(float *v1, int *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + int *v2Host = nullptr; + int *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVscatter_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vscatter/stub.cpp b/test/vpto/cases/micro-op/gather-scatter/vscatter/stub.cpp new file mode 100644 index 000000000..b7d8f63a6 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vscatter/stub.cpp @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vscatter +// family: gather-scatter +// target_ops: pto.vscatter +// scenarios: core-f32, full-mask, non-contiguous, explicit-index-pattern, scatter-store, store-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vscatter_kernel_2d(__gm__ float *v1, + __gm__ int *v2, + __gm__ float *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pand/compare.py b/test/vpto/cases/micro-op/materialization-predicate/pand/compare.py new file mode 100755 index 000000000..546f4445e --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pand/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pand +# family: materialization-predicate +# target_ops: pto.pand +# scenarios: predicate-transform, logical-and +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 8 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pand/golden.py b/test/vpto/cases/micro-op/materialization-predicate/pand/golden.py new file mode 100755 index 000000000..6d0506864 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pand/golden.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pand +# family: materialization-predicate +# target_ops: pto.pand +# scenarios: predicate-transform, logical-and +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +PREFIX_BITS = 13 +SUFFIX_BITS = 7 +PREDICATE_BITS = 256 +NIBBLE_COUNT = PREDICATE_BITS // 2 + + +def pack_nibbles(nibbles: np.ndarray) -> np.ndarray: + words = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + for idx, nibble in enumerate(nibbles): + words[idx // 8] |= np.uint32(int(nibble) & 0xF) << np.uint32((idx % 8) * 4) + return words + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + lhs = np.zeros((NIBBLE_COUNT,), dtype=np.uint8) + rhs = np.zeros((NIBBLE_COUNT,), dtype=np.uint8) + lhs[:PREFIX_BITS] = 1 + rhs[:SUFFIX_BITS] = 1 + golden = pack_nibbles(lhs & rhs) + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pand/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pand/kernel.pto new file mode 100644 index 000000000..789ebe48c --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pand/kernel.pto @@ -0,0 +1,55 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pand +// family: materialization-predicate +// target_ops: pto.pand +// scenarios: predicate-transform +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5"} { + func.func @pand_kernel_2d(%arg0: !pto.ptr) { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c7 = arith.constant 7 : i32 + %c13 = arith.constant 13 : i32 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + %all = pto.pset_b32 "PAT_ALL" : !pto.mask + %lhs, %lhs_next = pto.plt_b32 %c13 : i32 -> !pto.mask, i32 + %rhs, %rhs_next = pto.plt_b32 %c7 : i32 -> !pto.mask, i32 + %out = pto.pand %lhs, %rhs, %all : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + pto.psts %out, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pand/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/pand/launch.cpp new file mode 100644 index 000000000..bf665aec7 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pand/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pand +// family: materialization-predicate +// target_ops: pto.pand +// scenarios: predicate-transform +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pand_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPand(uint32_t *v1, void *stream) { + pand_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pand/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/pand/main.cpp new file mode 100644 index 000000000..751eed2d9 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pand/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pand +// family: materialization-predicate +// target_ops: pto.pand +// scenarios: predicate-transform +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPand(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPand(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pand/stub.cpp b/test/vpto/cases/micro-op/materialization-predicate/pand/stub.cpp new file mode 100644 index 000000000..973663ec8 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pand/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pand +// family: materialization-predicate +// target_ops: pto.pand +// scenarios: predicate-transform + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void pand_kernel_2d(__gm__ uint32_t *v1) { + (void)v1; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/compare.py b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/compare.py new file mode 100755 index 000000000..25e69c8a2 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pdintlv_b16-nontrivial +# family: materialization-predicate +# target_ops: pto.pdintlv_b16 +# scenarios: predicate-transform, lane-order, representative-logical-elements +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 16 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/golden.py b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/golden.py new file mode 100755 index 000000000..814eb34b5 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pdintlv_b16-nontrivial +# family: materialization-predicate +# target_ops: pto.pdintlv_b16 +# scenarios: predicate-transform, lane-order, representative-logical-elements +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +GOLDEN_PREFIX_WORDS = np.array([85, 0, 0, 0, 286331153, 286331153, 286331153, 286331153, 85, 0, 0, 0, 0, 0, 0, 0], dtype=np.uint32) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden[: GOLDEN_PREFIX_WORDS.size] = GOLDEN_PREFIX_WORDS + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/kernel.pto new file mode 100644 index 000000000..1026206b3 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/kernel.pto @@ -0,0 +1,37 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b16-nontrivial +// family: materialization-predicate +// target_ops: pto.pdintlv_b16 +// scenarios: predicate-transform, lane-order, nontrivial-pattern +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @pdintlv_b16_nontrivial_kernel_2d(%arg0: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c16 = arith.constant 16 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %lhs = pto.pset_b16 "PAT_VL8" : !pto.mask + %rhs = pto.pset_b16 "PAT_M4" : !pto.mask + %low, %high = pto.pdintlv_b16 %lhs, %rhs : !pto.mask, !pto.mask -> !pto.mask, !pto.mask + pto.psts %low, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %high, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/launch.cpp new file mode 100644 index 000000000..182f92536 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b16-nontrivial +// family: materialization-predicate +// target_ops: pto.pdintlv_b16 +// scenarios: predicate-transform, lane-order, nontrivial-pattern +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pdintlv_b16_nontrivial_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPdintlvB16Nontrivial(uint32_t *v1, void *stream) { + pdintlv_b16_nontrivial_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/main.cpp new file mode 100644 index 000000000..02d8a3875 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b16-nontrivial +// family: materialization-predicate +// target_ops: pto.pdintlv_b16 +// scenarios: predicate-transform, lane-order, nontrivial-pattern +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPdintlvB16Nontrivial(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPdintlvB16Nontrivial(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/stub.cpp b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/stub.cpp new file mode 100644 index 000000000..e45873930 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b16-nontrivial +// family: materialization-predicate +// target_ops: pto.pdintlv_b16 +// scenarios: predicate-transform, lane-order, nontrivial-pattern + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void pdintlv_b16_nontrivial_kernel_2d(__gm__ uint32_t *v1) { + (void)v1; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/compare.py b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/compare.py new file mode 100755 index 000000000..2a01bf650 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pdintlv_b16 +# family: materialization-predicate +# target_ops: pto.pdintlv_b16 +# scenarios: predicate-transform, lane-order +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 16 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/golden.py b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/golden.py new file mode 100755 index 000000000..e9227db10 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pdintlv_b16 +# family: materialization-predicate +# target_ops: pto.pdintlv_b16 +# scenarios: predicate-transform, lane-order +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +GOLDEN_PREFIX_WORDS = np.array([1431655765, 1431655765, 1431655765, 1431655765, 0, 0, 0, 0, 1431655765, 1431655765, 1431655765, 1431655765, 0, 0, 0, 0], dtype=np.uint32) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden[: GOLDEN_PREFIX_WORDS.size] = GOLDEN_PREFIX_WORDS + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/kernel.pto new file mode 100644 index 000000000..54c4cd105 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/kernel.pto @@ -0,0 +1,37 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b16 +// family: materialization-predicate +// target_ops: pto.pdintlv_b16 +// scenarios: predicate-transform, lane-order +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @pdintlv_b16_kernel_2d(%arg0: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c16 = arith.constant 16 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %lhs = pto.pset_b16 "PAT_ALL" : !pto.mask + %rhs = pto.pset_b16 "PAT_ALLF" : !pto.mask + %low, %high = pto.pdintlv_b16 %lhs, %rhs : !pto.mask, !pto.mask -> !pto.mask, !pto.mask + pto.psts %low, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %high, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/launch.cpp new file mode 100644 index 000000000..519e90a51 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b16 +// family: materialization-predicate +// target_ops: pto.pdintlv_b16 +// scenarios: predicate-transform, lane-order +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pdintlv_b16_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPdintlvB16(uint32_t *v1, void *stream) { + pdintlv_b16_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/main.cpp new file mode 100644 index 000000000..e2491af41 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b16 +// family: materialization-predicate +// target_ops: pto.pdintlv_b16 +// scenarios: predicate-transform, lane-order +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPdintlvB16(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPdintlvB16(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/stub.cpp b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/stub.cpp new file mode 100644 index 000000000..495bdf0f7 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b16 +// family: materialization-predicate +// target_ops: pto.pdintlv_b16 +// scenarios: predicate-transform, lane-order + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void pdintlv_b16_kernel_2d(__gm__ uint32_t *v1) { + (void)v1; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/compare.py b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/compare.py new file mode 100755 index 000000000..13e93d501 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pdintlv_b32-nontrivial +# family: materialization-predicate +# target_ops: pto.pdintlv_b32 +# scenarios: predicate-transform, lane-order, representative-logical-elements +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 16 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/golden.py b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/golden.py new file mode 100755 index 000000000..013f7751e --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pdintlv_b32-nontrivial +# family: materialization-predicate +# target_ops: pto.pdintlv_b32 +# scenarios: predicate-transform, lane-order, representative-logical-elements +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +GOLDEN_PREFIX_WORDS = np.array([4369, 0, 0, 0, 16843009, 16843009, 16843009, 16843009, 4369, 0, 0, 0, 0, 0, 0, 0], dtype=np.uint32) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden[: GOLDEN_PREFIX_WORDS.size] = GOLDEN_PREFIX_WORDS + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/kernel.pto new file mode 100644 index 000000000..a207737b8 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/kernel.pto @@ -0,0 +1,37 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b32-nontrivial +// family: materialization-predicate +// target_ops: pto.pdintlv_b32 +// scenarios: predicate-transform, lane-order, nontrivial-pattern +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @pdintlv_b32_nontrivial_kernel_2d(%arg0: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c16 = arith.constant 16 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %lhs = pto.pset_b32 "PAT_VL8" : !pto.mask + %rhs = pto.pset_b32 "PAT_M4" : !pto.mask + %low, %high = pto.pdintlv_b32 %lhs, %rhs : !pto.mask, !pto.mask -> !pto.mask, !pto.mask + pto.psts %low, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %high, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/launch.cpp new file mode 100644 index 000000000..9a6cd6a5e --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b32-nontrivial +// family: materialization-predicate +// target_ops: pto.pdintlv_b32 +// scenarios: predicate-transform, lane-order, nontrivial-pattern +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pdintlv_b32_nontrivial_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPdintlvB32Nontrivial(uint32_t *v1, void *stream) { + pdintlv_b32_nontrivial_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/main.cpp new file mode 100644 index 000000000..97d56c906 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b32-nontrivial +// family: materialization-predicate +// target_ops: pto.pdintlv_b32 +// scenarios: predicate-transform, lane-order, nontrivial-pattern +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPdintlvB32Nontrivial(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPdintlvB32Nontrivial(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/stub.cpp b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/stub.cpp new file mode 100644 index 000000000..ff04ee2d5 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b32-nontrivial +// family: materialization-predicate +// target_ops: pto.pdintlv_b32 +// scenarios: predicate-transform, lane-order, nontrivial-pattern + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void pdintlv_b32_nontrivial_kernel_2d(__gm__ uint32_t *v1) { + (void)v1; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/compare.py b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/compare.py new file mode 100755 index 000000000..fab797df6 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pdintlv_b32 +# family: materialization-predicate +# target_ops: pto.pdintlv_b32 +# scenarios: predicate-transform, lane-order +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 16 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/golden.py b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/golden.py new file mode 100755 index 000000000..cd1487eef --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pdintlv_b32 +# family: materialization-predicate +# target_ops: pto.pdintlv_b32 +# scenarios: predicate-transform, lane-order +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +GOLDEN_PREFIX_WORDS = np.array([286331153, 286331153, 286331153, 286331153, 0, 0, 0, 0, 286331153, 286331153, 286331153, 286331153, 0, 0, 0, 0], dtype=np.uint32) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden[: GOLDEN_PREFIX_WORDS.size] = GOLDEN_PREFIX_WORDS + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/kernel.pto new file mode 100644 index 000000000..7885cee06 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/kernel.pto @@ -0,0 +1,37 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b32 +// family: materialization-predicate +// target_ops: pto.pdintlv_b32 +// scenarios: predicate-transform, lane-order +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @pdintlv_b32_kernel_2d(%arg0: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c16 = arith.constant 16 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %lhs = pto.pset_b32 "PAT_ALL" : !pto.mask + %rhs = pto.pset_b32 "PAT_ALLF" : !pto.mask + %low, %high = pto.pdintlv_b32 %lhs, %rhs : !pto.mask, !pto.mask -> !pto.mask, !pto.mask + pto.psts %low, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %high, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/launch.cpp new file mode 100644 index 000000000..316cfc086 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b32 +// family: materialization-predicate +// target_ops: pto.pdintlv_b32 +// scenarios: predicate-transform, lane-order +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pdintlv_b32_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPdintlvB32(uint32_t *v1, void *stream) { + pdintlv_b32_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/main.cpp new file mode 100644 index 000000000..7af8a309d --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b32 +// family: materialization-predicate +// target_ops: pto.pdintlv_b32 +// scenarios: predicate-transform, lane-order +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPdintlvB32(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPdintlvB32(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/stub.cpp b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/stub.cpp new file mode 100644 index 000000000..7fcf1c128 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b32 +// family: materialization-predicate +// target_ops: pto.pdintlv_b32 +// scenarios: predicate-transform, lane-order + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void pdintlv_b32_kernel_2d(__gm__ uint32_t *v1) { + (void)v1; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/compare.py b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/compare.py new file mode 100755 index 000000000..12db124bd --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pdintlv_b8-nontrivial +# family: materialization-predicate +# target_ops: pto.pdintlv_b8 +# scenarios: predicate-transform, lane-order, representative-logical-elements +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 16 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/golden.py b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/golden.py new file mode 100755 index 000000000..1c58d3b87 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pdintlv_b8-nontrivial +# family: materialization-predicate +# target_ops: pto.pdintlv_b8 +# scenarios: predicate-transform, lane-order, representative-logical-elements +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +GOLDEN_PREFIX_WORDS = np.array([15, 0, 0, 0, 1431655765, 1431655765, 1431655765, 1431655765, 15, 0, 0, 0, 0, 0, 0, 0], dtype=np.uint32) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden[: GOLDEN_PREFIX_WORDS.size] = GOLDEN_PREFIX_WORDS + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/kernel.pto new file mode 100644 index 000000000..3f82b9263 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/kernel.pto @@ -0,0 +1,37 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b8-nontrivial +// family: materialization-predicate +// target_ops: pto.pdintlv_b8 +// scenarios: predicate-transform, lane-order, nontrivial-pattern +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @pdintlv_b8_nontrivial_kernel_2d(%arg0: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c16 = arith.constant 16 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %lhs = pto.pset_b8 "PAT_VL8" : !pto.mask + %rhs = pto.pset_b8 "PAT_M4" : !pto.mask + %low, %high = pto.pdintlv_b8 %lhs, %rhs : !pto.mask, !pto.mask -> !pto.mask, !pto.mask + pto.psts %low, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %high, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/launch.cpp new file mode 100644 index 000000000..e6e2949f5 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b8-nontrivial +// family: materialization-predicate +// target_ops: pto.pdintlv_b8 +// scenarios: predicate-transform, lane-order, nontrivial-pattern +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pdintlv_b8_nontrivial_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPdintlvB8Nontrivial(uint32_t *v1, void *stream) { + pdintlv_b8_nontrivial_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/main.cpp new file mode 100644 index 000000000..71e67e085 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b8-nontrivial +// family: materialization-predicate +// target_ops: pto.pdintlv_b8 +// scenarios: predicate-transform, lane-order, nontrivial-pattern +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPdintlvB8Nontrivial(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPdintlvB8Nontrivial(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/stub.cpp b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/stub.cpp new file mode 100644 index 000000000..8c809522f --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b8-nontrivial +// family: materialization-predicate +// target_ops: pto.pdintlv_b8 +// scenarios: predicate-transform, lane-order, nontrivial-pattern + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void pdintlv_b8_nontrivial_kernel_2d(__gm__ uint32_t *v1) { + (void)v1; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/compare.py b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/compare.py new file mode 100755 index 000000000..305f17e97 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pdintlv_b8 +# family: materialization-predicate +# target_ops: pto.pdintlv_b8 +# scenarios: predicate-transform, lane-order +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 16 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/golden.py b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/golden.py new file mode 100755 index 000000000..e0ed75f95 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pdintlv_b8 +# family: materialization-predicate +# target_ops: pto.pdintlv_b8 +# scenarios: predicate-transform, lane-order +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +GOLDEN_PREFIX_WORDS = np.array([4294967295, 4294967295, 4294967295, 4294967295, 0, 0, 0, 0, 4294967295, 4294967295, 4294967295, 4294967295, 0, 0, 0, 0], dtype=np.uint32) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden[: GOLDEN_PREFIX_WORDS.size] = GOLDEN_PREFIX_WORDS + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/kernel.pto new file mode 100644 index 000000000..647105b76 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/kernel.pto @@ -0,0 +1,37 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b8 +// family: materialization-predicate +// target_ops: pto.pdintlv_b8 +// scenarios: predicate-transform, lane-order +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @pdintlv_b8_kernel_2d(%arg0: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c16 = arith.constant 16 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %lhs = pto.pset_b8 "PAT_ALL" : !pto.mask + %rhs = pto.pset_b8 "PAT_ALLF" : !pto.mask + %low, %high = pto.pdintlv_b8 %lhs, %rhs : !pto.mask, !pto.mask -> !pto.mask, !pto.mask + pto.psts %low, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %high, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/launch.cpp new file mode 100644 index 000000000..a8a45dada --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b8 +// family: materialization-predicate +// target_ops: pto.pdintlv_b8 +// scenarios: predicate-transform, lane-order +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pdintlv_b8_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPdintlvB8(uint32_t *v1, void *stream) { + pdintlv_b8_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/main.cpp new file mode 100644 index 000000000..d0edfba37 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b8 +// family: materialization-predicate +// target_ops: pto.pdintlv_b8 +// scenarios: predicate-transform, lane-order +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPdintlvB8(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPdintlvB8(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/stub.cpp b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/stub.cpp new file mode 100644 index 000000000..1f344f07b --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b8 +// family: materialization-predicate +// target_ops: pto.pdintlv_b8 +// scenarios: predicate-transform, lane-order + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void pdintlv_b8_kernel_2d(__gm__ uint32_t *v1) { + (void)v1; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/compare.py b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/compare.py new file mode 100755 index 000000000..8700b0d56 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/compare.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pge-tail-mask-boundary +# family: materialization-predicate +# target_ops: pto.pge_b16, pto.pge_b32, pto.pge_b8 +# scenarios: tail-mask, boundary +# coding=utf-8 + +import os +import sys +import numpy as np + +EXPECTED_WORDS = 32 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print(f"[ERROR] Unexpected word count: golden={golden.size} out={output.size}") + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch (packed predicate words): idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/golden.py b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/golden.py new file mode 100755 index 000000000..67f3b65df --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/golden.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pge-tail-mask-boundary +# family: materialization-predicate +# target_ops: pto.pge_b16, pto.pge_b32, pto.pge_b8 +# scenarios: tail-mask, boundary +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 + + +def _pack_prefix(active_lanes: int, bit_stride: int, store_bytes: int) -> np.ndarray: + out = np.zeros((store_bytes,), dtype=np.uint8) + for lane in range(active_lanes): + bit_index = lane * bit_stride + out[bit_index // 8] |= np.uint8(1 << (bit_index % 8)) + return out + + +def generate(output_dir: Path, seed: int) -> None: + del seed + v1 = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS * 4,), dtype=np.uint8) + golden[0:32] = _pack_prefix(active_lanes=1, bit_stride=1, store_bytes=32) + golden[32:64] = _pack_prefix(active_lanes=1, bit_stride=2, store_bytes=32) + golden[64:96] = _pack_prefix(active_lanes=1, bit_stride=4, store_bytes=32) + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + golden.view(np.uint32).tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/kernel.pto new file mode 100644 index 000000000..2659d6708 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/kernel.pto @@ -0,0 +1,38 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pge-tail-mask-boundary +// family: materialization-predicate +// target_ops: pto.pge_b16, pto.pge_b32, pto.pge_b8 +// scenarios: tail-mask, boundary +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @pge_tail_mask_boundary_kernel_2d(%arg0: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c96_i64 = arith.constant 96 : i64 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %m0 = pto.pge_b8 "PAT_VL1" : !pto.mask + %m1 = pto.pge_b16 "PAT_VL1" : !pto.mask + %m2 = pto.pge_b32 "PAT_VL1" : !pto.mask + pto.psts %m0, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %m1, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %m2, %ub_out[%c64], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c96_i64, %c0_i64, %c96_i64, %c96_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/launch.cpp new file mode 100644 index 000000000..6a52d74b3 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pge-tail-mask-boundary +// family: materialization-predicate +// target_ops: pto.pge_b16, pto.pge_b32, pto.pge_b8 +// scenarios: tail-mask, boundary +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pge_tail_mask_boundary_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPgeTailMaskBoundary(uint32_t *v1, void *stream) { + pge_tail_mask_boundary_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/main.cpp new file mode 100644 index 000000000..c0ca90a8c --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pge-tail-mask-boundary +// family: materialization-predicate +// target_ops: pto.pge_b16, pto.pge_b32, pto.pge_b8 +// scenarios: tail-mask, boundary +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPgeTailMaskBoundary(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPgeTailMaskBoundary(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/stub.cpp b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/stub.cpp new file mode 100644 index 000000000..1ffe8bd40 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pge-tail-mask-boundary +// family: materialization-predicate +// target_ops: pto.pge_b16, pto.pge_b32, pto.pge_b8 +// scenarios: tail-mask, boundary + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void pge_tail_mask_boundary_kernel_2d(__gm__ uint32_t *v1) { + (void)v1; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/compare.py b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/compare.py new file mode 100755 index 000000000..2e598e7ae --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/compare.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pge-tail-mask +# family: materialization-predicate +# target_ops: pto.pge_b16, pto.pge_b32, pto.pge_b8 +# scenarios: tail-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + +EXPECTED_WORDS = 32 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print(f"[ERROR] Unexpected word count: golden={golden.size} out={output.size}") + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch (packed predicate words): idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/golden.py b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/golden.py new file mode 100755 index 000000000..823fc8889 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/golden.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pge-tail-mask +# family: materialization-predicate +# target_ops: pto.pge_b16, pto.pge_b32, pto.pge_b8 +# scenarios: tail-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 + + +def _pack_prefix(active_lanes: int, bit_stride: int, store_bytes: int) -> np.ndarray: + out = np.zeros((store_bytes,), dtype=np.uint8) + for lane in range(active_lanes): + bit_index = lane * bit_stride + out[bit_index // 8] |= np.uint8(1 << (bit_index % 8)) + return out + + +def generate(output_dir: Path, seed: int) -> None: + del seed + v1 = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS * 4,), dtype=np.uint8) + golden[0:32] = _pack_prefix(active_lanes=8, bit_stride=1, store_bytes=32) + golden[32:64] = _pack_prefix(active_lanes=8, bit_stride=2, store_bytes=32) + golden[64:96] = _pack_prefix(active_lanes=8, bit_stride=4, store_bytes=32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + golden.view(np.uint32).tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op pge-tail-mask validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/kernel.pto new file mode 100644 index 000000000..7991ea772 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/kernel.pto @@ -0,0 +1,38 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pge-tail-mask +// family: materialization-predicate +// target_ops: pto.pge_b16, pto.pge_b32, pto.pge_b8 +// scenarios: tail-mask +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @pge_tail_mask_kernel_2d(%arg0: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c96_i64 = arith.constant 96 : i64 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %m0 = pto.pge_b8 "PAT_VL8" : !pto.mask + %m1 = pto.pge_b16 "PAT_VL8" : !pto.mask + %m2 = pto.pge_b32 "PAT_VL8" : !pto.mask + pto.psts %m0, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %m1, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %m2, %ub_out[%c64], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c96_i64, %c0_i64, %c96_i64, %c96_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/launch.cpp new file mode 100644 index 000000000..c38434d88 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pge-tail-mask +// family: materialization-predicate +// target_ops: pto.pge_b16, pto.pge_b32, pto.pge_b8 +// scenarios: tail-mask +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pge_tail_mask_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPgeTailMask(uint32_t *v1, void *stream) { + pge_tail_mask_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/main.cpp new file mode 100644 index 000000000..dea4fa6c5 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pge-tail-mask +// family: materialization-predicate +// target_ops: pto.pge_b16, pto.pge_b32, pto.pge_b8 +// scenarios: tail-mask +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPgeTailMask(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPgeTailMask(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/stub.cpp b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/stub.cpp new file mode 100644 index 000000000..5de76cd6d --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pge-tail-mask +// family: materialization-predicate +// target_ops: pto.pge_b16, pto.pge_b32, pto.pge_b8 +// scenarios: tail-mask + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void pge_tail_mask_kernel_2d(__gm__ uint32_t *v1) { + (void)v1; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/compare.py b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/compare.py new file mode 100755 index 000000000..7704bbfb7 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pintlv_b16-nontrivial +# family: materialization-predicate +# target_ops: pto.pintlv_b16 +# scenarios: predicate-transform, lane-order, representative-logical-elements +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 16 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/golden.py b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/golden.py new file mode 100755 index 000000000..f1729a845 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pintlv_b16-nontrivial +# family: materialization-predicate +# target_ops: pto.pintlv_b16 +# scenarios: predicate-transform, lane-order, representative-logical-elements +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +GOLDEN_PREFIX_WORDS = np.array([286593301, 262148, 262148, 262148, 262148, 262148, 262148, 262148, 262148, 262148, 262148, 262148, 262148, 262148, 262148, 262148], dtype=np.uint32) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden[: GOLDEN_PREFIX_WORDS.size] = GOLDEN_PREFIX_WORDS + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/kernel.pto new file mode 100644 index 000000000..2dc9e5ff1 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/kernel.pto @@ -0,0 +1,37 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b16-nontrivial +// family: materialization-predicate +// target_ops: pto.pintlv_b16 +// scenarios: predicate-transform, lane-order, nontrivial-pattern +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @pintlv_b16_nontrivial_kernel_2d(%arg0: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c16 = arith.constant 16 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %lhs = pto.pset_b16 "PAT_VL8" : !pto.mask + %rhs = pto.pset_b16 "PAT_M4" : !pto.mask + %low, %high = pto.pintlv_b16 %lhs, %rhs : !pto.mask, !pto.mask -> !pto.mask, !pto.mask + pto.psts %low, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %high, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/launch.cpp new file mode 100644 index 000000000..57939cac6 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b16-nontrivial +// family: materialization-predicate +// target_ops: pto.pintlv_b16 +// scenarios: predicate-transform, lane-order, nontrivial-pattern +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pintlv_b16_nontrivial_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPintlvB16Nontrivial(uint32_t *v1, void *stream) { + pintlv_b16_nontrivial_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/main.cpp new file mode 100644 index 000000000..aca9caf7a --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b16-nontrivial +// family: materialization-predicate +// target_ops: pto.pintlv_b16 +// scenarios: predicate-transform, lane-order, nontrivial-pattern +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPintlvB16Nontrivial(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPintlvB16Nontrivial(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/stub.cpp b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/stub.cpp new file mode 100644 index 000000000..d026701ba --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b16-nontrivial +// family: materialization-predicate +// target_ops: pto.pintlv_b16 +// scenarios: predicate-transform, lane-order, nontrivial-pattern + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void pintlv_b16_nontrivial_kernel_2d(__gm__ uint32_t *v1) { + (void)v1; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/compare.py b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/compare.py new file mode 100755 index 000000000..9c1deb9c8 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pintlv_b16 +# family: materialization-predicate +# target_ops: pto.pintlv_b16 +# scenarios: predicate-transform, lane-order +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 16 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/golden.py b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/golden.py new file mode 100755 index 000000000..a52cba6a2 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pintlv_b16 +# family: materialization-predicate +# target_ops: pto.pintlv_b16 +# scenarios: predicate-transform, lane-order +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +GOLDEN_PREFIX_WORDS = np.array([286331153, 286331153, 286331153, 286331153, 286331153, 286331153, 286331153, 286331153, 286331153, 286331153, 286331153, 286331153, 286331153, 286331153, 286331153, 286331153], dtype=np.uint32) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden[: GOLDEN_PREFIX_WORDS.size] = GOLDEN_PREFIX_WORDS + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/kernel.pto new file mode 100644 index 000000000..8632c930e --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/kernel.pto @@ -0,0 +1,37 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b16 +// family: materialization-predicate +// target_ops: pto.pintlv_b16 +// scenarios: predicate-transform, lane-order +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @pintlv_b16_kernel_2d(%arg0: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c16 = arith.constant 16 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %lhs = pto.pset_b16 "PAT_ALL" : !pto.mask + %rhs = pto.pset_b16 "PAT_ALLF" : !pto.mask + %low, %high = pto.pintlv_b16 %lhs, %rhs : !pto.mask, !pto.mask -> !pto.mask, !pto.mask + pto.psts %low, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %high, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/launch.cpp new file mode 100644 index 000000000..262d87427 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b16 +// family: materialization-predicate +// target_ops: pto.pintlv_b16 +// scenarios: predicate-transform, lane-order +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pintlv_b16_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPintlvB16(uint32_t *v1, void *stream) { + pintlv_b16_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/main.cpp new file mode 100644 index 000000000..29156f86a --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b16 +// family: materialization-predicate +// target_ops: pto.pintlv_b16 +// scenarios: predicate-transform, lane-order +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPintlvB16(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPintlvB16(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/stub.cpp b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/stub.cpp new file mode 100644 index 000000000..af8247da7 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b16 +// family: materialization-predicate +// target_ops: pto.pintlv_b16 +// scenarios: predicate-transform, lane-order + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void pintlv_b16_kernel_2d(__gm__ uint32_t *v1) { + (void)v1; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/compare.py b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/compare.py new file mode 100755 index 000000000..daad8dae1 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pintlv_b32-nontrivial +# family: materialization-predicate +# target_ops: pto.pintlv_b32 +# scenarios: predicate-transform, lane-order, representative-logical-elements +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 16 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/golden.py b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/golden.py new file mode 100755 index 000000000..c28bc3f71 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pintlv_b32-nontrivial +# family: materialization-predicate +# target_ops: pto.pintlv_b32 +# scenarios: predicate-transform, lane-order, representative-logical-elements +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +GOLDEN_PREFIX_WORDS = np.array([16843025, 16843025, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16], dtype=np.uint32) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden[: GOLDEN_PREFIX_WORDS.size] = GOLDEN_PREFIX_WORDS + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/kernel.pto new file mode 100644 index 000000000..59bdf9a89 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/kernel.pto @@ -0,0 +1,37 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b32-nontrivial +// family: materialization-predicate +// target_ops: pto.pintlv_b32 +// scenarios: predicate-transform, lane-order, nontrivial-pattern +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @pintlv_b32_nontrivial_kernel_2d(%arg0: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c16 = arith.constant 16 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %lhs = pto.pset_b32 "PAT_VL8" : !pto.mask + %rhs = pto.pset_b32 "PAT_M4" : !pto.mask + %low, %high = pto.pintlv_b32 %lhs, %rhs : !pto.mask, !pto.mask -> !pto.mask, !pto.mask + pto.psts %low, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %high, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/launch.cpp new file mode 100644 index 000000000..06dcd4072 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b32-nontrivial +// family: materialization-predicate +// target_ops: pto.pintlv_b32 +// scenarios: predicate-transform, lane-order, nontrivial-pattern +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pintlv_b32_nontrivial_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPintlvB32Nontrivial(uint32_t *v1, void *stream) { + pintlv_b32_nontrivial_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/main.cpp new file mode 100644 index 000000000..befee95c9 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b32-nontrivial +// family: materialization-predicate +// target_ops: pto.pintlv_b32 +// scenarios: predicate-transform, lane-order, nontrivial-pattern +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPintlvB32Nontrivial(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPintlvB32Nontrivial(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/stub.cpp b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/stub.cpp new file mode 100644 index 000000000..45663e706 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b32-nontrivial +// family: materialization-predicate +// target_ops: pto.pintlv_b32 +// scenarios: predicate-transform, lane-order, nontrivial-pattern + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void pintlv_b32_nontrivial_kernel_2d(__gm__ uint32_t *v1) { + (void)v1; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/compare.py b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/compare.py new file mode 100755 index 000000000..b3050eb69 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pintlv_b32 +# family: materialization-predicate +# target_ops: pto.pintlv_b32 +# scenarios: predicate-transform, lane-order +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 16 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/golden.py b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/golden.py new file mode 100755 index 000000000..67cb39fc8 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pintlv_b32 +# family: materialization-predicate +# target_ops: pto.pintlv_b32 +# scenarios: predicate-transform, lane-order +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +GOLDEN_PREFIX_WORDS = np.array([16843009, 16843009, 16843009, 16843009, 16843009, 16843009, 16843009, 16843009, 16843009, 16843009, 16843009, 16843009, 16843009, 16843009, 16843009, 16843009], dtype=np.uint32) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden[: GOLDEN_PREFIX_WORDS.size] = GOLDEN_PREFIX_WORDS + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/kernel.pto new file mode 100644 index 000000000..5765ef77f --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/kernel.pto @@ -0,0 +1,37 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b32 +// family: materialization-predicate +// target_ops: pto.pintlv_b32 +// scenarios: predicate-transform, lane-order +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @pintlv_b32_kernel_2d(%arg0: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c16 = arith.constant 16 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %lhs = pto.pset_b32 "PAT_ALL" : !pto.mask + %rhs = pto.pset_b32 "PAT_ALLF" : !pto.mask + %low, %high = pto.pintlv_b32 %lhs, %rhs : !pto.mask, !pto.mask -> !pto.mask, !pto.mask + pto.psts %low, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %high, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/launch.cpp new file mode 100644 index 000000000..bb990592f --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b32 +// family: materialization-predicate +// target_ops: pto.pintlv_b32 +// scenarios: predicate-transform, lane-order +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pintlv_b32_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPintlvB32(uint32_t *v1, void *stream) { + pintlv_b32_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/main.cpp new file mode 100644 index 000000000..d0ef0696d --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b32 +// family: materialization-predicate +// target_ops: pto.pintlv_b32 +// scenarios: predicate-transform, lane-order +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPintlvB32(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPintlvB32(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/stub.cpp b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/stub.cpp new file mode 100644 index 000000000..2f9f69823 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b32 +// family: materialization-predicate +// target_ops: pto.pintlv_b32 +// scenarios: predicate-transform, lane-order + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void pintlv_b32_kernel_2d(__gm__ uint32_t *v1) { + (void)v1; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/compare.py b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/compare.py new file mode 100755 index 000000000..16b0c224d --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pintlv_b8-nontrivial +# family: materialization-predicate +# target_ops: pto.pintlv_b8 +# scenarios: predicate-transform, lane-order, representative-logical-elements +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 16 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/golden.py b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/golden.py new file mode 100755 index 000000000..de8ae6216 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pintlv_b8-nontrivial +# family: materialization-predicate +# target_ops: pto.pintlv_b8 +# scenarios: predicate-transform, lane-order, representative-logical-elements +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +GOLDEN_PREFIX_WORDS = np.array([33707863, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018], dtype=np.uint32) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden[: GOLDEN_PREFIX_WORDS.size] = GOLDEN_PREFIX_WORDS + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/kernel.pto new file mode 100644 index 000000000..78319891b --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/kernel.pto @@ -0,0 +1,37 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b8-nontrivial +// family: materialization-predicate +// target_ops: pto.pintlv_b8 +// scenarios: predicate-transform, lane-order, nontrivial-pattern +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @pintlv_b8_nontrivial_kernel_2d(%arg0: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c16 = arith.constant 16 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %lhs = pto.pset_b8 "PAT_VL8" : !pto.mask + %rhs = pto.pset_b8 "PAT_M4" : !pto.mask + %low, %high = pto.pintlv_b8 %lhs, %rhs : !pto.mask, !pto.mask -> !pto.mask, !pto.mask + pto.psts %low, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %high, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/launch.cpp new file mode 100644 index 000000000..c466ef9c8 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b8-nontrivial +// family: materialization-predicate +// target_ops: pto.pintlv_b8 +// scenarios: predicate-transform, lane-order, nontrivial-pattern +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pintlv_b8_nontrivial_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPintlvB8Nontrivial(uint32_t *v1, void *stream) { + pintlv_b8_nontrivial_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/main.cpp new file mode 100644 index 000000000..d27a5f08c --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b8-nontrivial +// family: materialization-predicate +// target_ops: pto.pintlv_b8 +// scenarios: predicate-transform, lane-order, nontrivial-pattern +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPintlvB8Nontrivial(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPintlvB8Nontrivial(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/stub.cpp b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/stub.cpp new file mode 100644 index 000000000..df9deef0a --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b8-nontrivial +// family: materialization-predicate +// target_ops: pto.pintlv_b8 +// scenarios: predicate-transform, lane-order, nontrivial-pattern + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void pintlv_b8_nontrivial_kernel_2d(__gm__ uint32_t *v1) { + (void)v1; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/compare.py b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/compare.py new file mode 100755 index 000000000..d6cae1168 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pintlv_b8 +# family: materialization-predicate +# target_ops: pto.pintlv_b8 +# scenarios: predicate-transform, lane-order +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 16 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/golden.py b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/golden.py new file mode 100755 index 000000000..bae1a196c --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pintlv_b8 +# family: materialization-predicate +# target_ops: pto.pintlv_b8 +# scenarios: predicate-transform, lane-order +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +GOLDEN_PREFIX_WORDS = np.array([1431655765, 1431655765, 1431655765, 1431655765, 1431655765, 1431655765, 1431655765, 1431655765, 1431655765, 1431655765, 1431655765, 1431655765, 1431655765, 1431655765, 1431655765, 1431655765], dtype=np.uint32) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden[: GOLDEN_PREFIX_WORDS.size] = GOLDEN_PREFIX_WORDS + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/kernel.pto new file mode 100644 index 000000000..fdf5a7afc --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/kernel.pto @@ -0,0 +1,37 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b8 +// family: materialization-predicate +// target_ops: pto.pintlv_b8 +// scenarios: predicate-transform, lane-order +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @pintlv_b8_kernel_2d(%arg0: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c16 = arith.constant 16 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %lhs = pto.pset_b8 "PAT_ALL" : !pto.mask + %rhs = pto.pset_b8 "PAT_ALLF" : !pto.mask + %low, %high = pto.pintlv_b8 %lhs, %rhs : !pto.mask, !pto.mask -> !pto.mask, !pto.mask + pto.psts %low, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %high, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/launch.cpp new file mode 100644 index 000000000..d0299e575 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b8 +// family: materialization-predicate +// target_ops: pto.pintlv_b8 +// scenarios: predicate-transform, lane-order +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pintlv_b8_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPintlvB8(uint32_t *v1, void *stream) { + pintlv_b8_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/main.cpp new file mode 100644 index 000000000..b4b856773 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b8 +// family: materialization-predicate +// target_ops: pto.pintlv_b8 +// scenarios: predicate-transform, lane-order +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPintlvB8(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPintlvB8(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/stub.cpp b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/stub.cpp new file mode 100644 index 000000000..9f8242611 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b8 +// family: materialization-predicate +// target_ops: pto.pintlv_b8 +// scenarios: predicate-transform, lane-order + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void pintlv_b8_kernel_2d(__gm__ uint32_t *v1) { + (void)v1; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/compare.py b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/compare.py new file mode 100755 index 000000000..8eac93173 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/compare.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/plt-tail-mask-boundary +# family: materialization-predicate +# target_ops: pto.plt_b16, pto.plt_b32, pto.plt_b8 +# scenarios: tail-mask, scalar-carry-out, boundary +# coding=utf-8 + +import os +import sys +import numpy as np + +EXPECTED_WORDS = 32 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print(f"[ERROR] Unexpected word count: golden={golden.size} out={output.size}") + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch (packed predicate words): idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/golden.py b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/golden.py new file mode 100755 index 000000000..c812554d8 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/golden.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/plt-tail-mask-boundary +# family: materialization-predicate +# target_ops: pto.plt_b16, pto.plt_b32, pto.plt_b8 +# scenarios: tail-mask, scalar-carry-out, boundary +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 + + +def _pack_prefix(active_lanes: int, bit_stride: int, store_bytes: int) -> np.ndarray: + out = np.zeros((store_bytes,), dtype=np.uint8) + for lane in range(active_lanes): + bit_index = lane * bit_stride + out[bit_index // 8] |= np.uint8(1 << (bit_index % 8)) + return out + + +def generate(output_dir: Path, seed: int) -> None: + del seed + v1 = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS * 4,), dtype=np.uint8) + golden[0:32] = _pack_prefix(active_lanes=1, bit_stride=1, store_bytes=32) + golden[32:64] = _pack_prefix(active_lanes=1, bit_stride=2, store_bytes=32) + golden[64:96] = _pack_prefix(active_lanes=1, bit_stride=4, store_bytes=32) + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + golden.view(np.uint32).tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/kernel.pto new file mode 100644 index 000000000..68b29ee1f --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/kernel.pto @@ -0,0 +1,39 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/plt-tail-mask-boundary +// family: materialization-predicate +// target_ops: pto.plt_b16, pto.plt_b32, pto.plt_b8 +// scenarios: tail-mask, scalar-carry-out, boundary +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @plt_tail_mask_boundary_kernel_2d(%arg0: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c96_i64 = arith.constant 96 : i64 + %c1_i32 = arith.constant 1 : i32 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %m0, %s0 = pto.plt_b8 %c1_i32 : i32 -> !pto.mask, i32 + %m1, %s1 = pto.plt_b16 %c1_i32 : i32 -> !pto.mask, i32 + %m2, %s2 = pto.plt_b32 %c1_i32 : i32 -> !pto.mask, i32 + pto.psts %m0, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %m1, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %m2, %ub_out[%c64], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c96_i64, %c0_i64, %c96_i64, %c96_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/launch.cpp new file mode 100644 index 000000000..fad09577d --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/plt-tail-mask-boundary +// family: materialization-predicate +// target_ops: pto.plt_b16, pto.plt_b32, pto.plt_b8 +// scenarios: tail-mask, scalar-carry-out, boundary +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void plt_tail_mask_boundary_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPltTailMaskBoundary(uint32_t *v1, void *stream) { + plt_tail_mask_boundary_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/main.cpp new file mode 100644 index 000000000..0dfe9d502 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/plt-tail-mask-boundary +// family: materialization-predicate +// target_ops: pto.plt_b16, pto.plt_b32, pto.plt_b8 +// scenarios: tail-mask, scalar-carry-out, boundary +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPltTailMaskBoundary(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPltTailMaskBoundary(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/stub.cpp b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/stub.cpp new file mode 100644 index 000000000..ad38acf23 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/plt-tail-mask-boundary +// family: materialization-predicate +// target_ops: pto.plt_b16, pto.plt_b32, pto.plt_b8 +// scenarios: tail-mask, scalar-carry-out, boundary + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void plt_tail_mask_boundary_kernel_2d(__gm__ uint32_t *v1) { + (void)v1; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/compare.py b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/compare.py new file mode 100755 index 000000000..b2466b6ab --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/compare.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/plt-tail-mask +# family: materialization-predicate +# target_ops: pto.plt_b16, pto.plt_b32, pto.plt_b8 +# scenarios: tail-mask, scalar-carry-out +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + +EXPECTED_WORDS = 32 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print(f"[ERROR] Unexpected word count: golden={golden.size} out={output.size}") + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch (packed predicate words): idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/golden.py b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/golden.py new file mode 100755 index 000000000..1ef2a1a79 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/golden.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/plt-tail-mask +# family: materialization-predicate +# target_ops: pto.plt_b16, pto.plt_b32, pto.plt_b8 +# scenarios: tail-mask, scalar-carry-out +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 + + +def _pack_prefix(active_lanes: int, bit_stride: int, store_bytes: int) -> np.ndarray: + out = np.zeros((store_bytes,), dtype=np.uint8) + for lane in range(active_lanes): + bit_index = lane * bit_stride + out[bit_index // 8] |= np.uint8(1 << (bit_index % 8)) + return out + + +def generate(output_dir: Path, seed: int) -> None: + del seed + v1 = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS * 4,), dtype=np.uint8) + golden[0:32] = _pack_prefix(active_lanes=13, bit_stride=1, store_bytes=32) + golden[32:64] = _pack_prefix(active_lanes=7, bit_stride=2, store_bytes=32) + golden[64:96] = _pack_prefix(active_lanes=3, bit_stride=4, store_bytes=32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + golden.view(np.uint32).tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op plt-tail-mask validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/kernel.pto new file mode 100644 index 000000000..39729a22b --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/kernel.pto @@ -0,0 +1,41 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/plt-tail-mask +// family: materialization-predicate +// target_ops: pto.plt_b16, pto.plt_b32, pto.plt_b8 +// scenarios: tail-mask, scalar-carry-out +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @plt_tail_mask_kernel_2d(%arg0: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c96_i64 = arith.constant 96 : i64 + %c13 = arith.constant 13 : i32 + %c7 = arith.constant 7 : i32 + %c3 = arith.constant 3 : i32 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %m0, %s0 = pto.plt_b8 %c13 : i32 -> !pto.mask, i32 + %m1, %s1 = pto.plt_b16 %c7 : i32 -> !pto.mask, i32 + %m2, %s2 = pto.plt_b32 %c3 : i32 -> !pto.mask, i32 + pto.psts %m0, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %m1, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %m2, %ub_out[%c64], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c96_i64, %c0_i64, %c96_i64, %c96_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/launch.cpp new file mode 100644 index 000000000..1c9b21d24 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/plt-tail-mask +// family: materialization-predicate +// target_ops: pto.plt_b16, pto.plt_b32, pto.plt_b8 +// scenarios: tail-mask, scalar-carry-out +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void plt_tail_mask_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPltTailMask(uint32_t *v1, void *stream) { + plt_tail_mask_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/main.cpp new file mode 100644 index 000000000..7fdb5fe93 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/plt-tail-mask +// family: materialization-predicate +// target_ops: pto.plt_b16, pto.plt_b32, pto.plt_b8 +// scenarios: tail-mask, scalar-carry-out +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPltTailMask(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPltTailMask(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/stub.cpp b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/stub.cpp new file mode 100644 index 000000000..bed898779 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/plt-tail-mask +// family: materialization-predicate +// target_ops: pto.plt_b16, pto.plt_b32, pto.plt_b8 +// scenarios: tail-mask, scalar-carry-out + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void plt_tail_mask_kernel_2d(__gm__ uint32_t *v1) { + (void)v1; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pnot/compare.py b/test/vpto/cases/micro-op/materialization-predicate/pnot/compare.py new file mode 100755 index 000000000..a75c98c84 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pnot/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pnot +# family: materialization-predicate +# target_ops: pto.pnot +# scenarios: predicate-transform, logical-not +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 8 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pnot/golden.py b/test/vpto/cases/micro-op/materialization-predicate/pnot/golden.py new file mode 100755 index 000000000..cafe4d7d0 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pnot/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pnot +# family: materialization-predicate +# target_ops: pto.pnot +# scenarios: predicate-transform, logical-not +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +GOLDEN_PREFIX_WORDS = np.array([0, 286261248, 286331153, 286331153, 286331153, 286331153, 286331153, 286331153], dtype=np.uint32) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden[: GOLDEN_PREFIX_WORDS.size] = GOLDEN_PREFIX_WORDS + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pnot/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pnot/kernel.pto new file mode 100644 index 000000000..fd8df6e77 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pnot/kernel.pto @@ -0,0 +1,36 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pnot +// family: materialization-predicate +// target_ops: pto.pnot +// scenarios: predicate-transform +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @pnot_kernel_2d(%arg0: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c13 = arith.constant 13 : i32 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %all = pto.pset_b32 "PAT_ALL" : !pto.mask + %half, %next = pto.plt_b32 %c13 : i32 -> !pto.mask, i32 + %out = pto.pnot %half, %all : !pto.mask, !pto.mask -> !pto.mask + pto.psts %out, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pnot/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/pnot/launch.cpp new file mode 100644 index 000000000..50cf29220 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pnot/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pnot +// family: materialization-predicate +// target_ops: pto.pnot +// scenarios: predicate-transform +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pnot_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPnot(uint32_t *v1, void *stream) { + pnot_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pnot/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/pnot/main.cpp new file mode 100644 index 000000000..64b153376 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pnot/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pnot +// family: materialization-predicate +// target_ops: pto.pnot +// scenarios: predicate-transform +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPnot(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPnot(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pnot/stub.cpp b/test/vpto/cases/micro-op/materialization-predicate/pnot/stub.cpp new file mode 100644 index 000000000..19416b2a3 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pnot/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pnot +// family: materialization-predicate +// target_ops: pto.pnot +// scenarios: predicate-transform + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void pnot_kernel_2d(__gm__ uint32_t *v1) { + (void)v1; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/por/compare.py b/test/vpto/cases/micro-op/materialization-predicate/por/compare.py new file mode 100755 index 000000000..2d6c341a8 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/por/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/por +# family: materialization-predicate +# target_ops: pto.por +# scenarios: predicate-transform, logical-or +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 8 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/por/golden.py b/test/vpto/cases/micro-op/materialization-predicate/por/golden.py new file mode 100755 index 000000000..c9c5dfe1e --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/por/golden.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/por +# family: materialization-predicate +# target_ops: pto.por +# scenarios: predicate-transform, logical-or +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +PREFIX_BITS = 13 +SUFFIX_BITS = 7 +PREDICATE_BITS = 256 +NIBBLE_COUNT = PREDICATE_BITS // 2 + + +def pack_nibbles(nibbles: np.ndarray) -> np.ndarray: + words = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + for idx, nibble in enumerate(nibbles): + words[idx // 8] |= np.uint32(int(nibble) & 0xF) << np.uint32((idx % 8) * 4) + return words + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + lhs = np.zeros((NIBBLE_COUNT,), dtype=np.uint8) + rhs = np.zeros((NIBBLE_COUNT,), dtype=np.uint8) + lhs[:PREFIX_BITS] = 1 + rhs[:SUFFIX_BITS] = 1 + golden = pack_nibbles(lhs | rhs) + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/por/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/por/kernel.pto new file mode 100644 index 000000000..84de1a29e --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/por/kernel.pto @@ -0,0 +1,55 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/por +// family: materialization-predicate +// target_ops: pto.por +// scenarios: predicate-transform +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5"} { + func.func @por_kernel_2d(%arg0: !pto.ptr) { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c7 = arith.constant 7 : i32 + %c13 = arith.constant 13 : i32 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + %all = pto.pset_b32 "PAT_ALL" : !pto.mask + %lhs, %lhs_next = pto.plt_b32 %c13 : i32 -> !pto.mask, i32 + %rhs, %rhs_next = pto.plt_b32 %c7 : i32 -> !pto.mask, i32 + %out = pto.por %lhs, %rhs, %all : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + pto.psts %out, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/por/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/por/launch.cpp new file mode 100644 index 000000000..caa8684a4 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/por/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/por +// family: materialization-predicate +// target_ops: pto.por +// scenarios: predicate-transform +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void por_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPor(uint32_t *v1, void *stream) { + por_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/por/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/por/main.cpp new file mode 100644 index 000000000..527116eff --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/por/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/por +// family: materialization-predicate +// target_ops: pto.por +// scenarios: predicate-transform +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPor(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPor(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/por/stub.cpp b/test/vpto/cases/micro-op/materialization-predicate/por/stub.cpp new file mode 100644 index 000000000..3fe95cf3c --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/por/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/por +// family: materialization-predicate +// target_ops: pto.por +// scenarios: predicate-transform + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void por_kernel_2d(__gm__ uint32_t *v1) { + (void)v1; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/compare.py b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/compare.py new file mode 100755 index 000000000..2585ff1e4 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/ppack-punpack-nontrivial +# family: materialization-predicate +# target_ops: pto.ppack, pto.punpack +# scenarios: pack-unpack-roundtrip, representative-logical-elements +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 16 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/golden.py b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/golden.py new file mode 100755 index 000000000..7923992f9 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/ppack-punpack-nontrivial +# family: materialization-predicate +# target_ops: pto.ppack, pto.punpack +# scenarios: pack-unpack-roundtrip, representative-logical-elements +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +GOLDEN_PREFIX_WORDS = np.array([16843009, 16843009, 16843009, 16843009, 0, 0, 0, 0, 65537, 65537, 65537, 65537, 65537, 65537, 65537, 65537], dtype=np.uint32) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden[: GOLDEN_PREFIX_WORDS.size] = GOLDEN_PREFIX_WORDS + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/kernel.pto new file mode 100644 index 000000000..2c9d46589 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/kernel.pto @@ -0,0 +1,36 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/ppack-punpack-nontrivial +// family: materialization-predicate +// target_ops: pto.ppack, pto.punpack +// scenarios: pack-unpack-roundtrip, nontrivial-pattern +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @ppack_punpack_nontrivial_kernel_2d(%arg0: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %src = pto.pset_b32 "PAT_M4" : !pto.mask + %packed = pto.ppack %src, "LOWER" : !pto.mask -> !pto.mask + %roundtrip = pto.punpack %packed, "LOWER" : !pto.mask -> !pto.mask + pto.psts %packed, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %roundtrip, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/launch.cpp new file mode 100644 index 000000000..aac69efd1 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/ppack-punpack-nontrivial +// family: materialization-predicate +// target_ops: pto.ppack, pto.punpack +// scenarios: pack-unpack-roundtrip, nontrivial-pattern +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void ppack_punpack_nontrivial_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPpackPunpackNontrivial(uint32_t *v1, void *stream) { + ppack_punpack_nontrivial_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/main.cpp new file mode 100644 index 000000000..30feafd01 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/ppack-punpack-nontrivial +// family: materialization-predicate +// target_ops: pto.ppack, pto.punpack +// scenarios: pack-unpack-roundtrip, nontrivial-pattern +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPpackPunpackNontrivial(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPpackPunpackNontrivial(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/stub.cpp b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/stub.cpp new file mode 100644 index 000000000..b826256fb --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/ppack-punpack-nontrivial +// family: materialization-predicate +// target_ops: pto.ppack, pto.punpack +// scenarios: pack-unpack-roundtrip, nontrivial-pattern + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void ppack_punpack_nontrivial_kernel_2d(__gm__ uint32_t *v1) { + (void)v1; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/compare.py b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/compare.py new file mode 100755 index 000000000..08e54575d --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/ppack-punpack +# family: materialization-predicate +# target_ops: pto.ppack, pto.punpack +# scenarios: pack-unpack-roundtrip +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 16 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/golden.py b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/golden.py new file mode 100755 index 000000000..05d8b350b --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/ppack-punpack +# family: materialization-predicate +# target_ops: pto.ppack, pto.punpack +# scenarios: pack-unpack-roundtrip +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +GOLDEN_PREFIX_WORDS = np.array([1431655765, 1431655765, 1431655765, 1431655765, 0, 0, 0, 0, 286331153, 286331153, 286331153, 286331153, 286331153, 286331153, 286331153, 286331153], dtype=np.uint32) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden[: GOLDEN_PREFIX_WORDS.size] = GOLDEN_PREFIX_WORDS + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/kernel.pto new file mode 100644 index 000000000..05569d1e4 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/kernel.pto @@ -0,0 +1,36 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/ppack-punpack +// family: materialization-predicate +// target_ops: pto.ppack, pto.punpack +// scenarios: pack-unpack-roundtrip +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @ppack_punpack_kernel_2d(%arg0: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %src = pto.pset_b32 "PAT_ALL" : !pto.mask + %packed = pto.ppack %src, "LOWER" : !pto.mask -> !pto.mask + %roundtrip = pto.punpack %packed, "LOWER" : !pto.mask -> !pto.mask + pto.psts %packed, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %roundtrip, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/launch.cpp new file mode 100644 index 000000000..2dc4b848d --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/ppack-punpack +// family: materialization-predicate +// target_ops: pto.ppack, pto.punpack +// scenarios: pack-unpack-roundtrip +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void ppack_punpack_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPpackPunpack(uint32_t *v1, void *stream) { + ppack_punpack_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/main.cpp new file mode 100644 index 000000000..0ad47bb5d --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/ppack-punpack +// family: materialization-predicate +// target_ops: pto.ppack, pto.punpack +// scenarios: pack-unpack-roundtrip +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPpackPunpack(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPpackPunpack(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/stub.cpp b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/stub.cpp new file mode 100644 index 000000000..46d8d8ba9 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/ppack-punpack +// family: materialization-predicate +// target_ops: pto.ppack, pto.punpack +// scenarios: pack-unpack-roundtrip + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void ppack_punpack_kernel_2d(__gm__ uint32_t *v1) { + (void)v1; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/compare.py b/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/compare.py new file mode 100755 index 000000000..d334ff512 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/psel-tail-predicate +# family: materialization-predicate +# target_ops: pto.psel +# scenarios: predicate-transform, predicate-select, tail-mask +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 16 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/golden.py b/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/golden.py new file mode 100755 index 000000000..0144a0f36 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/psel-tail-predicate +# family: materialization-predicate +# target_ops: pto.psel +# scenarios: predicate-transform, predicate-select, tail-mask +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +GOLDEN_PREFIX_WORDS = np.array([286331153, 69905, 0, 0, 0, 0, 0, 0, 286331153, 286331153, 286331153, 286331153, 286331153, 286331153, 286331153, 286331153], dtype=np.uint32) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden[: GOLDEN_PREFIX_WORDS.size] = GOLDEN_PREFIX_WORDS + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/kernel.pto new file mode 100644 index 000000000..91c7f6515 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/kernel.pto @@ -0,0 +1,40 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/psel-tail-predicate +// family: materialization-predicate +// target_ops: pto.psel +// scenarios: predicate-transform, predicate-select, tail-mask +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @psel_tail_kernel_2d(%arg0: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c13 = arith.constant 13 : i32 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %src0 = pto.pset_b32 "PAT_ALL" : !pto.mask + %sel, %next = pto.plt_b32 %c13 : i32 -> !pto.mask, i32 + %out = pto.psel %src0, %sel, %sel : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + pto.psts %out, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + %out_next = pto.psel %sel, %src0, %sel : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + pto.psts %out_next, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/launch.cpp new file mode 100644 index 000000000..e4c8692cf --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/psel-tail-predicate +// family: materialization-predicate +// target_ops: pto.psel +// scenarios: predicate-transform, predicate-select, tail-mask +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void psel_tail_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPselTailPredicate(uint32_t *v1, void *stream) { + psel_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/main.cpp new file mode 100644 index 000000000..5a5996be6 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/psel-tail-predicate +// family: materialization-predicate +// target_ops: pto.psel +// scenarios: predicate-transform, predicate-select, tail-mask +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPselTailPredicate(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPselTailPredicate(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/stub.cpp b/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/stub.cpp new file mode 100644 index 000000000..07959f124 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/psel-tail-predicate +// family: materialization-predicate +// target_ops: pto.psel +// scenarios: predicate-transform, predicate-select, tail-mask + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void psel_tail_kernel_2d(__gm__ uint32_t *v1) { + (void)v1; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/psel/compare.py b/test/vpto/cases/micro-op/materialization-predicate/psel/compare.py new file mode 100755 index 000000000..fb258a13e --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/psel/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/psel +# family: materialization-predicate +# target_ops: pto.psel +# scenarios: predicate-transform, predicate-select +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 8 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/psel/golden.py b/test/vpto/cases/micro-op/materialization-predicate/psel/golden.py new file mode 100755 index 000000000..101269c58 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/psel/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/psel +# family: materialization-predicate +# target_ops: pto.psel +# scenarios: predicate-transform, predicate-select +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +GOLDEN_PREFIX_WORDS = np.array([286331153, 69905, 0, 0, 0, 0, 0, 0], dtype=np.uint32) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden[: GOLDEN_PREFIX_WORDS.size] = GOLDEN_PREFIX_WORDS + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/psel/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/psel/kernel.pto new file mode 100644 index 000000000..f9a47a667 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/psel/kernel.pto @@ -0,0 +1,36 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/psel +// family: materialization-predicate +// target_ops: pto.psel +// scenarios: predicate-transform, predicate-select +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @psel_kernel_2d(%arg0: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c13 = arith.constant 13 : i32 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %src0 = pto.pset_b32 "PAT_ALL" : !pto.mask + %sel, %next = pto.plt_b32 %c13 : i32 -> !pto.mask, i32 + %out = pto.psel %src0, %sel, %sel : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + pto.psts %out, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/psel/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/psel/launch.cpp new file mode 100644 index 000000000..34e1641cb --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/psel/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/psel +// family: materialization-predicate +// target_ops: pto.psel +// scenarios: predicate-transform, predicate-select +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void psel_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPsel(uint32_t *v1, void *stream) { + psel_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/psel/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/psel/main.cpp new file mode 100644 index 000000000..bfb6d3558 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/psel/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/psel +// family: materialization-predicate +// target_ops: pto.psel +// scenarios: predicate-transform, predicate-select +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPsel(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPsel(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/psel/stub.cpp b/test/vpto/cases/micro-op/materialization-predicate/psel/stub.cpp new file mode 100644 index 000000000..702830b96 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/psel/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/psel +// family: materialization-predicate +// target_ops: pto.psel +// scenarios: predicate-transform, predicate-select + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void psel_kernel_2d(__gm__ uint32_t *v1) { + (void)v1; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/compare.py b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/compare.py new file mode 100755 index 000000000..2de1b2000 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pset-pattern-fragment +# family: materialization-predicate +# target_ops: pto.pset_b16, pto.pset_b32, pto.pset_b8 +# scenarios: pattern-mask, pat-vl, representative-logical-elements +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 24 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/golden.py b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/golden.py new file mode 100755 index 000000000..fcf402e5a --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pset-pattern-fragment +# family: materialization-predicate +# target_ops: pto.pset_b16, pto.pset_b32, pto.pset_b8 +# scenarios: pattern-mask, pat-vl, representative-logical-elements +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +GOLDEN_PREFIX_WORDS = np.array([1227133513, 2454267026, 613566756, 1227133513, 2454267026, 613566756, 1227133513, 2454267026, 1431655765, 1431655765, 1431655765, 1431655765, 0, 0, 0, 0, 286331153, 286331153, 0, 0, 0, 0, 0, 0], dtype=np.uint32) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden[: GOLDEN_PREFIX_WORDS.size] = GOLDEN_PREFIX_WORDS + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/kernel.pto new file mode 100644 index 000000000..4b5992a6b --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/kernel.pto @@ -0,0 +1,38 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pset-pattern-fragment +// family: materialization-predicate +// target_ops: pto.pset_b16, pto.pset_b32, pto.pset_b8 +// scenarios: pattern-mask, fragment-pattern +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @pset_pattern_fragment_kernel_2d(%arg0: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c96_i64 = arith.constant 96 : i64 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %m0 = pto.pset_b8 "PAT_M3" : !pto.mask + %m1 = pto.pset_b16 "PAT_H" : !pto.mask + %m2 = pto.pset_b32 "PAT_Q" : !pto.mask + pto.psts %m0, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %m1, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %m2, %ub_out[%c64], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c96_i64, %c0_i64, %c96_i64, %c96_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/launch.cpp new file mode 100644 index 000000000..c74390013 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pset-pattern-fragment +// family: materialization-predicate +// target_ops: pto.pset_b16, pto.pset_b32, pto.pset_b8 +// scenarios: pattern-mask, fragment-pattern +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pset_pattern_fragment_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPsetPatternFragment(uint32_t *v1, void *stream) { + pset_pattern_fragment_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/main.cpp new file mode 100644 index 000000000..116e3e7ae --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pset-pattern-fragment +// family: materialization-predicate +// target_ops: pto.pset_b16, pto.pset_b32, pto.pset_b8 +// scenarios: pattern-mask, fragment-pattern +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPsetPatternFragment(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPsetPatternFragment(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/stub.cpp b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/stub.cpp new file mode 100644 index 000000000..12b75b3dc --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pset-pattern-fragment +// family: materialization-predicate +// target_ops: pto.pset_b16, pto.pset_b32, pto.pset_b8 +// scenarios: pattern-mask, fragment-pattern + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void pset_pattern_fragment_kernel_2d(__gm__ uint32_t *v1) { + (void)v1; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/compare.py b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/compare.py new file mode 100755 index 000000000..10290abc6 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/compare.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pset-pattern +# family: materialization-predicate +# target_ops: pto.pset_b16, pto.pset_b32, pto.pset_b8 +# scenarios: pattern-mask, pat-all, pat-vl +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + +EXPECTED_WORDS = 24 + + +def compare_packed_words(golden_path, output_path): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_packed_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/golden.py b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/golden.py new file mode 100755 index 000000000..dcc083810 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/golden.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pset-pattern +# family: materialization-predicate +# target_ops: pto.pset_b16, pto.pset_b32, pto.pset_b8 +# scenarios: pattern-mask, pat-all, pat-vl +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 24 + + +def _pack_pset_prefix(active_lanes: int, bit_stride: int, store_bytes: int) -> np.ndarray: + out = np.zeros((store_bytes,), dtype=np.uint8) + for lane in range(active_lanes): + bit_index = lane * bit_stride + byte_index = bit_index // 8 + bit_in_byte = bit_index % 8 + out[byte_index] |= np.uint8(1 << bit_in_byte) + return out + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + + out = np.zeros((OUTPUT_WORDS * 4,), dtype=np.uint8) + out[0:32] = _pack_pset_prefix(active_lanes=256, bit_stride=1, store_bytes=32) + out[32:48] = _pack_pset_prefix(active_lanes=8, bit_stride=2, store_bytes=16) + out[64:80] = _pack_pset_prefix(active_lanes=16, bit_stride=4, store_bytes=16) + golden = out.view(np.uint32) + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op pset-pattern validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/kernel.pto new file mode 100644 index 000000000..1a7a30f51 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/kernel.pto @@ -0,0 +1,38 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pset-pattern +// family: materialization-predicate +// target_ops: pto.pset_b16, pto.pset_b32, pto.pset_b8 +// scenarios: pattern-mask, pat-all, pat-vl +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @pset_pattern_kernel_2d(%arg0: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c96_i64 = arith.constant 96 : i64 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %m0 = pto.pset_b8 "PAT_ALL" : !pto.mask + %m1 = pto.pset_b16 "PAT_VL8" : !pto.mask + %m2 = pto.pset_b32 "PAT_VL16" : !pto.mask + pto.psts %m0, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %m1, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %m2, %ub_out[%c64], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c96_i64, %c0_i64, %c96_i64, %c96_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/launch.cpp new file mode 100644 index 000000000..01ec8b624 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/launch.cpp @@ -0,0 +1,69 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pset-pattern +// family: materialization-predicate +// target_ops: pto.pset_b16, pto.pset_b32, pto.pset_b8 +// scenarios: pattern-mask, pat-all, pat-vl +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pset_pattern_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPset_pattern_kernel_2d(uint32_t *v1, void *stream) { + pset_pattern_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/main.cpp new file mode 100644 index 000000000..15a7b4181 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/main.cpp @@ -0,0 +1,119 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pset-pattern +// family: materialization-predicate +// target_ops: pto.pset_b16, pto.pset_b32, pto.pset_b8 +// scenarios: pattern-mask, pat-all, pat-vl +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchPset_pattern_kernel_2d(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 24; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPset_pattern_kernel_2d(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/stub.cpp b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/stub.cpp new file mode 100644 index 000000000..4776b7d2b --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/stub.cpp @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pset-pattern +// family: materialization-predicate +// target_ops: pto.pset_b16, pto.pset_b32, pto.pset_b8 +// scenarios: pattern-mask, pat-all, pat-vl +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void pset_pattern_kernel_2d(__gm__ uint32_t *v1) { + (void)v1; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pxor/compare.py b/test/vpto/cases/micro-op/materialization-predicate/pxor/compare.py new file mode 100755 index 000000000..0652ae4a5 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pxor/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pxor +# family: materialization-predicate +# target_ops: pto.pxor +# scenarios: predicate-transform, logical-xor +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 8 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pxor/golden.py b/test/vpto/cases/micro-op/materialization-predicate/pxor/golden.py new file mode 100755 index 000000000..16f212335 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pxor/golden.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pxor +# family: materialization-predicate +# target_ops: pto.pxor +# scenarios: predicate-transform, logical-xor +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +PREFIX_BITS = 13 +SUFFIX_BITS = 7 +PREDICATE_BITS = 256 +NIBBLE_COUNT = PREDICATE_BITS // 2 + + +def pack_nibbles(nibbles: np.ndarray) -> np.ndarray: + words = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + for idx, nibble in enumerate(nibbles): + words[idx // 8] |= np.uint32(int(nibble) & 0xF) << np.uint32((idx % 8) * 4) + return words + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + lhs = np.zeros((NIBBLE_COUNT,), dtype=np.uint8) + rhs = np.zeros((NIBBLE_COUNT,), dtype=np.uint8) + lhs[:PREFIX_BITS] = 1 + rhs[:SUFFIX_BITS] = 1 + golden = pack_nibbles(np.bitwise_xor(lhs, rhs)) + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pxor/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pxor/kernel.pto new file mode 100644 index 000000000..63ae77a30 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pxor/kernel.pto @@ -0,0 +1,55 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pxor +// family: materialization-predicate +// target_ops: pto.pxor +// scenarios: predicate-transform +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5"} { + func.func @pxor_kernel_2d(%arg0: !pto.ptr) { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c7 = arith.constant 7 : i32 + %c13 = arith.constant 13 : i32 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + %all = pto.pset_b32 "PAT_ALL" : !pto.mask + %lhs, %lhs_next = pto.plt_b32 %c13 : i32 -> !pto.mask, i32 + %rhs, %rhs_next = pto.plt_b32 %c7 : i32 -> !pto.mask, i32 + %out = pto.pxor %lhs, %rhs, %all : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + pto.psts %out, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pxor/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/pxor/launch.cpp new file mode 100644 index 000000000..55f3770dc --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pxor/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pxor +// family: materialization-predicate +// target_ops: pto.pxor +// scenarios: predicate-transform +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pxor_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPxor(uint32_t *v1, void *stream) { + pxor_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pxor/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/pxor/main.cpp new file mode 100644 index 000000000..6bf82fb80 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pxor/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pxor +// family: materialization-predicate +// target_ops: pto.pxor +// scenarios: predicate-transform +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPxor(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPxor(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pxor/stub.cpp b/test/vpto/cases/micro-op/materialization-predicate/pxor/stub.cpp new file mode 100644 index 000000000..d71da2e77 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pxor/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pxor +// family: materialization-predicate +// target_ops: pto.pxor +// scenarios: predicate-transform + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void pxor_kernel_2d(__gm__ uint32_t *v1) { + (void)v1; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/compare.py b/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/compare.py new file mode 100644 index 000000000..8a9b82365 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/compare.py @@ -0,0 +1,204 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v1.bin", "v1.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/golden.py b/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/golden.py new file mode 100644 index 000000000..86a637182 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/golden.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +VALUE = np.float32(1.25) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + v1 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v1 = np.full((ROWS, COLS), VALUE, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + golden_v1.reshape(-1).tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vbr-f32 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/kernel.pto new file mode 100644 index 000000000..45eb413dc --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/kernel.pto @@ -0,0 +1,48 @@ +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5"} { + func.func @vbr_f32_kernel_2d(%arg0: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %cst = arith.constant 1.250000e+00 : f32 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.vecscope { + %active = pto.pset_b32 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c64 { + %vec = pto.vbr %cst : f32 -> !pto.vreg<64xf32> + pto.vsts %vec, %ub_out[%offset], %active : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg0, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/launch.cpp new file mode 100644 index 000000000..cf1d57866 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/launch.cpp @@ -0,0 +1,61 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vbr_f32_kernel_2d(__gm__ float *v1); + +void LaunchVbr_f32_kernel_2d(float *v1, void *stream) { + vbr_f32_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/main.cpp new file mode 100644 index 000000000..0fce80155 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/main.cpp @@ -0,0 +1,111 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVbr_f32_kernel_2d(float *v1, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVbr_f32_kernel_2d(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/stub.cpp b/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/stub.cpp new file mode 100644 index 000000000..a251216f5 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/stub.cpp @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vbr_f32_kernel_2d(__gm__ float *v1) { + (void)v1; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/compare.py b/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/compare.py new file mode 100644 index 000000000..78ba4226d --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/compare.py @@ -0,0 +1,204 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v1.bin", "v1.bin", np.int32, 0) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/golden.py b/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/golden.py new file mode 100644 index 000000000..68e367caf --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/golden.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +VALUE = np.int32(7) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + v1 = np.zeros((ROWS, COLS), dtype=np.int32) + golden_v1 = np.full((ROWS, COLS), VALUE, dtype=np.int32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + golden_v1.reshape(-1).tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vbr-i32 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/kernel.pto new file mode 100644 index 000000000..a651c7888 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/kernel.pto @@ -0,0 +1,48 @@ +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5"} { + func.func @vbr_i32_kernel_2d(%arg0: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %cst = arith.constant 7 : i32 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.vecscope { + %active = pto.pset_b32 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c64 { + %vec = pto.vbr %cst : i32 -> !pto.vreg<64xi32> + pto.vsts %vec, %ub_out[%offset], %active : !pto.vreg<64xi32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg0, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/launch.cpp new file mode 100644 index 000000000..8db4dd6a4 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/launch.cpp @@ -0,0 +1,61 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vbr_i32_kernel_2d(__gm__ int *v1); + +void LaunchVbr_i32_kernel_2d(int *v1, void *stream) { + vbr_i32_kernel_2d<<<1, nullptr, stream>>>((__gm__ int *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/main.cpp new file mode 100644 index 000000000..cae6ce5ec --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/main.cpp @@ -0,0 +1,111 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVbr_i32_kernel_2d(int *v1, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(int); + int *v1Host = nullptr; + int *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVbr_i32_kernel_2d(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/stub.cpp b/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/stub.cpp new file mode 100644 index 000000000..df842be32 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/stub.cpp @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vbr_i32_kernel_2d(__gm__ int *v1) { + (void)v1; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/compare.py b/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/compare.py new file mode 100755 index 000000000..f22dfc0f5 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/compare.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/vdup-lane +# family: materialization-predicate +# target_ops: pto.vdup +# scenarios: core-f32, vector-input, lowest-highest +# coding=utf-8 + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + diff = np.abs(golden.astype(np.float64) - output.astype(np.float64)) + idx = int(np.argmax(diff)) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={float(diff[idx])} " + f"at idx={idx} (golden={golden[idx]}, out={output[idx]})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_low.bin", "out_low.bin", np.float32, 1e-4) and ok + ok = compare_bin("golden_high.bin", "out_high.bin", np.float32, 1e-4) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/golden.py b/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/golden.py new file mode 100755 index 000000000..712a2adbd --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/golden.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/vdup-lane +# family: materialization-predicate +# target_ops: pto.vdup +# scenarios: core-f32, vector-input, lowest-highest +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + src = rng.normal(loc=1.0, scale=3.0, size=(ROWS, COLS)).astype(np.float32) + + src_flat = src.reshape(-1) + low_flat = np.empty_like(src_flat) + high_flat = np.empty_like(src_flat) + block = 64 + for begin in range(0, src_flat.size, block): + chunk = src_flat[begin : begin + block] + low_flat[begin : begin + block] = chunk[0] + high_flat[begin : begin + block] = chunk[-1] + + low = low_flat.reshape(src.shape) + high = high_flat.reshape(src.shape) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "src.bin") + low.reshape(-1).tofile(output_dir / "golden_low.bin") + high.reshape(-1).tofile(output_dir / "golden_high.bin") + np.zeros_like(src.reshape(-1)).tofile(output_dir / "out_low.bin") + np.zeros_like(src.reshape(-1)).tofile(output_dir / "out_high.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO vector-input vdup validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/kernel.pto new file mode 100644 index 000000000..37b29d93a --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/kernel.pto @@ -0,0 +1,50 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/vdup-lane +// family: materialization-predicate +// target_ops: pto.vdup +// scenarios: core-f32, vector-input, lowest-highest +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vdup_lane_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr) { + %c8192_i64 = arith.constant 8192 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %false = arith.constant false + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_low = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_high = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.set_loop2_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 + pto.set_loop1_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %active = pto.pset_b32 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c64 { + %src = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %low = pto.vdup %src, %active {position = "LOWEST"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %high = pto.vdup %src, %active {position = "HIGHEST"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %low, %ub_low[%offset], %active : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.vsts %high, %ub_high[%offset], %active : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_low, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.copy_ubuf_to_gm %ub_high, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/launch.cpp new file mode 100644 index 000000000..522a86731 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/launch.cpp @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/vdup-lane +// family: materialization-predicate +// target_ops: pto.vdup +// scenarios: core-f32, vector-input, lowest-highest +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vdup_lane_kernel_2d(__gm__ float *src, + __gm__ float *outLow, + __gm__ float *outHigh); + +void LaunchVdup_lane_kernel_2d(float *src, float *outLow, float *outHigh, + void *stream) { + vdup_lane_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)src, + (__gm__ float *)outLow, + (__gm__ float *)outHigh); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/main.cpp new file mode 100644 index 000000000..685317502 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/main.cpp @@ -0,0 +1,119 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/vdup-lane +// family: materialization-predicate +// target_ops: pto.vdup +// scenarios: core-f32, vector-input, lowest-highest +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVdup_lane_kernel_2d(float *src, float *outLow, float *outHigh, void *stream); + +int main() { + size_t elemCount = 1024; + size_t fileSize = elemCount * sizeof(float); + float *srcHost = nullptr; + float *outLowHost = nullptr; + float *outHighHost = nullptr; + float *srcDevice = nullptr; + float *outLowDevice = nullptr; + float *outHighDevice = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), fileSize)); + ACL_CHECK(aclrtMallocHost((void **)(&outLowHost), fileSize)); + ACL_CHECK(aclrtMallocHost((void **)(&outHighHost), fileSize)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outLowDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outHighDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./src.bin", fileSize, srcHost, fileSize); + ACL_CHECK(aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemset(outLowDevice, fileSize, 0, fileSize)); + ACL_CHECK(aclrtMemset(outHighDevice, fileSize, 0, fileSize)); + + LaunchVdup_lane_kernel_2d(srcDevice, outLowDevice, outHighDevice, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(outLowHost, fileSize, outLowDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outHighHost, fileSize, outHighDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./out_low.bin", outLowHost, fileSize); + WriteFile("./out_high.bin", outHighHost, fileSize); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(outLowDevice); + aclrtFree(outHighDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(outLowHost); + aclrtFreeHost(outHighHost); + if (stream != nullptr) { + const aclError ret = aclrtDestroyStream(stream); + if (ret != ACL_SUCCESS) + std::fprintf(stderr, "[ERROR] aclrtDestroyStream(stream) failed: %d (%s:%d)\n", + (int)ret, __FILE__, __LINE__); + } + if (deviceSet) { + const aclError ret = aclrtResetDevice(deviceId); + if (ret != ACL_SUCCESS) + std::fprintf(stderr, "[ERROR] aclrtResetDevice(deviceId) failed: %d (%s:%d)\n", + (int)ret, __FILE__, __LINE__); + } + if (aclInited) { + const aclError ret = aclFinalize(); + if (ret != ACL_SUCCESS) + std::fprintf(stderr, "[ERROR] aclFinalize() failed: %d (%s:%d)\n", + (int)ret, __FILE__, __LINE__); + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/stub.cpp b/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/stub.cpp new file mode 100644 index 000000000..9379dd08b --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/stub.cpp @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/vdup-lane +// family: materialization-predicate +// target_ops: pto.vdup +// scenarios: core-f32, vector-input, lowest-highest +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vdup_lane_kernel_2d(__gm__ float *src, + __gm__ float *outLow, + __gm__ float *outHigh) { + (void)src; + (void)outLow; + (void)outHigh; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/compare.py b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/compare.py new file mode 100644 index 000000000..e423b7707 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/compare.py @@ -0,0 +1,60 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v1.bin", "v1.bin", np.float16, 0.001) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/golden.py b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/golden.py new file mode 100644 index 000000000..3f3ad08ba --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/golden.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +VALUE = np.float16(1.25) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + v1 = np.zeros((ROWS, COLS), dtype=np.float16) + golden_v1 = np.full((ROWS, COLS), VALUE, dtype=np.float16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + golden_v1.reshape(-1).tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vdup-scalar-f16 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/kernel.pto new file mode 100644 index 000000000..dfea1ef9a --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/kernel.pto @@ -0,0 +1,35 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/vdup-scalar-f16 +// family: materialization-predicate +// target_ops: pto.vdup +// scenarios: core-f16, scalar-broadcast +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vdup_scalar_f16_kernel_2d(%arg0: !pto.ptr) { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %cst = arith.constant 1.250000e+00 : f16 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.vecscope { + %active = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %vec = pto.vdup %cst, %active : f16, !pto.mask -> !pto.vreg<128xf16> + pto.vsts %vec, %ub_out[%offset], %active : !pto.vreg<128xf16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg0, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/launch.cpp new file mode 100644 index 000000000..664c961dd --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/launch.cpp @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vdup_scalar_f16_kernel_2d(__gm__ half *v1); + +void LaunchVdup_scalar_f16_kernel_2d(aclFloat16 *v1, void *stream) { + vdup_scalar_f16_kernel_2d<<<1, nullptr, stream>>>((__gm__ half *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/main.cpp new file mode 100644 index 000000000..b8f469441 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/main.cpp @@ -0,0 +1,111 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVdup_scalar_f16_kernel_2d(aclFloat16 *v1, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(aclFloat16); + aclFloat16 *v1Host = nullptr; + aclFloat16 *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVdup_scalar_f16_kernel_2d(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/stub.cpp b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/stub.cpp new file mode 100644 index 000000000..6896f43bb --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/stub.cpp @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vdup_scalar_f16_kernel_2d(__gm__ half *v1) { + (void)v1; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/compare.py b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/compare.py new file mode 100644 index 000000000..bc7b42e1a --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/compare.py @@ -0,0 +1,54 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden[idx])}, out={int(output[idx])}, dtype={dtype_np})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v1.bin", "v1.bin", np.int8) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/golden.py b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/golden.py new file mode 100644 index 000000000..90e2a44bb --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/golden.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +VALUE = np.int8(-83) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + v1 = np.zeros((ROWS, COLS), dtype=np.int8) + golden_v1 = np.full((ROWS, COLS), VALUE, dtype=np.int8) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + golden_v1.reshape(-1).tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vdup-scalar-i8 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/kernel.pto new file mode 100644 index 000000000..14b5abffb --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/kernel.pto @@ -0,0 +1,35 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/vdup-scalar-i8 +// family: materialization-predicate +// target_ops: pto.vdup +// scenarios: core-i8, scalar-broadcast +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vdup_scalar_i8_kernel_2d(%arg0: !pto.ptr) { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %cst = arith.constant -83 : i8 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.vecscope { + %active = pto.pset_b8 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %vec = pto.vdup %cst, %active : i8, !pto.mask -> !pto.vreg<256xi8> + pto.vsts %vec, %ub_out[%offset], %active : !pto.vreg<256xi8>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg0, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/launch.cpp new file mode 100644 index 000000000..5054b1777 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/launch.cpp @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vdup_scalar_i8_kernel_2d(__gm__ int8_t *v1); + +void LaunchVdup_scalar_i8_kernel_2d(int8_t *v1, void *stream) { + vdup_scalar_i8_kernel_2d<<<1, nullptr, stream>>>((__gm__ int8_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/main.cpp new file mode 100644 index 000000000..fde52caa7 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/main.cpp @@ -0,0 +1,111 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVdup_scalar_i8_kernel_2d(int8_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(int8_t); + int8_t *v1Host = nullptr; + int8_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVdup_scalar_i8_kernel_2d(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/stub.cpp b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/stub.cpp new file mode 100644 index 000000000..b65c98f53 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/stub.cpp @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vdup_scalar_i8_kernel_2d(__gm__ int8_t *v1) { + (void)v1; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/compare.py b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/compare.py new file mode 100644 index 000000000..8a9b82365 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/compare.py @@ -0,0 +1,204 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v1.bin", "v1.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/golden.py b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/golden.py new file mode 100644 index 000000000..3153c6005 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/golden.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +VALUE = np.float32(-2.5) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + v1 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v1 = np.full((ROWS, COLS), VALUE, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + golden_v1.reshape(-1).tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vdup-scalar validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/kernel.pto new file mode 100644 index 000000000..6ef474ecf --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/kernel.pto @@ -0,0 +1,48 @@ +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5"} { + func.func @vdup_scalar_kernel_2d(%arg0: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %cst = arith.constant -2.500000e+00 : f32 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.vecscope { + %active = pto.pset_b32 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c64 { + %vec = pto.vdup %cst, %active : f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %vec, %ub_out[%offset], %active : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg0, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/launch.cpp new file mode 100644 index 000000000..02754e6d2 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/launch.cpp @@ -0,0 +1,61 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vdup_scalar_kernel_2d(__gm__ float *v1); + +void LaunchVdup_scalar_kernel_2d(float *v1, void *stream) { + vdup_scalar_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/main.cpp new file mode 100644 index 000000000..6aff66657 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/main.cpp @@ -0,0 +1,111 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVdup_scalar_kernel_2d(float *v1, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVdup_scalar_kernel_2d(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/stub.cpp b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/stub.cpp new file mode 100644 index 000000000..255b982b6 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/stub.cpp @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vdup_scalar_kernel_2d(__gm__ float *v1) { + (void)v1; +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/_predicate_load_store_case.py b/test/vpto/cases/micro-op/predicate-load-store/_predicate_load_store_case.py new file mode 100644 index 000000000..494247b52 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/_predicate_load_store_case.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +OUTPUT_BYTES = ROWS * COLS +PREDICATE_BITS = 256 +# For the current A5 predicate load/store surface used by these composition +# cases, the user-visible packed NORM footprint is 16 bytes. Bytes beyond that +# range are not part of the checked result footprint. +NORM_STORAGE_BYTES = 16 + + +def prefix_bits(active_bits: int) -> np.ndarray: + bits = np.zeros((PREDICATE_BITS,), dtype=np.uint8) + bits[:active_bits] = 1 + return bits + + +def pk_us_compose(bits: np.ndarray) -> np.ndarray: + packed = bits[::2] + return np.repeat(packed, 2).astype(np.uint8, copy=False) + + +def norm_ds_compose(bits: np.ndarray) -> np.ndarray: + source = np.concatenate( + [bits.astype(np.uint8, copy=False), np.zeros_like(bits, dtype=np.uint8)] + ) + return source[::2][:PREDICATE_BITS].astype(np.uint8, copy=False) + + +def norm_store_bytes(bits: np.ndarray) -> np.ndarray: + packed = np.packbits(bits.astype(np.uint8, copy=False), bitorder="little") + out = np.zeros((OUTPUT_BYTES,), dtype=np.uint8) + out[:NORM_STORAGE_BYTES] = packed[:NORM_STORAGE_BYTES] + return out + + +def write_default_inputs(output_dir: Path) -> None: + np.zeros((ROWS * COLS,), dtype=np.float32).tofile(output_dir / "v1.bin") + np.zeros((ROWS * COLS,), dtype=np.float32).tofile(output_dir / "v2.bin") + np.zeros((OUTPUT_BYTES,), dtype=np.uint8).tofile(output_dir / "v3.bin") + + +def write_case(output_dir: Path, bits: np.ndarray) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + write_default_inputs(output_dir) + norm_store_bytes(bits).tofile(output_dir / "golden_v3.bin") + + +def compare_norm_store(golden_path: str, output_path: str) -> bool: + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + if golden.size < NORM_STORAGE_BYTES or output.size < NORM_STORAGE_BYTES: + return False + if not np.array_equal(golden[:NORM_STORAGE_BYTES], output[:NORM_STORAGE_BYTES]): + diff = np.nonzero(golden[:NORM_STORAGE_BYTES] != output[:NORM_STORAGE_BYTES])[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (predicate load/store composition): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True diff --git a/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/compare.py b/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/compare.py new file mode 100644 index 000000000..1f197eec2 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/compare.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/pldi-norm +# family: predicate-load-store +# target_ops: pto.pldi +# scenarios: packed-load, immediate-offset, representative-logical-elements + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.uint8) + output = np.fromfile("v2.bin", dtype=np.uint8) + if golden.size < 256 or output.size < 256: + print( + f"[ERROR] Packed buffer too small: golden={golden.size} out={output.size}" + ) + raise SystemExit(2) + if not np.array_equal(golden[:256], output[:256]): + diff = np.nonzero(golden[:256] != output[:256])[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (pldi NORM -> vsel): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + raise SystemExit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/golden.py b/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/golden.py new file mode 100644 index 000000000..d4811cdb0 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/golden.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/pldi-norm +# family: predicate-load-store +# target_ops: pto.pldi +# scenarios: packed-load, immediate-offset, representative-logical-elements + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +ACTIVE_BITS = 145 +OUTPUT_BYTES = 1024 +VECTOR_BYTES = 256 +PACKED_BYTES = 32 + + +def prefix_bits(active_bits: int) -> np.ndarray: + bits = np.zeros((256,), dtype=np.uint8) + bits[:active_bits] = 1 + return bits + + +def make_input_buffer(bits: np.ndarray) -> np.ndarray: + packed = np.packbits(bits.astype(np.uint8, copy=False), bitorder="little") + ones = np.ones((VECTOR_BYTES,), dtype=np.uint8) + zeros = np.zeros((VECTOR_BYTES,), dtype=np.uint8) + out = np.zeros((OUTPUT_BYTES,), dtype=np.uint8) + out[:PACKED_BYTES] = packed[:PACKED_BYTES] + out[PACKED_BYTES : PACKED_BYTES + VECTOR_BYTES] = ones + out[PACKED_BYTES + VECTOR_BYTES : PACKED_BYTES + 2 * VECTOR_BYTES] = zeros + return out + + +def expected_selected_bytes(bits: np.ndarray) -> np.ndarray: + out = np.zeros((OUTPUT_BYTES,), dtype=np.uint8) + out[:VECTOR_BYTES] = bits.astype(np.uint8, copy=False) + return out + + +def generate(output_dir: Path, seed: int) -> None: + del seed + bits = prefix_bits(ACTIVE_BITS) + input_buffer = make_input_buffer(bits) + golden = expected_selected_bytes(bits) + + output_dir.mkdir(parents=True, exist_ok=True) + input_buffer.tofile(output_dir / "v1.bin") + np.zeros((OUTPUT_BYTES,), dtype=np.uint8).tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate raw packed predicate input/golden for VPTO micro-op pldi-norm validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/kernel.pto new file mode 100644 index 000000000..b2cfb6aed --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/kernel.pto @@ -0,0 +1,58 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/pldi-norm +// family: predicate-load-store +// target_ops: pto.pldi +// scenarios: packed-load, immediate-offset, representative-logical-elements +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @pldi_norm_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c32 = arith.constant 32 : index + %c288 = arith.constant 288 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c4_i64 = arith.constant 4 : i64 + %c256_i64 = arith.constant 256 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c10240_i64 = arith.constant 10240 : i64 + %c256_i32 = arith.constant 256 : i32 + %c256_loop_i32 = arith.constant 256 : i32 + %false = arith.constant false + + %ub_in = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c10240_i64 : i64 -> !pto.ptr + + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c4_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c256 step %c256 iter_args(%remaining = %c256_loop_i32) -> (i32) { + %loaded = pto.pldi %ub_in[%c0], "NORM" : !pto.ptr, index -> !pto.mask + %full_mask, %next_remaining = pto.plt_b8 %remaining : i32 -> !pto.mask, i32 + %ones_offset = arith.addi %offset, %c32 : index + %zeros_offset = arith.addi %offset, %c288 : index + %ones = pto.vlds %ub_in[%ones_offset] : !pto.ptr -> !pto.vreg<256xui8> + %zeros = pto.vlds %ub_in[%zeros_offset] : !pto.ptr -> !pto.vreg<256xui8> + %out = pto.vsel %ones, %zeros, %loaded : !pto.vreg<256xui8>, !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<256xui8> + pto.vsts %out, %ub_out[%offset], %full_mask : !pto.vreg<256xui8>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.set_loop1_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c4_i64, %c256_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/launch.cpp b/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/launch.cpp new file mode 100644 index 000000000..8044d8893 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pldi_norm_kernel_2d(__gm__ unsigned char *v1, + __gm__ unsigned char *v2); + +void LaunchPldi_norm_kernel_2d(unsigned char *v1, unsigned char *v2, void *stream) { + pldi_norm_kernel_2d<<<1, nullptr, stream>>>((__gm__ unsigned char *)v1, + (__gm__ unsigned char *)v2); +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/main.cpp b/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/main.cpp new file mode 100644 index 000000000..a1a8204c2 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/main.cpp @@ -0,0 +1,85 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/pldi-norm +// ----------------------------------------------------------------------------- +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchPldi_norm_kernel_2d(unsigned char *v1, unsigned char *v2, void *stream); + +int main() { + size_t fileSize_v1 = 1024 * sizeof(unsigned char); + size_t fileSize_v2 = 1024 * sizeof(unsigned char); + unsigned char *v1Host = nullptr; + unsigned char *v1Device = nullptr; + unsigned char *v2Host = nullptr; + unsigned char *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchPldi_norm_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/stub.cpp b/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/stub.cpp new file mode 100644 index 000000000..bc617e364 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/stub.cpp @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void pldi_norm_kernel_2d(__gm__ unsigned char *v1, + __gm__ unsigned char *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/plds-norm/compare.py b/test/vpto/cases/micro-op/predicate-load-store/plds-norm/compare.py new file mode 100644 index 000000000..bd3820b2e --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/plds-norm/compare.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/plds-norm +# family: predicate-load-store +# target_ops: pto.plds +# scenarios: packed-load, dynamic-offset, representative-logical-elements + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.uint8) + output = np.fromfile("v2.bin", dtype=np.uint8) + if golden.size < 256 or output.size < 256: + print( + f"[ERROR] Packed buffer too small: golden={golden.size} out={output.size}" + ) + raise SystemExit(2) + if not np.array_equal(golden[:256], output[:256]): + diff = np.nonzero(golden[:256] != output[:256])[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (plds NORM -> vsel): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + raise SystemExit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/plds-norm/golden.py b/test/vpto/cases/micro-op/predicate-load-store/plds-norm/golden.py new file mode 100644 index 000000000..e6cf2fb1b --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/plds-norm/golden.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/plds-norm +# family: predicate-load-store +# target_ops: pto.plds +# scenarios: packed-load, dynamic-offset, representative-logical-elements + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +ACTIVE_BITS = 145 +OUTPUT_BYTES = 1024 +VECTOR_BYTES = 256 +PACKED_BYTES = 32 + + +def prefix_bits(active_bits: int) -> np.ndarray: + bits = np.zeros((256,), dtype=np.uint8) + bits[:active_bits] = 1 + return bits + + +def make_input_buffer(bits: np.ndarray) -> np.ndarray: + packed = np.packbits(bits.astype(np.uint8, copy=False), bitorder="little") + ones = np.ones((VECTOR_BYTES,), dtype=np.uint8) + zeros = np.zeros((VECTOR_BYTES,), dtype=np.uint8) + out = np.zeros((OUTPUT_BYTES,), dtype=np.uint8) + out[:PACKED_BYTES] = packed[:PACKED_BYTES] + out[PACKED_BYTES : PACKED_BYTES + VECTOR_BYTES] = ones + out[PACKED_BYTES + VECTOR_BYTES : PACKED_BYTES + 2 * VECTOR_BYTES] = zeros + return out + + +def expected_selected_bytes(bits: np.ndarray) -> np.ndarray: + out = np.zeros((OUTPUT_BYTES,), dtype=np.uint8) + out[:VECTOR_BYTES] = bits.astype(np.uint8, copy=False) + return out + + +def generate(output_dir: Path, seed: int) -> None: + del seed + bits = prefix_bits(ACTIVE_BITS) + input_buffer = make_input_buffer(bits) + golden = expected_selected_bytes(bits) + + output_dir.mkdir(parents=True, exist_ok=True) + input_buffer.tofile(output_dir / "v1.bin") + np.zeros((OUTPUT_BYTES,), dtype=np.uint8).tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate raw packed predicate input/golden for VPTO micro-op plds-norm validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/plds-norm/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/plds-norm/kernel.pto new file mode 100644 index 000000000..5dda2be3e --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/plds-norm/kernel.pto @@ -0,0 +1,57 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/plds-norm +// family: predicate-load-store +// target_ops: pto.plds +// scenarios: packed-load, dynamic-offset, representative-logical-elements +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @plds_norm_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c32 = arith.constant 32 : index + %c288 = arith.constant 288 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c4_i64 = arith.constant 4 : i64 + %c256_i64 = arith.constant 256 : i64 + %c256_i32 = arith.constant 256 : i32 + %c256_loop_i32 = arith.constant 256 : i32 + %c8192_i64 = arith.constant 8192 : i64 + %c10240_i64 = arith.constant 10240 : i64 + %false = arith.constant false + + %ub_in = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c10240_i64 : i64 -> !pto.ptr + + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c4_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c256 step %c256 iter_args(%remaining = %c256_loop_i32) -> (i32) { + %byte_offset = arith.addi %offset, %c0 : index + %loaded = pto.plds %ub_in[%byte_offset], "NORM" : !pto.ptr, index -> !pto.mask + %full_mask, %next_remaining = pto.plt_b8 %remaining : i32 -> !pto.mask, i32 + %ones_offset = arith.addi %offset, %c32 : index + %zeros_offset = arith.addi %offset, %c288 : index + %ones = pto.vlds %ub_in[%ones_offset] : !pto.ptr -> !pto.vreg<256xui8> + %zeros = pto.vlds %ub_in[%zeros_offset] : !pto.ptr -> !pto.vreg<256xui8> + %out = pto.vsel %ones, %zeros, %loaded : !pto.vreg<256xui8>, !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<256xui8> + pto.vsts %out, %ub_out[%offset], %full_mask : !pto.vreg<256xui8>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.set_loop1_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c4_i64, %c256_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/plds-norm/launch.cpp b/test/vpto/cases/micro-op/predicate-load-store/plds-norm/launch.cpp new file mode 100644 index 000000000..9a1e9de07 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/plds-norm/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void plds_norm_kernel_2d(__gm__ unsigned char *v1, + __gm__ unsigned char *v2); + +void LaunchPlds_norm_kernel_2d(unsigned char *v1, unsigned char *v2, void *stream) { + plds_norm_kernel_2d<<<1, nullptr, stream>>>((__gm__ unsigned char *)v1, + (__gm__ unsigned char *)v2); +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/plds-norm/main.cpp b/test/vpto/cases/micro-op/predicate-load-store/plds-norm/main.cpp new file mode 100644 index 000000000..30f136c67 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/plds-norm/main.cpp @@ -0,0 +1,85 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/plds-norm +// ----------------------------------------------------------------------------- +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchPlds_norm_kernel_2d(unsigned char *v1, unsigned char *v2, void *stream); + +int main() { + size_t fileSize_v1 = 1024 * sizeof(unsigned char); + size_t fileSize_v2 = 1024 * sizeof(unsigned char); + unsigned char *v1Host = nullptr; + unsigned char *v1Device = nullptr; + unsigned char *v2Host = nullptr; + unsigned char *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchPlds_norm_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/plds-norm/stub.cpp b/test/vpto/cases/micro-op/predicate-load-store/plds-norm/stub.cpp new file mode 100644 index 000000000..05f7914ae --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/plds-norm/stub.cpp @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void plds_norm_kernel_2d(__gm__ unsigned char *v1, + __gm__ unsigned char *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/compare.py b/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/compare.py new file mode 100644 index 000000000..9a9e71168 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/psti-norm-pldi-ds +# family: predicate-load-store +# target_ops: pto.pldi, pto.psti +# scenarios: predicate-load-store-composition, immediate-offset, load-store-pair-preservation, representative-logical-elements + +import os +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).resolve().parent.parent)) + +from _predicate_load_store_case import compare_norm_store + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_norm_store("golden_v3.bin", "v3.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/golden.py b/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/golden.py new file mode 100644 index 000000000..fa4f55ccd --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/golden.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/psti-norm-pldi-ds +# family: predicate-load-store +# target_ops: pto.pldi, pto.psti +# scenarios: predicate-load-store-composition, immediate-offset, load-store-pair-preservation, representative-logical-elements + +import argparse +from pathlib import Path +import sys + +sys.path.append(str(Path(__file__).resolve().parent.parent)) + +from _predicate_load_store_case import norm_ds_compose, prefix_bits, write_case + + +SEED = 19 +ACTIVE_BITS = 143 + + +def generate(output_dir: Path, seed: int, src_elem_bytes: int) -> None: + del seed + del src_elem_bytes + write_case(output_dir, norm_ds_compose(prefix_bits(ACTIVE_BITS))) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for psti-norm-pldi-ds." + ) + parser.add_argument( + "--output-dir", + type=Path, + default=Path("."), + help="Directory where v1.bin/v2.bin/v3.bin/golden_v3.bin are written.", + ) + parser.add_argument("--seed", type=int, default=SEED, help="Numpy random seed.") + parser.add_argument( + "--src-elem-bytes", + type=int, + default=4, + help="Unused compatibility option kept for the shared runner surface.", + ) + args = parser.parse_args() + generate(args.output_dir, args.seed, args.src_elem_bytes) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/kernel.pto new file mode 100644 index 000000000..7b162a666 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/kernel.pto @@ -0,0 +1,50 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/psti-norm-pldi-ds +// family: predicate-load-store +// target_ops: pto.pldi, pto.psti +// scenarios: predicate-load-store-composition, immediate-offset, load-store-pair-preservation, representative-logical-elements +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @psti_norm_pldi_ds_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c143 = arith.constant 143 : i32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c10240_i64 = arith.constant 10240 : i64 + %false = arith.constant false + + %ub_mid = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c10240_i64 : i64 -> !pto.ptr + + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_mid, %c0_i64, %c32_i64, %c32_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c143) -> (i32) { + %src, %next = pto.plt_b8 %remaining : i32 -> !pto.mask, i32 + pto.psti %src, %ub_mid[%c32], "NORM" : !pto.mask, !pto.ptr, index + pto.mem_bar "VST_VLD" + %loaded = pto.pldi %ub_mid[%c32], "DS" : !pto.ptr, index -> !pto.mask + pto.psts %loaded, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/launch.cpp b/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/launch.cpp new file mode 100644 index 000000000..b1d57a8e8 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void psti_norm_pldi_ds_kernel_2d(__gm__ float *v1, + __gm__ unsigned char *v2, + __gm__ unsigned char *v3); + +void LaunchPsti_norm_pldi_ds_kernel_2d(float *v1, unsigned char *v2, + unsigned char *v3, void *stream) { + psti_norm_pldi_ds_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ unsigned char *)v2, + (__gm__ unsigned char *)v3); +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/main.cpp b/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/main.cpp new file mode 100644 index 000000000..db6a66d3f --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/main.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/psti-norm-pldi-ds +// ----------------------------------------------------------------------------- +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchPsti_norm_pldi_ds_kernel_2d(float *v1, unsigned char *v2, + unsigned char *v3, void *stream); + +int main() { + size_t fileSize_v1 = 1024 * sizeof(float); + size_t fileSize_v2 = 1024 * sizeof(unsigned char); + size_t fileSize_v3 = 1024 * sizeof(unsigned char); + float *v1Host = nullptr; + float *v1Device = nullptr; + unsigned char *v2Host = nullptr; + unsigned char *v2Device = nullptr; + unsigned char *v3Host = nullptr; + unsigned char *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchPsti_norm_pldi_ds_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/stub.cpp b/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/stub.cpp new file mode 100644 index 000000000..62b74bd43 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/stub.cpp @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void psti_norm_pldi_ds_kernel_2d(__gm__ float *v1, + __gm__ unsigned char *v2, + __gm__ unsigned char *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/compare.py b/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/compare.py new file mode 100644 index 000000000..5adbcb96c --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/psti-pk-pldi-us +# family: predicate-load-store +# target_ops: pto.pldi, pto.psti +# scenarios: predicate-load-store-composition, immediate-offset, load-store-pair-preservation, representative-logical-elements + +import os +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).resolve().parent.parent)) + +from _predicate_load_store_case import compare_norm_store + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_norm_store("golden_v3.bin", "v3.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/golden.py b/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/golden.py new file mode 100644 index 000000000..eb6a105ed --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/golden.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/psti-pk-pldi-us +# family: predicate-load-store +# target_ops: pto.pldi, pto.psti +# scenarios: predicate-load-store-composition, immediate-offset, load-store-pair-preservation, representative-logical-elements + +import argparse +from pathlib import Path +import sys + +sys.path.append(str(Path(__file__).resolve().parent.parent)) + +from _predicate_load_store_case import pk_us_compose, prefix_bits, write_case + + +SEED = 19 +ACTIVE_BITS = 145 + + +def generate(output_dir: Path, seed: int, src_elem_bytes: int) -> None: + del seed + del src_elem_bytes + write_case(output_dir, pk_us_compose(prefix_bits(ACTIVE_BITS))) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for psti-pk-pldi-us.") + parser.add_argument( + "--output-dir", + type=Path, + default=Path("."), + help="Directory where v1.bin/v2.bin/v3.bin/golden_v3.bin are written.", + ) + parser.add_argument("--seed", type=int, default=SEED, help="Numpy random seed.") + parser.add_argument( + "--src-elem-bytes", + type=int, + default=4, + help="Unused compatibility option kept for the shared runner surface.", + ) + args = parser.parse_args() + generate(args.output_dir, args.seed, args.src_elem_bytes) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/kernel.pto new file mode 100644 index 000000000..9748b746c --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/kernel.pto @@ -0,0 +1,50 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/psti-pk-pldi-us +// family: predicate-load-store +// target_ops: pto.pldi, pto.psti +// scenarios: predicate-load-store-composition, immediate-offset, load-store-pair-preservation, representative-logical-elements +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @psti_pk_pldi_us_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c145 = arith.constant 145 : i32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c10240_i64 = arith.constant 10240 : i64 + %false = arith.constant false + + %ub_mid = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c10240_i64 : i64 -> !pto.ptr + + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_mid, %c0_i64, %c32_i64, %c32_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c145) -> (i32) { + %src, %next = pto.plt_b8 %remaining : i32 -> !pto.mask, i32 + pto.psti %src, %ub_mid[%c8], "PK" : !pto.mask, !pto.ptr, index + pto.mem_bar "VST_VLD" + %loaded = pto.pldi %ub_mid[%c8], "US" : !pto.ptr, index -> !pto.mask + pto.psts %loaded, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/launch.cpp b/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/launch.cpp new file mode 100644 index 000000000..40a6df38d --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void psti_pk_pldi_us_kernel_2d(__gm__ float *v1, + __gm__ unsigned char *v2, + __gm__ unsigned char *v3); + +void LaunchPsti_pk_pldi_us_kernel_2d(float *v1, unsigned char *v2, + unsigned char *v3, void *stream) { + psti_pk_pldi_us_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ unsigned char *)v2, + (__gm__ unsigned char *)v3); +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/main.cpp b/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/main.cpp new file mode 100644 index 000000000..0f2567d2c --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/main.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/psti-pk-pldi-us +// ----------------------------------------------------------------------------- +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchPsti_pk_pldi_us_kernel_2d(float *v1, unsigned char *v2, + unsigned char *v3, void *stream); + +int main() { + size_t fileSize_v1 = 1024 * sizeof(float); + size_t fileSize_v2 = 1024 * sizeof(unsigned char); + size_t fileSize_v3 = 1024 * sizeof(unsigned char); + float *v1Host = nullptr; + float *v1Device = nullptr; + unsigned char *v2Host = nullptr; + unsigned char *v2Device = nullptr; + unsigned char *v3Host = nullptr; + unsigned char *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchPsti_pk_pldi_us_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/stub.cpp b/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/stub.cpp new file mode 100644 index 000000000..161fc624a --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/stub.cpp @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void psti_pk_pldi_us_kernel_2d(__gm__ float *v1, + __gm__ unsigned char *v2, + __gm__ unsigned char *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psti-pk/compare.py b/test/vpto/cases/micro-op/predicate-load-store/psti-pk/compare.py new file mode 100644 index 000000000..e0ca16a26 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psti-pk/compare.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/psti-pk +# family: predicate-load-store +# target_ops: pto.psti +# scenarios: packed-store, immediate-offset, representative-logical-elements + +import numpy as np + + +EXPECTED_WORDS = 8 +PK_STORAGE_BYTES = 16 + + +def main() -> None: + golden = np.fromfile("golden_v1.bin", dtype=np.uint8) + output = np.fromfile("v1.bin", dtype=np.uint8) + expected_bytes = EXPECTED_WORDS * 4 + if golden.size != expected_bytes or output.size != expected_bytes: + print( + f"[ERROR] Unexpected byte count: golden={golden.size} " + f"out={output.size} expected={expected_bytes}" + ) + raise SystemExit(2) + if not np.array_equal(golden[:PK_STORAGE_BYTES], output[:PK_STORAGE_BYTES]): + diff = np.nonzero(golden[:PK_STORAGE_BYTES] != output[:PK_STORAGE_BYTES])[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (psti PK raw packed store): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + raise SystemExit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/psti-pk/golden.py b/test/vpto/cases/micro-op/predicate-load-store/psti-pk/golden.py new file mode 100644 index 000000000..42fd1b842 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psti-pk/golden.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/psti-pk +# family: predicate-load-store +# target_ops: pto.psti +# scenarios: packed-store, immediate-offset, representative-logical-elements + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 8 +ACTIVE_BITS = 145 +PK_STORAGE_BYTES = 16 + + +def prefix_bits(active_bits: int) -> np.ndarray: + bits = np.zeros((256,), dtype=np.uint8) + bits[:active_bits] = 1 + return bits + + +def generate(output_dir: Path, seed: int) -> None: + del seed + bits = prefix_bits(ACTIVE_BITS) + packed_pk = np.packbits(bits[::2], bitorder="little") + out = np.zeros((OUTPUT_WORDS * 4,), dtype=np.uint8) + out[:PK_STORAGE_BYTES] = packed_pk[:PK_STORAGE_BYTES] + + output_dir.mkdir(parents=True, exist_ok=True) + np.zeros((OUTPUT_WORDS,), dtype=np.uint32).tofile(output_dir / "v1.bin") + out.view(np.uint32).tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op psti-pk validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/psti-pk/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/psti-pk/kernel.pto new file mode 100644 index 000000000..da7f2f773 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psti-pk/kernel.pto @@ -0,0 +1,33 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/psti-pk +// family: predicate-load-store +// target_ops: pto.psti +// scenarios: packed-store, immediate-offset, representative-logical-elements +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @psti_pk_kernel_2d(%arg0: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c145 = arith.constant 145 : i32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %src, %next = pto.plt_b8 %c145 : i32 -> !pto.mask, i32 + pto.psti %src, %ub_out[%c0], "PK" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psti-pk/launch.cpp b/test/vpto/cases/micro-op/predicate-load-store/psti-pk/launch.cpp new file mode 100644 index 000000000..5be1e518d --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psti-pk/launch.cpp @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void psti_pk_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPsti_pk_kernel_2d(uint32_t *v1, void *stream) { + psti_pk_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psti-pk/main.cpp b/test/vpto/cases/micro-op/predicate-load-store/psti-pk/main.cpp new file mode 100644 index 000000000..8be1b45c4 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psti-pk/main.cpp @@ -0,0 +1,78 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/psti-pk +// ----------------------------------------------------------------------------- +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchPsti_pk_kernel_2d(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 8; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchPsti_pk_kernel_2d(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psti-pk/stub.cpp b/test/vpto/cases/micro-op/predicate-load-store/psti-pk/stub.cpp new file mode 100644 index 000000000..963b77e5d --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psti-pk/stub.cpp @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void psti_pk_kernel_2d(__gm__ uint32_t *v1) { + (void)v1; +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/compare.py b/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/compare.py new file mode 100644 index 000000000..39299f639 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/psts-norm-plds-ds +# family: predicate-load-store +# target_ops: pto.plds, pto.psts +# scenarios: predicate-load-store-composition, dynamic-offset, load-store-pair-preservation, representative-logical-elements + +import os +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).resolve().parent.parent)) + +from _predicate_load_store_case import compare_norm_store + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_norm_store("golden_v3.bin", "v3.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/golden.py b/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/golden.py new file mode 100644 index 000000000..19a16409e --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/golden.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/psts-norm-plds-ds +# family: predicate-load-store +# target_ops: pto.plds, pto.psts +# scenarios: predicate-load-store-composition, dynamic-offset, load-store-pair-preservation, representative-logical-elements + +import argparse +from pathlib import Path +import sys + +sys.path.append(str(Path(__file__).resolve().parent.parent)) + +from _predicate_load_store_case import norm_ds_compose, prefix_bits, write_case + + +SEED = 19 +ACTIVE_BITS = 175 + + +def generate(output_dir: Path, seed: int) -> None: + del seed + write_case(output_dir, norm_ds_compose(prefix_bits(ACTIVE_BITS))) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for psts-norm-plds-ds.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/kernel.pto new file mode 100644 index 000000000..fc82cc0f3 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/kernel.pto @@ -0,0 +1,55 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/psts-norm-plds-ds +// family: predicate-load-store +// target_ops: pto.plds, pto.psts +// scenarios: predicate-load-store-composition, dynamic-offset, load-store-pair-preservation, representative-logical-elements +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @psts_norm_plds_ds_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c175 = arith.constant 175 : i32 + %c0_i32 = arith.constant 0 : i32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c10240_i64 = arith.constant 10240 : i64 + %false = arith.constant false + + %ub_mid = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c10240_i64 : i64 -> !pto.ptr + + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_mid, %c0_i64, %c32_i64, %c32_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c175) -> (i32) { + %byte_offset = arith.addi %iv, %c32 : index + %src, %next = pto.plt_b8 %remaining : i32 -> !pto.mask, i32 + %zero, %_unused = pto.plt_b8 %c0_i32 : i32 -> !pto.mask, i32 + pto.psts %src, %ub_mid[%byte_offset], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %zero, %ub_mid[%c64], "NORM" : !pto.mask, !pto.ptr, index + pto.mem_bar "VST_VLD" + %loaded = pto.plds %ub_mid[%byte_offset], "DS" : !pto.ptr, index -> !pto.mask + pto.psts %loaded, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/launch.cpp b/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/launch.cpp new file mode 100644 index 000000000..7c9a920cc --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void psts_norm_plds_ds_kernel_2d(__gm__ float *v1, + __gm__ unsigned char *v2, + __gm__ unsigned char *v3); + +void LaunchPsts_norm_plds_ds_kernel_2d(float *v1, unsigned char *v2, + unsigned char *v3, void *stream) { + psts_norm_plds_ds_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ unsigned char *)v2, + (__gm__ unsigned char *)v3); +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/main.cpp b/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/main.cpp new file mode 100644 index 000000000..2f0be35a7 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/main.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/psts-norm-plds-ds +// ----------------------------------------------------------------------------- +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchPsts_norm_plds_ds_kernel_2d(float *v1, unsigned char *v2, + unsigned char *v3, void *stream); + +int main() { + size_t fileSize_v1 = 1024 * sizeof(float); + size_t fileSize_v2 = 1024 * sizeof(unsigned char); + size_t fileSize_v3 = 1024 * sizeof(unsigned char); + float *v1Host = nullptr; + float *v1Device = nullptr; + unsigned char *v2Host = nullptr; + unsigned char *v2Device = nullptr; + unsigned char *v3Host = nullptr; + unsigned char *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchPsts_norm_plds_ds_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/stub.cpp b/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/stub.cpp new file mode 100644 index 000000000..021494ffb --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/stub.cpp @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void psts_norm_plds_ds_kernel_2d(__gm__ float *v1, + __gm__ unsigned char *v2, + __gm__ unsigned char *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/compare.py b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/compare.py new file mode 100644 index 000000000..fe48e2bdb --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary +# family: predicate-load-store +# target_ops: pto.plds, pto.psts +# scenarios: predicate-load-store-composition, dynamic-offset, load-store-pair-preservation, representative-logical-elements + +import os +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).resolve().parent.parent)) + +from _predicate_load_store_case import compare_norm_store + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_norm_store("golden_v3.bin", "v3.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/golden.py b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/golden.py new file mode 100644 index 000000000..f2ab5e6e3 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/golden.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary +# family: predicate-load-store +# target_ops: pto.plds, pto.psts +# scenarios: predicate-load-store-composition, dynamic-offset, load-store-pair-preservation, representative-logical-elements + +import argparse +from pathlib import Path +import sys + +sys.path.append(str(Path(__file__).resolve().parent.parent)) + +from _predicate_load_store_case import pk_us_compose, prefix_bits, write_case + + +SEED = 19 +ACTIVE_BITS = 173 + + +def generate(output_dir: Path, seed: int) -> None: + del seed + write_case(output_dir, pk_us_compose(prefix_bits(ACTIVE_BITS))) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate inputs/golden for psts-pk-plds-us-prefix-boundary." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/kernel.pto new file mode 100644 index 000000000..534559cee --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/kernel.pto @@ -0,0 +1,52 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary +// family: predicate-load-store +// target_ops: pto.plds, pto.psts +// scenarios: predicate-load-store-composition, dynamic-offset, load-store-pair-preservation, representative-logical-elements +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @psts_pk_plds_us_prefix_boundary_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c173 = arith.constant 173 : i32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c10240_i64 = arith.constant 10240 : i64 + %false = arith.constant false + + %ub_mid = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c10240_i64 : i64 -> !pto.ptr + + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_mid, %c0_i64, %c32_i64, %c32_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c173) -> (i32) { + %byte_offset = arith.addi %iv, %c16 : index + %src, %next = pto.plt_b8 %remaining : i32 -> !pto.mask, i32 + pto.psts %src, %ub_mid[%byte_offset], "PK" : !pto.mask, !pto.ptr, index + pto.mem_bar "VST_VLD" + %loaded = pto.plds %ub_mid[%byte_offset], "US" : !pto.ptr, index -> !pto.mask + pto.psts %loaded, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/launch.cpp b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/launch.cpp new file mode 100644 index 000000000..a2d8377c7 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/launch.cpp @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void psts_pk_plds_us_prefix_boundary_kernel_2d(__gm__ float *v1, + __gm__ unsigned char *v2, + __gm__ unsigned char *v3); + +void LaunchPsts_pk_plds_us_prefix_boundary_kernel_2d(float *v1, unsigned char *v2, + unsigned char *v3, void *stream) { + psts_pk_plds_us_prefix_boundary_kernel_2d<<<1, nullptr, stream>>>( + (__gm__ float *)v1, (__gm__ unsigned char *)v2, (__gm__ unsigned char *)v3); +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/main.cpp b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/main.cpp new file mode 100644 index 000000000..d76ca6ac1 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/main.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary +// ----------------------------------------------------------------------------- +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchPsts_pk_plds_us_prefix_boundary_kernel_2d(float *v1, unsigned char *v2, + unsigned char *v3, void *stream); + +int main() { + size_t fileSize_v1 = 1024 * sizeof(float); + size_t fileSize_v2 = 1024 * sizeof(unsigned char); + size_t fileSize_v3 = 1024 * sizeof(unsigned char); + float *v1Host = nullptr; + float *v1Device = nullptr; + unsigned char *v2Host = nullptr; + unsigned char *v2Device = nullptr; + unsigned char *v3Host = nullptr; + unsigned char *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchPsts_pk_plds_us_prefix_boundary_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/stub.cpp b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/stub.cpp new file mode 100644 index 000000000..966832898 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/stub.cpp @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void psts_pk_plds_us_prefix_boundary_kernel_2d(__gm__ float *v1, + __gm__ unsigned char *v2, + __gm__ unsigned char *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/compare.py b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/compare.py new file mode 100644 index 000000000..8a88450d6 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/psts-pk-plds-us +# family: predicate-load-store +# target_ops: pto.plds, pto.psts +# scenarios: predicate-load-store-composition, dynamic-offset, load-store-pair-preservation, representative-logical-elements + +import os +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).resolve().parent.parent)) + +from _predicate_load_store_case import compare_norm_store + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_norm_store("golden_v3.bin", "v3.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/golden.py b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/golden.py new file mode 100644 index 000000000..cbbb855be --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/golden.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/psts-pk-plds-us +# family: predicate-load-store +# target_ops: pto.plds, pto.psts +# scenarios: predicate-load-store-composition, dynamic-offset, load-store-pair-preservation, representative-logical-elements + +import argparse +from pathlib import Path +import sys + +sys.path.append(str(Path(__file__).resolve().parent.parent)) + +from _predicate_load_store_case import pk_us_compose, prefix_bits, write_case + + +SEED = 19 +ACTIVE_BITS = 171 + + +def generate(output_dir: Path, seed: int) -> None: + del seed + write_case(output_dir, pk_us_compose(prefix_bits(ACTIVE_BITS))) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for psts-pk-plds-us.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/kernel.pto new file mode 100644 index 000000000..dcc2bede6 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/kernel.pto @@ -0,0 +1,51 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/psts-pk-plds-us +// family: predicate-load-store +// target_ops: pto.plds, pto.psts +// scenarios: predicate-load-store-composition, dynamic-offset, load-store-pair-preservation, representative-logical-elements +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @psts_pk_plds_us_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c171 = arith.constant 171 : i32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c10240_i64 = arith.constant 10240 : i64 + %false = arith.constant false + + %ub_mid = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c10240_i64 : i64 -> !pto.ptr + + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_mid, %c0_i64, %c32_i64, %c32_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c171) -> (i32) { + %byte_offset = arith.addi %iv, %c16 : index + %src, %next = pto.plt_b8 %remaining : i32 -> !pto.mask, i32 + pto.psts %src, %ub_mid[%byte_offset], "PK" : !pto.mask, !pto.ptr, index + pto.mem_bar "VST_VLD" + %loaded = pto.plds %ub_mid[%byte_offset], "US" : !pto.ptr, index -> !pto.mask + pto.psts %loaded, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/launch.cpp b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/launch.cpp new file mode 100644 index 000000000..acc8dfb7b --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void psts_pk_plds_us_kernel_2d(__gm__ float *v1, + __gm__ unsigned char *v2, + __gm__ unsigned char *v3); + +void LaunchPsts_pk_plds_us_kernel_2d(float *v1, unsigned char *v2, + unsigned char *v3, void *stream) { + psts_pk_plds_us_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ unsigned char *)v2, + (__gm__ unsigned char *)v3); +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/main.cpp b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/main.cpp new file mode 100644 index 000000000..1462814f4 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/main.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/psts-pk-plds-us +// ----------------------------------------------------------------------------- +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchPsts_pk_plds_us_kernel_2d(float *v1, unsigned char *v2, + unsigned char *v3, void *stream); + +int main() { + size_t fileSize_v1 = 1024 * sizeof(float); + size_t fileSize_v2 = 1024 * sizeof(unsigned char); + size_t fileSize_v3 = 1024 * sizeof(unsigned char); + float *v1Host = nullptr; + float *v1Device = nullptr; + unsigned char *v2Host = nullptr; + unsigned char *v2Device = nullptr; + unsigned char *v3Host = nullptr; + unsigned char *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchPsts_pk_plds_us_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/stub.cpp b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/stub.cpp new file mode 100644 index 000000000..bf4ecc010 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/stub.cpp @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void psts_pk_plds_us_kernel_2d(__gm__ float *v1, + __gm__ unsigned char *v2, + __gm__ unsigned char *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/compare.py b/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/compare.py new file mode 100644 index 000000000..845f5233e --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/pstu-init-align-outside-loop +# family: predicate-load-store +# target_ops: pto.pstu +# scenarios: unaligned-predicate-store, state-update, representative-logical-elements, init-align-outside-loop + +import os +import sys +import numpy as np + +EXPECTED_WORDS = 8 + + +def compare_packed_pred_mask(golden_path, output_path): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch (packed mask words): idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_packed_pred_mask("golden_v3.bin", "v3.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/golden.py b/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/golden.py new file mode 100644 index 000000000..bf0e5a2ab --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/golden.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/pstu-init-align-outside-loop +# family: predicate-load-store +# target_ops: pto.pstu +# scenarios: unaligned-predicate-store, state-update, representative-logical-elements, init-align-outside-loop +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +PACKED_BYTES_PER_STORE = 8 +OUTPUT_WORDS = 8 + + +def _pack_mask_b32(active_lanes: int) -> np.ndarray: + if active_lanes < 0 or active_lanes > 64: + raise ValueError(f"active_lanes must be in [0, 64], got {active_lanes}") + logical = np.zeros((64,), dtype=np.uint8) + logical[:active_lanes] = 1 + packed = np.packbits(logical, bitorder="little") + out = np.zeros((PACKED_BYTES_PER_STORE,), dtype=np.uint8) + out[: packed.size] = packed + return out + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-3.0, 3.0, size=(ROWS, COLS)).astype(np.float32) + v2 = rng.uniform(-1.0, 1.0, size=(ROWS, COLS)).astype(np.float32) + + first = _pack_mask_b32(13) + second = _pack_mask_b32(7) + packed = np.zeros((OUTPUT_WORDS * np.dtype(np.uint32).itemsize,), dtype=np.uint8) + packed[:PACKED_BYTES_PER_STORE] = first + packed[PACKED_BYTES_PER_STORE : 2 * PACKED_BYTES_PER_STORE] = second + packed[2 * PACKED_BYTES_PER_STORE : 3 * PACKED_BYTES_PER_STORE] = first + packed[3 * PACKED_BYTES_PER_STORE : 4 * PACKED_BYTES_PER_STORE] = second + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + output_init.tofile(output_dir / "v3.bin") + packed.view(np.uint32).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op pstu-init-align-outside-loop validation." + ) + parser.add_argument( + "--output-dir", + type=Path, + default=Path("."), + help="Directory where v1.bin/v2.bin/v3.bin/golden_v3.bin are written.", + ) + parser.add_argument( + "--seed", + type=int, + default=SEED, + help="Numpy random seed.", + ) + args = parser.parse_args() + + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/kernel.pto new file mode 100644 index 000000000..1bb6545b0 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/kernel.pto @@ -0,0 +1,45 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/pstu-init-align-outside-loop +// family: predicate-load-store +// target_ops: pto.pstu +// scenarios: unaligned-predicate-store, state-update, representative-logical-elements, init-align-outside-loop +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @pstu_init_align_outside_loop_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c8_i64 = arith.constant 8 : i64 + %c32_i64 = arith.constant 32 : i64 + %c0_i32 = arith.constant 0 : i32 + %c13 = arith.constant 13 : i32 + %c7 = arith.constant 7 : i32 + + %ub_mask = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.vecscope { + %align_init = pto.init_align : !pto.align + %align_final, %base_final = scf.for %iter = %c0 to %c2 step %c1 + iter_args(%align_iter = %align_init, %base_iter = %ub_mask) + -> (!pto.align, !pto.ptr) { + %value, %next = pto.plt_b32 %c13 : i32 -> !pto.mask, i32 + %align_out, %base_out = pto.pstu %align_iter, %value, %base_iter : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr + %value_tail, %next_tail = pto.plt_b32 %c7 : i32 -> !pto.mask, i32 + %align_tail, %base_tail = pto.pstu %align_out, %value_tail, %base_out : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr + scf.yield %align_tail, %base_tail : !pto.align, !pto.ptr + } + pto.vstas %align_final, %base_final, %c0_i32 : !pto.align, !pto.ptr, i32 + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_mask, %arg2, %c0_i64, %c1_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/launch.cpp b/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/launch.cpp new file mode 100644 index 000000000..1d97093b5 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/pstu-init-align-outside-loop +// family: predicate-load-store +// target_ops: pto.pstu +// scenarios: unaligned-predicate-store, state-update, representative-logical-elements, init-align-outside-loop +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +pstu_init_align_outside_loop_kernel_2d(__gm__ float *v1, __gm__ float *v2, + __gm__ uint32_t *v3); + +void LaunchPstu_init_align_outside_loop_kernel_2d(float *v1, float *v2, + uint32_t *v3, void *stream) { + pstu_init_align_outside_loop_kernel_2d<<<1, nullptr, stream>>>( + (__gm__ float *)v1, (__gm__ float *)v2, (__gm__ uint32_t *)v3); +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/main.cpp b/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/main.cpp new file mode 100644 index 000000000..ad4157e2f --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/main.cpp @@ -0,0 +1,112 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/pstu-init-align-outside-loop +// family: predicate-load-store +// target_ops: pto.pstu +// scenarios: unaligned-predicate-store, state-update, representative-logical-elements, init-align-outside-loop +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchPstu_init_align_outside_loop_kernel_2d(float *v1, float *v2, + uint32_t *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 8; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint32_t); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + uint32_t *v3Host = nullptr; + uint32_t *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchPstu_init_align_outside_loop_kernel_2d(v1Device, v2Device, v3Device, + stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/stub.cpp b/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/stub.cpp new file mode 100644 index 000000000..5334e569c --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/stub.cpp @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/pstu-init-align-outside-loop +// family: predicate-load-store +// target_ops: pto.pstu +// scenarios: unaligned-predicate-store, state-update, representative-logical-elements, init-align-outside-loop +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void +pstu_init_align_outside_loop_kernel_2d(__gm__ float *v1, __gm__ float *v2, + __gm__ uint32_t *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/compare.py b/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/compare.py new file mode 100755 index 000000000..bf213031a --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/pstu-state-advance-boundary +# family: predicate-load-store +# target_ops: pto.pstu +# scenarios: unaligned-predicate-store, state-update, boundary, b16-mask, typed-ptr-b16 + +import os +import sys +import numpy as np + +EXPECTED_WORDS = 16 + + +def compare_packed_pred_mask(golden_path, output_path): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.uint16) + output = np.fromfile(output_path, dtype=np.uint16) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch (packed mask words): idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_packed_pred_mask("golden_v3.bin", "v3.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/golden.py b/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/golden.py new file mode 100755 index 000000000..f9db13980 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/golden.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/pstu-state-advance-boundary +# family: predicate-load-store +# target_ops: pto.pstu +# scenarios: unaligned-predicate-store, state-update, boundary, b16-mask, typed-ptr-b16 +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +PACKED_BYTES_PER_STORE = 16 +OUTPUT_WORDS = 16 + + +def _pack_mask_b16(active_lanes: int) -> np.ndarray: + if active_lanes < 0 or active_lanes > 128: + raise ValueError(f"active_lanes must be in [0, 128], got {active_lanes}") + logical = np.zeros((128,), dtype=np.uint8) + logical[:active_lanes] = 1 + packed = np.packbits(logical, bitorder="little") + out = np.zeros((PACKED_BYTES_PER_STORE,), dtype=np.uint8) + out[: packed.size] = packed + return out + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + + v1 = rng.uniform(-3.0, 3.0, size=(ROWS, COLS)).astype(np.float32) + v2 = rng.uniform(-1.0, 1.0, size=(ROWS, COLS)).astype(np.float32) + + first = _pack_mask_b16(1) + second = _pack_mask_b16(127) + packed = np.concatenate([first, second]).astype(np.uint8, copy=False) + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + output_init.tofile(output_dir / "v3.bin") + packed.view(np.uint16).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op pstu-state-advance-boundary validation." + ) + parser.add_argument( + "--output-dir", + type=Path, + default=Path("."), + help="Directory where v1.bin/v2.bin/v3.bin/golden_v3.bin are written.", + ) + parser.add_argument( + "--seed", + type=int, + default=SEED, + help="Numpy random seed.", + ) + args = parser.parse_args() + + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/kernel.pto new file mode 100644 index 000000000..259c6c100 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/kernel.pto @@ -0,0 +1,40 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/pstu-state-advance-boundary +// family: predicate-load-store +// target_ops: pto.pstu +// scenarios: unaligned-predicate-store, state-update, boundary, b16-mask, typed-ptr-b16 +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @pstu_state_advance_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c127 = arith.constant 127 : i32 + + %ub_mask = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %align0 = pto.init_align : !pto.align + %value0, %next0 = pto.plt_b16 %c1_i32 : i32 -> !pto.mask, i32 + %align1, %base1 = pto.pstu %align0, %value0, %ub_mask : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr + %value1, %next1 = pto.plt_b16 %c127 : i32 -> !pto.mask, i32 + %align2, %base2 = pto.pstu %align1, %value1, %base1 : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr + pto.vstas %align2, %base2, %c0_i32 : !pto.align, !pto.ptr, i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_mask, %arg2, %c0_i64, %c1_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/launch.cpp b/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/launch.cpp new file mode 100644 index 000000000..2c01b6ceb --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/launch.cpp @@ -0,0 +1,58 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/pstu-state-advance-boundary +// family: predicate-load-store +// target_ops: pto.pstu +// scenarios: unaligned-predicate-store, state-update, boundary, b16-mask, typed-ptr-b16 +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pstu_state_advance_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ uint16_t *v3); + +void LaunchPstu_state_advance_kernel_2d(float *v1, float *v2, uint16_t *v3, + void *stream) { + pstu_state_advance_kernel_2d<<<1, nullptr, stream>>>( + (__gm__ float *)v1, (__gm__ float *)v2, (__gm__ uint16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/main.cpp b/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/main.cpp new file mode 100644 index 000000000..1f97fcc70 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/main.cpp @@ -0,0 +1,111 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/pstu-state-advance-boundary +// family: predicate-load-store +// target_ops: pto.pstu +// scenarios: unaligned-predicate-store, state-update, boundary, b16-mask, typed-ptr-b16 +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchPstu_state_advance_kernel_2d(float *v1, float *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 16; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchPstu_state_advance_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/stub.cpp b/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/stub.cpp new file mode 100644 index 000000000..fa4a8bb21 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/stub.cpp @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/pstu-state-advance-boundary +// family: predicate-load-store +// target_ops: pto.pstu +// scenarios: unaligned-predicate-store, state-update, boundary, b16-mask, typed-ptr-b16 +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void pstu_state_advance_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ uint16_t *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/pstu/compare.py b/test/vpto/cases/micro-op/predicate-load-store/pstu/compare.py new file mode 100755 index 000000000..e452bd612 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pstu/compare.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/pstu +# family: predicate-load-store +# target_ops: pto.pstu +# scenarios: unaligned-predicate-store, state-update, representative-logical-elements +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + +EXPECTED_WORDS = 8 +VALID_BYTES = 16 + + +def compare_packed_pred_mask(golden_path, output_path): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + expected_bytes = EXPECTED_WORDS * np.dtype(np.uint32).itemsize + if golden.size != expected_bytes or output.size != expected_bytes: + return False + if not np.array_equal(golden[:VALID_BYTES], output[:VALID_BYTES]): + diff = np.nonzero(golden[:VALID_BYTES] != output[:VALID_BYTES])[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch (packed mask words): idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_packed_pred_mask("golden_v3.bin", "v3.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/pstu/golden.py b/test/vpto/cases/micro-op/predicate-load-store/pstu/golden.py new file mode 100755 index 000000000..678185011 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pstu/golden.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/pstu +# family: predicate-load-store +# target_ops: pto.pstu +# scenarios: unaligned-predicate-store, state-update, representative-logical-elements +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +PACKED_BYTES_PER_STORE = 8 +OUTPUT_WORDS = 8 + + +def _pack_mask_b32(active_lanes: int) -> np.ndarray: + if active_lanes < 0 or active_lanes > 64: + raise ValueError(f"active_lanes must be in [0, 64], got {active_lanes}") + logical = np.zeros((64,), dtype=np.uint8) + logical[:active_lanes] = 1 + packed = np.packbits(logical, bitorder="little") + out = np.zeros((PACKED_BYTES_PER_STORE,), dtype=np.uint8) + out[: packed.size] = packed + return out + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-3.0, 3.0, size=(ROWS, COLS)).astype(np.float32) + v2 = rng.uniform(-1.0, 1.0, size=(ROWS, COLS)).astype(np.float32) + + first = _pack_mask_b32(13) + second = _pack_mask_b32(7) + packed = np.zeros((OUTPUT_WORDS * np.dtype(np.uint32).itemsize,), dtype=np.uint8) + packed[:PACKED_BYTES_PER_STORE] = first + packed[PACKED_BYTES_PER_STORE : 2 * PACKED_BYTES_PER_STORE] = second + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + output_init.tofile(output_dir / "v3.bin") + packed.view(np.uint32).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op pstu validation." + ) + parser.add_argument( + "--output-dir", + type=Path, + default=Path("."), + help="Directory where v1.bin/v2.bin/v3.bin/golden_v3.bin are written.", + ) + parser.add_argument( + "--seed", + type=int, + default=SEED, + help="Numpy random seed.", + ) + args = parser.parse_args() + + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/pstu/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/pstu/kernel.pto new file mode 100644 index 000000000..3669f19af --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pstu/kernel.pto @@ -0,0 +1,41 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/pstu +// family: predicate-load-store +// target_ops: pto.pstu +// scenarios: unaligned-predicate-store, state-update, representative-logical-elements +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @pstu_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c8_i64 = arith.constant 8 : i64 + %c32_i64 = arith.constant 32 : i64 + %c0_i32 = arith.constant 0 : i32 + %c13 = arith.constant 13 : i32 + %c7 = arith.constant 7 : i32 + + %ub_mask = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %align = pto.init_align : !pto.align + %value, %next = pto.plt_b32 %c13 : i32 -> !pto.mask, i32 + %align_out, %base_out = pto.pstu %align, %value, %ub_mask : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr + %value_tail, %next_tail = pto.plt_b32 %c7 : i32 -> !pto.mask, i32 + %align_tail, %base_tail = pto.pstu %align_out, %value_tail, %base_out : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr + pto.vstas %align_tail, %base_tail, %c0_i32 : !pto.align, !pto.ptr, i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_mask, %arg2, %c0_i64, %c1_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/pstu/launch.cpp b/test/vpto/cases/micro-op/predicate-load-store/pstu/launch.cpp new file mode 100644 index 000000000..977155180 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pstu/launch.cpp @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/pstu +// family: predicate-load-store +// target_ops: pto.pstu +// scenarios: unaligned-predicate-store, state-update, representative-logical-elements +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pstu_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ uint32_t *v3); + +void LaunchPstu_kernel_2d(float *v1, float *v2, uint32_t *v3, void *stream) { + pstu_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2, + (__gm__ uint32_t *)v3); +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/pstu/main.cpp b/test/vpto/cases/micro-op/predicate-load-store/pstu/main.cpp new file mode 100644 index 000000000..5c31f323d --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pstu/main.cpp @@ -0,0 +1,110 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/pstu +// family: predicate-load-store +// target_ops: pto.pstu +// scenarios: unaligned-predicate-store, state-update, representative-logical-elements +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchPstu_kernel_2d(float *v1, float *v2, uint32_t *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 8; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint32_t); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + uint32_t *v3Host = nullptr; + uint32_t *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchPstu_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/pstu/stub.cpp b/test/vpto/cases/micro-op/predicate-load-store/pstu/stub.cpp new file mode 100644 index 000000000..70df99121 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pstu/stub.cpp @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/pstu +// family: predicate-load-store +// target_ops: pto.pstu +// scenarios: unaligned-predicate-store, state-update, representative-logical-elements +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void pstu_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ uint32_t *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/compare.py b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/compare.py new file mode 100755 index 000000000..f2a3e0459 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vintlv-vdintlv-lane-boundary +# family: rearrangement +# target_ops: pto.vdintlv, pto.vintlv +# scenarios: paired-roundtrip, lane-order +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/golden.py b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/golden.py new file mode 100755 index 000000000..bc2746fab --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/golden.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vintlv-vdintlv-lane-boundary +# family: rearrangement +# target_ops: pto.vdintlv, pto.vintlv +# scenarios: paired-roundtrip, lane-order +# NOTE: paired vintlv+vdintlv roundtrip should recover the original input, including lane-boundary patterns. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + flat = rng.uniform(-8.0, 8.0, size=ROWS * COLS).astype(np.float32) + for base in range(0, flat.size, 128): + flat[base + 62 : base + 66] = np.array([-62.0, -1.0, 1.0, 62.0], dtype=np.float32) + v1 = flat.reshape(ROWS, COLS) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = v1.astype(np.float32, copy=True) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vintlv+vdintlv lane-boundary validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/kernel.pto new file mode 100644 index 000000000..f9ff23e1b --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/kernel.pto @@ -0,0 +1,52 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vintlv-vdintlv-lane-boundary +// family: rearrangement +// target_ops: pto.vdintlv, pto.vintlv +// scenarios: paired-roundtrip, lane-order, boundary +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vintlv_vdintlv_boundary_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %rhs_offset = arith.addi %offset, %c64 : index + %lhs = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_in[%rhs_offset] : !pto.ptr -> !pto.vreg<64xf32> + %ilow, %ihigh = pto.vintlv %lhs, %rhs : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %dlow, %dhigh = pto.vdintlv %ilow, %ihigh : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + pto.vsts %dlow, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.vsts %dhigh, %ub_out[%rhs_offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/launch.cpp b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/launch.cpp new file mode 100644 index 000000000..f7bb8bf5a --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/launch.cpp @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vintlv-vdintlv-lane-boundary +// family: rearrangement +// target_ops: pto.vdintlv, pto.vintlv +// scenarios: paired-roundtrip, lane-order +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vintlv_vdintlv_boundary_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVintlv_vdintlv_boundary_kernel_2d(float *v1, float *v2, void *stream) { + vintlv_vdintlv_boundary_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/main.cpp b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/main.cpp new file mode 100644 index 000000000..f6fb0606b --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vintlv-vdintlv-lane-boundary +// family: rearrangement +// target_ops: pto.vdintlv, pto.vintlv +// scenarios: paired-roundtrip, lane-order +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVintlv_vdintlv_boundary_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVintlv_vdintlv_boundary_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/stub.cpp b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/stub.cpp new file mode 100644 index 000000000..813765214 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vintlv-vdintlv-lane-boundary +// family: rearrangement +// target_ops: pto.vdintlv, pto.vintlv +// scenarios: paired-roundtrip, lane-order +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vintlv_vdintlv_boundary_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/compare.py b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/compare.py new file mode 100755 index 000000000..afa3d870f --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vintlv-vdintlv +# family: rearrangement +# target_ops: pto.vdintlv, pto.vintlv +# scenarios: paired-roundtrip, lane-order +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/golden.py b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/golden.py new file mode 100755 index 000000000..8878cae52 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vintlv-vdintlv +# family: rearrangement +# target_ops: pto.vdintlv, pto.vintlv +# scenarios: paired-roundtrip, lane-order +# NOTE: paired vintlv+vdintlv roundtrip should recover the original input. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = v1.astype(np.float32, copy=True) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vintlv+vdintlv validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/kernel.pto new file mode 100644 index 000000000..28fcdaab9 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/kernel.pto @@ -0,0 +1,52 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vintlv-vdintlv +// family: rearrangement +// target_ops: pto.vdintlv, pto.vintlv +// scenarios: paired-roundtrip, lane-order +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vintlv_vdintlv_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %rhs_offset = arith.addi %offset, %c64 : index + %lhs = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_in[%rhs_offset] : !pto.ptr -> !pto.vreg<64xf32> + %ilow, %ihigh = pto.vintlv %lhs, %rhs : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %dlow, %dhigh = pto.vdintlv %ilow, %ihigh : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + pto.vsts %dlow, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.vsts %dhigh, %ub_out[%rhs_offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/launch.cpp b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/launch.cpp new file mode 100644 index 000000000..27baf3164 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/launch.cpp @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vintlv-vdintlv +// family: rearrangement +// target_ops: pto.vdintlv, pto.vintlv +// scenarios: paired-roundtrip, lane-order +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vintlv_vdintlv_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVintlv_vdintlv_kernel_2d(float *v1, float *v2, void *stream) { + vintlv_vdintlv_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/main.cpp b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/main.cpp new file mode 100644 index 000000000..0a66ddb10 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vintlv-vdintlv +// family: rearrangement +// target_ops: pto.vdintlv, pto.vintlv +// scenarios: paired-roundtrip, lane-order +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVintlv_vdintlv_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVintlv_vdintlv_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/stub.cpp b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/stub.cpp new file mode 100644 index 000000000..302b145df --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vintlv-vdintlv +// family: rearrangement +// target_ops: pto.vdintlv, pto.vintlv +// scenarios: paired-roundtrip, lane-order +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vintlv_vdintlv_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/rearrangement/vpack-higher/compare.py b/test/vpto/cases/micro-op/rearrangement/vpack-higher/compare.py new file mode 100644 index 000000000..c318abde7 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vpack-higher/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vpack-higher +# family: rearrangement +# target_ops: pto.vpack +# scenarios: narrowing, higher-half-placement, zero-fill-lower-half +# coding=utf-8 +import os +import sys + +import numpy as np + + +def compare_bin(golden_path, output_path, dtype): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.uint16) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vpack-higher/golden.py b/test/vpto/cases/micro-op/rearrangement/vpack-higher/golden.py new file mode 100644 index 000000000..b97089067 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vpack-higher/golden.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vpack-higher +# family: rearrangement +# target_ops: pto.vpack +# scenarios: narrowing, higher-half-placement, zero-fill-lower-half, post-pack-consumer +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +ELEMS = ROWS * COLS +CHUNK = 64 +OUTPUT_ELEMS = ELEMS * 2 +SEED = 19 +BIAS = np.uint16(1) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(-(1 << 20), 1 << 20, size=ELEMS, dtype=np.int32) + v2 = np.zeros(OUTPUT_ELEMS, dtype=np.uint16) + golden_v2 = np.zeros(OUTPUT_ELEMS, dtype=np.uint16) + + narrowed = v1.astype(np.uint16, copy=False) + for chunk_base in range(0, ELEMS, CHUNK): + chunk = narrowed[chunk_base : chunk_base + CHUNK] + out_base = (chunk_base // CHUNK) * (CHUNK * 2) + golden_v2[out_base : out_base + CHUNK] = BIAS + golden_v2[out_base + CHUNK : out_base + 2 * CHUNK] = ( + chunk.astype(np.uint32) + int(BIAS) + ).astype(np.uint16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vpack-higher validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vpack-higher/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vpack-higher/kernel.pto new file mode 100644 index 000000000..a8fe95189 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vpack-higher/kernel.pto @@ -0,0 +1,49 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vpack-higher +// family: rearrangement +// target_ops: pto.vpack +// scenarios: narrowing, higher-half-placement, zero-fill-lower-half, post-pack-consumer +// ----------------------------------------------------------------------------- + +module attributes {pto.target_arch = "a5"} { + func.func @vpack_higher_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c1_i16 = arith.constant 1 : i16 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %store_mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %src_offset = %c0 to %c1024 step %c64 { + %dst_offset = arith.muli %src_offset, %c2 : index + %vec = pto.vlds %ub_in[%src_offset] : !pto.ptr -> !pto.vreg<64xi32> + %packed = pto.vpack %vec, "HIGHER" : !pto.vreg<64xi32> -> !pto.vreg<128xui16> + %observed = pto.vadds %packed, %c1_i16, %store_mask : !pto.vreg<128xui16>, i16, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %observed, %ub_out[%dst_offset], %store_mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/rearrangement/vpack-higher/launch.cpp b/test/vpto/cases/micro-op/rearrangement/vpack-higher/launch.cpp new file mode 100644 index 000000000..3d5224385 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vpack-higher/launch.cpp @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vpack-higher +// family: rearrangement +// target_ops: pto.vpack +// scenarios: narrowing, higher-half-placement, zero-fill-lower-half +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vpack_higher_kernel_2d(__gm__ int *v1, + __gm__ uint16_t *v2); + +void LaunchVpack_higher_kernel_2d(int32_t *v1, uint16_t *v2, void *stream) { + vpack_higher_kernel_2d<<<1, nullptr, stream>>>((__gm__ int *)v1, + (__gm__ uint16_t *)v2); +} diff --git a/test/vpto/cases/micro-op/rearrangement/vpack-higher/main.cpp b/test/vpto/cases/micro-op/rearrangement/vpack-higher/main.cpp new file mode 100644 index 000000000..b28ee6e85 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vpack-higher/main.cpp @@ -0,0 +1,128 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vpack-higher +// family: rearrangement +// target_ops: pto.vpack +// scenarios: narrowing, higher-half-placement, zero-fill-lower-half +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVpack_higher_kernel_2d(int32_t *v1, uint16_t *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(int32_t); + size_t elemCount_v2 = 2048; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + int32_t *v1Host = nullptr; + int32_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVpack_higher_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/rearrangement/vpack-higher/stub.cpp b/test/vpto/cases/micro-op/rearrangement/vpack-higher/stub.cpp new file mode 100644 index 000000000..3fbe073e5 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vpack-higher/stub.cpp @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vpack-higher +// family: rearrangement +// target_ops: pto.vpack +// scenarios: narrowing, higher-half-placement, zero-fill-lower-half +// ----------------------------------------------------------------------------- +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vpack_higher_kernel_2d(__gm__ int *v1, + __gm__ uint16_t *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/rearrangement/vpack-lower/compare.py b/test/vpto/cases/micro-op/rearrangement/vpack-lower/compare.py new file mode 100644 index 000000000..0caf7195c --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vpack-lower/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vpack-lower +# family: rearrangement +# target_ops: pto.vpack +# scenarios: narrowing, lower-half-placement, zero-fill-upper-half +# coding=utf-8 +import os +import sys + +import numpy as np + + +def compare_bin(golden_path, output_path, dtype): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.uint16) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vpack-lower/golden.py b/test/vpto/cases/micro-op/rearrangement/vpack-lower/golden.py new file mode 100644 index 000000000..37ca69a25 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vpack-lower/golden.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vpack-lower +# family: rearrangement +# target_ops: pto.vpack +# scenarios: narrowing, lower-half-placement, zero-fill-upper-half, post-pack-consumer +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +ELEMS = ROWS * COLS +CHUNK = 64 +OUTPUT_ELEMS = ELEMS * 2 +SEED = 19 +BIAS = np.uint16(1) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(-(1 << 20), 1 << 20, size=ELEMS, dtype=np.int32) + v2 = np.zeros(OUTPUT_ELEMS, dtype=np.uint16) + golden_v2 = np.zeros(OUTPUT_ELEMS, dtype=np.uint16) + + narrowed = v1.astype(np.uint16, copy=False) + for chunk_base in range(0, ELEMS, CHUNK): + chunk = narrowed[chunk_base : chunk_base + CHUNK] + out_base = (chunk_base // CHUNK) * (CHUNK * 2) + golden_v2[out_base : out_base + CHUNK] = ( + chunk.astype(np.uint32) + int(BIAS) + ).astype(np.uint16) + golden_v2[out_base + CHUNK : out_base + 2 * CHUNK] = BIAS + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vpack-lower validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vpack-lower/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vpack-lower/kernel.pto new file mode 100644 index 000000000..d73ee2331 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vpack-lower/kernel.pto @@ -0,0 +1,49 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vpack-lower +// family: rearrangement +// target_ops: pto.vpack +// scenarios: narrowing, lower-half-placement, zero-fill-upper-half, post-pack-consumer +// ----------------------------------------------------------------------------- + +module attributes {pto.target_arch = "a5"} { + func.func @vpack_lower_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c1_i16 = arith.constant 1 : i16 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %store_mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %src_offset = %c0 to %c1024 step %c64 { + %dst_offset = arith.muli %src_offset, %c2 : index + %vec = pto.vlds %ub_in[%src_offset] : !pto.ptr -> !pto.vreg<64xi32> + %packed = pto.vpack %vec, "LOWER" : !pto.vreg<64xi32> -> !pto.vreg<128xui16> + %observed = pto.vadds %packed, %c1_i16, %store_mask : !pto.vreg<128xui16>, i16, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %observed, %ub_out[%dst_offset], %store_mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/rearrangement/vpack-lower/launch.cpp b/test/vpto/cases/micro-op/rearrangement/vpack-lower/launch.cpp new file mode 100644 index 000000000..3bf8b0da1 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vpack-lower/launch.cpp @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vpack-lower +// family: rearrangement +// target_ops: pto.vpack +// scenarios: narrowing, lower-half-placement, zero-fill-upper-half +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vpack_lower_kernel_2d(__gm__ int *v1, + __gm__ uint16_t *v2); + +void LaunchVpack_lower_kernel_2d(int32_t *v1, uint16_t *v2, void *stream) { + vpack_lower_kernel_2d<<<1, nullptr, stream>>>((__gm__ int *)v1, + (__gm__ uint16_t *)v2); +} diff --git a/test/vpto/cases/micro-op/rearrangement/vpack-lower/main.cpp b/test/vpto/cases/micro-op/rearrangement/vpack-lower/main.cpp new file mode 100644 index 000000000..5cc58448c --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vpack-lower/main.cpp @@ -0,0 +1,128 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vpack-lower +// family: rearrangement +// target_ops: pto.vpack +// scenarios: narrowing, lower-half-placement, zero-fill-upper-half +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVpack_lower_kernel_2d(int32_t *v1, uint16_t *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(int32_t); + size_t elemCount_v2 = 2048; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + int32_t *v1Host = nullptr; + int32_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVpack_lower_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/rearrangement/vpack-lower/stub.cpp b/test/vpto/cases/micro-op/rearrangement/vpack-lower/stub.cpp new file mode 100644 index 000000000..1d144585a --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vpack-lower/stub.cpp @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vpack-lower +// family: rearrangement +// target_ops: pto.vpack +// scenarios: narrowing, lower-half-placement, zero-fill-upper-half +// ----------------------------------------------------------------------------- +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vpack_lower_kernel_2d(__gm__ int *v1, + __gm__ uint16_t *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/compare.py b/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/compare.py new file mode 100755 index 000000000..4cca2c574 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/compare.py @@ -0,0 +1,208 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vsqz-nontrivial-mask +# family: rearrangement +# target_ops: pto.vsqz +# scenarios: predicate-driven-rearrangement, stable-order, nontrivial-mask +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/golden.py b/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/golden.py new file mode 100755 index 000000000..e7d456a40 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/golden.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vsqz-nontrivial-mask +# family: rearrangement +# target_ops: pto.vsqz +# scenarios: predicate-driven-rearrangement, stable-order, nontrivial-mask +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +LANES = 64 +BLOCKS = ROWS * COLS // LANES +ACTIVE_POSITIONS = [1, 4, 5, 9, 12, 16, 21, 24, 29, 33, 36, 40, 45, 49, 54, 60] +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + values = rng.uniform(-8.0, 8.0, size=(BLOCKS, LANES)).astype(np.float32) + mask_seed = np.full((BLOCKS, LANES), -1.0, dtype=np.float32) + golden = np.zeros((BLOCKS, LANES), dtype=np.float32) + + for block in range(BLOCKS): + for pos in ACTIVE_POSITIONS: + mask_seed[block, pos] = 1.0 + kept = values[block, ACTIVE_POSITIONS] + golden[block, :kept.size] = kept + + output_dir.mkdir(parents=True, exist_ok=True) + values.reshape(-1).tofile(output_dir / "v1.bin") + mask_seed.reshape(-1).tofile(output_dir / "v2.bin") + golden.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate nontrivial-mask inputs/golden for VPTO micro-op vsqz validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/kernel.pto new file mode 100644 index 000000000..719d09451 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/kernel.pto @@ -0,0 +1,61 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vsqz-nontrivial-mask +// family: rearrangement +// target_ops: pto.vsqz +// scenarios: predicate-driven-rearrangement, stable-order, nontrivial-mask +// ----------------------------------------------------------------------------- +// Validate nontrivial predicate-driven compaction: +// - arg0 provides input values. +// - arg1 provides a mask seed (positive => keep lane; non-positive => drop lane) +// and receives the compacted output. +// For each 64-lane chunk: +// 1. Build placement mask via vcmps(mask_seed > 0). +// 2. Run vsqz using that placement mask. +// 3. Store full compacted vector (kept lanes first, tail zeroed) back to UB. + +module attributes {pto.target_arch = "a5"} { + func.func @vsqz_nontrivial_mask_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i32 = arith.constant 64 : i32 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %zero_f32 = arith.constant 0.0 : f32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_mask_seed = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_mask_seed, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %offset = %c0 to %c1024 step %c64 { + %store_mask, %unused = pto.plt_b32 %c64_i32 : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %mask_seed = pto.vlds %ub_mask_seed[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %place = pto.vcmps %mask_seed, %zero_f32, %store_mask, "gt" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask + %out = pto.vsqz %vec, %place : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %store_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/launch.cpp b/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/launch.cpp new file mode 100644 index 000000000..be43cc98f --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vsqz-nontrivial-mask +// family: rearrangement +// target_ops: pto.vsqz +// scenarios: predicate-driven-rearrangement, stable-order +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vsqz_nontrivial_mask_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVsqzNontrivialMask_kernel_2d(float *v1, float *v2, void *stream) { + vsqz_nontrivial_mask_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/main.cpp b/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/main.cpp new file mode 100644 index 000000000..8d467ea02 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vsqz-nontrivial-mask +// family: rearrangement +// target_ops: pto.vsqz +// scenarios: predicate-driven-rearrangement, stable-order +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVsqzNontrivialMask_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVsqzNontrivialMask_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/stub.cpp b/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/stub.cpp new file mode 100644 index 000000000..85dffd7bd --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vsqz-nontrivial-mask +// family: rearrangement +// target_ops: pto.vsqz +// scenarios: predicate-driven-rearrangement, stable-order +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vsqz_nontrivial_mask_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/rearrangement/vsqz/compare.py b/test/vpto/cases/micro-op/rearrangement/vsqz/compare.py new file mode 100755 index 000000000..f10e14e5f --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vsqz/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vsqz +# family: rearrangement +# target_ops: pto.vsqz +# scenarios: predicate-driven-rearrangement, stable-order +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vsqz/golden.py b/test/vpto/cases/micro-op/rearrangement/vsqz/golden.py new file mode 100755 index 000000000..5722d4362 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vsqz/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vsqz +# family: rearrangement +# target_ops: pto.vsqz +# scenarios: predicate-driven-rearrangement, stable-order +# NOTE: full-mask compaction should preserve original lane order. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = v1.copy() + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vsqz validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vsqz/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vsqz/kernel.pto new file mode 100644 index 000000000..24ae27544 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vsqz/kernel.pto @@ -0,0 +1,68 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vsqz +// family: rearrangement +// target_ops: pto.vsqz +// scenarios: predicate-driven-rearrangement, stable-order +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5"} { + func.func @vsqz_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vsqz %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/rearrangement/vsqz/launch.cpp b/test/vpto/cases/micro-op/rearrangement/vsqz/launch.cpp new file mode 100644 index 000000000..511bb9f16 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vsqz/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vsqz +// family: rearrangement +// target_ops: pto.vsqz +// scenarios: predicate-driven-rearrangement, stable-order +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vsqz_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVsqz_kernel_2d(float *v1, float *v2, void *stream) { + vsqz_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/rearrangement/vsqz/main.cpp b/test/vpto/cases/micro-op/rearrangement/vsqz/main.cpp new file mode 100644 index 000000000..ddb3c318b --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vsqz/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vsqz +// family: rearrangement +// target_ops: pto.vsqz +// scenarios: predicate-driven-rearrangement, stable-order +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVsqz_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVsqz_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/rearrangement/vsqz/stub.cpp b/test/vpto/cases/micro-op/rearrangement/vsqz/stub.cpp new file mode 100644 index 000000000..5ed7fa05a --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vsqz/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vsqz +// family: rearrangement +// target_ops: pto.vsqz +// scenarios: predicate-driven-rearrangement, stable-order +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vsqz_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/rearrangement/vsunpack/compare.py b/test/vpto/cases/micro-op/rearrangement/vsunpack/compare.py new file mode 100755 index 000000000..85f2fa8b2 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vsunpack/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vsunpack +# family: rearrangement +# target_ops: pto.vsunpack +# scenarios: pack-unpack, sign-extend +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.int32, 0.0) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vsunpack/golden.py b/test/vpto/cases/micro-op/rearrangement/vsunpack/golden.py new file mode 100755 index 000000000..8ceca3b5e --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vsunpack/golden.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vsunpack +# family: rearrangement +# target_ops: pto.vsunpack +# scenarios: pack-unpack, sign-extend +# NOTE: sign-extending unpack of the lower half of each 128-lane i16 chunk. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +INPUT_ELEMS = 2048 +OUTPUT_ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(np.iinfo(np.int16).min, np.iinfo(np.int16).max + 1, size=INPUT_ELEMS, dtype=np.int16) + v2 = np.zeros(OUTPUT_ELEMS, dtype=np.int32) + golden_v2 = np.zeros(OUTPUT_ELEMS, dtype=np.int32) + for src_base in range(0, INPUT_ELEMS, 128): + dst_base = (src_base // 128) * 64 + golden_v2[dst_base : dst_base + 64] = v1[src_base : src_base + 64].astype(np.int32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vsunpack validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vsunpack/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vsunpack/kernel.pto new file mode 100644 index 000000000..69167a5c3 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vsunpack/kernel.pto @@ -0,0 +1,70 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vsunpack +// family: rearrangement +// target_ops: pto.vsunpack +// scenarios: pack-unpack, sign-extend +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5"} { + func.func @vsunpack_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c2 = arith.constant 2 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %part = arith.constant 0 : index + + %gm_in = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + %gm_out = pto.castptr %arg1 : !pto.ptr -> !pto.ptr + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %gm_in, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %store_mask = pto.pset_b32 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c64 { + %src_offset = arith.muli %offset, %c2 : index + %vec = pto.vlds %ub_in[%src_offset] : !pto.ptr -> !pto.vreg<128xi16> + %out = pto.vsunpack %vec, %part : !pto.vreg<128xi16> -> !pto.vreg<64xi32> + pto.vsts %out, %ub_out[%offset], %store_mask : !pto.vreg<64xi32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/rearrangement/vsunpack/launch.cpp b/test/vpto/cases/micro-op/rearrangement/vsunpack/launch.cpp new file mode 100644 index 000000000..938b0ee13 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vsunpack/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vsunpack +// family: rearrangement +// target_ops: pto.vsunpack +// scenarios: pack-unpack, sign-extend +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vsunpack_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVsunpack_kernel_2d(float *v1, float *v2, void *stream) { + vsunpack_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/rearrangement/vsunpack/main.cpp b/test/vpto/cases/micro-op/rearrangement/vsunpack/main.cpp new file mode 100644 index 000000000..8b0e546ac --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vsunpack/main.cpp @@ -0,0 +1,132 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vsunpack +// family: rearrangement +// target_ops: pto.vsunpack +// scenarios: pack-unpack, sign-extend +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVsunpack_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 2048; + size_t fileSize_v1 = elemCount_v1 * sizeof(int16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int32_t); + int16_t *v1Host = nullptr; + int16_t *v1Device = nullptr; + int32_t *v2Host = nullptr; + int32_t *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVsunpack_kernel_2d(reinterpret_cast(v1Device), + reinterpret_cast(v2Device), + stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/rearrangement/vsunpack/stub.cpp b/test/vpto/cases/micro-op/rearrangement/vsunpack/stub.cpp new file mode 100644 index 000000000..c7f220a8f --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vsunpack/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vsunpack +// family: rearrangement +// target_ops: pto.vsunpack +// scenarios: pack-unpack, sign-extend +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vsunpack_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/compare.py b/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/compare.py new file mode 100644 index 000000000..c5d68b8e4 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/compare.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vusqz-nontrivial-mask +# family: rearrangement +# target_ops: pto.vusqz +# scenarios: predicate-driven-rearrangement, prefix-count + +import sys +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v3.bin", dtype=np.int32) + output = np.fromfile("v3.bin", dtype=np.int32) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + sys.exit(2) + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch at idx={idx}: golden={int(golden[idx])} out={int(output[idx])}" + ) + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/golden.py b/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/golden.py new file mode 100644 index 000000000..81fc36cbc --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/golden.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vusqz-nontrivial-mask +# family: rearrangement +# target_ops: pto.vusqz +# scenarios: predicate-driven-rearrangement, prefix-count + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +LANES = 64 +BLOCKS = ROWS * COLS // LANES +ACTIVE_POSITIONS = [1, 4, 5, 9, 12, 16, 21, 24, 29, 33, 36, 40, 45, 49, 54, 60] +SEED = 19 + + +def build_case() -> tuple[np.ndarray, np.ndarray, np.ndarray]: + src = np.zeros((BLOCKS, LANES), dtype=np.int32) + mask_seed = np.full((BLOCKS, LANES), -1.0, dtype=np.float32) + out = np.zeros((BLOCKS, LANES), dtype=np.int32) + + for block in range(BLOCKS): + src[block] = np.arange(block * 1000 + 7, block * 1000 + 7 + LANES, dtype=np.int32) + for pos in ACTIVE_POSITIONS: + mask_seed[block, pos] = 1.0 + active_count = 0 + out[block, 0] = 0 + for lane in range(1, LANES): + if mask_seed[block, lane - 1] > 0.0: + active_count += 1 + out[block, lane] = active_count + + return src.reshape(ROWS, COLS), mask_seed.reshape(ROWS, COLS), out.reshape(ROWS, COLS) + + +def generate(output_dir: Path) -> None: + src, mask_seed, out = build_case() + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + mask_seed.reshape(-1).tofile(output_dir / "v2.bin") + out.reshape(-1).tofile(output_dir / "golden_v3.bin") + np.zeros_like(out.reshape(-1)).tofile(output_dir / "v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate vusqz nontrivial prefix-count inputs/golden." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + del args.seed + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/kernel.pto new file mode 100644 index 000000000..0bf1ff181 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/kernel.pto @@ -0,0 +1,55 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vusqz-nontrivial-mask +// family: rearrangement +// target_ops: pto.vusqz +// scenarios: predicate-driven-rearrangement, prefix-count +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vusqz_nontrivial_mask_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i32 = arith.constant 64 : i32 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %zero_f32 = arith.constant 0.0 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_mask_seed = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_src, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_mask_seed, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %offset = %c0 to %c1024 step %c64 { + %store_mask, %unused = pto.plt_b32 %c64_i32 : i32 -> !pto.mask, i32 + %src = pto.vlds %ub_src[%offset] : !pto.ptr -> !pto.vreg<64xi32> + %mask_seed = pto.vlds %ub_mask_seed[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %place = pto.vcmps %mask_seed, %zero_f32, %store_mask, "gt" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask + %out = pto.vusqz %src, %place : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> + pto.vsts %out, %ub_out[%offset], %store_mask : !pto.vreg<64xi32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/launch.cpp b/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/launch.cpp new file mode 100644 index 000000000..e9edd85e3 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vusqz-nontrivial-mask +// family: rearrangement +// target_ops: pto.vusqz +// scenarios: predicate-driven-rearrangement, prefix-count +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vusqz_nontrivial_mask_kernel_2d(__gm__ int32_t *v1, + __gm__ float *v2, + __gm__ int32_t *v3); + +void LaunchVusqz_nontrivial_mask_kernel_2d(int32_t *v1, + float *v2, + int32_t *v3, + void *stream) { + vusqz_nontrivial_mask_kernel_2d<<<1, nullptr, stream>>>( + (__gm__ int32_t *)v1, (__gm__ float *)v2, (__gm__ int32_t *)v3); +} diff --git a/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/main.cpp b/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/main.cpp new file mode 100644 index 000000000..50190e7f3 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vusqz-nontrivial-mask +// family: rearrangement +// target_ops: pto.vusqz +// scenarios: predicate-driven-rearrangement, prefix-count +// ----------------------------------------------------------------------------- +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, \ + __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVusqz_nontrivial_mask_kernel_2d(int32_t *v1, + float *v2, + int32_t *v3, + void *stream); + +int main() { + constexpr size_t elemCount = 1024; + size_t fileSizeV1 = elemCount * sizeof(int32_t); + size_t fileSizeV2 = elemCount * sizeof(float); + size_t fileSizeV3 = elemCount * sizeof(int32_t); + int32_t *v1Host = nullptr; + int32_t *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int32_t *v3Host = nullptr; + int32_t *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSizeV1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSizeV2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSizeV3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSizeV1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSizeV2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSizeV3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSizeV1, v1Host, fileSizeV1); + ReadFile("./v2.bin", fileSizeV2, v2Host, fileSizeV2); + std::fill_n(v3Host, elemCount, 0); + ACL_CHECK(aclrtMemcpy(v1Device, fileSizeV1, v1Host, fileSizeV1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSizeV2, v2Host, fileSizeV2, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSizeV3, v3Host, fileSizeV3, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVusqz_nontrivial_mask_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSizeV3, v3Device, fileSizeV3, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSizeV3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + (void)aclrtDestroyStream(stream); + if (deviceSet) + (void)aclrtResetDevice(deviceId); + if (aclInited) + (void)aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/stub.cpp b/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/stub.cpp new file mode 100644 index 000000000..ec84ac166 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/stub.cpp @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vusqz-nontrivial-mask +// family: rearrangement +// target_ops: pto.vusqz +// scenarios: predicate-driven-rearrangement, prefix-count +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vusqz_nontrivial_mask_kernel_2d(__gm__ int32_t *v1, + __gm__ float *v2, + __gm__ int32_t *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/rearrangement/vusqz/compare.py b/test/vpto/cases/micro-op/rearrangement/vusqz/compare.py new file mode 100644 index 000000000..6f3603aab --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vusqz/compare.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vusqz +# family: rearrangement +# target_ops: pto.vusqz +# scenarios: predicate-driven-rearrangement, prefix-count + +import sys +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v3.bin", dtype=np.int32) + output = np.fromfile("v3.bin", dtype=np.int32) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + sys.exit(2) + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch at idx={idx}: golden={int(golden[idx])} out={int(output[idx])}" + ) + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vusqz/golden.py b/test/vpto/cases/micro-op/rearrangement/vusqz/golden.py new file mode 100644 index 000000000..94c38565a --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vusqz/golden.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vusqz +# family: rearrangement +# target_ops: pto.vusqz +# scenarios: predicate-driven-rearrangement, prefix-count + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +LANES = 64 +BLOCKS = ROWS * COLS // LANES +ACTIVE_PER_BLOCK = 16 +SEED = 19 + + +def build_case() -> tuple[np.ndarray, np.ndarray, np.ndarray]: + src = np.zeros((BLOCKS, LANES), dtype=np.int32) + mask_seed = np.full((BLOCKS, LANES), -1.0, dtype=np.float32) + out = np.zeros((BLOCKS, LANES), dtype=np.int32) + + for block in range(BLOCKS): + src[block] = np.arange(block * 100 - 31, block * 100 - 31 + LANES, dtype=np.int32) + mask_seed[block, :ACTIVE_PER_BLOCK] = 1.0 + active_count = 0 + out[block, 0] = 0 + for lane in range(1, LANES): + if mask_seed[block, lane - 1] > 0.0: + active_count += 1 + out[block, lane] = active_count + + return src.reshape(ROWS, COLS), mask_seed.reshape(ROWS, COLS), out.reshape(ROWS, COLS) + + +def generate(output_dir: Path) -> None: + src, mask_seed, out = build_case() + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + mask_seed.reshape(-1).tofile(output_dir / "v2.bin") + out.reshape(-1).tofile(output_dir / "golden_v3.bin") + np.zeros_like(out.reshape(-1)).tofile(output_dir / "v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate vusqz prefix-count inputs/golden.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + del args.seed + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vusqz/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vusqz/kernel.pto new file mode 100644 index 000000000..58fc84285 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vusqz/kernel.pto @@ -0,0 +1,55 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vusqz +// family: rearrangement +// target_ops: pto.vusqz +// scenarios: predicate-driven-rearrangement, prefix-count +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vusqz_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i32 = arith.constant 64 : i32 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %zero_f32 = arith.constant 0.0 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_mask_seed = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_src, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_mask_seed, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %offset = %c0 to %c1024 step %c64 { + %store_mask, %unused = pto.plt_b32 %c64_i32 : i32 -> !pto.mask, i32 + %src = pto.vlds %ub_src[%offset] : !pto.ptr -> !pto.vreg<64xi32> + %mask_seed = pto.vlds %ub_mask_seed[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %place = pto.vcmps %mask_seed, %zero_f32, %store_mask, "gt" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask + %out = pto.vusqz %src, %place : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> + pto.vsts %out, %ub_out[%offset], %store_mask : !pto.vreg<64xi32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/rearrangement/vusqz/launch.cpp b/test/vpto/cases/micro-op/rearrangement/vusqz/launch.cpp new file mode 100644 index 000000000..3684fe9b2 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vusqz/launch.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vusqz +// family: rearrangement +// target_ops: pto.vusqz +// scenarios: predicate-driven-rearrangement, prefix-count +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vusqz_kernel_2d(__gm__ int32_t *v1, + __gm__ float *v2, + __gm__ int32_t *v3); + +void LaunchVusqz_kernel_2d(int32_t *v1, float *v2, int32_t *v3, void *stream) { + vusqz_kernel_2d<<<1, nullptr, stream>>>( + (__gm__ int32_t *)v1, (__gm__ float *)v2, (__gm__ int32_t *)v3); +} diff --git a/test/vpto/cases/micro-op/rearrangement/vusqz/main.cpp b/test/vpto/cases/micro-op/rearrangement/vusqz/main.cpp new file mode 100644 index 000000000..9da958163 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vusqz/main.cpp @@ -0,0 +1,100 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vusqz +// family: rearrangement +// target_ops: pto.vusqz +// scenarios: predicate-driven-rearrangement, prefix-count +// ----------------------------------------------------------------------------- +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, \ + __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVusqz_kernel_2d(int32_t *v1, float *v2, int32_t *v3, void *stream); + +int main() { + constexpr size_t elemCount = 1024; + size_t fileSizeV1 = elemCount * sizeof(int32_t); + size_t fileSizeV2 = elemCount * sizeof(float); + size_t fileSizeV3 = elemCount * sizeof(int32_t); + int32_t *v1Host = nullptr; + int32_t *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int32_t *v3Host = nullptr; + int32_t *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSizeV1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSizeV2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSizeV3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSizeV1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSizeV2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSizeV3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSizeV1, v1Host, fileSizeV1); + ReadFile("./v2.bin", fileSizeV2, v2Host, fileSizeV2); + std::fill_n(v3Host, elemCount, 0); + ACL_CHECK(aclrtMemcpy(v1Device, fileSizeV1, v1Host, fileSizeV1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSizeV2, v2Host, fileSizeV2, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSizeV3, v3Host, fileSizeV3, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVusqz_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSizeV3, v3Device, fileSizeV3, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSizeV3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + (void)aclrtDestroyStream(stream); + if (deviceSet) + (void)aclrtResetDevice(deviceId); + if (aclInited) + (void)aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/rearrangement/vusqz/stub.cpp b/test/vpto/cases/micro-op/rearrangement/vusqz/stub.cpp new file mode 100644 index 000000000..7c08a2624 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vusqz/stub.cpp @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vusqz +// family: rearrangement +// target_ops: pto.vusqz +// scenarios: predicate-driven-rearrangement, prefix-count +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vusqz_kernel_2d(__gm__ int32_t *v1, + __gm__ float *v2, + __gm__ int32_t *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/test/vpto/cases/micro-op/rearrangement/vzunpack/compare.py b/test/vpto/cases/micro-op/rearrangement/vzunpack/compare.py new file mode 100755 index 000000000..0dc97cc35 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vzunpack/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vzunpack +# family: rearrangement +# target_ops: pto.vzunpack +# scenarios: pack-unpack, zero-extend +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.uint32, 0.0) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vzunpack/golden.py b/test/vpto/cases/micro-op/rearrangement/vzunpack/golden.py new file mode 100755 index 000000000..e6014e397 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vzunpack/golden.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vzunpack +# family: rearrangement +# target_ops: pto.vzunpack +# scenarios: pack-unpack, zero-extend +# NOTE: zero-extending unpack of the lower half of each 128-lane ui16 chunk. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +INPUT_ELEMS = 2048 +OUTPUT_ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, np.iinfo(np.uint16).max + 1, size=INPUT_ELEMS, dtype=np.uint16) + v2 = np.zeros(OUTPUT_ELEMS, dtype=np.uint32) + golden_v2 = np.zeros(OUTPUT_ELEMS, dtype=np.uint32) + for src_base in range(0, INPUT_ELEMS, 128): + dst_base = (src_base // 128) * 64 + golden_v2[dst_base : dst_base + 64] = v1[src_base : src_base + 64].astype(np.uint32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vzunpack validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vzunpack/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vzunpack/kernel.pto new file mode 100644 index 000000000..5cd5705fc --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vzunpack/kernel.pto @@ -0,0 +1,70 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vzunpack +// family: rearrangement +// target_ops: pto.vzunpack +// scenarios: pack-unpack, zero-extend +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5"} { + func.func @vzunpack_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c2 = arith.constant 2 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %part = arith.constant 0 : index + + %gm_in = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + %gm_out = pto.castptr %arg1 : !pto.ptr -> !pto.ptr + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %gm_in, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %store_mask = pto.pset_b32 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c64 { + %src_offset = arith.muli %offset, %c2 : index + %vec = pto.vlds %ub_in[%src_offset] : !pto.ptr -> !pto.vreg<128xui16> + %out = pto.vzunpack %vec, %part : !pto.vreg<128xui16> -> !pto.vreg<64xui32> + pto.vsts %out, %ub_out[%offset], %store_mask : !pto.vreg<64xui32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/rearrangement/vzunpack/launch.cpp b/test/vpto/cases/micro-op/rearrangement/vzunpack/launch.cpp new file mode 100644 index 000000000..7fa2a6c4b --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vzunpack/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vzunpack +// family: rearrangement +// target_ops: pto.vzunpack +// scenarios: pack-unpack, zero-extend +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vzunpack_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVzunpack_kernel_2d(float *v1, float *v2, void *stream) { + vzunpack_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/rearrangement/vzunpack/main.cpp b/test/vpto/cases/micro-op/rearrangement/vzunpack/main.cpp new file mode 100644 index 000000000..e3693855f --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vzunpack/main.cpp @@ -0,0 +1,132 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vzunpack +// family: rearrangement +// target_ops: pto.vzunpack +// scenarios: pack-unpack, zero-extend +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVzunpack_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 2048; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint32_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint32_t *v2Host = nullptr; + uint32_t *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVzunpack_kernel_2d(reinterpret_cast(v1Device), + reinterpret_cast(v2Device), + stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/rearrangement/vzunpack/stub.cpp b/test/vpto/cases/micro-op/rearrangement/vzunpack/stub.cpp new file mode 100644 index 000000000..8ce60879f --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vzunpack/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vzunpack +// family: rearrangement +// target_ops: pto.vzunpack +// scenarios: pack-unpack, zero-extend +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vzunpack_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/reduction/vcadd-tail/compare.py b/test/vpto/cases/micro-op/reduction/vcadd-tail/compare.py new file mode 100644 index 000000000..c13d79273 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcadd-tail/compare.py @@ -0,0 +1,39 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float32, 1e-4, 1000) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcadd-tail/golden.py b/test/vpto/cases/micro-op/reduction/vcadd-tail/golden.py new file mode 100644 index 000000000..9ea041d65 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcadd-tail/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LANES = 64 +LOGICAL_ELEMS = 1000 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + flat_in = v1.reshape(-1) + flat_out = golden_v2.reshape(-1) + for offset in range(0, LOGICAL_ELEMS, LANES): + chunk = flat_in[offset:min(offset + LANES, LOGICAL_ELEMS)] + flat_out[offset] = np.sum(chunk, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcadd-tail/kernel.pto b/test/vpto/cases/micro-op/reduction/vcadd-tail/kernel.pto new file mode 100644 index 000000000..0c45c6f56 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcadd-tail/kernel.pto @@ -0,0 +1,40 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vabs_tail_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1000_i32 = arith.constant 1000 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vcadd %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/reduction/vcadd-tail/launch.cpp b/test/vpto/cases/micro-op/reduction/vcadd-tail/launch.cpp new file mode 100644 index 000000000..494bc5bf3 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcadd-tail/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_tail_kernel_2d(float *v1, float *v2, void *stream) { + vabs_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/reduction/vcadd-tail/main.cpp b/test/vpto/cases/micro-op/reduction/vcadd-tail/main.cpp new file mode 100644 index 000000000..cf25e5dff --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcadd-tail/main.cpp @@ -0,0 +1,87 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_tail_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_tail_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/reduction/vcadd-tail/stub.cpp b/test/vpto/cases/micro-op/reduction/vcadd-tail/stub.cpp new file mode 100644 index 000000000..507350679 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcadd-tail/stub.cpp @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vabs_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/reduction/vcadd/compare.py b/test/vpto/cases/micro-op/reduction/vcadd/compare.py new file mode 100644 index 000000000..962985a24 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcadd/compare.py @@ -0,0 +1,204 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcadd/golden.py b/test/vpto/cases/micro-op/reduction/vcadd/golden.py new file mode 100644 index 000000000..906f71e66 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcadd/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LANES = 64 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + flat_in = v1.reshape(-1) + flat_out = golden_v2.reshape(-1) + for offset in range(0, flat_in.size, LANES): + chunk = flat_in[offset:offset + LANES] + flat_out[offset] = np.sum(chunk, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcadd/kernel.pto b/test/vpto/cases/micro-op/reduction/vcadd/kernel.pto new file mode 100644 index 000000000..9c27df73c --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcadd/kernel.pto @@ -0,0 +1,41 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vabs_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vcadd %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/reduction/vcadd/launch.cpp b/test/vpto/cases/micro-op/reduction/vcadd/launch.cpp new file mode 100644 index 000000000..9002bcd67 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcadd/launch.cpp @@ -0,0 +1,62 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vabs_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/reduction/vcadd/main.cpp b/test/vpto/cases/micro-op/reduction/vcadd/main.cpp new file mode 100644 index 000000000..29454461f --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcadd/main.cpp @@ -0,0 +1,122 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/reduction/vcadd/stub.cpp b/test/vpto/cases/micro-op/reduction/vcadd/stub.cpp new file mode 100644 index 000000000..19b277575 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcadd/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/reduction/vcgadd-tail/compare.py b/test/vpto/cases/micro-op/reduction/vcgadd-tail/compare.py new file mode 100755 index 000000000..00c3f8d4b --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgadd-tail/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/reduction/vcgadd-tail +# family: reduction +# target_ops: pto.vcgadd +# scenarios: group-reduction, tail-mask, result-placement +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcgadd-tail/golden.py b/test/vpto/cases/micro-op/reduction/vcgadd-tail/golden.py new file mode 100755 index 000000000..282927eff --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgadd-tail/golden.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LANES = 64 +LOGICAL_ELEMS = 1000 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + flat_in = v1.reshape(-1) + flat_out = golden_v2.reshape(-1) + group_elems = 8 + for offset in range(0, LOGICAL_ELEMS, LANES): + chunk = flat_in[offset:min(offset + LANES, LOGICAL_ELEMS)] + for gi, group in enumerate(range(0, chunk.size, group_elems)): + # VCGADD writes one reduced value per 32B block continuously to low lanes. + flat_out[offset + gi] = np.sum(chunk[group:group + group_elems], dtype=np.float32) + + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcgadd-tail/kernel.pto b/test/vpto/cases/micro-op/reduction/vcgadd-tail/kernel.pto new file mode 100644 index 000000000..4b2bc7d34 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgadd-tail/kernel.pto @@ -0,0 +1,49 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgadd-tail +// family: reduction +// target_ops: pto.vcgadd +// scenarios: group-reduction, tail-mask, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vcgadd_tail_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1000_i32 = arith.constant 1000 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vcgadd %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/reduction/vcgadd-tail/launch.cpp b/test/vpto/cases/micro-op/reduction/vcgadd-tail/launch.cpp new file mode 100644 index 000000000..e35c2b363 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgadd-tail/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgadd-tail +// family: reduction +// target_ops: pto.vcgadd +// scenarios: group-reduction, tail-mask, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcgadd_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVcgaddTail_kernel_2d(float *v1, float *v2, void *stream) { + vcgadd_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/reduction/vcgadd-tail/main.cpp b/test/vpto/cases/micro-op/reduction/vcgadd-tail/main.cpp new file mode 100644 index 000000000..29bdc23bf --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgadd-tail/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgadd-tail +// family: reduction +// target_ops: pto.vcgadd +// scenarios: group-reduction, tail-mask, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcgaddTail_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcgaddTail_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/reduction/vcgadd-tail/stub.cpp b/test/vpto/cases/micro-op/reduction/vcgadd-tail/stub.cpp new file mode 100644 index 000000000..5f8d7e5c3 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgadd-tail/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgadd-tail +// family: reduction +// target_ops: pto.vcgadd +// scenarios: group-reduction, tail-mask, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vcgadd_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/reduction/vcgadd/compare.py b/test/vpto/cases/micro-op/reduction/vcgadd/compare.py new file mode 100755 index 000000000..2c8e5f087 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgadd/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/reduction/vcgadd +# family: reduction +# target_ops: pto.vcgadd +# scenarios: group-reduction, result-placement +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcgadd/golden.py b/test/vpto/cases/micro-op/reduction/vcgadd/golden.py new file mode 100755 index 000000000..efa021477 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgadd/golden.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LANES = 64 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + flat_in = v1.reshape(-1) + flat_out = golden_v2.reshape(-1) + group_elems = 8 + for offset in range(0, flat_in.size, LANES): + chunk = flat_in[offset:offset + LANES] + for gi, group in enumerate(range(0, chunk.size, group_elems)): + # VCGADD writes one reduced value per 32B block continuously to low lanes. + flat_out[offset + gi] = np.sum(chunk[group:group + group_elems], dtype=np.float32) + + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcgadd/kernel.pto b/test/vpto/cases/micro-op/reduction/vcgadd/kernel.pto new file mode 100644 index 000000000..022af54f1 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgadd/kernel.pto @@ -0,0 +1,49 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgadd +// family: reduction +// target_ops: pto.vcgadd +// scenarios: group-reduction, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vcgadd_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vcgadd %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/reduction/vcgadd/launch.cpp b/test/vpto/cases/micro-op/reduction/vcgadd/launch.cpp new file mode 100644 index 000000000..16a1993e8 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgadd/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgadd +// family: reduction +// target_ops: pto.vcgadd +// scenarios: group-reduction, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcgadd_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVcgadd_kernel_2d(float *v1, float *v2, void *stream) { + vcgadd_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/reduction/vcgadd/main.cpp b/test/vpto/cases/micro-op/reduction/vcgadd/main.cpp new file mode 100644 index 000000000..712f0755a --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgadd/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgadd +// family: reduction +// target_ops: pto.vcgadd +// scenarios: group-reduction, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcgadd_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcgadd_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/reduction/vcgadd/stub.cpp b/test/vpto/cases/micro-op/reduction/vcgadd/stub.cpp new file mode 100644 index 000000000..b98cab9ca --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgadd/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgadd +// family: reduction +// target_ops: pto.vcgadd +// scenarios: group-reduction, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vcgadd_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/reduction/vcgmax-tie/compare.py b/test/vpto/cases/micro-op/reduction/vcgmax-tie/compare.py new file mode 100755 index 000000000..a4a5c50c3 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmax-tie/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/reduction/vcgmax-tie +# family: reduction +# target_ops: pto.vcgmax +# scenarios: group-reduction, result-placement +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcgmax-tie/golden.py b/test/vpto/cases/micro-op/reduction/vcgmax-tie/golden.py new file mode 100755 index 000000000..a4d414312 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmax-tie/golden.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LANES = 64 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + flat_seed = v1.reshape(-1) + for offset in range(0, flat_seed.size, LANES): + for group in range(0, LANES, 8): + base = offset + group + flat_seed[base:base + 8] = np.array([7.0, 7.0, -3.0, 1.0, 0.5, -2.0, 4.0, 6.0], dtype=np.float32) + + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + flat_in = v1.reshape(-1) + flat_out = golden_v2.reshape(-1) + group_elems = 8 + for offset in range(0, flat_in.size, LANES): + chunk = flat_in[offset:offset + LANES] + for gi, group in enumerate(range(0, chunk.size, group_elems)): + # VCGMAX writes one reduced value per 32B block continuously to low lanes. + flat_out[offset + gi] = np.max(chunk[group:group + group_elems]) + + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcgmax-tie/kernel.pto b/test/vpto/cases/micro-op/reduction/vcgmax-tie/kernel.pto new file mode 100644 index 000000000..ed4b0337e --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmax-tie/kernel.pto @@ -0,0 +1,49 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgmax-tie +// family: reduction +// target_ops: pto.vcgmax +// scenarios: group-reduction, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vcgmax_tie_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vcgmax %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/reduction/vcgmax-tie/launch.cpp b/test/vpto/cases/micro-op/reduction/vcgmax-tie/launch.cpp new file mode 100644 index 000000000..35e5a63b3 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmax-tie/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgmax-tie +// family: reduction +// target_ops: pto.vcgmax +// scenarios: group-reduction, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcgmax_tie_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVcgmaxTie_kernel_2d(float *v1, float *v2, void *stream) { + vcgmax_tie_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/reduction/vcgmax-tie/main.cpp b/test/vpto/cases/micro-op/reduction/vcgmax-tie/main.cpp new file mode 100644 index 000000000..79ff13b2e --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmax-tie/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgmax-tie +// family: reduction +// target_ops: pto.vcgmax +// scenarios: group-reduction, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcgmaxTie_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcgmaxTie_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/reduction/vcgmax-tie/stub.cpp b/test/vpto/cases/micro-op/reduction/vcgmax-tie/stub.cpp new file mode 100644 index 000000000..9a72f486d --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmax-tie/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgmax-tie +// family: reduction +// target_ops: pto.vcgmax +// scenarios: group-reduction, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vcgmax_tie_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/reduction/vcgmax/compare.py b/test/vpto/cases/micro-op/reduction/vcgmax/compare.py new file mode 100755 index 000000000..f1f037986 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmax/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/reduction/vcgmax +# family: reduction +# target_ops: pto.vcgmax +# scenarios: group-reduction, result-placement +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcgmax/golden.py b/test/vpto/cases/micro-op/reduction/vcgmax/golden.py new file mode 100755 index 000000000..d807ff1e0 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmax/golden.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LANES = 64 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + flat_in = v1.reshape(-1) + flat_out = golden_v2.reshape(-1) + group_elems = 8 + for offset in range(0, flat_in.size, LANES): + chunk = flat_in[offset:offset + LANES] + for gi, group in enumerate(range(0, chunk.size, group_elems)): + # VCGMAX writes one reduced value per 32B block continuously to low lanes. + flat_out[offset + gi] = np.max(chunk[group:group + group_elems]) + + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcgmax/kernel.pto b/test/vpto/cases/micro-op/reduction/vcgmax/kernel.pto new file mode 100644 index 000000000..7fac87069 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmax/kernel.pto @@ -0,0 +1,49 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgmax +// family: reduction +// target_ops: pto.vcgmax +// scenarios: group-reduction, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vcgmax_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vcgmax %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/reduction/vcgmax/launch.cpp b/test/vpto/cases/micro-op/reduction/vcgmax/launch.cpp new file mode 100644 index 000000000..33855f496 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmax/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgmax +// family: reduction +// target_ops: pto.vcgmax +// scenarios: group-reduction, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcgmax_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVcgmax_kernel_2d(float *v1, float *v2, void *stream) { + vcgmax_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/reduction/vcgmax/main.cpp b/test/vpto/cases/micro-op/reduction/vcgmax/main.cpp new file mode 100644 index 000000000..f51aa0ebe --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmax/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgmax +// family: reduction +// target_ops: pto.vcgmax +// scenarios: group-reduction, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcgmax_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcgmax_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/reduction/vcgmax/stub.cpp b/test/vpto/cases/micro-op/reduction/vcgmax/stub.cpp new file mode 100644 index 000000000..f2d06c55b --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmax/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgmax +// family: reduction +// target_ops: pto.vcgmax +// scenarios: group-reduction, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vcgmax_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/reduction/vcgmin-tie/compare.py b/test/vpto/cases/micro-op/reduction/vcgmin-tie/compare.py new file mode 100755 index 000000000..05b8ee45c --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmin-tie/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/reduction/vcgmin-tie +# family: reduction +# target_ops: pto.vcgmin +# scenarios: group-reduction, result-placement +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcgmin-tie/golden.py b/test/vpto/cases/micro-op/reduction/vcgmin-tie/golden.py new file mode 100755 index 000000000..62a18cd0d --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmin-tie/golden.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LANES = 64 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + flat_seed = v1.reshape(-1) + for offset in range(0, flat_seed.size, LANES): + for group in range(0, LANES, 8): + base = offset + group + flat_seed[base:base + 8] = np.array([-7.0, -7.0, 3.0, -1.0, 0.5, 2.0, -4.0, -6.0], dtype=np.float32) + + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + flat_in = v1.reshape(-1) + flat_out = golden_v2.reshape(-1) + group_elems = 8 + for offset in range(0, flat_in.size, LANES): + chunk = flat_in[offset:offset + LANES] + for gi, group in enumerate(range(0, chunk.size, group_elems)): + # VCGMIN writes one reduced value per 32B block continuously to low lanes. + flat_out[offset + gi] = np.min(chunk[group:group + group_elems]) + + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcgmin-tie/kernel.pto b/test/vpto/cases/micro-op/reduction/vcgmin-tie/kernel.pto new file mode 100644 index 000000000..a34cfd2db --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmin-tie/kernel.pto @@ -0,0 +1,49 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgmin-tie +// family: reduction +// target_ops: pto.vcgmin +// scenarios: group-reduction, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vcgmin_tie_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vcgmin %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/reduction/vcgmin-tie/launch.cpp b/test/vpto/cases/micro-op/reduction/vcgmin-tie/launch.cpp new file mode 100644 index 000000000..35f95d660 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmin-tie/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgmin-tie +// family: reduction +// target_ops: pto.vcgmin +// scenarios: group-reduction, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcgmin_tie_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVcgminTie_kernel_2d(float *v1, float *v2, void *stream) { + vcgmin_tie_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/reduction/vcgmin-tie/main.cpp b/test/vpto/cases/micro-op/reduction/vcgmin-tie/main.cpp new file mode 100644 index 000000000..3a940457b --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmin-tie/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgmin-tie +// family: reduction +// target_ops: pto.vcgmin +// scenarios: group-reduction, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcgminTie_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcgminTie_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/reduction/vcgmin-tie/stub.cpp b/test/vpto/cases/micro-op/reduction/vcgmin-tie/stub.cpp new file mode 100644 index 000000000..85f3dbe93 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmin-tie/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgmin-tie +// family: reduction +// target_ops: pto.vcgmin +// scenarios: group-reduction, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vcgmin_tie_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/reduction/vcgmin/compare.py b/test/vpto/cases/micro-op/reduction/vcgmin/compare.py new file mode 100755 index 000000000..57ac3a528 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmin/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/reduction/vcgmin +# family: reduction +# target_ops: pto.vcgmin +# scenarios: group-reduction, result-placement +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcgmin/golden.py b/test/vpto/cases/micro-op/reduction/vcgmin/golden.py new file mode 100755 index 000000000..5f2413af5 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmin/golden.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LANES = 64 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + flat_in = v1.reshape(-1) + flat_out = golden_v2.reshape(-1) + group_elems = 8 + for offset in range(0, flat_in.size, LANES): + chunk = flat_in[offset:offset + LANES] + for gi, group in enumerate(range(0, chunk.size, group_elems)): + # VCGMIN writes one reduced value per 32B block continuously to low lanes. + flat_out[offset + gi] = np.min(chunk[group:group + group_elems]) + + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcgmin/kernel.pto b/test/vpto/cases/micro-op/reduction/vcgmin/kernel.pto new file mode 100644 index 000000000..2d329434f --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmin/kernel.pto @@ -0,0 +1,49 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgmin +// family: reduction +// target_ops: pto.vcgmin +// scenarios: group-reduction, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vcgmin_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vcgmin %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/reduction/vcgmin/launch.cpp b/test/vpto/cases/micro-op/reduction/vcgmin/launch.cpp new file mode 100644 index 000000000..b6787415c --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmin/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgmin +// family: reduction +// target_ops: pto.vcgmin +// scenarios: group-reduction, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcgmin_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVcgmin_kernel_2d(float *v1, float *v2, void *stream) { + vcgmin_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/reduction/vcgmin/main.cpp b/test/vpto/cases/micro-op/reduction/vcgmin/main.cpp new file mode 100644 index 000000000..1c4fc7676 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmin/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgmin +// family: reduction +// target_ops: pto.vcgmin +// scenarios: group-reduction, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcgmin_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcgmin_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/reduction/vcgmin/stub.cpp b/test/vpto/cases/micro-op/reduction/vcgmin/stub.cpp new file mode 100644 index 000000000..fe6d3d1d0 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmin/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgmin +// family: reduction +// target_ops: pto.vcgmin +// scenarios: group-reduction, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vcgmin_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/reduction/vcmax/compare.py b/test/vpto/cases/micro-op/reduction/vcmax/compare.py new file mode 100644 index 000000000..962985a24 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcmax/compare.py @@ -0,0 +1,204 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcmax/golden.py b/test/vpto/cases/micro-op/reduction/vcmax/golden.py new file mode 100644 index 000000000..739d372e6 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcmax/golden.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LANES = 64 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + flat_in = v1.reshape(-1) + flat_out = golden_v2.reshape(-1) + flat_out_u32 = flat_out.view(np.uint32) + for offset in range(0, flat_in.size, LANES): + chunk = flat_in[offset:offset + LANES] + idx = int(np.argmax(chunk)) + flat_out[offset] = chunk[idx] + flat_out_u32[offset + 1] = np.uint32(idx) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcmax/kernel.pto b/test/vpto/cases/micro-op/reduction/vcmax/kernel.pto new file mode 100644 index 000000000..5db925e8e --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcmax/kernel.pto @@ -0,0 +1,41 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vabs_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vcmax %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/reduction/vcmax/launch.cpp b/test/vpto/cases/micro-op/reduction/vcmax/launch.cpp new file mode 100644 index 000000000..9002bcd67 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcmax/launch.cpp @@ -0,0 +1,62 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vabs_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/reduction/vcmax/main.cpp b/test/vpto/cases/micro-op/reduction/vcmax/main.cpp new file mode 100644 index 000000000..29454461f --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcmax/main.cpp @@ -0,0 +1,122 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/reduction/vcmax/stub.cpp b/test/vpto/cases/micro-op/reduction/vcmax/stub.cpp new file mode 100644 index 000000000..19b277575 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcmax/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/reduction/vcmin/compare.py b/test/vpto/cases/micro-op/reduction/vcmin/compare.py new file mode 100644 index 000000000..962985a24 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcmin/compare.py @@ -0,0 +1,204 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcmin/golden.py b/test/vpto/cases/micro-op/reduction/vcmin/golden.py new file mode 100644 index 000000000..bbbfe8d57 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcmin/golden.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LANES = 64 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + flat_in = v1.reshape(-1) + flat_out = golden_v2.reshape(-1) + flat_out_u32 = flat_out.view(np.uint32) + for offset in range(0, flat_in.size, LANES): + chunk = flat_in[offset:offset + LANES] + idx = int(np.argmin(chunk)) + flat_out[offset] = chunk[idx] + flat_out_u32[offset + 1] = np.uint32(idx) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcmin/kernel.pto b/test/vpto/cases/micro-op/reduction/vcmin/kernel.pto new file mode 100644 index 000000000..9f43b2fbc --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcmin/kernel.pto @@ -0,0 +1,41 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vabs_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vcmin %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/reduction/vcmin/launch.cpp b/test/vpto/cases/micro-op/reduction/vcmin/launch.cpp new file mode 100644 index 000000000..9002bcd67 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcmin/launch.cpp @@ -0,0 +1,62 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vabs_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/reduction/vcmin/main.cpp b/test/vpto/cases/micro-op/reduction/vcmin/main.cpp new file mode 100644 index 000000000..29454461f --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcmin/main.cpp @@ -0,0 +1,122 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/reduction/vcmin/stub.cpp b/test/vpto/cases/micro-op/reduction/vcmin/stub.cpp new file mode 100644 index 000000000..19b277575 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcmin/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/reduction/vcpadd-tail/compare.py b/test/vpto/cases/micro-op/reduction/vcpadd-tail/compare.py new file mode 100755 index 000000000..fcaa15984 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcpadd-tail/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/reduction/vcpadd-tail +# family: reduction +# target_ops: pto.vcpadd +# scenarios: prefix-op, tail-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcpadd-tail/golden.py b/test/vpto/cases/micro-op/reduction/vcpadd-tail/golden.py new file mode 100755 index 000000000..08dc83922 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcpadd-tail/golden.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LANES = 64 +LOGICAL_ELEMS = 1000 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + flat_in = v1.reshape(-1) + flat_out = golden_v2.reshape(-1) + for offset in range(0, LOGICAL_ELEMS, LANES): + chunk = flat_in[offset:min(offset + LANES, LOGICAL_ELEMS)] + pair_count = (chunk.size + 1) // 2 + for i in range(pair_count): + a = chunk[2 * i] + b = chunk[2 * i + 1] if (2 * i + 1) < chunk.size else np.float32(0.0) + # VCPADD writes pair-reduction results to low half lanes. + flat_out[offset + i] = np.float32(a + b) + + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcpadd-tail/kernel.pto b/test/vpto/cases/micro-op/reduction/vcpadd-tail/kernel.pto new file mode 100644 index 000000000..817df6a67 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcpadd-tail/kernel.pto @@ -0,0 +1,49 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcpadd-tail +// family: reduction +// target_ops: pto.vcpadd +// scenarios: prefix-op, tail-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vcpadd_tail_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1000_i32 = arith.constant 1000 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vcpadd %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/reduction/vcpadd-tail/launch.cpp b/test/vpto/cases/micro-op/reduction/vcpadd-tail/launch.cpp new file mode 100644 index 000000000..08c0b9ad5 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcpadd-tail/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcpadd-tail +// family: reduction +// target_ops: pto.vcpadd +// scenarios: prefix-op, tail-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcpadd_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVcpaddTail_kernel_2d(float *v1, float *v2, void *stream) { + vcpadd_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/reduction/vcpadd-tail/main.cpp b/test/vpto/cases/micro-op/reduction/vcpadd-tail/main.cpp new file mode 100644 index 000000000..d571471dc --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcpadd-tail/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcpadd-tail +// family: reduction +// target_ops: pto.vcpadd +// scenarios: prefix-op, tail-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcpaddTail_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcpaddTail_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/reduction/vcpadd-tail/stub.cpp b/test/vpto/cases/micro-op/reduction/vcpadd-tail/stub.cpp new file mode 100644 index 000000000..6563b8e2a --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcpadd-tail/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcpadd-tail +// family: reduction +// target_ops: pto.vcpadd +// scenarios: prefix-op, tail-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vcpadd_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/reduction/vcpadd/compare.py b/test/vpto/cases/micro-op/reduction/vcpadd/compare.py new file mode 100755 index 000000000..8094ed94e --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcpadd/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/reduction/vcpadd +# family: reduction +# target_ops: pto.vcpadd +# scenarios: prefix-op, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcpadd/golden.py b/test/vpto/cases/micro-op/reduction/vcpadd/golden.py new file mode 100755 index 000000000..eb41c69f0 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcpadd/golden.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LANES = 64 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + flat_in = v1.reshape(-1) + flat_out = golden_v2.reshape(-1) + for offset in range(0, flat_in.size, LANES): + chunk = flat_in[offset:offset + LANES] + pair_count = (chunk.size + 1) // 2 + for i in range(pair_count): + a = chunk[2 * i] + b = chunk[2 * i + 1] if (2 * i + 1) < chunk.size else np.float32(0.0) + # VCPADD writes pair-reduction results to low half lanes. + flat_out[offset + i] = np.float32(a + b) + + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcpadd/kernel.pto b/test/vpto/cases/micro-op/reduction/vcpadd/kernel.pto new file mode 100644 index 000000000..1dc57f76e --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcpadd/kernel.pto @@ -0,0 +1,49 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcpadd +// family: reduction +// target_ops: pto.vcpadd +// scenarios: prefix-op, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vcpadd_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vcpadd %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/reduction/vcpadd/launch.cpp b/test/vpto/cases/micro-op/reduction/vcpadd/launch.cpp new file mode 100644 index 000000000..ad26d59b2 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcpadd/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcpadd +// family: reduction +// target_ops: pto.vcpadd +// scenarios: prefix-op, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcpadd_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVcpadd_kernel_2d(float *v1, float *v2, void *stream) { + vcpadd_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/reduction/vcpadd/main.cpp b/test/vpto/cases/micro-op/reduction/vcpadd/main.cpp new file mode 100644 index 000000000..7f62d2606 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcpadd/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcpadd +// family: reduction +// target_ops: pto.vcpadd +// scenarios: prefix-op, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcpadd_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcpadd_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/reduction/vcpadd/stub.cpp b/test/vpto/cases/micro-op/reduction/vcpadd/stub.cpp new file mode 100644 index 000000000..0e448ff8e --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcpadd/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcpadd +// family: reduction +// target_ops: pto.vcpadd +// scenarios: prefix-op, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vcpadd_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/compare.py b/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/compare.py new file mode 100644 index 000000000..87edcafae --- /dev/null +++ b/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/compare.py @@ -0,0 +1,54 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden[idx])}, out={int(output[idx])}, dtype={dtype_np})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.int16) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/golden.py b/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/golden.py new file mode 100644 index 000000000..a4fb8dd28 --- /dev/null +++ b/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/golden.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(-20000, 20000, size=ELEMS, dtype=np.int16) + v2 = np.zeros(ELEMS, dtype=np.int16) + golden_v2 = (v1.astype(np.int32) + 4).astype(np.int16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO UB scalar load/store validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/kernel.pto b/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/kernel.pto new file mode 100644 index 000000000..5b9192cc3 --- /dev/null +++ b/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/kernel.pto @@ -0,0 +1,48 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/scalar-load-store/load-store-scalar-ub +// family: scalar-load-store +// target_ops: pto.load_scalar, pto.store_scalar +// scenarios: core-i16, ub-roundtrip, scalar-rw +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @load_store_scalar_ub_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c7_i16 = arith.constant 7 : i16 + %c3_i16 = arith.constant 3 : i16 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %offset = %c0 to %c1024 step %c1 { + %loaded = pto.load_scalar %ub_in[%offset] : !pto.ptr -> i16 + %biased = arith.addi %loaded, %c7_i16 : i16 + pto.store_scalar %biased, %ub_out[%offset] : !pto.ptr, i16 + %echo = pto.load_scalar %ub_out[%offset] : !pto.ptr -> i16 + %result = arith.subi %echo, %c3_i16 : i16 + pto.store_scalar %result, %ub_out[%offset] : !pto.ptr, i16 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/launch.cpp b/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/launch.cpp new file mode 100644 index 000000000..dbfc19cfa --- /dev/null +++ b/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void load_store_scalar_ub_kernel(__gm__ int16_t *v1, + __gm__ int16_t *v2); + +void LaunchLoad_store_scalar_ub_kernel(int16_t *v1, int16_t *v2, void *stream) { + load_store_scalar_ub_kernel<<<1, nullptr, stream>>>((__gm__ int16_t *)v1, + (__gm__ int16_t *)v2); +} diff --git a/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/main.cpp b/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/main.cpp new file mode 100644 index 000000000..ff7a3e98b --- /dev/null +++ b/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/main.cpp @@ -0,0 +1,119 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchLoad_store_scalar_ub_kernel(int16_t *v1, int16_t *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(int16_t); + int16_t *v1Host = nullptr; + int16_t *v2Host = nullptr; + int16_t *v1Device = nullptr; + int16_t *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v1, v2Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v1, v2Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchLoad_store_scalar_ub_kernel(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v1, v2Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v1); + +cleanup: + aclrtFree(v2Device); + aclrtFree(v1Device); + aclrtFreeHost(v2Host); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/stub.cpp b/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/stub.cpp new file mode 100644 index 000000000..654db952b --- /dev/null +++ b/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/stub.cpp @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void load_store_scalar_ub_kernel(__gm__ int16_t *v1, + __gm__ int16_t *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/system-runtime-query/get-block-subblock-id/compare.py b/test/vpto/cases/micro-op/system-runtime-query/get-block-subblock-id/compare.py new file mode 100644 index 000000000..561c706e1 --- /dev/null +++ b/test/vpto/cases/micro-op/system-runtime-query/get-block-subblock-id/compare.py @@ -0,0 +1,54 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden[idx])}, out={int(output[idx])}, dtype={dtype_np})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v1.bin", "v1.bin", np.int64) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/system-runtime-query/get-block-subblock-id/golden.py b/test/vpto/cases/micro-op/system-runtime-query/get-block-subblock-id/golden.py new file mode 100644 index 000000000..785d7bc11 --- /dev/null +++ b/test/vpto/cases/micro-op/system-runtime-query/get-block-subblock-id/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +VALUES = np.full(64, -1, dtype=np.int64) +VALUES[0] = 0 +VALUES[1] = 0 +VALUES[2] = 2 +VALUES[3] = 1 +VALUES[32] = 1 +VALUES[33] = 0 +VALUES[34] = 2 +VALUES[35] = 1 + + +def generate(output_dir: Path, seed: int) -> None: + del seed + v1 = np.full(VALUES.shape, -1, dtype=np.int64) + golden_v1 = VALUES.copy() + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + golden_v1.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO runtime query validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/system-runtime-query/get-block-subblock-id/kernel.pto b/test/vpto/cases/micro-op/system-runtime-query/get-block-subblock-id/kernel.pto new file mode 100644 index 000000000..0d638a30c --- /dev/null +++ b/test/vpto/cases/micro-op/system-runtime-query/get-block-subblock-id/kernel.pto @@ -0,0 +1,35 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/system-runtime-query/get-block-subblock-id +// family: system-runtime-query +// target_ops: pto.get_block_idx, pto.get_subblock_idx, pto.get_block_num, +// pto.get_subblock_num, pto.store_scalar +// scenarios: multi-block, runtime-query +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @get_block_subblock_id_kernel(%arg0: !pto.ptr) { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c32 = arith.constant 32 : index + + pto.vecscope { + } + + %block = pto.get_block_idx + %subblock = pto.get_subblock_idx + %block_num = pto.get_block_num + %subblock_num = pto.get_subblock_num + + %block_idx = arith.index_cast %block : i64 to index + %slot_base = arith.muli %block_idx, %c32 : index + + pto.store_scalar %block, %arg0[%slot_base] : !pto.ptr, i64 + %slot_1 = arith.addi %slot_base, %c1 : index + pto.store_scalar %subblock, %arg0[%slot_1] : !pto.ptr, i64 + %slot_2 = arith.addi %slot_base, %c2 : index + pto.store_scalar %block_num, %arg0[%slot_2] : !pto.ptr, i64 + %slot_3 = arith.addi %slot_base, %c3 : index + pto.store_scalar %subblock_num, %arg0[%slot_3] : !pto.ptr, i64 + return + } +} diff --git a/test/vpto/cases/micro-op/system-runtime-query/get-block-subblock-id/launch.cpp b/test/vpto/cases/micro-op/system-runtime-query/get-block-subblock-id/launch.cpp new file mode 100644 index 000000000..c3d8bcdbe --- /dev/null +++ b/test/vpto/cases/micro-op/system-runtime-query/get-block-subblock-id/launch.cpp @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void get_block_subblock_id_kernel(__gm__ int64_t *v1); + +void LaunchGet_block_subblock_id_kernel(int64_t *v1, void *stream) { + get_block_subblock_id_kernel<<<2, nullptr, stream>>>((__gm__ int64_t *)v1); +} diff --git a/test/vpto/cases/micro-op/system-runtime-query/get-block-subblock-id/main.cpp b/test/vpto/cases/micro-op/system-runtime-query/get-block-subblock-id/main.cpp new file mode 100644 index 000000000..f9ed27342 --- /dev/null +++ b/test/vpto/cases/micro-op/system-runtime-query/get-block-subblock-id/main.cpp @@ -0,0 +1,111 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchGet_block_subblock_id_kernel(int64_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 64; + size_t fileSize_v1 = elemCount_v1 * sizeof(int64_t); + int64_t *v1Host = nullptr; + int64_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchGet_block_subblock_id_kernel(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/system-runtime-query/get-block-subblock-id/stub.cpp b/test/vpto/cases/micro-op/system-runtime-query/get-block-subblock-id/stub.cpp new file mode 100644 index 000000000..d06d80ca4 --- /dev/null +++ b/test/vpto/cases/micro-op/system-runtime-query/get-block-subblock-id/stub.cpp @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void get_block_subblock_id_kernel(__gm__ int64_t *v1) { + (void)v1; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-f16/compare.py b/test/vpto/cases/micro-op/unary-vector/vabs-f16/compare.py new file mode 100755 index 000000000..77d269686 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-f16/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vabs-f16 +# family: unary-vector +# target_ops: pto.vabs +# scenarios: core-f16, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-f16/golden.py b/test/vpto/cases/micro-op/unary-vector/vabs-f16/golden.py new file mode 100755 index 000000000..b90b097ce --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-f16/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vabs-f16 +# family: unary-vector +# target_ops: pto.vabs +# scenarios: core-f16, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.abs(v1).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vabs validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-f16/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vabs-f16/kernel.pto new file mode 100644 index 000000000..d74233130 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-f16/kernel.pto @@ -0,0 +1,68 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vabs-f16 +// family: unary-vector +// target_ops: pto.vabs +// scenarios: core-f16, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5"} { + func.func @vabs_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vabs %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-f16/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vabs-f16/launch.cpp new file mode 100644 index 000000000..58cdd948a --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-f16/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vabs-f16 +// family: unary-vector +// target_ops: pto.vabs +// scenarios: core-f16, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vabs_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-f16/main.cpp b/test/vpto/cases/micro-op/unary-vector/vabs-f16/main.cpp new file mode 100644 index 000000000..76001407d --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-f16/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vabs-f16 +// family: unary-vector +// target_ops: pto.vabs +// scenarios: core-f16, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-f16/stub.cpp b/test/vpto/cases/micro-op/unary-vector/vabs-f16/stub.cpp new file mode 100644 index 000000000..fc01ca487 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-f16/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vabs-f16 +// family: unary-vector +// target_ops: pto.vabs +// scenarios: core-f16, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/compare.py b/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/compare.py new file mode 100644 index 000000000..962985a24 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/compare.py @@ -0,0 +1,204 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/golden.py b/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/golden.py new file mode 100644 index 000000000..95f77e83a --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + del seed + specials = np.array( + [-np.inf, -7.5, -0.0, 0.0, 1.0, np.inf, np.nan, 3.5], + dtype=np.float32, + ) + v1 = np.resize(specials, ROWS * COLS).reshape(ROWS, COLS).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.abs(v1).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/kernel.pto new file mode 100644 index 000000000..1e6de64e5 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/kernel.pto @@ -0,0 +1,40 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vabs_f32_exceptional_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vabs %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/launch.cpp new file mode 100644 index 000000000..806579491 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_f32_exceptional_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_f32_exceptional_kernel_2d(float *v1, float *v2, void *stream) { + vabs_f32_exceptional_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/main.cpp b/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/main.cpp new file mode 100644 index 000000000..b3312f7e2 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/main.cpp @@ -0,0 +1,87 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_f32_exceptional_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_f32_exceptional_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/stub.cpp b/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/stub.cpp new file mode 100644 index 000000000..b790f8429 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/stub.cpp @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vabs_f32_exceptional_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/compare.py b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/compare.py new file mode 100644 index 000000000..672b2df43 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vabs-i16-signed-overflow-edge +# family: unary-vector +# target_ops: pto.vabs +# scenarios: core-i16-signed, full-mask, integer-overflow + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.int16) + output = np.fromfile(output_path, dtype=np.int16) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/golden.py b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/golden.py new file mode 100644 index 000000000..e8562fdac --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vabs-i16-signed-overflow-edge +# family: unary-vector +# target_ops: pto.vabs +# scenarios: core-i16-signed, full-mask, integer-overflow + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + data = rng.integers(-30000, 30000, size=ELEMS, dtype=np.int16) + edge = np.array( + [-32768, -32767, -12345, -1, 0, 1, 12345, 32767, + -32768, -2, 2, -32766, 32766, -1024, 1024, -17], + dtype=np.int16, + ) + data[:edge.size] = edge + golden = np.abs(data).astype(np.int16, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + data.tofile(output_dir / "v1.bin") + np.zeros(ELEMS, dtype=np.int16).tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/kernel.pto new file mode 100644 index 000000000..3e66f18ce --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/kernel.pto @@ -0,0 +1,47 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vabs-i16-signed-overflow-edge +// family: unary-vector +// target_ops: pto.vabs +// scenarios: core-i16-signed, full-mask, integer-overflow +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vabs_i16_signed_overflow_edge_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xi16> + %out = pto.vabs %vec, %mask : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xi16>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/launch.cpp new file mode 100644 index 000000000..be3498f7e --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/launch.cpp @@ -0,0 +1,49 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_i16_signed_overflow_edge_kernel( + __gm__ int16_t *v1, __gm__ int16_t *v2); + +void LaunchVabs_i16_signed_overflow_edge_kernel(int16_t *v1, int16_t *v2, + void *stream) { + vabs_i16_signed_overflow_edge_kernel<<<1, nullptr, stream>>>( + (__gm__ int16_t *)v1, (__gm__ int16_t *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/main.cpp b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/main.cpp new file mode 100644 index 000000000..55de29f79 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vabs-i16-signed-overflow-edge +// family: unary-vector +// target_ops: pto.vabs +// scenarios: core-i16-signed, full-mask, integer-overflow +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_i16_signed_overflow_edge_kernel(int16_t *v1, int16_t *v2, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(int16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int16_t); + int16_t *v1Host = nullptr; + int16_t *v1Device = nullptr; + int16_t *v2Host = nullptr; + int16_t *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_i16_signed_overflow_edge_kernel(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/stub.cpp b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/stub.cpp new file mode 100644 index 000000000..95861591c --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/stub.cpp @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vabs_i16_signed_overflow_edge_kernel( + __gm__ int16_t *v1, __gm__ int16_t *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/compare.py b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/compare.py new file mode 100755 index 000000000..eca2ddc70 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vabs-i16-signed +# family: unary-vector +# target_ops: pto.vabs +# scenarios: core-i16-signed, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/golden.py b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/golden.py new file mode 100755 index 000000000..ae05da408 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vabs-i16-signed +# family: unary-vector +# target_ops: pto.vabs +# scenarios: core-i16-signed, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.abs(v1).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vabs validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/kernel.pto new file mode 100644 index 000000000..7b5cd2e9b --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/kernel.pto @@ -0,0 +1,68 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vabs-i16-signed +// family: unary-vector +// target_ops: pto.vabs +// scenarios: core-i16-signed, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5"} { + func.func @vabs_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vabs %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/launch.cpp new file mode 100644 index 000000000..1d4dc5556 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vabs-i16-signed +// family: unary-vector +// target_ops: pto.vabs +// scenarios: core-i16-signed, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vabs_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/main.cpp b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/main.cpp new file mode 100644 index 000000000..565d8c357 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vabs-i16-signed +// family: unary-vector +// target_ops: pto.vabs +// scenarios: core-i16-signed, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/stub.cpp b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/stub.cpp new file mode 100644 index 000000000..bd6526db6 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vabs-i16-signed +// family: unary-vector +// target_ops: pto.vabs +// scenarios: core-i16-signed, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/compare.py b/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/compare.py new file mode 100755 index 000000000..a6d5c46f7 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vabs-i16-unsigned +# family: unary-vector +# target_ops: pto.vabs +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/golden.py b/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/golden.py new file mode 100755 index 000000000..6f75723de --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vabs-i16-unsigned +# family: unary-vector +# target_ops: pto.vabs +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.abs(v1).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vabs validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/kernel.pto new file mode 100644 index 000000000..94898cb27 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/kernel.pto @@ -0,0 +1,68 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vabs-i16-unsigned +// family: unary-vector +// target_ops: pto.vabs +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5"} { + func.func @vabs_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vabs %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/launch.cpp new file mode 100644 index 000000000..8ae1a4350 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vabs-i16-unsigned +// family: unary-vector +// target_ops: pto.vabs +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vabs_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/main.cpp b/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/main.cpp new file mode 100644 index 000000000..d54791913 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vabs-i16-unsigned +// family: unary-vector +// target_ops: pto.vabs +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/stub.cpp b/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/stub.cpp new file mode 100644 index 000000000..05b4ec452 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vabs-i16-unsigned +// family: unary-vector +// target_ops: pto.vabs +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/compare.py b/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/compare.py new file mode 100644 index 000000000..6098dd82c --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/compare.py @@ -0,0 +1,198 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +ACTIVE_ELEMS = 1000 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float32, 0.0001, ACTIVE_ELEMS) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/golden.py b/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/golden.py new file mode 100644 index 000000000..7448a6a1c --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/golden.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +ACTIVE_ELEMS = 1000 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + edge_values = np.array( + [ + -0.0, + 0.0, + -1.0, + 1.0, + -8.0, + 8.0, + -1.0e-30, + 1.0e-30, + -1.0e10, + 1.0e10, + -3.5, + 3.5, + -7.25, + 7.25, + -2.0, + 2.0, + ], + dtype=np.float32, + ) + v1.reshape(-1)[: edge_values.size] = edge_values + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + flat_v1 = v1.reshape(-1) + flat_golden_v2 = golden_v2.reshape(-1) + flat_golden_v2[:ACTIVE_ELEMS] = np.abs(flat_v1[:ACTIVE_ELEMS]).astype( + np.float32, copy=False + ) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vabs loop-carried vreg validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/kernel.pto new file mode 100644 index 000000000..58b88ae51 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/kernel.pto @@ -0,0 +1,48 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vabs_loop_carried_vreg_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c1000 = arith.constant 1000 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1000_i32 = arith.constant 1000 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1000 step %c64 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = scf.for %iter = %c0 to %c2 step %c1 + iter_args(%carry = %vec) -> (!pto.vreg<64xf32>) { + %abs = pto.vabs %carry, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + scf.yield %abs : !pto.vreg<64xf32> + } + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/launch.cpp new file mode 100644 index 000000000..2663b7625 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vabs_loop_carried_vreg_kernel_2d(__gm__ float *v1, __gm__ float *v2); + +void LaunchVabs_loop_carried_vreg_kernel_2d(float *v1, float *v2, void *stream) { + vabs_loop_carried_vreg_kernel_2d<<<1, nullptr, stream>>>( + (__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/main.cpp b/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/main.cpp new file mode 100644 index 000000000..4d4bd221b --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/main.cpp @@ -0,0 +1,122 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_loop_carried_vreg_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_loop_carried_vreg_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/stub.cpp b/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/stub.cpp new file mode 100644 index 000000000..95d0d8c4c --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/stub.cpp @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void +vabs_loop_carried_vreg_kernel_2d(__gm__ float *v1, __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-tail/compare.py b/test/vpto/cases/micro-op/unary-vector/vabs-tail/compare.py new file mode 100644 index 000000000..c13d79273 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-tail/compare.py @@ -0,0 +1,39 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float32, 1e-4, 1000) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-tail/golden.py b/test/vpto/cases/micro-op/unary-vector/vabs-tail/golden.py new file mode 100644 index 000000000..03fd9a768 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-tail/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LOGICAL_ELEMS = 1000 +OUT_SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v2 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v2.reshape(-1)[:LOGICAL_ELEMS] = np.abs( + v1.reshape(-1)[:LOGICAL_ELEMS] + ).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-tail/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vabs-tail/kernel.pto new file mode 100644 index 000000000..9ef43bd64 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-tail/kernel.pto @@ -0,0 +1,40 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vabs_tail_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1000_i32 = arith.constant 1000 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vabs %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-tail/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vabs-tail/launch.cpp new file mode 100644 index 000000000..494bc5bf3 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-tail/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_tail_kernel_2d(float *v1, float *v2, void *stream) { + vabs_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-tail/main.cpp b/test/vpto/cases/micro-op/unary-vector/vabs-tail/main.cpp new file mode 100644 index 000000000..cf25e5dff --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-tail/main.cpp @@ -0,0 +1,87 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_tail_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_tail_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-tail/stub.cpp b/test/vpto/cases/micro-op/unary-vector/vabs-tail/stub.cpp new file mode 100644 index 000000000..507350679 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-tail/stub.cpp @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vabs_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs/compare.py b/test/vpto/cases/micro-op/unary-vector/vabs/compare.py new file mode 100644 index 000000000..962985a24 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs/compare.py @@ -0,0 +1,204 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vabs/golden.py b/test/vpto/cases/micro-op/unary-vector/vabs/golden.py new file mode 100644 index 000000000..5b04d20bb --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs/golden.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.abs(v1).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vabs validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vabs/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vabs/kernel.pto new file mode 100644 index 000000000..8fb9f1391 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs/kernel.pto @@ -0,0 +1,60 @@ +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5"} { + func.func @vabs_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vabs %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vabs/launch.cpp new file mode 100644 index 000000000..9002bcd67 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs/launch.cpp @@ -0,0 +1,62 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vabs_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs/main.cpp b/test/vpto/cases/micro-op/unary-vector/vabs/main.cpp new file mode 100644 index 000000000..29454461f --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs/main.cpp @@ -0,0 +1,122 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs/stub.cpp b/test/vpto/cases/micro-op/unary-vector/vabs/stub.cpp new file mode 100644 index 000000000..19b277575 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-f16/compare.py b/test/vpto/cases/micro-op/unary-vector/vexp-f16/compare.py new file mode 100755 index 000000000..1971de729 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-f16/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vexp-f16 +# family: unary-vector +# target_ops: pto.vexp +# scenarios: core-f16, full-mask +# NOTE: f16 vector exp baseline. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float16, 0.01) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-f16/golden.py b/test/vpto/cases/micro-op/unary-vector/vexp-f16/golden.py new file mode 100755 index 000000000..aa2de48ca --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-f16/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vexp-f16 +# family: unary-vector +# target_ops: pto.vexp +# scenarios: core-f16, full-mask +# NOTE: f16 vector exp baseline. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-4.0, 4.0, size=(ROWS, COLS)).astype(np.float16) + v2 = np.zeros((ROWS, COLS), dtype=np.float16) + golden_v2 = np.exp(v1.astype(np.float32)).astype(np.float16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vexp f16 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-f16/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vexp-f16/kernel.pto new file mode 100644 index 000000000..939ee5b4d --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-f16/kernel.pto @@ -0,0 +1,48 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vexp-f16 +// family: unary-vector +// target_ops: pto.vexp +// scenarios: core-f16, full-mask +// NOTE: f16 vector exp baseline. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vexp_f16_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xf16> + %out = pto.vexp %vec, %mask : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xf16>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-f16/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vexp-f16/launch.cpp new file mode 100644 index 000000000..4530d8cea --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-f16/launch.cpp @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vexp-f16 +// family: unary-vector +// target_ops: pto.vexp +// scenarios: core-f16, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vexp_f16_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVexp_f16_kernel_2d(float *v1, float *v2, void *stream) { + vexp_f16_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-f16/main.cpp b/test/vpto/cases/micro-op/unary-vector/vexp-f16/main.cpp new file mode 100644 index 000000000..c41afb75e --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-f16/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vexp-f16 +// family: unary-vector +// target_ops: pto.vexp +// scenarios: core-f16, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVexp_f16_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVexp_f16_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-f16/stub.cpp b/test/vpto/cases/micro-op/unary-vector/vexp-f16/stub.cpp new file mode 100644 index 000000000..0a8c1e774 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-f16/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vexp-f16 +// family: unary-vector +// target_ops: pto.vexp +// scenarios: core-f16, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vexp_f16_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/compare.py b/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/compare.py new file mode 100644 index 000000000..962985a24 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/compare.py @@ -0,0 +1,204 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/golden.py b/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/golden.py new file mode 100644 index 000000000..fd76b39a9 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + del seed + specials = np.array( + [-np.inf, -1.0, -0.0, 0.0, 1.0, np.inf, np.nan, 3.5], + dtype=np.float32, + ) + v1 = np.resize(specials, ROWS * COLS).reshape(ROWS, COLS).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.exp(v1).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/kernel.pto new file mode 100644 index 000000000..8228a02a3 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/kernel.pto @@ -0,0 +1,40 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vexp_f32_exceptional_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vexp %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/launch.cpp new file mode 100644 index 000000000..f96f1fc2e --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vexp_f32_exceptional_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVexp_f32_exceptional_kernel_2d(float *v1, float *v2, void *stream) { + vexp_f32_exceptional_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/main.cpp b/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/main.cpp new file mode 100644 index 000000000..2a6824d9f --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/main.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVexp_f32_exceptional_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVexp_f32_exceptional_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/stub.cpp b/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/stub.cpp new file mode 100644 index 000000000..9c55720f1 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/stub.cpp @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vexp_f32_exceptional_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/compare.py b/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/compare.py new file mode 100644 index 000000000..962985a24 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/compare.py @@ -0,0 +1,204 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/golden.py b/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/golden.py new file mode 100644 index 000000000..11cde41fe --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + del seed + specials = np.array( + [-120.0, -104.0, -88.0, 0.0, 40.0, 88.0, 90.0, 104.0], + dtype=np.float32, + ) + v1 = np.resize(specials, ROWS * COLS).reshape(ROWS, COLS).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.exp(v1).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/kernel.pto new file mode 100644 index 000000000..7dde55a08 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/kernel.pto @@ -0,0 +1,40 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vexp_f32_over_underflow_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vexp %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/launch.cpp new file mode 100644 index 000000000..219a407d0 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vexp_f32_over_underflow_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVexp_f32_exceptional_kernel_2d(float *v1, float *v2, void *stream) { + vexp_f32_over_underflow_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/main.cpp b/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/main.cpp new file mode 100644 index 000000000..2a6824d9f --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/main.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVexp_f32_exceptional_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVexp_f32_exceptional_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/stub.cpp b/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/stub.cpp new file mode 100644 index 000000000..6fd67f7f4 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/stub.cpp @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vexp_f32_over_underflow_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-tail/compare.py b/test/vpto/cases/micro-op/unary-vector/vexp-tail/compare.py new file mode 100644 index 000000000..c13d79273 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-tail/compare.py @@ -0,0 +1,39 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float32, 1e-4, 1000) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-tail/golden.py b/test/vpto/cases/micro-op/unary-vector/vexp-tail/golden.py new file mode 100644 index 000000000..b77b49528 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-tail/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LOGICAL_ELEMS = 1000 +OUT_SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-4.0, 4.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v2 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v2.reshape(-1)[:LOGICAL_ELEMS] = np.exp( + v1.reshape(-1)[:LOGICAL_ELEMS] + ).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-tail/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vexp-tail/kernel.pto new file mode 100644 index 000000000..9f30419ea --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-tail/kernel.pto @@ -0,0 +1,40 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vexp_tail_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1000_i32 = arith.constant 1000 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vexp %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-tail/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vexp-tail/launch.cpp new file mode 100644 index 000000000..723fee5d5 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-tail/launch.cpp @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vexp_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVexp_tail_kernel_2d(float *v1, float *v2, void *stream) { + vexp_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-tail/main.cpp b/test/vpto/cases/micro-op/unary-vector/vexp-tail/main.cpp new file mode 100644 index 000000000..19f1b06f2 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-tail/main.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVexp_tail_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVexp_tail_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-tail/stub.cpp b/test/vpto/cases/micro-op/unary-vector/vexp-tail/stub.cpp new file mode 100644 index 000000000..50a0b3806 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-tail/stub.cpp @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vexp_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vexp/compare.py b/test/vpto/cases/micro-op/unary-vector/vexp/compare.py new file mode 100644 index 000000000..962985a24 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp/compare.py @@ -0,0 +1,204 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vexp/golden.py b/test/vpto/cases/micro-op/unary-vector/vexp/golden.py new file mode 100644 index 000000000..1df0d6853 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp/golden.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-4.0, 4.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.exp(v1).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vexp validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vexp/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vexp/kernel.pto new file mode 100644 index 000000000..abcad4e82 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp/kernel.pto @@ -0,0 +1,60 @@ +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5"} { + func.func @vexp_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vexp %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vexp/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vexp/launch.cpp new file mode 100644 index 000000000..b6d8cdbf0 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp/launch.cpp @@ -0,0 +1,62 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vexp_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVexp_kernel_2d(float *v1, float *v2, void *stream) { + vexp_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vexp/main.cpp b/test/vpto/cases/micro-op/unary-vector/vexp/main.cpp new file mode 100644 index 000000000..f864622ca --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp/main.cpp @@ -0,0 +1,122 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVexp_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVexp_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vexp/stub.cpp b/test/vpto/cases/micro-op/unary-vector/vexp/stub.cpp new file mode 100644 index 000000000..0c8ab3a00 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vexp_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/compare.py b/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/compare.py new file mode 100755 index 000000000..afee62a98 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vln-domain-boundary +# family: unary-vector +# target_ops: pto.vln +# scenarios: core-f32, domain-positive, exceptional-values +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/golden.py b/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/golden.py new file mode 100755 index 000000000..64f82ec2d --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/golden.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vln-domain-boundary +# family: unary-vector +# target_ops: pto.vln +# scenarios: core-f32, domain-positive, exceptional-values +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(0.125, 8.0, size=(ROWS, COLS)).astype(np.float32) + flat = v1.reshape(-1) + flat[:8] = np.array( + [ + np.float32(np.finfo(np.float32).tiny), + np.float32(np.finfo(np.float32).tiny * 2.0), + np.float32(1.0), + np.float32(2.0), + np.float32(16.0), + np.float32(1024.0), + np.float32(np.finfo(np.float32).max), + np.float32(0.5), + ], + dtype=np.float32, + ) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.log(v1).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vln domain-boundary validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/kernel.pto new file mode 100644 index 000000000..14a195eeb --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/kernel.pto @@ -0,0 +1,68 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vln-domain-boundary +// family: unary-vector +// target_ops: pto.vln +// scenarios: core-f32, domain-positive, exceptional-values +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5"} { + func.func @vabs_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vln %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/launch.cpp new file mode 100644 index 000000000..6aeeded6c --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vln-domain-boundary +// family: unary-vector +// target_ops: pto.vln +// scenarios: core-f32, domain-positive, exceptional-values +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vabs_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/main.cpp b/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/main.cpp new file mode 100644 index 000000000..ab31f79d8 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vln-domain-boundary +// family: unary-vector +// target_ops: pto.vln +// scenarios: core-f32, domain-positive, exceptional-values +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/stub.cpp b/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/stub.cpp new file mode 100644 index 000000000..1f1f413a4 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vln-domain-boundary +// family: unary-vector +// target_ops: pto.vln +// scenarios: core-f32, domain-positive, exceptional-values +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vln/compare.py b/test/vpto/cases/micro-op/unary-vector/vln/compare.py new file mode 100644 index 000000000..962985a24 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vln/compare.py @@ -0,0 +1,204 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vln/golden.py b/test/vpto/cases/micro-op/unary-vector/vln/golden.py new file mode 100644 index 000000000..b7a53856f --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vln/golden.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = np.exp(rng.uniform(-4.0, 2.0, size=(ROWS, COLS)).astype(np.float32)) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.log(v1).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vln validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vln/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vln/kernel.pto new file mode 100644 index 000000000..7746c3655 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vln/kernel.pto @@ -0,0 +1,41 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vexp_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vln %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vln/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vln/launch.cpp new file mode 100644 index 000000000..b6d8cdbf0 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vln/launch.cpp @@ -0,0 +1,62 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vexp_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVexp_kernel_2d(float *v1, float *v2, void *stream) { + vexp_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vln/main.cpp b/test/vpto/cases/micro-op/unary-vector/vln/main.cpp new file mode 100644 index 000000000..f864622ca --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vln/main.cpp @@ -0,0 +1,122 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVexp_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVexp_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vln/stub.cpp b/test/vpto/cases/micro-op/unary-vector/vln/stub.cpp new file mode 100644 index 000000000..0c8ab3a00 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vln/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vexp_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/compare.py b/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/compare.py new file mode 100755 index 000000000..1030c959f --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vneg-f32-exceptional +# family: unary-vector +# target_ops: pto.vneg +# scenarios: core-f32, exceptional-values +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/golden.py b/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/golden.py new file mode 100755 index 000000000..0f394e5ed --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/golden.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vneg-f32-exceptional +# family: unary-vector +# target_ops: pto.vneg +# scenarios: core-f32, exceptional-values +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + flat = v1.reshape(-1) + flat[:8] = np.array( + [ + np.float32(0.0), + np.float32(-0.0), + np.float32(np.inf), + np.float32(-np.inf), + np.float32(np.nan), + np.float32(1.0), + np.float32(-1.0), + np.float32(3.5), + ], + dtype=np.float32, + ) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.negative(v1).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vneg exceptional validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/kernel.pto new file mode 100644 index 000000000..93d2893e1 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/kernel.pto @@ -0,0 +1,68 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vneg-f32-exceptional +// family: unary-vector +// target_ops: pto.vneg +// scenarios: core-f32, exceptional-values +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5"} { + func.func @vabs_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vneg %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/launch.cpp new file mode 100644 index 000000000..2614d8040 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vneg-f32-exceptional +// family: unary-vector +// target_ops: pto.vneg +// scenarios: core-f32, exceptional-values +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vabs_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/main.cpp b/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/main.cpp new file mode 100644 index 000000000..de8fba973 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vneg-f32-exceptional +// family: unary-vector +// target_ops: pto.vneg +// scenarios: core-f32, exceptional-values +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/stub.cpp b/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/stub.cpp new file mode 100644 index 000000000..a4d613cc0 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vneg-f32-exceptional +// family: unary-vector +// target_ops: pto.vneg +// scenarios: core-f32, exceptional-values +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vneg/compare.py b/test/vpto/cases/micro-op/unary-vector/vneg/compare.py new file mode 100755 index 000000000..0ce4e18b6 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vneg/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vneg +# family: unary-vector +# target_ops: pto.vneg +# scenarios: core-f32, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vneg/golden.py b/test/vpto/cases/micro-op/unary-vector/vneg/golden.py new file mode 100755 index 000000000..a7e86608a --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vneg/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vneg +# family: unary-vector +# target_ops: pto.vneg +# scenarios: core-f32, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.negative(v1).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vneg validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vneg/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vneg/kernel.pto new file mode 100644 index 000000000..68197d779 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vneg/kernel.pto @@ -0,0 +1,68 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vneg +// family: unary-vector +// target_ops: pto.vneg +// scenarios: core-f32, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5"} { + func.func @vabs_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vneg %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vneg/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vneg/launch.cpp new file mode 100644 index 000000000..65504cb9e --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vneg/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vneg +// family: unary-vector +// target_ops: pto.vneg +// scenarios: core-f32, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vabs_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vneg/main.cpp b/test/vpto/cases/micro-op/unary-vector/vneg/main.cpp new file mode 100644 index 000000000..134aa5b2c --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vneg/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vneg +// family: unary-vector +// target_ops: pto.vneg +// scenarios: core-f32, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vneg/stub.cpp b/test/vpto/cases/micro-op/unary-vector/vneg/stub.cpp new file mode 100644 index 000000000..0417e6ccc --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vneg/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vneg +// family: unary-vector +// target_ops: pto.vneg +// scenarios: core-f32, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vnot/compare.py b/test/vpto/cases/micro-op/unary-vector/vnot/compare.py new file mode 100755 index 000000000..cdecd8075 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vnot/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vnot +# family: unary-vector +# target_ops: pto.vnot +# scenarios: core-i16-signed, full-mask +# NOTE: lane-wise bitwise inversion on signed i16 source lanes. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.int16, 0.0) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vnot/golden.py b/test/vpto/cases/micro-op/unary-vector/vnot/golden.py new file mode 100755 index 000000000..c0d048c27 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vnot/golden.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vnot +# family: unary-vector +# target_ops: pto.vnot +# scenarios: core-i16-signed, full-mask +# NOTE: lane-wise bitwise inversion on signed i16 source lanes. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers( + low=np.iinfo(np.int16).min, + high=np.iinfo(np.int16).max + 1, + size=(ROWS, COLS), + dtype=np.int16, + ) + v2 = np.zeros((ROWS, COLS), dtype=np.int16) + golden_v2 = np.bitwise_not(v1).astype(np.int16, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vnot validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vnot/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vnot/kernel.pto new file mode 100644 index 000000000..d33238b99 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vnot/kernel.pto @@ -0,0 +1,48 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vnot +// family: unary-vector +// target_ops: pto.vnot +// scenarios: core-i16-signed, full-mask +// NOTE: lane-wise bitwise inversion on signed i16 source lanes. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vnot_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xi16> + %out = pto.vnot %vec, %mask : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xi16>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vnot/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vnot/launch.cpp new file mode 100644 index 000000000..c2b22293e --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vnot/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vnot +// family: unary-vector +// target_ops: pto.vnot +// scenarios: core-i16-signed, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vnot_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVnot_kernel_2d(float *v1, float *v2, void *stream) { + vnot_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vnot/main.cpp b/test/vpto/cases/micro-op/unary-vector/vnot/main.cpp new file mode 100644 index 000000000..3b97a8523 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vnot/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vnot +// family: unary-vector +// target_ops: pto.vnot +// scenarios: core-i16-signed, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVnot_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVnot_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vnot/stub.cpp b/test/vpto/cases/micro-op/unary-vector/vnot/stub.cpp new file mode 100644 index 000000000..f5a8abcdb --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vnot/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vnot +// family: unary-vector +// target_ops: pto.vnot +// scenarios: core-i16-signed, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vnot_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vrelu/compare.py b/test/vpto/cases/micro-op/unary-vector/vrelu/compare.py new file mode 100644 index 000000000..962985a24 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vrelu/compare.py @@ -0,0 +1,204 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vrelu/golden.py b/test/vpto/cases/micro-op/unary-vector/vrelu/golden.py new file mode 100644 index 000000000..d481d99d4 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vrelu/golden.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.maximum(v1, np.float32(0.0)).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vrelu validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vrelu/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vrelu/kernel.pto new file mode 100644 index 000000000..0fb482427 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vrelu/kernel.pto @@ -0,0 +1,41 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vexp_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vrelu %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vrelu/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vrelu/launch.cpp new file mode 100644 index 000000000..b6d8cdbf0 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vrelu/launch.cpp @@ -0,0 +1,62 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vexp_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVexp_kernel_2d(float *v1, float *v2, void *stream) { + vexp_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vrelu/main.cpp b/test/vpto/cases/micro-op/unary-vector/vrelu/main.cpp new file mode 100644 index 000000000..f864622ca --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vrelu/main.cpp @@ -0,0 +1,122 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVexp_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVexp_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vrelu/stub.cpp b/test/vpto/cases/micro-op/unary-vector/vrelu/stub.cpp new file mode 100644 index 000000000..0c8ab3a00 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vrelu/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vexp_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/compare.py b/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/compare.py new file mode 100755 index 000000000..71d8b50c2 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vsqrt-domain-boundary +# family: unary-vector +# target_ops: pto.vsqrt +# scenarios: core-f32, domain-nonnegative, exceptional-values +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/golden.py b/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/golden.py new file mode 100755 index 000000000..9607fbcc5 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/golden.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vsqrt-domain-boundary +# family: unary-vector +# target_ops: pto.vsqrt +# scenarios: core-f32, domain-nonnegative, exceptional-values +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(0.0, 16.0, size=(ROWS, COLS)).astype(np.float32) + flat = v1.reshape(-1) + flat[:8] = np.array( + [ + np.float32(0.0), + np.nextafter(np.float32(0.0), np.float32(1.0), dtype=np.float32), + np.float32(1.0), + np.float32(4.0), + np.float32(9.0), + np.float32(16.0), + np.float32(1024.0), + np.float32(np.finfo(np.float32).max), + ], + dtype=np.float32, + ) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.sqrt(v1).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vsqrt domain-boundary validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/kernel.pto new file mode 100644 index 000000000..f2c48b3d6 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/kernel.pto @@ -0,0 +1,68 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vsqrt-domain-boundary +// family: unary-vector +// target_ops: pto.vsqrt +// scenarios: core-f32, domain-nonnegative, exceptional-values +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5"} { + func.func @vabs_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vsqrt %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/launch.cpp new file mode 100644 index 000000000..1070e35b1 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vsqrt-domain-boundary +// family: unary-vector +// target_ops: pto.vsqrt +// scenarios: core-f32, domain-nonnegative, exceptional-values +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vabs_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/main.cpp b/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/main.cpp new file mode 100644 index 000000000..95bed1286 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vsqrt-domain-boundary +// family: unary-vector +// target_ops: pto.vsqrt +// scenarios: core-f32, domain-nonnegative, exceptional-values +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/stub.cpp b/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/stub.cpp new file mode 100644 index 000000000..49e13bfda --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vsqrt-domain-boundary +// family: unary-vector +// target_ops: pto.vsqrt +// scenarios: core-f32, domain-nonnegative, exceptional-values +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vsqrt/compare.py b/test/vpto/cases/micro-op/unary-vector/vsqrt/compare.py new file mode 100644 index 000000000..962985a24 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vsqrt/compare.py @@ -0,0 +1,204 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vsqrt/golden.py b/test/vpto/cases/micro-op/unary-vector/vsqrt/golden.py new file mode 100644 index 000000000..a5739a6b3 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vsqrt/golden.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + base = rng.uniform(0.0, 4.0, size=(ROWS, COLS)).astype(np.float32) + v1 = np.square(base).astype(np.float32, copy=False) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.sqrt(v1).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vsqrt validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vsqrt/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vsqrt/kernel.pto new file mode 100644 index 000000000..3b1de1acd --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vsqrt/kernel.pto @@ -0,0 +1,41 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vexp_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vsqrt %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vsqrt/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vsqrt/launch.cpp new file mode 100644 index 000000000..b6d8cdbf0 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vsqrt/launch.cpp @@ -0,0 +1,62 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vexp_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVexp_kernel_2d(float *v1, float *v2, void *stream) { + vexp_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vsqrt/main.cpp b/test/vpto/cases/micro-op/unary-vector/vsqrt/main.cpp new file mode 100644 index 000000000..f864622ca --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vsqrt/main.cpp @@ -0,0 +1,122 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVexp_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVexp_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vsqrt/stub.cpp b/test/vpto/cases/micro-op/unary-vector/vsqrt/stub.cpp new file mode 100644 index 000000000..0c8ab3a00 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vsqrt/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vexp_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/compare.py b/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/compare.py new file mode 100644 index 000000000..b5dd9902e --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vaddcs-carry-boundary +# family: vec-scalar +# target_ops: pto.vaddcs +# scenarios: core-u32-unsigned, full-mask, carry-chain, integer-overflow + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +LOGICAL_ELEMS = 64 +SRC_ELEM_BYTES = 4 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + repeat_elems = REPEAT_BYTES // src_elem_bytes + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + +def compare_result(): + golden = np.fromfile("golden_v3.bin", dtype=np.uint32, count=64) + output = np.fromfile("v3.bin", dtype=np.uint32, count=64) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def compare_carry(): + prefix_bytes = _packed_pred_storage_bytes(LOGICAL_ELEMS, SRC_ELEM_BYTES) + golden = np.fromfile("golden_v4.bin", dtype=np.uint8) + output = np.fromfile("v4.bin", dtype=np.uint8) + if golden.size < prefix_bytes or output.size < prefix_bytes: + return False + return np.array_equal(golden[:prefix_bytes], output[:prefix_bytes]) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_result() and compare_carry() + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/golden.py b/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/golden.py new file mode 100644 index 000000000..ddf74542c --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/golden.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vaddcs-carry-boundary +# family: vec-scalar +# target_ops: pto.vaddcs +# scenarios: core-u32-unsigned, full-mask, carry-chain, integer-overflow + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 64 +LHS_PATTERN = np.array( + [0x00000000, 0x00000001, 0xFFFFFFFE, 0xFFFFFFFF, 0x7FFFFFFF, 0x80000000, 0xAAAAAAAA, 0x55555555], + dtype=np.uint32, +) +RHS_PATTERN = np.array( + [0x00000000, 0xFFFFFFFF, 0x00000001, 0x00000000, 0x80000000, 0x7FFFFFFF, 0x55555555, 0xAAAAAAAA], + dtype=np.uint32, +) + + +def pack_mask_nibbles(bits): + out = np.zeros(256, dtype=np.uint8) + for idx, bit in enumerate(bits): + if not bit: + continue + byte = idx // 2 + if idx % 2 == 0: + out[byte] |= np.uint8(0x1) + else: + out[byte] |= np.uint8(0x10) + return out + + +def generate(output_dir: Path, seed: int) -> None: + del seed + repeats = LANES // LHS_PATTERN.size + lhs = np.tile(LHS_PATTERN, repeats) + rhs = np.tile(RHS_PATTERN, repeats) + total = lhs.astype(np.uint64) + rhs.astype(np.uint64) + np.uint64(1) + result = (total & np.uint64(0xFFFFFFFF)).astype(np.uint32) + carry = (total >> np.uint64(32)) != 0 + + output_dir.mkdir(parents=True, exist_ok=True) + lhs.tofile(output_dir / "v1.bin") + rhs.tofile(output_dir / "v2.bin") + np.zeros(LANES, dtype=np.uint32).tofile(output_dir / "v3.bin") + np.zeros(256, dtype=np.uint8).tofile(output_dir / "v4.bin") + result.tofile(output_dir / "golden_v3.bin") + pack_mask_nibbles(carry).tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=19) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/kernel.pto new file mode 100644 index 000000000..b6c9aea8f --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/kernel.pto @@ -0,0 +1,54 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vaddcs-carry-boundary +// family: vec-scalar +// target_ops: pto.vaddcs +// scenarios: core-u32-unsigned, full-mask, carry-chain, integer-overflow +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vaddcs_carry_boundary_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr, + %arg3: !pto.ptr) { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + %false = arith.constant false + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_carry = pto.castptr %c12288_i64 : i64 -> !pto.ptr + + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %carry_in = pto.pset_b32 "PAT_ALL" : !pto.mask + %lhs = pto.vlds %ub_lhs[%c0] : !pto.ptr -> !pto.vreg<64xui32> + %rhs = pto.vlds %ub_rhs[%c0] : !pto.ptr -> !pto.vreg<64xui32> + %sum, %carry = pto.vaddcs %lhs, %rhs, %carry_in, %mask : !pto.vreg<64xui32>, !pto.vreg<64xui32>, !pto.mask, !pto.mask -> !pto.vreg<64xui32>, !pto.mask + pto.vsts %sum, %ub_out[%c0], %mask : !pto.vreg<64xui32>, !pto.ptr, !pto.mask + pto.psti %carry, %ub_carry[%c0], "NORM" : !pto.mask, !pto.ptr, index + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_carry, %arg3, %c0_i64, %c1_i64, %c128_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/launch.cpp new file mode 100644 index 000000000..209c04c6a --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/launch.cpp @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vaddcs-carry-boundary +// family: vec-scalar +// target_ops: pto.vaddcs +// scenarios: core-u32-unsigned, full-mask, carry-chain, integer-overflow +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vaddcs_carry_boundary_kernel( + __gm__ uint32_t *v1, __gm__ uint32_t *v2, __gm__ uint32_t *v3, + __gm__ uint8_t *v4); + +void LaunchVaddcsCarryBoundaryKernel(uint32_t *v1, uint32_t *v2, uint32_t *v3, + uint8_t *v4, void *stream) { + vaddcs_carry_boundary_kernel<<<1, nullptr, stream>>>( + (__gm__ uint32_t *)v1, (__gm__ uint32_t *)v2, (__gm__ uint32_t *)v3, + (__gm__ uint8_t *)v4); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/main.cpp new file mode 100644 index 000000000..6addb079d --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/main.cpp @@ -0,0 +1,115 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vaddcs-carry-boundary +// family: vec-scalar +// target_ops: pto.vaddcs +// scenarios: core-u32-unsigned, full-mask, carry-chain, integer-overflow +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVaddcsCarryBoundaryKernel(uint32_t *v1, uint32_t *v2, uint32_t *v3, + uint8_t *v4, void *stream); + +int main() { + size_t elemCount_v1 = 64; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + size_t elemCount_v2 = 64; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint32_t); + size_t elemCount_v3 = 64; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint32_t); + size_t elemCount_v4 = 256; + size_t fileSize_v4 = elemCount_v4 * sizeof(uint8_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + uint32_t *v2Host = nullptr; + uint32_t *v2Device = nullptr; + uint32_t *v3Host = nullptr; + uint32_t *v3Device = nullptr; + uint8_t *v4Host = nullptr; + uint8_t *v4Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMallocHost((void **)(&v4Host), fileSize_v4)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v4Device, fileSize_v4, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ReadFile("./v4.bin", fileSize_v4, v4Host, fileSize_v4); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v4Device, fileSize_v4, v4Host, fileSize_v4, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVaddcsCarryBoundaryKernel(v1Device, v2Device, v3Device, v4Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v4Host, fileSize_v4, v4Device, fileSize_v4, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + WriteFile("./v4.bin", v4Host, fileSize_v4); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFree(v4Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + aclrtFreeHost(v4Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/stub.cpp b/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/stub.cpp new file mode 100644 index 000000000..0dc3f636f --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/stub.cpp @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vaddcs-carry-boundary +// family: vec-scalar +// target_ops: pto.vaddcs +// scenarios: core-u32-unsigned, full-mask, carry-chain, integer-overflow +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vaddcs_carry_boundary_kernel( + __gm__ uint32_t *v1, __gm__ uint32_t *v2, __gm__ uint32_t *v3, + __gm__ uint8_t *v4) { + (void)v1; + (void)v2; + (void)v3; + (void)v4; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vaddcs/compare.py b/test/vpto/cases/micro-op/vec-scalar/vaddcs/compare.py new file mode 100644 index 000000000..130b06567 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vaddcs/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vaddcs +# family: vec-scalar +# target_ops: pto.vaddcs +# scenarios: core-u32-unsigned, full-mask, carry-chain + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +LOGICAL_ELEMS = 64 +SRC_ELEM_BYTES = 4 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + repeat_elems = REPEAT_BYTES // src_elem_bytes + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + +def compare_result(): + golden = np.fromfile("golden_v3.bin", dtype=np.uint32, count=64) + output = np.fromfile("v3.bin", dtype=np.uint32, count=64) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def compare_carry(): + prefix_bytes = _packed_pred_storage_bytes(LOGICAL_ELEMS, SRC_ELEM_BYTES) + golden = np.fromfile("golden_v4.bin", dtype=np.uint8) + output = np.fromfile("v4.bin", dtype=np.uint8) + if golden.size < prefix_bytes or output.size < prefix_bytes: + return False + return np.array_equal(golden[:prefix_bytes], output[:prefix_bytes]) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_result() and compare_carry() + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vaddcs/golden.py b/test/vpto/cases/micro-op/vec-scalar/vaddcs/golden.py new file mode 100644 index 000000000..35f57f535 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vaddcs/golden.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vaddcs +# family: vec-scalar +# target_ops: pto.vaddcs +# scenarios: core-u32-unsigned, full-mask, carry-chain + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 64 +SEED = 19 + + +def pack_mask_nibbles(bits): + out = np.zeros(256, dtype=np.uint8) + for idx, bit in enumerate(bits): + if not bit: + continue + byte = idx // 2 + if idx % 2 == 0: + out[byte] |= np.uint8(0x1) + else: + out[byte] |= np.uint8(0x10) + return out + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + lhs = rng.integers(0, 0xFFFFFFFF, size=LANES, dtype=np.uint32) + rhs = rng.integers(0, 0xFFFFFFFF, size=LANES, dtype=np.uint32) + total = lhs.astype(np.uint64) + rhs.astype(np.uint64) + np.uint64(1) + result = (total & np.uint64(0xFFFFFFFF)).astype(np.uint32) + carry = (total >> np.uint64(32)) != 0 + + output_dir.mkdir(parents=True, exist_ok=True) + lhs.tofile(output_dir / "v1.bin") + rhs.tofile(output_dir / "v2.bin") + np.zeros(LANES, dtype=np.uint32).tofile(output_dir / "v3.bin") + np.zeros(256, dtype=np.uint8).tofile(output_dir / "v4.bin") + result.tofile(output_dir / "golden_v3.bin") + pack_mask_nibbles(carry).tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vaddcs/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vaddcs/kernel.pto new file mode 100644 index 000000000..cd73bb3c6 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vaddcs/kernel.pto @@ -0,0 +1,52 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vaddcs +// family: vec-scalar +// target_ops: pto.vaddcs +// scenarios: core-u32-unsigned, full-mask, carry-chain +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vaddcs_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr, %arg3: !pto.ptr) { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + %false = arith.constant false + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_carry = pto.castptr %c12288_i64 : i64 -> !pto.ptr + + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %carry_in = pto.pset_b32 "PAT_ALL" : !pto.mask + %lhs = pto.vlds %ub_lhs[%c0] : !pto.ptr -> !pto.vreg<64xui32> + %rhs = pto.vlds %ub_rhs[%c0] : !pto.ptr -> !pto.vreg<64xui32> + %sum, %carry = pto.vaddcs %lhs, %rhs, %carry_in, %mask : !pto.vreg<64xui32>, !pto.vreg<64xui32>, !pto.mask, !pto.mask -> !pto.vreg<64xui32>, !pto.mask + pto.vsts %sum, %ub_out[%c0], %mask : !pto.vreg<64xui32>, !pto.ptr, !pto.mask + pto.psti %carry, %ub_carry[%c0], "NORM" : !pto.mask, !pto.ptr, index + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_carry, %arg3, %c0_i64, %c1_i64, %c128_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vaddcs/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vaddcs/launch.cpp new file mode 100644 index 000000000..d25fa38e7 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vaddcs/launch.cpp @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vaddcs +// family: vec-scalar +// target_ops: pto.vaddcs +// scenarios: core-u32-unsigned, full-mask, carry-chain +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vaddcs_kernel(__gm__ uint32_t *v1, __gm__ uint32_t *v2, __gm__ uint32_t *v3, + __gm__ uint8_t *v4); + +void LaunchVaddcs_kernel(uint32_t *v1, uint32_t *v2, uint32_t *v3, uint8_t *v4, + void *stream) { + vaddcs_kernel<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1, + (__gm__ uint32_t *)v2, + (__gm__ uint32_t *)v3, + (__gm__ uint8_t *)v4); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vaddcs/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vaddcs/main.cpp new file mode 100644 index 000000000..ea0899688 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vaddcs/main.cpp @@ -0,0 +1,115 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vaddcs +// family: vec-scalar +// target_ops: pto.vaddcs +// scenarios: core-u32-unsigned, full-mask, carry-chain +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVaddcs_kernel(uint32_t *v1, uint32_t *v2, uint32_t *v3, uint8_t *v4, + void *stream); + +int main() { + size_t elemCount_v1 = 64; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + size_t elemCount_v2 = 64; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint32_t); + size_t elemCount_v3 = 64; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint32_t); + size_t elemCount_v4 = 256; + size_t fileSize_v4 = elemCount_v4 * sizeof(uint8_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + uint32_t *v2Host = nullptr; + uint32_t *v2Device = nullptr; + uint32_t *v3Host = nullptr; + uint32_t *v3Device = nullptr; + uint8_t *v4Host = nullptr; + uint8_t *v4Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMallocHost((void **)(&v4Host), fileSize_v4)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v4Device, fileSize_v4, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ReadFile("./v4.bin", fileSize_v4, v4Host, fileSize_v4); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v4Device, fileSize_v4, v4Host, fileSize_v4, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVaddcs_kernel(v1Device, v2Device, v3Device, v4Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v4Host, fileSize_v4, v4Device, fileSize_v4, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + WriteFile("./v4.bin", v4Host, fileSize_v4); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFree(v4Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + aclrtFreeHost(v4Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vaddcs/stub.cpp b/test/vpto/cases/micro-op/vec-scalar/vaddcs/stub.cpp new file mode 100644 index 000000000..67d47238c --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vaddcs/stub.cpp @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vaddcs +// family: vec-scalar +// target_ops: pto.vaddcs +// scenarios: core-u32-unsigned, full-mask, carry-chain +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void +vaddcs_kernel(__gm__ uint32_t *v1, __gm__ uint32_t *v2, __gm__ uint32_t *v3, + __gm__ uint8_t *v4) { + (void)v1; + (void)v2; + (void)v3; + (void)v4; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/compare.py b/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/compare.py new file mode 100644 index 000000000..896c992b2 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.uint16, 0) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/golden.py b/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/golden.py new file mode 100644 index 000000000..0efc4ec38 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/golden.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 +SCALE = np.float32(1.5) + + +def f32_to_bf16_bits(values: np.ndarray) -> np.ndarray: + wide = values.astype(np.float32, copy=False).view(np.uint32) + rounding = np.uint32(0x7FFF) + ((wide >> 16) & np.uint32(1)) + return ((wide + rounding) >> 16).astype(np.uint16) + + +def bf16_bits_to_f32(bits: np.ndarray) -> np.ndarray: + return (bits.astype(np.uint32) << 16).view(np.float32) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1_f32 = rng.uniform(-4.0, 4.0, size=ELEMS).astype(np.float32) + v1 = f32_to_bf16_bits(v1_f32) + v2 = np.zeros(ELEMS, dtype=np.uint16) + scalar_bits = f32_to_bf16_bits(np.array([SCALE], dtype=np.float32))[0] + scalar = bf16_bits_to_f32(np.array([scalar_bits], dtype=np.uint16))[0] + golden_v2 = f32_to_bf16_bits(bf16_bits_to_f32(v1) + scalar) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/kernel.pto new file mode 100644 index 000000000..7c1ccbf2f --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/kernel.pto @@ -0,0 +1,41 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vadds_bf16_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %cst = arith.constant 1.500000e+00 : bf16 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xbf16> + %sum = pto.vadds %vec, %cst, %mask : !pto.vreg<128xbf16>, bf16, !pto.mask -> !pto.vreg<128xbf16> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<128xbf16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/launch.cpp new file mode 100644 index 000000000..734c5e95a --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vadds_bf16_kernel(__gm__ bfloat16_t *v1, + __gm__ bfloat16_t *v2); + +void LaunchVadds_bf16_kernel(uint16_t *v1, uint16_t *v2, void *stream) { + vadds_bf16_kernel<<<1, nullptr, stream>>>((__gm__ bfloat16_t *)v1, + (__gm__ bfloat16_t *)v2); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/main.cpp new file mode 100644 index 000000000..ce2c7d7bf --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/main.cpp @@ -0,0 +1,84 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadds_bf16_kernel(uint16_t *v1, uint16_t *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadds_bf16_kernel(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/stub.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/stub.cpp new file mode 100644 index 000000000..635388bd2 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/stub.cpp @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vadds_bf16_kernel(__gm__ bfloat16_t *v1, + __gm__ bfloat16_t *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-f16/compare.py b/test/vpto/cases/micro-op/vec-scalar/vadds-f16/compare.py new file mode 100644 index 000000000..1b47ca433 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-f16/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float16, 5e-3) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-f16/golden.py b/test/vpto/cases/micro-op/vec-scalar/vadds-f16/golden.py new file mode 100644 index 000000000..019cf6980 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-f16/golden.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 +SCALE = np.float16(1.5) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-4.0, 4.0, size=ELEMS).astype(np.float16) + v2 = np.zeros(ELEMS, dtype=np.float16) + golden_v2 = (v1.astype(np.float32) + np.float32(SCALE)).astype(np.float16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-f16/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vadds-f16/kernel.pto new file mode 100644 index 000000000..219522026 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-f16/kernel.pto @@ -0,0 +1,41 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vadds_f16_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %cst = arith.constant 1.500000e+00 : f16 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xf16> + %sum = pto.vadds %vec, %cst, %mask : !pto.vreg<128xf16>, f16, !pto.mask -> !pto.vreg<128xf16> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<128xf16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-f16/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-f16/launch.cpp new file mode 100644 index 000000000..e964f1539 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-f16/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vadds_f16_kernel(__gm__ half *v1, + __gm__ half *v2); + +void LaunchVadds_f16_kernel(uint16_t *v1, uint16_t *v2, void *stream) { + vadds_f16_kernel<<<1, nullptr, stream>>>((__gm__ half *)v1, + (__gm__ half *)v2); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-f16/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-f16/main.cpp new file mode 100644 index 000000000..0e9c0076c --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-f16/main.cpp @@ -0,0 +1,84 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadds_f16_kernel(uint16_t *v1, uint16_t *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadds_f16_kernel(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-f16/stub.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-f16/stub.cpp new file mode 100644 index 000000000..c2e2d8622 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-f16/stub.cpp @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vadds_f16_kernel(__gm__ half *v1, + __gm__ half *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/compare.py b/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/compare.py new file mode 100644 index 000000000..15b793fac --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/golden.py b/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/golden.py new file mode 100644 index 000000000..c101038fb --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +SCALE = np.float32(0.5) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + specials = np.array( + [-np.inf, -1.0, -0.0, 0.0, 1.0, np.inf, np.nan, 3.5], + dtype=np.float32, + ) + v1 = np.resize(specials, ROWS * COLS).reshape(ROWS, COLS).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = (v1 + SCALE).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/kernel.pto new file mode 100644 index 000000000..8309e4a04 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/kernel.pto @@ -0,0 +1,43 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vadds_f32_exceptional_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + %cst = arith.constant 5.000000e-01 : f32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %sum = pto.vadds %vec, %cst, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/launch.cpp new file mode 100644 index 000000000..915da93e1 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vadds_f32_exceptional_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVadds_f32_exceptional_kernel_2d(float *v1, float *v2, void *stream) { + vadds_f32_exceptional_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/main.cpp new file mode 100644 index 000000000..9ba910e1c --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/main.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadds_f32_exceptional_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadds_f32_exceptional_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/stub.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/stub.cpp new file mode 100644 index 000000000..1acd5f255 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/stub.cpp @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vadds_f32_exceptional_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/compare.py b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/compare.py new file mode 100644 index 000000000..3402d0a6d --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vadds-i16-signed-overflow +# family: vec-scalar +# target_ops: pto.vadds +# scenarios: core-i16-signed, full-mask, scalar-operand, integer-overflow + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str, dtype) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.int16) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/golden.py b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/golden.py new file mode 100644 index 000000000..fcf4c5afb --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/golden.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vadds-i16-signed-overflow +# family: vec-scalar +# target_ops: pto.vadds +# scenarios: core-i16-signed, full-mask, scalar-operand, integer-overflow + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 +SCALAR = np.int16(1024) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(-16000, 16000, size=ELEMS, dtype=np.int16) + v1[:12] = np.array( + [ + 32767, + 32766, + 32760, + 32000, + 0, + 1, + -1, + -32768, + -32767, + -32000, + 12345, + -12345, + ], + dtype=np.int16, + ) + v2 = np.zeros(ELEMS, dtype=np.int16) + golden_v2 = (v1.astype(np.int32) + int(SCALAR)).astype(np.int16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/kernel.pto new file mode 100644 index 000000000..52b50c48e --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/kernel.pto @@ -0,0 +1,46 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vadds-i16-signed-overflow +// family: vec-scalar +// target_ops: pto.vadds +// scenarios: core-i16-signed, full-mask, scalar-operand, integer-overflow +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vadds_i16_signed_overflow_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %scalar = arith.constant 1024 : i16 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xi16> + %sum = pto.vadds %vec, %scalar, %mask : !pto.vreg<128xi16>, i16, !pto.mask -> !pto.vreg<128xi16> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<128xi16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/launch.cpp new file mode 100644 index 000000000..4ac3e8a59 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/launch.cpp @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vadds-i16-signed-overflow +// family: vec-scalar +// target_ops: pto.vadds +// scenarios: core-i16-signed, full-mask, scalar-operand, integer-overflow +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vadds_i16_signed_overflow_kernel(__gm__ int16_t *v1, __gm__ int16_t *v2); + +void LaunchVadds_i16_signed_overflow_kernel(int16_t *v1, int16_t *v2, + void *stream) { + vadds_i16_signed_overflow_kernel<<<1, nullptr, stream>>>((__gm__ int16_t *)v1, + (__gm__ int16_t *)v2); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/main.cpp new file mode 100644 index 000000000..b8c63d7a8 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/main.cpp @@ -0,0 +1,81 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadds_i16_signed_overflow_kernel(int16_t *v1, int16_t *v2, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(int16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int16_t); + int16_t *v1Host = nullptr; + int16_t *v1Device = nullptr; + int16_t *v2Host = nullptr; + int16_t *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadds_i16_signed_overflow_kernel(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/stub.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/stub.cpp new file mode 100644 index 000000000..48145aba1 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/stub.cpp @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vadds-i16-signed-overflow +// family: vec-scalar +// target_ops: pto.vadds +// scenarios: core-i16-signed, full-mask, scalar-operand, integer-overflow +// ----------------------------------------------------------------------------- +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void +vadds_i16_signed_overflow_kernel(__gm__ int16_t *v1, __gm__ int16_t *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/compare.py b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/compare.py new file mode 100644 index 000000000..421ac84f5 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vadds-i16-signed +# family: vec-scalar +# target_ops: pto.vadds +# scenarios: core-i16-signed, full-mask, scalar-operand + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str, dtype) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.int16) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/golden.py b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/golden.py new file mode 100644 index 000000000..23d6855f0 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/golden.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vadds-i16-signed +# family: vec-scalar +# target_ops: pto.vadds +# scenarios: core-i16-signed, full-mask, scalar-operand + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 +SCALAR = np.int16(37) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(-12000, 12000, size=ELEMS, dtype=np.int16) + v2 = np.zeros(ELEMS, dtype=np.int16) + golden_v2 = (v1.astype(np.int32) + int(SCALAR)).astype(np.int16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/kernel.pto new file mode 100644 index 000000000..fe7c7c1c2 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/kernel.pto @@ -0,0 +1,46 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vadds-i16-signed +// family: vec-scalar +// target_ops: pto.vadds +// scenarios: core-i16-signed, full-mask, scalar-operand +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vadds_i16_signed_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c1024_i32 = arith.constant 1024 : i32 + %scalar = arith.constant 37 : i16 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xi16> + %sum = pto.vadds %vec, %scalar, %mask : !pto.vreg<128xi16>, i16, !pto.mask -> !pto.vreg<128xi16> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<128xi16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/launch.cpp new file mode 100644 index 000000000..fd275ede0 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vadds_i16_signed_kernel(__gm__ int16_t *v1, + __gm__ int16_t *v2); + +void LaunchVadds_i16_signed_kernel(int16_t *v1, int16_t *v2, void *stream) { + vadds_i16_signed_kernel<<<1, nullptr, stream>>>((__gm__ int16_t *)v1, + (__gm__ int16_t *)v2); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/main.cpp new file mode 100644 index 000000000..3cbb6afba --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/main.cpp @@ -0,0 +1,80 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadds_i16_signed_kernel(int16_t *v1, int16_t *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(int16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int16_t); + int16_t *v1Host = nullptr; + int16_t *v1Device = nullptr; + int16_t *v2Host = nullptr; + int16_t *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadds_i16_signed_kernel(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/stub.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/stub.cpp new file mode 100644 index 000000000..bd934808a --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/stub.cpp @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vadds_i16_signed_kernel(__gm__ int16_t *v1, + __gm__ int16_t *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/compare.py b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/compare.py new file mode 100755 index 000000000..a1b852540 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/compare.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vadds-i16-unsigned-overflow +# family: vec-scalar +# target_ops: pto.vadds +# scenarios: core-i16-unsigned, full-mask, scalar-operand, integer-overflow + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.uint16) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/golden.py b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/golden.py new file mode 100755 index 000000000..813ec0287 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/golden.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vadds-i16-unsigned-overflow +# family: vec-scalar +# target_ops: pto.vadds +# scenarios: core-i16-unsigned, full-mask, scalar-operand, integer-overflow + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 +SCALAR = np.uint16(4096) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, 65535, size=ELEMS, dtype=np.uint16) + v1[:12] = np.array( + [ + 65535, + 65534, + 65500, + 65000, + 4096, + 2048, + 1024, + 1, + 0, + 32768, + 12345, + 54321, + ], + dtype=np.uint16, + ) + v2 = np.zeros(ELEMS, dtype=np.uint16) + golden_v2 = (v1.astype(np.uint32) + int(SCALAR)).astype(np.uint16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/kernel.pto new file mode 100644 index 000000000..ad099717a --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/kernel.pto @@ -0,0 +1,46 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vadds-i16-unsigned-overflow +// family: vec-scalar +// target_ops: pto.vadds +// scenarios: core-i16-unsigned, full-mask, scalar-operand, integer-overflow +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vadds_i16_unsigned_overflow_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %scalar = arith.constant 4096 : i16 + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %sum = pto.vadds %vec, %scalar, %mask : !pto.vreg<128xui16>, i16, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/launch.cpp new file mode 100644 index 000000000..003c7556b --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/launch.cpp @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vadds-i16-unsigned-overflow +// family: vec-scalar +// target_ops: pto.vadds +// scenarios: core-i16-unsigned, full-mask, scalar-operand, integer-overflow +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vadds_i16_unsigned_overflow_kernel(__gm__ uint16_t *v1, __gm__ uint16_t *v2); + +void LaunchVadds_i16_unsigned_overflow_kernel(uint16_t *v1, uint16_t *v2, + void *stream) { + vadds_i16_unsigned_overflow_kernel<<<1, nullptr, stream>>>( + (__gm__ uint16_t *)v1, (__gm__ uint16_t *)v2); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/main.cpp new file mode 100644 index 000000000..8f8b2ebe6 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/main.cpp @@ -0,0 +1,81 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadds_i16_unsigned_overflow_kernel(uint16_t *v1, uint16_t *v2, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadds_i16_unsigned_overflow_kernel(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/stub.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/stub.cpp new file mode 100644 index 000000000..efcb47ac9 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/stub.cpp @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vadds-i16-unsigned-overflow +// family: vec-scalar +// target_ops: pto.vadds +// scenarios: core-i16-unsigned, full-mask, scalar-operand, integer-overflow +// ----------------------------------------------------------------------------- +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void +vadds_i16_unsigned_overflow_kernel(__gm__ uint16_t *v1, __gm__ uint16_t *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/compare.py b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/compare.py new file mode 100755 index 000000000..437f48ad7 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/compare.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vadds-i16-unsigned +# family: vec-scalar +# target_ops: pto.vadds +# scenarios: core-i16-unsigned, full-mask, scalar-operand + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.uint16) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/golden.py b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/golden.py new file mode 100755 index 000000000..df317a729 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/golden.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vadds-i16-unsigned +# family: vec-scalar +# target_ops: pto.vadds +# scenarios: core-i16-unsigned, full-mask, scalar-operand + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 +SCALAR = np.uint16(37) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, 60000, size=ELEMS, dtype=np.uint16) + v2 = np.zeros(ELEMS, dtype=np.uint16) + golden_v2 = (v1.astype(np.uint32) + int(SCALAR)).astype(np.uint16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/kernel.pto new file mode 100644 index 000000000..b3c4391ff --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/kernel.pto @@ -0,0 +1,46 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vadds-i16-unsigned +// family: vec-scalar +// target_ops: pto.vadds +// scenarios: core-i16-unsigned, full-mask, scalar-operand +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vadds_i16_unsigned_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %scalar = arith.constant 37 : i16 + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %sum = pto.vadds %vec, %scalar, %mask : !pto.vreg<128xui16>, i16, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/launch.cpp new file mode 100644 index 000000000..61ff83045 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/launch.cpp @@ -0,0 +1,50 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vadds-i16-unsigned +// family: vec-scalar +// target_ops: pto.vadds +// scenarios: core-i16-unsigned, full-mask, scalar-operand +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vadds_i16_unsigned_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2); + +void LaunchVadds_i16_unsigned_kernel(uint16_t *v1, uint16_t *v2, void *stream) { + vadds_i16_unsigned_kernel<<<1, nullptr, stream>>>((__gm__ uint16_t *)v1, + (__gm__ uint16_t *)v2); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/main.cpp new file mode 100644 index 000000000..de7c50e82 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/main.cpp @@ -0,0 +1,90 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vadds-i16-unsigned +// family: vec-scalar +// target_ops: pto.vadds +// scenarios: core-i16-unsigned, full-mask, scalar-operand +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadds_i16_unsigned_kernel(uint16_t *v1, uint16_t *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadds_i16_unsigned_kernel(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/stub.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/stub.cpp new file mode 100644 index 000000000..8ca1fd1bb --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/stub.cpp @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vadds-i16-unsigned +// family: vec-scalar +// target_ops: pto.vadds +// scenarios: core-i16-unsigned, full-mask, scalar-operand +// ----------------------------------------------------------------------------- +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vadds_i16_unsigned_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-tail/compare.py b/test/vpto/cases/micro-op/vec-scalar/vadds-tail/compare.py new file mode 100644 index 000000000..c13d79273 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-tail/compare.py @@ -0,0 +1,39 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float32, 1e-4, 1000) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-tail/golden.py b/test/vpto/cases/micro-op/vec-scalar/vadds-tail/golden.py new file mode 100644 index 000000000..2f06c22fa --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-tail/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +SCALE = np.float32(3.14) +LOGICAL_ELEMS = 1000 +OUT_SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.random((ROWS, COLS), dtype=np.float32) + v2 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v2 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v2.reshape(-1)[:LOGICAL_ELEMS] = ( + v1.reshape(-1)[:LOGICAL_ELEMS] + SCALE + ).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-tail/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vadds-tail/kernel.pto new file mode 100644 index 000000000..b66075220 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-tail/kernel.pto @@ -0,0 +1,43 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vadds_tail_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1000_i32 = arith.constant 1000 : i32 + %cst = arith.constant 3.140000e+00 : f32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %sum = pto.vadds %vec, %cst, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-tail/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-tail/launch.cpp new file mode 100644 index 000000000..b4cd46470 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-tail/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vadds_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVadds_tail_kernel_2d(float *v1, float *v2, void *stream) { + vadds_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-tail/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-tail/main.cpp new file mode 100644 index 000000000..ab77e6b1a --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-tail/main.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadds_tail_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadds_tail_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-tail/stub.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-tail/stub.cpp new file mode 100644 index 000000000..67e6846a1 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-tail/stub.cpp @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vadds_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds/compare.py b/test/vpto/cases/micro-op/vec-scalar/vadds/compare.py new file mode 100644 index 000000000..15b793fac --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds/golden.py b/test/vpto/cases/micro-op/vec-scalar/vadds/golden.py new file mode 100644 index 000000000..273a8d29f --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds/golden.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +SCALE = np.float32(3.14) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = (v1 + SCALE).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vadds/kernel.pto new file mode 100644 index 000000000..f8fb5d002 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds/kernel.pto @@ -0,0 +1,44 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vec_add_scalar_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + %cst = arith.constant 3.140000e+00 : f32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %sum = pto.vadds %vec, %cst, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds/launch.cpp new file mode 100644 index 000000000..44c07c249 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vec_add_scalar_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVec_add_scalar_kernel_2d(float *v1, float *v2, void *stream) { + vec_add_scalar_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds/main.cpp new file mode 100644 index 000000000..fcb42331f --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds/main.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVec_add_scalar_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVec_add_scalar_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds/stub.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds/stub.cpp new file mode 100644 index 000000000..dea70f9b6 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds/stub.cpp @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vec_add_scalar_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/compare.py b/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/compare.py new file mode 100644 index 000000000..c13d79273 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/compare.py @@ -0,0 +1,39 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float32, 1e-4, 1000) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/golden.py b/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/golden.py new file mode 100644 index 000000000..0b08cbcab --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +SCALE = np.float32(3.14) +LOGICAL_ELEMS = 1000 +OUT_SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.random((ROWS, COLS), dtype=np.float32) + v2 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v2 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v2.reshape(-1)[:LOGICAL_ELEMS] = np.maximum( + v1.reshape(-1)[:LOGICAL_ELEMS], SCALE + ).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/kernel.pto new file mode 100644 index 000000000..5f7cf6e77 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/kernel.pto @@ -0,0 +1,43 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vmaxs_tail_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1000_i32 = arith.constant 1000 : i32 + %cst = arith.constant 3.140000e+00 : f32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %maxv = pto.vmaxs %vec, %cst, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %maxv, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/launch.cpp new file mode 100644 index 000000000..d5ae524ce --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vmaxs_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVadds_tail_kernel_2d(float *v1, float *v2, void *stream) { + vmaxs_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/main.cpp new file mode 100644 index 000000000..ab77e6b1a --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/main.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadds_tail_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadds_tail_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/stub.cpp b/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/stub.cpp new file mode 100644 index 000000000..8a63f63ac --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/stub.cpp @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vmaxs_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmaxs/compare.py b/test/vpto/cases/micro-op/vec-scalar/vmaxs/compare.py new file mode 100644 index 000000000..15b793fac --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmaxs/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vmaxs/golden.py b/test/vpto/cases/micro-op/vec-scalar/vmaxs/golden.py new file mode 100644 index 000000000..d4b379f33 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmaxs/golden.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +SCALE = np.float32(3.14) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.maximum(v1, SCALE).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vmaxs/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vmaxs/kernel.pto new file mode 100644 index 000000000..217298fde --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmaxs/kernel.pto @@ -0,0 +1,44 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vec_max_scalar_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + %cst = arith.constant 3.140000e+00 : f32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %maxv = pto.vmaxs %vec, %cst, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %maxv, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmaxs/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vmaxs/launch.cpp new file mode 100644 index 000000000..a08848672 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmaxs/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vec_max_scalar_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVec_max_scalar_kernel_2d(float *v1, float *v2, void *stream) { + vec_max_scalar_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmaxs/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vmaxs/main.cpp new file mode 100644 index 000000000..47ce3f58b --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmaxs/main.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVec_max_scalar_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVec_max_scalar_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmaxs/stub.cpp b/test/vpto/cases/micro-op/vec-scalar/vmaxs/stub.cpp new file mode 100644 index 000000000..e691f926f --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmaxs/stub.cpp @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vec_max_scalar_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmins-tail/compare.py b/test/vpto/cases/micro-op/vec-scalar/vmins-tail/compare.py new file mode 100644 index 000000000..c13d79273 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmins-tail/compare.py @@ -0,0 +1,39 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float32, 1e-4, 1000) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vmins-tail/golden.py b/test/vpto/cases/micro-op/vec-scalar/vmins-tail/golden.py new file mode 100644 index 000000000..e4e63235a --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmins-tail/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +SCALE = np.float32(3.14) +LOGICAL_ELEMS = 1000 +OUT_SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.random((ROWS, COLS), dtype=np.float32) + v2 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v2 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v2.reshape(-1)[:LOGICAL_ELEMS] = np.minimum( + v1.reshape(-1)[:LOGICAL_ELEMS], SCALE + ).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vmins-tail/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vmins-tail/kernel.pto new file mode 100644 index 000000000..9ee49d201 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmins-tail/kernel.pto @@ -0,0 +1,43 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vmins_tail_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1000_i32 = arith.constant 1000 : i32 + %cst = arith.constant 3.140000e+00 : f32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %minv = pto.vmins %vec, %cst, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %minv, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmins-tail/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vmins-tail/launch.cpp new file mode 100644 index 000000000..2774c3f46 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmins-tail/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vmins_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVadds_tail_kernel_2d(float *v1, float *v2, void *stream) { + vmins_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmins-tail/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vmins-tail/main.cpp new file mode 100644 index 000000000..ab77e6b1a --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmins-tail/main.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadds_tail_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadds_tail_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmins-tail/stub.cpp b/test/vpto/cases/micro-op/vec-scalar/vmins-tail/stub.cpp new file mode 100644 index 000000000..7f4b72636 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmins-tail/stub.cpp @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vmins_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmins/compare.py b/test/vpto/cases/micro-op/vec-scalar/vmins/compare.py new file mode 100644 index 000000000..15b793fac --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmins/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vmins/golden.py b/test/vpto/cases/micro-op/vec-scalar/vmins/golden.py new file mode 100644 index 000000000..7caa057f7 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmins/golden.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +SCALE = np.float32(3.14) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.minimum(v1, SCALE).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vmins/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vmins/kernel.pto new file mode 100644 index 000000000..de6e39715 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmins/kernel.pto @@ -0,0 +1,44 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vec_min_scalar_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + %cst = arith.constant 3.140000e+00 : f32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %minv = pto.vmins %vec, %cst, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %minv, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmins/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vmins/launch.cpp new file mode 100644 index 000000000..23603d652 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmins/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vec_min_scalar_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVec_min_scalar_kernel_2d(float *v1, float *v2, void *stream) { + vec_min_scalar_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmins/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vmins/main.cpp new file mode 100644 index 000000000..888a58876 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmins/main.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVec_min_scalar_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVec_min_scalar_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmins/stub.cpp b/test/vpto/cases/micro-op/vec-scalar/vmins/stub.cpp new file mode 100644 index 000000000..437a7aa8e --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmins/stub.cpp @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vec_min_scalar_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/compare.py b/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/compare.py new file mode 100644 index 000000000..c13d79273 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/compare.py @@ -0,0 +1,39 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float32, 1e-4, 1000) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/golden.py b/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/golden.py new file mode 100644 index 000000000..fdfd56fd5 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +SCALE = np.float32(3.14) +LOGICAL_ELEMS = 1000 +OUT_SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.random((ROWS, COLS), dtype=np.float32) + v2 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v2 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v2.reshape(-1)[:LOGICAL_ELEMS] = ( + v1.reshape(-1)[:LOGICAL_ELEMS] * SCALE + ).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/kernel.pto new file mode 100644 index 000000000..11ff8c4fb --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/kernel.pto @@ -0,0 +1,43 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vmuls_tail_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1000_i32 = arith.constant 1000 : i32 + %cst = arith.constant 3.140000e+00 : f32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %prod = pto.vmuls %vec, %cst, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %prod, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/launch.cpp new file mode 100644 index 000000000..65f00d71a --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vmuls_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVadds_tail_kernel_2d(float *v1, float *v2, void *stream) { + vmuls_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/main.cpp new file mode 100644 index 000000000..ab77e6b1a --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/main.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadds_tail_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadds_tail_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/stub.cpp b/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/stub.cpp new file mode 100644 index 000000000..24bf41b71 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/stub.cpp @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vmuls_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmuls/compare.py b/test/vpto/cases/micro-op/vec-scalar/vmuls/compare.py new file mode 100644 index 000000000..15b793fac --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmuls/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vmuls/golden.py b/test/vpto/cases/micro-op/vec-scalar/vmuls/golden.py new file mode 100644 index 000000000..5233be0ed --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmuls/golden.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +SCALE = np.float32(3.14) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = (v1 * SCALE).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vmuls/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vmuls/kernel.pto new file mode 100644 index 000000000..49442c61c --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmuls/kernel.pto @@ -0,0 +1,44 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vec_mul_scalar_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + %cst = arith.constant 3.140000e+00 : f32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %prod = pto.vmuls %vec, %cst, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %prod, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmuls/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vmuls/launch.cpp new file mode 100644 index 000000000..0146a24b5 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmuls/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vec_mul_scalar_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVec_mul_scalar_kernel_2d(float *v1, float *v2, void *stream) { + vec_mul_scalar_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmuls/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vmuls/main.cpp new file mode 100644 index 000000000..e99b6c097 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmuls/main.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVec_mul_scalar_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVec_mul_scalar_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmuls/stub.cpp b/test/vpto/cases/micro-op/vec-scalar/vmuls/stub.cpp new file mode 100644 index 000000000..dc817f0ef --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmuls/stub.cpp @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vec_mul_scalar_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/compare.py b/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/compare.py new file mode 100644 index 000000000..f07dd1f4c --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/compare.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vshls-shift-boundary +# family: vec-scalar +# target_ops: pto.vshls +# scenarios: core-i16-unsigned, full-mask, scalar-operand, shift-boundary + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.uint16) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/golden.py b/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/golden.py new file mode 100644 index 000000000..e0952743a --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vshls-shift-boundary +# family: vec-scalar +# target_ops: pto.vshls +# scenarios: core-i16-unsigned, full-mask, scalar-operand, shift-boundary + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SHIFT = 15 +PATTERN = np.array( + [0x0000, 0x0001, 0x0002, 0x0003, 0x7FFF, 0x8000, 0x8001, 0xFFFF], + dtype=np.uint16, +) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + repeats = ELEMS // PATTERN.size + v1 = np.tile(PATTERN, repeats) + v2 = np.zeros(ELEMS, dtype=np.uint16) + golden_v2 = np.left_shift(v1.astype(np.uint32), SHIFT).astype(np.uint16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=19) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/kernel.pto new file mode 100644 index 000000000..0c7ee44f2 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/kernel.pto @@ -0,0 +1,46 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vshls-shift-boundary +// family: vec-scalar +// target_ops: pto.vshls +// scenarios: core-i16-unsigned, full-mask, scalar-operand, shift-boundary +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vshls_shift_boundary_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %scalar = arith.constant 15 : i16 + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %shifted = pto.vshls %vec, %scalar, %mask : !pto.vreg<128xui16>, i16, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %shifted, %ub_out[%offset], %mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/launch.cpp new file mode 100644 index 000000000..ee7141d19 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/launch.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vshls-shift-boundary +// family: vec-scalar +// target_ops: pto.vshls +// scenarios: core-i16-unsigned, full-mask, scalar-operand +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vshls_shift_boundary_kernel(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVshls_shift_boundary_kernel(float *v1, float *v2, void *stream) { + vshls_shift_boundary_kernel<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/main.cpp new file mode 100644 index 000000000..3b51b0c33 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/main.cpp @@ -0,0 +1,91 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vshls-shift-boundary +// family: vec-scalar +// target_ops: pto.vshls +// scenarios: core-i16-unsigned, full-mask, scalar-operand +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVshls_shift_boundary_kernel(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVshls_shift_boundary_kernel(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/stub.cpp b/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/stub.cpp new file mode 100644 index 000000000..596890454 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/stub.cpp @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vshls-shift-boundary +// family: vec-scalar +// target_ops: pto.vshls +// scenarios: core-i16-unsigned, full-mask, scalar-operand +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vshls_shift_boundary_kernel(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vshls/compare.py b/test/vpto/cases/micro-op/vec-scalar/vshls/compare.py new file mode 100644 index 000000000..d5e9ad835 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshls/compare.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vshls +# family: vec-scalar +# target_ops: pto.vshls +# scenarios: core-i16-unsigned, full-mask, scalar-operand + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.uint16) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vshls/golden.py b/test/vpto/cases/micro-op/vec-scalar/vshls/golden.py new file mode 100644 index 000000000..5d4fe1763 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshls/golden.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vshls +# family: vec-scalar +# target_ops: pto.vshls +# scenarios: core-i16-unsigned, full-mask, scalar-operand + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 +SHIFT = 3 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, 1 << 12, size=ELEMS, dtype=np.uint16) + v2 = np.zeros(ELEMS, dtype=np.uint16) + golden_v2 = np.left_shift(v1.astype(np.uint32), SHIFT).astype(np.uint16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vshls/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vshls/kernel.pto new file mode 100644 index 000000000..5700b094f --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshls/kernel.pto @@ -0,0 +1,46 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vshls +// family: vec-scalar +// target_ops: pto.vshls +// scenarios: core-i16-unsigned, full-mask, scalar-operand +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vshls_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %scalar = arith.constant 3 : i16 + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %shifted = pto.vshls %vec, %scalar, %mask : !pto.vreg<128xui16>, i16, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %shifted, %ub_out[%offset], %mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vshls/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vshls/launch.cpp new file mode 100644 index 000000000..ed048246e --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshls/launch.cpp @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vshls +// family: vec-scalar +// target_ops: pto.vshls +// scenarios: core-i16-unsigned, full-mask, scalar-operand +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vshls_kernel(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVshls_kernel(float *v1, float *v2, void *stream) { + vshls_kernel<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vshls/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vshls/main.cpp new file mode 100644 index 000000000..f5cec4212 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshls/main.cpp @@ -0,0 +1,91 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vshls +// family: vec-scalar +// target_ops: pto.vshls +// scenarios: core-i16-unsigned, full-mask, scalar-operand +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVshls_kernel(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVshls_kernel(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vshls/stub.cpp b/test/vpto/cases/micro-op/vec-scalar/vshls/stub.cpp new file mode 100644 index 000000000..c89341a4f --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshls/stub.cpp @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vshls +// family: vec-scalar +// target_ops: pto.vshls +// scenarios: core-i16-unsigned, full-mask, scalar-operand +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vshls_kernel(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/compare.py b/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/compare.py new file mode 100644 index 000000000..65d6e8920 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/compare.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vshrs-shift-boundary +# family: vec-scalar +# target_ops: pto.vshrs +# scenarios: core-i16-unsigned, full-mask, scalar-operand, shift-boundary + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.uint16) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/golden.py b/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/golden.py new file mode 100644 index 000000000..c1f36dae0 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vshrs-shift-boundary +# family: vec-scalar +# target_ops: pto.vshrs +# scenarios: core-i16-unsigned, full-mask, scalar-operand, shift-boundary + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SHIFT = 15 +PATTERN = np.array( + [0x0000, 0x0001, 0x0002, 0x0003, 0x7FFF, 0x8000, 0x8001, 0xFFFF], + dtype=np.uint16, +) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + repeats = ELEMS // PATTERN.size + v1 = np.tile(PATTERN, repeats) + v2 = np.zeros(ELEMS, dtype=np.uint16) + golden_v2 = np.right_shift(v1, SHIFT) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=19) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/kernel.pto new file mode 100644 index 000000000..8c00efd4c --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/kernel.pto @@ -0,0 +1,46 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vshrs-shift-boundary +// family: vec-scalar +// target_ops: pto.vshrs +// scenarios: core-i16-unsigned, full-mask, scalar-operand, shift-boundary +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vshrs_shift_boundary_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %scalar = arith.constant 15 : i16 + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %shifted = pto.vshrs %vec, %scalar, %mask : !pto.vreg<128xui16>, i16, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %shifted, %ub_out[%offset], %mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/launch.cpp new file mode 100644 index 000000000..b108e4ba5 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/launch.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vshrs-shift-boundary +// family: vec-scalar +// target_ops: pto.vshrs +// scenarios: core-i16-unsigned, full-mask, scalar-operand +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vshrs_shift_boundary_kernel(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVshrs_shift_boundary_kernel(float *v1, float *v2, void *stream) { + vshrs_shift_boundary_kernel<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/main.cpp new file mode 100644 index 000000000..4f5611378 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/main.cpp @@ -0,0 +1,91 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vshrs-shift-boundary +// family: vec-scalar +// target_ops: pto.vshrs +// scenarios: core-i16-unsigned, full-mask, scalar-operand +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVshrs_shift_boundary_kernel(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVshrs_shift_boundary_kernel(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/stub.cpp b/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/stub.cpp new file mode 100644 index 000000000..adbbba298 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/stub.cpp @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vshrs-shift-boundary +// family: vec-scalar +// target_ops: pto.vshrs +// scenarios: core-i16-unsigned, full-mask, scalar-operand +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vshrs_shift_boundary_kernel(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vshrs/compare.py b/test/vpto/cases/micro-op/vec-scalar/vshrs/compare.py new file mode 100644 index 000000000..3c2384aff --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshrs/compare.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vshrs +# family: vec-scalar +# target_ops: pto.vshrs +# scenarios: core-i16-unsigned, full-mask, scalar-operand + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.uint16) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vshrs/golden.py b/test/vpto/cases/micro-op/vec-scalar/vshrs/golden.py new file mode 100644 index 000000000..82b2a9e07 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshrs/golden.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vshrs +# family: vec-scalar +# target_ops: pto.vshrs +# scenarios: core-i16-unsigned, full-mask, scalar-operand + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 +SHIFT = 3 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, np.iinfo(np.uint16).max + 1, size=ELEMS, dtype=np.uint16) + v2 = np.zeros(ELEMS, dtype=np.uint16) + golden_v2 = np.right_shift(v1, SHIFT) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vshrs/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vshrs/kernel.pto new file mode 100644 index 000000000..ca745b48f --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshrs/kernel.pto @@ -0,0 +1,46 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vshrs +// family: vec-scalar +// target_ops: pto.vshrs +// scenarios: core-i16-unsigned, full-mask, scalar-operand +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vshrs_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %scalar = arith.constant 3 : i16 + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %shifted = pto.vshrs %vec, %scalar, %mask : !pto.vreg<128xui16>, i16, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %shifted, %ub_out[%offset], %mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vshrs/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vshrs/launch.cpp new file mode 100644 index 000000000..ebf9902d1 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshrs/launch.cpp @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vshrs +// family: vec-scalar +// target_ops: pto.vshrs +// scenarios: core-i16-unsigned, full-mask, scalar-operand +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vshrs_kernel(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVshrs_kernel(float *v1, float *v2, void *stream) { + vshrs_kernel<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vshrs/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vshrs/main.cpp new file mode 100644 index 000000000..81790da59 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshrs/main.cpp @@ -0,0 +1,91 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vshrs +// family: vec-scalar +// target_ops: pto.vshrs +// scenarios: core-i16-unsigned, full-mask, scalar-operand +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVshrs_kernel(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVshrs_kernel(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vshrs/stub.cpp b/test/vpto/cases/micro-op/vec-scalar/vshrs/stub.cpp new file mode 100644 index 000000000..d8a2720c0 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshrs/stub.cpp @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vshrs +// family: vec-scalar +// target_ops: pto.vshrs +// scenarios: core-i16-unsigned, full-mask, scalar-operand +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vshrs_kernel(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/compare.py b/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/compare.py new file mode 100644 index 000000000..87847b721 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vsubcs-borrow-boundary +# family: vec-scalar +# target_ops: pto.vsubcs +# scenarios: core-u32-unsigned, full-mask, carry-chain, integer-overflow + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +LOGICAL_ELEMS = 64 +SRC_ELEM_BYTES = 4 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + repeat_elems = REPEAT_BYTES // src_elem_bytes + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + +def compare_result(): + golden = np.fromfile("golden_v3.bin", dtype=np.uint32, count=64) + output = np.fromfile("v3.bin", dtype=np.uint32, count=64) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def compare_borrow(): + prefix_bytes = _packed_pred_storage_bytes(LOGICAL_ELEMS, SRC_ELEM_BYTES) + golden = np.fromfile("golden_v4.bin", dtype=np.uint8) + output = np.fromfile("v4.bin", dtype=np.uint8) + if golden.size < prefix_bytes or output.size < prefix_bytes: + return False + return np.array_equal(golden[:prefix_bytes], output[:prefix_bytes]) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_result() and compare_borrow() + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/golden.py b/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/golden.py new file mode 100644 index 000000000..d20ebafc3 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/golden.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vsubcs-borrow-boundary +# family: vec-scalar +# target_ops: pto.vsubcs +# scenarios: core-u32-unsigned, full-mask, carry-chain, integer-overflow + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 64 +LHS_PATTERN = np.array( + [0x00000000, 0x00000001, 0x00000000, 0xFFFFFFFF, 0x80000000, 0x7FFFFFFF, 0xAAAAAAAA, 0x55555555], + dtype=np.uint32, +) +RHS_PATTERN = np.array( + [0x00000000, 0x00000000, 0x00000001, 0xFFFFFFFF, 0x7FFFFFFF, 0x80000000, 0x55555555, 0xAAAAAAAA], + dtype=np.uint32, +) + + +def pack_mask_nibbles(bits): + out = np.zeros(256, dtype=np.uint8) + for idx, bit in enumerate(bits): + if not bit: + continue + byte = idx // 2 + if idx % 2 == 0: + out[byte] |= np.uint8(0x1) + else: + out[byte] |= np.uint8(0x10) + return out + + +def generate(output_dir: Path, seed: int) -> None: + del seed + repeats = LANES // LHS_PATTERN.size + lhs = np.tile(LHS_PATTERN, repeats) + rhs = np.tile(RHS_PATTERN, repeats) + lhs64 = lhs.astype(np.uint64) + rhs64 = rhs.astype(np.uint64) + no_borrow = lhs64 >= rhs64 + result = ((lhs64 - rhs64) & np.uint64(0xFFFFFFFF)).astype(np.uint32) + + output_dir.mkdir(parents=True, exist_ok=True) + lhs.tofile(output_dir / "v1.bin") + rhs.tofile(output_dir / "v2.bin") + np.zeros(LANES, dtype=np.uint32).tofile(output_dir / "v3.bin") + np.zeros(256, dtype=np.uint8).tofile(output_dir / "v4.bin") + result.tofile(output_dir / "golden_v3.bin") + pack_mask_nibbles(no_borrow).tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=19) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/kernel.pto new file mode 100644 index 000000000..830a471c4 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/kernel.pto @@ -0,0 +1,54 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vsubcs-borrow-boundary +// family: vec-scalar +// target_ops: pto.vsubcs +// scenarios: core-u32-unsigned, full-mask, carry-chain, integer-overflow +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vsubcs_borrow_boundary_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr, + %arg3: !pto.ptr) { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + %false = arith.constant false + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_borrow = pto.castptr %c12288_i64 : i64 -> !pto.ptr + + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %borrow_in = pto.pset_b32 "PAT_ALL" : !pto.mask + %lhs = pto.vlds %ub_lhs[%c0] : !pto.ptr -> !pto.vreg<64xui32> + %rhs = pto.vlds %ub_rhs[%c0] : !pto.ptr -> !pto.vreg<64xui32> + %diff, %borrow = pto.vsubcs %lhs, %rhs, %borrow_in, %mask : !pto.vreg<64xui32>, !pto.vreg<64xui32>, !pto.mask, !pto.mask -> !pto.vreg<64xui32>, !pto.mask + pto.vsts %diff, %ub_out[%c0], %mask : !pto.vreg<64xui32>, !pto.ptr, !pto.mask + pto.psti %borrow, %ub_borrow[%c0], "NORM" : !pto.mask, !pto.ptr, index + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_borrow, %arg3, %c0_i64, %c1_i64, %c128_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/launch.cpp new file mode 100644 index 000000000..a1cb56e2e --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/launch.cpp @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vsubcs-borrow-boundary +// family: vec-scalar +// target_ops: pto.vsubcs +// scenarios: core-u32-unsigned, full-mask, carry-chain, integer-overflow +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vsubcs_borrow_boundary_kernel( + __gm__ uint32_t *v1, __gm__ uint32_t *v2, __gm__ uint32_t *v3, + __gm__ uint8_t *v4); + +void LaunchVsubcsBorrowBoundaryKernel(uint32_t *v1, uint32_t *v2, uint32_t *v3, + uint8_t *v4, void *stream) { + vsubcs_borrow_boundary_kernel<<<1, nullptr, stream>>>( + (__gm__ uint32_t *)v1, (__gm__ uint32_t *)v2, (__gm__ uint32_t *)v3, + (__gm__ uint8_t *)v4); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/main.cpp new file mode 100644 index 000000000..169bc4512 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/main.cpp @@ -0,0 +1,116 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vsubcs-borrow-boundary +// family: vec-scalar +// target_ops: pto.vsubcs +// scenarios: core-u32-unsigned, full-mask, carry-chain, integer-overflow +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVsubcsBorrowBoundaryKernel(uint32_t *v1, uint32_t *v2, uint32_t *v3, + uint8_t *v4, void *stream); + +int main() { + size_t elemCount_v1 = 64; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + size_t elemCount_v2 = 64; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint32_t); + size_t elemCount_v3 = 64; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint32_t); + size_t elemCount_v4 = 256; + size_t fileSize_v4 = elemCount_v4 * sizeof(uint8_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + uint32_t *v2Host = nullptr; + uint32_t *v2Device = nullptr; + uint32_t *v3Host = nullptr; + uint32_t *v3Device = nullptr; + uint8_t *v4Host = nullptr; + uint8_t *v4Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMallocHost((void **)(&v4Host), fileSize_v4)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v4Device, fileSize_v4, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ReadFile("./v4.bin", fileSize_v4, v4Host, fileSize_v4); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v4Device, fileSize_v4, v4Host, fileSize_v4, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVsubcsBorrowBoundaryKernel(v1Device, v2Device, v3Device, v4Device, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v4Host, fileSize_v4, v4Device, fileSize_v4, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + WriteFile("./v4.bin", v4Host, fileSize_v4); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFree(v4Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + aclrtFreeHost(v4Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/stub.cpp b/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/stub.cpp new file mode 100644 index 000000000..d119e1ac5 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/stub.cpp @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vsubcs-borrow-boundary +// family: vec-scalar +// target_ops: pto.vsubcs +// scenarios: core-u32-unsigned, full-mask, carry-chain, integer-overflow +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vsubcs_borrow_boundary_kernel( + __gm__ uint32_t *v1, __gm__ uint32_t *v2, __gm__ uint32_t *v3, + __gm__ uint8_t *v4) { + (void)v1; + (void)v2; + (void)v3; + (void)v4; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vsubcs/compare.py b/test/vpto/cases/micro-op/vec-scalar/vsubcs/compare.py new file mode 100644 index 000000000..047f6c245 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vsubcs/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vsubcs +# family: vec-scalar +# target_ops: pto.vsubcs +# scenarios: core-u32-unsigned, full-mask, carry-chain + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +LOGICAL_ELEMS = 64 +SRC_ELEM_BYTES = 4 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + repeat_elems = REPEAT_BYTES // src_elem_bytes + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + +def compare_result(): + golden = np.fromfile("golden_v3.bin", dtype=np.uint32, count=64) + output = np.fromfile("v3.bin", dtype=np.uint32, count=64) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def compare_borrow(): + prefix_bytes = _packed_pred_storage_bytes(LOGICAL_ELEMS, SRC_ELEM_BYTES) + golden = np.fromfile("golden_v4.bin", dtype=np.uint8) + output = np.fromfile("v4.bin", dtype=np.uint8) + if golden.size < prefix_bytes or output.size < prefix_bytes: + return False + return np.array_equal(golden[:prefix_bytes], output[:prefix_bytes]) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_result() and compare_borrow() + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vsubcs/golden.py b/test/vpto/cases/micro-op/vec-scalar/vsubcs/golden.py new file mode 100644 index 000000000..d9c1f2e8b --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vsubcs/golden.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vsubcs +# family: vec-scalar +# target_ops: pto.vsubcs +# scenarios: core-u32-unsigned, full-mask, carry-chain + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 64 +SEED = 19 + + +def pack_mask_nibbles(bits): + out = np.zeros(256, dtype=np.uint8) + for idx, bit in enumerate(bits): + if not bit: + continue + byte = idx // 2 + if idx % 2 == 0: + out[byte] |= np.uint8(0x1) + else: + out[byte] |= np.uint8(0x10) + return out + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + lhs = rng.integers(0, 0xFFFFFFFF, size=LANES, dtype=np.uint32) + rhs = rng.integers(0, 0xFFFFFFFF, size=LANES, dtype=np.uint32) + lhs64 = lhs.astype(np.uint64) + rhs64 = rhs.astype(np.uint64) + no_borrow = lhs64 >= rhs64 + result = ((lhs64 - rhs64) & np.uint64(0xFFFFFFFF)).astype(np.uint32) + + output_dir.mkdir(parents=True, exist_ok=True) + lhs.tofile(output_dir / "v1.bin") + rhs.tofile(output_dir / "v2.bin") + np.zeros(LANES, dtype=np.uint32).tofile(output_dir / "v3.bin") + np.zeros(256, dtype=np.uint8).tofile(output_dir / "v4.bin") + result.tofile(output_dir / "golden_v3.bin") + pack_mask_nibbles(no_borrow).tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vsubcs/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vsubcs/kernel.pto new file mode 100644 index 000000000..8f0476a76 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vsubcs/kernel.pto @@ -0,0 +1,52 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vsubcs +// family: vec-scalar +// target_ops: pto.vsubcs +// scenarios: core-u32-unsigned, full-mask, carry-chain +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vsubcs_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr, %arg3: !pto.ptr) { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + %false = arith.constant false + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_borrow = pto.castptr %c12288_i64 : i64 -> !pto.ptr + + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %borrow_in = pto.pset_b32 "PAT_ALL" : !pto.mask + %lhs = pto.vlds %ub_lhs[%c0] : !pto.ptr -> !pto.vreg<64xui32> + %rhs = pto.vlds %ub_rhs[%c0] : !pto.ptr -> !pto.vreg<64xui32> + %diff, %borrow = pto.vsubcs %lhs, %rhs, %borrow_in, %mask : !pto.vreg<64xui32>, !pto.vreg<64xui32>, !pto.mask, !pto.mask -> !pto.vreg<64xui32>, !pto.mask + pto.vsts %diff, %ub_out[%c0], %mask : !pto.vreg<64xui32>, !pto.ptr, !pto.mask + pto.psti %borrow, %ub_borrow[%c0], "NORM" : !pto.mask, !pto.ptr, index + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_borrow, %arg3, %c0_i64, %c1_i64, %c128_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vsubcs/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vsubcs/launch.cpp new file mode 100644 index 000000000..534b84ab1 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vsubcs/launch.cpp @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vsubcs +// family: vec-scalar +// target_ops: pto.vsubcs +// scenarios: core-u32-unsigned, full-mask, carry-chain +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vsubcs_kernel(__gm__ uint32_t *v1, __gm__ uint32_t *v2, __gm__ uint32_t *v3, + __gm__ uint8_t *v4); + +void LaunchVsubcs_kernel(uint32_t *v1, uint32_t *v2, uint32_t *v3, uint8_t *v4, + void *stream) { + vsubcs_kernel<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1, + (__gm__ uint32_t *)v2, + (__gm__ uint32_t *)v3, + (__gm__ uint8_t *)v4); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vsubcs/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vsubcs/main.cpp new file mode 100644 index 000000000..5bcad0fcc --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vsubcs/main.cpp @@ -0,0 +1,115 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vsubcs +// family: vec-scalar +// target_ops: pto.vsubcs +// scenarios: core-u32-unsigned, full-mask, carry-chain +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVsubcs_kernel(uint32_t *v1, uint32_t *v2, uint32_t *v3, uint8_t *v4, + void *stream); + +int main() { + size_t elemCount_v1 = 64; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + size_t elemCount_v2 = 64; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint32_t); + size_t elemCount_v3 = 64; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint32_t); + size_t elemCount_v4 = 256; + size_t fileSize_v4 = elemCount_v4 * sizeof(uint8_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + uint32_t *v2Host = nullptr; + uint32_t *v2Device = nullptr; + uint32_t *v3Host = nullptr; + uint32_t *v3Device = nullptr; + uint8_t *v4Host = nullptr; + uint8_t *v4Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMallocHost((void **)(&v4Host), fileSize_v4)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v4Device, fileSize_v4, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ReadFile("./v4.bin", fileSize_v4, v4Host, fileSize_v4); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v4Device, fileSize_v4, v4Host, fileSize_v4, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVsubcs_kernel(v1Device, v2Device, v3Device, v4Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v4Host, fileSize_v4, v4Device, fileSize_v4, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + WriteFile("./v4.bin", v4Host, fileSize_v4); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFree(v4Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + aclrtFreeHost(v4Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vsubcs/stub.cpp b/test/vpto/cases/micro-op/vec-scalar/vsubcs/stub.cpp new file mode 100644 index 000000000..e7a3f169c --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vsubcs/stub.cpp @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vsubcs +// family: vec-scalar +// target_ops: pto.vsubcs +// scenarios: core-u32-unsigned, full-mask, carry-chain +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void +vsubcs_kernel(__gm__ uint32_t *v1, __gm__ uint32_t *v2, __gm__ uint32_t *v3, + __gm__ uint8_t *v4) { + (void)v1; + (void)v2; + (void)v3; + (void)v4; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/compare.py b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/compare.py new file mode 100755 index 000000000..bc0d8fc41 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/compare.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vldas-vldus-state-chain +# family: vector-load-store +# target_ops: pto.vldas, pto.vldus +# scenarios: core-f32, full-mask, unaligned, stream-state, state-update +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +PREFIX_ELEMS = 128 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float32, 0.0001, PREFIX_ELEMS) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/golden.py b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/golden.py new file mode 100755 index 000000000..926db9342 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/golden.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vldas-vldus-state-chain +# family: vector-load-store +# target_ops: pto.vldas, pto.vldus +# scenarios: core-f32, full-mask, unaligned, repeated-no-post +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +LANES = 64 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + flat_in = v1.reshape(-1) + flat_out = golden_v2.reshape(-1) + flat_out[:LANES] = flat_in[1 : 1 + LANES] + flat_out[LANES : 2 * LANES] = flat_in[65 : 65 + LANES] + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vldas-vldus state-chain validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/kernel.pto new file mode 100644 index 000000000..580780047 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/kernel.pto @@ -0,0 +1,59 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vldas-vldus-state-chain +// family: vector-load-store +// target_ops: pto.vldas, pto.vldus +// scenarios: core-f32, full-mask, unaligned, repeated-no-post +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// Validate repeated no-post unaligned loads. Each `pto.vldus` is paired with +// its own `pto.vldas` and uses an explicit unaligned source pointer; the second +// load does not depend on state returned from the first one. + +module attributes {pto.target_arch = "a5"} { + func.func @vldas_vldus_state_chain_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %offset = %c0 to %c64 step %c64 { + %src0 = pto.addptr %ub_in, %c1 : !pto.ptr -> !pto.ptr + %src1 = pto.addptr %src0, %c64 : !pto.ptr -> !pto.ptr + %align0 = pto.vldas %src0 : !pto.ptr -> !pto.align + %mask, %next_remaining = pto.plt_b32 %c1024_i32 : i32 -> !pto.mask, i32 + %out0, %align1 = pto.vldus %src0, %align0 + : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align + %align2 = pto.vldas %src1 : !pto.ptr -> !pto.align + %out1, %align3 = pto.vldus %src1, %align2 + : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align + pto.vsts %out0, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.vsts %out1, %ub_out[%c64], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/launch.cpp new file mode 100644 index 000000000..6a077aa65 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vldas-vldus-state-chain +// family: vector-load-store +// target_ops: pto.vldas, pto.vldus +// scenarios: core-f32, full-mask, unaligned, stream-state, state-update +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vldas_vldus_state_chain_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVldasVldusStateChain_kernel_2d(float *v1, float *v2, void *stream) { + vldas_vldus_state_chain_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/main.cpp new file mode 100644 index 000000000..95044a646 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vldas-vldus-state-chain +// family: vector-load-store +// target_ops: pto.vldas, pto.vldus +// scenarios: core-f32, full-mask, unaligned, stream-state, state-update +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVldasVldusStateChain_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVldasVldusStateChain_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/stub.cpp b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/stub.cpp new file mode 100644 index 000000000..21dc9b3a7 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vldas-vldus-state-chain +// family: vector-load-store +// target_ops: pto.vldas, pto.vldus +// scenarios: core-f32, full-mask, unaligned, stream-state, state-update +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vldas_vldus_state_chain_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/compare.py b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/compare.py new file mode 100755 index 000000000..f4bfe43cc --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/compare.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vldas-vldus +# family: vector-load-store +# target_ops: pto.vldas, pto.vldus +# scenarios: core-f32, full-mask, unaligned, stream-state +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +PREFIX_ELEMS = 64 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float32, 0.0001, PREFIX_ELEMS) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/golden.py b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/golden.py new file mode 100755 index 000000000..77d537906 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/golden.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vldas-vldus +# family: vector-load-store +# target_ops: pto.vldas, pto.vldus +# scenarios: core-f32, full-mask, unaligned, stream-state +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +LANES = 64 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + flat_in = v1.reshape(-1) + flat_out = golden_v2.reshape(-1) + flat_out[:LANES] = flat_in[1 : 1 + LANES] + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vldas-vldus validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/kernel.pto new file mode 100644 index 000000000..0eaf940b7 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/kernel.pto @@ -0,0 +1,69 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vldas-vldus +// family: vector-load-store +// target_ops: pto.vldas, pto.vldus +// scenarios: core-f32, full-mask, unaligned, stream-state +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5"} { + func.func @vabs_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %offset = %c0 to %c64 step %c64 { + %src0 = pto.addptr %ub_in, %c1 : !pto.ptr -> !pto.ptr + %align0 = pto.vldas %src0 : !pto.ptr -> !pto.align + %mask, %next_remaining = pto.plt_b32 %c1024_i32 : i32 -> !pto.mask, i32 + %out, %next_align = pto.vldus %src0, %align0 + : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/launch.cpp new file mode 100644 index 000000000..b25715db0 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vldas-vldus +// family: vector-load-store +// target_ops: pto.vldas, pto.vldus +// scenarios: core-f32, full-mask, unaligned, stream-state +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vabs_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/main.cpp new file mode 100644 index 000000000..7bf75309b --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vldas-vldus +// family: vector-load-store +// target_ops: pto.vldas, pto.vldus +// scenarios: core-f32, full-mask, unaligned, stream-state +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/stub.cpp b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/stub.cpp new file mode 100644 index 000000000..5e3ac7966 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vldas-vldus +// family: vector-load-store +// target_ops: pto.vldas, pto.vldus +// scenarios: core-f32, full-mask, unaligned, stream-state +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/compare.py b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/compare.py new file mode 100644 index 000000000..4c19eb038 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/compare.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/golden.py b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/golden.py new file mode 100644 index 000000000..8cc8dbe42 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/golden.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vlds-brc-b16-f32 +# family: vector-load-store +# target_ops: pto.vlds +# scenarios: core-f32, full-mask, aligned, dist-brc-b16, width-agnostic-dist + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMENTS = 1024 +LANES = 64 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ELEMENTS,)).astype(np.float32) + v2 = np.zeros((ELEMENTS,), dtype=np.float32) + + src_bytes = v1.view(np.uint8) + golden_bytes = np.zeros_like(src_bytes) + chunk_bytes = LANES * 4 + for offset in range(0, src_bytes.size, chunk_bytes): + pattern = src_bytes[offset : offset + 2] + tiled = np.tile(pattern, chunk_bytes // 2) + golden_bytes[offset : offset + chunk_bytes] = tiled + golden_v2 = golden_bytes.view(np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vlds BRC_B16 on f32 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/kernel.pto new file mode 100644 index 000000000..bd6c2e8d8 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/kernel.pto @@ -0,0 +1,45 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-brc-b16-f32 +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f32, full-mask, aligned, dist-brc-b16, width-agnostic-dist +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vlds_brc_b16_f32_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %out = pto.vlds %ub_in[%offset] {dist = "BRC_B16"} : !pto.ptr -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/launch.cpp new file mode 100644 index 000000000..530496dba --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vlds_brc_b16_f32_kernel(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVlds_brc_b16_f32_kernel(float *v1, float *v2, void *stream) { + vlds_brc_b16_f32_kernel<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} + diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/main.cpp new file mode 100644 index 000000000..661e47152 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/main.cpp @@ -0,0 +1,122 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVlds_brc_b16_f32_kernel(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVlds_brc_b16_f32_kernel(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/stub.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/stub.cpp new file mode 100644 index 000000000..39b43ab0f --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/stub.cpp @@ -0,0 +1,21 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vlds_brc_b16_f32_kernel(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/compare.py b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/compare.py new file mode 100755 index 000000000..51c671b74 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/compare.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vlds-brc-b16 +# family: vector-load-store +# target_ops: pto.vlds +# scenarios: core-f16, full-mask, aligned, dist-brc-b16 +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +PREFIX_ELEMS = 1024 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float16, 0.0001, PREFIX_ELEMS) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/golden.py b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/golden.py new file mode 100755 index 000000000..1f593f79d --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/golden.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vlds-brc-b16 +# family: vector-load-store +# target_ops: pto.vlds +# scenarios: core-f16, full-mask, aligned, dist-brc-b16 +# NOTE: BRC on b16 broadcasts the first f16 element of each 128-lane chunk. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMENTS = 2048 +ACTIVE_ELEMS = 1024 +LANES = 128 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ELEMENTS,)).astype(np.float16) + v2 = np.zeros((ELEMENTS,), dtype=np.float16) + golden_v2 = np.zeros((ELEMENTS,), dtype=np.float16) + for offset in range(0, ACTIVE_ELEMS, LANES): + golden_v2[offset : offset + LANES] = v1[offset] + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vlds b16 broadcast validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/kernel.pto new file mode 100644 index 000000000..d7a79b2b0 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/kernel.pto @@ -0,0 +1,52 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-brc-b16 +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f16, full-mask, aligned, dist-brc-b16 +// ----------------------------------------------------------------------------- +// Validate one representative `BRC_B16` load on `b16`. +// The case keeps the structure minimal: +// 1. DMA one input tile into UB +// 2. issue `pto.vlds` with `dist = "BRC_B16"` inside `pto.vecscope` +// 3. store the resulting vector back through `pto.vsts` + +module attributes {pto.target_arch = "a5"} { + func.func @vabs_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %out = pto.vlds %ub_in[%offset] {dist = "BRC_B16"} : !pto.ptr -> !pto.vreg<128xf16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xf16>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/launch.cpp new file mode 100644 index 000000000..1d8ac9f5d --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-brc-b16 +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f16, full-mask, aligned, dist-brc-b16 +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vabs_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/main.cpp new file mode 100644 index 000000000..cbec16893 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-brc-b16 +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f16, full-mask, aligned, dist-brc-b16 +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/stub.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/stub.cpp new file mode 100644 index 000000000..ad9fd1973 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-brc-b16 +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f16, full-mask, aligned, dist-brc-b16 +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/compare.py b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/compare.py new file mode 100755 index 000000000..16f027ac0 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vlds-brc-b32 +# family: vector-load-store +# target_ops: pto.vlds +# scenarios: core-f32, full-mask, aligned, dist-brc-b32 +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/golden.py b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/golden.py new file mode 100755 index 000000000..541e3d770 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/golden.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vlds-brc-b32 +# family: vector-load-store +# target_ops: pto.vlds +# scenarios: core-f32, full-mask, aligned, dist-brc-b32 +# NOTE: BRC on b32 broadcasts the first f32 element of each 64-lane chunk. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMENTS = 1024 +LANES = 64 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ELEMENTS,)).astype(np.float32) + v2 = np.zeros((ELEMENTS,), dtype=np.float32) + golden_v2 = np.empty((ELEMENTS,), dtype=np.float32) + for offset in range(0, ELEMENTS, LANES): + golden_v2[offset : offset + LANES] = v1[offset] + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vlds b32 broadcast validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/kernel.pto new file mode 100644 index 000000000..1bedcf05f --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/kernel.pto @@ -0,0 +1,67 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-brc-b32 +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f32, full-mask, aligned, dist-brc-b32 +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5"} { + func.func @vabs_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %out = pto.vlds %ub_in[%offset] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/launch.cpp new file mode 100644 index 000000000..bc599a4ac --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-brc-b32 +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f32, full-mask, aligned, dist-brc-b32 +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vabs_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/main.cpp new file mode 100644 index 000000000..567ff23cf --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-brc-b32 +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f32, full-mask, aligned, dist-brc-b32 +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/stub.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/stub.cpp new file mode 100644 index 000000000..fcec12a1f --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-brc-b32 +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f32, full-mask, aligned, dist-brc-b32 +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/compare.py b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/compare.py new file mode 100644 index 000000000..4c19eb038 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/compare.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/golden.py b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/golden.py new file mode 100644 index 000000000..be66bd820 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/golden.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vlds-brc-b8-f32 +# family: vector-load-store +# target_ops: pto.vlds +# scenarios: core-f32, full-mask, aligned, dist-brc-b8, width-agnostic-dist + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMENTS = 1024 +LANES = 64 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ELEMENTS,)).astype(np.float32) + v2 = np.zeros((ELEMENTS,), dtype=np.float32) + + src_bytes = v1.view(np.uint8) + golden_bytes = np.zeros_like(src_bytes) + chunk_bytes = LANES * 4 + for offset in range(0, src_bytes.size, chunk_bytes): + pattern = src_bytes[offset] + golden_bytes[offset : offset + chunk_bytes] = pattern + golden_v2 = golden_bytes.view(np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vlds BRC_B8 on f32 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/kernel.pto new file mode 100644 index 000000000..3e5f101db --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/kernel.pto @@ -0,0 +1,45 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-brc-b8-f32 +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f32, full-mask, aligned, dist-brc-b8, width-agnostic-dist +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vlds_brc_b8_f32_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %out = pto.vlds %ub_in[%offset] {dist = "BRC_B8"} : !pto.ptr -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/launch.cpp new file mode 100644 index 000000000..4628e9dc1 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vlds_brc_b8_f32_kernel(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVlds_brc_b8_f32_kernel(float *v1, float *v2, void *stream) { + vlds_brc_b8_f32_kernel<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} + diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/main.cpp new file mode 100644 index 000000000..bf2d99510 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/main.cpp @@ -0,0 +1,122 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVlds_brc_b8_f32_kernel(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVlds_brc_b8_f32_kernel(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/stub.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/stub.cpp new file mode 100644 index 000000000..f1705c6c4 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/stub.cpp @@ -0,0 +1,21 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vlds_brc_b8_f32_kernel(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/compare.py b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/compare.py new file mode 100755 index 000000000..db7576dcf --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/compare.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vlds-brc-blk +# family: vector-load-store +# target_ops: pto.vlds +# scenarios: core-u8, full-mask, aligned, dist-brc-blk +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +PREFIX_ELEMS = 1024 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.uint8, 0.0, PREFIX_ELEMS) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/golden.py b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/golden.py new file mode 100755 index 000000000..4b9610b1e --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/golden.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vlds-brc-blk +# family: vector-load-store +# target_ops: pto.vlds +# scenarios: core-u8, full-mask, aligned, dist-brc-blk +# NOTE: BRC_BLK repeats the first 32-byte block across each 256-byte vector chunk. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMENTS = 4096 +ACTIVE_ELEMS = 1024 +LANES = 256 +BLOCK_BYTES = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, 256, size=(ELEMENTS,), dtype=np.uint8) + v2 = np.zeros((ELEMENTS,), dtype=np.uint8) + golden_v2 = np.zeros((ELEMENTS,), dtype=np.uint8) + repeats = LANES // BLOCK_BYTES + for offset in range(0, ACTIVE_ELEMS, LANES): + golden_v2[offset : offset + LANES] = np.tile(v1[offset : offset + BLOCK_BYTES], repeats) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vlds block-broadcast validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/kernel.pto new file mode 100644 index 000000000..13a1330d9 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/kernel.pto @@ -0,0 +1,48 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-brc-blk +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-u8, full-mask, aligned, dist-brc-blk +// ----------------------------------------------------------------------------- +// Validate one representative `BRC_BLK` load. + +module attributes {pto.target_arch = "a5"} { + func.func @vabs_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c256 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b8 %remaining : i32 -> !pto.mask, i32 + %out = pto.vlds %ub_in[%offset] {dist = "BRC_BLK"} : !pto.ptr -> !pto.vreg<256xui8> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<256xui8>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/launch.cpp new file mode 100644 index 000000000..0299a6136 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-brc-blk +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-u8, full-mask, aligned, dist-brc-blk +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vabs_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/main.cpp new file mode 100644 index 000000000..7e62df365 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-brc-blk +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-u8, full-mask, aligned, dist-brc-blk +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/stub.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/stub.cpp new file mode 100644 index 000000000..1b3e5dd56 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-brc-blk +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-u8, full-mask, aligned, dist-brc-blk +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/compare.py b/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/compare.py new file mode 100755 index 000000000..e558d22f2 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/compare.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vlds-ds-b16 +# family: vector-load-store +# target_ops: pto.vlds +# scenarios: core-i16, full-mask, aligned, dist-ds-b16 +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +PREFIX_ELEMS = 1024 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.int16, 0.0, PREFIX_ELEMS) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/golden.py b/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/golden.py new file mode 100755 index 000000000..63da9f605 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/golden.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vlds-ds-b16 +# family: vector-load-store +# target_ops: pto.vlds +# scenarios: core-i16, full-mask, aligned, dist-ds-b16 +# NOTE: DS on b16 keeps every other i16 element from a 256-element source window. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMENTS = 2048 +ACTIVE_ELEMS = 1024 +LANES = 128 +SOURCE_WINDOW = 256 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(-(2**15), 2**15, size=(ELEMENTS,), dtype=np.int16) + v2 = np.zeros((ELEMENTS,), dtype=np.int16) + golden_v2 = np.zeros((ELEMENTS,), dtype=np.int16) + for offset in range(0, ACTIVE_ELEMS, LANES): + golden_v2[offset : offset + LANES] = v1[offset : offset + SOURCE_WINDOW : 2][:LANES] + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vlds b16 downsample validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/kernel.pto new file mode 100644 index 000000000..66a727e94 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/kernel.pto @@ -0,0 +1,48 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-ds-b16 +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-i16, full-mask, aligned, dist-ds-b16 +// ----------------------------------------------------------------------------- +// Validate one representative `DS_B16` load on `b16`. + +module attributes {pto.target_arch = "a5"} { + func.func @vabs_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %out = pto.vlds %ub_in[%offset] {dist = "DS_B16"} : !pto.ptr -> !pto.vreg<128xi16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xi16>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/launch.cpp new file mode 100644 index 000000000..07ccb8b8d --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-ds-b16 +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-i16, full-mask, aligned, dist-ds-b16 +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vabs_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/main.cpp new file mode 100644 index 000000000..951256acb --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-ds-b16 +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-i16, full-mask, aligned, dist-ds-b16 +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/stub.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/stub.cpp new file mode 100644 index 000000000..236035f6d --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-ds-b16 +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-i16, full-mask, aligned, dist-ds-b16 +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-tail/compare.py b/test/vpto/cases/micro-op/vector-load-store/vlds-tail/compare.py new file mode 100755 index 000000000..1f3503d9b --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-tail/compare.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vlds-tail +# family: vector-load-store +# target_ops: pto.vlds +# scenarios: core-f32, contiguous, tail-mask, aligned, dist-norm +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +PREFIX_ELEMS = 1000 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float32, 0.0001, PREFIX_ELEMS) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-tail/golden.py b/test/vpto/cases/micro-op/vector-load-store/vlds-tail/golden.py new file mode 100755 index 000000000..7eca6a3ad --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-tail/golden.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vlds-tail +# family: vector-load-store +# target_ops: pto.vlds +# scenarios: core-f32, contiguous, tail-mask, aligned, dist-norm +# NOTE: tail-mask case writes the first 1000 f32 lanes and leaves the +# remaining lanes zero. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMENTS = 1024 +LOGICAL_ELEMS = 1000 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ELEMENTS,)).astype(np.float32) + v2 = np.zeros((ELEMENTS,), dtype=np.float32) + golden_v2 = np.zeros((ELEMENTS,), dtype=np.float32) + golden_v2[:LOGICAL_ELEMS] = v1[:LOGICAL_ELEMS] + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vlds tail validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-tail/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds-tail/kernel.pto new file mode 100644 index 000000000..505e0d678 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-tail/kernel.pto @@ -0,0 +1,68 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-tail +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f32, contiguous, tail-mask, aligned, dist-norm +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5"} { + func.func @vabs_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1000_i32 = arith.constant 1000 : i32 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %out = pto.vlds %ub_in[%offset] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-tail/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-tail/launch.cpp new file mode 100644 index 000000000..dfbca2f61 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-tail/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-tail +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f32, contiguous, tail-mask, aligned, dist-norm +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vabs_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-tail/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-tail/main.cpp new file mode 100644 index 000000000..a9f049135 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-tail/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-tail +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f32, contiguous, tail-mask, aligned, dist-norm +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-tail/stub.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-tail/stub.cpp new file mode 100644 index 000000000..71bd9b99e --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-tail/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-tail +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f32, contiguous, tail-mask, aligned, dist-norm +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/compare.py b/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/compare.py new file mode 100644 index 000000000..3c394fd2d --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/compare.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vlds-unpk-b16 +# family: vector-load-store +# target_ops: pto.vlds +# scenarios: core-f16, full-mask, aligned, dist-unpk-b16 + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint16) + output = np.fromfile(output_path, dtype=np.uint16) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden=0x{int(golden[idx]):04x}, out=0x{int(output[idx]):04x})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/golden.py b/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/golden.py new file mode 100644 index 000000000..5d7850f1b --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/golden.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vlds-unpk-b16 +# family: vector-load-store +# target_ops: pto.vlds +# scenarios: core-f16, full-mask, aligned, dist-unpk-b16 + +import argparse +from pathlib import Path + +import numpy as np + + +INPUT_ELEMS = 1024 +OUTPUT_ELEMS = 2048 +SRC_CHUNK = 64 +DST_CHUNK = 128 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + src = rng.uniform(-8.0, 8.0, size=INPUT_ELEMS).astype(np.float16) + dst = np.zeros((OUTPUT_ELEMS,), dtype=np.float16) + golden = np.zeros((OUTPUT_ELEMS,), dtype=np.float16) + + for src_base in range(0, INPUT_ELEMS, SRC_CHUNK): + dst_base = src_base * 2 + golden[dst_base : dst_base + DST_CHUNK : 2] = src[src_base : src_base + SRC_CHUNK] + + output_dir.mkdir(parents=True, exist_ok=True) + src.view(np.uint16).tofile(output_dir / "v1.bin") + dst.view(np.uint16).tofile(output_dir / "v2.bin") + golden.view(np.uint16).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vlds UNPK_B16 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/kernel.pto new file mode 100644 index 000000000..672e884e2 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/kernel.pto @@ -0,0 +1,54 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-unpk-b16 +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f16, full-mask, aligned, dist-unpk-b16 +// ----------------------------------------------------------------------------- +// Validate one representative `UNPK_B16` load on `b16`. +// Installed A5 `TCvt.hpp::cast16to32` uses `vlds(..., UNPK_B16)` before +// `vcvt(..., PART_EVEN)`, so this case probes the resulting 128-lane `f16` +// layout directly through `vsts`. + +module attributes {pto.target_arch = "a5"} { + func.func @vlds_unpk_b16_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c2048_i64 = arith.constant 2048 : i64 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %full_mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %src_offset = %c0 to %c1024 step %c64 { + %dst_offset = arith.muli %src_offset, %c2 : index + %out = pto.vlds %ub_in[%src_offset] {dist = "UNPK_B16"} : !pto.ptr -> !pto.vreg<128xf16> + pto.vsts %out, %ub_out[%dst_offset], %full_mask : !pto.vreg<128xf16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/launch.cpp new file mode 100644 index 000000000..b8544e34a --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/launch.cpp @@ -0,0 +1,50 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-unpk-b16 +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f16, full-mask, aligned, dist-unpk-b16 +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vlds_unpk_b16_kernel_2d(__gm__ half *v1, + __gm__ half *v2); + +void LaunchVlds_unpk_b16_kernel_2d(uint16_t *v1, uint16_t *v2, void *stream) { + vlds_unpk_b16_kernel_2d<<<1, nullptr, stream>>>((__gm__ half *)v1, + (__gm__ half *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/main.cpp new file mode 100644 index 000000000..83260e90f --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/main.cpp @@ -0,0 +1,102 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-unpk-b16 +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f16, full-mask, aligned, dist-unpk-b16 +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVlds_unpk_b16_kernel_2d(uint16_t *v1, uint16_t *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 2048; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVlds_unpk_b16_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/stub.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/stub.cpp new file mode 100644 index 000000000..768b8bebe --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/stub.cpp @@ -0,0 +1,28 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-unpk-b16 +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f16, full-mask, aligned, dist-unpk-b16 +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vlds_unpk_b16_kernel_2d(__gm__ half *v1, + __gm__ half *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/compare.py b/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/compare.py new file mode 100755 index 000000000..5ccc50a39 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/compare.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vlds-us-b16 +# family: vector-load-store +# target_ops: pto.vlds +# scenarios: core-i16, full-mask, aligned, dist-us-b16 +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +PREFIX_ELEMS = 1024 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.int16, 0.0, PREFIX_ELEMS) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/golden.py b/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/golden.py new file mode 100755 index 000000000..214b0269c --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/golden.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vlds-us-b16 +# family: vector-load-store +# target_ops: pto.vlds +# scenarios: core-i16, full-mask, aligned, dist-us-b16 +# NOTE: US on b16 duplicates each source i16 element into two consecutive lanes. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMENTS = 2048 +ACTIVE_ELEMS = 1024 +LANES = 128 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(-(2**15), 2**15, size=(ELEMENTS,), dtype=np.int16) + v2 = np.zeros((ELEMENTS,), dtype=np.int16) + golden_v2 = np.zeros((ELEMENTS,), dtype=np.int16) + half_lanes = LANES // 2 + for offset in range(0, ACTIVE_ELEMS, LANES): + golden_v2[offset : offset + LANES] = np.repeat(v1[offset : offset + half_lanes], 2) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vlds b16 upsample validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/kernel.pto new file mode 100644 index 000000000..8e79bcab0 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/kernel.pto @@ -0,0 +1,48 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-us-b16 +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-i16, full-mask, aligned, dist-us-b16 +// ----------------------------------------------------------------------------- +// Validate one representative `US_B16` load on `b16`. + +module attributes {pto.target_arch = "a5"} { + func.func @vabs_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %out = pto.vlds %ub_in[%offset] {dist = "US_B16"} : !pto.ptr -> !pto.vreg<128xi16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xi16>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/launch.cpp new file mode 100644 index 000000000..4ecd1586e --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-us-b16 +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-i16, full-mask, aligned, dist-us-b16 +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vabs_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/main.cpp new file mode 100644 index 000000000..2c4c9b679 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-us-b16 +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-i16, full-mask, aligned, dist-us-b16 +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/stub.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/stub.cpp new file mode 100644 index 000000000..cdc3a6ec3 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-us-b16 +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-i16, full-mask, aligned, dist-us-b16 +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds/compare.py b/test/vpto/cases/micro-op/vector-load-store/vlds/compare.py new file mode 100755 index 000000000..1c07e2d7c --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vlds +# family: vector-load-store +# target_ops: pto.vlds +# scenarios: core-f32, contiguous, full-mask, aligned, dist-norm +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds/golden.py b/test/vpto/cases/micro-op/vector-load-store/vlds/golden.py new file mode 100755 index 000000000..21c58baab --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vlds +# family: vector-load-store +# target_ops: pto.vlds +# scenarios: core-f32, contiguous, full-mask, aligned, dist-norm +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = v1.astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vlds validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds/kernel.pto new file mode 100644 index 000000000..fca778706 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds/kernel.pto @@ -0,0 +1,67 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f32, contiguous, full-mask, aligned, dist-norm +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5"} { + func.func @vabs_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %out = pto.vlds %ub_in[%offset] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds/launch.cpp new file mode 100644 index 000000000..2e2fa02fb --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f32, contiguous, full-mask, aligned, dist-norm +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vabs_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds/main.cpp new file mode 100644 index 000000000..ab816737d --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f32, contiguous, full-mask, aligned, dist-norm +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds/stub.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds/stub.cpp new file mode 100644 index 000000000..3015e54ab --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f32, contiguous, full-mask, aligned, dist-norm +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/compare.py b/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/compare.py new file mode 100755 index 000000000..a4c5fae81 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vldsx2-layout-check +# family: vector-load-store +# target_ops: pto.vldsx2 +# scenarios: core-f32, full-mask, paired-roundtrip, dintlv, lane-order +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/golden.py b/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/golden.py new file mode 100755 index 000000000..8c481d96b --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/golden.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vldsx2-layout-check +# family: vector-load-store +# target_ops: pto.vldsx2 +# scenarios: core-f32, full-mask, dintlv, lane-order, split-observation +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +ACTIVE = 128 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32).reshape(-1) + flat = v1.reshape(-1) + + # DINTLV_B32 exposes the two deinterleaved 64-lane results independently. + # Observe them through two plain NORM_B32 stores: + # low -> output[offset : offset + 64] + # high -> output[offset + 64 : offset + 128] + for base in range(0, ROWS * COLS, ACTIVE): + chunk = flat[base : base + ACTIVE] + golden_v2[base : base + 64] = chunk[0::2] + golden_v2[base + 64 : base + 128] = chunk[1::2] + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vldsx2 layout validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/kernel.pto new file mode 100644 index 000000000..762500c80 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/kernel.pto @@ -0,0 +1,62 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vldsx2-layout-check +// family: vector-load-store +// target_ops: pto.vldsx2 +// scenarios: core-f32, full-mask, dintlv, lane-order, split-observation +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vldx2_layout_check_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c8 = arith.constant 8 : index + %c128 = arith.constant 128 : index + %c1_i64 = arith.constant 1 : i64 + %c0_i64 = arith.constant 0 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c64_i32 = arith.constant 64 : i32 + %false = arith.constant false + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, + %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, + %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, + i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %group = %c0 to %c8 step %c1 { + %group_base = arith.muli %group, %c128 : index + scf.for %chunk = %c0 to %c128 step %c128 { + %offset = arith.addi %group_base, %chunk : index + %high_offset = arith.addi %offset, %c64 : index + %mask, %remaining = pto.plt_b32 %c64_i32 : i32 -> !pto.mask, i32 + %x, %y = pto.vldsx2 %ub_in[%offset], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + pto.vsts %x, %ub_out[%offset], %mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.vsts %y, %ub_out[%high_offset], %mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, + %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/launch.cpp new file mode 100644 index 000000000..d06dda18c --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/launch.cpp @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vldsx2-layout-check +// family: vector-load-store +// target_ops: pto.vldsx2 +// scenarios: core-f32, full-mask, paired-roundtrip, dintlv, lane-order +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vldx2_layout_check_kernel(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVldx2_layout_check_kernel(float *v1, float *v2, void *stream) { + vldx2_layout_check_kernel<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/main.cpp new file mode 100644 index 000000000..45e578b57 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vldsx2-layout-check +// family: vector-load-store +// target_ops: pto.vldsx2 +// scenarios: core-f32, full-mask, paired-roundtrip, dintlv, lane-order +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVldx2_layout_check_kernel(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVldx2_layout_check_kernel(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/stub.cpp b/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/stub.cpp new file mode 100644 index 000000000..8cf335928 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vldsx2-layout-check +// family: vector-load-store +// target_ops: pto.vldsx2 +// scenarios: core-f32, full-mask, paired-roundtrip, dintlv, lane-order +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vldx2_layout_check_kernel(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/compare.py b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/compare.py new file mode 100644 index 000000000..af950320b --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/compare.py @@ -0,0 +1,201 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vldsx2-vstsx2-b8-f32 +# family: vector-load-store +# target_ops: pto.vldsx2, pto.vstsx2 +# scenarios: core-f32, full-mask, paired-roundtrip, dintlv-b8-intlv-b8, width-agnostic-dist +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/golden.py b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/golden.py new file mode 100644 index 000000000..6732c8799 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vldsx2-vstsx2-b8-f32 +# family: vector-load-store +# target_ops: pto.vldsx2, pto.vstsx2 +# scenarios: core-f32, full-mask, paired-roundtrip, dintlv-b8-intlv-b8, width-agnostic-dist +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.array(v1, copy=True) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vldsx2-vstsx2-b8-f32 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/kernel.pto new file mode 100644 index 000000000..193eb3043 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/kernel.pto @@ -0,0 +1,59 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vldsx2-vstsx2-b8-f32 +// family: vector-load-store +// target_ops: pto.vldsx2, pto.vstsx2 +// scenarios: core-f32, full-mask, paired-roundtrip, dintlv-b8-intlv-b8, width-agnostic-dist +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vldx2_vstsx2_b8_f32_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c128 = arith.constant 128 : index + %c1_i64 = arith.constant 1 : i64 + %c0_i64 = arith.constant 0 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c64_i32 = arith.constant 64 : i32 + %false = arith.constant false + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, + %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, + %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, + i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %group = %c0 to %c8 step %c1 { + %group_base = arith.muli %group, %c128 : index + scf.for %chunk = %c0 to %c128 step %c128 { + %offset = arith.addi %group_base, %chunk : index + %mask, %remaining = pto.plt_b32 %c64_i32 : i32 -> !pto.mask, i32 + %low, %high = pto.vldsx2 %ub_in[%offset], "DINTLV_B8" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + pto.vstsx2 %low, %high, %ub_out[%offset], "INTLV_B8", %mask + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, index, + !pto.mask + } + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, + %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/launch.cpp new file mode 100644 index 000000000..beadc1f7e --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/launch.cpp @@ -0,0 +1,50 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vldsx2-vstsx2-b8-f32 +// family: vector-load-store +// target_ops: pto.vldsx2, pto.vstsx2 +// scenarios: core-f32, full-mask, paired-roundtrip, dintlv-b8-intlv-b8, width-agnostic-dist +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vldx2_vstsx2_b8_f32_kernel(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVldx2_vstsx2_b8_f32_kernel(float *v1, float *v2, void *stream) { + vldx2_vstsx2_b8_f32_kernel<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/main.cpp new file mode 100644 index 000000000..61686d35b --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/main.cpp @@ -0,0 +1,128 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vldsx2-vstsx2-b8-f32 +// family: vector-load-store +// target_ops: pto.vldsx2, pto.vstsx2 +// scenarios: core-f32, full-mask, paired-roundtrip, dintlv-b8-intlv-b8, width-agnostic-dist +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVldx2_vstsx2_b8_f32_kernel(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVldx2_vstsx2_b8_f32_kernel(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/stub.cpp b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/stub.cpp new file mode 100644 index 000000000..dac6ec4d0 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/stub.cpp @@ -0,0 +1,28 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vldsx2-vstsx2-b8-f32 +// family: vector-load-store +// target_ops: pto.vldsx2, pto.vstsx2 +// scenarios: core-f32, full-mask, paired-roundtrip, dintlv-b8-intlv-b8, width-agnostic-dist +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vldx2_vstsx2_b8_f32_kernel(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/compare.py b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/compare.py new file mode 100755 index 000000000..b28c98567 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vldsx2-vstsx2 +# family: vector-load-store +# target_ops: pto.vldsx2, pto.vstsx2 +# scenarios: core-f32, full-mask, paired-roundtrip, dintlv +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/golden.py b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/golden.py new file mode 100755 index 000000000..14d41a9a3 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/golden.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vldsx2-vstsx2 +# family: vector-load-store +# target_ops: pto.vldsx2, pto.vstsx2 +# scenarios: core-f32, full-mask, paired-roundtrip, dintlv +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +ACTIVE = 128 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.array(v1, copy=True) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vldsx2-vstsx2 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/kernel.pto new file mode 100644 index 000000000..12bd7b9a0 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/kernel.pto @@ -0,0 +1,59 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vldsx2-vstsx2 +// family: vector-load-store +// target_ops: pto.vldsx2, pto.vstsx2 +// scenarios: core-f32, full-mask, paired-roundtrip, dintlv +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vldx2_vstsx2_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c128 = arith.constant 128 : index + %c1_i64 = arith.constant 1 : i64 + %c0_i64 = arith.constant 0 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c64_i32 = arith.constant 64 : i32 + %false = arith.constant false + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, + %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, + %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, + i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %group = %c0 to %c8 step %c1 { + %group_base = arith.muli %group, %c128 : index + scf.for %chunk = %c0 to %c128 step %c128 { + %offset = arith.addi %group_base, %chunk : index + %mask, %remaining = pto.plt_b32 %c64_i32 : i32 -> !pto.mask, i32 + %low, %high = pto.vldsx2 %ub_in[%offset], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + pto.vstsx2 %low, %high, %ub_out[%offset], "INTLV_B32", %mask + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, index, + !pto.mask + } + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, + %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/launch.cpp new file mode 100644 index 000000000..d7e4c3fed --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/launch.cpp @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vldsx2-vstsx2 +// family: vector-load-store +// target_ops: pto.vldsx2, pto.vstsx2 +// scenarios: core-f32, full-mask, paired-roundtrip, dintlv +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vldx2_vstsx2_kernel(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVldx2_vstsx2_kernel(float *v1, float *v2, void *stream) { + vldx2_vstsx2_kernel<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/main.cpp new file mode 100644 index 000000000..ec0f59491 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vldsx2-vstsx2 +// family: vector-load-store +// target_ops: pto.vldsx2, pto.vstsx2 +// scenarios: core-f32, full-mask, paired-roundtrip, dintlv +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVldx2_vstsx2_kernel(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVldx2_vstsx2_kernel(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/stub.cpp b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/stub.cpp new file mode 100644 index 000000000..2bd6f87bd --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vldsx2-vstsx2 +// family: vector-load-store +// target_ops: pto.vldsx2, pto.vstsx2 +// scenarios: core-f32, full-mask, paired-roundtrip, dintlv +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vldx2_vstsx2_kernel(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsldb/compare.py b/test/vpto/cases/micro-op/vector-load-store/vsldb/compare.py new file mode 100755 index 000000000..c47755b60 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsldb/compare.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vsldb +# family: vector-load-store +# target_ops: pto.vsldb +# scenarios: core-f32, full-mask, block-strided-load, block-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +PREFIX_ELEMS = 64 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float32, 0.0001, PREFIX_ELEMS) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vsldb/golden.py b/test/vpto/cases/micro-op/vector-load-store/vsldb/golden.py new file mode 100755 index 000000000..3d8dad137 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsldb/golden.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vsldb +# family: vector-load-store +# target_ops: pto.vsldb +# scenarios: core-f32, full-mask, block-strided-load, block-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +BLOCK_STRIDE = 2 +REPEAT_STRIDE = 4 +BLOCK_ELEMS = 8 +BLOCK_COUNT = 8 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + flat_in = v1.reshape(-1) + flat_golden = golden_v2.reshape(-1) + for blk in range(BLOCK_COUNT): + src_blk = REPEAT_STRIDE + blk * BLOCK_STRIDE + flat_golden[blk * BLOCK_ELEMS:(blk + 1) * BLOCK_ELEMS] = flat_in[ + src_blk * BLOCK_ELEMS:(src_blk + 1) * BLOCK_ELEMS + ] + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vsldb validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vsldb/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vsldb/kernel.pto new file mode 100644 index 000000000..9e9e87ff1 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsldb/kernel.pto @@ -0,0 +1,45 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsldb +// family: vector-load-store +// target_ops: pto.vsldb +// scenarios: core-f32, full-mask, block-strided-load, block-mask +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vsldb_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1_i64 = arith.constant 1 : i64 + %c0_i64 = arith.constant 0 : i64 + %c2_i16 = arith.constant 2 : i16 + %c4_i16 = arith.constant 4 : i16 + %c32_i64 = arith.constant 32 : i64 + %c64_i32 = arith.constant 64 : i32 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + %c1 = arith.constant 1 : index + pto.vecscope { + scf.for %iv = %c0 to %c1 step %c1 { + %mask, %next_remaining = pto.plt_b32 %c64_i32 : i32 -> !pto.mask, i32 + %loaded = pto.vsldb %ub_in, %c2_i16, %c4_i16, %mask : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %loaded, %ub_out[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsldb/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vsldb/launch.cpp new file mode 100644 index 000000000..fe71cc2b6 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsldb/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsldb +// family: vector-load-store +// target_ops: pto.vsldb +// scenarios: core-f32, full-mask, block-strided-load, block-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vsldb_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vsldb_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsldb/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vsldb/main.cpp new file mode 100644 index 000000000..f0b21ff83 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsldb/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsldb +// family: vector-load-store +// target_ops: pto.vsldb +// scenarios: core-f32, full-mask, block-strided-load, block-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsldb/stub.cpp b/test/vpto/cases/micro-op/vector-load-store/vsldb/stub.cpp new file mode 100644 index 000000000..c938ff8af --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsldb/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsldb +// family: vector-load-store +// target_ops: pto.vsldb +// scenarios: core-f32, full-mask, block-strided-load, block-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vsldb_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsstb/compare.py b/test/vpto/cases/micro-op/vector-load-store/vsstb/compare.py new file mode 100755 index 000000000..cffa8ea8b --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsstb/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vsstb +# family: vector-load-store +# target_ops: pto.vsstb +# scenarios: core-f32, full-mask, block-strided-store, block-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vsstb/golden.py b/test/vpto/cases/micro-op/vector-load-store/vsstb/golden.py new file mode 100755 index 000000000..033a8b030 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsstb/golden.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vsstb +# family: vector-load-store +# target_ops: pto.vsstb +# scenarios: core-f32, full-mask, block-strided-store, block-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +BLOCK_STRIDE = 2 +REPEAT_STRIDE = 4 +BLOCK_ELEMS = 8 +BLOCK_COUNT = 8 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.array(v1, copy=True) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vsstb validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vsstb/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vsstb/kernel.pto new file mode 100644 index 000000000..b022afee6 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsstb/kernel.pto @@ -0,0 +1,48 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsstb +// family: vector-load-store +// target_ops: pto.vsstb +// scenarios: core-f32, full-mask, block-strided-store, block-mask +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vsstb_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1_i64 = arith.constant 1 : i64 + %c0_i64 = arith.constant 0 : i64 + %c2_i16 = arith.constant 2 : i16 + %c4_i16 = arith.constant 4 : i16 + %c32_i64 = arith.constant 32 : i64 + %c64_i32 = arith.constant 64 : i32 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + %c1 = arith.constant 1 : index + pto.vecscope { + scf.for %iv = %c0 to %c1 step %c1 { + %mask, %next_remaining = pto.plt_b32 %c64_i32 : i32 -> !pto.mask, i32 + %value = pto.vlds %ub_in[%c0] : !pto.ptr -> !pto.vreg<64xf32> + pto.vsstb %value, %ub_out, %c2_i16, %c4_i16, %mask : !pto.vreg<64xf32>, !pto.ptr, i16, i16, !pto.mask + pto.mem_bar "VST_VLD" + %roundtrip = pto.vsldb %ub_out, %c2_i16, %c4_i16, %mask : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %roundtrip, %ub_in[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_in, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsstb/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vsstb/launch.cpp new file mode 100644 index 000000000..95d2a57bd --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsstb/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsstb +// family: vector-load-store +// target_ops: pto.vsstb +// scenarios: core-f32, full-mask, block-strided-store, block-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vsstb_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vsstb_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsstb/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vsstb/main.cpp new file mode 100644 index 000000000..72c683928 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsstb/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsstb +// family: vector-load-store +// target_ops: pto.vsstb +// scenarios: core-f32, full-mask, block-strided-store, block-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsstb/stub.cpp b/test/vpto/cases/micro-op/vector-load-store/vsstb/stub.cpp new file mode 100644 index 000000000..87ba3e783 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsstb/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsstb +// family: vector-load-store +// target_ops: pto.vsstb +// scenarios: core-f32, full-mask, block-strided-store, block-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vsstb_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vstar/compare.py b/test/vpto/cases/micro-op/vector-load-store/vstar/compare.py new file mode 100755 index 000000000..3f233f6e6 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstar/compare.py @@ -0,0 +1,236 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vstar +# family: vector-load-store +# target_ops: pto.vstar +# scenarios: core-f32, full-mask, aligned, state-update +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +CHECK_OFFSET = 1 +CHECK_COUNT = 8 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + golden = np.fromfile("golden_v2.bin", dtype=np.float32) if os.path.exists("golden_v2.bin") else None + output = np.fromfile("v2.bin", dtype=np.float32) if os.path.exists("v2.bin") else None + lo = CHECK_OFFSET + hi = CHECK_OFFSET + CHECK_COUNT + if output is None: + ok = False + print("[ERROR] Output missing: v2.bin") + elif golden is None: + ok = False + print("[ERROR] Golden missing: golden_v2.bin") + elif golden.size < hi or output.size < hi: + ok = False + print( + f"[ERROR] Flush slice too small: need={hi} elems, " + f"golden={golden.size}, out={output.size}" + ) + elif not np.allclose(golden[lo:hi], output[lo:hi], atol=0.0001, rtol=0.0001, equal_nan=True): + g = golden[lo:hi].astype(np.float64, copy=False) + o = output[lo:hi].astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + ok = False + print( + f"[ERROR] Mismatch (flush slice): golden_v2.bin vs v2.bin, max diff={float(abs_diff[idx])} " + f"at idx={lo + idx} (golden={g[idx]}, out={o[idx]}, dtype=float32)" + ) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vstar/golden.py b/test/vpto/cases/micro-op/vector-load-store/vstar/golden.py new file mode 100755 index 000000000..d1a1054ba --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstar/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vstar +# family: vector-load-store +# target_ops: pto.vstar +# scenarios: core-f32, predicate-squeezed, unaligned, state-update +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2.reshape(-1)[1:9] = v1.reshape(-1)[:8] + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vstar validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vstar/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vstar/kernel.pto new file mode 100644 index 000000000..f9ddd4309 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstar/kernel.pto @@ -0,0 +1,61 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vstar +// family: vector-load-store +// target_ops: pto.vstar +// scenarios: core-f32, predicate-squeezed, unaligned, state-update +// ----------------------------------------------------------------------------- +// Validate the final flush step of a stateful store chain. +// The case keeps `pto.vstar` as the target op and uses the minimal required +// setup: +// 1. load one aligned vector from `%ub_in` +// 2. squeeze a small active prefix to prime `SPR SQZN` +// 3. prime one store-state carrier from unaligned `%ub_out` +// 4. issue one `pto.vstur ... "POST_UPDATE"` to create residual state +// 5. flush that residual state with `pto.vstar` +// This makes the observable payload come from `vstar` while keeping the chain +// contract valid per docs. + +module attributes {pto.target_arch = "a5"} { + func.func @vstar_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c8_i32 = arith.constant 8 : i32 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1_elem = arith.constant 1 : index + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out1 = pto.addptr %ub_out, %c1_elem : !pto.ptr -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + pto.sprclr "AR" + scf.for %iter = %c0 to %c1 step %c1 { + %vec = pto.vlds %ub_in[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %mask, %unused = pto.plt_b32 %c8_i32 : i32 -> !pto.mask, i32 + %sqz = pto.vsqz %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %align0 = pto.init_align : !pto.align + %align1 = pto.vstur %align0, %sqz, %ub_out1, "POST_UPDATE" + : !pto.align, !pto.vreg<64xf32>, !pto.ptr -> !pto.align + pto.vstar %align1, %ub_out1 : !pto.align, !pto.ptr + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vstar/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vstar/launch.cpp new file mode 100644 index 000000000..6d4789d7e --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstar/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vstar +// family: vector-load-store +// target_ops: pto.vstar +// scenarios: core-f32, full-mask, aligned, state-update +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vstar_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVstar_kernel_2d(float *v1, float *v2, void *stream) { + vstar_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vstar/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vstar/main.cpp new file mode 100644 index 000000000..0f316a695 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstar/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vstar +// family: vector-load-store +// target_ops: pto.vstar +// scenarios: core-f32, full-mask, aligned, state-update +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVstar_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVstar_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vstar/stub.cpp b/test/vpto/cases/micro-op/vector-load-store/vstar/stub.cpp new file mode 100644 index 000000000..8439d0e02 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstar/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vstar +// family: vector-load-store +// target_ops: pto.vstar +// scenarios: core-f32, full-mask, aligned, state-update +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vstar_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/compare.py b/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/compare.py new file mode 100755 index 000000000..0916b067e --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/compare.py @@ -0,0 +1,208 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vstas-vstus-offset-update +# family: vector-load-store +# target_ops: pto.vstas, pto.vstus +# scenarios: core-f32, full-mask, immediate-offset, state-update +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float32, 0.0001, 69) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/golden.py b/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/golden.py new file mode 100755 index 000000000..b1b68f800 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/golden.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vstas-vstus-offset-update +# family: vector-load-store +# target_ops: pto.vstas, pto.vstus +# scenarios: core-f32, full-mask, immediate-offset, state-update +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMENTS = 1024 +VECTOR_LANES = 64 +POST_UPDATE_OFFSET_ELEMENTS = 3 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ELEMENTS,)).astype(np.float32) + v2 = np.zeros((ELEMENTS,), dtype=np.float32) + golden_v2 = np.zeros((ELEMENTS,), dtype=np.float32) + golden_v2[:POST_UPDATE_OFFSET_ELEMENTS] = v1[:POST_UPDATE_OFFSET_ELEMENTS] + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vstas/vstus chain validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/kernel.pto new file mode 100644 index 000000000..89b918893 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/kernel.pto @@ -0,0 +1,57 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vstas-vstus-offset-update +// family: vector-load-store +// target_ops: pto.vstas, pto.vstus +// scenarios: core-f32, full-mask, immediate-offset, state-update +// ----------------------------------------------------------------------------- +// Validate the state chain required by the plan: +// 1. prime a store-state carrier +// 2. issue one no-post `vstus` with a non-zero explicit offset +// 3. flush the residual state with `vstas` using the same explicit flush point +// The observable effect should match an unaligned store stream where `vstus` +// advances the stream by 3 f32 elements and leaves the buffered tail in +// `!pto.align`, then `vstas` commits that pending tail at the matching flush +// point identified by the original base plus the same scalar offset. + +module attributes {pto.target_arch = "a5"} { + func.func @vstas_vstus_offset_update_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c3_i32 = arith.constant 3 : i32 + %c0_i32 = arith.constant 0 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_out, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %offset = %c0 to %c64 step %c64 { + %align0 = pto.init_align : !pto.align + %vec = pto.vlds %ub_in[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %align1 = pto.vstus %align0, %c3_i32, %vec, %ub_out + : !pto.align, i32, !pto.vreg<64xf32>, !pto.ptr -> !pto.align + pto.vstas %align1, %ub_out, %c3_i32 : !pto.align, !pto.ptr, i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/launch.cpp new file mode 100644 index 000000000..b395937e5 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vstas-vstus-offset-update +// family: vector-load-store +// target_ops: pto.vstas, pto.vstus +// scenarios: core-f32, full-mask, immediate-offset, state-update +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vstas_vstus_offset_update_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVstasVstusOffsetUpdate_kernel_2d(float *v1, float *v2, void *stream) { + vstas_vstus_offset_update_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/main.cpp new file mode 100644 index 000000000..2d2a9469b --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vstas-vstus-offset-update +// family: vector-load-store +// target_ops: pto.vstas, pto.vstus +// scenarios: core-f32, full-mask, immediate-offset, state-update +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVstasVstusOffsetUpdate_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVstasVstusOffsetUpdate_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/stub.cpp b/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/stub.cpp new file mode 100644 index 000000000..c10d4c915 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vstas-vstus-offset-update +// family: vector-load-store +// target_ops: pto.vstas, pto.vstus +// scenarios: core-f32, full-mask, immediate-offset, state-update +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vstas_vstus_offset_update_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/compare.py b/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/compare.py new file mode 100755 index 000000000..d6d773550 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/compare.py @@ -0,0 +1,258 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vsts-1pt-b16 +# family: vector-load-store +# target_ops: pto.vsts +# scenarios: core-i16, full-mask, aligned, dist-1pt-b16 +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +ACTIVE_ELEMS = 1024 +LANES = 128 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def compare_1pt_positions(golden_path, output_path, dtype, active_elems, lanes): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + active_elems = int(active_elems) + lanes = int(lanes) + except Exception: + print(f"[ERROR] Invalid 1PT compare arguments: active_elems={active_elems} lanes={lanes}") + return False + if active_elems <= 0 or lanes <= 0: + print(f"[ERROR] Invalid 1PT compare arguments: active_elems={active_elems} lanes={lanes}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + return False + + positions = np.arange(0, active_elems, lanes, dtype=np.int64) + if positions.size == 0: + print("[ERROR] No 1PT positions selected") + return False + if positions[-1] >= golden.size: + print( + f"[ERROR] 1PT positions out of range: last={int(positions[-1])} size={golden.size}" + ) + return False + + golden_sel = golden[positions] + output_sel = output[positions] + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + pos = int(positions[idx]) + print( + f"[ERROR] Mismatch (1PT positions): idx={pos} " + f"golden={int(golden_sel[idx])} out={int(output_sel[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_1pt_positions("golden_v2.bin", "v2.bin", np.int16, ACTIVE_ELEMS, LANES) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/golden.py b/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/golden.py new file mode 100755 index 000000000..1ed2947c7 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/golden.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vsts-1pt-b16 +# family: vector-load-store +# target_ops: pto.vsts +# scenarios: core-i16, full-mask, aligned, dist-1pt-b16 +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMENTS = 2048 +ACTIVE_ELEMS = 1024 +LANES = 128 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(-(2**15), 2**15, size=(ELEMENTS,), dtype=np.int16) + v2 = np.zeros((ELEMENTS,), dtype=np.int16) + golden_v2 = np.zeros((ELEMENTS,), dtype=np.int16) + for offset in range(0, ACTIVE_ELEMS, LANES): + golden_v2[offset] = v1[offset] + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vsts 1PT validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/kernel.pto new file mode 100644 index 000000000..a45e227dd --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/kernel.pto @@ -0,0 +1,48 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsts-1pt-b16 +// family: vector-load-store +// target_ops: pto.vsts +// scenarios: core-i16, full-mask, aligned, dist-1pt-b16 +// ----------------------------------------------------------------------------- +// Validate one representative `1PT_B16` store distribution on `b16`. + +module attributes {pto.target_arch = "a5"} { + func.func @vsts_1pt_b16_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xi16> + pto.vsts %vec, %ub_out[%offset], %mask {dist = "1PT_B16"} : !pto.vreg<128xi16>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/launch.cpp new file mode 100644 index 000000000..3514bffb8 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/launch.cpp @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsts +// family: vector-load-store +// target_ops: pto.vsts +// scenarios: core-f32, contiguous, full-mask, aligned, dist-norm +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vsts_1pt_b16_kernel(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vsts_1pt_b16_kernel<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/main.cpp new file mode 100644 index 000000000..6bc7026e2 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsts +// family: vector-load-store +// target_ops: pto.vsts +// scenarios: core-f32, contiguous, full-mask, aligned, dist-norm +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/stub.cpp b/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/stub.cpp new file mode 100644 index 000000000..613c5ea44 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsts +// family: vector-load-store +// target_ops: pto.vsts +// scenarios: core-f32, contiguous, full-mask, aligned, dist-norm +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vsts_1pt_b16_kernel(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/compare.py b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/compare.py new file mode 100755 index 000000000..058c478a5 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/compare.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vsts-pk-b16 +# family: vector-load-store +# target_ops: pto.vsts +# scenarios: core-i16, full-mask, aligned, dist-pk-b16 +# coding=utf-8 + +import os +import sys +import numpy as np + +OUTPUT_BUFFER_BYTES = 4096 +# Keep this aligned with kernel.pto loop bound (offset: 0..1024 step 128 on i16). +ACTIVE_ELEMS = 1024 +LANES = 128 +BYTES_PER_ELEM = 2 + + +def build_checked_mask(total_bytes): + # For this case kernel: + # - loop offset: 0..1024 step 128 (i16 elements) + # - dist=PK_B16 stores 1 byte per active i16 element + # So each iteration writes 128 bytes at dst_byte_base = offset * 2. + mask = np.zeros((total_bytes,), dtype=bool) + for offset in range(0, ACTIVE_ELEMS, LANES): + dst_byte_base = offset * BYTES_PER_ELEM + mask[dst_byte_base : dst_byte_base + LANES] = True + return mask + + +def compare_bin(golden_path, output_path): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + return False + + if golden.size != OUTPUT_BUFFER_BYTES: + print( + f"[ERROR] Unexpected byte size for this case: got {golden.size}, expected {OUTPUT_BUFFER_BYTES}" + ) + return False + + checked = build_checked_mask(golden.size) + checked_golden = golden[checked] + checked_output = output[checked] + if not np.array_equal(checked_golden, checked_output): + diff = np.nonzero(checked_golden != checked_output)[0] + idx = int(diff[0]) if diff.size else 0 + global_idx = int(np.nonzero(checked)[0][idx]) if diff.size else 0 + print( + f"[ERROR] Mismatch (checked footprint): {golden_path} vs {output_path}, " + f"first diff at checked_idx={idx}, global_idx={global_idx} " + f"(golden=0x{int(checked_golden[idx]):02x}, out=0x{int(checked_output[idx]):02x})" + ) + return False + print( + f"[INFO] compared writable footprint only: {int(np.count_nonzero(checked))}/{golden.size} bytes" + ) + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/golden.py b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/golden.py new file mode 100755 index 000000000..b0ebf667c --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/golden.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vsts-pk-b16 +# family: vector-load-store +# target_ops: pto.vsts +# scenarios: core-i16, full-mask, aligned, dist-pk-b16 +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +OUTPUT_BUFFER_BYTES = 4096 +TOTAL_ELEMS_I16 = OUTPUT_BUFFER_BYTES // 2 +# This case kernel only iterates 0..1024 on i16 lanes, so only 1024 packed bytes +# are semantically writable by vsts(pk_b16) in this testcase. +ACTIVE_ELEMS = 1024 +LANES = 128 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(-(2**15), 2**15, size=(TOTAL_ELEMS_I16,), dtype=np.int16) + v2 = rng.integers(0, 256, size=(OUTPUT_BUFFER_BYTES,), dtype=np.uint8) + golden_v2 = v2.copy() + + # PK_B16: write low 8 bits of each active b16 element as a compact byte stream. + # Destination address is unchanged for non-post-update form; within each 256B + # lane chunk only the first 128B are overwritten. + v1_u16 = v1.view(np.uint16) + packed_bytes_per_chunk = LANES + for offset in range(0, ACTIVE_ELEMS, LANES): + src = v1_u16[offset : offset + LANES] + packed = (src & 0x00FF).astype(np.uint8) + dst_byte_base = offset * 2 + golden_v2[dst_byte_base : dst_byte_base + packed_bytes_per_chunk] = packed + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vsts PK_B16 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/kernel.pto new file mode 100644 index 000000000..29ef8cd9c --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/kernel.pto @@ -0,0 +1,48 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsts-pk-b16 +// family: vector-load-store +// target_ops: pto.vsts +// scenarios: core-i16, full-mask, aligned, dist-pk-b16 +// ----------------------------------------------------------------------------- +// Validate one representative `PK_B16` store distribution on `b16`. + +module attributes {pto.target_arch = "a5"} { + func.func @vsts_pk_b16_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xi16> + pto.vsts %vec, %ub_out[%offset], %mask {dist = "PK_B16"} : !pto.vreg<128xi16>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/launch.cpp new file mode 100644 index 000000000..9a902908c --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/launch.cpp @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsts +// family: vector-load-store +// target_ops: pto.vsts +// scenarios: core-f32, contiguous, full-mask, aligned, dist-norm +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vsts_pk_b16_kernel(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vsts_pk_b16_kernel<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/main.cpp new file mode 100644 index 000000000..6bc7026e2 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsts +// family: vector-load-store +// target_ops: pto.vsts +// scenarios: core-f32, contiguous, full-mask, aligned, dist-norm +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/stub.cpp b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/stub.cpp new file mode 100644 index 000000000..72cacdb60 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsts +// family: vector-load-store +// target_ops: pto.vsts +// scenarios: core-f32, contiguous, full-mask, aligned, dist-norm +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vsts_pk_b16_kernel(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/compare.py b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/compare.py new file mode 100644 index 000000000..4c19eb038 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/compare.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/golden.py b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/golden.py new file mode 100644 index 000000000..c5635db7c --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/golden.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vsts-pk-b64-f32 +# family: vector-load-store +# target_ops: pto.vsts +# scenarios: core-f32, full-mask, aligned, dist-pk-b64, width-agnostic-dist + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMENTS = 1024 +LANES = 64 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ELEMENTS,)).astype(np.float32) + v2 = rng.uniform(-8.0, 8.0, size=(ELEMENTS,)).astype(np.float32) + golden_v2 = np.array(v2, copy=True) + + for offset in range(0, ELEMENTS, LANES): + chunk = v1[offset : offset + LANES] + packed = chunk[0::2] + golden_v2[offset : offset + packed.size] = packed + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vsts PK_B64 on f32 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/kernel.pto new file mode 100644 index 000000000..797b25ff5 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/kernel.pto @@ -0,0 +1,48 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsts-pk-b64-f32 +// family: vector-load-store +// target_ops: pto.vsts +// scenarios: core-f32, full-mask, aligned, dist-pk-b64, width-agnostic-dist +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vsts_pk_b64_f32_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg1, %ub_out, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + pto.vsts %vec, %ub_out[%offset], %mask {dist = "PK_B64"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/launch.cpp new file mode 100644 index 000000000..039f11e40 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vsts_pk_b64_f32_kernel(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVsts_pk_b64_f32_kernel(float *v1, float *v2, void *stream) { + vsts_pk_b64_f32_kernel<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/main.cpp new file mode 100644 index 000000000..707a88e05 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/main.cpp @@ -0,0 +1,122 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVsts_pk_b64_f32_kernel(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVsts_pk_b64_f32_kernel(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/stub.cpp b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/stub.cpp new file mode 100644 index 000000000..1ea4bc718 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/stub.cpp @@ -0,0 +1,21 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vsts_pk_b64_f32_kernel(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-tail/compare.py b/test/vpto/cases/micro-op/vector-load-store/vsts-tail/compare.py new file mode 100755 index 000000000..1821ec6aa --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-tail/compare.py @@ -0,0 +1,257 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vsts-tail +# family: vector-load-store +# target_ops: pto.vsts +# scenarios: core-f32, contiguous, tail-mask, aligned, dist-norm +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_window(golden_path, output_path, dtype, eps, offset, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + offset = int(offset) + count = int(count) + except Exception: + print(f"[ERROR] Invalid compare window: offset={offset} count={count}") + return False + if offset < 0 or count <= 0: + print(f"[ERROR] Invalid compare window: offset={offset} count={count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + end = offset + count + if golden.size < end or output.size < end: + print( + f"[ERROR] Compare window out of range: offset={offset} count={count}, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[offset:end] + output_sel = output[offset:end] + if not np.allclose(golden_sel, output_sel, atol=eps, rtol=eps, equal_nan=True): + if golden_sel.size: + g = golden_sel.astype(np.float64, copy=False) + o = output_sel.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (window): {golden_path} vs {output_path}, max diff={diff} " + f"at idx={offset + idx} (golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, " + f"offset={offset}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (window): {golden_path} vs {output_path}, empty window, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin_window("golden_v2.bin", "v2.bin", np.float32, 0.0001, 0, 13) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-tail/golden.py b/test/vpto/cases/micro-op/vector-load-store/vsts-tail/golden.py new file mode 100755 index 000000000..73bd90f99 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-tail/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vsts-tail +# family: vector-load-store +# target_ops: pto.vsts +# scenarios: core-f32, contiguous, tail-mask, aligned, dist-norm +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +ACTIVE = 13 +SEED = 19 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2.reshape(-1)[:ACTIVE] = v1.reshape(-1)[:ACTIVE] + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vsts-tail validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-tail/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vsts-tail/kernel.pto new file mode 100644 index 000000000..da446a7fe --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-tail/kernel.pto @@ -0,0 +1,50 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsts-tail +// family: vector-load-store +// target_ops: pto.vsts +// scenarios: core-f32, contiguous, tail-mask, aligned, dist-norm +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vsts_tail_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1_i64 = arith.constant 1 : i64 + %c0_i64 = arith.constant 0 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c13_i32 = arith.constant 13 : i32 + %false = arith.constant false + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, + %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, + %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, + i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + %c1 = arith.constant 1 : index + pto.vecscope { + scf.for %iv = %c0 to %c1 step %c1 { + %mask, %remaining = pto.plt_b32 %c13_i32 : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%c0] : !pto.ptr -> !pto.vreg<64xf32> + pto.vsts %vec, %ub_out[%c0], %mask + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, + %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-tail/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vsts-tail/launch.cpp new file mode 100644 index 000000000..2a94e832d --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-tail/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsts-tail +// family: vector-load-store +// target_ops: pto.vsts +// scenarios: core-f32, contiguous, tail-mask, aligned, dist-norm +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vsts_tail_kernel(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vsts_tail_kernel<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-tail/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vsts-tail/main.cpp new file mode 100644 index 000000000..f8da4c77a --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-tail/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsts-tail +// family: vector-load-store +// target_ops: pto.vsts +// scenarios: core-f32, contiguous, tail-mask, aligned, dist-norm +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-tail/stub.cpp b/test/vpto/cases/micro-op/vector-load-store/vsts-tail/stub.cpp new file mode 100644 index 000000000..d9c810f1a --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-tail/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsts-tail +// family: vector-load-store +// target_ops: pto.vsts +// scenarios: core-f32, contiguous, tail-mask, aligned, dist-norm +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vsts_tail_kernel(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts/compare.py b/test/vpto/cases/micro-op/vector-load-store/vsts/compare.py new file mode 100755 index 000000000..dc064cb22 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vsts +# family: vector-load-store +# target_ops: pto.vsts +# scenarios: core-f32, contiguous, full-mask, aligned, dist-norm +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts/golden.py b/test/vpto/cases/micro-op/vector-load-store/vsts/golden.py new file mode 100755 index 000000000..9eb6e0453 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vsts +# family: vector-load-store +# target_ops: pto.vsts +# scenarios: core-f32, contiguous, full-mask, aligned, dist-norm +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = v1.astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vsts validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vsts/kernel.pto new file mode 100644 index 000000000..07e60de3d --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts/kernel.pto @@ -0,0 +1,67 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsts +// family: vector-load-store +// target_ops: pto.vsts +// scenarios: core-f32, contiguous, full-mask, aligned, dist-norm +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5"} { + func.func @vsts_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + pto.vsts %vec, %ub_out[%offset], %mask {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vsts/launch.cpp new file mode 100644 index 000000000..851e10299 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsts +// family: vector-load-store +// target_ops: pto.vsts +// scenarios: core-f32, contiguous, full-mask, aligned, dist-norm +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vsts_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vsts_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vsts/main.cpp new file mode 100644 index 000000000..6bc7026e2 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsts +// family: vector-load-store +// target_ops: pto.vsts +// scenarios: core-f32, contiguous, full-mask, aligned, dist-norm +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts/stub.cpp b/test/vpto/cases/micro-op/vector-load-store/vsts/stub.cpp new file mode 100644 index 000000000..033c8aa50 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsts +// family: vector-load-store +// target_ops: pto.vsts +// scenarios: core-f32, contiguous, full-mask, aligned, dist-norm +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vsts_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/compare.py b/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/compare.py new file mode 100755 index 000000000..b2a31f90e --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/compare.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vstsx2-layout-check +# family: vector-load-store +# target_ops: pto.vstsx2 +# scenarios: core-f32, full-mask, paired-roundtrip, dintlv, lane-order +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +PREFIX_ELEMS = 128 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float32, 0.0001, PREFIX_ELEMS) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/golden.py b/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/golden.py new file mode 100755 index 000000000..24665fcf7 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/golden.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vstsx2-layout-check +# family: vector-load-store +# target_ops: pto.vstsx2 +# scenarios: core-f32, full-mask, paired-roundtrip, dintlv, lane-order +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +ACTIVE = 128 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + flat_in = v1.reshape(-1) + flat_golden = golden_v2.reshape(-1) + flat_golden[:ACTIVE:2] = flat_in[:64] + flat_golden[1:ACTIVE:2] = flat_in[64:128] + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vstsx2 layout validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/kernel.pto new file mode 100644 index 000000000..9c54c8af1 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/kernel.pto @@ -0,0 +1,54 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vstsx2-layout-check +// family: vector-load-store +// target_ops: pto.vstsx2 +// scenarios: core-f32, full-mask, paired-roundtrip, dintlv, lane-order +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vstsx2_layout_check_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1_i64 = arith.constant 1 : i64 + %c0_i64 = arith.constant 0 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c64_i32 = arith.constant 64 : i32 + %false = arith.constant false + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, + %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, + %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, + i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + %c1 = arith.constant 1 : index + pto.vecscope { + scf.for %iv = %c0 to %c1 step %c1 { + %mask, %remaining = pto.plt_b32 %c64_i32 : i32 -> !pto.mask, i32 + %x = pto.vlds %ub_in[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %y = pto.vlds %ub_in[%c64] : !pto.ptr -> !pto.vreg<64xf32> + pto.vstsx2 %x, %y, %ub_out[%c0], "INTLV_B32", %mask + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, index, + !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, + %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/launch.cpp new file mode 100644 index 000000000..bf28f8403 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/launch.cpp @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vstsx2-layout-check +// family: vector-load-store +// target_ops: pto.vstsx2 +// scenarios: core-f32, full-mask, paired-roundtrip, dintlv, lane-order +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vstsx2_layout_check_kernel(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVstsx2_layout_check_kernel(float *v1, float *v2, void *stream) { + vstsx2_layout_check_kernel<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/main.cpp new file mode 100644 index 000000000..1e380b6b2 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vstsx2-layout-check +// family: vector-load-store +// target_ops: pto.vstsx2 +// scenarios: core-f32, full-mask, paired-roundtrip, dintlv, lane-order +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVstsx2_layout_check_kernel(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVstsx2_layout_check_kernel(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/stub.cpp b/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/stub.cpp new file mode 100644 index 000000000..2daf274bb --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vstsx2-layout-check +// family: vector-load-store +// target_ops: pto.vstsx2 +// scenarios: core-f32, full-mask, paired-roundtrip, dintlv, lane-order +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vstsx2_layout_check_kernel(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/compare.py b/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/compare.py new file mode 100644 index 000000000..fde3a5229 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/compare.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vstur-init-align-outside-loop +# family: vector-load-store +# target_ops: pto.vstur +# scenarios: core-f32, full-mask, unaligned, state-update, init-align-outside-loop +# coding=utf-8 + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_window(golden_path, output_path, dtype, eps, offset, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + offset = int(offset) + count = int(count) + except Exception: + print(f"[ERROR] Invalid compare window: offset={offset} count={count}") + return False + if offset < 0 or count <= 0: + print(f"[ERROR] Invalid compare window: offset={offset} count={count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + end = offset + count + if golden.size < end or output.size < end: + print( + f"[ERROR] Compare window out of range: offset={offset} count={count}, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[offset:end] + output_sel = output[offset:end] + if not np.allclose(golden_sel, output_sel, atol=eps, rtol=eps, equal_nan=True): + if golden_sel.size: + g = golden_sel.astype(np.float64, copy=False) + o = output_sel.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (window): {golden_path} vs {output_path}, max diff={diff} " + f"at idx={offset + idx} (golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, " + f"offset={offset}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (window): {golden_path} vs {output_path}, empty window, dtype={dtype_np}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_window("golden_v2.bin", "v2.bin", np.float32, 0.0001, 1, 8) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/golden.py b/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/golden.py new file mode 100644 index 000000000..d13ca8097 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/golden.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vstur-init-align-outside-loop +# family: vector-load-store +# target_ops: pto.vstur +# scenarios: core-f32, predicate-squeezed, unaligned, state-update, init-align-outside-loop +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +ACTIVE_LANES = 8 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2.reshape(-1)[1 : 1 + ACTIVE_LANES] = v1.reshape(-1)[:ACTIVE_LANES] + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vstur-init-align-outside-loop validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/kernel.pto new file mode 100644 index 000000000..71d6470e6 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/kernel.pto @@ -0,0 +1,51 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vstur-init-align-outside-loop +// family: vector-load-store +// target_ops: pto.vstur +// scenarios: core-f32, predicate-squeezed, unaligned, state-update, init-align-outside-loop +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vstur_init_align_outside_loop_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c8_i32 = arith.constant 8 : i32 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c0 = arith.constant 0 : index + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out1 = pto.addptr %ub_out, %c1 : !pto.ptr -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + pto.sprclr "AR" + %align0 = pto.init_align : !pto.align + %align_final = scf.for %offset = %c0 to %c1 step %c1 + iter_args(%align_iter = %align0) -> (!pto.align) { + %vec = pto.vlds %ub_in[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %mask, %unused = pto.plt_b32 %c8_i32 : i32 -> !pto.mask, i32 + %sqz = pto.vsqz %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %align1 = pto.vstur %align_iter, %sqz, %ub_out1, "POST_UPDATE" + : !pto.align, !pto.vreg<64xf32>, !pto.ptr -> !pto.align + scf.yield %align1 : !pto.align + } + pto.vstar %align_final, %ub_out1 : !pto.align, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/launch.cpp new file mode 100644 index 000000000..a56f83c6e --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/launch.cpp @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vstur-init-align-outside-loop +// family: vector-load-store +// target_ops: pto.vstur +// scenarios: core-f32, full-mask, unaligned, state-update, init-align-outside-loop +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vstur_init_align_outside_loop_kernel_2d(__gm__ float *v1, __gm__ float *v2); + +void LaunchVstur_init_align_outside_loop_kernel_2d(float *v1, float *v2, + void *stream) { + vstur_init_align_outside_loop_kernel_2d<<<1, nullptr, stream>>>( + (__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/main.cpp new file mode 100644 index 000000000..486ca9862 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/main.cpp @@ -0,0 +1,129 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vstur-init-align-outside-loop +// family: vector-load-store +// target_ops: pto.vstur +// scenarios: core-f32, full-mask, unaligned, state-update, init-align-outside-loop +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVstur_init_align_outside_loop_kernel_2d(float *v1, float *v2, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVstur_init_align_outside_loop_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/stub.cpp b/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/stub.cpp new file mode 100644 index 000000000..b9d832003 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/stub.cpp @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vstur-init-align-outside-loop +// family: vector-load-store +// target_ops: pto.vstur +// scenarios: core-f32, full-mask, unaligned, state-update, init-align-outside-loop +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void +vstur_init_align_outside_loop_kernel_2d(__gm__ float *v1, __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vstur/compare.py b/test/vpto/cases/micro-op/vector-load-store/vstur/compare.py new file mode 100755 index 000000000..80b4dab8a --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstur/compare.py @@ -0,0 +1,257 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vstur +# family: vector-load-store +# target_ops: pto.vstur +# scenarios: core-f32, full-mask, unaligned, state-update +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_window(golden_path, output_path, dtype, eps, offset, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + offset = int(offset) + count = int(count) + except Exception: + print(f"[ERROR] Invalid compare window: offset={offset} count={count}") + return False + if offset < 0 or count <= 0: + print(f"[ERROR] Invalid compare window: offset={offset} count={count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + end = offset + count + if golden.size < end or output.size < end: + print( + f"[ERROR] Compare window out of range: offset={offset} count={count}, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[offset:end] + output_sel = output[offset:end] + if not np.allclose(golden_sel, output_sel, atol=eps, rtol=eps, equal_nan=True): + if golden_sel.size: + g = golden_sel.astype(np.float64, copy=False) + o = output_sel.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (window): {golden_path} vs {output_path}, max diff={diff} " + f"at idx={offset + idx} (golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, " + f"offset={offset}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (window): {golden_path} vs {output_path}, empty window, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin_window("golden_v2.bin", "v2.bin", np.float32, 0.0001, 1, 8) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vstur/golden.py b/test/vpto/cases/micro-op/vector-load-store/vstur/golden.py new file mode 100755 index 000000000..96b3c4030 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstur/golden.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vstur +# family: vector-load-store +# target_ops: pto.vstur +# scenarios: core-f32, predicate-squeezed, unaligned, state-update +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +ACTIVE_LANES = 8 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2.reshape(-1)[1 : 1 + ACTIVE_LANES] = v1.reshape(-1)[:ACTIVE_LANES] + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vstur validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vstur/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vstur/kernel.pto new file mode 100644 index 000000000..6c194fde5 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstur/kernel.pto @@ -0,0 +1,59 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vstur +// family: vector-load-store +// target_ops: pto.vstur +// scenarios: core-f32, predicate-squeezed, unaligned, state-update +// ----------------------------------------------------------------------------- +// Validate the standalone `vstur` surface with its required SQZN producer. +// The case keeps the sequence minimal: +// 1. load one vector from `%ub_in` +// 2. generate a small predicate and squeeze the vector to prime `SPR SQZN` +// 3. prime one store-state carrier from `%ub_out` +// 4. issue one `pto.vstur ... "POST_UPDATE"` +// 5. flush the residual state with `pto.vstar` +// This preserves the testcase goal around unaligned store state update without +// fabricating extra semantics beyond the installed A5 wrapper contract. + +module attributes {pto.target_arch = "a5"} { + func.func @vstur_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c8_i32 = arith.constant 8 : i32 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c0 = arith.constant 0 : index + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out1 = pto.addptr %ub_out, %c1 : !pto.ptr -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + pto.sprclr "AR" + scf.for %offset = %c0 to %c1 step %c1 { + %vec = pto.vlds %ub_in[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %mask, %unused = pto.plt_b32 %c8_i32 : i32 -> !pto.mask, i32 + %sqz = pto.vsqz %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %align0 = pto.init_align : !pto.align + %align1 = pto.vstur %align0, %sqz, %ub_out1, "POST_UPDATE" + : !pto.align, !pto.vreg<64xf32>, !pto.ptr -> !pto.align + pto.vstar %align1, %ub_out1 : !pto.align, !pto.ptr + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vstur/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vstur/launch.cpp new file mode 100644 index 000000000..b0c69d79c --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstur/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vstur +// family: vector-load-store +// target_ops: pto.vstur +// scenarios: core-f32, full-mask, unaligned, state-update +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vstur_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVstur_kernel_2d(float *v1, float *v2, void *stream) { + vstur_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vstur/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vstur/main.cpp new file mode 100644 index 000000000..273fb30c4 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstur/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vstur +// family: vector-load-store +// target_ops: pto.vstur +// scenarios: core-f32, full-mask, unaligned, state-update +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVstur_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVstur_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vstur/stub.cpp b/test/vpto/cases/micro-op/vector-load-store/vstur/stub.cpp new file mode 100644 index 000000000..fcec1a134 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstur/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vstur +// family: vector-load-store +// target_ops: pto.vstur +// scenarios: core-f32, full-mask, unaligned, state-update +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vstur_kernel_2d(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/npu_validation/common/test_common.h b/test/vpto/npu_validation/common/test_common.h new file mode 100644 index 000000000..3cbb7a3e3 --- /dev/null +++ b/test/vpto/npu_validation/common/test_common.h @@ -0,0 +1,61 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace PtoTestCommon { + +inline bool ReadFile(const std::string &filePath, size_t &fileSize, void *buffer, size_t bufferSize) { + struct stat sBuf; + if (stat(filePath.c_str(), &sBuf) == -1) { + return false; + } + if (!S_ISREG(sBuf.st_mode)) { + return false; + } + + std::ifstream file(filePath, std::ios::binary); + if (!file.is_open()) { + return false; + } + + std::filebuf *buf = file.rdbuf(); + size_t size = buf->pubseekoff(0, std::ios::end, std::ios::in); + if (size == 0 || size > bufferSize) { + return false; + } + buf->pubseekpos(0, std::ios::in); + buf->sgetn(static_cast(buffer), size); + fileSize = size; + return true; +} + +inline bool WriteFile(const std::string &filePath, const void *buffer, size_t size) { + if (buffer == nullptr) { + return false; + } + + int fd = open(filePath.c_str(), O_RDWR | O_CREAT | O_TRUNC, S_IRUSR | S_IWRITE); + if (fd < 0) { + return false; + } + + ssize_t writeSize = write(fd, buffer, size); + (void)close(fd); + return writeSize == static_cast(size); +} + +} // namespace PtoTestCommon diff --git a/test/vpto/scripts/run_host_vpto_validation.sh b/test/vpto/scripts/run_host_vpto_validation.sh new file mode 100755 index 000000000..a748ac95a --- /dev/null +++ b/test/vpto/scripts/run_host_vpto_validation.sh @@ -0,0 +1,496 @@ +#!/usr/bin/env bash +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT_DIR="$(cd "${SCRIPT_DIR}/../../.." && pwd)" +VPTO_ROOT="${VPTO_ROOT:-${ROOT_DIR}/test/vpto/cases}" +CASES_ROOT="${CASES_ROOT:-${VPTO_ROOT}}" +NPU_VALIDATION_COMMON_DIR="${NPU_VALIDATION_COMMON_DIR:-${ROOT_DIR}/test/vpto/npu_validation/common}" + +WORK_SPACE="${WORK_SPACE:-}" +ASCEND_HOME_PATH="${ASCEND_HOME_PATH:-}" +PTOAS_BIN="${PTOAS_BIN:-${ROOT_DIR}/build/tools/ptoas/ptoas}" +PTOAS_FLAGS="${PTOAS_FLAGS:---pto-arch a5}" +VPTO_FLAGS="${VPTO_FLAGS:---pto-backend=vpto --vpto-emit-hivm-llvm}" +AICORE_ARCH="${AICORE_ARCH:-dav-c310-vec}" +# set he HOST_RUNNER to "ssh root@localhost" if must change user to root to access the device +HOST_RUNNER="${HOST_RUNNER:-}" +CASE_NAME="${CASE_NAME:-}" +MODULE_ID="${MODULE_ID:-a5d60abf67864aa0}" +DEVICE="${DEVICE:-SIM}" +SIM_LIB_DIR="${SIM_LIB_DIR:-}" +COMPILE_ONLY="${COMPILE_ONLY:-0}" + +log() { + echo "[$(date +'%F %T')] $*" +} + +die() { + echo "ERROR: $*" >&2 + exit 1 +} + +run_remote() { + local cmd="$1" + if [[ "${HOST_RUNNER}" == "ssh root@localhost" ]]; then + ssh -o StrictHostKeyChecking=no root@localhost "${cmd}" + else + bash -lc "${cmd}" + fi +} + +require_env() { + local name="$1" + local value="$2" + if [[ -z "${value}" ]]; then + die "${name} is required" + fi +} + +require_env "WORK_SPACE" "${WORK_SPACE}" +require_env "ASCEND_HOME_PATH" "${ASCEND_HOME_PATH}" +[[ -x "${PTOAS_BIN}" ]] || die "PTOAS_BIN is not executable: ${PTOAS_BIN}" +[[ -d "${CASES_ROOT}" ]] || die "missing cases root: ${CASES_ROOT}" + +if [[ -f "${ASCEND_HOME_PATH}/set_env.sh" ]]; then + set +u + source "${ASCEND_HOME_PATH}/set_env.sh" >/dev/null 2>&1 + set -u +fi + +resolve_sim_lib_dir() { + if [[ "${DEVICE}" != "SIM" ]]; then + return 0 + fi + + if [[ -n "${SIM_LIB_DIR}" ]]; then + [[ -d "${SIM_LIB_DIR}" ]] || + die "SIM_LIB_DIR is set but invalid: ${SIM_LIB_DIR}" + return 0 + fi + + local -a candidates=() + readarray -t candidates < <( + find "${ASCEND_HOME_PATH}" -type d -path '*/simulator/dav_3510/lib' | sort + ) + + if [[ "${#candidates[@]}" -eq 1 ]]; then + SIM_LIB_DIR="${candidates[0]}" + log "SIM_LIB_DIR is unset; auto-selected: ${SIM_LIB_DIR}" + return 0 + fi + + if [[ "${#candidates[@]}" -gt 1 ]]; then + SIM_LIB_DIR="${candidates[0]}" + log "SIM_LIB_DIR is unset; multiple dav_3510 simulator dirs found, using: ${SIM_LIB_DIR}" + return 0 + fi + + die "SIM_LIB_DIR is required for DEVICE=SIM and no dav_3510 simulator lib dir was found under: ${ASCEND_HOME_PATH}" +} + +resolve_sim_lib_dir + +BISHENG_BIN="${BISHENG_BIN:-${ASCEND_HOME_PATH}/bin/bisheng}" +BISHENG_CC1_BIN="${BISHENG_CC1_BIN:-${ASCEND_HOME_PATH}/tools/bisheng_compiler/bin/bisheng}" +CCE_LD_BIN="${CCE_LD_BIN:-${ASCEND_HOME_PATH}/bin/cce-ld}" +LD_LLD_BIN="${LD_LLD_BIN:-${ASCEND_HOME_PATH}/bin/ld.lld}" +CLANG_RESOURCE_DIR="${CLANG_RESOURCE_DIR:-${ASCEND_HOME_PATH}/tools/bisheng_compiler/lib/clang/15.0.5}" +CCE_STUB_DIR="${CCE_STUB_DIR:-${CLANG_RESOURCE_DIR}/include/cce_stub}" + +HOST_ARCH="$(uname -m)" +HOST_TRIPLE="" +HOST_TARGET_CPU="" +HOST_TARGET_ABI="" +HOST_FEATURE_FLAGS=() +HOST_OS_DIR="" + +case "${HOST_ARCH}" in + aarch64) + HOST_TRIPLE="aarch64-unknown-linux-gnu" + HOST_TARGET_CPU="generic" + HOST_TARGET_ABI="aapcs" + HOST_FEATURE_FLAGS=(-target-feature +neon -target-feature +v8a) + HOST_OS_DIR="aarch64-linux" + ;; + x86_64) + HOST_TRIPLE="x86_64-unknown-linux-gnu" + HOST_TARGET_CPU="x86-64" + HOST_OS_DIR="x86_64-linux" + ;; + *) + die "unsupported host arch from uname -m: ${HOST_ARCH}" + ;; +esac + +command -v "${BISHENG_BIN}" >/dev/null 2>&1 || die "bisheng not found: ${BISHENG_BIN}" +command -v python3 >/dev/null 2>&1 || die "python3 not found" + +readarray -t BISHENG_SYSTEM_INCLUDES < <( + "${BISHENG_BIN}" -xc++ -E -v - &1 | + awk ' + /#include <...> search starts here:/ {capture=1; next} + /End of search list\./ {capture=0} + capture && $0 ~ /^ / {sub(/^ +/, "", $0); print} + ' +) + +[[ "${#BISHENG_SYSTEM_INCLUDES[@]}" -gt 0 ]] || die "failed to discover bisheng system include directories" + +CC1_INCLUDE_FLAGS=() +for inc in "${BISHENG_SYSTEM_INCLUDES[@]}"; do + if [[ "${inc}" == */include/c++/* || "${inc}" == */backward ]]; then + CC1_INCLUDE_FLAGS+=(-internal-isystem "${inc}") + elif [[ "${inc}" == "/usr/include" ]]; then + CC1_INCLUDE_FLAGS+=(-internal-externc-isystem "${inc}") + else + CC1_INCLUDE_FLAGS+=(-internal-isystem "${inc}") + fi +done + +mkdir -p "${WORK_SPACE}" +WORK_SPACE="$(cd "${WORK_SPACE}" && pwd)" + +discover_cases() { + local required_files=( + kernel.pto + stub.cpp + launch.cpp + main.cpp + golden.py + compare.py + ) + + if [[ -n "${CASE_NAME}" ]]; then + local requested_dir="${CASES_ROOT}/${CASE_NAME}" + [[ -d "${requested_dir}" ]] || die "unknown case: ${CASE_NAME}" + for f in "${required_files[@]}"; do + [[ -f "${requested_dir}/${f}" ]] || die "case ${CASE_NAME} is missing ${f}" + done + printf "%s\n" "${CASE_NAME#/}" + return 0 + fi + + find "${CASES_ROOT}" -mindepth 1 -type d | sort | while read -r dir; do + local ok=1 + for f in "${required_files[@]}"; do + if [[ ! -f "${dir}/${f}" ]]; then + ok=0 + break + fi + done + [[ "${ok}" -eq 1 ]] || continue + local rel="${dir#${CASES_ROOT}/}" + printf "%s\n" "${rel}" + done +} + +readarray -t CASES < <(discover_cases) +[[ "${#CASES[@]}" -gt 0 ]] || die "no cases found under ${CASES_ROOT}" + +build_launch_object() { + local case_dir="$1" + local out_obj="$2" + + "${BISHENG_BIN}" \ + -c -fPIC -xcce -fenable-matrix --cce-aicore-enable-tl \ + -fPIC -Xhost-start -Xhost-end \ + -mllvm -cce-aicore-stack-size=0x8000 \ + -mllvm -cce-aicore-function-stack-size=0x8000 \ + -mllvm -cce-aicore-record-overflow=true \ + -mllvm -cce-aicore-addr-transform \ + -mllvm -cce-aicore-dcci-insert-for-scalar=false \ + --cce-aicore-arch="${AICORE_ARCH}" \ + -DREGISTER_BASE \ + -std=c++17 \ + -Wno-macro-redefined -Wno-ignored-attributes \ + -I "${ASCEND_HOME_PATH}/include" \ + -I "${ASCEND_HOME_PATH}/pkg_inc" \ + -I "${ASCEND_HOME_PATH}/pkg_inc/profiling" \ + -I "${ASCEND_HOME_PATH}/pkg_inc/runtime/runtime" \ + "${case_dir}/launch.cpp" \ + -o "${out_obj}" +} + +build_host_stub() { + local case_dir="$1" + local device_obj="$2" + local stub_obj="$3" + local module_id="$4" + local host_target_args=( + -triple "${HOST_TRIPLE}" + -target-cpu "${HOST_TARGET_CPU}" + ) + if [[ -n "${HOST_TARGET_ABI}" ]]; then + host_target_args+=(-target-abi "${HOST_TARGET_ABI}") + fi + if [[ ${#HOST_FEATURE_FLAGS[@]} -gt 0 ]]; then + host_target_args+=("${HOST_FEATURE_FLAGS[@]}") + fi + + "${BISHENG_CC1_BIN}" -cc1 \ + "${host_target_args[@]}" \ + -fcce-aicpu-legacy-launch \ + -fcce-is-host \ + -cce-launch-with-flagv2-impl \ + -fcce-aicore-arch "${AICORE_ARCH}" \ + -fcce-fatobj-compile \ + -emit-obj \ + --mrelax-relocations \ + -disable-free \ + -clear-ast-before-backend \ + -disable-llvm-verifier \ + -discard-value-names \ + -main-file-name "stub.cpp" \ + -mrelocation-model pic \ + -pic-level 2 \ + -fhalf-no-semantic-interposition \ + -fenable-matrix \ + -mllvm -enable-matrix \ + -mframe-pointer=non-leaf \ + -fmath-errno \ + -ffp-contract=on \ + -fno-rounding-math \ + -mconstructor-aliases \ + -funwind-tables=2 \ + -fallow-half-arguments-and-returns \ + -mllvm -treat-scalable-fixed-error-as-warning \ + -fcoverage-compilation-dir="${ROOT_DIR}" \ + -resource-dir "${CLANG_RESOURCE_DIR}" \ + -include __clang_cce_runtime_wrapper.h \ + -I "${ASCEND_HOME_PATH}/include" \ + -I "${ASCEND_HOME_PATH}/pkg_inc" \ + -I "${ASCEND_HOME_PATH}/pkg_inc/profiling" \ + -I "${ASCEND_HOME_PATH}/pkg_inc/runtime/runtime" \ + -D _FORTIFY_SOURCE=2 \ + -D REGISTER_BASE \ + "${CC1_INCLUDE_FLAGS[@]}" \ + -O2 \ + -Wno-macro-redefined \ + -Wno-ignored-attributes \ + -std=c++17 \ + -fdeprecated-macro \ + -fdebug-compilation-dir="${ROOT_DIR}" \ + -ferror-limit 19 \ + -stack-protector 2 \ + -fno-signed-char \ + -fgnuc-version=4.2.1 \ + -fcxx-exceptions \ + -fexceptions \ + -vectorize-loops \ + -vectorize-slp \ + -mllvm -cce-aicore-stack-size=0x8000 \ + -mllvm -cce-aicore-function-stack-size=0x8000 \ + -mllvm -cce-aicore-record-overflow=true \ + -mllvm -cce-aicore-addr-transform \ + -mllvm -cce-aicore-dcci-insert-for-scalar=false \ + -fcce-include-aibinary "${device_obj}" \ + -fcce-device-module-id "${module_id}" \ + -target-feature +outline-atomics \ + -faddrsig \ + -D__GCC_HAVE_DWARF2_CFI_ASM=1 \ + -o "${stub_obj}" \ + -x cce "${case_dir}/stub.cpp" +} + +link_kernel_so() { + local case_name="$1" + local host_stub_obj="$2" + local launch_obj="$3" + local repack_obj="$4" + local repack_so="$5" + local module_id="$6" + local extra_lib_dirs=() + local extra_link_libs=() + + "${CCE_LD_BIN}" \ + "${LD_LLD_BIN}" \ + -x \ + -cce-lite-bin-module-id "${module_id}" \ + -cce-aicore-arch="${AICORE_ARCH}" \ + -r \ + -o "${repack_obj}" \ + -cce-stub-dir "${CCE_STUB_DIR}" \ + -cce-install-dir "$(dirname "${BISHENG_CC1_BIN}")" \ + -cce-inputs-number 1 \ + "${host_stub_obj}" + + if [[ "${DEVICE}" == "SIM" ]]; then + [[ -n "${SIM_LIB_DIR}" && -d "${SIM_LIB_DIR}" ]] || + die "SIM_LIB_DIR is not set or invalid for DEVICE=SIM: ${SIM_LIB_DIR}" + extra_lib_dirs+=(-L "${SIM_LIB_DIR}" -Wl,-rpath,"${SIM_LIB_DIR}") + extra_link_libs+=(-Wl,--no-as-needed -lruntime_camodel) + else + extra_link_libs+=(-Wl,--no-as-needed -lruntime) + fi + + "${BISHENG_BIN}" \ + -fPIC -s -Wl,-z,relro -Wl,-z,now --cce-fatobj-link \ + -shared -Wl,-soname,"lib${case_name}_kernel.so" \ + -L "${ASCEND_HOME_PATH}/lib64" \ + "${extra_lib_dirs[@]}" \ + -Wl,-rpath,"${ASCEND_HOME_PATH}/lib64" \ + -o "${repack_so}" \ + "${repack_obj}" \ + "${launch_obj}" \ + "${extra_link_libs[@]}" +} + +build_host_executable() { + local case_token="$1" + local case_dir="$2" + local out_dir="$3" + local extra_ldflags=() + local extra_lib_dirs=() + if [[ "${DEVICE}" == "SIM" ]]; then + [[ -n "${SIM_LIB_DIR}" && -d "${SIM_LIB_DIR}" ]] || + die "SIM_LIB_DIR is not set or invalid for DEVICE=SIM: ${SIM_LIB_DIR}" + extra_lib_dirs+=(-L "${SIM_LIB_DIR}" -Wl,-rpath,"${SIM_LIB_DIR}") + extra_ldflags+=(-Wl,--allow-shlib-undefined -lruntime_camodel) + else + extra_ldflags+=(-Wl,--allow-shlib-undefined -lruntime) + fi + + "${BISHENG_BIN}" \ + -xc++ -include stdint.h -include stddef.h -std=c++17 \ + "${case_dir}/main.cpp" \ + -I "${case_dir}" \ + -I "${NPU_VALIDATION_COMMON_DIR}" \ + -I "${ASCEND_HOME_PATH}/include" \ + -L "${out_dir}" \ + -L "${ASCEND_HOME_PATH}/lib64" \ + "${extra_lib_dirs[@]}" \ + -Wl,-rpath,"${out_dir}" \ + -Wl,-rpath,"${ASCEND_HOME_PATH}/lib64" \ + -o "${out_dir}/${case_token}" \ + -l"${case_token}_kernel" \ + "${extra_ldflags[@]}" \ + -lstdc++ -lascendcl -lm -ltiling_api -lplatform -lc_sec -ldl -lnnopbase +} + +build_one_impl() { + local case_name="$1" + local case_dir="${CASES_ROOT}/${case_name}" + local case_token + case_token="$(printf '%s' "${case_name}" | sed 's#[/[:space:]]#_#g')" + local out_dir="${WORK_SPACE}/${case_token}" + local case_module_id + case_module_id="$(printf '%s' "${MODULE_ID}-${case_name}" | md5sum | cut -c1-16)" + local llvm_ir="${out_dir}/${case_token}.ll" + local device_obj="${out_dir}/${case_token}.o" + local launch_obj="${out_dir}/launch.o" + local host_stub_obj="${out_dir}/kernel_host_from_llvm.o" + local repack_obj="${out_dir}/${case_token}_stub.cpp.o" + local repack_so="${out_dir}/lib${case_token}_kernel.so" + + [[ -f "${case_dir}/kernel.pto" ]] || die "missing kernel.pto for ${case_name}" + [[ -f "${case_dir}/stub.cpp" ]] || die "missing stub.cpp for ${case_name}" + [[ -f "${case_dir}/main.cpp" ]] || die "missing main.cpp for ${case_name}" + [[ -f "${case_dir}/launch.cpp" ]] || die "missing launch.cpp for ${case_name}" + [[ -f "${case_dir}/golden.py" ]] || die "missing golden.py for ${case_name}" + [[ -f "${case_dir}/compare.py" ]] || die "missing compare.py for ${case_name}" + + log "[$case_name] step 1/6: lower VPTO MLIR to LLVM IR" + "${PTOAS_BIN}" ${PTOAS_FLAGS} ${VPTO_FLAGS} \ + "${case_dir}/kernel.pto" -o "${llvm_ir}" + + log "[$case_name] step 2/6: compile LLVM IR to device object" + "${BISHENG_BIN}" \ + --target=hiipu64-hisilicon-cce \ + -march="${AICORE_ARCH}" \ + --cce-aicore-arch="${AICORE_ARCH}" \ + --cce-aicore-only \ + -O2 \ + -c -x ir "${llvm_ir}" \ + -o "${device_obj}" + + log "[$case_name] step 3/6: build launch object and host fatobj stub" + build_launch_object "${case_dir}" "${launch_obj}" + build_host_stub "${case_dir}" "${device_obj}" "${host_stub_obj}" "${case_module_id}" + + log "[$case_name] step 4/6: link kernel shared library" + link_kernel_so "${case_token}" "${host_stub_obj}" "${launch_obj}" "${repack_obj}" "${repack_so}" "${case_module_id}" + + if [[ "${COMPILE_ONLY}" == "1" ]]; then + log "[$case_name] compile-only mode: stop after kernel shared library" + log "[$case_name] output dir: ${out_dir}" + return 0 + fi + + log "[$case_name] step 5/6: build host executable and golden" + build_host_executable "${case_token}" "${case_dir}" "${out_dir}" + ( + cd "${out_dir}" + python3 "${case_dir}/golden.py" + ) + + log "[$case_name] step 6/6: run NPU validation" + local remote_run_cmd + remote_run_cmd=$(cat </dev/null 2>&1; fi && \ +LD_LIBRARY_PATH="${out_dir}:${SIM_LIB_DIR}:\$ASCEND_HOME_PATH/lib64:\${LD_LIBRARY_PATH:-}" "./${case_token}" +EOF +) + run_remote "${remote_run_cmd}" + + local remote_ldd_cmd + remote_ldd_cmd=$(cat </dev/null 2>&1; fi && \ +LD_LIBRARY_PATH="${out_dir}:${SIM_LIB_DIR}:\$ASCEND_HOME_PATH/lib64:\${LD_LIBRARY_PATH:-}" ldd "./${case_token}" | grep "lib${case_token}_kernel.so" +EOF +) + local ldd_output + ldd_output="$(run_remote "${remote_ldd_cmd}")" + [[ "${ldd_output}" == *"${repack_so}"* || "${ldd_output}" == *"lib${case_token}_kernel.so"* ]] || \ + die "${case_name} did not load expected kernel so: ${ldd_output}" + + ( + cd "${out_dir}" + COMPARE_STRICT=1 python3 "${case_dir}/compare.py" + ) + + log "[$case_name] compare passed" + log "[$case_name] output dir: ${out_dir}" +} + +build_one() { + local case_name="$1" + local case_token + case_token="$(printf '%s' "${case_name}" | sed 's#[/[:space:]]#_#g')" + local out_dir="${WORK_SPACE}/${case_token}" + local case_log="${out_dir}/validation.log" + + rm -rf "${out_dir}" + mkdir -p "${out_dir}" + + ( + build_one_impl "${case_name}" + ) 2>&1 | tee "${case_log}" +} + +log "=== VPTO Host Validation ===" +log "WORK_SPACE=${WORK_SPACE}" +log "ASCEND_HOME_PATH=${ASCEND_HOME_PATH}" +log "PTOAS_BIN=${PTOAS_BIN}" +log "PTOAS_FLAGS=${PTOAS_FLAGS}" +log "VPTO_FLAGS=${VPTO_FLAGS}" +log "COMPILE_ONLY=${COMPILE_ONLY}" +log "CASE_NAME=${CASE_NAME:-}" + +for case_name in "${CASES[@]}"; do + build_one "${case_name}" +done + +log "All ${#CASES[@]} VPTO case(s) passed" diff --git a/test/vpto/scripts/run_host_vpto_validation_parallel.sh b/test/vpto/scripts/run_host_vpto_validation_parallel.sh new file mode 100755 index 000000000..b9acb680d --- /dev/null +++ b/test/vpto/scripts/run_host_vpto_validation_parallel.sh @@ -0,0 +1,189 @@ +#!/usr/bin/env bash +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT_DIR="$(cd "${SCRIPT_DIR}/../../.." && pwd)" +VPTO_ROOT="${VPTO_ROOT:-${ROOT_DIR}/test/vpto/cases}" +CASES_ROOT="${CASES_ROOT:-${VPTO_ROOT}}" +SERIAL_SCRIPT="${SCRIPT_DIR}/run_host_vpto_validation.sh" + +WORK_SPACE="${WORK_SPACE:-}" +CASE_NAME="${CASE_NAME:-}" +CASE_PREFIX="${CASE_PREFIX:-}" +JOBS="${JOBS:-}" + +log() { + echo "[$(date +'%F %T')] $*" +} + +die() { + echo "ERROR: $*" >&2 + exit 1 +} + +clean_tmp_inode_hotspots() { + local -a targets=( + /tmp/pto-microop-full + /tmp/pto-microop-full-redownload + ) + + log "tmp inode usage before cleanup" + df -ih /tmp + + for dir in "${targets[@]}"; do + if [[ -e "${dir}" ]]; then + log "remove ${dir}" + rm -rf "${dir}" + fi + done + + log "tmp inode usage after cleanup" + df -ih /tmp +} + +clean_tmp_inode_hotspots + +[[ -x "${SERIAL_SCRIPT}" ]] || die "missing serial validation script: ${SERIAL_SCRIPT}" +[[ -d "${CASES_ROOT}" ]] || die "missing cases root: ${CASES_ROOT}" +[[ -n "${WORK_SPACE}" ]] || die "WORK_SPACE is required" + +if [[ -z "${JOBS}" ]]; then + if command -v nproc >/dev/null 2>&1; then + JOBS="$(nproc)" + else + JOBS=1 + fi + if [[ "${JOBS}" -gt 1 ]]; then + JOBS="$((JOBS / 2))" + fi +fi + +[[ "${JOBS}" =~ ^[0-9]+$ ]] || die "JOBS must be a positive integer, got: ${JOBS}" +[[ "${JOBS}" -ge 1 ]] || die "JOBS must be >= 1" + +mkdir -p "${WORK_SPACE}" +WORK_SPACE="$(cd "${WORK_SPACE}" && pwd)" +SUMMARY_FILE="${WORK_SPACE}/parallel-summary.tsv" +RUNNER_LOG="${WORK_SPACE}/parallel-runner.log" + +discover_cases() { + local required_files=( + kernel.pto + stub.cpp + launch.cpp + main.cpp + golden.py + compare.py + ) + + if [[ -n "${CASE_NAME}" ]]; then + local requested_dir="${CASES_ROOT}/${CASE_NAME}" + [[ -d "${requested_dir}" ]] || die "unknown case: ${CASE_NAME}" + for f in "${required_files[@]}"; do + [[ -f "${requested_dir}/${f}" ]] || die "case ${CASE_NAME} is missing ${f}" + done + printf "%s\n" "${CASE_NAME#/}" + return 0 + fi + + find "${CASES_ROOT}" -mindepth 1 -type d | sort | while read -r dir; do + local ok=1 + for f in "${required_files[@]}"; do + if [[ ! -f "${dir}/${f}" ]]; then + ok=0 + break + fi + done + [[ "${ok}" -eq 1 ]] || continue + local rel="${dir#${CASES_ROOT}/}" + if [[ -n "${CASE_PREFIX}" && "${rel}" != "${CASE_PREFIX}"* ]]; then + continue + fi + printf "%s\n" "${rel}" + done +} + +readarray -t CASES < <(discover_cases) +[[ "${#CASES[@]}" -gt 0 ]] || die "no cases found under ${CASES_ROOT}" + +: > "${SUMMARY_FILE}" +: > "${RUNNER_LOG}" + +declare -A PID_TO_CASE=() + +launch_case() { + local case_name="$1" + + log "[${case_name}] launch" | tee -a "${RUNNER_LOG}" + ( + CASE_NAME="${case_name}" "${SERIAL_SCRIPT}" + ) & + + local pid=$! + PID_TO_CASE["${pid}"]="${case_name}" +} + +reap_one() { + local pid="$1" + local case_name="${PID_TO_CASE[${pid}]}" + local result="FAIL" + local detail="1" + + if wait "${pid}"; then + result="PASS" + detail="0" + fi + + printf '%s\t%s\t%s\n' "${case_name}" "${result}" "${detail}" >> "${SUMMARY_FILE}" + log "[${case_name}] ${result} (${detail})" | tee -a "${RUNNER_LOG}" + unset 'PID_TO_CASE['"${pid}"']' +} + +log "=== VPTO Host Validation Parallel ===" | tee -a "${RUNNER_LOG}" +log "WORK_SPACE=${WORK_SPACE}" | tee -a "${RUNNER_LOG}" +log "CASE_NAME=${CASE_NAME:-}" | tee -a "${RUNNER_LOG}" +log "CASE_PREFIX=${CASE_PREFIX:-}" | tee -a "${RUNNER_LOG}" +log "JOBS=${JOBS}" | tee -a "${RUNNER_LOG}" +log "TOTAL_CASES=${#CASES[@]}" | tee -a "${RUNNER_LOG}" + +next_index=0 +while [[ "${next_index}" -lt "${#CASES[@]}" || "${#PID_TO_CASE[@]}" -gt 0 ]]; do + while [[ "${next_index}" -lt "${#CASES[@]}" && "${#PID_TO_CASE[@]}" -lt "${JOBS}" ]]; do + launch_case "${CASES[${next_index}]}" + next_index="$((next_index + 1))" + done + + if [[ "${#PID_TO_CASE[@]}" -eq 0 ]]; then + continue + fi + + while true; do + for pid in "${!PID_TO_CASE[@]}"; do + if ! kill -0 "${pid}" 2>/dev/null; then + reap_one "${pid}" + break 2 + fi + done + sleep 1 + done +done + +pass_count="$(awk -F '\t' '$2 == "PASS" {count++} END {print count + 0}' "${SUMMARY_FILE}")" +fail_count="$(awk -F '\t' '$2 != "PASS" {count++} END {print count + 0}' "${SUMMARY_FILE}")" + +log "PASS=${pass_count} FAIL=${fail_count}" | tee -a "${RUNNER_LOG}" +log "summary: ${SUMMARY_FILE}" | tee -a "${RUNNER_LOG}" + +if [[ "${fail_count}" -ne 0 ]]; then + die "parallel validation finished with ${fail_count} failing case(s)" +fi + +log "All ${pass_count} case(s) passed" | tee -a "${RUNNER_LOG}" From ed2b927e7343259685c66d541fab3f6fce735107 Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Wed, 22 Apr 2026 09:18:03 +0800 Subject: [PATCH 117/192] feat: ignore local workspace --- .gitignore | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.gitignore b/.gitignore index 4fbeb8f5e..61f15f6b0 100644 --- a/.gitignore +++ b/.gitignore @@ -78,3 +78,7 @@ test/samples/**/npu_validation/* .DS_Store .ipynb_checkpoints/ *.orig + +# Local workspace +.work/ +.local/ From 6b1321edc6d4fbd9ed61322e38b1edaef1661127 Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Wed, 22 Apr 2026 09:51:45 +0800 Subject: [PATCH 118/192] feat: support tilelang dsl ci --- .github/workflows/ci.yml | 36 ++++++++-- test/tilelang_st/script/run_all_st.py | 96 +++++++++++++++++++++++---- 2 files changed, 114 insertions(+), 18 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 58e45958d..6002d9ad9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -101,8 +101,6 @@ jobs: - name: Checkout uses: actions/checkout@v4 with: - repository: ${{ github.event.pull_request.head.repo.full_name || github.repository }} - ref: ${{ github.event.pull_request.head.sha || github.sha }} fetch-depth: 1 persist-credentials: false @@ -279,12 +277,11 @@ jobs: LLVM_DIR: ${{ github.workspace }}/llvm-project/llvm/build-shared MLIR_PYTHONPATH: ${{ github.workspace }}/llvm-project/llvm/build-shared/tools/mlir/python_packages/mlir_core VPTO_SIM_WORKSPACE: ${{ github.workspace }}/.work/vpto-sim-ci + TILELANG_DSL_WORKSPACE: ${{ github.workspace }}/.work/tilelang-dsl-ci steps: - name: Checkout uses: actions/checkout@v4 with: - repository: ${{ github.event.pull_request.head.repo.full_name || github.repository }} - ref: ${{ github.event.pull_request.head.sha || github.sha }} fetch-depth: 1 persist-credentials: false @@ -293,7 +290,7 @@ jobs: run: | set -euo pipefail missing_tools=() - for tool in python3 git cmake ninja; do + for tool in python3 git cmake ninja make; do if ! command -v "${tool}" >/dev/null 2>&1; then missing_tools+=("${tool}") fi @@ -302,7 +299,7 @@ jobs: if [[ "${#missing_tools[@]}" -gt 0 ]]; then if command -v sudo >/dev/null 2>&1 && command -v apt-get >/dev/null 2>&1; then sudo apt-get update - sudo apt-get install -y python3 python3-pip git cmake ninja-build + sudo apt-get install -y python3 python3-pip git cmake ninja-build make else echo "ERROR: missing required tools on self-hosted runner: ${missing_tools[*]}" >&2 echo "ERROR: automatic installation requires sudo + apt-get" >&2 @@ -341,6 +338,14 @@ jobs: echo "ready=false" >> "${GITHUB_OUTPUT}" fi + - name: Clean CI work dirs + shell: bash + run: | + set -euo pipefail + rm -rf "${GITHUB_WORKSPACE}/build" + rm -rf "${VPTO_SIM_WORKSPACE}" + rm -rf "${TILELANG_DSL_WORKSPACE}" + - name: Prepare LLVM source (no rebuild) if: steps.detect-llvm-build.outputs.ready != 'true' shell: bash @@ -429,6 +434,25 @@ jobs: JOBS="${JOBS:-32}" \ bash test/vpto/scripts/run_host_vpto_validation_parallel.sh + - name: Run TileLang DSL CI + shell: bash + run: | + set -euo pipefail + mkdir -p "${TILELANG_DSL_WORKSPACE}" + ASCEND_HOME_PATH="${ASCEND_HOME_PATH}" \ + PTOAS_BIN="${PTOAS_BIN}" \ + bash test/tilelang_st/script/run_ci.sh -r sim -v a5 --jobs 32 \ + 2>&1 | tee "${TILELANG_DSL_WORKSPACE}/run_ci.log" + + - name: Upload TileLang DSL logs + if: always() + uses: actions/upload-artifact@v4 + with: + name: tilelang-dsl-ci-${{ github.run_id }} + path: | + ${{ env.TILELANG_DSL_WORKSPACE }}/run_ci.log + if-no-files-found: warn + - name: Upload VPTO SIM logs if: always() uses: actions/upload-artifact@v4 diff --git a/test/tilelang_st/script/run_all_st.py b/test/tilelang_st/script/run_all_st.py index a939aa6bf..b24fc5659 100755 --- a/test/tilelang_st/script/run_all_st.py +++ b/test/tilelang_st/script/run_all_st.py @@ -10,7 +10,9 @@ """Batch runner for TileLang ST, suitable for CI/self-hosted runner usage.""" import argparse +import concurrent.futures import os +import subprocess import sys import traceback @@ -66,6 +68,10 @@ def parse_args(): "--list", action="store_true", help="List discovered testcases and exit.", ) + parser.add_argument( + "-j", "--jobs", type=int, default=1, + help="Number of testcases to run in parallel after the shared build (default: 1).", + ) return parser.parse_args() @@ -89,6 +95,28 @@ def resolve_selected_testcases(all_testcases, requested): return requested_set +def run_testcase_subprocess(script_path, run_mode, soc_version, ptoas_bin, testcase): + command = [ + sys.executable, + script_path, + "-r", run_mode, + "-v", soc_version, + "-t", testcase, + "-p", ptoas_bin, + "-w", + ] + env = os.environ.copy() + result = subprocess.run( + command, + check=False, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + env=env, + ) + return testcase, result.returncode, result.stdout + + def main(): args = parse_args() @@ -99,6 +127,9 @@ def main(): file=sys.stderr, ) sys.exit(1) + if args.jobs < 1: + print("[ERROR] --jobs must be >= 1", file=sys.stderr) + sys.exit(1) script_path = os.path.abspath(__file__) tilelang_st_root = os.path.dirname(os.path.dirname(script_path)) @@ -142,6 +173,7 @@ def main(): print(f"[INFO] ptoas={ptoas_bin}") print(f"[INFO] target_dir={target_dir}") print(f"[INFO] selected_testcases={', '.join(selected_testcases)}") + print(f"[INFO] jobs={args.jobs}") original_dir = os.getcwd() failures = [] @@ -155,18 +187,58 @@ def main(): run_st.build_project(args.run_mode, default_soc_version, "all", ptoas_bin) total = len(selected_testcases) - for index, testcase in enumerate(selected_testcases, start=1): - print(f"[INFO] [{index}/{total}] running testcase: {testcase}") - try: - run_st.run_gen_data(testcase) - run_st.run_binary(testcase) - run_st.run_compare(testcase) - except Exception as exc: # pragma: no cover - CI-side aggregation path - failures.append((testcase, str(exc))) - print(f"[ERROR] testcase failed: {testcase}") - traceback.print_exc() - if args.fail_fast: - break + if args.jobs == 1: + for index, testcase in enumerate(selected_testcases, start=1): + print(f"[INFO] [{index}/{total}] running testcase: {testcase}") + try: + run_st.run_gen_data(testcase) + run_st.run_binary(testcase) + run_st.run_compare(testcase) + except Exception as exc: # pragma: no cover - CI-side aggregation path + failures.append((testcase, str(exc))) + print(f"[ERROR] testcase failed: {testcase}") + traceback.print_exc() + if args.fail_fast: + break + else: + print(f"[INFO] running testcases in parallel with jobs={args.jobs}") + max_workers = min(args.jobs, total) + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_testcase = {} + for index, testcase in enumerate(selected_testcases, start=1): + print(f"[INFO] [{index}/{total}] queue testcase: {testcase}") + future = executor.submit( + run_testcase_subprocess, + script_path, + args.run_mode, + args.soc_version, + ptoas_bin, + testcase, + ) + future_to_testcase[future] = testcase + + for future in concurrent.futures.as_completed(future_to_testcase): + testcase = future_to_testcase[future] + try: + _, returncode, output = future.result() + except Exception as exc: # pragma: no cover - executor/host failure + failures.append((testcase, str(exc))) + print(f"[ERROR] testcase runner crashed: {testcase}") + traceback.print_exc() + if args.fail_fast: + break + continue + + print(f"[INFO] ===== testcase {testcase} output begin =====") + if output: + print(output, end="" if output.endswith("\n") else "\n") + print(f"[INFO] ===== testcase {testcase} output end =====") + + if returncode != 0: + failures.append((testcase, f"subprocess exited with {returncode}")) + print(f"[ERROR] testcase failed: {testcase}") + if args.fail_fast: + break except Exception as exc: print(f"[ERROR] batch run failed: {exc}", file=sys.stderr) From 0fa8bb775da7801d51978abeb93e7cc2d5bfb137 Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Wed, 22 Apr 2026 14:48:13 +0800 Subject: [PATCH 119/192] remove the fork repo pr protects --- .github/workflows/ci.yml | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6002d9ad9..5cfad5a06 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -269,13 +269,10 @@ jobs: ${{ github.event_name == 'workflow_dispatch' || github.event_name == 'schedule' || - (github.event_name == 'pull_request' && - github.event.pull_request.head.repo.full_name == github.repository) + github.event_name == 'pull_request' }} env: LLVM_COMMIT: cd708029e0b2869e80abe31ddb175f7c35361f90 - LLVM_DIR: ${{ github.workspace }}/llvm-project/llvm/build-shared - MLIR_PYTHONPATH: ${{ github.workspace }}/llvm-project/llvm/build-shared/tools/mlir/python_packages/mlir_core VPTO_SIM_WORKSPACE: ${{ github.workspace }}/.work/vpto-sim-ci TILELANG_DSL_WORKSPACE: ${{ github.workspace }}/.work/tilelang-dsl-ci steps: @@ -285,6 +282,16 @@ jobs: fetch-depth: 1 persist-credentials: false + - name: Resolve LLVM directories + shell: bash + env: + TOOL_CACHE: ${{ runner.tool_cache }} + run: | + set -euo pipefail + echo "LLVM_ROOT=${TOOL_CACHE}/llvm-project" >> "${GITHUB_ENV}" + echo "LLVM_DIR=${TOOL_CACHE}/llvm-project/llvm/build-shared" >> "${GITHUB_ENV}" + echo "MLIR_PYTHONPATH=${TOOL_CACHE}/llvm-project/llvm/build-shared/tools/mlir/python_packages/mlir_core" >> "${GITHUB_ENV}" + - name: Ensure runner dependencies shell: bash run: | @@ -351,8 +358,8 @@ jobs: shell: bash run: | set -euo pipefail - mkdir -p llvm-project - cd llvm-project + mkdir -p "${LLVM_ROOT}" + cd "${LLVM_ROOT}" if [ ! -d .git ]; then git init @@ -367,7 +374,7 @@ jobs: shell: bash run: | set -euo pipefail - cd llvm-project + cd "${LLVM_ROOT}" cmake -G Ninja -S llvm -B llvm/build-shared \ -DLLVM_ENABLE_PROJECTS="mlir;clang" \ -DBUILD_SHARED_LIBS=ON \ From f9c89fc608f9eaf04f02f204645bcae9690a011e Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Wed, 22 Apr 2026 11:46:36 +0800 Subject: [PATCH 120/192] fix(vpto): normalize signed integer vector decls --- lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 28 +++- test/basic/issue_173_vpto_llvm.pto | 38 +++++ .../issue-173-vsts-signed-signless/compare.py | 43 +++++ .../issue-173-vsts-signed-signless/golden.py | 57 +++++++ .../issue-173-vsts-signed-signless/kernel.pto | 81 ++++++++++ .../issue-173-vsts-signed-signless/launch.cpp | 55 +++++++ .../issue-173-vsts-signed-signless/main.cpp | 150 ++++++++++++++++++ .../issue-173-vsts-signed-signless/stub.cpp | 29 ++++ 8 files changed, 478 insertions(+), 3 deletions(-) create mode 100644 test/basic/issue_173_vpto_llvm.pto create mode 100644 test/vpto/cases/vpto/issue-173-vsts-signed-signless/compare.py create mode 100644 test/vpto/cases/vpto/issue-173-vsts-signed-signless/golden.py create mode 100644 test/vpto/cases/vpto/issue-173-vsts-signed-signless/kernel.pto create mode 100644 test/vpto/cases/vpto/issue-173-vsts-signed-signless/launch.cpp create mode 100644 test/vpto/cases/vpto/issue-173-vsts-signed-signless/main.cpp create mode 100644 test/vpto/cases/vpto/issue-173-vsts-signed-signless/stub.cpp diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index 0bd5ae62d..e0619414a 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -48,9 +48,31 @@ static std::string getElementTypeFragment(Type type); static Type getElementTypeFromVectorLike(Type type); static std::optional getElementCountFromVectorLike(Type type); +static Type normalizeIntegerTypeForLLVMLowering(Type type, Builder &builder) { + if (auto intType = dyn_cast(type)) { + if (!intType.isSignless()) + return builder.getIntegerType(intType.getWidth()); + return type; + } + + if (auto vecType = dyn_cast(type)) { + Type normalizedElement = + normalizeIntegerTypeForLLVMLowering(vecType.getElementType(), builder); + if (normalizedElement == vecType.getElementType()) + return type; + return VectorType::get(vecType.getShape(), normalizedElement, + vecType.getScalableDims()); + } + + return type; +} + static Type convertVPTOType(Type type, Builder &builder) { - if (auto vecType = dyn_cast(type)) - return VectorType::get({vecType.getElementCount()}, vecType.getElementType()); + if (auto vecType = dyn_cast(type)) { + Type elementType = + normalizeIntegerTypeForLLVMLowering(vecType.getElementType(), builder); + return VectorType::get({vecType.getElementCount()}, elementType); + } if (isa(type)) return VectorType::get({256}, builder.getI1Type()); if (isa(type)) @@ -60,7 +82,7 @@ static Type convertVPTOType(Type type, Builder &builder) { builder.getContext(), static_cast(ptrType.getMemorySpace().getAddressSpace())); } - return type; + return normalizeIntegerTypeForLLVMLowering(type, builder); } static bool hasVPTOConvertibleType(Type type) { diff --git a/test/basic/issue_173_vpto_llvm.pto b/test/basic/issue_173_vpto_llvm.pto new file mode 100644 index 000000000..3fb291635 --- /dev/null +++ b/test/basic/issue_173_vpto_llvm.pto @@ -0,0 +1,38 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --vpto-emit-hivm-llvm %s -o %t.ll +// RUN: FileCheck %s < %t.ll + +// Regression test for issue #173: signed and signless i16 vector stores should +// share the same LLVM/HIVM declaration after VPTO type conversion. +module attributes {pto.target_arch = "a5"} { + func.func @store_signed_i16(%value: !pto.vreg<128xsi16>, %dst: !pto.ptr) attributes {pto.kernel_kind = #pto.kernel_kind} { + %c0 = arith.constant 0 : index + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + pto.vsts %value, %dst[%c0], %mask : !pto.vreg<128xsi16>, !pto.ptr, !pto.mask + } + return + } + + func.func @store_signless_i16(%value: !pto.vreg<128xi16>, %dst: !pto.ptr) attributes {pto.kernel_kind = #pto.kernel_kind} { + %c0 = arith.constant 0 : index + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + pto.vsts %value, %dst[%c0], %mask : !pto.vreg<128xi16>, !pto.ptr, !pto.mask + } + return + } +} + +// CHECK-COUNT-1: declare void @llvm.hivm.vstsx1.v128s16(<128 x i16>, ptr addrspace(6), i32, i32, i32, <256 x i1>) +// CHECK-LABEL: define void @store_signed_i16( +// CHECK: call void @llvm.hivm.vstsx1.v128s16(<128 x i16> {{.*}}, ptr addrspace(6) {{.*}}, i32 0, i32 1, i32 0, <256 x i1> {{.*}}) +// CHECK-LABEL: define void @store_signless_i16( +// CHECK: call void @llvm.hivm.vstsx1.v128s16(<128 x i16> {{.*}}, ptr addrspace(6) {{.*}}, i32 0, i32 1, i32 0, <256 x i1> {{.*}}) diff --git a/test/vpto/cases/vpto/issue-173-vsts-signed-signless/compare.py b/test/vpto/cases/vpto/issue-173-vsts-signed-signless/compare.py new file mode 100644 index 000000000..fb0e653e2 --- /dev/null +++ b/test/vpto/cases/vpto/issue-173-vsts-signed-signless/compare.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: vpto/issue-173-vsts-signed-signless +# family: vpto +# target_ops: pto.vlds, pto.vsts +# scenarios: signed-i16, signless-i16, same-module, issue-173-regression + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.int16) + output = np.fromfile(output_path, dtype=np.int16) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin") + ok = compare_bin("golden_v4.bin", "v4.bin") and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vpto/issue-173-vsts-signed-signless/golden.py b/test/vpto/cases/vpto/issue-173-vsts-signed-signless/golden.py new file mode 100644 index 000000000..53bda53f6 --- /dev/null +++ b/test/vpto/cases/vpto/issue-173-vsts-signed-signless/golden.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: vpto/issue-173-vsts-signed-signless +# family: vpto +# target_ops: pto.vlds, pto.vsts +# scenarios: signed-i16, signless-i16, same-module, issue-173-regression + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 173 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + signed = rng.integers(-32768, 32768, size=ELEMS, dtype=np.int16) + signless = rng.integers(-32768, 32768, size=ELEMS, dtype=np.int16) + + signed[:16] = np.array( + [-32768, -30000, -12345, -1, 0, 1, 2, 3, 7, 15, 127, 255, 1024, 12345, 30000, 32767], + dtype=np.int16, + ) + signless[:16] = np.array( + [32767, 30000, 12345, 1024, 255, 127, 15, 7, 3, 2, 1, 0, -1, -12345, -30000, -32768], + dtype=np.int16, + ) + + output_dir.mkdir(parents=True, exist_ok=True) + signed.tofile(output_dir / "v1.bin") + np.zeros(ELEMS, dtype=np.int16).tofile(output_dir / "v2.bin") + signless.tofile(output_dir / "v3.bin") + np.zeros(ELEMS, dtype=np.int16).tofile(output_dir / "v4.bin") + signed.tofile(output_dir / "golden_v2.bin") + signless.tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vpto/issue-173-vsts-signed-signless/kernel.pto b/test/vpto/cases/vpto/issue-173-vsts-signed-signless/kernel.pto new file mode 100644 index 000000000..e1227f2fc --- /dev/null +++ b/test/vpto/cases/vpto/issue-173-vsts-signed-signless/kernel.pto @@ -0,0 +1,81 @@ +// ----------------------------------------------------------------------------- +// case: vpto/issue-173-vsts-signed-signless +// family: vpto +// target_ops: pto.vlds, pto.vsts +// scenarios: signed-i16, signless-i16, same-module, issue-173-regression +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @copy_signed_i16_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xsi16> + pto.vsts %vec, %ub_out[%offset], %mask : !pto.vreg<128xsi16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } + + func.func @copy_signless_i16_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xi16> + pto.vsts %vec, %ub_out[%offset], %mask : !pto.vreg<128xi16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vpto/issue-173-vsts-signed-signless/launch.cpp b/test/vpto/cases/vpto/issue-173-vsts-signed-signless/launch.cpp new file mode 100644 index 000000000..3153aaff6 --- /dev/null +++ b/test/vpto/cases/vpto/issue-173-vsts-signed-signless/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void copy_signed_i16_kernel(__gm__ int16_t *v1, + __gm__ int16_t *v2); +extern "C" __global__ [aicore] void copy_signless_i16_kernel( + __gm__ int16_t *v3, __gm__ int16_t *v4); + +void LaunchCopySignedI16Kernel(int16_t *v1, int16_t *v2, void *stream) { + copy_signed_i16_kernel<<<1, nullptr, stream>>>((__gm__ int16_t *)v1, + (__gm__ int16_t *)v2); +} + +void LaunchCopySignlessI16Kernel(int16_t *v3, int16_t *v4, void *stream) { + copy_signless_i16_kernel<<<1, nullptr, stream>>>((__gm__ int16_t *)v3, + (__gm__ int16_t *)v4); +} diff --git a/test/vpto/cases/vpto/issue-173-vsts-signed-signless/main.cpp b/test/vpto/cases/vpto/issue-173-vsts-signed-signless/main.cpp new file mode 100644 index 000000000..a7ae99cc8 --- /dev/null +++ b/test/vpto/cases/vpto/issue-173-vsts-signed-signless/main.cpp @@ -0,0 +1,150 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: vpto/issue-173-vsts-signed-signless +// family: vpto +// target_ops: pto.vlds, pto.vsts +// scenarios: signed-i16, signless-i16, same-module, issue-173-regression +// ----------------------------------------------------------------------------- + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +#define FILE_CHECK(expr, path) \ + do { \ + if (!(expr)) { \ + std::fprintf(stderr, "[ERROR] file operation failed: %s (%s:%d)\n", \ + path, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchCopySignedI16Kernel(int16_t *v1, int16_t *v2, void *stream); +void LaunchCopySignlessI16Kernel(int16_t *v3, int16_t *v4, void *stream); + +int main() { + constexpr size_t elemCount = 1024; + constexpr size_t fileSize = elemCount * sizeof(int16_t); + size_t inputFileSize = fileSize; + + int16_t *v1Host = nullptr; + int16_t *v2Host = nullptr; + int16_t *v3Host = nullptr; + int16_t *v4Host = nullptr; + int16_t *v1Device = nullptr; + int16_t *v2Device = nullptr; + int16_t *v3Device = nullptr; + int16_t *v4Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize)); + ACL_CHECK(aclrtMallocHost((void **)(&v4Host), fileSize)); + + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v4Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST)); + + FILE_CHECK(ReadFile("./v1.bin", inputFileSize, v1Host, fileSize) && + inputFileSize == fileSize, + "./v1.bin"); + inputFileSize = fileSize; + FILE_CHECK(ReadFile("./v2.bin", inputFileSize, v2Host, fileSize) && + inputFileSize == fileSize, + "./v2.bin"); + inputFileSize = fileSize; + FILE_CHECK(ReadFile("./v3.bin", inputFileSize, v3Host, fileSize) && + inputFileSize == fileSize, + "./v3.bin"); + inputFileSize = fileSize; + FILE_CHECK(ReadFile("./v4.bin", inputFileSize, v4Host, fileSize) && + inputFileSize == fileSize, + "./v4.bin"); + + ACL_CHECK(aclrtMemcpy(v1Device, fileSize, v1Host, fileSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize, v2Host, fileSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize, v3Host, fileSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v4Device, fileSize, v4Host, fileSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchCopySignedI16Kernel(v1Device, v2Device, stream); + LaunchCopySignlessI16Kernel(v3Device, v4Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + + ACL_CHECK(aclrtMemcpy(v2Host, fileSize, v2Device, fileSize, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v4Host, fileSize, v4Device, fileSize, + ACL_MEMCPY_DEVICE_TO_HOST)); + + FILE_CHECK(WriteFile("./v2.bin", v2Host, fileSize), "./v2.bin"); + FILE_CHECK(WriteFile("./v4.bin", v4Host, fileSize), "./v4.bin"); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFree(v4Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + aclrtFreeHost(v4Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vpto/issue-173-vsts-signed-signless/stub.cpp b/test/vpto/cases/vpto/issue-173-vsts-signed-signless/stub.cpp new file mode 100644 index 000000000..105082f90 --- /dev/null +++ b/test/vpto/cases/vpto/issue-173-vsts-signed-signless/stub.cpp @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void copy_signed_i16_kernel(__gm__ int16_t *v1, + __gm__ int16_t *v2) { + (void)v1; + (void)v2; +} + +extern "C" __global__ [aicore] void copy_signless_i16_kernel(__gm__ int16_t *v3, + __gm__ int16_t *v4) { + (void)v3; + (void)v4; +} From 3b3959de45cecd816ab984c980f8c525596b0d58 Mon Sep 17 00:00:00 2001 From: mly <978226558@qq.com> Date: Wed, 22 Apr 2026 17:06:15 +0800 Subject: [PATCH 121/192] [WIP] feat: new dma load/store op (#141) * feat: new dma load/store op * update testcases --------- Co-authored-by: mouliangyu --- docs/isa-legacy/02-dma-copy.md | 634 ++++++++++++++++++ docs/isa/01-pipeline-sync.md | 26 +- docs/isa/02-dma-copy.md | 397 ++++------- docs/vpto-spec.md | 32 +- include/PTO/IR/PTO.h | 18 + include/PTO/IR/VPTOOps.td | 86 +++ lib/PTO/IR/VPTO.cpp | 476 +++++++++++++ lib/PTO/Transforms/PTOVPTOExpandBridgeOps.cpp | 91 ++- lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 38 +- .../binary-vector/vadd-bf16/kernel.pto | 15 +- .../binary-vector/vadd-f16/kernel.pto | 15 +- .../vadd-f32-exceptional/kernel.pto | 15 +- .../vadd-i16-signed-overflow/kernel.pto | 15 +- .../binary-vector/vadd-i16-signed/kernel.pto | 15 +- .../vadd-i16-unsigned-overflow/kernel.pto | 15 +- .../vadd-i16-unsigned/kernel.pto | 15 +- .../binary-vector/vadd-tail/kernel.pto | 15 +- .../micro-op/binary-vector/vadd/kernel.pto | 15 +- .../vaddc-carry-boundary/kernel.pto | 20 +- .../micro-op/binary-vector/vaddc/kernel.pto | 21 +- .../binary-vector/vand-mask-edge/kernel.pto | 15 +- .../micro-op/binary-vector/vand/kernel.pto | 15 +- .../binary-vector/vdiv-f16/kernel.pto | 15 +- .../vdiv-f32-exceptional/kernel.pto | 15 +- .../binary-vector/vdiv-tail/kernel.pto | 15 +- .../micro-op/binary-vector/vdiv/kernel.pto | 16 +- .../binary-vector/vmax-tail/kernel.pto | 15 +- .../micro-op/binary-vector/vmax/kernel.pto | 16 +- .../binary-vector/vmin-bf16/kernel.pto | 15 +- .../binary-vector/vmin-f16/kernel.pto | 15 +- .../vmin-f32-exceptional/kernel.pto | 16 +- .../vmin-f32-exceptional/vmin/kernel.pto | 16 +- .../binary-vector/vmin-i16-signed/kernel.pto | 15 +- .../vmin-i16-unsigned/kernel.pto | 15 +- .../binary-vector/vmin-tail/kernel.pto | 15 +- .../micro-op/binary-vector/vmin/kernel.pto | 16 +- .../binary-vector/vmul-tail/kernel.pto | 15 +- .../micro-op/binary-vector/vmul/kernel.pto | 16 +- .../micro-op/binary-vector/vor-f16/kernel.pto | 15 +- .../binary-vector/vor-mask-edge/kernel.pto | 15 +- .../micro-op/binary-vector/vor/kernel.pto | 15 +- .../vshl-i32-unsigned/kernel.pto | 15 +- .../vshl-shift-boundary/kernel.pto | 15 +- .../micro-op/binary-vector/vshl/kernel.pto | 15 +- .../binary-vector/vshr-i16-signed/kernel.pto | 15 +- .../vshr-shift-boundary/kernel.pto | 15 +- .../micro-op/binary-vector/vshr/kernel.pto | 15 +- .../binary-vector/vsub-tail/kernel.pto | 15 +- .../micro-op/binary-vector/vsub/kernel.pto | 16 +- .../vsubc-borrow-boundary/kernel.pto | 20 +- .../micro-op/binary-vector/vsubc/kernel.pto | 20 +- .../binary-vector/vxor-mask-edge/kernel.pto | 15 +- .../micro-op/binary-vector/vxor/kernel.pto | 15 +- .../compare-select/vcmp-eq/kernel.pto | 16 +- .../vcmp-f32-exceptional/kernel.pto | 16 +- .../compare-select/vcmp-i16-signed/kernel.pto | 16 +- .../vcmp-i16-unsigned/kernel.pto | 16 +- .../compare-select/vcmp-lt/kernel.pto | 16 +- .../compare-select/vcmp-tail/kernel.pto | 16 +- .../vcmps-f32-exceptional/kernel.pto | 12 +- .../compare-select/vcmps-f32/kernel.pto | 12 +- .../vcmps-i16-signed/kernel.pto | 12 +- .../vcmps-i16-unsigned/kernel.pto | 12 +- .../compare-select/vcmps-tail/kernel.pto | 12 +- .../compare-select/vsel-i16/kernel.pto | 15 +- .../vsel-predicate-edge/kernel.pto | 16 +- .../compare-select/vsel-tail/kernel.pto | 19 +- .../micro-op/compare-select/vsel/kernel.pto | 15 +- .../compare-select/vselr-f16/kernel.pto | 15 +- .../compare-select/vselr-u8/kernel.pto | 15 +- .../micro-op/compare-select/vselr/kernel.pto | 15 +- .../conversion/vcvt-f16-special/kernel.pto | 10 +- .../vcvt-f16-to-f32-part-even/kernel.pto | 10 +- .../vcvt-f16-to-f32-part-odd/kernel.pto | 10 +- .../conversion/vcvt-f16-to-f32/kernel.pto | 10 +- .../conversion/vcvt-f32-special/kernel.pto | 10 +- .../vcvt-f32-to-f16-pk-b32/kernel.pto | 10 +- .../conversion/vcvt-f32-to-f16/kernel.pto | 10 +- .../vcvt-i32-to-i16-overflow/kernel.pto | 10 +- .../conversion/vcvt-tail-special/kernel.pto | 10 +- .../micro-op/conversion/vcvt-tail/kernel.pto | 10 +- .../conversion/vtrc-f16-rounding/kernel.pto | 10 +- .../conversion/vtrc-f32-rounding/kernel.pto | 20 +- .../conversion/vtrc-f32-special/kernel.pto | 10 +- .../vtrc-rounding-boundary/kernel.pto | 20 +- .../micro-op/dsa-sfu/vaxpy-f32/kernel.pto | 16 +- .../micro-op/dsa-sfu/vbitsort/kernel.pto | 17 +- .../cases/micro-op/dsa-sfu/vci/kernel.pto | 16 +- .../dsa-sfu/vexpdiff-boundary/kernel.pto | 16 +- .../dsa-sfu/vexpdiff-f16-part/kernel.pto | 15 +- .../micro-op/dsa-sfu/vexpdiff-f32/kernel.pto | 11 +- .../micro-op/dsa-sfu/vlrelu-f16/kernel.pto | 11 +- .../dsa-sfu/vlrelu-f32-exceptional/kernel.pto | 11 +- .../micro-op/dsa-sfu/vlrelu-f32/kernel.pto | 11 +- .../micro-op/dsa-sfu/vlrelu-tail/kernel.pto | 11 +- .../vmula-accumulator-boundary/kernel.pto | 10 +- .../cases/micro-op/dsa-sfu/vmula/kernel.pto | 10 +- .../cases/micro-op/dsa-sfu/vmull/kernel.pto | 10 +- .../micro-op/dsa-sfu/vprelu-f32/kernel.pto | 16 +- .../micro-op/dsa-sfu/vprelu-tail/kernel.pto | 16 +- .../vgather2-duplicate-index/kernel.pto | 15 +- .../gather-scatter/vgather2/kernel.pto | 15 +- .../vgather2_bc-sparse-mask/kernel.pto | 15 +- .../gather-scatter/vgather2_bc/kernel.pto | 15 +- .../vgatherb-block-boundary/kernel.pto | 15 +- .../gather-scatter/vgatherb/kernel.pto | 15 +- .../vscatter-out-of-order-index/kernel.pto | 20 +- .../gather-scatter/vscatter/kernel.pto | 20 +- .../materialization-predicate/pand/kernel.pto | 5 +- .../pdintlv_b16-nontrivial/kernel.pto | 5 +- .../pdintlv_b16/kernel.pto | 5 +- .../pdintlv_b32-nontrivial/kernel.pto | 5 +- .../pdintlv_b32/kernel.pto | 5 +- .../pdintlv_b8-nontrivial/kernel.pto | 5 +- .../pdintlv_b8/kernel.pto | 5 +- .../pge-tail-mask-boundary/kernel.pto | 5 +- .../pge-tail-mask/kernel.pto | 5 +- .../pintlv_b16-nontrivial/kernel.pto | 5 +- .../pintlv_b16/kernel.pto | 5 +- .../pintlv_b32-nontrivial/kernel.pto | 5 +- .../pintlv_b32/kernel.pto | 5 +- .../pintlv_b8-nontrivial/kernel.pto | 5 +- .../pintlv_b8/kernel.pto | 5 +- .../plt-tail-mask-boundary/kernel.pto | 5 +- .../plt-tail-mask/kernel.pto | 5 +- .../materialization-predicate/pnot/kernel.pto | 5 +- .../materialization-predicate/por/kernel.pto | 5 +- .../ppack-punpack-nontrivial/kernel.pto | 5 +- .../ppack-punpack/kernel.pto | 5 +- .../psel-tail-predicate/kernel.pto | 5 +- .../materialization-predicate/psel/kernel.pto | 5 +- .../pset-pattern-fragment/kernel.pto | 5 +- .../pset-pattern/kernel.pto | 5 +- .../materialization-predicate/pxor/kernel.pto | 5 +- .../vbr-f32/kernel.pto | 5 +- .../vbr-i32/kernel.pto | 5 +- .../vdup-lane/kernel.pto | 14 +- .../vdup-scalar-f16/kernel.pto | 5 +- .../vdup-scalar-i8/kernel.pto | 5 +- .../vdup-scalar/kernel.pto | 5 +- .../predicate-load-store/pldi-norm/kernel.pto | 11 +- .../predicate-load-store/plds-norm/kernel.pto | 11 +- .../psti-norm-pldi-ds/kernel.pto | 11 +- .../psti-pk-pldi-us/kernel.pto | 11 +- .../predicate-load-store/psti-pk/kernel.pto | 5 +- .../psts-norm-plds-ds/kernel.pto | 11 +- .../kernel.pto | 11 +- .../psts-pk-plds-us/kernel.pto | 11 +- .../pstu-init-align-outside-loop/kernel.pto | 5 +- .../pstu-state-advance-boundary/kernel.pto | 5 +- .../predicate-load-store/pstu/kernel.pto | 5 +- .../vintlv-vdintlv-lane-boundary/kernel.pto | 10 +- .../rearrangement/vintlv-vdintlv/kernel.pto | 10 +- .../rearrangement/vpack-higher/kernel.pto | 10 +- .../rearrangement/vpack-lower/kernel.pto | 10 +- .../vsqz-nontrivial-mask/kernel.pto | 14 +- .../micro-op/rearrangement/vsqz/kernel.pto | 10 +- .../rearrangement/vsunpack/kernel.pto | 10 +- .../vusqz-nontrivial-mask/kernel.pto | 14 +- .../micro-op/rearrangement/vusqz/kernel.pto | 14 +- .../rearrangement/vzunpack/kernel.pto | 10 +- .../micro-op/reduction/vcadd-tail/kernel.pto | 10 +- .../cases/micro-op/reduction/vcadd/kernel.pto | 10 +- .../micro-op/reduction/vcgadd-tail/kernel.pto | 10 +- .../micro-op/reduction/vcgadd/kernel.pto | 10 +- .../micro-op/reduction/vcgmax-tie/kernel.pto | 10 +- .../micro-op/reduction/vcgmax/kernel.pto | 10 +- .../micro-op/reduction/vcgmin-tie/kernel.pto | 10 +- .../micro-op/reduction/vcgmin/kernel.pto | 10 +- .../cases/micro-op/reduction/vcmax/kernel.pto | 10 +- .../cases/micro-op/reduction/vcmin/kernel.pto | 10 +- .../micro-op/reduction/vcpadd-tail/kernel.pto | 10 +- .../micro-op/reduction/vcpadd/kernel.pto | 10 +- .../load-store-scalar-ub/kernel.pto | 10 +- .../micro-op/unary-vector/vabs-f16/kernel.pto | 10 +- .../vabs-f32-exceptional/kernel.pto | 10 +- .../vabs-i16-signed-overflow-edge/kernel.pto | 10 +- .../unary-vector/vabs-i16-signed/kernel.pto | 10 +- .../unary-vector/vabs-i16-unsigned/kernel.pto | 10 +- .../vabs-loop-carried-vreg/kernel.pto | 10 +- .../unary-vector/vabs-tail/kernel.pto | 10 +- .../micro-op/unary-vector/vabs/kernel.pto | 10 +- .../micro-op/unary-vector/vexp-f16/kernel.pto | 10 +- .../vexp-f32-exceptional/kernel.pto | 10 +- .../vexp-f32-over-underflow/kernel.pto | 10 +- .../unary-vector/vexp-tail/kernel.pto | 10 +- .../micro-op/unary-vector/vexp/kernel.pto | 10 +- .../vln-domain-boundary/kernel.pto | 10 +- .../micro-op/unary-vector/vln/kernel.pto | 10 +- .../vneg-f32-exceptional/kernel.pto | 10 +- .../micro-op/unary-vector/vneg/kernel.pto | 10 +- .../micro-op/unary-vector/vnot/kernel.pto | 10 +- .../micro-op/unary-vector/vrelu/kernel.pto | 10 +- .../vsqrt-domain-boundary/kernel.pto | 10 +- .../micro-op/unary-vector/vsqrt/kernel.pto | 10 +- .../vaddcs-carry-boundary/kernel.pto | 21 +- .../micro-op/vec-scalar/vaddcs/kernel.pto | 21 +- .../micro-op/vec-scalar/vadds-bf16/kernel.pto | 11 +- .../micro-op/vec-scalar/vadds-f16/kernel.pto | 11 +- .../vadds-f32-exceptional/kernel.pto | 11 +- .../vadds-i16-signed-overflow/kernel.pto | 10 +- .../vec-scalar/vadds-i16-signed/kernel.pto | 10 +- .../vadds-i16-unsigned-overflow/kernel.pto | 11 +- .../vec-scalar/vadds-i16-unsigned/kernel.pto | 11 +- .../micro-op/vec-scalar/vadds-tail/kernel.pto | 11 +- .../micro-op/vec-scalar/vadds/kernel.pto | 11 +- .../micro-op/vec-scalar/vmaxs-tail/kernel.pto | 11 +- .../micro-op/vec-scalar/vmaxs/kernel.pto | 11 +- .../micro-op/vec-scalar/vmins-tail/kernel.pto | 11 +- .../micro-op/vec-scalar/vmins/kernel.pto | 11 +- .../micro-op/vec-scalar/vmuls-tail/kernel.pto | 11 +- .../micro-op/vec-scalar/vmuls/kernel.pto | 11 +- .../vshls-shift-boundary/kernel.pto | 11 +- .../micro-op/vec-scalar/vshls/kernel.pto | 11 +- .../vshrs-shift-boundary/kernel.pto | 11 +- .../micro-op/vec-scalar/vshrs/kernel.pto | 11 +- .../vsubcs-borrow-boundary/kernel.pto | 21 +- .../micro-op/vec-scalar/vsubcs/kernel.pto | 21 +- .../vldas-vldus-state-chain/kernel.pto | 10 +- .../vector-load-store/vldas-vldus/kernel.pto | 10 +- .../vlds-brc-b16-f32/kernel.pto | 10 +- .../vector-load-store/vlds-brc-b16/kernel.pto | 10 +- .../vector-load-store/vlds-brc-b32/kernel.pto | 10 +- .../vlds-brc-b8-f32/kernel.pto | 10 +- .../vector-load-store/vlds-brc-blk/kernel.pto | 10 +- .../vector-load-store/vlds-ds-b16/kernel.pto | 10 +- .../vector-load-store/vlds-tail/kernel.pto | 10 +- .../vlds-unpk-b16/kernel.pto | 10 +- .../vector-load-store/vlds-us-b16/kernel.pto | 10 +- .../vector-load-store/vlds/kernel.pto | 10 +- .../vldsx2-layout-check/kernel.pto | 17 +- .../vldsx2-vstsx2-b8-f32/kernel.pto | 17 +- .../vldsx2-vstsx2/kernel.pto | 17 +- .../vector-load-store/vsldb/kernel.pto | 10 +- .../vector-load-store/vsstb/kernel.pto | 10 +- .../vector-load-store/vstar/kernel.pto | 10 +- .../vstas-vstus-offset-update/kernel.pto | 14 +- .../vector-load-store/vsts-1pt-b16/kernel.pto | 10 +- .../vector-load-store/vsts-pk-b16/kernel.pto | 10 +- .../vsts-pk-b64-f32/kernel.pto | 16 +- .../vector-load-store/vsts-tail/kernel.pto | 17 +- .../vector-load-store/vsts/kernel.pto | 10 +- .../vstsx2-layout-check/kernel.pto | 17 +- .../vstur-init-align-outside-loop/kernel.pto | 10 +- .../vector-load-store/vstur/kernel.pto | 10 +- 245 files changed, 3088 insertions(+), 1475 deletions(-) create mode 100644 docs/isa-legacy/02-dma-copy.md diff --git a/docs/isa-legacy/02-dma-copy.md b/docs/isa-legacy/02-dma-copy.md new file mode 100644 index 000000000..6e2fdc4f3 --- /dev/null +++ b/docs/isa-legacy/02-dma-copy.md @@ -0,0 +1,634 @@ +# 2. DMA Copy Programming + +> **Category:** DMA transfer configuration and execution +> **Pipelines:** MTE2 (GM→UB), MTE3 (UB→GM) + +DMA transfers move data between Global Memory (GM) and Unified Buffer (UB). The MTE engines operate asynchronously from the Vector core, requiring explicit sync (see [Pipeline Sync](01-pipeline-sync.md)). + +The MTE2/MTE3 DMA engine executes a **multi-level nested loop** transfer. Before issuing the copy instruction, stride and loop-size registers must be configured. + +--- + +## Loop Stride Configuration (GM→UB) + +These ops configure the MTE2 DMA engine's hardware loops for GM→UB transfers. They must be set **before** calling `pto.copy_gm_to_ubuf`. + +### `pto.set_loop_size_outtoub` + +- **syntax:** `pto.set_loop_size_outtoub %loop1_count, %loop2_count : i64, i64` +- **semantics:** Configure HW loop iteration counts for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%loop1_count` | 21 bits | Inner HW loop iteration count | +| `%loop2_count` | 21 bits | Outer HW loop iteration count | + +When not using multi-level looping, set both to 1. + +--- + +### `pto.set_loop2_stride_outtoub` + +- **syntax:** `pto.set_loop2_stride_outtoub %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure outer loop (loop2) pointer advance for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 40 bits | GM source pointer advance per loop2 iteration (bytes) | +| `%dst_stride` | 21 bits | UB destination pointer advance per loop2 iteration (bytes) | + +After each loop2 iteration, the DMA engine advances the GM read pointer by `%src_stride` and UB write pointer by `%dst_stride`. + +--- + +### `pto.set_loop1_stride_outtoub` + +- **syntax:** `pto.set_loop1_stride_outtoub %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure inner loop (loop1) pointer advance for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 40 bits | GM source pointer advance per loop1 iteration (bytes) | +| `%dst_stride` | 21 bits | UB destination pointer advance per loop1 iteration (bytes) | + +--- + +## Loop Stride Configuration (UB→GM) + +These ops configure the MTE3 DMA engine's hardware loops for UB→GM transfers. They must be set **before** calling `pto.copy_ubuf_to_gm`. + +Note: UB stride fields are 21 bits (sufficient for 256KB UB address space), GM stride fields are 40 bits (full GM address range). + +### `pto.set_loop_size_ubtoout` + +- **syntax:** `pto.set_loop_size_ubtoout %loop1_count, %loop2_count : i64, i64` +- **semantics:** Configure HW loop iteration counts for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%loop1_count` | 21 bits | Inner HW loop iteration count | +| `%loop2_count` | 21 bits | Outer HW loop iteration count | + +--- + +### `pto.set_loop2_stride_ubtoout` + +- **syntax:** `pto.set_loop2_stride_ubtoout %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure outer loop (loop2) pointer advance for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 21 bits | UB source pointer advance per loop2 iteration (bytes) | +| `%dst_stride` | 40 bits | GM destination pointer advance per loop2 iteration (bytes) | + +--- + +### `pto.set_loop1_stride_ubtoout` + +- **syntax:** `pto.set_loop1_stride_ubtoout %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure inner loop (loop1) pointer advance for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 21 bits | UB source pointer advance per loop1 iteration (bytes) | +| `%dst_stride` | 40 bits | GM destination pointer advance per loop1 iteration (bytes) | + +--- + +## Pad Value Configuration + +### `pto.set_mov_pad_val` + +- **syntax:** `pto.set_mov_pad_val %value : T` +- **supported `T`:** `i8`, `i16`, `i32`, `f16`, `bf16`, `f32` +- **semantics:** Configure the pad fill value used by GM→UB DMA when `data_select_bit = true`. + +This op programs the hardware pad register consumed by `pto.copy_gm_to_ubuf`. The operand is a typed scalar. Its raw bit pattern is encoded into the underlying hardware configuration payload: + +- integer inputs use their zero-extended bit pattern +- floating-point inputs use their bitcast-to-integer bit pattern, then zero-extend to `i64` + +This configuration affects only the GM→UB padding path. UB→GM DMA ignores the pad value. + +**Parameter Table:** + +| Parameter | Description | +|-----------|-------------| +| `%value` | Pad fill scalar. Must be one of `i8/i16/i32/f16/bf16/f32`. | + +**Example:** + +```mlir +%pad = arith.constant 0 : i16 +pto.set_mov_pad_val %pad : i16 +``` + +--- + +## DMA Transfer Execution + +### `pto.copy_gm_to_ubuf` + +- **syntax:** +```mlir +pto.copy_gm_to_ubuf %gm_src, %ub_dst, + %sid, %n_burst, %len_burst, %left_padding, %right_padding, + %data_select_bit, %l2_cache_ctl, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` +- **semantics:** DMA transfer from Global Memory (`!pto.ptr`) to Unified Buffer (`!pto.ptr`). + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%gm_src` | GM source pointer (`!pto.ptr`) | +| `%ub_dst` | UB destination pointer (`!pto.ptr`, 32B-aligned) | +| `%sid` | Stream ID (usually 0) | +| `%n_burst` | Number of burst rows (innermost loop count) | +| `%len_burst` | Contiguous bytes transferred per burst row | +| `%left_padding` | Left padding count (bytes) | +| `%right_padding` | Right padding count (bytes) | +| `%data_select_bit` | Padding / data-select control bit (`i1`) | +| `%l2_cache_ctl` | L2 cache allocate control (TBD — controls whether DMA allocates in L2 cache) | +| `%src_stride` | GM source stride: start-to-start distance between consecutive burst rows (bytes) | +| `%dst_stride` | UB destination stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | + +--- + +### `pto.copy_ubuf_to_gm` + +- **syntax:** +```mlir +pto.copy_ubuf_to_gm %ub_src, %gm_dst, + %sid, %n_burst, %len_burst, %reserved, %dst_stride, %src_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` +- **semantics:** DMA transfer from Unified Buffer (`!pto.ptr`) to Global Memory (`!pto.ptr`). MTE3 reads only `len_burst` bytes from each UB row (de-padding). + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%ub_src` | UB source pointer (`!pto.ptr`, 32B-aligned) | +| `%gm_dst` | GM destination pointer (`!pto.ptr`) | +| `%sid` | Stream ID (usually 0) | +| `%n_burst` | Number of burst rows | +| `%len_burst` | Contiguous bytes transferred per burst row | +| `%reserved` | Reserved field (set to 0) | +| `%dst_stride` | GM destination stride: start-to-start distance between consecutive burst rows (bytes) | +| `%src_stride` | UB source stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | + +--- + +### `pto.copy_ubuf_to_ubuf` + +- **syntax:** +```mlir +pto.copy_ubuf_to_ubuf %source, %dest, %sid, %n_burst, %len_burst, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64 x5 +``` +- **semantics:** Copy within Unified Buffer. + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%source` | UB source pointer | +| `%dest` | UB destination pointer | +| `%sid` | Stream ID | +| `%n_burst` | Number of bursts | +| `%len_burst` | Length per burst | +| `%src_stride` | Source stride | +| `%dst_stride` | Destination stride | + +--- + +## Burst / Stride / Pad Model + +All A5 DMA addresses are **stride-based**: stride is the distance from the start of one row to the start of the next row (`stride >= lenBurst`). There is no separate "gap" parameter. + +### Key Terms + +``` +burst = lenBurst contiguous bytes transferred per row +stride = distance (bytes) from start of row[r] to start of row[r+1] +pad = ub_stride - lenBurst, padded to the 32B alignment boundary +``` + +### Alignment Constraints + +- **UB addresses** (both source and destination) must be **32-byte aligned**. +- **GM→UB padding**: When `data_select_bit = true`, each UB row is padded from `lenBurst` up to the **32B-aligned boundary** of `ub_stride` with `pad_val` (set via `set_mov_pad_val`). This ensures every UB row starts at a 32B-aligned offset. +- **UB→GM de-padding**: MTE3 reads `lenBurst` bytes from each 32B-aligned UB row (skipping any padding that was added during load), writing only valid data to GM. This effectively strips padding on store. + +### 2D Diagram: GM→UB (pto.copy_gm_to_ubuf) + +``` +GM (source, `!pto.ptr`): + + |<--- src_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +UB (destination, `!pto.ptr`, 32B-aligned): + + |<---------- dst_stride (32B-aligned) ---------->| + |<- len_burst ->|<- pad (to 32B boundary) ->| | +Row 0: [##DATA########][000000 PAD 000000000000000] +Row 1: [##DATA########][000000 PAD 000000000000000] +Row 2: [##DATA########][000000 PAD 000000000000000] + ... +Row N-1: [##DATA########][000000 PAD 000000000000000] + +N = n_burst +stride = start of row[r] to start of row[r+1] +pad = filled with pad_val to 32B boundary (data_select_bit=true) +[DATA] = valid data transferred by DMA +[PAD] = pad_val fill (set via set_mov_pad_val) +``` + +### 2D Diagram: UB→GM (pto.copy_ubuf_to_gm) + +``` +UB (source, `!pto.ptr`, 32B-aligned start addr): + + |<---------- src_stride (32B-aligned) --------->| + |<- len_burst ->|<-- pad (ignored on read) -->| | +Row 0: [##DATA########][000 pad 000000000000000000] +Row 1: [##DATA########][000 pad 000000000000000000] +Row 2: [##DATA########][000 pad 000000000000000000] + ... +Row N-1: [##DATA########][000 pad 000000000000000000] + +GM (destination, `!pto.ptr`): + + |<--- dst_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +N = n_burst +MTE3 reads only len_burst bytes from each UB row (de-padding). +Only len_burst bytes are written to each GM row. +``` + +--- + +## Multi-Level Loop Semantics (C Code) + +The full DMA transfer is a nested loop. The HW loop registers (set before the copy) control the outer levels, and the copy instruction parameters control the innermost burst level. + +### GM→UB Full Loop + +```c +// C equivalent of what the HW executes: +for (int j = 0; j < loop2_count; j++) { // HW outer loop + uint8_t *gm1 = gm_src + j * loop2_src_stride; + uint8_t *ub1 = ub_dst + j * loop2_dst_stride; + + for (int k = 0; k < loop1_count; k++) { // HW inner loop + uint8_t *gm2 = gm1 + k * loop1_src_stride; + uint8_t *ub2 = ub1 + k * loop1_dst_stride; + + for (int r = 0; r < n_burst; r++) { // burst engine + memcpy(ub2 + r * dst_stride, // UB dest row + gm2 + r * src_stride, // GM src row + len_burst); // contiguous bytes + if (data_select_bit) + memset(ub2 + r * dst_stride + len_burst, + pad_val, dst_stride - len_burst); + } + } +} +``` + +### UB→GM Full Loop + +```c +// C equivalent: +for (int j = 0; j < loop2_count; j++) { + uint8_t *ub1 = ub_src + j * loop2_src_stride; + uint8_t *gm1 = gm_dst + j * loop2_dst_stride; + + for (int k = 0; k < loop1_count; k++) { + uint8_t *ub2 = ub1 + k * loop1_src_stride; + uint8_t *gm2 = gm1 + k * loop1_dst_stride; + + for (int r = 0; r < n_burst; r++) { + memcpy(gm2 + r * dst_stride, // GM dest row + ub2 + r * src_stride, // UB src row + len_burst); // contiguous bytes + } + } +} +``` + +--- + +## Example 1: GM→UB — Load a 32×32 f32 Tile (Simple Case) + +Load a 32×32 f32 tile from GM into UB. This matches the `abs_kernel_2d` test case. + +``` +GM layout (32 × 32 f32, contiguous): + + |<- len_burst = 128B (32 × 4) ->| + |<- src_stride = 128B --------->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + +UB layout (32 × 32 f32, 32B-aligned, contiguous): + + |<- dst_stride = 128B (32B-aligned) ->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + + len_burst = 32 × 4 = 128 bytes + src_stride = 128 bytes (contiguous rows) + dst_stride = 128 bytes (already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — no multi-level loops needed +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + +pto.copy_gm_to_ubuf %arg0, %ub_in, + %c0_i64, // sid = 0 + %c32_i64, // n_burst = 32 (32 rows) + %c128_i64, // len_burst = 128 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c128_i64, // src_stride = 128 bytes + %c128_i64 // dst_stride = 128 bytes + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +## Example 2: GM→UB — Load a 2D Tile from a Larger Matrix + +Load a 64×128 tile (f16) from a 1024×512 matrix in GM into UB. + +``` +GM layout (1024 × 512 f16): + + col 0 col 128 col 512 + | | | + +--[###TILE###]+.....................+ row R + +--[###TILE###]+.....................+ row R+1 + ... + +--[###TILE###]+.....................+ row R+63 + + |<--------- src_stride = 1024B ----------->| + |<-len_burst=256B->| + + len_burst = 128 × 2 = 256 bytes (128 f16 elements) + src_stride = 512 × 2 = 1024 bytes (start-to-start, full GM row) + +UB layout (64 × 128 f16, 32B-aligned, contiguous): + + +--[###TILE###]--+ row 0 (256 bytes, 32B-aligned, no pad) + +--[###TILE###]--+ row 1 + ... + +--[###TILE###]--+ row 63 + + dst_stride = 256 bytes (= len_burst, already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — no multi-level loops needed +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 (64 rows) + %c256_i64, // len_burst = 256 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c1024_i64, // src_stride = 1024 bytes (full matrix row) + %c256_i64 // dst_stride = 256 bytes (tile row) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +## Example 3: GM→UB — Load with Padding + +Load 100 valid columns from GM into a 128-wide UB tile (f16). The remaining 28 columns are zero-padded. + +``` +GM (100 cols valid, contiguous): + + |<-len_burst=200B->| + |<- src_stride=200B (start-to-start) ->| + +--[####DATA####]-+ row 0 + +--[####DATA####]-+ row 1 + ... + +--[####DATA####]-+ row 63 + +UB (128 cols wide, 32B-aligned, padded): + + |<--------- dst_stride = 256B (32B-aligned) --------->| + |<-len_burst=200B->|<---- pad = 56B to 32B boundary ->| + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 0 + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 1 + ... + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 63 + + len_burst = 100 × 2 = 200 bytes + src_stride = 200 bytes (start-to-start, contiguous in GM) + dst_stride = 128 × 2 = 256 bytes (32B-aligned tile width in UB) + pad = 256 - 200 = 56 bytes (padded to 32B boundary with pad_val) +``` + +```mlir +%pad = arith.constant 0 : i16 +pto.set_mov_pad_val %pad : i16 +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 + %c200_i64, // len_burst = 200 bytes + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %true, // data_select_bit = true (enable padding) + %c0_i64, // l2_cache_ctl = 0 + %c200_i64, // src_stride = 200 bytes + %c256_i64 // dst_stride = 256 bytes (32B-aligned) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +## Example 4: UB→GM — Store a 32×32 f32 Tile (Simple Case) + +Store a 32×32 f32 tile from UB back to GM. This matches the `abs_kernel_2d` test case. + +``` +UB (source, 32B-aligned, 32 × 32 f32): + + |<- src_stride = 128B (32B-aligned) ->| + |<- len_burst = 128B ->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 + + (no padding here — len_burst == src_stride) + +GM (dest, 32 × 32 f32): + + |<- dst_stride = 128B ->| + |<- len_burst = 128B -->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 +``` + +```mlir +// Configure MTE3 strides +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + +pto.copy_ubuf_to_gm %ub_out, %arg1, + %c0_i64, // sid = 0 + %c32_i64, // n_burst = 32 + %c128_i64, // len_burst = 128 bytes + %c0_i64, // reserved = 0 + %c128_i64, // dst_stride = 128 bytes + %c128_i64 // src_stride = 128 bytes + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` + +--- + +## Example 5: UB→GM — Store a 2D Tile Back to a Larger Matrix + +Store a 64×128 tile (f16) from UB back to a 1024×512 GM matrix at an offset. + +``` +UB (source, 32B-aligned, 64 × 128 f16): + + |<- src_stride = 256B (32B-aligned) ->| + |<- len_burst = 256B ->| + +--[#####TILE#####]---+ row 0 + +--[#####TILE#####]---+ row 1 + ... + +--[#####TILE#####]---+ row 63 + + (no padding here — len_burst == src_stride) + +GM (dest, into 1024 × 512 matrix): + + |<----------- dst_stride = 1024B (start-to-start) --------->| + |<- len_burst = 256B ->| | + col 0 col 128 col 512 + +--[#####TILE#####]---+.............................+ row R + +--[#####TILE#####]---+.............................+ row R+1 + ... + +--[#####TILE#####]---+.............................+ row R+63 + + MTE3 reads len_burst bytes from each 32B-aligned UB row, + writes only len_burst bytes per GM row (stride controls row spacing). +``` + +```mlir +// Configure MTE3 strides +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 + +pto.copy_ubuf_to_gm %ub_ptr, %gm_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 + %c256_i64, // len_burst = 256 bytes + %c0_i64, // reserved = 0 + %c1024_i64, // dst_stride = 1024 bytes (GM row) + %c256_i64 // src_stride = 256 bytes (UB row) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` + +--- + +## Example 6: GM→UB with Multi-Level Loop (Batch of Tiles) + +Load 4 batches of 8×128 tiles from a [4, 8, 128] f16 tensor using loop1. + +``` +GM [4, 8, 128] f16 (contiguous): UB (4 tiles laid out sequentially): + + batch 0: 8 rows × 256 bytes [batch 0: 8×128][batch 1: 8×128] + batch 1: 8 rows × 256 bytes [batch 2: 8×128][batch 3: 8×128] + batch 2: 8 rows × 256 bytes + batch 3: 8 rows × 256 bytes loop1 src_stride = 2048 bytes (8 × 256) + loop1 dst_stride = 2048 bytes (8 × 256) + Each batch = 8 × 256 = 2048 bytes loop1_count = 4 (iterate over batches) +``` + +```mlir +// loop1_count = 4 batches, loop2_count = 1 (not used) +pto.set_loop_size_outtoub %c4_i64, %c1_i64 : i64, i64 + +// loop1 stride: advance by one batch (2048 bytes) in both GM and UB +pto.set_loop1_stride_outtoub %c2048_i64, %c2048_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c8_i64, // n_burst = 8 rows per batch + %c256_i64, // len_burst = 256 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c256_i64, // src_stride = 256 (contiguous rows) + %c256_i64 // dst_stride = 256 (contiguous rows) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +Execution trace: + +``` +loop1 iter 0: gm_ptr + 0×2048 → ub_ptr + 0×2048, DMA 8 rows × 256B +loop1 iter 1: gm_ptr + 1×2048 → ub_ptr + 1×2048, DMA 8 rows × 256B +loop1 iter 2: gm_ptr + 2×2048 → ub_ptr + 2×2048, DMA 8 rows × 256B +loop1 iter 3: gm_ptr + 3×2048 → ub_ptr + 3×2048, DMA 8 rows × 256B +``` diff --git a/docs/isa/01-pipeline-sync.md b/docs/isa/01-pipeline-sync.md index c7b32b677..b3f350626 100644 --- a/docs/isa/01-pipeline-sync.md +++ b/docs/isa/01-pipeline-sync.md @@ -54,16 +54,16 @@ pipe_barrier(pipe); **Pipe identifiers:** `PIPE_MTE2`, `PIPE_V`, `PIPE_MTE3` -**Example:** Two back-to-back `copy_ubuf_to_gm` calls writing to the same GM address. Without a barrier, MTE3 may reorder them and the final GM value is non-deterministic: +**Example:** Two back-to-back `dma_store` calls writing to the same GM address. Without a barrier, MTE3 may reorder them and the final GM value is non-deterministic: ```mlir // Both stores target the same GM address — order matters! -pto.copy_ubuf_to_gm %ub_partial_0, %gm_result, ... +pto.dma_store %ub_partial_0, %gm_result, ... // Without pipe_barrier, MTE3 could execute the second copy before the first // completes, producing a non-deterministic result at %gm_result. pto.pipe_barrier "PIPE_MTE3" // After barrier: first copy is guaranteed complete. Second copy overwrites deterministically. -pto.copy_ubuf_to_gm %ub_partial_1, %gm_result, ... +pto.dma_store %ub_partial_1, %gm_result, ... ``` --- @@ -160,7 +160,7 @@ For non-1:1 producer-consumer ratios (e.g., 1 MTE2 load : N Vector compute slice ```mlir // set_flag/wait_flag: 1 MTE2 load, 8 Vector computes on slices // MTE2 loads large tile once -pto.copy_gm_to_ubuf %gm_ptr, %ub_tile, ... +pto.dma_load %gm_ptr, %ub_tile, ... pto.set_flag["PIPE_MTE2", "PIPE_V", "EVT_TILE_READY"] // ◀ MUST be outside loop // Vector consumes in 8 slices — but wait_flag can only fire ONCE @@ -177,7 +177,7 @@ With `get_buf`/`rls_buf`, acquire/release can be **inside the loop** — no peel // get_buf/rls_buf: same 1:8 pattern, acquire/release inside loop works fine // MTE2 loads large tile pto.get_buf "PIPE_MTE2", %bufid_tile, %c0 : i64, i64 -pto.copy_gm_to_ubuf %gm_ptr, %ub_tile, ... +pto.dma_load %gm_ptr, %ub_tile, ... pto.rls_buf "PIPE_MTE2", %bufid_tile, %c0 : i64, i64 // Vector acquires/releases per slice — all 8 iterations work correctly @@ -252,7 +252,7 @@ Each cross-pipeline data dependency requires an explicit signal/wait pair. The p ```mlir // ─── Stage 1: MTE2 loads data from GM into UB ─── -pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, ... +pto.dma_load %gm_ptr, %ub_ptr, ... // MTE2 signals: "UB data is ready for Vector pipe" pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -275,7 +275,7 @@ pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] // MTE3 waits until Vector's signal arrives pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] -pto.copy_ubuf_to_gm %ub_out, %gm_out, ... +pto.dma_store %ub_out, %gm_out, ... ``` **Key property:** Every cross-pipeline edge is an explicit `(set_flag, wait_flag)` pair. Simple for straight-line code, but gets verbose in loops (see Example 3). @@ -290,7 +290,7 @@ Instead of naming events, each pipeline declares when it **acquires** (`get_buf` // ─── Stage 1: MTE2 loads data into UB ─── // MTE2 acquires ub_ptr — blocks if Vector hasn't released it from a prior iteration pto.get_buf "PIPE_MTE2", %bufid_ub_ptr, %c0 : i64, i64 // mode=0 (default) -pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, ... +pto.dma_load %gm_ptr, %ub_ptr, ... // MTE2 done writing ub_ptr — release it so Vector can consume pto.rls_buf "PIPE_MTE2", %bufid_ub_ptr, %c0 : i64, i64 @@ -315,7 +315,7 @@ pto.rls_buf "PIPE_V", %bufid_ub_out, %c0 : i64, i64 // ─── Stage 3: MTE3 stores result to GM ─── // MTE3 acquires ub_out — blocks until Vector releases it (RAW: V write → MTE3 read) pto.get_buf "PIPE_MTE3", %bufid_ub_out, %c0 : i64, i64 -pto.copy_ubuf_to_gm %ub_out, %gm_out, ... +pto.dma_store %ub_out, %gm_out, ... // MTE3 done reading ub_out — release so Vector can reuse it in next iteration pto.rls_buf "PIPE_MTE3", %bufid_ub_out, %c0 : i64, i64 ``` @@ -368,7 +368,7 @@ scf.for %i = %c0 to %N step %c1 { // ── MTE2: load tile[i] into buf_in[i%2] ── // WAR: wait until Vector has released buf_in[i%2] from iteration i-2 pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{pp}"] - pto.copy_gm_to_ubuf %gm_ptr[%i], %ub_in[%pp], ... + pto.dma_load %gm_ptr[%i], %ub_in[%pp], ... // RAW: signal Vector that buf_in[i%2] data is ready pto.set_flag["PIPE_MTE2", "PIPE_V", "EVT_IN_FWD_{pp}"] @@ -391,7 +391,7 @@ scf.for %i = %c0 to %N step %c1 { // ── MTE3: store result from buf_out[i%2] to GM ── // RAW: wait for Vector to finish writing buf_out[i%2] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVT_OUT_FWD_{pp}"] - pto.copy_ubuf_to_gm %ub_out[%pp], %gm_out[%i], ... + pto.dma_store %ub_out[%pp], %gm_out[%i], ... // WAR: tell Vector "done reading buf_out[i%2]" pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{pp}"] } @@ -422,7 +422,7 @@ scf.for %i = %c0 to %N step %c1 { // Acquires buf[i%2] — on first iteration, buffer is free so proceeds immediately. // On later iterations, blocks until Vector releases buf[i%2] (WAR: automatic). pto.get_buf "PIPE_MTE2", %bufid_buf[%pp], %c0 : i64, i64 // mode=0 - pto.copy_gm_to_ubuf %gm_ptr[%i], %ub_buf[%pp], ... + pto.dma_load %gm_ptr[%i], %ub_buf[%pp], ... pto.rls_buf "PIPE_MTE2", %bufid_buf[%pp], %c0 : i64, i64 // ── Vector: compute on buf[i%2] ── @@ -442,7 +442,7 @@ scf.for %i = %c0 to %N step %c1 { // ── MTE3: store result ── // Acquires out[i%2] — blocks until Vector releases it (RAW: automatic) pto.get_buf "PIPE_MTE3", %bufid_out[%pp], %c0 : i64, i64 - pto.copy_ubuf_to_gm %ub_out[%pp], %gm_out[%i], ... + pto.dma_store %ub_out[%pp], %gm_out[%i], ... pto.rls_buf "PIPE_MTE3", %bufid_out[%pp], %c0 : i64, i64 } // No post-loop drain needed — last rls_buf completes the pipeline. diff --git a/docs/isa/02-dma-copy.md b/docs/isa/02-dma-copy.md index 6e2fdc4f3..e402d8117 100644 --- a/docs/isa/02-dma-copy.md +++ b/docs/isa/02-dma-copy.md @@ -5,215 +5,123 @@ DMA transfers move data between Global Memory (GM) and Unified Buffer (UB). The MTE engines operate asynchronously from the Vector core, requiring explicit sync (see [Pipeline Sync](01-pipeline-sync.md)). -The MTE2/MTE3 DMA engine executes a **multi-level nested loop** transfer. Before issuing the copy instruction, stride and loop-size registers must be configured. +This document describes the public grouped DMA interfaces: ---- - -## Loop Stride Configuration (GM→UB) - -These ops configure the MTE2 DMA engine's hardware loops for GM→UB transfers. They must be set **before** calling `pto.copy_gm_to_ubuf`. - -### `pto.set_loop_size_outtoub` - -- **syntax:** `pto.set_loop_size_outtoub %loop1_count, %loop2_count : i64, i64` -- **semantics:** Configure HW loop iteration counts for GM→UB DMA. - -**Parameter Table:** - -| Parameter | Width | Description | -|-----------|-------|-------------| -| `%loop1_count` | 21 bits | Inner HW loop iteration count | -| `%loop2_count` | 21 bits | Outer HW loop iteration count | - -When not using multi-level looping, set both to 1. - ---- - -### `pto.set_loop2_stride_outtoub` - -- **syntax:** `pto.set_loop2_stride_outtoub %src_stride, %dst_stride : i64, i64` -- **semantics:** Configure outer loop (loop2) pointer advance for GM→UB DMA. - -**Parameter Table:** - -| Parameter | Width | Description | -|-----------|-------|-------------| -| `%src_stride` | 40 bits | GM source pointer advance per loop2 iteration (bytes) | -| `%dst_stride` | 21 bits | UB destination pointer advance per loop2 iteration (bytes) | - -After each loop2 iteration, the DMA engine advances the GM read pointer by `%src_stride` and UB write pointer by `%dst_stride`. +- `pto.dma_load` +- `pto.dma_store` ---- - -### `pto.set_loop1_stride_outtoub` - -- **syntax:** `pto.set_loop1_stride_outtoub %src_stride, %dst_stride : i64, i64` -- **semantics:** Configure inner loop (loop1) pointer advance for GM→UB DMA. - -**Parameter Table:** - -| Parameter | Width | Description | -|-----------|-------|-------------| -| `%src_stride` | 40 bits | GM source pointer advance per loop1 iteration (bytes) | -| `%dst_stride` | 21 bits | UB destination pointer advance per loop1 iteration (bytes) | - ---- - -## Loop Stride Configuration (UB→GM) - -These ops configure the MTE3 DMA engine's hardware loops for UB→GM transfers. They must be set **before** calling `pto.copy_ubuf_to_gm`. - -Note: UB stride fields are 21 bits (sufficient for 256KB UB address space), GM stride fields are 40 bits (full GM address range). - -### `pto.set_loop_size_ubtoout` - -- **syntax:** `pto.set_loop_size_ubtoout %loop1_count, %loop2_count : i64, i64` -- **semantics:** Configure HW loop iteration counts for UB→GM DMA. - -**Parameter Table:** - -| Parameter | Width | Description | -|-----------|-------|-------------| -| `%loop1_count` | 21 bits | Inner HW loop iteration count | -| `%loop2_count` | 21 bits | Outer HW loop iteration count | +The legacy low-level DMA configuration and raw copy interfaces are documented in +[02-dma-copy-legacy.md](02-dma-copy-legacy.md). --- -### `pto.set_loop2_stride_ubtoout` - -- **syntax:** `pto.set_loop2_stride_ubtoout %src_stride, %dst_stride : i64, i64` -- **semantics:** Configure outer loop (loop2) pointer advance for UB→GM DMA. - -**Parameter Table:** - -| Parameter | Width | Description | -|-----------|-------|-------------| -| `%src_stride` | 21 bits | UB source pointer advance per loop2 iteration (bytes) | -| `%dst_stride` | 40 bits | GM destination pointer advance per loop2 iteration (bytes) | - ---- +## DMA Transfer Execution -### `pto.set_loop1_stride_ubtoout` +### `pto.dma_load` -- **syntax:** `pto.set_loop1_stride_ubtoout %src_stride, %dst_stride : i64, i64` -- **semantics:** Configure inner loop (loop1) pointer advance for UB→GM DMA. +- **syntax:** +```mlir +pto.dma_load %gm_src, %ub_dst, %sid, %l2_cache_ctl, %len_burst + nburst(%n_burst, %src_stride, %dst_stride) + [loop1(%loop1_count, %loop1_src_stride, %loop1_dst_stride)] + [loop2(%loop2_count, %loop2_src_stride, %loop2_dst_stride)] + [pad(%pad_value[, %left_padding_count, %right_padding_count])] + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i64, + [loop1 i64, i64, i64,] + [loop2 i64, i64, i64,] + [pad T[, i64, i64]] +``` +- **semantics:** Grouped GM→UB DMA transfer. It carries the burst, optional HW loop, and optional padding configuration on the copy op itself. **Parameter Table:** | Parameter | Width | Description | |-----------|-------|-------------| -| `%src_stride` | 21 bits | UB source pointer advance per loop1 iteration (bytes) | -| `%dst_stride` | 40 bits | GM destination pointer advance per loop1 iteration (bytes) | - ---- - -## Pad Value Configuration - -### `pto.set_mov_pad_val` - -- **syntax:** `pto.set_mov_pad_val %value : T` -- **supported `T`:** `i8`, `i16`, `i32`, `f16`, `bf16`, `f32` -- **semantics:** Configure the pad fill value used by GM→UB DMA when `data_select_bit = true`. - -This op programs the hardware pad register consumed by `pto.copy_gm_to_ubuf`. The operand is a typed scalar. Its raw bit pattern is encoded into the underlying hardware configuration payload: - -- integer inputs use their zero-extended bit pattern -- floating-point inputs use their bitcast-to-integer bit pattern, then zero-extend to `i64` - -This configuration affects only the GM→UB padding path. UB→GM DMA ignores the pad value. - -**Parameter Table:** - -| Parameter | Description | -|-----------|-------------| -| `%value` | Pad fill scalar. Must be one of `i8/i16/i32/f16/bf16/f32`. | +| `%gm_src` | ptr | GM source pointer (`!pto.ptr`) | +| `%ub_dst` | ptr | UB destination pointer (`!pto.ptr`, 32B-aligned) | +| `%sid` | 32 bits | Stream ID | +| `%l2_cache_ctl` | 2 bits | L2 cache allocate control | +| `%len_burst` | 16 bits | Contiguous bytes transferred per burst row | +| `nburst(%n_burst, %src_stride, %dst_stride)` | 16 bits / 40 bits / 21 bits | Required innermost burst loop: count, GM source stride, UB destination stride | +| `loop1(%loop1_count, %loop1_src_stride, %loop1_dst_stride)` | 21 bits / 40 bits / 21 bits | Optional inner HW loop: count, GM source stride, UB destination stride | +| `loop2(%loop2_count, %loop2_src_stride, %loop2_dst_stride)` | 21 bits / 40 bits / 21 bits | Optional outer HW loop: count, GM source stride, UB destination stride | +| `pad(%pad_value[, %left_padding_count, %right_padding_count])` | scalar / 8 bits / 8 bits | Optional padding: fill value, optional left padding count, optional right padding count | + +**Constraints:** + +- `nburst(...)` is always required. +- `loop1(...)` and `loop2(...)` must each be provided as a complete group when present. +- `pad(...)` may contain only `%pad_value`; omitted left and right padding counts default to 0. +- If either left or right padding count is provided, both counts must be provided. +- `loop1(...)` may be used without `loop2(...)`; in that case `loop2_count` is treated as 1 when programming the loop-size register. +- `loop2(...)` requires `loop1(...)`; `loop2` without `loop1` is rejected by the verifier. +- `pad(...)` is independent of `loop1(...)` and `loop2(...)`. +- A DMA load may use `nburst(...) pad(...)` without any HW loop group. **Example:** ```mlir -%pad = arith.constant 0 : i16 -pto.set_mov_pad_val %pad : i16 +pto.dma_load %gm_in, %ub_out, %sid, %cache, %len_burst + nburst(%rows, %gm_row_stride, %ub_row_stride) + loop1(%tiles, %gm_tile_stride, %ub_tile_stride) + pad(%pad) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i64, loop1 i64, i64, i64, pad f16 ``` --- -## DMA Transfer Execution - -### `pto.copy_gm_to_ubuf` - -- **syntax:** -```mlir -pto.copy_gm_to_ubuf %gm_src, %ub_dst, - %sid, %n_burst, %len_burst, %left_padding, %right_padding, - %data_select_bit, %l2_cache_ctl, %src_stride, %dst_stride - : !pto.ptr, !pto.ptr, i64, i64, i64, - i64, i64, i1, i64, i64, i64 -``` -- **semantics:** DMA transfer from Global Memory (`!pto.ptr`) to Unified Buffer (`!pto.ptr`). - -**Parameters:** - -| Parameter | Description | -|-----------|-------------| -| `%gm_src` | GM source pointer (`!pto.ptr`) | -| `%ub_dst` | UB destination pointer (`!pto.ptr`, 32B-aligned) | -| `%sid` | Stream ID (usually 0) | -| `%n_burst` | Number of burst rows (innermost loop count) | -| `%len_burst` | Contiguous bytes transferred per burst row | -| `%left_padding` | Left padding count (bytes) | -| `%right_padding` | Right padding count (bytes) | -| `%data_select_bit` | Padding / data-select control bit (`i1`) | -| `%l2_cache_ctl` | L2 cache allocate control (TBD — controls whether DMA allocates in L2 cache) | -| `%src_stride` | GM source stride: start-to-start distance between consecutive burst rows (bytes) | -| `%dst_stride` | UB destination stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | - ---- - -### `pto.copy_ubuf_to_gm` +### `pto.dma_store` - **syntax:** ```mlir -pto.copy_ubuf_to_gm %ub_src, %gm_dst, - %sid, %n_burst, %len_burst, %reserved, %dst_stride, %src_stride - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +pto.dma_store %ub_src, %gm_dst, %sid, %reserved, %len_burst + nburst(%n_burst, %src_stride, %dst_stride) + [loop1(%loop1_count, %loop1_src_stride, %loop1_dst_stride)] + [loop2(%loop2_count, %loop2_src_stride, %loop2_dst_stride)] + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i64, + [loop1 i64, i64, i64,] + [loop2 i64, i64, i64] ``` -- **semantics:** DMA transfer from Unified Buffer (`!pto.ptr`) to Global Memory (`!pto.ptr`). MTE3 reads only `len_burst` bytes from each UB row (de-padding). +- **semantics:** Grouped UB→GM DMA transfer. It carries the burst and optional HW loop configuration on the copy op itself. -**Parameters:** - -| Parameter | Description | -|-----------|-------------| -| `%ub_src` | UB source pointer (`!pto.ptr`, 32B-aligned) | -| `%gm_dst` | GM destination pointer (`!pto.ptr`) | -| `%sid` | Stream ID (usually 0) | -| `%n_burst` | Number of burst rows | -| `%len_burst` | Contiguous bytes transferred per burst row | -| `%reserved` | Reserved field (set to 0) | -| `%dst_stride` | GM destination stride: start-to-start distance between consecutive burst rows (bytes) | -| `%src_stride` | UB source stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | +**Parameter Table:** ---- +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%ub_src` | ptr | UB source pointer (`!pto.ptr`, 32B-aligned) | +| `%gm_dst` | ptr | GM destination pointer (`!pto.ptr`) | +| `%sid` | 32 bits | Stream ID | +| `%reserved` | 8 bits | Reserved field, normally 0 | +| `%len_burst` | 16 bits | Contiguous bytes transferred per burst row | +| `nburst(%n_burst, %src_stride, %dst_stride)` | 16 bits / 21 bits / 40 bits | Required innermost burst loop: count, UB source stride, GM destination stride | +| `loop1(%loop1_count, %loop1_src_stride, %loop1_dst_stride)` | 21 bits / 21 bits / 40 bits | Optional inner HW loop: count, UB source stride, GM destination stride | +| `loop2(%loop2_count, %loop2_src_stride, %loop2_dst_stride)` | 21 bits / 21 bits / 40 bits | Optional outer HW loop: count, UB source stride, GM destination stride | + +**Constraints:** + +- `nburst(...)` is always required. +- `loop1(...)` and `loop2(...)` must each be provided as a complete group when present. +- `loop1(...)` may be used without `loop2(...)`; in that case `loop2_count` is treated as 1 when programming the loop-size register. +- `loop2(...)` requires `loop1(...)`; `loop2` without `loop1` is rejected by the verifier. -### `pto.copy_ubuf_to_ubuf` +**Example:** -- **syntax:** ```mlir -pto.copy_ubuf_to_ubuf %source, %dest, %sid, %n_burst, %len_burst, %src_stride, %dst_stride - : !pto.ptr, !pto.ptr, i64 x5 +pto.dma_store %ub_in, %gm_out, %sid, %zero, %len_burst + nburst(%rows, %ub_row_stride, %gm_row_stride) + loop1(%tiles, %ub_tile_stride, %gm_tile_stride) + loop2(%batches, %ub_batch_stride, %gm_batch_stride) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i64, loop1 i64, i64, i64, loop2 i64, i64, i64 ``` -- **semantics:** Copy within Unified Buffer. -**Parameters:** +--- -| Parameter | Description | -|-----------|-------------| -| `%source` | UB source pointer | -| `%dest` | UB destination pointer | -| `%sid` | Stream ID | -| `%n_burst` | Number of bursts | -| `%len_burst` | Length per burst | -| `%src_stride` | Source stride | -| `%dst_stride` | Destination stride | +For the legacy low-level DMA copy family, see +[02-dma-copy-legacy.md](02-dma-copy-legacy.md). --- @@ -232,10 +140,10 @@ pad = ub_stride - lenBurst, padded to the 32B alignment boundary ### Alignment Constraints - **UB addresses** (both source and destination) must be **32-byte aligned**. -- **GM→UB padding**: When `data_select_bit = true`, each UB row is padded from `lenBurst` up to the **32B-aligned boundary** of `ub_stride` with `pad_val` (set via `set_mov_pad_val`). This ensures every UB row starts at a 32B-aligned offset. +- **GM→UB padding**: When `pad(...)` is present on `pto.dma_load`, each UB row is padded from `lenBurst` up to the **32B-aligned boundary** of `ub_stride` with `pad_val`. This ensures every UB row starts at a 32B-aligned offset. - **UB→GM de-padding**: MTE3 reads `lenBurst` bytes from each 32B-aligned UB row (skipping any padding that was added during load), writing only valid data to GM. This effectively strips padding on store. -### 2D Diagram: GM→UB (pto.copy_gm_to_ubuf) +### 2D Diagram: GM→UB (`pto.dma_load`) ``` GM (source, `!pto.ptr`): @@ -260,12 +168,12 @@ Row N-1: [##DATA########][000000 PAD 000000000000000] N = n_burst stride = start of row[r] to start of row[r+1] -pad = filled with pad_val to 32B boundary (data_select_bit=true) +pad = filled with pad_val to 32B boundary (`pad(...)` present) [DATA] = valid data transferred by DMA -[PAD] = pad_val fill (set via set_mov_pad_val) +[PAD] = pad_val fill (from `pad(...)`) ``` -### 2D Diagram: UB→GM (pto.copy_ubuf_to_gm) +### 2D Diagram: UB→GM (`pto.dma_store`) ``` UB (source, `!pto.ptr`, 32B-aligned start addr): @@ -297,7 +205,8 @@ Only len_burst bytes are written to each GM row. ## Multi-Level Loop Semantics (C Code) -The full DMA transfer is a nested loop. The HW loop registers (set before the copy) control the outer levels, and the copy instruction parameters control the innermost burst level. +The full DMA transfer is a nested loop. `loop1(...)` / `loop2(...)` control the +outer levels, and `nburst(...)` controls the innermost burst level. ### GM→UB Full Loop @@ -315,7 +224,7 @@ for (int j = 0; j < loop2_count; j++) { // HW outer loop memcpy(ub2 + r * dst_stride, // UB dest row gm2 + r * src_stride, // GM src row len_burst); // contiguous bytes - if (data_select_bit) + if (pad_enabled) memset(ub2 + r * dst_stride + len_burst, pad_val, dst_stride - len_burst); } @@ -374,21 +283,11 @@ UB layout (32 × 32 f32, 32B-aligned, contiguous): ``` ```mlir -// Simple 2D load — no multi-level loops needed -pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - -pto.copy_gm_to_ubuf %arg0, %ub_in, - %c0_i64, // sid = 0 - %c32_i64, // n_burst = 32 (32 rows) - %c128_i64, // len_burst = 128 bytes per row - %c0_i64, // left_padding = 0 - %c0_i64, // right_padding = 0 - %false, // data_select_bit = false - %c0_i64, // l2_cache_ctl = 0 - %c128_i64, // src_stride = 128 bytes - %c128_i64 // dst_stride = 128 bytes - : !pto.ptr, !pto.ptr, i64, i64, i64, - i64, i64, i1, i64, i64, i64 +// Simple 2D load — only nburst(...) is needed +pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i64 ``` --- @@ -424,23 +323,10 @@ UB layout (64 × 128 f16, 32B-aligned, contiguous): ``` ```mlir -// Simple 2D load — no multi-level loops needed -pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 -pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 -pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 - -pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, - %c0_i64, // sid = 0 - %c64_i64, // n_burst = 64 (64 rows) - %c256_i64, // len_burst = 256 bytes per row - %c0_i64, // left_padding = 0 - %c0_i64, // right_padding = 0 - %false, // data_select_bit = false - %c0_i64, // l2_cache_ctl = 0 - %c1024_i64, // src_stride = 1024 bytes (full matrix row) - %c256_i64 // dst_stride = 256 bytes (tile row) - : !pto.ptr, !pto.ptr, i64, i64, i64, - i64, i64, i1, i64, i64, i64 +pto.dma_load %gm_ptr, %ub_ptr, %c0_i64, %c0_i64, %c256_i64 + nburst(%c64_i64, %c1024_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i64 ``` --- @@ -476,23 +362,11 @@ UB (128 cols wide, 32B-aligned, padded): ```mlir %pad = arith.constant 0 : i16 -pto.set_mov_pad_val %pad : i16 -pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 -pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 -pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 - -pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, - %c0_i64, // sid = 0 - %c64_i64, // n_burst = 64 - %c200_i64, // len_burst = 200 bytes - %c0_i64, // left_padding = 0 - %c0_i64, // right_padding = 0 - %true, // data_select_bit = true (enable padding) - %c0_i64, // l2_cache_ctl = 0 - %c200_i64, // src_stride = 200 bytes - %c256_i64 // dst_stride = 256 bytes (32B-aligned) - : !pto.ptr, !pto.ptr, i64, i64, i64, - i64, i64, i1, i64, i64, i64 +pto.dma_load %gm_ptr, %ub_ptr, %c0_i64, %c0_i64, %c200_i64 + nburst(%c64_i64, %c200_i64, %c256_i64) + pad(%pad, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i64, pad i16, i64, i64 ``` --- @@ -524,17 +398,10 @@ GM (dest, 32 × 32 f32): ``` ```mlir -// Configure MTE3 strides -pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - -pto.copy_ubuf_to_gm %ub_out, %arg1, - %c0_i64, // sid = 0 - %c32_i64, // n_burst = 32 - %c128_i64, // len_burst = 128 bytes - %c0_i64, // reserved = 0 - %c128_i64, // dst_stride = 128 bytes - %c128_i64 // src_stride = 128 bytes - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i64 ``` --- @@ -570,19 +437,10 @@ GM (dest, into 1024 × 512 matrix): ``` ```mlir -// Configure MTE3 strides -pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 -pto.set_loop1_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 -pto.set_loop2_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 - -pto.copy_ubuf_to_gm %ub_ptr, %gm_ptr, - %c0_i64, // sid = 0 - %c64_i64, // n_burst = 64 - %c256_i64, // len_burst = 256 bytes - %c0_i64, // reserved = 0 - %c1024_i64, // dst_stride = 1024 bytes (GM row) - %c256_i64 // src_stride = 256 bytes (UB row) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +pto.dma_store %ub_ptr, %gm_ptr, %c0_i64, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i64 ``` --- @@ -603,25 +461,12 @@ GM [4, 8, 128] f16 (contiguous): UB (4 tiles laid out sequentially): ``` ```mlir -// loop1_count = 4 batches, loop2_count = 1 (not used) -pto.set_loop_size_outtoub %c4_i64, %c1_i64 : i64, i64 - -// loop1 stride: advance by one batch (2048 bytes) in both GM and UB -pto.set_loop1_stride_outtoub %c2048_i64, %c2048_i64 : i64, i64 -pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 - -pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, - %c0_i64, // sid = 0 - %c8_i64, // n_burst = 8 rows per batch - %c256_i64, // len_burst = 256 bytes per row - %c0_i64, // left_padding = 0 - %c0_i64, // right_padding = 0 - %false, // data_select_bit = false - %c0_i64, // l2_cache_ctl = 0 - %c256_i64, // src_stride = 256 (contiguous rows) - %c256_i64 // dst_stride = 256 (contiguous rows) - : !pto.ptr, !pto.ptr, i64, i64, i64, - i64, i64, i1, i64, i64, i64 +// loop1_count = 4 batches, loop2 omitted +pto.dma_load %gm_ptr, %ub_ptr, %c0_i64, %c0_i64, %c256_i64 + nburst(%c8_i64, %c256_i64, %c256_i64) + loop1(%c4_i64, %c2048_i64, %c2048_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i64, loop1 i64, i64, i64 ``` Execution trace: diff --git a/docs/vpto-spec.md b/docs/vpto-spec.md index 70bc42c55..553c2133e 100644 --- a/docs/vpto-spec.md +++ b/docs/vpto-spec.md @@ -141,11 +141,11 @@ The PTO micro Instruction enforces a strict memory hierarchy. The Unified Buffer └─────────────────────────────────────────────┘ ``` -1. **GM → UB**: DMA transfer via MTE2 (`pto.copy_gm_to_ubuf`) +1. **GM → UB**: DMA transfer via MTE2 (`pto.dma_load`) 2. **UB → vreg**: Vector Load instructions (`pto.vlds`, `pto.vldsx2`, etc.) 3. **vreg → vreg**: Compute instructions (`pto.vadd`, `pto.vmul`, etc.) 4. **vreg → UB**: Vector Store instructions (`pto.vsts`, `pto.vstsx2`, etc.) -5. **UB → GM**: DMA transfer via MTE3 (`pto.copy_ubuf_to_gm`) +5. **UB → GM**: DMA transfer via MTE3 (`pto.dma_store`) **Load/Store Access Patterns**: @@ -246,14 +246,10 @@ pto.strict_vecscope(%ub, %ub_out, %lane) { ### Example: VecScope ```mlir -%pad = arith.constant 0 : i32 -pto.set_mov_pad_val %pad : i32 -pto.set_loop2_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 -pto.set_loop1_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 -pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 -pto.copy_gm_to_ubuf %7, %2, %3, %3, %c0_i64, %c32_i64, %4, %c0_i64, %c0_i64, - %false, %c0_i64, %c128_i64, %c128_i64 - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i1, i64, i64, i64 +pto.dma_load %7, %2, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -269,11 +265,10 @@ pto.vecscope { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] -pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 -pto.set_loop1_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 -pto.set_loop2_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 -pto.copy_ubuf_to_gm %8, %14, %3, %3, %c0_i64, %c32_i64, %4, %c0_i64, %c128_i64, %c128_i64 - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i64 +pto.dma_store %8, %14, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i64 ``` ### Example: Strict VecScope @@ -887,7 +882,7 @@ This section provides a categorized overview of all PTO micro Instruction operat | # | Group | Description | Count | Details | |---|-------|-------------|-------|---------| | 1 | [Pipeline Sync](isa/01-pipeline-sync.md) | Intra-core pipeline synchronization | 5 | `pto.set_flag`, `pto.wait_flag`, `pto.pipe_barrier`, `pto.get_buf`, `pto.rls_buf` | -| 2 | [DMA Copy Programming](isa/02-dma-copy.md) | DMA configuration and transfer between GM↔UB | 10 | `pto.set_loop*_stride_*`, `pto.set_loop_size_*`, `pto.set_mov_pad_val`, `pto.copy_gm_to_ubuf`, `pto.copy_ubuf_to_ubuf`, `pto.copy_ubuf_to_gm` | +| 2 | [DMA Copy Programming](isa/02-dma-copy.md) | Public DMA transfer interface between GM↔UB | 2 | `pto.dma_load`, `pto.dma_store` | | 3 | [Vector Load/Store](isa/03-vector-load-store.md) | UB↔vreg data movement with various access patterns | ~20 | `pto.vlds`, `pto.vldsx2`, `pto.vgather2`, `pto.vsts`, `pto.vstsx2`, `pto.vscatter`, etc. | | 4 | [Predicate Load/Store](isa/04-predicate-load-store.md) | UB↔mask register movement | 5 | `pto.plds`, `pto.pldi`, `pto.psts`, `pto.psti`, `pto.pstu` | | 5 | [Materialization & Predicate Ops](isa/05-materialization-predicate.md) | Scalar broadcast, predicate generation and manipulation | ~17 | `pto.vbr`, `pto.vdup`, `pto.pset_b*`, `pto.pge_b*`, `pto.plt_b*`, `pto.ppack`, `pto.punpack`, `pto.pnot`, `pto.psel`, etc. | @@ -910,9 +905,8 @@ This section provides a categorized overview of all PTO micro Instruction operat | Operation | Group | Description | |-----------|-------|-------------| -| GM→UB DMA | 2 | `pto.copy_gm_to_ubuf` | -| UB→GM DMA | 2 | `pto.copy_ubuf_to_gm` | -| UB→UB Copy | 2 | `pto.copy_ubuf_to_ubuf` | +| GM→UB DMA | 2 | `pto.dma_load` | +| UB→GM DMA | 2 | `pto.dma_store` | | Contiguous Load | 3 | `pto.vlds` with `NORM` dist | | Broadcast Load | 3 | `pto.vlds` with `BRC` family dist | | Gather | 3 | `pto.vgather2`, `pto.vgatherb` | diff --git a/include/PTO/IR/PTO.h b/include/PTO/IR/PTO.h index b8c6f90b2..2ccafc409 100644 --- a/include/PTO/IR/PTO.h +++ b/include/PTO/IR/PTO.h @@ -66,6 +66,24 @@ // PTO Dialect Operations //===----------------------------------------------------------------------===// +namespace mlir { +namespace pto { + +struct DmaLoopConfig { + Value count; + Value srcStride; + Value dstStride; +}; + +struct DmaPadConfig { + Value value; + Value leftCount; + Value rightCount; +}; + +} // namespace pto +} // namespace mlir + #define GET_OP_CLASSES #include "PTO/IR/PTOOps.h.inc" diff --git a/include/PTO/IR/VPTOOps.td b/include/PTO/IR/VPTOOps.td index 324c45c69..a5a1cf3a6 100644 --- a/include/PTO/IR/VPTOOps.td +++ b/include/PTO/IR/VPTOOps.td @@ -232,6 +232,51 @@ def PTO_CopyGmToUbufOp : PTO_Op<"copy_gm_to_ubuf", [ }]; } +def PTO_DmaLoadOp : PTO_Op<"dma_load", [ + AttrSizedOperandSegments, + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$sid, + I64:$l2_cache_ctl, + I64:$len_burst, + I64:$n_burst, + I64:$nburst_src_stride, + I64:$nburst_dst_stride, + Optional:$loop1_count, + Optional:$loop1_src_stride, + Optional:$loop1_dst_stride, + Optional:$loop2_count, + Optional:$loop2_src_stride, + Optional:$loop2_dst_stride, + Optional>:$pad_value, + Optional:$left_padding_count, + Optional:$right_padding_count + ); + + let results = (outs); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; + + let builders = [ + OpBuilder<(ins + "::mlir::Value":$source, + "::mlir::Value":$destination, + "::mlir::Value":$sid, + "::mlir::Value":$l2CacheCtl, + "::mlir::Value":$lenBurst, + "::mlir::pto::DmaLoopConfig":$nburst, + "::std::optional<::mlir::pto::DmaLoopConfig>":$loop1, + "::std::optional<::mlir::pto::DmaLoopConfig>":$loop2, + "::std::optional<::mlir::pto::DmaPadConfig>":$pad + )> + ]; + +} + def PTO_CopyUbufToUbufOp : PTO_Op<"copy_ubuf_to_ubuf"> { let arguments = (ins PTO_BufferType:$source, @@ -1178,6 +1223,47 @@ def PTO_CopyUbufToGmOp : PTO_Op<"copy_ubuf_to_gm", [ }]; } +def PTO_DmaStoreOp : PTO_Op<"dma_store", [ + AttrSizedOperandSegments, + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$sid, + I64:$reserved, + I64:$len_burst, + I64:$n_burst, + I64:$nburst_src_stride, + I64:$nburst_dst_stride, + Optional:$loop1_count, + Optional:$loop1_src_stride, + Optional:$loop1_dst_stride, + Optional:$loop2_count, + Optional:$loop2_src_stride, + Optional:$loop2_dst_stride + ); + + let results = (outs); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; + + let builders = [ + OpBuilder<(ins + "::mlir::Value":$source, + "::mlir::Value":$destination, + "::mlir::Value":$sid, + "::mlir::Value":$reserved, + "::mlir::Value":$lenBurst, + "::mlir::pto::DmaLoopConfig":$nburst, + "::std::optional<::mlir::pto::DmaLoopConfig>":$loop1, + "::std::optional<::mlir::pto::DmaLoopConfig>":$loop2 + )> + ]; + +} + // NOTE: Unvalidated new x2 / pair / align-store-family abstractions. Added to // reflect CCE builtin families but not yet end-to-end validated. def PTO_VselrOp : PTO_Op<"vselr", [Pure]> { diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp index c8f2ad7c3..a2691ada2 100644 --- a/lib/PTO/IR/VPTO.cpp +++ b/lib/PTO/IR/VPTO.cpp @@ -947,6 +947,137 @@ static int64_t getPtrElementByteSize(Type type) { return 0; } +static bool hasAll(Value first, Value second, Value third) { + return static_cast(first) && static_cast(second) && + static_cast(third); +} + +static bool hasAny(Value first, Value second, Value third) { + return static_cast(first) || static_cast(second) || + static_cast(third); +} + +static ParseResult parseRequiredOperandWithComma( + OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand) { + if (parser.parseOperand(operand)) + return failure(); + return parser.parseComma(); +} + +static ParseResult parseDmaTripleGroup( + OpAsmParser &parser, StringRef keyword, + SmallVectorImpl &operands) { + if (parser.parseKeyword(keyword) || parser.parseLParen()) + return failure(); + for (int i = 0; i < 3; ++i) { + OpAsmParser::UnresolvedOperand operand; + if (parser.parseOperand(operand)) + return failure(); + operands.push_back(operand); + if (i != 2 && parser.parseComma()) + return failure(); + } + return parser.parseRParen(); +} + +static ParseResult parseOptionalDmaTripleGroup( + OpAsmParser &parser, StringRef keyword, + SmallVectorImpl &operands) { + if (failed(parser.parseOptionalKeyword(keyword))) + return success(); + if (parser.parseLParen()) + return failure(); + for (int i = 0; i < 3; ++i) { + OpAsmParser::UnresolvedOperand operand; + if (parser.parseOperand(operand)) + return failure(); + operands.push_back(operand); + if (i != 2 && parser.parseComma()) + return failure(); + } + return parser.parseRParen(); +} + +static ParseResult parseOptionalDmaPadGroup( + OpAsmParser &parser, + SmallVectorImpl &operands) { + if (failed(parser.parseOptionalKeyword("pad"))) + return success(); + if (parser.parseLParen()) + return failure(); + OpAsmParser::UnresolvedOperand value; + if (parser.parseOperand(value)) + return failure(); + operands.push_back(value); + if (succeeded(parser.parseOptionalComma())) { + OpAsmParser::UnresolvedOperand left; + OpAsmParser::UnresolvedOperand right; + if (parser.parseOperand(left) || parser.parseComma() || + parser.parseOperand(right)) + return failure(); + operands.push_back(left); + operands.push_back(right); + } + return parser.parseRParen(); +} + +static ParseResult parseDmaTripleTypes(OpAsmParser &parser, + SmallVectorImpl &types) { + for (int i = 0; i < 3; ++i) { + Type type; + if (parser.parseType(type)) + return failure(); + types.push_back(type); + if (i != 2 && parser.parseComma()) + return failure(); + } + return success(); +} + +static ParseResult parseDmaPadTypes(OpAsmParser &parser, + SmallVectorImpl &types) { + Type valueType; + if (parser.parseType(valueType)) + return failure(); + types.push_back(valueType); + if (succeeded(parser.parseOptionalComma())) { + Type leftType; + Type rightType; + if (parser.parseType(leftType) || parser.parseComma() || + parser.parseType(rightType)) + return failure(); + types.push_back(leftType); + types.push_back(rightType); + } + return success(); +} + +static void printDmaTripleGroup(OpAsmPrinter &printer, StringRef keyword, + Value first, Value second, Value third) { + printer << " " << keyword << "(" << first << ", " << second << ", " << third + << ")"; +} + +static void printDmaTripleTypes(OpAsmPrinter &printer, StringRef keyword, + Type first, Type second, Type third) { + printer << ", " << keyword << " " << first << ", " << second << ", " << third; +} + +static void printDmaPadGroup(OpAsmPrinter &printer, Value value, Value left, + Value right) { + printer << " pad(" << value; + if (left || right) + printer << ", " << left << ", " << right; + printer << ")"; +} + +static void printDmaPadTypes(OpAsmPrinter &printer, Type valueType, + Type leftType, Type rightType) { + printer << ", pad " << valueType; + if (leftType || rightType) + printer << ", " << leftType << ", " << rightType; +} + template static LogicalResult verifyCopyGmToUbufOp(CopyOp op, bool expectSourceGM) { auto sourceType = dyn_cast(op.getSource().getType()); @@ -982,6 +1113,38 @@ static LogicalResult verifyCopyGmToUbufOp(CopyOp op, bool expectSourceGM) { return success(); } +template +static LogicalResult verifyOptionalDmaLoopGroup(DmaOp op, Value count, + Value srcStride, + Value dstStride, + StringRef name) { + if (hasAny(count, srcStride, dstStride) && !hasAll(count, srcStride, dstStride)) + return op.emitOpError() << "requires " << name + << " group to provide count, src stride, and dst stride together"; + return success(); +} + +static LogicalResult verifyDmaLoadStoreLoopGroups(Operation *op, Value loop1Count, + Value loop1SrcStride, + Value loop1DstStride, + Value loop2Count, + Value loop2SrcStride, + Value loop2DstStride) { + auto emitError = [&]() { return op->emitOpError(); }; + if (hasAny(loop1Count, loop1SrcStride, loop1DstStride) && + !hasAll(loop1Count, loop1SrcStride, loop1DstStride)) + return emitError() + << "requires loop1 group to provide count, src stride, and dst stride together"; + if (hasAny(loop2Count, loop2SrcStride, loop2DstStride) && + !hasAll(loop2Count, loop2SrcStride, loop2DstStride)) + return emitError() + << "requires loop2 group to provide count, src stride, and dst stride together"; + if (hasAll(loop2Count, loop2SrcStride, loop2DstStride) && + !hasAll(loop1Count, loop1SrcStride, loop1DstStride)) + return emitError() << "requires loop1 when loop2 is present"; + return success(); +} + template static LogicalResult verifyCopyUbufToGmOp(CopyOp op, bool expectSourceGM) { auto sourceType = dyn_cast(op.getSource().getType()); @@ -1139,6 +1302,185 @@ LogicalResult CopyGmToUbufOp::verify() { return verifyCopyGmToUbufOp(*this, true); } +void DmaLoadOp::build(OpBuilder &builder, OperationState &state, Value source, + Value destination, Value sid, Value l2CacheCtl, + Value lenBurst, pto::DmaLoopConfig nburst, + std::optional loop1, + std::optional loop2, + std::optional pad) { + state.addOperands({source, destination, sid, l2CacheCtl, lenBurst, + nburst.count, nburst.srcStride, nburst.dstStride}); + if (loop1) + state.addOperands({loop1->count, loop1->srcStride, loop1->dstStride}); + if (loop2) + state.addOperands({loop2->count, loop2->srcStride, loop2->dstStride}); + bool hasPadCounts = pad && pad->leftCount && pad->rightCount; + assert((!pad || static_cast(pad->leftCount) == + static_cast(pad->rightCount)) && + "dma_load pad config must provide both left and right counts, or omit both"); + if (pad) { + state.addOperands(pad->value); + if (hasPadCounts) + state.addOperands({pad->leftCount, pad->rightCount}); + } + + state.addAttribute( + getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr( + {1, 1, 1, 1, 1, 1, 1, 1, + loop1 ? 1 : 0, loop1 ? 1 : 0, loop1 ? 1 : 0, + loop2 ? 1 : 0, loop2 ? 1 : 0, loop2 ? 1 : 0, + pad ? 1 : 0, hasPadCounts ? 1 : 0, hasPadCounts ? 1 : 0})); +} + +ParseResult DmaLoadOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand source, destination, sid, l2CacheCtl, lenBurst; + SmallVector nburstOperands; + SmallVector loop1Operands; + SmallVector loop2Operands; + SmallVector padOperands; + if (parseRequiredOperandWithComma(parser, source) || + parseRequiredOperandWithComma(parser, destination) || + parseRequiredOperandWithComma(parser, sid) || + parseRequiredOperandWithComma(parser, l2CacheCtl) || + parser.parseOperand(lenBurst) || + parseDmaTripleGroup(parser, "nburst", nburstOperands) || + parseOptionalDmaTripleGroup(parser, "loop1", loop1Operands) || + parseOptionalDmaTripleGroup(parser, "loop2", loop2Operands) || + parseOptionalDmaPadGroup(parser, padOperands)) + return failure(); + + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()) + return failure(); + + Type sourceType, destinationType, sidType, l2CacheCtlType, lenBurstType; + SmallVector nburstTypes, loop1Types, loop2Types, padTypes; + if (parser.parseType(sourceType) || parser.parseComma() || + parser.parseType(destinationType) || parser.parseComma() || + parser.parseType(sidType) || parser.parseComma() || + parser.parseType(l2CacheCtlType) || parser.parseComma() || + parser.parseType(lenBurstType) || parser.parseComma() || + parseDmaTripleTypes(parser, nburstTypes)) + return failure(); + while (succeeded(parser.parseOptionalComma())) { + StringRef keyword; + if (parser.parseKeyword(&keyword)) + return failure(); + if (keyword == "loop1") { + if (!loop1Types.empty() || parseDmaTripleTypes(parser, loop1Types)) + return failure(); + continue; + } + if (keyword == "loop2") { + if (!loop2Types.empty() || parseDmaTripleTypes(parser, loop2Types)) + return failure(); + continue; + } + if (keyword == "pad") { + if (!padTypes.empty() || parseDmaPadTypes(parser, padTypes)) + return failure(); + continue; + } + return parser.emitError(parser.getCurrentLocation(), + "expected one of 'loop1', 'loop2', or 'pad'"); + } + + auto &segments = + result.getOrAddProperties().operandSegmentSizes; + llvm::copy(ArrayRef{1, 1, 1, 1, 1, 1, 1, 1, + static_cast(loop1Operands.size() ? 1 : 0), + static_cast(loop1Operands.size() ? 1 : 0), + static_cast(loop1Operands.size() ? 1 : 0), + static_cast(loop2Operands.size() ? 1 : 0), + static_cast(loop2Operands.size() ? 1 : 0), + static_cast(loop2Operands.size() ? 1 : 0), + static_cast(padOperands.size() ? 1 : 0), + static_cast(padOperands.size() == 3 ? 1 : 0), + static_cast(padOperands.size() == 3 ? 1 : 0)}, + segments.begin()); + + if (parser.resolveOperand(source, sourceType, result.operands) || + parser.resolveOperand(destination, destinationType, result.operands) || + parser.resolveOperand(sid, sidType, result.operands) || + parser.resolveOperand(l2CacheCtl, l2CacheCtlType, result.operands) || + parser.resolveOperand(lenBurst, lenBurstType, result.operands) || + parser.resolveOperands(nburstOperands, nburstTypes, parser.getCurrentLocation(), + result.operands) || + parser.resolveOperands(loop1Operands, loop1Types, parser.getCurrentLocation(), + result.operands) || + parser.resolveOperands(loop2Operands, loop2Types, parser.getCurrentLocation(), + result.operands) || + parser.resolveOperands(padOperands, padTypes, parser.getCurrentLocation(), + result.operands)) + return failure(); + return success(); +} + +void DmaLoadOp::print(OpAsmPrinter &printer) { + printer << " " << getSource() << ", " << getDestination() << ", " << getSid() + << ", " << getL2CacheCtl() << ", " << getLenBurst(); + printDmaTripleGroup(printer, "nburst", getNBurst(), getNburstSrcStride(), + getNburstDstStride()); + if (hasAll(getLoop1Count(), getLoop1SrcStride(), getLoop1DstStride())) + printDmaTripleGroup(printer, "loop1", getLoop1Count(), getLoop1SrcStride(), + getLoop1DstStride()); + if (hasAll(getLoop2Count(), getLoop2SrcStride(), getLoop2DstStride())) + printDmaTripleGroup(printer, "loop2", getLoop2Count(), getLoop2SrcStride(), + getLoop2DstStride()); + if (getPadValue()) + printDmaPadGroup(printer, getPadValue(), getLeftPaddingCount(), + getRightPaddingCount()); + printer.printOptionalAttrDict((*this)->getAttrs()); + printer << " : " << getSource().getType() << ", " << getDestination().getType() + << ", " << getSid().getType() << ", " << getL2CacheCtl().getType() + << ", " << getLenBurst().getType() << ", " << getNBurst().getType() + << ", " << getNburstSrcStride().getType() << ", " + << getNburstDstStride().getType(); + if (hasAll(getLoop1Count(), getLoop1SrcStride(), getLoop1DstStride())) + printDmaTripleTypes(printer, "loop1", getLoop1Count().getType(), + getLoop1SrcStride().getType(), + getLoop1DstStride().getType()); + if (hasAll(getLoop2Count(), getLoop2SrcStride(), getLoop2DstStride())) + printDmaTripleTypes(printer, "loop2", getLoop2Count().getType(), + getLoop2SrcStride().getType(), + getLoop2DstStride().getType()); + if (getPadValue()) + printDmaPadTypes(printer, getPadValue().getType(), + getLeftPaddingCount() ? getLeftPaddingCount().getType() : Type{}, + getRightPaddingCount() ? getRightPaddingCount().getType() : Type{}); +} + +void DmaLoadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult DmaLoadOp::verify() { + if (failed(verifyCopyGmToUbufOp(*this, true))) + return failure(); + if (failed(verifyDmaLoadStoreLoopGroups( + getOperation(), getLoop1Count(), getLoop1SrcStride(), + getLoop1DstStride(), getLoop2Count(), getLoop2SrcStride(), + getLoop2DstStride()))) + return failure(); + if (!getPadValue() && (getLeftPaddingCount() || getRightPaddingCount())) + return emitOpError() << "requires pad group to provide a pad value"; + if (getPadValue() && static_cast(getLeftPaddingCount()) != + static_cast(getRightPaddingCount())) + return emitOpError() + << "requires pad group to provide both left and right counts, or omit both"; + if (Value padValue = getPadValue()) { + Type valueType = padValue.getType(); + if (!isSupportedMovPadScalarType(valueType)) + return emitOpError() + << "expects pad value to be i8/i16/i32 or f16/bf16/f32 scalar, but got " + << valueType; + } + return success(); +} + LogicalResult SetMovPadValOp::verify() { Type valueType = getValue().getType(); if (isSupportedMovPadScalarType(valueType)) @@ -2968,3 +3310,137 @@ void CopyUbufToGmOp::getEffects( LogicalResult CopyUbufToGmOp::verify() { return verifyCopyUbufToGmOp(*this, false); } + +void DmaStoreOp::build(OpBuilder &builder, OperationState &state, Value source, + Value destination, Value sid, Value reserved, + Value lenBurst, pto::DmaLoopConfig nburst, + std::optional loop1, + std::optional loop2) { + state.addOperands({source, destination, sid, reserved, lenBurst, nburst.count, + nburst.srcStride, nburst.dstStride}); + if (loop1) + state.addOperands({loop1->count, loop1->srcStride, loop1->dstStride}); + if (loop2) + state.addOperands({loop2->count, loop2->srcStride, loop2->dstStride}); + + state.addAttribute( + getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr( + {1, 1, 1, 1, 1, 1, 1, 1, + loop1 ? 1 : 0, loop1 ? 1 : 0, loop1 ? 1 : 0, + loop2 ? 1 : 0, loop2 ? 1 : 0, loop2 ? 1 : 0})); +} + +ParseResult DmaStoreOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand source, destination, sid, reserved, lenBurst; + SmallVector nburstOperands; + SmallVector loop1Operands; + SmallVector loop2Operands; + if (parseRequiredOperandWithComma(parser, source) || + parseRequiredOperandWithComma(parser, destination) || + parseRequiredOperandWithComma(parser, sid) || + parseRequiredOperandWithComma(parser, reserved) || + parser.parseOperand(lenBurst) || + parseDmaTripleGroup(parser, "nburst", nburstOperands) || + parseOptionalDmaTripleGroup(parser, "loop1", loop1Operands) || + parseOptionalDmaTripleGroup(parser, "loop2", loop2Operands)) + return failure(); + + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()) + return failure(); + + Type sourceType, destinationType, sidType, reservedType, lenBurstType; + SmallVector nburstTypes, loop1Types, loop2Types; + if (parser.parseType(sourceType) || parser.parseComma() || + parser.parseType(destinationType) || parser.parseComma() || + parser.parseType(sidType) || parser.parseComma() || + parser.parseType(reservedType) || parser.parseComma() || + parser.parseType(lenBurstType) || parser.parseComma() || + parseDmaTripleTypes(parser, nburstTypes)) + return failure(); + while (succeeded(parser.parseOptionalComma())) { + StringRef keyword; + if (parser.parseKeyword(&keyword)) + return failure(); + if (keyword == "loop1") { + if (!loop1Types.empty() || parseDmaTripleTypes(parser, loop1Types)) + return failure(); + continue; + } + if (keyword == "loop2") { + if (!loop2Types.empty() || parseDmaTripleTypes(parser, loop2Types)) + return failure(); + continue; + } + return parser.emitError(parser.getCurrentLocation(), + "expected one of 'loop1' or 'loop2'"); + } + + auto &segments = + result.getOrAddProperties().operandSegmentSizes; + llvm::copy(ArrayRef{1, 1, 1, 1, 1, 1, 1, 1, + static_cast(loop1Operands.size() ? 1 : 0), + static_cast(loop1Operands.size() ? 1 : 0), + static_cast(loop1Operands.size() ? 1 : 0), + static_cast(loop2Operands.size() ? 1 : 0), + static_cast(loop2Operands.size() ? 1 : 0), + static_cast(loop2Operands.size() ? 1 : 0)}, + segments.begin()); + + if (parser.resolveOperand(source, sourceType, result.operands) || + parser.resolveOperand(destination, destinationType, result.operands) || + parser.resolveOperand(sid, sidType, result.operands) || + parser.resolveOperand(reserved, reservedType, result.operands) || + parser.resolveOperand(lenBurst, lenBurstType, result.operands) || + parser.resolveOperands(nburstOperands, nburstTypes, parser.getCurrentLocation(), + result.operands) || + parser.resolveOperands(loop1Operands, loop1Types, parser.getCurrentLocation(), + result.operands) || + parser.resolveOperands(loop2Operands, loop2Types, parser.getCurrentLocation(), + result.operands)) + return failure(); + return success(); +} + +void DmaStoreOp::print(OpAsmPrinter &printer) { + printer << " " << getSource() << ", " << getDestination() << ", " << getSid() + << ", " << getReserved() << ", " << getLenBurst(); + printDmaTripleGroup(printer, "nburst", getNBurst(), getNburstSrcStride(), + getNburstDstStride()); + if (hasAll(getLoop1Count(), getLoop1SrcStride(), getLoop1DstStride())) + printDmaTripleGroup(printer, "loop1", getLoop1Count(), getLoop1SrcStride(), + getLoop1DstStride()); + if (hasAll(getLoop2Count(), getLoop2SrcStride(), getLoop2DstStride())) + printDmaTripleGroup(printer, "loop2", getLoop2Count(), getLoop2SrcStride(), + getLoop2DstStride()); + printer.printOptionalAttrDict((*this)->getAttrs()); + printer << " : " << getSource().getType() << ", " << getDestination().getType() + << ", " << getSid().getType() << ", " << getReserved().getType() + << ", " << getLenBurst().getType() << ", " << getNBurst().getType() + << ", " << getNburstSrcStride().getType() << ", " + << getNburstDstStride().getType(); + if (hasAll(getLoop1Count(), getLoop1SrcStride(), getLoop1DstStride())) + printDmaTripleTypes(printer, "loop1", getLoop1Count().getType(), + getLoop1SrcStride().getType(), + getLoop1DstStride().getType()); + if (hasAll(getLoop2Count(), getLoop2SrcStride(), getLoop2DstStride())) + printDmaTripleTypes(printer, "loop2", getLoop2Count().getType(), + getLoop2SrcStride().getType(), + getLoop2DstStride().getType()); +} + +void DmaStoreOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult DmaStoreOp::verify() { + if (failed(verifyCopyUbufToGmOp(*this, false))) + return failure(); + return verifyDmaLoadStoreLoopGroups( + getOperation(), getLoop1Count(), getLoop1SrcStride(), + getLoop1DstStride(), getLoop2Count(), getLoop2SrcStride(), + getLoop2DstStride()); +} diff --git a/lib/PTO/Transforms/PTOVPTOExpandBridgeOps.cpp b/lib/PTO/Transforms/PTOVPTOExpandBridgeOps.cpp index 9b4711982..cbd02a602 100644 --- a/lib/PTO/Transforms/PTOVPTOExpandBridgeOps.cpp +++ b/lib/PTO/Transforms/PTOVPTOExpandBridgeOps.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -71,6 +72,18 @@ static Value offsetBufferPointer(Value basePtr, Type elementType, offsetIndex); } +static bool isKnownOne(Value value) { + APInt intValue; + return value && matchPattern(value, m_ConstantInt(&intValue)) && + intValue.isOne(); +} + +static bool shouldRestoreDmaLoopSize(Value loop1Count, Value loop2Count) { + if (!loop1Count) + return false; + return !isKnownOne(loop1Count) || !isKnownOne(loop2Count); +} + struct ExpandUvldPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -98,6 +111,81 @@ struct ExpandUvldPattern : public OpRewritePattern { } }; +struct ExpandDmaLoadPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(pto::DmaLoadOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value one = rewriter.create(loc, 1, 64); + Value loop2Size = op.getLoop2Count(); + if (!loop2Size) + loop2Size = one; + if (Value loop2Count = op.getLoop2Count()) + rewriter.create( + loc, op.getLoop2SrcStride(), op.getLoop2DstStride()); + + if (Value loop1Count = op.getLoop1Count()) { + rewriter.create( + loc, op.getLoop1SrcStride(), op.getLoop1DstStride()); + rewriter.create(loc, loop2Size, loop1Count); + } + + Value leftPadding = op.getLeftPaddingCount(); + if (!leftPadding) + leftPadding = rewriter.create(loc, 0, 64); + Value rightPadding = op.getRightPaddingCount(); + if (!rightPadding) + rightPadding = rewriter.create(loc, 0, 64); + Value dataSelect = rewriter.create( + loc, rewriter.getI1Type(), + rewriter.getBoolAttr(static_cast(op.getPadValue()))); + + if (Value padValue = op.getPadValue()) + rewriter.create(loc, padValue); + + rewriter.create( + loc, op.getSource(), op.getDestination(), op.getSid(), op.getNBurst(), + op.getLenBurst(), leftPadding, rightPadding, dataSelect, + op.getL2CacheCtl(), op.getNburstSrcStride(), op.getNburstDstStride()); + if (shouldRestoreDmaLoopSize(op.getLoop1Count(), loop2Size)) + rewriter.create(loc, one, one); + rewriter.eraseOp(op); + return success(); + } +}; + +struct ExpandDmaStorePattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(pto::DmaStoreOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value one = rewriter.create(loc, 1, 64); + Value loop2Size = op.getLoop2Count(); + if (!loop2Size) + loop2Size = one; + if (Value loop2Count = op.getLoop2Count()) + rewriter.create( + loc, op.getLoop2SrcStride(), op.getLoop2DstStride()); + + if (Value loop1Count = op.getLoop1Count()) { + rewriter.create( + loc, op.getLoop1SrcStride(), op.getLoop1DstStride()); + rewriter.create(loc, loop2Size, loop1Count); + } + + rewriter.create( + loc, op.getSource(), op.getDestination(), op.getSid(), op.getNBurst(), + op.getLenBurst(), op.getReserved(), op.getNburstDstStride(), + op.getNburstSrcStride()); + if (shouldRestoreDmaLoopSize(op.getLoop1Count(), loop2Size)) + rewriter.create(loc, one, one); + rewriter.eraseOp(op); + return success(); + } +}; + struct PTOVPTOExpandBridgeOpsPass : public pto::impl::PTOVPTOExpandBridgeOpsBase { using pto::impl::PTOVPTOExpandBridgeOpsBase< @@ -109,7 +197,8 @@ struct PTOVPTOExpandBridgeOpsPass return; RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); + patterns.add( + &getContext()); if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) signalPassFailure(); } diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index e0619414a..e2dac4705 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -854,6 +854,23 @@ packCopyGmToUbConfig1(Operation *anchor, ValueRange operands) { return packLoopPair(anchor, operands[9], operands[10]); } +static FailureOr packCopyGmToUbConfig0(Operation *anchor, Value sid, + Value nBurst, Value lenBurst, + Value leftPadding, + Value rightPadding, + Value dataSelect, + Value cacheCtl) { + SmallVector operands(11); + operands[2] = sid; + operands[3] = nBurst; + operands[4] = lenBurst; + operands[5] = leftPadding; + operands[6] = rightPadding; + operands[7] = dataSelect; + operands[8] = cacheCtl; + return packCopyGmToUbConfig0(anchor, operands); +} + static FailureOr packCopyUbToGmConfig0(Operation *anchor, ValueRange operands) { if (operands.size() != 8) @@ -896,6 +913,17 @@ packCopyUbToGmConfig1(Operation *anchor, ValueRange operands) { return packLoopPair(anchor, operands[6], operands[7]); } +static FailureOr packCopyUbToGmConfig0(Operation *anchor, Value sid, + Value nBurst, Value lenBurst, + Value reserved) { + SmallVector operands(8); + operands[2] = sid; + operands[3] = nBurst; + operands[4] = lenBurst; + operands[5] = reserved; + return packCopyUbToGmConfig0(anchor, operands); +} + static FailureOr packVbitsortConfig(Operation *anchor, Value repeatTimes) { OpBuilder builder(anchor); builder.setInsertionPoint(anchor); @@ -1259,8 +1287,11 @@ static StringRef getReductionUnaryStem() { } static FailureOr buildCopyGmToUbCallee(MLIRContext *context, - pto::CopyGmToUbufOp op) { - Type elementType = cast(op.getSource().getType()).getElementType(); + Type sourceType) { + auto ptrType = dyn_cast(sourceType); + if (!ptrType) + return failure(); + Type elementType = ptrType.getElementType(); std::string elem = getCopyElementFragment(elementType); if (elem.empty()) return failure(); @@ -2172,7 +2203,7 @@ class LowerCopyOpPattern final : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { FailureOr calleeName = failure(); if constexpr (std::is_same_v) - calleeName = buildCopyGmToUbCallee(op.getContext(), op); + calleeName = buildCopyGmToUbCallee(op.getContext(), op.getSource().getType()); else calleeName = buildCopyUbToGmCallee(op.getContext()); if (failed(calleeName)) @@ -2215,6 +2246,7 @@ class LowerCopyOpPattern final : public OpConversionPattern { LoweringState &state; }; + template class LowerVecScalarMaskedOpPattern final : public OpConversionPattern { diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-bf16/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vadd-bf16/kernel.pto index f86357b6b..bf31d5ab5 100644 --- a/test/vpto/cases/micro-op/binary-vector/vadd-bf16/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vadd-bf16/kernel.pto @@ -22,10 +22,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -42,8 +44,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-f16/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vadd-f16/kernel.pto index 7b382fad6..b70ba1dd8 100644 --- a/test/vpto/cases/micro-op/binary-vector/vadd-f16/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vadd-f16/kernel.pto @@ -22,10 +22,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -42,8 +44,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/kernel.pto index b514a3a81..8eeb85a2e 100644 --- a/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/kernel.pto @@ -17,10 +17,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -38,8 +40,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/kernel.pto index 59770ba4b..ea2ca63da 100644 --- a/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/kernel.pto @@ -24,10 +24,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -45,8 +47,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/kernel.pto index 13302379f..87659b202 100644 --- a/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/kernel.pto @@ -22,10 +22,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -42,8 +44,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/kernel.pto index 043d33afd..234e09bf8 100644 --- a/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/kernel.pto @@ -24,10 +24,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -45,8 +47,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/kernel.pto index a73b810d9..774431150 100644 --- a/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/kernel.pto @@ -22,10 +22,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -42,8 +44,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-tail/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vadd-tail/kernel.pto index ec2fa5fd6..f41776c3b 100644 --- a/test/vpto/cases/micro-op/binary-vector/vadd-tail/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vadd-tail/kernel.pto @@ -17,10 +17,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -38,8 +40,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vadd/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vadd/kernel.pto index b74c45432..bdc4892d1 100644 --- a/test/vpto/cases/micro-op/binary-vector/vadd/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vadd/kernel.pto @@ -20,12 +20,14 @@ module attributes {pto.target_arch = "a5"} { %false = arith.constant false pto.get_buf "PIPE_MTE2", 0, 0 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.rls_buf "PIPE_MTE2", 0, 0 pto.get_buf "PIPE_MTE2", 1, 0 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.rls_buf "PIPE_MTE2", 1, 0 pto.get_buf "PIPE_V", 0, 0 pto.get_buf "PIPE_V", 1, 0 @@ -46,8 +48,9 @@ module attributes {pto.target_arch = "a5"} { pto.rls_buf "PIPE_V", 2, 0 pto.get_buf "PIPE_MTE3", 2, 0 - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.rls_buf "PIPE_MTE3", 2, 0 pto.barrier #pto.pipe return diff --git a/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/kernel.pto index 4bab24ec8..7131fc4c8 100644 --- a/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/kernel.pto @@ -24,10 +24,12 @@ module attributes {pto.target_arch = "a5"} { %ub_carry = pto.castptr %c12288_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -43,10 +45,12 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_carry, %arg3, %c0_i64, %c1_i64, %c128_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_carry, %arg3, %c0_i64, %c0_i64, %c128_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vaddc/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vaddc/kernel.pto index dd7246774..3aef43a9d 100644 --- a/test/vpto/cases/micro-op/binary-vector/vaddc/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vaddc/kernel.pto @@ -23,11 +23,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %ub_carry = pto.castptr %c12288_i64 : i64 -> !pto.ptr %false = arith.constant false - - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -43,10 +44,12 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_carry, %arg3, %c0_i64, %c1_i64, %c128_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_carry, %arg3, %c0_i64, %c0_i64, %c128_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/kernel.pto index a474d3161..2a6085ba9 100644 --- a/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/kernel.pto @@ -24,10 +24,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -44,8 +46,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vand/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vand/kernel.pto index 453445d96..b2a66c186 100644 --- a/test/vpto/cases/micro-op/binary-vector/vand/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vand/kernel.pto @@ -22,10 +22,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -42,8 +44,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv-f16/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vdiv-f16/kernel.pto index 8f416095e..dbe3913f4 100644 --- a/test/vpto/cases/micro-op/binary-vector/vdiv-f16/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vdiv-f16/kernel.pto @@ -23,10 +23,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -44,8 +46,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/kernel.pto index bc4b2536f..657f79491 100644 --- a/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/kernel.pto @@ -17,10 +17,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -38,8 +40,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv-tail/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vdiv-tail/kernel.pto index 96b8c2fe6..1af9192c2 100644 --- a/test/vpto/cases/micro-op/binary-vector/vdiv-tail/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vdiv-tail/kernel.pto @@ -17,10 +17,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -38,8 +40,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vdiv/kernel.pto index 865320944..62b41c796 100644 --- a/test/vpto/cases/micro-op/binary-vector/vdiv/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vdiv/kernel.pto @@ -18,10 +18,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -39,9 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vmax-tail/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmax-tail/kernel.pto index d8d17ca9b..52c6db41e 100644 --- a/test/vpto/cases/micro-op/binary-vector/vmax-tail/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vmax-tail/kernel.pto @@ -17,10 +17,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -38,8 +40,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vmax/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmax/kernel.pto index 1e9948ca9..b24b18bdc 100644 --- a/test/vpto/cases/micro-op/binary-vector/vmax/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vmax/kernel.pto @@ -18,10 +18,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -39,9 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-bf16/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmin-bf16/kernel.pto index 4af801098..ebe54983a 100644 --- a/test/vpto/cases/micro-op/binary-vector/vmin-bf16/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vmin-bf16/kernel.pto @@ -22,10 +22,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -42,8 +44,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-f16/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmin-f16/kernel.pto index 2654e8d7e..88a3bf770 100644 --- a/test/vpto/cases/micro-op/binary-vector/vmin-f16/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vmin-f16/kernel.pto @@ -22,10 +22,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -42,8 +44,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/kernel.pto index df1987955..66241fe64 100644 --- a/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/kernel.pto @@ -18,10 +18,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -39,9 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/vmin/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/vmin/kernel.pto index d896ab4e8..93043db06 100644 --- a/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/vmin/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/vmin/kernel.pto @@ -18,10 +18,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -39,9 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/kernel.pto index 19c3779ec..5f3b0852c 100644 --- a/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/kernel.pto @@ -22,10 +22,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -42,8 +44,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/kernel.pto index 33f972373..2ed31bbe1 100644 --- a/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/kernel.pto @@ -22,10 +22,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -42,8 +44,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-tail/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmin-tail/kernel.pto index 8aade9361..a6270b077 100644 --- a/test/vpto/cases/micro-op/binary-vector/vmin-tail/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vmin-tail/kernel.pto @@ -17,10 +17,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -38,8 +40,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vmin/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmin/kernel.pto index d896ab4e8..93043db06 100644 --- a/test/vpto/cases/micro-op/binary-vector/vmin/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vmin/kernel.pto @@ -18,10 +18,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -39,9 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vmul-tail/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmul-tail/kernel.pto index a7eccda6f..b890b58b2 100644 --- a/test/vpto/cases/micro-op/binary-vector/vmul-tail/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vmul-tail/kernel.pto @@ -17,10 +17,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -38,8 +40,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vmul/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmul/kernel.pto index 5763169d7..d0940b673 100644 --- a/test/vpto/cases/micro-op/binary-vector/vmul/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vmul/kernel.pto @@ -18,10 +18,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -39,9 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vor-f16/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vor-f16/kernel.pto index 4a92da608..4544d4f73 100644 --- a/test/vpto/cases/micro-op/binary-vector/vor-f16/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vor-f16/kernel.pto @@ -24,10 +24,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -44,8 +46,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/kernel.pto index 49552af6d..7ed862f26 100644 --- a/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/kernel.pto @@ -24,10 +24,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -44,8 +46,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vor/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vor/kernel.pto index 97fd111c0..e531dc633 100644 --- a/test/vpto/cases/micro-op/binary-vector/vor/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vor/kernel.pto @@ -24,10 +24,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -44,8 +46,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/kernel.pto index 85112255a..4a1d178b0 100644 --- a/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/kernel.pto @@ -26,10 +26,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -46,8 +48,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/kernel.pto index 42d07f0ec..977665238 100644 --- a/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/kernel.pto @@ -24,10 +24,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -44,8 +46,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vshl/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vshl/kernel.pto index c21d9eadf..0a405c63d 100644 --- a/test/vpto/cases/micro-op/binary-vector/vshl/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vshl/kernel.pto @@ -24,10 +24,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -44,8 +46,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/kernel.pto index 459959491..98516150a 100644 --- a/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/kernel.pto @@ -24,10 +24,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -44,8 +46,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/kernel.pto index b20001bda..bc5156d50 100644 --- a/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/kernel.pto @@ -24,10 +24,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -44,8 +46,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vshr/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vshr/kernel.pto index 2ce05e8be..eefb2416c 100644 --- a/test/vpto/cases/micro-op/binary-vector/vshr/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vshr/kernel.pto @@ -24,10 +24,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -44,8 +46,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vsub-tail/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vsub-tail/kernel.pto index f124da2d7..460e5f486 100644 --- a/test/vpto/cases/micro-op/binary-vector/vsub-tail/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vsub-tail/kernel.pto @@ -17,10 +17,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -38,8 +40,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vsub/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vsub/kernel.pto index 1721e0273..df7837848 100644 --- a/test/vpto/cases/micro-op/binary-vector/vsub/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vsub/kernel.pto @@ -18,10 +18,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -39,9 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/kernel.pto index aaeabd1d4..19df47975 100644 --- a/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/kernel.pto @@ -24,10 +24,12 @@ module attributes {pto.target_arch = "a5"} { %ub_borrow = pto.castptr %c12288_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -43,10 +45,12 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_borrow, %arg3, %c0_i64, %c1_i64, %c128_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_borrow, %arg3, %c0_i64, %c0_i64, %c128_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vsubc/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vsubc/kernel.pto index 7a7dc3720..c555dd4f4 100644 --- a/test/vpto/cases/micro-op/binary-vector/vsubc/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vsubc/kernel.pto @@ -24,10 +24,12 @@ module attributes {pto.target_arch = "a5"} { %ub_borrow = pto.castptr %c12288_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -43,10 +45,12 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_borrow, %arg3, %c0_i64, %c1_i64, %c128_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_borrow, %arg3, %c0_i64, %c0_i64, %c128_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/kernel.pto index 1e9ff40bb..42a7ebe87 100644 --- a/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/kernel.pto @@ -24,10 +24,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -44,8 +46,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vxor/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vxor/kernel.pto index cf456ee72..fff570fd6 100644 --- a/test/vpto/cases/micro-op/binary-vector/vxor/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vxor/kernel.pto @@ -24,10 +24,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -44,8 +46,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-eq/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmp-eq/kernel.pto index e8fcf175f..0b0b9be36 100644 --- a/test/vpto/cases/micro-op/compare-select/vcmp-eq/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vcmp-eq/kernel.pto @@ -15,12 +15,14 @@ module attributes {pto.target_arch = "a5"} { %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr %ub_rhs = pto.castptr %c256_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr - - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -38,11 +40,11 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/kernel.pto index 66751d6ec..382b2e455 100644 --- a/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/kernel.pto @@ -15,12 +15,14 @@ module attributes {pto.target_arch = "a5"} { %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr %ub_rhs = pto.castptr %c256_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr - - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -38,11 +40,11 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/kernel.pto index df5c04fe9..291ad31a7 100644 --- a/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/kernel.pto @@ -16,12 +16,14 @@ module attributes {pto.target_arch = "a5"} { %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr %ub_rhs = pto.castptr %c256_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr - - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -39,11 +41,11 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c32_i64_data, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c32_i64_data + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/kernel.pto index 8dfa3cf79..833388e7d 100644 --- a/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/kernel.pto @@ -16,12 +16,14 @@ module attributes {pto.target_arch = "a5"} { %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr %ub_rhs = pto.castptr %c256_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr - - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -39,11 +41,11 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c32_i64_data, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c32_i64_data + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-lt/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmp-lt/kernel.pto index fbe299c4b..0601a6bb4 100644 --- a/test/vpto/cases/micro-op/compare-select/vcmp-lt/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vcmp-lt/kernel.pto @@ -15,12 +15,14 @@ module attributes {pto.target_arch = "a5"} { %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr %ub_rhs = pto.castptr %c256_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr - - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -38,11 +40,11 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-tail/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmp-tail/kernel.pto index 610422e29..bb63cbb1f 100644 --- a/test/vpto/cases/micro-op/compare-select/vcmp-tail/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vcmp-tail/kernel.pto @@ -15,12 +15,14 @@ module attributes {pto.target_arch = "a5"} { %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr %ub_rhs = pto.castptr %c256_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr - - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -38,11 +40,11 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/kernel.pto index fbd52f20f..631e5c875 100644 --- a/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/kernel.pto @@ -14,11 +14,11 @@ module attributes {pto.target_arch = "a5"} { %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c256_i64 : i64 -> !pto.ptr - - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_src, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_src, %c0_i64, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -35,11 +35,11 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-f32/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmps-f32/kernel.pto index 3be90dca5..1d7396289 100644 --- a/test/vpto/cases/micro-op/compare-select/vcmps-f32/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vcmps-f32/kernel.pto @@ -14,11 +14,11 @@ module attributes {pto.target_arch = "a5"} { %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c256_i64 : i64 -> !pto.ptr - - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_src, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_src, %c0_i64, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -35,11 +35,11 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/kernel.pto index c1f08269e..c8fc83432 100644 --- a/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/kernel.pto @@ -13,11 +13,11 @@ module attributes {pto.target_arch = "a5"} { %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c256_i64 : i64 -> !pto.ptr - - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_src, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_src, %c0_i64, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -34,11 +34,11 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/kernel.pto index dc2864f86..c95b45413 100644 --- a/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/kernel.pto @@ -14,11 +14,11 @@ module attributes {pto.target_arch = "a5"} { %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c256_i64 : i64 -> !pto.ptr - - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_src, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_src, %c0_i64, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -35,11 +35,11 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-tail/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmps-tail/kernel.pto index 342605924..c2abf5c19 100644 --- a/test/vpto/cases/micro-op/compare-select/vcmps-tail/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vcmps-tail/kernel.pto @@ -13,11 +13,11 @@ module attributes {pto.target_arch = "a5"} { %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c256_i64 : i64 -> !pto.ptr - - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_src, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_src, %c0_i64, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -34,11 +34,11 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/compare-select/vsel-i16/kernel.pto b/test/vpto/cases/micro-op/compare-select/vsel-i16/kernel.pto index 58d5eac19..c2f616389 100644 --- a/test/vpto/cases/micro-op/compare-select/vsel-i16/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vsel-i16/kernel.pto @@ -14,12 +14,14 @@ module attributes {pto.target_arch = "a5"} { %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr %ub_rhs = pto.castptr %c256_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr - - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -38,10 +40,11 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/kernel.pto b/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/kernel.pto index 79ed52b1a..8b25e4070 100644 --- a/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/kernel.pto @@ -15,12 +15,14 @@ module attributes {pto.target_arch = "a5"} { %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr %ub_rhs = pto.castptr %c256_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr - - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -39,11 +41,11 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/compare-select/vsel-tail/kernel.pto b/test/vpto/cases/micro-op/compare-select/vsel-tail/kernel.pto index 0d4602a58..46f470aa7 100644 --- a/test/vpto/cases/micro-op/compare-select/vsel-tail/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vsel-tail/kernel.pto @@ -14,13 +14,17 @@ module attributes {pto.target_arch = "a5"} { %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr %ub_rhs = pto.castptr %c256_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr - - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.copy_gm_to_ubuf %arg2, %ub_out, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg2, %ub_out, %c0_i64, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -39,10 +43,11 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/compare-select/vsel/kernel.pto b/test/vpto/cases/micro-op/compare-select/vsel/kernel.pto index 9a330d306..2993878a7 100644 --- a/test/vpto/cases/micro-op/compare-select/vsel/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vsel/kernel.pto @@ -14,12 +14,14 @@ module attributes {pto.target_arch = "a5"} { %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr %ub_rhs = pto.castptr %c256_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr - - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -38,10 +40,11 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c64_i64, %c256_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/compare-select/vselr-f16/kernel.pto b/test/vpto/cases/micro-op/compare-select/vselr-f16/kernel.pto index cea3fd793..8c5fa6e07 100644 --- a/test/vpto/cases/micro-op/compare-select/vselr-f16/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vselr-f16/kernel.pto @@ -22,10 +22,12 @@ module attributes {pto.target_arch = "a5"} { %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr %ub_idx = pto.castptr %c2048_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr - - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_src, %c0_i64, %c8_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_idx, %c0_i64, %c8_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_src, %c0_i64, %c0_i64, %c256_i64 + nburst(%c8_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_idx, %c0_i64, %c0_i64, %c256_i64 + nburst(%c8_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -43,8 +45,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c8_i64, %c256_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c256_i64 + nburst(%c8_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/compare-select/vselr-u8/kernel.pto b/test/vpto/cases/micro-op/compare-select/vselr-u8/kernel.pto index 08d052afc..6f9fa6f6e 100644 --- a/test/vpto/cases/micro-op/compare-select/vselr-u8/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vselr-u8/kernel.pto @@ -22,10 +22,12 @@ module attributes {pto.target_arch = "a5"} { %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr %ub_idx = pto.castptr %c1024_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr - - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_src, %c0_i64, %c4_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_idx, %c0_i64, %c4_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_src, %c0_i64, %c0_i64, %c256_i64 + nburst(%c4_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_idx, %c0_i64, %c0_i64, %c256_i64 + nburst(%c4_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -43,8 +45,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c4_i64, %c256_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c256_i64 + nburst(%c4_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/compare-select/vselr/kernel.pto b/test/vpto/cases/micro-op/compare-select/vselr/kernel.pto index f1eff1633..aea856716 100644 --- a/test/vpto/cases/micro-op/compare-select/vselr/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vselr/kernel.pto @@ -22,18 +22,20 @@ module attributes {pto.target_arch = "a5"} { %5 = pto.addptr %3, %4 : -> pto.set_loop2_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 pto.set_loop1_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 %6 = pto.castptr %5 : !pto.ptr -> !pto.ptr %false = arith.constant false - pto.copy_gm_to_ubuf %6, %0, %c0_i64, %c32_i64, %2, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %6, %0, %c0_i64, %c0_i64, %2 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 %7 = pto.castptr %c4096_i64 : i64 -> !pto.ptr %8 = pto.castptr %arg1 : !pto.ptr -> !pto.ptr %9 = pto.addptr %8, %4 : -> pto.set_loop2_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 pto.set_loop1_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 %10 = pto.castptr %9 : !pto.ptr -> !pto.ptr - pto.copy_gm_to_ubuf %10, %7, %c0_i64, %c32_i64, %2, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %10, %7, %c0_i64, %c0_i64, %2 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] %11 = pto.castptr %c8192_i64 : i64 -> !pto.ptr @@ -60,11 +62,12 @@ module attributes {pto.target_arch = "a5"} { %12 = arith.muli %1, %c4_i64 : i64 %13 = pto.castptr %arg2 : !pto.ptr -> !pto.ptr %14 = pto.addptr %13, %4 : -> - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 %15 = pto.castptr %14 : !pto.ptr -> !pto.ptr - pto.copy_ubuf_to_gm %11, %15, %c0_i64, %c32_i64, %12, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %11, %15, %c0_i64, %c0_i64, %12 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-special/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-f16-special/kernel.pto index d475fbb7a..5c4766df6 100644 --- a/test/vpto/cases/micro-op/conversion/vcvt-f16-special/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-special/kernel.pto @@ -15,8 +15,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -32,8 +33,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/kernel.pto index 60dfc3f26..3829a35bc 100644 --- a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/kernel.pto @@ -17,8 +17,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -36,8 +37,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c16_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c16_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/kernel.pto index e20d43ad5..f23a6b4b4 100644 --- a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/kernel.pto @@ -17,8 +17,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -36,8 +37,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c16_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c16_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/kernel.pto index a49630d82..b0b5558e9 100644 --- a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/kernel.pto @@ -15,8 +15,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -32,8 +33,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-special/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-f32-special/kernel.pto index dfd8c8ac2..67682c1da 100644 --- a/test/vpto/cases/micro-op/conversion/vcvt-f32-special/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-special/kernel.pto @@ -18,8 +18,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -40,8 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/kernel.pto index c6a08514e..45d1582e4 100644 --- a/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/kernel.pto @@ -21,8 +21,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -38,8 +39,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/kernel.pto index 91b9045ef..b9a88ed04 100644 --- a/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/kernel.pto @@ -18,8 +18,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -40,8 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/kernel.pto index 292a4df27..75a9d625b 100644 --- a/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/kernel.pto @@ -23,8 +23,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -45,8 +46,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/conversion/vcvt-tail-special/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-tail-special/kernel.pto index 9a907e9a1..f1c8b0c80 100644 --- a/test/vpto/cases/micro-op/conversion/vcvt-tail-special/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vcvt-tail-special/kernel.pto @@ -18,8 +18,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -40,8 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/conversion/vcvt-tail/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-tail/kernel.pto index 0cfaca5e4..fc77a7652 100644 --- a/test/vpto/cases/micro-op/conversion/vcvt-tail/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vcvt-tail/kernel.pto @@ -18,8 +18,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -40,8 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/kernel.pto b/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/kernel.pto index b26bc2308..6d55727f1 100644 --- a/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/kernel.pto @@ -15,8 +15,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -33,8 +34,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/kernel.pto b/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/kernel.pto index 1cf47698c..7451aa8a2 100644 --- a/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/kernel.pto @@ -41,8 +41,9 @@ module attributes {pto.target_arch = "a5"} { %ub_f = pto.castptr %c12288_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -63,12 +64,15 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_r, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_z, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_f, %arg3, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_r, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_z, %arg2, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_f, %arg3, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/conversion/vtrc-f32-special/kernel.pto b/test/vpto/cases/micro-op/conversion/vtrc-f32-special/kernel.pto index 948f50b6c..ff0e6eb4c 100644 --- a/test/vpto/cases/micro-op/conversion/vtrc-f32-special/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vtrc-f32-special/kernel.pto @@ -14,8 +14,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -32,8 +33,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/kernel.pto b/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/kernel.pto index 1cf47698c..7451aa8a2 100644 --- a/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/kernel.pto @@ -41,8 +41,9 @@ module attributes {pto.target_arch = "a5"} { %ub_f = pto.castptr %c12288_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -63,12 +64,15 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_r, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_z, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_f, %arg3, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_r, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_z, %arg2, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_f, %arg3, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/kernel.pto index 517dbdd5b..eb4f8570b 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/kernel.pto @@ -26,10 +26,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_addend, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_addend, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -46,9 +48,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vbitsort/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vbitsort/kernel.pto index d4a421d50..f45b2cb9b 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vbitsort/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vbitsort/kernel.pto @@ -18,11 +18,12 @@ module attributes {pto.target_arch = "a5"} { %ub_scores = pto.castptr %c0_i64 : i64 -> !pto.ptr %ub_indices = pto.castptr %c128_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c256_i64 : i64 -> !pto.ptr - - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_scores, %c0_i64, %c1_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_indices, %c0_i64, %c1_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_scores, %c0_i64, %c0_i64, %c128_i64 + nburst(%c1_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_indices, %c0_i64, %c0_i64, %c128_i64 + nburst(%c1_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -31,9 +32,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vci/kernel.pto index becc05323..577e9dd51 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vci/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vci/kernel.pto @@ -23,10 +23,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_zero, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_out, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_zero, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg0, %ub_out, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -43,9 +45,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/kernel.pto index 2d4df9cb0..30c802bc0 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/kernel.pto @@ -25,10 +25,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_max, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_max, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -45,9 +47,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/kernel.pto index 338dbf8be..553552e3c 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/kernel.pto @@ -27,10 +27,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_max, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_max, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -52,8 +54,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/kernel.pto index 026f64ace..1e98a0c2e 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/kernel.pto @@ -23,8 +23,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -41,9 +42,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/kernel.pto index e071ed15d..3e58304af 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/kernel.pto @@ -24,8 +24,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -41,9 +42,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/kernel.pto index 3d0cfb074..0f864b522 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/kernel.pto @@ -17,8 +17,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -35,9 +36,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/kernel.pto index 3d0cfb074..0f864b522 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/kernel.pto @@ -17,8 +17,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -35,9 +36,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/kernel.pto index 53a0e0456..13309185c 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/kernel.pto @@ -16,8 +16,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -34,9 +35,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/kernel.pto index 5901724b7..4ef16ba54 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/kernel.pto @@ -22,8 +22,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -42,8 +43,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmula/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vmula/kernel.pto index 3f4216a6b..ef2debd20 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vmula/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vmula/kernel.pto @@ -22,8 +22,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -42,8 +43,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmull/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vmull/kernel.pto index 30af09612..5ce7fd332 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vmull/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vmull/kernel.pto @@ -23,8 +23,9 @@ module attributes {pto.target_arch = "a5"} { %gm_out = pto.castptr %arg1 : !pto.ptr -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %gm_in, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %gm_in, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -44,8 +45,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/kernel.pto index 2dc582748..49e7afa73 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/kernel.pto @@ -25,10 +25,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_alpha, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_alpha, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -45,9 +47,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/kernel.pto index 5c3951823..14c1a1271 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/kernel.pto @@ -26,10 +26,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_alpha, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_alpha, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -47,9 +49,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/kernel.pto b/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/kernel.pto index 46bd057fc..e623a95c5 100644 --- a/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/kernel.pto +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/kernel.pto @@ -25,10 +25,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_offsets, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_offsets, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -45,8 +47,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2/kernel.pto b/test/vpto/cases/micro-op/gather-scatter/vgather2/kernel.pto index 34b01dced..07abdcd98 100644 --- a/test/vpto/cases/micro-op/gather-scatter/vgather2/kernel.pto +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2/kernel.pto @@ -44,10 +44,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_offsets, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_offsets, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -64,8 +66,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/kernel.pto b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/kernel.pto index 5baaa4a8d..aa0ffe7a6 100644 --- a/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/kernel.pto +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/kernel.pto @@ -26,10 +26,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_offsets, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_offsets, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -47,8 +49,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/kernel.pto b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/kernel.pto index 80f1c5a28..5e396a036 100644 --- a/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/kernel.pto +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/kernel.pto @@ -26,10 +26,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_offsets, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_offsets, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -47,8 +49,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/kernel.pto b/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/kernel.pto index ff1ed7abb..a25a6f38b 100644 --- a/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/kernel.pto @@ -26,10 +26,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_offsets, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_offsets, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -47,8 +49,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/gather-scatter/vgatherb/kernel.pto b/test/vpto/cases/micro-op/gather-scatter/vgatherb/kernel.pto index abad98343..8a6de533a 100644 --- a/test/vpto/cases/micro-op/gather-scatter/vgatherb/kernel.pto +++ b/test/vpto/cases/micro-op/gather-scatter/vgatherb/kernel.pto @@ -26,10 +26,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_offsets, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_offsets, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -47,8 +49,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/kernel.pto b/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/kernel.pto index f07933e42..63275d2b8 100644 --- a/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/kernel.pto +++ b/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/kernel.pto @@ -25,12 +25,15 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_offsets, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg2, %ub_out, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_offsets, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg2, %ub_out, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -47,8 +50,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/gather-scatter/vscatter/kernel.pto b/test/vpto/cases/micro-op/gather-scatter/vscatter/kernel.pto index 2efab3a00..021e3de6d 100644 --- a/test/vpto/cases/micro-op/gather-scatter/vscatter/kernel.pto +++ b/test/vpto/cases/micro-op/gather-scatter/vscatter/kernel.pto @@ -44,12 +44,15 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_offsets, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg2, %ub_out, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_offsets, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg2, %ub_out, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -66,8 +69,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/pand/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pand/kernel.pto index 789ebe48c..6ce0c6b76 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/pand/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/pand/kernel.pto @@ -47,8 +47,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/kernel.pto index 1026206b3..97928963f 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/kernel.pto @@ -29,8 +29,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/kernel.pto index 54c4cd105..4dc743519 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/kernel.pto @@ -29,8 +29,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/kernel.pto index a207737b8..eaff46fb8 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/kernel.pto @@ -29,8 +29,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/kernel.pto index 7885cee06..f6716db6e 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/kernel.pto @@ -29,8 +29,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/kernel.pto index 3f82b9263..907a1f831 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/kernel.pto @@ -29,8 +29,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/kernel.pto index 647105b76..5bdcd1e62 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/kernel.pto @@ -29,8 +29,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/kernel.pto index 2659d6708..36e8153e0 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/kernel.pto @@ -30,8 +30,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c96_i64, %c0_i64, %c96_i64, %c96_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c96_i64 + nburst(%c1_i64, %c96_i64, %c96_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/kernel.pto index 7991ea772..c35613534 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/kernel.pto @@ -30,8 +30,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c96_i64, %c0_i64, %c96_i64, %c96_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c96_i64 + nburst(%c1_i64, %c96_i64, %c96_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/kernel.pto index 2dc9e5ff1..bd0b898b3 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/kernel.pto @@ -29,8 +29,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/kernel.pto index 8632c930e..09468bda1 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/kernel.pto @@ -29,8 +29,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/kernel.pto index 59bdf9a89..073e994bb 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/kernel.pto @@ -29,8 +29,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/kernel.pto index 5765ef77f..21ad9e9f0 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/kernel.pto @@ -29,8 +29,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/kernel.pto index 78319891b..33bf4a3a5 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/kernel.pto @@ -29,8 +29,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/kernel.pto index fdf5a7afc..576e7f712 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/kernel.pto @@ -29,8 +29,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/kernel.pto index 68b29ee1f..7a035d937 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/kernel.pto @@ -31,8 +31,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c96_i64, %c0_i64, %c96_i64, %c96_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c96_i64 + nburst(%c1_i64, %c96_i64, %c96_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/kernel.pto index 39729a22b..03a64089b 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/kernel.pto @@ -33,8 +33,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c96_i64, %c0_i64, %c96_i64, %c96_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c96_i64 + nburst(%c1_i64, %c96_i64, %c96_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/pnot/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pnot/kernel.pto index fd8df6e77..e7cd4dc89 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/pnot/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/pnot/kernel.pto @@ -28,8 +28,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/por/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/por/kernel.pto index 84de1a29e..b0074ccb1 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/por/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/por/kernel.pto @@ -47,8 +47,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/kernel.pto index 2c9d46589..cd68b8024 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/kernel.pto @@ -28,8 +28,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/kernel.pto index 05569d1e4..37eae598b 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/kernel.pto @@ -28,8 +28,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/kernel.pto index 91c7f6515..f3966c735 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/kernel.pto @@ -32,8 +32,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/psel/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/psel/kernel.pto index f9a47a667..1dbfd52f5 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/psel/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/psel/kernel.pto @@ -28,8 +28,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/kernel.pto index 4b5992a6b..4ebd9389e 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/kernel.pto @@ -30,8 +30,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c96_i64, %c0_i64, %c96_i64, %c96_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c96_i64 + nburst(%c1_i64, %c96_i64, %c96_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/kernel.pto index 1a7a30f51..9a409c85c 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/kernel.pto @@ -30,8 +30,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c96_i64, %c0_i64, %c96_i64, %c96_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c96_i64 + nburst(%c1_i64, %c96_i64, %c96_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/pxor/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pxor/kernel.pto index 63ae77a30..925d6a810 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/pxor/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/pxor/kernel.pto @@ -47,8 +47,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/kernel.pto index 45eb413dc..f31fe25fd 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/kernel.pto @@ -40,8 +40,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg0, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg0, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/kernel.pto index a651c7888..5e1cfe7f2 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/kernel.pto @@ -40,8 +40,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg0, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg0, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/kernel.pto index 37b29d93a..c9ba40fc0 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/kernel.pto @@ -23,8 +23,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_loop2_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 pto.set_loop1_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -41,9 +42,12 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_low, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.copy_ubuf_to_gm %ub_high, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_low, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_high, %arg2, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/kernel.pto index dfea1ef9a..a5dc03db4 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/kernel.pto @@ -27,8 +27,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg0, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg0, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/kernel.pto index 14b5abffb..3cf27b45b 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/kernel.pto @@ -27,8 +27,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg0, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg0, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/kernel.pto index 6ef474ecf..4c736a0cc 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/kernel.pto @@ -40,8 +40,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg0, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg0, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/kernel.pto index b2cfb6aed..c9c206be4 100644 --- a/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/kernel.pto +++ b/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/kernel.pto @@ -25,9 +25,9 @@ module attributes {pto.target_arch = "a5"} { %ub_in = pto.castptr %c8192_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c10240_i64 : i64 -> !pto.ptr - - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c4_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c256_i64 + nburst(%c4_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -48,10 +48,11 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c4_i64, %c256_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c256_i64 + nburst(%c4_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/predicate-load-store/plds-norm/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/plds-norm/kernel.pto index 5dda2be3e..05c0727b3 100644 --- a/test/vpto/cases/micro-op/predicate-load-store/plds-norm/kernel.pto +++ b/test/vpto/cases/micro-op/predicate-load-store/plds-norm/kernel.pto @@ -23,9 +23,9 @@ module attributes {pto.target_arch = "a5"} { %ub_in = pto.castptr %c8192_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c10240_i64 : i64 -> !pto.ptr - - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c4_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c256_i64 + nburst(%c4_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -47,10 +47,11 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c4_i64, %c256_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c256_i64 + nburst(%c4_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/kernel.pto index 7b162a666..1ce7b1f7b 100644 --- a/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/kernel.pto +++ b/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/kernel.pto @@ -20,9 +20,9 @@ module attributes {pto.target_arch = "a5"} { %ub_mid = pto.castptr %c8192_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c10240_i64 : i64 -> !pto.ptr - - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_mid, %c0_i64, %c32_i64, %c32_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg1, %ub_mid, %c0_i64, %c0_i64, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -40,10 +40,11 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/kernel.pto index 9748b746c..297c5a42d 100644 --- a/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/kernel.pto +++ b/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/kernel.pto @@ -20,9 +20,9 @@ module attributes {pto.target_arch = "a5"} { %ub_mid = pto.castptr %c8192_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c10240_i64 : i64 -> !pto.ptr - - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_mid, %c0_i64, %c32_i64, %c32_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg1, %ub_mid, %c0_i64, %c0_i64, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -40,10 +40,11 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/predicate-load-store/psti-pk/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/psti-pk/kernel.pto index da7f2f773..696e33544 100644 --- a/test/vpto/cases/micro-op/predicate-load-store/psti-pk/kernel.pto +++ b/test/vpto/cases/micro-op/predicate-load-store/psti-pk/kernel.pto @@ -25,8 +25,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c1_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/kernel.pto index fc82cc0f3..ec078bc62 100644 --- a/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/kernel.pto +++ b/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/kernel.pto @@ -22,9 +22,9 @@ module attributes {pto.target_arch = "a5"} { %ub_mid = pto.castptr %c8192_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c10240_i64 : i64 -> !pto.ptr - - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_mid, %c0_i64, %c32_i64, %c32_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg1, %ub_mid, %c0_i64, %c0_i64, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -45,10 +45,11 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/kernel.pto index 534559cee..7394d1d98 100644 --- a/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/kernel.pto @@ -21,9 +21,9 @@ module attributes {pto.target_arch = "a5"} { %ub_mid = pto.castptr %c8192_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c10240_i64 : i64 -> !pto.ptr - - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_mid, %c0_i64, %c32_i64, %c32_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg1, %ub_mid, %c0_i64, %c0_i64, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -42,10 +42,11 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/kernel.pto index dcc2bede6..ca010a7d5 100644 --- a/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/kernel.pto +++ b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/kernel.pto @@ -20,9 +20,9 @@ module attributes {pto.target_arch = "a5"} { %ub_mid = pto.castptr %c8192_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c10240_i64 : i64 -> !pto.ptr - - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_mid, %c0_i64, %c32_i64, %c32_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg1, %ub_mid, %c0_i64, %c0_i64, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -41,10 +41,11 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/kernel.pto index 1bb6545b0..7f984fd38 100644 --- a/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/kernel.pto +++ b/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/kernel.pto @@ -37,8 +37,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_mask, %arg2, %c0_i64, %c1_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_mask, %arg2, %c0_i64, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/kernel.pto index 259c6c100..f98e46330 100644 --- a/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/kernel.pto @@ -32,8 +32,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_mask, %arg2, %c0_i64, %c1_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_mask, %arg2, %c0_i64, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/predicate-load-store/pstu/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/pstu/kernel.pto index 3669f19af..3c6d66504 100644 --- a/test/vpto/cases/micro-op/predicate-load-store/pstu/kernel.pto +++ b/test/vpto/cases/micro-op/predicate-load-store/pstu/kernel.pto @@ -33,8 +33,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_mask, %arg2, %c0_i64, %c1_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_mask, %arg2, %c0_i64, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/kernel.pto index f9ff23e1b..d9441c369 100644 --- a/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/kernel.pto @@ -22,8 +22,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -44,8 +45,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/kernel.pto index 28fcdaab9..923366521 100644 --- a/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/kernel.pto +++ b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/kernel.pto @@ -22,8 +22,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -44,8 +45,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/rearrangement/vpack-higher/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vpack-higher/kernel.pto index a8fe95189..5964dc138 100644 --- a/test/vpto/cases/micro-op/rearrangement/vpack-higher/kernel.pto +++ b/test/vpto/cases/micro-op/rearrangement/vpack-higher/kernel.pto @@ -22,8 +22,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -41,8 +42,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/rearrangement/vpack-lower/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vpack-lower/kernel.pto index d73ee2331..956b3f46a 100644 --- a/test/vpto/cases/micro-op/rearrangement/vpack-lower/kernel.pto +++ b/test/vpto/cases/micro-op/rearrangement/vpack-lower/kernel.pto @@ -22,8 +22,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -41,8 +42,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/kernel.pto index 719d09451..77adf7d9f 100644 --- a/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/kernel.pto +++ b/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/kernel.pto @@ -33,9 +33,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_mask_seed, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_mask_seed, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -53,8 +56,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/rearrangement/vsqz/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vsqz/kernel.pto index 24ae27544..4bf1d2a0c 100644 --- a/test/vpto/cases/micro-op/rearrangement/vsqz/kernel.pto +++ b/test/vpto/cases/micro-op/rearrangement/vsqz/kernel.pto @@ -42,8 +42,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -60,8 +61,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/rearrangement/vsunpack/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vsunpack/kernel.pto index 69167a5c3..db726fa47 100644 --- a/test/vpto/cases/micro-op/rearrangement/vsunpack/kernel.pto +++ b/test/vpto/cases/micro-op/rearrangement/vsunpack/kernel.pto @@ -44,8 +44,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %gm_in, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %gm_in, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -62,8 +63,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/kernel.pto index 0bf1ff181..4feb0d6a6 100644 --- a/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/kernel.pto +++ b/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/kernel.pto @@ -27,9 +27,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_src, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_mask_seed, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_src, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_mask_seed, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -47,8 +50,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/rearrangement/vusqz/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vusqz/kernel.pto index 58fc84285..6f890a9ea 100644 --- a/test/vpto/cases/micro-op/rearrangement/vusqz/kernel.pto +++ b/test/vpto/cases/micro-op/rearrangement/vusqz/kernel.pto @@ -27,9 +27,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_src, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_mask_seed, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_src, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_mask_seed, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -47,8 +50,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/rearrangement/vzunpack/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vzunpack/kernel.pto index 5cd5705fc..174baf184 100644 --- a/test/vpto/cases/micro-op/rearrangement/vzunpack/kernel.pto +++ b/test/vpto/cases/micro-op/rearrangement/vzunpack/kernel.pto @@ -44,8 +44,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %gm_in, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %gm_in, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -62,8 +63,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/reduction/vcadd-tail/kernel.pto b/test/vpto/cases/micro-op/reduction/vcadd-tail/kernel.pto index 0c45c6f56..fa7d8818f 100644 --- a/test/vpto/cases/micro-op/reduction/vcadd-tail/kernel.pto +++ b/test/vpto/cases/micro-op/reduction/vcadd-tail/kernel.pto @@ -14,8 +14,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -32,8 +33,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/reduction/vcadd/kernel.pto b/test/vpto/cases/micro-op/reduction/vcadd/kernel.pto index 9c27df73c..ae8062a97 100644 --- a/test/vpto/cases/micro-op/reduction/vcadd/kernel.pto +++ b/test/vpto/cases/micro-op/reduction/vcadd/kernel.pto @@ -15,8 +15,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -33,8 +34,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/reduction/vcgadd-tail/kernel.pto b/test/vpto/cases/micro-op/reduction/vcgadd-tail/kernel.pto index 4b2bc7d34..255bf3b56 100644 --- a/test/vpto/cases/micro-op/reduction/vcgadd-tail/kernel.pto +++ b/test/vpto/cases/micro-op/reduction/vcgadd-tail/kernel.pto @@ -23,8 +23,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -41,8 +42,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/reduction/vcgadd/kernel.pto b/test/vpto/cases/micro-op/reduction/vcgadd/kernel.pto index 022af54f1..b3d7f4a18 100644 --- a/test/vpto/cases/micro-op/reduction/vcgadd/kernel.pto +++ b/test/vpto/cases/micro-op/reduction/vcgadd/kernel.pto @@ -23,8 +23,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -41,8 +42,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/reduction/vcgmax-tie/kernel.pto b/test/vpto/cases/micro-op/reduction/vcgmax-tie/kernel.pto index ed4b0337e..e9e35c09e 100644 --- a/test/vpto/cases/micro-op/reduction/vcgmax-tie/kernel.pto +++ b/test/vpto/cases/micro-op/reduction/vcgmax-tie/kernel.pto @@ -23,8 +23,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -41,8 +42,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/reduction/vcgmax/kernel.pto b/test/vpto/cases/micro-op/reduction/vcgmax/kernel.pto index 7fac87069..177639bdf 100644 --- a/test/vpto/cases/micro-op/reduction/vcgmax/kernel.pto +++ b/test/vpto/cases/micro-op/reduction/vcgmax/kernel.pto @@ -23,8 +23,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -41,8 +42,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/reduction/vcgmin-tie/kernel.pto b/test/vpto/cases/micro-op/reduction/vcgmin-tie/kernel.pto index a34cfd2db..6f8c02864 100644 --- a/test/vpto/cases/micro-op/reduction/vcgmin-tie/kernel.pto +++ b/test/vpto/cases/micro-op/reduction/vcgmin-tie/kernel.pto @@ -23,8 +23,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -41,8 +42,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/reduction/vcgmin/kernel.pto b/test/vpto/cases/micro-op/reduction/vcgmin/kernel.pto index 2d329434f..cd4597eb9 100644 --- a/test/vpto/cases/micro-op/reduction/vcgmin/kernel.pto +++ b/test/vpto/cases/micro-op/reduction/vcgmin/kernel.pto @@ -23,8 +23,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -41,8 +42,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/reduction/vcmax/kernel.pto b/test/vpto/cases/micro-op/reduction/vcmax/kernel.pto index 5db925e8e..0079b3724 100644 --- a/test/vpto/cases/micro-op/reduction/vcmax/kernel.pto +++ b/test/vpto/cases/micro-op/reduction/vcmax/kernel.pto @@ -15,8 +15,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -33,8 +34,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/reduction/vcmin/kernel.pto b/test/vpto/cases/micro-op/reduction/vcmin/kernel.pto index 9f43b2fbc..171d088f8 100644 --- a/test/vpto/cases/micro-op/reduction/vcmin/kernel.pto +++ b/test/vpto/cases/micro-op/reduction/vcmin/kernel.pto @@ -15,8 +15,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -33,8 +34,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/reduction/vcpadd-tail/kernel.pto b/test/vpto/cases/micro-op/reduction/vcpadd-tail/kernel.pto index 817df6a67..87930eb3f 100644 --- a/test/vpto/cases/micro-op/reduction/vcpadd-tail/kernel.pto +++ b/test/vpto/cases/micro-op/reduction/vcpadd-tail/kernel.pto @@ -23,8 +23,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -41,8 +42,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/reduction/vcpadd/kernel.pto b/test/vpto/cases/micro-op/reduction/vcpadd/kernel.pto index 1dc57f76e..02217c761 100644 --- a/test/vpto/cases/micro-op/reduction/vcpadd/kernel.pto +++ b/test/vpto/cases/micro-op/reduction/vcpadd/kernel.pto @@ -23,8 +23,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -41,8 +42,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/kernel.pto b/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/kernel.pto index 5b9192cc3..6b9b69a30 100644 --- a/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/kernel.pto +++ b/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/kernel.pto @@ -21,8 +21,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -40,8 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-f16/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vabs-f16/kernel.pto index d74233130..d27d811d6 100644 --- a/test/vpto/cases/micro-op/unary-vector/vabs-f16/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vabs-f16/kernel.pto @@ -42,8 +42,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -60,8 +61,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/kernel.pto index 1e6de64e5..5e70d9aa0 100644 --- a/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/kernel.pto @@ -14,8 +14,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -32,8 +33,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/kernel.pto index 3e66f18ce..a30c4171f 100644 --- a/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/kernel.pto @@ -21,8 +21,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -39,8 +40,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/kernel.pto index 7b5cd2e9b..bd15cdeff 100644 --- a/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/kernel.pto @@ -42,8 +42,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -60,8 +61,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/kernel.pto index 94898cb27..b06a30b8c 100644 --- a/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/kernel.pto @@ -42,8 +42,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -60,8 +61,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/kernel.pto index 58b88ae51..289fff56b 100644 --- a/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/kernel.pto @@ -18,8 +18,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -40,8 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-tail/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vabs-tail/kernel.pto index 9ef43bd64..f40c918dc 100644 --- a/test/vpto/cases/micro-op/unary-vector/vabs-tail/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vabs-tail/kernel.pto @@ -14,8 +14,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -32,8 +33,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vabs/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vabs/kernel.pto index 8fb9f1391..dece39ab6 100644 --- a/test/vpto/cases/micro-op/unary-vector/vabs/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vabs/kernel.pto @@ -34,8 +34,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -52,8 +53,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-f16/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vexp-f16/kernel.pto index 939ee5b4d..e393f27c0 100644 --- a/test/vpto/cases/micro-op/unary-vector/vexp-f16/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vexp-f16/kernel.pto @@ -22,8 +22,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -40,8 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/kernel.pto index 8228a02a3..677c46ed3 100644 --- a/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/kernel.pto @@ -14,8 +14,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -32,8 +33,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/kernel.pto index 7dde55a08..3a0cfa12b 100644 --- a/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/kernel.pto @@ -14,8 +14,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -32,8 +33,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-tail/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vexp-tail/kernel.pto index 9f30419ea..417b59631 100644 --- a/test/vpto/cases/micro-op/unary-vector/vexp-tail/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vexp-tail/kernel.pto @@ -14,8 +14,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -32,8 +33,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vexp/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vexp/kernel.pto index abcad4e82..361a3f405 100644 --- a/test/vpto/cases/micro-op/unary-vector/vexp/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vexp/kernel.pto @@ -34,8 +34,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -52,8 +53,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/kernel.pto index 14a195eeb..4dc052fbf 100644 --- a/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/kernel.pto @@ -42,8 +42,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -60,8 +61,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vln/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vln/kernel.pto index 7746c3655..2c4b24b4d 100644 --- a/test/vpto/cases/micro-op/unary-vector/vln/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vln/kernel.pto @@ -15,8 +15,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -33,8 +34,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/kernel.pto index 93d2893e1..f42d91a17 100644 --- a/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/kernel.pto @@ -42,8 +42,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -60,8 +61,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vneg/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vneg/kernel.pto index 68197d779..a3b0db15f 100644 --- a/test/vpto/cases/micro-op/unary-vector/vneg/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vneg/kernel.pto @@ -42,8 +42,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -60,8 +61,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vnot/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vnot/kernel.pto index d33238b99..638c376db 100644 --- a/test/vpto/cases/micro-op/unary-vector/vnot/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vnot/kernel.pto @@ -22,8 +22,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -40,8 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vrelu/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vrelu/kernel.pto index 0fb482427..7a3b10eac 100644 --- a/test/vpto/cases/micro-op/unary-vector/vrelu/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vrelu/kernel.pto @@ -15,8 +15,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -33,8 +34,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/kernel.pto index f2c48b3d6..e1b5d4f0f 100644 --- a/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/kernel.pto @@ -42,8 +42,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -60,8 +61,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vsqrt/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vsqrt/kernel.pto index 3b1de1acd..dd97fb468 100644 --- a/test/vpto/cases/micro-op/unary-vector/vsqrt/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vsqrt/kernel.pto @@ -15,8 +15,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -33,8 +34,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/kernel.pto index b6c9aea8f..ee4012a99 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/kernel.pto @@ -23,11 +23,12 @@ module attributes {pto.target_arch = "a5"} { %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %ub_carry = pto.castptr %c12288_i64 : i64 -> !pto.ptr - - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -44,10 +45,12 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_carry, %arg3, %c0_i64, %c1_i64, %c128_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_carry, %arg3, %c0_i64, %c0_i64, %c128_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vaddcs/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vaddcs/kernel.pto index cd73bb3c6..76bf73839 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vaddcs/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vaddcs/kernel.pto @@ -21,11 +21,12 @@ module attributes {pto.target_arch = "a5"} { %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %ub_carry = pto.castptr %c12288_i64 : i64 -> !pto.ptr - - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -42,10 +43,12 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_carry, %arg3, %c0_i64, %c1_i64, %c128_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_carry, %arg3, %c0_i64, %c0_i64, %c128_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/kernel.pto index 7c1ccbf2f..f4c829c50 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/kernel.pto @@ -15,8 +15,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -32,9 +33,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-f16/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vadds-f16/kernel.pto index 219522026..7e79f1862 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vadds-f16/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-f16/kernel.pto @@ -15,8 +15,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -32,9 +33,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/kernel.pto index 8309e4a04..17839109d 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/kernel.pto @@ -16,8 +16,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -34,9 +35,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/kernel.pto index 52b50c48e..a450713a8 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/kernel.pto @@ -21,8 +21,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -38,8 +39,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/kernel.pto index fe7c7c1c2..0e319b6cf 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/kernel.pto @@ -21,8 +21,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -38,8 +39,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/kernel.pto index ad099717a..c2b7cf7a2 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/kernel.pto @@ -20,8 +20,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -37,9 +38,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/kernel.pto index b3c4391ff..ae7984b3c 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/kernel.pto @@ -20,8 +20,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -37,9 +38,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-tail/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vadds-tail/kernel.pto index b66075220..f6dde0983 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vadds-tail/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-tail/kernel.pto @@ -16,8 +16,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -34,9 +35,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vadds/kernel.pto index f8fb5d002..4bb2cbfbf 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vadds/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vadds/kernel.pto @@ -17,8 +17,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -35,9 +36,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/kernel.pto index 5f7cf6e77..91b3c6cfc 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/kernel.pto @@ -16,8 +16,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -34,9 +35,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vmaxs/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vmaxs/kernel.pto index 217298fde..ec3fa1106 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vmaxs/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vmaxs/kernel.pto @@ -17,8 +17,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -35,9 +36,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vmins-tail/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vmins-tail/kernel.pto index 9ee49d201..59092b972 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vmins-tail/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vmins-tail/kernel.pto @@ -16,8 +16,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -34,9 +35,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vmins/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vmins/kernel.pto index de6e39715..d2fa7b4e7 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vmins/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vmins/kernel.pto @@ -17,8 +17,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -35,9 +36,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/kernel.pto index 11ff8c4fb..6c8527f8c 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/kernel.pto @@ -16,8 +16,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -34,9 +35,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vmuls/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vmuls/kernel.pto index 49442c61c..9992d7887 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vmuls/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vmuls/kernel.pto @@ -17,8 +17,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -35,9 +36,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/kernel.pto index 0c7ee44f2..2a6a4f6b5 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/kernel.pto @@ -20,8 +20,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -37,9 +38,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vshls/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vshls/kernel.pto index 5700b094f..184d1a1f5 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vshls/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vshls/kernel.pto @@ -20,8 +20,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -37,9 +38,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/kernel.pto index 8c00efd4c..07bfbed77 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/kernel.pto @@ -20,8 +20,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -37,9 +38,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vshrs/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vshrs/kernel.pto index ca745b48f..783e5cd09 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vshrs/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vshrs/kernel.pto @@ -20,8 +20,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -37,9 +38,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/kernel.pto index 830a471c4..8bb3028e4 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/kernel.pto @@ -23,11 +23,12 @@ module attributes {pto.target_arch = "a5"} { %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %ub_borrow = pto.castptr %c12288_i64 : i64 -> !pto.ptr - - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -44,10 +45,12 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_borrow, %arg3, %c0_i64, %c1_i64, %c128_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_borrow, %arg3, %c0_i64, %c0_i64, %c128_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vsubcs/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vsubcs/kernel.pto index 8f0476a76..0969c996a 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vsubcs/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vsubcs/kernel.pto @@ -21,11 +21,12 @@ module attributes {pto.target_arch = "a5"} { %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %ub_borrow = pto.castptr %c12288_i64 : i64 -> !pto.ptr - - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_lhs, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_rhs, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -42,10 +43,12 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg2, %c0_i64, %c1_i64, %c256_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_borrow, %arg3, %c0_i64, %c1_i64, %c128_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_borrow, %arg3, %c0_i64, %c0_i64, %c128_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/kernel.pto index 580780047..1592a02fa 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/kernel.pto @@ -27,8 +27,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -51,8 +52,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/kernel.pto index 0eaf940b7..e49e98352 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/kernel.pto @@ -42,8 +42,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -61,8 +62,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/kernel.pto index bd6c2e8d8..df14c972c 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/kernel.pto @@ -20,8 +20,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -37,8 +38,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/kernel.pto index d7a79b2b0..9aa74cd03 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/kernel.pto @@ -27,8 +27,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -44,8 +45,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/kernel.pto index 1bedcf05f..c7d4982e7 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/kernel.pto @@ -42,8 +42,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -59,8 +60,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/kernel.pto index 3e5f101db..5d0375db2 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/kernel.pto @@ -20,8 +20,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -37,8 +38,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/kernel.pto index 13a1330d9..ff5b07da8 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/kernel.pto @@ -23,8 +23,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -40,8 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/kernel.pto index 66a727e94..56a9938ef 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/kernel.pto @@ -23,8 +23,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -40,8 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-tail/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds-tail/kernel.pto index 505e0d678..1270dffc5 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vlds-tail/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-tail/kernel.pto @@ -43,8 +43,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -60,8 +61,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/kernel.pto index 672e884e2..7ce1a5b81 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/kernel.pto @@ -29,8 +29,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -46,8 +47,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/kernel.pto index 8e79bcab0..d7ced22df 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/kernel.pto @@ -23,8 +23,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -40,8 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds/kernel.pto index fca778706..56373cafc 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vlds/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vlds/kernel.pto @@ -42,8 +42,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -59,8 +60,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/kernel.pto index 762500c80..0eeefab4e 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/kernel.pto @@ -22,13 +22,9 @@ module attributes {pto.target_arch = "a5"} { %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr - - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, - %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, - %c128_i64 - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, - i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -52,10 +48,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, - %c0_i64, %c128_i64, %c128_i64 - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/kernel.pto index 193eb3043..899859fd9 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/kernel.pto @@ -21,13 +21,9 @@ module attributes {pto.target_arch = "a5"} { %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr - - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, - %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, - %c128_i64 - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, - i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -49,10 +45,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, - %c0_i64, %c128_i64, %c128_i64 - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/kernel.pto index 12bd7b9a0..e3fe71b6f 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/kernel.pto @@ -21,13 +21,9 @@ module attributes {pto.target_arch = "a5"} { %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr - - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, - %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, - %c128_i64 - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, - i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -49,10 +45,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, - %c0_i64, %c128_i64, %c128_i64 - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vsldb/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vsldb/kernel.pto index 9e9e87ff1..454684cf1 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vsldb/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vsldb/kernel.pto @@ -20,8 +20,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -37,8 +38,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vsstb/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vsstb/kernel.pto index b022afee6..defe42ef5 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vsstb/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vsstb/kernel.pto @@ -20,8 +20,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -40,8 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_in, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_in, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vstar/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vstar/kernel.pto index f9ddd4309..89d0b856c 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vstar/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vstar/kernel.pto @@ -32,8 +32,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out1 = pto.addptr %ub_out, %c1_elem : !pto.ptr -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -53,8 +54,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/kernel.pto index 89b918893..34dc63560 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/kernel.pto @@ -30,9 +30,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_out, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_out, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -49,8 +52,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/kernel.pto index a45e227dd..592b12b7c 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/kernel.pto @@ -23,8 +23,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -40,8 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/kernel.pto index 29ef8cd9c..e6e2ee5c0 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/kernel.pto @@ -23,8 +23,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -40,8 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/kernel.pto index 797b25ff5..273b38bda 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/kernel.pto @@ -20,11 +20,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 - - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg1, %ub_out, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_out, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -40,8 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-tail/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vsts-tail/kernel.pto index da446a7fe..5ea2df48c 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vsts-tail/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-tail/kernel.pto @@ -17,13 +17,9 @@ module attributes {pto.target_arch = "a5"} { %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr - - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, - %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, - %c128_i64 - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, - i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -40,10 +36,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, - %c0_i64, %c128_i64, %c128_i64 - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vsts/kernel.pto index 07e60de3d..f5eaba126 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vsts/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vsts/kernel.pto @@ -42,8 +42,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -59,8 +60,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/kernel.pto index 9c54c8af1..6dcfb7139 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/kernel.pto @@ -19,13 +19,9 @@ module attributes {pto.target_arch = "a5"} { %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr - - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, - %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, - %c128_i64 - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, - i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -44,10 +40,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, - %c0_i64, %c128_i64, %c128_i64 - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/kernel.pto index 71d6470e6..38a90af23 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/kernel.pto @@ -20,8 +20,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out1 = pto.addptr %ub_out, %c1 : !pto.ptr -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -43,8 +44,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vstur/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vstur/kernel.pto index 6c194fde5..104be5a26 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vstur/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vstur/kernel.pto @@ -30,8 +30,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out1 = pto.addptr %ub_out, %c1 : !pto.ptr -> !pto.ptr %false = arith.constant false - pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 - pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -51,8 +52,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 - pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } From b9c00a7d94d38f0ca736d84ae4731dc95559b3b8 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Wed, 22 Apr 2026 14:54:12 +0800 Subject: [PATCH 122/192] Support packed vcvt part modes --- docs/isa/09-conversion-ops.md | 17 ++- lib/PTO/IR/VPTO.cpp | 63 ++++++++++- lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 17 ++- test/basic/vcvt_part_modes_verify_invalid.pto | 30 +++++ test/basic/vcvt_part_modes_vpto_llvm.pto | 34 ++++++ .../vcvt-u32-to-u8-part-p0123/compare.py | 37 +++++++ .../vcvt-u32-to-u8-part-p0123/golden.py | 88 +++++++++++++++ .../vcvt-u32-to-u8-part-p0123/kernel.pto | 63 +++++++++++ .../vcvt-u32-to-u8-part-p0123/launch.cpp | 49 +++++++++ .../vcvt-u32-to-u8-part-p0123/main.cpp | 103 ++++++++++++++++++ .../vcvt-u32-to-u8-part-p0123/stub.cpp | 23 ++++ 11 files changed, 517 insertions(+), 7 deletions(-) create mode 100644 test/basic/vcvt_part_modes_verify_invalid.pto create mode 100644 test/basic/vcvt_part_modes_vpto_llvm.pto create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/compare.py create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/golden.py create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/kernel.pto create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/launch.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/main.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/stub.cpp diff --git a/docs/isa/09-conversion-ops.md b/docs/isa/09-conversion-ops.md index f3953089b..47f9f50a0 100644 --- a/docs/isa/09-conversion-ops.md +++ b/docs/isa/09-conversion-ops.md @@ -94,13 +94,26 @@ for (int i = 0; i < min(N, M); i++) ### Part Modes Use `part` when a width-changing conversion writes only one half of each wider -destination lane group. This is typically used in even/odd placement forms such -as `32 -> 16` or `16 -> 32` style conversions. +destination lane group. + +- `Part` (`PART_EVEN`, `PART_ODD`) + - Used by ordinary width-changing conversions. + - Typical cases include `32 -> 16`, `16 -> 32`, and other even/odd packing + or unpacking forms. +- `Part_T` (`PART_P0`, `PART_P1`, `PART_P2`, `PART_P3`) + - Used by lower-level packed placement forms. + - Typical cases include `32 -> 8`, packed fp8/fp4 conversion paths, and + other flows where the result is written into one of four sub-parts before a + later merge or compact step. | Mode | Description | |------|-------------| | `EVEN` | Output to even-indexed lanes | | `ODD` | Output to odd-indexed lanes | +| `P0` | Output to sub-part 0 in 4-way packed placement forms | +| `P1` | Output to sub-part 1 in 4-way packed placement forms | +| `P2` | Output to sub-part 2 in 4-way packed placement forms | +| `P3` | Output to sub-part 3 in 4-way packed placement forms | --- diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp index a2691ada2..ce0481be3 100644 --- a/lib/PTO/IR/VPTO.cpp +++ b/lib/PTO/IR/VPTO.cpp @@ -560,6 +560,24 @@ static std::optional normalizeEvenOddPartToken(StringRef token) { return std::nullopt; } +static std::optional normalizePacked4PartToken(StringRef token) { + if (token == "P0" || token == "PART_P0") + return StringRef("P0"); + if (token == "P1" || token == "PART_P1") + return StringRef("P1"); + if (token == "P2" || token == "PART_P2") + return StringRef("P2"); + if (token == "P3" || token == "PART_P3") + return StringRef("P3"); + return std::nullopt; +} + +static std::optional normalizeVcvtPartToken(StringRef token) { + if (auto normalized = normalizeEvenOddPartToken(token)) + return normalized; + return normalizePacked4PartToken(token); +} + namespace { enum class VcvtElemKind { @@ -582,6 +600,11 @@ struct VcvtContract { bool requiresPart; }; +enum class VcvtPartFamily { + EvenOdd, + Packed4, +}; + static VcvtElemKind classifyVcvtElemType(Type type) { if (type.isF16()) return VcvtElemKind::F16; @@ -628,6 +651,27 @@ static std::optional getVcvtElemBitWidth(VcvtElemKind kind) { return std::nullopt; } +static std::optional classifyVcvtPartFamily(unsigned srcBits, + unsigned dstBits) { + unsigned largerBits = std::max(srcBits, dstBits); + unsigned smallerBits = std::min(srcBits, dstBits); + if (largerBits == smallerBits * 2) + return VcvtPartFamily::EvenOdd; + if (largerBits == smallerBits * 4) + return VcvtPartFamily::Packed4; + return std::nullopt; +} + +static bool isValidVcvtPartForFamily(StringRef part, VcvtPartFamily family) { + switch (family) { + case VcvtPartFamily::EvenOdd: + return part == "EVEN" || part == "ODD"; + case VcvtPartFamily::Packed4: + return part == "P0" || part == "P1" || part == "P2" || part == "P3"; + } + return false; +} + static std::optional lookupVcvtContract(VcvtElemKind src, VcvtElemKind dst) { switch (src) { @@ -2652,8 +2696,7 @@ ParseResult VcvtOp::parse(OpAsmParser &parser, OperationState &result) { normalizeRoundModeToken)) || failed(normalizeNamedStringAttr("rnd", "rnd", normalizeRoundModeToken)) || failed(normalizeNamedStringAttr("sat", "sat", normalizeSaturationToken)) || - failed( - normalizeNamedStringAttr("part", "part", normalizeEvenOddPartToken))) + failed(normalizeNamedStringAttr("part", "part", normalizeVcvtPartToken))) return failure(); result.addAttributes(attrs); @@ -2713,8 +2756,20 @@ LogicalResult VcvtOp::verify() { if (getPartAttr()) { StringRef part = *getPart(); - if (!normalizeEvenOddPartToken(part)) - return emitOpError("part must be EVEN or ODD"); + auto normalizedPart = normalizeVcvtPartToken(part); + if (!normalizedPart) + return emitOpError("part must be one of EVEN/ODD/P0/P1/P2/P3"); + auto partFamily = classifyVcvtPartFamily(*inputElemBits, *resultElemBits); + if (!partFamily) + return emitOpError("part attr is not supported for this vcvt width relation"); + if (!isValidVcvtPartForFamily(*normalizedPart, *partFamily)) { + switch (*partFamily) { + case VcvtPartFamily::EvenOdd: + return emitOpError("part must be EVEN or ODD for 8/16 and 16/32 vcvt forms"); + case VcvtPartFamily::Packed4: + return emitOpError("part must be P0, P1, P2, or P3 for 8/32 vcvt forms"); + } + } } if (static_cast(getPartAttr()) != contract->requiresPart) { return contract->requiresPart ? emitOpError("requires part attr for this vcvt type pair") diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index e2dac4705..bd950edc9 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -376,6 +376,20 @@ static std::optional parsePartImmediate(StringRef part) { return std::nullopt; } +static std::optional parseVcvtPartImmediate(StringRef part) { + if (part == "EVEN" || part == "PART_EVEN" || part == "P0" || + part == "PART_P0") + return 0; + if (part == "ODD" || part == "PART_ODD" || part == "P1" || + part == "PART_P1") + return 1; + if (part == "P2" || part == "PART_P2") + return 2; + if (part == "P3" || part == "PART_P3") + return 3; + return std::nullopt; +} + static std::optional parsePredicateStoreDistImmediate(StringRef dist) { if (dist == "NORM") return 0; @@ -4241,7 +4255,8 @@ class LowerVcvtOpPattern final : public OpConversionPattern { } if ((*contract).requiresPart) { - auto part = op.getPartAttr() ? parsePartImmediate(*op.getPart()) : std::nullopt; + auto part = + op.getPartAttr() ? parseVcvtPartImmediate(*op.getPart()) : std::nullopt; if (!part) return rewriter.notifyMatchFailure(op, "vcvt requires valid part attr"); Value partValue = getI32Constant(rewriter, op.getLoc(), *part); diff --git a/test/basic/vcvt_part_modes_verify_invalid.pto b/test/basic/vcvt_part_modes_verify_invalid.pto new file mode 100644 index 000000000..caa8cae08 --- /dev/null +++ b/test/basic/vcvt_part_modes_verify_invalid.pto @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --pto-arch=a5 --pto-backend=vpto %s 2>&1 | FileCheck %s + +// CHECK: error: 'pto.vcvt' op part must be P0, P1, P2, or P3 for 8/32 vcvt forms +// CHECK: error: 'pto.vcvt' op part must be EVEN or ODD for 8/16 and 16/32 vcvt forms + +module attributes {pto.target_arch = "a5"} { + func.func @vcvt_u32_to_u8_rejects_even(%seed: ui32) attributes {pto.kernel_kind = #pto.kernel_kind} { + pto.vecscope { + %src = pto.vbr %seed : ui32 -> !pto.vreg<64xui32> + %bad = pto.vcvt %src {sat = "SAT", part = "EVEN"} : !pto.vreg<64xui32> -> !pto.vreg<256xui8> + } + return + } + + func.func @vcvt_u16_to_u8_rejects_p0(%seed: ui16) attributes {pto.kernel_kind = #pto.kernel_kind} { + pto.vecscope { + %src = pto.vbr %seed : ui16 -> !pto.vreg<128xui16> + %bad = pto.vcvt %src {sat = "SAT", part = "P0"} : !pto.vreg<128xui16> -> !pto.vreg<256xui8> + } + return + } +} diff --git a/test/basic/vcvt_part_modes_vpto_llvm.pto b/test/basic/vcvt_part_modes_vpto_llvm.pto new file mode 100644 index 000000000..a3dfc75f2 --- /dev/null +++ b/test/basic/vcvt_part_modes_vpto_llvm.pto @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --vpto-emit-hivm-llvm %s -o - | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + func.func @vcvt_u32_to_u8_packed_parts(%seed: ui32, %dst: !pto.ptr) attributes {pto.kernel_kind = #pto.kernel_kind} { + %c0 = arith.constant 0 : index + pto.vecscope { + %mask = pto.pset_b8 "PAT_ALL" : !pto.mask + %src = pto.vbr %seed : ui32 -> !pto.vreg<64xui32> + %p0 = pto.vcvt %src {sat = "SAT", part = "P0"} : !pto.vreg<64xui32> -> !pto.vreg<256xui8> + %p1 = pto.vcvt %src {sat = "SAT", part = "P1"} : !pto.vreg<64xui32> -> !pto.vreg<256xui8> + %p2 = pto.vcvt %src {sat = "SAT", part = "P2"} : !pto.vreg<64xui32> -> !pto.vreg<256xui8> + %p3 = pto.vcvt %src {sat = "SAT", part = "P3"} : !pto.vreg<64xui32> -> !pto.vreg<256xui8> + %m01 = pto.vor %p0, %p1, %mask : !pto.vreg<256xui8>, !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<256xui8> + %m23 = pto.vor %p2, %p3, %mask : !pto.vreg<256xui8>, !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<256xui8> + %out = pto.vor %m01, %m23, %mask : !pto.vreg<256xui8>, !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<256xui8> + pto.vsts %out, %dst[%c0], %mask : !pto.vreg<256xui8>, !pto.ptr, !pto.mask + } + return + } +} + +// CHECK-LABEL: define void @vcvt_u32_to_u8_packed_parts( +// CHECK: call <256 x i8> @llvm.hivm.vcvtii.u322u8.x({{.*}} i32 0, i32 0) +// CHECK: call <256 x i8> @llvm.hivm.vcvtii.u322u8.x({{.*}} i32 0, i32 1) +// CHECK: call <256 x i8> @llvm.hivm.vcvtii.u322u8.x({{.*}} i32 0, i32 2) +// CHECK: call <256 x i8> @llvm.hivm.vcvtii.u322u8.x({{.*}} i32 0, i32 3) diff --git a/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/compare.py b/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/compare.py new file mode 100644 index 000000000..918a4b0bc --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/golden.py b/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/golden.py new file mode 100644 index 000000000..487f52bfb --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/golden.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/conversion/vcvt-u32-to-u8-part-p0123 +# family: conversion +# target_ops: pto.vcvt +# scenarios: u32-to-u8, sat, part-p0123 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 23 +CHUNK = 256 +SUBCHUNK = 64 +U8_MAX = np.iinfo(np.uint8).max + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + data = rng.integers(0, 2000, size=ELEMS, dtype=np.uint32) + edge = np.array( + [ + 0, + 1, + 2, + 3, + 4, + 7, + 15, + 31, + 63, + 127, + 128, + 129, + 254, + 255, + 256, + 257, + 511, + 512, + 1023, + 65535, + 0xFFFFFFFF, + ], + dtype=np.uint32, + ) + data[: edge.size] = edge + + clipped = np.clip(data, 0, U8_MAX).astype(np.uint8) + golden = np.empty(ELEMS, dtype=np.uint8) + for offset in range(0, ELEMS, CHUNK): + p0 = clipped[offset : offset + SUBCHUNK] + p1 = clipped[offset + SUBCHUNK : offset + 2 * SUBCHUNK] + p2 = clipped[offset + 2 * SUBCHUNK : offset + 3 * SUBCHUNK] + p3 = clipped[offset + 3 * SUBCHUNK : offset + 4 * SUBCHUNK] + merged = np.empty(CHUNK, dtype=np.uint8) + merged[0::4] = p0 + merged[1::4] = p1 + merged[2::4] = p2 + merged[3::4] = p3 + golden[offset : offset + CHUNK] = merged + + output_dir.mkdir(parents=True, exist_ok=True) + data.tofile(output_dir / "v1.bin") + np.zeros(ELEMS, dtype=np.uint8).tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/kernel.pto new file mode 100644 index 000000000..d3d46469e --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/kernel.pto @@ -0,0 +1,63 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-u32-to-u8-part-p0123 +// family: conversion +// target_ops: pto.vcvt +// scenarios: u32-to-u8, sat, part-p0123 +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vcvt_u32_to_u8_part_p0123_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c192 = arith.constant 192 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c4_i64 = arith.constant 4 : i64 + %c8_i64 = arith.constant 8 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %full_mask = pto.pset_b8 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c256 { + %offset_p1 = arith.addi %offset, %c64 : index + %offset_p2 = arith.addi %offset, %c128 : index + %offset_p3 = arith.addi %offset, %c192 : index + %src_p0 = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xui32> + %src_p1 = pto.vlds %ub_in[%offset_p1] : !pto.ptr -> !pto.vreg<64xui32> + %src_p2 = pto.vlds %ub_in[%offset_p2] : !pto.ptr -> !pto.vreg<64xui32> + %src_p3 = pto.vlds %ub_in[%offset_p3] : !pto.ptr -> !pto.vreg<64xui32> + %part_p0 = pto.vcvt %src_p0 {sat = "SAT", part = "P0"} : !pto.vreg<64xui32> -> !pto.vreg<256xui8> + %part_p1 = pto.vcvt %src_p1 {sat = "SAT", part = "P1"} : !pto.vreg<64xui32> -> !pto.vreg<256xui8> + %part_p2 = pto.vcvt %src_p2 {sat = "SAT", part = "P2"} : !pto.vreg<64xui32> -> !pto.vreg<256xui8> + %part_p3 = pto.vcvt %src_p3 {sat = "SAT", part = "P3"} : !pto.vreg<64xui32> -> !pto.vreg<256xui8> + %merged01 = pto.vor %part_p0, %part_p1, %full_mask : !pto.vreg<256xui8>, !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<256xui8> + %merged23 = pto.vor %part_p2, %part_p3, %full_mask : !pto.vreg<256xui8>, !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<256xui8> + %merged = pto.vor %merged01, %merged23, %full_mask : !pto.vreg<256xui8>, !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<256xui8> + pto.vsts %merged, %ub_out[%offset], %full_mask : !pto.vreg<256xui8>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c4_i64, %c256_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/launch.cpp b/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/launch.cpp new file mode 100644 index 000000000..0923981ed --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/launch.cpp @@ -0,0 +1,49 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, WHETHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcvt_u32_to_u8_part_p0123_kernel( + __gm__ uint32_t *v1, __gm__ uint8_t *v2); + +void LaunchVcvt_u32_to_u8_part_p0123_kernel(uint32_t *v1, uint8_t *v2, + void *stream) { + vcvt_u32_to_u8_part_p0123_kernel<<<1, nullptr, stream>>>( + (__gm__ uint32_t *)v1, (__gm__ uint8_t *)v2); +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/main.cpp b/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/main.cpp new file mode 100644 index 000000000..07cedfe4e --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, WHETHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-u32-to-u8-part-p0123 +// family: conversion +// target_ops: pto.vcvt +// scenarios: u32-to-u8, sat, part-p0123 +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcvt_u32_to_u8_part_p0123_kernel(uint32_t *v1, uint8_t *v2, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint8_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + uint8_t *v2Host = nullptr; + uint8_t *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcvt_u32_to_u8_part_p0123_kernel(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/stub.cpp b/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/stub.cpp new file mode 100644 index 000000000..9e3547e7e --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/stub.cpp @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, WHETHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vcvt_u32_to_u8_part_p0123_kernel( + __gm__ uint32_t *v1, __gm__ uint8_t *v2) { + (void)v1; + (void)v2; +} From bda4063ed8b3593bb8cd259aa5c9a4b82e9cba70 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Wed, 22 Apr 2026 15:35:32 +0800 Subject: [PATCH 123/192] Support p0/p1/p2/p3 vcvt part mode --- .../11-vector-arithmetic-operations.md | 16 ++++- tilelang-dsl/python/tilelang_dsl/semantic.py | 6 +- tilelang-dsl/python/tilelang_dsl/types.py | 4 ++ tilelang-dsl/tests/test_tilelang_dsl_v1.py | 67 +++++++++++++++++++ 4 files changed, 89 insertions(+), 4 deletions(-) diff --git a/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md b/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md index dd60b3692..09e940314 100644 --- a/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md +++ b/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md @@ -1354,7 +1354,7 @@ family. **Attribute Enums**: - `pto.VcvtRoundMode`: `R`, `A`, `F`, `C`, `Z`, `O` - `pto.VcvtSatMode`: `SAT`, `NOSAT` -- `pto.VcvtPartMode`: `EVEN`, `ODD` +- `pto.VcvtPartMode`: `EVEN`, `ODD`, `P0`, `P1`, `P2`, `P3` **Parameters**: | Parameter | Type | Description | @@ -1364,7 +1364,7 @@ family. | `mask` | `MaskType` | Predicate mask selecting active source lanes. Its granularity must match the source vector family, not the destination family | | `rnd` | `pto.VcvtRoundMode` \| `None` | Optional rounding-mode attribute lowered to VPTO `rnd` | | `sat` | `pto.VcvtSatMode` \| `None` | Optional saturation attribute lowered to VPTO `sat` | -| `part` | `pto.VcvtPartMode` \| `None` | Optional even/odd packing selector lowered to VPTO `part` | +| `part` | `pto.VcvtPartMode` \| `None` | Optional width-changing lane-placement selector lowered to VPTO `part` | **Returns**: | Return Value | Type | Description | @@ -1384,6 +1384,18 @@ family. `i8`/`si8`/`ui8` use `mask_b8`. - The enum form is preferred. For compatibility, canonical strings such as `"R"`, `"SAT"`, and `"EVEN"` are also accepted. +- VPTO `part` supports two families: `Part` (`EVEN`/`ODD`) for ordinary + width-changing conversions (e.g. `32 -> 16`, `16 -> 32`), and `Part_T` + (`P0`–`P3`) for 4-way packed placement (e.g. `32 -> 8`, fp8/fp4 flows). + + | Mode | VPTO spelling | Family | Description | TileLang DSL v1 status | + |------|---------------|--------|-------------|------------------------| + | `EVEN` | `PART_EVEN` | `Part` | Output to even-indexed lanes | Exposed as `pto.VcvtPartMode.EVEN` | + | `ODD` | `PART_ODD` | `Part` | Output to odd-indexed lanes | Exposed as `pto.VcvtPartMode.ODD` | + | `P0` | `PART_P0` | `Part_T` | Output to sub-part 0 in 4-way packed placement | Exposed as `pto.VcvtPartMode.P0` | + | `P1` | `PART_P1` | `Part_T` | Output to sub-part 1 in 4-way packed placement | Exposed as `pto.VcvtPartMode.P1` | + | `P2` | `PART_P2` | `Part_T` | Output to sub-part 2 in 4-way packed placement | Exposed as `pto.VcvtPartMode.P2` | + | `P3` | `PART_P3` | `Part_T` | Output to sub-part 3 in 4-way packed placement | Exposed as `pto.VcvtPartMode.P3` | - Only backend-supported source/destination type pairs are legal. For the full A5 `vcvt` type matrix, width-changing packing rules, and attribute-sensitive forms, refer to diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 596c0a2f7..a5ae35e42 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -5453,8 +5453,10 @@ def _normalize_vcvt_part_mode(self, expr: SemanticExpr | None) -> SemanticExpr | if part_mode not in {mode.value for mode in VcvtPartMode}: raise TypeError( "pto.vcvt part must be a VcvtPartMode enum such as " - "`pto.VcvtPartMode.EVEN` or `pto.VcvtPartMode.ODD`, or one of the " - 'canonical strings `"EVEN"` / `"ODD"` in TileLang DSL v1' + "`pto.VcvtPartMode.EVEN`, `pto.VcvtPartMode.ODD`, or " + "`pto.VcvtPartMode.P0`..`pto.VcvtPartMode.P3`, or one of the " + 'canonical strings `"EVEN"`, `"ODD"`, `"P0"`, `"P1"`, `"P2"`, or `"P3"` ' + "in TileLang DSL v1" ) return SemanticLiteralExpr(value=part_mode, type=SemanticMetaType(kind="string")) diff --git a/tilelang-dsl/python/tilelang_dsl/types.py b/tilelang-dsl/python/tilelang_dsl/types.py index 1dcf0e374..20f423857 100644 --- a/tilelang-dsl/python/tilelang_dsl/types.py +++ b/tilelang-dsl/python/tilelang_dsl/types.py @@ -448,6 +448,10 @@ class VcvtSatMode(str, Enum): class VcvtPartMode(str, Enum): EVEN = "EVEN" ODD = "ODD" + P0 = "P0" + P1 = "P1" + P2 = "P2" + P3 = "P3" class PostUpdateMode(str, Enum): diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index 3b51a8b88..7e1040429 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -160,7 +160,12 @@ def test_package_exports_surface(self) -> None: self.assertEqual(pto.CmpMode.GE.value, "ge") self.assertEqual(pto.VcvtRoundMode.R.value, "R") self.assertEqual(pto.VcvtSatMode.SAT.value, "SAT") + self.assertEqual(pto.VcvtPartMode.EVEN.value, "EVEN") self.assertEqual(pto.VcvtPartMode.ODD.value, "ODD") + self.assertEqual(pto.VcvtPartMode.P0.value, "P0") + self.assertEqual(pto.VcvtPartMode.P1.value, "P1") + self.assertEqual(pto.VcvtPartMode.P2.value, "P2") + self.assertEqual(pto.VcvtPartMode.P3.value, "P3") self.assertEqual(pto.PostUpdateMode.POST_UPDATE.value, "POST_UPDATE") self.assertEqual(pto.PostUpdateMode.NO_POST_UPDATE.value, "NO_POST_UPDATE") self.assertEqual(pto.Event.ID31.value, "EVENT_ID31") @@ -2983,6 +2988,68 @@ def kernel(dst: pto.Tile, src: pto.Tile): r"= pto\.vcvt %[^,\s]+(?: \{[^}]+\})? : !pto\.vreg<[^>]+> -> !pto\.vreg<[^>]+>", ) + def test_vcvt_supports_part_t_modes_with_enum(self) -> None: + @pto.vkernel( + op="vcvt_part_t_enum_unique", + dtypes=[(pto.i8, pto.f16)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + src_mask = pto.make_mask(pto.f16, pto.PAT.ALL) + dst_mask = pto.make_mask(pto.i8, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vcvt( + vec, + pto.i8, + src_mask, + rnd=pto.VcvtRoundMode.R, + sat=pto.VcvtSatMode.SAT, + part=pto.VcvtPartMode.P0, + ) + pto.vsts(out, dst, 0, dst_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 256), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vcvt", text) + self.assertIn('rnd = "R"', text) + self.assertIn('sat = "SAT"', text) + self.assertIn('part = "P0"', text) + + def test_vcvt_supports_part_t_modes_with_canonical_string(self) -> None: + @pto.vkernel( + op="vcvt_part_t_string_unique", + dtypes=[(pto.i8, pto.f16)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + src_mask = pto.make_mask(pto.f16, pto.PAT.ALL) + dst_mask = pto.make_mask(pto.i8, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vcvt( + vec, + pto.i8, + src_mask, + rnd="R", + sat="SAT", + part="P3", + ) + pto.vsts(out, dst, 0, dst_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 256), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vcvt", text) + self.assertIn('part = "P3"', text) + def test_vcvt_i32_to_i64_reuses_b32_mask_and_emits_i64_vreg(self) -> None: @pto.vkernel( op="vcvt_i32_to_i64_unique", From 8fa8abf10f8a9637557c0dbd4e820ea2e1c7374e Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Wed, 22 Apr 2026 15:55:30 +0800 Subject: [PATCH 124/192] Support the whole mem barrier types --- docs/isa/01-pipeline-sync.md | 29 ++++--- test/basic/membar_barrier_types_vpto_llvm.pto | 43 ++++++++++ .../docs/user_guide/08-sync-dma-operations.md | 8 +- tilelang-dsl/python/tilelang_dsl/semantic.py | 10 ++- tilelang-dsl/python/tilelang_dsl/types.py | 9 +++ tilelang-dsl/tests/test_tilelang_dsl_v1.py | 81 +++++++++++++++++++ 6 files changed, 164 insertions(+), 16 deletions(-) create mode 100644 test/basic/membar_barrier_types_vpto_llvm.pto diff --git a/docs/isa/01-pipeline-sync.md b/docs/isa/01-pipeline-sync.md index b3f350626..f042fa9b2 100644 --- a/docs/isa/01-pipeline-sync.md +++ b/docs/isa/01-pipeline-sync.md @@ -116,21 +116,30 @@ The `mode` parameter controls how `get_buf` and `rls_buf` interact with pipeline ### `pto.mem_bar` - **syntax:** `pto.mem_bar "BARRIER_TYPE"` -- **semantics:** Intra-vector-pipe memory fence within `__VEC_SCOPE__`. Required when UB addresses alias between vector load/store operations. +- **semantics:** Shared-memory (UB address space) memory fence within `__VEC_SCOPE__`. Required when UB addresses alias between memory operations. The barrier type selects which classes of prior instructions must complete before which classes of subsequent instructions may proceed. ```c mem_bar(barrier_type); ``` -**Barrier types:** - -| Type | Semantics | -|------|-----------| -| `VV_ALL` | All prior vector ops complete before subsequent | -| `VST_VLD` | All prior vector stores visible before subsequent loads | -| `VLD_VST` | All prior vector loads complete before subsequent stores | - -**Example:** Ensure stores are visible before loads to same UB region: +**Barrier types** are organized into three families by the scope of prior vs. subsequent instructions: + +| Family | Barrier type | Prior instructions | Subsequent instructions | +|--------|-------------|-------------------|------------------------| +| **VV** (vector→vector) | `VV_ALL` | All vector load/store | All vector load/store | +| | `VST_VLD` | All vector store | All vector load | +| | `VLD_VST` | All vector load | All vector store | +| | `VST_VST` | All vector store | All vector store | +| **VS** (vector→scalar) | `VS_ALL` | All vector load/store | All scalar load/store | +| | `VST_LD` | All vector store | All scalar load | +| | `VLD_ST` | All vector load | All scalar store | +| | `VST_ST` | All vector store | All scalar store | +| **SV** (scalar→vector) | `SV_ALL` | All scalar load/store | All vector load/store | +| | `ST_VLD` | All scalar store | All vector load | +| | `LD_VST` | All scalar load | All vector store | +| | `ST_VST` | All scalar store | All vector store | + +**Example:** Ensure vector stores are visible before subsequent vector loads to the same UB region: ```mlir pto.vsts %v0, %ub[%c0] : !pto.vreg<64xf32>, !pto.ptr pto.mem_bar "VST_VLD" diff --git a/test/basic/membar_barrier_types_vpto_llvm.pto b/test/basic/membar_barrier_types_vpto_llvm.pto new file mode 100644 index 000000000..ec95361ac --- /dev/null +++ b/test/basic/membar_barrier_types_vpto_llvm.pto @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --vpto-emit-hivm-llvm %s -o - | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + func.func @membar_all_documented_kinds() attributes {pto.kernel_kind = #pto.kernel_kind} { + pto.vecscope { + pto.mem_bar "VV_ALL" + pto.mem_bar "VST_VLD" + pto.mem_bar "VLD_VST" + pto.mem_bar "VST_VST" + pto.mem_bar "VS_ALL" + pto.mem_bar "VST_LD" + pto.mem_bar "VLD_ST" + pto.mem_bar "VST_ST" + pto.mem_bar "SV_ALL" + pto.mem_bar "ST_VLD" + pto.mem_bar "LD_VST" + pto.mem_bar "ST_VST" + } + return + } +} + +// CHECK-LABEL: define void @membar_all_documented_kinds( +// CHECK: call void @llvm.hivm.mem.bar.vv.all() +// CHECK: call void @llvm.hivm.mem.bar.vst.vld() +// CHECK: call void @llvm.hivm.mem.bar.vld.vst() +// CHECK: call void @llvm.hivm.mem.bar.vst.vst() +// CHECK: call void @llvm.hivm.mem.bar.vs.all() +// CHECK: call void @llvm.hivm.mem.bar.vst.ld() +// CHECK: call void @llvm.hivm.mem.bar.vld.st() +// CHECK: call void @llvm.hivm.mem.bar.vst.st() +// CHECK: call void @llvm.hivm.mem.bar.sv.all() +// CHECK: call void @llvm.hivm.mem.bar.st.vld() +// CHECK: call void @llvm.hivm.mem.bar.ld.vst() +// CHECK: call void @llvm.hivm.mem.bar.st.vst() diff --git a/tilelang-dsl/docs/user_guide/08-sync-dma-operations.md b/tilelang-dsl/docs/user_guide/08-sync-dma-operations.md index 2ee6ec05a..883e5104a 100644 --- a/tilelang-dsl/docs/user_guide/08-sync-dma-operations.md +++ b/tilelang-dsl/docs/user_guide/08-sync-dma-operations.md @@ -7,9 +7,9 @@ Operations for pipeline synchronization and buffer management. The following enum types provide type-safe parameter specification for synchronization operations: - **`BarrierType`**: Memory barrier types for `pto.mem_bar` - - `VV_ALL`: All prior vector ops complete before subsequent - - `VST_VLD`: All prior vector stores visible before subsequent loads - - `VLD_VST`: All prior vector loads complete before subsequent stores + - `VV_ALL`, `VST_VLD`, `VLD_VST`, `VST_VST`: vector→vector barriers + - `VS_ALL`, `VST_LD`, `VLD_ST`, `VST_ST`: vector→scalar barriers + - `SV_ALL`, `ST_VLD`, `LD_VST`, `ST_VST`: scalar→vector barriers - **`Pipe`**: Hardware pipeline identifiers - `MTE2`: Memory Transfer Engine 2 pipeline @@ -127,7 +127,7 @@ pto.rls_buf(Pipe.MTE2, 0, 0) **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| -| `barrier_type` | `BarrierType` | Barrier type: `BarrierType.VV_ALL` (all prior vector ops complete before subsequent), `BarrierType.VST_VLD` (all prior vector stores visible before subsequent loads), `BarrierType.VLD_VST` (all prior vector loads complete before subsequent stores) | +| `barrier_type` | `BarrierType` | Barrier type controlling prior/subsequent instruction ordering. Supported values are `BarrierType.VV_ALL`, `BarrierType.VST_VLD`, `BarrierType.VLD_VST`, `BarrierType.VST_VST`, `BarrierType.VS_ALL`, `BarrierType.VST_LD`, `BarrierType.VLD_ST`, `BarrierType.VST_ST`, `BarrierType.SV_ALL`, `BarrierType.ST_VLD`, `BarrierType.LD_VST`, and `BarrierType.ST_VST`. | **Returns**: None (side-effect operation) diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index a5ae35e42..b45156518 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -6007,8 +6007,14 @@ def _require_barrier_type(self, expr: SemanticExpr, context: str) -> str: if expr.type.kind == "barrier_type" and isinstance(expr.binding.value, BarrierType): return expr.binding.value.value if isinstance(expr, SemanticLiteralExpr) and isinstance(expr.type, SemanticMetaType) and expr.type.kind == "string": - return expr.value - raise TypeError(f"{context} must be a BarrierType symbol or string literal in TileLang DSL v1") + if expr.value in {barrier_type.value for barrier_type in BarrierType}: + return expr.value + raise TypeError( + f"{context} must be a BarrierType symbol or canonical barrier string " + "(`VV_ALL`, `VST_VLD`, `VLD_VST`, `VST_VST`, `VS_ALL`, `VST_LD`, " + "`VLD_ST`, `VST_ST`, `SV_ALL`, `ST_VLD`, `LD_VST`, or `ST_VST`) " + "in TileLang DSL v1" + ) def _normalize_event_id_expr(self, expr: SemanticExpr, context: str) -> SemanticExpr: if isinstance(expr, SemanticSymbolExpr) and expr.type.kind == "event" and isinstance(expr.value, Event): diff --git a/tilelang-dsl/python/tilelang_dsl/types.py b/tilelang-dsl/python/tilelang_dsl/types.py index 20f423857..3ff5e387f 100644 --- a/tilelang-dsl/python/tilelang_dsl/types.py +++ b/tilelang-dsl/python/tilelang_dsl/types.py @@ -178,6 +178,15 @@ class BarrierType(str, Enum): VV_ALL = "VV_ALL" VST_VLD = "VST_VLD" VLD_VST = "VLD_VST" + VST_VST = "VST_VST" + VS_ALL = "VS_ALL" + VST_LD = "VST_LD" + VLD_ST = "VLD_ST" + VST_ST = "VST_ST" + SV_ALL = "SV_ALL" + ST_VLD = "ST_VLD" + LD_VST = "LD_VST" + ST_VST = "ST_VST" class MaskPattern(str, Enum): diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index 7e1040429..dfb1fb996 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -128,6 +128,15 @@ def test_package_exports_surface(self) -> None: self.assertTrue(hasattr(pto, "si64")) self.assertTrue(hasattr(pto, "ui64")) self.assertEqual(pto.BarrierType.VST_VLD.value, "VST_VLD") + self.assertEqual(pto.BarrierType.VST_VST.value, "VST_VST") + self.assertEqual(pto.BarrierType.VS_ALL.value, "VS_ALL") + self.assertEqual(pto.BarrierType.VST_LD.value, "VST_LD") + self.assertEqual(pto.BarrierType.VLD_ST.value, "VLD_ST") + self.assertEqual(pto.BarrierType.VST_ST.value, "VST_ST") + self.assertEqual(pto.BarrierType.SV_ALL.value, "SV_ALL") + self.assertEqual(pto.BarrierType.ST_VLD.value, "ST_VLD") + self.assertEqual(pto.BarrierType.LD_VST.value, "LD_VST") + self.assertEqual(pto.BarrierType.ST_VST.value, "ST_VST") self.assertEqual(pto.PadMode.PadNull.value, "PadNull") self.assertEqual(pto.PadMode.PadFirstElem.value, "PadFirstElem") self.assertEqual(pto.PadMode.PadValue.value, "PadValue") @@ -5368,6 +5377,78 @@ def kernel( self.assertIn("pto.wait_flag_dev %arg3, %c8_i64 : i64, i64", text) self.assertIn("pto.wait_intra_core %arg4, %c31_i64 : i64, i64", text) + def test_mem_bar_accepts_extended_barrier_type_enum(self) -> None: + BarrierType = pto.BarrierType + + @pto.vkernel( + op="mem_bar_extended_enum_unique", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + pto.mem_bar(BarrierType.ST_VST) + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec = pto.vlds(src, 0) + pto.vsts(vec, dst, 0, mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertIsInstance(semantic_kernel.body[0], SemanticMemBarStmt) + + text = specialized.mlir_text() + self.assertIn('pto.mem_bar "ST_VST"', text) + + def test_mem_bar_accepts_extended_barrier_type_enum_vst_st(self) -> None: + BarrierType = pto.BarrierType + + @pto.vkernel( + op="mem_bar_extended_enum_vst_st_unique", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + pto.mem_bar(BarrierType.VST_ST) + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec = pto.vlds(src, 0) + pto.vsts(vec, dst, 0, mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn('pto.mem_bar "VST_ST"', text) + + def test_mem_bar_rejects_unknown_barrier_string(self) -> None: + with self.assertRaises(TypeError) as ctx: + + @pto.vkernel( + op="mem_bar_invalid_string_unique", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + pto.mem_bar("NOT_A_BARRIER") + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec = pto.vlds(src, 0) + pto.vsts(vec, dst, 0, mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + specialized.mlir_text() + + self.assertIn("canonical barrier string", str(ctx.exception)) + def test_runtime_block_queries_and_scalar_pointer_helpers_lower_to_v0_3_surface(self) -> None: @pto.vkernel( op="runtime_block_queries_and_scalar_helpers", From 32eaa4d58d540b82f2c17188e3baa04542b88304 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Wed, 22 Apr 2026 17:36:04 +0800 Subject: [PATCH 125/192] chore(ci): fix license headers for pr199 --- lib/TileOps/tload_template.py | 8 ++++++++ lib/TileOps/tstore_template.py | 8 ++++++++ .../conversion/vcvt-u32-to-u8-part-p0123/launch.cpp | 2 +- .../conversion/vcvt-u32-to-u8-part-p0123/main.cpp | 2 +- .../conversion/vcvt-u32-to-u8-part-p0123/stub.cpp | 2 +- tilelang-dsl/python/tilelang_dsl/frontend_ast.py | 8 ++++++++ tilelang-dsl/python/tilelang_dsl/support_matrix.py | 8 ++++++++ tilelang-dsl/tests/test_tilelang_dsl_v1.py | 8 ++++++++ 8 files changed, 43 insertions(+), 3 deletions(-) diff --git a/lib/TileOps/tload_template.py b/lib/TileOps/tload_template.py index 3366ac2cd..d298bdfac 100644 --- a/lib/TileOps/tload_template.py +++ b/lib/TileOps/tload_template.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + """`pto.tload` 的 TileLang DSL 模板""" import tilelang_dsl as pto diff --git a/lib/TileOps/tstore_template.py b/lib/TileOps/tstore_template.py index 2850a1651..278d2c25f 100644 --- a/lib/TileOps/tstore_template.py +++ b/lib/TileOps/tstore_template.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + """`pto.tstore` 的 TileLang DSL 模板""" import tilelang_dsl as pto diff --git a/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/launch.cpp b/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/launch.cpp index 0923981ed..a3b66083b 100644 --- a/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/launch.cpp +++ b/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/launch.cpp @@ -2,7 +2,7 @@ // This program is free software, you can redistribute it and/or modify it under the terms and conditions of // CANN Open Software License Agreement Version 2.0 (the "License"). // Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, WHETHER EXPRESS OR IMPLIED, +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. diff --git a/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/main.cpp b/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/main.cpp index 07cedfe4e..ae417b2d7 100644 --- a/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/main.cpp +++ b/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/main.cpp @@ -2,7 +2,7 @@ // This program is free software, you can redistribute it and/or modify it under the terms and conditions of // CANN Open Software License Agreement Version 2.0 (the "License"). // Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, WHETHER EXPRESS OR IMPLIED, +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. diff --git a/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/stub.cpp b/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/stub.cpp index 9e3547e7e..716245ae9 100644 --- a/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/stub.cpp +++ b/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/stub.cpp @@ -2,7 +2,7 @@ // This program is free software, you can redistribute it and/or modify it under the terms and conditions of // CANN Open Software License Agreement Version 2.0 (the "License"). // Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, WHETHER EXPRESS OR IMPLIED, +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. diff --git a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py index dcd58a9e4..d4cff2da3 100644 --- a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py +++ b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + """Frontend AST nodes for TileLang DSL descriptor materialization.""" from __future__ import annotations diff --git a/tilelang-dsl/python/tilelang_dsl/support_matrix.py b/tilelang-dsl/python/tilelang_dsl/support_matrix.py index a82a2d3f1..9fcfc6559 100644 --- a/tilelang-dsl/python/tilelang_dsl/support_matrix.py +++ b/tilelang-dsl/python/tilelang_dsl/support_matrix.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + """Support-matrix definitions and diagnostics for TileLang DSL v1.""" from __future__ import annotations diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index dfb1fb996..4096fa168 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + import tempfile import unittest from unittest import mock From 293875268b9609678bc93008e6565e01bca8ce83 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Wed, 22 Apr 2026 17:28:21 +0800 Subject: [PATCH 126/192] fix(expand-tileop): add cmp_mode attribute handling for TCmpOp and TCmpSOp --- lib/PTO/Transforms/ExpandTileOp.cpp | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/lib/PTO/Transforms/ExpandTileOp.cpp b/lib/PTO/Transforms/ExpandTileOp.cpp index 6c009e556..e9f4bf912 100644 --- a/lib/PTO/Transforms/ExpandTileOp.cpp +++ b/lib/PTO/Transforms/ExpandTileOp.cpp @@ -255,6 +255,18 @@ static void appendOpContextAttrs( if (roundMode) attrs.emplace_back("round_mode", *roundMode); } + if (auto tcmp = dyn_cast(op)) { + if (auto cmpModeAttr = tcmp.getCmpModeAttr()) { + attrs.emplace_back("cmp_mode", + stringifyCmpMode(cmpModeAttr.getValue()).str()); + } + } + if (auto tcmps = dyn_cast(op)) { + if (auto cmpModeAttr = tcmps.getCmpModeAttr()) { + attrs.emplace_back("cmp_mode", + stringifyCmpMode(cmpModeAttr.getValue()).str()); + } + } } static bool getStaticIntFromValue(Value value, int64_t &out) { From fed07fa556cfb0078d9f05ee143e9ca8bff752da Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Wed, 22 Apr 2026 21:30:33 +0800 Subject: [PATCH 127/192] feat: support dsl ut ci --- .github/workflows/ci.yml | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5cfad5a06..299c9c7b4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -265,6 +265,9 @@ jobs: vpto-sim-validation: runs-on: [self-hosted, Linux, X64, label-1] timeout-minutes: 120 + concurrency: + group: vpto-sim-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true if: >- ${{ github.event_name == 'workflow_dispatch' || @@ -275,6 +278,7 @@ jobs: LLVM_COMMIT: cd708029e0b2869e80abe31ddb175f7c35361f90 VPTO_SIM_WORKSPACE: ${{ github.workspace }}/.work/vpto-sim-ci TILELANG_DSL_WORKSPACE: ${{ github.workspace }}/.work/tilelang-dsl-ci + TILELANG_DSL_UT_WORKSPACE: ${{ github.workspace }}/.work/tilelang-dsl-ut-ci steps: - name: Checkout uses: actions/checkout@v4 @@ -460,6 +464,24 @@ jobs: ${{ env.TILELANG_DSL_WORKSPACE }}/run_ci.log if-no-files-found: warn + - name: Run TileLang DSL unit tests + shell: bash + run: | + set -euo pipefail + mkdir -p "${TILELANG_DSL_UT_WORKSPACE}" + cd tilelang-dsl + PYTHONPATH=python python3 -m unittest discover -s tests -p 'test_*.py' \ + 2>&1 | tee "${TILELANG_DSL_UT_WORKSPACE}/unittest.log" + + - name: Upload TileLang DSL unit test logs + if: always() + uses: actions/upload-artifact@v4 + with: + name: tilelang-dsl-ut-ci-${{ github.run_id }} + path: | + ${{ env.TILELANG_DSL_UT_WORKSPACE }}/unittest.log + if-no-files-found: warn + - name: Upload VPTO SIM logs if: always() uses: actions/upload-artifact@v4 From cfee28c6d0f55930b8e3220ed4dc6ddc48281b47 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Wed, 22 Apr 2026 21:02:30 +0800 Subject: [PATCH 128/192] feat(tilelang-dsl): add vscatter surface lowering --- .../user_guide/09-vector-memory-operations.md | 2 +- tilelang-dsl/python/tilelang_dsl/lowering.py | 30 +++++++++++++ tilelang-dsl/python/tilelang_dsl/semantic.py | 44 ++++++++++++++++++- .../python/tilelang_dsl/support_matrix.py | 1 + tilelang-dsl/tests/test_tilelang_dsl_v1.py | 32 ++++++++++++++ 5 files changed, 107 insertions(+), 2 deletions(-) diff --git a/tilelang-dsl/docs/user_guide/09-vector-memory-operations.md b/tilelang-dsl/docs/user_guide/09-vector-memory-operations.md index 57def372d..68bdb1bde 100644 --- a/tilelang-dsl/docs/user_guide/09-vector-memory-operations.md +++ b/tilelang-dsl/docs/user_guide/09-vector-memory-operations.md @@ -768,7 +768,7 @@ pto.vsta(align, tile[k:]) **Constraints**: - Only `b8`, `b16`, and `b32` element sizes are supported -- Index vector must use a supported integer element type and layout +- Current TileLang DSL / VPTO path requires `i32` index vectors - Each computed address must be element-aligned - If indices alias, only one write is guaranteed (winning lane is implementation-defined) - Only the first `active_lanes` offsets participate in the scatter diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index 838ce2df8..6cfa56bb1 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -65,6 +65,7 @@ SemanticType, SemanticTupleExpr, SemanticTupleType, + SemanticVScatterStmt, SemanticVRegType, SemanticVectorPairStoreStmt, SemanticVectorStoreStmt, @@ -313,6 +314,12 @@ def _collect_used_tile_buffers_from_stmt( self._collect_used_tile_buffers_from_expr(index, used) self._collect_used_tile_buffers_from_expr(stmt.mask, used) return + if isinstance(stmt, SemanticVScatterStmt): + self._collect_used_tile_buffers_from_expr(stmt.value, used) + self._record_tile_buffer_use(stmt.destination, used) + self._collect_used_tile_buffers_from_expr(stmt.offsets, used) + self._collect_used_tile_buffers_from_expr(stmt.active_lanes, used) + return if isinstance(stmt, SemanticPredicateStoreStmt): self._collect_used_tile_buffers_from_expr(stmt.value, used) self._record_tile_buffer_use(stmt.destination, used) @@ -432,6 +439,8 @@ def _render_stmt( return self._render_dma_store(stmt, env, indent=indent) if isinstance(stmt, SemanticVectorStoreStmt): return self._render_vector_store(stmt, env, indent=indent) + if isinstance(stmt, SemanticVScatterStmt): + return self._render_vscatter(stmt, env, indent=indent) if isinstance(stmt, SemanticVectorPairStoreStmt): return self._render_vector_pair_store(stmt, env, indent=indent) if isinstance(stmt, SemanticPredicateStoreStmt): @@ -1088,6 +1097,27 @@ def _render_vector_pair_store( ) return lines + def _render_vscatter( + self, + stmt: SemanticVScatterStmt, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + lines: list[str] = [] + value = self._lower_expr(stmt.value, env, indent=indent, into=lines) + destination = self._lower_expr(stmt.destination, env, indent=indent, into=lines) + offsets = self._lower_expr(stmt.offsets, env, indent=indent, into=lines) + active_lanes = self._lower_to_index(stmt.active_lanes, env, indent=indent, into=lines) + lines.append( + self._indent(indent) + + "pto.vscatter " + + f"{value.name}, {destination.name}, {offsets.name}, {active_lanes.name} : " + + f"{self._render_type(value.type)}, {self._render_type(destination.type)}, " + + f"{self._render_type(offsets.type)}, {self._render_type(active_lanes.type)}" + ) + return lines + def _render_predicate_store( self, stmt: SemanticPredicateStoreStmt, diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index b45156518..2229e4a9d 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -596,6 +596,14 @@ class SemanticVectorPairStoreStmt(SemanticStmt): mask: SemanticExpr +@dataclass(frozen=True) +class SemanticVScatterStmt(SemanticStmt): + value: SemanticExpr + destination: SemanticExpr + offsets: SemanticExpr + active_lanes: SemanticExpr + + @dataclass(frozen=True) class SemanticPredicateStoreStmt(SemanticStmt): op_name: str @@ -1166,6 +1174,7 @@ def _should_infer_vecscope( "vsta", "vstas", "vstar", + "vscatter", "vsts", "vstsx2", "vstus", @@ -1266,6 +1275,7 @@ def _frontend_stmt_contains_vector_activity(self, stmt: FrontendStmtNode) -> boo "vsta", "vstas", "vstar", + "vscatter", "vsts", "vstsx2", "vstus", @@ -1365,6 +1375,8 @@ def _semantic_block_contains_vector_activity( return True if isinstance(stmt, SemanticVectorStoreStmt): return True + if isinstance(stmt, SemanticVScatterStmt): + return True if isinstance(stmt, SemanticPredicateStoreStmt): return True if isinstance(stmt, SemanticAlignStoreStmt): @@ -1640,7 +1652,7 @@ def _is_vector_store_call(self, expr: FrontendExprNode) -> bool: return ( isinstance(expr, FrontendCallExpr) and expr.namespace == "pto" - and expr.name in {"psts", "pst", "psti", "vsst", "vsta", "vstas", "vstar", "vsts", "vstsx2"} + and expr.name in {"psts", "pst", "psti", "vsst", "vsta", "vstas", "vstar", "vscatter", "vsts", "vstsx2"} ) def _is_scalar_store_call(self, expr: FrontendExprNode) -> bool: @@ -1889,6 +1901,35 @@ def _analyze_vector_store_stmt( dict(env), ) + if expr.name == "vscatter": + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + if len(args) != 4: + raise TypeError("pto.vscatter expects exactly 4 positional arguments in TileLang DSL v1") + value, destination, offsets, active_lanes = args + value_type = self._require_vreg_expr(value, "pto.vscatter value") + self._require_vector_pointer_expr(destination, "pto.vscatter destination") + offsets_type = self._require_vreg_expr(offsets, "pto.vscatter offsets") + if not is_integer_dtype(offsets_type.element_dtype): + raise TypeError("pto.vscatter offsets must use an integer vector type in TileLang DSL v1") + if integer_bitwidth(offsets_type.element_dtype) != 32: + raise TypeError("pto.vscatter currently requires i32 offset vectors in TileLang DSL v1") + if value_type.lanes != offsets_type.lanes: + raise TypeError("pto.vscatter value and offsets must use the same lane count in TileLang DSL v1") + self._require_i32_like_expr(active_lanes, "pto.vscatter active_lanes") + self._require_matching_vector_pointer(value_type, destination.type, "pto.vscatter") + return ( + SemanticVScatterStmt( + value=value, + destination=destination, + offsets=offsets, + active_lanes=active_lanes, + ), + dict(env), + ) + if expr.name == "vsst": if len(expr.args) == 3: scalar = self._analyze_expr(expr.args[0], env, allow_outer_lookup=allow_outer_lookup) @@ -6448,6 +6489,7 @@ def analyze_frontend_kernel(node: FrontendKernelNode) -> SemanticKernel: "SemanticTupleType", "SemanticType", "SemanticVRegType", + "SemanticVScatterStmt", "SemanticVectorPairStoreStmt", "SemanticVectorStoreStmt", "SemanticWaitFlagDevStmt", diff --git a/tilelang-dsl/python/tilelang_dsl/support_matrix.py b/tilelang-dsl/python/tilelang_dsl/support_matrix.py index 9fcfc6559..33602fd07 100644 --- a/tilelang-dsl/python/tilelang_dsl/support_matrix.py +++ b/tilelang-dsl/python/tilelang_dsl/support_matrix.py @@ -157,6 +157,7 @@ ADVANCED_VECSCOPE_PTO_CALLS = frozenset( { + "vscatter", "vcmp", "vcmps", "vsel", diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index 4096fa168..0c80dd4ff 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -74,6 +74,7 @@ SemanticTileConfigType, SemanticTileType, SemanticVecscopeStmt, + SemanticVScatterStmt, SemanticVectorPairStoreStmt, SemanticVectorStoreStmt, SemanticVRegType, @@ -444,6 +445,7 @@ def test_stable_starter_surface_groups_map_to_stable_tier(self) -> None: self.assertEqual(get_feature_tier("pto.vsort32"), BASIC_TIER) self.assertEqual(get_feature_tier("pto.vldsx2"), BASIC_TIER) self.assertEqual(get_feature_tier("pto.vstsx2"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vscatter"), ADVANCED_TIER) self.assertEqual(get_feature_tier("pto.vbitsort"), ADVANCED_TIER) self.assertEqual(get_feature_tier("pto.vmrgsort4"), ADVANCED_TIER) self.assertEqual(get_feature_tier("PadMode"), BASIC_TIER) @@ -5532,6 +5534,36 @@ def kernel(src: pto.Tile, dst: pto.Tile): self.assertIn('"DINTLV"', text) self.assertIn('"INTLV"', text) + def test_vscatter_lowers_from_advanced_pointer_surface(self) -> None: + @pto.vkernel( + op="vscatter_pointer_surface", + dtypes=[(pto.i32, pto.f32)], + advanced=True, + ) + def kernel( + offsets_src: pto.ptr(pto.i32, pto.MemorySpace.UB), + dst: pto.ptr(pto.f32, pto.MemorySpace.UB), + ): + vec = pto.vbr(1.0) + offsets = pto.vlds(offsets_src, 0) + pto.vscatter(vec, dst, offsets, 64) + return None + + specialized = kernel.specialize() + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + vecscope = next(stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticVecscopeStmt)) + scatter_stmt = next(stmt for stmt in vecscope.body if isinstance(stmt, SemanticVScatterStmt)) + + self.assertIsInstance(scatter_stmt, SemanticVScatterStmt) + self.assertEqual(scatter_stmt.destination.type.memory_space, "ub") + self.assertEqual(scatter_stmt.value.type.element_dtype, pto.f32) + self.assertEqual(scatter_stmt.offsets.type.element_dtype, pto.i32) + + text = specialized.mlir_text() + self.assertIn("pto.vscatter", text) + self.assertIn("!pto.vreg<64xf32>", text) + self.assertIn("!pto.vreg<64xi32>", text) + def test_align_load_and_stateful_store_ops_lower_to_current_vpto_surface(self) -> None: @pto.vkernel( op="align_load_and_stateful_store_ops", From 72798b87e698704bee52620b92c1ca1d217bd052 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Wed, 22 Apr 2026 10:04:17 +0800 Subject: [PATCH 129/192] fix(dsl): require hex strings for integer bit patterns (#174) --- .../docs/user_guide/05-type-system.md | 31 +++++++++++++++++++ tilelang-dsl/python/tilelang_dsl/semantic.py | 18 +++++++++-- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 25 ++++++++++++--- 3 files changed, 68 insertions(+), 6 deletions(-) diff --git a/tilelang-dsl/docs/user_guide/05-type-system.md b/tilelang-dsl/docs/user_guide/05-type-system.md index eadd7d270..e174a0eff 100644 --- a/tilelang-dsl/docs/user_guide/05-type-system.md +++ b/tilelang-dsl/docs/user_guide/05-type-system.md @@ -37,6 +37,37 @@ Integer sign semantics are part of the DSL type surface. `pto.si16`, `pto.ui16`, and `pto.i16` are distinct scalar dtypes and lower to `si16`, `ui16`, and `i16` respectively in VPTO IR. +### Integer Literal Guidance + +For ordinary integer constants, prefer plain integer literals instead of +string forms. + +```python +count = pto.i32(1024) +delta = pto.i16(-12) +min_i32 = pto.i32(-2147483648) +unsigned_hi = pto.ui16(32768) +``` + +Integer string literals are reserved for explicit bit-pattern authoring. They +must use hex form. + +```python +# Use hex strings only when you intentionally want fixed-width bit-pattern +# interpretation at the target dtype width. +hi_bit = pto.i32("0x80000000") # -2147483648 +all_ones = pto.i16("0xFFFF") # -1 +unsigned_hi = pto.ui16("0x8000") # 32768 +``` + +Rules: +- Prefer plain integer literals such as `pto.i32(1024)` or `pto.i16(-12)` for normal integer authoring. +- Integer string literals must use hex bit-pattern form such as `"0xFFFF"`. +- Ordinary integer strings such as `"1024"` or `"-12"` are rejected; write them as integer literals instead. +- For signed and signless integer dtypes (`pto.i*`, `pto.si*`), hex strings use two's-complement interpretation at the target dtype width. +- For unsigned integer dtypes (`pto.ui*`), hex strings keep their unsigned value. +- Hex strings must fit within the target bit width. For example, `pto.i16("0x10000")` is rejected because the literal exceeds 16 bits. + ### Floating-Point Literal Forms `pto.f16(...)`, `pto.bf16(...)`, and `pto.f32(...)` accept multiple literal forms. diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 2229e4a9d..96f2937ac 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -4275,12 +4275,26 @@ def _parse_integer_literal_string( context: str, ) -> int: text = literal.strip().lower() + bits = integer_bitwidth(target_dtype) + signedness = integer_signedness(target_dtype) + assert bits is not None + signless_or_signed = signedness != "unsigned" + if not text.startswith("0x"): + raise TypeError( + f"{context} string literals must use hex bit-pattern form like \"0xFF\" in TileLang DSL v1" + ) try: - parsed = int(text, 0) + parsed = int(text, 16) except ValueError as exc: raise TypeError( - f"{context} string literal {literal!r} is not a valid integer literal" + f"{context} string literal {literal!r} is not a valid hex bit-pattern" ) from exc + if parsed >= (1 << bits): + raise TypeError( + f"{context} bit-pattern literal {literal!r} exceeds {bits}-bit width for {target_dtype.name}" + ) + if signless_or_signed and parsed >= (1 << (bits - 1)): + parsed -= 1 << bits return self._check_integer_literal_range(parsed, target_dtype, context) def _check_integer_literal_range( diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index 0c80dd4ff..7c0def37a 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -4093,16 +4093,33 @@ def kernel(inp: pto.TensorView): self.assertIn("pto.i32 value must be a scalar or index value", str(ctx.exception)) - def test_scalar_constructor_accepts_integer_string_literals(self) -> None: - @pto.vkernel(op="scalar_constructor_integer_string_literals_unique", dtypes=[(pto.f32,)]) + def test_scalar_constructor_accepts_integer_hex_bit_pattern_strings(self) -> None: + @pto.vkernel(op="scalar_constructor_integer_hex_bit_patterns_unique", dtypes=[(pto.f32,)]) def kernel(inp: pto.TensorView): x = pto.i16("0x7FFF") y = pto.i32("0x7FFFFFFF") + z = pto.i16("0x8000") + a = pto.i32("0x80000000") + b = pto.ui16("0x8000") return None text = kernel.mlir_text() self.assertIn("= arith.constant 32767 : i16", text) self.assertIn("= arith.constant 2147483647 : i32", text) + self.assertIn("= arith.constant -32768 : i16", text) + self.assertIn("= arith.constant -2147483648 : i32", text) + self.assertIn("= arith.constant 32768 : i16", text) + + def test_scalar_constructor_rejects_non_hex_integer_string_literals(self) -> None: + @pto.vkernel(op="scalar_constructor_non_hex_integer_strings_unique", dtypes=[(pto.f32,)]) + def kernel(inp: pto.TensorView): + x = pto.i32("1024") + return None + + with self.assertRaises(TypeError) as ctx: + kernel.mlir_text() + + self.assertIn("string literals must use hex bit-pattern form", str(ctx.exception)) def test_scalar_constructor_rejects_out_of_range_integer_literal(self) -> None: @pto.vkernel(op="scalar_constructor_oob_int_unique", dtypes=[(pto.f32,)]) @@ -4118,13 +4135,13 @@ def kernel(inp: pto.TensorView): def test_scalar_constructor_rejects_out_of_range_integer_string_literal(self) -> None: @pto.vkernel(op="scalar_constructor_oob_integer_string_unique", dtypes=[(pto.f32,)]) def kernel(inp: pto.TensorView): - x = pto.i16("0x8000") + x = pto.i16("0x10000") return None with self.assertRaises(TypeError) as ctx: kernel.mlir_text() - self.assertIn("out of range for i16", str(ctx.exception)) + self.assertIn("exceeds 16-bit width for i16", str(ctx.exception)) def test_inferred_vecscope_propagates_bindings_to_constexpr_if(self) -> None: @pto.vkernel( From fb201c6f4d059a94a16827507688d0962b545c87 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Wed, 22 Apr 2026 10:15:28 +0800 Subject: [PATCH 130/192] docs: add license headers to issue 174 updates --- tilelang-dsl/docs/user_guide/05-type-system.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tilelang-dsl/docs/user_guide/05-type-system.md b/tilelang-dsl/docs/user_guide/05-type-system.md index e174a0eff..1c573fe97 100644 --- a/tilelang-dsl/docs/user_guide/05-type-system.md +++ b/tilelang-dsl/docs/user_guide/05-type-system.md @@ -1,3 +1,13 @@ + + ## Type System ### Scalar Types From 1eac185d2b9b36363f5ef3faf3b701a3a121e1b3 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Tue, 21 Apr 2026 16:58:23 +0800 Subject: [PATCH 131/192] Add bitcast after arith.constant --- tilelang-dsl/python/tilelang_dsl/lowering.py | 66 +++++++++++++++++--- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 26 ++++++++ 2 files changed, 82 insertions(+), 10 deletions(-) diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index 6cfa56bb1..6f977d522 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -2505,16 +2505,12 @@ def _lower_expr( if isinstance(expr, SemanticBindingRef): return env.get(expr.binding.name, _RenderedValue(expr.binding.ssa_name, expr.type)) if isinstance(expr, SemanticLiteralExpr): - if desired_name is not None and into is not None: - into.append( - self._indent(indent) - + f"{desired_name} = arith.constant {self._format_constant(expr.value, expr.type)} : " - f"{self._render_arith_constant_type(expr.type)}" - ) - return _RenderedValue(name=desired_name, type=expr.type) - return _RenderedValue( - name=self._materialize_constant(expr.value, expr.type), - type=expr.type, + return self._lower_literal_expr( + expr.value, + expr.type, + indent=indent, + desired_name=desired_name, + into=into, ) if isinstance(expr, SemanticSubscriptAccess): return self._lower_subscript_access( @@ -3711,6 +3707,56 @@ def _materialize_constant(self, value: object, ty: SemanticType) -> str: ) return name + def _signless_integer_scalar_type(self, ty: SemanticType) -> SemanticScalarType | None: + if not isinstance(ty, SemanticScalarType) or not is_integer_dtype(ty.dtype): + return None + signedness = integer_signedness(ty.dtype) + if signedness in {None, "signless"}: + return None + bitwidth = integer_bitwidth(ty.dtype) + if bitwidth not in {8, 16, 32, 64}: + raise NotImplementedError( + f"unsupported integer bitwidth {bitwidth!r} for signless literal lowering" + ) + return SemanticScalarType(dtype=ScalarType(f"i{bitwidth}")) + + def _lower_literal_expr( + self, + value: object, + ty: SemanticType, + *, + indent: int, + desired_name: str | None = None, + into: list[str] | None = None, + ) -> _RenderedValue: + raw_type = self._signless_integer_scalar_type(ty) or ty + if desired_name is not None and into is not None and raw_type == ty: + into.append( + self._indent(indent) + + f"{desired_name} = arith.constant {self._format_constant(value, ty)} : " + f"{self._render_arith_constant_type(ty)}" + ) + return _RenderedValue(name=desired_name, type=ty) + + if desired_name is not None and into is not None: + raw_name = self._new_temp() + into.append( + self._indent(indent) + + f"{raw_name} = arith.constant {self._format_constant(value, raw_type)} : " + f"{self._render_arith_constant_type(raw_type)}" + ) + into.append( + self._indent(indent) + + f"{desired_name} = builtin.unrealized_conversion_cast {raw_name} : " + f"{self._render_type(raw_type)} to {self._render_type(ty)}" + ) + return _RenderedValue(name=desired_name, type=ty) + + return _RenderedValue( + name=self._materialize_constant(value, ty), + type=ty, + ) + def _constant_name(self, value: object, ty: SemanticType) -> str: if isinstance(ty, SemanticIndexType): stem = f"c{value}" diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index 7c0def37a..d0cae63cc 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -2450,6 +2450,32 @@ def kernel(tile: pto.Tile): self.assertIn("arith.constant 4294967295 : i32", text) self.assertNotIn("arith.constant 4294967295 : ui32", text) + def test_unsigned_pad_value_eval_broadcast_bitcasts_signless_literal(self) -> None: + @pto.vkernel(op="tile_pad_value_ui16_vbr_unique", dtypes=[(pto.ui16,)], advanced=True) + def kernel(tile: pto.Tile): + scalar = tile.pad_value.eval() + vec = pto.vbr(scalar) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(8, 16), + memory_space=pto.MemorySpace.UB, + config=pto.TileConfig.from_mapping( + { + "pad_value": pto.PadValue.MAX, + } + ), + ) + ) + + text = specialized.mlir_text() + self.assertIn("dtype=ui16", text) + self.assertIn("arith.constant 65535 : i16", text) + self.assertIn("builtin.unrealized_conversion_cast", text) + self.assertIn(": i16 to ui16", text) + self.assertIn("pto.vbr", text) + def test_make_mask_vlds_vsts_and_vector_families_lower_inside_strict_vecscope(self) -> None: @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32)], advanced=True) From 6a3c1c2f9a4bb202237d729cdcf58f73a5fd083b Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Tue, 21 Apr 2026 17:13:18 +0800 Subject: [PATCH 132/192] fix: handle type compatibility in tryCloneOpLibInlineBridgeOp --- lib/PTO/Transforms/PTOLowerToOpLibCalls.cpp | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/lib/PTO/Transforms/PTOLowerToOpLibCalls.cpp b/lib/PTO/Transforms/PTOLowerToOpLibCalls.cpp index 93eb585f1..42f630902 100644 --- a/lib/PTO/Transforms/PTOLowerToOpLibCalls.cpp +++ b/lib/PTO/Transforms/PTOLowerToOpLibCalls.cpp @@ -1,3 +1,11 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + #include "PTO/IR/PTO.h" #include "PTOLowerToOpLibCalls.h" @@ -201,7 +209,15 @@ FailureOr mlir::pto::tryCloneOpLibInlineBridgeOp(OpBuilder &builder, if (!mappedSrc) return failure(); - mapping.map(cast.getResult(0), mappedSrc); + Type dstTy = cast.getResult(0).getType(); + if (mappedSrc.getType() == dstTy) { + mapping.map(cast.getResult(0), mappedSrc); + return true; + } + + auto clonedCast = builder.create( + cast.getLoc(), TypeRange{dstTy}, ValueRange{mappedSrc}); + mapping.map(cast.getResult(0), clonedCast.getResult(0)); return true; } From 4d4e23d642357b152d1e8aabcf5d1973d5dcbb46 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Wed, 22 Apr 2026 21:54:44 +0800 Subject: [PATCH 133/192] fix(dsl): align vexpdif surface with VPTO --- docs/isa/13-dsa-sfu-ops.md | 6 +- docs/vpto-spec.md | 4 +- include/PTO/IR/VPTOOps.td | 2 +- lib/PTO/IR/VPTO.cpp | 2 +- lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 34 +++++----- .../kernels/online-softmax-update/compare.py | 2 +- .../kernels/online-softmax-update/golden.py | 2 +- .../kernels/online-softmax-update/kernel.pto | 10 +-- .../kernels/online-softmax-update/launch.cpp | 2 +- .../kernels/online-softmax-update/main.cpp | 2 +- .../kernels/online-softmax-update/stub.cpp | 2 +- .../dsa-sfu/vexpdiff-boundary/compare.py | 4 +- .../dsa-sfu/vexpdiff-boundary/golden.py | 4 +- .../dsa-sfu/vexpdiff-boundary/kernel.pto | 8 +-- .../dsa-sfu/vexpdiff-boundary/launch.cpp | 8 +-- .../dsa-sfu/vexpdiff-boundary/main.cpp | 4 +- .../dsa-sfu/vexpdiff-boundary/stub.cpp | 6 +- .../dsa-sfu/vexpdiff-f16-part/compare.py | 4 +- .../dsa-sfu/vexpdiff-f16-part/golden.py | 4 +- .../dsa-sfu/vexpdiff-f16-part/kernel.pto | 10 +-- .../dsa-sfu/vexpdiff-f16-part/launch.cpp | 8 +-- .../dsa-sfu/vexpdiff-f16-part/main.cpp | 4 +- .../dsa-sfu/vexpdiff-f16-part/stub.cpp | 6 +- .../micro-op/dsa-sfu/vexpdiff-f32/compare.py | 4 +- .../micro-op/dsa-sfu/vexpdiff-f32/golden.py | 4 +- .../micro-op/dsa-sfu/vexpdiff-f32/kernel.pto | 8 +-- .../micro-op/dsa-sfu/vexpdiff-f32/launch.cpp | 8 +-- .../micro-op/dsa-sfu/vexpdiff-f32/main.cpp | 4 +- .../micro-op/dsa-sfu/vexpdiff-f32/stub.cpp | 6 +- .../11-vector-arithmetic-operations.md | 14 +++-- .../docs/vpto_spec/vpto-spec-current.md | 10 +-- tilelang-dsl/docs/vpto_spec/vpto-spec-v0.2.md | 8 +-- tilelang-dsl/docs/vpto_spec/vpto-spec-v0.3.md | 10 +-- tilelang-dsl/python/tilelang_dsl/lowering.py | 12 +++- tilelang-dsl/python/tilelang_dsl/semantic.py | 63 ++++++++++++++++++- .../python/tilelang_dsl/support_matrix.py | 1 + tilelang-dsl/tests/test_tilelang_dsl_v1.py | 30 ++++++++- 37 files changed, 210 insertions(+), 110 deletions(-) diff --git a/docs/isa/13-dsa-sfu-ops.md b/docs/isa/13-dsa-sfu-ops.md index 32fc75b82..06edaef68 100644 --- a/docs/isa/13-dsa-sfu-ops.md +++ b/docs/isa/13-dsa-sfu-ops.md @@ -57,9 +57,9 @@ for (int i = 0; i < N; i++) --- -### `pto.vexpdiff` +### `pto.vexpdif` -- **syntax:** `%result = pto.vexpdiff %input, %max, "EVEN|ODD" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **syntax:** `%result = pto.vexpdif %input, %max, "EVEN|ODD" : !pto.vreg, !pto.vreg -> !pto.vreg` - **A5 types:** input `f16` or `f32`, output `f32` - **semantics:** Fused exp(x - max) for numerically stable softmax. @@ -219,7 +219,7 @@ for (int i = 0; i < N; i++) ```mlir // Softmax with fused expdiff %max_broadcast = pto.vlds %ub_max[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> -%exp_stable = pto.vexpdiff %logits, %max_broadcast : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> +%exp_stable = pto.vexpdif %logits, %max_broadcast : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> // Leaky ReLU activation %activated = pto.vlrelu %linear_out, %alpha_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> diff --git a/docs/vpto-spec.md b/docs/vpto-spec.md index 553c2133e..2e2beb5c0 100644 --- a/docs/vpto-spec.md +++ b/docs/vpto-spec.md @@ -893,7 +893,7 @@ This section provides a categorized overview of all PTO micro Instruction operat | 10 | [Reduction Ops](isa/10-reduction-ops.md) | Vector reductions | 7 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin`, `pto.vcgadd`, `pto.vcgmax`, `pto.vcgmin`, `pto.vcpadd` | | 11 | [Compare & Select](isa/11-compare-select.md) | Comparison and conditional selection | 4 (+1 not A5) | `pto.vcmp`, `pto.vcmps`, `pto.vsel`, `pto.vselr` (`pto.vselrv2` removed: not A5) | | 12 | [Data Rearrangement](isa/12-data-rearrangement.md) | In-register data movement and permutation | 2 (+2 not A5) | `pto.vintlv`, `pto.vdintlv` (`pto.vintlvv2`, `pto.vdintlvv2` removed: not A5) | -| 13 | [DSA/SFU Ops](isa/13-dsa-sfu-ops.md) | Specialized ops, index generation, and sorting helpers | 9 | `pto.vlrelu`, `pto.vprelu`, `pto.vexpdiff`, `pto.vaxpy`, `pto.vmull`, `pto.vmula`, `pto.vci`, `pto.vbitsort`, `pto.vmrgsort4` | +| 13 | [DSA/SFU Ops](isa/13-dsa-sfu-ops.md) | Specialized ops, index generation, and sorting helpers | 9 | `pto.vlrelu`, `pto.vprelu`, `pto.vexpdif`, `pto.vaxpy`, `pto.vmull`, `pto.vmula`, `pto.vci`, `pto.vbitsort`, `pto.vmrgsort4` | | 14 | [Arith (Shared MLIR Dialect)](isa/14-shared-arith.md) | Full scalar `arith` surface used around PTO ops; the companion page lists categories and representative examples | all scalar ops | `arith.constant`, `arith.addi`, `arith.addf`, `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.index_cast`, `arith.extsi`, `arith.trunci`, `arith.andi`, `arith.shli`, etc. | | 15 | [SCF (Shared MLIR Dialect)](isa/15-shared-scf.md) | Structured loops, branches, and loop-carried state around PTO regions | 5 | `scf.for`, `scf.if`, `scf.while`, `scf.condition`, `scf.yield` | @@ -982,7 +982,7 @@ pto.vsts %max_vec, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, ! %max_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> // 2. exp(x - max) using fused op -%exp = pto.vexpdiff %logits, %max_bc, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> +%exp = pto.vexpdif %logits, %max_bc, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> // 3. Sum %sum = pto.vcadd %exp, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> diff --git a/include/PTO/IR/VPTOOps.td b/include/PTO/IR/VPTOOps.td index a5a1cf3a6..5e4b5178c 100644 --- a/include/PTO/IR/VPTOOps.td +++ b/include/PTO/IR/VPTOOps.td @@ -1456,7 +1456,7 @@ class PTO_UnmaskedBinaryVecOp : PTO_Op { } def PTO_VpreluOp : PTO_UnmaskedBinaryVecOp<"vprelu">; -def PTO_VexpdiffOp : PTO_Op<"vexpdiff", [Pure]> { +def PTO_VexpdifOp : PTO_Op<"vexpdif", [Pure]> { let arguments = (ins PTO_VectorType:$input, PTO_VectorType:$max, diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp index ce0481be3..4ac498d6a 100644 --- a/lib/PTO/IR/VPTO.cpp +++ b/lib/PTO/IR/VPTO.cpp @@ -2941,7 +2941,7 @@ static LogicalResult verifyFloatBinaryVecNoMaskOp(BinaryVecNoMaskOp op) { } LogicalResult VpreluOp::verify() { return verifyFloatBinaryVecNoMaskOp(*this); } -LogicalResult VexpdiffOp::verify() { +LogicalResult VexpdifOp::verify() { if (failed(verifyVRegTypeLike(*this, getInput().getType(), "input type")) || failed(verifyVRegTypeLike(*this, getMax().getType(), "max type")) || failed(verifyVRegTypeLike(*this, getResult().getType(), "result type"))) diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index bd950edc9..7ccc2230e 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -1669,9 +1669,9 @@ static FailureOr buildVtrcCallee(MLIRContext *context, Type resultTyp return StringAttr::get(context, "llvm.hivm.vtrc." + vec + ".x").getValue(); } -static FailureOr buildVexpdiffCallee(MLIRContext *context, - Type inputType, - Type resultType) { +static FailureOr buildVexpdifCallee(MLIRContext *context, + Type inputType, + Type resultType) { std::string srcVec = getElementTypeFragment(getElementTypeFromVectorLike(inputType)); auto srcLanes = getElementCountFromVectorLike(inputType); @@ -4090,38 +4090,38 @@ class LowerVciOpPattern final : public OpConversionPattern { LoweringState &state; }; -class LowerVexpdiffOpPattern final - : public OpConversionPattern { +class LowerVexpdifOpPattern final + : public OpConversionPattern { public: - explicit LowerVexpdiffOpPattern(TypeConverter &typeConverter, - MLIRContext *context, LoweringState &state) - : OpConversionPattern(typeConverter, context), + explicit LowerVexpdifOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} LogicalResult - matchAndRewrite(pto::VexpdiffOp op, pto::VexpdiffOp::Adaptor adaptor, + matchAndRewrite(pto::VexpdifOp op, pto::VexpdifOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto laneCount = getElementCountFromVectorLike(op.getInput().getType()); Type elemType = getElementTypeFromVectorLike(op.getInput().getType()); auto part = parsePartImmediate(op.getPart()); if (!laneCount || !elemType || !part) - return rewriter.notifyMatchFailure(op, "unsupported vexpdiff signature"); + return rewriter.notifyMatchFailure(op, "unsupported vexpdif signature"); FailureOr mask = materializeDynamicPltMask( rewriter, state, op.getLoc(), getI32Constant(rewriter, op.getLoc(), *laneCount), elemType); if (failed(mask)) - return rewriter.notifyMatchFailure(op, "failed to materialize vexpdiff mask"); + return rewriter.notifyMatchFailure(op, "failed to materialize vexpdif mask"); Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); if (!resultType) - return rewriter.notifyMatchFailure(op, "failed to convert vexpdiff result type"); + return rewriter.notifyMatchFailure(op, "failed to convert vexpdif result type"); FailureOr calleeName = - buildVexpdiffCallee(op.getContext(), op.getInput().getType(), - op.getResult().getType()); + buildVexpdifCallee(op.getContext(), op.getInput().getType(), + op.getResult().getType()); if (failed(calleeName)) - return rewriter.notifyMatchFailure(op, "unsupported vexpdiff callee"); + return rewriter.notifyMatchFailure(op, "unsupported vexpdif callee"); Value partValue = getI32Constant(rewriter, op.getLoc(), *part); auto funcType = rewriter.getFunctionType( @@ -5051,7 +5051,7 @@ static void populateVPTOOpLoweringPatterns(VPTOTypeConverter &typeConverter, LowerVgather2OpPattern, LowerVgather2BcOpPattern, LowerVgatherbOpPattern, LowerVscatterOpPattern, LowerVpreluOpPattern, LowerVaxpyOpPattern, - LowerVciOpPattern, LowerVexpdiffOpPattern, + LowerVciOpPattern, LowerVexpdifOpPattern, LowerVbitsortOpPattern, LowerVtrcOpPattern, LowerVcvtOpPattern, LowerVbitcastOpPattern, LowerPredicateLoadOpPattern, @@ -5110,7 +5110,7 @@ static void configureVPTOOpLoweringTarget(ConversionTarget &target, pto::PintlvB8Op, pto::PintlvB16Op, pto::PintlvB32Op, pto::VsunpackOp, pto::VzunpackOp, pto::VpackOp, pto::VintlvOp, pto::VdintlvOp, pto::VpreluOp, - pto::VaxpyOp, pto::VciOp, pto::VexpdiffOp, + pto::VaxpyOp, pto::VciOp, pto::VexpdifOp, pto::VbitsortOp, pto::VtrcOp, pto::VcvtOp, pto::VbitcastOp, pto::VcmpOp, pto::VcmpsOp, diff --git a/test/vpto/cases/kernels/online-softmax-update/compare.py b/test/vpto/cases/kernels/online-softmax-update/compare.py index 31a8aa717..40eba2276 100644 --- a/test/vpto/cases/kernels/online-softmax-update/compare.py +++ b/test/vpto/cases/kernels/online-softmax-update/compare.py @@ -9,7 +9,7 @@ # case: kernels/online-softmax-update # family: kernels -# target_ops: pto.copy_gm_to_ubuf, pto.copy_ubuf_to_gm, pto.vlds, pto.vcmax, pto.vdup, pto.vmax, pto.vexpdiff, pto.vcadd, pto.vadd, pto.vmul, pto.vdiv, pto.vsts +# target_ops: pto.copy_gm_to_ubuf, pto.copy_ubuf_to_gm, pto.vlds, pto.vcmax, pto.vdup, pto.vmax, pto.vexpdif, pto.vcadd, pto.vadd, pto.vmul, pto.vdiv, pto.vsts # scenarios: online-softmax-update, 16x128-f32, oldmax-oldsum-qk-to-newmax-newsum-expmax-out import os diff --git a/test/vpto/cases/kernels/online-softmax-update/golden.py b/test/vpto/cases/kernels/online-softmax-update/golden.py index 497f7eed6..1dad1c0df 100644 --- a/test/vpto/cases/kernels/online-softmax-update/golden.py +++ b/test/vpto/cases/kernels/online-softmax-update/golden.py @@ -9,7 +9,7 @@ # case: kernels/online-softmax-update # family: kernels -# target_ops: pto.get_block_idx, pto.copy_gm_to_ubuf, pto.copy_ubuf_to_gm, pto.vlds, pto.vcmax, pto.vdup, pto.vmax, pto.vexpdiff, pto.vcadd, pto.vadd, pto.vmul, pto.vdiv, pto.vsts +# target_ops: pto.get_block_idx, pto.copy_gm_to_ubuf, pto.copy_ubuf_to_gm, pto.vlds, pto.vcmax, pto.vdup, pto.vmax, pto.vexpdif, pto.vcadd, pto.vadd, pto.vmul, pto.vdiv, pto.vsts # scenarios: online-softmax-update, dynamic-rows-and-seq, max-seq-128, block-rows-8, oldmax-oldsum-qk-to-newmax-newsum-expmax-out import argparse diff --git a/test/vpto/cases/kernels/online-softmax-update/kernel.pto b/test/vpto/cases/kernels/online-softmax-update/kernel.pto index a2dab4d2e..07bd45279 100644 --- a/test/vpto/cases/kernels/online-softmax-update/kernel.pto +++ b/test/vpto/cases/kernels/online-softmax-update/kernel.pto @@ -1,7 +1,7 @@ // ----------------------------------------------------------------------------- // case: kernels/online-softmax-update // family: kernels -// target_ops: pto.get_block_idx, pto.copy_gm_to_ubuf, pto.copy_ubuf_to_gm, pto.vlds, pto.vcmax, pto.vdup, pto.vmax, pto.vexpdiff, pto.vcadd, pto.vadd, pto.vmul, pto.vdiv, pto.vsts +// target_ops: pto.get_block_idx, pto.copy_gm_to_ubuf, pto.copy_ubuf_to_gm, pto.vlds, pto.vcmax, pto.vdup, pto.vmax, pto.vexpdif, pto.vcadd, pto.vadd, pto.vmul, pto.vdiv, pto.vsts // scenarios: online-softmax-update, dynamic-rows-and-seq, max-seq-128, block-rows-8, oldmax-oldsum-qk-to-newmax-newsum-expmax-out // ----------------------------------------------------------------------------- module attributes {pto.target_arch = "a5"} { @@ -104,9 +104,9 @@ module attributes {pto.target_arch = "a5"} { %chunk_max = pto.vcmax %vec, %chunk_mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %chunk_max_bc = pto.vdup %chunk_max, %active {position = "LOWEST"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %merged_max = pto.vmax %running_max, %chunk_max_bc, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> - %scaled_running = pto.vexpdiff %running_max, %merged_max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + %scaled_running = pto.vexpdif %running_max, %merged_max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> %running_sum_scaled = pto.vmul %scaled_running, %running_sum, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> - %chunk_exp = pto.vexpdiff %vec, %merged_max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + %chunk_exp = pto.vexpdif %vec, %merged_max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> %chunk_sum = pto.vcadd %chunk_exp, %chunk_mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %chunk_sum_bc = pto.vdup %chunk_sum, %active {position = "LOWEST"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %merged_sum = pto.vadd %running_sum_scaled, %chunk_sum_bc, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> @@ -117,7 +117,7 @@ module attributes {pto.target_arch = "a5"} { scf.yield %next_max, %next_sum : !pto.vreg<64xf32>, !pto.vreg<64xf32> } - %raw_expmax = pto.vexpdiff %oldmax_bc, %final_max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + %raw_expmax = pto.vexpdif %oldmax_bc, %final_max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> %scaled_oldsum = pto.vmul %raw_expmax, %oldsum_bc, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %expmax = pto.vdiv %scaled_oldsum, %final_sum, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> pto.vsts %final_max, %ub_newmax[%row], %one_mask {dist = "1PT_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask @@ -138,7 +138,7 @@ module attributes {pto.target_arch = "a5"} { %chunk_mask, %chunk_rest = pto.plt_b32 %remaining_cols : i32 -> !pto.mask, i32 %chunk_base = arith.addi %row_qk, %chunk : index %vec = pto.vlds %ub_qk[%chunk_base] : !pto.ptr -> !pto.vreg<64xf32> - %exp = pto.vexpdiff %vec, %final_max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + %exp = pto.vexpdif %vec, %final_max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> %out = pto.vdiv %exp, %final_sum, %chunk_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> pto.vsts %out, %ub_out[%chunk_base], %chunk_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask } diff --git a/test/vpto/cases/kernels/online-softmax-update/launch.cpp b/test/vpto/cases/kernels/online-softmax-update/launch.cpp index e50841764..d06402730 100644 --- a/test/vpto/cases/kernels/online-softmax-update/launch.cpp +++ b/test/vpto/cases/kernels/online-softmax-update/launch.cpp @@ -9,7 +9,7 @@ // ----------------------------------------------------------------------------- // case: kernels/online-softmax-update // family: kernels -// target_ops: pto.get_block_idx, pto.copy_gm_to_ubuf, pto.copy_ubuf_to_gm, pto.vlds, pto.vcmax, pto.vdup, pto.vmax, pto.vexpdiff, pto.vcadd, pto.vadd, pto.vmul, pto.vdiv, pto.vsts +// target_ops: pto.get_block_idx, pto.copy_gm_to_ubuf, pto.copy_ubuf_to_gm, pto.vlds, pto.vcmax, pto.vdup, pto.vmax, pto.vexpdif, pto.vcadd, pto.vadd, pto.vmul, pto.vdiv, pto.vsts // scenarios: online-softmax-update, dynamic-rows-and-seq, max-seq-128, block-rows-8, oldmax-oldsum-qk-to-newmax-newsum-expmax-out // ----------------------------------------------------------------------------- #ifndef __VEC_SCOPE__ diff --git a/test/vpto/cases/kernels/online-softmax-update/main.cpp b/test/vpto/cases/kernels/online-softmax-update/main.cpp index af6cbb63b..7c6f120e2 100644 --- a/test/vpto/cases/kernels/online-softmax-update/main.cpp +++ b/test/vpto/cases/kernels/online-softmax-update/main.cpp @@ -9,7 +9,7 @@ // ----------------------------------------------------------------------------- // case: kernels/online-softmax-update // family: kernels -// target_ops: pto.get_block_idx, pto.copy_gm_to_ubuf, pto.copy_ubuf_to_gm, pto.vlds, pto.vcmax, pto.vdup, pto.vmax, pto.vexpdiff, pto.vcadd, pto.vadd, pto.vmul, pto.vdiv, pto.vsts +// target_ops: pto.get_block_idx, pto.copy_gm_to_ubuf, pto.copy_ubuf_to_gm, pto.vlds, pto.vcmax, pto.vdup, pto.vmax, pto.vexpdif, pto.vcadd, pto.vadd, pto.vmul, pto.vdiv, pto.vsts // scenarios: online-softmax-update, dynamic-rows-and-seq, max-seq-128, block-rows-8, oldmax-oldsum-qk-to-newmax-newsum-expmax-out // ----------------------------------------------------------------------------- #include "test_common.h" diff --git a/test/vpto/cases/kernels/online-softmax-update/stub.cpp b/test/vpto/cases/kernels/online-softmax-update/stub.cpp index 389a74d5f..4764fd305 100644 --- a/test/vpto/cases/kernels/online-softmax-update/stub.cpp +++ b/test/vpto/cases/kernels/online-softmax-update/stub.cpp @@ -9,7 +9,7 @@ // ----------------------------------------------------------------------------- // case: kernels/online-softmax-update // family: kernels -// target_ops: pto.copy_gm_to_ubuf, pto.copy_ubuf_to_gm, pto.vlds, pto.vcmax, pto.vdup, pto.vmax, pto.vexpdiff, pto.vcadd, pto.vadd, pto.vmul, pto.vdiv, pto.vsts +// target_ops: pto.copy_gm_to_ubuf, pto.copy_ubuf_to_gm, pto.vlds, pto.vcmax, pto.vdup, pto.vmax, pto.vexpdif, pto.vcadd, pto.vadd, pto.vmul, pto.vdiv, pto.vsts // scenarios: online-softmax-update, dynamic-rows-and-seq, max-seq-128, block-rows-8, oldmax-oldsum-qk-to-newmax-newsum-expmax-out // ----------------------------------------------------------------------------- #include diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/compare.py b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/compare.py index 4b67584ba..5353e5df9 100755 --- a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/compare.py +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/compare.py @@ -7,9 +7,9 @@ # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. # See LICENSE in the root of the software repository for the full text of the License. -# case: micro-op/dsa-sfu/vexpdiff-boundary +# case: micro-op/dsa-sfu/vexpdif-boundary # family: dsa-sfu -# target_ops: pto.vexpdiff +# target_ops: pto.vexpdif # scenarios: core-f32, fused-expdiff, exceptional-values, floating-overflow-underflow # NOTE: bulk-generated coverage skeleton. diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/golden.py b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/golden.py index b82ce1002..b4b417320 100755 --- a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/golden.py +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/golden.py @@ -7,9 +7,9 @@ # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. # See LICENSE in the root of the software repository for the full text of the License. -# case: micro-op/dsa-sfu/vexpdiff-boundary +# case: micro-op/dsa-sfu/vexpdif-boundary # family: dsa-sfu -# target_ops: pto.vexpdiff +# target_ops: pto.vexpdif # scenarios: core-f32, fused-expdiff, exceptional-values, floating-overflow-underflow # NOTE: bulk-generated coverage skeleton. # coding=utf-8 diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/kernel.pto index 30c802bc0..da20f83ab 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/kernel.pto @@ -1,13 +1,13 @@ // ----------------------------------------------------------------------------- -// case: micro-op/dsa-sfu/vexpdiff-boundary +// case: micro-op/dsa-sfu/vexpdif-boundary // family: dsa-sfu -// target_ops: pto.vexpdiff +// target_ops: pto.vexpdif // scenarios: core-f32, fused-expdiff, exceptional-values, floating-overflow-underflow // NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is // still a valid test conclusion in the current coverage-first phase. // ----------------------------------------------------------------------------- module attributes {pto.target_arch = "a5"} { - func.func @vexpdiff_boundary_kernel_2d(%arg0: !pto.ptr, + func.func @vexpdif_boundary_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr) { %c0 = arith.constant 0 : index @@ -40,7 +40,7 @@ module attributes {pto.target_arch = "a5"} { scf.for %offset = %c0 to %c1024 step %c64 { %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> %max = pto.vlds %ub_max[%offset] : !pto.ptr -> !pto.vreg<64xf32> - %sum = pto.vexpdiff %vec, %max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + %sum = pto.vexpdif %vec, %max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask } } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/launch.cpp index 1ff4ec8a7..e2f5057e6 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/launch.cpp +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/launch.cpp @@ -7,9 +7,9 @@ // See LICENSE in the root of the software repository for the full text of the License. // ----------------------------------------------------------------------------- -// case: micro-op/dsa-sfu/vexpdiff-boundary +// case: micro-op/dsa-sfu/vexpdif-boundary // family: dsa-sfu -// target_ops: pto.vexpdiff +// target_ops: pto.vexpdif // scenarios: core-f32, fused-expdiff, exceptional-values, floating-overflow-underflow // NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is // still a valid test conclusion in the current coverage-first phase. @@ -43,12 +43,12 @@ struct MrgSortExecutedNumList { #include "acl/acl.h" #endif -extern "C" __global__ [aicore] void vexpdiff_boundary_kernel_2d(__gm__ float *v1, +extern "C" __global__ [aicore] void vexpdif_boundary_kernel_2d(__gm__ float *v1, __gm__ float *v2, __gm__ float *v3); void LaunchVexpdiff_boundary_kernel_2d(float *v1, float *v2, float *v3, void *stream) { - vexpdiff_boundary_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + vexpdif_boundary_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2, (__gm__ float *)v3); } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/main.cpp index 59f0d80f4..3f29604cb 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/main.cpp +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/main.cpp @@ -7,9 +7,9 @@ // See LICENSE in the root of the software repository for the full text of the License. // ----------------------------------------------------------------------------- -// case: micro-op/dsa-sfu/vexpdiff-boundary +// case: micro-op/dsa-sfu/vexpdif-boundary // family: dsa-sfu -// target_ops: pto.vexpdiff +// target_ops: pto.vexpdif // scenarios: core-f32, fused-expdiff, exceptional-values, floating-overflow-underflow // NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is // still a valid test conclusion in the current coverage-first phase. diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/stub.cpp b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/stub.cpp index 278265849..97f735224 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/stub.cpp +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/stub.cpp @@ -7,9 +7,9 @@ // See LICENSE in the root of the software repository for the full text of the License. // ----------------------------------------------------------------------------- -// case: micro-op/dsa-sfu/vexpdiff-boundary +// case: micro-op/dsa-sfu/vexpdif-boundary // family: dsa-sfu -// target_ops: pto.vexpdiff +// target_ops: pto.vexpdif // scenarios: core-f32, fused-expdiff, exceptional-values, floating-overflow-underflow // NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is // still a valid test conclusion in the current coverage-first phase. @@ -23,7 +23,7 @@ #define __gm__ #endif -extern "C" __global__ [aicore] void vexpdiff_boundary_kernel_2d(__gm__ float *v1, +extern "C" __global__ [aicore] void vexpdif_boundary_kernel_2d(__gm__ float *v1, __gm__ float *v2, __gm__ float *v3) { (void)v1; diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/compare.py b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/compare.py index c2ea3f6bd..8ca6af6cf 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/compare.py +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/compare.py @@ -7,9 +7,9 @@ # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. # See LICENSE in the root of the software repository for the full text of the License. -# case: micro-op/dsa-sfu/vexpdiff-f16-part +# case: micro-op/dsa-sfu/vexpdif-f16-part # family: dsa-sfu -# target_ops: pto.vexpdiff +# target_ops: pto.vexpdif # scenarios: core-f16, fused-expdiff, part-even-odd import os diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/golden.py b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/golden.py index 915730bc8..1d493c5bb 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/golden.py +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/golden.py @@ -7,9 +7,9 @@ # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. # See LICENSE in the root of the software repository for the full text of the License. -# case: micro-op/dsa-sfu/vexpdiff-f16-part +# case: micro-op/dsa-sfu/vexpdif-f16-part # family: dsa-sfu -# target_ops: pto.vexpdiff +# target_ops: pto.vexpdif # scenarios: core-f16, fused-expdiff, part-even-odd import argparse diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/kernel.pto index 553552e3c..3c4e4e4fa 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/kernel.pto @@ -1,12 +1,12 @@ // ----------------------------------------------------------------------------- -// case: micro-op/dsa-sfu/vexpdiff-f16-part +// case: micro-op/dsa-sfu/vexpdif-f16-part // family: dsa-sfu -// target_ops: pto.vexpdiff +// target_ops: pto.vexpdif // scenarios: core-f16, fused-expdiff, part-even-odd // NOTE: validates that ODD/EVEN selects odd/even lanes from f16 inputs. // ----------------------------------------------------------------------------- module attributes {pto.target_arch = "a5"} { - func.func @vexpdiff_f16_part_kernel_2d(%arg0: !pto.ptr, + func.func @vexpdif_f16_part_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr) { %c0 = arith.constant 0 : index @@ -43,8 +43,8 @@ module attributes {pto.target_arch = "a5"} { %max = pto.vlds %ub_max[%offset] : !pto.ptr -> !pto.vreg<128xf16> %even_mask, %remaining_after_even = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 %odd_mask, %next_remaining = pto.plt_b32 %remaining_after_even : i32 -> !pto.mask, i32 - %even = pto.vexpdiff %input, %max, "EVEN" : !pto.vreg<128xf16>, !pto.vreg<128xf16> -> !pto.vreg<64xf32> - %odd = pto.vexpdiff %input, %max, "ODD" : !pto.vreg<128xf16>, !pto.vreg<128xf16> -> !pto.vreg<64xf32> + %even = pto.vexpdif %input, %max, "EVEN" : !pto.vreg<128xf16>, !pto.vreg<128xf16> -> !pto.vreg<64xf32> + %odd = pto.vexpdif %input, %max, "ODD" : !pto.vreg<128xf16>, !pto.vreg<128xf16> -> !pto.vreg<64xf32> %odd_offset = arith.addi %offset, %c64 : index pto.vsts %even, %ub_out[%offset], %even_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask pto.vsts %odd, %ub_out[%odd_offset], %odd_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/launch.cpp index 8bee57183..78f8bef63 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/launch.cpp +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/launch.cpp @@ -7,9 +7,9 @@ // See LICENSE in the root of the software repository for the full text of the License. // ----------------------------------------------------------------------------- -// case: micro-op/dsa-sfu/vexpdiff-f16-part +// case: micro-op/dsa-sfu/vexpdif-f16-part // family: dsa-sfu -// target_ops: pto.vexpdiff +// target_ops: pto.vexpdif // scenarios: core-f16, fused-expdiff, part-even-odd // ----------------------------------------------------------------------------- #ifndef __VEC_SCOPE__ @@ -41,13 +41,13 @@ struct MrgSortExecutedNumList { #include "acl/acl.h" #endif -extern "C" __global__ [aicore] void vexpdiff_f16_part_kernel_2d(__gm__ half *v1, +extern "C" __global__ [aicore] void vexpdif_f16_part_kernel_2d(__gm__ half *v1, __gm__ half *v2, __gm__ float *v3); void LaunchVexpdiff_f16_part_kernel_2d(uint16_t *v1, uint16_t *v2, float *v3, void *stream) { - vexpdiff_f16_part_kernel_2d<<<1, nullptr, stream>>>((__gm__ half *)v1, + vexpdif_f16_part_kernel_2d<<<1, nullptr, stream>>>((__gm__ half *)v1, (__gm__ half *)v2, (__gm__ float *)v3); } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/main.cpp index 7137c02a0..58b1f6c5d 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/main.cpp +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/main.cpp @@ -7,9 +7,9 @@ // See LICENSE in the root of the software repository for the full text of the License. // ----------------------------------------------------------------------------- -// case: micro-op/dsa-sfu/vexpdiff-f16-part +// case: micro-op/dsa-sfu/vexpdif-f16-part // family: dsa-sfu -// target_ops: pto.vexpdiff +// target_ops: pto.vexpdif // scenarios: core-f16, fused-expdiff, part-even-odd // ----------------------------------------------------------------------------- /** diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/stub.cpp b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/stub.cpp index 65011d6bb..f978d9722 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/stub.cpp +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/stub.cpp @@ -7,9 +7,9 @@ // See LICENSE in the root of the software repository for the full text of the License. // ----------------------------------------------------------------------------- -// case: micro-op/dsa-sfu/vexpdiff-f16-part +// case: micro-op/dsa-sfu/vexpdif-f16-part // family: dsa-sfu -// target_ops: pto.vexpdiff +// target_ops: pto.vexpdif // scenarios: core-f16, fused-expdiff, part-even-odd // ----------------------------------------------------------------------------- @@ -21,7 +21,7 @@ #define __gm__ #endif -extern "C" __global__ [aicore] void vexpdiff_f16_part_kernel_2d(__gm__ half *v1, +extern "C" __global__ [aicore] void vexpdif_f16_part_kernel_2d(__gm__ half *v1, __gm__ half *v2, __gm__ float *v3) { (void)v1; diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/compare.py b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/compare.py index fdd9df368..8575e7aa5 100755 --- a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/compare.py +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/compare.py @@ -7,9 +7,9 @@ # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. # See LICENSE in the root of the software repository for the full text of the License. -# case: micro-op/dsa-sfu/vexpdiff-f32 +# case: micro-op/dsa-sfu/vexpdif-f32 # family: dsa-sfu -# target_ops: pto.vexpdiff +# target_ops: pto.vexpdif # scenarios: core-f32, fused-expdiff # NOTE: bulk-generated coverage skeleton. diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/golden.py b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/golden.py index 9c6cf1776..874d5b6e5 100755 --- a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/golden.py +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/golden.py @@ -7,9 +7,9 @@ # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. # See LICENSE in the root of the software repository for the full text of the License. -# case: micro-op/dsa-sfu/vexpdiff-f32 +# case: micro-op/dsa-sfu/vexpdif-f32 # family: dsa-sfu -# target_ops: pto.vexpdiff +# target_ops: pto.vexpdif # scenarios: core-f32, fused-expdiff # NOTE: bulk-generated coverage skeleton. # coding=utf-8 diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/kernel.pto index 1e98a0c2e..e9c83b6d7 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/kernel.pto @@ -1,13 +1,13 @@ // ----------------------------------------------------------------------------- -// case: micro-op/dsa-sfu/vexpdiff-f32 +// case: micro-op/dsa-sfu/vexpdif-f32 // family: dsa-sfu -// target_ops: pto.vexpdiff +// target_ops: pto.vexpdif // scenarios: core-f32, fused-expdiff // NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is // still a valid test conclusion in the current coverage-first phase. // ----------------------------------------------------------------------------- module attributes {pto.target_arch = "a5"} { - func.func @vexpdiff_kernel_2d(%arg0: !pto.ptr, + func.func @vexpdif_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { %c0 = arith.constant 0 : index %c64 = arith.constant 64 : index @@ -34,7 +34,7 @@ module attributes {pto.target_arch = "a5"} { %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> - %sum = pto.vexpdiff %vec, %vec, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + %sum = pto.vexpdif %vec, %vec, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask scf.yield %next_remaining : i32 } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/launch.cpp index cc7fb20e8..00ada867d 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/launch.cpp +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/launch.cpp @@ -7,9 +7,9 @@ // See LICENSE in the root of the software repository for the full text of the License. // ----------------------------------------------------------------------------- -// case: micro-op/dsa-sfu/vexpdiff-f32 +// case: micro-op/dsa-sfu/vexpdif-f32 // family: dsa-sfu -// target_ops: pto.vexpdiff +// target_ops: pto.vexpdif // scenarios: core-f32, fused-expdiff // NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is // still a valid test conclusion in the current coverage-first phase. @@ -43,10 +43,10 @@ struct MrgSortExecutedNumList { #include "acl/acl.h" #endif -extern "C" __global__ [aicore] void vexpdiff_kernel_2d(__gm__ float *v1, +extern "C" __global__ [aicore] void vexpdif_kernel_2d(__gm__ float *v1, __gm__ float *v2); void LaunchVexpdiff_kernel_2d(float *v1, float *v2, void *stream) { - vexpdiff_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + vexpdif_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/main.cpp index ccf695380..4afacde3a 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/main.cpp +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/main.cpp @@ -7,9 +7,9 @@ // See LICENSE in the root of the software repository for the full text of the License. // ----------------------------------------------------------------------------- -// case: micro-op/dsa-sfu/vexpdiff-f32 +// case: micro-op/dsa-sfu/vexpdif-f32 // family: dsa-sfu -// target_ops: pto.vexpdiff +// target_ops: pto.vexpdif // scenarios: core-f32, fused-expdiff // NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is // still a valid test conclusion in the current coverage-first phase. diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/stub.cpp b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/stub.cpp index 30bf48863..d0818a67b 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/stub.cpp +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/stub.cpp @@ -7,9 +7,9 @@ // See LICENSE in the root of the software repository for the full text of the License. // ----------------------------------------------------------------------------- -// case: micro-op/dsa-sfu/vexpdiff-f32 +// case: micro-op/dsa-sfu/vexpdif-f32 // family: dsa-sfu -// target_ops: pto.vexpdiff +// target_ops: pto.vexpdif // scenarios: core-f32, fused-expdiff // NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is // still a valid test conclusion in the current coverage-first phase. @@ -23,7 +23,7 @@ #define __gm__ #endif -extern "C" __global__ [aicore] void vexpdiff_kernel_2d(__gm__ float *v1, +extern "C" __global__ [aicore] void vexpdif_kernel_2d(__gm__ float *v1, __gm__ float *v2) { (void)v1; (void)v2; diff --git a/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md b/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md index 09e940314..39882950e 100644 --- a/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md +++ b/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md @@ -343,23 +343,27 @@ neg_vec = pto.vneg(vec_f32, mask32) **Constraints**: - Operates on integer vector types only -#### `pto.vexpdiff(vec: VRegType, mask: MaskType) -> VRegType` +#### `pto.vexpdif(vec: VRegType, max_vec: VRegType, part: pto.VcvtPartMode) -> VRegType` -**Description**: Exponential difference of vector elements. +**Description**: Fused exponential difference `exp(vec - max_vec)` for numerically stable softmax lowering. **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| | `vec` | `VRegType` | Input vector | -| `mask` | `MaskType` | Predicate mask | +| `max_vec` | `VRegType` | Per-lane max vector subtracted before exponentiation | +| `part` | `pto.VcvtPartMode` | Output part selector enum. Use `pto.VcvtPartMode.EVEN` or `pto.VcvtPartMode.ODD`. | **Returns**: | Return Value | Type | Description | |--------------|------|-------------| -| `result` | `VRegType` | Exponential difference values | +| `result` | `VRegType` | Exponential difference values; result element type is `f32` | **Constraints**: -- For floating-point vector types only +- Supports `f16` and `f32` input vectors only +- `vec` and `max_vec` must use the same vector type +- `part` should use `pto.VcvtPartMode.EVEN` or `pto.VcvtPartMode.ODD` +- Canonical strings `"EVEN"` / `"ODD"` are still accepted for compatibility ### Binary Vector Operations diff --git a/tilelang-dsl/docs/vpto_spec/vpto-spec-current.md b/tilelang-dsl/docs/vpto_spec/vpto-spec-current.md index 8db80caef..620bb407b 100644 --- a/tilelang-dsl/docs/vpto_spec/vpto-spec-current.md +++ b/tilelang-dsl/docs/vpto_spec/vpto-spec-current.md @@ -894,7 +894,7 @@ This section provides a categorized overview of all PTO micro Instruction operat | 10 | [Reduction Ops](#isa-10-reduction-ops) | Vector reductions | 7 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin`, `pto.vcgadd`, `pto.vcgmax`, `pto.vcgmin`, `pto.vcpadd` | | 11 | [Compare & Select](#isa-11-compare-select) | Comparison and conditional selection | 4 (+1 not A5) | `pto.vcmp`, `pto.vcmps`, `pto.vsel`, `pto.vselr` (`pto.vselrv2` removed: not A5) | | 12 | [Data Rearrangement](#isa-12-data-rearrangement) | In-register data movement and permutation | 2 (+2 not A5) | `pto.vintlv`, `pto.vdintlv` (`pto.vintlvv2`, `pto.vdintlvv2` removed: not A5) | -| 13 | [DSA/SFU Ops](#isa-13-dsa-sfu-ops) | Specialized ops, index generation, and sorting helpers | 9 | `pto.vlrelu`, `pto.vprelu`, `pto.vexpdiff`, `pto.vaxpy`, `pto.vmull`, `pto.vmula`, `pto.vci`, `pto.vbitsort`, `pto.vmrgsort4` | +| 13 | [DSA/SFU Ops](#isa-13-dsa-sfu-ops) | Specialized ops, index generation, and sorting helpers | 9 | `pto.vlrelu`, `pto.vprelu`, `pto.vexpdif`, `pto.vaxpy`, `pto.vmull`, `pto.vmula`, `pto.vci`, `pto.vbitsort`, `pto.vmrgsort4` | | 14 | [Arith (Shared MLIR Dialect)](#isa-14-shared-arith) | Full scalar `arith` surface used around PTO ops; the companion page lists categories and representative examples | all scalar ops | `arith.constant`, `arith.addi`, `arith.addf`, `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.index_cast`, `arith.extsi`, `arith.trunci`, `arith.andi`, `arith.shli`, etc. | | 15 | [SCF (Shared MLIR Dialect)](#isa-15-shared-scf) | Structured loops, branches, and loop-carried state around PTO regions | 5 | `scf.for`, `scf.if`, `scf.while`, `scf.condition`, `scf.yield` | @@ -4888,9 +4888,9 @@ for (int i = 0; i < N; i++) --- -##### `pto.vexpdiff` +##### `pto.vexpdif` -- **syntax:** `%result = pto.vexpdiff %input, %max, "EVEN|ODD" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **syntax:** `%result = pto.vexpdif %input, %max, "EVEN|ODD" : !pto.vreg, !pto.vreg -> !pto.vreg` - **A5 types:** input `f16` or `f32`, output `f32` - **semantics:** Fused exp(x - max) for numerically stable softmax. @@ -5049,7 +5049,7 @@ for (int i = 0; i < N; i++) ```mlir // Softmax with fused expdiff %max_broadcast = pto.vlds %ub_max[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> -%exp_stable = pto.vexpdiff %logits, %max_broadcast : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> +%exp_stable = pto.vexpdif %logits, %max_broadcast : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> // Leaky ReLU activation %activated = pto.vlrelu %linear_out, %alpha_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> @@ -5285,7 +5285,7 @@ pto.vsts %max_vec, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, ! %max_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> // 2. exp(x - max) using fused op -%exp = pto.vexpdiff %logits, %max_bc, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> +%exp = pto.vexpdif %logits, %max_bc, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> // 3. Sum %sum = pto.vcadd %exp, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> diff --git a/tilelang-dsl/docs/vpto_spec/vpto-spec-v0.2.md b/tilelang-dsl/docs/vpto_spec/vpto-spec-v0.2.md index 3c1e31419..6c06c4c07 100644 --- a/tilelang-dsl/docs/vpto_spec/vpto-spec-v0.2.md +++ b/tilelang-dsl/docs/vpto_spec/vpto-spec-v0.2.md @@ -4490,9 +4490,9 @@ for (int i = 0; i < N; i++) --- -##### `pto.vexpdiff` +##### `pto.vexpdif` -- **syntax:** `%result = pto.vexpdiff %input, %max : !pto.vreg, !pto.vreg -> !pto.vreg` +- **syntax:** `%result = pto.vexpdif %input, %max : !pto.vreg, !pto.vreg -> !pto.vreg` - **A5 types:** f16, f32 - **semantics:** Fused exp(x - max) for numerically stable softmax. @@ -4736,7 +4736,7 @@ for (int i = 0; i < N; i++) ```mlir // Softmax with fused expdiff %max_broadcast = pto.vlds %ub_max[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> -%exp_stable = pto.vexpdiff %logits, %max_broadcast : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> +%exp_stable = pto.vexpdif %logits, %max_broadcast : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> // Leaky ReLU activation %activated = pto.vlrelu %linear_out, %alpha_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> @@ -4975,7 +4975,7 @@ pto.vsts %max_vec, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, ! %max_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> // 2. exp(x - max) using fused op -%exp = pto.vexpdiff %logits, %max_bc : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> +%exp = pto.vexpdif %logits, %max_bc : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> // 3. Sum %sum = pto.vcadd %exp, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> diff --git a/tilelang-dsl/docs/vpto_spec/vpto-spec-v0.3.md b/tilelang-dsl/docs/vpto_spec/vpto-spec-v0.3.md index 8de281795..c2f10ab6d 100644 --- a/tilelang-dsl/docs/vpto_spec/vpto-spec-v0.3.md +++ b/tilelang-dsl/docs/vpto_spec/vpto-spec-v0.3.md @@ -894,7 +894,7 @@ This section provides a categorized overview of all PTO micro Instruction operat | 10 | [Reduction Ops](#isa-10-reduction-ops) | Vector reductions | 7 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin`, `pto.vcgadd`, `pto.vcgmax`, `pto.vcgmin`, `pto.vcpadd` | | 11 | [Compare & Select](#isa-11-compare-select) | Comparison and conditional selection | 4 (+1 not A5) | `pto.vcmp`, `pto.vcmps`, `pto.vsel`, `pto.vselr` (`pto.vselrv2` removed: not A5) | | 12 | [Data Rearrangement](#isa-12-data-rearrangement) | In-register data movement and permutation | 2 (+2 not A5) | `pto.vintlv`, `pto.vdintlv` (`pto.vintlvv2`, `pto.vdintlvv2` removed: not A5) | -| 13 | [DSA/SFU Ops](#isa-13-dsa-sfu-ops) | Specialized ops, index generation, and sorting helpers | 9 | `pto.vlrelu`, `pto.vprelu`, `pto.vexpdiff`, `pto.vaxpy`, `pto.vmull`, `pto.vmula`, `pto.vci`, `pto.vbitsort`, `pto.vmrgsort4` | +| 13 | [DSA/SFU Ops](#isa-13-dsa-sfu-ops) | Specialized ops, index generation, and sorting helpers | 9 | `pto.vlrelu`, `pto.vprelu`, `pto.vexpdif`, `pto.vaxpy`, `pto.vmull`, `pto.vmula`, `pto.vci`, `pto.vbitsort`, `pto.vmrgsort4` | | 14 | [Arith (Shared MLIR Dialect)](#isa-14-shared-arith) | Full scalar `arith` surface used around PTO ops; the companion page lists categories and representative examples | all scalar ops | `arith.constant`, `arith.addi`, `arith.addf`, `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.index_cast`, `arith.extsi`, `arith.trunci`, `arith.andi`, `arith.shli`, etc. | | 15 | [SCF (Shared MLIR Dialect)](#isa-15-shared-scf) | Structured loops, branches, and loop-carried state around PTO regions | 5 | `scf.for`, `scf.if`, `scf.while`, `scf.condition`, `scf.yield` | @@ -4855,9 +4855,9 @@ for (int i = 0; i < N; i++) --- -##### `pto.vexpdiff` +##### `pto.vexpdif` -- **syntax:** `%result = pto.vexpdiff %input, %max, "EVEN|ODD" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **syntax:** `%result = pto.vexpdif %input, %max, "EVEN|ODD" : !pto.vreg, !pto.vreg -> !pto.vreg` - **A5 types:** input `f16` or `f32`, output `f32` - **semantics:** Fused exp(x - max) for numerically stable softmax. @@ -5016,7 +5016,7 @@ for (int i = 0; i < N; i++) ```mlir // Softmax with fused expdiff %max_broadcast = pto.vlds %ub_max[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> -%exp_stable = pto.vexpdiff %logits, %max_broadcast : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> +%exp_stable = pto.vexpdif %logits, %max_broadcast : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> // Leaky ReLU activation %activated = pto.vlrelu %linear_out, %alpha_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> @@ -5252,7 +5252,7 @@ pto.vsts %max_vec, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, ! %max_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> // 2. exp(x - max) using fused op -%exp = pto.vexpdiff %logits, %max_bc, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> +%exp = pto.vexpdif %logits, %max_bc, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> // 3. Sum %sum = pto.vcadd %exp, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index 6f977d522..744008bda 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -3038,7 +3038,6 @@ def _lower_call_expr( "vzunpack", "vusqz", "vsqz", - "vexpdiff", "vcgadd", "vcgmax", "vcgmin", @@ -3054,6 +3053,17 @@ def _lower_call_expr( ) return _RenderedValue(name=result_name, type=expr.type) + if expr.name == "vexpdif": + lhs = self._lower_expr(expr.args[0], env, indent=indent, into=into) + rhs = self._lower_expr(expr.args[1], env, indent=indent, into=into) + part = self._render_string_literal(expr.args[2]) + into.append( + self._indent(indent) + + f"{result_name} = pto.vexpdif {lhs.name}, {rhs.name}, {part} : " + + f"{self._render_type(lhs.type)}, {self._render_type(rhs.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + if expr.name in { "vadd", "vsub", diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 96f2937ac..fbaf78e06 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -228,7 +228,6 @@ def _classify_vcvt_elem_kind(dtype: ScalarType) -> str | None: "vzunpack", "vusqz", "vsqz", - "vexpdiff", "vtrc", "vcgadd", "vcgmax", @@ -275,6 +274,7 @@ def _classify_vcvt_elem_kind(dtype: ScalarType) -> str | None: _TERNARY_VECTOR_OPS = {"vaxpy", "vmula"} _MULTI_RESULT_VECTOR_OPS = {"vmull", "vldsx2", "vldus", "pstu"} _BROADCAST_VECTOR_OPS = {"vbr", "vdup", "vci"} +_VEXPDIF_OP_ALIASES = {"vexpdif", "vexpdiff"} _LOW_LEVEL_DMA_UNARY_CONFIG_OPS = {"set_mov_pad_val"} _LOW_LEVEL_DMA_CONFIG_OPS = { "set_loop2_stride_outtoub", @@ -1188,6 +1188,7 @@ def _should_infer_vecscope( | _MULTI_RESULT_VECTOR_OPS | _BROADCAST_VECTOR_OPS | _ADVANCED_VECTOR_ACTIVITY_OPS + | _VEXPDIF_OP_ALIASES ) def _block_can_live_in_inferred_vecscope( @@ -1289,6 +1290,7 @@ def _frontend_stmt_contains_vector_activity(self, stmt: FrontendStmtNode) -> boo | _MULTI_RESULT_VECTOR_OPS | _BROADCAST_VECTOR_OPS | _ADVANCED_VECTOR_ACTIVITY_OPS + | _VEXPDIF_OP_ALIASES ) ) @@ -1412,6 +1414,7 @@ def _expr_contains_vector_activity(self, expr: SemanticExpr) -> bool: | _MULTI_RESULT_VECTOR_OPS | _BROADCAST_VECTOR_OPS | _ADVANCED_VECTOR_ACTIVITY_OPS + | _VEXPDIF_OP_ALIASES ): return True return any(self._expr_contains_vector_activity(arg) for arg in expr.args) @@ -4033,6 +4036,8 @@ def _analyze_call_expr( return self._analyze_broadcast_vector_op(name, args) if name in _MULTI_RESULT_VECTOR_OPS: return self._analyze_multi_result_vector_op(name, args) + if name in _VEXPDIF_OP_ALIASES: + return self._analyze_vexpdif_op(args) if name in _UNARY_VECTOR_OPS: return self._analyze_unary_vector_op(name, args) if name in _BINARY_VECTOR_OPS: @@ -4742,6 +4747,26 @@ def _analyze_unary_vector_op( result_type = self._vcadd_result_vreg_type(vreg) return SemanticCallExpr(namespace="pto", name=name, args=args, type=result_type) + def _analyze_vexpdif_op( + self, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if len(args) != 3: + raise TypeError("pto.vexpdif expects exactly 3 positional arguments in TileLang DSL v1") + input_expr, max_expr, part_expr = args + input_type = self._require_vreg_expr(input_expr, "pto.vexpdif input") + max_type = self._require_vreg_expr(max_expr, "pto.vexpdif max") + if input_type != max_type: + raise TypeError("pto.vexpdif requires input/max vector types to match") + self._validate_vexpdif_dtype(input_type.element_dtype) + part = self._normalize_vexpdif_part(part_expr, "pto.vexpdif part") + return SemanticCallExpr( + namespace="pto", + name="vexpdif", + args=(input_expr, max_expr, part), + type=self._vexpdif_result_vreg_type(input_type), + ) + def _analyze_binary_vector_op( self, name: str, @@ -5366,6 +5391,11 @@ def _vcadd_result_vreg_type(self, vreg_type: SemanticVRegType) -> SemanticVRegTy return self._vreg_type_for_dtype(widened_dtype) return vreg_type + def _vexpdif_result_vreg_type(self, vreg_type: SemanticVRegType) -> SemanticVRegType: + if vreg_type.element_dtype.name == "f32": + return vreg_type + return SemanticVRegType(element_dtype=f32, lanes=vreg_type.lanes // 2) + def _normalize_position_mode( self, expr: SemanticExpr | None, @@ -5582,6 +5612,31 @@ def _require_string_expr(self, expr: SemanticExpr, context: str) -> str: return expr.binding.value raise TypeError(f"{context} must be a string literal in TileLang DSL") + def _normalize_vexpdif_part(self, expr: SemanticExpr, context: str) -> SemanticExpr: + if ( + isinstance(expr, SemanticSymbolExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "vcvt_part_mode" + and isinstance(expr.value, VcvtPartMode) + ): + part = expr.value.value + elif ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "vcvt_part_mode" + and isinstance(expr.binding.value, VcvtPartMode) + ): + part = expr.binding.value.value + else: + part = self._require_string_expr(expr, context) + if part not in {VcvtPartMode.EVEN.value, VcvtPartMode.ODD.value}: + raise TypeError( + "pto.vexpdif part must be `pto.VcvtPartMode.EVEN` or " + "`pto.VcvtPartMode.ODD`, or one of the canonical strings " + '`"EVEN"` / `"ODD"` in TileLang DSL v1' + ) + return SemanticLiteralExpr(value=part, type=SemanticMetaType(kind="string")) + def _normalize_cmp_mode(self, expr: SemanticExpr, context: str) -> SemanticExpr: if ( isinstance(expr, SemanticSymbolExpr) @@ -5953,7 +6008,7 @@ def _vreg_type_for_dtype(self, dtype: ScalarType) -> SemanticVRegType: return SemanticVRegType(element_dtype=dtype, lanes=256 // width) def _validate_unary_dtype(self, name: str, dtype: ScalarType) -> None: - if name in {"vexp", "vln", "vsqrt", "vrec", "vrsqrt", "vexpdiff"} and dtype.name not in {"f16", "f32"}: + if name in {"vexp", "vln", "vsqrt", "vrec", "vrsqrt"} and dtype.name not in {"f16", "f32"}: raise TypeError(f"pto.{name} only supports f16/f32 in TileLang DSL v1") if name == "vrelu" and dtype.name not in {"f16", "f32"}: raise TypeError("pto.vrelu only supports f16/f32 in TileLang DSL v1") @@ -6003,6 +6058,10 @@ def _validate_binary_dtype(self, name: str, dtype: ScalarType) -> None: ): raise TypeError(f"pto.{name} does not support this dtype in TileLang DSL v1") + def _validate_vexpdif_dtype(self, dtype: ScalarType) -> None: + if dtype.name not in {"f16", "f32"}: + raise TypeError("pto.vexpdif only supports f16/f32 in TileLang DSL v1") + def _validate_vector_scalar_dtype(self, name: str, dtype: ScalarType) -> None: if name == "vdivs" and dtype.name not in {"f16", "f32"}: raise TypeError("pto.vdivs only supports f16/f32 in TileLang DSL v1") diff --git a/tilelang-dsl/python/tilelang_dsl/support_matrix.py b/tilelang-dsl/python/tilelang_dsl/support_matrix.py index 33602fd07..d92bb661c 100644 --- a/tilelang-dsl/python/tilelang_dsl/support_matrix.py +++ b/tilelang-dsl/python/tilelang_dsl/support_matrix.py @@ -104,6 +104,7 @@ "vzunpack", "vusqz", "vsqz", + "vexpdif", "vexpdiff", "vtrc", "vbr", diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index d0cae63cc..38924e83c 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -2965,7 +2965,7 @@ def kernel(dst: pto.Tile, src: pto.Tile, alpha: pto.f32): out = pto.vsqrt(out, all_mask) out = pto.vrec(out, all_mask) out = pto.vrsqrt(out, all_mask) - out = pto.vexpdiff(out, all_mask) + out = pto.vexpdif(out, vec1, pto.VcvtPartMode.ODD) out = pto.vcadd(out, all_mask) out = pto.vcmax(out, all_mask) out = pto.vcmin(out, all_mask) @@ -2987,7 +2987,7 @@ def kernel(dst: pto.Tile, src: pto.Tile, alpha: pto.f32): self.assertIn("pto.vsqrt", text) self.assertIn("pto.vrec", text) self.assertIn("pto.vrsqrt", text) - self.assertIn("pto.vexpdiff", text) + self.assertIn("pto.vexpdif", text) self.assertIn("pto.vcadd", text) self.assertIn("pto.vcmax", text) self.assertIn("pto.vcmin", text) @@ -2997,6 +2997,32 @@ def kernel(dst: pto.Tile, src: pto.Tile, alpha: pto.f32): self.assertIn("pto.vlrelu", text) self.assertIn("pto.vcvt", text) + def test_vexpdif_f16_surface_lowers_to_f32_half_lanes(self) -> None: + @pto.vkernel( + op="vexpdif_f16_surface_unique", + dtypes=[(pto.f32, pto.f16, pto.f16)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile, max_src: pto.Tile): + vec = pto.vlds(src, 0) + max_vec = pto.vlds(max_src, 0) + out = pto.vexpdif(vec, max_vec, pto.VcvtPartMode.ODD) + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + pto.vsts(out, dst, 0, mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + max_src=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertRegex( + text, + r'pto\.vexpdif %\w+_\d+, %\w+_\d+, "ODD" : !pto\.vreg<128xf16>, !pto\.vreg<128xf16> -> !pto\.vreg<64xf32>', + ) + def test_vcvt_supports_keyword_attrs_with_enums(self) -> None: @pto.vkernel( op="vcvt_keyword_attrs_unique", From 4b741bb8945eb924f50a79af9f3ffa0bcfd14c72 Mon Sep 17 00:00:00 2001 From: FangRui Date: Wed, 22 Apr 2026 10:07:42 +0800 Subject: [PATCH 134/192] fix: avoid false A5 tprelu vec overflow in memory planning (cherry picked from commit f6ba2c41f856cfb293aee8b5e3b92dffc552602b) --- .../basic/issue531_tprelu_vec_overflow_a5.pto | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 test/basic/issue531_tprelu_vec_overflow_a5.pto diff --git a/test/basic/issue531_tprelu_vec_overflow_a5.pto b/test/basic/issue531_tprelu_vec_overflow_a5.pto new file mode 100644 index 000000000..b4cee9238 --- /dev/null +++ b/test/basic/issue531_tprelu_vec_overflow_a5.pto @@ -0,0 +1,28 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 %s -o - 2>&1 | FileCheck %s + +// Regression for issue #531: +// A5 TPRELU should not treat tmp as a required scratch write in local-memory +// planning, otherwise this shape triggers a false vec overflow. +// CHECK-NOT: vec overflow +// CHECK: AICORE void TPRELU_OVERFLOW() + +module { + func.func @TPRELU_OVERFLOW() { + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tprelu ins(%src0, %src1, %tmp : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} From f2ecaca2bd43c06e0a77cde8ba2366fa5c4bf4e6 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Wed, 22 Apr 2026 23:41:05 +0800 Subject: [PATCH 135/192] fix(dsl): unify type conversion logics in frontend --- lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 4 +- tilelang-dsl/python/tilelang_dsl/lowering.py | 201 ++++-- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 676 ++++++++++++++++++- 3 files changed, 828 insertions(+), 53 deletions(-) diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index 7ccc2230e..7f4588755 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -2528,7 +2528,9 @@ class LowerVbrOpPattern final : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "failed to convert vbr result type"); Value scalar = adaptor.getValue(); - if (!scalar || scalar.getType() != op.getValue().getType()) + Type expectedScalarType = + this->getTypeConverter()->convertType(op.getValue().getType()); + if (!scalar || !expectedScalarType || scalar.getType() != expectedScalarType) return rewriter.notifyMatchFailure(op, "unexpected converted vbr operand type"); diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index 744008bda..c0dca2161 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -3278,15 +3278,10 @@ def _lower_to_i32( into: list[str], ) -> _RenderedValue: value = self._lower_expr(expr, env, indent=indent, into=into) - if isinstance(value.type, SemanticScalarType) and value.type.dtype.name == "i32": - return value - if isinstance(value.type, SemanticIndexType): - cast_name = self._new_temp() - into.append( - self._indent(indent) - + f"{cast_name} = arith.index_cast {value.name} : index to i32" - ) - return _RenderedValue(name=cast_name, type=_I32_TYPE) + if isinstance(value.type, SemanticIndexType) or ( + isinstance(value.type, SemanticScalarType) and is_integer_dtype(value.type.dtype) + ): + return self._coerce_rendered_value(value, _I32_TYPE, indent=indent, into=into) raise NotImplementedError("expected an i32 or index operand during TileLang DSL v1 lowering") def _lower_to_index( @@ -3298,11 +3293,21 @@ def _lower_to_index( into: list[str], ) -> _RenderedValue: value = self._lower_expr(expr, env, indent=indent, into=into) + return self._coerce_rendered_to_index(value, indent=indent, into=into) + + def _coerce_rendered_to_index( + self, + value: _RenderedValue, + *, + indent: int, + into: list[str], + ) -> _RenderedValue: if isinstance(value.type, SemanticIndexType): return value if isinstance(value.type, SemanticScalarType) and is_integer_dtype(value.type.dtype): bits = integer_bitwidth(value.type.dtype) if bits in {32, 64}: + value = self._bridge_rendered_to_signless_integer(value, indent=indent, into=into) cast_name = self._new_temp() into.append( self._indent(indent) @@ -3318,15 +3323,10 @@ def _coerce_rendered_to_i64( indent: int, into: list[str], ) -> _RenderedValue: - if isinstance(value.type, SemanticScalarType) and value.type.dtype.name == "i64": - return value - if isinstance(value.type, SemanticIndexType): - cast_name = self._new_temp() - into.append( - self._indent(indent) - + f"{cast_name} = arith.index_castui {value.name} : index to i64" - ) - return _RenderedValue(name=cast_name, type=_I64_TYPE) + if isinstance(value.type, SemanticIndexType) or ( + isinstance(value.type, SemanticScalarType) and is_integer_dtype(value.type.dtype) + ): + return self._coerce_rendered_value(value, _I64_TYPE, indent=indent, into=into) raise NotImplementedError("expected an i64 or index operand during TileLang DSL v1 lowering") def _lower_remaining_to_i32( @@ -3338,17 +3338,50 @@ def _lower_remaining_to_i32( into: list[str], ) -> _RenderedValue: value = self._lower_expr(expr, env, indent=indent, into=into) - if isinstance(value.type, SemanticScalarType) and value.type.dtype.name == "i32": - return value - if isinstance(value.type, SemanticIndexType): - cast_name = self._new_temp() - into.append( - self._indent(indent) - + f"{cast_name} = arith.index_cast {value.name} : index to i32" - ) - return _RenderedValue(name=cast_name, type=_I32_TYPE) + if isinstance(value.type, SemanticIndexType) or ( + isinstance(value.type, SemanticScalarType) and is_integer_dtype(value.type.dtype) + ): + return self._coerce_rendered_value(value, _I32_TYPE, indent=indent, into=into) raise NotImplementedError("tail make_mask lowering expects an i32 or index remaining operand") + def _bridge_rendered_to_signless_integer( + self, + value: _RenderedValue, + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + if not isinstance(value.type, SemanticScalarType) or not is_integer_dtype(value.type.dtype): + return value + raw_type = self._signless_integer_scalar_type(value.type) + if raw_type is None: + return value + cast_name = self._new_temp() + into.append( + self._indent(indent) + + f"{cast_name} = builtin.unrealized_conversion_cast {value.name} : " + f"{self._render_type(value.type)} to {self._render_type(raw_type)}" + ) + return _RenderedValue(name=cast_name, type=raw_type) + + def _bridge_rendered_integer_to_target( + self, + value: _RenderedValue, + target_type: SemanticScalarType, + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + if value.type == target_type: + return value + cast_name = self._new_temp() + into.append( + self._indent(indent) + + f"{cast_name} = builtin.unrealized_conversion_cast {value.name} : " + f"{self._render_type(value.type)} to {self._render_type(target_type)}" + ) + return _RenderedValue(name=cast_name, type=target_type) + def _materialize_copy_buffer_ptr( self, value: _RenderedValue, @@ -3393,27 +3426,49 @@ def _scalar_int_sign(dtype: ScalarType) -> str: sign = integer_signedness(dtype) return "signless" if sign is None else sign + def _signless_int_type_for_bits(bits: int) -> SemanticScalarType: + if bits not in {8, 16, 32, 64}: + raise NotImplementedError( + f"unsupported integer bitwidth {bits!r} for signless coercion in TileLang DSL v1 lowering" + ) + return SemanticScalarType(dtype=ScalarType(f"i{bits}")) + if type(value.type) is type(target_type) and value.type == target_type: return value if isinstance(value.type, SemanticIndexType) and isinstance(target_type, SemanticScalarType): target_int_bits = _scalar_int_bits(target_type.dtype) target_sign = _scalar_int_sign(target_type.dtype) - if target_int_bits == 32: - op = "arith.index_castui" if target_sign == "unsigned" else "arith.index_cast" - cast_name = self._new_temp() - into.append( - self._indent(indent) - + f"{cast_name} = {op} {value.name} : index to {target_type.dtype.name}" + signless_target_type = self._signless_integer_scalar_type(target_type) + if signless_target_type is None and target_int_bits in {8, 16, 32, 64}: + signless_target_type = target_type + if target_int_bits in {8, 16, 32, 64} and signless_target_type is not None: + carrier_type = ( + _signless_int_type_for_bits(32) + if target_int_bits in {8, 16} + else _signless_int_type_for_bits(target_int_bits) ) - return _RenderedValue(name=cast_name, type=target_type) - if target_int_bits == 64: - op = "arith.index_castui" if target_sign in {"signless", "unsigned"} else "arith.index_cast" + op = "arith.index_castui" if target_sign == "unsigned" else "arith.index_cast" cast_name = self._new_temp() into.append( self._indent(indent) - + f"{cast_name} = {op} {value.name} : index to {target_type.dtype.name}" + + f"{cast_name} = {op} {value.name} : index to {carrier_type.dtype.name}" ) - return _RenderedValue(name=cast_name, type=target_type) + lowered_value = _RenderedValue(name=cast_name, type=carrier_type) + if target_int_bits in {8, 16}: + lowered_value = self._coerce_rendered_value( + lowered_value, + signless_target_type, + indent=indent, + into=into, + ) + if signless_target_type != target_type: + return self._coerce_rendered_value( + lowered_value, + target_type, + indent=indent, + into=into, + ) + return lowered_value if target_type.dtype.name in {"f16", "bf16", "f32"}: index_to_int_name = self._new_temp() index_to_int_op = "arith.index_castui" @@ -3432,35 +3487,73 @@ def _scalar_int_sign(dtype: ScalarType) -> str: dst = target_type.dtype.name if src == dst: return value - cast_name = self._new_temp() src_bits = _scalar_int_bits(value.type.dtype) dst_bits = _scalar_int_bits(target_type.dtype) if src_bits is not None and dst_bits is not None: - if src_bits == dst_bits: - op = "arith.bitcast" - elif src_bits < dst_bits: + src_sign = _scalar_int_sign(value.type.dtype) + signless_value = self._bridge_rendered_to_signless_integer(value, indent=indent, into=into) + signless_target_type = self._signless_integer_scalar_type(target_type) or target_type + if signless_value.type == signless_target_type: + if signless_target_type != target_type: + return self._bridge_rendered_integer_to_target( + signless_value, + target_type, + indent=indent, + into=into, + ) + return signless_value + cast_name = self._new_temp() + if src_bits < dst_bits: op = "arith.extui" if _scalar_int_sign(value.type.dtype) == "unsigned" else "arith.extsi" - else: + elif src_bits > dst_bits: op = "arith.trunci" + else: + raise NotImplementedError( + f"unsupported same-width integer coercion from {value.type!r} to {target_type!r} " + "in TileLang DSL v1 lowering" + ) into.append( self._indent(indent) - + f"{cast_name} = {op} {value.name} : {src} to {dst}" + + f"{cast_name} = {op} {signless_value.name} : " + f"{self._render_type(signless_value.type)} to {self._render_type(signless_target_type)}" ) - return _RenderedValue(name=cast_name, type=target_type) + lowered_value = _RenderedValue(name=cast_name, type=signless_target_type) + if signless_target_type != target_type: + return self._bridge_rendered_integer_to_target( + lowered_value, + target_type, + indent=indent, + into=into, + ) + return lowered_value if src_bits is not None and dst in {"f16", "bf16", "f32"}: + signless_value = self._bridge_rendered_to_signless_integer(value, indent=indent, into=into) + cast_name = self._new_temp() op = "arith.uitofp" if _scalar_int_sign(value.type.dtype) == "unsigned" else "arith.sitofp" into.append( self._indent(indent) - + f"{cast_name} = {op} {value.name} : {src} to {dst}" + + f"{cast_name} = {op} {signless_value.name} : " + f"{self._render_type(signless_value.type)} to {dst}" ) return _RenderedValue(name=cast_name, type=target_type) if src in {"f16", "bf16", "f32"} and dst_bits is not None: + signless_target_type = self._signless_integer_scalar_type(target_type) or target_type + cast_name = self._new_temp() op = "arith.fptoui" if _scalar_int_sign(target_type.dtype) == "unsigned" else "arith.fptosi" into.append( self._indent(indent) - + f"{cast_name} = {op} {value.name} : {src} to {dst}" + + f"{cast_name} = {op} {value.name} : {src} to {self._render_type(signless_target_type)}" ) - return _RenderedValue(name=cast_name, type=target_type) + lowered_value = _RenderedValue(name=cast_name, type=signless_target_type) + if signless_target_type != target_type: + return self._bridge_rendered_integer_to_target( + lowered_value, + target_type, + indent=indent, + into=into, + ) + return lowered_value + cast_name = self._new_temp() if src in {"f16", "bf16", "f32"} and dst in {"f16", "bf16", "f32"}: op = "arith.extf" if src in {"f16", "bf16"} and dst == "f32" else "arith.truncf" into.append( @@ -3708,6 +3801,18 @@ def _materialize_constant(self, value: object, ty: SemanticType) -> str: if cache_key in self._constant_cache: return self._constant_cache[cache_key] + raw_type = self._signless_integer_scalar_type(ty) + if raw_type is not None: + raw_name = self._materialize_constant(value, raw_type) + name = self._constant_name(value, ty) + self._constant_cache[cache_key] = name + self._constant_lines.append( + self._indent(4) + + f"{name} = builtin.unrealized_conversion_cast {raw_name} : " + f"{self._render_type(raw_type)} to {self._render_type(ty)}" + ) + return name + name = self._constant_name(value, ty) self._constant_cache[cache_key] = name self._constant_lines.append( diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index 38924e83c..65daee63f 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -2450,6 +2450,79 @@ def kernel(tile: pto.Tile): self.assertIn("arith.constant 4294967295 : i32", text) self.assertNotIn("arith.constant 4294967295 : ui32", text) + def test_cached_unsigned_integer_constructor_constant_preserves_typed_bridge(self) -> None: + @pto.vkernel( + op="cached_ui16_constructor_constant_bridge_unique", + dtypes=[(pto.ui16, pto.ui16)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + all_mask = pto.make_mask(pto.ui16, pto.PAT.ALL) + vec = pto.vlds(src, 0) + biased = pto.vadds(vec, pto.ui16(1), all_mask) + out = pto.vadds(biased, pto.ui16(1), all_mask) + pto.vsts(out, dst, 0, all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertEqual(text.count("arith.constant 1 : i16"), 1) + self.assertEqual(text.count(": i16 to ui16"), 1) + self.assertNotIn("arith.constant 1 : ui16", text) + + def test_narrow_typed_integer_zero_constructors_lower_with_signless_bridge(self) -> None: + @pto.vkernel(op="si16_zero_constructor_bridge_unique", dtypes=[(pto.si16, pto.si16)], advanced=True) + def si16_kernel(dst: pto.Tile, src: pto.Tile): + all_mask = pto.make_mask(pto.si16, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vadds(vec, pto.si16(0), all_mask) + pto.vsts(out, dst, 0, all_mask) + return None + + @pto.vkernel(op="ui16_zero_constructor_bridge_unique", dtypes=[(pto.ui16, pto.ui16)], advanced=True) + def ui16_kernel(dst: pto.Tile, src: pto.Tile): + all_mask = pto.make_mask(pto.ui16, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vadds(vec, pto.ui16(0), all_mask) + pto.vsts(out, dst, 0, all_mask) + return None + + @pto.vkernel(op="si8_zero_constructor_bridge_unique", dtypes=[(pto.si8, pto.si8)], advanced=True) + def si8_kernel(dst: pto.Tile, src: pto.Tile): + all_mask = pto.make_mask(pto.si8, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vadds(vec, pto.si8(0), all_mask) + pto.vsts(out, dst, 0, all_mask) + return None + + @pto.vkernel(op="ui8_zero_constructor_bridge_unique", dtypes=[(pto.ui8, pto.ui8)], advanced=True) + def ui8_kernel(dst: pto.Tile, src: pto.Tile): + all_mask = pto.make_mask(pto.ui8, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vadds(vec, pto.ui8(0), all_mask) + pto.vsts(out, dst, 0, all_mask) + return None + + tile_specs = dict( + dst=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + for dtype_name, raw_type, kernel in ( + ("si16", "i16", si16_kernel), + ("ui16", "i16", ui16_kernel), + ("si8", "i8", si8_kernel), + ("ui8", "i8", ui8_kernel), + ): + with self.subTest(dtype=dtype_name): + text = kernel.specialize(**tile_specs).mlir_text() + self.assertEqual(text.count(f"arith.constant 0 : {raw_type}"), 1) + self.assertEqual(text.count(f": {raw_type} to {dtype_name}"), 1) + self.assertNotIn(f"arith.constant 0 : {dtype_name}", text) + def test_unsigned_pad_value_eval_broadcast_bitcasts_signless_literal(self) -> None: @pto.vkernel(op="tile_pad_value_ui16_vbr_unique", dtypes=[(pto.ui16,)], advanced=True) def kernel(tile: pto.Tile): @@ -2476,6 +2549,325 @@ def kernel(tile: pto.Tile): self.assertIn(": i16 to ui16", text) self.assertIn("pto.vbr", text) + def test_index_to_unsigned_scalar_constructor_bridges_via_signless_integer(self) -> None: + @pto.vkernel(op="index_to_ui32_constructor_unique", dtypes=[(pto.ui32,)], advanced=True) + def kernel(tile: pto.Tile): + cols = tile.valid_shape[1] + for col in range(0, cols, 1): + offset = pto.ui32(col) + vec = pto.vbr(offset) + mask, _ = pto.make_mask(pto.ui32, 1) + pto.vsts(vec, tile[col, 0:], mask) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(8, 1), + valid_shape=(8, 1), + memory_space=pto.MemorySpace.UB, + config=pto.TileConfig.from_mapping( + { + "b_layout": pto.BLayout.COL_MAJOR, + "s_layout": pto.SLayout.NONE_BOX, + } + ), + ) + ) + + text = specialized.mlir_text() + self.assertIn("arith.index_castui", text) + self.assertIn(": index to i32", text) + self.assertIn("builtin.unrealized_conversion_cast", text) + self.assertIn(": i32 to ui32", text) + self.assertNotIn(": index to ui32", text) + + def test_index_to_ui16_scalar_constructor_bridges_via_signless_integer(self) -> None: + @pto.vkernel(op="index_to_ui16_constructor_unique", dtypes=[(pto.ui16,)], advanced=True) + def kernel(tile: pto.Tile): + cols = tile.valid_shape[1] + for col in range(0, cols, 1): + offset = pto.ui16(col) + vec = pto.vbr(offset) + mask, _ = pto.make_mask(pto.ui16, 1) + pto.vsts(vec, tile[col, 0:], mask) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(8, 1), + valid_shape=(8, 1), + memory_space=pto.MemorySpace.UB, + config=pto.TileConfig.from_mapping( + { + "b_layout": pto.BLayout.COL_MAJOR, + "s_layout": pto.SLayout.NONE_BOX, + } + ), + ) + ) + + text = specialized.mlir_text() + self.assertIn("arith.index_castui", text) + self.assertIn(": index to i32", text) + self.assertIn("arith.trunci", text) + self.assertIn(": i32 to i16", text) + self.assertIn("builtin.unrealized_conversion_cast", text) + self.assertIn(": i16 to ui16", text) + self.assertNotIn(": index to ui16", text) + + def test_index_to_ui8_scalar_constructor_bridges_via_signless_integer(self) -> None: + @pto.vkernel(op="index_to_ui8_constructor_unique", dtypes=[(pto.ui8,)], advanced=True) + def kernel(tile: pto.Tile): + cols = tile.valid_shape[1] + for col in range(0, cols, 1): + offset = pto.ui8(col) + vec = pto.vbr(offset) + mask, _ = pto.make_mask(pto.ui8, 1) + pto.vsts(vec, tile[col, 0:], mask) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(8, 1), + valid_shape=(8, 1), + memory_space=pto.MemorySpace.UB, + config=pto.TileConfig.from_mapping( + { + "b_layout": pto.BLayout.COL_MAJOR, + "s_layout": pto.SLayout.NONE_BOX, + } + ), + ) + ) + + text = specialized.mlir_text() + self.assertIn("arith.index_castui", text) + self.assertIn(": index to i32", text) + self.assertIn("arith.trunci", text) + self.assertIn(": i32 to i8", text) + self.assertIn("builtin.unrealized_conversion_cast", text) + self.assertIn(": i8 to ui8", text) + self.assertNotIn(": index to ui8", text) + + def test_index_to_si8_scalar_constructor_bridges_via_signless_integer(self) -> None: + @pto.vkernel(op="index_to_si8_constructor_unique", dtypes=[(pto.si8,)], advanced=True) + def kernel(tile: pto.Tile): + cols = tile.valid_shape[1] + for col in range(0, cols, 1): + offset = pto.si8(col) + vec = pto.vbr(offset) + mask, _ = pto.make_mask(pto.si8, 1) + pto.vsts(vec, tile[col, 0:], mask) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(8, 1), + valid_shape=(8, 1), + memory_space=pto.MemorySpace.UB, + config=pto.TileConfig.from_mapping( + { + "b_layout": pto.BLayout.COL_MAJOR, + "s_layout": pto.SLayout.NONE_BOX, + } + ), + ) + ) + + text = specialized.mlir_text() + self.assertIn("arith.index_cast", text) + self.assertIn(": index to i32", text) + self.assertIn("arith.trunci", text) + self.assertIn(": i32 to i8", text) + self.assertIn("builtin.unrealized_conversion_cast", text) + self.assertIn(": i8 to si8", text) + self.assertNotIn(": index to si8", text) + + def test_index_to_i16_scalar_constructor_lowers_via_index_cast_then_trunci(self) -> None: + @pto.vkernel(op="index_to_i16_constructor_unique", dtypes=[(pto.i16,)], advanced=True) + def kernel(tile: pto.Tile): + cols = tile.valid_shape[1] + for col in range(0, cols, 1): + offset = pto.i16(col) + vec = pto.vbr(offset) + mask, _ = pto.make_mask(pto.i16, 1) + pto.vsts(vec, tile[col, 0:], mask) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(8, 1), + valid_shape=(8, 1), + memory_space=pto.MemorySpace.UB, + config=pto.TileConfig.from_mapping( + { + "b_layout": pto.BLayout.COL_MAJOR, + "s_layout": pto.SLayout.NONE_BOX, + } + ), + ) + ) + + text = specialized.mlir_text() + self.assertIn("arith.index_cast", text) + self.assertIn(": index to i32", text) + self.assertIn("arith.trunci", text) + self.assertIn(": i32 to i16", text) + self.assertNotIn("builtin.unrealized_conversion_cast", text) + self.assertNotIn(": index to i16", text) + + def test_index_to_si16_scalar_constructor_bridges_via_signless_integer(self) -> None: + @pto.vkernel(op="index_to_si16_constructor_unique", dtypes=[(pto.si16,)], advanced=True) + def kernel(tile: pto.Tile): + cols = tile.valid_shape[1] + for col in range(0, cols, 1): + offset = pto.si16(col) + vec = pto.vbr(offset) + mask, _ = pto.make_mask(pto.si16, 1) + pto.vsts(vec, tile[col, 0:], mask) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(8, 1), + valid_shape=(8, 1), + memory_space=pto.MemorySpace.UB, + config=pto.TileConfig.from_mapping( + { + "b_layout": pto.BLayout.COL_MAJOR, + "s_layout": pto.SLayout.NONE_BOX, + } + ), + ) + ) + + text = specialized.mlir_text() + self.assertIn("arith.index_cast", text) + self.assertIn(": index to i32", text) + self.assertIn("arith.trunci", text) + self.assertIn(": i32 to i16", text) + self.assertIn("builtin.unrealized_conversion_cast", text) + self.assertIn(": i16 to si16", text) + self.assertNotIn(": index to si16", text) + + def test_index_to_32bit_integer_scalar_constructors_bridge_via_signless_integer(self) -> None: + @pto.vkernel(op="index_to_i32_constructor_unique", dtypes=[(pto.i32,)], advanced=True) + def i32_kernel(tile: pto.Tile): + cols = tile.valid_shape[1] + for col in range(0, cols, 1): + offset = pto.i32(col) + vec = pto.vbr(offset) + mask, _ = pto.make_mask(pto.i32, 1) + pto.vsts(vec, tile[col, 0:], mask) + return None + + @pto.vkernel(op="index_to_si32_constructor_unique", dtypes=[(pto.si32,)], advanced=True) + def si32_kernel(tile: pto.Tile): + cols = tile.valid_shape[1] + for col in range(0, cols, 1): + offset = pto.si32(col) + vec = pto.vbr(offset) + mask, _ = pto.make_mask(pto.si32, 1) + pto.vsts(vec, tile[col, 0:], mask) + return None + + @pto.vkernel(op="index_to_ui32_constructor_bridge_unique", dtypes=[(pto.ui32,)], advanced=True) + def ui32_kernel(tile: pto.Tile): + cols = tile.valid_shape[1] + for col in range(0, cols, 1): + offset = pto.ui32(col) + vec = pto.vbr(offset) + mask, _ = pto.make_mask(pto.ui32, 1) + pto.vsts(vec, tile[col, 0:], mask) + return None + + tile_spec = pto.TileSpecialization( + shape=(8, 1), + valid_shape=(8, 1), + memory_space=pto.MemorySpace.UB, + config=pto.TileConfig.from_mapping( + { + "b_layout": pto.BLayout.COL_MAJOR, + "s_layout": pto.SLayout.NONE_BOX, + } + ), + ) + + for dtype_name, op_name, kernel in ( + ("i32", "arith.index_cast", i32_kernel), + ("si32", "arith.index_cast", si32_kernel), + ("ui32", "arith.index_castui", ui32_kernel), + ): + with self.subTest(dtype=dtype_name): + text = kernel.specialize(tile=tile_spec).mlir_text() + self.assertIn(op_name, text) + self.assertIn(": index to i32", text) + if dtype_name == "i32": + self.assertNotIn("builtin.unrealized_conversion_cast", text) + else: + self.assertIn("builtin.unrealized_conversion_cast", text) + self.assertIn(f": i32 to {dtype_name}", text) + self.assertNotIn(f": index to {dtype_name}", text) + + def test_index_to_64bit_integer_scalar_constructors_bridge_via_signless_integer(self) -> None: + @pto.vkernel(op="index_to_i64_constructor_unique", dtypes=[(pto.i64,)], advanced=True) + def i64_kernel(tile: pto.Tile): + cols = tile.valid_shape[1] + for col in range(0, cols, 1): + offset = pto.i64(col) + vec = pto.vbr(offset) + mask, _ = pto.make_mask(pto.i64, 1) + pto.vsts(vec, tile[col, 0:], mask) + return None + + @pto.vkernel(op="index_to_si64_constructor_unique", dtypes=[(pto.si64,)], advanced=True) + def si64_kernel(tile: pto.Tile): + cols = tile.valid_shape[1] + for col in range(0, cols, 1): + offset = pto.si64(col) + vec = pto.vbr(offset) + mask, _ = pto.make_mask(pto.si64, 1) + pto.vsts(vec, tile[col, 0:], mask) + return None + + @pto.vkernel(op="index_to_ui64_constructor_unique", dtypes=[(pto.ui64,)], advanced=True) + def ui64_kernel(tile: pto.Tile): + cols = tile.valid_shape[1] + for col in range(0, cols, 1): + offset = pto.ui64(col) + vec = pto.vbr(offset) + mask, _ = pto.make_mask(pto.ui64, 1) + pto.vsts(vec, tile[col, 0:], mask) + return None + + tile_spec = pto.TileSpecialization( + shape=(8, 1), + valid_shape=(8, 1), + memory_space=pto.MemorySpace.UB, + config=pto.TileConfig.from_mapping( + { + "b_layout": pto.BLayout.COL_MAJOR, + "s_layout": pto.SLayout.NONE_BOX, + } + ), + ) + + for dtype_name, op_name, kernel in ( + ("i64", "arith.index_cast", i64_kernel), + ("si64", "arith.index_cast", si64_kernel), + ("ui64", "arith.index_castui", ui64_kernel), + ): + with self.subTest(dtype=dtype_name): + text = kernel.specialize(tile=tile_spec).mlir_text() + self.assertIn(op_name, text) + self.assertIn(": index to i64", text) + if dtype_name == "i64": + self.assertNotIn("builtin.unrealized_conversion_cast", text) + else: + self.assertIn("builtin.unrealized_conversion_cast", text) + self.assertIn(f": i64 to {dtype_name}", text) + self.assertNotIn(f": index to {dtype_name}", text) + def test_make_mask_vlds_vsts_and_vector_families_lower_inside_strict_vecscope(self) -> None: @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32)], advanced=True) @@ -3887,6 +4279,65 @@ def kernel(dst: pto.Tile, src: pto.Tile, scalar: pto.i32): self.assertIn("pto.vors", text) self.assertIn("pto.vxors", text) + def test_vci_typed_integer_inputs_lower_without_typed_arith(self) -> None: + @pto.vkernel( + op="vci_typed_integer_inputs_unique", + dtypes=[(pto.ui16, pto.si16, pto.i32)], + advanced=True, + ) + def kernel(dst_u: pto.Tile, dst_s: pto.Tile, seed: pto.i32): + unsigned_mask = pto.make_mask(pto.ui16, pto.PAT.ALL) + signed_mask = pto.make_mask(pto.si16, pto.PAT.ALL) + + unsigned_idx = pto.vci(pto.ui16(0)) + signed_idx = pto.vci(pto.si16(seed), pto.OrderMode.ASC) + + pto.vsts(unsigned_idx, dst_u, 0, unsigned_mask) + pto.vsts(signed_idx, dst_s, 0, signed_mask) + return None + + specialized = kernel.specialize( + dst_u=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + dst_s=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vci", text) + self.assertIn(": i16 to ui16", text) + self.assertIn(": i16 to si16", text) + self.assertNotIn("arith.constant 0 : ui16", text) + self.assertNotRegex(text, r"arith\.(extsi|extui|trunci|bitcast) %\w+ : .* to (ui16|si16)") + + def test_vector_scalar_bitwise_typed_scalar_inputs_lower_without_typed_arith(self) -> None: + @pto.vkernel( + op="vector_scalar_bitwise_typed_scalar_inputs_unique", + dtypes=[(pto.ui16, pto.ui16, pto.i32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile, seed: pto.i32): + mask = pto.make_mask(pto.ui16, pto.PAT.ALL) + vec = pto.vlds(src, 0) + scalar = pto.ui16(seed) + out = pto.vands(vec, scalar, mask) + out = pto.vors(out, scalar, mask) + out = pto.vxors(out, scalar, mask) + pto.vsts(out, dst, 0, mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vands", text) + self.assertIn("pto.vors", text) + self.assertIn("pto.vxors", text) + self.assertIn("arith.trunci", text) + self.assertIn(": i32 to i16", text) + self.assertIn(": i16 to ui16", text) + self.assertNotRegex(text, r"arith\.trunci %\w+ : i32 to ui16") + def test_broadcast_and_index_vector_ops_surface_lowers(self) -> None: @pto.vkernel( op="broadcast_and_index_vector_ops_unique", @@ -3959,6 +4410,45 @@ def kernel(dst: pto.Tile, seed: pto.i32): self.assertIn("pto.vdup scalar input does not accept `position`", str(ctx.exception)) + def test_vbr_and_vdup_accept_narrow_typed_scalar_constructors_with_explicit_bridges(self) -> None: + @pto.vkernel( + op="narrow_typed_vbr_vdup_scalar_constructors_unique", + dtypes=[(pto.si16, pto.ui16)], + advanced=True, + ) + def kernel(dst_s: pto.Tile, dst_u: pto.Tile): + signed_mask = pto.make_mask(pto.si16, pto.PAT.ALL) + unsigned_mask = pto.make_mask(pto.ui16, pto.PAT.ALL) + + signed = pto.vadd( + pto.vbr(pto.si16(0)), + pto.vdup(pto.si16(0), signed_mask), + signed_mask, + ) + unsigned = pto.vadd( + pto.vbr(pto.ui16(0)), + pto.vdup(pto.ui16(0), unsigned_mask), + unsigned_mask, + ) + + pto.vsts(signed, dst_s, 0, signed_mask) + pto.vsts(unsigned, dst_u, 0, unsigned_mask) + return None + + specialized = kernel.specialize( + dst_s=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + dst_u=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vbr", text) + self.assertIn("pto.vdup", text) + self.assertIn("arith.constant 0 : i16", text) + self.assertIn(": i16 to si16", text) + self.assertIn(": i16 to ui16", text) + self.assertNotIn("arith.constant 0 : si16", text) + self.assertNotIn("arith.constant 0 : ui16", text) + def test_signed_and_unsigned_integer_dtypes_lower_distinctly(self) -> None: @pto.vkernel( op="signed_unsigned_integer_types_unique", @@ -3989,6 +4479,90 @@ def kernel(dst_s: pto.Tile, src_s: pto.Tile, dst_u: pto.Tile, src_u: pto.Tile): self.assertIn("!pto.vreg<128xsi16>", text) self.assertIn("!pto.vreg<128xui16>", text) + def test_vcmps_literal_scalar_uses_signless_integer_bridge(self) -> None: + @pto.vkernel( + op="vcmps_literal_scalar_bridge_unique", + dtypes=[(pto.si16, pto.si16)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + all_mask = pto.make_mask(pto.si16, pto.PAT.ALL) + vec = pto.vlds(src, 0) + cmp_mask = pto.vcmps(vec, pto.si16(-1), all_mask, pto.CmpMode.GT) + selected = pto.vsel(vec, vec, cmp_mask) + pto.vsts(selected, dst, 0, all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vcmps", text) + self.assertIn("arith.constant -1 : i16", text) + self.assertIn(": i16 to si16", text) + self.assertNotIn("arith.constant -1 : si16", text) + + def test_vadds_index_constructor_scalar_uses_signless_integer_bridge(self) -> None: + @pto.vkernel( + op="vadds_index_constructor_bridge_unique", + dtypes=[(pto.ui16, pto.ui16)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + cols = dst.valid_shape[1] + vec = pto.vlds(src, 0) + mask, _ = pto.make_mask(pto.ui16, 1) + for col in range(0, cols, 1): + scalar = pto.ui16(col) + out = pto.vadds(vec, scalar, mask) + pto.vsts(out, dst, 0, mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization( + shape=(8, 128), + valid_shape=(8, 1), + memory_space=pto.MemorySpace.UB, + ), + src=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vadds", text) + self.assertIn("arith.index_castui", text) + self.assertIn(": index to i32", text) + self.assertIn("arith.trunci", text) + self.assertIn(": i32 to i16", text) + self.assertIn(": i16 to ui16", text) + self.assertNotIn(": index to ui16", text) + + def test_vshrs_cast_result_scalar_uses_signless_integer_bridge(self) -> None: + @pto.vkernel( + op="vshrs_cast_result_scalar_bridge_unique", + dtypes=[(pto.i32, pto.i32, pto.ui16)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile, shift_seed: pto.ui16): + all_mask = pto.make_mask(pto.i32, pto.PAT.ALL) + vec = pto.vlds(src, 0) + shift = pto.i16(shift_seed) + out = pto.vshrs(vec, shift, all_mask) + pto.vsts(out, dst, 0, all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vshrs", text) + self.assertIn(": ui16 to i16", text) + self.assertNotRegex(text, r"arith\.bitcast %\w+ : ui16 to i16") + self.assertNotRegex(text, r"arith\.trunci %\w+ : ui16 to i16") + def test_vbr_accepts_float_literal_constant(self) -> None: @pto.vkernel( op="broadcast_float_literal_constant_unique", @@ -4069,6 +4643,95 @@ def kernel(dst: pto.Tile, src: pto.Tile): self.assertIn("arith.extf", text) self.assertIn("arith.truncf", text) + def test_typed_integer_scalar_coercion_uses_signless_integer_carriers(self) -> None: + @pto.vkernel( + op="typed_integer_scalar_coercion_unique", + dtypes=[(pto.si16, pto.si16, pto.ui16, pto.ui16, pto.i32, pto.i32, pto.i32)], + advanced=True, + ) + def kernel( + dst_s: pto.Tile, + src_s: pto.Tile, + dst_u: pto.Tile, + src_u: pto.Tile, + dst_i: pto.Tile, + src_i: pto.Tile, + seed: pto.i32, + ): + signed_mask = pto.make_mask(pto.si16, pto.PAT.ALL) + unsigned_mask = pto.make_mask(pto.ui16, pto.PAT.ALL) + scalar_mask = pto.make_mask(pto.i32, pto.PAT.ALL) + + signed_scalar = pto.si16(seed) + unsigned_scalar = pto.ui16(seed) + + signed_vec = pto.vlds(src_s, 0) + unsigned_vec = pto.vlds(src_u, 0) + scalar_vec = pto.vlds(src_i, 0) + + signed_out = pto.vadds(signed_vec, signed_scalar, signed_mask) + unsigned_out = pto.vadds(unsigned_vec, unsigned_scalar, unsigned_mask) + scalar_out = pto.vadds(scalar_vec, pto.i32(signed_scalar), scalar_mask) + scalar_out = pto.vadds(scalar_out, pto.i32(unsigned_scalar), scalar_mask) + + pto.vsts(signed_out, dst_s, 0, signed_mask) + pto.vsts(unsigned_out, dst_u, 0, unsigned_mask) + pto.vsts(scalar_out, dst_i, 0, scalar_mask) + return None + + specialized = kernel.specialize( + dst_s=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src_s=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + dst_u=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src_u=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + dst_i=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src_i=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("arith.trunci", text) + self.assertIn(": i32 to i16", text) + self.assertIn(": i16 to si16", text) + self.assertIn(": i16 to ui16", text) + self.assertIn(": si16 to i16", text) + self.assertIn(": ui16 to i16", text) + self.assertIn("arith.extsi", text) + self.assertIn("arith.extui", text) + self.assertNotRegex(text, r"arith\.trunci %\w+ : i32 to (si16|ui16)") + self.assertNotRegex(text, r"arith\.extsi %\w+ : si16 to i32") + self.assertNotRegex(text, r"arith\.extui %\w+ : ui16 to i32") + + def test_typed_integer_float_scalar_coercion_uses_signless_integer_carriers(self) -> None: + @pto.vkernel( + op="typed_integer_float_scalar_coercion_unique", + dtypes=[(pto.ui16, pto.ui16)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + mask = pto.make_mask(pto.ui16, pto.PAT.ALL) + vec = pto.vlds(src, 0) + scalar = pto.ui16(1) + flt = pto.f32(scalar) + back = pto.ui16(flt) + out = pto.vadds(vec, back, mask) + pto.vsts(out, dst, 0, mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn(": ui16 to i16", text) + self.assertIn("arith.uitofp", text) + self.assertIn(": i16 to f32", text) + self.assertIn("arith.fptoui", text) + self.assertIn(": f32 to i16", text) + self.assertIn(": i16 to ui16", text) + self.assertNotRegex(text, r"arith\.uitofp %\w+ : ui16 to f32") + self.assertNotRegex(text, r"arith\.fptoui %\w+ : f32 to ui16") + def test_scalar_constructor_accepts_signed_float_literals(self) -> None: @pto.vkernel(op="scalar_constructor_signed_float_literals_unique", dtypes=[(pto.f32,)]) def kernel(inp: pto.TensorView): @@ -4945,7 +5608,7 @@ def kernel(dst: pto.Tile): ) text = specialized.mlir_text() - self.assertIn("arith.bitcast", text) + self.assertIn("builtin.unrealized_conversion_cast", text) self.assertRegex(text, r"pto\.set_mov_pad_val %[^ ]+ : i16") def test_copy_ubuf_to_gm_keyword_surface_lowers_in_advanced_mode(self) -> None: @@ -5806,17 +6469,19 @@ def kernel(mask_dst: pto.ptr(pto.ui32, pto.MemorySpace.UB)): def test_predicate_load_store_alias_and_immediate_forms_lower_to_supported_ops(self) -> None: @pto.vkernel( op="predicate_load_store_alias_and_immediate_forms", - dtypes=[(pto.ui32, pto.ui32)], + dtypes=[(pto.ui32, pto.ui32, pto.ui32, pto.si32)], advanced=True, ) def kernel( mask_src: pto.ptr(pto.ui32, pto.MemorySpace.UB), mask_dst: pto.ptr(pto.ui32, pto.MemorySpace.UB), + off_u: pto.ui32, + off_s: pto.si32, ): mask0 = pto.pld(mask_src, 0, pto.PredicateDist.NORM) - mask1 = pto.pldi(mask_src, pto.i32(8), pto.PredicateDist.US) + mask1 = pto.pldi(mask_src, pto.i32(off_u), pto.PredicateDist.US) pto.pst(mask0, mask_dst, 0) - pto.psti(mask1, mask_dst, pto.i32(8), pto.PredicateDist.PK) + pto.psti(mask1, mask_dst, pto.i32(off_s), pto.PredicateDist.PK) return None text = kernel.specialize().mlir_text() @@ -5824,7 +6489,10 @@ def kernel( self.assertIn("pto.pldi", text) self.assertIn("pto.psts", text) self.assertIn("pto.psti", text) + self.assertIn("builtin.unrealized_conversion_cast", text) self.assertIn("arith.index_cast", text) + self.assertNotRegex(text, r"arith\.extsi %\w+ : si32 to i32") + self.assertNotRegex(text, r"arith\.extui %\w+ : ui32 to i32") def test_predicate_reorder_families_lower_to_supported_ops(self) -> None: @pto.vkernel( From 304e18640bb5e87d199ab2f05aea3064eccee4da Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Tue, 21 Apr 2026 23:36:01 +0800 Subject: [PATCH 136/192] fix(ptoas): make vpto backend own tile op expansion (#162) --- lib/PTO/Transforms/ExpandTileOp.cpp | 3 +- test/dsl/expand_tile_op_tilelang_tadds.pto | 1 + tools/ptoas/ptoas.cpp | 51 +++++++++++----------- 3 files changed, 27 insertions(+), 28 deletions(-) diff --git a/lib/PTO/Transforms/ExpandTileOp.cpp b/lib/PTO/Transforms/ExpandTileOp.cpp index e9f4bf912..674892450 100644 --- a/lib/PTO/Transforms/ExpandTileOp.cpp +++ b/lib/PTO/Transforms/ExpandTileOp.cpp @@ -818,8 +818,7 @@ void ExpandTileOpPass::runOnOperation() { if (tilelangPath.empty()) { mod.emitError( - "ExpandTileOp requires a non-empty tilelang-path when " - "--enable-tile-op-expand is set"); + "ExpandTileOp requires a non-empty tilelang-path on the VPTO backend"); signalPassFailure(); return; } diff --git a/test/dsl/expand_tile_op_tilelang_tadds.pto b/test/dsl/expand_tile_op_tilelang_tadds.pto index e7360c376..55c206856 100644 --- a/test/dsl/expand_tile_op_tilelang_tadds.pto +++ b/test/dsl/expand_tile_op_tilelang_tadds.pto @@ -5,6 +5,7 @@ // Test ExpandTileOp expansion for pto.tadds in the VPTO pipeline. // +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto %s -o - | FileCheck %s // RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand --emit-vpto %s -o - | FileCheck %s // CHECK: func.func @TADDS() diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index de0f94ae5..53ad57c9d 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -200,7 +200,8 @@ static llvm::cl::opt enableInsertSync("enable-insert-sync", static llvm::cl::opt enableTileOpExpand( "enable-tile-op-expand", llvm::cl::desc( - "Enable Tile-to-Vector lowering path (memref->tile_buf recovery)"), + "Deprecated compatibility flag. TileOp expansion is controlled by " + "--pto-backend=vpto."), llvm::cl::init(false)); #ifndef PTOAS_DEFAULT_TILELANG_PATH @@ -1140,31 +1141,29 @@ static LogicalResult prepareVPTOForEmission(ModuleOp module) { static LogicalResult lowerPTOToVPTOBackend(ModuleOp module) { PassManager backendPM(module.getContext()); - if (enableTileOpExpand) { - // TileOp Expand path: - // 1. MemrefToTileBuf: recover tile_buf from memref - // 2. ExpandTileOp: instantiate TileLang DSL templates, replace tile ops - // with func.call to template functions (tile_buf params) - // 3. InlineLibCall: inline template function bodies - // 4. FoldTileBufIntrinsics: fold tile_buf_addr / tile_valid_rows / - // tile_valid_cols to concrete memref/constant values - backendPM.addPass(pto::createMemrefToTileBufPass()); - - pto::ExpandTileOpOptions expandOpts; - expandOpts.tilelangPath = tilelangPath; - expandOpts.tilelangPkgPath = tilelangPkgPath; - backendPM.addPass(pto::createExpandTileOpPass(expandOpts)); - - backendPM.addPass(pto::createPTOInlineLibCallPass()); - backendPM.addNestedPass( - pto::createFoldTileBufIntrinsicsPass()); - // FoldTileBufIntrinsics materializes many constant branch conditions. - // Clean them up immediately on the TileOp expansion path before the - // authoring-stage VPTO verifier and let the existing CSE passes remove the - // resulting dead values later in the pipeline. - backendPM.addPass(mlir::createSCCPPass()); - backendPM.addPass(mlir::createCanonicalizerPass()); - } + // TileOp Expand path: + // 1. MemrefToTileBuf: recover tile_buf from memref + // 2. ExpandTileOp: instantiate TileLang DSL templates, replace tile ops + // with func.call to template functions (tile_buf params) + // 3. InlineLibCall: inline template function bodies + // 4. FoldTileBufIntrinsics: fold tile_buf_addr / tile_valid_rows / + // tile_valid_cols to concrete memref/constant values + backendPM.addPass(pto::createMemrefToTileBufPass()); + + pto::ExpandTileOpOptions expandOpts; + expandOpts.tilelangPath = tilelangPath; + expandOpts.tilelangPkgPath = tilelangPkgPath; + backendPM.addPass(pto::createExpandTileOpPass(expandOpts)); + + backendPM.addPass(pto::createPTOInlineLibCallPass()); + backendPM.addNestedPass( + pto::createFoldTileBufIntrinsicsPass()); + // FoldTileBufIntrinsics materializes many constant branch conditions. + // Clean them up immediately on the TileOp expansion path before the + // authoring-stage VPTO verifier and let the existing CSE passes remove the + // resulting dead values later in the pipeline. + backendPM.addPass(mlir::createSCCPPass()); + backendPM.addPass(mlir::createCanonicalizerPass()); if (failed(applyConfiguredPassManagerCLOptions(backendPM, "VPTO backend lowering"))) return failure(); From d1f50330c3871fe576528188839eda08584d25ff Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Thu, 23 Apr 2026 16:31:53 +0800 Subject: [PATCH 137/192] feat: support pto.pbitcast op --- docs/isa/09-conversion-ops.md | 25 ++++ docs/vpto-spec.md | 4 +- include/PTO/IR/VPTOOps.td | 2 + lib/PTO/IR/VPTO.cpp | 7 ++ lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 29 ++++- .../compare.py | 44 +++++++ .../golden.py | 54 ++++++++ .../kernel.pto | 77 ++++++++++++ .../launch.cpp | 51 ++++++++ .../vsel-f32-plds-us-pintlv-pbitcast/main.cpp | 116 ++++++++++++++++++ .../vsel-f32-plds-us-pintlv-pbitcast/stub.cpp | 25 ++++ .../docs/user_guide/05-type-system.md | 13 ++ .../user_guide/10-predicate-operations.md | 30 +++++ .../python/tilelang_dsl/frontend_ast.py | 1 + tilelang-dsl/python/tilelang_dsl/lowering.py | 9 ++ tilelang-dsl/python/tilelang_dsl/semantic.py | 47 +++++-- .../python/tilelang_dsl/support_matrix.py | 1 + tilelang-dsl/tests/test_tilelang_dsl_v1.py | 91 +++++++++++++- 18 files changed, 608 insertions(+), 18 deletions(-) create mode 100644 test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/compare.py create mode 100644 test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/golden.py create mode 100644 test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/kernel.pto create mode 100644 test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/launch.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/main.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/stub.cpp diff --git a/docs/isa/09-conversion-ops.md b/docs/isa/09-conversion-ops.md index 47f9f50a0..5c9cd5515 100644 --- a/docs/isa/09-conversion-ops.md +++ b/docs/isa/09-conversion-ops.md @@ -318,3 +318,28 @@ for (int i = 0; i < N; i++) %unsigned = pto.vbitcast %signed : !pto.vreg<64xsi32> -> !pto.vreg<64xui32> // Bits are identical; interpretation changes from signed to unsigned ``` + +## `pto.pbitcast` + +- **syntax:** `%result = pto.pbitcast %input : !pto.mask -> !pto.mask` +- **semantics:** Bitwise reinterpretation of a predicate register without + changing the underlying predicate-register image. This op makes mask-family + reinterpretation explicit in VPTO IR when a producer and consumer expect + different `!pto.mask<...>` views of the same hardware predicate state. + +- **inputs:** + `%input` is the source predicate register value. +- **outputs:** + `%result` is the reinterpreted predicate register value. +- **constraints and limitations:** + 1. Both source and result must be `!pto.mask<...>` types. + 2. `pto.pbitcast` does not materialize or normalize predicate contents; it + only changes which mask granularity the surrounding VPTO IR uses to + interpret the same predicate bits. + +**Example: Reinterpret a b16 predicate as b32 before a consumer** +```mlir +%m16 = pto.pintlv_b16 %lhs, %rhs : !pto.mask, !pto.mask -> !pto.mask, !pto.mask +%m32 = pto.pbitcast %m16#0 : !pto.mask -> !pto.mask +%result = pto.vsel %a, %b, %m32 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` diff --git a/docs/vpto-spec.md b/docs/vpto-spec.md index 2e2beb5c0..7a2d177c8 100644 --- a/docs/vpto-spec.md +++ b/docs/vpto-spec.md @@ -889,7 +889,7 @@ This section provides a categorized overview of all PTO micro Instruction operat | 6 | [Unary Vector Ops](isa/06-unary-vector-ops.md) | Single-input element-wise operations | 6 | `pto.vabs`, `pto.vexp`, `pto.vln`, `pto.vsqrt`, `pto.vrelu`, `pto.vnot` | | 7 | [Binary Vector Ops](isa/07-binary-vector-ops.md) | Two-input element-wise operations | 13 | `pto.vadd`, `pto.vsub`, `pto.vmul`, `pto.vdiv`, `pto.vmax`, `pto.vmin`, `pto.vand`, `pto.vor`, `pto.vxor`, `pto.vshl`, `pto.vshr`, `pto.vaddc`, `pto.vsubc` | | 8 | [Vec-Scalar Ops](isa/08-vec-scalar-ops.md) | Vector-scalar operations | 9 | `pto.vadds`, `pto.vmuls`, `pto.vmaxs`, `pto.vmins`, `pto.vlrelu`, `pto.vshls`, `pto.vshrs`, `pto.vaddcs`, `pto.vsubcs` | -| 9 | [Conversion Ops](isa/09-conversion-ops.md) | Type conversion with rounding/saturation control | 3 | `pto.vcvt`, `pto.vtrc`, `pto.vbitcast` | +| 9 | [Conversion Ops](isa/09-conversion-ops.md) | Type conversion with rounding/saturation control | 4 | `pto.vcvt`, `pto.vtrc`, `pto.vbitcast`, `pto.pbitcast` | | 10 | [Reduction Ops](isa/10-reduction-ops.md) | Vector reductions | 7 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin`, `pto.vcgadd`, `pto.vcgmax`, `pto.vcgmin`, `pto.vcpadd` | | 11 | [Compare & Select](isa/11-compare-select.md) | Comparison and conditional selection | 4 (+1 not A5) | `pto.vcmp`, `pto.vcmps`, `pto.vsel`, `pto.vselr` (`pto.vselrv2` removed: not A5) | | 12 | [Data Rearrangement](isa/12-data-rearrangement.md) | In-register data movement and permutation | 2 (+2 not A5) | `pto.vintlv`, `pto.vdintlv` (`pto.vintlvv2`, `pto.vdintlvv2` removed: not A5) | @@ -928,7 +928,7 @@ This section provides a categorized overview of all PTO micro Instruction operat | Operation | Group | Description | |-----------|-------|-------------| -| Type Conversion | 9 | `pto.vcvt`, `pto.vbitcast` | +| Type Conversion | 9 | `pto.vcvt`, `pto.vbitcast`, `pto.pbitcast` | | Interleave/Deinterleave | 12 | `pto.vintlv`, `pto.vdintlv` | | Interleave/Deinterleave (not A5) | 12 | `pto.vintlvv2`, `pto.vdintlvv2` | diff --git a/include/PTO/IR/VPTOOps.td b/include/PTO/IR/VPTOOps.td index 5e4b5178c..4c4b6df6b 100644 --- a/include/PTO/IR/VPTOOps.td +++ b/include/PTO/IR/VPTOOps.td @@ -584,6 +584,8 @@ def PTO_PunpackOp : PTO_MaskUnaryOp<"punpack"> { }]; } +def PTO_PbitcastOp : PTO_MaskUnaryOp<"pbitcast">; + def PTO_PnotOp : PTO_Op<"pnot", [Pure]> { let arguments = (ins PTO_MaskTypeConstraint:$input, diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp index 4ac498d6a..25120cf54 100644 --- a/lib/PTO/IR/VPTO.cpp +++ b/lib/PTO/IR/VPTO.cpp @@ -2117,6 +2117,13 @@ LogicalResult PunpackOp::verify() { return success(); } +LogicalResult PbitcastOp::verify() { + if (failed(verifyMaskTypeLike(*this, getInput().getType(), "input type")) || + failed(verifyMaskTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + return success(); +} + LogicalResult PnotOp::verify() { if (failed(verifyMaskTypeLike(*this, getInput().getType(), "input type")) || failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type")) || diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index 7f4588755..0a5c05dea 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -4300,6 +4300,30 @@ class LowerVbitcastOpPattern final } }; +class LowerPbitcastOpPattern final + : public OpConversionPattern { +public: + explicit LowerPbitcastOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context) {} + + LogicalResult + matchAndRewrite(pto::PbitcastOp op, pto::PbitcastOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, + "failed to convert pbitcast result type"); + if (adaptor.getInput().getType() != resultType) { + return rewriter.notifyMatchFailure( + op, "pbitcast expects identical lowered input/result types"); + } + rewriter.replaceOp(op, adaptor.getInput()); + return success(); + } +}; + class LowerVtrcOpPattern final : public OpConversionPattern { public: explicit LowerVtrcOpPattern(TypeConverter &typeConverter, MLIRContext *context, @@ -5055,7 +5079,7 @@ static void populateVPTOOpLoweringPatterns(VPTOTypeConverter &typeConverter, LowerVpreluOpPattern, LowerVaxpyOpPattern, LowerVciOpPattern, LowerVexpdifOpPattern, LowerVbitsortOpPattern, LowerVtrcOpPattern, LowerVcvtOpPattern, - LowerVbitcastOpPattern, + LowerVbitcastOpPattern, LowerPbitcastOpPattern, LowerPredicateLoadOpPattern, LowerPredicateLoadOpPattern, LowerPredicateStoreOpPattern, @@ -5106,7 +5130,8 @@ static void configureVPTOOpLoweringTarget(ConversionTarget &target, pto::VcaddOp, pto::VcmaxOp, pto::VcminOp, pto::VcgaddOp, pto::VcgmaxOp, pto::VcgminOp, pto::VcpaddOp, pto::VdupOp, pto::VbrOp, - pto::PpackOp, pto::PunpackOp, pto::VselOp, pto::VselrOp, + pto::PpackOp, pto::PunpackOp, pto::PbitcastOp, + pto::VselOp, pto::VselrOp, pto::PnotOp, pto::PselOp, pto::PandOp, pto::PorOp, pto::PxorOp, pto::PdintlvB8Op, pto::PdintlvB16Op, pto::PdintlvB32Op, pto::PintlvB8Op, pto::PintlvB16Op, pto::PintlvB32Op, diff --git a/test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/compare.py b/test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/compare.py new file mode 100644 index 000000000..0823d7be6 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/compare.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + diff = np.abs(golden.astype(np.float64) - output.astype(np.float64)) + idx = int(np.argmax(diff)) + print(f"[ERROR] Mismatch: idx={idx} golden={golden[idx]} out={output[idx]}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v4.bin", "v4.bin", np.float32, 1e-6) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/golden.py b/test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/golden.py new file mode 100644 index 000000000..d48384a1f --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/golden.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 128 +MASK_BYTES = 32 +SEED = 19 + + +def unpack_mask_lanes(packed: np.ndarray) -> np.ndarray: + bits = np.unpackbits(packed[:16], bitorder="little") + return bits.astype(np.bool_, copy=False) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-3.0, 3.0, size=(LANES,)).astype(np.float32) + v2 = rng.uniform(-3.0, 3.0, size=(LANES,)).astype(np.float32) + packed = rng.integers(0, 256, size=(MASK_BYTES,), dtype=np.uint8) + lanes = unpack_mask_lanes(packed) + v4 = np.zeros((LANES,), dtype=np.float32) + golden_v4 = np.where(lanes, v1, v2).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + packed.tofile(output_dir / "v3.bin") + v4.tofile(output_dir / "v4.bin") + golden_v4.tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate inputs/golden for VPTO vsel-f32-plds-us-pintlv-pbitcast." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/kernel.pto b/test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/kernel.pto new file mode 100644 index 000000000..66ec875f8 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/kernel.pto @@ -0,0 +1,77 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast +// family: compare-select +// target_ops: pto.plds, pto.pintlv_b16, pto.pbitcast, pto.vsel +// scenarios: packed-us-mask-expand-to-b32, f32-select-from-compressed-mask +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vsel_f32_plds_us_pintlv_pbitcast_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr, + %arg3: !pto.ptr) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c0_i64 = arith.constant 0 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c1056_i64 = arith.constant 1056 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c512_i64 : i64 -> !pto.ptr + %ub_mask = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c1056_i64 : i64 -> !pto.ptr + + pto.set_loop1_stride_outtoub %c128_i64, %c128_i64 : i64, i64 + pto.set_loop2_stride_outtoub %c128_i64, %c128_i64 : i64, i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.dma_load %arg2, %ub_mask, %c0_i64, %c0_i64, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask_b8 = pto.plds %ub_mask[%c0], "US" : !pto.ptr, index -> !pto.mask + %mask_b16 = pto.pbitcast %mask_b8 : !pto.mask -> !pto.mask + %all_b16 = pto.pset_b16 "PAT_ALL" : !pto.mask + %all_b32 = pto.pset_b32 "PAT_ALL" : !pto.mask + %mask0_b16, %mask1_b16 = pto.pintlv_b16 %mask_b16, %all_b16 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask + %mask0_b32 = pto.pbitcast %mask0_b16 : !pto.mask -> !pto.mask + %mask1_b32 = pto.pbitcast %mask1_b16 : !pto.mask -> !pto.mask + %lhs0 = pto.vlds %ub_lhs[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %rhs0 = pto.vlds %ub_rhs[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %lhs1 = pto.vlds %ub_lhs[%c64] : !pto.ptr -> !pto.vreg<64xf32> + %rhs1 = pto.vlds %ub_rhs[%c64] : !pto.ptr -> !pto.vreg<64xf32> + %out0 = pto.vsel %lhs0, %rhs0, %mask0_b32 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %out1 = pto.vsel %lhs1, %rhs1, %mask1_b32 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out0, %ub_out[%c0], %all_b32 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.vsts %out1, %ub_out[%c64], %all_b32 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop1_stride_ubtoout %c128_i64, %c128_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c128_i64, %c128_i64 : i64, i64 + pto.dma_store %ub_out, %arg3, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/launch.cpp b/test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/launch.cpp new file mode 100644 index 000000000..73a736714 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/launch.cpp @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vsel_f32_plds_us_pintlv_pbitcast_kernel_2d(__gm__ float *v1, __gm__ float *v2, + __gm__ unsigned char *v3, + __gm__ float *v4); + +void LaunchVsel_f32_plds_us_pintlv_pbitcast_kernel_2d(float *v1, float *v2, + unsigned char *v3, + float *v4, + void *stream) { + vsel_f32_plds_us_pintlv_pbitcast_kernel_2d<<<1, nullptr, stream>>>( + (__gm__ float *)v1, (__gm__ float *)v2, (__gm__ unsigned char *)v3, + (__gm__ float *)v4); +} diff --git a/test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/main.cpp b/test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/main.cpp new file mode 100644 index 000000000..4f36cd466 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/main.cpp @@ -0,0 +1,116 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast +// ----------------------------------------------------------------------------- +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVsel_f32_plds_us_pintlv_pbitcast_kernel_2d(float *v1, float *v2, + unsigned char *v3, + float *v4, + void *stream); + +int main() { + size_t elemCount_v1 = 128; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 128; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 32; + size_t fileSize_v3 = elemCount_v3 * sizeof(unsigned char); + size_t elemCount_v4 = 128; + size_t fileSize_v4 = elemCount_v4 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + unsigned char *v3Host = nullptr; + unsigned char *v3Device = nullptr; + float *v4Host = nullptr; + float *v4Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMallocHost((void **)(&v4Host), fileSize_v4)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v4Device, fileSize_v4, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ReadFile("./v4.bin", fileSize_v4, v4Host, fileSize_v4); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v4Device, fileSize_v4, v4Host, fileSize_v4, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVsel_f32_plds_us_pintlv_pbitcast_kernel_2d(v1Device, v2Device, v3Device, + v4Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v4Host, fileSize_v4, v4Device, fileSize_v4, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v4.bin", v4Host, fileSize_v4); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFree(v4Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + aclrtFreeHost(v4Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/stub.cpp b/test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/stub.cpp new file mode 100644 index 000000000..ea1e5e63b --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void +vsel_f32_plds_us_pintlv_pbitcast_kernel_2d(__gm__ float *v1, __gm__ float *v2, + __gm__ unsigned char *v3, + __gm__ float *v4) { + (void)v1; + (void)v2; + (void)v3; + (void)v4; +} diff --git a/tilelang-dsl/docs/user_guide/05-type-system.md b/tilelang-dsl/docs/user_guide/05-type-system.md index 1c573fe97..b68432daf 100644 --- a/tilelang-dsl/docs/user_guide/05-type-system.md +++ b/tilelang-dsl/docs/user_guide/05-type-system.md @@ -208,6 +208,19 @@ mask_ty = pto.mask_b32 mask: pto.mask_b32 = pto.make_mask(pto.f32, PAT.ALL) ``` +Typed masks also support explicit type reinterpretation via `pto.pbitcast`: + +```python +mask_b8 = pto.plds(mask_ptr, offset, pto.PredicateDist.US) +mask_b16 = pto.pbitcast(mask_b8, pto.mask_b16) +mask_b32 = pto.pbitcast(mask_b16, pto.mask_b32) +``` + +`pto.pbitcast(...)` is the predicate analogue of `pto.vbitcast(...)`: +- it changes the static mask granularity seen by later DSL/VPTO consumers +- it preserves the underlying predicate bit image +- it does not perform pack/unpack or interleave/deinterleave by itself + Mask operations must match the vector element family: - `f32`, `i32`, `si32`, and `ui32` vectors use `mask_b32` - `f16`, `bf16`, `i16`, `si16`, and `ui16` vectors use `mask_b16` diff --git a/tilelang-dsl/docs/user_guide/10-predicate-operations.md b/tilelang-dsl/docs/user_guide/10-predicate-operations.md index 21f8d4879..8cc92da2c 100644 --- a/tilelang-dsl/docs/user_guide/10-predicate-operations.md +++ b/tilelang-dsl/docs/user_guide/10-predicate-operations.md @@ -318,6 +318,36 @@ packed = pto.ppack(mask, pto.PredicatePart.LOWER) unpacked = pto.punpack(mask, pto.PredicatePart.HIGHER) ``` +#### `pto.pbitcast(mask: MaskType, to_type: MaskType) -> MaskType` + +**Description**: Reinterprets a typed predicate mask as another typed mask granularity without changing the underlying predicate bit image. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Input mask (`mask_b8`, `mask_b16`, or `mask_b32`) | +| `to_type` | `MaskType` | Target mask type marker such as `pto.mask_b16` or `pto.mask_b32` | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `MaskType` | Reinterpreted mask with the requested target granularity | + +**Constraints**: +- `mask` must already be a typed predicate value +- `to_type` must be one of the DSL mask type markers: `pto.mask_b8`, `pto.mask_b16`, `pto.mask_b32` +- this is a bit reinterpretation helper, not a logical predicate transform; it does not insert packing, unpacking, interleaving, or deinterleaving by itself +- use `pto.ppack`, `pto.punpack`, `pto.pdintlv_b8`, or `pto.pintlv_b16` when the predicate image itself must be rearranged + +**Example**: +```python +mask_b8 = pto.plds(mask_ptr, offset, pto.PredicateDist.US) +mask_b16 = pto.pbitcast(mask_b8, pto.mask_b16) + +mask0_b16, mask1_b16 = pto.pintlv_b16(mask_b16, pto.pset_b16(PAT.ALL)) +mask0_b32 = pto.pbitcast(mask0_b16, pto.mask_b32) +``` + #### `pto.pnot(mask: MaskType, gate: MaskType) -> MaskType` **Description**: Predicate negation under a same-granularity mask gate. diff --git a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py index d4cff2da3..821b6b27f 100644 --- a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py +++ b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py @@ -829,6 +829,7 @@ def _collect_reachable_inline_procs( "vlds": frozenset({"dist"}), "vsts": frozenset({"dist"}), "vbitcast": frozenset(), + "pbitcast": frozenset(), } diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index c0dca2161..b640c7eda 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -2975,6 +2975,15 @@ def _lower_call_expr( ) return _RenderedValue(name=result_name, type=expr.type) + if expr.name == "pbitcast": + value = self._lower_expr(expr.args[0], env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"{result_name} = pto.pbitcast {value.name} : " + + f"{self._render_type(value.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + if expr.name == "vbitsort": destination = self._lower_expr(expr.args[0], env, indent=indent, into=into) source = self._lower_expr(expr.args[1], env, indent=indent, into=into) diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index fbaf78e06..c0b9579db 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -3655,21 +3655,25 @@ def _analyze_as_ptr_method(self, base: SemanticExpr) -> SemanticExpr: raise TypeError("`as_ptr()` expects a TensorView/PartitionTensorView or Tile value in TileLang DSL v1") def _analyze_astype_method(self, base: SemanticExpr, args: tuple[SemanticExpr, ...]) -> SemanticExpr: - """Analyze vreg.astype(dtype) method call.""" if len(args) != 1: raise TypeError("`astype()` expects exactly 1 positional argument (target dtype) in TileLang DSL v1") - # Verify target dtype is a valid dtype symbol - target_dtype = self._require_dtype_symbol(args[0], "astype target dtype") - # Verify base is a vector register - if not isinstance(base.type, SemanticVRegType): - raise TypeError("`astype()` expects a vector register value in TileLang DSL v1") - # Convert to pto.vbitcast call, pass original dtype expression as second argument - return SemanticCallExpr( - namespace="pto", - name="vbitcast", - args=(base, args[0]), - type=self._vreg_type_for_dtype(target_dtype), - ) + if isinstance(base.type, SemanticVRegType): + target_dtype = self._require_dtype_symbol(args[0], "astype target dtype") + return SemanticCallExpr( + namespace="pto", + name="vbitcast", + args=(base, args[0]), + type=self._vreg_type_for_dtype(target_dtype), + ) + if isinstance(base.type, SemanticMaskType): + target_mask_type = self._require_mask_type_expr(args[0], "astype target dtype") + return SemanticCallExpr( + namespace="pto", + name="pbitcast", + args=(base, args[0]), + type=SemanticMaskType(granularity=target_mask_type.granularity), + ) + raise TypeError("`astype()` expects a vector register or mask value in TileLang DSL v1") def _valid_shape_expr(self, base: SemanticExpr) -> SemanticExpr: base_type = base.type @@ -4026,6 +4030,8 @@ def _analyze_call_expr( return self._analyze_vcvt(args) if name == "vbitcast": return self._analyze_vbitcast(args) + if name == "pbitcast": + return self._analyze_pbitcast(args) if name == "vtrc": return self._analyze_vtrc(args) if name == "vbitsort": @@ -5211,6 +5217,21 @@ def _analyze_vbitcast(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: type=self._vreg_type_for_dtype(target_dtype), ) + def _analyze_pbitcast(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) != 2: + raise TypeError("pto.pbitcast expects exactly 2 positional arguments in TileLang DSL") + self._require_mask_expr(args[0], "pto.pbitcast mask") + target_mask_type = self._require_mask_type_expr(args[1], "pto.pbitcast to_type") + return SemanticCallExpr( + namespace="pto", + name="pbitcast", + args=( + args[0], + args[1], + ), + type=SemanticMaskType(granularity=target_mask_type.granularity), + ) + def _analyze_vbitsort(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: if len(args) != 4: raise TypeError("pto.vbitsort expects exactly 4 positional arguments in TileLang DSL v1") diff --git a/tilelang-dsl/python/tilelang_dsl/support_matrix.py b/tilelang-dsl/python/tilelang_dsl/support_matrix.py index d92bb661c..bb4b311dd 100644 --- a/tilelang-dsl/python/tilelang_dsl/support_matrix.py +++ b/tilelang-dsl/python/tilelang_dsl/support_matrix.py @@ -152,6 +152,7 @@ "vmrgsort", "vcvt", "vbitcast", + "pbitcast", "vci", } ) diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index 65daee63f..ec0ccaabb 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -4115,7 +4115,7 @@ def kernel(dst: pto.Tile, src: pto.Tile): ) specialized.mlir_text() - self.assertIn("vector register value", str(ctx.exception)) + self.assertIn("vector register or mask value", str(ctx.exception)) def test_vbitcast_supports_element_size_change(self) -> None: @pto.vkernel( @@ -4146,6 +4146,95 @@ def kernel(dst: pto.Tile, src: pto.Tile): self.assertRegex(text, r"= pto\.vbitcast %[^:]+ : !pto\.vreg<64xf32> -> !pto\.vreg<128xf16>") self.assertRegex(text, r"= pto\.vbitcast %[^:]+ : !pto\.vreg<128xf16> -> !pto\.vreg<64xf32>") + def test_pbitcast_supports_direct_interface(self) -> None: + @pto.vkernel( + op="pbitcast_direct_unique", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + src_mask = pto.make_mask(pto.f16, pto.PAT.ALL) + dst_mask = pto.pbitcast(src_mask, pto.mask_b32) + vec = pto.vlds(src, 0) + pto.vsts(vec, dst, 0, dst_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.pbitcast", text) + self.assertRegex(text, r'%src_mask_\d+ = pto\.pset_b16 "PAT_ALL" : !pto\.mask') + self.assertRegex(text, r'= pto\.pbitcast %[^:]+ : !pto\.mask -> !pto\.mask') + + def test_pbitcast_rejects_non_mask_input(self) -> None: + with self.assertRaises(TypeError) as ctx: + @pto.vkernel( + op="pbitcast_non_mask_input_unique", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + vec = pto.vlds(src, 0) + mask = pto.pbitcast(vec, pto.mask_b32) + pto.vsts(vec, dst, 0, mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + specialized.mlir_text() + + self.assertIn("mask value", str(ctx.exception)) + + def test_mask_astype_lowers_to_pbitcast(self) -> None: + @pto.vkernel( + op="mask_astype_unique", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + src_mask = pto.make_mask(pto.f16, pto.PAT.ALL) + dst_mask = src_mask.astype(pto.mask_b32) + vec = pto.vlds(src, 0) + pto.vsts(vec, dst, 0, dst_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.pbitcast", text) + self.assertRegex(text, r'%src_mask_\d+ = pto\.pset_b16 "PAT_ALL" : !pto\.mask') + self.assertRegex(text, r'= pto\.pbitcast %[^:]+ : !pto\.mask -> !pto\.mask') + + def test_astype_rejects_non_vreg_or_mask_receiver(self) -> None: + with self.assertRaises(TypeError) as ctx: + @pto.vkernel( + op="astype_invalid_receiver_unique", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + scalar = pto.f32(1.0) + mask = scalar.astype(pto.mask_b32) + vec = pto.vlds(src, 0) + pto.vsts(vec, dst, 0, mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + specialized.mlir_text() + + self.assertIn("vector register or mask value", str(ctx.exception)) + def test_index_to_float_scalar_cast_lowers_via_integer_bridge(self) -> None: @pto.vkernel( op="index_to_float_scalar_cast_unique", From 398beb5c63f2e00971672028e34604a1d474460d Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Thu, 23 Apr 2026 16:39:36 +0800 Subject: [PATCH 138/192] fix(dsl): fix the order mode lowering of pto.vci --- docs/isa/09-conversion-ops.md | 4 +-- docs/isa/13-dsa-sfu-ops.md | 4 +-- .../11-vector-arithmetic-operations.md | 9 ++++-- tilelang-dsl/python/tilelang_dsl/semantic.py | 6 ++-- tilelang-dsl/python/tilelang_dsl/types.py | 3 +- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 32 +++++++++++++++++-- 6 files changed, 45 insertions(+), 13 deletions(-) diff --git a/docs/isa/09-conversion-ops.md b/docs/isa/09-conversion-ops.md index 5c9cd5515..9ea2f5b78 100644 --- a/docs/isa/09-conversion-ops.md +++ b/docs/isa/09-conversion-ops.md @@ -29,14 +29,14 @@ Cycle-accurate simulator **popped→retire** latency (cycles). Only representati ## `pto.vci` -- **syntax:** `%result = pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- **syntax:** `%result = pto.vci %index {order = "ASC|DESC"} : integer -> !pto.vreg` - **semantics:** Generate a lane-index vector from a scalar seed/index value. - **inputs:** `%index` is the scalar seed or base index. - **outputs:** `%result` is the generated index vector. - **constraints and limitations:** - This is an index-generation family, not a numeric conversion. `ORDER` and the + This is an index-generation family, not a numeric conversion. `order` and the result element type together determine how indices are generated. `%result` uses an integer element type, and the scalar `%index` type matches that result element type. diff --git a/docs/isa/13-dsa-sfu-ops.md b/docs/isa/13-dsa-sfu-ops.md index 06edaef68..c32fedd79 100644 --- a/docs/isa/13-dsa-sfu-ops.md +++ b/docs/isa/13-dsa-sfu-ops.md @@ -149,7 +149,7 @@ for (int i = 0; i < N; i++) ### `pto.vci` -- **syntax:** `%result = pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- **syntax:** `%result = pto.vci %index {order = "ASC|DESC"} : integer -> !pto.vreg` - **semantics:** Generate lane index vector. ```c @@ -208,7 +208,7 @@ for (int i = 0; i < N; i++) - `pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` - `pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` -- `pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- `pto.vci %index {order = "ASC|DESC"} : integer -> !pto.vreg` - `pto.vbitsort %dest, %src, %indices, %repeat_times : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, index` - `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, i64, i64` diff --git a/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md b/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md index 39882950e..7edf20936 100644 --- a/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md +++ b/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md @@ -1499,7 +1499,7 @@ None. The op writes UB memory directly. - `dest` and `src0` through `src3` must be UB-backed pointers - Inputs must already be sorted according to the order encoded by `config` -**Order Mode Enum**: The `OrderMode` enum provides type-safe order selection for `pto.vci` operations. Currently only `ASC` (ascending order) is supported, with more order options planned for future releases. +**Order Mode Enum**: The `OrderMode` enum provides type-safe order selection for `pto.vci` operations. `ASC` and `DESC` are supported. #### `pto.vci(index: ScalarType, order: OrderMode = OrderMode.ASC) -> VRegType` @@ -1509,7 +1509,7 @@ None. The op writes UB memory directly. | Parameter | Type | Description | |-----------|------|-------------| | `index` | `ScalarType` | Scalar seed or base index value | -| `order` | `OrderMode` | Order mode enum (default: `OrderMode.ASC` for ascending order) | +| `order` | `OrderMode` | Order mode enum (default: `OrderMode.ASC`; supported values: `ASC`, `DESC`) | **Returns**: | Return Value | Type | Description | @@ -1519,13 +1519,16 @@ None. The op writes UB memory directly. **Constraints**: - This is an index-generation family, not a numeric conversion - The `order` parameter and result element type together determine how indices are generated -- Currently only ascending order (`OrderMode.ASC`) is supported +- Supported order modes are ascending (`OrderMode.ASC`) and descending (`OrderMode.DESC`) **Example**: ```python # Generate ascending indices starting from 0 indices = pto.vci(pto.i32(0), OrderMode.ASC) +# Generate descending indices starting from the seed value +indices_desc = pto.vci(pto.i32(63), OrderMode.DESC) + # Keyword form for the optional order argument is also supported indices_kw = pto.vci(pto.i32(0), order=OrderMode.ASC) ``` diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index c0b9579db..13d7ec73c 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -5469,8 +5469,10 @@ def _normalize_order_mode( ): return SemanticLiteralExpr(value=expr.binding.value.value, type=SemanticMetaType(kind="string")) order = self._require_string_expr(expr, context) - if order != OrderMode.ASC.value: - raise TypeError("pto.vci currently only supports order `OrderMode.ASC` in TileLang DSL v1") + if order not in {OrderMode.ASC.value, OrderMode.DESC.value}: + raise TypeError( + "pto.vci currently only supports order `OrderMode.ASC` or `OrderMode.DESC` in TileLang DSL v1" + ) return SemanticLiteralExpr(value=order, type=SemanticMetaType(kind="string")) def _normalize_vcvt_round_mode(self, expr: SemanticExpr | None) -> SemanticExpr | None: diff --git a/tilelang-dsl/python/tilelang_dsl/types.py b/tilelang-dsl/python/tilelang_dsl/types.py index 3ff5e387f..0e806f82f 100644 --- a/tilelang-dsl/python/tilelang_dsl/types.py +++ b/tilelang-dsl/python/tilelang_dsl/types.py @@ -437,7 +437,8 @@ class PositionMode(str, Enum): class OrderMode(str, Enum): - ASC = "ORDER_ASC" + ASC = "ASC" + DESC = "DESC" class VcvtRoundMode(str, Enum): diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index ec0ccaabb..3e8414d24 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -161,7 +161,8 @@ def test_package_exports_surface(self) -> None: self.assertEqual(pto.InterleaveDist.INTLV.value, "INTLV") self.assertEqual(pto.PositionMode.LOWEST.value, "LOWEST") self.assertEqual(pto.PositionMode.HIGHEST.value, "HIGHEST") - self.assertEqual(pto.OrderMode.ASC.value, "ORDER_ASC") + self.assertEqual(pto.OrderMode.ASC.value, "ASC") + self.assertEqual(pto.OrderMode.DESC.value, "DESC") self.assertEqual(pto.PredicateDist.NORM.value, "NORM") self.assertEqual(pto.PredicateDist.US.value, "US") self.assertEqual(pto.PredicateDist.DS.value, "DS") @@ -4471,11 +4472,36 @@ def kernel(dst: pto.Tile, src: pto.Tile, seed: pto.i32): self.assertNotIn('position = "POS_LOWEST"', text) self.assertRegex( text, - r'pto\.vci\s+%[^\s]+\s+\{order = "ORDER_ASC"\}\s+:', + r'pto\.vci\s+%[^\s]+\s+\{order = "ASC"\}\s+:', ) self.assertNotRegex( text, - r'pto\.vci\s+%[^\s]+,\s*"ORDER_ASC"\s+:', + r'pto\.vci\s+%[^\s]+,\s*"ASC"\s+:', + ) + + def test_vci_desc_lowers_to_desc_order_attr(self) -> None: + @pto.vkernel( + op="vci_desc_order_unique", + dtypes=[(pto.i32, pto.i32, pto.i32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile, seed: pto.i32): + mask = pto.make_mask(pto.i32, pto.PAT.ALL) + indices = pto.vci(seed, pto.OrderMode.DESC) + vec = pto.vlds(src, 0) + out = pto.vadd(vec, indices, mask) + pto.vsts(out, dst, 0, mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertRegex( + text, + r'pto\.vci\s+%[^\s]+\s+\{order = "DESC"\}\s+:', ) def test_vdup_scalar_input_rejects_position_argument(self) -> None: From bf4dc9a7e6e2f15a94d00e19df526d5a33861559 Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Thu, 23 Apr 2026 17:43:17 +0800 Subject: [PATCH 139/192] feat: support copy_ubuf_to_ubuf / dma_copy lowering --- docs/isa-legacy/02-dma-copy.md | 61 ++++++--- docs/isa/02-dma-copy.md | 69 +++++++++- docs/tilelang-dsl-syntax-sugar-proposals.md | 10 +- docs/vpto-spec.md | 6 +- include/PTO/IR/VPTOOps.td | 26 ++++ lib/PTO/IR/VPTO.cpp | 16 +++ lib/PTO/Transforms/HIVMIntrinsicNaming.cpp | 3 +- lib/PTO/Transforms/PTOVPTOExpandBridgeOps.cpp | 16 ++- lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 87 +++++++++++- .../cases/vpto/dma-copy-rearrange/compare.py | 50 +++++++ .../cases/vpto/dma-copy-rearrange/golden.py | 44 ++++++ .../cases/vpto/dma-copy-rearrange/kernel.pto | 67 +++++++++ .../cases/vpto/dma-copy-rearrange/launch.cpp | 54 ++++++++ .../cases/vpto/dma-copy-rearrange/main.cpp | 128 ++++++++++++++++++ .../cases/vpto/dma-copy-rearrange/stub.cpp | 31 +++++ 15 files changed, 632 insertions(+), 36 deletions(-) create mode 100644 test/vpto/cases/vpto/dma-copy-rearrange/compare.py create mode 100644 test/vpto/cases/vpto/dma-copy-rearrange/golden.py create mode 100644 test/vpto/cases/vpto/dma-copy-rearrange/kernel.pto create mode 100644 test/vpto/cases/vpto/dma-copy-rearrange/launch.cpp create mode 100644 test/vpto/cases/vpto/dma-copy-rearrange/main.cpp create mode 100644 test/vpto/cases/vpto/dma-copy-rearrange/stub.cpp diff --git a/docs/isa-legacy/02-dma-copy.md b/docs/isa-legacy/02-dma-copy.md index 6e2fdc4f3..1a24f62d0 100644 --- a/docs/isa-legacy/02-dma-copy.md +++ b/docs/isa-legacy/02-dma-copy.md @@ -198,35 +198,46 @@ pto.copy_ubuf_to_gm %ub_src, %gm_dst, - **syntax:** ```mlir -pto.copy_ubuf_to_ubuf %source, %dest, %sid, %n_burst, %len_burst, %src_stride, %dst_stride - : !pto.ptr, !pto.ptr, i64 x5 +pto.copy_ubuf_to_ubuf %ub_src, %ub_dst, + %sid, %n_burst, %len_burst, %src_gap, %dst_gap + : !pto.ptr, !pto.ptr, + i64, i64, i64, i64, i64 ``` -- **semantics:** Copy within Unified Buffer. +- **semantics:** Raw UB→UB copy within Unified Buffer. `pto.dma_copy` uses the same operand contract. -**Parameters:** +**Parameter Table:** -| Parameter | Description | -|-----------|-------------| -| `%source` | UB source pointer | -| `%dest` | UB destination pointer | -| `%sid` | Stream ID | -| `%n_burst` | Number of bursts | -| `%len_burst` | Length per burst | -| `%src_stride` | Source stride | -| `%dst_stride` | Destination stride | +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%ub_src` | ptr | UB source pointer (`!pto.ptr`, 32B-aligned) | +| `%ub_dst` | ptr | UB destination pointer (`!pto.ptr`, 32B-aligned) | +| `%sid` | 16 bits | Stream ID | +| `%n_burst` | 16 bits | Number of bursts | +| `%len_burst` | 16 bits | Burst length in units of 32 bytes | +| `%src_gap` | 16 bits | Source gap between consecutive bursts, in units of 32 bytes | +| `%dst_gap` | 16 bits | Destination gap between consecutive bursts, in units of 32 bytes | --- -## Burst / Stride / Pad Model +## Burst / Stride / Gap / Pad Model + +The legacy DMA copy family uses two different innermost-burst contracts: -All A5 DMA addresses are **stride-based**: stride is the distance from the start of one row to the start of the next row (`stride >= lenBurst`). There is no separate "gap" parameter. +- `pto.copy_gm_to_ubuf` / `pto.copy_ubuf_to_gm` are **stride-based**. Their + source and destination stride operands are start-to-start distances in bytes. +- `pto.copy_ubuf_to_ubuf` is **gap-based**. Its + `%len_burst`, `%src_gap`, and `%dst_gap` operands are encoded in units of + 32 bytes. ### Key Terms ``` -burst = lenBurst contiguous bytes transferred per row -stride = distance (bytes) from start of row[r] to start of row[r+1] -pad = ub_stride - lenBurst, padded to the 32B alignment boundary +GM↔UB burst = lenBurst contiguous bytes transferred per row +GM↔UB stride = distance (bytes) from start of row[r] to start of row[r+1] +pad = ub_stride - lenBurst, padded to the 32B alignment boundary +UB→UB burst = len_burst * 32 bytes +UB→UB next source start = previous source start + (len_burst + src_gap) * 32 bytes +UB→UB next destination start = previous destination start + (len_burst + dst_gap) * 32 bytes ``` ### Alignment Constraints @@ -235,6 +246,20 @@ pad = ub_stride - lenBurst, padded to the 32B alignment boundary - **GM→UB padding**: When `data_select_bit = true`, each UB row is padded from `lenBurst` up to the **32B-aligned boundary** of `ub_stride` with `pad_val` (set via `set_mov_pad_val`). This ensures every UB row starts at a 32B-aligned offset. - **UB→GM de-padding**: MTE3 reads `lenBurst` bytes from each 32B-aligned UB row (skipping any padding that was added during load), writing only valid data to GM. This effectively strips padding on store. +### UB→UB Raw Copy (`pto.copy_ubuf_to_ubuf`) + +For UB→UB raw copy, each burst copies `len_burst * 32` bytes. + +After burst `r`, the next burst starts at: + +```text +src_next = src_curr + (len_burst + src_gap) * 32 bytes +dst_next = dst_curr + (len_burst + dst_gap) * 32 bytes +``` + +So `src_gap` and `dst_gap` are not start-to-start strides. They are additional +gaps inserted after the copied 32B blocks. + ### 2D Diagram: GM→UB (pto.copy_gm_to_ubuf) ``` diff --git a/docs/isa/02-dma-copy.md b/docs/isa/02-dma-copy.md index e402d8117..845df918b 100644 --- a/docs/isa/02-dma-copy.md +++ b/docs/isa/02-dma-copy.md @@ -9,9 +9,12 @@ This document describes the public grouped DMA interfaces: - `pto.dma_load` - `pto.dma_store` +- `pto.dma_copy` -The legacy low-level DMA configuration and raw copy interfaces are documented in -[02-dma-copy-legacy.md](02-dma-copy-legacy.md). +This chapter covers the public grouped DMA interfaces. The legacy raw copy +family remains documented separately; in particular, `pto.copy_ubuf_to_ubuf` +shares the same UB→UB copy contract as `pto.dma_copy` but remains a legacy +surface op. --- @@ -120,14 +123,49 @@ pto.dma_store %ub_in, %gm_out, %sid, %zero, %len_burst --- -For the legacy low-level DMA copy family, see -[02-dma-copy-legacy.md](02-dma-copy-legacy.md). +### `pto.dma_copy` + +- **syntax:** +```mlir +pto.dma_copy %ub_src, %ub_dst, %sid, %len_burst + nburst(%n_burst, %src_gap, %dst_gap) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 +``` +- **semantics:** Grouped UB→UB raw copy.. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%ub_src` | ptr | UB source pointer (`!pto.ptr`, 32B-aligned) | +| `%ub_dst` | ptr | UB destination pointer (`!pto.ptr`, 32B-aligned) | +| `%sid` | 16 bits | Stream ID | +| `%len_burst` | 16 bits | Burst length in units of 32 bytes | +| `nburst(%n_burst, %src_gap, %dst_gap)` | 16 bits / 16 bits / 16 bits | Required UB→UB outer burst group: count, source gap, destination gap | + +**Constraints:** + +- UB source and destination addresses must be 32B-aligned. +- `%len_burst`, `%src_gap`, and `%dst_gap` are encoded in units of 32 bytes. + +**Example:** + +```mlir +pto.dma_copy %ub_src, %ub_dst, %sid, %len32b + nburst(%rows, %src_gap, %dst_gap) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 +``` --- -## Burst / Stride / Pad Model +## GM↔UB Burst / Stride / Pad Model -All A5 DMA addresses are **stride-based**: stride is the distance from the start of one row to the start of the next row (`stride >= lenBurst`). There is no separate "gap" parameter. +This section describes the grouped GM↔UB DMA interfaces in this document: +`pto.dma_load` and `pto.dma_store`. + +For these grouped GM↔UB DMA ops, the innermost `nburst(...)` group is +**stride-based**: the source and destination stride operands are the +start-to-start byte distance from one burst row to the next row. ### Key Terms @@ -143,6 +181,25 @@ pad = ub_stride - lenBurst, padded to the 32B alignment boundary - **GM→UB padding**: When `pad(...)` is present on `pto.dma_load`, each UB row is padded from `lenBurst` up to the **32B-aligned boundary** of `ub_stride` with `pad_val`. This ensures every UB row starts at a 32B-aligned offset. - **UB→GM de-padding**: MTE3 reads `lenBurst` bytes from each 32B-aligned UB row (skipping any padding that was added during load), writing only valid data to GM. This effectively strips padding on store. +--- + +## UB→UB Burst / Gap Model + +This section describes the grouped UB→UB DMA interface in this document: +`pto.dma_copy`. + +For `pto.dma_copy`, each burst copies `len_burst * 32` bytes. + +The next burst starts at: + +```text +src_next = src_curr + (len_burst + src_gap) * 32 bytes +dst_next = dst_curr + (len_burst + dst_gap) * 32 bytes +``` + +So `src_gap` and `dst_gap` are additional gaps after the copied 32B blocks. +They are not start-to-start strides. + ### 2D Diagram: GM→UB (`pto.dma_load`) ``` diff --git a/docs/tilelang-dsl-syntax-sugar-proposals.md b/docs/tilelang-dsl-syntax-sugar-proposals.md index 8a60466d9..bff549e4c 100644 --- a/docs/tilelang-dsl-syntax-sugar-proposals.md +++ b/docs/tilelang-dsl-syntax-sugar-proposals.md @@ -16,11 +16,11 @@ next_ptr = pto.addptr(ub_ptr, 4096) **Problem**: Users must manage byte offsets and memory spaces manually. ### 2. **Verbose Copy Operations** -The `pto.copy_ubuf_to_ubuf` operation has 7 parameters: -- `src_offset`, `src_stride0`, `src_stride1` -- `dst_offset`, `dst_stride0`, `dst_stride1` +The `pto.copy_ubuf_to_ubuf` / `pto.dma_copy` operand contract is low-level: +- source pointer, destination pointer, `sid` +- `n_burst`, `len_burst`, `src_gap`, `dst_gap` -**Problem**: Correctly setting stride parameters is error-prone, especially for multi-dimensional data. +**Problem**: Correctly setting burst and gap parameters is error-prone, especially for multi-dimensional data. ### 3. **Precise Mask Type Matching** ```python @@ -401,4 +401,4 @@ These enhancements will significantly improve the TileLang DSL's usability while 4. Mask 的隐式推导(针对边界处理) NPU 算子经常要处理尾部不对齐的数据(Tail processing)。 -优化建议:虽然底层需要具体的 Mask 寄存器配置(如 PAT_ALL),但在 for 循环的最后一步边界处理时,能否提供一个类似 pto.make_mask(remaining_elements) 的宏/内联函数?让它在生成 MLIR 时,自动展开为对应的硬件 plt_b32 等指令,这样可以大幅减少手写冗长边界判断的样板代码。 \ No newline at end of file +优化建议:虽然底层需要具体的 Mask 寄存器配置(如 PAT_ALL),但在 for 循环的最后一步边界处理时,能否提供一个类似 pto.make_mask(remaining_elements) 的宏/内联函数?让它在生成 MLIR 时,自动展开为对应的硬件 plt_b32 等指令,这样可以大幅减少手写冗长边界判断的样板代码。 diff --git a/docs/vpto-spec.md b/docs/vpto-spec.md index 7a2d177c8..c4d748a55 100644 --- a/docs/vpto-spec.md +++ b/docs/vpto-spec.md @@ -147,6 +147,10 @@ The PTO micro Instruction enforces a strict memory hierarchy. The Unified Buffer 4. **vreg → UB**: Vector Store instructions (`pto.vsts`, `pto.vstsx2`, etc.) 5. **UB → GM**: DMA transfer via MTE3 (`pto.dma_store`) +The grouped DMA surface in this specification covers GM↔UB transfer only. +Low-level raw copy families such as UB→UB copy use separate operand contracts +and are outside this grouped DMA interface. + **Load/Store Access Patterns**: For UB↔vreg data movement, besides contiguous load/store, the architecture provides rich access pattern support including strided access, pack/unpack, interleave/deinterleave, broadcast, upsample/downsample, channel split/merge, gather/scatter, and squeeze/expand operations. For detailed instruction syntax and distribution modes, refer to the [Vector Load/Store](isa/03-vector-load-store.md) group in the ISA specification. @@ -882,7 +886,7 @@ This section provides a categorized overview of all PTO micro Instruction operat | # | Group | Description | Count | Details | |---|-------|-------------|-------|---------| | 1 | [Pipeline Sync](isa/01-pipeline-sync.md) | Intra-core pipeline synchronization | 5 | `pto.set_flag`, `pto.wait_flag`, `pto.pipe_barrier`, `pto.get_buf`, `pto.rls_buf` | -| 2 | [DMA Copy Programming](isa/02-dma-copy.md) | Public DMA transfer interface between GM↔UB | 2 | `pto.dma_load`, `pto.dma_store` | +| 2 | [DMA Copy Programming](isa/02-dma-copy.md) | Public DMA transfer interface between GM↔UB and UB↔UB | 3 | `pto.dma_load`, `pto.dma_store`, `pto.dma_copy` | | 3 | [Vector Load/Store](isa/03-vector-load-store.md) | UB↔vreg data movement with various access patterns | ~20 | `pto.vlds`, `pto.vldsx2`, `pto.vgather2`, `pto.vsts`, `pto.vstsx2`, `pto.vscatter`, etc. | | 4 | [Predicate Load/Store](isa/04-predicate-load-store.md) | UB↔mask register movement | 5 | `pto.plds`, `pto.pldi`, `pto.psts`, `pto.psti`, `pto.pstu` | | 5 | [Materialization & Predicate Ops](isa/05-materialization-predicate.md) | Scalar broadcast, predicate generation and manipulation | ~17 | `pto.vbr`, `pto.vdup`, `pto.pset_b*`, `pto.pge_b*`, `pto.plt_b*`, `pto.ppack`, `pto.punpack`, `pto.pnot`, `pto.psel`, etc. | diff --git a/include/PTO/IR/VPTOOps.td b/include/PTO/IR/VPTOOps.td index 4c4b6df6b..8630c8936 100644 --- a/include/PTO/IR/VPTOOps.td +++ b/include/PTO/IR/VPTOOps.td @@ -299,6 +299,32 @@ def PTO_CopyUbufToUbufOp : PTO_Op<"copy_ubuf_to_ubuf"> { }]; } +def PTO_DmaCopyOp : PTO_Op<"dma_copy", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$sid, + I64:$n_burst, + I64:$len_burst, + I64:$src_stride, + I64:$dst_stride + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `,` $destination `,` $sid `,` $len_burst + `nburst` `(` $n_burst `,` $src_stride `,` $dst_stride `)` + attr-dict `:` type($source) `,` type($destination) `,` type($sid) `,` + type($n_burst) `,` type($len_burst) `,` type($src_stride) `,` + type($dst_stride) + }]; +} + def PTO_VldsOp : PTO_Op<"vlds", [ DeclareOpInterfaceMethods ]> { diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp index 25120cf54..75e38c830 100644 --- a/lib/PTO/IR/VPTO.cpp +++ b/lib/PTO/IR/VPTO.cpp @@ -1661,6 +1661,22 @@ LogicalResult CopyUbufToUbufOp::verify() { return success(); } +void DmaCopyOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult DmaCopyOp::verify() { + if (!isBufferLike(getSource().getType()) || !isBufferLike(getDestination().getType())) + return emitOpError("requires pointer-like source and destination"); + if (classifyMemoryRole(getSource().getType()) != MemoryRole::UB || + classifyMemoryRole(getDestination().getType()) != MemoryRole::UB) + return emitOpError("requires UB-backed source and destination"); + return success(); +} + void VgatherbOp::getEffects( SmallVectorImpl> &effects) { diff --git a/lib/PTO/Transforms/HIVMIntrinsicNaming.cpp b/lib/PTO/Transforms/HIVMIntrinsicNaming.cpp index d17b8ad38..1ae923dac 100644 --- a/lib/PTO/Transforms/HIVMIntrinsicNaming.cpp +++ b/lib/PTO/Transforms/HIVMIntrinsicNaming.cpp @@ -578,8 +578,7 @@ FailureOr selectStoreIntrinsic(Operation *op) { if (isa(op)) { usedFields = {"family=copy_ubuf_to_ubuf"}; - return makeUnresolved(op, "copy_ubuf_to_ubuf", "copy_ubuf_to_ubuf", - usedFields, missingFields, ""); + return makeResolved(op, "llvm.hivm.MOV.UB.TO.UB.v310", usedFields, ""); } return failure(); diff --git a/lib/PTO/Transforms/PTOVPTOExpandBridgeOps.cpp b/lib/PTO/Transforms/PTOVPTOExpandBridgeOps.cpp index cbd02a602..2580a9cab 100644 --- a/lib/PTO/Transforms/PTOVPTOExpandBridgeOps.cpp +++ b/lib/PTO/Transforms/PTOVPTOExpandBridgeOps.cpp @@ -186,6 +186,18 @@ struct ExpandDmaStorePattern : public OpRewritePattern { } }; +struct ExpandDmaCopyPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(pto::DmaCopyOp op, + PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, op.getSource(), op.getDestination(), op.getSid(), op.getNBurst(), + op.getLenBurst(), op.getSrcStride(), op.getDstStride()); + return success(); + } +}; + struct PTOVPTOExpandBridgeOpsPass : public pto::impl::PTOVPTOExpandBridgeOpsBase { using pto::impl::PTOVPTOExpandBridgeOpsBase< @@ -197,8 +209,8 @@ struct PTOVPTOExpandBridgeOpsPass return; RewritePatternSet patterns(&getContext()); - patterns.add( - &getContext()); + patterns.add(&getContext()); if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) signalPassFailure(); } diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index 0a5c05dea..e505c17d7 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -938,6 +938,41 @@ static FailureOr packCopyUbToGmConfig0(Operation *anchor, Value sid, return packCopyUbToGmConfig0(anchor, operands); } +static FailureOr +packCopyUbToUbConfig(Operation *anchor, ValueRange operands) { + if (operands.size() != 7) + return failure(); + + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); + + auto getI64Operand = [&](unsigned idx) -> Value { + return castIntegerLikeTo(anchor, operands[idx], builder.getI64Type()); + }; + + Value nBurst = getI64Operand(3); + Value lenBurst = getI64Operand(4); + Value srcStride = getI64Operand(5); + Value dstStride = getI64Operand(6); + if (!nBurst || !lenBurst || !srcStride || !dstStride) + return failure(); + + auto shl = [&](Value value, uint64_t amount) -> Value { + return builder.create(loc, value, + getI64Constant(builder, loc, amount)); + }; + auto bitOr = [&](Value lhs, Value rhs) -> Value { + return builder.create(loc, lhs, rhs); + }; + + Value config = nBurst; + config = bitOr(config, shl(lenBurst, 16)); + config = bitOr(config, shl(srcStride, 32)); + config = bitOr(config, shl(dstStride, 48)); + return config; +} + static FailureOr packVbitsortConfig(Operation *anchor, Value repeatTimes) { OpBuilder builder(anchor); builder.setInsertionPoint(anchor); @@ -1319,6 +1354,10 @@ static StringRef buildCopyUbToGmCallee(MLIRContext *context) { .getValue(); } +static StringRef buildCopyUbToUbCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.MOV.UB.TO.UB.v310").getValue(); +} + static StringRef buildPstiCallee(MLIRContext *context) { return StringAttr::get(context, "llvm.hivm.psti.b8").getValue(); } @@ -2260,6 +2299,48 @@ class LowerCopyOpPattern final : public OpConversionPattern { LoweringState &state; }; +class LowerCopyUbufToUbufOpPattern final + : public OpConversionPattern { +public: + explicit LowerCopyUbufToUbufOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::CopyUbufToUbufOp op, + pto::CopyUbufToUbufOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto llvmSourceType = + dyn_cast(adaptor.getOperands()[0].getType()); + auto llvmDestType = + dyn_cast(adaptor.getOperands()[1].getType()); + if (!llvmSourceType || !llvmDestType) + return rewriter.notifyMatchFailure(op, "expected LLVM pointer copy operands"); + + FailureOr config = packCopyUbToUbConfig(op, adaptor.getOperands()); + if (failed(config)) + return rewriter.notifyMatchFailure(op, "failed to materialize copy config"); + + StringRef calleeName = buildCopyUbToUbCallee(op.getContext()); + SmallVector args{adaptor.getOperands()[1], adaptor.getOperands()[0], + *config}; + auto funcType = rewriter.getFunctionType( + TypeRange{llvmDestType, llvmSourceType, rewriter.getI64Type()}, + TypeRange{}); + auto call = rewriter.create(op.getLoc(), calleeName, + TypeRange{}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + (void)call; + return success(); + } + +private: + LoweringState &state; +}; + template class LowerVecScalarMaskedOpPattern final @@ -5086,7 +5167,8 @@ static void populateVPTOOpLoweringPatterns(VPTOTypeConverter &typeConverter, LowerPredicateStoreOpPattern, LowerPstuOpPattern, LowerVstusOpPattern, LowerVsturOpPattern, LowerCopyOpPattern, - LowerCopyOpPattern>( + LowerCopyOpPattern, + LowerCopyUbufToUbufOpPattern>( typeConverter, patterns.getContext(), state); } @@ -5141,7 +5223,8 @@ static void configureVPTOOpLoweringTarget(ConversionTarget &target, pto::VbitsortOp, pto::VtrcOp, pto::VcvtOp, pto::VbitcastOp, pto::VcmpOp, pto::VcmpsOp, - pto::CopyGmToUbufOp, pto::CopyUbufToGmOp>(); + pto::CopyGmToUbufOp, pto::CopyUbufToGmOp, + pto::CopyUbufToUbufOp>(); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); } diff --git a/test/vpto/cases/vpto/dma-copy-rearrange/compare.py b/test/vpto/cases/vpto/dma-copy-rearrange/compare.py new file mode 100644 index 000000000..a9e80ddd7 --- /dev/null +++ b/test/vpto/cases/vpto/dma-copy-rearrange/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: vpto/dma-copy-rearrange +# family: vpto +# target_ops: pto.copy_gm_to_ubuf, pto.dma_copy, pto.copy_ubuf_to_gm +# scenarios: i16, ub-rearrange, permute-4x16-rows + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.int16) + output = np.fromfile(output_path, dtype=np.int16) + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch: {golden.shape} vs {output.shape}") + return False + if np.array_equal(golden, output): + return True + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] first mismatch at idx={idx}: golden={int(golden[idx])}, out={int(output[idx])}") + return False + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vpto/dma-copy-rearrange/golden.py b/test/vpto/cases/vpto/dma-copy-rearrange/golden.py new file mode 100644 index 000000000..0b8128d54 --- /dev/null +++ b/test/vpto/cases/vpto/dma-copy-rearrange/golden.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: vpto/dma-copy-rearrange +# family: vpto +# target_ops: pto.copy_gm_to_ubuf, pto.dma_copy, pto.copy_ubuf_to_gm +# scenarios: i16, ub-rearrange, permute-4x16-rows + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 4 +COLS = 16 + + +def generate(output_dir: Path) -> None: + v1 = np.arange(ROWS * COLS, dtype=np.int16).reshape(ROWS, COLS) + v2 = np.zeros((ROWS, COLS), dtype=np.int16) + golden_v2 = v1[[2, 0, 3, 1], :].copy() + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vpto/dma-copy-rearrange/kernel.pto b/test/vpto/cases/vpto/dma-copy-rearrange/kernel.pto new file mode 100644 index 000000000..d8e2429aa --- /dev/null +++ b/test/vpto/cases/vpto/dma-copy-rearrange/kernel.pto @@ -0,0 +1,67 @@ +// ----------------------------------------------------------------------------- +// case: vpto/dma-copy-rearrange +// family: vpto +// target_ops: pto.copy_gm_to_ubuf, pto.dma_copy, pto.copy_ubuf_to_gm +// scenarios: i16, ub-rearrange, permute-4x16-rows +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @dma_copy_rearrange_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: i64, + %arg3: i64, + %arg4: i64, + %arg5: i64) { + %false = arith.constant false + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c96_i64 = arith.constant 96 : i64 + %c128_i64 = arith.constant 128 : i64 + %c160_i64 = arith.constant 160 : i64 + %c192_i64 = arith.constant 192 : i64 + %c224_i64 = arith.constant 224 : i64 + %c256_i64 = arith.constant 256 : i64 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c128_i64 : i64 -> !pto.ptr + + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c1_i64, %c128_i64, + %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.barrier #pto.pipe + + %src_row0 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %src_row1 = pto.castptr %c32_i64 : i64 -> !pto.ptr + %src_row2 = pto.castptr %c64_i64 : i64 -> !pto.ptr + %src_row3 = pto.castptr %c96_i64 : i64 -> !pto.ptr + + %dst_row0 = pto.castptr %c128_i64 : i64 -> !pto.ptr + %dst_row1 = pto.castptr %c160_i64 : i64 -> !pto.ptr + %dst_row2 = pto.castptr %c192_i64 : i64 -> !pto.ptr + %dst_row3 = pto.castptr %c224_i64 : i64 -> !pto.ptr + + pto.dma_copy %src_row2, %dst_row0, %c0_i64, %arg3 + nburst(%arg2, %arg4, %arg5) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_copy %src_row0, %dst_row1, %c0_i64, %arg3 + nburst(%arg2, %arg4, %arg5) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_copy %src_row3, %dst_row2, %c0_i64, %arg3 + nburst(%arg2, %arg4, %arg5) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_copy %src_row1, %dst_row3, %c0_i64, %arg3 + nburst(%arg2, %arg4, %arg5) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.barrier #pto.pipe + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c1_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vpto/dma-copy-rearrange/launch.cpp b/test/vpto/cases/vpto/dma-copy-rearrange/launch.cpp new file mode 100644 index 000000000..e418c4f9a --- /dev/null +++ b/test/vpto/cases/vpto/dma-copy-rearrange/launch.cpp @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void dma_copy_rearrange_kernel(__gm__ int16_t *v1, + __gm__ int16_t *v2, + int64_t n_burst, + int64_t len_burst, + int64_t src_gap, + int64_t dst_gap); + +void LaunchDma_copy_rearrange_kernel(int16_t *v1, int16_t *v2, + int64_t n_burst, int64_t len_burst, + int64_t src_gap, int64_t dst_gap, + void *stream) { + dma_copy_rearrange_kernel<<<1, nullptr, stream>>>((__gm__ int16_t *)v1, + (__gm__ int16_t *)v2, + n_burst, len_burst, + src_gap, dst_gap); +} diff --git a/test/vpto/cases/vpto/dma-copy-rearrange/main.cpp b/test/vpto/cases/vpto/dma-copy-rearrange/main.cpp new file mode 100644 index 000000000..6b50936a4 --- /dev/null +++ b/test/vpto/cases/vpto/dma-copy-rearrange/main.cpp @@ -0,0 +1,128 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: vpto/dma-copy-rearrange +// family: vpto +// target_ops: pto.copy_gm_to_ubuf, pto.dma_copy, pto.copy_ubuf_to_gm +// scenarios: i16, ub-rearrange, permute-4x16-rows +// ----------------------------------------------------------------------------- + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +#define FILE_CHECK(expr, path) \ + do { \ + if (!(expr)) { \ + std::fprintf(stderr, "[ERROR] file operation failed: %s (%s:%d)\n", \ + path, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchDma_copy_rearrange_kernel(int16_t *v1, int16_t *v2, + int64_t n_burst, int64_t len_burst, + int64_t src_gap, int64_t dst_gap, + void *stream); + +int main() { + constexpr size_t elemCount = 64; + constexpr size_t fileSize = elemCount * sizeof(int16_t); + size_t inputFileSize = fileSize; + + int16_t *v1Host = nullptr; + int16_t *v2Host = nullptr; + int16_t *v1Device = nullptr; + int16_t *v2Device = nullptr; + const int64_t nBurst = 1; + const int64_t lenBurst = 1; + const int64_t srcGap = 0; + const int64_t dstGap = 0; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST)); + + FILE_CHECK(ReadFile("./v1.bin", inputFileSize, v1Host, fileSize) && + inputFileSize == fileSize, + "./v1.bin"); + inputFileSize = fileSize; + FILE_CHECK(ReadFile("./v2.bin", inputFileSize, v2Host, fileSize) && + inputFileSize == fileSize, + "./v2.bin"); + + ACL_CHECK(aclrtMemcpy(v1Device, fileSize, v1Host, fileSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize, v2Host, fileSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchDma_copy_rearrange_kernel(v1Device, v2Device, nBurst, lenBurst, + srcGap, dstGap, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + + ACL_CHECK(aclrtMemcpy(v2Host, fileSize, v2Device, fileSize, + ACL_MEMCPY_DEVICE_TO_HOST)); + + FILE_CHECK(WriteFile("./v2.bin", v2Host, fileSize), "./v2.bin"); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vpto/dma-copy-rearrange/stub.cpp b/test/vpto/cases/vpto/dma-copy-rearrange/stub.cpp new file mode 100644 index 000000000..67b75188c --- /dev/null +++ b/test/vpto/cases/vpto/dma-copy-rearrange/stub.cpp @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void dma_copy_rearrange_kernel(__gm__ int16_t *v1, + __gm__ int16_t *v2, + int64_t n_burst, + int64_t len_burst, + int64_t src_gap, + int64_t dst_gap) { + (void)v1; + (void)v2; + (void)n_burst; + (void)len_burst; + (void)src_gap; + (void)dst_gap; +} From 555c9217ea9db2eb409151847c9b74d1c50c0489 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Thu, 23 Apr 2026 17:17:54 +0800 Subject: [PATCH 140/192] fix(vpto): align vrelu i32 support (#220) --- docs/isa/06-unary-vector-ops.md | 7 ++-- lib/PTO/IR/VPTO.cpp | 15 ++++++- lib/PTO/Transforms/PTOValidateVPTOIR.cpp | 22 ++++++++++ test/basic/issue220_vrelu_i32_vpto_llvm.pto | 42 +++++++++++++++++++ test/basic/vrelu_verify_invalid.pto | 25 +++++++++++ .../11-vector-arithmetic-operations.md | 2 + .../docs/vpto_spec/vpto-spec-current.md | 7 ++-- tilelang-dsl/python/tilelang_dsl/semantic.py | 7 +++- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 33 +++++++++++++++ 9 files changed, 151 insertions(+), 9 deletions(-) create mode 100644 test/basic/issue220_vrelu_i32_vpto_llvm.pto create mode 100644 test/basic/vrelu_verify_invalid.pto diff --git a/docs/isa/06-unary-vector-ops.md b/docs/isa/06-unary-vector-ops.md index 2706ac39b..8eff4fcec 100644 --- a/docs/isa/06-unary-vector-ops.md +++ b/docs/isa/06-unary-vector-ops.md @@ -126,7 +126,7 @@ for (int i = 0; i < N; i++) ### `pto.vrelu` - **syntax:** `%result = pto.vrelu %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` -- **A5 types:** f16, f32 +- **A5 types:** si32, i32, f16, f32 ```c for (int i = 0; i < N; i++) @@ -135,8 +135,9 @@ for (int i = 0; i < N; i++) - **inputs:** `%input` is the source vector and `%mask` selects active lanes. - **outputs:** `%result` holds `max(input[i], 0)` per active lane. -- **constraints and limitations:** Only floating-point element types are legal - on the current A5 surface described here. +- **constraints and limitations:** Signed or signless 32-bit integer and + floating-point element types are legal on the current A5 surface described + here. --- diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp index 75e38c830..c9582c29d 100644 --- a/lib/PTO/IR/VPTO.cpp +++ b/lib/PTO/IR/VPTO.cpp @@ -2356,7 +2356,20 @@ LogicalResult VexpOp::verify() { return verifyUnaryVecOp(*this); } LogicalResult VlnOp::verify() { return verifyUnaryVecOp(*this); } LogicalResult VsqrtOp::verify() { return verifyUnaryVecOp(*this); } LogicalResult VnegOp::verify() { return verifyUnaryVecOp(*this); } -LogicalResult VreluOp::verify() { return verifyUnaryVecOp(*this); } +LogicalResult VreluOp::verify() { + if (failed(verifyUnaryVecOp(*this))) + return failure(); + auto inputType = cast(getInput().getType()); + Type elemType = inputType.getElementType(); + if (auto intType = dyn_cast(elemType)) { + if (intType.getWidth() != 32 || intType.isUnsigned()) + return emitOpError("requires si32/i32/f16/f32 vector element type"); + return success(); + } + if (!elemType.isF16() && !elemType.isF32()) + return emitOpError("requires si32/i32/f16/f32 vector element type"); + return success(); +} LogicalResult VnotOp::verify() { return verifyUnaryVecOp(*this); } template diff --git a/lib/PTO/Transforms/PTOValidateVPTOIR.cpp b/lib/PTO/Transforms/PTOValidateVPTOIR.cpp index 7f1a54a18..ce9f76cd3 100644 --- a/lib/PTO/Transforms/PTOValidateVPTOIR.cpp +++ b/lib/PTO/Transforms/PTOValidateVPTOIR.cpp @@ -600,6 +600,27 @@ class VPTOLegalityValidator { .Default([](Operation *) { return success(); }); } + static LogicalResult validateUnaryElementTypeContracts(Operation *op) { + return llvm::TypeSwitch(op) + .Case([](VreluOp concreteOp) { + auto vecType = dyn_cast(concreteOp.getInput().getType()); + if (!vecType) + return success(); + + Type elemType = vecType.getElementType(); + if (auto intType = dyn_cast(elemType)) { + if (intType.getWidth() == 32 && !intType.isUnsigned()) + return success(); + } else if (elemType.isF16() || elemType.isF32()) { + return success(); + } + + concreteOp.emitOpError("requires si32/i32/f16/f32 vector element type"); + return failure(); + }) + .Default([](Operation *) { return success(); }); + } + static LogicalResult validateMaskGranularityContracts(Operation *op) { return llvm::TypeSwitch(op) .Case, %dst: !pto.ptr) attributes {pto.kernel_kind = #pto.kernel_kind} { + %c0 = arith.constant 0 : index + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %out = pto.vrelu %value, %mask : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> + pto.vsts %out, %dst[%c0], %mask : !pto.vreg<64xi32>, !pto.ptr, !pto.mask + } + return + } + + func.func @vrelu_signed_i32_store(%value: !pto.vreg<64xsi32>, %dst: !pto.ptr) attributes {pto.kernel_kind = #pto.kernel_kind} { + %c0 = arith.constant 0 : index + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %out = pto.vrelu %value, %mask : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<64xsi32> + pto.vsts %out, %dst[%c0], %mask : !pto.vreg<64xsi32>, !pto.ptr, !pto.mask + } + return + } +} + +// CHECK-COUNT-1: declare <64 x i32> @llvm.hivm.vrelu.v64s32.x(<64 x i32>, <256 x i1>) +// CHECK-LABEL: define void @vrelu_signless_i32_store( +// CHECK: %[[RELU0:[^ ]+]] = call <64 x i32> @llvm.hivm.vrelu.v64s32.x(<64 x i32> %0, <256 x i1> %[[MASK0:[^ ]+]]) +// CHECK: call void @llvm.hivm.vstsx1.v64s32(<64 x i32> %[[RELU0]], ptr addrspace(6) %1, i32 0, i32 2, i32 0, <256 x i1> %[[MASK0]]) +// CHECK-LABEL: define void @vrelu_signed_i32_store( +// CHECK: %[[RELU1:[^ ]+]] = call <64 x i32> @llvm.hivm.vrelu.v64s32.x(<64 x i32> %0, <256 x i1> %[[MASK1:[^ ]+]]) +// CHECK: call void @llvm.hivm.vstsx1.v64s32(<64 x i32> %[[RELU1]], ptr addrspace(6) %1, i32 0, i32 2, i32 0, <256 x i1> %[[MASK1]]) diff --git a/test/basic/vrelu_verify_invalid.pto b/test/basic/vrelu_verify_invalid.pto new file mode 100644 index 000000000..02e7af66c --- /dev/null +++ b/test/basic/vrelu_verify_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ! ptoas --pto-arch=a5 --pto-backend=vpto %s 2>&1 | FileCheck %s + +// Negative tests for pto.vrelu verifier. +// +// CHECK: error: 'pto.vrelu' op requires si32/i32/f16/f32 vector element type + +func.func @vrelu_bf16_invalid(%src: !pto.ptr, %dst: !pto.ptr) + attributes {pto.kernel_kind = #pto.kernel_kind} { + %c0 = arith.constant 0 : index + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + %vec = pto.vlds %src[%c0] : !pto.ptr -> !pto.vreg<128xbf16> + %out = pto.vrelu %vec, %mask : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<128xbf16> + pto.vsts %out, %dst[%c0], %mask : !pto.vreg<128xbf16>, !pto.ptr, !pto.mask + } + return +} diff --git a/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md b/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md index 7edf20936..513539b7e 100644 --- a/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md +++ b/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md @@ -89,6 +89,8 @@ abs_vec = pto.vabs(vec_f32, mask32) **Description**: ReLU activation (max(0, x)) of vector elements. +**Supported dtypes**: `si32`, `i32`, `f16`, `f32` + **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| diff --git a/tilelang-dsl/docs/vpto_spec/vpto-spec-current.md b/tilelang-dsl/docs/vpto_spec/vpto-spec-current.md index 620bb407b..a8c11d7a5 100644 --- a/tilelang-dsl/docs/vpto_spec/vpto-spec-current.md +++ b/tilelang-dsl/docs/vpto_spec/vpto-spec-current.md @@ -3326,7 +3326,7 @@ for (int i = 0; i < N; i++) ##### `pto.vrelu` - **syntax:** `%result = pto.vrelu %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` -- **A5 types:** f16, f32 +- **A5 types:** si32, i32, f16, f32 ```c for (int i = 0; i < N; i++) @@ -3335,8 +3335,9 @@ for (int i = 0; i < N; i++) - **inputs:** `%input` is the source vector and `%mask` selects active lanes. - **outputs:** `%result` holds `max(input[i], 0)` per active lane. -- **constraints and limitations:** Only floating-point element types are legal - on the current A5 surface described here. +- **constraints and limitations:** Signed or signless 32-bit integer and + floating-point element types are legal on the current A5 surface described + here. --- diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 13d7ec73c..f0e90c5eb 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -6033,8 +6033,11 @@ def _vreg_type_for_dtype(self, dtype: ScalarType) -> SemanticVRegType: def _validate_unary_dtype(self, name: str, dtype: ScalarType) -> None: if name in {"vexp", "vln", "vsqrt", "vrec", "vrsqrt"} and dtype.name not in {"f16", "f32"}: raise TypeError(f"pto.{name} only supports f16/f32 in TileLang DSL v1") - if name == "vrelu" and dtype.name not in {"f16", "f32"}: - raise TypeError("pto.vrelu only supports f16/f32 in TileLang DSL v1") + if name == "vrelu" and not ( + dtype.name in {"f16", "f32"} + or (is_integer_dtype(dtype) and integer_bitwidth(dtype) == 32) + ): + raise TypeError("pto.vrelu only supports i32/f16/f32 in TileLang DSL v1") if name in {"vnot", "vbcnt", "vcls", "vsunpack", "vzunpack", "vusqz", "vsqz"} and not ( is_integer_dtype(dtype) and integer_bitwidth(dtype) in {8, 16, 32} ): diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index 3e8414d24..eac6467d3 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -2917,6 +2917,39 @@ def kernel(tile: pto.Tile, scale: pto.f32): self.assertRegex(text, r"%activated_\d+ = pto\.vrelu %summed_\d+, %mask_\d+ : !pto\.vreg<64xf32>, !pto\.mask -> !pto\.vreg<64xf32>") self.assertRegex(text, r"pto\.vsts %activated_\d+, %dst_\d+\[%lane_\d+\], %mask_\d+ : !pto\.vreg<64xf32>, !pto\.ptr, !pto\.mask") + def test_vrelu_accepts_i32_inside_strict_vecscope(self) -> None: + @pto.vkernel(op="eltwise", dtypes=[(pto.i32, pto.i32)], advanced=True) + def kernel(tile: pto.Tile, bias: pto.i32): + with pto.strict_vecscope(tile, tile, bias, 0, 256, 64) as ( + src, + dst, + offset, + lb, + ub, + step, + ): + for lane in range(lb, ub, step): + mask = pto.make_mask(pto.i32, pto.PAT.ALL) + vec = pto.vlds(src, lane) + shifted = pto.vadds(vec, offset, mask) + activated = pto.vrelu(shifted, mask) + pto.vsts(activated, dst, lane, mask) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(16, 16), + memory_space=pto.MemorySpace.UB, + ), + ) + + text = specialized.mlir_text() + self.assertRegex(text, r'%mask_\d+ = pto\.pset_b32 "PAT_ALL" : !pto\.mask') + self.assertRegex(text, r"%vec_\d+ = pto\.vlds %src_\d+\[%lane_\d+\] : !pto\.ptr -> !pto\.vreg<64xi32>") + self.assertRegex(text, r"%shifted_\d+ = pto\.vadds %vec_\d+, %offset_\d+, %mask_\d+ : !pto\.vreg<64xi32>, i32, !pto\.mask -> !pto\.vreg<64xi32>") + self.assertRegex(text, r"%activated_\d+ = pto\.vrelu %shifted_\d+, %mask_\d+ : !pto\.vreg<64xi32>, !pto\.mask -> !pto\.vreg<64xi32>") + self.assertRegex(text, r"pto\.vsts %activated_\d+, %dst_\d+\[%lane_\d+\], %mask_\d+ : !pto\.vreg<64xi32>, !pto\.ptr, !pto\.mask") + def test_tail_make_mask_lowers_to_typed_plt_and_updates_remaining(self) -> None: @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.i32)], advanced=True) def kernel(tile: pto.Tile, remaining: pto.i32): From d58d89c7bee728c76e575dc3ed9e5efafb516146 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Thu, 23 Apr 2026 18:11:26 +0800 Subject: [PATCH 141/192] Fix DSL frontend vecscope auto inference --- tilelang-dsl/python/tilelang_dsl/semantic.py | 167 ++++++------------ .../python/tilelang_dsl/support_matrix.py | 13 ++ tilelang-dsl/tests/test_tilelang_dsl_v1.py | 62 +++++++ 3 files changed, 125 insertions(+), 117 deletions(-) diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index f0e90c5eb..0baa18ee2 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -44,6 +44,8 @@ ) from .support_matrix import ( DEFERRED_PTO_SURFACES, + INFERRED_VECSCOPE_ACTIVITY_PTO_CALLS, + INFERRED_VECSCOPE_NEUTRAL_PTO_CALLS, advanced_mode_message, deferred_surface_message, unsupported_feature_message, @@ -297,44 +299,7 @@ def _is_supported_mov_pad_scalar_dtype(dtype: ScalarType) -> bool: return dtype.name in {"f16", "bf16", "f32"} -_COMPARE_SELECT_OPS = {"vcmp", "vcmps", "vsel", "vselr", "vselrv2"} -_PREDICATE_MOVEMENT_OPS = { - "pset_b8", - "pset_b16", - "pset_b32", - "pge_b8", - "pge_b16", - "pge_b32", - "plt_b8", - "plt_b16", - "plt_b32", - "plds", - "pld", - "pldi", - "psts", - "pst", - "psti", - "pstu", - "pnot", - "psel", - "pand", - "por", - "pxor", - "ppack", - "punpack", - "pdintlv_b8", - "pintlv_b16", -} -_CARRY_OPS = {"vaddc", "vsubc", "vaddcs", "vsubcs"} -_REARRANGEMENT_OPS = {"vintlv", "vdintlv", "vintlvv2", "vdintlvv2"} _UB_HELPER_OPS = {"vbitsort", "vmrgsort4"} -_ADVANCED_VECTOR_ACTIVITY_OPS = ( - _COMPARE_SELECT_OPS - | _PREDICATE_MOVEMENT_OPS - | _CARRY_OPS - | _REARRANGEMENT_OPS - | {"vcvt"} -) _TENSORVIEW_RANK = 5 @@ -1043,12 +1008,12 @@ def _analyze_block( semantic_statements = [] index = 0 while index < len(statements): - if self._stmt_can_participate_in_inferred_vecscope( + if self._stmt_can_start_inferred_vecscope_run( statements[index], allow_inferred_vecscope=allow_inferred_vecscope, ): end = index + 1 - while end < len(statements) and self._stmt_can_participate_in_inferred_vecscope( + while end < len(statements) and self._stmt_can_continue_inferred_vecscope_run( statements[end], allow_inferred_vecscope=allow_inferred_vecscope, ): @@ -1083,21 +1048,40 @@ def _analyze_block( index += 1 return tuple(semantic_statements), current_env - def _stmt_can_participate_in_inferred_vecscope( + def _stmt_can_start_inferred_vecscope_run( + self, + stmt: FrontendStmtNode, + *, + allow_inferred_vecscope: bool, + ) -> bool: + if not self._stmt_allows_inferred_vecscope(allow_inferred_vecscope): + return False + if self._frontend_stmt_is_vecscope_boundary(stmt): + return False + return self._frontend_stmt_can_live_in_inferred_vecscope(stmt) + + def _stmt_can_continue_inferred_vecscope_run( self, stmt: FrontendStmtNode, *, allow_inferred_vecscope: bool, ) -> bool: + if not self._stmt_allows_inferred_vecscope(allow_inferred_vecscope): + return False + if self._frontend_stmt_is_vecscope_boundary(stmt): + return False + return self._frontend_stmt_can_live_in_inferred_vecscope( + stmt + ) or self._frontend_stmt_is_neutral_vecscope_stmt(stmt) + + def _stmt_allows_inferred_vecscope(self, allow_inferred_vecscope: bool) -> bool: if self._has_explicit_vecscope: return False if self._disable_inference_depth > 0: return False if not allow_inferred_vecscope: return False - if self._frontend_stmt_is_vecscope_boundary(stmt): - return False - return self._frontend_stmt_can_live_in_inferred_vecscope(stmt) + return True def _analyze_stmt_or_inline( self, @@ -1160,36 +1144,7 @@ def _should_infer_vecscope( if isinstance(stmt, FrontendForStmt): return self._block_can_live_in_inferred_vecscope(stmt.body) name = self._frontend_vector_call_name(stmt) - return name in ( - { - "make_mask", - "init_align", - "vlds", - "vldas", - "vldus", - "plds", - "psts", - "pstu", - "vsst", - "vsta", - "vstas", - "vstar", - "vscatter", - "vsts", - "vstsx2", - "vstus", - "vstur", - } - | _UNARY_VECTOR_OPS - | _BINARY_VECTOR_OPS - | _VECTOR_SCALAR_OPS - | _VECTOR_IMMEDIATE_OPS - | _TERNARY_VECTOR_OPS - | _MULTI_RESULT_VECTOR_OPS - | _BROADCAST_VECTOR_OPS - | _ADVANCED_VECTOR_ACTIVITY_OPS - | _VEXPDIF_OP_ALIASES - ) + return name in INFERRED_VECSCOPE_ACTIVITY_PTO_CALLS def _block_can_live_in_inferred_vecscope( self, @@ -1214,6 +1169,13 @@ def _frontend_stmt_is_vecscope_boundary(self, stmt: FrontendStmtNode) -> bool: return True if isinstance(stmt, FrontendIfStmt): return not stmt.is_constexpr + if ( + isinstance(stmt, FrontendExprStmt) + and isinstance(stmt.expr, FrontendCallExpr) + and stmt.expr.namespace == "pto" + and stmt.expr.name in INFERRED_VECSCOPE_NEUTRAL_PTO_CALLS + ): + return False return ( isinstance(stmt, FrontendExprStmt) and ( @@ -1244,12 +1206,20 @@ def _frontend_stmt_is_scalar_vecscope_stmt( stmt: FrontendStmtNode, ) -> bool: return isinstance(stmt, FrontendNoOpStmt) or isinstance(stmt, FrontendAssignStmt) or ( + self._frontend_stmt_is_neutral_vecscope_stmt(stmt) + ) or ( + isinstance(stmt, FrontendIfStmt) and stmt.is_constexpr + ) + + def _frontend_stmt_is_neutral_vecscope_stmt( + self, + stmt: FrontendStmtNode, + ) -> bool: + return ( isinstance(stmt, FrontendExprStmt) and isinstance(stmt.expr, FrontendCallExpr) and stmt.expr.namespace == "pto" - and stmt.expr.name == "store_scalar" - ) or ( - isinstance(stmt, FrontendIfStmt) and stmt.is_constexpr + and stmt.expr.name in INFERRED_VECSCOPE_NEUTRAL_PTO_CALLS ) def _frontend_stmt_contains_vector_activity(self, stmt: FrontendStmtNode) -> bool: @@ -1262,36 +1232,7 @@ def _frontend_stmt_contains_vector_activity(self, stmt: FrontendStmtNode) -> boo return False return ( expr.namespace == "pto" - and expr.name in ( - { - "make_mask", - "init_align", - "vlds", - "vldas", - "vldus", - "plds", - "psts", - "pstu", - "vsst", - "vsta", - "vstas", - "vstar", - "vscatter", - "vsts", - "vstsx2", - "vstus", - "vstur", - } - | _UNARY_VECTOR_OPS - | _BINARY_VECTOR_OPS - | _VECTOR_SCALAR_OPS - | _VECTOR_IMMEDIATE_OPS - | _TERNARY_VECTOR_OPS - | _MULTI_RESULT_VECTOR_OPS - | _BROADCAST_VECTOR_OPS - | _ADVANCED_VECTOR_ACTIVITY_OPS - | _VEXPDIF_OP_ALIASES - ) + and expr.name in INFERRED_VECSCOPE_ACTIVITY_PTO_CALLS ) def _run_contains_vector_op(self, statements: tuple[FrontendStmtNode, ...]) -> bool: @@ -1404,17 +1345,9 @@ def _semantic_block_contains_vector_activity( def _expr_contains_vector_activity(self, expr: SemanticExpr) -> bool: if isinstance(expr, SemanticCallExpr): - if expr.namespace == "pto" and expr.name in ( - {"make_mask", "vlds"} - | _UNARY_VECTOR_OPS - | _BINARY_VECTOR_OPS - | _VECTOR_SCALAR_OPS - | _VECTOR_IMMEDIATE_OPS - | _TERNARY_VECTOR_OPS - | _MULTI_RESULT_VECTOR_OPS - | _BROADCAST_VECTOR_OPS - | _ADVANCED_VECTOR_ACTIVITY_OPS - | _VEXPDIF_OP_ALIASES + if ( + expr.namespace == "pto" + and expr.name in INFERRED_VECSCOPE_ACTIVITY_PTO_CALLS ): return True return any(self._expr_contains_vector_activity(arg) for arg in expr.args) diff --git a/tilelang-dsl/python/tilelang_dsl/support_matrix.py b/tilelang-dsl/python/tilelang_dsl/support_matrix.py index bb4b311dd..f4aa6ab41 100644 --- a/tilelang-dsl/python/tilelang_dsl/support_matrix.py +++ b/tilelang-dsl/python/tilelang_dsl/support_matrix.py @@ -200,6 +200,17 @@ } ) +INFERRED_VECSCOPE_ACTIVITY_PTO_CALLS = frozenset( + SUPPORTED_VECSCOPE_PTO_CALLS | ADVANCED_VECSCOPE_PTO_CALLS +) + +INFERRED_VECSCOPE_NEUTRAL_PTO_CALLS = frozenset( + { + "mem_bar", + "store_scalar", + } +) + ADVANCED_EXPR_PTO_CALLS = frozenset( { "ptr", @@ -440,6 +451,8 @@ def get_surface_group_tier(group_name: str) -> str: "ADVANCED_EXPR_PTO_CALLS", "ADVANCED_TOPLEVEL_PTO_CALLS", "ADVANCED_VECSCOPE_PTO_CALLS", + "INFERRED_VECSCOPE_ACTIVITY_PTO_CALLS", + "INFERRED_VECSCOPE_NEUTRAL_PTO_CALLS", "SUPPORTED_TOPLEVEL_PTO_CALLS", "SUPPORTED_VECSCOPE_PTO_CALLS", "BASIC_TIER", diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index eac6467d3..2256ed034 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -24,6 +24,10 @@ AUTHORING_TIER_SURFACE_GROUPS, BASIC_TIER, BASIC_TILE_INDEXING_SURFACES, + ADVANCED_VECSCOPE_PTO_CALLS, + INFERRED_VECSCOPE_ACTIVITY_PTO_CALLS, + INFERRED_VECSCOPE_NEUTRAL_PTO_CALLS, + SUPPORTED_VECSCOPE_PTO_CALLS, get_feature_tier, get_surface_group_tier, ) @@ -499,6 +503,21 @@ def test_unsupported_features_do_not_report_legacy_tiers(self) -> None: with self.assertRaises(KeyError): get_feature_tier("pto.vreduce") + def test_inferred_vecscope_tables_follow_supported_vecscope_surfaces(self) -> None: + self.assertTrue( + SUPPORTED_VECSCOPE_PTO_CALLS.issubset( + INFERRED_VECSCOPE_ACTIVITY_PTO_CALLS + ) + ) + self.assertTrue( + ADVANCED_VECSCOPE_PTO_CALLS.issubset( + INFERRED_VECSCOPE_ACTIVITY_PTO_CALLS + ) + ) + self.assertIn("vbitcast", INFERRED_VECSCOPE_ACTIVITY_PTO_CALLS) + self.assertIn("vselr", INFERRED_VECSCOPE_ACTIVITY_PTO_CALLS) + self.assertIn("mem_bar", INFERRED_VECSCOPE_NEUTRAL_PTO_CALLS) + class TileLangDSLMatcherEntryTests(unittest.TestCase): def test_select_kernel_returns_descriptor_from_default_registry(self) -> None: @@ -6100,6 +6119,49 @@ def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile, scalar: pto.i32): self.assertIn(" = pto.vselrv2 ", text) self.assertIn("pto.vsts ", text) + def test_inferred_vecscope_keeps_vbitcast_and_mem_bar_with_vector_users(self) -> None: + @pto.vkernel(op="issue_217_vecscope", dtypes=[(pto.i32, pto.ui8)], advanced=True) + def kernel(src: pto.Tile, dst: pto.Tile): + valid_rows, valid_cols = dst.valid_shape + full_mask = pto.make_mask(pto.i32, pto.PAT.ALL) + idx_mask = pto.make_mask(pto.i16, pto.PAT.ALL) + v_idx = pto.vci(pto.i8(0), pto.OrderMode.ASC) + v_idx_i16 = pto.vbitcast(v_idx, pto.i16) + v_idx_i16 = pto.vmuls(v_idx_i16, pto.i16(4), idx_mask) + v_idx_ui8 = pto.vbitcast(v_idx_i16, pto.ui8) + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(pto.i32)): + store_mask, remained = pto.make_mask(pto.ui8, remained) + vec = pto.vlds(src[row, col:]) + converted = pto.vcvt( + vec, + pto.ui8, + full_mask, + sat=pto.VcvtSatMode.NOSAT, + part=pto.VcvtPartMode.P0, + ) + result = pto.vselr(converted, v_idx_ui8) + pto.mem_bar(pto.BarrierType.VST_VST) + pto.vsts(result, dst[row, col:], store_mask, dist="NORM_B8") + return None + + specialized = kernel.specialize( + src=pto.TileSpecialization(shape=(16, 64), memory_space=pto.MemorySpace.UB), + dst=pto.TileSpecialization(shape=(16, 64), memory_space=pto.MemorySpace.UB), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + vecscope_stmts = [stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticVecscopeStmt)] + self.assertEqual(len(vecscope_stmts), 1) + + text = specialized.mlir_text() + self.assertEqual(text.count("pto.vecscope {"), 1) + self.assertIn("pto.vbitcast", text) + self.assertIn('pto.mem_bar "VST_VST"', text) + self.assertIn("pto.vselr", text) + self.assertIn("pto.vsts", text) + def test_elementwise_kernel_positive_regression_covers_vecscope_tail_mask_and_dynamic_loop_bound(self) -> None: @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32, pto.i32)], advanced=True) def kernel(inp: pto.TensorView, tile: pto.Tile, remaining: pto.i32): From b2e071b2088f1d20d9e4e748eb107798f83a1a44 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Thu, 23 Apr 2026 00:02:43 +0800 Subject: [PATCH 142/192] feat(tilelang-dsl): support constructor calls on static dtype bindings --- .../docs/user_guide/05-type-system.md | 9 +++ tilelang-dsl/python/tilelang_dsl/kernel.py | 66 +++++++++++++++++++ tilelang-dsl/python/tilelang_dsl/semantic.py | 55 ++++++++++++++-- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 39 +++++++++++ 4 files changed, 162 insertions(+), 7 deletions(-) diff --git a/tilelang-dsl/docs/user_guide/05-type-system.md b/tilelang-dsl/docs/user_guide/05-type-system.md index b68432daf..2b1e3772b 100644 --- a/tilelang-dsl/docs/user_guide/05-type-system.md +++ b/tilelang-dsl/docs/user_guide/05-type-system.md @@ -43,6 +43,15 @@ y: pto.i32 = 1024 # Type annotation z = pto.ui16(7) # Explicit unsigned 16-bit constant ``` +Static dtype bindings can also be called like constructors. This is useful when +the dtype comes from compile-time metadata such as `element_type`: + +```python +idx_dtype = tile.element_type +zero_idx = idx_dtype(0) +v_col = idx_dtype(col) +``` + Integer sign semantics are part of the DSL type surface. `pto.si16`, `pto.ui16`, and `pto.i16` are distinct scalar dtypes and lower to `si16`, `ui16`, and `i16` respectively in VPTO IR. diff --git a/tilelang-dsl/python/tilelang_dsl/kernel.py b/tilelang-dsl/python/tilelang_dsl/kernel.py index 743188511..eb01241be 100644 --- a/tilelang-dsl/python/tilelang_dsl/kernel.py +++ b/tilelang-dsl/python/tilelang_dsl/kernel.py @@ -61,6 +61,27 @@ | ADVANCED_TOPLEVEL_PTO_CALLS ) +_DSL_DTYPE_NAMES = frozenset( + { + "i1", + "i8", + "si8", + "ui8", + "i16", + "si16", + "ui16", + "i32", + "si32", + "ui32", + "i64", + "si64", + "ui64", + "f16", + "bf16", + "f32", + } +) + _INLINE_PROC_REGISTRY: dict[tuple[str, str], "InlineProcDescriptor"] = {} @@ -242,6 +263,7 @@ def __init__(self, source_info: _FunctionSourceInfo, *, advanced_enabled: bool, self.advanced_enabled = advanced_enabled self.module_name = module_name self._vecscope_depth = 0 + self._static_dtype_bindings: set[str] = set() def validate(self) -> None: for stmt in self.source_info.function_def.body: @@ -295,6 +317,22 @@ def visit_If(self, node: ast.If) -> None: for stmt in node.orelse: self.visit(stmt) + def visit_Assign(self, node: ast.Assign) -> None: + self.visit(node.value) + is_static_dtype = self._expr_is_static_dtype_expr(node.value) + for target in node.targets: + self._update_static_dtype_bindings(target, is_static_dtype=is_static_dtype) + + def visit_AnnAssign(self, node: ast.AnnAssign) -> None: + if node.value is not None: + self.visit(node.value) + is_static_dtype = node.value is not None and self._expr_is_static_dtype_expr(node.value) + self._update_static_dtype_bindings(node.target, is_static_dtype=is_static_dtype) + + def visit_AugAssign(self, node: ast.AugAssign) -> None: + self.visit(node.value) + self._update_static_dtype_bindings(node.target, is_static_dtype=False) + def visit_With(self, node: ast.With) -> None: if len(node.items) != 1: raise self.source_info.error(node, "only single with item is supported in TileLang DSL v1") @@ -484,6 +522,9 @@ def visit_Call(self, node: ast.Call) -> None: if node.func.id == "range": self._validate_call_keywords(node) return + if node.func.id in self._static_dtype_bindings: + self._validate_call_keywords(node) + return inline_proc = _find_inline_proc(node.func.id, module_name=self.module_name) if inline_proc is not None: _validate_inline_proc_call_surface(self.source_info, node, inline_proc) @@ -498,6 +539,31 @@ def visit_Call(self, node: ast.Call) -> None: "unsupported call surface in TileLang DSL v1", ) + def _expr_is_static_dtype_expr(self, node: ast.AST) -> bool: + if isinstance(node, ast.Name): + return node.id in self._static_dtype_bindings + if isinstance(node, ast.Attribute): + if ( + isinstance(node.value, ast.Name) + and node.value.id == "pto" + and node.attr in _DSL_DTYPE_NAMES + ): + return True + if node.attr == "element_type": + return True + return False + + def _update_static_dtype_bindings(self, target: ast.expr, *, is_static_dtype: bool) -> None: + if isinstance(target, ast.Name): + if is_static_dtype: + self._static_dtype_bindings.add(target.id) + else: + self._static_dtype_bindings.discard(target.id) + return + if isinstance(target, (ast.Tuple, ast.List)): + for element in target.elts: + self._update_static_dtype_bindings(element, is_static_dtype=False) + def _load_function_source_info(py_fn: Callable[..., Any]) -> _FunctionSourceInfo | None: try: diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 0baa18ee2..a6bfb530f 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -3055,6 +3055,27 @@ def _analyze_expr( expr, ) if isinstance(expr, FrontendCallExpr): + if expr.namespace is None: + binding = env.get(expr.name) + if ( + binding is not None + and isinstance(binding.type, SemanticMetaType) + and binding.type.kind == "dtype" + and isinstance(binding.value, ScalarType) + ): + if expr.keywords: + raise TypeError( + f"`{expr.name}` does not support keyword arguments in TileLang DSL v1" + ) + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + return self._analyze_scalar_constructor_for_dtype( + binding.value, + args, + surface_name=expr.name, + ) if expr.namespace is None and expr.name in self._inline_proc_nodes: if expr.keywords: raise TypeError( @@ -4116,18 +4137,30 @@ def _analyze_scalar_constructor( self, name: str, args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + return self._analyze_scalar_constructor_for_dtype( + _DTYPE_SYMBOLS[name], + args, + surface_name=f"pto.{name}", + ) + + def _analyze_scalar_constructor_for_dtype( + self, + target_dtype: ScalarType, + args: tuple[SemanticExpr, ...], + *, + surface_name: str, ) -> SemanticExpr: if len(args) != 1: - raise TypeError(f"pto.{name} expects exactly 1 positional argument in TileLang DSL v1") + raise TypeError(f"{surface_name} expects exactly 1 positional argument in TileLang DSL v1") - target_dtype = _DTYPE_SYMBOLS[name] if ( target_dtype.name in {"f16", "bf16", "f32"} and isinstance(args[0], SemanticLiteralExpr) and isinstance(args[0].type, SemanticMetaType) and args[0].type.kind == "string" ): - parsed = self._parse_float_literal_string(args[0].value, target_dtype, f"pto.{name} value") + parsed = self._parse_float_literal_string(args[0].value, target_dtype, f"{surface_name} value") return SemanticLiteralExpr( value=parsed, type=SemanticScalarType(dtype=target_dtype), @@ -4138,13 +4171,17 @@ def _analyze_scalar_constructor( and isinstance(args[0].type, SemanticMetaType) and args[0].type.kind == "string" ): - parsed = self._parse_integer_literal_string(args[0].value, target_dtype, f"pto.{name} value") + parsed = self._parse_integer_literal_string( + args[0].value, + target_dtype, + f"{surface_name} value", + ) return SemanticLiteralExpr( value=parsed, type=SemanticScalarType(dtype=target_dtype), ) - value = self._require_scalar_or_index_expr(args[0], f"pto.{name} value") + value = self._require_scalar_or_index_expr(args[0], f"{surface_name} value") if isinstance(value.type, SemanticScalarType) and value.type.dtype == target_dtype: return value @@ -4166,7 +4203,11 @@ def _analyze_scalar_constructor( else: casted = None if casted is not None: - checked = self._check_integer_literal_range(casted, target_dtype, f"pto.{name} value") + checked = self._check_integer_literal_range( + casted, + target_dtype, + f"{surface_name} value", + ) return SemanticLiteralExpr(value=checked, type=SemanticScalarType(dtype=target_dtype)) else: if isinstance(literal_value, (bool, int, float)): @@ -4177,7 +4218,7 @@ def _analyze_scalar_constructor( return SemanticCallExpr( namespace="pto", - name=name, + name=target_dtype.name, args=(value,), type=SemanticScalarType(dtype=target_dtype), ) diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index 2256ed034..9688ea718 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -2446,6 +2446,45 @@ def kernel(tile: pto.Tile): self.assertIsInstance(scalar_assign.targets[0].type, SemanticScalarType) self.assertEqual(scalar_assign.targets[0].type.dtype, pto.f32) + def test_static_dtype_binding_supports_constructor_call_surface(self) -> None: + @pto.vkernel(op="static_dtype_binding_constructor_unique", dtypes=[(pto.i32,)]) + def kernel(tile: pto.Tile): + idx_dtype = tile.element_type + cols = tile.shape[1] + zero_idx = idx_dtype(0) + v_col = idx_dtype(cols) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(8, 16), + memory_space=pto.MemorySpace.UB, + ) + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + dtype_assign, cols_assign, zero_assign, cast_assign = semantic_kernel.body[:4] + + self.assertIsInstance(dtype_assign, SemanticAssignStmt) + self.assertIsInstance(dtype_assign.value, SemanticSymbolExpr) + self.assertEqual(dtype_assign.value.value, pto.i32) + + self.assertIsInstance(cols_assign, SemanticAssignStmt) + self.assertIsInstance(cols_assign.targets[0].type, SemanticIndexType) + + self.assertIsInstance(zero_assign, SemanticAssignStmt) + self.assertIsInstance(zero_assign.value, SemanticLiteralExpr) + self.assertEqual(zero_assign.value.value, 0) + self.assertIsInstance(zero_assign.targets[0].type, SemanticScalarType) + self.assertEqual(zero_assign.targets[0].type.dtype, pto.i32) + + self.assertIsInstance(cast_assign, SemanticAssignStmt) + self.assertIsInstance(cast_assign.value, SemanticCallExpr) + self.assertEqual(cast_assign.value.namespace, "pto") + self.assertEqual(cast_assign.value.name, "i32") + self.assertIsInstance(cast_assign.targets[0].type, SemanticScalarType) + self.assertEqual(cast_assign.targets[0].type.dtype, pto.i32) + def test_unsigned_integer_constants_lower_with_signless_arith_types(self) -> None: @pto.vkernel(op="tile_pad_value_ui32_max_eval_unique", dtypes=[(pto.ui32,)]) def kernel(tile: pto.Tile): From dd3759b48023c917b426205b6a9501318fcb40fc Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Thu, 23 Apr 2026 19:42:25 +0800 Subject: [PATCH 143/192] fix(dsl): fix DSL frontend punpack lowering --- lib/PTO/IR/VPTO.cpp | 15 ++++++++ lib/PTO/Transforms/PTOValidateVPTOIR.cpp | 26 +++++++++++++- tilelang-dsl/python/tilelang_dsl/semantic.py | 13 ++++++- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 37 ++++++++++++++++++++ 4 files changed, 89 insertions(+), 2 deletions(-) diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp index c9582c29d..c92ecf0b3 100644 --- a/lib/PTO/IR/VPTO.cpp +++ b/lib/PTO/IR/VPTO.cpp @@ -87,6 +87,12 @@ static LogicalResult verifyMaskTypeWithGranularityLike(Operation *op, Type type, return success(); } +static bool isMaskGranularityAdjacentWidening(StringRef inputGranularity, + StringRef resultGranularity) { + return (inputGranularity == "b8" && resultGranularity == "b16") || + (inputGranularity == "b16" && resultGranularity == "b32"); +} + static LogicalResult verifyEnclosingLoopLike(Operation *op, StringRef opNameForDiag) { if (!op->getParentOfType()) { @@ -2130,6 +2136,15 @@ LogicalResult PunpackOp::verify() { return failure(); if (getPart() != "LOWER") return emitOpError("currently supports only LOWER part"); + auto inputMaskType = cast(getInput().getType()); + auto resultMaskType = cast(getResult().getType()); + StringRef inputGranularity = inputMaskType.getGranularity(); + StringRef resultGranularity = resultMaskType.getGranularity(); + if (inputGranularity != resultGranularity && + !isMaskGranularityAdjacentWidening(inputGranularity, resultGranularity)) { + return emitOpError( + "requires result mask granularity to match the input or widen by one step"); + } return success(); } diff --git a/lib/PTO/Transforms/PTOValidateVPTOIR.cpp b/lib/PTO/Transforms/PTOValidateVPTOIR.cpp index ce9f76cd3..172cade39 100644 --- a/lib/PTO/Transforms/PTOValidateVPTOIR.cpp +++ b/lib/PTO/Transforms/PTOValidateVPTOIR.cpp @@ -389,6 +389,27 @@ class VPTOLegalityValidator { << rhsRole << " " << rhsType; } + static bool isAdjacentMaskGranularityWidening(VPTOMaskGranularity input, + VPTOMaskGranularity result) { + return (input == VPTOMaskGranularity::B8 && + result == VPTOMaskGranularity::B16) || + (input == VPTOMaskGranularity::B16 && + result == VPTOMaskGranularity::B32); + } + + static LogicalResult validatePunpackMaskGranularity(PunpackOp op) { + auto input = VPTOLegalityHelper::getMaskGranularity(op.getInput().getType()); + auto result = VPTOLegalityHelper::getMaskGranularity(op.getResult().getType()); + if (!input || !result || *input == *result || + isAdjacentMaskGranularityWidening(*input, *result)) + return success(); + + return op.emitOpError() + << "input mask type " << op.getInput().getType() + << " does not match result mask type " << op.getResult().getType() + << " for pto.punpack"; + } + template static LogicalResult validateInputMaskVectorConsumer(OpTy op) { return validateMaskMatchesVectorFamily(op, op.getMask().getType(), @@ -643,9 +664,12 @@ class VPTOLegalityValidator { return validateCompareFamilyContract(concreteOp, concreteOp.getSrc().getType()); }) - .Case([](auto concreteOp) { + .Case([](PpackOp concreteOp) { return validateMaskOnlyUnaryContract(concreteOp); }) + .Case([](PunpackOp concreteOp) { + return validatePunpackMaskGranularity(concreteOp); + }) .Case( [](PnotOp concreteOp) { return validateMaskOnlyPnotContract(concreteOp); }) .Case( diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index a6bfb530f..96b15f0c2 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -4851,7 +4851,18 @@ def _analyze_mask_part_op( raise TypeError(f"pto.{name} expects exactly 2 positional arguments in TileLang DSL") mask = self._require_mask_expr(args[0], f"pto.{name} mask") part = self._normalize_predicate_part(args[1], f"pto.{name} part") - return SemanticCallExpr(namespace="pto", name=name, args=(args[0], part), type=mask) + result_granularity = mask.granularity + if name == "punpack": + if mask.granularity == "b8": + result_granularity = "b16" + elif mask.granularity == "b16": + result_granularity = "b32" + return SemanticCallExpr( + namespace="pto", + name=name, + args=(args[0], part), + type=SemanticMaskType(granularity=result_granularity), + ) def _analyze_mask_logic_op( self, diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index 9688ea718..e44c1f2bf 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -6183,6 +6183,7 @@ def kernel(src: pto.Tile, dst: pto.Tile): result = pto.vselr(converted, v_idx_ui8) pto.mem_bar(pto.BarrierType.VST_VST) pto.vsts(result, dst[row, col:], store_mask, dist="NORM_B8") + return None specialized = kernel.specialize( @@ -6201,6 +6202,42 @@ def kernel(src: pto.Tile, dst: pto.Tile): self.assertIn("pto.vselr", text) self.assertIn("pto.vsts", text) + def test_punpack_widens_b16_mask_for_norm_b32_store_in_advanced_mode(self) -> None: + @pto.vkernel(op="punpack_widen_b16_to_b32_unique", dtypes=[(pto.si8, pto.i32)], advanced=True) + def kernel(src: pto.Tile, dst: pto.Tile): + valid_rows, valid_cols = dst.valid_shape + lanes_i32 = pto.get_lanes(pto.i32) + for row in range(0, valid_rows, 1): + b8_mask = pto.make_mask(pto.i8, pto.PAT.ALL) + mask_b16, _ = pto.make_mask(pto.i16, valid_cols) + mask_b32 = pto.punpack(mask_b16, pto.PredicatePart.LOWER) + vec_si8 = pto.vlds(src[row, 0:], dist="UNPK_B8") + vec_ui8 = pto.vbitcast(vec_si8, pto.ui8) + v_zero_i8 = pto.vdup(pto.i8(0), b8_mask) + v_zero = pto.vbitcast(v_zero_i8, pto.ui8) + wide_lo, _ = pto.vintlv(vec_ui8, v_zero) + narrowed = pto.vbitcast(wide_lo, pto.si8) + converted = pto.vcvt(narrowed, pto.i32, b8_mask, part=pto.VcvtPartMode.P0) + pto.vsts(converted, dst[row, 0:], mask_b32, dist="NORM_B32") + pto.vsts(converted, dst[row, lanes_i32:], mask_b32, dist="NORM_B32") + return None + + specialized = kernel.specialize( + src=pto.TileSpecialization(shape=(16, 64), memory_space=pto.MemorySpace.UB), + dst=pto.TileSpecialization(shape=(16, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn(' = pto.punpack ', text) + self.assertRegex( + text, + r"pto\.punpack %mask_b16_\d+, \"LOWER\" : !pto\.mask -> !pto\.mask", + ) + self.assertRegex( + text, + r"pto\.vsts %converted_\d+, %tmp_\d+\[%c0\], %mask_b32_\d+ \{dist = \"NORM_B32\"\} : !pto\.vreg<64xi32>, memref<\?x\?xi32, strided<\[\?, \?\], offset: \?>, #pto\.address_space>, !pto\.mask", + ) + def test_elementwise_kernel_positive_regression_covers_vecscope_tail_mask_and_dynamic_loop_bound(self) -> None: @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32, pto.i32)], advanced=True) def kernel(inp: pto.TensorView, tile: pto.Tile, remaining: pto.i32): From 7ea5b6f0f528a7d025c498e82cd58ec72a385a68 Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Thu, 23 Apr 2026 22:45:04 +0800 Subject: [PATCH 144/192] bugfix: vbr/cmps with vreg accepts i16 as scalar operand --- lib/PTO/IR/VPTO.cpp | 27 ++- lib/PTO/Transforms/PTOToVPTOLowering.cpp | 37 +++- lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 51 ++++- .../vcmps-i16-unsigned/kernel.pto | 5 +- .../compare-select/vcmps-i8-signed/compare.py | 45 ++++ .../compare-select/vcmps-i8-signed/golden.py | 53 +++++ .../compare-select/vcmps-i8-signed/kernel.pto | 45 ++++ .../compare-select/vcmps-i8-signed/launch.cpp | 48 +++++ .../compare-select/vcmps-i8-signed/main.cpp | 92 ++++++++ .../compare-select/vcmps-i8-signed/stub.cpp | 22 ++ .../vcmps-i8-unsigned/compare.py | 45 ++++ .../vcmps-i8-unsigned/golden.py | 53 +++++ .../vcmps-i8-unsigned/kernel.pto | 45 ++++ .../vcmps-i8-unsigned/launch.cpp | 48 +++++ .../compare-select/vcmps-i8-unsigned/main.cpp | 92 ++++++++ .../compare-select/vcmps-i8-unsigned/stub.cpp | 22 ++ .../vbr-i8/compare.py | 204 ++++++++++++++++++ .../vbr-i8/golden.py | 45 ++++ .../vbr-i8/kernel.pto | 49 +++++ .../vbr-i8/launch.cpp | 61 ++++++ .../materialization-predicate/vbr-i8/main.cpp | 111 ++++++++++ .../materialization-predicate/vbr-i8/stub.cpp | 23 ++ .../vbr-u8/compare.py | 204 ++++++++++++++++++ .../vbr-u8/golden.py | 45 ++++ .../vbr-u8/kernel.pto | 49 +++++ .../vbr-u8/launch.cpp | 61 ++++++ .../materialization-predicate/vbr-u8/main.cpp | 111 ++++++++++ .../materialization-predicate/vbr-u8/stub.cpp | 23 ++ .../vdup-scalar-i8/kernel.pto | 10 +- .../vdup-scalar-u8/compare.py | 54 +++++ .../vdup-scalar-u8/golden.py | 45 ++++ .../vdup-scalar-u8/kernel.pto | 36 ++++ .../vdup-scalar-u8/launch.cpp | 43 ++++ .../vdup-scalar-u8/main.cpp | 101 +++++++++ .../vdup-scalar-u8/stub.cpp | 20 ++ 35 files changed, 2003 insertions(+), 22 deletions(-) create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/compare.py create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/golden.py create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/kernel.pto create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/launch.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/main.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/stub.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/compare.py create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/golden.py create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/kernel.pto create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/launch.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/main.cpp create mode 100644 test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/stub.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vbr-i8/compare.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vbr-i8/golden.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vbr-i8/kernel.pto create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vbr-i8/launch.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vbr-i8/main.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vbr-i8/stub.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vbr-u8/compare.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vbr-u8/golden.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vbr-u8/kernel.pto create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vbr-u8/launch.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vbr-u8/main.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vbr-u8/stub.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/compare.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/golden.py create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/kernel.pto create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/launch.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/main.cpp create mode 100644 test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/stub.cpp diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp index c92ecf0b3..7cbfcdf4d 100644 --- a/lib/PTO/IR/VPTO.cpp +++ b/lib/PTO/IR/VPTO.cpp @@ -1540,6 +1540,23 @@ LogicalResult SetMovPadValOp::verify() { << valueType; } +static bool isCompatibleScalarForSemanticType(Type semanticType, + Type scalarType) { + if (semanticType == scalarType) + return true; + + auto semanticInt = dyn_cast(semanticType); + auto scalarInt = dyn_cast(scalarType); + if (!semanticInt || !scalarInt || semanticInt.getWidth() != scalarInt.getWidth()) + return false; + + if (semanticInt.isSigned()) + return scalarInt.isSigned() || scalarInt.isSignless(); + if (semanticInt.isUnsigned()) + return scalarInt.isUnsigned() || scalarInt.isSignless(); + return scalarInt.isSignless(); +} + LogicalResult VbrOp::verify() { if (failed(verifyVRegTypeLike(*this, getResult().getType(), "result"))) return failure(); @@ -1548,7 +1565,8 @@ LogicalResult VbrOp::verify() { Type elementType = getValue().getType(); if (isa(elementType)) return emitOpError("value must be a scalar matching the result element type"); - if (elementType != resultVecType.getElementType()) + Type resultElementType = resultVecType.getElementType(); + if (!isCompatibleScalarForSemanticType(resultElementType, elementType)) return emitOpError("value type must match result element type"); return success(); } @@ -1954,7 +1972,8 @@ LogicalResult VdupOp::verify() { if (getPosition()) return emitOpError("position is only supported for vector input"); - if (inputType != resultType.getElementType()) + Type resultElementType = resultType.getElementType(); + if (!isCompatibleScalarForSemanticType(resultElementType, inputType)) return emitOpError("scalar input must match result element type"); return success(); @@ -2633,7 +2652,9 @@ LogicalResult VcmpsOp::verify() { failed(verifyMaskTypeLike(*this, getResult().getType(), "result type"))) return failure(); auto srcType = cast(getSrc().getType()); - if (getScalar().getType() != srcType.getElementType()) + Type srcElementType = srcType.getElementType(); + Type scalarType = getScalar().getType(); + if (!isCompatibleScalarForSemanticType(srcElementType, scalarType)) return emitOpError("requires scalar type to match source element type"); if (!isSupportedCmpMode(getCmpMode())) return emitOpError("requires cmp_mode to be one of eq/ne/lt/le/gt/ge"); diff --git a/lib/PTO/Transforms/PTOToVPTOLowering.cpp b/lib/PTO/Transforms/PTOToVPTOLowering.cpp index 23c673505..cdb2f711d 100644 --- a/lib/PTO/Transforms/PTOToVPTOLowering.cpp +++ b/lib/PTO/Transforms/PTOToVPTOLowering.cpp @@ -1728,6 +1728,23 @@ VPTOUnaryContract buildUnaryContract(StringRef family, Value src) { return contract; } +static bool isCompatibleScalarForSemanticType(Type semanticType, + Type scalarType) { + if (semanticType == scalarType) + return true; + + auto semanticInt = dyn_cast(semanticType); + auto scalarInt = dyn_cast(scalarType); + if (!semanticInt || !scalarInt || semanticInt.getWidth() != scalarInt.getWidth()) + return false; + + if (semanticInt.isSigned()) + return scalarInt.isSigned() || scalarInt.isSignless(); + if (semanticInt.isUnsigned()) + return scalarInt.isUnsigned() || scalarInt.isSignless(); + return scalarInt.isSignless(); +} + VPTOUnaryContract extractTExpContract(TExpOp op) { return buildUnaryContract("exp", op.getSrc()); } @@ -5050,7 +5067,8 @@ LogicalResult lowerTCmpS(TCmpSOp op, PatternRewriter &rewriter) { auto dstElemType = dyn_cast_or_null(getElementType(op.getDst())); if (!dstElemType || !dstElemType.isUnsignedInteger(8)) return op.emitOpError("tcmps lowering currently requires ui8 destination tiles"); - if (op.getScalar().getType() != contract.elementType) + if (!isCompatibleScalarForSemanticType(contract.elementType, + op.getScalar().getType())) return op.emitOpError("tcmps lowering requires scalar type to match source element type"); Value srcBuffer = materializeBufferPointer(op.getSrc(), contract.elementType, @@ -5491,7 +5509,7 @@ LogicalResult lowerTExpandS(TExpandsOp op, PatternRewriter &rewriter) { return op.emitOpError("expands lowering requires a concrete element type"); Type scalarType = op.getScalar().getType(); - if (scalarType != contract.elementType) + if (!isCompatibleScalarForSemanticType(contract.elementType, scalarType)) return op.emitOpError("expands lowering requires scalar type to match destination element type"); if (!(contract.elementType.isF16() || contract.elementType.isF32() || @@ -6212,7 +6230,8 @@ LogicalResult lowerTMulS(TMulSOp op, PatternRewriter &rewriter, }, "f16, f32, and 16/32-bit integer element types"))) return failure(); - if (op.getScalar().getType() != contract.elementType) + if (!isCompatibleScalarForSemanticType(contract.elementType, + op.getScalar().getType())) return op.emitOpError("tmuls lowering requires scalar type to match source element type"); return buildScalarUnaryVecScope("muls", contract, strategy, op.getSrc0(), op.getScalar(), op.getDst(), rewriter, op.getLoc()); @@ -6539,7 +6558,8 @@ LogicalResult lowerTDivS(TDivSOp op, PatternRewriter &rewriter, [](Type type) { return type.isF16() || type.isF32(); }, "f16 and f32 element types"))) return failure(); - if (scalarOperand.getType() != contract.elementType) + if (!isCompatibleScalarForSemanticType(contract.elementType, + scalarOperand.getType())) return op.emitOpError( "divs lowering requires scalar type to match source element type"); return buildScalarDivVecScope(contract, strategy, tileOperand, scalarOperand, op.getDst(), @@ -6560,7 +6580,8 @@ LogicalResult lowerTAddS(TAddSOp op, PatternRewriter &rewriter, }, "f16, f32, bf16, and 16/32-bit integer element types"))) return failure(); - if (op.getScalar().getType() != contract.elementType) + if (!isCompatibleScalarForSemanticType(contract.elementType, + op.getScalar().getType())) return op.emitOpError("tadds lowering requires scalar type to match source element type"); return buildScalarUnaryVecScope("adds", contract, strategy, op.getSrc(), op.getScalar(), op.getDst(), rewriter, op.getLoc()); @@ -6670,7 +6691,8 @@ LogicalResult lowerTMaxS(TMaxSOp op, PatternRewriter &rewriter, op, contract, op.getDst(), [](Type type) { return type.isF32(); }, "f32 element type"))) return failure(); - if (op.getScalar().getType() != contract.elementType) + if (!isCompatibleScalarForSemanticType(contract.elementType, + op.getScalar().getType())) return op.emitOpError("tmaxs lowering requires scalar type to match source element type"); return buildScalarUnaryVecScope("maxs", contract, strategy, op.getSrc(), op.getScalar(), op.getDst(), rewriter, op.getLoc()); @@ -6683,7 +6705,8 @@ LogicalResult lowerTMinS(TMinSOp op, PatternRewriter &rewriter, op, contract, op.getDst(), [](Type type) { return type.isF32(); }, "f32 element type"))) return failure(); - if (op.getScalar().getType() != contract.elementType) + if (!isCompatibleScalarForSemanticType(contract.elementType, + op.getScalar().getType())) return op.emitOpError("tmins lowering requires scalar type to match source element type"); return buildScalarUnaryVecScope("mins", contract, strategy, op.getSrc(), op.getScalar(), op.getDst(), rewriter, op.getLoc()); diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index e505c17d7..eccabbc0c 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -278,6 +278,37 @@ static FailureOr normalizeVdupScalarOperand(OpBuilder &builder, Location return builder.create(loc, i16Type, input).getResult(); } +static Value normalizeByteScalarOperandForHivmCall(OpBuilder &builder, Location loc, + Value input, + Type semanticElementType) { + auto intType = dyn_cast(input.getType()); + if (!intType || intType.getWidth() != 8) + return input; + + Type i16Type = builder.getIntegerType(16); + auto semanticIntType = dyn_cast(semanticElementType); + if (semanticIntType && semanticIntType.isUnsigned()) + return builder.create(loc, i16Type, input).getResult(); + return builder.create(loc, i16Type, input).getResult(); +} + +static bool isCompatibleScalarForSemanticType(Type semanticType, + Type scalarType) { + if (semanticType == scalarType) + return true; + + auto semanticInt = dyn_cast(semanticType); + auto scalarInt = dyn_cast(scalarType); + if (!semanticInt || !scalarInt || semanticInt.getWidth() != scalarInt.getWidth()) + return false; + + if (semanticInt.isSigned()) + return scalarInt.isSigned() || scalarInt.isSignless(); + if (semanticInt.isUnsigned()) + return scalarInt.isUnsigned() || scalarInt.isSignless(); + return scalarInt.isSignless(); +} + static std::string getCopyElementFragment(Type elementType) { if (!elementType) return {}; @@ -1178,8 +1209,9 @@ static FailureOr buildVdupCallee(MLIRContext *context, pto::VdupOp op .getValue(); } -static FailureOr buildVbrCallee(MLIRContext *context, Type scalarType) { - std::string scalar = getVbrScalarFragment(scalarType); +static FailureOr buildVbrCallee(MLIRContext *context, + Type semanticElementType) { + std::string scalar = getVbrScalarFragment(semanticElementType); if (scalar.empty()) return failure(); return StringAttr::get(context, "llvm.hivm.vbr." + scalar + ".v300").getValue(); @@ -2563,7 +2595,10 @@ class LowerVdupOpPattern final : public OpConversionPattern { callArgs.push_back(input); } else { Type scalarType = getElementTypeFromVectorLike(op.getResult().getType()); - if (!scalarType || op.getInput().getType() != scalarType) { + if (!scalarType || + (op.getInput().getType() != scalarType && + !isCompatibleScalarForSemanticType(scalarType, + op.getInput().getType()))) { return rewriter.notifyMatchFailure(op, "unexpected scalar-input vdup type"); } @@ -2600,7 +2635,8 @@ class LowerVbrOpPattern final : public OpConversionPattern { matchAndRewrite(pto::VbrOp op, pto::VbrOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { FailureOr calleeName = - buildVbrCallee(op.getContext(), op.getValue().getType()); + buildVbrCallee(op.getContext(), + cast(op.getResult().getType()).getElementType()); if (failed(calleeName)) return rewriter.notifyMatchFailure(op, "unsupported vbr VPTO signature"); @@ -2615,6 +2651,10 @@ class LowerVbrOpPattern final : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "unexpected converted vbr operand type"); + scalar = normalizeByteScalarOperandForHivmCall( + rewriter, op.getLoc(), scalar, + cast(op.getResult().getType()).getElementType()); + auto funcType = rewriter.getFunctionType(TypeRange{scalar.getType()}, TypeRange{resultType}); auto call = rewriter.create(op.getLoc(), *calleeName, @@ -3035,6 +3075,9 @@ class LowerCmpOpPattern final : public OpConversionPattern { return rewriter.notifyMatchFailure( op, "unexpected converted scalar-compare operand types"); } + callArgs[1] = normalizeByteScalarOperandForHivmCall( + rewriter, op.getLoc(), callArgs[1], + cast(op.getSrc().getType()).getElementType()); } else { if (callArgs.size() != 3 || !callArgs[0] || !callArgs[1] || !callArgs[2] || callArgs[0].getType() != callArgs[1].getType() || diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/kernel.pto index c95b45413..4df6d097d 100644 --- a/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/kernel.pto @@ -8,8 +8,7 @@ module attributes {pto.target_arch = "a5"} { %c64_i64 = arith.constant 64 : i64 %c256_i64 = arith.constant 256 : i64 %c128_i32 = arith.constant 128 : i32 - %threshold_i16 = arith.constant 513 : i16 - %threshold = builtin.unrealized_conversion_cast %threshold_i16 : i16 to ui16 + %threshold = arith.constant 513 : i16 %false = arith.constant false %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr @@ -27,7 +26,7 @@ module attributes {pto.target_arch = "a5"} { %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c128_i32) -> (i32) { %active, %next = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 %src = pto.vlds %ub_src[%c0] : !pto.ptr -> !pto.vreg<128xui16> - %pred = pto.vcmps %src, %threshold, %active, "gt" : !pto.vreg<128xui16>, ui16, !pto.mask -> !pto.mask + %pred = pto.vcmps %src, %threshold, %active, "gt" : !pto.vreg<128xui16>, i16, !pto.mask -> !pto.mask pto.psts %pred, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index scf.yield %next : i32 } diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/compare.py b/test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/compare.py new file mode 100644 index 000000000..bc2a4827f --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/compare.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_mask(golden_path, output_path): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch (packed mask): idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_mask("golden_v2.bin", "v2.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/golden.py b/test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/golden.py new file mode 100644 index 000000000..9b26d1c00 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/golden.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 256 +SEED = 19 +THRESHOLD = np.int8(5) +OUTPUT_BYTES = 32 + + +def encode_b8_mask(mask: np.ndarray) -> np.ndarray: + out = np.zeros((OUTPUT_BYTES,), dtype=np.uint8) + for i, bit in enumerate(mask.astype(np.uint8, copy=False)): + if bit: + byte_index = i // 8 + bit_shift = i % 8 + out[byte_index] |= np.uint8(1 << bit_shift) + return out + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(-128, 127, size=(LANES,), dtype=np.int8) + mask = np.greater(v1, THRESHOLD) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + np.zeros((OUTPUT_BYTES,), dtype=np.uint8).tofile(output_dir / "v2.bin") + encode_b8_mask(mask).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for VPTO vcmps-i8-signed.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/kernel.pto new file mode 100644 index 000000000..be566c5ec --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/kernel.pto @@ -0,0 +1,45 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vcmps_i8_signed_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %c256_i32 = arith.constant 256 : i32 + %threshold = arith.constant 5 : i8 + %false = arith.constant false + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c256_i64 : i64 -> !pto.ptr + pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.dma_load %arg0, %ub_src, %c0_i64, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c256_i32) -> (i32) { + %active, %next = pto.plt_b8 %remaining : i32 -> !pto.mask, i32 + %src = pto.vlds %ub_src[%c0] : !pto.ptr -> !pto.vreg<256xsi8> + %pred = pto.vcmps %src, %threshold, %active, "gt" : !pto.vreg<256xsi8>, i8, !pto.mask -> !pto.mask + pto.psts %pred, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/launch.cpp b/test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/launch.cpp new file mode 100644 index 000000000..3423328de --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcmps_i8_signed_kernel_2d(__gm__ int8_t *v1, + __gm__ unsigned char *v2); + +void LaunchVcmps_i8_signed_kernel_2d(int8_t *v1, unsigned char *v2, void *stream) { + vcmps_i8_signed_kernel_2d<<<1, nullptr, stream>>>((__gm__ int8_t *)v1, + (__gm__ unsigned char *)v2); +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/main.cpp b/test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/main.cpp new file mode 100644 index 000000000..bb2a47b15 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/main.cpp @@ -0,0 +1,92 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcmps_i8_signed_kernel_2d(int8_t *v1, unsigned char *v2, void *stream); + +int main() { + size_t elemCount_v1 = 256; + size_t fileSize_v1 = elemCount_v1 * sizeof(int8_t); + size_t elemCount_v2 = 32; + size_t fileSize_v2 = elemCount_v2 * sizeof(unsigned char); + int8_t *v1Host = nullptr; + int8_t *v1Device = nullptr; + unsigned char *v2Host = nullptr; + unsigned char *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVcmps_i8_signed_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/stub.cpp b/test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/stub.cpp new file mode 100644 index 000000000..5cca51101 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/stub.cpp @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vcmps_i8_signed_kernel_2d(__gm__ int8_t *v1, + __gm__ unsigned char *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/compare.py b/test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/compare.py new file mode 100644 index 000000000..bc2a4827f --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/compare.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_mask(golden_path, output_path): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch (packed mask): idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_mask("golden_v2.bin", "v2.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/golden.py b/test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/golden.py new file mode 100644 index 000000000..334f772c7 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/golden.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 256 +SEED = 19 +THRESHOLD = np.uint8(129) +OUTPUT_BYTES = 32 + + +def encode_b8_mask(mask: np.ndarray) -> np.ndarray: + out = np.zeros((OUTPUT_BYTES,), dtype=np.uint8) + for i, bit in enumerate(mask.astype(np.uint8, copy=False)): + if bit: + byte_index = i // 8 + bit_shift = i % 8 + out[byte_index] |= np.uint8(1 << bit_shift) + return out + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, 255, size=(LANES,), dtype=np.uint8) + mask = np.greater(v1, THRESHOLD) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + np.zeros((OUTPUT_BYTES,), dtype=np.uint8).tofile(output_dir / "v2.bin") + encode_b8_mask(mask).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for VPTO vcmps-i8-unsigned.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/kernel.pto new file mode 100644 index 000000000..a051e39a6 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/kernel.pto @@ -0,0 +1,45 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vcmps_i8_unsigned_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %c256_i32 = arith.constant 256 : i32 + %threshold = arith.constant -127 : i8 + %false = arith.constant false + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c256_i64 : i64 -> !pto.ptr + pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.dma_load %arg0, %ub_src, %c0_i64, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c256_i32) -> (i32) { + %active, %next = pto.plt_b8 %remaining : i32 -> !pto.mask, i32 + %src = pto.vlds %ub_src[%c0] : !pto.ptr -> !pto.vreg<256xui8> + %pred = pto.vcmps %src, %threshold, %active, "gt" : !pto.vreg<256xui8>, i8, !pto.mask -> !pto.mask + pto.psts %pred, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/launch.cpp b/test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/launch.cpp new file mode 100644 index 000000000..ce3b2eca9 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcmps_i8_unsigned_kernel_2d(__gm__ uint8_t *v1, + __gm__ unsigned char *v2); + +void LaunchVcmps_i8_unsigned_kernel_2d(uint8_t *v1, unsigned char *v2, void *stream) { + vcmps_i8_unsigned_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint8_t *)v1, + (__gm__ unsigned char *)v2); +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/main.cpp b/test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/main.cpp new file mode 100644 index 000000000..f24020b81 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/main.cpp @@ -0,0 +1,92 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcmps_i8_unsigned_kernel_2d(uint8_t *v1, unsigned char *v2, void *stream); + +int main() { + size_t elemCount_v1 = 256; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint8_t); + size_t elemCount_v2 = 32; + size_t fileSize_v2 = elemCount_v2 * sizeof(unsigned char); + uint8_t *v1Host = nullptr; + uint8_t *v1Device = nullptr; + unsigned char *v2Host = nullptr; + unsigned char *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVcmps_i8_unsigned_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/stub.cpp b/test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/stub.cpp new file mode 100644 index 000000000..9edba2229 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/stub.cpp @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vcmps_i8_unsigned_kernel_2d(__gm__ uint8_t *v1, + __gm__ unsigned char *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-i8/compare.py b/test/vpto/cases/micro-op/materialization-predicate/vbr-i8/compare.py new file mode 100644 index 000000000..ceb7196aa --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-i8/compare.py @@ -0,0 +1,204 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v1.bin", "v1.bin", np.int8, 0) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-i8/golden.py b/test/vpto/cases/micro-op/materialization-predicate/vbr-i8/golden.py new file mode 100644 index 000000000..ea6c7ad81 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-i8/golden.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +VALUE = np.int8(-7) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + v1 = np.zeros((ROWS, COLS), dtype=np.int8) + golden_v1 = np.full((ROWS, COLS), VALUE, dtype=np.int8) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + golden_v1.reshape(-1).tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vbr-i8 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-i8/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/vbr-i8/kernel.pto new file mode 100644 index 000000000..80b6d376d --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-i8/kernel.pto @@ -0,0 +1,49 @@ +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5"} { + func.func @vbr_i8_kernel_2d(%arg0: !pto.ptr) { + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %cst = arith.constant -7 : i8 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.vecscope { + %active = pto.pset_b8 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c256 { + %vec = pto.vbr %cst : i8 -> !pto.vreg<256xsi8> + pto.vsts %vec, %ub_out[%offset], %active : !pto.vreg<256xsi8>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.dma_store %ub_out, %arg0, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-i8/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/vbr-i8/launch.cpp new file mode 100644 index 000000000..d3b16ce39 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-i8/launch.cpp @@ -0,0 +1,61 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vbr_i8_kernel_2d(__gm__ int8_t *v1); + +void LaunchVbr_i8_kernel_2d(int8_t *v1, void *stream) { + vbr_i8_kernel_2d<<<1, nullptr, stream>>>((__gm__ int8_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-i8/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/vbr-i8/main.cpp new file mode 100644 index 000000000..e9b756b4d --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-i8/main.cpp @@ -0,0 +1,111 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVbr_i8_kernel_2d(int8_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(int8_t); + int8_t *v1Host = nullptr; + int8_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVbr_i8_kernel_2d(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-i8/stub.cpp b/test/vpto/cases/micro-op/materialization-predicate/vbr-i8/stub.cpp new file mode 100644 index 000000000..c71834af8 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-i8/stub.cpp @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vbr_i8_kernel_2d(__gm__ int8_t *v1) { + (void)v1; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-u8/compare.py b/test/vpto/cases/micro-op/materialization-predicate/vbr-u8/compare.py new file mode 100644 index 000000000..207bac38b --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-u8/compare.py @@ -0,0 +1,204 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v1.bin", "v1.bin", np.uint8, 0) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-u8/golden.py b/test/vpto/cases/micro-op/materialization-predicate/vbr-u8/golden.py new file mode 100644 index 000000000..41a4e75f9 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-u8/golden.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +VALUE = np.uint8(201) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + v1 = np.zeros((ROWS, COLS), dtype=np.uint8) + golden_v1 = np.full((ROWS, COLS), VALUE, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + golden_v1.reshape(-1).tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vbr-u8 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-u8/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/vbr-u8/kernel.pto new file mode 100644 index 000000000..23da4ecce --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-u8/kernel.pto @@ -0,0 +1,49 @@ +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5"} { + func.func @vbr_u8_kernel_2d(%arg0: !pto.ptr) { + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %cst = arith.constant -55 : i8 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.vecscope { + %active = pto.pset_b8 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c256 { + %vec = pto.vbr %cst : i8 -> !pto.vreg<256xui8> + pto.vsts %vec, %ub_out[%offset], %active : !pto.vreg<256xui8>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.dma_store %ub_out, %arg0, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-u8/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/vbr-u8/launch.cpp new file mode 100644 index 000000000..20b8a37c7 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-u8/launch.cpp @@ -0,0 +1,61 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vbr_u8_kernel_2d(__gm__ uint8_t *v1); + +void LaunchVbr_u8_kernel_2d(uint8_t *v1, void *stream) { + vbr_u8_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint8_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-u8/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/vbr-u8/main.cpp new file mode 100644 index 000000000..53637c15b --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-u8/main.cpp @@ -0,0 +1,111 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVbr_u8_kernel_2d(uint8_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint8_t); + uint8_t *v1Host = nullptr; + uint8_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVbr_u8_kernel_2d(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-u8/stub.cpp b/test/vpto/cases/micro-op/materialization-predicate/vbr-u8/stub.cpp new file mode 100644 index 000000000..44bab5746 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-u8/stub.cpp @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// The runtime launcher resolves the real device implementation from the +// embedded aibinary. The host-side fatobj still needs a concrete kernel symbol +// with the final ABI name, but it does not need the original EmitC body. +extern "C" __global__ [aicore] void vbr_u8_kernel_2d(__gm__ uint8_t *v1) { + (void)v1; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/kernel.pto index 3cf27b45b..83729c53e 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/kernel.pto @@ -5,7 +5,7 @@ // scenarios: core-i8, scalar-broadcast // ----------------------------------------------------------------------------- module attributes {pto.target_arch = "a5"} { - func.func @vdup_scalar_i8_kernel_2d(%arg0: !pto.ptr) { + func.func @vdup_scalar_i8_kernel_2d(%arg0: !pto.ptr) { %c0 = arith.constant 0 : index %c128 = arith.constant 128 : index %c1024 = arith.constant 1024 : index @@ -15,13 +15,13 @@ module attributes {pto.target_arch = "a5"} { %c128_i64 = arith.constant 128 : i64 %cst = arith.constant -83 : i8 - %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr pto.vecscope { %active = pto.pset_b8 "PAT_ALL" : !pto.mask scf.for %offset = %c0 to %c1024 step %c128 { - %vec = pto.vdup %cst, %active : i8, !pto.mask -> !pto.vreg<256xi8> - pto.vsts %vec, %ub_out[%offset], %active : !pto.vreg<256xi8>, !pto.ptr, !pto.mask + %vec = pto.vdup %cst, %active : i8, !pto.mask -> !pto.vreg<256xsi8> + pto.vsts %vec, %ub_out[%offset], %active : !pto.vreg<256xsi8>, !pto.ptr, !pto.mask } } @@ -29,7 +29,7 @@ module attributes {pto.target_arch = "a5"} { pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.dma_store %ub_out, %arg0, %c0_i64, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/compare.py b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/compare.py new file mode 100644 index 000000000..54831ce84 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/compare.py @@ -0,0 +1,54 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden[idx])}, out={int(output[idx])}, dtype={dtype_np})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v1.bin", "v1.bin", np.uint8) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/golden.py b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/golden.py new file mode 100644 index 000000000..20eed0a35 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/golden.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +VALUE = np.uint8(173) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + v1 = np.zeros((ROWS, COLS), dtype=np.uint8) + golden_v1 = np.full((ROWS, COLS), VALUE, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + golden_v1.reshape(-1).tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vdup-scalar-u8 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/kernel.pto new file mode 100644 index 000000000..9dc6debf7 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/kernel.pto @@ -0,0 +1,36 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/vdup-scalar-u8 +// family: materialization-predicate +// target_ops: pto.vdup +// scenarios: core-u8, scalar-broadcast-signless +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vdup_scalar_u8_kernel_2d(%arg0: !pto.ptr) { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %cst = arith.constant -83 : i8 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.vecscope { + %active = pto.pset_b8 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %vec = pto.vdup %cst, %active : i8, !pto.mask -> !pto.vreg<256xui8> + pto.vsts %vec, %ub_out[%offset], %active : !pto.vreg<256xui8>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.dma_store %ub_out, %arg0, %c0_i64, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/launch.cpp new file mode 100644 index 000000000..d25c635c9 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/launch.cpp @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vdup_scalar_u8_kernel_2d(__gm__ uint8_t *v1); + +void LaunchVdup_scalar_u8_kernel_2d(uint8_t *v1, void *stream) { + vdup_scalar_u8_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint8_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/main.cpp new file mode 100644 index 000000000..32dfff59a --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/main.cpp @@ -0,0 +1,101 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVdup_scalar_u8_kernel_2d(uint8_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint8_t); + uint8_t *v1Host = nullptr; + uint8_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVdup_scalar_u8_kernel_2d(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/stub.cpp b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/stub.cpp new file mode 100644 index 000000000..cb150965d --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/stub.cpp @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vdup_scalar_u8_kernel_2d(__gm__ uint8_t *v1) { + (void)v1; +} From c1d7dcea9d6961ffcb7ec86948ff980fa14bf559 Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Fri, 24 Apr 2026 00:17:04 +0800 Subject: [PATCH 145/192] feat: vci support si8/si16/si32/f16/f32 --- docs/isa/09-conversion-ops.md | 18 +++-- docs/isa/13-dsa-sfu-ops.md | 24 +++--- include/PTO/IR/VPTOOps.td | 2 +- lib/PTO/IR/VPTO.cpp | 15 ++-- lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 21 ++++- .../cases/micro-op/dsa-sfu/vci-f16/compare.py | 42 ++++++++++ .../cases/micro-op/dsa-sfu/vci-f16/golden.py | 49 ++++++++++++ .../cases/micro-op/dsa-sfu/vci-f16/kernel.pto | 26 ++++++ .../cases/micro-op/dsa-sfu/vci-f16/launch.cpp | 44 ++++++++++ .../cases/micro-op/dsa-sfu/vci-f16/main.cpp | 79 ++++++++++++++++++ .../cases/micro-op/dsa-sfu/vci-f16/stub.cpp | 21 +++++ .../cases/micro-op/dsa-sfu/vci-si8/compare.py | 42 ++++++++++ .../cases/micro-op/dsa-sfu/vci-si8/golden.py | 49 ++++++++++++ .../cases/micro-op/dsa-sfu/vci-si8/kernel.pto | 32 ++++++++ .../cases/micro-op/dsa-sfu/vci-si8/launch.cpp | 44 ++++++++++ .../cases/micro-op/dsa-sfu/vci-si8/main.cpp | 80 +++++++++++++++++++ .../cases/micro-op/dsa-sfu/vci-si8/stub.cpp | 23 ++++++ .../cases/micro-op/dsa-sfu/vci/kernel.pto | 34 ++------ .../cases/micro-op/dsa-sfu/vci/launch.cpp | 18 ++--- test/vpto/cases/micro-op/dsa-sfu/vci/main.cpp | 27 ++----- test/vpto/cases/micro-op/dsa-sfu/vci/stub.cpp | 13 +-- 21 files changed, 607 insertions(+), 96 deletions(-) create mode 100755 test/vpto/cases/micro-op/dsa-sfu/vci-f16/compare.py create mode 100755 test/vpto/cases/micro-op/dsa-sfu/vci-f16/golden.py create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vci-f16/kernel.pto create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vci-f16/launch.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vci-f16/main.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vci-f16/stub.cpp create mode 100755 test/vpto/cases/micro-op/dsa-sfu/vci-si8/compare.py create mode 100755 test/vpto/cases/micro-op/dsa-sfu/vci-si8/golden.py create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vci-si8/kernel.pto create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vci-si8/launch.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vci-si8/main.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vci-si8/stub.cpp diff --git a/docs/isa/09-conversion-ops.md b/docs/isa/09-conversion-ops.md index 9ea2f5b78..090921a86 100644 --- a/docs/isa/09-conversion-ops.md +++ b/docs/isa/09-conversion-ops.md @@ -29,17 +29,21 @@ Cycle-accurate simulator **popped→retire** latency (cycles). Only representati ## `pto.vci` -- **syntax:** `%result = pto.vci %index {order = "ASC|DESC"} : integer -> !pto.vreg` -- **semantics:** Generate a lane-index vector from a scalar seed/index value. +- **syntax:** `%result = pto.vci %index {order = "ASC|DESC"} : T -> !pto.vreg` +- **semantics:** Generate a lane-index vector from a scalar base value. - **inputs:** - `%index` is the scalar seed or base index. + `%index` is the scalar base value. Supported scalar types are `i8/i16/i32`, + `f16`, and `f32`. - **outputs:** `%result` is the generated index vector. - **constraints and limitations:** - This is an index-generation family, not a numeric conversion. `order` and the - result element type together determine how indices are generated. `%result` - uses an integer element type, and the scalar `%index` type matches that - result element type. + This is an index-generation family, not a numeric conversion. `order` and + the result element type together determine whether lanes are generated as + `base + lane_id` or `base - lane_id`. Supported result types are + `!pto.vreg<256xsi8>`, `!pto.vreg<128xsi16>`, `!pto.vreg<64xsi32>`, + `!pto.vreg<128xf16>`, and `!pto.vreg<64xf32>`. `%index` must use the + matching scalar type for `f16`/`f32`; for integer results, `%index` must use + the same bit width and may be signless or signed. --- diff --git a/docs/isa/13-dsa-sfu-ops.md b/docs/isa/13-dsa-sfu-ops.md index c32fedd79..f4d88fe98 100644 --- a/docs/isa/13-dsa-sfu-ops.md +++ b/docs/isa/13-dsa-sfu-ops.md @@ -149,21 +149,25 @@ for (int i = 0; i < N; i++) ### `pto.vci` -- **syntax:** `%result = pto.vci %index {order = "ASC|DESC"} : integer -> !pto.vreg` -- **semantics:** Generate lane index vector. +- **syntax:** `%result = pto.vci %index {order = "ASC|DESC"} : T -> !pto.vreg` +- **semantics:** Generate a lane index vector from a scalar base value. ```c for (int i = 0; i < N; i++) - dst[i] = base_index + i; + dst[i] = (order == ASC) ? (base_index + i) : (base_index - i); ``` **Use case:** Generate indices for gather/scatter, argsort, etc. -- **inputs:** `%index` is the scalar seed/base index. +- **inputs:** `%index` is the scalar base value. Supported scalar types are + `i8/i16/i32`, `f16`, and `f32`. - **outputs:** `%result` is the generated index vector. -- **constraints and limitations:** This page documents the arithmetic/indexing - use of the family; the conversion page also records the same opcode for - completeness. +- **constraints and limitations:** `%result` element type determines both the + generated element type and the lane count. Supported result types are + `!pto.vreg<256xsi8>`, `!pto.vreg<128xsi16>`, `!pto.vreg<64xsi32>`, + `!pto.vreg<128xf16>`, and `!pto.vreg<64xf32>`. `%index` must use the + matching scalar type for `f16`/`f32`; for integer results, `%index` must use + the same bit width and may be signless or signed. --- @@ -208,7 +212,7 @@ for (int i = 0; i < N; i++) - `pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` - `pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` -- `pto.vci %index {order = "ASC|DESC"} : integer -> !pto.vreg` +- `pto.vci %index {order = "ASC|DESC"} : T -> !pto.vreg` - `pto.vbitsort %dest, %src, %indices, %repeat_times : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, index` - `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, i64, i64` @@ -224,6 +228,6 @@ for (int i = 0; i < N; i++) // Leaky ReLU activation %activated = pto.vlrelu %linear_out, %alpha_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> -// Generate indices for argsort -%indices = pto.vci %c0 {order = "ASC"} : i32 -> !pto.vreg<64xi32> +// Generate ascending si32 indices for argsort +%indices = pto.vci %c0 {order = "ASC"} : i32 -> !pto.vreg<64xsi32> ``` diff --git a/include/PTO/IR/VPTOOps.td b/include/PTO/IR/VPTOOps.td index 8630c8936..35c9d7723 100644 --- a/include/PTO/IR/VPTOOps.td +++ b/include/PTO/IR/VPTOOps.td @@ -1031,7 +1031,7 @@ def PTO_VbitcastOp : PTO_Op<"vbitcast", [Pure]> { def PTO_VciOp : PTO_Op<"vci", [Pure]> { let arguments = (ins - AnyInteger:$index, + AnyTypeOf<[AnyInteger, AnyFloat], "integer/float scalar">:$index, OptionalAttr:$order ); let results = (outs PTO_VectorType:$result); diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp index 7cbfcdf4d..1413e70ef 100644 --- a/lib/PTO/IR/VPTO.cpp +++ b/lib/PTO/IR/VPTO.cpp @@ -1640,12 +1640,15 @@ LogicalResult VciOp::verify() { auto resultType = dyn_cast(getResult().getType()); if (!resultType) return emitOpError("result must be !pto.vreg<...>"); - if (!isa(resultType.getElementType())) - return emitOpError("result element type must be integer"); - auto indexType = dyn_cast(getIndex().getType()); - if (!indexType) - return emitOpError("index must be an integer scalar"); - if (indexType != resultType.getElementType()) + Type resultElemType = resultType.getElementType(); + bool supportedInteger = false; + if (auto intType = dyn_cast(resultElemType)) + supportedInteger = intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + bool supportedFloat = resultElemType.isF16() || resultElemType.isF32(); + if (!supportedInteger && !supportedFloat) + return emitOpError("result element type must be integer or f16/f32"); + if (!isCompatibleScalarForSemanticType(resultElemType, getIndex().getType())) return emitOpError("index type must match result element type"); return success(); } diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index eccabbc0c..d4dc0eb63 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -4200,13 +4200,30 @@ class LowerVciOpPattern final : public OpConversionPattern { if (failed(calleeName)) return rewriter.notifyMatchFailure(op, "unsupported vci callee"); + Value indexValue = adaptor.getIndex(); + Type resultElemType = + cast(op.getResult().getType()).getElementType(); + if (auto intType = dyn_cast(resultElemType)) { + if (intType.getWidth() == 8) { + Type loweredIndexType = rewriter.getI16Type(); + if (intType.isUnsigned()) + indexValue = rewriter.create(op.getLoc(), + loweredIndexType, + indexValue); + else + indexValue = rewriter.create(op.getLoc(), + loweredIndexType, + indexValue); + } + } + Value orderValue = getI32Constant(rewriter, op.getLoc(), *order); auto funcType = rewriter.getFunctionType( - TypeRange{adaptor.getIndex().getType(), orderValue.getType()}, + TypeRange{indexValue.getType(), orderValue.getType()}, TypeRange{resultType}); auto call = rewriter.create( op.getLoc(), *calleeName, TypeRange{resultType}, - ValueRange{adaptor.getIndex(), orderValue}); + ValueRange{indexValue, orderValue}); state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); rewriter.replaceOp(op, call.getResults()); return success(); diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci-f16/compare.py b/test/vpto/cases/micro-op/dsa-sfu/vci-f16/compare.py new file mode 100755 index 000000000..8c2628b88 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vci-f16/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vci +# family: dsa-sfu / conversion +# target_ops: pto.vci +# scenarios: index-generation +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float16, 0.001) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci-f16/golden.py b/test/vpto/cases/micro-op/dsa-sfu/vci-f16/golden.py new file mode 100755 index 000000000..c19fcdb99 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vci-f16/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vci +# family: dsa-sfu / conversion +# target_ops: pto.vci +# scenarios: index-generation +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 1 +COLS = 128 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + _ = seed + v1 = np.zeros((ROWS, COLS), dtype=np.float16) + v2 = np.zeros((ROWS, COLS), dtype=np.float16) + golden_v2 = np.arange(ROWS * COLS, dtype=np.float16).reshape(ROWS, COLS) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci-f16/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vci-f16/kernel.pto new file mode 100644 index 000000000..f2b1d360e --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vci-f16/kernel.pto @@ -0,0 +1,26 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vci_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c2_i64 = arith.constant 2 : i64 + %c128_i64 = arith.constant 128 : i64 + %cst = arith.constant 0.0 : f16 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + %indices = pto.vci %cst {order = "ASC"} : f16 -> !pto.vreg<128xf16> + pto.vsts %indices, %ub_out[%c0], %mask : !pto.vreg<128xf16>, !pto.ptr, !pto.mask + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c2_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci-f16/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vci-f16/launch.cpp new file mode 100644 index 000000000..8647dab79 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vci-f16/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vci_kernel_2d(__gm__ half *v1, + __gm__ half *v2); + +void LaunchVci_kernel_2d(aclFloat16 *v1, aclFloat16 *v2, void *stream) { + vci_kernel_2d<<<1, nullptr, stream>>>((__gm__ half *)v1, + (__gm__ half *)v2); +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci-f16/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vci-f16/main.cpp new file mode 100644 index 000000000..b628b0747 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vci-f16/main.cpp @@ -0,0 +1,79 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVci_kernel_2d(aclFloat16 *v1, aclFloat16 *v2, void *stream); + +int main() { + size_t elemCount_v1 = 128; + size_t fileSize_v1 = elemCount_v1 * sizeof(aclFloat16); + size_t elemCount_v2 = 128; + size_t fileSize_v2 = elemCount_v2 * sizeof(aclFloat16); + aclFloat16 *v1Host = nullptr; + aclFloat16 *v1Device = nullptr; + aclFloat16 *v2Host = nullptr; + aclFloat16 *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVci_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci-f16/stub.cpp b/test/vpto/cases/micro-op/dsa-sfu/vci-f16/stub.cpp new file mode 100644 index 000000000..915e7d018 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vci-f16/stub.cpp @@ -0,0 +1,21 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vci_kernel_2d(__gm__ half *v1, + __gm__ half *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci-si8/compare.py b/test/vpto/cases/micro-op/dsa-sfu/vci-si8/compare.py new file mode 100755 index 000000000..326fa7450 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vci-si8/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vci +# family: dsa-sfu / conversion +# target_ops: pto.vci +# scenarios: index-generation +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.int8, 0) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci-si8/golden.py b/test/vpto/cases/micro-op/dsa-sfu/vci-si8/golden.py new file mode 100755 index 000000000..b3482d94d --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vci-si8/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vci +# family: dsa-sfu / conversion +# target_ops: pto.vci +# scenarios: index-generation +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + _ = seed + v1 = np.zeros((ROWS, COLS), dtype=np.int8) + v2 = np.zeros((ROWS, COLS), dtype=np.int8) + golden_v2 = np.arange(ROWS * COLS, dtype=np.int32).astype(np.int8).reshape(ROWS, COLS) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci-si8/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vci-si8/kernel.pto new file mode 100644 index 000000000..59397a667 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vci-si8/kernel.pto @@ -0,0 +1,32 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vci_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c8_i64 = arith.constant 8 : i64 + %c128_i64 = arith.constant 128 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c256 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b8 %remaining : i32 -> !pto.mask, i32 + %base = arith.index_castui %offset : index to i8 + %indices = pto.vci %base {order = "ASC"} : i8 -> !pto.vreg<256xsi8> + pto.vsts %indices, %ub_out[%offset], %mask : !pto.vreg<256xsi8>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + nburst(%c8_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci-si8/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vci-si8/launch.cpp new file mode 100644 index 000000000..0b3f33084 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vci-si8/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vci_kernel_2d(__gm__ int8_t *v1, + __gm__ int8_t *v2); + +void LaunchVci_kernel_2d(int8_t *v1, int8_t *v2, void *stream) { + vci_kernel_2d<<<1, nullptr, stream>>>((__gm__ int8_t *)v1, + (__gm__ int8_t *)v2); +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci-si8/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vci-si8/main.cpp new file mode 100644 index 000000000..204d6efa9 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vci-si8/main.cpp @@ -0,0 +1,80 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVci_kernel_2d(int8_t *v1, int8_t *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(int8_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int8_t); + int8_t *v1Host = nullptr; + int8_t *v1Device = nullptr; + int8_t *v2Host = nullptr; + int8_t *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVci_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci-si8/stub.cpp b/test/vpto/cases/micro-op/dsa-sfu/vci-si8/stub.cpp new file mode 100644 index 000000000..3a7ba2167 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vci-si8/stub.cpp @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vci_kernel_2d(__gm__ int8_t *v1, + __gm__ int8_t *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vci/kernel.pto index 577e9dd51..3286f2a14 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vci/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vci/kernel.pto @@ -1,44 +1,22 @@ -// ----------------------------------------------------------------------------- -// case: micro-op/dsa-sfu/vci -// family: dsa-sfu / conversion -// target_ops: pto.vci -// scenarios: index-generation -// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is -// still a valid test conclusion in the current coverage-first phase. -// ----------------------------------------------------------------------------- module attributes {pto.target_arch = "a5"} { - func.func @vci_kernel_2d(%arg0: !pto.ptr, - %arg1: !pto.ptr) { + func.func @vci_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) { %c0 = arith.constant 0 : index %c64 = arith.constant 64 : index %c1024 = arith.constant 1024 : index %c0_i64 = arith.constant 0 : i64 - %c1_i64 = arith.constant 1 : i64 %c32_i64 = arith.constant 32 : i64 %c128_i64 = arith.constant 128 : i64 - %c4096_i64 = arith.constant 4096 : i64 %c1024_i32 = arith.constant 1024 : i32 - %ub_zero = pto.castptr %c0_i64 : i64 -> !pto.ptr - %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr - - %false = arith.constant false - pto.dma_load %arg0, %ub_zero, %c0_i64, %c0_i64, %c128_i64 - nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg0, %ub_out, %c0_i64, %c0_i64, %c128_i64 - nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - - pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] - pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr pto.vecscope { %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 %base = arith.index_castui %offset : index to i32 - %indices = pto.vci %base {order = "ASC"} : i32 -> !pto.vreg<64xi32> - pto.vsts %indices, %ub_out[%offset], %mask : !pto.vreg<64xi32>, !pto.ptr, !pto.mask + %indices = pto.vci %base {order = "ASC"} : i32 -> !pto.vreg<64xsi32> + pto.vsts %indices, %ub_out[%offset], %mask : !pto.vreg<64xsi32>, !pto.ptr, !pto.mask scf.yield %next_remaining : i32 } } @@ -47,7 +25,7 @@ module attributes {pto.target_arch = "a5"} { pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vci/launch.cpp index 33957c516..0ce203973 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vci/launch.cpp +++ b/test/vpto/cases/micro-op/dsa-sfu/vci/launch.cpp @@ -6,14 +6,6 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// ----------------------------------------------------------------------------- -// case: micro-op/dsa-sfu/vci -// family: dsa-sfu / conversion -// target_ops: pto.vci -// scenarios: index-generation -// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is -// still a valid test conclusion in the current coverage-first phase. -// ----------------------------------------------------------------------------- #ifndef __VEC_SCOPE__ #define __VEC_SCOPE__ #endif @@ -43,10 +35,10 @@ struct MrgSortExecutedNumList { #include "acl/acl.h" #endif -extern "C" __global__ [aicore] void vci_kernel_2d(__gm__ int *v1, - __gm__ int *v2); +extern "C" __global__ [aicore] void vci_kernel_2d(__gm__ int32_t *v1, + __gm__ int32_t *v2); -void LaunchVci_kernel_2d(int *v1, int *v2, void *stream) { - vci_kernel_2d<<<1, nullptr, stream>>>((__gm__ int *)v1, - (__gm__ int *)v2); +void LaunchVci_kernel_2d(int32_t *v1, int32_t *v2, void *stream) { + vci_kernel_2d<<<1, nullptr, stream>>>((__gm__ int32_t *)v1, + (__gm__ int32_t *)v2); } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vci/main.cpp index 2d828b0ba..0baf928bd 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vci/main.cpp +++ b/test/vpto/cases/micro-op/dsa-sfu/vci/main.cpp @@ -6,20 +6,9 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// ----------------------------------------------------------------------------- -// case: micro-op/dsa-sfu/vci -// family: dsa-sfu / conversion -// target_ops: pto.vci -// scenarios: index-generation -// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is -// still a valid test conclusion in the current coverage-first phase. -// ----------------------------------------------------------------------------- -/** -Copyright (c) 2025 Huawei Technologies Co., Ltd. -*/ - #include "test_common.h" #include "acl/acl.h" +#include #include #include @@ -36,17 +25,17 @@ using namespace PtoTestCommon; } \ } while (0) -void LaunchVci_kernel_2d(int *v1, int *v2, void *stream); +void LaunchVci_kernel_2d(int32_t *v1, int32_t *v2, void *stream); int main() { size_t elemCount_v1 = 1024; - size_t fileSize_v1 = elemCount_v1 * sizeof(int); + size_t fileSize_v1 = elemCount_v1 * sizeof(int32_t); size_t elemCount_v2 = 1024; - size_t fileSize_v2 = elemCount_v2 * sizeof(int); - int *v1Host = nullptr; - int *v1Device = nullptr; - int *v2Host = nullptr; - int *v2Device = nullptr; + size_t fileSize_v2 = elemCount_v2 * sizeof(int32_t); + int32_t *v1Host = nullptr; + int32_t *v1Device = nullptr; + int32_t *v2Host = nullptr; + int32_t *v2Device = nullptr; int rc = 0; bool aclInited = false; bool deviceSet = false; diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci/stub.cpp b/test/vpto/cases/micro-op/dsa-sfu/vci/stub.cpp index 8f04031ff..2c824129c 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vci/stub.cpp +++ b/test/vpto/cases/micro-op/dsa-sfu/vci/stub.cpp @@ -6,14 +6,7 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// ----------------------------------------------------------------------------- -// case: micro-op/dsa-sfu/vci -// family: dsa-sfu / conversion -// target_ops: pto.vci -// scenarios: index-generation -// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is -// still a valid test conclusion in the current coverage-first phase. -// ----------------------------------------------------------------------------- +#include #ifndef __global__ #define __global__ @@ -23,8 +16,8 @@ #define __gm__ #endif -extern "C" __global__ [aicore] void vci_kernel_2d(__gm__ int *v1, - __gm__ int *v2) { +extern "C" __global__ [aicore] void vci_kernel_2d(__gm__ int32_t *v1, + __gm__ int32_t *v2) { (void)v1; (void)v2; } From ec226edfb88e11c315dea6efaf0df83552567d83 Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Fri, 24 Apr 2026 01:23:19 +0800 Subject: [PATCH 146/192] feat: remove dead code --- include/PTO/Transforms/Passes.h | 2 - include/PTO/Transforms/Passes.td | 24 - include/PTO/Transforms/VPTOLowering.h | 214 +- lib/PTO/Transforms/CMakeLists.txt | 3 +- lib/PTO/Transforms/PTOToVPTO.cpp | 612 -- lib/PTO/Transforms/PTOToVPTOLowering.cpp | 7367 ----------------- .../Transforms/VPTOBufferMaterialization.cpp | 89 + tools/ptoas/ptoas.cpp | 1 - 8 files changed, 91 insertions(+), 8221 deletions(-) delete mode 100644 lib/PTO/Transforms/PTOToVPTO.cpp delete mode 100644 lib/PTO/Transforms/PTOToVPTOLowering.cpp create mode 100644 lib/PTO/Transforms/VPTOBufferMaterialization.cpp diff --git a/include/PTO/Transforms/Passes.h b/include/PTO/Transforms/Passes.h index 2994168a0..e5a847b6e 100644 --- a/include/PTO/Transforms/Passes.h +++ b/include/PTO/Transforms/Passes.h @@ -70,8 +70,6 @@ std::unique_ptr createVPTOPtrNormalizePass(); std::unique_ptr createVPTOPtrCastCleanupPass(); std::unique_ptr createPTOValidateVPTOIRPass(); std::unique_ptr createPTOValidateVPTOEmissionIRPass(); -std::unique_ptr createLowerPTOToVPTOPass(); -std::unique_ptr createLowerPTOToVPTOPass(StringRef loweringStrategy); std::unique_ptr createMemrefToTileBufPass(); std::unique_ptr createExpandTileOpPass(); std::unique_ptr createExpandTileOpPass(const ExpandTileOpOptions &options); diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index 6c07552d0..b7c95cb57 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -463,28 +463,4 @@ def VPTOPtrCastCleanup "mlir::memref::MemRefDialect"]; } -def PTOToVPTO : Pass<"pto-to-vpto", "ModuleOp"> { - let summary = "Lower PTO tile ops to VPTO backend ops"; - let description = [{ - Lowers PTO tile ops to VPTO backend ops. For already-planned fusion groups, - the pass rewrites the `pto.fusion_region` body in place and preserves the - wrapper until explicit flatten. Residual non-fused PTO ops may continue to - be lowered directly in their parent block and are not wrapped into - synthetic `pto.fusion_region` containers solely for backend lowering. - }]; - let constructor = "mlir::pto::createLowerPTOToVPTOPass()"; - let options = [ - Option<"loweringStrategy", "pto-lowering-strategy", "std::string", - "\"post-update\"", - "vector lowering strategy: post-update or no-post-update"> - ]; - let dependentDialects = [ - "pto::PTODialect", - "func::FuncDialect", - "arith::ArithDialect", - "memref::MemRefDialect", - "scf::SCFDialect" - ]; -} - #endif // MLIR_DIALECT_PTO_PASSES diff --git a/include/PTO/Transforms/VPTOLowering.h b/include/PTO/Transforms/VPTOLowering.h index 2c80d2354..2f7a8332e 100644 --- a/include/PTO/Transforms/VPTOLowering.h +++ b/include/PTO/Transforms/VPTOLowering.h @@ -6,7 +6,7 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -//===- VPTOLowering.h - PTO to VPTO lowering contracts ----------*- C++ -*-===// +//===- VPTOLowering.h - VPTO buffer materialization contracts ---*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -18,228 +18,16 @@ #define MLIR_DIALECT_PTO_TRANSFORMS_VPTOLOWERING_H_ #include "PTO/IR/PTO.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/PatternMatch.h" -#include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Support/LLVM.h" #include "llvm/Support/raw_ostream.h" namespace mlir { namespace pto { -enum class VPTOTileDomain { - Vec, - Acc, - Mat, -}; - -enum class VPTOLoweringStrategy { - PostUpdate, - NoPostUpdate, -}; - -struct VPTOPartitionTrace { - SmallVector offsets; - SmallVector sizes; - bool hasDynamicOffsets = false; - bool hasDynamicSizes = false; -}; - -struct VPTOLoopProgramming { - int64_t loop2 = 1; - int64_t loop1 = 1; - int64_t srcLoop2Stride = 1; - int64_t srcLoop1Stride = 1; - int64_t dstLoop2Stride = 1; - int64_t dstLoop1Stride = 1; -}; - -enum class VPTOLoopScopeKind { - None, - AIVVectorScope, -}; - -struct VPTOLoopScopeContract { - VPTOLoopScopeKind kind = VPTOLoopScopeKind::None; - StringRef loweredAttr = "llvm.loop.aivector_scope"; - int64_t loopDepth = 0; -}; - -struct VPTOLoadContract { - StringRef sourceLayout; - SmallVector sourceShape; - SmallVector sourceStrides; - StringRef tileLayout; - VPTOTileDomain tileDomain = VPTOTileDomain::Vec; - Type elementType; - Value validRowsValue; - Value validColsValue; - int64_t validRows = ShapedType::kDynamic; - int64_t validCols = ShapedType::kDynamic; - StringRef padMode; - Value padValue; - Value leftPaddingNum; - Value rightPaddingNum; - bool initOutBuffer = false; - Value initCondition; - VPTOPartitionTrace trace; -}; - -struct VPTOUnaryContract { - StringRef family; - VPTOTileDomain tileDomain = VPTOTileDomain::Vec; - StringRef tileLayout; - Value validRowsValue; - Value validColsValue; - int64_t validRows = ShapedType::kDynamic; - int64_t validCols = ShapedType::kDynamic; - Type elementType; - VPTOLoopScopeContract loopScope; -}; - -struct VPTOBinaryContract { - StringRef family; - VPTOTileDomain tileDomain = VPTOTileDomain::Vec; - StringRef tileLayout; - Value validRowsValue; - Value validColsValue; - int64_t validRows = ShapedType::kDynamic; - int64_t validCols = ShapedType::kDynamic; - Type elementType; - VPTOLoopScopeContract loopScope; -}; - -struct VPTOStoreContract { - VPTOTileDomain srcDomain = VPTOTileDomain::Vec; - StringRef destinationLayout; - SmallVector destinationShape; - SmallVector destinationStrides; - Type elementType; - Value validRowsValue; - Value validColsValue; - int64_t validRows = ShapedType::kDynamic; - int64_t validCols = ShapedType::kDynamic; - VPTOPartitionTrace trace; -}; - -void set_loop2_stride_outtoub(Operation *copyOp, int64_t dstStride, - int64_t srcStride, Builder &builder); -void set_loop1_stride_outtoub(Operation *copyOp, int64_t dstStride, - int64_t srcStride, Builder &builder); -void set_loop_size_outtoub(Operation *copyOp, int64_t loop2, int64_t loop1, - Builder &builder); -void set_loop2_stride_ubtoout(Operation *copyOp, int64_t srcStride, - int64_t dstStride, Builder &builder); -void set_loop1_stride_ubtoout(Operation *copyOp, int64_t srcStride, - int64_t dstStride, Builder &builder); -void set_loop_size_ubtoout(Operation *copyOp, int64_t loop2, int64_t loop1, - Builder &builder); -FailureOr -createLoopScopeRegion(Location loc, const VPTOLoopScopeContract &contract, - PatternRewriter &rewriter); Value materializeBufferPointer(Value value, Type elementType, Attribute memorySpace, PatternRewriter &rewriter, Location loc); - -LogicalResult lowerTLOAD(TLoadOp op, PatternRewriter &rewriter); -LogicalResult lowerTABS(TAbsOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy); -LogicalResult lowerTADD(TAddOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy); -LogicalResult lowerTSUB(TSubOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy); -LogicalResult lowerTMUL(TMulOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy); -LogicalResult lowerTDIV(TDivOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy); -LogicalResult lowerTMAX(TMaxOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy); -LogicalResult lowerTMIN(TMinOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy); -LogicalResult lowerTAND(TAndOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy); -LogicalResult lowerTANDS(TAndSOp op, PatternRewriter &rewriter); -LogicalResult lowerTOR(TOrOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy); -LogicalResult lowerTORS(TOrSOp op, PatternRewriter &rewriter); -LogicalResult lowerTXOR(TXorOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy); -LogicalResult lowerTXORS(TXorSOp op, PatternRewriter &rewriter); -LogicalResult lowerTEXP(TExpOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy); -LogicalResult lowerTLOG(TLogOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy); -LogicalResult lowerTSQRT(TSqrtOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy); -LogicalResult lowerTRSQRT(TRsqrtOp op, PatternRewriter &rewriter); -LogicalResult lowerTRECIP(TRecipOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy); -LogicalResult lowerTNEG(TNegOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy); -LogicalResult lowerTLRELU(TLReluOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy); -LogicalResult lowerTCI(TCIOp op, PatternRewriter &rewriter); -LogicalResult lowerTCVT(TCvtOp op, PatternRewriter &rewriter); -LogicalResult lowerTCmp(TCmpOp op, PatternRewriter &rewriter); -LogicalResult lowerTCmpS(TCmpSOp op, PatternRewriter &rewriter); -LogicalResult lowerTSel(TSelOp op, PatternRewriter &rewriter); -LogicalResult lowerTAddC(TAddCOp op, PatternRewriter &rewriter); -LogicalResult lowerTAddS(TAddSOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy); -LogicalResult lowerTAddSC(TAddSCOp op, PatternRewriter &rewriter); -LogicalResult lowerTMinS(TMinSOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy); -LogicalResult lowerTDivS(TDivSOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy); -LogicalResult lowerTMulS(TMulSOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy); -LogicalResult lowerTSubC(TSubCOp op, PatternRewriter &rewriter); -LogicalResult lowerTSubS(TSubSOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy); -LogicalResult lowerTSubSC(TSubSCOp op, PatternRewriter &rewriter); -LogicalResult lowerTMaxS(TMaxSOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy); -LogicalResult lowerTSelS(TSelSOp op, PatternRewriter &rewriter); -LogicalResult lowerTRELU(TReluOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy); -LogicalResult lowerTNOT(TNotOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy); -LogicalResult lowerTTRANS(TTransOp op, PatternRewriter &rewriter); -LogicalResult lowerTFILLPAD(TFillPadOp op, PatternRewriter &rewriter); -LogicalResult lowerTFILLPADExpand(TFillPadExpandOp op, PatternRewriter &rewriter); -LogicalResult lowerTRowMax(TRowMaxOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy); -LogicalResult lowerTRowMin(TRowMinOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy); -LogicalResult lowerTRowSum(TRowSumOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy); -LogicalResult lowerTColMax(TColMaxOp op, PatternRewriter &rewriter); -LogicalResult lowerTColMin(TColMinOp op, PatternRewriter &rewriter); -LogicalResult lowerTColSum(TColSumOp op, PatternRewriter &rewriter); -LogicalResult lowerTRowExpand(TRowExpandOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy); -LogicalResult lowerTColExpand(TColExpandOp op, PatternRewriter &rewriter); -LogicalResult lowerTRowExpandMul(TRowExpandMulOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy); -LogicalResult lowerTRowExpandDiv(TRowExpandDivOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy); -LogicalResult lowerTRowExpandSub(TRowExpandSubOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy); -LogicalResult lowerTPartAdd(TPartAddOp op, PatternRewriter &rewriter); -LogicalResult lowerTPartMax(TPartMaxOp op, PatternRewriter &rewriter); -LogicalResult lowerTPartMin(TPartMinOp op, PatternRewriter &rewriter); -LogicalResult lowerTExpandS(TExpandsOp op, PatternRewriter &rewriter); -LogicalResult lowerTGather(TGatherOp op, PatternRewriter &rewriter); -LogicalResult lowerTGatherB(TGatherBOp op, PatternRewriter &rewriter); -LogicalResult lowerTScatter(TScatterOp op, PatternRewriter &rewriter); -LogicalResult lowerTMrgSort(TMrgSortOp op, PatternRewriter &rewriter); -LogicalResult lowerTSort32(TSort32Op op, PatternRewriter &rewriter); -LogicalResult lowerTSTORE(TStoreOp op, PatternRewriter &rewriter); -LogicalResult lowerSetFlag(SetFlagOp op, PatternRewriter &rewriter); -LogicalResult lowerWaitFlag(WaitFlagOp op, PatternRewriter &rewriter); -LogicalResult lowerBarrier(BarrierOp op, PatternRewriter &rewriter); -LogicalResult lowerGetBuf(GetBufOp op, PatternRewriter &rewriter); -LogicalResult lowerRlsBuf(RlsBufOp op, PatternRewriter &rewriter); LogicalResult convertVPTOEmissionBoundaryToPtr( ModuleOp module, llvm::raw_ostream *diagOS = nullptr); diff --git a/lib/PTO/Transforms/CMakeLists.txt b/lib/PTO/Transforms/CMakeLists.txt index 39637791b..e79c97f64 100644 --- a/lib/PTO/Transforms/CMakeLists.txt +++ b/lib/PTO/Transforms/CMakeLists.txt @@ -19,8 +19,7 @@ add_mlir_dialect_library(PTOTransforms VPTOPtrCastCleanup.cpp PTOVPTOExpandBridgeOps.cpp PTOVPTOPtrBoundary.cpp - PTOToVPTO.cpp - PTOToVPTOLowering.cpp + VPTOBufferMaterialization.cpp PTOValidateVPTOIR.cpp InsertSync/PTOInsertSync.cpp diff --git a/lib/PTO/Transforms/PTOToVPTO.cpp b/lib/PTO/Transforms/PTOToVPTO.cpp deleted file mode 100644 index c73774bf4..000000000 --- a/lib/PTO/Transforms/PTOToVPTO.cpp +++ /dev/null @@ -1,612 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -//===- PTOToVPTO.cpp - PTO to VPTO pass wiring ---------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "PTO/Transforms/VPTOLowering.h" -#include "PTO/Transforms/Passes.h" - -#include "PTO/IR/PTO.h" - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/raw_ostream.h" - -namespace mlir { -namespace pto { - -#define GEN_PASS_DEF_PTOTOVPTO -#include "PTO/Transforms/Passes.h.inc" - -namespace { - - -FailureOr -parseVPTOLoweringStrategy(StringRef strategyName) { - if (strategyName == "post-update") - return VPTOLoweringStrategy::PostUpdate; - if (strategyName == "no-post-update") - return VPTOLoweringStrategy::NoPostUpdate; - return failure(); -} - -LogicalResult lowerTLOADOp(TLoadOp op, PatternRewriter &rewriter) { - return lowerTLOAD(op, rewriter); -} - -LogicalResult lowerTABSOp(TAbsOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - return lowerTABS(op, rewriter, strategy); -} - -LogicalResult lowerTADDOp(TAddOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - return lowerTADD(op, rewriter, strategy); -} - -LogicalResult lowerTSUBOp(TSubOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - return lowerTSUB(op, rewriter, strategy); -} - -LogicalResult lowerTMULOp(TMulOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - return lowerTMUL(op, rewriter, strategy); -} - -LogicalResult lowerTDIVOp(TDivOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - return lowerTDIV(op, rewriter, strategy); -} - -LogicalResult lowerTMAXOp(TMaxOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - return lowerTMAX(op, rewriter, strategy); -} - -LogicalResult lowerTMINOp(TMinOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - return lowerTMIN(op, rewriter, strategy); -} - -LogicalResult lowerTANDOp(TAndOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - return lowerTAND(op, rewriter, strategy); -} - -LogicalResult lowerTANDSOp(TAndSOp op, PatternRewriter &rewriter) { - return lowerTANDS(op, rewriter); -} - -LogicalResult lowerTOROp(TOrOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - return lowerTOR(op, rewriter, strategy); -} - -LogicalResult lowerTORSOp(TOrSOp op, PatternRewriter &rewriter) { - return lowerTORS(op, rewriter); -} - -LogicalResult lowerTXOROp(TXorOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - return lowerTXOR(op, rewriter, strategy); -} - -LogicalResult lowerTXORSOp(TXorSOp op, PatternRewriter &rewriter) { - return lowerTXORS(op, rewriter); -} - -LogicalResult lowerTEXPOp(TExpOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - return lowerTEXP(op, rewriter, strategy); -} - -LogicalResult lowerTLOGOp(TLogOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - return lowerTLOG(op, rewriter, strategy); -} - -LogicalResult lowerTSQRTOp(TSqrtOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - return lowerTSQRT(op, rewriter, strategy); -} - -LogicalResult lowerTRSQRTOp(TRsqrtOp op, PatternRewriter &rewriter) { - return lowerTRSQRT(op, rewriter); -} - -LogicalResult lowerTRECIPOp(TRecipOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - return lowerTRECIP(op, rewriter, strategy); -} - -LogicalResult lowerTNEGOp(TNegOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - return lowerTNEG(op, rewriter, strategy); -} - -LogicalResult lowerTLRELUOp(TLReluOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - return lowerTLRELU(op, rewriter, strategy); -} - -LogicalResult lowerTCIOp(TCIOp op, PatternRewriter &rewriter) { - return lowerTCI(op, rewriter); -} - -LogicalResult lowerTCVTOp(TCvtOp op, PatternRewriter &rewriter) { - return lowerTCVT(op, rewriter); -} - -LogicalResult lowerTCmpOp(TCmpOp op, PatternRewriter &rewriter) { - return lowerTCmp(op, rewriter); -} - -LogicalResult lowerTCmpSOp(TCmpSOp op, PatternRewriter &rewriter) { - return lowerTCmpS(op, rewriter); -} - -LogicalResult lowerTSelOp(TSelOp op, PatternRewriter &rewriter) { - return lowerTSel(op, rewriter); -} - -LogicalResult lowerTAddCOp(TAddCOp op, PatternRewriter &rewriter) { - return lowerTAddC(op, rewriter); -} - -LogicalResult lowerTAddSOp(TAddSOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - return lowerTAddS(op, rewriter, strategy); -} - -LogicalResult lowerTAddSCOp(TAddSCOp op, PatternRewriter &rewriter) { - return lowerTAddSC(op, rewriter); -} - -LogicalResult lowerTMinSOp(TMinSOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - return lowerTMinS(op, rewriter, strategy); -} - -LogicalResult lowerTSubCOp(TSubCOp op, PatternRewriter &rewriter) { - return lowerTSubC(op, rewriter); -} - -LogicalResult lowerTSubSOp(TSubSOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - return lowerTSubS(op, rewriter, strategy); -} - -LogicalResult lowerTSubSCOp(TSubSCOp op, PatternRewriter &rewriter) { - return lowerTSubSC(op, rewriter); -} - -LogicalResult lowerTMaxSOp(TMaxSOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - return lowerTMaxS(op, rewriter, strategy); -} - -LogicalResult lowerTDivSOp(TDivSOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - return lowerTDivS(op, rewriter, strategy); -} - -LogicalResult lowerTMulSOp(TMulSOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - return lowerTMulS(op, rewriter, strategy); -} - -LogicalResult lowerTSelSOp(TSelSOp op, PatternRewriter &rewriter) { - return lowerTSelS(op, rewriter); -} - -LogicalResult lowerTRELUOp(TReluOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - return lowerTRELU(op, rewriter, strategy); -} - -LogicalResult lowerTNOTOp(TNotOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - return lowerTNOT(op, rewriter, strategy); -} - -LogicalResult lowerTTRANSOp(TTransOp op, PatternRewriter &rewriter) { - return lowerTTRANS(op, rewriter); -} - -LogicalResult lowerTFILLPADOp(TFillPadOp op, PatternRewriter &rewriter) { - return lowerTFILLPAD(op, rewriter); -} - -LogicalResult lowerTFILLPADExpandOp(TFillPadExpandOp op, PatternRewriter &rewriter) { - return lowerTFILLPADExpand(op, rewriter); -} - -LogicalResult lowerTRowMaxOp(TRowMaxOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - return lowerTRowMax(op, rewriter, strategy); -} - -LogicalResult lowerTRowMinOp(TRowMinOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - return lowerTRowMin(op, rewriter, strategy); -} - -LogicalResult lowerTRowSumOp(TRowSumOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - return lowerTRowSum(op, rewriter, strategy); -} - -LogicalResult lowerTColMaxOp(TColMaxOp op, PatternRewriter &rewriter) { - return lowerTColMax(op, rewriter); -} - -LogicalResult lowerTColMinOp(TColMinOp op, PatternRewriter &rewriter) { - return lowerTColMin(op, rewriter); -} - -LogicalResult lowerTColSumOp(TColSumOp op, PatternRewriter &rewriter) { - return lowerTColSum(op, rewriter); -} - -LogicalResult lowerTRowExpandOp(TRowExpandOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - return lowerTRowExpand(op, rewriter, strategy); -} - -LogicalResult lowerTColExpandOp(TColExpandOp op, PatternRewriter &rewriter) { - return lowerTColExpand(op, rewriter); -} - -LogicalResult lowerTRowExpandMulOp(TRowExpandMulOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - return lowerTRowExpandMul(op, rewriter, strategy); -} - -LogicalResult lowerTRowExpandDivOp(TRowExpandDivOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - return lowerTRowExpandDiv(op, rewriter, strategy); -} - -LogicalResult lowerTRowExpandSubOp(TRowExpandSubOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - return lowerTRowExpandSub(op, rewriter, strategy); -} - -LogicalResult lowerTPartAddOp(TPartAddOp op, PatternRewriter &rewriter) { - return lowerTPartAdd(op, rewriter); -} - -LogicalResult lowerTPartMaxOp(TPartMaxOp op, PatternRewriter &rewriter) { - return lowerTPartMax(op, rewriter); -} - -LogicalResult lowerTPartMinOp(TPartMinOp op, PatternRewriter &rewriter) { - return lowerTPartMin(op, rewriter); -} - -LogicalResult lowerTExpandSOp(TExpandsOp op, PatternRewriter &rewriter) { - return lowerTExpandS(op, rewriter); -} - -LogicalResult lowerTGatherOp(TGatherOp op, PatternRewriter &rewriter) { - return lowerTGather(op, rewriter); -} - -LogicalResult lowerTGatherBOp(TGatherBOp op, PatternRewriter &rewriter) { - return lowerTGatherB(op, rewriter); -} - -LogicalResult lowerTScatterOp(TScatterOp op, PatternRewriter &rewriter) { - return lowerTScatter(op, rewriter); -} - -LogicalResult lowerTMrgSortOp(TMrgSortOp op, PatternRewriter &rewriter) { - return lowerTMrgSort(op, rewriter); -} - -LogicalResult lowerTSort32Op(TSort32Op op, PatternRewriter &rewriter) { - return lowerTSort32(op, rewriter); -} - -LogicalResult lowerTSTOREOp(TStoreOp op, PatternRewriter &rewriter) { - return lowerTSTORE(op, rewriter); -} - -LogicalResult lowerSetFlagOp(SetFlagOp op, PatternRewriter &rewriter) { - return lowerSetFlag(op, rewriter); -} - -LogicalResult lowerWaitFlagOp(WaitFlagOp op, PatternRewriter &rewriter) { - return lowerWaitFlag(op, rewriter); -} - -LogicalResult lowerBarrierOp(BarrierOp op, PatternRewriter &rewriter) { - return lowerBarrier(op, rewriter); -} - -LogicalResult lowerGetBufOp(GetBufOp op, PatternRewriter &rewriter) { - return lowerGetBuf(op, rewriter); -} - -LogicalResult lowerRlsBufOp(RlsBufOp op, PatternRewriter &rewriter) { - return lowerRlsBuf(op, rewriter); -} - -LogicalResult lowerTensorPipelineOp(Operation *op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - rewriter.setInsertionPoint(op); - - LogicalResult lowered = success(); - if (auto tload = dyn_cast(op)) - lowered = lowerTLOADOp(tload, rewriter); - else if (auto tabs = dyn_cast(op)) - lowered = lowerTABSOp(tabs, rewriter, strategy); - else if (auto tadd = dyn_cast(op)) - lowered = lowerTADDOp(tadd, rewriter, strategy); - else if (auto tsub = dyn_cast(op)) - lowered = lowerTSUBOp(tsub, rewriter, strategy); - else if (auto tmul = dyn_cast(op)) - lowered = lowerTMULOp(tmul, rewriter, strategy); - else if (auto tdiv = dyn_cast(op)) - lowered = lowerTDIVOp(tdiv, rewriter, strategy); - else if (auto tmax = dyn_cast(op)) - lowered = lowerTMAXOp(tmax, rewriter, strategy); - else if (auto tmin = dyn_cast(op)) - lowered = lowerTMINOp(tmin, rewriter, strategy); - else if (auto tand = dyn_cast(op)) - lowered = lowerTANDOp(tand, rewriter, strategy); - else if (auto tands = dyn_cast(op)) - lowered = lowerTANDSOp(tands, rewriter); - else if (auto tor = dyn_cast(op)) - lowered = lowerTOROp(tor, rewriter, strategy); - else if (auto tors = dyn_cast(op)) - lowered = lowerTORSOp(tors, rewriter); - else if (auto txor = dyn_cast(op)) - lowered = lowerTXOROp(txor, rewriter, strategy); - else if (auto txors = dyn_cast(op)) - lowered = lowerTXORSOp(txors, rewriter); - else if (auto texp = dyn_cast(op)) - lowered = lowerTEXPOp(texp, rewriter, strategy); - else if (auto tlog = dyn_cast(op)) - lowered = lowerTLOGOp(tlog, rewriter, strategy); - else if (auto tsqrt = dyn_cast(op)) - lowered = lowerTSQRTOp(tsqrt, rewriter, strategy); - else if (auto trsqr = dyn_cast(op)) - lowered = lowerTRSQRTOp(trsqr, rewriter); - else if (auto trecip = dyn_cast(op)) - lowered = lowerTRECIPOp(trecip, rewriter, strategy); - else if (auto tneg = dyn_cast(op)) - lowered = lowerTNEGOp(tneg, rewriter, strategy); - else if (auto tlrelu = dyn_cast(op)) - lowered = lowerTLRELUOp(tlrelu, rewriter, strategy); - else if (auto tci = dyn_cast(op)) - lowered = lowerTCIOp(tci, rewriter); - else if (auto tcvt = dyn_cast(op)) - lowered = lowerTCVTOp(tcvt, rewriter); - else if (auto tcmp = dyn_cast(op)) - lowered = lowerTCmpOp(tcmp, rewriter); - else if (auto tcmps = dyn_cast(op)) - lowered = lowerTCmpSOp(tcmps, rewriter); - else if (auto tsel = dyn_cast(op)) - lowered = lowerTSelOp(tsel, rewriter); - else if (auto taddc = dyn_cast(op)) - lowered = lowerTAddCOp(taddc, rewriter); - else if (auto tadds = dyn_cast(op)) - lowered = lowerTAddSOp(tadds, rewriter, strategy); - else if (auto taddsc = dyn_cast(op)) - lowered = lowerTAddSCOp(taddsc, rewriter); - else if (auto tmins = dyn_cast(op)) - lowered = lowerTMinSOp(tmins, rewriter, strategy); - else if (auto tsubc = dyn_cast(op)) - lowered = lowerTSubCOp(tsubc, rewriter); - else if (auto tsubs = dyn_cast(op)) - lowered = lowerTSubSOp(tsubs, rewriter, strategy); - else if (auto tsubsc = dyn_cast(op)) - lowered = lowerTSubSCOp(tsubsc, rewriter); - else if (auto tmaxs = dyn_cast(op)) - lowered = lowerTMaxSOp(tmaxs, rewriter, strategy); - else if (auto tdivs = dyn_cast(op)) - lowered = lowerTDivSOp(tdivs, rewriter, strategy); - else if (auto tmuls = dyn_cast(op)) - lowered = lowerTMulSOp(tmuls, rewriter, strategy); - else if (auto tsels = dyn_cast(op)) - lowered = lowerTSelSOp(tsels, rewriter); - else if (auto trelu = dyn_cast(op)) - lowered = lowerTRELUOp(trelu, rewriter, strategy); - else if (auto tnot = dyn_cast(op)) - lowered = lowerTNOTOp(tnot, rewriter, strategy); - else if (auto ttrans = dyn_cast(op)) - lowered = lowerTTRANSOp(ttrans, rewriter); - else if (auto tfillpad = dyn_cast(op)) - lowered = lowerTFILLPADOp(tfillpad, rewriter); - else if (auto tfillpadExpand = dyn_cast(op)) - lowered = lowerTFILLPADExpandOp(tfillpadExpand, rewriter); - else if (auto trowmax = dyn_cast(op)) - lowered = lowerTRowMaxOp(trowmax, rewriter, strategy); - else if (auto trowmin = dyn_cast(op)) - lowered = lowerTRowMinOp(trowmin, rewriter, strategy); - else if (auto trowsum = dyn_cast(op)) - lowered = lowerTRowSumOp(trowsum, rewriter, strategy); - else if (auto tcolmax = dyn_cast(op)) - lowered = lowerTColMaxOp(tcolmax, rewriter); - else if (auto tcolmin = dyn_cast(op)) - lowered = lowerTColMinOp(tcolmin, rewriter); - else if (auto tcolsum = dyn_cast(op)) - lowered = lowerTColSumOp(tcolsum, rewriter); - else if (auto trowexpand = dyn_cast(op)) - lowered = lowerTRowExpandOp(trowexpand, rewriter, strategy); - else if (auto tcolexpand = dyn_cast(op)) - lowered = lowerTColExpandOp(tcolexpand, rewriter); - else if (auto trowexpandmul = dyn_cast(op)) - lowered = lowerTRowExpandMulOp(trowexpandmul, rewriter, strategy); - else if (auto trowexpanddiv = dyn_cast(op)) - lowered = lowerTRowExpandDivOp(trowexpanddiv, rewriter, strategy); - else if (auto trowexpandsub = dyn_cast(op)) - lowered = lowerTRowExpandSubOp(trowexpandsub, rewriter, strategy); - else if (auto tpartadd = dyn_cast(op)) - lowered = lowerTPartAddOp(tpartadd, rewriter); - else if (auto tpartmax = dyn_cast(op)) - lowered = lowerTPartMaxOp(tpartmax, rewriter); - else if (auto tpartmin = dyn_cast(op)) - lowered = lowerTPartMinOp(tpartmin, rewriter); - else if (auto texpands = dyn_cast(op)) - lowered = lowerTExpandSOp(texpands, rewriter); - else if (auto tgather = dyn_cast(op)) - lowered = lowerTGatherOp(tgather, rewriter); - else if (auto tgatherb = dyn_cast(op)) - lowered = lowerTGatherBOp(tgatherb, rewriter); - else if (auto tscatter = dyn_cast(op)) - lowered = lowerTScatterOp(tscatter, rewriter); - else if (auto tmrgsort = dyn_cast(op)) - lowered = lowerTMrgSortOp(tmrgsort, rewriter); - else if (auto tsort32 = dyn_cast(op)) - lowered = lowerTSort32Op(tsort32, rewriter); - else if (auto tstore = dyn_cast(op)) - lowered = lowerTSTOREOp(tstore, rewriter); - else - return success(); - - if (failed(lowered)) - return failure(); - - rewriter.eraseOp(op); - return success(); -} - -LogicalResult lowerResidualPTOOp(Operation *op, PatternRewriter &rewriter) { - rewriter.setInsertionPoint(op); - - LogicalResult lowered = success(); - if (auto setFlag = dyn_cast(op)) - lowered = lowerSetFlagOp(setFlag, rewriter); - else if (auto waitFlag = dyn_cast(op)) - lowered = lowerWaitFlagOp(waitFlag, rewriter); - else if (auto barrier = dyn_cast(op)) - lowered = lowerBarrierOp(barrier, rewriter); - else if (auto getBuf = dyn_cast(op)) - lowered = lowerGetBufOp(getBuf, rewriter); - else if (auto rlsBuf = dyn_cast(op)) - lowered = lowerRlsBufOp(rlsBuf, rewriter); - else if (isa(op) && op->use_empty()) - lowered = success(); - else - return success(); - - if (failed(lowered)) - return failure(); - - rewriter.eraseOp(op); - return success(); -} - -struct PTOToVPTOPass : public impl::PTOToVPTOBase { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PTOToVPTOPass) - - PTOToVPTOPass() = default; - - explicit PTOToVPTOPass(StringRef loweringStrategy) { - this->loweringStrategy = loweringStrategy.str(); - } - - void runOnOperation() override { - ModuleOp module = getOperation(); - FailureOr loweringStrategy = - parseVPTOLoweringStrategy(this->loweringStrategy); - if (failed(loweringStrategy)) { - module.emitError() - << "unsupported pto-lowering-strategy: " << this->loweringStrategy - << " (expected post-update or no-post-update)"; - signalPassFailure(); - return; - } - SmallVector tensorPipelineOps; - SmallVector residualPTOOps; - module.walk([&](Operation *op) { - if (isa(op)) - tensorPipelineOps.push_back(op); - else if (isa(op)) - residualPTOOps.push_back(op); - }); - - PatternRewriter rewriter(&getContext()); - bool sawFailure = false; - for (Operation *op : tensorPipelineOps) { - if (!op->getBlock()) - continue; - if (failed(lowerTensorPipelineOp(op, rewriter, *loweringStrategy))) - sawFailure = true; - } - for (Operation *op : residualPTOOps) { - if (!op->getBlock()) - continue; - if (failed(lowerResidualPTOOp(op, rewriter))) - sawFailure = true; - } - - bool erasedDeadScaffold = true; - while (erasedDeadScaffold) { - erasedDeadScaffold = false; - SmallVector deadScaffoldOps; - module.walk([&](Operation *op) { - if ((isa(op)) && op->use_empty()) - deadScaffoldOps.push_back(op); - }); - for (Operation *op : deadScaffoldOps) { - if (!op->getBlock()) - continue; - rewriter.setInsertionPoint(op); - rewriter.eraseOp(op); - erasedDeadScaffold = true; - } - } - - // Keep the backend mainline memref-first through PTOToVPTO. Pointer ABI - // bridging belongs to the emission boundary, where text/LLVM emitters can - // materialize the required ptr-only signature on a cloned module. - - if (sawFailure) - signalPassFailure(); - } -}; - -} // namespace - -std::unique_ptr createLowerPTOToVPTOPass() { - return std::make_unique(); -} - -std::unique_ptr createLowerPTOToVPTOPass(StringRef loweringStrategy) { - return std::make_unique(loweringStrategy); -} - -} // namespace pto -} // namespace mlir diff --git a/lib/PTO/Transforms/PTOToVPTOLowering.cpp b/lib/PTO/Transforms/PTOToVPTOLowering.cpp deleted file mode 100644 index cdb2f711d..000000000 --- a/lib/PTO/Transforms/PTOToVPTOLowering.cpp +++ /dev/null @@ -1,7367 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -//===- PTOToVPTOLowering.cpp - PTO to VPTO lowering helpers --------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "PTO/Transforms/VPTOLowering.h" - -#include "PTO/IR/PTO.h" -#include "PTO/IR/PTOSyncUtils.h" - -#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" -#include "mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" -#include "llvm/Support/ErrorHandling.h" -#include "llvm/ADT/APFloat.h" - -#include -#include - -namespace mlir { -namespace pto { - -namespace { - -constexpr StringLiteral kLoweredLoopScopeAttrName = "llvm.loop.aivector_scope"; - -static Type getVcaddResultElementType(MLIRContext *context, Type inputElementType) { - if (auto intType = dyn_cast(inputElementType)) { - if (intType.getWidth() == 8) - return IntegerType::get(context, 16, intType.getSignedness()); - if (intType.getWidth() == 16) - return IntegerType::get(context, 32, intType.getSignedness()); - } - return inputElementType; -} - -static pto::VRegType getVcaddResultVRegType(MLIRContext *context, - pto::VRegType inputType) { - int64_t resultLanes = inputType.getElementCount(); - if (auto intType = dyn_cast(inputType.getElementType())) { - if (intType.getWidth() == 8 || intType.getWidth() == 16) - resultLanes /= 2; - } - return pto::VRegType::get( - context, resultLanes, - getVcaddResultElementType(context, inputType.getElementType())); -} - -struct ResolvedTensorView { - Value root; - Attribute layoutAttr; - SmallVector shape; - SmallVector strides; - OpFoldResult offsetElems; -}; - -struct VecNdTransferPlan { - Value outerCount; - Value outerSrcStrideElems; - Value outerDstStrideElems; - Value loop2Size; - Value loop1Size; - Value loop2FirstStrideBytes; - Value loop2SecondStrideBytes; - Value loop1FirstStrideBytes; - Value loop1SecondStrideBytes; - Value nBurst; - Value lenBurst; - Value firstStrideBytes; - Value secondStrideBytes; -}; - -struct VPTORowReduceContract { - StringRef family; - VPTOTileDomain srcDomain = VPTOTileDomain::Vec; - VPTOTileDomain dstDomain = VPTOTileDomain::Vec; - StringRef srcLayout; - StringRef dstLayout; - Type elementType; - Value validRowsValue; - Value validColsValue; - int64_t validRows = ShapedType::kDynamic; - int64_t validCols = ShapedType::kDynamic; - int64_t dstValidCols = ShapedType::kDynamic; - VPTOLoopScopeContract loopScope; -}; - -struct VPTOColReduceContract { - StringRef family; - VPTOTileDomain srcDomain = VPTOTileDomain::Vec; - VPTOTileDomain dstDomain = VPTOTileDomain::Vec; - StringRef srcLayout; - StringRef dstLayout; - Type elementType; - Value validRowsValue; - Value validColsValue; - int64_t validRows = ShapedType::kDynamic; - int64_t validCols = ShapedType::kDynamic; - int64_t dstValidRows = ShapedType::kDynamic; - int64_t dstValidCols = ShapedType::kDynamic; - bool isBinary = false; - Value tmp; - VPTOLoopScopeContract loopScope; -}; - -struct VPTOPartContract { - StringRef family; - VPTOTileDomain src0Domain = VPTOTileDomain::Vec; - VPTOTileDomain src1Domain = VPTOTileDomain::Vec; - VPTOTileDomain dstDomain = VPTOTileDomain::Vec; - StringRef src0Layout; - StringRef src1Layout; - StringRef dstLayout; - Type elementType; - Value src0ValidRowsValue; - Value src0ValidColsValue; - Value src1ValidRowsValue; - Value src1ValidColsValue; - Value dstValidRowsValue; - Value dstValidColsValue; - int64_t src0ValidRows = ShapedType::kDynamic; - int64_t src0ValidCols = ShapedType::kDynamic; - int64_t src1ValidRows = ShapedType::kDynamic; - int64_t src1ValidCols = ShapedType::kDynamic; - int64_t dstValidRows = ShapedType::kDynamic; - int64_t dstValidCols = ShapedType::kDynamic; - VPTOLoopScopeContract loopScope; -}; - -struct VPTOExpandContract { - StringRef family; - VPTOTileDomain srcDomain = VPTOTileDomain::Vec; - VPTOTileDomain dstDomain = VPTOTileDomain::Vec; - StringRef srcLayout; - StringRef dstLayout; - Type elementType; - Value srcValidRowsValue; - Value srcValidColsValue; - Value dstValidRowsValue; - Value dstValidColsValue; - int64_t srcValidRows = ShapedType::kDynamic; - int64_t srcValidCols = ShapedType::kDynamic; - int64_t dstValidRows = ShapedType::kDynamic; - int64_t dstValidCols = ShapedType::kDynamic; - VPTOLoopScopeContract loopScope; -}; - -StringRef inferVecTransferLayoutFromTile(StringRef explicitLayout, - StringRef tileLayout) { - if (explicitLayout != "nd") - return explicitLayout; - if (tileLayout == "col_major") - return "dn"; - return "nd"; -} - -int64_t getElementByteSize(Type type); -Value materializeIndexValue(Value maybeValue, int64_t fallback, - PatternRewriter &rewriter, Location loc); -Value materializeI64Value(Value maybeValue, int64_t fallback, - PatternRewriter &rewriter, Location loc); - -LogicalResult emitUnresolvedInstalledA5BaselineError(Operation *op, - StringRef family) { - return op->emitOpError() - << family - << " lowering is intentionally unresolved until the installed A5 PTO " - "helper baseline is located and traced"; -} - -std::optional getConstInt(Value value) { - if (!value) - return std::nullopt; - - if (auto constIndex = value.getDefiningOp()) - return constIndex.value(); - if (auto constInt = value.getDefiningOp()) - return constInt.value(); - if (auto constOp = value.getDefiningOp()) { - if (auto intAttr = dyn_cast(constOp.getValue())) - return intAttr.getInt(); - } - return std::nullopt; -} - -std::optional getConstInt(OpFoldResult value) { - if (auto attr = dyn_cast(value)) { - if (auto intAttr = dyn_cast(attr)) - return intAttr.getInt(); - return std::nullopt; - } - return getConstInt(cast(value)); -} - -Value materializeIndexOfr(OpFoldResult value, PatternRewriter &rewriter, - Location loc) { - if (auto attr = dyn_cast(value)) { - if (auto intAttr = dyn_cast(attr)) - return rewriter.create(loc, intAttr.getInt()); - return {}; - } - Value v = cast(value); - if (v.getType().isIndex()) - return v; - if (isa(v.getType())) - return rewriter.create(loc, rewriter.getIndexType(), v); - return {}; -} - -Value materializeI64Ofr(OpFoldResult value, PatternRewriter &rewriter, - Location loc) { - if (auto attr = dyn_cast(value)) { - if (auto intAttr = dyn_cast(attr)) - return rewriter.create(loc, intAttr.getInt(), 64); - return {}; - } - return materializeI64Value(cast(value), ShapedType::kDynamic, rewriter, loc); -} - -Value materializeIndexBuilder(OpFoldResult value, PatternRewriter &rewriter, Location loc) { - if (auto attr = dyn_cast(value)) { - if (auto intAttr = dyn_cast(attr)) - return rewriter.create(loc, intAttr.getInt()); - return {}; - } - Value v = cast(value); - if (v.getType().isIndex()) - return v; - if (isa(v.getType())) - return rewriter.create(loc, rewriter.getIndexType(), v); - return {}; -} - -Value createI64Mul(Value lhs, Value rhs, PatternRewriter &rewriter, Location loc) { - if (!lhs || !rhs) - return {}; - if (std::optional lhsConst = getConstInt(lhs)) { - if (std::optional rhsConst = getConstInt(rhs)) - return rewriter.create(loc, (*lhsConst) * (*rhsConst), 64); - } - return rewriter.create(loc, lhs, rhs); -} - -Value createI64Add(Value lhs, Value rhs, PatternRewriter &rewriter, Location loc) { - if (!lhs || !rhs) - return {}; - if (std::optional lhsConst = getConstInt(lhs)) { - if (std::optional rhsConst = getConstInt(rhs)) - return rewriter.create(loc, (*lhsConst) + (*rhsConst), 64); - } - return rewriter.create(loc, lhs, rhs); -} - -OpFoldResult addOfr(OpFoldResult lhs, OpFoldResult rhs, PatternRewriter &rewriter, - Location loc) { - if (auto lhsConst = getConstInt(lhs)) { - if (auto rhsConst = getConstInt(rhs)) - return rewriter.getIndexAttr((*lhsConst) + (*rhsConst)); - } - Value lhsValue = materializeIndexBuilder(lhs, rewriter, loc); - Value rhsValue = materializeIndexBuilder(rhs, rewriter, loc); - if (!lhsValue || !rhsValue) - return {}; - return rewriter.create(loc, lhsValue, rhsValue).getResult(); -} - -OpFoldResult multiplyOfr(OpFoldResult lhs, OpFoldResult rhs, PatternRewriter &rewriter, - Location loc) { - if (auto lhsConst = getConstInt(lhs)) { - if (auto rhsConst = getConstInt(rhs)) - return rewriter.getIndexAttr((*lhsConst) * (*rhsConst)); - } - Value lhsValue = materializeIndexBuilder(lhs, rewriter, loc); - Value rhsValue = materializeIndexBuilder(rhs, rewriter, loc); - if (!lhsValue || !rhsValue) - return {}; - return rewriter.create(loc, lhsValue, rhsValue).getResult(); -} - -bool resolveTensorView(Value value, ResolvedTensorView &info, PatternRewriter &rewriter, - Location loc) { - if (!value) - return false; - - if (auto part = value.getDefiningOp()) { - if (!resolveTensorView(part.getSource(), info, rewriter, loc)) - return false; - SmallVector offsets; - offsets.reserve(part.getOffsets().size()); - for (Value offset : part.getOffsets()) - offsets.push_back(offset); - if (offsets.size() != info.strides.size()) - return false; - OpFoldResult totalOffset = info.offsetElems; - for (auto [offset, stride] : llvm::zip(offsets, info.strides)) { - OpFoldResult term = multiplyOfr(offset, stride, rewriter, loc); - if (!term) - return false; - totalOffset = addOfr(totalOffset, term, rewriter, loc); - if (!totalOffset) - return false; - } - info.offsetElems = totalOffset; - info.shape.clear(); - for (Value size : part.getSizes()) - info.shape.push_back(size); - return true; - } - - if (auto source = value.getDefiningOp()) { - info.root = source.getPtr(); - info.layoutAttr = source.getLayoutAttr(); - info.shape.assign(source.getShape().begin(), source.getShape().end()); - info.strides.assign(source.getStrides().begin(), source.getStrides().end()); - info.offsetElems = rewriter.getIndexAttr(0); - return true; - } - - if (auto subview = value.getDefiningOp()) { - ResolvedTensorView parent; - Value source = subview.getSource(); - if (auto reinterpret = source.getDefiningOp()) { - Value root = reinterpret.getSource(); - while (true) { - if (auto cast = root.getDefiningOp()) { - root = cast.getSource(); - continue; - } - break; - } - parent.root = root; - if (Attribute layout = reinterpret->getAttr("layout")) - parent.layoutAttr = layout; - auto parentShapes = - getMixedValues(reinterpret.getStaticSizes(), reinterpret.getSizes(), rewriter); - auto parentStrides = - getMixedValues(reinterpret.getStaticStrides(), reinterpret.getStrides(), rewriter); - auto offsets = - getMixedValues(reinterpret.getStaticOffsets(), reinterpret.getOffsets(), rewriter); - parent.shape.assign(parentShapes.begin(), parentShapes.end()); - parent.strides.assign(parentStrides.begin(), parentStrides.end()); - parent.offsetElems = - offsets.empty() ? OpFoldResult(rewriter.getIndexAttr(0)) : offsets.front(); - } else if (!resolveTensorView(source, parent, rewriter, loc)) { - return false; - } - - if (parent.strides.empty()) { - auto sourceType = dyn_cast(source.getType()); - if (!sourceType) - return false; - SmallVector strides; - int64_t offset = 0; - if (failed(getStridesAndOffset(sourceType, strides, offset))) { - strides.assign(sourceType.getRank(), 1); - int64_t running = 1; - for (int i = sourceType.getRank() - 1; i >= 0; --i) { - strides[i] = running; - int64_t dim = sourceType.getShape()[i]; - if (dim != ShapedType::kDynamic) - running *= dim; - } - } - for (int64_t stride : strides) - parent.strides.push_back(rewriter.getIndexAttr(stride == ShapedType::kDynamic ? 1 : stride)); - parent.offsetElems = rewriter.getIndexAttr(offset); - parent.root = source; - } - - info = parent; - if (subview.getMixedOffsets().size() != info.strides.size()) - return false; - - OpFoldResult totalOffset = info.offsetElems; - for (auto [offset, stride] : llvm::zip(subview.getMixedOffsets(), info.strides)) { - OpFoldResult term = multiplyOfr(offset, stride, rewriter, loc); - if (!term) - return false; - totalOffset = addOfr(totalOffset, term, rewriter, loc); - if (!totalOffset) - return false; - } - - SmallVector newStrides; - newStrides.reserve(info.strides.size()); - for (auto [srcStride, step] : llvm::zip(info.strides, subview.getMixedStrides())) { - OpFoldResult product = multiplyOfr(srcStride, step, rewriter, loc); - if (!product) - return false; - newStrides.push_back(product); - } - - info.offsetElems = totalOffset; - info.shape.assign(subview.getMixedSizes().begin(), subview.getMixedSizes().end()); - info.strides = std::move(newStrides); - return true; - } - - if (auto reinterpret = value.getDefiningOp()) { - Value root = reinterpret.getSource(); - while (true) { - if (auto cast = root.getDefiningOp()) { - root = cast.getSource(); - continue; - } - if (auto unrealized = root.getDefiningOp()) { - if (!unrealized.getInputs().empty()) { - root = unrealized.getInputs().front(); - continue; - } - } - break; - } - info.root = root; - if (Attribute layout = reinterpret->getAttr("layout")) - info.layoutAttr = layout; - auto reinterpretShapes = - getMixedValues(reinterpret.getStaticSizes(), reinterpret.getSizes(), rewriter); - auto reinterpretStrides = - getMixedValues(reinterpret.getStaticStrides(), reinterpret.getStrides(), rewriter); - auto offsets = - getMixedValues(reinterpret.getStaticOffsets(), reinterpret.getOffsets(), rewriter); - info.shape.assign(reinterpretShapes.begin(), reinterpretShapes.end()); - info.strides.assign(reinterpretStrides.begin(), reinterpretStrides.end()); - if (!offsets.empty()) { - if (offsets.size() != 1) - return false; - info.offsetElems = offsets.front(); - } else { - info.offsetElems = rewriter.getIndexAttr(0); - } - return true; - } - - if (auto cast = value.getDefiningOp()) - return resolveTensorView(cast.getSource(), info, rewriter, loc); - - if (auto memrefType = dyn_cast(value.getType())) { - info.root = value; - info.shape.clear(); - for (int64_t dim : memrefType.getShape()) - info.shape.push_back(rewriter.getIndexAttr(dim == ShapedType::kDynamic ? 1 : dim)); - SmallVector strides; - int64_t offset = 0; - if (failed(getStridesAndOffset(memrefType, strides, offset))) { - strides.assign(memrefType.getRank(), 1); - int64_t running = 1; - for (int i = memrefType.getRank() - 1; i >= 0; --i) { - strides[i] = running; - int64_t dim = memrefType.getShape()[i]; - if (dim != ShapedType::kDynamic) - running *= dim; - } - offset = 0; - } - info.strides.clear(); - for (int64_t stride : strides) - info.strides.push_back(rewriter.getIndexAttr(stride == ShapedType::kDynamic ? 1 : stride)); - info.offsetElems = rewriter.getIndexAttr(offset); - return true; - } - - return false; -} - -void normalizeMixedGlobalShapeAndStride(ArrayRef shape, - ArrayRef strides, - SmallVectorImpl &globalShape, - SmallVectorImpl &globalStride, - PatternRewriter &rewriter, Location loc) { - constexpr int64_t kRank = 5; - globalShape.assign(kRank, rewriter.getIndexAttr(1)); - globalStride.assign(kRank, rewriter.getIndexAttr(1)); - - size_t rank = std::min(shape.size(), strides.size()); - rank = std::min(rank, kRank); - size_t base = kRank - rank; - for (size_t i = 0; i < rank; ++i) { - globalShape[base + i] = shape[shape.size() - rank + i]; - globalStride[base + i] = strides[strides.size() - rank + i]; - } - - for (int i = static_cast(kRank) - 2; i >= 0; --i) { - if (i >= static_cast(base)) - continue; - OpFoldResult product = multiplyOfr(globalStride[i + 1], globalShape[i + 1], rewriter, loc); - if (!product) - product = rewriter.getIndexAttr(ShapedType::kDynamic); - globalStride[i] = product; - } -} - -Value adjustPointerByElemOffset(Value ptr, Value elemOffsetI64, int64_t elemBytes, - PatternRewriter &rewriter, Location loc) { - if (!ptr || !elemOffsetI64 || elemBytes <= 0) - return {}; - - Value offset = elemOffsetI64.getType().isIndex() - ? rewriter.create( - loc, rewriter.getI64Type(), elemOffsetI64) - : elemOffsetI64; - Value byteOffset = offset; - if (elemBytes != 1) { - Value elemBytesValue = rewriter.create(loc, elemBytes, 64); - byteOffset = createI64Mul(offset, elemBytesValue, rewriter, loc); - } - if (auto ptrType = dyn_cast(ptr.getType())) { - auto bytePtrType = PtrType::get(rewriter.getContext(), rewriter.getI8Type(), - ptrType.getMemorySpace()); - Value bytePtr = ptrType == bytePtrType - ? ptr - : rewriter.create(loc, bytePtrType, ptr).getResult(); - Value byteOffsetIndex = - byteOffset.getType().isIndex() - ? byteOffset - : rewriter.create(loc, rewriter.getIndexType(), - byteOffset); - return rewriter.create(loc, bytePtrType, bytePtr, byteOffsetIndex); - } - return {}; -} - -Value castPtrToElementType(Value ptr, Type elementType, PatternRewriter &rewriter, - Location loc) { - auto ptrType = dyn_cast_or_null(ptr.getType()); - if (!ptrType || !elementType) - return {}; - auto targetType = - PtrType::get(rewriter.getContext(), elementType, ptrType.getMemorySpace()); - if (targetType == ptrType) - return ptr; - return rewriter.create(loc, targetType, ptr).getResult(); -} - -Type getCopyTransferElementType(Type elementType, Builder &builder) { - if (getElementByteSize(elementType) == 8) - return builder.getI32Type(); - return elementType; -} - -LogicalResult buildVecNdLoadPlan(ArrayRef shape, - ArrayRef strides, int64_t tileCols, - Value validColsValue, int64_t validCols, - Type elementType, PatternRewriter &rewriter, - Location loc, VecNdTransferPlan &plan) { - if (tileCols == ShapedType::kDynamic) - return failure(); - int64_t elemBytes = getElementByteSize(elementType); - if (elemBytes <= 0) - return failure(); - - SmallVector globalShape; - SmallVector globalStride; - normalizeMixedGlobalShapeAndStride(shape, strides, globalShape, globalStride, rewriter, loc); - - auto toI64 = [&](OpFoldResult ofr) { return materializeI64Ofr(ofr, rewriter, loc); }; - Value gShape0 = toI64(globalShape[0]); - Value gShape1 = toI64(globalShape[1]); - Value gShape2 = toI64(globalShape[2]); - Value gShape3 = toI64(globalShape[3]); - Value gStride0 = toI64(globalStride[0]); - Value gStride1 = toI64(globalStride[1]); - Value gStride2 = toI64(globalStride[2]); - Value gStride3 = toI64(globalStride[3]); - Value validColsI64 = materializeI64Value(validColsValue, validCols, rewriter, loc); - if (!gShape0 || !gShape1 || !gShape2 || !gShape3 || !gStride0 || !gStride1 || - !gStride2 || !gStride3 || !validColsI64) - return failure(); - - Value tileColsI64 = rewriter.create(loc, tileCols, 64); - Value elemBytesI64 = rewriter.create(loc, elemBytes, 64); - Value dstStride2 = createI64Mul(gShape3, tileColsI64, rewriter, loc); - Value dstStride1 = createI64Mul(gShape2, dstStride2, rewriter, loc); - Value dstStride0 = createI64Mul(gShape1, dstStride1, rewriter, loc); - - plan.outerCount = gShape0; - plan.outerSrcStrideElems = gStride0; - plan.outerDstStrideElems = dstStride0; - plan.loop2Size = gShape1; - plan.loop1Size = gShape2; - plan.loop2FirstStrideBytes = createI64Mul(dstStride1, elemBytesI64, rewriter, loc); - plan.loop2SecondStrideBytes = createI64Mul(gStride1, elemBytesI64, rewriter, loc); - plan.loop1FirstStrideBytes = createI64Mul(dstStride2, elemBytesI64, rewriter, loc); - plan.loop1SecondStrideBytes = createI64Mul(gStride2, elemBytesI64, rewriter, loc); - plan.nBurst = gShape3; - plan.lenBurst = createI64Mul(validColsI64, elemBytesI64, rewriter, loc); - plan.firstStrideBytes = createI64Mul(gStride3, elemBytesI64, rewriter, loc); - plan.secondStrideBytes = createI64Mul(tileColsI64, elemBytesI64, rewriter, loc); - return success(); -} - -LogicalResult buildVecDnLoadPlan(ArrayRef shape, - ArrayRef strides, int64_t tileRows, - Value validRowsValue, int64_t validRows, - Type elementType, PatternRewriter &rewriter, - Location loc, VecNdTransferPlan &plan) { - if (tileRows == ShapedType::kDynamic) - return failure(); - int64_t elemBytes = getElementByteSize(elementType); - if (elemBytes <= 0) - return failure(); - - SmallVector globalShape; - SmallVector globalStride; - normalizeMixedGlobalShapeAndStride(shape, strides, globalShape, globalStride, - rewriter, loc); - - auto toI64 = [&](OpFoldResult ofr) { return materializeI64Ofr(ofr, rewriter, loc); }; - Value gShape0 = toI64(globalShape[0]); - Value gShape1 = toI64(globalShape[1]); - Value gShape2 = toI64(globalShape[2]); - Value gShape4 = toI64(globalShape[4]); - Value gStride0 = toI64(globalStride[0]); - Value gStride1 = toI64(globalStride[1]); - Value gStride2 = toI64(globalStride[2]); - Value gStride4 = toI64(globalStride[4]); - Value validRowsI64 = materializeI64Value(validRowsValue, validRows, rewriter, loc); - if (!gShape0 || !gShape1 || !gShape2 || !gShape4 || !gStride0 || !gStride1 || - !gStride2 || !gStride4 || !validRowsI64) - return failure(); - - Value tileRowsI64 = rewriter.create(loc, tileRows, 64); - Value elemBytesI64 = rewriter.create(loc, elemBytes, 64); - Value dstStride2 = createI64Mul(gShape4, tileRowsI64, rewriter, loc); - Value dstStride1 = createI64Mul(gShape2, dstStride2, rewriter, loc); - Value dstStride0 = createI64Mul(gShape1, dstStride1, rewriter, loc); - - plan.outerCount = gShape0; - plan.outerSrcStrideElems = gStride0; - plan.outerDstStrideElems = dstStride0; - plan.loop2Size = gShape1; - plan.loop1Size = gShape2; - plan.loop2FirstStrideBytes = createI64Mul(dstStride1, elemBytesI64, rewriter, loc); - plan.loop2SecondStrideBytes = createI64Mul(gStride1, elemBytesI64, rewriter, loc); - plan.loop1FirstStrideBytes = createI64Mul(dstStride2, elemBytesI64, rewriter, loc); - plan.loop1SecondStrideBytes = createI64Mul(gStride2, elemBytesI64, rewriter, loc); - plan.nBurst = gShape4; - plan.lenBurst = createI64Mul(validRowsI64, elemBytesI64, rewriter, loc); - plan.firstStrideBytes = createI64Mul(gStride4, elemBytesI64, rewriter, loc); - plan.secondStrideBytes = createI64Mul(tileRowsI64, elemBytesI64, rewriter, loc); - return success(); -} - -LogicalResult buildVecNdStorePlan(ArrayRef shape, - ArrayRef strides, int64_t tileCols, - Value validColsValue, int64_t validCols, - Type elementType, PatternRewriter &rewriter, - Location loc, VecNdTransferPlan &plan) { - if (failed(buildVecNdLoadPlan(shape, strides, tileCols, validColsValue, validCols, - elementType, rewriter, loc, plan))) - return failure(); - std::swap(plan.outerSrcStrideElems, plan.outerDstStrideElems); - std::swap(plan.loop2FirstStrideBytes, plan.loop2SecondStrideBytes); - std::swap(plan.loop1FirstStrideBytes, plan.loop1SecondStrideBytes); - return success(); -} - -LogicalResult buildVecDnStorePlan(ArrayRef shape, - ArrayRef strides, int64_t tileRows, - Value validRowsValue, int64_t validRows, - Type elementType, PatternRewriter &rewriter, - Location loc, VecNdTransferPlan &plan) { - if (tileRows == ShapedType::kDynamic) - return failure(); - int64_t elemBytes = getElementByteSize(elementType); - if (elemBytes <= 0) - return failure(); - - SmallVector globalShape; - SmallVector globalStride; - normalizeMixedGlobalShapeAndStride(shape, strides, globalShape, globalStride, - rewriter, loc); - - auto toI64 = [&](OpFoldResult ofr) { return materializeI64Ofr(ofr, rewriter, loc); }; - Value gShape0 = toI64(globalShape[0]); - Value gShape1 = toI64(globalShape[1]); - Value gShape2 = toI64(globalShape[2]); - Value gShape4 = toI64(globalShape[4]); - Value gStride0 = toI64(globalStride[0]); - Value gStride1 = toI64(globalStride[1]); - Value gStride2 = toI64(globalStride[2]); - Value gStride4 = toI64(globalStride[4]); - Value validRowsI64 = materializeI64Value(validRowsValue, validRows, rewriter, loc); - if (!gShape0 || !gShape1 || !gShape2 || !gShape4 || !gStride0 || !gStride1 || - !gStride2 || !gStride4 || !validRowsI64) - return failure(); - - Value tileRowsI64 = rewriter.create(loc, tileRows, 64); - Value elemBytesI64 = rewriter.create(loc, elemBytes, 64); - Value outerSrcStride = - createI64Mul(createI64Mul(createI64Mul(gShape1, gShape2, rewriter, loc), - gShape4, rewriter, loc), - tileRowsI64, rewriter, loc); - Value loop1SrcStride = - createI64Mul(createI64Mul(tileRowsI64, gShape4, rewriter, loc), elemBytesI64, - rewriter, loc); - Value loop2SrcStride = - createI64Mul(createI64Mul(createI64Mul(gShape2, tileRowsI64, rewriter, loc), - gShape4, rewriter, loc), - elemBytesI64, rewriter, loc); - - plan.outerCount = gShape0; - plan.outerSrcStrideElems = outerSrcStride; - plan.outerDstStrideElems = gStride0; - plan.loop2Size = gShape1; - plan.loop1Size = gShape2; - plan.loop2FirstStrideBytes = loop2SrcStride; - plan.loop2SecondStrideBytes = createI64Mul(gStride1, elemBytesI64, rewriter, loc); - plan.loop1FirstStrideBytes = loop1SrcStride; - plan.loop1SecondStrideBytes = createI64Mul(gStride2, elemBytesI64, rewriter, loc); - plan.nBurst = gShape4; - plan.lenBurst = createI64Mul(validRowsI64, elemBytesI64, rewriter, loc); - plan.firstStrideBytes = createI64Mul(gStride4, elemBytesI64, rewriter, loc); - plan.secondStrideBytes = createI64Mul(tileRowsI64, elemBytesI64, rewriter, loc); - return success(); -} - -StringRef stringifyTileLayout(TileBufType type) { - if (auto layoutAttr = dyn_cast_or_null(type.getBLayoutAttr())) { - switch (layoutAttr.getValue()) { - case BLayout::RowMajor: - return "row_major"; - case BLayout::ColMajor: - return "col_major"; - } - } - return "row_major"; -} - -StringRef stringifyTileLayoutConfig(TileBufConfigAttr config) { - if (!config) - return "row_major"; - if (auto layoutAttr = dyn_cast_or_null(config.getBLayout())) { - switch (layoutAttr.getValue()) { - case BLayout::RowMajor: - return "row_major"; - case BLayout::ColMajor: - return "col_major"; - } - } - return "row_major"; -} - -StringRef stringifyPadModeAttr(PadModeAttr padMode) { - if (!padMode) - return "none"; - - switch (padMode.getPadmode()) { - case PadMode::PadNull: - return "none"; - case PadMode::PadFirstElem: - return "first_elem"; - case PadMode::PadValue: - return "value"; - } - return "none"; -} - -StringRef stringifyLayoutAttr(Attribute layoutAttr) { - if (auto attr = dyn_cast_or_null(layoutAttr)) - return stringifyLayout(attr.getLayout()); - return "nd"; -} - -PipeAttr stringifyPipeAttr(PipeAttr pipe, PatternRewriter &rewriter) { - return PipeAttr::get(rewriter.getContext(), pipe.getPipe()); -} - -EventAttr stringifyEventAttr(EventAttr event, PatternRewriter &rewriter) { - return EventAttr::get(rewriter.getContext(), event.getEvent()); -} - -StringRef stringifyCmpModeAttr(CmpModeAttr cmpMode) { - if (!cmpMode) - return "eq"; - switch (cmpMode.getValue()) { - case CmpMode::EQ: - return "eq"; - case CmpMode::NE: - return "ne"; - case CmpMode::LT: - return "lt"; - case CmpMode::LE: - return "le"; - case CmpMode::GT: - return "gt"; - case CmpMode::GE: - return "ge"; - } - return "eq"; -} - -StringRef stringifyElementTypeFragment(Type type) { - if (!type) - return "unknown"; - if (type.isF16()) - return "f16"; - if (type.isBF16()) - return "bf16"; - if (type.isF32()) - return "f32"; - if (auto intType = dyn_cast(type)) { - if (intType.isUnsigned()) - switch (intType.getWidth()) { - case 8: - return "u8"; - case 16: - return "u16"; - case 32: - return "u32"; - case 64: - return "u64"; - default: - break; - } - switch (intType.getWidth()) { - case 8: - return "s8"; - case 16: - return "s16"; - case 32: - return "s32"; - case 64: - return "s64"; - default: - break; - } - } - return "unknown"; -} - -StringRef stringifyCopyTransferTypeFragment(Type type) { - switch (getElementByteSize(type)) { - case 1: - return "u8"; - case 2: - return "u16"; - case 4: - case 8: - return "u32"; - default: - return stringifyElementTypeFragment(type); - } -} - -static bool isSupportedPackedCmp32ElementType(Type type) { - if (!type) - return false; - if (type.isF32()) - return true; - auto intType = dyn_cast(type); - return intType && intType.getWidth() == 32; -} - -VPTOTileDomain deriveTileDomain(Attribute memorySpace) { - if (auto addrSpace = dyn_cast_or_null(memorySpace)) { - switch (addrSpace.getAddressSpace()) { - case AddressSpace::ACC: - return VPTOTileDomain::Acc; - case AddressSpace::MAT: - return VPTOTileDomain::Mat; - case AddressSpace::VEC: - default: - return VPTOTileDomain::Vec; - } - } - if (auto intAttr = dyn_cast_or_null(memorySpace)) { - switch (intAttr.getInt()) { - case static_cast(AddressSpace::ACC): - return VPTOTileDomain::Acc; - case static_cast(AddressSpace::MAT): - return VPTOTileDomain::Mat; - default: - return VPTOTileDomain::Vec; - } - } - return VPTOTileDomain::Vec; -} - -void getValidShape(TileBufType type, int64_t &rows, int64_t &cols) { - ArrayRef validShape = type.getValidShape(); - rows = validShape.size() > 0 ? validShape[0] : ShapedType::kDynamic; - cols = validShape.size() > 1 ? validShape[1] : ShapedType::kDynamic; -} - -static std::pair getIfResultYieldedValues(Value value) { - auto result = dyn_cast(value); - if (!result) - return {Value(), Value()}; - auto ifOp = dyn_cast(result.getOwner()); - if (!ifOp) - return {Value(), Value()}; - unsigned resultNumber = result.getResultNumber(); - auto thenYield = dyn_cast(ifOp.thenBlock()->getTerminator()); - auto elseYield = dyn_cast(ifOp.elseBlock()->getTerminator()); - if (!thenYield || !elseYield) - return {Value(), Value()}; - if (resultNumber >= thenYield.getNumOperands() || - resultNumber >= elseYield.getNumOperands()) - return {Value(), Value()}; - return {thenYield.getOperand(resultNumber), elseYield.getOperand(resultNumber)}; -} - -static bool equalOrBothNull(Value lhs, Value rhs) { - if (!lhs && !rhs) - return true; - if (!lhs || !rhs) - return false; - if (lhs == rhs) - return true; - auto lhsConst = getConstInt(lhs); - auto rhsConst = getConstInt(rhs); - return lhsConst && rhsConst && *lhsConst == *rhsConst; -} - -TileBufConfigAttr lookupTileConfig(Value value) { - if (!value) - return {}; - if (auto bind = value.getDefiningOp()) - return bind.getConfig(); - if (auto cast = value.getDefiningOp()) - return cast.getConfig().value_or(TileBufConfigAttr{}); - if (auto subview = value.getDefiningOp()) - return lookupTileConfig(subview.getSource()); - if (auto reinterpret = value.getDefiningOp()) - return lookupTileConfig(reinterpret.getSource()); - if (auto cast = value.getDefiningOp()) - return lookupTileConfig(cast.getSource()); - if (auto [thenValue, elseValue] = getIfResultYieldedValues(value); - thenValue && elseValue) { - TileBufConfigAttr thenConfig = lookupTileConfig(thenValue); - TileBufConfigAttr elseConfig = lookupTileConfig(elseValue); - if (thenConfig && elseConfig && thenConfig == elseConfig) - return thenConfig; - } - return {}; -} - -bool hasStructuredTileDriver(Value value) { - if (!value) - return false; - if (isa(value.getType())) - return true; - if (value.getDefiningOp()) - return true; - if (auto subview = value.getDefiningOp()) - return hasStructuredTileDriver(subview.getSource()); - if (auto reinterpret = value.getDefiningOp()) - return hasStructuredTileDriver(reinterpret.getSource()); - if (auto cast = value.getDefiningOp()) - return hasStructuredTileDriver(cast.getSource()); - if (auto [thenValue, elseValue] = getIfResultYieldedValues(value); - thenValue && elseValue) { - return hasStructuredTileDriver(thenValue) && hasStructuredTileDriver(elseValue); - } - return false; -} - -void lookupValidDims(Value value, Value &validRow, Value &validCol) { - if (!value) { - validRow = {}; - validCol = {}; - return; - } - if (auto bind = value.getDefiningOp()) { - validRow = bind.getValidRow(); - validCol = bind.getValidCol(); - return; - } - if (auto cast = value.getDefiningOp()) { - validRow = cast.getValidRow(); - validCol = cast.getValidCol(); - return; - } - if (auto subview = value.getDefiningOp()) { - lookupValidDims(subview.getSource(), validRow, validCol); - return; - } - if (auto reinterpret = value.getDefiningOp()) { - lookupValidDims(reinterpret.getSource(), validRow, validCol); - return; - } - if (auto cast = value.getDefiningOp()) { - lookupValidDims(cast.getSource(), validRow, validCol); - return; - } - if (auto [thenValue, elseValue] = getIfResultYieldedValues(value); - thenValue && elseValue) { - Value thenRow; - Value thenCol; - Value elseRow; - Value elseCol; - lookupValidDims(thenValue, thenRow, thenCol); - lookupValidDims(elseValue, elseRow, elseCol); - validRow = equalOrBothNull(thenRow, elseRow) ? thenRow : Value(); - validCol = equalOrBothNull(thenCol, elseCol) ? thenCol : Value(); - return; - } - validRow = {}; - validCol = {}; -} - -Type getElementType(Value value) { - Type type = value.getType(); - if (auto tileType = dyn_cast(type)) - return tileType.getElementType(); - if (auto memrefType = dyn_cast(type)) - return memrefType.getElementType(); - if (auto ptrType = dyn_cast(type)) - return ptrType.getElementType(); - return {}; -} - -Attribute getMemorySpace(Value value) { - Type type = value.getType(); - if (auto tileType = dyn_cast(type)) - return tileType.getMemorySpace(); - if (auto memrefType = dyn_cast(type)) - return memrefType.getMemorySpace(); - if (auto ptrType = dyn_cast(type)) - return ptrType.getMemorySpace(); - return {}; -} - -StringRef deriveTileLayout(Value value) { - if (auto tileType = dyn_cast(value.getType())) - return stringifyTileLayout(tileType); - return stringifyTileLayoutConfig(lookupTileConfig(value)); -} - -void deriveValidShape(Value value, int64_t &rows, int64_t &cols) { - if (auto tileType = dyn_cast(value.getType())) { - getValidShape(tileType, rows, cols); - return; - } - - Value validRow; - Value validCol; - lookupValidDims(value, validRow, validCol); - rows = getConstInt(validRow).value_or(ShapedType::kDynamic); - cols = getConstInt(validCol).value_or(ShapedType::kDynamic); - if (rows != ShapedType::kDynamic && cols != ShapedType::kDynamic) - return; - if (!hasStructuredTileDriver(value)) - return; - - auto shapedType = dyn_cast(value.getType()); - if (!shapedType || !shapedType.hasRank()) - return; - - ArrayRef shape = shapedType.getShape(); - if (shape.empty()) { - if (rows == ShapedType::kDynamic) - rows = 1; - if (cols == ShapedType::kDynamic) - cols = 1; - return; - } - if (shape.size() == 1) { - if (rows == ShapedType::kDynamic) - rows = 1; - if (cols == ShapedType::kDynamic) - cols = shape.front(); - return; - } - - if (cols == ShapedType::kDynamic) - cols = shape.back(); - if (rows == ShapedType::kDynamic) { - int64_t flatRows = 1; - for (int64_t dim : shape.drop_back()) { - if (dim == ShapedType::kDynamic) { - flatRows = ShapedType::kDynamic; - break; - } - flatRows *= dim; - } - rows = flatRows; - } -} - -void deriveValidShapeValues(Value value, Value &rows, Value &cols) { - if (auto tileType = dyn_cast(value.getType())) { - ArrayRef validShape = tileType.getValidShape(); - rows = {}; - cols = {}; - (void)validShape; - lookupValidDims(value, rows, cols); - return; - } - lookupValidDims(value, rows, cols); -} - -void appendStaticSizes(ValueRange values, SmallVectorImpl &out, - bool &hasDynamic) { - out.clear(); - hasDynamic = false; - out.reserve(values.size()); - for (Value value : values) { - if (std::optional constant = getConstInt(value)) { - out.push_back(*constant); - continue; - } - out.push_back(ShapedType::kDynamic); - hasDynamic = true; - } -} - -int64_t getElementByteSize(Type type) { - if (auto floatType = dyn_cast(type)) - return (floatType.getWidth() + 7) / 8; - if (auto intType = dyn_cast(type)) - return (intType.getWidth() + 7) / 8; - return 0; -} - -Value materializeIndexValue(Value maybeValue, int64_t fallback, - PatternRewriter &rewriter, Location loc) { - if (maybeValue) - return maybeValue; - if (fallback != ShapedType::kDynamic) - return rewriter.create(loc, fallback); - return {}; -} - -Value materializeI64Value(Value maybeValue, int64_t fallback, - PatternRewriter &rewriter, Location loc) { - if (maybeValue) { - Type type = maybeValue.getType(); - if (type.isIndex()) - return rewriter.create(loc, rewriter.getI64Type(), maybeValue); - if (type.isInteger(64)) - return maybeValue; - if (auto intType = dyn_cast(type)) - return rewriter.create(loc, rewriter.getI64Type(), maybeValue); - } - if (fallback != ShapedType::kDynamic) - return rewriter.create(loc, fallback, 64); - return {}; -} - -void recordStaticValues(ValueRange values, SmallVectorImpl &out) { - out.clear(); - out.reserve(values.size()); - for (Value value : values) - out.push_back(getConstInt(value).value_or(ShapedType::kDynamic)); -} - -void recordStaticSizes(ArrayRef values, - SmallVectorImpl &out, bool &hasDynamic) { - out.clear(); - hasDynamic = false; - out.reserve(values.size()); - for (OpFoldResult value : values) { - if (auto attr = dyn_cast(value)) { - if (auto intAttr = dyn_cast(attr)) { - out.push_back(intAttr.getInt()); - continue; - } - } else if (std::optional constant = - getConstInt(cast(value))) { - out.push_back(*constant); - continue; - } - out.push_back(ShapedType::kDynamic); - hasDynamic = true; - } -} - -void mergeSubviewTrace(VPTOPartitionTrace &trace, ArrayRef offsets, - ArrayRef sizes, bool hasDynamicOffsets, - bool hasDynamicSizes) { - if (trace.offsets.empty()) { - trace.offsets.assign(offsets.begin(), offsets.end()); - trace.hasDynamicOffsets = hasDynamicOffsets; - } else { - size_t count = std::min(trace.offsets.size(), offsets.size()); - for (size_t i = 0; i < count; ++i) { - if (trace.offsets[i] == ShapedType::kDynamic || - offsets[i] == ShapedType::kDynamic) { - trace.offsets[i] = ShapedType::kDynamic; - trace.hasDynamicOffsets = true; - continue; - } - trace.offsets[i] += offsets[i]; - } - trace.hasDynamicOffsets = trace.hasDynamicOffsets || hasDynamicOffsets; - } - - trace.sizes.assign(sizes.begin(), sizes.end()); - trace.hasDynamicSizes = hasDynamicSizes; -} - -Value resolveTensorViewBase(Value value, Attribute &layoutAttr, - SmallVectorImpl &shape, - SmallVectorImpl &strides) { - if (!value) - return {}; - - if (auto part = value.getDefiningOp()) { - return resolveTensorViewBase(part.getSource(), layoutAttr, shape, strides); - } - - if (auto source = value.getDefiningOp()) { - layoutAttr = source.getLayoutAttr(); - auto tensorType = dyn_cast(source.getResult().getType()); - shape.assign(tensorType.getShape().begin(), tensorType.getShape().end()); - recordStaticValues(source.getStrides(), strides); - return source.getPtr(); - } - - if (auto subview = value.getDefiningOp()) { - Value base = - resolveTensorViewBase(subview.getSource(), layoutAttr, shape, strides); - if (shape.empty()) { - bool hasDynamicSizes = false; - recordStaticSizes(subview.getMixedSizes(), shape, hasDynamicSizes); - } - return base ? base : value; - } - - if (auto reinterpret = value.getDefiningOp()) { - if (Attribute layout = reinterpret->getAttr("layout")) - layoutAttr = layout; - if (shape.empty()) { - bool hasDynamicSizes = false; - recordStaticSizes(reinterpret.getMixedSizes(), shape, hasDynamicSizes); - } - if (strides.empty()) { - bool hasDynamicStrides = false; - recordStaticSizes(reinterpret.getMixedStrides(), strides, - hasDynamicStrides); - } - Value base = - resolveTensorViewBase(reinterpret.getSource(), layoutAttr, shape, strides); - return base ? base : value; - } - - if (auto cast = value.getDefiningOp()) { - Value base = - resolveTensorViewBase(cast.getSource(), layoutAttr, shape, strides); - return base ? base : value; - } - - if (auto memrefType = dyn_cast(value.getType())) { - if (shape.empty()) - shape.assign(memrefType.getShape().begin(), memrefType.getShape().end()); - if (strides.empty()) { - int64_t offset = 0; - if (failed(mlir::getStridesAndOffset(memrefType, strides, offset))) - strides.assign(shape.size(), ShapedType::kDynamic); - } - return value; - } - - return {}; -} - -pto::VRegType getVPTOVRegType(MLIRContext *context, Type elementType) { - unsigned bitWidth = 0; - if (auto floatType = dyn_cast(elementType)) - bitWidth = floatType.getWidth(); - else if (auto intType = dyn_cast(elementType)) - bitWidth = intType.getWidth(); - - if (bitWidth == 0 || 2048 % bitWidth != 0) - return {}; - return pto::VRegType::get(context, 2048 / bitWidth, elementType); -} - -pto::MaskType getVPTOMaskType(MLIRContext *context, StringRef granularity) { - return pto::MaskType::get(context, granularity); -} - -pto::MaskType getVPTOMaskTypeForElementType(MLIRContext *context, - Type elementType) { - unsigned bitWidth = 0; - if (auto floatType = dyn_cast(elementType)) - bitWidth = floatType.getWidth(); - else if (auto intType = dyn_cast(elementType)) - bitWidth = intType.getWidth(); - - switch (bitWidth) { - case 8: - return getVPTOMaskType(context, "b8"); - case 16: - return getVPTOMaskType(context, "b16"); - case 32: - return getVPTOMaskType(context, "b32"); - default: - return {}; - } -} - -ArrayAttr asI64ArrayAttr(Builder &builder, ArrayRef values) { - SmallVector attrs; - attrs.reserve(values.size()); - for (int64_t value : values) - attrs.push_back(builder.getI64IntegerAttr(value)); - return builder.getArrayAttr(attrs); -} - -void normalizeToPTOGlobalShapeAndStride(ArrayRef shape, - ArrayRef strides, - SmallVectorImpl &globalShape, - SmallVectorImpl &globalStride) { - constexpr int64_t kRank = 5; - globalShape.assign(kRank, 1); - globalStride.assign(kRank, 1); - - size_t shapeRank = std::min(shape.size(), kRank); - size_t strideRank = std::min(strides.size(), kRank); - size_t rank = std::min(shapeRank, strideRank); - size_t base = kRank - rank; - - for (size_t i = 0; i < rank; ++i) { - globalShape[base + i] = shape[shape.size() - rank + i]; - globalStride[base + i] = strides[strides.size() - rank + i]; - } - - for (int i = static_cast(kRank) - 2; i >= 0; --i) { - if (i >= static_cast(base)) - continue; - if (globalStride[i + 1] == ShapedType::kDynamic || - globalShape[i + 1] == ShapedType::kDynamic) { - globalStride[i] = ShapedType::kDynamic; - continue; - } - globalStride[i] = globalStride[i + 1] * globalShape[i + 1]; - } -} - -int64_t packLoopStrideConfig(int64_t first, int64_t second) { - return (static_cast(first) << 40) | static_cast(second); -} - -int64_t packLoopSizeConfig(int64_t loop2, int64_t loop1) { - return (static_cast(loop2) << 21) | static_cast(loop1); -} - -LogicalResult deriveVecNDTransferConfig(ArrayRef shape, - ArrayRef strides, - StringRef tileLayout, Type elementType, - int64_t validRows, int64_t validCols, - SmallVectorImpl &globalShape, - SmallVectorImpl &globalStride, - int64_t &nBurst, int64_t &lenBurst, - int64_t &gmStrideBytes, - int64_t &ubStrideBytes, - int64_t &loop1Size, - int64_t &loop2Size, - int64_t &loop1FirstStrideBytes, - int64_t &loop1SecondStrideBytes, - int64_t &loop2FirstStrideBytes, - int64_t &loop2SecondStrideBytes) { - if (tileLayout != "row_major") - return failure(); - - int64_t elemBytes = getElementByteSize(elementType); - if (elemBytes <= 0) - return failure(); - - normalizeToPTOGlobalShapeAndStride(shape, strides, globalShape, globalStride); - if (globalShape.size() != 5 || globalStride.size() != 5) - return failure(); - if (llvm::any_of(globalShape, [](int64_t v) { return v == ShapedType::kDynamic; }) || - llvm::any_of(globalStride, [](int64_t v) { return v == ShapedType::kDynamic; })) - return failure(); - nBurst = globalShape[3]; - lenBurst = (validCols == ShapedType::kDynamic) ? ShapedType::kDynamic - : validCols * elemBytes; - gmStrideBytes = globalStride[3] * elemBytes; - ubStrideBytes = globalShape[4] * elemBytes; - - int64_t dstStride2 = globalShape[3] * validCols; - int64_t dstStride1 = globalShape[2] * dstStride2; - - loop2Size = globalShape[1]; - loop1Size = globalShape[2]; - loop2FirstStrideBytes = dstStride1 * elemBytes; - loop2SecondStrideBytes = globalStride[1] * elemBytes; - loop1FirstStrideBytes = dstStride2 * elemBytes; - loop1SecondStrideBytes = globalStride[2] * elemBytes; - return success(); -} - -std::pair getStaticTileRowsCols(Value value) { - if (auto shapedType = dyn_cast(value.getType())) { - ArrayRef shape = shapedType.getShape(); - if (shape.size() >= 2) - return {shape[shape.size() - 2], shape[shape.size() - 1]}; - } - return {ShapedType::kDynamic, ShapedType::kDynamic}; -} - -Value materializeStaticOrDynamicDimAsIndex(Value value, int64_t dim, - unsigned dimPos, - PatternRewriter &rewriter, - Location loc) { - if (dim != ShapedType::kDynamic) - return rewriter.create(loc, dim); - if (isa(value.getType())) - return rewriter.create(loc, value, dimPos); - return {}; -} - -LogicalResult materializeShapeBackedValidShapeValues(Value value, Value &rows, - Value &cols, - PatternRewriter &rewriter, - Location loc) { - rows = {}; - cols = {}; - - auto shapedType = dyn_cast(value.getType()); - if (!shapedType || !shapedType.hasRank() || !hasStructuredTileDriver(value)) - return failure(); - - ArrayRef shape = shapedType.getShape(); - if (shape.empty()) { - rows = rewriter.create(loc, 1); - cols = rewriter.create(loc, 1); - return success(); - } - if (shape.size() == 1) { - rows = rewriter.create(loc, 1); - cols = materializeStaticOrDynamicDimAsIndex(value, shape.front(), 0, rewriter, loc); - return success(cols != nullptr); - } - - cols = materializeStaticOrDynamicDimAsIndex(value, shape.back(), shape.size() - 1, - rewriter, loc); - if (!cols) - return failure(); - - Value flatRows = rewriter.create(loc, 1); - for (auto [idx, dim] : llvm::enumerate(shape.drop_back())) { - Value dimValue = - materializeStaticOrDynamicDimAsIndex(value, dim, idx, rewriter, loc); - if (!dimValue) - return failure(); - flatRows = rewriter.create(loc, flatRows, dimValue); - } - rows = flatRows; - return success(); -} - -LogicalResult resolveExecutionValidShape(Value carrier, Value &rowsValue, - Value &colsValue, int64_t &rows, - int64_t &cols, - PatternRewriter &rewriter, - Location loc) { - rowsValue = materializeIndexValue(rowsValue, rows, rewriter, loc); - colsValue = materializeIndexValue(colsValue, cols, rewriter, loc); - if (rowsValue && colsValue) - return success(); - - if (succeeded(materializeShapeBackedValidShapeValues(carrier, rowsValue, colsValue, - rewriter, loc))) { - deriveValidShape(carrier, rows, cols); - return success(rowsValue && colsValue); - } - return failure(); -} - -Attribute getGmMemorySpace(MLIRContext *context) { - return AddressSpaceAttr::get(context, AddressSpace::GM); -} - -AddressSpaceAttr getNormalizedPtrMemorySpace(Attribute memorySpace, - MLIRContext *context) { - if (auto addrSpace = dyn_cast_or_null(memorySpace)) - return addrSpace; - if (auto intAttr = dyn_cast_or_null(memorySpace)) - return AddressSpaceAttr::get(context, - static_cast(intAttr.getInt())); - return AddressSpaceAttr::get(context, AddressSpace::GM); -} - -Value materializeMemRefView(Value value, ArrayRef shape, Type elementType, - Attribute memorySpace, - PatternRewriter &rewriter, Location loc) { - auto memrefType = - MemRefType::get(shape, elementType, AffineMap(), memorySpace); - if (value.getType() == memrefType) - return value; - return rewriter - .create( - loc, TypeRange(ArrayRef{memrefType}), value) - .getResult(0); -} - -Value materializeTileBufferView(Value value, PatternRewriter &rewriter, - Location loc) { - if (auto memrefType = dyn_cast(value.getType())) - return value; - - auto tileType = dyn_cast(value.getType()); - if (!tileType) - return {}; - - return materializeMemRefView(value, tileType.getShape(), tileType.getElementType(), - tileType.getMemorySpace(), rewriter, loc); -} - -} // namespace - -Value materializeBufferPointer(Value value, Type elementType, - Attribute memorySpace, - PatternRewriter &rewriter, Location loc) { - if (!value) - return {}; - - auto ptrMemorySpace = - getNormalizedPtrMemorySpace(memorySpace, rewriter.getContext()); - auto ptrType = PtrType::get(rewriter.getContext(), elementType, ptrMemorySpace); - - if (value.getType() == ptrType) - return value; - - if (auto bind = value.getDefiningOp()) - return materializeBufferPointer(bind.getSource(), elementType, memorySpace, - rewriter, loc); - - if (auto cast = value.getDefiningOp()) { - if (cast.getAddrs().empty()) - return {}; - return rewriter.create(loc, ptrType, cast.getAddrs().front()) - .getResult(); - } - - Value memrefValue = materializeTileBufferView(value, rewriter, loc); - auto memrefType = dyn_cast_or_null(memrefValue.getType()); - if (!memrefValue || !memrefType) - return {}; - return rewriter.create(loc, ptrType, memrefValue).getResult(); -} - -namespace { - -Value materializeBufferLikeAddress(Value value, Type elementType, - Attribute memorySpace, - PatternRewriter &rewriter, Location loc) { - if (!value) - return {}; - - if (auto bind = value.getDefiningOp()) - return materializeBufferLikeAddress(bind.getSource(), elementType, memorySpace, - rewriter, loc); - - // Keep memref semantics through the VPTO mainline whenever possible. - Value memrefValue = materializeTileBufferView(value, rewriter, loc); - if (memrefValue && isa(memrefValue.getType())) - return memrefValue; - - return materializeBufferPointer(value, elementType, memorySpace, rewriter, loc); -} - -Value offsetBufferPointer(Value basePtr, Type elementType, Value elementOffset, - PatternRewriter &rewriter, Location loc) { - if (!basePtr) - return {}; - - if (auto ptrType = dyn_cast(basePtr.getType())) { - Value offsetIndex = - elementOffset.getType().isIndex() - ? elementOffset - : rewriter.create(loc, - rewriter.getIndexType(), - elementOffset); - return rewriter.create(loc, ptrType, basePtr, offsetIndex); - } - return {}; -} - -Value buildPackedCountI64(PatternRewriter &rewriter, Location loc, - ArrayRef counts) { - Value packed = rewriter.create(loc, 0, 64); - for (auto [idx, count] : llvm::enumerate(counts)) { - Value countI64 = count.getType().isIndex() - ? rewriter.create( - loc, rewriter.getI64Type(), count) - : count; - if (idx != 0) { - Value shift = rewriter.create(loc, idx * 16, 64); - countI64 = rewriter.create(loc, countI64, shift); - } - packed = rewriter.create(loc, packed, countI64); - } - return packed; -} - -Value buildCeilDivPositiveI64(PatternRewriter &rewriter, Location loc, Value lhs, - int64_t rhs) { - Value rhsValue = rewriter.create(loc, rhs, 64); - Value rhsMinusOne = rewriter.create(loc, rhs - 1, 64); - Value biased = rewriter.create(loc, lhs, rhsMinusOne); - return rewriter.create(loc, biased, rhsValue); -} - -VPTOPartitionTrace extractPartitionTrace(Value value) { - VPTOPartitionTrace trace; - if (auto part = value.getDefiningOp()) { - appendStaticSizes(part.getOffsets(), trace.offsets, trace.hasDynamicOffsets); - appendStaticSizes(part.getSizes(), trace.sizes, trace.hasDynamicSizes); - return trace; - } - if (auto subview = value.getDefiningOp()) { - trace = extractPartitionTrace(subview.getSource()); - SmallVector offsets; - SmallVector sizes; - bool hasDynamicOffsets = false; - bool hasDynamicSizes = false; - recordStaticSizes(subview.getMixedOffsets(), offsets, hasDynamicOffsets); - recordStaticSizes(subview.getMixedSizes(), sizes, hasDynamicSizes); - mergeSubviewTrace(trace, offsets, sizes, hasDynamicOffsets, hasDynamicSizes); - return trace; - } - if (auto reinterpret = value.getDefiningOp()) - return extractPartitionTrace(reinterpret.getSource()); - if (auto cast = value.getDefiningOp()) - return extractPartitionTrace(cast.getSource()); - if (auto unrealized = value.getDefiningOp()) { - if (!unrealized.getInputs().empty()) - return extractPartitionTrace(unrealized.getInputs().front()); - } - return trace; -} - -VPTOLoadContract extractTLoadContract(TLoadOp op) { - VPTOLoadContract contract; - contract.trace = extractPartitionTrace(op.getSrc()); - contract.elementType = getElementType(op.getDst()); - - Attribute layoutAttr; - Value base = resolveTensorViewBase(op.getSrc(), layoutAttr, contract.sourceShape, - contract.sourceStrides); - (void)base; - contract.sourceLayout = stringifyLayoutAttr(layoutAttr); - - contract.tileLayout = deriveTileLayout(op.getDst()); - contract.tileDomain = deriveTileDomain(getMemorySpace(op.getDst())); - deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); - deriveValidShape(op.getDst(), contract.validRows, contract.validCols); - contract.padMode = stringifyPadModeAttr(op.getPadModeAttr()); - contract.padValue = op.getPadValue(); - contract.leftPaddingNum = op.getLeftPaddingNum(); - contract.rightPaddingNum = op.getRightPaddingNum(); - contract.initOutBuffer = op.getInitOutBuffer(); - contract.initCondition = op.getInitCondition(); - return contract; -} - -VPTOUnaryContract extractTAbsContract(TAbsOp op) { - VPTOUnaryContract contract; - contract.family = "abs"; - contract.tileDomain = deriveTileDomain(getMemorySpace(op.getSrc())); - contract.tileLayout = deriveTileLayout(op.getSrc()); - deriveValidShapeValues(op.getSrc(), contract.validRowsValue, contract.validColsValue); - deriveValidShape(op.getSrc(), contract.validRows, contract.validCols); - contract.elementType = getElementType(op.getSrc()); - contract.loopScope.kind = VPTOLoopScopeKind::AIVVectorScope; - contract.loopScope.loweredAttr = kLoweredLoopScopeAttrName; - contract.loopScope.loopDepth = 0; - return contract; -} - -VPTOBinaryContract buildBinaryContract(StringRef family, Value src0) { - VPTOBinaryContract contract; - contract.family = family; - contract.tileDomain = deriveTileDomain(getMemorySpace(src0)); - contract.tileLayout = deriveTileLayout(src0); - deriveValidShapeValues(src0, contract.validRowsValue, contract.validColsValue); - deriveValidShape(src0, contract.validRows, contract.validCols); - contract.elementType = getElementType(src0); - contract.loopScope.kind = VPTOLoopScopeKind::AIVVectorScope; - contract.loopScope.loweredAttr = kLoweredLoopScopeAttrName; - contract.loopScope.loopDepth = 0; - return contract; -} - -VPTOBinaryContract extractTAddContract(TAddOp op) { - return buildBinaryContract("add", op.getSrc0()); -} - -VPTOBinaryContract extractTSubContract(TSubOp op) { - return buildBinaryContract("sub", op.getSrc0()); -} - -VPTOBinaryContract extractTMulContract(TMulOp op) { - return buildBinaryContract("mul", op.getSrc0()); -} - -VPTOBinaryContract extractTDivContract(TDivOp op) { - return buildBinaryContract("div", op.getSrc0()); -} - -VPTOUnaryContract buildUnaryContract(StringRef family, Value src) { - VPTOUnaryContract contract; - contract.family = family; - contract.tileDomain = deriveTileDomain(getMemorySpace(src)); - contract.tileLayout = deriveTileLayout(src); - deriveValidShapeValues(src, contract.validRowsValue, contract.validColsValue); - deriveValidShape(src, contract.validRows, contract.validCols); - contract.elementType = getElementType(src); - contract.loopScope.kind = VPTOLoopScopeKind::AIVVectorScope; - contract.loopScope.loweredAttr = kLoweredLoopScopeAttrName; - contract.loopScope.loopDepth = 0; - return contract; -} - -static bool isCompatibleScalarForSemanticType(Type semanticType, - Type scalarType) { - if (semanticType == scalarType) - return true; - - auto semanticInt = dyn_cast(semanticType); - auto scalarInt = dyn_cast(scalarType); - if (!semanticInt || !scalarInt || semanticInt.getWidth() != scalarInt.getWidth()) - return false; - - if (semanticInt.isSigned()) - return scalarInt.isSigned() || scalarInt.isSignless(); - if (semanticInt.isUnsigned()) - return scalarInt.isUnsigned() || scalarInt.isSignless(); - return scalarInt.isSignless(); -} - -VPTOUnaryContract extractTExpContract(TExpOp op) { - return buildUnaryContract("exp", op.getSrc()); -} - -VPTOUnaryContract extractTLogContract(TLogOp op) { - return buildUnaryContract("log", op.getSrc()); -} - -VPTOUnaryContract extractTSqrtContract(TSqrtOp op) { - return buildUnaryContract("sqrt", op.getSrc()); -} - -VPTOUnaryContract extractTRecipContract(TRecipOp op) { - return buildUnaryContract("recip", op.getSrc()); -} - -VPTOUnaryContract extractTReluContract(TReluOp op) { - return buildUnaryContract("relu", op.getSrc()); -} - -VPTOUnaryContract extractTNotContract(TNotOp op) { - return buildUnaryContract("not", op.getSrc()); -} - -static FailureOr stringifyA5RoundMode(TCvtOp op, - PatternRewriter &rewriter) { - switch (op.getRmode()) { - case RoundMode::NONE: - case RoundMode::RINT: - case RoundMode::CAST_RINT: - return rewriter.getStringAttr("ROUND_R"); - case RoundMode::ROUND: - return rewriter.getStringAttr("ROUND_A"); - case RoundMode::FLOOR: - return rewriter.getStringAttr("ROUND_F"); - case RoundMode::CEIL: - return rewriter.getStringAttr("ROUND_C"); - case RoundMode::TRUNC: - return rewriter.getStringAttr("ROUND_Z"); - case RoundMode::ODD: - return rewriter.getStringAttr("ROUND_O"); - } - return failure(); -} - -enum class VPTOCvtLoweringKind { - Vtrc, - F32ToBF16, - F16ToF32, - BF16ToF16, - BF16ToF32, -}; - -static FailureOr classifyA5CvtLowering(Type srcElemType, - Type dstElemType) { - if (srcElemType.isF32() && dstElemType.isF32()) - return VPTOCvtLoweringKind::Vtrc; - if (srcElemType.isF32() && dstElemType.isBF16()) - return VPTOCvtLoweringKind::F32ToBF16; - if (srcElemType.isF16() && dstElemType.isF32()) - return VPTOCvtLoweringKind::F16ToF32; - if (srcElemType.isBF16() && dstElemType.isF16()) - return VPTOCvtLoweringKind::BF16ToF16; - if (srcElemType.isBF16() && dstElemType.isF32()) - return VPTOCvtLoweringKind::BF16ToF32; - return failure(); -} - -VPTOUnaryContract extractTExpandSContract(TExpandsOp op) { - VPTOUnaryContract contract; - contract.family = "expands"; - contract.tileDomain = deriveTileDomain(getMemorySpace(op.getDst())); - contract.tileLayout = deriveTileLayout(op.getDst()); - deriveValidShapeValues(op.getDst(), contract.validRowsValue, - contract.validColsValue); - deriveValidShape(op.getDst(), contract.validRows, contract.validCols); - contract.elementType = getElementType(op.getDst()); - contract.loopScope.kind = VPTOLoopScopeKind::AIVVectorScope; - contract.loopScope.loweredAttr = kLoweredLoopScopeAttrName; - contract.loopScope.loopDepth = 0; - return contract; -} - -VPTOExpandContract extractTRowExpandContract(TRowExpandOp op) { - VPTOExpandContract contract; - contract.family = "rowexpand"; - contract.srcDomain = deriveTileDomain(getMemorySpace(op.getSrc())); - contract.dstDomain = deriveTileDomain(getMemorySpace(op.getDst())); - contract.srcLayout = deriveTileLayout(op.getSrc()); - contract.dstLayout = deriveTileLayout(op.getDst()); - contract.elementType = getElementType(op.getSrc()); - deriveValidShapeValues(op.getSrc(), contract.srcValidRowsValue, - contract.srcValidColsValue); - deriveValidShape(op.getSrc(), contract.srcValidRows, contract.srcValidCols); - deriveValidShapeValues(op.getDst(), contract.dstValidRowsValue, - contract.dstValidColsValue); - deriveValidShape(op.getDst(), contract.dstValidRows, contract.dstValidCols); - contract.loopScope.kind = VPTOLoopScopeKind::AIVVectorScope; - contract.loopScope.loweredAttr = kLoweredLoopScopeAttrName; - contract.loopScope.loopDepth = 0; - return contract; -} - -VPTOExpandContract extractTColExpandContract(TColExpandOp op) { - VPTOExpandContract contract; - contract.family = "colexpand"; - contract.srcDomain = deriveTileDomain(getMemorySpace(op.getSrc())); - contract.dstDomain = deriveTileDomain(getMemorySpace(op.getDst())); - contract.srcLayout = deriveTileLayout(op.getSrc()); - contract.dstLayout = deriveTileLayout(op.getDst()); - contract.elementType = getElementType(op.getSrc()); - deriveValidShapeValues(op.getSrc(), contract.srcValidRowsValue, - contract.srcValidColsValue); - deriveValidShape(op.getSrc(), contract.srcValidRows, contract.srcValidCols); - deriveValidShapeValues(op.getDst(), contract.dstValidRowsValue, - contract.dstValidColsValue); - deriveValidShape(op.getDst(), contract.dstValidRows, contract.dstValidCols); - contract.loopScope.kind = VPTOLoopScopeKind::AIVVectorScope; - contract.loopScope.loweredAttr = kLoweredLoopScopeAttrName; - contract.loopScope.loopDepth = 0; - return contract; -} - -VPTORowReduceContract extractTRowReduceContract(Value src, Value dst, - StringRef family) { - VPTORowReduceContract contract; - contract.family = family; - contract.srcDomain = deriveTileDomain(getMemorySpace(src)); - contract.dstDomain = deriveTileDomain(getMemorySpace(dst)); - contract.srcLayout = deriveTileLayout(src); - contract.dstLayout = deriveTileLayout(dst); - contract.elementType = getElementType(src); - deriveValidShapeValues(src, contract.validRowsValue, contract.validColsValue); - deriveValidShape(src, contract.validRows, contract.validCols); - int64_t dstRows = ShapedType::kDynamic; - deriveValidShape(dst, dstRows, contract.dstValidCols); - contract.loopScope.kind = VPTOLoopScopeKind::AIVVectorScope; - contract.loopScope.loweredAttr = kLoweredLoopScopeAttrName; - contract.loopScope.loopDepth = 0; - return contract; -} - -VPTORowReduceContract extractTRowMaxContract(TRowMaxOp op) { - return extractTRowReduceContract(op.getSrc(), op.getDst(), "rowmax"); -} - -VPTORowReduceContract extractTRowMinContract(TRowMinOp op) { - return extractTRowReduceContract(op.getSrc(), op.getDst(), "rowmin"); -} - -VPTORowReduceContract extractTRowSumContract(TRowSumOp op) { - return extractTRowReduceContract(op.getSrc(), op.getDst(), "rowsum"); -} - -VPTOColReduceContract extractTColReduceContract(Value src, Value dst, - StringRef family) { - VPTOColReduceContract contract; - contract.family = family; - contract.srcDomain = deriveTileDomain(getMemorySpace(src)); - contract.dstDomain = deriveTileDomain(getMemorySpace(dst)); - contract.srcLayout = deriveTileLayout(src); - contract.dstLayout = deriveTileLayout(dst); - contract.elementType = getElementType(src); - deriveValidShapeValues(src, contract.validRowsValue, contract.validColsValue); - deriveValidShape(src, contract.validRows, contract.validCols); - deriveValidShape(dst, contract.dstValidRows, contract.dstValidCols); - contract.loopScope.kind = VPTOLoopScopeKind::AIVVectorScope; - contract.loopScope.loweredAttr = kLoweredLoopScopeAttrName; - contract.loopScope.loopDepth = 0; - return contract; -} - -VPTOColReduceContract extractTColMaxContract(TColMaxOp op) { - return extractTColReduceContract(op.getSrc(), op.getDst(), "colmax"); -} - -VPTOColReduceContract extractTColMinContract(TColMinOp op) { - return extractTColReduceContract(op.getSrc(), op.getDst(), "colmin"); -} - -VPTOColReduceContract extractTColSumContract(TColSumOp op) { - VPTOColReduceContract contract = - extractTColReduceContract(op.getSrc(), op.getDst(), "colsum"); - contract.isBinary = op.getIsBinary(); - contract.tmp = op.getTmp(); - return contract; -} - -VPTOPartContract extractTPartContract(Value src0, Value src1, Value dst, - StringRef family) { - VPTOPartContract contract; - contract.family = family; - contract.src0Domain = deriveTileDomain(getMemorySpace(src0)); - contract.src1Domain = deriveTileDomain(getMemorySpace(src1)); - contract.dstDomain = deriveTileDomain(getMemorySpace(dst)); - contract.src0Layout = deriveTileLayout(src0); - contract.src1Layout = deriveTileLayout(src1); - contract.dstLayout = deriveTileLayout(dst); - contract.elementType = getElementType(dst); - deriveValidShapeValues(src0, contract.src0ValidRowsValue, contract.src0ValidColsValue); - deriveValidShapeValues(src1, contract.src1ValidRowsValue, contract.src1ValidColsValue); - deriveValidShapeValues(dst, contract.dstValidRowsValue, contract.dstValidColsValue); - deriveValidShape(src0, contract.src0ValidRows, contract.src0ValidCols); - deriveValidShape(src1, contract.src1ValidRows, contract.src1ValidCols); - deriveValidShape(dst, contract.dstValidRows, contract.dstValidCols); - contract.loopScope.kind = VPTOLoopScopeKind::AIVVectorScope; - contract.loopScope.loweredAttr = kLoweredLoopScopeAttrName; - contract.loopScope.loopDepth = 0; - return contract; -} - -VPTOPartContract extractTPartAddContract(TPartAddOp op) { - return extractTPartContract(op.getSrc0(), op.getSrc1(), op.getDst(), "partadd"); -} - -VPTOPartContract extractTPartMaxContract(TPartMaxOp op) { - return extractTPartContract(op.getSrc0(), op.getSrc1(), op.getDst(), "partmax"); -} - -VPTOPartContract extractTPartMinContract(TPartMinOp op) { - return extractTPartContract(op.getSrc0(), op.getSrc1(), op.getDst(), "partmin"); -} - -VPTOStoreContract extractTStoreContract(TStoreOp op) { - VPTOStoreContract contract; - contract.trace = extractPartitionTrace(op.getDst()); - - contract.srcDomain = deriveTileDomain(getMemorySpace(op.getSrc())); - deriveValidShapeValues(op.getSrc(), contract.validRowsValue, contract.validColsValue); - deriveValidShape(op.getSrc(), contract.validRows, contract.validCols); - contract.elementType = getElementType(op.getSrc()); - - Attribute layoutAttr; - Value base = resolveTensorViewBase(op.getDst(), layoutAttr, - contract.destinationShape, - contract.destinationStrides); - (void)base; - contract.destinationLayout = stringifyLayoutAttr(layoutAttr); - return contract; -} - -void attachLoadContractAttrs(Operation *op, const VPTOLoadContract &contract) { - Builder builder(op->getContext()); - SmallVector globalShape; - SmallVector globalStride; - normalizeToPTOGlobalShapeAndStride(contract.sourceShape, contract.sourceStrides, - globalShape, globalStride); - op->setAttr("g_shape", asI64ArrayAttr(builder, globalShape)); - op->setAttr("g_strides", asI64ArrayAttr(builder, globalStride)); -} - -void attachStoreContractAttrs(Operation *op, const VPTOStoreContract &contract) { - Builder builder(op->getContext()); - SmallVector globalShape; - SmallVector globalStride; - normalizeToPTOGlobalShapeAndStride(contract.destinationShape, - contract.destinationStrides, globalShape, - globalStride); - op->setAttr("g_shape", asI64ArrayAttr(builder, globalShape)); - op->setAttr("g_strides", asI64ArrayAttr(builder, globalStride)); -} - -LogicalResult lowerUnsupportedAccStore(Location loc) { - emitError(loc) << "TSTORE ACC lowering TODO for vpto backend"; - return failure(); -} - -LogicalResult lowerUnsupportedMatStore(Location loc) { - emitError(loc) << "TSTORE MAT lowering TODO for vpto backend"; - return failure(); -} - -} // namespace - -FailureOr -createLoopScopeRegion(Location loc, const VPTOLoopScopeContract &contract, - PatternRewriter &rewriter) { - if (contract.kind == VPTOLoopScopeKind::None) - return failure(); - if (contract.kind != VPTOLoopScopeKind::AIVVectorScope) - return failure(); - - auto vecScope = rewriter.create(loc); - vecScope.getBody().push_back(new Block()); - return vecScope; -} - -void set_loop2_stride_outtoub(Operation *copyOp, int64_t dstStride, - int64_t srcStride, Builder &builder) { - copyOp->setAttr("pto.set_loop2_stride_outtoub", - builder.getI64IntegerAttr( - packLoopStrideConfig(dstStride, srcStride))); -} - -void set_loop1_stride_outtoub(Operation *copyOp, int64_t dstStride, - int64_t srcStride, Builder &builder) { - copyOp->setAttr("pto.set_loop1_stride_outtoub", - builder.getI64IntegerAttr( - packLoopStrideConfig(dstStride, srcStride))); -} - -void set_loop_size_outtoub(Operation *copyOp, int64_t loop2, int64_t loop1, - Builder &builder) { - copyOp->setAttr("pto.set_loop_size_outtoub", - builder.getI64IntegerAttr(packLoopSizeConfig(loop2, loop1))); -} - -void set_loop2_stride_ubtoout(Operation *copyOp, int64_t srcStride, - int64_t dstStride, Builder &builder) { - copyOp->setAttr("pto.set_loop2_stride_ubtoout", - builder.getI64IntegerAttr( - packLoopStrideConfig(srcStride, dstStride))); -} - -void set_loop1_stride_ubtoout(Operation *copyOp, int64_t srcStride, - int64_t dstStride, Builder &builder) { - copyOp->setAttr("pto.set_loop1_stride_ubtoout", - builder.getI64IntegerAttr( - packLoopStrideConfig(srcStride, dstStride))); -} - -void set_loop_size_ubtoout(Operation *copyOp, int64_t loop2, int64_t loop1, - Builder &builder) { - copyOp->setAttr("pto.set_loop_size_ubtoout", - builder.getI64IntegerAttr(packLoopSizeConfig(loop2, loop1))); -} - -LogicalResult programCopyGmToUbLoops(Operation *copyOp, - const VPTOLoadContract &contract, - Builder &builder) { - SmallVector globalShape; - SmallVector globalStride; - int64_t nBurst = 0, lenBurst = 0, gmStrideBytes = 0, ubStrideBytes = 0; - int64_t loop1Size = 0, loop2Size = 0; - int64_t loop1DstStrideBytes = 0, loop1SrcStrideBytes = 0; - int64_t loop2DstStrideBytes = 0, loop2SrcStrideBytes = 0; - if (failed(deriveVecNDTransferConfig(contract.sourceShape, contract.sourceStrides, - contract.tileLayout, contract.elementType, - contract.validRows, contract.validCols, - globalShape, globalStride, nBurst, lenBurst, - gmStrideBytes, ubStrideBytes, loop1Size, - loop2Size, loop1DstStrideBytes, - loop1SrcStrideBytes, loop2DstStrideBytes, - loop2SrcStrideBytes))) - return failure(); - - set_loop2_stride_outtoub(copyOp, loop2DstStrideBytes, loop2SrcStrideBytes, builder); - set_loop1_stride_outtoub(copyOp, loop1DstStrideBytes, loop1SrcStrideBytes, builder); - set_loop_size_outtoub(copyOp, loop2Size, loop1Size, builder); - return success(); -} - -LogicalResult programCopyUbToGmLoops(Operation *copyOp, - const VPTOStoreContract &contract, - Builder &builder) { - SmallVector globalShape; - SmallVector globalStride; - int64_t nBurst = 0, lenBurst = 0, burstDstStrideBytes = 0, burstSrcStrideBytes = 0; - int64_t loop1Size = 0, loop2Size = 0; - int64_t loop1SrcStrideBytes = 0, loop1DstStrideBytes = 0; - int64_t loop2SrcStrideBytes = 0, loop2DstStrideBytes = 0; - if (failed(deriveVecNDTransferConfig(contract.destinationShape, - contract.destinationStrides, - "row_major", contract.elementType, - contract.validRows, contract.validCols, - globalShape, globalStride, nBurst, lenBurst, - burstDstStrideBytes, burstSrcStrideBytes, - loop1Size, loop2Size, loop1SrcStrideBytes, - loop1DstStrideBytes, loop2SrcStrideBytes, - loop2DstStrideBytes))) - return failure(); - - set_loop_size_ubtoout(copyOp, loop2Size, loop1Size, builder); - set_loop1_stride_ubtoout(copyOp, loop1SrcStrideBytes, loop1DstStrideBytes, builder); - set_loop2_stride_ubtoout(copyOp, loop2SrcStrideBytes, loop2DstStrideBytes, builder); - return success(); -} - -int64_t deriveStaticRowStride(Value value) { - StringRef layout = deriveTileLayout(value); - if (layout == "col_major") - return 1; - - if (auto tileType = dyn_cast(value.getType())) { - ArrayRef shape = tileType.getShape(); - if (shape.size() >= 2) - return shape[shape.size() - 1]; - } - if (auto shapedType = dyn_cast(value.getType())) { - ArrayRef shape = shapedType.getShape(); - if (shape.size() >= 2) - return shape[shape.size() - 1]; - } - return ShapedType::kDynamic; -} - -int64_t deriveStaticShapeDim(Value value, unsigned dim) { - if (auto tileType = dyn_cast(value.getType())) { - ArrayRef shape = tileType.getShape(); - if (dim < shape.size()) - return shape[dim]; - } - if (auto shapedType = dyn_cast(value.getType())) { - ArrayRef shape = shapedType.getShape(); - if (dim < shape.size()) - return shape[dim]; - } - return ShapedType::kDynamic; -} - -int64_t deriveStaticTileCols(Value value) { - if (auto tileType = dyn_cast(value.getType())) { - ArrayRef shape = tileType.getShape(); - if (!shape.empty()) - return shape.back(); - } - if (auto shapedType = dyn_cast(value.getType())) { - ArrayRef shape = shapedType.getShape(); - if (!shape.empty()) - return shape.back(); - } - return ShapedType::kDynamic; -} - -Value buildFullWidthColsCondition(ArrayRef tileCols, - Value validColsValue, - PatternRewriter &rewriter, Location loc) { - Value condition; - for (int64_t tileCol : tileCols) { - if (tileCol == ShapedType::kDynamic) - return {}; - Value tileColValue = rewriter.create(loc, tileCol); - Value isFullWidth = rewriter.create( - loc, arith::CmpIPredicate::eq, validColsValue, tileColValue); - condition = condition ? rewriter.create(loc, condition, isFullWidth) - : isFullWidth; - } - return condition; -} - -Value buildMinIndexValue(PatternRewriter &rewriter, Location loc, Value lhs, - Value rhs) { - auto lhsLtRhs = rewriter.create(loc, arith::CmpIPredicate::slt, - lhs, rhs); - return rewriter.create(loc, lhsLtRhs, lhs, rhs); -} - -struct PredicateMaterialization { - Value mask; - Value nextScalar; -}; - -PredicateMaterialization buildPredicateForLaneCount(PatternRewriter &rewriter, - Location loc, - Type elementType, - Value laneCount) { - auto maskType = getVPTOMaskTypeForElementType(rewriter.getContext(), elementType); - Value laneCountI32 = laneCount; - if (laneCount.getType().isIndex()) { - laneCountI32 = - rewriter.create(loc, rewriter.getI32Type(), laneCount); - } else if (auto intType = dyn_cast(laneCount.getType())) { - if (intType.getWidth() < 32) - laneCountI32 = rewriter.create(loc, rewriter.getI32Type(), laneCount); - else if (intType.getWidth() > 32) - laneCountI32 = - rewriter.create(loc, rewriter.getI32Type(), laneCount); - } - unsigned bitWidth = 0; - if (auto intType = dyn_cast(elementType)) - bitWidth = intType.getWidth(); - else if (auto floatType = dyn_cast(elementType)) - bitWidth = floatType.getWidth(); - if (bitWidth == 8) { - auto plt = rewriter.create(loc, maskType, rewriter.getI32Type(), - laneCountI32); - return {plt.getMask(), plt.getScalarOut()}; - } - if (bitWidth == 16) { - auto plt = rewriter.create(loc, maskType, rewriter.getI32Type(), - laneCountI32); - return {plt.getMask(), plt.getScalarOut()}; - } - if (bitWidth == 32) { - auto plt = rewriter.create(loc, maskType, rewriter.getI32Type(), - laneCountI32); - return {plt.getMask(), plt.getScalarOut()}; - } - llvm_unreachable("unsupported element type for predicate lane-count lowering"); -} - -Value buildPredicateMaskForLaneCount(PatternRewriter &rewriter, Location loc, - Type elementType, Value laneCount) { - return buildPredicateForLaneCount(rewriter, loc, elementType, laneCount).mask; -} - -Value buildAllPredicateMask(PatternRewriter &rewriter, Location loc, - Type elementType) { - auto maskType = getVPTOMaskTypeForElementType(rewriter.getContext(), elementType); - StringAttr allPattern = rewriter.getStringAttr("PAT_ALL"); - unsigned bitWidth = 0; - if (auto intType = dyn_cast(elementType)) - bitWidth = intType.getWidth(); - else if (auto floatType = dyn_cast(elementType)) - bitWidth = floatType.getWidth(); - if (bitWidth == 8) - return rewriter.create(loc, maskType, allPattern).getResult(); - if (bitWidth == 16) - return rewriter.create(loc, maskType, allPattern).getResult(); - if (bitWidth == 32) - return rewriter.create(loc, maskType, allPattern).getResult(); - llvm_unreachable("unsupported element type for full predicate mask lowering"); -} - -LogicalResult buildMaskedVectorStore(PatternRewriter &rewriter, Location loc, - Value value, Value dstBuffer, - Value dstOffset, Value activeLanes, - int64_t vectorWidth) { - auto vecType = cast(value.getType()); - Value mask = buildPredicateMaskForLaneCount(rewriter, loc, - vecType.getElementType(), - activeLanes); - rewriter.create(loc, value, dstBuffer, dstOffset, StringAttr(), - mask); - return success(); -} - -Attribute buildRowReduceInitValue(Type elementType, StringRef family, - Builder &builder) { - if (!isa(elementType)) - return {}; - - if (family == "rowsum") - return builder.getFloatAttr(elementType, 0.0); - - const llvm::fltSemantics &semantics = [&]() -> const llvm::fltSemantics & { - if (elementType.isF16()) - return llvm::APFloat::IEEEhalf(); - if (elementType.isBF16()) - return llvm::APFloat::BFloat(); - return llvm::APFloat::IEEEsingle(); - }(); - bool negative = family == "rowmax"; - return builder.getFloatAttr(elementType, llvm::APFloat::getInf(semantics, negative)); -} - -Attribute buildPartPadValue(Type elementType, StringRef family, Builder &builder) { - if (family == "partadd") - return builder.getZeroAttr(elementType); - if (isa(elementType)) { - const llvm::fltSemantics &semantics = [&]() -> const llvm::fltSemantics & { - if (elementType.isF16()) - return llvm::APFloat::IEEEhalf(); - if (elementType.isBF16()) - return llvm::APFloat::BFloat(); - return llvm::APFloat::IEEEsingle(); - }(); - bool negative = family == "partmax"; - return builder.getFloatAttr(elementType, llvm::APFloat::getInf(semantics, negative)); - } - if (auto intType = dyn_cast(elementType)) { - unsigned width = intType.getWidth(); - if (intType.isUnsigned()) { - if (family == "partmax") - return builder.getIntegerAttr(elementType, 0); - return builder.getIntegerAttr(elementType, llvm::APInt::getAllOnes(width)); - } - if (family == "partmax") - return builder.getIntegerAttr(elementType, llvm::APInt::getSignedMinValue(width)); - return builder.getIntegerAttr(elementType, llvm::APInt::getSignedMaxValue(width)); - } - return {}; -} - -Attribute buildFillPadValue(Type elementType, PadValueAttr padAttr, Builder &builder) { - if (!padAttr) - return {}; - - switch (padAttr.getValue()) { - case PadValue::Null: - return {}; - case PadValue::Zero: - return builder.getZeroAttr(elementType); - case PadValue::Max: - if (isa(elementType)) { - const llvm::fltSemantics &semantics = [&]() -> const llvm::fltSemantics & { - if (elementType.isF16()) - return llvm::APFloat::IEEEhalf(); - if (elementType.isBF16()) - return llvm::APFloat::BFloat(); - return llvm::APFloat::IEEEsingle(); - }(); - return builder.getFloatAttr(elementType, - llvm::APFloat::getLargest(semantics)); - } - if (auto intType = dyn_cast(elementType)) { - unsigned width = intType.getWidth(); - return intType.isUnsigned() - ? builder.getIntegerAttr(elementType, - llvm::APInt::getMaxValue(width)) - : builder.getIntegerAttr(elementType, - llvm::APInt::getSignedMaxValue(width)); - } - return {}; - case PadValue::Min: - if (isa(elementType)) { - const llvm::fltSemantics &semantics = [&]() -> const llvm::fltSemantics & { - if (elementType.isF16()) - return llvm::APFloat::IEEEhalf(); - if (elementType.isBF16()) - return llvm::APFloat::BFloat(); - return llvm::APFloat::IEEEsingle(); - }(); - auto value = llvm::APFloat::getLargest(semantics); - value.changeSign(); - return builder.getFloatAttr(elementType, value); - } - if (auto intType = dyn_cast(elementType)) { - unsigned width = intType.getWidth(); - return intType.isUnsigned() - ? builder.getIntegerAttr(elementType, llvm::APInt(width, 0)) - : builder.getIntegerAttr(elementType, - llvm::APInt::getSignedMinValue(width)); - } - return {}; - } - return {}; -} - -LogicalResult buildRowReduceVecScope(StringRef family, - const VPTORowReduceContract &contract, - VPTOLoweringStrategy strategy, Value src, - Value dst, - PatternRewriter &rewriter, Location loc) { - auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); - if (!vecType) - return emitError(loc) << "unsupported VPTO row-reduce element type"; - - Value srcBuffer = materializeBufferPointer(src, contract.elementType, - getMemorySpace(src), rewriter, loc); - Value dstBuffer = materializeBufferPointer(dst, contract.elementType, - getMemorySpace(dst), rewriter, loc); - if (!srcBuffer || !dstBuffer) - return emitError(loc) << "requires pointer-backed tile buffers for row-reduce lowering"; - - if (contract.validRows == ShapedType::kDynamic || - contract.validCols == ShapedType::kDynamic) - return emitError(loc) << family << " lowering currently requires static valid rows and cols"; - - int64_t srcRowStride = deriveStaticRowStride(src); - int64_t dstRowStride = deriveStaticRowStride(dst); - if (srcRowStride == ShapedType::kDynamic || dstRowStride == ShapedType::kDynamic) - return emitError(loc) << family << " lowering requires static row strides"; - - Attribute initValue = buildRowReduceInitValue(contract.elementType, family, rewriter); - if (!initValue) - return emitError(loc) << family << " lowering supports only f16 and f32 element types"; - - auto getRowReduceStoreDist = [&]() -> StringAttr { - if (contract.elementType.isF16() || contract.elementType.isBF16()) - return rewriter.getStringAttr("1PT_B16"); - if (contract.elementType.isF32()) - return rewriter.getStringAttr("1PT_B32"); - return {}; - }; - StringAttr storeDist = getRowReduceStoreDist(); - if (!storeDist) - return emitError(loc) << family << " lowering supports only f16 and f32 row-reduce stores"; - - int64_t vectorWidth = vecType.getElementCount(); - int64_t repeatTimes = llvm::divideCeil(contract.validCols, vectorWidth); - - Value c0 = rewriter.create(loc, 0); - Value c1 = rewriter.create(loc, 1); - Value rowsUpper = rewriter.create(loc, contract.validRows); - Value srcRowStrideValue = rewriter.create(loc, srcRowStride); - Value dstRowStrideValue = rewriter.create(loc, dstRowStride); - Value vectorWidthValue = rewriter.create(loc, vectorWidth); - Value initScalar = rewriter.create(loc, cast(initValue)); - - FailureOr vecScope = - createLoopScopeRegion(loc, contract.loopScope, rewriter); - if (failed(vecScope)) - return emitError(loc) << "failed to create AIV vector scope region"; - - OpBuilder::InsertionGuard aivGuard(rewriter); - rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); - Value dstPredicate = - buildPredicateMaskForLaneCount(rewriter, loc, contract.elementType, c1); - Value validColsValue = - rewriter.create(loc, contract.validCols); - - if (strategy == VPTOLoweringStrategy::PostUpdate) { - auto rowLoop = - rewriter.create(loc, c0, rowsUpper, c1, ValueRange{dstBuffer}); - - OpBuilder::InsertionGuard rowGuard(rewriter); - rewriter.setInsertionPointToStart(rowLoop.getBody()); - Value row = rowLoop.getInductionVar(); - Value dstPtr = rowLoop.getRegionIterArgs().front(); - Value rowBase = rewriter.create(loc, row, srcRowStrideValue); - Value srcPtr = - adjustPointerByElemOffset(srcBuffer, rowBase, getElementByteSize(contract.elementType), - rewriter, loc); - Value acc = rewriter.create(loc, vecType, initScalar); - Value remainingCols = rewriter.create( - loc, contract.validCols, 32); - for (int64_t repeatIndex = 0; repeatIndex < repeatTimes; ++repeatIndex) { - auto predicateState = - buildPredicateForLaneCount(rewriter, loc, contract.elementType, remainingCols); - Value srcPredicate = predicateState.mask; - auto srcVecOp = rewriter.create( - loc, TypeRange{vecType, srcPtr.getType()}, srcPtr, vectorWidthValue, - rewriter.getStringAttr("NORM")); - Value srcVec = srcVecOp.getResult(); - srcPtr = srcVecOp.getUpdatedSource(); - - Value reduced; - if (family == "rowsum") - reduced = rewriter.create( - loc, getVcaddResultVRegType(rewriter.getContext(), vecType), srcVec, - srcPredicate); - else if (family == "rowmax") - reduced = rewriter.create(loc, vecType, srcVec, srcPredicate); - else if (family == "rowmin") - reduced = rewriter.create(loc, vecType, srcVec, srcPredicate); - else - return emitError(loc) << "unsupported VPTO row-reduce family: " << family; - - Value fullMask = buildAllPredicateMask(rewriter, loc, contract.elementType); - if (family == "rowsum") { - if (reduced.getType() != vecType) - reduced = rewriter.create(loc, vecType, reduced); - acc = rewriter.create(loc, vecType, acc, reduced, fullMask); - } else if (family == "rowmax") - acc = rewriter.create(loc, vecType, acc, reduced, fullMask); - else - acc = rewriter.create(loc, vecType, acc, reduced, fullMask); - remainingCols = predicateState.nextScalar; - } - - auto storeOp = rewriter.create(loc, dstPtr.getType(), acc, dstPtr, - dstRowStrideValue, storeDist, - dstPredicate); - Value nextDstPtr = storeOp.getUpdatedDestination(); - rewriter.create(loc, nextDstPtr); - return success(); - } - - auto rowLoop = rewriter.create(loc, c0, rowsUpper, c1); - OpBuilder::InsertionGuard rowGuard(rewriter); - rewriter.setInsertionPointToStart(rowLoop.getBody()); - Value row = rowLoop.getInductionVar(); - Value rowBase = rewriter.create(loc, row, srcRowStrideValue); - Value acc = rewriter.create(loc, vecType, initScalar); - for (int64_t repeatIndex = 0; repeatIndex < repeatTimes; ++repeatIndex) { - Value repeat = rewriter.create(loc, repeatIndex); - Value repeatBase = - rewriter.create(loc, repeat, vectorWidthValue); - Value srcOffset = - rewriter.create(loc, rowBase, repeatBase); - Value remainingCols = - rewriter.create(loc, validColsValue, repeatBase); - Value srcPredicate = buildPredicateMaskForLaneCount( - rewriter, loc, contract.elementType, remainingCols); - Value srcVec = - rewriter.create(loc, vecType, srcBuffer, srcOffset, - StringAttr()) - .getResult(); - - Value reduced; - if (family == "rowsum") - reduced = rewriter.create( - loc, getVcaddResultVRegType(rewriter.getContext(), vecType), srcVec, - srcPredicate); - else if (family == "rowmax") - reduced = rewriter.create(loc, vecType, srcVec, srcPredicate); - else if (family == "rowmin") - reduced = rewriter.create(loc, vecType, srcVec, srcPredicate); - else - return emitError(loc) << "unsupported VPTO row-reduce family: " << family; - - Value fullMask = buildAllPredicateMask(rewriter, loc, contract.elementType); - if (family == "rowsum") { - if (reduced.getType() != vecType) - reduced = rewriter.create(loc, vecType, reduced); - acc = rewriter.create(loc, vecType, acc, reduced, fullMask); - } else if (family == "rowmax") - acc = rewriter.create(loc, vecType, acc, reduced, fullMask); - else - acc = rewriter.create(loc, vecType, acc, reduced, fullMask); - } - - Value dstOffset = rewriter.create(loc, row, dstRowStrideValue); - rewriter.create(loc, acc, dstBuffer, dstOffset, storeDist, - dstPredicate); - return success(); -} - -LogicalResult buildColReduceVecScope(StringRef family, - const VPTOColReduceContract &contract, - Value src, Value dst, Value tmp, - PatternRewriter &rewriter, Location loc) { - auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); - if (!vecType) - return emitError(loc) << "unsupported VPTO col-reduce element type"; - - Value srcBuffer = materializeBufferPointer(src, contract.elementType, - getMemorySpace(src), rewriter, loc); - Value dstBuffer = materializeBufferPointer(dst, contract.elementType, - getMemorySpace(dst), rewriter, loc); - if (!srcBuffer || !dstBuffer) - return emitError(loc) << "requires pointer-backed tile buffers for col-reduce lowering"; - - Value tmpBuffer; - if (contract.isBinary) { - tmpBuffer = materializeBufferPointer(tmp, contract.elementType, getMemorySpace(tmp), - rewriter, loc); - if (!tmpBuffer) - return emitError(loc) << "binary colsum lowering requires pointer-backed tmp tile"; - } - - int64_t srcRowStride = deriveStaticRowStride(src); - int64_t dstRowStride = deriveStaticRowStride(dst); - int64_t tmpRowStride = - contract.isBinary ? deriveStaticRowStride(tmp) : ShapedType::kDynamic; - if (srcRowStride == ShapedType::kDynamic || dstRowStride == ShapedType::kDynamic || - (contract.isBinary && tmpRowStride == ShapedType::kDynamic)) - return emitError(loc) << family << " lowering requires static row strides"; - - Attribute initValue = buildRowReduceInitValue(contract.elementType, family, rewriter); - if (!initValue) - return emitError(loc) << family << " lowering supports only f16 and f32 element types"; - - int64_t vectorWidth = vecType.getElementCount(); - int64_t repeatTimes = llvm::divideCeil(contract.validCols, vectorWidth); - - Value c0 = rewriter.create(loc, 0); - Value c1 = rewriter.create(loc, 1); - Value repeatUpper = rewriter.create(loc, repeatTimes); - Value rowUpper = rewriter.create(loc, contract.validRows); - Value srcRowStrideValue = rewriter.create(loc, srcRowStride); - Value dstRowStrideValue = rewriter.create(loc, dstRowStride); - Value vectorWidthValue = rewriter.create(loc, vectorWidth); - Value initScalar = rewriter.create(loc, cast(initValue)); - - FailureOr vecScope = - createLoopScopeRegion(loc, contract.loopScope, rewriter); - if (failed(vecScope)) - return emitError(loc) << "failed to create AIV vector scope region"; - - OpBuilder::InsertionGuard aivGuard(rewriter); - rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); - auto chunkLoop = rewriter.create(loc, c0, repeatUpper, c1); - - OpBuilder::InsertionGuard chunkGuard(rewriter); - rewriter.setInsertionPointToStart(chunkLoop.getBody()); - Value chunk = chunkLoop.getInductionVar(); - Value chunkOffset = rewriter.create(loc, chunk, vectorWidthValue); - - if (!contract.isBinary) { - Value firstRowOffset = chunkOffset; - Value acc0 = - rewriter.create(loc, vecType, srcBuffer, firstRowOffset, StringAttr()).getResult(); - auto rowLoop = rewriter.create(loc, c1, rowUpper, c1, ValueRange{acc0}); - OpBuilder::InsertionGuard rowGuard(rewriter); - rewriter.setInsertionPointToStart(rowLoop.getBody()); - Value row = rowLoop.getInductionVar(); - Value acc = rowLoop.getRegionIterArgs().front(); - Value rowBase = rewriter.create(loc, row, srcRowStrideValue); - Value srcOffset = rewriter.create(loc, rowBase, chunkOffset); - Value srcVec = - rewriter.create(loc, vecType, srcBuffer, srcOffset, StringAttr()).getResult(); - Value nextAcc; - Value fullMask = buildAllPredicateMask(rewriter, loc, contract.elementType); - if (family == "colmax") - nextAcc = rewriter.create(loc, vecType, acc, srcVec, fullMask); - else if (family == "colmin") - nextAcc = rewriter.create(loc, vecType, acc, srcVec, fullMask); - else - nextAcc = rewriter.create(loc, vecType, acc, srcVec, fullMask); - rewriter.create(loc, nextAcc); - - rewriter.setInsertionPointAfter(rowLoop); - Value dstOffset = chunkOffset; - rewriter.create( - loc, rowLoop.getResult(0), dstBuffer, dstOffset, StringAttr(), - buildAllPredicateMask(rewriter, loc, contract.elementType)); - return success(); - } - - Value tmpRowStrideValue = rewriter.create(loc, tmpRowStride); - auto reducePair = [&](Value lhs, Value rhs) -> Value { - return rewriter.create( - loc, vecType, lhs, rhs, buildAllPredicateMask(rewriter, loc, contract.elementType)) - .getResult(); - }; - - int64_t nLoopStatic = contract.validRows / 2; - bool remainStatic = (contract.validRows % 2) != 0; - Value pairUpper = rewriter.create(loc, nLoopStatic); - auto pairLoop = rewriter.create(loc, c0, pairUpper, c1); - { - OpBuilder::InsertionGuard pairGuard(rewriter); - rewriter.setInsertionPointToStart(pairLoop.getBody()); - Value pair = pairLoop.getInductionVar(); - Value row0 = rewriter.create( - loc, rewriter.create(loc, pair, rewriter.create(loc, 2)), - srcRowStrideValue); - Value row1 = rewriter.create( - loc, rewriter.create(loc, - rewriter.create(loc, pair, rewriter.create(loc, 2)), - c1), - srcRowStrideValue); - Value src0Offset = rewriter.create(loc, row0, chunkOffset); - Value src1Offset = rewriter.create(loc, row1, chunkOffset); - Value lhs = rewriter.create(loc, vecType, srcBuffer, src0Offset, StringAttr()).getResult(); - Value rhs = rewriter.create(loc, vecType, srcBuffer, src1Offset, StringAttr()).getResult(); - Value sum = reducePair(lhs, rhs); - Value tmpOffset = rewriter.create(loc, pair, tmpRowStrideValue); - rewriter.create(loc, sum, tmpBuffer, tmpOffset, StringAttr(), - buildAllPredicateMask(rewriter, loc, - contract.elementType)); - } - - if (remainStatic && nLoopStatic > 0) { - Value lastRowOffset = rewriter.create( - loc, - rewriter.create( - loc, rewriter.create(loc, contract.validRows - 1), - srcRowStrideValue), - chunkOffset); - Value tmpOffset = rewriter.create( - loc, rewriter.create(loc, nLoopStatic - 1), tmpRowStrideValue); - Value lhs = rewriter.create(loc, vecType, srcBuffer, lastRowOffset, StringAttr()).getResult(); - Value rhs = rewriter.create(loc, vecType, tmpBuffer, tmpOffset, StringAttr()).getResult(); - Value sum = reducePair(lhs, rhs); - rewriter.create(loc, sum, tmpBuffer, tmpOffset, StringAttr(), - buildAllPredicateMask(rewriter, loc, - contract.elementType)); - } - - int64_t currentRows = nLoopStatic; - while (currentRows > 1) { - int64_t nextRows = currentRows / 2; - bool remain = (currentRows % 2) != 0; - Value nextUpper = rewriter.create(loc, nextRows); - auto foldLoop = rewriter.create(loc, c0, nextUpper, c1); - OpBuilder::InsertionGuard foldGuard(rewriter); - rewriter.setInsertionPointToStart(foldLoop.getBody()); - Value pair = foldLoop.getInductionVar(); - Value idx2 = rewriter.create( - loc, pair, rewriter.create(loc, 2)); - Value idx2p1 = rewriter.create(loc, idx2, c1); - Value lhsOff = rewriter.create(loc, idx2, tmpRowStrideValue); - Value rhsOff = rewriter.create(loc, idx2p1, tmpRowStrideValue); - Value lhs = rewriter.create(loc, vecType, tmpBuffer, lhsOff, StringAttr()).getResult(); - Value rhs = rewriter.create(loc, vecType, tmpBuffer, rhsOff, StringAttr()).getResult(); - Value sum = reducePair(lhs, rhs); - Value outOff = rewriter.create(loc, pair, tmpRowStrideValue); - rewriter.create(loc, sum, tmpBuffer, outOff, StringAttr(), - buildAllPredicateMask(rewriter, loc, - contract.elementType)); - - rewriter.setInsertionPointAfter(foldLoop); - if (remain && nextRows > 0) { - Value lhsOff = rewriter.create( - loc, rewriter.create(loc, nextRows - 1), tmpRowStrideValue); - Value rhsOff = rewriter.create( - loc, rewriter.create(loc, 2 * nextRows), tmpRowStrideValue); - Value lhs = rewriter.create(loc, vecType, tmpBuffer, lhsOff, StringAttr()).getResult(); - Value rhs = rewriter.create(loc, vecType, tmpBuffer, rhsOff, StringAttr()).getResult(); - Value sum = reducePair(lhs, rhs); - rewriter.create(loc, sum, tmpBuffer, lhsOff, StringAttr(), - buildAllPredicateMask(rewriter, loc, - contract.elementType)); - } - currentRows = nextRows; - } - - Value finalVec; - if (currentRows == 0) { - finalVec = rewriter.create(loc, vecType, initScalar).getResult(); - } else { - finalVec = rewriter.create(loc, vecType, tmpBuffer, c0, StringAttr()).getResult(); - } - Value dstOffset = chunkOffset; - rewriter.create(loc, finalVec, dstBuffer, dstOffset, StringAttr(), - buildAllPredicateMask(rewriter, loc, - contract.elementType)); - return success(); -} - -LogicalResult buildPartFill(StringRef family, const VPTOPartContract &contract, - Value dstBuffer, pto::VRegType vecType, - int64_t dstStride, PatternRewriter &rewriter, - Location loc) { - Attribute initValue = buildPartPadValue(contract.elementType, family, rewriter); - if (!initValue) - return emitError(loc) << "unsupported pad value for " << family; - int64_t vectorWidth = vecType.getElementCount(); - int64_t repeatTimes = llvm::divideCeil(contract.dstValidCols, vectorWidth); - Value c0 = rewriter.create(loc, 0); - Value c1 = rewriter.create(loc, 1); - Value rowsUpper = rewriter.create(loc, contract.dstValidRows); - Value repeatUpper = rewriter.create(loc, repeatTimes); - Value dstStrideValue = rewriter.create(loc, dstStride); - Value vectorWidthValue = rewriter.create(loc, vectorWidth); - Value initScalar = rewriter.create(loc, cast(initValue)); - Value initVec = rewriter.create(loc, vecType, initScalar); - auto rowLoop = rewriter.create(loc, c0, rowsUpper, c1); - OpBuilder::InsertionGuard rowGuard(rewriter); - rewriter.setInsertionPointToStart(rowLoop.getBody()); - auto chunkLoop = rewriter.create(loc, c0, repeatUpper, c1); - OpBuilder::InsertionGuard chunkGuard(rewriter); - rewriter.setInsertionPointToStart(chunkLoop.getBody()); - Value row = rowLoop.getInductionVar(); - Value chunk = chunkLoop.getInductionVar(); - Value rowBase = rewriter.create(loc, row, dstStrideValue); - Value chunkBase = rewriter.create(loc, chunk, vectorWidthValue); - Value dstOffset = rewriter.create(loc, rowBase, chunkBase); - rewriter.create(loc, initVec, dstBuffer, dstOffset, StringAttr(), - buildAllPredicateMask(rewriter, loc, - vecType.getElementType())); - rewriter.setInsertionPointAfter(chunkLoop); - return success(); -} - -LogicalResult buildPartCopyRegion(Value srcBuffer, Value dstBuffer, pto::VRegType vecType, - int64_t srcStride, int64_t dstStride, - int64_t startRow, int64_t validRows, - int64_t validCols, PatternRewriter &rewriter, - Location loc) { - int64_t vectorWidth = vecType.getElementCount(); - int64_t repeatTimes = llvm::divideCeil(validCols, vectorWidth); - Value c0 = rewriter.create(loc, 0); - Value c1 = rewriter.create(loc, 1); - Value rowsUpper = rewriter.create(loc, validRows); - Value repeatUpper = rewriter.create(loc, repeatTimes); - Value srcStrideValue = rewriter.create(loc, srcStride); - Value dstStrideValue = rewriter.create(loc, dstStride); - Value vectorWidthValue = rewriter.create(loc, vectorWidth); - Value startRowValue = rewriter.create(loc, startRow); - auto rowLoop = rewriter.create(loc, startRowValue, rowsUpper, c1); - OpBuilder::InsertionGuard rowGuard(rewriter); - rewriter.setInsertionPointToStart(rowLoop.getBody()); - auto chunkLoop = rewriter.create(loc, c0, repeatUpper, c1); - OpBuilder::InsertionGuard chunkGuard(rewriter); - rewriter.setInsertionPointToStart(chunkLoop.getBody()); - Value row = rowLoop.getInductionVar(); - Value chunk = chunkLoop.getInductionVar(); - Value rowSrc = rewriter.create(loc, row, srcStrideValue); - Value rowDst = rewriter.create(loc, row, dstStrideValue); - Value chunkBase = rewriter.create(loc, chunk, vectorWidthValue); - Value srcOffset = rewriter.create(loc, rowSrc, chunkBase); - Value dstOffset = rewriter.create(loc, rowDst, chunkBase); - Value vec = rewriter.create(loc, vecType, srcBuffer, srcOffset, StringAttr()).getResult(); - rewriter.create(loc, vec, dstBuffer, dstOffset, StringAttr(), - buildAllPredicateMask(rewriter, loc, - vecType.getElementType())); - rewriter.setInsertionPointAfter(chunkLoop); - return success(); -} - -LogicalResult buildPartBinaryRegion(StringRef family, Value src0Buffer, Value src1Buffer, - Value dstBuffer, pto::VRegType vecType, - int64_t src0Stride, int64_t src1Stride, - int64_t dstStride, int64_t validRows, - int64_t validCols, PatternRewriter &rewriter, - Location loc) { - int64_t vectorWidth = vecType.getElementCount(); - int64_t repeatTimes = llvm::divideCeil(validCols, vectorWidth); - Value c0 = rewriter.create(loc, 0); - Value c1 = rewriter.create(loc, 1); - Value rowsUpper = rewriter.create(loc, validRows); - Value repeatUpper = rewriter.create(loc, repeatTimes); - Value src0StrideValue = rewriter.create(loc, src0Stride); - Value src1StrideValue = rewriter.create(loc, src1Stride); - Value dstStrideValue = rewriter.create(loc, dstStride); - Value vectorWidthValue = rewriter.create(loc, vectorWidth); - auto rowLoop = rewriter.create(loc, c0, rowsUpper, c1); - OpBuilder::InsertionGuard rowGuard(rewriter); - rewriter.setInsertionPointToStart(rowLoop.getBody()); - auto chunkLoop = rewriter.create(loc, c0, repeatUpper, c1); - OpBuilder::InsertionGuard chunkGuard(rewriter); - rewriter.setInsertionPointToStart(chunkLoop.getBody()); - Value row = rowLoop.getInductionVar(); - Value chunk = chunkLoop.getInductionVar(); - Value chunkBase = rewriter.create(loc, chunk, vectorWidthValue); - Value rowSrc0 = rewriter.create(loc, row, src0StrideValue); - Value rowSrc1 = rewriter.create(loc, row, src1StrideValue); - Value rowDst = rewriter.create(loc, row, dstStrideValue); - Value src0Offset = rewriter.create(loc, rowSrc0, chunkBase); - Value src1Offset = rewriter.create(loc, rowSrc1, chunkBase); - Value dstOffset = rewriter.create(loc, rowDst, chunkBase); - Value lhs = rewriter.create(loc, vecType, src0Buffer, src0Offset, StringAttr()).getResult(); - Value rhs = rewriter.create(loc, vecType, src1Buffer, src1Offset, StringAttr()).getResult(); - Value fullMask = buildAllPredicateMask(rewriter, loc, vecType.getElementType()); - Value out; - if (family == "partadd") - out = rewriter.create(loc, vecType, lhs, rhs, fullMask); - else if (family == "partmax") - out = rewriter.create(loc, vecType, lhs, rhs, fullMask); - else if (family == "partmin") - out = rewriter.create(loc, vecType, lhs, rhs, fullMask); - else - return emitError(loc) << "unsupported part family: " << family; - rewriter.create(loc, out, dstBuffer, dstOffset, StringAttr(), - buildAllPredicateMask(rewriter, loc, - vecType.getElementType())); - rewriter.setInsertionPointAfter(chunkLoop); - return success(); -} - -LogicalResult buildPartVecScope(StringRef family, const VPTOPartContract &contract, - Value src0, Value src1, Value dst, - PatternRewriter &rewriter, Location loc) { - auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); - if (!vecType) - return emitError(loc) << "unsupported VPTO part element type"; - Value src0Buffer = materializeBufferLikeAddress(src0, contract.elementType, - getMemorySpace(src0), rewriter, loc); - Value src1Buffer = materializeBufferLikeAddress(src1, contract.elementType, - getMemorySpace(src1), rewriter, loc); - Value dstBuffer = materializeBufferLikeAddress(dst, contract.elementType, - getMemorySpace(dst), rewriter, loc); - if (!src0Buffer || !src1Buffer || !dstBuffer) - return emitError(loc) << "requires pointer-backed tile buffers for part lowering"; - int64_t src0Stride = deriveStaticRowStride(src0); - int64_t src1Stride = deriveStaticRowStride(src1); - int64_t dstStride = deriveStaticRowStride(dst); - if (src0Stride == ShapedType::kDynamic || src1Stride == ShapedType::kDynamic || - dstStride == ShapedType::kDynamic) - return emitError(loc) << family << " lowering requires static row strides"; - - Value c0 = rewriter.create(loc, 0); - Value c1 = rewriter.create(loc, 1); - FailureOr vecScope = - createLoopScopeRegion(loc, contract.loopScope, rewriter); - if (failed(vecScope)) - return emitError(loc) << "failed to create AIV vector scope region"; - OpBuilder::InsertionGuard aivGuard(rewriter); - rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); - - auto condSrc0EqDst = contract.src0ValidRows == contract.dstValidRows && - contract.src0ValidCols == contract.dstValidCols; - auto condSrc0RowLtDst = contract.src0ValidRows < contract.dstValidRows && - contract.src0ValidCols == contract.dstValidCols; - auto condSrc0ColLtDst = contract.src0ValidRows <= contract.dstValidRows && - contract.src0ValidCols < contract.dstValidCols; - auto condSrc1EqDst = contract.src1ValidRows == contract.dstValidRows && - contract.src1ValidCols == contract.dstValidCols; - auto condSrc1RowLtDst = contract.src1ValidRows < contract.dstValidRows && - contract.src1ValidCols == contract.dstValidCols; - auto condSrc1ColLtDst = contract.src1ValidRows <= contract.dstValidRows && - contract.src1ValidCols < contract.dstValidCols; - - if (family == "partadd") { - if (condSrc0EqDst && condSrc1EqDst) - return buildPartBinaryRegion(family, src0Buffer, src1Buffer, dstBuffer, vecType, - src0Stride, src1Stride, dstStride, - contract.dstValidRows, contract.dstValidCols, - rewriter, loc); - if (condSrc0ColLtDst && condSrc1EqDst) { - if (failed(buildPartCopyRegion(src1Buffer, dstBuffer, vecType, src1Stride, dstStride, - 0, contract.src1ValidRows, contract.dstValidCols, - rewriter, loc))) - return failure(); - if (contract.src0ValidCols != 0) - return buildPartBinaryRegion(family, src0Buffer, dstBuffer, dstBuffer, vecType, - src0Stride, dstStride, dstStride, - contract.src0ValidRows, contract.src0ValidCols, - rewriter, loc); - return success(); - } - if (condSrc0RowLtDst && condSrc1EqDst) { - if (contract.src0ValidRows != 0 && - failed(buildPartBinaryRegion(family, src0Buffer, src1Buffer, dstBuffer, vecType, - src0Stride, src1Stride, dstStride, - contract.src0ValidRows, contract.src0ValidCols, - rewriter, loc))) - return failure(); - return buildPartCopyRegion(src1Buffer, dstBuffer, vecType, src1Stride, dstStride, - contract.src0ValidRows, contract.src1ValidRows, - contract.dstValidCols, rewriter, loc); - } - if (condSrc1ColLtDst && condSrc0EqDst) { - if (failed(buildPartCopyRegion(src0Buffer, dstBuffer, vecType, src0Stride, dstStride, - 0, contract.src0ValidRows, contract.dstValidCols, - rewriter, loc))) - return failure(); - if (contract.src1ValidCols != 0) - return buildPartBinaryRegion(family, src1Buffer, dstBuffer, dstBuffer, vecType, - src1Stride, dstStride, dstStride, - contract.src1ValidRows, contract.src1ValidCols, - rewriter, loc); - return success(); - } - if (condSrc1RowLtDst && condSrc0EqDst) { - if (contract.src1ValidRows != 0 && - failed(buildPartBinaryRegion(family, src0Buffer, src1Buffer, dstBuffer, vecType, - src0Stride, src1Stride, dstStride, - contract.src1ValidRows, contract.src1ValidCols, - rewriter, loc))) - return failure(); - return buildPartCopyRegion(src0Buffer, dstBuffer, vecType, src0Stride, dstStride, - contract.src1ValidRows, contract.src0ValidRows, - contract.dstValidCols, rewriter, loc); - } - return emitError(loc) << "partadd lowering only supports PTO-covered destination-equality/extension cases"; - } - - bool condDstGeSrc = contract.src0ValidRows <= contract.dstValidRows && - contract.src0ValidCols <= contract.dstValidCols && - contract.src1ValidRows <= contract.dstValidRows && - contract.src1ValidCols <= contract.dstValidCols; - if (!condDstGeSrc) - return emitError(loc) << family << " lowering only supports dst >= src0/src1 shape relation"; - if (failed(buildPartFill(family, contract, dstBuffer, vecType, dstStride, rewriter, loc))) - return failure(); - if (failed(buildPartCopyRegion(src0Buffer, dstBuffer, vecType, src0Stride, dstStride, - 0, contract.src0ValidRows, contract.src0ValidCols, - rewriter, loc))) - return failure(); - return buildPartBinaryRegion(family, dstBuffer, src1Buffer, dstBuffer, vecType, - dstStride, src1Stride, dstStride, - contract.src1ValidRows, contract.src1ValidCols, - rewriter, loc); -} - -LogicalResult buildUnaryVecScope(StringRef family, - const VPTOUnaryContract &contract, - VPTOLoweringStrategy strategy, Value src, - Value dst, PatternRewriter &rewriter, - Location loc) { - auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); - if (!vecType) - return emitError(loc) << "unsupported VPTO unary element type"; - - Value srcBuffer = materializeBufferPointer(src, contract.elementType, - getMemorySpace(src), rewriter, loc); - Value dstBuffer = materializeBufferPointer(dst, contract.elementType, - getMemorySpace(dst), rewriter, loc); - if (!srcBuffer || !dstBuffer) - return emitError(loc) << "requires pointer-backed tile buffers for unary lowering"; - - int64_t vectorWidth = vecType.getElementCount(); - Value validRowsValue; - Value validColsValue; - int64_t validRows = ShapedType::kDynamic; - int64_t validCols = ShapedType::kDynamic; - deriveValidShapeValues(dst, validRowsValue, validColsValue); - deriveValidShape(dst, validRows, validCols); - if (failed(resolveExecutionValidShape(dst, validRowsValue, validColsValue, validRows, - validCols, rewriter, loc))) - return emitError(loc) << "unary lowering requires valid rows and cols"; - - int64_t srcStride = deriveStaticRowStride(src); - int64_t dstStride = deriveStaticRowStride(dst); - int64_t srcCols = deriveStaticTileCols(src); - int64_t dstCols = deriveStaticTileCols(dst); - if (srcStride == ShapedType::kDynamic || dstStride == ShapedType::kDynamic || - srcCols == ShapedType::kDynamic || dstCols == ShapedType::kDynamic) - return emitError(loc) << "unary lowering requires static row strides and cols"; - - auto buildUnaryValue = [&](Value loaded, Value predicate) -> FailureOr { - if (family == "abs") - return rewriter.create(loc, vecType, loaded, predicate).getResult(); - if (family == "exp") - return rewriter.create(loc, vecType, loaded, predicate).getResult(); - if (family == "log") - return rewriter.create(loc, vecType, loaded, predicate).getResult(); - if (family == "sqrt") - return rewriter.create(loc, vecType, loaded, predicate).getResult(); - if (family == "relu") - return rewriter.create(loc, vecType, loaded, predicate).getResult(); - if (family == "not") - return rewriter.create(loc, vecType, loaded, predicate).getResult(); - return failure(); - }; - - Value c0 = rewriter.create(loc, 0); - Value c1 = rewriter.create(loc, 1); - Value totalElementsValue = - rewriter.create(loc, validRowsValue, validColsValue); - Value vectorStepValue = - rewriter.create(loc, vectorWidth); - Value srcStrideValue = rewriter.create(loc, srcStride); - Value dstStrideValue = rewriter.create(loc, dstStride); - Value scalarInit = rewriter.create(loc, rewriter.getI32Type(), - totalElementsValue); - Value rowScalarInit = rewriter.create(loc, rewriter.getI32Type(), - validColsValue); - Value fullWidthCond = - buildFullWidthColsCondition({srcCols, dstCols}, validColsValue, rewriter, loc); - if (!fullWidthCond) - return emitError(loc) << "unary lowering could not materialize full-width selector"; - - FailureOr vecScope = - createLoopScopeRegion(loc, contract.loopScope, rewriter); - if (failed(vecScope)) - return emitError(loc) << "failed to create AIV vector scope region"; - - OpBuilder::InsertionGuard aivGuard(rewriter); - rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); - auto ifOp = rewriter.create(loc, TypeRange{}, fullWidthCond, - /*withElseRegion=*/true); - rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); - { - scf::ForOp chunkLoop; - if (strategy == VPTOLoweringStrategy::PostUpdate) { - chunkLoop = rewriter.create( - loc, c0, totalElementsValue, vectorStepValue, - ValueRange{srcBuffer, dstBuffer, scalarInit}); - } else { - chunkLoop = rewriter.create(loc, c0, totalElementsValue, - vectorStepValue, - ValueRange{scalarInit}); - } - rewriter.setInsertionPointToStart(chunkLoop.getBody()); - Value remaining = chunkLoop.getRegionIterArgs().back(); - PredicateMaterialization predicateState = - buildPredicateForLaneCount(rewriter, loc, contract.elementType, remaining); - Value loadBase = srcBuffer; - Value storeBase = dstBuffer; - Value loadOffset = chunkLoop.getInductionVar(); - Value storeOffset = chunkLoop.getInductionVar(); - if (strategy == VPTOLoweringStrategy::PostUpdate) { - loadBase = chunkLoop.getRegionIterArgs()[0]; - storeBase = chunkLoop.getRegionIterArgs()[1]; - loadOffset = vectorStepValue; - storeOffset = vectorStepValue; - } - Value loaded; - Value nextSrc = {}; - if (strategy == VPTOLoweringStrategy::PostUpdate) { - auto vlds = rewriter.create(loc, vecType, loadBase.getType(), - loadBase, loadOffset, StringAttr()); - loaded = vlds.getResult(); - nextSrc = vlds.getUpdatedSource(); - } else { - auto vlds = - rewriter.create(loc, vecType, loadBase, loadOffset, StringAttr()); - loaded = vlds.getResult(); - } - FailureOr computed = buildUnaryValue(loaded, predicateState.mask); - if (failed(computed)) - return emitError(loc) << "unsupported VPTO unary family: " << family; - if (strategy == VPTOLoweringStrategy::PostUpdate) { - auto vsts = rewriter.create(loc, storeBase.getType(), *computed, - storeBase, storeOffset, StringAttr(), - predicateState.mask); - Value nextDst = vsts.getUpdatedDestination(); - rewriter.create( - loc, ValueRange{nextSrc, nextDst, predicateState.nextScalar}); - } else { - rewriter.create(loc, *computed, storeBase, storeOffset, - StringAttr(), predicateState.mask); - rewriter.create(loc, ValueRange{predicateState.nextScalar}); - } - } - - rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); - { - Value repeatUpper = rewriter.create(loc, validColsValue, - vectorStepValue); - auto rowLoop = rewriter.create(loc, c0, validRowsValue, c1); - rewriter.setInsertionPointToStart(rowLoop.getBody()); - Value row = rowLoop.getInductionVar(); - Value srcRowBase = rewriter.create(loc, row, srcStrideValue); - Value dstRowBase = rewriter.create(loc, row, dstStrideValue); - auto repeatLoop = rewriter.create(loc, c0, repeatUpper, c1, - ValueRange{rowScalarInit}); - rewriter.setInsertionPointToStart(repeatLoop.getBody()); - Value remaining = repeatLoop.getRegionIterArgs()[0]; - PredicateMaterialization predicateState = - buildPredicateForLaneCount(rewriter, loc, contract.elementType, remaining); - Value chunkBase = - rewriter.create(loc, repeatLoop.getInductionVar(), vectorStepValue); - Value srcOffset = rewriter.create(loc, srcRowBase, chunkBase); - Value dstOffset = rewriter.create(loc, dstRowBase, chunkBase); - auto loaded = - rewriter.create(loc, vecType, srcBuffer, srcOffset, StringAttr()); - FailureOr computed = - buildUnaryValue(loaded.getResult(), predicateState.mask); - if (failed(computed)) - return emitError(loc) << "unsupported VPTO unary family: " << family; - rewriter.create(loc, *computed, dstBuffer, dstOffset, - StringAttr(), predicateState.mask); - rewriter.create(loc, ValueRange{predicateState.nextScalar}); - } - rewriter.setInsertionPointAfter(ifOp); - - return success(); -} - -LogicalResult buildBinaryVecScope(StringRef family, - const VPTOBinaryContract &contract, - VPTOLoweringStrategy strategy, Value src0, - Value src1, Value dst, - PatternRewriter &rewriter, Location loc) { - auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); - if (!vecType) - return emitError(loc) << "unsupported VPTO binary element type"; - - Value src0Buffer = materializeBufferPointer(src0, contract.elementType, - getMemorySpace(src0), rewriter, loc); - Value src1Buffer = materializeBufferPointer(src1, contract.elementType, - getMemorySpace(src1), rewriter, loc); - Value dstBuffer = materializeBufferPointer(dst, contract.elementType, - getMemorySpace(dst), rewriter, loc); - if (!src0Buffer || !src1Buffer || !dstBuffer) - return emitError(loc) << "requires pointer-backed tile buffers for binary lowering"; - - int64_t vectorWidth = vecType.getElementCount(); - Value validRowsValue = contract.validRowsValue; - Value validColsValue = contract.validColsValue; - int64_t validRows = contract.validRows; - int64_t validCols = contract.validCols; - if (failed(resolveExecutionValidShape(dst, validRowsValue, validColsValue, validRows, - validCols, rewriter, loc))) - return emitError(loc) << "binary lowering requires valid rows and cols"; - - int64_t src0Stride = deriveStaticRowStride(src0); - int64_t src1Stride = deriveStaticRowStride(src1); - int64_t dstStride = deriveStaticRowStride(dst); - int64_t src0Cols = deriveStaticTileCols(src0); - int64_t src1Cols = deriveStaticTileCols(src1); - int64_t dstCols = deriveStaticTileCols(dst); - if (src0Stride == ShapedType::kDynamic || src1Stride == ShapedType::kDynamic || - dstStride == ShapedType::kDynamic || src0Cols == ShapedType::kDynamic || - src1Cols == ShapedType::kDynamic || dstCols == ShapedType::kDynamic) - return emitError(loc) << "binary lowering requires static row strides and cols"; - - auto buildBinaryValue = [&](Value lhs, Value rhs, Value mask) -> FailureOr { - if (family == "add") - return rewriter.create(loc, vecType, lhs, rhs, mask).getResult(); - if (family == "sub") - return rewriter.create(loc, vecType, lhs, rhs, mask).getResult(); - if (family == "mul") - return rewriter.create(loc, vecType, lhs, rhs, mask).getResult(); - if (family == "div") - return rewriter.create(loc, vecType, lhs, rhs, mask).getResult(); - if (family == "max") - return rewriter.create(loc, vecType, lhs, rhs, mask).getResult(); - if (family == "min") - return rewriter.create(loc, vecType, lhs, rhs, mask).getResult(); - if (family == "and") - return rewriter.create(loc, vecType, lhs, rhs, mask).getResult(); - if (family == "or") - return rewriter.create(loc, vecType, lhs, rhs, mask).getResult(); - if (family == "xor") - return rewriter.create(loc, vecType, lhs, rhs, mask).getResult(); - return failure(); - }; - - Value c0 = rewriter.create(loc, 0); - Value c1 = rewriter.create(loc, 1); - Value totalElementsValue = - rewriter.create(loc, validRowsValue, validColsValue); - Value vectorStepValue = - rewriter.create(loc, vectorWidth); - Value src0StrideValue = rewriter.create(loc, src0Stride); - Value src1StrideValue = rewriter.create(loc, src1Stride); - Value dstStrideValue = rewriter.create(loc, dstStride); - Value scalarInit = rewriter.create(loc, rewriter.getI32Type(), - totalElementsValue); - Value rowScalarInit = rewriter.create(loc, rewriter.getI32Type(), - validColsValue); - bool sameShapeLinearPath = src0Stride == dstStride && src1Stride == dstStride && - src0Cols == dstCols && src1Cols == dstCols; - Value fullWidthCond = buildFullWidthColsCondition( - {src0Cols, src1Cols, dstCols}, validColsValue, rewriter, loc); - if (!fullWidthCond) - return emitError(loc) << "binary lowering could not materialize full-width selector"; - Value use1DCond = sameShapeLinearPath ? fullWidthCond : Value(); - - FailureOr vecScope = - createLoopScopeRegion(loc, contract.loopScope, rewriter); - if (failed(vecScope)) - return emitError(loc) << "failed to create AIV vector scope region"; - - OpBuilder::InsertionGuard aivGuard(rewriter); - rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); - auto emit1DBody = [&]() -> LogicalResult { - scf::ForOp chunkLoop; - if (strategy == VPTOLoweringStrategy::PostUpdate) { - chunkLoop = rewriter.create( - loc, c0, totalElementsValue, vectorStepValue, - ValueRange{src0Buffer, src1Buffer, dstBuffer, scalarInit}); - } else { - chunkLoop = rewriter.create(loc, c0, totalElementsValue, - vectorStepValue, - ValueRange{scalarInit}); - } - rewriter.setInsertionPointToStart(chunkLoop.getBody()); - Value remaining = chunkLoop.getRegionIterArgs().back(); - PredicateMaterialization predicateState = - buildPredicateForLaneCount(rewriter, loc, contract.elementType, remaining); - Value lhsBase = src0Buffer; - Value rhsBase = src1Buffer; - Value dstBase = dstBuffer; - Value loadOffset = chunkLoop.getInductionVar(); - Value storeOffset = chunkLoop.getInductionVar(); - if (strategy == VPTOLoweringStrategy::PostUpdate) { - lhsBase = chunkLoop.getRegionIterArgs()[0]; - rhsBase = chunkLoop.getRegionIterArgs()[1]; - dstBase = chunkLoop.getRegionIterArgs()[2]; - loadOffset = vectorStepValue; - storeOffset = vectorStepValue; - } - Value lhsValue; - Value rhsValue; - Value nextSrc0 = {}; - Value nextSrc1 = {}; - if (strategy == VPTOLoweringStrategy::PostUpdate) { - auto lhs = rewriter.create(loc, vecType, lhsBase.getType(), - lhsBase, loadOffset, StringAttr()); - auto rhs = rewriter.create(loc, vecType, rhsBase.getType(), - rhsBase, loadOffset, StringAttr()); - lhsValue = lhs.getResult(); - rhsValue = rhs.getResult(); - nextSrc0 = lhs.getUpdatedSource(); - nextSrc1 = rhs.getUpdatedSource(); - } else { - auto lhs = - rewriter.create(loc, vecType, lhsBase, loadOffset, StringAttr()); - auto rhs = - rewriter.create(loc, vecType, rhsBase, loadOffset, StringAttr()); - lhsValue = lhs.getResult(); - rhsValue = rhs.getResult(); - } - FailureOr computed = buildBinaryValue(lhsValue, rhsValue, predicateState.mask); - if (failed(computed)) - return emitError(loc) << "unsupported VPTO binary family: " << family; - if (strategy == VPTOLoweringStrategy::PostUpdate) { - auto vsts = rewriter.create(loc, dstBase.getType(), *computed, - dstBase, storeOffset, StringAttr(), - predicateState.mask); - Value nextDst = vsts.getUpdatedDestination(); - rewriter.create( - loc, - ValueRange{nextSrc0, nextSrc1, nextDst, predicateState.nextScalar}); - } else { - rewriter.create(loc, *computed, dstBase, storeOffset, - StringAttr(), predicateState.mask); - rewriter.create(loc, ValueRange{predicateState.nextScalar}); - } - return success(); - }; - - auto emit2DBody = [&]() -> LogicalResult { - Value repeatUpper = rewriter.create(loc, validColsValue, - vectorStepValue); - auto rowLoop = rewriter.create(loc, c0, validRowsValue, c1); - rewriter.setInsertionPointToStart(rowLoop.getBody()); - Value row = rowLoop.getInductionVar(); - Value src0RowBase = rewriter.create(loc, row, src0StrideValue); - Value src1RowBase = rewriter.create(loc, row, src1StrideValue); - Value dstRowBase = rewriter.create(loc, row, dstStrideValue); - - if (strategy == VPTOLoweringStrategy::NoPostUpdate) { - auto repeatLoop = rewriter.create(loc, c0, repeatUpper, c1, - ValueRange{rowScalarInit}); - rewriter.setInsertionPointToStart(repeatLoop.getBody()); - Value remaining = repeatLoop.getRegionIterArgs()[0]; - PredicateMaterialization predicateState = - buildPredicateForLaneCount(rewriter, loc, contract.elementType, remaining); - Value chunkBase = - rewriter.create(loc, repeatLoop.getInductionVar(), vectorStepValue); - Value src0Offset = rewriter.create(loc, src0RowBase, chunkBase); - Value src1Offset = rewriter.create(loc, src1RowBase, chunkBase); - Value dstOffset = rewriter.create(loc, dstRowBase, chunkBase); - auto lhs = rewriter.create(loc, vecType, src0Buffer, src0Offset, - StringAttr()); - auto rhs = rewriter.create(loc, vecType, src1Buffer, src1Offset, - StringAttr()); - FailureOr computed = - buildBinaryValue(lhs.getResult(), rhs.getResult(), predicateState.mask); - if (failed(computed)) - return emitError(loc) << "unsupported VPTO binary family: " << family; - rewriter.create(loc, *computed, dstBuffer, dstOffset, - StringAttr(), predicateState.mask); - rewriter.create(loc, ValueRange{predicateState.nextScalar}); - return success(); - } - - auto repeatLoop = rewriter.create(loc, c0, repeatUpper, c1); - rewriter.setInsertionPointToStart(repeatLoop.getBody()); - Value chunkBase = - rewriter.create(loc, repeatLoop.getInductionVar(), vectorStepValue); - Value src0Offset = rewriter.create(loc, src0RowBase, chunkBase); - Value src1Offset = rewriter.create(loc, src1RowBase, chunkBase); - Value dstOffset = rewriter.create(loc, dstRowBase, chunkBase); - Value nextChunk = rewriter.create(loc, chunkBase, vectorStepValue); - Value exceeds = - rewriter.create(loc, arith::CmpIPredicate::sge, nextChunk, validColsValue); - Value tailCount = rewriter.create(loc, validColsValue, chunkBase); - Value activeLanes = - rewriter.create(loc, exceeds, tailCount, vectorStepValue); - Value predicate = buildPredicateMaskForLaneCount(rewriter, loc, - contract.elementType, activeLanes); - auto lhs = - rewriter.create(loc, vecType, src0Buffer, src0Offset, StringAttr()); - auto rhs = - rewriter.create(loc, vecType, src1Buffer, src1Offset, StringAttr()); - FailureOr computed = buildBinaryValue(lhs.getResult(), rhs.getResult(), predicate); - if (failed(computed)) - return emitError(loc) << "unsupported VPTO binary family: " << family; - rewriter.create(loc, *computed, dstBuffer, dstOffset, - StringAttr(), predicate); - return success(); - }; - - if (use1DCond) { - auto ifOp = rewriter.create(loc, TypeRange{}, use1DCond, - /*withElseRegion=*/true); - rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); - if (failed(emit1DBody())) - return failure(); - rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); - if (failed(emit2DBody())) - return failure(); - rewriter.setInsertionPointAfter(ifOp); - } else { - if (failed(emit2DBody())) - return failure(); - } - return success(); -} - -LogicalResult buildExpandScalarVecScope(const VPTOUnaryContract &contract, - Value scalar, Value dst, - PatternRewriter &rewriter, - Location loc) { - auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); - if (!vecType) - return emitError(loc) << "unsupported VPTO expands element type"; - - Value dstBuffer = materializeBufferPointer(dst, contract.elementType, - getMemorySpace(dst), rewriter, loc); - if (!dstBuffer) - return emitError(loc) << "requires pointer-backed tile buffer for expands lowering"; - - Value validRowsValue = materializeIndexValue(contract.validRowsValue, - contract.validRows, rewriter, loc); - Value validColsValue = materializeIndexValue(contract.validColsValue, - contract.validCols, rewriter, loc); - if (!validRowsValue || !validColsValue) - return emitError(loc) << "expands lowering requires valid rows and cols"; - - int64_t vectorWidth = vecType.getElementCount(); - int64_t dstStride = deriveStaticRowStride(dst); - int64_t dstCols = deriveStaticTileCols(dst); - if (dstStride == ShapedType::kDynamic || dstCols == ShapedType::kDynamic) - return emitError(loc) << "expands lowering requires static destination row stride and cols"; - - Value c0 = rewriter.create(loc, 0); - Value c1 = rewriter.create(loc, 1); - Value totalElementsValue = - rewriter.create(loc, validRowsValue, validColsValue); - Value vectorStepValue = - rewriter.create(loc, vectorWidth); - Value dstStrideValue = rewriter.create(loc, dstStride); - Value fullWidthCond = - buildFullWidthColsCondition({dstCols}, validColsValue, rewriter, loc); - if (!fullWidthCond) - return emitError(loc) << "expands lowering could not materialize full-width selector"; - - FailureOr vecScope = - createLoopScopeRegion(loc, contract.loopScope, rewriter); - if (failed(vecScope)) - return emitError(loc) << "failed to create AIV vector scope region"; - - OpBuilder::InsertionGuard aivGuard(rewriter); - rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); - auto ifOp = rewriter.create(loc, TypeRange{}, fullWidthCond, - /*withElseRegion=*/true); - - rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); - { - Value scalarInit = rewriter.create( - loc, rewriter.getI32Type(), totalElementsValue); - auto chunkLoop = rewriter.create( - loc, c0, totalElementsValue, vectorStepValue, - ValueRange{dstBuffer, scalarInit}); - rewriter.setInsertionPointToStart(chunkLoop.getBody()); - Value dstPtr = chunkLoop.getRegionIterArgs()[0]; - Value remaining = chunkLoop.getRegionIterArgs()[1]; - PredicateMaterialization predicateState = buildPredicateForLaneCount( - rewriter, loc, contract.elementType, remaining); - Value computed = - rewriter.create(loc, vecType, scalar, predicateState.mask, StringAttr()); - auto vsts = rewriter.create(loc, dstPtr.getType(), computed, dstPtr, - vectorStepValue, StringAttr(), - predicateState.mask); - Value nextDst = vsts.getUpdatedDestination(); - rewriter.create( - loc, ValueRange{nextDst, predicateState.nextScalar}); - } - - rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); - { - Value repeatUpper = rewriter.create(loc, validColsValue, - vectorStepValue); - auto rowLoop = rewriter.create(loc, c0, validRowsValue, c1); - rewriter.setInsertionPointToStart(rowLoop.getBody()); - Value row = rowLoop.getInductionVar(); - Value rowBase = rewriter.create(loc, row, dstStrideValue); - auto repeatLoop = rewriter.create(loc, c0, repeatUpper, c1); - rewriter.setInsertionPointToStart(repeatLoop.getBody()); - Value repeat = repeatLoop.getInductionVar(); - Value chunkBase = rewriter.create(loc, repeat, vectorStepValue); - Value dstOffset = rewriter.create(loc, rowBase, chunkBase); - Value remainingCols = - rewriter.create(loc, validColsValue, chunkBase); - Value activeLanes = - buildMinIndexValue(rewriter, loc, remainingCols, vectorStepValue); - Value predicate = buildPredicateMaskForLaneCount( - rewriter, loc, contract.elementType, activeLanes); - Value computed = - rewriter.create(loc, vecType, scalar, predicate, StringAttr()); - rewriter.create(loc, computed, dstBuffer, dstOffset, - StringAttr(), predicate); - } - - rewriter.setInsertionPointAfter(ifOp); - return success(); -} - -LogicalResult buildScalarUnaryVecScope(StringRef family, - const VPTOUnaryContract &contract, - VPTOLoweringStrategy strategy, - Value src, Value scalar, Value dst, - PatternRewriter &rewriter, - Location loc) { - auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); - if (!vecType) - return emitError(loc) << "unsupported VPTO scalar-unary element type"; - - Value srcBuffer = materializeBufferPointer(src, contract.elementType, - getMemorySpace(src), rewriter, loc); - Value dstBuffer = materializeBufferPointer(dst, contract.elementType, - getMemorySpace(dst), rewriter, loc); - if (!srcBuffer || !dstBuffer) - return emitError(loc) - << "requires pointer-backed tile buffers for scalar-unary lowering"; - - Value validRowsValue = materializeIndexValue(contract.validRowsValue, - contract.validRows, rewriter, loc); - Value validColsValue = materializeIndexValue(contract.validColsValue, - contract.validCols, rewriter, loc); - if (!validRowsValue || !validColsValue) - return emitError(loc) << family << " lowering requires valid rows and cols"; - - int64_t vectorWidth = vecType.getElementCount(); - int64_t srcStride = deriveStaticRowStride(src); - int64_t dstStride = deriveStaticRowStride(dst); - int64_t srcCols = deriveStaticTileCols(src); - int64_t dstCols = deriveStaticTileCols(dst); - if (srcStride == ShapedType::kDynamic || dstStride == ShapedType::kDynamic || - srcCols == ShapedType::kDynamic || dstCols == ShapedType::kDynamic) - return emitError(loc) - << family << " lowering requires static src/dst row stride and cols"; - - Value c0 = rewriter.create(loc, 0); - Value c1 = rewriter.create(loc, 1); - Value totalElementsValue = - rewriter.create(loc, validRowsValue, validColsValue); - Value vectorStepValue = - rewriter.create(loc, vectorWidth); - Value srcStrideValue = rewriter.create(loc, srcStride); - Value dstStrideValue = rewriter.create(loc, dstStride); - Value fullWidthCond = buildFullWidthColsCondition( - {srcCols, dstCols}, validColsValue, rewriter, loc); - if (!fullWidthCond) - return emitError(loc) << family << " lowering could not materialize full-width selector"; - - FailureOr vecScope = - createLoopScopeRegion(loc, contract.loopScope, rewriter); - if (failed(vecScope)) - return emitError(loc) << "failed to create AIV vector scope region"; - - OpBuilder::InsertionGuard aivGuard(rewriter); - rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); - auto ifOp = rewriter.create(loc, TypeRange{}, fullWidthCond, - /*withElseRegion=*/true); - - rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); - { - auto emitComputed = [&](Value loadedVec, Value predicate) -> FailureOr { - if (family == "adds") - return rewriter.create(loc, vecType, loadedVec, scalar, predicate).getResult(); - if (family == "maxs") - return rewriter.create(loc, vecType, loadedVec, scalar, predicate).getResult(); - if (family == "mins") - return rewriter.create(loc, vecType, loadedVec, scalar, predicate).getResult(); - if (family == "muls") - return rewriter.create(loc, vecType, loadedVec, scalar, predicate).getResult(); - if (family == "lrelu") - return rewriter.create(loc, vecType, loadedVec, scalar, predicate).getResult(); - return failure(); - }; - - if (strategy == VPTOLoweringStrategy::NoPostUpdate) { - auto chunkLoop = - rewriter.create(loc, c0, totalElementsValue, vectorStepValue); - rewriter.setInsertionPointToStart(chunkLoop.getBody()); - Value offset = chunkLoop.getInductionVar(); - Value remaining = rewriter.create(loc, totalElementsValue, offset); - Value activeLanes = - buildMinIndexValue(rewriter, loc, remaining, vectorStepValue); - Value predicate = buildPredicateMaskForLaneCount( - rewriter, loc, contract.elementType, activeLanes); - auto loaded = - rewriter.create(loc, vecType, srcBuffer, offset, StringAttr()); - FailureOr computed = emitComputed(loaded.getResult(), predicate); - if (failed(computed)) - return emitError(loc) << "unsupported VPTO scalar-unary family: " << family; - rewriter.create(loc, *computed, dstBuffer, offset, StringAttr(), - predicate); - } else { - Value scalarInit = rewriter.create( - loc, rewriter.getI32Type(), totalElementsValue); - auto chunkLoop = rewriter.create( - loc, c0, totalElementsValue, vectorStepValue, - ValueRange{srcBuffer, dstBuffer, scalarInit}); - rewriter.setInsertionPointToStart(chunkLoop.getBody()); - Value srcPtr = chunkLoop.getRegionIterArgs()[0]; - Value dstPtr = chunkLoop.getRegionIterArgs()[1]; - Value remaining = chunkLoop.getRegionIterArgs()[2]; - PredicateMaterialization predicateState = buildPredicateForLaneCount( - rewriter, loc, contract.elementType, remaining); - auto loaded = rewriter.create(loc, vecType, srcPtr.getType(), srcPtr, - vectorStepValue, StringAttr()); - FailureOr computed = emitComputed(loaded.getResult(), predicateState.mask); - if (failed(computed)) - return emitError(loc) << "unsupported VPTO scalar-unary family: " << family; - auto vsts = rewriter.create(loc, dstPtr.getType(), *computed, dstPtr, - vectorStepValue, StringAttr(), - predicateState.mask); - Value nextSrc = loaded.getUpdatedSource(); - Value nextDst = vsts.getUpdatedDestination(); - rewriter.create( - loc, ValueRange{nextSrc, nextDst, predicateState.nextScalar}); - } - } - - rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); - { - Value repeatUpper = rewriter.create(loc, validColsValue, - vectorStepValue); - auto rowLoop = rewriter.create(loc, c0, validRowsValue, c1); - rewriter.setInsertionPointToStart(rowLoop.getBody()); - Value row = rowLoop.getInductionVar(); - Value srcRowBase = rewriter.create(loc, row, srcStrideValue); - Value dstRowBase = rewriter.create(loc, row, dstStrideValue); - auto repeatLoop = rewriter.create(loc, c0, repeatUpper, c1); - rewriter.setInsertionPointToStart(repeatLoop.getBody()); - Value repeat = repeatLoop.getInductionVar(); - Value chunkBase = rewriter.create(loc, repeat, vectorStepValue); - Value srcOffset = rewriter.create(loc, srcRowBase, chunkBase); - Value dstOffset = rewriter.create(loc, dstRowBase, chunkBase); - Value predicate; - if (strategy == VPTOLoweringStrategy::NoPostUpdate) { - predicate = - buildPredicateMaskForLaneCount(rewriter, loc, contract.elementType, validColsValue); - } else { - Value remainingCols = - rewriter.create(loc, validColsValue, chunkBase); - Value activeLanes = - buildMinIndexValue(rewriter, loc, remainingCols, vectorStepValue); - predicate = buildPredicateMaskForLaneCount( - rewriter, loc, contract.elementType, activeLanes); - } - auto loaded = - rewriter.create(loc, vecType, srcBuffer, srcOffset, StringAttr()); - Value computed; - if (family == "adds") - computed = rewriter.create(loc, vecType, loaded.getResult(), scalar, predicate); - else if (family == "maxs") - computed = rewriter.create(loc, vecType, loaded.getResult(), scalar, predicate); - else if (family == "mins") - computed = rewriter.create(loc, vecType, loaded.getResult(), scalar, predicate); - else if (family == "muls") - computed = rewriter.create(loc, vecType, loaded.getResult(), scalar, predicate); - else if (family == "lrelu") - computed = rewriter.create(loc, vecType, loaded.getResult(), scalar, predicate); - else - return emitError(loc) << "unsupported VPTO scalar-unary family: " << family; - rewriter.create(loc, computed, dstBuffer, dstOffset, - StringAttr(), predicate); - } - - rewriter.setInsertionPointAfter(ifOp); - return success(); -} - -LogicalResult buildScalarBitwiseVecScope(StringRef family, - const VPTOUnaryContract &contract, - Value src, Value scalar, Value dst, - PatternRewriter &rewriter, - Location loc) { - auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); - if (!vecType) - return emitError(loc) << "unsupported VPTO scalar-bitwise element type"; - - Value srcBuffer = materializeBufferPointer(src, contract.elementType, - getMemorySpace(src), rewriter, loc); - Value dstBuffer = materializeBufferPointer(dst, contract.elementType, - getMemorySpace(dst), rewriter, loc); - if (!srcBuffer || !dstBuffer) - return emitError(loc) - << "requires pointer-backed tile buffers for scalar-bitwise lowering"; - - Value validRowsValue; - Value validColsValue; - int64_t validRows = ShapedType::kDynamic; - int64_t validCols = ShapedType::kDynamic; - deriveValidShapeValues(dst, validRowsValue, validColsValue); - deriveValidShape(dst, validRows, validCols); - if (failed(resolveExecutionValidShape(dst, validRowsValue, validColsValue, validRows, - validCols, rewriter, loc))) - return emitError(loc) << family << " lowering requires valid rows and cols"; - - int64_t vectorWidth = vecType.getElementCount(); - Value c0 = rewriter.create(loc, 0); - Value c1 = rewriter.create(loc, 1); - Value totalElementsValue = - rewriter.create(loc, validRowsValue, validColsValue); - Value vectorStepValue = - rewriter.create(loc, vectorWidth); - Value vectorWidthValue = - rewriter.create(loc, vectorWidth); - - FailureOr vecScope = - createLoopScopeRegion(loc, contract.loopScope, rewriter); - if (failed(vecScope)) - return emitError(loc) << "failed to create AIV vector scope region"; - - OpBuilder::InsertionGuard aivGuard(rewriter); - rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); - auto chunkLoop = - rewriter.create(loc, c0, totalElementsValue, vectorStepValue); - - OpBuilder::InsertionGuard chunkGuard(rewriter); - rewriter.setInsertionPointToStart(chunkLoop.getBody()); - Value offset = chunkLoop.getInductionVar(); - Value remaining = rewriter.create(loc, totalElementsValue, offset); - Value activeLanes = - buildMinIndexValue(rewriter, loc, remaining, vectorWidthValue); - Value predicate = - buildPredicateMaskForLaneCount(rewriter, loc, contract.elementType, activeLanes); - Value scalarVec = - rewriter.create(loc, vecType, scalar, predicate, StringAttr()); - auto loaded = rewriter.create(loc, vecType, srcBuffer, offset, - StringAttr()); - - Value computed; - if (family == "ands") - computed = - rewriter.create(loc, vecType, loaded.getResult(), scalarVec, predicate); - else if (family == "ors") - computed = - rewriter.create(loc, vecType, loaded.getResult(), scalarVec, predicate); - else if (family == "xors") - computed = - rewriter.create(loc, vecType, loaded.getResult(), scalarVec, predicate); - else - return emitError(loc) << "unsupported VPTO scalar-bitwise family: " << family; - rewriter.create(loc, computed, dstBuffer, offset, StringAttr(), - predicate); - return success(); -} - -static bool isVPTOShapedLikeValue(Value value) { - Type type = value.getType(); - return isa(type); -} - -LogicalResult buildScalarDivVecScope(const VPTOUnaryContract &contract, - VPTOLoweringStrategy strategy, - Value src, Value scalar, Value dst, - bool scalarFirst, - PatternRewriter &rewriter, Location loc) { - auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); - if (!vecType) - return emitError(loc) << "unsupported VPTO divs element type"; - - Value srcBuffer = materializeBufferPointer(src, contract.elementType, - getMemorySpace(src), rewriter, loc); - Value dstBuffer = materializeBufferPointer(dst, contract.elementType, - getMemorySpace(dst), rewriter, loc); - if (!srcBuffer || !dstBuffer) - return emitError(loc) - << "requires pointer-backed tile buffers for divs lowering"; - - Value validRowsValue = materializeIndexValue(contract.validRowsValue, - contract.validRows, rewriter, loc); - Value validColsValue = materializeIndexValue(contract.validColsValue, - contract.validCols, rewriter, loc); - if (!validRowsValue || !validColsValue) - return emitError(loc) << "divs lowering requires valid rows and cols"; - - int64_t vectorWidth = vecType.getElementCount(); - int64_t srcStride = deriveStaticRowStride(src); - int64_t dstStride = deriveStaticRowStride(dst); - int64_t srcCols = deriveStaticTileCols(src); - int64_t dstCols = deriveStaticTileCols(dst); - if (srcStride == ShapedType::kDynamic || dstStride == ShapedType::kDynamic || - srcCols == ShapedType::kDynamic || dstCols == ShapedType::kDynamic) - return emitError(loc) - << "divs lowering requires static src/dst row stride and cols"; - - Value c0 = rewriter.create(loc, 0); - Value c1 = rewriter.create(loc, 1); - Value totalElementsValue = - rewriter.create(loc, validRowsValue, validColsValue); - Value vectorStepValue = - rewriter.create(loc, vectorWidth); - Value srcStrideValue = rewriter.create(loc, srcStride); - Value dstStrideValue = rewriter.create(loc, dstStride); - Value fullWidthCond = buildFullWidthColsCondition( - {srcCols, dstCols}, validColsValue, rewriter, loc); - if (!fullWidthCond) - return emitError(loc) << "divs lowering could not materialize full-width selector"; - - auto buildDivValue = [&](Value loaded, Value predicate) -> FailureOr { - if (contract.elementType.isF32()) { - if (scalarFirst) { - Value scalarVec = - rewriter.create(loc, vecType, scalar, predicate, StringAttr()); - return rewriter.create(loc, vecType, scalarVec, loaded, predicate) - .getResult(); - } - Value one = rewriter.create( - loc, contract.elementType, - rewriter.getFloatAttr(contract.elementType, 1.0)); - Value reciprocal = rewriter.create(loc, one, scalar); - return rewriter.create(loc, vecType, loaded, reciprocal, predicate).getResult(); - } - if (contract.elementType.isF16()) { - Value scalarVec = - rewriter.create(loc, vecType, scalar, predicate, StringAttr()); - return scalarFirst - ? rewriter.create(loc, vecType, scalarVec, loaded, predicate) - .getResult() - : rewriter.create(loc, vecType, loaded, scalarVec, predicate) - .getResult(); - } - return failure(); - }; - - FailureOr vecScope = - createLoopScopeRegion(loc, contract.loopScope, rewriter); - if (failed(vecScope)) - return emitError(loc) << "failed to create AIV vector scope region"; - - OpBuilder::InsertionGuard aivGuard(rewriter); - rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); - auto ifOp = rewriter.create(loc, TypeRange{}, fullWidthCond, - /*withElseRegion=*/true); - - rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); - { - if (strategy == VPTOLoweringStrategy::NoPostUpdate) { - auto chunkLoop = - rewriter.create(loc, c0, totalElementsValue, vectorStepValue); - rewriter.setInsertionPointToStart(chunkLoop.getBody()); - Value offset = chunkLoop.getInductionVar(); - Value remaining = rewriter.create(loc, totalElementsValue, offset); - Value activeLanes = - buildMinIndexValue(rewriter, loc, remaining, vectorStepValue); - Value predicate = buildPredicateMaskForLaneCount( - rewriter, loc, contract.elementType, activeLanes); - auto loaded = - rewriter.create(loc, vecType, srcBuffer, offset, StringAttr()); - FailureOr computed = buildDivValue(loaded.getResult(), predicate); - if (failed(computed)) - return emitError(loc) - << "divs lowering currently supports only f16 and f32 element types"; - rewriter.create(loc, *computed, dstBuffer, offset, StringAttr(), - predicate); - } else { - Value scalarInit = rewriter.create( - loc, rewriter.getI32Type(), totalElementsValue); - auto chunkLoop = rewriter.create( - loc, c0, totalElementsValue, vectorStepValue, - ValueRange{srcBuffer, dstBuffer, scalarInit}); - rewriter.setInsertionPointToStart(chunkLoop.getBody()); - Value srcPtr = chunkLoop.getRegionIterArgs()[0]; - Value dstPtr = chunkLoop.getRegionIterArgs()[1]; - Value remaining = chunkLoop.getRegionIterArgs()[2]; - PredicateMaterialization predicateState = buildPredicateForLaneCount( - rewriter, loc, contract.elementType, remaining); - auto loaded = rewriter.create(loc, vecType, srcPtr.getType(), srcPtr, - vectorStepValue, StringAttr()); - FailureOr computed = buildDivValue(loaded.getResult(), predicateState.mask); - if (failed(computed)) - return emitError(loc) - << "divs lowering currently supports only f16 and f32 element types"; - auto vsts = rewriter.create(loc, dstPtr.getType(), *computed, dstPtr, - vectorStepValue, StringAttr(), - predicateState.mask); - Value nextSrc = loaded.getUpdatedSource(); - Value nextDst = vsts.getUpdatedDestination(); - rewriter.create( - loc, ValueRange{nextSrc, nextDst, predicateState.nextScalar}); - } - } - - rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); - { - Value repeatUpper = rewriter.create(loc, validColsValue, - vectorStepValue); - auto rowLoop = rewriter.create(loc, c0, validRowsValue, c1); - rewriter.setInsertionPointToStart(rowLoop.getBody()); - Value row = rowLoop.getInductionVar(); - Value srcRowBase = rewriter.create(loc, row, srcStrideValue); - Value dstRowBase = rewriter.create(loc, row, dstStrideValue); - auto repeatLoop = rewriter.create(loc, c0, repeatUpper, c1); - rewriter.setInsertionPointToStart(repeatLoop.getBody()); - Value repeat = repeatLoop.getInductionVar(); - Value chunkBase = rewriter.create(loc, repeat, vectorStepValue); - Value srcOffset = rewriter.create(loc, srcRowBase, chunkBase); - Value dstOffset = rewriter.create(loc, dstRowBase, chunkBase); - Value predicate; - if (strategy == VPTOLoweringStrategy::NoPostUpdate) { - predicate = - buildPredicateMaskForLaneCount(rewriter, loc, contract.elementType, validColsValue); - } else { - Value remainingCols = - rewriter.create(loc, validColsValue, chunkBase); - Value activeLanes = - buildMinIndexValue(rewriter, loc, remainingCols, vectorStepValue); - predicate = buildPredicateMaskForLaneCount( - rewriter, loc, contract.elementType, activeLanes); - } - auto loaded = - rewriter.create(loc, vecType, srcBuffer, srcOffset, StringAttr()); - FailureOr computed = buildDivValue(loaded.getResult(), predicate); - if (failed(computed)) - return emitError(loc) - << "divs lowering currently supports only f16 and f32 element types"; - rewriter.create(loc, *computed, dstBuffer, dstOffset, - StringAttr(), predicate); - } - - rewriter.setInsertionPointAfter(ifOp); - return success(); -} - -LogicalResult checkExpandContract(Operation *op, - const VPTOExpandContract &contract) { - bool hasPrecheckFailure = false; - if (contract.srcDomain != VPTOTileDomain::Vec || - contract.dstDomain != VPTOTileDomain::Vec) { - op->emitOpError() << contract.family - << " lowering requires vec source and destination"; - hasPrecheckFailure = true; - } - if (contract.srcLayout != "row_major" || contract.dstLayout != "row_major") { - op->emitOpError() << contract.family - << " lowering requires row-major source and destination tile layout"; - hasPrecheckFailure = true; - } - if (!contract.elementType || - (!contract.elementType.isF16() && !contract.elementType.isF32())) { - op->emitOpError() << contract.family - << " lowering currently supports only f16 and f32 element types"; - hasPrecheckFailure = true; - } - auto isStatic = [](int64_t value) { return value != ShapedType::kDynamic; }; - if (!isStatic(contract.srcValidRows) || !isStatic(contract.srcValidCols) || - !isStatic(contract.dstValidRows) || !isStatic(contract.dstValidCols)) { - op->emitOpError() << contract.family - << " lowering currently requires static source and destination valid shapes"; - hasPrecheckFailure = true; - } - return failure(hasPrecheckFailure); -} - -LogicalResult buildRowExpandVecScope(const VPTOExpandContract &contract, - VPTOLoweringStrategy strategy, Value src, Value dst, - PatternRewriter &rewriter, Location loc) { - auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); - if (!vecType) - return emitError(loc) << "unsupported VPTO rowexpand element type"; - - Value srcBuffer = materializeBufferPointer(src, contract.elementType, - getMemorySpace(src), rewriter, loc); - Value dstBuffer = materializeBufferPointer(dst, contract.elementType, - getMemorySpace(dst), rewriter, loc); - if (!srcBuffer || !dstBuffer) - return emitError(loc) - << "requires pointer-backed tile buffers for rowexpand lowering"; - - auto [srcRows, srcCols] = getStaticTileRowsCols(src); - auto [dstRows, dstCols] = getStaticTileRowsCols(dst); - if (srcCols == ShapedType::kDynamic || dstCols == ShapedType::kDynamic || - srcRows == ShapedType::kDynamic || dstRows == ShapedType::kDynamic) - return emitError(loc) << "rowexpand lowering requires static physical tile shape"; - - int64_t vectorWidth = vecType.getElementCount(); - Value validRowsValue = materializeIndexValue( - contract.dstValidRowsValue, contract.dstValidRows, rewriter, loc); - Value validColsValue = materializeIndexValue( - contract.dstValidColsValue, contract.dstValidCols, rewriter, loc); - if (!validRowsValue || !validColsValue) - return emitError(loc) << "rowexpand lowering requires valid rows and cols"; - - Value c0 = rewriter.create(loc, 0); - Value c1 = rewriter.create(loc, 1); - Value srcStrideValue = rewriter.create(loc, srcCols); - Value dstStrideValue = rewriter.create(loc, dstCols); - Value vectorStepValue = - rewriter.create(loc, vectorWidth); - Value rowScalarInit = rewriter.create(loc, rewriter.getI32Type(), - validColsValue); - FailureOr vecScope = - createLoopScopeRegion(loc, contract.loopScope, rewriter); - if (failed(vecScope)) - return emitError(loc) << "failed to create AIV vector scope region"; - - OpBuilder::InsertionGuard aivGuard(rewriter); - rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); - Value repeatUpper = rewriter.create(loc, validColsValue, - vectorStepValue); - if (strategy == VPTOLoweringStrategy::NoPostUpdate) { - auto rowLoop = rewriter.create(loc, c0, validRowsValue, c1); - rewriter.setInsertionPointToStart(rowLoop.getBody()); - Value row = rowLoop.getInductionVar(); - Value srcOffset = rewriter.create(loc, row, srcStrideValue); - Value dstBase = rewriter.create(loc, row, dstStrideValue); - auto loaded = - rewriter.create(loc, vecType, srcBuffer, srcOffset, StringAttr()); - Value fullMask = buildAllPredicateMask(rewriter, loc, contract.elementType); - Value expanded = rewriter.create( - loc, vecType, loaded.getResult(), fullMask, rewriter.getStringAttr("LOWEST")); - auto chunkLoop = rewriter.create(loc, c0, repeatUpper, c1, - ValueRange{rowScalarInit}); - rewriter.setInsertionPointToStart(chunkLoop.getBody()); - Value remaining = chunkLoop.getRegionIterArgs()[0]; - PredicateMaterialization predicateState = - buildPredicateForLaneCount(rewriter, loc, contract.elementType, remaining); - Value chunkBase = - rewriter.create(loc, chunkLoop.getInductionVar(), vectorStepValue); - Value dstOffset = rewriter.create(loc, dstBase, chunkBase); - rewriter.create(loc, expanded, dstBuffer, dstOffset, StringAttr(), - predicateState.mask); - rewriter.create(loc, ValueRange{predicateState.nextScalar}); - return success(); - } - - auto rowLoop = - rewriter.create(loc, c0, validRowsValue, c1, ValueRange{srcBuffer, dstBuffer}); - rewriter.setInsertionPointToStart(rowLoop.getBody()); - Value srcPtr = rowLoop.getRegionIterArgs()[0]; - Value dstPtr = rowLoop.getRegionIterArgs()[1]; - auto loaded = rewriter.create(loc, vecType, srcPtr.getType(), srcPtr, - srcStrideValue, StringAttr()); - Value fullMask = buildAllPredicateMask(rewriter, loc, contract.elementType); - Value expanded = rewriter.create( - loc, vecType, loaded.getResult(), fullMask, rewriter.getStringAttr("LOWEST")); - auto chunkLoop = rewriter.create(loc, c0, repeatUpper, c1, - ValueRange{dstPtr, rowScalarInit}); - rewriter.setInsertionPointToStart(chunkLoop.getBody()); - Value dstChunkPtr = chunkLoop.getRegionIterArgs()[0]; - Value remaining = chunkLoop.getRegionIterArgs()[1]; - PredicateMaterialization predicateState = - buildPredicateForLaneCount(rewriter, loc, contract.elementType, remaining); - auto vsts = rewriter.create(loc, dstChunkPtr.getType(), expanded, - dstChunkPtr, vectorStepValue, StringAttr(), - predicateState.mask); - Value nextDstChunkPtr = vsts.getUpdatedDestination(); - rewriter.create(loc, ValueRange{nextDstChunkPtr, predicateState.nextScalar}); - - rewriter.setInsertionPointAfter(chunkLoop); - Value rowAdvance = rewriter.create(loc, repeatUpper, vectorStepValue); - Value dstPad = rewriter.create(loc, dstStrideValue, rowAdvance); - Value nextDstPtr = - offsetBufferPointer(dstPtr, contract.elementType, dstPad, rewriter, loc); - Value nextSrcPtr = loaded.getUpdatedSource(); - rewriter.create(loc, ValueRange{nextSrcPtr, nextDstPtr}); - return success(); -} - -LogicalResult buildColExpandVecScope(const VPTOExpandContract &contract, - Value src, Value dst, - PatternRewriter &rewriter, Location loc) { - auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); - if (!vecType) - return emitError(loc) << "unsupported VPTO colexpand element type"; - - Value srcBuffer = materializeBufferPointer(src, contract.elementType, - getMemorySpace(src), rewriter, loc); - Value dstBuffer = materializeBufferPointer(dst, contract.elementType, - getMemorySpace(dst), rewriter, loc); - if (!srcBuffer || !dstBuffer) - return emitError(loc) - << "requires pointer-backed tile buffers for colexpand lowering"; - - auto [dstRows, dstCols] = getStaticTileRowsCols(dst); - if (dstRows == ShapedType::kDynamic || dstCols == ShapedType::kDynamic) - return emitError(loc) - << "colexpand lowering requires static physical destination tile shape"; - - int64_t vectorWidth = vecType.getElementCount(); - Value validRowsValue = materializeIndexValue( - contract.dstValidRowsValue, contract.dstValidRows, rewriter, loc); - Value validColsValue = materializeIndexValue( - contract.dstValidColsValue, contract.dstValidCols, rewriter, loc); - if (!validRowsValue || !validColsValue) - return emitError(loc) << "colexpand lowering requires valid rows and cols"; - - Value c0 = rewriter.create(loc, 0); - Value c1 = rewriter.create(loc, 1); - Value dstStrideValue = rewriter.create(loc, dstCols); - Value vectorStepValue = - rewriter.create(loc, vectorWidth); - FailureOr vecScope = - createLoopScopeRegion(loc, contract.loopScope, rewriter); - if (failed(vecScope)) - return emitError(loc) << "failed to create AIV vector scope region"; - - OpBuilder::InsertionGuard aivGuard(rewriter); - rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); - auto rowLoop = rewriter.create(loc, c0, validRowsValue, c1); - rewriter.setInsertionPointToStart(rowLoop.getBody()); - auto chunkLoop = - rewriter.create(loc, c0, validColsValue, vectorStepValue); - rewriter.setInsertionPointToStart(chunkLoop.getBody()); - - Value dstBase = - rewriter.create(loc, rowLoop.getInductionVar(), dstStrideValue); - Value dstOffset = - rewriter.create(loc, dstBase, chunkLoop.getInductionVar()); - auto loaded = rewriter.create( - loc, vecType, srcBuffer, chunkLoop.getInductionVar(), StringAttr()); - rewriter.create(loc, loaded.getResult(), dstBuffer, dstOffset, - StringAttr(), - buildAllPredicateMask(rewriter, loc, - contract.elementType)); - return success(); -} - -LogicalResult checkGenericUnaryContract(Operation *op, - const VPTOUnaryContract &contract, - Value dst, - function_ref typePredicate, - StringRef supportedTypeText) { - int64_t dstRows = ShapedType::kDynamic; - int64_t dstCols = ShapedType::kDynamic; - deriveValidShape(dst, dstRows, dstCols); - StringRef dstLayout = deriveTileLayout(dst); - VPTOTileDomain dstDomain = deriveTileDomain(getMemorySpace(dst)); - - bool hasPrecheckFailure = false; - if (contract.tileDomain != VPTOTileDomain::Vec || dstDomain != VPTOTileDomain::Vec) { - op->emitOpError() << contract.family << " lowering requires tile domain vec"; - hasPrecheckFailure = true; - } - if (contract.tileLayout != "row_major" || dstLayout != "row_major") { - op->emitOpError() << contract.family << " lowering requires row-major tile layout"; - hasPrecheckFailure = true; - } - if (contract.validRows != ShapedType::kDynamic && - dstRows != ShapedType::kDynamic && dstRows > contract.validRows) { - op->emitOpError() << contract.family - << " lowering requires destination valid rows not to exceed source"; - hasPrecheckFailure = true; - } - if (contract.validCols != ShapedType::kDynamic && - dstCols != ShapedType::kDynamic && dstCols > contract.validCols) { - op->emitOpError() << contract.family - << " lowering requires destination valid cols not to exceed source"; - hasPrecheckFailure = true; - } - if (!contract.elementType || !typePredicate(contract.elementType)) { - op->emitOpError() - << contract.family << " lowering supports only " << supportedTypeText; - hasPrecheckFailure = true; - } - return failure(hasPrecheckFailure); -} - -LogicalResult checkGenericBinaryContract( - Operation *op, const VPTOBinaryContract &contract, Value src1, Value dst, - function_ref typePredicate, StringRef supportedTypeText) { - StringRef src1Layout = deriveTileLayout(src1); - StringRef dstLayout = deriveTileLayout(dst); - VPTOTileDomain src1Domain = deriveTileDomain(getMemorySpace(src1)); - VPTOTileDomain dstDomain = deriveTileDomain(getMemorySpace(dst)); - - bool hasPrecheckFailure = false; - if (contract.tileDomain != VPTOTileDomain::Vec || src1Domain != VPTOTileDomain::Vec || - dstDomain != VPTOTileDomain::Vec) { - op->emitOpError() << contract.family << " lowering requires tile domain vec"; - hasPrecheckFailure = true; - } - if (contract.tileLayout != "row_major" || src1Layout != "row_major" || - dstLayout != "row_major") { - op->emitOpError() << contract.family << " lowering requires row-major tile layout"; - hasPrecheckFailure = true; - } - if (!contract.elementType || !typePredicate(contract.elementType)) { - op->emitOpError() - << contract.family << " lowering supports only " << supportedTypeText; - hasPrecheckFailure = true; - } - return failure(hasPrecheckFailure); -} - -LogicalResult checkRowReduceContract(Operation *op, - const VPTORowReduceContract &contract, - Value dst) { - int64_t dstRows = ShapedType::kDynamic; - int64_t dstCols = ShapedType::kDynamic; - deriveValidShape(dst, dstRows, dstCols); - - bool hasPrecheckFailure = false; - if (contract.srcDomain != VPTOTileDomain::Vec || - contract.dstDomain != VPTOTileDomain::Vec) { - op->emitOpError() << contract.family << " lowering requires vec source and destination"; - hasPrecheckFailure = true; - } - if (contract.srcLayout != "row_major") { - op->emitOpError() << contract.family << " lowering requires row-major source tile layout"; - hasPrecheckFailure = true; - } - if (contract.dstLayout != "row_major" && contract.dstLayout != "col_major") { - op->emitOpError() << contract.family - << " lowering requires row-major or col-major destination tile layout"; - hasPrecheckFailure = true; - } - if (!contract.elementType || (!contract.elementType.isF16() && !contract.elementType.isF32())) { - op->emitOpError() << contract.family << " lowering supports only f16 and f32 element types"; - hasPrecheckFailure = true; - } - if (contract.validRows == ShapedType::kDynamic || - contract.validCols == ShapedType::kDynamic) { - op->emitOpError() << contract.family - << " lowering currently requires static source valid rows and cols"; - hasPrecheckFailure = true; - } - if (contract.validRows != dstRows) { - op->emitOpError() << contract.family - << " lowering requires destination valid rows to match source valid rows"; - hasPrecheckFailure = true; - } - if (dstCols != 1) { - op->emitOpError() << contract.family - << " lowering requires destination valid cols to equal 1"; - hasPrecheckFailure = true; - } - if (contract.dstLayout == "col_major") { - auto [dstRowsPhysical, dstColsPhysical] = getStaticTileRowsCols(dst); - (void)dstRowsPhysical; - if (dstColsPhysical != 1) { - op->emitOpError() << contract.family - << " lowering requires col-major destinations to use physical cols == 1"; - hasPrecheckFailure = true; - } - } - return failure(hasPrecheckFailure); -} - -LogicalResult checkColReduceContract(Operation *op, - const VPTOColReduceContract &contract, - Value dst) { - int64_t dstRows = ShapedType::kDynamic; - int64_t dstCols = ShapedType::kDynamic; - deriveValidShape(dst, dstRows, dstCols); - - bool hasPrecheckFailure = false; - if (contract.srcDomain != VPTOTileDomain::Vec || - contract.dstDomain != VPTOTileDomain::Vec) { - op->emitOpError() << contract.family << " lowering requires vec source and destination"; - hasPrecheckFailure = true; - } - if (contract.srcLayout != "row_major" || contract.dstLayout != "row_major") { - op->emitOpError() << contract.family - << " lowering requires row-major source and destination tile layout"; - hasPrecheckFailure = true; - } - if (!contract.elementType || - (!contract.elementType.isF16() && !contract.elementType.isF32())) { - op->emitOpError() << contract.family << " lowering supports only f16 and f32 element types"; - hasPrecheckFailure = true; - } - if (contract.validRows == ShapedType::kDynamic || - contract.validCols == ShapedType::kDynamic) { - op->emitOpError() << contract.family - << " lowering currently requires static source valid rows and cols"; - hasPrecheckFailure = true; - } - if (dstRows != 1) { - op->emitOpError() << contract.family - << " lowering requires destination valid rows to equal 1"; - hasPrecheckFailure = true; - } - if (dstCols != contract.validCols) { - op->emitOpError() << contract.family - << " lowering requires destination valid cols to match source valid cols"; - hasPrecheckFailure = true; - } - if (contract.isBinary && !contract.tmp) { - op->emitOpError() << contract.family << " lowering requires tmp for binary path"; - hasPrecheckFailure = true; - } - return failure(hasPrecheckFailure); -} - -LogicalResult checkPartContract(Operation *op, const VPTOPartContract &contract) { - bool hasPrecheckFailure = false; - if (contract.src0Domain != VPTOTileDomain::Vec || - contract.src1Domain != VPTOTileDomain::Vec || - contract.dstDomain != VPTOTileDomain::Vec) { - op->emitOpError() << contract.family << " lowering requires vec source and destination"; - hasPrecheckFailure = true; - } - if (contract.src0Layout != "row_major" || contract.src1Layout != "row_major" || - contract.dstLayout != "row_major") { - op->emitOpError() << contract.family - << " lowering requires row-major source and destination tile layout"; - hasPrecheckFailure = true; - } - if (!contract.elementType) - hasPrecheckFailure = true; - else if (contract.family == "partadd") { - bool ok = contract.elementType.isF16() || contract.elementType.isF32() || - contract.elementType.isBF16(); - if (auto intType = dyn_cast(contract.elementType)) - ok = intType.getWidth() == 8 || intType.getWidth() == 16 || - intType.getWidth() == 32; - if (!ok) { - op->emitOpError() << contract.family - << " lowering supports f16, f32, bf16, and 8/16/32-bit integers"; - hasPrecheckFailure = true; - } - } else { - bool ok = contract.elementType.isF16() || contract.elementType.isF32() || - contract.elementType.isBF16(); - if (auto intType = dyn_cast(contract.elementType)) - ok = intType.getWidth() == 8 || intType.getWidth() == 16 || - intType.getWidth() == 32; - if (!ok) { - op->emitOpError() << contract.family - << " lowering supports f16, f32, bf16, and 8/16/32-bit integers"; - hasPrecheckFailure = true; - } - } - auto allStatic = [&](int64_t a, int64_t b) { - return a != ShapedType::kDynamic && b != ShapedType::kDynamic; - }; - if (!allStatic(contract.src0ValidRows, contract.src0ValidCols) || - !allStatic(contract.src1ValidRows, contract.src1ValidCols) || - !allStatic(contract.dstValidRows, contract.dstValidCols)) { - op->emitOpError() << contract.family - << " lowering currently requires static source and destination valid shapes"; - hasPrecheckFailure = true; - } - return failure(hasPrecheckFailure); -} - -LogicalResult lowerTLOAD(TLoadOp op, PatternRewriter &rewriter) { - VPTOLoadContract contract = extractTLoadContract(op); - if (contract.tileDomain != VPTOTileDomain::Vec) - return op.emitOpError("currently supports only VEC TLOAD lowering"); - - ResolvedTensorView sourceView; - if (!resolveTensorView(op.getSrc(), sourceView, rewriter, op.getLoc())) - return op.emitOpError("requires a recoverable source tensor view for VPTO lowering"); - - StringRef sourceLayout = - inferVecTransferLayoutFromTile(stringifyLayoutAttr(sourceView.layoutAttr), - contract.tileLayout); - bool isNdLoad = contract.tileLayout == "row_major" && sourceLayout == "nd"; - bool isDnLoad = contract.tileLayout == "col_major" && sourceLayout == "dn"; - if (!isNdLoad && !isDnLoad) - return op.emitOpError("currently supports only ND row_major or DN col_major vec TLOAD lowering"); - - Value sourceBuffer = - materializeBufferPointer(sourceView.root, getElementType(sourceView.root), - getGmMemorySpace(rewriter.getContext()), rewriter, - op.getLoc()); - Value destinationBuffer = - materializeBufferPointer(op.getDst(), contract.elementType, - getMemorySpace(op.getDst()), rewriter, op.getLoc()); - if (!sourceBuffer || !destinationBuffer) - return op.emitOpError("requires A5-compatible source and destination buffers"); - - auto [tileRows, tileCols] = getStaticTileRowsCols(op.getDst()); - (void)tileRows; - bool ubPad = contract.padMode != "none" || contract.padValue || - contract.leftPaddingNum || contract.rightPaddingNum; - Value validRowsValue = - materializeI64Value(contract.validRowsValue, contract.validRows, rewriter, - op.getLoc()); - Value validColsValue = - materializeI64Value(contract.validColsValue, contract.validCols, rewriter, - op.getLoc()); - Value sidValue = rewriter.create(op.getLoc(), 0, 64); - int64_t elemBytes = getElementByteSize(contract.elementType); - if ((isNdLoad && tileCols == ShapedType::kDynamic) || - (isDnLoad && tileRows == ShapedType::kDynamic) || elemBytes <= 0) - return op.emitOpError("requires static tile shape for A5-compatible transfer arguments"); - VecNdTransferPlan plan; - LogicalResult planResult = - isNdLoad ? buildVecNdLoadPlan(sourceView.shape, sourceView.strides, tileCols, - contract.validColsValue, contract.validCols, - contract.elementType, rewriter, op.getLoc(), plan) - : buildVecDnLoadPlan(sourceView.shape, sourceView.strides, tileRows, - contract.validRowsValue, contract.validRows, - contract.elementType, rewriter, op.getLoc(), plan); - if (failed(planResult)) - return op.emitOpError("requires PTO-compatible vec copy_gm_to_ubuf arguments"); - Value leftPaddingValue = rewriter.create(op.getLoc(), 0, 64); - Value rightPaddingValue = rewriter.create(op.getLoc(), 0, 64); - Value cacheCtlValue = rewriter.create(op.getLoc(), 0, 64); - if (!validRowsValue || !validColsValue) - return op.emitOpError("requires valid rows and cols for A5-compatible transfer arguments"); - Value sourceOffset = - materializeI64Ofr(sourceView.offsetElems, rewriter, op.getLoc()); - if (!sourceOffset) - return op.emitOpError("requires a materializable source offset for VPTO lowering"); - Value sourceBase = adjustPointerByElemOffset(sourceBuffer, sourceOffset, elemBytes, - rewriter, op.getLoc()); - if (!sourceBase) - return op.emitOpError("failed to materialize source base pointer"); - - rewriter.create( - op.getLoc(), plan.loop2FirstStrideBytes, plan.loop2SecondStrideBytes); - rewriter.create( - op.getLoc(), plan.loop1FirstStrideBytes, plan.loop1SecondStrideBytes); - rewriter.create(op.getLoc(), plan.loop2Size, - plan.loop1Size); - - auto emitCopy = [&](Value srcPtr, Value dstPtr) { - Type transferElementType = - getCopyTransferElementType(contract.elementType, rewriter); - Value typedSrcPtr = - castPtrToElementType(srcPtr, transferElementType, rewriter, op.getLoc()); - Value typedDstPtr = - castPtrToElementType(dstPtr, transferElementType, rewriter, op.getLoc()); - if (!typedSrcPtr || !typedDstPtr) - return failure(); - Value dataSelectBitValue = - rewriter.create(op.getLoc(), rewriter.getI1Type(), - rewriter.getBoolAttr(ubPad)); - rewriter.create( - op.getLoc(), typedSrcPtr, typedDstPtr, sidValue, plan.nBurst, - plan.lenBurst, leftPaddingValue, rightPaddingValue, dataSelectBitValue, - cacheCtlValue, plan.firstStrideBytes, plan.secondStrideBytes); - return success(); - }; - - if (std::optional outerConst = getConstInt(plan.outerCount); outerConst && *outerConst == 1) { - return emitCopy(sourceBase, destinationBuffer); - } - - Value c0 = rewriter.create(op.getLoc(), 0); - Value c1 = rewriter.create(op.getLoc(), 1); - Value outerUpper = - rewriter.create(op.getLoc(), rewriter.getIndexType(), - plan.outerCount); - auto outerLoop = rewriter.create(op.getLoc(), c0, outerUpper, c1); - rewriter.setInsertionPointToStart(outerLoop.getBody()); - Value ivI64 = rewriter.create(op.getLoc(), rewriter.getI64Type(), - outerLoop.getInductionVar()); - Value srcStep = createI64Mul(ivI64, plan.outerSrcStrideElems, rewriter, op.getLoc()); - Value dstStep = createI64Mul(ivI64, plan.outerDstStrideElems, rewriter, op.getLoc()); - Value iterSrc = adjustPointerByElemOffset(sourceBase, srcStep, elemBytes, rewriter, - op.getLoc()); - Value iterDst = adjustPointerByElemOffset(destinationBuffer, dstStep, elemBytes, rewriter, - op.getLoc()); - return emitCopy(iterSrc, iterDst); -} - -LogicalResult lowerTABS(TAbsOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - VPTOUnaryContract contract = extractTAbsContract(op); - if (failed(checkGenericUnaryContract( - op, contract, op.getDst(), - [](Type type) { return type.isF16() || type.isF32(); }, "f16 and f32 element types"))) - return failure(); - - return buildUnaryVecScope("abs", contract, strategy, op.getSrc(), op.getDst(), - rewriter, op.getLoc()); -} - -LogicalResult lowerTADD(TAddOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - VPTOBinaryContract contract = extractTAddContract(op); - deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); - deriveValidShape(op.getDst(), contract.validRows, contract.validCols); - if (failed(checkGenericBinaryContract( - op, contract, op.getSrc1(), op.getDst(), - [](Type type) { - if (type.isF16() || type.isF32() || type.isBF16()) - return true; - if (auto intType = dyn_cast(type)) - return intType.getWidth() == 8 || intType.getWidth() == 16 || - intType.getWidth() == 32; - return false; - }, - "f16, f32, bf16, and 8/16/32-bit integer element types"))) - return failure(); - return buildBinaryVecScope("add", contract, strategy, op.getSrc0(), - op.getSrc1(), op.getDst(), rewriter, op.getLoc()); -} - -LogicalResult lowerTSUB(TSubOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - VPTOBinaryContract contract = extractTSubContract(op); - deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); - deriveValidShape(op.getDst(), contract.validRows, contract.validCols); - if (failed(checkGenericBinaryContract( - op, contract, op.getSrc1(), op.getDst(), - [](Type type) { - if (type.isF16() || type.isF32() || type.isBF16()) - return true; - if (auto intType = dyn_cast(type)) - return intType.getWidth() == 8 || intType.getWidth() == 16 || - intType.getWidth() == 32; - return false; - }, - "f16, f32, bf16, and 8/16/32-bit integer element types"))) - return failure(); - return buildBinaryVecScope("sub", contract, strategy, op.getSrc0(), - op.getSrc1(), op.getDst(), rewriter, op.getLoc()); -} - -LogicalResult lowerTMUL(TMulOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - VPTOBinaryContract contract = extractTMulContract(op); - deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); - deriveValidShape(op.getDst(), contract.validRows, contract.validCols); - if (failed(checkGenericBinaryContract( - op, contract, op.getSrc1(), op.getDst(), - [](Type type) { - if (type.isF16() || type.isF32() || type.isBF16()) - return true; - if (auto intType = dyn_cast(type)) - return intType.getWidth() == 8 || intType.getWidth() == 16 || - intType.getWidth() == 32; - return false; - }, - "f16, f32, bf16, and 8/16/32-bit integer element types"))) - return failure(); - return buildBinaryVecScope("mul", contract, strategy, op.getSrc0(), - op.getSrc1(), op.getDst(), rewriter, op.getLoc()); -} - -LogicalResult lowerTDIV(TDivOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - VPTOBinaryContract contract = extractTDivContract(op); - deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); - deriveValidShape(op.getDst(), contract.validRows, contract.validCols); - if (failed(checkGenericBinaryContract( - op, contract, op.getSrc1(), op.getDst(), - [](Type type) { - if (type.isF16() || type.isF32()) - return true; - if (auto intType = dyn_cast(type)) - return intType.getWidth() == 16 || intType.getWidth() == 32; - return false; - }, - "f16, f32, and 16/32-bit integer element types"))) - return failure(); - return buildBinaryVecScope("div", contract, strategy, op.getSrc0(), - op.getSrc1(), op.getDst(), rewriter, op.getLoc()); -} - -LogicalResult lowerTMAX(TMaxOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - VPTOBinaryContract contract = buildBinaryContract("max", op.getSrc0()); - deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); - deriveValidShape(op.getDst(), contract.validRows, contract.validCols); - if (failed(checkGenericBinaryContract( - op, contract, op.getSrc1(), op.getDst(), - [](Type type) { - if (type.isF16() || type.isF32() || type.isBF16()) - return true; - if (auto intType = dyn_cast(type)) - return intType.getWidth() == 8 || intType.getWidth() == 16 || - intType.getWidth() == 32; - return false; - }, - "f16, f32, bf16, and 8/16/32-bit integer element types"))) - return failure(); - return buildBinaryVecScope("max", contract, strategy, op.getSrc0(), - op.getSrc1(), op.getDst(), rewriter, op.getLoc()); -} - -LogicalResult lowerTMIN(TMinOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - VPTOBinaryContract contract = buildBinaryContract("min", op.getSrc0()); - deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); - deriveValidShape(op.getDst(), contract.validRows, contract.validCols); - if (failed(checkGenericBinaryContract( - op, contract, op.getSrc1(), op.getDst(), - [](Type type) { - if (type.isF16() || type.isF32() || type.isBF16()) - return true; - if (auto intType = dyn_cast(type)) - return intType.getWidth() == 8 || intType.getWidth() == 16 || - intType.getWidth() == 32; - return false; - }, - "f16, f32, bf16, and 8/16/32-bit integer element types"))) - return failure(); - return buildBinaryVecScope("min", contract, strategy, op.getSrc0(), - op.getSrc1(), op.getDst(), rewriter, op.getLoc()); -} - -LogicalResult lowerTAND(TAndOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - VPTOBinaryContract contract = buildBinaryContract("and", op.getSrc0()); - deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); - deriveValidShape(op.getDst(), contract.validRows, contract.validCols); - if (failed(checkGenericBinaryContract( - op, contract, op.getSrc1(), op.getDst(), - [](Type type) { - if (auto intType = dyn_cast(type)) - return intType.getWidth() == 8 || intType.getWidth() == 16 || - intType.getWidth() == 32; - return false; - }, - "8/16/32-bit integer element types"))) - return failure(); - return buildBinaryVecScope("and", contract, strategy, op.getSrc0(), - op.getSrc1(), op.getDst(), rewriter, op.getLoc()); -} - -LogicalResult lowerTANDS(TAndSOp op, PatternRewriter &rewriter) { - return emitUnresolvedInstalledA5BaselineError(op, "tands"); -} - -LogicalResult lowerTOR(TOrOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - VPTOBinaryContract contract = buildBinaryContract("or", op.getSrc0()); - deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); - deriveValidShape(op.getDst(), contract.validRows, contract.validCols); - if (failed(checkGenericBinaryContract( - op, contract, op.getSrc1(), op.getDst(), - [](Type type) { - if (auto intType = dyn_cast(type)) - return intType.getWidth() == 8 || intType.getWidth() == 16 || - intType.getWidth() == 32; - return false; - }, - "8/16/32-bit integer element types"))) - return failure(); - return buildBinaryVecScope("or", contract, strategy, op.getSrc0(), - op.getSrc1(), op.getDst(), rewriter, op.getLoc()); -} - -LogicalResult lowerTORS(TOrSOp op, PatternRewriter &rewriter) { - return emitUnresolvedInstalledA5BaselineError(op, "tors"); -} - -LogicalResult lowerTXOR(TXorOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - VPTOBinaryContract contract = buildBinaryContract("xor", op.getSrc0()); - deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); - deriveValidShape(op.getDst(), contract.validRows, contract.validCols); - if (failed(checkGenericBinaryContract( - op, contract, op.getSrc1(), op.getDst(), - [](Type type) { - if (auto intType = dyn_cast(type)) - return intType.getWidth() == 8 || intType.getWidth() == 16 || - intType.getWidth() == 32; - return false; - }, - "8/16/32-bit integer element types"))) - return failure(); - return buildBinaryVecScope("xor", contract, strategy, op.getSrc0(), - op.getSrc1(), op.getDst(), rewriter, op.getLoc()); -} - -LogicalResult lowerTXORS(TXorSOp op, PatternRewriter &rewriter) { - return emitUnresolvedInstalledA5BaselineError(op, "txors"); -} - -LogicalResult lowerTEXP(TExpOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - VPTOUnaryContract contract = extractTExpContract(op); - if (failed(checkGenericUnaryContract( - op, contract, op.getDst(), - [](Type type) { return type.isF16() || type.isF32(); }, "f16 and f32 element types"))) - return failure(); - return buildUnaryVecScope("exp", contract, strategy, op.getSrc(), op.getDst(), - rewriter, op.getLoc()); -} - -LogicalResult lowerTLOG(TLogOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - VPTOUnaryContract contract = extractTLogContract(op); - if (failed(checkGenericUnaryContract( - op, contract, op.getDst(), - [](Type type) { return type.isF16() || type.isF32(); }, "f16 and f32 element types"))) - return failure(); - return buildUnaryVecScope("log", contract, strategy, op.getSrc(), op.getDst(), - rewriter, op.getLoc()); -} - -LogicalResult lowerTSQRT(TSqrtOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - VPTOUnaryContract contract = extractTSqrtContract(op); - if (failed(checkGenericUnaryContract( - op, contract, op.getDst(), - [](Type type) { return type.isF16() || type.isF32(); }, "f16 and f32 element types"))) - return failure(); - return buildUnaryVecScope("sqrt", contract, strategy, op.getSrc(), op.getDst(), - rewriter, op.getLoc()); -} - -LogicalResult lowerTRSQRT(TRsqrtOp op, PatternRewriter &rewriter) { - VPTOUnaryContract contract = buildUnaryContract("rsqrt", op.getSrc()); - if (failed(checkGenericUnaryContract( - op, contract, op.getDst(), - [](Type type) { return type.isF16() || type.isF32(); }, "f16 and f32 element types"))) - return failure(); - - auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); - if (!vecType) - return op.emitOpError("trsqrt lowering requires a supported VPTO vector element type"); - - Value srcBuffer = materializeBufferPointer(op.getSrc(), contract.elementType, - getMemorySpace(op.getSrc()), rewriter, - op.getLoc()); - Value dstBuffer = materializeBufferPointer(op.getDst(), contract.elementType, - getMemorySpace(op.getDst()), rewriter, - op.getLoc()); - if (!srcBuffer || !dstBuffer) - return op.emitOpError("trsqrt lowering requires pointer-backed tile buffers"); - - Value validRowsValue = materializeIndexValue(contract.validRowsValue, - contract.validRows, rewriter, op.getLoc()); - Value validColsValue = materializeIndexValue(contract.validColsValue, - contract.validCols, rewriter, op.getLoc()); - if (!validRowsValue || !validColsValue) - return op.emitOpError("trsqrt lowering requires valid rows and cols"); - - int64_t srcRowStride = deriveStaticRowStride(op.getSrc()); - int64_t dstRowStride = deriveStaticRowStride(op.getDst()); - if (srcRowStride == ShapedType::kDynamic || dstRowStride == ShapedType::kDynamic) - return op.emitOpError("trsqrt lowering requires static row-major row strides"); - - int64_t vectorWidth = vecType.getElementCount(); - Value c0 = rewriter.create(op.getLoc(), 0); - Value c1 = rewriter.create(op.getLoc(), 1); - Value srcRowStrideValue = - rewriter.create(op.getLoc(), srcRowStride); - Value dstRowStrideValue = - rewriter.create(op.getLoc(), dstRowStride); - Value vectorStepValue = - rewriter.create(op.getLoc(), vectorWidth); - TypedAttr oneAttr = FloatAttr::get(contract.elementType, 1.0); - Value one = rewriter.create(op.getLoc(), oneAttr); - - FailureOr vecScope = - createLoopScopeRegion(op.getLoc(), contract.loopScope, rewriter); - if (failed(vecScope)) - return op.emitOpError("failed to create AIV vector scope region"); - - OpBuilder::InsertionGuard aivGuard(rewriter); - rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); - auto rowLoop = rewriter.create(op.getLoc(), c0, validRowsValue, c1); - rewriter.setInsertionPointToStart(rowLoop.getBody()); - Value fullMask = buildAllPredicateMask(rewriter, op.getLoc(), vecType.getElementType()); - auto ones = - rewriter.create(op.getLoc(), vecType, one, fullMask, StringAttr()); - auto chunkLoop = - rewriter.create(op.getLoc(), c0, validColsValue, vectorStepValue); - rewriter.setInsertionPointToStart(chunkLoop.getBody()); - Value srcRowBase = rewriter.create( - op.getLoc(), rowLoop.getInductionVar(), srcRowStrideValue); - Value dstRowBase = rewriter.create( - op.getLoc(), rowLoop.getInductionVar(), dstRowStrideValue); - Value chunkOffset = chunkLoop.getInductionVar(); - Value srcOffset = - rewriter.create(op.getLoc(), srcRowBase, chunkOffset); - Value dstOffset = - rewriter.create(op.getLoc(), dstRowBase, chunkOffset); - Value remaining = rewriter.create(op.getLoc(), validColsValue, chunkOffset); - Value predicate = - buildPredicateMaskForLaneCount(rewriter, op.getLoc(), contract.elementType, remaining); - auto loaded = rewriter.create(op.getLoc(), vecType, srcBuffer, - srcOffset, StringAttr()); - auto sqrt = rewriter.create(op.getLoc(), vecType, loaded.getResult(), - predicate); - auto result = rewriter.create(op.getLoc(), vecType, ones.getResult(), - sqrt.getResult(), predicate); - rewriter.create( - op.getLoc(), result.getResult(), dstBuffer, dstOffset, StringAttr(), predicate); - return success(); -} - -LogicalResult lowerTRECIP(TRecipOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - VPTOUnaryContract contract = extractTRecipContract(op); - if (failed(checkGenericUnaryContract( - op, contract, op.getDst(), - [](Type type) { return type.isF16() || type.isF32(); }, "f16 and f32 element types"))) - return failure(); - return buildUnaryVecScope("recip", contract, strategy, op.getSrc(), - op.getDst(), rewriter, op.getLoc()); -} - -LogicalResult lowerTNEG(TNegOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - VPTOUnaryContract contract = buildUnaryContract("muls", op.getSrc()); - if (failed(checkGenericUnaryContract( - op, contract, op.getDst(), - [](Type type) { - if (type.isF16() || type.isF32()) - return true; - if (auto intType = dyn_cast(type)) - return intType.getWidth() == 16 || intType.getWidth() == 32; - return false; - }, - "f16, f32, and 16/32-bit integer element types"))) - return failure(); - - TypedAttr negOneAttr; - if (contract.elementType.isF16()) - negOneAttr = FloatAttr::get(contract.elementType, -1.0); - else if (contract.elementType.isF32()) - negOneAttr = FloatAttr::get(contract.elementType, -1.0); - else if (auto intType = dyn_cast(contract.elementType)) - negOneAttr = IntegerAttr::get(intType, -1); - else - return op.emitOpError("tneg lowering requires scalar element type"); - - Value negOne = rewriter.create(op.getLoc(), negOneAttr); - return buildScalarUnaryVecScope("muls", contract, strategy, op.getSrc(), negOne, - op.getDst(), rewriter, op.getLoc()); -} - -LogicalResult lowerTLRELU(TLReluOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - VPTOUnaryContract contract = buildUnaryContract("lrelu", op.getSrc()); - if (failed(checkGenericUnaryContract( - op, contract, op.getDst(), - [](Type type) { return type.isF16() || type.isF32(); }, - "f16 and f32 element types"))) - return failure(); - if (op.getSlope().getType() != contract.elementType) - return op.emitOpError("tlrelu lowering requires slope type to match source element type"); - return buildScalarUnaryVecScope("lrelu", contract, strategy, op.getSrc(), op.getSlope(), - op.getDst(), rewriter, op.getLoc()); -} - -LogicalResult lowerTCVT(TCvtOp op, PatternRewriter &rewriter) { - VPTOUnaryContract contract = buildUnaryContract("cvt", op.getSrc()); - if (failed(checkGenericUnaryContract( - op, contract, op.getDst(), - [](Type type) { return type.isF16() || type.isF32() || type.isBF16(); }, - "f16, f32, or bf16 element type"))) - return failure(); - - Type dstElementType = getElementType(op.getDst()); - FailureOr loweringKind = - classifyA5CvtLowering(contract.elementType, dstElementType); - if (failed(loweringKind)) - return op.emitOpError( - "current tcvt lowering supports only f32->f32, f32->bf16, f16->f32, bf16->f16, and bf16->f32"); - - FailureOr roundMode = stringifyA5RoundMode(op, rewriter); - if (failed(roundMode)) - return op.emitOpError("tcvt lowering does not recognize the requested round mode"); - - auto srcVecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); - auto dstVecType = getVPTOVRegType(rewriter.getContext(), dstElementType); - if (!srcVecType || !dstVecType) - return op.emitOpError("tcvt lowering requires legal VPTO vector types"); - - Value srcBuffer = materializeBufferPointer(op.getSrc(), contract.elementType, - getMemorySpace(op.getSrc()), rewriter, - op.getLoc()); - Value dstBuffer = materializeBufferPointer(op.getDst(), dstElementType, - getMemorySpace(op.getDst()), rewriter, - op.getLoc()); - if (!srcBuffer || !dstBuffer) - return op.emitOpError("tcvt lowering requires pointer-backed tile buffers"); - - Value validRowsValue = materializeIndexValue(contract.validRowsValue, - contract.validRows, rewriter, - op.getLoc()); - Value validColsValue = materializeIndexValue(contract.validColsValue, - contract.validCols, rewriter, - op.getLoc()); - if (!validRowsValue || !validColsValue) - return op.emitOpError("tcvt lowering requires valid rows and cols"); - - int64_t vectorWidth = dstVecType.getElementCount(); - if (contract.validRows != ShapedType::kDynamic && - contract.validCols != ShapedType::kDynamic) { - int64_t totalElements = contract.validRows * contract.validCols; - if (totalElements % vectorWidth != 0) - return op.emitOpError( - "tcvt lowering requires total valid elements divisible by vector width"); - } - - Value c0 = rewriter.create(op.getLoc(), 0); - Value c1 = rewriter.create(op.getLoc(), 1); - Value totalElementsValue = - rewriter.create(op.getLoc(), validRowsValue, validColsValue); - Value vectorStepValue = - rewriter.create(op.getLoc(), vectorWidth); - - FailureOr vecScope = - createLoopScopeRegion(op.getLoc(), contract.loopScope, rewriter); - if (failed(vecScope)) - return op.emitOpError("failed to create AIV vector scope region"); - - OpBuilder::InsertionGuard aivGuard(rewriter); - rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); - auto chunkLoop = - rewriter.create(op.getLoc(), c0, totalElementsValue, vectorStepValue); - OpBuilder::InsertionGuard chunkGuard(rewriter); - rewriter.setInsertionPointToStart(chunkLoop.getBody()); - Value offset = chunkLoop.getInductionVar(); - switch (*loweringKind) { - case VPTOCvtLoweringKind::Vtrc: { - auto loaded = - rewriter.create(op.getLoc(), srcVecType, srcBuffer, offset, StringAttr()); - Value mask = buildAllPredicateMask(rewriter, op.getLoc(), dstElementType); - Value converted = rewriter.create(op.getLoc(), dstVecType, - loaded.getResult(), mask, - *roundMode); - rewriter.create( - op.getLoc(), converted, dstBuffer, offset, StringAttr(), - buildAllPredicateMask(rewriter, op.getLoc(), dstElementType)); - break; - } - case VPTOCvtLoweringKind::F32ToBF16: { - Value halfStep = rewriter.create( - op.getLoc(), srcVecType.getElementCount()); - Value upperOffset = - rewriter.create(op.getLoc(), offset, halfStep); - auto lower = - rewriter.create(op.getLoc(), srcVecType, srcBuffer, offset, StringAttr()); - auto upper = rewriter.create(op.getLoc(), srcVecType, srcBuffer, - upperOffset, StringAttr()); - Value odd = rewriter.create( - op.getLoc(), dstVecType, upper.getResult(), *roundMode, - rewriter.getStringAttr("RS_ENABLE"), rewriter.getStringAttr("PART_ODD")); - Value even = rewriter.create( - op.getLoc(), dstVecType, lower.getResult(), *roundMode, - rewriter.getStringAttr("RS_ENABLE"), rewriter.getStringAttr("PART_EVEN")); - Value merged = - rewriter.create( - op.getLoc(), dstVecType, even, odd, - buildAllPredicateMask(rewriter, op.getLoc(), dstElementType)); - rewriter.create( - op.getLoc(), merged, dstBuffer, offset, StringAttr(), - buildAllPredicateMask(rewriter, op.getLoc(), dstElementType)); - break; - } - case VPTOCvtLoweringKind::F16ToF32: { - auto loaded = rewriter.create( - op.getLoc(), srcVecType, srcBuffer, offset, rewriter.getStringAttr("UNPK_B16")); - Value converted = rewriter.create( - op.getLoc(), dstVecType, loaded.getResult(), StringAttr(), - StringAttr(), rewriter.getStringAttr("PART_EVEN")); - rewriter.create( - op.getLoc(), converted, dstBuffer, offset, StringAttr(), - buildAllPredicateMask(rewriter, op.getLoc(), dstElementType)); - break; - } - case VPTOCvtLoweringKind::BF16ToF16: { - auto loaded = - rewriter.create(op.getLoc(), srcVecType, srcBuffer, offset, StringAttr()); - Value converted = rewriter.create( - op.getLoc(), dstVecType, loaded.getResult(), *roundMode, - rewriter.getStringAttr("RS_ENABLE"), StringAttr()); - rewriter.create( - op.getLoc(), converted, dstBuffer, offset, StringAttr(), - buildAllPredicateMask(rewriter, op.getLoc(), dstElementType)); - break; - } - case VPTOCvtLoweringKind::BF16ToF32: { - auto loaded = rewriter.create( - op.getLoc(), srcVecType, srcBuffer, offset, rewriter.getStringAttr("UNPK_B16")); - Value converted = rewriter.create( - op.getLoc(), dstVecType, loaded.getResult(), StringAttr(), - StringAttr(), rewriter.getStringAttr("PART_EVEN")); - rewriter.create( - op.getLoc(), converted, dstBuffer, offset, StringAttr(), - buildAllPredicateMask(rewriter, op.getLoc(), dstElementType)); - break; - } - } - return success(); -} - -template -LogicalResult buildPackedCmp32VecScope(StringRef family, - const VPTOBinaryContract &contract, - Value dst, Value dstBuffer, - PatternRewriter &rewriter, Location loc, - CompareEmitter emitCompare) { - auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); - if (!vecType) - return emitError(loc) << family << " lowering requires a supported vector element type"; - - Value validRowsValue = materializeIndexValue(contract.validRowsValue, - contract.validRows, rewriter, loc); - Value validColsValue = materializeIndexValue(contract.validColsValue, - contract.validCols, rewriter, loc); - if (!validRowsValue || !validColsValue) - return emitError(loc) << family << " lowering requires valid rows and cols"; - if (contract.validRows == ShapedType::kDynamic || - contract.validCols == ShapedType::kDynamic) - return emitError(loc) << family << " lowering currently requires static valid rows and cols"; - - int64_t totalElements = contract.validRows * contract.validCols; - constexpr int64_t repeatElem = 64; - int64_t repeatTimes = (totalElements + repeatElem - 1) / repeatElem; - int64_t pairedRepeats = repeatTimes / 2; - int64_t remainRepeats = repeatTimes % 2; - - auto compareMaskType = - getVPTOMaskTypeForElementType(rewriter.getContext(), contract.elementType); - auto packedMaskType = getVPTOMaskType(rewriter.getContext(), "b8"); - Value c0 = rewriter.create(loc, 0); - Value c1 = rewriter.create(loc, 1); - Value pairUpper = rewriter.create(loc, pairedRepeats); - Value repeatStep = rewriter.create(loc, repeatElem); - Value pairSrcStride = rewriter.create(loc, repeatElem * 2); - Value pairDstStride = rewriter.create(loc, 4); - Value laneCount = rewriter.create(loc, repeatElem, 32); - Value totalRemaining = rewriter.create(loc, totalElements, 32); - - FailureOr vecScope = - createLoopScopeRegion(loc, contract.loopScope, rewriter); - if (failed(vecScope)) - return emitError(loc) << "failed to create AIV vector scope region"; - - OpBuilder::InsertionGuard aivGuard(rewriter); - rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); - auto pairLoop = - rewriter.create(loc, c0, pairUpper, c1, ValueRange{totalRemaining}); - rewriter.setInsertionPointToStart(pairLoop.getBody()); - Value remaining = pairLoop.getRegionIterArgs().front(); - Value pairBase = rewriter.create(loc, pairLoop.getInductionVar(), - pairSrcStride); - Value pairNext = rewriter.create(loc, pairBase, repeatStep); - Value dstOffset = rewriter.create(loc, pairLoop.getInductionVar(), - pairDstStride); - Value dstBase = adjustPointerByElemOffset(dstBuffer, dstOffset, 4, rewriter, loc); - Value dstZero = rewriter.create(loc, 0); - auto pairMask0 = rewriter.create(loc, compareMaskType, - rewriter.getI32Type(), - remaining); - auto pairMask1 = rewriter.create(loc, compareMaskType, - rewriter.getI32Type(), - pairMask0.getScalarOut()); - Value cmp0 = emitCompare(rewriter, loc, pairBase, pairMask0.getMask()); - Value cmp1 = emitCompare(rewriter, loc, pairNext, pairMask1.getMask()); - Value packedCmp0 = rewriter - .create(loc, packedMaskType, cmp0, - rewriter.getStringAttr("LOWER")) - .getResult(); - Value packedCmp1 = rewriter - .create(loc, packedMaskType, cmp1, - rewriter.getStringAttr("LOWER")) - .getResult(); - auto interleaved = rewriter.create( - loc, packedMaskType, packedMaskType, packedCmp0, packedCmp1); - rewriter.create(loc, interleaved.getLow(), dstBase, dstZero, - "NORM"); - rewriter.create(loc, pairMask1.getScalarOut()); - - if (remainRepeats == 0) - return success(); - - rewriter.setInsertionPointAfter(pairLoop); - Value tailBase = rewriter.create(loc, pairedRepeats * repeatElem * 2); - Value tailDst = rewriter.create(loc, pairedRepeats * 4); - Value tailDstBase = adjustPointerByElemOffset(dstBuffer, tailDst, 4, rewriter, loc); - Value tailDstZero = rewriter.create(loc, 0); - auto tailMask = rewriter.create(loc, compareMaskType, - rewriter.getI32Type(), - pairLoop.getResult(0)); - Value tailCmp = emitCompare(rewriter, loc, tailBase, tailMask.getMask()); - Value packedTail = rewriter - .create(loc, packedMaskType, tailCmp, - rewriter.getStringAttr("LOWER")) - .getResult(); - rewriter.create(loc, packedTail, tailDstBase, tailDstZero, - "NORM"); - return success(); -} - -LogicalResult lowerTCmpS(TCmpSOp op, PatternRewriter &rewriter) { - VPTOBinaryContract contract = buildBinaryContract("cmps", op.getSrc()); - deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); - deriveValidShape(op.getDst(), contract.validRows, contract.validCols); - - if (contract.tileDomain != VPTOTileDomain::Vec || - deriveTileDomain(getMemorySpace(op.getDst())) != VPTOTileDomain::Vec) - return op.emitOpError("tcmps lowering requires tile domain vec"); - if (contract.tileLayout != "row_major" || deriveTileLayout(op.getDst()) != "row_major") - return op.emitOpError("tcmps lowering requires row-major tile layout"); - if (contract.validRows == ShapedType::kDynamic || - contract.validCols == ShapedType::kDynamic) - return op.emitOpError("tcmps lowering requires static valid shape"); - int64_t dstRows = ShapedType::kDynamic; - int64_t dstCols = ShapedType::kDynamic; - deriveValidShape(op.getDst(), dstRows, dstCols); - if (contract.validRows != dstRows || contract.validCols != dstCols) - return op.emitOpError("tcmps lowering requires matching source and destination valid region"); - if (!isSupportedPackedCmp32ElementType(contract.elementType)) - return op.emitOpError("tcmps lowering currently supports only 32-bit source tiles"); - auto dstElemType = dyn_cast_or_null(getElementType(op.getDst())); - if (!dstElemType || !dstElemType.isUnsignedInteger(8)) - return op.emitOpError("tcmps lowering currently requires ui8 destination tiles"); - if (!isCompatibleScalarForSemanticType(contract.elementType, - op.getScalar().getType())) - return op.emitOpError("tcmps lowering requires scalar type to match source element type"); - - Value srcBuffer = materializeBufferPointer(op.getSrc(), contract.elementType, - getMemorySpace(op.getSrc()), rewriter, - op.getLoc()); - Value dstBuffer = materializeBufferPointer(op.getDst(), getElementType(op.getDst()), - getMemorySpace(op.getDst()), rewriter, - op.getLoc()); - if (!srcBuffer || !dstBuffer) - return op.emitOpError("tcmps lowering requires pointer-backed tile buffers"); - - StringAttr cmpMode = rewriter.getStringAttr(stringifyCmpModeAttr(op.getCmpModeAttr())); - return buildPackedCmp32VecScope( - "tcmps", contract, op.getDst(), dstBuffer, rewriter, op.getLoc(), - [&](PatternRewriter &nestedRewriter, Location nestedLoc, Value offset, - Value mask) -> Value { - auto vecType = - getVPTOVRegType(nestedRewriter.getContext(), contract.elementType); - auto loaded = - nestedRewriter.create(nestedLoc, vecType, srcBuffer, offset, StringAttr()); - return nestedRewriter - .create(nestedLoc, - getVPTOMaskTypeForElementType( - nestedRewriter.getContext(), - contract.elementType), - loaded.getResult(), op.getScalar(), mask, cmpMode) - .getResult(); - }); -} - -LogicalResult lowerTCmp(TCmpOp op, PatternRewriter &rewriter) { - VPTOBinaryContract contract = buildBinaryContract("cmp", op.getSrc0()); - deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); - deriveValidShape(op.getDst(), contract.validRows, contract.validCols); - - if (contract.tileDomain != VPTOTileDomain::Vec || - deriveTileDomain(getMemorySpace(op.getSrc1())) != VPTOTileDomain::Vec || - deriveTileDomain(getMemorySpace(op.getDst())) != VPTOTileDomain::Vec) - return op.emitOpError("tcmp lowering requires tile domain vec"); - if (contract.tileLayout != "row_major" || deriveTileLayout(op.getSrc1()) != "row_major" || - deriveTileLayout(op.getDst()) != "row_major") - return op.emitOpError("tcmp lowering requires row-major tile layout"); - if (contract.validRows == ShapedType::kDynamic || - contract.validCols == ShapedType::kDynamic) - return op.emitOpError("tcmp lowering requires static valid shape"); - int64_t src1Rows = ShapedType::kDynamic; - int64_t src1Cols = ShapedType::kDynamic; - int64_t dstRows = ShapedType::kDynamic; - int64_t dstCols = ShapedType::kDynamic; - deriveValidShape(op.getSrc1(), src1Rows, src1Cols); - deriveValidShape(op.getDst(), dstRows, dstCols); - if (contract.validRows != src1Rows || contract.validCols != src1Cols || - contract.validRows != dstRows || contract.validCols != dstCols) - return op.emitOpError("tcmp lowering requires matching source and destination valid region"); - if (!isSupportedPackedCmp32ElementType(contract.elementType)) - return op.emitOpError("tcmp lowering currently supports only 32-bit source tiles"); - if (getElementType(op.getSrc1()) != contract.elementType) - return op.emitOpError("tcmp lowering requires src1 element type to match src0"); - auto dstElemType = dyn_cast_or_null(getElementType(op.getDst())); - if (!dstElemType || !dstElemType.isUnsignedInteger(8)) - return op.emitOpError("tcmp lowering currently requires ui8 destination tiles"); - - Value src0Buffer = materializeBufferPointer(op.getSrc0(), contract.elementType, - getMemorySpace(op.getSrc0()), rewriter, - op.getLoc()); - Value src1Buffer = materializeBufferPointer(op.getSrc1(), contract.elementType, - getMemorySpace(op.getSrc1()), rewriter, - op.getLoc()); - Value dstBuffer = materializeBufferPointer(op.getDst(), getElementType(op.getDst()), - getMemorySpace(op.getDst()), rewriter, - op.getLoc()); - if (!src0Buffer || !src1Buffer || !dstBuffer) - return op.emitOpError("tcmp lowering requires pointer-backed tile buffers"); - - StringAttr cmpMode = rewriter.getStringAttr(stringifyCmpModeAttr(op.getCmpModeAttr())); - return buildPackedCmp32VecScope( - "tcmp", contract, op.getDst(), dstBuffer, rewriter, op.getLoc(), - [&](PatternRewriter &nestedRewriter, Location nestedLoc, Value offset, - Value mask) -> Value { - auto vecType = - getVPTOVRegType(nestedRewriter.getContext(), contract.elementType); - auto lhs = - nestedRewriter.create(nestedLoc, vecType, src0Buffer, offset, StringAttr()); - auto rhs = - nestedRewriter.create(nestedLoc, vecType, src1Buffer, offset, StringAttr()); - return nestedRewriter - .create(nestedLoc, - getVPTOMaskTypeForElementType( - nestedRewriter.getContext(), - contract.elementType), - lhs.getResult(), rhs.getResult(), mask, cmpMode) - .getResult(); - }); -} - -LogicalResult lowerTCI(TCIOp op, PatternRewriter &rewriter) { - Type elementType = getElementType(op.getDst()); - auto intType = dyn_cast_or_null(elementType); - if (!intType || (intType.getWidth() != 16 && intType.getWidth() != 32)) - return op.emitOpError("tci lowering requires i16 or i32 destination element type"); - if (deriveTileDomain(getMemorySpace(op.getDst())) != VPTOTileDomain::Vec) - return op.emitOpError("tci lowering requires tile domain vec"); - if (deriveTileLayout(op.getDst()) != "row_major") - return op.emitOpError("tci lowering requires row-major tile layout"); - - int64_t validRows = ShapedType::kDynamic; - int64_t validCols = ShapedType::kDynamic; - Value validRowsValue; - Value validColsValue; - deriveValidShapeValues(op.getDst(), validRowsValue, validColsValue); - deriveValidShape(op.getDst(), validRows, validCols); - if (validRows != 1) - return op.emitOpError("tci lowering currently requires valid rows == 1"); - - Value dstBuffer = materializeBufferPointer(op.getDst(), elementType, - getMemorySpace(op.getDst()), rewriter, - op.getLoc()); - if (!dstBuffer) - return op.emitOpError("tci lowering requires pointer-backed destination tile buffer"); - - Value upperBound = materializeIndexValue(validColsValue, validCols, rewriter, op.getLoc()); - if (!upperBound) - return op.emitOpError("tci lowering requires valid cols"); - - Value c0 = rewriter.create(op.getLoc(), 0); - Value c1 = rewriter.create(op.getLoc(), 1); - auto loop = rewriter.create(op.getLoc(), c0, upperBound, c1); - - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(loop.getBody()); - Value iv = loop.getInductionVar(); - Value ivAsElem = rewriter.create(op.getLoc(), intType, iv); - Value stored = - op.getDescending() - ? rewriter.create(op.getLoc(), op.getS(), ivAsElem).getResult() - : rewriter.create(op.getLoc(), op.getS(), ivAsElem).getResult(); - rewriter.create(op.getLoc(), dstBuffer, iv, stored); - return success(); -} - -LogicalResult lowerTRELU(TReluOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - VPTOUnaryContract contract = extractTReluContract(op); - if (failed(checkGenericUnaryContract( - op, contract, op.getDst(), - [](Type type) { - return type.isF16() || type.isF32() || - (isa(type) && cast(type).getWidth() == 32); - }, - "f16, f32, and i32 element types"))) - return failure(); - return buildUnaryVecScope("relu", contract, strategy, op.getSrc(), - op.getDst(), rewriter, op.getLoc()); -} - -LogicalResult lowerTNOT(TNotOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - VPTOUnaryContract contract = extractTNotContract(op); - if (failed(checkGenericUnaryContract( - op, contract, op.getDst(), - [](Type type) { - if (type.isF16() || type.isF32() || type.isBF16()) - return true; - if (auto intType = dyn_cast(type)) - return intType.getWidth() == 8 || intType.getWidth() == 16 || - intType.getWidth() == 32; - return false; - }, - "f16, f32, bf16, and 8/16/32-bit integer element types"))) - return failure(); - return buildUnaryVecScope("not", contract, strategy, op.getSrc(), op.getDst(), - rewriter, op.getLoc()); -} - -LogicalResult lowerTTRANS(TTransOp op, PatternRewriter &rewriter) { - VPTOUnaryContract contract = buildUnaryContract("trans", op.getSrc()); - int64_t dstRows = ShapedType::kDynamic; - int64_t dstCols = ShapedType::kDynamic; - deriveValidShape(op.getDst(), dstRows, dstCols); - - if (contract.tileDomain != VPTOTileDomain::Vec || - deriveTileDomain(getMemorySpace(op.getDst())) != VPTOTileDomain::Vec) - return op.emitOpError("ttrans lowering requires tile domain vec"); - if (contract.tileLayout != "row_major" || deriveTileLayout(op.getDst()) != "row_major") - return op.emitOpError("ttrans lowering requires row-major tile layout"); - if (contract.validRows == ShapedType::kDynamic || contract.validCols == ShapedType::kDynamic || - dstRows == ShapedType::kDynamic || dstCols == ShapedType::kDynamic) - return op.emitOpError("ttrans lowering requires static valid shape"); - if (contract.validRows != dstCols || contract.validCols != dstRows) - return op.emitOpError("ttrans lowering requires transposed source/destination valid shape"); - if (contract.elementType != getElementType(op.getDst())) - return op.emitOpError("ttrans lowering requires matching source/destination element type"); - - int64_t elemBytes = getElementByteSize(contract.elementType); - int64_t srcStride = deriveStaticRowStride(op.getSrc()); - int64_t dstStride = deriveStaticRowStride(op.getDst()); - if (elemBytes != 4) - return op.emitOpError("ttrans lowering currently supports only b32 element types"); - if (srcStride == ShapedType::kDynamic || dstStride == ShapedType::kDynamic) - return op.emitOpError("ttrans lowering requires static source/destination row stride"); - - auto dataVecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); - auto indexElemType = rewriter.getIntegerType(32); - auto indexVecType = getVPTOVRegType(rewriter.getContext(), indexElemType); - if (!dataVecType || !indexVecType) - return op.emitOpError("ttrans lowering requires supported VPTO vector types"); - - Value srcBuffer = materializeBufferPointer(op.getSrc(), contract.elementType, - getMemorySpace(op.getSrc()), rewriter, - op.getLoc()); - Value dstBuffer = materializeBufferPointer(op.getDst(), contract.elementType, - getMemorySpace(op.getDst()), rewriter, - op.getLoc()); - if (!srcBuffer || !dstBuffer) - return op.emitOpError("ttrans lowering requires pointer-backed tile buffers"); - - constexpr int64_t repeatBytes = 256; - constexpr int64_t blockBytes = 32; - int64_t elementsPerRepeat = repeatBytes / elemBytes; - int64_t blockSizeElem = blockBytes / elemBytes; - int64_t alignedRows = - llvm::divideCeil(contract.validRows, blockSizeElem) * blockSizeElem; - int64_t repeatTimes = llvm::divideCeil(alignedRows, elementsPerRepeat); - - Value c0 = rewriter.create(op.getLoc(), 0); - Value c1 = rewriter.create(op.getLoc(), 1); - Value colsUpper = rewriter.create(op.getLoc(), contract.validCols); - Value chunkUpper = rewriter.create(op.getLoc(), repeatTimes); - Value elementsPerRepeatValue = - rewriter.create(op.getLoc(), elementsPerRepeat); - Value dstStrideValue = rewriter.create(op.getLoc(), dstStride); - Value srcStrideI32 = rewriter.create(op.getLoc(), srcStride, 32); - - FailureOr vecScope = - createLoopScopeRegion(op.getLoc(), contract.loopScope, rewriter); - if (failed(vecScope)) - return op.emitOpError("failed to create AIV vector scope region"); - - OpBuilder::InsertionGuard aivGuard(rewriter); - rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); - auto colLoop = rewriter.create(op.getLoc(), c0, colsUpper, c1); - rewriter.setInsertionPointToStart(colLoop.getBody()); - auto chunkLoop = rewriter.create(op.getLoc(), c0, chunkUpper, c1); - rewriter.setInsertionPointToStart(chunkLoop.getBody()); - - Value chunkBase = rewriter.create(op.getLoc(), chunkLoop.getInductionVar(), - elementsPerRepeatValue); - Value colI32 = rewriter.create(op.getLoc(), indexElemType, - colLoop.getInductionVar()); - Value chunkBaseI32 = - rewriter.create(op.getLoc(), indexElemType, chunkBase); - auto indices = - rewriter.create(op.getLoc(), indexVecType, chunkBaseI32, - rewriter.getStringAttr("INC_ORDER")); - Value fullMask = buildAllPredicateMask(rewriter, op.getLoc(), indexElemType); - auto scaled = rewriter.create(op.getLoc(), indexVecType, - indices.getResult(), srcStrideI32, fullMask); - auto offsets = rewriter.create(op.getLoc(), indexVecType, - scaled.getResult(), colI32, fullMask); - Value fullActiveLanes = - rewriter.create(op.getLoc(), - dataVecType.getElementCount()); - auto gathered = - rewriter.create(op.getLoc(), dataVecType, srcBuffer, - offsets.getResult(), fullActiveLanes); - Value dstBase = - rewriter.create(op.getLoc(), colLoop.getInductionVar(), dstStrideValue); - Value dstOffset = rewriter.create(op.getLoc(), dstBase, chunkBase); - rewriter.create( - op.getLoc(), gathered.getResult(), dstBuffer, dstOffset, StringAttr(), - buildAllPredicateMask(rewriter, op.getLoc(), contract.elementType)); - return success(); -} - -template -LogicalResult lowerTFillPadCommon(FillPadOpTy op, PatternRewriter &rewriter, - bool allowDstExpand) { - VPTOUnaryContract contract = buildUnaryContract("fillpad", op.getSrc()); - int64_t dstRows = ShapedType::kDynamic; - int64_t dstCols = ShapedType::kDynamic; - deriveValidShape(op.getDst(), dstRows, dstCols); - - if (contract.tileDomain != VPTOTileDomain::Vec || - deriveTileDomain(getMemorySpace(op.getDst())) != VPTOTileDomain::Vec) - return op.emitOpError("fillpad lowering requires tile domain vec"); - if (contract.tileLayout != "row_major" || deriveTileLayout(op.getDst()) != "row_major") - return op.emitOpError("fillpad lowering requires row-major tile layout"); - if (contract.validRows == ShapedType::kDynamic || contract.validCols == ShapedType::kDynamic || - dstRows == ShapedType::kDynamic || dstCols == ShapedType::kDynamic) - return op.emitOpError("fillpad lowering requires static valid shape"); - if (!allowDstExpand && (contract.validRows != dstRows || contract.validCols != dstCols)) - return op.emitOpError("tfillpad lowering requires matching source/destination valid shape"); - if (allowDstExpand && (dstRows < contract.validRows || dstCols < contract.validCols)) - return op.emitOpError("tfillpad_expand lowering requires dst shape >= src shape"); - if (contract.elementType != getElementType(op.getDst())) - return op.emitOpError("fillpad lowering requires matching source/destination element type"); - - int64_t srcStride = deriveStaticRowStride(op.getSrc()); - int64_t dstStride = deriveStaticRowStride(op.getDst()); - if (srcStride == ShapedType::kDynamic || dstStride == ShapedType::kDynamic) - return op.emitOpError("fillpad lowering requires static source/destination row stride"); - - auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); - if (!vecType) - return op.emitOpError("fillpad lowering requires supported VPTO vector element type"); - - Value srcBuffer = materializeBufferPointer(op.getSrc(), contract.elementType, - getMemorySpace(op.getSrc()), rewriter, - op.getLoc()); - Value dstBuffer = materializeBufferPointer(op.getDst(), contract.elementType, - getMemorySpace(op.getDst()), rewriter, - op.getLoc()); - if (!srcBuffer || !dstBuffer) - return op.emitOpError("fillpad lowering requires pointer-backed tile buffers"); - - auto config = lookupTileConfig(op.getDst()); - PadValueAttr padAttr = config ? dyn_cast(config.getPad()) : PadValueAttr{}; - Attribute padValueAttr = buildFillPadValue(contract.elementType, padAttr, rewriter); - if (!padValueAttr) - return op.emitOpError("fillpad lowering requires a concrete non-null dst pad value"); - Value padScalar = rewriter.create(op.getLoc(), cast(padValueAttr)); - Value fullMask = buildAllPredicateMask(rewriter, op.getLoc(), vecType.getElementType()); - auto padVec = - rewriter.create(op.getLoc(), vecType, padScalar, fullMask, StringAttr()); - - int64_t vectorWidth = vecType.getElementCount(); - int64_t padCols = dstCols - contract.validCols; - int64_t padRows = dstRows - contract.validRows; - - Value c0 = rewriter.create(op.getLoc(), 0); - Value c1 = rewriter.create(op.getLoc(), 1); - Value srcRowsUpper = rewriter.create(op.getLoc(), contract.validRows); - Value srcColsUpper = rewriter.create(op.getLoc(), contract.validCols); - Value dstRowsUpper = rewriter.create(op.getLoc(), dstRows); - Value vectorStep = rewriter.create(op.getLoc(), vectorWidth); - Value srcStrideValue = rewriter.create(op.getLoc(), srcStride); - Value dstStrideValue = rewriter.create(op.getLoc(), dstStride); - Value validColsValue = rewriter.create(op.getLoc(), contract.validCols); - Value dstColsValue = rewriter.create(op.getLoc(), dstCols); - - FailureOr vecScope = - createLoopScopeRegion(op.getLoc(), contract.loopScope, rewriter); - if (failed(vecScope)) - return op.emitOpError("failed to create AIV vector scope region"); - - OpBuilder::InsertionGuard aivGuard(rewriter); - rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); - - auto rowLoop = rewriter.create(op.getLoc(), c0, srcRowsUpper, c1); - rewriter.setInsertionPointToStart(rowLoop.getBody()); - Value srcRowBase = rewriter.create(op.getLoc(), rowLoop.getInductionVar(), - srcStrideValue); - Value dstRowBase = rewriter.create(op.getLoc(), rowLoop.getInductionVar(), - dstStrideValue); - - auto copyChunkLoop = - rewriter.create(op.getLoc(), c0, srcColsUpper, vectorStep); - rewriter.setInsertionPointToStart(copyChunkLoop.getBody()); - Value copyOffset = - rewriter.create(op.getLoc(), srcRowBase, copyChunkLoop.getInductionVar()); - auto loaded = rewriter.create(op.getLoc(), vecType, srcBuffer, - copyOffset, StringAttr()); - Value copyDstOffset = - rewriter.create(op.getLoc(), dstRowBase, copyChunkLoop.getInductionVar()); - Value copyRemaining = - rewriter.create(op.getLoc(), validColsValue, copyChunkLoop.getInductionVar()); - auto copyNeedsClamp = rewriter.create(op.getLoc(), arith::CmpIPredicate::slt, - copyRemaining, vectorStep); - Value copyActiveLanes = - rewriter.create(op.getLoc(), copyNeedsClamp, copyRemaining, vectorStep); - Value copyMask = buildPredicateMaskForLaneCount( - rewriter, op.getLoc(), contract.elementType, copyActiveLanes); - rewriter.create(op.getLoc(), loaded.getResult(), dstBuffer, - copyDstOffset, StringAttr(), copyMask); - - rewriter.setInsertionPointAfter(copyChunkLoop); - if (padCols > 0) { - Value padColsUpper = rewriter.create(op.getLoc(), padCols); - auto padColLoop = rewriter.create(op.getLoc(), c0, padColsUpper, vectorStep); - rewriter.setInsertionPointToStart(padColLoop.getBody()); - Value padDstStart = rewriter.create(op.getLoc(), dstRowBase, validColsValue); - Value padDstOffset = rewriter.create(op.getLoc(), padDstStart, - padColLoop.getInductionVar()); - Value padRemaining = - rewriter.create(op.getLoc(), padColsUpper, padColLoop.getInductionVar()); - auto padNeedsClamp = rewriter.create(op.getLoc(), arith::CmpIPredicate::slt, - padRemaining, vectorStep); - Value padActiveLanes = - rewriter.create(op.getLoc(), padNeedsClamp, padRemaining, vectorStep); - Value padMask = buildPredicateMaskForLaneCount( - rewriter, op.getLoc(), contract.elementType, padActiveLanes); - rewriter.create(op.getLoc(), padVec.getResult(), dstBuffer, - padDstOffset, StringAttr(), padMask); - } - - rewriter.setInsertionPointAfter(rowLoop); - if (padRows > 0) { - Value bottomStart = rewriter.create(op.getLoc(), srcRowsUpper, dstStrideValue); - Value bottomElements = - rewriter.create(op.getLoc(), - rewriter.create(op.getLoc(), dstRowsUpper, - dstColsValue), - bottomStart); - auto bottomLoop = rewriter.create(op.getLoc(), c0, bottomElements, vectorStep); - rewriter.setInsertionPointToStart(bottomLoop.getBody()); - Value bottomDstOffset = - rewriter.create(op.getLoc(), bottomStart, bottomLoop.getInductionVar()); - Value bottomRemaining = - rewriter.create(op.getLoc(), bottomElements, bottomLoop.getInductionVar()); - auto bottomNeedsClamp = rewriter.create( - op.getLoc(), arith::CmpIPredicate::slt, bottomRemaining, vectorStep); - Value bottomActiveLanes = rewriter.create( - op.getLoc(), bottomNeedsClamp, bottomRemaining, vectorStep); - Value bottomMask = buildPredicateMaskForLaneCount( - rewriter, op.getLoc(), contract.elementType, bottomActiveLanes); - rewriter.create(op.getLoc(), padVec.getResult(), dstBuffer, - bottomDstOffset, StringAttr(), bottomMask); - } - - return success(); -} - -LogicalResult lowerTFILLPAD(TFillPadOp op, PatternRewriter &rewriter) { - return lowerTFillPadCommon(op, rewriter, /*allowDstExpand=*/false); -} - -LogicalResult lowerTFILLPADExpand(TFillPadExpandOp op, PatternRewriter &rewriter) { - return lowerTFillPadCommon(op, rewriter, /*allowDstExpand=*/true); -} - -LogicalResult lowerTExpandS(TExpandsOp op, PatternRewriter &rewriter) { - VPTOUnaryContract contract = extractTExpandSContract(op); - if (contract.tileDomain != VPTOTileDomain::Vec) - return op.emitOpError("expands lowering requires tile domain vec"); - if (contract.tileLayout != "row_major") - return op.emitOpError("expands lowering requires row-major tile layout"); - if (!contract.elementType) - return op.emitOpError("expands lowering requires a concrete element type"); - - Type scalarType = op.getScalar().getType(); - if (!isCompatibleScalarForSemanticType(contract.elementType, scalarType)) - return op.emitOpError("expands lowering requires scalar type to match destination element type"); - - if (!(contract.elementType.isF16() || contract.elementType.isF32() || - contract.elementType.isBF16())) { - if (auto intType = dyn_cast(contract.elementType)) { - unsigned width = intType.getWidth(); - if (width != 8 && width != 16 && width != 32) - return op.emitOpError("expands lowering supports only f16, f32, bf16, and 8/16/32-bit integer element types"); - } else { - return op.emitOpError("expands lowering supports only scalar integer or floating-point element types"); - } - } - - return buildExpandScalarVecScope(contract, op.getScalar(), op.getDst(), - rewriter, op.getLoc()); -} - -LogicalResult lowerTGather(TGatherOp op, PatternRewriter &rewriter) { - auto requireVecRowMajor = [&](Value value, StringRef role) -> LogicalResult { - if (deriveTileDomain(getMemorySpace(value)) != VPTOTileDomain::Vec) - return op.emitOpError() << "tgather lowering requires vec tile domain for " - << role; - if (deriveTileLayout(value) != "row_major") - return op.emitOpError() << "tgather lowering requires row-major layout for " - << role; - return success(); - }; - - if (failed(requireVecRowMajor(op.getSrc(), "src")) || - failed(requireVecRowMajor(op.getDst(), "dst"))) - return failure(); - - Type dataElementType = getElementType(op.getSrc()); - if (dataElementType != getElementType(op.getDst())) - return op.emitOpError("tgather lowering requires matching src/dst element type"); - - auto dataVecType = getVPTOVRegType(rewriter.getContext(), dataElementType); - if (!dataVecType) - return op.emitOpError("tgather lowering requires supported VPTO data type"); - - Value srcBuffer = materializeBufferPointer(op.getSrc(), dataElementType, - getMemorySpace(op.getSrc()), rewriter, - op.getLoc()); - Value dstBuffer = materializeBufferPointer(op.getDst(), dataElementType, - getMemorySpace(op.getDst()), rewriter, - op.getLoc()); - if (!srcBuffer || !dstBuffer) - return op.emitOpError("tgather lowering requires pointer-backed tile buffers"); - - int64_t srcStride = deriveStaticRowStride(op.getSrc()); - int64_t dstStride = deriveStaticRowStride(op.getDst()); - if (srcStride == ShapedType::kDynamic || dstStride == ShapedType::kDynamic) - return op.emitOpError("tgather lowering requires static row stride"); - - Value c0 = rewriter.create(op.getLoc(), 0); - Value c1 = rewriter.create(op.getLoc(), 1); - VPTOLoopScopeContract loopScope; - loopScope.kind = VPTOLoopScopeKind::AIVVectorScope; - loopScope.loweredAttr = kLoweredLoopScopeAttrName; - loopScope.loopDepth = 0; - - FailureOr vecScope = - createLoopScopeRegion(op.getLoc(), loopScope, rewriter); - if (failed(vecScope)) - return op.emitOpError("failed to create AIV vector scope region"); - - OpBuilder::InsertionGuard aivGuard(rewriter); - rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); - - if (Value indices = op.getIndices()) { - if (failed(requireVecRowMajor(indices, "indices"))) - return failure(); - - Type indexElementType = getElementType(indices); - auto indexIntegerType = dyn_cast(indexElementType); - auto indexVecType = getVPTOVRegType(rewriter.getContext(), indexElementType); - if (!indexIntegerType || !indexVecType) - return op.emitOpError("tgather index lowering requires integer indices with supported VPTO vector type"); - if (indexVecType.getElementCount() != dataVecType.getElementCount()) - return op.emitOpError("tgather index lowering currently requires matching data/index vector widths"); - - Value indexBuffer = materializeBufferPointer(indices, indexElementType, - getMemorySpace(indices), rewriter, - op.getLoc()); - if (!indexBuffer) - return op.emitOpError("tgather index lowering requires pointer-backed indices tile"); - - int64_t indexStride = deriveStaticRowStride(indices); - if (indexStride == ShapedType::kDynamic) - return op.emitOpError("tgather index lowering requires static index row stride"); - - Value validRowsValue; - Value validColsValue; - int64_t validRows = ShapedType::kDynamic; - int64_t validCols = ShapedType::kDynamic; - deriveValidShapeValues(op.getDst(), validRowsValue, validColsValue); - deriveValidShape(op.getDst(), validRows, validCols); - if (failed(resolveExecutionValidShape(op.getDst(), validRowsValue, validColsValue, - validRows, validCols, rewriter, op.getLoc()))) - return op.emitOpError("tgather index lowering requires valid dst shape"); - - int64_t chunkWidth = indexVecType.getElementCount(); - Value chunkStep = rewriter.create(op.getLoc(), chunkWidth); - Value dstStrideValue = - rewriter.create(op.getLoc(), dstStride); - Value indexStrideValue = - rewriter.create(op.getLoc(), indexStride); - - auto rowLoop = rewriter.create(op.getLoc(), c0, validRowsValue, c1); - rewriter.setInsertionPointToStart(rowLoop.getBody()); - auto chunkLoop = - rewriter.create(op.getLoc(), c0, validColsValue, chunkStep); - rewriter.setInsertionPointToStart(chunkLoop.getBody()); - - Value row = rowLoop.getInductionVar(); - Value chunkBase = chunkLoop.getInductionVar(); - Value remaining = - rewriter.create(op.getLoc(), validColsValue, chunkBase); - Value activeLanes = - buildMinIndexValue(rewriter, op.getLoc(), remaining, chunkStep); - - Value dstRowBase = - rewriter.create(op.getLoc(), row, dstStrideValue); - Value indexRowBase = - rewriter.create(op.getLoc(), row, indexStrideValue); - Value indexOffset = - rewriter.create(op.getLoc(), indexRowBase, chunkBase); - auto offsetVector = rewriter.create(op.getLoc(), indexVecType, - indexBuffer, indexOffset, - StringAttr()); - auto gathered = rewriter.create( - op.getLoc(), dataVecType, srcBuffer, offsetVector.getResult(), activeLanes); - Value dstOffset = - rewriter.create(op.getLoc(), dstRowBase, chunkBase); - return buildMaskedVectorStore(rewriter, op.getLoc(), gathered.getResult(), - dstBuffer, dstOffset, activeLanes, chunkWidth); - } - - auto maskPattern = op.getMaskPatternAttr(); - if (!maskPattern) - return op.emitOpError("tgather lowering requires indices or maskPattern"); - if (maskPattern.getValue() != MaskPattern::P1111) - return op.emitOpError("tgather mask lowering currently supports only maskPattern=P1111"); - - Value validRowsValue; - Value validColsValue; - int64_t validRows = ShapedType::kDynamic; - int64_t validCols = ShapedType::kDynamic; - deriveValidShapeValues(op.getSrc(), validRowsValue, validColsValue); - deriveValidShape(op.getSrc(), validRows, validCols); - if (failed(resolveExecutionValidShape(op.getSrc(), validRowsValue, validColsValue, - validRows, validCols, rewriter, op.getLoc()))) - return op.emitOpError("tgather mask lowering requires valid src shape"); - - int64_t chunkWidth = dataVecType.getElementCount(); - Value chunkStep = rewriter.create(op.getLoc(), chunkWidth); - Value srcStrideValue = - rewriter.create(op.getLoc(), srcStride); - Value dstStrideValue = - rewriter.create(op.getLoc(), dstStride); - - auto rowLoop = rewriter.create(op.getLoc(), c0, validRowsValue, c1); - rewriter.setInsertionPointToStart(rowLoop.getBody()); - auto chunkLoop = - rewriter.create(op.getLoc(), c0, validColsValue, chunkStep); - rewriter.setInsertionPointToStart(chunkLoop.getBody()); - - Value row = rowLoop.getInductionVar(); - Value chunkBase = chunkLoop.getInductionVar(); - Value remaining = - rewriter.create(op.getLoc(), validColsValue, chunkBase); - Value activeLanes = buildMinIndexValue(rewriter, op.getLoc(), remaining, chunkStep); - - Value srcRowBase = - rewriter.create(op.getLoc(), row, srcStrideValue); - Value dstRowBase = - rewriter.create(op.getLoc(), row, dstStrideValue); - Value srcOffset = - rewriter.create(op.getLoc(), srcRowBase, chunkBase); - auto loaded = rewriter.create(op.getLoc(), dataVecType, srcBuffer, - srcOffset, StringAttr()); - Value dstOffset = - rewriter.create(op.getLoc(), dstRowBase, chunkBase); - return buildMaskedVectorStore(rewriter, op.getLoc(), loaded.getResult(), dstBuffer, - dstOffset, activeLanes, chunkWidth); -} - -LogicalResult lowerTGatherB(TGatherBOp op, PatternRewriter &rewriter) { - auto requireVecRowMajor = [&](Value value, StringRef role) -> LogicalResult { - if (deriveTileDomain(getMemorySpace(value)) != VPTOTileDomain::Vec) - return op.emitOpError() << "tgatherb lowering requires vec tile domain for " - << role; - if (deriveTileLayout(value) != "row_major") - return op.emitOpError() << "tgatherb lowering requires row-major layout for " - << role; - return success(); - }; - - if (failed(requireVecRowMajor(op.getSrc(), "src")) || - failed(requireVecRowMajor(op.getOffsets(), "offsets")) || - failed(requireVecRowMajor(op.getDst(), "dst"))) - return failure(); - - Type dataElementType = getElementType(op.getDst()); - if (getElementType(op.getSrc()) != dataElementType) - return op.emitOpError("tgatherb lowering requires matching src/dst element type"); - - auto offsetIntegerType = dyn_cast(getElementType(op.getOffsets())); - if (!offsetIntegerType || offsetIntegerType.getWidth() != 32 || - !offsetIntegerType.isUnsigned()) - return op.emitOpError("tgatherb lowering currently requires unsigned 32-bit offsets"); - - auto dataVecType = getVPTOVRegType(rewriter.getContext(), dataElementType); - auto offsetVecType = - getVPTOVRegType(rewriter.getContext(), getElementType(op.getOffsets())); - if (!dataVecType || !offsetVecType) - return op.emitOpError("tgatherb lowering requires supported VPTO vector types"); - - Value srcBuffer = materializeBufferPointer(op.getSrc(), dataElementType, - getMemorySpace(op.getSrc()), rewriter, - op.getLoc()); - Value dstBuffer = materializeBufferPointer(op.getDst(), dataElementType, - getMemorySpace(op.getDst()), rewriter, - op.getLoc()); - Value offsetBuffer = - materializeBufferPointer(op.getOffsets(), getElementType(op.getOffsets()), - getMemorySpace(op.getOffsets()), rewriter, op.getLoc()); - if (!srcBuffer || !dstBuffer || !offsetBuffer) - return op.emitOpError("tgatherb lowering requires pointer-backed tile buffers"); - - int64_t dstStride = deriveStaticRowStride(op.getDst()); - int64_t offsetStride = deriveStaticRowStride(op.getOffsets()); - int64_t staticRows = deriveStaticShapeDim(op.getDst(), 0); - int64_t staticCols = deriveStaticShapeDim(op.getDst(), 1); - if (dstStride == ShapedType::kDynamic || offsetStride == ShapedType::kDynamic || - staticRows == ShapedType::kDynamic || staticCols == ShapedType::kDynamic) - return op.emitOpError("tgatherb lowering requires static tile shape and row stride"); - - Value validRowsValue; - Value validColsValue; - int64_t validRows = ShapedType::kDynamic; - int64_t validCols = ShapedType::kDynamic; - deriveValidShapeValues(op.getDst(), validRowsValue, validColsValue); - deriveValidShape(op.getDst(), validRows, validCols); - if (failed(resolveExecutionValidShape(op.getDst(), validRowsValue, validColsValue, - validRows, validCols, rewriter, op.getLoc()))) - return op.emitOpError("tgatherb lowering requires valid dst shape"); - - unsigned elemBytes = dataElementType.getIntOrFloatBitWidth() / 8; - int64_t elementsPerRepeat = 256 / elemBytes; - int64_t blockSizeElem = 32 / elemBytes; - int64_t staticRepeatTimes = llvm::divideCeil(staticCols, elementsPerRepeat); - - Value c0 = rewriter.create(op.getLoc(), 0); - Value c1 = rewriter.create(op.getLoc(), 1); - Value elementsPerRepeatValue = - rewriter.create(op.getLoc(), elementsPerRepeat); - Value blockSizeElemValue = - rewriter.create(op.getLoc(), blockSizeElem); - Value dstStrideValue = - rewriter.create(op.getLoc(), dstStride); - Value offsetStrideValue = - rewriter.create(op.getLoc(), offsetStride); - - VPTOLoopScopeContract loopScope; - loopScope.kind = VPTOLoopScopeKind::AIVVectorScope; - loopScope.loweredAttr = kLoweredLoopScopeAttrName; - loopScope.loopDepth = 0; - - FailureOr vecScope = - createLoopScopeRegion(op.getLoc(), loopScope, rewriter); - if (failed(vecScope)) - return op.emitOpError("failed to create AIV vector scope region"); - - OpBuilder::InsertionGuard aivGuard(rewriter); - rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); - - if (staticRepeatTimes > staticRows) { - auto rowLoop = rewriter.create(op.getLoc(), c0, validRowsValue, c1); - rewriter.setInsertionPointToStart(rowLoop.getBody()); - auto chunkLoop = rewriter.create(op.getLoc(), c0, validColsValue, - elementsPerRepeatValue); - rewriter.setInsertionPointToStart(chunkLoop.getBody()); - - Value row = rowLoop.getInductionVar(); - Value chunkBase = chunkLoop.getInductionVar(); - Value remaining = - rewriter.create(op.getLoc(), validColsValue, chunkBase); - Value activeLanes = buildMinIndexValue(rewriter, op.getLoc(), remaining, - elementsPerRepeatValue); - Value rowOffsetBase = - rewriter.create(op.getLoc(), row, offsetStrideValue); - Value rowDstBase = - rewriter.create(op.getLoc(), row, dstStrideValue); - Value offsetChunkBase = - rewriter.create(op.getLoc(), chunkBase, - blockSizeElemValue); - Value offsetLoadOffset = - rewriter.create(op.getLoc(), rowOffsetBase, offsetChunkBase); - auto offsets = rewriter.create(op.getLoc(), offsetVecType, - offsetBuffer, offsetLoadOffset, - StringAttr()); - auto gathered = rewriter.create( - op.getLoc(), dataVecType, srcBuffer, offsets.getResult(), activeLanes); - Value dstOffset = - rewriter.create(op.getLoc(), rowDstBase, chunkBase); - return buildMaskedVectorStore(rewriter, op.getLoc(), gathered.getResult(), - dstBuffer, dstOffset, activeLanes, - dataVecType.getElementCount()); - } - - auto chunkLoop = rewriter.create(op.getLoc(), c0, validColsValue, - elementsPerRepeatValue); - rewriter.setInsertionPointToStart(chunkLoop.getBody()); - auto rowLoop = rewriter.create(op.getLoc(), c0, validRowsValue, c1); - rewriter.setInsertionPointToStart(rowLoop.getBody()); - - Value chunkBase = chunkLoop.getInductionVar(); - Value row = rowLoop.getInductionVar(); - Value remaining = - rewriter.create(op.getLoc(), validColsValue, chunkBase); - Value activeLanes = buildMinIndexValue(rewriter, op.getLoc(), remaining, - elementsPerRepeatValue); - Value rowOffsetBase = - rewriter.create(op.getLoc(), row, offsetStrideValue); - Value rowDstBase = - rewriter.create(op.getLoc(), row, dstStrideValue); - Value offsetChunkBase = - rewriter.create(op.getLoc(), chunkBase, - blockSizeElemValue); - Value offsetLoadOffset = - rewriter.create(op.getLoc(), rowOffsetBase, offsetChunkBase); - auto offsets = rewriter.create(op.getLoc(), offsetVecType, offsetBuffer, - offsetLoadOffset, StringAttr()); - auto gathered = rewriter.create( - op.getLoc(), dataVecType, srcBuffer, offsets.getResult(), activeLanes); - Value dstOffset = - rewriter.create(op.getLoc(), chunkBase, rowDstBase); - return buildMaskedVectorStore(rewriter, op.getLoc(), gathered.getResult(), - dstBuffer, dstOffset, activeLanes, - dataVecType.getElementCount()); -} - -LogicalResult lowerTScatter(TScatterOp op, PatternRewriter &rewriter) { - auto requireVecRowMajor = [&](Value value, StringRef role) -> LogicalResult { - if (deriveTileDomain(getMemorySpace(value)) != VPTOTileDomain::Vec) - return op.emitOpError() << "tscatter lowering requires vec tile domain for " - << role; - if (deriveTileLayout(value) != "row_major") - return op.emitOpError() << "tscatter lowering requires row-major layout for " - << role; - return success(); - }; - - if (failed(requireVecRowMajor(op.getSrc(), "src")) || - failed(requireVecRowMajor(op.getIndexes(), "indexes")) || - failed(requireVecRowMajor(op.getDst(), "dst"))) - return failure(); - - Type dataElementType = getElementType(op.getSrc()); - if (dataElementType != getElementType(op.getDst())) - return op.emitOpError("tscatter lowering requires matching src/dst element type"); - - Type indexElementType = getElementType(op.getIndexes()); - auto indexIntegerType = dyn_cast(indexElementType); - if (!indexIntegerType || indexIntegerType.getWidth() != 32) - return op.emitOpError("tscatter lowering currently requires 32-bit integer indexes"); - - auto dataVecType = getVPTOVRegType(rewriter.getContext(), dataElementType); - auto indexVecType = getVPTOVRegType(rewriter.getContext(), indexElementType); - if (!dataVecType || !indexVecType || - dataVecType.getElementCount() != indexVecType.getElementCount()) - return op.emitOpError("tscatter lowering currently requires matching data/index vector widths"); - - Value srcBuffer = materializeBufferPointer(op.getSrc(), dataElementType, - getMemorySpace(op.getSrc()), rewriter, - op.getLoc()); - Value dstBuffer = materializeBufferPointer(op.getDst(), dataElementType, - getMemorySpace(op.getDst()), rewriter, - op.getLoc()); - Value indexBuffer = materializeBufferPointer(op.getIndexes(), indexElementType, - getMemorySpace(op.getIndexes()), rewriter, - op.getLoc()); - if (!srcBuffer || !dstBuffer || !indexBuffer) - return op.emitOpError("tscatter lowering requires pointer-backed tile buffers"); - - int64_t srcStride = deriveStaticRowStride(op.getSrc()); - int64_t indexStride = deriveStaticRowStride(op.getIndexes()); - if (srcStride == ShapedType::kDynamic || indexStride == ShapedType::kDynamic) - return op.emitOpError("tscatter lowering requires static src/index row stride"); - - Value validRowsValue; - Value validColsValue; - int64_t validRows = ShapedType::kDynamic; - int64_t validCols = ShapedType::kDynamic; - deriveValidShapeValues(op.getIndexes(), validRowsValue, validColsValue); - deriveValidShape(op.getIndexes(), validRows, validCols); - if (failed(resolveExecutionValidShape(op.getIndexes(), validRowsValue, validColsValue, - validRows, validCols, rewriter, op.getLoc()))) - return op.emitOpError("tscatter lowering requires valid index shape"); - - Value c0 = rewriter.create(op.getLoc(), 0); - Value c1 = rewriter.create(op.getLoc(), 1); - Value chunkStep = rewriter.create( - op.getLoc(), indexVecType.getElementCount()); - Value srcStrideValue = - rewriter.create(op.getLoc(), srcStride); - Value indexStrideValue = - rewriter.create(op.getLoc(), indexStride); - - VPTOLoopScopeContract loopScope; - loopScope.kind = VPTOLoopScopeKind::AIVVectorScope; - loopScope.loweredAttr = kLoweredLoopScopeAttrName; - loopScope.loopDepth = 0; - - FailureOr vecScope = - createLoopScopeRegion(op.getLoc(), loopScope, rewriter); - if (failed(vecScope)) - return op.emitOpError("failed to create AIV vector scope region"); - - OpBuilder::InsertionGuard aivGuard(rewriter); - rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); - auto rowLoop = rewriter.create(op.getLoc(), c0, validRowsValue, c1); - rewriter.setInsertionPointToStart(rowLoop.getBody()); - auto chunkLoop = - rewriter.create(op.getLoc(), c0, validColsValue, chunkStep); - rewriter.setInsertionPointToStart(chunkLoop.getBody()); - - Value row = rowLoop.getInductionVar(); - Value chunkBase = chunkLoop.getInductionVar(); - Value remaining = - rewriter.create(op.getLoc(), validColsValue, chunkBase); - Value activeLanes = - buildMinIndexValue(rewriter, op.getLoc(), remaining, chunkStep); - - Value srcRowBase = - rewriter.create(op.getLoc(), row, srcStrideValue); - Value indexRowBase = - rewriter.create(op.getLoc(), row, indexStrideValue); - Value srcOffset = - rewriter.create(op.getLoc(), srcRowBase, chunkBase); - Value indexOffset = - rewriter.create(op.getLoc(), indexRowBase, chunkBase); - auto srcVector = rewriter.create(op.getLoc(), dataVecType, srcBuffer, - srcOffset, StringAttr()); - auto indexVector = rewriter.create(op.getLoc(), indexVecType, indexBuffer, - indexOffset, StringAttr()); - rewriter.create(op.getLoc(), srcVector.getResult(), dstBuffer, - indexVector.getResult(), activeLanes); - return success(); -} - -LogicalResult lowerTMrgSort(TMrgSortOp op, PatternRewriter &rewriter) { - auto requireVecRowMajor = [&](Value value, StringRef role) -> LogicalResult { - if (deriveTileDomain(getMemorySpace(value)) != VPTOTileDomain::Vec) - return op.emitOpError() << "tmrgsort lowering requires vec tile domain for " - << role; - if (deriveTileLayout(value) != "row_major") - return op.emitOpError() << "tmrgsort lowering requires row-major layout for " - << role; - return success(); - }; - auto requireOneRow = [&](Value value, StringRef role) -> LogicalResult { - if (deriveStaticShapeDim(value, 0) != 1) - return op.emitOpError() << "tmrgsort lowering requires rows==1 for " << role; - return success(); - }; - - Location loc = op.getLoc(); - if (op.isFormat1()) { - Value src = op.getSrcs().front(); - Value dst = op.getDsts().front(); - if (failed(requireVecRowMajor(src, "src")) || failed(requireVecRowMajor(dst, "dst")) || - failed(requireOneRow(src, "src")) || failed(requireOneRow(dst, "dst"))) - return failure(); - - Type elementType = getElementType(src); - if (elementType != getElementType(dst)) - return op.emitOpError("tmrgsort format1 requires matching src/dst element type"); - if (!(elementType.isF16() || elementType.isF32())) - return op.emitOpError("tmrgsort format1 currently supports only f16/f32"); - - Value srcBuffer = materializeBufferPointer(src, elementType, getMemorySpace(src), - rewriter, loc); - Value dstBuffer = materializeBufferPointer(dst, elementType, getMemorySpace(dst), - rewriter, loc); - if (!srcBuffer || !dstBuffer) - return op.emitOpError("tmrgsort format1 requires pointer-backed tile buffers"); - - Value blockLen = op.getBlockLen(); - if (!blockLen) - return op.emitOpError("tmrgsort format1 requires blockLen"); - Value blockLenI64; - if (blockLen.getType().isIndex()) - blockLenI64 = - rewriter.create(loc, rewriter.getI64Type(), blockLen); - else - blockLenI64 = - rewriter.create(loc, rewriter.getI64Type(), blockLen); - Value blockLenIndex = - rewriter.create(loc, rewriter.getIndexType(), blockLenI64); - - Value validRowsValue; - Value validColsValue; - int64_t validRows = ShapedType::kDynamic; - int64_t validCols = ShapedType::kDynamic; - deriveValidShapeValues(src, validRowsValue, validColsValue); - deriveValidShape(src, validRows, validCols); - Value validColsI64 = materializeI64Value(validColsValue, validCols, rewriter, loc); - - int64_t elemBytes = getElementByteSize(elementType); - Value numStructures = rewriter.create( - loc, rewriter.getI64Type(), - rewriter.create( - loc, blockLenI64, rewriter.create(loc, elemBytes, 64)), - rewriter.create(loc, 3, 64)); - Value count = buildPackedCountI64(rewriter, loc, - {numStructures, numStructures, numStructures, numStructures}); - Value repeatTimes = rewriter.create( - loc, validColsI64, - rewriter.create( - loc, blockLenI64, rewriter.create(loc, 4, 64))); - Value config = rewriter.create( - loc, repeatTimes, rewriter.create(loc, 0b1111 << 8, 64)); - - Value src0 = srcBuffer; - Value src1 = offsetBufferPointer(srcBuffer, elementType, blockLenIndex, rewriter, loc); - Value src2 = offsetBufferPointer( - srcBuffer, elementType, - rewriter.create(loc, blockLenIndex, - rewriter.create(loc, 2)), - rewriter, loc); - Value src3 = offsetBufferPointer( - srcBuffer, elementType, - rewriter.create(loc, blockLenIndex, - rewriter.create(loc, 3)), - rewriter, loc); - rewriter.create(loc, dstBuffer, src0, src1, src2, src3, count, - config); - return success(); - } - - if (!op.isFormat2()) - return op.emitOpError("unsupported tmrgsort format for current vpto backend"); - if (op.getExhausted()) - return op.emitOpError("tmrgsort format2 exhausted=true is not yet supported"); - if (op.getSrcs().size() != 4 || op.getDsts().size() != 2) - return op.emitOpError("tmrgsort format2 currently requires exactly 4 srcs and 2 dsts"); - - Type elementType = getElementType(op.getSrcs().front()); - if (!(elementType.isF16() || elementType.isF32())) - return op.emitOpError("tmrgsort format2 currently supports only f16/f32"); - - SmallVector srcBuffers; - SmallVector srcCounts; - srcBuffers.reserve(4); - srcCounts.reserve(4); - for (Value src : op.getSrcs()) { - if (failed(requireVecRowMajor(src, "src")) || failed(requireOneRow(src, "src"))) - return failure(); - if (getElementType(src) != elementType) - return op.emitOpError("tmrgsort format2 requires matching source element types"); - - Value srcBuffer = - materializeBufferPointer(src, elementType, getMemorySpace(src), rewriter, loc); - if (!srcBuffer) - return op.emitOpError("tmrgsort format2 requires pointer-backed source tiles"); - srcBuffers.push_back(srcBuffer); - - Value rowsValue; - Value colsValue; - int64_t rows = ShapedType::kDynamic; - int64_t cols = ShapedType::kDynamic; - deriveValidShapeValues(src, rowsValue, colsValue); - deriveValidShape(src, rows, cols); - Value colsI64 = materializeI64Value(colsValue, cols, rewriter, loc); - srcCounts.push_back(rewriter.create( - loc, rewriter.getI64Type(), colsI64, - rewriter.create(loc, elementType.isF32() ? 1 : 2, 64))); - } - - Value dst = op.getDsts()[0]; - Value tmp = op.getDsts()[1]; - if (failed(requireVecRowMajor(dst, "dst")) || failed(requireVecRowMajor(tmp, "tmp")) || - failed(requireOneRow(dst, "dst")) || failed(requireOneRow(tmp, "tmp"))) - return failure(); - if (getElementType(dst) != elementType || getElementType(tmp) != elementType) - return op.emitOpError("tmrgsort format2 requires matching dst/tmp element types"); - - Value dstBuffer = - materializeBufferPointer(dst, elementType, getMemorySpace(dst), rewriter, loc); - Value tmpBuffer = - materializeBufferPointer(tmp, elementType, getMemorySpace(tmp), rewriter, loc); - if (!dstBuffer || !tmpBuffer) - return op.emitOpError("tmrgsort format2 requires pointer-backed dst/tmp tiles"); - - Value count = buildPackedCountI64(rewriter, loc, srcCounts); - Value config = - rewriter.create(loc, 1 | (0b1111 << 8), 64); - rewriter.create(loc, tmpBuffer, srcBuffers[0], srcBuffers[1], - srcBuffers[2], srcBuffers[3], count, config); - - Value dstRowsValue; - Value dstColsValue; - int64_t dstRows = ShapedType::kDynamic; - int64_t dstCols = ShapedType::kDynamic; - deriveValidShapeValues(dst, dstRowsValue, dstColsValue); - deriveValidShape(dst, dstRows, dstCols); - Value dstColsI64 = materializeI64Value(dstColsValue, dstCols, rewriter, loc); - int64_t elemBytes = getElementByteSize(elementType); - Value lenBurst = buildCeilDivPositiveI64( - rewriter, loc, - rewriter.create( - loc, dstColsI64, rewriter.create(loc, elemBytes, 64)), - 32); - Value zeroI64 = rewriter.create(loc, 0, 64); - Value oneI64 = rewriter.create(loc, 1, 64); - rewriter.create(loc, tmpBuffer, dstBuffer, zeroI64, oneI64, - lenBurst, zeroI64, zeroI64); - return success(); -} - -LogicalResult lowerTSort32(TSort32Op op, PatternRewriter &rewriter) { - auto requireVecRowMajor = [&](Value value, StringRef role) -> LogicalResult { - if (deriveTileDomain(getMemorySpace(value)) != VPTOTileDomain::Vec) - return op.emitOpError() << "tsort32 lowering requires vec tile domain for " - << role; - if (deriveTileLayout(value) != "row_major") - return op.emitOpError() << "tsort32 lowering requires row-major layout for " - << role; - return success(); - }; - - if (failed(requireVecRowMajor(op.getSrc(), "src")) || - failed(requireVecRowMajor(op.getDst(), "dst")) || - failed(requireVecRowMajor(op.getIdx(), "idx"))) - return failure(); - - Type dataType = getElementType(op.getSrc()); - if (dataType != getElementType(op.getDst())) - return op.emitOpError("tsort32 lowering requires matching src/dst element type"); - if (!(dataType.isF16() || dataType.isF32())) - return op.emitOpError("tsort32 lowering currently supports only f16/f32 data"); - auto idxType = dyn_cast(getElementType(op.getIdx())); - if (!idxType || idxType.getWidth() != 32 || !idxType.isUnsigned()) - return op.emitOpError("tsort32 lowering currently requires u32 index tile"); - - Value srcBuffer = - materializeBufferPointer(op.getSrc(), dataType, getMemorySpace(op.getSrc()), - rewriter, op.getLoc()); - Value dstBuffer = - materializeBufferPointer(op.getDst(), dataType, getMemorySpace(op.getDst()), - rewriter, op.getLoc()); - Value idxBuffer = materializeBufferPointer(op.getIdx(), getElementType(op.getIdx()), - getMemorySpace(op.getIdx()), rewriter, - op.getLoc()); - if (!srcBuffer || !dstBuffer || !idxBuffer) - return op.emitOpError("tsort32 lowering requires pointer-backed tiles"); - - int64_t srcStride = deriveStaticRowStride(op.getSrc()); - int64_t dstStride = deriveStaticRowStride(op.getDst()); - int64_t idxStride = deriveStaticRowStride(op.getIdx()); - if (srcStride == ShapedType::kDynamic || dstStride == ShapedType::kDynamic || - idxStride == ShapedType::kDynamic) - return op.emitOpError("tsort32 lowering requires static row stride"); - - Value validRowsValue; - Value validColsValue; - int64_t validRows = ShapedType::kDynamic; - int64_t validCols = ShapedType::kDynamic; - deriveValidShapeValues(op.getSrc(), validRowsValue, validColsValue); - deriveValidShape(op.getSrc(), validRows, validCols); - if (validCols == ShapedType::kDynamic || (validCols % 32) != 0) - return op.emitOpError("tsort32 lowering currently requires static validCol divisible by 32"); - - int64_t idxValidRows = ShapedType::kDynamic; - int64_t idxValidCols = ShapedType::kDynamic; - deriveValidShape(op.getIdx(), idxValidRows, idxValidCols); - bool idxBroadcast = idxValidRows == 1; - - Value c0 = rewriter.create(op.getLoc(), 0); - Value c1 = rewriter.create(op.getLoc(), 1); - Value repeatNumPerRow = - rewriter.create(op.getLoc(), validCols / 32); - Value srcStrideValue = rewriter.create(op.getLoc(), srcStride); - Value dstStrideValue = rewriter.create(op.getLoc(), dstStride); - Value idxStrideValue = - rewriter.create(op.getLoc(), idxBroadcast ? 0 : idxStride); - - auto rowLoop = rewriter.create(op.getLoc(), c0, validRowsValue, c1); - rewriter.setInsertionPointToStart(rowLoop.getBody()); - Value row = rowLoop.getInductionVar(); - Value srcOffset = rewriter.create(op.getLoc(), row, srcStrideValue); - Value dstOffset = rewriter.create(op.getLoc(), row, dstStrideValue); - Value idxOffset = rewriter.create(op.getLoc(), row, idxStrideValue); - Value rowSrcPtr = - offsetBufferPointer(srcBuffer, dataType, srcOffset, rewriter, op.getLoc()); - Value rowDstPtr = - offsetBufferPointer(dstBuffer, dataType, dstOffset, rewriter, op.getLoc()); - Value rowIdxPtr = offsetBufferPointer(idxBuffer, getElementType(op.getIdx()), idxOffset, - rewriter, op.getLoc()); - rewriter.create(op.getLoc(), rowDstPtr, rowSrcPtr, rowIdxPtr, - repeatNumPerRow); - return success(); -} - -LogicalResult lowerTMulS(TMulSOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - VPTOUnaryContract contract = buildUnaryContract("muls", op.getSrc0()); - if (failed(checkGenericUnaryContract( - op, contract, op.getDst(), - [](Type type) { - if (type.isF16() || type.isF32()) - return true; - if (auto intType = dyn_cast(type)) - return intType.getWidth() == 16 || intType.getWidth() == 32; - return false; - }, - "f16, f32, and 16/32-bit integer element types"))) - return failure(); - if (!isCompatibleScalarForSemanticType(contract.elementType, - op.getScalar().getType())) - return op.emitOpError("tmuls lowering requires scalar type to match source element type"); - return buildScalarUnaryVecScope("muls", contract, strategy, op.getSrc0(), op.getScalar(), - op.getDst(), rewriter, op.getLoc()); -} - -LogicalResult lowerTSelS(TSelSOp op, PatternRewriter &rewriter) { - VPTOBinaryContract contract = buildBinaryContract("sels", op.getSrc()); - deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); - deriveValidShape(op.getDst(), contract.validRows, contract.validCols); - if (failed(checkGenericBinaryContract( - op, contract, op.getTmp(), op.getDst(), - [](Type type) { - if (type.isF16() || type.isF32() || type.isBF16()) - return true; - if (auto intType = dyn_cast(type)) - return intType.getWidth() == 8 || intType.getWidth() == 16 || - intType.getWidth() == 32; - return false; - }, - "f16, f32, bf16, and 8/16/32-bit integer element types"))) - return failure(); - - auto selectModeType = dyn_cast(op.getScalar().getType()); - if (!selectModeType) - return op.emitOpError("tsels lowering requires integer selectMode"); - - auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); - if (!vecType) - return op.emitOpError("tsels lowering requires a supported VPTO vector element type"); - - Value src0Buffer = materializeBufferPointer(op.getSrc(), contract.elementType, - getMemorySpace(op.getSrc()), rewriter, - op.getLoc()); - Value src1Buffer = materializeBufferPointer(op.getTmp(), contract.elementType, - getMemorySpace(op.getTmp()), rewriter, - op.getLoc()); - Value dstBuffer = materializeBufferPointer(op.getDst(), contract.elementType, - getMemorySpace(op.getDst()), rewriter, - op.getLoc()); - if (!src0Buffer || !src1Buffer || !dstBuffer) - return op.emitOpError("tsels lowering requires pointer-backed tile buffers"); - - Value validRowsValue = materializeIndexValue(contract.validRowsValue, - contract.validRows, rewriter, op.getLoc()); - Value validColsValue = materializeIndexValue(contract.validColsValue, - contract.validCols, rewriter, op.getLoc()); - if (!validRowsValue || !validColsValue) - return op.emitOpError("tsels lowering requires valid rows and cols"); - - int64_t vectorWidth = vecType.getElementCount(); - if (contract.validRows != ShapedType::kDynamic && - contract.validCols != ShapedType::kDynamic) { - int64_t totalElements = contract.validRows * contract.validCols; - if (totalElements % vectorWidth != 0) - return op.emitOpError( - "tsels lowering currently requires total valid elements divisible by vector width"); - } - - Value c0 = rewriter.create(op.getLoc(), 0); - Value c1 = rewriter.create(op.getLoc(), 1); - Value totalElementsValue = - rewriter.create(op.getLoc(), validRowsValue, validColsValue); - Value vectorStepValue = - rewriter.create(op.getLoc(), vectorWidth); - - FailureOr vecScope = - createLoopScopeRegion(op.getLoc(), contract.loopScope, rewriter); - if (failed(vecScope)) - return op.emitOpError("failed to create AIV vector scope region"); - - OpBuilder::InsertionGuard aivGuard(rewriter); - rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); - - Value selectOne = rewriter.create( - op.getLoc(), IntegerAttr::get(selectModeType, 1)); - Value isAll = rewriter.create(op.getLoc(), arith::CmpIPredicate::eq, - op.getScalar(), selectOne); - auto ifOp = rewriter.create( - op.getLoc(), TypeRange{getVPTOMaskTypeForElementType(rewriter.getContext(), contract.elementType)}, isAll, - /*withElseRegion=*/true); - rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); - Value allMask = rewriter - .create(op.getLoc(), - getVPTOMaskTypeForElementType(rewriter.getContext(), contract.elementType), - rewriter.getStringAttr("PAT_ALL")) - .getResult(); - rewriter.create(op.getLoc(), allMask); - rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); - Value allfMask = rewriter - .create(op.getLoc(), - getVPTOMaskTypeForElementType(rewriter.getContext(), contract.elementType), - rewriter.getStringAttr("PAT_ALLF")) - .getResult(); - rewriter.create(op.getLoc(), allfMask); - - rewriter.setInsertionPointAfter(ifOp); - auto chunkLoop = - rewriter.create(op.getLoc(), c0, totalElementsValue, vectorStepValue); - rewriter.setInsertionPointToStart(chunkLoop.getBody()); - Value offset = chunkLoop.getInductionVar(); - Value mask = ifOp.getResult(0); - auto src0Vec = rewriter.create(op.getLoc(), vecType, src0Buffer, - offset, StringAttr()); - auto src1Vec = rewriter.create(op.getLoc(), vecType, src1Buffer, - offset, StringAttr()); - Value selected = rewriter - .create(op.getLoc(), vecType, src0Vec.getResult(), - src1Vec.getResult(), mask) - .getResult(); - rewriter.create( - op.getLoc(), selected, dstBuffer, offset, StringAttr(), - buildAllPredicateMask(rewriter, op.getLoc(), contract.elementType)); - return success(); -} - -LogicalResult lowerTSel(TSelOp op, PatternRewriter &rewriter) { - VPTOBinaryContract contract = buildBinaryContract("tsel", op.getSrc0()); - deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); - deriveValidShape(op.getDst(), contract.validRows, contract.validCols); - - int64_t src1Rows = ShapedType::kDynamic; - int64_t src1Cols = ShapedType::kDynamic; - int64_t dstRows = ShapedType::kDynamic; - int64_t dstCols = ShapedType::kDynamic; - int64_t maskRows = ShapedType::kDynamic; - int64_t maskCols = ShapedType::kDynamic; - deriveValidShape(op.getSrc1(), src1Rows, src1Cols); - deriveValidShape(op.getDst(), dstRows, dstCols); - deriveValidShape(op.getMask(), maskRows, maskCols); - - if (contract.tileDomain != VPTOTileDomain::Vec || - deriveTileDomain(getMemorySpace(op.getSrc1())) != VPTOTileDomain::Vec || - deriveTileDomain(getMemorySpace(op.getDst())) != VPTOTileDomain::Vec || - deriveTileDomain(getMemorySpace(op.getMask())) != VPTOTileDomain::Vec) - return op.emitOpError("tsel lowering requires tile domain vec"); - if (contract.tileLayout != "row_major" || deriveTileLayout(op.getSrc1()) != "row_major" || - deriveTileLayout(op.getDst()) != "row_major" || deriveTileLayout(op.getMask()) != "row_major") - return op.emitOpError("tsel lowering requires row-major tile layout"); - if (contract.validRows == ShapedType::kDynamic || - contract.validCols == ShapedType::kDynamic) - return op.emitOpError("tsel lowering requires static valid shape"); - if (contract.validRows != src1Rows || contract.validCols != src1Cols || - contract.validRows != dstRows || contract.validCols != dstCols || - contract.validRows != maskRows || contract.validCols != maskCols) - return op.emitOpError("tsel lowering requires matching source, mask, and destination valid region"); - if (!contract.elementType || !contract.elementType.isF32()) - return op.emitOpError("tsel lowering currently supports only f32 data tiles"); - auto maskElemType = dyn_cast_or_null(getElementType(op.getMask())); - if (!maskElemType || maskElemType.getWidth() != 8) - return op.emitOpError("tsel lowering currently requires i8 mask tiles"); - - auto [tileRows, tileCols] = getStaticTileRowsCols(op.getDst()); - auto [maskTileRows, maskTileCols] = getStaticTileRowsCols(op.getMask()); - if (tileRows == ShapedType::kDynamic || tileCols == ShapedType::kDynamic || - maskTileRows == ShapedType::kDynamic || maskTileCols == ShapedType::kDynamic) - return op.emitOpError("tsel lowering requires static tile rows and cols"); - Value maskBuffer = materializeBufferPointer(op.getMask(), getElementType(op.getMask()), - getMemorySpace(op.getMask()), rewriter, - op.getLoc()); - Value src0Buffer = materializeBufferPointer(op.getSrc0(), contract.elementType, - getMemorySpace(op.getSrc0()), rewriter, - op.getLoc()); - Value src1Buffer = materializeBufferPointer(op.getSrc1(), contract.elementType, - getMemorySpace(op.getSrc1()), rewriter, - op.getLoc()); - Value dstBuffer = materializeBufferPointer(op.getDst(), contract.elementType, - getMemorySpace(op.getDst()), rewriter, - op.getLoc()); - if (!maskBuffer || !src0Buffer || !src1Buffer || !dstBuffer) - return op.emitOpError("tsel lowering requires pointer-backed tile buffers"); - - auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); - if (!vecType) - return op.emitOpError("tsel lowering requires a supported VPTO vector element type"); - - Value c0 = rewriter.create(op.getLoc(), 0); - Value c1 = rewriter.create(op.getLoc(), 1); - Value validRowsValue = materializeIndexValue(contract.validRowsValue, contract.validRows, - rewriter, op.getLoc()); - if (!validRowsValue) - return op.emitOpError("tsel lowering requires valid rows"); - Value rowStride = rewriter.create(op.getLoc(), tileCols); - Value maskStride = rewriter.create(op.getLoc(), maskTileCols); - constexpr int64_t elementsPerRepeat = 64; - constexpr int64_t unrollConstant = 2; - int64_t repeatTimes = (contract.validCols + elementsPerRepeat - 1) / elementsPerRepeat; - int64_t pairedRepeatTimes = repeatTimes / unrollConstant; - int64_t remainRepeat = repeatTimes % unrollConstant; - int64_t repeatIdxBase = pairedRepeatTimes * unrollConstant; - - FailureOr vecScope = - createLoopScopeRegion(op.getLoc(), contract.loopScope, rewriter); - if (failed(vecScope)) - return op.emitOpError("failed to create AIV vector scope region"); - - OpBuilder::InsertionGuard aivGuard(rewriter); - rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); - auto splitMaskType = getVPTOMaskType(rewriter.getContext(), "b16"); - Value fullMask = rewriter - .create(op.getLoc(), splitMaskType, - rewriter.getStringAttr("PAT_ALL")) - .getResult(); - auto rowLoop = rewriter.create(op.getLoc(), c0, validRowsValue, c1); - rewriter.setInsertionPointToStart(rowLoop.getBody()); - Value rowBase = rewriter.create(op.getLoc(), rowLoop.getInductionVar(), rowStride); - Value maskBase = rewriter.create(op.getLoc(), rowLoop.getInductionVar(), maskStride); - - for (int64_t j = 0; j < pairedRepeatTimes; ++j) { - int64_t repeatIdx = j * unrollConstant; - int64_t colOffset0 = repeatIdx * elementsPerRepeat; - int64_t colOffset1 = colOffset0 + elementsPerRepeat; - int64_t maskOffsetImm = repeatIdx * 8; - int64_t count0 = std::min(elementsPerRepeat, contract.validCols - colOffset0); - int64_t count1 = std::min(elementsPerRepeat, contract.validCols - colOffset1); - - Value maskOffset = rewriter.create( - op.getLoc(), maskBase, - rewriter.create(op.getLoc(), maskOffsetImm)); - Value rawMask = rewriter - .create(op.getLoc(), - splitMaskType, - maskBuffer, maskOffset, - rewriter.getStringAttr("US")) - .getResult(); - auto splitMask = rewriter.create( - op.getLoc(), splitMaskType, splitMaskType, rawMask, fullMask); - - Value dataOffset0 = rewriter.create( - op.getLoc(), rowBase, - rewriter.create(op.getLoc(), colOffset0)); - auto lhs0 = rewriter.create(op.getLoc(), vecType, src0Buffer, - dataOffset0, StringAttr()); - auto rhs0 = rewriter.create(op.getLoc(), vecType, src1Buffer, - dataOffset0, StringAttr()); - Value selected0 = rewriter - .create(op.getLoc(), vecType, lhs0.getResult(), - rhs0.getResult(), splitMask.getLow()) - .getResult(); - Value storeMask0 = buildPredicateMaskForLaneCount( - rewriter, op.getLoc(), contract.elementType, - rewriter.create(op.getLoc(), count0)); - rewriter.create(op.getLoc(), selected0, dstBuffer, dataOffset0, - StringAttr(), storeMask0); - - Value dataOffset1 = rewriter.create( - op.getLoc(), rowBase, - rewriter.create(op.getLoc(), colOffset1)); - auto lhs1 = rewriter.create(op.getLoc(), vecType, src0Buffer, - dataOffset1, StringAttr()); - auto rhs1 = rewriter.create(op.getLoc(), vecType, src1Buffer, - dataOffset1, StringAttr()); - Value selected1 = rewriter - .create(op.getLoc(), vecType, lhs1.getResult(), - rhs1.getResult(), splitMask.getHigh()) - .getResult(); - Value storeMask1 = buildPredicateMaskForLaneCount( - rewriter, op.getLoc(), contract.elementType, - rewriter.create(op.getLoc(), count1)); - rewriter.create(op.getLoc(), selected1, dstBuffer, dataOffset1, - StringAttr(), storeMask1); - } - - for (int64_t j = 0; j < remainRepeat; ++j) { - int64_t repeatIdx = repeatIdxBase + j; - int64_t colOffset = repeatIdx * elementsPerRepeat; - int64_t count = std::max(0, contract.validCols - colOffset); - int64_t maskOffsetImm = repeatIdx * 8; - - Value maskOffset = rewriter.create( - op.getLoc(), maskBase, - rewriter.create(op.getLoc(), maskOffsetImm)); - Value rawMask = rewriter - .create(op.getLoc(), - splitMaskType, - maskBuffer, maskOffset, - rewriter.getStringAttr("US")) - .getResult(); - Value unpackedMask = rewriter - .create( - op.getLoc(), splitMaskType, - rawMask, rewriter.getStringAttr("LOWER")) - .getResult(); - Value dataOffset = rewriter.create( - op.getLoc(), rowBase, - rewriter.create(op.getLoc(), colOffset)); - auto lhs = rewriter.create(op.getLoc(), vecType, src0Buffer, - dataOffset, StringAttr()); - auto rhs = rewriter.create(op.getLoc(), vecType, src1Buffer, - dataOffset, StringAttr()); - Value selected = rewriter - .create(op.getLoc(), vecType, lhs.getResult(), - rhs.getResult(), unpackedMask) - .getResult(); - Value storeMask = buildPredicateMaskForLaneCount( - rewriter, op.getLoc(), contract.elementType, - rewriter.create(op.getLoc(), count)); - rewriter.create(op.getLoc(), selected, dstBuffer, dataOffset, - StringAttr(), storeMask); - } - return success(); -} - -LogicalResult lowerTDivS(TDivSOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - Value tileOperand; - Value scalarOperand; - bool scalarFirst = false; - if (isVPTOShapedLikeValue(op.getSrc()) && !isVPTOShapedLikeValue(op.getScalar())) { - tileOperand = op.getSrc(); - scalarOperand = op.getScalar(); - } else if (!isVPTOShapedLikeValue(op.getSrc()) && - isVPTOShapedLikeValue(op.getScalar())) { - tileOperand = op.getScalar(); - scalarOperand = op.getSrc(); - scalarFirst = true; - } else { - return op.emitOpError( - "divs lowering requires exactly one shaped operand and one scalar operand"); - } - - VPTOUnaryContract contract = buildUnaryContract("divs", tileOperand); - if (failed(checkGenericUnaryContract( - op, contract, op.getDst(), - [](Type type) { return type.isF16() || type.isF32(); }, - "f16 and f32 element types"))) - return failure(); - if (!isCompatibleScalarForSemanticType(contract.elementType, - scalarOperand.getType())) - return op.emitOpError( - "divs lowering requires scalar type to match source element type"); - return buildScalarDivVecScope(contract, strategy, tileOperand, scalarOperand, op.getDst(), - scalarFirst, rewriter, op.getLoc()); -} - -LogicalResult lowerTAddS(TAddSOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - VPTOUnaryContract contract = buildUnaryContract("adds", op.getSrc()); - if (failed(checkGenericUnaryContract( - op, contract, op.getDst(), - [](Type type) { - if (type.isF16() || type.isF32() || type.isBF16()) - return true; - if (auto intType = dyn_cast(type)) - return intType.getWidth() == 16 || intType.getWidth() == 32; - return false; - }, - "f16, f32, bf16, and 16/32-bit integer element types"))) - return failure(); - if (!isCompatibleScalarForSemanticType(contract.elementType, - op.getScalar().getType())) - return op.emitOpError("tadds lowering requires scalar type to match source element type"); - return buildScalarUnaryVecScope("adds", contract, strategy, op.getSrc(), op.getScalar(), - op.getDst(), rewriter, op.getLoc()); -} - -LogicalResult lowerTAddC(TAddCOp op, PatternRewriter &rewriter) { - VPTOBinaryContract first = buildBinaryContract("add", op.getSrc0()); - deriveValidShapeValues(op.getDst(), first.validRowsValue, first.validColsValue); - deriveValidShape(op.getDst(), first.validRows, first.validCols); - if (failed(checkGenericBinaryContract( - op, first, op.getSrc1(), op.getDst(), - [](Type type) { - if (type.isF16() || type.isF32() || type.isBF16()) - return true; - if (auto intType = dyn_cast(type)) - return intType.getWidth() == 8 || intType.getWidth() == 16 || - intType.getWidth() == 32; - return false; - }, - "f16, f32, bf16, and 8/16/32-bit integer element types"))) - return failure(); - if (failed(buildBinaryVecScope("add", first, VPTOLoweringStrategy::PostUpdate, - op.getSrc0(), op.getSrc1(), op.getDst(), - rewriter, op.getLoc()))) - return failure(); - - VPTOBinaryContract second = buildBinaryContract("add", op.getDst()); - deriveValidShapeValues(op.getDst(), second.validRowsValue, second.validColsValue); - deriveValidShape(op.getDst(), second.validRows, second.validCols); - if (failed(checkGenericBinaryContract( - op, second, op.getSrc2(), op.getDst(), - [](Type type) { - if (type.isF16() || type.isF32() || type.isBF16()) - return true; - if (auto intType = dyn_cast(type)) - return intType.getWidth() == 8 || intType.getWidth() == 16 || - intType.getWidth() == 32; - return false; - }, - "f16, f32, bf16, and 8/16/32-bit integer element types"))) - return failure(); - return buildBinaryVecScope("add", second, VPTOLoweringStrategy::PostUpdate, - op.getDst(), op.getSrc2(), op.getDst(), rewriter, - op.getLoc()); -} - -LogicalResult lowerTAddSC(TAddSCOp op, PatternRewriter &rewriter) { - return emitUnresolvedInstalledA5BaselineError(op, "taddsc"); -} - -LogicalResult lowerTSubC(TSubCOp op, PatternRewriter &rewriter) { - VPTOBinaryContract first = buildBinaryContract("sub", op.getSrc0()); - deriveValidShapeValues(op.getDst(), first.validRowsValue, first.validColsValue); - deriveValidShape(op.getDst(), first.validRows, first.validCols); - if (failed(checkGenericBinaryContract( - op, first, op.getSrc1(), op.getDst(), - [](Type type) { - if (type.isF16() || type.isF32() || type.isBF16()) - return true; - if (auto intType = dyn_cast(type)) - return intType.getWidth() == 8 || intType.getWidth() == 16 || - intType.getWidth() == 32; - return false; - }, - "f16, f32, bf16, and 8/16/32-bit integer element types"))) - return failure(); - if (failed(buildBinaryVecScope("sub", first, VPTOLoweringStrategy::PostUpdate, - op.getSrc0(), op.getSrc1(), op.getDst(), - rewriter, op.getLoc()))) - return failure(); - - VPTOBinaryContract second = buildBinaryContract("add", op.getDst()); - deriveValidShapeValues(op.getDst(), second.validRowsValue, second.validColsValue); - deriveValidShape(op.getDst(), second.validRows, second.validCols); - if (failed(checkGenericBinaryContract( - op, second, op.getSrc2(), op.getDst(), - [](Type type) { - if (type.isF16() || type.isF32() || type.isBF16()) - return true; - if (auto intType = dyn_cast(type)) - return intType.getWidth() == 8 || intType.getWidth() == 16 || - intType.getWidth() == 32; - return false; - }, - "f16, f32, bf16, and 8/16/32-bit integer element types"))) - return failure(); - return buildBinaryVecScope("add", second, VPTOLoweringStrategy::PostUpdate, - op.getDst(), op.getSrc2(), op.getDst(), rewriter, - op.getLoc()); -} - -LogicalResult lowerTSubS(TSubSOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - (void)rewriter; - (void)strategy; - return emitUnresolvedInstalledA5BaselineError(op, "tsubs"); -} - -LogicalResult lowerTSubSC(TSubSCOp op, PatternRewriter &rewriter) { - return emitUnresolvedInstalledA5BaselineError(op, "tsubsc"); -} - -LogicalResult lowerTMaxS(TMaxSOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - VPTOUnaryContract contract = buildUnaryContract("maxs", op.getSrc()); - if (failed(checkGenericUnaryContract( - op, contract, op.getDst(), - [](Type type) { return type.isF32(); }, "f32 element type"))) - return failure(); - if (!isCompatibleScalarForSemanticType(contract.elementType, - op.getScalar().getType())) - return op.emitOpError("tmaxs lowering requires scalar type to match source element type"); - return buildScalarUnaryVecScope("maxs", contract, strategy, op.getSrc(), op.getScalar(), - op.getDst(), rewriter, op.getLoc()); -} - -LogicalResult lowerTMinS(TMinSOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - VPTOUnaryContract contract = buildUnaryContract("mins", op.getSrc()); - if (failed(checkGenericUnaryContract( - op, contract, op.getDst(), - [](Type type) { return type.isF32(); }, "f32 element type"))) - return failure(); - if (!isCompatibleScalarForSemanticType(contract.elementType, - op.getScalar().getType())) - return op.emitOpError("tmins lowering requires scalar type to match source element type"); - return buildScalarUnaryVecScope("mins", contract, strategy, op.getSrc(), op.getScalar(), - op.getDst(), rewriter, op.getLoc()); -} - -LogicalResult lowerTRowMax(TRowMaxOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - VPTORowReduceContract contract = extractTRowMaxContract(op); - if (failed(checkRowReduceContract(op, contract, op.getDst()))) - return failure(); - return buildRowReduceVecScope("rowmax", contract, strategy, op.getSrc(), op.getDst(), - rewriter, op.getLoc()); -} - -LogicalResult lowerTRowMin(TRowMinOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - VPTORowReduceContract contract = extractTRowMinContract(op); - if (failed(checkRowReduceContract(op, contract, op.getDst()))) - return failure(); - return buildRowReduceVecScope("rowmin", contract, strategy, op.getSrc(), op.getDst(), - rewriter, op.getLoc()); -} - -LogicalResult lowerTRowSum(TRowSumOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - VPTORowReduceContract contract = extractTRowSumContract(op); - if (failed(checkRowReduceContract(op, contract, op.getDst()))) - return failure(); - return buildRowReduceVecScope("rowsum", contract, strategy, op.getSrc(), op.getDst(), - rewriter, op.getLoc()); -} - -LogicalResult lowerTColMax(TColMaxOp op, PatternRewriter &rewriter) { - VPTOColReduceContract contract = extractTColMaxContract(op); - if (failed(checkColReduceContract(op, contract, op.getDst()))) - return failure(); - return buildColReduceVecScope("colmax", contract, op.getSrc(), op.getDst(), - Value(), rewriter, op.getLoc()); -} - -LogicalResult lowerTColMin(TColMinOp op, PatternRewriter &rewriter) { - VPTOColReduceContract contract = extractTColMinContract(op); - if (failed(checkColReduceContract(op, contract, op.getDst()))) - return failure(); - return buildColReduceVecScope("colmin", contract, op.getSrc(), op.getDst(), - Value(), rewriter, op.getLoc()); -} - -LogicalResult lowerTColSum(TColSumOp op, PatternRewriter &rewriter) { - VPTOColReduceContract contract = extractTColSumContract(op); - if (failed(checkColReduceContract(op, contract, op.getDst()))) - return failure(); - return buildColReduceVecScope("colsum", contract, op.getSrc(), op.getDst(), - op.getTmp(), rewriter, op.getLoc()); -} - -LogicalResult lowerTRowExpand(TRowExpandOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - VPTOExpandContract contract = extractTRowExpandContract(op); - if (failed(checkExpandContract(op, contract))) - return failure(); - if (contract.srcValidRows != contract.dstValidRows) - return op.emitOpError() - << "rowexpand lowering requires source and destination valid rows to match"; - return buildRowExpandVecScope(contract, strategy, op.getSrc(), op.getDst(), rewriter, - op.getLoc()); -} - -LogicalResult lowerTColExpand(TColExpandOp op, PatternRewriter &rewriter) { - VPTOExpandContract contract = extractTColExpandContract(op); - if (failed(checkExpandContract(op, contract))) - return failure(); - if (contract.srcValidCols != contract.dstValidCols) - return op.emitOpError() - << "colexpand lowering requires source and destination valid cols to match"; - return buildColExpandVecScope(contract, op.getSrc(), op.getDst(), rewriter, - op.getLoc()); -} - -template -LogicalResult lowerTRowExpandBinaryLike(OpTy op, PatternRewriter &rewriter, - StringRef family, - VPTOLoweringStrategy strategy) { - Type elementType = getElementType(op.getDst()); - if (!elementType || (!elementType.isF16() && !elementType.isF32())) - return op.emitOpError() << family - << " lowering currently supports only f16 and f32 element types"; - - if (deriveTileDomain(getMemorySpace(op.getDst())) != VPTOTileDomain::Vec || - deriveTileDomain(getMemorySpace(op.getSrc0())) != VPTOTileDomain::Vec || - deriveTileDomain(getMemorySpace(op.getSrc1())) != VPTOTileDomain::Vec) - return op.emitOpError() << family << " lowering requires vec tile domain"; - if (deriveTileLayout(op.getDst()) != "row_major") - return op.emitOpError() << family << " lowering requires row-major dst layout"; - - int64_t dstValidRows = ShapedType::kDynamic; - int64_t dstValidCols = ShapedType::kDynamic; - int64_t src0ValidRows = ShapedType::kDynamic; - int64_t src0ValidCols = ShapedType::kDynamic; - int64_t src1ValidRows = ShapedType::kDynamic; - int64_t src1ValidCols = ShapedType::kDynamic; - deriveValidShape(op.getDst(), dstValidRows, dstValidCols); - deriveValidShape(op.getSrc0(), src0ValidRows, src0ValidCols); - deriveValidShape(op.getSrc1(), src1ValidRows, src1ValidCols); - if (dstValidRows == ShapedType::kDynamic || dstValidCols == ShapedType::kDynamic || - src0ValidRows == ShapedType::kDynamic || src0ValidCols == ShapedType::kDynamic || - src1ValidRows == ShapedType::kDynamic || src1ValidCols == ShapedType::kDynamic) - return op.emitOpError() << family - << " lowering currently requires static valid shapes"; - - bool src0EqDst = op.getSrc0().getType() == op.getDst().getType(); - bool src1EqDst = op.getSrc1().getType() == op.getDst().getType(); - if (!src0EqDst && !src1EqDst) - return op.emitOpError() << family - << " lowering requires src0 or src1 to match dst tile type"; - - Value baseSrc = src0EqDst ? op.getSrc0() : op.getSrc1(); - Value expandSrc = src0EqDst ? op.getSrc1() : op.getSrc0(); - StringRef expandLayout = deriveTileLayout(expandSrc); - int64_t expandValidRows = src0EqDst ? src1ValidRows : src0ValidRows; - int64_t expandValidCols = src0EqDst ? src1ValidCols : src0ValidCols; - if (expandValidRows != dstValidRows) - return op.emitOpError() << family - << " lowering requires expand operand valid rows to match dst"; - - int64_t elemBytes = getElementByteSize(elementType); - bool expandIsRowMajor = expandLayout == "row_major" && expandValidCols == 32 / elemBytes; - bool expandIsColMajor = expandLayout == "col_major" && expandValidCols == 1; - if (!expandIsRowMajor && !expandIsColMajor) - return op.emitOpError() << family - << " lowering requires PTO A5-compatible expand operand shape"; - - auto vecType = getVPTOVRegType(rewriter.getContext(), elementType); - if (!vecType) - return op.emitOpError() << family - << " lowering requires a legal VPTO vector type"; - - Value baseBuffer = materializeBufferPointer(baseSrc, elementType, - getMemorySpace(baseSrc), rewriter, - op.getLoc()); - Value expandBuffer = materializeBufferPointer(expandSrc, elementType, - getMemorySpace(expandSrc), rewriter, - op.getLoc()); - Value dstBuffer = materializeBufferPointer(op.getDst(), elementType, - getMemorySpace(op.getDst()), rewriter, - op.getLoc()); - if (!baseBuffer || !expandBuffer || !dstBuffer) - return op.emitOpError() << family - << " lowering requires pointer-backed tile buffers"; - - int64_t dstRowStride = deriveStaticRowStride(op.getDst()); - int64_t baseRowStride = deriveStaticRowStride(baseSrc); - int64_t expandRowStride = deriveStaticRowStride(expandSrc); - if (dstRowStride == ShapedType::kDynamic || baseRowStride == ShapedType::kDynamic || - expandRowStride == ShapedType::kDynamic) - return op.emitOpError() << family << " lowering requires static row strides"; - - Value c0 = rewriter.create(op.getLoc(), 0); - Value c1 = rewriter.create(op.getLoc(), 1); - Value rowsUpper = rewriter.create(op.getLoc(), dstValidRows); - Value colsUpper = rewriter.create(op.getLoc(), dstValidCols); - Value vectorStep = - rewriter.create(op.getLoc(), vecType.getElementCount()); - Value baseStrideValue = - rewriter.create(op.getLoc(), baseRowStride); - Value expandStrideValue = - rewriter.create(op.getLoc(), expandRowStride); - Value dstStrideValue = - rewriter.create(op.getLoc(), dstRowStride); - Value blockSizeValue = - rewriter.create(op.getLoc(), 32 / elemBytes); - - VPTOLoopScopeContract loopScope; - loopScope.kind = VPTOLoopScopeKind::AIVVectorScope; - loopScope.loweredAttr = kLoweredLoopScopeAttrName; - loopScope.loopDepth = 0; - - auto buildRowExpandValue = [&](Value baseVec, Value expandedVec, - Value predicate) -> FailureOr { - if (family == "trowexpandmul") - return rewriter.create(op.getLoc(), vecType, baseVec, - expandedVec, predicate) - .getResult(); - if (family == "trowexpanddiv") { - if (src0EqDst) - return rewriter.create(op.getLoc(), vecType, baseVec, - expandedVec, predicate) - .getResult(); - return rewriter.create(op.getLoc(), vecType, expandedVec, - baseVec, predicate) - .getResult(); - } - if (family == "trowexpandsub") { - if (src0EqDst) - return rewriter.create(op.getLoc(), vecType, baseVec, - expandedVec, predicate) - .getResult(); - return rewriter.create(op.getLoc(), vecType, expandedVec, - baseVec, predicate) - .getResult(); - } - return failure(); - }; - - FailureOr vecScope = - createLoopScopeRegion(op.getLoc(), loopScope, rewriter); - if (failed(vecScope)) - return op.emitOpError("failed to create AIV vector scope region"); - - OpBuilder::InsertionGuard aivGuard(rewriter); - rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); - auto rowLoop = rewriter.create(op.getLoc(), c0, rowsUpper, c1); - rewriter.setInsertionPointToStart(rowLoop.getBody()); - Value row = rowLoop.getInductionVar(); - Value baseRowOffset = rewriter.create(op.getLoc(), row, baseStrideValue); - Value dstRowOffset = rewriter.create(op.getLoc(), row, dstStrideValue); - Value expandRowOffset = expandIsRowMajor - ? rewriter.create(op.getLoc(), row, blockSizeValue) - : rewriter.create(op.getLoc(), row, expandStrideValue); - - Value expandVec; - if (expandIsColMajor) { - Value fullMask = buildAllPredicateMask(rewriter, op.getLoc(), elementType); - Value expandScalar = - rewriter.create(op.getLoc(), vecType, expandBuffer, - expandRowOffset); - expandVec = rewriter - .create(op.getLoc(), vecType, expandScalar, fullMask, - StringAttr()) - .getResult(); - } else { - expandVec = rewriter - .create(op.getLoc(), vecType, expandBuffer, expandRowOffset, - rewriter.getStringAttr("BLK")) - .getResult(); - } - - auto colLoop = rewriter.create(op.getLoc(), c0, colsUpper, vectorStep); - rewriter.setInsertionPointToStart(colLoop.getBody()); - Value col = colLoop.getInductionVar(); - Value remainingCols = rewriter.create(op.getLoc(), colsUpper, col); - Value needsTailMask = rewriter.create( - op.getLoc(), arith::CmpIPredicate::slt, remainingCols, vectorStep); - Value activeLanes = rewriter.create(op.getLoc(), needsTailMask, - remainingCols, vectorStep); - Value baseOffset = rewriter.create(op.getLoc(), baseRowOffset, col); - Value dstOffset = rewriter.create(op.getLoc(), dstRowOffset, col); - Value storeMask = - buildPredicateMaskForLaneCount(rewriter, op.getLoc(), elementType, activeLanes); - Value baseVec = - rewriter.create(op.getLoc(), vecType, baseBuffer, baseOffset, StringAttr()); - FailureOr computed = - buildRowExpandValue(baseVec, expandVec, storeMask); - if (failed(computed)) - return op.emitOpError() << "unsupported rowexpand binary family"; - rewriter.create(op.getLoc(), *computed, dstBuffer, dstOffset, - StringAttr(), storeMask); - rewriter.create(op.getLoc()); - return success(); -} - -LogicalResult lowerTRowExpandMul(TRowExpandMulOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - return lowerTRowExpandBinaryLike(op, rewriter, "trowexpandmul", strategy); -} - -LogicalResult lowerTRowExpandDiv(TRowExpandDivOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - return lowerTRowExpandBinaryLike(op, rewriter, "trowexpanddiv", strategy); -} - -LogicalResult lowerTRowExpandSub(TRowExpandSubOp op, PatternRewriter &rewriter, - VPTOLoweringStrategy strategy) { - return lowerTRowExpandBinaryLike(op, rewriter, "trowexpandsub", strategy); -} - -LogicalResult lowerTPartAdd(TPartAddOp op, PatternRewriter &rewriter) { - VPTOPartContract contract = extractTPartAddContract(op); - if (failed(checkPartContract(op, contract))) - return failure(); - return buildPartVecScope("partadd", contract, op.getSrc0(), op.getSrc1(), - op.getDst(), rewriter, op.getLoc()); -} - -LogicalResult lowerTPartMax(TPartMaxOp op, PatternRewriter &rewriter) { - VPTOPartContract contract = extractTPartMaxContract(op); - if (failed(checkPartContract(op, contract))) - return failure(); - return buildPartVecScope("partmax", contract, op.getSrc0(), op.getSrc1(), - op.getDst(), rewriter, op.getLoc()); -} - -LogicalResult lowerTPartMin(TPartMinOp op, PatternRewriter &rewriter) { - VPTOPartContract contract = extractTPartMinContract(op); - if (failed(checkPartContract(op, contract))) - return failure(); - return buildPartVecScope("partmin", contract, op.getSrc0(), op.getSrc1(), - op.getDst(), rewriter, op.getLoc()); -} - -LogicalResult lowerTSTORE(TStoreOp op, PatternRewriter &rewriter) { - VPTOStoreContract contract = extractTStoreContract(op); - - switch (contract.srcDomain) { - case VPTOTileDomain::Acc: - return lowerUnsupportedAccStore(op.getLoc()); - case VPTOTileDomain::Mat: - return lowerUnsupportedMatStore(op.getLoc()); - case VPTOTileDomain::Vec: - break; - } - - ResolvedTensorView destinationView; - if (!resolveTensorView(op.getDst(), destinationView, rewriter, op.getLoc())) - return op.emitOpError("requires a recoverable destination tensor view for VPTO lowering"); - - StringRef sourceTileLayout = deriveTileLayout(op.getSrc()); - StringRef destinationLayout = - inferVecTransferLayoutFromTile(stringifyLayoutAttr(destinationView.layoutAttr), - sourceTileLayout); - bool isNdStore = sourceTileLayout == "row_major" && destinationLayout == "nd"; - bool isDnStore = sourceTileLayout == "col_major" && destinationLayout == "dn"; - if (!isNdStore && !isDnStore) - return op.emitOpError("currently supports only ND row_major or DN col_major vec TSTORE lowering"); - - Value sourceBuffer = - materializeBufferPointer(op.getSrc(), contract.elementType, - getMemorySpace(op.getSrc()), rewriter, op.getLoc()); - Value destinationBuffer = - materializeBufferPointer(destinationView.root, getElementType(destinationView.root), - getGmMemorySpace(rewriter.getContext()), rewriter, - op.getLoc()); - if (!sourceBuffer || !destinationBuffer) - return op.emitOpError("requires A5-compatible source and destination buffers"); - - auto [tileRows, tileCols] = getStaticTileRowsCols(op.getSrc()); - Value validRowsValue = - materializeI64Value(contract.validRowsValue, contract.validRows, rewriter, - op.getLoc()); - Value validColsValue = - materializeI64Value(contract.validColsValue, contract.validCols, rewriter, - op.getLoc()); - Value sidValue = rewriter.create(op.getLoc(), 0, 64); - int64_t elemBytes = getElementByteSize(contract.elementType); - if ((isNdStore && tileCols == ShapedType::kDynamic) || - (isDnStore && tileRows == ShapedType::kDynamic) || elemBytes <= 0) - return op.emitOpError("requires static tile shape for A5-compatible transfer arguments"); - VecNdTransferPlan plan; - LogicalResult planResult = - isNdStore ? buildVecNdStorePlan(destinationView.shape, destinationView.strides, - tileCols, contract.validColsValue, - contract.validCols, contract.elementType, - rewriter, op.getLoc(), plan) - : buildVecDnStorePlan(destinationView.shape, destinationView.strides, - tileRows, contract.validRowsValue, - contract.validRows, contract.elementType, - rewriter, op.getLoc(), plan); - if (failed(planResult)) - return op.emitOpError("requires PTO-compatible vec copy_ubuf_to_gm arguments"); - Value reservedValue = rewriter.create(op.getLoc(), 0, 64); - if (!validRowsValue || !validColsValue) - return op.emitOpError("requires valid rows and cols for A5-compatible transfer arguments"); - Value destinationOffset = - materializeI64Ofr(destinationView.offsetElems, rewriter, op.getLoc()); - if (!destinationOffset) - return op.emitOpError("requires a materializable destination offset for VPTO lowering"); - Value destinationBase = - adjustPointerByElemOffset(destinationBuffer, destinationOffset, elemBytes, rewriter, - op.getLoc()); - if (!destinationBase) - return op.emitOpError("failed to materialize destination base pointer"); - - rewriter.create(op.getLoc(), plan.loop2Size, - plan.loop1Size); - rewriter.create( - op.getLoc(), plan.loop1FirstStrideBytes, plan.loop1SecondStrideBytes); - rewriter.create( - op.getLoc(), plan.loop2FirstStrideBytes, plan.loop2SecondStrideBytes); - - auto emitCopy = [&](Value srcPtr, Value dstPtr) { - Type transferElementType = - getCopyTransferElementType(contract.elementType, rewriter); - Value typedSrcPtr = - castPtrToElementType(srcPtr, transferElementType, rewriter, op.getLoc()); - Value typedDstPtr = - castPtrToElementType(dstPtr, transferElementType, rewriter, op.getLoc()); - if (!typedSrcPtr || !typedDstPtr) - return failure(); - rewriter.create( - op.getLoc(), typedSrcPtr, typedDstPtr, sidValue, plan.nBurst, - plan.lenBurst, reservedValue, plan.firstStrideBytes, - plan.secondStrideBytes); - return success(); - }; - - if (std::optional outerConst = getConstInt(plan.outerCount); outerConst && *outerConst == 1) { - return emitCopy(sourceBuffer, destinationBase); - } - - Value c0 = rewriter.create(op.getLoc(), 0); - Value c1 = rewriter.create(op.getLoc(), 1); - Value outerUpper = - rewriter.create(op.getLoc(), rewriter.getIndexType(), - plan.outerCount); - auto outerLoop = rewriter.create(op.getLoc(), c0, outerUpper, c1); - rewriter.setInsertionPointToStart(outerLoop.getBody()); - Value ivI64 = rewriter.create(op.getLoc(), rewriter.getI64Type(), - outerLoop.getInductionVar()); - Value srcStep = createI64Mul(ivI64, plan.outerSrcStrideElems, rewriter, op.getLoc()); - Value dstStep = createI64Mul(ivI64, plan.outerDstStrideElems, rewriter, op.getLoc()); - Value iterSrc = adjustPointerByElemOffset(sourceBuffer, srcStep, elemBytes, rewriter, - op.getLoc()); - Value iterDst = adjustPointerByElemOffset(destinationBase, dstStep, elemBytes, rewriter, - op.getLoc()); - return emitCopy(iterSrc, iterDst); -} - -LogicalResult lowerSetFlag(SetFlagOp op, PatternRewriter &rewriter) { - rewriter.create(op.getLoc(), - stringifyPipeAttr(op.getSrcPipe(), rewriter), - stringifyPipeAttr(op.getDstPipe(), rewriter), - stringifyEventAttr(op.getEventId(), rewriter)); - return success(); -} - -LogicalResult lowerWaitFlag(WaitFlagOp op, PatternRewriter &rewriter) { - rewriter.create(op.getLoc(), - stringifyPipeAttr(op.getSrcPipe(), rewriter), - stringifyPipeAttr(op.getDstPipe(), rewriter), - stringifyEventAttr(op.getEventId(), rewriter)); - return success(); -} - -LogicalResult lowerBarrier(BarrierOp op, PatternRewriter &rewriter) { - rewriter.create(op.getLoc(), - stringifyPipeAttr(op.getPipe(), rewriter)); - return success(); -} - -static FailureOr stringifyConcreteSyncPipeAttr(Attribute opTypeAttr, - PatternRewriter &rewriter) { - if (auto pipeAttr = dyn_cast(opTypeAttr)) - return PipeAttr::get(rewriter.getContext(), pipeAttr.getPipe()); - auto opTypeOr = parseSyncOpTypeLikeAttr(opTypeAttr); - if (failed(opTypeOr)) - return failure(); - PIPE pipe = mapSyncOpTypeToPipe(*opTypeOr); - if (!isConcreteSyncPipe(pipe)) - return failure(); - return PipeAttr::get(rewriter.getContext(), pipe); -} - -LogicalResult lowerGetBuf(GetBufOp op, PatternRewriter &rewriter) { - FailureOr pipeAttr = - stringifyConcreteSyncPipeAttr(op.getOpTypeAttr(), rewriter); - if (failed(pipeAttr)) - return op.emitOpError("get_buf expects SyncOpType/PipeEventType that maps to a concrete pipe"); - - rewriter.create(op.getLoc(), Attribute(*pipeAttr), - static_cast(op.getBufId()), - static_cast(op.getMode())); - return success(); -} - -LogicalResult lowerRlsBuf(RlsBufOp op, PatternRewriter &rewriter) { - FailureOr pipeAttr = - stringifyConcreteSyncPipeAttr(op.getOpTypeAttr(), rewriter); - if (failed(pipeAttr)) - return op.emitOpError("rls_buf expects SyncOpType/PipeEventType that maps to a concrete pipe"); - - rewriter.create(op.getLoc(), Attribute(*pipeAttr), - static_cast(op.getBufId()), - static_cast(op.getMode())); - return success(); -} - -namespace { - -static Type convertVPTOBoundaryMemRefType(Type type) { - auto memrefType = dyn_cast(type); - if (!memrefType) - return type; - auto memorySpace = dyn_cast_or_null(memrefType.getMemorySpace()); - if (!memorySpace) - return {}; - return PtrType::get(type.getContext(), memrefType.getElementType(), memorySpace); -} - -static LogicalResult eraseDeadVPTOMemRefScaffold(ModuleOp module) { - bool erasedAny = true; - while (erasedAny) { - erasedAny = false; - SmallVector deadOps; - module.walk([&](Operation *op) { - if (!op->use_empty()) - return; - if (isa(op)) - deadOps.push_back(op); - }); - for (Operation *op : deadOps) { - op->erase(); - erasedAny = true; - } - } - return success(); -} - -static LogicalResult verifyNoResidualVPTOMemRefs(ModuleOp module, - llvm::raw_ostream *diagOS) { - for (func::FuncOp func : module.getOps()) { - for (Type input : func.getFunctionType().getInputs()) { - if (!isa(input)) - continue; - if (diagOS) - *diagOS << "VPTO ptr-only boundary failed: residual memref argument in " - << func.getName() << ": " << input << "\n"; - return failure(); - } - for (Type result : func.getFunctionType().getResults()) { - if (!isa(result)) - continue; - if (diagOS) - *diagOS << "VPTO ptr-only boundary failed: residual memref result in " - << func.getName() << ": " << result << "\n"; - return failure(); - } - } - - WalkResult walk = module.walk([&](Operation *op) { - auto hasResidualMemRef = [](TypeRange types) { - return llvm::any_of(types, [](Type type) { - return isa(type); - }); - }; - if (hasResidualMemRef(op->getOperandTypes()) || - hasResidualMemRef(op->getResultTypes())) { - if (diagOS) { - *diagOS << "VPTO ptr-only boundary failed: residual memref-typed op " - << op->getName() << "\n"; - op->print(*diagOS); - *diagOS << "\n"; - } - return WalkResult::interrupt(); - } - for (Region ®ion : op->getRegions()) { - for (Block &block : region) { - for (BlockArgument arg : block.getArguments()) { - if (!isa(arg.getType())) - continue; - if (diagOS) - *diagOS << "VPTO ptr-only boundary failed: residual memref block " - << "argument in op " << op->getName() << ": " - << arg.getType() << "\n"; - return WalkResult::interrupt(); - } - } - } - return WalkResult::advance(); - }); - return walk.wasInterrupted() ? failure() : success(); -} - -} // namespace - -LogicalResult convertVPTOFunctionBoundariesToPtr(ModuleOp module, - llvm::raw_ostream *diagOS) { - // VPTO kernels use ptr-only entry semantics: the function ABI keeps only the - // same-space base pointer, while shape/stride/offset stay in live SSA and - // address calculations inside the body. - if (failed(eraseDeadVPTOMemRefScaffold(module))) - return failure(); - - bool sawFailure = false; - for (func::FuncOp func : module.getOps()) { - if (func.isExternal()) - continue; - - FunctionType functionType = func.getFunctionType(); - SmallVector newInputs(functionType.getInputs().begin(), - functionType.getInputs().end()); - bool changed = false; - - for (auto [idx, inputType] : llvm::enumerate(functionType.getInputs())) { - auto memrefType = dyn_cast(inputType); - if (!memrefType) - continue; - - Type newType = convertVPTOBoundaryMemRefType(inputType); - if (!newType) { - if (diagOS) - *diagOS << "VPTO ptr-only boundary failed: unsupported memref " - << "argument type in " << func.getName() << ": " - << inputType << "\n"; - sawFailure = true; - continue; - } - - BlockArgument arg = func.getArgument(idx); - SmallVector users(arg.getUsers().begin(), arg.getUsers().end()); - arg.setType(newType); - newInputs[idx] = newType; - changed = true; - - for (Operation *user : users) { - if (auto cast = dyn_cast(user)) { - if (cast.getInput() != arg) - continue; - if (cast.getResult().getType() == newType) { - cast.getResult().replaceAllUsesWith(arg); - cast.erase(); - } - continue; - } - - if (isa(user) && - user->use_empty()) { - user->erase(); - continue; - } - - if (diagOS) { - *diagOS << "VPTO ptr-only boundary failed: argument " << idx - << " of " << func.getName() - << " still feeds a memref-dependent user after ptr rewrite:\n"; - user->print(*diagOS); - *diagOS << "\n"; - } - sawFailure = true; - } - } - - for (Type resultType : functionType.getResults()) { - if (!isa(resultType)) - continue; - if (diagOS) - *diagOS << "VPTO ptr-only boundary failed: memref result is unsupported " - << "for " << func.getName() << ": " << resultType << "\n"; - sawFailure = true; - } - - if (changed) { - func.setFunctionType( - FunctionType::get(module.getContext(), newInputs, functionType.getResults())); - } - } - - if (sawFailure) - return failure(); - - if (failed(eraseDeadVPTOMemRefScaffold(module))) - return failure(); - return verifyNoResidualVPTOMemRefs(module, diagOS); -} - -} // namespace pto -} // namespace mlir diff --git a/lib/PTO/Transforms/VPTOBufferMaterialization.cpp b/lib/PTO/Transforms/VPTOBufferMaterialization.cpp new file mode 100644 index 000000000..a12fc8771 --- /dev/null +++ b/lib/PTO/Transforms/VPTOBufferMaterialization.cpp @@ -0,0 +1,89 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/VPTOLowering.h" + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir::pto { +namespace { + +static AddressSpaceAttr getNormalizedPtrMemorySpace(Attribute memorySpace, + MLIRContext *context) { + if (auto addrSpace = dyn_cast_or_null(memorySpace)) + return addrSpace; + if (auto intAttr = dyn_cast_or_null(memorySpace)) + return AddressSpaceAttr::get(context, + static_cast(intAttr.getInt())); + return AddressSpaceAttr::get(context, AddressSpace::GM); +} + +static Value materializeMemRefView(Value value, ArrayRef shape, + Type elementType, Attribute memorySpace, + PatternRewriter &rewriter, Location loc) { + auto memrefType = + MemRefType::get(shape, elementType, AffineMap(), memorySpace); + if (value.getType() == memrefType) + return value; + return rewriter + .create( + loc, TypeRange(ArrayRef{memrefType}), value) + .getResult(0); +} + +static Value materializeTileBufferView(Value value, PatternRewriter &rewriter, + Location loc) { + if (isa(value.getType())) + return value; + + auto tileType = dyn_cast(value.getType()); + if (!tileType) + return {}; + + return materializeMemRefView(value, tileType.getShape(), + tileType.getElementType(), + tileType.getMemorySpace(), rewriter, loc); +} + +} // namespace + +Value materializeBufferPointer(Value value, Type elementType, + Attribute memorySpace, + PatternRewriter &rewriter, Location loc) { + if (!value) + return {}; + + auto ptrMemorySpace = + getNormalizedPtrMemorySpace(memorySpace, rewriter.getContext()); + auto ptrType = PtrType::get(rewriter.getContext(), elementType, ptrMemorySpace); + + if (value.getType() == ptrType) + return value; + + if (auto bind = value.getDefiningOp()) + return materializeBufferPointer(bind.getSource(), elementType, memorySpace, + rewriter, loc); + + if (auto cast = value.getDefiningOp()) { + if (cast.getAddrs().empty()) + return {}; + return rewriter.create(loc, ptrType, cast.getAddrs().front()) + .getResult(); + } + + Value memrefValue = materializeTileBufferView(value, rewriter, loc); + auto memrefType = dyn_cast_or_null(memrefValue.getType()); + if (!memrefValue || !memrefType) + return {}; + return rewriter.create(loc, ptrType, memrefValue).getResult(); +} + +} // namespace mlir::pto diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index 53ad57c9d..e124a7154 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -7,7 +7,6 @@ // See LICENSE in the root of the software repository for the full text of the License. #include "PTO/IR/PTO.h" -#include "PTO/Transforms/VPTOLowering.h" #include "PTO/Transforms/VPTOLLVMEmitter.h" #include "PTO/Transforms/Passes.h" #include "PTO/Transforms/BufferizableOpInterfaceImpl.h" From d6634cd7100dfb17236434a6361076ecded36d52 Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Fri, 24 Apr 2026 00:05:13 +0800 Subject: [PATCH 147/192] feat: dsl supports for trems --- tilelang-dsl/python/tilelang_dsl/lowering.py | 22 ++++- tilelang-dsl/python/tilelang_dsl/semantic.py | 98 ++++++++++++++++++- .../python/tilelang_dsl/support_matrix.py | 4 + tilelang-dsl/tests/test_tilelang_dsl_v1.py | 35 ++++--- 4 files changed, 139 insertions(+), 20 deletions(-) diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index b640c7eda..690ffe5c8 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -790,7 +790,7 @@ def _render_multi_result_assign( env[high_target.name] = _RenderedValue(name=high_target.ssa_name, type=high_type) return lines - if stmt.value.name in {"pdintlv_b8", "pintlv_b16"}: + if stmt.value.name in {"pdintlv_b8", "pdintlv_b16", "pdintlv_b32", "pintlv_b8", "pintlv_b16", "pintlv_b32"}: lines = [] lhs = self._lower_expr(stmt.value.args[0], env, indent=indent, into=lines) rhs = self._lower_expr(stmt.value.args[1], env, indent=indent, into=lines) @@ -3053,6 +3053,15 @@ def _lower_call_expr( "vcpadd", "vsort32", }: + if expr.name in {"vsunpack", "vzunpack"}: + value = self._lower_expr(expr.args[0], env, indent=indent, into=into) + part = self._lower_to_index(expr.args[1], env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"{result_name} = pto.{expr.name} {value.name}, {part.name} : " + + f"{self._render_type(value.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) value = self._lower_expr(expr.args[0], env, indent=indent, into=into) mask = self._lower_expr(expr.args[1], env, indent=indent, into=into) into.append( @@ -3090,7 +3099,6 @@ def _lower_call_expr( "vshl", "vshr", "vprelu", - "vpack", "vperm", "vmrgsort", }: @@ -3105,6 +3113,16 @@ def _lower_call_expr( ) return _RenderedValue(name=result_name, type=expr.type) + if expr.name == "vpack": + vector = self._lower_expr(expr.args[0], env, indent=indent, into=into) + part = self._render_string_literal(expr.args[1]) + into.append( + self._indent(indent) + + f"{result_name} = pto.vpack {vector.name}, {part} : " + + f"{self._render_type(vector.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + if expr.name in {"vshift", "vslide"}: vector = self._lower_expr(expr.args[0], env, indent=indent, into=into) immediate = self._lower_expr(expr.args[1], env, indent=indent, into=into) diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 96b15f0c2..c23aff866 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -3970,7 +3970,7 @@ def _analyze_call_expr( return self._analyze_mask_part_op(name, args) if name in {"pnot", "psel", "pand", "por", "pxor"}: return self._analyze_mask_logic_op(name, args) - if name in {"pdintlv_b8", "pintlv_b16"}: + if name in {"pdintlv_b8", "pdintlv_b16", "pdintlv_b32", "pintlv_b8", "pintlv_b16", "pintlv_b32"}: return self._analyze_predicate_reorder_op(name, args) if name in {"vcmp", "vcmps"}: return self._analyze_compare_op(name, args) @@ -3980,6 +3980,8 @@ def _analyze_call_expr( return self._analyze_carry_op(name, args) if name in {"vintlv", "vdintlv", "vintlvv2", "vdintlvv2"}: return self._analyze_rearrangement_op(name, args) + if name == "vpack": + return self._analyze_vpack_op(args) if name == "vcvt": return self._analyze_vcvt(args) if name == "vbitcast": @@ -4716,6 +4718,20 @@ def _analyze_unary_vector_op( name: str, args: tuple[SemanticExpr, ...], ) -> SemanticExpr: + if name in {"vsunpack", "vzunpack"}: + if len(args) != 2: + raise TypeError(f"pto.{name} expects exactly 2 positional arguments in TileLang DSL v1") + value, part = args + vreg = self._require_vreg_expr(value, f"pto.{name} value") + self._require_i32_like_expr(part, f"pto.{name} part") + self._validate_unary_dtype(name, vreg.element_dtype) + result_dtype = self._unpack_result_dtype(name, vreg.element_dtype) + return SemanticCallExpr( + namespace="pto", + name=name, + args=args, + type=SemanticVRegType(element_dtype=result_dtype, lanes=vreg.lanes // 2), + ) if len(args) != 2: raise TypeError(f"pto.{name} expects exactly 2 positional arguments in TileLang DSL v1") value, mask = args @@ -4903,7 +4919,14 @@ def _analyze_predicate_reorder_op( raise TypeError(f"pto.{name} expects exactly 2 positional arguments in TileLang DSL v1") lhs = self._require_mask_expr(args[0], f"pto.{name} src0") rhs = self._require_mask_expr(args[1], f"pto.{name} src1") - expected_granularity = "b8" if name == "pdintlv_b8" else "b16" + expected_granularity = { + "pdintlv_b8": "b8", + "pdintlv_b16": "b16", + "pdintlv_b32": "b32", + "pintlv_b8": "b8", + "pintlv_b16": "b16", + "pintlv_b32": "b32", + }[name] if lhs.granularity != expected_granularity or rhs.granularity != expected_granularity: raise TypeError(f"pto.{name} expects !pto.mask<{expected_granularity}> operands") return SemanticCallExpr( @@ -5161,6 +5184,23 @@ def _analyze_vcvt( type=self._vreg_type_for_dtype(target_dtype), ) + def _analyze_vpack_op( + self, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if len(args) != 2: + raise TypeError("pto.vpack expects exactly 2 positional arguments in TileLang DSL") + vector = self._require_vreg_expr(args[0], "pto.vpack vector") + part = self._normalize_predicate_part(args[1], "pto.vpack part") + self._validate_binary_dtype("vpack", vector.element_dtype) + result_dtype = self._pack_result_dtype(vector.element_dtype) + return SemanticCallExpr( + namespace="pto", + name="vpack", + args=(args[0], part), + type=SemanticVRegType(element_dtype=result_dtype, lanes=vector.lanes * 2), + ) + def _analyze_vtrc( self, args: tuple[SemanticExpr, ...], @@ -6015,6 +6055,54 @@ def _vreg_type_for_dtype(self, dtype: ScalarType) -> SemanticVRegType: raise TypeError(f"dtype `{dtype.name}` is not supported by vlds/vsts in TileLang DSL v1") return SemanticVRegType(element_dtype=dtype, lanes=256 // width) + def _unpack_result_dtype(self, name: str, dtype: ScalarType) -> ScalarType: + if not is_integer_dtype(dtype): + raise TypeError(f"pto.{name} only supports integer vector dtypes in TileLang DSL v1") + width = integer_bitwidth(dtype) + if width not in {8, 16, 32}: + raise TypeError(f"pto.{name} only supports 8/16/32-bit integer vector dtypes in TileLang DSL v1") + + if name == "vzunpack": + mapping = { + "i8": ui16, + "si8": ui16, + "ui8": ui16, + "i16": ui32, + "si16": ui32, + "ui16": ui32, + "i32": ui64, + "si32": ui64, + "ui32": ui64, + } + return mapping[dtype.name] + + mapping = { + "i8": i16, + "si8": si16, + "i16": i32, + "si16": si32, + "i32": i64, + "si32": si64, + } + if dtype.name not in mapping: + raise TypeError(f"pto.{name} requires signed/signless integer vector dtypes in TileLang DSL v1") + return mapping[dtype.name] + + def _pack_result_dtype(self, dtype: ScalarType) -> ScalarType: + if not is_integer_dtype(dtype): + raise TypeError("pto.vpack only supports integer vector dtypes in TileLang DSL v1") + mapping = { + "i32": ui16, + "si32": ui16, + "ui32": ui16, + "i16": ui8, + "si16": ui8, + "ui16": ui8, + } + if dtype.name not in mapping: + raise TypeError("pto.vpack only supports 32->16 and 16->8 integer packing in TileLang DSL v1") + return mapping[dtype.name] + def _validate_unary_dtype(self, name: str, dtype: ScalarType) -> None: if name in {"vexp", "vln", "vsqrt", "vrec", "vrsqrt"} and dtype.name not in {"f16", "f32"}: raise TypeError(f"pto.{name} only supports f16/f32 in TileLang DSL v1") @@ -6033,8 +6121,10 @@ def _validate_unary_dtype(self, name: str, dtype: ScalarType) -> None: raise TypeError(f"pto.{name} does not support this dtype in TileLang DSL v1") def _validate_binary_dtype(self, name: str, dtype: ScalarType) -> None: - if name == "vdiv" and dtype.name not in {"f16", "f32"}: - raise TypeError("pto.vdiv only supports f16/f32 in TileLang DSL v1") + if name == "vdiv" and not ( + dtype.name in {"f16", "f32"} or (is_integer_dtype(dtype) and integer_bitwidth(dtype) == 16) + ): + raise TypeError("pto.vdiv only supports f16/f32/i16/ui16 in TileLang DSL v1") if name == "vprelu" and dtype.name not in {"f16", "f32"}: raise TypeError("pto.vprelu only supports f16/f32 in TileLang DSL v1") if name in {"vaddreluconv", "vmulconv"} and dtype.name not in {"f16", "bf16", "f32"}: diff --git a/tilelang-dsl/python/tilelang_dsl/support_matrix.py b/tilelang-dsl/python/tilelang_dsl/support_matrix.py index f4aa6ab41..bca098b02 100644 --- a/tilelang-dsl/python/tilelang_dsl/support_matrix.py +++ b/tilelang-dsl/python/tilelang_dsl/support_matrix.py @@ -186,7 +186,11 @@ "pst", "psti", "pdintlv_b8", + "pdintlv_b16", + "pdintlv_b32", + "pintlv_b8", "pintlv_b16", + "pintlv_b32", "vaddc", "vsubc", "vaddcs", diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index e44c1f2bf..ee1ad42ff 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -4366,10 +4366,10 @@ def kernel(dst: pto.Tile, src: pto.Tile, shift: pto.i16): out = pto.vbcnt(vec0, all_mask) out = pto.vneg(out, all_mask) out = pto.vcls(out, all_mask) - out = pto.vsunpack(out, all_mask) - out = pto.vzunpack(out, all_mask) - out = pto.vusqz(out, all_mask) - out = pto.vsqz(out, all_mask) + pto.vsunpack(vec0, 0) + pto.vzunpack(vec0.astype(pto.ui32), 0) + pto.vusqz(vec0.astype(pto.ui32), pto.make_mask(pto.ui32, pto.PAT.ALL)) + pto.vsqz(vec0, all_mask) out = pto.vshl(out, vec1, all_mask) out = pto.vshr(out, vec1, all_mask) out = pto.vshls(out, shift, all_mask) @@ -5124,18 +5124,20 @@ def kernel(dst: pto.Tile, src: pto.Tile, shift: pto.i32): all_mask = pto.make_mask(pto.i32, pto.PAT.ALL) vec0 = pto.vlds(src, 0) vec1 = pto.vlds(src, 64) - indices = pto.vci(shift, pto.OrderMode.ASC) + packed_mask = pto.make_mask(pto.ui16, pto.PAT.ALL) out = pto.vcgadd(vec0, all_mask) out = pto.vcgmax(out, all_mask) out = pto.vcgmin(out, all_mask) out = pto.vcpadd(out, all_mask) - out = pto.vpack(out, vec1, all_mask) - out = pto.vperm(out, indices, all_mask) - out = pto.vshift(out, shift, all_mask) - out = pto.vslide(out, shift, all_mask) + packed0 = pto.vpack(vec0, pto.PredicatePart.LOWER) + packed1 = pto.vpack(vec1, pto.PredicatePart.HIGHER) + indices = pto.vci(pto.i16(shift), pto.OrderMode.ASC) + packed0 = pto.vperm(packed0, indices, packed_mask) + packed0 = pto.vshift(packed0, pto.i16(shift), packed_mask) + packed0 = pto.vslide(packed0, pto.i16(shift), packed_mask) + packed0 = pto.vmrgsort(packed0, packed1, packed_mask) out = pto.vsort32(out, all_mask) - out = pto.vmrgsort(out, vec1, all_mask) pto.vsts(out, dst, 0, all_mask) return None @@ -6789,19 +6791,24 @@ def test_predicate_reorder_families_lower_to_supported_ops(self) -> None: def kernel(mask_dst: pto.ptr(pto.ui32, pto.MemorySpace.UB)): mask8 = pto.pset_b8(pto.PAT.ALL) mask16 = pto.pset_b16(pto.PAT.ALL) + mask32 = pto.pset_b32(pto.PAT.ALL) low8, high8 = pto.pdintlv_b8(mask8, mask8) + low8i, high8i = pto.pintlv_b8(mask8, mask8) + low16d, high16d = pto.pdintlv_b16(mask16, mask16) low16, high16 = pto.pintlv_b16(mask16, mask16) - _ = low8 - _ = high8 - _ = low16 - _ = high16 + low32, high32 = pto.pdintlv_b32(mask32, mask32) + low32i, high32i = pto.pintlv_b32(mask32, mask32) all32 = pto.make_mask(pto.ui32, pto.PAT.ALL) pto.psts(all32, mask_dst, 0) return None text = kernel.specialize().mlir_text() self.assertIn("pto.pdintlv_b8", text) + self.assertIn("pto.pintlv_b8", text) + self.assertIn("pto.pdintlv_b16", text) self.assertIn("pto.pintlv_b16", text) + self.assertIn("pto.pdintlv_b32", text) + self.assertIn("pto.pintlv_b32", text) def test_pdintlv_b8_rejects_wrong_mask_granularity(self) -> None: with self.assertRaises(TypeError) as ctx: From b6e09acb753d757bb762364a5e35c60422a96d01 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Fri, 24 Apr 2026 11:55:56 +0800 Subject: [PATCH 148/192] align(tilelang-dsl): require enum dist for vlds/vsts --- lib/TileOps/tcvt_template.py | 4 +- .../user_guide/09-vector-memory-operations.md | 52 +++++++- .../11-vector-arithmetic-operations.md | 27 +++++ tilelang-dsl/python/tilelang_dsl/__init__.py | 4 + .../python/tilelang_dsl/frontend_ast.py | 2 + tilelang-dsl/python/tilelang_dsl/semantic.py | 112 ++++++++++-------- tilelang-dsl/python/tilelang_dsl/types.py | 39 ++++++ tilelang-dsl/tests/test_tilelang_dsl_v1.py | 62 ++++++++-- 8 files changed, 233 insertions(+), 69 deletions(-) diff --git a/lib/TileOps/tcvt_template.py b/lib/TileOps/tcvt_template.py index 8a37e00cb..dca49b4e3 100644 --- a/lib/TileOps/tcvt_template.py +++ b/lib/TileOps/tcvt_template.py @@ -66,7 +66,7 @@ def template_tcvt_f32_to_f16(src: pto.Tile, dst: pto.Tile): sat=pto.VcvtSatMode.SAT, part=pto.VcvtPartMode.EVEN, ) - pto.vsts(converted, dst[row, col:], store_mask, dist="PK_B32") + pto.vsts(converted, dst[row, col:], store_mask, dist=pto.VStoreDist.PK_B32) return @@ -126,7 +126,7 @@ def template_tcvt_f16_to_f32(src: pto.Tile, dst: pto.Tile): remained = valid_cols for col in range(0, valid_cols, pto.get_lanes(pto.f32)): store_mask, remained = pto.make_mask(pto.f32, remained) - vec = pto.vlds(src[row, col:], dist="UNPK_B16") + vec = pto.vlds(src[row, col:], dist=pto.VLoadDist.UNPK_B16) converted = pto.vcvt( vec, pto.f32, diff --git a/tilelang-dsl/docs/user_guide/09-vector-memory-operations.md b/tilelang-dsl/docs/user_guide/09-vector-memory-operations.md index 68bdb1bde..f7a20fd76 100644 --- a/tilelang-dsl/docs/user_guide/09-vector-memory-operations.md +++ b/tilelang-dsl/docs/user_guide/09-vector-memory-operations.md @@ -3,6 +3,18 @@ The current DSL exposes type-safe Enum operands for the dual load/store distribution families: +- **`VLoadDist`** for `pto.vlds` + - `VLoadDist.NORM`: ordinary load + - `VLoadDist.UNPK_B8`, `VLoadDist.UNPK_B16`, `VLoadDist.UNPK_B32`: unpacking loads + - `VLoadDist.BRC_B8`, `VLoadDist.BRC_B16`, `VLoadDist.BRC_B32`: broadcast loads + - `VLoadDist.US_B8`, `VLoadDist.US_B16`, `VLoadDist.DS_B8`, `VLoadDist.DS_B16`: strided/narrow load families + +- **`VStoreDist`** for `pto.vsts` + - `VStoreDist.NORM_B8`, `VStoreDist.NORM_B16`, `VStoreDist.NORM_B32`: ordinary stores + - `VStoreDist.ONE_POINT_B8`, `VStoreDist.ONE_POINT_B16`, `VStoreDist.ONE_POINT_B32`: one-point stores + - `VStoreDist.PK_B16`, `VStoreDist.PK_B32`, `VStoreDist.PK_B64`: packed stores + - `VStoreDist.PK4_B32`, `VStoreDist.MRG4CHN_B8`, `VStoreDist.MRG2CHN_B8`, `VStoreDist.MRG2CHN_B16`: merged packed stores + - **`DeinterleaveDist`** for `pto.vldsx2` - `DeinterleaveDist.DINTLV`: alternating-element deinterleave - `DeinterleaveDist.BDINTLV`: block deinterleave @@ -20,6 +32,8 @@ distribution families: The canonical VPTO v0.3 spellings are the enum values: +- `VLoadDist.UNPK_B16.value == "UNPK_B16"` +- `VStoreDist.PK_B32.value == "PK_B32"` - `DeinterleaveDist.DINTLV.value == "DINTLV"` - `DeinterleaveDist.BDINTLV.value == "BDINTLV"` - `InterleaveDist.INTLV.value == "INTLV"` @@ -208,9 +222,9 @@ The syntax sugar eliminates manual byte calculations, reduces errors, and makes Operations for loading data from memory into vector registers. -#### `pto.vlds(buf: ptr, offset: Index) -> VRegType` [Advanced Tier] -#### `pto.vlds(tile[row, col:]) -> VRegType` [Basic Tier] -#### `pto.vlds(tile[start:]) -> VRegType` [Basic Tier] +#### `pto.vlds(buf: ptr, offset: Index, dist: pto.VLoadDist | None = None) -> VRegType` [Advanced Tier] +#### `pto.vlds(tile[row, col:], dist: pto.VLoadDist | None = None) -> VRegType` [Basic Tier] +#### `pto.vlds(tile[start:], dist: pto.VLoadDist | None = None) -> VRegType` [Basic Tier] **Description**: Stateless vector load from buffer. Supports both traditional byte-offset syntax and new element-indexing syntax. @@ -219,12 +233,14 @@ Operations for loading data from memory into vector registers. |-----------|------|-------------| | `buf` | `ptr` | Pointer to buffer in UB memory space (Advanced mode only - requires explicit pointer) | | `offset` | `Index` | Byte offset | +| `dist` | `pto.VLoadDist \| None` | Optional load distribution enum such as `pto.VLoadDist.NORM` or `pto.VLoadDist.UNPK_B16` | **Parameters (element-indexing syntax)**: | Parameter | Type | Description | |-----------|------|-------------| | `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | | `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | +| `dist` | `pto.VLoadDist \| None` | Optional load distribution enum such as `pto.VLoadDist.NORM` or `pto.VLoadDist.UNPK_B16` | **Returns**: | Return Value | Type | Description | @@ -235,11 +251,14 @@ Operations for loading data from memory into vector registers. - Buffer must be in UB memory space - For byte-offset syntax: offset must be properly aligned based on element type - For element-indexing syntax: the requested vector region must be within tile bounds and satisfy alignment requirements +- `dist` is optional. When omitted, the load uses the backend default layout for the vector family. +- `dist` must be a `pto.VLoadDist` enum value. **Examples**: ```python # Traditional byte-offset syntax vec = pto.vlds(ub_ptr, lane * 256) +vec_unpacked = pto.vlds(ub_ptr, lane * 128, dist=pto.VLoadDist.UNPK_B16) # New element-indexing syntax vec = pto.vlds(tile[i, j:]) # Load from row i, columns j to j+vector_lanes-1 @@ -550,9 +569,9 @@ vec = pto.vsldb(tile[k], control_word, mask) Operations for storing data from vector registers to memory. -#### `pto.vsts(vec: VRegType, buf: ptr, offset: Index, mask: MaskType) -> None` [Advanced Tier] -#### `pto.vsts(vec: VRegType, tile[row, col:], mask: MaskType) -> None` [Basic Tier] -#### `pto.vsts(vec: VRegType, tile[start:], mask: MaskType) -> None` [Basic Tier] +#### `pto.vsts(vec: VRegType, buf: ptr, offset: Index, mask: MaskType, dist: pto.VStoreDist | None = None) -> None` [Advanced Tier] +#### `pto.vsts(vec: VRegType, tile[row, col:], mask: MaskType, dist: pto.VStoreDist | None = None) -> None` [Basic Tier] +#### `pto.vsts(vec: VRegType, tile[start:], mask: MaskType, dist: pto.VStoreDist | None = None) -> None` [Basic Tier] **Description**: Stateless vector store to buffer. Supports both byte-offset and element-indexing syntax. @@ -563,6 +582,7 @@ Operations for storing data from vector registers to memory. | `buf` | `ptr` | Pointer to destination buffer in UB memory space (Advanced mode only - requires explicit pointer) | | `offset` | `Index` | Byte offset | | `mask` | `MaskType` | Predicate mask | +| `dist` | `pto.VStoreDist \| None` | Optional store distribution enum such as `pto.VStoreDist.NORM_B32` or `pto.VStoreDist.PK_B32` | **Parameters (element-indexing syntax)**: | Parameter | Type | Description | @@ -571,6 +591,7 @@ Operations for storing data from vector registers to memory. | `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column | | `tile[start:]` | `Tile` with indexing | 1D tile with starting element index | | `mask` | `MaskType` | Predicate mask | +| `dist` | `pto.VStoreDist \| None` | Optional store distribution enum such as `pto.VStoreDist.NORM_B32` or `pto.VStoreDist.PK_B32` | **Returns**: None (side-effect operation) @@ -578,6 +599,14 @@ Operations for storing data from vector registers to memory. - Buffer must be in UB memory space - For byte-offset syntax: offset must be properly aligned based on element type - For element-indexing syntax: the destination vector region must be within tile bounds and satisfy alignment requirements +- `dist` is optional. When omitted, the store uses the backend default layout for the vector family. +- Current TileLang DSL v1 accepts exactly one keyword attr on `pto.vsts`: `dist=...`. +- `dist` must be a `pto.VStoreDist` enum value. +- `mask` must match the effective store payload granularity, which may differ from the vector element family when `dist` repacks lanes. +- Common width-changing cases: + default / `NORM_B32` stores expect `mask_b32` for `f32`/`i32`-family vectors; + `PK_B32` also expects `mask_b32` and is used by narrow stores such as `f32 -> f16` `tcvt`; + `PK_B16` expects `mask_b16`. **Examples**: ```python @@ -588,6 +617,17 @@ pto.vsts(vec_f32, ub_ptr, lane * 256, mask32) pto.vsts(vec, tile[i, j:], mask) # Store to row i, columns j to j+vector_lanes-1 pto.vsts(vec, tile[k:], mask) # Store to 1D tile, elements k to k+vector_lanes-1 +# VPTO-aligned packed store +vec_f16 = pto.vcvt( + vec_f32, + pto.f16, + mask32, + rnd=pto.VcvtRoundMode.R, + sat=pto.VcvtSatMode.SAT, + part=pto.VcvtPartMode.EVEN, +) +pto.vsts(vec_f16, tile[i, j:], mask32, dist=pto.VStoreDist.PK_B32) + # In a generic kernel @pto.vkernel(target="a5", op="copy", dtypes=[(pto.AnyFloat, pto.AnyFloat)], priority=10) def generic_store(src: pto.Tile, dst: pto.Tile): diff --git a/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md b/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md index 513539b7e..2791b7ce4 100644 --- a/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md +++ b/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md @@ -1414,12 +1414,19 @@ family. `f16 -> si32` requires `rnd` and `part`, and rejects `sat`; `bf16 -> f16` requires `rnd` and `sat`; `f16 -> f32` requires `part`; + `f32 -> f16` requires `rnd`, `sat`, and `part`; `si32 -> f32` requires `rnd`. - VPTO does not define a `mask_b64` form. Conversions that produce `si64` results still use the typed mask granularity of the source vector family. - Width-changing conversions continue to follow VPTO packing semantics even on the simplified DSL surface. For example, `f16 -> f32` uses an `f16`-family `mask_b16`, because the mask is attached to the source vector family. +- A common `tcvt`-style pair is: + `f16 -> f32`: `pto.vlds(..., dist=pto.VLoadDist.UNPK_B16)` + `pto.vcvt(..., part=pto.VcvtPartMode.EVEN)`; + `f32 -> f16`: `pto.vcvt(..., rnd=..., sat=..., part=pto.VcvtPartMode.EVEN)` + `pto.vsts(..., dist=pto.VStoreDist.PK_B32)`. +- In those `tcvt` flows, the `vcvt` mask still follows the source vector family: + `f16 -> f32` uses `mask_b16`, while `f32 -> f16` uses `mask_b32`. +- The follow-on `vsts` mask is checked against the store `dist`, not the narrowed element dtype alone. For example, `pto.vsts(vec_f16, ..., mask32, dist=pto.VStoreDist.PK_B32)` is valid and expected for `f32 -> f16` rowwise `tcvt`. **Example**: ```python @@ -1454,6 +1461,26 @@ vec_f16_narrow = pto.vcvt( sat=pto.VcvtSatMode.SAT, part=pto.VcvtPartMode.ODD, ) + +# Rowwise tcvt-style widening from f16 to f32 +vec_f16_unpacked = pto.vlds(src, 0, dist=pto.VLoadDist.UNPK_B16) +vec_f32_from_f16 = pto.vcvt( + vec_f16_unpacked, + pto.f32, + mask16, + part=pto.VcvtPartMode.EVEN, +) + +# Rowwise tcvt-style narrowing from f32 to f16 +vec_f16_packed = pto.vcvt( + vec_f32, + pto.f16, + mask32, + rnd=pto.VcvtRoundMode.R, + sat=pto.VcvtSatMode.SAT, + part=pto.VcvtPartMode.EVEN, +) +pto.vsts(vec_f16_packed, dst, 0, mask32, dist=pto.VStoreDist.PK_B32) ``` #### `pto.vbitsort(dest: ptr, src: ptr, indices: ptr, repeat_times: index) -> None` [Advanced Tier] diff --git a/tilelang-dsl/python/tilelang_dsl/__init__.py b/tilelang-dsl/python/tilelang_dsl/__init__.py index 8f036b476..8bfe77eb1 100644 --- a/tilelang-dsl/python/tilelang_dsl/__init__.py +++ b/tilelang-dsl/python/tilelang_dsl/__init__.py @@ -47,10 +47,12 @@ VcvtPartMode, VcvtRoundMode, VcvtSatMode, + VLoadDist, PointerType, PostUpdateMode, Pipe, PredicateDist, + VStoreDist, ScalarType, SLayout, TensorView, @@ -124,6 +126,8 @@ "EVENT", "MaskPattern", "PredicateDist", + "VLoadDist", + "VStoreDist", "PredicatePart", "CmpMode", "PAT", diff --git a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py index 821b6b27f..49580a289 100644 --- a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py +++ b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py @@ -989,6 +989,8 @@ def _build_expr(node: ast.AST, context: _FrontendBuildContext) -> FrontendExprNo "EVENT", "MaskPattern", "PredicateDist", + "VLoadDist", + "VStoreDist", "PredicatePart", "CmpMode", "Pipe", diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index c23aff866..43a821459 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -76,7 +76,9 @@ VcvtPartMode, VcvtRoundMode, VcvtSatMode, + VLoadDist, VRegType, + VStoreDist, bf16, bytewidth, f16, @@ -138,6 +140,8 @@ for pad_value in (PadValue.NULL, PadValue.ZERO, PadValue.MAX, PadValue.MIN) } _PREDICATE_DIST_SYMBOLS = {dist.name: dist for dist in PredicateDist} +_VLOAD_DIST_SYMBOLS = {dist.name: dist for dist in VLoadDist} +_VSTORE_DIST_SYMBOLS = {dist.name: dist for dist in VStoreDist} _PREDICATE_PART_SYMBOLS = {part.name: part for part in PredicatePart} _CMP_MODE_SYMBOLS = {mode.name: mode for mode in CmpMode} _DEINTERLEAVE_DIST_SYMBOLS = dict(DeinterleaveDist.__members__) @@ -3344,6 +3348,24 @@ def _analyze_symbol_expr(self, expr: FrontendSymbolExpr) -> SemanticExpr: value=predicate_dist, type=SemanticMetaType(kind="predicate_dist"), ) + if expr.namespace in {"VLoadDist", "pto.VLoadDist"}: + vload_dist = _VLOAD_DIST_SYMBOLS.get(expr.name) + if vload_dist is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=vload_dist, + type=SemanticMetaType(kind="vload_dist"), + ) + if expr.namespace in {"VStoreDist", "pto.VStoreDist"}: + vstore_dist = _VSTORE_DIST_SYMBOLS.get(expr.name) + if vstore_dist is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=vstore_dist, + type=SemanticMetaType(kind="vstore_dist"), + ) if expr.namespace in {"PredicatePart", "pto.PredicatePart"}: predicate_part = _PREDICATE_PART_SYMBOLS.get(expr.name) if predicate_part is not None: @@ -5831,37 +5853,27 @@ def _normalize_vlds_dist( ) -> SemanticExpr | None: if expr is None: return None - dist = self._require_string_expr(expr, context) - normalized = dist - if normalized not in { - "NORM", - "BRC_B8", - "BRC_B16", - "BRC_B32", - "US_B8", - "US_B16", - "DS_B8", - "DS_B16", - "UNPK_B8", - "UNPK_B16", - "UNPK_B32", - "BRC_BLK", - "E2B_B16", - "E2B_B32", - "UNPK4", - "SPLT4CHN", - "SPLT2CHN_B8", - "SPLT2CHN_B16", - }: + if ( + isinstance(expr, SemanticSymbolExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "vload_dist" + and isinstance(expr.value, VLoadDist) + ): + dist = expr.value.value + elif ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "vload_dist" + and isinstance(expr.binding.value, VLoadDist) + ): + dist = expr.binding.value.value + else: raise TypeError( - "pto.vlds dist must be one of " - "\"NORM\", \"BRC_B8\", \"BRC_B16\", \"BRC_B32\", " - "\"US_B8\", \"US_B16\", \"DS_B8\", \"DS_B16\", " - "\"UNPK_B8\", \"UNPK_B16\", \"UNPK_B32\", \"BRC_BLK\", " - "\"E2B_B16\", \"E2B_B32\", \"UNPK4\", \"SPLT4CHN\", " - "\"SPLT2CHN_B8\", or \"SPLT2CHN_B16\" in TileLang DSL v1" + "pto.vlds dist must be a VLoadDist enum such as " + "`pto.VLoadDist.NORM`, `pto.VLoadDist.UNPK_B16`, or " + "`pto.VLoadDist.BRC_B32` in TileLang DSL v1" ) - return SemanticLiteralExpr(value=normalized, type=SemanticMetaType(kind="string")) + return SemanticLiteralExpr(value=dist, type=SemanticMetaType(kind="string")) def _normalize_vsts_dist( self, @@ -5870,31 +5882,27 @@ def _normalize_vsts_dist( ) -> SemanticExpr | None: if expr is None: return None - dist = self._require_string_expr(expr, context) - normalized = dist - if normalized not in { - "NORM_B8", - "NORM_B16", - "NORM_B32", - "1PT_B8", - "1PT_B16", - "1PT_B32", - "PK_B16", - "PK_B32", - "PK_B64", - "PK4_B32", - "MRG4CHN_B8", - "MRG2CHN_B8", - "MRG2CHN_B16", - }: + if ( + isinstance(expr, SemanticSymbolExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "vstore_dist" + and isinstance(expr.value, VStoreDist) + ): + dist = expr.value.value + elif ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "vstore_dist" + and isinstance(expr.binding.value, VStoreDist) + ): + dist = expr.binding.value.value + else: raise TypeError( - "pto.vsts dist must be one of " - "\"NORM_B8\", \"NORM_B16\", \"NORM_B32\", " - "\"1PT_B8\", \"1PT_B16\", \"1PT_B32\", " - "\"PK_B16\", \"PK_B32\", \"PK_B64\", \"PK4_B32\", " - "\"MRG4CHN_B8\", \"MRG2CHN_B8\", or \"MRG2CHN_B16\" in TileLang DSL v1" + "pto.vsts dist must be a VStoreDist enum such as " + "`pto.VStoreDist.NORM_B32`, `pto.VStoreDist.PK_B32`, or " + "`pto.VStoreDist.ONE_POINT_B8` in TileLang DSL v1" ) - return SemanticLiteralExpr(value=normalized, type=SemanticMetaType(kind="string")) + return SemanticLiteralExpr(value=dist, type=SemanticMetaType(kind="string")) def _require_i1_expr(self, expr: SemanticExpr, context: str) -> None: scalar = self._require_scalar_expr(expr, context) diff --git a/tilelang-dsl/python/tilelang_dsl/types.py b/tilelang-dsl/python/tilelang_dsl/types.py index 0e806f82f..dfcdc8759 100644 --- a/tilelang-dsl/python/tilelang_dsl/types.py +++ b/tilelang-dsl/python/tilelang_dsl/types.py @@ -205,6 +205,43 @@ class PredicateDist(str, Enum): PK = "PK" +class VLoadDist(str, Enum): + NORM = "NORM" + BRC_B8 = "BRC_B8" + BRC_B16 = "BRC_B16" + BRC_B32 = "BRC_B32" + US_B8 = "US_B8" + US_B16 = "US_B16" + DS_B8 = "DS_B8" + DS_B16 = "DS_B16" + UNPK_B8 = "UNPK_B8" + UNPK_B16 = "UNPK_B16" + UNPK_B32 = "UNPK_B32" + BRC_BLK = "BRC_BLK" + E2B_B16 = "E2B_B16" + E2B_B32 = "E2B_B32" + UNPK4 = "UNPK4" + SPLT4CHN = "SPLT4CHN" + SPLT2CHN_B8 = "SPLT2CHN_B8" + SPLT2CHN_B16 = "SPLT2CHN_B16" + + +class VStoreDist(str, Enum): + NORM_B8 = "NORM_B8" + NORM_B16 = "NORM_B16" + NORM_B32 = "NORM_B32" + ONE_POINT_B8 = "1PT_B8" + ONE_POINT_B16 = "1PT_B16" + ONE_POINT_B32 = "1PT_B32" + PK_B16 = "PK_B16" + PK_B32 = "PK_B32" + PK_B64 = "PK_B64" + PK4_B32 = "PK4_B32" + MRG4CHN_B8 = "MRG4CHN_B8" + MRG2CHN_B8 = "MRG2CHN_B8" + MRG2CHN_B16 = "MRG2CHN_B16" + + class PredicatePart(str, Enum): LOWER = "LOWER" HIGHER = "HIGHER" @@ -710,6 +747,8 @@ def get_op_attr(name: str, default: Any = None) -> Any: "EVENT", "MaskPattern", "PredicateDist", + "VLoadDist", + "VStoreDist", "PredicatePart", "CmpMode", "PAT", diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index ee1ad42ff..736ada783 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -3614,7 +3614,7 @@ def test_vcvt_i32_to_i64_reuses_b32_mask_and_emits_i64_vreg(self) -> None: def kernel(dst: pto.Tile, src: pto.Tile): src_mask = pto.make_mask(pto.i32, pto.PAT.ALL) dst_mask = pto.make_mask(pto.i64, pto.PAT.ALL) - vec = pto.vlds(src, 0, dist="UNPK_B32") + vec = pto.vlds(src, 0, dist=pto.VLoadDist.UNPK_B32) out = pto.vcvt( vec, pto.i64, @@ -3642,6 +3642,50 @@ def kernel(dst: pto.Tile, src: pto.Tile): self.assertIn('part = "EVEN"', text) self.assertIn("pto.vsts", text) + def test_vlds_dist_requires_vload_dist_enum(self) -> None: + with self.assertRaises(TypeError) as ctx: + + @pto.vkernel( + op="vlds_dist_requires_enum_unique", + dtypes=[(pto.i32, pto.i32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + mask = pto.make_mask(pto.i32, pto.PAT.ALL) + vec = pto.vlds(src, 0, dist="UNPK_B32") + pto.vsts(vec, dst, 0, mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + specialized.mlir_text() + + self.assertIn("VLoadDist enum", str(ctx.exception)) + + def test_vsts_dist_requires_vstore_dist_enum(self) -> None: + with self.assertRaises(TypeError) as ctx: + + @pto.vkernel( + op="vsts_dist_requires_enum_unique", + dtypes=[(pto.ui8, pto.ui8)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + mask = pto.make_mask(pto.ui8, pto.PAT.ALL) + vec = pto.vlds(src, 0) + pto.vsts(vec, dst, 0, mask, dist="NORM_B8") + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 256), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 256), memory_space=pto.MemorySpace.UB), + ) + specialized.mlir_text() + + self.assertIn("VStoreDist enum", str(ctx.exception)) + def test_vtrc_defaults_to_round_nearest(self) -> None: @pto.vkernel( op="vtrc_default_rnd_unique", @@ -3879,7 +3923,7 @@ def test_vcvt_f16_to_i32_requires_rnd_and_part(self) -> None: def kernel(dst: pto.Tile, src: pto.Tile): src_mask = pto.make_mask(pto.f16, pto.PAT.ALL) dst_mask = pto.make_mask(pto.i32, pto.PAT.ALL) - vec = pto.vlds(src, 0, dist="UNPK_B16") + vec = pto.vlds(src, 0, dist=pto.VLoadDist.UNPK_B16) out = pto.vcvt( vec, pto.i32, @@ -3907,7 +3951,7 @@ def kernel(dst: pto.Tile, src: pto.Tile): def kernel(dst: pto.Tile, src: pto.Tile): src_mask = pto.make_mask(pto.f16, pto.PAT.ALL) dst_mask = pto.make_mask(pto.i32, pto.PAT.ALL) - vec = pto.vlds(src, 0, dist="UNPK_B16") + vec = pto.vlds(src, 0, dist=pto.VLoadDist.UNPK_B16) out = pto.vcvt( vec, pto.i32, @@ -3936,7 +3980,7 @@ def test_vcvt_f16_to_i32_rejects_sat(self) -> None: def kernel(dst: pto.Tile, src: pto.Tile): src_mask = pto.make_mask(pto.f16, pto.PAT.ALL) dst_mask = pto.make_mask(pto.i32, pto.PAT.ALL) - vec = pto.vlds(src, 0, dist="UNPK_B16") + vec = pto.vlds(src, 0, dist=pto.VLoadDist.UNPK_B16) out = pto.vcvt( vec, pto.i32, @@ -3965,7 +4009,7 @@ def test_vcvt_f16_to_i32_accepts_rnd_and_part(self) -> None: def kernel(dst: pto.Tile, src: pto.Tile): src_mask = pto.make_mask(pto.f16, pto.PAT.ALL) dst_mask = pto.make_mask(pto.i32, pto.PAT.ALL) - vec = pto.vlds(src, 0, dist="UNPK_B16") + vec = pto.vlds(src, 0, dist=pto.VLoadDist.UNPK_B16) out = pto.vcvt( vec, pto.i32, @@ -6184,7 +6228,7 @@ def kernel(src: pto.Tile, dst: pto.Tile): ) result = pto.vselr(converted, v_idx_ui8) pto.mem_bar(pto.BarrierType.VST_VST) - pto.vsts(result, dst[row, col:], store_mask, dist="NORM_B8") + pto.vsts(result, dst[row, col:], store_mask, dist=pto.VStoreDist.NORM_B8) return None @@ -6213,15 +6257,15 @@ def kernel(src: pto.Tile, dst: pto.Tile): b8_mask = pto.make_mask(pto.i8, pto.PAT.ALL) mask_b16, _ = pto.make_mask(pto.i16, valid_cols) mask_b32 = pto.punpack(mask_b16, pto.PredicatePart.LOWER) - vec_si8 = pto.vlds(src[row, 0:], dist="UNPK_B8") + vec_si8 = pto.vlds(src[row, 0:], dist=pto.VLoadDist.UNPK_B8) vec_ui8 = pto.vbitcast(vec_si8, pto.ui8) v_zero_i8 = pto.vdup(pto.i8(0), b8_mask) v_zero = pto.vbitcast(v_zero_i8, pto.ui8) wide_lo, _ = pto.vintlv(vec_ui8, v_zero) narrowed = pto.vbitcast(wide_lo, pto.si8) converted = pto.vcvt(narrowed, pto.i32, b8_mask, part=pto.VcvtPartMode.P0) - pto.vsts(converted, dst[row, 0:], mask_b32, dist="NORM_B32") - pto.vsts(converted, dst[row, lanes_i32:], mask_b32, dist="NORM_B32") + pto.vsts(converted, dst[row, 0:], mask_b32, dist=pto.VStoreDist.NORM_B32) + pto.vsts(converted, dst[row, lanes_i32:], mask_b32, dist=pto.VStoreDist.NORM_B32) return None specialized = kernel.specialize( From 3b229704f399f85b84340d4a26471da5632b0902 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Fri, 24 Apr 2026 12:01:13 +0800 Subject: [PATCH 149/192] chore(tileops): add missing license header for tcvt template --- lib/TileOps/tcvt_template.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/lib/TileOps/tcvt_template.py b/lib/TileOps/tcvt_template.py index dca49b4e3..19f177a60 100644 --- a/lib/TileOps/tcvt_template.py +++ b/lib/TileOps/tcvt_template.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + """TileLang DSL template for pto.tcvt.""" import tilelang_dsl as pto From 8b5923237ddd9f1f23752e606906fa77d8e347ef Mon Sep 17 00:00:00 2001 From: qukelin Date: Fri, 24 Apr 2026 02:09:39 +0800 Subject: [PATCH 150/192] Add a mixed Tile/VPTO online softmax kernel --- include/PTO/IR/VPTOOps.td | 7 +- lib/PTO/IR/VPTO.cpp | 33 ++- lib/PTO/Transforms/ExpandTileOp.cpp | 21 +- lib/PTO/Transforms/MemrefToTileBuf.cpp | 26 +- lib/TileOps/tload_template.py | 35 ++- lib/TileOps/tstore_template.py | 31 ++- .../npu/a5/src/st/testcase/CMakeLists.txt | 17 ++ .../src/st/testcase/run_ptoas_to_file.cmake | 36 ++- .../a5/src/st/testcase/softmax/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/softmax/cases.py | 25 ++ .../npu/a5/src/st/testcase/softmax/compare.py | 82 ++++++ .../a5/src/st/testcase/softmax/gen_data.py | 64 +++++ .../npu/a5/src/st/testcase/softmax/launch.cpp | 27 ++ .../npu/a5/src/st/testcase/softmax/main.cpp | 197 ++++++++++++++ .../a5/src/st/testcase/softmax/softmax.pto | 247 +++++++++++++++++ .../kernels/online-softmax-update/compare.py | 47 +++- .../kernels/online-softmax-update/kernel.pto | 6 - .../kernel_tload_tstore.pto | 249 ++++++++++++++++++ .../python/tilelang_dsl/expand_helper.py | 34 ++- tools/ptoas/ptoas.cpp | 16 +- 20 files changed, 1160 insertions(+), 49 deletions(-) create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/softmax/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/softmax/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/softmax/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/softmax/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/softmax/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/softmax/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/softmax/softmax.pto create mode 100644 test/vpto/cases/kernels/online-softmax-update/kernel_tload_tstore.pto diff --git a/include/PTO/IR/VPTOOps.td b/include/PTO/IR/VPTOOps.td index 35c9d7723..ea72b8c75 100644 --- a/include/PTO/IR/VPTOOps.td +++ b/include/PTO/IR/VPTOOps.td @@ -77,10 +77,13 @@ def TileBufAddrOp : PTO_Op<"tile_buf_addr", [Pure]> { shape and address space. This op is emitted by TileLang DSL templates and resolved by the - FoldTileBufIntrinsics pass after inlining. + FoldTileBufIntrinsics pass after inlining. Hand-written `.pto` may also + use it directly on the memref result of `pto.bind_tile` / lowered + `pto.alloc_tile`. }]; - let arguments = (ins TileBufType:$src); + let arguments = (ins AnyTypeOf<[TileBufType, AnyMemRef], + "tile_buf or tile-bound memref">:$src); let results = (outs PtrOrMemRef:$dst); let hasVerifier = 1; diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp index 1413e70ef..10da9c322 100644 --- a/lib/PTO/IR/VPTO.cpp +++ b/lib/PTO/IR/VPTO.cpp @@ -2032,19 +2032,38 @@ LogicalResult TensorViewAddrOp::verify() { } LogicalResult TileBufAddrOp::verify() { - auto srcType = dyn_cast(getSrc().getType()); - if (!srcType) - return emitOpError("source must be a !pto.tile_buf<...>"); - Type dstType = getDst().getType(); - Type elementType = srcType.getElementType(); - auto srcSpace = dyn_cast_or_null(srcType.getMemorySpace()); + Type elementType; + Attribute srcMemorySpace; + int64_t srcRank = 0; + + if (auto srcTileType = dyn_cast(getSrc().getType())) { + elementType = srcTileType.getElementType(); + srcMemorySpace = srcTileType.getMemorySpace(); + srcRank = static_cast(srcTileType.getShape().size()); + } else if (auto srcMemRefType = dyn_cast(getSrc().getType())) { + // Compatibility for the current TileOp expansion pipeline: + // PTOViewToMemref lowers tile_buf producers (for example alloc_tile) to + // memref + pto.bind_tile before MemrefToTileBuf reconstructs tile_buf + // values. Hand-written pto.tile_buf_addr may therefore temporarily see a + // tile-bound memref operand in this intermediate stage. If the pipeline is + // changed to avoid that PTOViewToMemref round-trip, this memref acceptance + // can be removed and TileBufAddrOp can go back to requiring tile_buf-only + // operands. + elementType = srcMemRefType.getElementType(); + srcMemorySpace = srcMemRefType.getMemorySpace(); + srcRank = srcMemRefType.getRank(); + } else { + return emitOpError("source must be a !pto.tile_buf<...> or memref"); + } + + auto srcSpace = dyn_cast_or_null(srcMemorySpace); if (auto dstMemRefType = dyn_cast(dstType)) { if (dstMemRefType.getElementType() != elementType) return emitOpError( "memref result element type must match tile element type"); - if (dstMemRefType.getRank() != static_cast(srcType.getShape().size())) + if (dstMemRefType.getRank() != srcRank) return emitOpError("memref result rank must match tile rank"); auto dstSpace = dyn_cast_or_null(dstMemRefType.getMemorySpace()); diff --git a/lib/PTO/Transforms/ExpandTileOp.cpp b/lib/PTO/Transforms/ExpandTileOp.cpp index 674892450..0dd21c5ea 100644 --- a/lib/PTO/Transforms/ExpandTileOp.cpp +++ b/lib/PTO/Transforms/ExpandTileOp.cpp @@ -477,6 +477,23 @@ static void appendJsonIntArray(std::string &json, ArrayRef arr) { json += "]"; } +/// Serialize a JSON array where dynamic dimensions become `null`. +static void appendJsonDimArray(std::string &json, ArrayRef arr, + bool negativeIsDynamic = false) { + json += "["; + for (size_t i = 0; i < arr.size(); ++i) { + if (i > 0) + json += ","; + int64_t dim = arr[i]; + if (ShapedType::isDynamic(dim) || (negativeIsDynamic && dim < 0)) { + json += "null"; + continue; + } + json += std::to_string(dim); + } + json += "]"; +} + static std::string buildOperandSpecsJson(const SpecKey &key) { std::string json = "["; for (size_t i = 0; i < key.operands.size(); ++i) { @@ -488,7 +505,7 @@ static std::string buildOperandSpecsJson(const SpecKey &key) { json += "{\"kind\":\"tile\",\"dtype\":\"" + op.dtype + "\",\"shape\":"; appendJsonIntArray(json, op.tileShape); json += ",\"valid_shape\":"; - appendJsonIntArray(json, op.tileValidShape); + appendJsonDimArray(json, op.tileValidShape, /*negativeIsDynamic=*/true); json += ",\"memory_space\":\""; json += op.tileMemorySpace; json += "\",\"config\":{"; @@ -506,7 +523,7 @@ static std::string buildOperandSpecsJson(const SpecKey &key) { if (op.kind == OperandKind::View) { json += "{\"kind\":\"view\",\"dtype\":\"" + op.dtype + "\",\"shape\":"; - appendJsonIntArray(json, op.viewShape); + appendJsonDimArray(json, op.viewShape); if (!op.viewStrides.empty()) { json += ",\"strides\":["; for (size_t dim = 0; dim < op.viewStrides.size(); ++dim) { diff --git a/lib/PTO/Transforms/MemrefToTileBuf.cpp b/lib/PTO/Transforms/MemrefToTileBuf.cpp index 40e86006b..d38b98949 100644 --- a/lib/PTO/Transforms/MemrefToTileBuf.cpp +++ b/lib/PTO/Transforms/MemrefToTileBuf.cpp @@ -85,10 +85,15 @@ static pto::TileBufType reconstructTileBufType(pto::BindTileOp bindOp) { } // ============================================================================ -// Helper: check whether an op is a tile-level op (needs tile_buf operands) +// Helper: check whether an op should consume recovered tile_buf operands. +// +// Besides tile ops themselves, tile_buf intrinsics also need to be rewired so +// later FoldTileBufIntrinsics sees the canonical +// unrealized_conversion_cast <- bind_tile anchor. // ============================================================================ -static bool isTileOp(Operation *op) { - return isa(op); +static bool shouldRewriteTileBufOperands(Operation *op) { + return isa(op); } // ============================================================================ @@ -138,10 +143,10 @@ LogicalResult MemrefToTileBufPass::processFunction(func::FuncOp func, memrefToTileBuf[bindOp.getResult()] = cast.getResult(0); } - // Phase 2: For each tile op, replace memref operands that have a - // corresponding tile_buf value. + // Phase 2: For each tile op / tile_buf intrinsic, replace memref operands + // that have a corresponding tile_buf value. func.walk([&](Operation *op) { - if (!isTileOp(op)) + if (!shouldRewriteTileBufOperands(op)) return; for (OpOperand &operand : op->getOpOperands()) { auto it = memrefToTileBuf.find(operand.get()); @@ -174,7 +179,8 @@ LogicalResult MemrefToTileBufPass::processFunction(func::FuncOp func, unsigned idx = blockArg.getArgNumber(); pto::TileBufType tileBufTy = reconstructTileBufType(bindOp); - // Replace all tile op uses of the cast with the block arg directly. + // Replace all tile op / tile_buf intrinsic uses of the cast with the + // block arg directly. auto castIt = memrefToTileBuf.find(bindOp.getResult()); if (castIt == memrefToTileBuf.end()) continue; @@ -203,9 +209,9 @@ LogicalResult MemrefToTileBufPass::processFunction(func::FuncOp func, // But the back-cast's own operand must remain the block arg. backCast.getInputsMutable().assign(ValueRange{blockArg}); - // Now replace tile op uses: they should use the tile_buf block arg - // directly instead of going through the unrealized_conversion_cast - // chain. + // Now replace tile op / tile_buf intrinsic uses: they should use the + // tile_buf block arg directly instead of going through the + // unrealized_conversion_cast chain. tileBufVal.replaceAllUsesWith(blockArg); // Erase the now-dead forward cast (memref → tile_buf). if (auto castOp = diff --git a/lib/TileOps/tload_template.py b/lib/TileOps/tload_template.py index d298bdfac..7be0971a3 100644 --- a/lib/TileOps/tload_template.py +++ b/lib/TileOps/tload_template.py @@ -10,6 +10,27 @@ import tilelang_dsl as pto + +def _constraint_scalar(value): + return value.value if hasattr(value, "value") else value + + +def _known_eq(lhs, rhs) -> bool: + lhs_value = _constraint_scalar(lhs) + rhs_value = _constraint_scalar(rhs) + if lhs_value is None or rhs_value is None: + return True + return lhs_value == rhs_value + + +def _known_le(lhs, rhs) -> bool: + lhs_value = _constraint_scalar(lhs) + rhs_value = _constraint_scalar(rhs) + if lhs_value is None or rhs_value is None: + return True + return lhs_value <= rhs_value + + def _match_tile_layout(dst, *, row_major: bool, s_layout) -> bool: b_layout_ok = ( dst.config.b_layout == pto.BLayout.ROW_MAJOR @@ -22,16 +43,20 @@ def _match_tile_layout(dst, *, row_major: bool, s_layout) -> bool: def _check_load_bounds(src, dst, *, logical_rows, logical_cols=None, stride_axis=None) -> bool: if src.rank != 5: return False - if stride_axis is not None and src.strides[stride_axis] != 1: + if stride_axis is not None and not _known_eq(src.strides[stride_axis], 1): return False - if dst.valid_shape[0] > logical_rows or logical_rows > dst.shape[0]: + if not _known_le(dst.valid_shape[0], logical_rows): return False - if dst.valid_shape[0] > dst.shape[0]: + if not _known_le(logical_rows, dst.shape[0]): + return False + if not _known_le(dst.valid_shape[0], dst.shape[0]): return False if logical_cols is not None: - if dst.valid_shape[1] > logical_cols or logical_cols > dst.shape[1]: + if not _known_le(dst.valid_shape[1], logical_cols): + return False + if not _known_le(logical_cols, dst.shape[1]): return False - if dst.valid_shape[1] > dst.shape[1]: + if not _known_le(dst.valid_shape[1], dst.shape[1]): return False return True diff --git a/lib/TileOps/tstore_template.py b/lib/TileOps/tstore_template.py index 278d2c25f..37597a299 100644 --- a/lib/TileOps/tstore_template.py +++ b/lib/TileOps/tstore_template.py @@ -10,6 +10,27 @@ import tilelang_dsl as pto + +def _constraint_scalar(value): + return value.value if hasattr(value, "value") else value + + +def _known_eq(lhs, rhs) -> bool: + lhs_value = _constraint_scalar(lhs) + rhs_value = _constraint_scalar(rhs) + if lhs_value is None or rhs_value is None: + return True + return lhs_value == rhs_value + + +def _known_le(lhs, rhs) -> bool: + lhs_value = _constraint_scalar(lhs) + rhs_value = _constraint_scalar(rhs) + if lhs_value is None or rhs_value is None: + return True + return lhs_value <= rhs_value + + def _match_store_tile_layout(src, *, row_major: bool, s_layout) -> bool: b_layout_ok = ( src.config.b_layout == pto.BLayout.ROW_MAJOR @@ -22,15 +43,15 @@ def _match_store_tile_layout(src, *, row_major: bool, s_layout) -> bool: def _check_store_bounds(src, dst, *, logical_rows, logical_cols, stride_axis=None) -> bool: if dst.rank != 5: return False - if stride_axis is not None and dst.strides[stride_axis] != 1: + if stride_axis is not None and not _known_eq(dst.strides[stride_axis], 1): return False - if src.valid_shape[0] != logical_rows: + if not _known_eq(src.valid_shape[0], logical_rows): return False - if src.valid_shape[1] != logical_cols: + if not _known_eq(src.valid_shape[1], logical_cols): return False - if src.valid_shape[0] > src.shape[0]: + if not _known_le(src.valid_shape[0], src.shape[0]): return False - if src.valid_shape[1] > src.shape[1]: + if not _known_le(src.valid_shape[1], src.shape[1]): return False return True diff --git a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt index db614307b..7d96af8f2 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt +++ b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt @@ -20,6 +20,20 @@ set(PTO_TILELANG_ST_TESTCASE_DIR ${CMAKE_CURRENT_LIST_DIR}) function(pto_tilelang_vec_st NAME) + set(options DISABLE_INSERT_SYNC) + set(oneValueArgs PTO_LEVEL) + cmake_parse_arguments(PTO_TILELANG_ST "${options}" "${oneValueArgs}" "" ${ARGN}) + + set(PTOAS_ENABLE_INSERT_SYNC ON) + if(PTO_TILELANG_ST_DISABLE_INSERT_SYNC) + set(PTOAS_ENABLE_INSERT_SYNC OFF) + endif() + + set(PTOAS_PTO_LEVEL "") + if(DEFINED PTO_TILELANG_ST_PTO_LEVEL) + set(PTOAS_PTO_LEVEL "${PTO_TILELANG_ST_PTO_LEVEL}") + endif() + # Step 1: ptoas .pto → kernel.ll set(PTO_SRC ${CMAKE_CURRENT_SOURCE_DIR}/${NAME}.pto) set(KERNEL_LL ${CMAKE_CURRENT_BINARY_DIR}/${NAME}_kernel.ll) @@ -31,6 +45,8 @@ function(pto_tilelang_vec_st NAME) -DPTOAS_BIN=${PTOAS_BIN} -DPTO_SRC=${PTO_SRC} -DKERNEL_LL=${KERNEL_LL} + -DPTOAS_ENABLE_INSERT_SYNC=${PTOAS_ENABLE_INSERT_SYNC} + -DPTOAS_PTO_LEVEL=${PTOAS_PTO_LEVEL} -P ${PTOAS_CAPTURE_SCRIPT} DEPENDS ${PTO_SRC} ${PTOAS_CAPTURE_SCRIPT} COMMENT "ptoas: ${NAME}.pto -> ${NAME}_kernel.ll" @@ -116,6 +132,7 @@ set(ALL_TESTCASES tadd tcvt tload + softmax ) if((TEST_CASE IN_LIST ALL_TESTCASES) OR (TEST_CASE STREQUAL "all")) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/run_ptoas_to_file.cmake b/test/tilelang_st/npu/a5/src/st/testcase/run_ptoas_to_file.cmake index 8922fbab8..be7b98468 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/run_ptoas_to_file.cmake +++ b/test/tilelang_st/npu/a5/src/st/testcase/run_ptoas_to_file.cmake @@ -13,15 +13,35 @@ endif() get_filename_component(KERNEL_LL_DIR "${KERNEL_LL}" DIRECTORY) file(MAKE_DIRECTORY "${KERNEL_LL_DIR}") +if(NOT DEFINED PTOAS_ENABLE_INSERT_SYNC) + set(PTOAS_ENABLE_INSERT_SYNC ON) +endif() + +set(PTOAS_COMMAND + "${PTOAS_BIN}" + --pto-arch=a5 +) + +if(DEFINED PTOAS_PTO_LEVEL AND NOT PTOAS_PTO_LEVEL STREQUAL "") + list(APPEND PTOAS_COMMAND "--pto-level=${PTOAS_PTO_LEVEL}") +endif() + +list(APPEND PTOAS_COMMAND --pto-backend=vpto) + +if(PTOAS_ENABLE_INSERT_SYNC) + list(APPEND PTOAS_COMMAND --enable-insert-sync) +endif() + +list(APPEND PTOAS_COMMAND + --enable-tile-op-expand + --vpto-emit-hivm-llvm + "${PTO_SRC}" + -o + - +) + execute_process( - COMMAND "${PTOAS_BIN}" - --pto-arch=a5 - --pto-backend=vpto - --enable-insert-sync - --enable-tile-op-expand - --vpto-emit-hivm-llvm - "${PTO_SRC}" - -o - + COMMAND ${PTOAS_COMMAND} OUTPUT_FILE "${KERNEL_LL}" ERROR_VARIABLE PTOAS_STDERR RESULT_VARIABLE PTOAS_RESULT diff --git a/test/tilelang_st/npu/a5/src/st/testcase/softmax/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/softmax/CMakeLists.txt new file mode 100644 index 000000000..3c5224444 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/softmax/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(softmax DISABLE_INSERT_SYNC PTO_LEVEL level3) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/softmax/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/softmax/cases.py new file mode 100644 index 000000000..8b865c96a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/softmax/cases.py @@ -0,0 +1,25 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import numpy as np + + +CASES = [ + { + "name": "f32_rows24_seq73", + "dtype": np.float32, + "shape": (24, 128), + "valid_shape": (24, 73), + "eps": 1e-4, + "rows": 24, + "cols": 128, + "seq": 73, + "seed": 19, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/softmax/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/softmax/compare.py new file mode 100644 index 000000000..6a5c89eb8 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/softmax/compare.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import os +import sys + +import numpy as np +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def load_array(path, dtype, shape): + if not os.path.exists(path): + raise FileNotFoundError(path) + return np.fromfile(path, dtype=dtype).reshape(shape) + + +def compare_case(case): + case_dir = case["name"] + rows = int(case["rows"]) + cols = int(case["cols"]) + seq = int(case["seq"]) + dtype = case["dtype"] + eps = case["eps"] + + try: + golden_v4 = load_array(os.path.join(case_dir, "golden_v4.bin"), dtype, (rows,)) + output_v4 = load_array(os.path.join(case_dir, "v4.bin"), dtype, (rows,)) + golden_v5 = load_array(os.path.join(case_dir, "golden_v5.bin"), dtype, (rows,)) + output_v5 = load_array(os.path.join(case_dir, "v5.bin"), dtype, (rows,)) + golden_v6 = load_array(os.path.join(case_dir, "golden_v6.bin"), dtype, (rows,)) + output_v6 = load_array(os.path.join(case_dir, "v6.bin"), dtype, (rows,)) + golden_v7 = load_array( + os.path.join(case_dir, "golden_v7.bin"), dtype, (rows, cols) + ) + output_v7 = load_array(os.path.join(case_dir, "v7.bin"), dtype, (rows, cols)) + except FileNotFoundError as exc: + print(style_fail(f"[ERROR] {case['name']}: missing file {exc}")) + return False + + ok = True + ok = result_cmp(golden_v4, output_v4, eps) and ok + ok = result_cmp(golden_v5, output_v5, eps) and ok + ok = result_cmp(golden_v6, output_v6, eps) and ok + ok = result_cmp(golden_v7[:, :seq], output_v7[:, :seq], eps) and ok + return ok + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + matched_case = case_filter is None + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + matched_case = True + ok = compare_case(case) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not matched_case: + print(style_fail(f"[ERROR] unknown case filter: {case_filter}")) + sys.exit(2) + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/softmax/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/softmax/gen_data.py new file mode 100644 index 000000000..05bcef759 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/softmax/gen_data.py @@ -0,0 +1,64 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import numpy as np + +from cases import CASES +from st_common import save_case_data, validate_cases + + +validate_cases(CASES) + +for case in CASES: + rows = int(case["rows"]) + cols = int(case["cols"]) + seq = int(case["seq"]) + seed = int(case["seed"]) + + rng = np.random.default_rng(seed) + oldmax = rng.uniform(-3.0, 1.5, size=(rows,)).astype(np.float32) + oldsum = rng.uniform(0.5, 4.0, size=(rows,)).astype(np.float32) + qk = rng.normal(loc=0.0, scale=1.5, size=(rows, cols)).astype(np.float32) + + qk_active = qk[:, :seq] + qk_rowmax = np.max(qk_active, axis=1) + newmax = np.maximum(qk_rowmax, oldmax) + tmp_active = np.exp(qk_active - newmax[:, None], dtype=np.float32) + cursum = np.sum(tmp_active, axis=1, dtype=np.float32) + raw_expmax = np.exp(oldmax - newmax, dtype=np.float32) + newsum = raw_expmax * oldsum + cursum + expmax = (raw_expmax * oldsum) / newsum + out = np.zeros((rows, cols), dtype=np.float32) + out[:, :seq] = tmp_active / newsum[:, None] + + zeros_state = np.zeros((rows,), dtype=np.float32) + zeros_out = np.zeros((rows, cols), dtype=np.float32) + + save_case_data( + case["name"], + { + "v1": oldmax, + "v2": oldsum, + "v3": qk.reshape(-1), + "v4": zeros_state, + "v5": zeros_state, + "v6": zeros_state, + "v7": zeros_out.reshape(-1), + "v8": np.array([seq], dtype=np.int32), + "v9": np.array([rows], dtype=np.int32), + "golden_v4": newmax, + "golden_v5": newsum, + "golden_v6": expmax, + "golden_v7": out.reshape(-1), + }, + ) + print( + f"[INFO] gen_data: {case['name']} rows={rows} cols={cols} " + f"seq={seq} dtype={case['dtype'].__name__}" + ) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/softmax/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/softmax/launch.cpp new file mode 100644 index 000000000..dd702e189 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/softmax/launch.cpp @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +extern "C" __global__ AICORE void online_softmax_update_kernel_2d(__gm__ float *v1, __gm__ float *v2, __gm__ float *v3, __gm__ float *v4, __gm__ float *v5, __gm__ float *v6, __gm__ float *v7, int32_t v8, int32_t v9); + +void LaunchSOFTMAX_f32_rows24_seq73(float *v1, float *v2, float *v3, + float *v4, float *v5, float *v6, + float *v7, int32_t v8, int32_t v9, + void *stream) { + const int32_t blockRows = 8; + const int32_t blocks = (v9 + blockRows - 1) / blockRows; + online_softmax_update_kernel_2d<<>>( + (__gm__ float *)v1, (__gm__ float *)v2, (__gm__ float *)v3, + (__gm__ float *)v4, (__gm__ float *)v5, (__gm__ float *)v6, + (__gm__ float *)v7, v8, v9); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/softmax/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/softmax/main.cpp new file mode 100644 index 000000000..4018ecf57 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/softmax/main.cpp @@ -0,0 +1,197 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +namespace pto { +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +} // namespace pto +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchSOFTMAX_f32_rows24_seq73(float *v1, float *v2, float *v3, + float *v4, float *v5, float *v6, + float *v7, int32_t v8, int32_t v9, + void *stream); + +using LaunchFn = void (*)(float *, float *, float *, float *, float *, float *, + float *, int32_t, int32_t, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; + size_t cols; +}; + +static const TestCase kCases[] = { + {"f32_rows24_seq73", LaunchSOFTMAX_f32_rows24_seq73, 24, 128}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, aclrtStream stream) { + const size_t scalarBytes = sizeof(int32_t); + const size_t stateElems = tc.rows; + const size_t outElems = tc.rows * tc.cols; + const size_t stateBytes = stateElems * sizeof(float); + const size_t outBytes = outElems * sizeof(float); + std::string caseDir = std::string("./") + tc.name; + + float *v1Host = nullptr, *v2Host = nullptr, *v3Host = nullptr; + float *v4Host = nullptr, *v5Host = nullptr, *v6Host = nullptr, *v7Host = nullptr; + float *v1Device = nullptr, *v2Device = nullptr, *v3Device = nullptr; + float *v4Device = nullptr, *v5Device = nullptr, *v6Device = nullptr, *v7Device = nullptr; + int32_t seqHost = 0; + int32_t rowsHost = 0; + size_t fileSize = 0; + int rc = 0; + + std::printf("[INFO] === case: %s (rows=%zu, cols=%zu) ===\n", + tc.name, tc.rows, tc.cols); + + if (!ReadFile(caseDir + "/v8.bin", fileSize, &seqHost, scalarBytes) || + !ReadFile(caseDir + "/v9.bin", fileSize, &rowsHost, scalarBytes)) { + std::fprintf(stderr, "[ERROR] failed to read scalar inputs for %s\n", tc.name); + return 1; + } + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), stateBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), stateBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), outBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&v4Host), stateBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&v5Host), stateBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&v6Host), stateBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&v7Host), outBytes)); + + ACL_CHECK(aclrtMalloc((void **)&v1Device, stateBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, stateBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v4Device, stateBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v5Device, stateBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v6Device, stateBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v7Device, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + if (!ReadFile(caseDir + "/v1.bin", fileSize, v1Host, stateBytes) || + !ReadFile(caseDir + "/v2.bin", fileSize, v2Host, stateBytes) || + !ReadFile(caseDir + "/v3.bin", fileSize, v3Host, outBytes) || + !ReadFile(caseDir + "/v4.bin", fileSize, v4Host, stateBytes) || + !ReadFile(caseDir + "/v5.bin", fileSize, v5Host, stateBytes) || + !ReadFile(caseDir + "/v6.bin", fileSize, v6Host, stateBytes) || + !ReadFile(caseDir + "/v7.bin", fileSize, v7Host, outBytes)) { + std::fprintf(stderr, "[ERROR] failed to read tensor inputs for %s\n", tc.name); + rc = 1; + goto cleanup; + } + + ACL_CHECK(aclrtMemcpy(v1Device, stateBytes, v1Host, stateBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, stateBytes, v2Host, stateBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, outBytes, v3Host, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v4Device, stateBytes, v4Host, stateBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v5Device, stateBytes, v5Host, stateBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v6Device, stateBytes, v6Host, stateBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v7Device, outBytes, v7Host, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + + tc.launch(v1Device, v2Device, v3Device, v4Device, v5Device, v6Device, + v7Device, seqHost, rowsHost, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v4Host, stateBytes, v4Device, stateBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v5Host, stateBytes, v5Device, stateBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v6Host, stateBytes, v6Device, stateBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v7Host, outBytes, v7Device, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + + if (!WriteFile(caseDir + "/v4.bin", v4Host, stateBytes) || + !WriteFile(caseDir + "/v5.bin", v5Host, stateBytes) || + !WriteFile(caseDir + "/v6.bin", v6Host, stateBytes) || + !WriteFile(caseDir + "/v7.bin", v7Host, outBytes)) { + std::fprintf(stderr, "[ERROR] failed to write outputs for %s\n", tc.name); + rc = 1; + } + +cleanup: + if (v1Device != nullptr) aclrtFree(v1Device); + if (v2Device != nullptr) aclrtFree(v2Device); + if (v3Device != nullptr) aclrtFree(v3Device); + if (v4Device != nullptr) aclrtFree(v4Device); + if (v5Device != nullptr) aclrtFree(v5Device); + if (v6Device != nullptr) aclrtFree(v6Device); + if (v7Device != nullptr) aclrtFree(v7Device); + if (v1Host != nullptr) aclrtFreeHost(v1Host); + if (v2Host != nullptr) aclrtFreeHost(v2Host); + if (v3Host != nullptr) aclrtFreeHost(v3Host); + if (v4Host != nullptr) aclrtFreeHost(v4Host); + if (v5Host != nullptr) aclrtFreeHost(v5Host); + if (v6Host != nullptr) aclrtFreeHost(v6Host); + if (v7Host != nullptr) aclrtFreeHost(v7Host); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + bool matchedCase = (caseFilter == nullptr); + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) + continue; + matchedCase = true; + if (RunCase(kCases[i], stream) != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (!matchedCase) { + std::fprintf(stderr, "[ERROR] unknown case filter: %s\n", caseFilter); + rc = 1; + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/softmax/softmax.pto b/test/tilelang_st/npu/a5/src/st/testcase/softmax/softmax.pto new file mode 100644 index 000000000..bba0b1dd7 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/softmax/softmax.pto @@ -0,0 +1,247 @@ +// TileLang ST kernel for online softmax update with mixed pto.tload/pto.tstore +// and raw VPTO vecscope compute. +// This testcase keeps manual sync in the source, so ST compilation disables +// --enable-insert-sync and enables --pto-level=level3 for alloc_tile addr=. + +module attributes {pto.target_arch = "a5"} { + func.func @online_softmax_update_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr, + %arg3: !pto.ptr, + %arg4: !pto.ptr, + %arg5: !pto.ptr, + %arg6: !pto.ptr, + %arg7: i32, + %arg8: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c8_i64 = arith.constant 8 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c8448_i64 = arith.constant 8448 : i64 + %c16640_i64 = arith.constant 16640 : i64 + %c16768_i64 = arith.constant 16768 : i64 + %c16896_i64 = arith.constant 16896 : i64 + + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %c64_i32 = arith.constant 64 : i32 + %c0_i32 = arith.constant 0 : i32 + + %block = pto.get_block_idx + %block_idx = arith.index_cast %block : i64 to index + %row_base = arith.muli %block_idx, %c8 : index + %block_rows_i32 = arith.index_cast %c8 : index to i32 + %row_base_i32 = arith.index_cast %row_base : index to i32 + %remaining_rows = arith.subi %arg8, %row_base_i32 : i32 + %has_rows = arith.cmpi sgt, %remaining_rows, %c0_i32 : i32 + %too_many_rows = arith.cmpi sgt, %remaining_rows, %c8_i32 : i32 + %row_count_i32 = arith.select %too_many_rows, %c8_i32, %remaining_rows : i32 + %row_count = arith.index_cast %row_count_i32 : i32 to index + %seq = arith.index_cast %arg7 : i32 to index + %rows = arith.index_cast %arg8 : i32 to index + %rows_x_128 = arith.muli %rows, %c128 : index + + scf.if %has_rows { + %oldmax_view = pto.make_tensor_view %arg0, + shape = [%c1, %c1, %c1, %rows, %c1], + strides = [%rows, %rows, %rows, %c1, %rows] + : !pto.tensor_view + %oldsum_view = pto.make_tensor_view %arg1, + shape = [%c1, %c1, %c1, %rows, %c1], + strides = [%rows, %rows, %rows, %c1, %rows] + : !pto.tensor_view + %qk_view = pto.make_tensor_view %arg2, + shape = [%c1, %c1, %c1, %rows, %c128], + strides = [%rows_x_128, %rows_x_128, %rows_x_128, %c128, %c1] + : !pto.tensor_view + %newmax_view = pto.make_tensor_view %arg3, + shape = [%c1, %c1, %c1, %rows, %c1], + strides = [%rows, %rows, %rows, %c1, %rows] + : !pto.tensor_view + %newsum_view = pto.make_tensor_view %arg4, + shape = [%c1, %c1, %c1, %rows, %c1], + strides = [%rows, %rows, %rows, %c1, %rows] + : !pto.tensor_view + %expmax_view = pto.make_tensor_view %arg5, + shape = [%c1, %c1, %c1, %rows, %c1], + strides = [%rows, %rows, %rows, %c1, %rows] + : !pto.tensor_view + %out_view = pto.make_tensor_view %arg6, + shape = [%c1, %c1, %c1, %rows, %c128], + strides = [%rows_x_128, %rows_x_128, %rows_x_128, %c128, %c1] + : !pto.tensor_view + + %oldmax_part = pto.partition_view %oldmax_view, + offsets = [%c0, %c0, %c0, %row_base, %c0], + sizes = [%c1, %c1, %c1, %row_count, %c1] + : !pto.tensor_view -> !pto.partition_tensor_view + %oldsum_part = pto.partition_view %oldsum_view, + offsets = [%c0, %c0, %c0, %row_base, %c0], + sizes = [%c1, %c1, %c1, %row_count, %c1] + : !pto.tensor_view -> !pto.partition_tensor_view + %qk_part = pto.partition_view %qk_view, + offsets = [%c0, %c0, %c0, %row_base, %c0], + sizes = [%c1, %c1, %c1, %row_count, %seq] + : !pto.tensor_view -> !pto.partition_tensor_view + %newmax_part = pto.partition_view %newmax_view, + offsets = [%c0, %c0, %c0, %row_base, %c0], + sizes = [%c1, %c1, %c1, %row_count, %c1] + : !pto.tensor_view -> !pto.partition_tensor_view + %newsum_part = pto.partition_view %newsum_view, + offsets = [%c0, %c0, %c0, %row_base, %c0], + sizes = [%c1, %c1, %c1, %row_count, %c1] + : !pto.tensor_view -> !pto.partition_tensor_view + %expmax_part = pto.partition_view %expmax_view, + offsets = [%c0, %c0, %c0, %row_base, %c0], + sizes = [%c1, %c1, %c1, %row_count, %c1] + : !pto.tensor_view -> !pto.partition_tensor_view + %out_part = pto.partition_view %out_view, + offsets = [%c0, %c0, %c0, %row_base, %c0], + sizes = [%c1, %c1, %c1, %row_count, %seq] + : !pto.tensor_view -> !pto.partition_tensor_view + + %oldmax_tile = pto.alloc_tile addr = %c0_i64 valid_row = %row_count valid_col = %c1 + : !pto.tile_buf + %oldsum_tile = pto.alloc_tile addr = %c128_i64 valid_row = %row_count valid_col = %c1 + : !pto.tile_buf + %qk_tile = pto.alloc_tile addr = %c256_i64 valid_row = %row_count valid_col = %seq + : !pto.tile_buf + %out_tile = pto.alloc_tile addr = %c8448_i64 valid_row = %row_count valid_col = %seq + : !pto.tile_buf + %newmax_tile = pto.alloc_tile addr = %c16640_i64 valid_row = %row_count valid_col = %c1 + : !pto.tile_buf + %newsum_tile = pto.alloc_tile addr = %c16768_i64 valid_row = %row_count valid_col = %c1 + : !pto.tile_buf + %expmax_tile = pto.alloc_tile addr = %c16896_i64 valid_row = %row_count valid_col = %c1 + : !pto.tile_buf + + %ub_oldmax = pto.tile_buf_addr %oldmax_tile + : !pto.tile_buf + -> !pto.ptr + %ub_oldsum = pto.tile_buf_addr %oldsum_tile + : !pto.tile_buf + -> !pto.ptr + %ub_qk = pto.tile_buf_addr %qk_tile + : !pto.tile_buf + -> !pto.ptr + %ub_out = pto.tile_buf_addr %out_tile + : !pto.tile_buf + -> !pto.ptr + %ub_newmax = pto.tile_buf_addr %newmax_tile + : !pto.tile_buf + -> !pto.ptr + %ub_newsum = pto.tile_buf_addr %newsum_tile + : !pto.tile_buf + -> !pto.ptr + %ub_expmax = pto.tile_buf_addr %expmax_tile + : !pto.tile_buf + -> !pto.ptr + + pto.tload ins(%oldmax_part : !pto.partition_tensor_view) + outs(%oldmax_tile : !pto.tile_buf) + pto.tload ins(%oldsum_part : !pto.partition_tensor_view) + outs(%oldsum_tile : !pto.tile_buf) + pto.tload ins(%qk_part : !pto.partition_tensor_view) + outs(%qk_tile : !pto.tile_buf) + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %active = pto.pset_b32 "PAT_ALL" : !pto.mask + %one_mask, %one_remaining = pto.plt_b32 %c1_i32 : i32 -> !pto.mask, i32 + scf.for %row = %c0 to %row_count step %c1 { + %row_qk = arith.muli %row, %c128 : index + %oldmax_bc = pto.vlds %ub_oldmax[%row] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + %oldsum_bc = pto.vlds %ub_oldsum[%row] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + + %final_max, %final_sum = scf.for %chunk = %c0 to %c128 step %c64 + iter_args(%running_max = %oldmax_bc, %running_sum = %oldsum_bc) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %chunk_i32 = arith.index_cast %chunk : index to i32 + %remaining_cols = arith.subi %arg7, %chunk_i32 : i32 + %has_chunk = arith.cmpi sgt, %remaining_cols, %c0_i32 : i32 + %next_max, %next_sum = scf.if %has_chunk -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %chunk_mask, %chunk_rest = pto.plt_b32 %remaining_cols : i32 -> !pto.mask, i32 + %chunk_base = arith.addi %row_qk, %chunk : index + %vec = pto.vlds %ub_qk[%chunk_base] : !pto.ptr -> !pto.vreg<64xf32> + %chunk_max = pto.vcmax %vec, %chunk_mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %chunk_max_bc = pto.vdup %chunk_max, %active {position = "LOWEST"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %merged_max = pto.vmax %running_max, %chunk_max_bc, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %scaled_running = pto.vexpdif %running_max, %merged_max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + %running_sum_scaled = pto.vmul %scaled_running, %running_sum, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %chunk_exp = pto.vexpdif %vec, %merged_max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + %chunk_sum = pto.vcadd %chunk_exp, %chunk_mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %chunk_sum_bc = pto.vdup %chunk_sum, %active {position = "LOWEST"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %merged_sum = pto.vadd %running_sum_scaled, %chunk_sum_bc, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + scf.yield %merged_max, %merged_sum : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } else { + scf.yield %running_max, %running_sum : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + scf.yield %next_max, %next_sum : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + %raw_expmax = pto.vexpdif %oldmax_bc, %final_max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + %scaled_oldsum = pto.vmul %raw_expmax, %oldsum_bc, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %expmax = pto.vdiv %scaled_oldsum, %final_sum, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %final_max, %ub_newmax[%row], %one_mask {dist = "1PT_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.vsts %final_sum, %ub_newsum[%row], %one_mask {dist = "1PT_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.vsts %expmax, %ub_expmax[%row], %one_mask {dist = "1PT_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + + scf.for %chunk = %c0 to %c128 step %c64 { + %chunk_i32 = arith.index_cast %chunk : index to i32 + %remaining_cols = arith.subi %arg7, %chunk_i32 : i32 + %has_chunk = arith.cmpi sgt, %remaining_cols, %c0_i32 : i32 + scf.if %has_chunk { + %chunk_mask, %chunk_rest = pto.plt_b32 %remaining_cols : i32 -> !pto.mask, i32 + %chunk_base = arith.addi %row_qk, %chunk : index + %vec = pto.vlds %ub_qk[%chunk_base] : !pto.ptr -> !pto.vreg<64xf32> + %exp = pto.vexpdif %vec, %final_max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + %out = pto.vdiv %exp, %final_sum, %chunk_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%chunk_base], %chunk_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.tstore ins(%newmax_tile : !pto.tile_buf) + outs(%newmax_part : !pto.partition_tensor_view) + pto.tstore ins(%newsum_tile : !pto.tile_buf) + outs(%newsum_part : !pto.partition_tensor_view) + pto.tstore ins(%expmax_tile : !pto.tile_buf) + outs(%expmax_part : !pto.partition_tensor_view) + pto.tstore ins(%out_tile : !pto.tile_buf) + outs(%out_part : !pto.partition_tensor_view) + } + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/kernels/online-softmax-update/compare.py b/test/vpto/cases/kernels/online-softmax-update/compare.py index 40eba2276..75774a3a5 100644 --- a/test/vpto/cases/kernels/online-softmax-update/compare.py +++ b/test/vpto/cases/kernels/online-softmax-update/compare.py @@ -42,12 +42,57 @@ def compare_bin(golden_path, output_path, dtype, eps): return True +def compare_matrix_valid(golden_path, output_path, rows, cols, valid_cols, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + expected_elems = rows * cols + if golden.size != expected_elems or output.size != expected_elems: + print( + f"[ERROR] Shape mismatch: expected elems={expected_elems}, " + f"golden={golden.size}, out={output.size}" + ) + return False + golden = golden.reshape(rows, cols) + output = output.reshape(rows, cols) + if not np.allclose( + golden[:, :valid_cols], + output[:, :valid_cols], + atol=eps, + rtol=eps, + equal_nan=True, + ): + abs_diff = np.abs( + golden[:, :valid_cols].astype(np.float64) + - output[:, :valid_cols].astype(np.float64) + ) + flat_idx = int(np.argmax(abs_diff)) + row, col = divmod(flat_idx, valid_cols) + print( + f"[ERROR] Mismatch in valid region: max diff={float(abs_diff[row, col])} " + f"at row={row}, col={col} " + f"(golden={float(golden[row, col])}, out={float(output[row, col])}, dtype={dtype_np})" + ) + return False + return True + + def main(): + rows = int(np.fromfile("v9.bin", dtype=np.int32)[0]) + seq = int(np.fromfile("v8.bin", dtype=np.int32)[0]) ok = True ok = compare_bin("golden_v4.bin", "v4.bin", np.float32, 1e-4) and ok ok = compare_bin("golden_v5.bin", "v5.bin", np.float32, 1e-4) and ok ok = compare_bin("golden_v6.bin", "v6.bin", np.float32, 1e-4) and ok - ok = compare_bin("golden_v7.bin", "v7.bin", np.float32, 1e-4) and ok + ok = compare_matrix_valid( + "golden_v7.bin", "v7.bin", rows, 128, seq, np.float32, 1e-4 + ) and ok if not ok: print("[ERROR] compare failed") sys.exit(2) diff --git a/test/vpto/cases/kernels/online-softmax-update/kernel.pto b/test/vpto/cases/kernels/online-softmax-update/kernel.pto index 07bd45279..4cfc35a38 100644 --- a/test/vpto/cases/kernels/online-softmax-update/kernel.pto +++ b/test/vpto/cases/kernels/online-softmax-update/kernel.pto @@ -124,12 +124,6 @@ module attributes {pto.target_arch = "a5"} { pto.vsts %final_sum, %ub_newsum[%row], %one_mask {dist = "1PT_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask pto.vsts %expmax, %ub_expmax[%row], %one_mask {dist = "1PT_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask - %zero = pto.vsub %final_max, %final_max, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> - scf.for %chunk = %c0 to %c128 step %c64 { - %chunk_base = arith.addi %row_qk, %chunk : index - pto.vsts %zero, %ub_out[%chunk_base], %active : !pto.vreg<64xf32>, !pto.ptr, !pto.mask - } - scf.for %chunk = %c0 to %c128 step %c64 { %chunk_i32 = arith.index_cast %chunk : index to i32 %remaining_cols = arith.subi %arg7, %chunk_i32 : i32 diff --git a/test/vpto/cases/kernels/online-softmax-update/kernel_tload_tstore.pto b/test/vpto/cases/kernels/online-softmax-update/kernel_tload_tstore.pto new file mode 100644 index 000000000..88c161ff5 --- /dev/null +++ b/test/vpto/cases/kernels/online-softmax-update/kernel_tload_tstore.pto @@ -0,0 +1,249 @@ +// ----------------------------------------------------------------------------- +// case: kernels/online-softmax-update +// family: kernels +// target_ops: pto.get_block_idx, pto.tload, pto.tstore, pto.tile_buf_addr, pto.vlds, pto.vcmax, pto.vdup, pto.vmax, pto.vexpdif, pto.vcadd, pto.vadd, pto.vmul, pto.vdiv, pto.vsts +// scenarios: online-softmax-update, dynamic-rows-and-seq, max-seq-128, block-rows-8, oldmax-oldsum-qk-to-newmax-newsum-expmax-out, valid-region-only-output +// note: compile with --pto-level=level3 because alloc_tile uses explicit addr= +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @online_softmax_update_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr, + %arg3: !pto.ptr, + %arg4: !pto.ptr, + %arg5: !pto.ptr, + %arg6: !pto.ptr, + %arg7: i32, + %arg8: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c8_i64 = arith.constant 8 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c8448_i64 = arith.constant 8448 : i64 + %c16640_i64 = arith.constant 16640 : i64 + %c16768_i64 = arith.constant 16768 : i64 + %c16896_i64 = arith.constant 16896 : i64 + + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %c64_i32 = arith.constant 64 : i32 + %c0_i32 = arith.constant 0 : i32 + + %block = pto.get_block_idx + %block_idx = arith.index_cast %block : i64 to index + %row_base = arith.muli %block_idx, %c8 : index + %block_rows_i32 = arith.index_cast %c8 : index to i32 + %row_base_i32 = arith.index_cast %row_base : index to i32 + %remaining_rows = arith.subi %arg8, %row_base_i32 : i32 + %has_rows = arith.cmpi sgt, %remaining_rows, %c0_i32 : i32 + %too_many_rows = arith.cmpi sgt, %remaining_rows, %c8_i32 : i32 + %row_count_i32 = arith.select %too_many_rows, %c8_i32, %remaining_rows : i32 + %row_count = arith.index_cast %row_count_i32 : i32 to index + %seq = arith.index_cast %arg7 : i32 to index + %rows = arith.index_cast %arg8 : i32 to index + %rows_x_128 = arith.muli %rows, %c128 : index + + scf.if %has_rows { + %oldmax_view = pto.make_tensor_view %arg0, + shape = [%c1, %c1, %c1, %rows, %c1], + strides = [%rows, %rows, %rows, %c1, %rows] + : !pto.tensor_view + %oldsum_view = pto.make_tensor_view %arg1, + shape = [%c1, %c1, %c1, %rows, %c1], + strides = [%rows, %rows, %rows, %c1, %rows] + : !pto.tensor_view + %qk_view = pto.make_tensor_view %arg2, + shape = [%c1, %c1, %c1, %rows, %c128], + strides = [%rows_x_128, %rows_x_128, %rows_x_128, %c128, %c1] + : !pto.tensor_view + %newmax_view = pto.make_tensor_view %arg3, + shape = [%c1, %c1, %c1, %rows, %c1], + strides = [%rows, %rows, %rows, %c1, %rows] + : !pto.tensor_view + %newsum_view = pto.make_tensor_view %arg4, + shape = [%c1, %c1, %c1, %rows, %c1], + strides = [%rows, %rows, %rows, %c1, %rows] + : !pto.tensor_view + %expmax_view = pto.make_tensor_view %arg5, + shape = [%c1, %c1, %c1, %rows, %c1], + strides = [%rows, %rows, %rows, %c1, %rows] + : !pto.tensor_view + %out_view = pto.make_tensor_view %arg6, + shape = [%c1, %c1, %c1, %rows, %c128], + strides = [%rows_x_128, %rows_x_128, %rows_x_128, %c128, %c1] + : !pto.tensor_view + + %oldmax_part = pto.partition_view %oldmax_view, + offsets = [%c0, %c0, %c0, %row_base, %c0], + sizes = [%c1, %c1, %c1, %row_count, %c1] + : !pto.tensor_view -> !pto.partition_tensor_view + %oldsum_part = pto.partition_view %oldsum_view, + offsets = [%c0, %c0, %c0, %row_base, %c0], + sizes = [%c1, %c1, %c1, %row_count, %c1] + : !pto.tensor_view -> !pto.partition_tensor_view + %qk_part = pto.partition_view %qk_view, + offsets = [%c0, %c0, %c0, %row_base, %c0], + sizes = [%c1, %c1, %c1, %row_count, %seq] + : !pto.tensor_view -> !pto.partition_tensor_view + %newmax_part = pto.partition_view %newmax_view, + offsets = [%c0, %c0, %c0, %row_base, %c0], + sizes = [%c1, %c1, %c1, %row_count, %c1] + : !pto.tensor_view -> !pto.partition_tensor_view + %newsum_part = pto.partition_view %newsum_view, + offsets = [%c0, %c0, %c0, %row_base, %c0], + sizes = [%c1, %c1, %c1, %row_count, %c1] + : !pto.tensor_view -> !pto.partition_tensor_view + %expmax_part = pto.partition_view %expmax_view, + offsets = [%c0, %c0, %c0, %row_base, %c0], + sizes = [%c1, %c1, %c1, %row_count, %c1] + : !pto.tensor_view -> !pto.partition_tensor_view + %out_part = pto.partition_view %out_view, + offsets = [%c0, %c0, %c0, %row_base, %c0], + sizes = [%c1, %c1, %c1, %row_count, %seq] + : !pto.tensor_view -> !pto.partition_tensor_view + + %oldmax_tile = pto.alloc_tile addr = %c0_i64 valid_row = %row_count valid_col = %c1 + : !pto.tile_buf + %oldsum_tile = pto.alloc_tile addr = %c128_i64 valid_row = %row_count valid_col = %c1 + : !pto.tile_buf + %qk_tile = pto.alloc_tile addr = %c256_i64 valid_row = %row_count valid_col = %seq + : !pto.tile_buf + %out_tile = pto.alloc_tile addr = %c8448_i64 valid_row = %row_count valid_col = %seq + : !pto.tile_buf + %newmax_tile = pto.alloc_tile addr = %c16640_i64 valid_row = %row_count valid_col = %c1 + : !pto.tile_buf + %newsum_tile = pto.alloc_tile addr = %c16768_i64 valid_row = %row_count valid_col = %c1 + : !pto.tile_buf + %expmax_tile = pto.alloc_tile addr = %c16896_i64 valid_row = %row_count valid_col = %c1 + : !pto.tile_buf + + %ub_oldmax = pto.tile_buf_addr %oldmax_tile + : !pto.tile_buf + -> !pto.ptr + %ub_oldsum = pto.tile_buf_addr %oldsum_tile + : !pto.tile_buf + -> !pto.ptr + %ub_qk = pto.tile_buf_addr %qk_tile + : !pto.tile_buf + -> !pto.ptr + %ub_out = pto.tile_buf_addr %out_tile + : !pto.tile_buf + -> !pto.ptr + %ub_newmax = pto.tile_buf_addr %newmax_tile + : !pto.tile_buf + -> !pto.ptr + %ub_newsum = pto.tile_buf_addr %newsum_tile + : !pto.tile_buf + -> !pto.ptr + %ub_expmax = pto.tile_buf_addr %expmax_tile + : !pto.tile_buf + -> !pto.ptr + + pto.tload ins(%oldmax_part : !pto.partition_tensor_view) + outs(%oldmax_tile : !pto.tile_buf) + pto.tload ins(%oldsum_part : !pto.partition_tensor_view) + outs(%oldsum_tile : !pto.tile_buf) + pto.tload ins(%qk_part : !pto.partition_tensor_view) + outs(%qk_tile : !pto.tile_buf) + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %active = pto.pset_b32 "PAT_ALL" : !pto.mask + %one_mask, %one_remaining = pto.plt_b32 %c1_i32 : i32 -> !pto.mask, i32 + scf.for %row = %c0 to %row_count step %c1 { + %row_qk = arith.muli %row, %c128 : index + %oldmax_bc = pto.vlds %ub_oldmax[%row] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + %oldsum_bc = pto.vlds %ub_oldsum[%row] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + + %final_max, %final_sum = scf.for %chunk = %c0 to %c128 step %c64 + iter_args(%running_max = %oldmax_bc, %running_sum = %oldsum_bc) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %chunk_i32 = arith.index_cast %chunk : index to i32 + %remaining_cols = arith.subi %arg7, %chunk_i32 : i32 + %has_chunk = arith.cmpi sgt, %remaining_cols, %c0_i32 : i32 + %next_max, %next_sum = scf.if %has_chunk -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %chunk_mask, %chunk_rest = pto.plt_b32 %remaining_cols : i32 -> !pto.mask, i32 + %chunk_base = arith.addi %row_qk, %chunk : index + %vec = pto.vlds %ub_qk[%chunk_base] : !pto.ptr -> !pto.vreg<64xf32> + %chunk_max = pto.vcmax %vec, %chunk_mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %chunk_max_bc = pto.vdup %chunk_max, %active {position = "LOWEST"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %merged_max = pto.vmax %running_max, %chunk_max_bc, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %scaled_running = pto.vexpdif %running_max, %merged_max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + %running_sum_scaled = pto.vmul %scaled_running, %running_sum, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %chunk_exp = pto.vexpdif %vec, %merged_max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + %chunk_sum = pto.vcadd %chunk_exp, %chunk_mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %chunk_sum_bc = pto.vdup %chunk_sum, %active {position = "LOWEST"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %merged_sum = pto.vadd %running_sum_scaled, %chunk_sum_bc, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + scf.yield %merged_max, %merged_sum : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } else { + scf.yield %running_max, %running_sum : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + scf.yield %next_max, %next_sum : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + %raw_expmax = pto.vexpdif %oldmax_bc, %final_max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + %scaled_oldsum = pto.vmul %raw_expmax, %oldsum_bc, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %expmax = pto.vdiv %scaled_oldsum, %final_sum, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %final_max, %ub_newmax[%row], %one_mask {dist = "1PT_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.vsts %final_sum, %ub_newsum[%row], %one_mask {dist = "1PT_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.vsts %expmax, %ub_expmax[%row], %one_mask {dist = "1PT_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + + scf.for %chunk = %c0 to %c128 step %c64 { + %chunk_i32 = arith.index_cast %chunk : index to i32 + %remaining_cols = arith.subi %arg7, %chunk_i32 : i32 + %has_chunk = arith.cmpi sgt, %remaining_cols, %c0_i32 : i32 + scf.if %has_chunk { + %chunk_mask, %chunk_rest = pto.plt_b32 %remaining_cols : i32 -> !pto.mask, i32 + %chunk_base = arith.addi %row_qk, %chunk : index + %vec = pto.vlds %ub_qk[%chunk_base] : !pto.ptr -> !pto.vreg<64xf32> + %exp = pto.vexpdif %vec, %final_max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + %out = pto.vdiv %exp, %final_sum, %chunk_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%chunk_base], %chunk_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.tstore ins(%newmax_tile : !pto.tile_buf) + outs(%newmax_part : !pto.partition_tensor_view) + pto.tstore ins(%newsum_tile : !pto.tile_buf) + outs(%newsum_part : !pto.partition_tensor_view) + pto.tstore ins(%expmax_tile : !pto.tile_buf) + outs(%expmax_part : !pto.partition_tensor_view) + pto.tstore ins(%out_tile : !pto.tile_buf) + outs(%out_part : !pto.partition_tensor_view) + } + pto.barrier #pto.pipe + return + } +} diff --git a/tilelang-dsl/python/tilelang_dsl/expand_helper.py b/tilelang-dsl/python/tilelang_dsl/expand_helper.py index e8dd5f8a0..e30be67ed 100644 --- a/tilelang-dsl/python/tilelang_dsl/expand_helper.py +++ b/tilelang-dsl/python/tilelang_dsl/expand_helper.py @@ -121,6 +121,26 @@ def _match_descriptor( return None +def _parse_optional_int_sequence( + values: list[object], + *, + field_name: str, + index: int, +) -> tuple[int | None, ...]: + parsed: list[int | None] = [] + for dim in values: + if dim is None: + parsed.append(None) + continue + try: + parsed.append(int(dim)) + except (TypeError, ValueError) as exc: + raise ValueError( + f"operand-specs[{index}] {field_name} entries must be integers or null" + ) from exc + return tuple(parsed) + + def _parse_operand_specs(spec_text: str) -> list[dict]: try: raw_specs = json.loads(spec_text) @@ -170,7 +190,13 @@ def _parse_operand_specs(spec_text: str) -> list[dict]: "kind": "tile", "dtype": dtype, "shape": tuple(int(dim) for dim in shape), - "valid_shape": None if valid_shape is None else tuple(int(dim) for dim in valid_shape), + "valid_shape": None + if valid_shape is None + else _parse_optional_int_sequence( + valid_shape, + field_name="tile valid_shape", + index=index, + ), "config": config, "memory_space": memory_space, } @@ -188,7 +214,11 @@ def _parse_operand_specs(spec_text: str) -> list[dict]: view_spec: dict = { "kind": "view", "dtype": dtype, - "shape": tuple(int(dim) for dim in shape), + "shape": _parse_optional_int_sequence( + shape, + field_name="view shape", + index=index, + ), "memory_space": memory_space, } raw_strides = raw.get("strides") diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index e124a7154..422b550c0 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -464,6 +464,17 @@ static bool containsVPTOIR(llvm::StringRef input) { return false; } +static bool hasUnexpandedTileOps(ModuleOp module) { + bool found = false; + module.walk([&](Operation *op) { + if (found) + return; + if (isa(op)) + found = true; + }); + return found; +} + // -------------------------------------------------------------------------- // Post-process C++ output: rewrite marker calls into Tile member calls. // @@ -1467,7 +1478,10 @@ int main(int argc, char **argv) { return 1; } - if (effectiveBackend == PTOBackend::VPTO && inputIsVPTOIR) { + const bool hasTileOpsToExpand = hasUnexpandedTileOps(*module); + + if (effectiveBackend == PTOBackend::VPTO && inputIsVPTOIR && + !hasTileOpsToExpand) { if (ptoPrintSeamIR || !ptoSeamIRFile.empty()) { llvm::errs() << "Error: shared pre-backend seam IR is unavailable when " "the input is already VPTO IR.\n"; From a2fafdd0407a7539e1b3a2cc72fa055794be01f3 Mon Sep 17 00:00:00 2001 From: qukelin Date: Fri, 24 Apr 2026 10:06:10 +0800 Subject: [PATCH 151/192] Refine mixed Tile/VPTO softmax kernels and CI fixes --- .../a5/src/st/testcase/softmax/softmax.pto | 101 ++++++++---------- .../kernel_tload_tstore.pto | 100 ++++++++--------- .../python/tilelang_dsl/expand_helper.py | 8 ++ 3 files changed, 99 insertions(+), 110 deletions(-) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/softmax/softmax.pto b/test/tilelang_st/npu/a5/src/st/testcase/softmax/softmax.pto index bba0b1dd7..1a90261c0 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/softmax/softmax.pto +++ b/test/tilelang_st/npu/a5/src/st/testcase/softmax/softmax.pto @@ -111,68 +111,57 @@ module attributes {pto.target_arch = "a5"} { sizes = [%c1, %c1, %c1, %row_count, %seq] : !pto.tensor_view -> !pto.partition_tensor_view - %oldmax_tile = pto.alloc_tile addr = %c0_i64 valid_row = %row_count valid_col = %c1 - : !pto.tile_buf - %oldsum_tile = pto.alloc_tile addr = %c128_i64 valid_row = %row_count valid_col = %c1 - : !pto.tile_buf + // Tile domain: alloc_tile creates UB tile handles; tload/tstore operate + // on tile_buf values before/after the vector scope compute region. + %oldmax_tile = pto.alloc_tile addr = %c0_i64 valid_row = %row_count + : !pto.tile_buf + %oldsum_tile = pto.alloc_tile addr = %c128_i64 valid_row = %row_count + : !pto.tile_buf %qk_tile = pto.alloc_tile addr = %c256_i64 valid_row = %row_count valid_col = %seq - : !pto.tile_buf + : !pto.tile_buf %out_tile = pto.alloc_tile addr = %c8448_i64 valid_row = %row_count valid_col = %seq - : !pto.tile_buf - %newmax_tile = pto.alloc_tile addr = %c16640_i64 valid_row = %row_count valid_col = %c1 - : !pto.tile_buf - %newsum_tile = pto.alloc_tile addr = %c16768_i64 valid_row = %row_count valid_col = %c1 - : !pto.tile_buf - %expmax_tile = pto.alloc_tile addr = %c16896_i64 valid_row = %row_count valid_col = %c1 - : !pto.tile_buf - - %ub_oldmax = pto.tile_buf_addr %oldmax_tile - : !pto.tile_buf - -> !pto.ptr - %ub_oldsum = pto.tile_buf_addr %oldsum_tile - : !pto.tile_buf - -> !pto.ptr - %ub_qk = pto.tile_buf_addr %qk_tile - : !pto.tile_buf - -> !pto.ptr - %ub_out = pto.tile_buf_addr %out_tile - : !pto.tile_buf - -> !pto.ptr - %ub_newmax = pto.tile_buf_addr %newmax_tile - : !pto.tile_buf - -> !pto.ptr - %ub_newsum = pto.tile_buf_addr %newsum_tile - : !pto.tile_buf - -> !pto.ptr - %ub_expmax = pto.tile_buf_addr %expmax_tile - : !pto.tile_buf - -> !pto.ptr + : !pto.tile_buf + %newmax_tile = pto.alloc_tile addr = %c16640_i64 valid_row = %row_count + : !pto.tile_buf + %newsum_tile = pto.alloc_tile addr = %c16768_i64 valid_row = %row_count + : !pto.tile_buf + %expmax_tile = pto.alloc_tile addr = %c16896_i64 valid_row = %row_count + : !pto.tile_buf pto.tload ins(%oldmax_part : !pto.partition_tensor_view) - outs(%oldmax_tile : !pto.tile_buf) + outs(%oldmax_tile : !pto.tile_buf) pto.tload ins(%oldsum_part : !pto.partition_tensor_view) - outs(%oldsum_tile : !pto.tile_buf) + outs(%oldsum_tile : !pto.tile_buf) pto.tload ins(%qk_part : !pto.partition_tensor_view) - outs(%qk_tile : !pto.tile_buf) + outs(%qk_tile : !pto.tile_buf) pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.vecscope { + // Boundary into vecscope instructions: tile_buf_addr materializes UB + // pointers from tile handles so vecscope can use vlds/vsts. + %ub_oldmax = pto.tile_buf_addr %oldmax_tile + : !pto.tile_buf + -> !pto.ptr + %ub_oldsum = pto.tile_buf_addr %oldsum_tile + : !pto.tile_buf + -> !pto.ptr + %ub_qk = pto.tile_buf_addr %qk_tile + : !pto.tile_buf + -> !pto.ptr + %ub_out = pto.tile_buf_addr %out_tile + : !pto.tile_buf + -> !pto.ptr + %ub_newmax = pto.tile_buf_addr %newmax_tile + : !pto.tile_buf + -> !pto.ptr + %ub_newsum = pto.tile_buf_addr %newsum_tile + : !pto.tile_buf + -> !pto.ptr + %ub_expmax = pto.tile_buf_addr %expmax_tile + : !pto.tile_buf + -> !pto.ptr %active = pto.pset_b32 "PAT_ALL" : !pto.mask %one_mask, %one_remaining = pto.plt_b32 %c1_i32 : i32 -> !pto.mask, i32 scf.for %row = %c0 to %row_count step %c1 { @@ -232,13 +221,15 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.tstore ins(%newmax_tile : !pto.tile_buf) + // Back in the tile domain: tstore writes the tile_buf results to GM + // partitions after the VPTO vecscope finishes. + pto.tstore ins(%newmax_tile : !pto.tile_buf) outs(%newmax_part : !pto.partition_tensor_view) - pto.tstore ins(%newsum_tile : !pto.tile_buf) + pto.tstore ins(%newsum_tile : !pto.tile_buf) outs(%newsum_part : !pto.partition_tensor_view) - pto.tstore ins(%expmax_tile : !pto.tile_buf) + pto.tstore ins(%expmax_tile : !pto.tile_buf) outs(%expmax_part : !pto.partition_tensor_view) - pto.tstore ins(%out_tile : !pto.tile_buf) + pto.tstore ins(%out_tile : !pto.tile_buf) outs(%out_part : !pto.partition_tensor_view) } pto.barrier #pto.pipe diff --git a/test/vpto/cases/kernels/online-softmax-update/kernel_tload_tstore.pto b/test/vpto/cases/kernels/online-softmax-update/kernel_tload_tstore.pto index 88c161ff5..cadd1b618 100644 --- a/test/vpto/cases/kernels/online-softmax-update/kernel_tload_tstore.pto +++ b/test/vpto/cases/kernels/online-softmax-update/kernel_tload_tstore.pto @@ -113,68 +113,57 @@ module attributes {pto.target_arch = "a5"} { sizes = [%c1, %c1, %c1, %row_count, %seq] : !pto.tensor_view -> !pto.partition_tensor_view - %oldmax_tile = pto.alloc_tile addr = %c0_i64 valid_row = %row_count valid_col = %c1 - : !pto.tile_buf - %oldsum_tile = pto.alloc_tile addr = %c128_i64 valid_row = %row_count valid_col = %c1 - : !pto.tile_buf + // Tile domain: alloc_tile creates UB tile handles; tload/tstore operate + // on tile_buf values before/after the vector scope compute region. + %oldmax_tile = pto.alloc_tile addr = %c0_i64 valid_row = %row_count + : !pto.tile_buf + %oldsum_tile = pto.alloc_tile addr = %c128_i64 valid_row = %row_count + : !pto.tile_buf %qk_tile = pto.alloc_tile addr = %c256_i64 valid_row = %row_count valid_col = %seq - : !pto.tile_buf + : !pto.tile_buf %out_tile = pto.alloc_tile addr = %c8448_i64 valid_row = %row_count valid_col = %seq - : !pto.tile_buf - %newmax_tile = pto.alloc_tile addr = %c16640_i64 valid_row = %row_count valid_col = %c1 - : !pto.tile_buf - %newsum_tile = pto.alloc_tile addr = %c16768_i64 valid_row = %row_count valid_col = %c1 - : !pto.tile_buf - %expmax_tile = pto.alloc_tile addr = %c16896_i64 valid_row = %row_count valid_col = %c1 - : !pto.tile_buf - - %ub_oldmax = pto.tile_buf_addr %oldmax_tile - : !pto.tile_buf - -> !pto.ptr - %ub_oldsum = pto.tile_buf_addr %oldsum_tile - : !pto.tile_buf - -> !pto.ptr - %ub_qk = pto.tile_buf_addr %qk_tile - : !pto.tile_buf - -> !pto.ptr - %ub_out = pto.tile_buf_addr %out_tile - : !pto.tile_buf - -> !pto.ptr - %ub_newmax = pto.tile_buf_addr %newmax_tile - : !pto.tile_buf - -> !pto.ptr - %ub_newsum = pto.tile_buf_addr %newsum_tile - : !pto.tile_buf - -> !pto.ptr - %ub_expmax = pto.tile_buf_addr %expmax_tile - : !pto.tile_buf - -> !pto.ptr + : !pto.tile_buf + %newmax_tile = pto.alloc_tile addr = %c16640_i64 valid_row = %row_count + : !pto.tile_buf + %newsum_tile = pto.alloc_tile addr = %c16768_i64 valid_row = %row_count + : !pto.tile_buf + %expmax_tile = pto.alloc_tile addr = %c16896_i64 valid_row = %row_count + : !pto.tile_buf pto.tload ins(%oldmax_part : !pto.partition_tensor_view) - outs(%oldmax_tile : !pto.tile_buf) + outs(%oldmax_tile : !pto.tile_buf) pto.tload ins(%oldsum_part : !pto.partition_tensor_view) - outs(%oldsum_tile : !pto.tile_buf) + outs(%oldsum_tile : !pto.tile_buf) pto.tload ins(%qk_part : !pto.partition_tensor_view) - outs(%qk_tile : !pto.tile_buf) + outs(%qk_tile : !pto.tile_buf) pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.vecscope { + // Boundary into vecscope instructions: tile_buf_addr materializes UB + // pointers from tile handles so vecscope can use vlds/vsts. + %ub_oldmax = pto.tile_buf_addr %oldmax_tile + : !pto.tile_buf + -> !pto.ptr + %ub_oldsum = pto.tile_buf_addr %oldsum_tile + : !pto.tile_buf + -> !pto.ptr + %ub_qk = pto.tile_buf_addr %qk_tile + : !pto.tile_buf + -> !pto.ptr + %ub_out = pto.tile_buf_addr %out_tile + : !pto.tile_buf + -> !pto.ptr + %ub_newmax = pto.tile_buf_addr %newmax_tile + : !pto.tile_buf + -> !pto.ptr + %ub_newsum = pto.tile_buf_addr %newsum_tile + : !pto.tile_buf + -> !pto.ptr + %ub_expmax = pto.tile_buf_addr %expmax_tile + : !pto.tile_buf + -> !pto.ptr %active = pto.pset_b32 "PAT_ALL" : !pto.mask %one_mask, %one_remaining = pto.plt_b32 %c1_i32 : i32 -> !pto.mask, i32 scf.for %row = %c0 to %row_count step %c1 { @@ -234,13 +223,14 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.tstore ins(%newmax_tile : !pto.tile_buf) + // Leave the VPTO vecscope and switch back to tile-domain DMA ops. + pto.tstore ins(%newmax_tile : !pto.tile_buf) outs(%newmax_part : !pto.partition_tensor_view) - pto.tstore ins(%newsum_tile : !pto.tile_buf) + pto.tstore ins(%newsum_tile : !pto.tile_buf) outs(%newsum_part : !pto.partition_tensor_view) - pto.tstore ins(%expmax_tile : !pto.tile_buf) + pto.tstore ins(%expmax_tile : !pto.tile_buf) outs(%expmax_part : !pto.partition_tensor_view) - pto.tstore ins(%out_tile : !pto.tile_buf) + pto.tstore ins(%out_tile : !pto.tile_buf) outs(%out_part : !pto.partition_tensor_view) } pto.barrier #pto.pipe diff --git a/tilelang-dsl/python/tilelang_dsl/expand_helper.py b/tilelang-dsl/python/tilelang_dsl/expand_helper.py index e30be67ed..4d2a875fd 100644 --- a/tilelang-dsl/python/tilelang_dsl/expand_helper.py +++ b/tilelang-dsl/python/tilelang_dsl/expand_helper.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + """CLI helper invoked by ExpandTileOp to instantiate a tilelang DSL template. Usage: From bc85feaf704c83a241ecaced1b4eb78ecc394199 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Thu, 23 Apr 2026 20:44:11 +0800 Subject: [PATCH 152/192] fix(pto): avoid false A5 trowarg vec overflow (#558) --- .../issue558_trowargmax_vec_overflow_a5.pto | 68 +++++++++++++++++++ .../issue558_trowargmin_vec_overflow_a5.pto | 68 +++++++++++++++++++ 2 files changed, 136 insertions(+) create mode 100644 test/basic/issue558_trowargmax_vec_overflow_a5.pto create mode 100644 test/basic/issue558_trowargmin_vec_overflow_a5.pto diff --git a/test/basic/issue558_trowargmax_vec_overflow_a5.pto b/test/basic/issue558_trowargmax_vec_overflow_a5.pto new file mode 100644 index 000000000..0cef1576c --- /dev/null +++ b/test/basic/issue558_trowargmax_vec_overflow_a5.pto @@ -0,0 +1,68 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 %s -o - 2>&1 | FileCheck %s + +// Regression for issue #558: +// A5 TROWARGMAX should not treat tmp as a required scratch write in local +// memory planning, otherwise this shape triggers a false vec overflow. +// CHECK-NOT: vec overflow +// CHECK: __global__ AICORE void TROWARGMAX_TMP_OVERFLOW( + +module { + func.func @TROWARGMAX_TMP_OVERFLOW(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2_r = arith.constant 2 : index + %c16384_c = arith.constant 16384 : index + %c32768_se = arith.constant 32768 : index + %c2_de = arith.constant 2 : index + %c1_dc = arith.constant 1 : index + %c2_vr = arith.constant 2 : index + %c16381_vc = arith.constant 16381 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2_r, %c16384_c], + strides = [%c32768_se, %c32768_se, %c32768_se, %c16384_c, %c1] + : !pto.tensor_view<1x1x1x2x16384xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2_r, %c1_dc], + strides = [%c2_de, %c2_de, %c2_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x2x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2_vr, %c16381_vc] + : !pto.tensor_view<1x1x1x2x16384xf32> -> !pto.partition_tensor_view<1x1x1x2x16381xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2_vr, %c1] + : !pto.tensor_view<1x1x1x2x1xui32> -> !pto.partition_tensor_view<1x1x1x2x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x16381xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x1xui32>) + return + } +} diff --git a/test/basic/issue558_trowargmin_vec_overflow_a5.pto b/test/basic/issue558_trowargmin_vec_overflow_a5.pto new file mode 100644 index 000000000..47a190d2a --- /dev/null +++ b/test/basic/issue558_trowargmin_vec_overflow_a5.pto @@ -0,0 +1,68 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 %s -o - 2>&1 | FileCheck %s + +// Regression for issue #558: +// A5 TROWARGMIN should not treat tmp as a required scratch write in local +// memory planning, otherwise this shape triggers a false vec overflow. +// CHECK-NOT: vec overflow +// CHECK: __global__ AICORE void TROWARGMIN_TMP_OVERFLOW( + +module { + func.func @TROWARGMIN_TMP_OVERFLOW(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2_r = arith.constant 2 : index + %c16384_c = arith.constant 16384 : index + %c32768_se = arith.constant 32768 : index + %c2_de = arith.constant 2 : index + %c1_dc = arith.constant 1 : index + %c2_vr = arith.constant 2 : index + %c16381_vc = arith.constant 16381 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2_r, %c16384_c], + strides = [%c32768_se, %c32768_se, %c32768_se, %c16384_c, %c1] + : !pto.tensor_view<1x1x1x2x16384xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2_r, %c1_dc], + strides = [%c2_de, %c2_de, %c2_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x2x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2_vr, %c16381_vc] + : !pto.tensor_view<1x1x1x2x16384xf32> -> !pto.partition_tensor_view<1x1x1x2x16381xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2_vr, %c1] + : !pto.tensor_view<1x1x1x2x1xui32> -> !pto.partition_tensor_view<1x1x1x2x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x16381xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x1xui32>) + return + } +} From 69e6dcd6732c265bbdea97663dac39c932845c11 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Thu, 23 Apr 2026 20:53:59 +0800 Subject: [PATCH 153/192] Fix tcolarg verifier element width constraints --- lib/PTO/IR/PTO.cpp | 1 - test/basic/issue554_tcolarg_invalid_width.pto | 16 +++++++++++++ test/basic/issue554_tcolarg_types.pto | 24 +++++++++++++++++++ 3 files changed, 40 insertions(+), 1 deletion(-) create mode 100644 test/basic/issue554_tcolarg_invalid_width.pto create mode 100644 test/basic/issue554_tcolarg_types.pto diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index 1e258f009..88a454231 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -3010,7 +3010,6 @@ static LogicalResult verifyVecTileStorage(Operation *op, Type ty, StringRef name return op->emitOpError() << "expects " << name << " to be in the vec address space"; return success(); } - static LogicalResult verifyVecTileCommonA2A3(Operation *op, Type ty, StringRef name) { if (failed(verifyTileBufCommon(op, ty, name))) diff --git a/test/basic/issue554_tcolarg_invalid_width.pto b/test/basic/issue554_tcolarg_invalid_width.pto new file mode 100644 index 000000000..f25559efe --- /dev/null +++ b/test/basic/issue554_tcolarg_invalid_width.pto @@ -0,0 +1,16 @@ +// RUN: not ptoas --pto-arch=a3 %s 2>&1 | FileCheck %s +// RUN: not ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s + +module { + func.func @issue554_tcolargmax_i64_rejected() { + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tcolargmax ins(%src, %tmp : !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} + +// CHECK: error: 'pto.tcolargmax' op expects src/tmp element type to be 1, 2, or 4 bytes wide diff --git a/test/basic/issue554_tcolarg_types.pto b/test/basic/issue554_tcolarg_types.pto new file mode 100644 index 000000000..ca7f822da --- /dev/null +++ b/test/basic/issue554_tcolarg_types.pto @@ -0,0 +1,24 @@ +// RUN: ptoas --pto-arch=a3 %s >/dev/null +// RUN: ptoas --pto-arch=a5 %s >/dev/null + +module { + func.func @issue554_tcolargmax_ui16() { + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tcolargmax ins(%src, %tmp : !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } + + func.func @issue554_tcolargmin_ui32() { + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tcolargmin ins(%src, %tmp : !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} From d05c68f5abde91b012226588e1f1cbbbd4cdaad2 Mon Sep 17 00:00:00 2001 From: liggest <43201720+liggest@users.noreply.github.com> Date: Fri, 24 Apr 2026 16:33:22 +0800 Subject: [PATCH 154/192] feat(tileop): Add unary tileop templates (#168) * Add TABS tileop template * Add TEXP tileop template * Add TLOG tileop template * Add TNEG tileop template * Add TNOT tileop template * Add TRECIP tileop template * Add TRSQRT tileop template * Add TSQRT tileop template * Add TODOs for HIGH_PRECISION type * Specify supported dtypes for TRECIP & TRSQRT * Add license headers for uop tileop templates --- lib/TileOps/tabs_template.py | 29 ++ lib/TileOps/texp_template.py | 29 ++ lib/TileOps/tlog_template.py | 29 ++ lib/TileOps/tneg_template.py | 29 ++ lib/TileOps/tnot_template.py | 30 ++ lib/TileOps/trecip_template.py | 36 +++ lib/TileOps/trsqrt_template.py | 36 +++ lib/TileOps/tsqrt_template.py | 29 ++ test/basic/expand_tile_op_tilelang_tabs.pto | 40 +++ test/basic/expand_tile_op_tilelang_texp.pto | 40 +++ test/basic/expand_tile_op_tilelang_tlog.pto | 40 +++ test/basic/expand_tile_op_tilelang_tneg.pto | 40 +++ test/basic/expand_tile_op_tilelang_tnot.pto | 40 +++ test/basic/expand_tile_op_tilelang_trecip.pto | 40 +++ test/basic/expand_tile_op_tilelang_trsqrt.pto | 41 +++ test/basic/expand_tile_op_tilelang_tsqrt.pto | 40 +++ .../npu/a5/src/st/testcase/CMakeLists.txt | 8 + .../a5/src/st/testcase/tabs/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tabs/cases.py | 55 ++++ .../npu/a5/src/st/testcase/tabs/compare.py | 49 +++ .../npu/a5/src/st/testcase/tabs/gen_data.py | 32 ++ .../npu/a5/src/st/testcase/tabs/launch.cpp | 41 +++ .../npu/a5/src/st/testcase/tabs/main.cpp | 137 ++++++++ .../npu/a5/src/st/testcase/tabs/tabs.pto | 204 ++++++++++++ .../a5/src/st/testcase/texp/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/texp/cases.py | 55 ++++ .../npu/a5/src/st/testcase/texp/compare.py | 49 +++ .../npu/a5/src/st/testcase/texp/gen_data.py | 32 ++ .../npu/a5/src/st/testcase/texp/launch.cpp | 41 +++ .../npu/a5/src/st/testcase/texp/main.cpp | 137 ++++++++ .../npu/a5/src/st/testcase/texp/texp.pto | 204 ++++++++++++ .../a5/src/st/testcase/tlog/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tlog/cases.py | 55 ++++ .../npu/a5/src/st/testcase/tlog/compare.py | 49 +++ .../npu/a5/src/st/testcase/tlog/gen_data.py | 33 ++ .../npu/a5/src/st/testcase/tlog/launch.cpp | 41 +++ .../npu/a5/src/st/testcase/tlog/main.cpp | 137 ++++++++ .../npu/a5/src/st/testcase/tlog/tlog.pto | 204 ++++++++++++ .../a5/src/st/testcase/tneg/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tneg/cases.py | 69 ++++ .../npu/a5/src/st/testcase/tneg/compare.py | 49 +++ .../npu/a5/src/st/testcase/tneg/gen_data.py | 33 ++ .../npu/a5/src/st/testcase/tneg/launch.cpp | 55 ++++ .../npu/a5/src/st/testcase/tneg/main.cpp | 141 ++++++++ .../npu/a5/src/st/testcase/tneg/tneg.pto | 299 +++++++++++++++++ .../a5/src/st/testcase/tnot/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tnot/cases.py | 69 ++++ .../npu/a5/src/st/testcase/tnot/compare.py | 49 +++ .../npu/a5/src/st/testcase/tnot/gen_data.py | 30 ++ .../npu/a5/src/st/testcase/tnot/launch.cpp | 55 ++++ .../npu/a5/src/st/testcase/tnot/main.cpp | 141 ++++++++ .../npu/a5/src/st/testcase/tnot/tnot.pto | 299 +++++++++++++++++ .../a5/src/st/testcase/trecip/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/trecip/cases.py | 69 ++++ .../npu/a5/src/st/testcase/trecip/compare.py | 49 +++ .../npu/a5/src/st/testcase/trecip/gen_data.py | 31 ++ .../npu/a5/src/st/testcase/trecip/launch.cpp | 55 ++++ .../npu/a5/src/st/testcase/trecip/main.cpp | 141 ++++++++ .../npu/a5/src/st/testcase/trecip/trecip.pto | 304 ++++++++++++++++++ .../a5/src/st/testcase/trsqrt/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/trsqrt/cases.py | 55 ++++ .../npu/a5/src/st/testcase/trsqrt/compare.py | 49 +++ .../npu/a5/src/st/testcase/trsqrt/gen_data.py | 34 ++ .../npu/a5/src/st/testcase/trsqrt/launch.cpp | 41 +++ .../npu/a5/src/st/testcase/trsqrt/main.cpp | 137 ++++++++ .../npu/a5/src/st/testcase/trsqrt/trsqrt.pto | 205 ++++++++++++ .../a5/src/st/testcase/tsqrt/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tsqrt/cases.py | 55 ++++ .../npu/a5/src/st/testcase/tsqrt/compare.py | 49 +++ .../npu/a5/src/st/testcase/tsqrt/gen_data.py | 33 ++ .../npu/a5/src/st/testcase/tsqrt/launch.cpp | 41 +++ .../npu/a5/src/st/testcase/tsqrt/main.cpp | 137 ++++++++ .../npu/a5/src/st/testcase/tsqrt/tsqrt.pto | 204 ++++++++++++ 73 files changed, 5181 insertions(+) create mode 100644 lib/TileOps/tabs_template.py create mode 100644 lib/TileOps/texp_template.py create mode 100644 lib/TileOps/tlog_template.py create mode 100644 lib/TileOps/tneg_template.py create mode 100644 lib/TileOps/tnot_template.py create mode 100644 lib/TileOps/trecip_template.py create mode 100644 lib/TileOps/trsqrt_template.py create mode 100644 lib/TileOps/tsqrt_template.py create mode 100644 test/basic/expand_tile_op_tilelang_tabs.pto create mode 100644 test/basic/expand_tile_op_tilelang_texp.pto create mode 100644 test/basic/expand_tile_op_tilelang_tlog.pto create mode 100644 test/basic/expand_tile_op_tilelang_tneg.pto create mode 100644 test/basic/expand_tile_op_tilelang_tnot.pto create mode 100644 test/basic/expand_tile_op_tilelang_trecip.pto create mode 100644 test/basic/expand_tile_op_tilelang_trsqrt.pto create mode 100644 test/basic/expand_tile_op_tilelang_tsqrt.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tabs/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tabs/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tabs/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tabs/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tabs/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tabs/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tabs/tabs.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/texp/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/texp/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/texp/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/texp/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/texp/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/texp/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/texp/texp.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tlog/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tlog/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tlog/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tlog/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tlog/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tlog/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tlog/tlog.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tneg/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tneg/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tneg/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tneg/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tneg/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tneg/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tneg/tneg.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tnot/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tnot/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tnot/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tnot/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tnot/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tnot/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tnot/tnot.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trecip/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trecip/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trecip/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trecip/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trecip/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trecip/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trecip/trecip.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trsqrt/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trsqrt/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trsqrt/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trsqrt/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trsqrt/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trsqrt/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trsqrt/trsqrt.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tsqrt/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tsqrt/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tsqrt/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tsqrt/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tsqrt/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tsqrt/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tsqrt/tsqrt.pto diff --git a/lib/TileOps/tabs_template.py b/lib/TileOps/tabs_template.py new file mode 100644 index 000000000..6c6802ae5 --- /dev/null +++ b/lib/TileOps/tabs_template.py @@ -0,0 +1,29 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tabs""" + +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tabs" +) +def template_tabs(src: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vinput = pto.vlds(src[row, col:]) + result = pto.vabs(vinput, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/texp_template.py b/lib/TileOps/texp_template.py new file mode 100644 index 000000000..0208596b3 --- /dev/null +++ b/lib/TileOps/texp_template.py @@ -0,0 +1,29 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.texp""" + +import tilelang_dsl as pto + +# TODO: Add implementation for HIGH_PRECISION type +@pto.vkernel( + target="a5", + op="pto.texp" +) +def template_texp(src: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vinput = pto.vlds(src[row, col:]) + result = pto.vexp(vinput, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/tlog_template.py b/lib/TileOps/tlog_template.py new file mode 100644 index 000000000..faf7a63bb --- /dev/null +++ b/lib/TileOps/tlog_template.py @@ -0,0 +1,29 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tlog""" + +import tilelang_dsl as pto + +# TODO: Add implementation for HIGH_PRECISION type +@pto.vkernel( + target="a5", + op="pto.tlog" +) +def template_tlog(src: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vinput = pto.vlds(src[row, col:]) + result = pto.vln(vinput, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/tneg_template.py b/lib/TileOps/tneg_template.py new file mode 100644 index 000000000..8e10ce4ca --- /dev/null +++ b/lib/TileOps/tneg_template.py @@ -0,0 +1,29 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tneg""" + +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tneg" +) +def template_tneg(src: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vinput = pto.vlds(src[row, col:]) + result = pto.vneg(vinput, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/tnot_template.py b/lib/TileOps/tnot_template.py new file mode 100644 index 000000000..f4728e853 --- /dev/null +++ b/lib/TileOps/tnot_template.py @@ -0,0 +1,30 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tnot""" + +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tnot", + dtypes=[(pto.AnyInt, pto.AnyInt)] +) +def template_tnot(src: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vinput = pto.vlds(src[row, col:]) + result = pto.vnot(vinput, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/trecip_template.py b/lib/TileOps/trecip_template.py new file mode 100644 index 000000000..657706634 --- /dev/null +++ b/lib/TileOps/trecip_template.py @@ -0,0 +1,36 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.trecip""" + +import tilelang_dsl as pto + +# TODO: Add implementation for HIGH_PRECISION type +@pto.vkernel( + target="a5", + op="pto.trecip", + dtypes=[(pto.f16, pto.f16), (pto.f32, pto.f32)] +) +def template_trecip(src: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vinput = pto.vlds(src[row, col:]) + if pto.constexpr(dtype == pto.f16): + one_scalar = pto.f16(1.0) + else: + one_scalar = pto.f32(1.0) + one = pto.vbr(one_scalar) + # one = pto.vbr(dtype(1.0)) + result = pto.vdiv(one, vinput, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/trsqrt_template.py b/lib/TileOps/trsqrt_template.py new file mode 100644 index 000000000..87368adca --- /dev/null +++ b/lib/TileOps/trsqrt_template.py @@ -0,0 +1,36 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.trsqrt""" + +import tilelang_dsl as pto + +# TODO: Add implementation for HIGH_PRECISION type +@pto.vkernel( + target="a5", + op="pto.trsqrt", + dtypes=[(pto.f16, pto.f16), (pto.f32, pto.f32)] +) +def template_trsqrt(src: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vinput = pto.vlds(src[row, col:]) + if pto.constexpr(dtype == pto.f16): + one_scalar = pto.f16(1.0) + else: + one_scalar = pto.f32(1.0) + one = pto.vbr(one_scalar) + sqrt_result = pto.vsqrt(vinput, mask) + result = pto.vdiv(one, sqrt_result, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/tsqrt_template.py b/lib/TileOps/tsqrt_template.py new file mode 100644 index 000000000..45c707fde --- /dev/null +++ b/lib/TileOps/tsqrt_template.py @@ -0,0 +1,29 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tsqrt""" + +import tilelang_dsl as pto + +# TODO: Add implementation for HIGH_PRECISION type +@pto.vkernel( + target="a5", + op="pto.tsqrt" +) +def template_tsqrt(src: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vinput = pto.vlds(src[row, col:]) + result = pto.vsqrt(vinput, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/test/basic/expand_tile_op_tilelang_tabs.pto b/test/basic/expand_tile_op_tilelang_tabs.pto new file mode 100644 index 000000000..6c7e1084c --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_tabs.pto @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tabs should be lowered to vector-style VPTO IR. +// CHECK: func.func @TABS +// CHECK-NOT: pto.tabs ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vabs +// CHECK: pto.vsts + +module { + func.func @TABS() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.tabs ins(%a: !pto.tile_buf) + outs(%tile_buf: !pto.tile_buf) + return + } +} diff --git a/test/basic/expand_tile_op_tilelang_texp.pto b/test/basic/expand_tile_op_tilelang_texp.pto new file mode 100644 index 000000000..49399dceb --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_texp.pto @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.texp should be lowered to vector-style VPTO IR. +// CHECK: func.func @TEXP +// CHECK-NOT: pto.texp ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vexp +// CHECK: pto.vsts + +module { + func.func @TEXP() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.texp ins(%a: !pto.tile_buf) + outs(%tile_buf: !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/expand_tile_op_tilelang_tlog.pto b/test/basic/expand_tile_op_tilelang_tlog.pto new file mode 100644 index 000000000..673105a40 --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_tlog.pto @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tlog should be lowered to vector-style VPTO IR. +// CHECK: func.func @TLOG +// CHECK-NOT: pto.tlog ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vln +// CHECK: pto.vsts + +module { + func.func @TLOG() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.tlog ins(%a: !pto.tile_buf) + outs(%tile_buf: !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/expand_tile_op_tilelang_tneg.pto b/test/basic/expand_tile_op_tilelang_tneg.pto new file mode 100644 index 000000000..939686186 --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_tneg.pto @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tneg should be lowered to vector-style VPTO IR. +// CHECK: func.func @TNEG +// CHECK-NOT: pto.tneg ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vmuls +// CHECK: pto.vsts + +module { + func.func @TNEG() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.tneg ins(%a: !pto.tile_buf) + outs(%tile_buf: !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/expand_tile_op_tilelang_tnot.pto b/test/basic/expand_tile_op_tilelang_tnot.pto new file mode 100644 index 000000000..5b2c75965 --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_tnot.pto @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tnot should be lowered to vector-style VPTO IR. +// CHECK: func.func @TNOT +// CHECK-NOT: pto.tnot ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vnot +// CHECK: pto.vsts + +module { + func.func @TNOT() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.tnot ins(%a: !pto.tile_buf) + outs(%tile_buf: !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/expand_tile_op_tilelang_trecip.pto b/test/basic/expand_tile_op_tilelang_trecip.pto new file mode 100644 index 000000000..65c44c21a --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_trecip.pto @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.trecip should be lowered to vector-style VPTO IR. +// CHECK: func.func @TRECIP +// CHECK-NOT: pto.trecip ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vrec +// CHECK: pto.vsts + +module { + func.func @TRECIP() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.trecip ins(%a: !pto.tile_buf) + outs(%tile_buf: !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/expand_tile_op_tilelang_trsqrt.pto b/test/basic/expand_tile_op_tilelang_trsqrt.pto new file mode 100644 index 000000000..5488ba6ae --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_trsqrt.pto @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.trsqrt should be lowered to vector-style VPTO IR. +// CHECK: func.func @TRSQRT +// CHECK-NOT: pto.trsqrt ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vsqrt +// CHECK: pto.vrec +// CHECK: pto.vsts + +module { + func.func @TRSQRT() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.trsqrt ins(%a: !pto.tile_buf) + outs(%tile_buf: !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/expand_tile_op_tilelang_tsqrt.pto b/test/basic/expand_tile_op_tilelang_tsqrt.pto new file mode 100644 index 000000000..1263aa4a5 --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_tsqrt.pto @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tsqrt should be lowered to vector-style VPTO IR. +// CHECK: func.func @TSQRT +// CHECK-NOT: pto.tsqrt ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vsqrt +// CHECK: pto.vsts + +module { + func.func @TSQRT() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.tsqrt ins(%a: !pto.tile_buf) + outs(%tile_buf: !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt index 7d96af8f2..49d3df96b 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt +++ b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt @@ -133,6 +133,14 @@ set(ALL_TESTCASES tcvt tload softmax + tabs + texp + tlog + tneg + tnot + trecip + trsqrt + tsqrt ) if((TEST_CASE IN_LIST ALL_TESTCASES) OR (TEST_CASE STREQUAL "all")) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tabs/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tabs/CMakeLists.txt new file mode 100644 index 000000000..b776efb52 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tabs/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tabs) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tabs/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tabs/cases.py new file mode 100644 index 000000000..d63eb85f8 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tabs/cases.py @@ -0,0 +1,55 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tabs ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_16x64", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-6, + }, + { + "name": "f32_32x32", + "dtype": np.float32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-6, + }, + { + "name": "f16_16x64", + "dtype": np.float16, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-3, + }, + { + "name": "f16_32x32", + "dtype": np.float16, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-3, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tabs/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tabs/compare.py new file mode 100644 index 000000000..428604929 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tabs/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tabs/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tabs/gen_data.py new file mode 100644 index 000000000..22bf5d95d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tabs/gen_data.py @@ -0,0 +1,32 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input = np.random.randn(*shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = np.abs(input[:vr, :vc]).astype(dtype, copy=False) + + save_case_data(case["name"], {"input": input, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tabs/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tabs/launch.cpp new file mode 100644 index 000000000..dd39abd15 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tabs/launch.cpp @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 16x64 +extern "C" __global__ AICORE void TABS_f32_16x64(__gm__ float *a, __gm__ float *b); + +void LaunchTABS_f32_16x64(void *a, void *b, void *stream) { + TABS_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b); +} + +// Case 1: f32 32x32 +extern "C" __global__ AICORE void TABS_f32_32x32(__gm__ float *a, __gm__ float *b); + +void LaunchTABS_f32_32x32(void *a, void *b, void *stream) { + TABS_f32_32x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b); +} + +// Case 2: f16 16x64 +extern "C" __global__ AICORE void TABS_f16_16x64(__gm__ uint16_t *a, __gm__ uint16_t *b); + +void LaunchTABS_f16_16x64(void *a, void *b, void *stream) { + TABS_f16_16x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b); +} + +// Case 3: f16 32x32 +extern "C" __global__ AICORE void TABS_f16_32x32(__gm__ uint16_t *a, __gm__ uint16_t *b); + +void LaunchTABS_f16_32x32(void *a, void *b, void *stream) { + TABS_f16_32x32<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tabs/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tabs/main.cpp new file mode 100644 index 000000000..681510ddf --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tabs/main.cpp @@ -0,0 +1,137 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tabs ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTABS_f32_16x64(void *a, void *b, void *stream); +void LaunchTABS_f32_32x32(void *a, void *b, void *stream); +void LaunchTABS_f16_16x64(void *a, void *b, void *stream); +void LaunchTABS_f16_32x32(void *a, void *b, void *stream); + +using LaunchFn = void (*)(void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_16x64", LaunchTABS_f32_16x64, 16, 64, 16, 64, sizeof(float)}, + {"f32_32x32", LaunchTABS_f32_32x32, 32, 32, 32, 32, sizeof(float)}, + {"f16_16x64", LaunchTABS_f16_16x64, 16, 64, 16, 64, sizeof(uint16_t)}, + {"f16_32x32", LaunchTABS_f16_32x32, 32, 32, 32, 32, sizeof(uint16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&srcHost), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); + + aclrtMalloc((void **)&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tabs [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tabs/tabs.pto b/test/tilelang_st/npu/a5/src/st/testcase/tabs/tabs.pto new file mode 100644 index 000000000..9391f3f40 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tabs/tabs.pto @@ -0,0 +1,204 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tabs: tload(a) + tabs(a)->b + tstore(b). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. + +module { + // Case 0: f32 16x64 (1024 elements) + func.func @TABS_f32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%a : !pto.tile_buf) + + pto.tabs ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case 1: f32 32x32 (1024 elements) + func.func @TABS_f32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%a : !pto.tile_buf) + + pto.tabs ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + return + } + + // Case 2: f16 16x64 (1024 elements) + func.func @TABS_f16_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + outs(%a : !pto.tile_buf) + + pto.tabs ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + return + } + + // Case 3: f16 32x32 (1024 elements) + func.func @TABS_f16_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + outs(%a : !pto.tile_buf) + + pto.tabs ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + return + } +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/texp/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/texp/CMakeLists.txt new file mode 100644 index 000000000..6ce5def10 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/texp/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(texp) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/texp/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/texp/cases.py new file mode 100644 index 000000000..e2faacbb6 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/texp/cases.py @@ -0,0 +1,55 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for texp ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_16x64", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-6, + }, + { + "name": "f32_32x32", + "dtype": np.float32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-6, + }, + { + "name": "f16_16x64", + "dtype": np.float16, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-3, + }, + { + "name": "f16_32x32", + "dtype": np.float16, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-3, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/texp/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/texp/compare.py new file mode 100644 index 000000000..428604929 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/texp/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/texp/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/texp/gen_data.py new file mode 100644 index 000000000..7abaa7a78 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/texp/gen_data.py @@ -0,0 +1,32 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input = np.random.randn(*shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = np.exp(input[:vr, :vc]).astype(dtype, copy=False) + + save_case_data(case["name"], {"input": input, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/texp/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/texp/launch.cpp new file mode 100644 index 000000000..4140cb635 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/texp/launch.cpp @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 16x64 +extern "C" __global__ AICORE void TEXP_f32_16x64(__gm__ float *a, __gm__ float *b); + +void LaunchTEXP_f32_16x64(void *a, void *b, void *stream) { + TEXP_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b); +} + +// Case 1: f32 32x32 +extern "C" __global__ AICORE void TEXP_f32_32x32(__gm__ float *a, __gm__ float *b); + +void LaunchTEXP_f32_32x32(void *a, void *b, void *stream) { + TEXP_f32_32x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b); +} + +// Case 2: f16 16x64 +extern "C" __global__ AICORE void TEXP_f16_16x64(__gm__ uint16_t *a, __gm__ uint16_t *b); + +void LaunchTEXP_f16_16x64(void *a, void *b, void *stream) { + TEXP_f16_16x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b); +} + +// Case 3: f16 32x32 +extern "C" __global__ AICORE void TEXP_f16_32x32(__gm__ uint16_t *a, __gm__ uint16_t *b); + +void LaunchTEXP_f16_32x32(void *a, void *b, void *stream) { + TEXP_f16_32x32<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/texp/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/texp/main.cpp new file mode 100644 index 000000000..3ac2d5de5 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/texp/main.cpp @@ -0,0 +1,137 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang texp ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTEXP_f32_16x64(void *a, void *b, void *stream); +void LaunchTEXP_f32_32x32(void *a, void *b, void *stream); +void LaunchTEXP_f16_16x64(void *a, void *b, void *stream); +void LaunchTEXP_f16_32x32(void *a, void *b, void *stream); + +using LaunchFn = void (*)(void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_16x64", LaunchTEXP_f32_16x64, 16, 64, 16, 64, sizeof(float)}, + {"f32_32x32", LaunchTEXP_f32_32x32, 32, 32, 32, 32, sizeof(float)}, + {"f16_16x64", LaunchTEXP_f16_16x64, 16, 64, 16, 64, sizeof(uint16_t)}, + {"f16_32x32", LaunchTEXP_f16_32x32, 32, 32, 32, 32, sizeof(uint16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&srcHost), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); + + aclrtMalloc((void **)&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./texp [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/texp/texp.pto b/test/tilelang_st/npu/a5/src/st/testcase/texp/texp.pto new file mode 100644 index 000000000..5a3579595 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/texp/texp.pto @@ -0,0 +1,204 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.texp: tload(a) + texp(a)->b + tstore(b). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. + +module { + // Case 0: f32 16x64 (1024 elements) + func.func @TEXP_f32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%a : !pto.tile_buf) + + pto.texp ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case 1: f32 32x32 (1024 elements) + func.func @TEXP_f32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%a : !pto.tile_buf) + + pto.texp ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + return + } + + // Case 2: f16 16x64 (1024 elements) + func.func @TEXP_f16_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + outs(%a : !pto.tile_buf) + + pto.texp ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + return + } + + // Case 3: f16 32x32 (1024 elements) + func.func @TEXP_f16_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + outs(%a : !pto.tile_buf) + + pto.texp ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + return + } +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tlog/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tlog/CMakeLists.txt new file mode 100644 index 000000000..f17ca9cf8 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tlog/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tlog) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tlog/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tlog/cases.py new file mode 100644 index 000000000..2880e46dd --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tlog/cases.py @@ -0,0 +1,55 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tlog ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_16x64", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-6, + }, + { + "name": "f32_32x32", + "dtype": np.float32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-6, + }, + { + "name": "f16_16x64", + "dtype": np.float16, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-3, + }, + { + "name": "f16_32x32", + "dtype": np.float16, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-3, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tlog/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tlog/compare.py new file mode 100644 index 000000000..428604929 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tlog/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tlog/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tlog/gen_data.py new file mode 100644 index 000000000..459d8fb12 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tlog/gen_data.py @@ -0,0 +1,33 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + # Generate positive random values for log (log requires positive inputs) + input = np.random.uniform(0.1, 10.0, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = np.log(input[:vr, :vc]).astype(dtype, copy=False) + + save_case_data(case["name"], {"input": input, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tlog/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tlog/launch.cpp new file mode 100644 index 000000000..5c7a262b1 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tlog/launch.cpp @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 16x64 +extern "C" __global__ AICORE void TLOG_f32_16x64(__gm__ float *a, __gm__ float *b); + +void LaunchTLOG_f32_16x64(void *a, void *b, void *stream) { + TLOG_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b); +} + +// Case 1: f32 32x32 +extern "C" __global__ AICORE void TLOG_f32_32x32(__gm__ float *a, __gm__ float *b); + +void LaunchTLOG_f32_32x32(void *a, void *b, void *stream) { + TLOG_f32_32x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b); +} + +// Case 2: f16 16x64 +extern "C" __global__ AICORE void TLOG_f16_16x64(__gm__ uint16_t *a, __gm__ uint16_t *b); + +void LaunchTLOG_f16_16x64(void *a, void *b, void *stream) { + TLOG_f16_16x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b); +} + +// Case 3: f16 32x32 +extern "C" __global__ AICORE void TLOG_f16_32x32(__gm__ uint16_t *a, __gm__ uint16_t *b); + +void LaunchTLOG_f16_32x32(void *a, void *b, void *stream) { + TLOG_f16_32x32<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tlog/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tlog/main.cpp new file mode 100644 index 000000000..133ff955d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tlog/main.cpp @@ -0,0 +1,137 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tlog ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTLOG_f32_16x64(void *a, void *b, void *stream); +void LaunchTLOG_f32_32x32(void *a, void *b, void *stream); +void LaunchTLOG_f16_16x64(void *a, void *b, void *stream); +void LaunchTLOG_f16_32x32(void *a, void *b, void *stream); + +using LaunchFn = void (*)(void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_16x64", LaunchTLOG_f32_16x64, 16, 64, 16, 64, sizeof(float)}, + {"f32_32x32", LaunchTLOG_f32_32x32, 32, 32, 32, 32, sizeof(float)}, + {"f16_16x64", LaunchTLOG_f16_16x64, 16, 64, 16, 64, sizeof(uint16_t)}, + {"f16_32x32", LaunchTLOG_f16_32x32, 32, 32, 32, 32, sizeof(uint16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&srcHost), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); + + aclrtMalloc((void **)&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tlog [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tlog/tlog.pto b/test/tilelang_st/npu/a5/src/st/testcase/tlog/tlog.pto new file mode 100644 index 000000000..fbe57f290 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tlog/tlog.pto @@ -0,0 +1,204 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tlog: tload(a) + tlog(a)->b + tstore(b). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. + +module { + // Case 0: f32 16x64 (1024 elements) + func.func @TLOG_f32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%a : !pto.tile_buf) + + pto.tlog ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case 1: f32 32x32 (1024 elements) + func.func @TLOG_f32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%a : !pto.tile_buf) + + pto.tlog ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + return + } + + // Case 2: f16 16x64 (1024 elements) + func.func @TLOG_f16_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + outs(%a : !pto.tile_buf) + + pto.tlog ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + return + } + + // Case 3: f16 32x32 (1024 elements) + func.func @TLOG_f16_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + outs(%a : !pto.tile_buf) + + pto.tlog ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + return + } +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tneg/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tneg/CMakeLists.txt new file mode 100644 index 000000000..02a068e9e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tneg/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tneg) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tneg/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tneg/cases.py new file mode 100644 index 000000000..f5251d28a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tneg/cases.py @@ -0,0 +1,69 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tneg ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_16x64", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-6, + }, + { + "name": "f32_32x32", + "dtype": np.float32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-6, + }, + { + "name": "f16_16x64", + "dtype": np.float16, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-3, + }, + { + "name": "f16_32x32", + "dtype": np.float16, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-3, + }, + { + "name": "i32_32x32", + "dtype": np.int32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 0, + }, + { + "name": "i16_64x16", + "dtype": np.int16, + "shape": (64, 16), + "valid_shape": (64, 16), + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tneg/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tneg/compare.py new file mode 100644 index 000000000..428604929 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tneg/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tneg/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tneg/gen_data.py new file mode 100644 index 000000000..0c88055b7 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tneg/gen_data.py @@ -0,0 +1,33 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + # Random values (no constraints for neg) + input = np.random.randn(*shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = np.negative(input[:vr, :vc]).astype(dtype, copy=False) + + save_case_data(case["name"], {"input": input, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tneg/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tneg/launch.cpp new file mode 100644 index 000000000..ef6121f48 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tneg/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 16x64 +extern "C" __global__ AICORE void TNEG_f32_16x64(__gm__ float *a, __gm__ float *b); + +void LaunchTNEG_f32_16x64(void *a, void *b, void *stream) { + TNEG_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b); +} + +// Case 1: f32 32x32 +extern "C" __global__ AICORE void TNEG_f32_32x32(__gm__ float *a, __gm__ float *b); + +void LaunchTNEG_f32_32x32(void *a, void *b, void *stream) { + TNEG_f32_32x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b); +} + +// Case 2: f16 16x64 +extern "C" __global__ AICORE void TNEG_f16_16x64(__gm__ uint16_t *a, __gm__ uint16_t *b); + +void LaunchTNEG_f16_16x64(void *a, void *b, void *stream) { + TNEG_f16_16x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b); +} + +// Case 3: f16 32x32 +extern "C" __global__ AICORE void TNEG_f16_32x32(__gm__ uint16_t *a, __gm__ uint16_t *b); + +void LaunchTNEG_f16_32x32(void *a, void *b, void *stream) { + TNEG_f16_32x32<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b); +} + +// Case 4: i32 32x32 +extern "C" __global__ AICORE void TNEG_i32_32x32(__gm__ int32_t *a, __gm__ int32_t *b); + +void LaunchTNEG_i32_32x32(void *a, void *b, void *stream) { + TNEG_i32_32x32<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b); +} + +// Case 5: i16 64x16 +extern "C" __global__ AICORE void TNEG_i16_64x16(__gm__ int16_t *a, __gm__ int16_t *b); + +void LaunchTNEG_i16_64x16(void *a, void *b, void *stream) { + TNEG_i16_64x16<<<1, nullptr, stream>>>((__gm__ int16_t *)a, (__gm__ int16_t *)b); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tneg/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tneg/main.cpp new file mode 100644 index 000000000..6a86073fc --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tneg/main.cpp @@ -0,0 +1,141 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tneg ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTNEG_f32_16x64(void *a, void *b, void *stream); +void LaunchTNEG_f32_32x32(void *a, void *b, void *stream); +void LaunchTNEG_f16_16x64(void *a, void *b, void *stream); +void LaunchTNEG_f16_32x32(void *a, void *b, void *stream); +void LaunchTNEG_i32_32x32(void *a, void *b, void *stream); +void LaunchTNEG_i16_64x16(void *a, void *b, void *stream); + +using LaunchFn = void (*)(void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_16x64", LaunchTNEG_f32_16x64, 16, 64, 16, 64, sizeof(float)}, + {"f32_32x32", LaunchTNEG_f32_32x32, 32, 32, 32, 32, sizeof(float)}, + {"f16_16x64", LaunchTNEG_f16_16x64, 16, 64, 16, 64, sizeof(uint16_t)}, + {"f16_32x32", LaunchTNEG_f16_32x32, 32, 32, 32, 32, sizeof(uint16_t)}, + {"i32_32x32", LaunchTNEG_i32_32x32, 32, 32, 32, 32, sizeof(int32_t)}, + {"i16_64x16", LaunchTNEG_i16_64x16, 64, 16, 64, 16, sizeof(int16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&srcHost), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); + + aclrtMalloc((void **)&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tneg [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tneg/tneg.pto b/test/tilelang_st/npu/a5/src/st/testcase/tneg/tneg.pto new file mode 100644 index 000000000..1e370babe --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tneg/tneg.pto @@ -0,0 +1,299 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tneg: tload(a) + tneg(a)->b + tstore(b). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. + +module { + // Case 0: f32 16x64 (1024 elements) + func.func @TNEG_f32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%a : !pto.tile_buf) + + pto.tneg ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case 1: f32 32x32 (1024 elements) + func.func @TNEG_f32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%a : !pto.tile_buf) + + pto.tneg ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + return + } + + // Case 2: f16 16x64 (1024 elements) + func.func @TNEG_f16_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + outs(%a : !pto.tile_buf) + + pto.tneg ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + return + } + + // Case 3: f16 32x32 (1024 elements) + func.func @TNEG_f16_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + outs(%a : !pto.tile_buf) + + pto.tneg ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + return + } + + // Case 4: i32 32x32 (1024 elements) + func.func @TNEG_i32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) + outs(%a : !pto.tile_buf) + + pto.tneg ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) + return + } + + // Case 5: i16 64x16 (1024 elements) + func.func @TNEG_i16_64x16(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c16], + strides = [%c1024, %c1024, %c1024, %c16, %c1] + : !pto.tensor_view<1x1x1x64x16xi16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c16], + strides = [%c1024, %c1024, %c1024, %c16, %c1] + : !pto.tensor_view<1x1x1x64x16xi16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c16] + : !pto.tensor_view<1x1x1x64x16xi16> -> !pto.partition_tensor_view<1x1x1x64x16xi16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c16] + : !pto.tensor_view<1x1x1x64x16xi16> -> !pto.partition_tensor_view<1x1x1x64x16xi16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x16xi16>) + outs(%a : !pto.tile_buf) + + pto.tneg ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x64x16xi16>) + return + } +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tnot/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tnot/CMakeLists.txt new file mode 100644 index 000000000..ee5525ac2 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tnot/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tnot) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tnot/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tnot/cases.py new file mode 100644 index 000000000..b6612d63f --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tnot/cases.py @@ -0,0 +1,69 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tnot ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.int8, np.int16, np.int32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol), 0 for exact match. + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "int8_64x64", + "dtype": np.int8, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 0 + }, + { + "name": "uint8_60x60", + "dtype": np.uint8, + "shape": (64, 64), + "valid_shape": (60, 60), + "eps": 0 + }, + { + "name": "int16_64x64", + "dtype": np.int16, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 0 + }, + { + "name": "uint16_60x60", + "dtype": np.uint16, + "shape": (64, 64), + "valid_shape": (60, 60), + "eps": 0 + }, + { + "name": "int32_64x64", + "dtype": np.int32, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 0 + }, + { + "name": "uint32_60x60", + "dtype": np.uint32, + "shape": (64, 64), + "valid_shape": (60, 60), + "eps": 0 + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tnot/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tnot/compare.py new file mode 100644 index 000000000..428604929 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tnot/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tnot/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tnot/gen_data.py new file mode 100644 index 000000000..62de58386 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tnot/gen_data.py @@ -0,0 +1,30 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + dtype_info = np.iinfo(dtype) + input = np.random.randint(dtype_info.min, dtype_info.max, size=shape, dtype=dtype) + golden = np.bitwise_not(input).astype(dtype, copy=False) + + save_case_data(case["name"], {"input": input, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tnot/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tnot/launch.cpp new file mode 100644 index 000000000..858f6d181 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tnot/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: int8 64x64 +extern "C" __global__ AICORE void TNOT_int8_64x64(__gm__ int8_t *a, __gm__ int8_t *b); + +void LaunchTNOT_int8_64x64(void *a, void *b, void *stream) { + TNOT_int8_64x64<<<1, nullptr, stream>>>((__gm__ int8_t *)a, (__gm__ int8_t *)b); +} + +// Case 1: uint8 60x60 +extern "C" __global__ AICORE void TNOT_uint8_60x60(__gm__ uint8_t *a, __gm__ uint8_t *b); + +void LaunchTNOT_uint8_60x60(void *a, void *b, void *stream) { + TNOT_uint8_60x60<<<1, nullptr, stream>>>((__gm__ uint8_t *)a, (__gm__ uint8_t *)b); +} + +// Case 2: int16 64x64 +extern "C" __global__ AICORE void TNOT_int16_64x64(__gm__ int16_t *a, __gm__ int16_t *b); + +void LaunchTNOT_int16_64x64(void *a, void *b, void *stream) { + TNOT_int16_64x64<<<1, nullptr, stream>>>((__gm__ int16_t *)a, (__gm__ int16_t *)b); +} + +// Case 3: uint16 60x60 +extern "C" __global__ AICORE void TNOT_uint16_60x60(__gm__ uint16_t *a, __gm__ uint16_t *b); + +void LaunchTNOT_uint16_60x60(void *a, void *b, void *stream) { + TNOT_uint16_60x60<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b); +} + +// Case 4: int32 64x64 +extern "C" __global__ AICORE void TNOT_int32_64x64(__gm__ int32_t *a, __gm__ int32_t *b); + +void LaunchTNOT_int32_64x64(void *a, void *b, void *stream) { + TNOT_int32_64x64<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b); +} + +// Case 5: uint32 60x60 +extern "C" __global__ AICORE void TNOT_uint32_60x60(__gm__ uint32_t *a, __gm__ uint32_t *b); + +void LaunchTNOT_uint32_60x60(void *a, void *b, void *stream) { + TNOT_uint32_60x60<<<1, nullptr, stream>>>((__gm__ uint32_t *)a, (__gm__ uint32_t *)b); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tnot/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tnot/main.cpp new file mode 100644 index 000000000..55a823be7 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tnot/main.cpp @@ -0,0 +1,141 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tnot ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTNOT_int8_64x64(void *a, void *b, void *stream); +void LaunchTNOT_uint8_60x60(void *a, void *b, void *stream); +void LaunchTNOT_int16_64x64(void *a, void *b, void *stream); +void LaunchTNOT_uint16_60x60(void *a, void *b, void *stream); +void LaunchTNOT_int32_64x64(void *a, void *b, void *stream); +void LaunchTNOT_uint32_60x60(void *a, void *b, void *stream); + +using LaunchFn = void (*)(void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"int8_64x64", LaunchTNOT_int8_64x64, 64, 64, 64, 64, sizeof(int8_t)}, + {"uint8_60x60", LaunchTNOT_uint8_60x60, 64, 64, 60, 60, sizeof(uint8_t)}, + {"int16_64x64", LaunchTNOT_int16_64x64, 64, 64, 64, 64, sizeof(int16_t)}, + {"uint16_60x60", LaunchTNOT_uint16_60x60, 64, 64, 60, 60, sizeof(uint16_t)}, + {"int32_64x64", LaunchTNOT_int32_64x64, 64, 64, 64, 64, sizeof(int32_t)}, + {"uint32_60x60", LaunchTNOT_uint32_60x60, 64, 64, 60, 60, sizeof(uint32_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&srcHost), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); + + aclrtMalloc((void **)&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tnot [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tnot/tnot.pto b/test/tilelang_st/npu/a5/src/st/testcase/tnot/tnot.pto new file mode 100644 index 000000000..6b75c85ba --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tnot/tnot.pto @@ -0,0 +1,299 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tnot: tload(a) + tnot(a)->b + tstore(b). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. + +module { + // Case 0: int8 64x64 (valid 64x64) + func.func @TNOT_int8_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi8> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi8> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi8> -> !pto.partition_tensor_view<1x1x1x64x64xi8> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi8> -> !pto.partition_tensor_view<1x1x1x64x64xi8> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xi8>) + outs(%a : !pto.tile_buf) + + pto.tnot ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x64x64xi8>) + return + } + + // Case 1: uint8 64x64 (valid 60x60) - partition_view sizes = valid_shape + func.func @TNOT_uint8_60x60(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c60 = arith.constant 60 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xui8> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xui8> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x64x64xui8> -> !pto.partition_tensor_view<1x1x1x60x60xui8> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x64x64xui8> -> !pto.partition_tensor_view<1x1x1x60x60xui8> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x60x60xui8>) + outs(%a : !pto.tile_buf) + + pto.tnot ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x60x60xui8>) + return + } + + // Case 2: int16 64x64 (valid 64x64) + func.func @TNOT_int16_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) + outs(%a : !pto.tile_buf) + + pto.tnot ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) + return + } + + // Case 3: uint16 64x64 (valid 60x60) - partition_view sizes = valid_shape + func.func @TNOT_uint16_60x60(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c60 = arith.constant 60 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xui16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xui16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x64x64xui16> -> !pto.partition_tensor_view<1x1x1x60x60xui16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x64x64xui16> -> !pto.partition_tensor_view<1x1x1x60x60xui16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x60x60xui16>) + outs(%a : !pto.tile_buf) + + pto.tnot ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x60x60xui16>) + return + } + + // Case 4: int32 64x64 (valid 64x64) + func.func @TNOT_int32_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + outs(%a : !pto.tile_buf) + + pto.tnot ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + return + } + + // Case 5: uint32 64x64 (valid 60x60) - partition_view sizes = valid_shape + func.func @TNOT_uint32_60x60(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c60 = arith.constant 60 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xui32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xui32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x64x64xui32> -> !pto.partition_tensor_view<1x1x1x60x60xui32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x64x64xui32> -> !pto.partition_tensor_view<1x1x1x60x60xui32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x60x60xui32>) + outs(%a : !pto.tile_buf) + + pto.tnot ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x60x60xui32>) + return + } +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trecip/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/trecip/CMakeLists.txt new file mode 100644 index 000000000..9ec69bc60 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trecip/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(trecip) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trecip/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/trecip/cases.py new file mode 100644 index 000000000..b1c2012e2 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trecip/cases.py @@ -0,0 +1,69 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for trecip ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_16x64", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-6, + }, + { + "name": "f32_32x32", + "dtype": np.float32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-6, + }, + { + "name": "f16_16x64", + "dtype": np.float16, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-3, + }, + { + "name": "f16_32x32", + "dtype": np.float16, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-3, + }, + { + "name": "f32_64x64_pad", + "dtype": np.float32, + "shape": (66, 72), + "valid_shape": (64, 64), + "eps": 1e-6, + }, + { + "name": "f32_58x70", + "dtype": np.float32, + "shape": (66, 72), + "valid_shape": (58, 70), + "eps": 1e-6, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trecip/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/trecip/compare.py new file mode 100644 index 000000000..428604929 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trecip/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trecip/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/trecip/gen_data.py new file mode 100644 index 000000000..81e052958 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trecip/gen_data.py @@ -0,0 +1,31 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + # Avoid 0 for reciprocal, use range [0.1, 10.0] + input = np.random.uniform(0.1, 10.0, size=shape).astype(dtype) + + # reciprocal = 1/x + golden = np.reciprocal(input).astype(dtype, copy=False) + + save_case_data(case["name"], {"input": input, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trecip/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trecip/launch.cpp new file mode 100644 index 000000000..3cf95d119 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trecip/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 16x64 +extern "C" __global__ AICORE void TRECIP_f32_16x64(__gm__ float *a, __gm__ float *b); + +void LaunchTRECIP_f32_16x64(void *a, void *b, void *stream) { + TRECIP_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b); +} + +// Case 1: f32 32x32 +extern "C" __global__ AICORE void TRECIP_f32_32x32(__gm__ float *a, __gm__ float *b); + +void LaunchTRECIP_f32_32x32(void *a, void *b, void *stream) { + TRECIP_f32_32x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b); +} + +// Case 2: f16 16x64 +extern "C" __global__ AICORE void TRECIP_f16_16x64(__gm__ uint16_t *a, __gm__ uint16_t *b); + +void LaunchTRECIP_f16_16x64(void *a, void *b, void *stream) { + TRECIP_f16_16x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b); +} + +// Case 3: f16 32x32 +extern "C" __global__ AICORE void TRECIP_f16_32x32(__gm__ uint16_t *a, __gm__ uint16_t *b); + +void LaunchTRECIP_f16_32x32(void *a, void *b, void *stream) { + TRECIP_f16_32x32<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b); +} + +// Case 4: f32 66x72, valid 64x64 (pad) +extern "C" __global__ AICORE void TRECIP_f32_64x64_pad(__gm__ float *a, __gm__ float *b); + +void LaunchTRECIP_f32_64x64_pad(void *a, void *b, void *stream) { + TRECIP_f32_64x64_pad<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b); +} + +// Case 5: f32 66x72, valid 58x70 (non-square valid) +extern "C" __global__ AICORE void TRECIP_f32_58x70(__gm__ float *a, __gm__ float *b); + +void LaunchTRECIP_f32_58x70(void *a, void *b, void *stream) { + TRECIP_f32_58x70<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trecip/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trecip/main.cpp new file mode 100644 index 000000000..4a8400e6a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trecip/main.cpp @@ -0,0 +1,141 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang trecip ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTRECIP_f32_16x64(void *a, void *b, void *stream); +void LaunchTRECIP_f32_32x32(void *a, void *b, void *stream); +void LaunchTRECIP_f16_16x64(void *a, void *b, void *stream); +void LaunchTRECIP_f16_32x32(void *a, void *b, void *stream); +void LaunchTRECIP_f32_64x64_pad(void *a, void *b, void *stream); +void LaunchTRECIP_f32_58x70(void *a, void *b, void *stream); + +using LaunchFn = void (*)(void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_16x64", LaunchTRECIP_f32_16x64, 16, 64, 16, 64, sizeof(float)}, + {"f32_32x32", LaunchTRECIP_f32_32x32, 32, 32, 32, 32, sizeof(float)}, + {"f16_16x64", LaunchTRECIP_f16_16x64, 16, 64, 16, 64, sizeof(uint16_t)}, + {"f16_32x32", LaunchTRECIP_f16_32x32, 32, 32, 32, 32, sizeof(uint16_t)}, + {"f32_64x64_pad", LaunchTRECIP_f32_64x64_pad, 66, 72, 64, 64, sizeof(float)}, + {"f32_58x70", LaunchTRECIP_f32_58x70, 66, 72, 58, 70, sizeof(float)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, fileSize); + aclrtMallocHost(&dstHost, fileSize); + + aclrtMalloc(&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./trecip [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trecip/trecip.pto b/test/tilelang_st/npu/a5/src/st/testcase/trecip/trecip.pto new file mode 100644 index 000000000..d98996639 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trecip/trecip.pto @@ -0,0 +1,304 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.trecip: 1/x (reciprocal) +// trecip = vdiv(1.0, x) +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. + +module { + // Case 0: f32 16x64 + func.func @TRECIP_f32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%a : !pto.tile_buf) + + pto.trecip ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case 1: f32 32x32 + func.func @TRECIP_f32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%a : !pto.tile_buf) + + pto.trecip ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + return + } + + // Case 2: f16 16x64 + func.func @TRECIP_f16_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + outs(%a : !pto.tile_buf) + + pto.trecip ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + return + } + + // Case 3: f16 32x32 + func.func @TRECIP_f16_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + outs(%a : !pto.tile_buf) + + pto.trecip ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + return + } + + // Case 4: f32 66x72, valid 64x64 (pad case) + func.func @TRECIP_f32_64x64_pad(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c66 = arith.constant 66 : index + %c72 = arith.constant 72 : index + %c4752 = arith.constant 4752 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c66, %c72], + strides = [%c4752, %c4752, %c4752, %c72, %c1] + : !pto.tensor_view<1x1x1x66x72xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c66, %c72], + strides = [%c4752, %c4752, %c4752, %c72, %c1] + : !pto.tensor_view<1x1x1x66x72xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x66x72xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x66x72xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%a : !pto.tile_buf) + + pto.trecip ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case 5: f32 66x72, valid 58x70 (non-square valid) + func.func @TRECIP_f32_58x70(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c58 = arith.constant 58 : index + %c70 = arith.constant 70 : index + %c66 = arith.constant 66 : index + %c72 = arith.constant 72 : index + %c4752 = arith.constant 4752 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c66, %c72], + strides = [%c4752, %c4752, %c4752, %c72, %c1] + : !pto.tensor_view<1x1x1x66x72xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c66, %c72], + strides = [%c4752, %c4752, %c4752, %c72, %c1] + : !pto.tensor_view<1x1x1x66x72xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c58, %c70] + : !pto.tensor_view<1x1x1x66x72xf32> -> !pto.partition_tensor_view<1x1x1x58x70xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c58, %c70] + : !pto.tensor_view<1x1x1x66x72xf32> -> !pto.partition_tensor_view<1x1x1x58x70xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x58x70xf32>) + outs(%a : !pto.tile_buf) + + pto.trecip ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x58x70xf32>) + return + } +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/CMakeLists.txt new file mode 100644 index 000000000..7209977f8 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(trsqrt) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/cases.py new file mode 100644 index 000000000..cb8b6a48d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/cases.py @@ -0,0 +1,55 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for trsqrt ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_16x64", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-6, + }, + { + "name": "f32_32x32", + "dtype": np.float32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-6, + }, + { + "name": "f16_16x64", + "dtype": np.float16, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-3, + }, + { + "name": "f16_32x32", + "dtype": np.float16, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-3, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/compare.py new file mode 100644 index 000000000..428604929 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/gen_data.py new file mode 100644 index 000000000..9ca63c976 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/gen_data.py @@ -0,0 +1,34 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + # Positive values for rsqrt (1/sqrt(x) requires sqrt(x) > 0) + input = np.random.uniform(0.1, 100.0, size=shape).astype(dtype) + + # rsqrt = 1 / sqrt(x) + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = np.reciprocal(np.sqrt(input[:vr, :vc])).astype(dtype, copy=False) + + save_case_data(case["name"], {"input": input, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/launch.cpp new file mode 100644 index 000000000..65a35f3bd --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/launch.cpp @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 16x64 +extern "C" __global__ AICORE void TRSQRT_f32_16x64(__gm__ float *a, __gm__ float *b); + +void LaunchTRSQRT_f32_16x64(void *a, void *b, void *stream) { + TRSQRT_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b); +} + +// Case 1: f32 32x32 +extern "C" __global__ AICORE void TRSQRT_f32_32x32(__gm__ float *a, __gm__ float *b); + +void LaunchTRSQRT_f32_32x32(void *a, void *b, void *stream) { + TRSQRT_f32_32x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b); +} + +// Case 2: f16 16x64 +extern "C" __global__ AICORE void TRSQRT_f16_16x64(__gm__ uint16_t *a, __gm__ uint16_t *b); + +void LaunchTRSQRT_f16_16x64(void *a, void *b, void *stream) { + TRSQRT_f16_16x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b); +} + +// Case 3: f16 32x32 +extern "C" __global__ AICORE void TRSQRT_f16_32x32(__gm__ uint16_t *a, __gm__ uint16_t *b); + +void LaunchTRSQRT_f16_32x32(void *a, void *b, void *stream) { + TRSQRT_f16_32x32<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/main.cpp new file mode 100644 index 000000000..20c955070 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/main.cpp @@ -0,0 +1,137 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang trsqrt ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTRSQRT_f32_16x64(void *a, void *b, void *stream); +void LaunchTRSQRT_f32_32x32(void *a, void *b, void *stream); +void LaunchTRSQRT_f16_16x64(void *a, void *b, void *stream); +void LaunchTRSQRT_f16_32x32(void *a, void *b, void *stream); + +using LaunchFn = void (*)(void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_16x64", LaunchTRSQRT_f32_16x64, 16, 64, 16, 64, sizeof(float)}, + {"f32_32x32", LaunchTRSQRT_f32_32x32, 32, 32, 32, 32, sizeof(float)}, + {"f16_16x64", LaunchTRSQRT_f16_16x64, 16, 64, 16, 64, sizeof(uint16_t)}, + {"f16_32x32", LaunchTRSQRT_f16_32x32, 32, 32, 32, 32, sizeof(uint16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&srcHost), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); + + aclrtMalloc((void **)&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./trsqrt [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/trsqrt.pto b/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/trsqrt.pto new file mode 100644 index 000000000..607afc8e9 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/trsqrt.pto @@ -0,0 +1,205 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.trsqrt: 1/sqrt(x) +// trsqrt = vsqrt(x) -> vdiv(1.0, sqrt_result) +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. + +module { + // Case 0: f32 16x64 (1024 elements) + func.func @TRSQRT_f32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%a : !pto.tile_buf) + + pto.trsqrt ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case 1: f32 32x32 (1024 elements) + func.func @TRSQRT_f32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%a : !pto.tile_buf) + + pto.trsqrt ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + return + } + + // Case 2: f16 16x64 (1024 elements) + func.func @TRSQRT_f16_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + outs(%a : !pto.tile_buf) + + pto.trsqrt ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + return + } + + // Case 3: f16 32x32 (1024 elements) + func.func @TRSQRT_f16_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + outs(%a : !pto.tile_buf) + + pto.trsqrt ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + return + } +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/CMakeLists.txt new file mode 100644 index 000000000..83de2cda8 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tsqrt) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/cases.py new file mode 100644 index 000000000..0e05491d0 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/cases.py @@ -0,0 +1,55 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tsqrt ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_16x64", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-6, + }, + { + "name": "f32_32x32", + "dtype": np.float32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-6, + }, + { + "name": "f16_16x64", + "dtype": np.float16, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-3, + }, + { + "name": "f16_32x32", + "dtype": np.float16, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-3, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/compare.py new file mode 100644 index 000000000..428604929 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/gen_data.py new file mode 100644 index 000000000..bc2301baa --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/gen_data.py @@ -0,0 +1,33 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + # Generate positive random values for sqrt + input = np.random.uniform(0.1, 100.0, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = np.sqrt(input[:vr, :vc]).astype(dtype, copy=False) + + save_case_data(case["name"], {"input": input, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/launch.cpp new file mode 100644 index 000000000..b6ffbdcb4 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/launch.cpp @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 16x64 +extern "C" __global__ AICORE void TSQRT_f32_16x64(__gm__ float *a, __gm__ float *b); + +void LaunchTSQRT_f32_16x64(void *a, void *b, void *stream) { + TSQRT_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b); +} + +// Case 1: f32 32x32 +extern "C" __global__ AICORE void TSQRT_f32_32x32(__gm__ float *a, __gm__ float *b); + +void LaunchTSQRT_f32_32x32(void *a, void *b, void *stream) { + TSQRT_f32_32x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b); +} + +// Case 2: f16 16x64 +extern "C" __global__ AICORE void TSQRT_f16_16x64(__gm__ uint16_t *a, __gm__ uint16_t *b); + +void LaunchTSQRT_f16_16x64(void *a, void *b, void *stream) { + TSQRT_f16_16x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b); +} + +// Case 3: f16 32x32 +extern "C" __global__ AICORE void TSQRT_f16_32x32(__gm__ uint16_t *a, __gm__ uint16_t *b); + +void LaunchTSQRT_f16_32x32(void *a, void *b, void *stream) { + TSQRT_f16_32x32<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/main.cpp new file mode 100644 index 000000000..25bea9592 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/main.cpp @@ -0,0 +1,137 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tsqrt ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTSQRT_f32_16x64(void *a, void *b, void *stream); +void LaunchTSQRT_f32_32x32(void *a, void *b, void *stream); +void LaunchTSQRT_f16_16x64(void *a, void *b, void *stream); +void LaunchTSQRT_f16_32x32(void *a, void *b, void *stream); + +using LaunchFn = void (*)(void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_16x64", LaunchTSQRT_f32_16x64, 16, 64, 16, 64, sizeof(float)}, + {"f32_32x32", LaunchTSQRT_f32_32x32, 32, 32, 32, 32, sizeof(float)}, + {"f16_16x64", LaunchTSQRT_f16_16x64, 16, 64, 16, 64, sizeof(uint16_t)}, + {"f16_32x32", LaunchTSQRT_f16_32x32, 32, 32, 32, 32, sizeof(uint16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&srcHost), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); + + aclrtMalloc((void **)&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tsqrt [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/tsqrt.pto b/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/tsqrt.pto new file mode 100644 index 000000000..ba1118e90 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/tsqrt.pto @@ -0,0 +1,204 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tsqrt: tload(a) + tsqrt(a)->b + tstore(b). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. + +module { + // Case 0: f32 16x64 (1024 elements) + func.func @TSQRT_f32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%a : !pto.tile_buf) + + pto.tsqrt ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case 1: f32 32x32 (1024 elements) + func.func @TSQRT_f32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%a : !pto.tile_buf) + + pto.tsqrt ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + +pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + return + } + + // Case 2: f16 16x64 (1024 elements) + func.func @TSQRT_f16_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + outs(%a : !pto.tile_buf) + + pto.tsqrt ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + return + } + + // Case 3: f16 32x32 (1024 elements) + func.func @TSQRT_f16_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + outs(%a : !pto.tile_buf) + + pto.tsqrt ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + return + } +} From f7447e357fe587b4f76fbdf690db406eaaab823a Mon Sep 17 00:00:00 2001 From: FangRui Date: Tue, 21 Apr 2026 10:00:20 +0800 Subject: [PATCH 155/192] feat(vpto): add vmatmal op and dependent copy op --- docs/isa/16-cube-matmul.md | 104 + docs/vpto-spec.md | 44 + include/PTO/IR/VPTOOps.td | 319 +- lib/PTO/IR/PTO.cpp | 25 +- lib/PTO/IR/VPTO.cpp | 69 + lib/PTO/Transforms/HIVMIntrinsicNaming.cpp | 346 +- lib/PTO/Transforms/PTOValidateVPTOIR.cpp | 5 +- lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 3560 ++++++++++++----- .../cube-matmul/mad_bf16bf16f32/compare.py | 46 + .../cube-matmul/mad_bf16bf16f32/golden.py | 41 + .../cube-matmul/mad_bf16bf16f32/kernel.pto | 46 + .../cube-matmul/mad_bf16bf16f32/launch.cpp | 49 + .../cube-matmul/mad_bf16bf16f32/main.cpp | 127 + .../cube-matmul/mad_bf16bf16f32/stub.cpp | 25 + .../cube-matmul/mad_f16f16f32/compare.py | 46 + .../cube-matmul/mad_f16f16f32/golden.py | 41 + .../cube-matmul/mad_f16f16f32/kernel.pto | 46 + .../cube-matmul/mad_f16f16f32/launch.cpp | 49 + .../cube-matmul/mad_f16f16f32/main.cpp | 127 + .../cube-matmul/mad_f16f16f32/stub.cpp | 25 + .../cube-matmul/mad_f32f32f32/compare.py | 46 + .../cube-matmul/mad_f32f32f32/golden.py | 41 + .../cube-matmul/mad_f32f32f32/kernel.pto | 46 + .../cube-matmul/mad_f32f32f32/launch.cpp | 49 + .../cube-matmul/mad_f32f32f32/main.cpp | 127 + .../cube-matmul/mad_f32f32f32/stub.cpp | 25 + .../micro-op/cube-matmul/mad_mx/compare.py | 46 + .../micro-op/cube-matmul/mad_mx/golden.py | 42 + .../micro-op/cube-matmul/mad_mx/kernel.pto | 47 + .../micro-op/cube-matmul/mad_mx/launch.cpp | 49 + .../micro-op/cube-matmul/mad_mx/main.cpp | 127 + .../micro-op/cube-matmul/mad_mx/stub.cpp | 25 + test/vpto/scripts/run_host_vpto_validation.sh | 61 +- tools/ptoas/ptoas.cpp | 45 +- 34 files changed, 4890 insertions(+), 1026 deletions(-) create mode 100644 docs/isa/16-cube-matmul.md create mode 100644 test/vpto/cases/micro-op/cube-matmul/mad_bf16bf16f32/compare.py create mode 100644 test/vpto/cases/micro-op/cube-matmul/mad_bf16bf16f32/golden.py create mode 100644 test/vpto/cases/micro-op/cube-matmul/mad_bf16bf16f32/kernel.pto create mode 100644 test/vpto/cases/micro-op/cube-matmul/mad_bf16bf16f32/launch.cpp create mode 100644 test/vpto/cases/micro-op/cube-matmul/mad_bf16bf16f32/main.cpp create mode 100644 test/vpto/cases/micro-op/cube-matmul/mad_bf16bf16f32/stub.cpp create mode 100644 test/vpto/cases/micro-op/cube-matmul/mad_f16f16f32/compare.py create mode 100644 test/vpto/cases/micro-op/cube-matmul/mad_f16f16f32/golden.py create mode 100644 test/vpto/cases/micro-op/cube-matmul/mad_f16f16f32/kernel.pto create mode 100644 test/vpto/cases/micro-op/cube-matmul/mad_f16f16f32/launch.cpp create mode 100644 test/vpto/cases/micro-op/cube-matmul/mad_f16f16f32/main.cpp create mode 100644 test/vpto/cases/micro-op/cube-matmul/mad_f16f16f32/stub.cpp create mode 100644 test/vpto/cases/micro-op/cube-matmul/mad_f32f32f32/compare.py create mode 100644 test/vpto/cases/micro-op/cube-matmul/mad_f32f32f32/golden.py create mode 100644 test/vpto/cases/micro-op/cube-matmul/mad_f32f32f32/kernel.pto create mode 100644 test/vpto/cases/micro-op/cube-matmul/mad_f32f32f32/launch.cpp create mode 100644 test/vpto/cases/micro-op/cube-matmul/mad_f32f32f32/main.cpp create mode 100644 test/vpto/cases/micro-op/cube-matmul/mad_f32f32f32/stub.cpp create mode 100644 test/vpto/cases/micro-op/cube-matmul/mad_mx/compare.py create mode 100644 test/vpto/cases/micro-op/cube-matmul/mad_mx/golden.py create mode 100644 test/vpto/cases/micro-op/cube-matmul/mad_mx/kernel.pto create mode 100644 test/vpto/cases/micro-op/cube-matmul/mad_mx/launch.cpp create mode 100644 test/vpto/cases/micro-op/cube-matmul/mad_mx/main.cpp create mode 100644 test/vpto/cases/micro-op/cube-matmul/mad_mx/stub.cpp diff --git a/docs/isa/16-cube-matmul.md b/docs/isa/16-cube-matmul.md new file mode 100644 index 000000000..50ccd327e --- /dev/null +++ b/docs/isa/16-cube-matmul.md @@ -0,0 +1,104 @@ +# 16. Cube Matrix Multiply (MAT) + +> **Category:** Cube unit — GM/L1 staging, L0A/L0B loads, L0C accumulate, and matrix-side side-buffer moves +> **Pipelines:** MTE2 (GM→L1 / cbuf), Cube (L0A/L0B→L0C), MTE3/FIX (L0C→{GM,L1,UB}, L1→{BT,FB}) + +This group documents **buffer-pointer** PTO ops used to express a minimal **cube matmul data path** on A5: data is moved from GM into L1-aligned buffers (`cbuf`), loaded into L0A/L0B, multiplied into L0C (`cc`), then written back or redistributed to GM/L1/UB and related matrix-side buffers. These ops are distinct from the vector `!pto.vreg<…>` surface in groups 3–13. + +Typical usage keeps the body inside `pto.vecscope { … }` (or another enclosing region required by the PTO verifier) so cube-side effects remain ordered with respect to other PTO work. + +--- + +## `pto.copy_gm_to_cbuf` + +- **syntax:** `pto.copy_gm_to_cbuf %src, %dst, %n_burst, %len_burst, %src_stride, %dst_stride : !pto.ptr<…, gm>, !pto.ptr<…, ub>, i64, i64, i64, i64` +- **semantics:** GM→L1 (`cbuf`) aligned copy. `%src` is GM; `%dst` is UB-backed L1 (`cbuf`) staging. + +Operands `%n_burst`, `%len_burst`, `%src_stride`, and `%dst_stride` configure the transfer shape; they are lowered to packed `i64` configuration tokens for the target `llvm.hivm` MOV family intrinsic. + +--- + +## `pto.load_cbuf_to_ca` + +- **syntax:** `pto.load_cbuf_to_ca %src, %dst, %m, %k : !pto.ptr<…, ub>, !pto.ptr<…, ub>, i64, i64` +- **semantics:** L1 (`cbuf`) → L0A load. `%src` is `cbuf`; `%dst` is UB-backed L0A staging. + +--- + +## `pto.load_cbuf_to_cb` + +- **syntax:** `pto.load_cbuf_to_cb %src, %dst, %k, %n : !pto.ptr<…, ub>, !pto.ptr<…, ub>, i64, i64` +- **semantics:** L1 (`cbuf`) → L0B load. `%src` is `cbuf`; `%dst` is UB-backed L0B staging. + +--- + +## `pto.mad` + +- **syntax:** `pto.mad %lhs, %rhs, %dst, %m, %n, %k : !pto.ptr<…, ub>, !pto.ptr<…, ub>, !pto.ptr<…, ub>, i64, i64, i64` +- **semantics:** Cube **multiply** on L0A (`%lhs`) and L0B (`%rhs`) into L0C (`%dst`). All three pointers must be UB-backed **buffer** pointers (`!pto.ptr<…, ub>`) classified as left / right / accumulator roles in lowering. + +Supported element-type combinations follow the HIVM intrinsic selection in the compiler (for example `f16` × `f16` → `f32` accumulation, and MX-dtyped paths where applicable). + +--- + +## `pto.copy_matrix_cc_to_gm` + +- **syntax:** `pto.copy_matrix_cc_to_gm %src, %dst, %m, %n : !pto.ptr<…, ub>, !pto.ptr<…, gm>, i64, i64` +- **semantics:** L0C (`cc`) → GM matrix writeback. `%src` is UB-backed `cc`; `%dst` is GM. + +--- + +## `pto.copy_gm_to_cbuf_multi_nd2nz` / `pto.copy_gm_to_cbuf_multi_dn2nz` + +- **syntax:** `pto.copy_gm_to_cbuf_multi_* %src, %dst, ... : !pto.ptr<…, gm>, !pto.ptr<…, ub>, ...` +- **semantics:** GM→L1 (`cbuf`) multi-fractal staging paths for cube data layout conversion variants (`ND2NZ` / `DN2NZ`), lowered to `llvm.hivm.MOV.OUT.TO.L1.MULTI.*` families. + +--- + +## `pto.copy_matrix_cc_to_cbuf` / `pto.copy_matrix_cc_to_ub` + +- **syntax:** `pto.copy_matrix_cc_to_* %src, %dst, %config0, %config1 : !pto.ptr<…, ub>, !pto.ptr<…, ub>, i64, i64` +- **semantics:** L0C (`cc`) redistribution to L1 (`cbuf`) or UB destinations. These are post-matmul movement ops typically used before follow-up fusion or output formatting. + +--- + +## `pto.copy_cbuf_to_bt` / `pto.copy_cbuf_to_fbuf` + +- **syntax:** `pto.copy_cbuf_to_bt ...` and `pto.copy_cbuf_to_fbuf ...` +- **semantics:** L1 (`cbuf`) to bias/scaling-side buffers for matrix post-processing setup. Lowered to `llvm.hivm.MOV.L1.TO.BT.f16` and `llvm.hivm.MOV.L1.TO.FB.V2`. + +--- + +## Verified A5 Op Set (Current Batch) + +The following PTO ops have been verified in the current A5 VPTO→LLVM/HIVM and Bisheng cube-flow validation batch: + +- `pto.copy_cbuf_to_bt` +- `pto.copy_cbuf_to_fbuf` +- `pto.copy_gm_to_cbuf_multi_dn2nz` +- `pto.copy_gm_to_cbuf_multi_nd2nz` +- `pto.copy_matrix_cc_to_cbuf` +- `pto.copy_matrix_cc_to_ub` +- `pto.load_cbuf_to_ca_mx` +- `pto.load_cbuf_to_ca_s4` +- `pto.load_cbuf_to_cb_mx` +- `pto.load_cbuf_to_cb_s4` +- `pto.set_atomic_s32` +- `pto.set_atomic_s8` +- `pto.set_channel_para` +- `pto.set_fpc` +- `pto.set_loop1_stride_outtol1` +- `pto.set_loop2_stride_outtol1` +- `pto.set_loop3_para` +- `pto.set_loop_size_outtol1` +- `pto.set_mte2_nz_para` +- `pto.set_pad_val_outtol1` +- `pto.set_quant_pre` + +--- + +## Current PTOAS Coverage + +- VPTO → LLVM (`--vpto-emit-hivm-llvm`) lowers the ops in this group to target-specific `llvm.hivm.*` intrinsics with explicit address spaces for GM / cbuf / L0A / L0B / L0C and matrix-side side buffers. +- FileCheck coverage lives under `test/basic/vpto_mad_*.pto` and `test/basic/vpto_cube_dma_matmul_*.pto`. +- TileLang ST builds a cube-linked host testcase via `pto_tilelang_cube_st(tmatmul)` in `test/tilelang_st/npu/a5/src/st/testcase/tmatmul/`. diff --git a/docs/vpto-spec.md b/docs/vpto-spec.md index c4d748a55..76182d26c 100644 --- a/docs/vpto-spec.md +++ b/docs/vpto-spec.md @@ -900,6 +900,7 @@ This section provides a categorized overview of all PTO micro Instruction operat | 13 | [DSA/SFU Ops](isa/13-dsa-sfu-ops.md) | Specialized ops, index generation, and sorting helpers | 9 | `pto.vlrelu`, `pto.vprelu`, `pto.vexpdif`, `pto.vaxpy`, `pto.vmull`, `pto.vmula`, `pto.vci`, `pto.vbitsort`, `pto.vmrgsort4` | | 14 | [Arith (Shared MLIR Dialect)](isa/14-shared-arith.md) | Full scalar `arith` surface used around PTO ops; the companion page lists categories and representative examples | all scalar ops | `arith.constant`, `arith.addi`, `arith.addf`, `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.index_cast`, `arith.extsi`, `arith.trunci`, `arith.andi`, `arith.shli`, etc. | | 15 | [SCF (Shared MLIR Dialect)](isa/15-shared-scf.md) | Structured loops, branches, and loop-carried state around PTO regions | 5 | `scf.for`, `scf.if`, `scf.while`, `scf.condition`, `scf.yield` | +| 16 | [Cube Matrix Multiply (MAT)](isa/16-cube-matmul.md) | GM↔L1 cube staging, L0A/L0B loads, L0C matmul, and L0C/L1 side-buffer moves | 10+ | `pto.copy_gm_to_cbuf`, `pto.copy_gm_to_cbuf_multi_nd2nz`, `pto.copy_gm_to_cbuf_multi_dn2nz`, `pto.load_cbuf_to_ca`, `pto.load_cbuf_to_cb`, `pto.mad`, `pto.copy_matrix_cc_to_gm`, `pto.copy_matrix_cc_to_cbuf`, `pto.copy_matrix_cc_to_ub`, `pto.copy_cbuf_to_bt`, `pto.copy_cbuf_to_fbuf` | --- @@ -911,6 +912,12 @@ This section provides a categorized overview of all PTO micro Instruction operat |-----------|-------|-------------| | GM→UB DMA | 2 | `pto.dma_load` | | UB→GM DMA | 2 | `pto.dma_store` | +| GM→L1 (cube staging) | 16 | `pto.copy_gm_to_cbuf` | +| GM→L1 (multi layout staging) | 16 | `pto.copy_gm_to_cbuf_multi_nd2nz`, `pto.copy_gm_to_cbuf_multi_dn2nz` | +| L1→L0A / L1→L0B | 16 | `pto.load_cbuf_to_ca`, `pto.load_cbuf_to_cb` | +| L0C→GM (cube writeback) | 16 | `pto.copy_matrix_cc_to_gm` | +| L0C→L1 / L0C→UB | 16 | `pto.copy_matrix_cc_to_cbuf`, `pto.copy_matrix_cc_to_ub` | +| L1→BT / L1→FB | 16 | `pto.copy_cbuf_to_bt`, `pto.copy_cbuf_to_fbuf` | | Contiguous Load | 3 | `pto.vlds` with `NORM` dist | | Broadcast Load | 3 | `pto.vlds` with `BRC` family dist | | Gather | 3 | `pto.vgather2`, `pto.vgatherb` | @@ -925,6 +932,7 @@ This section provides a categorized overview of all PTO micro Instruction operat | Scalar Operations | 8 | `pto.vadds`, `pto.vmuls`, etc. | | Transcendental | 6 | `pto.vexp`, `pto.vln`, `pto.vsqrt`, etc. | | Reduction | 10 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin` | +| Cube matmul (L0A×L0B→L0C) | 16 | `pto.mad` | | Comparison | 11 | `pto.vcmp`, `pto.vcmps` | | Selection | 11 | `pto.vsel`, `pto.vselr` | @@ -959,6 +967,42 @@ Group 14 covers the full scalar `arith` surface. The rows below list common PTO | Conditional Regions | 15 | `scf.if`, `scf.yield` | | Break-like Structured Loops | 15 | `scf.while`, `scf.condition`, `scf.yield` | +### Recent A5 Additions (Implemented) + +- `pto.set_quant_pre` (lowered to `llvm.hivm.SET.QUANT.PRE.v300`) +- `pto.set_atomic_s32`, `pto.set_atomic_s8` (A5-selectable atomic mode controls) +- Cube-side movement additions: + - `pto.copy_gm_to_cbuf_multi_nd2nz` + - `pto.copy_gm_to_cbuf_multi_dn2nz` + - `pto.copy_matrix_cc_to_cbuf` + - `pto.copy_matrix_cc_to_ub` + - `pto.copy_cbuf_to_bt` + - `pto.copy_cbuf_to_fbuf` + +### Verified Op List (Current Batch) + +- `pto.copy_cbuf_to_bt` +- `pto.copy_cbuf_to_fbuf` +- `pto.copy_gm_to_cbuf_multi_dn2nz` +- `pto.copy_gm_to_cbuf_multi_nd2nz` +- `pto.copy_matrix_cc_to_cbuf` +- `pto.copy_matrix_cc_to_ub` +- `pto.load_cbuf_to_ca_mx` +- `pto.load_cbuf_to_ca_s4` +- `pto.load_cbuf_to_cb_mx` +- `pto.load_cbuf_to_cb_s4` +- `pto.set_atomic_s32` +- `pto.set_atomic_s8` +- `pto.set_channel_para` +- `pto.set_fpc` +- `pto.set_loop1_stride_outtol1` +- `pto.set_loop2_stride_outtol1` +- `pto.set_loop3_para` +- `pto.set_loop_size_outtol1` +- `pto.set_mte2_nz_para` +- `pto.set_pad_val_outtol1` +- `pto.set_quant_pre` + --- ## Supported Data Types diff --git a/include/PTO/IR/VPTOOps.td b/include/PTO/IR/VPTOOps.td index ea72b8c75..b5ac18d41 100644 --- a/include/PTO/IR/VPTOOps.td +++ b/include/PTO/IR/VPTOOps.td @@ -183,6 +183,12 @@ class PTO_NullaryI64PureOp : PTO_Op { }]; } +class PTO_NullaryConfigOp : PTO_Op { + let arguments = (ins); + let results = (outs); + let assemblyFormat = [{ attr-dict }]; +} + def PTO_SetMovPadValOp : PTO_Op<"set_mov_pad_val"> { let arguments = (ins AnyTypeOf<[AnyInteger, AnyFloat], "integer/float scalar">:$value); @@ -197,13 +203,24 @@ def PTO_SetMovPadValOp : PTO_Op<"set_mov_pad_val"> { def PTO_SetLoop2StrideOutToUbOp : PTO_BinaryI64ConfigOp<"set_loop2_stride_outtoub">; def PTO_SetLoop1StrideOutToUbOp : PTO_BinaryI64ConfigOp<"set_loop1_stride_outtoub">; def PTO_SetLoopSizeOutToUbOp : PTO_BinaryI64ConfigOp<"set_loop_size_outtoub">; +def PTO_SetLoop2StrideOutToL1Op : PTO_UnaryI64ConfigOp<"set_loop2_stride_outtol1">; +def PTO_SetLoop1StrideOutToL1Op : PTO_UnaryI64ConfigOp<"set_loop1_stride_outtol1">; +def PTO_SetLoopSizeOutToL1Op : PTO_UnaryI64ConfigOp<"set_loop_size_outtol1">; def PTO_SetLoop2StrideUbToOutOp : PTO_BinaryI64ConfigOp<"set_loop2_stride_ubtoout">; def PTO_SetLoop1StrideUbToOutOp : PTO_BinaryI64ConfigOp<"set_loop1_stride_ubtoout">; def PTO_SetLoopSizeUbToOutOp : PTO_BinaryI64ConfigOp<"set_loop_size_ubtoout">; +def PTO_SetLoop3ParaOp : PTO_BinaryI64ConfigOp<"set_loop3_para">; +def PTO_SetChannelParaOp : PTO_BinaryI64ConfigOp<"set_channel_para">; +def PTO_SetMte2NzParaOp : PTO_UnaryI64ConfigOp<"set_mte2_nz_para">; +def PTO_SetPadValOutToL1Op : PTO_UnaryI64ConfigOp<"set_pad_val_outtol1">; def PTO_GetCtrlOp : PTO_NullaryI64PureOp<"get_ctrl">; def PTO_SetCtrlOp : PTO_UnaryI64ConfigOp<"set_ctrl">; def PTO_Sbitset0Op : PTO_BinaryI64PureOp<"sbitset0">; def PTO_Sbitset1Op : PTO_BinaryI64PureOp<"sbitset1">; +def PTO_SetQuantPreOp : PTO_UnaryI64ConfigOp<"set_quant_pre">; +def PTO_SetFpcOp : PTO_UnaryI64ConfigOp<"set_fpc">; +def PTO_SetAtomicS32Op : PTO_NullaryConfigOp<"set_atomic_s32">; +def PTO_SetAtomicS8Op : PTO_NullaryConfigOp<"set_atomic_s8">; def PTO_CopyGmToUbufOp : PTO_Op<"copy_gm_to_ubuf", [ DeclareOpInterfaceMethods @@ -328,6 +345,266 @@ def PTO_DmaCopyOp : PTO_Op<"dma_copy", [ }]; } +def PTO_CopyGmToCbufOp : PTO_Op<"copy_gm_to_cbuf"> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$n_burst, + I64:$len_burst, + I64:$src_stride, + I64:$dst_stride + ); + + let results = (outs); + + let assemblyFormat = [{ + $source `,` $destination `,` $n_burst `,` $len_burst `,` $src_stride `,` $dst_stride + attr-dict `:` type($source) `,` type($destination) `,` type($n_burst) `,` type($len_burst) `,` type($src_stride) `,` type($dst_stride) + }]; +} + +def PTO_CopyGmToCbufMultiNd2NzOp : PTO_Op<"copy_gm_to_cbuf_multi_nd2nz"> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$sid, + I64:$loop1_src_stride, + I64:$l2_cache_ctrl, + I64:$n_value, + I64:$d_value, + I64:$loop4_src_stride, + I1:$smallc0_en + ); + + let results = (outs); + + let assemblyFormat = [{ + $source `,` $destination `,` $sid `,` $loop1_src_stride `,` $l2_cache_ctrl `,` + $n_value `,` $d_value `,` $loop4_src_stride `,` $smallc0_en + attr-dict `:` type($source) `,` type($destination) `,` type($sid) `,` + type($loop1_src_stride) `,` type($l2_cache_ctrl) `,` type($n_value) `,` + type($d_value) `,` type($loop4_src_stride) `,` type($smallc0_en) + }]; +} + +def PTO_CopyGmToCbufMultiDn2NzOp : PTO_Op<"copy_gm_to_cbuf_multi_dn2nz"> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$sid, + I64:$loop1_src_stride, + I64:$l2_cache_ctrl, + I64:$n_value, + I64:$d_value, + I64:$loop4_src_stride, + I1:$smallc0_en + ); + + let results = (outs); + + let assemblyFormat = [{ + $source `,` $destination `,` $sid `,` $loop1_src_stride `,` $l2_cache_ctrl `,` + $n_value `,` $d_value `,` $loop4_src_stride `,` $smallc0_en + attr-dict `:` type($source) `,` type($destination) `,` type($sid) `,` + type($loop1_src_stride) `,` type($l2_cache_ctrl) `,` type($n_value) `,` + type($d_value) `,` type($loop4_src_stride) `,` type($smallc0_en) + }]; +} + +def PTO_CopyCbufToBtOp : PTO_Op<"copy_cbuf_to_bt"> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I1:$conv_control, + I64:$n_burst, + I64:$len_burst, + I64:$source_gap, + I64:$dst_gap + ); + let results = (outs); + + let assemblyFormat = [{ + $source `,` $destination `,` $conv_control `,` $n_burst `,` $len_burst `,` + $source_gap `,` $dst_gap + attr-dict `:` type($source) `,` type($destination) `,` type($conv_control) `,` + type($n_burst) `,` type($len_burst) `,` type($source_gap) `,` type($dst_gap) + }]; +} + +def PTO_CopyCbufToFbufOp : PTO_Op<"copy_cbuf_to_fbuf"> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$n_burst, + I64:$len_burst, + I64:$source_gap, + I64:$dst_gap + ); + let results = (outs); + + let assemblyFormat = [{ + $source `,` $destination `,` $n_burst `,` $len_burst `,` + $source_gap `,` $dst_gap + attr-dict `:` type($source) `,` type($destination) `,` type($n_burst) `,` + type($len_burst) `,` type($source_gap) `,` type($dst_gap) + }]; +} + +def PTO_LoadCbufToCaOp : PTO_Op<"load_cbuf_to_ca"> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$m, + I64:$k + ); + + let results = (outs); + + let assemblyFormat = [{ + $source `,` $destination `,` $m `,` $k attr-dict `:` type($source) `,` type($destination) `,` type($m) `,` type($k) + }]; +} + +def PTO_LoadCbufToCaS4Op : PTO_Op<"load_cbuf_to_ca_s4"> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$m_start, + I64:$k_start, + I64:$m_step, + I64:$k_step, + I64:$src_stride, + I64:$dst_stride, + I64:$transpose + ); + + let results = (outs); + + let assemblyFormat = [{ + $source `,` $destination `,` $m_start `,` $k_start `,` $m_step `,` + $k_step `,` $src_stride `,` $dst_stride `,` $transpose + attr-dict `:` type($source) `,` type($destination) `,` type($m_start) `,` + type($k_start) `,` type($m_step) `,` type($k_step) `,` type($src_stride) `,` + type($dst_stride) `,` type($transpose) + }]; +} + +def PTO_LoadCbufToCbOp : PTO_Op<"load_cbuf_to_cb"> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$k, + I64:$n + ); + + let results = (outs); + + let assemblyFormat = [{ + $source `,` $destination `,` $k `,` $n attr-dict `:` type($source) `,` type($destination) `,` type($k) `,` type($n) + }]; +} + +def PTO_LoadCbufToCbS4Op : PTO_Op<"load_cbuf_to_cb_s4"> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$m_start, + I64:$k_start, + I64:$m_step, + I64:$k_step, + I64:$src_stride, + I64:$dst_stride, + I64:$transpose + ); + + let results = (outs); + + let assemblyFormat = [{ + $source `,` $destination `,` $m_start `,` $k_start `,` $m_step `,` + $k_step `,` $src_stride `,` $dst_stride `,` $transpose + attr-dict `:` type($source) `,` type($destination) `,` type($m_start) `,` + type($k_start) `,` type($m_step) `,` type($k_step) `,` type($src_stride) `,` + type($dst_stride) `,` type($transpose) + }]; +} + +def PTO_LoadCbufToCaMxOp : PTO_Op<"load_cbuf_to_ca_mx"> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$m, + I64:$k + ); + + let results = (outs); + + let assemblyFormat = [{ + $source `,` $destination `,` $m `,` $k attr-dict `:` type($source) `,` type($destination) `,` type($m) `,` type($k) + }]; +} + +def PTO_LoadCbufToCbMxOp : PTO_Op<"load_cbuf_to_cb_mx"> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$k, + I64:$n + ); + + let results = (outs); + + let assemblyFormat = [{ + $source `,` $destination `,` $k `,` $n attr-dict `:` type($source) `,` type($destination) `,` type($k) `,` type($n) + }]; +} + +def PTO_CopyMatrixCcToGmOp : PTO_Op<"copy_matrix_cc_to_gm"> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$m, + I64:$n + ); + + let results = (outs); + + let assemblyFormat = [{ + $source `,` $destination `,` $m `,` $n attr-dict `:` type($source) `,` type($destination) `,` type($m) `,` type($n) + }]; +} + +def PTO_CopyMatrixCcToCbufOp : PTO_Op<"copy_matrix_cc_to_cbuf"> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$config0, + I64:$config1 + ); + + let results = (outs); + + let assemblyFormat = [{ + $source `,` $destination `,` $config0 `,` $config1 attr-dict `:` type($source) `,` + type($destination) `,` type($config0) `,` type($config1) + }]; +} + +def PTO_CopyMatrixCcToUbOp : PTO_Op<"copy_matrix_cc_to_ub"> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$config0, + I64:$config1 + ); + + let results = (outs); + + let assemblyFormat = [{ + $source `,` $destination `,` $config0 `,` $config1 attr-dict `:` type($source) `,` + type($destination) `,` type($config0) `,` type($config1) + }]; +} + def PTO_VldsOp : PTO_Op<"vlds", [ DeclareOpInterfaceMethods ]> { @@ -776,6 +1053,47 @@ def PTO_VnegOp : PTO_UnaryVecOp<"vneg">; def PTO_VreluOp : PTO_UnaryVecOp<"vrelu">; def PTO_VnotOp : PTO_UnaryVecOp<"vnot">; def PTO_VcaddOp : PTO_UnaryVecOp<"vcadd">; + +def PTO_MadOp : PTO_Op<"mad", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$lhs, + PTO_BufferType:$rhs, + PTO_BufferType:$dst, + I64:$m, + I64:$n, + I64:$k + ); + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs `,` $dst `,` $m `,` $n `,` $k attr-dict `:` type($lhs) `,` type($rhs) `,` type($dst) `,` type($m) `,` type($n) `,` type($k) + }]; +} + +def PTO_MadMxOp : PTO_Op<"mad_mx", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$lhs, + PTO_BufferType:$rhs, + PTO_BufferType:$dst, + I64:$m, + I64:$n, + I64:$k + ); + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs `,` $dst `,` $m `,` $n `,` $k attr-dict `:` type($lhs) `,` type($rhs) `,` type($dst) `,` type($m) `,` type($n) `,` type($k) + }]; +} + def PTO_VcmaxOp : PTO_UnaryVecOp<"vcmax">; def PTO_VcminOp : PTO_UnaryVecOp<"vcmin">; def PTO_VcgaddOp : PTO_UnaryVecOp<"vcgadd">; @@ -1292,7 +1610,6 @@ def PTO_DmaStoreOp : PTO_Op<"dma_store", [ "::std::optional<::mlir::pto::DmaLoopConfig>":$loop2 )> ]; - } // NOTE: Unvalidated new x2 / pair / align-store-family abstractions. Added to diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index 88a454231..a8bb595b1 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -382,7 +382,14 @@ static LogicalResult dispatchVerifierByArch(Operation *op, FnA2A3 &&verifyA2A3, static std::optional parsePtrAddressSpaceKeyword(StringRef keyword) { return llvm::StringSwitch>(keyword) .Case("gm", pto::AddressSpace::GM) + .Case("mat", pto::AddressSpace::MAT) + .Case("left", pto::AddressSpace::LEFT) + .Case("right", pto::AddressSpace::RIGHT) + .Case("acc", pto::AddressSpace::ACC) + .Case("vec", pto::AddressSpace::VEC) .Case("ub", pto::AddressSpace::VEC) + .Case("bias", pto::AddressSpace::BIAS) + .Case("scaling", pto::AddressSpace::SCALING) .Default(std::nullopt); } @@ -391,8 +398,20 @@ static StringRef printPtrAddressSpaceKeyword(pto::AddressSpace space) { case pto::AddressSpace::GM: case pto::AddressSpace::Zero: return "gm"; + case pto::AddressSpace::MAT: + return "mat"; + case pto::AddressSpace::LEFT: + return "left"; + case pto::AddressSpace::RIGHT: + return "right"; + case pto::AddressSpace::ACC: + return "acc"; case pto::AddressSpace::VEC: return "ub"; + case pto::AddressSpace::BIAS: + return "bias"; + case pto::AddressSpace::SCALING: + return "scaling"; default: return {}; } @@ -514,7 +533,8 @@ static mlir::Type parsePTOTypeAllowNoBang(mlir::OpAsmParser &parser) { auto parsed = parsePtrAddressSpaceKeyword(memorySpaceKeyword); if (!parsed) { parser.emitError(parser.getCurrentLocation(), - "!pto.ptr address space must be `gm` or `ub`"); + "!pto.ptr address space must be one of " + "`gm|mat|left|right|acc|vec|ub|bias|scaling`"); return mlir::Type(); } memorySpace = pto::AddressSpaceAttr::get(ctx, *parsed); @@ -561,7 +581,8 @@ mlir::Type PtrType::parse(::mlir::AsmParser &parser) { auto parsed = parsePtrAddressSpaceKeyword(memorySpaceKeyword); if (!parsed) { parser.emitError(parser.getCurrentLocation(), - "!pto.ptr address space must be `gm` or `ub`"); + "!pto.ptr address space must be one of " + "`gm|mat|left|right|acc|vec|ub|bias|scaling`"); return {}; } memorySpace = pto::AddressSpaceAttr::get(parser.getContext(), *parsed); diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp index 10da9c322..93dd9a6e9 100644 --- a/lib/PTO/IR/VPTO.cpp +++ b/lib/PTO/IR/VPTO.cpp @@ -144,6 +144,8 @@ static bool isSupportedMovPadScalarType(Type type) { return false; } +static bool isMxElementType(Type type) { return isa(type); } + static std::optional getVdupMaskGranularity(Type elementType) { if (auto intType = dyn_cast(elementType)) { switch (intType.getWidth()) { @@ -1539,6 +1541,73 @@ LogicalResult SetMovPadValOp::verify() { << "expects i8/i16/i32 or f16/bf16/f32 scalar operand, but got " << valueType; } +void MadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getLhsMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getRhsMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDstMutable()); +} + +LogicalResult MadOp::verify() { + auto lhsType = dyn_cast(getLhs().getType()); + auto rhsType = dyn_cast(getRhs().getType()); + auto dstType = dyn_cast(getDst().getType()); + if (!lhsType || !rhsType || !dstType) + return emitOpError("requires typed !pto.ptr lhs/rhs/dst operands"); + + const auto lhsAS = lhsType.getMemorySpace().getAddressSpace(); + const auto rhsAS = rhsType.getMemorySpace().getAddressSpace(); + const auto dstAS = dstType.getMemorySpace().getAddressSpace(); + + // Keep legacy low-level VPTO syntax working (ub/vec for all operands), while + // also accepting strong cube spaces used by matmul pipelines. + const bool isLegacyUB = + lhsAS == pto::AddressSpace::VEC && rhsAS == pto::AddressSpace::VEC && + dstAS == pto::AddressSpace::VEC; + const bool isStrongCube = + lhsAS == pto::AddressSpace::LEFT && rhsAS == pto::AddressSpace::RIGHT && + dstAS == pto::AddressSpace::ACC; + if (!isLegacyUB && !isStrongCube) { + return emitOpError( + "requires either UB-backed lhs/rhs/dst pointers or " + "left/right/acc-typed lhs/rhs/dst pointers"); + } + + return success(); +} + +void MadMxOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getLhsMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getRhsMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDstMutable()); +} + +LogicalResult MadMxOp::verify() { + auto lhsType = dyn_cast(getLhs().getType()); + auto rhsType = dyn_cast(getRhs().getType()); + auto dstType = dyn_cast(getDst().getType()); + if (!lhsType || !rhsType || !dstType) + return emitOpError("requires typed !pto.ptr lhs/rhs/dst operands"); + + const auto lhsAS = lhsType.getMemorySpace().getAddressSpace(); + const auto rhsAS = rhsType.getMemorySpace().getAddressSpace(); + const auto dstAS = dstType.getMemorySpace().getAddressSpace(); + const bool isStrongCube = + lhsAS == pto::AddressSpace::LEFT && rhsAS == pto::AddressSpace::RIGHT && + dstAS == pto::AddressSpace::ACC; + if (!isStrongCube) + return emitOpError("requires left/right/acc-typed lhs/rhs/dst pointers"); + + if (!isMxElementType(lhsType.getElementType()) || + !isMxElementType(rhsType.getElementType())) { + return emitOpError( + "requires MX lhs/rhs element types (currently f8E4M3FN)"); + } + return success(); +} static bool isCompatibleScalarForSemanticType(Type semanticType, Type scalarType) { diff --git a/lib/PTO/Transforms/HIVMIntrinsicNaming.cpp b/lib/PTO/Transforms/HIVMIntrinsicNaming.cpp index 1ae923dac..6eb78bd43 100644 --- a/lib/PTO/Transforms/HIVMIntrinsicNaming.cpp +++ b/lib/PTO/Transforms/HIVMIntrinsicNaming.cpp @@ -82,27 +82,32 @@ static std::string getCopyElementFragment(Type type) { if (!ptrType) return {}; Type elementType = ptrType.getElementType(); - if (auto floatType = dyn_cast(elementType)) { - switch ((floatType.getWidth() + 7) / 8) { - case 1: - return "u8"; - case 2: - return "u16"; - case 4: - case 8: - return "u32"; - default: - return {}; - } - } + if (elementType.isF16()) + return "f16"; + if (elementType.isBF16()) + return "bf16"; + if (elementType.isF32()) + return "f32"; + std::string typeText; + llvm::raw_string_ostream os(typeText); + elementType.print(os); + os.flush(); + std::string lower = StringRef(typeText).lower(); + if (StringRef(lower).contains("e4m3")) + return "e4m3"; + if (StringRef(lower).contains("e5m2")) + return "e5m2"; + if (StringRef(lower).contains("e8m0")) + return "e8m0"; + if (StringRef(lower).contains("hif8")) + return "hif8"; if (auto intType = dyn_cast(elementType)) { - switch ((intType.getWidth() + 7) / 8) { - case 1: + switch (intType.getWidth()) { + case 8: return "u8"; - case 2: + case 16: return "u16"; - case 4: - case 8: + case 32: return "u32"; default: return {}; @@ -111,6 +116,112 @@ static std::string getCopyElementFragment(Type type) { return {}; } +static bool isMxElementType(Type type) { + if (auto floatType = dyn_cast(type)) + return floatType.getWidth() == 8; + std::string typeText; + llvm::raw_string_ostream os(typeText); + type.print(os); + os.flush(); + return StringRef(typeText).starts_with("f8"); +} + +static std::string getMadMxElementFragment(Type type) { + if (type.isF16()) + return "f16"; + if (type.isBF16()) + return "bf16"; + + std::string typeText; + llvm::raw_string_ostream os(typeText); + type.print(os); + os.flush(); + + std::string lower = StringRef(typeText).lower(); + if (StringRef(lower).contains("e4m3")) + return "e4m3"; + if (StringRef(lower).contains("e5m2")) + return "e5m2"; + if (StringRef(lower).contains("hif4")) + return "hif4"; + if (StringRef(lower).contains("e2m1x2")) + return "e2m1x2"; + if (StringRef(lower).contains("e1m2x2")) + return "e1m2x2"; + return {}; +} + +static std::string buildMadMxIntrinsicName(Type lhsType, Type rhsType) { + std::string lhs = getMadMxElementFragment(lhsType); + std::string rhs = getMadMxElementFragment(rhsType); + if (lhs.empty() || rhs.empty()) + return {}; + return "llvm.hivm.MMAD.MX." + lhs + rhs; +} + +static std::string getMadRhsFragment(Type type) { + if (type.isF16()) + return "f16"; + if (type.isBF16()) + return "bf16"; + if (type.isF32()) + return "f32"; + if (auto intType = dyn_cast(type)) { + if (intType.isSigned() && intType.getWidth() == 4) + return "s4"; + if (intType.isSigned() && intType.getWidth() == 8) + return "s8"; + if (intType.isUnsigned() && intType.getWidth() == 2) + return "u2"; + } + + std::string typeText; + llvm::raw_string_ostream os(typeText); + type.print(os); + os.flush(); + std::string lower = StringRef(typeText).lower(); + if (StringRef(lower).contains("e8m0")) + return "e8m0"; + return {}; +} + +static std::string getMadDstFragment(Type type) { + if (type.isF16()) + return "f16"; + if (type.isF32()) + return "f32"; + if (auto intType = dyn_cast(type)) { + if (intType.isSigned() && intType.getWidth() == 32) + return "s32"; + } + return {}; +} + +static std::string buildMadIntrinsicName(Type lhsType, Type rhsType, + Type dstType) { + std::string rhs = getMadRhsFragment(rhsType); + std::string dst = getMadDstFragment(dstType); + if (lhsType.isF16() && rhs == "f16" && dst == "f32") + return "llvm.hivm.MAD.f162f32.c310"; + if (lhsType.isF16() && rhs == "f16" && dst == "f16") + return "llvm.hivm.MAD.f162f16"; + if (lhsType.isF16() && rhs == "f16" && dst == "s32") + return "llvm.hivm.MAD.f162s32.1952"; + if (lhsType.isBF16() && rhs == "bf16" && dst == "f32") + return "llvm.hivm.MAD.bf162f32.c310"; + if (lhsType.isF32() && rhs == "f32" && dst == "f32") + return "llvm.hivm.MAD.f322f32.c310"; + if (lhsType.isF16() && rhs == "s4") + return "llvm.hivm.MAD.f16s4.c310"; + if (lhsType.isF16() && rhs == "s8") + return "llvm.hivm.MAD.f16s8.c310"; + if (lhsType.isF16() && rhs == "u2") + return "llvm.hivm.MAD.f16u2"; + if (lhsType.isF16() && rhs == "e8m0") + return "llvm.hivm.MAD.f16e8m0.c310"; + return {}; +} + static std::string getOpMnemonic(Operation *op) { return op->getName().stripDialect().str(); } @@ -229,6 +340,14 @@ static FailureOr selectConfigLike(Operation *op) { usedFields, ""); if (isa(op)) return makeResolved(op, "llvm.hivm.SET.LOOP.SIZE.OUTTOUB", usedFields, ""); + if (isa(op)) + return makeResolved(op, "llvm.hivm.SET.LOOP2.STRIDE.OUTTOL1", usedFields, + ""); + if (isa(op)) + return makeResolved(op, "llvm.hivm.SET.LOOP1.STRIDE.OUTTOL1", usedFields, + ""); + if (isa(op)) + return makeResolved(op, "llvm.hivm.SET.LOOP.SIZE.OUTTOL1", usedFields, ""); if (isa(op)) return makeResolved(op, "llvm.hivm.SET.LOOP2.STRIDE.UBTOOUT", usedFields, ""); @@ -237,8 +356,18 @@ static FailureOr selectConfigLike(Operation *op) { ""); if (isa(op)) return makeResolved(op, "llvm.hivm.SET.LOOP.SIZE.UBTOOUT", usedFields, ""); + if (isa(op)) + return makeResolved(op, "llvm.hivm.SET.MTE2.NZ.PARA", usedFields, ""); if (isa(op)) return makeResolved(op, "llvm.hivm.SET.MOV.PAD.VAL", usedFields, ""); + if (isa(op)) + return makeResolved(op, "llvm.hivm.SET.PAD.VAL.OUTTOL1", usedFields, ""); + if (isa(op)) + return makeResolved(op, "llvm.hivm.SET.FPC", usedFields, ""); + if (isa(op)) + return makeResolved(op, "llvm.hivm.SET.ATOMIC.S32", usedFields, ""); + if (isa(op)) + return makeResolved(op, "llvm.hivm.SET.ATOMIC.S8", usedFields, ""); llvm::SmallVector missingFields = {"confirmed_hivm_name"}; return makeUnresolved(op, getOpMnemonic(op), "", usedFields, missingFields, @@ -567,6 +696,178 @@ FailureOr selectStoreIntrinsic(Operation *op) { missingFields, ""); } + if (auto copy = dyn_cast(op)) { + usedFields = {"family=copy_gm_to_cbuf_multi_nd2nz"}; + return makeResolved(op, "llvm.hivm.MOV.OUT.TO.L1.MULTI.ND2NZ", usedFields, + ""); + } + + if (auto copy = dyn_cast(op)) { + usedFields = {"family=copy_gm_to_cbuf_multi_dn2nz"}; + return makeResolved(op, "llvm.hivm.MOV.OUT.TO.L1.MULTI.DN2NZ", usedFields, + ""); + } + + if (auto matmul = dyn_cast(op)) { + std::string lhsElem = getElementTypeFragment( + cast(matmul.getLhs().getType()).getElementType()); + std::string rhsElem = getElementTypeFragment( + cast(matmul.getRhs().getType()).getElementType()); + std::string dstElem = getElementTypeFragment( + cast(matmul.getDst().getType()).getElementType()); + usedFields = {"family=mad", "lhs=" + lhsElem, "rhs=" + rhsElem, + "dst=" + dstElem, "shape=i64xm_n_k"}; + Type lhsType = cast(matmul.getLhs().getType()).getElementType(); + Type rhsType = cast(matmul.getRhs().getType()).getElementType(); + Type dstType = cast(matmul.getDst().getType()).getElementType(); + std::string madName = buildMadIntrinsicName(lhsType, rhsType, dstType); + if (!madName.empty()) + return makeResolved(op, madName, usedFields, ""); + if (isMxElementType(lhsType) && isMxElementType(rhsType)) { + std::string mxName = buildMadMxIntrinsicName(lhsType, rhsType); + if (!mxName.empty()) + return makeResolved(op, mxName, usedFields, ""); + } + missingFields.push_back("lhs/rhs_element_type_mapping"); + return makeUnresolved(op, "mad", "llvm.hivm.MAD/llvm.hivm.MMAD.MX.*", + usedFields, + missingFields, ""); + } + + if (auto matmulMx = dyn_cast(op)) { + std::string lhsElem = getElementTypeFragment( + cast(matmulMx.getLhs().getType()).getElementType()); + std::string rhsElem = getElementTypeFragment( + cast(matmulMx.getRhs().getType()).getElementType()); + std::string dstElem = getElementTypeFragment( + cast(matmulMx.getDst().getType()).getElementType()); + usedFields = {"family=mad_mx", "lhs=" + lhsElem, "rhs=" + rhsElem, + "dst=" + dstElem, "shape=i64xm_n_k"}; + Type lhsType = + cast(matmulMx.getLhs().getType()).getElementType(); + Type rhsType = + cast(matmulMx.getRhs().getType()).getElementType(); + if (isMxElementType(lhsType) && isMxElementType(rhsType)) { + std::string mxName = buildMadMxIntrinsicName(lhsType, rhsType); + if (!mxName.empty()) + return makeResolved(op, mxName, usedFields, ""); + } + missingFields.push_back("lhs/rhs_mx_element_type"); + return makeUnresolved(op, "mad_mx", "llvm.hivm.MMAD.MX.*", usedFields, + missingFields, ""); + } + + if (auto copy = dyn_cast(op)) { + std::string elemFragment = getCopyElementFragment(copy.getSource().getType()); + usedFields = {"family=copy_gm_to_cbuf"}; + if (!elemFragment.empty()) + usedFields.push_back("element=" + elemFragment); + if (elemFragment.empty()) { + missingFields.push_back("element_type_mapping"); + return makeUnresolved(op, "copy_gm_to_cbuf", + "llvm.hivm.MOV.OUT.TO.L1.ALIGN.V2..DV", + usedFields, missingFields, ""); + } + return makeResolved(op, "llvm.hivm.MOV.OUT.TO.L1.ALIGN.V2." + elemFragment + + ".DV", + usedFields, ""); + } + + if (auto load = dyn_cast(op)) { + std::string srcElem = getElementTypeFragment( + cast(load.getSource().getType()).getElementType()); + std::string dstElem = getElementTypeFragment( + cast(load.getDestination().getType()).getElementType()); + usedFields = {"family=load_cbuf_to_ca", "src=" + srcElem, "dst=" + dstElem, + "shape=i64xm_k"}; + if (srcElem.empty()) { + missingFields.push_back("src_element_type_mapping"); + return makeUnresolved(op, "load_cbuf_to_ca", + "llvm.hivm.LOAD.L1.TO.L0A.2Dv2.", usedFields, + missingFields, ""); + } + return makeResolved(op, "llvm.hivm.LOAD.L1.TO.L0A.2Dv2." + srcElem, + usedFields, ""); + } + + if (auto load = dyn_cast(op)) { + std::string srcElem = getElementTypeFragment( + cast(load.getSource().getType()).getElementType()); + std::string dstElem = getElementTypeFragment( + cast(load.getDestination().getType()).getElementType()); + usedFields = {"family=load_cbuf_to_cb", "src=" + srcElem, "dst=" + dstElem, + "shape=i64xk_n"}; + if (srcElem.empty()) { + missingFields.push_back("src_element_type_mapping"); + return makeUnresolved(op, "load_cbuf_to_cb", + "llvm.hivm.LOAD.L1.TO.L0B.2Dv2.", usedFields, + missingFields, ""); + } + return makeResolved(op, "llvm.hivm.LOAD.L1.TO.L0B.2Dv2." + srcElem, + usedFields, ""); + } + + if (auto copy = dyn_cast(op)) { + std::string srcElem = getElementTypeFragment( + cast(copy.getSource().getType()).getElementType()); + std::string dstElem = getElementTypeFragment( + cast(copy.getDestination().getType()).getElementType()); + usedFields = {"family=copy_matrix_cc_to_gm", "src=" + srcElem, + "dst=" + dstElem, "shape=i64xm_n"}; + return makeResolved(op, "llvm.hivm.FIX.L0C.TO.OUT.f32.EXT", usedFields, ""); + } + + if (auto copy = dyn_cast(op)) { + std::string srcElem = getElementTypeFragment( + cast(copy.getSource().getType()).getElementType()); + std::string dstElem = getElementTypeFragment( + cast(copy.getDestination().getType()).getElementType()); + usedFields = {"family=copy_matrix_cc_to_cbuf", "src=" + srcElem, + "dst=" + dstElem}; + return makeResolved(op, "llvm.hivm.FIX.L0C.TO.L1.f32.EXT", usedFields, ""); + } + + if (auto copy = dyn_cast(op)) { + std::string srcElem = getElementTypeFragment( + cast(copy.getSource().getType()).getElementType()); + std::string dstElem = getElementTypeFragment( + cast(copy.getDestination().getType()).getElementType()); + usedFields = {"family=copy_matrix_cc_to_ub", "src=" + srcElem, + "dst=" + dstElem}; + if (dstElem == "f16") + return makeResolved(op, "llvm.hivm.MOV.L0CDPF32.TO.UB.f322f16", + usedFields, ""); + if (dstElem == "f32") + return makeResolved(op, "llvm.hivm.MOV.L0CDPF32.TO.UB.f322f32", + usedFields, ""); + missingFields.push_back("dst_element_type_mapping"); + return makeUnresolved(op, "copy_matrix_cc_to_ub", + "llvm.hivm.MOV.L0CDPF32.TO.UB.f322f{16|32}", + usedFields, missingFields, ""); + } + + if (auto copy = dyn_cast(op)) { + usedFields = {"family=copy_cbuf_to_bt", "src=f16"}; + return makeResolved(op, "llvm.hivm.MOV.L1.TO.BT.f16", usedFields, ""); + } + + if (auto copy = dyn_cast(op)) { + usedFields = {"family=copy_cbuf_to_fbuf"}; + return makeResolved(op, "llvm.hivm.MOV.L1.TO.FB.V2", usedFields, ""); + } + + if (auto load = dyn_cast(op)) { + usedFields = {"family=load_cbuf_to_ca_s4"}; + return makeResolved(op, "llvm.hivm.LOAD.L1.TO.L0A.2Dv2.s4", usedFields, + ""); + } + + if (auto load = dyn_cast(op)) { + usedFields = {"family=load_cbuf_to_cb_s4"}; + return makeResolved(op, "llvm.hivm.LOAD.L1.TO.L0B.2Dv2.s4", usedFields, + ""); + } + if (auto copy = dyn_cast(op)) { std::string elemFragment = getCopyElementFragment(copy.getSource().getType()); usedFields = {"family=copy_ubuf_to_gm"}; @@ -590,9 +891,12 @@ FailureOr selectIntrinsic(Operation *op) { return selectSyncLike(op); if (isa(op)) + pto::SetLoopSizeOutToUbOp, pto::SetLoop2StrideOutToL1Op, + pto::SetLoop1StrideOutToL1Op, pto::SetLoopSizeOutToL1Op, + pto::SetLoop2StrideUbToOutOp, pto::SetLoop1StrideUbToOutOp, + pto::SetLoopSizeUbToOutOp, pto::SetMte2NzParaOp, + pto::SetMovPadValOp, pto::SetPadValOutToL1Op, pto::SetFpcOp, + pto::SetAtomicS32Op, pto::SetAtomicS8Op>(op)) return selectConfigLike(op); if (succeeded(selectLoadIntrinsic(op))) diff --git a/lib/PTO/Transforms/PTOValidateVPTOIR.cpp b/lib/PTO/Transforms/PTOValidateVPTOIR.cpp index 172cade39..d17e7b6c8 100644 --- a/lib/PTO/Transforms/PTOValidateVPTOIR.cpp +++ b/lib/PTO/Transforms/PTOValidateVPTOIR.cpp @@ -198,8 +198,9 @@ class VPTOLegalityHelper { if (isa(op)) return VPTOBufferAddressFamily::Copy; - if (isa(op)) + if (isa(op)) return VPTOBufferAddressFamily::PtrOnly; if (isa( + loc, builder.getIntegerAttr(builder.getI1Type(), value ? 1 : 0)) + .getResult(); +} + +static bool isMxElementType(Type ty) { + if (auto floatType = dyn_cast(ty)) + return floatType.getWidth() == 8; + std::string typeText; + llvm::raw_string_ostream os(typeText); + ty.print(os); + os.flush(); + return StringRef(typeText).starts_with("f8"); +} + +static std::string getMadMxElementFragment(Type type) { + if (type.isF16()) + return "f16"; + if (type.isBF16()) + return "bf16"; + + std::string typeText; + llvm::raw_string_ostream os(typeText); + type.print(os); + os.flush(); + + std::string lower = StringRef(typeText).lower(); + if (StringRef(lower).contains("e4m3")) + return "e4m3"; + if (StringRef(lower).contains("e5m2")) + return "e5m2"; + if (StringRef(lower).contains("hif4")) + return "hif4"; + if (StringRef(lower).contains("e2m1x2")) + return "e2m1x2"; + if (StringRef(lower).contains("e1m2x2")) + return "e1m2x2"; + return {}; +} + +static FailureOr buildMadMxCalleeName(MLIRContext *context, + Type lhsElem, Type rhsElem) { + std::string lhs = getMadMxElementFragment(lhsElem); + std::string rhs = getMadMxElementFragment(rhsElem); + if (lhs.empty() || rhs.empty()) + return failure(); + return StringAttr::get(context, "llvm.hivm.MMAD.MX." + lhs + rhs).getValue(); +} + +static std::string getMadRhsFragment(Type type) { + if (type.isF16()) + return "f16"; + if (type.isBF16()) + return "bf16"; + if (type.isF32()) + return "f32"; + if (auto intType = dyn_cast(type)) { + if (intType.isSigned() && intType.getWidth() == 4) + return "s4"; + if (intType.isSigned() && intType.getWidth() == 8) + return "s8"; + if (intType.isUnsigned() && intType.getWidth() == 2) + return "u2"; + } + + std::string typeText; + llvm::raw_string_ostream os(typeText); + type.print(os); + os.flush(); + std::string lower = StringRef(typeText).lower(); + if (StringRef(lower).contains("e8m0")) + return "e8m0"; + return {}; +} + +static std::string getMadDstFragment(Type type) { + if (type.isF16()) + return "f16"; + if (type.isF32()) + return "f32"; + if (auto intType = dyn_cast(type)) { + if (intType.isSigned() && intType.getWidth() == 32) + return "s32"; + } + return {}; +} + +static FailureOr buildMadTypedCalleeName(MLIRContext *context, + Type lhsElem, Type rhsElem, + Type dstElem) { + std::string rhs = getMadRhsFragment(rhsElem); + std::string dst = getMadDstFragment(dstElem); + if (lhsElem.isF16() && rhs == "f16" && dst == "f32") + return StringAttr::get(context, "llvm.hivm.MAD.f162f32.c310").getValue(); + if (lhsElem.isF16() && rhs == "f16" && dst == "f16") + return StringAttr::get(context, "llvm.hivm.MAD.f162f16").getValue(); + if (lhsElem.isF16() && rhs == "f16" && dst == "s32") + return StringAttr::get(context, "llvm.hivm.MAD.f162s32.1952").getValue(); + if (lhsElem.isBF16() && rhs == "bf16" && dst == "f32") + return StringAttr::get(context, "llvm.hivm.MAD.bf162f32.c310").getValue(); + if (lhsElem.isF32() && rhs == "f32" && dst == "f32") + return StringAttr::get(context, "llvm.hivm.MAD.f322f32.c310").getValue(); + if (lhsElem.isF16() && rhs == "s4") + return StringAttr::get(context, "llvm.hivm.MAD.f16s4.c310").getValue(); + if (lhsElem.isF16() && rhs == "s8") + return StringAttr::get(context, "llvm.hivm.MAD.f16s8.c310").getValue(); + if (lhsElem.isF16() && rhs == "u2") + return StringAttr::get(context, "llvm.hivm.MAD.f16u2").getValue(); + if (lhsElem.isF16() && rhs == "e8m0") + return StringAttr::get(context, "llvm.hivm.MAD.f16e8m0.c310").getValue(); + return failure(); +} + static FailureOr buildLaneTypedCallee(MLIRContext *context, Type resultType, StringRef stem, @@ -260,6 +376,24 @@ static Value castIntegerLikeTo(Operation *anchor, Value value, Type targetType) return {}; } +static FailureOr reinterpretPointerToAddrSpace(Operation *anchor, + Value value, + unsigned targetAddressSpace) { + auto sourcePtrType = dyn_cast(value.getType()); + if (!sourcePtrType) + return failure(); + if (sourcePtrType.getAddressSpace() == targetAddressSpace) + return value; + + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); + Value asInt = builder.create(loc, builder.getI64Type(), value); + Type targetPtrType = + LLVM::LLVMPointerType::get(anchor->getContext(), targetAddressSpace); + return builder.create(loc, targetPtrType, asInt).getResult(); +} + static FailureOr normalizeVdupScalarOperand(OpBuilder &builder, Location loc, pto::VdupOp op) { Value input = op.getInput(); @@ -318,6 +452,20 @@ static std::string getCopyElementFragment(Type elementType) { return "bf16"; if (elementType.isF32()) return "f32"; + // Handle FP8 family (e4m3/e5m2/e8m0/hif8) used by cube-matmul/mad_mx. + std::string typeText; + llvm::raw_string_ostream os(typeText); + elementType.print(os); + os.flush(); + std::string lower = StringRef(typeText).lower(); + if (StringRef(lower).contains("e4m3")) + return "e4m3"; + if (StringRef(lower).contains("e5m2")) + return "e5m2"; + if (StringRef(lower).contains("e8m0")) + return "e8m0"; + if (StringRef(lower).contains("hif8")) + return "hif8"; if (auto intType = dyn_cast(elementType)) { switch (intType.getWidth()) { case 8: @@ -973,7 +1121,6 @@ static FailureOr packCopyUbToUbConfig(Operation *anchor, ValueRange operands) { if (operands.size() != 7) return failure(); - OpBuilder builder(anchor); builder.setInsertionPoint(anchor); Location loc = anchor->getLoc(); @@ -1004,999 +1151,2025 @@ packCopyUbToUbConfig(Operation *anchor, ValueRange operands) { return config; } -static FailureOr packVbitsortConfig(Operation *anchor, Value repeatTimes) { +static FailureOr +packCopyGmToCbufConfig0(Operation *anchor, Value nBurst, Value lenBurst) { OpBuilder builder(anchor); builder.setInsertionPoint(anchor); Location loc = anchor->getLoc(); - Value repeatI64 = castIntegerLikeTo(anchor, repeatTimes, builder.getI64Type()); - if (!repeatI64) + Value nBurstI64 = castIntegerLikeTo(anchor, nBurst, builder.getI64Type()); + Value lenBurstI64 = castIntegerLikeTo(anchor, lenBurst, builder.getI64Type()); + if (!nBurstI64 || !lenBurstI64) return failure(); - return builder - .create(loc, repeatI64, getI64Constant(builder, loc, 56)) - .getResult(); + + auto shl = [&](Value value, uint64_t amount) -> Value { + return builder.create(loc, value, + getI64Constant(builder, loc, amount)); + }; + auto bitOr = [&](Value lhs, Value rhs) -> Value { + return builder.create(loc, lhs, rhs); + }; + + Value config0 = getI64Constant(builder, loc, 0); // sid + config0 = bitOr(config0, shl(nBurstI64, 4)); // burst_num[24:4] + config0 = bitOr(config0, shl(lenBurstI64, 25)); // burst_len[45:25] + return config0; } -static FailureOr convertElementOffsetToBytes(Operation *anchor, Value offset, - Type elementType) { +static FailureOr +packCopyGmToCbufConfig1(Operation *anchor, Value srcStride, + Value dstStride) { OpBuilder builder(anchor); builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); - Value offsetI32 = castIntegerLikeTo(anchor, offset, builder.getI32Type()); - if (!offsetI32) + Value srcStrideI64 = castIntegerLikeTo(anchor, srcStride, builder.getI64Type()); + Value dstStrideI64 = castIntegerLikeTo(anchor, dstStride, builder.getI64Type()); + if (!srcStrideI64 || !dstStrideI64) return failure(); - unsigned bitWidth = 0; - if (auto intType = dyn_cast(elementType)) - bitWidth = intType.getWidth(); - else if (auto floatType = dyn_cast(elementType)) - bitWidth = floatType.getWidth(); - if (bitWidth == 0 || bitWidth % 8 != 0) - return failure(); + auto shl = [&](Value value, uint64_t amount) -> Value { + return builder.create(loc, value, + getI64Constant(builder, loc, amount)); + }; + auto bitOr = [&](Value lhs, Value rhs) -> Value { + return builder.create(loc, lhs, rhs); + }; - Value scale = builder.create( - anchor->getLoc(), builder.getI32IntegerAttr(bitWidth / 8)); - return builder.create(anchor->getLoc(), offsetI32, scale) - .getResult(); + // config1 packs burst_src_stride[39:0] and burst_dst_stride[60:40]. + return bitOr(srcStrideI64, shl(dstStrideI64, 40)); } -static FailureOr materializeDynamicPltMask(ConversionPatternRewriter &rewriter, - LoweringState &state, - Location loc, - Value laneCount, - Type vectorElemType) { - Type i32Type = rewriter.getI32Type(); - Value laneCountI32 = laneCount; - if (laneCountI32.getType() != i32Type) { - laneCountI32 = castIntegerLikeTo(rewriter.getInsertionBlock()->getParentOp(), - laneCountI32, i32Type); - if (!laneCountI32) - return failure(); - } +static FailureOr +packCopyGmToCbufMultiConfig0(Operation *anchor, Value sid, + Value loop1SrcStride, Value l2CacheCtl, + Value nValue) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); - StringRef calleeName; - if (vectorElemType.isF32()) { - calleeName = StringRef("llvm.hivm.plt.b32.v300"); - } else if (vectorElemType.isF16() || vectorElemType.isBF16()) { - calleeName = StringRef("llvm.hivm.plt.b16.v300"); - } else if (auto intType = dyn_cast(vectorElemType)) { - if (intType.getWidth() == 32) - calleeName = StringRef("llvm.hivm.plt.b32.v300"); - else if (intType.getWidth() == 16) - calleeName = StringRef("llvm.hivm.plt.b16.v300"); - else if (intType.getWidth() == 8) - calleeName = StringRef("llvm.hivm.plt.b8.v300"); - } - if (calleeName.empty()) + Value sidI64 = castIntegerLikeTo(anchor, sid, builder.getI64Type()); + Value loop1SrcStrideI64 = + castIntegerLikeTo(anchor, loop1SrcStride, builder.getI64Type()); + Value l2CacheCtlI64 = castIntegerLikeTo(anchor, l2CacheCtl, builder.getI64Type()); + Value nValueI64 = castIntegerLikeTo(anchor, nValue, builder.getI64Type()); + if (!sidI64 || !loop1SrcStrideI64 || !l2CacheCtlI64 || !nValueI64) return failure(); - Type maskType = VectorType::get({256}, rewriter.getI1Type()); - auto funcType = - rewriter.getFunctionType(TypeRange{i32Type}, TypeRange{maskType, i32Type}); - auto call = rewriter.create(loc, calleeName, funcType.getResults(), - ValueRange{laneCountI32}); - state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); - return call.getResult(0); -} + auto shl = [&](Value value, uint64_t amount) -> Value { + return builder.create(loc, value, + getI64Constant(builder, loc, amount)); + }; + auto bitOr = [&](Value lhs, Value rhs) -> Value { + return builder.create(loc, lhs, rhs); + }; -static FailureOr buildCarryBinaryCallee(MLIRContext *context, - Type resultType, - StringRef stem) { - std::string vec = - getElementTypeFragment(cast(resultType).getElementType()); - auto lanes = getElementCountFromVectorLike(resultType); - if (vec.empty() || !lanes) - return failure(); - return StringAttr::get(context, "llvm.hivm." + stem.str() + ".v" + - std::to_string(*lanes) + vec) - .getValue(); + Value config0 = sidI64; + config0 = bitOr(config0, shl(loop1SrcStrideI64, 4)); + config0 = bitOr(config0, shl(l2CacheCtlI64, 44)); + config0 = bitOr(config0, shl(nValueI64, 48)); + return config0; } -template -static StringRef getUnaryMaskedStem() { - if constexpr (std::is_same_v) - return "vabs"; - if constexpr (std::is_same_v) - return "vexp"; - if constexpr (std::is_same_v) - return "vln"; - if constexpr (std::is_same_v) - return "vneg"; - if constexpr (std::is_same_v) - return "vsqrt"; - if constexpr (std::is_same_v) - return "vrelu"; - if constexpr (std::is_same_v) - return "vnot"; - return {}; -} +static FailureOr +packCopyGmToCbufMultiConfig1(Operation *anchor, Value dValue, + Value loop4SrcStride, Value smallC0En) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); -template -static StringRef getBinaryMaskedStem() { - if constexpr (std::is_same_v) - return "vadd"; - if constexpr (std::is_same_v) - return "vsub"; - if constexpr (std::is_same_v) - return "vmul"; - if constexpr (std::is_same_v) - return "vdiv"; - if constexpr (std::is_same_v) - return "vmax"; - if constexpr (std::is_same_v) - return "vmin"; - if constexpr (std::is_same_v) - return "vand"; - if constexpr (std::is_same_v) - return "vor"; - if constexpr (std::is_same_v) - return "vxor"; - if constexpr (std::is_same_v) - return "vshl"; - if constexpr (std::is_same_v) - return "vshr"; - return {}; -} + Value dValueI64 = castIntegerLikeTo(anchor, dValue, builder.getI64Type()); + Value loop4SrcStrideI64 = + castIntegerLikeTo(anchor, loop4SrcStride, builder.getI64Type()); + Value smallC0EnI64 = castIntegerLikeTo(anchor, smallC0En, builder.getI64Type()); + if (!dValueI64 || !loop4SrcStrideI64 || !smallC0EnI64) + return failure(); -template -static StringRef getCarryBinaryStem() { - if constexpr (std::is_same_v) - return "vaddc"; - if constexpr (std::is_same_v) - return "vsubc"; - if constexpr (std::is_same_v) - return "vaddcs"; - if constexpr (std::is_same_v) - return "vsubcs"; - return {}; -} + auto shl = [&](Value value, uint64_t amount) -> Value { + return builder.create(loc, value, + getI64Constant(builder, loc, amount)); + }; + auto bitOr = [&](Value lhs, Value rhs) -> Value { + return builder.create(loc, lhs, rhs); + }; -template -static constexpr bool hasCarryInput() { - return std::is_same_v || - std::is_same_v; + Value config1 = dValueI64; + config1 = bitOr(config1, shl(loop4SrcStrideI64, 21)); + config1 = bitOr(config1, shl(smallC0EnI64, 61)); + return config1; } -static FailureOr buildVselCallee(MLIRContext *context, - Type resultType) { - std::string vec = - getElementTypeFragment(cast(resultType).getElementType()); - auto lanes = getElementCountFromVectorLike(resultType); - if (vec.empty() || !lanes) - return failure(); - return StringAttr::get(context, "llvm.hivm.vsel.v" + std::to_string(*lanes) + - vec) - .getValue(); -} +static FailureOr packCopyCbufToBtConfig(Operation *anchor, + Value convControl, + Value nBurst, Value lenBurst, + Value sourceGap, + Value dstGap) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); -static FailureOr buildVselrCallee(MLIRContext *context, - Type resultType) { - Type elemType = getElementTypeFromVectorLike(resultType); - auto lanes = getElementCountFromVectorLike(resultType); - if (!elemType || !lanes) + Value convControlI64 = + castIntegerLikeTo(anchor, convControl, builder.getI64Type()); + Value nBurstI64 = castIntegerLikeTo(anchor, nBurst, builder.getI64Type()); + Value lenBurstI64 = castIntegerLikeTo(anchor, lenBurst, builder.getI64Type()); + Value sourceGapI64 = castIntegerLikeTo(anchor, sourceGap, builder.getI64Type()); + Value dstGapI64 = castIntegerLikeTo(anchor, dstGap, builder.getI64Type()); + if (!convControlI64 || !nBurstI64 || !lenBurstI64 || !sourceGapI64 || + !dstGapI64) return failure(); - std::string vec = getElementTypeFragment(elemType); - if (auto floatType = dyn_cast(elemType); - floatType && floatType.isF32()) - vec = "u32"; - if (vec.empty()) - return failure(); + auto shl = [&](Value value, uint64_t amount) -> Value { + return builder.create(loc, value, + getI64Constant(builder, loc, amount)); + }; + auto bitOr = [&](Value lhs, Value rhs) -> Value { + return builder.create(loc, lhs, rhs); + }; - return StringAttr::get(context, "llvm.hivm.vselr.v" + std::to_string(*lanes) + - vec) - .getValue(); + Value config = shl(convControlI64, 3); + config = bitOr(config, shl(nBurstI64, 4)); + config = bitOr(config, shl(lenBurstI64, 16)); + config = bitOr(config, shl(sourceGapI64, 32)); + config = bitOr(config, shl(dstGapI64, 48)); + return config; } -static FailureOr buildVdupCallee(MLIRContext *context, pto::VdupOp op) { - Type inputType = op.getInput().getType(); - Type resultType = op.getResult().getType(); - std::string vec = getElementTypeFragment(getElementTypeFromVectorLike(resultType)); - auto lanes = getElementCountFromVectorLike(resultType); - if (vec.empty() || !lanes) - return failure(); - - if (isa(inputType)) { - StringRef position = op.getPosition().value_or("LOWEST"); - StringRef family = position == "HIGHEST" ? "vdupm" : "vdup"; - return StringAttr::get(context, "llvm.hivm." + family.str() + ".v" + - std::to_string(*lanes) + vec + ".z") - .getValue(); - } - - return StringAttr::get(context, "llvm.hivm.vdups.v" + std::to_string(*lanes) + - vec + ".z") - .getValue(); -} +static FailureOr packCopyCbufToFbufConfig(Operation *anchor, Value nBurst, + Value lenBurst, + Value sourceGap, + Value dstGap) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); -static FailureOr buildVbrCallee(MLIRContext *context, - Type semanticElementType) { - std::string scalar = getVbrScalarFragment(semanticElementType); - if (scalar.empty()) + Value nBurstI64 = castIntegerLikeTo(anchor, nBurst, builder.getI64Type()); + Value lenBurstI64 = castIntegerLikeTo(anchor, lenBurst, builder.getI64Type()); + Value sourceGapI64 = castIntegerLikeTo(anchor, sourceGap, builder.getI64Type()); + Value dstGapI64 = castIntegerLikeTo(anchor, dstGap, builder.getI64Type()); + if (!nBurstI64 || !lenBurstI64 || !sourceGapI64 || !dstGapI64) return failure(); - return StringAttr::get(context, "llvm.hivm.vbr." + scalar + ".v300").getValue(); -} -static FailureOr buildPstuCallee(MLIRContext *context, pto::PstuOp op) { - if (auto maskType = dyn_cast(op.getValue().getType())) { - if (maskType.isB16()) - return StringAttr::get(context, "llvm.hivm.pstu.b16").getValue(); - if (maskType.isB32()) - return StringAttr::get(context, "llvm.hivm.pstu.b32").getValue(); - } - return failure(); -} + auto shl = [&](Value value, uint64_t amount) -> Value { + return builder.create(loc, value, + getI64Constant(builder, loc, amount)); + }; + auto bitOr = [&](Value lhs, Value rhs) -> Value { + return builder.create(loc, lhs, rhs); + }; -static StringRef buildVstusCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.vstus").getValue(); + Value config = shl(nBurstI64, 4); + config = bitOr(config, shl(lenBurstI64, 16)); + config = bitOr(config, shl(sourceGapI64, 32)); + config = bitOr(config, shl(dstGapI64, 48)); + return config; } -static StringRef buildVsturCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.vstur").getValue(); -} +static FailureOr +packLoadCbufToS4Config0(Operation *anchor, Value mStart, Value kStart, + Value mStep, Value kStep) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); -static StringRef buildInitAlignCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.init.vector.align.data").getValue(); -} + Value mStartI64 = castIntegerLikeTo(anchor, mStart, builder.getI64Type()); + Value kStartI64 = castIntegerLikeTo(anchor, kStart, builder.getI64Type()); + Value mStepI64 = castIntegerLikeTo(anchor, mStep, builder.getI64Type()); + Value kStepI64 = castIntegerLikeTo(anchor, kStep, builder.getI64Type()); + if (!mStartI64 || !kStartI64 || !mStepI64 || !kStepI64) + return failure(); -template -static StringRef buildRuntimeQueryCallee(MLIRContext *context); + auto shl = [&](Value value, uint64_t amount) -> Value { + return builder.create(loc, value, + getI64Constant(builder, loc, amount)); + }; + auto bitOr = [&](Value lhs, Value rhs) -> Value { + return builder.create(loc, lhs, rhs); + }; -template <> -StringRef buildRuntimeQueryCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.GET.CTRL").getValue(); + Value config0 = mStartI64; + config0 = bitOr(config0, shl(kStartI64, 16)); + config0 = bitOr(config0, shl(mStepI64, 32)); + config0 = bitOr(config0, shl(kStepI64, 40)); + return config0; } -static StringRef buildSprclrCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.sprclr").getValue(); -} +static FailureOr +packLoadCbufToS4Config1(Operation *anchor, Value srcStride, Value dstStride) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); -template -static StringRef buildUnaryConfigCallee(MLIRContext *context); + Value srcStrideI64 = castIntegerLikeTo(anchor, srcStride, builder.getI64Type()); + Value dstStrideI64 = castIntegerLikeTo(anchor, dstStride, builder.getI64Type()); + if (!srcStrideI64 || !dstStrideI64) + return failure(); -template <> -StringRef buildUnaryConfigCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.SET.CTRL").getValue(); + auto shl = [&](Value value, uint64_t amount) -> Value { + return builder.create(loc, value, + getI64Constant(builder, loc, amount)); + }; + return builder.create(loc, srcStrideI64, shl(dstStrideI64, 16)) + .getResult(); } -static StringRef buildVstarCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.vstar").getValue(); -} +static FailureOr +packCopyMatrixCcToGmConfig0(Operation *anchor, Value m, Value n) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); -static StringRef buildVstasCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.vstas").getValue(); -} + Value mI64 = castIntegerLikeTo(anchor, m, builder.getI64Type()); + Value nI64 = castIntegerLikeTo(anchor, n, builder.getI64Type()); + if (!mI64 || !nI64) + return failure(); -template -static StringRef buildBinaryI64PureCallee(MLIRContext *context); + auto shl = [&](Value value, uint64_t amount) -> Value { + return builder.create(loc, value, + getI64Constant(builder, loc, amount)); + }; + auto bitOr = [&](Value lhs, Value rhs) -> Value { + return builder.create(loc, lhs, rhs); + }; -template <> -StringRef buildBinaryI64PureCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.SBITSET0").getValue(); -} + Value sid = getI64Constant(builder, loc, 0); + Value loopDstStride = nI64; -template <> -StringRef buildBinaryI64PureCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.SBITSET1").getValue(); + Value config0 = sid; + config0 = bitOr(config0, shl(nI64, 4)); // n_size[15:4] + config0 = bitOr(config0, shl(mI64, 16)); // m_size[31:16] + config0 = bitOr(config0, shl(loopDstStride, 32)); // loop_dst_stride[63:32] + return config0; } -static FailureOr buildVldsPostCallee(MLIRContext *context, - Type resultType) { - std::string vec = getElementTypeFragment(getElementTypeFromVectorLike(resultType)); - auto lanes = getElementCountFromVectorLike(resultType); - if (vec.empty() || !lanes) - return failure(); - return StringAttr::get(context, "llvm.hivm.vldsx1.post.v" + - std::to_string(*lanes) + vec) - .getValue(); -} +static FailureOr +packCopyMatrixCcToGmConfig1(Operation *anchor, Value n) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); -static FailureOr buildVstsPostCallee(MLIRContext *context, - Type valueType) { - std::string vec = getElementTypeFragment(getElementTypeFromVectorLike(valueType)); - auto lanes = getElementCountFromVectorLike(valueType); - if (vec.empty() || !lanes) + Value nI64 = castIntegerLikeTo(anchor, n, builder.getI64Type()); + if (!nI64) return failure(); - return StringAttr::get(context, "llvm.hivm.vstsx1.post.v" + - std::to_string(*lanes) + vec) - .getValue(); -} -static StringRef buildVldasCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.vldas").getValue(); + // config1 currently enables the minimal default behavior: + // loop_src_stride = n, all control bits = 0. + return nI64; } -static FailureOr buildVldusCallee(MLIRContext *context, - Type resultType) { - std::string vec = getElementTypeFragment(getElementTypeFromVectorLike(resultType)); - auto lanes = getElementCountFromVectorLike(resultType); - if (vec.empty() || !lanes) - return failure(); - return StringAttr::get(context, "llvm.hivm.vldus.v" + - std::to_string(*lanes) + vec) - .getValue(); -} +static FailureOr +packLoadCbufToCaConfig0(Operation *anchor, Value m, Value k) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); -static FailureOr buildVcmpCallee(MLIRContext *context, Type inputType, - StringRef cmpMode, - bool isScalarCompare) { - std::string elem = getElementTypeFragment(getElementTypeFromVectorLike(inputType)); - if (elem.empty()) + Value mI64 = castIntegerLikeTo(anchor, m, builder.getI64Type()); + Value kI64 = castIntegerLikeTo(anchor, k, builder.getI64Type()); + if (!mI64 || !kI64) return failure(); - StringRef stem = isScalarCompare ? "vcmps" : "vcmp"; - return StringAttr::get(context, "llvm.hivm." + stem.str() + "." + - cmpMode.str() + "." + elem + ".z") - .getValue(); -} -template -static StringRef getVecScalarMaskedStem() { - if constexpr (std::is_same_v) - return "vmuls"; - if constexpr (std::is_same_v) - return "vadds"; - if constexpr (std::is_same_v) - return "vmaxs"; - if constexpr (std::is_same_v) - return "vmins"; - if constexpr (std::is_same_v) - return "vlrelu"; - if constexpr (std::is_same_v) - return "vshls"; - if constexpr (std::is_same_v) - return "vshrs"; - return {}; -} + auto shl = [&](Value value, uint64_t amount) -> Value { + return builder.create(loc, value, + getI64Constant(builder, loc, amount)); + }; + auto bitOr = [&](Value lhs, Value rhs) -> Value { + return builder.create(loc, lhs, rhs); + }; -template -static StringRef getReductionUnaryStem() { - if constexpr (std::is_same_v) - return "vcadd"; - if constexpr (std::is_same_v) - return "vcmax"; - if constexpr (std::is_same_v) - return "vcmin"; - if constexpr (std::is_same_v) - return "vcgadd"; - if constexpr (std::is_same_v) - return "vcgmax"; - if constexpr (std::is_same_v) - return "vcgmin"; - if constexpr (std::is_same_v) - return "vcpadd"; - return {}; + Value mStart = getI64Constant(builder, loc, 0); + Value kStart = getI64Constant(builder, loc, 0); + Value mStep = mI64; + Value kStep = getI64Constant(builder, loc, 1); + + Value config0 = mStart; + config0 = bitOr(config0, shl(kStart, 16)); + config0 = bitOr(config0, shl(mStep, 32)); + config0 = bitOr(config0, shl(kStep, 40)); + return config0; } -static FailureOr buildCopyGmToUbCallee(MLIRContext *context, - Type sourceType) { - auto ptrType = dyn_cast(sourceType); - if (!ptrType) - return failure(); - Type elementType = ptrType.getElementType(); - std::string elem = getCopyElementFragment(elementType); - if (elem.empty()) +static FailureOr +packLoadCbufToCaConfig1(Operation *anchor, Value k) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); + + Value kI64 = castIntegerLikeTo(anchor, k, builder.getI64Type()); + if (!kI64) return failure(); - return StringAttr::get(context, "llvm.hivm.MOV.OUT.TO.UB.ALIGN.V2." + elem + - ".DV") - .getValue(); -} -static StringRef buildCopyUbToGmCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.MOV.UB.TO.OUT.ALIGN.V2.DV") - .getValue(); + auto shl = [&](Value value, uint64_t amount) -> Value { + return builder.create(loc, value, + getI64Constant(builder, loc, amount)); + }; + return builder.create(loc, shl(kI64, 16), kI64).getResult(); } -static StringRef buildCopyUbToUbCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.MOV.UB.TO.UB.v310").getValue(); -} +static FailureOr +packLoadCbufToCbConfig0(Operation *anchor, Value k, Value n) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); -static StringRef buildPstiCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.psti.b8").getValue(); -} + Value kI64 = castIntegerLikeTo(anchor, k, builder.getI64Type()); + Value nI64 = castIntegerLikeTo(anchor, n, builder.getI64Type()); + if (!kI64 || !nI64) + return failure(); -static StringRef buildPstsCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.psts.b8").getValue(); -} + auto shl = [&](Value value, uint64_t amount) -> Value { + return builder.create(loc, value, + getI64Constant(builder, loc, amount)); + }; + auto bitOr = [&](Value lhs, Value rhs) -> Value { + return builder.create(loc, lhs, rhs); + }; -static StringRef buildPldiCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.pldi.b8").getValue(); -} + Value mStart = getI64Constant(builder, loc, 0); + Value kStart = getI64Constant(builder, loc, 0); + Value mStep = kI64; + Value kStep = getI64Constant(builder, loc, 1); -static StringRef buildPldsCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.plds.b8").getValue(); + Value config0 = mStart; + config0 = bitOr(config0, shl(kStart, 16)); + config0 = bitOr(config0, shl(mStep, 32)); + config0 = bitOr(config0, shl(kStep, 40)); + return config0; } -static StringRef buildPnotCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.pnot.z").getValue(); -} +static FailureOr +packLoadCbufToCbConfig1(Operation *anchor, Value n) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); -static StringRef buildPselCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.psel").getValue(); -} + Value nI64 = castIntegerLikeTo(anchor, n, builder.getI64Type()); + if (!nI64) + return failure(); -static StringRef buildPandCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.pand.z").getValue(); + auto shl = [&](Value value, uint64_t amount) -> Value { + return builder.create(loc, value, + getI64Constant(builder, loc, amount)); + }; + return builder.create(loc, shl(nI64, 16), nI64).getResult(); } -static StringRef buildPorCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.por.z").getValue(); -} +static FailureOr +packMadConfig(Operation *anchor, Value m, Value n, Value k) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); -static StringRef buildPxorCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.pxor.z").getValue(); -} + Value mI64 = castIntegerLikeTo(anchor, m, builder.getI64Type()); + Value nI64 = castIntegerLikeTo(anchor, n, builder.getI64Type()); + Value kI64 = castIntegerLikeTo(anchor, k, builder.getI64Type()); + if (!mI64 || !nI64 || !kI64) + return failure(); -static StringRef buildPpackCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.ppack.z").getValue(); -} + auto shl = [&](Value value, uint64_t amount) -> Value { + return builder.create(loc, value, + getI64Constant(builder, loc, amount)); + }; + auto bitOr = [&](Value lhs, Value rhs) -> Value { + return builder.create(loc, lhs, rhs); + }; -static StringRef buildPunpackCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.punpack").getValue(); + Value config = mI64; + config = bitOr(config, shl(kI64, 16)); + config = bitOr(config, shl(nI64, 32)); + config = bitOr(config, shl(getI64Constant(builder, loc, 0), 48)); // unitFlag + config = bitOr(config, shl(getI64Constant(builder, loc, 0), 56)); // gemvCtrl + config = bitOr(config, shl(getI64Constant(builder, loc, 0), 57)); // btBufCtrl + config = + bitOr(config, shl(getI64Constant(builder, loc, 1), 58)); // zeroCmatrixCtrl + return config; } -template -static StringRef buildPredicatePairReorderCallee(MLIRContext *context); +static FailureOr packVbitsortConfig(Operation *anchor, Value repeatTimes) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); -template <> -StringRef buildPredicatePairReorderCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.pdintlv.b8").getValue(); + Value repeatI64 = castIntegerLikeTo(anchor, repeatTimes, builder.getI64Type()); + if (!repeatI64) + return failure(); + return builder + .create(loc, repeatI64, getI64Constant(builder, loc, 56)) + .getResult(); } -template <> -StringRef buildPredicatePairReorderCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.pdintlv.b16").getValue(); -} +static FailureOr convertElementOffsetToBytes(Operation *anchor, Value offset, + Type elementType) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); -template <> -StringRef buildPredicatePairReorderCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.pdintlv.b32").getValue(); -} + Value offsetI32 = castIntegerLikeTo(anchor, offset, builder.getI32Type()); + if (!offsetI32) + return failure(); -template <> -StringRef buildPredicatePairReorderCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.pintlv.b8").getValue(); -} + unsigned bitWidth = 0; + if (auto intType = dyn_cast(elementType)) + bitWidth = intType.getWidth(); + else if (auto floatType = dyn_cast(elementType)) + bitWidth = floatType.getWidth(); + if (bitWidth == 0 || bitWidth % 8 != 0) + return failure(); -template <> -StringRef buildPredicatePairReorderCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.pintlv.b16").getValue(); + Value scale = builder.create( + anchor->getLoc(), builder.getI32IntegerAttr(bitWidth / 8)); + return builder.create(anchor->getLoc(), offsetI32, scale) + .getResult(); } -template <> -StringRef buildPredicatePairReorderCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.pintlv.b32").getValue(); -} +static FailureOr materializeDynamicPltMask(ConversionPatternRewriter &rewriter, + LoweringState &state, + Location loc, + Value laneCount, + Type vectorElemType) { + Type i32Type = rewriter.getI32Type(); + Value laneCountI32 = laneCount; + if (laneCountI32.getType() != i32Type) { + laneCountI32 = castIntegerLikeTo(rewriter.getInsertionBlock()->getParentOp(), + laneCountI32, i32Type); + if (!laneCountI32) + return failure(); + } -static FailureOr buildInterleaveCallee(MLIRContext *context, - Type resultType, - StringRef stem) { - return buildLaneTypedCallee(context, resultType, stem, ""); + StringRef calleeName; + if (vectorElemType.isF32()) { + calleeName = StringRef("llvm.hivm.plt.b32.v300"); + } else if (vectorElemType.isF16() || vectorElemType.isBF16()) { + calleeName = StringRef("llvm.hivm.plt.b16.v300"); + } else if (auto intType = dyn_cast(vectorElemType)) { + if (intType.getWidth() == 32) + calleeName = StringRef("llvm.hivm.plt.b32.v300"); + else if (intType.getWidth() == 16) + calleeName = StringRef("llvm.hivm.plt.b16.v300"); + else if (intType.getWidth() == 8) + calleeName = StringRef("llvm.hivm.plt.b8.v300"); + } + if (calleeName.empty()) + return failure(); + + Type maskType = VectorType::get({256}, rewriter.getI1Type()); + auto funcType = + rewriter.getFunctionType(TypeRange{i32Type}, TypeRange{maskType, i32Type}); + auto call = rewriter.create(loc, calleeName, funcType.getResults(), + ValueRange{laneCountI32}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + return call.getResult(0); } -static FailureOr buildUnpackCallee(MLIRContext *context, - Type inputType, - Type resultType, - StringRef stem) { - std::string input = - getElementTypeFragment(getElementTypeFromVectorLike(inputType)); - std::string result = - getElementTypeFragment(getElementTypeFromVectorLike(resultType)); - if (input.empty() || result.empty()) +static FailureOr buildCarryBinaryCallee(MLIRContext *context, + Type resultType, + StringRef stem) { + std::string vec = + getElementTypeFragment(cast(resultType).getElementType()); + auto lanes = getElementCountFromVectorLike(resultType); + if (vec.empty() || !lanes) return failure(); - return StringAttr::get(context, - "llvm.hivm." + stem.str() + "." + input + "2" + result) + return StringAttr::get(context, "llvm.hivm." + stem.str() + ".v" + + std::to_string(*lanes) + vec) .getValue(); } -static FailureOr buildVpackCallee(MLIRContext *context, Type inputType, - Type resultType) { - std::string input = - getElementTypeFragment(getElementTypeFromVectorLike(inputType)); - std::string result = - getElementTypeFragment(getElementTypeFromVectorLike(resultType)); - if (input.empty() || result.empty()) - return failure(); +template +static StringRef getUnaryMaskedStem() { + if constexpr (std::is_same_v) + return "vabs"; + if constexpr (std::is_same_v) + return "vexp"; + if constexpr (std::is_same_v) + return "vln"; + if constexpr (std::is_same_v) + return "vneg"; + if constexpr (std::is_same_v) + return "vsqrt"; + if constexpr (std::is_same_v) + return "vrelu"; + if constexpr (std::is_same_v) + return "vnot"; + return {}; +} - return StringAttr::get(context, "llvm.hivm.vpack." + input + "2" + result + ".x") - .getValue(); +template +static StringRef getBinaryMaskedStem() { + if constexpr (std::is_same_v) + return "vadd"; + if constexpr (std::is_same_v) + return "vsub"; + if constexpr (std::is_same_v) + return "vmul"; + if constexpr (std::is_same_v) + return "vdiv"; + if constexpr (std::is_same_v) + return "vmax"; + if constexpr (std::is_same_v) + return "vmin"; + if constexpr (std::is_same_v) + return "vand"; + if constexpr (std::is_same_v) + return "vor"; + if constexpr (std::is_same_v) + return "vxor"; + if constexpr (std::is_same_v) + return "vshl"; + if constexpr (std::is_same_v) + return "vshr"; + return {}; } -static FailureOr buildVsqzCallee(MLIRContext *context, - Type resultType) { - return buildLaneTypedCallee(context, resultType, "vsqz", ".x.v300"); +template +static StringRef getCarryBinaryStem() { + if constexpr (std::is_same_v) + return "vaddc"; + if constexpr (std::is_same_v) + return "vsubc"; + if constexpr (std::is_same_v) + return "vaddcs"; + if constexpr (std::is_same_v) + return "vsubcs"; + return {}; } -static FailureOr buildVusqzCallee(MLIRContext *context, - Type resultType) { - return buildLaneTypedCallee(context, resultType, "vusqz", ".m"); +template +static constexpr bool hasCarryInput() { + return std::is_same_v || + std::is_same_v; } -static FailureOr buildVmulaCallee(MLIRContext *context, - Type resultType) { - return buildLaneTypedCallee(context, resultType, "vmula", ".m"); +static FailureOr buildVselCallee(MLIRContext *context, + Type resultType) { + std::string vec = + getElementTypeFragment(cast(resultType).getElementType()); + auto lanes = getElementCountFromVectorLike(resultType); + if (vec.empty() || !lanes) + return failure(); + return StringAttr::get(context, "llvm.hivm.vsel.v" + std::to_string(*lanes) + + vec) + .getValue(); } -static FailureOr buildVmullCallee(MLIRContext *context, +static FailureOr buildVselrCallee(MLIRContext *context, Type resultType) { - return buildLaneTypedCallee(context, resultType, "vmull", ""); -} + Type elemType = getElementTypeFromVectorLike(resultType); + auto lanes = getElementCountFromVectorLike(resultType); + if (!elemType || !lanes) + return failure(); -template -static StringRef getPredicateStoreCallee(MLIRContext *context); + std::string vec = getElementTypeFragment(elemType); + if (auto floatType = dyn_cast(elemType); + floatType && floatType.isF32()) + vec = "u32"; + if (vec.empty()) + return failure(); -template <> -StringRef getPredicateStoreCallee(MLIRContext *context) { - return buildPstiCallee(context); -} - -template <> -StringRef getPredicateStoreCallee(MLIRContext *context) { - return buildPstsCallee(context); -} - -template -static StringRef getPredicateLoadCallee(MLIRContext *context); - -template <> -StringRef getPredicateLoadCallee(MLIRContext *context) { - return buildPldiCallee(context); -} - -template <> -StringRef getPredicateLoadCallee(MLIRContext *context) { - return buildPldsCallee(context); + return StringAttr::get(context, "llvm.hivm.vselr.v" + std::to_string(*lanes) + + vec) + .getValue(); } -template -static StringRef getPredicateMaskCallee(MLIRContext *context); +static FailureOr buildVdupCallee(MLIRContext *context, pto::VdupOp op) { + Type inputType = op.getInput().getType(); + Type resultType = op.getResult().getType(); + std::string vec = getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + auto lanes = getElementCountFromVectorLike(resultType); + if (vec.empty() || !lanes) + return failure(); -template <> -StringRef getPredicateMaskCallee(MLIRContext *context) { - return buildPnotCallee(context); -} + if (isa(inputType)) { + StringRef position = op.getPosition().value_or("LOWEST"); + StringRef family = position == "HIGHEST" ? "vdupm" : "vdup"; + return StringAttr::get(context, "llvm.hivm." + family.str() + ".v" + + std::to_string(*lanes) + vec + ".z") + .getValue(); + } -template <> -StringRef getPredicateMaskCallee(MLIRContext *context) { - return buildPselCallee(context); + return StringAttr::get(context, "llvm.hivm.vdups.v" + std::to_string(*lanes) + + vec + ".z") + .getValue(); } -template <> -StringRef getPredicateMaskCallee(MLIRContext *context) { - return buildPandCallee(context); +static FailureOr buildVbrCallee(MLIRContext *context, + Type semanticElementType) { + std::string scalar = getVbrScalarFragment(semanticElementType); + if (scalar.empty()) + return failure(); + return StringAttr::get(context, "llvm.hivm.vbr." + scalar + ".v300").getValue(); } -template <> -StringRef getPredicateMaskCallee(MLIRContext *context) { - return buildPorCallee(context); +static FailureOr buildPstuCallee(MLIRContext *context, pto::PstuOp op) { + if (auto maskType = dyn_cast(op.getValue().getType())) { + if (maskType.isB16()) + return StringAttr::get(context, "llvm.hivm.pstu.b16").getValue(); + if (maskType.isB32()) + return StringAttr::get(context, "llvm.hivm.pstu.b32").getValue(); + } + return failure(); } -template <> -StringRef getPredicateMaskCallee(MLIRContext *context) { - return buildPxorCallee(context); +static StringRef buildVstusCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.vstus").getValue(); } -template -static StringRef getPredicatePackCallee(MLIRContext *context); - -template <> -StringRef getPredicatePackCallee(MLIRContext *context) { - return buildPpackCallee(context); +static StringRef buildVsturCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.vstur").getValue(); } -template <> -StringRef getPredicatePackCallee(MLIRContext *context) { - return buildPunpackCallee(context); +static StringRef buildInitAlignCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.init.vector.align.data").getValue(); } -template -static StringRef buildPltCallee(MLIRContext *context); - -template <> -StringRef buildPltCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.plt.b8.v300").getValue(); -} +template +static StringRef buildRuntimeQueryCallee(MLIRContext *context); template <> -StringRef buildPltCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.plt.b16.v300").getValue(); +StringRef buildRuntimeQueryCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.GET.CTRL").getValue(); } -template <> -StringRef buildPltCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.plt.b32.v300").getValue(); +static StringRef buildSprclrCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.sprclr").getValue(); } -template -static StringRef buildPsetCallee(MLIRContext *context); +template +static StringRef buildUnaryConfigCallee(MLIRContext *context); template <> -StringRef buildPsetCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.pset.b8").getValue(); +StringRef buildUnaryConfigCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.CTRL").getValue(); } -template <> -StringRef buildPsetCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.pset.b16").getValue(); +static StringRef buildVstarCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.vstar").getValue(); } -template <> -StringRef buildPsetCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.pset.b32").getValue(); +static StringRef buildVstasCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.vstas").getValue(); } -template -static StringRef buildPgeCallee(MLIRContext *context); - -template <> -StringRef buildPgeCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.pge.b8").getValue(); -} +template +static StringRef buildBinaryI64PureCallee(MLIRContext *context); template <> -StringRef buildPgeCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.pge.b16").getValue(); +StringRef buildBinaryI64PureCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SBITSET0").getValue(); } template <> -StringRef buildPgeCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.pge.b32").getValue(); +StringRef buildBinaryI64PureCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SBITSET1").getValue(); } -static FailureOr buildVldsCallee(MLIRContext *context, Type resultType) { +static FailureOr buildVldsPostCallee(MLIRContext *context, + Type resultType) { std::string vec = getElementTypeFragment(getElementTypeFromVectorLike(resultType)); auto lanes = getElementCountFromVectorLike(resultType); if (vec.empty() || !lanes) return failure(); - return StringAttr::get(context, "llvm.hivm.vldsx1.v" + std::to_string(*lanes) + - vec) + return StringAttr::get(context, "llvm.hivm.vldsx1.post.v" + + std::to_string(*lanes) + vec) .getValue(); } -static FailureOr buildVldsx2Callee(MLIRContext *context, - Type resultType) { - return buildLaneTypedCallee(context, resultType, "vldsx2", ""); -} - -static StringRef buildVsldbCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.vsldb").getValue(); -} - -static FailureOr buildVstsCallee(MLIRContext *context, Type valueType) { +static FailureOr buildVstsPostCallee(MLIRContext *context, + Type valueType) { std::string vec = getElementTypeFragment(getElementTypeFromVectorLike(valueType)); auto lanes = getElementCountFromVectorLike(valueType); if (vec.empty() || !lanes) return failure(); - return StringAttr::get(context, "llvm.hivm.vstsx1.v" + std::to_string(*lanes) + - vec) + return StringAttr::get(context, "llvm.hivm.vstsx1.post.v" + + std::to_string(*lanes) + vec) .getValue(); } -static FailureOr buildVstsx2Callee(MLIRContext *context, Type valueType) { - return buildLaneTypedCallee(context, valueType, "vstsx2", ""); -} - -static StringRef buildVsstbCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.vsstb").getValue(); +static StringRef buildVldasCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.vldas").getValue(); } -static FailureOr buildVgather2Callee(MLIRContext *context, - Type resultType) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(resultType)); +static FailureOr buildVldusCallee(MLIRContext *context, + Type resultType) { + std::string vec = getElementTypeFragment(getElementTypeFromVectorLike(resultType)); auto lanes = getElementCountFromVectorLike(resultType); if (vec.empty() || !lanes) return failure(); - return StringAttr::get(context, "llvm.hivm.vgather2.v300.v" + + return StringAttr::get(context, "llvm.hivm.vldus.v" + std::to_string(*lanes) + vec) .getValue(); } -static FailureOr buildVgather2BcCallee(MLIRContext *context, - Type resultType) { - return buildLaneTypedCallee(context, resultType, "vgather2.bc", ""); +static FailureOr buildVcmpCallee(MLIRContext *context, Type inputType, + StringRef cmpMode, + bool isScalarCompare) { + std::string elem = getElementTypeFragment(getElementTypeFromVectorLike(inputType)); + if (elem.empty()) + return failure(); + StringRef stem = isScalarCompare ? "vcmps" : "vcmp"; + return StringAttr::get(context, "llvm.hivm." + stem.str() + "." + + cmpMode.str() + "." + elem + ".z") + .getValue(); } -static FailureOr buildVgatherbCallee(MLIRContext *context, - Type resultType) { - return buildLaneTypedCallee(context, resultType, "vgatherb.v310", ""); +template +static StringRef getVecScalarMaskedStem() { + if constexpr (std::is_same_v) + return "vmuls"; + if constexpr (std::is_same_v) + return "vadds"; + if constexpr (std::is_same_v) + return "vmaxs"; + if constexpr (std::is_same_v) + return "vmins"; + if constexpr (std::is_same_v) + return "vlrelu"; + if constexpr (std::is_same_v) + return "vshls"; + if constexpr (std::is_same_v) + return "vshrs"; + return {}; } -static FailureOr buildVscatterCallee(MLIRContext *context, - Type valueType) { - return buildLaneTypedCallee(context, valueType, "vscatter", ".v300"); +template +static StringRef getReductionUnaryStem() { + if constexpr (std::is_same_v) + return "vcadd"; + if constexpr (std::is_same_v) + return "vcmax"; + if constexpr (std::is_same_v) + return "vcmin"; + if constexpr (std::is_same_v) + return "vcgadd"; + if constexpr (std::is_same_v) + return "vcgmax"; + if constexpr (std::is_same_v) + return "vcgmin"; + if constexpr (std::is_same_v) + return "vcpadd"; + return {}; } -static FailureOr buildVpreluCallee(MLIRContext *context, - Type resultType) { - return buildLaneTypedCallee(context, resultType, "vprelu", ".x"); +static FailureOr buildCopyGmToUbCallee(MLIRContext *context, + Type sourceType) { + auto ptrType = dyn_cast(sourceType); + if (!ptrType) + return failure(); + Type elementType = ptrType.getElementType(); + std::string elem = getCopyElementFragment(elementType); + if (elem.empty()) + return failure(); + return StringAttr::get(context, "llvm.hivm.MOV.OUT.TO.UB.ALIGN.V2." + elem + + ".DV") + .getValue(); } -static FailureOr buildVaxpyCallee(MLIRContext *context, - Type resultType) { - return buildLaneTypedCallee(context, resultType, "vaxpy", ".m"); +static StringRef buildCopyUbToGmCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.MOV.UB.TO.OUT.ALIGN.V2.DV") + .getValue(); } -static FailureOr buildVciCallee(MLIRContext *context, Type resultType) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(resultType)); - auto lanes = getElementCountFromVectorLike(resultType); - if (vec.empty() || !lanes) +static StringRef buildCopyUbToUbCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.MOV.UB.TO.UB.v310").getValue(); +} + +static FailureOr buildMadCallee(MLIRContext *context, + pto::MadOp op) { + auto lhsType = dyn_cast(op.getLhs().getType()); + auto rhsType = dyn_cast(op.getRhs().getType()); + auto dstType = dyn_cast(op.getDst().getType()); + if (!lhsType || !rhsType || !dstType) return failure(); - if (vec == "f16" || vec == "f32") - return StringAttr::get(context, "llvm.hivm.vci.v" + std::to_string(*lanes) + - vec + "." + vec) - .getValue(); - return StringAttr::get(context, - "llvm.hivm.vci.v" + std::to_string(*lanes) + vec) - .getValue(); + + Type lhsElem = lhsType.getElementType(); + Type rhsElem = rhsType.getElementType(); + Type dstElem = dstType.getElementType(); + if (auto typed = buildMadTypedCalleeName(context, lhsElem, rhsElem, dstElem); + succeeded(typed)) + return typed; + if (isMxElementType(lhsElem) && isMxElementType(rhsElem)) + return buildMadMxCalleeName(context, lhsElem, rhsElem); + return failure(); } -static FailureOr buildVtrcCallee(MLIRContext *context, Type resultType) { - std::string vec = - getElementTypeFragment(getElementTypeFromVectorLike(resultType)); - auto lanes = getElementCountFromVectorLike(resultType); - if (vec.empty() || !lanes) +static FailureOr buildMadMxCallee(MLIRContext *context, + pto::MadMxOp op) { + auto lhsType = dyn_cast(op.getLhs().getType()); + auto rhsType = dyn_cast(op.getRhs().getType()); + if (!lhsType || !rhsType) return failure(); - return StringAttr::get(context, "llvm.hivm.vtrc." + vec + ".x").getValue(); + if (isMxElementType(lhsType.getElementType()) && + isMxElementType(rhsType.getElementType())) { + return buildMadMxCalleeName(context, lhsType.getElementType(), + rhsType.getElementType()); + } + return failure(); } -static FailureOr buildVexpdifCallee(MLIRContext *context, - Type inputType, - Type resultType) { - std::string srcVec = - getElementTypeFragment(getElementTypeFromVectorLike(inputType)); - auto srcLanes = getElementCountFromVectorLike(inputType); - std::string dstElem = - getElementTypeFragment(getElementTypeFromVectorLike(resultType)); - if (srcVec.empty() || dstElem.empty() || !srcLanes) +static FailureOr buildCopyGmToCbufCallee(MLIRContext *context, + Type sourceType) { + auto ptrType = dyn_cast(sourceType); + if (!ptrType) return failure(); - return StringAttr::get(context, "llvm.hivm.vexpdif.v" + - std::to_string(*srcLanes) + srcVec + - dstElem) + std::string elem = getCopyElementFragment(ptrType.getElementType()); + if (elem.empty()) + return failure(); + return StringAttr::get(context, "llvm.hivm.MOV.OUT.TO.L1.ALIGN.V2." + elem + + ".DV") .getValue(); } -static FailureOr buildVbitsortCallee(MLIRContext *context, - pto::VbitsortOp op) { - Type sourceElemType = cast(op.getSource().getType()).getElementType(); - if (sourceElemType.isF16()) - return StringAttr::get(context, "llvm.hivm.VBS32.V300.f16").getValue(); - if (sourceElemType.isF32()) - return StringAttr::get(context, "llvm.hivm.VBS32.V300.f32").getValue(); - return failure(); +static StringRef buildCopyGmToCbufMultiNd2NzCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.MOV.OUT.TO.L1.MULTI.ND2NZ") + .getValue(); } -static FailureOr buildVcvtContract(pto::VcvtOp op) { - Type inputElemType = getElementTypeFromVectorLike(op.getInput().getType()); - Type resultElemType = getElementTypeFromVectorLike(op.getResult().getType()); - if (!inputElemType || !resultElemType) +static StringRef buildCopyGmToCbufMultiDn2NzCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.MOV.OUT.TO.L1.MULTI.DN2NZ") + .getValue(); +} + +static FailureOr buildLoadCbufToCaCallee(MLIRContext *context, + Type sourceType) { + auto ptrType = dyn_cast(sourceType); + if (!ptrType) return failure(); - auto contract = lookupVcvtContract(classifyVcvtElemType(inputElemType), - classifyVcvtElemType(resultElemType)); - if (!contract) + std::string elem = getElementTypeFragment(ptrType.getElementType()); + if (elem.empty()) return failure(); - return *contract; + return StringAttr::get(context, "llvm.hivm.LOAD.L1.TO.L0A.2Dv2." + elem) + .getValue(); } -template -static StringRef buildSetLoopCallee(MLIRContext *context); +static StringRef buildLoadCbufToCaS4Callee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.LOAD.L1.TO.L0A.2Dv2.s4") + .getValue(); +} -template -static StringRef buildUnaryConfigCallee(MLIRContext *context); +static FailureOr buildLoadCbufToCbCallee(MLIRContext *context, + Type sourceType) { + auto ptrType = dyn_cast(sourceType); + if (!ptrType) + return failure(); + std::string elem = getElementTypeFragment(ptrType.getElementType()); + if (elem.empty()) + return failure(); + return StringAttr::get(context, "llvm.hivm.LOAD.L1.TO.L0B.2Dv2." + elem) + .getValue(); +} -template <> -StringRef buildSetLoopCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.SET.LOOP2.STRIDE.OUTTOUB") +static StringRef buildLoadCbufToCbS4Callee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.LOAD.L1.TO.L0B.2Dv2.s4") .getValue(); } -template <> -StringRef buildSetLoopCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.SET.LOOP1.STRIDE.OUTTOUB") +static StringRef buildLoadCbufToCaMxCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.LOAD.L1.TO.L0A.MX.2Dv2.v") .getValue(); } -template <> -StringRef buildSetLoopCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.SET.LOOP.SIZE.OUTTOUB") +static StringRef buildLoadCbufToCbMxCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.LOAD.L1.TO.L0B.MX.2Dv2.v") .getValue(); } -template <> -StringRef buildSetLoopCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.SET.LOOP2.STRIDE.UBTOOUT") +static StringRef buildCopyMatrixCcToGmCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.FIX.L0C.TO.OUT.f32.EXT") .getValue(); } -template <> -StringRef buildSetLoopCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.SET.LOOP1.STRIDE.UBTOOUT") +static StringRef buildCopyMatrixCcToCbufCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.FIX.L0C.TO.L1.f32.EXT") .getValue(); } -template <> -StringRef buildSetLoopCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.SET.LOOP.SIZE.UBTOOUT") +static FailureOr buildCopyMatrixCcToUbCallee(MLIRContext *context, + Type destinationType) { + auto ptrType = dyn_cast(destinationType); + if (!ptrType) + return failure(); + Type dstElem = ptrType.getElementType(); + if (dstElem.isF16()) + return StringAttr::get(context, "llvm.hivm.MOV.L0CDPF32.TO.UB.f322f16") + .getValue(); + if (dstElem.isF32()) + return StringAttr::get(context, "llvm.hivm.MOV.L0CDPF32.TO.UB.f322f32") + .getValue(); + return failure(); +} + +static StringRef buildCopyCbufToBtCallee(pto::CopyCbufToBtOp op) { + return StringAttr::get(op.getContext(), "llvm.hivm.MOV.L1.TO.BT.f16") .getValue(); } -template <> -StringRef buildUnaryConfigCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.SET.MOV.PAD.VAL").getValue(); +static StringRef buildCopyCbufToFbufCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.MOV.L1.TO.FB.V2").getValue(); } -static FailureOr encodeMovPadValue(Location loc, Value value, - ConversionPatternRewriter &rewriter) { - Type type = value.getType(); - Value payload = value; - unsigned bitWidth = 0; +static StringRef buildPstiCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.psti.b8").getValue(); +} - if (auto intType = dyn_cast(type)) { - bitWidth = intType.getWidth(); - } else if (auto floatType = dyn_cast(type)) { - bitWidth = floatType.getWidth(); - auto intType = rewriter.getIntegerType(bitWidth); - payload = rewriter.create(loc, intType, value); - } else { - return failure(); - } +static StringRef buildPstsCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.psts.b8").getValue(); +} - if (bitWidth != 8 && bitWidth != 16 && bitWidth != 32) - return failure(); +static StringRef buildPldiCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pldi.b8").getValue(); +} - return rewriter.create(loc, rewriter.getI64Type(), payload) - .getResult(); +static StringRef buildPldsCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.plds.b8").getValue(); } -template -static StringRef buildSyncCallee(MLIRContext *context); +static StringRef buildPnotCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pnot.z").getValue(); +} -template <> -StringRef buildSyncCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.SET.FLAG.IMM").getValue(); +static StringRef buildPselCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.psel").getValue(); } -template <> -StringRef buildSyncCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.WAIT.FLAG.IMM").getValue(); +static StringRef buildPandCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pand.z").getValue(); } -template <> -StringRef buildSyncCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.BARRIER").getValue(); +static StringRef buildPorCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.por.z").getValue(); } -static StringRef buildMemBarCallee(MemBarKind kind, MLIRContext *context) { - switch (kind) { - case MemBarKind::VV_ALL: - return StringAttr::get(context, "llvm.hivm.mem.bar.vv.all").getValue(); - case MemBarKind::VST_VLD: - return StringAttr::get(context, "llvm.hivm.mem.bar.vst.vld").getValue(); - case MemBarKind::VLD_VST: - return StringAttr::get(context, "llvm.hivm.mem.bar.vld.vst").getValue(); - case MemBarKind::VST_VST: - return StringAttr::get(context, "llvm.hivm.mem.bar.vst.vst").getValue(); - case MemBarKind::VS_ALL: - return StringAttr::get(context, "llvm.hivm.mem.bar.vs.all").getValue(); - case MemBarKind::VST_LD: - return StringAttr::get(context, "llvm.hivm.mem.bar.vst.ld").getValue(); - case MemBarKind::VLD_ST: - return StringAttr::get(context, "llvm.hivm.mem.bar.vld.st").getValue(); - case MemBarKind::VST_ST: - return StringAttr::get(context, "llvm.hivm.mem.bar.vst.st").getValue(); - case MemBarKind::SV_ALL: - return StringAttr::get(context, "llvm.hivm.mem.bar.sv.all").getValue(); - case MemBarKind::ST_VLD: - return StringAttr::get(context, "llvm.hivm.mem.bar.st.vld").getValue(); - case MemBarKind::LD_VST: - return StringAttr::get(context, "llvm.hivm.mem.bar.ld.vst").getValue(); - case MemBarKind::ST_VST: - return StringAttr::get(context, "llvm.hivm.mem.bar.st.vst").getValue(); - case MemBarKind::SS_ALL: - return StringAttr::get(context, "llvm.hivm.mem.bar.ss.all").getValue(); - case MemBarKind::ST_LD: - return StringAttr::get(context, "llvm.hivm.mem.bar.st.ld").getValue(); - case MemBarKind::LD_ST: - return StringAttr::get(context, "llvm.hivm.mem.bar.ld.st").getValue(); - case MemBarKind::ST_ST: - return StringAttr::get(context, "llvm.hivm.mem.bar.st.st").getValue(); - } - llvm_unreachable("unexpected membar kind"); +static StringRef buildPxorCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pxor.z").getValue(); +} + +static StringRef buildPpackCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.ppack.z").getValue(); +} + +static StringRef buildPunpackCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.punpack").getValue(); } +template +static StringRef buildPredicatePairReorderCallee(MLIRContext *context); + template <> -StringRef buildSyncCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.GET.BUFI.mode").getValue(); +StringRef buildPredicatePairReorderCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pdintlv.b8").getValue(); } template <> -StringRef buildSyncCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.RLS.BUFI.mode").getValue(); +StringRef buildPredicatePairReorderCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pdintlv.b16").getValue(); } -template -static StringRef buildRuntimeQueryCallee(MLIRContext *context); +template <> +StringRef buildPredicatePairReorderCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pdintlv.b32").getValue(); +} template <> -StringRef buildRuntimeQueryCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.GET.BLOCK.IDX").getValue(); +StringRef buildPredicatePairReorderCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pintlv.b8").getValue(); } template <> -StringRef buildRuntimeQueryCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.GET.SUBBLOCKID").getValue(); +StringRef buildPredicatePairReorderCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pintlv.b16").getValue(); } template <> -StringRef buildRuntimeQueryCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.GET.BLOCK.NUM").getValue(); +StringRef buildPredicatePairReorderCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pintlv.b32").getValue(); +} + +static FailureOr buildInterleaveCallee(MLIRContext *context, + Type resultType, + StringRef stem) { + return buildLaneTypedCallee(context, resultType, stem, ""); +} + +static FailureOr buildUnpackCallee(MLIRContext *context, + Type inputType, + Type resultType, + StringRef stem) { + std::string input = + getElementTypeFragment(getElementTypeFromVectorLike(inputType)); + std::string result = + getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + if (input.empty() || result.empty()) + return failure(); + return StringAttr::get(context, + "llvm.hivm." + stem.str() + "." + input + "2" + result) + .getValue(); } -template <> -StringRef buildRuntimeQueryCallee(MLIRContext *context) { - return StringAttr::get(context, "llvm.hivm.GET.SUBBLOCKDIM").getValue(); -} +static FailureOr buildVpackCallee(MLIRContext *context, Type inputType, + Type resultType) { + std::string input = + getElementTypeFragment(getElementTypeFromVectorLike(inputType)); + std::string result = + getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + if (input.empty() || result.empty()) + return failure(); + + return StringAttr::get(context, "llvm.hivm.vpack." + input + "2" + result + ".x") + .getValue(); +} + +static FailureOr buildVsqzCallee(MLIRContext *context, + Type resultType) { + return buildLaneTypedCallee(context, resultType, "vsqz", ".x.v300"); +} + +static FailureOr buildVusqzCallee(MLIRContext *context, + Type resultType) { + return buildLaneTypedCallee(context, resultType, "vusqz", ".m"); +} + +static FailureOr buildVmulaCallee(MLIRContext *context, + Type resultType) { + return buildLaneTypedCallee(context, resultType, "vmula", ".m"); +} + +static FailureOr buildVmullCallee(MLIRContext *context, + Type resultType) { + return buildLaneTypedCallee(context, resultType, "vmull", ""); +} + +template +static StringRef getPredicateStoreCallee(MLIRContext *context); + +template <> +StringRef getPredicateStoreCallee(MLIRContext *context) { + return buildPstiCallee(context); +} + +template <> +StringRef getPredicateStoreCallee(MLIRContext *context) { + return buildPstsCallee(context); +} + +template +static StringRef getPredicateLoadCallee(MLIRContext *context); + +template <> +StringRef getPredicateLoadCallee(MLIRContext *context) { + return buildPldiCallee(context); +} + +template <> +StringRef getPredicateLoadCallee(MLIRContext *context) { + return buildPldsCallee(context); +} + +template +static StringRef getPredicateMaskCallee(MLIRContext *context); + +template <> +StringRef getPredicateMaskCallee(MLIRContext *context) { + return buildPnotCallee(context); +} + +template <> +StringRef getPredicateMaskCallee(MLIRContext *context) { + return buildPselCallee(context); +} + +template <> +StringRef getPredicateMaskCallee(MLIRContext *context) { + return buildPandCallee(context); +} + +template <> +StringRef getPredicateMaskCallee(MLIRContext *context) { + return buildPorCallee(context); +} + +template <> +StringRef getPredicateMaskCallee(MLIRContext *context) { + return buildPxorCallee(context); +} + +template +static StringRef getPredicatePackCallee(MLIRContext *context); + +template <> +StringRef getPredicatePackCallee(MLIRContext *context) { + return buildPpackCallee(context); +} + +template <> +StringRef getPredicatePackCallee(MLIRContext *context) { + return buildPunpackCallee(context); +} + +template +static StringRef buildPltCallee(MLIRContext *context); + +template <> +StringRef buildPltCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.plt.b8.v300").getValue(); +} + +template <> +StringRef buildPltCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.plt.b16.v300").getValue(); +} + +template <> +StringRef buildPltCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.plt.b32.v300").getValue(); +} + +template +static StringRef buildPsetCallee(MLIRContext *context); + +template <> +StringRef buildPsetCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pset.b8").getValue(); +} + +template <> +StringRef buildPsetCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pset.b16").getValue(); +} + +template <> +StringRef buildPsetCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pset.b32").getValue(); +} + +template +static StringRef buildPgeCallee(MLIRContext *context); + +template <> +StringRef buildPgeCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pge.b8").getValue(); +} + +template <> +StringRef buildPgeCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pge.b16").getValue(); +} + +template <> +StringRef buildPgeCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pge.b32").getValue(); +} + +static FailureOr buildVldsCallee(MLIRContext *context, Type resultType) { + std::string vec = getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + auto lanes = getElementCountFromVectorLike(resultType); + if (vec.empty() || !lanes) + return failure(); + return StringAttr::get(context, "llvm.hivm.vldsx1.v" + std::to_string(*lanes) + + vec) + .getValue(); +} + +static FailureOr buildVldsx2Callee(MLIRContext *context, + Type resultType) { + return buildLaneTypedCallee(context, resultType, "vldsx2", ""); +} + +static StringRef buildVsldbCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.vsldb").getValue(); +} + +static FailureOr buildVstsCallee(MLIRContext *context, Type valueType) { + std::string vec = getElementTypeFragment(getElementTypeFromVectorLike(valueType)); + auto lanes = getElementCountFromVectorLike(valueType); + if (vec.empty() || !lanes) + return failure(); + return StringAttr::get(context, "llvm.hivm.vstsx1.v" + std::to_string(*lanes) + + vec) + .getValue(); +} + +static FailureOr buildVstsx2Callee(MLIRContext *context, Type valueType) { + return buildLaneTypedCallee(context, valueType, "vstsx2", ""); +} + +static StringRef buildVsstbCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.vsstb").getValue(); +} + +static FailureOr buildVgather2Callee(MLIRContext *context, + Type resultType) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + auto lanes = getElementCountFromVectorLike(resultType); + if (vec.empty() || !lanes) + return failure(); + return StringAttr::get(context, "llvm.hivm.vgather2.v300.v" + + std::to_string(*lanes) + vec) + .getValue(); +} + +static FailureOr buildVgather2BcCallee(MLIRContext *context, + Type resultType) { + return buildLaneTypedCallee(context, resultType, "vgather2.bc", ""); +} + +static FailureOr buildVgatherbCallee(MLIRContext *context, + Type resultType) { + return buildLaneTypedCallee(context, resultType, "vgatherb.v310", ""); +} + +static FailureOr buildVscatterCallee(MLIRContext *context, + Type valueType) { + return buildLaneTypedCallee(context, valueType, "vscatter", ".v300"); +} + +static FailureOr buildVpreluCallee(MLIRContext *context, + Type resultType) { + return buildLaneTypedCallee(context, resultType, "vprelu", ".x"); +} + +static FailureOr buildVaxpyCallee(MLIRContext *context, + Type resultType) { + return buildLaneTypedCallee(context, resultType, "vaxpy", ".m"); +} + +static FailureOr buildVciCallee(MLIRContext *context, Type resultType) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + auto lanes = getElementCountFromVectorLike(resultType); + if (vec.empty() || !lanes) + return failure(); + if (vec == "f16" || vec == "f32") + return StringAttr::get(context, "llvm.hivm.vci.v" + std::to_string(*lanes) + + vec + "." + vec) + .getValue(); + return StringAttr::get(context, + "llvm.hivm.vci.v" + std::to_string(*lanes) + vec) + .getValue(); +} + +static FailureOr buildVtrcCallee(MLIRContext *context, Type resultType) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + auto lanes = getElementCountFromVectorLike(resultType); + if (vec.empty() || !lanes) + return failure(); + return StringAttr::get(context, "llvm.hivm.vtrc." + vec + ".x").getValue(); +} + +static FailureOr buildVexpdifCallee(MLIRContext *context, + Type inputType, + Type resultType) { + std::string srcVec = + getElementTypeFragment(getElementTypeFromVectorLike(inputType)); + auto srcLanes = getElementCountFromVectorLike(inputType); + std::string dstElem = + getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + if (srcVec.empty() || dstElem.empty() || !srcLanes) + return failure(); + return StringAttr::get(context, "llvm.hivm.vexpdif.v" + + std::to_string(*srcLanes) + srcVec + + dstElem) + .getValue(); +} + +static FailureOr buildVbitsortCallee(MLIRContext *context, + pto::VbitsortOp op) { + Type sourceElemType = cast(op.getSource().getType()).getElementType(); + if (sourceElemType.isF16()) + return StringAttr::get(context, "llvm.hivm.VBS32.V300.f16").getValue(); + if (sourceElemType.isF32()) + return StringAttr::get(context, "llvm.hivm.VBS32.V300.f32").getValue(); + return failure(); +} + +static FailureOr buildVcvtContract(pto::VcvtOp op) { + Type inputElemType = getElementTypeFromVectorLike(op.getInput().getType()); + Type resultElemType = getElementTypeFromVectorLike(op.getResult().getType()); + if (!inputElemType || !resultElemType) + return failure(); + auto contract = lookupVcvtContract(classifyVcvtElemType(inputElemType), + classifyVcvtElemType(resultElemType)); + if (!contract) + return failure(); + return *contract; +} + +template +static StringRef buildSetLoopCallee(MLIRContext *context); + +template +static StringRef buildUnaryConfigCallee(MLIRContext *context); + +template +static StringRef buildNullaryConfigCallee(MLIRContext *context); + +template <> +StringRef buildSetLoopCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.LOOP2.STRIDE.OUTTOUB") + .getValue(); +} + +template <> +StringRef buildSetLoopCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.LOOP1.STRIDE.OUTTOUB") + .getValue(); +} + +template <> +StringRef buildSetLoopCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.LOOP.SIZE.OUTTOUB") + .getValue(); +} + +template <> +StringRef buildSetLoopCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.LOOP2.STRIDE.UBTOOUT") + .getValue(); +} + +template <> +StringRef buildSetLoopCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.LOOP1.STRIDE.UBTOOUT") + .getValue(); +} + +template <> +StringRef buildSetLoopCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.LOOP.SIZE.UBTOOUT") + .getValue(); +} + +template <> +StringRef buildSetLoopCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.LOOP3.PARA").getValue(); +} + +template <> +StringRef buildSetLoopCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.CHANNEL.PARA").getValue(); +} + +template <> +StringRef buildUnaryConfigCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.MOV.PAD.VAL").getValue(); +} + +template <> +StringRef buildUnaryConfigCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.QUANT.PRE.v300").getValue(); +} + +template <> +StringRef buildUnaryConfigCallee( + MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.LOOP2.STRIDE.OUTTOL1") + .getValue(); +} + +template <> +StringRef buildUnaryConfigCallee( + MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.LOOP1.STRIDE.OUTTOL1") + .getValue(); +} + +template <> +StringRef buildUnaryConfigCallee( + MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.LOOP.SIZE.OUTTOL1") + .getValue(); +} + +template <> +StringRef buildUnaryConfigCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.MTE2.NZ.PARA").getValue(); +} + +template <> +StringRef buildUnaryConfigCallee( + MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.PAD.VAL.OUTTOL1") + .getValue(); +} + +template <> +StringRef buildUnaryConfigCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.FPC").getValue(); +} + +template <> +StringRef buildNullaryConfigCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.ATOMIC.S32").getValue(); +} + +template <> +StringRef buildNullaryConfigCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.ATOMIC.S8").getValue(); +} + +static FailureOr encodeMovPadValue(Location loc, Value value, + ConversionPatternRewriter &rewriter) { + Type type = value.getType(); + Value payload = value; + unsigned bitWidth = 0; + + if (auto intType = dyn_cast(type)) { + bitWidth = intType.getWidth(); + } else if (auto floatType = dyn_cast(type)) { + bitWidth = floatType.getWidth(); + auto intType = rewriter.getIntegerType(bitWidth); + payload = rewriter.create(loc, intType, value); + } else { + return failure(); + } + + if (bitWidth != 8 && bitWidth != 16 && bitWidth != 32) + return failure(); + + return rewriter.create(loc, rewriter.getI64Type(), payload) + .getResult(); +} + +template +static StringRef buildSyncCallee(MLIRContext *context); + +template <> +StringRef buildSyncCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.FLAG.IMM").getValue(); +} + +template <> +StringRef buildSyncCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.WAIT.FLAG.IMM").getValue(); +} + +template <> +StringRef buildSyncCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.BARRIER").getValue(); +} + +static StringRef buildMemBarCallee(MemBarKind kind, MLIRContext *context) { + switch (kind) { + case MemBarKind::VV_ALL: + return StringAttr::get(context, "llvm.hivm.mem.bar.vv.all").getValue(); + case MemBarKind::VST_VLD: + return StringAttr::get(context, "llvm.hivm.mem.bar.vst.vld").getValue(); + case MemBarKind::VLD_VST: + return StringAttr::get(context, "llvm.hivm.mem.bar.vld.vst").getValue(); + case MemBarKind::VST_VST: + return StringAttr::get(context, "llvm.hivm.mem.bar.vst.vst").getValue(); + case MemBarKind::VS_ALL: + return StringAttr::get(context, "llvm.hivm.mem.bar.vs.all").getValue(); + case MemBarKind::VST_LD: + return StringAttr::get(context, "llvm.hivm.mem.bar.vst.ld").getValue(); + case MemBarKind::VLD_ST: + return StringAttr::get(context, "llvm.hivm.mem.bar.vld.st").getValue(); + case MemBarKind::VST_ST: + return StringAttr::get(context, "llvm.hivm.mem.bar.vst.st").getValue(); + case MemBarKind::SV_ALL: + return StringAttr::get(context, "llvm.hivm.mem.bar.sv.all").getValue(); + case MemBarKind::ST_VLD: + return StringAttr::get(context, "llvm.hivm.mem.bar.st.vld").getValue(); + case MemBarKind::LD_VST: + return StringAttr::get(context, "llvm.hivm.mem.bar.ld.vst").getValue(); + case MemBarKind::ST_VST: + return StringAttr::get(context, "llvm.hivm.mem.bar.st.vst").getValue(); + case MemBarKind::SS_ALL: + return StringAttr::get(context, "llvm.hivm.mem.bar.ss.all").getValue(); + case MemBarKind::ST_LD: + return StringAttr::get(context, "llvm.hivm.mem.bar.st.ld").getValue(); + case MemBarKind::LD_ST: + return StringAttr::get(context, "llvm.hivm.mem.bar.ld.st").getValue(); + case MemBarKind::ST_ST: + return StringAttr::get(context, "llvm.hivm.mem.bar.st.st").getValue(); + } + llvm_unreachable("unexpected membar kind"); +} + +template <> +StringRef buildSyncCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.GET.BUFI.mode").getValue(); +} + +template <> +StringRef buildSyncCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.RLS.BUFI.mode").getValue(); +} + +template +static StringRef buildRuntimeQueryCallee(MLIRContext *context); + +template <> +StringRef buildRuntimeQueryCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.GET.BLOCK.IDX").getValue(); +} + +template <> +StringRef buildRuntimeQueryCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.GET.SUBBLOCKID").getValue(); +} + +template <> +StringRef buildRuntimeQueryCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.GET.BLOCK.NUM").getValue(); +} + +template <> +StringRef buildRuntimeQueryCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.GET.SUBBLOCKDIM").getValue(); +} + +static LogicalResult +materializeDecls(ModuleOp module, ArrayRef plannedDecls, + llvm::raw_ostream &diagOS) { + OpBuilder builder(module.getBodyRegion()); + builder.setInsertionPointToStart(&module.getBodyRegion().front()); + for (const PlannedDecl &decl : plannedDecls) { + if (func::FuncOp existing = module.lookupSymbol(decl.name)) { + if (existing.getFunctionType() != decl.type) { + diagOS << "VPTO LLVM emission failed: conflicting declaration for " + << decl.name << "\n"; + return failure(); + } + continue; + } + auto func = + builder.create(module.getLoc(), decl.name, decl.type); + func.setPrivate(); + } + return success(); +} + +template +class LowerUnaryMaskedOpPattern final : public OpConversionPattern { +public: + explicit LowerUnaryMaskedOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(UnaryOp op, typename UnaryOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + StringRef stem = getUnaryMaskedStem(); + FailureOr calleeName = + buildLaneTypedCallee(op.getContext(), op.getResult().getType(), stem, ".x"); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported unary VPTO signature"); + + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, + "failed to convert unary result type"); + + Value input = adaptor.getOperands()[0]; + Value mask = adaptor.getOperands()[1]; + Type expectedMaskType = + this->getTypeConverter()->convertType(op->getOperand(1).getType()); + if (!input || !mask || input.getType() != resultType || + mask.getType() != expectedMaskType) { + return rewriter.notifyMatchFailure( + op, "unexpected converted unary VPTO operand types"); + } + + auto call = rewriter.create(op.getLoc(), *calleeName, + TypeRange{resultType}, + ValueRange{input, mask}); + state.plannedDecls.push_back( + PlannedDecl{calleeName->str(), call.getCalleeType()}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVsqzOpPattern final : public OpConversionPattern { +public: + explicit LowerVsqzOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VsqzOp op, pto::VsqzOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = + buildVsqzCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vsqz VPTO signature"); + + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + Type maskType = this->getTypeConverter()->convertType(op.getMask().getType()); + if (!resultType || !maskType) + return rewriter.notifyMatchFailure(op, "failed to convert vsqz types"); + + Value input = adaptor.getInput(); + Value mask = adaptor.getMask(); + if (!input || !mask || input.getType() != resultType || + mask.getType() != maskType) { + return rewriter.notifyMatchFailure(op, + "unexpected converted vsqz operand types"); + } + + Value storeHint = + getI32Constant(rewriter, op.getLoc(), determineVsqzStoreHint(op)); + auto funcType = rewriter.getFunctionType( + TypeRange{resultType, maskType, storeHint.getType()}, TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, + ValueRange{input, mask, storeHint}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVusqzOpPattern final : public OpConversionPattern { +public: + explicit LowerVusqzOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VusqzOp op, pto::VusqzOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = + buildVusqzCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vusqz VPTO signature"); + + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + Type maskType = this->getTypeConverter()->convertType(op.getMask().getType()); + if (!resultType || !maskType) + return rewriter.notifyMatchFailure(op, "failed to convert vusqz types"); + + Value src = adaptor.getSrc(); + Value mask = adaptor.getMask(); + if (!src || !mask || src.getType() != resultType || mask.getType() != maskType) { + return rewriter.notifyMatchFailure(op, + "unexpected converted vusqz operand types"); + } + + auto funcType = + rewriter.getFunctionType(TypeRange{resultType, maskType}, TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, ValueRange{src, mask}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVmulaOpPattern final : public OpConversionPattern { +public: + explicit LowerVmulaOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VmulaOp op, pto::VmulaOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = + buildVmulaCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vmula VPTO signature"); + + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + Type maskType = this->getTypeConverter()->convertType(op.getMask().getType()); + if (!resultType || !maskType) + return rewriter.notifyMatchFailure(op, "failed to convert vmula types"); + + Value acc = adaptor.getAcc(); + Value lhs = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + Value mask = adaptor.getMask(); + if (!acc || !lhs || !rhs || !mask || acc.getType() != resultType || + lhs.getType() != resultType || rhs.getType() != resultType || + mask.getType() != maskType) { + return rewriter.notifyMatchFailure(op, + "unexpected converted vmula operand types"); + } + + auto funcType = rewriter.getFunctionType( + TypeRange{resultType, resultType, resultType, maskType}, + TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, + ValueRange{acc, lhs, rhs, mask}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVmullOpPattern final : public OpConversionPattern { +public: + explicit LowerVmullOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VmullOp op, pto::VmullOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = + buildVmullCallee(op.getContext(), op.getLow().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vmull VPTO signature"); + + Type inputType = this->getTypeConverter()->convertType(op.getLhs().getType()); + Type maskType = this->getTypeConverter()->convertType(op.getMask().getType()); + SmallVector resultTypes; + if (!inputType || !maskType || + failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), resultTypes))) { + return rewriter.notifyMatchFailure(op, "failed to convert vmull types"); + } + if (resultTypes.size() != 2 || resultTypes[0] != resultTypes[1]) + return rewriter.notifyMatchFailure(op, "unexpected converted vmull results"); + + Value lhs = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + Value mask = adaptor.getMask(); + if (!lhs || !rhs || !mask || lhs.getType() != inputType || + rhs.getType() != inputType || mask.getType() != maskType) { + return rewriter.notifyMatchFailure(op, + "unexpected converted vmull operand types"); + } + + auto funcType = rewriter.getFunctionType(TypeRange{inputType, inputType, maskType}, + resultTypes); + auto call = rewriter.create(op.getLoc(), *calleeName, resultTypes, + ValueRange{lhs, rhs, mask}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerBinaryMaskedOpPattern final : public OpConversionPattern { +public: + explicit LowerBinaryMaskedOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(BinaryOp op, typename BinaryOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + StringRef stem = getBinaryMaskedStem(); + FailureOr calleeName = + buildLaneTypedCallee(op.getContext(), op.getResult().getType(), stem, ".x"); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported binary VPTO signature"); + + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, + "failed to convert binary result type"); + + Value lhs = adaptor.getOperands()[0]; + Value rhs = adaptor.getOperands()[1]; + Value mask = adaptor.getOperands()[2]; + Type expectedMaskType = + this->getTypeConverter()->convertType(op->getOperand(2).getType()); + if (!lhs || !rhs || !mask || lhs.getType() != resultType || + rhs.getType() != resultType || mask.getType() != expectedMaskType) { + return rewriter.notifyMatchFailure( + op, "unexpected converted binary VPTO operand types"); + } + + auto call = rewriter.create(op.getLoc(), *calleeName, + TypeRange{resultType}, + ValueRange{lhs, rhs, mask}); + state.plannedDecls.push_back( + PlannedDecl{calleeName->str(), call.getCalleeType()}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerCarryBinaryOpPattern final : public OpConversionPattern { +public: + explicit LowerCarryBinaryOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(CarryOp op, typename CarryOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + StringRef stem = getCarryBinaryStem(); + FailureOr calleeName = + buildCarryBinaryCallee(op.getContext(), op.getResult().getType(), stem); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported carry VPTO signature"); + + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + Type carryType = + this->getTypeConverter()->convertType(op->getResult(1).getType()); + if (!resultType || !carryType) + return rewriter.notifyMatchFailure(op, + "failed to convert carry result types"); + + SmallVector callArgs; + callArgs.append(adaptor.getOperands().begin(), adaptor.getOperands().end()); + const size_t expectedArgCount = hasCarryInput() ? 4 : 3; + if (callArgs.size() != expectedArgCount || callArgs[0].getType() != resultType || + callArgs[1].getType() != resultType || callArgs.back().getType() != carryType) + return rewriter.notifyMatchFailure(op, + "unexpected converted carry operand types"); + if constexpr (hasCarryInput()) { + if (callArgs[2].getType() != carryType) + return rewriter.notifyMatchFailure( + op, "unexpected converted carry input operand type"); + } + + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType, carryType}, callArgs); + state.plannedDecls.push_back( + PlannedDecl{calleeName->str(), call.getCalleeType()}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerCopyOpPattern final : public OpConversionPattern { +public: + explicit LowerCopyOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(CopyOp op, typename CopyOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = failure(); + if constexpr (std::is_same_v) + calleeName = buildCopyGmToUbCallee(op.getContext(), op.getSource().getType()); + else + calleeName = buildCopyUbToGmCallee(op.getContext()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported copy VPTO signature"); + + auto llvmSourceType = + dyn_cast(adaptor.getOperands()[0].getType()); + auto llvmDestType = + dyn_cast(adaptor.getOperands()[1].getType()); + if (!llvmSourceType || !llvmDestType) + return rewriter.notifyMatchFailure(op, "expected LLVM pointer copy operands"); + + FailureOr config0 = failure(); + FailureOr config1 = failure(); + if constexpr (std::is_same_v) { + config0 = packCopyGmToUbConfig0(op, adaptor.getOperands()); + config1 = packCopyGmToUbConfig1(op, adaptor.getOperands()); + } else { + config0 = packCopyUbToGmConfig0(op, adaptor.getOperands()); + config1 = packCopyUbToGmConfig1(op, adaptor.getOperands()); + } + if (failed(config0) || failed(config1)) + return rewriter.notifyMatchFailure(op, "failed to materialize copy config"); + + SmallVector args{adaptor.getOperands()[1], adaptor.getOperands()[0], + *config0, *config1}; + auto funcType = rewriter.getFunctionType( + TypeRange{llvmDestType, llvmSourceType, rewriter.getI64Type(), + rewriter.getI64Type()}, + TypeRange{}); + auto call = rewriter.create(op.getLoc(), *calleeName, + TypeRange{}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.eraseOp(op); + (void)call; + return success(); + } + +private: + LoweringState &state; +}; + +class LowerCopyUbufToUbufOpPattern final + : public OpConversionPattern { +public: + explicit LowerCopyUbufToUbufOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::CopyUbufToUbufOp op, + pto::CopyUbufToUbufOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto llvmSourceType = + dyn_cast(adaptor.getOperands()[0].getType()); + auto llvmDestType = + dyn_cast(adaptor.getOperands()[1].getType()); + if (!llvmSourceType || !llvmDestType) + return rewriter.notifyMatchFailure(op, "expected LLVM pointer copy operands"); + + FailureOr config = packCopyUbToUbConfig(op, adaptor.getOperands()); + if (failed(config)) + return rewriter.notifyMatchFailure(op, "failed to materialize copy config"); -static LogicalResult -materializeDecls(ModuleOp module, ArrayRef plannedDecls, - llvm::raw_ostream &diagOS) { - OpBuilder builder(module.getBodyRegion()); - builder.setInsertionPointToStart(&module.getBodyRegion().front()); - for (const PlannedDecl &decl : plannedDecls) { - if (func::FuncOp existing = module.lookupSymbol(decl.name)) { - if (existing.getFunctionType() != decl.type) { - diagOS << "VPTO LLVM emission failed: conflicting declaration for " - << decl.name << "\n"; - return failure(); - } - continue; - } - auto func = - builder.create(module.getLoc(), decl.name, decl.type); - func.setPrivate(); + StringRef calleeName = buildCopyUbToUbCallee(op.getContext()); + SmallVector args{adaptor.getOperands()[1], adaptor.getOperands()[0], + *config}; + auto funcType = rewriter.getFunctionType( + TypeRange{llvmDestType, llvmSourceType, rewriter.getI64Type()}, + TypeRange{}); + auto call = rewriter.create(op.getLoc(), calleeName, + TypeRange{}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + (void)call; + return success(); } - return success(); -} -template -class LowerUnaryMaskedOpPattern final : public OpConversionPattern { +private: + LoweringState &state; +}; + +class LowerMadOpPattern final + : public OpConversionPattern { public: - explicit LowerUnaryMaskedOpPattern(TypeConverter &typeConverter, - MLIRContext *context, - LoweringState &state) - : OpConversionPattern(typeConverter, context), state(state) {} + explicit LowerMadOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} LogicalResult - matchAndRewrite(UnaryOp op, typename UnaryOp::Adaptor adaptor, + matchAndRewrite(pto::MadOp op, pto::MadOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - StringRef stem = getUnaryMaskedStem(); - FailureOr calleeName = - buildLaneTypedCallee(op.getContext(), op.getResult().getType(), stem, ".x"); - if (failed(calleeName)) - return rewriter.notifyMatchFailure(op, "unsupported unary VPTO signature"); - - Type resultType = - this->getTypeConverter()->convertType(op.getResult().getType()); - if (!resultType) + Value lhsRaw = adaptor.getLhs(); + Value rhsRaw = adaptor.getRhs(); + Value dstRaw = adaptor.getDst(); + Value m = adaptor.getM(); + Value n = adaptor.getN(); + Value k = adaptor.getK(); + if (!lhsRaw || !rhsRaw || !dstRaw || !m || !n || !k) + return rewriter.notifyMatchFailure(op, "expected converted mad operands"); + + if (!isa(lhsRaw.getType()) || + !isa(rhsRaw.getType()) || + !isa(dstRaw.getType())) { return rewriter.notifyMatchFailure(op, - "failed to convert unary result type"); + "expected LLVM pointer lhs/rhs/dst"); + } - Value input = adaptor.getOperands()[0]; - Value mask = adaptor.getOperands()[1]; - Type expectedMaskType = - this->getTypeConverter()->convertType(op->getOperand(1).getType()); - if (!input || !mask || input.getType() != resultType || - mask.getType() != expectedMaskType) { + Type i64Ty = rewriter.getI64Type(); + + constexpr unsigned caAddressSpace = + static_cast(pto::AddressSpace::LEFT); + constexpr unsigned cbAddressSpace = + static_cast(pto::AddressSpace::RIGHT); + constexpr unsigned ccAddressSpace = + static_cast(pto::AddressSpace::ACC); + FailureOr lhs = reinterpretPointerToAddrSpace(op, lhsRaw, caAddressSpace); + FailureOr rhs = reinterpretPointerToAddrSpace(op, rhsRaw, cbAddressSpace); + FailureOr dst = reinterpretPointerToAddrSpace(op, dstRaw, ccAddressSpace); + if (failed(lhs) || failed(rhs) || failed(dst)) + return rewriter.notifyMatchFailure(op, "failed to map cube pointer spaces"); + + FailureOr calleeName = buildMadCallee(op.getContext(), op); + if (failed(calleeName)) { return rewriter.notifyMatchFailure( - op, "unexpected converted unary VPTO operand types"); + op, "unsupported mad element types for mad/mad_mx dispatch"); } + FailureOr config = packMadConfig(op, m, n, k); + if (failed(config)) + return rewriter.notifyMatchFailure(op, "failed to pack mad config"); - auto call = rewriter.create(op.getLoc(), *calleeName, - TypeRange{resultType}, - ValueRange{input, mask}); - state.plannedDecls.push_back( - PlannedDecl{calleeName->str(), call.getCalleeType()}); - rewriter.replaceOp(op, call.getResults()); + auto funcType = rewriter.getFunctionType( + TypeRange{dst->getType(), lhs->getType(), rhs->getType(), i64Ty}, + TypeRange{}); + auto call = + rewriter.create(op.getLoc(), *calleeName, TypeRange{}, + ValueRange{*dst, *lhs, *rhs, *config}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.eraseOp(op); + (void)call; return success(); } @@ -2004,43 +3177,69 @@ class LowerUnaryMaskedOpPattern final : public OpConversionPattern { LoweringState &state; }; -class LowerVsqzOpPattern final : public OpConversionPattern { +class LowerMadMxOpPattern final + : public OpConversionPattern { public: - explicit LowerVsqzOpPattern(TypeConverter &typeConverter, MLIRContext *context, - LoweringState &state) - : OpConversionPattern(typeConverter, context), state(state) {} + explicit LowerMadMxOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} LogicalResult - matchAndRewrite(pto::VsqzOp op, pto::VsqzOp::Adaptor adaptor, + matchAndRewrite(pto::MadMxOp op, pto::MadMxOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - FailureOr calleeName = - buildVsqzCallee(op.getContext(), op.getResult().getType()); - if (failed(calleeName)) - return rewriter.notifyMatchFailure(op, "unsupported vsqz VPTO signature"); - - Type resultType = - this->getTypeConverter()->convertType(op.getResult().getType()); - Type maskType = this->getTypeConverter()->convertType(op.getMask().getType()); - if (!resultType || !maskType) - return rewriter.notifyMatchFailure(op, "failed to convert vsqz types"); + Value lhsRaw = adaptor.getLhs(); + Value rhsRaw = adaptor.getRhs(); + Value dstRaw = adaptor.getDst(); + Value m = adaptor.getM(); + Value n = adaptor.getN(); + Value k = adaptor.getK(); + if (!lhsRaw || !rhsRaw || !dstRaw || !m || !n || !k) { + return rewriter.notifyMatchFailure(op, + "expected converted mad_mx operands"); + } - Value input = adaptor.getInput(); - Value mask = adaptor.getMask(); - if (!input || !mask || input.getType() != resultType || - mask.getType() != maskType) { + if (!isa(lhsRaw.getType()) || + !isa(rhsRaw.getType()) || + !isa(dstRaw.getType())) { return rewriter.notifyMatchFailure(op, - "unexpected converted vsqz operand types"); + "expected LLVM pointer lhs/rhs/dst"); } - Value storeHint = - getI32Constant(rewriter, op.getLoc(), determineVsqzStoreHint(op)); + Type i64Ty = rewriter.getI64Type(); + constexpr unsigned caAddressSpace = + static_cast(pto::AddressSpace::LEFT); + constexpr unsigned cbAddressSpace = + static_cast(pto::AddressSpace::RIGHT); + constexpr unsigned ccAddressSpace = + static_cast(pto::AddressSpace::ACC); + FailureOr lhs = + reinterpretPointerToAddrSpace(op, lhsRaw, caAddressSpace); + FailureOr rhs = + reinterpretPointerToAddrSpace(op, rhsRaw, cbAddressSpace); + FailureOr dst = + reinterpretPointerToAddrSpace(op, dstRaw, ccAddressSpace); + if (failed(lhs) || failed(rhs) || failed(dst)) + return rewriter.notifyMatchFailure(op, "failed to map cube pointer spaces"); + + FailureOr calleeName = buildMadMxCallee(op.getContext(), op); + if (failed(calleeName)) { + return rewriter.notifyMatchFailure( + op, "unsupported mad_mx element types for mad_mx dispatch"); + } + FailureOr config = packMadConfig(op, m, n, k); + if (failed(config)) + return rewriter.notifyMatchFailure(op, "failed to pack mad config"); + auto funcType = rewriter.getFunctionType( - TypeRange{resultType, maskType, storeHint.getType()}, TypeRange{resultType}); - auto call = rewriter.create( - op.getLoc(), *calleeName, TypeRange{resultType}, - ValueRange{input, mask, storeHint}); + TypeRange{dst->getType(), lhs->getType(), rhs->getType(), i64Ty}, + TypeRange{}); + auto call = + rewriter.create(op.getLoc(), *calleeName, TypeRange{}, + ValueRange{*dst, *lhs, *rhs, *config}); state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); - rewriter.replaceOp(op, call.getResults()); + rewriter.eraseOp(op); + (void)call; return success(); } @@ -2048,39 +3247,72 @@ class LowerVsqzOpPattern final : public OpConversionPattern { LoweringState &state; }; -class LowerVusqzOpPattern final : public OpConversionPattern { +class LowerCopyGmToCbufOpPattern final + : public OpConversionPattern { public: - explicit LowerVusqzOpPattern(TypeConverter &typeConverter, - MLIRContext *context, LoweringState &state) - : OpConversionPattern(typeConverter, context), state(state) {} + explicit LowerCopyGmToCbufOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} - LogicalResult - matchAndRewrite(pto::VusqzOp op, pto::VusqzOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - FailureOr calleeName = - buildVusqzCallee(op.getContext(), op.getResult().getType()); - if (failed(calleeName)) - return rewriter.notifyMatchFailure(op, "unsupported vusqz VPTO signature"); + LogicalResult matchAndRewrite( + pto::CopyGmToCbufOp op, + pto::CopyGmToCbufOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value sourceRaw = adaptor.getSource(); + Value destinationRaw = adaptor.getDestination(); + Value nBurst = adaptor.getNBurst(); + Value lenBurst = adaptor.getLenBurst(); + Value srcStride = adaptor.getSrcStride(); + Value dstStride = adaptor.getDstStride(); + if (!sourceRaw || !destinationRaw || !nBurst || !lenBurst || !srcStride || + !dstStride) { + return rewriter.notifyMatchFailure(op, "expected converted operands"); + } - Type resultType = - this->getTypeConverter()->convertType(op.getResult().getType()); - Type maskType = this->getTypeConverter()->convertType(op.getMask().getType()); - if (!resultType || !maskType) - return rewriter.notifyMatchFailure(op, "failed to convert vusqz types"); + if (!isa(sourceRaw.getType()) || + !isa(destinationRaw.getType())) { + return rewriter.notifyMatchFailure(op, "expected LLVM pointer src/dst"); + } - Value src = adaptor.getSrc(); - Value mask = adaptor.getMask(); - if (!src || !mask || src.getType() != resultType || mask.getType() != maskType) { - return rewriter.notifyMatchFailure(op, - "unexpected converted vusqz operand types"); + Type i64Ty = rewriter.getI64Type(); + if (nBurst.getType() != i64Ty || lenBurst.getType() != i64Ty || + srcStride.getType() != i64Ty || dstStride.getType() != i64Ty) { + return rewriter.notifyMatchFailure(op, "expected i64 config operands"); } - auto funcType = - rewriter.getFunctionType(TypeRange{resultType, maskType}, TypeRange{resultType}); - auto call = rewriter.create( - op.getLoc(), *calleeName, TypeRange{resultType}, ValueRange{src, mask}); + constexpr unsigned gmAddressSpace = + static_cast(pto::AddressSpace::GM); + constexpr unsigned cbufAddressSpace = + static_cast(pto::AddressSpace::MAT); + FailureOr source = reinterpretPointerToAddrSpace(op, sourceRaw, gmAddressSpace); + FailureOr destination = + reinterpretPointerToAddrSpace(op, destinationRaw, cbufAddressSpace); + if (failed(source) || failed(destination)) + return rewriter.notifyMatchFailure(op, "failed to map cbuf/gm pointer spaces"); + + FailureOr calleeName = + buildCopyGmToCbufCallee(op.getContext(), op.getSource().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, + "unsupported copy_gm_to_cbuf element type"); + FailureOr config0 = + packCopyGmToCbufConfig0(op, nBurst, lenBurst); + FailureOr config1 = + packCopyGmToCbufConfig1(op, srcStride, dstStride); + if (failed(config0) || failed(config1)) + return rewriter.notifyMatchFailure(op, + "failed to pack copy_gm_to_cbuf config"); + + auto funcType = rewriter.getFunctionType( + TypeRange{destination->getType(), source->getType(), i64Ty, i64Ty}, + TypeRange{}); + rewriter.create( + op.getLoc(), *calleeName, TypeRange{}, + ValueRange{*destination, *source, *config0, *config1}); state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); - rewriter.replaceOp(op, call.getResults()); + rewriter.eraseOp(op); return success(); } @@ -2088,45 +3320,230 @@ class LowerVusqzOpPattern final : public OpConversionPattern { LoweringState &state; }; -class LowerVmulaOpPattern final : public OpConversionPattern { +template +class LowerCopyGmToCbufMultiOpPattern final + : public OpConversionPattern { public: - explicit LowerVmulaOpPattern(TypeConverter &typeConverter, MLIRContext *context, - LoweringState &state) - : OpConversionPattern(typeConverter, context), state(state) {} + explicit LowerCopyGmToCbufMultiOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} LogicalResult - matchAndRewrite(pto::VmulaOp op, pto::VmulaOp::Adaptor adaptor, + matchAndRewrite(CopyOp op, typename CopyOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - FailureOr calleeName = - buildVmulaCallee(op.getContext(), op.getResult().getType()); - if (failed(calleeName)) - return rewriter.notifyMatchFailure(op, "unsupported vmula VPTO signature"); + Value sourceRaw = adaptor.getSource(); + Value destinationRaw = adaptor.getDestination(); + if (!sourceRaw || !destinationRaw) + return rewriter.notifyMatchFailure(op, "expected converted operands"); + if (!isa(sourceRaw.getType()) || + !isa(destinationRaw.getType())) + return rewriter.notifyMatchFailure(op, "expected LLVM pointer src/dst"); + + constexpr unsigned gmAddressSpace = + static_cast(pto::AddressSpace::GM); + constexpr unsigned cbufAddressSpace = + static_cast(pto::AddressSpace::MAT); + FailureOr source = + reinterpretPointerToAddrSpace(op, sourceRaw, gmAddressSpace); + FailureOr destination = + reinterpretPointerToAddrSpace(op, destinationRaw, cbufAddressSpace); + if (failed(source) || failed(destination)) + return rewriter.notifyMatchFailure(op, "failed to map cbuf/gm pointer spaces"); + + FailureOr config0 = packCopyGmToCbufMultiConfig0( + op, adaptor.getSid(), adaptor.getLoop1SrcStride(), + adaptor.getL2CacheCtrl(), adaptor.getNValue()); + FailureOr config1 = + packCopyGmToCbufMultiConfig1(op, adaptor.getDValue(), + adaptor.getLoop4SrcStride(), + adaptor.getSmallc0En()); + if (failed(config0) || failed(config1)) + return rewriter.notifyMatchFailure(op, "failed to pack multi copy config"); - Type resultType = - this->getTypeConverter()->convertType(op.getResult().getType()); - Type maskType = this->getTypeConverter()->convertType(op.getMask().getType()); - if (!resultType || !maskType) - return rewriter.notifyMatchFailure(op, "failed to convert vmula types"); + StringRef calleeName = [] (MLIRContext *ctx) -> StringRef { + if constexpr (std::is_same_v) + return buildCopyGmToCbufMultiNd2NzCallee(ctx); + return buildCopyGmToCbufMultiDn2NzCallee(ctx); + }(op.getContext()); - Value acc = adaptor.getAcc(); - Value lhs = adaptor.getLhs(); - Value rhs = adaptor.getRhs(); - Value mask = adaptor.getMask(); - if (!acc || !lhs || !rhs || !mask || acc.getType() != resultType || - lhs.getType() != resultType || rhs.getType() != resultType || - mask.getType() != maskType) { - return rewriter.notifyMatchFailure(op, - "unexpected converted vmula operand types"); + Type i64Ty = rewriter.getI64Type(); + auto funcType = rewriter.getFunctionType( + TypeRange{destination->getType(), source->getType(), i64Ty, i64Ty}, + TypeRange{}); + rewriter.create( + op.getLoc(), calleeName, TypeRange{}, + ValueRange{*destination, *source, *config0, *config1}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerCopyCbufToBtOpPattern final + : public OpConversionPattern { +public: + explicit LowerCopyCbufToBtOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult matchAndRewrite(pto::CopyCbufToBtOp op, + pto::CopyCbufToBtOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value sourceRaw = adaptor.getSource(); + Value destinationRaw = adaptor.getDestination(); + if (!sourceRaw || !destinationRaw) + return rewriter.notifyMatchFailure(op, "expected converted operands"); + if (!isa(sourceRaw.getType()) || + !isa(destinationRaw.getType())) + return rewriter.notifyMatchFailure(op, "expected LLVM pointer src/dst"); + + constexpr unsigned cbufAddressSpace = + static_cast(pto::AddressSpace::MAT); + constexpr unsigned btAddressSpace = + static_cast(pto::AddressSpace::BIAS); + FailureOr source = + reinterpretPointerToAddrSpace(op, sourceRaw, cbufAddressSpace); + FailureOr destinationPtr = + reinterpretPointerToAddrSpace(op, destinationRaw, btAddressSpace); + if (failed(source) || failed(destinationPtr)) + return rewriter.notifyMatchFailure(op, "failed to map cbuf/bt pointer spaces"); + + FailureOr config = packCopyCbufToBtConfig( + op, adaptor.getConvControl(), adaptor.getNBurst(), adaptor.getLenBurst(), + adaptor.getSourceGap(), adaptor.getDstGap()); + if (failed(config)) + return rewriter.notifyMatchFailure(op, "failed to pack copy_cbuf_to_bt config"); + + Type i64Ty = rewriter.getI64Type(); + Value destination = + rewriter.create(op.getLoc(), i64Ty, *destinationPtr); + StringRef calleeName = buildCopyCbufToBtCallee(op); + auto funcType = rewriter.getFunctionType( + TypeRange{i64Ty, source->getType(), i64Ty}, TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, + ValueRange{destination, *source, *config}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerCopyCbufToFbufOpPattern final + : public OpConversionPattern { +public: + explicit LowerCopyCbufToFbufOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult matchAndRewrite(pto::CopyCbufToFbufOp op, + pto::CopyCbufToFbufOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value sourceRaw = adaptor.getSource(); + Value destinationRaw = adaptor.getDestination(); + if (!sourceRaw || !destinationRaw) + return rewriter.notifyMatchFailure(op, "expected converted operands"); + if (!isa(sourceRaw.getType()) || + !isa(destinationRaw.getType())) + return rewriter.notifyMatchFailure(op, "expected LLVM pointer src/dst"); + + constexpr unsigned cbufAddressSpace = + static_cast(pto::AddressSpace::MAT); + constexpr unsigned fbufAddressSpace = + static_cast(pto::AddressSpace::BIAS); + FailureOr source = + reinterpretPointerToAddrSpace(op, sourceRaw, cbufAddressSpace); + FailureOr destination = + reinterpretPointerToAddrSpace(op, destinationRaw, fbufAddressSpace); + if (failed(source) || failed(destination)) + return rewriter.notifyMatchFailure(op, "failed to map cbuf/fbuf pointer spaces"); + + FailureOr config = packCopyCbufToFbufConfig( + op, adaptor.getNBurst(), adaptor.getLenBurst(), adaptor.getSourceGap(), + adaptor.getDstGap()); + if (failed(config)) + return rewriter.notifyMatchFailure(op, "failed to pack copy_cbuf_to_fbuf config"); + + Type i64Ty = rewriter.getI64Type(); + StringRef calleeName = buildCopyCbufToFbufCallee(op.getContext()); + auto funcType = rewriter.getFunctionType( + TypeRange{destination->getType(), source->getType(), i64Ty}, TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, + ValueRange{*destination, *source, *config}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerLoadCbufToCaOpPattern final + : public OpConversionPattern { +public: + explicit LowerLoadCbufToCaOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult matchAndRewrite(pto::LoadCbufToCaOp op, + pto::LoadCbufToCaOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value sourceRaw = adaptor.getSource(); + Value destinationRaw = adaptor.getDestination(); + Value m = adaptor.getM(); + Value k = adaptor.getK(); + if (!sourceRaw || !destinationRaw || !m || !k) + return rewriter.notifyMatchFailure(op, "expected converted operands"); + + if (!isa(sourceRaw.getType()) || + !isa(destinationRaw.getType())) { + return rewriter.notifyMatchFailure(op, "expected LLVM pointer src/dst"); } + Type i64Ty = rewriter.getI64Type(); + + constexpr unsigned cbufAddressSpace = + static_cast(pto::AddressSpace::MAT); + constexpr unsigned caAddressSpace = + static_cast(pto::AddressSpace::LEFT); + FailureOr source = reinterpretPointerToAddrSpace(op, sourceRaw, cbufAddressSpace); + FailureOr destination = + reinterpretPointerToAddrSpace(op, destinationRaw, caAddressSpace); + if (failed(source) || failed(destination)) + return rewriter.notifyMatchFailure(op, "failed to map cbuf/ca pointer spaces"); + + FailureOr config0 = packLoadCbufToCaConfig0(op, m, k); + FailureOr config1 = packLoadCbufToCaConfig1(op, k); + if (failed(config0) || failed(config1)) + return rewriter.notifyMatchFailure(op, "failed to pack load_cbuf_to_ca config"); + Value transpose = getI64Constant(rewriter, op.getLoc(), 0); + + FailureOr calleeName = + buildLoadCbufToCaCallee(op.getContext(), op.getSource().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, + "unsupported load_cbuf_to_ca element type"); auto funcType = rewriter.getFunctionType( - TypeRange{resultType, resultType, resultType, maskType}, - TypeRange{resultType}); - auto call = rewriter.create( - op.getLoc(), *calleeName, TypeRange{resultType}, - ValueRange{acc, lhs, rhs, mask}); + TypeRange{destination->getType(), source->getType(), i64Ty, i64Ty, + i64Ty}, + TypeRange{}); + rewriter.create(op.getLoc(), *calleeName, TypeRange{}, + ValueRange{*destination, *source, *config0, + *config1, transpose}); state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); - rewriter.replaceOp(op, call.getResults()); + rewriter.eraseOp(op); return success(); } @@ -2134,45 +3551,65 @@ class LowerVmulaOpPattern final : public OpConversionPattern { LoweringState &state; }; -class LowerVmullOpPattern final : public OpConversionPattern { +template +class LowerLoadCbufToS4OpPattern final : public OpConversionPattern { public: - explicit LowerVmullOpPattern(TypeConverter &typeConverter, MLIRContext *context, - LoweringState &state) - : OpConversionPattern(typeConverter, context), state(state) {} + explicit LowerLoadCbufToS4OpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} LogicalResult - matchAndRewrite(pto::VmullOp op, pto::VmullOp::Adaptor adaptor, + matchAndRewrite(LoadOp op, typename LoadOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - FailureOr calleeName = - buildVmullCallee(op.getContext(), op.getLow().getType()); - if (failed(calleeName)) - return rewriter.notifyMatchFailure(op, "unsupported vmull VPTO signature"); - - Type inputType = this->getTypeConverter()->convertType(op.getLhs().getType()); - Type maskType = this->getTypeConverter()->convertType(op.getMask().getType()); - SmallVector resultTypes; - if (!inputType || !maskType || - failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), resultTypes))) { - return rewriter.notifyMatchFailure(op, "failed to convert vmull types"); - } - if (resultTypes.size() != 2 || resultTypes[0] != resultTypes[1]) - return rewriter.notifyMatchFailure(op, "unexpected converted vmull results"); + Value sourceRaw = adaptor.getSource(); + Value destinationRaw = adaptor.getDestination(); + if (!sourceRaw || !destinationRaw) + return rewriter.notifyMatchFailure(op, "expected converted operands"); + if (!isa(sourceRaw.getType()) || + !isa(destinationRaw.getType())) + return rewriter.notifyMatchFailure(op, "expected LLVM pointer src/dst"); + + constexpr unsigned cbufAddressSpace = + static_cast(pto::AddressSpace::MAT); + constexpr unsigned targetAddressSpace = + std::is_same_v + ? static_cast(pto::AddressSpace::LEFT) + : static_cast(pto::AddressSpace::RIGHT); + FailureOr source = + reinterpretPointerToAddrSpace(op, sourceRaw, cbufAddressSpace); + FailureOr destination = + reinterpretPointerToAddrSpace(op, destinationRaw, targetAddressSpace); + if (failed(source) || failed(destination)) + return rewriter.notifyMatchFailure(op, "failed to map cbuf/cube pointer spaces"); + + FailureOr config0 = packLoadCbufToS4Config0( + op, adaptor.getMStart(), adaptor.getKStart(), adaptor.getMStep(), + adaptor.getKStep()); + FailureOr config1 = + packLoadCbufToS4Config1(op, adaptor.getSrcStride(), + adaptor.getDstStride()); + if (failed(config0) || failed(config1)) + return rewriter.notifyMatchFailure(op, "failed to pack load_cbuf_to_*_s4 config"); - Value lhs = adaptor.getLhs(); - Value rhs = adaptor.getRhs(); - Value mask = adaptor.getMask(); - if (!lhs || !rhs || !mask || lhs.getType() != inputType || - rhs.getType() != inputType || mask.getType() != maskType) { - return rewriter.notifyMatchFailure(op, - "unexpected converted vmull operand types"); - } + Value transpose = + castIntegerLikeTo(op, adaptor.getTranspose(), rewriter.getI64Type()); + if (!transpose) + return rewriter.notifyMatchFailure(op, "failed to cast transpose to i64"); - auto funcType = rewriter.getFunctionType(TypeRange{inputType, inputType, maskType}, - resultTypes); - auto call = rewriter.create(op.getLoc(), *calleeName, resultTypes, - ValueRange{lhs, rhs, mask}); - state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); - rewriter.replaceOp(op, call.getResults()); + StringRef calleeName = std::is_same_v + ? buildLoadCbufToCaS4Callee(op.getContext()) + : buildLoadCbufToCbS4Callee(op.getContext()); + Type i64Ty = rewriter.getI64Type(); + auto funcType = rewriter.getFunctionType( + TypeRange{destination->getType(), source->getType(), i64Ty, i64Ty, + i64Ty}, + TypeRange{}); + rewriter.create( + op.getLoc(), calleeName, TypeRange{}, + ValueRange{*destination, *source, *config0, *config1, transpose}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); return success(); } @@ -2180,46 +3617,62 @@ class LowerVmullOpPattern final : public OpConversionPattern { LoweringState &state; }; -template -class LowerBinaryMaskedOpPattern final : public OpConversionPattern { +class LowerLoadCbufToCbOpPattern final + : public OpConversionPattern { public: - explicit LowerBinaryMaskedOpPattern(TypeConverter &typeConverter, + explicit LowerLoadCbufToCbOpPattern(TypeConverter &typeConverter, MLIRContext *context, LoweringState &state) - : OpConversionPattern(typeConverter, context), state(state) {} + : OpConversionPattern(typeConverter, context), + state(state) {} - LogicalResult - matchAndRewrite(BinaryOp op, typename BinaryOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - StringRef stem = getBinaryMaskedStem(); - FailureOr calleeName = - buildLaneTypedCallee(op.getContext(), op.getResult().getType(), stem, ".x"); - if (failed(calleeName)) - return rewriter.notifyMatchFailure(op, "unsupported binary VPTO signature"); + LogicalResult matchAndRewrite(pto::LoadCbufToCbOp op, + pto::LoadCbufToCbOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value sourceRaw = adaptor.getSource(); + Value destinationRaw = adaptor.getDestination(); + Value k = adaptor.getK(); + Value n = adaptor.getN(); + if (!sourceRaw || !destinationRaw || !k || !n) + return rewriter.notifyMatchFailure(op, "expected converted operands"); + + if (!isa(sourceRaw.getType()) || + !isa(destinationRaw.getType())) { + return rewriter.notifyMatchFailure(op, "expected LLVM pointer src/dst"); + } - Type resultType = - this->getTypeConverter()->convertType(op.getResult().getType()); - if (!resultType) - return rewriter.notifyMatchFailure(op, - "failed to convert binary result type"); + Type i64Ty = rewriter.getI64Type(); - Value lhs = adaptor.getOperands()[0]; - Value rhs = adaptor.getOperands()[1]; - Value mask = adaptor.getOperands()[2]; - Type expectedMaskType = - this->getTypeConverter()->convertType(op->getOperand(2).getType()); - if (!lhs || !rhs || !mask || lhs.getType() != resultType || - rhs.getType() != resultType || mask.getType() != expectedMaskType) { - return rewriter.notifyMatchFailure( - op, "unexpected converted binary VPTO operand types"); - } + constexpr unsigned cbufAddressSpace = + static_cast(pto::AddressSpace::MAT); + constexpr unsigned cbAddressSpace = + static_cast(pto::AddressSpace::RIGHT); + FailureOr source = reinterpretPointerToAddrSpace(op, sourceRaw, cbufAddressSpace); + FailureOr destination = + reinterpretPointerToAddrSpace(op, destinationRaw, cbAddressSpace); + if (failed(source) || failed(destination)) + return rewriter.notifyMatchFailure(op, "failed to map cbuf/cb pointer spaces"); - auto call = rewriter.create(op.getLoc(), *calleeName, - TypeRange{resultType}, - ValueRange{lhs, rhs, mask}); - state.plannedDecls.push_back( - PlannedDecl{calleeName->str(), call.getCalleeType()}); - rewriter.replaceOp(op, call.getResults()); + FailureOr config0 = packLoadCbufToCbConfig0(op, k, n); + FailureOr config1 = packLoadCbufToCbConfig1(op, n); + if (failed(config0) || failed(config1)) + return rewriter.notifyMatchFailure(op, "failed to pack load_cbuf_to_cb config"); + Value transpose = getI64Constant(rewriter, op.getLoc(), 0); + + FailureOr calleeName = + buildLoadCbufToCbCallee(op.getContext(), op.getSource().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, + "unsupported load_cbuf_to_cb element type"); + auto funcType = rewriter.getFunctionType( + TypeRange{destination->getType(), source->getType(), i64Ty, i64Ty, + i64Ty}, + TypeRange{}); + rewriter.create(op.getLoc(), *calleeName, TypeRange{}, + ValueRange{*destination, *source, *config0, + *config1, transpose}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.eraseOp(op); return success(); } @@ -2227,48 +3680,51 @@ class LowerBinaryMaskedOpPattern final : public OpConversionPattern { LoweringState &state; }; -template -class LowerCarryBinaryOpPattern final : public OpConversionPattern { +class LowerLoadCbufToCaMxOpPattern final + : public OpConversionPattern { public: - explicit LowerCarryBinaryOpPattern(TypeConverter &typeConverter, - MLIRContext *context, LoweringState &state) - : OpConversionPattern(typeConverter, context), state(state) {} + explicit LowerLoadCbufToCaMxOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} LogicalResult - matchAndRewrite(CarryOp op, typename CarryOp::Adaptor adaptor, + matchAndRewrite(pto::LoadCbufToCaMxOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - StringRef stem = getCarryBinaryStem(); - FailureOr calleeName = - buildCarryBinaryCallee(op.getContext(), op.getResult().getType(), stem); - if (failed(calleeName)) - return rewriter.notifyMatchFailure(op, "unsupported carry VPTO signature"); - - Type resultType = - this->getTypeConverter()->convertType(op.getResult().getType()); - Type carryType = - this->getTypeConverter()->convertType(op->getResult(1).getType()); - if (!resultType || !carryType) - return rewriter.notifyMatchFailure(op, - "failed to convert carry result types"); - - SmallVector callArgs; - callArgs.append(adaptor.getOperands().begin(), adaptor.getOperands().end()); - const size_t expectedArgCount = hasCarryInput() ? 4 : 3; - if (callArgs.size() != expectedArgCount || callArgs[0].getType() != resultType || - callArgs[1].getType() != resultType || callArgs.back().getType() != carryType) + Value srcRaw = adaptor.getSource(); + Value dstRaw = adaptor.getDestination(); + if (!srcRaw || !dstRaw || !adaptor.getM() || !adaptor.getK()) + return rewriter.notifyMatchFailure(op, "expected converted operands"); + if (!isa(srcRaw.getType()) || + !isa(dstRaw.getType())) + return rewriter.notifyMatchFailure(op, "expected LLVM pointer src/dst"); + + constexpr unsigned cbufAddressSpace = + static_cast(pto::AddressSpace::MAT); + constexpr unsigned caAddressSpace = + static_cast(pto::AddressSpace::LEFT); + FailureOr src = reinterpretPointerToAddrSpace(op, srcRaw, cbufAddressSpace); + FailureOr dst = reinterpretPointerToAddrSpace(op, dstRaw, caAddressSpace); + if (failed(src) || failed(dst)) + return rewriter.notifyMatchFailure(op, "failed to map cbuf/ca pointer spaces"); + + FailureOr config0 = packLoadCbufToCaConfig0(op, adaptor.getM(), adaptor.getK()); + FailureOr config1 = packLoadCbufToCaConfig1(op, adaptor.getK()); + if (failed(config0) || failed(config1)) return rewriter.notifyMatchFailure(op, - "unexpected converted carry operand types"); - if constexpr (hasCarryInput()) { - if (callArgs[2].getType() != carryType) - return rewriter.notifyMatchFailure( - op, "unexpected converted carry input operand type"); - } + "failed to pack load_cbuf_to_ca_mx config"); + auto i64Ty = rewriter.getI64Type(); + Value dstAddr = rewriter.create(op.getLoc(), i64Ty, *dst); - auto call = rewriter.create( - op.getLoc(), *calleeName, TypeRange{resultType, carryType}, callArgs); - state.plannedDecls.push_back( - PlannedDecl{calleeName->str(), call.getCalleeType()}); - rewriter.replaceOp(op, call.getResults()); + StringRef calleeName = buildLoadCbufToCaMxCallee(op.getContext()); + auto funcType = rewriter.getFunctionType( + TypeRange{i64Ty, src->getType(), i64Ty, i64Ty}, + TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, + ValueRange{dstAddr, *src, *config0, *config1}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); return success(); } @@ -2276,54 +3732,109 @@ class LowerCarryBinaryOpPattern final : public OpConversionPattern { LoweringState &state; }; -template -class LowerCopyOpPattern final : public OpConversionPattern { +class LowerLoadCbufToCbMxOpPattern final + : public OpConversionPattern { public: - explicit LowerCopyOpPattern(TypeConverter &typeConverter, MLIRContext *context, - LoweringState &state) - : OpConversionPattern(typeConverter, context), state(state) {} + explicit LowerLoadCbufToCbMxOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} LogicalResult - matchAndRewrite(CopyOp op, typename CopyOp::Adaptor adaptor, + matchAndRewrite(pto::LoadCbufToCbMxOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - FailureOr calleeName = failure(); - if constexpr (std::is_same_v) - calleeName = buildCopyGmToUbCallee(op.getContext(), op.getSource().getType()); - else - calleeName = buildCopyUbToGmCallee(op.getContext()); - if (failed(calleeName)) - return rewriter.notifyMatchFailure(op, "unsupported copy VPTO signature"); + Value srcRaw = adaptor.getSource(); + Value dstRaw = adaptor.getDestination(); + if (!srcRaw || !dstRaw || !adaptor.getK() || !adaptor.getN()) + return rewriter.notifyMatchFailure(op, "expected converted operands"); + if (!isa(srcRaw.getType()) || + !isa(dstRaw.getType())) + return rewriter.notifyMatchFailure(op, "expected LLVM pointer src/dst"); + + constexpr unsigned cbufAddressSpace = + static_cast(pto::AddressSpace::MAT); + constexpr unsigned cbAddressSpace = + static_cast(pto::AddressSpace::RIGHT); + FailureOr src = reinterpretPointerToAddrSpace(op, srcRaw, cbufAddressSpace); + FailureOr dst = reinterpretPointerToAddrSpace(op, dstRaw, cbAddressSpace); + if (failed(src) || failed(dst)) + return rewriter.notifyMatchFailure(op, "failed to map cbuf/cb pointer spaces"); + + FailureOr config0 = packLoadCbufToCbConfig0(op, adaptor.getK(), adaptor.getN()); + FailureOr config1 = packLoadCbufToCbConfig1(op, adaptor.getN()); + if (failed(config0) || failed(config1)) + return rewriter.notifyMatchFailure(op, + "failed to pack load_cbuf_to_cb_mx config"); + auto i64Ty = rewriter.getI64Type(); + Value dstAddr = rewriter.create(op.getLoc(), i64Ty, *dst); - auto llvmSourceType = - dyn_cast(adaptor.getOperands()[0].getType()); - auto llvmDestType = - dyn_cast(adaptor.getOperands()[1].getType()); - if (!llvmSourceType || !llvmDestType) - return rewriter.notifyMatchFailure(op, "expected LLVM pointer copy operands"); + StringRef calleeName = buildLoadCbufToCbMxCallee(op.getContext()); + auto funcType = rewriter.getFunctionType( + TypeRange{i64Ty, src->getType(), i64Ty, i64Ty}, + TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, + ValueRange{dstAddr, *src, *config0, *config1}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } - FailureOr config0 = failure(); - FailureOr config1 = failure(); - if constexpr (std::is_same_v) { - config0 = packCopyGmToUbConfig0(op, adaptor.getOperands()); - config1 = packCopyGmToUbConfig1(op, adaptor.getOperands()); - } else { - config0 = packCopyUbToGmConfig0(op, adaptor.getOperands()); - config1 = packCopyUbToGmConfig1(op, adaptor.getOperands()); +private: + LoweringState &state; +}; + +class LowerCopyMatrixCcToGmOpPattern final + : public OpConversionPattern { +public: + explicit LowerCopyMatrixCcToGmOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult matchAndRewrite( + pto::CopyMatrixCcToGmOp op, pto::CopyMatrixCcToGmOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value sourceRaw = adaptor.getSource(); + Value destinationRaw = adaptor.getDestination(); + Value m = adaptor.getM(); + Value n = adaptor.getN(); + if (!sourceRaw || !destinationRaw || !m || !n) + return rewriter.notifyMatchFailure(op, "expected converted operands"); + + if (!isa(sourceRaw.getType()) || + !isa(destinationRaw.getType())) { + return rewriter.notifyMatchFailure(op, "expected LLVM pointer src/dst"); } + + Type i64Ty = rewriter.getI64Type(); + if (m.getType() != i64Ty || n.getType() != i64Ty) + return rewriter.notifyMatchFailure(op, "expected i64 m/n operands"); + + constexpr unsigned gmAddressSpace = + static_cast(pto::AddressSpace::GM); + constexpr unsigned ccAddressSpace = + static_cast(pto::AddressSpace::ACC); + FailureOr source = reinterpretPointerToAddrSpace(op, sourceRaw, ccAddressSpace); + FailureOr destination = + reinterpretPointerToAddrSpace(op, destinationRaw, gmAddressSpace); + if (failed(source) || failed(destination)) + return rewriter.notifyMatchFailure(op, "failed to map cc/gm pointer spaces"); + + FailureOr config0 = packCopyMatrixCcToGmConfig0(op, m, n); + FailureOr config1 = packCopyMatrixCcToGmConfig1(op, n); if (failed(config0) || failed(config1)) - return rewriter.notifyMatchFailure(op, "failed to materialize copy config"); + return rewriter.notifyMatchFailure(op, "failed to pack copy_matrix_cc_to_gm config"); - SmallVector args{adaptor.getOperands()[1], adaptor.getOperands()[0], - *config0, *config1}; + StringRef calleeName = buildCopyMatrixCcToGmCallee(op.getContext()); auto funcType = rewriter.getFunctionType( - TypeRange{llvmDestType, llvmSourceType, rewriter.getI64Type(), - rewriter.getI64Type()}, + TypeRange{destination->getType(), source->getType(), i64Ty, i64Ty}, TypeRange{}); - auto call = rewriter.create(op.getLoc(), *calleeName, - TypeRange{}, args); - state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, + ValueRange{*destination, *source, *config0, *config1}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); rewriter.eraseOp(op); - (void)call; return success(); } @@ -2331,41 +3842,61 @@ class LowerCopyOpPattern final : public OpConversionPattern { LoweringState &state; }; -class LowerCopyUbufToUbufOpPattern final - : public OpConversionPattern { +template +class LowerCopyMatrixCcToBufOpPattern final + : public OpConversionPattern { public: - explicit LowerCopyUbufToUbufOpPattern(TypeConverter &typeConverter, - MLIRContext *context, - LoweringState &state) - : OpConversionPattern(typeConverter, context), - state(state) {} + explicit LowerCopyMatrixCcToBufOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} LogicalResult - matchAndRewrite(pto::CopyUbufToUbufOp op, - pto::CopyUbufToUbufOp::Adaptor adaptor, + matchAndRewrite(CopyOp op, typename CopyOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto llvmSourceType = - dyn_cast(adaptor.getOperands()[0].getType()); - auto llvmDestType = - dyn_cast(adaptor.getOperands()[1].getType()); - if (!llvmSourceType || !llvmDestType) - return rewriter.notifyMatchFailure(op, "expected LLVM pointer copy operands"); - - FailureOr config = packCopyUbToUbConfig(op, adaptor.getOperands()); - if (failed(config)) - return rewriter.notifyMatchFailure(op, "failed to materialize copy config"); + Value sourceRaw = adaptor.getSource(); + Value destinationRaw = adaptor.getDestination(); + if (!sourceRaw || !destinationRaw) + return rewriter.notifyMatchFailure(op, "expected converted operands"); + if (!isa(sourceRaw.getType()) || + !isa(destinationRaw.getType())) + return rewriter.notifyMatchFailure(op, "expected LLVM pointer src/dst"); + + constexpr unsigned ccAddressSpace = + static_cast(pto::AddressSpace::ACC); + constexpr unsigned targetAddressSpace = + std::is_same_v + ? static_cast(pto::AddressSpace::MAT) + : static_cast(pto::AddressSpace::VEC); + FailureOr source = + reinterpretPointerToAddrSpace(op, sourceRaw, ccAddressSpace); + FailureOr destination = + reinterpretPointerToAddrSpace(op, destinationRaw, targetAddressSpace); + if (failed(source) || failed(destination)) + return rewriter.notifyMatchFailure(op, "failed to map cc->buf pointer spaces"); + + Type i64Ty = rewriter.getI64Type(); + Value config0 = castIntegerLikeTo(op, adaptor.getConfig0(), i64Ty); + Value config1 = castIntegerLikeTo(op, adaptor.getConfig1(), i64Ty); + if (!config0 || !config1) + return rewriter.notifyMatchFailure(op, "failed to cast config operands to i64"); - StringRef calleeName = buildCopyUbToUbCallee(op.getContext()); - SmallVector args{adaptor.getOperands()[1], adaptor.getOperands()[0], - *config}; + FailureOr calleeName = + std::is_same_v + ? FailureOr(buildCopyMatrixCcToCbufCallee(op.getContext())) + : buildCopyMatrixCcToUbCallee(op.getContext(), + op.getDestination().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure( + op, "unsupported copy_matrix_cc_to_{cbuf,ub} element type"); auto funcType = rewriter.getFunctionType( - TypeRange{llvmDestType, llvmSourceType, rewriter.getI64Type()}, + TypeRange{destination->getType(), source->getType(), i64Ty, i64Ty}, TypeRange{}); - auto call = rewriter.create(op.getLoc(), calleeName, - TypeRange{}, args); - state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.create(op.getLoc(), *calleeName, TypeRange{}, + ValueRange{*destination, *source, config0, + config1}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); rewriter.eraseOp(op); - (void)call; return success(); } @@ -2373,7 +3904,6 @@ class LowerCopyUbufToUbufOpPattern final LoweringState &state; }; - template class LowerVecScalarMaskedOpPattern final : public OpConversionPattern { @@ -4700,6 +6230,31 @@ class LowerUnaryI64ConfigOpPattern final : public OpConversionPattern LoweringState &state; }; +template +class LowerNullaryConfigOpPattern final : public OpConversionPattern { +public: + explicit LowerNullaryConfigOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(ConfigOp op, typename ConfigOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + StringRef calleeName = buildNullaryConfigCallee(op.getContext()); + auto funcType = rewriter.getFunctionType(TypeRange{}, TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, + ValueRange{}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + template class LowerPipeEventSyncOpPattern final : public OpConversionPattern { public: @@ -5197,8 +6752,19 @@ static void populateVPTOOpLoweringPatterns(VPTOTypeConverter &typeConverter, LowerSetLoopConfigOpPattern, LowerSetLoopConfigOpPattern, LowerSetLoopConfigOpPattern, + LowerSetLoopConfigOpPattern, + LowerSetLoopConfigOpPattern, LowerUnaryI64ConfigOpPattern, LowerUnaryConfigOpPattern, + LowerUnaryI64ConfigOpPattern, + LowerUnaryI64ConfigOpPattern, + LowerUnaryI64ConfigOpPattern, + LowerUnaryI64ConfigOpPattern, + LowerUnaryI64ConfigOpPattern, + LowerUnaryI64ConfigOpPattern, + LowerUnaryI64ConfigOpPattern, + LowerNullaryConfigOpPattern, + LowerNullaryConfigOpPattern, LowerPipeEventSyncOpPattern, LowerPipeEventSyncOpPattern, LowerBarrierOpPattern, LowerMemBarOpPattern, @@ -5226,6 +6792,18 @@ static void populateVPTOOpLoweringPatterns(VPTOTypeConverter &typeConverter, LowerPredicateStoreOpPattern, LowerPredicateStoreOpPattern, LowerPstuOpPattern, LowerVstusOpPattern, LowerVsturOpPattern, + LowerCopyGmToCbufOpPattern, LowerLoadCbufToCaOpPattern, + LowerLoadCbufToCbOpPattern, + LowerLoadCbufToS4OpPattern, + LowerLoadCbufToS4OpPattern, + LowerLoadCbufToCaMxOpPattern, + LowerLoadCbufToCbMxOpPattern, LowerCopyMatrixCcToGmOpPattern, + LowerCopyMatrixCcToBufOpPattern, + LowerCopyMatrixCcToBufOpPattern, + LowerCopyCbufToBtOpPattern, LowerCopyCbufToFbufOpPattern, + LowerCopyGmToCbufMultiOpPattern, + LowerCopyGmToCbufMultiOpPattern, + LowerMadOpPattern, LowerMadMxOpPattern, LowerCopyOpPattern, LowerCopyOpPattern, LowerCopyUbufToUbufOpPattern>( @@ -5247,7 +6825,12 @@ static void configureVPTOOpLoweringTarget(ConversionTarget &target, target.addIllegalOp(); + pto::SetLoop3ParaOp, pto::SetChannelParaOp, + pto::SetLoop2StrideOutToL1Op, pto::SetLoop1StrideOutToL1Op, + pto::SetLoopSizeOutToL1Op, pto::SetMte2NzParaOp, + pto::SetPadValOutToL1Op, pto::SetFpcOp, + pto::SetAtomicS32Op, pto::SetAtomicS8Op, pto::SetCtrlOp, + pto::SetMovPadValOp, pto::SetQuantPreOp>(); target.addIllegalOp(); target.addIllegalOp(); + pto::CopyUbufToUbufOp, + pto::CopyGmToCbufOp, pto::LoadCbufToCaOp, + pto::LoadCbufToCbOp, pto::LoadCbufToCaS4Op, + pto::LoadCbufToCbS4Op, pto::LoadCbufToCaMxOp, + pto::LoadCbufToCbMxOp, pto::CopyMatrixCcToGmOp, + pto::CopyMatrixCcToCbufOp, pto::CopyMatrixCcToUbOp, + pto::CopyCbufToBtOp, pto::CopyCbufToFbufOp, + pto::CopyGmToCbufMultiNd2NzOp, + pto::CopyGmToCbufMultiDn2NzOp, + pto::MadOp, pto::MadMxOp>(); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); } diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_bf16bf16f32/compare.py b/test/vpto/cases/micro-op/cube-matmul/mad_bf16bf16f32/compare.py new file mode 100644 index 000000000..c1391455e --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_bf16bf16f32/compare.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.float32) + output = np.fromfile(output_path, dtype=np.float32) + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch: {golden.shape} vs {output.shape}") + return False + if np.allclose(golden, output, atol=1e-2, rtol=1e-2): + return True + diff = np.where(np.abs(golden - output) > (1e-2 + 1e-2 * np.abs(golden)))[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] first mismatch at idx={idx}: golden={float(golden[idx])}, out={float(output[idx])}") + return False + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_bf16bf16f32/golden.py b/test/vpto/cases/micro-op/cube-matmul/mad_bf16bf16f32/golden.py new file mode 100644 index 000000000..b294da63a --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_bf16bf16f32/golden.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path +import numpy as np + +M = 16 +N = 16 +K = 16 + + +def generate(output_dir: Path) -> None: + a = np.zeros((M, K), dtype=np.uint16) + b = np.zeros((K, N), dtype=np.uint16) + c = np.zeros((M, N), dtype=np.float32) + golden_c = np.zeros((M, N), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + a.reshape(-1).tofile(output_dir / "v1.bin") + b.reshape(-1).tofile(output_dir / "v2.bin") + c.reshape(-1).tofile(output_dir / "v3.bin") + golden_c.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_bf16bf16f32/kernel.pto b/test/vpto/cases/micro-op/cube-matmul/mad_bf16bf16f32/kernel.pto new file mode 100644 index 000000000..724fc0889 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_bf16bf16f32/kernel.pto @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5"} { + func.func @mad_bf16bf16f32_kernel(%a_gm: !pto.ptr, + %b_gm: !pto.ptr, + %c_gm: !pto.ptr) { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + + %l1_a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l1_b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0c = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.copy_gm_to_cbuf %a_gm, %l1_a, %c1_i64, %c16_i64, %c0_i64, %c0_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.copy_gm_to_cbuf %b_gm, %l1_b, %c1_i64, %c16_i64, %c0_i64, %c0_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_M", "EVENT_ID0"] + + pto.load_cbuf_to_ca %l1_a, %l0a, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.load_cbuf_to_cb %l1_b, %l0b, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mad %l0a, %l0b, %l0c, %c16_i64, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + + pto.copy_matrix_cc_to_gm %l0c, %c_gm, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_bf16bf16f32/launch.cpp b/test/vpto/cases/micro-op/cube-matmul/mad_bf16bf16f32/launch.cpp new file mode 100644 index 000000000..95d872f6d --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_bf16bf16f32/launch.cpp @@ -0,0 +1,49 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif + +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void mad_bf16bf16f32_kernel(__gm__ __bf16 *a, + __gm__ __bf16 *b, + __gm__ float *c); + +void LaunchMad_bf16bf16f32_kernel(__bf16 *a, __bf16 *b, float *c, void *stream) { + mad_bf16bf16f32_kernel<<<1, nullptr, stream>>>((__gm__ __bf16 *)a, + (__gm__ __bf16 *)b, + (__gm__ float *)c); +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_bf16bf16f32/main.cpp b/test/vpto/cases/micro-op/cube-matmul/mad_bf16bf16f32/main.cpp new file mode 100644 index 000000000..64edd74f8 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_bf16bf16f32/main.cpp @@ -0,0 +1,127 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +#define FILE_CHECK(expr, path) \ + do { \ + if (!(expr)) { \ + std::fprintf(stderr, "[ERROR] file operation failed: %s (%s:%d)\n", \ + path, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchMad_bf16bf16f32_kernel(__bf16 *a, __bf16 *b, float *c, void *stream); + +int main() { + constexpr size_t kM = 16; + constexpr size_t kN = 16; + constexpr size_t kK = 16; + constexpr size_t aElem = kM * kK; + constexpr size_t bElem = kK * kN; + constexpr size_t cElem = kM * kN; + + constexpr size_t aSize = aElem * sizeof(__bf16); + constexpr size_t bSize = bElem * sizeof(__bf16); + constexpr size_t cSize = cElem * sizeof(float); + + __bf16 *aHost = nullptr; + __bf16 *bHost = nullptr; + float *cHost = nullptr; + __bf16 *aDevice = nullptr; + __bf16 *bDevice = nullptr; + float *cDevice = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + size_t inputSize = 0; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&aHost), aSize)); + ACL_CHECK(aclrtMallocHost((void **)(&bHost), bSize)); + ACL_CHECK(aclrtMallocHost((void **)(&cHost), cSize)); + ACL_CHECK(aclrtMalloc((void **)&aDevice, aSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&bDevice, bSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&cDevice, cSize, ACL_MEM_MALLOC_HUGE_FIRST)); + + inputSize = aSize; + FILE_CHECK(ReadFile("./v1.bin", inputSize, aHost, aSize) && inputSize == aSize, + "./v1.bin"); + inputSize = bSize; + FILE_CHECK(ReadFile("./v2.bin", inputSize, bHost, bSize) && inputSize == bSize, + "./v2.bin"); + inputSize = cSize; + FILE_CHECK(ReadFile("./v3.bin", inputSize, cHost, cSize) && inputSize == cSize, + "./v3.bin"); + + ACL_CHECK(aclrtMemcpy(aDevice, aSize, aHost, aSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(bDevice, bSize, bHost, bSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(cDevice, cSize, cHost, cSize, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchMad_bf16bf16f32_kernel(aDevice, bDevice, cDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + + ACL_CHECK(aclrtMemcpy(cHost, cSize, cDevice, cSize, ACL_MEMCPY_DEVICE_TO_HOST)); + FILE_CHECK(WriteFile("./v3.bin", cHost, cSize), "./v3.bin"); + +cleanup: + aclrtFree(aDevice); + aclrtFree(bDevice); + aclrtFree(cDevice); + aclrtFreeHost(aHost); + aclrtFreeHost(bHost); + aclrtFreeHost(cHost); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_bf16bf16f32/stub.cpp b/test/vpto/cases/micro-op/cube-matmul/mad_bf16bf16f32/stub.cpp new file mode 100644 index 000000000..3b32c0f85 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_bf16bf16f32/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void mad_bf16bf16f32_kernel(__gm__ __bf16 *a, + __gm__ __bf16 *b, + __gm__ float *c) { + (void)a; + (void)b; + (void)c; +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_f16f16f32/compare.py b/test/vpto/cases/micro-op/cube-matmul/mad_f16f16f32/compare.py new file mode 100644 index 000000000..c1391455e --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_f16f16f32/compare.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.float32) + output = np.fromfile(output_path, dtype=np.float32) + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch: {golden.shape} vs {output.shape}") + return False + if np.allclose(golden, output, atol=1e-2, rtol=1e-2): + return True + diff = np.where(np.abs(golden - output) > (1e-2 + 1e-2 * np.abs(golden)))[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] first mismatch at idx={idx}: golden={float(golden[idx])}, out={float(output[idx])}") + return False + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_f16f16f32/golden.py b/test/vpto/cases/micro-op/cube-matmul/mad_f16f16f32/golden.py new file mode 100644 index 000000000..044c06618 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_f16f16f32/golden.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path +import numpy as np + +M = 16 +N = 16 +K = 16 + + +def generate(output_dir: Path) -> None: + a = np.zeros((M, K), dtype=np.float16) + b = np.zeros((K, N), dtype=np.float16) + c = np.zeros((M, N), dtype=np.float32) + golden_c = np.zeros((M, N), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + a.reshape(-1).tofile(output_dir / "v1.bin") + b.reshape(-1).tofile(output_dir / "v2.bin") + c.reshape(-1).tofile(output_dir / "v3.bin") + golden_c.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_f16f16f32/kernel.pto b/test/vpto/cases/micro-op/cube-matmul/mad_f16f16f32/kernel.pto new file mode 100644 index 000000000..c56a78148 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_f16f16f32/kernel.pto @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5"} { + func.func @mad_f16f16f32_kernel(%a_gm: !pto.ptr, + %b_gm: !pto.ptr, + %c_gm: !pto.ptr) { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + + %l1_a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l1_b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0c = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.copy_gm_to_cbuf %a_gm, %l1_a, %c1_i64, %c16_i64, %c0_i64, %c0_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.copy_gm_to_cbuf %b_gm, %l1_b, %c1_i64, %c16_i64, %c0_i64, %c0_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_M", "EVENT_ID0"] + + pto.load_cbuf_to_ca %l1_a, %l0a, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.load_cbuf_to_cb %l1_b, %l0b, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mad %l0a, %l0b, %l0c, %c16_i64, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + + pto.copy_matrix_cc_to_gm %l0c, %c_gm, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_f16f16f32/launch.cpp b/test/vpto/cases/micro-op/cube-matmul/mad_f16f16f32/launch.cpp new file mode 100644 index 000000000..16aa85c76 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_f16f16f32/launch.cpp @@ -0,0 +1,49 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif + +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void mad_f16f16f32_kernel(__gm__ __fp16 *a, + __gm__ __fp16 *b, + __gm__ float *c); + +void LaunchMad_f16f16f32_kernel(__fp16 *a, __fp16 *b, float *c, void *stream) { + mad_f16f16f32_kernel<<<1, nullptr, stream>>>((__gm__ __fp16 *)a, + (__gm__ __fp16 *)b, + (__gm__ float *)c); +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_f16f16f32/main.cpp b/test/vpto/cases/micro-op/cube-matmul/mad_f16f16f32/main.cpp new file mode 100644 index 000000000..f2f352beb --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_f16f16f32/main.cpp @@ -0,0 +1,127 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +#define FILE_CHECK(expr, path) \ + do { \ + if (!(expr)) { \ + std::fprintf(stderr, "[ERROR] file operation failed: %s (%s:%d)\n", \ + path, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchMad_f16f16f32_kernel(__fp16 *a, __fp16 *b, float *c, void *stream); + +int main() { + constexpr size_t kM = 16; + constexpr size_t kN = 16; + constexpr size_t kK = 16; + constexpr size_t aElem = kM * kK; + constexpr size_t bElem = kK * kN; + constexpr size_t cElem = kM * kN; + + constexpr size_t aSize = aElem * sizeof(__fp16); + constexpr size_t bSize = bElem * sizeof(__fp16); + constexpr size_t cSize = cElem * sizeof(float); + + __fp16 *aHost = nullptr; + __fp16 *bHost = nullptr; + float *cHost = nullptr; + __fp16 *aDevice = nullptr; + __fp16 *bDevice = nullptr; + float *cDevice = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + size_t inputSize = 0; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&aHost), aSize)); + ACL_CHECK(aclrtMallocHost((void **)(&bHost), bSize)); + ACL_CHECK(aclrtMallocHost((void **)(&cHost), cSize)); + ACL_CHECK(aclrtMalloc((void **)&aDevice, aSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&bDevice, bSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&cDevice, cSize, ACL_MEM_MALLOC_HUGE_FIRST)); + + inputSize = aSize; + FILE_CHECK(ReadFile("./v1.bin", inputSize, aHost, aSize) && inputSize == aSize, + "./v1.bin"); + inputSize = bSize; + FILE_CHECK(ReadFile("./v2.bin", inputSize, bHost, bSize) && inputSize == bSize, + "./v2.bin"); + inputSize = cSize; + FILE_CHECK(ReadFile("./v3.bin", inputSize, cHost, cSize) && inputSize == cSize, + "./v3.bin"); + + ACL_CHECK(aclrtMemcpy(aDevice, aSize, aHost, aSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(bDevice, bSize, bHost, bSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(cDevice, cSize, cHost, cSize, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchMad_f16f16f32_kernel(aDevice, bDevice, cDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + + ACL_CHECK(aclrtMemcpy(cHost, cSize, cDevice, cSize, ACL_MEMCPY_DEVICE_TO_HOST)); + FILE_CHECK(WriteFile("./v3.bin", cHost, cSize), "./v3.bin"); + +cleanup: + aclrtFree(aDevice); + aclrtFree(bDevice); + aclrtFree(cDevice); + aclrtFreeHost(aHost); + aclrtFreeHost(bHost); + aclrtFreeHost(cHost); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_f16f16f32/stub.cpp b/test/vpto/cases/micro-op/cube-matmul/mad_f16f16f32/stub.cpp new file mode 100644 index 000000000..eecbe74d6 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_f16f16f32/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void mad_f16f16f32_kernel(__gm__ __fp16 *a, + __gm__ __fp16 *b, + __gm__ float *c) { + (void)a; + (void)b; + (void)c; +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_f32f32f32/compare.py b/test/vpto/cases/micro-op/cube-matmul/mad_f32f32f32/compare.py new file mode 100644 index 000000000..c1391455e --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_f32f32f32/compare.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.float32) + output = np.fromfile(output_path, dtype=np.float32) + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch: {golden.shape} vs {output.shape}") + return False + if np.allclose(golden, output, atol=1e-2, rtol=1e-2): + return True + diff = np.where(np.abs(golden - output) > (1e-2 + 1e-2 * np.abs(golden)))[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] first mismatch at idx={idx}: golden={float(golden[idx])}, out={float(output[idx])}") + return False + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_f32f32f32/golden.py b/test/vpto/cases/micro-op/cube-matmul/mad_f32f32f32/golden.py new file mode 100644 index 000000000..bb7011e62 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_f32f32f32/golden.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path +import numpy as np + +M = 16 +N = 16 +K = 16 + + +def generate(output_dir: Path) -> None: + a = np.zeros((M, K), dtype=np.float32) + b = np.zeros((K, N), dtype=np.float32) + c = np.zeros((M, N), dtype=np.float32) + golden_c = np.zeros((M, N), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + a.reshape(-1).tofile(output_dir / "v1.bin") + b.reshape(-1).tofile(output_dir / "v2.bin") + c.reshape(-1).tofile(output_dir / "v3.bin") + golden_c.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_f32f32f32/kernel.pto b/test/vpto/cases/micro-op/cube-matmul/mad_f32f32f32/kernel.pto new file mode 100644 index 000000000..7bf137c86 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_f32f32f32/kernel.pto @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5"} { + func.func @mad_f32f32f32_kernel(%a_gm: !pto.ptr, + %b_gm: !pto.ptr, + %c_gm: !pto.ptr) { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + + %l1_a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l1_b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0c = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.copy_gm_to_cbuf %a_gm, %l1_a, %c1_i64, %c16_i64, %c0_i64, %c0_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.copy_gm_to_cbuf %b_gm, %l1_b, %c1_i64, %c16_i64, %c0_i64, %c0_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_M", "EVENT_ID0"] + + pto.load_cbuf_to_ca %l1_a, %l0a, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.load_cbuf_to_cb %l1_b, %l0b, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mad %l0a, %l0b, %l0c, %c16_i64, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + + pto.copy_matrix_cc_to_gm %l0c, %c_gm, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_f32f32f32/launch.cpp b/test/vpto/cases/micro-op/cube-matmul/mad_f32f32f32/launch.cpp new file mode 100644 index 000000000..92efd7f89 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_f32f32f32/launch.cpp @@ -0,0 +1,49 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif + +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void mad_f32f32f32_kernel(__gm__ float *a, + __gm__ float *b, + __gm__ float *c); + +void LaunchMad_f32f32f32_kernel(float *a, float *b, float *c, void *stream) { + mad_f32f32f32_kernel<<<1, nullptr, stream>>>((__gm__ float *)a, + (__gm__ float *)b, + (__gm__ float *)c); +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_f32f32f32/main.cpp b/test/vpto/cases/micro-op/cube-matmul/mad_f32f32f32/main.cpp new file mode 100644 index 000000000..4268b9fea --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_f32f32f32/main.cpp @@ -0,0 +1,127 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +#define FILE_CHECK(expr, path) \ + do { \ + if (!(expr)) { \ + std::fprintf(stderr, "[ERROR] file operation failed: %s (%s:%d)\n", \ + path, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchMad_f32f32f32_kernel(float *a, float *b, float *c, void *stream); + +int main() { + constexpr size_t kM = 16; + constexpr size_t kN = 16; + constexpr size_t kK = 16; + constexpr size_t aElem = kM * kK; + constexpr size_t bElem = kK * kN; + constexpr size_t cElem = kM * kN; + + constexpr size_t aSize = aElem * sizeof(float); + constexpr size_t bSize = bElem * sizeof(float); + constexpr size_t cSize = cElem * sizeof(float); + + float *aHost = nullptr; + float *bHost = nullptr; + float *cHost = nullptr; + float *aDevice = nullptr; + float *bDevice = nullptr; + float *cDevice = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + size_t inputSize = 0; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&aHost), aSize)); + ACL_CHECK(aclrtMallocHost((void **)(&bHost), bSize)); + ACL_CHECK(aclrtMallocHost((void **)(&cHost), cSize)); + ACL_CHECK(aclrtMalloc((void **)&aDevice, aSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&bDevice, bSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&cDevice, cSize, ACL_MEM_MALLOC_HUGE_FIRST)); + + inputSize = aSize; + FILE_CHECK(ReadFile("./v1.bin", inputSize, aHost, aSize) && inputSize == aSize, + "./v1.bin"); + inputSize = bSize; + FILE_CHECK(ReadFile("./v2.bin", inputSize, bHost, bSize) && inputSize == bSize, + "./v2.bin"); + inputSize = cSize; + FILE_CHECK(ReadFile("./v3.bin", inputSize, cHost, cSize) && inputSize == cSize, + "./v3.bin"); + + ACL_CHECK(aclrtMemcpy(aDevice, aSize, aHost, aSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(bDevice, bSize, bHost, bSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(cDevice, cSize, cHost, cSize, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchMad_f32f32f32_kernel(aDevice, bDevice, cDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + + ACL_CHECK(aclrtMemcpy(cHost, cSize, cDevice, cSize, ACL_MEMCPY_DEVICE_TO_HOST)); + FILE_CHECK(WriteFile("./v3.bin", cHost, cSize), "./v3.bin"); + +cleanup: + aclrtFree(aDevice); + aclrtFree(bDevice); + aclrtFree(cDevice); + aclrtFreeHost(aHost); + aclrtFreeHost(bHost); + aclrtFreeHost(cHost); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_f32f32f32/stub.cpp b/test/vpto/cases/micro-op/cube-matmul/mad_f32f32f32/stub.cpp new file mode 100644 index 000000000..02acc8852 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_f32f32f32/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void mad_f32f32f32_kernel(__gm__ float *a, + __gm__ float *b, + __gm__ float *c) { + (void)a; + (void)b; + (void)c; +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_mx/compare.py b/test/vpto/cases/micro-op/cube-matmul/mad_mx/compare.py new file mode 100644 index 000000000..e01389171 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_mx/compare.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.float32) + output = np.fromfile(output_path, dtype=np.float32) + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch: {golden.shape} vs {output.shape}") + return False + if np.array_equal(golden, output): + return True + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] first mismatch at idx={idx}: golden={float(golden[idx])}, out={float(output[idx])}") + return False + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_mx/golden.py b/test/vpto/cases/micro-op/cube-matmul/mad_mx/golden.py new file mode 100644 index 000000000..6e5badcf9 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_mx/golden.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path + +import numpy as np + +M = 16 +N = 16 +K = 16 + + +def generate(output_dir: Path) -> None: + a = np.zeros((M, K), dtype=np.uint8) + b = np.zeros((K, N), dtype=np.uint8) + c = np.zeros((M, N), dtype=np.float32) + golden_c = np.zeros((M, N), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + a.reshape(-1).tofile(output_dir / "v1.bin") + b.reshape(-1).tofile(output_dir / "v2.bin") + c.reshape(-1).tofile(output_dir / "v3.bin") + golden_c.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_mx/kernel.pto b/test/vpto/cases/micro-op/cube-matmul/mad_mx/kernel.pto new file mode 100644 index 000000000..418572106 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_mx/kernel.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5"} { + func.func @mad_mx_kernel(%a_gm: !pto.ptr, + %b_gm: !pto.ptr, + %c_gm: !pto.ptr) { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c8_i64 = arith.constant 8 : i64 + %c16_i64 = arith.constant 16 : i64 + + %l1_a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l1_b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0c = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.copy_gm_to_cbuf %a_gm, %l1_a, %c1_i64, %c8_i64, %c0_i64, %c0_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.copy_gm_to_cbuf %b_gm, %l1_b, %c1_i64, %c8_i64, %c0_i64, %c0_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_M", "EVENT_ID0"] + + pto.load_cbuf_to_ca_mx %l1_a, %l0a, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.load_cbuf_to_cb_mx %l1_b, %l0b, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mad_mx %l0a, %l0b, %l0c, %c16_i64, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + + pto.copy_matrix_cc_to_gm %l0c, %c_gm, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_mx/launch.cpp b/test/vpto/cases/micro-op/cube-matmul/mad_mx/launch.cpp new file mode 100644 index 000000000..eddfd1033 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_mx/launch.cpp @@ -0,0 +1,49 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif + +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void mad_mx_kernel(__gm__ uint8_t *a, + __gm__ uint8_t *b, + __gm__ float *c); + +void LaunchMad_mx_kernel(uint8_t *a, uint8_t *b, float *c, void *stream) { + mad_mx_kernel<<<1, nullptr, stream>>>((__gm__ uint8_t *)a, + (__gm__ uint8_t *)b, + (__gm__ float *)c); +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_mx/main.cpp b/test/vpto/cases/micro-op/cube-matmul/mad_mx/main.cpp new file mode 100644 index 000000000..42773b0ce --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_mx/main.cpp @@ -0,0 +1,127 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +#define FILE_CHECK(expr, path) \ + do { \ + if (!(expr)) { \ + std::fprintf(stderr, "[ERROR] file operation failed: %s (%s:%d)\n", \ + path, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchMad_mx_kernel(uint8_t *a, uint8_t *b, float *c, void *stream); + +int main() { + constexpr size_t kM = 16; + constexpr size_t kN = 16; + constexpr size_t kK = 16; + constexpr size_t aElem = kM * kK; + constexpr size_t bElem = kK * kN; + constexpr size_t cElem = kM * kN; + + constexpr size_t aSize = aElem * sizeof(uint8_t); + constexpr size_t bSize = bElem * sizeof(uint8_t); + constexpr size_t cSize = cElem * sizeof(float); + + uint8_t *aHost = nullptr; + uint8_t *bHost = nullptr; + float *cHost = nullptr; + uint8_t *aDevice = nullptr; + uint8_t *bDevice = nullptr; + float *cDevice = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + size_t inputSize = 0; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&aHost), aSize)); + ACL_CHECK(aclrtMallocHost((void **)(&bHost), bSize)); + ACL_CHECK(aclrtMallocHost((void **)(&cHost), cSize)); + ACL_CHECK(aclrtMalloc((void **)&aDevice, aSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&bDevice, bSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&cDevice, cSize, ACL_MEM_MALLOC_HUGE_FIRST)); + + inputSize = aSize; + FILE_CHECK(ReadFile("./v1.bin", inputSize, aHost, aSize) && inputSize == aSize, + "./v1.bin"); + inputSize = bSize; + FILE_CHECK(ReadFile("./v2.bin", inputSize, bHost, bSize) && inputSize == bSize, + "./v2.bin"); + inputSize = cSize; + FILE_CHECK(ReadFile("./v3.bin", inputSize, cHost, cSize) && inputSize == cSize, + "./v3.bin"); + + ACL_CHECK(aclrtMemcpy(aDevice, aSize, aHost, aSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(bDevice, bSize, bHost, bSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(cDevice, cSize, cHost, cSize, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchMad_mx_kernel(aDevice, bDevice, cDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + + ACL_CHECK(aclrtMemcpy(cHost, cSize, cDevice, cSize, ACL_MEMCPY_DEVICE_TO_HOST)); + FILE_CHECK(WriteFile("./v3.bin", cHost, cSize), "./v3.bin"); + +cleanup: + aclrtFree(aDevice); + aclrtFree(bDevice); + aclrtFree(cDevice); + aclrtFreeHost(aHost); + aclrtFreeHost(bHost); + aclrtFreeHost(cHost); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_mx/stub.cpp b/test/vpto/cases/micro-op/cube-matmul/mad_mx/stub.cpp new file mode 100644 index 000000000..b2e381d39 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_mx/stub.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void mad_mx_kernel(__gm__ uint8_t *a, + __gm__ uint8_t *b, + __gm__ float *c) { + (void)a; + (void)b; + (void)c; +} diff --git a/test/vpto/scripts/run_host_vpto_validation.sh b/test/vpto/scripts/run_host_vpto_validation.sh index a748ac95a..19ea53fda 100755 --- a/test/vpto/scripts/run_host_vpto_validation.sh +++ b/test/vpto/scripts/run_host_vpto_validation.sh @@ -21,6 +21,8 @@ PTOAS_BIN="${PTOAS_BIN:-${ROOT_DIR}/build/tools/ptoas/ptoas}" PTOAS_FLAGS="${PTOAS_FLAGS:---pto-arch a5}" VPTO_FLAGS="${VPTO_FLAGS:---pto-backend=vpto --vpto-emit-hivm-llvm}" AICORE_ARCH="${AICORE_ARCH:-dav-c310-vec}" +CUBE_AICORE_ARCH="${CUBE_AICORE_ARCH:-dav-c310-cube}" +CUBE_CASES="${CUBE_CASES:-mad_mx mad_f16f16f32 mad_f32f32f32 mad_bf16bf16f32}" # set he HOST_RUNNER to "ssh root@localhost" if must change user to root to access the device HOST_RUNNER="${HOST_RUNNER:-}" CASE_NAME="${CASE_NAME:-}" @@ -29,6 +31,11 @@ DEVICE="${DEVICE:-SIM}" SIM_LIB_DIR="${SIM_LIB_DIR:-}" COMPILE_ONLY="${COMPILE_ONLY:-0}" +declare -a CUBE_CASE_LIST=() +if [[ -n "${CUBE_CASES}" ]]; then + read -r -a CUBE_CASE_LIST <<< "${CUBE_CASES//,/ }" +fi + log() { echo "[$(date +'%F %T')] $*" } @@ -196,9 +203,33 @@ discover_cases() { readarray -t CASES < <(discover_cases) [[ "${#CASES[@]}" -gt 0 ]] || die "no cases found under ${CASES_ROOT}" +case_uses_cube_mode() { + local case_name="$1" + local case_base="${case_name##*/}" + for item in "${CUBE_CASE_LIST[@]}"; do + [[ -n "${item}" ]] || continue + if [[ "${case_name}" == "${item}" || "${case_name}" == */"${item}" || + "${case_base}" == "${item}" ]]; then + return 0 + fi + done + return 1 +} + +build_case_vpto_flags() { + local case_name="$1" + local base_flags="$2" + if case_uses_cube_mode "${case_name}"; then + echo "${base_flags} --vpto-march ${CUBE_AICORE_ARCH} --vpto-cce-aicore-arch ${CUBE_AICORE_ARCH}" + return + fi + echo "${base_flags}" +} + build_launch_object() { local case_dir="$1" local out_obj="$2" + local case_arch="$3" "${BISHENG_BIN}" \ -c -fPIC -xcce -fenable-matrix --cce-aicore-enable-tl \ @@ -208,7 +239,7 @@ build_launch_object() { -mllvm -cce-aicore-record-overflow=true \ -mllvm -cce-aicore-addr-transform \ -mllvm -cce-aicore-dcci-insert-for-scalar=false \ - --cce-aicore-arch="${AICORE_ARCH}" \ + --cce-aicore-arch="${case_arch}" \ -DREGISTER_BASE \ -std=c++17 \ -Wno-macro-redefined -Wno-ignored-attributes \ @@ -225,6 +256,7 @@ build_host_stub() { local device_obj="$2" local stub_obj="$3" local module_id="$4" + local case_arch="$5" local host_target_args=( -triple "${HOST_TRIPLE}" -target-cpu "${HOST_TARGET_CPU}" @@ -241,7 +273,7 @@ build_host_stub() { -fcce-aicpu-legacy-launch \ -fcce-is-host \ -cce-launch-with-flagv2-impl \ - -fcce-aicore-arch "${AICORE_ARCH}" \ + -fcce-aicore-arch "${case_arch}" \ -fcce-fatobj-compile \ -emit-obj \ --mrelax-relocations \ @@ -308,6 +340,7 @@ link_kernel_so() { local repack_obj="$4" local repack_so="$5" local module_id="$6" + local case_arch="$7" local extra_lib_dirs=() local extra_link_libs=() @@ -315,7 +348,7 @@ link_kernel_so() { "${LD_LLD_BIN}" \ -x \ -cce-lite-bin-module-id "${module_id}" \ - -cce-aicore-arch="${AICORE_ARCH}" \ + -cce-aicore-arch="${case_arch}" \ -r \ -o "${repack_obj}" \ -cce-stub-dir "${CCE_STUB_DIR}" \ @@ -390,6 +423,12 @@ build_one_impl() { local host_stub_obj="${out_dir}/kernel_host_from_llvm.o" local repack_obj="${out_dir}/${case_token}_stub.cpp.o" local repack_so="${out_dir}/lib${case_token}_kernel.so" + local case_arch="${AICORE_ARCH}" + if case_uses_cube_mode "${case_name}"; then + case_arch="${CUBE_AICORE_ARCH}" + fi + local case_vpto_flags + case_vpto_flags="$(build_case_vpto_flags "${case_name}" "${VPTO_FLAGS}")" [[ -f "${case_dir}/kernel.pto" ]] || die "missing kernel.pto for ${case_name}" [[ -f "${case_dir}/stub.cpp" ]] || die "missing stub.cpp for ${case_name}" @@ -398,26 +437,27 @@ build_one_impl() { [[ -f "${case_dir}/golden.py" ]] || die "missing golden.py for ${case_name}" [[ -f "${case_dir}/compare.py" ]] || die "missing compare.py for ${case_name}" + log "[$case_name] mode: $(case_uses_cube_mode "${case_name}" && echo cube || echo vec) (aicore_arch=${case_arch})" log "[$case_name] step 1/6: lower VPTO MLIR to LLVM IR" - "${PTOAS_BIN}" ${PTOAS_FLAGS} ${VPTO_FLAGS} \ + "${PTOAS_BIN}" ${PTOAS_FLAGS} ${case_vpto_flags} \ "${case_dir}/kernel.pto" -o "${llvm_ir}" log "[$case_name] step 2/6: compile LLVM IR to device object" "${BISHENG_BIN}" \ --target=hiipu64-hisilicon-cce \ - -march="${AICORE_ARCH}" \ - --cce-aicore-arch="${AICORE_ARCH}" \ + -march="${case_arch}" \ + --cce-aicore-arch="${case_arch}" \ --cce-aicore-only \ -O2 \ -c -x ir "${llvm_ir}" \ -o "${device_obj}" log "[$case_name] step 3/6: build launch object and host fatobj stub" - build_launch_object "${case_dir}" "${launch_obj}" - build_host_stub "${case_dir}" "${device_obj}" "${host_stub_obj}" "${case_module_id}" + build_launch_object "${case_dir}" "${launch_obj}" "${case_arch}" + build_host_stub "${case_dir}" "${device_obj}" "${host_stub_obj}" "${case_module_id}" "${case_arch}" log "[$case_name] step 4/6: link kernel shared library" - link_kernel_so "${case_token}" "${host_stub_obj}" "${launch_obj}" "${repack_obj}" "${repack_so}" "${case_module_id}" + link_kernel_so "${case_token}" "${host_stub_obj}" "${launch_obj}" "${repack_obj}" "${repack_so}" "${case_module_id}" "${case_arch}" if [[ "${COMPILE_ONLY}" == "1" ]]; then log "[$case_name] compile-only mode: stop after kernel shared library" @@ -486,6 +526,9 @@ log "ASCEND_HOME_PATH=${ASCEND_HOME_PATH}" log "PTOAS_BIN=${PTOAS_BIN}" log "PTOAS_FLAGS=${PTOAS_FLAGS}" log "VPTO_FLAGS=${VPTO_FLAGS}" +log "AICORE_ARCH(default)=${AICORE_ARCH}" +log "CUBE_AICORE_ARCH=${CUBE_AICORE_ARCH}" +log "CUBE_CASES=${CUBE_CASES:-}" log "COMPILE_ONLY=${COMPILE_ONLY}" log "CASE_NAME=${CASE_NAME:-}" diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index 422b550c0..1b0bb054a 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -313,6 +313,18 @@ static llvm::cl::opt vptoUnresolvedReport( llvm::cl::desc("Write unresolved VPTO mappings to a sidecar report"), llvm::cl::value_desc("path"), llvm::cl::init("")); +static llvm::cl::opt vptoMarch( + "vpto-march", + llvm::cl::desc("Bisheng -march for VPTO HIVM LLVM emission (default: " + "dav-c310-vec). Use dav-c310-cube for cube MAT kernels."), + llvm::cl::value_desc("march"), llvm::cl::init("")); + +static llvm::cl::opt vptoCceAicoreArch( + "vpto-cce-aicore-arch", + llvm::cl::desc("Bisheng --cce-aicore-arch for VPTO HIVM target-attribute " + "queries (default: same as --vpto-march)."), + llvm::cl::value_desc("arch"), llvm::cl::init("")); + static llvm::cl::opt hivmUnresolvedReport( "hivm-unresolved-report", llvm::cl::desc("Write unresolved HIVM mappings to a sidecar report"), @@ -1192,13 +1204,32 @@ static pto::VPTOEmissionOptions buildVPTOEmissionOptions() { options.unresolvedReportPath = !hivmUnresolvedReport.empty() ? hivmUnresolvedReport : vptoUnresolvedReport; options.targetTriple = "hiipu64-hisilicon-cce"; - options.march = "dav-c310-vec"; - options.aicoreArch = "dav-c310-vec"; - options.defaultTargetCPU = "dav-c310-vec"; - options.defaultTargetFeatures = - "+ATOMIC,+ArchV130,+AregRedefinable,+ArithmeticBf16,+AtomicForB8 ," - "+F8e4m3,+F8e5m2,+F8e8m0,+FFTSBlk,+Fp4e1m2x2,+Fp4e2m1x2,+LDExtRefine," - "+MOVX8,+SPR7bits,+SyncV,+dav-c310-vec"; + + const std::string kVecMarch = "dav-c310-vec"; + const std::string kCubeMarch = "dav-c310-cube"; + std::string march = vptoMarch.empty() ? kVecMarch : std::string(vptoMarch); + std::string aicore = + vptoCceAicoreArch.empty() ? march : std::string(vptoCceAicoreArch); + options.march = march; + options.aicoreArch = aicore; + + // When bisheng target-attribute probing fails (e.g. ptoas subprocess without + // CANN in PATH), LLVMFuncOp attrs fall back to these defaults. They must + // match the HIVM intrinsics (vec vs cube) or bisheng will CannotSelect. + if (aicore.find("cube") != std::string::npos || + march.find("cube") != std::string::npos) { + options.defaultTargetCPU = kCubeMarch; + options.defaultTargetFeatures = + "+ATOMIC,+ArchV130,+AregRedefinable,+ArithmeticBf16,+AtomicForB8 ," + "+F8e4m3,+F8e5m2,+F8e8m0,+FFTSBlk,+Fp4e1m2x2,+Fp4e2m1x2,+LDExtRefine," + "+MOVX8,+SPR7bits,+SyncV,+dav-c310-cube"; + } else { + options.defaultTargetCPU = kVecMarch; + options.defaultTargetFeatures = + "+ATOMIC,+ArchV130,+AregRedefinable,+ArithmeticBf16,+AtomicForB8 ," + "+F8e4m3,+F8e5m2,+F8e8m0,+FFTSBlk,+Fp4e1m2x2,+Fp4e2m1x2,+LDExtRefine," + "+MOVX8,+SPR7bits,+SyncV,+dav-c310-vec"; + } return options; } From 3b03da313cb9e3a252747e011701882d83d0dcba Mon Sep 17 00:00:00 2001 From: mly <978226558@qq.com> Date: Fri, 24 Apr 2026 23:36:24 +0800 Subject: [PATCH 156/192] feature dma remove sid (#243) * feat: remove sid operand of dma op * feat: make variadic loops for dma op --------- Co-authored-by: mouliangyu --- docs/isa/02-dma-copy.md | 235 +++++----- docs/vpto-spec.md | 10 +- include/PTO/IR/VPTOOps.td | 47 +- lib/PTO/IR/VPTO.cpp | 412 ++++++++++-------- lib/PTO/Transforms/PTOVPTOExpandBridgeOps.cpp | 159 +++++-- .../binary-vector/vadd-bf16/kernel.pto | 12 +- .../binary-vector/vadd-f16/kernel.pto | 12 +- .../vadd-f32-exceptional/kernel.pto | 12 +- .../vadd-i16-signed-overflow/kernel.pto | 12 +- .../binary-vector/vadd-i16-signed/kernel.pto | 12 +- .../vadd-i16-unsigned-overflow/kernel.pto | 12 +- .../vadd-i16-unsigned/kernel.pto | 12 +- .../binary-vector/vadd-tail/kernel.pto | 12 +- .../micro-op/binary-vector/vadd/kernel.pto | 12 +- .../vaddc-carry-boundary/kernel.pto | 16 +- .../micro-op/binary-vector/vaddc/kernel.pto | 16 +- .../binary-vector/vand-mask-edge/kernel.pto | 12 +- .../micro-op/binary-vector/vand/kernel.pto | 12 +- .../binary-vector/vdiv-f16/kernel.pto | 12 +- .../vdiv-f32-exceptional/kernel.pto | 12 +- .../binary-vector/vdiv-tail/kernel.pto | 12 +- .../micro-op/binary-vector/vdiv/kernel.pto | 12 +- .../binary-vector/vmax-tail/kernel.pto | 12 +- .../micro-op/binary-vector/vmax/kernel.pto | 12 +- .../binary-vector/vmin-bf16/kernel.pto | 12 +- .../binary-vector/vmin-f16/kernel.pto | 12 +- .../vmin-f32-exceptional/kernel.pto | 12 +- .../vmin-f32-exceptional/vmin/kernel.pto | 12 +- .../binary-vector/vmin-i16-signed/kernel.pto | 12 +- .../vmin-i16-unsigned/kernel.pto | 12 +- .../binary-vector/vmin-tail/kernel.pto | 12 +- .../micro-op/binary-vector/vmin/kernel.pto | 12 +- .../binary-vector/vmul-tail/kernel.pto | 12 +- .../micro-op/binary-vector/vmul/kernel.pto | 12 +- .../micro-op/binary-vector/vor-f16/kernel.pto | 12 +- .../binary-vector/vor-mask-edge/kernel.pto | 12 +- .../micro-op/binary-vector/vor/kernel.pto | 12 +- .../vshl-i32-unsigned/kernel.pto | 12 +- .../vshl-shift-boundary/kernel.pto | 12 +- .../micro-op/binary-vector/vshl/kernel.pto | 12 +- .../binary-vector/vshr-i16-signed/kernel.pto | 12 +- .../vshr-shift-boundary/kernel.pto | 12 +- .../micro-op/binary-vector/vshr/kernel.pto | 12 +- .../binary-vector/vsub-tail/kernel.pto | 12 +- .../micro-op/binary-vector/vsub/kernel.pto | 12 +- .../vsubc-borrow-boundary/kernel.pto | 16 +- .../micro-op/binary-vector/vsubc/kernel.pto | 16 +- .../binary-vector/vxor-mask-edge/kernel.pto | 12 +- .../micro-op/binary-vector/vxor/kernel.pto | 12 +- .../compare-select/vcmp-eq/kernel.pto | 12 +- .../vcmp-f32-exceptional/kernel.pto | 12 +- .../compare-select/vcmp-i16-signed/kernel.pto | 12 +- .../vcmp-i16-unsigned/kernel.pto | 12 +- .../compare-select/vcmp-lt/kernel.pto | 12 +- .../compare-select/vcmp-tail/kernel.pto | 12 +- .../vcmps-f32-exceptional/kernel.pto | 8 +- .../compare-select/vcmps-f32/kernel.pto | 8 +- .../vcmps-i16-signed/kernel.pto | 8 +- .../vcmps-i16-unsigned/kernel.pto | 8 +- .../compare-select/vcmps-i8-signed/kernel.pto | 8 +- .../vcmps-i8-unsigned/kernel.pto | 8 +- .../compare-select/vcmps-tail/kernel.pto | 8 +- .../kernel.pto | 16 +- .../compare-select/vsel-i16/kernel.pto | 12 +- .../vsel-predicate-edge/kernel.pto | 12 +- .../compare-select/vsel-tail/kernel.pto | 16 +- .../micro-op/compare-select/vsel/kernel.pto | 12 +- .../compare-select/vselr-f16/kernel.pto | 12 +- .../compare-select/vselr-u8/kernel.pto | 12 +- .../micro-op/compare-select/vselr/kernel.pto | 12 +- .../conversion/vcvt-f16-special/kernel.pto | 8 +- .../vcvt-f16-to-f32-part-even/kernel.pto | 8 +- .../vcvt-f16-to-f32-part-odd/kernel.pto | 8 +- .../conversion/vcvt-f16-to-f32/kernel.pto | 8 +- .../conversion/vcvt-f32-special/kernel.pto | 8 +- .../vcvt-f32-to-f16-pk-b32/kernel.pto | 8 +- .../conversion/vcvt-f32-to-f16/kernel.pto | 8 +- .../vcvt-i32-to-i16-overflow/kernel.pto | 8 +- .../conversion/vcvt-tail-special/kernel.pto | 8 +- .../micro-op/conversion/vcvt-tail/kernel.pto | 8 +- .../conversion/vtrc-f16-rounding/kernel.pto | 8 +- .../conversion/vtrc-f32-rounding/kernel.pto | 16 +- .../conversion/vtrc-f32-special/kernel.pto | 8 +- .../vtrc-rounding-boundary/kernel.pto | 16 +- .../micro-op/dsa-sfu/vaxpy-f32/kernel.pto | 12 +- .../micro-op/dsa-sfu/vbitsort/kernel.pto | 12 +- .../cases/micro-op/dsa-sfu/vci-f16/kernel.pto | 4 +- .../cases/micro-op/dsa-sfu/vci-si8/kernel.pto | 4 +- .../cases/micro-op/dsa-sfu/vci/kernel.pto | 4 +- .../dsa-sfu/vexpdiff-boundary/kernel.pto | 12 +- .../dsa-sfu/vexpdiff-f16-part/kernel.pto | 12 +- .../micro-op/dsa-sfu/vexpdiff-f32/kernel.pto | 8 +- .../micro-op/dsa-sfu/vlrelu-f16/kernel.pto | 8 +- .../dsa-sfu/vlrelu-f32-exceptional/kernel.pto | 8 +- .../micro-op/dsa-sfu/vlrelu-f32/kernel.pto | 8 +- .../micro-op/dsa-sfu/vlrelu-tail/kernel.pto | 8 +- .../vmula-accumulator-boundary/kernel.pto | 8 +- .../cases/micro-op/dsa-sfu/vmula/kernel.pto | 8 +- .../cases/micro-op/dsa-sfu/vmull/kernel.pto | 8 +- .../micro-op/dsa-sfu/vprelu-f32/kernel.pto | 12 +- .../micro-op/dsa-sfu/vprelu-tail/kernel.pto | 12 +- .../vgather2-duplicate-index/kernel.pto | 12 +- .../gather-scatter/vgather2/kernel.pto | 12 +- .../vgather2_bc-sparse-mask/kernel.pto | 12 +- .../gather-scatter/vgather2_bc/kernel.pto | 12 +- .../vgatherb-block-boundary/kernel.pto | 12 +- .../gather-scatter/vgatherb/kernel.pto | 12 +- .../vscatter-out-of-order-index/kernel.pto | 16 +- .../gather-scatter/vscatter/kernel.pto | 16 +- .../materialization-predicate/pand/kernel.pto | 4 +- .../pdintlv_b16-nontrivial/kernel.pto | 4 +- .../pdintlv_b16/kernel.pto | 4 +- .../pdintlv_b32-nontrivial/kernel.pto | 4 +- .../pdintlv_b32/kernel.pto | 4 +- .../pdintlv_b8-nontrivial/kernel.pto | 4 +- .../pdintlv_b8/kernel.pto | 4 +- .../pge-tail-mask-boundary/kernel.pto | 4 +- .../pge-tail-mask/kernel.pto | 4 +- .../pintlv_b16-nontrivial/kernel.pto | 4 +- .../pintlv_b16/kernel.pto | 4 +- .../pintlv_b32-nontrivial/kernel.pto | 4 +- .../pintlv_b32/kernel.pto | 4 +- .../pintlv_b8-nontrivial/kernel.pto | 4 +- .../pintlv_b8/kernel.pto | 4 +- .../plt-tail-mask-boundary/kernel.pto | 4 +- .../plt-tail-mask/kernel.pto | 4 +- .../materialization-predicate/pnot/kernel.pto | 4 +- .../materialization-predicate/por/kernel.pto | 4 +- .../ppack-punpack-nontrivial/kernel.pto | 4 +- .../ppack-punpack/kernel.pto | 4 +- .../psel-tail-predicate/kernel.pto | 4 +- .../materialization-predicate/psel/kernel.pto | 4 +- .../pset-pattern-fragment/kernel.pto | 4 +- .../pset-pattern/kernel.pto | 4 +- .../materialization-predicate/pxor/kernel.pto | 4 +- .../vbr-f32/kernel.pto | 4 +- .../vbr-i32/kernel.pto | 4 +- .../vbr-i8/kernel.pto | 4 +- .../vbr-u8/kernel.pto | 4 +- .../vdup-lane/kernel.pto | 12 +- .../vdup-scalar-f16/kernel.pto | 4 +- .../vdup-scalar-i8/kernel.pto | 4 +- .../vdup-scalar-u8/kernel.pto | 4 +- .../vdup-scalar/kernel.pto | 4 +- .../predicate-load-store/pldi-norm/kernel.pto | 8 +- .../predicate-load-store/plds-norm/kernel.pto | 8 +- .../psti-norm-pldi-ds/kernel.pto | 8 +- .../psti-pk-pldi-us/kernel.pto | 8 +- .../predicate-load-store/psti-pk/kernel.pto | 4 +- .../psts-norm-plds-ds/kernel.pto | 8 +- .../kernel.pto | 8 +- .../psts-pk-plds-us/kernel.pto | 8 +- .../pstu-init-align-outside-loop/kernel.pto | 4 +- .../pstu-state-advance-boundary/kernel.pto | 4 +- .../predicate-load-store/pstu/kernel.pto | 4 +- .../vintlv-vdintlv-lane-boundary/kernel.pto | 8 +- .../rearrangement/vintlv-vdintlv/kernel.pto | 8 +- .../rearrangement/vpack-higher/kernel.pto | 8 +- .../rearrangement/vpack-lower/kernel.pto | 8 +- .../vsqz-nontrivial-mask/kernel.pto | 12 +- .../micro-op/rearrangement/vsqz/kernel.pto | 8 +- .../rearrangement/vsunpack/kernel.pto | 8 +- .../vusqz-nontrivial-mask/kernel.pto | 12 +- .../micro-op/rearrangement/vusqz/kernel.pto | 12 +- .../rearrangement/vzunpack/kernel.pto | 8 +- .../micro-op/reduction/vcadd-tail/kernel.pto | 8 +- .../cases/micro-op/reduction/vcadd/kernel.pto | 8 +- .../micro-op/reduction/vcgadd-tail/kernel.pto | 8 +- .../micro-op/reduction/vcgadd/kernel.pto | 8 +- .../micro-op/reduction/vcgmax-tie/kernel.pto | 8 +- .../micro-op/reduction/vcgmax/kernel.pto | 8 +- .../micro-op/reduction/vcgmin-tie/kernel.pto | 8 +- .../micro-op/reduction/vcgmin/kernel.pto | 8 +- .../cases/micro-op/reduction/vcmax/kernel.pto | 8 +- .../cases/micro-op/reduction/vcmin/kernel.pto | 8 +- .../micro-op/reduction/vcpadd-tail/kernel.pto | 8 +- .../micro-op/reduction/vcpadd/kernel.pto | 8 +- .../load-store-scalar-ub/kernel.pto | 8 +- .../micro-op/unary-vector/vabs-f16/kernel.pto | 8 +- .../vabs-f32-exceptional/kernel.pto | 8 +- .../vabs-i16-signed-overflow-edge/kernel.pto | 8 +- .../unary-vector/vabs-i16-signed/kernel.pto | 8 +- .../unary-vector/vabs-i16-unsigned/kernel.pto | 8 +- .../vabs-loop-carried-vreg/kernel.pto | 8 +- .../unary-vector/vabs-tail/kernel.pto | 8 +- .../micro-op/unary-vector/vabs/kernel.pto | 8 +- .../micro-op/unary-vector/vexp-f16/kernel.pto | 8 +- .../vexp-f32-exceptional/kernel.pto | 8 +- .../vexp-f32-over-underflow/kernel.pto | 8 +- .../unary-vector/vexp-tail/kernel.pto | 8 +- .../micro-op/unary-vector/vexp/kernel.pto | 8 +- .../vln-domain-boundary/kernel.pto | 8 +- .../micro-op/unary-vector/vln/kernel.pto | 8 +- .../vneg-f32-exceptional/kernel.pto | 8 +- .../micro-op/unary-vector/vneg/kernel.pto | 8 +- .../micro-op/unary-vector/vnot/kernel.pto | 8 +- .../micro-op/unary-vector/vrelu/kernel.pto | 8 +- .../vsqrt-domain-boundary/kernel.pto | 8 +- .../micro-op/unary-vector/vsqrt/kernel.pto | 8 +- .../vaddcs-carry-boundary/kernel.pto | 16 +- .../micro-op/vec-scalar/vaddcs/kernel.pto | 16 +- .../micro-op/vec-scalar/vadds-bf16/kernel.pto | 8 +- .../micro-op/vec-scalar/vadds-f16/kernel.pto | 8 +- .../vadds-f32-exceptional/kernel.pto | 8 +- .../vadds-i16-signed-overflow/kernel.pto | 8 +- .../vec-scalar/vadds-i16-signed/kernel.pto | 8 +- .../vadds-i16-unsigned-overflow/kernel.pto | 8 +- .../vec-scalar/vadds-i16-unsigned/kernel.pto | 8 +- .../micro-op/vec-scalar/vadds-tail/kernel.pto | 8 +- .../micro-op/vec-scalar/vadds/kernel.pto | 8 +- .../micro-op/vec-scalar/vmaxs-tail/kernel.pto | 8 +- .../micro-op/vec-scalar/vmaxs/kernel.pto | 8 +- .../micro-op/vec-scalar/vmins-tail/kernel.pto | 8 +- .../micro-op/vec-scalar/vmins/kernel.pto | 8 +- .../micro-op/vec-scalar/vmuls-tail/kernel.pto | 8 +- .../micro-op/vec-scalar/vmuls/kernel.pto | 8 +- .../vshls-shift-boundary/kernel.pto | 8 +- .../micro-op/vec-scalar/vshls/kernel.pto | 8 +- .../vshrs-shift-boundary/kernel.pto | 8 +- .../micro-op/vec-scalar/vshrs/kernel.pto | 8 +- .../vsubcs-borrow-boundary/kernel.pto | 16 +- .../micro-op/vec-scalar/vsubcs/kernel.pto | 16 +- .../vldas-vldus-state-chain/kernel.pto | 8 +- .../vector-load-store/vldas-vldus/kernel.pto | 8 +- .../vlds-brc-b16-f32/kernel.pto | 8 +- .../vector-load-store/vlds-brc-b16/kernel.pto | 8 +- .../vector-load-store/vlds-brc-b32/kernel.pto | 8 +- .../vlds-brc-b8-f32/kernel.pto | 8 +- .../vector-load-store/vlds-brc-blk/kernel.pto | 8 +- .../vlds-dma-loop/compare.py | 53 +++ .../vector-load-store/vlds-dma-loop/golden.py | 47 ++ .../vlds-dma-loop/kernel.pto | 67 +++ .../vlds-dma-loop/launch.cpp | 41 ++ .../vector-load-store/vlds-dma-loop/main.cpp | 103 +++++ .../vector-load-store/vlds-dma-loop/stub.cpp | 21 + .../vector-load-store/vlds-ds-b16/kernel.pto | 8 +- .../vector-load-store/vlds-tail/kernel.pto | 8 +- .../vlds-unpk-b16/kernel.pto | 8 +- .../vector-load-store/vlds-us-b16/kernel.pto | 8 +- .../vector-load-store/vlds/kernel.pto | 8 +- .../vldsx2-layout-check/kernel.pto | 8 +- .../vldsx2-vstsx2-b8-f32/kernel.pto | 8 +- .../vldsx2-vstsx2/kernel.pto | 8 +- .../vector-load-store/vsldb/kernel.pto | 8 +- .../vector-load-store/vsstb/kernel.pto | 8 +- .../vector-load-store/vstar/kernel.pto | 8 +- .../vstas-vstus-offset-update/kernel.pto | 12 +- .../vector-load-store/vsts-1pt-b16/kernel.pto | 8 +- .../vector-load-store/vsts-pk-b16/kernel.pto | 8 +- .../vsts-pk-b64-f32/kernel.pto | 12 +- .../vector-load-store/vsts-tail/kernel.pto | 8 +- .../vector-load-store/vsts/kernel.pto | 8 +- .../vstsx2-layout-check/kernel.pto | 8 +- .../vstur-init-align-outside-loop/kernel.pto | 8 +- .../vector-load-store/vstur/kernel.pto | 8 +- .../cases/vpto/dma-copy-rearrange/kernel.pto | 16 +- 256 files changed, 1948 insertions(+), 1443 deletions(-) create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-dma-loop/compare.py create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-dma-loop/golden.py create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-dma-loop/kernel.pto create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-dma-loop/launch.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-dma-loop/main.cpp create mode 100644 test/vpto/cases/micro-op/vector-load-store/vlds-dma-loop/stub.cpp diff --git a/docs/isa/02-dma-copy.md b/docs/isa/02-dma-copy.md index 845df918b..ab9f6fec5 100644 --- a/docs/isa/02-dma-copy.md +++ b/docs/isa/02-dma-copy.md @@ -24,18 +24,15 @@ surface op. - **syntax:** ```mlir -pto.dma_load %gm_src, %ub_dst, %sid, %l2_cache_ctl, %len_burst +pto.dma_load %gm_src, %ub_dst, %l2_cache_ctl, %len_burst nburst(%n_burst, %src_stride, %dst_stride) - [loop1(%loop1_count, %loop1_src_stride, %loop1_dst_stride)] - [loop2(%loop2_count, %loop2_src_stride, %loop2_dst_stride)] + [loop(%loop_count, %loop_src_stride, %loop_dst_stride)]* [pad(%pad_value[, %left_padding_count, %right_padding_count])] : !pto.ptr, !pto.ptr, i64, i64, i64, - i64, i64, i64, - [loop1 i64, i64, i64,] - [loop2 i64, i64, i64,] + [loop i64, i64, i64,]* [pad T[, i64, i64]] ``` -- **semantics:** Grouped GM→UB DMA transfer. It carries the burst, optional HW loop, and optional padding configuration on the copy op itself. +- **semantics:** Grouped GM→UB DMA transfer. `nburst(...)` defines the innermost repeated burst transfer, optional `loop(...)` groups add outer repetition levels, and `pad(...)` controls UB row padding. **Parameter Table:** @@ -43,34 +40,34 @@ pto.dma_load %gm_src, %ub_dst, %sid, %l2_cache_ctl, %len_burst |-----------|-------|-------------| | `%gm_src` | ptr | GM source pointer (`!pto.ptr`) | | `%ub_dst` | ptr | UB destination pointer (`!pto.ptr`, 32B-aligned) | -| `%sid` | 32 bits | Stream ID | | `%l2_cache_ctl` | 2 bits | L2 cache allocate control | | `%len_burst` | 16 bits | Contiguous bytes transferred per burst row | -| `nburst(%n_burst, %src_stride, %dst_stride)` | 16 bits / 40 bits / 21 bits | Required innermost burst loop: count, GM source stride, UB destination stride | -| `loop1(%loop1_count, %loop1_src_stride, %loop1_dst_stride)` | 21 bits / 40 bits / 21 bits | Optional inner HW loop: count, GM source stride, UB destination stride | -| `loop2(%loop2_count, %loop2_src_stride, %loop2_dst_stride)` | 21 bits / 40 bits / 21 bits | Optional outer HW loop: count, GM source stride, UB destination stride | +| `nburst(%n_burst, %src_stride, %dst_stride)` | 16 bits / 40 bits / 21 bits | Required innermost burst group: count, GM source stride, UB destination stride | +| `loop(%loop_count, %loop_src_stride, %loop_dst_stride)` | 21 bits / 40 bits / 21 bits | Optional outer repetition group: count, GM source stride, UB destination stride | | `pad(%pad_value[, %left_padding_count, %right_padding_count])` | scalar / 8 bits / 8 bits | Optional padding: fill value, optional left padding count, optional right padding count | **Constraints:** - `nburst(...)` is always required. -- `loop1(...)` and `loop2(...)` must each be provided as a complete group when present. +- Each `loop(...)` group must be provided as a complete triple when present. +- `nburst(...)` is the innermost group. +- `loop(...)` groups are ordered from inner to outer. +- The first `loop(...)` group wraps `nburst(...)`. +- Each additional `loop(...)` group wraps all earlier groups. - `pad(...)` may contain only `%pad_value`; omitted left and right padding counts default to 0. - If either left or right padding count is provided, both counts must be provided. -- `loop1(...)` may be used without `loop2(...)`; in that case `loop2_count` is treated as 1 when programming the loop-size register. -- `loop2(...)` requires `loop1(...)`; `loop2` without `loop1` is rejected by the verifier. -- `pad(...)` is independent of `loop1(...)` and `loop2(...)`. -- A DMA load may use `nburst(...) pad(...)` without any HW loop group. +- `pad(...)` is independent of the optional `loop(...)` groups. +- A DMA load may use `nburst(...) pad(...)` without any `loop(...)` group. **Example:** ```mlir -pto.dma_load %gm_in, %ub_out, %sid, %cache, %len_burst +pto.dma_load %gm_in, %ub_out, %cache, %len_burst nburst(%rows, %gm_row_stride, %ub_row_stride) - loop1(%tiles, %gm_tile_stride, %ub_tile_stride) + loop(%tiles, %gm_tile_stride, %ub_tile_stride) pad(%pad) - : !pto.ptr, !pto.ptr, i64, i64, i64, - i64, i64, i64, loop1 i64, i64, i64, pad f16 + : !pto.ptr, !pto.ptr, i64, i64, + loop i64, i64, i64, pad f16 ``` --- @@ -79,16 +76,13 @@ pto.dma_load %gm_in, %ub_out, %sid, %cache, %len_burst - **syntax:** ```mlir -pto.dma_store %ub_src, %gm_dst, %sid, %reserved, %len_burst +pto.dma_store %ub_src, %gm_dst, %len_burst nburst(%n_burst, %src_stride, %dst_stride) - [loop1(%loop1_count, %loop1_src_stride, %loop1_dst_stride)] - [loop2(%loop2_count, %loop2_src_stride, %loop2_dst_stride)] - : !pto.ptr, !pto.ptr, i64, i64, i64, - i64, i64, i64, - [loop1 i64, i64, i64,] - [loop2 i64, i64, i64] + [loop(%loop_count, %loop_src_stride, %loop_dst_stride)]* + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, + [loop i64, i64, i64,]* ``` -- **semantics:** Grouped UB→GM DMA transfer. It carries the burst and optional HW loop configuration on the copy op itself. +- **semantics:** Grouped UB→GM DMA transfer. `nburst(...)` defines the innermost repeated burst transfer, and optional `loop(...)` groups add outer repetition levels. **Parameter Table:** @@ -96,29 +90,28 @@ pto.dma_store %ub_src, %gm_dst, %sid, %reserved, %len_burst |-----------|-------|-------------| | `%ub_src` | ptr | UB source pointer (`!pto.ptr`, 32B-aligned) | | `%gm_dst` | ptr | GM destination pointer (`!pto.ptr`) | -| `%sid` | 32 bits | Stream ID | -| `%reserved` | 8 bits | Reserved field, normally 0 | | `%len_burst` | 16 bits | Contiguous bytes transferred per burst row | -| `nburst(%n_burst, %src_stride, %dst_stride)` | 16 bits / 21 bits / 40 bits | Required innermost burst loop: count, UB source stride, GM destination stride | -| `loop1(%loop1_count, %loop1_src_stride, %loop1_dst_stride)` | 21 bits / 21 bits / 40 bits | Optional inner HW loop: count, UB source stride, GM destination stride | -| `loop2(%loop2_count, %loop2_src_stride, %loop2_dst_stride)` | 21 bits / 21 bits / 40 bits | Optional outer HW loop: count, UB source stride, GM destination stride | +| `nburst(%n_burst, %src_stride, %dst_stride)` | 16 bits / 21 bits / 40 bits | Required innermost burst group: count, UB source stride, GM destination stride | +| `loop(%loop_count, %loop_src_stride, %loop_dst_stride)` | 21 bits / 21 bits / 40 bits | Optional outer repetition group: count, UB source stride, GM destination stride | **Constraints:** - `nburst(...)` is always required. -- `loop1(...)` and `loop2(...)` must each be provided as a complete group when present. -- `loop1(...)` may be used without `loop2(...)`; in that case `loop2_count` is treated as 1 when programming the loop-size register. -- `loop2(...)` requires `loop1(...)`; `loop2` without `loop1` is rejected by the verifier. +- Each `loop(...)` group must be provided as a complete triple when present. +- `nburst(...)` is the innermost group. +- `loop(...)` groups are ordered from inner to outer. +- The first `loop(...)` group wraps `nburst(...)`. +- Each additional `loop(...)` group wraps all earlier groups. **Example:** ```mlir -pto.dma_store %ub_in, %gm_out, %sid, %zero, %len_burst +pto.dma_store %ub_in, %gm_out, %len_burst nburst(%rows, %ub_row_stride, %gm_row_stride) - loop1(%tiles, %ub_tile_stride, %gm_tile_stride) - loop2(%batches, %ub_batch_stride, %gm_batch_stride) - : !pto.ptr, !pto.ptr, i64, i64, i64, - i64, i64, i64, loop1 i64, i64, i64, loop2 i64, i64, i64 + loop(%tiles, %ub_tile_stride, %gm_tile_stride) + loop(%batches, %ub_batch_stride, %gm_batch_stride) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, + loop i64, i64, i64, loop i64, i64, i64 ``` --- @@ -127,9 +120,9 @@ pto.dma_store %ub_in, %gm_out, %sid, %zero, %len_burst - **syntax:** ```mlir -pto.dma_copy %ub_src, %ub_dst, %sid, %len_burst +pto.dma_copy %ub_src, %ub_dst, %len_burst nburst(%n_burst, %src_gap, %dst_gap) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 ``` - **semantics:** Grouped UB→UB raw copy.. @@ -139,7 +132,6 @@ pto.dma_copy %ub_src, %ub_dst, %sid, %len_burst |-----------|-------|-------------| | `%ub_src` | ptr | UB source pointer (`!pto.ptr`, 32B-aligned) | | `%ub_dst` | ptr | UB destination pointer (`!pto.ptr`, 32B-aligned) | -| `%sid` | 16 bits | Stream ID | | `%len_burst` | 16 bits | Burst length in units of 32 bytes | | `nburst(%n_burst, %src_gap, %dst_gap)` | 16 bits / 16 bits / 16 bits | Required UB→UB outer burst group: count, source gap, destination gap | @@ -151,9 +143,9 @@ pto.dma_copy %ub_src, %ub_dst, %sid, %len_burst **Example:** ```mlir -pto.dma_copy %ub_src, %ub_dst, %sid, %len32b +pto.dma_copy %ub_src, %ub_dst, %len32b nburst(%rows, %src_gap, %dst_gap) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 ``` --- @@ -260,56 +252,88 @@ Only len_burst bytes are written to each GM row. --- -## Multi-Level Loop Semantics (C Code) +## Multi-Level Loop Semantics -The full DMA transfer is a nested loop. `loop1(...)` / `loop2(...)` control the -outer levels, and `nburst(...)` controls the innermost burst level. +The full DMA transfer is a nested loop. `nburst(...)` is the innermost group. +If one or more `loop(...)` groups are present, they wrap `nburst(...)` in the +same order they appear in the op: the first `loop(...)` is the innermost outer +group, the second `loop(...)` wraps the first one, and so on. ### GM→UB Full Loop +For a form + +```mlir +pto.dma_load %gm_src, %ub_dst, %l2_cache_ctl, %len_burst + nburst(%n_burst, %src_stride, %dst_stride) + loop(%c0, %s0, %d0) + loop(%c1, %s1, %d1) + ... + loop(%cN, %sN, %dN) + [pad(%pad_value[, %left_padding_count, %right_padding_count])] +``` + +the transfer is equivalent to: + ```c -// C equivalent of what the HW executes: -for (int j = 0; j < loop2_count; j++) { // HW outer loop - uint8_t *gm1 = gm_src + j * loop2_src_stride; - uint8_t *ub1 = ub_dst + j * loop2_dst_stride; - - for (int k = 0; k < loop1_count; k++) { // HW inner loop - uint8_t *gm2 = gm1 + k * loop1_src_stride; - uint8_t *ub2 = ub1 + k * loop1_dst_stride; - - for (int r = 0; r < n_burst; r++) { // burst engine - memcpy(ub2 + r * dst_stride, // UB dest row - gm2 + r * src_stride, // GM src row - len_burst); // contiguous bytes - if (pad_enabled) - memset(ub2 + r * dst_stride + len_burst, - pad_val, dst_stride - len_burst); - } +for (int lN = 0; lN < cN; ++lN) { + ... + for (int l1 = 0; l1 < c1; ++l1) { + for (int l0 = 0; l0 < c0; ++l0) { + uint8_t *gm_base = gm_src + l0 * s0 + l1 * s1 + ... + lN * sN; + uint8_t *ub_base = ub_dst + l0 * d0 + l1 * d1 + ... + lN * dN; + for (int r = 0; r < n_burst; ++r) { + memcpy(ub_base + r * dst_stride, + gm_base + r * src_stride, + len_burst); + if (pad_enabled) + memset(ub_base + r * dst_stride + len_burst, + pad_val, + dst_stride - len_burst); + } } + } } ``` +If no `loop(...)` group is present, only the innermost `nburst(...)` loop +remains. + ### UB→GM Full Loop +For a form + +```mlir +pto.dma_store %ub_src, %gm_dst, %len_burst + nburst(%n_burst, %src_stride, %dst_stride) + loop(%c0, %s0, %d0) + loop(%c1, %s1, %d1) + ... + loop(%cN, %sN, %dN) +``` + +the transfer is equivalent to: + ```c -// C equivalent: -for (int j = 0; j < loop2_count; j++) { - uint8_t *ub1 = ub_src + j * loop2_src_stride; - uint8_t *gm1 = gm_dst + j * loop2_dst_stride; - - for (int k = 0; k < loop1_count; k++) { - uint8_t *ub2 = ub1 + k * loop1_src_stride; - uint8_t *gm2 = gm1 + k * loop1_dst_stride; - - for (int r = 0; r < n_burst; r++) { - memcpy(gm2 + r * dst_stride, // GM dest row - ub2 + r * src_stride, // UB src row - len_burst); // contiguous bytes - } +for (int lN = 0; lN < cN; ++lN) { + ... + for (int l1 = 0; l1 < c1; ++l1) { + for (int l0 = 0; l0 < c0; ++l0) { + uint8_t *ub_base = ub_src + l0 * s0 + l1 * s1 + ... + lN * sN; + uint8_t *gm_base = gm_dst + l0 * d0 + l1 * d1 + ... + lN * dN; + for (int r = 0; r < n_burst; ++r) { + memcpy(gm_base + r * dst_stride, + ub_base + r * src_stride, + len_burst); + } } + } } ``` +If no `loop(...)` group is present, only the innermost `nburst(...)` loop +remains. + --- ## Example 1: GM→UB — Load a 32×32 f32 Tile (Simple Case) @@ -341,10 +365,9 @@ UB layout (32 × 32 f32, 32B-aligned, contiguous): ```mlir // Simple 2D load — only nburst(...) is needed -pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 +pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, - i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64 ``` --- @@ -380,10 +403,9 @@ UB layout (64 × 128 f16, 32B-aligned, contiguous): ``` ```mlir -pto.dma_load %gm_ptr, %ub_ptr, %c0_i64, %c0_i64, %c256_i64 +pto.dma_load %gm_ptr, %ub_ptr, %c0_i64, %c256_i64 nburst(%c64_i64, %c1024_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, - i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64 ``` --- @@ -419,11 +441,10 @@ UB (128 cols wide, 32B-aligned, padded): ```mlir %pad = arith.constant 0 : i16 -pto.dma_load %gm_ptr, %ub_ptr, %c0_i64, %c0_i64, %c200_i64 +pto.dma_load %gm_ptr, %ub_ptr, %c0_i64, %c200_i64 nburst(%c64_i64, %c200_i64, %c256_i64) pad(%pad, %c0_i64, %c0_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, - i64, i64, i64, pad i16, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, pad i16, i64, i64 ``` --- @@ -455,10 +476,9 @@ GM (dest, 32 × 32 f32): ``` ```mlir -pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 +pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, - i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 ``` --- @@ -494,17 +514,17 @@ GM (dest, into 1024 × 512 matrix): ``` ```mlir -pto.dma_store %ub_ptr, %gm_ptr, %c0_i64, %c0_i64, %c256_i64 +pto.dma_store %ub_ptr, %gm_ptr, %c256_i64 nburst(%c64_i64, %c256_i64, %c1024_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, - i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 ``` --- ## Example 6: GM→UB with Multi-Level Loop (Batch of Tiles) -Load 4 batches of 8×128 tiles from a [4, 8, 128] f16 tensor using loop1. +Load 4 batches of 8×128 tiles from a [4, 8, 128] f16 tensor using one outer +`loop(...)` group. ``` GM [4, 8, 128] f16 (contiguous): UB (4 tiles laid out sequentially): @@ -512,25 +532,24 @@ GM [4, 8, 128] f16 (contiguous): UB (4 tiles laid out sequentially): batch 0: 8 rows × 256 bytes [batch 0: 8×128][batch 1: 8×128] batch 1: 8 rows × 256 bytes [batch 2: 8×128][batch 3: 8×128] batch 2: 8 rows × 256 bytes - batch 3: 8 rows × 256 bytes loop1 src_stride = 2048 bytes (8 × 256) - loop1 dst_stride = 2048 bytes (8 × 256) - Each batch = 8 × 256 = 2048 bytes loop1_count = 4 (iterate over batches) + batch 3: 8 rows × 256 bytes outer loop src_stride = 2048 bytes (8 × 256) + outer loop dst_stride = 2048 bytes (8 × 256) + Each batch = 8 × 256 = 2048 bytes outer loop count = 4 (iterate over batches) ``` ```mlir -// loop1_count = 4 batches, loop2 omitted -pto.dma_load %gm_ptr, %ub_ptr, %c0_i64, %c0_i64, %c256_i64 +// One outer loop group over 4 batches +pto.dma_load %gm_ptr, %ub_ptr, %c0_i64, %c256_i64 nburst(%c8_i64, %c256_i64, %c256_i64) - loop1(%c4_i64, %c2048_i64, %c2048_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, - i64, i64, i64, loop1 i64, i64, i64 + loop(%c4_i64, %c2048_i64, %c2048_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, loop i64, i64, i64 ``` Execution trace: ``` -loop1 iter 0: gm_ptr + 0×2048 → ub_ptr + 0×2048, DMA 8 rows × 256B -loop1 iter 1: gm_ptr + 1×2048 → ub_ptr + 1×2048, DMA 8 rows × 256B -loop1 iter 2: gm_ptr + 2×2048 → ub_ptr + 2×2048, DMA 8 rows × 256B -loop1 iter 3: gm_ptr + 3×2048 → ub_ptr + 3×2048, DMA 8 rows × 256B +loop iter 0: gm_ptr + 0×2048 → ub_ptr + 0×2048, DMA 8 rows × 256B +loop iter 1: gm_ptr + 1×2048 → ub_ptr + 1×2048, DMA 8 rows × 256B +loop iter 2: gm_ptr + 2×2048 → ub_ptr + 2×2048, DMA 8 rows × 256B +loop iter 3: gm_ptr + 3×2048 → ub_ptr + 3×2048, DMA 8 rows × 256B ``` diff --git a/docs/vpto-spec.md b/docs/vpto-spec.md index 76182d26c..3aa56b275 100644 --- a/docs/vpto-spec.md +++ b/docs/vpto-spec.md @@ -250,10 +250,9 @@ pto.strict_vecscope(%ub, %ub_out, %lane) { ### Example: VecScope ```mlir -pto.dma_load %7, %2, %c0_i64, %c0_i64, %c128_i64 +pto.dma_load %7, %2, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, - i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -269,10 +268,9 @@ pto.vecscope { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] -pto.dma_store %8, %14, %c0_i64, %c0_i64, %c128_i64 +pto.dma_store %8, %14, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, - i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 ``` ### Example: Strict VecScope diff --git a/include/PTO/IR/VPTOOps.td b/include/PTO/IR/VPTOOps.td index b5ac18d41..07fa649c7 100644 --- a/include/PTO/IR/VPTOOps.td +++ b/include/PTO/IR/VPTOOps.td @@ -259,18 +259,14 @@ def PTO_DmaLoadOp : PTO_Op<"dma_load", [ let arguments = (ins PTO_BufferType:$source, PTO_BufferType:$destination, - I64:$sid, I64:$l2_cache_ctl, I64:$len_burst, I64:$n_burst, I64:$nburst_src_stride, I64:$nburst_dst_stride, - Optional:$loop1_count, - Optional:$loop1_src_stride, - Optional:$loop1_dst_stride, - Optional:$loop2_count, - Optional:$loop2_src_stride, - Optional:$loop2_dst_stride, + Variadic:$loop_counts, + Variadic:$loop_src_strides, + Variadic:$loop_dst_strides, Optional>:$pad_value, Optional:$left_padding_count, Optional:$right_padding_count @@ -285,7 +281,15 @@ def PTO_DmaLoadOp : PTO_Op<"dma_load", [ OpBuilder<(ins "::mlir::Value":$source, "::mlir::Value":$destination, - "::mlir::Value":$sid, + "::mlir::Value":$l2CacheCtl, + "::mlir::Value":$lenBurst, + "::mlir::pto::DmaLoopConfig":$nburst, + "::llvm::ArrayRef<::mlir::pto::DmaLoopConfig>":$loops, + "::std::optional<::mlir::pto::DmaPadConfig>":$pad + )>, + OpBuilder<(ins + "::mlir::Value":$source, + "::mlir::Value":$destination, "::mlir::Value":$l2CacheCtl, "::mlir::Value":$lenBurst, "::mlir::pto::DmaLoopConfig":$nburst, @@ -325,7 +329,6 @@ def PTO_DmaCopyOp : PTO_Op<"dma_copy", [ let arguments = (ins PTO_BufferType:$source, PTO_BufferType:$destination, - I64:$sid, I64:$n_burst, I64:$len_burst, I64:$src_stride, @@ -337,10 +340,10 @@ def PTO_DmaCopyOp : PTO_Op<"dma_copy", [ let hasVerifier = 1; let assemblyFormat = [{ - $source `,` $destination `,` $sid `,` $len_burst + $source `,` $destination `,` $len_burst `nburst` `(` $n_burst `,` $src_stride `,` $dst_stride `)` - attr-dict `:` type($source) `,` type($destination) `,` type($sid) `,` - type($n_burst) `,` type($len_burst) `,` type($src_stride) `,` + attr-dict `:` type($source) `,` type($destination) `,` type($n_burst) `,` + type($len_burst) `,` type($src_stride) `,` type($dst_stride) }]; } @@ -1579,18 +1582,13 @@ def PTO_DmaStoreOp : PTO_Op<"dma_store", [ let arguments = (ins PTO_BufferType:$source, PTO_BufferType:$destination, - I64:$sid, - I64:$reserved, I64:$len_burst, I64:$n_burst, I64:$nburst_src_stride, I64:$nburst_dst_stride, - Optional:$loop1_count, - Optional:$loop1_src_stride, - Optional:$loop1_dst_stride, - Optional:$loop2_count, - Optional:$loop2_src_stride, - Optional:$loop2_dst_stride + Variadic:$loop_counts, + Variadic:$loop_src_strides, + Variadic:$loop_dst_strides ); let results = (outs); @@ -1602,8 +1600,13 @@ def PTO_DmaStoreOp : PTO_Op<"dma_store", [ OpBuilder<(ins "::mlir::Value":$source, "::mlir::Value":$destination, - "::mlir::Value":$sid, - "::mlir::Value":$reserved, + "::mlir::Value":$lenBurst, + "::mlir::pto::DmaLoopConfig":$nburst, + "::llvm::ArrayRef<::mlir::pto::DmaLoopConfig>":$loops + )>, + OpBuilder<(ins + "::mlir::Value":$source, + "::mlir::Value":$destination, "::mlir::Value":$lenBurst, "::mlir::pto::DmaLoopConfig":$nburst, "::std::optional<::mlir::pto::DmaLoopConfig>":$loop1, diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp index 93dd9a6e9..cbb89e90d 100644 --- a/lib/PTO/IR/VPTO.cpp +++ b/lib/PTO/IR/VPTO.cpp @@ -1032,45 +1032,38 @@ static ParseResult parseDmaTripleGroup( return parser.parseRParen(); } -static ParseResult parseOptionalDmaTripleGroup( - OpAsmParser &parser, StringRef keyword, +static ParseResult parseOptionalDmaTripleGroupAlias( + OpAsmParser &parser, ArrayRef keywords, + StringRef &parsedKeyword, SmallVectorImpl &operands) { - if (failed(parser.parseOptionalKeyword(keyword))) - return success(); - if (parser.parseLParen()) - return failure(); - for (int i = 0; i < 3; ++i) { - OpAsmParser::UnresolvedOperand operand; - if (parser.parseOperand(operand)) - return failure(); - operands.push_back(operand); - if (i != 2 && parser.parseComma()) + parsedKeyword = {}; + for (StringRef keyword : keywords) { + if (failed(parser.parseOptionalKeyword(keyword))) + continue; + parsedKeyword = keyword; + if (parser.parseLParen()) return failure(); + for (int i = 0; i < 3; ++i) { + OpAsmParser::UnresolvedOperand operand; + if (parser.parseOperand(operand)) + return failure(); + operands.push_back(operand); + if (i != 2 && parser.parseComma()) + return failure(); + } + return parser.parseRParen(); } - return parser.parseRParen(); + return success(); } -static ParseResult parseOptionalDmaPadGroup( - OpAsmParser &parser, - SmallVectorImpl &operands) { - if (failed(parser.parseOptionalKeyword("pad"))) - return success(); - if (parser.parseLParen()) - return failure(); - OpAsmParser::UnresolvedOperand value; - if (parser.parseOperand(value)) - return failure(); - operands.push_back(value); - if (succeeded(parser.parseOptionalComma())) { - OpAsmParser::UnresolvedOperand left; - OpAsmParser::UnresolvedOperand right; - if (parser.parseOperand(left) || parser.parseComma() || - parser.parseOperand(right)) - return failure(); - operands.push_back(left); - operands.push_back(right); - } - return parser.parseRParen(); +static bool isDmaLoopKeyword(StringRef keyword) { + if (keyword == "loop") + return true; + if (!keyword.consume_front("loop")) + return false; + if (keyword.empty()) + return false; + return llvm::all_of(keyword, llvm::isDigit); } static ParseResult parseDmaTripleTypes(OpAsmParser &parser, @@ -1176,24 +1169,14 @@ static LogicalResult verifyOptionalDmaLoopGroup(DmaOp op, Value count, return success(); } -static LogicalResult verifyDmaLoadStoreLoopGroups(Operation *op, Value loop1Count, - Value loop1SrcStride, - Value loop1DstStride, - Value loop2Count, - Value loop2SrcStride, - Value loop2DstStride) { - auto emitError = [&]() { return op->emitOpError(); }; - if (hasAny(loop1Count, loop1SrcStride, loop1DstStride) && - !hasAll(loop1Count, loop1SrcStride, loop1DstStride)) - return emitError() - << "requires loop1 group to provide count, src stride, and dst stride together"; - if (hasAny(loop2Count, loop2SrcStride, loop2DstStride) && - !hasAll(loop2Count, loop2SrcStride, loop2DstStride)) - return emitError() - << "requires loop2 group to provide count, src stride, and dst stride together"; - if (hasAll(loop2Count, loop2SrcStride, loop2DstStride) && - !hasAll(loop1Count, loop1SrcStride, loop1DstStride)) - return emitError() << "requires loop1 when loop2 is present"; +static LogicalResult verifyDmaLoadStoreLoopGroups(Operation *op, + ValueRange loopCounts, + ValueRange loopSrcStrides, + ValueRange loopDstStrides) { + if (loopCounts.size() != loopSrcStrides.size() || + loopCounts.size() != loopDstStrides.size()) + return op->emitOpError() + << "requires each loop group to provide count, src stride, and dst stride together"; return success(); } @@ -1355,17 +1338,18 @@ LogicalResult CopyGmToUbufOp::verify() { } void DmaLoadOp::build(OpBuilder &builder, OperationState &state, Value source, - Value destination, Value sid, Value l2CacheCtl, - Value lenBurst, pto::DmaLoopConfig nburst, - std::optional loop1, - std::optional loop2, + Value destination, Value l2CacheCtl, Value lenBurst, + pto::DmaLoopConfig nburst, + llvm::ArrayRef loops, std::optional pad) { - state.addOperands({source, destination, sid, l2CacheCtl, lenBurst, - nburst.count, nburst.srcStride, nburst.dstStride}); - if (loop1) - state.addOperands({loop1->count, loop1->srcStride, loop1->dstStride}); - if (loop2) - state.addOperands({loop2->count, loop2->srcStride, loop2->dstStride}); + state.addOperands({source, destination, l2CacheCtl, lenBurst, nburst.count, + nburst.srcStride, nburst.dstStride}); + for (const pto::DmaLoopConfig &loop : loops) + state.addOperands(loop.count); + for (const pto::DmaLoopConfig &loop : loops) + state.addOperands(loop.srcStride); + for (const pto::DmaLoopConfig &loop : loops) + state.addOperands(loop.dstStride); bool hasPadCounts = pad && pad->leftCount && pad->rightCount; assert((!pad || static_cast(pad->leftCount) == static_cast(pad->rightCount)) && @@ -1379,37 +1363,83 @@ void DmaLoadOp::build(OpBuilder &builder, OperationState &state, Value source, state.addAttribute( getOperandSegmentSizeAttr(), builder.getDenseI32ArrayAttr( - {1, 1, 1, 1, 1, 1, 1, 1, - loop1 ? 1 : 0, loop1 ? 1 : 0, loop1 ? 1 : 0, - loop2 ? 1 : 0, loop2 ? 1 : 0, loop2 ? 1 : 0, + {1, 1, 1, 1, 1, 1, 1, + static_cast(loops.size()), + static_cast(loops.size()), + static_cast(loops.size()), pad ? 1 : 0, hasPadCounts ? 1 : 0, hasPadCounts ? 1 : 0})); } +void DmaLoadOp::build(OpBuilder &builder, OperationState &state, Value source, + Value destination, Value l2CacheCtl, Value lenBurst, + pto::DmaLoopConfig nburst, + std::optional loop1, + std::optional loop2, + std::optional pad) { + SmallVector loops; + if (loop1) + loops.push_back(*loop1); + if (loop2) + loops.push_back(*loop2); + build(builder, state, source, destination, l2CacheCtl, lenBurst, nburst, + loops, pad); +} + ParseResult DmaLoadOp::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::UnresolvedOperand source, destination, sid, l2CacheCtl, lenBurst; + OpAsmParser::UnresolvedOperand source, destination, l2CacheCtl, lenBurst; SmallVector nburstOperands; - SmallVector loop1Operands; - SmallVector loop2Operands; + SmallVector loopCountOperands; + SmallVector loopSrcStrideOperands; + SmallVector loopDstStrideOperands; SmallVector padOperands; if (parseRequiredOperandWithComma(parser, source) || parseRequiredOperandWithComma(parser, destination) || - parseRequiredOperandWithComma(parser, sid) || parseRequiredOperandWithComma(parser, l2CacheCtl) || parser.parseOperand(lenBurst) || - parseDmaTripleGroup(parser, "nburst", nburstOperands) || - parseOptionalDmaTripleGroup(parser, "loop1", loop1Operands) || - parseOptionalDmaTripleGroup(parser, "loop2", loop2Operands) || - parseOptionalDmaPadGroup(parser, padOperands)) + parseDmaTripleGroup(parser, "nburst", nburstOperands)) return failure(); + while (true) { + if (succeeded(parser.parseOptionalKeyword("pad"))) { + if (parser.parseLParen()) + return failure(); + OpAsmParser::UnresolvedOperand value; + if (parser.parseOperand(value)) + return failure(); + padOperands.push_back(value); + if (succeeded(parser.parseOptionalComma())) { + OpAsmParser::UnresolvedOperand left; + OpAsmParser::UnresolvedOperand right; + if (parser.parseOperand(left) || parser.parseComma() || + parser.parseOperand(right)) + return failure(); + padOperands.push_back(left); + padOperands.push_back(right); + } + if (parser.parseRParen()) + return failure(); + break; + } + + StringRef parsedKeyword; + SmallVector loopGroupOperands; + if (parseOptionalDmaTripleGroupAlias(parser, {"loop", "loop1", "loop2"}, + parsedKeyword, loopGroupOperands)) + return failure(); + if (parsedKeyword.empty()) + break; + loopCountOperands.push_back(loopGroupOperands[0]); + loopSrcStrideOperands.push_back(loopGroupOperands[1]); + loopDstStrideOperands.push_back(loopGroupOperands[2]); + } if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()) return failure(); - Type sourceType, destinationType, sidType, l2CacheCtlType, lenBurstType; - SmallVector nburstTypes, loop1Types, loop2Types, padTypes; + Type sourceType, destinationType, l2CacheCtlType, lenBurstType; + SmallVector nburstTypes, loopCountTypes, loopSrcStrideTypes, + loopDstStrideTypes, padTypes; if (parser.parseType(sourceType) || parser.parseComma() || parser.parseType(destinationType) || parser.parseComma() || - parser.parseType(sidType) || parser.parseComma() || parser.parseType(l2CacheCtlType) || parser.parseComma() || parser.parseType(lenBurstType) || parser.parseComma() || parseDmaTripleTypes(parser, nburstTypes)) @@ -1418,14 +1448,13 @@ ParseResult DmaLoadOp::parse(OpAsmParser &parser, OperationState &result) { StringRef keyword; if (parser.parseKeyword(&keyword)) return failure(); - if (keyword == "loop1") { - if (!loop1Types.empty() || parseDmaTripleTypes(parser, loop1Types)) - return failure(); - continue; - } - if (keyword == "loop2") { - if (!loop2Types.empty() || parseDmaTripleTypes(parser, loop2Types)) + if (isDmaLoopKeyword(keyword)) { + SmallVector loopGroupTypes; + if (parseDmaTripleTypes(parser, loopGroupTypes)) return failure(); + loopCountTypes.push_back(loopGroupTypes[0]); + loopSrcStrideTypes.push_back(loopGroupTypes[1]); + loopDstStrideTypes.push_back(loopGroupTypes[2]); continue; } if (keyword == "pad") { @@ -1434,18 +1463,24 @@ ParseResult DmaLoadOp::parse(OpAsmParser &parser, OperationState &result) { continue; } return parser.emitError(parser.getCurrentLocation(), - "expected one of 'loop1', 'loop2', or 'pad'"); + "expected one of 'loop' or 'pad'"); } + int32_t loopGroupCount = static_cast(loopCountOperands.size()); + if (loopCountOperands.size() != loopSrcStrideOperands.size() || + loopCountOperands.size() != loopDstStrideOperands.size() || + loopCountTypes.size() != loopSrcStrideTypes.size() || + loopCountTypes.size() != loopDstStrideTypes.size()) + return parser.emitError(parser.getCurrentLocation(), + "requires each loop group to provide count, src stride, and dst stride"); + if (loopCountOperands.size() != loopCountTypes.size()) + return parser.emitError(parser.getCurrentLocation(), + "requires loop operand and type groups to match"); + auto &segments = result.getOrAddProperties().operandSegmentSizes; - llvm::copy(ArrayRef{1, 1, 1, 1, 1, 1, 1, 1, - static_cast(loop1Operands.size() ? 1 : 0), - static_cast(loop1Operands.size() ? 1 : 0), - static_cast(loop1Operands.size() ? 1 : 0), - static_cast(loop2Operands.size() ? 1 : 0), - static_cast(loop2Operands.size() ? 1 : 0), - static_cast(loop2Operands.size() ? 1 : 0), + llvm::copy(ArrayRef{1, 1, 1, 1, 1, 1, 1, + loopGroupCount, loopGroupCount, loopGroupCount, static_cast(padOperands.size() ? 1 : 0), static_cast(padOperands.size() == 3 ? 1 : 0), static_cast(padOperands.size() == 3 ? 1 : 0)}, @@ -1453,14 +1488,16 @@ ParseResult DmaLoadOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.resolveOperand(source, sourceType, result.operands) || parser.resolveOperand(destination, destinationType, result.operands) || - parser.resolveOperand(sid, sidType, result.operands) || parser.resolveOperand(l2CacheCtl, l2CacheCtlType, result.operands) || parser.resolveOperand(lenBurst, lenBurstType, result.operands) || parser.resolveOperands(nburstOperands, nburstTypes, parser.getCurrentLocation(), result.operands) || - parser.resolveOperands(loop1Operands, loop1Types, parser.getCurrentLocation(), - result.operands) || - parser.resolveOperands(loop2Operands, loop2Types, parser.getCurrentLocation(), + parser.resolveOperands(loopCountOperands, loopCountTypes, + parser.getCurrentLocation(), result.operands) || + parser.resolveOperands(loopSrcStrideOperands, loopSrcStrideTypes, + parser.getCurrentLocation(), result.operands) || + parser.resolveOperands(loopDstStrideOperands, loopDstStrideTypes, + parser.getCurrentLocation(), result.operands) || parser.resolveOperands(padOperands, padTypes, parser.getCurrentLocation(), result.operands)) @@ -1469,33 +1506,26 @@ ParseResult DmaLoadOp::parse(OpAsmParser &parser, OperationState &result) { } void DmaLoadOp::print(OpAsmPrinter &printer) { - printer << " " << getSource() << ", " << getDestination() << ", " << getSid() - << ", " << getL2CacheCtl() << ", " << getLenBurst(); + printer << " " << getSource() << ", " << getDestination() << ", " + << getL2CacheCtl() << ", " << getLenBurst(); printDmaTripleGroup(printer, "nburst", getNBurst(), getNburstSrcStride(), getNburstDstStride()); - if (hasAll(getLoop1Count(), getLoop1SrcStride(), getLoop1DstStride())) - printDmaTripleGroup(printer, "loop1", getLoop1Count(), getLoop1SrcStride(), - getLoop1DstStride()); - if (hasAll(getLoop2Count(), getLoop2SrcStride(), getLoop2DstStride())) - printDmaTripleGroup(printer, "loop2", getLoop2Count(), getLoop2SrcStride(), - getLoop2DstStride()); + for (auto [count, srcStride, dstStride] : + llvm::zip(getLoopCounts(), getLoopSrcStrides(), getLoopDstStrides())) + printDmaTripleGroup(printer, "loop", count, srcStride, dstStride); if (getPadValue()) printDmaPadGroup(printer, getPadValue(), getLeftPaddingCount(), getRightPaddingCount()); printer.printOptionalAttrDict((*this)->getAttrs()); printer << " : " << getSource().getType() << ", " << getDestination().getType() - << ", " << getSid().getType() << ", " << getL2CacheCtl().getType() - << ", " << getLenBurst().getType() << ", " << getNBurst().getType() - << ", " << getNburstSrcStride().getType() << ", " + << ", " << getL2CacheCtl().getType() << ", " << getLenBurst().getType() + << ", " << getNBurst().getType() << ", " << getNburstSrcStride().getType() + << ", " << getNburstDstStride().getType(); - if (hasAll(getLoop1Count(), getLoop1SrcStride(), getLoop1DstStride())) - printDmaTripleTypes(printer, "loop1", getLoop1Count().getType(), - getLoop1SrcStride().getType(), - getLoop1DstStride().getType()); - if (hasAll(getLoop2Count(), getLoop2SrcStride(), getLoop2DstStride())) - printDmaTripleTypes(printer, "loop2", getLoop2Count().getType(), - getLoop2SrcStride().getType(), - getLoop2DstStride().getType()); + for (auto [count, srcStride, dstStride] : + llvm::zip(getLoopCounts(), getLoopSrcStrides(), getLoopDstStrides())) + printDmaTripleTypes(printer, "loop", count.getType(), srcStride.getType(), + dstStride.getType()); if (getPadValue()) printDmaPadTypes(printer, getPadValue().getType(), getLeftPaddingCount() ? getLeftPaddingCount().getType() : Type{}, @@ -1513,9 +1543,8 @@ LogicalResult DmaLoadOp::verify() { if (failed(verifyCopyGmToUbufOp(*this, true))) return failure(); if (failed(verifyDmaLoadStoreLoopGroups( - getOperation(), getLoop1Count(), getLoop1SrcStride(), - getLoop1DstStride(), getLoop2Count(), getLoop2SrcStride(), - getLoop2DstStride()))) + getOperation(), getLoopCounts(), getLoopSrcStrides(), + getLoopDstStrides()))) return failure(); if (!getPadValue() && (getLeftPaddingCount() || getRightPaddingCount())) return emitOpError() << "requires pad group to provide a pad value"; @@ -3530,49 +3559,70 @@ LogicalResult CopyUbufToGmOp::verify() { } void DmaStoreOp::build(OpBuilder &builder, OperationState &state, Value source, - Value destination, Value sid, Value reserved, - Value lenBurst, pto::DmaLoopConfig nburst, - std::optional loop1, - std::optional loop2) { - state.addOperands({source, destination, sid, reserved, lenBurst, nburst.count, + Value destination, Value lenBurst, pto::DmaLoopConfig nburst, + llvm::ArrayRef loops) { + state.addOperands({source, destination, lenBurst, nburst.count, nburst.srcStride, nburst.dstStride}); - if (loop1) - state.addOperands({loop1->count, loop1->srcStride, loop1->dstStride}); - if (loop2) - state.addOperands({loop2->count, loop2->srcStride, loop2->dstStride}); + for (const pto::DmaLoopConfig &loop : loops) + state.addOperands(loop.count); + for (const pto::DmaLoopConfig &loop : loops) + state.addOperands(loop.srcStride); + for (const pto::DmaLoopConfig &loop : loops) + state.addOperands(loop.dstStride); state.addAttribute( getOperandSegmentSizeAttr(), builder.getDenseI32ArrayAttr( - {1, 1, 1, 1, 1, 1, 1, 1, - loop1 ? 1 : 0, loop1 ? 1 : 0, loop1 ? 1 : 0, - loop2 ? 1 : 0, loop2 ? 1 : 0, loop2 ? 1 : 0})); + {1, 1, 1, 1, 1, 1, + static_cast(loops.size()), + static_cast(loops.size()), + static_cast(loops.size())})); +} + +void DmaStoreOp::build(OpBuilder &builder, OperationState &state, Value source, + Value destination, Value lenBurst, pto::DmaLoopConfig nburst, + std::optional loop1, + std::optional loop2) { + SmallVector loops; + if (loop1) + loops.push_back(*loop1); + if (loop2) + loops.push_back(*loop2); + build(builder, state, source, destination, lenBurst, nburst, loops); } ParseResult DmaStoreOp::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::UnresolvedOperand source, destination, sid, reserved, lenBurst; + OpAsmParser::UnresolvedOperand source, destination, lenBurst; SmallVector nburstOperands; - SmallVector loop1Operands; - SmallVector loop2Operands; + SmallVector loopCountOperands; + SmallVector loopSrcStrideOperands; + SmallVector loopDstStrideOperands; if (parseRequiredOperandWithComma(parser, source) || parseRequiredOperandWithComma(parser, destination) || - parseRequiredOperandWithComma(parser, sid) || - parseRequiredOperandWithComma(parser, reserved) || parser.parseOperand(lenBurst) || - parseDmaTripleGroup(parser, "nburst", nburstOperands) || - parseOptionalDmaTripleGroup(parser, "loop1", loop1Operands) || - parseOptionalDmaTripleGroup(parser, "loop2", loop2Operands)) + parseDmaTripleGroup(parser, "nburst", nburstOperands)) return failure(); + while (true) { + StringRef parsedKeyword; + SmallVector loopGroupOperands; + if (parseOptionalDmaTripleGroupAlias(parser, {"loop", "loop1", "loop2"}, + parsedKeyword, loopGroupOperands)) + return failure(); + if (parsedKeyword.empty()) + break; + loopCountOperands.push_back(loopGroupOperands[0]); + loopSrcStrideOperands.push_back(loopGroupOperands[1]); + loopDstStrideOperands.push_back(loopGroupOperands[2]); + } if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()) return failure(); - Type sourceType, destinationType, sidType, reservedType, lenBurstType; - SmallVector nburstTypes, loop1Types, loop2Types; + Type sourceType, destinationType, lenBurstType; + SmallVector nburstTypes, loopCountTypes, loopSrcStrideTypes, + loopDstStrideTypes; if (parser.parseType(sourceType) || parser.parseComma() || parser.parseType(destinationType) || parser.parseComma() || - parser.parseType(sidType) || parser.parseComma() || - parser.parseType(reservedType) || parser.parseComma() || parser.parseType(lenBurstType) || parser.parseComma() || parseDmaTripleTypes(parser, nburstTypes)) return failure(); @@ -3580,71 +3630,70 @@ ParseResult DmaStoreOp::parse(OpAsmParser &parser, OperationState &result) { StringRef keyword; if (parser.parseKeyword(&keyword)) return failure(); - if (keyword == "loop1") { - if (!loop1Types.empty() || parseDmaTripleTypes(parser, loop1Types)) - return failure(); - continue; - } - if (keyword == "loop2") { - if (!loop2Types.empty() || parseDmaTripleTypes(parser, loop2Types)) + if (isDmaLoopKeyword(keyword)) { + SmallVector loopGroupTypes; + if (parseDmaTripleTypes(parser, loopGroupTypes)) return failure(); + loopCountTypes.push_back(loopGroupTypes[0]); + loopSrcStrideTypes.push_back(loopGroupTypes[1]); + loopDstStrideTypes.push_back(loopGroupTypes[2]); continue; } return parser.emitError(parser.getCurrentLocation(), - "expected one of 'loop1' or 'loop2'"); + "expected 'loop'"); } + int32_t loopGroupCount = static_cast(loopCountOperands.size()); + if (loopCountOperands.size() != loopSrcStrideOperands.size() || + loopCountOperands.size() != loopDstStrideOperands.size() || + loopCountTypes.size() != loopSrcStrideTypes.size() || + loopCountTypes.size() != loopDstStrideTypes.size()) + return parser.emitError(parser.getCurrentLocation(), + "requires each loop group to provide count, src stride, and dst stride"); + if (loopCountOperands.size() != loopCountTypes.size()) + return parser.emitError(parser.getCurrentLocation(), + "requires loop operand and type groups to match"); + auto &segments = result.getOrAddProperties().operandSegmentSizes; - llvm::copy(ArrayRef{1, 1, 1, 1, 1, 1, 1, 1, - static_cast(loop1Operands.size() ? 1 : 0), - static_cast(loop1Operands.size() ? 1 : 0), - static_cast(loop1Operands.size() ? 1 : 0), - static_cast(loop2Operands.size() ? 1 : 0), - static_cast(loop2Operands.size() ? 1 : 0), - static_cast(loop2Operands.size() ? 1 : 0)}, + llvm::copy(ArrayRef{1, 1, 1, 1, 1, 1, + loopGroupCount, loopGroupCount, loopGroupCount}, segments.begin()); if (parser.resolveOperand(source, sourceType, result.operands) || parser.resolveOperand(destination, destinationType, result.operands) || - parser.resolveOperand(sid, sidType, result.operands) || - parser.resolveOperand(reserved, reservedType, result.operands) || parser.resolveOperand(lenBurst, lenBurstType, result.operands) || parser.resolveOperands(nburstOperands, nburstTypes, parser.getCurrentLocation(), result.operands) || - parser.resolveOperands(loop1Operands, loop1Types, parser.getCurrentLocation(), - result.operands) || - parser.resolveOperands(loop2Operands, loop2Types, parser.getCurrentLocation(), + parser.resolveOperands(loopCountOperands, loopCountTypes, + parser.getCurrentLocation(), result.operands) || + parser.resolveOperands(loopSrcStrideOperands, loopSrcStrideTypes, + parser.getCurrentLocation(), result.operands) || + parser.resolveOperands(loopDstStrideOperands, loopDstStrideTypes, + parser.getCurrentLocation(), result.operands)) return failure(); return success(); } void DmaStoreOp::print(OpAsmPrinter &printer) { - printer << " " << getSource() << ", " << getDestination() << ", " << getSid() - << ", " << getReserved() << ", " << getLenBurst(); + printer << " " << getSource() << ", " << getDestination() << ", " + << getLenBurst(); printDmaTripleGroup(printer, "nburst", getNBurst(), getNburstSrcStride(), getNburstDstStride()); - if (hasAll(getLoop1Count(), getLoop1SrcStride(), getLoop1DstStride())) - printDmaTripleGroup(printer, "loop1", getLoop1Count(), getLoop1SrcStride(), - getLoop1DstStride()); - if (hasAll(getLoop2Count(), getLoop2SrcStride(), getLoop2DstStride())) - printDmaTripleGroup(printer, "loop2", getLoop2Count(), getLoop2SrcStride(), - getLoop2DstStride()); + for (auto [count, srcStride, dstStride] : + llvm::zip(getLoopCounts(), getLoopSrcStrides(), getLoopDstStrides())) + printDmaTripleGroup(printer, "loop", count, srcStride, dstStride); printer.printOptionalAttrDict((*this)->getAttrs()); printer << " : " << getSource().getType() << ", " << getDestination().getType() - << ", " << getSid().getType() << ", " << getReserved().getType() << ", " << getLenBurst().getType() << ", " << getNBurst().getType() - << ", " << getNburstSrcStride().getType() << ", " + << ", " << getNburstSrcStride().getType() + << ", " << getNburstDstStride().getType(); - if (hasAll(getLoop1Count(), getLoop1SrcStride(), getLoop1DstStride())) - printDmaTripleTypes(printer, "loop1", getLoop1Count().getType(), - getLoop1SrcStride().getType(), - getLoop1DstStride().getType()); - if (hasAll(getLoop2Count(), getLoop2SrcStride(), getLoop2DstStride())) - printDmaTripleTypes(printer, "loop2", getLoop2Count().getType(), - getLoop2SrcStride().getType(), - getLoop2DstStride().getType()); + for (auto [count, srcStride, dstStride] : + llvm::zip(getLoopCounts(), getLoopSrcStrides(), getLoopDstStrides())) + printDmaTripleTypes(printer, "loop", count.getType(), srcStride.getType(), + dstStride.getType()); } void DmaStoreOp::getEffects( @@ -3658,7 +3707,6 @@ LogicalResult DmaStoreOp::verify() { if (failed(verifyCopyUbufToGmOp(*this, false))) return failure(); return verifyDmaLoadStoreLoopGroups( - getOperation(), getLoop1Count(), getLoop1SrcStride(), - getLoop1DstStride(), getLoop2Count(), getLoop2SrcStride(), - getLoop2DstStride()); + getOperation(), getLoopCounts(), getLoopSrcStrides(), + getLoopDstStrides()); } diff --git a/lib/PTO/Transforms/PTOVPTOExpandBridgeOps.cpp b/lib/PTO/Transforms/PTOVPTOExpandBridgeOps.cpp index 2580a9cab..3d2366349 100644 --- a/lib/PTO/Transforms/PTOVPTOExpandBridgeOps.cpp +++ b/lib/PTO/Transforms/PTOVPTOExpandBridgeOps.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" @@ -84,6 +85,74 @@ static bool shouldRestoreDmaLoopSize(Value loop1Count, Value loop2Count) { return !isKnownOne(loop1Count) || !isKnownOne(loop2Count); } +static SmallVector collectLoopConfigs(ValueRange counts, + ValueRange srcStrides, + ValueRange dstStrides) { + SmallVector loops; + loops.reserve(counts.size()); + for (auto [count, srcStride, dstStride] : + llvm::zip(counts, srcStrides, dstStrides)) + loops.push_back({count, srcStride, dstStride}); + return loops; +} + +static Value offsetPointerByBytes(Value basePtr, Value byteOffset, + PatternRewriter &rewriter, Location loc) { + if (!basePtr) + return {}; + + APInt constOffset; + if (matchPattern(byteOffset, m_ConstantInt(&constOffset)) && constOffset.isZero()) + return basePtr; + + Value baseInt = + rewriter.create(loc, rewriter.getI64Type(), basePtr); + Value offsetI64 = byteOffset; + if (!offsetI64.getType().isInteger(64)) + offsetI64 = + rewriter.create(loc, rewriter.getI64Type(), + offsetI64); + Value sum = rewriter.create(loc, baseInt, offsetI64); + return rewriter.create(loc, basePtr.getType(), sum); +} + +static Value buildAccumulatedByteOffset(Location loc, Value baseOffset, + Value indexI64, Value stride, + PatternRewriter &rewriter) { + Value delta = rewriter.create(loc, indexI64, stride); + return rewriter.create(loc, baseOffset, delta); +} + +template +static void buildSoftwareLoopNest(PatternRewriter &rewriter, Location loc, + ArrayRef loops, + Value srcOffset, Value dstOffset, + BodyBuilder &&buildLeaf) { + if (loops.empty()) { + buildLeaf(srcOffset, dstOffset); + return; + } + + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value count = rewriter.create(loc, rewriter.getIndexType(), + loops.front().count); + scf::ForOp forOp = rewriter.create(loc, c0, count, c1); + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(forOp.getBody()); + Value ivI64 = + rewriter.create(loc, rewriter.getI64Type(), + forOp.getInductionVar()); + Value nextSrcOffset = buildAccumulatedByteOffset( + loc, srcOffset, ivI64, loops.front().srcStride, rewriter); + Value nextDstOffset = buildAccumulatedByteOffset( + loc, dstOffset, ivI64, loops.front().dstStride, rewriter); + buildSoftwareLoopNest(rewriter, loc, loops.drop_front(), nextSrcOffset, + nextDstOffset, buildLeaf); + } +} + struct ExpandUvldPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -117,17 +186,28 @@ struct ExpandDmaLoadPattern : public OpRewritePattern { LogicalResult matchAndRewrite(pto::DmaLoadOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); + Value zero = rewriter.create(loc, 0, 64); Value one = rewriter.create(loc, 1, 64); - Value loop2Size = op.getLoop2Count(); - if (!loop2Size) - loop2Size = one; - if (Value loop2Count = op.getLoop2Count()) + SmallVector loops = + collectLoopConfigs(op.getLoopCounts(), op.getLoopSrcStrides(), + op.getLoopDstStrides()); + ArrayRef hwLoops = ArrayRef(loops).take_front(2); + ArrayRef swLoops = ArrayRef(loops).drop_front(hwLoops.size()); + + Value loop1Count; + Value loop2Size = one; + if (hwLoops.size() == 2) { rewriter.create( - loc, op.getLoop2SrcStride(), op.getLoop2DstStride()); - - if (Value loop1Count = op.getLoop1Count()) { + loc, hwLoops[0].srcStride, hwLoops[0].dstStride); + loop2Size = hwLoops[0].count; + loop1Count = hwLoops[1].count; rewriter.create( - loc, op.getLoop1SrcStride(), op.getLoop1DstStride()); + loc, hwLoops[1].srcStride, hwLoops[1].dstStride); + rewriter.create(loc, loop2Size, loop1Count); + } else if (hwLoops.size() == 1) { + loop1Count = hwLoops[0].count; + rewriter.create( + loc, hwLoops[0].srcStride, hwLoops[0].dstStride); rewriter.create(loc, loop2Size, loop1Count); } @@ -144,11 +224,18 @@ struct ExpandDmaLoadPattern : public OpRewritePattern { if (Value padValue = op.getPadValue()) rewriter.create(loc, padValue); - rewriter.create( - loc, op.getSource(), op.getDestination(), op.getSid(), op.getNBurst(), - op.getLenBurst(), leftPadding, rightPadding, dataSelect, - op.getL2CacheCtl(), op.getNburstSrcStride(), op.getNburstDstStride()); - if (shouldRestoreDmaLoopSize(op.getLoop1Count(), loop2Size)) + buildSoftwareLoopNest( + rewriter, loc, swLoops, zero, zero, + [&](Value srcOffset, Value dstOffset) { + Value source = offsetPointerByBytes(op.getSource(), srcOffset, rewriter, loc); + Value destination = + offsetPointerByBytes(op.getDestination(), dstOffset, rewriter, loc); + rewriter.create( + loc, source, destination, zero, op.getNBurst(), op.getLenBurst(), + leftPadding, rightPadding, dataSelect, op.getL2CacheCtl(), + op.getNburstSrcStride(), op.getNburstDstStride()); + }); + if (shouldRestoreDmaLoopSize(loop1Count, loop2Size)) rewriter.create(loc, one, one); rewriter.eraseOp(op); return success(); @@ -161,25 +248,42 @@ struct ExpandDmaStorePattern : public OpRewritePattern { LogicalResult matchAndRewrite(pto::DmaStoreOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); + Value zero = rewriter.create(loc, 0, 64); Value one = rewriter.create(loc, 1, 64); - Value loop2Size = op.getLoop2Count(); - if (!loop2Size) - loop2Size = one; - if (Value loop2Count = op.getLoop2Count()) + SmallVector loops = + collectLoopConfigs(op.getLoopCounts(), op.getLoopSrcStrides(), + op.getLoopDstStrides()); + ArrayRef hwLoops = ArrayRef(loops).take_front(2); + ArrayRef swLoops = ArrayRef(loops).drop_front(hwLoops.size()); + + Value loop1Count; + Value loop2Size = one; + if (hwLoops.size() == 2) { rewriter.create( - loc, op.getLoop2SrcStride(), op.getLoop2DstStride()); - - if (Value loop1Count = op.getLoop1Count()) { + loc, hwLoops[0].srcStride, hwLoops[0].dstStride); + loop2Size = hwLoops[0].count; + loop1Count = hwLoops[1].count; + rewriter.create( + loc, hwLoops[1].srcStride, hwLoops[1].dstStride); + rewriter.create(loc, loop2Size, loop1Count); + } else if (hwLoops.size() == 1) { + loop1Count = hwLoops[0].count; rewriter.create( - loc, op.getLoop1SrcStride(), op.getLoop1DstStride()); + loc, hwLoops[0].srcStride, hwLoops[0].dstStride); rewriter.create(loc, loop2Size, loop1Count); } - rewriter.create( - loc, op.getSource(), op.getDestination(), op.getSid(), op.getNBurst(), - op.getLenBurst(), op.getReserved(), op.getNburstDstStride(), - op.getNburstSrcStride()); - if (shouldRestoreDmaLoopSize(op.getLoop1Count(), loop2Size)) + buildSoftwareLoopNest( + rewriter, loc, swLoops, zero, zero, + [&](Value srcOffset, Value dstOffset) { + Value source = offsetPointerByBytes(op.getSource(), srcOffset, rewriter, loc); + Value destination = + offsetPointerByBytes(op.getDestination(), dstOffset, rewriter, loc); + rewriter.create( + loc, source, destination, zero, op.getNBurst(), op.getLenBurst(), + zero, op.getNburstDstStride(), op.getNburstSrcStride()); + }); + if (shouldRestoreDmaLoopSize(loop1Count, loop2Size)) rewriter.create(loc, one, one); rewriter.eraseOp(op); return success(); @@ -191,8 +295,9 @@ struct ExpandDmaCopyPattern : public OpRewritePattern { LogicalResult matchAndRewrite(pto::DmaCopyOp op, PatternRewriter &rewriter) const override { + Value zero = rewriter.create(op.getLoc(), 0, 64); rewriter.replaceOpWithNewOp( - op, op.getSource(), op.getDestination(), op.getSid(), op.getNBurst(), + op, op.getSource(), op.getDestination(), zero, op.getNBurst(), op.getLenBurst(), op.getSrcStride(), op.getDstStride()); return success(); } diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-bf16/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vadd-bf16/kernel.pto index bf31d5ab5..be07f40d3 100644 --- a/test/vpto/cases/micro-op/binary-vector/vadd-bf16/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vadd-bf16/kernel.pto @@ -22,12 +22,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -44,9 +44,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg2, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-f16/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vadd-f16/kernel.pto index b70ba1dd8..d0e101305 100644 --- a/test/vpto/cases/micro-op/binary-vector/vadd-f16/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vadd-f16/kernel.pto @@ -22,12 +22,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -44,9 +44,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg2, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/kernel.pto index 8eeb85a2e..f6d4ca961 100644 --- a/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/kernel.pto @@ -17,12 +17,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -40,9 +40,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg2, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/kernel.pto index ea2ca63da..a95d5a391 100644 --- a/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/kernel.pto @@ -24,12 +24,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -47,9 +47,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg2, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/kernel.pto index 87659b202..48f75a5ea 100644 --- a/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/kernel.pto @@ -22,12 +22,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -44,9 +44,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg2, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/kernel.pto index 234e09bf8..a5db0b306 100644 --- a/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/kernel.pto @@ -24,12 +24,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -47,9 +47,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg2, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/kernel.pto index 774431150..67443f0be 100644 --- a/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/kernel.pto @@ -22,12 +22,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -44,9 +44,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg2, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-tail/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vadd-tail/kernel.pto index f41776c3b..2070e0554 100644 --- a/test/vpto/cases/micro-op/binary-vector/vadd-tail/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vadd-tail/kernel.pto @@ -17,12 +17,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -40,9 +40,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg2, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vadd/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vadd/kernel.pto index bdc4892d1..5b6f075d1 100644 --- a/test/vpto/cases/micro-op/binary-vector/vadd/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vadd/kernel.pto @@ -20,14 +20,14 @@ module attributes {pto.target_arch = "a5"} { %false = arith.constant false pto.get_buf "PIPE_MTE2", 0, 0 - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.rls_buf "PIPE_MTE2", 0, 0 pto.get_buf "PIPE_MTE2", 1, 0 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.rls_buf "PIPE_MTE2", 1, 0 pto.get_buf "PIPE_V", 0, 0 pto.get_buf "PIPE_V", 1, 0 @@ -48,9 +48,9 @@ module attributes {pto.target_arch = "a5"} { pto.rls_buf "PIPE_V", 2, 0 pto.get_buf "PIPE_MTE3", 2, 0 - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg2, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.rls_buf "PIPE_MTE3", 2, 0 pto.barrier #pto.pipe return diff --git a/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/kernel.pto index 7131fc4c8..b61a0084e 100644 --- a/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/kernel.pto @@ -24,12 +24,12 @@ module attributes {pto.target_arch = "a5"} { %ub_carry = pto.castptr %c12288_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c256_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c256_i64 nburst(%c1_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c256_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c256_i64 nburst(%c1_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -45,12 +45,12 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c256_i64 + pto.dma_store %ub_out, %arg2, %c256_i64 nburst(%c1_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_store %ub_carry, %arg3, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.dma_store %ub_carry, %arg3, %c128_i64 nburst(%c1_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vaddc/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vaddc/kernel.pto index 3aef43a9d..fe0149024 100644 --- a/test/vpto/cases/micro-op/binary-vector/vaddc/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vaddc/kernel.pto @@ -23,12 +23,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %ub_carry = pto.castptr %c12288_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c256_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c256_i64 nburst(%c1_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c256_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c256_i64 nburst(%c1_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -44,12 +44,12 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c256_i64 + pto.dma_store %ub_out, %arg2, %c256_i64 nburst(%c1_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_store %ub_carry, %arg3, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.dma_store %ub_carry, %arg3, %c128_i64 nburst(%c1_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/kernel.pto index 2a6085ba9..ecaf59684 100644 --- a/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/kernel.pto @@ -24,12 +24,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -46,9 +46,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg2, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vand/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vand/kernel.pto index b2a66c186..87c8dd9fa 100644 --- a/test/vpto/cases/micro-op/binary-vector/vand/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vand/kernel.pto @@ -22,12 +22,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -44,9 +44,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg2, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv-f16/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vdiv-f16/kernel.pto index dbe3913f4..38a2214dc 100644 --- a/test/vpto/cases/micro-op/binary-vector/vdiv-f16/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vdiv-f16/kernel.pto @@ -23,12 +23,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -46,9 +46,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg2, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/kernel.pto index 657f79491..e39f3bd8c 100644 --- a/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/kernel.pto @@ -17,12 +17,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -40,9 +40,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg2, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv-tail/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vdiv-tail/kernel.pto index 1af9192c2..2a1ec00ba 100644 --- a/test/vpto/cases/micro-op/binary-vector/vdiv-tail/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vdiv-tail/kernel.pto @@ -17,12 +17,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -40,9 +40,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg2, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vdiv/kernel.pto index 62b41c796..e62ebc826 100644 --- a/test/vpto/cases/micro-op/binary-vector/vdiv/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vdiv/kernel.pto @@ -18,12 +18,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -41,9 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg2, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vmax-tail/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmax-tail/kernel.pto index 52c6db41e..9745b54e8 100644 --- a/test/vpto/cases/micro-op/binary-vector/vmax-tail/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vmax-tail/kernel.pto @@ -17,12 +17,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -40,9 +40,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg2, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vmax/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmax/kernel.pto index b24b18bdc..82cd84e87 100644 --- a/test/vpto/cases/micro-op/binary-vector/vmax/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vmax/kernel.pto @@ -18,12 +18,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -41,9 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg2, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-bf16/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmin-bf16/kernel.pto index ebe54983a..03279fa25 100644 --- a/test/vpto/cases/micro-op/binary-vector/vmin-bf16/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vmin-bf16/kernel.pto @@ -22,12 +22,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -44,9 +44,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg2, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-f16/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmin-f16/kernel.pto index 88a3bf770..4c6c97a99 100644 --- a/test/vpto/cases/micro-op/binary-vector/vmin-f16/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vmin-f16/kernel.pto @@ -22,12 +22,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -44,9 +44,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg2, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/kernel.pto index 66241fe64..3b2887fd0 100644 --- a/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/kernel.pto @@ -18,12 +18,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -41,9 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg2, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/vmin/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/vmin/kernel.pto index 93043db06..3b5085b1f 100644 --- a/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/vmin/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/vmin/kernel.pto @@ -18,12 +18,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -41,9 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg2, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/kernel.pto index 5f3b0852c..b4d6d04bb 100644 --- a/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/kernel.pto @@ -22,12 +22,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -44,9 +44,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg2, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/kernel.pto index 2ed31bbe1..e566b31fd 100644 --- a/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/kernel.pto @@ -22,12 +22,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -44,9 +44,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg2, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-tail/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmin-tail/kernel.pto index a6270b077..7d28503ff 100644 --- a/test/vpto/cases/micro-op/binary-vector/vmin-tail/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vmin-tail/kernel.pto @@ -17,12 +17,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -40,9 +40,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg2, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vmin/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmin/kernel.pto index 93043db06..3b5085b1f 100644 --- a/test/vpto/cases/micro-op/binary-vector/vmin/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vmin/kernel.pto @@ -18,12 +18,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -41,9 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg2, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vmul-tail/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmul-tail/kernel.pto index b890b58b2..21104fa52 100644 --- a/test/vpto/cases/micro-op/binary-vector/vmul-tail/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vmul-tail/kernel.pto @@ -17,12 +17,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -40,9 +40,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg2, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vmul/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmul/kernel.pto index d0940b673..2295dd723 100644 --- a/test/vpto/cases/micro-op/binary-vector/vmul/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vmul/kernel.pto @@ -18,12 +18,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -41,9 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg2, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vor-f16/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vor-f16/kernel.pto index 4544d4f73..bbbe84b47 100644 --- a/test/vpto/cases/micro-op/binary-vector/vor-f16/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vor-f16/kernel.pto @@ -24,12 +24,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -46,9 +46,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg2, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/kernel.pto index 7ed862f26..14f73c2e7 100644 --- a/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/kernel.pto @@ -24,12 +24,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -46,9 +46,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg2, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vor/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vor/kernel.pto index e531dc633..4de5829cf 100644 --- a/test/vpto/cases/micro-op/binary-vector/vor/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vor/kernel.pto @@ -24,12 +24,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -46,9 +46,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg2, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/kernel.pto index 4a1d178b0..c58a97430 100644 --- a/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/kernel.pto @@ -26,12 +26,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -48,9 +48,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg2, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/kernel.pto index 977665238..80a5fc5ce 100644 --- a/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/kernel.pto @@ -24,12 +24,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -46,9 +46,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg2, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vshl/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vshl/kernel.pto index 0a405c63d..aee5db18c 100644 --- a/test/vpto/cases/micro-op/binary-vector/vshl/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vshl/kernel.pto @@ -24,12 +24,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -46,9 +46,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg2, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/kernel.pto index 98516150a..13f12c24f 100644 --- a/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/kernel.pto @@ -24,12 +24,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -46,9 +46,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg2, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/kernel.pto index bc5156d50..7d445e3da 100644 --- a/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/kernel.pto @@ -24,12 +24,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -46,9 +46,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg2, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vshr/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vshr/kernel.pto index eefb2416c..c38711afe 100644 --- a/test/vpto/cases/micro-op/binary-vector/vshr/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vshr/kernel.pto @@ -24,12 +24,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -46,9 +46,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg2, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vsub-tail/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vsub-tail/kernel.pto index 460e5f486..22f26bafb 100644 --- a/test/vpto/cases/micro-op/binary-vector/vsub-tail/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vsub-tail/kernel.pto @@ -17,12 +17,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -40,9 +40,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg2, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vsub/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vsub/kernel.pto index df7837848..330cfeb4f 100644 --- a/test/vpto/cases/micro-op/binary-vector/vsub/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vsub/kernel.pto @@ -18,12 +18,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -41,9 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg2, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/kernel.pto index 19df47975..a1819ac91 100644 --- a/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/kernel.pto @@ -24,12 +24,12 @@ module attributes {pto.target_arch = "a5"} { %ub_borrow = pto.castptr %c12288_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c256_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c256_i64 nburst(%c1_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c256_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c256_i64 nburst(%c1_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -45,12 +45,12 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c256_i64 + pto.dma_store %ub_out, %arg2, %c256_i64 nburst(%c1_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_store %ub_borrow, %arg3, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.dma_store %ub_borrow, %arg3, %c128_i64 nburst(%c1_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vsubc/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vsubc/kernel.pto index c555dd4f4..50149443f 100644 --- a/test/vpto/cases/micro-op/binary-vector/vsubc/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vsubc/kernel.pto @@ -24,12 +24,12 @@ module attributes {pto.target_arch = "a5"} { %ub_borrow = pto.castptr %c12288_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c256_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c256_i64 nburst(%c1_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c256_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c256_i64 nburst(%c1_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -45,12 +45,12 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c256_i64 + pto.dma_store %ub_out, %arg2, %c256_i64 nburst(%c1_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_store %ub_borrow, %arg3, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.dma_store %ub_borrow, %arg3, %c128_i64 nburst(%c1_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/kernel.pto index 42a7ebe87..0904f144b 100644 --- a/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/kernel.pto @@ -24,12 +24,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -46,9 +46,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg2, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/binary-vector/vxor/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vxor/kernel.pto index fff570fd6..fda88bc55 100644 --- a/test/vpto/cases/micro-op/binary-vector/vxor/kernel.pto +++ b/test/vpto/cases/micro-op/binary-vector/vxor/kernel.pto @@ -24,12 +24,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c64_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -46,9 +46,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg2, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-eq/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmp-eq/kernel.pto index 0b0b9be36..e99311b86 100644 --- a/test/vpto/cases/micro-op/compare-select/vcmp-eq/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vcmp-eq/kernel.pto @@ -17,12 +17,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c256_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c256_i64 nburst(%c64_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c256_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c256_i64 nburst(%c64_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -42,9 +42,9 @@ module attributes {pto.target_arch = "a5"} { pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c32_i64 + pto.dma_store %ub_out, %arg2, %c32_i64 nburst(%c32_i64, %c32_i64, %c32_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/kernel.pto index 382b2e455..774c3367a 100644 --- a/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/kernel.pto @@ -17,12 +17,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c256_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c256_i64 nburst(%c64_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c256_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c256_i64 nburst(%c64_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -42,9 +42,9 @@ module attributes {pto.target_arch = "a5"} { pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c32_i64 + pto.dma_store %ub_out, %arg2, %c32_i64 nburst(%c32_i64, %c32_i64, %c32_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/kernel.pto index 291ad31a7..76935bae2 100644 --- a/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/kernel.pto @@ -18,12 +18,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c256_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c256_i64 nburst(%c64_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c256_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c256_i64 nburst(%c64_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -43,9 +43,9 @@ module attributes {pto.target_arch = "a5"} { pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c32_i64_data + pto.dma_store %ub_out, %arg2, %c32_i64_data nburst(%c32_i64, %c32_i64, %c32_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/kernel.pto index 833388e7d..04910e234 100644 --- a/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/kernel.pto @@ -18,12 +18,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c256_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c256_i64 nburst(%c64_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c256_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c256_i64 nburst(%c64_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -43,9 +43,9 @@ module attributes {pto.target_arch = "a5"} { pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c32_i64_data + pto.dma_store %ub_out, %arg2, %c32_i64_data nburst(%c32_i64, %c32_i64, %c32_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-lt/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmp-lt/kernel.pto index 0601a6bb4..dbf4c12e2 100644 --- a/test/vpto/cases/micro-op/compare-select/vcmp-lt/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vcmp-lt/kernel.pto @@ -17,12 +17,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c256_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c256_i64 nburst(%c64_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c256_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c256_i64 nburst(%c64_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -42,9 +42,9 @@ module attributes {pto.target_arch = "a5"} { pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c32_i64 + pto.dma_store %ub_out, %arg2, %c32_i64 nburst(%c32_i64, %c32_i64, %c32_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-tail/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmp-tail/kernel.pto index bb63cbb1f..25bd933d4 100644 --- a/test/vpto/cases/micro-op/compare-select/vcmp-tail/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vcmp-tail/kernel.pto @@ -17,12 +17,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c256_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c256_i64 nburst(%c64_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c256_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c256_i64 nburst(%c64_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -42,9 +42,9 @@ module attributes {pto.target_arch = "a5"} { pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c32_i64 + pto.dma_store %ub_out, %arg2, %c32_i64 nburst(%c32_i64, %c32_i64, %c32_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/kernel.pto index 631e5c875..1d80c5cf7 100644 --- a/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/kernel.pto @@ -16,9 +16,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c256_i64 : i64 -> !pto.ptr pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 - pto.dma_load %arg0, %ub_src, %c0_i64, %c0_i64, %c256_i64 + pto.dma_load %arg0, %ub_src, %c0_i64, %c256_i64 nburst(%c64_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -37,9 +37,9 @@ module attributes {pto.target_arch = "a5"} { pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c32_i64 + pto.dma_store %ub_out, %arg1, %c32_i64 nburst(%c32_i64, %c32_i64, %c32_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-f32/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmps-f32/kernel.pto index 1d7396289..20f091e15 100644 --- a/test/vpto/cases/micro-op/compare-select/vcmps-f32/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vcmps-f32/kernel.pto @@ -16,9 +16,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c256_i64 : i64 -> !pto.ptr pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 - pto.dma_load %arg0, %ub_src, %c0_i64, %c0_i64, %c256_i64 + pto.dma_load %arg0, %ub_src, %c0_i64, %c256_i64 nburst(%c64_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -37,9 +37,9 @@ module attributes {pto.target_arch = "a5"} { pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c32_i64 + pto.dma_store %ub_out, %arg1, %c32_i64 nburst(%c32_i64, %c32_i64, %c32_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/kernel.pto index c8fc83432..76153909a 100644 --- a/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/kernel.pto @@ -15,9 +15,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c256_i64 : i64 -> !pto.ptr pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 - pto.dma_load %arg0, %ub_src, %c0_i64, %c0_i64, %c256_i64 + pto.dma_load %arg0, %ub_src, %c0_i64, %c256_i64 nburst(%c64_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -36,9 +36,9 @@ module attributes {pto.target_arch = "a5"} { pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c32_i64 + pto.dma_store %ub_out, %arg1, %c32_i64 nburst(%c32_i64, %c32_i64, %c32_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/kernel.pto index 4df6d097d..146ce50dc 100644 --- a/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/kernel.pto @@ -15,9 +15,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c256_i64 : i64 -> !pto.ptr pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 - pto.dma_load %arg0, %ub_src, %c0_i64, %c0_i64, %c256_i64 + pto.dma_load %arg0, %ub_src, %c0_i64, %c256_i64 nburst(%c64_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -36,9 +36,9 @@ module attributes {pto.target_arch = "a5"} { pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c32_i64 + pto.dma_store %ub_out, %arg1, %c32_i64 nburst(%c32_i64, %c32_i64, %c32_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/kernel.pto index be566c5ec..289973d92 100644 --- a/test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/kernel.pto @@ -15,9 +15,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c256_i64 : i64 -> !pto.ptr pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 - pto.dma_load %arg0, %ub_src, %c0_i64, %c0_i64, %c256_i64 + pto.dma_load %arg0, %ub_src, %c0_i64, %c256_i64 nburst(%c64_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -36,9 +36,9 @@ module attributes {pto.target_arch = "a5"} { pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c32_i64 + pto.dma_store %ub_out, %arg1, %c32_i64 nburst(%c32_i64, %c32_i64, %c32_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/kernel.pto index a051e39a6..836b58df9 100644 --- a/test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/kernel.pto @@ -15,9 +15,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c256_i64 : i64 -> !pto.ptr pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 - pto.dma_load %arg0, %ub_src, %c0_i64, %c0_i64, %c256_i64 + pto.dma_load %arg0, %ub_src, %c0_i64, %c256_i64 nburst(%c64_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -36,9 +36,9 @@ module attributes {pto.target_arch = "a5"} { pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c32_i64 + pto.dma_store %ub_out, %arg1, %c32_i64 nburst(%c32_i64, %c32_i64, %c32_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-tail/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmps-tail/kernel.pto index c2abf5c19..7d6ff410c 100644 --- a/test/vpto/cases/micro-op/compare-select/vcmps-tail/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vcmps-tail/kernel.pto @@ -15,9 +15,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c256_i64 : i64 -> !pto.ptr pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 - pto.dma_load %arg0, %ub_src, %c0_i64, %c0_i64, %c256_i64 + pto.dma_load %arg0, %ub_src, %c0_i64, %c256_i64 nburst(%c64_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -36,9 +36,9 @@ module attributes {pto.target_arch = "a5"} { pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c32_i64 + pto.dma_store %ub_out, %arg1, %c32_i64 nburst(%c32_i64, %c32_i64, %c32_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/kernel.pto b/test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/kernel.pto index 66ec875f8..f213d1120 100644 --- a/test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/kernel.pto @@ -33,15 +33,15 @@ module attributes {pto.target_arch = "a5"} { pto.set_loop1_stride_outtoub %c128_i64, %c128_i64 : i64, i64 pto.set_loop2_stride_outtoub %c128_i64, %c128_i64 : i64, i64 - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg2, %ub_mask, %c0_i64, %c0_i64, %c32_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg2, %ub_mask, %c0_i64, %c32_i64 nburst(%c32_i64, %c32_i64, %c32_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -68,9 +68,9 @@ module attributes {pto.target_arch = "a5"} { pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.set_loop1_stride_ubtoout %c128_i64, %c128_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c128_i64, %c128_i64 : i64, i64 - pto.dma_store %ub_out, %arg3, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg3, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/compare-select/vsel-i16/kernel.pto b/test/vpto/cases/micro-op/compare-select/vsel-i16/kernel.pto index c2f616389..2a975830e 100644 --- a/test/vpto/cases/micro-op/compare-select/vsel-i16/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vsel-i16/kernel.pto @@ -16,12 +16,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c256_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c256_i64 nburst(%c64_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c256_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c256_i64 nburst(%c64_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -42,9 +42,9 @@ module attributes {pto.target_arch = "a5"} { pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.set_loop1_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c256_i64 + pto.dma_store %ub_out, %arg2, %c256_i64 nburst(%c64_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/kernel.pto b/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/kernel.pto index 8b25e4070..66d40e7bc 100644 --- a/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/kernel.pto @@ -17,12 +17,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c256_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c256_i64 nburst(%c64_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c256_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c256_i64 nburst(%c64_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -43,9 +43,9 @@ module attributes {pto.target_arch = "a5"} { pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.set_loop1_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c256_i64 + pto.dma_store %ub_out, %arg2, %c256_i64 nburst(%c64_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/compare-select/vsel-tail/kernel.pto b/test/vpto/cases/micro-op/compare-select/vsel-tail/kernel.pto index 46f470aa7..088cf183d 100644 --- a/test/vpto/cases/micro-op/compare-select/vsel-tail/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vsel-tail/kernel.pto @@ -16,15 +16,15 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c256_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c256_i64 nburst(%c64_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c256_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c256_i64 nburst(%c64_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg2, %ub_out, %c0_i64, %c0_i64, %c256_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg2, %ub_out, %c0_i64, %c256_i64 nburst(%c64_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -45,9 +45,9 @@ module attributes {pto.target_arch = "a5"} { pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.set_loop1_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c256_i64 + pto.dma_store %ub_out, %arg2, %c256_i64 nburst(%c64_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/compare-select/vsel/kernel.pto b/test/vpto/cases/micro-op/compare-select/vsel/kernel.pto index 2993878a7..f54e993f7 100644 --- a/test/vpto/cases/micro-op/compare-select/vsel/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vsel/kernel.pto @@ -16,12 +16,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c256_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c256_i64 nburst(%c64_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c256_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c256_i64 nburst(%c64_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -42,9 +42,9 @@ module attributes {pto.target_arch = "a5"} { pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.set_loop1_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c256_i64 + pto.dma_store %ub_out, %arg2, %c256_i64 nburst(%c64_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/compare-select/vselr-f16/kernel.pto b/test/vpto/cases/micro-op/compare-select/vselr-f16/kernel.pto index 8c5fa6e07..02d5e19dc 100644 --- a/test/vpto/cases/micro-op/compare-select/vselr-f16/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vselr-f16/kernel.pto @@ -22,12 +22,12 @@ module attributes {pto.target_arch = "a5"} { %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr %ub_idx = pto.castptr %c2048_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr - pto.dma_load %arg0, %ub_src, %c0_i64, %c0_i64, %c256_i64 + pto.dma_load %arg0, %ub_src, %c0_i64, %c256_i64 nburst(%c8_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_idx, %c0_i64, %c0_i64, %c256_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_idx, %c0_i64, %c256_i64 nburst(%c8_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -45,9 +45,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c256_i64 + pto.dma_store %ub_out, %arg2, %c256_i64 nburst(%c8_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/compare-select/vselr-u8/kernel.pto b/test/vpto/cases/micro-op/compare-select/vselr-u8/kernel.pto index 6f9fa6f6e..2084bde41 100644 --- a/test/vpto/cases/micro-op/compare-select/vselr-u8/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vselr-u8/kernel.pto @@ -22,12 +22,12 @@ module attributes {pto.target_arch = "a5"} { %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr %ub_idx = pto.castptr %c1024_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr - pto.dma_load %arg0, %ub_src, %c0_i64, %c0_i64, %c256_i64 + pto.dma_load %arg0, %ub_src, %c0_i64, %c256_i64 nburst(%c4_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_idx, %c0_i64, %c0_i64, %c256_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_idx, %c0_i64, %c256_i64 nburst(%c4_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -45,9 +45,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c256_i64 + pto.dma_store %ub_out, %arg2, %c256_i64 nburst(%c4_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/compare-select/vselr/kernel.pto b/test/vpto/cases/micro-op/compare-select/vselr/kernel.pto index aea856716..69a216b22 100644 --- a/test/vpto/cases/micro-op/compare-select/vselr/kernel.pto +++ b/test/vpto/cases/micro-op/compare-select/vselr/kernel.pto @@ -24,18 +24,18 @@ module attributes {pto.target_arch = "a5"} { pto.set_loop1_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 %6 = pto.castptr %5 : !pto.ptr -> !pto.ptr %false = arith.constant false - pto.dma_load %6, %0, %c0_i64, %c0_i64, %2 + pto.dma_load %6, %0, %c0_i64, %2 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 %7 = pto.castptr %c4096_i64 : i64 -> !pto.ptr %8 = pto.castptr %arg1 : !pto.ptr -> !pto.ptr %9 = pto.addptr %8, %4 : -> pto.set_loop2_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 pto.set_loop1_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 %10 = pto.castptr %9 : !pto.ptr -> !pto.ptr - pto.dma_load %10, %7, %c0_i64, %c0_i64, %2 + pto.dma_load %10, %7, %c0_i64, %2 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] %11 = pto.castptr %c8192_i64 : i64 -> !pto.ptr @@ -65,9 +65,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_loop1_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 %15 = pto.castptr %14 : !pto.ptr -> !pto.ptr - pto.dma_store %11, %15, %c0_i64, %c0_i64, %12 + pto.dma_store %11, %15, %12 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-special/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-f16-special/kernel.pto index 5c4766df6..e688fdfa5 100644 --- a/test/vpto/cases/micro-op/conversion/vcvt-f16-special/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-special/kernel.pto @@ -15,9 +15,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -33,9 +33,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/kernel.pto index 3829a35bc..a2eba0eb2 100644 --- a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/kernel.pto @@ -17,9 +17,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -37,9 +37,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c16_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/kernel.pto index f23a6b4b4..c89d17f43 100644 --- a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/kernel.pto @@ -17,9 +17,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -37,9 +37,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c16_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/kernel.pto index b0b5558e9..9e9db6afb 100644 --- a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/kernel.pto @@ -15,9 +15,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -33,9 +33,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-special/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-f32-special/kernel.pto index 67682c1da..33a833c7b 100644 --- a/test/vpto/cases/micro-op/conversion/vcvt-f32-special/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-special/kernel.pto @@ -18,9 +18,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -41,9 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg1, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/kernel.pto index 45d1582e4..ce8ed8c09 100644 --- a/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/kernel.pto @@ -21,9 +21,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -39,9 +39,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg1, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/kernel.pto index b9a88ed04..110ea1e19 100644 --- a/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/kernel.pto @@ -18,9 +18,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -41,9 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg1, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/kernel.pto index 75a9d625b..f635b34d7 100644 --- a/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/kernel.pto @@ -23,9 +23,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -46,9 +46,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg1, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/conversion/vcvt-tail-special/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-tail-special/kernel.pto index f1c8b0c80..e08e2523a 100644 --- a/test/vpto/cases/micro-op/conversion/vcvt-tail-special/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vcvt-tail-special/kernel.pto @@ -18,9 +18,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -41,9 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg1, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/conversion/vcvt-tail/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-tail/kernel.pto index fc77a7652..ec42555e4 100644 --- a/test/vpto/cases/micro-op/conversion/vcvt-tail/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vcvt-tail/kernel.pto @@ -18,9 +18,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -41,9 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg1, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/kernel.pto b/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/kernel.pto index 6d55727f1..8e20ffe7b 100644 --- a/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/kernel.pto @@ -15,9 +15,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -34,9 +34,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg1, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/kernel.pto b/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/kernel.pto index 7451aa8a2..1be8276b3 100644 --- a/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/kernel.pto @@ -41,9 +41,9 @@ module attributes {pto.target_arch = "a5"} { %ub_f = pto.castptr %c12288_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -64,15 +64,15 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_r, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_r, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_store %ub_z, %arg2, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.dma_store %ub_z, %arg2, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_store %ub_f, %arg3, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.dma_store %ub_f, %arg3, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/conversion/vtrc-f32-special/kernel.pto b/test/vpto/cases/micro-op/conversion/vtrc-f32-special/kernel.pto index ff0e6eb4c..4c3008d0f 100644 --- a/test/vpto/cases/micro-op/conversion/vtrc-f32-special/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vtrc-f32-special/kernel.pto @@ -14,9 +14,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -33,9 +33,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/kernel.pto b/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/kernel.pto index 7451aa8a2..1be8276b3 100644 --- a/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/kernel.pto @@ -41,9 +41,9 @@ module attributes {pto.target_arch = "a5"} { %ub_f = pto.castptr %c12288_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -64,15 +64,15 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_r, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_r, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_store %ub_z, %arg2, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.dma_store %ub_z, %arg2, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_store %ub_f, %arg3, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.dma_store %ub_f, %arg3, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/kernel.pto index eb4f8570b..ebc2e8226 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/kernel.pto @@ -26,12 +26,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_addend, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_addend, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -48,9 +48,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg2, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vbitsort/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vbitsort/kernel.pto index f45b2cb9b..9c82000a5 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vbitsort/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vbitsort/kernel.pto @@ -18,12 +18,12 @@ module attributes {pto.target_arch = "a5"} { %ub_scores = pto.castptr %c0_i64 : i64 -> !pto.ptr %ub_indices = pto.castptr %c128_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c256_i64 : i64 -> !pto.ptr - pto.dma_load %arg0, %ub_scores, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_scores, %c0_i64, %c128_i64 nburst(%c1_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_indices, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_indices, %c0_i64, %c128_i64 nburst(%c1_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -32,9 +32,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c256_i64 + pto.dma_store %ub_out, %arg2, %c256_i64 nburst(%c1_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci-f16/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vci-f16/kernel.pto index f2b1d360e..75afac8a7 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vci-f16/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vci-f16/kernel.pto @@ -17,9 +17,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c2_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci-si8/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vci-si8/kernel.pto index 59397a667..5b4ef4046 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vci-si8/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vci-si8/kernel.pto @@ -23,9 +23,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c8_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vci/kernel.pto index 3286f2a14..6ebeaf6c2 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vci/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vci/kernel.pto @@ -23,9 +23,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/kernel.pto index da20f83ab..e12d6718b 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/kernel.pto @@ -25,12 +25,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_max, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_max, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -47,9 +47,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg2, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/kernel.pto index 3c4e4e4fa..af2480002 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/kernel.pto @@ -27,12 +27,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_max, %c0_i64, %c0_i64, %c64_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_max, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -54,9 +54,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg2, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/kernel.pto index e9c83b6d7..618e90e67 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/kernel.pto @@ -23,9 +23,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -42,9 +42,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/kernel.pto index 3e58304af..71b093e04 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/kernel.pto @@ -24,9 +24,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -42,9 +42,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg1, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/kernel.pto index 0f864b522..718529e22 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/kernel.pto @@ -17,9 +17,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -36,9 +36,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/kernel.pto index 0f864b522..718529e22 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/kernel.pto @@ -17,9 +17,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -36,9 +36,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/kernel.pto index 13309185c..67d4150ce 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/kernel.pto @@ -16,9 +16,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -35,9 +35,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/kernel.pto index 4ef16ba54..47c420635 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/kernel.pto @@ -22,9 +22,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -43,9 +43,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmula/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vmula/kernel.pto index ef2debd20..02c8c2fda 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vmula/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vmula/kernel.pto @@ -22,9 +22,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -43,9 +43,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmull/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vmull/kernel.pto index 5ce7fd332..312d89953 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vmull/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vmull/kernel.pto @@ -23,9 +23,9 @@ module attributes {pto.target_arch = "a5"} { %gm_out = pto.castptr %arg1 : !pto.ptr -> !pto.ptr %false = arith.constant false - pto.dma_load %gm_in, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %gm_in, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -45,9 +45,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %gm_out, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/kernel.pto index 49e7afa73..207b17691 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/kernel.pto @@ -25,12 +25,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_alpha, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_alpha, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -47,9 +47,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg2, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/kernel.pto index 14c1a1271..854b70201 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/kernel.pto @@ -26,12 +26,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_alpha, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_alpha, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -49,9 +49,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg2, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/kernel.pto b/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/kernel.pto index e623a95c5..5d750ea07 100644 --- a/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/kernel.pto +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/kernel.pto @@ -25,12 +25,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_offsets, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_offsets, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -47,9 +47,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg2, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2/kernel.pto b/test/vpto/cases/micro-op/gather-scatter/vgather2/kernel.pto index 07abdcd98..fb04ffbbc 100644 --- a/test/vpto/cases/micro-op/gather-scatter/vgather2/kernel.pto +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2/kernel.pto @@ -44,12 +44,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_offsets, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_offsets, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -66,9 +66,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg2, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/kernel.pto b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/kernel.pto index aa0ffe7a6..831282357 100644 --- a/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/kernel.pto +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/kernel.pto @@ -26,12 +26,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_offsets, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_offsets, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -49,9 +49,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg2, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/kernel.pto b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/kernel.pto index 5e396a036..4249bcb95 100644 --- a/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/kernel.pto +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/kernel.pto @@ -26,12 +26,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_offsets, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_offsets, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -49,9 +49,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg2, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/kernel.pto b/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/kernel.pto index a25a6f38b..119d99580 100644 --- a/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/kernel.pto @@ -26,12 +26,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_offsets, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_offsets, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -49,9 +49,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg2, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/gather-scatter/vgatherb/kernel.pto b/test/vpto/cases/micro-op/gather-scatter/vgatherb/kernel.pto index 8a6de533a..5416730f0 100644 --- a/test/vpto/cases/micro-op/gather-scatter/vgatherb/kernel.pto +++ b/test/vpto/cases/micro-op/gather-scatter/vgatherb/kernel.pto @@ -26,12 +26,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_offsets, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_offsets, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -49,9 +49,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg2, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/kernel.pto b/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/kernel.pto index 63275d2b8..b4de75a50 100644 --- a/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/kernel.pto +++ b/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/kernel.pto @@ -25,15 +25,15 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_offsets, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_offsets, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg2, %ub_out, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg2, %ub_out, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -50,9 +50,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg2, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/gather-scatter/vscatter/kernel.pto b/test/vpto/cases/micro-op/gather-scatter/vscatter/kernel.pto index 021e3de6d..7fee06681 100644 --- a/test/vpto/cases/micro-op/gather-scatter/vscatter/kernel.pto +++ b/test/vpto/cases/micro-op/gather-scatter/vscatter/kernel.pto @@ -44,15 +44,15 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_offsets, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_offsets, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg2, %ub_out, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg2, %ub_out, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -69,9 +69,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg2, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/pand/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pand/kernel.pto index 6ce0c6b76..687cce24f 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/pand/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/pand/kernel.pto @@ -47,9 +47,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c32_i64 + pto.dma_store %ub_out, %gm_out, %c32_i64 nburst(%c1_i64, %c32_i64, %c32_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/kernel.pto index 97928963f..c0d540561 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/kernel.pto @@ -29,9 +29,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %gm_out, %c64_i64 nburst(%c1_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/kernel.pto index 4dc743519..eccc4f8b9 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/kernel.pto @@ -29,9 +29,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %gm_out, %c64_i64 nburst(%c1_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/kernel.pto index eaff46fb8..9a02f603b 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/kernel.pto @@ -29,9 +29,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %gm_out, %c64_i64 nburst(%c1_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/kernel.pto index f6716db6e..42290fe7b 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/kernel.pto @@ -29,9 +29,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %gm_out, %c64_i64 nburst(%c1_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/kernel.pto index 907a1f831..142339044 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/kernel.pto @@ -29,9 +29,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %gm_out, %c64_i64 nburst(%c1_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/kernel.pto index 5bdcd1e62..2b1256673 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/kernel.pto @@ -29,9 +29,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %gm_out, %c64_i64 nburst(%c1_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/kernel.pto index 36e8153e0..e0d779ca5 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/kernel.pto @@ -30,9 +30,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c96_i64 + pto.dma_store %ub_out, %gm_out, %c96_i64 nburst(%c1_i64, %c96_i64, %c96_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/kernel.pto index c35613534..d0abaf810 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/kernel.pto @@ -30,9 +30,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c96_i64 + pto.dma_store %ub_out, %gm_out, %c96_i64 nburst(%c1_i64, %c96_i64, %c96_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/kernel.pto index bd0b898b3..e2db7dce7 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/kernel.pto @@ -29,9 +29,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %gm_out, %c64_i64 nburst(%c1_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/kernel.pto index 09468bda1..811bbda8a 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/kernel.pto @@ -29,9 +29,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %gm_out, %c64_i64 nburst(%c1_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/kernel.pto index 073e994bb..9b861c19c 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/kernel.pto @@ -29,9 +29,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %gm_out, %c64_i64 nburst(%c1_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/kernel.pto index 21ad9e9f0..6c72c42c1 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/kernel.pto @@ -29,9 +29,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %gm_out, %c64_i64 nburst(%c1_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/kernel.pto index 33bf4a3a5..fbe80933a 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/kernel.pto @@ -29,9 +29,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %gm_out, %c64_i64 nburst(%c1_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/kernel.pto index 576e7f712..ea02076cb 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/kernel.pto @@ -29,9 +29,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %gm_out, %c64_i64 nburst(%c1_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/kernel.pto index 7a035d937..102e37991 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/kernel.pto @@ -31,9 +31,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c96_i64 + pto.dma_store %ub_out, %gm_out, %c96_i64 nburst(%c1_i64, %c96_i64, %c96_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/kernel.pto index 03a64089b..e2ecbf6b7 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/kernel.pto @@ -33,9 +33,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c96_i64 + pto.dma_store %ub_out, %gm_out, %c96_i64 nburst(%c1_i64, %c96_i64, %c96_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/pnot/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pnot/kernel.pto index e7cd4dc89..60be64076 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/pnot/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/pnot/kernel.pto @@ -28,9 +28,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c32_i64 + pto.dma_store %ub_out, %gm_out, %c32_i64 nburst(%c1_i64, %c32_i64, %c32_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/por/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/por/kernel.pto index b0074ccb1..04e58bdc7 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/por/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/por/kernel.pto @@ -47,9 +47,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c32_i64 + pto.dma_store %ub_out, %gm_out, %c32_i64 nburst(%c1_i64, %c32_i64, %c32_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/kernel.pto index cd68b8024..7fa6a5d9c 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/kernel.pto @@ -28,9 +28,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %gm_out, %c64_i64 nburst(%c1_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/kernel.pto index 37eae598b..9d065da1a 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/kernel.pto @@ -28,9 +28,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %gm_out, %c64_i64 nburst(%c1_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/kernel.pto index f3966c735..9600945a7 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/kernel.pto @@ -32,9 +32,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %gm_out, %c64_i64 nburst(%c1_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/psel/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/psel/kernel.pto index 1dbfd52f5..071109605 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/psel/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/psel/kernel.pto @@ -28,9 +28,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c32_i64 + pto.dma_store %ub_out, %gm_out, %c32_i64 nburst(%c1_i64, %c32_i64, %c32_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/kernel.pto index 4ebd9389e..72b50e2e0 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/kernel.pto @@ -30,9 +30,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c96_i64 + pto.dma_store %ub_out, %gm_out, %c96_i64 nburst(%c1_i64, %c96_i64, %c96_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/kernel.pto index 9a409c85c..18d37e4a2 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/kernel.pto @@ -30,9 +30,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c96_i64 + pto.dma_store %ub_out, %gm_out, %c96_i64 nburst(%c1_i64, %c96_i64, %c96_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/pxor/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pxor/kernel.pto index 925d6a810..a3da9f913 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/pxor/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/pxor/kernel.pto @@ -47,9 +47,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c32_i64 + pto.dma_store %ub_out, %gm_out, %c32_i64 nburst(%c1_i64, %c32_i64, %c32_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/kernel.pto index f31fe25fd..1f83c43ea 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/kernel.pto @@ -40,9 +40,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg0, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg0, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/kernel.pto index 5e1cfe7f2..171c006a5 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/kernel.pto @@ -40,9 +40,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg0, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg0, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-i8/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/vbr-i8/kernel.pto index 80b6d376d..006e3fe6d 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/vbr-i8/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-i8/kernel.pto @@ -40,9 +40,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg0, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg0, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-u8/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/vbr-u8/kernel.pto index 23da4ecce..ebfea9f23 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/vbr-u8/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-u8/kernel.pto @@ -40,9 +40,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg0, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg0, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/kernel.pto index c9ba40fc0..233d71cec 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/kernel.pto @@ -23,9 +23,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_loop2_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 pto.set_loop1_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -42,12 +42,12 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_low, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_low, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_store %ub_high, %arg2, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.dma_store %ub_high, %arg2, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/kernel.pto index a5dc03db4..785ae61fd 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/kernel.pto @@ -27,9 +27,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg0, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg0, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/kernel.pto index 83729c53e..f7a6abf74 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/kernel.pto @@ -27,9 +27,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg0, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg0, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/kernel.pto index 9dc6debf7..d2f99a5e5 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/kernel.pto @@ -27,9 +27,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg0, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg0, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/kernel.pto index 4c736a0cc..27410fa17 100644 --- a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/kernel.pto +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/kernel.pto @@ -40,9 +40,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg0, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg0, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/kernel.pto index c9c206be4..7cef3d6d3 100644 --- a/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/kernel.pto +++ b/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/kernel.pto @@ -25,9 +25,9 @@ module attributes {pto.target_arch = "a5"} { %ub_in = pto.castptr %c8192_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c10240_i64 : i64 -> !pto.ptr - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c256_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c256_i64 nburst(%c4_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -50,9 +50,9 @@ module attributes {pto.target_arch = "a5"} { pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.set_loop1_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c256_i64 + pto.dma_store %ub_out, %arg1, %c256_i64 nburst(%c4_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/predicate-load-store/plds-norm/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/plds-norm/kernel.pto index 05c0727b3..18fa1ea04 100644 --- a/test/vpto/cases/micro-op/predicate-load-store/plds-norm/kernel.pto +++ b/test/vpto/cases/micro-op/predicate-load-store/plds-norm/kernel.pto @@ -23,9 +23,9 @@ module attributes {pto.target_arch = "a5"} { %ub_in = pto.castptr %c8192_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c10240_i64 : i64 -> !pto.ptr - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c256_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c256_i64 nburst(%c4_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -49,9 +49,9 @@ module attributes {pto.target_arch = "a5"} { pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.set_loop1_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c256_i64 + pto.dma_store %ub_out, %arg1, %c256_i64 nburst(%c4_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/kernel.pto index 1ce7b1f7b..c8d37df93 100644 --- a/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/kernel.pto +++ b/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/kernel.pto @@ -20,9 +20,9 @@ module attributes {pto.target_arch = "a5"} { %ub_mid = pto.castptr %c8192_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c10240_i64 : i64 -> !pto.ptr - pto.dma_load %arg1, %ub_mid, %c0_i64, %c0_i64, %c32_i64 + pto.dma_load %arg1, %ub_mid, %c0_i64, %c32_i64 nburst(%c32_i64, %c32_i64, %c32_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -42,9 +42,9 @@ module attributes {pto.target_arch = "a5"} { pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c32_i64 + pto.dma_store %ub_out, %arg2, %c32_i64 nburst(%c32_i64, %c32_i64, %c32_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/kernel.pto index 297c5a42d..cf17771b9 100644 --- a/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/kernel.pto +++ b/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/kernel.pto @@ -20,9 +20,9 @@ module attributes {pto.target_arch = "a5"} { %ub_mid = pto.castptr %c8192_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c10240_i64 : i64 -> !pto.ptr - pto.dma_load %arg1, %ub_mid, %c0_i64, %c0_i64, %c32_i64 + pto.dma_load %arg1, %ub_mid, %c0_i64, %c32_i64 nburst(%c32_i64, %c32_i64, %c32_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -42,9 +42,9 @@ module attributes {pto.target_arch = "a5"} { pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c32_i64 + pto.dma_store %ub_out, %arg2, %c32_i64 nburst(%c32_i64, %c32_i64, %c32_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/predicate-load-store/psti-pk/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/psti-pk/kernel.pto index 696e33544..316f2c7b6 100644 --- a/test/vpto/cases/micro-op/predicate-load-store/psti-pk/kernel.pto +++ b/test/vpto/cases/micro-op/predicate-load-store/psti-pk/kernel.pto @@ -25,9 +25,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c32_i64 + pto.dma_store %ub_out, %gm_out, %c32_i64 nburst(%c1_i64, %c32_i64, %c32_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/kernel.pto index ec078bc62..eae3b4b46 100644 --- a/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/kernel.pto +++ b/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/kernel.pto @@ -22,9 +22,9 @@ module attributes {pto.target_arch = "a5"} { %ub_mid = pto.castptr %c8192_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c10240_i64 : i64 -> !pto.ptr - pto.dma_load %arg1, %ub_mid, %c0_i64, %c0_i64, %c32_i64 + pto.dma_load %arg1, %ub_mid, %c0_i64, %c32_i64 nburst(%c32_i64, %c32_i64, %c32_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -47,9 +47,9 @@ module attributes {pto.target_arch = "a5"} { pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c32_i64 + pto.dma_store %ub_out, %arg2, %c32_i64 nburst(%c32_i64, %c32_i64, %c32_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/kernel.pto index 7394d1d98..45860ebfd 100644 --- a/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/kernel.pto @@ -21,9 +21,9 @@ module attributes {pto.target_arch = "a5"} { %ub_mid = pto.castptr %c8192_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c10240_i64 : i64 -> !pto.ptr - pto.dma_load %arg1, %ub_mid, %c0_i64, %c0_i64, %c32_i64 + pto.dma_load %arg1, %ub_mid, %c0_i64, %c32_i64 nburst(%c32_i64, %c32_i64, %c32_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -44,9 +44,9 @@ module attributes {pto.target_arch = "a5"} { pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c32_i64 + pto.dma_store %ub_out, %arg2, %c32_i64 nburst(%c32_i64, %c32_i64, %c32_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/kernel.pto index ca010a7d5..e48994eb7 100644 --- a/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/kernel.pto +++ b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/kernel.pto @@ -20,9 +20,9 @@ module attributes {pto.target_arch = "a5"} { %ub_mid = pto.castptr %c8192_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c10240_i64 : i64 -> !pto.ptr - pto.dma_load %arg1, %ub_mid, %c0_i64, %c0_i64, %c32_i64 + pto.dma_load %arg1, %ub_mid, %c0_i64, %c32_i64 nburst(%c32_i64, %c32_i64, %c32_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -43,9 +43,9 @@ module attributes {pto.target_arch = "a5"} { pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c32_i64 + pto.dma_store %ub_out, %arg2, %c32_i64 nburst(%c32_i64, %c32_i64, %c32_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/kernel.pto index 7f984fd38..d0736b38e 100644 --- a/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/kernel.pto +++ b/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/kernel.pto @@ -37,9 +37,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_mask, %arg2, %c0_i64, %c0_i64, %c32_i64 + pto.dma_store %ub_mask, %arg2, %c32_i64 nburst(%c1_i64, %c32_i64, %c32_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/kernel.pto index f98e46330..f63161bf2 100644 --- a/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/kernel.pto @@ -32,9 +32,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_mask, %arg2, %c0_i64, %c0_i64, %c32_i64 + pto.dma_store %ub_mask, %arg2, %c32_i64 nburst(%c1_i64, %c32_i64, %c32_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/predicate-load-store/pstu/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/pstu/kernel.pto index 3c6d66504..1298f2b93 100644 --- a/test/vpto/cases/micro-op/predicate-load-store/pstu/kernel.pto +++ b/test/vpto/cases/micro-op/predicate-load-store/pstu/kernel.pto @@ -33,9 +33,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_mask, %arg2, %c0_i64, %c0_i64, %c32_i64 + pto.dma_store %ub_mask, %arg2, %c32_i64 nburst(%c1_i64, %c32_i64, %c32_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/kernel.pto index d9441c369..2a2551aac 100644 --- a/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/kernel.pto @@ -22,9 +22,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -45,9 +45,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/kernel.pto index 923366521..216c59cd9 100644 --- a/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/kernel.pto +++ b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/kernel.pto @@ -22,9 +22,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -45,9 +45,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/rearrangement/vpack-higher/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vpack-higher/kernel.pto index 5964dc138..45f5f7827 100644 --- a/test/vpto/cases/micro-op/rearrangement/vpack-higher/kernel.pto +++ b/test/vpto/cases/micro-op/rearrangement/vpack-higher/kernel.pto @@ -22,9 +22,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -42,9 +42,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/rearrangement/vpack-lower/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vpack-lower/kernel.pto index 956b3f46a..a0cea3dc6 100644 --- a/test/vpto/cases/micro-op/rearrangement/vpack-lower/kernel.pto +++ b/test/vpto/cases/micro-op/rearrangement/vpack-lower/kernel.pto @@ -22,9 +22,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -42,9 +42,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/kernel.pto index 77adf7d9f..d8ea645b1 100644 --- a/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/kernel.pto +++ b/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/kernel.pto @@ -33,12 +33,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_mask_seed, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_mask_seed, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -56,9 +56,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/rearrangement/vsqz/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vsqz/kernel.pto index 4bf1d2a0c..f3b8ba660 100644 --- a/test/vpto/cases/micro-op/rearrangement/vsqz/kernel.pto +++ b/test/vpto/cases/micro-op/rearrangement/vsqz/kernel.pto @@ -42,9 +42,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -61,9 +61,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/rearrangement/vsunpack/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vsunpack/kernel.pto index db726fa47..0a09e5c56 100644 --- a/test/vpto/cases/micro-op/rearrangement/vsunpack/kernel.pto +++ b/test/vpto/cases/micro-op/rearrangement/vsunpack/kernel.pto @@ -44,9 +44,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %gm_in, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %gm_in, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -63,9 +63,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %gm_out, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/kernel.pto index 4feb0d6a6..d0991617f 100644 --- a/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/kernel.pto +++ b/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/kernel.pto @@ -27,12 +27,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_src, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_src, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_mask_seed, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_mask_seed, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -50,9 +50,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg2, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/rearrangement/vusqz/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vusqz/kernel.pto index 6f890a9ea..8810df858 100644 --- a/test/vpto/cases/micro-op/rearrangement/vusqz/kernel.pto +++ b/test/vpto/cases/micro-op/rearrangement/vusqz/kernel.pto @@ -27,12 +27,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_src, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_src, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_mask_seed, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_mask_seed, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -50,9 +50,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg2, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/rearrangement/vzunpack/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vzunpack/kernel.pto index 174baf184..72dc647d7 100644 --- a/test/vpto/cases/micro-op/rearrangement/vzunpack/kernel.pto +++ b/test/vpto/cases/micro-op/rearrangement/vzunpack/kernel.pto @@ -44,9 +44,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %gm_in, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %gm_in, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -63,9 +63,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %gm_out, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %gm_out, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/reduction/vcadd-tail/kernel.pto b/test/vpto/cases/micro-op/reduction/vcadd-tail/kernel.pto index fa7d8818f..d1cf038f4 100644 --- a/test/vpto/cases/micro-op/reduction/vcadd-tail/kernel.pto +++ b/test/vpto/cases/micro-op/reduction/vcadd-tail/kernel.pto @@ -14,9 +14,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -33,9 +33,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/reduction/vcadd/kernel.pto b/test/vpto/cases/micro-op/reduction/vcadd/kernel.pto index ae8062a97..3f54ab42d 100644 --- a/test/vpto/cases/micro-op/reduction/vcadd/kernel.pto +++ b/test/vpto/cases/micro-op/reduction/vcadd/kernel.pto @@ -15,9 +15,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -34,9 +34,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/reduction/vcgadd-tail/kernel.pto b/test/vpto/cases/micro-op/reduction/vcgadd-tail/kernel.pto index 255bf3b56..dd0729bbe 100644 --- a/test/vpto/cases/micro-op/reduction/vcgadd-tail/kernel.pto +++ b/test/vpto/cases/micro-op/reduction/vcgadd-tail/kernel.pto @@ -23,9 +23,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -42,9 +42,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/reduction/vcgadd/kernel.pto b/test/vpto/cases/micro-op/reduction/vcgadd/kernel.pto index b3d7f4a18..9d477bd88 100644 --- a/test/vpto/cases/micro-op/reduction/vcgadd/kernel.pto +++ b/test/vpto/cases/micro-op/reduction/vcgadd/kernel.pto @@ -23,9 +23,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -42,9 +42,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/reduction/vcgmax-tie/kernel.pto b/test/vpto/cases/micro-op/reduction/vcgmax-tie/kernel.pto index e9e35c09e..a69d2bdec 100644 --- a/test/vpto/cases/micro-op/reduction/vcgmax-tie/kernel.pto +++ b/test/vpto/cases/micro-op/reduction/vcgmax-tie/kernel.pto @@ -23,9 +23,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -42,9 +42,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/reduction/vcgmax/kernel.pto b/test/vpto/cases/micro-op/reduction/vcgmax/kernel.pto index 177639bdf..ce97a56b9 100644 --- a/test/vpto/cases/micro-op/reduction/vcgmax/kernel.pto +++ b/test/vpto/cases/micro-op/reduction/vcgmax/kernel.pto @@ -23,9 +23,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -42,9 +42,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/reduction/vcgmin-tie/kernel.pto b/test/vpto/cases/micro-op/reduction/vcgmin-tie/kernel.pto index 6f8c02864..c780d92d0 100644 --- a/test/vpto/cases/micro-op/reduction/vcgmin-tie/kernel.pto +++ b/test/vpto/cases/micro-op/reduction/vcgmin-tie/kernel.pto @@ -23,9 +23,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -42,9 +42,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/reduction/vcgmin/kernel.pto b/test/vpto/cases/micro-op/reduction/vcgmin/kernel.pto index cd4597eb9..4bb2c8d7f 100644 --- a/test/vpto/cases/micro-op/reduction/vcgmin/kernel.pto +++ b/test/vpto/cases/micro-op/reduction/vcgmin/kernel.pto @@ -23,9 +23,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -42,9 +42,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/reduction/vcmax/kernel.pto b/test/vpto/cases/micro-op/reduction/vcmax/kernel.pto index 0079b3724..b18c382f8 100644 --- a/test/vpto/cases/micro-op/reduction/vcmax/kernel.pto +++ b/test/vpto/cases/micro-op/reduction/vcmax/kernel.pto @@ -15,9 +15,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -34,9 +34,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/reduction/vcmin/kernel.pto b/test/vpto/cases/micro-op/reduction/vcmin/kernel.pto index 171d088f8..40ec8d911 100644 --- a/test/vpto/cases/micro-op/reduction/vcmin/kernel.pto +++ b/test/vpto/cases/micro-op/reduction/vcmin/kernel.pto @@ -15,9 +15,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -34,9 +34,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/reduction/vcpadd-tail/kernel.pto b/test/vpto/cases/micro-op/reduction/vcpadd-tail/kernel.pto index 87930eb3f..05c11b7f0 100644 --- a/test/vpto/cases/micro-op/reduction/vcpadd-tail/kernel.pto +++ b/test/vpto/cases/micro-op/reduction/vcpadd-tail/kernel.pto @@ -23,9 +23,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -42,9 +42,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/reduction/vcpadd/kernel.pto b/test/vpto/cases/micro-op/reduction/vcpadd/kernel.pto index 02217c761..43066b6b9 100644 --- a/test/vpto/cases/micro-op/reduction/vcpadd/kernel.pto +++ b/test/vpto/cases/micro-op/reduction/vcpadd/kernel.pto @@ -23,9 +23,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -42,9 +42,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/kernel.pto b/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/kernel.pto index 6b9b69a30..3ec47b2ed 100644 --- a/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/kernel.pto +++ b/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/kernel.pto @@ -21,9 +21,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -41,9 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg1, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-f16/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vabs-f16/kernel.pto index d27d811d6..46c0d0b27 100644 --- a/test/vpto/cases/micro-op/unary-vector/vabs-f16/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vabs-f16/kernel.pto @@ -42,9 +42,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -61,9 +61,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/kernel.pto index 5e70d9aa0..af321c360 100644 --- a/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/kernel.pto @@ -14,9 +14,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -33,9 +33,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/kernel.pto index a30c4171f..4050c6f7d 100644 --- a/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/kernel.pto @@ -21,9 +21,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -40,9 +40,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg1, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/kernel.pto index bd15cdeff..2fceef4fc 100644 --- a/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/kernel.pto @@ -42,9 +42,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -61,9 +61,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/kernel.pto index b06a30b8c..32d45ed32 100644 --- a/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/kernel.pto @@ -42,9 +42,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -61,9 +61,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/kernel.pto index 289fff56b..5aca615a2 100644 --- a/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/kernel.pto @@ -18,9 +18,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -41,9 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-tail/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vabs-tail/kernel.pto index f40c918dc..d51bbcf77 100644 --- a/test/vpto/cases/micro-op/unary-vector/vabs-tail/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vabs-tail/kernel.pto @@ -14,9 +14,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -33,9 +33,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vabs/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vabs/kernel.pto index dece39ab6..da3a5fb76 100644 --- a/test/vpto/cases/micro-op/unary-vector/vabs/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vabs/kernel.pto @@ -34,9 +34,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -53,9 +53,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-f16/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vexp-f16/kernel.pto index e393f27c0..101190d43 100644 --- a/test/vpto/cases/micro-op/unary-vector/vexp-f16/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vexp-f16/kernel.pto @@ -22,9 +22,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -41,9 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/kernel.pto index 677c46ed3..6961c7252 100644 --- a/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/kernel.pto @@ -14,9 +14,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -33,9 +33,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/kernel.pto index 3a0cfa12b..219266020 100644 --- a/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/kernel.pto @@ -14,9 +14,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -33,9 +33,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-tail/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vexp-tail/kernel.pto index 417b59631..f43f28687 100644 --- a/test/vpto/cases/micro-op/unary-vector/vexp-tail/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vexp-tail/kernel.pto @@ -14,9 +14,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -33,9 +33,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vexp/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vexp/kernel.pto index 361a3f405..f87d315cd 100644 --- a/test/vpto/cases/micro-op/unary-vector/vexp/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vexp/kernel.pto @@ -34,9 +34,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -53,9 +53,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/kernel.pto index 4dc052fbf..58535c556 100644 --- a/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/kernel.pto @@ -42,9 +42,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -61,9 +61,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vln/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vln/kernel.pto index 2c4b24b4d..bf9c0a8a4 100644 --- a/test/vpto/cases/micro-op/unary-vector/vln/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vln/kernel.pto @@ -15,9 +15,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -34,9 +34,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/kernel.pto index f42d91a17..d56307776 100644 --- a/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/kernel.pto @@ -42,9 +42,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -61,9 +61,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vneg/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vneg/kernel.pto index a3b0db15f..3acf5001e 100644 --- a/test/vpto/cases/micro-op/unary-vector/vneg/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vneg/kernel.pto @@ -42,9 +42,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -61,9 +61,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vnot/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vnot/kernel.pto index 638c376db..b44e4a63e 100644 --- a/test/vpto/cases/micro-op/unary-vector/vnot/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vnot/kernel.pto @@ -22,9 +22,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -41,9 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vrelu/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vrelu/kernel.pto index 7a3b10eac..78366a6ef 100644 --- a/test/vpto/cases/micro-op/unary-vector/vrelu/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vrelu/kernel.pto @@ -15,9 +15,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -34,9 +34,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/kernel.pto index e1b5d4f0f..e1000c4a6 100644 --- a/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/kernel.pto @@ -42,9 +42,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -61,9 +61,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/unary-vector/vsqrt/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vsqrt/kernel.pto index dd97fb468..97563248b 100644 --- a/test/vpto/cases/micro-op/unary-vector/vsqrt/kernel.pto +++ b/test/vpto/cases/micro-op/unary-vector/vsqrt/kernel.pto @@ -15,9 +15,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -34,9 +34,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/kernel.pto index ee4012a99..078a4333c 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/kernel.pto @@ -23,12 +23,12 @@ module attributes {pto.target_arch = "a5"} { %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %ub_carry = pto.castptr %c12288_i64 : i64 -> !pto.ptr - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c256_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c256_i64 nburst(%c1_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c256_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c256_i64 nburst(%c1_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -45,12 +45,12 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c256_i64 + pto.dma_store %ub_out, %arg2, %c256_i64 nburst(%c1_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_store %ub_carry, %arg3, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.dma_store %ub_carry, %arg3, %c128_i64 nburst(%c1_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vaddcs/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vaddcs/kernel.pto index 76bf73839..d44f812cb 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vaddcs/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vaddcs/kernel.pto @@ -21,12 +21,12 @@ module attributes {pto.target_arch = "a5"} { %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %ub_carry = pto.castptr %c12288_i64 : i64 -> !pto.ptr - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c256_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c256_i64 nburst(%c1_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c256_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c256_i64 nburst(%c1_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -43,12 +43,12 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c256_i64 + pto.dma_store %ub_out, %arg2, %c256_i64 nburst(%c1_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_store %ub_carry, %arg3, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.dma_store %ub_carry, %arg3, %c128_i64 nburst(%c1_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/kernel.pto index f4c829c50..b63099ee0 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/kernel.pto @@ -15,9 +15,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -33,9 +33,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg1, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-f16/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vadds-f16/kernel.pto index 7e79f1862..ab518ba32 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vadds-f16/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-f16/kernel.pto @@ -15,9 +15,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -33,9 +33,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg1, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/kernel.pto index 17839109d..2a7f8900b 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/kernel.pto @@ -16,9 +16,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -35,9 +35,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/kernel.pto index a450713a8..fbc57fbdd 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/kernel.pto @@ -21,9 +21,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -39,9 +39,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg1, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/kernel.pto index 0e319b6cf..99f51b0d2 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/kernel.pto @@ -21,9 +21,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -39,9 +39,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg1, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/kernel.pto index c2b7cf7a2..e3ab82d7d 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/kernel.pto @@ -20,9 +20,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -38,9 +38,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg1, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/kernel.pto index ae7984b3c..b6af5f53a 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/kernel.pto @@ -20,9 +20,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -38,9 +38,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg1, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-tail/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vadds-tail/kernel.pto index f6dde0983..b29daad22 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vadds-tail/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-tail/kernel.pto @@ -16,9 +16,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -35,9 +35,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vadds/kernel.pto index 4bb2cbfbf..4cfbd7bb2 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vadds/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vadds/kernel.pto @@ -17,9 +17,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -36,9 +36,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/kernel.pto index 91b3c6cfc..91eea46b6 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/kernel.pto @@ -16,9 +16,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -35,9 +35,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vmaxs/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vmaxs/kernel.pto index ec3fa1106..f5014e1c9 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vmaxs/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vmaxs/kernel.pto @@ -17,9 +17,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -36,9 +36,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vmins-tail/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vmins-tail/kernel.pto index 59092b972..959fe0a0d 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vmins-tail/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vmins-tail/kernel.pto @@ -16,9 +16,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -35,9 +35,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vmins/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vmins/kernel.pto index d2fa7b4e7..8320f73a3 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vmins/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vmins/kernel.pto @@ -17,9 +17,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -36,9 +36,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/kernel.pto index 6c8527f8c..4ba5be0e6 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/kernel.pto @@ -16,9 +16,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -35,9 +35,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vmuls/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vmuls/kernel.pto index 9992d7887..41ff4f793 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vmuls/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vmuls/kernel.pto @@ -17,9 +17,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -36,9 +36,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/kernel.pto index 2a6a4f6b5..54c447757 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/kernel.pto @@ -20,9 +20,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -38,9 +38,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg1, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vshls/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vshls/kernel.pto index 184d1a1f5..ea2820b4e 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vshls/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vshls/kernel.pto @@ -20,9 +20,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -38,9 +38,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg1, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/kernel.pto index 07bfbed77..030694c12 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/kernel.pto @@ -20,9 +20,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -38,9 +38,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg1, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vshrs/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vshrs/kernel.pto index 783e5cd09..053c0fe89 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vshrs/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vshrs/kernel.pto @@ -20,9 +20,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -38,9 +38,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c64_i64 + pto.dma_store %ub_out, %arg1, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/kernel.pto index 8bb3028e4..6f5071841 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/kernel.pto @@ -23,12 +23,12 @@ module attributes {pto.target_arch = "a5"} { %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %ub_borrow = pto.castptr %c12288_i64 : i64 -> !pto.ptr - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c256_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c256_i64 nburst(%c1_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c256_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c256_i64 nburst(%c1_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -45,12 +45,12 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c256_i64 + pto.dma_store %ub_out, %arg2, %c256_i64 nburst(%c1_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_store %ub_borrow, %arg3, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.dma_store %ub_borrow, %arg3, %c128_i64 nburst(%c1_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vec-scalar/vsubcs/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vsubcs/kernel.pto index 0969c996a..f7f4c87b6 100644 --- a/test/vpto/cases/micro-op/vec-scalar/vsubcs/kernel.pto +++ b/test/vpto/cases/micro-op/vec-scalar/vsubcs/kernel.pto @@ -21,12 +21,12 @@ module attributes {pto.target_arch = "a5"} { %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr %ub_borrow = pto.castptr %c12288_i64 : i64 -> !pto.ptr - pto.dma_load %arg0, %ub_lhs, %c0_i64, %c0_i64, %c256_i64 + pto.dma_load %arg0, %ub_lhs, %c0_i64, %c256_i64 nburst(%c1_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_rhs, %c0_i64, %c0_i64, %c256_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_rhs, %c0_i64, %c256_i64 nburst(%c1_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -43,12 +43,12 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg2, %c0_i64, %c0_i64, %c256_i64 + pto.dma_store %ub_out, %arg2, %c256_i64 nburst(%c1_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_store %ub_borrow, %arg3, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.dma_store %ub_borrow, %arg3, %c128_i64 nburst(%c1_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/kernel.pto index 1592a02fa..7aaac2724 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/kernel.pto @@ -27,9 +27,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -52,9 +52,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/kernel.pto index e49e98352..fc57605dc 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/kernel.pto @@ -42,9 +42,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -62,9 +62,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/kernel.pto index df14c972c..73205a244 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/kernel.pto @@ -20,9 +20,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -38,9 +38,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/kernel.pto index 9aa74cd03..dee879c0f 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/kernel.pto @@ -27,9 +27,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -45,9 +45,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/kernel.pto index c7d4982e7..5610109c0 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/kernel.pto @@ -42,9 +42,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -60,9 +60,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/kernel.pto index 5d0375db2..fcbd66d5f 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/kernel.pto @@ -20,9 +20,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -38,9 +38,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/kernel.pto index ff5b07da8..8a4e998bf 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/kernel.pto @@ -23,9 +23,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -41,9 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-dma-loop/compare.py b/test/vpto/cases/micro-op/vector-load-store/vlds-dma-loop/compare.py new file mode 100644 index 000000000..81cfc5edb --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-dma-loop/compare.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + abs_diff = np.abs(golden.astype(np.float64) - output.astype(np.float64)) + idx = int(np.argmax(abs_diff)) + print( + f"[ERROR] Mismatch at idx={idx}: golden={golden[idx]}, out={output[idx]}, " + f"diff={abs_diff[idx]}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-dma-loop/golden.py b/test/vpto/cases/micro-op/vector-load-store/vlds-dma-loop/golden.py new file mode 100644 index 000000000..ee6929e7f --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-dma-loop/golden.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 8 +INPUT_COLS = 56 +OUTPUT_COLS = 64 +PAD_VALUE = 1.0 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, INPUT_COLS)).astype(np.float32) + v2 = np.zeros((ROWS, OUTPUT_COLS), dtype=np.float32) + golden_v2 = np.full((ROWS, OUTPUT_COLS), PAD_VALUE, dtype=np.float32) + golden_v2[:, :INPUT_COLS] = v1 + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate inputs/golden for VPTO micro-op vlds dma loop validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-dma-loop/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds-dma-loop/kernel.pto new file mode 100644 index 000000000..c40c721ad --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-dma-loop/kernel.pto @@ -0,0 +1,67 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-dma-loop +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f32, dma-loop-load-store, sw-loop-plus-hw-loop, full-mask, aligned, dist-norm +// ----------------------------------------------------------------------------- + +module attributes {pto.target_arch = "a5"} { + func.func @vlds_dma_loop_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c512 = arith.constant 512 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + %c8_i64 = arith.constant 8 : i64 + %c224_i64 = arith.constant 224 : i64 + %c256_i64 = arith.constant 256 : i64 + %c448_i64 = arith.constant 448 : i64 + %c512_i64 = arith.constant 512 : i64 + %c896_i64 = arith.constant 896 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c512_i32 = arith.constant 512 : i32 + %pad = arith.constant 1.000000e+00 : f32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + pto.dma_load %arg0, %ub_in, %c0_i64, %c224_i64 + nburst(%c1_i64, %c224_i64, %c256_i64) + loop(%c2_i64, %c224_i64, %c256_i64) + loop(%c2_i64, %c448_i64, %c512_i64) + loop(%c2_i64, %c896_i64, %c1024_i64) + pad(%pad, %c0_i64, %c8_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, + loop i64, i64, i64, loop i64, i64, i64, loop i64, i64, i64, pad f32, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c512 step %c64 iter_args(%remaining = %c512_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %value = pto.vlds %ub_in[%offset] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> + pto.vsts %value, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.dma_store %ub_out, %arg1, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + loop(%c2_i64, %c256_i64, %c256_i64) + loop(%c2_i64, %c512_i64, %c512_i64) + loop(%c2_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, + loop i64, i64, i64, loop i64, i64, i64, loop i64, i64, i64 + + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-dma-loop/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-dma-loop/launch.cpp new file mode 100644 index 000000000..3f59702eb --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-dma-loop/launch.cpp @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vlds_dma_loop_kernel(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVlds_dma_loop_kernel(float *v1, float *v2, void *stream) { + vlds_dma_loop_kernel<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-dma-loop/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-dma-loop/main.cpp new file mode 100644 index 000000000..42d510325 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-dma-loop/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVlds_dma_loop_kernel(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 448; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 512; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVlds_dma_loop_kernel(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-dma-loop/stub.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-dma-loop/stub.cpp new file mode 100644 index 000000000..ec3c09629 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-dma-loop/stub.cpp @@ -0,0 +1,21 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vlds_dma_loop_kernel(__gm__ float *v1, + __gm__ float *v2) { + (void)v1; + (void)v2; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/kernel.pto index 56a9938ef..35a62c84f 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/kernel.pto @@ -23,9 +23,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -41,9 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-tail/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds-tail/kernel.pto index 1270dffc5..b6249bd2a 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vlds-tail/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-tail/kernel.pto @@ -43,9 +43,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -61,9 +61,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/kernel.pto index 7ce1a5b81..81bbef156 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/kernel.pto @@ -29,9 +29,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c64_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c64_i64 nburst(%c32_i64, %c64_i64, %c64_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -47,9 +47,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/kernel.pto index d7ced22df..c21b86fa7 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/kernel.pto @@ -23,9 +23,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -41,9 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds/kernel.pto index 56373cafc..ac7442714 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vlds/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vlds/kernel.pto @@ -42,9 +42,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -60,9 +60,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/kernel.pto index 0eeefab4e..ecbed6870 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/kernel.pto @@ -22,9 +22,9 @@ module attributes {pto.target_arch = "a5"} { %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -48,9 +48,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/kernel.pto index 899859fd9..0a1eefd7e 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/kernel.pto @@ -21,9 +21,9 @@ module attributes {pto.target_arch = "a5"} { %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -45,9 +45,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/kernel.pto index e3fe71b6f..bac9c91a6 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/kernel.pto @@ -21,9 +21,9 @@ module attributes {pto.target_arch = "a5"} { %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -45,9 +45,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vsldb/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vsldb/kernel.pto index 454684cf1..a993f31e6 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vsldb/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vsldb/kernel.pto @@ -20,9 +20,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -38,9 +38,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vsstb/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vsstb/kernel.pto index defe42ef5..578b9f09c 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vsstb/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vsstb/kernel.pto @@ -20,9 +20,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -41,9 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_in, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_in, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vstar/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vstar/kernel.pto index 89d0b856c..7587447cf 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vstar/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vstar/kernel.pto @@ -32,9 +32,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out1 = pto.addptr %ub_out, %c1_elem : !pto.ptr -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -54,9 +54,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/kernel.pto index 34dc63560..e23527a5e 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/kernel.pto @@ -30,12 +30,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_out, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_out, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -52,9 +52,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/kernel.pto index 592b12b7c..e65a7e71d 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/kernel.pto @@ -23,9 +23,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -41,9 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/kernel.pto index e6e2ee5c0..0471589ae 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/kernel.pto @@ -23,9 +23,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -41,9 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/kernel.pto index 273b38bda..f7b50637a 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/kernel.pto @@ -20,12 +20,12 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 - pto.dma_load %arg1, %ub_out, %c0_i64, %c0_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.dma_load %arg1, %ub_out, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -41,9 +41,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-tail/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vsts-tail/kernel.pto index 5ea2df48c..6bf21f48f 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vsts-tail/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-tail/kernel.pto @@ -17,9 +17,9 @@ module attributes {pto.target_arch = "a5"} { %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -36,9 +36,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vsts/kernel.pto index f5eaba126..b0399c76e 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vsts/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vsts/kernel.pto @@ -42,9 +42,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -60,9 +60,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/kernel.pto index 6dcfb7139..b447dcbbe 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/kernel.pto @@ -19,9 +19,9 @@ module attributes {pto.target_arch = "a5"} { %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -40,9 +40,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/kernel.pto index 38a90af23..795812f72 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/kernel.pto @@ -20,9 +20,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out1 = pto.addptr %ub_out, %c1 : !pto.ptr -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -44,9 +44,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/micro-op/vector-load-store/vstur/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vstur/kernel.pto index 104be5a26..43acc9ec3 100644 --- a/test/vpto/cases/micro-op/vector-load-store/vstur/kernel.pto +++ b/test/vpto/cases/micro-op/vector-load-store/vstur/kernel.pto @@ -30,9 +30,9 @@ module attributes {pto.target_arch = "a5"} { %ub_out1 = pto.addptr %ub_out, %c1 : !pto.ptr -> !pto.ptr %false = arith.constant false - pto.dma_load %arg0, %ub_in, %c0_i64, %c0_i64, %c128_i64 + pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] @@ -52,9 +52,9 @@ module attributes {pto.target_arch = "a5"} { pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.dma_store %ub_out, %arg1, %c0_i64, %c0_i64, %c128_i64 + pto.dma_store %ub_out, %arg1, %c128_i64 nburst(%c32_i64, %c128_i64, %c128_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe return } diff --git a/test/vpto/cases/vpto/dma-copy-rearrange/kernel.pto b/test/vpto/cases/vpto/dma-copy-rearrange/kernel.pto index d8e2429aa..b4ae545c9 100644 --- a/test/vpto/cases/vpto/dma-copy-rearrange/kernel.pto +++ b/test/vpto/cases/vpto/dma-copy-rearrange/kernel.pto @@ -43,18 +43,18 @@ module attributes {pto.target_arch = "a5"} { %dst_row2 = pto.castptr %c192_i64 : i64 -> !pto.ptr %dst_row3 = pto.castptr %c224_i64 : i64 -> !pto.ptr - pto.dma_copy %src_row2, %dst_row0, %c0_i64, %arg3 + pto.dma_copy %src_row2, %dst_row0, %arg3 nburst(%arg2, %arg4, %arg5) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 - pto.dma_copy %src_row0, %dst_row1, %c0_i64, %arg3 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.dma_copy %src_row0, %dst_row1, %arg3 nburst(%arg2, %arg4, %arg5) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 - pto.dma_copy %src_row3, %dst_row2, %c0_i64, %arg3 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.dma_copy %src_row3, %dst_row2, %arg3 nburst(%arg2, %arg4, %arg5) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 - pto.dma_copy %src_row1, %dst_row3, %c0_i64, %arg3 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.dma_copy %src_row1, %dst_row3, %arg3 nburst(%arg2, %arg4, %arg5) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe From 8e5b7d8dd380738a14b4ce26d12ebd84774b9e7f Mon Sep 17 00:00:00 2001 From: Zhang Zhendong Date: Sat, 25 Apr 2026 00:16:04 +0800 Subject: [PATCH 157/192] fix(dsl): forward trandom rounds in tile expansion (#245) (#250) --- lib/PTO/Transforms/ExpandTileOp.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/lib/PTO/Transforms/ExpandTileOp.cpp b/lib/PTO/Transforms/ExpandTileOp.cpp index 0dd21c5ea..b10285727 100644 --- a/lib/PTO/Transforms/ExpandTileOp.cpp +++ b/lib/PTO/Transforms/ExpandTileOp.cpp @@ -247,6 +247,10 @@ static std::optional getTCvtRoundModeString(pto::TCvtOp op) { return std::nullopt; } +static std::string getTRandomRoundsString(pto::TRandomOp op) { + return std::to_string(op.getRounds()); +} + static void appendOpContextAttrs( Operation *op, SmallVectorImpl> &attrs) { @@ -255,6 +259,8 @@ static void appendOpContextAttrs( if (roundMode) attrs.emplace_back("round_mode", *roundMode); } + if (auto trandom = dyn_cast(op)) + attrs.emplace_back("rounds", getTRandomRoundsString(trandom)); if (auto tcmp = dyn_cast(op)) { if (auto cmpModeAttr = tcmp.getCmpModeAttr()) { attrs.emplace_back("cmp_mode", From 3b60e536ccb602c6d7535294925e59ad5f93c70c Mon Sep 17 00:00:00 2001 From: cj Date: Sat, 25 Apr 2026 09:17:27 +0800 Subject: [PATCH 158/192] Add OP for TPartAdd & TPartMul (#213) * Add OP for TPartAdd & TPartMul * docs: add license headers to TPartAdd/TPartMul templates Add PR386 OAT.3 license headers to tpartadd_template.py and tpartmul_template.py. * fix-review: update license * Add OP for TPartAdd & TPartMul: update cmakelist --------- Co-authored-by: caojian5 --- lib/TileOps/tpartadd_template.py | 79 +++ lib/TileOps/tpartmul_template.py | 79 +++ .../expand_tile_op_tilelang_tpartadd.pto | 47 ++ .../expand_tile_op_tilelang_tpartmul.pto | 47 ++ .../npu/a5/src/st/testcase/CMakeLists.txt | 2 + .../src/st/testcase/tpartadd/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tpartadd/cases.py | 122 ++++ .../a5/src/st/testcase/tpartadd/compare.py | 53 ++ .../a5/src/st/testcase/tpartadd/gen_data.py | 96 +++ .../a5/src/st/testcase/tpartadd/launch.cpp | 76 +++ .../npu/a5/src/st/testcase/tpartadd/main.cpp | 164 +++++ .../a5/src/st/testcase/tpartadd/tpartadd.pto | 616 ++++++++++++++++++ .../src/st/testcase/tpartmul/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tpartmul/cases.py | 122 ++++ .../a5/src/st/testcase/tpartmul/compare.py | 53 ++ .../a5/src/st/testcase/tpartmul/gen_data.py | 96 +++ .../a5/src/st/testcase/tpartmul/launch.cpp | 76 +++ .../npu/a5/src/st/testcase/tpartmul/main.cpp | 164 +++++ .../a5/src/st/testcase/tpartmul/tpartmul.pto | 616 ++++++++++++++++++ 19 files changed, 2526 insertions(+) create mode 100644 lib/TileOps/tpartadd_template.py create mode 100644 lib/TileOps/tpartmul_template.py create mode 100644 test/basic/expand_tile_op_tilelang_tpartadd.pto create mode 100644 test/basic/expand_tile_op_tilelang_tpartmul.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tpartadd/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tpartadd/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tpartadd/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tpartadd/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tpartadd/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tpartadd/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tpartadd/tpartadd.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tpartmul/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tpartmul/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tpartmul/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tpartmul/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tpartmul/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tpartmul/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tpartmul/tpartmul.pto diff --git a/lib/TileOps/tpartadd_template.py b/lib/TileOps/tpartadd_template.py new file mode 100644 index 000000000..8e9716389 --- /dev/null +++ b/lib/TileOps/tpartadd_template.py @@ -0,0 +1,79 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tpartadd""" + +import tilelang_dsl as pto + + +@pto.inline_proc +def tpart_op_instr(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile, valid_rows, valid_cols): + dtype = dst.element_type + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + summed = pto.vadd(lhs, rhs, mask) + pto.vsts(summed, dst[row, col:], mask) + return None + +@pto.inline_proc +def tpart_copy_instr(dst: pto.Tile, src: pto.Tile, valid_rows, valid_cols, start_row): + dtype = dst.element_type + for row in range(start_row, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + val = pto.vlds(src[row, col:]) + pto.vsts(val, dst[row, col:], mask) + return None + +@pto.inline_proc +def tpart_op(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile, + dst_valid_rows, dst_valid_cols, + src1_valid_rows, src1_valid_cols): + + src1_eq_dst = (src1_valid_rows == dst_valid_rows and src1_valid_cols == dst_valid_cols) + src1_row_lt_dst = (src1_valid_rows < dst_valid_rows and src1_valid_cols == dst_valid_cols) + src1_col_lt_dst = (src1_valid_rows <= dst_valid_rows and src1_valid_cols < dst_valid_cols) + + if src1_eq_dst: + tpart_op_instr(dst, src0, src1, dst_valid_rows, dst_valid_cols) + elif src1_col_lt_dst: + tpart_copy_instr(dst, src0, dst_valid_rows, dst_valid_cols, 0) + if src1_valid_cols > 0: + tpart_op_instr(dst, src0, src1, src1_valid_rows, src1_valid_cols) + elif src1_row_lt_dst: + if src1_valid_cols > 0: + tpart_op_instr(dst, src0, src1, src1_valid_rows, src1_valid_cols) + tpart_copy_instr(dst, src0, dst_valid_rows, dst_valid_cols, src1_valid_rows) + + return + +@pto.vkernel( + target="a5", + op="pto.tpartadd" +) +def template_tpartadd(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dst_valid_rows, dst_valid_cols = dst.valid_shape + src0_valid_rows, src0_valid_cols = src0.valid_shape + src1_valid_rows, src1_valid_cols = src1.valid_shape + + src0_eq_dst = (src0_valid_rows == dst_valid_rows and src0_valid_cols == dst_valid_cols) + src1_eq_dst = (src1_valid_rows == dst_valid_rows and src1_valid_cols == dst_valid_cols) + + if src0_eq_dst or src1_eq_dst: + if src0_eq_dst: + tpart_op(dst, src0, src1, dst_valid_rows, dst_valid_cols, src1_valid_rows, src1_valid_cols) + elif src1_eq_dst: + tpart_op(dst, src1, src0, dst_valid_rows, dst_valid_cols, src0_valid_rows, src0_valid_cols) + # TODO: raise an error later + + return \ No newline at end of file diff --git a/lib/TileOps/tpartmul_template.py b/lib/TileOps/tpartmul_template.py new file mode 100644 index 000000000..fe39597ca --- /dev/null +++ b/lib/TileOps/tpartmul_template.py @@ -0,0 +1,79 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tpartmul""" + +import tilelang_dsl as pto + + +@pto.inline_proc +def tpart_op_instr(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile, valid_rows, valid_cols): + dtype = dst.element_type + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + summed = pto.vmul(lhs, rhs, mask) + pto.vsts(summed, dst[row, col:], mask) + return None + +@pto.inline_proc +def tpart_copy_instr(dst: pto.Tile, src: pto.Tile, valid_rows, valid_cols, start_row): + dtype = dst.element_type + for row in range(start_row, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + val = pto.vlds(src[row, col:]) + pto.vsts(val, dst[row, col:], mask) + return None + +@pto.inline_proc +def tpart_op(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile, + dst_valid_rows, dst_valid_cols, + src1_valid_rows, src1_valid_cols): + + src1_eq_dst = (src1_valid_rows == dst_valid_rows and src1_valid_cols == dst_valid_cols) + src1_row_lt_dst = (src1_valid_rows < dst_valid_rows and src1_valid_cols == dst_valid_cols) + src1_col_lt_dst = (src1_valid_rows <= dst_valid_rows and src1_valid_cols < dst_valid_cols) + + if src1_eq_dst: + tpart_op_instr(dst, src0, src1, dst_valid_rows, dst_valid_cols) + elif src1_col_lt_dst: + tpart_copy_instr(dst, src0, dst_valid_rows, dst_valid_cols, 0) + if src1_valid_cols > 0: + tpart_op_instr(dst, src0, src1, src1_valid_rows, src1_valid_cols) + elif src1_row_lt_dst: + if src1_valid_cols > 0: + tpart_op_instr(dst, src0, src1, src1_valid_rows, src1_valid_cols) + tpart_copy_instr(dst, src0, dst_valid_rows, dst_valid_cols, src1_valid_rows) + + return + +@pto.vkernel( + target="a5", + op="pto.tpartmul" +) +def template_tpartmul(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dst_valid_rows, dst_valid_cols = dst.valid_shape + src0_valid_rows, src0_valid_cols = src0.valid_shape + src1_valid_rows, src1_valid_cols = src1.valid_shape + + src0_eq_dst = (src0_valid_rows == dst_valid_rows and src0_valid_cols == dst_valid_cols) + src1_eq_dst = (src1_valid_rows == dst_valid_rows and src1_valid_cols == dst_valid_cols) + + if src0_eq_dst or src1_eq_dst: + if src0_eq_dst: + tpart_op(dst, src0, src1, dst_valid_rows, dst_valid_cols, src1_valid_rows, src1_valid_cols) + elif src1_eq_dst: + tpart_op(dst, src1, src0, dst_valid_rows, dst_valid_cols, src0_valid_rows, src0_valid_cols) + # TODO: raise an error later + + return \ No newline at end of file diff --git a/test/basic/expand_tile_op_tilelang_tpartadd.pto b/test/basic/expand_tile_op_tilelang_tpartadd.pto new file mode 100644 index 000000000..894f88de9 --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_tpartadd.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tpartadd/tpartmul/tpartmax/tpartmin via TileLang Python DSL templates. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// Tile ops should be lowered to vector-style VPTO IR. + +// TPartAdd checks +// CHECK-LABEL: func.func @TPARTADD +// CHECK-NOT: pto.tpartadd ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vadd +// CHECK: pto.vsts + +module { + func.func @TPARTADD() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tpartadd ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/expand_tile_op_tilelang_tpartmul.pto b/test/basic/expand_tile_op_tilelang_tpartmul.pto new file mode 100644 index 000000000..e1e66c9e5 --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_tpartmul.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tpartadd/tpartmul/tpartmax/tpartmin via TileLang Python DSL templates. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// Tile ops should be lowered to vector-style VPTO IR. + +// TPartMul checks +// CHECK-LABEL: func.func @TPARTMUL +// CHECK-NOT: pto.tpartmul ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vmul +// CHECK: pto.vsts + +module { + func.func @TPARTMUL() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tpartmul ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt index 49d3df96b..b1b1af182 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt +++ b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt @@ -138,6 +138,8 @@ set(ALL_TESTCASES tlog tneg tnot + tpartadd + tpartmul trecip trsqrt tsqrt diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/CMakeLists.txt new file mode 100644 index 000000000..4eb1affcf --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tpartadd) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/cases.py new file mode 100644 index 000000000..6ec74d95d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/cases.py @@ -0,0 +1,122 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tpartadd ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions (same for src0/src1/dst). + - valid_shape: (valid_rows, valid_cols) — src0 valid region (src0_eq_dst scenario). + - src1_vshape: (src1_valid_rows, src1_valid_cols) — src1 valid region. + May be smaller than dst valid region for partial add cases. + - dst_vshape: (dst_valid_rows, dst_valid_cols) — dst valid region. + - eps: tolerance for numpy.allclose (atol and rtol). + +tpartadd semantics: + - If src0_valid == dst_valid: dst[:src1_rows,:src1_cols] = src0[:src1_rows,:src1_cols] + src1[:src1_rows,:src1_cols] + dst[src1_rows:,:] = src0[src1_rows:,:] (copy remaining rows) + OR (for col_less) dst[:,:src1_cols] = src0[:,:src1_cols] + src1[:,:src1_cols] + dst[:,src1_cols:] = src0[:,src1_cols:] (copy remaining cols) + - If src1_valid == dst_valid: similar logic with src1 as the full operand. + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + # float32 cases + { + "name": "f32_64x64_full", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (64, 64), # src0 valid region + "src1_vshape": (64, 64), # src1 valid region (same as dst) + "dst_vshape": (64, 64), # dst valid region + "eps": 1e-6, + }, + { + "name": "f32_64x64_src0_row_less", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (8, 64), # src0 valid region (row_less) + "src1_vshape": (64, 64), # src1 valid region (equals dst) + "dst_vshape": (64, 64), # dst valid region + "eps": 1e-6, + }, + { + "name": "f32_64x64_src0_col_less", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (64, 8), # src0 valid region (col_less) + "src1_vshape": (64, 64), # src1 valid region (equals dst) + "dst_vshape": (64, 64), # dst valid region + "eps": 1e-6, + }, + { + "name": "f32_64x64_src1_row_less", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (64, 64), # src0 valid region (equals dst) + "src1_vshape": (8, 64), # src1 valid region (row_less) + "dst_vshape": (64, 64), # dst valid region + "eps": 1e-6, + }, + { + "name": "f32_64x64_src1_col_less", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (64, 64), # src0 valid region (equals dst) + "src1_vshape": (64, 8), # src1 valid region (col_less) + "dst_vshape": (64, 64), # dst valid region + "eps": 1e-6, + }, + # float16 cases + { + "name": "f16_8x48_src0_col_less", + "dtype": np.float16, + "shape": (8, 48), + "valid_shape": (8, 16), # src0 valid region (col_less) + "src1_vshape": (8, 48), # src1 valid region (equals dst) + "dst_vshape": (8, 48), # dst valid region + "eps": 1e-3, + }, + { + "name": "f16_8x768_src0_col_less", + "dtype": np.float16, + "shape": (8, 768), + "valid_shape": (8, 512), # src0 valid region (col_less) + "src1_vshape": (8, 768), # src1 valid region (equals dst) + "dst_vshape": (8, 768), # dst valid region + "eps": 1e-3, + }, + # int16 cases + { + "name": "i16_8x48_src1_col_less", + "dtype": np.int16, + "shape": (8, 48), + "valid_shape": (8, 48), # src0 valid region (equals dst) + "src1_vshape": (8, 16), # src1 valid region (col_less) + "dst_vshape": (8, 48), # dst valid region + "eps": 0, # exact match for int + }, + # int32 cases + { + "name": "i32_64x64_src0_row_less", + "dtype": np.int32, + "shape": (64, 64), + "valid_shape": (8, 64), # src0 valid region (row_less) + "src1_vshape": (64, 64), # src1 valid region (equals dst) + "dst_vshape": (64, 64), # dst valid region + "eps": 0, # exact match for int + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/compare.py new file mode 100644 index 000000000..283ee788a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/compare.py @@ -0,0 +1,53 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +# Add parent directory to path for st_common import +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from st_common import result_cmp, style_fail, style_pass + +from cases import CASES + + +def main(): + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + dtype = case["dtype"] + dst_vr, dst_vc = case["dst_vshape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=dtype).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=dtype).reshape(shape) + + # Compare only the dst valid region + ok = result_cmp(golden[:dst_vr, :dst_vc], output[:dst_vr, :dst_vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/gen_data.py new file mode 100644 index 000000000..9ecaf30fa --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/gen_data.py @@ -0,0 +1,96 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +import os +import sys + +# Add parent directory to path for st_common import +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from st_common import setup_case_rng, save_case_data + +from cases import CASES + + +def _to_tuple(shape): + """Convert shape to tuple if needed.""" + if isinstance(shape, tuple): + return shape + return tuple(shape) + + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = _to_tuple(case["shape"]) + src0_valid = _to_tuple(case["valid_shape"]) + src1_valid = _to_tuple(case["src1_vshape"]) + dst_valid = _to_tuple(case["dst_vshape"]) + + rows, cols = shape + src0_vr, src0_vc = src0_valid + src1_vr, src1_vc = src1_valid + dst_vr, dst_vc = dst_valid + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + input2 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + + # Compute golden according to tpartadd semantics from template: + # If src0_valid == dst_valid: use tpart_op with src0 as full operand + # - If src1 row_less: add for src1 region, copy src0 for remaining rows + # - If src1 col_less: copy src0 full, then add for overlapping region + # If src1_valid == dst_valid: use tpart_op with src1 as full operand (swap src0/src1) + + src0_eq_dst = (src0_vr == dst_vr and src0_vc == dst_vc) + src1_eq_dst = (src1_vr == dst_vr and src1_vc == dst_vc) + + if src0_eq_dst: + # src0 is the full operand matching dst + src1_row_lt_dst = (src1_vr < dst_vr and src1_vc == dst_vc) + src1_col_lt_dst = (src1_vr <= dst_vr and src1_vc < dst_vc) + + if src1_eq_dst: + # Full add: dst[:] = src0[:] + src1[:] + golden[:dst_vr, :dst_vc] = (input1[:dst_vr, :dst_vc] + input2[:dst_vr, :dst_vc]).astype(dtype, copy=False) + elif src1_col_lt_dst: + # Col_less: first copy src0, then add in overlapping region + golden[:dst_vr, :dst_vc] = input1[:dst_vr, :dst_vc].copy() + if src1_vc > 0: + golden[:src1_vr, :src1_vc] = (input1[:src1_vr, :src1_vc] + input2[:src1_vr, :src1_vc]).astype(dtype, copy=False) + elif src1_row_lt_dst: + # Row_less: add for src1 region, copy src0 for remaining rows + if src1_vc > 0: + golden[:src1_vr, :src1_vc] = (input1[:src1_vr, :src1_vc] + input2[:src1_vr, :src1_vc]).astype(dtype, copy=False) + golden[src1_vr:dst_vr, :dst_vc] = input1[src1_vr:dst_vr, :dst_vc].copy() + elif src1_eq_dst: + # src1 is the full operand matching dst, swap src0/src1 in the logic + src0_row_lt_dst = (src0_vr < dst_vr and src0_vc == dst_vc) + src0_col_lt_dst = (src0_vr <= dst_vr and src0_vc < dst_vc) + + if src0_eq_dst: + # Full add: dst[:] = src0[:] + src1[:] + golden[:dst_vr, :dst_vc] = (input1[:dst_vr, :dst_vc] + input2[:dst_vr, :dst_vc]).astype(dtype, copy=False) + elif src0_col_lt_dst: + # Col_less: first copy src1, then add in overlapping region + golden[:dst_vr, :dst_vc] = input2[:dst_vr, :dst_vc].copy() + if src0_vc > 0: + golden[:src0_vr, :src0_vc] = (input1[:src0_vr, :src0_vc] + input2[:src0_vr, :src0_vc]).astype(dtype, copy=False) + elif src0_row_lt_dst: + # Row_less: add for src0 region, copy src1 for remaining rows + if src0_vc > 0: + golden[:src0_vr, :src0_vc] = (input1[:src0_vr, :src0_vc] + input2[:src0_vr, :src0_vc]).astype(dtype, copy=False) + golden[src0_vr:dst_vr, :dst_vc] = input2[src0_vr:dst_vr, :dst_vc].copy() + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} src0_valid={src0_valid} src1_valid={src1_valid} dst_valid={dst_valid} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/launch.cpp new file mode 100644 index 000000000..02d725199 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/launch.cpp @@ -0,0 +1,76 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 64x64 full +extern "C" __global__ AICORE void TPARTADD_f32_64x64_full(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTADD_f32_64x64_full(float *a, float *b, float *c, void *stream) { + TPARTADD_f32_64x64_full<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 1: f32 64x64 src0 row less +extern "C" __global__ AICORE void TPARTADD_f32_64x64_src0_row_less(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTADD_f32_64x64_src0_row_less(float *a, float *b, float *c, void *stream) { + TPARTADD_f32_64x64_src0_row_less<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 2: f32 64x64 src0 col less +extern "C" __global__ AICORE void TPARTADD_f32_64x64_src0_col_less(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTADD_f32_64x64_src0_col_less(float *a, float *b, float *c, void *stream) { + TPARTADD_f32_64x64_src0_col_less<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 3: f32 64x64 src1 row less +extern "C" __global__ AICORE void TPARTADD_f32_64x64_src1_row_less(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTADD_f32_64x64_src1_row_less(float *a, float *b, float *c, void *stream) { + TPARTADD_f32_64x64_src1_row_less<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 4: f32 64x64 src1 col less +extern "C" __global__ AICORE void TPARTADD_f32_64x64_src1_col_less(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTADD_f32_64x64_src1_col_less(float *a, float *b, float *c, void *stream) { + TPARTADD_f32_64x64_src1_col_less<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 5: f16 8x48 src0 col less +extern "C" __global__ AICORE void TPARTADD_f16_8x48_src0_col_less(__gm__ uint16_t *a, __gm__ uint16_t *b, __gm__ uint16_t *c); + +void LaunchTPARTADD_f16_8x48_src0_col_less(uint16_t *a, uint16_t *b, uint16_t *c, void *stream) { + TPARTADD_f16_8x48_src0_col_less<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b, (__gm__ uint16_t *)c); +} + +// Case 6: f16 8x768 src0 col less +extern "C" __global__ AICORE void TPARTADD_f16_8x768_src0_col_less(__gm__ uint16_t *a, __gm__ uint16_t *b, __gm__ uint16_t *c); + +void LaunchTPARTADD_f16_8x768_src0_col_less(uint16_t *a, uint16_t *b, uint16_t *c, void *stream) { + TPARTADD_f16_8x768_src0_col_less<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b, (__gm__ uint16_t *)c); +} + +// Case 7: i16 8x48 src1 col less +extern "C" __global__ AICORE void TPARTADD_i16_8x48_src1_col_less(__gm__ int16_t *a, __gm__ int16_t *b, __gm__ int16_t *c); + +void LaunchTPARTADD_i16_8x48_src1_col_less(int16_t *a, int16_t *b, int16_t *c, void *stream) { + TPARTADD_i16_8x48_src1_col_less<<<1, nullptr, stream>>>((__gm__ int16_t *)a, (__gm__ int16_t *)b, (__gm__ int16_t *)c); +} + +// Case 8: i32 64x64 src0 row less +extern "C" __global__ AICORE void TPARTADD_i32_64x64_src0_row_less(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); + +void LaunchTPARTADD_i32_64x64_src0_row_less(int32_t *a, int32_t *b, int32_t *c, void *stream) { + TPARTADD_i32_64x64_src0_row_less<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/main.cpp new file mode 100644 index 000000000..34e013b77 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/main.cpp @@ -0,0 +1,164 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tpartadd ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTPARTADD_f32_64x64_full(float *a, float *b, float *c, void *stream); +void LaunchTPARTADD_f32_64x64_src0_row_less(float *a, float *b, float *c, void *stream); +void LaunchTPARTADD_f32_64x64_src0_col_less(float *a, float *b, float *c, void *stream); +void LaunchTPARTADD_f32_64x64_src1_row_less(float *a, float *b, float *c, void *stream); +void LaunchTPARTADD_f32_64x64_src1_col_less(float *a, float *b, float *c, void *stream); +void LaunchTPARTADD_f16_8x48_src0_col_less(uint16_t *a, uint16_t *b, uint16_t *c, void *stream); +void LaunchTPARTADD_f16_8x768_src0_col_less(uint16_t *a, uint16_t *b, uint16_t *c, void *stream); +void LaunchTPARTADD_i16_8x48_src1_col_less(int16_t *a, int16_t *b, int16_t *c, void *stream); +void LaunchTPARTADD_i32_64x64_src0_row_less(int32_t *a, int32_t *b, int32_t *c, void *stream); + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t src0ValidRows; // src0 effective rows + size_t src0ValidCols; // src0 effective cols + size_t src1ValidRows; // src1 effective rows + size_t src1ValidCols; // src1 effective cols + size_t dstValidRows; // dst effective rows + size_t dstValidCols; // dst effective cols + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_64x64_full", reinterpret_cast(LaunchTPARTADD_f32_64x64_full), 64, 64, 64, 64, 64, 64, 64, 64, sizeof(float)}, + {"f32_64x64_src0_row_less", reinterpret_cast(LaunchTPARTADD_f32_64x64_src0_row_less), 64, 64, 8, 64, 64, 64, 64, 64, sizeof(float)}, + {"f32_64x64_src0_col_less", reinterpret_cast(LaunchTPARTADD_f32_64x64_src0_col_less), 64, 64, 64, 8, 64, 64, 64, 64, sizeof(float)}, + {"f32_64x64_src1_row_less", reinterpret_cast(LaunchTPARTADD_f32_64x64_src1_row_less), 64, 64, 64, 64, 8, 64, 64, 64, sizeof(float)}, + {"f32_64x64_src1_col_less", reinterpret_cast(LaunchTPARTADD_f32_64x64_src1_col_less), 64, 64, 64, 64, 64, 8, 64, 64, sizeof(float)}, + {"f16_8x48_src0_col_less", reinterpret_cast(LaunchTPARTADD_f16_8x48_src0_col_less), 8, 48, 8, 16, 8, 48, 8, 48, sizeof(uint16_t)}, + {"f16_8x768_src0_col_less", reinterpret_cast(LaunchTPARTADD_f16_8x768_src0_col_less), 8,768, 8,512, 8,768, 8,768, sizeof(uint16_t)}, + {"i16_8x48_src1_col_less", reinterpret_cast(LaunchTPARTADD_i16_8x48_src1_col_less), 8, 48, 8, 48, 8, 16, 8, 48, sizeof(int16_t)}, + {"i32_64x64_src0_row_less", reinterpret_cast(LaunchTPARTADD_i32_64x64_src0_row_less), 64, 64, 8, 64, 64, 64, 64, 64, sizeof(int32_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, src0_valid=%zux%zu, src1_valid=%zux%zu, dst_valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.src0ValidRows, tc.src0ValidCols, + tc.src1ValidRows, tc.src1ValidCols, tc.dstValidRows, tc.dstValidCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = fileSize; + size_t src1FileSize = fileSize; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), fileSize); + aclrtMallocHost((void **)(&src1Host), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); + + aclrtMalloc((void **)&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, fileSize, src0Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, fileSize, src1Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tpartadd [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/tpartadd.pto b/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/tpartadd.pto new file mode 100644 index 000000000..61fbd5021 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/tpartadd.pto @@ -0,0 +1,616 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use the file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tpartadd: partial elementwise add with valid region handling. +// Multiple cases with different valid_shape combinations in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. + +module { + // Case 0: f32 64x64 full (src0/src1/dst all have same valid_shape 64x64) + func.func @TPARTADD_f32_64x64_full(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartadd ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case 1: f32 64x64 src0 row less (src0 valid 8x64, src1/dst valid 64x64) + func.func @TPARTADD_f32_64x64_src0_row_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + // src0: partial valid region (8,64) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: full valid region (64,64) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: full valid region (64,64) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartadd ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case 2: f32 64x64 src0 col less (src0 valid 64x8, src1/dst valid 64x64) + func.func @TPARTADD_f32_64x64_src0_col_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + // src0: partial valid region (64,8) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: full valid region (64,64) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: full valid region (64,64) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartadd ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case 3: f32 64x64 src1 row less (src0/dst valid 64x64, src1 valid 8x64) + func.func @TPARTADD_f32_64x64_src1_row_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + // src0: full valid region (64,64) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: partial valid region (8,64) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: full valid region (64,64) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartadd ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case 4: f32 64x64 src1 col less (src0/dst valid 64x64, src1 valid 64x8) + func.func @TPARTADD_f32_64x64_src1_col_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + // src0: full valid region (64,64) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: partial valid region (64,8) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: full valid region (64,64) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartadd ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case 5: f16 8x48 src0 col less (src0 valid 8x16, src1/dst valid 8x48) + func.func @TPARTADD_f16_8x48_src0_col_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c48 = arith.constant 48 : index + %c384 = arith.constant 384 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c8, %c48], + strides = [%c384, %c384, %c384, %c48, %c1] + : !pto.tensor_view<1x1x1x8x48xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c8, %c48], + strides = [%c384, %c384, %c384, %c48, %c1] + : !pto.tensor_view<1x1x1x8x48xf16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c8, %c48], + strides = [%c384, %c384, %c384, %c48, %c1] + : !pto.tensor_view<1x1x1x8x48xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c48] + : !pto.tensor_view<1x1x1x8x48xf16> -> !pto.partition_tensor_view<1x1x1x8x48xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c48] + : !pto.tensor_view<1x1x1x8x48xf16> -> !pto.partition_tensor_view<1x1x1x8x48xf16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c48] + : !pto.tensor_view<1x1x1x8x48xf16> -> !pto.partition_tensor_view<1x1x1x8x48xf16> + + // src0: partial valid region (8,16) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: full valid region (8,48) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: full valid region (8,48) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x8x48xf16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x8x48xf16>) + outs(%b : !pto.tile_buf) + + pto.tpartadd ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x8x48xf16>) + return + } + + // Case 6: f16 8x768 src0 col less (src0 valid 8x512, src1/dst valid 8x768) + func.func @TPARTADD_f16_8x768_src0_col_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c512 = arith.constant 512 : index + %c768 = arith.constant 768 : index + %c6144 = arith.constant 6144 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c8, %c768], + strides = [%c6144, %c6144, %c6144, %c768, %c1] + : !pto.tensor_view<1x1x1x8x768xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c8, %c768], + strides = [%c6144, %c6144, %c6144, %c768, %c1] + : !pto.tensor_view<1x1x1x8x768xf16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c8, %c768], + strides = [%c6144, %c6144, %c6144, %c768, %c1] + : !pto.tensor_view<1x1x1x8x768xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c768] + : !pto.tensor_view<1x1x1x8x768xf16> -> !pto.partition_tensor_view<1x1x1x8x768xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c768] + : !pto.tensor_view<1x1x1x8x768xf16> -> !pto.partition_tensor_view<1x1x1x8x768xf16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c768] + : !pto.tensor_view<1x1x1x8x768xf16> -> !pto.partition_tensor_view<1x1x1x8x768xf16> + + // src0: partial valid region (8,512) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: full valid region (8,768) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: full valid region (8,768) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x8x768xf16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x8x768xf16>) + outs(%b : !pto.tile_buf) + + pto.tpartadd ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x8x768xf16>) + return + } + + // Case 7: i16 8x48 src1 col less (src0/dst valid 8x48, src1 valid 8x16) + func.func @TPARTADD_i16_8x48_src1_col_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c48 = arith.constant 48 : index + %c384 = arith.constant 384 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c8, %c48], + strides = [%c384, %c384, %c384, %c48, %c1] + : !pto.tensor_view<1x1x1x8x48xi16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c8, %c48], + strides = [%c384, %c384, %c384, %c48, %c1] + : !pto.tensor_view<1x1x1x8x48xi16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c8, %c48], + strides = [%c384, %c384, %c384, %c48, %c1] + : !pto.tensor_view<1x1x1x8x48xi16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c48] + : !pto.tensor_view<1x1x1x8x48xi16> -> !pto.partition_tensor_view<1x1x1x8x48xi16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c48] + : !pto.tensor_view<1x1x1x8x48xi16> -> !pto.partition_tensor_view<1x1x1x8x48xi16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c48] + : !pto.tensor_view<1x1x1x8x48xi16> -> !pto.partition_tensor_view<1x1x1x8x48xi16> + + // src0: full valid region (8,48) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: partial valid region (8,16) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: full valid region (8,48) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x8x48xi16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x8x48xi16>) + outs(%b : !pto.tile_buf) + + pto.tpartadd ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x8x48xi16>) + return + } + + // Case 8: i32 64x64 src0 row less (src0 valid 8x64, src1/dst valid 64x64) + func.func @TPARTADD_i32_64x64_src0_row_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + + // src0: partial valid region (8,64) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: full valid region (64,64) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: full valid region (64,64) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + outs(%b : !pto.tile_buf) + + pto.tpartadd ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/CMakeLists.txt new file mode 100644 index 000000000..190439e25 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tpartmul) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/cases.py new file mode 100644 index 000000000..ad892fa2b --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/cases.py @@ -0,0 +1,122 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tpartmul ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions (same for src0/src1/dst). + - valid_shape: (valid_rows, valid_cols) — src0 valid region (src0_eq_dst scenario). + - src1_vshape: (src1_valid_rows, src1_valid_cols) — src1 valid region. + May be smaller than dst valid region for partial mul cases. + - dst_vshape: (dst_valid_rows, dst_valid_cols) — dst valid region. + - eps: tolerance for numpy.allclose (atol and rtol). + +tpartmul semantics: + - If src0_valid == dst_valid: dst[:src1_rows,:src1_cols] = src0[:src1_rows,:src1_cols] * src1[:src1_rows,:src1_cols] + dst[src1_rows:,:] = src0[src1_rows:,:] (copy remaining rows) + OR (for col_less) dst[:,:src1_cols] = src0[:,:src1_cols] * src1[:,:src1_cols] + dst[:,src1_cols:] = src0[:,src1_cols:] (copy remaining cols) + - If src1_valid == dst_valid: similar logic with src1 as the full operand. + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + # float32 cases + { + "name": "f32_64x64_full", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (64, 64), # src0 valid region + "src1_vshape": (64, 64), # src1 valid region (same as dst) + "dst_vshape": (64, 64), # dst valid region + "eps": 1e-6, + }, + { + "name": "f32_64x64_src0_row_less", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (8, 64), # src0 valid region (row_less) + "src1_vshape": (64, 64), # src1 valid region (equals dst) + "dst_vshape": (64, 64), # dst valid region + "eps": 1e-6, + }, + { + "name": "f32_64x64_src0_col_less", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (64, 8), # src0 valid region (col_less) + "src1_vshape": (64, 64), # src1 valid region (equals dst) + "dst_vshape": (64, 64), # dst valid region + "eps": 1e-6, + }, + { + "name": "f32_64x64_src1_row_less", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (64, 64), # src0 valid region (equals dst) + "src1_vshape": (8, 64), # src1 valid region (row_less) + "dst_vshape": (64, 64), # dst valid region + "eps": 1e-6, + }, + { + "name": "f32_64x64_src1_col_less", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (64, 64), # src0 valid region (equals dst) + "src1_vshape": (64, 8), # src1 valid region (col_less) + "dst_vshape": (64, 64), # dst valid region + "eps": 1e-6, + }, + # float16 cases + { + "name": "f16_8x48_src0_col_less", + "dtype": np.float16, + "shape": (8, 48), + "valid_shape": (8, 16), # src0 valid region (col_less) + "src1_vshape": (8, 48), # src1 valid region (equals dst) + "dst_vshape": (8, 48), # dst valid region + "eps": 1e-3, + }, + { + "name": "f16_8x768_src0_col_less", + "dtype": np.float16, + "shape": (8, 768), + "valid_shape": (8, 512), # src0 valid region (col_less) + "src1_vshape": (8, 768), # src1 valid region (equals dst) + "dst_vshape": (8, 768), # dst valid region + "eps": 1e-3, + }, + # int16 cases + { + "name": "i16_8x48_src1_col_less", + "dtype": np.int16, + "shape": (8, 48), + "valid_shape": (8, 48), # src0 valid region (equals dst) + "src1_vshape": (8, 16), # src1 valid region (col_less) + "dst_vshape": (8, 48), # dst valid region + "eps": 0, # exact match for int + }, + # int32 cases + { + "name": "i32_64x64_src0_row_less", + "dtype": np.int32, + "shape": (64, 64), + "valid_shape": (8, 64), # src0 valid region (row_less) + "src1_vshape": (64, 64), # src1 valid region (equals dst) + "dst_vshape": (64, 64), # dst valid region + "eps": 0, # exact match for int + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/compare.py new file mode 100644 index 000000000..283ee788a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/compare.py @@ -0,0 +1,53 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +# Add parent directory to path for st_common import +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from st_common import result_cmp, style_fail, style_pass + +from cases import CASES + + +def main(): + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + dtype = case["dtype"] + dst_vr, dst_vc = case["dst_vshape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=dtype).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=dtype).reshape(shape) + + # Compare only the dst valid region + ok = result_cmp(golden[:dst_vr, :dst_vc], output[:dst_vr, :dst_vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/gen_data.py new file mode 100644 index 000000000..5ca965d0e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/gen_data.py @@ -0,0 +1,96 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +import os +import sys + +# Add parent directory to path for st_common import +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from st_common import setup_case_rng, save_case_data + +from cases import CASES + + +def _to_tuple(shape): + """Convert shape to tuple if needed.""" + if isinstance(shape, tuple): + return shape + return tuple(shape) + + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = _to_tuple(case["shape"]) + src0_valid = _to_tuple(case["valid_shape"]) + src1_valid = _to_tuple(case["src1_vshape"]) + dst_valid = _to_tuple(case["dst_vshape"]) + + rows, cols = shape + src0_vr, src0_vc = src0_valid + src1_vr, src1_vc = src1_valid + dst_vr, dst_vc = dst_valid + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + input2 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + + # Compute golden according to tpartmul semantics from template: + # If src0_valid == dst_valid: use tpart_op with src0 as full operand + # - If src1 row_less: mul for src1 region, copy src0 for remaining rows + # - If src1 col_less: copy src0 full, then mul for overlapping region + # If src1_valid == dst_valid: use tpart_op with src1 as full operand (swap src0/src1) + + src0_eq_dst = (src0_vr == dst_vr and src0_vc == dst_vc) + src1_eq_dst = (src1_vr == dst_vr and src1_vc == dst_vc) + + if src0_eq_dst: + # src0 is the full operand matching dst + src1_row_lt_dst = (src1_vr < dst_vr and src1_vc == dst_vc) + src1_col_lt_dst = (src1_vr <= dst_vr and src1_vc < dst_vc) + + if src1_eq_dst: + # Full mul: dst[:] = src0[:] * src1[:] + golden[:dst_vr, :dst_vc] = (input1[:dst_vr, :dst_vc] * input2[:dst_vr, :dst_vc]).astype(dtype, copy=False) + elif src1_col_lt_dst: + # Col_less: first copy src0, then mul in overlapping region + golden[:dst_vr, :dst_vc] = input1[:dst_vr, :dst_vc].copy() + if src1_vc > 0: + golden[:src1_vr, :src1_vc] = (input1[:src1_vr, :src1_vc] * input2[:src1_vr, :src1_vc]).astype(dtype, copy=False) + elif src1_row_lt_dst: + # Row_less: mul for src1 region, copy src0 for remaining rows + if src1_vc > 0: + golden[:src1_vr, :src1_vc] = (input1[:src1_vr, :src1_vc] * input2[:src1_vr, :src1_vc]).astype(dtype, copy=False) + golden[src1_vr:dst_vr, :dst_vc] = input1[src1_vr:dst_vr, :dst_vc].copy() + elif src1_eq_dst: + # src1 is the full operand matching dst, swap src0/src1 in the logic + src0_row_lt_dst = (src0_vr < dst_vr and src0_vc == dst_vc) + src0_col_lt_dst = (src0_vr <= dst_vr and src0_vc < dst_vc) + + if src0_eq_dst: + # Full mul: dst[:] = src0[:] * src1[:] + golden[:dst_vr, :dst_vc] = (input1[:dst_vr, :dst_vc] * input2[:dst_vr, :dst_vc]).astype(dtype, copy=False) + elif src0_col_lt_dst: + # Col_less: first copy src1, then mul in overlapping region + golden[:dst_vr, :dst_vc] = input2[:dst_vr, :dst_vc].copy() + if src0_vc > 0: + golden[:src0_vr, :src0_vc] = (input1[:src0_vr, :src0_vc] * input2[:src0_vr, :src0_vc]).astype(dtype, copy=False) + elif src0_row_lt_dst: + # Row_less: mul for src0 region, copy src1 for remaining rows + if src0_vc > 0: + golden[:src0_vr, :src0_vc] = (input1[:src0_vr, :src0_vc] * input2[:src0_vr, :src0_vc]).astype(dtype, copy=False) + golden[src0_vr:dst_vr, :dst_vc] = input2[src0_vr:dst_vr, :dst_vc].copy() + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} src0_valid={src0_valid} src1_valid={src1_valid} dst_valid={dst_valid} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/launch.cpp new file mode 100644 index 000000000..fb00bb99f --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/launch.cpp @@ -0,0 +1,76 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 64x64 full +extern "C" __global__ AICORE void TPARTMUL_f32_64x64_full(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTMUL_f32_64x64_full(float *a, float *b, float *c, void *stream) { + TPARTMUL_f32_64x64_full<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 1: f32 64x64 src0 row less +extern "C" __global__ AICORE void TPARTMUL_f32_64x64_src0_row_less(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTMUL_f32_64x64_src0_row_less(float *a, float *b, float *c, void *stream) { + TPARTMUL_f32_64x64_src0_row_less<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 2: f32 64x64 src0 col less +extern "C" __global__ AICORE void TPARTMUL_f32_64x64_src0_col_less(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTMUL_f32_64x64_src0_col_less(float *a, float *b, float *c, void *stream) { + TPARTMUL_f32_64x64_src0_col_less<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 3: f32 64x64 src1 row less +extern "C" __global__ AICORE void TPARTMUL_f32_64x64_src1_row_less(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTMUL_f32_64x64_src1_row_less(float *a, float *b, float *c, void *stream) { + TPARTMUL_f32_64x64_src1_row_less<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 4: f32 64x64 src1 col less +extern "C" __global__ AICORE void TPARTMUL_f32_64x64_src1_col_less(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTMUL_f32_64x64_src1_col_less(float *a, float *b, float *c, void *stream) { + TPARTMUL_f32_64x64_src1_col_less<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 5: f16 8x48 src0 col less +extern "C" __global__ AICORE void TPARTMUL_f16_8x48_src0_col_less(__gm__ uint16_t *a, __gm__ uint16_t *b, __gm__ uint16_t *c); + +void LaunchTPARTMUL_f16_8x48_src0_col_less(uint16_t *a, uint16_t *b, uint16_t *c, void *stream) { + TPARTMUL_f16_8x48_src0_col_less<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b, (__gm__ uint16_t *)c); +} + +// Case 6: f16 8x768 src0 col less +extern "C" __global__ AICORE void TPARTMUL_f16_8x768_src0_col_less(__gm__ uint16_t *a, __gm__ uint16_t *b, __gm__ uint16_t *c); + +void LaunchTPARTMUL_f16_8x768_src0_col_less(uint16_t *a, uint16_t *b, uint16_t *c, void *stream) { + TPARTMUL_f16_8x768_src0_col_less<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b, (__gm__ uint16_t *)c); +} + +// Case 7: i16 8x48 src1 col less +extern "C" __global__ AICORE void TPARTMUL_i16_8x48_src1_col_less(__gm__ int16_t *a, __gm__ int16_t *b, __gm__ int16_t *c); + +void LaunchTPARTMUL_i16_8x48_src1_col_less(int16_t *a, int16_t *b, int16_t *c, void *stream) { + TPARTMUL_i16_8x48_src1_col_less<<<1, nullptr, stream>>>((__gm__ int16_t *)a, (__gm__ int16_t *)b, (__gm__ int16_t *)c); +} + +// Case 8: i32 64x64 src0 row less +extern "C" __global__ AICORE void TPARTMUL_i32_64x64_src0_row_less(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); + +void LaunchTPARTMUL_i32_64x64_src0_row_less(int32_t *a, int32_t *b, int32_t *c, void *stream) { + TPARTMUL_i32_64x64_src0_row_less<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/main.cpp new file mode 100644 index 000000000..d281d8710 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/main.cpp @@ -0,0 +1,164 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tpartmul ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTPARTMUL_f32_64x64_full(float *a, float *b, float *c, void *stream); +void LaunchTPARTMUL_f32_64x64_src0_row_less(float *a, float *b, float *c, void *stream); +void LaunchTPARTMUL_f32_64x64_src0_col_less(float *a, float *b, float *c, void *stream); +void LaunchTPARTMUL_f32_64x64_src1_row_less(float *a, float *b, float *c, void *stream); +void LaunchTPARTMUL_f32_64x64_src1_col_less(float *a, float *b, float *c, void *stream); +void LaunchTPARTMUL_f16_8x48_src0_col_less(uint16_t *a, uint16_t *b, uint16_t *c, void *stream); +void LaunchTPARTMUL_f16_8x768_src0_col_less(uint16_t *a, uint16_t *b, uint16_t *c, void *stream); +void LaunchTPARTMUL_i16_8x48_src1_col_less(int16_t *a, int16_t *b, int16_t *c, void *stream); +void LaunchTPARTMUL_i32_64x64_src0_row_less(int32_t *a, int32_t *b, int32_t *c, void *stream); + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t src0ValidRows; // src0 effective rows + size_t src0ValidCols; // src0 effective cols + size_t src1ValidRows; // src1 effective rows + size_t src1ValidCols; // src1 effective cols + size_t dstValidRows; // dst effective rows + size_t dstValidCols; // dst effective cols + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_64x64_full", reinterpret_cast(LaunchTPARTMUL_f32_64x64_full), 64, 64, 64, 64, 64, 64, 64, 64, sizeof(float)}, + {"f32_64x64_src0_row_less", reinterpret_cast(LaunchTPARTMUL_f32_64x64_src0_row_less), 64, 64, 8, 64, 64, 64, 64, 64, sizeof(float)}, + {"f32_64x64_src0_col_less", reinterpret_cast(LaunchTPARTMUL_f32_64x64_src0_col_less), 64, 64, 64, 8, 64, 64, 64, 64, sizeof(float)}, + {"f32_64x64_src1_row_less", reinterpret_cast(LaunchTPARTMUL_f32_64x64_src1_row_less), 64, 64, 64, 64, 8, 64, 64, 64, sizeof(float)}, + {"f32_64x64_src1_col_less", reinterpret_cast(LaunchTPARTMUL_f32_64x64_src1_col_less), 64, 64, 64, 64, 64, 8, 64, 64, sizeof(float)}, + {"f16_8x48_src0_col_less", reinterpret_cast(LaunchTPARTMUL_f16_8x48_src0_col_less), 8, 48, 8, 16, 8, 48, 8, 48, sizeof(uint16_t)}, + {"f16_8x768_src0_col_less", reinterpret_cast(LaunchTPARTMUL_f16_8x768_src0_col_less), 8,768, 8,512, 8,768, 8,768, sizeof(uint16_t)}, + {"i16_8x48_src1_col_less", reinterpret_cast(LaunchTPARTMUL_i16_8x48_src1_col_less), 8, 48, 8, 48, 8, 16, 8, 48, sizeof(int16_t)}, + {"i32_64x64_src0_row_less", reinterpret_cast(LaunchTPARTMUL_i32_64x64_src0_row_less), 64, 64, 8, 64, 64, 64, 64, 64, sizeof(int32_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, src0_valid=%zux%zu, src1_valid=%zux%zu, dst_valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.src0ValidRows, tc.src0ValidCols, + tc.src1ValidRows, tc.src1ValidCols, tc.dstValidRows, tc.dstValidCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = fileSize; + size_t src1FileSize = fileSize; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), fileSize); + aclrtMallocHost((void **)(&src1Host), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); + + aclrtMalloc((void **)&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, fileSize, src0Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, fileSize, src1Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tpartmul [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/tpartmul.pto b/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/tpartmul.pto new file mode 100644 index 000000000..103bd6b13 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/tpartmul.pto @@ -0,0 +1,616 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tpartmul: partial elementwise mul with valid region handling. +// Multiple cases with different valid_shape combinations in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. + +module { + // Case 0: f32 64x64 full (src0/src1/dst all have same valid_shape 64x64) + func.func @TPARTMUL_f32_64x64_full(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartmul ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case 1: f32 64x64 src0 row less (src0 valid 8x64, src1/dst valid 64x64) + func.func @TPARTMUL_f32_64x64_src0_row_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + // src0: partial valid region (8,64) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: full valid region (64,64) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: full valid region (64,64) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartmul ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case 2: f32 64x64 src0 col less (src0 valid 64x8, src1/dst valid 64x64) + func.func @TPARTMUL_f32_64x64_src0_col_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + // src0: partial valid region (64,8) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: full valid region (64,64) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: full valid region (64,64) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartmul ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case 3: f32 64x64 src1 row less (src0/dst valid 64x64, src1 valid 8x64) + func.func @TPARTMUL_f32_64x64_src1_row_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + // src0: full valid region (64,64) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: partial valid region (8,64) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: full valid region (64,64) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartmul ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case 4: f32 64x64 src1 col less (src0/dst valid 64x64, src1 valid 64x8) + func.func @TPARTMUL_f32_64x64_src1_col_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + // src0: full valid region (64,64) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: partial valid region (64,8) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: full valid region (64,64) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartmul ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case 5: f16 8x48 src0 col less (src0 valid 8x16, src1/dst valid 8x48) + func.func @TPARTMUL_f16_8x48_src0_col_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c48 = arith.constant 48 : index + %c384 = arith.constant 384 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c8, %c48], + strides = [%c384, %c384, %c384, %c48, %c1] + : !pto.tensor_view<1x1x1x8x48xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c8, %c48], + strides = [%c384, %c384, %c384, %c48, %c1] + : !pto.tensor_view<1x1x1x8x48xf16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c8, %c48], + strides = [%c384, %c384, %c384, %c48, %c1] + : !pto.tensor_view<1x1x1x8x48xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c48] + : !pto.tensor_view<1x1x1x8x48xf16> -> !pto.partition_tensor_view<1x1x1x8x48xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c48] + : !pto.tensor_view<1x1x1x8x48xf16> -> !pto.partition_tensor_view<1x1x1x8x48xf16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c48] + : !pto.tensor_view<1x1x1x8x48xf16> -> !pto.partition_tensor_view<1x1x1x8x48xf16> + + // src0: partial valid region (8,16) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: full valid region (8,48) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: full valid region (8,48) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x8x48xf16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x8x48xf16>) + outs(%b : !pto.tile_buf) + + pto.tpartmul ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x8x48xf16>) + return + } + + // Case 6: f16 8x768 src0 col less (src0 valid 8x512, src1/dst valid 8x768) + func.func @TPARTMUL_f16_8x768_src0_col_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c512 = arith.constant 512 : index + %c768 = arith.constant 768 : index + %c6144 = arith.constant 6144 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c8, %c768], + strides = [%c6144, %c6144, %c6144, %c768, %c1] + : !pto.tensor_view<1x1x1x8x768xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c8, %c768], + strides = [%c6144, %c6144, %c6144, %c768, %c1] + : !pto.tensor_view<1x1x1x8x768xf16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c8, %c768], + strides = [%c6144, %c6144, %c6144, %c768, %c1] + : !pto.tensor_view<1x1x1x8x768xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c768] + : !pto.tensor_view<1x1x1x8x768xf16> -> !pto.partition_tensor_view<1x1x1x8x768xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c768] + : !pto.tensor_view<1x1x1x8x768xf16> -> !pto.partition_tensor_view<1x1x1x8x768xf16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c768] + : !pto.tensor_view<1x1x1x8x768xf16> -> !pto.partition_tensor_view<1x1x1x8x768xf16> + + // src0: partial valid region (8,512) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: full valid region (8,768) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: full valid region (8,768) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x8x768xf16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x8x768xf16>) + outs(%b : !pto.tile_buf) + + pto.tpartmul ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x8x768xf16>) + return + } + + // Case 7: i16 8x48 src1 col less (src0/dst valid 8x48, src1 valid 8x16) + func.func @TPARTMUL_i16_8x48_src1_col_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c48 = arith.constant 48 : index + %c384 = arith.constant 384 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c8, %c48], + strides = [%c384, %c384, %c384, %c48, %c1] + : !pto.tensor_view<1x1x1x8x48xi16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c8, %c48], + strides = [%c384, %c384, %c384, %c48, %c1] + : !pto.tensor_view<1x1x1x8x48xi16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c8, %c48], + strides = [%c384, %c384, %c384, %c48, %c1] + : !pto.tensor_view<1x1x1x8x48xi16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c48] + : !pto.tensor_view<1x1x1x8x48xi16> -> !pto.partition_tensor_view<1x1x1x8x48xi16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c48] + : !pto.tensor_view<1x1x1x8x48xi16> -> !pto.partition_tensor_view<1x1x1x8x48xi16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c48] + : !pto.tensor_view<1x1x1x8x48xi16> -> !pto.partition_tensor_view<1x1x1x8x48xi16> + + // src0: full valid region (8,48) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: partial valid region (8,16) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: full valid region (8,48) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x8x48xi16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x8x48xi16>) + outs(%b : !pto.tile_buf) + + pto.tpartmul ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x8x48xi16>) + return + } + + // Case 8: i32 64x64 src0 row less (src0 valid 8x64, src1/dst valid 64x64) + func.func @TPARTMUL_i32_64x64_src0_row_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + + // src0: partial valid region (8,64) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: full valid region (64,64) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: full valid region (64,64) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + outs(%b : !pto.tile_buf) + + pto.tpartmul ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + return + } +} \ No newline at end of file From 6c45966665e2fb5630e56f5a6a03e4d1a9c1fde4 Mon Sep 17 00:00:00 2001 From: cj Date: Sat, 25 Apr 2026 09:48:07 +0800 Subject: [PATCH 159/192] Add OP for TPartMin & TPartMax (#230) * Add OP for TPartMin & TPartMax * review-fix: i8 & ui8 has supported * fix-review: update license * Add OP for TPartMin & TPartMax: update cmakelist --------- Co-authored-by: caojian5 Co-authored-by: Zhang Zhendong --- lib/TileOps/tpartmax_template.py | 50 ++ lib/TileOps/tpartmin_template.py | 50 ++ .../expand_tile_op_tilelang_tpartmax.pto | 48 ++ .../expand_tile_op_tilelang_tpartmin.pto | 47 ++ .../npu/a5/src/st/testcase/CMakeLists.txt | 2 + .../src/st/testcase/tpartmax/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tpartmax/cases.py | 153 ++++ .../a5/src/st/testcase/tpartmax/compare.py | 53 ++ .../a5/src/st/testcase/tpartmax/gen_data.py | 127 ++++ .../a5/src/st/testcase/tpartmax/launch.cpp | 97 +++ .../npu/a5/src/st/testcase/tpartmax/main.cpp | 230 ++++++ .../a5/src/st/testcase/tpartmax/tpartmax.pto | 717 +++++++++++++++++ .../src/st/testcase/tpartmin/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tpartmin/cases.py | 153 ++++ .../a5/src/st/testcase/tpartmin/compare.py | 53 ++ .../a5/src/st/testcase/tpartmin/gen_data.py | 127 ++++ .../a5/src/st/testcase/tpartmin/launch.cpp | 97 +++ .../npu/a5/src/st/testcase/tpartmin/main.cpp | 230 ++++++ .../a5/src/st/testcase/tpartmin/tpartmin.pto | 718 ++++++++++++++++++ 19 files changed, 2970 insertions(+) create mode 100644 lib/TileOps/tpartmax_template.py create mode 100644 lib/TileOps/tpartmin_template.py create mode 100644 test/basic/expand_tile_op_tilelang_tpartmax.pto create mode 100644 test/basic/expand_tile_op_tilelang_tpartmin.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tpartmax/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tpartmax/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tpartmax/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tpartmax/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tpartmax/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tpartmax/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tpartmax/tpartmax.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tpartmin/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tpartmin/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tpartmin/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tpartmin/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tpartmin/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tpartmin/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tpartmin/tpartmin.pto diff --git a/lib/TileOps/tpartmax_template.py b/lib/TileOps/tpartmax_template.py new file mode 100644 index 000000000..733616738 --- /dev/null +++ b/lib/TileOps/tpartmax_template.py @@ -0,0 +1,50 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tpartmax""" + +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tpartmax", + advanced=True, +) +def template_tpartmax(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + src0_valid_rows, src0_valid_cols = src0.valid_shape + src1_valid_rows, src1_valid_cols = src1.valid_shape + lanes = pto.get_lanes(dtype) + + pad_scalar = pto.PadValue.MIN.eval(dtype) + pad_vec = pto.vbr(pad_scalar) + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + pto.vsts(pad_vec, dst[row, col:], mask) + + for row in range(0, src0_valid_rows, 1): + remained = src0_valid_cols + for col in range(0, src0_valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + vec0 = pto.vlds(src0[row, col:]) + pto.vsts(vec0, dst[row, col:], mask) + + for row in range(0, src1_valid_rows, 1): + remained = src1_valid_cols + for col in range(0, src1_valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + vec_dst = pto.vlds(dst[row, col:]) + vec1 = pto.vlds(src1[row, col:]) + result = pto.vmax(vec_dst, vec1, mask) + pto.vsts(result, dst[row, col:], mask) + + return \ No newline at end of file diff --git a/lib/TileOps/tpartmin_template.py b/lib/TileOps/tpartmin_template.py new file mode 100644 index 000000000..05364ba4a --- /dev/null +++ b/lib/TileOps/tpartmin_template.py @@ -0,0 +1,50 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tpartmin""" + +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tpartmin", + advanced=True, +) +def template_tpartmin(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + src0_valid_rows, src0_valid_cols = src0.valid_shape + src1_valid_rows, src1_valid_cols = src1.valid_shape + lanes = pto.get_lanes(dtype) + + pad_scalar = pto.PadValue.MAX.eval(dtype) + pad_vec = pto.vbr(pad_scalar) + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + pto.vsts(pad_vec, dst[row, col:], mask) + + for row in range(0, src0_valid_rows, 1): + remained = src0_valid_cols + for col in range(0, src0_valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + vec0 = pto.vlds(src0[row, col:]) + pto.vsts(vec0, dst[row, col:], mask) + + for row in range(0, src1_valid_rows, 1): + remained = src1_valid_cols + for col in range(0, src1_valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + vec_dst = pto.vlds(dst[row, col:]) + vec1 = pto.vlds(src1[row, col:]) + result = pto.vmin(vec_dst, vec1, mask) + pto.vsts(result, dst[row, col:], mask) + + return \ No newline at end of file diff --git a/test/basic/expand_tile_op_tilelang_tpartmax.pto b/test/basic/expand_tile_op_tilelang_tpartmax.pto new file mode 100644 index 000000000..5e03fd19f --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_tpartmax.pto @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tpartadd/tpartmul/tpartmax/tpartmin via TileLang Python DSL templates. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// Tile ops should be lowered to vector-style VPTO IR. + +// TPartMax checks +// CHECK-LABEL: func.func @TPARTMAX +// CHECK-NOT: pto.tpartmax ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vmax +// CHECK: pto.vsts + + +module { + func.func @TPARTMAX() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tpartmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/expand_tile_op_tilelang_tpartmin.pto b/test/basic/expand_tile_op_tilelang_tpartmin.pto new file mode 100644 index 000000000..2727ccf06 --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_tpartmin.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tpartadd/tpartmul/tpartmax/tpartmin via TileLang Python DSL templates. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// Tile ops should be lowered to vector-style VPTO IR. + +// TPartMin checks +// CHECK-LABEL: func.func @TPARTMIN +// CHECK-NOT: pto.tpartmin ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vmin +// CHECK: pto.vsts + +module { + func.func @TPARTMIN() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tpartmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt index b1b1af182..b8604dbc3 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt +++ b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt @@ -138,6 +138,8 @@ set(ALL_TESTCASES tlog tneg tnot + tpartmax + tpartmin tpartadd tpartmul trecip diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/CMakeLists.txt new file mode 100644 index 000000000..04f947e55 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tpartmax) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/cases.py new file mode 100644 index 000000000..e6e6dbfe5 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/cases.py @@ -0,0 +1,153 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tpartmax ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions (same for src0/src1/dst). + - valid_shape: (valid_rows, valid_cols) — src0 valid region (src0_eq_dst scenario). + - src1_vshape: (src1_valid_rows, src1_valid_cols) — src1 valid region. + May be smaller than dst valid region for partial max cases. + - dst_vshape: (dst_valid_rows, dst_valid_cols) — dst valid region. + - eps: tolerance for numpy.allclose (atol and rtol). + +tpartmax semantics: + - If src0_valid == dst_valid: dst[:src1_rows,:src1_cols] = max(src0[:src1_rows,:src1_cols], src1[:src1_rows,:src1_cols]) + dst[src1_rows:,:] = src0[src1_rows:,:] (copy remaining rows) + OR (for col_less) dst[:,:src1_cols] = max(src0[:,:src1_cols], src1[:,:src1_cols]) + dst[:,src1_cols:] = src0[:,src1_cols:] (copy remaining cols) + - If src1_valid == dst_valid: similar logic with src1 as the full operand. + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + # float32 cases from pto-isa + { + "name": "f32_64x64_full", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (64, 64), # src0 valid region + "src1_vshape": (64, 64), # src1 valid region (same as dst) + "dst_vshape": (64, 64), # dst valid region + "eps": 1e-6, + }, + { + "name": "f32_2x24_src1_col_less", + "dtype": np.float32, + "shape": (2, 24), + "valid_shape": (2, 24), # src0 valid region (equals dst) + "src1_vshape": (2, 8), # src1 valid region (col_less) + "dst_vshape": (2, 24), # dst valid region + "eps": 1e-6, + }, + { + "name": "f32_128x64_src1_row_less", + "dtype": np.float32, + "shape": (128, 64), + "valid_shape": (128, 64), # src0 valid region (equals dst) + "src1_vshape": (96, 64), # src1 valid region (row_less) + "dst_vshape": (128, 64), # dst valid region + "eps": 1e-6, + }, + { + "name": "f32_95x95_full", + "dtype": np.float32, + "shape": (95, 95), + "valid_shape": (95, 95), # src0 valid region + "src1_vshape": (95, 95), # src1 valid region (same as dst) + "dst_vshape": (95, 95), # dst valid region + "eps": 1e-6, + }, + { + "name": "f32_122x123_complex", + "dtype": np.float32, + "shape": (122, 123), + "valid_shape": (104, 123), # src0 valid region + "src1_vshape": (122, 110), # src1 valid region + "dst_vshape": (122, 123), # dst valid region (src1 rows, src0 cols) + "eps": 1e-6, + }, + # float16 cases from pto-isa + { + "name": "f16_122x123_complex", + "dtype": np.float16, + "shape": (122, 123), + "valid_shape": (104, 123), # src0 valid region + "src1_vshape": (122, 110), # src1 valid region + "dst_vshape": (122, 123), # dst valid region + "eps": 1e-3, + }, + # int16 cases from pto-isa + { + "name": "i16_122x123_complex", + "dtype": np.int16, + "shape": (122, 123), + "valid_shape": (104, 123), # src0 valid region + "src1_vshape": (122, 110), # src1 valid region + "dst_vshape": (122, 123), # dst valid region + "eps": 0, + }, + # int32 cases from pto-isa + { + "name": "i32_122x123_complex", + "dtype": np.int32, + "shape": (122, 123), + "valid_shape": (104, 123), # src0 valid region + "src1_vshape": (122, 110), # src1 valid region + "dst_vshape": (122, 123), # dst valid region + "eps": 0, + }, + # uint16 cases from pto-isa + { + "name": "u16_122x123_complex", + "dtype": np.uint16, + "shape": (122, 123), + "valid_shape": (104, 123), # src0 valid region + "src1_vshape": (122, 110), # src1 valid region + "dst_vshape": (122, 123), # dst valid region + "eps": 0, + }, + # uint32 cases from pto-isa + { + "name": "u32_122x123_complex", + "dtype": np.uint32, + "shape": (122, 123), + "valid_shape": (104, 123), # src0 valid region + "src1_vshape": (122, 110), # src1 valid region + "dst_vshape": (122, 123), # dst valid region + "eps": 0, + }, + # int8 cases from pto-isa + { + "name": "i8_122x123_complex", + "dtype": np.int8, + "shape": (122, 123), + "valid_shape": (104, 123), # src0 valid region + "src1_vshape": (122, 110), # src1 valid region + "dst_vshape": (122, 123), # dst valid region + "eps": 0, + }, + # uint8 cases from pto-isa + { + "name": "u8_122x123_complex", + "dtype": np.uint8, + "shape": (122, 123), + "valid_shape": (104, 123), # src0 valid region + "src1_vshape": (122, 110), # src1 valid region + "dst_vshape": (122, 123), # dst valid region + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/compare.py new file mode 100644 index 000000000..283ee788a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/compare.py @@ -0,0 +1,53 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +# Add parent directory to path for st_common import +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from st_common import result_cmp, style_fail, style_pass + +from cases import CASES + + +def main(): + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + dtype = case["dtype"] + dst_vr, dst_vc = case["dst_vshape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=dtype).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=dtype).reshape(shape) + + # Compare only the dst valid region + ok = result_cmp(golden[:dst_vr, :dst_vc], output[:dst_vr, :dst_vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/gen_data.py new file mode 100644 index 000000000..700de5895 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/gen_data.py @@ -0,0 +1,127 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +import os +import sys + +# Add parent directory to path for st_common import +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from st_common import setup_case_rng, save_case_data + +from cases import CASES + + +def _to_tuple(shape): + """Convert shape to tuple if needed.""" + if isinstance(shape, tuple): + return shape + return tuple(shape) + + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = _to_tuple(case["shape"]) + src0_valid = _to_tuple(case["valid_shape"]) + src1_valid = _to_tuple(case["src1_vshape"]) + dst_valid = _to_tuple(case["dst_vshape"]) + + rows, cols = shape + src0_vr, src0_vc = src0_valid + src1_vr, src1_vc = src1_valid + dst_vr, dst_vc = dst_valid + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + input2 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + + # tpartmax semantics (based on pto-isa TPartBinOps.hpp TCopyPadOp): + # Algorithm: + # 1. dst[:] = Min (padding for max operation) + # 2. dst[0:src0_vr, 0:src0_vc] = src0[0:src0_vr, 0:src0_vc] (copy src0 to dst) + # 3. dst[0:src1_vr, 0:src1_vc] = max(dst[0:src1_vr, 0:src1_vc], src1[0:src1_vr, 0:src1_vc]) + # (apply max in src1 valid region) + + src0_eq_dst = (src0_vr == dst_vr and src0_vc == dst_vc) + src1_eq_dst = (src1_vr == dst_vr and src1_vc == dst_vc) + + if src0_eq_dst and src1_eq_dst: + # Full max: both src0 and src1 cover entire dst + golden[:dst_vr, :dst_vc] = np.maximum(input1[:dst_vr, :dst_vc], input2[:dst_vr, :dst_vc]).astype(dtype, copy=False) + elif src0_eq_dst: + # src0 covers dst, src1 is partial + # dst = src0 (copy), then max(dst, src1) in src1 region = max(src0, src1) in src1 region, src0 in rest + golden[:src1_vr, :src1_vc] = np.maximum(input1[:src1_vr, :src1_vc], input2[:src1_vr, :src1_vc]).astype(dtype, copy=False) + if src1_vc < dst_vc: + golden[:src1_vr, src1_vc:dst_vc] = input1[:src1_vr, src1_vc:dst_vc].copy() + if src1_vr < dst_vr: + golden[src1_vr:dst_vr, :dst_vc] = input1[src1_vr:dst_vr, :dst_vc].copy() + elif src1_eq_dst: + # src1 covers dst, src0 is partial + # dst = Min, then copy src0 in src0 region, then max(dst, src1) in src1 region + golden[:src0_vr, :src0_vc] = np.maximum(input1[:src0_vr, :src0_vc], input2[:src0_vr, :src0_vc]).astype(dtype, copy=False) + if src0_vc < dst_vc: + golden[:src0_vr, src0_vc:dst_vc] = input2[:src0_vr, src0_vc:dst_vc].copy() + if src0_vr < dst_vr: + golden[src0_vr:dst_vr, :dst_vc] = input2[src0_vr:dst_vr, :dst_vc].copy() + else: + min_vr = min(src0_vr, src1_vr) + min_vc = min(src0_vc, src1_vc) + + # Region 1: [0:min_vr, 0:min_vc] - overlapping region (both src0 and src1 valid) + golden[:min_vr, :min_vc] = np.maximum(input1[:min_vr, :min_vc], input2[:min_vr, :min_vc]).astype(dtype, copy=False) + + # Region 2: [0:src0_vr, min_vc:src0_vc] if src0_vc > min_vc + if src0_vc > min_vc: + golden[:src0_vr, min_vc:src0_vc] = input1[:src0_vr, min_vc:src0_vc].copy() + + # Region 3: [min_vr:src1_vr, 0:min_vc] if src1_vr > min_vr + if src1_vr > min_vr: + golden[min_vr:src1_vr, :min_vc] = input2[min_vr:src1_vr, :min_vc].copy() + + # Region 4: [min_vr:src1_vr, min_vc:src1_vc] if src1_vr > min_vr AND src1_vc > min_vc + if src1_vr > min_vr and src1_vc > min_vc: + golden[min_vr:src1_vr, min_vc:src1_vc] = input2[min_vr:src1_vr, min_vc:src1_vc].copy() + + # Region 5: [0:min_vr, src1_vc:src0_vc] if src0_vc > src1_vc + if src0_vc > src1_vc and min_vr > 0: + # Already handled in Region 2 if rows are [0:src0_vr] + pass # Region 2 covers this + + if src1_vr > src0_vr and src0_vc > src1_vc: + # Region [src0_vr:src1_vr, src1_vc:src0_vc] = Min (neither covers) + # This is correct for tpartmax - padding value is Min + # For floats, we use -np.inf. For integers, use dtype min. + if dtype == np.float32: + min_val = np.finfo(np.float32).min + elif dtype == np.float16: + min_val = np.finfo(np.float16).min + elif dtype == np.int8: + min_val = np.iinfo(np.int8).min + elif dtype == np.uint8: + min_val = np.iinfo(np.uint8).min + elif dtype == np.int16: + min_val = np.iinfo(np.int16).min + elif dtype == np.uint16: + min_val = np.iinfo(np.uint16).min + elif dtype == np.int32: + min_val = np.iinfo(np.int32).min + elif dtype == np.uint32: + min_val = np.iinfo(np.uint32).min + else: + min_val = np.iinfo(dtype).min + golden[src0_vr:src1_vr, src1_vc:src0_vc] = min_val + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} src0_valid={src0_valid} src1_valid={src1_valid} dst_valid={dst_valid} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/launch.cpp new file mode 100644 index 000000000..98a1a76d7 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/launch.cpp @@ -0,0 +1,97 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case: f32 64x64 full +extern "C" __global__ AICORE void TPARTMAX_f32_64x64_full(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTMAX_f32_64x64_full(float *a, float *b, float *c, void *stream) { + TPARTMAX_f32_64x64_full<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case: f32 2x24 src1 col less +extern "C" __global__ AICORE void TPARTMAX_f32_2x24_src1_col_less(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTMAX_f32_2x24_src1_col_less(float *a, float *b, float *c, void *stream) { + TPARTMAX_f32_2x24_src1_col_less<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case: f32 128x64 src1 row less +extern "C" __global__ AICORE void TPARTMAX_f32_128x64_src1_row_less(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTMAX_f32_128x64_src1_row_less(float *a, float *b, float *c, void *stream) { + TPARTMAX_f32_128x64_src1_row_less<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case: f32 95x95 full +extern "C" __global__ AICORE void TPARTMAX_f32_95x95_full(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTMAX_f32_95x95_full(float *a, float *b, float *c, void *stream) { + TPARTMAX_f32_95x95_full<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case: f32 122x123 complex +extern "C" __global__ AICORE void TPARTMAX_f32_122x123_complex(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTMAX_f32_122x123_complex(float *a, float *b, float *c, void *stream) { + TPARTMAX_f32_122x123_complex<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case: f16 122x123 complex +extern "C" __global__ AICORE void TPARTMAX_f16_122x123_complex(__gm__ uint16_t *a, __gm__ uint16_t *b, __gm__ uint16_t *c); + +void LaunchTPARTMAX_f16_122x123_complex(uint16_t *a, uint16_t *b, uint16_t *c, void *stream) { + TPARTMAX_f16_122x123_complex<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b, (__gm__ uint16_t *)c); +} + +// Case: i16 122x123 complex +extern "C" __global__ AICORE void TPARTMAX_i16_122x123_complex(__gm__ int16_t *a, __gm__ int16_t *b, __gm__ int16_t *c); + +void LaunchTPARTMAX_i16_122x123_complex(int16_t *a, int16_t *b, int16_t *c, void *stream) { + TPARTMAX_i16_122x123_complex<<<1, nullptr, stream>>>((__gm__ int16_t *)a, (__gm__ int16_t *)b, (__gm__ int16_t *)c); +} + +// Case: i32 122x123 complex +extern "C" __global__ AICORE void TPARTMAX_i32_122x123_complex(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); + +void LaunchTPARTMAX_i32_122x123_complex(int32_t *a, int32_t *b, int32_t *c, void *stream) { + TPARTMAX_i32_122x123_complex<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); +} + +// Case: u16 122x123 complex +extern "C" __global__ AICORE void TPARTMAX_u16_122x123_complex(__gm__ uint16_t *a, __gm__ uint16_t *b, __gm__ uint16_t *c); + +void LaunchTPARTMAX_u16_122x123_complex(uint16_t *a, uint16_t *b, uint16_t *c, void *stream) { + TPARTMAX_u16_122x123_complex<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b, (__gm__ uint16_t *)c); +} + +// Case: u32 122x123 complex +extern "C" __global__ AICORE void TPARTMAX_u32_122x123_complex(__gm__ uint32_t *a, __gm__ uint32_t *b, __gm__ uint32_t *c); + +void LaunchTPARTMAX_u32_122x123_complex(uint32_t *a, uint32_t *b, uint32_t *c, void *stream) { + TPARTMAX_u32_122x123_complex<<<1, nullptr, stream>>>((__gm__ uint32_t *)a, (__gm__ uint32_t *)b, (__gm__ uint32_t *)c); +} + +// Case: i8 122x123 complex +extern "C" __global__ AICORE void TPARTMAX_i8_122x123_complex(__gm__ int8_t *a, __gm__ int8_t *b, __gm__ int8_t *c); + +void LaunchTPARTMAX_i8_122x123_complex(int8_t *a, int8_t *b, int8_t *c, void *stream) { + TPARTMAX_i8_122x123_complex<<<1, nullptr, stream>>>((__gm__ int8_t *)a, (__gm__ int8_t *)b, (__gm__ int8_t *)c); +} + +// Case: u8 122x123 complex +extern "C" __global__ AICORE void TPARTMAX_u8_122x123_complex(__gm__ uint8_t *a, __gm__ uint8_t *b, __gm__ uint8_t *c); + +void LaunchTPARTMAX_u8_122x123_complex(uint8_t *a, uint8_t *b, uint8_t *c, void *stream) { + TPARTMAX_u8_122x123_complex<<<1, nullptr, stream>>>((__gm__ uint8_t *)a, (__gm__ uint8_t *)b, (__gm__ uint8_t *)c); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/main.cpp new file mode 100644 index 000000000..c81ab0e62 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/main.cpp @@ -0,0 +1,230 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tpartmax ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTPARTMAX_f32_64x64_full(float *a, float *b, float *c, void *stream); +void LaunchTPARTMAX_f32_2x24_src1_col_less(float *a, float *b, float *c, void *stream); +void LaunchTPARTMAX_f32_128x64_src1_row_less(float *a, float *b, float *c, void *stream); +void LaunchTPARTMAX_f32_95x95_full(float *a, float *b, float *c, void *stream); +void LaunchTPARTMAX_f32_122x123_complex(float *a, float *b, float *c, void *stream); +void LaunchTPARTMAX_f16_122x123_complex(uint16_t *a, uint16_t *b, uint16_t *c, void *stream); +void LaunchTPARTMAX_i16_122x123_complex(int16_t *a, int16_t *b, int16_t *c, void *stream); +void LaunchTPARTMAX_i32_122x123_complex(int32_t *a, int32_t *b, int32_t *c, void *stream); +void LaunchTPARTMAX_u16_122x123_complex(uint16_t *a, uint16_t *b, uint16_t *c, void *stream); +void LaunchTPARTMAX_u32_122x123_complex(uint32_t *a, uint32_t *b, uint32_t *c, void *stream); +void LaunchTPARTMAX_i8_122x123_complex(int8_t *a, int8_t *b, int8_t *c, void *stream); +void LaunchTPARTMAX_u8_122x123_complex(uint8_t *a, uint8_t *b, uint8_t *c, void *stream); + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols (valid cols) + size_t src0ValidRows; // src0 effective rows + size_t src0ValidCols; // src0 effective cols + size_t src1ValidRows; // src1 effective rows + size_t src1ValidCols; // src1 effective cols + size_t dstValidRows; // dst effective rows + size_t dstValidCols; // dst effective cols + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_64x64_full", reinterpret_cast(LaunchTPARTMAX_f32_64x64_full), 64, 64, 64, 64, 64, 64, 64, 64, sizeof(float)}, + {"f32_2x24_src1_col_less", reinterpret_cast(LaunchTPARTMAX_f32_2x24_src1_col_less), 2, 24, 2, 24, 2, 8, 2, 24, sizeof(float)}, + {"f32_128x64_src1_row_less", reinterpret_cast(LaunchTPARTMAX_f32_128x64_src1_row_less), 128, 64,128, 64, 96, 64,128, 64, sizeof(float)}, + {"f32_95x95_full", reinterpret_cast(LaunchTPARTMAX_f32_95x95_full), 95, 95, 95, 95, 95, 95, 95, 95, sizeof(float)}, + {"f32_122x123_complex", reinterpret_cast(LaunchTPARTMAX_f32_122x123_complex), 122,123,104,123,122,110,122,123, sizeof(float)}, + {"f16_122x123_complex", reinterpret_cast(LaunchTPARTMAX_f16_122x123_complex), 122,123,104,123,122,110,122,123, sizeof(uint16_t)}, + {"i16_122x123_complex", reinterpret_cast(LaunchTPARTMAX_i16_122x123_complex), 122,123,104,123,122,110,122,123, sizeof(int16_t)}, + {"i32_122x123_complex", reinterpret_cast(LaunchTPARTMAX_i32_122x123_complex), 122,123,104,123,122,110,122,123, sizeof(int32_t)}, + {"u16_122x123_complex", reinterpret_cast(LaunchTPARTMAX_u16_122x123_complex), 122,123,104,123,122,110,122,123, sizeof(uint16_t)}, + {"u32_122x123_complex", reinterpret_cast(LaunchTPARTMAX_u32_122x123_complex), 122,123,104,123,122,110,122,123, sizeof(uint32_t)}, + {"i8_122x123_complex", reinterpret_cast(LaunchTPARTMAX_i8_122x123_complex), 122,123,104,123,122,110,122,123, sizeof(int8_t)}, + {"u8_122x123_complex", reinterpret_cast(LaunchTPARTMAX_u8_122x123_complex), 122,123,104,123,122,110,122,123, sizeof(uint8_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +// Calculate aligned cols for 32-byte alignment +static size_t CalcAlignedCols(size_t cols, size_t elemSize) { + size_t totalBytes = cols * elemSize; + size_t alignedBytes = ((totalBytes + 31) / 32) * 32; + return alignedBytes / elemSize; +} + +// Helper to pad data with stride +static void PadDataWithStride(const void *src, void *dst, size_t rows, size_t cols, + size_t alignedCols, size_t elemSize) { + const char *srcPtr = static_cast(src); + char *dstPtr = static_cast(dst); + for (size_t r = 0; r < rows; ++r) { + memcpy(dstPtr + r * alignedCols * elemSize, + srcPtr + r * cols * elemSize, + cols * elemSize); + // Zero-fill padding region (optional, data will be overwritten by kernel) + memset(dstPtr + r * alignedCols * elemSize + cols * elemSize, + 0, + (alignedCols - cols) * elemSize); + } +} + +// Helper to unpad data (extract valid cols) +static void UnpadDataWithStride(const void *src, void *dst, size_t rows, size_t cols, + size_t alignedCols, size_t elemSize) { + const char *srcPtr = static_cast(src); + char *dstPtr = static_cast(dst); + for (size_t r = 0; r < rows; ++r) { + memcpy(dstPtr + r * cols * elemSize, + srcPtr + r * alignedCols * elemSize, + cols * elemSize); + } +} + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + const size_t alignedCols = CalcAlignedCols(tc.cols, tc.elemSize); + const size_t paddedSize = tc.rows * alignedCols * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, src0_valid=%zux%zu, src1_valid=%zux%zu, dst_valid=%zux%zu, alignedCols=%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.src0ValidRows, tc.src0ValidCols, + tc.src1ValidRows, tc.src1ValidCols, tc.dstValidRows, tc.dstValidCols, alignedCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + + void *src0HostOrig = nullptr, *src1HostOrig = nullptr, *dstHostOrig = nullptr; + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + // Allocate host buffers for original data (contiguous) + aclrtMallocHost((void **)(&src0HostOrig), fileSize); + aclrtMallocHost((void **)(&src1HostOrig), fileSize); + aclrtMallocHost((void **)(&dstHostOrig), fileSize); + + // Allocate host buffers for padded data + aclrtMallocHost((void **)(&src0Host), paddedSize); + aclrtMallocHost((void **)(&src1Host), paddedSize); + aclrtMallocHost((void **)(&dstHost), paddedSize); + + // Allocate device buffers with padded size + aclrtMalloc((void **)&src0Device, paddedSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, paddedSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, paddedSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (rc == 0) { + size_t src0FileSize = fileSize; + size_t src1FileSize = fileSize; + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0HostOrig, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1HostOrig, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + } + + if (rc == 0) { + // Pad input data with stride + PadDataWithStride(src0HostOrig, src0Host, tc.rows, tc.cols, alignedCols, tc.elemSize); + PadDataWithStride(src1HostOrig, src1Host, tc.rows, tc.cols, alignedCols, tc.elemSize); + + aclrtMemcpy(src0Device, paddedSize, src0Host, paddedSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, paddedSize, src1Host, paddedSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, paddedSize, dstDevice, paddedSize, ACL_MEMCPY_DEVICE_TO_HOST); + + // Unpad output data + UnpadDataWithStride(dstHost, dstHostOrig, tc.rows, tc.cols, alignedCols, tc.elemSize); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHostOrig, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + if (src0HostOrig != nullptr) + aclrtFreeHost(src0HostOrig); + if (src1HostOrig != nullptr) + aclrtFreeHost(src1HostOrig); + if (dstHostOrig != nullptr) + aclrtFreeHost(dstHostOrig); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tpartmax [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/tpartmax.pto b/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/tpartmax.pto new file mode 100644 index 000000000..7cd8e20b1 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/tpartmax.pto @@ -0,0 +1,717 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tpartmax: partial elementwise max with valid region handling. +// Multiple cases with different valid_shape combinations in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. + +module { + // Case: f32_64x64_full (src0 valid 64x64, src1 valid 64x64, dst valid 64x64) + func.func @TPARTMAX_f32_64x64_full(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + // src0: valid region (64,64) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (64,64) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (64,64) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case: f32_2x24_src1_col_less (src0 valid 2x24, src1 valid 2x8, dst valid 2x24) + func.func @TPARTMAX_f32_2x24_src1_col_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c24 = arith.constant 24 : index + %c48 = arith.constant 48 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c2, %c24], + strides = [%c48, %c48, %c48, %c24, %c1] + : !pto.tensor_view<1x1x1x2x24xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c2, %c24], + strides = [%c48, %c48, %c48, %c24, %c1] + : !pto.tensor_view<1x1x1x2x24xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c2, %c24], + strides = [%c48, %c48, %c48, %c24, %c1] + : !pto.tensor_view<1x1x1x2x24xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c24] + : !pto.tensor_view<1x1x1x2x24xf32> -> !pto.partition_tensor_view<1x1x1x2x24xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c24] + : !pto.tensor_view<1x1x1x2x24xf32> -> !pto.partition_tensor_view<1x1x1x2x24xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c24] + : !pto.tensor_view<1x1x1x2x24xf32> -> !pto.partition_tensor_view<1x1x1x2x24xf32> + + // src0: valid region (2,24) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (2,8) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (2,24) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x2x24xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x2x24xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x2x24xf32>) + return + } + + // Case: f32_128x64_src1_row_less (src0 valid 128x64, src1 valid 96x64, dst valid 128x64) + func.func @TPARTMAX_f32_128x64_src1_row_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c64 = arith.constant 64 : index + %c8192 = arith.constant 8192 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c128, %c64], + strides = [%c8192, %c8192, %c8192, %c64, %c1] + : !pto.tensor_view<1x1x1x128x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c128, %c64], + strides = [%c8192, %c8192, %c8192, %c64, %c1] + : !pto.tensor_view<1x1x1x128x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c128, %c64], + strides = [%c8192, %c8192, %c8192, %c64, %c1] + : !pto.tensor_view<1x1x1x128x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c64] + : !pto.tensor_view<1x1x1x128x64xf32> -> !pto.partition_tensor_view<1x1x1x128x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c64] + : !pto.tensor_view<1x1x1x128x64xf32> -> !pto.partition_tensor_view<1x1x1x128x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c64] + : !pto.tensor_view<1x1x1x128x64xf32> -> !pto.partition_tensor_view<1x1x1x128x64xf32> + + // src0: valid region (128,64) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (96,64) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (128,64) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x128x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x128x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x128x64xf32>) + return + } + + // Case: f32_95x95_full (src0 valid 95x95, src1 valid 95x95, dst valid 95x95) + func.func @TPARTMAX_f32_95x95_full(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c95 = arith.constant 95 : index + %c96 = arith.constant 96 : index + %c9120 = arith.constant 9120 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c95, %c96], + strides = [%c9120, %c9120, %c9120, %c96, %c1] + : !pto.tensor_view<1x1x1x95x96xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c95, %c96], + strides = [%c9120, %c9120, %c9120, %c96, %c1] + : !pto.tensor_view<1x1x1x95x96xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c95, %c96], + strides = [%c9120, %c9120, %c9120, %c96, %c1] + : !pto.tensor_view<1x1x1x95x96xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c95, %c96] + : !pto.tensor_view<1x1x1x95x96xf32> -> !pto.partition_tensor_view<1x1x1x95x96xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c95, %c96] + : !pto.tensor_view<1x1x1x95x96xf32> -> !pto.partition_tensor_view<1x1x1x95x96xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c95, %c95] + : !pto.tensor_view<1x1x1x95x96xf32> -> !pto.partition_tensor_view<1x1x1x95x95xf32> + + // src0: valid region (95,95) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (95,95) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (95,95) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x95x96xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x95x96xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x95x95xf32>) + return + } + + // Case: f32_122x123_complex (src0 valid 104x123, src1 valid 122x110, dst valid 122x123) + func.func @TPARTMAX_f32_122x123_complex(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c122 = arith.constant 122 : index + %c128 = arith.constant 128 : index + %c15616 = arith.constant 15616 : index + %c123 = arith.constant 123 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xf32> -> !pto.partition_tensor_view<1x1x1x122x128xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xf32> -> !pto.partition_tensor_view<1x1x1x122x128xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c123] + : !pto.tensor_view<1x1x1x122x128xf32> -> !pto.partition_tensor_view<1x1x1x122x123xf32> + + // src0: valid region (104,123) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (122,110) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (122,123) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x122x128xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x122x128xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x122x123xf32>) + return + } + + // Case: f16_122x123_complex (src0 valid 104x123, src1 valid 122x110, dst valid 122x123) + func.func @TPARTMAX_f16_122x123_complex(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c122 = arith.constant 122 : index + %c128 = arith.constant 128 : index + %c15616 = arith.constant 15616 : index + %c123 = arith.constant 123 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xf16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xf16> -> !pto.partition_tensor_view<1x1x1x122x128xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xf16> -> !pto.partition_tensor_view<1x1x1x122x128xf16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c123] + : !pto.tensor_view<1x1x1x122x128xf16> -> !pto.partition_tensor_view<1x1x1x122x123xf16> + + // src0: valid region (104,123) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (122,110) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (122,123) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x122x128xf16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x122x128xf16>) + outs(%b : !pto.tile_buf) + + pto.tpartmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x122x123xf16>) + return + } + + // Case: i16_122x123_complex (src0 valid 104x123, src1 valid 122x110, dst valid 122x123) + func.func @TPARTMAX_i16_122x123_complex(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c122 = arith.constant 122 : index + %c128 = arith.constant 128 : index + %c15616 = arith.constant 15616 : index + %c123 = arith.constant 123 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xi16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xi16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xi16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xi16> -> !pto.partition_tensor_view<1x1x1x122x128xi16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xi16> -> !pto.partition_tensor_view<1x1x1x122x128xi16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c123] + : !pto.tensor_view<1x1x1x122x128xi16> -> !pto.partition_tensor_view<1x1x1x122x123xi16> + + // src0: valid region (104,123) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (122,110) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (122,123) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x122x128xi16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x122x128xi16>) + outs(%b : !pto.tile_buf) + + pto.tpartmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x122x123xi16>) + return + } + + // Case: i32_122x123_complex (src0 valid 104x123, src1 valid 122x110, dst valid 122x123) + func.func @TPARTMAX_i32_122x123_complex(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c122 = arith.constant 122 : index + %c128 = arith.constant 128 : index + %c15616 = arith.constant 15616 : index + %c123 = arith.constant 123 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xi32> -> !pto.partition_tensor_view<1x1x1x122x128xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xi32> -> !pto.partition_tensor_view<1x1x1x122x128xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c123] + : !pto.tensor_view<1x1x1x122x128xi32> -> !pto.partition_tensor_view<1x1x1x122x123xi32> + + // src0: valid region (104,123) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (122,110) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (122,123) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x122x128xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x122x128xi32>) + outs(%b : !pto.tile_buf) + + pto.tpartmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x122x123xi32>) + return + } + + // Case: u16_122x123_complex (src0 valid 104x123, src1 valid 122x110, dst valid 122x123) + func.func @TPARTMAX_u16_122x123_complex(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c122 = arith.constant 122 : index + %c128 = arith.constant 128 : index + %c15616 = arith.constant 15616 : index + %c123 = arith.constant 123 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xui16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xui16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xui16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xui16> -> !pto.partition_tensor_view<1x1x1x122x128xui16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xui16> -> !pto.partition_tensor_view<1x1x1x122x128xui16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c123] + : !pto.tensor_view<1x1x1x122x128xui16> -> !pto.partition_tensor_view<1x1x1x122x123xui16> + + // src0: valid region (104,123) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (122,110) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (122,123) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x122x128xui16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x122x128xui16>) + outs(%b : !pto.tile_buf) + + pto.tpartmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x122x123xui16>) + return + } + + // Case: u32_122x123_complex (src0 valid 104x123, src1 valid 122x110, dst valid 122x123) + func.func @TPARTMAX_u32_122x123_complex(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c122 = arith.constant 122 : index + %c128 = arith.constant 128 : index + %c15616 = arith.constant 15616 : index + %c123 = arith.constant 123 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xui32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xui32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xui32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xui32> -> !pto.partition_tensor_view<1x1x1x122x128xui32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xui32> -> !pto.partition_tensor_view<1x1x1x122x128xui32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c123] + : !pto.tensor_view<1x1x1x122x128xui32> -> !pto.partition_tensor_view<1x1x1x122x123xui32> + + // src0: valid region (104,123) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (122,110) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (122,123) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x122x128xui32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x122x128xui32>) + outs(%b : !pto.tile_buf) + + pto.tpartmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x122x123xui32>) + return + } + + // Case: i8_122x123_complex (src0 valid 104x123, src1 valid 122x110, dst valid 122x123) + func.func @TPARTMAX_i8_122x123_complex(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c122 = arith.constant 122 : index + %c128 = arith.constant 128 : index + %c15616 = arith.constant 15616 : index + %c123 = arith.constant 123 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xi8> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xi8> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xi8> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xi8> -> !pto.partition_tensor_view<1x1x1x122x128xi8> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xi8> -> !pto.partition_tensor_view<1x1x1x122x128xi8> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c123] + : !pto.tensor_view<1x1x1x122x128xi8> -> !pto.partition_tensor_view<1x1x1x122x123xi8> + + // src0: valid region (104,123) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (122,110) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (122,123) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x122x128xi8>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x122x128xi8>) + outs(%b : !pto.tile_buf) + + pto.tpartmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x122x123xi8>) + return + } + + // Case: u8_122x123_complex (src0 valid 104x123, src1 valid 122x110, dst valid 122x123) + func.func @TPARTMAX_u8_122x123_complex(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c122 = arith.constant 122 : index + %c128 = arith.constant 128 : index + %c15616 = arith.constant 15616 : index + %c123 = arith.constant 123 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xui8> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xui8> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xui8> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xui8> -> !pto.partition_tensor_view<1x1x1x122x128xui8> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xui8> -> !pto.partition_tensor_view<1x1x1x122x128xui8> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c123] + : !pto.tensor_view<1x1x1x122x128xui8> -> !pto.partition_tensor_view<1x1x1x122x123xui8> + + // src0: valid region (104,123) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (122,110) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (122,123) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x122x128xui8>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x122x128xui8>) + outs(%b : !pto.tile_buf) + + pto.tpartmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x122x123xui8>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/CMakeLists.txt new file mode 100644 index 000000000..cfb480147 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tpartmin) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/cases.py new file mode 100644 index 000000000..50976fbd1 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/cases.py @@ -0,0 +1,153 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tpartmin ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions (same for src0/src1/dst). + - valid_shape: (valid_rows, valid_cols) — src0 valid region (src0_eq_dst scenario). + - src1_vshape: (src1_valid_rows, src1_valid_cols) — src1 valid region. + May be smaller than dst valid region for partial min cases. + - dst_vshape: (dst_valid_rows, dst_valid_cols) — dst valid region. + - eps: tolerance for numpy.allclose (atol and rtol). + +tpartmin semantics: + - If src0_valid == dst_valid: dst[:src1_rows,:src1_cols] = min(src0[:src1_rows,:src1_cols], src1[:src1_rows,:src1_cols]) + dst[src1_rows:,:] = src0[src1_rows:,:] (copy remaining rows) + OR (for col_less) dst[:,:src1_cols] = min(src0[:,:src1_cols], src1[:,:src1_cols]) + dst[:,src1_cols:] = src0[:,src1_cols:] (copy remaining cols) + - If src1_valid == dst_valid: similar logic with src1 as the full operand. + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + # float32 cases from pto-isa + { + "name": "f32_64x64_full", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (64, 64), # src0 valid region + "src1_vshape": (64, 64), # src1 valid region (same as dst) + "dst_vshape": (64, 64), # dst valid region + "eps": 1e-6, + }, + { + "name": "f32_2x24_src1_col_less", + "dtype": np.float32, + "shape": (2, 24), + "valid_shape": (2, 24), # src0 valid region (equals dst) + "src1_vshape": (2, 8), # src1 valid region (col_less) + "dst_vshape": (2, 24), # dst valid region + "eps": 1e-6, + }, + { + "name": "f32_128x64_src1_row_less", + "dtype": np.float32, + "shape": (128, 64), + "valid_shape": (128, 64), # src0 valid region (equals dst) + "src1_vshape": (96, 64), # src1 valid region (row_less) + "dst_vshape": (128, 64), # dst valid region + "eps": 1e-6, + }, + { + "name": "f32_95x95_full", + "dtype": np.float32, + "shape": (95, 95), + "valid_shape": (95, 95), # src0 valid region + "src1_vshape": (95, 95), # src1 valid region (same as dst) + "dst_vshape": (95, 95), # dst valid region + "eps": 1e-6, + }, + { + "name": "f32_122x123_complex", + "dtype": np.float32, + "shape": (122, 123), + "valid_shape": (104, 123), # src0 valid region + "src1_vshape": (122, 110), # src1 valid region + "dst_vshape": (122, 123), # dst valid region (src1 rows, src0 cols) + "eps": 1e-6, + }, + # float16 cases from pto-isa + { + "name": "f16_122x123_complex", + "dtype": np.float16, + "shape": (122, 123), + "valid_shape": (104, 123), # src0 valid region + "src1_vshape": (122, 110), # src1 valid region + "dst_vshape": (122, 123), # dst valid region + "eps": 1e-3, + }, + # int16 cases from pto-isa + { + "name": "i16_122x123_complex", + "dtype": np.int16, + "shape": (122, 123), + "valid_shape": (104, 123), # src0 valid region + "src1_vshape": (122, 110), # src1 valid region + "dst_vshape": (122, 123), # dst valid region + "eps": 0, + }, + # int32 cases from pto-isa + { + "name": "i32_122x123_complex", + "dtype": np.int32, + "shape": (122, 123), + "valid_shape": (104, 123), # src0 valid region + "src1_vshape": (122, 110), # src1 valid region + "dst_vshape": (122, 123), # dst valid region + "eps": 0, + }, + # uint16 cases from pto-isa + { + "name": "u16_122x123_complex", + "dtype": np.uint16, + "shape": (122, 123), + "valid_shape": (104, 123), # src0 valid region + "src1_vshape": (122, 110), # src1 valid region + "dst_vshape": (122, 123), # dst valid region + "eps": 0, + }, + # uint32 cases from pto-isa + { + "name": "u32_122x123_complex", + "dtype": np.uint32, + "shape": (122, 123), + "valid_shape": (104, 123), # src0 valid region + "src1_vshape": (122, 110), # src1 valid region + "dst_vshape": (122, 123), # dst valid region + "eps": 0, + }, + # int8 cases from pto-isa + { + "name": "i8_122x123_complex", + "dtype": np.int8, + "shape": (122, 123), + "valid_shape": (104, 123), # src0 valid region + "src1_vshape": (122, 110), # src1 valid region + "dst_vshape": (122, 123), # dst valid region + "eps": 0, + }, + # uint8 cases from pto-isa + { + "name": "u8_122x123_complex", + "dtype": np.uint8, + "shape": (122, 123), + "valid_shape": (104, 123), # src0 valid region + "src1_vshape": (122, 110), # src1 valid region + "dst_vshape": (122, 123), # dst valid region + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/compare.py new file mode 100644 index 000000000..283ee788a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/compare.py @@ -0,0 +1,53 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +# Add parent directory to path for st_common import +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from st_common import result_cmp, style_fail, style_pass + +from cases import CASES + + +def main(): + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + dtype = case["dtype"] + dst_vr, dst_vc = case["dst_vshape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=dtype).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=dtype).reshape(shape) + + # Compare only the dst valid region + ok = result_cmp(golden[:dst_vr, :dst_vc], output[:dst_vr, :dst_vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/gen_data.py new file mode 100644 index 000000000..fb3766f42 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/gen_data.py @@ -0,0 +1,127 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +import os +import sys + +# Add parent directory to path for st_common import +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from st_common import setup_case_rng, save_case_data + +from cases import CASES + + +def _to_tuple(shape): + """Convert shape to tuple if needed.""" + if isinstance(shape, tuple): + return shape + return tuple(shape) + + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = _to_tuple(case["shape"]) + src0_valid = _to_tuple(case["valid_shape"]) + src1_valid = _to_tuple(case["src1_vshape"]) + dst_valid = _to_tuple(case["dst_vshape"]) + + rows, cols = shape + src0_vr, src0_vc = src0_valid + src1_vr, src1_vc = src1_valid + dst_vr, dst_vc = dst_valid + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + input2 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + + # tpartmin semantics (based on pto-isa TPartBinOps.hpp TCopyPadOp): + # Algorithm: + # 1. dst[:] = Max (padding for min operation) + # 2. dst[0:src0_vr, 0:src0_vc] = src0[0:src0_vr, 0:src0_vc] (copy src0 to dst) + # 3. dst[0:src1_vr, 0:src1_vc] = min(dst[0:src1_vr, 0:src1_vc], src1[0:src1_vr, 0:src1_vc]) + # (apply min in src1 valid region) + + src0_eq_dst = (src0_vr == dst_vr and src0_vc == dst_vc) + src1_eq_dst = (src1_vr == dst_vr and src1_vc == dst_vc) + + if src0_eq_dst and src1_eq_dst: + # Full min: both src0 and src1 cover entire dst + golden[:dst_vr, :dst_vc] = np.minimum(input1[:dst_vr, :dst_vc], input2[:dst_vr, :dst_vc]).astype(dtype, copy=False) + elif src0_eq_dst: + # src0 covers dst, src1 is partial + # dst = src0 (copy), then min(dst, src1) in src1 region = min(src0, src1) in src1 region, src0 in rest + golden[:src1_vr, :src1_vc] = np.minimum(input1[:src1_vr, :src1_vc], input2[:src1_vr, :src1_vc]).astype(dtype, copy=False) + if src1_vc < dst_vc: + golden[:src1_vr, src1_vc:dst_vc] = input1[:src1_vr, src1_vc:dst_vc].copy() + if src1_vr < dst_vr: + golden[src1_vr:dst_vr, :dst_vc] = input1[src1_vr:dst_vr, :dst_vc].copy() + elif src1_eq_dst: + # src1 covers dst, src0 is partial + # dst = Max, then copy src0 in src0 region, then min(dst, src1) in src1 region + golden[:src0_vr, :src0_vc] = np.minimum(input1[:src0_vr, :src0_vc], input2[:src0_vr, :src0_vc]).astype(dtype, copy=False) + if src0_vc < dst_vc: + golden[:src0_vr, src0_vc:dst_vc] = input2[:src0_vr, src0_vc:dst_vc].copy() + if src0_vr < dst_vr: + golden[src0_vr:dst_vr, :dst_vc] = input2[src0_vr:dst_vr, :dst_vc].copy() + else: + min_vr = min(src0_vr, src1_vr) + min_vc = min(src0_vc, src1_vc) + + # Region 1: [0:min_vr, 0:min_vc] - overlapping region (both src0 and src1 valid) + golden[:min_vr, :min_vc] = np.minimum(input1[:min_vr, :min_vc], input2[:min_vr, :min_vc]).astype(dtype, copy=False) + + # Region 2: [0:src0_vr, min_vc:src0_vc] if src0_vc > min_vc + if src0_vc > min_vc: + golden[:src0_vr, min_vc:src0_vc] = input1[:src0_vr, min_vc:src0_vc].copy() + + # Region 3: [min_vr:src1_vr, 0:min_vc] if src1_vr > min_vr + if src1_vr > min_vr: + golden[min_vr:src1_vr, :min_vc] = input2[min_vr:src1_vr, :min_vc].copy() + + # Region 4: [min_vr:src1_vr, min_vc:src1_vc] if src1_vr > min_vr AND src1_vc > min_vc + if src1_vr > min_vr and src1_vc > min_vc: + golden[min_vr:src1_vr, min_vc:src1_vc] = input2[min_vr:src1_vr, min_vc:src1_vc].copy() + + # Region 5: [0:min_vr, src1_vc:src0_vc] if src0_vc > src1_vc + if src0_vc > src1_vc and min_vr > 0: + # Already handled in Region 2 if rows are [0:src0_vr] + pass # Region 2 covers this + + if src1_vr > src0_vr and src0_vc > src1_vc: + # Region [src0_vr:src1_vr, src1_vc:src0_vc] = Max (neither covers) + # This is correct for tpartmin - padding value is Max + # For floats, we use np.inf. For integers, use dtype max. + if dtype == np.float32: + max_val = np.finfo(np.float32).max + elif dtype == np.float16: + max_val = np.finfo(np.float16).max + elif dtype == np.int8: + max_val = np.iinfo(np.int8).max + elif dtype == np.uint8: + max_val = np.iinfo(np.uint8).max + elif dtype == np.int16: + max_val = np.iinfo(np.int16).max + elif dtype == np.uint16: + max_val = np.iinfo(np.uint16).max + elif dtype == np.int32: + max_val = np.iinfo(np.int32).max + elif dtype == np.uint32: + max_val = np.iinfo(np.uint32).max + else: + max_val = np.iinfo(dtype).max + golden[src0_vr:src1_vr, src1_vc:src0_vc] = max_val + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} src0_valid={src0_valid} src1_valid={src1_valid} dst_valid={dst_valid} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/launch.cpp new file mode 100644 index 000000000..4fdee00b6 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/launch.cpp @@ -0,0 +1,97 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case: f32 64x64 full +extern "C" __global__ AICORE void TPARTMIN_f32_64x64_full(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTMIN_f32_64x64_full(float *a, float *b, float *c, void *stream) { + TPARTMIN_f32_64x64_full<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case: f32 2x24 src1 col less +extern "C" __global__ AICORE void TPARTMIN_f32_2x24_src1_col_less(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTMIN_f32_2x24_src1_col_less(float *a, float *b, float *c, void *stream) { + TPARTMIN_f32_2x24_src1_col_less<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case: f32 128x64 src1 row less +extern "C" __global__ AICORE void TPARTMIN_f32_128x64_src1_row_less(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTMIN_f32_128x64_src1_row_less(float *a, float *b, float *c, void *stream) { + TPARTMIN_f32_128x64_src1_row_less<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case: f32 95x95 full +extern "C" __global__ AICORE void TPARTMIN_f32_95x95_full(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTMIN_f32_95x95_full(float *a, float *b, float *c, void *stream) { + TPARTMIN_f32_95x95_full<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case: f32 122x123 complex +extern "C" __global__ AICORE void TPARTMIN_f32_122x123_complex(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTMIN_f32_122x123_complex(float *a, float *b, float *c, void *stream) { + TPARTMIN_f32_122x123_complex<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case: f16 122x123 complex +extern "C" __global__ AICORE void TPARTMIN_f16_122x123_complex(__gm__ uint16_t *a, __gm__ uint16_t *b, __gm__ uint16_t *c); + +void LaunchTPARTMIN_f16_122x123_complex(uint16_t *a, uint16_t *b, uint16_t *c, void *stream) { + TPARTMIN_f16_122x123_complex<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b, (__gm__ uint16_t *)c); +} + +// Case: i16 122x123 complex +extern "C" __global__ AICORE void TPARTMIN_i16_122x123_complex(__gm__ int16_t *a, __gm__ int16_t *b, __gm__ int16_t *c); + +void LaunchTPARTMIN_i16_122x123_complex(int16_t *a, int16_t *b, int16_t *c, void *stream) { + TPARTMIN_i16_122x123_complex<<<1, nullptr, stream>>>((__gm__ int16_t *)a, (__gm__ int16_t *)b, (__gm__ int16_t *)c); +} + +// Case: i32 122x123 complex +extern "C" __global__ AICORE void TPARTMIN_i32_122x123_complex(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); + +void LaunchTPARTMIN_i32_122x123_complex(int32_t *a, int32_t *b, int32_t *c, void *stream) { + TPARTMIN_i32_122x123_complex<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); +} + +// Case: u16 122x123 complex +extern "C" __global__ AICORE void TPARTMIN_u16_122x123_complex(__gm__ uint16_t *a, __gm__ uint16_t *b, __gm__ uint16_t *c); + +void LaunchTPARTMIN_u16_122x123_complex(uint16_t *a, uint16_t *b, uint16_t *c, void *stream) { + TPARTMIN_u16_122x123_complex<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b, (__gm__ uint16_t *)c); +} + +// Case: u32 122x123 complex +extern "C" __global__ AICORE void TPARTMIN_u32_122x123_complex(__gm__ uint32_t *a, __gm__ uint32_t *b, __gm__ uint32_t *c); + +void LaunchTPARTMIN_u32_122x123_complex(uint32_t *a, uint32_t *b, uint32_t *c, void *stream) { + TPARTMIN_u32_122x123_complex<<<1, nullptr, stream>>>((__gm__ uint32_t *)a, (__gm__ uint32_t *)b, (__gm__ uint32_t *)c); +} + +// Case: i8 122x123 complex +extern "C" __global__ AICORE void TPARTMIN_i8_122x123_complex(__gm__ int8_t *a, __gm__ int8_t *b, __gm__ int8_t *c); + +void LaunchTPARTMIN_i8_122x123_complex(int8_t *a, int8_t *b, int8_t *c, void *stream) { + TPARTMIN_i8_122x123_complex<<<1, nullptr, stream>>>((__gm__ int8_t *)a, (__gm__ int8_t *)b, (__gm__ int8_t *)c); +} + +// Case: u8 122x123 complex +extern "C" __global__ AICORE void TPARTMIN_u8_122x123_complex(__gm__ uint8_t *a, __gm__ uint8_t *b, __gm__ uint8_t *c); + +void LaunchTPARTMIN_u8_122x123_complex(uint8_t *a, uint8_t *b, uint8_t *c, void *stream) { + TPARTMIN_u8_122x123_complex<<<1, nullptr, stream>>>((__gm__ uint8_t *)a, (__gm__ uint8_t *)b, (__gm__ uint8_t *)c); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/main.cpp new file mode 100644 index 000000000..5251f0149 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/main.cpp @@ -0,0 +1,230 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tpartmin ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTPARTMIN_f32_64x64_full(float *a, float *b, float *c, void *stream); +void LaunchTPARTMIN_f32_2x24_src1_col_less(float *a, float *b, float *c, void *stream); +void LaunchTPARTMIN_f32_128x64_src1_row_less(float *a, float *b, float *c, void *stream); +void LaunchTPARTMIN_f32_95x95_full(float *a, float *b, float *c, void *stream); +void LaunchTPARTMIN_f32_122x123_complex(float *a, float *b, float *c, void *stream); +void LaunchTPARTMIN_f16_122x123_complex(uint16_t *a, uint16_t *b, uint16_t *c, void *stream); +void LaunchTPARTMIN_i16_122x123_complex(int16_t *a, int16_t *b, int16_t *c, void *stream); +void LaunchTPARTMIN_i32_122x123_complex(int32_t *a, int32_t *b, int32_t *c, void *stream); +void LaunchTPARTMIN_u16_122x123_complex(uint16_t *a, uint16_t *b, uint16_t *c, void *stream); +void LaunchTPARTMIN_u32_122x123_complex(uint32_t *a, uint32_t *b, uint32_t *c, void *stream); +void LaunchTPARTMIN_i8_122x123_complex(int8_t *a, int8_t *b, int8_t *c, void *stream); +void LaunchTPARTMIN_u8_122x123_complex(uint8_t *a, uint8_t *b, uint8_t *c, void *stream); + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols (valid cols) + size_t src0ValidRows; // src0 effective rows + size_t src0ValidCols; // src0 effective cols + size_t src1ValidRows; // src1 effective rows + size_t src1ValidCols; // src1 effective cols + size_t dstValidRows; // dst effective rows + size_t dstValidCols; // dst effective cols + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_64x64_full", reinterpret_cast(LaunchTPARTMIN_f32_64x64_full), 64, 64, 64, 64, 64, 64, 64, 64, sizeof(float)}, + {"f32_2x24_src1_col_less", reinterpret_cast(LaunchTPARTMIN_f32_2x24_src1_col_less), 2, 24, 2, 24, 2, 8, 2, 24, sizeof(float)}, + {"f32_128x64_src1_row_less", reinterpret_cast(LaunchTPARTMIN_f32_128x64_src1_row_less), 128, 64,128, 64, 96, 64,128, 64, sizeof(float)}, + {"f32_95x95_full", reinterpret_cast(LaunchTPARTMIN_f32_95x95_full), 95, 95, 95, 95, 95, 95, 95, 95, sizeof(float)}, + {"f32_122x123_complex", reinterpret_cast(LaunchTPARTMIN_f32_122x123_complex), 122,123,104,123,122,110,122,123, sizeof(float)}, + {"f16_122x123_complex", reinterpret_cast(LaunchTPARTMIN_f16_122x123_complex), 122,123,104,123,122,110,122,123, sizeof(uint16_t)}, + {"i16_122x123_complex", reinterpret_cast(LaunchTPARTMIN_i16_122x123_complex), 122,123,104,123,122,110,122,123, sizeof(int16_t)}, + {"i32_122x123_complex", reinterpret_cast(LaunchTPARTMIN_i32_122x123_complex), 122,123,104,123,122,110,122,123, sizeof(int32_t)}, + {"u16_122x123_complex", reinterpret_cast(LaunchTPARTMIN_u16_122x123_complex), 122,123,104,123,122,110,122,123, sizeof(uint16_t)}, + {"u32_122x123_complex", reinterpret_cast(LaunchTPARTMIN_u32_122x123_complex), 122,123,104,123,122,110,122,123, sizeof(uint32_t)}, + {"i8_122x123_complex", reinterpret_cast(LaunchTPARTMIN_i8_122x123_complex), 122,123,104,123,122,110,122,123, sizeof(int8_t)}, + {"u8_122x123_complex", reinterpret_cast(LaunchTPARTMIN_u8_122x123_complex), 122,123,104,123,122,110,122,123, sizeof(uint8_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +// Calculate aligned cols for 32-byte alignment +static size_t CalcAlignedCols(size_t cols, size_t elemSize) { + size_t totalBytes = cols * elemSize; + size_t alignedBytes = ((totalBytes + 31) / 32) * 32; + return alignedBytes / elemSize; +} + +// Helper to pad data with stride +static void PadDataWithStride(const void *src, void *dst, size_t rows, size_t cols, + size_t alignedCols, size_t elemSize) { + const char *srcPtr = static_cast(src); + char *dstPtr = static_cast(dst); + for (size_t r = 0; r < rows; ++r) { + memcpy(dstPtr + r * alignedCols * elemSize, + srcPtr + r * cols * elemSize, + cols * elemSize); + // Zero-fill padding region (optional, data will be overwritten by kernel) + memset(dstPtr + r * alignedCols * elemSize + cols * elemSize, + 0, + (alignedCols - cols) * elemSize); + } +} + +// Helper to unpad data (extract valid cols) +static void UnpadDataWithStride(const void *src, void *dst, size_t rows, size_t cols, + size_t alignedCols, size_t elemSize) { + const char *srcPtr = static_cast(src); + char *dstPtr = static_cast(dst); + for (size_t r = 0; r < rows; ++r) { + memcpy(dstPtr + r * cols * elemSize, + srcPtr + r * alignedCols * elemSize, + cols * elemSize); + } +} + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + const size_t alignedCols = CalcAlignedCols(tc.cols, tc.elemSize); + const size_t paddedSize = tc.rows * alignedCols * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, src0_valid=%zux%zu, src1_valid=%zux%zu, dst_valid=%zux%zu, alignedCols=%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.src0ValidRows, tc.src0ValidCols, + tc.src1ValidRows, tc.src1ValidCols, tc.dstValidRows, tc.dstValidCols, alignedCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + + void *src0HostOrig = nullptr, *src1HostOrig = nullptr, *dstHostOrig = nullptr; + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + // Allocate host buffers for original data (contiguous) + aclrtMallocHost((void **)(&src0HostOrig), fileSize); + aclrtMallocHost((void **)(&src1HostOrig), fileSize); + aclrtMallocHost((void **)(&dstHostOrig), fileSize); + + // Allocate host buffers for padded data + aclrtMallocHost((void **)(&src0Host), paddedSize); + aclrtMallocHost((void **)(&src1Host), paddedSize); + aclrtMallocHost((void **)(&dstHost), paddedSize); + + // Allocate device buffers with padded size + aclrtMalloc((void **)&src0Device, paddedSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, paddedSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, paddedSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (rc == 0) { + size_t src0FileSize = fileSize; + size_t src1FileSize = fileSize; + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0HostOrig, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1HostOrig, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + } + + if (rc == 0) { + // Pad input data with stride + PadDataWithStride(src0HostOrig, src0Host, tc.rows, tc.cols, alignedCols, tc.elemSize); + PadDataWithStride(src1HostOrig, src1Host, tc.rows, tc.cols, alignedCols, tc.elemSize); + + aclrtMemcpy(src0Device, paddedSize, src0Host, paddedSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, paddedSize, src1Host, paddedSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, paddedSize, dstDevice, paddedSize, ACL_MEMCPY_DEVICE_TO_HOST); + + // Unpad output data + UnpadDataWithStride(dstHost, dstHostOrig, tc.rows, tc.cols, alignedCols, tc.elemSize); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHostOrig, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + if (src0HostOrig != nullptr) + aclrtFreeHost(src0HostOrig); + if (src1HostOrig != nullptr) + aclrtFreeHost(src1HostOrig); + if (dstHostOrig != nullptr) + aclrtFreeHost(dstHostOrig); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tpartmin [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/tpartmin.pto b/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/tpartmin.pto new file mode 100644 index 000000000..f7583c72c --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/tpartmin.pto @@ -0,0 +1,718 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use the file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tpartmin: partial elementwise min with valid region handling. +// Multiple cases with different valid_shape combinations in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. + +module { + // Case: f32_64x64_full (src0 valid 64x64, src1 valid 64x64, dst valid 64x64) + func.func @TPARTMIN_f32_64x64_full(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + // src0: valid region (64,64) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (64,64) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (64,64) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case: f32_2x24_src1_col_less (src0 valid 2x24, src1 valid 2x8, dst valid 2x24) + func.func @TPARTMIN_f32_2x24_src1_col_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c24 = arith.constant 24 : index + %c48 = arith.constant 48 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c2, %c24], + strides = [%c48, %c48, %c48, %c24, %c1] + : !pto.tensor_view<1x1x1x2x24xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c2, %c24], + strides = [%c48, %c48, %c48, %c24, %c1] + : !pto.tensor_view<1x1x1x2x24xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c2, %c24], + strides = [%c48, %c48, %c48, %c24, %c1] + : !pto.tensor_view<1x1x1x2x24xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c24] + : !pto.tensor_view<1x1x1x2x24xf32> -> !pto.partition_tensor_view<1x1x1x2x24xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c24] + : !pto.tensor_view<1x1x1x2x24xf32> -> !pto.partition_tensor_view<1x1x1x2x24xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c24] + : !pto.tensor_view<1x1x1x2x24xf32> -> !pto.partition_tensor_view<1x1x1x2x24xf32> + + // src0: valid region (2,24) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (2,8) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (2,24) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x2x24xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x2x24xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x2x24xf32>) + return + } + + // Case: f32_128x64_src1_row_less (src0 valid 128x64, src1 valid 96x64, dst valid 128x64) + func.func @TPARTMIN_f32_128x64_src1_row_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c64 = arith.constant 64 : index + %c8192 = arith.constant 8192 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c128, %c64], + strides = [%c8192, %c8192, %c8192, %c64, %c1] + : !pto.tensor_view<1x1x1x128x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c128, %c64], + strides = [%c8192, %c8192, %c8192, %c64, %c1] + : !pto.tensor_view<1x1x1x128x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c128, %c64], + strides = [%c8192, %c8192, %c8192, %c64, %c1] + : !pto.tensor_view<1x1x1x128x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c64] + : !pto.tensor_view<1x1x1x128x64xf32> -> !pto.partition_tensor_view<1x1x1x128x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c64] + : !pto.tensor_view<1x1x1x128x64xf32> -> !pto.partition_tensor_view<1x1x1x128x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c64] + : !pto.tensor_view<1x1x1x128x64xf32> -> !pto.partition_tensor_view<1x1x1x128x64xf32> + + // src0: valid region (128,64) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (96,64) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (128,64) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x128x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x128x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x128x64xf32>) + return + } + + // Case: f32_95x95_full (src0 valid 95x95, src1 valid 95x95, dst valid 95x95) + func.func @TPARTMIN_f32_95x95_full(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c95 = arith.constant 95 : index + %c96 = arith.constant 96 : index + %c9120 = arith.constant 9120 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c95, %c96], + strides = [%c9120, %c9120, %c9120, %c96, %c1] + : !pto.tensor_view<1x1x1x95x96xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c95, %c96], + strides = [%c9120, %c9120, %c9120, %c96, %c1] + : !pto.tensor_view<1x1x1x95x96xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c95, %c96], + strides = [%c9120, %c9120, %c9120, %c96, %c1] + : !pto.tensor_view<1x1x1x95x96xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c95, %c96] + : !pto.tensor_view<1x1x1x95x96xf32> -> !pto.partition_tensor_view<1x1x1x95x96xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c95, %c96] + : !pto.tensor_view<1x1x1x95x96xf32> -> !pto.partition_tensor_view<1x1x1x95x96xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c95, %c95] + : !pto.tensor_view<1x1x1x95x96xf32> -> !pto.partition_tensor_view<1x1x1x95x95xf32> + + // src0: valid region (95,95) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (95,95) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (95,95) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x95x96xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x95x96xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x95x95xf32>) + return + } + + // Case: f32_122x123_complex (src0 valid 104x123, src1 valid 122x110, dst valid 122x123) + func.func @TPARTMIN_f32_122x123_complex(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c122 = arith.constant 122 : index + %c128 = arith.constant 128 : index + %c15616 = arith.constant 15616 : index + %c123 = arith.constant 123 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xf32> -> !pto.partition_tensor_view<1x1x1x122x128xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xf32> -> !pto.partition_tensor_view<1x1x1x122x128xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c123] + : !pto.tensor_view<1x1x1x122x128xf32> -> !pto.partition_tensor_view<1x1x1x122x123xf32> + + // src0: valid region (104,123) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (122,110) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (122,123) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x122x128xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x122x128xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x122x123xf32>) + return + } + + // Case: f16_122x123_complex (src0 valid 104x123, src1 valid 122x110, dst valid 122x123) + func.func @TPARTMIN_f16_122x123_complex(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c122 = arith.constant 122 : index + %c128 = arith.constant 128 : index + %c15616 = arith.constant 15616 : index + %c123 = arith.constant 123 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xf16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xf16> -> !pto.partition_tensor_view<1x1x1x122x128xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xf16> -> !pto.partition_tensor_view<1x1x1x122x128xf16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c123] + : !pto.tensor_view<1x1x1x122x128xf16> -> !pto.partition_tensor_view<1x1x1x122x123xf16> + + // src0: valid region (104,123) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (122,110) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (122,123) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x122x128xf16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x122x128xf16>) + outs(%b : !pto.tile_buf) + + pto.tpartmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x122x123xf16>) + return + } + + // Case: i16_122x123_complex (src0 valid 104x123, src1 valid 122x110, dst valid 122x123) + func.func @TPARTMIN_i16_122x123_complex(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c122 = arith.constant 122 : index + %c128 = arith.constant 128 : index + %c15616 = arith.constant 15616 : index + %c123 = arith.constant 123 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xi16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xi16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xi16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xi16> -> !pto.partition_tensor_view<1x1x1x122x128xi16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xi16> -> !pto.partition_tensor_view<1x1x1x122x128xi16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c123] + : !pto.tensor_view<1x1x1x122x128xi16> -> !pto.partition_tensor_view<1x1x1x122x123xi16> + + // src0: valid region (104,123) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (122,110) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (122,123) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x122x128xi16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x122x128xi16>) + outs(%b : !pto.tile_buf) + + pto.tpartmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x122x123xi16>) + return + } + + // Case: i32_122x123_complex (src0 valid 104x123, src1 valid 122x110, dst valid 122x123) + func.func @TPARTMIN_i32_122x123_complex(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c122 = arith.constant 122 : index + %c128 = arith.constant 128 : index + %c15616 = arith.constant 15616 : index + %c123 = arith.constant 123 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xi32> -> !pto.partition_tensor_view<1x1x1x122x128xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xi32> -> !pto.partition_tensor_view<1x1x1x122x128xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c123] + : !pto.tensor_view<1x1x1x122x128xi32> -> !pto.partition_tensor_view<1x1x1x122x123xi32> + + // src0: valid region (104,123) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (122,110) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (122,123) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x122x128xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x122x128xi32>) + outs(%b : !pto.tile_buf) + + pto.tpartmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x122x123xi32>) + return + } + + // Case: u16_122x123_complex (src0 valid 104x123, src1 valid 122x110, dst valid 122x123) + func.func @TPARTMIN_u16_122x123_complex(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c122 = arith.constant 122 : index + %c128 = arith.constant 128 : index + %c15616 = arith.constant 15616 : index + %c123 = arith.constant 123 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xui16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xui16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xui16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xui16> -> !pto.partition_tensor_view<1x1x1x122x128xui16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xui16> -> !pto.partition_tensor_view<1x1x1x122x128xui16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c123] + : !pto.tensor_view<1x1x1x122x128xui16> -> !pto.partition_tensor_view<1x1x1x122x123xui16> + + // src0: valid region (104,123) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (122,110) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (122,123) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x122x128xui16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x122x128xui16>) + outs(%b : !pto.tile_buf) + + pto.tpartmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x122x123xui16>) + return + } + + // Case: u32_122x123_complex (src0 valid 104x123, src1 valid 122x110, dst valid 122x123) + func.func @TPARTMIN_u32_122x123_complex(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c122 = arith.constant 122 : index + %c128 = arith.constant 128 : index + %c15616 = arith.constant 15616 : index + %c123 = arith.constant 123 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xui32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xui32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xui32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xui32> -> !pto.partition_tensor_view<1x1x1x122x128xui32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xui32> -> !pto.partition_tensor_view<1x1x1x122x128xui32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c123] + : !pto.tensor_view<1x1x1x122x128xui32> -> !pto.partition_tensor_view<1x1x1x122x123xui32> + + // src0: valid region (104,123) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (122,110) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (122,123) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x122x128xui32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x122x128xui32>) + outs(%b : !pto.tile_buf) + + pto.tpartmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x122x123xui32>) + return + } + + // Case: i8_122x123_complex (src0 valid 104x123, src1 valid 122x110, dst valid 122x123) + func.func @TPARTMIN_i8_122x123_complex(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c122 = arith.constant 122 : index + %c128 = arith.constant 128 : index + %c15616 = arith.constant 15616 : index + %c123 = arith.constant 123 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xi8> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xi8> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xi8> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xi8> -> !pto.partition_tensor_view<1x1x1x122x128xi8> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xi8> -> !pto.partition_tensor_view<1x1x1x122x128xi8> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c123] + : !pto.tensor_view<1x1x1x122x128xi8> -> !pto.partition_tensor_view<1x1x1x122x123xi8> + + // src0: valid region (104,123) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (122,110) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (122,123) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x122x128xi8>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x122x128xi8>) + outs(%b : !pto.tile_buf) + + pto.tpartmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x122x123xi8>) + return + } + + // Case: u8_122x123_complex (src0 valid 104x123, src1 valid 122x110, dst valid 122x123) + func.func @TPARTMIN_u8_122x123_complex(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c122 = arith.constant 122 : index + %c128 = arith.constant 128 : index + %c15616 = arith.constant 15616 : index + %c123 = arith.constant 123 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xui8> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xui8> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xui8> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xui8> -> !pto.partition_tensor_view<1x1x1x122x128xui8> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xui8> -> !pto.partition_tensor_view<1x1x1x122x128xui8> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c123] + : !pto.tensor_view<1x1x1x122x128xui8> -> !pto.partition_tensor_view<1x1x1x122x123xui8> + + // src0: valid region (104,123) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (122,110) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (122,123) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x122x128xui8>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x122x128xui8>) + outs(%b : !pto.tile_buf) + + pto.tpartmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x122x123xui8>) + return + } +} + From 31a863ab8614b39161eac1941d64a38ffe29768b Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Fri, 24 Apr 2026 16:20:23 +0800 Subject: [PATCH 160/192] fix(vpto): legalize integer cast ops after VPTO lowering (#240) --- lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 20 ++++--- tilelang-dsl/python/tilelang_dsl/semantic.py | 2 +- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 63 ++++++++++++++++++-- 3 files changed, 70 insertions(+), 15 deletions(-) diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index 76821b9ae..679a761b1 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -395,21 +395,24 @@ static FailureOr reinterpretPointerToAddrSpace(Operation *anchor, } static FailureOr normalizeVdupScalarOperand(OpBuilder &builder, Location loc, - pto::VdupOp op) { - Value input = op.getInput(); + Value input, + Type resultType) { auto intType = dyn_cast(input.getType()); if (!intType || intType.getWidth() != 8) return input; - Type resultElemType = getElementTypeFromVectorLike(op.getResult().getType()); + Type resultElemType = getElementTypeFromVectorLike(resultType); std::string resultElemFragment = getElementTypeFragment(resultElemType); if (resultElemFragment != "s8" && resultElemFragment != "u8") return input; - Type i16Type = builder.getIntegerType(16); - if (resultElemFragment == "u8") - return builder.create(loc, i16Type, input).getResult(); - return builder.create(loc, i16Type, input).getResult(); + if (intType.isSignless()) + return input; + + Type signlessType = builder.getIntegerType(intType.getWidth()); + return builder + .create(loc, TypeRange{signlessType}, input) + .getResult(0); } static Value normalizeByteScalarOperandForHivmCall(OpBuilder &builder, Location loc, @@ -4133,7 +4136,8 @@ class LowerVdupOpPattern final : public OpConversionPattern { "unexpected scalar-input vdup type"); } FailureOr normalizedScalar = - normalizeVdupScalarOperand(rewriter, op.getLoc(), op); + normalizeVdupScalarOperand(rewriter, op.getLoc(), adaptor.getInput(), + op.getResult().getType()); if (failed(normalizedScalar)) return rewriter.notifyMatchFailure(op, "failed to normalize scalar vdup input"); diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 43a821459..bda55f221 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -1076,7 +1076,7 @@ def _stmt_can_continue_inferred_vecscope_run( return False return self._frontend_stmt_can_live_in_inferred_vecscope( stmt - ) or self._frontend_stmt_is_neutral_vecscope_stmt(stmt) + ) or self._frontend_stmt_is_scalar_vecscope_stmt(stmt) def _stmt_allows_inferred_vecscope(self, allow_inferred_vecscope: bool) -> bool: if self._has_explicit_vecscope: diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index 736ada783..d7535e6fd 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -5455,7 +5455,7 @@ def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): self.assertRegex(cols_dynamic_text, r"scf\.for %row_\d+ = %c0 to %valid_rows_\d+ step %c1") self.assertRegex(cols_dynamic_text, r"scf\.for %col_\d+ = %c0 to %valid_cols_\d+ step %c128") - def test_advanced_mode_scalar_boundaries_split_inferred_vecscope_runs(self) -> None: + def test_advanced_mode_scalar_assignments_stay_inside_inferred_vecscope_runs(self) -> None: @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32)], advanced=True) def kernel(src: pto.Tile, dst: pto.Tile): dtype = src.element_type @@ -5475,13 +5475,17 @@ def kernel(src: pto.Tile, dst: pto.Tile): semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) vecscope_stmts = [stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticVecscopeStmt)] - self.assertEqual(len(vecscope_stmts), 2) + self.assertEqual(len(vecscope_stmts), 1) text = specialized.mlir_text() - self.assertEqual(text.count("pto.vecscope {"), 2) - self.assertLess(text.index("pto.vecscope {"), text.index("%boundary_")) - self.assertLess(text.index("%boundary_"), text.index("return")) - self.assertLess(text.index("%boundary_"), text.rindex("pto.vecscope {")) + self.assertEqual(text.count("pto.vecscope {"), 1) + boundary_index = text.index("%boundary_") + first_vsts = text.index("pto.vsts") + second_vsts = text.rindex("pto.vsts") + self.assertLess(text.index("pto.vecscope {"), boundary_index) + self.assertLess(first_vsts, boundary_index) + self.assertLess(boundary_index, second_vsts) + self.assertLess(boundary_index, text.index("return")) def test_explicit_vecscope_is_supported_in_stable_mode(self) -> None: @pto.vkernel(op="explicit_vecscope_stable_unique", dtypes=[(pto.f32, pto.f32)]) @@ -6248,6 +6252,53 @@ def kernel(src: pto.Tile, dst: pto.Tile): self.assertIn("pto.vselr", text) self.assertIn("pto.vsts", text) + def test_inferred_vecscope_keeps_scalar_get_lanes_between_vector_def_and_use(self) -> None: + @pto.vkernel(op="issue_240_vecscope", dtypes=[(pto.si8, pto.i32)], advanced=True) + def kernel(src: pto.Tile, dst: pto.Tile): + valid_rows, valid_cols = dst.valid_shape + b8_mask = pto.make_mask(pto.ui8, pto.PAT.ALL) + v_zero = pto.vdup(pto.ui8(0), b8_mask) + lanes_i32 = pto.get_lanes(pto.i32) + lanes_i16 = pto.get_lanes(pto.i16) + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, lanes_i16): + mask_b16_cur, remained = pto.make_mask(pto.i16, remained) + mask_b16_next, remained2 = pto.make_mask(pto.i16, remained) + mask_b32_cur = pto.punpack(mask_b16_cur, pto.PredicatePart.LOWER) + mask_b32_next = pto.punpack(mask_b16_next, pto.PredicatePart.LOWER) + vec_si8 = pto.vlds(src[row, col:], dist=pto.VLoadDist.UNPK_B8) + vec_ui8 = pto.vbitcast(vec_si8, pto.ui8) + vec_ui8_lo, vec_ui8_hi = pto.vintlv(vec_ui8, v_zero) + vec_si8_lo = pto.vbitcast(vec_ui8_lo, pto.si8) + vec_si8_hi = pto.vbitcast(vec_ui8_hi, pto.si8) + out_lo = pto.vcvt(vec_si8_lo, pto.i32, b8_mask, part=pto.VcvtPartMode.P0) + out_hi = pto.vcvt(vec_si8_hi, pto.i32, b8_mask, part=pto.VcvtPartMode.P0) + pto.vsts(out_lo, dst[row, col:], mask_b32_cur, dist=pto.VStoreDist.NORM_B32) + pto.vsts( + out_hi, + dst[row, col + lanes_i32:], + mask_b32_next, + dist=pto.VStoreDist.NORM_B32, + ) + return None + + specialized = kernel.specialize( + src=pto.TileSpecialization(shape=(16, 64), memory_space=pto.MemorySpace.UB), + dst=pto.TileSpecialization(shape=(16, 64), memory_space=pto.MemorySpace.UB), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + vecscope_stmts = [stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticVecscopeStmt)] + self.assertEqual(len(vecscope_stmts), 1) + + text = specialized.mlir_text() + self.assertEqual(text.count("pto.vecscope {"), 1) + self.assertIn(" = arith.constant 64 : index", text) + self.assertIn(" = arith.constant 128 : index", text) + self.assertIn(" = pto.vdup ", text) + self.assertIn(" = pto.vintlv ", text) + def test_punpack_widens_b16_mask_for_norm_b32_store_in_advanced_mode(self) -> None: @pto.vkernel(op="punpack_widen_b16_to_b32_unique", dtypes=[(pto.si8, pto.i32)], advanced=True) def kernel(src: pto.Tile, dst: pto.Tile): From 980e6f5bcc2f086543955d83af556209322391af Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Sat, 25 Apr 2026 09:30:17 +0800 Subject: [PATCH 161/192] fix(vpto): restore byte-scalar vdup HIVM ABI --- lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index 679a761b1..283f9486c 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -4141,7 +4141,9 @@ class LowerVdupOpPattern final : public OpConversionPattern { if (failed(normalizedScalar)) return rewriter.notifyMatchFailure(op, "failed to normalize scalar vdup input"); - callArgs.push_back(*normalizedScalar); + Value scalarForCall = normalizeByteScalarOperandForHivmCall( + rewriter, op.getLoc(), *normalizedScalar, scalarType); + callArgs.push_back(scalarForCall); } callArgs.push_back(mask); From 677395ee59b11b7348ab21d8c36c8236aa4d315f Mon Sep 17 00:00:00 2001 From: mly <978226558@qq.com> Date: Sat, 25 Apr 2026 12:58:46 +0800 Subject: [PATCH 162/192] feat: cancel ci on update (#260) Co-authored-by: mouliangyu --- .github/workflows/ci.yml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 299c9c7b4..e907aed69 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,7 +1,10 @@ name: CI +concurrency: + group: ci-${{ github.event_name }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: ${{ github.event_name == 'pull_request' }} + on: - push: pull_request: # Nightly remote-board validation (GitHub cron is UTC). # 22:00 CST (UTC+8) == 14:00 UTC. @@ -67,7 +70,7 @@ permissions: jobs: license-header-check: - if: ${{ github.event_name == 'pull_request' || github.event_name == 'push' }} + if: ${{ github.event_name == 'pull_request' }} runs-on: ubuntu-22.04 steps: - name: Checkout From 2d289c800590fcf9b215a7162483e245783bf00b Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Sat, 25 Apr 2026 09:04:04 +0800 Subject: [PATCH 163/192] fix(vpto): normalize scalar memref access before llvm emit (#247) --- lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 12 +- lib/PTO/Transforms/VPTOPtrNormalize.cpp | 113 ++++++++++++++++++ .../issue_247_load_scalar_ptr_normalize.pto | 26 ++++ 3 files changed, 148 insertions(+), 3 deletions(-) create mode 100644 test/basic/issue_247_load_scalar_ptr_normalize.pto diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index 283f9486c..cadcd96d8 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -6576,6 +6576,12 @@ class ConvertPtoLoadScalarOp final if (!llvmPtrType) return rewriter.notifyMatchFailure(op, "expected LLVM pointer operand"); + Type convertedValueType = + getTypeConverter()->convertType(op.getValue().getType()); + if (!convertedValueType) + return rewriter.notifyMatchFailure(op, + "could not convert load_scalar result type"); + Value offset = adaptor.getOffset(); if (offset.getType().isIndex()) offset = rewriter.create(op.getLoc(), @@ -6584,7 +6590,7 @@ class ConvertPtoLoadScalarOp final Value elemPtr = adaptor.getPtr(); if (!matchPattern(offset, m_Zero())) { elemPtr = rewriter.create(op.getLoc(), llvmPtrType, - op.getValue().getType(), adaptor.getPtr(), + convertedValueType, adaptor.getPtr(), ValueRange{offset}); } @@ -6602,8 +6608,8 @@ class ConvertPtoLoadScalarOp final }; rewriter.replaceOpWithNewOp( - op, op.getValue().getType(), elemPtr, - getNaturalAlignment(op.getValue().getType())); + op, convertedValueType, elemPtr, + getNaturalAlignment(convertedValueType)); return success(); } }; diff --git a/lib/PTO/Transforms/VPTOPtrNormalize.cpp b/lib/PTO/Transforms/VPTOPtrNormalize.cpp index cba0bf027..df6e0613d 100644 --- a/lib/PTO/Transforms/VPTOPtrNormalize.cpp +++ b/lib/PTO/Transforms/VPTOPtrNormalize.cpp @@ -1,3 +1,11 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + #include "PTO/IR/PTO.h" #include "PTO/Transforms/Passes.h" @@ -154,6 +162,64 @@ static Value materializeSubviewInputPtr(Value source, PatternRewriter &rewriter, return rewriter.create(loc, ptrType, source); } +static Value materializeScalarAccessPtr(Value source, PatternRewriter &rewriter, + Location loc) { + if (!source) + return {}; + if (isa(source.getType())) + return source; + + if (auto cast = source.getDefiningOp()) { + if (cast->getNumOperands() != 1 || cast->getNumResults() != 1) + return {}; + Value input = cast.getOperands().front(); + if (isa(input.getType())) + return input; + return materializeScalarAccessPtr(input, rewriter, loc); + } + + if (auto cast = source.getDefiningOp()) + return materializeScalarAccessPtr(cast.getSource(), rewriter, loc); + + if (auto subview = source.getDefiningOp()) { + if (!needsSubviewPtrConversion(subview)) + return {}; + + Value basePtr = + materializeScalarAccessPtr(subview.getSource(), rewriter, loc); + if (!basePtr) + return {}; + + Value offset; + if (failed(computeSubviewElementOffset(subview, rewriter, offset))) + return {}; + + auto ptrType = dyn_cast(convertSubviewResultType(source.getType())); + if (!ptrType) + return {}; + if (basePtr.getType() != ptrType) + basePtr = rewriter.create(loc, ptrType, basePtr); + return rewriter.create(loc, ptrType, basePtr, offset); + } + + if (auto bind = source.getDefiningOp()) + return materializeScalarAccessPtr(bind.getSource(), rewriter, loc); + + if (auto pointerCast = source.getDefiningOp()) { + if (pointerCast.getAddrs().empty()) + return {}; + Value addr = pointerCast.getAddrs().front(); + if (isa(addr.getType())) + return addr; + return materializeScalarAccessPtr(addr, rewriter, loc); + } + + // Restrict normalization to memref views that already sit on top of a ptr-like + // boundary bridge. Materializing fresh memref -> ptr casts here would leave + // illegal pto.castptr(memref) behind in this pass. + return {}; +} + struct ConvertTileBufAddrToPtrPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -320,6 +386,47 @@ struct ConvertVstsSubviewOperandPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::LoadScalarOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value ptr = materializeScalarAccessPtr(adaptor.getPtr(), rewriter, op.getLoc()); + if (!ptr) + return rewriter.notifyMatchFailure(op, + "failed to materialize load_scalar ptr"); + if (!isa(ptr.getType())) + return rewriter.notifyMatchFailure(op, "expected ptr-form load_scalar input"); + + rewriter.replaceOpWithNewOp(op, op.getValue().getType(), + ptr, adaptor.getOffset()); + return success(); + } +}; + +struct ConvertStoreScalarOperandToPtrPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::StoreScalarOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value ptr = materializeScalarAccessPtr(adaptor.getPtr(), rewriter, op.getLoc()); + if (!ptr) + return rewriter.notifyMatchFailure(op, + "failed to materialize store_scalar ptr"); + if (!isa(ptr.getType())) + return rewriter.notifyMatchFailure(op, "expected ptr-form store_scalar input"); + + rewriter.replaceOpWithNewOp(op, ptr, + adaptor.getOffset(), + adaptor.getValue()); + return success(); + } +}; + struct ConvertPtrNormalizeUnrealizedCastOp final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -403,6 +510,10 @@ struct VPTOPtrNormalizePass target.addDynamicallyLegalOp([](pto::VstsOp op) { return isa(op.getDestination().getType()); }); + target.addDynamicallyLegalOp( + [](pto::LoadScalarOp op) { return isa(op.getPtr().getType()); }); + target.addDynamicallyLegalOp( + [](pto::StoreScalarOp op) { return isa(op.getPtr().getType()); }); target.addDynamicallyLegalOp( [](memref::SubViewOp op) { return !needsSubviewPtrConversion(op); }); @@ -416,6 +527,8 @@ struct VPTOPtrNormalizePass ConvertBindTileToPtrPattern, ConvertSubviewToAddPtrPattern, ConvertVldsSubviewOperandPattern, ConvertVstsSubviewOperandPattern, + ConvertLoadScalarOperandToPtrPattern, + ConvertStoreScalarOperandToPtrPattern, ConvertPtrNormalizeUnrealizedCastOp>( typeConverter, context); diff --git a/test/basic/issue_247_load_scalar_ptr_normalize.pto b/test/basic/issue_247_load_scalar_ptr_normalize.pto new file mode 100644 index 000000000..b5e0b8fc8 --- /dev/null +++ b/test/basic/issue_247_load_scalar_ptr_normalize.pto @@ -0,0 +1,26 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-after=vpto-ptr-normalize %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=NORMALIZE +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --vpto-emit-hivm-llvm %s -o /dev/null + +module attributes {pto.target_arch = "a5"} { + func.func @load_scalar_ui32(%ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %v = pto.load_scalar %ptr[%c0] : !pto.ptr -> ui32 + pto.store_scalar %v, %ptr[%c0] : !pto.ptr, ui32 + return + } +} + +// NORMALIZE: // -----// IR Dump After +// NORMALIZE-SAME: (vpto-ptr-normalize) +// NORMALIZE-LABEL: func.func @load_scalar_ui32(%arg0: !pto.ptr) { +// NORMALIZE: %[[C0:.*]] = arith.constant 0 : index +// NORMALIZE: %[[VAL:.*]] = pto.load_scalar %arg0[%[[C0]]] : !pto.ptr -> ui32 +// NORMALIZE: pto.store_scalar %[[VAL]], %arg0[%[[C0]]] : !pto.ptr, ui32 From 43174771f7d2fb887e663185adc1a48de14b3714 Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Sat, 25 Apr 2026 15:13:27 +0800 Subject: [PATCH 164/192] feat: disable vpto ci --- .github/workflows/ci.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e907aed69..abe08cbd5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -437,6 +437,7 @@ jobs: echo "PTOAS_BIN=${GITHUB_WORKSPACE}/build/tools/ptoas/ptoas" >> "${GITHUB_ENV}" - name: Run VPTO SIM validation + if: ${{ false }} shell: bash run: | set -euo pipefail @@ -455,7 +456,7 @@ jobs: mkdir -p "${TILELANG_DSL_WORKSPACE}" ASCEND_HOME_PATH="${ASCEND_HOME_PATH}" \ PTOAS_BIN="${PTOAS_BIN}" \ - bash test/tilelang_st/script/run_ci.sh -r sim -v a5 --jobs 32 \ + bash test/tilelang_st/script/run_ci.sh -r sim -v a5 --jobs 64 \ 2>&1 | tee "${TILELANG_DSL_WORKSPACE}/run_ci.log" - name: Upload TileLang DSL logs From 9c09a83adf21e498a9d499b751c052e3f51ed38e Mon Sep 17 00:00:00 2001 From: Happybot Date: Sat, 25 Apr 2026 14:39:28 +0800 Subject: [PATCH 165/192] docs(vpto): add Cluster Programming Model section to Part I Adds a new '### Cluster Programming Model' section to the Architecture Overview (Part I) covering: - A5 cluster topology: 1 Cube (AIC) + 2 Vector (AIV0/AIV1) blocks, independent programs per SU with their own issue queues - Intra-cluster sync primitives: pto.set_flag/wait_flag (intra-core pipeline) vs pto.set_intra_block/wait_intra_core (inter-block IPC); clarifies that set_cross_core is multi-cluster scope only - C->V data path: fixpipe (L0C->UB) with NZ->ND layout conversion and dual-destination mode for 1:2 tile split (Split-M / Split-N) - V->C data path: TMOV ub2l1 (UB->L1) with ND->NZ layout conversion and 2:1 sub-tile assembly into full NZ Mat tile for MMAD - GM-staged fallback path - Programming model note: ISA exposes primitives directly; higher-level FIFO abstractions are software libraries, not part of the ISA --- docs/vpto-spec.md | 99 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) diff --git a/docs/vpto-spec.md b/docs/vpto-spec.md index 3aa56b275..e9195d21b 100644 --- a/docs/vpto-spec.md +++ b/docs/vpto-spec.md @@ -287,6 +287,105 @@ pto.strict_vecscope(%ub_in, %ub_out, %lane, %remaining) { Use `pto.strict_vecscope` when the source form should make all vector-scope inputs explicit in the region signature instead of relying on surrounding SSA visibility. The scope op itself only defines the vector-interval boundary and region argument contract. +### Cluster Programming Model + +#### Overview + +An A5 cluster contains one **Cube block** (AIC) and two **Vector blocks** (AIV0, AIV1). Each +block runs an **independent program** under its own Scalar Unit (SU), with its own issue queues: + +| Block | Issue Queues | +|---|---| +| Cube (AIC) | MTE2, MTE1, CUBE, FIXP | +| Vector (AIV) | MTE2, VEC, MTE3 | + +There is no implicit synchronization between blocks. All coordination between the Cube and Vector +programs is **explicit**, via the primitives described below. + +#### Intra-Cluster Synchronization + +Within a cluster, the PTO micro ISA provides two levels of synchronization: + +**Intra-core pipeline sync** (`pto.set_flag` / `pto.wait_flag`): coordinates the asynchronous +pipelines *within a single block* — for example, ensuring MTE2 completes a GM→UB load before +the VEC pipeline begins computation. This does not cross block boundaries. + +**Inter-block sync** (`pto.set_intra_block` / `pto.wait_intra_core`): coordinates between the +Cube block and a Vector block within the same cluster. The sender specifies which **local +pipeline** commits the signal, ensuring the preceding operation on that pipeline has completed +before the signal is issued. The receiver specifies which **local pipeline** should stall until +the signal arrives. This is the fundamental IPC primitive for Cube–Vector cooperation on A5. + +> **Note:** `pto.set_cross_core` / `pto.wait_cross_core` operate at **multi-cluster** scope and +> are not used for intra-cluster communication. + +#### Intra-Cluster Data Paths + +A5 provides dedicated on-chip data paths between the Cube and Vector blocks, bypassing Global +Memory entirely. These are the **recommended high-performance paths** for intra-cluster tile +exchange. + +##### C→V: Cube L0C → Vector UB (fixpipe) + +The **fixpipe** instruction transfers data directly from Cube's L0C buffer to a Vector block's UB. +Because Cube natively produces results in **NZ fractal layout** and Vector operates on **ND +(row-major) layout**, fixpipe performs the layout conversion in hardware: + +``` +Cube L0C (NZ layout) ──[fixpipe, NZ2ND]──▶ Vector UB (ND layout) +``` + +Fixpipe supports a **dual-destination mode**: a single transfer can write to *both* AIV0's UB and +AIV1's UB simultaneously, with the tile split in hardware along either the row axis +(`DualModeSplitM`) or the column axis (`DualModeSplitN`): + +| Split | AIV0 receives | AIV1 receives | +|---|---|---| +| Split-M (rows) | Upper `[M/2, N]` in ND | Lower `[M/2, N]` in ND | +| Split-N (cols) | Left `[M, N/2]` in ND | Right `[M, N/2]` in ND | + +This 1→2 broadcast with in-hardware tile split is the architectural basis for 1:2 +Cube-to-Vector tile distribution. + +##### V→C: Vector UB → Cube L1 (TMOV ub2l1) + +The reverse path uses `TMOV ub2l1` to transfer data from a Vector block's UB into Cube's L1 +buffer. A key architectural constraint: Cube's L1 stores tiles in **NZ fractal layout** (e.g. +`K1M1M0K0` — for fp16: `K0=16`, `M0=16`) so they can be loaded into L0A/L0B for MMAD +computation. Since Vector produces tiles in **ND layout**, the layout conversion from ND to NZ +must be applied as part of the V→C transfer: + +``` +Vector UB (ND layout) ──[TMOV ub2l1, ND→NZ]──▶ Cube L1 (NZ K1M1M0K0) +``` + +For 1:2 mode, both AIV0 and AIV1 each transfer a sub-tile into Cube's L1. The two sub-tiles are +assembled into a single contiguous NZ Mat tile in L1, ready for use as a LeftTile or RightTile +input to MMAD: + +| Split | AIV0 writes to L1 | AIV1 writes to L1 | Assembled in L1 | +|---|---|---|---| +| Split-M (rows) | `[K/2, N]` NZ at base | `[K/2, N]` NZ at offset | Full `[K, N]` NZ Mat tile | +| Split-N (cols) | `[K, N/2]` NZ at base | `[K, N/2]` NZ at offset | Full `[K, N]` NZ Mat tile | + +##### Fallback: GM-Staged Transfer + +When the local data path is not applicable, data can be exchanged via a **Global Memory staging +buffer**: the producer DMAs data to GM, and the consumer DMAs from GM. This path incurs off-chip +bandwidth cost and higher latency, but serves as a general fallback. + +#### Programming Model + +The common pattern for Cube–Vector co-programming is a **software pipeline**: the Cube and Vector +programs run a coordinated loop where each iteration the Cube produces a tile and the Vector +consumes it (or vice versa), with explicit `pto.set_intra_block` / `pto.wait_intra_core` +handshakes at each step to maintain correct data ordering. + +The PTO micro ISA exposes all the hardware primitives above directly. Higher-level constructs +that simplify this pattern (such as in-order FIFO abstractions) can be implemented as software +libraries on top of these primitives; they are not part of the ISA itself. + + ### Scope This document is the interface specification centered on the `mlir::pto` dialect and the shared MLIR surface used alongside it in PTO micro Instruction programs. From 55f192120ce491e3003bd135b85c4483f53dedb6 Mon Sep 17 00:00:00 2001 From: Happybot Date: Sat, 25 Apr 2026 15:35:29 +0800 Subject: [PATCH 166/192] docs(vpto): add ASCII cluster topology diagram to Cluster Programming Model section Shows A5 cluster layout with: - Cube core (AIC): SU + MTE2/MTE1/CUBE/FIXP issue queues - Vector 0 (AIV0, subblock_id=0): SU + MTE2/VEC/MTE3 issue queues - Vector 1 (AIV1, subblock_id=1): SU + MTE2/VEC/MTE3 issue queues - SC (System Controller): 32 semaphores total (16 C->V + 16 V->C), 4-bit counters, sema_id 0-7 for AIV0, 8-15 for AIV1 --- docs/vpto-spec.md | 51 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/docs/vpto-spec.md b/docs/vpto-spec.md index e9195d21b..caf279988 100644 --- a/docs/vpto-spec.md +++ b/docs/vpto-spec.md @@ -302,6 +302,57 @@ block runs an **independent program** under its own Scalar Unit (SU), with its o There is no implicit synchronization between blocks. All coordination between the Cube and Vector programs is **explicit**, via the primitives described below. + +``` +┌─────────────────────────────────────── A5 CLUSTER ───────────────────────────────────────┐ +│ │ +│ ┌─────────────────────┐ ┌─────────────────────┐ ┌─────────────────────┐ │ +│ │ CUBE CORE (AIC) │ │ VECTOR 0 (AIV0) │ │ VECTOR 1 (AIV1) │ │ +│ │ │ │ subblock_id = 0 │ │ subblock_id = 1 │ │ +│ │ ┌───────────────┐ │ │ ┌───────────────┐ │ │ ┌───────────────┐ │ │ +│ │ │ Scalar Unit │ │ │ │ Scalar Unit │ │ │ │ Scalar Unit │ │ │ +│ │ │ (SU) │ │ │ │ (SU) │ │ │ │ (SU) │ │ │ +│ │ │ runs cube │ │ │ │ runs vec │ │ │ │ runs vec │ │ │ +│ │ │ program │ │ │ │ program │ │ │ │ program │ │ │ +│ │ └───────────────┘ │ │ └───────────────┘ │ │ └───────────────┘ │ │ +│ │ ── Issue Queues ─ │ │ ── Issue Queues ─ │ │ ── Issue Queues ─ │ │ +│ │ ┌───────────────┐ │ │ ┌───────────────┐ │ │ ┌───────────────┐ │ │ +│ │ │ MTE2 │ │ │ │ MTE2 │ │ │ │ MTE2 │ │ │ +│ │ │ GM → L1 │ │ │ │ GM → UB │ │ │ │ GM → UB │ │ │ +│ │ ├───────────────┤ │ │ ├───────────────┤ │ │ ├───────────────┤ │ │ +│ │ │ MTE1 │ │ │ │ VEC │ │ │ │ VEC │ │ │ +│ │ │ L1 → L0A/B │ │ │ │ SIMD compute │ │ │ │ SIMD compute │ │ │ +│ │ ├───────────────┤ │ │ ├───────────────┤ │ │ ├───────────────┤ │ │ +│ │ │ CUBE │ │ │ │ MTE3 │ │ │ │ MTE3 │ │ │ +│ │ │ MMAD (L0C) │ │ │ │ UB → GM │ │ │ │ UB → GM │ │ │ +│ │ ├───────────────┤ │ │ └───────────────┘ │ │ └───────────────┘ │ │ +│ │ │ FIXP │ │ │ │ │ │ │ +│ │ │ L0C → UB │ │ │ │ │ │ │ +│ │ │ (fixpipe) │ │ │ │ │ │ │ +│ │ └───────────────┘ │ │ │ │ │ │ +│ └─────────────────────┘ └─────────────────────┘ └─────────────────────┘ │ +│ │ +│ ┌────────────────────── SC (System Controller) ──────────────────────────────────────┐ │ +│ │ │ │ +│ │ C→V direction · 16 semaphores · 4-bit counter each │ │ +│ │ ┌─────────────────────────────────────────────────────────────────────────────┐ │ │ +│ │ │ sema_id: [ 0][ 1][ 2][ 3][ 4][ 5][ 6][ 7] ←── AIV0 (subblock_id=0) │ │ │ +│ │ │ [ 8][ 9][10][11][12][13][14][15] ←── AIV1 (subblock_id=1) │ │ │ +│ │ └─────────────────────────────────────────────────────────────────────────────┘ │ │ +│ │ │ │ +│ │ V→C direction · 16 semaphores · 4-bit counter each │ │ +│ │ ┌─────────────────────────────────────────────────────────────────────────────┐ │ │ +│ │ │ sema_id: [ 0][ 1][ 2][ 3][ 4][ 5][ 6][ 7] ←── AIV0 (subblock_id=0) │ │ │ +│ │ │ [ 8][ 9][10][11][12][13][14][15] ←── AIV1 (subblock_id=1) │ │ │ +│ │ └─────────────────────────────────────────────────────────────────────────────┘ │ │ +│ │ │ │ +│ │ set_intra_block(trigger_pipe, sema_id) ──► increments semaphore │ │ +│ │ wait_intra_core(wait_pipe, sema_id) ──► stalls pipe until semaphore > 0 │ │ +│ │ │ │ +│ └─────────────────────────────────────────────────────────────────────────────────────┘ │ +└───────────────────────────────────────────────────────────────────────────────────────────┘ +``` + #### Intra-Cluster Synchronization Within a cluster, the PTO micro ISA provides two levels of synchronization: From ca281196380608f763abb57881bcd78f3d58e061 Mon Sep 17 00:00:00 2001 From: Happybot Date: Sat, 25 Apr 2026 15:38:53 +0800 Subject: [PATCH 167/192] docs(vpto): fix SC semaphore layout in cluster diagram 32 semaphores are shared for both C->V and V->C directions: - sema_id 0-15: communicate with AIV0 (subblock_id=0) - sema_id 16-31: communicate with AIV1 (subblock_id=1) -> 16 sema_id pairs available for 1:2 C:V sync per slot --- docs/vpto-spec.md | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/docs/vpto-spec.md b/docs/vpto-spec.md index caf279988..c644d109e 100644 --- a/docs/vpto-spec.md +++ b/docs/vpto-spec.md @@ -334,17 +334,19 @@ programs is **explicit**, via the primitives described below. │ │ │ ┌────────────────────── SC (System Controller) ──────────────────────────────────────┐ │ │ │ │ │ -│ │ C→V direction · 16 semaphores · 4-bit counter each │ │ -│ │ ┌─────────────────────────────────────────────────────────────────────────────┐ │ │ -│ │ │ sema_id: [ 0][ 1][ 2][ 3][ 4][ 5][ 6][ 7] ←── AIV0 (subblock_id=0) │ │ │ -│ │ │ [ 8][ 9][10][11][12][13][14][15] ←── AIV1 (subblock_id=1) │ │ │ -│ │ └─────────────────────────────────────────────────────────────────────────────┘ │ │ +│ │ 32 semaphores · 4-bit counter each · shared for C→V and V→C directions │ │ │ │ │ │ -│ │ V→C direction · 16 semaphores · 4-bit counter each │ │ -│ │ ┌─────────────────────────────────────────────────────────────────────────────┐ │ │ -│ │ │ sema_id: [ 0][ 1][ 2][ 3][ 4][ 5][ 6][ 7] ←── AIV0 (subblock_id=0) │ │ │ -│ │ │ [ 8][ 9][10][11][12][13][14][15] ←── AIV1 (subblock_id=1) │ │ │ -│ │ └─────────────────────────────────────────────────────────────────────────────┘ │ │ +│ │ ┌──────────────────────────────────────────────────────────────────────────────┐ │ │ +│ │ │ sema_id 0 –15 │ [ 0][ 1][ 2][ 3][ 4][ 5][ 6][ 7][ 8][ 9][10][11][12][13][14][15] │ │ │ +│ │ │ │ ↕ C→V / V→C ↕ │ │ │ +│ │ │ │ communicate with AIV0 (subblock_id=0) │ │ │ +│ │ ├──────────────────────────────────────────────────────────────────────────────┤ │ │ +│ │ │ sema_id 16–31 │ [16][17][18][19][20][21][22][23][24][25][26][27][28][29][30][31] │ │ │ +│ │ │ │ ↕ C→V / V→C ↕ │ │ │ +│ │ │ │ communicate with AIV1 (subblock_id=1) │ │ │ +│ │ └──────────────────────────────────────────────────────────────────────────────┘ │ │ +│ │ │ │ +│ │ → 16 sema_id pairs (0–15) available for 1:2 C:V sync per slot │ │ │ │ │ │ │ │ set_intra_block(trigger_pipe, sema_id) ──► increments semaphore │ │ │ │ wait_intra_core(wait_pipe, sema_id) ──► stalls pipe until semaphore > 0 │ │ From f8592d9ddd44896d0f073f904c08b40056e2a3f3 Mon Sep 17 00:00:00 2001 From: lwwang Date: Sat, 25 Apr 2026 16:51:12 +0800 Subject: [PATCH 168/192] [feat] add rowsum rowmin rowmax rowargmax rowargmin (#255) * [feat] add rowsum rowmin rowmax rowargmax rowargmin * feat(tileop): add TROWARGMAX, TROWARGMIN, TROWMAX, TROWMIN, TROWPROD, and TROWSUM test cases * refactor(tileops): simplify initialization of min/max values and accumulator in templates Co-authored-by: Copilot * feat(tileops): implement one-point store distance selection in row templates Co-authored-by: Copilot --------- Co-authored-by: Copilot --- lib/TileOps/trowargmax_template.py | 74 + lib/TileOps/trowargmin_template.py | 74 + lib/TileOps/trowmax_template.py | 61 + lib/TileOps/trowmin_template.py | 63 + lib/TileOps/trowprod_template.py | 71 + lib/TileOps/trowsum_template.py | 75 + .../expand_tile_op_tilelang_trowargmax.pto | 45 + .../expand_tile_op_tilelang_trowargmin.pto | 45 + .../basic/expand_tile_op_tilelang_trowmax.pto | 44 + .../basic/expand_tile_op_tilelang_trowmin.pto | 44 + .../expand_tile_op_tilelang_trowprod.pto | 45 + .../basic/expand_tile_op_tilelang_trowsum.pto | 44 + .../npu/a5/src/st/testcase/CMakeLists.txt | 6 + .../src/st/testcase/trowargmax/CMakeLists.txt | 9 + .../a5/src/st/testcase/trowargmax/cases.py | 215 +++ .../a5/src/st/testcase/trowargmax/compare.py | 52 + .../a5/src/st/testcase/trowargmax/gen_data.py | 38 + .../a5/src/st/testcase/trowargmax/launch.cpp | 133 ++ .../a5/src/st/testcase/trowargmax/main.cpp | 210 +++ .../src/st/testcase/trowargmax/trowargmax.pto | 1205 +++++++++++++ .../src/st/testcase/trowargmin/CMakeLists.txt | 9 + .../a5/src/st/testcase/trowargmin/cases.py | 215 +++ .../a5/src/st/testcase/trowargmin/compare.py | 52 + .../a5/src/st/testcase/trowargmin/gen_data.py | 38 + .../a5/src/st/testcase/trowargmin/launch.cpp | 133 ++ .../a5/src/st/testcase/trowargmin/main.cpp | 206 +++ .../src/st/testcase/trowargmin/trowargmin.pto | 1205 +++++++++++++ .../a5/src/st/testcase/trowmax/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/trowmax/cases.py | 224 +++ .../npu/a5/src/st/testcase/trowmax/compare.py | 49 + .../a5/src/st/testcase/trowmax/gen_data.py | 41 + .../npu/a5/src/st/testcase/trowmax/launch.cpp | 183 ++ .../npu/a5/src/st/testcase/trowmax/main.cpp | 207 +++ .../a5/src/st/testcase/trowmax/trowmax.pto | 1545 +++++++++++++++++ .../a5/src/st/testcase/trowmin/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/trowmin/cases.py | 224 +++ .../npu/a5/src/st/testcase/trowmin/compare.py | 49 + .../a5/src/st/testcase/trowmin/gen_data.py | 41 + .../npu/a5/src/st/testcase/trowmin/launch.cpp | 155 ++ .../npu/a5/src/st/testcase/trowmin/main.cpp | 207 +++ .../a5/src/st/testcase/trowmin/trowmin.pto | 1545 +++++++++++++++++ .../src/st/testcase/trowprod/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/trowprod/cases.py | 153 ++ .../a5/src/st/testcase/trowprod/compare.py | 49 + .../a5/src/st/testcase/trowprod/gen_data.py | 42 + .../a5/src/st/testcase/trowprod/launch.cpp | 105 ++ .../npu/a5/src/st/testcase/trowprod/main.cpp | 186 ++ .../a5/src/st/testcase/trowprod/trowprod.pto | 999 +++++++++++ .../a5/src/st/testcase/trowsum/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/trowsum/cases.py | 165 ++ .../npu/a5/src/st/testcase/trowsum/compare.py | 49 + .../a5/src/st/testcase/trowsum/gen_data.py | 45 + .../npu/a5/src/st/testcase/trowsum/launch.cpp | 111 ++ .../npu/a5/src/st/testcase/trowsum/main.cpp | 195 +++ .../a5/src/st/testcase/trowsum/trowsum.pto | 1092 ++++++++++++ 55 files changed, 12108 insertions(+) create mode 100644 lib/TileOps/trowargmax_template.py create mode 100644 lib/TileOps/trowargmin_template.py create mode 100644 lib/TileOps/trowmax_template.py create mode 100644 lib/TileOps/trowmin_template.py create mode 100644 lib/TileOps/trowprod_template.py create mode 100644 lib/TileOps/trowsum_template.py create mode 100644 test/basic/expand_tile_op_tilelang_trowargmax.pto create mode 100644 test/basic/expand_tile_op_tilelang_trowargmin.pto create mode 100644 test/basic/expand_tile_op_tilelang_trowmax.pto create mode 100644 test/basic/expand_tile_op_tilelang_trowmin.pto create mode 100644 test/basic/expand_tile_op_tilelang_trowprod.pto create mode 100644 test/basic/expand_tile_op_tilelang_trowsum.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowargmax/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowargmax/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowargmax/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowargmax/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowargmax/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowargmax/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowargmax/trowargmax.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowargmin/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowargmin/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowargmin/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowargmin/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowargmin/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowargmin/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowargmin/trowargmin.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowmax/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowmax/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowmax/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowmax/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowmax/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowmax/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowmax/trowmax.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowmin/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowmin/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowmin/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowmin/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowmin/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowmin/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowmin/trowmin.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowprod/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowprod/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowprod/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowprod/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowprod/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowprod/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowprod/trowprod.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowsum/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowsum/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowsum/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowsum/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowsum/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowsum/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowsum/trowsum.pto diff --git a/lib/TileOps/trowargmax_template.py b/lib/TileOps/trowargmax_template.py new file mode 100644 index 000000000..c0df597ed --- /dev/null +++ b/lib/TileOps/trowargmax_template.py @@ -0,0 +1,74 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.trowargmax""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + +@pto.vkernel( + target="a5", + op="pto.trowargmax", + advanced=True, +) +def template_trowargmax(src: pto.Tile, tmp: pto.Tile, dst: pto.Tile): + src_dtype = src.element_type + idx_dtype = dst.element_type + lanes = pto.get_lanes(src_dtype) + valid_rows, valid_cols = src.valid_shape + elem_bytes = pto.bytewidth(idx_dtype) + + # Initialize with dtype-specific minimum value (aligned with pto-isa Padding::Min) + init_val = pto.PadValue.MIN.eval(src_dtype) + + # Select one-point store dist based on index dtype size + if pto.constexpr(elem_bytes == 4): + store_dist = pto.VStoreDist.ONE_POINT_B32 + elif pto.constexpr(elem_bytes == 2): + store_dist = pto.VStoreDist.ONE_POINT_B16 + else: + store_dist = pto.VStoreDist.ONE_POINT_B8 + + for row in range(0, valid_rows, 1): + remained = valid_cols + + v_val_acc = pto.vbr(init_val) + init_zero_idx = idx_dtype(0) + v_idx_acc = pto.vbr(init_zero_idx) + + # Masks: src_dtype for data ops and final store (matches pto-isa CreatePredicate) + # idx_dtype for index arithmetic operations + mask_1, _ = pto.make_mask(src_dtype, 1) + mask_1_idx, _ = pto.make_mask(idx_dtype, 1) + + # Process all column chunks + for col in range(0, valid_cols, lanes): + mask, remained = pto.make_mask(src_dtype, remained) + v_src = pto.vlds(src[row, col:]) + v_reduced = pto.vcmax(v_src, mask) + + v_val, v_idx = pto.vdintlv(v_reduced, pto.vbr(src_dtype(0))) + v_idx = pto.vbitcast(v_idx, idx_dtype) + + # Add absolute col offset to the chunk's local index + col_offset = idx_dtype(col) + v_idx = pto.vadds(v_idx, col_offset, mask_1_idx) + + # Compare current chunk max with global max so far + cmp_mask = pto.vcmp(v_val_acc, v_val, mask_1, "lt") + + # Update global max and global argmax + v_val_acc = pto.vsel(v_val, v_val_acc, cmp_mask) + # v_idx_acc is ui32, requires b32 mask; convert cmp_mask from src_dtype's mask to b32 + cmp_mask_b32 = pto.pbitcast(cmp_mask, pto.mask_b32) + v_idx_acc = pto.vsel(v_idx, v_idx_acc, cmp_mask_b32) + + # Store index accumulator to destination tile using one-point mode + pto.vsts(v_idx_acc, dst[row, 0:], mask_1_idx, dist=store_dist) + return \ No newline at end of file diff --git a/lib/TileOps/trowargmin_template.py b/lib/TileOps/trowargmin_template.py new file mode 100644 index 000000000..f23ab7137 --- /dev/null +++ b/lib/TileOps/trowargmin_template.py @@ -0,0 +1,74 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.trowargmin""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + +@pto.vkernel( + target="a5", + op="pto.trowargmin", + advanced=True, +) +def template_trowargmin(src: pto.Tile, tmp: pto.Tile, dst: pto.Tile): + src_dtype = src.element_type + idx_dtype = dst.element_type + lanes = pto.get_lanes(src_dtype) + valid_rows, valid_cols = src.valid_shape + elem_bytes = pto.bytewidth(idx_dtype) + + # Initialize with dtype-specific maximum value (aligned with pto-isa Padding::Max) + init_val = pto.PadValue.MAX.eval(src_dtype) + + # Select one-point store dist based on index dtype size + if pto.constexpr(elem_bytes == 4): + store_dist = pto.VStoreDist.ONE_POINT_B32 + elif pto.constexpr(elem_bytes == 2): + store_dist = pto.VStoreDist.ONE_POINT_B16 + else: + store_dist = pto.VStoreDist.ONE_POINT_B8 + + for row in range(0, valid_rows, 1): + remained = valid_cols + + v_val_acc = pto.vbr(init_val) + init_zero_idx = idx_dtype(0) + v_idx_acc = pto.vbr(init_zero_idx) + + # Masks: src_dtype for data ops and final store (matches pto-isa CreatePredicate) + # idx_dtype for index arithmetic operations + mask_1, _ = pto.make_mask(src_dtype, 1) + mask_1_idx, _ = pto.make_mask(idx_dtype, 1) + + # Process all column chunks + for col in range(0, valid_cols, lanes): + mask, remained = pto.make_mask(src_dtype, remained) + v_src = pto.vlds(src[row, col:]) + v_reduced = pto.vcmin(v_src, mask) + + v_val, v_idx = pto.vdintlv(v_reduced, pto.vbr(src_dtype(0))) + v_idx = pto.vbitcast(v_idx, idx_dtype) + + # Add absolute col offset to the chunk's local index + col_offset = idx_dtype(col) + v_idx = pto.vadds(v_idx, col_offset, mask_1_idx) + + # Compare current chunk min with global min so far + cmp_mask = pto.vcmp(v_val_acc, v_val, mask_1, "gt") + + # Update global min and global argmin + v_val_acc = pto.vsel(v_val, v_val_acc, cmp_mask) + # v_idx_acc is ui32, requires b32 mask; cast cmp_mask from src_dtype's mask to b32 + cmp_mask_b32 = pto.pbitcast(cmp_mask, pto.mask_b32) + v_idx_acc = pto.vsel(v_idx, v_idx_acc, cmp_mask_b32) + + # Store index accumulator to destination tile using one-point mode + pto.vsts(v_idx_acc, dst[row, 0:], mask_1_idx, dist=store_dist) + return diff --git a/lib/TileOps/trowmax_template.py b/lib/TileOps/trowmax_template.py new file mode 100644 index 000000000..522ce699e --- /dev/null +++ b/lib/TileOps/trowmax_template.py @@ -0,0 +1,61 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.trowmax""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + +@pto.vkernel( + target="a5", + op="pto.trowmax", + advanced=True, +) +def template_trowmax(src: pto.Tile, tmp: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + lanes = pto.get_lanes(dtype) + valid_rows, valid_cols = src.valid_shape + elem_bytes = pto.bytewidth(dtype) + + # Initialize with dtype-specific minimum value (aligned with pto-isa Padding::Min) + init_val = pto.PadValue.MIN.eval(dtype) + + # Select one-point store dist based on dtype size + if pto.constexpr(elem_bytes == 4): + store_dist = pto.VStoreDist.ONE_POINT_B32 + elif pto.constexpr(elem_bytes == 2): + store_dist = pto.VStoreDist.ONE_POINT_B16 + else: + store_dist = pto.VStoreDist.ONE_POINT_B8 + + for row in range(0, valid_rows, 1): + remained = valid_cols + + mask_1, _ = pto.make_mask(dtype, 1) + + # Initialize the accumulator for ROWMAX + v_acc = pto.vbr(init_val) + + # Process column chunks + for col in range(0, valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + v_src = pto.vlds(src[row, col:]) + + # vcmax reduces src_dtype to acc_dtype + v_reduced = pto.vcmax(v_src, mask) + + # Clear masked lanes to init_val for float types so vmax doesn't see NaN + if pto.constexpr(dtype == pto.f32 or dtype == pto.f16): + v_reduced = pto.vsel(v_reduced, v_acc, mask) + + v_acc = pto.vmax(v_acc, v_reduced, mask_1) + + # Write final reduction to dest buffer once using one-point mode + pto.vsts(v_acc, dst[row, 0:], mask_1, dist=store_dist) + return diff --git a/lib/TileOps/trowmin_template.py b/lib/TileOps/trowmin_template.py new file mode 100644 index 000000000..f74b798c4 --- /dev/null +++ b/lib/TileOps/trowmin_template.py @@ -0,0 +1,63 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.trowmin""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.trowmin", + advanced=True, +) +def template_trowmin(src: pto.Tile, tmp: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + lanes = pto.get_lanes(dtype) + valid_rows, valid_cols = src.valid_shape + elem_bytes = pto.bytewidth(dtype) + + # Initialize with dtype-specific maximum value (aligned with pto-isa Padding::Max) + init_val = pto.PadValue.MAX.eval(dtype) + + # Select one-point store dist based on dtype size + if pto.constexpr(elem_bytes == 4): + store_dist = pto.VStoreDist.ONE_POINT_B32 + elif pto.constexpr(elem_bytes == 2): + store_dist = pto.VStoreDist.ONE_POINT_B16 + else: + store_dist = pto.VStoreDist.ONE_POINT_B8 + + mask_1, _ = pto.make_mask(dtype, 1) + + for row in range(0, valid_rows, 1): + remained = valid_cols + + # Initialize the accumulator for ROWMIN + v_acc = pto.vbr(init_val) + + # Process column chunks + for col in range(0, valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + v_src = pto.vlds(src[row, col:]) + + # vcmin reduces src_dtype to acc_dtype + v_reduced = pto.vcmin(v_src, mask) + + # Clear masked lanes to init_val for float types so vmin doesn't see NaN + if pto.constexpr(dtype == pto.f32 or dtype == pto.f16): + v_reduced = pto.vsel(v_reduced, v_acc, mask) + + # accumulate using the accumulator's mask logic + v_acc = pto.vmin(v_acc, v_reduced, mask_1) + + # Write final reduction to dest buffer once using one-point mode + pto.vsts(v_acc, dst[row, 0:], mask_1, dist=store_dist) + return diff --git a/lib/TileOps/trowprod_template.py b/lib/TileOps/trowprod_template.py new file mode 100644 index 000000000..55208d328 --- /dev/null +++ b/lib/TileOps/trowprod_template.py @@ -0,0 +1,71 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.trowprod""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + +@pto.vkernel( + target="a5", + op="pto.trowprod", + advanced=True, +) +def template_trowprod(src: pto.Tile, tmp: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + lanes = pto.get_lanes(dtype) + valid_rows, valid_cols = src.valid_shape + elem_bytes = pto.bytewidth(dtype) + + # nLoop from C++ constants: TROW_PROD_LOOP_B16=7, TROW_PROD_LOOP_B32=6 + TROW_PROD_LOOP_B16 = 7 + TROW_PROD_LOOP_B32 = 6 + if pto.constexpr(dtype == pto.f16 or dtype == pto.i16): + n_loop = TROW_PROD_LOOP_B16 + else: + n_loop = TROW_PROD_LOOP_B32 + + # Select one-point store dist based on dtype size + if pto.constexpr(elem_bytes == 4): + store_dist = pto.VStoreDist.ONE_POINT_B32 + elif pto.constexpr(elem_bytes == 2): + store_dist = pto.VStoreDist.ONE_POINT_B16 + else: + store_dist = pto.VStoreDist.ONE_POINT_B8 + + mask_1, _ = pto.make_mask(dtype, 1) + + for row in range(0, valid_rows, 1): + remained = valid_cols + + one_val = dtype(1) + v_acc = pto.vbr(one_val) + v_one = pto.vbr(one_val) + + # Multiply across column chunks + for col in range(0, valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + v_src = pto.vlds(src[row, col:]) + + # Element-wise product + v_prod = pto.vmul(v_acc, v_src, mask) + + # Simulate MODE_MERGING with vsel (keep v_acc outside mask) + v_acc = pto.vsel(v_prod, v_acc, mask) + + # Log2 reduction phase across the vector + reduce_mask, _ = pto.make_mask(dtype, lanes) # all lanes active for inner reduction + + for k in range(0, n_loop, 1): + v_intlv1, v_intlv2 = pto.vintlv(v_acc, v_one) + v_acc = pto.vmul(v_intlv1, v_intlv2, reduce_mask) + + # Write final result at lane 0 using one-point mode + pto.vsts(v_acc, dst[row, 0:], mask_1, dist=store_dist) + return diff --git a/lib/TileOps/trowsum_template.py b/lib/TileOps/trowsum_template.py new file mode 100644 index 000000000..0dacd1cd2 --- /dev/null +++ b/lib/TileOps/trowsum_template.py @@ -0,0 +1,75 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.trowsum""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.trowsum", +) +def template_trowsum(src: pto.Tile, tmp: pto.Tile, dst: pto.Tile): + src_dtype = src.element_type + dst_dtype = dst.element_type + + # vcadd widens i16 -> i32; floats/i32 unchanged + if pto.constexpr(src_dtype == pto.i16): + acc_dtype = pto.i32 + else: + acc_dtype = src_dtype + + lanes = pto.get_lanes(src_dtype) + valid_rows, valid_cols = src.valid_shape + + # Use type-appropriate zero for accumulator initialization + zero_val = acc_dtype(0) + + # Select one-point store dist based on dst dtype size + elem_bytes = pto.bytewidth(dst_dtype) + if pto.constexpr(elem_bytes == 4): + store_dist = pto.VStoreDist.ONE_POINT_B32 + elif pto.constexpr(elem_bytes == 2): + store_dist = pto.VStoreDist.ONE_POINT_B16 + else: + store_dist = pto.VStoreDist.ONE_POINT_B8 + + dst_mask_1, _ = pto.make_mask(dst_dtype, 1) + + for row in range(0, valid_rows, 1): + remained = valid_cols + + acc_mask_1, _ = pto.make_mask(acc_dtype, 1) + + # Initialize the accumulator with type-appropriate zero + v_acc = pto.vbr(zero_val) + + # Process column chunks + for col in range(0, valid_cols, lanes): + mask, remained = pto.make_mask(src_dtype, remained) + v_src = pto.vlds(src[row, col:]) + + # vcadd widens src_dtype to acc_dtype for integer types + v_reduced = pto.vcadd(v_src, mask) + + # accumulate using the accumulator's mask logic + v_acc = pto.vadd(v_acc, v_reduced, acc_mask_1) + + # Store the accumulated result safely once per row using one-point mode + if pto.constexpr(acc_dtype != dst_dtype): + # Truncate / Type cast before storing + # Note: For int32 -> int16 mapping, vcvt processes it. + acc_mask_for_cvt, _ = pto.make_mask(acc_dtype, 1) + v_acc_casted = pto.vcvt(v_acc, dst_dtype, acc_mask_for_cvt, sat=pto.VcvtSatMode.SAT, part=pto.VcvtPartMode.EVEN) + pto.vsts(v_acc_casted, dst[row, 0:], dst_mask_1, dist=store_dist) + else: + pto.vsts(v_acc, dst[row, 0:], dst_mask_1, dist=store_dist) + return diff --git a/test/basic/expand_tile_op_tilelang_trowargmax.pto b/test/basic/expand_tile_op_tilelang_trowargmax.pto new file mode 100644 index 000000000..c77a3f526 --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_trowargmax.pto @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s +// After the full tile-op-expand path on the VPTO backend, the original +// CHECK: func.func @TROWARGMAX +// CHECK-NOT: pto.trowargmax ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vbr +// CHECK: pto.vlds +// CHECK: pto.vcmax +// CHECK: pto.vdintlv +// CHECK: pto.vsts + +module { + func.func @TROWARGMAX() { + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/expand_tile_op_tilelang_trowargmin.pto b/test/basic/expand_tile_op_tilelang_trowargmin.pto new file mode 100644 index 000000000..85981c0b6 --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_trowargmin.pto @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s +// After the full tile-op-expand path on the VPTO backend, the original +// CHECK: func.func @TROWARGMIN +// CHECK-NOT: pto.trowargmin ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vbr +// CHECK: pto.vlds +// CHECK: pto.vcmin +// CHECK: pto.vdintlv +// CHECK: pto.vsts + +module { + func.func @TROWARGMIN() { + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/expand_tile_op_tilelang_trowmax.pto b/test/basic/expand_tile_op_tilelang_trowmax.pto new file mode 100644 index 000000000..fdd7607dd --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_trowmax.pto @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s +// After the full tile-op-expand path on the VPTO backend, the original +// CHECK: func.func @TROWMAX +// CHECK-NOT: pto.trowmax ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vbr +// CHECK: pto.vlds +// CHECK: pto.vcmax +// CHECK: pto.vsts + +module { + func.func @TROWMAX() { + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/expand_tile_op_tilelang_trowmin.pto b/test/basic/expand_tile_op_tilelang_trowmin.pto new file mode 100644 index 000000000..8ef1f7ee7 --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_trowmin.pto @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s +// After the full tile-op-expand path on the VPTO backend, the original +// CHECK: func.func @TROWMIN +// CHECK-NOT: pto.trowmin ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vbr +// CHECK: pto.vlds +// CHECK: pto.vcmin +// CHECK: pto.vsts + +module { + func.func @TROWMIN() { + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/expand_tile_op_tilelang_trowprod.pto b/test/basic/expand_tile_op_tilelang_trowprod.pto new file mode 100644 index 000000000..8bfc55c1c --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_trowprod.pto @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s +// After the full tile-op-expand path on the VPTO backend, the original +// CHECK: func.func @TROWPROD +// CHECK-NOT: pto.trowprod ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vbr +// CHECK: pto.vlds +// CHECK: pto.vmul +// CHECK: pto.vintlv +// CHECK: pto.vsts + +module { + func.func @TROWPROD() { + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.trowprod ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/expand_tile_op_tilelang_trowsum.pto b/test/basic/expand_tile_op_tilelang_trowsum.pto new file mode 100644 index 000000000..3f842f8d9 --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_trowsum.pto @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s +// After the full tile-op-expand path on the VPTO backend, the original +// CHECK: func.func @TROWSUM +// CHECK-NOT: pto.trowsum ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vbr +// CHECK: pto.vlds +// CHECK: pto.vcadd +// CHECK: pto.vsts + +module { + func.func @TROWSUM() { + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt index b8604dbc3..352667eb0 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt +++ b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt @@ -143,6 +143,12 @@ set(ALL_TESTCASES tpartadd tpartmul trecip + trowargmax + trowargmin + trowsum + trowmax + trowmin + trowprod trsqrt tsqrt ) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/CMakeLists.txt new file mode 100644 index 000000000..42aec9129 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(trowargmax) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/cases.py new file mode 100644 index 000000000..c67114409 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/cases.py @@ -0,0 +1,215 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for trowargmax ST test cases — aligned with pto-isa.""" + +import numpy as np + +CASES = [ + # uint32_dst + float32_src + { + "name": "uint32_float_8x1_8x8_8x8", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (8, 8), + "valid_shape": (8, 8), + "eps": 0, + }, + { + "name": "uint32_float_1024x1_1024x8_1024x8", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (1024, 8), + "valid_shape": (1024, 8), + "eps": 0, + }, + { + "name": "uint32_float_16x1_13x16_13x13", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (13, 16), + "valid_shape": (13, 13), + "eps": 0, + }, + { + "name": "uint32_float_1024x1_1023x24_1023x17", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (1023, 24), + "valid_shape": (1023, 17), + "eps": 0, + }, + { + "name": "uint32_float_8x1_8x64_8x64", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (8, 64), + "valid_shape": (8, 64), + "eps": 0, + }, + { + "name": "uint32_float_264x1_260x64_260x64", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (260, 64), + "valid_shape": (260, 64), + "eps": 0, + }, + { + "name": "uint32_float_8x1_1x128_1x128", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (1, 128), + "valid_shape": (1, 128), + "eps": 0, + }, + { + "name": "uint32_float_64x1_32x128_32x128", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (32, 128), + "valid_shape": (32, 128), + "eps": 0, + }, + { + "name": "uint32_float_8x1_3x4096_3x4095", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (3, 4096), + "valid_shape": (3, 4095), + "eps": 0, + }, + { + "name": "uint32_float_8x1_2x16384_2x16381", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (2, 16384), + "valid_shape": (2, 16381), + "eps": 0, + }, + # uint32_dst + float16_src + { + "name": "uint32_half_16x1_2x16_2x16", + "dtype": np.float16, + "dst_dtype": np.uint32, + "shape": (2, 16), + "valid_shape": (2, 16), + "eps": 0, + }, + { + "name": "uint32_half_16x1_13x16_13x13", + "dtype": np.float16, + "dst_dtype": np.uint32, + "shape": (13, 16), + "valid_shape": (13, 13), + "eps": 0, + }, + { + "name": "uint32_half_272x1_260x64_260x64", + "dtype": np.float16, + "dst_dtype": np.uint32, + "shape": (260, 64), + "valid_shape": (260, 64), + "eps": 0, + }, + { + "name": "uint32_half_16x1_3x8192_3x8191", + "dtype": np.float16, + "dst_dtype": np.uint32, + "shape": (3, 8192), + "valid_shape": (3, 8191), + "eps": 0, + }, + { + "name": "uint32_half_16x1_1x16384_1x16381", + "dtype": np.float16, + "dst_dtype": np.uint32, + "shape": (1, 16384), + "valid_shape": (1, 16381), + "eps": 0, + }, + { + "name": "uint32_half_16x1_1x32768_1x32761", + "dtype": np.float16, + "dst_dtype": np.uint32, + "shape": (1, 32768), + "valid_shape": (1, 32761), + "eps": 0, + }, + # int32_dst + float32_src + { + "name": "int32_float_16x1_13x16_13x13", + "dtype": np.float32, + "dst_dtype": np.int32, + "shape": (13, 16), + "valid_shape": (13, 13), + "eps": 0, + }, + # int32_dst + float16_src + { + "name": "int32_half_16x1_13x16_13x13", + "dtype": np.float16, + "dst_dtype": np.int32, + "shape": (13, 16), + "valid_shape": (13, 13), + "eps": 0, + }, + # uint32_dst + float32_src (dst col > 1) + { + "name": "uint32_float_3x8_3x3480_3x3473", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (3, 3480), + "valid_shape": (3, 3473), + "eps": 0, + }, + { + "name": "uint32_float_260x8_260x64_260x64", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (260, 64), + "valid_shape": (260, 64), + "eps": 0, + }, + { + "name": "uint32_float_1023x8_1023x24_1023x17", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (1023, 24), + "valid_shape": (1023, 17), + "eps": 0, + }, + # uint32_dst + float16_src (dst col > 1) + { + "name": "uint32_half_3x16_3x3488_3x3473", + "dtype": np.float16, + "dst_dtype": np.uint32, + "shape": (3, 3488), + "valid_shape": (3, 3473), + "eps": 0, + }, + { + "name": "uint32_half_260x16_260x64_260x64", + "dtype": np.float16, + "dst_dtype": np.uint32, + "shape": (260, 64), + "valid_shape": (260, 64), + "eps": 0, + }, + { + "name": "uint32_half_1023x16_1023x32_1023x17", + "dtype": np.float16, + "dst_dtype": np.uint32, + "shape": (1023, 32), + "valid_shape": (1023, 17), + "eps": 0, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/compare.py new file mode 100644 index 000000000..4cd015fd3 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/compare.py @@ -0,0 +1,52 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + vr, vc = case["valid_shape"] + out_shape = (vr, 1) + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dst_dtype"], count=np.prod(out_shape)).reshape(out_shape) + + output_full = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dst_dtype"]) + dst_cols = len(output_full) // vr + output = output_full.reshape(vr, dst_cols)[:, 0:1] + + ok = result_cmp(golden, output, case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/gen_data.py new file mode 100644 index 000000000..3016b948f --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/gen_data.py @@ -0,0 +1,38 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + dst_dtype = case["dst_dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + if dtype in (np.int8, np.uint8, np.int16, np.uint16, np.int32, np.uint32): + dtype_info = np.iinfo(dtype) + input1 = np.random.randint(dtype_info.min, dtype_info.max, size=shape).astype(dtype) + else: + dtype_info = np.finfo(dtype) + input1 = np.random.uniform(low=dtype_info.min, high=dtype_info.max, size=shape).astype(dtype) + + out_shape = (valid_shape[0], 1) + golden = np.zeros(out_shape, dtype=dst_dtype) + golden[:, 0:1] = np.argmax(input1[:, :valid_shape[1]], axis=1, keepdims=True).astype(dst_dtype) + + save_case_data(case["name"], {"input1": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/launch.cpp new file mode 100644 index 000000000..1da9eee23 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/launch.cpp @@ -0,0 +1,133 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +extern "C" __global__ AICORE void TROWARGMAX_uint32_float_8x1_8x8_8x8(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_float_8x1_8x8_8x8(float *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_float_8x1_8x8_8x8<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_float_1024x1_1024x8_1024x8(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_float_1024x1_1024x8_1024x8(float *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_float_1024x1_1024x8_1024x8<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_float_16x1_13x16_13x13(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_float_16x1_13x16_13x13(float *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_float_16x1_13x16_13x13<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_float_1024x1_1023x24_1023x17(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_float_1024x1_1023x24_1023x17(float *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_float_1024x1_1023x24_1023x17<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_float_8x1_8x64_8x64(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_float_8x1_8x64_8x64(float *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_float_8x1_8x64_8x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_float_264x1_260x64_260x64(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_float_264x1_260x64_260x64(float *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_float_264x1_260x64_260x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_float_8x1_1x128_1x128(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_float_8x1_1x128_1x128(float *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_float_8x1_1x128_1x128<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_float_64x1_32x128_32x128(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_float_64x1_32x128_32x128(float *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_float_64x1_32x128_32x128<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_float_8x1_3x4096_3x4095(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_float_8x1_3x4096_3x4095(float *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_float_8x1_3x4096_3x4095<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_float_8x1_2x16384_2x16381(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_float_8x1_2x16384_2x16381(float *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_float_8x1_2x16384_2x16381<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_half_16x1_2x16_2x16(__gm__ uint16_t *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_half_16x1_2x16_2x16(uint16_t *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_half_16x1_2x16_2x16<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_half_16x1_13x16_13x13(__gm__ uint16_t *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_half_16x1_13x16_13x13(uint16_t *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_half_16x1_13x16_13x13<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_half_272x1_260x64_260x64(__gm__ uint16_t *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_half_272x1_260x64_260x64(uint16_t *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_half_272x1_260x64_260x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_half_16x1_3x8192_3x8191(__gm__ uint16_t *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_half_16x1_3x8192_3x8191(uint16_t *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_half_16x1_3x8192_3x8191<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_half_16x1_1x16384_1x16381(__gm__ uint16_t *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_half_16x1_1x16384_1x16381(uint16_t *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_half_16x1_1x16384_1x16381<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_half_16x1_1x32768_1x32761(__gm__ uint16_t *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_half_16x1_1x32768_1x32761(uint16_t *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_half_16x1_1x32768_1x32761<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_int32_float_16x1_13x16_13x13(__gm__ float *src, __gm__ int32_t *dst); +void LaunchTROWARGMAX_int32_float_16x1_13x16_13x13(float *src, int32_t *dst, void *stream) { + TROWARGMAX_int32_float_16x1_13x16_13x13<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ int32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_int32_half_16x1_13x16_13x13(__gm__ uint16_t *src, __gm__ int32_t *dst); +void LaunchTROWARGMAX_int32_half_16x1_13x16_13x13(uint16_t *src, int32_t *dst, void *stream) { + TROWARGMAX_int32_half_16x1_13x16_13x13<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ int32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_float_3x8_3x3480_3x3473(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_float_3x8_3x3480_3x3473(float *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_float_3x8_3x3480_3x3473<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_float_260x8_260x64_260x64(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_float_260x8_260x64_260x64(float *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_float_260x8_260x64_260x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_float_1023x8_1023x24_1023x17(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_float_1023x8_1023x24_1023x17(float *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_float_1023x8_1023x24_1023x17<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_half_3x16_3x3488_3x3473(__gm__ uint16_t *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_half_3x16_3x3488_3x3473(uint16_t *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_half_3x16_3x3488_3x3473<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_half_260x16_260x64_260x64(__gm__ uint16_t *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_half_260x16_260x64_260x64(uint16_t *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_half_260x16_260x64_260x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_half_1023x16_1023x32_1023x17(__gm__ uint16_t *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_half_1023x16_1023x32_1023x17(uint16_t *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_half_1023x16_1023x32_1023x17<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint32_t *)dst); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/main.cpp new file mode 100644 index 000000000..908e57820 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/main.cpp @@ -0,0 +1,210 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang trowargmax ST — case-table driven. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTROWARGMAX_uint32_float_8x1_8x8_8x8(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_float_1024x1_1024x8_1024x8(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_float_16x1_13x16_13x13(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_float_1024x1_1023x24_1023x17(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_float_8x1_8x64_8x64(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_float_264x1_260x64_260x64(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_float_8x1_1x128_1x128(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_float_64x1_32x128_32x128(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_float_8x1_3x4096_3x4095(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_float_8x1_2x16384_2x16381(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_half_16x1_2x16_2x16(uint16_t *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_half_16x1_13x16_13x13(uint16_t *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_half_272x1_260x64_260x64(uint16_t *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_half_16x1_3x8192_3x8191(uint16_t *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_half_16x1_1x16384_1x16381(uint16_t *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_half_16x1_1x32768_1x32761(uint16_t *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_int32_float_16x1_13x16_13x13(float *src, int32_t *dst, void *stream); +void LaunchTROWARGMAX_int32_half_16x1_13x16_13x13(uint16_t *src, int32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_float_3x8_3x3480_3x3473(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_float_260x8_260x64_260x64(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_float_1023x8_1023x24_1023x17(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_half_3x16_3x3488_3x3473(uint16_t *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_half_260x16_260x64_260x64(uint16_t *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_half_1023x16_1023x32_1023x17(uint16_t *src, uint32_t *dst, void *stream); + +using LaunchFnF32U32 = void (*)(float *, uint32_t *, void *); +using LaunchFnF16U32 = void (*)(uint16_t *, uint32_t *, void *); +using LaunchFnF32S32 = void (*)(float *, int32_t *, void *); +using LaunchFnF16S32 = void (*)(uint16_t *, int32_t *, void *); + +enum class DType { F32U32, F16U32, F32S32, F16S32 }; + +struct TestCase { + const char *name; + DType dtype; + union { + LaunchFnF32U32 launchF32U32; + LaunchFnF16U32 launchF16U32; + LaunchFnF32S32 launchF32S32; + LaunchFnF16S32 launchF16S32; + }; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t srcElemSize; // bytes per src element + size_t dstElemSize; // bytes per dst element + size_t dstCols; // dst tile cols +}; + +static const TestCase kCases[] = { + {"uint32_float_8x1_8x8_8x8", DType::F32U32, .launchF32U32 = LaunchTROWARGMAX_uint32_float_8x1_8x8_8x8, 8, 8, 8, 8, 4, 4, 1}, + {"uint32_float_1024x1_1024x8_1024x8", DType::F32U32, .launchF32U32 = LaunchTROWARGMAX_uint32_float_1024x1_1024x8_1024x8, 1024, 8, 1024, 8, 4, 4, 1}, + {"uint32_float_16x1_13x16_13x13", DType::F32U32, .launchF32U32 = LaunchTROWARGMAX_uint32_float_16x1_13x16_13x13, 13, 16, 13, 13, 4, 4, 1}, + {"uint32_float_1024x1_1023x24_1023x17", DType::F32U32, .launchF32U32 = LaunchTROWARGMAX_uint32_float_1024x1_1023x24_1023x17, 1023, 24, 1023, 17, 4, 4, 1}, + {"uint32_float_8x1_8x64_8x64", DType::F32U32, .launchF32U32 = LaunchTROWARGMAX_uint32_float_8x1_8x64_8x64, 8, 64, 8, 64, 4, 4, 1}, + {"uint32_float_264x1_260x64_260x64", DType::F32U32, .launchF32U32 = LaunchTROWARGMAX_uint32_float_264x1_260x64_260x64, 260, 64, 260, 64, 4, 4, 1}, + {"uint32_float_8x1_1x128_1x128", DType::F32U32, .launchF32U32 = LaunchTROWARGMAX_uint32_float_8x1_1x128_1x128, 1, 128, 1, 128, 4, 4, 1}, + {"uint32_float_64x1_32x128_32x128", DType::F32U32, .launchF32U32 = LaunchTROWARGMAX_uint32_float_64x1_32x128_32x128, 32, 128, 32, 128, 4, 4, 1}, + {"uint32_float_8x1_3x4096_3x4095", DType::F32U32, .launchF32U32 = LaunchTROWARGMAX_uint32_float_8x1_3x4096_3x4095, 3, 4096, 3, 4095, 4, 4, 1}, + {"uint32_float_8x1_2x16384_2x16381", DType::F32U32, .launchF32U32 = LaunchTROWARGMAX_uint32_float_8x1_2x16384_2x16381, 2, 16384, 2, 16381, 4, 4, 1}, + {"uint32_half_16x1_2x16_2x16", DType::F16U32, .launchF16U32 = LaunchTROWARGMAX_uint32_half_16x1_2x16_2x16, 2, 16, 2, 16, 2, 4, 1}, + {"uint32_half_16x1_13x16_13x13", DType::F16U32, .launchF16U32 = LaunchTROWARGMAX_uint32_half_16x1_13x16_13x13, 13, 16, 13, 13, 2, 4, 1}, + {"uint32_half_272x1_260x64_260x64", DType::F16U32, .launchF16U32 = LaunchTROWARGMAX_uint32_half_272x1_260x64_260x64, 260, 64, 260, 64, 2, 4, 1}, + {"uint32_half_16x1_3x8192_3x8191", DType::F16U32, .launchF16U32 = LaunchTROWARGMAX_uint32_half_16x1_3x8192_3x8191, 3, 8192, 3, 8191, 2, 4, 1}, + {"uint32_half_16x1_1x16384_1x16381", DType::F16U32, .launchF16U32 = LaunchTROWARGMAX_uint32_half_16x1_1x16384_1x16381, 1, 16384, 1, 16381, 2, 4, 1}, + {"uint32_half_16x1_1x32768_1x32761", DType::F16U32, .launchF16U32 = LaunchTROWARGMAX_uint32_half_16x1_1x32768_1x32761, 1, 32768, 1, 32761, 2, 4, 1}, + {"int32_float_16x1_13x16_13x13", DType::F32S32, .launchF32S32 = LaunchTROWARGMAX_int32_float_16x1_13x16_13x13, 13, 16, 13, 13, 4, 4, 1}, + {"int32_half_16x1_13x16_13x13", DType::F16S32, .launchF16S32 = LaunchTROWARGMAX_int32_half_16x1_13x16_13x13, 13, 16, 13, 13, 2, 4, 1}, + {"uint32_float_3x8_3x3480_3x3473", DType::F32U32, .launchF32U32 = LaunchTROWARGMAX_uint32_float_3x8_3x3480_3x3473, 3, 3480, 3, 3473, 4, 4, 8}, + {"uint32_float_260x8_260x64_260x64", DType::F32U32, .launchF32U32 = LaunchTROWARGMAX_uint32_float_260x8_260x64_260x64, 260, 64, 260, 64, 4, 4, 8}, + {"uint32_float_1023x8_1023x24_1023x17", DType::F32U32, .launchF32U32 = LaunchTROWARGMAX_uint32_float_1023x8_1023x24_1023x17, 1023, 24, 1023, 17, 4, 4, 8}, + {"uint32_half_3x16_3x3488_3x3473", DType::F16U32, .launchF16U32 = LaunchTROWARGMAX_uint32_half_3x16_3x3488_3x3473, 3, 3488, 3, 3473, 2, 4, 16}, + {"uint32_half_260x16_260x64_260x64", DType::F16U32, .launchF16U32 = LaunchTROWARGMAX_uint32_half_260x16_260x64_260x64, 260, 64, 260, 64, 2, 4, 16}, + {"uint32_half_1023x16_1023x32_1023x17", DType::F16U32, .launchF16U32 = LaunchTROWARGMAX_uint32_half_1023x16_1023x32_1023x17, 1023, 32, 1023, 17, 2, 4, 16}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t srcElemCount = tc.rows * tc.cols; + const size_t srcFileSize = srcElemCount * tc.srcElemSize; + const size_t dstElemCount = tc.validRows * tc.dstCols; + const size_t dstFileSize = dstElemCount * tc.dstElemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = srcFileSize; + + void *src0Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&src0Host, srcFileSize); + aclrtMallocHost(&dstHost, dstFileSize); + + aclrtMalloc(&src0Device, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (rc == 0) { + aclrtMemset(dstDevice, dstFileSize, 0, dstFileSize); + } + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, srcFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, srcFileSize, src0Host, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + switch (tc.dtype) { + case DType::F32U32: + tc.launchF32U32((float *)src0Device, (uint32_t *)dstDevice, stream); + break; + case DType::F16U32: + tc.launchF16U32((uint16_t *)src0Device, (uint32_t *)dstDevice, stream); + break; + case DType::F32S32: + tc.launchF32S32((float *)src0Device, (int32_t *)dstDevice, stream); + break; + case DType::F16S32: + tc.launchF16S32((uint16_t *)src0Device, (int32_t *)dstDevice, stream); + break; + } + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0) { + mkdir(caseDir.c_str(), 0755); + if (!WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./trowargmax [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/trowargmax.pto b/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/trowargmax.pto new file mode 100644 index 000000000..249a5b7c7 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/trowargmax.pto @@ -0,0 +1,1205 @@ +// Auto-generated trowargmax ST testcases + +module { + + func.func @TROWARGMAX_uint32_float_8x1_8x8_8x8(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8_r = arith.constant 8 : index + %c8_c = arith.constant 8 : index + %c64_se = arith.constant 64 : index + %c8_de = arith.constant 8 : index + %c1_dc = arith.constant 1 : index + %c8_vr = arith.constant 8 : index + %c8_vc = arith.constant 8 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c8_r, %c8_c], + strides = [%c64_se, %c64_se, %c64_se, %c8_c, %c1] + : !pto.tensor_view<1x1x1x8x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c8_r, %c1_dc], + strides = [%c8_de, %c8_de, %c8_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x8x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8_vr, %c8_vc] + : !pto.tensor_view<1x1x1x8x8xf32> -> !pto.partition_tensor_view<1x1x1x8x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8_vr, %c1] + : !pto.tensor_view<1x1x1x8x1xui32> -> !pto.partition_tensor_view<1x1x1x8x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x8x8xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x8x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_float_1024x1_1024x8_1024x8(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1024_r = arith.constant 1024 : index + %c8_c = arith.constant 8 : index + %c8192_se = arith.constant 8192 : index + %c1024_de = arith.constant 1024 : index + %c1_dc = arith.constant 1 : index + %c1024_vr = arith.constant 1024 : index + %c8_vc = arith.constant 8 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1024_r, %c8_c], + strides = [%c8192_se, %c8192_se, %c8192_se, %c8_c, %c1] + : !pto.tensor_view<1x1x1x1024x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1024_r, %c1_dc], + strides = [%c1024_de, %c1024_de, %c1024_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x1024x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1024_vr, %c8_vc] + : !pto.tensor_view<1x1x1x1024x8xf32> -> !pto.partition_tensor_view<1x1x1x1024x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1024_vr, %c1] + : !pto.tensor_view<1x1x1x1024x1xui32> -> !pto.partition_tensor_view<1x1x1x1024x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1024x8xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1024x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_float_16x1_13x16_13x13(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c13_r = arith.constant 13 : index + %c16_c = arith.constant 16 : index + %c208_se = arith.constant 208 : index + %c13_de = arith.constant 13 : index + %c1_dc = arith.constant 1 : index + %c13_vr = arith.constant 13 : index + %c13_vc = arith.constant 13 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c13_r, %c16_c], + strides = [%c208_se, %c208_se, %c208_se, %c16_c, %c1] + : !pto.tensor_view<1x1x1x13x16xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c13_r, %c1_dc], + strides = [%c13_de, %c13_de, %c13_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x13x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c13_vr, %c13_vc] + : !pto.tensor_view<1x1x1x13x16xf32> -> !pto.partition_tensor_view<1x1x1x13x13xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c13_vr, %c1] + : !pto.tensor_view<1x1x1x13x1xui32> -> !pto.partition_tensor_view<1x1x1x13x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x13x13xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x13x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_float_1024x1_1023x24_1023x17(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1023_r = arith.constant 1023 : index + %c24_c = arith.constant 24 : index + %c24552_se = arith.constant 24552 : index + %c1023_de = arith.constant 1023 : index + %c1_dc = arith.constant 1 : index + %c1023_vr = arith.constant 1023 : index + %c17_vc = arith.constant 17 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1023_r, %c24_c], + strides = [%c24552_se, %c24552_se, %c24552_se, %c24_c, %c1] + : !pto.tensor_view<1x1x1x1023x24xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1023_r, %c1_dc], + strides = [%c1023_de, %c1023_de, %c1023_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x1023x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1023_vr, %c17_vc] + : !pto.tensor_view<1x1x1x1023x24xf32> -> !pto.partition_tensor_view<1x1x1x1023x17xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1023_vr, %c1] + : !pto.tensor_view<1x1x1x1023x1xui32> -> !pto.partition_tensor_view<1x1x1x1023x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1023x17xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1023x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_float_8x1_8x64_8x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8_r = arith.constant 8 : index + %c64_c = arith.constant 64 : index + %c512_se = arith.constant 512 : index + %c8_de = arith.constant 8 : index + %c1_dc = arith.constant 1 : index + %c8_vr = arith.constant 8 : index + %c64_vc = arith.constant 64 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c8_r, %c64_c], + strides = [%c512_se, %c512_se, %c512_se, %c64_c, %c1] + : !pto.tensor_view<1x1x1x8x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c8_r, %c1_dc], + strides = [%c8_de, %c8_de, %c8_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x8x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8_vr, %c64_vc] + : !pto.tensor_view<1x1x1x8x64xf32> -> !pto.partition_tensor_view<1x1x1x8x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8_vr, %c1] + : !pto.tensor_view<1x1x1x8x1xui32> -> !pto.partition_tensor_view<1x1x1x8x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x8x64xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x8x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_float_264x1_260x64_260x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c260_r = arith.constant 260 : index + %c64_c = arith.constant 64 : index + %c16640_se = arith.constant 16640 : index + %c260_de = arith.constant 260 : index + %c1_dc = arith.constant 1 : index + %c260_vr = arith.constant 260 : index + %c64_vc = arith.constant 64 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c260_r, %c64_c], + strides = [%c16640_se, %c16640_se, %c16640_se, %c64_c, %c1] + : !pto.tensor_view<1x1x1x260x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c260_r, %c1_dc], + strides = [%c260_de, %c260_de, %c260_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x260x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260_vr, %c64_vc] + : !pto.tensor_view<1x1x1x260x64xf32> -> !pto.partition_tensor_view<1x1x1x260x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260_vr, %c1] + : !pto.tensor_view<1x1x1x260x1xui32> -> !pto.partition_tensor_view<1x1x1x260x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x260x64xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x260x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_float_8x1_1x128_1x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1_r = arith.constant 1 : index + %c128_c = arith.constant 128 : index + %c128_se = arith.constant 128 : index + %c1_de = arith.constant 1 : index + %c1_dc = arith.constant 1 : index + %c1_vr = arith.constant 1 : index + %c128_vc = arith.constant 128 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1_r, %c128_c], + strides = [%c128_se, %c128_se, %c128_se, %c128_c, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1_r, %c1_dc], + strides = [%c1_de, %c1_de, %c1_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x1x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1_vr, %c128_vc] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1_vr, %c1] + : !pto.tensor_view<1x1x1x1x1xui32> -> !pto.partition_tensor_view<1x1x1x1x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_float_64x1_32x128_32x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32_r = arith.constant 32 : index + %c128_c = arith.constant 128 : index + %c4096_se = arith.constant 4096 : index + %c32_de = arith.constant 32 : index + %c1_dc = arith.constant 1 : index + %c32_vr = arith.constant 32 : index + %c128_vc = arith.constant 128 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32_r, %c128_c], + strides = [%c4096_se, %c4096_se, %c4096_se, %c128_c, %c1] + : !pto.tensor_view<1x1x1x32x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32_r, %c1_dc], + strides = [%c32_de, %c32_de, %c32_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x32x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32_vr, %c128_vc] + : !pto.tensor_view<1x1x1x32x128xf32> -> !pto.partition_tensor_view<1x1x1x32x128xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32_vr, %c1] + : !pto.tensor_view<1x1x1x32x1xui32> -> !pto.partition_tensor_view<1x1x1x32x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x128xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_float_8x1_3x4096_3x4095(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3_r = arith.constant 3 : index + %c4096_c = arith.constant 4096 : index + %c12288_se = arith.constant 12288 : index + %c3_de = arith.constant 3 : index + %c1_dc = arith.constant 1 : index + %c3_vr = arith.constant 3 : index + %c4095_vc = arith.constant 4095 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c3_r, %c4096_c], + strides = [%c12288_se, %c12288_se, %c12288_se, %c4096_c, %c1] + : !pto.tensor_view<1x1x1x3x4096xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c3_r, %c1_dc], + strides = [%c3_de, %c3_de, %c3_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x3x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c3_vr, %c4095_vc] + : !pto.tensor_view<1x1x1x3x4096xf32> -> !pto.partition_tensor_view<1x1x1x3x4095xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c3_vr, %c1] + : !pto.tensor_view<1x1x1x3x1xui32> -> !pto.partition_tensor_view<1x1x1x3x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x3x4095xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x3x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_float_8x1_2x16384_2x16381(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2_r = arith.constant 2 : index + %c16384_c = arith.constant 16384 : index + %c32768_se = arith.constant 32768 : index + %c2_de = arith.constant 2 : index + %c1_dc = arith.constant 1 : index + %c2_vr = arith.constant 2 : index + %c16381_vc = arith.constant 16381 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2_r, %c16384_c], + strides = [%c32768_se, %c32768_se, %c32768_se, %c16384_c, %c1] + : !pto.tensor_view<1x1x1x2x16384xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2_r, %c1_dc], + strides = [%c2_de, %c2_de, %c2_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x2x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2_vr, %c16381_vc] + : !pto.tensor_view<1x1x1x2x16384xf32> -> !pto.partition_tensor_view<1x1x1x2x16381xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2_vr, %c1] + : !pto.tensor_view<1x1x1x2x1xui32> -> !pto.partition_tensor_view<1x1x1x2x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x16381xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_half_16x1_2x16_2x16(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2_r = arith.constant 2 : index + %c16_c = arith.constant 16 : index + %c32_se = arith.constant 32 : index + %c2_de = arith.constant 2 : index + %c1_dc = arith.constant 1 : index + %c2_vr = arith.constant 2 : index + %c16_vc = arith.constant 16 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2_r, %c16_c], + strides = [%c32_se, %c32_se, %c32_se, %c16_c, %c1] + : !pto.tensor_view<1x1x1x2x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2_r, %c1_dc], + strides = [%c2_de, %c2_de, %c2_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x2x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2_vr, %c16_vc] + : !pto.tensor_view<1x1x1x2x16xf16> -> !pto.partition_tensor_view<1x1x1x2x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2_vr, %c1] + : !pto.tensor_view<1x1x1x2x1xui32> -> !pto.partition_tensor_view<1x1x1x2x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x16xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_half_16x1_13x16_13x13(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c13_r = arith.constant 13 : index + %c16_c = arith.constant 16 : index + %c208_se = arith.constant 208 : index + %c13_de = arith.constant 13 : index + %c1_dc = arith.constant 1 : index + %c13_vr = arith.constant 13 : index + %c13_vc = arith.constant 13 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c13_r, %c16_c], + strides = [%c208_se, %c208_se, %c208_se, %c16_c, %c1] + : !pto.tensor_view<1x1x1x13x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c13_r, %c1_dc], + strides = [%c13_de, %c13_de, %c13_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x13x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c13_vr, %c13_vc] + : !pto.tensor_view<1x1x1x13x16xf16> -> !pto.partition_tensor_view<1x1x1x13x13xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c13_vr, %c1] + : !pto.tensor_view<1x1x1x13x1xui32> -> !pto.partition_tensor_view<1x1x1x13x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x13x13xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x13x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_half_272x1_260x64_260x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c260_r = arith.constant 260 : index + %c64_c = arith.constant 64 : index + %c16640_se = arith.constant 16640 : index + %c260_de = arith.constant 260 : index + %c1_dc = arith.constant 1 : index + %c260_vr = arith.constant 260 : index + %c64_vc = arith.constant 64 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c260_r, %c64_c], + strides = [%c16640_se, %c16640_se, %c16640_se, %c64_c, %c1] + : !pto.tensor_view<1x1x1x260x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c260_r, %c1_dc], + strides = [%c260_de, %c260_de, %c260_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x260x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260_vr, %c64_vc] + : !pto.tensor_view<1x1x1x260x64xf16> -> !pto.partition_tensor_view<1x1x1x260x64xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260_vr, %c1] + : !pto.tensor_view<1x1x1x260x1xui32> -> !pto.partition_tensor_view<1x1x1x260x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x260x64xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x260x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_half_16x1_3x8192_3x8191(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3_r = arith.constant 3 : index + %c8192_c = arith.constant 8192 : index + %c24576_se = arith.constant 24576 : index + %c3_de = arith.constant 3 : index + %c1_dc = arith.constant 1 : index + %c3_vr = arith.constant 3 : index + %c8191_vc = arith.constant 8191 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c3_r, %c8192_c], + strides = [%c24576_se, %c24576_se, %c24576_se, %c8192_c, %c1] + : !pto.tensor_view<1x1x1x3x8192xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c3_r, %c1_dc], + strides = [%c3_de, %c3_de, %c3_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x3x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c3_vr, %c8191_vc] + : !pto.tensor_view<1x1x1x3x8192xf16> -> !pto.partition_tensor_view<1x1x1x3x8191xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c3_vr, %c1] + : !pto.tensor_view<1x1x1x3x1xui32> -> !pto.partition_tensor_view<1x1x1x3x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x3x8191xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x3x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_half_16x1_1x16384_1x16381(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1_r = arith.constant 1 : index + %c16384_c = arith.constant 16384 : index + %c16384_se = arith.constant 16384 : index + %c1_de = arith.constant 1 : index + %c1_dc = arith.constant 1 : index + %c1_vr = arith.constant 1 : index + %c16381_vc = arith.constant 16381 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1_r, %c16384_c], + strides = [%c16384_se, %c16384_se, %c16384_se, %c16384_c, %c1] + : !pto.tensor_view<1x1x1x1x16384xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1_r, %c1_dc], + strides = [%c1_de, %c1_de, %c1_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x1x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1_vr, %c16381_vc] + : !pto.tensor_view<1x1x1x1x16384xf16> -> !pto.partition_tensor_view<1x1x1x1x16381xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1_vr, %c1] + : !pto.tensor_view<1x1x1x1x1xui32> -> !pto.partition_tensor_view<1x1x1x1x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x16381xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_half_16x1_1x32768_1x32761(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1_r = arith.constant 1 : index + %c32768_c = arith.constant 32768 : index + %c32768_se = arith.constant 32768 : index + %c1_de = arith.constant 1 : index + %c1_dc = arith.constant 1 : index + %c1_vr = arith.constant 1 : index + %c32761_vc = arith.constant 32761 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1_r, %c32768_c], + strides = [%c32768_se, %c32768_se, %c32768_se, %c32768_c, %c1] + : !pto.tensor_view<1x1x1x1x32768xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1_r, %c1_dc], + strides = [%c1_de, %c1_de, %c1_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x1x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1_vr, %c32761_vc] + : !pto.tensor_view<1x1x1x1x32768xf16> -> !pto.partition_tensor_view<1x1x1x1x32761xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1_vr, %c1] + : !pto.tensor_view<1x1x1x1x1xui32> -> !pto.partition_tensor_view<1x1x1x1x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x32761xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x1xui32>) + return + } + + func.func @TROWARGMAX_int32_float_16x1_13x16_13x13(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c13_r = arith.constant 13 : index + %c16_c = arith.constant 16 : index + %c208_se = arith.constant 208 : index + %c13_de = arith.constant 13 : index + %c1_dc = arith.constant 1 : index + %c13_vr = arith.constant 13 : index + %c13_vc = arith.constant 13 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c13_r, %c16_c], + strides = [%c208_se, %c208_se, %c208_se, %c16_c, %c1] + : !pto.tensor_view<1x1x1x13x16xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c13_r, %c1_dc], + strides = [%c13_de, %c13_de, %c13_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x13x1xi32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c13_vr, %c13_vc] + : !pto.tensor_view<1x1x1x13x16xf32> -> !pto.partition_tensor_view<1x1x1x13x13xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c13_vr, %c1] + : !pto.tensor_view<1x1x1x13x1xi32> -> !pto.partition_tensor_view<1x1x1x13x1xi32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x13x13xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x13x1xi32>) + return + } + + func.func @TROWARGMAX_int32_half_16x1_13x16_13x13(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c13_r = arith.constant 13 : index + %c16_c = arith.constant 16 : index + %c208_se = arith.constant 208 : index + %c13_de = arith.constant 13 : index + %c1_dc = arith.constant 1 : index + %c13_vr = arith.constant 13 : index + %c13_vc = arith.constant 13 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c13_r, %c16_c], + strides = [%c208_se, %c208_se, %c208_se, %c16_c, %c1] + : !pto.tensor_view<1x1x1x13x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c13_r, %c1_dc], + strides = [%c13_de, %c13_de, %c13_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x13x1xi32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c13_vr, %c13_vc] + : !pto.tensor_view<1x1x1x13x16xf16> -> !pto.partition_tensor_view<1x1x1x13x13xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c13_vr, %c1] + : !pto.tensor_view<1x1x1x13x1xi32> -> !pto.partition_tensor_view<1x1x1x13x1xi32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x13x13xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x13x1xi32>) + return + } + + func.func @TROWARGMAX_uint32_float_3x8_3x3480_3x3473(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3_r = arith.constant 3 : index + %c3480_c = arith.constant 3480 : index + %c10440_se = arith.constant 10440 : index + %c24_de = arith.constant 24 : index + %c8_dc = arith.constant 8 : index + %c3_vr = arith.constant 3 : index + %c3473_vc = arith.constant 3473 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c3_r, %c3480_c], + strides = [%c10440_se, %c10440_se, %c10440_se, %c3480_c, %c1] + : !pto.tensor_view<1x1x1x3x3480xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c3_r, %c8_dc], + strides = [%c24_de, %c24_de, %c24_de, %c8_dc, %c1] + : !pto.tensor_view<1x1x1x3x8xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c3_vr, %c3473_vc] + : !pto.tensor_view<1x1x1x3x3480xf32> -> !pto.partition_tensor_view<1x1x1x3x3473xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c3_vr, %c1] + : !pto.tensor_view<1x1x1x3x8xui32> -> !pto.partition_tensor_view<1x1x1x3x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x3x3473xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x3x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_float_260x8_260x64_260x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c260_r = arith.constant 260 : index + %c64_c = arith.constant 64 : index + %c16640_se = arith.constant 16640 : index + %c2080_de = arith.constant 2080 : index + %c8_dc = arith.constant 8 : index + %c260_vr = arith.constant 260 : index + %c64_vc = arith.constant 64 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c260_r, %c64_c], + strides = [%c16640_se, %c16640_se, %c16640_se, %c64_c, %c1] + : !pto.tensor_view<1x1x1x260x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c260_r, %c8_dc], + strides = [%c2080_de, %c2080_de, %c2080_de, %c8_dc, %c1] + : !pto.tensor_view<1x1x1x260x8xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260_vr, %c64_vc] + : !pto.tensor_view<1x1x1x260x64xf32> -> !pto.partition_tensor_view<1x1x1x260x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260_vr, %c1] + : !pto.tensor_view<1x1x1x260x8xui32> -> !pto.partition_tensor_view<1x1x1x260x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x260x64xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x260x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_float_1023x8_1023x24_1023x17(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1023_r = arith.constant 1023 : index + %c24_c = arith.constant 24 : index + %c24552_se = arith.constant 24552 : index + %c8184_de = arith.constant 8184 : index + %c8_dc = arith.constant 8 : index + %c1023_vr = arith.constant 1023 : index + %c17_vc = arith.constant 17 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1023_r, %c24_c], + strides = [%c24552_se, %c24552_se, %c24552_se, %c24_c, %c1] + : !pto.tensor_view<1x1x1x1023x24xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1023_r, %c8_dc], + strides = [%c8184_de, %c8184_de, %c8184_de, %c8_dc, %c1] + : !pto.tensor_view<1x1x1x1023x8xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1023_vr, %c17_vc] + : !pto.tensor_view<1x1x1x1023x24xf32> -> !pto.partition_tensor_view<1x1x1x1023x17xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1023_vr, %c1] + : !pto.tensor_view<1x1x1x1023x8xui32> -> !pto.partition_tensor_view<1x1x1x1023x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1023x17xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1023x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_half_3x16_3x3488_3x3473(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3_r = arith.constant 3 : index + %c3488_c = arith.constant 3488 : index + %c10464_se = arith.constant 10464 : index + %c48_de = arith.constant 48 : index + %c16_dc = arith.constant 16 : index + %c3_vr = arith.constant 3 : index + %c3473_vc = arith.constant 3473 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c3_r, %c3488_c], + strides = [%c10464_se, %c10464_se, %c10464_se, %c3488_c, %c1] + : !pto.tensor_view<1x1x1x3x3488xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c3_r, %c16_dc], + strides = [%c48_de, %c48_de, %c48_de, %c16_dc, %c1] + : !pto.tensor_view<1x1x1x3x16xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c3_vr, %c3473_vc] + : !pto.tensor_view<1x1x1x3x3488xf16> -> !pto.partition_tensor_view<1x1x1x3x3473xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c3_vr, %c1] + : !pto.tensor_view<1x1x1x3x16xui32> -> !pto.partition_tensor_view<1x1x1x3x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x3x3473xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x3x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_half_260x16_260x64_260x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c260_r = arith.constant 260 : index + %c64_c = arith.constant 64 : index + %c16640_se = arith.constant 16640 : index + %c4160_de = arith.constant 4160 : index + %c16_dc = arith.constant 16 : index + %c260_vr = arith.constant 260 : index + %c64_vc = arith.constant 64 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c260_r, %c64_c], + strides = [%c16640_se, %c16640_se, %c16640_se, %c64_c, %c1] + : !pto.tensor_view<1x1x1x260x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c260_r, %c16_dc], + strides = [%c4160_de, %c4160_de, %c4160_de, %c16_dc, %c1] + : !pto.tensor_view<1x1x1x260x16xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260_vr, %c64_vc] + : !pto.tensor_view<1x1x1x260x64xf16> -> !pto.partition_tensor_view<1x1x1x260x64xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260_vr, %c1] + : !pto.tensor_view<1x1x1x260x16xui32> -> !pto.partition_tensor_view<1x1x1x260x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x260x64xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x260x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_half_1023x16_1023x32_1023x17(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1023_r = arith.constant 1023 : index + %c32_c = arith.constant 32 : index + %c32736_se = arith.constant 32736 : index + %c16368_de = arith.constant 16368 : index + %c16_dc = arith.constant 16 : index + %c1023_vr = arith.constant 1023 : index + %c17_vc = arith.constant 17 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1023_r, %c32_c], + strides = [%c32736_se, %c32736_se, %c32736_se, %c32_c, %c1] + : !pto.tensor_view<1x1x1x1023x32xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1023_r, %c16_dc], + strides = [%c16368_de, %c16368_de, %c16368_de, %c16_dc, %c1] + : !pto.tensor_view<1x1x1x1023x16xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1023_vr, %c17_vc] + : !pto.tensor_view<1x1x1x1023x32xf16> -> !pto.partition_tensor_view<1x1x1x1023x17xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1023_vr, %c1] + : !pto.tensor_view<1x1x1x1023x16xui32> -> !pto.partition_tensor_view<1x1x1x1023x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1023x17xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1023x1xui32>) + return + } + +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/CMakeLists.txt new file mode 100644 index 000000000..a6a8925b5 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(trowargmin) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/cases.py new file mode 100644 index 000000000..2614b130e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/cases.py @@ -0,0 +1,215 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for trowargmin ST test cases — aligned with pto-isa.""" + +import numpy as np + +CASES = [ + # uint32_dst + float32_src + { + "name": "uint32_float_8x1_8x8_8x8", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (8, 8), + "valid_shape": (8, 8), + "eps": 0, + }, + { + "name": "uint32_float_1024x1_1024x8_1024x8", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (1024, 8), + "valid_shape": (1024, 8), + "eps": 0, + }, + { + "name": "uint32_float_16x1_13x16_13x13", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (13, 16), + "valid_shape": (13, 13), + "eps": 0, + }, + { + "name": "uint32_float_1024x1_1023x24_1023x17", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (1023, 24), + "valid_shape": (1023, 17), + "eps": 0, + }, + { + "name": "uint32_float_8x1_8x64_8x64", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (8, 64), + "valid_shape": (8, 64), + "eps": 0, + }, + { + "name": "uint32_float_264x1_260x64_260x64", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (260, 64), + "valid_shape": (260, 64), + "eps": 0, + }, + { + "name": "uint32_float_8x1_1x128_1x128", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (1, 128), + "valid_shape": (1, 128), + "eps": 0, + }, + { + "name": "uint32_float_64x1_32x128_32x128", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (32, 128), + "valid_shape": (32, 128), + "eps": 0, + }, + { + "name": "uint32_float_8x1_3x4096_3x4095", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (3, 4096), + "valid_shape": (3, 4095), + "eps": 0, + }, + { + "name": "uint32_float_8x1_2x16384_2x16381", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (2, 16384), + "valid_shape": (2, 16381), + "eps": 0, + }, + # uint32_dst + float16_src + { + "name": "uint32_half_16x1_2x16_2x16", + "dtype": np.float16, + "dst_dtype": np.uint32, + "shape": (2, 16), + "valid_shape": (2, 16), + "eps": 0, + }, + { + "name": "uint32_half_16x1_13x16_13x13", + "dtype": np.float16, + "dst_dtype": np.uint32, + "shape": (13, 16), + "valid_shape": (13, 13), + "eps": 0, + }, + { + "name": "uint32_half_272x1_260x64_260x64", + "dtype": np.float16, + "dst_dtype": np.uint32, + "shape": (260, 64), + "valid_shape": (260, 64), + "eps": 0, + }, + { + "name": "uint32_half_16x1_3x8192_3x8191", + "dtype": np.float16, + "dst_dtype": np.uint32, + "shape": (3, 8192), + "valid_shape": (3, 8191), + "eps": 0, + }, + { + "name": "uint32_half_16x1_1x16384_1x16381", + "dtype": np.float16, + "dst_dtype": np.uint32, + "shape": (1, 16384), + "valid_shape": (1, 16381), + "eps": 0, + }, + { + "name": "uint32_half_16x1_1x32768_1x32761", + "dtype": np.float16, + "dst_dtype": np.uint32, + "shape": (1, 32768), + "valid_shape": (1, 32761), + "eps": 0, + }, + # int32_dst + float32_src + { + "name": "int32_float_16x1_13x16_13x13", + "dtype": np.float32, + "dst_dtype": np.int32, + "shape": (13, 16), + "valid_shape": (13, 13), + "eps": 0, + }, + # int32_dst + float16_src + { + "name": "int32_half_16x1_13x16_13x13", + "dtype": np.float16, + "dst_dtype": np.int32, + "shape": (13, 16), + "valid_shape": (13, 13), + "eps": 0, + }, + # uint32_dst + float32_src (dst col > 1) + { + "name": "uint32_float_3x8_3x3480_3x3473", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (3, 3480), + "valid_shape": (3, 3473), + "eps": 0, + }, + { + "name": "uint32_float_260x8_260x64_260x64", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (260, 64), + "valid_shape": (260, 64), + "eps": 0, + }, + { + "name": "uint32_float_1023x8_1023x24_1023x17", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (1023, 24), + "valid_shape": (1023, 17), + "eps": 0, + }, + # uint32_dst + float16_src (dst col > 1) + { + "name": "uint32_half_3x16_3x3488_3x3473", + "dtype": np.float16, + "dst_dtype": np.uint32, + "shape": (3, 3488), + "valid_shape": (3, 3473), + "eps": 0, + }, + { + "name": "uint32_half_260x16_260x64_260x64", + "dtype": np.float16, + "dst_dtype": np.uint32, + "shape": (260, 64), + "valid_shape": (260, 64), + "eps": 0, + }, + { + "name": "uint32_half_1023x16_1023x32_1023x17", + "dtype": np.float16, + "dst_dtype": np.uint32, + "shape": (1023, 32), + "valid_shape": (1023, 17), + "eps": 0, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/compare.py new file mode 100644 index 000000000..4cd015fd3 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/compare.py @@ -0,0 +1,52 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + vr, vc = case["valid_shape"] + out_shape = (vr, 1) + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dst_dtype"], count=np.prod(out_shape)).reshape(out_shape) + + output_full = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dst_dtype"]) + dst_cols = len(output_full) // vr + output = output_full.reshape(vr, dst_cols)[:, 0:1] + + ok = result_cmp(golden, output, case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/gen_data.py new file mode 100644 index 000000000..6c103094c --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/gen_data.py @@ -0,0 +1,38 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + dst_dtype = case["dst_dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + if dtype in (np.int8, np.uint8, np.int16, np.uint16, np.int32, np.uint32): + dtype_info = np.iinfo(dtype) + input1 = np.random.randint(dtype_info.min, dtype_info.max, size=shape).astype(dtype) + else: + dtype_info = np.finfo(dtype) + input1 = np.random.uniform(low=dtype_info.min, high=dtype_info.max, size=shape).astype(dtype) + + out_shape = (valid_shape[0], 1) + golden = np.zeros(out_shape, dtype=dst_dtype) + golden[:, 0:1] = np.argmin(input1[:, :valid_shape[1]], axis=1, keepdims=True).astype(dst_dtype) + + save_case_data(case["name"], {"input1": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/launch.cpp new file mode 100644 index 000000000..d87134237 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/launch.cpp @@ -0,0 +1,133 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +extern "C" __global__ AICORE void TROWARGMIN_uint32_float_8x1_8x8_8x8(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_float_8x1_8x8_8x8(float *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_float_8x1_8x8_8x8<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_float_1024x1_1024x8_1024x8(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_float_1024x1_1024x8_1024x8(float *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_float_1024x1_1024x8_1024x8<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_float_16x1_13x16_13x13(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_float_16x1_13x16_13x13(float *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_float_16x1_13x16_13x13<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_float_1024x1_1023x24_1023x17(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_float_1024x1_1023x24_1023x17(float *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_float_1024x1_1023x24_1023x17<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_float_8x1_8x64_8x64(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_float_8x1_8x64_8x64(float *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_float_8x1_8x64_8x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_float_264x1_260x64_260x64(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_float_264x1_260x64_260x64(float *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_float_264x1_260x64_260x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_float_8x1_1x128_1x128(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_float_8x1_1x128_1x128(float *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_float_8x1_1x128_1x128<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_float_64x1_32x128_32x128(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_float_64x1_32x128_32x128(float *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_float_64x1_32x128_32x128<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_float_8x1_3x4096_3x4095(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_float_8x1_3x4096_3x4095(float *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_float_8x1_3x4096_3x4095<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_float_8x1_2x16384_2x16381(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_float_8x1_2x16384_2x16381(float *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_float_8x1_2x16384_2x16381<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_half_16x1_2x16_2x16(__gm__ uint16_t *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_half_16x1_2x16_2x16(uint16_t *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_half_16x1_2x16_2x16<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_half_16x1_13x16_13x13(__gm__ uint16_t *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_half_16x1_13x16_13x13(uint16_t *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_half_16x1_13x16_13x13<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_half_272x1_260x64_260x64(__gm__ uint16_t *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_half_272x1_260x64_260x64(uint16_t *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_half_272x1_260x64_260x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_half_16x1_3x8192_3x8191(__gm__ uint16_t *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_half_16x1_3x8192_3x8191(uint16_t *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_half_16x1_3x8192_3x8191<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_half_16x1_1x16384_1x16381(__gm__ uint16_t *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_half_16x1_1x16384_1x16381(uint16_t *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_half_16x1_1x16384_1x16381<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_half_16x1_1x32768_1x32761(__gm__ uint16_t *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_half_16x1_1x32768_1x32761(uint16_t *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_half_16x1_1x32768_1x32761<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_int32_float_16x1_13x16_13x13(__gm__ float *src, __gm__ int32_t *dst); +void LaunchTROWARGMIN_int32_float_16x1_13x16_13x13(float *src, int32_t *dst, void *stream) { + TROWARGMIN_int32_float_16x1_13x16_13x13<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ int32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_int32_half_16x1_13x16_13x13(__gm__ uint16_t *src, __gm__ int32_t *dst); +void LaunchTROWARGMIN_int32_half_16x1_13x16_13x13(uint16_t *src, int32_t *dst, void *stream) { + TROWARGMIN_int32_half_16x1_13x16_13x13<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ int32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_float_3x8_3x3480_3x3473(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_float_3x8_3x3480_3x3473(float *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_float_3x8_3x3480_3x3473<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_float_260x8_260x64_260x64(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_float_260x8_260x64_260x64(float *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_float_260x8_260x64_260x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_float_1023x8_1023x24_1023x17(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_float_1023x8_1023x24_1023x17(float *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_float_1023x8_1023x24_1023x17<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_half_3x16_3x3488_3x3473(__gm__ uint16_t *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_half_3x16_3x3488_3x3473(uint16_t *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_half_3x16_3x3488_3x3473<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_half_260x16_260x64_260x64(__gm__ uint16_t *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_half_260x16_260x64_260x64(uint16_t *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_half_260x16_260x64_260x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_half_1023x16_1023x32_1023x17(__gm__ uint16_t *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_half_1023x16_1023x32_1023x17(uint16_t *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_half_1023x16_1023x32_1023x17<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint32_t *)dst); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/main.cpp new file mode 100644 index 000000000..997db22a2 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/main.cpp @@ -0,0 +1,206 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang trowargmin ST — case-table driven. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTROWARGMIN_uint32_float_8x1_8x8_8x8(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_float_1024x1_1024x8_1024x8(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_float_16x1_13x16_13x13(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_float_1024x1_1023x24_1023x17(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_float_8x1_8x64_8x64(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_float_264x1_260x64_260x64(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_float_8x1_1x128_1x128(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_float_64x1_32x128_32x128(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_float_8x1_3x4096_3x4095(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_float_8x1_2x16384_2x16381(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_half_16x1_2x16_2x16(uint16_t *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_half_16x1_13x16_13x13(uint16_t *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_half_272x1_260x64_260x64(uint16_t *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_half_16x1_3x8192_3x8191(uint16_t *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_half_16x1_1x16384_1x16381(uint16_t *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_half_16x1_1x32768_1x32761(uint16_t *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_int32_float_16x1_13x16_13x13(float *src, int32_t *dst, void *stream); +void LaunchTROWARGMIN_int32_half_16x1_13x16_13x13(uint16_t *src, int32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_float_3x8_3x3480_3x3473(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_float_260x8_260x64_260x64(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_float_1023x8_1023x24_1023x17(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_half_3x16_3x3488_3x3473(uint16_t *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_half_260x16_260x64_260x64(uint16_t *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_half_1023x16_1023x32_1023x17(uint16_t *src, uint32_t *dst, void *stream); + +using LaunchFnF32U32 = void (*)(float *, uint32_t *, void *); +using LaunchFnF16U32 = void (*)(uint16_t *, uint32_t *, void *); +using LaunchFnF32S32 = void (*)(float *, int32_t *, void *); +using LaunchFnF16S32 = void (*)(uint16_t *, int32_t *, void *); + +enum class DType { F32U32, F16U32, F32S32, F16S32 }; + +struct TestCase { + const char *name; + DType dtype; + union { + LaunchFnF32U32 launchF32U32; + LaunchFnF16U32 launchF16U32; + LaunchFnF32S32 launchF32S32; + LaunchFnF16S32 launchF16S32; + }; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t srcElemSize; // bytes per src element + size_t dstElemSize; // bytes per dst element + size_t dstCols; // dst tile cols +}; + +static const TestCase kCases[] = { + {"uint32_float_8x1_8x8_8x8", DType::F32U32, .launchF32U32 = LaunchTROWARGMIN_uint32_float_8x1_8x8_8x8, 8, 8, 8, 8, 4, 4, 1}, + {"uint32_float_1024x1_1024x8_1024x8", DType::F32U32, .launchF32U32 = LaunchTROWARGMIN_uint32_float_1024x1_1024x8_1024x8, 1024, 8, 1024, 8, 4, 4, 1}, + {"uint32_float_16x1_13x16_13x13", DType::F32U32, .launchF32U32 = LaunchTROWARGMIN_uint32_float_16x1_13x16_13x13, 13, 16, 13, 13, 4, 4, 1}, + {"uint32_float_1024x1_1023x24_1023x17", DType::F32U32, .launchF32U32 = LaunchTROWARGMIN_uint32_float_1024x1_1023x24_1023x17, 1023, 24, 1023, 17, 4, 4, 1}, + {"uint32_float_8x1_8x64_8x64", DType::F32U32, .launchF32U32 = LaunchTROWARGMIN_uint32_float_8x1_8x64_8x64, 8, 64, 8, 64, 4, 4, 1}, + {"uint32_float_264x1_260x64_260x64", DType::F32U32, .launchF32U32 = LaunchTROWARGMIN_uint32_float_264x1_260x64_260x64, 260, 64, 260, 64, 4, 4, 1}, + {"uint32_float_8x1_1x128_1x128", DType::F32U32, .launchF32U32 = LaunchTROWARGMIN_uint32_float_8x1_1x128_1x128, 1, 128, 1, 128, 4, 4, 1}, + {"uint32_float_64x1_32x128_32x128", DType::F32U32, .launchF32U32 = LaunchTROWARGMIN_uint32_float_64x1_32x128_32x128, 32, 128, 32, 128, 4, 4, 1}, + {"uint32_float_8x1_3x4096_3x4095", DType::F32U32, .launchF32U32 = LaunchTROWARGMIN_uint32_float_8x1_3x4096_3x4095, 3, 4096, 3, 4095, 4, 4, 1}, + {"uint32_float_8x1_2x16384_2x16381", DType::F32U32, .launchF32U32 = LaunchTROWARGMIN_uint32_float_8x1_2x16384_2x16381, 2, 16384, 2, 16381, 4, 4, 1}, + {"uint32_half_16x1_2x16_2x16", DType::F16U32, .launchF16U32 = LaunchTROWARGMIN_uint32_half_16x1_2x16_2x16, 2, 16, 2, 16, 2, 4, 1}, + {"uint32_half_16x1_13x16_13x13", DType::F16U32, .launchF16U32 = LaunchTROWARGMIN_uint32_half_16x1_13x16_13x13, 13, 16, 13, 13, 2, 4, 1}, + {"uint32_half_272x1_260x64_260x64", DType::F16U32, .launchF16U32 = LaunchTROWARGMIN_uint32_half_272x1_260x64_260x64, 260, 64, 260, 64, 2, 4, 1}, + {"uint32_half_16x1_3x8192_3x8191", DType::F16U32, .launchF16U32 = LaunchTROWARGMIN_uint32_half_16x1_3x8192_3x8191, 3, 8192, 3, 8191, 2, 4, 1}, + {"uint32_half_16x1_1x16384_1x16381", DType::F16U32, .launchF16U32 = LaunchTROWARGMIN_uint32_half_16x1_1x16384_1x16381, 1, 16384, 1, 16381, 2, 4, 1}, + {"uint32_half_16x1_1x32768_1x32761", DType::F16U32, .launchF16U32 = LaunchTROWARGMIN_uint32_half_16x1_1x32768_1x32761, 1, 32768, 1, 32761, 2, 4, 1}, + {"int32_float_16x1_13x16_13x13", DType::F32S32, .launchF32S32 = LaunchTROWARGMIN_int32_float_16x1_13x16_13x13, 13, 16, 13, 13, 4, 4, 1}, + {"int32_half_16x1_13x16_13x13", DType::F16S32, .launchF16S32 = LaunchTROWARGMIN_int32_half_16x1_13x16_13x13, 13, 16, 13, 13, 2, 4, 1}, + {"uint32_float_3x8_3x3480_3x3473", DType::F32U32, .launchF32U32 = LaunchTROWARGMIN_uint32_float_3x8_3x3480_3x3473, 3, 3480, 3, 3473, 4, 4, 8}, + {"uint32_float_260x8_260x64_260x64", DType::F32U32, .launchF32U32 = LaunchTROWARGMIN_uint32_float_260x8_260x64_260x64, 260, 64, 260, 64, 4, 4, 8}, + {"uint32_float_1023x8_1023x24_1023x17", DType::F32U32, .launchF32U32 = LaunchTROWARGMIN_uint32_float_1023x8_1023x24_1023x17, 1023, 24, 1023, 17, 4, 4, 8}, + {"uint32_half_3x16_3x3488_3x3473", DType::F16U32, .launchF16U32 = LaunchTROWARGMIN_uint32_half_3x16_3x3488_3x3473, 3, 3488, 3, 3473, 2, 4, 16}, + {"uint32_half_260x16_260x64_260x64", DType::F16U32, .launchF16U32 = LaunchTROWARGMIN_uint32_half_260x16_260x64_260x64, 260, 64, 260, 64, 2, 4, 16}, + {"uint32_half_1023x16_1023x32_1023x17", DType::F16U32, .launchF16U32 = LaunchTROWARGMIN_uint32_half_1023x16_1023x32_1023x17, 1023, 32, 1023, 17, 2, 4, 16}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t srcElemCount = tc.rows * tc.cols; + const size_t srcFileSize = srcElemCount * tc.srcElemSize; + const size_t dstElemCount = tc.validRows * tc.dstCols; + const size_t dstFileSize = dstElemCount * tc.dstElemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = srcFileSize; + + void *src0Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&src0Host, srcFileSize); + aclrtMallocHost(&dstHost, dstFileSize); + + aclrtMalloc(&src0Device, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, srcFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, srcFileSize, src0Host, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + switch (tc.dtype) { + case DType::F32U32: + tc.launchF32U32((float *)src0Device, (uint32_t *)dstDevice, stream); + break; + case DType::F16U32: + tc.launchF16U32((uint16_t *)src0Device, (uint32_t *)dstDevice, stream); + break; + case DType::F32S32: + tc.launchF32S32((float *)src0Device, (int32_t *)dstDevice, stream); + break; + case DType::F16S32: + tc.launchF16S32((uint16_t *)src0Device, (int32_t *)dstDevice, stream); + break; + } + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0) { + mkdir(caseDir.c_str(), 0755); + if (!WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./trowargmin [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/trowargmin.pto b/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/trowargmin.pto new file mode 100644 index 000000000..25bc9c751 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/trowargmin.pto @@ -0,0 +1,1205 @@ +// Auto-generated trowargmin ST testcases + +module { + + func.func @TROWARGMIN_uint32_float_8x1_8x8_8x8(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8_r = arith.constant 8 : index + %c8_c = arith.constant 8 : index + %c64_se = arith.constant 64 : index + %c8_de = arith.constant 8 : index + %c1_dc = arith.constant 1 : index + %c8_vr = arith.constant 8 : index + %c8_vc = arith.constant 8 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c8_r, %c8_c], + strides = [%c64_se, %c64_se, %c64_se, %c8_c, %c1] + : !pto.tensor_view<1x1x1x8x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c8_r, %c1_dc], + strides = [%c8_de, %c8_de, %c8_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x8x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8_vr, %c8_vc] + : !pto.tensor_view<1x1x1x8x8xf32> -> !pto.partition_tensor_view<1x1x1x8x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8_vr, %c1] + : !pto.tensor_view<1x1x1x8x1xui32> -> !pto.partition_tensor_view<1x1x1x8x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x8x8xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x8x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_float_1024x1_1024x8_1024x8(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1024_r = arith.constant 1024 : index + %c8_c = arith.constant 8 : index + %c8192_se = arith.constant 8192 : index + %c1024_de = arith.constant 1024 : index + %c1_dc = arith.constant 1 : index + %c1024_vr = arith.constant 1024 : index + %c8_vc = arith.constant 8 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1024_r, %c8_c], + strides = [%c8192_se, %c8192_se, %c8192_se, %c8_c, %c1] + : !pto.tensor_view<1x1x1x1024x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1024_r, %c1_dc], + strides = [%c1024_de, %c1024_de, %c1024_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x1024x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1024_vr, %c8_vc] + : !pto.tensor_view<1x1x1x1024x8xf32> -> !pto.partition_tensor_view<1x1x1x1024x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1024_vr, %c1] + : !pto.tensor_view<1x1x1x1024x1xui32> -> !pto.partition_tensor_view<1x1x1x1024x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1024x8xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1024x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_float_16x1_13x16_13x13(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c13_r = arith.constant 13 : index + %c16_c = arith.constant 16 : index + %c208_se = arith.constant 208 : index + %c13_de = arith.constant 13 : index + %c1_dc = arith.constant 1 : index + %c13_vr = arith.constant 13 : index + %c13_vc = arith.constant 13 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c13_r, %c16_c], + strides = [%c208_se, %c208_se, %c208_se, %c16_c, %c1] + : !pto.tensor_view<1x1x1x13x16xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c13_r, %c1_dc], + strides = [%c13_de, %c13_de, %c13_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x13x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c13_vr, %c13_vc] + : !pto.tensor_view<1x1x1x13x16xf32> -> !pto.partition_tensor_view<1x1x1x13x13xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c13_vr, %c1] + : !pto.tensor_view<1x1x1x13x1xui32> -> !pto.partition_tensor_view<1x1x1x13x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x13x13xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x13x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_float_1024x1_1023x24_1023x17(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1023_r = arith.constant 1023 : index + %c24_c = arith.constant 24 : index + %c24552_se = arith.constant 24552 : index + %c1023_de = arith.constant 1023 : index + %c1_dc = arith.constant 1 : index + %c1023_vr = arith.constant 1023 : index + %c17_vc = arith.constant 17 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1023_r, %c24_c], + strides = [%c24552_se, %c24552_se, %c24552_se, %c24_c, %c1] + : !pto.tensor_view<1x1x1x1023x24xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1023_r, %c1_dc], + strides = [%c1023_de, %c1023_de, %c1023_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x1023x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1023_vr, %c17_vc] + : !pto.tensor_view<1x1x1x1023x24xf32> -> !pto.partition_tensor_view<1x1x1x1023x17xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1023_vr, %c1] + : !pto.tensor_view<1x1x1x1023x1xui32> -> !pto.partition_tensor_view<1x1x1x1023x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1023x17xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1023x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_float_8x1_8x64_8x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8_r = arith.constant 8 : index + %c64_c = arith.constant 64 : index + %c512_se = arith.constant 512 : index + %c8_de = arith.constant 8 : index + %c1_dc = arith.constant 1 : index + %c8_vr = arith.constant 8 : index + %c64_vc = arith.constant 64 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c8_r, %c64_c], + strides = [%c512_se, %c512_se, %c512_se, %c64_c, %c1] + : !pto.tensor_view<1x1x1x8x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c8_r, %c1_dc], + strides = [%c8_de, %c8_de, %c8_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x8x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8_vr, %c64_vc] + : !pto.tensor_view<1x1x1x8x64xf32> -> !pto.partition_tensor_view<1x1x1x8x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8_vr, %c1] + : !pto.tensor_view<1x1x1x8x1xui32> -> !pto.partition_tensor_view<1x1x1x8x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x8x64xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x8x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_float_264x1_260x64_260x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c260_r = arith.constant 260 : index + %c64_c = arith.constant 64 : index + %c16640_se = arith.constant 16640 : index + %c260_de = arith.constant 260 : index + %c1_dc = arith.constant 1 : index + %c260_vr = arith.constant 260 : index + %c64_vc = arith.constant 64 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c260_r, %c64_c], + strides = [%c16640_se, %c16640_se, %c16640_se, %c64_c, %c1] + : !pto.tensor_view<1x1x1x260x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c260_r, %c1_dc], + strides = [%c260_de, %c260_de, %c260_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x260x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260_vr, %c64_vc] + : !pto.tensor_view<1x1x1x260x64xf32> -> !pto.partition_tensor_view<1x1x1x260x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260_vr, %c1] + : !pto.tensor_view<1x1x1x260x1xui32> -> !pto.partition_tensor_view<1x1x1x260x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x260x64xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x260x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_float_8x1_1x128_1x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1_r = arith.constant 1 : index + %c128_c = arith.constant 128 : index + %c128_se = arith.constant 128 : index + %c1_de = arith.constant 1 : index + %c1_dc = arith.constant 1 : index + %c1_vr = arith.constant 1 : index + %c128_vc = arith.constant 128 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1_r, %c128_c], + strides = [%c128_se, %c128_se, %c128_se, %c128_c, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1_r, %c1_dc], + strides = [%c1_de, %c1_de, %c1_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x1x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1_vr, %c128_vc] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1_vr, %c1] + : !pto.tensor_view<1x1x1x1x1xui32> -> !pto.partition_tensor_view<1x1x1x1x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_float_64x1_32x128_32x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32_r = arith.constant 32 : index + %c128_c = arith.constant 128 : index + %c4096_se = arith.constant 4096 : index + %c32_de = arith.constant 32 : index + %c1_dc = arith.constant 1 : index + %c32_vr = arith.constant 32 : index + %c128_vc = arith.constant 128 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32_r, %c128_c], + strides = [%c4096_se, %c4096_se, %c4096_se, %c128_c, %c1] + : !pto.tensor_view<1x1x1x32x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32_r, %c1_dc], + strides = [%c32_de, %c32_de, %c32_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x32x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32_vr, %c128_vc] + : !pto.tensor_view<1x1x1x32x128xf32> -> !pto.partition_tensor_view<1x1x1x32x128xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32_vr, %c1] + : !pto.tensor_view<1x1x1x32x1xui32> -> !pto.partition_tensor_view<1x1x1x32x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x128xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_float_8x1_3x4096_3x4095(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3_r = arith.constant 3 : index + %c4096_c = arith.constant 4096 : index + %c12288_se = arith.constant 12288 : index + %c3_de = arith.constant 3 : index + %c1_dc = arith.constant 1 : index + %c3_vr = arith.constant 3 : index + %c4095_vc = arith.constant 4095 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c3_r, %c4096_c], + strides = [%c12288_se, %c12288_se, %c12288_se, %c4096_c, %c1] + : !pto.tensor_view<1x1x1x3x4096xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c3_r, %c1_dc], + strides = [%c3_de, %c3_de, %c3_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x3x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c3_vr, %c4095_vc] + : !pto.tensor_view<1x1x1x3x4096xf32> -> !pto.partition_tensor_view<1x1x1x3x4095xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c3_vr, %c1] + : !pto.tensor_view<1x1x1x3x1xui32> -> !pto.partition_tensor_view<1x1x1x3x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x3x4095xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x3x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_float_8x1_2x16384_2x16381(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2_r = arith.constant 2 : index + %c16384_c = arith.constant 16384 : index + %c32768_se = arith.constant 32768 : index + %c2_de = arith.constant 2 : index + %c1_dc = arith.constant 1 : index + %c2_vr = arith.constant 2 : index + %c16381_vc = arith.constant 16381 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2_r, %c16384_c], + strides = [%c32768_se, %c32768_se, %c32768_se, %c16384_c, %c1] + : !pto.tensor_view<1x1x1x2x16384xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2_r, %c1_dc], + strides = [%c2_de, %c2_de, %c2_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x2x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2_vr, %c16381_vc] + : !pto.tensor_view<1x1x1x2x16384xf32> -> !pto.partition_tensor_view<1x1x1x2x16381xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2_vr, %c1] + : !pto.tensor_view<1x1x1x2x1xui32> -> !pto.partition_tensor_view<1x1x1x2x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x16381xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_half_16x1_2x16_2x16(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2_r = arith.constant 2 : index + %c16_c = arith.constant 16 : index + %c32_se = arith.constant 32 : index + %c2_de = arith.constant 2 : index + %c1_dc = arith.constant 1 : index + %c2_vr = arith.constant 2 : index + %c16_vc = arith.constant 16 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2_r, %c16_c], + strides = [%c32_se, %c32_se, %c32_se, %c16_c, %c1] + : !pto.tensor_view<1x1x1x2x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2_r, %c1_dc], + strides = [%c2_de, %c2_de, %c2_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x2x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2_vr, %c16_vc] + : !pto.tensor_view<1x1x1x2x16xf16> -> !pto.partition_tensor_view<1x1x1x2x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2_vr, %c1] + : !pto.tensor_view<1x1x1x2x1xui32> -> !pto.partition_tensor_view<1x1x1x2x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x16xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_half_16x1_13x16_13x13(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c13_r = arith.constant 13 : index + %c16_c = arith.constant 16 : index + %c208_se = arith.constant 208 : index + %c13_de = arith.constant 13 : index + %c1_dc = arith.constant 1 : index + %c13_vr = arith.constant 13 : index + %c13_vc = arith.constant 13 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c13_r, %c16_c], + strides = [%c208_se, %c208_se, %c208_se, %c16_c, %c1] + : !pto.tensor_view<1x1x1x13x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c13_r, %c1_dc], + strides = [%c13_de, %c13_de, %c13_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x13x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c13_vr, %c13_vc] + : !pto.tensor_view<1x1x1x13x16xf16> -> !pto.partition_tensor_view<1x1x1x13x13xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c13_vr, %c1] + : !pto.tensor_view<1x1x1x13x1xui32> -> !pto.partition_tensor_view<1x1x1x13x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x13x13xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x13x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_half_272x1_260x64_260x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c260_r = arith.constant 260 : index + %c64_c = arith.constant 64 : index + %c16640_se = arith.constant 16640 : index + %c260_de = arith.constant 260 : index + %c1_dc = arith.constant 1 : index + %c260_vr = arith.constant 260 : index + %c64_vc = arith.constant 64 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c260_r, %c64_c], + strides = [%c16640_se, %c16640_se, %c16640_se, %c64_c, %c1] + : !pto.tensor_view<1x1x1x260x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c260_r, %c1_dc], + strides = [%c260_de, %c260_de, %c260_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x260x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260_vr, %c64_vc] + : !pto.tensor_view<1x1x1x260x64xf16> -> !pto.partition_tensor_view<1x1x1x260x64xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260_vr, %c1] + : !pto.tensor_view<1x1x1x260x1xui32> -> !pto.partition_tensor_view<1x1x1x260x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x260x64xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x260x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_half_16x1_3x8192_3x8191(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3_r = arith.constant 3 : index + %c8192_c = arith.constant 8192 : index + %c24576_se = arith.constant 24576 : index + %c3_de = arith.constant 3 : index + %c1_dc = arith.constant 1 : index + %c3_vr = arith.constant 3 : index + %c8191_vc = arith.constant 8191 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c3_r, %c8192_c], + strides = [%c24576_se, %c24576_se, %c24576_se, %c8192_c, %c1] + : !pto.tensor_view<1x1x1x3x8192xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c3_r, %c1_dc], + strides = [%c3_de, %c3_de, %c3_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x3x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c3_vr, %c8191_vc] + : !pto.tensor_view<1x1x1x3x8192xf16> -> !pto.partition_tensor_view<1x1x1x3x8191xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c3_vr, %c1] + : !pto.tensor_view<1x1x1x3x1xui32> -> !pto.partition_tensor_view<1x1x1x3x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x3x8191xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x3x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_half_16x1_1x16384_1x16381(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1_r = arith.constant 1 : index + %c16384_c = arith.constant 16384 : index + %c16384_se = arith.constant 16384 : index + %c1_de = arith.constant 1 : index + %c1_dc = arith.constant 1 : index + %c1_vr = arith.constant 1 : index + %c16381_vc = arith.constant 16381 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1_r, %c16384_c], + strides = [%c16384_se, %c16384_se, %c16384_se, %c16384_c, %c1] + : !pto.tensor_view<1x1x1x1x16384xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1_r, %c1_dc], + strides = [%c1_de, %c1_de, %c1_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x1x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1_vr, %c16381_vc] + : !pto.tensor_view<1x1x1x1x16384xf16> -> !pto.partition_tensor_view<1x1x1x1x16381xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1_vr, %c1] + : !pto.tensor_view<1x1x1x1x1xui32> -> !pto.partition_tensor_view<1x1x1x1x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x16381xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_half_16x1_1x32768_1x32761(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1_r = arith.constant 1 : index + %c32768_c = arith.constant 32768 : index + %c32768_se = arith.constant 32768 : index + %c1_de = arith.constant 1 : index + %c1_dc = arith.constant 1 : index + %c1_vr = arith.constant 1 : index + %c32761_vc = arith.constant 32761 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1_r, %c32768_c], + strides = [%c32768_se, %c32768_se, %c32768_se, %c32768_c, %c1] + : !pto.tensor_view<1x1x1x1x32768xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1_r, %c1_dc], + strides = [%c1_de, %c1_de, %c1_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x1x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1_vr, %c32761_vc] + : !pto.tensor_view<1x1x1x1x32768xf16> -> !pto.partition_tensor_view<1x1x1x1x32761xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1_vr, %c1] + : !pto.tensor_view<1x1x1x1x1xui32> -> !pto.partition_tensor_view<1x1x1x1x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x32761xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x1xui32>) + return + } + + func.func @TROWARGMIN_int32_float_16x1_13x16_13x13(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c13_r = arith.constant 13 : index + %c16_c = arith.constant 16 : index + %c208_se = arith.constant 208 : index + %c13_de = arith.constant 13 : index + %c1_dc = arith.constant 1 : index + %c13_vr = arith.constant 13 : index + %c13_vc = arith.constant 13 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c13_r, %c16_c], + strides = [%c208_se, %c208_se, %c208_se, %c16_c, %c1] + : !pto.tensor_view<1x1x1x13x16xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c13_r, %c1_dc], + strides = [%c13_de, %c13_de, %c13_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x13x1xi32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c13_vr, %c13_vc] + : !pto.tensor_view<1x1x1x13x16xf32> -> !pto.partition_tensor_view<1x1x1x13x13xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c13_vr, %c1] + : !pto.tensor_view<1x1x1x13x1xi32> -> !pto.partition_tensor_view<1x1x1x13x1xi32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x13x13xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x13x1xi32>) + return + } + + func.func @TROWARGMIN_int32_half_16x1_13x16_13x13(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c13_r = arith.constant 13 : index + %c16_c = arith.constant 16 : index + %c208_se = arith.constant 208 : index + %c13_de = arith.constant 13 : index + %c1_dc = arith.constant 1 : index + %c13_vr = arith.constant 13 : index + %c13_vc = arith.constant 13 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c13_r, %c16_c], + strides = [%c208_se, %c208_se, %c208_se, %c16_c, %c1] + : !pto.tensor_view<1x1x1x13x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c13_r, %c1_dc], + strides = [%c13_de, %c13_de, %c13_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x13x1xi32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c13_vr, %c13_vc] + : !pto.tensor_view<1x1x1x13x16xf16> -> !pto.partition_tensor_view<1x1x1x13x13xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c13_vr, %c1] + : !pto.tensor_view<1x1x1x13x1xi32> -> !pto.partition_tensor_view<1x1x1x13x1xi32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x13x13xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x13x1xi32>) + return + } + + func.func @TROWARGMIN_uint32_float_3x8_3x3480_3x3473(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3_r = arith.constant 3 : index + %c3480_c = arith.constant 3480 : index + %c10440_se = arith.constant 10440 : index + %c24_de = arith.constant 24 : index + %c8_dc = arith.constant 8 : index + %c3_vr = arith.constant 3 : index + %c3473_vc = arith.constant 3473 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c3_r, %c3480_c], + strides = [%c10440_se, %c10440_se, %c10440_se, %c3480_c, %c1] + : !pto.tensor_view<1x1x1x3x3480xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c3_r, %c8_dc], + strides = [%c24_de, %c24_de, %c24_de, %c8_dc, %c1] + : !pto.tensor_view<1x1x1x3x8xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c3_vr, %c3473_vc] + : !pto.tensor_view<1x1x1x3x3480xf32> -> !pto.partition_tensor_view<1x1x1x3x3473xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c3_vr, %c1] + : !pto.tensor_view<1x1x1x3x8xui32> -> !pto.partition_tensor_view<1x1x1x3x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x3x3473xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x3x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_float_260x8_260x64_260x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c260_r = arith.constant 260 : index + %c64_c = arith.constant 64 : index + %c16640_se = arith.constant 16640 : index + %c2080_de = arith.constant 2080 : index + %c8_dc = arith.constant 8 : index + %c260_vr = arith.constant 260 : index + %c64_vc = arith.constant 64 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c260_r, %c64_c], + strides = [%c16640_se, %c16640_se, %c16640_se, %c64_c, %c1] + : !pto.tensor_view<1x1x1x260x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c260_r, %c8_dc], + strides = [%c2080_de, %c2080_de, %c2080_de, %c8_dc, %c1] + : !pto.tensor_view<1x1x1x260x8xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260_vr, %c64_vc] + : !pto.tensor_view<1x1x1x260x64xf32> -> !pto.partition_tensor_view<1x1x1x260x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260_vr, %c1] + : !pto.tensor_view<1x1x1x260x8xui32> -> !pto.partition_tensor_view<1x1x1x260x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x260x64xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x260x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_float_1023x8_1023x24_1023x17(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1023_r = arith.constant 1023 : index + %c24_c = arith.constant 24 : index + %c24552_se = arith.constant 24552 : index + %c8184_de = arith.constant 8184 : index + %c8_dc = arith.constant 8 : index + %c1023_vr = arith.constant 1023 : index + %c17_vc = arith.constant 17 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1023_r, %c24_c], + strides = [%c24552_se, %c24552_se, %c24552_se, %c24_c, %c1] + : !pto.tensor_view<1x1x1x1023x24xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1023_r, %c8_dc], + strides = [%c8184_de, %c8184_de, %c8184_de, %c8_dc, %c1] + : !pto.tensor_view<1x1x1x1023x8xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1023_vr, %c17_vc] + : !pto.tensor_view<1x1x1x1023x24xf32> -> !pto.partition_tensor_view<1x1x1x1023x17xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1023_vr, %c1] + : !pto.tensor_view<1x1x1x1023x8xui32> -> !pto.partition_tensor_view<1x1x1x1023x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1023x17xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1023x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_half_3x16_3x3488_3x3473(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3_r = arith.constant 3 : index + %c3488_c = arith.constant 3488 : index + %c10464_se = arith.constant 10464 : index + %c48_de = arith.constant 48 : index + %c16_dc = arith.constant 16 : index + %c3_vr = arith.constant 3 : index + %c3473_vc = arith.constant 3473 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c3_r, %c3488_c], + strides = [%c10464_se, %c10464_se, %c10464_se, %c3488_c, %c1] + : !pto.tensor_view<1x1x1x3x3488xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c3_r, %c16_dc], + strides = [%c48_de, %c48_de, %c48_de, %c16_dc, %c1] + : !pto.tensor_view<1x1x1x3x16xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c3_vr, %c3473_vc] + : !pto.tensor_view<1x1x1x3x3488xf16> -> !pto.partition_tensor_view<1x1x1x3x3473xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c3_vr, %c1] + : !pto.tensor_view<1x1x1x3x16xui32> -> !pto.partition_tensor_view<1x1x1x3x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x3x3473xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x3x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_half_260x16_260x64_260x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c260_r = arith.constant 260 : index + %c64_c = arith.constant 64 : index + %c16640_se = arith.constant 16640 : index + %c4160_de = arith.constant 4160 : index + %c16_dc = arith.constant 16 : index + %c260_vr = arith.constant 260 : index + %c64_vc = arith.constant 64 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c260_r, %c64_c], + strides = [%c16640_se, %c16640_se, %c16640_se, %c64_c, %c1] + : !pto.tensor_view<1x1x1x260x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c260_r, %c16_dc], + strides = [%c4160_de, %c4160_de, %c4160_de, %c16_dc, %c1] + : !pto.tensor_view<1x1x1x260x16xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260_vr, %c64_vc] + : !pto.tensor_view<1x1x1x260x64xf16> -> !pto.partition_tensor_view<1x1x1x260x64xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260_vr, %c1] + : !pto.tensor_view<1x1x1x260x16xui32> -> !pto.partition_tensor_view<1x1x1x260x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x260x64xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x260x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_half_1023x16_1023x32_1023x17(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1023_r = arith.constant 1023 : index + %c32_c = arith.constant 32 : index + %c32736_se = arith.constant 32736 : index + %c16368_de = arith.constant 16368 : index + %c16_dc = arith.constant 16 : index + %c1023_vr = arith.constant 1023 : index + %c17_vc = arith.constant 17 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1023_r, %c32_c], + strides = [%c32736_se, %c32736_se, %c32736_se, %c32_c, %c1] + : !pto.tensor_view<1x1x1x1023x32xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1023_r, %c16_dc], + strides = [%c16368_de, %c16368_de, %c16368_de, %c16_dc, %c1] + : !pto.tensor_view<1x1x1x1023x16xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1023_vr, %c17_vc] + : !pto.tensor_view<1x1x1x1023x32xf16> -> !pto.partition_tensor_view<1x1x1x1023x17xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1023_vr, %c1] + : !pto.tensor_view<1x1x1x1023x16xui32> -> !pto.partition_tensor_view<1x1x1x1023x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1023x17xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1023x1xui32>) + return + } + +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowmax/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/trowmax/CMakeLists.txt new file mode 100644 index 000000000..62291cfb6 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowmax/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(trowmax) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowmax/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/trowmax/cases.py new file mode 100644 index 000000000..f6db377f2 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowmax/cases.py @@ -0,0 +1,224 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for trowmax ST test cases. + +Aligned with pto-isa tests/npu/a2a3/src/st/testcase/trowmax (28 cases). +""" + +import numpy as np + +CASES = [ + # f32 cases (case1-case5 from pto-isa) + { + "name": "f32_127x64_valid127x63", + "dtype": np.float32, + "shape": (127, 64), + "valid_shape": (127, 63), + "eps": 1e-5, + }, + { + "name": "f32_63x64", + "dtype": np.float32, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 1e-5, + }, + { + "name": "f32_31x128_valid31x127", + "dtype": np.float32, + "shape": (31, 128), + "valid_shape": (31, 127), + "eps": 1e-5, + }, + { + "name": "f32_15x192", + "dtype": np.float32, + "shape": (15, 192), + "valid_shape": (15, 192), + "eps": 1e-5, + }, + { + "name": "f32_7x448_valid7x447", + "dtype": np.float32, + "shape": (7, 448), + "valid_shape": (7, 447), + "eps": 1e-5, + }, + # f16 case (case6 from pto-isa) + { + "name": "f16_256x16_valid256x15", + "dtype": np.float16, + "shape": (256, 16), + "valid_shape": (256, 15), + "eps": 1e-2, + }, + # f32 more cases (case7-case14 from pto-isa) + { + "name": "f32_30x216", + "dtype": np.float32, + "shape": (30, 216), + "valid_shape": (30, 216), + "eps": 1e-5, + }, + { + "name": "f32_30x216_valid30x24", + "dtype": np.float32, + "shape": (30, 216), + "valid_shape": (30, 24), + "eps": 1e-5, + }, + { + "name": "f32_30x216_valid11x216", + "dtype": np.float32, + "shape": (30, 216), + "valid_shape": (11, 216), + "eps": 1e-5, + }, + { + "name": "f32_30x216_valid11x24", + "dtype": np.float32, + "shape": (30, 216), + "valid_shape": (11, 24), + "eps": 1e-5, + }, + { + "name": "f32_238x40", + "dtype": np.float32, + "shape": (238, 40), + "valid_shape": (238, 40), + "eps": 1e-5, + }, + { + "name": "f32_238x40_valid238x16", + "dtype": np.float32, + "shape": (238, 40), + "valid_shape": (238, 16), + "eps": 1e-5, + }, + { + "name": "f32_238x40_valid121x40", + "dtype": np.float32, + "shape": (238, 40), + "valid_shape": (121, 40), + "eps": 1e-5, + }, + { + "name": "f32_238x40_valid121x16", + "dtype": np.float32, + "shape": (238, 40), + "valid_shape": (121, 16), + "eps": 1e-5, + }, + # f32 DN dst cases (case15-case18 from pto-isa) + { + "name": "f32_64x128", + "dtype": np.float32, + "shape": (64, 128), + "valid_shape": (64, 128), + "eps": 1e-5, + }, + { + "name": "f32_32x256", + "dtype": np.float32, + "shape": (32, 256), + "valid_shape": (32, 256), + "eps": 1e-5, + }, + { + "name": "f32_16x512", + "dtype": np.float32, + "shape": (16, 512), + "valid_shape": (16, 512), + "eps": 1e-5, + }, + { + "name": "f32_8x1024", + "dtype": np.float32, + "shape": (8, 1024), + "valid_shape": (8, 1024), + "eps": 1e-5, + }, + + # int32 cases (case19-case23 from pto-isa) + { + "name": "i32_127x64_valid127x63", + "dtype": np.int32, + "shape": (127, 64), + "valid_shape": (127, 63), + "eps": 0, + }, + { + "name": "i32_63x64", + "dtype": np.int32, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 0, + }, + { + "name": "i32_31x128_valid31x127", + "dtype": np.int32, + "shape": (31, 128), + "valid_shape": (31, 127), + "eps": 0, + }, + { + "name": "i32_15x192", + "dtype": np.int32, + "shape": (15, 192), + "valid_shape": (15, 192), + "eps": 0, + }, + { + "name": "i32_7x448_valid7x447", + "dtype": np.int32, + "shape": (7, 448), + "valid_shape": (7, 447), + "eps": 0, + }, + + # int16 cases (case24-case28 from pto-isa) + { + "name": "i16_128x64", + "dtype": np.int16, + "shape": (128, 64), + "valid_shape": (128, 64), + "eps": 0, + }, + { + "name": "i16_64x64", + "dtype": np.int16, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 0, + }, + { + "name": "i16_32x128", + "dtype": np.int16, + "shape": (32, 128), + "valid_shape": (32, 128), + "eps": 0, + }, + { + "name": "i16_16x192", + "dtype": np.int16, + "shape": (16, 192), + "valid_shape": (16, 192), + "eps": 0, + }, + { + "name": "i16_8x448", + "dtype": np.int16, + "shape": (8, 448), + "valid_shape": (8, 448), + "eps": 0, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowmax/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/trowmax/compare.py new file mode 100644 index 000000000..12d4207bd --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowmax/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + vr, vc = case["valid_shape"] + out_shape = (vr,) + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"], count=np.prod(out_shape)).reshape(out_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"], count=np.prod(out_shape)).reshape(out_shape) + + ok = result_cmp(golden, output, case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowmax/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/trowmax/gen_data.py new file mode 100644 index 000000000..97495c982 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowmax/gen_data.py @@ -0,0 +1,41 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + if np.issubdtype(dtype, np.integer): + if dtype == np.int32: + input1 = np.random.randint(low=-100, high=100, size=shape).astype(dtype) + else: + input1 = np.random.randint(low=-50, high=50, size=shape).astype(dtype) + else: + input1 = np.random.uniform(low=-16, high=16, size=shape).astype(dtype) + + out_shape = (valid_shape[0],) + golden = np.zeros(out_shape, dtype=dtype) + vr, vc = valid_shape + for i in range(vr): + golden[i] = np.max(input1[i, :vc]) + + golden = golden.astype(dtype, copy=False) + save_case_data(case["name"], {"input1": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowmax/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowmax/launch.cpp new file mode 100644 index 000000000..a5d840da9 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowmax/launch.cpp @@ -0,0 +1,183 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +extern "C" __global__ AICORE void TROWMAX_f32_127x64_valid127x63(__gm__ float *src, __gm__ float *dst); + +void LaunchTROWMAX_f32_127x64_valid127x63(float *src, float *dst, void *stream) { + TROWMAX_f32_127x64_valid127x63<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_f32_63x64(__gm__ float *src, __gm__ float *dst); + +void LaunchTROWMAX_f32_63x64(float *src, float *dst, void *stream) { + TROWMAX_f32_63x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_f32_31x128_valid31x127(__gm__ float *src, __gm__ float *dst); + +void LaunchTROWMAX_f32_31x128_valid31x127(float *src, float *dst, void *stream) { + TROWMAX_f32_31x128_valid31x127<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_f32_15x192(__gm__ float *src, __gm__ float *dst); + +void LaunchTROWMAX_f32_15x192(float *src, float *dst, void *stream) { + TROWMAX_f32_15x192<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_f32_7x448_valid7x447(__gm__ float *src, __gm__ float *dst); + +void LaunchTROWMAX_f32_7x448_valid7x447(float *src, float *dst, void *stream) { + TROWMAX_f32_7x448_valid7x447<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_f16_256x16_valid256x15(__gm__ uint16_t *src, __gm__ uint16_t *dst); + +void LaunchTROWMAX_f16_256x16_valid256x15(uint16_t *src, uint16_t *dst, void *stream) { + TROWMAX_f16_256x16_valid256x15<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint16_t *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_f32_30x216(__gm__ float *src, __gm__ float *dst); + +void LaunchTROWMAX_f32_30x216(float *src, float *dst, void *stream) { + TROWMAX_f32_30x216<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_f32_30x216_valid30x24(__gm__ float *src, __gm__ float *dst); + +void LaunchTROWMAX_f32_30x216_valid30x24(float *src, float *dst, void *stream) { + TROWMAX_f32_30x216_valid30x24<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_f32_30x216_valid11x216(__gm__ float *src, __gm__ float *dst); + +void LaunchTROWMAX_f32_30x216_valid11x216(float *src, float *dst, void *stream) { + TROWMAX_f32_30x216_valid11x216<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_f32_30x216_valid11x24(__gm__ float *src, __gm__ float *dst); + +void LaunchTROWMAX_f32_30x216_valid11x24(float *src, float *dst, void *stream) { + TROWMAX_f32_30x216_valid11x24<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_f32_238x40(__gm__ float *src, __gm__ float *dst); + +void LaunchTROWMAX_f32_238x40(float *src, float *dst, void *stream) { + TROWMAX_f32_238x40<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_f32_238x40_valid238x16(__gm__ float *src, __gm__ float *dst); + +void LaunchTROWMAX_f32_238x40_valid238x16(float *src, float *dst, void *stream) { + TROWMAX_f32_238x40_valid238x16<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_f32_238x40_valid121x40(__gm__ float *src, __gm__ float *dst); + +void LaunchTROWMAX_f32_238x40_valid121x40(float *src, float *dst, void *stream) { + TROWMAX_f32_238x40_valid121x40<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_f32_238x40_valid121x16(__gm__ float *src, __gm__ float *dst); + +void LaunchTROWMAX_f32_238x40_valid121x16(float *src, float *dst, void *stream) { + TROWMAX_f32_238x40_valid121x16<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_f32_64x128(__gm__ float *src, __gm__ float *dst); + +void LaunchTROWMAX_f32_64x128(float *src, float *dst, void *stream) { + TROWMAX_f32_64x128<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_f32_32x256(__gm__ float *src, __gm__ float *dst); + +void LaunchTROWMAX_f32_32x256(float *src, float *dst, void *stream) { + TROWMAX_f32_32x256<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_f32_16x512(__gm__ float *src, __gm__ float *dst); + +void LaunchTROWMAX_f32_16x512(float *src, float *dst, void *stream) { + TROWMAX_f32_16x512<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_f32_8x1024(__gm__ float *src, __gm__ float *dst); + +void LaunchTROWMAX_f32_8x1024(float *src, float *dst, void *stream) { + TROWMAX_f32_8x1024<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +// int32 cases +extern "C" __global__ AICORE void TROWMAX_i32_127x64_valid127x63(__gm__ int32_t *src, __gm__ int32_t *dst); + +void LaunchTROWMAX_i32_127x64_valid127x63(int32_t *src, int32_t *dst, void *stream) { + TROWMAX_i32_127x64_valid127x63<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_i32_63x64(__gm__ int32_t *src, __gm__ int32_t *dst); + +void LaunchTROWMAX_i32_63x64(int32_t *src, int32_t *dst, void *stream) { + TROWMAX_i32_63x64<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_i32_31x128_valid31x127(__gm__ int32_t *src, __gm__ int32_t *dst); + +void LaunchTROWMAX_i32_31x128_valid31x127(int32_t *src, int32_t *dst, void *stream) { + TROWMAX_i32_31x128_valid31x127<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_i32_15x192(__gm__ int32_t *src, __gm__ int32_t *dst); + +void LaunchTROWMAX_i32_15x192(int32_t *src, int32_t *dst, void *stream) { + TROWMAX_i32_15x192<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_i32_7x448_valid7x447(__gm__ int32_t *src, __gm__ int32_t *dst); + +void LaunchTROWMAX_i32_7x448_valid7x447(int32_t *src, int32_t *dst, void *stream) { + TROWMAX_i32_7x448_valid7x447<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} + +// int16 cases +extern "C" __global__ AICORE void TROWMAX_i16_128x64(__gm__ int16_t *src, __gm__ int16_t *dst); + +void LaunchTROWMAX_i16_128x64(int16_t *src, int16_t *dst, void *stream) { + TROWMAX_i16_128x64<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_i16_64x64(__gm__ int16_t *src, __gm__ int16_t *dst); + +void LaunchTROWMAX_i16_64x64(int16_t *src, int16_t *dst, void *stream) { + TROWMAX_i16_64x64<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_i16_32x128(__gm__ int16_t *src, __gm__ int16_t *dst); + +void LaunchTROWMAX_i16_32x128(int16_t *src, int16_t *dst, void *stream) { + TROWMAX_i16_32x128<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_i16_16x192(__gm__ int16_t *src, __gm__ int16_t *dst); + +void LaunchTROWMAX_i16_16x192(int16_t *src, int16_t *dst, void *stream) { + TROWMAX_i16_16x192<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_i16_8x448(__gm__ int16_t *src, __gm__ int16_t *dst); + +void LaunchTROWMAX_i16_8x448(int16_t *src, int16_t *dst, void *stream) { + TROWMAX_i16_8x448<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowmax/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowmax/main.cpp new file mode 100644 index 000000000..b89132b41 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowmax/main.cpp @@ -0,0 +1,207 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang trowmax ST — case-table driven. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTROWMAX_f32_127x64_valid127x63(float *src, float *dst, void *stream); +void LaunchTROWMAX_f32_63x64(float *src, float *dst, void *stream); +void LaunchTROWMAX_f32_31x128_valid31x127(float *src, float *dst, void *stream); +void LaunchTROWMAX_f32_15x192(float *src, float *dst, void *stream); +void LaunchTROWMAX_f32_7x448_valid7x447(float *src, float *dst, void *stream); +void LaunchTROWMAX_f16_256x16_valid256x15(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTROWMAX_f32_30x216(float *src, float *dst, void *stream); +void LaunchTROWMAX_f32_30x216_valid30x24(float *src, float *dst, void *stream); +void LaunchTROWMAX_f32_30x216_valid11x216(float *src, float *dst, void *stream); +void LaunchTROWMAX_f32_30x216_valid11x24(float *src, float *dst, void *stream); +void LaunchTROWMAX_f32_238x40(float *src, float *dst, void *stream); +void LaunchTROWMAX_f32_238x40_valid238x16(float *src, float *dst, void *stream); +void LaunchTROWMAX_f32_238x40_valid121x40(float *src, float *dst, void *stream); +void LaunchTROWMAX_f32_238x40_valid121x16(float *src, float *dst, void *stream); +void LaunchTROWMAX_f32_64x128(float *src, float *dst, void *stream); +void LaunchTROWMAX_f32_32x256(float *src, float *dst, void *stream); +void LaunchTROWMAX_f32_16x512(float *src, float *dst, void *stream); +void LaunchTROWMAX_f32_8x1024(float *src, float *dst, void *stream); +void LaunchTROWMAX_i32_127x64_valid127x63(int32_t *src, int32_t *dst, void *stream); +void LaunchTROWMAX_i32_63x64(int32_t *src, int32_t *dst, void *stream); +void LaunchTROWMAX_i32_31x128_valid31x127(int32_t *src, int32_t *dst, void *stream); +void LaunchTROWMAX_i32_15x192(int32_t *src, int32_t *dst, void *stream); +void LaunchTROWMAX_i32_7x448_valid7x447(int32_t *src, int32_t *dst, void *stream); +void LaunchTROWMAX_i16_128x64(int16_t *src, int16_t *dst, void *stream); +void LaunchTROWMAX_i16_64x64(int16_t *src, int16_t *dst, void *stream); +void LaunchTROWMAX_i16_32x128(int16_t *src, int16_t *dst, void *stream); +void LaunchTROWMAX_i16_16x192(int16_t *src, int16_t *dst, void *stream); +void LaunchTROWMAX_i16_8x448(int16_t *src, int16_t *dst, void *stream); + +using LaunchFnF32 = void (*)(float *, float *, void *); +using LaunchFnF16 = void (*)(uint16_t *, uint16_t *, void *); +using LaunchFnI32 = void (*)(int32_t *, int32_t *, void *); +using LaunchFnI16 = void (*)(int16_t *, int16_t *, void *); + +enum class DType { F32, F16, I32, I16 }; + +struct TestCase { + const char *name; + DType dtype; + union { + LaunchFnF32 launchF32; + LaunchFnF16 launchF16; + LaunchFnI32 launchI32; + LaunchFnI16 launchI16; + }; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + // f32 cases + {"f32_127x64_valid127x63", DType::F32, .launchF32 = LaunchTROWMAX_f32_127x64_valid127x63, 127, 64, 127, 63, 4}, + {"f32_63x64", DType::F32, .launchF32 = LaunchTROWMAX_f32_63x64, 63, 64, 63, 64, 4}, + {"f32_31x128_valid31x127", DType::F32, .launchF32 = LaunchTROWMAX_f32_31x128_valid31x127, 31, 128, 31, 127, 4}, + {"f32_15x192", DType::F32, .launchF32 = LaunchTROWMAX_f32_15x192, 15, 192, 15, 192, 4}, + {"f32_7x448_valid7x447", DType::F32, .launchF32 = LaunchTROWMAX_f32_7x448_valid7x447, 7, 448, 7, 447, 4}, + // f16 case + {"f16_256x16_valid256x15", DType::F16, .launchF16 = LaunchTROWMAX_f16_256x16_valid256x15, 256, 16, 256, 15, 2}, + // f32 more cases + {"f32_30x216", DType::F32, .launchF32 = LaunchTROWMAX_f32_30x216, 30, 216, 30, 216, 4}, + {"f32_30x216_valid30x24", DType::F32, .launchF32 = LaunchTROWMAX_f32_30x216_valid30x24, 30, 216, 30, 24, 4}, + {"f32_30x216_valid11x216", DType::F32, .launchF32 = LaunchTROWMAX_f32_30x216_valid11x216, 30, 216, 11, 216, 4}, + {"f32_30x216_valid11x24", DType::F32, .launchF32 = LaunchTROWMAX_f32_30x216_valid11x24, 30, 216, 11, 24, 4}, + {"f32_238x40", DType::F32, .launchF32 = LaunchTROWMAX_f32_238x40, 238, 40, 238, 40, 4}, + {"f32_238x40_valid238x16", DType::F32, .launchF32 = LaunchTROWMAX_f32_238x40_valid238x16, 238, 40, 238, 16, 4}, + {"f32_238x40_valid121x40", DType::F32, .launchF32 = LaunchTROWMAX_f32_238x40_valid121x40, 238, 40, 121, 40, 4}, + {"f32_238x40_valid121x16", DType::F32, .launchF32 = LaunchTROWMAX_f32_238x40_valid121x16, 238, 40, 121, 16, 4}, + // f32 DN dst cases + {"f32_64x128", DType::F32, .launchF32 = LaunchTROWMAX_f32_64x128, 64, 128, 64, 128, 4}, + {"f32_32x256", DType::F32, .launchF32 = LaunchTROWMAX_f32_32x256, 32, 256, 32, 256, 4}, + {"f32_16x512", DType::F32, .launchF32 = LaunchTROWMAX_f32_16x512, 16, 512, 16, 512, 4}, + {"f32_8x1024", DType::F32, .launchF32 = LaunchTROWMAX_f32_8x1024, 8, 1024,8, 1024,4}, + // int32 cases + {"i32_127x64_valid127x63", DType::I32, .launchI32 = LaunchTROWMAX_i32_127x64_valid127x63, 127, 64, 127, 63, 4}, + {"i32_63x64", DType::I32, .launchI32 = LaunchTROWMAX_i32_63x64, 63, 64, 63, 64, 4}, + {"i32_31x128_valid31x127", DType::I32, .launchI32 = LaunchTROWMAX_i32_31x128_valid31x127, 31, 128, 31, 127, 4}, + {"i32_15x192", DType::I32, .launchI32 = LaunchTROWMAX_i32_15x192, 15, 192, 15, 192, 4}, + {"i32_7x448_valid7x447", DType::I32, .launchI32 = LaunchTROWMAX_i32_7x448_valid7x447, 7, 448, 7, 447, 4}, + // int16 cases + {"i16_128x64", DType::I16, .launchI16 = LaunchTROWMAX_i16_128x64, 128, 64, 128, 64, 2}, + {"i16_64x64", DType::I16, .launchI16 = LaunchTROWMAX_i16_64x64, 64, 64, 64, 64, 2}, + {"i16_32x128", DType::I16, .launchI16 = LaunchTROWMAX_i16_32x128, 32, 128, 32, 128, 2}, + {"i16_16x192", DType::I16, .launchI16 = LaunchTROWMAX_i16_16x192, 16, 192, 16, 192, 2}, + {"i16_8x448", DType::I16, .launchI16 = LaunchTROWMAX_i16_8x448, 8, 448, 8, 448, 2}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t srcElemCount = tc.rows * tc.cols; + const size_t srcFileSize = srcElemCount * tc.elemSize; + const size_t dstElemCount = tc.validRows * 1; + const size_t dstFileSize = dstElemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = srcFileSize; + + void *src0Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&src0Host, srcFileSize); + aclrtMallocHost(&dstHost, dstFileSize); + + aclrtMalloc(&src0Device, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, srcFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, srcFileSize, src0Host, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + switch (tc.dtype) { + case DType::F32: tc.launchF32((float *)src0Device, (float *)dstDevice, stream); break; + case DType::F16: tc.launchF16((uint16_t *)src0Device, (uint16_t *)dstDevice, stream); break; + case DType::I32: tc.launchI32((int32_t *)src0Device, (int32_t *)dstDevice, stream); break; + case DType::I16: tc.launchI16((int16_t *)src0Device, (int16_t *)dstDevice, stream); break; + } + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./trowmax [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowmax/trowmax.pto b/test/tilelang_st/npu/a5/src/st/testcase/trowmax/trowmax.pto new file mode 100644 index 000000000..4b658ee84 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowmax/trowmax.pto @@ -0,0 +1,1545 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.trowmax: tload(src) + trowmax(src, tmp)->dst + tstore(dst). + +module { + + // Case 0: f32 127x64 (valid=127x63) + func.func @TROWMAX_f32_127x64_valid127x63(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c127 = arith.constant 127 : index + %c8128 = arith.constant 8128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c127, %c64], + strides = [%c8128, %c8128, %c8128, %c64, %c1] + : !pto.tensor_view<1x1x1x127x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c127, %c1], + strides = [%c127, %c127, %c127, %c1, %c1] + : !pto.tensor_view<1x1x1x127x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c127, %c63] + : !pto.tensor_view<1x1x1x127x64xf32> -> !pto.partition_tensor_view<1x1x1x127x63xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c127, %c1] + : !pto.tensor_view<1x1x1x127x1xf32> -> !pto.partition_tensor_view<1x1x1x127x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x127x63xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x127x1xf32>) + return + } + + // Case 1: f32 63x64 (valid=63x64) + func.func @TROWMAX_f32_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c1], + strides = [%c63, %c63, %c63, %c1, %c1] + : !pto.tensor_view<1x1x1x63x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xf32> -> !pto.partition_tensor_view<1x1x1x63x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c1] + : !pto.tensor_view<1x1x1x63x1xf32> -> !pto.partition_tensor_view<1x1x1x63x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x1xf32>) + return + } + + // Case 2: f32 31x128 (valid=31x127) + func.func @TROWMAX_f32_31x128_valid31x127(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c31 = arith.constant 31 : index + %c127 = arith.constant 127 : index + %c128 = arith.constant 128 : index + %c3968 = arith.constant 3968 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c31, %c1], + strides = [%c31, %c31, %c31, %c1, %c1] + : !pto.tensor_view<1x1x1x31x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c127] + : !pto.tensor_view<1x1x1x31x128xf32> -> !pto.partition_tensor_view<1x1x1x31x127xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c1] + : !pto.tensor_view<1x1x1x31x1xf32> -> !pto.partition_tensor_view<1x1x1x31x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x31x127xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x31x1xf32>) + return + } + + // Case 3: f32 15x192 (valid=15x192) + func.func @TROWMAX_f32_15x192(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c15 = arith.constant 15 : index + %c192 = arith.constant 192 : index + %c2880 = arith.constant 2880 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c15, %c1], + strides = [%c15, %c15, %c15, %c1, %c1] + : !pto.tensor_view<1x1x1x15x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xf32> -> !pto.partition_tensor_view<1x1x1x15x192xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c1] + : !pto.tensor_view<1x1x1x15x1xf32> -> !pto.partition_tensor_view<1x1x1x15x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x192xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x15x1xf32>) + return + } + + // Case 4: f32 7x448 (valid=7x447) + func.func @TROWMAX_f32_7x448_valid7x447(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c8 = arith.constant 8 : index + %c447 = arith.constant 447 : index + %c448 = arith.constant 448 : index + %c3136 = arith.constant 3136 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c7, %c1], + strides = [%c7, %c7, %c7, %c1, %c1] + : !pto.tensor_view<1x1x1x7x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c447] + : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x447xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c1] + : !pto.tensor_view<1x1x1x7x1xf32> -> !pto.partition_tensor_view<1x1x1x7x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x7x447xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x7x1xf32>) + return + } + + // Case 5: f16 256x16 (valid=256x15) + func.func @TROWMAX_f16_256x16_valid256x15(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c4096, %c4096, %c4096, %c16, %c1] + : !pto.tensor_view<1x1x1x256x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c256, %c1], + strides = [%c256, %c256, %c256, %c1, %c1] + : !pto.tensor_view<1x1x1x256x1xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c15] + : !pto.tensor_view<1x1x1x256x16xf16> -> !pto.partition_tensor_view<1x1x1x256x15xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c1] + : !pto.tensor_view<1x1x1x256x1xf16> -> !pto.partition_tensor_view<1x1x1x256x1xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x256x15xf16>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x256x1xf16>) + return + } + + // Case 6: f32 30x216 (valid=30x216) + func.func @TROWMAX_f32_30x216(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c30 = arith.constant 30 : index + %c216 = arith.constant 216 : index + %c6480 = arith.constant 6480 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c30, %c216], + strides = [%c6480, %c6480, %c6480, %c216, %c1] + : !pto.tensor_view<1x1x1x30x216xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c30, %c1], + strides = [%c30, %c30, %c30, %c1, %c1] + : !pto.tensor_view<1x1x1x30x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c30, %c216] + : !pto.tensor_view<1x1x1x30x216xf32> -> !pto.partition_tensor_view<1x1x1x30x216xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c30, %c1] + : !pto.tensor_view<1x1x1x30x1xf32> -> !pto.partition_tensor_view<1x1x1x30x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x30x216xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x30x1xf32>) + return + } + + // Case 7: f32 30x216 (valid=30x24) + func.func @TROWMAX_f32_30x216_valid30x24(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c24 = arith.constant 24 : index + %c30 = arith.constant 30 : index + %c216 = arith.constant 216 : index + %c6480 = arith.constant 6480 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c30, %c216], + strides = [%c6480, %c6480, %c6480, %c216, %c1] + : !pto.tensor_view<1x1x1x30x216xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c30, %c1], + strides = [%c30, %c30, %c30, %c1, %c1] + : !pto.tensor_view<1x1x1x30x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c30, %c24] + : !pto.tensor_view<1x1x1x30x216xf32> -> !pto.partition_tensor_view<1x1x1x30x24xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c30, %c1] + : !pto.tensor_view<1x1x1x30x1xf32> -> !pto.partition_tensor_view<1x1x1x30x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x30x24xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x30x1xf32>) + return + } + + // Case 8: f32 30x216 (valid=11x216) + func.func @TROWMAX_f32_30x216_valid11x216(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c11 = arith.constant 11 : index + %c30 = arith.constant 30 : index + %c216 = arith.constant 216 : index + %c6480 = arith.constant 6480 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c30, %c216], + strides = [%c6480, %c6480, %c6480, %c216, %c1] + : !pto.tensor_view<1x1x1x30x216xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c11, %c1], + strides = [%c11, %c11, %c11, %c1, %c1] + : !pto.tensor_view<1x1x1x11x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c11, %c216] + : !pto.tensor_view<1x1x1x30x216xf32> -> !pto.partition_tensor_view<1x1x1x11x216xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c11, %c1] + : !pto.tensor_view<1x1x1x11x1xf32> -> !pto.partition_tensor_view<1x1x1x11x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x11x216xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x11x1xf32>) + return + } + + // Case 9: f32 30x216 (valid=11x24) + func.func @TROWMAX_f32_30x216_valid11x24(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c11 = arith.constant 11 : index + %c24 = arith.constant 24 : index + %c30 = arith.constant 30 : index + %c216 = arith.constant 216 : index + %c6480 = arith.constant 6480 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c30, %c216], + strides = [%c6480, %c6480, %c6480, %c216, %c1] + : !pto.tensor_view<1x1x1x30x216xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c11, %c1], + strides = [%c11, %c11, %c11, %c1, %c1] + : !pto.tensor_view<1x1x1x11x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c11, %c24] + : !pto.tensor_view<1x1x1x30x216xf32> -> !pto.partition_tensor_view<1x1x1x11x24xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c11, %c1] + : !pto.tensor_view<1x1x1x11x1xf32> -> !pto.partition_tensor_view<1x1x1x11x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x11x24xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x11x1xf32>) + return + } + + // Case 10: f32 238x40 (valid=238x40) + func.func @TROWMAX_f32_238x40(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c40 = arith.constant 40 : index + %c238 = arith.constant 238 : index + %c9520 = arith.constant 9520 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c238, %c40], + strides = [%c9520, %c9520, %c9520, %c40, %c1] + : !pto.tensor_view<1x1x1x238x40xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c238, %c1], + strides = [%c238, %c238, %c238, %c1, %c1] + : !pto.tensor_view<1x1x1x238x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c238, %c40] + : !pto.tensor_view<1x1x1x238x40xf32> -> !pto.partition_tensor_view<1x1x1x238x40xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c238, %c1] + : !pto.tensor_view<1x1x1x238x1xf32> -> !pto.partition_tensor_view<1x1x1x238x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x238x40xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x238x1xf32>) + return + } + + // Case 11: f32 238x40 (valid=238x16) + func.func @TROWMAX_f32_238x40_valid238x16(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c40 = arith.constant 40 : index + %c238 = arith.constant 238 : index + %c9520 = arith.constant 9520 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c238, %c40], + strides = [%c9520, %c9520, %c9520, %c40, %c1] + : !pto.tensor_view<1x1x1x238x40xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c238, %c1], + strides = [%c238, %c238, %c238, %c1, %c1] + : !pto.tensor_view<1x1x1x238x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c238, %c16] + : !pto.tensor_view<1x1x1x238x40xf32> -> !pto.partition_tensor_view<1x1x1x238x16xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c238, %c1] + : !pto.tensor_view<1x1x1x238x1xf32> -> !pto.partition_tensor_view<1x1x1x238x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x238x16xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x238x1xf32>) + return + } + + // Case 12: f32 238x40 (valid=121x40) + func.func @TROWMAX_f32_238x40_valid121x40(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c40 = arith.constant 40 : index + %c121 = arith.constant 121 : index + %c238 = arith.constant 238 : index + %c9520 = arith.constant 9520 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c238, %c40], + strides = [%c9520, %c9520, %c9520, %c40, %c1] + : !pto.tensor_view<1x1x1x238x40xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c121, %c1], + strides = [%c121, %c121, %c121, %c1, %c1] + : !pto.tensor_view<1x1x1x121x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c121, %c40] + : !pto.tensor_view<1x1x1x238x40xf32> -> !pto.partition_tensor_view<1x1x1x121x40xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c121, %c1] + : !pto.tensor_view<1x1x1x121x1xf32> -> !pto.partition_tensor_view<1x1x1x121x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x121x40xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x121x1xf32>) + return + } + + // Case 13: f32 238x40 (valid=121x16) + func.func @TROWMAX_f32_238x40_valid121x16(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c40 = arith.constant 40 : index + %c121 = arith.constant 121 : index + %c238 = arith.constant 238 : index + %c9520 = arith.constant 9520 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c238, %c40], + strides = [%c9520, %c9520, %c9520, %c40, %c1] + : !pto.tensor_view<1x1x1x238x40xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c121, %c1], + strides = [%c121, %c121, %c121, %c1, %c1] + : !pto.tensor_view<1x1x1x121x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c121, %c16] + : !pto.tensor_view<1x1x1x238x40xf32> -> !pto.partition_tensor_view<1x1x1x121x16xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c121, %c1] + : !pto.tensor_view<1x1x1x121x1xf32> -> !pto.partition_tensor_view<1x1x1x121x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x121x16xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x121x1xf32>) + return + } + + // Case 14: f32 64x128 (valid=64x128) + func.func @TROWMAX_f32_64x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c64, %c128], + strides = [%c8192, %c8192, %c8192, %c128, %c1] + : !pto.tensor_view<1x1x1x64x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c64, %c1], + strides = [%c64, %c64, %c64, %c1, %c1] + : !pto.tensor_view<1x1x1x64x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c128] + : !pto.tensor_view<1x1x1x64x128xf32> -> !pto.partition_tensor_view<1x1x1x64x128xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c1] + : !pto.tensor_view<1x1x1x64x1xf32> -> !pto.partition_tensor_view<1x1x1x64x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x64x128xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x64x1xf32>) + return + } + + // Case 15: f32 32x256 (valid=32x256) + func.func @TROWMAX_f32_32x256(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c256], + strides = [%c8192, %c8192, %c8192, %c256, %c1] + : !pto.tensor_view<1x1x1x32x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c1], + strides = [%c32, %c32, %c32, %c1, %c1] + : !pto.tensor_view<1x1x1x32x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c256] + : !pto.tensor_view<1x1x1x32x256xf32> -> !pto.partition_tensor_view<1x1x1x32x256xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c1] + : !pto.tensor_view<1x1x1x32x1xf32> -> !pto.partition_tensor_view<1x1x1x32x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x256xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x1xf32>) + return + } + + // Case 16: f32 16x512 (valid=16x512) + func.func @TROWMAX_f32_16x512(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c512 = arith.constant 512 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c512], + strides = [%c8192, %c8192, %c8192, %c512, %c1] + : !pto.tensor_view<1x1x1x16x512xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c1], + strides = [%c16, %c16, %c16, %c1, %c1] + : !pto.tensor_view<1x1x1x16x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c512] + : !pto.tensor_view<1x1x1x16x512xf32> -> !pto.partition_tensor_view<1x1x1x16x512xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c1] + : !pto.tensor_view<1x1x1x16x1xf32> -> !pto.partition_tensor_view<1x1x1x16x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x512xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x1xf32>) + return + } + + // Case 17: f32 8x1024 (valid=8x1024) + func.func @TROWMAX_f32_8x1024(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c1024 = arith.constant 1024 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c8, %c1024], + strides = [%c8192, %c8192, %c8192, %c1024, %c1] + : !pto.tensor_view<1x1x1x8x1024xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c8, %c1], + strides = [%c8, %c8, %c8, %c1, %c1] + : !pto.tensor_view<1x1x1x8x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c1024] + : !pto.tensor_view<1x1x1x8x1024xf32> -> !pto.partition_tensor_view<1x1x1x8x1024xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c1] + : !pto.tensor_view<1x1x1x8x1xf32> -> !pto.partition_tensor_view<1x1x1x8x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x8x1024xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x8x1xf32>) + return + } + + // ======================================================================== + // int32 cases (case19-case23) + // ======================================================================== + + // case19: i32 127x64 valid=127x63 + func.func @TROWMAX_i32_127x64_valid127x63(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c127 = arith.constant 127 : index + %c8128 = arith.constant 8128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c127, %c64], + strides = [%c8128, %c8128, %c8128, %c64, %c1] + : !pto.tensor_view<1x1x1x127x64xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c127, %c1], + strides = [%c127, %c127, %c127, %c1, %c1] + : !pto.tensor_view<1x1x1x127x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c127, %c63] + : !pto.tensor_view<1x1x1x127x64xi32> -> !pto.partition_tensor_view<1x1x1x127x63xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c127, %c1] + : !pto.tensor_view<1x1x1x127x1xi32> -> !pto.partition_tensor_view<1x1x1x127x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x127x63xi32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x127x1xi32>) + return + } + + // case20: i32 63x64 valid=63x64 + func.func @TROWMAX_i32_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c1], + strides = [%c63, %c63, %c63, %c1, %c1] + : !pto.tensor_view<1x1x1x63x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xi32> -> !pto.partition_tensor_view<1x1x1x63x64xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c1] + : !pto.tensor_view<1x1x1x63x1xi32> -> !pto.partition_tensor_view<1x1x1x63x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xi32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x1xi32>) + return + } + + // case21: i32 31x128 valid=31x127 + func.func @TROWMAX_i32_31x128_valid31x127(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c31 = arith.constant 31 : index + %c127 = arith.constant 127 : index + %c128 = arith.constant 128 : index + %c3968 = arith.constant 3968 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c31, %c1], + strides = [%c31, %c31, %c31, %c1, %c1] + : !pto.tensor_view<1x1x1x31x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c127] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x127xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c1] + : !pto.tensor_view<1x1x1x31x1xi32> -> !pto.partition_tensor_view<1x1x1x31x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x31x127xi32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x31x1xi32>) + return + } + + // case22: i32 15x192 valid=15x192 + func.func @TROWMAX_i32_15x192(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c15 = arith.constant 15 : index + %c192 = arith.constant 192 : index + %c2880 = arith.constant 2880 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c15, %c1], + strides = [%c15, %c15, %c15, %c1, %c1] + : !pto.tensor_view<1x1x1x15x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi32> -> !pto.partition_tensor_view<1x1x1x15x192xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c1] + : !pto.tensor_view<1x1x1x15x1xi32> -> !pto.partition_tensor_view<1x1x1x15x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x192xi32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x15x1xi32>) + return + } + + // case23: i32 7x448 valid=7x447 + func.func @TROWMAX_i32_7x448_valid7x447(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c7 = arith.constant 7 : index + %c447 = arith.constant 447 : index + %c448 = arith.constant 448 : index + %c3136 = arith.constant 3136 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c7, %c1], + strides = [%c7, %c7, %c7, %c1, %c1] + : !pto.tensor_view<1x1x1x7x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c447] + : !pto.tensor_view<1x1x1x7x448xi32> -> !pto.partition_tensor_view<1x1x1x7x447xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c1] + : !pto.tensor_view<1x1x1x7x1xi32> -> !pto.partition_tensor_view<1x1x1x7x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x7x447xi32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x7x1xi32>) + return + } + + // ======================================================================== + // int16 cases (case24-case28) + // ======================================================================== + + // case24: i16 128x64 valid=128x64 + func.func @TROWMAX_i16_128x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c128, %c64], + strides = [%c8192, %c8192, %c8192, %c64, %c1] + : !pto.tensor_view<1x1x1x128x64xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c128, %c1], + strides = [%c128, %c128, %c128, %c1, %c1] + : !pto.tensor_view<1x1x1x128x1xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c64] + : !pto.tensor_view<1x1x1x128x64xi16> -> !pto.partition_tensor_view<1x1x1x128x64xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c1] + : !pto.tensor_view<1x1x1x128x1xi16> -> !pto.partition_tensor_view<1x1x1x128x1xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x128x64xi16>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x128x1xi16>) + return + } + + // case25: i16 64x64 valid=64x64 + func.func @TROWMAX_i16_64x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c64, %c1], + strides = [%c64, %c64, %c64, %c1, %c1] + : !pto.tensor_view<1x1x1x64x1xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c1] + : !pto.tensor_view<1x1x1x64x1xi16> -> !pto.partition_tensor_view<1x1x1x64x1xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x64x1xi16>) + return + } + + // case26: i16 32x128 valid=32x128 + func.func @TROWMAX_i16_32x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c128], + strides = [%c4096, %c4096, %c4096, %c128, %c1] + : !pto.tensor_view<1x1x1x32x128xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c1], + strides = [%c32, %c32, %c32, %c1, %c1] + : !pto.tensor_view<1x1x1x32x1xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c128] + : !pto.tensor_view<1x1x1x32x128xi16> -> !pto.partition_tensor_view<1x1x1x32x128xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c1] + : !pto.tensor_view<1x1x1x32x1xi16> -> !pto.partition_tensor_view<1x1x1x32x1xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x128xi16>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x1xi16>) + return + } + + // case27: i16 16x192 valid=16x192 + func.func @TROWMAX_i16_16x192(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c16r = arith.constant 16 : index + %c192 = arith.constant 192 : index + %c3072 = arith.constant 3072 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16r, %c192], + strides = [%c3072, %c3072, %c3072, %c192, %c1] + : !pto.tensor_view<1x1x1x16x192xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16r, %c1], + strides = [%c16, %c16, %c16, %c1, %c1] + : !pto.tensor_view<1x1x1x16x1xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16r, %c192] + : !pto.tensor_view<1x1x1x16x192xi16> -> !pto.partition_tensor_view<1x1x1x16x192xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16r, %c1] + : !pto.tensor_view<1x1x1x16x1xi16> -> !pto.partition_tensor_view<1x1x1x16x1xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x192xi16>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x1xi16>) + return + } + + // case28: i16 8x448 valid=8x448 + func.func @TROWMAX_i16_8x448(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c8 = arith.constant 8 : index + %c448 = arith.constant 448 : index + %c3584 = arith.constant 3584 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c8, %c448], + strides = [%c3584, %c3584, %c3584, %c448, %c1] + : !pto.tensor_view<1x1x1x8x448xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c8, %c1], + strides = [%c8, %c8, %c8, %c1, %c1] + : !pto.tensor_view<1x1x1x8x1xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c448] + : !pto.tensor_view<1x1x1x8x448xi16> -> !pto.partition_tensor_view<1x1x1x8x448xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c1] + : !pto.tensor_view<1x1x1x8x1xi16> -> !pto.partition_tensor_view<1x1x1x8x1xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x8x448xi16>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x8x1xi16>) + return + } + +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowmin/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/trowmin/CMakeLists.txt new file mode 100644 index 000000000..e88611a82 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowmin/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(trowmin) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowmin/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/trowmin/cases.py new file mode 100644 index 000000000..903509084 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowmin/cases.py @@ -0,0 +1,224 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for trowmin ST test cases. + +Aligned with pto-isa tests/npu/a2a3/src/st/testcase/trowmin (28 cases). +""" + +import numpy as np + +CASES = [ + # f32 cases (case1-case5 from pto-isa) + { + "name": "f32_127x64_valid127x63", + "dtype": np.float32, + "shape": (127, 64), + "valid_shape": (127, 63), + "eps": 1e-5, + }, + { + "name": "f32_63x64", + "dtype": np.float32, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 1e-5, + }, + { + "name": "f32_31x128_valid31x127", + "dtype": np.float32, + "shape": (31, 128), + "valid_shape": (31, 127), + "eps": 1e-5, + }, + { + "name": "f32_15x192", + "dtype": np.float32, + "shape": (15, 192), + "valid_shape": (15, 192), + "eps": 1e-5, + }, + { + "name": "f32_7x448_valid7x447", + "dtype": np.float32, + "shape": (7, 448), + "valid_shape": (7, 447), + "eps": 1e-5, + }, + # f16 case (case6 from pto-isa) + { + "name": "f16_256x16_valid256x15", + "dtype": np.float16, + "shape": (256, 16), + "valid_shape": (256, 15), + "eps": 1e-2, + }, + # f32 more cases (case7-case14 from pto-isa) + { + "name": "f32_30x216", + "dtype": np.float32, + "shape": (30, 216), + "valid_shape": (30, 216), + "eps": 1e-5, + }, + { + "name": "f32_30x216_valid30x24", + "dtype": np.float32, + "shape": (30, 216), + "valid_shape": (30, 24), + "eps": 1e-5, + }, + { + "name": "f32_30x216_valid11x216", + "dtype": np.float32, + "shape": (30, 216), + "valid_shape": (11, 216), + "eps": 1e-5, + }, + { + "name": "f32_30x216_valid11x24", + "dtype": np.float32, + "shape": (30, 216), + "valid_shape": (11, 24), + "eps": 1e-5, + }, + { + "name": "f32_238x40", + "dtype": np.float32, + "shape": (238, 40), + "valid_shape": (238, 40), + "eps": 1e-5, + }, + { + "name": "f32_238x40_valid238x16", + "dtype": np.float32, + "shape": (238, 40), + "valid_shape": (238, 16), + "eps": 1e-5, + }, + { + "name": "f32_238x40_valid121x40", + "dtype": np.float32, + "shape": (238, 40), + "valid_shape": (121, 40), + "eps": 1e-5, + }, + { + "name": "f32_238x40_valid121x16", + "dtype": np.float32, + "shape": (238, 40), + "valid_shape": (121, 16), + "eps": 1e-5, + }, + # f32 DN dst cases (case15-case18 from pto-isa) + { + "name": "f32_64x128", + "dtype": np.float32, + "shape": (64, 128), + "valid_shape": (64, 128), + "eps": 1e-5, + }, + { + "name": "f32_32x256", + "dtype": np.float32, + "shape": (32, 256), + "valid_shape": (32, 256), + "eps": 1e-5, + }, + { + "name": "f32_16x512", + "dtype": np.float32, + "shape": (16, 512), + "valid_shape": (16, 512), + "eps": 1e-5, + }, + { + "name": "f32_8x1024", + "dtype": np.float32, + "shape": (8, 1024), + "valid_shape": (8, 1024), + "eps": 1e-5, + }, + + # int32 cases (case19-case23 from pto-isa) + { + "name": "i32_127x64_valid127x63", + "dtype": np.int32, + "shape": (127, 64), + "valid_shape": (127, 63), + "eps": 0, + }, + { + "name": "i32_63x64", + "dtype": np.int32, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 0, + }, + { + "name": "i32_31x128_valid31x127", + "dtype": np.int32, + "shape": (31, 128), + "valid_shape": (31, 127), + "eps": 0, + }, + { + "name": "i32_15x192", + "dtype": np.int32, + "shape": (15, 192), + "valid_shape": (15, 192), + "eps": 0, + }, + { + "name": "i32_7x448_valid7x447", + "dtype": np.int32, + "shape": (7, 448), + "valid_shape": (7, 447), + "eps": 0, + }, + + # int16 cases (case24-case28 from pto-isa) + { + "name": "i16_128x64", + "dtype": np.int16, + "shape": (128, 64), + "valid_shape": (128, 64), + "eps": 0, + }, + { + "name": "i16_64x64", + "dtype": np.int16, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 0, + }, + { + "name": "i16_32x128", + "dtype": np.int16, + "shape": (32, 128), + "valid_shape": (32, 128), + "eps": 0, + }, + { + "name": "i16_16x192", + "dtype": np.int16, + "shape": (16, 192), + "valid_shape": (16, 192), + "eps": 0, + }, + { + "name": "i16_8x448", + "dtype": np.int16, + "shape": (8, 448), + "valid_shape": (8, 448), + "eps": 0, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowmin/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/trowmin/compare.py new file mode 100644 index 000000000..12d4207bd --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowmin/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + vr, vc = case["valid_shape"] + out_shape = (vr,) + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"], count=np.prod(out_shape)).reshape(out_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"], count=np.prod(out_shape)).reshape(out_shape) + + ok = result_cmp(golden, output, case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowmin/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/trowmin/gen_data.py new file mode 100644 index 000000000..cf1bed8ac --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowmin/gen_data.py @@ -0,0 +1,41 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + if np.issubdtype(dtype, np.integer): + if dtype == np.int32: + input1 = np.random.randint(low=-100, high=100, size=shape).astype(dtype) + else: + input1 = np.random.randint(low=-50, high=50, size=shape).astype(dtype) + else: + input1 = np.random.uniform(low=-16, high=16, size=shape).astype(dtype) + + out_shape = (valid_shape[0],) + golden = np.zeros(out_shape, dtype=dtype) + vr, vc = valid_shape + for i in range(vr): + golden[i] = np.min(input1[i, :vc]) + + golden = golden.astype(dtype, copy=False) + save_case_data(case["name"], {"input1": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowmin/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowmin/launch.cpp new file mode 100644 index 000000000..e4a8f8bde --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowmin/launch.cpp @@ -0,0 +1,155 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +extern "C" __global__ AICORE void TROWMIN_f32_127x64_valid127x63(__gm__ float *src, __gm__ float *dst); +void LaunchTROWMIN_f32_127x64_valid127x63(float *src, float *dst, void *stream) { + TROWMIN_f32_127x64_valid127x63<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_f32_63x64(__gm__ float *src, __gm__ float *dst); +void LaunchTROWMIN_f32_63x64(float *src, float *dst, void *stream) { + TROWMIN_f32_63x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_f32_31x128_valid31x127(__gm__ float *src, __gm__ float *dst); +void LaunchTROWMIN_f32_31x128_valid31x127(float *src, float *dst, void *stream) { + TROWMIN_f32_31x128_valid31x127<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_f32_15x192(__gm__ float *src, __gm__ float *dst); +void LaunchTROWMIN_f32_15x192(float *src, float *dst, void *stream) { + TROWMIN_f32_15x192<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_f32_7x448_valid7x447(__gm__ float *src, __gm__ float *dst); +void LaunchTROWMIN_f32_7x448_valid7x447(float *src, float *dst, void *stream) { + TROWMIN_f32_7x448_valid7x447<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_f16_256x16_valid256x15(__gm__ uint16_t *src, __gm__ uint16_t *dst); +void LaunchTROWMIN_f16_256x16_valid256x15(uint16_t *src, uint16_t *dst, void *stream) { + TROWMIN_f16_256x16_valid256x15<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint16_t *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_f32_30x216(__gm__ float *src, __gm__ float *dst); +void LaunchTROWMIN_f32_30x216(float *src, float *dst, void *stream) { + TROWMIN_f32_30x216<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_f32_30x216_valid30x24(__gm__ float *src, __gm__ float *dst); +void LaunchTROWMIN_f32_30x216_valid30x24(float *src, float *dst, void *stream) { + TROWMIN_f32_30x216_valid30x24<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_f32_30x216_valid11x216(__gm__ float *src, __gm__ float *dst); +void LaunchTROWMIN_f32_30x216_valid11x216(float *src, float *dst, void *stream) { + TROWMIN_f32_30x216_valid11x216<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_f32_30x216_valid11x24(__gm__ float *src, __gm__ float *dst); +void LaunchTROWMIN_f32_30x216_valid11x24(float *src, float *dst, void *stream) { + TROWMIN_f32_30x216_valid11x24<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_f32_238x40(__gm__ float *src, __gm__ float *dst); +void LaunchTROWMIN_f32_238x40(float *src, float *dst, void *stream) { + TROWMIN_f32_238x40<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_f32_238x40_valid238x16(__gm__ float *src, __gm__ float *dst); +void LaunchTROWMIN_f32_238x40_valid238x16(float *src, float *dst, void *stream) { + TROWMIN_f32_238x40_valid238x16<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_f32_238x40_valid121x40(__gm__ float *src, __gm__ float *dst); +void LaunchTROWMIN_f32_238x40_valid121x40(float *src, float *dst, void *stream) { + TROWMIN_f32_238x40_valid121x40<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_f32_238x40_valid121x16(__gm__ float *src, __gm__ float *dst); +void LaunchTROWMIN_f32_238x40_valid121x16(float *src, float *dst, void *stream) { + TROWMIN_f32_238x40_valid121x16<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_f32_64x128(__gm__ float *src, __gm__ float *dst); +void LaunchTROWMIN_f32_64x128(float *src, float *dst, void *stream) { + TROWMIN_f32_64x128<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_f32_32x256(__gm__ float *src, __gm__ float *dst); +void LaunchTROWMIN_f32_32x256(float *src, float *dst, void *stream) { + TROWMIN_f32_32x256<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_f32_16x512(__gm__ float *src, __gm__ float *dst); +void LaunchTROWMIN_f32_16x512(float *src, float *dst, void *stream) { + TROWMIN_f32_16x512<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_f32_8x1024(__gm__ float *src, __gm__ float *dst); +void LaunchTROWMIN_f32_8x1024(float *src, float *dst, void *stream) { + TROWMIN_f32_8x1024<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +// int32 cases +extern "C" __global__ AICORE void TROWMIN_i32_127x64_valid127x63(__gm__ int32_t *src, __gm__ int32_t *dst); +void LaunchTROWMIN_i32_127x64_valid127x63(int32_t *src, int32_t *dst, void *stream) { + TROWMIN_i32_127x64_valid127x63<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_i32_63x64(__gm__ int32_t *src, __gm__ int32_t *dst); +void LaunchTROWMIN_i32_63x64(int32_t *src, int32_t *dst, void *stream) { + TROWMIN_i32_63x64<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_i32_31x128_valid31x127(__gm__ int32_t *src, __gm__ int32_t *dst); +void LaunchTROWMIN_i32_31x128_valid31x127(int32_t *src, int32_t *dst, void *stream) { + TROWMIN_i32_31x128_valid31x127<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_i32_15x192(__gm__ int32_t *src, __gm__ int32_t *dst); +void LaunchTROWMIN_i32_15x192(int32_t *src, int32_t *dst, void *stream) { + TROWMIN_i32_15x192<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_i32_7x448_valid7x447(__gm__ int32_t *src, __gm__ int32_t *dst); +void LaunchTROWMIN_i32_7x448_valid7x447(int32_t *src, int32_t *dst, void *stream) { + TROWMIN_i32_7x448_valid7x447<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} + +// int16 cases +extern "C" __global__ AICORE void TROWMIN_i16_128x64(__gm__ int16_t *src, __gm__ int16_t *dst); +void LaunchTROWMIN_i16_128x64(int16_t *src, int16_t *dst, void *stream) { + TROWMIN_i16_128x64<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_i16_64x64(__gm__ int16_t *src, __gm__ int16_t *dst); +void LaunchTROWMIN_i16_64x64(int16_t *src, int16_t *dst, void *stream) { + TROWMIN_i16_64x64<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_i16_32x128(__gm__ int16_t *src, __gm__ int16_t *dst); +void LaunchTROWMIN_i16_32x128(int16_t *src, int16_t *dst, void *stream) { + TROWMIN_i16_32x128<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_i16_16x192(__gm__ int16_t *src, __gm__ int16_t *dst); +void LaunchTROWMIN_i16_16x192(int16_t *src, int16_t *dst, void *stream) { + TROWMIN_i16_16x192<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_i16_8x448(__gm__ int16_t *src, __gm__ int16_t *dst); +void LaunchTROWMIN_i16_8x448(int16_t *src, int16_t *dst, void *stream) { + TROWMIN_i16_8x448<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowmin/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowmin/main.cpp new file mode 100644 index 000000000..f0b9f0025 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowmin/main.cpp @@ -0,0 +1,207 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang trowmin ST — case-table driven. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTROWMIN_f32_127x64_valid127x63(float *src, float *dst, void *stream); +void LaunchTROWMIN_f32_63x64(float *src, float *dst, void *stream); +void LaunchTROWMIN_f32_31x128_valid31x127(float *src, float *dst, void *stream); +void LaunchTROWMIN_f32_15x192(float *src, float *dst, void *stream); +void LaunchTROWMIN_f32_7x448_valid7x447(float *src, float *dst, void *stream); +void LaunchTROWMIN_f16_256x16_valid256x15(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTROWMIN_f32_30x216(float *src, float *dst, void *stream); +void LaunchTROWMIN_f32_30x216_valid30x24(float *src, float *dst, void *stream); +void LaunchTROWMIN_f32_30x216_valid11x216(float *src, float *dst, void *stream); +void LaunchTROWMIN_f32_30x216_valid11x24(float *src, float *dst, void *stream); +void LaunchTROWMIN_f32_238x40(float *src, float *dst, void *stream); +void LaunchTROWMIN_f32_238x40_valid238x16(float *src, float *dst, void *stream); +void LaunchTROWMIN_f32_238x40_valid121x40(float *src, float *dst, void *stream); +void LaunchTROWMIN_f32_238x40_valid121x16(float *src, float *dst, void *stream); +void LaunchTROWMIN_f32_64x128(float *src, float *dst, void *stream); +void LaunchTROWMIN_f32_32x256(float *src, float *dst, void *stream); +void LaunchTROWMIN_f32_16x512(float *src, float *dst, void *stream); +void LaunchTROWMIN_f32_8x1024(float *src, float *dst, void *stream); +void LaunchTROWMIN_i32_127x64_valid127x63(int32_t *src, int32_t *dst, void *stream); +void LaunchTROWMIN_i32_63x64(int32_t *src, int32_t *dst, void *stream); +void LaunchTROWMIN_i32_31x128_valid31x127(int32_t *src, int32_t *dst, void *stream); +void LaunchTROWMIN_i32_15x192(int32_t *src, int32_t *dst, void *stream); +void LaunchTROWMIN_i32_7x448_valid7x447(int32_t *src, int32_t *dst, void *stream); +void LaunchTROWMIN_i16_128x64(int16_t *src, int16_t *dst, void *stream); +void LaunchTROWMIN_i16_64x64(int16_t *src, int16_t *dst, void *stream); +void LaunchTROWMIN_i16_32x128(int16_t *src, int16_t *dst, void *stream); +void LaunchTROWMIN_i16_16x192(int16_t *src, int16_t *dst, void *stream); +void LaunchTROWMIN_i16_8x448(int16_t *src, int16_t *dst, void *stream); + +using LaunchFnF32 = void (*)(float *, float *, void *); +using LaunchFnF16 = void (*)(uint16_t *, uint16_t *, void *); +using LaunchFnI32 = void (*)(int32_t *, int32_t *, void *); +using LaunchFnI16 = void (*)(int16_t *, int16_t *, void *); + +enum class DType { F32, F16, I32, I16 }; + +struct TestCase { + const char *name; + DType dtype; + union { + LaunchFnF32 launchF32; + LaunchFnF16 launchF16; + LaunchFnI32 launchI32; + LaunchFnI16 launchI16; + }; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + // f32 cases + {"f32_127x64_valid127x63", DType::F32, .launchF32 = LaunchTROWMIN_f32_127x64_valid127x63, 127, 64, 127, 63, 4}, + {"f32_63x64", DType::F32, .launchF32 = LaunchTROWMIN_f32_63x64, 63, 64, 63, 64, 4}, + {"f32_31x128_valid31x127", DType::F32, .launchF32 = LaunchTROWMIN_f32_31x128_valid31x127, 31, 128, 31, 127, 4}, + {"f32_15x192", DType::F32, .launchF32 = LaunchTROWMIN_f32_15x192, 15, 192, 15, 192, 4}, + {"f32_7x448_valid7x447", DType::F32, .launchF32 = LaunchTROWMIN_f32_7x448_valid7x447, 7, 448, 7, 447, 4}, + // f16 case + {"f16_256x16_valid256x15", DType::F16, .launchF16 = LaunchTROWMIN_f16_256x16_valid256x15, 256, 16, 256, 15, 2}, + // f32 more cases + {"f32_30x216", DType::F32, .launchF32 = LaunchTROWMIN_f32_30x216, 30, 216, 30, 216, 4}, + {"f32_30x216_valid30x24", DType::F32, .launchF32 = LaunchTROWMIN_f32_30x216_valid30x24, 30, 216, 30, 24, 4}, + {"f32_30x216_valid11x216", DType::F32, .launchF32 = LaunchTROWMIN_f32_30x216_valid11x216, 30, 216, 11, 216, 4}, + {"f32_30x216_valid11x24", DType::F32, .launchF32 = LaunchTROWMIN_f32_30x216_valid11x24, 30, 216, 11, 24, 4}, + {"f32_238x40", DType::F32, .launchF32 = LaunchTROWMIN_f32_238x40, 238, 40, 238, 40, 4}, + {"f32_238x40_valid238x16", DType::F32, .launchF32 = LaunchTROWMIN_f32_238x40_valid238x16, 238, 40, 238, 16, 4}, + {"f32_238x40_valid121x40", DType::F32, .launchF32 = LaunchTROWMIN_f32_238x40_valid121x40, 238, 40, 121, 40, 4}, + {"f32_238x40_valid121x16", DType::F32, .launchF32 = LaunchTROWMIN_f32_238x40_valid121x16, 238, 40, 121, 16, 4}, + // f32 DN dst cases + {"f32_64x128", DType::F32, .launchF32 = LaunchTROWMIN_f32_64x128, 64, 128, 64, 128, 4}, + {"f32_32x256", DType::F32, .launchF32 = LaunchTROWMIN_f32_32x256, 32, 256, 32, 256, 4}, + {"f32_16x512", DType::F32, .launchF32 = LaunchTROWMIN_f32_16x512, 16, 512, 16, 512, 4}, + {"f32_8x1024", DType::F32, .launchF32 = LaunchTROWMIN_f32_8x1024, 8, 1024,8, 1024,4}, + // int32 cases + {"i32_127x64_valid127x63", DType::I32, .launchI32 = LaunchTROWMIN_i32_127x64_valid127x63, 127, 64, 127, 63, 4}, + {"i32_63x64", DType::I32, .launchI32 = LaunchTROWMIN_i32_63x64, 63, 64, 63, 64, 4}, + {"i32_31x128_valid31x127", DType::I32, .launchI32 = LaunchTROWMIN_i32_31x128_valid31x127, 31, 128, 31, 127, 4}, + {"i32_15x192", DType::I32, .launchI32 = LaunchTROWMIN_i32_15x192, 15, 192, 15, 192, 4}, + {"i32_7x448_valid7x447", DType::I32, .launchI32 = LaunchTROWMIN_i32_7x448_valid7x447, 7, 448, 7, 447, 4}, + // int16 cases + {"i16_128x64", DType::I16, .launchI16 = LaunchTROWMIN_i16_128x64, 128, 64, 128, 64, 2}, + {"i16_64x64", DType::I16, .launchI16 = LaunchTROWMIN_i16_64x64, 64, 64, 64, 64, 2}, + {"i16_32x128", DType::I16, .launchI16 = LaunchTROWMIN_i16_32x128, 32, 128, 32, 128, 2}, + {"i16_16x192", DType::I16, .launchI16 = LaunchTROWMIN_i16_16x192, 16, 192, 16, 192, 2}, + {"i16_8x448", DType::I16, .launchI16 = LaunchTROWMIN_i16_8x448, 8, 448, 8, 448, 2}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t srcElemCount = tc.rows * tc.cols; + const size_t srcFileSize = srcElemCount * tc.elemSize; + const size_t dstElemCount = tc.validRows * 1; + const size_t dstFileSize = dstElemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = srcFileSize; + + void *src0Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&src0Host, srcFileSize); + aclrtMallocHost(&dstHost, dstFileSize); + + aclrtMalloc(&src0Device, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, srcFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, srcFileSize, src0Host, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + switch (tc.dtype) { + case DType::F32: tc.launchF32((float *)src0Device, (float *)dstDevice, stream); break; + case DType::F16: tc.launchF16((uint16_t *)src0Device, (uint16_t *)dstDevice, stream); break; + case DType::I32: tc.launchI32((int32_t *)src0Device, (int32_t *)dstDevice, stream); break; + case DType::I16: tc.launchI16((int16_t *)src0Device, (int16_t *)dstDevice, stream); break; + } + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./trowmin [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowmin/trowmin.pto b/test/tilelang_st/npu/a5/src/st/testcase/trowmin/trowmin.pto new file mode 100644 index 000000000..4c336b629 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowmin/trowmin.pto @@ -0,0 +1,1545 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.trowmin: tload(src) + trowmin(src, tmp)->dst + tstore(dst). + +module { + + // Case 0: f32 127x64 (valid=127x63) + func.func @TROWMIN_f32_127x64_valid127x63(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c127 = arith.constant 127 : index + %c8128 = arith.constant 8128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c127, %c64], + strides = [%c8128, %c8128, %c8128, %c64, %c1] + : !pto.tensor_view<1x1x1x127x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c127, %c1], + strides = [%c127, %c127, %c127, %c1, %c1] + : !pto.tensor_view<1x1x1x127x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c127, %c63] + : !pto.tensor_view<1x1x1x127x64xf32> -> !pto.partition_tensor_view<1x1x1x127x63xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c127, %c1] + : !pto.tensor_view<1x1x1x127x1xf32> -> !pto.partition_tensor_view<1x1x1x127x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x127x63xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x127x1xf32>) + return + } + + // Case 1: f32 63x64 (valid=63x64) + func.func @TROWMIN_f32_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c1], + strides = [%c63, %c63, %c63, %c1, %c1] + : !pto.tensor_view<1x1x1x63x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xf32> -> !pto.partition_tensor_view<1x1x1x63x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c1] + : !pto.tensor_view<1x1x1x63x1xf32> -> !pto.partition_tensor_view<1x1x1x63x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x1xf32>) + return + } + + // Case 2: f32 31x128 (valid=31x127) + func.func @TROWMIN_f32_31x128_valid31x127(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c31 = arith.constant 31 : index + %c127 = arith.constant 127 : index + %c128 = arith.constant 128 : index + %c3968 = arith.constant 3968 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c31, %c1], + strides = [%c31, %c31, %c31, %c1, %c1] + : !pto.tensor_view<1x1x1x31x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c127] + : !pto.tensor_view<1x1x1x31x128xf32> -> !pto.partition_tensor_view<1x1x1x31x127xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c1] + : !pto.tensor_view<1x1x1x31x1xf32> -> !pto.partition_tensor_view<1x1x1x31x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x31x127xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x31x1xf32>) + return + } + + // Case 3: f32 15x192 (valid=15x192) + func.func @TROWMIN_f32_15x192(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c15 = arith.constant 15 : index + %c192 = arith.constant 192 : index + %c2880 = arith.constant 2880 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c15, %c1], + strides = [%c15, %c15, %c15, %c1, %c1] + : !pto.tensor_view<1x1x1x15x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xf32> -> !pto.partition_tensor_view<1x1x1x15x192xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c1] + : !pto.tensor_view<1x1x1x15x1xf32> -> !pto.partition_tensor_view<1x1x1x15x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x192xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x15x1xf32>) + return + } + + // Case 4: f32 7x448 (valid=7x447) + func.func @TROWMIN_f32_7x448_valid7x447(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c8 = arith.constant 8 : index + %c447 = arith.constant 447 : index + %c448 = arith.constant 448 : index + %c3136 = arith.constant 3136 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c7, %c1], + strides = [%c7, %c7, %c7, %c1, %c1] + : !pto.tensor_view<1x1x1x7x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c447] + : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x447xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c1] + : !pto.tensor_view<1x1x1x7x1xf32> -> !pto.partition_tensor_view<1x1x1x7x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x7x447xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x7x1xf32>) + return + } + + // Case 5: f16 256x16 (valid=256x15) + func.func @TROWMIN_f16_256x16_valid256x15(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c4096, %c4096, %c4096, %c16, %c1] + : !pto.tensor_view<1x1x1x256x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c256, %c1], + strides = [%c256, %c256, %c256, %c1, %c1] + : !pto.tensor_view<1x1x1x256x1xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c15] + : !pto.tensor_view<1x1x1x256x16xf16> -> !pto.partition_tensor_view<1x1x1x256x15xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c1] + : !pto.tensor_view<1x1x1x256x1xf16> -> !pto.partition_tensor_view<1x1x1x256x1xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x256x15xf16>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x256x1xf16>) + return + } + + // Case 6: f32 30x216 (valid=30x216) + func.func @TROWMIN_f32_30x216(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c30 = arith.constant 30 : index + %c216 = arith.constant 216 : index + %c6480 = arith.constant 6480 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c30, %c216], + strides = [%c6480, %c6480, %c6480, %c216, %c1] + : !pto.tensor_view<1x1x1x30x216xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c30, %c1], + strides = [%c30, %c30, %c30, %c1, %c1] + : !pto.tensor_view<1x1x1x30x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c30, %c216] + : !pto.tensor_view<1x1x1x30x216xf32> -> !pto.partition_tensor_view<1x1x1x30x216xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c30, %c1] + : !pto.tensor_view<1x1x1x30x1xf32> -> !pto.partition_tensor_view<1x1x1x30x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x30x216xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x30x1xf32>) + return + } + + // Case 7: f32 30x216 (valid=30x24) + func.func @TROWMIN_f32_30x216_valid30x24(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c24 = arith.constant 24 : index + %c30 = arith.constant 30 : index + %c216 = arith.constant 216 : index + %c6480 = arith.constant 6480 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c30, %c216], + strides = [%c6480, %c6480, %c6480, %c216, %c1] + : !pto.tensor_view<1x1x1x30x216xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c30, %c1], + strides = [%c30, %c30, %c30, %c1, %c1] + : !pto.tensor_view<1x1x1x30x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c30, %c24] + : !pto.tensor_view<1x1x1x30x216xf32> -> !pto.partition_tensor_view<1x1x1x30x24xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c30, %c1] + : !pto.tensor_view<1x1x1x30x1xf32> -> !pto.partition_tensor_view<1x1x1x30x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x30x24xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x30x1xf32>) + return + } + + // Case 8: f32 30x216 (valid=11x216) + func.func @TROWMIN_f32_30x216_valid11x216(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c11 = arith.constant 11 : index + %c30 = arith.constant 30 : index + %c216 = arith.constant 216 : index + %c6480 = arith.constant 6480 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c30, %c216], + strides = [%c6480, %c6480, %c6480, %c216, %c1] + : !pto.tensor_view<1x1x1x30x216xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c11, %c1], + strides = [%c11, %c11, %c11, %c1, %c1] + : !pto.tensor_view<1x1x1x11x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c11, %c216] + : !pto.tensor_view<1x1x1x30x216xf32> -> !pto.partition_tensor_view<1x1x1x11x216xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c11, %c1] + : !pto.tensor_view<1x1x1x11x1xf32> -> !pto.partition_tensor_view<1x1x1x11x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x11x216xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x11x1xf32>) + return + } + + // Case 9: f32 30x216 (valid=11x24) + func.func @TROWMIN_f32_30x216_valid11x24(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c11 = arith.constant 11 : index + %c24 = arith.constant 24 : index + %c30 = arith.constant 30 : index + %c216 = arith.constant 216 : index + %c6480 = arith.constant 6480 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c30, %c216], + strides = [%c6480, %c6480, %c6480, %c216, %c1] + : !pto.tensor_view<1x1x1x30x216xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c11, %c1], + strides = [%c11, %c11, %c11, %c1, %c1] + : !pto.tensor_view<1x1x1x11x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c11, %c24] + : !pto.tensor_view<1x1x1x30x216xf32> -> !pto.partition_tensor_view<1x1x1x11x24xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c11, %c1] + : !pto.tensor_view<1x1x1x11x1xf32> -> !pto.partition_tensor_view<1x1x1x11x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x11x24xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x11x1xf32>) + return + } + + // Case 10: f32 238x40 (valid=238x40) + func.func @TROWMIN_f32_238x40(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c40 = arith.constant 40 : index + %c238 = arith.constant 238 : index + %c9520 = arith.constant 9520 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c238, %c40], + strides = [%c9520, %c9520, %c9520, %c40, %c1] + : !pto.tensor_view<1x1x1x238x40xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c238, %c1], + strides = [%c238, %c238, %c238, %c1, %c1] + : !pto.tensor_view<1x1x1x238x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c238, %c40] + : !pto.tensor_view<1x1x1x238x40xf32> -> !pto.partition_tensor_view<1x1x1x238x40xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c238, %c1] + : !pto.tensor_view<1x1x1x238x1xf32> -> !pto.partition_tensor_view<1x1x1x238x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x238x40xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x238x1xf32>) + return + } + + // Case 11: f32 238x40 (valid=238x16) + func.func @TROWMIN_f32_238x40_valid238x16(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c40 = arith.constant 40 : index + %c238 = arith.constant 238 : index + %c9520 = arith.constant 9520 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c238, %c40], + strides = [%c9520, %c9520, %c9520, %c40, %c1] + : !pto.tensor_view<1x1x1x238x40xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c238, %c1], + strides = [%c238, %c238, %c238, %c1, %c1] + : !pto.tensor_view<1x1x1x238x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c238, %c16] + : !pto.tensor_view<1x1x1x238x40xf32> -> !pto.partition_tensor_view<1x1x1x238x16xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c238, %c1] + : !pto.tensor_view<1x1x1x238x1xf32> -> !pto.partition_tensor_view<1x1x1x238x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x238x16xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x238x1xf32>) + return + } + + // Case 12: f32 238x40 (valid=121x40) + func.func @TROWMIN_f32_238x40_valid121x40(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c40 = arith.constant 40 : index + %c121 = arith.constant 121 : index + %c238 = arith.constant 238 : index + %c9520 = arith.constant 9520 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c238, %c40], + strides = [%c9520, %c9520, %c9520, %c40, %c1] + : !pto.tensor_view<1x1x1x238x40xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c121, %c1], + strides = [%c121, %c121, %c121, %c1, %c1] + : !pto.tensor_view<1x1x1x121x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c121, %c40] + : !pto.tensor_view<1x1x1x238x40xf32> -> !pto.partition_tensor_view<1x1x1x121x40xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c121, %c1] + : !pto.tensor_view<1x1x1x121x1xf32> -> !pto.partition_tensor_view<1x1x1x121x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x121x40xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x121x1xf32>) + return + } + + // Case 13: f32 238x40 (valid=121x16) + func.func @TROWMIN_f32_238x40_valid121x16(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c40 = arith.constant 40 : index + %c121 = arith.constant 121 : index + %c238 = arith.constant 238 : index + %c9520 = arith.constant 9520 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c238, %c40], + strides = [%c9520, %c9520, %c9520, %c40, %c1] + : !pto.tensor_view<1x1x1x238x40xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c121, %c1], + strides = [%c121, %c121, %c121, %c1, %c1] + : !pto.tensor_view<1x1x1x121x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c121, %c16] + : !pto.tensor_view<1x1x1x238x40xf32> -> !pto.partition_tensor_view<1x1x1x121x16xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c121, %c1] + : !pto.tensor_view<1x1x1x121x1xf32> -> !pto.partition_tensor_view<1x1x1x121x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x121x16xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x121x1xf32>) + return + } + + // Case 14: f32 64x128 (valid=64x128) + func.func @TROWMIN_f32_64x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c64, %c128], + strides = [%c8192, %c8192, %c8192, %c128, %c1] + : !pto.tensor_view<1x1x1x64x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c64, %c1], + strides = [%c64, %c64, %c64, %c1, %c1] + : !pto.tensor_view<1x1x1x64x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c128] + : !pto.tensor_view<1x1x1x64x128xf32> -> !pto.partition_tensor_view<1x1x1x64x128xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c1] + : !pto.tensor_view<1x1x1x64x1xf32> -> !pto.partition_tensor_view<1x1x1x64x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x64x128xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x64x1xf32>) + return + } + + // Case 15: f32 32x256 (valid=32x256) + func.func @TROWMIN_f32_32x256(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c256], + strides = [%c8192, %c8192, %c8192, %c256, %c1] + : !pto.tensor_view<1x1x1x32x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c1], + strides = [%c32, %c32, %c32, %c1, %c1] + : !pto.tensor_view<1x1x1x32x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c256] + : !pto.tensor_view<1x1x1x32x256xf32> -> !pto.partition_tensor_view<1x1x1x32x256xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c1] + : !pto.tensor_view<1x1x1x32x1xf32> -> !pto.partition_tensor_view<1x1x1x32x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x256xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x1xf32>) + return + } + + // Case 16: f32 16x512 (valid=16x512) + func.func @TROWMIN_f32_16x512(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c512 = arith.constant 512 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c512], + strides = [%c8192, %c8192, %c8192, %c512, %c1] + : !pto.tensor_view<1x1x1x16x512xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c1], + strides = [%c16, %c16, %c16, %c1, %c1] + : !pto.tensor_view<1x1x1x16x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c512] + : !pto.tensor_view<1x1x1x16x512xf32> -> !pto.partition_tensor_view<1x1x1x16x512xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c1] + : !pto.tensor_view<1x1x1x16x1xf32> -> !pto.partition_tensor_view<1x1x1x16x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x512xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x1xf32>) + return + } + + // Case 17: f32 8x1024 (valid=8x1024) + func.func @TROWMIN_f32_8x1024(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c1024 = arith.constant 1024 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c8, %c1024], + strides = [%c8192, %c8192, %c8192, %c1024, %c1] + : !pto.tensor_view<1x1x1x8x1024xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c8, %c1], + strides = [%c8, %c8, %c8, %c1, %c1] + : !pto.tensor_view<1x1x1x8x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c1024] + : !pto.tensor_view<1x1x1x8x1024xf32> -> !pto.partition_tensor_view<1x1x1x8x1024xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c1] + : !pto.tensor_view<1x1x1x8x1xf32> -> !pto.partition_tensor_view<1x1x1x8x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x8x1024xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x8x1xf32>) + return + } + + // ======================================================================== + // int32 cases (case19-case23) + // ======================================================================== + + // case19: i32 127x64 valid=127x63 + func.func @TROWMIN_i32_127x64_valid127x63(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c127 = arith.constant 127 : index + %c8128 = arith.constant 8128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c127, %c64], + strides = [%c8128, %c8128, %c8128, %c64, %c1] + : !pto.tensor_view<1x1x1x127x64xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c127, %c1], + strides = [%c127, %c127, %c127, %c1, %c1] + : !pto.tensor_view<1x1x1x127x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c127, %c63] + : !pto.tensor_view<1x1x1x127x64xi32> -> !pto.partition_tensor_view<1x1x1x127x63xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c127, %c1] + : !pto.tensor_view<1x1x1x127x1xi32> -> !pto.partition_tensor_view<1x1x1x127x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x127x63xi32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x127x1xi32>) + return + } + + // case20: i32 63x64 valid=63x64 + func.func @TROWMIN_i32_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c1], + strides = [%c63, %c63, %c63, %c1, %c1] + : !pto.tensor_view<1x1x1x63x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xi32> -> !pto.partition_tensor_view<1x1x1x63x64xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c1] + : !pto.tensor_view<1x1x1x63x1xi32> -> !pto.partition_tensor_view<1x1x1x63x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xi32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x1xi32>) + return + } + + // case21: i32 31x128 valid=31x127 + func.func @TROWMIN_i32_31x128_valid31x127(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c31 = arith.constant 31 : index + %c127 = arith.constant 127 : index + %c128 = arith.constant 128 : index + %c3968 = arith.constant 3968 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c31, %c1], + strides = [%c31, %c31, %c31, %c1, %c1] + : !pto.tensor_view<1x1x1x31x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c127] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x127xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c1] + : !pto.tensor_view<1x1x1x31x1xi32> -> !pto.partition_tensor_view<1x1x1x31x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x31x127xi32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x31x1xi32>) + return + } + + // case22: i32 15x192 valid=15x192 + func.func @TROWMIN_i32_15x192(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c15 = arith.constant 15 : index + %c192 = arith.constant 192 : index + %c2880 = arith.constant 2880 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c15, %c1], + strides = [%c15, %c15, %c15, %c1, %c1] + : !pto.tensor_view<1x1x1x15x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi32> -> !pto.partition_tensor_view<1x1x1x15x192xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c1] + : !pto.tensor_view<1x1x1x15x1xi32> -> !pto.partition_tensor_view<1x1x1x15x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x192xi32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x15x1xi32>) + return + } + + // case23: i32 7x448 valid=7x447 + func.func @TROWMIN_i32_7x448_valid7x447(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c7 = arith.constant 7 : index + %c447 = arith.constant 447 : index + %c448 = arith.constant 448 : index + %c3136 = arith.constant 3136 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c7, %c1], + strides = [%c7, %c7, %c7, %c1, %c1] + : !pto.tensor_view<1x1x1x7x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c447] + : !pto.tensor_view<1x1x1x7x448xi32> -> !pto.partition_tensor_view<1x1x1x7x447xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c1] + : !pto.tensor_view<1x1x1x7x1xi32> -> !pto.partition_tensor_view<1x1x1x7x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x7x447xi32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x7x1xi32>) + return + } + + // ======================================================================== + // int16 cases (case24-case28) + // ======================================================================== + + // case24: i16 128x64 valid=128x64 + func.func @TROWMIN_i16_128x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c128, %c64], + strides = [%c8192, %c8192, %c8192, %c64, %c1] + : !pto.tensor_view<1x1x1x128x64xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c128, %c1], + strides = [%c128, %c128, %c128, %c1, %c1] + : !pto.tensor_view<1x1x1x128x1xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c64] + : !pto.tensor_view<1x1x1x128x64xi16> -> !pto.partition_tensor_view<1x1x1x128x64xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c1] + : !pto.tensor_view<1x1x1x128x1xi16> -> !pto.partition_tensor_view<1x1x1x128x1xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x128x64xi16>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x128x1xi16>) + return + } + + // case25: i16 64x64 valid=64x64 + func.func @TROWMIN_i16_64x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c64, %c1], + strides = [%c64, %c64, %c64, %c1, %c1] + : !pto.tensor_view<1x1x1x64x1xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c1] + : !pto.tensor_view<1x1x1x64x1xi16> -> !pto.partition_tensor_view<1x1x1x64x1xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x64x1xi16>) + return + } + + // case26: i16 32x128 valid=32x128 + func.func @TROWMIN_i16_32x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c128], + strides = [%c4096, %c4096, %c4096, %c128, %c1] + : !pto.tensor_view<1x1x1x32x128xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c1], + strides = [%c32, %c32, %c32, %c1, %c1] + : !pto.tensor_view<1x1x1x32x1xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c128] + : !pto.tensor_view<1x1x1x32x128xi16> -> !pto.partition_tensor_view<1x1x1x32x128xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c1] + : !pto.tensor_view<1x1x1x32x1xi16> -> !pto.partition_tensor_view<1x1x1x32x1xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x128xi16>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x1xi16>) + return + } + + // case27: i16 16x192 valid=16x192 + func.func @TROWMIN_i16_16x192(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c16r = arith.constant 16 : index + %c192 = arith.constant 192 : index + %c3072 = arith.constant 3072 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16r, %c192], + strides = [%c3072, %c3072, %c3072, %c192, %c1] + : !pto.tensor_view<1x1x1x16x192xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16r, %c1], + strides = [%c16, %c16, %c16, %c1, %c1] + : !pto.tensor_view<1x1x1x16x1xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16r, %c192] + : !pto.tensor_view<1x1x1x16x192xi16> -> !pto.partition_tensor_view<1x1x1x16x192xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16r, %c1] + : !pto.tensor_view<1x1x1x16x1xi16> -> !pto.partition_tensor_view<1x1x1x16x1xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x192xi16>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x1xi16>) + return + } + + // case28: i16 8x448 valid=8x448 + func.func @TROWMIN_i16_8x448(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c8 = arith.constant 8 : index + %c448 = arith.constant 448 : index + %c3584 = arith.constant 3584 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c8, %c448], + strides = [%c3584, %c3584, %c3584, %c448, %c1] + : !pto.tensor_view<1x1x1x8x448xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c8, %c1], + strides = [%c8, %c8, %c8, %c1, %c1] + : !pto.tensor_view<1x1x1x8x1xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c448] + : !pto.tensor_view<1x1x1x8x448xi16> -> !pto.partition_tensor_view<1x1x1x8x448xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c1] + : !pto.tensor_view<1x1x1x8x1xi16> -> !pto.partition_tensor_view<1x1x1x8x1xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x8x448xi16>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x8x1xi16>) + return + } + +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowprod/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/trowprod/CMakeLists.txt new file mode 100644 index 000000000..6a30d1293 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowprod/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(trowprod) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowprod/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/trowprod/cases.py new file mode 100644 index 000000000..66f7176a9 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowprod/cases.py @@ -0,0 +1,153 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for trowprod ST test cases. + +Aligned with pto-isa tests/npu/a5/src/st/testcase/trowprod (18 cases). +""" + +import numpy as np + +CASES = [ + # f32 cases (case1-case5 from pto-isa) + { + "name": "f32_127x64_valid127x63", + "dtype": np.float32, + "shape": (127, 64), + "valid_shape": (127, 63), + "eps": 1e-3, + }, + { + "name": "f32_63x64", + "dtype": np.float32, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 1e-3, + }, + { + "name": "f32_31x128_valid31x127", + "dtype": np.float32, + "shape": (31, 128), + "valid_shape": (31, 127), + "eps": 1e-3, + }, + { + "name": "f32_15x192", + "dtype": np.float32, + "shape": (15, 192), + "valid_shape": (15, 192), + "eps": 1e-3, + }, + { + "name": "f32_7x448_valid7x447", + "dtype": np.float32, + "shape": (7, 448), + "valid_shape": (7, 447), + "eps": 1e-3, + }, + # f16 case (case6 from pto-isa) + { + "name": "f16_256x16_valid256x15", + "dtype": np.float16, + "shape": (256, 16), + "valid_shape": (256, 15), + "eps": 1e-1, + }, + # f32 DN dst cases (case7-case10 from pto-isa) + { + "name": "f32_64x128", + "dtype": np.float32, + "shape": (64, 128), + "valid_shape": (64, 128), + "eps": 1e-3, + }, + { + "name": "f32_32x256", + "dtype": np.float32, + "shape": (32, 256), + "valid_shape": (32, 256), + "eps": 1e-3, + }, + { + "name": "f32_16x512", + "dtype": np.float32, + "shape": (16, 512), + "valid_shape": (16, 512), + "eps": 1e-3, + }, + { + "name": "f32_8x1024", + "dtype": np.float32, + "shape": (8, 1024), + "valid_shape": (8, 1024), + "eps": 1e-3, + }, + + # int32 cases (case11-case15 from pto-isa) + { + "name": "i32_127x64_valid127x63", + "dtype": np.int32, + "shape": (127, 64), + "valid_shape": (127, 63), + "eps": 0, + }, + { + "name": "i32_63x64", + "dtype": np.int32, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 0, + }, + { + "name": "i32_31x128_valid31x127", + "dtype": np.int32, + "shape": (31, 128), + "valid_shape": (31, 127), + "eps": 0, + }, + { + "name": "i32_15x192", + "dtype": np.int32, + "shape": (15, 192), + "valid_shape": (15, 192), + "eps": 0, + }, + { + "name": "i32_7x448_valid7x447", + "dtype": np.int32, + "shape": (7, 448), + "valid_shape": (7, 447), + "eps": 0, + }, + + # int16 cases (case16-case18 from pto-isa) + { + "name": "i16_256x16_valid256x15", + "dtype": np.int16, + "shape": (256, 16), + "valid_shape": (256, 15), + "eps": 0, + }, + { + "name": "i16_63x64", + "dtype": np.int16, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 0, + }, + { + "name": "i16_31x128_valid31x127", + "dtype": np.int16, + "shape": (31, 128), + "valid_shape": (31, 127), + "eps": 0, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowprod/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/trowprod/compare.py new file mode 100644 index 000000000..12d4207bd --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowprod/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + vr, vc = case["valid_shape"] + out_shape = (vr,) + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"], count=np.prod(out_shape)).reshape(out_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"], count=np.prod(out_shape)).reshape(out_shape) + + ok = result_cmp(golden, output, case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowprod/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/trowprod/gen_data.py new file mode 100644 index 000000000..b1f6092af --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowprod/gen_data.py @@ -0,0 +1,42 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + if np.issubdtype(dtype, np.integer): + if dtype == np.int32: + input1 = np.random.randint(low=-3, high=4, size=shape).astype(dtype) + else: + input1 = np.random.randint(low=-2, high=3, size=shape).astype(dtype) + else: + input1 = np.random.uniform(low=0.9, high=1.1, size=shape).astype(dtype) + + out_shape = (valid_shape[0],) + golden = np.ones(out_shape, dtype=dtype) + vr, vc = valid_shape + for i in range(vr): + for j in range(vc): + golden[i] *= input1[i, j] + + golden = golden.astype(dtype, copy=False) + save_case_data(case["name"], {"input1": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowprod/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowprod/launch.cpp new file mode 100644 index 000000000..0533066cd --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowprod/launch.cpp @@ -0,0 +1,105 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +extern "C" __global__ AICORE void TROWPROD_f32_127x64_valid127x63(__gm__ float *src, __gm__ float *dst); +void LaunchTROWPROD_f32_127x64_valid127x63(float *src, float *dst, void *stream) { + TROWPROD_f32_127x64_valid127x63<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWPROD_f32_63x64(__gm__ float *src, __gm__ float *dst); +void LaunchTROWPROD_f32_63x64(float *src, float *dst, void *stream) { + TROWPROD_f32_63x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWPROD_f32_31x128_valid31x127(__gm__ float *src, __gm__ float *dst); +void LaunchTROWPROD_f32_31x128_valid31x127(float *src, float *dst, void *stream) { + TROWPROD_f32_31x128_valid31x127<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWPROD_f32_15x192(__gm__ float *src, __gm__ float *dst); +void LaunchTROWPROD_f32_15x192(float *src, float *dst, void *stream) { + TROWPROD_f32_15x192<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWPROD_f32_7x448_valid7x447(__gm__ float *src, __gm__ float *dst); +void LaunchTROWPROD_f32_7x448_valid7x447(float *src, float *dst, void *stream) { + TROWPROD_f32_7x448_valid7x447<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWPROD_f16_256x16_valid256x15(__gm__ uint16_t *src, __gm__ uint16_t *dst); +void LaunchTROWPROD_f16_256x16_valid256x15(uint16_t *src, uint16_t *dst, void *stream) { + TROWPROD_f16_256x16_valid256x15<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint16_t *)dst); +} + +extern "C" __global__ AICORE void TROWPROD_f32_64x128(__gm__ float *src, __gm__ float *dst); +void LaunchTROWPROD_f32_64x128(float *src, float *dst, void *stream) { + TROWPROD_f32_64x128<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWPROD_f32_32x256(__gm__ float *src, __gm__ float *dst); +void LaunchTROWPROD_f32_32x256(float *src, float *dst, void *stream) { + TROWPROD_f32_32x256<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWPROD_f32_16x512(__gm__ float *src, __gm__ float *dst); +void LaunchTROWPROD_f32_16x512(float *src, float *dst, void *stream) { + TROWPROD_f32_16x512<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWPROD_f32_8x1024(__gm__ float *src, __gm__ float *dst); +void LaunchTROWPROD_f32_8x1024(float *src, float *dst, void *stream) { + TROWPROD_f32_8x1024<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +// int32 cases +extern "C" __global__ AICORE void TROWPROD_i32_127x64_valid127x63(__gm__ int32_t *src, __gm__ int32_t *dst); +void LaunchTROWPROD_i32_127x64_valid127x63(int32_t *src, int32_t *dst, void *stream) { + TROWPROD_i32_127x64_valid127x63<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} + +extern "C" __global__ AICORE void TROWPROD_i32_63x64(__gm__ int32_t *src, __gm__ int32_t *dst); +void LaunchTROWPROD_i32_63x64(int32_t *src, int32_t *dst, void *stream) { + TROWPROD_i32_63x64<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} + +extern "C" __global__ AICORE void TROWPROD_i32_31x128_valid31x127(__gm__ int32_t *src, __gm__ int32_t *dst); +void LaunchTROWPROD_i32_31x128_valid31x127(int32_t *src, int32_t *dst, void *stream) { + TROWPROD_i32_31x128_valid31x127<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} + +extern "C" __global__ AICORE void TROWPROD_i32_15x192(__gm__ int32_t *src, __gm__ int32_t *dst); +void LaunchTROWPROD_i32_15x192(int32_t *src, int32_t *dst, void *stream) { + TROWPROD_i32_15x192<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} + +extern "C" __global__ AICORE void TROWPROD_i32_7x448_valid7x447(__gm__ int32_t *src, __gm__ int32_t *dst); +void LaunchTROWPROD_i32_7x448_valid7x447(int32_t *src, int32_t *dst, void *stream) { + TROWPROD_i32_7x448_valid7x447<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} + +// int16 cases +extern "C" __global__ AICORE void TROWPROD_i16_256x16_valid256x15(__gm__ int16_t *src, __gm__ int16_t *dst); +void LaunchTROWPROD_i16_256x16_valid256x15(int16_t *src, int16_t *dst, void *stream) { + TROWPROD_i16_256x16_valid256x15<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} + +extern "C" __global__ AICORE void TROWPROD_i16_63x64(__gm__ int16_t *src, __gm__ int16_t *dst); +void LaunchTROWPROD_i16_63x64(int16_t *src, int16_t *dst, void *stream) { + TROWPROD_i16_63x64<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} + +extern "C" __global__ AICORE void TROWPROD_i16_31x128_valid31x127(__gm__ int16_t *src, __gm__ int16_t *dst); +void LaunchTROWPROD_i16_31x128_valid31x127(int16_t *src, int16_t *dst, void *stream) { + TROWPROD_i16_31x128_valid31x127<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowprod/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowprod/main.cpp new file mode 100644 index 000000000..32981566a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowprod/main.cpp @@ -0,0 +1,186 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang trowprod ST — case-table driven. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTROWPROD_f32_127x64_valid127x63(float *src, float *dst, void *stream); +void LaunchTROWPROD_f32_63x64(float *src, float *dst, void *stream); +void LaunchTROWPROD_f32_31x128_valid31x127(float *src, float *dst, void *stream); +void LaunchTROWPROD_f32_15x192(float *src, float *dst, void *stream); +void LaunchTROWPROD_f32_7x448_valid7x447(float *src, float *dst, void *stream); +void LaunchTROWPROD_f16_256x16_valid256x15(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTROWPROD_f32_64x128(float *src, float *dst, void *stream); +void LaunchTROWPROD_f32_32x256(float *src, float *dst, void *stream); +void LaunchTROWPROD_f32_16x512(float *src, float *dst, void *stream); +void LaunchTROWPROD_f32_8x1024(float *src, float *dst, void *stream); +void LaunchTROWPROD_i32_127x64_valid127x63(int32_t *src, int32_t *dst, void *stream); +void LaunchTROWPROD_i32_63x64(int32_t *src, int32_t *dst, void *stream); +void LaunchTROWPROD_i32_31x128_valid31x127(int32_t *src, int32_t *dst, void *stream); +void LaunchTROWPROD_i32_15x192(int32_t *src, int32_t *dst, void *stream); +void LaunchTROWPROD_i32_7x448_valid7x447(int32_t *src, int32_t *dst, void *stream); +void LaunchTROWPROD_i16_256x16_valid256x15(int16_t *src, int16_t *dst, void *stream); +void LaunchTROWPROD_i16_63x64(int16_t *src, int16_t *dst, void *stream); +void LaunchTROWPROD_i16_31x128_valid31x127(int16_t *src, int16_t *dst, void *stream); + +using LaunchFnF32 = void (*)(float *, float *, void *); +using LaunchFnF16 = void (*)(uint16_t *, uint16_t *, void *); +using LaunchFnI32 = void (*)(int32_t *, int32_t *, void *); +using LaunchFnI16 = void (*)(int16_t *, int16_t *, void *); + +enum class DType { F32, F16, I32, I16 }; + +struct TestCase { + const char *name; + DType dtype; + union { + LaunchFnF32 launchF32; + LaunchFnF16 launchF16; + LaunchFnI32 launchI32; + LaunchFnI16 launchI16; + }; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + // f32 cases + {"f32_127x64_valid127x63", DType::F32, .launchF32 = LaunchTROWPROD_f32_127x64_valid127x63, 127, 64, 127, 63, 4}, + {"f32_63x64", DType::F32, .launchF32 = LaunchTROWPROD_f32_63x64, 63, 64, 63, 64, 4}, + {"f32_31x128_valid31x127", DType::F32, .launchF32 = LaunchTROWPROD_f32_31x128_valid31x127, 31, 128, 31, 127, 4}, + {"f32_15x192", DType::F32, .launchF32 = LaunchTROWPROD_f32_15x192, 15, 192, 15, 192, 4}, + {"f32_7x448_valid7x447", DType::F32, .launchF32 = LaunchTROWPROD_f32_7x448_valid7x447, 7, 448, 7, 447, 4}, + // f16 case + {"f16_256x16_valid256x15", DType::F16, .launchF16 = LaunchTROWPROD_f16_256x16_valid256x15, 256, 16, 256, 15, 2}, + // f32 DN dst cases + {"f32_64x128", DType::F32, .launchF32 = LaunchTROWPROD_f32_64x128, 64, 128, 64, 128, 4}, + {"f32_32x256", DType::F32, .launchF32 = LaunchTROWPROD_f32_32x256, 32, 256, 32, 256, 4}, + {"f32_16x512", DType::F32, .launchF32 = LaunchTROWPROD_f32_16x512, 16, 512, 16, 512, 4}, + {"f32_8x1024", DType::F32, .launchF32 = LaunchTROWPROD_f32_8x1024, 8, 1024,8, 1024,4}, + // int32 cases + {"i32_127x64_valid127x63", DType::I32, .launchI32 = LaunchTROWPROD_i32_127x64_valid127x63, 127, 64, 127, 63, 4}, + {"i32_63x64", DType::I32, .launchI32 = LaunchTROWPROD_i32_63x64, 63, 64, 63, 64, 4}, + {"i32_31x128_valid31x127", DType::I32, .launchI32 = LaunchTROWPROD_i32_31x128_valid31x127, 31, 128, 31, 127, 4}, + {"i32_15x192", DType::I32, .launchI32 = LaunchTROWPROD_i32_15x192, 15, 192, 15, 192, 4}, + {"i32_7x448_valid7x447", DType::I32, .launchI32 = LaunchTROWPROD_i32_7x448_valid7x447, 7, 448, 7, 447, 4}, + // int16 cases + {"i16_256x16_valid256x15", DType::I16, .launchI16 = LaunchTROWPROD_i16_256x16_valid256x15, 256, 16, 256, 15, 2}, + {"i16_63x64", DType::I16, .launchI16 = LaunchTROWPROD_i16_63x64, 63, 64, 63, 64, 2}, + {"i16_31x128_valid31x127", DType::I16, .launchI16 = LaunchTROWPROD_i16_31x128_valid31x127, 31, 128, 31, 127, 2}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t srcElemCount = tc.rows * tc.cols; + const size_t srcFileSize = srcElemCount * tc.elemSize; + const size_t dstElemCount = tc.validRows * 1; + const size_t dstFileSize = dstElemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = srcFileSize; + + void *src0Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&src0Host, srcFileSize); + aclrtMallocHost(&dstHost, dstFileSize); + + aclrtMalloc(&src0Device, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, srcFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, srcFileSize, src0Host, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + switch (tc.dtype) { + case DType::F32: tc.launchF32((float *)src0Device, (float *)dstDevice, stream); break; + case DType::F16: tc.launchF16((uint16_t *)src0Device, (uint16_t *)dstDevice, stream); break; + case DType::I32: tc.launchI32((int32_t *)src0Device, (int32_t *)dstDevice, stream); break; + case DType::I16: tc.launchI16((int16_t *)src0Device, (int16_t *)dstDevice, stream); break; + } + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./trowprod [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowprod/trowprod.pto b/test/tilelang_st/npu/a5/src/st/testcase/trowprod/trowprod.pto new file mode 100644 index 000000000..90e6a93fd --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowprod/trowprod.pto @@ -0,0 +1,999 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.trowprod: tload(src) + trowprod(src, tmp)->dst + tstore(dst). + +module { + + // Case 0: f32 127x64 (valid=127x63) + func.func @TROWPROD_f32_127x64_valid127x63(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c127 = arith.constant 127 : index + %c8128 = arith.constant 8128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c127, %c64], + strides = [%c8128, %c8128, %c8128, %c64, %c1] + : !pto.tensor_view<1x1x1x127x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c127, %c1], + strides = [%c127, %c127, %c127, %c1, %c1] + : !pto.tensor_view<1x1x1x127x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c127, %c63] + : !pto.tensor_view<1x1x1x127x64xf32> -> !pto.partition_tensor_view<1x1x1x127x63xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c127, %c1] + : !pto.tensor_view<1x1x1x127x1xf32> -> !pto.partition_tensor_view<1x1x1x127x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x127x63xf32>) + outs(%src : !pto.tile_buf) + + pto.trowprod ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x127x1xf32>) + return + } + + // Case 1: f32 63x64 (valid=63x64) + func.func @TROWPROD_f32_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c1], + strides = [%c63, %c63, %c63, %c1, %c1] + : !pto.tensor_view<1x1x1x63x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xf32> -> !pto.partition_tensor_view<1x1x1x63x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c1] + : !pto.tensor_view<1x1x1x63x1xf32> -> !pto.partition_tensor_view<1x1x1x63x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xf32>) + outs(%src : !pto.tile_buf) + + pto.trowprod ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x1xf32>) + return + } + + // Case 2: f32 31x128 (valid=31x127) + func.func @TROWPROD_f32_31x128_valid31x127(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c31 = arith.constant 31 : index + %c127 = arith.constant 127 : index + %c128 = arith.constant 128 : index + %c3968 = arith.constant 3968 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c31, %c1], + strides = [%c31, %c31, %c31, %c1, %c1] + : !pto.tensor_view<1x1x1x31x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c127] + : !pto.tensor_view<1x1x1x31x128xf32> -> !pto.partition_tensor_view<1x1x1x31x127xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c1] + : !pto.tensor_view<1x1x1x31x1xf32> -> !pto.partition_tensor_view<1x1x1x31x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x31x127xf32>) + outs(%src : !pto.tile_buf) + + pto.trowprod ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x31x1xf32>) + return + } + + // Case 3: f32 15x192 (valid=15x192) + func.func @TROWPROD_f32_15x192(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c15 = arith.constant 15 : index + %c192 = arith.constant 192 : index + %c2880 = arith.constant 2880 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c15, %c1], + strides = [%c15, %c15, %c15, %c1, %c1] + : !pto.tensor_view<1x1x1x15x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xf32> -> !pto.partition_tensor_view<1x1x1x15x192xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c1] + : !pto.tensor_view<1x1x1x15x1xf32> -> !pto.partition_tensor_view<1x1x1x15x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x192xf32>) + outs(%src : !pto.tile_buf) + + pto.trowprod ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x15x1xf32>) + return + } + + // Case 4: f32 7x448 (valid=7x447) + func.func @TROWPROD_f32_7x448_valid7x447(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c8 = arith.constant 8 : index + %c447 = arith.constant 447 : index + %c448 = arith.constant 448 : index + %c3136 = arith.constant 3136 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c7, %c1], + strides = [%c7, %c7, %c7, %c1, %c1] + : !pto.tensor_view<1x1x1x7x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c447] + : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x447xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c1] + : !pto.tensor_view<1x1x1x7x1xf32> -> !pto.partition_tensor_view<1x1x1x7x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x7x447xf32>) + outs(%src : !pto.tile_buf) + + pto.trowprod ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x7x1xf32>) + return + } + + // Case 5: f16 256x16 (valid=256x15) + func.func @TROWPROD_f16_256x16_valid256x15(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c4096, %c4096, %c4096, %c16, %c1] + : !pto.tensor_view<1x1x1x256x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c256, %c1], + strides = [%c256, %c256, %c256, %c1, %c1] + : !pto.tensor_view<1x1x1x256x1xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c15] + : !pto.tensor_view<1x1x1x256x16xf16> -> !pto.partition_tensor_view<1x1x1x256x15xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c1] + : !pto.tensor_view<1x1x1x256x1xf16> -> !pto.partition_tensor_view<1x1x1x256x1xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x256x15xf16>) + outs(%src : !pto.tile_buf) + + pto.trowprod ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x256x1xf16>) + return + } + + // Case 6: f32 64x128 (valid=64x128) + func.func @TROWPROD_f32_64x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c64, %c128], + strides = [%c8192, %c8192, %c8192, %c128, %c1] + : !pto.tensor_view<1x1x1x64x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c64, %c1], + strides = [%c64, %c64, %c64, %c1, %c1] + : !pto.tensor_view<1x1x1x64x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c128] + : !pto.tensor_view<1x1x1x64x128xf32> -> !pto.partition_tensor_view<1x1x1x64x128xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c1] + : !pto.tensor_view<1x1x1x64x1xf32> -> !pto.partition_tensor_view<1x1x1x64x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x64x128xf32>) + outs(%src : !pto.tile_buf) + + pto.trowprod ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x64x1xf32>) + return + } + + // Case 7: f32 32x256 (valid=32x256) + func.func @TROWPROD_f32_32x256(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c256], + strides = [%c8192, %c8192, %c8192, %c256, %c1] + : !pto.tensor_view<1x1x1x32x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c1], + strides = [%c32, %c32, %c32, %c1, %c1] + : !pto.tensor_view<1x1x1x32x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c256] + : !pto.tensor_view<1x1x1x32x256xf32> -> !pto.partition_tensor_view<1x1x1x32x256xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c1] + : !pto.tensor_view<1x1x1x32x1xf32> -> !pto.partition_tensor_view<1x1x1x32x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x256xf32>) + outs(%src : !pto.tile_buf) + + pto.trowprod ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x1xf32>) + return + } + + // Case 8: f32 16x512 (valid=16x512) + func.func @TROWPROD_f32_16x512(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c512 = arith.constant 512 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c512], + strides = [%c8192, %c8192, %c8192, %c512, %c1] + : !pto.tensor_view<1x1x1x16x512xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c1], + strides = [%c16, %c16, %c16, %c1, %c1] + : !pto.tensor_view<1x1x1x16x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c512] + : !pto.tensor_view<1x1x1x16x512xf32> -> !pto.partition_tensor_view<1x1x1x16x512xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c1] + : !pto.tensor_view<1x1x1x16x1xf32> -> !pto.partition_tensor_view<1x1x1x16x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x512xf32>) + outs(%src : !pto.tile_buf) + + pto.trowprod ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x1xf32>) + return + } + + // Case 9: f32 8x1024 (valid=8x1024) + func.func @TROWPROD_f32_8x1024(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c1024 = arith.constant 1024 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c8, %c1024], + strides = [%c8192, %c8192, %c8192, %c1024, %c1] + : !pto.tensor_view<1x1x1x8x1024xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c8, %c1], + strides = [%c8, %c8, %c8, %c1, %c1] + : !pto.tensor_view<1x1x1x8x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c1024] + : !pto.tensor_view<1x1x1x8x1024xf32> -> !pto.partition_tensor_view<1x1x1x8x1024xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c1] + : !pto.tensor_view<1x1x1x8x1xf32> -> !pto.partition_tensor_view<1x1x1x8x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x8x1024xf32>) + outs(%src : !pto.tile_buf) + + pto.trowprod ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x8x1xf32>) + return + } + + // ======================================================================== + // int32 cases (case11-case15) + // ======================================================================== + + // case11: i32 127x64 valid=127x63 + func.func @TROWPROD_i32_127x64_valid127x63(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c127 = arith.constant 127 : index + %c8128 = arith.constant 8128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c127, %c64], + strides = [%c8128, %c8128, %c8128, %c64, %c1] + : !pto.tensor_view<1x1x1x127x64xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c127, %c1], + strides = [%c127, %c127, %c127, %c1, %c1] + : !pto.tensor_view<1x1x1x127x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c127, %c63] + : !pto.tensor_view<1x1x1x127x64xi32> -> !pto.partition_tensor_view<1x1x1x127x63xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c127, %c1] + : !pto.tensor_view<1x1x1x127x1xi32> -> !pto.partition_tensor_view<1x1x1x127x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x127x63xi32>) + outs(%src : !pto.tile_buf) + + pto.trowprod ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x127x1xi32>) + return + } + + // case12: i32 63x64 valid=63x64 + func.func @TROWPROD_i32_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c1], + strides = [%c63, %c63, %c63, %c1, %c1] + : !pto.tensor_view<1x1x1x63x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xi32> -> !pto.partition_tensor_view<1x1x1x63x64xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c1] + : !pto.tensor_view<1x1x1x63x1xi32> -> !pto.partition_tensor_view<1x1x1x63x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xi32>) + outs(%src : !pto.tile_buf) + + pto.trowprod ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x1xi32>) + return + } + + // case13: i32 31x128 valid=31x127 + func.func @TROWPROD_i32_31x128_valid31x127(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c31 = arith.constant 31 : index + %c127 = arith.constant 127 : index + %c128 = arith.constant 128 : index + %c3968 = arith.constant 3968 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c31, %c1], + strides = [%c31, %c31, %c31, %c1, %c1] + : !pto.tensor_view<1x1x1x31x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c127] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x127xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c1] + : !pto.tensor_view<1x1x1x31x1xi32> -> !pto.partition_tensor_view<1x1x1x31x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x31x127xi32>) + outs(%src : !pto.tile_buf) + + pto.trowprod ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x31x1xi32>) + return + } + + // case14: i32 15x192 valid=15x192 + func.func @TROWPROD_i32_15x192(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c15 = arith.constant 15 : index + %c192 = arith.constant 192 : index + %c2880 = arith.constant 2880 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c15, %c1], + strides = [%c15, %c15, %c15, %c1, %c1] + : !pto.tensor_view<1x1x1x15x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi32> -> !pto.partition_tensor_view<1x1x1x15x192xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c1] + : !pto.tensor_view<1x1x1x15x1xi32> -> !pto.partition_tensor_view<1x1x1x15x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x192xi32>) + outs(%src : !pto.tile_buf) + + pto.trowprod ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x15x1xi32>) + return + } + + // case15: i32 7x448 valid=7x447 + func.func @TROWPROD_i32_7x448_valid7x447(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c7 = arith.constant 7 : index + %c447 = arith.constant 447 : index + %c448 = arith.constant 448 : index + %c3136 = arith.constant 3136 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c7, %c1], + strides = [%c7, %c7, %c7, %c1, %c1] + : !pto.tensor_view<1x1x1x7x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c447] + : !pto.tensor_view<1x1x1x7x448xi32> -> !pto.partition_tensor_view<1x1x1x7x447xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c1] + : !pto.tensor_view<1x1x1x7x1xi32> -> !pto.partition_tensor_view<1x1x1x7x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x7x447xi32>) + outs(%src : !pto.tile_buf) + + pto.trowprod ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x7x1xi32>) + return + } + + // ======================================================================== + // int16 cases (case16-case18) + // ======================================================================== + + // case16: i16 256x16 valid=256x15 + func.func @TROWPROD_i16_256x16_valid256x15(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c15 = arith.constant 15 : index + %c256 = arith.constant 256 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c4096, %c4096, %c4096, %c16, %c1] + : !pto.tensor_view<1x1x1x256x16xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c256, %c1], + strides = [%c256, %c256, %c256, %c1, %c1] + : !pto.tensor_view<1x1x1x256x1xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c15] + : !pto.tensor_view<1x1x1x256x16xi16> -> !pto.partition_tensor_view<1x1x1x256x15xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c1] + : !pto.tensor_view<1x1x1x256x1xi16> -> !pto.partition_tensor_view<1x1x1x256x1xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x256x15xi16>) + outs(%src : !pto.tile_buf) + + pto.trowprod ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x256x1xi16>) + return + } + + // case17: i16 63x64 valid=63x64 + func.func @TROWPROD_i16_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c1], + strides = [%c63, %c63, %c63, %c1, %c1] + : !pto.tensor_view<1x1x1x63x1xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xi16> -> !pto.partition_tensor_view<1x1x1x63x64xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c1] + : !pto.tensor_view<1x1x1x63x1xi16> -> !pto.partition_tensor_view<1x1x1x63x1xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xi16>) + outs(%src : !pto.tile_buf) + + pto.trowprod ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x1xi16>) + return + } + + // case18: i16 31x128 valid=31x127 + func.func @TROWPROD_i16_31x128_valid31x127(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c31 = arith.constant 31 : index + %c127 = arith.constant 127 : index + %c128 = arith.constant 128 : index + %c3968 = arith.constant 3968 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c31, %c1], + strides = [%c31, %c31, %c31, %c1, %c1] + : !pto.tensor_view<1x1x1x31x1xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c127] + : !pto.tensor_view<1x1x1x31x128xi16> -> !pto.partition_tensor_view<1x1x1x31x127xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c1] + : !pto.tensor_view<1x1x1x31x1xi16> -> !pto.partition_tensor_view<1x1x1x31x1xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x31x127xi16>) + outs(%src : !pto.tile_buf) + + pto.trowprod ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x31x1xi16>) + return + } + +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowsum/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/trowsum/CMakeLists.txt new file mode 100644 index 000000000..bcb316bcc --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowsum/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(trowsum) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowsum/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/trowsum/cases.py new file mode 100644 index 000000000..b2b9e96b8 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowsum/cases.py @@ -0,0 +1,165 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for trowsum ST test cases. + +Aligned with pto-isa tests/npu/a5/src/st/testcase/trowsum (20 cases). +""" + +import numpy as np + +CASES = [ + # f32 cases (case1-case10 from pto-isa) + { + "name": "case1", + "dtype": np.float32, + "shape": (127, 64), + "valid_shape": (127, 63), + "eps": 1e-3, + }, + { + "name": "case2", + "dtype": np.float32, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 1e-3, + }, + { + "name": "case3", + "dtype": np.float32, + "shape": (31, 128), + "valid_shape": (31, 127), + "eps": 1e-3, + }, + { + "name": "case4", + "dtype": np.float32, + "shape": (15, 192), + "valid_shape": (15, 192), + "eps": 1e-3, + }, + { + "name": "case5", + "dtype": np.float32, + "shape": (7, 448), + "valid_shape": (7, 447), + "eps": 1e-3, + }, + { + "name": "case6", + "dtype": np.float16, + "shape": (256, 16), + "valid_shape": (256, 15), + "eps": 5e-3, + }, + { + "name": "case7", + "dtype": np.float32, + "shape": (64, 128), + "valid_shape": (64, 128), + "eps": 1e-3, + }, + { + "name": "case8", + "dtype": np.float32, + "shape": (32, 256), + "valid_shape": (32, 256), + "eps": 1e-3, + }, + { + "name": "case9", + "dtype": np.float32, + "shape": (16, 512), + "valid_shape": (16, 512), + "eps": 1e-3, + }, + { + "name": "case10", + "dtype": np.float32, + "shape": (8, 1024), + "valid_shape": (8, 1024), + "eps": 1e-3, + }, + + # int32 cases (case11-case15 from pto-isa) + { + "name": "case11", + "dtype": np.int32, + "shape": (127, 64), + "valid_shape": (127, 63), + "eps": 0, + }, + { + "name": "case12", + "dtype": np.int32, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 0, + }, + { + "name": "case13", + "dtype": np.int32, + "shape": (31, 128), + "valid_shape": (31, 127), + "eps": 0, + }, + { + "name": "case14", + "dtype": np.int32, + "shape": (15, 192), + "valid_shape": (15, 192), + "eps": 0, + }, + { + "name": "case15", + "dtype": np.int32, + "shape": (7, 448), + "valid_shape": (7, 447), + "eps": 0, + }, + + # int16 cases (case16-case20 from pto-isa) + { + "name": "case16", + "dtype": np.int16, + "shape": (128, 64), + "valid_shape": (128, 64), + "eps": 0, + }, + { + "name": "case17", + "dtype": np.int16, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 0, + }, + { + "name": "case18", + "dtype": np.int16, + "shape": (32, 128), + "valid_shape": (32, 128), + "eps": 0, + }, + { + "name": "case19", + "dtype": np.int16, + "shape": (16, 192), + "valid_shape": (16, 192), + "eps": 0, + }, + { + "name": "case20", + "dtype": np.int16, + "shape": (8, 448), + "valid_shape": (8, 448), + "eps": 0, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowsum/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/trowsum/compare.py new file mode 100644 index 000000000..b80e2549b --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowsum/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + vr, vc = case["valid_shape"] + out_shape = (vr, 1) + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"], count=np.prod(out_shape)).reshape(out_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"], count=np.prod(out_shape)).reshape(out_shape) + + ok = result_cmp(golden, output, case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowsum/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/trowsum/gen_data.py new file mode 100644 index 000000000..26b0c9f31 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowsum/gen_data.py @@ -0,0 +1,45 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import numpy as np +from cases import CASES +from st_common import validate_cases, save_case_data + +validate_cases(CASES) + +np.random.seed(42) + +for case in CASES: + dtype = case["dtype"] + row = case["shape"][0] + valid_row = case["valid_shape"][0] + col = case["shape"][1] + valid_col = case["valid_shape"][1] + + if np.issubdtype(dtype, np.integer): + if dtype == np.int32: + input_arr = np.random.randint(low=-100, high=100, size=(row, col)).astype(dtype) + elif dtype == np.int16: + input_arr = np.random.randint(low=-50, high=50, size=(row, col)).astype(dtype) + else: + input_arr = np.random.randint(low=-10, high=10, size=(row, col)).astype(dtype) + else: + input_arr = np.random.uniform(low=-1, high=1, size=(row, col)).astype(dtype) + + output_arr = np.zeros((row,)) + for i in range(valid_row): + for j in range(valid_col): + output_arr[i] += input_arr[i, j] + output_arr = output_arr.astype(dtype) + + save_case_data(case["name"], {"input": input_arr, "golden": output_arr}) + print(f"[INFO] gen_data: {case['name']} shape=({row},{col}) valid=({valid_row},{valid_col}) dtype={dtype.__name__}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowsum/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowsum/launch.cpp new file mode 100644 index 000000000..18e1a09c2 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowsum/launch.cpp @@ -0,0 +1,111 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// ======================================================================== +// f32 kernels (case1-case10) +// ======================================================================== + +extern "C" __global__ AICORE void TROWSUM_case1(__gm__ float *src, __gm__ float *dst); +extern "C" __global__ AICORE void TROWSUM_case2(__gm__ float *src, __gm__ float *dst); +extern "C" __global__ AICORE void TROWSUM_case3(__gm__ float *src, __gm__ float *dst); +extern "C" __global__ AICORE void TROWSUM_case4(__gm__ float *src, __gm__ float *dst); +extern "C" __global__ AICORE void TROWSUM_case5(__gm__ float *src, __gm__ float *dst); +extern "C" __global__ AICORE void TROWSUM_case6(__gm__ uint16_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TROWSUM_case7(__gm__ float *src, __gm__ float *dst); +extern "C" __global__ AICORE void TROWSUM_case8(__gm__ float *src, __gm__ float *dst); +extern "C" __global__ AICORE void TROWSUM_case9(__gm__ float *src, __gm__ float *dst); +extern "C" __global__ AICORE void TROWSUM_case10(__gm__ float *src, __gm__ float *dst); + +void LaunchTROWSUM_case1(float *src, float *dst, void *stream) { + TROWSUM_case1<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} +void LaunchTROWSUM_case2(float *src, float *dst, void *stream) { + TROWSUM_case2<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} +void LaunchTROWSUM_case3(float *src, float *dst, void *stream) { + TROWSUM_case3<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} +void LaunchTROWSUM_case4(float *src, float *dst, void *stream) { + TROWSUM_case4<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} +void LaunchTROWSUM_case5(float *src, float *dst, void *stream) { + TROWSUM_case5<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} +void LaunchTROWSUM_case6(uint16_t *src, uint16_t *dst, void *stream) { + TROWSUM_case6<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint16_t *)dst); +} +void LaunchTROWSUM_case7(float *src, float *dst, void *stream) { + TROWSUM_case7<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} +void LaunchTROWSUM_case8(float *src, float *dst, void *stream) { + TROWSUM_case8<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} +void LaunchTROWSUM_case9(float *src, float *dst, void *stream) { + TROWSUM_case9<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} +void LaunchTROWSUM_case10(float *src, float *dst, void *stream) { + TROWSUM_case10<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +// ======================================================================== +// i32 kernels (case11-case15) +// ======================================================================== + +extern "C" __global__ AICORE void TROWSUM_case11(__gm__ int32_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TROWSUM_case12(__gm__ int32_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TROWSUM_case13(__gm__ int32_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TROWSUM_case14(__gm__ int32_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TROWSUM_case15(__gm__ int32_t *src, __gm__ int32_t *dst); + +void LaunchTROWSUM_case11(int32_t *src, int32_t *dst, void *stream) { + TROWSUM_case11<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} +void LaunchTROWSUM_case12(int32_t *src, int32_t *dst, void *stream) { + TROWSUM_case12<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} +void LaunchTROWSUM_case13(int32_t *src, int32_t *dst, void *stream) { + TROWSUM_case13<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} +void LaunchTROWSUM_case14(int32_t *src, int32_t *dst, void *stream) { + TROWSUM_case14<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} +void LaunchTROWSUM_case15(int32_t *src, int32_t *dst, void *stream) { + TROWSUM_case15<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} + +// ======================================================================== +// i16 kernels (case16-case20) +// ======================================================================== + +extern "C" __global__ AICORE void TROWSUM_case16(__gm__ int16_t *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TROWSUM_case17(__gm__ int16_t *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TROWSUM_case18(__gm__ int16_t *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TROWSUM_case19(__gm__ int16_t *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TROWSUM_case20(__gm__ int16_t *src, __gm__ int16_t *dst); + +void LaunchTROWSUM_case16(int16_t *src, int16_t *dst, void *stream) { + TROWSUM_case16<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} +void LaunchTROWSUM_case17(int16_t *src, int16_t *dst, void *stream) { + TROWSUM_case17<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} +void LaunchTROWSUM_case18(int16_t *src, int16_t *dst, void *stream) { + TROWSUM_case18<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} +void LaunchTROWSUM_case19(int16_t *src, int16_t *dst, void *stream) { + TROWSUM_case19<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} +void LaunchTROWSUM_case20(int16_t *src, int16_t *dst, void *stream) { + TROWSUM_case20<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowsum/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowsum/main.cpp new file mode 100644 index 000000000..170b432fb --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowsum/main.cpp @@ -0,0 +1,195 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang trowsum ST — aligned with pto-isa 20 cases. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// f32 launch wrappers +void LaunchTROWSUM_case1(float *src, float *dst, void *stream); +void LaunchTROWSUM_case2(float *src, float *dst, void *stream); +void LaunchTROWSUM_case3(float *src, float *dst, void *stream); +void LaunchTROWSUM_case4(float *src, float *dst, void *stream); +void LaunchTROWSUM_case5(float *src, float *dst, void *stream); +void LaunchTROWSUM_case6(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTROWSUM_case7(float *src, float *dst, void *stream); +void LaunchTROWSUM_case8(float *src, float *dst, void *stream); +void LaunchTROWSUM_case9(float *src, float *dst, void *stream); +void LaunchTROWSUM_case10(float *src, float *dst, void *stream); + +// i32 launch wrappers +void LaunchTROWSUM_case11(int32_t *src, int32_t *dst, void *stream); +void LaunchTROWSUM_case12(int32_t *src, int32_t *dst, void *stream); +void LaunchTROWSUM_case13(int32_t *src, int32_t *dst, void *stream); +void LaunchTROWSUM_case14(int32_t *src, int32_t *dst, void *stream); +void LaunchTROWSUM_case15(int32_t *src, int32_t *dst, void *stream); + +// i16 launch wrappers +void LaunchTROWSUM_case16(int16_t *src, int16_t *dst, void *stream); +void LaunchTROWSUM_case17(int16_t *src, int16_t *dst, void *stream); +void LaunchTROWSUM_case18(int16_t *src, int16_t *dst, void *stream); +void LaunchTROWSUM_case19(int16_t *src, int16_t *dst, void *stream); +void LaunchTROWSUM_case20(int16_t *src, int16_t *dst, void *stream); + +using LaunchFnF32 = void (*)(float *, float *, void *); +using LaunchFnF16 = void (*)(uint16_t *, uint16_t *, void *); +using LaunchFnI32 = void (*)(int32_t *, int32_t *, void *); +using LaunchFnI16 = void (*)(int16_t *, int16_t *, void *); + +enum class DType { F32, F16, I32, I16 }; + +struct TestCase { + const char *name; + DType dtype; + union { + LaunchFnF32 launchF32; + LaunchFnF16 launchF16; + LaunchFnI32 launchI32; + LaunchFnI16 launchI16; + }; + size_t rows; + size_t cols; + size_t validRows; + size_t validCols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + // f32 cases + {"case1", DType::F32, .launchF32 = LaunchTROWSUM_case1, 127, 64, 127, 63, 4}, + {"case2", DType::F32, .launchF32 = LaunchTROWSUM_case2, 63, 64, 63, 64, 4}, + {"case3", DType::F32, .launchF32 = LaunchTROWSUM_case3, 31, 128, 31, 127, 4}, + {"case4", DType::F32, .launchF32 = LaunchTROWSUM_case4, 15, 192, 15, 192, 4}, + {"case5", DType::F32, .launchF32 = LaunchTROWSUM_case5, 7, 448, 7, 447, 4}, + {"case6", DType::F16, .launchF16 = LaunchTROWSUM_case6, 256, 16, 256, 15, 2}, + {"case7", DType::F32, .launchF32 = LaunchTROWSUM_case7, 64, 128, 64, 128, 4}, + {"case8", DType::F32, .launchF32 = LaunchTROWSUM_case8, 32, 256, 32, 256, 4}, + {"case9", DType::F32, .launchF32 = LaunchTROWSUM_case9, 16, 512, 16, 512, 4}, + {"case10", DType::F32, .launchF32 = LaunchTROWSUM_case10, 8, 1024, 8, 1024, 4}, + + // i32 cases + {"case11", DType::I32, .launchI32 = LaunchTROWSUM_case11, 127, 64, 127, 63, 4}, + {"case12", DType::I32, .launchI32 = LaunchTROWSUM_case12, 63, 64, 63, 64, 4}, + {"case13", DType::I32, .launchI32 = LaunchTROWSUM_case13, 31, 128, 31, 127, 4}, + {"case14", DType::I32, .launchI32 = LaunchTROWSUM_case14, 15, 192, 15, 192, 4}, + {"case15", DType::I32, .launchI32 = LaunchTROWSUM_case15, 7, 448, 7, 447, 4}, + + // i16 cases + {"case16", DType::I16, .launchI16 = LaunchTROWSUM_case16, 128, 64, 128, 64, 2}, + {"case17", DType::I16, .launchI16 = LaunchTROWSUM_case17, 64, 64, 64, 64, 2}, + {"case18", DType::I16, .launchI16 = LaunchTROWSUM_case18, 32, 128, 32, 128, 2}, + {"case19", DType::I16, .launchI16 = LaunchTROWSUM_case19, 16, 192, 16, 192, 2}, + {"case20", DType::I16, .launchI16 = LaunchTROWSUM_case20, 8, 448, 8, 448, 2}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t srcElemCount = tc.rows * tc.cols; + const size_t srcFileSize = srcElemCount * tc.elemSize; + const size_t dstElemCount = tc.rows; + const size_t dstFileSize = dstElemCount * tc.elemSize; + size_t actualFileSize = 0; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + std::string caseDir = std::string("./") + tc.name; + + void *src0Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&src0Host, srcFileSize); + aclrtMallocHost(&dstHost, dstFileSize); + + aclrtMalloc(&src0Device, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), actualFileSize, src0Host, srcFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, srcFileSize, src0Host, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + if (tc.dtype == DType::F32) { + tc.launchF32((float *)src0Device, (float *)dstDevice, stream); + } else if (tc.dtype == DType::F16) { + tc.launchF16((uint16_t *)src0Device, (uint16_t *)dstDevice, stream); + } else if (tc.dtype == DType::I32) { + tc.launchI32((int32_t *)src0Device, (int32_t *)dstDevice, stream); + } else { + tc.launchI16((int16_t *)src0Device, (int16_t *)dstDevice, stream); + } + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowsum/trowsum.pto b/test/tilelang_st/npu/a5/src/st/testcase/trowsum/trowsum.pto new file mode 100644 index 000000000..64c778f16 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowsum/trowsum.pto @@ -0,0 +1,1092 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.trowsum: tload(src) + trowsum(src, tmp)->dst + tstore(dst). +// Aligned with pto-isa tests/npu/a5/src/st/testcase/trowsum (20 cases). + +module { + + // ======================================================================== + // f32 cases (case1-case10) + // ======================================================================== + + // case1: f32 127x64 valid=127x63 + func.func @TROWSUM_case1(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c127 = arith.constant 127 : index + %c8128 = arith.constant 8128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c127, %c64], + strides = [%c8128, %c8128, %c8128, %c64, %c1] + : !pto.tensor_view<1x1x1x127x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c127, %c1], + strides = [%c127, %c127, %c127, %c1, %c1] + : !pto.tensor_view<1x1x1x127x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c127, %c63] + : !pto.tensor_view<1x1x1x127x64xf32> -> !pto.partition_tensor_view<1x1x1x127x63xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c127, %c1] + : !pto.tensor_view<1x1x1x127x1xf32> -> !pto.partition_tensor_view<1x1x1x127x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x127x63xf32>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x127x1xf32>) + return + } + + // case2: f32 63x64 valid=63x64 + func.func @TROWSUM_case2(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c1], + strides = [%c63, %c63, %c63, %c1, %c1] + : !pto.tensor_view<1x1x1x63x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xf32> -> !pto.partition_tensor_view<1x1x1x63x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c1] + : !pto.tensor_view<1x1x1x63x1xf32> -> !pto.partition_tensor_view<1x1x1x63x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xf32>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x1xf32>) + return + } + + // case3: f32 31x128 valid=31x127 + func.func @TROWSUM_case3(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c31 = arith.constant 31 : index + %c127 = arith.constant 127 : index + %c128 = arith.constant 128 : index + %c3968 = arith.constant 3968 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c31, %c1], + strides = [%c31, %c31, %c31, %c1, %c1] + : !pto.tensor_view<1x1x1x31x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c127] + : !pto.tensor_view<1x1x1x31x128xf32> -> !pto.partition_tensor_view<1x1x1x31x127xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c1] + : !pto.tensor_view<1x1x1x31x1xf32> -> !pto.partition_tensor_view<1x1x1x31x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x31x127xf32>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x31x1xf32>) + return + } + + // case4: f32 15x192 valid=15x192 + func.func @TROWSUM_case4(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c192 = arith.constant 192 : index + %c2880 = arith.constant 2880 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c15, %c1], + strides = [%c15, %c15, %c15, %c1, %c1] + : !pto.tensor_view<1x1x1x15x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xf32> -> !pto.partition_tensor_view<1x1x1x15x192xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c1] + : !pto.tensor_view<1x1x1x15x1xf32> -> !pto.partition_tensor_view<1x1x1x15x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x192xf32>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x15x1xf32>) + return + } + + // case5: f32 7x448 valid=7x447 + func.func @TROWSUM_case5(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c447 = arith.constant 447 : index + %c448 = arith.constant 448 : index + %c3136 = arith.constant 3136 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c7, %c1], + strides = [%c7, %c7, %c7, %c1, %c1] + : !pto.tensor_view<1x1x1x7x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c447] + : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x447xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c1] + : !pto.tensor_view<1x1x1x7x1xf32> -> !pto.partition_tensor_view<1x1x1x7x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x7x447xf32>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x7x1xf32>) + return + } + + // case6: f16 256x16 valid=256x15 + func.func @TROWSUM_case6(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c4096, %c4096, %c4096, %c16, %c1] + : !pto.tensor_view<1x1x1x256x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c256, %c1], + strides = [%c256, %c256, %c256, %c1, %c1] + : !pto.tensor_view<1x1x1x256x1xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c15] + : !pto.tensor_view<1x1x1x256x16xf16> -> !pto.partition_tensor_view<1x1x1x256x15xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c1] + : !pto.tensor_view<1x1x1x256x1xf16> -> !pto.partition_tensor_view<1x1x1x256x1xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x256x15xf16>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x256x1xf16>) + return + } + + // case7: f32 64x128 valid=64x128 (DN dst) + func.func @TROWSUM_case7(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c64, %c128], + strides = [%c8192, %c8192, %c8192, %c128, %c1] + : !pto.tensor_view<1x1x1x64x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c64, %c1], + strides = [%c64, %c64, %c64, %c1, %c1] + : !pto.tensor_view<1x1x1x64x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c128] + : !pto.tensor_view<1x1x1x64x128xf32> -> !pto.partition_tensor_view<1x1x1x64x128xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c1] + : !pto.tensor_view<1x1x1x64x1xf32> -> !pto.partition_tensor_view<1x1x1x64x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x64x128xf32>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x64x1xf32>) + return + } + + // case8: f32 32x256 valid=32x256 (DN dst) + func.func @TROWSUM_case8(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c256], + strides = [%c8192, %c8192, %c8192, %c256, %c1] + : !pto.tensor_view<1x1x1x32x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c1], + strides = [%c32, %c32, %c32, %c1, %c1] + : !pto.tensor_view<1x1x1x32x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c256] + : !pto.tensor_view<1x1x1x32x256xf32> -> !pto.partition_tensor_view<1x1x1x32x256xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c1] + : !pto.tensor_view<1x1x1x32x1xf32> -> !pto.partition_tensor_view<1x1x1x32x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x256xf32>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x1xf32>) + return + } + + // case9: f32 16x512 valid=16x512 (DN dst) + func.func @TROWSUM_case9(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c512 = arith.constant 512 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c512], + strides = [%c8192, %c8192, %c8192, %c512, %c1] + : !pto.tensor_view<1x1x1x16x512xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c1], + strides = [%c16, %c16, %c16, %c1, %c1] + : !pto.tensor_view<1x1x1x16x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c512] + : !pto.tensor_view<1x1x1x16x512xf32> -> !pto.partition_tensor_view<1x1x1x16x512xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c1] + : !pto.tensor_view<1x1x1x16x1xf32> -> !pto.partition_tensor_view<1x1x1x16x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x512xf32>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x1xf32>) + return + } + + // case10: f32 8x1024 valid=8x1024 (DN dst) + func.func @TROWSUM_case10(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c1024 = arith.constant 1024 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c8, %c1024], + strides = [%c8192, %c8192, %c8192, %c1024, %c1] + : !pto.tensor_view<1x1x1x8x1024xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c8, %c1], + strides = [%c8, %c8, %c8, %c1, %c1] + : !pto.tensor_view<1x1x1x8x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c1024] + : !pto.tensor_view<1x1x1x8x1024xf32> -> !pto.partition_tensor_view<1x1x1x8x1024xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c1] + : !pto.tensor_view<1x1x1x8x1xf32> -> !pto.partition_tensor_view<1x1x1x8x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x8x1024xf32>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x8x1xf32>) + return + } + + // ======================================================================== + // i32 cases (case11-case15) + // ======================================================================== + + // case11: i32 127x64 valid=127x63 + func.func @TROWSUM_case11(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c127 = arith.constant 127 : index + %c8128 = arith.constant 8128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c127, %c64], + strides = [%c8128, %c8128, %c8128, %c64, %c1] + : !pto.tensor_view<1x1x1x127x64xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c127, %c1], + strides = [%c127, %c127, %c127, %c1, %c1] + : !pto.tensor_view<1x1x1x127x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c127, %c63] + : !pto.tensor_view<1x1x1x127x64xi32> -> !pto.partition_tensor_view<1x1x1x127x63xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c127, %c1] + : !pto.tensor_view<1x1x1x127x1xi32> -> !pto.partition_tensor_view<1x1x1x127x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x127x63xi32>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x127x1xi32>) + return + } + + // case12: i32 63x64 valid=63x64 + func.func @TROWSUM_case12(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c1], + strides = [%c63, %c63, %c63, %c1, %c1] + : !pto.tensor_view<1x1x1x63x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xi32> -> !pto.partition_tensor_view<1x1x1x63x64xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c1] + : !pto.tensor_view<1x1x1x63x1xi32> -> !pto.partition_tensor_view<1x1x1x63x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xi32>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x1xi32>) + return + } + + // case13: i32 31x128 valid=31x127 + func.func @TROWSUM_case13(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c31 = arith.constant 31 : index + %c127 = arith.constant 127 : index + %c128 = arith.constant 128 : index + %c3968 = arith.constant 3968 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c31, %c1], + strides = [%c31, %c31, %c31, %c1, %c1] + : !pto.tensor_view<1x1x1x31x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c127] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x127xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c1] + : !pto.tensor_view<1x1x1x31x1xi32> -> !pto.partition_tensor_view<1x1x1x31x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x31x127xi32>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x31x1xi32>) + return + } + + // case14: i32 15x192 valid=15x192 + func.func @TROWSUM_case14(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c192 = arith.constant 192 : index + %c2880 = arith.constant 2880 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c15, %c1], + strides = [%c15, %c15, %c15, %c1, %c1] + : !pto.tensor_view<1x1x1x15x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi32> -> !pto.partition_tensor_view<1x1x1x15x192xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c1] + : !pto.tensor_view<1x1x1x15x1xi32> -> !pto.partition_tensor_view<1x1x1x15x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x192xi32>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x15x1xi32>) + return + } + + // case15: i32 7x448 valid=7x447 + func.func @TROWSUM_case15(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c447 = arith.constant 447 : index + %c448 = arith.constant 448 : index + %c3136 = arith.constant 3136 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c7, %c1], + strides = [%c7, %c7, %c7, %c1, %c1] + : !pto.tensor_view<1x1x1x7x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c447] + : !pto.tensor_view<1x1x1x7x448xi32> -> !pto.partition_tensor_view<1x1x1x7x447xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c1] + : !pto.tensor_view<1x1x1x7x1xi32> -> !pto.partition_tensor_view<1x1x1x7x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x7x447xi32>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x7x1xi32>) + return + } + + // ======================================================================== + // i16 cases (case16-case20) + // ======================================================================== + + // case16: i16 128x64 valid=128x64 + func.func @TROWSUM_case16(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c128, %c64], + strides = [%c8192, %c8192, %c8192, %c64, %c1] + : !pto.tensor_view<1x1x1x128x64xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c128, %c1], + strides = [%c128, %c128, %c128, %c1, %c1] + : !pto.tensor_view<1x1x1x128x1xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c64] + : !pto.tensor_view<1x1x1x128x64xi16> -> !pto.partition_tensor_view<1x1x1x128x64xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c1] + : !pto.tensor_view<1x1x1x128x1xi16> -> !pto.partition_tensor_view<1x1x1x128x1xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x128x64xi16>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x128x1xi16>) + return + } + + // case17: i16 64x64 valid=64x64 + func.func @TROWSUM_case17(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c64, %c1], + strides = [%c64, %c64, %c64, %c1, %c1] + : !pto.tensor_view<1x1x1x64x1xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c1] + : !pto.tensor_view<1x1x1x64x1xi16> -> !pto.partition_tensor_view<1x1x1x64x1xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x64x1xi16>) + return + } + + // case18: i16 32x128 valid=32x128 + func.func @TROWSUM_case18(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c128], + strides = [%c4096, %c4096, %c4096, %c128, %c1] + : !pto.tensor_view<1x1x1x32x128xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c1], + strides = [%c32, %c32, %c32, %c1, %c1] + : !pto.tensor_view<1x1x1x32x1xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c128] + : !pto.tensor_view<1x1x1x32x128xi16> -> !pto.partition_tensor_view<1x1x1x32x128xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c1] + : !pto.tensor_view<1x1x1x32x1xi16> -> !pto.partition_tensor_view<1x1x1x32x1xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x128xi16>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x1xi16>) + return + } + + // case19: i16 16x192 valid=16x192 + func.func @TROWSUM_case19(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c192 = arith.constant 192 : index + %c3072 = arith.constant 3072 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c192], + strides = [%c3072, %c3072, %c3072, %c192, %c1] + : !pto.tensor_view<1x1x1x16x192xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c1], + strides = [%c16, %c16, %c16, %c1, %c1] + : !pto.tensor_view<1x1x1x16x1xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c192] + : !pto.tensor_view<1x1x1x16x192xi16> -> !pto.partition_tensor_view<1x1x1x16x192xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c1] + : !pto.tensor_view<1x1x1x16x1xi16> -> !pto.partition_tensor_view<1x1x1x16x1xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x192xi16>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x1xi16>) + return + } + + // case20: i16 8x448 valid=8x448 + func.func @TROWSUM_case20(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c448 = arith.constant 448 : index + %c3584 = arith.constant 3584 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c8, %c448], + strides = [%c3584, %c3584, %c3584, %c448, %c1] + : !pto.tensor_view<1x1x1x8x448xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c8, %c1], + strides = [%c8, %c8, %c8, %c1, %c1] + : !pto.tensor_view<1x1x1x8x1xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c448] + : !pto.tensor_view<1x1x1x8x448xi16> -> !pto.partition_tensor_view<1x1x1x8x448xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c1] + : !pto.tensor_view<1x1x1x8x1xi16> -> !pto.partition_tensor_view<1x1x1x8x1xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x8x448xi16>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x8x1xi16>) + return + } + +} From 5312763f2238f6340a98f15c3b05a931882d85f4 Mon Sep 17 00:00:00 2001 From: ChaoyangJi Date: Sat, 25 Apr 2026 16:51:46 +0800 Subject: [PATCH 169/192] add tcolmax tcolmin tileops lib implementation (#221) * add tcolmax tcolmin tileops lib implementation * add tcolsum tcolprod tileops lib implementation * fix license check in lib template * fix col mask processing and add constraints check for tcolmax/min/sum/prod * fix license check for tcolmax/min/sum/prod --- lib/TileOps/tcolmax_template.py | 56 + lib/TileOps/tcolmin_template.py | 56 + lib/TileOps/tcolprod_template.py | 56 + lib/TileOps/tcolsum_template.py | 57 + test/basic/tcolmax.pto | 42 + test/basic/tcolmin.pto | 42 + test/basic/tcolprod.pto | 42 + test/basic/tcolsum.pto | 42 + .../npu/a5/src/st/testcase/CMakeLists.txt | 4 + .../a5/src/st/testcase/tcolmax/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tcolmax/cases.py | 245 ++++ .../npu/a5/src/st/testcase/tcolmax/compare.py | 50 + .../a5/src/st/testcase/tcolmax/gen_data.py | 35 + .../npu/a5/src/st/testcase/tcolmax/launch.cpp | 181 +++ .../npu/a5/src/st/testcase/tcolmax/main.cpp | 189 +++ .../a5/src/st/testcase/tcolmax/tcolmax.pto | 1182 +++++++++++++++++ .../a5/src/st/testcase/tcolmin/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tcolmin/cases.py | 245 ++++ .../npu/a5/src/st/testcase/tcolmin/compare.py | 50 + .../a5/src/st/testcase/tcolmin/gen_data.py | 35 + .../npu/a5/src/st/testcase/tcolmin/launch.cpp | 181 +++ .../npu/a5/src/st/testcase/tcolmin/main.cpp | 189 +++ .../a5/src/st/testcase/tcolmin/tcolmin.pto | 1182 +++++++++++++++++ .../src/st/testcase/tcolprod/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tcolprod/cases.py | 164 +++ .../a5/src/st/testcase/tcolprod/compare.py | 50 + .../a5/src/st/testcase/tcolprod/gen_data.py | 35 + .../a5/src/st/testcase/tcolprod/launch.cpp | 118 ++ .../npu/a5/src/st/testcase/tcolprod/main.cpp | 170 +++ .../a5/src/st/testcase/tcolprod/tcolprod.pto | 744 +++++++++++ .../a5/src/st/testcase/tcolsum/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tcolsum/cases.py | 173 +++ .../npu/a5/src/st/testcase/tcolsum/compare.py | 50 + .../a5/src/st/testcase/tcolsum/gen_data.py | 35 + .../npu/a5/src/st/testcase/tcolsum/launch.cpp | 125 ++ .../npu/a5/src/st/testcase/tcolsum/main.cpp | 173 +++ .../a5/src/st/testcase/tcolsum/tcolsum.pto | 793 +++++++++++ 37 files changed, 6827 insertions(+) create mode 100644 lib/TileOps/tcolmax_template.py create mode 100644 lib/TileOps/tcolmin_template.py create mode 100644 lib/TileOps/tcolprod_template.py create mode 100644 lib/TileOps/tcolsum_template.py create mode 100644 test/basic/tcolmax.pto create mode 100644 test/basic/tcolmin.pto create mode 100644 test/basic/tcolprod.pto create mode 100644 test/basic/tcolsum.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolmax/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolmax/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolmax/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolmax/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolmax/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolmax/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolmax/tcolmax.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolmin/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolmin/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolmin/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolmin/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolmin/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolmin/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolmin/tcolmin.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolprod/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolprod/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolprod/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolprod/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolprod/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolprod/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolprod/tcolprod.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolsum/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolsum/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolsum/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolsum/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolsum/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolsum/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolsum/tcolsum.pto diff --git a/lib/TileOps/tcolmax_template.py b/lib/TileOps/tcolmax_template.py new file mode 100644 index 000000000..f08df7eda --- /dev/null +++ b/lib/TileOps/tcolmax_template.py @@ -0,0 +1,56 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys +from pathlib import Path +import tilelang_dsl as pto + +def _validate_tcolmax( + src_shape=(), + src_valid_shape=(), + dst_shape=(), + dst_valid_shape=(), + src_config=None, + dst_config=None +): + if src_config is None or dst_config is None: + return False + if src_config.b_layout != pto.BLayout.ROW_MAJOR: + return False + if dst_config.b_layout != pto.BLayout.ROW_MAJOR: + return False + if src_config.s_layout != pto.SLayout.NONE_BOX: + return False + if dst_config.s_layout != pto.SLayout.NONE_BOX: + return False + if dst_valid_shape[0] != 1: + return False + return True + +@pto.vkernel( + target="a5", + op="pto.tcolmax", + constraints=[_validate_tcolmax] +) +def template_tcolmax(src: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = src.valid_shape + + lanes = pto.get_lanes(dtype) + remained = valid_cols + + for col_chunk in range(0, valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + + acc = pto.vlds(src[0, col_chunk:]) + for row in range(1, valid_rows, 1): + row_vec = pto.vlds(src[row, col_chunk:]) + acc = pto.vmax(acc, row_vec, mask) + pto.vsts(acc, dst[0, col_chunk:], mask) + + return diff --git a/lib/TileOps/tcolmin_template.py b/lib/TileOps/tcolmin_template.py new file mode 100644 index 000000000..2a36dcdd5 --- /dev/null +++ b/lib/TileOps/tcolmin_template.py @@ -0,0 +1,56 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys +from pathlib import Path +import tilelang_dsl as pto + +def _validate_tcolmin( + src_shape=(), + src_valid_shape=(), + dst_shape=(), + dst_valid_shape=(), + src_config=None, + dst_config=None +): + if src_config is None or dst_config is None: + return False + if src_config.b_layout != pto.BLayout.ROW_MAJOR: + return False + if dst_config.b_layout != pto.BLayout.ROW_MAJOR: + return False + if src_config.s_layout != pto.SLayout.NONE_BOX: + return False + if dst_config.s_layout != pto.SLayout.NONE_BOX: + return False + if dst_valid_shape[0] != 1: + return False + return True + +@pto.vkernel( + target="a5", + op="pto.tcolmin", + constraints=[_validate_tcolmin] +) +def template_tcolmin(src: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = src.valid_shape + + lanes = pto.get_lanes(dtype) + remained = valid_cols + + for col_chunk in range(0, valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + + acc = pto.vlds(src[0, col_chunk:]) + for row in range(1, valid_rows, 1): + row_vec = pto.vlds(src[row, col_chunk:]) + acc = pto.vmin(acc, row_vec, mask) + pto.vsts(acc, dst[0, col_chunk:], mask) + + return diff --git a/lib/TileOps/tcolprod_template.py b/lib/TileOps/tcolprod_template.py new file mode 100644 index 000000000..4ebb99f48 --- /dev/null +++ b/lib/TileOps/tcolprod_template.py @@ -0,0 +1,56 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys +from pathlib import Path +import tilelang_dsl as pto + +def _validate_tcolprod( + src_shape=(), + src_valid_shape=(), + dst_shape=(), + dst_valid_shape=(), + src_config=None, + dst_config=None +): + if src_config is None or dst_config is None: + return False + if src_config.b_layout != pto.BLayout.ROW_MAJOR: + return False + if dst_config.b_layout != pto.BLayout.ROW_MAJOR: + return False + if src_config.s_layout != pto.SLayout.NONE_BOX: + return False + if dst_config.s_layout != pto.SLayout.NONE_BOX: + return False + if dst_valid_shape[0] != 1: + return False + return True + +@pto.vkernel( + target="a5", + op="pto.tcolprod", + constraints=[_validate_tcolprod] +) +def template_tcolprod(src: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = src.valid_shape + + lanes = pto.get_lanes(dtype) + remained = valid_cols + + for col_chunk in range(0, valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + + acc = pto.vlds(src[0, col_chunk:]) + for row in range(1, valid_rows, 1): + row_vec = pto.vlds(src[row, col_chunk:]) + acc = pto.vmul(acc, row_vec, mask) + pto.vsts(acc, dst[0, col_chunk:], mask) + + return diff --git a/lib/TileOps/tcolsum_template.py b/lib/TileOps/tcolsum_template.py new file mode 100644 index 000000000..b187b45a3 --- /dev/null +++ b/lib/TileOps/tcolsum_template.py @@ -0,0 +1,57 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys +from pathlib import Path +import tilelang_dsl as pto + +def _validate_tcolsum( + src_shape=(), + src_valid_shape=(), + dst_shape=(), + dst_valid_shape=(), + src_config=None, + dst_config=None +): + if src_config is None or dst_config is None: + return False + if src_config.b_layout != pto.BLayout.ROW_MAJOR: + return False + if dst_config.b_layout != pto.BLayout.ROW_MAJOR: + return False + if src_config.s_layout != pto.SLayout.NONE_BOX: + return False + if dst_config.s_layout != pto.SLayout.NONE_BOX: + return False + if dst_valid_shape[0] != 1: + return False + return True + +# Todo: This is the basic implementation. Later the binary colsum algorithm should be implemented also. +@pto.vkernel( + target="a5", + op="pto.tcolsum", + constraints=[_validate_tcolsum] +) +def template_tcolsum(src: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = src.valid_shape + + lanes = pto.get_lanes(dtype) + remained = valid_cols + + for col_chunk in range(0, valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + + acc = pto.vlds(src[0, col_chunk:]) + for row in range(1, valid_rows, 1): + row_vec = pto.vlds(src[row, col_chunk:]) + acc = pto.vadd(acc, row_vec, mask) + pto.vsts(acc, dst[0, col_chunk:], mask) + + return diff --git a/test/basic/tcolmax.pto b/test/basic/tcolmax.pto new file mode 100644 index 000000000..f117ad856 --- /dev/null +++ b/test/basic/tcolmax.pto @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FIT FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tcolmax via the TileLang Python DSL template +// lib/TileOps/tcolmax_template.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tcolmax should be lowered to vector-style VPTO IR. +// CHECK: func.func @TCOLMAX +// CHECK-NOT: pto.tcolmax ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vmax +// CHECK: pto.vsts + +module { + func.func @TCOLMAX() { + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} diff --git a/test/basic/tcolmin.pto b/test/basic/tcolmin.pto new file mode 100644 index 000000000..dd071ca73 --- /dev/null +++ b/test/basic/tcolmin.pto @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FIT FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tcolmin via the TileLang Python DSL template +// lib/TileOps/tcolmin_template.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tcolmin should be lowered to vector-style VPTO IR. +// CHECK: func.func @TCOLMIN +// CHECK-NOT: pto.tcolmin ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vmin +// CHECK: pto.vsts + +module { + func.func @TCOLMIN() { + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} diff --git a/test/basic/tcolprod.pto b/test/basic/tcolprod.pto new file mode 100644 index 000000000..4ca1b8e93 --- /dev/null +++ b/test/basic/tcolprod.pto @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FIT FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tcolprod via the TileLang Python DSL template +// lib/TileOps/tcolprod_template.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tcolprod should be lowered to vector-style VPTO IR. +// CHECK: func.func @TCOLPROD +// CHECK-NOT: pto.tcolprod ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vmul +// CHECK: pto.vsts + +module { + func.func @TCOLPROD() { + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tcolprod ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} diff --git a/test/basic/tcolsum.pto b/test/basic/tcolsum.pto new file mode 100644 index 000000000..f63ac6525 --- /dev/null +++ b/test/basic/tcolsum.pto @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FIT FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tcolsum via the TileLang Python DSL template +// lib/TileOps/tcolsum_template.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tcolsum should be lowered to vector-style VPTO IR. +// CHECK: func.func @TCOLSUM +// CHECK-NOT: pto.tcolsum ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vadd +// CHECK: pto.vsts + +module { + func.func @TCOLSUM() { + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tcolsum ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt index 352667eb0..347354bb0 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt +++ b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt @@ -132,6 +132,10 @@ set(ALL_TESTCASES tadd tcvt tload + tcolmax + tcolmin + tcolsum + tcolprod softmax tabs texp diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/CMakeLists.txt new file mode 100644 index 000000000..5afae033c --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tcolmax) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/cases.py new file mode 100644 index 000000000..f6fa36d9e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/cases.py @@ -0,0 +1,245 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tcolmax ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions for input. + - valid_shape: (valid_rows, valid_cols) — effective computation region for input. + - dst_shape: (1, cols) — allocated tile dimensions for output. + - dst_valid_shape: (1, valid_cols) — effective computation region for output. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_1x256", + "dtype": np.float32, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 1e-6, + }, + { + "name": "f32_16x128", + "dtype": np.float32, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 1e-6, + }, + { + "name": "f32_16x256", + "dtype": np.float32, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 1e-6, + }, + { + "name": "f16_1x256", + "dtype": np.float16, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 1e-3, + }, + { + "name": "f16_16x128", + "dtype": np.float16, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 1e-3, + }, + { + "name": "f16_16x256", + "dtype": np.float16, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 1e-3, + }, + { + "name": "i8_1x256", + "dtype": np.int8, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "i8_16x128", + "dtype": np.int8, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 0, + }, + { + "name": "i8_16x256", + "dtype": np.int8, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "i16_1x256", + "dtype": np.int16, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "i16_16x128", + "dtype": np.int16, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 0, + }, + { + "name": "i16_16x256", + "dtype": np.int16, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "i32_1x256", + "dtype": np.int32, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "i32_16x128", + "dtype": np.int32, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 0, + }, + { + "name": "i32_16x256", + "dtype": np.int32, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "ui8_1x256", + "dtype": np.uint8, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "ui8_16x128", + "dtype": np.uint8, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 0, + }, + { + "name": "ui8_16x256", + "dtype": np.uint8, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "ui16_1x256", + "dtype": np.uint16, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "ui16_16x128", + "dtype": np.uint16, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 0, + }, + { + "name": "ui16_16x256", + "dtype": np.uint16, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "ui32_1x256", + "dtype": np.uint32, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "ui32_16x128", + "dtype": np.uint32, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 0, + }, + { + "name": "ui32_16x256", + "dtype": np.uint32, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/compare.py new file mode 100644 index 000000000..06a17bbda --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + vr, vc = dst_valid_shape + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(dst_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(dst_shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/gen_data.py new file mode 100644 index 000000000..4dcd83b95 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/gen_data.py @@ -0,0 +1,35 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + + vr, vc = valid_shape + golden = np.zeros(dst_shape, dtype=dtype) + golden_result = np.max(input1[:vr, :vc], axis=0, keepdims=True).astype(dtype) + golden[:1, :vc] = golden_result + + save_case_data(case["name"], {"input": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dst_shape={dst_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/launch.cpp new file mode 100644 index 000000000..cb5325a5b --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/launch.cpp @@ -0,0 +1,181 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMAX_f32_1x256(__gm__ float *dst, __gm__ float *src); + +void LaunchTCOLMAX_f32_1x256(float *dst, float *src, void *stream) { + TCOLMAX_f32_1x256<<<1, nullptr, stream>>>((__gm__ float *)dst, (__gm__ float *)src); +} + +// Case 1: f32 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLMAX_f32_16x128(__gm__ float *dst, __gm__ float *src); + +void LaunchTCOLMAX_f32_16x128(float *dst, float *src, void *stream) { + TCOLMAX_f32_16x128<<<1, nullptr, stream>>>((__gm__ float *)dst, (__gm__ float *)src); +} + +// Case 2: f32 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMAX_f32_16x256(__gm__ float *dst, __gm__ float *src); + +void LaunchTCOLMAX_f32_16x256(float *dst, float *src, void *stream) { + TCOLMAX_f32_16x256<<<1, nullptr, stream>>>((__gm__ float *)dst, (__gm__ float *)src); +} + +// Case 3: f16 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMAX_f16_1x256(__gm__ half *dst, __gm__ half *src); + +void LaunchTCOLMAX_f16_1x256(void *dst, void *src, void *stream) { + TCOLMAX_f16_1x256<<<1, nullptr, stream>>>((__gm__ half *)dst, (__gm__ half *)src); +} + +// Case 4: f16 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLMAX_f16_16x128(__gm__ half *dst, __gm__ half *src); + +void LaunchTCOLMAX_f16_16x128(void *dst, void *src, void *stream) { + TCOLMAX_f16_16x128<<<1, nullptr, stream>>>((__gm__ half *)dst, (__gm__ half *)src); +} + +// Case 5: f16 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMAX_f16_16x256(__gm__ half *dst, __gm__ half *src); + +void LaunchTCOLMAX_f16_16x256(void *dst, void *src, void *stream) { + TCOLMAX_f16_16x256<<<1, nullptr, stream>>>((__gm__ half *)dst, (__gm__ half *)src); +} + +// Case 6: i8 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMAX_i8_1x256(__gm__ int8_t *dst, __gm__ int8_t *src); + +void LaunchTCOLMAX_i8_1x256(void *dst, void *src, void *stream) { + TCOLMAX_i8_1x256<<<1, nullptr, stream>>>((__gm__ int8_t *)dst, (__gm__ int8_t *)src); +} + +// Case 7: i8 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLMAX_i8_16x128(__gm__ int8_t *dst, __gm__ int8_t *src); + +void LaunchTCOLMAX_i8_16x128(void *dst, void *src, void *stream) { + TCOLMAX_i8_16x128<<<1, nullptr, stream>>>((__gm__ int8_t *)dst, (__gm__ int8_t *)src); +} + +// Case 8: i8 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMAX_i8_16x256(__gm__ int8_t *dst, __gm__ int8_t *src); + +void LaunchTCOLMAX_i8_16x256(void *dst, void *src, void *stream) { + TCOLMAX_i8_16x256<<<1, nullptr, stream>>>((__gm__ int8_t *)dst, (__gm__ int8_t *)src); +} + +// Case 9: i16 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMAX_i16_1x256(__gm__ int16_t *dst, __gm__ int16_t *src); + +void LaunchTCOLMAX_i16_1x256(void *dst, void *src, void *stream) { + TCOLMAX_i16_1x256<<<1, nullptr, stream>>>((__gm__ int16_t *)dst, (__gm__ int16_t *)src); +} + +// Case 10: i16 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLMAX_i16_16x128(__gm__ int16_t *dst, __gm__ int16_t *src); + +void LaunchTCOLMAX_i16_16x128(void *dst, void *src, void *stream) { + TCOLMAX_i16_16x128<<<1, nullptr, stream>>>((__gm__ int16_t *)dst, (__gm__ int16_t *)src); +} + +// Case 11: i16 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMAX_i16_16x256(__gm__ int16_t *dst, __gm__ int16_t *src); + +void LaunchTCOLMAX_i16_16x256(void *dst, void *src, void *stream) { + TCOLMAX_i16_16x256<<<1, nullptr, stream>>>((__gm__ int16_t *)dst, (__gm__ int16_t *)src); +} + +// Case 12: i32 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMAX_i32_1x256(__gm__ int32_t *dst, __gm__ int32_t *src); + +void LaunchTCOLMAX_i32_1x256(void *dst, void *src, void *stream) { + TCOLMAX_i32_1x256<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ int32_t *)src); +} + +// Case 13: i32 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLMAX_i32_16x128(__gm__ int32_t *dst, __gm__ int32_t *src); + +void LaunchTCOLMAX_i32_16x128(void *dst, void *src, void *stream) { + TCOLMAX_i32_16x128<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ int32_t *)src); +} + +// Case 14: i32 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMAX_i32_16x256(__gm__ int32_t *dst, __gm__ int32_t *src); + +void LaunchTCOLMAX_i32_16x256(void *dst, void *src, void *stream) { + TCOLMAX_i32_16x256<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ int32_t *)src); +} + +// Case 15: ui8 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMAX_ui8_1x256(__gm__ uint8_t *dst, __gm__ uint8_t *src); + +void LaunchTCOLMAX_ui8_1x256(void *dst, void *src, void *stream) { + TCOLMAX_ui8_1x256<<<1, nullptr, stream>>>((__gm__ uint8_t *)dst, (__gm__ uint8_t *)src); +} + +// Case 16: ui8 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLMAX_ui8_16x128(__gm__ uint8_t *dst, __gm__ uint8_t *src); + +void LaunchTCOLMAX_ui8_16x128(void *dst, void *src, void *stream) { + TCOLMAX_ui8_16x128<<<1, nullptr, stream>>>((__gm__ uint8_t *)dst, (__gm__ uint8_t *)src); +} + +// Case 17: ui8 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMAX_ui8_16x256(__gm__ uint8_t *dst, __gm__ uint8_t *src); + +void LaunchTCOLMAX_ui8_16x256(void *dst, void *src, void *stream) { + TCOLMAX_ui8_16x256<<<1, nullptr, stream>>>((__gm__ uint8_t *)dst, (__gm__ uint8_t *)src); +} + +// Case 18: ui16 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMAX_ui16_1x256(__gm__ uint16_t *dst, __gm__ uint16_t *src); + +void LaunchTCOLMAX_ui16_1x256(void *dst, void *src, void *stream) { + TCOLMAX_ui16_1x256<<<1, nullptr, stream>>>((__gm__ uint16_t *)dst, (__gm__ uint16_t *)src); +} + +// Case 19: ui16 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLMAX_ui16_16x128(__gm__ uint16_t *dst, __gm__ uint16_t *src); + +void LaunchTCOLMAX_ui16_16x128(void *dst, void *src, void *stream) { + TCOLMAX_ui16_16x128<<<1, nullptr, stream>>>((__gm__ uint16_t *)dst, (__gm__ uint16_t *)src); +} + +// Case 20: ui16 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMAX_ui16_16x256(__gm__ uint16_t *dst, __gm__ uint16_t *src); + +void LaunchTCOLMAX_ui16_16x256(void *dst, void *src, void *stream) { + TCOLMAX_ui16_16x256<<<1, nullptr, stream>>>((__gm__ uint16_t *)dst, (__gm__ uint16_t *)src); +} + +// Case 21: ui32 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMAX_ui32_1x256(__gm__ uint32_t *dst, __gm__ uint32_t *src); + +void LaunchTCOLMAX_ui32_1x256(void *dst, void *src, void *stream) { + TCOLMAX_ui32_1x256<<<1, nullptr, stream>>>((__gm__ uint32_t *)dst, (__gm__ uint32_t *)src); +} + +// Case 22: ui32 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLMAX_ui32_16x128(__gm__ uint32_t *dst, __gm__ uint32_t *src); + +void LaunchTCOLMAX_ui32_16x128(void *dst, void *src, void *stream) { + TCOLMAX_ui32_16x128<<<1, nullptr, stream>>>((__gm__ uint32_t *)dst, (__gm__ uint32_t *)src); +} + +// Case 23: ui32 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMAX_ui32_16x256(__gm__ uint32_t *dst, __gm__ uint32_t *src); + +void LaunchTCOLMAX_ui32_16x256(void *dst, void *src, void *stream) { + TCOLMAX_ui32_16x256<<<1, nullptr, stream>>>((__gm__ uint32_t *)dst, (__gm__ uint32_t *)src); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/main.cpp new file mode 100644 index 000000000..aaa0b9505 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/main.cpp @@ -0,0 +1,189 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tcolmax ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTCOLMAX_f32_1x256(float *dst, float *src, void *stream); +void LaunchTCOLMAX_f32_16x128(float *dst, float *src, void *stream); +void LaunchTCOLMAX_f32_16x256(float *dst, float *src, void *stream); +void LaunchTCOLMAX_f16_1x256(void *dst, void *src, void *stream); +void LaunchTCOLMAX_f16_16x128(void *dst, void *src, void *stream); +void LaunchTCOLMAX_f16_16x256(void *dst, void *src, void *stream); +void LaunchTCOLMAX_i8_1x256(void *dst, void *src, void *stream); +void LaunchTCOLMAX_i8_16x128(void *dst, void *src, void *stream); +void LaunchTCOLMAX_i8_16x256(void *dst, void *src, void *stream); +void LaunchTCOLMAX_i16_1x256(void *dst, void *src, void *stream); +void LaunchTCOLMAX_i16_16x128(void *dst, void *src, void *stream); +void LaunchTCOLMAX_i16_16x256(void *dst, void *src, void *stream); +void LaunchTCOLMAX_i32_1x256(void *dst, void *src, void *stream); +void LaunchTCOLMAX_i32_16x128(void *dst, void *src, void *stream); +void LaunchTCOLMAX_i32_16x256(void *dst, void *src, void *stream); +void LaunchTCOLMAX_ui8_1x256(void *dst, void *src, void *stream); +void LaunchTCOLMAX_ui8_16x128(void *dst, void *src, void *stream); +void LaunchTCOLMAX_ui8_16x256(void *dst, void *src, void *stream); +void LaunchTCOLMAX_ui16_1x256(void *dst, void *src, void *stream); +void LaunchTCOLMAX_ui16_16x128(void *dst, void *src, void *stream); +void LaunchTCOLMAX_ui16_16x256(void *dst, void *src, void *stream); +void LaunchTCOLMAX_ui32_1x256(void *dst, void *src, void *stream); +void LaunchTCOLMAX_ui32_16x128(void *dst, void *src, void *stream); +void LaunchTCOLMAX_ui32_16x256(void *dst, void *src, void *stream); + +using LaunchFnFloat = void (*)(float *, float *, void *); +using LaunchFnVoid = void (*)(void *, void *, void *); + +struct TestCase { + const char *name; + void *launch; + size_t srcRows; + size_t srcCols; + size_t srcValidRows; + size_t srcValidCols; + size_t dstRows; + size_t dstCols; + size_t dstValidCols; + size_t elemSize; + bool isFp16; +}; + +static const TestCase kCases[] = { + {"f32_1x256", (void*)LaunchTCOLMAX_f32_1x256, 1, 256, 1, 255, 1, 256, 255, sizeof(float), false}, + {"f32_16x128", (void*)LaunchTCOLMAX_f32_16x128, 16, 128, 16, 127, 1, 128, 127, sizeof(float), false}, + {"f32_16x256", (void*)LaunchTCOLMAX_f32_16x256, 16, 256, 15, 255, 1, 256, 255, sizeof(float), false}, + {"f16_1x256", (void*)LaunchTCOLMAX_f16_1x256, 1, 256, 1, 255, 1, 256, 255, 2, true}, + {"f16_16x128", (void*)LaunchTCOLMAX_f16_16x128, 16, 128, 16, 127, 1, 128, 127, 2, true}, + {"f16_16x256", (void*)LaunchTCOLMAX_f16_16x256, 16, 256, 15, 255, 1, 256, 255, 2, true}, + {"i8_1x256", (void*)LaunchTCOLMAX_i8_1x256, 1, 256, 1, 255, 1, 256, 255, 1, true}, + {"i8_16x128", (void*)LaunchTCOLMAX_i8_16x128, 16, 128, 16, 127, 1, 128, 127, 1, true}, + {"i8_16x256", (void*)LaunchTCOLMAX_i8_16x256, 16, 256, 15, 255, 1, 256, 255, 1, true}, + {"i16_1x256", (void*)LaunchTCOLMAX_i16_1x256, 1, 256, 1, 255, 1, 256, 255, 2, true}, + {"i16_16x128", (void*)LaunchTCOLMAX_i16_16x128, 16, 128, 16, 127, 1, 128, 127, 2, true}, + {"i16_16x256", (void*)LaunchTCOLMAX_i16_16x256, 16, 256, 15, 255, 1, 256, 255, 2, true}, + {"i32_1x256", (void*)LaunchTCOLMAX_i32_1x256, 1, 256, 1, 255, 1, 256, 255, 4, true}, + {"i32_16x128", (void*)LaunchTCOLMAX_i32_16x128, 16, 128, 16, 127, 1, 128, 127, 4, true}, + {"i32_16x256", (void*)LaunchTCOLMAX_i32_16x256, 16, 256, 15, 255, 1, 256, 255, 4, true}, + {"ui8_1x256", (void*)LaunchTCOLMAX_ui8_1x256, 1, 256, 1, 255, 1, 256, 255, 1, true}, + {"ui8_16x128", (void*)LaunchTCOLMAX_ui8_16x128, 16, 128, 16, 127, 1, 128, 127, 1, true}, + {"ui8_16x256", (void*)LaunchTCOLMAX_ui8_16x256, 16, 256, 15, 255, 1, 256, 255, 1, true}, + {"ui16_1x256", (void*)LaunchTCOLMAX_ui16_1x256, 1, 256, 1, 255, 1, 256, 255, 2, true}, + {"ui16_16x128", (void*)LaunchTCOLMAX_ui16_16x128, 16, 128, 16, 127, 1, 128, 127, 2, true}, + {"ui16_16x256", (void*)LaunchTCOLMAX_ui16_16x256, 16, 256, 15, 255, 1, 256, 255, 2, true}, + {"ui32_1x256", (void*)LaunchTCOLMAX_ui32_1x256, 1, 256, 1, 255, 1, 256, 255, 4, true}, + {"ui32_16x128", (void*)LaunchTCOLMAX_ui32_16x128, 16, 128, 16, 127, 1, 128, 127, 4, true}, + {"ui32_16x256", (void*)LaunchTCOLMAX_ui32_16x256, 16, 256, 15, 255, 1, 256, 255, 4, true}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t srcElemCount = tc.srcRows * tc.srcCols; + const size_t srcFileSize = srcElemCount * tc.elemSize; + const size_t dstElemCount = tc.dstRows * tc.dstCols; + const size_t dstFileSize = dstElemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (src=%zux%zu, dst=%zux%zu, fp16=%d) ===\n", + tc.name, tc.srcRows, tc.srcCols, tc.dstRows, tc.dstCols, tc.isFp16); + + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSizeVar = srcFileSize; + size_t dstFileSizeVar = dstFileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, srcFileSize); + aclrtMallocHost(&dstHost, dstFileSize); + + aclrtMalloc(&srcDevice, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), srcFileSizeVar, srcHost, srcFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, srcFileSize, srcHost, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + if (tc.isFp16) { + LaunchFnVoid launch = (LaunchFnVoid)tc.launch; + launch(dstDevice, srcDevice, stream); + } else { + LaunchFnFloat launch = (LaunchFnFloat)tc.launch; + launch((float*)dstDevice, (float*)srcDevice, stream); + } + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSizeVar)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/tcolmax.pto b/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/tcolmax.pto new file mode 100644 index 000000000..2265bce5a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/tcolmax.pto @@ -0,0 +1,1182 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tcolmax: tload(src) + tcolmax(dst, src) + tstore(dst). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. + +module { + // Case 0: f32 1x256 (input: 1x256, output: 1x256) + func.func @TCOLMAX_f32_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x255xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x255xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xf32>) + return + } + +// Case 1: f32 16x128 (input: 16x128, output: 1x128) + func.func @TCOLMAX_f32_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x127xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x127xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xf32>) + return + } + + // Case 2: f32 16x256 (input: 16x256, output: 1x256) + func.func @TCOLMAX_f32_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xf32> -> !pto.partition_tensor_view<1x1x1x15x255xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x255xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xf32>) + return + } + + // Case 3: f16 1x256 (input: 1x256, output: 1x256) + func.func @TCOLMAX_f16_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x255xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x255xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xf16>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xf16>) + return + } + + // Case 4: f16 16x128 (input: 16x128, output: 1x128) + func.func @TCOLMAX_f16_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x127xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xf16> -> !pto.partition_tensor_view<1x1x1x1x127xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xf16>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xf16>) + return + } + + // Case 5: f16 16x256 (input: 16x256, output: 1x256) + func.func @TCOLMAX_f16_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xf16> -> !pto.partition_tensor_view<1x1x1x15x255xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x255xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xf16>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xf16>) + return + } + + // Case 6: i8 1x256 (input: 1x256, output: 1x256) + func.func @TCOLMAX_i8_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi8> -> !pto.partition_tensor_view<1x1x1x1x255xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi8> -> !pto.partition_tensor_view<1x1x1x1x255xi8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xi8>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi8>) + return + } + + // Case 7: i8 16x128 (input: 16x128, output: 1x128) + func.func @TCOLMAX_i8_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xi8> -> !pto.partition_tensor_view<1x1x1x16x127xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xi8> -> !pto.partition_tensor_view<1x1x1x1x127xi8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xi8>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xi8>) + return + } + + // Case 8: i8 16x256 (input: 16x256, output: 1x256) + func.func @TCOLMAX_i8_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xi8> -> !pto.partition_tensor_view<1x1x1x15x255xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi8> -> !pto.partition_tensor_view<1x1x1x1x255xi8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xi8>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi8>) + return + } + + // Case 9: i16 1x256 (input: 1x256, output: 1x256) + func.func @TCOLMAX_i16_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi16> -> !pto.partition_tensor_view<1x1x1x1x255xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi16> -> !pto.partition_tensor_view<1x1x1x1x255xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xi16>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi16>) + return + } + + // Case 10: i16 16x128 (input: 16x128, output: 1x128) + func.func @TCOLMAX_i16_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xi16> -> !pto.partition_tensor_view<1x1x1x16x127xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xi16> -> !pto.partition_tensor_view<1x1x1x1x127xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xi16>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xi16>) + return + } + + // Case 11: i16 16x256 (input: 16x256, output: 1x256) + func.func @TCOLMAX_i16_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xi16> -> !pto.partition_tensor_view<1x1x1x15x255xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi16> -> !pto.partition_tensor_view<1x1x1x1x255xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xi16>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi16>) + return + } + + // Case 12: i32 1x256 (input: 1x256, output: 1x256) + func.func @TCOLMAX_i32_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + return + } + + // Case 13: i32 16x128 (input: 16x128, output: 1x128) + func.func @TCOLMAX_i32_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xi32> -> !pto.partition_tensor_view<1x1x1x16x127xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xi32> -> !pto.partition_tensor_view<1x1x1x1x127xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xi32>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xi32>) + return + } + + // Case 14: i32 16x256 (input: 16x256, output: 1x256) + func.func @TCOLMAX_i32_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xi32> -> !pto.partition_tensor_view<1x1x1x15x255xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xi32>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + return + } + + // Case 15: ui8 1x256 (input: 1x256, output: 1x256) + func.func @TCOLMAX_ui8_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui8> -> !pto.partition_tensor_view<1x1x1x1x255xui8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui8> -> !pto.partition_tensor_view<1x1x1x1x255xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xui8>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xui8>) + return + } + + // Case 16: ui8 16x128 (input: 16x128, output: 1x128) + func.func @TCOLMAX_ui8_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xui8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xui8> -> !pto.partition_tensor_view<1x1x1x16x127xui8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xui8> -> !pto.partition_tensor_view<1x1x1x1x127xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xui8>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xui8>) + return + } + + // Case 17: ui8 16x256 (input: 16x256, output: 1x256) + func.func @TCOLMAX_ui8_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xui8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xui8> -> !pto.partition_tensor_view<1x1x1x15x255xui8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui8> -> !pto.partition_tensor_view<1x1x1x1x255xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xui8>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xui8>) + return + } + + // Case 18: ui16 1x256 (input: 1x256, output: 1x256) + func.func @TCOLMAX_ui16_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui16> -> !pto.partition_tensor_view<1x1x1x1x255xui16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui16> -> !pto.partition_tensor_view<1x1x1x1x255xui16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xui16>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xui16>) + return + } + + // Case 19: ui16 16x128 (input: 16x128, output: 1x128) + func.func @TCOLMAX_ui16_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xui16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xui16> -> !pto.partition_tensor_view<1x1x1x16x127xui16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xui16> -> !pto.partition_tensor_view<1x1x1x1x127xui16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xui16>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xui16>) + return + } + + // Case 20: ui16 16x256 (input: 16x256, output: 1x256) + func.func @TCOLMAX_ui16_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xui16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xui16> -> !pto.partition_tensor_view<1x1x1x15x255xui16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui16> -> !pto.partition_tensor_view<1x1x1x1x255xui16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xui16>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xui16>) + return + } + + // Case 21: ui32 1x256 (input: 1x256, output: 1x256) + func.func @TCOLMAX_ui32_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui32> -> !pto.partition_tensor_view<1x1x1x1x255xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui32> -> !pto.partition_tensor_view<1x1x1x1x255xui32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xui32>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xui32>) + return + } + + // Case 22: ui32 16x128 (input: 16x128, output: 1x128) + func.func @TCOLMAX_ui32_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xui32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xui32> -> !pto.partition_tensor_view<1x1x1x16x127xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xui32> -> !pto.partition_tensor_view<1x1x1x1x127xui32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xui32>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xui32>) + return + } + + // Case 23: ui32 16x256 (input: 16x256, output: 1x256) + func.func @TCOLMAX_ui32_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xui32> -> !pto.partition_tensor_view<1x1x1x15x255xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui32> -> !pto.partition_tensor_view<1x1x1x1x255xui32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xui32>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xui32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/CMakeLists.txt new file mode 100644 index 000000000..6de952f0b --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tcolmin) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/cases.py new file mode 100644 index 000000000..ba16bbd3f --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/cases.py @@ -0,0 +1,245 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tcolmin ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions for input. + - valid_shape: (valid_rows, valid_cols) — effective computation region for input. + - dst_shape: (1, cols) — allocated tile dimensions for output. + - dst_valid_shape: (1, valid_cols) — effective computation region for output. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_1x256", + "dtype": np.float32, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 1e-6, + }, + { + "name": "f32_16x128", + "dtype": np.float32, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 1e-6, + }, + { + "name": "f32_16x256", + "dtype": np.float32, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 1e-6, + }, + { + "name": "f16_1x256", + "dtype": np.float16, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 1e-3, + }, + { + "name": "f16_16x128", + "dtype": np.float16, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 1e-3, + }, + { + "name": "f16_16x256", + "dtype": np.float16, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 1e-3, + }, + { + "name": "i8_1x256", + "dtype": np.int8, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "i8_16x128", + "dtype": np.int8, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 0, + }, + { + "name": "i8_16x256", + "dtype": np.int8, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "i16_1x256", + "dtype": np.int16, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "i16_16x128", + "dtype": np.int16, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 0, + }, + { + "name": "i16_16x256", + "dtype": np.int16, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "i32_1x256", + "dtype": np.int32, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "i32_16x128", + "dtype": np.int32, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 0, + }, + { + "name": "i32_16x256", + "dtype": np.int32, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "ui8_1x256", + "dtype": np.uint8, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "ui8_16x128", + "dtype": np.uint8, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 0, + }, + { + "name": "ui8_16x256", + "dtype": np.uint8, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "ui16_1x256", + "dtype": np.uint16, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "ui16_16x128", + "dtype": np.uint16, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 0, + }, + { + "name": "ui16_16x256", + "dtype": np.uint16, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "ui32_1x256", + "dtype": np.uint32, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "ui32_16x128", + "dtype": np.uint32, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 0, + }, + { + "name": "ui32_16x256", + "dtype": np.uint32, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/compare.py new file mode 100644 index 000000000..06a17bbda --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + vr, vc = dst_valid_shape + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(dst_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(dst_shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/gen_data.py new file mode 100644 index 000000000..152c58370 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/gen_data.py @@ -0,0 +1,35 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + + vr, vc = valid_shape + golden = np.zeros(dst_shape, dtype=dtype) + golden_result = np.min(input1[:vr, :vc], axis=0, keepdims=True).astype(dtype) + golden[:1, :vc] = golden_result + + save_case_data(case["name"], {"input": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dst_shape={dst_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/launch.cpp new file mode 100644 index 000000000..7e43609cc --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/launch.cpp @@ -0,0 +1,181 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMIN_f32_1x256(__gm__ float *dst, __gm__ float *src); + +void LaunchTCOLMIN_f32_1x256(float *dst, float *src, void *stream) { + TCOLMIN_f32_1x256<<<1, nullptr, stream>>>((__gm__ float *)dst, (__gm__ float *)src); +} + +// Case 1: f32 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLMIN_f32_16x128(__gm__ float *dst, __gm__ float *src); + +void LaunchTCOLMIN_f32_16x128(float *dst, float *src, void *stream) { + TCOLMIN_f32_16x128<<<1, nullptr, stream>>>((__gm__ float *)dst, (__gm__ float *)src); +} + +// Case 2: f32 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMIN_f32_16x256(__gm__ float *dst, __gm__ float *src); + +void LaunchTCOLMIN_f32_16x256(float *dst, float *src, void *stream) { + TCOLMIN_f32_16x256<<<1, nullptr, stream>>>((__gm__ float *)dst, (__gm__ float *)src); +} + +// Case 3: f16 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMIN_f16_1x256(__gm__ half *dst, __gm__ half *src); + +void LaunchTCOLMIN_f16_1x256(void *dst, void *src, void *stream) { + TCOLMIN_f16_1x256<<<1, nullptr, stream>>>((__gm__ half *)dst, (__gm__ half *)src); +} + +// Case 4: f16 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLMIN_f16_16x128(__gm__ half *dst, __gm__ half *src); + +void LaunchTCOLMIN_f16_16x128(void *dst, void *src, void *stream) { + TCOLMIN_f16_16x128<<<1, nullptr, stream>>>((__gm__ half *)dst, (__gm__ half *)src); +} + +// Case 5: f16 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMIN_f16_16x256(__gm__ half *dst, __gm__ half *src); + +void LaunchTCOLMIN_f16_16x256(void *dst, void *src, void *stream) { + TCOLMIN_f16_16x256<<<1, nullptr, stream>>>((__gm__ half *)dst, (__gm__ half *)src); +} + +// Case 6: i8 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMIN_i8_1x256(__gm__ int8_t *dst, __gm__ int8_t *src); + +void LaunchTCOLMIN_i8_1x256(void *dst, void *src, void *stream) { + TCOLMIN_i8_1x256<<<1, nullptr, stream>>>((__gm__ int8_t *)dst, (__gm__ int8_t *)src); +} + +// Case 7: i8 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLMIN_i8_16x128(__gm__ int8_t *dst, __gm__ int8_t *src); + +void LaunchTCOLMIN_i8_16x128(void *dst, void *src, void *stream) { + TCOLMIN_i8_16x128<<<1, nullptr, stream>>>((__gm__ int8_t *)dst, (__gm__ int8_t *)src); +} + +// Case 8: i8 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMIN_i8_16x256(__gm__ int8_t *dst, __gm__ int8_t *src); + +void LaunchTCOLMIN_i8_16x256(void *dst, void *src, void *stream) { + TCOLMIN_i8_16x256<<<1, nullptr, stream>>>((__gm__ int8_t *)dst, (__gm__ int8_t *)src); +} + +// Case 9: i16 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMIN_i16_1x256(__gm__ int16_t *dst, __gm__ int16_t *src); + +void LaunchTCOLMIN_i16_1x256(void *dst, void *src, void *stream) { + TCOLMIN_i16_1x256<<<1, nullptr, stream>>>((__gm__ int16_t *)dst, (__gm__ int16_t *)src); +} + +// Case 10: i16 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLMIN_i16_16x128(__gm__ int16_t *dst, __gm__ int16_t *src); + +void LaunchTCOLMIN_i16_16x128(void *dst, void *src, void *stream) { + TCOLMIN_i16_16x128<<<1, nullptr, stream>>>((__gm__ int16_t *)dst, (__gm__ int16_t *)src); +} + +// Case 11: i16 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMIN_i16_16x256(__gm__ int16_t *dst, __gm__ int16_t *src); + +void LaunchTCOLMIN_i16_16x256(void *dst, void *src, void *stream) { + TCOLMIN_i16_16x256<<<1, nullptr, stream>>>((__gm__ int16_t *)dst, (__gm__ int16_t *)src); +} + +// Case 12: i32 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMIN_i32_1x256(__gm__ int32_t *dst, __gm__ int32_t *src); + +void LaunchTCOLMIN_i32_1x256(void *dst, void *src, void *stream) { + TCOLMIN_i32_1x256<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ int32_t *)src); +} + +// Case 13: i32 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLMIN_i32_16x128(__gm__ int32_t *dst, __gm__ int32_t *src); + +void LaunchTCOLMIN_i32_16x128(void *dst, void *src, void *stream) { + TCOLMIN_i32_16x128<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ int32_t *)src); +} + +// Case 14: i32 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMIN_i32_16x256(__gm__ int32_t *dst, __gm__ int32_t *src); + +void LaunchTCOLMIN_i32_16x256(void *dst, void *src, void *stream) { + TCOLMIN_i32_16x256<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ int32_t *)src); +} + +// Case 15: ui8 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMIN_ui8_1x256(__gm__ uint8_t *dst, __gm__ uint8_t *src); + +void LaunchTCOLMIN_ui8_1x256(void *dst, void *src, void *stream) { + TCOLMIN_ui8_1x256<<<1, nullptr, stream>>>((__gm__ uint8_t *)dst, (__gm__ uint8_t *)src); +} + +// Case 16: ui8 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLMIN_ui8_16x128(__gm__ uint8_t *dst, __gm__ uint8_t *src); + +void LaunchTCOLMIN_ui8_16x128(void *dst, void *src, void *stream) { + TCOLMIN_ui8_16x128<<<1, nullptr, stream>>>((__gm__ uint8_t *)dst, (__gm__ uint8_t *)src); +} + +// Case 17: ui8 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMIN_ui8_16x256(__gm__ uint8_t *dst, __gm__ uint8_t *src); + +void LaunchTCOLMIN_ui8_16x256(void *dst, void *src, void *stream) { + TCOLMIN_ui8_16x256<<<1, nullptr, stream>>>((__gm__ uint8_t *)dst, (__gm__ uint8_t *)src); +} + +// Case 18: ui16 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMIN_ui16_1x256(__gm__ uint16_t *dst, __gm__ uint16_t *src); + +void LaunchTCOLMIN_ui16_1x256(void *dst, void *src, void *stream) { + TCOLMIN_ui16_1x256<<<1, nullptr, stream>>>((__gm__ uint16_t *)dst, (__gm__ uint16_t *)src); +} + +// Case 19: ui16 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLMIN_ui16_16x128(__gm__ uint16_t *dst, __gm__ uint16_t *src); + +void LaunchTCOLMIN_ui16_16x128(void *dst, void *src, void *stream) { + TCOLMIN_ui16_16x128<<<1, nullptr, stream>>>((__gm__ uint16_t *)dst, (__gm__ uint16_t *)src); +} + +// Case 20: ui16 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMIN_ui16_16x256(__gm__ uint16_t *dst, __gm__ uint16_t *src); + +void LaunchTCOLMIN_ui16_16x256(void *dst, void *src, void *stream) { + TCOLMIN_ui16_16x256<<<1, nullptr, stream>>>((__gm__ uint16_t *)dst, (__gm__ uint16_t *)src); +} + +// Case 21: ui32 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMIN_ui32_1x256(__gm__ uint32_t *dst, __gm__ uint32_t *src); + +void LaunchTCOLMIN_ui32_1x256(void *dst, void *src, void *stream) { + TCOLMIN_ui32_1x256<<<1, nullptr, stream>>>((__gm__ uint32_t *)dst, (__gm__ uint32_t *)src); +} + +// Case 22: ui32 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLMIN_ui32_16x128(__gm__ uint32_t *dst, __gm__ uint32_t *src); + +void LaunchTCOLMIN_ui32_16x128(void *dst, void *src, void *stream) { + TCOLMIN_ui32_16x128<<<1, nullptr, stream>>>((__gm__ uint32_t *)dst, (__gm__ uint32_t *)src); +} + +// Case 23: ui32 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMIN_ui32_16x256(__gm__ uint32_t *dst, __gm__ uint32_t *src); + +void LaunchTCOLMIN_ui32_16x256(void *dst, void *src, void *stream) { + TCOLMIN_ui32_16x256<<<1, nullptr, stream>>>((__gm__ uint32_t *)dst, (__gm__ uint32_t *)src); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/main.cpp new file mode 100644 index 000000000..c24bc5316 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/main.cpp @@ -0,0 +1,189 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tcolmin ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTCOLMIN_f32_1x256(float *dst, float *src, void *stream); +void LaunchTCOLMIN_f32_16x128(float *dst, float *src, void *stream); +void LaunchTCOLMIN_f32_16x256(float *dst, float *src, void *stream); +void LaunchTCOLMIN_f16_1x256(void *dst, void *src, void *stream); +void LaunchTCOLMIN_f16_16x128(void *dst, void *src, void *stream); +void LaunchTCOLMIN_f16_16x256(void *dst, void *src, void *stream); +void LaunchTCOLMIN_i8_1x256(void *dst, void *src, void *stream); +void LaunchTCOLMIN_i8_16x128(void *dst, void *src, void *stream); +void LaunchTCOLMIN_i8_16x256(void *dst, void *src, void *stream); +void LaunchTCOLMIN_i16_1x256(void *dst, void *src, void *stream); +void LaunchTCOLMIN_i16_16x128(void *dst, void *src, void *stream); +void LaunchTCOLMIN_i16_16x256(void *dst, void *src, void *stream); +void LaunchTCOLMIN_i32_1x256(void *dst, void *src, void *stream); +void LaunchTCOLMIN_i32_16x128(void *dst, void *src, void *stream); +void LaunchTCOLMIN_i32_16x256(void *dst, void *src, void *stream); +void LaunchTCOLMIN_ui8_1x256(void *dst, void *src, void *stream); +void LaunchTCOLMIN_ui8_16x128(void *dst, void *src, void *stream); +void LaunchTCOLMIN_ui8_16x256(void *dst, void *src, void *stream); +void LaunchTCOLMIN_ui16_1x256(void *dst, void *src, void *stream); +void LaunchTCOLMIN_ui16_16x128(void *dst, void *src, void *stream); +void LaunchTCOLMIN_ui16_16x256(void *dst, void *src, void *stream); +void LaunchTCOLMIN_ui32_1x256(void *dst, void *src, void *stream); +void LaunchTCOLMIN_ui32_16x128(void *dst, void *src, void *stream); +void LaunchTCOLMIN_ui32_16x256(void *dst, void *src, void *stream); + +using LaunchFnFloat = void (*)(float *, float *, void *); +using LaunchFnVoid = void (*)(void *, void *, void *); + +struct TestCase { + const char *name; + void *launch; + size_t srcRows; + size_t srcCols; + size_t srcValidRows; + size_t srcValidCols; + size_t dstRows; + size_t dstCols; + size_t dstValidCols; + size_t elemSize; + bool isFp16; +}; + +static const TestCase kCases[] = { + {"f32_1x256", (void*)LaunchTCOLMIN_f32_1x256, 1, 256, 1, 255, 1, 256, 255, sizeof(float), false}, + {"f32_16x128", (void*)LaunchTCOLMIN_f32_16x128, 16, 128, 16, 127, 1, 128, 127, sizeof(float), false}, + {"f32_16x256", (void*)LaunchTCOLMIN_f32_16x256, 16, 256, 15, 255, 1, 256, 255, sizeof(float), false}, + {"f16_1x256", (void*)LaunchTCOLMIN_f16_1x256, 1, 256, 1, 255, 1, 256, 255, 2, true}, + {"f16_16x128", (void*)LaunchTCOLMIN_f16_16x128, 16, 128, 16, 127, 1, 128, 127, 2, true}, + {"f16_16x256", (void*)LaunchTCOLMIN_f16_16x256, 16, 256, 15, 255, 1, 256, 255, 2, true}, + {"i8_1x256", (void*)LaunchTCOLMIN_i8_1x256, 1, 256, 1, 255, 1, 256, 255, 1, true}, + {"i8_16x128", (void*)LaunchTCOLMIN_i8_16x128, 16, 128, 16, 127, 1, 128, 127, 1, true}, + {"i8_16x256", (void*)LaunchTCOLMIN_i8_16x256, 16, 256, 15, 255, 1, 256, 255, 1, true}, + {"i16_1x256", (void*)LaunchTCOLMIN_i16_1x256, 1, 256, 1, 255, 1, 256, 255, 2, true}, + {"i16_16x128", (void*)LaunchTCOLMIN_i16_16x128, 16, 128, 16, 127, 1, 128, 127, 2, true}, + {"i16_16x256", (void*)LaunchTCOLMIN_i16_16x256, 16, 256, 15, 255, 1, 256, 255, 2, true}, + {"i32_1x256", (void*)LaunchTCOLMIN_i32_1x256, 1, 256, 1, 255, 1, 256, 255, 4, true}, + {"i32_16x128", (void*)LaunchTCOLMIN_i32_16x128, 16, 128, 16, 127, 1, 128, 127, 4, true}, + {"i32_16x256", (void*)LaunchTCOLMIN_i32_16x256, 16, 256, 15, 255, 1, 256, 255, 4, true}, + {"ui8_1x256", (void*)LaunchTCOLMIN_ui8_1x256, 1, 256, 1, 255, 1, 256, 255, 1, true}, + {"ui8_16x128", (void*)LaunchTCOLMIN_ui8_16x128, 16, 128, 16, 127, 1, 128, 127, 1, true}, + {"ui8_16x256", (void*)LaunchTCOLMIN_ui8_16x256, 16, 256, 15, 255, 1, 256, 255, 1, true}, + {"ui16_1x256", (void*)LaunchTCOLMIN_ui16_1x256, 1, 256, 1, 255, 1, 256, 255, 2, true}, + {"ui16_16x128", (void*)LaunchTCOLMIN_ui16_16x128, 16, 128, 16, 127, 1, 128, 127, 2, true}, + {"ui16_16x256", (void*)LaunchTCOLMIN_ui16_16x256, 16, 256, 15, 255, 1, 256, 255, 2, true}, + {"ui32_1x256", (void*)LaunchTCOLMIN_ui32_1x256, 1, 256, 1, 255, 1, 256, 255, 4, true}, + {"ui32_16x128", (void*)LaunchTCOLMIN_ui32_16x128, 16, 128, 16, 127, 1, 128, 127, 4, true}, + {"ui32_16x256", (void*)LaunchTCOLMIN_ui32_16x256, 16, 256, 15, 255, 1, 256, 255, 4, true}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t srcElemCount = tc.srcRows * tc.srcCols; + const size_t srcFileSize = srcElemCount * tc.elemSize; + const size_t dstElemCount = tc.dstRows * tc.dstCols; + const size_t dstFileSize = dstElemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (src=%zux%zu, dst=%zux%zu, fp16=%d) ===\n", + tc.name, tc.srcRows, tc.srcCols, tc.dstRows, tc.dstCols, tc.isFp16); + + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSizeVar = srcFileSize; + size_t dstFileSizeVar = dstFileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, srcFileSize); + aclrtMallocHost(&dstHost, dstFileSize); + + aclrtMalloc(&srcDevice, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), srcFileSizeVar, srcHost, srcFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, srcFileSize, srcHost, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + if (tc.isFp16) { + LaunchFnVoid launch = (LaunchFnVoid)tc.launch; + launch(dstDevice, srcDevice, stream); + } else { + LaunchFnFloat launch = (LaunchFnFloat)tc.launch; + launch((float*)dstDevice, (float*)srcDevice, stream); + } + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSizeVar)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/tcolmin.pto b/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/tcolmin.pto new file mode 100644 index 000000000..a6a0a0ac3 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/tcolmin.pto @@ -0,0 +1,1182 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tcolmin: tload(src) + tcolmin(dst, src) + tstore(dst). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. + +module { + // Case 0: f32 1x256 (input: 1x256, output: 1x256) + func.func @TCOLMIN_f32_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x255xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x255xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xf32>) + return + } + + // Case 1: f32 16x128 (input: 16x128, output: 1x128) + func.func @TCOLMIN_f32_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x127xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x127xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xf32>) + return + } + + // Case 2: f32 16x256 (input: 16x256, output: 1x256) + func.func @TCOLMIN_f32_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xf32> -> !pto.partition_tensor_view<1x1x1x15x255xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x255xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xf32>) + return + } + + // Case 3: f16 1x256 (input: 1x256, output: 1x256) + func.func @TCOLMIN_f16_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x255xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x255xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xf16>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xf16>) + return + } + + // Case 4: f16 16x128 (input: 16x128, output: 1x128) + func.func @TCOLMIN_f16_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x127xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xf16> -> !pto.partition_tensor_view<1x1x1x1x127xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xf16>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xf16>) + return + } + + // Case 5: f16 16x256 (input: 16x256, output: 1x256) + func.func @TCOLMIN_f16_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xf16> -> !pto.partition_tensor_view<1x1x1x15x255xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x255xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xf16>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xf16>) + return + } + + // Case 6: i8 1x256 (input: 1x256, output: 1x256) + func.func @TCOLMIN_i8_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi8> -> !pto.partition_tensor_view<1x1x1x1x255xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi8> -> !pto.partition_tensor_view<1x1x1x1x255xi8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xi8>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi8>) + return + } + + // Case 7: i8 16x128 (input: 16x128, output: 1x128) + func.func @TCOLMIN_i8_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xi8> -> !pto.partition_tensor_view<1x1x1x16x127xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xi8> -> !pto.partition_tensor_view<1x1x1x1x127xi8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xi8>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xi8>) + return + } + + // Case 8: i8 16x256 (input: 16x256, output: 1x256) + func.func @TCOLMIN_i8_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xi8> -> !pto.partition_tensor_view<1x1x1x15x255xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi8> -> !pto.partition_tensor_view<1x1x1x1x255xi8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xi8>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi8>) + return + } + + // Case 9: i16 1x256 (input: 1x256, output: 1x256) + func.func @TCOLMIN_i16_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi16> -> !pto.partition_tensor_view<1x1x1x1x255xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi16> -> !pto.partition_tensor_view<1x1x1x1x255xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xi16>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi16>) + return + } + + // Case 10: i16 16x128 (input: 16x128, output: 1x128) + func.func @TCOLMIN_i16_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xi16> -> !pto.partition_tensor_view<1x1x1x16x127xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xi16> -> !pto.partition_tensor_view<1x1x1x1x127xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xi16>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xi16>) + return + } + + // Case 11: i16 16x256 (input: 16x256, output: 1x256) + func.func @TCOLMIN_i16_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xi16> -> !pto.partition_tensor_view<1x1x1x15x255xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi16> -> !pto.partition_tensor_view<1x1x1x1x255xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xi16>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi16>) + return + } + + // Case 12: i32 1x256 (input: 1x256, output: 1x256) + func.func @TCOLMIN_i32_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + return + } + + // Case 13: i32 16x128 (input: 16x128, output: 1x128) + func.func @TCOLMIN_i32_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xi32> -> !pto.partition_tensor_view<1x1x1x16x127xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xi32> -> !pto.partition_tensor_view<1x1x1x1x127xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xi32>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xi32>) + return + } + + // Case 14: i32 16x256 (input: 16x256, output: 1x256) + func.func @TCOLMIN_i32_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xi32> -> !pto.partition_tensor_view<1x1x1x15x255xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xi32>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + return + } + + // Case 15: ui8 1x256 (input: 1x256, output: 1x256) + func.func @TCOLMIN_ui8_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui8> -> !pto.partition_tensor_view<1x1x1x1x255xui8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui8> -> !pto.partition_tensor_view<1x1x1x1x255xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xui8>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xui8>) + return + } + + // Case 16: ui8 16x128 (input: 16x128, output: 1x128) + func.func @TCOLMIN_ui8_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xui8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xui8> -> !pto.partition_tensor_view<1x1x1x16x127xui8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xui8> -> !pto.partition_tensor_view<1x1x1x1x127xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xui8>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xui8>) + return + } + + // Case 17: ui8 16x256 (input: 16x256, output: 1x256) + func.func @TCOLMIN_ui8_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xui8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xui8> -> !pto.partition_tensor_view<1x1x1x15x255xui8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui8> -> !pto.partition_tensor_view<1x1x1x1x255xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xui8>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xui8>) + return + } + + // Case 18: ui16 1x256 (input: 1x256, output: 1x256) + func.func @TCOLMIN_ui16_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui16> -> !pto.partition_tensor_view<1x1x1x1x255xui16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui16> -> !pto.partition_tensor_view<1x1x1x1x255xui16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xui16>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xui16>) + return + } + + // Case 19: ui16 16x128 (input: 16x128, output: 1x128) + func.func @TCOLMIN_ui16_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xui16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xui16> -> !pto.partition_tensor_view<1x1x1x16x127xui16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xui16> -> !pto.partition_tensor_view<1x1x1x1x127xui16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xui16>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xui16>) + return + } + + // Case 20: ui16 16x256 (input: 16x256, output: 1x256) + func.func @TCOLMIN_ui16_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xui16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xui16> -> !pto.partition_tensor_view<1x1x1x15x255xui16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui16> -> !pto.partition_tensor_view<1x1x1x1x255xui16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xui16>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xui16>) + return + } + + // Case 21: ui32 1x256 (input: 1x256, output: 1x256) + func.func @TCOLMIN_ui32_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui32> -> !pto.partition_tensor_view<1x1x1x1x255xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui32> -> !pto.partition_tensor_view<1x1x1x1x255xui32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xui32>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xui32>) + return + } + + // Case 22: ui32 16x128 (input: 16x128, output: 1x128) + func.func @TCOLMIN_ui32_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xui32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xui32> -> !pto.partition_tensor_view<1x1x1x16x127xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xui32> -> !pto.partition_tensor_view<1x1x1x1x127xui32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xui32>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xui32>) + return + } + + // Case 23: ui32 16x256 (input: 16x256, output: 1x256) + func.func @TCOLMIN_ui32_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xui32> -> !pto.partition_tensor_view<1x1x1x15x255xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui32> -> !pto.partition_tensor_view<1x1x1x1x255xui32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xui32>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xui32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/CMakeLists.txt new file mode 100644 index 000000000..02b874532 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tcolprod) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/cases.py new file mode 100644 index 000000000..e95d300ec --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/cases.py @@ -0,0 +1,164 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tcolprod ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions for input. + - valid_shape: (valid_rows, valid_cols) — effective computation region for input. + - dst_shape: (1, cols) — allocated tile dimensions for output. + - dst_valid_shape: (1, valid_cols) — effective computation region for output. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_1x256", + "dtype": np.float32, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 1e-6, + }, + { + "name": "f32_16x128", + "dtype": np.float32, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 1e-6, + }, + { + "name": "f32_16x256", + "dtype": np.float32, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 1e-6, + }, + { + "name": "i16_1x256", + "dtype": np.int16, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "i16_16x128", + "dtype": np.int16, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 0, + }, + { + "name": "i16_16x256", + "dtype": np.int16, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "ui16_1x256", + "dtype": np.uint16, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "ui16_16x128", + "dtype": np.uint16, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 0, + }, + { + "name": "ui16_16x256", + "dtype": np.uint16, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "i32_1x256", + "dtype": np.int32, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "i32_16x128", + "dtype": np.int32, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 0, + }, + { + "name": "i32_16x256", + "dtype": np.int32, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "ui32_1x256", + "dtype": np.uint32, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "ui32_16x128", + "dtype": np.uint32, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 0, + }, + { + "name": "ui32_16x256", + "dtype": np.uint32, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/compare.py new file mode 100644 index 000000000..06a17bbda --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + vr, vc = dst_valid_shape + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(dst_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(dst_shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/gen_data.py new file mode 100644 index 000000000..ff3b740eb --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/gen_data.py @@ -0,0 +1,35 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + + vr, vc = valid_shape + golden = np.zeros(dst_shape, dtype=dtype) + golden_result = np.prod(input1[:vr, :vc], axis=0, keepdims=True).astype(dtype) + golden[:1, :vc] = golden_result + + save_case_data(case["name"], {"input": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dst_shape={dst_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/launch.cpp new file mode 100644 index 000000000..158d2a06e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/launch.cpp @@ -0,0 +1,118 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLPROD_f32_1x256(__gm__ float *dst, __gm__ float *src); + +void LaunchTCOLPROD_f32_1x256(float *dst, float *src, void *stream) { + TCOLPROD_f32_1x256<<<1, nullptr, stream>>>((__gm__ float *)dst, (__gm__ float *)src); +} + +// Case 1: f32 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLPROD_f32_16x128(__gm__ float *dst, __gm__ float *src); + +void LaunchTCOLPROD_f32_16x128(float *dst, float *src, void *stream) { + TCOLPROD_f32_16x128<<<1, nullptr, stream>>>((__gm__ float *)dst, (__gm__ float *)src); +} + +// Case 2: f32 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLPROD_f32_16x256(__gm__ float *dst, __gm__ float *src); + +void LaunchTCOLPROD_f32_16x256(float *dst, float *src, void *stream) { + TCOLPROD_f32_16x256<<<1, nullptr, stream>>>((__gm__ float *)dst, (__gm__ float *)src); +} + +// Case 3: i16 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLPROD_i16_1x256(__gm__ int16_t *dst, __gm__ int16_t *src); + +void LaunchTCOLPROD_i16_1x256(void *dst, void *src, void *stream) { + TCOLPROD_i16_1x256<<<1, nullptr, stream>>>((__gm__ int16_t *)dst, (__gm__ int16_t *)src); +} + +// Case 4: i16 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLPROD_i16_16x128(__gm__ int16_t *dst, __gm__ int16_t *src); + +void LaunchTCOLPROD_i16_16x128(void *dst, void *src, void *stream) { + TCOLPROD_i16_16x128<<<1, nullptr, stream>>>((__gm__ int16_t *)dst, (__gm__ int16_t *)src); +} + +// Case 5: i16 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLPROD_i16_16x256(__gm__ int16_t *dst, __gm__ int16_t *src); + +void LaunchTCOLPROD_i16_16x256(void *dst, void *src, void *stream) { + TCOLPROD_i16_16x256<<<1, nullptr, stream>>>((__gm__ int16_t *)dst, (__gm__ int16_t *)src); +} + +// Case 6: ui16 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLPROD_ui16_1x256(__gm__ uint16_t *dst, __gm__ uint16_t *src); + +void LaunchTCOLPROD_ui16_1x256(void *dst, void *src, void *stream) { + TCOLPROD_ui16_1x256<<<1, nullptr, stream>>>((__gm__ uint16_t *)dst, (__gm__ uint16_t *)src); +} + +// Case 7: ui16 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLPROD_ui16_16x128(__gm__ uint16_t *dst, __gm__ uint16_t *src); + +void LaunchTCOLPROD_ui16_16x128(void *dst, void *src, void *stream) { + TCOLPROD_ui16_16x128<<<1, nullptr, stream>>>((__gm__ uint16_t *)dst, (__gm__ uint16_t *)src); +} + +// Case 8: ui16 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLPROD_ui16_16x256(__gm__ uint16_t *dst, __gm__ uint16_t *src); + +void LaunchTCOLPROD_ui16_16x256(void *dst, void *src, void *stream) { + TCOLPROD_ui16_16x256<<<1, nullptr, stream>>>((__gm__ uint16_t *)dst, (__gm__ uint16_t *)src); +} + +// Case 9: i32 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLPROD_i32_1x256(__gm__ int32_t *dst, __gm__ int32_t *src); + +void LaunchTCOLPROD_i32_1x256(void *dst, void *src, void *stream) { + TCOLPROD_i32_1x256<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ int32_t *)src); +} + +// Case 10: i32 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLPROD_i32_16x128(__gm__ int32_t *dst, __gm__ int32_t *src); + +void LaunchTCOLPROD_i32_16x128(void *dst, void *src, void *stream) { + TCOLPROD_i32_16x128<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ int32_t *)src); +} + +// Case 11: i32 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLPROD_i32_16x256(__gm__ int32_t *dst, __gm__ int32_t *src); + +void LaunchTCOLPROD_i32_16x256(void *dst, void *src, void *stream) { + TCOLPROD_i32_16x256<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ int32_t *)src); +} + +// Case 12: ui32 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLPROD_ui32_1x256(__gm__ uint32_t *dst, __gm__ uint32_t *src); + +void LaunchTCOLPROD_ui32_1x256(void *dst, void *src, void *stream) { + TCOLPROD_ui32_1x256<<<1, nullptr, stream>>>((__gm__ uint32_t *)dst, (__gm__ uint32_t *)src); +} + +// Case 13: ui32 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLPROD_ui32_16x128(__gm__ uint32_t *dst, __gm__ uint32_t *src); + +void LaunchTCOLPROD_ui32_16x128(void *dst, void *src, void *stream) { + TCOLPROD_ui32_16x128<<<1, nullptr, stream>>>((__gm__ uint32_t *)dst, (__gm__ uint32_t *)src); +} + +// Case 14: ui32 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLPROD_ui32_16x256(__gm__ uint32_t *dst, __gm__ uint32_t *src); + +void LaunchTCOLPROD_ui32_16x256(void *dst, void *src, void *stream) { + TCOLPROD_ui32_16x256<<<1, nullptr, stream>>>((__gm__ uint32_t *)dst, (__gm__ uint32_t *)src); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/main.cpp new file mode 100644 index 000000000..6850592a1 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/main.cpp @@ -0,0 +1,170 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. +// Host driver for TileLang tcolprod ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTCOLPROD_f32_1x256(float *dst, float *src, void *stream); +void LaunchTCOLPROD_f32_16x128(float *dst, float *src, void *stream); +void LaunchTCOLPROD_f32_16x256(float *dst, float *src, void *stream); +void LaunchTCOLPROD_i16_1x256(void *dst, void *src, void *stream); +void LaunchTCOLPROD_i16_16x128(void *dst, void *src, void *stream); +void LaunchTCOLPROD_i16_16x256(void *dst, void *src, void *stream); +void LaunchTCOLPROD_ui16_1x256(void *dst, void *src, void *stream); +void LaunchTCOLPROD_ui16_16x128(void *dst, void *src, void *stream); +void LaunchTCOLPROD_ui16_16x256(void *dst, void *src, void *stream); +void LaunchTCOLPROD_i32_1x256(void *dst, void *src, void *stream); +void LaunchTCOLPROD_i32_16x128(void *dst, void *src, void *stream); +void LaunchTCOLPROD_i32_16x256(void *dst, void *src, void *stream); +void LaunchTCOLPROD_ui32_1x256(void *dst, void *src, void *stream); +void LaunchTCOLPROD_ui32_16x128(void *dst, void *src, void *stream); +void LaunchTCOLPROD_ui32_16x256(void *dst, void *src, void *stream); + +using LaunchFnFloat = void (*)(float *, float *, void *); +using LaunchFnVoid = void (*)(void *, void *, void *); + +struct TestCase { + const char *name; + void *launch; + size_t srcRows; + size_t srcCols; + size_t srcValidRows; + size_t srcValidCols; + size_t dstRows; + size_t dstCols; + size_t dstValidCols; + size_t elemSize; + bool isFp16; +}; + +static const TestCase kCases[] = { + {"f32_1x256", (void*)LaunchTCOLPROD_f32_1x256, 1, 256, 1, 255, 1, 256, 255, sizeof(float), false}, + {"f32_16x128", (void*)LaunchTCOLPROD_f32_16x128, 16, 128, 16, 127, 1, 128, 127, sizeof(float), false}, + {"f32_16x256", (void*)LaunchTCOLPROD_f32_16x256, 16, 256, 15, 255, 1, 256, 255, sizeof(float), false}, + {"i16_1x256", (void*)LaunchTCOLPROD_i16_1x256, 1, 256, 1, 255, 1, 256, 255, 2, true}, + {"i16_16x128", (void*)LaunchTCOLPROD_i16_16x128, 16, 128, 16, 127, 1, 128, 127, 2, true}, + {"i16_16x256", (void*)LaunchTCOLPROD_i16_16x256, 16, 256, 15, 255, 1, 256, 255, 2, true}, + {"ui16_1x256", (void*)LaunchTCOLPROD_ui16_1x256, 1, 256, 1, 255, 1, 256, 255, 2, true}, + {"ui16_16x128", (void*)LaunchTCOLPROD_ui16_16x128, 16, 128, 16, 127, 1, 128, 127, 2, true}, + {"ui16_16x256", (void*)LaunchTCOLPROD_ui16_16x256, 16, 256, 15, 255, 1, 256, 255, 2, true}, + {"i32_1x256", (void*)LaunchTCOLPROD_i32_1x256, 1, 256, 1, 255, 1, 256, 255, 4, true}, + {"i32_16x128", (void*)LaunchTCOLPROD_i32_16x128, 16, 128, 16, 127, 1, 128, 127, 4, true}, + {"i32_16x256", (void*)LaunchTCOLPROD_i32_16x256, 16, 256, 15, 255, 1, 256, 255, 4, true}, + {"ui32_1x256", (void*)LaunchTCOLPROD_ui32_1x256, 1, 256, 1, 255, 1, 256, 255, 4, true}, + {"ui32_16x128", (void*)LaunchTCOLPROD_ui32_16x128, 16, 128, 16, 127, 1, 128, 127, 4, true}, + {"ui32_16x256", (void*)LaunchTCOLPROD_ui32_16x256, 16, 256, 15, 255, 1, 256, 255, 4, true}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t srcElemCount = tc.srcRows * tc.srcCols; + const size_t srcFileSize = srcElemCount * tc.elemSize; + const size_t dstElemCount = tc.dstRows * tc.dstCols; + const size_t dstFileSize = dstElemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (src=%zux%zu, dst=%zux%zu, fp16=%d) ===\n", + tc.name, tc.srcRows, tc.srcCols, tc.dstRows, tc.dstCols, tc.isFp16); + + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSizeVar = srcFileSize; + size_t dstFileSizeVar = dstFileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, srcFileSize); + aclrtMallocHost(&dstHost, dstFileSize); + + aclrtMalloc(&srcDevice, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), srcFileSizeVar, srcHost, srcFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, srcFileSize, srcHost, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + if (tc.isFp16) { + LaunchFnVoid launch = (LaunchFnVoid)tc.launch; + launch(dstDevice, srcDevice, stream); + } else { + LaunchFnFloat launch = (LaunchFnFloat)tc.launch; + launch((float*)dstDevice, (float*)srcDevice, stream); + } + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSizeVar)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/tcolprod.pto b/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/tcolprod.pto new file mode 100644 index 000000000..0ef39ef04 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/tcolprod.pto @@ -0,0 +1,744 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tcolprod: tload(src) + tcolprod(dst, src) + tstore(dst). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. + +module { + // Case 0: f32 1x256 (input: 1x256, output: 1x256) + func.func @TCOLPROD_f32_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x255xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x255xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolprod ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xf32>) + return + } + + // Case 1: f32 16x128 (input: 16x128, output: 1x128) + func.func @TCOLPROD_f32_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x127xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x127xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolprod ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xf32>) + return + } + + // Case 2: f32 16x256 (input: 16x256, output: 1x256) + func.func @TCOLPROD_f32_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xf32> -> !pto.partition_tensor_view<1x1x1x15x255xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x255xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolprod ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xf32>) + return + } + + // Case 3: i16 1x256 (input: 1x256, output: 1x256) + func.func @TCOLPROD_i16_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi16> -> !pto.partition_tensor_view<1x1x1x1x255xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi16> -> !pto.partition_tensor_view<1x1x1x1x255xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xi16>) + outs(%src : !pto.tile_buf) + + pto.tcolprod ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi16>) + return + } + + // Case 4: i16 16x128 (input: 16x128, output: 1x128) + func.func @TCOLPROD_i16_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xi16> -> !pto.partition_tensor_view<1x1x1x16x127xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xi16> -> !pto.partition_tensor_view<1x1x1x1x127xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xi16>) + outs(%src : !pto.tile_buf) + + pto.tcolprod ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xi16>) + return + } + + // Case 5: i16 16x256 (input: 16x256, output: 1x256) + func.func @TCOLPROD_i16_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xi16> -> !pto.partition_tensor_view<1x1x1x15x255xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi16> -> !pto.partition_tensor_view<1x1x1x1x255xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xi16>) + outs(%src : !pto.tile_buf) + + pto.tcolprod ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi16>) + return + } + + // Case 6: ui16 1x256 (input: 1x256, output: 1x256) + func.func @TCOLPROD_ui16_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui16> -> !pto.partition_tensor_view<1x1x1x1x255xui16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui16> -> !pto.partition_tensor_view<1x1x1x1x255xui16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xui16>) + outs(%src : !pto.tile_buf) + + pto.tcolprod ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xui16>) + return + } + + // Case 7: ui16 16x128 (input: 16x128, output: 1x128) + func.func @TCOLPROD_ui16_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xui16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xui16> -> !pto.partition_tensor_view<1x1x1x16x127xui16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xui16> -> !pto.partition_tensor_view<1x1x1x1x127xui16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xui16>) + outs(%src : !pto.tile_buf) + + pto.tcolprod ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xui16>) + return + } + + // Case 8: ui16 16x256 (input: 16x256, output: 1x256) + func.func @TCOLPROD_ui16_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xui16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xui16> -> !pto.partition_tensor_view<1x1x1x15x255xui16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui16> -> !pto.partition_tensor_view<1x1x1x1x255xui16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xui16>) + outs(%src : !pto.tile_buf) + + pto.tcolprod ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xui16>) + return + } + + // Case 9: i32 1x256 (input: 1x256, output: 1x256) + func.func @TCOLPROD_i32_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + outs(%src : !pto.tile_buf) + + pto.tcolprod ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + return + } + + // Case 10: i32 16x128 (input: 16x128, output: 1x128) + func.func @TCOLPROD_i32_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xi32> -> !pto.partition_tensor_view<1x1x1x16x127xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xi32> -> !pto.partition_tensor_view<1x1x1x1x127xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xi32>) + outs(%src : !pto.tile_buf) + + pto.tcolprod ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xi32>) + return + } + + // Case 11: i32 16x256 (input: 16x256, output: 1x256) + func.func @TCOLPROD_i32_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xi32> -> !pto.partition_tensor_view<1x1x1x15x255xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xi32>) + outs(%src : !pto.tile_buf) + + pto.tcolprod ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + return + } + + // Case 12: ui32 1x256 (input: 1x256, output: 1x256) + func.func @TCOLPROD_ui32_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui32> -> !pto.partition_tensor_view<1x1x1x1x255xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui32> -> !pto.partition_tensor_view<1x1x1x1x255xui32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xui32>) + outs(%src : !pto.tile_buf) + + pto.tcolprod ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xui32>) + return + } + + // Case 13: ui32 16x128 (input: 16x128, output: 1x128) + func.func @TCOLPROD_ui32_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xui32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xui32> -> !pto.partition_tensor_view<1x1x1x16x127xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xui32> -> !pto.partition_tensor_view<1x1x1x1x127xui32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xui32>) + outs(%src : !pto.tile_buf) + + pto.tcolprod ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xui32>) + return + } + + // Case 14: ui32 16x256 (input: 16x256, output: 1x256) + func.func @TCOLPROD_ui32_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xui32> -> !pto.partition_tensor_view<1x1x1x15x255xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui32> -> !pto.partition_tensor_view<1x1x1x1x255xui32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xui32>) + outs(%src : !pto.tile_buf) + + pto.tcolprod ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xui32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/CMakeLists.txt new file mode 100644 index 000000000..e59d778af --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tcolsum) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/cases.py new file mode 100644 index 000000000..dfbcbddf3 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/cases.py @@ -0,0 +1,173 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tcolsum ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions for input. + - valid_shape: (valid_rows, valid_cols) — effective computation region for input. + - dst_shape: (1, cols) — allocated tile dimensions for output. + - dst_valid_shape: (1, valid_cols) — effective computation region for output. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_1x256", + "dtype": np.float32, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 1e-6, + }, + { + "name": "f32_16x128", + "dtype": np.float32, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 1e-6, + }, + { + "name": "f32_16x256", + "dtype": np.float32, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 1e-6, + }, + { + "name": "f32_64x128_1", + "dtype": np.float32, + "shape": (64, 128), + "valid_shape": (63, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 1e-6, + }, + { + "name": "f32_64x128_2", + "dtype": np.float32, + "shape": (64, 128), + "valid_shape": (64, 128), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 128), + "eps": 1e-6, + }, + { + "name": "f32_1x512", + "dtype": np.float32, + "shape": (1, 512), + "valid_shape": (1, 511), + "dst_shape": (1, 512), + "dst_valid_shape": (1, 511), + "eps": 1e-6, + }, + { + "name": "f16_1x256", + "dtype": np.float16, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 1e-3, + }, + { + "name": "f16_16x128", + "dtype": np.float16, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 1e-3, + }, + { + "name": "f16_16x256", + "dtype": np.float16, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 1e-3, + }, + { + "name": "f16_64x128_1", + "dtype": np.float16, + "shape": (64, 128), + "valid_shape": (63, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 1e-3, + }, + { + "name": "f16_64x128_2", + "dtype": np.float16, + "shape": (64, 128), + "valid_shape": (64, 128), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 128), + "eps": 1e-3, + }, + { + "name": "i8_1x256", + "dtype": np.int8, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "i8_16x128", + "dtype": np.int8, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 0, + }, + { + "name": "i8_16x256", + "dtype": np.int8, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "i8_64x128_1", + "dtype": np.int8, + "shape": (64, 128), + "valid_shape": (63, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 0, + }, + { + "name": "i8_64x128_2", + "dtype": np.int8, + "shape": (64, 128), + "valid_shape": (64, 128), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 128), + "eps": 0, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/compare.py new file mode 100644 index 000000000..06a17bbda --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + vr, vc = dst_valid_shape + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(dst_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(dst_shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/gen_data.py new file mode 100644 index 000000000..7c0e19eef --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/gen_data.py @@ -0,0 +1,35 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + + vr, vc = valid_shape + golden = np.zeros(dst_shape, dtype=dtype) + golden_result = np.sum(input1[:vr, :vc], axis=0, keepdims=True).astype(dtype) + golden[:1, :vc] = golden_result + + save_case_data(case["name"], {"input": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dst_shape={dst_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/launch.cpp new file mode 100644 index 000000000..6c8af717f --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/launch.cpp @@ -0,0 +1,125 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLSUM_f32_1x256(__gm__ float *dst, __gm__ float *src); + +void LaunchTCOLSUM_f32_1x256(float *dst, float *src, void *stream) { + TCOLSUM_f32_1x256<<<1, nullptr, stream>>>((__gm__ float *)dst, (__gm__ float *)src); +} + +// Case 1: f32 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLSUM_f32_16x128(__gm__ float *dst, __gm__ float *src); + +void LaunchTCOLSUM_f32_16x128(float *dst, float *src, void *stream) { + TCOLSUM_f32_16x128<<<1, nullptr, stream>>>((__gm__ float *)dst, (__gm__ float *)src); +} + +// Case 2: f32 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLSUM_f32_16x256(__gm__ float *dst, __gm__ float *src); + +void LaunchTCOLSUM_f32_16x256(float *dst, float *src, void *stream) { + TCOLSUM_f32_16x256<<<1, nullptr, stream>>>((__gm__ float *)dst, (__gm__ float *)src); +} + +// Case 3: f32 64x128_1 (input: 64x128, output: 1x128) +extern "C" __global__ AICORE void TCOLSUM_f32_64x128_1(__gm__ float *dst, __gm__ float *src); + +void LaunchTCOLSUM_f32_64x128_1(float *dst, float *src, void *stream) { + TCOLSUM_f32_64x128_1<<<1, nullptr, stream>>>((__gm__ float *)dst, (__gm__ float *)src); +} + +// Case 4: f32 64x128_2 (input: 64x128, output: 1x128) +extern "C" __global__ AICORE void TCOLSUM_f32_64x128_2(__gm__ float *dst, __gm__ float *src); + +void LaunchTCOLSUM_f32_64x128_2(float *dst, float *src, void *stream) { + TCOLSUM_f32_64x128_2<<<1, nullptr, stream>>>((__gm__ float *)dst, (__gm__ float *)src); +} + +// Case 5: f32 1x512 (input: 1x512, output: 1x512) +extern "C" __global__ AICORE void TCOLSUM_f32_1x512(__gm__ float *dst, __gm__ float *src); + +void LaunchTCOLSUM_f32_1x512(float *dst, float *src, void *stream) { + TCOLSUM_f32_1x512<<<1, nullptr, stream>>>((__gm__ float *)dst, (__gm__ float *)src); +} + +// Case 6: f16 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLSUM_f16_1x256(__gm__ half *dst, __gm__ half *src); + +void LaunchTCOLSUM_f16_1x256(void *dst, void *src, void *stream) { + TCOLSUM_f16_1x256<<<1, nullptr, stream>>>((__gm__ half *)dst, (__gm__ half *)src); +} + +// Case 7: f16 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLSUM_f16_16x128(__gm__ half *dst, __gm__ half *src); + +void LaunchTCOLSUM_f16_16x128(void *dst, void *src, void *stream) { + TCOLSUM_f16_16x128<<<1, nullptr, stream>>>((__gm__ half *)dst, (__gm__ half *)src); +} + +// Case 8: f16 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLSUM_f16_16x256(__gm__ half *dst, __gm__ half *src); + +void LaunchTCOLSUM_f16_16x256(void *dst, void *src, void *stream) { + TCOLSUM_f16_16x256<<<1, nullptr, stream>>>((__gm__ half *)dst, (__gm__ half *)src); +} + +// Case 9: f16 64x128_1 (input: 64x128, output: 1x128) +extern "C" __global__ AICORE void TCOLSUM_f16_64x128_1(__gm__ half *dst, __gm__ half *src); + +void LaunchTCOLSUM_f16_64x128_1(void *dst, void *src, void *stream) { + TCOLSUM_f16_64x128_1<<<1, nullptr, stream>>>((__gm__ half *)dst, (__gm__ half *)src); +} + +// Case 10: f16 64x128_2 (input: 64x128, output: 1x128) +extern "C" __global__ AICORE void TCOLSUM_f16_64x128_2(__gm__ half *dst, __gm__ half *src); + +void LaunchTCOLSUM_f16_64x128_2(void *dst, void *src, void *stream) { + TCOLSUM_f16_64x128_2<<<1, nullptr, stream>>>((__gm__ half *)dst, (__gm__ half *)src); +} + +// Case 11: i8 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLSUM_i8_1x256(__gm__ int8_t *dst, __gm__ int8_t *src); + +void LaunchTCOLSUM_i8_1x256(void *dst, void *src, void *stream) { + TCOLSUM_i8_1x256<<<1, nullptr, stream>>>((__gm__ int8_t *)dst, (__gm__ int8_t *)src); +} + +// Case 12: i8 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLSUM_i8_16x128(__gm__ int8_t *dst, __gm__ int8_t *src); + +void LaunchTCOLSUM_i8_16x128(void *dst, void *src, void *stream) { + TCOLSUM_i8_16x128<<<1, nullptr, stream>>>((__gm__ int8_t *)dst, (__gm__ int8_t *)src); +} + +// Case 13: i8 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLSUM_i8_16x256(__gm__ int8_t *dst, __gm__ int8_t *src); + +void LaunchTCOLSUM_i8_16x256(void *dst, void *src, void *stream) { + TCOLSUM_i8_16x256<<<1, nullptr, stream>>>((__gm__ int8_t *)dst, (__gm__ int8_t *)src); +} + +// Case 14: i8 64x128_1 (input: 64x128, output: 1x128) +extern "C" __global__ AICORE void TCOLSUM_i8_64x128_1(__gm__ int8_t *dst, __gm__ int8_t *src); + +void LaunchTCOLSUM_i8_64x128_1(void *dst, void *src, void *stream) { + TCOLSUM_i8_64x128_1<<<1, nullptr, stream>>>((__gm__ int8_t *)dst, (__gm__ int8_t *)src); +} + +// Case 15: i8 64x128_2 (input: 64x128, output: 1x128) +extern "C" __global__ AICORE void TCOLSUM_i8_64x128_2(__gm__ int8_t *dst, __gm__ int8_t *src); + +void LaunchTCOLSUM_i8_64x128_2(void *dst, void *src, void *stream) { + TCOLSUM_i8_64x128_2<<<1, nullptr, stream>>>((__gm__ int8_t *)dst, (__gm__ int8_t *)src); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/main.cpp new file mode 100644 index 000000000..eaff88d45 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/main.cpp @@ -0,0 +1,173 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tcolsum ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTCOLSUM_f32_1x256(float *dst, float *src, void *stream); +void LaunchTCOLSUM_f32_16x128(float *dst, float *src, void *stream); +void LaunchTCOLSUM_f32_16x256(float *dst, float *src, void *stream); +void LaunchTCOLSUM_f32_64x128_1(float *dst, float *src, void *stream); +void LaunchTCOLSUM_f32_64x128_2(float *dst, float *src, void *stream); +void LaunchTCOLSUM_f32_1x512(float *dst, float *src, void *stream); +void LaunchTCOLSUM_f16_1x256(void *dst, void *src, void *stream); +void LaunchTCOLSUM_f16_16x128(void *dst, void *src, void *stream); +void LaunchTCOLSUM_f16_16x256(void *dst, void *src, void *stream); +void LaunchTCOLSUM_f16_64x128_1(void *dst, void *src, void *stream); +void LaunchTCOLSUM_f16_64x128_2(void *dst, void *src, void *stream); +void LaunchTCOLSUM_i8_1x256(void *dst, void *src, void *stream); +void LaunchTCOLSUM_i8_16x128(void *dst, void *src, void *stream); +void LaunchTCOLSUM_i8_16x256(void *dst, void *src, void *stream); +void LaunchTCOLSUM_i8_64x128_1(void *dst, void *src, void *stream); +void LaunchTCOLSUM_i8_64x128_2(void *dst, void *src, void *stream); + +using LaunchFnFloat = void (*)(float *, float *, void *); +using LaunchFnVoid = void (*)(void *, void *, void *); + +struct TestCase { + const char *name; + void *launch; + size_t srcRows; + size_t srcCols; + size_t srcValidRows; + size_t srcValidCols; + size_t dstRows; + size_t dstCols; + size_t dstValidCols; + size_t elemSize; + bool isFp16; +}; + +static const TestCase kCases[] = { + {"f32_1x256", (void*)LaunchTCOLSUM_f32_1x256, 1, 256, 1, 255, 1, 256, 255, sizeof(float), false}, + {"f32_16x128", (void*)LaunchTCOLSUM_f32_16x128, 16, 128, 16, 127, 1, 128, 127, sizeof(float), false}, + {"f32_16x256", (void*)LaunchTCOLSUM_f32_16x256, 16, 256, 15, 255, 1, 256, 255, sizeof(float), false}, + {"f32_64x128_1", (void*)LaunchTCOLSUM_f32_64x128_1, 64, 128, 63, 127, 1, 128, 127, sizeof(float), false}, + {"f32_64x128_2", (void*)LaunchTCOLSUM_f32_64x128_2, 64, 128, 64, 128, 1, 128, 128, sizeof(float), false}, + {"f32_1x512", (void*)LaunchTCOLSUM_f32_1x512, 1, 512, 1, 511, 1, 512, 511, sizeof(float), false}, + {"f16_1x256", (void*)LaunchTCOLSUM_f16_1x256, 1, 256, 1, 255, 1, 256, 255, 2, true}, + {"f16_16x128", (void*)LaunchTCOLSUM_f16_16x128, 16, 128, 16, 127, 1, 128, 127, 2, true}, + {"f16_16x256", (void*)LaunchTCOLSUM_f16_16x256, 16, 256, 15, 255, 1, 256, 255, 2, true}, + {"f16_64x128_1", (void*)LaunchTCOLSUM_f16_64x128_1, 64, 128, 63, 127, 1, 128, 127, 2, true}, + {"f16_64x128_2", (void*)LaunchTCOLSUM_f16_64x128_2, 64, 128, 64, 128, 1, 128, 128, 2, true}, + {"i8_1x256", (void*)LaunchTCOLSUM_i8_1x256, 1, 256, 1, 255, 1, 256, 255, 1, true}, + {"i8_16x128", (void*)LaunchTCOLSUM_i8_16x128, 16, 128, 16, 127, 1, 128, 127, 1, true}, + {"i8_16x256", (void*)LaunchTCOLSUM_i8_16x256, 16, 256, 15, 255, 1, 256, 255, 1, true}, + {"i8_64x128_1", (void*)LaunchTCOLSUM_i8_64x128_1, 64, 128, 63, 127, 1, 128, 127, 1, true}, + {"i8_64x128_2", (void*)LaunchTCOLSUM_i8_64x128_2, 64, 128, 64, 128, 1, 128, 128, 1, true}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t srcElemCount = tc.srcRows * tc.srcCols; + const size_t srcFileSize = srcElemCount * tc.elemSize; + const size_t dstElemCount = tc.dstRows * tc.dstCols; + const size_t dstFileSize = dstElemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (src=%zux%zu, dst=%zux%zu, fp16=%d) ===\n", + tc.name, tc.srcRows, tc.srcCols, tc.dstRows, tc.dstCols, tc.isFp16); + + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSizeVar = srcFileSize; + size_t dstFileSizeVar = dstFileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, srcFileSize); + aclrtMallocHost(&dstHost, dstFileSize); + + aclrtMalloc(&srcDevice, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), srcFileSizeVar, srcHost, srcFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, srcFileSize, srcHost, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + if (tc.isFp16) { + LaunchFnVoid launch = (LaunchFnVoid)tc.launch; + launch(dstDevice, srcDevice, stream); + } else { + LaunchFnFloat launch = (LaunchFnFloat)tc.launch; + launch((float*)dstDevice, (float*)srcDevice, stream); + } + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSizeVar)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/tcolsum.pto b/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/tcolsum.pto new file mode 100644 index 000000000..4553cc48d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/tcolsum.pto @@ -0,0 +1,793 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tcolsum: tload(src) + tcolsum(dst, src) + tstore(dst). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. + +module { + // Case 0: f32 1x256 (input: 1x256, output: 1x256) + func.func @TCOLSUM_f32_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x255xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x255xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolsum ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xf32>) + return + } + + // Case 1: f32 16x128 (input: 16x128, output: 1x128) + func.func @TCOLSUM_f32_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x127xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x127xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolsum ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xf32>) + return + } + + // Case 2: f32 16x256 (input: 16x256, output: 1x256) + func.func @TCOLSUM_f32_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xf32> -> !pto.partition_tensor_view<1x1x1x15x255xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x255xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolsum ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xf32>) + return + } + + // Case 3: f32 64x128_1 (input: 64x128, output: 1x128) + func.func @TCOLSUM_f32_64x128_1(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c64, %c128], + strides = [%c8192, %c8192, %c8192, %c128, %c1] + : !pto.tensor_view<1x1x1x64x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c127] + : !pto.tensor_view<1x1x1x64x128xf32> -> !pto.partition_tensor_view<1x1x1x63x127xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x127xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x127xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolsum ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xf32>) + return + } + + // Case 4: f32 64x128_2 (input: 64x128, output: 1x128) + func.func @TCOLSUM_f32_64x128_2(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c64, %c128], + strides = [%c8192, %c8192, %c8192, %c128, %c1] + : !pto.tensor_view<1x1x1x64x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c128] + : !pto.tensor_view<1x1x1x64x128xf32> -> !pto.partition_tensor_view<1x1x1x64x128xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x64x128xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolsum ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + return + } + + // Case 5: f32 1x512 (input: 1x512, output: 1x512) + func.func @TCOLSUM_f32_1x512(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c512 = arith.constant 512 : index + %c511 = arith.constant 511 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c512], + strides = [%c512, %c512, %c512, %c512, %c1] + : !pto.tensor_view<1x1x1x1x512xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c512], + strides = [%c512, %c512, %c512, %c512, %c1] + : !pto.tensor_view<1x1x1x1x512xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c511] + : !pto.tensor_view<1x1x1x1x512xf32> -> !pto.partition_tensor_view<1x1x1x1x511xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c511] + : !pto.tensor_view<1x1x1x1x512xf32> -> !pto.partition_tensor_view<1x1x1x1x511xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x511xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolsum ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x511xf32>) + return + } + + // Case 6: f16 1x256 (input: 1x256, output: 1x256) + func.func @TCOLSUM_f16_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x255xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x255xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xf16>) + outs(%src : !pto.tile_buf) + + pto.tcolsum ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xf16>) + return + } + + // Case 7: f16 16x128 (input: 16x128, output: 1x128) + func.func @TCOLSUM_f16_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x127xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xf16> -> !pto.partition_tensor_view<1x1x1x1x127xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xf16>) + outs(%src : !pto.tile_buf) + + pto.tcolsum ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xf16>) + return + } + + // Case 8: f16 16x256 (input: 16x256, output: 1x256) + func.func @TCOLSUM_f16_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xf16> -> !pto.partition_tensor_view<1x1x1x15x255xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x255xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xf16>) + outs(%src : !pto.tile_buf) + + pto.tcolsum ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xf16>) + return + } + + // Case 9: f16 64x128_1 (input: 64x128, output: 1x128) + func.func @TCOLSUM_f16_64x128_1(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c64, %c128], + strides = [%c8192, %c8192, %c8192, %c128, %c1] + : !pto.tensor_view<1x1x1x64x128xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c127] + : !pto.tensor_view<1x1x1x64x128xf16> -> !pto.partition_tensor_view<1x1x1x63x127xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xf16> -> !pto.partition_tensor_view<1x1x1x1x127xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x127xf16>) + outs(%src : !pto.tile_buf) + + pto.tcolsum ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xf16>) + return + } + + // Case 10: f16 64x128_2 (input: 64x128, output: 1x128) + func.func @TCOLSUM_f16_64x128_2(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c64, %c128], + strides = [%c8192, %c8192, %c8192, %c128, %c1] + : !pto.tensor_view<1x1x1x64x128xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c128] + : !pto.tensor_view<1x1x1x64x128xf16> -> !pto.partition_tensor_view<1x1x1x64x128xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf16> -> !pto.partition_tensor_view<1x1x1x1x128xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x64x128xf16>) + outs(%src : !pto.tile_buf) + + pto.tcolsum ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xf16>) + return + } + + // Case 11: i8 1x256 (input: 1x256, output: 1x256) + func.func @TCOLSUM_i8_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi8> -> !pto.partition_tensor_view<1x1x1x1x255xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi8> -> !pto.partition_tensor_view<1x1x1x1x255xi8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xi8>) + outs(%src : !pto.tile_buf) + + pto.tcolsum ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi8>) + return + } + + // Case 12: i8 16x128 (input: 16x128, output: 1x128) + func.func @TCOLSUM_i8_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xi8> -> !pto.partition_tensor_view<1x1x1x16x127xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xi8> -> !pto.partition_tensor_view<1x1x1x1x127xi8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xi8>) + outs(%src : !pto.tile_buf) + + pto.tcolsum ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xi8>) + return + } + + // Case 13: i8 16x256 (input: 16x256, output: 1x256) + func.func @TCOLSUM_i8_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xi8> -> !pto.partition_tensor_view<1x1x1x15x255xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi8> -> !pto.partition_tensor_view<1x1x1x1x255xi8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xi8>) + outs(%src : !pto.tile_buf) + + pto.tcolsum ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi8>) + return + } + + // Case 14: i8 64x128_1 (input: 64x128, output: 1x128) + func.func @TCOLSUM_i8_64x128_1(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c64, %c128], + strides = [%c8192, %c8192, %c8192, %c128, %c1] + : !pto.tensor_view<1x1x1x64x128xi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c127] + : !pto.tensor_view<1x1x1x64x128xi8> -> !pto.partition_tensor_view<1x1x1x63x127xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xi8> -> !pto.partition_tensor_view<1x1x1x1x127xi8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x127xi8>) + outs(%src : !pto.tile_buf) + + pto.tcolsum ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xi8>) + return + } + + // Case 15: i8 64x128_2 (input: 64x128, output: 1x128) + func.func @TCOLSUM_i8_64x128_2(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c64, %c128], + strides = [%c8192, %c8192, %c8192, %c128, %c1] + : !pto.tensor_view<1x1x1x64x128xi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c128] + : !pto.tensor_view<1x1x1x64x128xi8> -> !pto.partition_tensor_view<1x1x1x64x128xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xi8> -> !pto.partition_tensor_view<1x1x1x1x128xi8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x64x128xi8>) + outs(%src : !pto.tile_buf) + + pto.tcolsum ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xi8>) + return + } +} \ No newline at end of file From 9c73d05d025f6e2f7676ec8cc99d1e282b519839 Mon Sep 17 00:00:00 2001 From: Yuchen Zhang Date: Sat, 25 Apr 2026 17:00:25 +0800 Subject: [PATCH 170/192] Add scalar ops (#191) --- lib/TileOps/render_template_mlir.py | 8 + lib/TileOps/tadd_template.py | 8 + lib/TileOps/tadds_template.py | 8 + lib/TileOps/tands_template.py | 40 +++ lib/TileOps/tdivs_template.py | 61 ++++ lib/TileOps/tmaxs_template.py | 31 ++ lib/TileOps/tmins_template.py | 31 ++ lib/TileOps/tmuls_template.py | 31 ++ lib/TileOps/tors_template.py | 40 +++ lib/TileOps/tshls_template.py | 31 ++ lib/TileOps/tshrs_template.py | 31 ++ lib/TileOps/tsubs_template.py | 38 +++ lib/TileOps/txors_template.py | 40 +++ test/basic/expand_tile_op_tilelang_tadds.pto | 38 +++ test/basic/expand_tile_op_tilelang_tands.pto | 34 ++ test/basic/expand_tile_op_tilelang_tdivs.pto | 63 ++++ test/basic/expand_tile_op_tilelang_tmaxs.pto | 33 ++ test/basic/expand_tile_op_tilelang_tmins.pto | 33 ++ test/basic/expand_tile_op_tilelang_tmuls.pto | 33 ++ test/basic/expand_tile_op_tilelang_tors.pto | 34 ++ test/basic/expand_tile_op_tilelang_tshls.pto | 33 ++ test/basic/expand_tile_op_tilelang_tshrs.pto | 33 ++ test/basic/expand_tile_op_tilelang_tsubs.pto | 33 ++ test/basic/expand_tile_op_tilelang_txors.pto | 39 +++ .../npu/a5/src/st/testcase/CMakeLists.txt | 11 + .../a5/src/st/testcase/tadds/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tadds/cases.py | 69 +++++ .../npu/a5/src/st/testcase/tadds/compare.py | 46 +++ .../npu/a5/src/st/testcase/tadds/gen_data.py | 35 +++ .../npu/a5/src/st/testcase/tadds/launch.cpp | 58 ++++ .../npu/a5/src/st/testcase/tadds/main.cpp | 139 +++++++++ .../npu/a5/src/st/testcase/tadds/tadds.pto | 292 ++++++++++++++++++ .../a5/src/st/testcase/tands/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tands/cases.py | 42 +++ .../npu/a5/src/st/testcase/tands/compare.py | 46 +++ .../npu/a5/src/st/testcase/tands/gen_data.py | 35 +++ .../npu/a5/src/st/testcase/tands/launch.cpp | 45 +++ .../npu/a5/src/st/testcase/tands/main.cpp | 135 ++++++++ .../npu/a5/src/st/testcase/tands/tands.pto | 200 ++++++++++++ .../a5/src/st/testcase/tdivs/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tdivs/cases.py | 85 +++++ .../npu/a5/src/st/testcase/tdivs/compare.py | 46 +++ .../npu/a5/src/st/testcase/tdivs/gen_data.py | 40 +++ .../npu/a5/src/st/testcase/tdivs/launch.cpp | 67 ++++ .../npu/a5/src/st/testcase/tdivs/main.cpp | 143 +++++++++ .../npu/a5/src/st/testcase/tdivs/tdivs.pto | 167 ++++++++++ .../a5/src/st/testcase/tmaxs/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tmaxs/cases.py | 22 ++ .../npu/a5/src/st/testcase/tmaxs/compare.py | 46 +++ .../npu/a5/src/st/testcase/tmaxs/gen_data.py | 35 +++ .../npu/a5/src/st/testcase/tmaxs/launch.cpp | 58 ++++ .../npu/a5/src/st/testcase/tmaxs/main.cpp | 139 +++++++++ .../npu/a5/src/st/testcase/tmaxs/tmaxs.pto | 292 ++++++++++++++++++ .../a5/src/st/testcase/tmins/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tmins/cases.py | 22 ++ .../npu/a5/src/st/testcase/tmins/compare.py | 46 +++ .../npu/a5/src/st/testcase/tmins/gen_data.py | 35 +++ .../npu/a5/src/st/testcase/tmins/launch.cpp | 58 ++++ .../npu/a5/src/st/testcase/tmins/main.cpp | 139 +++++++++ .../npu/a5/src/st/testcase/tmins/tmins.pto | 292 ++++++++++++++++++ .../a5/src/st/testcase/tmuls/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tmuls/cases.py | 69 +++++ .../npu/a5/src/st/testcase/tmuls/compare.py | 46 +++ .../npu/a5/src/st/testcase/tmuls/gen_data.py | 35 +++ .../npu/a5/src/st/testcase/tmuls/launch.cpp | 58 ++++ .../npu/a5/src/st/testcase/tmuls/main.cpp | 139 +++++++++ .../npu/a5/src/st/testcase/tmuls/tmuls.pto | 292 ++++++++++++++++++ .../a5/src/st/testcase/tors/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tors/cases.py | 42 +++ .../npu/a5/src/st/testcase/tors/compare.py | 46 +++ .../npu/a5/src/st/testcase/tors/gen_data.py | 35 +++ .../npu/a5/src/st/testcase/tors/launch.cpp | 45 +++ .../npu/a5/src/st/testcase/tors/main.cpp | 135 ++++++++ .../npu/a5/src/st/testcase/tors/tors.pto | 200 ++++++++++++ .../a5/src/st/testcase/tshls/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tshls/cases.py | 42 +++ .../npu/a5/src/st/testcase/tshls/compare.py | 46 +++ .../npu/a5/src/st/testcase/tshls/gen_data.py | 34 ++ .../npu/a5/src/st/testcase/tshls/launch.cpp | 44 +++ .../npu/a5/src/st/testcase/tshls/main.cpp | 135 ++++++++ .../npu/a5/src/st/testcase/tshls/tshls.pto | 200 ++++++++++++ .../a5/src/st/testcase/tshrs/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tshrs/cases.py | 42 +++ .../npu/a5/src/st/testcase/tshrs/compare.py | 46 +++ .../npu/a5/src/st/testcase/tshrs/gen_data.py | 34 ++ .../npu/a5/src/st/testcase/tshrs/launch.cpp | 44 +++ .../npu/a5/src/st/testcase/tshrs/main.cpp | 135 ++++++++ .../npu/a5/src/st/testcase/tshrs/tshrs.pto | 200 ++++++++++++ .../a5/src/st/testcase/tsubs/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tsubs/cases.py | 22 ++ .../npu/a5/src/st/testcase/tsubs/compare.py | 46 +++ .../npu/a5/src/st/testcase/tsubs/gen_data.py | 35 +++ .../npu/a5/src/st/testcase/tsubs/launch.cpp | 58 ++++ .../npu/a5/src/st/testcase/tsubs/main.cpp | 139 +++++++++ .../npu/a5/src/st/testcase/tsubs/tsubs.pto | 292 ++++++++++++++++++ .../a5/src/st/testcase/txors/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/txors/cases.py | 48 +++ .../npu/a5/src/st/testcase/txors/compare.py | 46 +++ .../npu/a5/src/st/testcase/txors/gen_data.py | 35 +++ .../npu/a5/src/st/testcase/txors/launch.cpp | 41 +++ .../npu/a5/src/st/testcase/txors/main.cpp | 135 ++++++++ .../npu/a5/src/st/testcase/txors/txors.pto | 221 +++++++++++++ 102 files changed, 7050 insertions(+) create mode 100644 lib/TileOps/tands_template.py create mode 100644 lib/TileOps/tdivs_template.py create mode 100644 lib/TileOps/tmaxs_template.py create mode 100644 lib/TileOps/tmins_template.py create mode 100644 lib/TileOps/tmuls_template.py create mode 100644 lib/TileOps/tors_template.py create mode 100644 lib/TileOps/tshls_template.py create mode 100644 lib/TileOps/tshrs_template.py create mode 100644 lib/TileOps/tsubs_template.py create mode 100644 lib/TileOps/txors_template.py create mode 100644 test/basic/expand_tile_op_tilelang_tadds.pto create mode 100644 test/basic/expand_tile_op_tilelang_tands.pto create mode 100644 test/basic/expand_tile_op_tilelang_tdivs.pto create mode 100644 test/basic/expand_tile_op_tilelang_tmaxs.pto create mode 100644 test/basic/expand_tile_op_tilelang_tmins.pto create mode 100644 test/basic/expand_tile_op_tilelang_tmuls.pto create mode 100644 test/basic/expand_tile_op_tilelang_tors.pto create mode 100644 test/basic/expand_tile_op_tilelang_tshls.pto create mode 100644 test/basic/expand_tile_op_tilelang_tshrs.pto create mode 100644 test/basic/expand_tile_op_tilelang_tsubs.pto create mode 100644 test/basic/expand_tile_op_tilelang_txors.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadds/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadds/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadds/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadds/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadds/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadds/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadds/tadds.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tands/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tands/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tands/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tands/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tands/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tands/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tands/tands.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tdivs/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tdivs/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tdivs/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tdivs/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tdivs/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tdivs/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tdivs/tdivs.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmaxs/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmaxs/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmaxs/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmaxs/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmaxs/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmaxs/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmaxs/tmaxs.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmins/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmins/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmins/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmins/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmins/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmins/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmins/tmins.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmuls/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmuls/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmuls/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmuls/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmuls/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmuls/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmuls/tmuls.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tors/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tors/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tors/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tors/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tors/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tors/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tors/tors.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tshls/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tshls/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tshls/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tshls/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tshls/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tshls/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tshls/tshls.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tshrs/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tshrs/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tshrs/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tshrs/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tshrs/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tshrs/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tshrs/tshrs.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tsubs/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tsubs/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tsubs/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tsubs/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tsubs/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tsubs/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tsubs/tsubs.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/txors/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/txors/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/txors/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/txors/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/txors/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/txors/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/txors/txors.pto diff --git a/lib/TileOps/render_template_mlir.py b/lib/TileOps/render_template_mlir.py index 42ce40552..5c4952674 100644 --- a/lib/TileOps/render_template_mlir.py +++ b/lib/TileOps/render_template_mlir.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + #!/usr/bin/env python3 """Materialize a TileLang DSL library template to authoring-form MLIR. diff --git a/lib/TileOps/tadd_template.py b/lib/TileOps/tadd_template.py index ecab2be5b..8e247fd73 100644 --- a/lib/TileOps/tadd_template.py +++ b/lib/TileOps/tadd_template.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + """TileLang DSL template for pto.tadd""" import sys diff --git a/lib/TileOps/tadds_template.py b/lib/TileOps/tadds_template.py index fd36e1f3d..7c3ddb06c 100644 --- a/lib/TileOps/tadds_template.py +++ b/lib/TileOps/tadds_template.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + """TileLang DSL template for pto.tadds""" import sys diff --git a/lib/TileOps/tands_template.py b/lib/TileOps/tands_template.py new file mode 100644 index 000000000..f93502900 --- /dev/null +++ b/lib/TileOps/tands_template.py @@ -0,0 +1,40 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tands + +Note: A5 hardware implements tands as: + TEXPANDS_IMPL(dst, scalar); // broadcast scalar to dst + TAND_IMPL(dst, src, dst); // dst = src & dst + +This template uses vbr + vand to achieve element-wise bitwise AND. +Only supports tile, scalar order (matching TAndS.hpp). +""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tands", +) +def template_tands(src: pto.Tile, scalar: pto.AnyType, dst: pto.Tile): + dtype = src.element_type + valid_rows, valid_cols = src.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vec = pto.vlds(src[row, col:]) + scalar_vec = pto.vbr(scalar) + result = pto.vand(vec, scalar_vec, mask) + pto.vsts(result, dst[row, col:], mask) + return diff --git a/lib/TileOps/tdivs_template.py b/lib/TileOps/tdivs_template.py new file mode 100644 index 000000000..ab0fdd5de --- /dev/null +++ b/lib/TileOps/tdivs_template.py @@ -0,0 +1,61 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tdivs + +Supports two operand orders (matching TDivS.hpp): + 1. tdivs(src_tile, scalar, dst) -> src / scalar + 2. tdivs(scalar, src_tile, dst) -> scalar / src + +TODO: Add support for high-precision division (e.g., f64 or extended precision) +""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tdivs", +) +def template_tdivs_tile_scalar(src: pto.Tile, scalar: pto.AnyType, dst: pto.Tile): + """src / scalar""" + dtype = src.element_type + valid_rows, valid_cols = src.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vec = pto.vlds(src[row, col:]) + scalar_vec = pto.vbr(scalar) + result = pto.vdiv(vec, scalar_vec, mask) + pto.vsts(result, dst[row, col:], mask) + return + + +@pto.vkernel( + target="a5", + op="pto.tdivs", +) +def template_tdivs_scalar_tile(scalar: pto.AnyType, src: pto.Tile, dst: pto.Tile): + """scalar / src""" + dtype = src.element_type + valid_rows, valid_cols = src.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vec = pto.vlds(src[row, col:]) + scalar_vec = pto.vbr(scalar) + result = pto.vdiv(scalar_vec, vec, mask) + # TO DO: support high precision division + pto.vsts(result, dst[row, col:], mask) + return diff --git a/lib/TileOps/tmaxs_template.py b/lib/TileOps/tmaxs_template.py new file mode 100644 index 000000000..5c9e409a3 --- /dev/null +++ b/lib/TileOps/tmaxs_template.py @@ -0,0 +1,31 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tmaxs""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tmaxs", +) +def template_tmaxs(src: pto.Tile, scalar: pto.AnyType, dst: pto.Tile): + dtype = src.element_type + valid_rows, valid_cols = src.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vec = pto.vlds(src[row, col:]) + result = pto.vmaxs(vec, scalar, mask) + pto.vsts(result, dst[row, col:], mask) + return diff --git a/lib/TileOps/tmins_template.py b/lib/TileOps/tmins_template.py new file mode 100644 index 000000000..bda0df5f9 --- /dev/null +++ b/lib/TileOps/tmins_template.py @@ -0,0 +1,31 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tmins""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tmins", +) +def template_tmins(src: pto.Tile, scalar: pto.AnyType, dst: pto.Tile): + dtype = src.element_type + valid_rows, valid_cols = src.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vec = pto.vlds(src[row, col:]) + result = pto.vmins(vec, scalar, mask) + pto.vsts(result, dst[row, col:], mask) + return diff --git a/lib/TileOps/tmuls_template.py b/lib/TileOps/tmuls_template.py new file mode 100644 index 000000000..8d02ea826 --- /dev/null +++ b/lib/TileOps/tmuls_template.py @@ -0,0 +1,31 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tmuls""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tmuls", +) +def template_tmuls(src: pto.Tile, scalar: pto.AnyType, dst: pto.Tile): + dtype = src.element_type + valid_rows, valid_cols = src.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vec = pto.vlds(src[row, col:]) + result = pto.vmuls(vec, scalar, mask) + pto.vsts(result, dst[row, col:], mask) + return diff --git a/lib/TileOps/tors_template.py b/lib/TileOps/tors_template.py new file mode 100644 index 000000000..4ff567a0a --- /dev/null +++ b/lib/TileOps/tors_template.py @@ -0,0 +1,40 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tors + +Note: A5 hardware implements tors as: + TEXPANDS_IMPL(dst, scalar); // broadcast scalar to dst + TOR_IMPL(dst, src, dst); // dst = src | dst + +This template uses vbr + vor to achieve element-wise bitwise OR. +Only supports tile, scalar order (matching TOrS.hpp). +""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tors", +) +def template_tors(src: pto.Tile, scalar: pto.AnyType, dst: pto.Tile): + dtype = src.element_type + valid_rows, valid_cols = src.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vec = pto.vlds(src[row, col:]) + scalar_vec = pto.vbr(scalar) + result = pto.vor(vec, scalar_vec, mask) + pto.vsts(result, dst[row, col:], mask) + return diff --git a/lib/TileOps/tshls_template.py b/lib/TileOps/tshls_template.py new file mode 100644 index 000000000..def0b0353 --- /dev/null +++ b/lib/TileOps/tshls_template.py @@ -0,0 +1,31 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tshls""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tshls", +) +def template_tshls(src: pto.Tile, scalar: pto.i16, dst: pto.Tile): + dtype = src.element_type + valid_rows, valid_cols = src.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vec = pto.vlds(src[row, col:]) + result = pto.vshls(vec, scalar, mask) + pto.vsts(result, dst[row, col:], mask) + return diff --git a/lib/TileOps/tshrs_template.py b/lib/TileOps/tshrs_template.py new file mode 100644 index 000000000..8366a638b --- /dev/null +++ b/lib/TileOps/tshrs_template.py @@ -0,0 +1,31 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tshrs""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tshrs", +) +def template_tshrs(src: pto.Tile, scalar: pto.AnyType, dst: pto.Tile): + dtype = src.element_type + valid_rows, valid_cols = src.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vec = pto.vlds(src[row, col:]) + result = pto.vshrs(vec, scalar, mask) + pto.vsts(result, dst[row, col:], mask) + return diff --git a/lib/TileOps/tsubs_template.py b/lib/TileOps/tsubs_template.py new file mode 100644 index 000000000..84dc8bfbd --- /dev/null +++ b/lib/TileOps/tsubs_template.py @@ -0,0 +1,38 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tsubs + +Note: A5 hardware implements tsubs as vadds with negated scalar: + dst = src - scalar = src + (-scalar) +This template uses vbr + vsub to achieve element-wise subtraction. +TODO: Use vadds(vec, -scalar) when DSL supports unary negation on scalars. +""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tsubs", +) +def template_tsubs(src: pto.Tile, scalar: pto.AnyType, dst: pto.Tile): + dtype = src.element_type + valid_rows, valid_cols = src.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vec = pto.vlds(src[row, col:]) + scalar_vec = pto.vbr(scalar) + result = pto.vsub(vec, scalar_vec, mask) + pto.vsts(result, dst[row, col:], mask) + return diff --git a/lib/TileOps/txors_template.py b/lib/TileOps/txors_template.py new file mode 100644 index 000000000..ca3ffcea1 --- /dev/null +++ b/lib/TileOps/txors_template.py @@ -0,0 +1,40 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.txors + +Note: A5 hardware implements txors as: + TEXPANDS_IMPL(dst, scalar); // broadcast scalar to dst + TXOR_IMPL(dst, src, dst, tmp); // dst = src ^ dst + +This template uses vbr + vxor to achieve element-wise bitwise XOR. +Requires tmp tile matching TXorS.hpp signature (src, scalar, tmp, dst). +""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.txors", +) +def template_txors(src: pto.Tile, scalar: pto.AnyType, tmp: pto.Tile, dst: pto.Tile): + dtype = src.element_type + valid_rows, valid_cols = src.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vec = pto.vlds(src[row, col:]) + scalar_vec = pto.vbr(scalar) + result = pto.vxor(vec, scalar_vec, mask) + pto.vsts(result, dst[row, col:], mask) + return diff --git a/test/basic/expand_tile_op_tilelang_tadds.pto b/test/basic/expand_tile_op_tilelang_tadds.pto new file mode 100644 index 000000000..9c417ad3c --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_tadds.pto @@ -0,0 +1,38 @@ +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that pto.tadds can be lowered to vector-style VPTO IR. +// +// IMPORTANT: Do NOT use --enable-tile-op-expand for ops with scalar operands. +// TADDS has a scalar operand (f32), so it uses the PTOToVPTO lowering path. +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto %s -o - 2>/dev/null | FileCheck %s + +// After PTOToVPTO lowering, pto.tadds should use vadds (vector add scalar). +// CHECK: func.func @TADDS +// CHECK-NOT: pto.tadds ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vadds +// CHECK: pto.vsts + +module { + func.func @TADDS() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + %scalar = arith.constant 1.0 : f32 + + pto.tadds ins(%a, %scalar : !pto.tile_buf, + f32) + outs(%tile_buf : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/expand_tile_op_tilelang_tands.pto b/test/basic/expand_tile_op_tilelang_tands.pto new file mode 100644 index 000000000..197820c06 --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_tands.pto @@ -0,0 +1,34 @@ +// Test that pto.tands can be lowered to vector-style VPTO IR. +// +// IMPORTANT: Do NOT use --enable-tile-op-expand for ops with scalar operands. +// TANDS has a scalar operand (i32), so it uses the PTOToVPTO lowering path. +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto %s -o - 2>/dev/null | FileCheck %s + +// After PTOToVPTO lowering, pto.tands should use vbr + vand. +// CHECK: func.func @TANDS +// CHECK-NOT: pto.tands ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vbr +// CHECK: pto.vand +// CHECK: pto.vsts + +module { + func.func @TANDS() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + %scalar = arith.constant 0xFF : i32 + + pto.tands ins(%a, %scalar : !pto.tile_buf, + i32) + outs(%tile_buf : !pto.tile_buf) + return + } +} diff --git a/test/basic/expand_tile_op_tilelang_tdivs.pto b/test/basic/expand_tile_op_tilelang_tdivs.pto new file mode 100644 index 000000000..479c51b5c --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_tdivs.pto @@ -0,0 +1,63 @@ +// Test that pto.tdivs can be lowered to vector-style VPTO IR. +// +// IMPORTANT: Do NOT use --enable-tile-op-expand for ops with scalar operands. +// TDIVS has a scalar operand (f32), so it uses the PTOToVPTO lowering path. +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto %s -o - 2>/dev/null | FileCheck %s --check-prefix=CHECK-TILE-SCALAR +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto %s -o - 2>/dev/null | FileCheck %s --check-prefix=CHECK-SCALAR-TILE + +// tile / scalar form: +// CHECK-TILE-SCALAR: func.func @TDIVS +// CHECK-TILE-SCALAR-NOT: pto.tdivs ins +// CHECK-TILE-SCALAR: pto.vecscope +// CHECK-TILE-SCALAR: pto.vlds +// CHECK-TILE-SCALAR: pto.vbr +// CHECK-TILE-SCALAR: pto.vdiv +// CHECK-TILE-SCALAR: pto.vsts + +// scalar / tile form: +// CHECK-SCALAR-TILE: func.func @TDIVS +// CHECK-SCALAR-TILE-NOT: pto.tdivs ins +// CHECK-SCALAR-TILE: pto.vecscope +// CHECK-SCALAR-TILE: pto.vlds +// CHECK-SCALAR-TILE: pto.vbr +// CHECK-SCALAR-TILE: pto.vdiv +// CHECK-SCALAR-TILE: pto.vsts + +module { + func.func @TDIVS() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + %scalar = arith.constant 2.0 : f32 + + pto.tdivs ins(%a, %scalar : !pto.tile_buf, + f32) + outs(%tile_buf : !pto.tile_buf) + return + } +} + +module { + func.func @TDIVS() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + %scalar = arith.constant 2.0 : f32 + + pto.tdivs ins(%scalar, %a : f32, + !pto.tile_buf) + outs(%tile_buf : !pto.tile_buf) + return + } +} diff --git a/test/basic/expand_tile_op_tilelang_tmaxs.pto b/test/basic/expand_tile_op_tilelang_tmaxs.pto new file mode 100644 index 000000000..61cb595b2 --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_tmaxs.pto @@ -0,0 +1,33 @@ +// Test that pto.tmaxs can be lowered to vector-style VPTO IR. +// +// IMPORTANT: Do NOT use --enable-tile-op-expand for ops with scalar operands. +// TMAXS has a scalar operand (f32), so it uses the PTOToVPTO lowering path. +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto %s -o - 2>/dev/null | FileCheck %s + +// After PTOToVPTO lowering, pto.tmaxs should use vmaxs (vector max scalar). +// CHECK: func.func @TMAXS +// CHECK-NOT: pto.tmaxs ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vmaxs +// CHECK: pto.vsts + +module { + func.func @TMAXS() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + %scalar = arith.constant 0.0 : f32 + + pto.tmaxs ins(%a, %scalar : !pto.tile_buf, + f32) + outs(%tile_buf : !pto.tile_buf) + return + } +} diff --git a/test/basic/expand_tile_op_tilelang_tmins.pto b/test/basic/expand_tile_op_tilelang_tmins.pto new file mode 100644 index 000000000..6707e4691 --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_tmins.pto @@ -0,0 +1,33 @@ +// Test that pto.tmins can be lowered to vector-style VPTO IR. +// +// IMPORTANT: Do NOT use --enable-tile-op-expand for ops with scalar operands. +// TMINS has a scalar operand (f32), so it uses the PTOToVPTO lowering path. +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto %s -o - 2>/dev/null | FileCheck %s + +// After PTOToVPTO lowering, pto.tmins should use vmins (vector min scalar). +// CHECK: func.func @TMINS +// CHECK-NOT: pto.tmins ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vmins +// CHECK: pto.vsts + +module { + func.func @TMINS() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + %scalar = arith.constant 0.0 : f32 + + pto.tmins ins(%a, %scalar : !pto.tile_buf, + f32) + outs(%tile_buf : !pto.tile_buf) + return + } +} diff --git a/test/basic/expand_tile_op_tilelang_tmuls.pto b/test/basic/expand_tile_op_tilelang_tmuls.pto new file mode 100644 index 000000000..344beb7d6 --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_tmuls.pto @@ -0,0 +1,33 @@ +// Test that pto.tmuls can be lowered to vector-style VPTO IR. +// +// IMPORTANT: Do NOT use --enable-tile-op-expand for ops with scalar operands. +// TMULS has a scalar operand (f32), so it uses the PTOToVPTO lowering path. +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto %s -o - 2>/dev/null | FileCheck %s + +// After PTOToVPTO lowering, pto.tmuls should use vmuls (vector multiply scalar). +// CHECK: func.func @TMULS +// CHECK-NOT: pto.tmuls ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vmuls +// CHECK: pto.vsts + +module { + func.func @TMULS() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + %scalar = arith.constant 2.0 : f32 + + pto.tmuls ins(%a, %scalar : !pto.tile_buf, + f32) + outs(%tile_buf : !pto.tile_buf) + return + } +} diff --git a/test/basic/expand_tile_op_tilelang_tors.pto b/test/basic/expand_tile_op_tilelang_tors.pto new file mode 100644 index 000000000..56069a244 --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_tors.pto @@ -0,0 +1,34 @@ +// Test that pto.tors can be lowered to vector-style VPTO IR. +// +// IMPORTANT: Do NOT use --enable-tile-op-expand for ops with scalar operands. +// TORS has a scalar operand (i32), so it uses the PTOToVPTO lowering path. +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto %s -o - 2>/dev/null | FileCheck %s + +// After PTOToVPTO lowering, pto.tors should use vbr + vor. +// CHECK: func.func @TORS +// CHECK-NOT: pto.tors ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vbr +// CHECK: pto.vor +// CHECK: pto.vsts + +module { + func.func @TORS() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + %scalar = arith.constant 0xFF : i32 + + pto.tors ins(%a, %scalar : !pto.tile_buf, + i32) + outs(%tile_buf : !pto.tile_buf) + return + } +} diff --git a/test/basic/expand_tile_op_tilelang_tshls.pto b/test/basic/expand_tile_op_tilelang_tshls.pto new file mode 100644 index 000000000..3fc31aaec --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_tshls.pto @@ -0,0 +1,33 @@ +// Test that pto.tshls can be lowered to vector-style VPTO IR. +// +// IMPORTANT: Do NOT use --enable-tile-op-expand for ops with scalar operands. +// TSHLS has a scalar operand (i16), so it uses the PTOToVPTO lowering path. +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto %s -o - 2>/dev/null | FileCheck %s + +// After PTOToVPTO lowering, pto.tshls should use vshls (vector shift left scalar). +// CHECK: func.func @TSHLS +// CHECK-NOT: pto.tshls ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vshls +// CHECK: pto.vsts + +module { + func.func @TSHLS() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + %scalar = arith.constant 2 : i16 + + pto.tshls ins(%a, %scalar : !pto.tile_buf, + i16) + outs(%tile_buf : !pto.tile_buf) + return + } +} diff --git a/test/basic/expand_tile_op_tilelang_tshrs.pto b/test/basic/expand_tile_op_tilelang_tshrs.pto new file mode 100644 index 000000000..ea8148c3d --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_tshrs.pto @@ -0,0 +1,33 @@ +// Test that pto.tshrs can be lowered to vector-style VPTO IR. +// +// IMPORTANT: Do NOT use --enable-tile-op-expand for ops with scalar operands. +// TSHRS has a scalar operand (i16), so it uses the PTOToVPTO lowering path. +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto %s -o - 2>/dev/null | FileCheck %s + +// After PTOToVPTO lowering, pto.tshrs should use vshrs (vector shift right scalar). +// CHECK: func.func @TSHRS +// CHECK-NOT: pto.tshrs ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vshrs +// CHECK: pto.vsts + +module { + func.func @TSHRS() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + %scalar = arith.constant 2 : i16 + + pto.tshrs ins(%a, %scalar : !pto.tile_buf, + i16) + outs(%tile_buf : !pto.tile_buf) + return + } +} diff --git a/test/basic/expand_tile_op_tilelang_tsubs.pto b/test/basic/expand_tile_op_tilelang_tsubs.pto new file mode 100644 index 000000000..300b1b345 --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_tsubs.pto @@ -0,0 +1,33 @@ +// Test that pto.tsubs can be lowered to vector-style VPTO IR. +// +// IMPORTANT: Do NOT use --enable-tile-op-expand for ops with scalar operands. +// TSUBS has a scalar operand (f32), so it uses the PTOToVPTO lowering path. +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto %s -o - 2>/dev/null | FileCheck %s + +// After PTOToVPTO lowering, pto.tsubs should use vsubs (vector subtract scalar). +// CHECK: func.func @TSUBS +// CHECK-NOT: pto.tsubs ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vsubs +// CHECK: pto.vsts + +module { + func.func @TSUBS() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + %scalar = arith.constant 1.0 : f32 + + pto.tsubs ins(%a, %scalar : !pto.tile_buf, + f32) + outs(%tile_buf : !pto.tile_buf) + return + } +} diff --git a/test/basic/expand_tile_op_tilelang_txors.pto b/test/basic/expand_tile_op_tilelang_txors.pto new file mode 100644 index 000000000..336c8c218 --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_txors.pto @@ -0,0 +1,39 @@ +// Test that pto.txors can be lowered to vector-style VPTO IR. +// +// IMPORTANT: Do NOT use --enable-tile-op-expand for ops with scalar operands. +// TXORS has a scalar operand (i32), so it uses the PTOToVPTO lowering path. +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto %s -o - 2>/dev/null | FileCheck %s + +// After PTOToVPTO lowering, pto.txors should use vbr + vxor. +// CHECK: func.func @TXORS +// CHECK-NOT: pto.txors ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vbr +// CHECK: pto.vxor +// CHECK: pto.vsts + +module { + func.func @TXORS() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + %tmp_buf = pto.alloc_tile + : !pto.tile_buf + %scalar = arith.constant 0xFF : i32 + + pto.txors ins(%a, %scalar, %tmp_buf : !pto.tile_buf, + i32, + !pto.tile_buf) + outs(%tile_buf : !pto.tile_buf) + return + } +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt index 347354bb0..54e5038cc 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt +++ b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt @@ -155,6 +155,17 @@ set(ALL_TESTCASES trowprod trsqrt tsqrt + tadds + tands + tdivs + tmaxs + tmins + tmuls + tors + tshls + tshrs + tsubs + txors ) if((TEST_CASE IN_LIST ALL_TESTCASES) OR (TEST_CASE STREQUAL "all")) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadds/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tadds/CMakeLists.txt new file mode 100644 index 000000000..d4535a569 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadds/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tadds) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadds/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tadds/cases.py new file mode 100644 index 000000000..5b24462bf --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadds/cases.py @@ -0,0 +1,69 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tadds ST test cases. + +Shapes and dtypes match testcase/tadds (C++ GTest suite): + case1: float, 32x64, valid 32x64 + case2: float16, 63x64, valid 63x64 + case3: int32, 31x128, valid 31x128 + case4: int16, 15x192, valid 15x192 + case5: float, 7x448, valid 7x448 + case6: float, 256x16, valid 256x16 + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_32x64", + "dtype": np.float32, + "shape": (32, 64), + "valid_shape": (32, 64), + "eps": 1e-6, + }, + { + "name": "f16_63x64", + "dtype": np.float16, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 1e-3, + }, + { + "name": "i32_31x128", + "dtype": np.int32, + "shape": (31, 128), + "valid_shape": (31, 128), + "eps": 0, + }, + { + "name": "i16_15x192", + "dtype": np.int16, + "shape": (15, 192), + "valid_shape": (15, 192), + "eps": 0, + }, + { + "name": "f32_7x448", + "dtype": np.float32, + "shape": (7, 448), + "valid_shape": (7, 448), + "eps": 1e-6, + }, + { + "name": "f32_256x16", + "dtype": np.float32, + "shape": (256, 16), + "valid_shape": (256, 16), + "eps": 1e-6, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadds/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tadds/compare.py new file mode 100644 index 000000000..50186777e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadds/compare.py @@ -0,0 +1,46 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadds/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tadds/gen_data.py new file mode 100644 index 000000000..c4f47c5f4 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadds/gen_data.py @@ -0,0 +1,35 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +# Scalar value added to every element (matches the scalar passed in launch.cpp) +SCALAR = 3.0 + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + scalar_val = dtype(SCALAR) + golden[:vr, :vc] = (input1[:vr, :vc] + scalar_val).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__} scalar={SCALAR}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadds/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tadds/launch.cpp new file mode 100644 index 000000000..49f0c98ec --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadds/launch.cpp @@ -0,0 +1,58 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Scalar value added to every element (must match gen_data.py SCALAR) +static constexpr float TADDS_SCALAR_F32 = 3.0f; + +// Case 0: f32 32x64 +extern "C" __global__ AICORE void TADDS_f32_32x64(__gm__ float *src, __gm__ float *dst, float scalar); + +void LaunchTADDS_f32_32x64(float *src, float *dst, void *stream) { + TADDS_f32_32x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TADDS_SCALAR_F32); +} + +// Case 1: f16 63x64 +extern "C" __global__ AICORE void TADDS_f16_63x64(__gm__ unsigned short *src, __gm__ unsigned short *dst, unsigned short scalar); + +void LaunchTADDS_f16_63x64(unsigned short *src, unsigned short *dst, void *stream) { + TADDS_f16_63x64<<<1, nullptr, stream>>>((__gm__ unsigned short *)src, (__gm__ unsigned short *)dst, (unsigned short)0x4200); +} + +// Case 2: i32 31x128 +extern "C" __global__ AICORE void TADDS_i32_31x128(__gm__ int32_t *src, __gm__ int32_t *dst, int32_t scalar); + +void LaunchTADDS_i32_31x128(int32_t *src, int32_t *dst, void *stream) { + TADDS_i32_31x128<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst, (int32_t)3); +} + +// Case 3: i16 15x192 +extern "C" __global__ AICORE void TADDS_i16_15x192(__gm__ int16_t *src, __gm__ int16_t *dst, int16_t scalar); + +void LaunchTADDS_i16_15x192(int16_t *src, int16_t *dst, void *stream) { + TADDS_i16_15x192<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst, (int16_t)3); +} + +// Case 4: f32 7x448 +extern "C" __global__ AICORE void TADDS_f32_7x448(__gm__ float *src, __gm__ float *dst, float scalar); + +void LaunchTADDS_f32_7x448(float *src, float *dst, void *stream) { + TADDS_f32_7x448<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TADDS_SCALAR_F32); +} + +// Case 5: f32 256x16 +extern "C" __global__ AICORE void TADDS_f32_256x16(__gm__ float *src, __gm__ float *dst, float scalar); + +void LaunchTADDS_f32_256x16(float *src, float *dst, void *stream) { + TADDS_f32_256x16<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TADDS_SCALAR_F32); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadds/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tadds/main.cpp new file mode 100644 index 000000000..4c6f409dc --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadds/main.cpp @@ -0,0 +1,139 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tadds ST — case-table driven. +// tadds: dst = src + scalar (single input + scalar, unlike tadd which has two inputs). +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTADDS_f32_32x64(float *src, float *dst, void *stream); +void LaunchTADDS_f16_63x64(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTADDS_i32_31x128(int32_t *src, int32_t *dst, void *stream); +void LaunchTADDS_i16_15x192(int16_t *src, int16_t *dst, void *stream); +void LaunchTADDS_f32_7x448(float *src, float *dst, void *stream); +void LaunchTADDS_f32_256x16(float *src, float *dst, void *stream); + +struct TestCase { + const char *name; + void (*launch)(void *, void *, void *); // src, dst, stream + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_32x64", (void (*)(void*,void*,void*))LaunchTADDS_f32_32x64, 32, 64, 32, 64, sizeof(float)}, + {"f16_63x64", (void (*)(void*,void*,void*))LaunchTADDS_f16_63x64, 63, 64, 63, 64, sizeof(uint16_t)}, + {"i32_31x128", (void (*)(void*,void*,void*))LaunchTADDS_i32_31x128, 31, 128, 31, 128, sizeof(int32_t)}, + {"i16_15x192", (void (*)(void*,void*,void*))LaunchTADDS_i16_15x192, 15, 192, 15, 192, sizeof(int16_t)}, + {"f32_7x448", (void (*)(void*,void*,void*))LaunchTADDS_f32_7x448, 7, 448, 7, 448, sizeof(float)}, + {"f32_256x16", (void (*)(void*,void*,void*))LaunchTADDS_f32_256x16, 256, 16, 256, 16, sizeof(float)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, fileSize); + aclrtMallocHost(&dstHost, fileSize); + + aclrtMalloc(&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tadds [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadds/tadds.pto b/test/tilelang_st/npu/a5/src/st/testcase/tadds/tadds.pto new file mode 100644 index 000000000..9287b8b46 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadds/tadds.pto @@ -0,0 +1,292 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tadds: tload(src) + tadds(src, scalar)->dst + tstore(dst). +// Multiple cases with different shapes/dtypes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. + +module { + + // Case 0: f32 32x64 (2048 elements) + func.func @TADDS_f32_32x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + outs(%src : !pto.tile_buf) + pto.tadds ins(%src, %scalar : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + return + } + + // Case 1: f16 63x64 (4032 elements) + func.func @TADDS_f16_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f16) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) + outs(%src : !pto.tile_buf) + pto.tadds ins(%src, %scalar : !pto.tile_buf, f16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) + return + } + + // Case 2: i32 31x128 (3968 elements) + func.func @TADDS_i32_31x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c31 = arith.constant 31 : index + %c128 = arith.constant 128 : index + %c3968 = arith.constant 3968 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + outs(%src : !pto.tile_buf) + pto.tadds ins(%src, %scalar : !pto.tile_buf, i32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + return + } + + // Case 3: i16 15x192 (2880 elements) + func.func @TADDS_i16_15x192(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i16) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c192 = arith.constant 192 : index + %c2880 = arith.constant 2880 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + outs(%src : !pto.tile_buf) + pto.tadds ins(%src, %scalar : !pto.tile_buf, i16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + return + } + + // Case 4: f32 7x448 (3136 elements) + func.func @TADDS_f32_7x448(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c448 = arith.constant 448 : index + %c3136 = arith.constant 3136 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c448] + : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x448xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c448] + : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x448xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x7x448xf32>) + outs(%src : !pto.tile_buf) + pto.tadds ins(%src, %scalar : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x7x448xf32>) + return + } + + // Case 5: f32 256x16 (4096 elements) + func.func @TADDS_f32_256x16(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c16 = arith.constant 16 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c4096, %c4096, %c4096, %c16, %c1] + : !pto.tensor_view<1x1x1x256x16xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c4096, %c4096, %c4096, %c16, %c1] + : !pto.tensor_view<1x1x1x256x16xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c16] + : !pto.tensor_view<1x1x1x256x16xf32> -> !pto.partition_tensor_view<1x1x1x256x16xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c16] + : !pto.tensor_view<1x1x1x256x16xf32> -> !pto.partition_tensor_view<1x1x1x256x16xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x256x16xf32>) + outs(%src : !pto.tile_buf) + pto.tadds ins(%src, %scalar : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x256x16xf32>) + return + } + +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tands/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tands/CMakeLists.txt new file mode 100644 index 000000000..0ff088f8d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tands/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tands) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tands/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tands/cases.py new file mode 100644 index 000000000..18cc99178 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tands/cases.py @@ -0,0 +1,42 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np + +CASES = [ + { + "name": "i32_32x64", + "dtype": np.int32, + "shape": (32, 64), + "valid_shape": (32, 64), + "eps": 0, + }, + { + "name": "i16_63x64", + "dtype": np.int16, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 0, + }, + { + "name": "i32_31x128", + "dtype": np.int32, + "shape": (31, 128), + "valid_shape": (31, 128), + "eps": 0, + }, + { + "name": "i16_15x192", + "dtype": np.int16, + "shape": (15, 192), + "valid_shape": (15, 192), + "eps": 0, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tands/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tands/compare.py new file mode 100644 index 000000000..50186777e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tands/compare.py @@ -0,0 +1,46 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tands/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tands/gen_data.py new file mode 100644 index 000000000..9f187cb03 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tands/gen_data.py @@ -0,0 +1,35 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +# Scalar value for bitwise AND (must match launch.cpp) +SCALAR = 3 + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + scalar_val = dtype(SCALAR) + golden[:vr, :vc] = (input1[:vr, :vc] & scalar_val).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__} scalar={SCALAR}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tands/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tands/launch.cpp new file mode 100644 index 000000000..8226ac79e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tands/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Scalar value for bitwise AND (must match gen_data.py SCALAR) +static constexpr int32_t TANDS_SCALAR_I32 = 3; +static constexpr int16_t TANDS_SCALAR_I16 = 3; + +// Case 0: i32 32x64 +extern "C" __global__ AICORE void TANDS_i32_32x64(__gm__ int32_t *src, __gm__ int32_t *dst, int32_t scalar); + +void LaunchTANDS_i32_32x64(int32_t *src, int32_t *dst, void *stream) { + TANDS_i32_32x64<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst, TANDS_SCALAR_I32); +} + +// Case 1: i16 63x64 +extern "C" __global__ AICORE void TANDS_i16_63x64(__gm__ int16_t *src, __gm__ int16_t *dst, int16_t scalar); + +void LaunchTANDS_i16_63x64(int16_t *src, int16_t *dst, void *stream) { + TANDS_i16_63x64<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst, TANDS_SCALAR_I16); +} + +// Case 2: i32 31x128 +extern "C" __global__ AICORE void TANDS_i32_31x128(__gm__ int32_t *src, __gm__ int32_t *dst, int32_t scalar); + +void LaunchTANDS_i32_31x128(int32_t *src, int32_t *dst, void *stream) { + TANDS_i32_31x128<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst, TANDS_SCALAR_I32); +} + +// Case 3: i16 15x192 +extern "C" __global__ AICORE void TANDS_i16_15x192(__gm__ int16_t *src, __gm__ int16_t *dst, int16_t scalar); + +void LaunchTANDS_i16_15x192(int16_t *src, int16_t *dst, void *stream) { + TANDS_i16_15x192<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst, TANDS_SCALAR_I16); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tands/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tands/main.cpp new file mode 100644 index 000000000..e0f93f2e7 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tands/main.cpp @@ -0,0 +1,135 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tands ST — case-table driven. +// tands: dst = src & scalar (single input + scalar, bitwise AND). +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTANDS_i32_32x64(int32_t *src, int32_t *dst, void *stream); +void LaunchTANDS_i16_63x64(int16_t *src, int16_t *dst, void *stream); +void LaunchTANDS_i32_31x128(int32_t *src, int32_t *dst, void *stream); +void LaunchTANDS_i16_15x192(int16_t *src, int16_t *dst, void *stream); + +struct TestCase { + const char *name; + void (*launch)(void *, void *, void *); // src, dst, stream + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"i32_32x64", (void (*)(void*,void*,void*))LaunchTANDS_i32_32x64, 32, 64, 32, 64, sizeof(int32_t)}, + {"i16_63x64", (void (*)(void*,void*,void*))LaunchTANDS_i16_63x64, 63, 64, 63, 64, sizeof(int16_t)}, + {"i32_31x128", (void (*)(void*,void*,void*))LaunchTANDS_i32_31x128, 31, 128, 31, 128, sizeof(int32_t)}, + {"i16_15x192", (void (*)(void*,void*,void*))LaunchTANDS_i16_15x192, 15, 192, 15, 192, sizeof(int16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, fileSize); + aclrtMallocHost(&dstHost, fileSize); + + aclrtMalloc(&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tands [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tands/tands.pto b/test/tilelang_st/npu/a5/src/st/testcase/tands/tands.pto new file mode 100644 index 000000000..da10360a4 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tands/tands.pto @@ -0,0 +1,200 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tands: tload(src) + tands(src, scalar)->dst + tstore(dst). +// Multiple cases with different shapes/dtypes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. + +module { + + // Case 0: i32 32x64 (2048 elements) + func.func @TANDS_i32_32x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xi32> -> !pto.partition_tensor_view<1x1x1x32x64xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xi32> -> !pto.partition_tensor_view<1x1x1x32x64xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x64xi32>) + outs(%src : !pto.tile_buf) + pto.tands ins(%src, %scalar : !pto.tile_buf, i32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xi32>) + return + } + + // Case 1: i16 63x64 (4032 elements) + func.func @TANDS_i16_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i16) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xi16> -> !pto.partition_tensor_view<1x1x1x63x64xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xi16> -> !pto.partition_tensor_view<1x1x1x63x64xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xi16>) + outs(%src : !pto.tile_buf) + pto.tands ins(%src, %scalar : !pto.tile_buf, i16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x64xi16>) + return + } + + // Case 2: i32 31x128 (3968 elements) + func.func @TANDS_i32_31x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c31 = arith.constant 31 : index + %c128 = arith.constant 128 : index + %c3968 = arith.constant 3968 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + outs(%src : !pto.tile_buf) + pto.tands ins(%src, %scalar : !pto.tile_buf, i32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + return + } + + // Case 3: i16 15x192 (2880 elements) + func.func @TANDS_i16_15x192(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i16) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c192 = arith.constant 192 : index + %c2880 = arith.constant 2880 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + outs(%src : !pto.tile_buf) + pto.tands ins(%src, %scalar : !pto.tile_buf, i16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + return + } + +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tdivs/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tdivs/CMakeLists.txt new file mode 100644 index 000000000..cfd816f61 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tdivs/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tdivs) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tdivs/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tdivs/cases.py new file mode 100644 index 000000000..8fbcdea4d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tdivs/cases.py @@ -0,0 +1,85 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tdivs ST test cases. + +vdiv only supports f16/f32 in TileLang DSL v1. +""" + +import numpy as np + +CASES = [ + # src / scalar direction + { + "name": "f32_32x64", + "dtype": np.float32, + "shape": (32, 64), + "valid_shape": (32, 64), + "eps": 1e-6, + "direction": "src_scalar", + }, + { + "name": "f16_63x64", + "dtype": np.float16, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 1e-3, + "direction": "src_scalar", + }, + { + "name": "f32_7x448", + "dtype": np.float32, + "shape": (7, 448), + "valid_shape": (7, 448), + "eps": 1e-6, + "direction": "src_scalar", + }, + { + "name": "f32_256x16", + "dtype": np.float32, + "shape": (256, 16), + "valid_shape": (256, 16), + "eps": 1e-6, + "direction": "src_scalar", + }, + # scalar / src direction + { + "name": "f32_32x64_scalar_src", + "dtype": np.float32, + "shape": (32, 64), + "valid_shape": (32, 64), + "eps": 1e-6, + "direction": "scalar_src", + }, + { + "name": "f16_63x64_scalar_src", + "dtype": np.float16, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 1e-3, + "direction": "scalar_src", + }, + { + "name": "f32_7x448_scalar_src", + "dtype": np.float32, + "shape": (7, 448), + "valid_shape": (7, 448), + "eps": 1e-6, + "direction": "scalar_src", + }, + { + "name": "f32_256x16_scalar_src", + "dtype": np.float32, + "shape": (256, 16), + "valid_shape": (256, 16), + "eps": 1e-6, + "direction": "scalar_src", + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tdivs/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tdivs/compare.py new file mode 100644 index 000000000..50186777e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tdivs/compare.py @@ -0,0 +1,46 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tdivs/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tdivs/gen_data.py new file mode 100644 index 000000000..61988f637 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tdivs/gen_data.py @@ -0,0 +1,40 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +# Scalar value for division (matches the scalar passed in launch.cpp) +SCALAR = 3.0 + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + direction = case.get("direction", "src_scalar") + + # Avoid zero values in src for scalar/src direction (division by zero) + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + scalar_val = dtype(SCALAR) + if direction == "src_scalar": + golden[:vr, :vc] = (input1[:vr, :vc] / scalar_val).astype(dtype, copy=False) + else: # scalar_src + golden[:vr, :vc] = (scalar_val / input1[:vr, :vc]).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__} direction={direction} scalar={SCALAR}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tdivs/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tdivs/launch.cpp new file mode 100644 index 000000000..4ddee7260 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tdivs/launch.cpp @@ -0,0 +1,67 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +static constexpr float TDIVS_SCALAR_F32 = 3.0f; + +// ========== src / scalar direction ========== + +// Case 0: f32 32x64 +extern "C" __global__ AICORE void TDIVS_f32_32x64(__gm__ float *src, __gm__ float *dst, float scalar); +void LaunchTDIVS_f32_32x64(float *src, float *dst, void *stream) { + TDIVS_f32_32x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TDIVS_SCALAR_F32); +} + +// Case 1: f16 63x64 +extern "C" __global__ AICORE void TDIVS_f16_63x64(__gm__ unsigned short *src, __gm__ unsigned short *dst, unsigned short scalar); +void LaunchTDIVS_f16_63x64(unsigned short *src, unsigned short *dst, void *stream) { + TDIVS_f16_63x64<<<1, nullptr, stream>>>((__gm__ unsigned short *)src, (__gm__ unsigned short *)dst, (unsigned short)0x4200); +} + +// Case 2: f32 7x448 +extern "C" __global__ AICORE void TDIVS_f32_7x448(__gm__ float *src, __gm__ float *dst, float scalar); +void LaunchTDIVS_f32_7x448(float *src, float *dst, void *stream) { + TDIVS_f32_7x448<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TDIVS_SCALAR_F32); +} + +// Case 3: f32 256x16 +extern "C" __global__ AICORE void TDIVS_f32_256x16(__gm__ float *src, __gm__ float *dst, float scalar); +void LaunchTDIVS_f32_256x16(float *src, float *dst, void *stream) { + TDIVS_f32_256x16<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TDIVS_SCALAR_F32); +} + +// ========== scalar / src direction ========== + +// Case 4: f32 32x64 scalar/src +extern "C" __global__ AICORE void TDIVS_f32_32x64_scalar_src(__gm__ float *src, __gm__ float *dst, float scalar); +void LaunchTDIVS_f32_32x64_scalar_src(float *src, float *dst, void *stream) { + TDIVS_f32_32x64_scalar_src<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TDIVS_SCALAR_F32); +} + +// Case 5: f16 63x64 scalar/src +extern "C" __global__ AICORE void TDIVS_f16_63x64_scalar_src(__gm__ unsigned short *src, __gm__ unsigned short *dst, unsigned short scalar); +void LaunchTDIVS_f16_63x64_scalar_src(unsigned short *src, unsigned short *dst, void *stream) { + TDIVS_f16_63x64_scalar_src<<<1, nullptr, stream>>>((__gm__ unsigned short *)src, (__gm__ unsigned short *)dst, (unsigned short)0x4200); +} + +// Case 6: f32 7x448 scalar/src +extern "C" __global__ AICORE void TDIVS_f32_7x448_scalar_src(__gm__ float *src, __gm__ float *dst, float scalar); +void LaunchTDIVS_f32_7x448_scalar_src(float *src, float *dst, void *stream) { + TDIVS_f32_7x448_scalar_src<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TDIVS_SCALAR_F32); +} + +// Case 7: f32 256x16 scalar/src +extern "C" __global__ AICORE void TDIVS_f32_256x16_scalar_src(__gm__ float *src, __gm__ float *dst, float scalar); +void LaunchTDIVS_f32_256x16_scalar_src(float *src, float *dst, void *stream) { + TDIVS_f32_256x16_scalar_src<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TDIVS_SCALAR_F32); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tdivs/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tdivs/main.cpp new file mode 100644 index 000000000..02c7934fa --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tdivs/main.cpp @@ -0,0 +1,143 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tdivs ST — case-table driven. +// tdivs: dst = src / scalar (single input + scalar). +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTDIVS_f32_32x64(float *src, float *dst, void *stream); +void LaunchTDIVS_f16_63x64(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTDIVS_f32_7x448(float *src, float *dst, void *stream); +void LaunchTDIVS_f32_256x16(float *src, float *dst, void *stream); +void LaunchTDIVS_f32_32x64_scalar_src(float *src, float *dst, void *stream); +void LaunchTDIVS_f16_63x64_scalar_src(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTDIVS_f32_7x448_scalar_src(float *src, float *dst, void *stream); +void LaunchTDIVS_f32_256x16_scalar_src(float *src, float *dst, void *stream); + +struct TestCase { + const char *name; + void (*launch)(void *, void *, void *); // src, dst, stream + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_32x64", (void (*)(void*,void*,void*))LaunchTDIVS_f32_32x64, 32, 64, 32, 64, sizeof(float)}, + {"f16_63x64", (void (*)(void*,void*,void*))LaunchTDIVS_f16_63x64, 63, 64, 63, 64, sizeof(uint16_t)}, + {"f32_7x448", (void (*)(void*,void*,void*))LaunchTDIVS_f32_7x448, 7, 448, 7, 448, sizeof(float)}, + {"f32_256x16", (void (*)(void*,void*,void*))LaunchTDIVS_f32_256x16, 256, 16, 256, 16, sizeof(float)}, + {"f32_32x64_scalar_src", (void (*)(void*,void*,void*))LaunchTDIVS_f32_32x64_scalar_src, 32, 64, 32, 64, sizeof(float)}, + {"f16_63x64_scalar_src", (void (*)(void*,void*,void*))LaunchTDIVS_f16_63x64_scalar_src, 63, 64, 63, 64, sizeof(uint16_t)}, + {"f32_7x448_scalar_src", (void (*)(void*,void*,void*))LaunchTDIVS_f32_7x448_scalar_src, 7, 448, 7, 448, sizeof(float)}, + {"f32_256x16_scalar_src", (void (*)(void*,void*,void*))LaunchTDIVS_f32_256x16_scalar_src, 256, 16, 256, 16, sizeof(float)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, fileSize); + aclrtMallocHost(&dstHost, fileSize); + + aclrtMalloc(&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tdivs [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tdivs/tdivs.pto b/test/tilelang_st/npu/a5/src/st/testcase/tdivs/tdivs.pto new file mode 100644 index 000000000..ccda3f263 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tdivs/tdivs.pto @@ -0,0 +1,167 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tdivs: tload(src) + tdivs(src, scalar)->dst + tstore(dst). +// vdiv only supports f16/f32 in TileLang DSL v1. +module { + + // Case 0: f32 32x64 + func.func @TDIVS_f32_32x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c2048 = arith.constant 2048 : index + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c32, %c64], strides = [%c2048, %c2048, %c2048, %c64, %c1] : !pto.tensor_view<1x1x1x32x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c32, %c64], strides = [%c2048, %c2048, %c2048, %c64, %c1] : !pto.tensor_view<1x1x1x32x64xf32> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c64] : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c64] : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) outs(%src : !pto.tile_buf) + pto.tdivs ins(%src, %scalar : !pto.tile_buf, f32) outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + return + } + + // Case 1: f16 63x64 + func.func @TDIVS_f16_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f16) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c63, %c64], strides = [%c4032, %c4032, %c4032, %c64, %c1] : !pto.tensor_view<1x1x1x63x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c63, %c64], strides = [%c4032, %c4032, %c4032, %c64, %c1] : !pto.tensor_view<1x1x1x63x64xf16> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c63, %c64] : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c63, %c64] : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) outs(%src : !pto.tile_buf) + pto.tdivs ins(%src, %scalar : !pto.tile_buf, f16) outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) + return + } + + // Case 2: f32 7x448 + func.func @TDIVS_f32_7x448(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c448 = arith.constant 448 : index + %c3136 = arith.constant 3136 : index + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c7, %c448], strides = [%c3136, %c3136, %c3136, %c448, %c1] : !pto.tensor_view<1x1x1x7x448xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c7, %c448], strides = [%c3136, %c3136, %c3136, %c448, %c1] : !pto.tensor_view<1x1x1x7x448xf32> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c7, %c448] : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x448xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c7, %c448] : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x448xf32> + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x7x448xf32>) outs(%src : !pto.tile_buf) + pto.tdivs ins(%src, %scalar : !pto.tile_buf, f32) outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x7x448xf32>) + return + } + + // Case 3: f32 256x16 + func.func @TDIVS_f32_256x16(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c16 = arith.constant 16 : index + %c4096 = arith.constant 4096 : index + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c256, %c16], strides = [%c4096, %c4096, %c4096, %c16, %c1] : !pto.tensor_view<1x1x1x256x16xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c256, %c16], strides = [%c4096, %c4096, %c4096, %c16, %c1] : !pto.tensor_view<1x1x1x256x16xf32> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c256, %c16] : !pto.tensor_view<1x1x1x256x16xf32> -> !pto.partition_tensor_view<1x1x1x256x16xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c256, %c16] : !pto.tensor_view<1x1x1x256x16xf32> -> !pto.partition_tensor_view<1x1x1x256x16xf32> + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x256x16xf32>) outs(%src : !pto.tile_buf) + pto.tdivs ins(%src, %scalar : !pto.tile_buf, f32) outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x256x16xf32>) + return + } + + // ========== scalar / src direction ========== + + // Case 4: f32 32x64 scalar/src + func.func @TDIVS_f32_32x64_scalar_src(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c2048 = arith.constant 2048 : index + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c32, %c64], strides = [%c2048, %c2048, %c2048, %c64, %c1] : !pto.tensor_view<1x1x1x32x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c32, %c64], strides = [%c2048, %c2048, %c2048, %c64, %c1] : !pto.tensor_view<1x1x1x32x64xf32> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c64] : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c64] : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) outs(%src : !pto.tile_buf) + pto.tdivs ins(%scalar, %src : f32, !pto.tile_buf) outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + return + } + + // Case 5: f16 63x64 scalar/src + func.func @TDIVS_f16_63x64_scalar_src(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f16) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c63, %c64], strides = [%c4032, %c4032, %c4032, %c64, %c1] : !pto.tensor_view<1x1x1x63x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c63, %c64], strides = [%c4032, %c4032, %c4032, %c64, %c1] : !pto.tensor_view<1x1x1x63x64xf16> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c63, %c64] : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c63, %c64] : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) outs(%src : !pto.tile_buf) + pto.tdivs ins(%scalar, %src : f16, !pto.tile_buf) outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) + return + } + + // Case 6: f32 7x448 scalar/src + func.func @TDIVS_f32_7x448_scalar_src(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c448 = arith.constant 448 : index + %c3136 = arith.constant 3136 : index + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c7, %c448], strides = [%c3136, %c3136, %c3136, %c448, %c1] : !pto.tensor_view<1x1x1x7x448xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c7, %c448], strides = [%c3136, %c3136, %c3136, %c448, %c1] : !pto.tensor_view<1x1x1x7x448xf32> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c7, %c448] : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x448xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c7, %c448] : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x448xf32> + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x7x448xf32>) outs(%src : !pto.tile_buf) + pto.tdivs ins(%scalar, %src : f32, !pto.tile_buf) outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x7x448xf32>) + return + } + + // Case 7: f32 256x16 scalar/src + func.func @TDIVS_f32_256x16_scalar_src(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c16 = arith.constant 16 : index + %c4096 = arith.constant 4096 : index + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c256, %c16], strides = [%c4096, %c4096, %c4096, %c16, %c1] : !pto.tensor_view<1x1x1x256x16xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c256, %c16], strides = [%c4096, %c4096, %c4096, %c16, %c1] : !pto.tensor_view<1x1x1x256x16xf32> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c256, %c16] : !pto.tensor_view<1x1x1x256x16xf32> -> !pto.partition_tensor_view<1x1x1x256x16xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c256, %c16] : !pto.tensor_view<1x1x1x256x16xf32> -> !pto.partition_tensor_view<1x1x1x256x16xf32> + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x256x16xf32>) outs(%src : !pto.tile_buf) + pto.tdivs ins(%scalar, %src : f32, !pto.tile_buf) outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x256x16xf32>) + return + } + +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/CMakeLists.txt new file mode 100644 index 000000000..a540c4c13 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tmaxs) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/cases.py new file mode 100644 index 000000000..d3bab221b --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/cases.py @@ -0,0 +1,22 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tmaxs ST test cases.""" + +import numpy as np + +CASES = [ + {"name": "f32_32x64", "dtype": np.float32, "shape": (32, 64), "valid_shape": (32, 64), "eps": 1e-6}, + {"name": "f16_63x64", "dtype": np.float16, "shape": (63, 64), "valid_shape": (63, 64), "eps": 1e-3}, + {"name": "i32_31x128", "dtype": np.int32, "shape": (31, 128), "valid_shape": (31, 128), "eps": 0}, + {"name": "i16_15x192", "dtype": np.int16, "shape": (15, 192), "valid_shape": (15, 192), "eps": 0}, + {"name": "f32_7x448", "dtype": np.float32, "shape": (7, 448), "valid_shape": (7, 448), "eps": 1e-6}, + {"name": "f32_256x16", "dtype": np.float32, "shape": (256, 16), "valid_shape": (256, 16), "eps": 1e-6}, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/compare.py new file mode 100644 index 000000000..50186777e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/compare.py @@ -0,0 +1,46 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/gen_data.py new file mode 100644 index 000000000..10520c68b --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/gen_data.py @@ -0,0 +1,35 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +# Scalar value used for element-wise maximum (matches the scalar passed in launch.cpp) +SCALAR = 5.0 + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + scalar_val = dtype(SCALAR) + golden[:vr, :vc] = np.maximum(input1[:vr, :vc], scalar_val).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__} scalar={SCALAR}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/launch.cpp new file mode 100644 index 000000000..793db13f1 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/launch.cpp @@ -0,0 +1,58 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Scalar value used for element-wise maximum (must match gen_data.py SCALAR) +static constexpr float TMAXS_SCALAR_F32 = 5.0f; + +// Case 0: f32 32x64 +extern "C" __global__ AICORE void TMAXS_f32_32x64(__gm__ float *src, __gm__ float *dst, float scalar); + +void LaunchTMAXS_f32_32x64(float *src, float *dst, void *stream) { + TMAXS_f32_32x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TMAXS_SCALAR_F32); +} + +// Case 1: f16 63x64 +extern "C" __global__ AICORE void TMAXS_f16_63x64(__gm__ unsigned short *src, __gm__ unsigned short *dst, unsigned short scalar); + +void LaunchTMAXS_f16_63x64(unsigned short *src, unsigned short *dst, void *stream) { + TMAXS_f16_63x64<<<1, nullptr, stream>>>((__gm__ unsigned short *)src, (__gm__ unsigned short *)dst, (unsigned short)0x4500); +} + +// Case 2: i32 31x128 +extern "C" __global__ AICORE void TMAXS_i32_31x128(__gm__ int32_t *src, __gm__ int32_t *dst, int32_t scalar); + +void LaunchTMAXS_i32_31x128(int32_t *src, int32_t *dst, void *stream) { + TMAXS_i32_31x128<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst, (int32_t)5); +} + +// Case 3: i16 15x192 +extern "C" __global__ AICORE void TMAXS_i16_15x192(__gm__ int16_t *src, __gm__ int16_t *dst, int16_t scalar); + +void LaunchTMAXS_i16_15x192(int16_t *src, int16_t *dst, void *stream) { + TMAXS_i16_15x192<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst, (int16_t)5); +} + +// Case 4: f32 7x448 +extern "C" __global__ AICORE void TMAXS_f32_7x448(__gm__ float *src, __gm__ float *dst, float scalar); + +void LaunchTMAXS_f32_7x448(float *src, float *dst, void *stream) { + TMAXS_f32_7x448<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TMAXS_SCALAR_F32); +} + +// Case 5: f32 256x16 +extern "C" __global__ AICORE void TMAXS_f32_256x16(__gm__ float *src, __gm__ float *dst, float scalar); + +void LaunchTMAXS_f32_256x16(float *src, float *dst, void *stream) { + TMAXS_f32_256x16<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TMAXS_SCALAR_F32); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/main.cpp new file mode 100644 index 000000000..7104ff7ad --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/main.cpp @@ -0,0 +1,139 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tmaxs ST — case-table driven. +// tmaxs: dst = max(src, scalar) (single input + scalar). +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTMAXS_f32_32x64(float *src, float *dst, void *stream); +void LaunchTMAXS_f16_63x64(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTMAXS_i32_31x128(int32_t *src, int32_t *dst, void *stream); +void LaunchTMAXS_i16_15x192(int16_t *src, int16_t *dst, void *stream); +void LaunchTMAXS_f32_7x448(float *src, float *dst, void *stream); +void LaunchTMAXS_f32_256x16(float *src, float *dst, void *stream); + +struct TestCase { + const char *name; + void (*launch)(void *, void *, void *); // src, dst, stream + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_32x64", (void (*)(void*,void*,void*))LaunchTMAXS_f32_32x64, 32, 64, 32, 64, sizeof(float)}, + {"f16_63x64", (void (*)(void*,void*,void*))LaunchTMAXS_f16_63x64, 63, 64, 63, 64, sizeof(uint16_t)}, + {"i32_31x128", (void (*)(void*,void*,void*))LaunchTMAXS_i32_31x128, 31, 128, 31, 128, sizeof(int32_t)}, + {"i16_15x192", (void (*)(void*,void*,void*))LaunchTMAXS_i16_15x192, 15, 192, 15, 192, sizeof(int16_t)}, + {"f32_7x448", (void (*)(void*,void*,void*))LaunchTMAXS_f32_7x448, 7, 448, 7, 448, sizeof(float)}, + {"f32_256x16", (void (*)(void*,void*,void*))LaunchTMAXS_f32_256x16, 256, 16, 256, 16, sizeof(float)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, fileSize); + aclrtMallocHost(&dstHost, fileSize); + + aclrtMalloc(&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tmaxs [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/tmaxs.pto b/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/tmaxs.pto new file mode 100644 index 000000000..9ef3316ba --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/tmaxs.pto @@ -0,0 +1,292 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tmaxs: tload(src) + tmaxs(src, scalar)->dst + tstore(dst). +// Multiple cases with different shapes/dtypes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. + +module { + + // Case 0: f32 32x64 (2048 elements) + func.func @TMAXS_f32_32x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + outs(%src : !pto.tile_buf) + pto.tmaxs ins(%src, %scalar : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + return + } + + // Case 1: f16 63x64 (4032 elements) + func.func @TMAXS_f16_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f16) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) + outs(%src : !pto.tile_buf) + pto.tmaxs ins(%src, %scalar : !pto.tile_buf, f16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) + return + } + + // Case 2: i32 31x128 (3968 elements) + func.func @TMAXS_i32_31x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c31 = arith.constant 31 : index + %c128 = arith.constant 128 : index + %c3968 = arith.constant 3968 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + outs(%src : !pto.tile_buf) + pto.tmaxs ins(%src, %scalar : !pto.tile_buf, i32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + return + } + + // Case 3: i16 15x192 (2880 elements) + func.func @TMAXS_i16_15x192(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i16) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c192 = arith.constant 192 : index + %c2880 = arith.constant 2880 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + outs(%src : !pto.tile_buf) + pto.tmaxs ins(%src, %scalar : !pto.tile_buf, i16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + return + } + + // Case 4: f32 7x448 (3136 elements) + func.func @TMAXS_f32_7x448(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c448 = arith.constant 448 : index + %c3136 = arith.constant 3136 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c448] + : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x448xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c448] + : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x448xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x7x448xf32>) + outs(%src : !pto.tile_buf) + pto.tmaxs ins(%src, %scalar : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x7x448xf32>) + return + } + + // Case 5: f32 256x16 (4096 elements) + func.func @TMAXS_f32_256x16(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c16 = arith.constant 16 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c4096, %c4096, %c4096, %c16, %c1] + : !pto.tensor_view<1x1x1x256x16xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c4096, %c4096, %c4096, %c16, %c1] + : !pto.tensor_view<1x1x1x256x16xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c16] + : !pto.tensor_view<1x1x1x256x16xf32> -> !pto.partition_tensor_view<1x1x1x256x16xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c16] + : !pto.tensor_view<1x1x1x256x16xf32> -> !pto.partition_tensor_view<1x1x1x256x16xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x256x16xf32>) + outs(%src : !pto.tile_buf) + pto.tmaxs ins(%src, %scalar : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x256x16xf32>) + return + } + +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmins/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tmins/CMakeLists.txt new file mode 100644 index 000000000..038d4e327 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmins/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tmins) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmins/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tmins/cases.py new file mode 100644 index 000000000..4526c4182 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmins/cases.py @@ -0,0 +1,22 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tmins ST test cases.""" + +import numpy as np + +CASES = [ + {"name": "f32_32x64", "dtype": np.float32, "shape": (32, 64), "valid_shape": (32, 64), "eps": 1e-6}, + {"name": "f16_63x64", "dtype": np.float16, "shape": (63, 64), "valid_shape": (63, 64), "eps": 1e-3}, + {"name": "i32_31x128", "dtype": np.int32, "shape": (31, 128), "valid_shape": (31, 128), "eps": 0}, + {"name": "i16_15x192", "dtype": np.int16, "shape": (15, 192), "valid_shape": (15, 192), "eps": 0}, + {"name": "f32_7x448", "dtype": np.float32, "shape": (7, 448), "valid_shape": (7, 448), "eps": 1e-6}, + {"name": "f32_256x16", "dtype": np.float32, "shape": (256, 16), "valid_shape": (256, 16), "eps": 1e-6}, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmins/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tmins/compare.py new file mode 100644 index 000000000..50186777e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmins/compare.py @@ -0,0 +1,46 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmins/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tmins/gen_data.py new file mode 100644 index 000000000..84da39655 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmins/gen_data.py @@ -0,0 +1,35 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +# Scalar value used for element-wise minimum (matches the scalar passed in launch.cpp) +SCALAR = 5.0 + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + scalar_val = dtype(SCALAR) + golden[:vr, :vc] = np.minimum(input1[:vr, :vc], scalar_val).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__} scalar={SCALAR}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmins/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmins/launch.cpp new file mode 100644 index 000000000..65d44ffc4 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmins/launch.cpp @@ -0,0 +1,58 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Scalar value used for element-wise minimum (must match gen_data.py SCALAR) +static constexpr float TMINS_SCALAR_F32 = 5.0f; + +// Case 0: f32 32x64 +extern "C" __global__ AICORE void TMINS_f32_32x64(__gm__ float *src, __gm__ float *dst, float scalar); + +void LaunchTMINS_f32_32x64(float *src, float *dst, void *stream) { + TMINS_f32_32x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TMINS_SCALAR_F32); +} + +// Case 1: f16 63x64 +extern "C" __global__ AICORE void TMINS_f16_63x64(__gm__ unsigned short *src, __gm__ unsigned short *dst, unsigned short scalar); + +void LaunchTMINS_f16_63x64(unsigned short *src, unsigned short *dst, void *stream) { + TMINS_f16_63x64<<<1, nullptr, stream>>>((__gm__ unsigned short *)src, (__gm__ unsigned short *)dst, (unsigned short)0x4500); +} + +// Case 2: i32 31x128 +extern "C" __global__ AICORE void TMINS_i32_31x128(__gm__ int32_t *src, __gm__ int32_t *dst, int32_t scalar); + +void LaunchTMINS_i32_31x128(int32_t *src, int32_t *dst, void *stream) { + TMINS_i32_31x128<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst, (int32_t)5); +} + +// Case 3: i16 15x192 +extern "C" __global__ AICORE void TMINS_i16_15x192(__gm__ int16_t *src, __gm__ int16_t *dst, int16_t scalar); + +void LaunchTMINS_i16_15x192(int16_t *src, int16_t *dst, void *stream) { + TMINS_i16_15x192<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst, (int16_t)5); +} + +// Case 4: f32 7x448 +extern "C" __global__ AICORE void TMINS_f32_7x448(__gm__ float *src, __gm__ float *dst, float scalar); + +void LaunchTMINS_f32_7x448(float *src, float *dst, void *stream) { + TMINS_f32_7x448<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TMINS_SCALAR_F32); +} + +// Case 5: f32 256x16 +extern "C" __global__ AICORE void TMINS_f32_256x16(__gm__ float *src, __gm__ float *dst, float scalar); + +void LaunchTMINS_f32_256x16(float *src, float *dst, void *stream) { + TMINS_f32_256x16<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TMINS_SCALAR_F32); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmins/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmins/main.cpp new file mode 100644 index 000000000..9fd09e48e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmins/main.cpp @@ -0,0 +1,139 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tmins ST — case-table driven. +// tmins: dst = min(src, scalar) (single input + scalar). +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTMINS_f32_32x64(float *src, float *dst, void *stream); +void LaunchTMINS_f16_63x64(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTMINS_i32_31x128(int32_t *src, int32_t *dst, void *stream); +void LaunchTMINS_i16_15x192(int16_t *src, int16_t *dst, void *stream); +void LaunchTMINS_f32_7x448(float *src, float *dst, void *stream); +void LaunchTMINS_f32_256x16(float *src, float *dst, void *stream); + +struct TestCase { + const char *name; + void (*launch)(void *, void *, void *); // src, dst, stream + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_32x64", (void (*)(void*,void*,void*))LaunchTMINS_f32_32x64, 32, 64, 32, 64, sizeof(float)}, + {"f16_63x64", (void (*)(void*,void*,void*))LaunchTMINS_f16_63x64, 63, 64, 63, 64, sizeof(uint16_t)}, + {"i32_31x128", (void (*)(void*,void*,void*))LaunchTMINS_i32_31x128, 31, 128, 31, 128, sizeof(int32_t)}, + {"i16_15x192", (void (*)(void*,void*,void*))LaunchTMINS_i16_15x192, 15, 192, 15, 192, sizeof(int16_t)}, + {"f32_7x448", (void (*)(void*,void*,void*))LaunchTMINS_f32_7x448, 7, 448, 7, 448, sizeof(float)}, + {"f32_256x16", (void (*)(void*,void*,void*))LaunchTMINS_f32_256x16, 256, 16, 256, 16, sizeof(float)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, fileSize); + aclrtMallocHost(&dstHost, fileSize); + + aclrtMalloc(&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tmins [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmins/tmins.pto b/test/tilelang_st/npu/a5/src/st/testcase/tmins/tmins.pto new file mode 100644 index 000000000..f05e56825 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmins/tmins.pto @@ -0,0 +1,292 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tmins: tload(src) + tmins(src, scalar)->dst + tstore(dst). +// Multiple cases with different shapes/dtypes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. + +module { + + // Case 0: f32 32x64 (2048 elements) + func.func @TMINS_f32_32x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + outs(%src : !pto.tile_buf) + pto.tmins ins(%src, %scalar : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + return + } + + // Case 1: f16 63x64 (4032 elements) + func.func @TMINS_f16_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f16) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) + outs(%src : !pto.tile_buf) + pto.tmins ins(%src, %scalar : !pto.tile_buf, f16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) + return + } + + // Case 2: i32 31x128 (3968 elements) + func.func @TMINS_i32_31x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c31 = arith.constant 31 : index + %c128 = arith.constant 128 : index + %c3968 = arith.constant 3968 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + outs(%src : !pto.tile_buf) + pto.tmins ins(%src, %scalar : !pto.tile_buf, i32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + return + } + + // Case 3: i16 15x192 (2880 elements) + func.func @TMINS_i16_15x192(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i16) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c192 = arith.constant 192 : index + %c2880 = arith.constant 2880 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + outs(%src : !pto.tile_buf) + pto.tmins ins(%src, %scalar : !pto.tile_buf, i16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + return + } + + // Case 4: f32 7x448 (3136 elements) + func.func @TMINS_f32_7x448(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c448 = arith.constant 448 : index + %c3136 = arith.constant 3136 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c448] + : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x448xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c448] + : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x448xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x7x448xf32>) + outs(%src : !pto.tile_buf) + pto.tmins ins(%src, %scalar : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x7x448xf32>) + return + } + + // Case 5: f32 256x16 (4096 elements) + func.func @TMINS_f32_256x16(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c16 = arith.constant 16 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c4096, %c4096, %c4096, %c16, %c1] + : !pto.tensor_view<1x1x1x256x16xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c4096, %c4096, %c4096, %c16, %c1] + : !pto.tensor_view<1x1x1x256x16xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c16] + : !pto.tensor_view<1x1x1x256x16xf32> -> !pto.partition_tensor_view<1x1x1x256x16xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c16] + : !pto.tensor_view<1x1x1x256x16xf32> -> !pto.partition_tensor_view<1x1x1x256x16xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x256x16xf32>) + outs(%src : !pto.tile_buf) + pto.tmins ins(%src, %scalar : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x256x16xf32>) + return + } + +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmuls/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tmuls/CMakeLists.txt new file mode 100644 index 000000000..49ba8cd84 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmuls/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tmuls) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmuls/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tmuls/cases.py new file mode 100644 index 000000000..d12724e5f --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmuls/cases.py @@ -0,0 +1,69 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tmuls ST test cases. + +Shapes and dtype match testcase/tadds (C++ GTest suite): + case1: float, 32x64, valid 32x64 + case2: float16, 63x64, valid 63x64 + case3: int32, 31x128, valid 31x128 + case4: int16, 15x192, valid 15x192 + case5: float, 7x448, valid 7x448 + case6: float, 256x16, valid 256x16 + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_32x64", + "dtype": np.float32, + "shape": (32, 64), + "valid_shape": (32, 64), + "eps": 1e-6, + }, + { + "name": "f16_63x64", + "dtype": np.float16, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 1e-3, + }, + { + "name": "i32_31x128", + "dtype": np.int32, + "shape": (31, 128), + "valid_shape": (31, 128), + "eps": 0, + }, + { + "name": "i16_15x192", + "dtype": np.int16, + "shape": (15, 192), + "valid_shape": (15, 192), + "eps": 0, + }, + { + "name": "f32_7x448", + "dtype": np.float32, + "shape": (7, 448), + "valid_shape": (7, 448), + "eps": 1e-6, + }, + { + "name": "f32_256x16", + "dtype": np.float32, + "shape": (256, 16), + "valid_shape": (256, 16), + "eps": 1e-6, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmuls/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tmuls/compare.py new file mode 100644 index 000000000..50186777e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmuls/compare.py @@ -0,0 +1,46 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmuls/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tmuls/gen_data.py new file mode 100644 index 000000000..a98114643 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmuls/gen_data.py @@ -0,0 +1,35 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +# Scalar value multiplied into every element (matches the scalar passed in launch.cpp) +SCALAR = 3.0 + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + scalar_val = dtype(SCALAR) + golden[:vr, :vc] = (input1[:vr, :vc] * scalar_val).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__} scalar={SCALAR}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmuls/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmuls/launch.cpp new file mode 100644 index 000000000..fdd67d596 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmuls/launch.cpp @@ -0,0 +1,58 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Scalar value multiplied into every element (must match gen_data.py SCALAR) +static constexpr float TMULS_SCALAR_F32 = 3.0f; + +// Case 0: f32 32x64 +extern "C" __global__ AICORE void TMULS_f32_32x64(__gm__ float *src, __gm__ float *dst, float scalar); + +void LaunchTMULS_f32_32x64(float *src, float *dst, void *stream) { + TMULS_f32_32x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TMULS_SCALAR_F32); +} + +// Case 1: f16 63x64 +extern "C" __global__ AICORE void TMULS_f16_63x64(__gm__ unsigned short *src, __gm__ unsigned short *dst, unsigned short scalar); + +void LaunchTMULS_f16_63x64(unsigned short *src, unsigned short *dst, void *stream) { + TMULS_f16_63x64<<<1, nullptr, stream>>>((__gm__ unsigned short *)src, (__gm__ unsigned short *)dst, (unsigned short)0x4200); +} + +// Case 2: i32 31x128 +extern "C" __global__ AICORE void TMULS_i32_31x128(__gm__ int32_t *src, __gm__ int32_t *dst, int32_t scalar); + +void LaunchTMULS_i32_31x128(int32_t *src, int32_t *dst, void *stream) { + TMULS_i32_31x128<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst, (int32_t)3); +} + +// Case 3: i16 15x192 +extern "C" __global__ AICORE void TMULS_i16_15x192(__gm__ int16_t *src, __gm__ int16_t *dst, int16_t scalar); + +void LaunchTMULS_i16_15x192(int16_t *src, int16_t *dst, void *stream) { + TMULS_i16_15x192<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst, (int16_t)3); +} + +// Case 4: f32 7x448 +extern "C" __global__ AICORE void TMULS_f32_7x448(__gm__ float *src, __gm__ float *dst, float scalar); + +void LaunchTMULS_f32_7x448(float *src, float *dst, void *stream) { + TMULS_f32_7x448<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TMULS_SCALAR_F32); +} + +// Case 5: f32 256x16 +extern "C" __global__ AICORE void TMULS_f32_256x16(__gm__ float *src, __gm__ float *dst, float scalar); + +void LaunchTMULS_f32_256x16(float *src, float *dst, void *stream) { + TMULS_f32_256x16<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TMULS_SCALAR_F32); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmuls/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmuls/main.cpp new file mode 100644 index 000000000..a5372cccc --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmuls/main.cpp @@ -0,0 +1,139 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tmuls ST — case-table driven. +// tmuls: dst = src * scalar (single input + scalar). +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTMULS_f32_32x64(float *src, float *dst, void *stream); +void LaunchTMULS_f16_63x64(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTMULS_i32_31x128(int32_t *src, int32_t *dst, void *stream); +void LaunchTMULS_i16_15x192(int16_t *src, int16_t *dst, void *stream); +void LaunchTMULS_f32_7x448(float *src, float *dst, void *stream); +void LaunchTMULS_f32_256x16(float *src, float *dst, void *stream); + +struct TestCase { + const char *name; + void (*launch)(void *, void *, void *); // src, dst, stream + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_32x64", (void (*)(void*,void*,void*))LaunchTMULS_f32_32x64, 32, 64, 32, 64, sizeof(float)}, + {"f16_63x64", (void (*)(void*,void*,void*))LaunchTMULS_f16_63x64, 63, 64, 63, 64, sizeof(uint16_t)}, + {"i32_31x128", (void (*)(void*,void*,void*))LaunchTMULS_i32_31x128, 31, 128, 31, 128, sizeof(int32_t)}, + {"i16_15x192", (void (*)(void*,void*,void*))LaunchTMULS_i16_15x192, 15, 192, 15, 192, sizeof(int16_t)}, + {"f32_7x448", (void (*)(void*,void*,void*))LaunchTMULS_f32_7x448, 7, 448, 7, 448, sizeof(float)}, + {"f32_256x16", (void (*)(void*,void*,void*))LaunchTMULS_f32_256x16, 256, 16, 256, 16, sizeof(float)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, fileSize); + aclrtMallocHost(&dstHost, fileSize); + + aclrtMalloc(&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tmuls [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmuls/tmuls.pto b/test/tilelang_st/npu/a5/src/st/testcase/tmuls/tmuls.pto new file mode 100644 index 000000000..9926190da --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmuls/tmuls.pto @@ -0,0 +1,292 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tmuls: tload(src) + tmuls(src, scalar)->dst + tstore(dst). +// Multiple cases with different shapes/dtypes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. + +module { + + // Case 0: f32 32x64 (2048 elements) + func.func @TMULS_f32_32x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + outs(%src : !pto.tile_buf) + pto.tmuls ins(%src, %scalar : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + return + } + + // Case 1: f16 63x64 (4032 elements) + func.func @TMULS_f16_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f16) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) + outs(%src : !pto.tile_buf) + pto.tmuls ins(%src, %scalar : !pto.tile_buf, f16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) + return + } + + // Case 2: i32 31x128 (3968 elements) + func.func @TMULS_i32_31x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c31 = arith.constant 31 : index + %c128 = arith.constant 128 : index + %c3968 = arith.constant 3968 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + outs(%src : !pto.tile_buf) + pto.tmuls ins(%src, %scalar : !pto.tile_buf, i32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + return + } + + // Case 3: i16 15x192 (2880 elements) + func.func @TMULS_i16_15x192(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i16) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c192 = arith.constant 192 : index + %c2880 = arith.constant 2880 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + outs(%src : !pto.tile_buf) + pto.tmuls ins(%src, %scalar : !pto.tile_buf, i16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + return + } + + // Case 4: f32 7x448 (3136 elements) + func.func @TMULS_f32_7x448(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c448 = arith.constant 448 : index + %c3136 = arith.constant 3136 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c448] + : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x448xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c448] + : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x448xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x7x448xf32>) + outs(%src : !pto.tile_buf) + pto.tmuls ins(%src, %scalar : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x7x448xf32>) + return + } + + // Case 5: f32 256x16 (4096 elements) + func.func @TMULS_f32_256x16(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c16 = arith.constant 16 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c4096, %c4096, %c4096, %c16, %c1] + : !pto.tensor_view<1x1x1x256x16xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c4096, %c4096, %c4096, %c16, %c1] + : !pto.tensor_view<1x1x1x256x16xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c16] + : !pto.tensor_view<1x1x1x256x16xf32> -> !pto.partition_tensor_view<1x1x1x256x16xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c16] + : !pto.tensor_view<1x1x1x256x16xf32> -> !pto.partition_tensor_view<1x1x1x256x16xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x256x16xf32>) + outs(%src : !pto.tile_buf) + pto.tmuls ins(%src, %scalar : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x256x16xf32>) + return + } + +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tors/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tors/CMakeLists.txt new file mode 100644 index 000000000..5decd02d7 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tors/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tors) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tors/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tors/cases.py new file mode 100644 index 000000000..18cc99178 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tors/cases.py @@ -0,0 +1,42 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np + +CASES = [ + { + "name": "i32_32x64", + "dtype": np.int32, + "shape": (32, 64), + "valid_shape": (32, 64), + "eps": 0, + }, + { + "name": "i16_63x64", + "dtype": np.int16, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 0, + }, + { + "name": "i32_31x128", + "dtype": np.int32, + "shape": (31, 128), + "valid_shape": (31, 128), + "eps": 0, + }, + { + "name": "i16_15x192", + "dtype": np.int16, + "shape": (15, 192), + "valid_shape": (15, 192), + "eps": 0, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tors/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tors/compare.py new file mode 100644 index 000000000..50186777e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tors/compare.py @@ -0,0 +1,46 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tors/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tors/gen_data.py new file mode 100644 index 000000000..c4c879dcd --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tors/gen_data.py @@ -0,0 +1,35 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +# Scalar value for bitwise OR (must match launch.cpp) +SCALAR = 3 + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + scalar_val = dtype(SCALAR) + golden[:vr, :vc] = (input1[:vr, :vc] | scalar_val).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__} scalar={SCALAR}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tors/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tors/launch.cpp new file mode 100644 index 000000000..4495ff38c --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tors/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Scalar value for bitwise OR (must match gen_data.py SCALAR) +static constexpr int32_t TORS_SCALAR_I32 = 3; +static constexpr int16_t TORS_SCALAR_I16 = 3; + +// Case 0: i32 32x64 +extern "C" __global__ AICORE void TORS_i32_32x64(__gm__ int32_t *src, __gm__ int32_t *dst, int32_t scalar); + +void LaunchTORS_i32_32x64(int32_t *src, int32_t *dst, void *stream) { + TORS_i32_32x64<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst, TORS_SCALAR_I32); +} + +// Case 1: i16 63x64 +extern "C" __global__ AICORE void TORS_i16_63x64(__gm__ int16_t *src, __gm__ int16_t *dst, int16_t scalar); + +void LaunchTORS_i16_63x64(int16_t *src, int16_t *dst, void *stream) { + TORS_i16_63x64<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst, TORS_SCALAR_I16); +} + +// Case 2: i32 31x128 +extern "C" __global__ AICORE void TORS_i32_31x128(__gm__ int32_t *src, __gm__ int32_t *dst, int32_t scalar); + +void LaunchTORS_i32_31x128(int32_t *src, int32_t *dst, void *stream) { + TORS_i32_31x128<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst, TORS_SCALAR_I32); +} + +// Case 3: i16 15x192 +extern "C" __global__ AICORE void TORS_i16_15x192(__gm__ int16_t *src, __gm__ int16_t *dst, int16_t scalar); + +void LaunchTORS_i16_15x192(int16_t *src, int16_t *dst, void *stream) { + TORS_i16_15x192<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst, TORS_SCALAR_I16); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tors/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tors/main.cpp new file mode 100644 index 000000000..b67da6f06 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tors/main.cpp @@ -0,0 +1,135 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tors ST — case-table driven. +// tors: dst = src | scalar (single input + scalar, bitwise OR). +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTORS_i32_32x64(int32_t *src, int32_t *dst, void *stream); +void LaunchTORS_i16_63x64(int16_t *src, int16_t *dst, void *stream); +void LaunchTORS_i32_31x128(int32_t *src, int32_t *dst, void *stream); +void LaunchTORS_i16_15x192(int16_t *src, int16_t *dst, void *stream); + +struct TestCase { + const char *name; + void (*launch)(void *, void *, void *); // src, dst, stream + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"i32_32x64", (void (*)(void*,void*,void*))LaunchTORS_i32_32x64, 32, 64, 32, 64, sizeof(int32_t)}, + {"i16_63x64", (void (*)(void*,void*,void*))LaunchTORS_i16_63x64, 63, 64, 63, 64, sizeof(int16_t)}, + {"i32_31x128", (void (*)(void*,void*,void*))LaunchTORS_i32_31x128, 31, 128, 31, 128, sizeof(int32_t)}, + {"i16_15x192", (void (*)(void*,void*,void*))LaunchTORS_i16_15x192, 15, 192, 15, 192, sizeof(int16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, fileSize); + aclrtMallocHost(&dstHost, fileSize); + + aclrtMalloc(&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tors [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tors/tors.pto b/test/tilelang_st/npu/a5/src/st/testcase/tors/tors.pto new file mode 100644 index 000000000..b939d66eb --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tors/tors.pto @@ -0,0 +1,200 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tors: tload(src) + tors(src, scalar)->dst + tstore(dst). +// Multiple cases with different shapes/dtypes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. + +module { + + // Case 0: i32 32x64 (2048 elements) + func.func @TORS_i32_32x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xi32> -> !pto.partition_tensor_view<1x1x1x32x64xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xi32> -> !pto.partition_tensor_view<1x1x1x32x64xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x64xi32>) + outs(%src : !pto.tile_buf) + pto.tors ins(%src, %scalar : !pto.tile_buf, i32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xi32>) + return + } + + // Case 1: i16 63x64 (4032 elements) + func.func @TORS_i16_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i16) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xi16> -> !pto.partition_tensor_view<1x1x1x63x64xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xi16> -> !pto.partition_tensor_view<1x1x1x63x64xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xi16>) + outs(%src : !pto.tile_buf) + pto.tors ins(%src, %scalar : !pto.tile_buf, i16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x64xi16>) + return + } + + // Case 2: i32 31x128 (3968 elements) + func.func @TORS_i32_31x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c31 = arith.constant 31 : index + %c128 = arith.constant 128 : index + %c3968 = arith.constant 3968 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + outs(%src : !pto.tile_buf) + pto.tors ins(%src, %scalar : !pto.tile_buf, i32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + return + } + + // Case 3: i16 15x192 (2880 elements) + func.func @TORS_i16_15x192(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i16) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c192 = arith.constant 192 : index + %c2880 = arith.constant 2880 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + outs(%src : !pto.tile_buf) + pto.tors ins(%src, %scalar : !pto.tile_buf, i16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + return + } + +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshls/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tshls/CMakeLists.txt new file mode 100644 index 000000000..ae8289e40 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshls/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tshls) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshls/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tshls/cases.py new file mode 100644 index 000000000..18cc99178 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshls/cases.py @@ -0,0 +1,42 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np + +CASES = [ + { + "name": "i32_32x64", + "dtype": np.int32, + "shape": (32, 64), + "valid_shape": (32, 64), + "eps": 0, + }, + { + "name": "i16_63x64", + "dtype": np.int16, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 0, + }, + { + "name": "i32_31x128", + "dtype": np.int32, + "shape": (31, 128), + "valid_shape": (31, 128), + "eps": 0, + }, + { + "name": "i16_15x192", + "dtype": np.int16, + "shape": (15, 192), + "valid_shape": (15, 192), + "eps": 0, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshls/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tshls/compare.py new file mode 100644 index 000000000..50186777e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshls/compare.py @@ -0,0 +1,46 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshls/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tshls/gen_data.py new file mode 100644 index 000000000..9b4624bfc --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshls/gen_data.py @@ -0,0 +1,34 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +# Scalar value for left shift (must match launch.cpp) +SCALAR = 2 + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = (input1[:vr, :vc] << SCALAR).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__} scalar={SCALAR}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshls/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tshls/launch.cpp new file mode 100644 index 000000000..5e7343071 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshls/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Scalar value for left shift (must match gen_data.py SCALAR) +static constexpr int16_t TSHLS_SCALAR = 2; + +// Case 0: i32 32x64 +extern "C" __global__ AICORE void TSHLS_i32_32x64(__gm__ int32_t *src, __gm__ int32_t *dst, int16_t scalar); + +void LaunchTSHLS_i32_32x64(int32_t *src, int32_t *dst, void *stream) { + TSHLS_i32_32x64<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst, TSHLS_SCALAR); +} + +// Case 1: i16 63x64 +extern "C" __global__ AICORE void TSHLS_i16_63x64(__gm__ int16_t *src, __gm__ int16_t *dst, int16_t scalar); + +void LaunchTSHLS_i16_63x64(int16_t *src, int16_t *dst, void *stream) { + TSHLS_i16_63x64<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst, TSHLS_SCALAR); +} + +// Case 2: i32 31x128 +extern "C" __global__ AICORE void TSHLS_i32_31x128(__gm__ int32_t *src, __gm__ int32_t *dst, int16_t scalar); + +void LaunchTSHLS_i32_31x128(int32_t *src, int32_t *dst, void *stream) { + TSHLS_i32_31x128<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst, TSHLS_SCALAR); +} + +// Case 3: i16 15x192 +extern "C" __global__ AICORE void TSHLS_i16_15x192(__gm__ int16_t *src, __gm__ int16_t *dst, int16_t scalar); + +void LaunchTSHLS_i16_15x192(int16_t *src, int16_t *dst, void *stream) { + TSHLS_i16_15x192<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst, TSHLS_SCALAR); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshls/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tshls/main.cpp new file mode 100644 index 000000000..ca0e4d73e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshls/main.cpp @@ -0,0 +1,135 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tshls ST — case-table driven. +// tshls: dst = src << scalar (single input + scalar, left shift). +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTSHLS_i32_32x64(int32_t *src, int32_t *dst, void *stream); +void LaunchTSHLS_i16_63x64(int16_t *src, int16_t *dst, void *stream); +void LaunchTSHLS_i32_31x128(int32_t *src, int32_t *dst, void *stream); +void LaunchTSHLS_i16_15x192(int16_t *src, int16_t *dst, void *stream); + +struct TestCase { + const char *name; + void (*launch)(void *, void *, void *); // src, dst, stream + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"i32_32x64", (void (*)(void*,void*,void*))LaunchTSHLS_i32_32x64, 32, 64, 32, 64, sizeof(int32_t)}, + {"i16_63x64", (void (*)(void*,void*,void*))LaunchTSHLS_i16_63x64, 63, 64, 63, 64, sizeof(int16_t)}, + {"i32_31x128", (void (*)(void*,void*,void*))LaunchTSHLS_i32_31x128, 31, 128, 31, 128, sizeof(int32_t)}, + {"i16_15x192", (void (*)(void*,void*,void*))LaunchTSHLS_i16_15x192, 15, 192, 15, 192, sizeof(int16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, fileSize); + aclrtMallocHost(&dstHost, fileSize); + + aclrtMalloc(&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tshls [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshls/tshls.pto b/test/tilelang_st/npu/a5/src/st/testcase/tshls/tshls.pto new file mode 100644 index 000000000..f4fbf0893 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshls/tshls.pto @@ -0,0 +1,200 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tshls: tload(src) + tshls(src, scalar)->dst + tstore(dst). +// Multiple cases with different shapes/dtypes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. + +module { + + // Case 0: i32 32x64 (2048 elements) + func.func @TSHLS_i32_32x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i16) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xi32> -> !pto.partition_tensor_view<1x1x1x32x64xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xi32> -> !pto.partition_tensor_view<1x1x1x32x64xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x64xi32>) + outs(%src : !pto.tile_buf) + pto.tshls ins(%src, %scalar : !pto.tile_buf, i16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xi32>) + return + } + + // Case 1: i16 63x64 (4032 elements) + func.func @TSHLS_i16_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i16) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xi16> -> !pto.partition_tensor_view<1x1x1x63x64xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xi16> -> !pto.partition_tensor_view<1x1x1x63x64xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xi16>) + outs(%src : !pto.tile_buf) + pto.tshls ins(%src, %scalar : !pto.tile_buf, i16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x64xi16>) + return + } + + // Case 2: i32 31x128 (3968 elements) + func.func @TSHLS_i32_31x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i16) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c31 = arith.constant 31 : index + %c128 = arith.constant 128 : index + %c3968 = arith.constant 3968 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + outs(%src : !pto.tile_buf) + pto.tshls ins(%src, %scalar : !pto.tile_buf, i16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + return + } + + // Case 3: i16 15x192 (2880 elements) + func.func @TSHLS_i16_15x192(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i16) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c192 = arith.constant 192 : index + %c2880 = arith.constant 2880 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + outs(%src : !pto.tile_buf) + pto.tshls ins(%src, %scalar : !pto.tile_buf, i16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + return + } + +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshrs/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tshrs/CMakeLists.txt new file mode 100644 index 000000000..c8e37c793 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshrs/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tshrs) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshrs/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tshrs/cases.py new file mode 100644 index 000000000..18cc99178 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshrs/cases.py @@ -0,0 +1,42 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np + +CASES = [ + { + "name": "i32_32x64", + "dtype": np.int32, + "shape": (32, 64), + "valid_shape": (32, 64), + "eps": 0, + }, + { + "name": "i16_63x64", + "dtype": np.int16, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 0, + }, + { + "name": "i32_31x128", + "dtype": np.int32, + "shape": (31, 128), + "valid_shape": (31, 128), + "eps": 0, + }, + { + "name": "i16_15x192", + "dtype": np.int16, + "shape": (15, 192), + "valid_shape": (15, 192), + "eps": 0, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshrs/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tshrs/compare.py new file mode 100644 index 000000000..50186777e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshrs/compare.py @@ -0,0 +1,46 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshrs/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tshrs/gen_data.py new file mode 100644 index 000000000..6f269f96a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshrs/gen_data.py @@ -0,0 +1,34 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +# Scalar value for right shift (must match launch.cpp) +SCALAR = 2 + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = (input1[:vr, :vc] >> SCALAR).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__} scalar={SCALAR}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshrs/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tshrs/launch.cpp new file mode 100644 index 000000000..e80c03e16 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshrs/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Scalar value for right shift (must match gen_data.py SCALAR) +static constexpr int16_t TSHRS_SCALAR = 2; + +// Case 0: i32 32x64 +extern "C" __global__ AICORE void TSHRS_i32_32x64(__gm__ int32_t *src, __gm__ int32_t *dst, int16_t scalar); + +void LaunchTSHRS_i32_32x64(int32_t *src, int32_t *dst, void *stream) { + TSHRS_i32_32x64<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst, TSHRS_SCALAR); +} + +// Case 1: i16 63x64 +extern "C" __global__ AICORE void TSHRS_i16_63x64(__gm__ int16_t *src, __gm__ int16_t *dst, int16_t scalar); + +void LaunchTSHRS_i16_63x64(int16_t *src, int16_t *dst, void *stream) { + TSHRS_i16_63x64<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst, TSHRS_SCALAR); +} + +// Case 2: i32 31x128 +extern "C" __global__ AICORE void TSHRS_i32_31x128(__gm__ int32_t *src, __gm__ int32_t *dst, int16_t scalar); + +void LaunchTSHRS_i32_31x128(int32_t *src, int32_t *dst, void *stream) { + TSHRS_i32_31x128<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst, TSHRS_SCALAR); +} + +// Case 3: i16 15x192 +extern "C" __global__ AICORE void TSHRS_i16_15x192(__gm__ int16_t *src, __gm__ int16_t *dst, int16_t scalar); + +void LaunchTSHRS_i16_15x192(int16_t *src, int16_t *dst, void *stream) { + TSHRS_i16_15x192<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst, TSHRS_SCALAR); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshrs/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tshrs/main.cpp new file mode 100644 index 000000000..3afd710f9 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshrs/main.cpp @@ -0,0 +1,135 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tshrs ST — case-table driven. +// tshrs: dst = src >> scalar (single input + scalar, right shift). +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTSHRS_i32_32x64(int32_t *src, int32_t *dst, void *stream); +void LaunchTSHRS_i16_63x64(int16_t *src, int16_t *dst, void *stream); +void LaunchTSHRS_i32_31x128(int32_t *src, int32_t *dst, void *stream); +void LaunchTSHRS_i16_15x192(int16_t *src, int16_t *dst, void *stream); + +struct TestCase { + const char *name; + void (*launch)(void *, void *, void *); // src, dst, stream + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"i32_32x64", (void (*)(void*,void*,void*))LaunchTSHRS_i32_32x64, 32, 64, 32, 64, sizeof(int32_t)}, + {"i16_63x64", (void (*)(void*,void*,void*))LaunchTSHRS_i16_63x64, 63, 64, 63, 64, sizeof(int16_t)}, + {"i32_31x128", (void (*)(void*,void*,void*))LaunchTSHRS_i32_31x128, 31, 128, 31, 128, sizeof(int32_t)}, + {"i16_15x192", (void (*)(void*,void*,void*))LaunchTSHRS_i16_15x192, 15, 192, 15, 192, sizeof(int16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, fileSize); + aclrtMallocHost(&dstHost, fileSize); + + aclrtMalloc(&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tshrs [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshrs/tshrs.pto b/test/tilelang_st/npu/a5/src/st/testcase/tshrs/tshrs.pto new file mode 100644 index 000000000..3196bf8f1 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshrs/tshrs.pto @@ -0,0 +1,200 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tshrs: tload(src) + tshrs(src, scalar)->dst + tstore(dst). +// Multiple cases with different shapes/dtypes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. + +module { + + // Case 0: i32 32x64 (2048 elements) + func.func @TSHRS_i32_32x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i16) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xi32> -> !pto.partition_tensor_view<1x1x1x32x64xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xi32> -> !pto.partition_tensor_view<1x1x1x32x64xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x64xi32>) + outs(%src : !pto.tile_buf) + pto.tshrs ins(%src, %scalar : !pto.tile_buf, i16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xi32>) + return + } + + // Case 1: i16 63x64 (4032 elements) + func.func @TSHRS_i16_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i16) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xi16> -> !pto.partition_tensor_view<1x1x1x63x64xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xi16> -> !pto.partition_tensor_view<1x1x1x63x64xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xi16>) + outs(%src : !pto.tile_buf) + pto.tshrs ins(%src, %scalar : !pto.tile_buf, i16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x64xi16>) + return + } + + // Case 2: i32 31x128 (3968 elements) + func.func @TSHRS_i32_31x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i16) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c31 = arith.constant 31 : index + %c128 = arith.constant 128 : index + %c3968 = arith.constant 3968 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + outs(%src : !pto.tile_buf) + pto.tshrs ins(%src, %scalar : !pto.tile_buf, i16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + return + } + + // Case 3: i16 15x192 (2880 elements) + func.func @TSHRS_i16_15x192(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i16) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c192 = arith.constant 192 : index + %c2880 = arith.constant 2880 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + outs(%src : !pto.tile_buf) + pto.tshrs ins(%src, %scalar : !pto.tile_buf, i16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + return + } + +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsubs/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tsubs/CMakeLists.txt new file mode 100644 index 000000000..3ccdb0fb1 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsubs/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tsubs) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsubs/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tsubs/cases.py new file mode 100644 index 000000000..af6b6b425 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsubs/cases.py @@ -0,0 +1,22 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tsubs ST test cases.""" + +import numpy as np + +CASES = [ + {"name": "f32_32x64", "dtype": np.float32, "shape": (32, 64), "valid_shape": (32, 64), "eps": 1e-6}, + {"name": "f16_63x64", "dtype": np.float16, "shape": (63, 64), "valid_shape": (63, 64), "eps": 1e-3}, + {"name": "i32_31x128", "dtype": np.int32, "shape": (31, 128), "valid_shape": (31, 128), "eps": 0}, + {"name": "i16_15x192", "dtype": np.int16, "shape": (15, 192), "valid_shape": (15, 192), "eps": 0}, + {"name": "f32_7x448", "dtype": np.float32, "shape": (7, 448), "valid_shape": (7, 448), "eps": 1e-6}, + {"name": "f32_256x16", "dtype": np.float32, "shape": (256, 16), "valid_shape": (256, 16), "eps": 1e-6}, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsubs/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tsubs/compare.py new file mode 100644 index 000000000..50186777e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsubs/compare.py @@ -0,0 +1,46 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsubs/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tsubs/gen_data.py new file mode 100644 index 000000000..20d55e1d6 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsubs/gen_data.py @@ -0,0 +1,35 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +# Scalar value subtracted from every element (matches the scalar passed in launch.cpp) +SCALAR = 3.0 + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + scalar_val = dtype(SCALAR) + golden[:vr, :vc] = (input1[:vr, :vc] - scalar_val).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__} scalar={SCALAR}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsubs/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tsubs/launch.cpp new file mode 100644 index 000000000..d511cf09d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsubs/launch.cpp @@ -0,0 +1,58 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Scalar value subtracted from every element (must match gen_data.py SCALAR) +static constexpr float TSUBS_SCALAR_F32 = 3.0f; + +// Case 0: f32 32x64 +extern "C" __global__ AICORE void TSUBS_f32_32x64(__gm__ float *src, __gm__ float *dst, float scalar); + +void LaunchTSUBS_f32_32x64(float *src, float *dst, void *stream) { + TSUBS_f32_32x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TSUBS_SCALAR_F32); +} + +// Case 1: f16 63x64 +extern "C" __global__ AICORE void TSUBS_f16_63x64(__gm__ unsigned short *src, __gm__ unsigned short *dst, unsigned short scalar); + +void LaunchTSUBS_f16_63x64(unsigned short *src, unsigned short *dst, void *stream) { + TSUBS_f16_63x64<<<1, nullptr, stream>>>((__gm__ unsigned short *)src, (__gm__ unsigned short *)dst, (unsigned short)0x4200); +} + +// Case 2: i32 31x128 +extern "C" __global__ AICORE void TSUBS_i32_31x128(__gm__ int32_t *src, __gm__ int32_t *dst, int32_t scalar); + +void LaunchTSUBS_i32_31x128(int32_t *src, int32_t *dst, void *stream) { + TSUBS_i32_31x128<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst, (int32_t)3); +} + +// Case 3: i16 15x192 +extern "C" __global__ AICORE void TSUBS_i16_15x192(__gm__ int16_t *src, __gm__ int16_t *dst, int16_t scalar); + +void LaunchTSUBS_i16_15x192(int16_t *src, int16_t *dst, void *stream) { + TSUBS_i16_15x192<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst, (int16_t)3); +} + +// Case 4: f32 7x448 +extern "C" __global__ AICORE void TSUBS_f32_7x448(__gm__ float *src, __gm__ float *dst, float scalar); + +void LaunchTSUBS_f32_7x448(float *src, float *dst, void *stream) { + TSUBS_f32_7x448<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TSUBS_SCALAR_F32); +} + +// Case 5: f32 256x16 +extern "C" __global__ AICORE void TSUBS_f32_256x16(__gm__ float *src, __gm__ float *dst, float scalar); + +void LaunchTSUBS_f32_256x16(float *src, float *dst, void *stream) { + TSUBS_f32_256x16<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TSUBS_SCALAR_F32); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsubs/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tsubs/main.cpp new file mode 100644 index 000000000..40509f578 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsubs/main.cpp @@ -0,0 +1,139 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tsubs ST — case-table driven. +// tsubs: dst = src - scalar (single input + scalar). +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTSUBS_f32_32x64(float *src, float *dst, void *stream); +void LaunchTSUBS_f16_63x64(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTSUBS_i32_31x128(int32_t *src, int32_t *dst, void *stream); +void LaunchTSUBS_i16_15x192(int16_t *src, int16_t *dst, void *stream); +void LaunchTSUBS_f32_7x448(float *src, float *dst, void *stream); +void LaunchTSUBS_f32_256x16(float *src, float *dst, void *stream); + +struct TestCase { + const char *name; + void (*launch)(void *, void *, void *); // src, dst, stream + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_32x64", (void (*)(void*,void*,void*))LaunchTSUBS_f32_32x64, 32, 64, 32, 64, sizeof(float)}, + {"f16_63x64", (void (*)(void*,void*,void*))LaunchTSUBS_f16_63x64, 63, 64, 63, 64, sizeof(uint16_t)}, + {"i32_31x128", (void (*)(void*,void*,void*))LaunchTSUBS_i32_31x128, 31, 128, 31, 128, sizeof(int32_t)}, + {"i16_15x192", (void (*)(void*,void*,void*))LaunchTSUBS_i16_15x192, 15, 192, 15, 192, sizeof(int16_t)}, + {"f32_7x448", (void (*)(void*,void*,void*))LaunchTSUBS_f32_7x448, 7, 448, 7, 448, sizeof(float)}, + {"f32_256x16", (void (*)(void*,void*,void*))LaunchTSUBS_f32_256x16, 256, 16, 256, 16, sizeof(float)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, fileSize); + aclrtMallocHost(&dstHost, fileSize); + + aclrtMalloc(&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tsubs [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsubs/tsubs.pto b/test/tilelang_st/npu/a5/src/st/testcase/tsubs/tsubs.pto new file mode 100644 index 000000000..fa04ca8c7 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsubs/tsubs.pto @@ -0,0 +1,292 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tsubs: tload(src) + tsubs(src, scalar)->dst + tstore(dst). +// Multiple cases with different shapes/dtypes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. + +module { + + // Case 0: f32 32x64 (2048 elements) + func.func @TSUBS_f32_32x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + outs(%src : !pto.tile_buf) + pto.tsubs ins(%src, %scalar : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + return + } + + // Case 1: f16 63x64 (4032 elements) + func.func @TSUBS_f16_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f16) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) + outs(%src : !pto.tile_buf) + pto.tsubs ins(%src, %scalar : !pto.tile_buf, f16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) + return + } + + // Case 2: i32 31x128 (3968 elements) + func.func @TSUBS_i32_31x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c31 = arith.constant 31 : index + %c128 = arith.constant 128 : index + %c3968 = arith.constant 3968 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + outs(%src : !pto.tile_buf) + pto.tsubs ins(%src, %scalar : !pto.tile_buf, i32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + return + } + + // Case 3: i16 15x192 (2880 elements) + func.func @TSUBS_i16_15x192(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i16) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c192 = arith.constant 192 : index + %c2880 = arith.constant 2880 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + outs(%src : !pto.tile_buf) + pto.tsubs ins(%src, %scalar : !pto.tile_buf, i16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + return + } + + // Case 4: f32 7x448 (3136 elements) + func.func @TSUBS_f32_7x448(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c448 = arith.constant 448 : index + %c3136 = arith.constant 3136 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c448] + : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x448xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c448] + : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x448xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x7x448xf32>) + outs(%src : !pto.tile_buf) + pto.tsubs ins(%src, %scalar : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x7x448xf32>) + return + } + + // Case 5: f32 256x16 (4096 elements) + func.func @TSUBS_f32_256x16(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c16 = arith.constant 16 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c4096, %c4096, %c4096, %c16, %c1] + : !pto.tensor_view<1x1x1x256x16xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c4096, %c4096, %c4096, %c16, %c1] + : !pto.tensor_view<1x1x1x256x16xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c16] + : !pto.tensor_view<1x1x1x256x16xf32> -> !pto.partition_tensor_view<1x1x1x256x16xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c16] + : !pto.tensor_view<1x1x1x256x16xf32> -> !pto.partition_tensor_view<1x1x1x256x16xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x256x16xf32>) + outs(%src : !pto.tile_buf) + pto.tsubs ins(%src, %scalar : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x256x16xf32>) + return + } + +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/txors/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/txors/CMakeLists.txt new file mode 100644 index 000000000..1bcd9e681 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/txors/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(txors) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/txors/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/txors/cases.py new file mode 100644 index 000000000..9b652056d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/txors/cases.py @@ -0,0 +1,48 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for txors ST test cases. + +txors: bitwise XOR with scalar, dst = src ^ scalar. +Integer only: i32, i16. +""" + +import numpy as np + +CASES = [ + { + "name": "i32_32x64", + "dtype": np.int32, + "shape": (32, 64), + "valid_shape": (32, 64), + "eps": 0, + }, + { + "name": "i16_63x64", + "dtype": np.int16, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 0, + }, + { + "name": "i32_31x128", + "dtype": np.int32, + "shape": (31, 128), + "valid_shape": (31, 128), + "eps": 0, + }, + { + "name": "i16_15x192", + "dtype": np.int16, + "shape": (15, 192), + "valid_shape": (15, 192), + "eps": 0, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/txors/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/txors/compare.py new file mode 100644 index 000000000..50186777e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/txors/compare.py @@ -0,0 +1,46 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/txors/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/txors/gen_data.py new file mode 100644 index 000000000..5c12edd5a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/txors/gen_data.py @@ -0,0 +1,35 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +# Scalar value for bitwise XOR (matches the scalar passed in launch.cpp) +SCALAR = 3 + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + scalar_val = dtype(SCALAR) + golden[:vr, :vc] = (input1[:vr, :vc] ^ scalar_val).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__} scalar={SCALAR}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/txors/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/txors/launch.cpp new file mode 100644 index 000000000..f61619d9f --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/txors/launch.cpp @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: i32 32x64 +extern "C" __global__ AICORE void TXORS_i32_32x64(__gm__ int32_t *src, __gm__ int32_t *dst, int32_t scalar); + +void LaunchTXORS_i32_32x64(int32_t *src, int32_t *dst, void *stream) { + TXORS_i32_32x64<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst, (int32_t)3); +} + +// Case 1: i16 63x64 +extern "C" __global__ AICORE void TXORS_i16_63x64(__gm__ int16_t *src, __gm__ int16_t *dst, int16_t scalar); + +void LaunchTXORS_i16_63x64(int16_t *src, int16_t *dst, void *stream) { + TXORS_i16_63x64<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst, (int16_t)3); +} + +// Case 2: i32 31x128 +extern "C" __global__ AICORE void TXORS_i32_31x128(__gm__ int32_t *src, __gm__ int32_t *dst, int32_t scalar); + +void LaunchTXORS_i32_31x128(int32_t *src, int32_t *dst, void *stream) { + TXORS_i32_31x128<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst, (int32_t)3); +} + +// Case 3: i16 15x192 +extern "C" __global__ AICORE void TXORS_i16_15x192(__gm__ int16_t *src, __gm__ int16_t *dst, int16_t scalar); + +void LaunchTXORS_i16_15x192(int16_t *src, int16_t *dst, void *stream) { + TXORS_i16_15x192<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst, (int16_t)3); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/txors/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/txors/main.cpp new file mode 100644 index 000000000..f46282f01 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/txors/main.cpp @@ -0,0 +1,135 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang txors ST — case-table driven. +// txors: dst = src ^ scalar (bitwise XOR with scalar). +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTXORS_i32_32x64(int32_t *src, int32_t *dst, void *stream); +void LaunchTXORS_i16_63x64(int16_t *src, int16_t *dst, void *stream); +void LaunchTXORS_i32_31x128(int32_t *src, int32_t *dst, void *stream); +void LaunchTXORS_i16_15x192(int16_t *src, int16_t *dst, void *stream); + +struct TestCase { + const char *name; + void (*launch)(void *, void *, void *); // src, dst, stream + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"i32_32x64", (void (*)(void*,void*,void*))LaunchTXORS_i32_32x64, 32, 64, 32, 64, sizeof(int32_t)}, + {"i16_63x64", (void (*)(void*,void*,void*))LaunchTXORS_i16_63x64, 63, 64, 63, 64, sizeof(int16_t)}, + {"i32_31x128", (void (*)(void*,void*,void*))LaunchTXORS_i32_31x128, 31, 128, 31, 128, sizeof(int32_t)}, + {"i16_15x192", (void (*)(void*,void*,void*))LaunchTXORS_i16_15x192, 15, 192, 15, 192, sizeof(int16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, fileSize); + aclrtMallocHost(&dstHost, fileSize); + + aclrtMalloc(&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./txors [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/txors/txors.pto b/test/tilelang_st/npu/a5/src/st/testcase/txors/txors.pto new file mode 100644 index 000000000..4fbf5013d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/txors/txors.pto @@ -0,0 +1,221 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.txors: tload(src) + txors(src, scalar, tmp)->dst + tstore(dst). +// Bitwise XOR with scalar: dst = src ^ scalar. +// Integer only: i32, i16. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. + +module { + + // Case 0: i32 32x64 (2048 elements) + func.func @TXORS_i32_32x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xi32> -> !pto.partition_tensor_view<1x1x1x32x64xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xi32> -> !pto.partition_tensor_view<1x1x1x32x64xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x64xi32>) + outs(%src : !pto.tile_buf) + pto.txors ins(%src, %scalar, %tmp : !pto.tile_buf, i32, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xi32>) + return + } + + // Case 1: i16 63x64 (4032 elements) + func.func @TXORS_i16_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i16) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xi16> -> !pto.partition_tensor_view<1x1x1x63x64xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xi16> -> !pto.partition_tensor_view<1x1x1x63x64xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xi16>) + outs(%src : !pto.tile_buf) + pto.txors ins(%src, %scalar, %tmp : !pto.tile_buf, i16, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x64xi16>) + return + } + + // Case 2: i32 31x128 (3968 elements) + func.func @TXORS_i32_31x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c31 = arith.constant 31 : index + %c128 = arith.constant 128 : index + %c3968 = arith.constant 3968 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + outs(%src : !pto.tile_buf) + pto.txors ins(%src, %scalar, %tmp : !pto.tile_buf, i32, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + return + } + + // Case 3: i16 15x192 (2880 elements) + func.func @TXORS_i16_15x192(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i16) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c192 = arith.constant 192 : index + %c2880 = arith.constant 2880 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + outs(%src : !pto.tile_buf) + pto.txors ins(%src, %scalar, %tmp : !pto.tile_buf, i16, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + return + } + +} From 8466ce4bec81499b6adbb4be4765150d7e1336ce Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Sat, 25 Apr 2026 14:19:31 +0800 Subject: [PATCH 171/192] feat: improve the func call in dsl. and support vdiv/vmod func --- lib/TileOps/__init__.py | 9 + lib/TileOps/math.py | 485 ++++++++++++++++++ lib/TileOps/render_template_mlir.py | 12 + .../docs/user_guide/06-control-flow.md | 9 + .../python/tilelang_dsl/expand_helper.py | 7 +- tilelang-dsl/python/tilelang_dsl/kernel.py | 40 +- tilelang-dsl/python/tilelang_dsl/lowering.py | 13 +- tilelang-dsl/python/tilelang_dsl/semantic.py | 62 ++- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 121 +++++ 9 files changed, 738 insertions(+), 20 deletions(-) create mode 100644 lib/TileOps/__init__.py create mode 100644 lib/TileOps/math.py diff --git a/lib/TileOps/__init__.py b/lib/TileOps/__init__.py new file mode 100644 index 000000000..34437fbee --- /dev/null +++ b/lib/TileOps/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""Shared TileOps template helpers.""" diff --git a/lib/TileOps/math.py b/lib/TileOps/math.py new file mode 100644 index 000000000..3d32340db --- /dev/null +++ b/lib/TileOps/math.py @@ -0,0 +1,485 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import tilelang_dsl as pto + + +@pto.inline_proc +def _vdiv_u16(vec, scalar_vec, mask): + zero = pto.ui16(0) + one = pto.ui16(1) + fp32_one = pto.f32(1.0) + full_mask_b16 = pto.pset_b16(pto.PAT.ALL) + full_mask_b32 = pto.pset_b32(pto.PAT.ALL) + + zero_mask = pto.vcmps(scalar_vec, zero, mask, pto.CmpMode.EQ) + active_mask = pto.pnot(zero_mask, mask) + + zero_u16 = pto.vbr(zero) + vy_lower_u16, vy_higher_u16 = pto.vintlv(scalar_vec, zero_u16) + vy_lower_u32 = pto.vcvt(vy_lower_u16, pto.ui32, full_mask_b16, part=pto.VcvtPartMode.EVEN) + vy_higher_u32 = pto.vcvt(vy_higher_u16, pto.ui32, full_mask_b16, part=pto.VcvtPartMode.EVEN) + active_low = pto.vcmps(vy_lower_u32, pto.ui32(0), full_mask_b32, pto.CmpMode.NE) + active_high = pto.vcmps(vy_higher_u32, pto.ui32(0), full_mask_b32, pto.CmpMode.NE) + vy_lower_f32 = pto.vcvt(pto.vbitcast(vy_lower_u32, pto.i32), pto.f32, active_low, rnd=pto.VcvtRoundMode.F) + vy_higher_f32 = pto.vcvt(pto.vbitcast(vy_higher_u32, pto.i32), pto.f32, active_high, rnd=pto.VcvtRoundMode.F) + + vy_rec_lower = pto.vdiv(pto.vbr(fp32_one), vy_lower_f32, active_low) + vy_rec_higher = pto.vdiv(pto.vbr(fp32_one), vy_higher_f32, active_high) + vy_scale_lower = pto.vmul(vy_rec_lower, pto.vbr(pto.f32(65536.0)), active_low) + vy_scale_higher = pto.vmul(vy_rec_higher, pto.vbr(pto.f32(65536.0)), active_high) + + v_lower_i32 = pto.vcvt( + vy_scale_lower, + pto.i32, + active_low, + rnd=pto.VcvtRoundMode.F, + sat=pto.VcvtSatMode.NOSAT, + ) + v_higher_i32 = pto.vcvt( + vy_scale_higher, + pto.i32, + active_high, + rnd=pto.VcvtRoundMode.F, + sat=pto.VcvtSatMode.NOSAT, + ) + v_lower_u32 = pto.vbitcast(v_lower_i32, pto.ui32) + v_higher_u32 = pto.vbitcast(v_higher_i32, pto.ui32) + + vx_lower_u16, vx_higher_u16 = pto.vintlv(vec, zero_u16) + vx_lower_u32 = pto.vcvt(vx_lower_u16, pto.ui32, full_mask_b16, part=pto.VcvtPartMode.EVEN) + vx_higher_u32 = pto.vcvt(vx_higher_u16, pto.ui32, full_mask_b16, part=pto.VcvtPartMode.EVEN) + q_tmp_lower = pto.vmul(v_lower_u32, vx_lower_u32, active_low) + q_tmp_higher = pto.vmul(v_higher_u32, vx_higher_u32, active_high) + _q_lower, q_tmp = pto.vdintlv(pto.vbitcast(q_tmp_lower, pto.ui16), pto.vbitcast(q_tmp_higher, pto.ui16)) + + yq_tmp = pto.vmul(q_tmp, scalar_vec, active_mask) + r_tmp = pto.vsub(vec, yq_tmp, active_mask) + ge_mask = pto.vcmp(r_tmp, scalar_vec, active_mask, pto.CmpMode.GE) + refined_r = pto.vsub(r_tmp, scalar_vec, active_mask) + r_tmp = pto.vsel(refined_r, r_tmp, ge_mask) + q_inc = pto.vadds(q_tmp, one, active_mask) + q_tmp = pto.vsel(q_inc, q_tmp, ge_mask) + ge_mask = pto.vcmp(r_tmp, scalar_vec, active_mask, pto.CmpMode.GE) + refined_r = pto.vsub(r_tmp, scalar_vec, active_mask) + r_tmp = pto.vsel(refined_r, r_tmp, ge_mask) + q_inc = pto.vadds(q_tmp, one, active_mask) + q_tmp = pto.vsel(q_inc, q_tmp, ge_mask) + zero_q = pto.vbr(pto.ui16(0xFFFF)) + return pto.vsel(zero_q, q_tmp, zero_mask) + + +@pto.inline_proc +def vdiv_u16(vec, scalar_vec, mask): + return _vdiv_u16(vec, scalar_vec, mask) + + +@pto.inline_proc +def vdiv_i16(vec, scalar_vec, mask): + zero = pto.i16(0) + neg_one = pto.i16(-1) + + zero_mask = pto.vcmps(scalar_vec, zero, mask, pto.CmpMode.EQ) + active_mask = pto.pnot(zero_mask, mask) + + abs_x = pto.vbitcast(pto.vabs(vec, active_mask), pto.ui16) + abs_y = pto.vbitcast(pto.vabs(scalar_vec, active_mask), pto.ui16) + x_xor_y = pto.vxor(vec, scalar_vec, active_mask) + p_pos = pto.vcmps(x_xor_y, zero, active_mask, pto.CmpMode.GE) + + q_abs = _vdiv_u16(abs_x, abs_y, active_mask) + neg_q = pto.vneg(pto.vbitcast(q_abs, pto.i16), active_mask) + q = pto.vsel(pto.vbitcast(q_abs, pto.i16), neg_q, p_pos) + return pto.vsel(pto.vbr(neg_one), q, zero_mask) + + +@pto.inline_proc +def vmod_u16(vec, scalar_vec, mask): + zero = pto.ui16(0) + one = pto.ui16(1) + zero_r = pto.ui16(0xFFFF) + fp32_one = pto.f32(1.0) + full_mask_b16 = pto.pset_b16(pto.PAT.ALL) + full_mask_b32 = pto.pset_b32(pto.PAT.ALL) + + zero_mask = pto.vcmps(scalar_vec, zero, mask, pto.CmpMode.EQ) + active_mask = pto.pnot(zero_mask, mask) + + zero_u16 = pto.vbr(zero) + vy_lower_u16, vy_higher_u16 = pto.vintlv(scalar_vec, zero_u16) + vy_lower_u32 = pto.vcvt(vy_lower_u16, pto.ui32, full_mask_b16, part=pto.VcvtPartMode.EVEN) + vy_higher_u32 = pto.vcvt(vy_higher_u16, pto.ui32, full_mask_b16, part=pto.VcvtPartMode.EVEN) + active_low = pto.vcmps(vy_lower_u32, pto.ui32(0), full_mask_b32, pto.CmpMode.NE) + active_high = pto.vcmps(vy_higher_u32, pto.ui32(0), full_mask_b32, pto.CmpMode.NE) + vy_lower_f32 = pto.vcvt(pto.vbitcast(vy_lower_u32, pto.i32), pto.f32, active_low, rnd=pto.VcvtRoundMode.F) + vy_higher_f32 = pto.vcvt(pto.vbitcast(vy_higher_u32, pto.i32), pto.f32, active_high, rnd=pto.VcvtRoundMode.F) + + vy_rec_lower = pto.vdiv(pto.vbr(fp32_one), vy_lower_f32, active_low) + vy_rec_higher = pto.vdiv(pto.vbr(fp32_one), vy_higher_f32, active_high) + vy_scale_lower = pto.vmul(vy_rec_lower, pto.vbr(pto.f32(65536.0)), active_low) + vy_scale_higher = pto.vmul(vy_rec_higher, pto.vbr(pto.f32(65536.0)), active_high) + + v_lower_i32 = pto.vcvt( + vy_scale_lower, + pto.i32, + active_low, + rnd=pto.VcvtRoundMode.F, + sat=pto.VcvtSatMode.NOSAT, + ) + v_higher_i32 = pto.vcvt( + vy_scale_higher, + pto.i32, + active_high, + rnd=pto.VcvtRoundMode.F, + sat=pto.VcvtSatMode.NOSAT, + ) + v_lower_u32 = pto.vbitcast(v_lower_i32, pto.ui32) + v_higher_u32 = pto.vbitcast(v_higher_i32, pto.ui32) + + vx_lower_u16, vx_higher_u16 = pto.vintlv(vec, zero_u16) + vx_lower_u32 = pto.vcvt(vx_lower_u16, pto.ui32, full_mask_b16, part=pto.VcvtPartMode.EVEN) + vx_higher_u32 = pto.vcvt(vx_higher_u16, pto.ui32, full_mask_b16, part=pto.VcvtPartMode.EVEN) + q_tmp_lower = pto.vmul(v_lower_u32, vx_lower_u32, active_low) + q_tmp_higher = pto.vmul(v_higher_u32, vx_higher_u32, active_high) + _q_lower, q_tmp = pto.vdintlv(pto.vbitcast(q_tmp_lower, pto.ui16), pto.vbitcast(q_tmp_higher, pto.ui16)) + + yq_tmp = pto.vmul(q_tmp, scalar_vec, active_mask) + r_tmp = pto.vsub(vec, yq_tmp, active_mask) + ge_mask = pto.vcmp(r_tmp, scalar_vec, active_mask, pto.CmpMode.GE) + refined_r = pto.vsub(r_tmp, scalar_vec, active_mask) + r_tmp = pto.vsel(refined_r, r_tmp, ge_mask) + q_inc = pto.vadds(q_tmp, one, active_mask) + q_tmp = pto.vsel(q_inc, q_tmp, ge_mask) + ge_mask = pto.vcmp(r_tmp, scalar_vec, active_mask, pto.CmpMode.GE) + refined_r = pto.vsub(r_tmp, scalar_vec, active_mask) + r_tmp = pto.vsel(refined_r, r_tmp, ge_mask) + q_inc = pto.vadds(q_tmp, one, active_mask) + q_tmp = pto.vsel(q_inc, q_tmp, ge_mask) + return pto.vsel(pto.vbr(zero_r), r_tmp, zero_mask) + + +@pto.inline_proc +def vdiv_u32(vec, scalar_vec, mask): + zero = pto.ui32(0) + one = pto.ui32(1) + zero_q = pto.ui32(0xFFFFFFFF) + fp32_one = pto.f32(1.0) + full_mask = pto.pset_b32(pto.PAT.ALL) + + zero_mask = pto.vcmps(scalar_vec, zero, mask, pto.CmpMode.EQ) + active_mask = pto.pnot(zero_mask, mask) + + zero_u32 = pto.vbr(zero) + zero_f32 = pto.vbr(pto.f32(0.0)) + vy_lower_u32, vy_higher_u32 = pto.vintlv(scalar_vec, zero_u32) + vy_lower_f32 = pto.vcvt(pto.vbitcast(vy_lower_u32, pto.i64), pto.f32, full_mask, rnd=pto.VcvtRoundMode.F, part=pto.VcvtPartMode.EVEN) + vy_higher_f32 = pto.vcvt(pto.vbitcast(vy_higher_u32, pto.i64), pto.f32, full_mask, rnd=pto.VcvtRoundMode.F, part=pto.VcvtPartMode.EVEN) + vy_float, _vy_waste = pto.vdintlv(vy_lower_f32, vy_higher_f32) + + vy_rec = pto.vdiv(pto.vbr(fp32_one), vy_float, full_mask) + vy_scale = pto.vmul(vy_rec, pto.vbr(pto.f32(4294966784.0)), full_mask) + + vy_scale_lower_f32, vy_scale_higher_f32 = pto.vintlv(vy_scale, zero_f32) + v_lower_i64 = pto.vcvt( + vy_scale_lower_f32, + pto.i64, + full_mask, + rnd=pto.VcvtRoundMode.F, + sat=pto.VcvtSatMode.NOSAT, + part=pto.VcvtPartMode.EVEN, + ) + v_higher_i64 = pto.vcvt( + vy_scale_higher_f32, + pto.i64, + full_mask, + rnd=pto.VcvtRoundMode.F, + sat=pto.VcvtSatMode.NOSAT, + part=pto.VcvtPartMode.EVEN, + ) + z, _z_waste = pto.vdintlv(pto.vbitcast(v_lower_i64, pto.ui32), pto.vbitcast(v_higher_i64, pto.ui32)) + + tmp_0 = pto.vmul(z, scalar_vec, full_mask) + tmp_0 = pto.vbitcast(pto.vneg(pto.vbitcast(tmp_0, pto.i32), full_mask), pto.ui32) + _z_lower, z_high = pto.vmull(z, tmp_0, full_mask) + z = pto.vadd(z, z_high, full_mask) + + _q_lower, q_tmp = pto.vmull(vec, z, full_mask) + yq_tmp = pto.vmul(q_tmp, scalar_vec, active_mask) + r_tmp = pto.vsub(vec, yq_tmp, active_mask) + ge_mask = pto.vcmp(r_tmp, scalar_vec, active_mask, pto.CmpMode.GE) + refined_r = pto.vsub(r_tmp, scalar_vec, active_mask) + r_tmp = pto.vsel(refined_r, r_tmp, ge_mask) + q_inc = pto.vadds(q_tmp, one, active_mask) + q_tmp = pto.vsel(q_inc, q_tmp, ge_mask) + ge_mask = pto.vcmp(r_tmp, scalar_vec, active_mask, pto.CmpMode.GE) + refined_r = pto.vsub(r_tmp, scalar_vec, active_mask) + r_tmp = pto.vsel(refined_r, r_tmp, ge_mask) + q_inc = pto.vadds(q_tmp, one, active_mask) + q_tmp = pto.vsel(q_inc, q_tmp, ge_mask) + return pto.vsel(pto.vbr(zero_q), q_tmp, zero_mask) + + +@pto.inline_proc +def vmod_i16(vec, scalar_vec, mask): + zero = pto.i16(0) + neg_one = pto.i16(-1) + + zero_mask = pto.vcmps(scalar_vec, zero, mask, pto.CmpMode.EQ) + active_mask = pto.pnot(zero_mask, mask) + + abs_x = pto.vbitcast(pto.vabs(vec, active_mask), pto.ui16) + abs_y = pto.vbitcast(pto.vabs(scalar_vec, active_mask), pto.ui16) + x_xor_y = pto.vxor(vec, scalar_vec, active_mask) + p_pos = pto.vcmps(x_xor_y, zero, active_mask, pto.CmpMode.GE) + + q_abs = _vdiv_u16(abs_x, abs_y, active_mask) + neg_q = pto.vneg(pto.vbitcast(q_abs, pto.i16), active_mask) + q = pto.vsel(pto.vbitcast(q_abs, pto.i16), neg_q, p_pos) + + qy = pto.vmul(q, scalar_vec, active_mask) + remainder = pto.vsub(vec, qy, active_mask) + + nonzero_remainder = pto.vcmps(remainder, zero, active_mask, pto.CmpMode.NE) + sign_x = pto.vcmps(vec, zero, active_mask, pto.CmpMode.GE) + sign_y = pto.vcmps(scalar_vec, zero, active_mask, pto.CmpMode.GE) + sign_diff = pto.pxor(sign_x, sign_y, active_mask) + need_floor_fix = pto.pand(sign_diff, nonzero_remainder, active_mask) + amended_remainder = pto.vadd(scalar_vec, remainder, active_mask) + remainder = pto.vsel(amended_remainder, remainder, need_floor_fix) + return pto.vsel(pto.vbr(neg_one), remainder, zero_mask) + + +@pto.inline_proc +def vdiv_i32(vec, scalar_vec, mask): + zero = pto.i32(0) + neg_one = pto.i32(-1) + fp32_one = pto.f32(1.0) + false_mask = pto.pset_b32(pto.PAT.ALLF) + + zero_mask = pto.vcmps(scalar_vec, zero, mask, pto.CmpMode.EQ) + active_mask = pto.pnot(zero_mask, mask) + + abs_x = pto.vbitcast(pto.vabs(vec, active_mask), pto.ui32) + abs_y = pto.vbitcast(pto.vabs(scalar_vec, active_mask), pto.ui32) + x_xor_y = pto.vxor(vec, scalar_vec, active_mask) + p_pos = pto.vcmps(x_xor_y, zero, active_mask, pto.CmpMode.GE) + + y_float = pto.vcvt(pto.vbitcast(abs_y, pto.i32), pto.f32, active_mask, rnd=pto.VcvtRoundMode.R) + y_rec = pto.vdiv(pto.vbr(fp32_one), y_float, active_mask) + f_z_tmp_bits = pto.vadds(pto.vbitcast(y_rec, pto.ui32), pto.ui32(0x0FFFFFFE), active_mask) + + low_mask, high_mask = pto.pintlv_b32(active_mask, false_mask) + lower_bits, higher_bits = pto.vintlv(f_z_tmp_bits, pto.vbr(pto.ui32(0))) + lower_i64 = pto.vcvt( + pto.vbitcast(lower_bits, pto.f32), + pto.i64, + low_mask, + rnd=pto.VcvtRoundMode.F, + sat=pto.VcvtSatMode.NOSAT, + part=pto.VcvtPartMode.EVEN, + ) + higher_i64 = pto.vcvt( + pto.vbitcast(higher_bits, pto.f32), + pto.i64, + high_mask, + rnd=pto.VcvtRoundMode.F, + sat=pto.VcvtSatMode.NOSAT, + part=pto.VcvtPartMode.EVEN, + ) + z, _z_waste = pto.vdintlv(pto.vbitcast(lower_i64, pto.ui32), pto.vbitcast(higher_i64, pto.ui32)) + active_mask, _waste_mask = pto.pdintlv_b32(low_mask, high_mask) + + fz_negative = pto.vcmps(pto.vbitcast(f_z_tmp_bits, pto.f32), pto.f32(0.0), active_mask, pto.CmpMode.LT) + z = pto.vsel(pto.vbr(pto.ui32(0)), z, fz_negative) + + tmp_0 = pto.vmul(z, abs_y, active_mask) + tmp_0 = pto.vbitcast(pto.vneg(pto.vbitcast(tmp_0, pto.i32), active_mask), pto.ui32) + _z_lower, z_high = pto.vmull(z, tmp_0, active_mask) + z = pto.vadd(z, z_high, active_mask) + + _q_lower, q_tmp = pto.vmull(abs_x, z, active_mask) + yq_tmp = pto.vmul(q_tmp, abs_y, active_mask) + r_tmp = pto.vsub(abs_x, yq_tmp, active_mask) + ge_mask = pto.vcmp(r_tmp, abs_y, active_mask, pto.CmpMode.GE) + refined_r = pto.vsub(r_tmp, abs_y, active_mask) + r_tmp = pto.vsel(refined_r, r_tmp, ge_mask) + q_inc = pto.vadds(q_tmp, pto.ui32(1), active_mask) + q_tmp = pto.vsel(q_inc, q_tmp, ge_mask) + ge_mask = pto.vcmp(r_tmp, abs_y, active_mask, pto.CmpMode.GE) + refined_r = pto.vsub(r_tmp, abs_y, active_mask) + r_tmp = pto.vsel(refined_r, r_tmp, ge_mask) + q_inc = pto.vadds(q_tmp, pto.ui32(1), active_mask) + q_tmp = pto.vsel(q_inc, q_tmp, ge_mask) + + neg_q = pto.vneg(pto.vbitcast(q_tmp, pto.i32), active_mask) + q = pto.vsel(pto.vbitcast(q_tmp, pto.i32), neg_q, p_pos) + return pto.vsel(pto.vbr(neg_one), q, zero_mask) + + +@pto.inline_proc +def vmod_u32(vec, scalar_vec, mask): + zero = pto.ui32(0) + one = pto.ui32(1) + zero_r = pto.ui32(0xFFFFFFFF) + fp32_one = pto.f32(1.0) + full_mask = pto.pset_b32(pto.PAT.ALL) + + zero_mask = pto.vcmps(scalar_vec, zero, mask, pto.CmpMode.EQ) + active_mask = pto.pnot(zero_mask, mask) + + zero_u32 = pto.vbr(zero) + zero_f32 = pto.vbr(pto.f32(0.0)) + vy_lower_u32, vy_higher_u32 = pto.vintlv(scalar_vec, zero_u32) + vy_lower_f32 = pto.vcvt(pto.vbitcast(vy_lower_u32, pto.i64), pto.f32, full_mask, rnd=pto.VcvtRoundMode.F, part=pto.VcvtPartMode.EVEN) + vy_higher_f32 = pto.vcvt(pto.vbitcast(vy_higher_u32, pto.i64), pto.f32, full_mask, rnd=pto.VcvtRoundMode.F, part=pto.VcvtPartMode.EVEN) + vy_float, _vy_waste = pto.vdintlv(vy_lower_f32, vy_higher_f32) + + vy_rec = pto.vdiv(pto.vbr(fp32_one), vy_float, full_mask) + vy_scale = pto.vmul(vy_rec, pto.vbr(pto.f32(4294966784.0)), full_mask) + + vy_scale_lower_f32, vy_scale_higher_f32 = pto.vintlv(vy_scale, zero_f32) + v_lower_i64 = pto.vcvt( + vy_scale_lower_f32, + pto.i64, + full_mask, + rnd=pto.VcvtRoundMode.F, + sat=pto.VcvtSatMode.NOSAT, + part=pto.VcvtPartMode.EVEN, + ) + v_higher_i64 = pto.vcvt( + vy_scale_higher_f32, + pto.i64, + full_mask, + rnd=pto.VcvtRoundMode.F, + sat=pto.VcvtSatMode.NOSAT, + part=pto.VcvtPartMode.EVEN, + ) + z, _z_waste = pto.vdintlv(pto.vbitcast(v_lower_i64, pto.ui32), pto.vbitcast(v_higher_i64, pto.ui32)) + + tmp_0 = pto.vmul(z, scalar_vec, full_mask) + tmp_0 = pto.vbitcast(pto.vneg(pto.vbitcast(tmp_0, pto.i32), full_mask), pto.ui32) + _z_lower, z_high = pto.vmull(z, tmp_0, full_mask) + z = pto.vadd(z, z_high, full_mask) + + _q_lower, q_tmp = pto.vmull(vec, z, full_mask) + yq_tmp = pto.vmul(q_tmp, scalar_vec, active_mask) + r_tmp = pto.vsub(vec, yq_tmp, active_mask) + ge_mask = pto.vcmp(r_tmp, scalar_vec, active_mask, pto.CmpMode.GE) + refined_r = pto.vsub(r_tmp, scalar_vec, active_mask) + r_tmp = pto.vsel(refined_r, r_tmp, ge_mask) + q_inc = pto.vadds(q_tmp, one, active_mask) + q_tmp = pto.vsel(q_inc, q_tmp, ge_mask) + ge_mask = pto.vcmp(r_tmp, scalar_vec, active_mask, pto.CmpMode.GE) + refined_r = pto.vsub(r_tmp, scalar_vec, active_mask) + r_tmp = pto.vsel(refined_r, r_tmp, ge_mask) + q_inc = pto.vadds(q_tmp, one, active_mask) + q_tmp = pto.vsel(q_inc, q_tmp, ge_mask) + return pto.vsel(pto.vbr(zero_r), r_tmp, zero_mask) + + +@pto.inline_proc +def vmod_i32(vec, scalar_vec, mask): + zero = pto.i32(0) + neg_one = pto.i32(-1) + fp32_one = pto.f32(1.0) + false_mask = pto.pset_b32(pto.PAT.ALLF) + + zero_mask = pto.vcmps(scalar_vec, zero, mask, pto.CmpMode.EQ) + active_mask = pto.pnot(zero_mask, mask) + + abs_x = pto.vbitcast(pto.vabs(vec, active_mask), pto.ui32) + abs_y = pto.vbitcast(pto.vabs(scalar_vec, active_mask), pto.ui32) + x_xor_y = pto.vxor(vec, scalar_vec, active_mask) + p_pos = pto.vcmps(x_xor_y, zero, active_mask, pto.CmpMode.GE) + + y_float = pto.vcvt(pto.vbitcast(abs_y, pto.i32), pto.f32, active_mask, rnd=pto.VcvtRoundMode.R) + y_rec = pto.vdiv(pto.vbr(fp32_one), y_float, active_mask) + f_z_tmp_bits = pto.vadds(pto.vbitcast(y_rec, pto.ui32), pto.ui32(0x0FFFFFFE), active_mask) + + low_mask, high_mask = pto.pintlv_b32(active_mask, false_mask) + lower_bits, higher_bits = pto.vintlv(f_z_tmp_bits, pto.vbr(pto.ui32(0))) + lower_i64 = pto.vcvt( + pto.vbitcast(lower_bits, pto.f32), + pto.i64, + low_mask, + rnd=pto.VcvtRoundMode.F, + sat=pto.VcvtSatMode.NOSAT, + part=pto.VcvtPartMode.EVEN, + ) + higher_i64 = pto.vcvt( + pto.vbitcast(higher_bits, pto.f32), + pto.i64, + high_mask, + rnd=pto.VcvtRoundMode.F, + sat=pto.VcvtSatMode.NOSAT, + part=pto.VcvtPartMode.EVEN, + ) + z, _z_waste = pto.vdintlv(pto.vbitcast(lower_i64, pto.ui32), pto.vbitcast(higher_i64, pto.ui32)) + active_mask, _waste_mask = pto.pdintlv_b32(low_mask, high_mask) + + fz_negative = pto.vcmps(pto.vbitcast(f_z_tmp_bits, pto.f32), pto.f32(0.0), active_mask, pto.CmpMode.LT) + z = pto.vsel(pto.vbr(pto.ui32(0)), z, fz_negative) + + tmp_0 = pto.vmul(z, abs_y, active_mask) + tmp_0 = pto.vbitcast(pto.vneg(pto.vbitcast(tmp_0, pto.i32), active_mask), pto.ui32) + _z_lower, z_high = pto.vmull(z, tmp_0, active_mask) + z = pto.vadd(z, z_high, active_mask) + + _q_lower, q_tmp = pto.vmull(abs_x, z, active_mask) + yq_tmp = pto.vmul(q_tmp, abs_y, active_mask) + r_tmp = pto.vsub(abs_x, yq_tmp, active_mask) + ge_mask = pto.vcmp(r_tmp, abs_y, active_mask, pto.CmpMode.GE) + refined_r = pto.vsub(r_tmp, abs_y, active_mask) + r_tmp = pto.vsel(refined_r, r_tmp, ge_mask) + q_inc = pto.vadds(q_tmp, pto.ui32(1), active_mask) + q_tmp = pto.vsel(q_inc, q_tmp, ge_mask) + ge_mask = pto.vcmp(r_tmp, abs_y, active_mask, pto.CmpMode.GE) + refined_r = pto.vsub(r_tmp, abs_y, active_mask) + r_tmp = pto.vsel(refined_r, r_tmp, ge_mask) + q_inc = pto.vadds(q_tmp, pto.ui32(1), active_mask) + q_tmp = pto.vsel(q_inc, q_tmp, ge_mask) + + neg_q = pto.vneg(pto.vbitcast(q_tmp, pto.i32), active_mask) + q = pto.vsel(pto.vbitcast(q_tmp, pto.i32), neg_q, p_pos) + + qy = pto.vmul(q, scalar_vec, active_mask) + remainder = pto.vsub(vec, qy, active_mask) + nonzero_remainder = pto.vcmps(pto.vbitcast(r_tmp, pto.i32), zero, active_mask, pto.CmpMode.NE) + sign_x = pto.vcmps(vec, zero, active_mask, pto.CmpMode.GE) + sign_y = pto.vcmps(scalar_vec, zero, active_mask, pto.CmpMode.GE) + sign_diff = pto.pxor(sign_x, sign_y, active_mask) + need_floor_fix = pto.pand(sign_diff, nonzero_remainder, active_mask) + amended_remainder = pto.vadd(scalar_vec, remainder, active_mask) + remainder = pto.vsel(amended_remainder, remainder, need_floor_fix) + return pto.vsel(pto.vbr(neg_one), remainder, zero_mask) + + +@pto.inline_proc +def vmod(vec, scalar_vec, mask, dtype): + if pto.constexpr(dtype == pto.ui16): + result = vmod_u16(vec, scalar_vec, mask) + elif pto.constexpr(dtype == pto.i16): + result = vmod_i16(vec, scalar_vec, mask) + elif pto.constexpr(dtype == pto.ui32): + result = vmod_u32(vec, scalar_vec, mask) + else: + result = vmod_i32(vec, scalar_vec, mask) + return result + + +@pto.inline_proc +def vdiv(vec, scalar_vec, mask, dtype): + if pto.constexpr(dtype == pto.ui16): + result = vdiv_u16(vec, scalar_vec, mask) + elif pto.constexpr(dtype == pto.i16): + result = vdiv_i16(vec, scalar_vec, mask) + elif pto.constexpr(dtype == pto.ui32): + result = vdiv_u32(vec, scalar_vec, mask) + else: + result = vdiv_i32(vec, scalar_vec, mask) + return result diff --git a/lib/TileOps/render_template_mlir.py b/lib/TileOps/render_template_mlir.py index 5c4952674..8a92cfcc6 100644 --- a/lib/TileOps/render_template_mlir.py +++ b/lib/TileOps/render_template_mlir.py @@ -7,6 +7,14 @@ # See LICENSE in the root of the software repository for the full text of the License. #!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + """Materialize a TileLang DSL library template to authoring-form MLIR. Examples: @@ -95,11 +103,15 @@ def _parse_args() -> argparse.Namespace: def _load_module(template_path: Path) -> ModuleType: + template_parent = template_path.parent.parent + if str(template_parent) not in sys.path: + sys.path.insert(0, str(template_parent)) module_name = f"_tileops_template_{template_path.stem}" spec = importlib.util.spec_from_file_location(module_name, template_path) if spec is None or spec.loader is None: raise ValueError(f"failed to load Python module from {template_path}") module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module spec.loader.exec_module(module) return module diff --git a/tilelang-dsl/docs/user_guide/06-control-flow.md b/tilelang-dsl/docs/user_guide/06-control-flow.md index f4793acc9..098e74110 100644 --- a/tilelang-dsl/docs/user_guide/06-control-flow.md +++ b/tilelang-dsl/docs/user_guide/06-control-flow.md @@ -100,6 +100,15 @@ def row_copy(dst: pto.Tile, src: pto.Tile, row: pto.i32): Important semantics: +- `pto.(...)` and bare helper calls are different mechanisms. +- Calls written as `pto.vadd(...)`, `pto.vdiv(...)`, `pto.vlds(...)`, etc. target + built-in TileLang/VPTO surfaces directly. +- Calls written as bare Python names such as `store_row(...)` target a + user-defined `@pto.inline_proc` helper when the callee name resolves to a + registered top-level inline procedure in the current module. +- `inline_proc` helpers do not live in the `pto` namespace; using the same + basename as a `pto.` op is allowed because the frontend distinguishes + `pto.xxx(...)` from bare `xxx(...)` calls. - Frontend preserves helper `func.func` and `func.call` in `mlir_text()` output. - VPTO backend mainline force-inlines helper calls before downstream lowering. - Helper definitions support default parameter values. diff --git a/tilelang-dsl/python/tilelang_dsl/expand_helper.py b/tilelang-dsl/python/tilelang_dsl/expand_helper.py index 4d2a875fd..50e7fe9fb 100644 --- a/tilelang-dsl/python/tilelang_dsl/expand_helper.py +++ b/tilelang-dsl/python/tilelang_dsl/expand_helper.py @@ -87,10 +87,15 @@ def _find_descriptors(module) -> list[VKernelDescriptor]: def _import_py_file(path: Path): """Import a .py file as a module and return it.""" - spec = importlib.util.spec_from_file_location(f"_tl_template_{path.stem}", str(path)) + template_parent = path.parent.parent + if str(template_parent) not in sys.path: + sys.path.insert(0, str(template_parent)) + module_name = f"_tl_template_{path.stem}" + spec = importlib.util.spec_from_file_location(module_name, str(path)) if spec is None or spec.loader is None: return None mod = importlib.util.module_from_spec(spec) + sys.modules[module_name] = mod try: spec.loader.exec_module(mod) except Exception as exc: diff --git a/tilelang-dsl/python/tilelang_dsl/kernel.py b/tilelang-dsl/python/tilelang_dsl/kernel.py index eb01241be..743ae6136 100644 --- a/tilelang-dsl/python/tilelang_dsl/kernel.py +++ b/tilelang-dsl/python/tilelang_dsl/kernel.py @@ -13,6 +13,7 @@ import os import inspect import ast +import sys import subprocess import tempfile import textwrap @@ -140,7 +141,16 @@ def _inline_proc_registry_key(fn: Callable[..., Any]) -> tuple[str, str]: def _find_inline_proc(name: str, *, module_name: str | None) -> InlineProcDescriptor | None: if module_name is None: return None - return _INLINE_PROC_REGISTRY.get((module_name, name)) + descriptor = _INLINE_PROC_REGISTRY.get((module_name, name)) + if descriptor is not None: + return descriptor + module = sys.modules.get(module_name) + if module is None: + return None + value = getattr(module, name, None) + if isinstance(value, InlineProcDescriptor): + return value + return None def _validate_inline_proc_call_surface( @@ -175,16 +185,24 @@ def _validate_inline_proc_call_surface( def _collect_inline_procs(module_name: str) -> tuple[tuple[str, InlineProcDescriptor], ...]: - return tuple( - sorted( - ( - (symbol, descriptor) - for (registered_module, symbol), descriptor in _INLINE_PROC_REGISTRY.items() - if registered_module == module_name - ), - key=lambda item: item[0], - ) - ) + collected: dict[str, InlineProcDescriptor] = { + symbol: descriptor + for (registered_module, symbol), descriptor in _INLINE_PROC_REGISTRY.items() + if registered_module == module_name + } + + module = sys.modules.get(module_name) + if module is not None: + for symbol, value in vars(module).items(): + if not isinstance(value, InlineProcDescriptor): + continue + collected.setdefault(symbol, value) + origin_module = value.py_fn.__module__ + for (registered_module, helper_name), helper in _INLINE_PROC_REGISTRY.items(): + if registered_module == origin_module: + collected.setdefault(helper_name, helper) + + return tuple(sorted(collected.items(), key=lambda item: item[0])) def _register_inline_proc(descriptor: InlineProcDescriptor) -> InlineProcDescriptor: diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index 690ffe5c8..c2880d84c 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -230,8 +230,13 @@ def render(self) -> str: parameter_list = ", ".join( f"{param.ssa_name}: {self._render_type(param.type)}" for param in self.kernel.parameters - if param.kind != "tile_valid_shape" + if param.kind != "tile_valid_shape" and self._should_materialize_function_boundary_type(param.type) ) + result_sig = "" + if self.kernel.body and isinstance(self.kernel.body[-1], SemanticReturnStmt): + return_value = self.kernel.body[-1].value + if return_value is not None: + result_sig = f" -> {self._render_type(return_value.type)}" env = { param.name: _RenderedValue(name=param.ssa_name, type=param.type) for param in self.kernel.parameters @@ -268,7 +273,7 @@ def render(self) -> str: lines.append(f'module attributes {{pto.target_arch = "{self.kernel.target}"}} {{') lines.append( " func.func " - f"{_format_symbol_name(self.kernel.symbol_name)}({parameter_list}) " + f"{_format_symbol_name(self.kernel.symbol_name)}({parameter_list}){result_sig} " "attributes { pto.tilelang.instance } {" ) lines.extend(self._constant_lines) @@ -279,6 +284,9 @@ def render(self) -> str: lines.append("") return "\n".join(lines) + def _should_materialize_function_boundary_type(self, ty: SemanticType) -> bool: + return not isinstance(ty, (SemanticMetaType, SemanticPadValueType)) + def _collect_used_tile_buffers( self, statements: tuple[SemanticStmt, ...], @@ -2582,6 +2590,7 @@ def _lower_call_expr( rendered_args = [ self._lower_expr(arg, env, indent=indent, into=into) for arg in expr.args + if self._should_materialize_function_boundary_type(arg.type) ] rendered_arg_names = ", ".join(arg.name for arg in rendered_args) rendered_arg_types = ", ".join(self._render_type(arg.type) for arg in rendered_args) diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index bda55f221..d781d141f 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -777,10 +777,14 @@ def __init__(self, node: FrontendKernelNode): self._inline_proc_nodes: dict[str, FrontendInlineProcNode] = { inline_proc.name: inline_proc for inline_proc in node.inline_procs } - self._inline_proc_specializations: dict[tuple[str, tuple[SemanticType, ...]], SemanticKernel] = {} - self._inline_proc_return_types: dict[tuple[str, tuple[SemanticType, ...]], SemanticType | None] = {} - self._inline_proc_order: list[tuple[str, tuple[SemanticType, ...]]] = [] - self._inline_proc_active_stack: list[tuple[str, tuple[SemanticType, ...]]] = [] + self._inline_proc_specializations: dict[ + tuple[str, tuple[tuple[SemanticType, object], ...]], SemanticKernel + ] = {} + self._inline_proc_return_types: dict[ + tuple[str, tuple[tuple[SemanticType, object], ...]], SemanticType | None + ] = {} + self._inline_proc_order: list[tuple[str, tuple[tuple[SemanticType, object], ...]]] = [] + self._inline_proc_active_stack: list[tuple[str, tuple[tuple[SemanticType, object], ...]]] = [] def _expr_source_location( self, @@ -1423,8 +1427,53 @@ def _inline_proc_specialization_key( self, name: str, args: tuple[SemanticExpr, ...], - ) -> tuple[str, tuple[SemanticType, ...]]: - return (name, tuple(arg.type for arg in args)) + ) -> tuple[str, tuple[tuple[SemanticType, object], ...]]: + return ( + name, + tuple( + (arg.type, self._inline_proc_static_specialization_token(arg)) + for arg in args + ), + ) + + def _inline_proc_static_specialization_token( + self, + expr: SemanticExpr, + ) -> object: + if isinstance(expr, SemanticLiteralExpr) and expr.value is None: + return ("none",) + + if isinstance(expr.type, SemanticMetaType) and expr.type.kind in { + "dtype", + "ptr_type", + "mask_type", + }: + value = self._try_static_value(expr) + if value is not None: + return ("meta", expr.type.kind, value) + + value = self._try_static_value(expr) + if isinstance(value, bool): + return ("bool", value) + if isinstance(value, int) and not isinstance(value, bool): + return ("int", value) + if value is None: + return ("dynamic",) + return ("dynamic",) + + def _inline_proc_bound_static_value( + self, + expr: SemanticExpr, + ) -> Any | None: + token = self._inline_proc_static_specialization_token(expr) + kind = token[0] + if kind == "meta": + return token[2] + if kind in {"bool", "int"}: + return token[1] + if kind == "none": + return None + return None def _inline_proc_symbol_name( self, @@ -1486,6 +1535,7 @@ def _materialize_inline_proc_specialization( ssa_name=f"%arg{index}", type=arg_expr.type, origin="inline_param", + value=self._inline_proc_bound_static_value(arg_expr), ) helper_env[param.name] = binding helper_parameters.append(SemanticParameter(binding=binding)) diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index d7535e6fd..afc2fd9ed 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -7034,6 +7034,53 @@ def kernel(dst: pto.Tile, src: pto.Tile): self.assertRegex(text, r"= func\.call @__tl_inline_") self.assertIn("pto.vsts", text) + def test_inline_proc_and_pto_surface_can_share_basename(self) -> None: + @pto.inline_proc + def vdiv(src: pto.Tile, lane: pto.i32 = 0): + return pto.vlds(src, lane) + + @pto.vkernel(op="inline_proc_same_basename_as_pto_surface_unique", + dtypes=[(pto.f32, pto.f32)]) + def kernel(dst: pto.Tile, src: pto.Tile): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + helper_vec = vdiv(src, 0) + raw_vec = pto.vdiv(helper_vec, helper_vec, mask) + pto.vsts(raw_vec, dst, 0, mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + + frontend_kernel = build_frontend_kernel_node(specialized) + self.assertGreaterEqual(len(frontend_kernel.inline_procs), 1) + self.assertIn("vdiv", {proc.name for proc in frontend_kernel.inline_procs}) + call_values = [ + stmt.value + for stmt in frontend_kernel.body + if isinstance(stmt, FrontendAssignStmt) + and isinstance(stmt.value, FrontendCallExpr) + ] + helper_call = next( + value for value in call_values if value.namespace is None and value.name == "vdiv" + ) + raw_call = next( + value for value in call_values if value.namespace == "pto" and value.name == "vdiv" + ) + self.assertEqual(len(helper_call.args), 2) + self.assertEqual(len(raw_call.args), 3) + self.assertIsInstance(raw_call, FrontendCallExpr) + self.assertIsNone(helper_call.namespace) + self.assertEqual(helper_call.name, "vdiv") + self.assertEqual(raw_call.namespace, "pto") + self.assertEqual(raw_call.name, "vdiv") + + text = specialized.mlir_text() + self.assertIn("pto.tilelang.inline_proc", text) + self.assertRegex(text, r"func\.call @__tl_inline_vdiv_") + self.assertIn("= pto.vdiv ", text) + def test_inline_proc_rejects_non_trailing_return(self) -> None: with self.assertRaises(pto.TileLangFrontendError) as ctx: @@ -7244,6 +7291,80 @@ def kernel(dst: pto.Tile, src: pto.Tile): self.assertRegex(text, r"= func\.call @__tl_inline_[A-Za-z0-9_]+\(.*\) : \([^\)]*\) -> index") self.assertRegex(text, r"func\.call @__tl_inline_[A-Za-z0-9_]+\(.*\) : \([^\)]*\) -> \(\)") + def test_inline_proc_supports_constexpr_dtype_dispatch(self) -> None: + @pto.inline_proc + def inline_pick_lane(dtype): + if pto.constexpr(dtype == pto.ui16): + lane = 1 + elif pto.constexpr(dtype == pto.i16): + lane = 2 + elif pto.constexpr(dtype == pto.ui32): + lane = 3 + else: + lane = 4 + return lane + + @pto.vkernel(op="inline_proc_constexpr_dtype_dispatch_unique", dtypes=[(pto.f32,)]) + def kernel(dst: pto.Tile): + lane = inline_pick_lane(dst.element_type) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + + assign_stmt = next( + stmt + for stmt in semantic_kernel.body + if isinstance(stmt, SemanticAssignStmt) and isinstance(stmt.value, SemanticCallExpr) + ) + self.assertRegex(assign_stmt.value.name, r"^__tl_inline_") + self.assertEqual(len(semantic_kernel.inline_helpers), 1) + helper_assign = next( + stmt + for stmt in semantic_kernel.inline_helpers[0].body + if isinstance(stmt, SemanticAssignStmt) + and len(stmt.targets) == 1 + and stmt.targets[0].name == "lane" + ) + self.assertIsInstance(helper_assign.value, SemanticLiteralExpr) + self.assertEqual(helper_assign.value.value, 4) + + def test_inline_proc_specializes_same_type_with_different_static_values(self) -> None: + @pto.inline_proc + def inline_scale(lane: pto.i32): + if pto.constexpr(lane == 1): + value = 2 + else: + value = 4 + return value + + @pto.vkernel(op="inline_proc_static_value_specialization_unique", dtypes=[(pto.f32,)]) + def kernel(dst: pto.Tile): + lane0 = inline_scale(1) + lane1 = inline_scale(2) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + + self.assertEqual(len(semantic_kernel.inline_helpers), 2) + literal_values = [] + for helper in semantic_kernel.inline_helpers: + helper_assign = next( + stmt + for stmt in helper.body + if isinstance(stmt, SemanticAssignStmt) + and len(stmt.targets) == 1 + and stmt.targets[0].name == "value" + ) + self.assertIsInstance(helper_assign.value, SemanticLiteralExpr) + literal_values.append(helper_assign.value.value) + self.assertEqual(sorted(literal_values), [2, 4]) + def test_inline_proc_rejects_mutual_recursion(self) -> None: @pto.inline_proc def inline_a(dst: pto.Tile): From c3797c392a59e49f726db093e5e03d434c04ad14 Mon Sep 17 00:00:00 2001 From: kangjiaming1 <1159380836@qq.com> Date: Sat, 25 Apr 2026 18:03:27 +0800 Subject: [PATCH 172/192] add texpand/tfillpad/tfillpad_inplace/tfillpad_expand op (#167) * add texpand/tfillpad/tfillpad_inplace/tfillpad_expand op * revise the review comments * Add uint16-related test cases for tfillpad/tfillpad_expand * solve the CI problem * revise review comments * revise review comments * revise review comments --------- Co-authored-by: kangjiaming Co-authored-by: KurrinQu --- lib/TileOps/texpand_template.py | 72 +++ lib/TileOps/tfillpad_expand_template.py | 164 +++++ lib/TileOps/tfillpad_inplace_template.py | 155 +++++ lib/TileOps/tfillpad_template.py | 170 ++++++ .../basic/expand_tile_op_tilelang_texpand.pto | 38 ++ .../expand_tile_op_tilelang_tfillpad.pto | 45 ++ ...xpand_tile_op_tilelang_tfillpad_expand.pto | 45 ++ ...pand_tile_op_tilelang_tfillpad_inplace.pto | 43 ++ .../npu/a5/src/st/testcase/CMakeLists.txt | 4 + .../src/st/testcase/texpands/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/texpands/cases.py | 124 ++++ .../a5/src/st/testcase/texpands/compare.py | 78 +++ .../a5/src/st/testcase/texpands/gen_data.py | 49 ++ .../a5/src/st/testcase/texpands/launch.cpp | 80 +++ .../npu/a5/src/st/testcase/texpands/main.cpp | 171 ++++++ .../a5/src/st/testcase/texpands/texpands.pto | 387 ++++++++++++ .../src/st/testcase/tfillpad/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tfillpad/cases.py | 193 ++++++ .../a5/src/st/testcase/tfillpad/compare.py | 81 +++ .../a5/src/st/testcase/tfillpad/gen_data.py | 117 ++++ .../a5/src/st/testcase/tfillpad/launch.cpp | 93 +++ .../npu/a5/src/st/testcase/tfillpad/main.cpp | 207 +++++++ .../a5/src/st/testcase/tfillpad/tfillpad.pto | 576 ++++++++++++++++++ .../testcase/tfillpad_expand/CMakeLists.txt | 9 + .../src/st/testcase/tfillpad_expand/cases.py | 74 +++ .../st/testcase/tfillpad_expand/compare.py | 75 +++ .../st/testcase/tfillpad_expand/gen_data.py | 114 ++++ .../st/testcase/tfillpad_expand/launch.cpp | 29 + .../src/st/testcase/tfillpad_expand/main.cpp | 151 +++++ .../tfillpad_expand/tfillpad_expand.pto | 134 ++++ .../testcase/tfillpad_inplace/CMakeLists.txt | 9 + .../src/st/testcase/tfillpad_inplace/cases.py | 38 ++ .../st/testcase/tfillpad_inplace/compare.py | 80 +++ .../st/testcase/tfillpad_inplace/gen_data.py | 99 +++ .../st/testcase/tfillpad_inplace/launch.cpp | 23 + .../src/st/testcase/tfillpad_inplace/main.cpp | 129 ++++ .../tfillpad_inplace/tfillpad_inplace.pto | 76 +++ 37 files changed, 3950 insertions(+) create mode 100644 lib/TileOps/texpand_template.py create mode 100644 lib/TileOps/tfillpad_expand_template.py create mode 100644 lib/TileOps/tfillpad_inplace_template.py create mode 100644 lib/TileOps/tfillpad_template.py create mode 100644 test/basic/expand_tile_op_tilelang_texpand.pto create mode 100644 test/basic/expand_tile_op_tilelang_tfillpad.pto create mode 100644 test/basic/expand_tile_op_tilelang_tfillpad_expand.pto create mode 100644 test/basic/expand_tile_op_tilelang_tfillpad_inplace.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/texpands/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/texpands/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/texpands/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/texpands/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/texpands/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/texpands/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/texpands/texpands.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tfillpad/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tfillpad/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tfillpad/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tfillpad/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tfillpad/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tfillpad/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tfillpad/tfillpad.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/tfillpad_expand.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/tfillpad_inplace.pto diff --git a/lib/TileOps/texpand_template.py b/lib/TileOps/texpand_template.py new file mode 100644 index 000000000..d2e07360f --- /dev/null +++ b/lib/TileOps/texpand_template.py @@ -0,0 +1,72 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.texpands + +This template implements scalar broadcast expansion for location=VEC tiles. +It fills dst.valid_shape region with the broadcasted scalar value. + +Location constraint: + - This template is designed for tiles with location=VEC (vector buffer) + - In PTO-ISA, texpands has separate implementations for VEC and MAT locations + - For MAT location tiles, a different template/implementation path should be used + - Current tilelang_dsl MemorySpace only distinguishes GM and UB, where UB maps to + both VEC and MAT locations. The constraint checks memory_space=="ub" as a proxy + - Future enhancement: tilelang_dsl should support explicit location distinction + (e.g., MemorySpace.VEC vs MemorySpace.MAT) for more precise constraint matching + +Layout considerations: + - PTO-ISA has both rowmajor and colmajor expands implementations + - However, expands (scalar broadcast) is layout-agnostic: it simply fills + the tile with a scalar value using vector stores + - The vector store (vsts) writes data according to the tile's physical layout, + which is handled by the underlying DMA engine + - Therefore, this single template covers both rowmajor and colmajor cases +""" + +import tilelang_dsl as pto + + +def _texpands_vec_location_constraint(scalar, dst) -> bool: + """Constraint: dst tile must have location=VEC (represented as memory_space=ub). + + PTO-ISA defines texpands for both MAT and VEC locations: + - MAT location: expands matrix tiles (different implementation path, not supported here) + - VEC location: expands vector tiles (this template) + + Current tilelang_dsl limitation: + MemorySpace only has UB and GM. VEC and MAT both map to UB. + We check memory_space=="ub" as a proxy for VEC location. + MAT tiles should use a different op/template path and won't match here. + """ + # Check memory_space is "ub" (VEC/MAT location, not GM) + # In current tilelang_dsl, VEC location tiles have memory_space="ub" + ms = dst.memory_space + if isinstance(ms, str): + return ms == "ub" + return hasattr(ms, "value") and ms.value == "ub" + + +@pto.vkernel( + target="a5", + op="pto.texpands", + constraints=[_texpands_vec_location_constraint], +) +def template_texpands(scalar: pto.AnyType, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + # Use vdup for scalar broadcast + vec = pto.vdup(scalar, mask) + pto.vsts(vec, dst[row, col:], mask) + + return \ No newline at end of file diff --git a/lib/TileOps/tfillpad_expand_template.py b/lib/TileOps/tfillpad_expand_template.py new file mode 100644 index 000000000..6a685f707 --- /dev/null +++ b/lib/TileOps/tfillpad_expand_template.py @@ -0,0 +1,164 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tfillpad_expand + +Expand mode semantics: + - TFILLPAD_EXPAND: src rows may be less than dst rows + - Copy src.valid data to dst + - Fill cols from src.valid_cols to dst.valid_cols with FillPadVal + - Fill rows from src.rows to dst.rows with FillPadVal + +Strategy: + - Phase 1: Copy aligned valid blocks (cols 0 to aligned_col-1) + - Phase 2: Fill cols aligned_col to dst_valid_cols-1 with FillPadVal + - Phase 3: Copy tail valid lanes (cols aligned_col to src_valid_cols-1) + - Phase 4: Fill row expansion + +Address alignment and unaligned handling: + - vlds/vsts require 32-byte aligned base addresses + - Phase 1: col=0 is always aligned (tile base address is aligned), each iteration + accesses col + lanes which maintains alignment + - Phase 2/3/4: handle non-aligned lengths using make_mask() to control active lanes + - make_mask approach: simpler than vldus/vstus for isolated tail operations, no need + for alignment state management (vldas/vldus/vsta sequence) + - vldus/vstus is suitable for continuous unaligned streams; for single tail ops, + mask-controlled vlds/vsts is more direct and efficient +""" + +import tilelang_dsl as pto + +_NEG1_F32 = -1.0 + +# All supported dtype pairs for tfillpad_expand +_DTYPE_SIGNATURES = [ + (pto.f32, pto.f32), + (pto.i16, pto.i16), + (pto.si16, pto.si16), + (pto.ui16, pto.ui16), + (pto.i32, pto.i32), + (pto.si32, pto.si32), + (pto.ui32, pto.ui32), + (pto.i8, pto.i8), + (pto.si8, pto.si8), + (pto.ui8, pto.ui8), +] + + +@pto.vkernel( + target="a5", + op="pto.tfillpad_expand", + dtypes=_DTYPE_SIGNATURES, +) +def template_tfillpad_expand(src: pto.Tile, dst: pto.Tile): + """Unified tfillpad_expand template for all dtypes. + + Main logic is identical across dtypes; only PadValue handling differs: + - f32: ZERO + expansion uses -1.0 (special encoding), otherwise eval() or 0.0 + - integer families: eval() or dtype-specific zero constant + """ + dtype = dst.element_type + src_rows, _ = src.shape + src_valid_rows, src_valid_cols = src.valid_shape + dst_rows, _ = dst.shape + dst_valid_rows, dst_valid_cols = dst.valid_shape + + lanes = pto.get_lanes(dtype) + aligned_col = (src_valid_cols // lanes) * lanes + has_tail = src_valid_cols > aligned_col + has_valid_expansion = (src_valid_cols < dst_valid_cols) or (src_valid_rows < dst_valid_rows) + + # PadValue handling - dtype-specific + if pto.constexpr(dtype == pto.f32): + if pto.constexpr(dst.pad_value == pto.PadValue.ZERO and has_valid_expansion): + fill_scalar = pto.f32(_NEG1_F32) + elif pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.f32(0.0) + elif pto.constexpr(dtype == pto.ui16): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.ui16(0) + elif pto.constexpr(dtype == pto.si16): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.si16(0) + elif pto.constexpr(dtype == pto.i16): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.i16(0) + elif pto.constexpr(dtype == pto.ui32): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.ui32(0) + elif pto.constexpr(dtype == pto.si32): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.si32(0) + elif pto.constexpr(dtype == pto.i32): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.i32(0) + elif pto.constexpr(dtype == pto.ui8): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.ui8(0) + elif pto.constexpr(dtype == pto.si8): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.si8(0) + elif pto.constexpr(dtype == pto.i8): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.i8(0) + + # Phase 1: Copy aligned valid blocks + for row in range(0, src_valid_rows, 1): + remained = aligned_col + for col in range(0, aligned_col, lanes): + mask, remained = pto.make_mask(dtype, remained) + data = pto.vlds(src[row, col:]) + pto.vsts(data, dst[row, col:], mask) + + # Phase 2: Fill col padding + if pto.constexpr(aligned_col < dst_valid_cols): + for row in range(0, dst_valid_rows, 1): + remained = dst_valid_cols - aligned_col + for col in range(aligned_col, dst_valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + vec = pto.vdup(fill_scalar, mask) + pto.vsts(vec, dst[row, col:], mask) + + # Phase 3: Copy tail valid lanes + if pto.constexpr(has_tail): + for row in range(0, src_valid_rows, 1): + remained = src_valid_cols - aligned_col + mask_copy, remained = pto.make_mask(dtype, remained) + data = pto.vlds(src[row, aligned_col:]) + pto.vsts(data, dst[row, aligned_col:], mask_copy) + + # Phase 4: Fill row expansion + if pto.constexpr(src_rows < dst_rows): + for row in range(src_rows, dst_rows, 1): + remained = dst_valid_cols + for col in range(0, dst_valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + vec = pto.vdup(fill_scalar, mask) + pto.vsts(vec, dst[row, col:], mask) + + return \ No newline at end of file diff --git a/lib/TileOps/tfillpad_inplace_template.py b/lib/TileOps/tfillpad_inplace_template.py new file mode 100644 index 000000000..f79cb27fb --- /dev/null +++ b/lib/TileOps/tfillpad_inplace_template.py @@ -0,0 +1,155 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tfillpad_inplace + +Semantic (based on C++ TFillPad.hpp reference): + - TFILLPAD_INPLACE: same physical buffer (src == dst), skips copy phase, only fills expansion + - Inplace mode: src and dst share the same physical UB address + +Strategy (inplace mode): + - Skip Phase 1+3: Copy phases (data already in buffer) + - Phase 2: Fill cols from src_valid_cols to dst_valid_cols-1 with FillPadVal + - Phase 4: Fill row expansion +""" + +import tilelang_dsl as pto + +_NEG1_F32 = -1.0 + +# All supported dtype pairs +_DTYPE_SIGNATURES = [ + (pto.f32, pto.f32), + (pto.i16, pto.i16), + (pto.si16, pto.si16), + (pto.ui16, pto.ui16), + (pto.i32, pto.i32), + (pto.si32, pto.si32), + (pto.ui32, pto.ui32), + (pto.i8, pto.i8), + (pto.si8, pto.si8), + (pto.ui8, pto.ui8), +] + + +@pto.vkernel( + target="a5", + op="pto.tfillpad_inplace", + dtypes=_DTYPE_SIGNATURES, + advanced=True, # Required for as_ptr() +) +def template_tfillpad_inplace(src: pto.Tile, dst: pto.Tile): + """tfillpad_inplace: skip copy phase, only fill expansion regions. + + Uses vstus+vstas for unaligned column fill, matching C++ TFillPad.hpp. + """ + dtype = dst.element_type + src_rows, _ = src.shape + src_valid_rows, src_valid_cols = src.valid_shape + dst_rows, dst_cols = dst.shape + dst_valid_rows, dst_valid_cols = dst.valid_shape + + lanes = pto.get_lanes(dtype) + has_valid_expansion = (src_valid_cols < dst_valid_cols) or (src_valid_rows < dst_valid_rows) + + # PadValue handling - same as tfillpad_template.py + if pto.constexpr(dtype == pto.f32): + if pto.constexpr(dst.pad_value == pto.PadValue.ZERO and has_valid_expansion): + fill_scalar = pto.f32(_NEG1_F32) + elif pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.f32(0.0) + elif pto.constexpr(dtype == pto.ui16): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.ui16(0) + elif pto.constexpr(dtype == pto.si16): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.si16(0) + elif pto.constexpr(dtype == pto.i16): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.i16(0) + elif pto.constexpr(dtype == pto.ui32): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.ui32(0) + elif pto.constexpr(dtype == pto.si32): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.si32(0) + elif pto.constexpr(dtype == pto.i32): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.i32(0) + elif pto.constexpr(dtype == pto.ui8): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.ui8(0) + elif pto.constexpr(dtype == pto.si8): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.si8(0) + elif pto.constexpr(dtype == pto.i8): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.i8(0) + + # Phase 2: Fill cols from src_valid_cols to dst_valid_cols-1 + # Use vstus+vstas for unaligned starting column, matching C++ TFillPad.hpp + # vstus signature: vstus(align, offset/i32, value, base) + # offset is the number of elements to store + if pto.constexpr(src_valid_cols < dst_valid_cols): + pad_cols = dst_valid_cols - src_valid_cols + pad_repeat_times = (pad_cols + lanes - 1) // lanes + + # Create fill vector once (reused across all rows) + fill_vec = pto.vdup(fill_scalar) + + for row in range(0, dst_valid_rows, 1): + # Initialize align register for this row + ureg = pto.init_align() + + # Get pointer to UB buffer + base_ptr = dst.as_ptr() + + # Loop with DSL range, using j as constexpr loop variable + for j in range(0, pad_repeat_times, 1): + # sreg calculation: constexpr evaluated per iteration + # remaining = pad_cols - j*lanes, sreg = min(remaining, lanes) + remaining = pad_cols - j * lanes + # Use constexpr to branch on sreg value + if pto.constexpr(remaining >= lanes): + ureg = pto.vstus(ureg, lanes, fill_vec, base_ptr) + elif pto.constexpr(remaining > 0): + ureg = pto.vstus(ureg, remaining, fill_vec, base_ptr) + + # vstas: align final address with offset=0 + pto.vstas(ureg, fill_vec, dst[row, src_valid_cols:], 0) + + # Phase 4: Fill row expansion (rows src_rows to dst_rows-1) + if pto.constexpr(src_rows < dst_rows): + for row in range(src_rows, dst_rows, 1): + remained = dst_valid_cols + for col in range(0, dst_valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + vec = pto.vdup(fill_scalar, mask) + pto.vsts(vec, dst[row, col:], mask) + + return \ No newline at end of file diff --git a/lib/TileOps/tfillpad_template.py b/lib/TileOps/tfillpad_template.py new file mode 100644 index 000000000..77730ad11 --- /dev/null +++ b/lib/TileOps/tfillpad_template.py @@ -0,0 +1,170 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tfillpad + +Semantic (based on C++ TFillPad.hpp reference): + - TFILLPAD: copies src.valid data to dst, then fills dst expansion with FillPadVal + - TFILLPAD_INPLACE: same physical buffer (src == dst), skips copy phase, only fills expansion + +Key logic from C++: + if constexpr (!inplace) { + CopyValidElementsVec(dst, src, ...); // Phase 1+3: copy valid data + } + // Phase 2+4: fill expansion (always executed if dst has larger valid region) + FillExpansion(dst, padCols, padRows, padValue); + +Strategy: + - Phase 1: Copy aligned valid blocks (cols 0 to aligned_col-1) [only if !inplace] + - Phase 2: Fill cols aligned_col to dst_valid_cols-1 with FillPadVal + - Phase 3: Copy tail valid lanes (cols aligned_col to src_valid_cols-1) [only if !inplace] + - Phase 4: Fill row expansion + +Address alignment and unaligned handling: + - vlds/vsts require 32-byte aligned base addresses + - Phase 1: col=0 is always aligned (tile base address is aligned) + - Phase 2/4: handle non-aligned lengths using make_mask() to control active lanes + +Note: There is no separate pto.tfillpad_inplace operation in PTO IR. + In-place mode is expressed via tfillpad with src and dst being the same SSA value. + This template handles both cases by detecting if copy phase is needed. +""" + +import tilelang_dsl as pto + +_NEG1_F32 = -1.0 + +# All supported dtype pairs +_DTYPE_SIGNATURES = [ + (pto.f32, pto.f32), + (pto.i16, pto.i16), + (pto.si16, pto.si16), + (pto.ui16, pto.ui16), + (pto.i32, pto.i32), + (pto.si32, pto.si32), + (pto.ui32, pto.ui32), + (pto.i8, pto.i8), + (pto.si8, pto.si8), + (pto.ui8, pto.ui8), +] + + +@pto.vkernel( + target="a5", + op="pto.tfillpad", + dtypes=_DTYPE_SIGNATURES, +) +def template_tfillpad(src: pto.Tile, dst: pto.Tile): + """tfillpad: copy src.valid to dst and fill expansion regions. + + Based on C++ TFillPad.hpp reference: + - TFILLPAD (non-inplace): CopyValidElementsVec + FillExpansion + + tfillpad requires src.shape == dst.shape (same physical size). + If dst.valid > src.valid, fill the expansion regions. + """ + dtype = dst.element_type + src_rows, _ = src.shape + src_valid_rows, src_valid_cols = src.valid_shape + dst_rows, dst_cols = dst.shape + dst_valid_rows, dst_valid_cols = dst.valid_shape + + lanes = pto.get_lanes(dtype) + aligned_col = (src_valid_cols // lanes) * lanes + has_tail = src_valid_cols > aligned_col + has_valid_expansion = (src_valid_cols < dst_valid_cols) or (src_valid_rows < dst_valid_rows) + + # PadValue handling - dtype-specific (inline to avoid external call) + if pto.constexpr(dtype == pto.f32): + if pto.constexpr(dst.pad_value == pto.PadValue.ZERO and has_valid_expansion): + fill_scalar = pto.f32(_NEG1_F32) + elif pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.f32(0.0) + elif pto.constexpr(dtype == pto.ui16): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.ui16(0) + elif pto.constexpr(dtype == pto.si16): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.si16(0) + elif pto.constexpr(dtype == pto.i16): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.i16(0) + elif pto.constexpr(dtype == pto.ui32): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.ui32(0) + elif pto.constexpr(dtype == pto.si32): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.si32(0) + elif pto.constexpr(dtype == pto.i32): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.i32(0) + elif pto.constexpr(dtype == pto.ui8): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.ui8(0) + elif pto.constexpr(dtype == pto.si8): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.si8(0) + elif pto.constexpr(dtype == pto.i8): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.i8(0) + + # Phase 1: Copy aligned valid blocks + for row in range(0, src_valid_rows, 1): + remained = aligned_col + for col in range(0, aligned_col, lanes): + mask, remained = pto.make_mask(dtype, remained) + data = pto.vlds(src[row, col:]) + pto.vsts(data, dst[row, col:], mask) + + # Phase 2: Fill cols from aligned_col to dst_cols-1 + if pto.constexpr(aligned_col < dst_cols): + for row in range(0, src_valid_rows, 1): + remained = dst_cols - aligned_col + for col in range(aligned_col, dst_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + vec = pto.vdup(fill_scalar, mask) + pto.vsts(vec, dst[row, col:], mask) + + # Phase 3: Copy tail valid lanes + if pto.constexpr(has_tail): + for row in range(0, src_valid_rows, 1): + remained = src_valid_cols - aligned_col + mask_copy, remained = pto.make_mask(dtype, remained) + data = pto.vlds(src[row, aligned_col:]) + pto.vsts(data, dst[row, aligned_col:], mask_copy) + + # Phase 4: Fill row expansion + if pto.constexpr(src_rows < dst_rows): + for row in range(src_rows, dst_rows, 1): + remained = dst_cols + for col in range(0, dst_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + vec = pto.vdup(fill_scalar, mask) + pto.vsts(vec, dst[row, col:], mask) + + return \ No newline at end of file diff --git a/test/basic/expand_tile_op_tilelang_texpand.pto b/test/basic/expand_tile_op_tilelang_texpand.pto new file mode 100644 index 000000000..9f76692af --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_texpand.pto @@ -0,0 +1,38 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.texpands via the default TileLang Python DSL template +// lib/TileOps/texpands_template.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.texpands should be lowered to vector-style VPTO IR. +// CHECK: func.func @TEXPANDS +// CHECK-NOT: pto.texpands ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vdup +// CHECK: pto.vsts + +module { + func.func @TEXPANDS() { + %scalar = arith.constant 1.0 : f32 + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.texpands ins(%scalar : f32) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/expand_tile_op_tilelang_tfillpad.pto b/test/basic/expand_tile_op_tilelang_tfillpad.pto new file mode 100644 index 000000000..d58b37456 --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_tfillpad.pto @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tfillpad via the default TileLang Python DSL template +// lib/TileOps/tfillpad_template.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tfillpad should be lowered to vector-style VPTO IR. +// CHECK: func.func @TFILLPAD +// CHECK-NOT: pto.tfillpad ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vsts +// CHECK: pto.vdup +// Note: vstus is not supported in TileLang DSL v1, so padding uses vsts instead + +module { + func.func @TFILLPAD() { + // Source Tile: valid region 8x48, total capacity 16x64 + %src = pto.alloc_tile + : !pto.tile_buf + // Destination Tile: same size as source, valid region also 8x48 + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tfillpad ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} diff --git a/test/basic/expand_tile_op_tilelang_tfillpad_expand.pto b/test/basic/expand_tile_op_tilelang_tfillpad_expand.pto new file mode 100644 index 000000000..953a6342c --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_tfillpad_expand.pto @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tfillpad_expand via the default TileLang Python DSL template +// lib/TileOps/tfillpad_expand_template.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tfillpad_expand should be lowered to vector-style VPTO IR. +// CHECK: func.func @TFILLPAD_EXPAND +// CHECK-NOT: pto.tfillpad_expand ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vsts +// CHECK: pto.vdup +// CHECK: pto.vstus + +module { + func.func @TFILLPAD_EXPAND() { + // 源 Tile: 较小尺寸 8x32 + %src = pto.alloc_tile + : !pto.tile_buf + // 目标 Tile: 较大尺寸 16x64 + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tfillpad_expand ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} diff --git a/test/basic/expand_tile_op_tilelang_tfillpad_inplace.pto b/test/basic/expand_tile_op_tilelang_tfillpad_inplace.pto new file mode 100644 index 000000000..931b36490 --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_tfillpad_inplace.pto @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tfillpad (inplace mode) via the default TileLang Python DSL template +// lib/TileOps/tfillpad_template.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tfillpad (inplace) should be lowered to vector-style VPTO IR. +// CHECK: func.func @TFILLPAD_INPLACE +// CHECK-NOT: pto.tfillpad ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK-NOT: pto.vlds +// CHECK: pto.vdup +// CHECK: pto.vstus +// CHECK: pto.vsts + +module { + func.func @TFILLPAD_INPLACE() { + // 原地操作:src 和 dst 是同一个 Tile + // 有效区域 8x48,总容量 16x64 + %tile = pto.alloc_tile + : !pto.tile_buf + + // src 和 dst 相同,表示原地填充 padding + pto.tfillpad ins(%tile : !pto.tile_buf) + outs(%tile : !pto.tile_buf) + return + } +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt index 54e5038cc..0b5c2dc9a 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt +++ b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt @@ -155,6 +155,10 @@ set(ALL_TESTCASES trowprod trsqrt tsqrt + texpands + tfillpad + tfillpad_inplace + tfillpad_expand tadds tands tdivs diff --git a/test/tilelang_st/npu/a5/src/st/testcase/texpands/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/texpands/CMakeLists.txt new file mode 100644 index 000000000..3b48410cc --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/texpands/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(texpands) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/texpands/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/texpands/cases.py new file mode 100644 index 000000000..3beb7daef --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/texpands/cases.py @@ -0,0 +1,124 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for texpands ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - scalar: the scalar value to broadcast to the tile. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + # ========== float32 cases ========== + # Full valid shape cases + { + "name": "f32_16x64_scalar5", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "scalar": 5.0, + "eps": 1e-6, + }, + { + "name": "f32_32x32_scalar3", + "dtype": np.float32, + "shape": (32, 32), + "valid_shape": (32, 32), + "scalar": 3.0, + "eps": 1e-6, + }, + { + "name": "f32_64x64_scalar2", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (64, 64), + "scalar": 2.0, + "eps": 1e-6, + }, + # Partial valid shape cases + { + "name": "f32_16x64_partial", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (12, 48), + "scalar": 7.0, + "eps": 1e-6, + }, + { + "name": "f32_64x64_valid_60x60", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (60, 60), + "scalar": 42.0, + "eps": 1e-6, + }, + + # ========== int32 cases ========== + { + "name": "i32_64x64_scalar100", + "dtype": np.int32, + "shape": (64, 64), + "valid_shape": (64, 64), + "scalar": 100, + "eps": 0, # exact match for integers + }, + { + "name": "i32_64x64_valid_60x60", + "dtype": np.int32, + "shape": (64, 64), + "valid_shape": (60, 60), + "scalar": 99, + "eps": 0, + }, + + # ========== half (fp16) cases ========== + { + "name": "f16_64x64_scalar1_5", + "dtype": np.float16, + "shape": (64, 64), + "valid_shape": (64, 64), + "scalar": 1.5, + "eps": 1e-3, # fp16 has lower precision + }, + { + "name": "f16_2x4096_valid_1x3600", + "dtype": np.float16, + "shape": (2, 4096), + "valid_shape": (1, 3600), + "scalar": 2.5, + "eps": 1e-3, + }, + + # ========== int16 cases ========== + { + "name": "i16_64x64_scalar50", + "dtype": np.int16, + "shape": (64, 64), + "valid_shape": (64, 64), + "scalar": 50, + "eps": 0, + }, + { + "name": "i16_20x512_valid_16x200", + "dtype": np.int16, + "shape": (20, 512), + "valid_shape": (16, 200), + "scalar": 25, + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/texpands/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/texpands/compare.py new file mode 100644 index 000000000..db0cdf826 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/texpands/compare.py @@ -0,0 +1,78 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Compare output against golden for texpands test cases.""" + +import os +import sys +import numpy as np + +from cases import CASES + + +def main(): + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + eps = case["eps"] + + vr, vc = valid_shape + + # Load golden and output + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=dtype).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=dtype).reshape(shape) + + # For integer types, eps=0 means exact match + # For float types, use np.allclose with eps + if eps == 0: + # Integer comparison - exact match + if not np.array_equal(golden[:vr, :vc], output[:vr, :vc]): + diff = golden[:vr, :vc] - output[:vr, :vc] + idx = int(np.argmax(np.abs(diff))) + print(f"[ERROR] {case['name']}: Mismatch at idx={idx} (golden={golden.flat[idx]}, output={output.flat[idx]})") + all_passed = False + else: + print(f"[INFO] {case['name']}: compare passed") + else: + # Float comparison - use allclose + # Convert to float64 for comparison (fp16 precision issues) + g = golden[:vr, :vc].astype(np.float64, copy=False) + o = output[:vr, :vc].astype(np.float64, copy=False) + + if g.shape != o.shape: + print(f"[ERROR] {case['name']}: Shape mismatch: golden {g.shape} vs output {o.shape}") + all_passed = False + continue + + if not np.allclose(g, o, atol=eps, rtol=eps, equal_nan=True): + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + print(f"[ERROR] {case['name']}: Mismatch: max diff={float(abs_diff.flat[idx])} " + f"at idx={idx} (golden={g.flat[idx]}, output={o.flat[idx]})") + all_passed = False + else: + print(f"[INFO] {case['name']}: compare passed") + + if not all_passed: + sys.exit(2) + print("[INFO] all cases passed") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/texpands/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/texpands/gen_data.py new file mode 100644 index 000000000..b2dd3cac2 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/texpands/gen_data.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Generate golden data for texpands test cases.""" + +import os +import numpy as np + +from cases import CASES + + +def setup_case_rng(case): + """Set a per-case deterministic random seed.""" + np.random.seed(hash(case["name"]) & 0xFFFFFFFF) + + +def save_case_data(case_name, data_dict): + """Create case directory and write {name}.bin for each entry.""" + os.makedirs(case_name, exist_ok=True) + for name, arr in data_dict.items(): + arr.tofile(os.path.join(case_name, f"{name}.bin")) + + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + scalar = case["scalar"] + + # Convert scalar to the correct dtype + scalar_val = dtype(scalar) + + # Generate golden: fill valid_shape region with scalar value + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = scalar_val + + save_case_data(case["name"], {"golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} scalar={scalar} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/texpands/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/texpands/launch.cpp new file mode 100644 index 000000000..35ffa4f83 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/texpands/launch.cpp @@ -0,0 +1,80 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// ========== float32 kernels ========== + +extern "C" __global__ AICORE void TEXPANDS_f32_16x64_scalar5(__gm__ float *dst); +extern "C" __global__ AICORE void TEXPANDS_f32_32x32_scalar3(__gm__ float *dst); +extern "C" __global__ AICORE void TEXPANDS_f32_64x64_scalar2(__gm__ float *dst); +extern "C" __global__ AICORE void TEXPANDS_f32_16x64_partial(__gm__ float *dst); +extern "C" __global__ AICORE void TEXPANDS_f32_64x64_valid_60x60(__gm__ float *dst); + +void LaunchTEXPANDS_f32_16x64_scalar5(float *dst, void *stream) { + TEXPANDS_f32_16x64_scalar5<<<1, nullptr, stream>>>((__gm__ float *)dst); +} + +void LaunchTEXPANDS_f32_32x32_scalar3(float *dst, void *stream) { + TEXPANDS_f32_32x32_scalar3<<<1, nullptr, stream>>>((__gm__ float *)dst); +} + +void LaunchTEXPANDS_f32_64x64_scalar2(float *dst, void *stream) { + TEXPANDS_f32_64x64_scalar2<<<1, nullptr, stream>>>((__gm__ float *)dst); +} + +void LaunchTEXPANDS_f32_16x64_partial(float *dst, void *stream) { + TEXPANDS_f32_16x64_partial<<<1, nullptr, stream>>>((__gm__ float *)dst); +} + +void LaunchTEXPANDS_f32_64x64_valid_60x60(float *dst, void *stream) { + TEXPANDS_f32_64x64_valid_60x60<<<1, nullptr, stream>>>((__gm__ float *)dst); +} + +// ========== int32 kernels ========== + +extern "C" __global__ AICORE void TEXPANDS_i32_64x64_scalar100(__gm__ int32_t *dst); +extern "C" __global__ AICORE void TEXPANDS_i32_64x64_valid_60x60(__gm__ int32_t *dst); + +void LaunchTEXPANDS_i32_64x64_scalar100(int32_t *dst, void *stream) { + TEXPANDS_i32_64x64_scalar100<<<1, nullptr, stream>>>((__gm__ int32_t *)dst); +} + +void LaunchTEXPANDS_i32_64x64_valid_60x60(int32_t *dst, void *stream) { + TEXPANDS_i32_64x64_valid_60x60<<<1, nullptr, stream>>>((__gm__ int32_t *)dst); +} + +// ========== half (fp16) kernels ========== + +extern "C" __global__ AICORE void TEXPANDS_f16_64x64_scalar1_5(__gm__ uint16_t *dst); +extern "C" __global__ AICORE void TEXPANDS_f16_2x4096_valid_1x3600(__gm__ uint16_t *dst); + +void LaunchTEXPANDS_f16_64x64_scalar1_5(uint16_t *dst, void *stream) { + TEXPANDS_f16_64x64_scalar1_5<<<1, nullptr, stream>>>((__gm__ uint16_t *)dst); +} + +void LaunchTEXPANDS_f16_2x4096_valid_1x3600(uint16_t *dst, void *stream) { + TEXPANDS_f16_2x4096_valid_1x3600<<<1, nullptr, stream>>>((__gm__ uint16_t *)dst); +} + +// ========== int16 kernels ========== + +extern "C" __global__ AICORE void TEXPANDS_i16_64x64_scalar50(__gm__ int16_t *dst); +extern "C" __global__ AICORE void TEXPANDS_i16_20x512_valid_16x200(__gm__ int16_t *dst); + +void LaunchTEXPANDS_i16_64x64_scalar50(int16_t *dst, void *stream) { + TEXPANDS_i16_64x64_scalar50<<<1, nullptr, stream>>>((__gm__ int16_t *)dst); +} + +void LaunchTEXPANDS_i16_20x512_valid_16x200(int16_t *dst, void *stream) { + TEXPANDS_i16_20x512_valid_16x200<<<1, nullptr, stream>>>((__gm__ int16_t *)dst); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/texpands/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/texpands/main.cpp new file mode 100644 index 000000000..20179e763 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/texpands/main.cpp @@ -0,0 +1,171 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang texpands ST — case-table driven. +// Each case launches a different kernel variant, writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTEXPANDS_f32_16x64_scalar5(float *dst, void *stream); +void LaunchTEXPANDS_f32_32x32_scalar3(float *dst, void *stream); +void LaunchTEXPANDS_f32_64x64_scalar2(float *dst, void *stream); +void LaunchTEXPANDS_f32_16x64_partial(float *dst, void *stream); +void LaunchTEXPANDS_f32_64x64_valid_60x60(float *dst, void *stream); +void LaunchTEXPANDS_i32_64x64_scalar100(int32_t *dst, void *stream); +void LaunchTEXPANDS_i32_64x64_valid_60x60(int32_t *dst, void *stream); +void LaunchTEXPANDS_f16_64x64_scalar1_5(uint16_t *dst, void *stream); +void LaunchTEXPANDS_f16_2x4096_valid_1x3600(uint16_t *dst, void *stream); +void LaunchTEXPANDS_i16_64x64_scalar50(int16_t *dst, void *stream); +void LaunchTEXPANDS_i16_20x512_valid_16x200(int16_t *dst, void *stream); + +enum class DataType { F32, I32, F16, I16 }; + +struct TestCase { + const char *name; + DataType dtype; + void (*launch)(void *, void *); // Generic launch function pointer + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +// Helper to wrap type-specific launch functions +template +void wrapLaunch(void *dst, void *stream, void (*fn)(T *, void *)) { + fn((T *)dst, stream); +} + +static const TestCase kCases[] = { + // ========== float32 cases ========== + {"f32_16x64_scalar5", DataType::F32, + [](void *dst, void *stream) { wrapLaunch(dst, stream, LaunchTEXPANDS_f32_16x64_scalar5); }, + 16, 64, 16, 64, sizeof(float)}, + {"f32_32x32_scalar3", DataType::F32, + [](void *dst, void *stream) { wrapLaunch(dst, stream, LaunchTEXPANDS_f32_32x32_scalar3); }, + 32, 32, 32, 32, sizeof(float)}, + {"f32_64x64_scalar2", DataType::F32, + [](void *dst, void *stream) { wrapLaunch(dst, stream, LaunchTEXPANDS_f32_64x64_scalar2); }, + 64, 64, 64, 64, sizeof(float)}, + {"f32_16x64_partial", DataType::F32, + [](void *dst, void *stream) { wrapLaunch(dst, stream, LaunchTEXPANDS_f32_16x64_partial); }, + 16, 64, 12, 48, sizeof(float)}, + {"f32_64x64_valid_60x60", DataType::F32, + [](void *dst, void *stream) { wrapLaunch(dst, stream, LaunchTEXPANDS_f32_64x64_valid_60x60); }, + 64, 64, 60, 60, sizeof(float)}, + + // ========== int32 cases ========== + {"i32_64x64_scalar100", DataType::I32, + [](void *dst, void *stream) { wrapLaunch(dst, stream, LaunchTEXPANDS_i32_64x64_scalar100); }, + 64, 64, 64, 64, sizeof(int32_t)}, + {"i32_64x64_valid_60x60", DataType::I32, + [](void *dst, void *stream) { wrapLaunch(dst, stream, LaunchTEXPANDS_i32_64x64_valid_60x60); }, + 64, 64, 60, 60, sizeof(int32_t)}, + + // ========== half (fp16) cases ========== + {"f16_64x64_scalar1_5", DataType::F16, + [](void *dst, void *stream) { wrapLaunch(dst, stream, LaunchTEXPANDS_f16_64x64_scalar1_5); }, + 64, 64, 64, 64, sizeof(uint16_t)}, + {"f16_2x4096_valid_1x3600", DataType::F16, + [](void *dst, void *stream) { wrapLaunch(dst, stream, LaunchTEXPANDS_f16_2x4096_valid_1x3600); }, + 2, 4096, 1, 3600, sizeof(uint16_t)}, + + // ========== int16 cases ========== + {"i16_64x64_scalar50", DataType::I16, + [](void *dst, void *stream) { wrapLaunch(dst, stream, LaunchTEXPANDS_i16_64x64_scalar50); }, + 64, 64, 64, 64, sizeof(int16_t)}, + {"i16_20x512_valid_16x200", DataType::I16, + [](void *dst, void *stream) { wrapLaunch(dst, stream, LaunchTEXPANDS_i16_20x512_valid_16x200); }, + 20, 512, 16, 200, sizeof(int16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + + void *dstHost = nullptr; + void *dstDevice = nullptr; + + aclrtMallocHost(&dstHost, fileSize); + aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + // Launch kernel (scalar is hardcoded in .pto) + tc.launch(dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./texpands [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/texpands/texpands.pto b/test/tilelang_st/npu/a5/src/st/testcase/texpands/texpands.pto new file mode 100644 index 000000000..82a778806 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/texpands/texpands.pto @@ -0,0 +1,387 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.texpands: broadcast a scalar to a tile. +// Multiple cases with different shapes, data types, and scalar values. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm + +module { + // ========== float32 cases ========== + + // Case: f32 16x64, scalar=5.0 (full valid shape) + func.func @TEXPANDS_f32_16x64_scalar5(%dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %scalar = arith.constant 5.0 : f32 + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.texpands ins(%scalar : f32) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case: f32 32x32, scalar=3.0 (full valid shape) + func.func @TEXPANDS_f32_32x32_scalar3(%dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + %scalar = arith.constant 3.0 : f32 + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.texpands ins(%scalar : f32) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + return + } + + // Case: f32 64x64, scalar=2.0 (full valid shape) + func.func @TEXPANDS_f32_64x64_scalar2(%dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + %scalar = arith.constant 2.0 : f32 + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.texpands ins(%scalar : f32) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case: f32 16x64, scalar=7.0 (partial valid shape: 12x48) + func.func @TEXPANDS_f32_16x64_partial(%dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c12 = arith.constant 12 : index + %c16 = arith.constant 16 : index + %c48 = arith.constant 48 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %scalar = arith.constant 7.0 : f32 + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c12, %c48] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x12x48xf32> + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.texpands ins(%scalar : f32) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x12x48xf32>) + return + } + + // Case: f32 64x64, valid 60x60, scalar=42.0 + func.func @TEXPANDS_f32_64x64_valid_60x60(%dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c60 = arith.constant 60 : index + %c64 = arith.constant 64 : index + %c3600 = arith.constant 3600 : index + %c4096 = arith.constant 4096 : index + %scalar = arith.constant 42.0 : f32 + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x60x60xf32> + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.texpands ins(%scalar : f32) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x60x60xf32>) + return + } + + // ========== int32 cases ========== + + // Case: i32 64x64, scalar=100 (full valid shape) + func.func @TEXPANDS_i32_64x64_scalar100(%dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + %scalar = arith.constant 100 : i32 + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.texpands ins(%scalar : i32) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + return + } + + // Case: i32 64x64, valid 60x60, scalar=99 + func.func @TEXPANDS_i32_64x64_valid_60x60(%dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c60 = arith.constant 60 : index + %c64 = arith.constant 64 : index + %c3600 = arith.constant 3600 : index + %c4096 = arith.constant 4096 : index + %scalar = arith.constant 99 : i32 + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x60x60xi32> + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.texpands ins(%scalar : i32) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x60x60xi32>) + return + } + + // ========== half (fp16) cases ========== + + // Case: f16 64x64, scalar=1.5 (full valid shape) + func.func @TEXPANDS_f16_64x64_scalar1_5(%dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + %scalar = arith.constant 1.5 : f16 + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf16> -> !pto.partition_tensor_view<1x1x1x64x64xf16> + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.texpands ins(%scalar : f16) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x64x64xf16>) + return + } + + // Case: f16 2x4096, valid 1x3600, scalar=2.5 (wide column shape) + func.func @TEXPANDS_f16_2x4096_valid_1x3600(%dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3600 = arith.constant 3600 : index + %c4096 = arith.constant 4096 : index + %c8192 = arith.constant 8192 : index // 2*4096 + %scalar = arith.constant 2.5 : f16 + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c4096], + strides = [%c8192, %c8192, %c8192, %c4096, %c1] + : !pto.tensor_view<1x1x1x2x4096xf16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c3600] + : !pto.tensor_view<1x1x1x2x4096xf16> -> !pto.partition_tensor_view<1x1x1x1x3600xf16> + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.texpands ins(%scalar : f16) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x3600xf16>) + return + } + + // ========== int16 cases ========== + + // Case: i16 64x64, scalar=50 (full valid shape) + func.func @TEXPANDS_i16_64x64_scalar50(%dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + %scalar = arith.constant 50 : i16 + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.texpands ins(%scalar : i16) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) + return + } + + // Case: i16 20x512, valid 16x200, scalar=25 + func.func @TEXPANDS_i16_20x512_valid_16x200(%dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c20 = arith.constant 20 : index + %c200 = arith.constant 200 : index + %c512 = arith.constant 512 : index + %c3200 = arith.constant 3200 : index // 16*200 + %c10240 = arith.constant 10240 : index // 20*512 + %scalar = arith.constant 25 : i16 + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c20, %c512], + strides = [%c10240, %c10240, %c10240, %c512, %c1] + : !pto.tensor_view<1x1x1x20x512xi16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c200] + : !pto.tensor_view<1x1x1x20x512xi16> -> !pto.partition_tensor_view<1x1x1x16x200xi16> + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.texpands ins(%scalar : i16) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x200xi16>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/CMakeLists.txt new file mode 100644 index 000000000..0bffcc7fd --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tfillpad) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/cases.py new file mode 100644 index 000000000..4507ef6be --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/cases.py @@ -0,0 +1,193 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tfillpad ST test cases. + +Matches C++ reference test cases exactly (Cases 1-13). + +PadValue semantics: + - Max: +inf for float, MAX for integers + - Min: -inf for float, MIN for integers + - Null: no fill (keep original value) + - Custom(-1.0f): -1.0f for float, -1 for integers + +Each case defines: + - name: case identifier (must match main.cpp kCases[] and launch.cpp) + - dtype: numpy dtype + - shape: (rows, cols) — dst tile physical dimensions + - valid_shape: (valid_rows, valid_cols) — dst valid region (output size) + - src_shape: (rows, cols) — src tile physical dimensions (optional, default=dst) + - src_valid_shape: (valid_rows, valid_cols) — src valid region (optional, default=dst_valid) + - load_padval: PadValue for TLOAD (fill invalid columns in src tile) + - fill_padval: PadValue for TFILLPAD (fill expansion region in dst) + - eps: tolerance for numpy.allclose +""" + +import numpy as np + +# PadValue enum values matching C++ definition +PADVAL_MAX = "Max" # +inf for float, MAX for integers +PADVAL_MIN = "Min" # -inf for float, MIN for integers +PADVAL_NULL = "Null" # no fill (keep original value, treated as 0 in golden) +PADVAL_ZERO = "Zero" # zero fill +PADVAL_NEG1 = "Neg1" # -1.0f for float, -1 for integers (Custom) + +CASES = [ + # ========== Case 1: float, 128x127 -> 128x128, PadMax ========== + # C++: runTFILLPAD + + { + "name": "f32_128x128_pad_128x127", + "dtype": np.float32, + "shape": (128, 128), # dst tile physical + "valid_shape": (128, 128), # dst valid (output size) + "src_shape": (128, 127), # src tile physical (127 cols, < dst 128) + "src_valid_shape": (128, 127), # src valid = full src + "load_padval": PADVAL_MAX, # TLOAD: fill col 127 with +inf + "fill_padval": PADVAL_MAX, # TFILLPAD: no expansion needed + "eps": 1e-6, + }, + + # ========== Case 2: float, 128x127 -> 128x160, PadMax ========== + # C++: runTFILLPAD + + { + "name": "f32_128x160_pad_128x127", + "dtype": np.float32, + "shape": (128, 160), # dst tile physical + "valid_shape": (128, 160), # dst valid (output size) + "src_shape": (128, 127), # src tile physical + "src_valid_shape": (128, 127), # src valid + "load_padval": PADVAL_MAX, # TLOAD: fill col 127 with +inf + "fill_padval": PADVAL_MAX, # TFILLPAD: fill cols 128-159 with +inf + "eps": 1e-6, + }, + + # ========== Case 3: float, 128x127 -> 128x160, LoadPad=Min, FillPad=Max ========== + # C++: runTFILLPAD + + { + "name": "f32_128x160_pad_128x127_v2", + "dtype": np.float32, + "shape": (128, 160), # dst tile physical + "valid_shape": (128, 160), # dst valid (output size) + "src_shape": (128, 127), # src tile physical + "src_valid_shape": (128, 127), # src valid + "load_padval": PADVAL_MIN, # TLOAD: fill col 127 with -inf + "fill_padval": PADVAL_MAX, # TFILLPAD: fill cols 128-159 with +inf + "eps": 1e-6, + }, + + # ========== Case 4: float, 260x7 -> 260x16, PadMin/Max ========== + # C++: runTFILLPAD + + { + "name": "f32_260x16_pad_260x7", + "dtype": np.float32, + "shape": (260, 16), # dst tile physical + "valid_shape": (260, 16), # dst valid (output size) + "src_shape": (260, 7), # src tile physical + "src_valid_shape": (260, 7), # src valid + "load_padval": PADVAL_MIN, # TLOAD: fill cols 8-15 with -inf (32B aligned tile) + "fill_padval": PADVAL_MAX, # TFILLPAD: no expansion needed + "eps": 1e-6, + }, + + # ========== Case 6: uint16, 260x7 -> 260x32, PadMin/Max ========== + # C++: runTFILLPAD + + { + "name": "u16_260x32_pad_260x7", + "dtype": np.uint16, + "shape": (260, 32), # dst tile physical + "valid_shape": (260, 32), # dst valid (output size) + "src_shape": (260, 7), # src tile physical + "src_valid_shape": (260, 7), # src valid + "load_padval": PADVAL_MIN, # TLOAD: fill cols 8-31 with MIN (uint16 0) + "fill_padval": PADVAL_MAX, # TFILLPAD: fill cols 8-31 with MAX (uint16 65535) + "eps": 0, + }, + + # ========== Case 7: int8, 260x7 -> 260x64, PadMin/Max ========== + # C++: runTFILLPAD + + { + "name": "s8_260x64_pad_260x7", + "dtype": np.int8, + "shape": (260, 64), # dst tile physical + "valid_shape": (260, 64), # dst valid (output size) + "src_shape": (260, 7), # src tile physical + "src_valid_shape": (260, 7), # src valid + "load_padval": PADVAL_MIN, # TLOAD: fill cols 8-63 with MIN (int8 -128) + "fill_padval": PADVAL_MAX, # TFILLPAD: no expansion needed + "eps": 0, + }, + + # ========== Case 10: int16, 260x7 -> 260x32, PadMin/Min ========== + # C++: runTFILLPAD + + { + "name": "s16_260x32_pad_260x7", + "dtype": np.int16, + "shape": (260, 32), # dst tile physical + "valid_shape": (260, 32), # dst valid (output size) + "src_shape": (260, 7), # src tile physical + "src_valid_shape": (260, 7), # src valid + "load_padval": PADVAL_MIN, # TLOAD: fill cols 8-31 with MIN (int16 -32768) + "fill_padval": PADVAL_MIN, # TFILLPAD: no expansion needed + "eps": 0, + }, + + # ========== Case 11: int32, 260x7 -> 260x32, PadMin/Min ========== + # C++: runTFILLPAD + + { + "name": "s32_260x32_pad_260x7", + "dtype": np.int32, + "shape": (260, 32), # dst tile physical + "valid_shape": (260, 32), # dst valid (output size) + "src_shape": (260, 7), # src tile physical + "src_valid_shape": (260, 7), # src valid + "load_padval": PADVAL_MIN, # TLOAD: fill cols 8-31 with MIN (int32 -2147483648) + "fill_padval": PADVAL_MIN, # TFILLPAD: no expansion needed + "eps": 0, + }, + + # ========== Case 12: float, 128x64 -> 128x128, LoadPad=Null, FillPad=Neg1 ========== + # C++: runTFILLPAD + + { + "name": "f32_128x128_pad_128x64_neg1", + "dtype": np.float32, + "shape": (128, 128), # dst tile physical + "valid_shape": (128, 128), # dst valid = full dst (output size) + "src_shape": (128, 64), # src tile physical (64 cols) + "src_valid_shape": (128, 64), # src valid = full src + "load_padval": PADVAL_NULL, # TLOAD: no fill (src cols 64 aligned to 32B) + "fill_padval": PADVAL_NEG1, # TFILLPAD: fill cols 64-127 with -1.0f + "eps": 1e-6, + }, + + # ========== Case 13: float, 128x127 -> 128x160, LoadPad=Neg1, FillPad=Neg1 ========== + # C++: runTFILLPAD + + { + "name": "f32_128x160_pad_128x127_neg1", + "dtype": np.float32, + "shape": (128, 160), # dst tile physical + "valid_shape": (128, 160), # dst valid = full dst (output size) - CHANGED! + "src_shape": (128, 127), # src tile physical (127 cols) + "src_valid_shape": (128, 127), # src valid = full src + "load_padval": PADVAL_NEG1, # TLOAD: fill col 127 with -1.0f (127 not 32B aligned) + "fill_padval": PADVAL_NEG1, # TFILLPAD: fill cols 128-159 with -1.0f + "eps": 1e-6, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/compare.py new file mode 100644 index 000000000..1a023b000 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/compare.py @@ -0,0 +1,81 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Compare output against golden for tfillpad test cases. + +For tfillpad: + - Input: full tile shape (rows x cols) + - Output: only valid region (valid_rows x valid_cols) + - Golden: valid region only +""" + +import os +import sys +import numpy as np + +from cases import CASES + + +def main(): + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dtype = case["dtype"] + valid_shape = case["valid_shape"] + eps = case["eps"] + + # Load golden and output (both stored with valid_shape) + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=dtype).reshape(valid_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=dtype).reshape(valid_shape) + + # For integer types, eps=0 means exact match + # For float types, use np.allclose with eps + if eps == 0: + # Integer comparison - exact match + if not np.array_equal(golden, output): + diff = golden - output + idx = int(np.argmax(np.abs(diff))) + print(f"[ERROR] {case['name']}: Mismatch at idx={idx} (golden={golden.flat[idx]}, output={output.flat[idx]})") + all_passed = False + else: + print(f"[INFO] {case['name']}: compare passed") + else: + # Float comparison - use allclose + # Convert to float64 for comparison (fp16 precision issues) + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + + if g.shape != o.shape: + print(f"[ERROR] {case['name']}: Shape mismatch: golden {g.shape} vs output {o.shape}") + all_passed = False + continue + + if not np.allclose(g, o, atol=eps, rtol=eps, equal_nan=True): + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + print(f"[ERROR] {case['name']}: Mismatch: max diff={float(abs_diff.flat[idx])} " + f"at idx={idx} (golden={g.flat[idx]}, output={o.flat[idx]})") + all_passed = False + else: + print(f"[INFO] {case['name']}: compare passed") + + if not all_passed: + sys.exit(2) + print("[INFO] all cases passed") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/gen_data.py new file mode 100644 index 000000000..80b1a15a1 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/gen_data.py @@ -0,0 +1,117 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Generate golden data for tfillpad test cases. + +TFILLPAD semantics: + 1. Copy src.valid_shape data to dst + 2. Fill cols from src.valid_cols to dst.cols with FillPadVal + 3. Fill rows from src.rows to dst.rows with FillPadVal + +Note: LoadPadVal is used by TLOAD only, TFILLPAD uses FillPadVal for expansion. +""" + +import os +import numpy as np +import struct + +from cases import CASES, PADVAL_MAX, PADVAL_MIN, PADVAL_NULL, PADVAL_ZERO, PADVAL_NEG1 + + +# FLT_MAX and -FLT_MAX (matching DSL PadValue.MAX/MIN) +def _float32_from_bits(bits: int) -> float: + return struct.unpack(">f", bits.to_bytes(4, byteorder="big", signed=False))[0] + +_FLT_MAX = _float32_from_bits(0x7F7FFFFF) # ~3.4028235e+38 +_FLT_MIN = _float32_from_bits(0xFF7FFFFF) # ~-3.4028235e+38 + + +def get_pad_value(dtype, padval_name): + """Get the actual pad value for a dtype based on PadValue enum. + + Matches DSL PadValue.materialize_scalar behavior: + - MAX: FLT_MAX for float (not inf), max for integers + - MIN: -FLT_MAX for float (not -inf), min for integers + - NEG1: -1.0 for float, -1 for integers + - NULL/ZERO: 0 + """ + if padval_name == PADVAL_MAX: + if np.issubdtype(dtype, np.floating): + return np.float32(_FLT_MAX) + else: + return np.iinfo(dtype).max + elif padval_name == PADVAL_MIN: + if np.issubdtype(dtype, np.floating): + return np.float32(_FLT_MIN) + else: + return np.iinfo(dtype).min + elif padval_name == PADVAL_NEG1: + if np.issubdtype(dtype, np.floating): + return np.float32(-1.0) + else: + return dtype(-1) + else: # PADVAL_NULL or PADVAL_ZERO + return dtype(0) + + +def setup_case_rng(case): + """Set a per-case deterministic random seed.""" + np.random.seed(hash(case["name"]) & 0xFFFFFFFF) + + +def save_case_data(case_name, data_dict): + """Create case directory and write {name}.bin for each entry.""" + os.makedirs(case_name, exist_ok=True) + for name, arr in data_dict.items(): + arr.tofile(os.path.join(case_name, f"{name}.bin")) + + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + dst_shape = case["shape"] + dst_valid = case["valid_shape"] + src_shape = case.get("src_shape", dst_shape) + src_valid = case.get("src_valid_shape", dst_valid) + fill_padval = case.get("fill_padval", PADVAL_ZERO) + + # Input: generated with src_shape (matching C++ input size) + src_vr, src_vc = src_valid + input_data = np.zeros(src_shape, dtype=dtype) + input_data[:src_vr, :src_vc] = np.random.randint(1, 10, size=(src_vr, src_vc)).astype(dtype) + + # Golden: generated with dst_valid (output size) + dst_vr, dst_vc = dst_valid + golden = np.zeros(dst_valid, dtype=dtype) + + # Step 1: Copy src valid data to dst + copy_vr = min(src_vr, dst_vr) + copy_vc = min(src_vc, dst_vc) + golden[:copy_vr, :copy_vc] = input_data[:copy_vr, :copy_vc] + + # Step 2: TFILLPAD fills cols from src_valid_cols to dst_cols with FillPadVal + # (NOT LoadPadVal! TFILLPAD uses FillPadVal for expansion) + if dst_vc > src_vc: + fill_val = get_pad_value(dtype, fill_padval) + golden[:dst_vr, src_vc:dst_vc] = fill_val + + # Step 3: TFILLPAD fills rows from src_rows to dst_rows with FillPadVal + if dst_shape[0] > src_shape[0]: + fill_val = get_pad_value(dtype, fill_padval) + expand_rows_start = src_shape[0] + expand_rows_end = dst_vr + if expand_rows_end > expand_rows_start: + golden[expand_rows_start:expand_rows_end, :dst_vc] = fill_val + + save_case_data(case["name"], {"input": input_data, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} input={src_shape} golden={dst_valid} " + f"fill_pad={fill_padval} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/launch.cpp new file mode 100644 index 000000000..9f1583dc0 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/launch.cpp @@ -0,0 +1,93 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// ========== Case 1: float, 128x128, valid=128x127 ========== + +extern "C" __global__ AICORE void TFILLPAD_f32_128x128_pad_128x127(__gm__ float *src, __gm__ float *dst); + +void LaunchTFILLPAD_f32_128x128_pad_128x127(float *src, float *dst, void *stream) { + TFILLPAD_f32_128x128_pad_128x127<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +// ========== Case 2: float, 128x160, valid=128x127 ========== + +extern "C" __global__ AICORE void TFILLPAD_f32_128x160_pad_128x127(__gm__ float *src, __gm__ float *dst); + +void LaunchTFILLPAD_f32_128x160_pad_128x127(float *src, float *dst, void *stream) { + TFILLPAD_f32_128x160_pad_128x127<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +// ========== Case 3: float, 128x160, valid=128x127 (different PadVal) ========== + +extern "C" __global__ AICORE void TFILLPAD_f32_128x160_pad_128x127_v2(__gm__ float *src, __gm__ float *dst); + +void LaunchTFILLPAD_f32_128x160_pad_128x127_v2(float *src, float *dst, void *stream) { + TFILLPAD_f32_128x160_pad_128x127_v2<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +// ========== Case 4: float, 260x16, valid=260x7 ========== + +extern "C" __global__ AICORE void TFILLPAD_f32_260x16_pad_260x7(__gm__ float *src, __gm__ float *dst); + +void LaunchTFILLPAD_f32_260x16_pad_260x7(float *src, float *dst, void *stream) { + TFILLPAD_f32_260x16_pad_260x7<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +// ========== Case 6: uint16, 260x32, valid=260x7 ========== + +extern "C" __global__ AICORE void TFILLPAD_u16_260x32_pad_260x7(__gm__ uint16_t *src, __gm__ uint16_t *dst); + +void LaunchTFILLPAD_u16_260x32_pad_260x7(uint16_t *src, uint16_t *dst, void *stream) { + TFILLPAD_u16_260x32_pad_260x7<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint16_t *)dst); +} + +// ========== Case 7: int8, 260x64, valid=260x7 ========== + +extern "C" __global__ AICORE void TFILLPAD_s8_260x64_pad_260x7(__gm__ int8_t *src, __gm__ int8_t *dst); + +void LaunchTFILLPAD_s8_260x64_pad_260x7(int8_t *src, int8_t *dst, void *stream) { + TFILLPAD_s8_260x64_pad_260x7<<<1, nullptr, stream>>>((__gm__ int8_t *)src, (__gm__ int8_t *)dst); +} + +// ========== Case 10: int16, 260x32, valid=260x7 ========== + +extern "C" __global__ AICORE void TFILLPAD_s16_260x32_pad_260x7(__gm__ int16_t *src, __gm__ int16_t *dst); + +void LaunchTFILLPAD_s16_260x32_pad_260x7(int16_t *src, int16_t *dst, void *stream) { + TFILLPAD_s16_260x32_pad_260x7<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} + +// ========== Case 11: int32, 260x32, valid=260x7 ========== + +extern "C" __global__ AICORE void TFILLPAD_s32_260x32_pad_260x7(__gm__ int32_t *src, __gm__ int32_t *dst); + +void LaunchTFILLPAD_s32_260x32_pad_260x7(int32_t *src, int32_t *dst, void *stream) { + TFILLPAD_s32_260x32_pad_260x7<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} + +// ========== Case 12: float, src=128x64, dst=128x128, PadCustomNeg1 ========== + +extern "C" __global__ AICORE void TFILLPAD_f32_128x128_pad_128x64_neg1(__gm__ float *src, __gm__ float *dst); + +void LaunchTFILLPAD_f32_128x128_pad_128x64_neg1(float *src, float *dst, void *stream) { + TFILLPAD_f32_128x128_pad_128x64_neg1<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +// ========== Case 13: float, src=128x127, dst=128x160, PadCustomNeg1 ========== + +extern "C" __global__ AICORE void TFILLPAD_f32_128x160_pad_128x127_neg1(__gm__ float *src, __gm__ float *dst); + +void LaunchTFILLPAD_f32_128x160_pad_128x127_neg1(float *src, float *dst, void *stream) { + TFILLPAD_f32_128x160_pad_128x127_neg1<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/main.cpp new file mode 100644 index 000000000..a3b0036d0 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/main.cpp @@ -0,0 +1,207 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tfillpad ST (non-inplace mode). +// Matches C++ reference test cases: Cases 1, 2, 3, 4, 6, 7, 10, 11, 12, 13 +// Output size: dst valid region (dst tile physical shape for full output) + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTFILLPAD_f32_128x128_pad_128x127(float *src, float *dst, void *stream); +void LaunchTFILLPAD_f32_128x160_pad_128x127(float *src, float *dst, void *stream); +void LaunchTFILLPAD_f32_128x160_pad_128x127_v2(float *src, float *dst, void *stream); +void LaunchTFILLPAD_f32_260x16_pad_260x7(float *src, float *dst, void *stream); +void LaunchTFILLPAD_u16_260x32_pad_260x7(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTFILLPAD_s8_260x64_pad_260x7(int8_t *src, int8_t *dst, void *stream); +void LaunchTFILLPAD_s16_260x32_pad_260x7(int16_t *src, int16_t *dst, void *stream); +void LaunchTFILLPAD_s32_260x32_pad_260x7(int32_t *src, int32_t *dst, void *stream); +void LaunchTFILLPAD_f32_128x128_pad_128x64_neg1(float *src, float *dst, void *stream); +void LaunchTFILLPAD_f32_128x160_pad_128x127_neg1(float *src, float *dst, void *stream); + +enum class DataType { F32, U16, S8, S16, S32 }; + +struct TestCase { + const char *name; + DataType dtype; + void (*launch)(void *, void *, void *); + size_t rows; // dst tile rows (physical) + size_t cols; // dst tile cols (physical) + size_t validRows; // dst valid rows (output rows) + size_t validCols; // dst valid cols (output cols) - CHANGED: now = dst physical cols for full output + size_t srcRows; // src tensor rows (0 means same as rows) + size_t srcCols; // src tensor cols (0 means same as cols) + size_t elemSize; +}; + +template +void wrapLaunch(void *src, void *dst, void *stream, void (*fn)(T *, T *, void *)) { + fn((T *)src, (T *)dst, stream); +} + +static const TestCase kCases[] = { + // Case 1: float, src=128x127, dst=128x128, LoadPad=Max, FillPad=Max + // Output: 128x128 (full dst tile) + {"f32_128x128_pad_128x127", DataType::F32, + [](void *src, void *dst, void *stream) { wrapLaunch(src, dst, stream, LaunchTFILLPAD_f32_128x128_pad_128x127); }, + 128, 128, 128, 128, 128, 127, sizeof(float)}, // CHANGED: validCols=128, srcCols=127 + + // Case 2: float, src=128x127, dst=128x160, LoadPad=Max, FillPad=Max + // Output: 128x160 (full dst tile) + {"f32_128x160_pad_128x127", DataType::F32, + [](void *src, void *dst, void *stream) { wrapLaunch(src, dst, stream, LaunchTFILLPAD_f32_128x160_pad_128x127); }, + 128, 160, 128, 160, 128, 127, sizeof(float)}, // CHANGED: validCols=160, srcCols=127 + + // Case 3: float, src=128x127, dst=128x160, LoadPad=Min, FillPad=Max + // Output: 128x160 (full dst tile) + {"f32_128x160_pad_128x127_v2", DataType::F32, + [](void *src, void *dst, void *stream) { wrapLaunch(src, dst, stream, LaunchTFILLPAD_f32_128x160_pad_128x127_v2); }, + 128, 160, 128, 160, 128, 127, sizeof(float)}, // CHANGED: validCols=160, srcCols=127 + + // Case 4: float, src=260x7, dst=260x16, LoadPad=Min, FillPad=Max + // Output: 260x16 (full dst tile) + {"f32_260x16_pad_260x7", DataType::F32, + [](void *src, void *dst, void *stream) { wrapLaunch(src, dst, stream, LaunchTFILLPAD_f32_260x16_pad_260x7); }, + 260, 16, 260, 16, 260, 7, sizeof(float)}, // CHANGED: validCols=16, srcCols=7 + + // Case 6: uint16, src=260x7, dst=260x32, LoadPad=Min, FillPad=Max + // Output: 260x32 (full dst tile) + {"u16_260x32_pad_260x7", DataType::U16, + [](void *src, void *dst, void *stream) { wrapLaunch(src, dst, stream, LaunchTFILLPAD_u16_260x32_pad_260x7); }, + 260, 32, 260, 32, 260, 7, sizeof(uint16_t)}, + + // Case 7: int8, src=260x7, dst=260x64, LoadPad=Min, FillPad=Max + // Output: 260x64 (full dst tile) + {"s8_260x64_pad_260x7", DataType::S8, + [](void *src, void *dst, void *stream) { wrapLaunch(src, dst, stream, LaunchTFILLPAD_s8_260x64_pad_260x7); }, + 260, 64, 260, 64, 260, 7, sizeof(int8_t)}, // CHANGED: validCols=64, srcCols=7 + + // Case 10: int16, src=260x7, dst=260x32, LoadPad=Min, FillPad=Min + // Output: 260x32 (full dst tile) + {"s16_260x32_pad_260x7", DataType::S16, + [](void *src, void *dst, void *stream) { wrapLaunch(src, dst, stream, LaunchTFILLPAD_s16_260x32_pad_260x7); }, + 260, 32, 260, 32, 260, 7, sizeof(int16_t)}, // CHANGED: validCols=32, srcCols=7 + + // Case 11: int32, src=260x7, dst=260x32, LoadPad=Min, FillPad=Min + // Output: 260x32 (full dst tile) + {"s32_260x32_pad_260x7", DataType::S32, + [](void *src, void *dst, void *stream) { wrapLaunch(src, dst, stream, LaunchTFILLPAD_s32_260x32_pad_260x7); }, + 260, 32, 260, 32, 260, 7, sizeof(int32_t)}, // CHANGED: validCols=32, srcCols=7 + + // Case 12: float, src=128x64, dst=128x128, LoadPad=Null, FillPad=Custom(-1.0f) + // Output: 128x128 (full dst tile) + {"f32_128x128_pad_128x64_neg1", DataType::F32, + [](void *src, void *dst, void *stream) { wrapLaunch(src, dst, stream, LaunchTFILLPAD_f32_128x128_pad_128x64_neg1); }, + 128, 128, 128, 128, 128, 64, sizeof(float)}, // correct: validCols=128, srcCols=64 + + // Case 13: float, src=128x127, dst=128x160, LoadPad=Custom(-1.0f), FillPad=Custom(-1.0f) + // Output: 128x160 (full dst tile) - CHANGED from 127 to 160 + {"f32_128x160_pad_128x127_neg1", DataType::F32, + [](void *src, void *dst, void *stream) { wrapLaunch(src, dst, stream, LaunchTFILLPAD_f32_128x160_pad_128x127_neg1); }, + 128, 160, 128, 160, 128, 127, sizeof(float)}, // CHANGED: validCols=160 +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + size_t srcRows = (tc.srcRows > 0) ? tc.srcRows : tc.rows; + size_t srcCols = (tc.srcCols > 0) ? tc.srcCols : tc.cols; + size_t inputElemCount = srcRows * srcCols; + size_t outputElemCount = tc.validRows * tc.validCols; + size_t inputFileSize = inputElemCount * tc.elemSize; + size_t outputFileSize = outputElemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (src=%zux%zu, dst=%zux%zu, output=%zux%zu) ===\n", + tc.name, srcRows, srcCols, tc.rows, tc.cols, tc.validRows, tc.validCols); + + std::string caseDir = std::string("./") + tc.name; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, inputFileSize); + aclrtMallocHost(&dstHost, outputFileSize); + + aclrtMalloc(&srcDevice, inputFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, outputFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), inputFileSize, srcHost, inputFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, inputFileSize, srcHost, inputFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, outputFileSize, dstDevice, outputFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, outputFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/tfillpad.pto b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/tfillpad.pto new file mode 100644 index 000000000..cb5e5f4dc --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/tfillpad.pto @@ -0,0 +1,576 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tfillpad (non-inplace mode). +// Matches C++ reference test cases: Cases 1, 2, 3, 4, 6, 7, 10, 11, 12, 13 +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// +// PadValue encoding: 0=Null, 1=Zero, 2=Max, 3=Min +// Cases 12/13 use Custom(-1.0f) which cannot be encoded in PTO IR, +// template uses shape-based detection for these cases. +// +// C++ template params: shape3=src_rows, shape4=src_cols, kTRows_=dst_rows, kTCols_=dst_cols + +module { + // ========== Case 1: float, src=128x127, dst=128x128, LoadPad=Max, FillPad=Max ========== + + func.func @TFILLPAD_f32_128x128_pad_128x127(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c127 = arith.constant 127 : index + %c128 = arith.constant 128 : index + %c16256 = arith.constant 16256 : index // 128*127 (src size) + %c16384 = arith.constant 16384 : index // 128*128 (dst size) + + // Src tensor_view: 128x127 (matching C++ shape4=127) + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c128, %c127], + strides = [%c16256, %c16256, %c16256, %c127, %c1] + : !pto.tensor_view<1x1x1x128x127xf32> + // Dst tensor_view: 128x128 + %dst_out_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c128, %c128], + strides = [%c16384, %c16384, %c16384, %c128, %c1] + : !pto.tensor_view<1x1x1x128x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c127] + : !pto.tensor_view<1x1x1x128x127xf32> -> !pto.partition_tensor_view<1x1x1x128x127xf32> + %dst_out_part = pto.partition_view %dst_out_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c128] + : !pto.tensor_view<1x1x1x128x128xf32> -> !pto.partition_tensor_view<1x1x1x128x128xf32> + + // Src tile: LoadPadVal=Max (pad=2), src physical=128x127, v_col=127 + // C++: shape4_aligned = align_to_32B(127, float) = 128, so tile cols=128 for alignment + %src = pto.alloc_tile + : !pto.tile_buf + // Dst tile: FillPadVal=Max (pad=2), dst physical=128x128, v_col=128 (full output) + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x128x127xf32>) + outs(%src : !pto.tile_buf) + + pto.tfillpad ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_out_part : !pto.partition_tensor_view<1x1x1x128x128xf32>) + return + } + + // ========== Case 2: float, src=128x127, dst=128x160, LoadPad=Max, FillPad=Max ========== + + func.func @TFILLPAD_f32_128x160_pad_128x127(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c127 = arith.constant 127 : index + %c128 = arith.constant 128 : index + %c160 = arith.constant 160 : index + %c16256 = arith.constant 16256 : index // 128*127 (src size) + %c20480 = arith.constant 20480 : index // 128*160 (dst size) + + // Src tensor_view: 128x127 + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c128, %c127], + strides = [%c16256, %c16256, %c16256, %c127, %c1] + : !pto.tensor_view<1x1x1x128x127xf32> + // Dst tensor_view: 128x160 + %dst_out_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c128, %c160], + strides = [%c20480, %c20480, %c20480, %c160, %c1] + : !pto.tensor_view<1x1x1x128x160xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c127] + : !pto.tensor_view<1x1x1x128x127xf32> -> !pto.partition_tensor_view<1x1x1x128x127xf32> + %dst_out_part = pto.partition_view %dst_out_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c160] + : !pto.tensor_view<1x1x1x128x160xf32> -> !pto.partition_tensor_view<1x1x1x128x160xf32> + + // Src tile: LoadPadVal=Max (pad=2), shape4_aligned=128, v_col=127 + %src = pto.alloc_tile + : !pto.tile_buf + // Dst tile: FillPadVal=Max (pad=2), dst physical=128x160, v_col=160 (full output) + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x128x127xf32>) + outs(%src : !pto.tile_buf) + + pto.tfillpad ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_out_part : !pto.partition_tensor_view<1x1x1x128x160xf32>) + return + } + + // ========== Case 3: float, src=128x127, dst=128x160, LoadPad=Min, FillPad=Max ========== + + func.func @TFILLPAD_f32_128x160_pad_128x127_v2(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c127 = arith.constant 127 : index + %c128 = arith.constant 128 : index + %c160 = arith.constant 160 : index + %c16256 = arith.constant 16256 : index // 128*127 (src size) + %c20480 = arith.constant 20480 : index // 128*160 (dst size) + + // Src tensor_view: 128x127 + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c128, %c127], + strides = [%c16256, %c16256, %c16256, %c127, %c1] + : !pto.tensor_view<1x1x1x128x127xf32> + // Dst tensor_view: 128x160 + %dst_out_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c128, %c160], + strides = [%c20480, %c20480, %c20480, %c160, %c1] + : !pto.tensor_view<1x1x1x128x160xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c127] + : !pto.tensor_view<1x1x1x128x127xf32> -> !pto.partition_tensor_view<1x1x1x128x127xf32> + %dst_out_part = pto.partition_view %dst_out_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c160] + : !pto.tensor_view<1x1x1x128x160xf32> -> !pto.partition_tensor_view<1x1x1x128x160xf32> + + // Src tile: LoadPadVal=Min (pad=3), shape4_aligned=128, v_col=127 + %src = pto.alloc_tile + : !pto.tile_buf + // Dst tile: FillPadVal=Max (pad=2), dst physical=128x160, v_col=160 (full output) + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x128x127xf32>) + outs(%src : !pto.tile_buf) + + pto.tfillpad ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_out_part : !pto.partition_tensor_view<1x1x1x128x160xf32>) + return + } + + // ========== Case 4: float, src=260x7, dst=260x16, LoadPad=Min, FillPad=Max ========== + + func.func @TFILLPAD_f32_260x16_pad_260x7(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c16 = arith.constant 16 : index + %c260 = arith.constant 260 : index + %c1820 = arith.constant 1820 : index // 260*7 (src size) + %c4160 = arith.constant 4160 : index // 260*16 (dst size) + + // Src tensor_view: 260x7 + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c260, %c7], + strides = [%c1820, %c1820, %c1820, %c7, %c1] + : !pto.tensor_view<1x1x1x260x7xf32> + // Dst tensor_view: 260x16 + %dst_out_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c260, %c16], + strides = [%c4160, %c4160, %c4160, %c16, %c1] + : !pto.tensor_view<1x1x1x260x16xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260, %c7] + : !pto.tensor_view<1x1x1x260x7xf32> -> !pto.partition_tensor_view<1x1x1x260x7xf32> + %dst_out_part = pto.partition_view %dst_out_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260, %c16] + : !pto.tensor_view<1x1x1x260x16xf32> -> !pto.partition_tensor_view<1x1x1x260x16xf32> + + // Src tile: LoadPadVal=Min (pad=3), shape4_aligned=8 (align 7 to 8 for f32), v_col=7 + %src = pto.alloc_tile + : !pto.tile_buf + // Dst tile: FillPadVal=Max (pad=2), dst physical=260x16, v_col=16 (full output) + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x260x7xf32>) + outs(%src : !pto.tile_buf) + + pto.tfillpad ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_out_part : !pto.partition_tensor_view<1x1x1x260x16xf32>) + return + } + + // ========== Case 6: uint16, src=260x7, dst=260x32, LoadPad=Min, FillPad=Max ========== + + func.func @TFILLPAD_u16_260x32_pad_260x7(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c32 = arith.constant 32 : index + %c260 = arith.constant 260 : index + %c1820 = arith.constant 1820 : index // 260*7 (src size) + %c8320 = arith.constant 8320 : index // 260*32 (dst size) + + // Src tensor_view: 260x7 + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c260, %c7], + strides = [%c1820, %c1820, %c1820, %c7, %c1] + : !pto.tensor_view<1x1x1x260x7xui16> + // Dst tensor_view: 260x32 + %dst_out_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c260, %c32], + strides = [%c8320, %c8320, %c8320, %c32, %c1] + : !pto.tensor_view<1x1x1x260x32xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260, %c7] + : !pto.tensor_view<1x1x1x260x7xui16> -> !pto.partition_tensor_view<1x1x1x260x7xui16> + %dst_out_part = pto.partition_view %dst_out_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260, %c32] + : !pto.tensor_view<1x1x1x260x32xui16> -> !pto.partition_tensor_view<1x1x1x260x32xui16> + + // Src tile: LoadPadVal=Min (pad=3), shape4_aligned=16, v_col=7 + %src = pto.alloc_tile + : !pto.tile_buf + // Dst tile: FillPadVal=Max (pad=2), dst physical=260x32, v_col=32 (full output) + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x260x7xui16>) + outs(%src : !pto.tile_buf) + + pto.tfillpad ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_out_part : !pto.partition_tensor_view<1x1x1x260x32xui16>) + return + } + + // ========== Case 7: int8, src=260x7, dst=260x64, LoadPad=Min, FillPad=Max ========== + + func.func @TFILLPAD_s8_260x64_pad_260x7(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c64 = arith.constant 64 : index + %c260 = arith.constant 260 : index + %c1820 = arith.constant 1820 : index // 260*7 (src size) + %c16640 = arith.constant 16640 : index // 260*64 (dst size) + + // Src tensor_view: 260x7 + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c260, %c7], + strides = [%c1820, %c1820, %c1820, %c7, %c1] + : !pto.tensor_view<1x1x1x260x7xi8> + // Dst tensor_view: 260x64 + %dst_out_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c260, %c64], + strides = [%c16640, %c16640, %c16640, %c64, %c1] + : !pto.tensor_view<1x1x1x260x64xi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260, %c7] + : !pto.tensor_view<1x1x1x260x7xi8> -> !pto.partition_tensor_view<1x1x1x260x7xi8> + %dst_out_part = pto.partition_view %dst_out_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260, %c64] + : !pto.tensor_view<1x1x1x260x64xi8> -> !pto.partition_tensor_view<1x1x1x260x64xi8> + + // Src tile: LoadPadVal=Min (pad=3), shape4_aligned=32 (align 7 to 32 for i8), v_col=7 + %src = pto.alloc_tile + : !pto.tile_buf + // Dst tile: FillPadVal=Max (pad=2), dst physical=260x64, v_col=64 (full output) + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x260x7xi8>) + outs(%src : !pto.tile_buf) + + pto.tfillpad ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_out_part : !pto.partition_tensor_view<1x1x1x260x64xi8>) + return + } + + // ========== Case 10: int16, src=260x7, dst=260x32, LoadPad=Min, FillPad=Min ========== + + func.func @TFILLPAD_s16_260x32_pad_260x7(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c32 = arith.constant 32 : index + %c260 = arith.constant 260 : index + %c1820 = arith.constant 1820 : index // 260*7 (src size) + %c8320 = arith.constant 8320 : index // 260*32 (dst size) + + // Src tensor_view: 260x7 + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c260, %c7], + strides = [%c1820, %c1820, %c1820, %c7, %c1] + : !pto.tensor_view<1x1x1x260x7xi16> + // Dst tensor_view: 260x32 + %dst_out_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c260, %c32], + strides = [%c8320, %c8320, %c8320, %c32, %c1] + : !pto.tensor_view<1x1x1x260x32xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260, %c7] + : !pto.tensor_view<1x1x1x260x7xi16> -> !pto.partition_tensor_view<1x1x1x260x7xi16> + %dst_out_part = pto.partition_view %dst_out_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260, %c32] + : !pto.tensor_view<1x1x1x260x32xi16> -> !pto.partition_tensor_view<1x1x1x260x32xi16> + + // Src tile: LoadPadVal=Min (pad=3), shape4_aligned=16, v_col=7 + %src = pto.alloc_tile + : !pto.tile_buf + // Dst tile: FillPadVal=Min (pad=3), dst physical=260x32, v_col=32 (full output) + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x260x7xi16>) + outs(%src : !pto.tile_buf) + + pto.tfillpad ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_out_part : !pto.partition_tensor_view<1x1x1x260x32xi16>) + return + } + + // ========== Case 11: int32, src=260x7, dst=260x32, LoadPad=Min, FillPad=Min ========== + + func.func @TFILLPAD_s32_260x32_pad_260x7(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c32 = arith.constant 32 : index + %c260 = arith.constant 260 : index + %c1820 = arith.constant 1820 : index // 260*7 (src size) + %c8320 = arith.constant 8320 : index // 260*32 (dst size) + + // Src tensor_view: 260x7 + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c260, %c7], + strides = [%c1820, %c1820, %c1820, %c7, %c1] + : !pto.tensor_view<1x1x1x260x7xi32> + // Dst tensor_view: 260x32 + %dst_out_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c260, %c32], + strides = [%c8320, %c8320, %c8320, %c32, %c1] + : !pto.tensor_view<1x1x1x260x32xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260, %c7] + : !pto.tensor_view<1x1x1x260x7xi32> -> !pto.partition_tensor_view<1x1x1x260x7xi32> + %dst_out_part = pto.partition_view %dst_out_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260, %c32] + : !pto.tensor_view<1x1x1x260x32xi32> -> !pto.partition_tensor_view<1x1x1x260x32xi32> + + // Src tile: LoadPadVal=Min (pad=3), shape4_aligned=8 (align 7 to 8 for i32), v_col=7 + %src = pto.alloc_tile + : !pto.tile_buf + // Dst tile: FillPadVal=Min (pad=3), dst physical=260x32, v_col=32 (full output) + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x260x7xi32>) + outs(%src : !pto.tile_buf) + + pto.tfillpad ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_out_part : !pto.partition_tensor_view<1x1x1x260x32xi32>) + return + } + + // ========== Case 12: float, src=128x64, dst=128x128, LoadPad=Null, FillPad=Custom(-1.0f) ========== + // PTO IR cannot encode Custom PadValue, template uses shape-based detection: + // src.valid_cols < dst.valid_cols => fill expansion region with -1.0f + + func.func @TFILLPAD_f32_128x128_pad_128x64_neg1(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c8192 = arith.constant 8192 : index // 128*64 (src size) + %c16384 = arith.constant 16384 : index // 128*128 (dst size) + + // Src tensor_view: 128x64 (C++ shape4=64) + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c128, %c64], + strides = [%c8192, %c8192, %c8192, %c64, %c1] + : !pto.tensor_view<1x1x1x128x64xf32> + + // Dst output tensor_view: 128x128 (full dst valid region) + %dst_out_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c128, %c128], + strides = [%c16384, %c16384, %c16384, %c128, %c1] + : !pto.tensor_view<1x1x1x128x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c64] + : !pto.tensor_view<1x1x1x128x64xf32> -> !pto.partition_tensor_view<1x1x1x128x64xf32> + %dst_out_part = pto.partition_view %dst_out_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c128] + : !pto.tensor_view<1x1x1x128x128xf32> -> !pto.partition_tensor_view<1x1x1x128x128xf32> + + // Src tile: LoadPadVal=Null (pad=0), shape4_aligned=64 (already aligned), v_col=64 + %src = pto.alloc_tile + : !pto.tile_buf + // Dst tile: FillPadVal=Custom(-1.0f) - detected by template via src.v_col < dst.v_col + // Use pad=1 (Zero) as placeholder (PTO IR cannot encode Custom), template detects expansion and uses -1.0f + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x128x64xf32>) + outs(%src : !pto.tile_buf) + + pto.tfillpad ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_out_part : !pto.partition_tensor_view<1x1x1x128x128xf32>) + return + } + + // ========== Case 13: float, src=128x127, dst=128x160, LoadPad=Custom(-1.0f), FillPad=Custom(-1.0f) ========== + // PTO IR cannot encode Custom PadValue, template uses shape-based detection + + func.func @TFILLPAD_f32_128x160_pad_128x127_neg1(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c127 = arith.constant 127 : index + %c128 = arith.constant 128 : index + %c160 = arith.constant 160 : index + %c16256 = arith.constant 16256 : index // 128*127 (src size) + %c20480 = arith.constant 20480 : index // 128*160 (dst size) + + // Src tensor_view: 128x127 + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c128, %c127], + strides = [%c16256, %c16256, %c16256, %c127, %c1] + : !pto.tensor_view<1x1x1x128x127xf32> + + // Dst output tensor_view: 128x160 (full dst output) + %dst_out_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c128, %c160], + strides = [%c20480, %c20480, %c20480, %c160, %c1] + : !pto.tensor_view<1x1x1x128x160xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c127] + : !pto.tensor_view<1x1x1x128x127xf32> -> !pto.partition_tensor_view<1x1x1x128x127xf32> + %dst_out_part = pto.partition_view %dst_out_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c160] + : !pto.tensor_view<1x1x1x128x160xf32> -> !pto.partition_tensor_view<1x1x1x128x160xf32> + + // Src tile: LoadPadVal=Custom(-1.0f), shape4_aligned=128, v_col=127 + // Use pad=0, template will detect and fill src padding region with -1.0f + %src = pto.alloc_tile + : !pto.tile_buf + // Dst tile: FillPadVal=Custom(-1.0f), dst physical=128x160, v_col=160 + // Use pad=1 (Zero) as placeholder (PTO IR cannot encode Custom), template detects expansion and uses -1.0f + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x128x127xf32>) + outs(%src : !pto.tile_buf) + + pto.tfillpad ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_out_part : !pto.partition_tensor_view<1x1x1x128x160xf32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/CMakeLists.txt new file mode 100644 index 000000000..eaf9f0308 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tfillpad_expand) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/cases.py new file mode 100644 index 000000000..351d7716a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/cases.py @@ -0,0 +1,74 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tfillpad_expand ST test cases. + +Matches C++ reference test cases: Cases 8, 9 + +C++ expand mode parameters: + - shape3: src physical rows + - shape4: src physical cols + - kTRows_: dst physical rows + - kTCols_: dst physical cols + - expand=true: TFILLPAD_EXPAND copies src valid data, fills expansion with FillPadVal + +Case 8: runTFILLPAD +Case 9: runTFILLPAD + +Each case defines: + - name: case identifier + - dtype: numpy dtype + - shape: (rows, cols) — src tile physical dimensions (input size) + - valid_shape: (valid_rows, valid_cols) — src valid region + - dst_shape: (rows, cols) — dst tile physical dimensions + - dst_valid_shape: (valid_rows, valid_cols) — dst valid region (output size) + - load_padval: PadValue for TLOAD (fill invalid columns in src tile) + - fill_padval: PadValue for TFILLPAD_EXPAND (fill expansion region in dst) + - eps: tolerance for numpy.allclose +""" + +import numpy as np + +# PadValue enum values matching C++ definition +PADVAL_MAX = "Max" # FLT_MAX for float, MAX for integers +PADVAL_MIN = "Min" # -FLT_MAX for float, MIN for integers +PADVAL_NULL = "Null" # no fill +PADVAL_ZERO = "Zero" # zero fill +PADVAL_NEG1 = "Neg1" # -1.0f for float, -1 for integers (Custom) + +CASES = [ + # ========== Case 1: uint16, src=259x7, dst=260x32, expand, LoadPad=Min, FillPad=Max ========== + + { + "name": "u16_260x32_src_259x7", + "dtype": np.uint16, + "shape": (259, 7), # src physical (C++ shape3=259, shape4=7) + "valid_shape": (259, 7), # src valid region (actual data) + "dst_shape": (260, 32), # dst physical + "dst_valid_shape": (260, 32), # dst valid (output size) + "load_padval": PADVAL_MIN, # TLOAD: fill cols 7-31 with MIN (uint16 MIN=0) + "fill_padval": PADVAL_MAX, # TFILLPAD_EXPAND: fill expansion region with MAX (uint16 MAX=65535) + "eps": 0, + }, + + # ========== Case 2: int8, src=259x7, dst=260x64, expand, LoadPad=Min, FillPad=Max ========== + + { + "name": "s8_260x64_src_259x7", + "dtype": np.int8, + "shape": (259, 7), # src physical (C++ shape3=259, shape4=7) + "valid_shape": (259, 7), # src valid region (actual data) + "dst_shape": (260, 64), # dst physical + "dst_valid_shape": (260, 64), # dst valid (output size) + "load_padval": PADVAL_MIN, # TLOAD: fill cols 7-63 with MIN (int8 MIN=-128) + "fill_padval": PADVAL_MAX, # TFILLPAD_EXPAND: fill expansion region with MAX (127) + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/compare.py new file mode 100644 index 000000000..fdd4a1d13 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/compare.py @@ -0,0 +1,75 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Compare output against golden for tfillpad_expand test cases.""" + +import os +import sys +import numpy as np + +from cases import CASES + + +def main(): + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dtype = case["dtype"] + dst_shape = case["dst_shape"] + eps = case["eps"] + + # Load golden and output (both stored with dst_shape) + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=dtype).reshape(dst_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=dtype).reshape(dst_shape) + + # For integer types, eps=0 means exact match + # For float types, use np.allclose with eps + if eps == 0: + # Integer comparison - exact match + if not np.array_equal(golden, output): + diff = golden - output + idx = int(np.argmax(np.abs(diff))) + print(f"[ERROR] {case['name']}: Mismatch at idx={idx} (golden={golden.flat[idx]}, output={output.flat[idx]})") + all_passed = False + else: + print(f"[INFO] {case['name']}: compare passed") + else: + # Float comparison - use allclose + # Convert to float64 for comparison (fp16 precision issues) + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + + if g.shape != o.shape: + print(f"[ERROR] {case['name']}: Shape mismatch: golden {g.shape} vs output {o.shape}") + all_passed = False + continue + + if not np.allclose(g, o, atol=eps, rtol=eps, equal_nan=True): + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + print(f"[ERROR] {case['name']}: Mismatch: max diff={float(abs_diff.flat[idx])} " + f"at idx={idx} (golden={g.flat[idx]}, output={o.flat[idx]})") + all_passed = False + else: + print(f"[INFO] {case['name']}: compare passed") + + if not all_passed: + sys.exit(2) + print("[INFO] all cases passed") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/gen_data.py new file mode 100644 index 000000000..c7b55c8a5 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/gen_data.py @@ -0,0 +1,114 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Generate golden data for tfillpad_expand test cases. + +TFILLPAD_EXPAND semantics: + 1. Copy src.valid_shape data to dst + 2. Fill cols from src.valid_cols to dst.valid_cols with FillPadVal + 3. Fill rows from src.rows to dst.rows with FillPadVal + +Note: LoadPadVal is used by TLOAD only, TFILLPAD_EXPAND uses FillPadVal for expansion. +""" + +import os +import numpy as np +import struct + +from cases import CASES, PADVAL_MAX, PADVAL_MIN, PADVAL_NEG1, PADVAL_ZERO + + +# FLT_MAX and -FLT_MAX (matching DSL PadValue.MAX/MIN) +def _float32_from_bits(bits: int) -> float: + return struct.unpack(">f", bits.to_bytes(4, byteorder="big", signed=False))[0] + +_FLT_MAX = _float32_from_bits(0x7F7FFFFF) # ~3.4028235e+38 +_FLT_MIN = _float32_from_bits(0xFF7FFFFF) # ~-3.4028235e+38 + + +def get_pad_value(dtype, padval_name): + """Get the actual pad value for a dtype based on PadValue enum. + + Matches DSL PadValue.materialize_scalar behavior: + - MAX: FLT_MAX for float (not inf), max for integers + - MIN: -FLT_MAX for float (not -inf), min for integers + - NEG1: -1.0 for float, -1 for integers + - NULL/ZERO: 0 + """ + if padval_name == PADVAL_MAX: + if np.issubdtype(dtype, np.floating): + return np.float32(_FLT_MAX) + else: + return np.iinfo(dtype).max + elif padval_name == PADVAL_MIN: + if np.issubdtype(dtype, np.floating): + return np.float32(_FLT_MIN) + else: + return np.iinfo(dtype).min + elif padval_name == PADVAL_NEG1: + if np.issubdtype(dtype, np.floating): + return np.float32(-1.0) + else: + return dtype(-1) + else: # PADVAL_NULL or PADVAL_ZERO + return dtype(0) + + +def setup_case_rng(case): + """Set a per-case deterministic random seed.""" + np.random.seed(hash(case["name"]) & 0xFFFFFFFF) + + +def save_case_data(case_name, data_dict): + """Create case directory and write {name}.bin for each entry.""" + os.makedirs(case_name, exist_ok=True) + for name, arr in data_dict.items(): + arr.tofile(os.path.join(case_name, f"{name}.bin")) + + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + src_shape = case["shape"] # src physical (input size, matching tensor_view) + src_valid = case["valid_shape"] # src valid region (actual data in input) + dst_shape = case["dst_shape"] # dst physical + dst_valid = case["dst_valid_shape"] # dst valid (output size) + fill_padval = case.get("fill_padval", PADVAL_ZERO) + + src_vr, src_vc = src_valid + dst_vr, dst_vc = dst_valid + + # Generate input: random values in src valid region, zeros elsewhere + # Input size = src_shape (matching tensor_view and C++ input) + input_data = np.zeros(src_shape, dtype=dtype) + input_data[:src_vr, :src_vc] = np.random.randint(1, 10, size=(src_vr, src_vc)).astype(dtype) + + # Generate golden: dst valid region (output size) + golden = np.zeros(dst_valid, dtype=dtype) + + # Step 1: Copy src valid data to dst + copy_vr = min(src_vr, dst_vr) + copy_vc = min(src_vc, dst_vc) + golden[:copy_vr, :copy_vc] = input_data[:copy_vr, :copy_vc] + + # Step 2: Fill column expansion region (cols from src_vc to dst_vc) + if dst_vc > src_vc: + fill_val = get_pad_value(dtype, fill_padval) + golden[:dst_vr, src_vc:dst_vc] = fill_val + + # Step 3: Fill row expansion region (rows from src_vr to dst_vr) + if dst_vr > src_vr: + fill_val = get_pad_value(dtype, fill_padval) + golden[src_vr:dst_vr, :dst_vc] = fill_val + + save_case_data(case["name"], {"input": input_data, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} src={src_shape} valid={src_valid} -> dst={dst_shape} " + f"fill_pad={fill_padval} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/launch.cpp new file mode 100644 index 000000000..c2f6a6da0 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/launch.cpp @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// ========== uint16 kernel (C++ case 8) ========== + +extern "C" __global__ AICORE void TFILLPAD_EXPAND_u16_260x32_src_259x7(__gm__ uint16_t *src, __gm__ uint16_t *dst); + +void LaunchTFILLPAD_EXPAND_u16_260x32_src_259x7(uint16_t *src, uint16_t *dst, void *stream) { + TFILLPAD_EXPAND_u16_260x32_src_259x7<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint16_t *)dst); +} + +// ========== int8 kernel (C++ case 9) ========== + +extern "C" __global__ AICORE void TFILLPAD_EXPAND_s8_260x64_src_259x7(__gm__ int8_t *src, __gm__ int8_t *dst); + +void LaunchTFILLPAD_EXPAND_s8_260x64_src_259x7(int8_t *src, int8_t *dst, void *stream) { + TFILLPAD_EXPAND_s8_260x64_src_259x7<<<1, nullptr, stream>>>((__gm__ int8_t *)src, (__gm__ int8_t *)dst); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/main.cpp new file mode 100644 index 000000000..72a657248 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/main.cpp @@ -0,0 +1,151 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tfillpad_expand ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTFILLPAD_EXPAND_u16_260x32_src_259x7(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTFILLPAD_EXPAND_s8_260x64_src_259x7(int8_t *src, int8_t *dst, void *stream); + +enum class DataType { U16, S8 }; + +struct TestCase { + const char *name; + DataType dtype; + void (*launch)(void *, void *, void *); // Generic launch function pointer + size_t srcRows; + size_t srcCols; + size_t srcValidRows; + size_t srcValidCols; + size_t dstRows; + size_t dstCols; + size_t dstValidRows; + size_t dstValidCols; + size_t elemSize; +}; + +// Helper to wrap type-specific launch functions +template +void wrapLaunch(void *src, void *dst, void *stream, void (*fn)(T *, T *, void *)) { + fn((T *)src, (T *)dst, stream); +} + +static const TestCase kCases[] = { + // ========== uint16 case (C++ case 8) ========== + {"u16_260x32_src_259x7", DataType::U16, + [](void *src, void *dst, void *stream) { wrapLaunch(src, dst, stream, LaunchTFILLPAD_EXPAND_u16_260x32_src_259x7); }, + 260, 32, 259, 7, 260, 32, 260, 32, sizeof(uint16_t)}, + + // ========== int8 case (C++ case 9) ========== + {"s8_260x64_src_259x7", DataType::S8, + [](void *src, void *dst, void *stream) { wrapLaunch(src, dst, stream, LaunchTFILLPAD_EXPAND_s8_260x64_src_259x7); }, + 260, 64, 259, 7, 260, 64, 260, 64, sizeof(int8_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + size_t srcElemCount = tc.srcRows * tc.srcCols; + size_t dstElemCount = tc.dstRows * tc.dstCols; + size_t srcFileSize = srcElemCount * tc.elemSize; + size_t dstFileSize = dstElemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (src=%zux%zu valid=%zux%zu -> dst=%zux%zu) ===\n", + tc.name, tc.srcRows, tc.srcCols, tc.srcValidRows, tc.srcValidCols, tc.dstRows, tc.dstCols); + + std::string caseDir = std::string("./") + tc.name; + size_t inputFileSize = srcFileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, srcFileSize); + aclrtMallocHost(&dstHost, dstFileSize); + + aclrtMalloc(&srcDevice, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), inputFileSize, srcHost, srcFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, srcFileSize, srcHost, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/tfillpad_expand.pto b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/tfillpad_expand.pto new file mode 100644 index 000000000..211a415c1 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/tfillpad_expand.pto @@ -0,0 +1,134 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tfillpad_expand: copy src to dst and fill padding. +// Matches C++ test cases: case 8, 9 +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// +// PadValue encoding: 0=Null, 1=Zero, 2=Max, 3=Min +// Case 8: uint16, LoadPad=Min(pad=3), FillPad=Max(pad=2) +// Case 9: int8, LoadPad=Min(pad=3), FillPad=Max(pad=2) +// +// C++ template params: shape3=src_rows, shape4=src_cols, kTRows_=dst_rows, kTCols_=dst_cols + +module { + // ========== Case 8: uint16, src=259x7, dst=260x32, LoadPad=Min, FillPad=Max ========== + + func.func @TFILLPAD_EXPAND_u16_260x32_src_259x7(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c32 = arith.constant 32 : index + %c259 = arith.constant 259 : index + %c260 = arith.constant 260 : index + %c1813 = arith.constant 1813 : index // 259*7 (src size) + %c8320 = arith.constant 8320 : index // 260*32 (dst size) + + // Src tensor_view: 259x7 (matching C++ shape3=259, shape4=7) + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c259, %c7], + strides = [%c1813, %c1813, %c1813, %c7, %c1] + : !pto.tensor_view<1x1x1x259x7xui16> + + // Dst tensor_view: 260x32 + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c260, %c32], + strides = [%c8320, %c8320, %c8320, %c32, %c1] + : !pto.tensor_view<1x1x1x260x32xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c259, %c7] + : !pto.tensor_view<1x1x1x259x7xui16> -> !pto.partition_tensor_view<1x1x1x259x7xui16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260, %c32] + : !pto.tensor_view<1x1x1x260x32xui16> -> !pto.partition_tensor_view<1x1x1x260x32xui16> + + // Src tile: LoadPadVal=Min (pad=3), src physical aligned to 32B = 260x32, v_row=259, v_col=7 + // shape4_aligned = align_to_32B(7, uint16) = 16, but for expand we need dst cols = 32 + %src = pto.alloc_tile + : !pto.tile_buf + // Dst tile: FillPadVal=Max (pad=2), dst physical=260x32, v_row=260, v_col=32 (full output) + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x259x7xui16>) + outs(%src : !pto.tile_buf) + + pto.tfillpad_expand ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x260x32xui16>) + return + } + + // ========== Case 9: int8, src=259x7, dst=260x64, LoadPad=Min, FillPad=Max ========== + + func.func @TFILLPAD_EXPAND_s8_260x64_src_259x7(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c64 = arith.constant 64 : index + %c259 = arith.constant 259 : index + %c260 = arith.constant 260 : index + %c1813 = arith.constant 1813 : index // 259*7 (src size) + %c16640 = arith.constant 16640 : index // 260*64 (dst size) + + // Src tensor_view: 259x7 (matching C++ shape3=259, shape4=7) + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c259, %c7], + strides = [%c1813, %c1813, %c1813, %c7, %c1] + : !pto.tensor_view<1x1x1x259x7xi8> + + // Dst tensor_view: 260x64 + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c260, %c64], + strides = [%c16640, %c16640, %c16640, %c64, %c1] + : !pto.tensor_view<1x1x1x260x64xi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c259, %c7] + : !pto.tensor_view<1x1x1x259x7xi8> -> !pto.partition_tensor_view<1x1x1x259x7xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260, %c64] + : !pto.tensor_view<1x1x1x260x64xi8> -> !pto.partition_tensor_view<1x1x1x260x64xi8> + + // Src tile: LoadPadVal=Min (pad=3), src physical aligned = 260x64, v_row=259, v_col=7 + %src = pto.alloc_tile + : !pto.tile_buf + // Dst tile: FillPadVal=Max (pad=2), dst physical=260x64, v_row=260, v_col=64 (full output) + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x259x7xi8>) + outs(%src : !pto.tile_buf) + + pto.tfillpad_expand ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x260x64xi8>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/CMakeLists.txt new file mode 100644 index 000000000..9d0f6b924 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tfillpad_inplace PTO_LEVEL level3) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/cases.py new file mode 100644 index 000000000..c1b1dae17 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/cases.py @@ -0,0 +1,38 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tfillpad_inplace ST test cases. + +Matches C++ reference test case: Case 5 + +Each case defines: + - name: case identifier + - dtype: numpy dtype + - shape: (rows, cols) — tile dimensions (physical buffer size) + - valid_shape: (valid_rows, valid_cols) — valid region (smaller than shape) + - eps: tolerance for numpy.allclose +""" + +import numpy as np + +CASES = [ + # ========== Case: float, src_valid == dst_valid (no expansion) ========== + + { + "name": "f32_260x16_noexpand", + "dtype": np.float32, + "src_shape": (260, 16), # src physical + "src_valid": (260, 16), # src valid = dst valid (no expansion) + "dst_shape": (260, 16), # dst physical + "dst_valid": (260, 16), # dst valid = full output + "fill_padval": "Max", # FillPadVal (not used since no expansion) + "eps": 1e-6, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/compare.py new file mode 100644 index 000000000..a58a46a13 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/compare.py @@ -0,0 +1,80 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Compare output against golden for tfillpad_inplace test cases. + +For tfillpad_inplace: + - Input: full tile shape (rows x cols) + - Output: full tile shape (rows x cols) after inplace fill + - Golden: full tile shape +""" + +import os +import sys +import numpy as np + +from cases import CASES + + +def main(): + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dtype = case["dtype"] + dst_shape = case["dst_shape"] + eps = case["eps"] + + # Load golden and output (both stored with dst_shape) + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=dtype).reshape(dst_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=dtype).reshape(dst_shape) + + # For integer types, eps=0 means exact match + # For float types, use np.allclose with eps + if eps == 0: + # Integer comparison - exact match + if not np.array_equal(golden, output): + diff = golden - output + idx = int(np.argmax(np.abs(diff))) + print(f"[ERROR] {case['name']}: Mismatch at idx={idx} (golden={golden.flat[idx]}, output={output.flat[idx]})") + all_passed = False + else: + print(f"[INFO] {case['name']}: compare passed") + else: + # Float comparison - use allclose + # Convert to float64 for comparison (fp16 precision issues) + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + + if g.shape != o.shape: + print(f"[ERROR] {case['name']}: Shape mismatch: golden {g.shape} vs output {o.shape}") + all_passed = False + continue + + if not np.allclose(g, o, atol=eps, rtol=eps, equal_nan=True): + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + print(f"[ERROR] {case['name']}: Mismatch: max diff={float(abs_diff.flat[idx])} " + f"at idx={idx} (golden={g.flat[idx]}, output={o.flat[idx]})") + all_passed = False + else: + print(f"[INFO] {case['name']}: compare passed") + + if not all_passed: + sys.exit(2) + print("[INFO] all cases passed") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/gen_data.py new file mode 100644 index 000000000..2345a38a4 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/gen_data.py @@ -0,0 +1,99 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Generate golden data for tfillpad_inplace test cases. + +For tfillpad_inplace: + - Only one tile, valid_shape smaller than tile shape + - Input: full tile shape (rows x cols), random values in valid region, zeros in padding + - Golden: full tile shape with valid region copied and padding filled with MAX (PadValue.Max) +""" + +import os +import numpy as np +import struct + +from cases import CASES + +# FLT_MAX for float (matching DSL PadValue.MAX) +def _float32_from_bits(bits: int) -> float: + return struct.unpack(">f", bits.to_bytes(4, byteorder="big", signed=False))[0] + +_FLT_MAX = _float32_from_bits(0x7F7FFFFF) # ~3.4028235e+38 + + +def get_pad_value(dtype, padval_name): + """Get the actual pad value for a dtype based on PadValue enum.""" + if padval_name == "Max": + if np.issubdtype(dtype, np.floating): + return np.float32(_FLT_MAX) + else: + return np.iinfo(dtype).max + elif padval_name == "Min": + if np.issubdtype(dtype, np.floating): + return np.float32(-_FLT_MAX) + else: + return np.iinfo(dtype).min + elif padval_name == "Zero": + return dtype(0) + else: + return dtype(0) + + +def setup_case_rng(case): + """Set a per-case deterministic random seed.""" + np.random.seed(hash(case["name"]) & 0xFFFFFFFF) + + +def save_case_data(case_name, data_dict): + """Create case directory and write {name}.bin for each entry.""" + os.makedirs(case_name, exist_ok=True) + for name, arr in data_dict.items(): + arr.tofile(os.path.join(case_name, f"{name}.bin")) + + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + src_shape = case["src_shape"] + src_valid = case["src_valid"] + dst_shape = case["dst_shape"] + dst_valid = case["dst_valid"] + fill_padval = case.get("fill_padval", "Max") + + src_vr, src_vc = src_valid + dst_r, dst_c = dst_shape + dst_vr, dst_vc = dst_valid + + # Input: src valid region data (random values) + input_data = np.random.uniform(1.0, 10.0, size=(src_vr, src_vc)).astype(dtype) + + # Golden: dst full region + # Copy src.valid region to dst[:src_vr, :src_vc] + # Fill cols src_vc to dst_vc with FillPadVal + # Fill rows src_vr to dst_vr with FillPadVal (row expansion, if any) + golden = np.zeros(dst_shape, dtype=dtype) + golden[:src_vr, :src_vc] = input_data + + # Fill column padding (cols src_vc to dst_vc) + if dst_vc > src_vc: + fill_val = get_pad_value(dtype, fill_padval) + golden[:dst_vr, src_vc:dst_vc] = fill_val + + # Fill row padding (rows src_vr to dst_vr) + if dst_vr > src_vr: + fill_val = get_pad_value(dtype, fill_padval) + golden[src_vr:dst_vr, :dst_vc] = fill_val + + save_case_data(case["name"], {"input": input_data, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} " + f"src_valid={src_valid} dst_shape={dst_shape} " + f"fill_pad={fill_padval} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/launch.cpp new file mode 100644 index 000000000..39f42fc3a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/launch.cpp @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// ========== Case: float, 260x16, no expansion (inplace single buffer) ========== + +extern "C" __global__ AICORE void TFILLPAD_INPLACE_f32_260x16_noexpand(__gm__ float *buf); + +void LaunchTFILLPAD_INPLACE_f32_260x16_noexpand(float *buf, float *dummy, void *stream) { + // Inplace kernel: single buffer, src == dst physically + // dummy parameter ignored, only buf is used + TFILLPAD_INPLACE_f32_260x16_noexpand<<<1, nullptr, stream>>>((__gm__ float *)buf); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/main.cpp new file mode 100644 index 000000000..34f94bf40 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/main.cpp @@ -0,0 +1,129 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tfillpad_inplace ST. +// Matches C++ reference test case: Case 5 + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrapper (defined in launch.cpp) +// Inplace kernel takes single buffer pointer +void LaunchTFILLPAD_INPLACE_f32_260x16_noexpand(float *buf, float *dummy, void *stream); + +enum class DataType { F32 }; + +struct TestCase { + const char *name; + DataType dtype; + size_t rows; + size_t cols; + size_t validRows; + size_t validCols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + // Case: float, 260x16, no expansion (inplace: single buffer) + {"f32_260x16_noexpand", DataType::F32, + 260, 16, 260, 16, sizeof(float)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + size_t elemCount = tc.rows * tc.cols; + size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (%zux%zu, inplace) ===\n", + tc.name, tc.validRows, tc.validCols); + + std::string caseDir = std::string("./") + tc.name; + + // Single buffer for inplace operation + void *bufHost = nullptr; + void *bufDevice = nullptr; + + aclrtMallocHost(&bufHost, fileSize); + aclrtMalloc(&bufDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + // Load input data into the single buffer + if (!ReadFile((caseDir + "/input.bin").c_str(), fileSize, bufHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + // Copy input to device buffer + aclrtMemcpy(bufDevice, fileSize, bufHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + // Run inplace kernel (src == dst = bufDevice) + // Note: launch wrapper takes two args but inplace kernel uses same physical address + LaunchTFILLPAD_INPLACE_f32_260x16_noexpand((float *)bufDevice, (float *)bufDevice, stream); + + aclrtSynchronizeStream(stream); + // Copy result back (same buffer contains output) + aclrtMemcpy(bufHost, fileSize, bufDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), bufHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (bufDevice != nullptr) + aclrtFree(bufDevice); + if (bufHost != nullptr) + aclrtFreeHost(bufHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/tfillpad_inplace.pto b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/tfillpad_inplace.pto new file mode 100644 index 000000000..489219240 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/tfillpad_inplace.pto @@ -0,0 +1,76 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tfillpad (inplace mode). +// Matches C++ reference test case: Case 5 +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// +// PadValue encoding: 0=Null, 1=Zero, 2=Max, 3=Min +// Case 5: float, 260x16, valid=260x7, FillPad=Max (pad=2) +// +// Note: PTOAS tstore requires dst size to match src valid_shape. +// For outputting full buffer after inplace fill, we use two tiles: +// - src tile: holds input data (valid=260x7) +// - dst tile: receives filled data (valid=260x16 for output) + +module { + // ========== No expansion: float, 260x16 physical, src_valid == dst_valid ========== + + func.func @TFILLPAD_INPLACE_f32_260x16_noexpand(%tile_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c260 = arith.constant 260 : index + %c4160 = arith.constant 4160 : index // 260*16 (full tile size) + + // Input tensor_view: 260x16 + %src_view = pto.make_tensor_view %tile_ptr, + shape = [%c1, %c1, %c1, %c260, %c16], + strides = [%c4160, %c4160, %c4160, %c16, %c1] + : !pto.tensor_view<1x1x1x260x16xf32> + + // Output tensor_view: 260x16 (same as input) + %dst_view = pto.make_tensor_view %tile_ptr, + shape = [%c1, %c1, %c1, %c260, %c16], + strides = [%c4160, %c4160, %c4160, %c16, %c1] + : !pto.tensor_view<1x1x1x260x16xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260, %c16] + : !pto.tensor_view<1x1x1x260x16xf32> -> !pto.partition_tensor_view<1x1x1x260x16xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260, %c16] + : !pto.tensor_view<1x1x1x260x16xf32> -> !pto.partition_tensor_view<1x1x1x260x16xf32> + + // Single tile buffer in UB space at address 0 + // src_valid = dst_valid = 260x16, so no expansion needed + %tile_buf = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + + // Load full tile (260x16) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x260x16xf32>) + outs(%tile_buf : !pto.tile_buf) + + // tfillpad_inplace: src_valid == dst_valid, no expansion + pto.tfillpad_inplace ins(%tile_buf : !pto.tile_buf) + outs(%tile_buf : !pto.tile_buf) + + // Store full tile + pto.tstore ins(%tile_buf : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x260x16xf32>) + return + } +} \ No newline at end of file From b5f9e97d584da8bc80c14c3fd5c352fc0cc6acd1 Mon Sep 17 00:00:00 2001 From: mly <978226558@qq.com> Date: Sat, 25 Apr 2026 18:35:16 +0800 Subject: [PATCH 173/192] bugfix: vprelu should accept mask operand (#266) Co-authored-by: mouliangyu --- docs/isa/13-dsa-sfu-ops.md | 6 +- docs/vpto-spec.md | 2 +- include/PTO/IR/VPTOOps.td | 17 +++++- lib/PTO/IR/VPTO.cpp | 19 ++++++- lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 56 ++----------------- .../micro-op/dsa-sfu/vprelu-f32/kernel.pto | 2 +- .../micro-op/dsa-sfu/vprelu-tail/kernel.pto | 2 +- 7 files changed, 44 insertions(+), 60 deletions(-) diff --git a/docs/isa/13-dsa-sfu-ops.md b/docs/isa/13-dsa-sfu-ops.md index f4d88fe98..0196a559d 100644 --- a/docs/isa/13-dsa-sfu-ops.md +++ b/docs/isa/13-dsa-sfu-ops.md @@ -40,7 +40,7 @@ for (int i = 0; i < N; i++) ### `pto.vprelu` -- **syntax:** `%result = pto.vprelu %input, %alpha : !pto.vreg, !pto.vreg -> !pto.vreg` +- **syntax:** `%result = pto.vprelu %input, %alpha, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` - **A5 types:** f16, f32 - **semantics:** Parametric ReLU with per-element alpha vector. @@ -49,8 +49,8 @@ for (int i = 0; i < N; i++) dst[i] = (src[i] >= 0) ? src[i] : alpha[i] * src[i]; ``` -- **inputs:** `%input` is the activation vector and `%alpha` is the per-element - slope vector. +- **inputs:** `%input` is the activation vector, `%alpha` is the per-element + slope vector, and `%mask` selects active lanes. - **outputs:** `%result` is the parametric-ReLU vector. - **constraints and limitations:** Floating-point element types only on the current A5 surface. diff --git a/docs/vpto-spec.md b/docs/vpto-spec.md index c644d109e..3ac469bd7 100644 --- a/docs/vpto-spec.md +++ b/docs/vpto-spec.md @@ -1201,7 +1201,7 @@ pto.vsts %sum, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto. %lrelu = pto.vlrelu %input, %alpha, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> // Parametric ReLU (per-element alpha) -%prelu = pto.vprelu %input, %alpha_vec : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> +%prelu = pto.vprelu %input, %alpha_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> ``` diff --git a/include/PTO/IR/VPTOOps.td b/include/PTO/IR/VPTOOps.td index 07fa649c7..4963a7e07 100644 --- a/include/PTO/IR/VPTOOps.td +++ b/include/PTO/IR/VPTOOps.td @@ -1806,7 +1806,22 @@ class PTO_UnmaskedBinaryVecOp : PTO_Op { }]; } -def PTO_VpreluOp : PTO_UnmaskedBinaryVecOp<"vprelu">; +class PTO_BinaryVecMaskedOp : PTO_Op { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs `,` $mask attr-dict `:` type($lhs) `,` type($rhs) `,` type($mask) `->` type($result) + }]; +} + +def PTO_VpreluOp : PTO_BinaryVecMaskedOp<"vprelu">; def PTO_VexpdifOp : PTO_Op<"vexpdif", [Pure]> { let arguments = (ins PTO_VectorType:$input, diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp index cbb89e90d..23534684b 100644 --- a/lib/PTO/IR/VPTO.cpp +++ b/lib/PTO/IR/VPTO.cpp @@ -3132,7 +3132,24 @@ static LogicalResult verifyFloatBinaryVecNoMaskOp(BinaryVecNoMaskOp op) { return success(); } -LogicalResult VpreluOp::verify() { return verifyFloatBinaryVecNoMaskOp(*this); } +template +static LogicalResult verifyFloatBinaryVecMaskOp(BinaryVecMaskOp op) { + if (failed(verifyVRegTypeLike(op, op.getLhs().getType(), "lhs type")) || + failed(verifyVRegTypeLike(op, op.getRhs().getType(), "rhs type")) || + failed(verifyMaskTypeLike(op, op.getMask().getType(), "mask type")) || + failed(verifyVRegTypeLike(op, op.getResult().getType(), "result type"))) + return failure(); + if (op.getLhs().getType() != op.getRhs().getType() || + op.getLhs().getType() != op.getResult().getType()) + return op.emitOpError("requires lhs, rhs, and result to share one vector type"); + auto lhsType = cast(op.getLhs().getType()); + Type elemType = lhsType.getElementType(); + if (!elemType.isF16() && !elemType.isF32()) + return op.emitOpError("requires f16 or f32 vector element type"); + return success(); +} + +LogicalResult VpreluOp::verify() { return verifyFloatBinaryVecMaskOp(*this); } LogicalResult VexpdifOp::verify() { if (failed(verifyVRegTypeLike(*this, getInput().getType(), "input type")) || failed(verifyVRegTypeLike(*this, getMax().getType(), "max type")) || diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index cadcd96d8..57fd57635 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -1676,6 +1676,8 @@ static StringRef getBinaryMaskedStem() { return "vshl"; if constexpr (std::is_same_v) return "vshr"; + if constexpr (std::is_same_v) + return "vprelu"; return {}; } @@ -2381,11 +2383,6 @@ static FailureOr buildVscatterCallee(MLIRContext *context, return buildLaneTypedCallee(context, valueType, "vscatter", ".v300"); } -static FailureOr buildVpreluCallee(MLIRContext *context, - Type resultType) { - return buildLaneTypedCallee(context, resultType, "vprelu", ".x"); -} - static FailureOr buildVaxpyCallee(MLIRContext *context, Type resultType) { return buildLaneTypedCallee(context, resultType, "vaxpy", ".m"); @@ -5622,52 +5619,6 @@ class LowerVscatterOpPattern final LoweringState &state; }; -class LowerVpreluOpPattern final : public OpConversionPattern { -public: - explicit LowerVpreluOpPattern(TypeConverter &typeConverter, - MLIRContext *context, LoweringState &state) - : OpConversionPattern(typeConverter, context), - state(state) {} - - LogicalResult - matchAndRewrite(pto::VpreluOp op, pto::VpreluOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto laneCount = getElementCountFromVectorLike(op.getResult().getType()); - Type elemType = getElementTypeFromVectorLike(op.getResult().getType()); - if (!laneCount || !elemType) - return rewriter.notifyMatchFailure(op, "unsupported vprelu signature"); - - FailureOr mask = materializeDynamicPltMask( - rewriter, state, op.getLoc(), getI32Constant(rewriter, op.getLoc(), *laneCount), - elemType); - if (failed(mask)) - return rewriter.notifyMatchFailure(op, "failed to materialize vprelu mask"); - - Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); - if (!resultType) - return rewriter.notifyMatchFailure(op, "failed to convert vprelu result type"); - - FailureOr calleeName = - buildVpreluCallee(op.getContext(), op.getResult().getType()); - if (failed(calleeName)) - return rewriter.notifyMatchFailure(op, "unsupported vprelu callee"); - - auto funcType = rewriter.getFunctionType( - TypeRange{adaptor.getLhs().getType(), adaptor.getRhs().getType(), - (*mask).getType()}, - TypeRange{resultType}); - auto call = rewriter.create( - op.getLoc(), *calleeName, TypeRange{resultType}, - ValueRange{adaptor.getLhs(), adaptor.getRhs(), *mask}); - state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); - rewriter.replaceOp(op, call.getResults()); - return success(); - } - -private: - LoweringState &state; -}; - class LowerVaxpyOpPattern final : public OpConversionPattern { public: explicit LowerVaxpyOpPattern(TypeConverter &typeConverter, @@ -6704,6 +6655,7 @@ static void populateVPTOOpLoweringPatterns(VPTOTypeConverter &typeConverter, LowerBinaryMaskedOpPattern, LowerBinaryMaskedOpPattern, LowerBinaryMaskedOpPattern, + LowerBinaryMaskedOpPattern, LowerCarryBinaryOpPattern, LowerCarryBinaryOpPattern, LowerCarryBinaryOpPattern, @@ -6795,7 +6747,7 @@ static void populateVPTOOpLoweringPatterns(VPTOTypeConverter &typeConverter, LowerVstarOpPattern, LowerVstasOpPattern, LowerVgather2OpPattern, LowerVgather2BcOpPattern, LowerVgatherbOpPattern, LowerVscatterOpPattern, - LowerVpreluOpPattern, LowerVaxpyOpPattern, + LowerVaxpyOpPattern, LowerVciOpPattern, LowerVexpdifOpPattern, LowerVbitsortOpPattern, LowerVtrcOpPattern, LowerVcvtOpPattern, LowerVbitcastOpPattern, LowerPbitcastOpPattern, diff --git a/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/kernel.pto index 207b17691..4948416bb 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/kernel.pto @@ -40,7 +40,7 @@ module attributes {pto.target_arch = "a5"} { scf.for %offset = %c0 to %c1024 step %c64 { %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> %alpha = pto.vlds %ub_alpha[%offset] : !pto.ptr -> !pto.vreg<64xf32> - %sum = pto.vprelu %vec, %alpha : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + %sum = pto.vprelu %vec, %alpha, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask } } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/kernel.pto index 854b70201..e85e14ae9 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/kernel.pto @@ -41,7 +41,7 @@ module attributes {pto.target_arch = "a5"} { %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> %alpha = pto.vlds %ub_alpha[%offset] : !pto.ptr -> !pto.vreg<64xf32> - %sum = pto.vprelu %vec, %alpha : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + %sum = pto.vprelu %vec, %alpha, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask scf.yield %next_remaining : i32 } From 499e9d79c28a84168e88616a528b0dc32f819878 Mon Sep 17 00:00:00 2001 From: yuqiha <93496818+yuqiha@users.noreply.github.com> Date: Sun, 26 Apr 2026 01:19:50 +0800 Subject: [PATCH 174/192] ST tests for TReLU, TLReLU , TPrelu, Tsel and Tselsand operators have been passed. (#212) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * trelu tlrelu trandom 算子ST测试已通过,tsel tsels tprelu算子ST测试暂未通过 * trelu tlrelu trandom 算子ST测试已通过 * 针对评论修改trelu和tlrelu算子 * trelu, tlrelu, tprelu, tsel, tsels算子编译与ST测试均已通过 * trelu, tlrelu, tprelu, tsel, tsels算子编译与ST测试均已通过 * 添加license * 根据评论修改,删除prelu算子后续再提 * 针对tsels ci不通过修改 --------- Co-authored-by: KurrinQu --- lib/TileOps/tlrelu_template.py | 46 + lib/TileOps/trelu_template.py | 43 + lib/TileOps/tsel_template.py | 99 ++ lib/TileOps/tsels_template.py | 121 +++ test/basic/expand_tile_op_tlrelu_tilelang.pto | 47 + test/basic/expand_tile_op_trelu_tilelang.pto | 42 + test/basic/expand_tile_op_tsel_tilelang.pto | 57 ++ test/basic/expand_tile_op_tsels_tilelang.pto | 57 ++ test/basic/tlrelu/CMakeLists.txt | 9 + test/basic/tlrelu/cases.py | 65 ++ test/basic/tlrelu/compare.py | 49 + test/basic/tlrelu/gen_data.py | 45 + test/basic/tlrelu/launch.cpp | 41 + test/basic/tlrelu/main.cpp | 154 +++ test/basic/tlrelu/tlrelu.pto | 238 +++++ .../npu/a5/src/st/testcase/CMakeLists.txt | 4 + .../a5/src/st/testcase/tlrelu/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tlrelu/cases.py | 65 ++ .../npu/a5/src/st/testcase/tlrelu/compare.py | 49 + .../npu/a5/src/st/testcase/tlrelu/gen_data.py | 45 + .../npu/a5/src/st/testcase/tlrelu/launch.cpp | 41 + .../npu/a5/src/st/testcase/tlrelu/main.cpp | 154 +++ .../npu/a5/src/st/testcase/tlrelu/tlrelu.pto | 238 +++++ .../a5/src/st/testcase/trelu/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/trelu/cases.py | 52 ++ .../npu/a5/src/st/testcase/trelu/compare.py | 48 + .../npu/a5/src/st/testcase/trelu/gen_data.py | 32 + .../npu/a5/src/st/testcase/trelu/launch.cpp | 34 + .../npu/a5/src/st/testcase/trelu/main.cpp | 128 +++ .../npu/a5/src/st/testcase/trelu/trelu.pto | 157 ++++ .../a5/src/st/testcase/tsel/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tsel/cases.py | 97 ++ .../npu/a5/src/st/testcase/tsel/compare.py | 49 + .../npu/a5/src/st/testcase/tsel/gen_data.py | 44 + .../npu/a5/src/st/testcase/tsel/launch.cpp | 83 ++ .../npu/a5/src/st/testcase/tsel/main.cpp | 312 +++++++ .../npu/a5/src/st/testcase/tsel/tsel.pto | 884 ++++++++++++++++++ .../a5/src/st/testcase/tsels/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tsels/cases.py | 50 + .../npu/a5/src/st/testcase/tsels/compare.py | 49 + .../npu/a5/src/st/testcase/tsels/gen_data.py | 49 + .../npu/a5/src/st/testcase/tsels/launch.cpp | 168 ++++ .../npu/a5/src/st/testcase/tsels/main.cpp | 181 ++++ .../npu/a5/src/st/testcase/tsels/tsels.pto | 654 +++++++++++++ 44 files changed, 4816 insertions(+) create mode 100644 lib/TileOps/tlrelu_template.py create mode 100644 lib/TileOps/trelu_template.py create mode 100644 lib/TileOps/tsel_template.py create mode 100644 lib/TileOps/tsels_template.py create mode 100644 test/basic/expand_tile_op_tlrelu_tilelang.pto create mode 100644 test/basic/expand_tile_op_trelu_tilelang.pto create mode 100644 test/basic/expand_tile_op_tsel_tilelang.pto create mode 100644 test/basic/expand_tile_op_tsels_tilelang.pto create mode 100644 test/basic/tlrelu/CMakeLists.txt create mode 100644 test/basic/tlrelu/cases.py create mode 100644 test/basic/tlrelu/compare.py create mode 100644 test/basic/tlrelu/gen_data.py create mode 100644 test/basic/tlrelu/launch.cpp create mode 100644 test/basic/tlrelu/main.cpp create mode 100644 test/basic/tlrelu/tlrelu.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tlrelu/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tlrelu/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tlrelu/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tlrelu/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tlrelu/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tlrelu/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tlrelu/tlrelu.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trelu/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trelu/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trelu/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trelu/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trelu/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trelu/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trelu/trelu.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tsel/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tsel/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tsel/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tsel/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tsel/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tsel/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tsel/tsel.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tsels/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tsels/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tsels/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tsels/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tsels/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tsels/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tsels/tsels.pto diff --git a/lib/TileOps/tlrelu_template.py b/lib/TileOps/tlrelu_template.py new file mode 100644 index 000000000..33087ca51 --- /dev/null +++ b/lib/TileOps/tlrelu_template.py @@ -0,0 +1,46 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tlrelu (Leaky ReLU with scalar slope)""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tlrelu", + advanced=True +) +def template_tlrelu(src: pto.Tile, slope: pto.f32, dst: pto.Tile): + """Leaky ReLU: dst = src if src > 0 else src * slope. + + Semantics: + For each element (i, j): + dst[i, j] = src[i, j] > 0 ? src[i, j] : slope * src[i, j] + + Supported data types: f16, f32 + """ + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + lanes = pto.get_lanes(dtype) + if pto.constexpr(dtype == pto.f16): + slope_scalar = pto.f16(slope) + else: + slope_scalar = slope + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + src_vec = pto.vlds(src[row, col:]) + result = pto.vlrelu(src_vec, slope_scalar, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/trelu_template.py b/lib/TileOps/trelu_template.py new file mode 100644 index 000000000..0cdd0e7eb --- /dev/null +++ b/lib/TileOps/trelu_template.py @@ -0,0 +1,43 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.trelu (Elementwise ReLU)""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.trelu", + dtypes=[(pto.f16, pto.f16), (pto.f32, pto.f32), (pto.i32, pto.i32)], + advanced=True +) +def template_trelu(src: pto.Tile, dst: pto.Tile): + """Elementwise ReLU: dst = max(0, src). + + Semantics: + For each element (i, j): + dst[i, j] = max(0, src[i, j]) + + Supported data types: f16, f32, i32 + """ + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + lanes = pto.get_lanes(dtype) + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + src_vec = pto.vlds(src[row, col:]) + result = pto.vrelu(src_vec, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/tsel_template.py b/lib/TileOps/tsel_template.py new file mode 100644 index 000000000..92285716b --- /dev/null +++ b/lib/TileOps/tsel_template.py @@ -0,0 +1,99 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tsel + +NOTE: This template uses pto.plds for mask loading which directly +loads predicate mask from UB without vcmps comparison. +This approach matches the TSel.hpp implementation in pto-isa. + +Mask tile format: +- Packed predicate bytes in UB (`i8` tile data). +- Each row stores `ceil(valid_cols / 8)` valid bytes; tile row stride may be padded. + +REQUIRES: tilelang_dsl support for plds, astype(mask), pintlv_b16, castptr operations +""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tsel", + dtypes=[ + (pto.i8, pto.f32, pto.f32, pto.f32, pto.f32), + (pto.i8, pto.f16, pto.f16, pto.f16, pto.f16), + (pto.i8, pto.i8, pto.i8, pto.i8, pto.i8), + ], + advanced=True +) +def template_tsel(mask: pto.Tile, src0: pto.Tile, src1: pto.Tile, tmp: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + lanes = pto.get_lanes(dtype) + mask_row_stride = mask.shape[1] + mask_ptr = pto.castptr(mask.as_ptr(), pto.ptr(pto.ui8, pto.MemorySpace.UB)) + + if pto.constexpr(dtype == pto.f32): + full_mask_b16 = pto.pset_b16(pto.MaskPattern.ALL) + pair_width = lanes * 2 + paired_cols = (valid_cols // pair_width) * pair_width + for row in range(0, valid_rows, 1): + for col in range(0, paired_cols, pair_width): + mask_offset = row * mask_row_stride + col // 8 + select_mask_raw = pto.plds(mask_ptr, mask_offset, pto.PredicateDist.US) + select_mask = select_mask_raw.astype(pto.mask_b16) + pred0, _ = pto.make_mask(dtype, pair_width) + pred1, _ = pto.make_mask(dtype, lanes) + select_mask0, select_mask1 = pto.pintlv_b16(select_mask, full_mask_b16) + select_mask0 = select_mask0.astype(pto.mask_b32) + select_mask1 = select_mask1.astype(pto.mask_b32) + lhs0 = pto.vlds(src0[row, col:]) + rhs0 = pto.vlds(src1[row, col:]) + lhs1 = pto.vlds(src0[row, col + lanes:]) + rhs1 = pto.vlds(src1[row, col + lanes:]) + selected0 = pto.vsel(lhs0, rhs0, select_mask0) + selected1 = pto.vsel(lhs1, rhs1, select_mask1) + pto.vsts(selected0, dst[row, col:], pred0) + pto.vsts(selected1, dst[row, col + lanes:], pred1) + tail_cols = valid_cols - paired_cols + if tail_cols > 0: + col = paired_cols + mask_offset = row * mask_row_stride + col // 8 + select_mask_raw = pto.plds(mask_ptr, mask_offset, pto.PredicateDist.US) + select_mask = select_mask_raw.astype(pto.mask_b16) + select_mask0 = pto.punpack(select_mask, pto.PredicatePart.LOWER) + select_mask0 = select_mask0.astype(pto.mask_b32) + pred0, _ = pto.make_mask(dtype, tail_cols) + lhs0 = pto.vlds(src0[row, col:]) + rhs0 = pto.vlds(src1[row, col:]) + selected0 = pto.vsel(lhs0, rhs0, select_mask0) + pto.vsts(selected0, dst[row, col:], pred0) + else: + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, lanes): + pred_mask, remained = pto.make_mask(dtype, remained) + mask_offset = row * mask_row_stride + col // 8 + if pto.constexpr(dtype == pto.f16): + select_mask_raw = pto.plds(mask_ptr, mask_offset, pto.PredicateDist.US) + select_mask = select_mask_raw.astype(pto.mask_b16) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + selected = pto.vsel(lhs, rhs, select_mask) + pto.vsts(selected, dst[row, col:], pred_mask) + else: + select_mask = pto.plds(mask_ptr, mask_offset, pto.PredicateDist.NORM) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + selected = pto.vsel(lhs, rhs, select_mask) + pto.vsts(selected, dst[row, col:], pred_mask) + return \ No newline at end of file diff --git a/lib/TileOps/tsels_template.py b/lib/TileOps/tsels_template.py new file mode 100644 index 000000000..905d0dd06 --- /dev/null +++ b/lib/TileOps/tsels_template.py @@ -0,0 +1,121 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tsels + +NOTE: This template uses pto.plds for mask loading which directly +loads predicate mask from UB without vcmps comparison. + +TSels: Select between source tile and scalar based on mask. +- mask=true: select from src +- mask=false: select scalar value + +Mask tile format: +- Packed predicate bytes in UB. +- Each row stores ceil(valid_cols / 8) valid bytes; tile row stride may be padded. +- mask_dtype determines the storage format (i8/i16/i32), but the actual + predicate bits are packed and accessed as bytes. + +IMPORTANT: mask_row_stride is always mask.shape[1] (element count), +because mask tile stride equals cols in element units regardless of mask_dtype. +Byte offset for plds is col // 8 (one byte covers 8 elements). +""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tsels", + dtypes=[ + (pto.i8, pto.i8, pto.i8, pto.i8, pto.i8), + (pto.i16, pto.i8, pto.i8, pto.i8, pto.i8), + (pto.i32, pto.i8, pto.i8, pto.i8, pto.i8), + (pto.i8, pto.i16, pto.i16, pto.i16, pto.i16), + (pto.i16, pto.i16, pto.i16, pto.i16, pto.i16), + (pto.i32, pto.i16, pto.i16, pto.i16, pto.i16), + (pto.i8, pto.i32, pto.i32, pto.i32, pto.i32), + (pto.i16, pto.i32, pto.i32, pto.i32, pto.i32), + (pto.i32, pto.i32, pto.i32, pto.i32, pto.i32), + (pto.i8, pto.f16, pto.f16, pto.f16, pto.f16), + (pto.i16, pto.f16, pto.f16, pto.f16, pto.f16), + (pto.i32, pto.f16, pto.f16, pto.f16, pto.f16), + (pto.i8, pto.f32, pto.f32, pto.f32, pto.f32), + (pto.i16, pto.f32, pto.f32, pto.f32, pto.f32), + (pto.i32, pto.f32, pto.f32, pto.f32, pto.f32), + ], + advanced=True +) +def template_tsels(mask: pto.Tile, src: pto.Tile, tmp: pto.Tile, scalar: pto.AnyType, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + mask_dtype = mask.element_type + + lanes = pto.get_lanes(dtype) + mask_row_stride = mask.shape[1] * pto.bytewidth(mask_dtype) + mask_ptr = pto.castptr(mask.as_ptr(), pto.ptr(pto.ui8, pto.MemorySpace.UB)) + + scalar_mask, _ = pto.make_mask(dtype, lanes) + vreg_scalar = pto.vdup(scalar, scalar_mask) + + if pto.constexpr(lanes == 64): + full_mask_b16 = pto.pset_b16(pto.MaskPattern.ALL) + pair_width = lanes * 2 + paired_cols = (valid_cols // pair_width) * pair_width + for row in range(0, valid_rows, 1): + for col in range(0, paired_cols, pair_width): + mask_offset = row * mask_row_stride + col // 8 + select_mask_raw = pto.plds(mask_ptr, mask_offset, pto.PredicateDist.US) + select_mask = select_mask_raw.astype(pto.mask_b16) + pred0, _ = pto.make_mask(dtype, pair_width) + pred1, _ = pto.make_mask(dtype, lanes) + select_mask0, select_mask1 = pto.pintlv_b16(select_mask, full_mask_b16) + select_mask0 = select_mask0.astype(pto.mask_b32) + select_mask1 = select_mask1.astype(pto.mask_b32) + src0 = pto.vlds(src[row, col:]) + src1 = pto.vlds(src[row, col + lanes:]) + selected0 = pto.vsel(src0, vreg_scalar, select_mask0) + selected1 = pto.vsel(src1, vreg_scalar, select_mask1) + pto.vsts(selected0, dst[row, col:], pred0) + pto.vsts(selected1, dst[row, col + lanes:], pred1) + tail_cols = valid_cols - paired_cols + if tail_cols > 0: + col = paired_cols + mask_offset = row * mask_row_stride + col // 8 + select_mask_raw = pto.plds(mask_ptr, mask_offset, pto.PredicateDist.US) + select_mask = select_mask_raw.astype(pto.mask_b16) + select_mask0 = pto.punpack(select_mask, pto.PredicatePart.LOWER) + select_mask0 = select_mask0.astype(pto.mask_b32) + pred0, _ = pto.make_mask(dtype, tail_cols) + src0 = pto.vlds(src[row, col:]) + selected0 = pto.vsel(src0, vreg_scalar, select_mask0) + pto.vsts(selected0, dst[row, col:], pred0) + elif pto.constexpr(lanes == 128): + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, lanes): + pred_mask, remained = pto.make_mask(dtype, remained) + mask_offset = row * mask_row_stride + col // 8 + select_mask_raw = pto.plds(mask_ptr, mask_offset, pto.PredicateDist.US) + select_mask = select_mask_raw.astype(pto.mask_b16) + src_vec = pto.vlds(src[row, col:]) + selected = pto.vsel(src_vec, vreg_scalar, select_mask) + pto.vsts(selected, dst[row, col:], pred_mask) + else: + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, lanes): + pred_mask, remained = pto.make_mask(dtype, remained) + mask_offset = row * mask_row_stride + col // 8 + select_mask = pto.plds(mask_ptr, mask_offset, pto.PredicateDist.NORM) + src_vec = pto.vlds(src[row, col:]) + selected = pto.vsel(src_vec, vreg_scalar, select_mask) + pto.vsts(selected, dst[row, col:], pred_mask) + return \ No newline at end of file diff --git a/test/basic/expand_tile_op_tlrelu_tilelang.pto b/test/basic/expand_tile_op_tlrelu_tilelang.pto new file mode 100644 index 000000000..c7cd09aef --- /dev/null +++ b/test/basic/expand_tile_op_tlrelu_tilelang.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tlrelu via the default TileLang Python DSL template +// lib/TileOps/tlrelu_template.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tlrelu should be lowered to vector-style VPTO IR. +// CHECK: func.func @TLRelu_test +// CHECK-NOT: pto.tlrelu ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vcmps +// CHECK: pto.vmuls +// CHECK: pto.vsel +// CHECK: pto.vsts + +module { + func.func @TLRelu_test() { + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + %slope = arith.constant 0.1 : f32 + + pto.tlrelu ins(%src, %slope : !pto.tile_buf, + f32) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/expand_tile_op_trelu_tilelang.pto b/test/basic/expand_tile_op_trelu_tilelang.pto new file mode 100644 index 000000000..2cb50b0e4 --- /dev/null +++ b/test/basic/expand_tile_op_trelu_tilelang.pto @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.trelu via the default TileLang Python DSL template +// lib/TileOps/trelu_template.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.trelu should be lowered to vector-style VPTO IR. +// CHECK: func.func @TRelu_test +// CHECK-NOT: pto.trelu ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vrelu +// CHECK: pto.vsts + +module { + func.func @TRelu_test() { + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.trelu ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/expand_tile_op_tsel_tilelang.pto b/test/basic/expand_tile_op_tsel_tilelang.pto new file mode 100644 index 000000000..c30bc87c0 --- /dev/null +++ b/test/basic/expand_tile_op_tsel_tilelang.pto @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tsel via the default TileLang Python DSL template +// lib/TileOps/tsel_template.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tsel should be lowered to vector-style VPTO IR. +// CHECK: func.func @TSEL_test +// CHECK-NOT: pto.tsel ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vsel +// CHECK: pto.vsts + +module { + func.func @TSEL_test() { + %mask = pto.alloc_tile + : !pto.tile_buf + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + + pto.tsel ins(%mask, %src0, %src1, %tmp : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/expand_tile_op_tsels_tilelang.pto b/test/basic/expand_tile_op_tsels_tilelang.pto new file mode 100644 index 000000000..f9a5205cf --- /dev/null +++ b/test/basic/expand_tile_op_tsels_tilelang.pto @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tsels via the default TileLang Python DSL template +// lib/TileOps/tsels_template.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tsels should be lowered to vector-style VPTO IR. +// CHECK: func.func @TSELS_test +// CHECK-NOT: pto.tsels ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vcmps +// CHECK: pto.vdup +// CHECK: pto.vsel +// CHECK: pto.vsts + +module { + func.func @TSELS_test() { + %mask = pto.alloc_tile + : !pto.tile_buf + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + %scalar = arith.constant 42.0 : f32 + + pto.tsels ins(%mask, %src, %tmp, %scalar : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf, + f32) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/tlrelu/CMakeLists.txt b/test/basic/tlrelu/CMakeLists.txt new file mode 100644 index 000000000..7c79f07d5 --- /dev/null +++ b/test/basic/tlrelu/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tlrelu) \ No newline at end of file diff --git a/test/basic/tlrelu/cases.py b/test/basic/tlrelu/cases.py new file mode 100644 index 000000000..a2b897e5d --- /dev/null +++ b/test/basic/tlrelu/cases.py @@ -0,0 +1,65 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tlrelu ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — src tile dimensions (UB allocation). + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - dst_shape: (rows, cols) — dst tile physical dimensions (UB allocation, may have padding). + - dst_valid_shape: (valid_rows, valid_cols) — dst effective region (same as valid_shape). + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_32x64_dst128", + "dtype": np.float32, + "shape": (32, 64), + "valid_shape": (32, 64), + "dst_shape": (32, 128), + "dst_valid_shape": (32, 64), + "eps": 1e-3, + }, + { + "name": "f16_63x64_dst128", + "dtype": np.float16, + "shape": (63, 64), + "valid_shape": (63, 64), + "dst_shape": (63, 128), + "dst_valid_shape": (63, 64), + "eps": 1e-3, + }, + { + "name": "f32_7x448_dst512", + "dtype": np.float32, + "shape": (7, 448), + "valid_shape": (7, 448), + "dst_shape": (7, 512), + "dst_valid_shape": (7, 448), + "eps": 1e-3, + }, + { + "name": "f32_256x16_dst32", + "dtype": np.float32, + "shape": (256, 16), + "valid_shape": (256, 16), + "dst_shape": (256, 32), + "dst_valid_shape": (256, 16), + "eps": 1e-3, + }, +] \ No newline at end of file diff --git a/test/basic/tlrelu/compare.py b/test/basic/tlrelu/compare.py new file mode 100644 index 000000000..6af6f6d5c --- /dev/null +++ b/test/basic/tlrelu/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dst_shape = case["dst_shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(dst_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(dst_shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/basic/tlrelu/gen_data.py b/test/basic/tlrelu/gen_data.py new file mode 100644 index 000000000..22b2b3314 --- /dev/null +++ b/test/basic/tlrelu/gen_data.py @@ -0,0 +1,45 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +import struct +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + dst_shape = case["dst_shape"] + valid_shape = case["valid_shape"] + + rows, cols = shape + dst_rows, dst_cols = dst_shape + vr, vc = valid_shape + + input_arr = np.random.uniform(low=-8, high=8, size=(rows, cols)).astype(dtype) + slope = np.random.uniform(low=-8, high=8, size=(1, 1)).astype(np.float32) + golden = np.zeros((dst_rows, dst_cols), dtype=dtype) + + for i in range(vr): + for j in range(vc): + if input_arr[i, j] > 0: + golden[i, j] = input_arr[i, j] + else: + golden[i, j] = dtype(input_arr[i, j] * slope[0, 0]) + + slope_arr = np.array([slope[0, 0]], dtype=np.float32) + + save_case_data(case["name"], {"input": input_arr, "slope": slope_arr, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} dst_shape={dst_shape} valid_shape={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/basic/tlrelu/launch.cpp b/test/basic/tlrelu/launch.cpp new file mode 100644 index 000000000..356bd5115 --- /dev/null +++ b/test/basic/tlrelu/launch.cpp @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 32x64 -> dst 32x128 (valid 32x64) +extern "C" __global__ AICORE void TLRELU_f32_32x64_dst128(__gm__ float *src, __gm__ float *dst, float slope); + +void LaunchTLRELU_f32_32x64_dst128(float *src, float *dst, float slope, void *stream) { + TLRELU_f32_32x64_dst128<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, slope); +} + +// Case 1: f16 63x64 -> dst 63x128 (valid 63x64) +extern "C" __global__ AICORE void TLRELU_f16_63x64_dst128(__gm__ uint16_t *src, __gm__ uint16_t *dst, float slope); + +void LaunchTLRELU_f16_63x64_dst128(uint16_t *src, uint16_t *dst, float slope, void *stream) { + TLRELU_f16_63x64_dst128<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint16_t *)dst, slope); +} + +// Case 2: f32 7x448 -> dst 7x512 (valid 7x448) +extern "C" __global__ AICORE void TLRELU_f32_7x448_dst512(__gm__ float *src, __gm__ float *dst, float slope); + +void LaunchTLRELU_f32_7x448_dst512(float *src, float *dst, float slope, void *stream) { + TLRELU_f32_7x448_dst512<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, slope); +} + +// Case 3: f32 256x16 -> dst 256x32 (valid 256x16) +extern "C" __global__ AICORE void TLRELU_f32_256x16_dst32(__gm__ float *src, __gm__ float *dst, float slope); + +void LaunchTLRELU_f32_256x16_dst32(float *src, float *dst, float slope, void *stream) { + TLRELU_f32_256x16_dst32<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, slope); +} \ No newline at end of file diff --git a/test/basic/tlrelu/main.cpp b/test/basic/tlrelu/main.cpp new file mode 100644 index 000000000..3b75edf69 --- /dev/null +++ b/test/basic/tlrelu/main.cpp @@ -0,0 +1,154 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tlrelu ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTLRELU_f32_32x64_dst128(float *src, float *dst, float slope, void *stream); +void LaunchTLRELU_f16_63x64_dst128(uint16_t *src, uint16_t *dst, float slope, void *stream); +void LaunchTLRELU_f32_7x448_dst512(float *src, float *dst, float slope, void *stream); +void LaunchTLRELU_f32_256x16_dst32(float *src, float *dst, float slope, void *stream); + +using LaunchFn = void (*)(void *, void *, float, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t srcRows; // src tile rows + size_t srcCols; // src tile cols + size_t dstRows; // dst tile rows (may have padding) + size_t dstCols; // dst tile cols (may have padding) + size_t validRows; // effective computation rows (<= srcRows, dstRows) + size_t validCols; // effective computation cols (<= srcCols, dstCols) + size_t elemSize; // bytes per element + bool isFp16; // true for float16 case +}; + +static const TestCase kCases[] = { + {"f32_32x64_dst128", (LaunchFn)LaunchTLRELU_f32_32x64_dst128, 32, 64, 32, 128, 32, 64, sizeof(float), false}, + {"f16_63x64_dst128", (LaunchFn)LaunchTLRELU_f16_63x64_dst128, 63, 64, 63, 128, 63, 64, sizeof(uint16_t), true}, + {"f32_7x448_dst512", (LaunchFn)LaunchTLRELU_f32_7x448_dst512, 7, 448, 7, 512, 7, 448, sizeof(float), false}, + {"f32_256x16_dst32", (LaunchFn)LaunchTLRELU_f32_256x16_dst32, 256, 16, 256, 32, 256, 16, sizeof(float), false}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + size_t srcFileSize = tc.srcRows * tc.srcCols * tc.elemSize; + size_t dstFileSize = tc.dstRows * tc.dstCols * tc.elemSize; + size_t actualSize = 0; + + std::printf("[INFO] === case: %s (src=%zux%zu, dst=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.srcRows, tc.srcCols, tc.dstRows, tc.dstCols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + float slope = 0.0f; + + aclrtMallocHost(&srcHost, srcFileSize); + aclrtMallocHost(&dstHost, dstFileSize); + + aclrtMalloc(&srcDevice, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile(caseDir + "/input.bin", actualSize, srcHost, srcFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + // Read slope (4 bytes float) + if (rc == 0) { + std::ifstream slopeFile(caseDir + "/slope.bin", std::ios::binary); + if (!slopeFile) { + std::fprintf(stderr, "[ERROR] failed to open %s/slope.bin\n", caseDir.c_str()); + rc = 1; + } else { + slopeFile.read(reinterpret_cast(&slope), sizeof(float)); + slopeFile.close(); + } + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, srcFileSize, srcHost, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, slope, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tlrelu [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/basic/tlrelu/tlrelu.pto b/test/basic/tlrelu/tlrelu.pto new file mode 100644 index 000000000..c4131bc67 --- /dev/null +++ b/test/basic/tlrelu/tlrelu.pto @@ -0,0 +1,238 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tlrelu: tload(src) + tlrelu(src, slope)->dst + tstore(dst). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. + +module { + // Case 0: f32 src 32x64 -> dst 32x128 (valid 32x64) + func.func @TLRELU_f32_32x64_dst128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %slope: f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c2048 = arith.constant 2048 : index + %c4096 = arith.constant 4096 : index + + // Src GM view: 1x1x1x32x64 + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf32> + + // Dst GM view: shape=valid_shape (32x64), strides based on dst allocation (32x128) + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c4096, %c4096, %c4096, %c128, %c1] + : !pto.tensor_view<1x1x1x32x64xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + // Dst partition: sizes = valid_shape (32x64) + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + + // Src UB tile: 32x64, valid 32x64 + %src_tile = pto.alloc_tile + : !pto.tile_buf + // Dst UB tile: 32x64, valid 32x64 + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + outs(%src_tile : !pto.tile_buf) + + pto.tlrelu ins(%src_tile, %slope : !pto.tile_buf, f32) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + return + } + + // Case 1: f16 src 63x64 -> dst 63x128 (valid 63x64) + func.func @TLRELU_f16_63x64_dst128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %slope: f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c4032 = arith.constant 4032 : index + %c8064 = arith.constant 8064 : index + + // Src GM view: 1x1x1x63x64 + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xf16> + + // Dst GM view: shape=valid_shape (63x64), strides based on dst allocation (63x128) + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c8064, %c8064, %c8064, %c128, %c1] + : !pto.tensor_view<1x1x1x63x64xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + // Dst partition: sizes = valid_shape (63x64) + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + + // Src UB tile: 63x64, valid 63x64 + %src_tile = pto.alloc_tile + : !pto.tile_buf + // Dst UB tile: 63x64, valid 63x64 + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) + outs(%src_tile : !pto.tile_buf) + + pto.tlrelu ins(%src_tile, %slope : !pto.tile_buf, f32) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) + return + } + + // Case 2: f32 src 7x448 -> dst 7x512 (valid 7x448) + func.func @TLRELU_f32_7x448_dst512(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %slope: f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c448 = arith.constant 448 : index + %c512 = arith.constant 512 : index + %c3136 = arith.constant 3136 : index + %c3584 = arith.constant 3584 : index + + // Src GM view: 1x1x1x7x448 + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xf32> + + // Dst GM view: shape=valid_shape (7x448), strides based on dst allocation (7x512) + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3584, %c3584, %c3584, %c512, %c1] + : !pto.tensor_view<1x1x1x7x448xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c448] + : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x448xf32> + // Dst partition: sizes = valid_shape (7x448) + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c448] + : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x448xf32> + + // Src UB tile: 7x448, valid 7x448 + %src_tile = pto.alloc_tile + : !pto.tile_buf + // Dst UB tile: 7x448, valid 7x448 + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x7x448xf32>) + outs(%src_tile : !pto.tile_buf) + + pto.tlrelu ins(%src_tile, %slope : !pto.tile_buf, f32) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x7x448xf32>) + return + } + + // Case 3: f32 src 256x16 -> dst 256x32 (valid 256x16) + func.func @TLRELU_f32_256x16_dst32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %slope: f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %c4096 = arith.constant 4096 : index + %c8192 = arith.constant 8192 : index + + // Src GM view: 1x1x1x256x16 + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c4096, %c4096, %c4096, %c16, %c1] + : !pto.tensor_view<1x1x1x256x16xf32> + + // Dst GM view: shape=valid_shape (256x16), strides based on dst allocation (256x32) + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c8192, %c8192, %c8192, %c32, %c1] + : !pto.tensor_view<1x1x1x256x16xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c16] + : !pto.tensor_view<1x1x1x256x16xf32> -> !pto.partition_tensor_view<1x1x1x256x16xf32> + // Dst partition: sizes = valid_shape (256x16) + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c16] + : !pto.tensor_view<1x1x1x256x16xf32> -> !pto.partition_tensor_view<1x1x1x256x16xf32> + + // Src UB tile: 256x16, valid 256x16 + %src_tile = pto.alloc_tile + : !pto.tile_buf + // Dst UB tile: 256x16, valid 256x16 + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x256x16xf32>) + outs(%src_tile : !pto.tile_buf) + + pto.tlrelu ins(%src_tile, %slope : !pto.tile_buf, f32) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x256x16xf32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt index 0b5c2dc9a..cb0337cff 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt +++ b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt @@ -132,6 +132,10 @@ set(ALL_TESTCASES tadd tcvt tload + tlrelu + trelu + tsel + tsels tcolmax tcolmin tcolsum diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/CMakeLists.txt new file mode 100644 index 000000000..7c79f07d5 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tlrelu) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/cases.py new file mode 100644 index 000000000..a2b897e5d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/cases.py @@ -0,0 +1,65 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tlrelu ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — src tile dimensions (UB allocation). + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - dst_shape: (rows, cols) — dst tile physical dimensions (UB allocation, may have padding). + - dst_valid_shape: (valid_rows, valid_cols) — dst effective region (same as valid_shape). + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_32x64_dst128", + "dtype": np.float32, + "shape": (32, 64), + "valid_shape": (32, 64), + "dst_shape": (32, 128), + "dst_valid_shape": (32, 64), + "eps": 1e-3, + }, + { + "name": "f16_63x64_dst128", + "dtype": np.float16, + "shape": (63, 64), + "valid_shape": (63, 64), + "dst_shape": (63, 128), + "dst_valid_shape": (63, 64), + "eps": 1e-3, + }, + { + "name": "f32_7x448_dst512", + "dtype": np.float32, + "shape": (7, 448), + "valid_shape": (7, 448), + "dst_shape": (7, 512), + "dst_valid_shape": (7, 448), + "eps": 1e-3, + }, + { + "name": "f32_256x16_dst32", + "dtype": np.float32, + "shape": (256, 16), + "valid_shape": (256, 16), + "dst_shape": (256, 32), + "dst_valid_shape": (256, 16), + "eps": 1e-3, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/compare.py new file mode 100644 index 000000000..6af6f6d5c --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dst_shape = case["dst_shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(dst_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(dst_shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/gen_data.py new file mode 100644 index 000000000..22b2b3314 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/gen_data.py @@ -0,0 +1,45 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +import struct +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + dst_shape = case["dst_shape"] + valid_shape = case["valid_shape"] + + rows, cols = shape + dst_rows, dst_cols = dst_shape + vr, vc = valid_shape + + input_arr = np.random.uniform(low=-8, high=8, size=(rows, cols)).astype(dtype) + slope = np.random.uniform(low=-8, high=8, size=(1, 1)).astype(np.float32) + golden = np.zeros((dst_rows, dst_cols), dtype=dtype) + + for i in range(vr): + for j in range(vc): + if input_arr[i, j] > 0: + golden[i, j] = input_arr[i, j] + else: + golden[i, j] = dtype(input_arr[i, j] * slope[0, 0]) + + slope_arr = np.array([slope[0, 0]], dtype=np.float32) + + save_case_data(case["name"], {"input": input_arr, "slope": slope_arr, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} dst_shape={dst_shape} valid_shape={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/launch.cpp new file mode 100644 index 000000000..356bd5115 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/launch.cpp @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 32x64 -> dst 32x128 (valid 32x64) +extern "C" __global__ AICORE void TLRELU_f32_32x64_dst128(__gm__ float *src, __gm__ float *dst, float slope); + +void LaunchTLRELU_f32_32x64_dst128(float *src, float *dst, float slope, void *stream) { + TLRELU_f32_32x64_dst128<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, slope); +} + +// Case 1: f16 63x64 -> dst 63x128 (valid 63x64) +extern "C" __global__ AICORE void TLRELU_f16_63x64_dst128(__gm__ uint16_t *src, __gm__ uint16_t *dst, float slope); + +void LaunchTLRELU_f16_63x64_dst128(uint16_t *src, uint16_t *dst, float slope, void *stream) { + TLRELU_f16_63x64_dst128<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint16_t *)dst, slope); +} + +// Case 2: f32 7x448 -> dst 7x512 (valid 7x448) +extern "C" __global__ AICORE void TLRELU_f32_7x448_dst512(__gm__ float *src, __gm__ float *dst, float slope); + +void LaunchTLRELU_f32_7x448_dst512(float *src, float *dst, float slope, void *stream) { + TLRELU_f32_7x448_dst512<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, slope); +} + +// Case 3: f32 256x16 -> dst 256x32 (valid 256x16) +extern "C" __global__ AICORE void TLRELU_f32_256x16_dst32(__gm__ float *src, __gm__ float *dst, float slope); + +void LaunchTLRELU_f32_256x16_dst32(float *src, float *dst, float slope, void *stream) { + TLRELU_f32_256x16_dst32<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, slope); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/main.cpp new file mode 100644 index 000000000..3b75edf69 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/main.cpp @@ -0,0 +1,154 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tlrelu ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTLRELU_f32_32x64_dst128(float *src, float *dst, float slope, void *stream); +void LaunchTLRELU_f16_63x64_dst128(uint16_t *src, uint16_t *dst, float slope, void *stream); +void LaunchTLRELU_f32_7x448_dst512(float *src, float *dst, float slope, void *stream); +void LaunchTLRELU_f32_256x16_dst32(float *src, float *dst, float slope, void *stream); + +using LaunchFn = void (*)(void *, void *, float, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t srcRows; // src tile rows + size_t srcCols; // src tile cols + size_t dstRows; // dst tile rows (may have padding) + size_t dstCols; // dst tile cols (may have padding) + size_t validRows; // effective computation rows (<= srcRows, dstRows) + size_t validCols; // effective computation cols (<= srcCols, dstCols) + size_t elemSize; // bytes per element + bool isFp16; // true for float16 case +}; + +static const TestCase kCases[] = { + {"f32_32x64_dst128", (LaunchFn)LaunchTLRELU_f32_32x64_dst128, 32, 64, 32, 128, 32, 64, sizeof(float), false}, + {"f16_63x64_dst128", (LaunchFn)LaunchTLRELU_f16_63x64_dst128, 63, 64, 63, 128, 63, 64, sizeof(uint16_t), true}, + {"f32_7x448_dst512", (LaunchFn)LaunchTLRELU_f32_7x448_dst512, 7, 448, 7, 512, 7, 448, sizeof(float), false}, + {"f32_256x16_dst32", (LaunchFn)LaunchTLRELU_f32_256x16_dst32, 256, 16, 256, 32, 256, 16, sizeof(float), false}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + size_t srcFileSize = tc.srcRows * tc.srcCols * tc.elemSize; + size_t dstFileSize = tc.dstRows * tc.dstCols * tc.elemSize; + size_t actualSize = 0; + + std::printf("[INFO] === case: %s (src=%zux%zu, dst=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.srcRows, tc.srcCols, tc.dstRows, tc.dstCols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + float slope = 0.0f; + + aclrtMallocHost(&srcHost, srcFileSize); + aclrtMallocHost(&dstHost, dstFileSize); + + aclrtMalloc(&srcDevice, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile(caseDir + "/input.bin", actualSize, srcHost, srcFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + // Read slope (4 bytes float) + if (rc == 0) { + std::ifstream slopeFile(caseDir + "/slope.bin", std::ios::binary); + if (!slopeFile) { + std::fprintf(stderr, "[ERROR] failed to open %s/slope.bin\n", caseDir.c_str()); + rc = 1; + } else { + slopeFile.read(reinterpret_cast(&slope), sizeof(float)); + slopeFile.close(); + } + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, srcFileSize, srcHost, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, slope, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tlrelu [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/tlrelu.pto b/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/tlrelu.pto new file mode 100644 index 000000000..c4131bc67 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/tlrelu.pto @@ -0,0 +1,238 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tlrelu: tload(src) + tlrelu(src, slope)->dst + tstore(dst). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. + +module { + // Case 0: f32 src 32x64 -> dst 32x128 (valid 32x64) + func.func @TLRELU_f32_32x64_dst128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %slope: f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c2048 = arith.constant 2048 : index + %c4096 = arith.constant 4096 : index + + // Src GM view: 1x1x1x32x64 + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf32> + + // Dst GM view: shape=valid_shape (32x64), strides based on dst allocation (32x128) + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c4096, %c4096, %c4096, %c128, %c1] + : !pto.tensor_view<1x1x1x32x64xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + // Dst partition: sizes = valid_shape (32x64) + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + + // Src UB tile: 32x64, valid 32x64 + %src_tile = pto.alloc_tile + : !pto.tile_buf + // Dst UB tile: 32x64, valid 32x64 + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + outs(%src_tile : !pto.tile_buf) + + pto.tlrelu ins(%src_tile, %slope : !pto.tile_buf, f32) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + return + } + + // Case 1: f16 src 63x64 -> dst 63x128 (valid 63x64) + func.func @TLRELU_f16_63x64_dst128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %slope: f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c4032 = arith.constant 4032 : index + %c8064 = arith.constant 8064 : index + + // Src GM view: 1x1x1x63x64 + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xf16> + + // Dst GM view: shape=valid_shape (63x64), strides based on dst allocation (63x128) + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c8064, %c8064, %c8064, %c128, %c1] + : !pto.tensor_view<1x1x1x63x64xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + // Dst partition: sizes = valid_shape (63x64) + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + + // Src UB tile: 63x64, valid 63x64 + %src_tile = pto.alloc_tile + : !pto.tile_buf + // Dst UB tile: 63x64, valid 63x64 + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) + outs(%src_tile : !pto.tile_buf) + + pto.tlrelu ins(%src_tile, %slope : !pto.tile_buf, f32) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) + return + } + + // Case 2: f32 src 7x448 -> dst 7x512 (valid 7x448) + func.func @TLRELU_f32_7x448_dst512(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %slope: f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c448 = arith.constant 448 : index + %c512 = arith.constant 512 : index + %c3136 = arith.constant 3136 : index + %c3584 = arith.constant 3584 : index + + // Src GM view: 1x1x1x7x448 + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xf32> + + // Dst GM view: shape=valid_shape (7x448), strides based on dst allocation (7x512) + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3584, %c3584, %c3584, %c512, %c1] + : !pto.tensor_view<1x1x1x7x448xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c448] + : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x448xf32> + // Dst partition: sizes = valid_shape (7x448) + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c448] + : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x448xf32> + + // Src UB tile: 7x448, valid 7x448 + %src_tile = pto.alloc_tile + : !pto.tile_buf + // Dst UB tile: 7x448, valid 7x448 + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x7x448xf32>) + outs(%src_tile : !pto.tile_buf) + + pto.tlrelu ins(%src_tile, %slope : !pto.tile_buf, f32) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x7x448xf32>) + return + } + + // Case 3: f32 src 256x16 -> dst 256x32 (valid 256x16) + func.func @TLRELU_f32_256x16_dst32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %slope: f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %c4096 = arith.constant 4096 : index + %c8192 = arith.constant 8192 : index + + // Src GM view: 1x1x1x256x16 + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c4096, %c4096, %c4096, %c16, %c1] + : !pto.tensor_view<1x1x1x256x16xf32> + + // Dst GM view: shape=valid_shape (256x16), strides based on dst allocation (256x32) + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c8192, %c8192, %c8192, %c32, %c1] + : !pto.tensor_view<1x1x1x256x16xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c16] + : !pto.tensor_view<1x1x1x256x16xf32> -> !pto.partition_tensor_view<1x1x1x256x16xf32> + // Dst partition: sizes = valid_shape (256x16) + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c16] + : !pto.tensor_view<1x1x1x256x16xf32> -> !pto.partition_tensor_view<1x1x1x256x16xf32> + + // Src UB tile: 256x16, valid 256x16 + %src_tile = pto.alloc_tile + : !pto.tile_buf + // Dst UB tile: 256x16, valid 256x16 + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x256x16xf32>) + outs(%src_tile : !pto.tile_buf) + + pto.tlrelu ins(%src_tile, %slope : !pto.tile_buf, f32) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x256x16xf32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trelu/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/trelu/CMakeLists.txt new file mode 100644 index 000000000..5b01f92c5 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trelu/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(trelu) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trelu/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/trelu/cases.py new file mode 100644 index 000000000..85823525c --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trelu/cases.py @@ -0,0 +1,52 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for trelu ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — global data dimensions (input/output size). + - tile_shape: (tile_rows, tile_cols) — allocated tile buffer dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "int32_64x64", + "dtype": np.int32, + "shape": (64, 64), + "tile_shape": (64, 64), + "valid_shape": (64, 64), + "eps": 1e-6, + }, + { + "name": "f16_64x64_valid_60x60", + "dtype": np.float16, + "shape": (60, 60), + "tile_shape": (64, 64), + "valid_shape": (60, 60), + "eps": 1e-3, + }, + { + "name": "f32_64x64_valid_60x60", + "dtype": np.float32, + "shape": (60, 60), + "tile_shape": (64, 64), + "valid_shape": (60, 60), + "eps": 1e-6, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trelu/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/trelu/compare.py new file mode 100644 index 000000000..4409f6261 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trelu/compare.py @@ -0,0 +1,48 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden, output, case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trelu/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/trelu/gen_data.py new file mode 100644 index 000000000..8911a7cb4 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trelu/gen_data.py @@ -0,0 +1,32 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + + if dtype == np.int32: + input1 = np.random.randint(-3_000_000, 3_000_000, size=shape).astype(dtype) + else: + input1 = np.random.uniform(-10, 10, size=shape).astype(dtype) + + golden = np.maximum(input1, 0) + + save_case_data(case["name"], {"input": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={case['valid_shape']} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trelu/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trelu/launch.cpp new file mode 100644 index 000000000..94e256ff4 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trelu/launch.cpp @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: int32 64x64 +extern "C" __global__ AICORE void TRELU_int32_64x64(__gm__ int32_t *input, __gm__ int32_t *output); + +void LaunchTRELU_int32_64x64(int32_t *input, int32_t *output, void *stream) { + TRELU_int32_64x64<<<1, nullptr, stream>>>((__gm__ int32_t *)input, (__gm__ int32_t *)output); +} + +// Case 1: f16 64x64 valid 60x60 +extern "C" __global__ AICORE void TRELU_f16_64x64_v60x60(__gm__ uint16_t *input, __gm__ uint16_t *output); + +void LaunchTRELU_f16_64x64_v60x60(uint16_t *input, uint16_t *output, void *stream) { + TRELU_f16_64x64_v60x60<<<1, nullptr, stream>>>((__gm__ uint16_t *)input, (__gm__ uint16_t *)output); +} + +// Case 2: f32 64x64 valid 60x60 +extern "C" __global__ AICORE void TRELU_f32_64x64_v60x60(__gm__ float *input, __gm__ float *output); + +void LaunchTRELU_f32_64x64_v60x60(float *input, float *output, void *stream) { + TRELU_f32_64x64_v60x60<<<1, nullptr, stream>>>((__gm__ float *)input, (__gm__ float *)output); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trelu/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trelu/main.cpp new file mode 100644 index 000000000..b19a7fd95 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trelu/main.cpp @@ -0,0 +1,128 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang trelu ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTRELU_int32_64x64(int32_t *input, int32_t *output, void *stream); +void LaunchTRELU_f16_64x64_v60x60(uint16_t *input, uint16_t *output, void *stream); +void LaunchTRELU_f32_64x64_v60x60(float *input, float *output, void *stream); + +struct TestCase { + const char *name; + void (*launch)(void *, void *, void *); + size_t rows; + size_t cols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + {"int32_64x64", (void (*)(void*, void*, void*))LaunchTRELU_int32_64x64, 64, 64, sizeof(int32_t)}, + {"f16_64x64_valid_60x60", (void (*)(void*, void*, void*))LaunchTRELU_f16_64x64_v60x60, 60, 60, sizeof(uint16_t)}, + {"f32_64x64_valid_60x60", (void (*)(void*, void*, void*))LaunchTRELU_f32_64x64_v60x60, 60, 60, sizeof(float)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols); + + std::string caseDir = std::string("./") + tc.name; + + void *inputHost = nullptr, *outputHost = nullptr; + void *inputDevice = nullptr, *outputDevice = nullptr; + + aclrtMallocHost(&inputHost, fileSize); + aclrtMallocHost(&outputHost, fileSize); + + aclrtMalloc(&inputDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&outputDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), fileSize, inputHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(inputDevice, fileSize, inputHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(inputDevice, outputDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(outputHost, fileSize, outputDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), outputHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (inputDevice != nullptr) + aclrtFree(inputDevice); + if (outputDevice != nullptr) + aclrtFree(outputDevice); + if (inputHost != nullptr) + aclrtFreeHost(inputHost); + if (outputHost != nullptr) + aclrtFreeHost(outputHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trelu/trelu.pto b/test/tilelang_st/npu/a5/src/st/testcase/trelu/trelu.pto new file mode 100644 index 000000000..152e0d3f5 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trelu/trelu.pto @@ -0,0 +1,157 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.trelu: tload(input) + trelu(input)->output + tstore(output). +// Multiple cases with different shapes and dtypes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. + +module { + // Case 0: int32 64x64 (4096 elements, valid=64x64) + func.func @TRELU_int32_64x64(%input_ptr: !pto.ptr, %output_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %input_view = pto.make_tensor_view %input_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + %output_view = pto.make_tensor_view %output_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + + %input_part = pto.partition_view %input_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + %output_part = pto.partition_view %output_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%input_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + outs(%src : !pto.tile_buf) + + pto.trelu ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%output_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + return + } + + // Case 1: f16 64x64 (4096 elements, valid=60x60) + func.func @TRELU_f16_64x64_v60x60(%input_ptr: !pto.ptr, %output_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c60 = arith.constant 60 : index + %c64 = arith.constant 64 : index + %c3600 = arith.constant 3600 : index + + %input_view = pto.make_tensor_view %input_ptr, + shape = [%c1, %c1, %c1, %c60, %c60], + strides = [%c3600, %c3600, %c3600, %c60, %c1] + : !pto.tensor_view<1x1x1x60x60xf16> + %output_view = pto.make_tensor_view %output_ptr, + shape = [%c1, %c1, %c1, %c60, %c60], + strides = [%c3600, %c3600, %c3600, %c60, %c1] + : !pto.tensor_view<1x1x1x60x60xf16> + + %input_part = pto.partition_view %input_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x60x60xf16> -> !pto.partition_tensor_view<1x1x1x60x60xf16> + %output_part = pto.partition_view %output_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x60x60xf16> -> !pto.partition_tensor_view<1x1x1x60x60xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%input_part : !pto.partition_tensor_view<1x1x1x60x60xf16>) + outs(%src : !pto.tile_buf) + + pto.trelu ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%output_part : !pto.partition_tensor_view<1x1x1x60x60xf16>) + return + } + + // Case 2: f32 64x64 (4096 elements, valid=60x60) + func.func @TRELU_f32_64x64_v60x60(%input_ptr: !pto.ptr, %output_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c60 = arith.constant 60 : index + %c64 = arith.constant 64 : index + %c3600 = arith.constant 3600 : index + + %input_view = pto.make_tensor_view %input_ptr, + shape = [%c1, %c1, %c1, %c60, %c60], + strides = [%c3600, %c3600, %c3600, %c60, %c1] + : !pto.tensor_view<1x1x1x60x60xf32> + %output_view = pto.make_tensor_view %output_ptr, + shape = [%c1, %c1, %c1, %c60, %c60], + strides = [%c3600, %c3600, %c3600, %c60, %c1] + : !pto.tensor_view<1x1x1x60x60xf32> + + %input_part = pto.partition_view %input_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x60x60xf32> -> !pto.partition_tensor_view<1x1x1x60x60xf32> + %output_part = pto.partition_view %output_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x60x60xf32> -> !pto.partition_tensor_view<1x1x1x60x60xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%input_part : !pto.partition_tensor_view<1x1x1x60x60xf32>) + outs(%src : !pto.tile_buf) + + pto.trelu ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%output_part : !pto.partition_tensor_view<1x1x1x60x60xf32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsel/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tsel/CMakeLists.txt new file mode 100644 index 000000000..73a77806b --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsel/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tsel) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsel/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tsel/cases.py new file mode 100644 index 000000000..3432ab7e9 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsel/cases.py @@ -0,0 +1,97 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tsel ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_2x128", + "dtype": np.float32, + "shape": (2, 128), + "valid_shape": (2, 128), + "eps": 1e-6, + }, + { + "name": "f32_2x32", + "dtype": np.float32, + "shape": (2, 32), + "valid_shape": (2, 32), + "eps": 1e-6, + }, + { + "name": "f32_2x160", + "dtype": np.float32, + "shape": (2, 160), + "valid_shape": (2, 160), + "eps": 1e-6, + }, + { + "name": "f32_2x512", + "dtype": np.float32, + "shape": (2, 512), + "valid_shape": (2, 512), + "eps": 1e-6, + }, + { + "name": "f16_2x128", + "dtype": np.float16, + "shape": (2, 128), + "valid_shape": (2, 128), + "eps": 1e-3, + }, + { + "name": "f16_2x32", + "dtype": np.float16, + "shape": (2, 32), + "valid_shape": (2, 32), + "eps": 1e-3, + }, + { + "name": "f16_2x160", + "dtype": np.float16, + "shape": (2, 160), + "valid_shape": (2, 160), + "eps": 1e-3, + }, + { + "name": "i8_2x128", + "dtype": np.int8, + "shape": (2, 128), + "valid_shape": (2, 128), + "eps": 0, + }, + { + "name": "i8_2x32", + "dtype": np.int8, + "shape": (2, 32), + "valid_shape": (2, 32), + "eps": 0, + }, + { + "name": "i8_2x160", + "dtype": np.int8, + "shape": (2, 160), + "valid_shape": (2, 160), + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsel/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tsel/compare.py new file mode 100644 index 000000000..b975718e6 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsel/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsel/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tsel/gen_data.py new file mode 100644 index 000000000..7308ac94d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsel/gen_data.py @@ -0,0 +1,44 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + vr, vc = valid_shape + mask_cols = (vc + 7) // 8 + + src0 = np.random.randint(1, 10, size=shape).astype(dtype) + src1 = np.random.randint(1, 10, size=shape).astype(dtype) + mask = np.random.randint(0, 256, size=(vr, mask_cols), dtype=np.uint8) + + golden = np.zeros(shape, dtype=dtype) + src0_valid = src0[:vr, :vc] + src1_valid = src1[:vr, :vc] + for row in range(vr): + for packed_col in range(mask_cols): + byte = int(mask[row, packed_col]) + for bit in range(8): + col = packed_col * 8 + bit + if col >= vc: + break + golden[row, col] = src0_valid[row, col] if ((byte >> bit) & 1) else src1_valid[row, col] + + save_case_data(case["name"], {"input1": src0, "input2": src1, "input3": mask, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsel/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tsel/launch.cpp new file mode 100644 index 000000000..dbd0edbe1 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsel/launch.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 2x128 +extern "C" __global__ AICORE void TSEL_f32_2x128(__gm__ uint8_t *mask, __gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTSEL_f32_2x128(uint8_t *mask, float *src0, float *src1, float *dst, void *stream) { + TSEL_f32_2x128<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case 1: f32 2x32 +extern "C" __global__ AICORE void TSEL_f32_2x32(__gm__ uint8_t *mask, __gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTSEL_f32_2x32(uint8_t *mask, float *src0, float *src1, float *dst, void *stream) { + TSEL_f32_2x32<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case 2: f32 2x160 +extern "C" __global__ AICORE void TSEL_f32_2x160(__gm__ uint8_t *mask, __gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTSEL_f32_2x160(uint8_t *mask, float *src0, float *src1, float *dst, void *stream) { + TSEL_f32_2x160<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case 3: f32 2x512 +extern "C" __global__ AICORE void TSEL_f32_2x512(__gm__ uint8_t *mask, __gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTSEL_f32_2x512(uint8_t *mask, float *src0, float *src1, float *dst, void *stream) { + TSEL_f32_2x512<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case 4: f16 2x128 +extern "C" __global__ AICORE void TSEL_f16_2x128(__gm__ uint8_t *mask, __gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTSEL_f16_2x128(uint8_t *mask, uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TSEL_f16_2x128<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} + +// Case 5: f16 2x32 +extern "C" __global__ AICORE void TSEL_f16_2x32(__gm__ uint8_t *mask, __gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTSEL_f16_2x32(uint8_t *mask, uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TSEL_f16_2x32<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} + +// Case 6: f16 2x160 +extern "C" __global__ AICORE void TSEL_f16_2x160(__gm__ uint8_t *mask, __gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTSEL_f16_2x160(uint8_t *mask, uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TSEL_f16_2x160<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} + +// Case 7: i8 2x128 +extern "C" __global__ AICORE void TSEL_i8_2x128(__gm__ uint8_t *mask, __gm__ int8_t *src0, __gm__ int8_t *src1, __gm__ int8_t *dst); + +void LaunchTSEL_i8_2x128(uint8_t *mask, int8_t *src0, int8_t *src1, int8_t *dst, void *stream) { + TSEL_i8_2x128<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ int8_t *)src0, (__gm__ int8_t *)src1, (__gm__ int8_t *)dst); +} + +// Case 8: i8 2x32 +extern "C" __global__ AICORE void TSEL_i8_2x32(__gm__ uint8_t *mask, __gm__ int8_t *src0, __gm__ int8_t *src1, __gm__ int8_t *dst); + +void LaunchTSEL_i8_2x32(uint8_t *mask, int8_t *src0, int8_t *src1, int8_t *dst, void *stream) { + TSEL_i8_2x32<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ int8_t *)src0, (__gm__ int8_t *)src1, (__gm__ int8_t *)dst); +} + +// Case 9: i8 2x160 +extern "C" __global__ AICORE void TSEL_i8_2x160(__gm__ uint8_t *mask, __gm__ int8_t *src0, __gm__ int8_t *src1, __gm__ int8_t *dst); + +void LaunchTSEL_i8_2x160(uint8_t *mask, int8_t *src0, int8_t *src1, int8_t *dst, void *stream) { + TSEL_i8_2x160<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ int8_t *)src0, (__gm__ int8_t *)src1, (__gm__ int8_t *)dst); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsel/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tsel/main.cpp new file mode 100644 index 000000000..4bf41b7bd --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsel/main.cpp @@ -0,0 +1,312 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tsel ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTSEL_f32_2x128(uint8_t *mask, float *src0, float *src1, float *dst, void *stream); +void LaunchTSEL_f32_2x32(uint8_t *mask, float *src0, float *src1, float *dst, void *stream); +void LaunchTSEL_f32_2x160(uint8_t *mask, float *src0, float *src1, float *dst, void *stream); +void LaunchTSEL_f32_2x512(uint8_t *mask, float *src0, float *src1, float *dst, void *stream); +void LaunchTSEL_f16_2x128(uint8_t *mask, uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); +void LaunchTSEL_f16_2x32(uint8_t *mask, uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); +void LaunchTSEL_f16_2x160(uint8_t *mask, uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); +void LaunchTSEL_i8_2x128(uint8_t *mask, int8_t *src0, int8_t *src1, int8_t *dst, void *stream); +void LaunchTSEL_i8_2x32(uint8_t *mask, int8_t *src0, int8_t *src1, int8_t *dst, void *stream); +void LaunchTSEL_i8_2x160(uint8_t *mask, int8_t *src0, int8_t *src1, int8_t *dst, void *stream); + +enum DataType { DT_F32, DT_F16, DT_I8 }; + +using LaunchFnF32 = void (*)(uint8_t *, float *, float *, float *, void *); +using LaunchFnF16 = void (*)(uint8_t *, uint16_t *, uint16_t *, uint16_t *, void *); +using LaunchFnI8 = void (*)(uint8_t *, int8_t *, int8_t *, int8_t *, void *); + +struct TestCase { + const char *name; + DataType dtype; + LaunchFnF32 launchF32; + LaunchFnF16 launchF16; + LaunchFnI8 launchI8; + size_t rows; + size_t cols; + size_t validRows; + size_t validCols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + {"f32_2x128", DT_F32, LaunchTSEL_f32_2x128, nullptr, nullptr, 2, 128, 2, 128, sizeof(float)}, + {"f32_2x32", DT_F32, LaunchTSEL_f32_2x32, nullptr, nullptr, 2, 32, 2, 32, sizeof(float)}, + {"f32_2x160", DT_F32, LaunchTSEL_f32_2x160, nullptr, nullptr, 2, 160, 2, 160, sizeof(float)}, + {"f32_2x512", DT_F32, LaunchTSEL_f32_2x512, nullptr, nullptr, 2, 512, 2, 512, sizeof(float)}, + {"f16_2x128", DT_F16, nullptr, LaunchTSEL_f16_2x128, nullptr, 2, 128, 2, 128, sizeof(uint16_t)}, + {"f16_2x32", DT_F16, nullptr, LaunchTSEL_f16_2x32, nullptr, 2, 32, 2, 32, sizeof(uint16_t)}, + {"f16_2x160", DT_F16, nullptr, LaunchTSEL_f16_2x160, nullptr, 2, 160, 2, 160, sizeof(uint16_t)}, + {"i8_2x128", DT_I8, nullptr, nullptr, LaunchTSEL_i8_2x128, 2, 128, 2, 128, sizeof(int8_t)}, + {"i8_2x32", DT_I8, nullptr, nullptr, LaunchTSEL_i8_2x32, 2, 32, 2, 32, sizeof(int8_t)}, + {"i8_2x160", DT_I8, nullptr, nullptr, LaunchTSEL_i8_2x160, 2, 160, 2, 160, sizeof(int8_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSizeConst = elemCount * tc.elemSize; + const size_t maskCols = (tc.validCols + 7) / 8; + const size_t maskFileSizeConst = tc.validRows * maskCols * sizeof(uint8_t); + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + std::string caseDir = std::string("./") + tc.name; + + if (tc.dtype == DT_F32) { + uint8_t *maskHost = nullptr; + float *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + uint8_t *maskDevice = nullptr; + float *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&maskHost), maskFileSizeConst); + aclrtMallocHost((void **)(&src0Host), fileSizeConst); + aclrtMallocHost((void **)(&src1Host), fileSizeConst); + aclrtMallocHost((void **)(&dstHost), fileSizeConst); + + aclrtMalloc((void **)&maskDevice, maskFileSizeConst, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src0Device, fileSizeConst, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, fileSizeConst, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSizeConst, ACL_MEM_MALLOC_HUGE_FIRST); + + size_t fileSize = fileSizeConst; + if (!ReadFile((caseDir + "/input1.bin").c_str(), fileSize, src0Host, fileSizeConst)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + fileSize = fileSizeConst; + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), fileSize, src1Host, fileSizeConst)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + size_t maskFileSize = maskFileSizeConst; + if (rc == 0 && !ReadFile((caseDir + "/input3.bin").c_str(), maskFileSize, maskHost, maskFileSizeConst)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input3.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(maskDevice, maskFileSizeConst, maskHost, maskFileSizeConst, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src0Device, fileSizeConst, src0Host, fileSizeConst, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, fileSizeConst, src1Host, fileSizeConst, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launchF32(maskDevice, src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSizeConst, dstDevice, fileSizeConst, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSizeConst)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (maskDevice != nullptr) + aclrtFree(maskDevice); + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (maskHost != nullptr) + aclrtFreeHost(maskHost); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + } else if (tc.dtype == DT_F16) { + uint8_t *maskHost = nullptr; + uint16_t *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + uint8_t *maskDevice = nullptr; + uint16_t *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&maskHost), maskFileSizeConst); + aclrtMallocHost((void **)(&src0Host), fileSizeConst); + aclrtMallocHost((void **)(&src1Host), fileSizeConst); + aclrtMallocHost((void **)(&dstHost), fileSizeConst); + + aclrtMalloc((void **)&maskDevice, maskFileSizeConst, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src0Device, fileSizeConst, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, fileSizeConst, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSizeConst, ACL_MEM_MALLOC_HUGE_FIRST); + + size_t fileSize = fileSizeConst; + if (!ReadFile((caseDir + "/input1.bin").c_str(), fileSize, src0Host, fileSizeConst)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + fileSize = fileSizeConst; + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), fileSize, src1Host, fileSizeConst)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + size_t maskFileSize = maskFileSizeConst; + if (rc == 0 && !ReadFile((caseDir + "/input3.bin").c_str(), maskFileSize, maskHost, maskFileSizeConst)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input3.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(maskDevice, maskFileSizeConst, maskHost, maskFileSizeConst, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src0Device, fileSizeConst, src0Host, fileSizeConst, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, fileSizeConst, src1Host, fileSizeConst, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launchF16(maskDevice, src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSizeConst, dstDevice, fileSizeConst, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSizeConst)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (maskDevice != nullptr) + aclrtFree(maskDevice); + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (maskHost != nullptr) + aclrtFreeHost(maskHost); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + } else { + uint8_t *maskHost = nullptr; + int8_t *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + uint8_t *maskDevice = nullptr; + int8_t *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&maskHost), maskFileSizeConst); + aclrtMallocHost((void **)(&src0Host), fileSizeConst); + aclrtMallocHost((void **)(&src1Host), fileSizeConst); + aclrtMallocHost((void **)(&dstHost), fileSizeConst); + + aclrtMalloc((void **)&maskDevice, maskFileSizeConst, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src0Device, fileSizeConst, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, fileSizeConst, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSizeConst, ACL_MEM_MALLOC_HUGE_FIRST); + + size_t fileSize = fileSizeConst; + if (!ReadFile((caseDir + "/input1.bin").c_str(), fileSize, src0Host, fileSizeConst)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + fileSize = fileSizeConst; + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), fileSize, src1Host, fileSizeConst)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + size_t maskFileSize = maskFileSizeConst; + if (rc == 0 && !ReadFile((caseDir + "/input3.bin").c_str(), maskFileSize, maskHost, maskFileSizeConst)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input3.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(maskDevice, maskFileSizeConst, maskHost, maskFileSizeConst, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src0Device, fileSizeConst, src0Host, fileSizeConst, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, fileSizeConst, src1Host, fileSizeConst, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launchI8(maskDevice, src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSizeConst, dstDevice, fileSizeConst, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSizeConst)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (maskDevice != nullptr) + aclrtFree(maskDevice); + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (maskHost != nullptr) + aclrtFreeHost(maskHost); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + } + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsel/tsel.pto b/test/tilelang_st/npu/a5/src/st/testcase/tsel/tsel.pto new file mode 100644 index 000000000..479b5fdc6 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsel/tsel.pto @@ -0,0 +1,884 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tsel: packed mask tload + tload(src0) + tload(src1) + tsel(mask,src0,src1,tmp,dst) + tstore(dst). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. + +module { + // Case: f32 2x128 + func.func @TSEL_f32_2x128(%mask_ptr: !pto.ptr, %src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %mask_view = pto.make_tensor_view %mask_ptr, + shape = [%c1, %c1, %c1, %c2, %c16], + strides = [%c32, %c32, %c32, %c16, %c1] + : !pto.tensor_view<1x1x1x2x16xi8> + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xf32> + + %mask_part = pto.partition_view %mask_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c16] + : !pto.tensor_view<1x1x1x2x16xi8> -> !pto.partition_tensor_view<1x1x1x2x16xi8> + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xf32> -> !pto.partition_tensor_view<1x1x1x2x128xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xf32> -> !pto.partition_tensor_view<1x1x1x2x128xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xf32> -> !pto.partition_tensor_view<1x1x1x2x128xf32> + + %mask = pto.alloc_tile + : !pto.tile_buf + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x16xi8>) + outs(%mask : !pto.tile_buf) + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x2x128xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x2x128xf32>) + outs(%src1 : !pto.tile_buf) + + pto.tsel ins(%mask, %src0, %src1, %tmp : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xf32>) + return + } + + // Case: f32 2x32 + func.func @TSEL_f32_2x32(%mask_ptr: !pto.ptr, %src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + + %mask_view = pto.make_tensor_view %mask_ptr, + shape = [%c1, %c1, %c1, %c2, %c4], + strides = [%c8, %c8, %c8, %c4, %c1] + : !pto.tensor_view<1x1x1x2x4xi8> + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c2, %c32], + strides = [%c64, %c64, %c64, %c32, %c1] + : !pto.tensor_view<1x1x1x2x32xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c2, %c32], + strides = [%c64, %c64, %c64, %c32, %c1] + : !pto.tensor_view<1x1x1x2x32xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c32], + strides = [%c64, %c64, %c64, %c32, %c1] + : !pto.tensor_view<1x1x1x2x32xf32> + + %mask_part = pto.partition_view %mask_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c4] + : !pto.tensor_view<1x1x1x2x4xi8> -> !pto.partition_tensor_view<1x1x1x2x4xi8> + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c32] + : !pto.tensor_view<1x1x1x2x32xf32> -> !pto.partition_tensor_view<1x1x1x2x32xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c32] + : !pto.tensor_view<1x1x1x2x32xf32> -> !pto.partition_tensor_view<1x1x1x2x32xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c32] + : !pto.tensor_view<1x1x1x2x32xf32> -> !pto.partition_tensor_view<1x1x1x2x32xf32> + + %mask = pto.alloc_tile + : !pto.tile_buf + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x4xi8>) + outs(%mask : !pto.tile_buf) + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x2x32xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x2x32xf32>) + outs(%src1 : !pto.tile_buf) + + pto.tsel ins(%mask, %src0, %src1, %tmp : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x32xf32>) + return + } + + // Case: f32 2x160 + func.func @TSEL_f32_2x160(%mask_ptr: !pto.ptr, %src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c20 = arith.constant 20 : index + %c40 = arith.constant 40 : index + %c160 = arith.constant 160 : index + %c320 = arith.constant 320 : index + + %mask_view = pto.make_tensor_view %mask_ptr, + shape = [%c1, %c1, %c1, %c2, %c20], + strides = [%c40, %c40, %c40, %c20, %c1] + : !pto.tensor_view<1x1x1x2x20xi8> + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c2, %c160], + strides = [%c320, %c320, %c320, %c160, %c1] + : !pto.tensor_view<1x1x1x2x160xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c2, %c160], + strides = [%c320, %c320, %c320, %c160, %c1] + : !pto.tensor_view<1x1x1x2x160xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c160], + strides = [%c320, %c320, %c320, %c160, %c1] + : !pto.tensor_view<1x1x1x2x160xf32> + + %mask_part = pto.partition_view %mask_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c20] + : !pto.tensor_view<1x1x1x2x20xi8> -> !pto.partition_tensor_view<1x1x1x2x20xi8> + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c160] + : !pto.tensor_view<1x1x1x2x160xf32> -> !pto.partition_tensor_view<1x1x1x2x160xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c160] + : !pto.tensor_view<1x1x1x2x160xf32> -> !pto.partition_tensor_view<1x1x1x2x160xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c160] + : !pto.tensor_view<1x1x1x2x160xf32> -> !pto.partition_tensor_view<1x1x1x2x160xf32> + + %mask = pto.alloc_tile + : !pto.tile_buf + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x20xi8>) + outs(%mask : !pto.tile_buf) + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x2x160xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x2x160xf32>) + outs(%src1 : !pto.tile_buf) + + pto.tsel ins(%mask, %src0, %src1, %tmp : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x160xf32>) + return + } + + // Case: f32 2x512 + func.func @TSEL_f32_2x512(%mask_ptr: !pto.ptr, %src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + %c1024 = arith.constant 1024 : index + + %mask_view = pto.make_tensor_view %mask_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xi8> + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c2, %c512], + strides = [%c1024, %c1024, %c1024, %c512, %c1] + : !pto.tensor_view<1x1x1x2x512xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c2, %c512], + strides = [%c1024, %c1024, %c1024, %c512, %c1] + : !pto.tensor_view<1x1x1x2x512xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c512], + strides = [%c1024, %c1024, %c1024, %c512, %c1] + : !pto.tensor_view<1x1x1x2x512xf32> + + %mask_part = pto.partition_view %mask_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xi8> -> !pto.partition_tensor_view<1x1x1x2x64xi8> + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c512] + : !pto.tensor_view<1x1x1x2x512xf32> -> !pto.partition_tensor_view<1x1x1x2x512xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c512] + : !pto.tensor_view<1x1x1x2x512xf32> -> !pto.partition_tensor_view<1x1x1x2x512xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c512] + : !pto.tensor_view<1x1x1x2x512xf32> -> !pto.partition_tensor_view<1x1x1x2x512xf32> + + %mask = pto.alloc_tile + : !pto.tile_buf + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x64xi8>) + outs(%mask : !pto.tile_buf) + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x2x512xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x2x512xf32>) + outs(%src1 : !pto.tile_buf) + + pto.tsel ins(%mask, %src0, %src1, %tmp : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x512xf32>) + return + } + + // Case: f16 2x128 + func.func @TSEL_f16_2x128(%mask_ptr: !pto.ptr, %src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %mask_view = pto.make_tensor_view %mask_ptr, + shape = [%c1, %c1, %c1, %c2, %c16], + strides = [%c32, %c32, %c32, %c16, %c1] + : !pto.tensor_view<1x1x1x2x16xi8> + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xf16> + + %mask_part = pto.partition_view %mask_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c16] + : !pto.tensor_view<1x1x1x2x16xi8> -> !pto.partition_tensor_view<1x1x1x2x16xi8> + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xf16> -> !pto.partition_tensor_view<1x1x1x2x128xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xf16> -> !pto.partition_tensor_view<1x1x1x2x128xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xf16> -> !pto.partition_tensor_view<1x1x1x2x128xf16> + + %mask = pto.alloc_tile + : !pto.tile_buf + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x16xi8>) + outs(%mask : !pto.tile_buf) + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x2x128xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x2x128xf16>) + outs(%src1 : !pto.tile_buf) + + pto.tsel ins(%mask, %src0, %src1, %tmp : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xf16>) + return + } + + // Case: f16 2x32 + func.func @TSEL_f16_2x32(%mask_ptr: !pto.ptr, %src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + + %mask_view = pto.make_tensor_view %mask_ptr, + shape = [%c1, %c1, %c1, %c2, %c4], + strides = [%c8, %c8, %c8, %c4, %c1] + : !pto.tensor_view<1x1x1x2x4xi8> + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c2, %c32], + strides = [%c64, %c64, %c64, %c32, %c1] + : !pto.tensor_view<1x1x1x2x32xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c2, %c32], + strides = [%c64, %c64, %c64, %c32, %c1] + : !pto.tensor_view<1x1x1x2x32xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c32], + strides = [%c64, %c64, %c64, %c32, %c1] + : !pto.tensor_view<1x1x1x2x32xf16> + + %mask_part = pto.partition_view %mask_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c4] + : !pto.tensor_view<1x1x1x2x4xi8> -> !pto.partition_tensor_view<1x1x1x2x4xi8> + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c32] + : !pto.tensor_view<1x1x1x2x32xf16> -> !pto.partition_tensor_view<1x1x1x2x32xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c32] + : !pto.tensor_view<1x1x1x2x32xf16> -> !pto.partition_tensor_view<1x1x1x2x32xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c32] + : !pto.tensor_view<1x1x1x2x32xf16> -> !pto.partition_tensor_view<1x1x1x2x32xf16> + + %mask = pto.alloc_tile + : !pto.tile_buf + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x4xi8>) + outs(%mask : !pto.tile_buf) + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x2x32xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x2x32xf16>) + outs(%src1 : !pto.tile_buf) + + pto.tsel ins(%mask, %src0, %src1, %tmp : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x32xf16>) + return + } + + // Case: f16 2x160 + func.func @TSEL_f16_2x160(%mask_ptr: !pto.ptr, %src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c20 = arith.constant 20 : index + %c40 = arith.constant 40 : index + %c160 = arith.constant 160 : index + %c320 = arith.constant 320 : index + + %mask_view = pto.make_tensor_view %mask_ptr, + shape = [%c1, %c1, %c1, %c2, %c20], + strides = [%c40, %c40, %c40, %c20, %c1] + : !pto.tensor_view<1x1x1x2x20xi8> + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c2, %c160], + strides = [%c320, %c320, %c320, %c160, %c1] + : !pto.tensor_view<1x1x1x2x160xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c2, %c160], + strides = [%c320, %c320, %c320, %c160, %c1] + : !pto.tensor_view<1x1x1x2x160xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c160], + strides = [%c320, %c320, %c320, %c160, %c1] + : !pto.tensor_view<1x1x1x2x160xf16> + + %mask_part = pto.partition_view %mask_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c20] + : !pto.tensor_view<1x1x1x2x20xi8> -> !pto.partition_tensor_view<1x1x1x2x20xi8> + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c160] + : !pto.tensor_view<1x1x1x2x160xf16> -> !pto.partition_tensor_view<1x1x1x2x160xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c160] + : !pto.tensor_view<1x1x1x2x160xf16> -> !pto.partition_tensor_view<1x1x1x2x160xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c160] + : !pto.tensor_view<1x1x1x2x160xf16> -> !pto.partition_tensor_view<1x1x1x2x160xf16> + + %mask = pto.alloc_tile + : !pto.tile_buf + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x20xi8>) + outs(%mask : !pto.tile_buf) + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x2x160xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x2x160xf16>) + outs(%src1 : !pto.tile_buf) + + pto.tsel ins(%mask, %src0, %src1, %tmp : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x160xf16>) + return + } + + // Case: i8 2x128 + func.func @TSEL_i8_2x128(%mask_ptr: !pto.ptr, %src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %mask_view = pto.make_tensor_view %mask_ptr, + shape = [%c1, %c1, %c1, %c2, %c16], + strides = [%c32, %c32, %c32, %c16, %c1] + : !pto.tensor_view<1x1x1x2x16xi8> + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xi8> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xi8> + + %mask_part = pto.partition_view %mask_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c16] + : !pto.tensor_view<1x1x1x2x16xi8> -> !pto.partition_tensor_view<1x1x1x2x16xi8> + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xi8> -> !pto.partition_tensor_view<1x1x1x2x128xi8> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xi8> -> !pto.partition_tensor_view<1x1x1x2x128xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xi8> -> !pto.partition_tensor_view<1x1x1x2x128xi8> + + %mask = pto.alloc_tile + : !pto.tile_buf + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x16xi8>) + outs(%mask : !pto.tile_buf) + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x2x128xi8>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x2x128xi8>) + outs(%src1 : !pto.tile_buf) + + pto.tsel ins(%mask, %src0, %src1, %tmp : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xi8>) + return + } + + // Case: i8 2x32 + func.func @TSEL_i8_2x32(%mask_ptr: !pto.ptr, %src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + + %mask_view = pto.make_tensor_view %mask_ptr, + shape = [%c1, %c1, %c1, %c2, %c4], + strides = [%c8, %c8, %c8, %c4, %c1] + : !pto.tensor_view<1x1x1x2x4xi8> + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c2, %c32], + strides = [%c64, %c64, %c64, %c32, %c1] + : !pto.tensor_view<1x1x1x2x32xi8> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c2, %c32], + strides = [%c64, %c64, %c64, %c32, %c1] + : !pto.tensor_view<1x1x1x2x32xi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c32], + strides = [%c64, %c64, %c64, %c32, %c1] + : !pto.tensor_view<1x1x1x2x32xi8> + + %mask_part = pto.partition_view %mask_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c4] + : !pto.tensor_view<1x1x1x2x4xi8> -> !pto.partition_tensor_view<1x1x1x2x4xi8> + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c32] + : !pto.tensor_view<1x1x1x2x32xi8> -> !pto.partition_tensor_view<1x1x1x2x32xi8> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c32] + : !pto.tensor_view<1x1x1x2x32xi8> -> !pto.partition_tensor_view<1x1x1x2x32xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c32] + : !pto.tensor_view<1x1x1x2x32xi8> -> !pto.partition_tensor_view<1x1x1x2x32xi8> + + %mask = pto.alloc_tile + : !pto.tile_buf + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x4xi8>) + outs(%mask : !pto.tile_buf) + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x2x32xi8>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x2x32xi8>) + outs(%src1 : !pto.tile_buf) + + pto.tsel ins(%mask, %src0, %src1, %tmp : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x32xi8>) + return + } + + // Case: i8 2x160 + func.func @TSEL_i8_2x160(%mask_ptr: !pto.ptr, %src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c20 = arith.constant 20 : index + %c40 = arith.constant 40 : index + %c160 = arith.constant 160 : index + %c320 = arith.constant 320 : index + + %mask_view = pto.make_tensor_view %mask_ptr, + shape = [%c1, %c1, %c1, %c2, %c20], + strides = [%c40, %c40, %c40, %c20, %c1] + : !pto.tensor_view<1x1x1x2x20xi8> + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c2, %c160], + strides = [%c320, %c320, %c320, %c160, %c1] + : !pto.tensor_view<1x1x1x2x160xi8> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c2, %c160], + strides = [%c320, %c320, %c320, %c160, %c1] + : !pto.tensor_view<1x1x1x2x160xi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c160], + strides = [%c320, %c320, %c320, %c160, %c1] + : !pto.tensor_view<1x1x1x2x160xi8> + + %mask_part = pto.partition_view %mask_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c20] + : !pto.tensor_view<1x1x1x2x20xi8> -> !pto.partition_tensor_view<1x1x1x2x20xi8> + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c160] + : !pto.tensor_view<1x1x1x2x160xi8> -> !pto.partition_tensor_view<1x1x1x2x160xi8> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c160] + : !pto.tensor_view<1x1x1x2x160xi8> -> !pto.partition_tensor_view<1x1x1x2x160xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c160] + : !pto.tensor_view<1x1x1x2x160xi8> -> !pto.partition_tensor_view<1x1x1x2x160xi8> + + %mask = pto.alloc_tile + : !pto.tile_buf + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x20xi8>) + outs(%mask : !pto.tile_buf) + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x2x160xi8>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x2x160xi8>) + outs(%src1 : !pto.tile_buf) + + pto.tsel ins(%mask, %src0, %src1, %tmp : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x160xi8>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsels/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tsels/CMakeLists.txt new file mode 100644 index 000000000..d699a3c35 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsels/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tsels) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsels/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tsels/cases.py new file mode 100644 index 000000000..496b37c96 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsels/cases.py @@ -0,0 +1,50 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tsels ST test cases. + +Each case defines: + - name: case identifier + - dtype: numpy dtype for data (src/dst) + - dtype_mask: numpy dtype for mask + - dst_shape: (dst_rows, dst_cols) — allocated dst tile dimensions + - mask_shape: (mask_rows, mask_cols) — allocated mask tile dimensions + - src_shape: (src_rows, src_cols) — allocated src tile dimensions + - valid_shape: (valid_rows, valid_cols) — effective computation region + - eps: tolerance for numpy.allclose (atol and rtol) +""" + +import numpy as np + +CASES = [ + {"name": "uint8_uint8_2x32_2x32_2x32_2x32", "dtype": np.uint8, "dtype_mask": np.uint8, "shape": (2, 32), "dst_shape": (2, 32), "dst_valid_shape": (2, 32), "mask_shape": (2, 32), "src_shape": (2, 32), "valid_shape": (2, 32), "eps": 0}, + {"name": "uint8_uint16_2x32_2x16_2x32_2x32", "dtype": np.uint8, "dtype_mask": np.uint16, "shape": (2, 32), "dst_shape": (2, 32), "dst_valid_shape": (2, 32), "mask_shape": (2, 16), "src_shape": (2, 32), "valid_shape": (2, 32), "eps": 0}, + {"name": "uint8_uint32_2x32_2x8_2x32_2x32", "dtype": np.uint8, "dtype_mask": np.uint32, "shape": (2, 32), "dst_shape": (2, 32), "dst_valid_shape": (2, 32), "mask_shape": (2, 8), "src_shape": (2, 32), "valid_shape": (2, 32), "eps": 0}, + {"name": "uint16_uint8_2x16_2x32_2x16_2x16", "dtype": np.uint16, "dtype_mask": np.uint8, "shape": (2, 16), "dst_shape": (2, 16), "dst_valid_shape": (2, 16), "mask_shape": (2, 32), "src_shape": (2, 16), "valid_shape": (2, 16), "eps": 0}, + {"name": "uint16_uint16_2x16_2x16_2x16_2x16", "dtype": np.uint16, "dtype_mask": np.uint16, "shape": (2, 16), "dst_shape": (2, 16), "dst_valid_shape": (2, 16), "mask_shape": (2, 16), "src_shape": (2, 16), "valid_shape": (2, 16), "eps": 0}, + {"name": "uint16_uint32_2x16_2x8_2x16_2x16", "dtype": np.uint16, "dtype_mask": np.uint32, "shape": (2, 16), "dst_shape": (2, 16), "dst_valid_shape": (2, 16), "mask_shape": (2, 8), "src_shape": (2, 16), "valid_shape": (2, 16), "eps": 0}, + {"name": "uint32_uint8_2x8_2x32_2x8_2x8", "dtype": np.uint32, "dtype_mask": np.uint8, "shape": (2, 8), "dst_shape": (2, 8), "dst_valid_shape": (2, 8), "mask_shape": (2, 32), "src_shape": (2, 8), "valid_shape": (2, 8), "eps": 0}, + {"name": "uint32_uint16_2x8_2x16_2x8_2x8", "dtype": np.uint32, "dtype_mask": np.uint16, "shape": (2, 8), "dst_shape": (2, 8), "dst_valid_shape": (2, 8), "mask_shape": (2, 16), "src_shape": (2, 8), "valid_shape": (2, 8), "eps": 0}, + {"name": "uint32_uint32_2x8_2x8_2x8_2x8", "dtype": np.uint32, "dtype_mask": np.uint32, "shape": (2, 8), "dst_shape": (2, 8), "dst_valid_shape": (2, 8), "mask_shape": (2, 8), "src_shape": (2, 8), "valid_shape": (2, 8), "eps": 0}, + {"name": "f16_uint8_2x16_2x32_2x16_2x16", "dtype": np.float16, "dtype_mask": np.uint8, "shape": (2, 16), "dst_shape": (2, 16), "dst_valid_shape": (2, 16), "mask_shape": (2, 32), "src_shape": (2, 16), "valid_shape": (2, 16), "eps": 1e-3}, + {"name": "f16_uint16_2x16_2x16_2x16_2x16", "dtype": np.float16, "dtype_mask": np.uint16, "shape": (2, 16), "dst_shape": (2, 16), "dst_valid_shape": (2, 16), "mask_shape": (2, 16), "src_shape": (2, 16), "valid_shape": (2, 16), "eps": 1e-3}, + {"name": "f16_uint32_2x16_2x8_2x16_2x16", "dtype": np.float16, "dtype_mask": np.uint32, "shape": (2, 16), "dst_shape": (2, 16), "dst_valid_shape": (2, 16), "mask_shape": (2, 8), "src_shape": (2, 16), "valid_shape": (2, 16), "eps": 1e-3}, + {"name": "f32_uint8_2x8_2x32_2x8_2x8", "dtype": np.float32, "dtype_mask": np.uint8, "shape": (2, 8), "dst_shape": (2, 8), "dst_valid_shape": (2, 8), "mask_shape": (2, 32), "src_shape": (2, 8), "valid_shape": (2, 8), "eps": 1e-6}, + {"name": "f32_uint16_2x8_2x16_2x8_2x8", "dtype": np.float32, "dtype_mask": np.uint16, "shape": (2, 8), "dst_shape": (2, 8), "dst_valid_shape": (2, 8), "mask_shape": (2, 16), "src_shape": (2, 8), "valid_shape": (2, 8), "eps": 1e-6}, + {"name": "f32_uint32_2x8_2x8_2x8_2x8", "dtype": np.float32, "dtype_mask": np.uint32, "shape": (2, 8), "dst_shape": (2, 8), "dst_valid_shape": (2, 8), "mask_shape": (2, 8), "src_shape": (2, 8), "valid_shape": (2, 8), "eps": 1e-6}, + {"name": "uint8_uint8_2x32_2x64_2x128_2x31", "dtype": np.uint8, "dtype_mask": np.uint8, "shape": (2, 32), "dst_shape": (2, 32), "dst_valid_shape": (2, 31), "mask_shape": (2, 64), "src_shape": (2, 128), "valid_shape": (2, 31), "eps": 0}, + {"name": "uint16_uint8_2x32_2x64_2x128_2x31", "dtype": np.uint16, "dtype_mask": np.uint8, "shape": (2, 32), "dst_shape": (2, 32), "dst_valid_shape": (2, 31), "mask_shape": (2, 64), "src_shape": (2, 128), "valid_shape": (2, 31), "eps": 0}, + {"name": "f32_uint8_2x32_2x64_2x128_2x31", "dtype": np.float32, "dtype_mask": np.uint8, "shape": (2, 32), "dst_shape": (2, 32), "dst_valid_shape": (2, 31), "mask_shape": (2, 64), "src_shape": (2, 128), "valid_shape": (2, 31), "eps": 1e-6}, + {"name": "uint8_uint8_32x672_32x96_32x672_32x666", "dtype": np.uint8, "dtype_mask": np.uint8, "shape": (32, 672), "dst_shape": (32, 672), "dst_valid_shape": (32, 666), "mask_shape": (32, 96), "src_shape": (32, 672), "valid_shape": (32, 666), "eps": 0}, + {"name": "f16_uint8_32x672_32x96_32x672_32x666", "dtype": np.float16, "dtype_mask": np.uint8, "shape": (32, 672), "dst_shape": (32, 672), "dst_valid_shape": (32, 666), "mask_shape": (32, 96), "src_shape": (32, 672), "valid_shape": (32, 666), "eps": 1e-3}, + {"name": "f32_uint8_32x672_32x96_32x672_32x666", "dtype": np.float32, "dtype_mask": np.uint8, "shape": (32, 672), "dst_shape": (32, 672), "dst_valid_shape": (32, 666), "mask_shape": (32, 96), "src_shape": (32, 672), "valid_shape": (32, 666), "eps": 1e-6}, + {"name": "f32_uint8_1x8192_1x4096_1x8192_1x8192", "dtype": np.float32, "dtype_mask": np.uint8, "shape": (1, 8192), "dst_shape": (1, 8192), "dst_valid_shape": (1, 8192), "mask_shape": (1, 4096), "src_shape": (1, 8192), "valid_shape": (1, 8192), "eps": 1e-6}, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsels/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tsels/compare.py new file mode 100644 index 000000000..6af6f6d5c --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsels/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dst_shape = case["dst_shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(dst_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(dst_shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsels/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tsels/gen_data.py new file mode 100644 index 000000000..b462425c6 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsels/gen_data.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + dtype_mask = case["dtype_mask"] + dst_shape = case["dst_shape"] + mask_shape = case["mask_shape"] + src_shape = case["src_shape"] + valid_shape = case["valid_shape"] + height, width = valid_shape + + if dtype in (np.int8, np.uint8, np.int16, np.uint16, np.int32, np.uint32): + dtype_info = np.iinfo(dtype) + input1 = np.random.randint(dtype_info.min, dtype_info.max, size=src_shape).astype(dtype) + input2 = np.random.randint(dtype_info.min, dtype_info.max, size=[1]).astype(dtype) + else: + dtype_info = np.finfo(dtype) + input1 = np.random.uniform(low=dtype_info.min, high=dtype_info.max, size=src_shape).astype(dtype) + input2 = np.random.uniform(low=dtype_info.min, high=dtype_info.max, size=[1]).astype(dtype) + + mask_dtype_info = np.iinfo(dtype_mask) + mask = np.random.randint(mask_dtype_info.min, mask_dtype_info.max, size=mask_shape).astype(dtype_mask) + mask_u8view = mask.view(np.uint8).reshape(mask_shape[0], -1) + golden = np.zeros(dst_shape, dtype=dtype) + + for y in range(height): + for x in range(width): + do_select = (1 << (x & 7)) & mask_u8view[y, x >> 3] + golden[y, x] = input1[y, x] if do_select != 0 else input2[0] + + save_case_data(case["name"], {"mask": mask, "input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} dst={dst_shape} mask={mask_shape} src={src_shape} valid={valid_shape} dtype={dtype.__name__} mask_dtype={dtype_mask.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsels/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tsels/launch.cpp new file mode 100644 index 000000000..30372dbb8 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsels/launch.cpp @@ -0,0 +1,168 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +extern "C" __global__ AICORE void TSELS_uint8_uint8_2x32_2x32_2x32_2x32(__gm__ uint8_t *mask, __gm__ uint8_t *src, __gm__ uint8_t *dst, uint8_t scalar); +void LaunchTSELS_uint8_uint8_2x32_2x32_2x32_2x32(uint8_t *mask, uint8_t *src, uint8_t *dst, void *scalar_ptr, void *stream) { + uint8_t scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(uint8_t)); + TSELS_uint8_uint8_2x32_2x32_2x32_2x32<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ uint8_t *)src, (__gm__ uint8_t *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_uint8_uint16_2x32_2x16_2x32_2x32(__gm__ uint16_t *mask, __gm__ uint8_t *src, __gm__ uint8_t *dst, uint8_t scalar); +void LaunchTSELS_uint8_uint16_2x32_2x16_2x32_2x32(uint16_t *mask, uint8_t *src, uint8_t *dst, void *scalar_ptr, void *stream) { + uint8_t scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(uint8_t)); + TSELS_uint8_uint16_2x32_2x16_2x32_2x32<<<1, nullptr, stream>>>((__gm__ uint16_t *)mask, (__gm__ uint8_t *)src, (__gm__ uint8_t *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_uint8_uint32_2x32_2x8_2x32_2x32(__gm__ uint32_t *mask, __gm__ uint8_t *src, __gm__ uint8_t *dst, uint8_t scalar); +void LaunchTSELS_uint8_uint32_2x32_2x8_2x32_2x32(uint32_t *mask, uint8_t *src, uint8_t *dst, void *scalar_ptr, void *stream) { + uint8_t scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(uint8_t)); + TSELS_uint8_uint32_2x32_2x8_2x32_2x32<<<1, nullptr, stream>>>((__gm__ uint32_t *)mask, (__gm__ uint8_t *)src, (__gm__ uint8_t *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_uint16_uint8_2x16_2x32_2x16_2x16(__gm__ uint8_t *mask, __gm__ uint16_t *src, __gm__ uint16_t *dst, uint16_t scalar); +void LaunchTSELS_uint16_uint8_2x16_2x32_2x16_2x16(uint8_t *mask, uint16_t *src, uint16_t *dst, void *scalar_ptr, void *stream) { + uint16_t scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(uint16_t)); + TSELS_uint16_uint8_2x16_2x32_2x16_2x16<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ uint16_t *)src, (__gm__ uint16_t *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_uint16_uint16_2x16_2x16_2x16_2x16(__gm__ uint16_t *mask, __gm__ uint16_t *src, __gm__ uint16_t *dst, uint16_t scalar); +void LaunchTSELS_uint16_uint16_2x16_2x16_2x16_2x16(uint16_t *mask, uint16_t *src, uint16_t *dst, void *scalar_ptr, void *stream) { + uint16_t scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(uint16_t)); + TSELS_uint16_uint16_2x16_2x16_2x16_2x16<<<1, nullptr, stream>>>((__gm__ uint16_t *)mask, (__gm__ uint16_t *)src, (__gm__ uint16_t *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_uint16_uint32_2x16_2x8_2x16_2x16(__gm__ uint32_t *mask, __gm__ uint16_t *src, __gm__ uint16_t *dst, uint16_t scalar); +void LaunchTSELS_uint16_uint32_2x16_2x8_2x16_2x16(uint32_t *mask, uint16_t *src, uint16_t *dst, void *scalar_ptr, void *stream) { + uint16_t scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(uint16_t)); + TSELS_uint16_uint32_2x16_2x8_2x16_2x16<<<1, nullptr, stream>>>((__gm__ uint32_t *)mask, (__gm__ uint16_t *)src, (__gm__ uint16_t *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_uint32_uint8_2x8_2x32_2x8_2x8(__gm__ uint8_t *mask, __gm__ uint32_t *src, __gm__ uint32_t *dst, uint32_t scalar); +void LaunchTSELS_uint32_uint8_2x8_2x32_2x8_2x8(uint8_t *mask, uint32_t *src, uint32_t *dst, void *scalar_ptr, void *stream) { + uint32_t scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(uint32_t)); + TSELS_uint32_uint8_2x8_2x32_2x8_2x8<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ uint32_t *)src, (__gm__ uint32_t *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_uint32_uint16_2x8_2x16_2x8_2x8(__gm__ uint16_t *mask, __gm__ uint32_t *src, __gm__ uint32_t *dst, uint32_t scalar); +void LaunchTSELS_uint32_uint16_2x8_2x16_2x8_2x8(uint16_t *mask, uint32_t *src, uint32_t *dst, void *scalar_ptr, void *stream) { + uint32_t scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(uint32_t)); + TSELS_uint32_uint16_2x8_2x16_2x8_2x8<<<1, nullptr, stream>>>((__gm__ uint16_t *)mask, (__gm__ uint32_t *)src, (__gm__ uint32_t *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_uint32_uint32_2x8_2x8_2x8_2x8(__gm__ uint32_t *mask, __gm__ uint32_t *src, __gm__ uint32_t *dst, uint32_t scalar); +void LaunchTSELS_uint32_uint32_2x8_2x8_2x8_2x8(uint32_t *mask, uint32_t *src, uint32_t *dst, void *scalar_ptr, void *stream) { + uint32_t scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(uint32_t)); + TSELS_uint32_uint32_2x8_2x8_2x8_2x8<<<1, nullptr, stream>>>((__gm__ uint32_t *)mask, (__gm__ uint32_t *)src, (__gm__ uint32_t *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_f16_uint8_2x16_2x32_2x16_2x16(__gm__ uint8_t *mask, __gm__ uint16_t *src, __gm__ uint16_t *dst, uint16_t scalar); +void LaunchTSELS_f16_uint8_2x16_2x32_2x16_2x16(uint8_t *mask, uint16_t *src, uint16_t *dst, void *scalar_ptr, void *stream) { + uint16_t scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(uint16_t)); + TSELS_f16_uint8_2x16_2x32_2x16_2x16<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ uint16_t *)src, (__gm__ uint16_t *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_f16_uint16_2x16_2x16_2x16_2x16(__gm__ uint16_t *mask, __gm__ uint16_t *src, __gm__ uint16_t *dst, uint16_t scalar); +void LaunchTSELS_f16_uint16_2x16_2x16_2x16_2x16(uint16_t *mask, uint16_t *src, uint16_t *dst, void *scalar_ptr, void *stream) { + uint16_t scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(uint16_t)); + TSELS_f16_uint16_2x16_2x16_2x16_2x16<<<1, nullptr, stream>>>((__gm__ uint16_t *)mask, (__gm__ uint16_t *)src, (__gm__ uint16_t *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_f16_uint32_2x16_2x8_2x16_2x16(__gm__ uint32_t *mask, __gm__ uint16_t *src, __gm__ uint16_t *dst, uint16_t scalar); +void LaunchTSELS_f16_uint32_2x16_2x8_2x16_2x16(uint32_t *mask, uint16_t *src, uint16_t *dst, void *scalar_ptr, void *stream) { + uint16_t scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(uint16_t)); + TSELS_f16_uint32_2x16_2x8_2x16_2x16<<<1, nullptr, stream>>>((__gm__ uint32_t *)mask, (__gm__ uint16_t *)src, (__gm__ uint16_t *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_f32_uint8_2x8_2x32_2x8_2x8(__gm__ uint8_t *mask, __gm__ float *src, __gm__ float *dst, float scalar); +void LaunchTSELS_f32_uint8_2x8_2x32_2x8_2x8(uint8_t *mask, float *src, float *dst, void *scalar_ptr, void *stream) { + float scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(float)); + TSELS_f32_uint8_2x8_2x32_2x8_2x8<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ float *)src, (__gm__ float *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_f32_uint16_2x8_2x16_2x8_2x8(__gm__ uint16_t *mask, __gm__ float *src, __gm__ float *dst, float scalar); +void LaunchTSELS_f32_uint16_2x8_2x16_2x8_2x8(uint16_t *mask, float *src, float *dst, void *scalar_ptr, void *stream) { + float scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(float)); + TSELS_f32_uint16_2x8_2x16_2x8_2x8<<<1, nullptr, stream>>>((__gm__ uint16_t *)mask, (__gm__ float *)src, (__gm__ float *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_f32_uint32_2x8_2x8_2x8_2x8(__gm__ uint32_t *mask, __gm__ float *src, __gm__ float *dst, float scalar); +void LaunchTSELS_f32_uint32_2x8_2x8_2x8_2x8(uint32_t *mask, float *src, float *dst, void *scalar_ptr, void *stream) { + float scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(float)); + TSELS_f32_uint32_2x8_2x8_2x8_2x8<<<1, nullptr, stream>>>((__gm__ uint32_t *)mask, (__gm__ float *)src, (__gm__ float *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_uint8_uint8_2x32_2x64_2x128_2x31(__gm__ uint8_t *mask, __gm__ uint8_t *src, __gm__ uint8_t *dst, uint8_t scalar); +void LaunchTSELS_uint8_uint8_2x32_2x64_2x128_2x31(uint8_t *mask, uint8_t *src, uint8_t *dst, void *scalar_ptr, void *stream) { + uint8_t scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(uint8_t)); + TSELS_uint8_uint8_2x32_2x64_2x128_2x31<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ uint8_t *)src, (__gm__ uint8_t *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_uint16_uint8_2x32_2x64_2x128_2x31(__gm__ uint8_t *mask, __gm__ uint16_t *src, __gm__ uint16_t *dst, uint16_t scalar); +void LaunchTSELS_uint16_uint8_2x32_2x64_2x128_2x31(uint8_t *mask, uint16_t *src, uint16_t *dst, void *scalar_ptr, void *stream) { + uint16_t scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(uint16_t)); + TSELS_uint16_uint8_2x32_2x64_2x128_2x31<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ uint16_t *)src, (__gm__ uint16_t *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_f32_uint8_2x32_2x64_2x128_2x31(__gm__ uint8_t *mask, __gm__ float *src, __gm__ float *dst, float scalar); +void LaunchTSELS_f32_uint8_2x32_2x64_2x128_2x31(uint8_t *mask, float *src, float *dst, void *scalar_ptr, void *stream) { + float scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(float)); + TSELS_f32_uint8_2x32_2x64_2x128_2x31<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ float *)src, (__gm__ float *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_uint8_uint8_32x672_32x96_32x672_32x666(__gm__ uint8_t *mask, __gm__ uint8_t *src, __gm__ uint8_t *dst, uint8_t scalar); +void LaunchTSELS_uint8_uint8_32x672_32x96_32x672_32x666(uint8_t *mask, uint8_t *src, uint8_t *dst, void *scalar_ptr, void *stream) { + uint8_t scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(uint8_t)); + TSELS_uint8_uint8_32x672_32x96_32x672_32x666<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ uint8_t *)src, (__gm__ uint8_t *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_f16_uint8_32x672_32x96_32x672_32x666(__gm__ uint8_t *mask, __gm__ uint16_t *src, __gm__ uint16_t *dst, uint16_t scalar); +void LaunchTSELS_f16_uint8_32x672_32x96_32x672_32x666(uint8_t *mask, uint16_t *src, uint16_t *dst, void *scalar_ptr, void *stream) { + uint16_t scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(uint16_t)); + TSELS_f16_uint8_32x672_32x96_32x672_32x666<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ uint16_t *)src, (__gm__ uint16_t *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_f32_uint8_32x672_32x96_32x672_32x666(__gm__ uint8_t *mask, __gm__ float *src, __gm__ float *dst, float scalar); +void LaunchTSELS_f32_uint8_32x672_32x96_32x672_32x666(uint8_t *mask, float *src, float *dst, void *scalar_ptr, void *stream) { + float scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(float)); + TSELS_f32_uint8_32x672_32x96_32x672_32x666<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ float *)src, (__gm__ float *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_f32_uint8_1x8192_1x4096_1x8192_1x8192(__gm__ uint8_t *mask, __gm__ float *src, __gm__ float *dst, float scalar); +void LaunchTSELS_f32_uint8_1x8192_1x4096_1x8192_1x8192(uint8_t *mask, float *src, float *dst, void *scalar_ptr, void *stream) { + float scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(float)); + TSELS_f32_uint8_1x8192_1x4096_1x8192_1x8192<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ float *)src, (__gm__ float *)dst, scalar); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsels/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tsels/main.cpp new file mode 100644 index 000000000..351e822cc --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsels/main.cpp @@ -0,0 +1,181 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +void LaunchTSELS_uint8_uint8_2x32_2x32_2x32_2x32(uint8_t *mask, uint8_t *src, uint8_t *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_uint8_uint16_2x32_2x16_2x32_2x32(uint16_t *mask, uint8_t *src, uint8_t *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_uint8_uint32_2x32_2x8_2x32_2x32(uint32_t *mask, uint8_t *src, uint8_t *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_uint16_uint8_2x16_2x32_2x16_2x16(uint8_t *mask, uint16_t *src, uint16_t *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_uint16_uint16_2x16_2x16_2x16_2x16(uint16_t *mask, uint16_t *src, uint16_t *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_uint16_uint32_2x16_2x8_2x16_2x16(uint32_t *mask, uint16_t *src, uint16_t *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_uint32_uint8_2x8_2x32_2x8_2x8(uint8_t *mask, uint32_t *src, uint32_t *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_uint32_uint16_2x8_2x16_2x8_2x8(uint16_t *mask, uint32_t *src, uint32_t *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_uint32_uint32_2x8_2x8_2x8_2x8(uint32_t *mask, uint32_t *src, uint32_t *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_f16_uint8_2x16_2x32_2x16_2x16(uint8_t *mask, uint16_t *src, uint16_t *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_f16_uint16_2x16_2x16_2x16_2x16(uint16_t *mask, uint16_t *src, uint16_t *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_f16_uint32_2x16_2x8_2x16_2x16(uint32_t *mask, uint16_t *src, uint16_t *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_f32_uint8_2x8_2x32_2x8_2x8(uint8_t *mask, float *src, float *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_f32_uint16_2x8_2x16_2x8_2x8(uint16_t *mask, float *src, float *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_f32_uint32_2x8_2x8_2x8_2x8(uint32_t *mask, float *src, float *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_uint8_uint8_2x32_2x64_2x128_2x31(uint8_t *mask, uint8_t *src, uint8_t *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_uint16_uint8_2x32_2x64_2x128_2x31(uint8_t *mask, uint16_t *src, uint16_t *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_f32_uint8_2x32_2x64_2x128_2x31(uint8_t *mask, float *src, float *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_uint8_uint8_32x672_32x96_32x672_32x666(uint8_t *mask, uint8_t *src, uint8_t *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_f16_uint8_32x672_32x96_32x672_32x666(uint8_t *mask, uint16_t *src, uint16_t *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_f32_uint8_32x672_32x96_32x672_32x666(uint8_t *mask, float *src, float *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_f32_uint8_1x8192_1x4096_1x8192_1x8192(uint8_t *mask, float *src, float *dst, void *scalar_ptr, void *stream); + +struct TestCase { + const char *name; + void (*launch)(void*, void*, void*, void*, void*); + size_t dstRows, dstCols; + size_t maskRows, maskCols; + size_t srcRows, srcCols; + size_t validRows, validCols; + size_t dstElemSize; + size_t maskElemSize; + size_t srcElemSize; +}; + +static const TestCase kCases[] = { + {"uint8_uint8_2x32_2x32_2x32_2x32", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_uint8_uint8_2x32_2x32_2x32_2x32, 2, 32, 2, 32, 2, 32, 2, 32, 1, 1, 1}, + {"uint8_uint16_2x32_2x16_2x32_2x32", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_uint8_uint16_2x32_2x16_2x32_2x32, 2, 32, 2, 16, 2, 32, 2, 32, 1, 2, 1}, + {"uint8_uint32_2x32_2x8_2x32_2x32", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_uint8_uint32_2x32_2x8_2x32_2x32, 2, 32, 2, 8, 2, 32, 2, 32, 1, 4, 1}, + {"uint16_uint8_2x16_2x32_2x16_2x16", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_uint16_uint8_2x16_2x32_2x16_2x16, 2, 16, 2, 32, 2, 16, 2, 16, 2, 1, 2}, + {"uint16_uint16_2x16_2x16_2x16_2x16", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_uint16_uint16_2x16_2x16_2x16_2x16, 2, 16, 2, 16, 2, 16, 2, 16, 2, 2, 2}, + {"uint16_uint32_2x16_2x8_2x16_2x16", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_uint16_uint32_2x16_2x8_2x16_2x16, 2, 16, 2, 8, 2, 16, 2, 16, 2, 4, 2}, + {"uint32_uint8_2x8_2x32_2x8_2x8", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_uint32_uint8_2x8_2x32_2x8_2x8, 2, 8, 2, 32, 2, 8, 2, 8, 4, 1, 4}, + {"uint32_uint16_2x8_2x16_2x8_2x8", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_uint32_uint16_2x8_2x16_2x8_2x8, 2, 8, 2, 16, 2, 8, 2, 8, 4, 2, 4}, + {"uint32_uint32_2x8_2x8_2x8_2x8", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_uint32_uint32_2x8_2x8_2x8_2x8, 2, 8, 2, 8, 2, 8, 2, 8, 4, 4, 4}, + {"f16_uint8_2x16_2x32_2x16_2x16", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_f16_uint8_2x16_2x32_2x16_2x16, 2, 16, 2, 32, 2, 16, 2, 16, 2, 1, 2}, + {"f16_uint16_2x16_2x16_2x16_2x16", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_f16_uint16_2x16_2x16_2x16_2x16, 2, 16, 2, 16, 2, 16, 2, 16, 2, 2, 2}, + {"f16_uint32_2x16_2x8_2x16_2x16", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_f16_uint32_2x16_2x8_2x16_2x16, 2, 16, 2, 8, 2, 16, 2, 16, 2, 4, 2}, + {"f32_uint8_2x8_2x32_2x8_2x8", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_f32_uint8_2x8_2x32_2x8_2x8, 2, 8, 2, 32, 2, 8, 2, 8, 4, 1, 4}, + {"f32_uint16_2x8_2x16_2x8_2x8", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_f32_uint16_2x8_2x16_2x8_2x8, 2, 8, 2, 16, 2, 8, 2, 8, 4, 2, 4}, + {"f32_uint32_2x8_2x8_2x8_2x8", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_f32_uint32_2x8_2x8_2x8_2x8, 2, 8, 2, 8, 2, 8, 2, 8, 4, 4, 4}, + {"uint8_uint8_2x32_2x64_2x128_2x31", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_uint8_uint8_2x32_2x64_2x128_2x31, 2, 32, 2, 64, 2, 128, 2, 31, 1, 1, 1}, + {"uint16_uint8_2x32_2x64_2x128_2x31", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_uint16_uint8_2x32_2x64_2x128_2x31, 2, 32, 2, 64, 2, 128, 2, 31, 2, 1, 2}, + {"f32_uint8_2x32_2x64_2x128_2x31", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_f32_uint8_2x32_2x64_2x128_2x31, 2, 32, 2, 64, 2, 128, 2, 31, 4, 1, 4}, + {"uint8_uint8_32x672_32x96_32x672_32x666", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_uint8_uint8_32x672_32x96_32x672_32x666, 32, 672, 32, 96, 32, 672, 32, 666, 1, 1, 1}, + {"f16_uint8_32x672_32x96_32x672_32x666", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_f16_uint8_32x672_32x96_32x672_32x666, 32, 672, 32, 96, 32, 672, 32, 666, 2, 1, 2}, + {"f32_uint8_32x672_32x96_32x672_32x666", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_f32_uint8_32x672_32x96_32x672_32x666, 32, 672, 32, 96, 32, 672, 32, 666, 4, 1, 4}, + {"f32_uint8_1x8192_1x4096_1x8192_1x8192", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_f32_uint8_1x8192_1x4096_1x8192_1x8192, 1, 8192, 1, 4096, 1, 8192, 1, 8192, 4, 1, 4}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + size_t dstFileSize = tc.dstRows * tc.dstCols * tc.dstElemSize; + size_t maskFileSize = tc.maskRows * tc.maskCols * tc.maskElemSize; + size_t srcFileSize = tc.srcRows * tc.srcCols * tc.srcElemSize; + size_t scalarFileSize = tc.dstElemSize; + + std::printf("[INFO] === case: %s (dst=%zux%zu, mask=%zux%zu, src=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.dstRows, tc.dstCols, tc.maskRows, tc.maskCols, tc.srcRows, tc.srcCols, tc.validRows, tc.validCols); + + std::string caseDir = std::string("./") + tc.name; + const size_t maskFileSizeBuf = maskFileSize; + const size_t srcFileSizeBuf = srcFileSize; + const size_t scalarFileSizeBuf = scalarFileSize; + + void *maskHost = nullptr, *srcHost = nullptr, *dstHost = nullptr, *scalarHost = nullptr; + void *maskDevice = nullptr, *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&maskHost, maskFileSize); + aclrtMallocHost(&srcHost, srcFileSize); + aclrtMallocHost(&dstHost, dstFileSize); + aclrtMallocHost(&scalarHost, scalarFileSize); + + aclrtMalloc(&maskDevice, maskFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&srcDevice, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + memset(dstHost, 0, dstFileSize); + + if (!ReadFile(caseDir + "/mask.bin", maskFileSize, maskHost, maskFileSizeBuf)) { + std::fprintf(stderr, "[ERROR] failed to read %s/mask.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile(caseDir + "/input1.bin", srcFileSize, srcHost, srcFileSizeBuf)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile(caseDir + "/input2.bin", scalarFileSize, scalarHost, scalarFileSizeBuf)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(maskDevice, maskFileSize, maskHost, maskFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(srcDevice, srcFileSize, srcHost, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(maskDevice, srcDevice, dstDevice, scalarHost, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (maskDevice != nullptr) aclrtFree(maskDevice); + if (srcDevice != nullptr) aclrtFree(srcDevice); + if (dstDevice != nullptr) aclrtFree(dstDevice); + if (maskHost != nullptr) aclrtFreeHost(maskHost); + if (srcHost != nullptr) aclrtFreeHost(srcHost); + if (dstHost != nullptr) aclrtFreeHost(dstHost); + if (scalarHost != nullptr) aclrtFreeHost(scalarHost); + + if (rc == 0) std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsels/tsels.pto b/test/tilelang_st/npu/a5/src/st/testcase/tsels/tsels.pto new file mode 100644 index 000000000..c3f9dacd0 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsels/tsels.pto @@ -0,0 +1,654 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tsels: tload(mask) + tload(src) + tsels(mask,src,tmp,scalar)->dst + tstore(dst) +// 22 cases from pto-isa tests. + +module { + func.func @TSELS_uint8_uint8_2x32_2x32_2x32_2x32(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c2, %c32], strides = [%c64, %c64, %c64, %c32, %c1] : !pto.tensor_view<1x1x1x2x32xi8> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c2, %c32], strides = [%c64, %c64, %c64, %c32, %c1] : !pto.tensor_view<1x1x1x2x32xi8> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c2, %c32], strides = [%c64, %c64, %c64, %c32, %c1] : !pto.tensor_view<1x1x1x2x32xi8> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c32] : !pto.tensor_view<1x1x1x2x32xi8> -> !pto.partition_tensor_view<1x1x1x2x32xi8> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c32] : !pto.tensor_view<1x1x1x2x32xi8> -> !pto.partition_tensor_view<1x1x1x2x32xi8> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c32] : !pto.tensor_view<1x1x1x2x32xi8> -> !pto.partition_tensor_view<1x1x1x2x32xi8> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x32xi8>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x32xi8>) outs(%src_tile : !pto.tile_buf) + %scalar_i8 = arith.trunci %scalar : i32 to i8 + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_i8 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, i8) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x32xi8>) + return + } + + func.func @TSELS_uint8_uint16_2x32_2x16_2x32_2x32(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c2, %c16], strides = [%c32, %c32, %c32, %c16, %c1] : !pto.tensor_view<1x1x1x2x16xi16> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c2, %c32], strides = [%c64, %c64, %c64, %c32, %c1] : !pto.tensor_view<1x1x1x2x32xi8> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c2, %c32], strides = [%c64, %c64, %c64, %c32, %c1] : !pto.tensor_view<1x1x1x2x32xi8> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c16] : !pto.tensor_view<1x1x1x2x16xi16> -> !pto.partition_tensor_view<1x1x1x2x16xi16> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c32] : !pto.tensor_view<1x1x1x2x32xi8> -> !pto.partition_tensor_view<1x1x1x2x32xi8> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c32] : !pto.tensor_view<1x1x1x2x32xi8> -> !pto.partition_tensor_view<1x1x1x2x32xi8> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x16xi16>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x32xi8>) outs(%src_tile : !pto.tile_buf) + %scalar_i8 = arith.trunci %scalar : i32 to i8 + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_i8 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, i8) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x32xi8>) + return + } + + func.func @TSELS_uint8_uint32_2x32_2x8_2x32_2x32(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c2, %c8], strides = [%c16, %c16, %c16, %c8, %c1] : !pto.tensor_view<1x1x1x2x8xi32> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c2, %c32], strides = [%c64, %c64, %c64, %c32, %c1] : !pto.tensor_view<1x1x1x2x32xi8> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c2, %c32], strides = [%c64, %c64, %c64, %c32, %c1] : !pto.tensor_view<1x1x1x2x32xi8> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c8] : !pto.tensor_view<1x1x1x2x8xi32> -> !pto.partition_tensor_view<1x1x1x2x8xi32> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c32] : !pto.tensor_view<1x1x1x2x32xi8> -> !pto.partition_tensor_view<1x1x1x2x32xi8> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c32] : !pto.tensor_view<1x1x1x2x32xi8> -> !pto.partition_tensor_view<1x1x1x2x32xi8> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x8xi32>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x32xi8>) outs(%src_tile : !pto.tile_buf) + %scalar_i8 = arith.trunci %scalar : i32 to i8 + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_i8 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, i8) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x32xi8>) + return + } + + func.func @TSELS_uint16_uint8_2x16_2x32_2x16_2x16(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c2, %c32], strides = [%c64, %c64, %c64, %c32, %c1] : !pto.tensor_view<1x1x1x2x32xi8> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c2, %c16], strides = [%c32, %c32, %c32, %c16, %c1] : !pto.tensor_view<1x1x1x2x16xi16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c2, %c16], strides = [%c32, %c32, %c32, %c16, %c1] : !pto.tensor_view<1x1x1x2x16xi16> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c32] : !pto.tensor_view<1x1x1x2x32xi8> -> !pto.partition_tensor_view<1x1x1x2x32xi8> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c16] : !pto.tensor_view<1x1x1x2x16xi16> -> !pto.partition_tensor_view<1x1x1x2x16xi16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c16] : !pto.tensor_view<1x1x1x2x16xi16> -> !pto.partition_tensor_view<1x1x1x2x16xi16> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x32xi8>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x16xi16>) outs(%src_tile : !pto.tile_buf) + %scalar_i16 = arith.trunci %scalar : i32 to i16 + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_i16 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, i16) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x16xi16>) + return + } + + func.func @TSELS_uint16_uint16_2x16_2x16_2x16_2x16(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c2, %c16], strides = [%c32, %c32, %c32, %c16, %c1] : !pto.tensor_view<1x1x1x2x16xi16> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c2, %c16], strides = [%c32, %c32, %c32, %c16, %c1] : !pto.tensor_view<1x1x1x2x16xi16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c2, %c16], strides = [%c32, %c32, %c32, %c16, %c1] : !pto.tensor_view<1x1x1x2x16xi16> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c16] : !pto.tensor_view<1x1x1x2x16xi16> -> !pto.partition_tensor_view<1x1x1x2x16xi16> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c16] : !pto.tensor_view<1x1x1x2x16xi16> -> !pto.partition_tensor_view<1x1x1x2x16xi16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c16] : !pto.tensor_view<1x1x1x2x16xi16> -> !pto.partition_tensor_view<1x1x1x2x16xi16> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x16xi16>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x16xi16>) outs(%src_tile : !pto.tile_buf) + %scalar_i16 = arith.trunci %scalar : i32 to i16 + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_i16 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, i16) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x16xi16>) + return + } + + func.func @TSELS_uint16_uint32_2x16_2x8_2x16_2x16(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c2, %c8], strides = [%c16, %c16, %c16, %c8, %c1] : !pto.tensor_view<1x1x1x2x8xi32> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c2, %c16], strides = [%c32, %c32, %c32, %c16, %c1] : !pto.tensor_view<1x1x1x2x16xi16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c2, %c16], strides = [%c32, %c32, %c32, %c16, %c1] : !pto.tensor_view<1x1x1x2x16xi16> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c8] : !pto.tensor_view<1x1x1x2x8xi32> -> !pto.partition_tensor_view<1x1x1x2x8xi32> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c16] : !pto.tensor_view<1x1x1x2x16xi16> -> !pto.partition_tensor_view<1x1x1x2x16xi16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c16] : !pto.tensor_view<1x1x1x2x16xi16> -> !pto.partition_tensor_view<1x1x1x2x16xi16> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x8xi32>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x16xi16>) outs(%src_tile : !pto.tile_buf) + %scalar_i16 = arith.trunci %scalar : i32 to i16 + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_i16 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, i16) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x16xi16>) + return + } + + func.func @TSELS_uint32_uint8_2x8_2x32_2x8_2x8(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c2, %c32], strides = [%c64, %c64, %c64, %c32, %c1] : !pto.tensor_view<1x1x1x2x32xi8> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c2, %c8], strides = [%c16, %c16, %c16, %c8, %c1] : !pto.tensor_view<1x1x1x2x8xi32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c2, %c8], strides = [%c16, %c16, %c16, %c8, %c1] : !pto.tensor_view<1x1x1x2x8xi32> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c32] : !pto.tensor_view<1x1x1x2x32xi8> -> !pto.partition_tensor_view<1x1x1x2x32xi8> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c8] : !pto.tensor_view<1x1x1x2x8xi32> -> !pto.partition_tensor_view<1x1x1x2x8xi32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c8] : !pto.tensor_view<1x1x1x2x8xi32> -> !pto.partition_tensor_view<1x1x1x2x8xi32> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x32xi8>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x8xi32>) outs(%src_tile : !pto.tile_buf) + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, i32) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x8xi32>) + return + } + + func.func @TSELS_uint32_uint16_2x8_2x16_2x8_2x8(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c2, %c16], strides = [%c32, %c32, %c32, %c16, %c1] : !pto.tensor_view<1x1x1x2x16xi16> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c2, %c8], strides = [%c16, %c16, %c16, %c8, %c1] : !pto.tensor_view<1x1x1x2x8xi32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c2, %c8], strides = [%c16, %c16, %c16, %c8, %c1] : !pto.tensor_view<1x1x1x2x8xi32> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c16] : !pto.tensor_view<1x1x1x2x16xi16> -> !pto.partition_tensor_view<1x1x1x2x16xi16> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c8] : !pto.tensor_view<1x1x1x2x8xi32> -> !pto.partition_tensor_view<1x1x1x2x8xi32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c8] : !pto.tensor_view<1x1x1x2x8xi32> -> !pto.partition_tensor_view<1x1x1x2x8xi32> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x16xi16>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x8xi32>) outs(%src_tile : !pto.tile_buf) + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, i32) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x8xi32>) + return + } + + func.func @TSELS_uint32_uint32_2x8_2x8_2x8_2x8(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c2, %c8], strides = [%c16, %c16, %c16, %c8, %c1] : !pto.tensor_view<1x1x1x2x8xi32> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c2, %c8], strides = [%c16, %c16, %c16, %c8, %c1] : !pto.tensor_view<1x1x1x2x8xi32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c2, %c8], strides = [%c16, %c16, %c16, %c8, %c1] : !pto.tensor_view<1x1x1x2x8xi32> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c8] : !pto.tensor_view<1x1x1x2x8xi32> -> !pto.partition_tensor_view<1x1x1x2x8xi32> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c8] : !pto.tensor_view<1x1x1x2x8xi32> -> !pto.partition_tensor_view<1x1x1x2x8xi32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c8] : !pto.tensor_view<1x1x1x2x8xi32> -> !pto.partition_tensor_view<1x1x1x2x8xi32> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x8xi32>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x8xi32>) outs(%src_tile : !pto.tile_buf) + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, i32) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x8xi32>) + return + } + + func.func @TSELS_f16_uint8_2x16_2x32_2x16_2x16(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c2, %c32], strides = [%c64, %c64, %c64, %c32, %c1] : !pto.tensor_view<1x1x1x2x32xi8> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c2, %c16], strides = [%c32, %c32, %c32, %c16, %c1] : !pto.tensor_view<1x1x1x2x16xi16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c2, %c16], strides = [%c32, %c32, %c32, %c16, %c1] : !pto.tensor_view<1x1x1x2x16xi16> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c32] : !pto.tensor_view<1x1x1x2x32xi8> -> !pto.partition_tensor_view<1x1x1x2x32xi8> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c16] : !pto.tensor_view<1x1x1x2x16xi16> -> !pto.partition_tensor_view<1x1x1x2x16xi16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c16] : !pto.tensor_view<1x1x1x2x16xi16> -> !pto.partition_tensor_view<1x1x1x2x16xi16> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x32xi8>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x16xi16>) outs(%src_tile : !pto.tile_buf) + %scalar_i16 = arith.trunci %scalar : i32 to i16 + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_i16 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, i16) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x16xi16>) + return + } + + func.func @TSELS_f16_uint16_2x16_2x16_2x16_2x16(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c2, %c16], strides = [%c32, %c32, %c32, %c16, %c1] : !pto.tensor_view<1x1x1x2x16xi16> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c2, %c16], strides = [%c32, %c32, %c32, %c16, %c1] : !pto.tensor_view<1x1x1x2x16xi16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c2, %c16], strides = [%c32, %c32, %c32, %c16, %c1] : !pto.tensor_view<1x1x1x2x16xi16> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c16] : !pto.tensor_view<1x1x1x2x16xi16> -> !pto.partition_tensor_view<1x1x1x2x16xi16> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c16] : !pto.tensor_view<1x1x1x2x16xi16> -> !pto.partition_tensor_view<1x1x1x2x16xi16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c16] : !pto.tensor_view<1x1x1x2x16xi16> -> !pto.partition_tensor_view<1x1x1x2x16xi16> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x16xi16>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x16xi16>) outs(%src_tile : !pto.tile_buf) + %scalar_i16 = arith.trunci %scalar : i32 to i16 + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_i16 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, i16) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x16xi16>) + return + } + + func.func @TSELS_f16_uint32_2x16_2x8_2x16_2x16(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c2, %c8], strides = [%c16, %c16, %c16, %c8, %c1] : !pto.tensor_view<1x1x1x2x8xi32> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c2, %c16], strides = [%c32, %c32, %c32, %c16, %c1] : !pto.tensor_view<1x1x1x2x16xi16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c2, %c16], strides = [%c32, %c32, %c32, %c16, %c1] : !pto.tensor_view<1x1x1x2x16xi16> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c8] : !pto.tensor_view<1x1x1x2x8xi32> -> !pto.partition_tensor_view<1x1x1x2x8xi32> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c16] : !pto.tensor_view<1x1x1x2x16xi16> -> !pto.partition_tensor_view<1x1x1x2x16xi16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c16] : !pto.tensor_view<1x1x1x2x16xi16> -> !pto.partition_tensor_view<1x1x1x2x16xi16> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x8xi32>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x16xi16>) outs(%src_tile : !pto.tile_buf) + %scalar_i16 = arith.trunci %scalar : i32 to i16 + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_i16 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, i16) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x16xi16>) + return + } + + func.func @TSELS_f32_uint8_2x8_2x32_2x8_2x8(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c2, %c32], strides = [%c64, %c64, %c64, %c32, %c1] : !pto.tensor_view<1x1x1x2x32xi8> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c2, %c8], strides = [%c16, %c16, %c16, %c8, %c1] : !pto.tensor_view<1x1x1x2x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c2, %c8], strides = [%c16, %c16, %c16, %c8, %c1] : !pto.tensor_view<1x1x1x2x8xf32> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c32] : !pto.tensor_view<1x1x1x2x32xi8> -> !pto.partition_tensor_view<1x1x1x2x32xi8> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c8] : !pto.tensor_view<1x1x1x2x8xf32> -> !pto.partition_tensor_view<1x1x1x2x8xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c8] : !pto.tensor_view<1x1x1x2x8xf32> -> !pto.partition_tensor_view<1x1x1x2x8xf32> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x32xi8>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x8xf32>) outs(%src_tile : !pto.tile_buf) + %scalar_f32 = arith.bitcast %scalar : i32 to f32 + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_f32 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, f32) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x8xf32>) + return + } + + func.func @TSELS_f32_uint16_2x8_2x16_2x8_2x8(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c2, %c16], strides = [%c32, %c32, %c32, %c16, %c1] : !pto.tensor_view<1x1x1x2x16xi16> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c2, %c8], strides = [%c16, %c16, %c16, %c8, %c1] : !pto.tensor_view<1x1x1x2x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c2, %c8], strides = [%c16, %c16, %c16, %c8, %c1] : !pto.tensor_view<1x1x1x2x8xf32> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c16] : !pto.tensor_view<1x1x1x2x16xi16> -> !pto.partition_tensor_view<1x1x1x2x16xi16> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c8] : !pto.tensor_view<1x1x1x2x8xf32> -> !pto.partition_tensor_view<1x1x1x2x8xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c8] : !pto.tensor_view<1x1x1x2x8xf32> -> !pto.partition_tensor_view<1x1x1x2x8xf32> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x16xi16>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x8xf32>) outs(%src_tile : !pto.tile_buf) + %scalar_f32 = arith.bitcast %scalar : i32 to f32 + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_f32 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, f32) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x8xf32>) + return + } + + func.func @TSELS_f32_uint32_2x8_2x8_2x8_2x8(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c2, %c8], strides = [%c16, %c16, %c16, %c8, %c1] : !pto.tensor_view<1x1x1x2x8xi32> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c2, %c8], strides = [%c16, %c16, %c16, %c8, %c1] : !pto.tensor_view<1x1x1x2x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c2, %c8], strides = [%c16, %c16, %c16, %c8, %c1] : !pto.tensor_view<1x1x1x2x8xf32> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c8] : !pto.tensor_view<1x1x1x2x8xi32> -> !pto.partition_tensor_view<1x1x1x2x8xi32> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c8] : !pto.tensor_view<1x1x1x2x8xf32> -> !pto.partition_tensor_view<1x1x1x2x8xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c8] : !pto.tensor_view<1x1x1x2x8xf32> -> !pto.partition_tensor_view<1x1x1x2x8xf32> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x8xi32>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x8xf32>) outs(%src_tile : !pto.tile_buf) + %scalar_f32 = arith.bitcast %scalar : i32 to f32 + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_f32 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, f32) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x8xf32>) + return + } + + func.func @TSELS_uint8_uint8_2x32_2x64_2x128_2x31(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c31 = arith.constant 31 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c2, %c64], strides = [%c64, %c64, %c64, %c64, %c1] : !pto.tensor_view<1x1x1x2x64xi8> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c2, %c128], strides = [%c128, %c128, %c128, %c128, %c1] : !pto.tensor_view<1x1x1x2x128xi8> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c2, %c32], strides = [%c32, %c32, %c32, %c32, %c1] : !pto.tensor_view<1x1x1x2x32xi8> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c64] : !pto.tensor_view<1x1x1x2x64xi8> -> !pto.partition_tensor_view<1x1x1x2x64xi8> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c128] : !pto.tensor_view<1x1x1x2x128xi8> -> !pto.partition_tensor_view<1x1x1x2x128xi8> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c31] : !pto.tensor_view<1x1x1x2x32xi8> -> !pto.partition_tensor_view<1x1x1x2x31xi8> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x64xi8>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xi8>) outs(%src_tile : !pto.tile_buf) + %scalar_i8 = arith.trunci %scalar : i32 to i8 + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_i8 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, i8) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x31xi8>) + return + } + + func.func @TSELS_uint16_uint8_2x32_2x64_2x128_2x31(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c31 = arith.constant 31 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c2, %c64], strides = [%c64, %c64, %c64, %c64, %c1] : !pto.tensor_view<1x1x1x2x64xi8> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c2, %c128], strides = [%c256, %c256, %c256, %c128, %c1] : !pto.tensor_view<1x1x1x2x128xi16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c2, %c32], strides = [%c64, %c64, %c64, %c32, %c1] : !pto.tensor_view<1x1x1x2x32xi16> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c64] : !pto.tensor_view<1x1x1x2x64xi8> -> !pto.partition_tensor_view<1x1x1x2x64xi8> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c128] : !pto.tensor_view<1x1x1x2x128xi16> -> !pto.partition_tensor_view<1x1x1x2x128xi16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c31] : !pto.tensor_view<1x1x1x2x32xi16> -> !pto.partition_tensor_view<1x1x1x2x31xi16> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x64xi8>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xi16>) outs(%src_tile : !pto.tile_buf) + %scalar_i16 = arith.trunci %scalar : i32 to i16 + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_i16 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, i16) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x31xi16>) + return + } + + func.func @TSELS_f32_uint8_2x32_2x64_2x128_2x31(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %c31 = arith.constant 31 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c2, %c64], strides = [%c64, %c64, %c64, %c64, %c1] : !pto.tensor_view<1x1x1x2x64xi8> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c2, %c128], strides = [%c512, %c512, %c512, %c128, %c1] : !pto.tensor_view<1x1x1x2x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c2, %c32], strides = [%c128, %c128, %c128, %c32, %c1] : !pto.tensor_view<1x1x1x2x32xf32> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c64] : !pto.tensor_view<1x1x1x2x64xi8> -> !pto.partition_tensor_view<1x1x1x2x64xi8> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c128] : !pto.tensor_view<1x1x1x2x128xf32> -> !pto.partition_tensor_view<1x1x1x2x128xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c31] : !pto.tensor_view<1x1x1x2x32xf32> -> !pto.partition_tensor_view<1x1x1x2x31xf32> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x64xi8>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xf32>) outs(%src_tile : !pto.tile_buf) + %scalar_f32 = arith.bitcast %scalar : i32 to f32 + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_f32 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, f32) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x31xf32>) + return + } + + func.func @TSELS_uint8_uint8_32x672_32x96_32x672_32x666(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c96 = arith.constant 96 : index + %c672 = arith.constant 672 : index + %c666 = arith.constant 666 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c32, %c96], strides = [%c96, %c96, %c96, %c96, %c1] : !pto.tensor_view<1x1x1x32x96xi8> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c32, %c672], strides = [%c672, %c672, %c672, %c672, %c1] : !pto.tensor_view<1x1x1x32x672xi8> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c32, %c672], strides = [%c672, %c672, %c672, %c672, %c1] : !pto.tensor_view<1x1x1x32x672xi8> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c96] : !pto.tensor_view<1x1x1x32x96xi8> -> !pto.partition_tensor_view<1x1x1x32x96xi8> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c672] : !pto.tensor_view<1x1x1x32x672xi8> -> !pto.partition_tensor_view<1x1x1x32x672xi8> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c666] : !pto.tensor_view<1x1x1x32x672xi8> -> !pto.partition_tensor_view<1x1x1x32x666xi8> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x32x96xi8>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x672xi8>) outs(%src_tile : !pto.tile_buf) + %scalar_i8 = arith.trunci %scalar : i32 to i8 + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_i8 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, i8) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x666xi8>) + return + } + + func.func @TSELS_f16_uint8_32x672_32x96_32x672_32x666(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %c96 = arith.constant 96 : index + %c672 = arith.constant 672 : index + %c666 = arith.constant 666 : index + %c1344 = arith.constant 1344 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c32, %c96], strides = [%c96, %c96, %c96, %c96, %c1] : !pto.tensor_view<1x1x1x32x96xi8> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c32, %c672], strides = [%c1344, %c1344, %c1344, %c672, %c1] : !pto.tensor_view<1x1x1x32x672xi16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c32, %c672], strides = [%c1344, %c1344, %c1344, %c672, %c1] : !pto.tensor_view<1x1x1x32x672xi16> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c96] : !pto.tensor_view<1x1x1x32x96xi8> -> !pto.partition_tensor_view<1x1x1x32x96xi8> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c672] : !pto.tensor_view<1x1x1x32x672xi16> -> !pto.partition_tensor_view<1x1x1x32x672xi16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c666] : !pto.tensor_view<1x1x1x32x672xi16> -> !pto.partition_tensor_view<1x1x1x32x666xi16> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x32x96xi8>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x672xi16>) outs(%src_tile : !pto.tile_buf) + %scalar_i16 = arith.trunci %scalar : i32 to i16 + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_i16 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, i16) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x666xi16>) + return + } + + func.func @TSELS_f32_uint8_32x672_32x96_32x672_32x666(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c96 = arith.constant 96 : index + %c672 = arith.constant 672 : index + %c666 = arith.constant 666 : index + %c2688 = arith.constant 2688 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c32, %c96], strides = [%c96, %c96, %c96, %c96, %c1] : !pto.tensor_view<1x1x1x32x96xi8> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c32, %c672], strides = [%c2688, %c2688, %c2688, %c672, %c1] : !pto.tensor_view<1x1x1x32x672xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c32, %c672], strides = [%c2688, %c2688, %c2688, %c672, %c1] : !pto.tensor_view<1x1x1x32x672xf32> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c96] : !pto.tensor_view<1x1x1x32x96xi8> -> !pto.partition_tensor_view<1x1x1x32x96xi8> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c672] : !pto.tensor_view<1x1x1x32x672xf32> -> !pto.partition_tensor_view<1x1x1x32x672xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c666] : !pto.tensor_view<1x1x1x32x672xf32> -> !pto.partition_tensor_view<1x1x1x32x666xf32> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x32x96xi8>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x672xf32>) outs(%src_tile : !pto.tile_buf) + %scalar_f32 = arith.bitcast %scalar : i32 to f32 + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_f32 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, f32) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x666xf32>) + return + } + + func.func @TSELS_f32_uint8_1x8192_1x4096_1x8192_1x8192(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c4096 = arith.constant 4096 : index + %c8192 = arith.constant 8192 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c1, %c4096], strides = [%c4096, %c4096, %c4096, %c4096, %c1] : !pto.tensor_view<1x1x1x1x4096xi8> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c1, %c8192], strides = [%c8192, %c8192, %c8192, %c8192, %c1] : !pto.tensor_view<1x1x1x1x8192xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c1, %c8192], strides = [%c8192, %c8192, %c8192, %c8192, %c1] : !pto.tensor_view<1x1x1x1x8192xf32> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c1, %c4096] : !pto.tensor_view<1x1x1x1x4096xi8> -> !pto.partition_tensor_view<1x1x1x1x4096xi8> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c1, %c8192] : !pto.tensor_view<1x1x1x1x8192xf32> -> !pto.partition_tensor_view<1x1x1x1x8192xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c1, %c8192] : !pto.tensor_view<1x1x1x1x8192xf32> -> !pto.partition_tensor_view<1x1x1x1x8192xf32> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x1x4096xi8>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x8192xf32>) outs(%src_tile : !pto.tile_buf) + %scalar_f32 = arith.bitcast %scalar : i32 to f32 + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_f32 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, f32) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x8192xf32>) + return + } +} \ No newline at end of file From 4597cd18cf9ad41c658f08e07ddbb9b4c09ad128 Mon Sep 17 00:00:00 2001 From: pbbb205 <625293903@qq.com> Date: Sun, 26 Apr 2026 01:20:14 +0800 Subject: [PATCH 175/192] add trowexpand op (#149) * add trowexpand op * trowexpand series op * trowexpand series op * trowexpand series op --- lib/TileOps/trowexpand_template.py | 48 + lib/TileOps/trowexpandadd_template.py | 52 ++ lib/TileOps/trowexpanddiv_template.py | 84 ++ lib/TileOps/trowexpandexpdif_template.py | 78 ++ lib/TileOps/trowexpandmax_template.py | 52 ++ lib/TileOps/trowexpandmin_template.py | 52 ++ lib/TileOps/trowexpandmul_template.py | 52 ++ lib/TileOps/trowexpandsub_template.py | 52 ++ test/basic/trowexpand_tile_op_expand.pto | 44 + test/basic/trowexpandadd_tile_op_expand.pto | 51 ++ test/basic/trowexpanddiv_tile_op_expand.pto | 51 ++ .../basic/trowexpandexpdif_tile_op_expand.pto | 52 ++ test/basic/trowexpandmax_tile_op_expand.pto | 51 ++ test/basic/trowexpandmin_tile_op_expand.pto | 51 ++ test/basic/trowexpandmul_tile_op_expand.pto | 51 ++ test/basic/trowexpandsub_tile_op_expand.pto | 51 ++ .../npu/a5/src/st/testcase/CMakeLists.txt | 8 + .../src/st/testcase/trowexpand/CMakeLists.txt | 9 + .../a5/src/st/testcase/trowexpand/cases.py | 90 ++ .../a5/src/st/testcase/trowexpand/compare.py | 64 ++ .../a5/src/st/testcase/trowexpand/gen_data.py | 52 ++ .../a5/src/st/testcase/trowexpand/launch.cpp | 46 + .../a5/src/st/testcase/trowexpand/main.cpp | 145 +++ .../src/st/testcase/trowexpand/trowexpand.pto | 312 +++++++ .../st/testcase/trowexpandadd/CMakeLists.txt | 9 + .../a5/src/st/testcase/trowexpandadd/cases.py | 116 +++ .../src/st/testcase/trowexpandadd/compare.py | 61 ++ .../src/st/testcase/trowexpandadd/gen_data.py | 69 ++ .../src/st/testcase/trowexpandadd/launch.cpp | 55 ++ .../a5/src/st/testcase/trowexpandadd/main.cpp | 166 ++++ .../testcase/trowexpandadd/trowexpandadd.pto | 476 ++++++++++ .../st/testcase/trowexpanddiv/CMakeLists.txt | 9 + .../a5/src/st/testcase/trowexpanddiv/cases.py | 130 +++ .../src/st/testcase/trowexpanddiv/compare.py | 61 ++ .../src/st/testcase/trowexpanddiv/gen_data.py | 77 ++ .../src/st/testcase/trowexpanddiv/launch.cpp | 51 ++ .../a5/src/st/testcase/trowexpanddiv/main.cpp | 131 +++ .../testcase/trowexpanddiv/trowexpanddiv.pto | 860 ++++++++++++++++++ .../testcase/trowexpandexpdif/CMakeLists.txt | 9 + .../src/st/testcase/trowexpandexpdif/cases.py | 87 ++ .../st/testcase/trowexpandexpdif/compare.py | 61 ++ .../st/testcase/trowexpandexpdif/gen_data.py | 54 ++ .../st/testcase/trowexpandexpdif/launch.cpp | 41 + .../src/st/testcase/trowexpandexpdif/main.cpp | 126 +++ .../trowexpandexpdif/trowexpandexpdif.pto | 317 +++++++ .../st/testcase/trowexpandmax/CMakeLists.txt | 9 + .../a5/src/st/testcase/trowexpandmax/cases.py | 111 +++ .../src/st/testcase/trowexpandmax/compare.py | 61 ++ .../src/st/testcase/trowexpandmax/gen_data.py | 53 ++ .../src/st/testcase/trowexpandmax/launch.cpp | 55 ++ .../a5/src/st/testcase/trowexpandmax/main.cpp | 134 +++ .../testcase/trowexpandmax/trowexpandmax.pto | 437 +++++++++ .../st/testcase/trowexpandmin/CMakeLists.txt | 9 + .../a5/src/st/testcase/trowexpandmin/cases.py | 111 +++ .../src/st/testcase/trowexpandmin/compare.py | 61 ++ .../src/st/testcase/trowexpandmin/gen_data.py | 53 ++ .../src/st/testcase/trowexpandmin/launch.cpp | 55 ++ .../a5/src/st/testcase/trowexpandmin/main.cpp | 134 +++ .../testcase/trowexpandmin/trowexpandmin.pto | 437 +++++++++ .../st/testcase/trowexpandmul/CMakeLists.txt | 9 + .../a5/src/st/testcase/trowexpandmul/cases.py | 111 +++ .../src/st/testcase/trowexpandmul/compare.py | 61 ++ .../src/st/testcase/trowexpandmul/gen_data.py | 53 ++ .../src/st/testcase/trowexpandmul/launch.cpp | 55 ++ .../a5/src/st/testcase/trowexpandmul/main.cpp | 134 +++ .../testcase/trowexpandmul/trowexpandmul.pto | 212 +++++ .../st/testcase/trowexpandsub/CMakeLists.txt | 9 + .../a5/src/st/testcase/trowexpandsub/cases.py | 111 +++ .../src/st/testcase/trowexpandsub/compare.py | 61 ++ .../src/st/testcase/trowexpandsub/gen_data.py | 54 ++ .../src/st/testcase/trowexpandsub/launch.cpp | 55 ++ .../a5/src/st/testcase/trowexpandsub/main.cpp | 134 +++ .../testcase/trowexpandsub/trowexpandsub.pto | 211 +++++ 73 files changed, 7554 insertions(+) create mode 100644 lib/TileOps/trowexpand_template.py create mode 100644 lib/TileOps/trowexpandadd_template.py create mode 100644 lib/TileOps/trowexpanddiv_template.py create mode 100644 lib/TileOps/trowexpandexpdif_template.py create mode 100644 lib/TileOps/trowexpandmax_template.py create mode 100644 lib/TileOps/trowexpandmin_template.py create mode 100644 lib/TileOps/trowexpandmul_template.py create mode 100644 lib/TileOps/trowexpandsub_template.py create mode 100644 test/basic/trowexpand_tile_op_expand.pto create mode 100644 test/basic/trowexpandadd_tile_op_expand.pto create mode 100644 test/basic/trowexpanddiv_tile_op_expand.pto create mode 100644 test/basic/trowexpandexpdif_tile_op_expand.pto create mode 100644 test/basic/trowexpandmax_tile_op_expand.pto create mode 100644 test/basic/trowexpandmin_tile_op_expand.pto create mode 100644 test/basic/trowexpandmul_tile_op_expand.pto create mode 100644 test/basic/trowexpandsub_tile_op_expand.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpand/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpand/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpand/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpand/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpand/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpand/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpand/trowexpand.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/trowexpandadd.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/trowexpanddiv.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/trowexpandexpdif.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/trowexpandmax.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/trowexpandmin.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/trowexpandmul.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/trowexpandsub.pto diff --git a/lib/TileOps/trowexpand_template.py b/lib/TileOps/trowexpand_template.py new file mode 100644 index 000000000..4dc6c6a79 --- /dev/null +++ b/lib/TileOps/trowexpand_template.py @@ -0,0 +1,48 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.trowexpand""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +def _constraint_trowexpand_row_major(src: pto.Tile, dst: pto.Tile) -> bool: + """Constraint for RowMajor layout trowexpand template.""" + # Both src and dst must be RowMajor layout + src_row_major = src.config.b_layout == pto.BLayout.ROW_MAJOR + dst_row_major = dst.config.b_layout == pto.BLayout.ROW_MAJOR + return src_row_major and dst_row_major + + +@pto.vkernel( + target="a5", + op="pto.trowexpand", + dtypes=[(pto.AnyFloat, pto.AnyFloat), (pto.AnyInt, pto.AnyInt)], + constraints=[_constraint_trowexpand_row_major], +) +def template_trowexpand(src: pto.Tile, dst: pto.Tile): + """Template for pto.trowexpand. + + Broadcast src[row, 0] to entire dst[row, :] for each row. + Semantics: dst[row, col] = src[row, 0] for all col. + """ + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + # Load the first element of each row (src has cols=1, so entire row is the scalar) + # vdup broadcasts the first element to the full vector width + scalar_vec = pto.vlds(src[row, :]) + broadcasted = pto.vdup(scalar_vec, mask) + pto.vsts(broadcasted, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/trowexpandadd_template.py b/lib/TileOps/trowexpandadd_template.py new file mode 100644 index 000000000..96019d501 --- /dev/null +++ b/lib/TileOps/trowexpandadd_template.py @@ -0,0 +1,52 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.trowexpandadd""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +def _constraint_trowexpandadd_row_major(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile) -> bool: + """Constraint for RowMajor layout trowexpandadd template.""" + # All tiles must be RowMajor layout + src0_row_major = src0.config.b_layout == pto.BLayout.ROW_MAJOR + src1_row_major = src1.config.b_layout == pto.BLayout.ROW_MAJOR + dst_row_major = dst.config.b_layout == pto.BLayout.ROW_MAJOR + return src0_row_major and src1_row_major and dst_row_major + + +@pto.vkernel( + target="a5", + op="pto.trowexpandadd", + dtypes=[(pto.AnyFloat, pto.AnyFloat, pto.AnyFloat), (pto.AnyInt, pto.AnyInt, pto.AnyInt)], + constraints=[_constraint_trowexpandadd_row_major], +) +def template_trowexpandadd(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + """Template for pto.trowexpandadd. + + Add a per-row scalar from src1[row, 0] to each row of src0. + Semantics: dst[row, col] = src0[row, col] + src1[row, 0] + """ + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + # Load the scalar vector from src1[row, :] + # For row-major src1, valid_shape[1] is 32/sizeof(dtype) (e.g., 8 for f32) + # vdup broadcasts the first element to the full vector width + scalar_vec = pto.vlds(src1[row, :]) + broadcasted = pto.vdup(scalar_vec, mask) + lhs = pto.vlds(src0[row, col:]) + result = pto.vadd(lhs, broadcasted, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/trowexpanddiv_template.py b/lib/TileOps/trowexpanddiv_template.py new file mode 100644 index 000000000..5c8325408 --- /dev/null +++ b/lib/TileOps/trowexpanddiv_template.py @@ -0,0 +1,84 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.trowexpanddiv""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +def _constraint_trowexpanddiv_row_major(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile) -> bool: + """Constraint for RowMajor layout trowexpanddiv template.""" + # All tiles must be RowMajor layout + src0_row_major = src0.config.b_layout == pto.BLayout.ROW_MAJOR + src1_row_major = src1.config.b_layout == pto.BLayout.ROW_MAJOR + dst_row_major = dst.config.b_layout == pto.BLayout.ROW_MAJOR + return src0_row_major and src1_row_major and dst_row_major + + +@pto.vkernel( + target="a5", + op="pto.trowexpanddiv", + dtypes=[(pto.f32, pto.f32, pto.f32)], + constraints=[_constraint_trowexpanddiv_row_major], +) +def template_trowexpanddiv_f32(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + """Template for pto.trowexpanddiv with f32 dtype. + + Divide each row of src0 by a per-row scalar from src1[row, 0]. + Semantics: dst[row, col] = src0[row, col] / src1[row, 0] + """ + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + # Load the scalar vector from src1[row, :] + # For row-major src1, valid_shape[1] is 32/sizeof(dtype) (e.g., 8 for f32) + # vdup broadcasts the first element to the full vector width + scalar_vec = pto.vlds(src1[row, :]) + broadcasted = pto.vdup(scalar_vec, mask) + lhs = pto.vlds(src0[row, col:]) + result = pto.vdiv(lhs, broadcasted, mask) + # TODO: pto-isa vdiv supports high-precision mode. Current implementation uses Default mode. High-precision division needs to be implemented in future. + pto.vsts(result, dst[row, col:], mask) + return + + +@pto.vkernel( + target="a5", + op="pto.trowexpanddiv", + dtypes=[(pto.f16, pto.f16, pto.f16)], + constraints=[_constraint_trowexpanddiv_row_major], +) +def template_trowexpanddiv_f16(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + """Template for pto.trowexpanddiv with f16 dtype. + + Divide each row of src0 by a per-row scalar from src1[row, 0]. + Semantics: dst[row, col] = src0[row, col] / src1[row, 0] + """ + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + # Load the scalar vector from src1[row, :] + # For row-major src1, valid_shape[1] is 32/sizeof(dtype) (e.g., 16 for f16) + # vdup broadcasts the first element to the full vector width + scalar_vec = pto.vlds(src1[row, :]) + broadcasted = pto.vdup(scalar_vec, mask) + lhs = pto.vlds(src0[row, col:]) + result = pto.vdiv(lhs, broadcasted, mask) + # TODO: pto-isa vdiv supports high-precision mode. Current implementation uses Default mode. High-precision division needs to be implemented in future. + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/trowexpandexpdif_template.py b/lib/TileOps/trowexpandexpdif_template.py new file mode 100644 index 000000000..0e84294b9 --- /dev/null +++ b/lib/TileOps/trowexpandexpdif_template.py @@ -0,0 +1,78 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.trowexpandexpdif""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +def _constraint_trowexpandexpdif_row_major(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile) -> bool: + """Constraint for RowMajor layout trowexpandexpdif template.""" + # All tiles must be RowMajor layout + src0_row_major = src0.config.b_layout == pto.BLayout.ROW_MAJOR + src1_row_major = src1.config.b_layout == pto.BLayout.ROW_MAJOR + dst_row_major = dst.config.b_layout == pto.BLayout.ROW_MAJOR + return src0_row_major and src1_row_major and dst_row_major + + +@pto.vkernel( + target="a5", + op="pto.trowexpandexpdif", + dtypes=[(pto.f32, pto.f32, pto.f32)], + constraints=[_constraint_trowexpandexpdif_row_major], +) +def template_trowexpandexpdif_f32(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + """Template for pto.trowexpandexpdif with f32 dtype. + + Compute exp(src0 - scalar) for each row using per-row scalars from src1[row, 0]. + Semantics: dst[row, col] = exp(src0[row, col] - src1[row, 0]) + Used in numerically stable softmax computation. + """ + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(pto.f32)): + mask, remained = pto.make_mask(pto.f32, remained) + scalar_vec = pto.vlds(src1[row, :]) + broadcasted = pto.vdup(scalar_vec, mask) + lhs = pto.vlds(src0[row, col:]) + result = pto.vexpdif(lhs, broadcasted, pto.VcvtPartMode.EVEN) + pto.vsts(result, dst[row, col:], mask) + return + + +@pto.vkernel( + target="a5", + op="pto.trowexpandexpdif", + dtypes=[(pto.f16, pto.f16, pto.f16)], + constraints=[_constraint_trowexpandexpdif_row_major], +) +def template_trowexpandexpdif_f16(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + """Template for pto.trowexpandexpdif with f16 dtype. + + Compute exp(src0 - scalar) for each row using per-row scalars from src1[row, 0]. + Semantics: dst[row, col] = exp(src0[row, col] - src1[row, 0]) + Used in numerically stable softmax computation. + """ + dtype = pto.f16 + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + scalar_vec = pto.vlds(src1[row, :]) + broadcasted = pto.vdup(scalar_vec, mask) + lhs = pto.vlds(src0[row, col:]) + diff = pto.vsub(lhs, broadcasted, mask) + result = pto.vexp(diff, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/trowexpandmax_template.py b/lib/TileOps/trowexpandmax_template.py new file mode 100644 index 000000000..2e5d53124 --- /dev/null +++ b/lib/TileOps/trowexpandmax_template.py @@ -0,0 +1,52 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.trowexpandmax""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +def _constraint_trowexpandmax_row_major(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile) -> bool: + """Constraint for RowMajor layout trowexpandmax template.""" + # All tiles must be RowMajor layout + src0_row_major = src0.config.b_layout == pto.BLayout.ROW_MAJOR + src1_row_major = src1.config.b_layout == pto.BLayout.ROW_MAJOR + dst_row_major = dst.config.b_layout == pto.BLayout.ROW_MAJOR + return src0_row_major and src1_row_major and dst_row_major + + +@pto.vkernel( + target="a5", + op="pto.trowexpandmax", + dtypes=[(pto.AnyFloat, pto.AnyFloat, pto.AnyFloat), (pto.AnyInt, pto.AnyInt, pto.AnyInt)], + constraints=[_constraint_trowexpandmax_row_major], +) +def template_trowexpandmax(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + """Template for pto.trowexpandmax. + + Compute element-wise max of each row of src0 with a per-row scalar from src1[row, 0]. + Semantics: dst[row, col] = max(src0[row, col], src1[row, 0]) + """ + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + # Load the scalar vector from src1[row, :] + # For row-major src1, valid_shape[1] is 32/sizeof(dtype) (e.g., 8 for f32) + # vdup broadcasts the first element to the full vector width + scalar_vec = pto.vlds(src1[row, :]) + broadcasted = pto.vdup(scalar_vec, mask) + lhs = pto.vlds(src0[row, col:]) + result = pto.vmax(lhs, broadcasted, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/trowexpandmin_template.py b/lib/TileOps/trowexpandmin_template.py new file mode 100644 index 000000000..eae99ff30 --- /dev/null +++ b/lib/TileOps/trowexpandmin_template.py @@ -0,0 +1,52 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.trowexpandmin""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +def _constraint_trowexpandmin_row_major(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile) -> bool: + """Constraint for RowMajor layout trowexpandmin template.""" + # All tiles must be RowMajor layout + src0_row_major = src0.config.b_layout == pto.BLayout.ROW_MAJOR + src1_row_major = src1.config.b_layout == pto.BLayout.ROW_MAJOR + dst_row_major = dst.config.b_layout == pto.BLayout.ROW_MAJOR + return src0_row_major and src1_row_major and dst_row_major + + +@pto.vkernel( + target="a5", + op="pto.trowexpandmin", + dtypes=[(pto.AnyFloat, pto.AnyFloat, pto.AnyFloat), (pto.AnyInt, pto.AnyInt, pto.AnyInt)], + constraints=[_constraint_trowexpandmin_row_major], +) +def template_trowexpandmin(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + """Template for pto.trowexpandmin. + + Compute element-wise min of each row of src0 with a per-row scalar from src1[row, 0]. + Semantics: dst[row, col] = min(src0[row, col], src1[row, 0]) + """ + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + # Load the scalar vector from src1[row, :] + # For row-major src1, valid_shape[1] is 32/sizeof(dtype) (e.g., 8 for f32) + # vdup broadcasts the first element to the full vector width + scalar_vec = pto.vlds(src1[row, :]) + broadcasted = pto.vdup(scalar_vec, mask) + lhs = pto.vlds(src0[row, col:]) + result = pto.vmin(lhs, broadcasted, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/trowexpandmul_template.py b/lib/TileOps/trowexpandmul_template.py new file mode 100644 index 000000000..593420125 --- /dev/null +++ b/lib/TileOps/trowexpandmul_template.py @@ -0,0 +1,52 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.trowexpandmul""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +def _constraint_trowexpandmul_row_major(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile) -> bool: + """Constraint for RowMajor layout trowexpandmul template.""" + # All tiles must be RowMajor layout + src0_row_major = src0.config.b_layout == pto.BLayout.ROW_MAJOR + src1_row_major = src1.config.b_layout == pto.BLayout.ROW_MAJOR + dst_row_major = dst.config.b_layout == pto.BLayout.ROW_MAJOR + return src0_row_major and src1_row_major and dst_row_major + + +@pto.vkernel( + target="a5", + op="pto.trowexpandmul", + dtypes=[(pto.AnyFloat, pto.AnyFloat, pto.AnyFloat), (pto.AnyInt, pto.AnyInt, pto.AnyInt)], + constraints=[_constraint_trowexpandmul_row_major], +) +def template_trowexpandmul(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + """Template for pto.trowexpandmul. + + Multiply each row of src0 by a per-row scalar from src1[row, 0]. + Semantics: dst[row, col] = src0[row, col] * src1[row, 0] + """ + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + # Load the scalar vector from src1[row, :] + # For row-major src1, valid_shape[1] is 32/sizeof(dtype) (e.g., 8 for f32) + # vdup broadcasts the first element to the full vector width + scalar_vec = pto.vlds(src1[row, :]) + broadcasted = pto.vdup(scalar_vec, mask) + lhs = pto.vlds(src0[row, col:]) + result = pto.vmul(lhs, broadcasted, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/trowexpandsub_template.py b/lib/TileOps/trowexpandsub_template.py new file mode 100644 index 000000000..ed44e8613 --- /dev/null +++ b/lib/TileOps/trowexpandsub_template.py @@ -0,0 +1,52 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.trowexpandsub""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +def _constraint_trowexpandsub_row_major(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile) -> bool: + """Constraint for RowMajor layout trowexpandsub template.""" + # All tiles must be RowMajor layout + src0_row_major = src0.config.b_layout == pto.BLayout.ROW_MAJOR + src1_row_major = src1.config.b_layout == pto.BLayout.ROW_MAJOR + dst_row_major = dst.config.b_layout == pto.BLayout.ROW_MAJOR + return src0_row_major and src1_row_major and dst_row_major + + +@pto.vkernel( + target="a5", + op="pto.trowexpandsub", + dtypes=[(pto.AnyFloat, pto.AnyFloat, pto.AnyFloat), (pto.AnyInt, pto.AnyInt, pto.AnyInt)], + constraints=[_constraint_trowexpandsub_row_major], +) +def template_trowexpandsub(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + """Template for pto.trowexpandsub. + + Subtract a per-row scalar from src1[row, 0] from each row of src0. + Semantics: dst[row, col] = src0[row, col] - src1[row, 0] + """ + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + # Load the scalar vector from src1[row, :] + # For row-major src1, valid_shape[1] is 32/sizeof(dtype) (e.g., 8 for f32) + # vdup broadcasts the first element to the full vector width + scalar_vec = pto.vlds(src1[row, :]) + broadcasted = pto.vdup(scalar_vec, mask) + lhs = pto.vlds(src0[row, col:]) + result = pto.vsub(lhs, broadcasted, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/test/basic/trowexpand_tile_op_expand.pto b/test/basic/trowexpand_tile_op_expand.pto new file mode 100644 index 000000000..4beee5faf --- /dev/null +++ b/test/basic/trowexpand_tile_op_expand.pto @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.trowexpand via the TileLang Python DSL template +// lib/TileOps/trowexpand.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.trowexpand should be lowered to vector-style VPTO IR. +// CHECK: func.func @TROWEXPAND +// CHECK-NOT: pto.trowexpand ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vdup +// CHECK: pto.vsts + +module { + func.func @TROWEXPAND() { + // Source tile: 32x8 (32-byte aligned, cols=32/sizeof(f32)=8) + // Only column 0 contains valid data (v_col=8 for alignment, actual valid=1) + %src = pto.alloc_tile + : !pto.tile_buf + // Destination tile: 32x32 (broadcast each scalar across the row) + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.trowexpand ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/trowexpandadd_tile_op_expand.pto b/test/basic/trowexpandadd_tile_op_expand.pto new file mode 100644 index 000000000..f9aca7d43 --- /dev/null +++ b/test/basic/trowexpandadd_tile_op_expand.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.trowexpandadd via the TileLang Python DSL template +// lib/TileOps/trowexpandadd.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.trowexpandadd should be lowered to vector-style VPTO IR. +// CHECK: func.func @TROWEXPANDADD +// CHECK-NOT: pto.trowexpandadd ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vdup +// CHECK: pto.vadd +// CHECK: pto.vsts + +module { + func.func @TROWEXPANDADD() { + // src0: 32x32 matrix (row-major) + %src0 = pto.alloc_tile + : !pto.tile_buf + // src1: 32x8 (one scalar vector per row, width=8 for f32 which is 32/sizeof(f32)) + // For row-major src1, valid_shape[1] must be 32/sizeof(dtype) + %src1 = pto.alloc_tile + : !pto.tile_buf + // dst: 32x32 result (row-major) + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.trowexpandadd ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/trowexpanddiv_tile_op_expand.pto b/test/basic/trowexpanddiv_tile_op_expand.pto new file mode 100644 index 000000000..11f7db6a4 --- /dev/null +++ b/test/basic/trowexpanddiv_tile_op_expand.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.trowexpanddiv via the TileLang Python DSL template +// lib/TileOps/trowexpanddiv.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.trowexpanddiv should be lowered to vector-style VPTO IR. +// CHECK: func.func @TROWEXPANDDIV +// CHECK-NOT: pto.trowexpanddiv ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vdup +// CHECK: pto.vdiv +// CHECK: pto.vsts + +module { + func.func @TROWEXPANDDIV() { + // src0: 32x32 matrix (row-major) + %src0 = pto.alloc_tile + : !pto.tile_buf + // src1: 32x8 (one scalar vector per row, width=8 for f32 which is 32/sizeof(f32)) + // For row-major src1, valid_shape[1] must be 32/sizeof(dtype) + %src1 = pto.alloc_tile + : !pto.tile_buf + // dst: 32x32 result (row-major) + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.trowexpanddiv ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/trowexpandexpdif_tile_op_expand.pto b/test/basic/trowexpandexpdif_tile_op_expand.pto new file mode 100644 index 000000000..6ba51b00a --- /dev/null +++ b/test/basic/trowexpandexpdif_tile_op_expand.pto @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.trowexpandexpdif via the TileLang Python DSL template +// lib/TileOps/trowexpandexpdif.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.trowexpandexpdif should be lowered to vector-style VPTO IR. +// CHECK: func.func @TROWEXPANDEXPDIF +// CHECK-NOT: pto.trowexpandexpdif ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vdup +// CHECK: pto.vsubs +// CHECK: pto.vexp +// CHECK: pto.vsts + +module { + func.func @TROWEXPANDEXPDIF() { + // src0: 32x32 matrix (row-major) + %src0 = pto.alloc_tile + : !pto.tile_buf + // src1: 32x8 (one scalar vector per row, width=8 for f32 which is 32/sizeof(f32)) + // For row-major src1, valid_shape[1] must be 32/sizeof(dtype) + %src1 = pto.alloc_tile + : !pto.tile_buf + // dst: 32x32 result (row-major) + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.trowexpandexpdif ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/trowexpandmax_tile_op_expand.pto b/test/basic/trowexpandmax_tile_op_expand.pto new file mode 100644 index 000000000..26d4a6781 --- /dev/null +++ b/test/basic/trowexpandmax_tile_op_expand.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.trowexpandmax via the TileLang Python DSL template +// lib/TileOps/trowexpandmax.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.trowexpandmax should be lowered to vector-style VPTO IR. +// CHECK: func.func @TROWEXPANDMAX +// CHECK-NOT: pto.trowexpandmax ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vdup +// CHECK: pto.vmax +// CHECK: pto.vsts + +module { + func.func @TROWEXPANDMAX() { + // src0: 32x32 matrix (row-major) + %src0 = pto.alloc_tile + : !pto.tile_buf + // src1: 32x8 (one scalar vector per row, width=8 for f32 which is 32/sizeof(f32)) + // For row-major src1, valid_shape[1] must be 32/sizeof(dtype) + %src1 = pto.alloc_tile + : !pto.tile_buf + // dst: 32x32 result (row-major) + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.trowexpandmax ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/trowexpandmin_tile_op_expand.pto b/test/basic/trowexpandmin_tile_op_expand.pto new file mode 100644 index 000000000..10d7641b2 --- /dev/null +++ b/test/basic/trowexpandmin_tile_op_expand.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.trowexpandmin via the TileLang Python DSL template +// lib/TileOps/trowexpandmin.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.trowexpandmin should be lowered to vector-style VPTO IR. +// CHECK: func.func @TROWEXPANDMIN +// CHECK-NOT: pto.trowexpandmin ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vdup +// CHECK: pto.vmin +// CHECK: pto.vsts + +module { + func.func @TROWEXPANDMIN() { + // src0: 32x32 matrix (row-major) + %src0 = pto.alloc_tile + : !pto.tile_buf + // src1: 32x8 (one scalar vector per row, width=8 for f32 which is 32/sizeof(f32)) + // For row-major src1, valid_shape[1] must be 32/sizeof(dtype) + %src1 = pto.alloc_tile + : !pto.tile_buf + // dst: 32x32 result (row-major) + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.trowexpandmin ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/trowexpandmul_tile_op_expand.pto b/test/basic/trowexpandmul_tile_op_expand.pto new file mode 100644 index 000000000..a68cd3e2b --- /dev/null +++ b/test/basic/trowexpandmul_tile_op_expand.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.trowexpandmul via the TileLang Python DSL template +// lib/TileOps/trowexpandmul.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.trowexpandmul should be lowered to vector-style VPTO IR. +// CHECK: func.func @TROWEXPANDMUL +// CHECK-NOT: pto.trowexpandmul ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vdup +// CHECK: pto.vmul +// CHECK: pto.vsts + +module { + func.func @TROWEXPANDMUL() { + // src0: 32x32 matrix (row-major) + %src0 = pto.alloc_tile + : !pto.tile_buf + // src1: 32x8 (one scalar vector per row, width=8 for f32 which is 32/sizeof(f32)) + // For row-major src1, valid_shape[1] must be 32/sizeof(dtype) + %src1 = pto.alloc_tile + : !pto.tile_buf + // dst: 32x32 result (row-major) + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.trowexpandmul ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/trowexpandsub_tile_op_expand.pto b/test/basic/trowexpandsub_tile_op_expand.pto new file mode 100644 index 000000000..c78dfb87c --- /dev/null +++ b/test/basic/trowexpandsub_tile_op_expand.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.trowexpandsub via the TileLang Python DSL template +// lib/TileOps/trowexpandsub.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.trowexpandsub should be lowered to vector-style VPTO IR. +// CHECK: func.func @TROWEXPANDSUB +// CHECK-NOT: pto.trowexpandsub ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vdup +// CHECK: pto.vsub +// CHECK: pto.vsts + +module { + func.func @TROWEXPANDSUB() { + // src0: 32x32 matrix (row-major) + %src0 = pto.alloc_tile + : !pto.tile_buf + // src1: 32x8 (one scalar vector per row, width=8 for f32 which is 32/sizeof(f32)) + // For row-major src1, valid_shape[1] must be 32/sizeof(dtype) + %src1 = pto.alloc_tile + : !pto.tile_buf + // dst: 32x32 result (row-major) + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.trowexpandsub ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt index cb0337cff..f4c4eaa96 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt +++ b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt @@ -159,6 +159,14 @@ set(ALL_TESTCASES trowprod trsqrt tsqrt + trowexpand + trowexpandadd + trowexpanddiv + trowexpandexpdif + trowexpandmax + trowexpandmin + trowexpandmul + trowexpandsub texpands tfillpad tfillpad_inplace diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/CMakeLists.txt new file mode 100644 index 000000000..254ce36e5 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(trowexpand) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/cases.py new file mode 100644 index 000000000..44effd8a7 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/cases.py @@ -0,0 +1,90 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for trowexpand ST test cases. + +trowexpand is a row broadcast operation: expands a scalar per row to the entire row. +- Input shape: (rows, srcCols) - physical layout for NPU alignment +- srcCols = 32/sizeof(dtype) for 32-byte alignment +- Output shape: (rows, dstCols) - broadcast each scalar across the row +- dstValidCols may be less than dstCols for partial valid region + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32, np.float16, np.int8). + - src0_shape: (rows, srcCols) — physical input tile dimensions. + - src0_valid_shape: (valid_rows, 1) — effective input region. + - dst_shape: (rows, dstCols) — output tile dimensions. + - dst_valid_shape: (valid_rows, valid_cols) — effective output region. + - eps: tolerance for numpy.allclose (atol and rtol). +""" + +import numpy as np + +CASES = [ + # f32 cases (srcCols=8 for 32-byte alignment) + { + "name": "f32_16x128", + "dtype": np.float32, + "src0_shape": (16, 8), + "src0_valid_shape": (16, 1), + "dst_shape": (16, 128), + "dst_valid_shape": (16, 128), + "eps": 1e-6, + }, + { + "name": "f32_16x127", + "dtype": np.float32, + "src0_shape": (16, 8), + "src0_valid_shape": (16, 1), + "dst_shape": (16, 128), + "dst_valid_shape": (16, 127), # partial valid region + "eps": 1e-6, + }, + # f16 cases (srcCols=16 for 32-byte alignment) + { + "name": "f16_16x512", + "dtype": np.float16, + "src0_shape": (16, 16), + "src0_valid_shape": (16, 1), + "dst_shape": (16, 512), + "dst_valid_shape": (16, 512), + "eps": 1e-3, + }, + { + "name": "f16_16x511", + "dtype": np.float16, + "src0_shape": (16, 16), + "src0_valid_shape": (16, 1), + "dst_shape": (16, 512), + "dst_valid_shape": (16, 511), # partial valid region + "eps": 1e-3, + }, + # i8 cases (srcCols=32 for 32-byte alignment) + { + "name": "i8_16x256", + "dtype": np.int8, + "src0_shape": (16, 32), + "src0_valid_shape": (16, 1), + "dst_shape": (16, 256), + "dst_valid_shape": (16, 256), + "eps": 0, # exact match for integers + }, + { + "name": "i8_16x255", + "dtype": np.int8, + "src0_shape": (16, 32), + "src0_valid_shape": (16, 1), + "dst_shape": (16, 256), + "dst_valid_shape": (16, 255), # partial valid region + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/compare.py new file mode 100644 index 000000000..bf00d9b08 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Compare golden and output for trowexpand ST test cases. + +trowexpand: row broadcast operation. +Compare output (rows, cols) against golden (rows, cols). +""" + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass + +# Inline validation for multi-input format (trowexpand uses src0/dst only) +REQUIRED_KEYS = {"name", "dtype", "src0_shape", "src0_valid_shape", "dst_shape", "dst_valid_shape"} +for i, case in enumerate(CASES): + missing = REQUIRED_KEYS - case.keys() + if missing: + raise ValueError(f"cases[{i}] ({case.get('name', '?')}) missing keys: {missing}") + + +def main(): + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + dtype = case["dtype"] + + vr, vc = dst_valid_shape + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=dtype).reshape(dst_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=dtype).reshape(dst_shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/gen_data.py new file mode 100644 index 000000000..8eec79d9e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/gen_data.py @@ -0,0 +1,52 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Generate input and golden data for trowexpand ST test cases. + +trowexpand: row broadcast operation. +- Input: (rows, 1) - one scalar per row +- Output: (rows, cols) - broadcast each scalar across the entire row +""" + +import numpy as np +from cases import CASES +from st_common import setup_case_rng, save_case_data + +# Inline validation for multi-input format (trowexpand uses src0/dst only) +REQUIRED_KEYS = {"name", "dtype", "src0_shape", "src0_valid_shape", "dst_shape", "dst_valid_shape"} +for i, case in enumerate(CASES): + missing = REQUIRED_KEYS - case.keys() + if missing: + raise ValueError(f"cases[{i}] ({case.get('name', '?')}) missing keys: {missing}") + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + src0_shape = case["src0_shape"] # Physical shape (rows, 8) + src0_valid_shape = case["src0_valid_shape"] # Valid shape (rows, 1) + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + + # Generate input: random values for each row's scalar, padded to 8 columns + # Physical layout: (rows, 8), but only column 0 is valid data + input_data = np.zeros(src0_shape, dtype=dtype) + src_vr = src0_valid_shape[0] + input_data[:src_vr, 0] = np.random.randint(1, 10, size=src_vr).astype(dtype) + + # Generate golden: broadcast each row's scalar across columns + # dst[i, :] = src[i, 0] for all columns + golden = np.zeros(dst_shape, dtype=dtype) + dst_vr, dst_vc = dst_valid_shape + golden[:dst_vr, :dst_vc] = np.broadcast_to(input_data[:src_vr, 0:1], (dst_vr, dst_vc)).astype(dtype, copy=False) + + save_case_data(case["name"], {"input": input_data, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} src0_shape={src0_shape} dst_shape={dst_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/launch.cpp new file mode 100644 index 000000000..ab55c0382 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/launch.cpp @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// f32 kernels +extern "C" __global__ AICORE void TROWEXPAND_f32_16x128(__gm__ float *src, __gm__ float *dst); +extern "C" __global__ AICORE void TROWEXPAND_f32_16x127(__gm__ float *src, __gm__ float *dst); + +void LaunchTROWEXPAND_f32_16x128(float *src, float *dst, void *stream) { + TROWEXPAND_f32_16x128<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} +void LaunchTROWEXPAND_f32_16x127(float *src, float *dst, void *stream) { + TROWEXPAND_f32_16x127<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +// f16 kernels (use uint16_t for aclFloat16) +extern "C" __global__ AICORE void TROWEXPAND_f16_16x512(__gm__ uint16_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TROWEXPAND_f16_16x511(__gm__ uint16_t *src, __gm__ uint16_t *dst); + +void LaunchTROWEXPAND_f16_16x512(void *src, void *dst, void *stream) { + TROWEXPAND_f16_16x512<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint16_t *)dst); +} +void LaunchTROWEXPAND_f16_16x511(void *src, void *dst, void *stream) { + TROWEXPAND_f16_16x511<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint16_t *)dst); +} + +// i8 kernels +extern "C" __global__ AICORE void TROWEXPAND_i8_16x256(__gm__ int8_t *src, __gm__ int8_t *dst); +extern "C" __global__ AICORE void TROWEXPAND_i8_16x255(__gm__ int8_t *src, __gm__ int8_t *dst); + +void LaunchTROWEXPAND_i8_16x256(void *src, void *dst, void *stream) { + TROWEXPAND_i8_16x256<<<1, nullptr, stream>>>((__gm__ int8_t *)src, (__gm__ int8_t *)dst); +} +void LaunchTROWEXPAND_i8_16x255(void *src, void *dst, void *stream) { + TROWEXPAND_i8_16x255<<<1, nullptr, stream>>>((__gm__ int8_t *)src, (__gm__ int8_t *)dst); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/main.cpp new file mode 100644 index 000000000..60413d4e4 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/main.cpp @@ -0,0 +1,145 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang trowexpand ST — row broadcast operation. +// Supports multiple data types: f32, f16, i8 + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +// f32 +void LaunchTROWEXPAND_f32_16x128(float *src, float *dst, void *stream); +void LaunchTROWEXPAND_f32_16x127(float *src, float *dst, void *stream); +// f16 +void LaunchTROWEXPAND_f16_16x512(void *src, void *dst, void *stream); +void LaunchTROWEXPAND_f16_16x511(void *src, void *dst, void *stream); +// i8 +void LaunchTROWEXPAND_i8_16x256(void *src, void *dst, void *stream); +void LaunchTROWEXPAND_i8_16x255(void *src, void *dst, void *stream); + +// Generic launch function type +using LaunchFn = void (*)(void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t srcRows; + size_t srcCols; // srcCols = 32/sizeof(dtype) for alignment + size_t dstRows; + size_t dstCols; + size_t dstValidCols; // effective output columns + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + // f32: srcCols=8 (32/4), dstCols=128, dstValidCols=128 or 127 + {"f32_16x128", (LaunchFn)LaunchTROWEXPAND_f32_16x128, 16, 8, 16, 128, 128, sizeof(float)}, + {"f32_16x127", (LaunchFn)LaunchTROWEXPAND_f32_16x127, 16, 8, 16, 128, 127, sizeof(float)}, + // f16: srcCols=16 (32/2), dstCols=512, dstValidCols=512 or 511 + {"f16_16x512", LaunchTROWEXPAND_f16_16x512, 16, 16, 16, 512, 512, sizeof(uint16_t)}, + {"f16_16x511", LaunchTROWEXPAND_f16_16x511, 16, 16, 16, 512, 511, sizeof(uint16_t)}, + // i8: srcCols=32 (32/1), dstCols=256, dstValidCols=256 or 255 + {"i8_16x256", LaunchTROWEXPAND_i8_16x256, 16, 32, 16, 256, 256, sizeof(int8_t)}, + {"i8_16x255", LaunchTROWEXPAND_i8_16x255, 16, 32, 16, 256, 255, sizeof(int8_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + size_t srcFileSize = tc.srcRows * tc.srcCols * tc.elemSize; + const size_t dstFileSize = tc.dstRows * tc.dstCols * tc.elemSize; + + std::printf("[INFO] === case: %s (src=%zux%zu, dst=%zux%zu, valid_cols=%zu) ===\n", + tc.name, tc.srcRows, tc.srcCols, tc.dstRows, tc.dstCols, tc.dstValidCols); + + std::string caseDir = std::string("./") + tc.name; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&srcHost), srcFileSize); + aclrtMallocHost((void **)(&dstHost), dstFileSize); + + aclrtMalloc((void **)&srcDevice, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), srcFileSize, srcHost, srcFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, srcFileSize, srcHost, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/trowexpand.pto b/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/trowexpand.pto new file mode 100644 index 000000000..8f71191ae --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/trowexpand.pto @@ -0,0 +1,312 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.trowexpand: row broadcast operation. +// dst[row, col] = src[row, 0] (broadcast scalar per row) +// srcCols = 32/sizeof(dtype) for NPU 32-byte alignment + +module { + // f32_16x128: rows=16, srcCols=8, dstValidCols=128, dstCols=128 + func.func @TROWEXPAND_f32_16x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c8], + strides = [%c128, %c128, %c128, %c8, %c1] + : !pto.tensor_view<1x1x1x16x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c8] + : !pto.tensor_view<1x1x1x16x8xf32> -> !pto.partition_tensor_view<1x1x1x16x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x128xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x8xf32>) + outs(%src : !pto.tile_buf) + + pto.trowexpand ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x128xf32>) + return + } + + // f32_16x127: rows=16, srcCols=8, dstValidCols=127, dstCols=128 + func.func @TROWEXPAND_f32_16x127(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c127 = arith.constant 127 : index + %c128 = arith.constant 128 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c8], + strides = [%c128, %c128, %c128, %c8, %c1] + : !pto.tensor_view<1x1x1x16x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c8] + : !pto.tensor_view<1x1x1x16x8xf32> -> !pto.partition_tensor_view<1x1x1x16x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x127xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x8xf32>) + outs(%src : !pto.tile_buf) + + pto.trowexpand ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x127xf32>) + return + } + + // f16_16x512: rows=16, srcCols=16, dstValidCols=512, dstCols=512 + func.func @TROWEXPAND_f16_16x512(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c16], + strides = [%c256, %c256, %c256, %c16, %c1] + : !pto.tensor_view<1x1x1x16x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c512], + strides = [%c8192, %c8192, %c8192, %c512, %c1] + : !pto.tensor_view<1x1x1x16x512xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c16] + : !pto.tensor_view<1x1x1x16x16xf16> -> !pto.partition_tensor_view<1x1x1x16x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c512] + : !pto.tensor_view<1x1x1x16x512xf16> -> !pto.partition_tensor_view<1x1x1x16x512xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x16xf16>) + outs(%src : !pto.tile_buf) + + pto.trowexpand ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x512xf16>) + return + } + + // f16_16x511: rows=16, srcCols=16, dstValidCols=511, dstCols=512 + func.func @TROWEXPAND_f16_16x511(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c511 = arith.constant 511 : index + %c512 = arith.constant 512 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c16], + strides = [%c256, %c256, %c256, %c16, %c1] + : !pto.tensor_view<1x1x1x16x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c512], + strides = [%c8192, %c8192, %c8192, %c512, %c1] + : !pto.tensor_view<1x1x1x16x512xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c16] + : !pto.tensor_view<1x1x1x16x16xf16> -> !pto.partition_tensor_view<1x1x1x16x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c511] + : !pto.tensor_view<1x1x1x16x512xf16> -> !pto.partition_tensor_view<1x1x1x16x511xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x16xf16>) + outs(%src : !pto.tile_buf) + + pto.trowexpand ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x511xf16>) + return + } + + // i8_16x256: rows=16, srcCols=32, dstValidCols=256, dstCols=256 + func.func @TROWEXPAND_i8_16x256(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi8> -> !pto.partition_tensor_view<1x1x1x16x32xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c256] + : !pto.tensor_view<1x1x1x16x256xi8> -> !pto.partition_tensor_view<1x1x1x16x256xi8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x32xi8>) + outs(%src : !pto.tile_buf) + + pto.trowexpand ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x256xi8>) + return + } + + // i8_16x255: rows=16, srcCols=32, dstValidCols=255, dstCols=256 + func.func @TROWEXPAND_i8_16x255(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c255 = arith.constant 255 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi8> -> !pto.partition_tensor_view<1x1x1x16x32xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c255] + : !pto.tensor_view<1x1x1x16x256xi8> -> !pto.partition_tensor_view<1x1x1x16x255xi8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x32xi8>) + outs(%src : !pto.tile_buf) + + pto.trowexpand ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x255xi8>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/CMakeLists.txt new file mode 100644 index 000000000..47f7afb3f --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(trowexpandadd) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/cases.py new file mode 100644 index 000000000..cf5e55bb0 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/cases.py @@ -0,0 +1,116 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for trowexpandadd ST test cases. + +trowexpandadd: dst = src0 + broadcast(src1) across columns. +- src1Col determines how src1 is broadcast: + - src1Col=1: only first column is valid, broadcast to dstCols + - src1Col=8 (for f32): 8 columns are valid, no broadcast needed +- src1Cols (physical) = 32/sizeof(dtype) for NPU alignment + +Template parameters: + - dstRow, dstCol: dst shape + - src1Row, src1Col: src1 shape (src1Col is valid columns, not physical) + - src0eqdst: true means src0 shape equals dst, false means different +""" + +import numpy as np + +CASES = [ + # launchTRowExpandAdd + { + "name": "f32_16x32", + "dtype": np.float32, + "src0_shape": (16, 32), # src0eqdst=true, same as dst + "src0_valid_shape": (16, 32), + "src1_shape": (16, 8), # physical: 32/sizeof(f32)=8 + "src1_valid_shape": (16, 1), # src1Col=1, only first column valid + "dst_shape": (16, 32), + "dst_valid_shape": (16, 32), + "eps": 1e-6, + }, + # launchTRowExpandAdd + { + "name": "f32_56x128", + "dtype": np.float32, + "src0_shape": (56, 128), # src0eqdst=true + "src0_valid_shape": (56, 128), + "src1_shape": (56, 8), # physical: 8 + "src1_valid_shape": (56, 1), # src1Col=1 + "dst_shape": (56, 128), + "dst_valid_shape": (56, 128), + "eps": 1e-6, + }, + # launchTRowExpandAdd + { + "name": "f16_48x64", + "dtype": np.float16, + "src0_shape": (48, 64), # src0eqdst=true + "src0_valid_shape": (48, 64), + "src1_shape": (48, 16), # physical: 32/sizeof(f16)=16 + "src1_valid_shape": (48, 1), # src1Col=1 + "dst_shape": (48, 64), + "dst_valid_shape": (48, 64), + "eps": 1e-3, + }, + # launchTRowExpandAdd + { + "name": "f16_16x128", + "dtype": np.float16, + "src0_shape": (16, 128), # src0eqdst=true + "src0_valid_shape": (16, 128), + "src1_shape": (16, 16), # physical: 16 + "src1_valid_shape": (16, 1), # src1Col=1 + "dst_shape": (16, 128), + "dst_valid_shape": (16, 128), + "eps": 1e-3, + }, + # Note: launchTRowExpandAdd2 with src1Col=8 has different semantics - TBD + # launchTRowExpandAdd2 - needs investigation + # launchTRowExpandAdd2 - needs investigation + # launchTRowExpandAdd + { + "name": "f16_32x64_noeq", + "dtype": np.float16, + "src0_shape": (32, 64), # src0eqdst=false, but src0 shape still matches dst + "src0_valid_shape": (32, 64), + "src1_shape": (32, 16), # physical: 16 + "src1_valid_shape": (32, 1), # src1Col=1 + "dst_shape": (32, 64), + "dst_valid_shape": (32, 64), + "eps": 1e-3, + }, + # launchTRowExpandAdd + { + "name": "i32_16x32", + "dtype": np.int32, + "src0_shape": (16, 32), # src0eqdst=true + "src0_valid_shape": (16, 32), + "src1_shape": (16, 8), # physical: 32/sizeof(i32)=8 + "src1_valid_shape": (16, 1), # src1Col=1 + "dst_shape": (16, 32), + "dst_valid_shape": (16, 32), + "eps": 0, + }, + # launchTRowExpandAdd + { + "name": "i16_16x64", + "dtype": np.int16, + "src0_shape": (16, 64), # src0eqdst=true + "src0_valid_shape": (16, 64), + "src1_shape": (16, 16), # physical: 32/sizeof(i16)=16 + "src1_valid_shape": (16, 1), # src1Col=1 + "dst_shape": (16, 64), + "dst_valid_shape": (16, 64), + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/compare.py new file mode 100644 index 000000000..c88279ea7 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/compare.py @@ -0,0 +1,61 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Compare golden and output for trowexpandadd ST test cases.""" + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass + +# Inline validation for multi-input format (trowexpandadd uses src0/src1/dst) +REQUIRED_KEYS = {"name", "dtype", "src0_shape", "src0_valid_shape", "src1_shape", + "src1_valid_shape", "dst_shape", "dst_valid_shape"} +for i, case in enumerate(CASES): + missing = REQUIRED_KEYS - case.keys() + if missing: + raise ValueError(f"cases[{i}] ({case.get('name', '?')}) missing keys: {missing}") + + +def main(): + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + dtype = case["dtype"] + + vr, vc = dst_valid_shape + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=dtype).reshape(dst_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=dtype).reshape(dst_shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/gen_data.py new file mode 100644 index 000000000..b13261332 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/gen_data.py @@ -0,0 +1,69 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Generate input and golden data for trowexpandadd ST test cases. + +trowexpandadd: dst = src0 + broadcast(src1) across columns. +- src1Col=1: only first column valid, broadcast to all dst columns +- src1Col>1: each src1 column maps to a block of dst columns (dstCol/src1Col) +""" + +import numpy as np +from cases import CASES +from st_common import setup_case_rng, save_case_data + +# Inline validation for multi-input format (trowexpandadd uses src0/src1/dst) +REQUIRED_KEYS = {"name", "dtype", "src0_shape", "src0_valid_shape", "src1_shape", + "src1_valid_shape", "dst_shape", "dst_valid_shape"} +for i, case in enumerate(CASES): + missing = REQUIRED_KEYS - case.keys() + if missing: + raise ValueError(f"cases[{i}] ({case.get('name', '?')}) missing keys: {missing}") + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + src0_shape = case["src0_shape"] + src0_valid_shape = case["src0_valid_shape"] + src1_shape = case["src1_shape"] + src1_valid_shape = case["src1_valid_shape"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + + # Generate inputs + input1 = np.random.randint(1, 10, size=src0_shape).astype(dtype) # src0 matrix + input2 = np.random.randint(1, 10, size=src1_shape).astype(dtype) # src1 row vectors + + # Generate golden: dst = src0 + broadcast(src1) + golden = np.zeros(dst_shape, dtype=dtype) + dst_vr, dst_vc = dst_valid_shape + src0_vr, src0_vc = src0_valid_shape + src1_vr, src1_vc = src1_valid_shape + + if src1_vc == 1: + # src1Col=1: broadcast first column to all dst columns + golden[:dst_vr, :dst_vc] = ( + input1[:src0_vr, :src0_vc] + input2[:src1_vr, 0:1] + ).astype(dtype, copy=False) + else: + # src1Col>1: each src1 column maps to dstCol/src1_vc columns + # dst[:, block*repeat:(block+1)*repeat] = src0 + src1[:, block:block+1] + repeat = dst_vc // src1_vc + for block in range(src1_vc): + start_col = block * repeat + end_col = min((block + 1) * repeat, dst_vc) + golden[:dst_vr, start_col:end_col] = ( + input1[:src0_vr, start_col:end_col] + input2[:src1_vr, block:block+1] + ).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} src0={src0_shape} src1={src1_shape} dst={dst_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/launch.cpp new file mode 100644 index 000000000..5bde96197 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// f32 kernels +extern "C" __global__ AICORE void TROWEXPANDADD_f32_16x32(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); +extern "C" __global__ AICORE void TROWEXPANDADD_f32_56x128(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTROWEXPANDADD_f32_16x32(float *src0, float *src1, float *dst, void *stream) { + TROWEXPANDADD_f32_16x32<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} +void LaunchTROWEXPANDADD_f32_56x128(float *src0, float *src1, float *dst, void *stream) { + TROWEXPANDADD_f32_56x128<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Note: launchTRowExpandAdd2 with src1Col=8 has different semantics - TBD + +// f16 kernels (use uint16_t for aclFloat16) +extern "C" __global__ AICORE void TROWEXPANDADD_f16_48x64(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TROWEXPANDADD_f16_16x128(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TROWEXPANDADD_f16_32x64(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTROWEXPANDADD_f16_48x64(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDADD_f16_48x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} +void LaunchTROWEXPANDADD_f16_16x128(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDADD_f16_16x128<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} +void LaunchTROWEXPANDADD_f16_32x64(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDADD_f16_32x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} + +// i32 kernels +extern "C" __global__ AICORE void TROWEXPANDADD_i32_16x32(__gm__ int32_t *src0, __gm__ int32_t *src1, __gm__ int32_t *dst); + +void LaunchTROWEXPANDADD_i32_16x32(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDADD_i32_16x32<<<1, nullptr, stream>>>((__gm__ int32_t *)src0, (__gm__ int32_t *)src1, (__gm__ int32_t *)dst); +} + +// i16 kernels +extern "C" __global__ AICORE void TROWEXPANDADD_i16_16x64(__gm__ int16_t *src0, __gm__ int16_t *src1, __gm__ int16_t *dst); + +void LaunchTROWEXPANDADD_i16_16x64(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDADD_i16_16x64<<<1, nullptr, stream>>>((__gm__ int16_t *)src0, (__gm__ int16_t *)src1, (__gm__ int16_t *)dst); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/main.cpp new file mode 100644 index 000000000..7e6b1dce6 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/main.cpp @@ -0,0 +1,166 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang trowexpandadd ST — row-wise broadcast addition. +// Supports multiple data types: f32, f16, i32, i16 + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +// f32 +void LaunchTROWEXPANDADD_f32_16x32(float *src0, float *src1, float *dst, void *stream); +void LaunchTROWEXPANDADD_f32_56x128(float *src0, float *src1, float *dst, void *stream); +// f16 (use void* for aclFloat16) +void LaunchTROWEXPANDADD_f16_48x64(void *src0, void *src1, void *dst, void *stream); +void LaunchTROWEXPANDADD_f16_16x128(void *src0, void *src1, void *dst, void *stream); +void LaunchTROWEXPANDADD_f16_32x64(void *src0, void *src1, void *dst, void *stream); +// i32 +void LaunchTROWEXPANDADD_i32_16x32(void *src0, void *src1, void *dst, void *stream); +// i16 +void LaunchTROWEXPANDADD_i16_16x64(void *src0, void *src1, void *dst, void *stream); + +// Note: launchTRowExpandAdd2 with src1Col=8 has different semantics - TBD + +// Generic launch function type +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t src0Rows; + size_t src0Cols; + size_t src1Rows; + size_t src1Cols; // physical src1 cols = 32/sizeof(dtype) + size_t dstRows; + size_t dstCols; + size_t dstValidCols; // effective dst cols + size_t elemSize; +}; + +static const TestCase kCases[] = { + // f32 cases + {"f32_16x32", (LaunchFn)LaunchTROWEXPANDADD_f32_16x32, 16, 32, 16, 8, 16, 32, 32, sizeof(float)}, + {"f32_56x128", (LaunchFn)LaunchTROWEXPANDADD_f32_56x128, 56, 128, 56, 8, 56, 128, 128, sizeof(float)}, + // Note: f32_24x64_v2 and f32_20x64_v2_noeq have different semantics - TBD + // f16 cases + {"f16_48x64", LaunchTROWEXPANDADD_f16_48x64, 48, 64, 48, 16, 48, 64, 64, sizeof(uint16_t)}, + {"f16_16x128", LaunchTROWEXPANDADD_f16_16x128, 16, 128, 16, 16, 16, 128, 128, sizeof(uint16_t)}, + {"f16_32x64_noeq", LaunchTROWEXPANDADD_f16_32x64, 32, 64, 32, 16, 32, 64, 64, sizeof(uint16_t)}, + // i32 cases + {"i32_16x32", LaunchTROWEXPANDADD_i32_16x32, 16, 32, 16, 8, 16, 32, 32, sizeof(int32_t)}, + // i16 cases + {"i16_16x64", LaunchTROWEXPANDADD_i16_16x64, 16, 64, 16, 16, 16, 64, 64, sizeof(int16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + size_t src0FileSize = tc.src0Rows * tc.src0Cols * tc.elemSize; + size_t src1FileSize = tc.src1Rows * tc.src1Cols * tc.elemSize; + size_t dstFileSize = tc.dstRows * tc.dstCols * tc.elemSize; + + std::printf("[INFO] === case: %s (src0=%zux%zu, src1=%zux%zu, dst=%zux%zu) ===\n", + tc.name, tc.src0Rows, tc.src0Cols, tc.src1Rows, tc.src1Cols, tc.dstRows, tc.dstCols); + + std::string caseDir = std::string("./") + tc.name; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), src0FileSize); + aclrtMallocHost((void **)(&src1Host), src1FileSize); + aclrtMallocHost((void **)(&dstHost), dstFileSize); + + aclrtMalloc((void **)&src0Device, src0FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, src1FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, src0FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, src1FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, src0FileSize, src0Host, src0FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, src1FileSize, src1Host, src1FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/trowexpandadd.pto b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/trowexpandadd.pto new file mode 100644 index 000000000..d719219be --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/trowexpandadd.pto @@ -0,0 +1,476 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.trowexpandadd: row-wise broadcast addition. +// dst = src0 + broadcast(src1) where src1 is expanded across columns. +// src1 physical cols = 32/sizeof(dtype) for NPU alignment +// src1 v_col = src1Col from template (1 or 8) + +module { + // f32_16x32: dstRow=16, dstCol=32, src1Row=16, src1Col=1 + func.func @TROWEXPANDADD_f32_16x32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c8], + strides = [%c128, %c128, %c128, %c8, %c1] + : !pto.tensor_view<1x1x1x16x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xf32> -> !pto.partition_tensor_view<1x1x1x16x32xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c8] + : !pto.tensor_view<1x1x1x16x8xf32> -> !pto.partition_tensor_view<1x1x1x16x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xf32> -> !pto.partition_tensor_view<1x1x1x16x32xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x32xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x8xf32>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandadd ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x32xf32>) + return + } + + // f32_56x128: dstRow=56, dstCol=128, src1Row=56, src1Col=1 + func.func @TROWEXPANDADD_f32_56x128(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c56 = arith.constant 56 : index + %c128 = arith.constant 128 : index + %c448 = arith.constant 448 : index + %c7168 = arith.constant 7168 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c56, %c128], + strides = [%c7168, %c7168, %c7168, %c128, %c1] + : !pto.tensor_view<1x1x1x56x128xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c56, %c8], + strides = [%c448, %c448, %c448, %c8, %c1] + : !pto.tensor_view<1x1x1x56x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c56, %c128], + strides = [%c7168, %c7168, %c7168, %c128, %c1] + : !pto.tensor_view<1x1x1x56x128xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c56, %c128] + : !pto.tensor_view<1x1x1x56x128xf32> -> !pto.partition_tensor_view<1x1x1x56x128xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c56, %c8] + : !pto.tensor_view<1x1x1x56x8xf32> -> !pto.partition_tensor_view<1x1x1x56x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c56, %c128] + : !pto.tensor_view<1x1x1x56x128xf32> -> !pto.partition_tensor_view<1x1x1x56x128xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x56x128xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x56x8xf32>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandadd ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x56x128xf32>) + return + } + + // Note: launchTRowExpandAdd2 with src1Col=8 has different semantics - TBD + + // f16_48x64: dstRow=48, dstCol=64, src1Row=48, src1Col=1 + func.func @TROWEXPANDADD_f16_48x64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c48 = arith.constant 48 : index + %c64 = arith.constant 64 : index + %c768 = arith.constant 768 : index + %c3072 = arith.constant 3072 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c48, %c64], + strides = [%c3072, %c3072, %c3072, %c64, %c1] + : !pto.tensor_view<1x1x1x48x64xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c48, %c16], + strides = [%c768, %c768, %c768, %c16, %c1] + : !pto.tensor_view<1x1x1x48x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c48, %c64], + strides = [%c3072, %c3072, %c3072, %c64, %c1] + : !pto.tensor_view<1x1x1x48x64xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c48, %c64] + : !pto.tensor_view<1x1x1x48x64xf16> -> !pto.partition_tensor_view<1x1x1x48x64xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c48, %c16] + : !pto.tensor_view<1x1x1x48x16xf16> -> !pto.partition_tensor_view<1x1x1x48x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c48, %c64] + : !pto.tensor_view<1x1x1x48x64xf16> -> !pto.partition_tensor_view<1x1x1x48x64xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x48x64xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x48x16xf16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandadd ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x48x64xf16>) + return + } + + // f16_16x128: dstRow=16, dstCol=128, src1Row=16, src1Col=1 + func.func @TROWEXPANDADD_f16_16x128(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c16], + strides = [%c256, %c256, %c256, %c16, %c1] + : !pto.tensor_view<1x1x1x16x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x128xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c16] + : !pto.tensor_view<1x1x1x16x16xf16> -> !pto.partition_tensor_view<1x1x1x16x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x128xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x128xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x16xf16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandadd ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x128xf16>) + return + } + + // f16_32x64_noeq: dstRow=32, dstCol=64, src1Row=32, src1Col=1 (src0eqdst=false) + func.func @TROWEXPANDADD_f16_32x64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c512 = arith.constant 512 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c32, %c16], + strides = [%c512, %c512, %c512, %c16, %c1] + : !pto.tensor_view<1x1x1x32x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf16> -> !pto.partition_tensor_view<1x1x1x32x64xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c16] + : !pto.tensor_view<1x1x1x32x16xf16> -> !pto.partition_tensor_view<1x1x1x32x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf16> -> !pto.partition_tensor_view<1x1x1x32x64xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x32x64xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x32x16xf16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandadd ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xf16>) + return + } + + // i32_16x32: dstRow=16, dstCol=32, src1Row=16, src1Col=1 + func.func @TROWEXPANDADD_i32_16x32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c8], + strides = [%c128, %c128, %c128, %c8, %c1] + : !pto.tensor_view<1x1x1x16x8xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c8] + : !pto.tensor_view<1x1x1x16x8xi32> -> !pto.partition_tensor_view<1x1x1x16x8xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x8xi32>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandadd ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + return + } + + // i16_16x64: dstRow=16, dstCol=64, src1Row=16, src1Col=1 + func.func @TROWEXPANDADD_i16_16x64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c16], + strides = [%c256, %c256, %c256, %c16, %c1] + : !pto.tensor_view<1x1x1x16x16xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c16] + : !pto.tensor_view<1x1x1x16x16xi16> -> !pto.partition_tensor_view<1x1x1x16x16xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x16xi16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandadd ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/CMakeLists.txt new file mode 100644 index 000000000..2cbab13c9 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(trowexpanddiv) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/cases.py new file mode 100644 index 000000000..ad20e8203 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/cases.py @@ -0,0 +1,130 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for trowexpanddiv ST test cases. + +trowexpanddiv: dst = src0 / broadcast(src1) across columns. +- src1Col determines how src1 is broadcast: + - src1Col=1: only first column is valid, broadcast to dstCols + - src1Col>1: each src1 column maps to a block of dst columns (dstCol/src1Col columns per src1 value) +- src1 physical cols = 32/sizeof(dtype) for NPU alignment +- highPrecision: use high precision mode for computation +""" + +import numpy as np + +CASES = [ + # launchTRowExpandDiv + { + "name": "f32_40x64", + "dtype": np.float32, + "src0_shape": (40, 64), # src0eqdst=true + "src0_valid_shape": (40, 64), + "src1_shape": (40, 8), # physical: 32/sizeof(f32)=8 + "src1_valid_shape": (40, 1), # src1Col=1 + "dst_shape": (40, 64), + "dst_valid_shape": (40, 64), + "eps": 1e-6, + "high_precision": False, + }, + # launchTRowExpandDiv + { + "name": "f32_16x256", + "dtype": np.float32, + "src0_shape": (16, 256), + "src0_valid_shape": (16, 256), + "src1_shape": (16, 8), + "src1_valid_shape": (16, 1), + "dst_shape": (16, 256), + "dst_valid_shape": (16, 256), + "eps": 1e-6, + "high_precision": False, + }, + # launchTRowExpandDiv + { + "name": "f16_16x32", + "dtype": np.float16, + "src0_shape": (16, 32), + "src0_valid_shape": (16, 32), + "src1_shape": (16, 16), # physical: 32/sizeof(f16)=16 + "src1_valid_shape": (16, 1), + "dst_shape": (16, 32), + "dst_valid_shape": (16, 32), + "eps": 1e-3, + "high_precision": False, + }, + # launchTRowExpandDiv + { + "name": "f16_32x512", + "dtype": np.float16, + "src0_shape": (32, 512), + "src0_valid_shape": (32, 512), + "src1_shape": (32, 16), + "src1_valid_shape": (32, 1), + "dst_shape": (32, 512), + "dst_valid_shape": (32, 512), + "eps": 1e-3, + "high_precision": False, + }, + # launchTRowExpandDiv + { + "name": "f32_16x128_noeq", + "dtype": np.float32, + "src0_shape": (16, 128), # src0eqdst=false + "src0_valid_shape": (16, 128), + "src1_shape": (16, 8), + "src1_valid_shape": (16, 1), + "dst_shape": (16, 128), + "dst_valid_shape": (16, 128), + "eps": 1e-6, + "high_precision": False, + }, + # launchTRowExpandDiv + { + "name": "f16_32x64_noeq", + "dtype": np.float16, + "src0_shape": (32, 64), + "src0_valid_shape": (32, 64), + "src1_shape": (32, 16), + "src1_valid_shape": (32, 1), + "dst_shape": (32, 64), + "dst_valid_shape": (32, 64), + "eps": 1e-3, + "high_precision": False, + }, + # launchTRowExpandDiv + { + "name": "f32_40x32_hp", + "dtype": np.float32, + "src0_shape": (40, 32), + "src0_valid_shape": (40, 32), + "src1_shape": (40, 8), + "src1_valid_shape": (40, 1), + "dst_shape": (40, 32), + "dst_valid_shape": (40, 32), + "eps": 1e-6, + "high_precision": True, + }, + # launchTRowExpandDiv + { + "name": "f16_16x128_hp", + "dtype": np.float16, + "src0_shape": (16, 128), + "src0_valid_shape": (16, 128), + "src1_shape": (16, 16), + "src1_valid_shape": (16, 1), + "dst_shape": (16, 128), + "dst_valid_shape": (16, 128), + "eps": 1e-3, + "high_precision": True, + }, + # Note: launchTRowExpandDiv2 with src1Col>1 has different semantics - TBD +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/compare.py new file mode 100644 index 000000000..c6dfc114c --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/compare.py @@ -0,0 +1,61 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Compare golden and output for trowexpanddiv ST test cases.""" + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass + +# Inline validation for multi-input format (trowexpanddiv uses src0/src1/dst) +REQUIRED_KEYS = {"name", "dtype", "src0_shape", "src0_valid_shape", "src1_shape", + "src1_valid_shape", "dst_shape", "dst_valid_shape"} +for i, case in enumerate(CASES): + missing = REQUIRED_KEYS - case.keys() + if missing: + raise ValueError(f"cases[{i}] ({case.get('name', '?')}) missing keys: {missing}") + + +def main(): + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + dtype = case["dtype"] + + vr, vc = dst_valid_shape + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=dtype).reshape(dst_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=dtype).reshape(dst_shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/gen_data.py new file mode 100644 index 000000000..a260bf68e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/gen_data.py @@ -0,0 +1,77 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Generate input and golden data for trowexpanddiv ST test cases. + +trowexpanddiv: dst = src0 / broadcast(src1) +""" + +import numpy as np +from cases import CASES +from st_common import setup_case_rng, save_case_data + +# Inline validation for multi-input format (trowexpanddiv uses src0/src1/dst) +REQUIRED_KEYS = {"name", "dtype", "src0_shape", "src0_valid_shape", "src1_shape", + "src1_valid_shape", "dst_shape", "dst_valid_shape"} +for i, case in enumerate(CASES): + missing = REQUIRED_KEYS - case.keys() + if missing: + raise ValueError(f"cases[{i}] ({case.get('name', '?')}) missing keys: {missing}") + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + src0_shape = case["src0_shape"] + src0_valid_shape = case["src0_valid_shape"] + src1_shape = case["src1_shape"] + src1_valid_shape = case["src1_valid_shape"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + + input1 = np.random.randint(1, 10, size=src0_shape).astype(dtype) + input2 = np.random.randint(1, 10, size=src1_shape).astype(dtype) + + golden = np.zeros(dst_shape, dtype=dtype) + dst_vr, dst_vc = dst_valid_shape + src0_vr, src0_vc = src0_valid_shape + src1_vr, src1_vc = src1_valid_shape + + # Compute golden based on src1Col semantics + # src1Col=1: broadcast single column to all dst columns + # src1Col>1: each src1 column broadcasts to dst_vc/src1_vc columns + if dtype in (np.int8, np.int16, np.int32): + if src1_vc == 1: + golden[:dst_vr, :dst_vc] = ( + input1[:src0_vr, :src0_vc] // input2[:src1_vr, 0:1] + ).astype(dtype, copy=False) + else: + # src1Col > 1: each src1 column broadcasts to dst_vc/src1_vc dst columns + block_size = dst_vc // src1_vc + for c in range(src1_vc): + golden[:dst_vr, c*block_size:(c+1)*block_size] = ( + input1[:src0_vr, c*block_size:(c+1)*block_size] // input2[:src1_vr, c:c+1] + ).astype(dtype, copy=False) + else: + if src1_vc == 1: + golden[:dst_vr, :dst_vc] = ( + input1[:src0_vr, :src0_vc] / input2[:src1_vr, 0:1] + ).astype(dtype, copy=False) + else: + # src1Col > 1: each src1 column broadcasts to dst_vc/src1_vc dst columns + block_size = dst_vc // src1_vc + for c in range(src1_vc): + golden[:dst_vr, c*block_size:(c+1)*block_size] = ( + input1[:src0_vr, c*block_size:(c+1)*block_size] / input2[:src1_vr, c:c+1] + ).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/launch.cpp new file mode 100644 index 000000000..e028030ce --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/launch.cpp @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// f32 kernels +extern "C" __global__ AICORE void TROWEXPANDDIV_f32_40x64(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); +extern "C" __global__ AICORE void TROWEXPANDDIV_f32_16x256(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); +extern "C" __global__ AICORE void TROWEXPANDDIV_f32_16x128_noeq(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); +extern "C" __global__ AICORE void TROWEXPANDDIV_f32_40x32_hp(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTROWEXPANDDIV_f32_40x64(float *src0, float *src1, float *dst, void *stream) { + TROWEXPANDDIV_f32_40x64<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} +void LaunchTROWEXPANDDIV_f32_16x256(float *src0, float *src1, float *dst, void *stream) { + TROWEXPANDDIV_f32_16x256<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} +void LaunchTROWEXPANDDIV_f32_16x128_noeq(float *src0, float *src1, float *dst, void *stream) { + TROWEXPANDDIV_f32_16x128_noeq<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} +void LaunchTROWEXPANDDIV_f32_40x32_hp(float *src0, float *src1, float *dst, void *stream) { + TROWEXPANDDIV_f32_40x32_hp<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// f16 kernels (use uint16_t for aclFloat16) +extern "C" __global__ AICORE void TROWEXPANDDIV_f16_16x32(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TROWEXPANDDIV_f16_32x512(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TROWEXPANDDIV_f16_32x64_noeq(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TROWEXPANDDIV_f16_16x128_hp(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTROWEXPANDDIV_f16_16x32(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDDIV_f16_16x32<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} +void LaunchTROWEXPANDDIV_f16_32x512(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDDIV_f16_32x512<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} +void LaunchTROWEXPANDDIV_f16_32x64_noeq(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDDIV_f16_32x64_noeq<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} +void LaunchTROWEXPANDDIV_f16_16x128_hp(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDDIV_f16_16x128_hp<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/main.cpp new file mode 100644 index 000000000..6d2fcded5 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/main.cpp @@ -0,0 +1,131 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang trowexpanddiv ST — row-wise broadcast division. +// Supports f32, f16 +// Div variants: src1Col=1 (broadcast single value) or src1Col>1 (block broadcast) + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// f32 kernels +void LaunchTROWEXPANDDIV_f32_40x64(float *src0, float *src1, float *dst, void *stream); +void LaunchTROWEXPANDDIV_f32_16x256(float *src0, float *src1, float *dst, void *stream); +void LaunchTROWEXPANDDIV_f32_16x128_noeq(float *src0, float *src1, float *dst, void *stream); +void LaunchTROWEXPANDDIV_f32_40x32_hp(float *src0, float *src1, float *dst, void *stream); +// f16 kernels (use void* for aclFloat16) +void LaunchTROWEXPANDDIV_f16_16x32(void *src0, void *src1, void *dst, void *stream); +void LaunchTROWEXPANDDIV_f16_32x512(void *src0, void *src1, void *dst, void *stream); +void LaunchTROWEXPANDDIV_f16_32x64_noeq(void *src0, void *src1, void *dst, void *stream); +void LaunchTROWEXPANDDIV_f16_16x128_hp(void *src0, void *src1, void *dst, void *stream); + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t src0Rows, src0Cols, src1Rows, src1Cols, dstRows, dstCols; + size_t dstValidRows, dstValidCols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + // f32 cases + {"f32_40x64", (LaunchFn)LaunchTROWEXPANDDIV_f32_40x64, 40, 64, 40, 8, 40, 64, 40, 64, sizeof(float)}, + {"f32_16x256", (LaunchFn)LaunchTROWEXPANDDIV_f32_16x256, 16, 256, 16, 8, 16, 256, 16, 256, sizeof(float)}, + {"f32_16x128_noeq", (LaunchFn)LaunchTROWEXPANDDIV_f32_16x128_noeq, 16, 128, 16, 8, 16, 128, 16, 128, sizeof(float)}, + {"f32_40x32_hp", (LaunchFn)LaunchTROWEXPANDDIV_f32_40x32_hp, 40, 32, 40, 8, 40, 32, 40, 32, sizeof(float)}, + // f16 cases + {"f16_16x32", LaunchTROWEXPANDDIV_f16_16x32, 16, 32, 16, 16, 16, 32, 16, 32, sizeof(uint16_t)}, + {"f16_32x512", LaunchTROWEXPANDDIV_f16_32x512, 32, 512, 32, 16, 32, 512, 32, 512, sizeof(uint16_t)}, + {"f16_32x64_noeq", LaunchTROWEXPANDDIV_f16_32x64_noeq, 32, 64, 32, 16, 32, 64, 32, 64, sizeof(uint16_t)}, + {"f16_16x128_hp", LaunchTROWEXPANDDIV_f16_16x128_hp, 16, 128, 16, 16, 16, 128, 16, 128, sizeof(uint16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + size_t src0FileSize = tc.src0Rows * tc.src0Cols * tc.elemSize; + size_t src1FileSize = tc.src1Rows * tc.src1Cols * tc.elemSize; + const size_t dstFileSize = tc.dstRows * tc.dstCols * tc.elemSize; + + std::printf("[INFO] === case: %s (src0=%zux%zu, src1=%zux%zu, dst=%zux%zu) ===\n", + tc.name, tc.src0Rows, tc.src0Cols, tc.src1Rows, tc.src1Cols, tc.dstRows, tc.dstCols); + + std::string caseDir = std::string("./") + tc.name; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), src0FileSize); + aclrtMallocHost((void **)(&src1Host), src1FileSize); + aclrtMallocHost((void **)(&dstHost), dstFileSize); + aclrtMalloc((void **)&src0Device, src0FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, src1FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, src0FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, src1FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, src0FileSize, src0Host, src0FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, src1FileSize, src1Host, src1FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + tc.launch(src0Device, src1Device, dstDevice, stream); + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device) aclrtFree(src0Device); + if (src1Device) aclrtFree(src1Device); + if (dstDevice) aclrtFree(dstDevice); + if (src0Host) aclrtFreeHost(src0Host); + if (src1Host) aclrtFreeHost(src1Host); + if (dstHost) aclrtFreeHost(dstHost); + + if (rc == 0) std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + int rc = 0, deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) deviceId = std::atoi(envDevice); + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter && std::strcmp(kCases[i].name, caseFilter) != 0) continue; + if (RunCase(kCases[i], deviceId, stream) != 0) { rc = 1; break; } + } + + if (stream) aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/trowexpanddiv.pto b/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/trowexpanddiv.pto new file mode 100644 index 000000000..651f1e5d7 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/trowexpanddiv.pto @@ -0,0 +1,860 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.trowexpanddiv: row-wise broadcast division. +// Supports f32, f16 types. +// src1Col=1: broadcast single column value to all dst columns +// src1Col>1: each src1 column broadcasts to dstCol/src1Col columns + +module { + // f32_40x64: launchTRowExpandDiv + func.func @TROWEXPANDDIV_f32_40x64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c40 = arith.constant 40 : index + %c64 = arith.constant 64 : index + %c320 = arith.constant 320 : index + %c2560 = arith.constant 2560 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c40, %c64], + strides = [%c2560, %c2560, %c2560, %c64, %c1] + : !pto.tensor_view<1x1x1x40x64xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c40, %c8], + strides = [%c320, %c320, %c320, %c8, %c1] + : !pto.tensor_view<1x1x1x40x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c40, %c64], + strides = [%c2560, %c2560, %c2560, %c64, %c1] + : !pto.tensor_view<1x1x1x40x64xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c40, %c64] + : !pto.tensor_view<1x1x1x40x64xf32> -> !pto.partition_tensor_view<1x1x1x40x64xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c40, %c8] + : !pto.tensor_view<1x1x1x40x8xf32> -> !pto.partition_tensor_view<1x1x1x40x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c40, %c64] + : !pto.tensor_view<1x1x1x40x64xf32> -> !pto.partition_tensor_view<1x1x1x40x64xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x40x64xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x40x8xf32>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpanddiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true, highPrecision = false} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x40x64xf32>) + return + } + + // f32_16x256: launchTRowExpandDiv + func.func @TROWEXPANDDIV_f32_16x256(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c4096 = arith.constant 4096 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c8], + strides = [%c128, %c128, %c128, %c8, %c1] + : !pto.tensor_view<1x1x1x16x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c256] + : !pto.tensor_view<1x1x1x16x256xf32> -> !pto.partition_tensor_view<1x1x1x16x256xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c8] + : !pto.tensor_view<1x1x1x16x8xf32> -> !pto.partition_tensor_view<1x1x1x16x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c256] + : !pto.tensor_view<1x1x1x16x256xf32> -> !pto.partition_tensor_view<1x1x1x16x256xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x256xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x8xf32>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpanddiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true, highPrecision = false} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x256xf32>) + return + } + + // f16_16x32: launchTRowExpandDiv + func.func @TROWEXPANDDIV_f16_16x32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c16], + strides = [%c256, %c256, %c256, %c16, %c1] + : !pto.tensor_view<1x1x1x16x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xf16> -> !pto.partition_tensor_view<1x1x1x16x32xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c16] + : !pto.tensor_view<1x1x1x16x16xf16> -> !pto.partition_tensor_view<1x1x1x16x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xf16> -> !pto.partition_tensor_view<1x1x1x16x32xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x32xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x16xf16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpanddiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true, highPrecision = false} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x32xf16>) + return + } + + // f16_32x512: launchTRowExpandDiv + func.func @TROWEXPANDDIV_f16_32x512(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c512 = arith.constant 512 : index + %c512_ = arith.constant 512 : index + %c16384 = arith.constant 16384 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c32, %c512], + strides = [%c16384, %c16384, %c16384, %c512, %c1] + : !pto.tensor_view<1x1x1x32x512xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c32, %c16], + strides = [%c512_, %c512_, %c512_, %c16, %c1] + : !pto.tensor_view<1x1x1x32x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c512], + strides = [%c16384, %c16384, %c16384, %c512, %c1] + : !pto.tensor_view<1x1x1x32x512xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c512] + : !pto.tensor_view<1x1x1x32x512xf16> -> !pto.partition_tensor_view<1x1x1x32x512xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c16] + : !pto.tensor_view<1x1x1x32x16xf16> -> !pto.partition_tensor_view<1x1x1x32x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c512] + : !pto.tensor_view<1x1x1x32x512xf16> -> !pto.partition_tensor_view<1x1x1x32x512xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x32x512xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x32x16xf16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpanddiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true, highPrecision = false} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x512xf16>) + return + } + + // f32_16x128_noeq: launchTRowExpandDiv + func.func @TROWEXPANDDIV_f32_16x128_noeq(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c128_ = arith.constant 128 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c8], + strides = [%c128_, %c128_, %c128_, %c8, %c1] + : !pto.tensor_view<1x1x1x16x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x128xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c8] + : !pto.tensor_view<1x1x1x16x8xf32> -> !pto.partition_tensor_view<1x1x1x16x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x128xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x128xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x8xf32>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpanddiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = false, highPrecision = false} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x128xf32>) + return + } + + // f16_32x64_noeq: launchTRowExpandDiv + func.func @TROWEXPANDDIV_f16_32x64_noeq(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c512 = arith.constant 512 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c32, %c16], + strides = [%c512, %c512, %c512, %c16, %c1] + : !pto.tensor_view<1x1x1x32x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf16> -> !pto.partition_tensor_view<1x1x1x32x64xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c16] + : !pto.tensor_view<1x1x1x32x16xf16> -> !pto.partition_tensor_view<1x1x1x32x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf16> -> !pto.partition_tensor_view<1x1x1x32x64xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x32x64xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x32x16xf16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpanddiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = false, highPrecision = false} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xf16>) + return + } + + // f32_40x32_hp: launchTRowExpandDiv (highPrecision) + func.func @TROWEXPANDDIV_f32_40x32_hp(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c32 = arith.constant 32 : index + %c40 = arith.constant 40 : index + %c320 = arith.constant 320 : index + %c1280 = arith.constant 1280 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c40, %c32], + strides = [%c1280, %c1280, %c1280, %c32, %c1] + : !pto.tensor_view<1x1x1x40x32xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c40, %c8], + strides = [%c320, %c320, %c320, %c8, %c1] + : !pto.tensor_view<1x1x1x40x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c40, %c32], + strides = [%c1280, %c1280, %c1280, %c32, %c1] + : !pto.tensor_view<1x1x1x40x32xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c40, %c32] + : !pto.tensor_view<1x1x1x40x32xf32> -> !pto.partition_tensor_view<1x1x1x40x32xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c40, %c8] + : !pto.tensor_view<1x1x1x40x8xf32> -> !pto.partition_tensor_view<1x1x1x40x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c40, %c32] + : !pto.tensor_view<1x1x1x40x32xf32> -> !pto.partition_tensor_view<1x1x1x40x32xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x40x32xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x40x8xf32>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpanddiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true, highPrecision = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x40x32xf32>) + return + } + + // f16_16x128_hp: launchTRowExpandDiv (highPrecision) + func.func @TROWEXPANDDIV_f16_16x128_hp(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c16], + strides = [%c256, %c256, %c256, %c16, %c1] + : !pto.tensor_view<1x1x1x16x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x128xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c16] + : !pto.tensor_view<1x1x1x16x16xf16> -> !pto.partition_tensor_view<1x1x1x16x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x128xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x128xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x16xf16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpanddiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true, highPrecision = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x128xf16>) + return + } + + // f32_24x64_v2: launchTRowExpandDiv2 + func.func @TROWEXPANDDIV_f32_24x64_v2(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c24 = arith.constant 24 : index + %c64 = arith.constant 64 : index + %c192 = arith.constant 192 : index + %c1536 = arith.constant 1536 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c24, %c64], + strides = [%c1536, %c1536, %c1536, %c64, %c1] + : !pto.tensor_view<1x1x1x24x64xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c24, %c8], + strides = [%c192, %c192, %c192, %c8, %c1] + : !pto.tensor_view<1x1x1x24x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c24, %c64], + strides = [%c1536, %c1536, %c1536, %c64, %c1] + : !pto.tensor_view<1x1x1x24x64xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c24, %c64] + : !pto.tensor_view<1x1x1x24x64xf32> -> !pto.partition_tensor_view<1x1x1x24x64xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c24, %c8] + : !pto.tensor_view<1x1x1x24x8xf32> -> !pto.partition_tensor_view<1x1x1x24x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c24, %c64] + : !pto.tensor_view<1x1x1x24x64xf32> -> !pto.partition_tensor_view<1x1x1x24x64xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x24x64xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x24x8xf32>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpanddiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 8 : i64, src0eqdst = true, highPrecision = false} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x24x64xf32>) + return + } + + // f16_32x32_v2: launchTRowExpandDiv2 + func.func @TROWEXPANDDIV_f16_32x32_v2(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c32, %c16], + strides = [%c512, %c512, %c512, %c16, %c1] + : !pto.tensor_view<1x1x1x32x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c16] + : !pto.tensor_view<1x1x1x32x16xf16> -> !pto.partition_tensor_view<1x1x1x32x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x32x16xf16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpanddiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 16 : i64, src0eqdst = true, highPrecision = false} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + return + } + + // f32_20x64_v2_noeq: launchTRowExpandDiv2 + func.func @TROWEXPANDDIV_f32_20x64_v2_noeq(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c20 = arith.constant 20 : index + %c64 = arith.constant 64 : index + %c160 = arith.constant 160 : index + %c1280 = arith.constant 1280 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c20, %c64], + strides = [%c1280, %c1280, %c1280, %c64, %c1] + : !pto.tensor_view<1x1x1x20x64xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c20, %c8], + strides = [%c160, %c160, %c160, %c8, %c1] + : !pto.tensor_view<1x1x1x20x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c20, %c64], + strides = [%c1280, %c1280, %c1280, %c64, %c1] + : !pto.tensor_view<1x1x1x20x64xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c20, %c64] + : !pto.tensor_view<1x1x1x20x64xf32> -> !pto.partition_tensor_view<1x1x1x20x64xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c20, %c8] + : !pto.tensor_view<1x1x1x20x8xf32> -> !pto.partition_tensor_view<1x1x1x20x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c20, %c64] + : !pto.tensor_view<1x1x1x20x64xf32> -> !pto.partition_tensor_view<1x1x1x20x64xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x20x64xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x20x8xf32>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpanddiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 8 : i64, src0eqdst = false, highPrecision = false} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x20x64xf32>) + return + } + + // f16_16x64_v2_noeq: launchTRowExpandDiv2 + func.func @TROWEXPANDDIV_f16_16x64_v2_noeq(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c16], + strides = [%c256, %c256, %c256, %c16, %c1] + : !pto.tensor_view<1x1x1x16x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c16] + : !pto.tensor_view<1x1x1x16x16xf16> -> !pto.partition_tensor_view<1x1x1x16x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x16xf16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpanddiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 16 : i64, src0eqdst = false, highPrecision = false} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + return + } + + // f32_8x32_v2_hp: launchTRowExpandDiv2 + func.func @TROWEXPANDDIV_f32_8x32_v2_hp(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c8, %c32], + strides = [%c256, %c256, %c256, %c32, %c1] + : !pto.tensor_view<1x1x1x8x32xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c8, %c8], + strides = [%c256, %c256, %c256, %c8, %c1] + : !pto.tensor_view<1x1x1x8x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c8, %c32], + strides = [%c256, %c256, %c256, %c32, %c1] + : !pto.tensor_view<1x1x1x8x32xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c32] + : !pto.tensor_view<1x1x1x8x32xf32> -> !pto.partition_tensor_view<1x1x1x8x32xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c8] + : !pto.tensor_view<1x1x1x8x8xf32> -> !pto.partition_tensor_view<1x1x1x8x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c32] + : !pto.tensor_view<1x1x1x8x32xf32> -> !pto.partition_tensor_view<1x1x1x8x32xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x8x32xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x8x8xf32>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpanddiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 8 : i64, src0eqdst = true, highPrecision = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x8x32xf32>) + return + } + + // f16_8x128_v2_hp: launchTRowExpandDiv2 + func.func @TROWEXPANDDIV_f16_8x128_v2_hp(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c8, %c128], + strides = [%c1024, %c1024, %c1024, %c128, %c1] + : !pto.tensor_view<1x1x1x8x128xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c8, %c16], + strides = [%c128, %c128, %c128, %c16, %c1] + : !pto.tensor_view<1x1x1x8x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c8, %c128], + strides = [%c1024, %c1024, %c1024, %c128, %c1] + : !pto.tensor_view<1x1x1x8x128xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c128] + : !pto.tensor_view<1x1x1x8x128xf16> -> !pto.partition_tensor_view<1x1x1x8x128xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c16] + : !pto.tensor_view<1x1x1x8x16xf16> -> !pto.partition_tensor_view<1x1x1x8x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c128] + : !pto.tensor_view<1x1x1x8x128xf16> -> !pto.partition_tensor_view<1x1x1x8x128xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x8x128xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x8x16xf16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpanddiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 16 : i64, src0eqdst = true, highPrecision = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x8x128xf16>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/CMakeLists.txt new file mode 100644 index 000000000..fd0640efc --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(trowexpandexpdif) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/cases.py new file mode 100644 index 000000000..111391868 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/cases.py @@ -0,0 +1,87 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for trowexpandexpdif ST test cases. + +trowexpandexpdif: dst = exp(src0 - broadcast(src1)) +- src1Col=1: only first column of src1 is valid, broadcast to dstCols +- src1Col>1: launchTRowExpandExpdif2 with different semantics (TBD) +- src1 physical cols = 32/sizeof(dtype) for NPU alignment +- src0eqdst: true means src0 shape equals dst shape +""" + +import numpy as np + +CASES = [ + # launchTRowExpandExpdif + { + "name": "f32_32x64", + "dtype": np.float32, + "src0_shape": (32, 64), + "src0_valid_shape": (32, 64), + "src1_shape": (32, 8), # physical: 32/sizeof(f32)=8 + "src1_valid_shape": (32, 1), # src1Col=1 + "dst_shape": (32, 64), + "dst_valid_shape": (32, 64), + "eps": 1e-5, + }, + # launchTRowExpandExpdif + { + "name": "f32_16x32", + "dtype": np.float32, + "src0_shape": (16, 32), + "src0_valid_shape": (16, 32), + "src1_shape": (16, 8), + "src1_valid_shape": (16, 1), + "dst_shape": (16, 32), + "dst_valid_shape": (16, 32), + "eps": 1e-5, + }, + # launchTRowExpandExpdif + { + "name": "f16_16x32", + "dtype": np.float16, + "src0_shape": (16, 32), + "src0_valid_shape": (16, 32), + "src1_shape": (16, 16), # physical: 32/sizeof(f16)=16 + "src1_valid_shape": (16, 1), + "dst_shape": (16, 32), + "dst_valid_shape": (16, 32), + "eps": 1e-3, + }, + # launchTRowExpandExpdif + { + "name": "f16_48x64", + "dtype": np.float16, + "src0_shape": (48, 64), + "src0_valid_shape": (48, 64), + "src1_shape": (48, 16), + "src1_valid_shape": (48, 1), + "dst_shape": (48, 64), + "dst_valid_shape": (48, 64), + "eps": 1e-3, + }, + # launchTRowExpandExpdif + { + "name": "f32_16x128_noeq", + "dtype": np.float32, + "src0_shape": (16, 128), # src0eqdst=false + "src0_valid_shape": (16, 128), + "src1_shape": (16, 8), + "src1_valid_shape": (16, 1), + "dst_shape": (16, 128), + "dst_valid_shape": (16, 128), + "eps": 1e-5, + }, + # Note: launchTRowExpandExpdif2 with src1Col>1 has different semantics - TBD + # - float, 24, 64, 24, 8, true (src1Col=8) + # - aclFloat16, 16, 64, 16, 16, false (src1Col=16) +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/compare.py new file mode 100644 index 000000000..98aff7854 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/compare.py @@ -0,0 +1,61 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Compare golden and output for trowexpandexpdif ST test cases.""" + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass + +# Inline validation for multi-input format (trowexpandexpdif uses src0/src1/dst) +REQUIRED_KEYS = {"name", "dtype", "src0_shape", "src0_valid_shape", "src1_shape", + "src1_valid_shape", "dst_shape", "dst_valid_shape"} +for i, case in enumerate(CASES): + missing = REQUIRED_KEYS - case.keys() + if missing: + raise ValueError(f"cases[{i}] ({case.get('name', '?')}) missing keys: {missing}") + + +def main(): + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + dtype = case["dtype"] + + vr, vc = dst_valid_shape + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=dtype).reshape(dst_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=dtype).reshape(dst_shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/gen_data.py new file mode 100644 index 000000000..8b6e09814 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/gen_data.py @@ -0,0 +1,54 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Generate input and golden data for trowexpandexpdif ST test cases. + +trowexpandexpdif: dst = exp(src0 - broadcast(src1)) +""" + +import numpy as np +from cases import CASES +from st_common import setup_case_rng, save_case_data + +# Inline validation for multi-input format (trowexpandexpdif uses src0/src1/dst) +REQUIRED_KEYS = {"name", "dtype", "src0_shape", "src0_valid_shape", "src1_shape", + "src1_valid_shape", "dst_shape", "dst_valid_shape"} +for i, case in enumerate(CASES): + missing = REQUIRED_KEYS - case.keys() + if missing: + raise ValueError(f"cases[{i}] ({case.get('name', '?')}) missing keys: {missing}") + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + src0_shape = case["src0_shape"] + src0_valid_shape = case["src0_valid_shape"] + src1_shape = case["src1_shape"] + src1_valid_shape = case["src1_valid_shape"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + + # Use small values to avoid overflow in exp + input1 = np.random.randint(1, 5, size=src0_shape).astype(dtype) + input2 = np.random.randint(1, 5, size=src1_shape).astype(dtype) + + golden = np.zeros(dst_shape, dtype=dtype) + dst_vr, dst_vc = dst_valid_shape + src0_vr, src0_vc = src0_valid_shape + src1_vr = src1_valid_shape[0] + + # exp(src0 - src1_scalar) + diff = input1[:src0_vr, :src0_vc] - input2[:src1_vr, 0:1] + golden[:dst_vr, :dst_vc] = np.exp(diff).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/launch.cpp new file mode 100644 index 000000000..6b65978f7 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/launch.cpp @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// f32 kernels +extern "C" __global__ AICORE void TROWEXPANDEXPDIF_f32_32x64(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); +extern "C" __global__ AICORE void TROWEXPANDEXPDIF_f32_16x32(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); +extern "C" __global__ AICORE void TROWEXPANDEXPDIF_f32_16x128_noeq(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTROWEXPANDEXPDIF_f32_32x64(float *src0, float *src1, float *dst, void *stream) { + TROWEXPANDEXPDIF_f32_32x64<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} +void LaunchTROWEXPANDEXPDIF_f32_16x32(float *src0, float *src1, float *dst, void *stream) { + TROWEXPANDEXPDIF_f32_16x32<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} +void LaunchTROWEXPANDEXPDIF_f32_16x128_noeq(float *src0, float *src1, float *dst, void *stream) { + TROWEXPANDEXPDIF_f32_16x128_noeq<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// f16 kernels (use uint16_t for aclFloat16) +extern "C" __global__ AICORE void TROWEXPANDEXPDIF_f16_16x32(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TROWEXPANDEXPDIF_f16_48x64(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTROWEXPANDEXPDIF_f16_16x32(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDEXPDIF_f16_16x32<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} +void LaunchTROWEXPANDEXPDIF_f16_48x64(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDEXPDIF_f16_48x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} + +// Note: launchTRowExpandExpdif2 with src1Col>1 has different semantics - TBD \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/main.cpp new file mode 100644 index 000000000..92c7ba517 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/main.cpp @@ -0,0 +1,126 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang trowexpandexpdif ST — row-wise broadcast exponential difference. +// Supports f32, f16 + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// f32 kernels +void LaunchTROWEXPANDEXPDIF_f32_32x64(float *src0, float *src1, float *dst, void *stream); +void LaunchTROWEXPANDEXPDIF_f32_16x32(float *src0, float *src1, float *dst, void *stream); +void LaunchTROWEXPANDEXPDIF_f32_16x128_noeq(float *src0, float *src1, float *dst, void *stream); +// f16 kernels (use void* for aclFloat16) +void LaunchTROWEXPANDEXPDIF_f16_16x32(void *src0, void *src1, void *dst, void *stream); +void LaunchTROWEXPANDEXPDIF_f16_48x64(void *src0, void *src1, void *dst, void *stream); + +// Note: launchTRowExpandExpdif2 with src1Col>1 has different semantics - TBD + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t src0Rows, src0Cols, src1Rows, src1Cols, dstRows, dstCols; + size_t dstValidRows, dstValidCols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + // f32 cases + {"f32_32x64", (LaunchFn)LaunchTROWEXPANDEXPDIF_f32_32x64, 32, 64, 32, 8, 32, 64, 32, 64, sizeof(float)}, + {"f32_16x32", (LaunchFn)LaunchTROWEXPANDEXPDIF_f32_16x32, 16, 32, 16, 8, 16, 32, 16, 32, sizeof(float)}, + {"f32_16x128_noeq", (LaunchFn)LaunchTROWEXPANDEXPDIF_f32_16x128_noeq, 16, 128, 16, 8, 16, 128, 16, 128, sizeof(float)}, + // f16 cases + {"f16_16x32", LaunchTROWEXPANDEXPDIF_f16_16x32, 16, 32, 16, 16, 16, 32, 16, 32, sizeof(uint16_t)}, + {"f16_48x64", LaunchTROWEXPANDEXPDIF_f16_48x64, 48, 64, 48, 16, 48, 64, 48, 64, sizeof(uint16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + size_t src0FileSize = tc.src0Rows * tc.src0Cols * tc.elemSize; + size_t src1FileSize = tc.src1Rows * tc.src1Cols * tc.elemSize; + const size_t dstFileSize = tc.dstRows * tc.dstCols * tc.elemSize; + + std::printf("[INFO] === case: %s (src0=%zux%zu, src1=%zux%zu, dst=%zux%zu) ===\n", + tc.name, tc.src0Rows, tc.src0Cols, tc.src1Rows, tc.src1Cols, tc.dstRows, tc.dstCols); + + std::string caseDir = std::string("./") + tc.name; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), src0FileSize); + aclrtMallocHost((void **)(&src1Host), src1FileSize); + aclrtMallocHost((void **)(&dstHost), dstFileSize); + aclrtMalloc((void **)&src0Device, src0FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, src1FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, src0FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, src1FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, src0FileSize, src0Host, src0FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, src1FileSize, src1Host, src1FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + tc.launch(src0Device, src1Device, dstDevice, stream); + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device) aclrtFree(src0Device); + if (src1Device) aclrtFree(src1Device); + if (dstDevice) aclrtFree(dstDevice); + if (src0Host) aclrtFreeHost(src0Host); + if (src1Host) aclrtFreeHost(src1Host); + if (dstHost) aclrtFreeHost(dstHost); + + if (rc == 0) std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + int rc = 0, deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) deviceId = std::atoi(envDevice); + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter && std::strcmp(kCases[i].name, caseFilter) != 0) continue; + if (RunCase(kCases[i], deviceId, stream) != 0) { rc = 1; break; } + } + + if (stream) aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/trowexpandexpdif.pto b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/trowexpandexpdif.pto new file mode 100644 index 000000000..12248f7a1 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/trowexpandexpdif.pto @@ -0,0 +1,317 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.trowexpandexpdif: row-wise broadcast exponential difference. +// dst = exp(src0 - broadcast(src1)) +// Supports f32, f16 + +module { + // f32_32x64: launchTRowExpandExpdif + func.func @TROWEXPANDEXPDIF_f32_32x64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c32, %c8], + strides = [%c256, %c256, %c256, %c8, %c1] + : !pto.tensor_view<1x1x1x32x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c8] + : !pto.tensor_view<1x1x1x32x8xf32> -> !pto.partition_tensor_view<1x1x1x32x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x32x8xf32>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandexpdif ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + return + } + + // f32_16x32: launchTRowExpandExpdif + func.func @TROWEXPANDEXPDIF_f32_16x32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c8], + strides = [%c128, %c128, %c128, %c8, %c1] + : !pto.tensor_view<1x1x1x16x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xf32> -> !pto.partition_tensor_view<1x1x1x16x32xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c8] + : !pto.tensor_view<1x1x1x16x8xf32> -> !pto.partition_tensor_view<1x1x1x16x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xf32> -> !pto.partition_tensor_view<1x1x1x16x32xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x32xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x8xf32>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandexpdif ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x32xf32>) + return + } + + // f16_16x32: launchTRowExpandExpdif + func.func @TROWEXPANDEXPDIF_f16_16x32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c16], + strides = [%c256, %c256, %c256, %c16, %c1] + : !pto.tensor_view<1x1x1x16x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xf16> -> !pto.partition_tensor_view<1x1x1x16x32xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c16] + : !pto.tensor_view<1x1x1x16x16xf16> -> !pto.partition_tensor_view<1x1x1x16x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xf16> -> !pto.partition_tensor_view<1x1x1x16x32xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x32xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x16xf16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandexpdif ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x32xf16>) + return + } + + // f16_48x64: launchTRowExpandExpdif + func.func @TROWEXPANDEXPDIF_f16_48x64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c48 = arith.constant 48 : index + %c64 = arith.constant 64 : index + %c768 = arith.constant 768 : index + %c3072 = arith.constant 3072 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c48, %c64], + strides = [%c3072, %c3072, %c3072, %c64, %c1] + : !pto.tensor_view<1x1x1x48x64xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c48, %c16], + strides = [%c768, %c768, %c768, %c16, %c1] + : !pto.tensor_view<1x1x1x48x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c48, %c64], + strides = [%c3072, %c3072, %c3072, %c64, %c1] + : !pto.tensor_view<1x1x1x48x64xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c48, %c64] + : !pto.tensor_view<1x1x1x48x64xf16> -> !pto.partition_tensor_view<1x1x1x48x64xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c48, %c16] + : !pto.tensor_view<1x1x1x48x16xf16> -> !pto.partition_tensor_view<1x1x1x48x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c48, %c64] + : !pto.tensor_view<1x1x1x48x64xf16> -> !pto.partition_tensor_view<1x1x1x48x64xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x48x64xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x48x16xf16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandexpdif ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x48x64xf16>) + return + } + + // f32_16x128_noeq: launchTRowExpandExpdif + func.func @TROWEXPANDEXPDIF_f32_16x128_noeq(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c128_ = arith.constant 128 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c8], + strides = [%c128_, %c128_, %c128_, %c8, %c1] + : !pto.tensor_view<1x1x1x16x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x128xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c8] + : !pto.tensor_view<1x1x1x16x8xf32> -> !pto.partition_tensor_view<1x1x1x16x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x128xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x128xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x8xf32>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandexpdif ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = false} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x128xf32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/CMakeLists.txt new file mode 100644 index 000000000..7f6c82ffe --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(trowexpandmax) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/cases.py new file mode 100644 index 000000000..be9646b0f --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/cases.py @@ -0,0 +1,111 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for trowexpandmax ST test cases. + +trowexpandmax: row-wise broadcast maximum. +- src1Col=1: only first column of src1 is valid, broadcast to dstCols +- src1Col>1: launchTRowExpandMax2 with different semantics (TBD) +- src1 physical cols = 32/sizeof(dtype) for NPU alignment +- src0eqdst: true means src0 shape equals dst shape +""" + +import numpy as np + +CASES = [ + # launchTRowExpandMax + { + "name": "f32_16x32", + "dtype": np.float32, + "src0_shape": (16, 32), + "src0_valid_shape": (16, 32), + "src1_shape": (16, 8), # physical: 32/sizeof(f32)=8 + "src1_valid_shape": (16, 1), # src1Col=1 + "dst_shape": (16, 32), + "dst_valid_shape": (16, 32), + "eps": 1e-6, + }, + # launchTRowExpandMax + { + "name": "f32_56x128", + "dtype": np.float32, + "src0_shape": (56, 128), + "src0_valid_shape": (56, 128), + "src1_shape": (56, 8), + "src1_valid_shape": (56, 1), + "dst_shape": (56, 128), + "dst_valid_shape": (56, 128), + "eps": 1e-6, + }, + # launchTRowExpandMax + { + "name": "f16_48x64", + "dtype": np.float16, + "src0_shape": (48, 64), + "src0_valid_shape": (48, 64), + "src1_shape": (48, 16), # physical: 32/sizeof(f16)=16 + "src1_valid_shape": (48, 1), + "dst_shape": (48, 64), + "dst_valid_shape": (48, 64), + "eps": 1e-3, + }, + # launchTRowExpandMax + { + "name": "f16_16x128", + "dtype": np.float16, + "src0_shape": (16, 128), + "src0_valid_shape": (16, 128), + "src1_shape": (16, 16), + "src1_valid_shape": (16, 1), + "dst_shape": (16, 128), + "dst_valid_shape": (16, 128), + "eps": 1e-3, + }, + # launchTRowExpandMax + { + "name": "f16_32x64_noeq", + "dtype": np.float16, + "src0_shape": (32, 64), # src0eqdst=false + "src0_valid_shape": (32, 64), + "src1_shape": (32, 16), + "src1_valid_shape": (32, 1), + "dst_shape": (32, 64), + "dst_valid_shape": (32, 64), + "eps": 1e-3, + }, + # launchTRowExpandMax + { + "name": "i32_16x32", + "dtype": np.int32, + "src0_shape": (16, 32), + "src0_valid_shape": (16, 32), + "src1_shape": (16, 8), # physical: 32/sizeof(i32)=8 + "src1_valid_shape": (16, 1), + "dst_shape": (16, 32), + "dst_valid_shape": (16, 32), + "eps": 0, + }, + # launchTRowExpandMax + { + "name": "i16_16x64", + "dtype": np.int16, + "src0_shape": (16, 64), + "src0_valid_shape": (16, 64), + "src1_shape": (16, 16), # physical: 32/sizeof(i16)=16 + "src1_valid_shape": (16, 1), + "dst_shape": (16, 64), + "dst_valid_shape": (16, 64), + "eps": 0, + }, + # Note: launchTRowExpandMax2 with src1Col>1 has different semantics - TBD + # - float, 24, 64, 24, 8, true (src1Col=8) + # - float, 20, 64, 20, 8, false (src1Col=8) +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/compare.py new file mode 100644 index 000000000..d02ad1760 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/compare.py @@ -0,0 +1,61 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Compare golden and output for trowexpandmax ST test cases.""" + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass + +# Inline validation for multi-input format (trowexpandmax uses src0/src1/dst) +REQUIRED_KEYS = {"name", "dtype", "src0_shape", "src0_valid_shape", "src1_shape", + "src1_valid_shape", "dst_shape", "dst_valid_shape"} +for i, case in enumerate(CASES): + missing = REQUIRED_KEYS - case.keys() + if missing: + raise ValueError(f"cases[{i}] ({case.get('name', '?')}) missing keys: {missing}") + + +def main(): + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + dtype = case["dtype"] + + vr, vc = dst_valid_shape + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=dtype).reshape(dst_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=dtype).reshape(dst_shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/gen_data.py new file mode 100644 index 000000000..ba1a8a2c4 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/gen_data.py @@ -0,0 +1,53 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Generate input and golden data for trowexpandmax ST test cases. + +trowexpandmax: dst = max(src0, broadcast(src1)) +""" + +import numpy as np +from cases import CASES +from st_common import setup_case_rng, save_case_data + +# Inline validation for multi-input format (trowexpandmax uses src0/src1/dst) +REQUIRED_KEYS = {"name", "dtype", "src0_shape", "src0_valid_shape", "src1_shape", + "src1_valid_shape", "dst_shape", "dst_valid_shape"} +for i, case in enumerate(CASES): + missing = REQUIRED_KEYS - case.keys() + if missing: + raise ValueError(f"cases[{i}] ({case.get('name', '?')}) missing keys: {missing}") + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + src0_shape = case["src0_shape"] + src0_valid_shape = case["src0_valid_shape"] + src1_shape = case["src1_shape"] + src1_valid_shape = case["src1_valid_shape"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + + input1 = np.random.randint(1, 10, size=src0_shape).astype(dtype) + input2 = np.random.randint(1, 10, size=src1_shape).astype(dtype) + + golden = np.zeros(dst_shape, dtype=dtype) + dst_vr, dst_vc = dst_valid_shape + src0_vr, src0_vc = src0_valid_shape + src1_vr = src1_valid_shape[0] + + golden[:dst_vr, :dst_vc] = np.maximum( + input1[:src0_vr, :src0_vc], np.broadcast_to(input2[:src1_vr, 0:1], (dst_vr, dst_vc)) + ).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/launch.cpp new file mode 100644 index 000000000..8a16fd19a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// f32 kernels +extern "C" __global__ AICORE void TROWEXPANDMAX_f32_16x32(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); +extern "C" __global__ AICORE void TROWEXPANDMAX_f32_56x128(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTROWEXPANDMAX_f32_16x32(float *src0, float *src1, float *dst, void *stream) { + TROWEXPANDMAX_f32_16x32<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} +void LaunchTROWEXPANDMAX_f32_56x128(float *src0, float *src1, float *dst, void *stream) { + TROWEXPANDMAX_f32_56x128<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// f16 kernels (use uint16_t for aclFloat16) +extern "C" __global__ AICORE void TROWEXPANDMAX_f16_48x64(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TROWEXPANDMAX_f16_16x128(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TROWEXPANDMAX_f16_32x64_noeq(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTROWEXPANDMAX_f16_48x64(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDMAX_f16_48x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} +void LaunchTROWEXPANDMAX_f16_16x128(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDMAX_f16_16x128<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} +void LaunchTROWEXPANDMAX_f16_32x64_noeq(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDMAX_f16_32x64_noeq<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} + +// i32 kernels +extern "C" __global__ AICORE void TROWEXPANDMAX_i32_16x32(__gm__ int32_t *src0, __gm__ int32_t *src1, __gm__ int32_t *dst); + +void LaunchTROWEXPANDMAX_i32_16x32(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDMAX_i32_16x32<<<1, nullptr, stream>>>((__gm__ int32_t *)src0, (__gm__ int32_t *)src1, (__gm__ int32_t *)dst); +} + +// i16 kernels +extern "C" __global__ AICORE void TROWEXPANDMAX_i16_16x64(__gm__ int16_t *src0, __gm__ int16_t *src1, __gm__ int16_t *dst); + +void LaunchTROWEXPANDMAX_i16_16x64(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDMAX_i16_16x64<<<1, nullptr, stream>>>((__gm__ int16_t *)src0, (__gm__ int16_t *)src1, (__gm__ int16_t *)dst); +} + +// Note: launchTRowExpandMax2 with src1Col>1 has different semantics - TBD \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/main.cpp new file mode 100644 index 000000000..cf2b51036 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/main.cpp @@ -0,0 +1,134 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang trowexpandmax ST — row-wise broadcast maximum. +// Supports f32, f16, i32, i16 + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// f32 kernels +void LaunchTROWEXPANDMAX_f32_16x32(float *src0, float *src1, float *dst, void *stream); +void LaunchTROWEXPANDMAX_f32_56x128(float *src0, float *src1, float *dst, void *stream); +// f16 kernels (use void* for aclFloat16) +void LaunchTROWEXPANDMAX_f16_48x64(void *src0, void *src1, void *dst, void *stream); +void LaunchTROWEXPANDMAX_f16_16x128(void *src0, void *src1, void *dst, void *stream); +void LaunchTROWEXPANDMAX_f16_32x64_noeq(void *src0, void *src1, void *dst, void *stream); +// i32 kernels +void LaunchTROWEXPANDMAX_i32_16x32(void *src0, void *src1, void *dst, void *stream); +// i16 kernels +void LaunchTROWEXPANDMAX_i16_16x64(void *src0, void *src1, void *dst, void *stream); + +// Note: launchTRowExpandMax2 with src1Col>1 has different semantics - TBD + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t src0Rows, src0Cols, src1Rows, src1Cols, dstRows, dstCols; + size_t dstValidRows, dstValidCols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + // f32 cases + {"f32_16x32", (LaunchFn)LaunchTROWEXPANDMAX_f32_16x32, 16, 32, 16, 8, 16, 32, 16, 32, sizeof(float)}, + {"f32_56x128", (LaunchFn)LaunchTROWEXPANDMAX_f32_56x128, 56, 128, 56, 8, 56, 128, 56, 128, sizeof(float)}, + // f16 cases + {"f16_48x64", LaunchTROWEXPANDMAX_f16_48x64, 48, 64, 48, 16, 48, 64, 48, 64, sizeof(uint16_t)}, + {"f16_16x128", LaunchTROWEXPANDMAX_f16_16x128, 16, 128, 16, 16, 16, 128, 16, 128, sizeof(uint16_t)}, + {"f16_32x64_noeq", LaunchTROWEXPANDMAX_f16_32x64_noeq, 32, 64, 32, 16, 32, 64, 32, 64, sizeof(uint16_t)}, + // i32 cases + {"i32_16x32", LaunchTROWEXPANDMAX_i32_16x32, 16, 32, 16, 8, 16, 32, 16, 32, sizeof(int32_t)}, + // i16 cases + {"i16_16x64", LaunchTROWEXPANDMAX_i16_16x64, 16, 64, 16, 16, 16, 64, 16, 64, sizeof(int16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + size_t src0FileSize = tc.src0Rows * tc.src0Cols * tc.elemSize; + size_t src1FileSize = tc.src1Rows * tc.src1Cols * tc.elemSize; + const size_t dstFileSize = tc.dstRows * tc.dstCols * tc.elemSize; + + std::printf("[INFO] === case: %s (src0=%zux%zu, src1=%zux%zu, dst=%zux%zu) ===\n", + tc.name, tc.src0Rows, tc.src0Cols, tc.src1Rows, tc.src1Cols, tc.dstRows, tc.dstCols); + + std::string caseDir = std::string("./") + tc.name; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), src0FileSize); + aclrtMallocHost((void **)(&src1Host), src1FileSize); + aclrtMallocHost((void **)(&dstHost), dstFileSize); + aclrtMalloc((void **)&src0Device, src0FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, src1FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, src0FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, src1FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, src0FileSize, src0Host, src0FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, src1FileSize, src1Host, src1FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + tc.launch(src0Device, src1Device, dstDevice, stream); + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device) aclrtFree(src0Device); + if (src1Device) aclrtFree(src1Device); + if (dstDevice) aclrtFree(dstDevice); + if (src0Host) aclrtFreeHost(src0Host); + if (src1Host) aclrtFreeHost(src1Host); + if (dstHost) aclrtFreeHost(dstHost); + + if (rc == 0) std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + int rc = 0, deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) deviceId = std::atoi(envDevice); + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter && std::strcmp(kCases[i].name, caseFilter) != 0) continue; + if (RunCase(kCases[i], deviceId, stream) != 0) { rc = 1; break; } + } + + if (stream) aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/trowexpandmax.pto b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/trowexpandmax.pto new file mode 100644 index 000000000..a15c0c529 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/trowexpandmax.pto @@ -0,0 +1,437 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.trowexpandmax: row-wise broadcast maximum. +// Supports f32, f16, i32, i16 + +module { + // f32_16x32: launchTRowExpandMax + func.func @TROWEXPANDMAX_f32_16x32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c8], + strides = [%c128, %c128, %c128, %c8, %c1] + : !pto.tensor_view<1x1x1x16x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xf32> -> !pto.partition_tensor_view<1x1x1x16x32xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c8] + : !pto.tensor_view<1x1x1x16x8xf32> -> !pto.partition_tensor_view<1x1x1x16x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xf32> -> !pto.partition_tensor_view<1x1x1x16x32xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x32xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x8xf32>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandmax ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x32xf32>) + return + } + + // f32_56x128: launchTRowExpandMax + func.func @TROWEXPANDMAX_f32_56x128(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c56 = arith.constant 56 : index + %c128 = arith.constant 128 : index + %c448 = arith.constant 448 : index + %c7168 = arith.constant 7168 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c56, %c128], + strides = [%c7168, %c7168, %c7168, %c128, %c1] + : !pto.tensor_view<1x1x1x56x128xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c56, %c8], + strides = [%c448, %c448, %c448, %c8, %c1] + : !pto.tensor_view<1x1x1x56x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c56, %c128], + strides = [%c7168, %c7168, %c7168, %c128, %c1] + : !pto.tensor_view<1x1x1x56x128xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c56, %c128] + : !pto.tensor_view<1x1x1x56x128xf32> -> !pto.partition_tensor_view<1x1x1x56x128xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c56, %c8] + : !pto.tensor_view<1x1x1x56x8xf32> -> !pto.partition_tensor_view<1x1x1x56x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c56, %c128] + : !pto.tensor_view<1x1x1x56x128xf32> -> !pto.partition_tensor_view<1x1x1x56x128xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x56x128xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x56x8xf32>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandmax ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x56x128xf32>) + return + } + + // f16_48x64: launchTRowExpandMax + func.func @TROWEXPANDMAX_f16_48x64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c48 = arith.constant 48 : index + %c64 = arith.constant 64 : index + %c768 = arith.constant 768 : index + %c3072 = arith.constant 3072 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c48, %c64], + strides = [%c3072, %c3072, %c3072, %c64, %c1] + : !pto.tensor_view<1x1x1x48x64xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c48, %c16], + strides = [%c768, %c768, %c768, %c16, %c1] + : !pto.tensor_view<1x1x1x48x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c48, %c64], + strides = [%c3072, %c3072, %c3072, %c64, %c1] + : !pto.tensor_view<1x1x1x48x64xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c48, %c64] + : !pto.tensor_view<1x1x1x48x64xf16> -> !pto.partition_tensor_view<1x1x1x48x64xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c48, %c16] + : !pto.tensor_view<1x1x1x48x16xf16> -> !pto.partition_tensor_view<1x1x1x48x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c48, %c64] + : !pto.tensor_view<1x1x1x48x64xf16> -> !pto.partition_tensor_view<1x1x1x48x64xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x48x64xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x48x16xf16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandmax ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x48x64xf16>) + return + } + + // f16_16x128: launchTRowExpandMax + func.func @TROWEXPANDMAX_f16_16x128(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c16], + strides = [%c256, %c256, %c256, %c16, %c1] + : !pto.tensor_view<1x1x1x16x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x128xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c16] + : !pto.tensor_view<1x1x1x16x16xf16> -> !pto.partition_tensor_view<1x1x1x16x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x128xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x128xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x16xf16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandmax ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x128xf16>) + return + } + + // f16_32x64_noeq: launchTRowExpandMax + func.func @TROWEXPANDMAX_f16_32x64_noeq(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c512 = arith.constant 512 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c32, %c16], + strides = [%c512, %c512, %c512, %c16, %c1] + : !pto.tensor_view<1x1x1x32x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf16> -> !pto.partition_tensor_view<1x1x1x32x64xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c16] + : !pto.tensor_view<1x1x1x32x16xf16> -> !pto.partition_tensor_view<1x1x1x32x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf16> -> !pto.partition_tensor_view<1x1x1x32x64xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x32x64xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x32x16xf16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandmax ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = false} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xf16>) + return + } + + // i32_16x32: launchTRowExpandMax + func.func @TROWEXPANDMAX_i32_16x32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c8], + strides = [%c128, %c128, %c128, %c8, %c1] + : !pto.tensor_view<1x1x1x16x8xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c8] + : !pto.tensor_view<1x1x1x16x8xi32> -> !pto.partition_tensor_view<1x1x1x16x8xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x8xi32>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandmax ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + return + } + + // i16_16x64: launchTRowExpandMax + func.func @TROWEXPANDMAX_i16_16x64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c16], + strides = [%c256, %c256, %c256, %c16, %c1] + : !pto.tensor_view<1x1x1x16x16xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c16] + : !pto.tensor_view<1x1x1x16x16xi16> -> !pto.partition_tensor_view<1x1x1x16x16xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x16xi16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandmax ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/CMakeLists.txt new file mode 100644 index 000000000..2d154b940 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(trowexpandmin) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/cases.py new file mode 100644 index 000000000..97443ca5b --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/cases.py @@ -0,0 +1,111 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for trowexpandmin ST test cases. + +trowexpandmin: row-wise broadcast minimum. +- src1Col=1: only first column of src1 is valid, broadcast to dstCols +- src1Col>1: launchTRowExpandMin2 with different semantics (TBD) +- src1 physical cols = 32/sizeof(dtype) for NPU alignment +- src0eqdst: true means src0 shape equals dst shape +""" + +import numpy as np + +CASES = [ + # launchTRowExpandMin + { + "name": "f32_16x32", + "dtype": np.float32, + "src0_shape": (16, 32), + "src0_valid_shape": (16, 32), + "src1_shape": (16, 8), # physical: 32/sizeof(f32)=8 + "src1_valid_shape": (16, 1), # src1Col=1 + "dst_shape": (16, 32), + "dst_valid_shape": (16, 32), + "eps": 1e-6, + }, + # launchTRowExpandMin + { + "name": "f32_56x128", + "dtype": np.float32, + "src0_shape": (56, 128), + "src0_valid_shape": (56, 128), + "src1_shape": (56, 8), + "src1_valid_shape": (56, 1), + "dst_shape": (56, 128), + "dst_valid_shape": (56, 128), + "eps": 1e-6, + }, + # launchTRowExpandMin + { + "name": "f16_48x64", + "dtype": np.float16, + "src0_shape": (48, 64), + "src0_valid_shape": (48, 64), + "src1_shape": (48, 16), # physical: 32/sizeof(f16)=16 + "src1_valid_shape": (48, 1), + "dst_shape": (48, 64), + "dst_valid_shape": (48, 64), + "eps": 1e-3, + }, + # launchTRowExpandMin + { + "name": "f16_16x128", + "dtype": np.float16, + "src0_shape": (16, 128), + "src0_valid_shape": (16, 128), + "src1_shape": (16, 16), + "src1_valid_shape": (16, 1), + "dst_shape": (16, 128), + "dst_valid_shape": (16, 128), + "eps": 1e-3, + }, + # launchTRowExpandMin + { + "name": "f16_32x64_noeq", + "dtype": np.float16, + "src0_shape": (32, 64), # src0eqdst=false + "src0_valid_shape": (32, 64), + "src1_shape": (32, 16), + "src1_valid_shape": (32, 1), + "dst_shape": (32, 64), + "dst_valid_shape": (32, 64), + "eps": 1e-3, + }, + # launchTRowExpandMin + { + "name": "i32_16x32", + "dtype": np.int32, + "src0_shape": (16, 32), + "src0_valid_shape": (16, 32), + "src1_shape": (16, 8), # physical: 32/sizeof(i32)=8 + "src1_valid_shape": (16, 1), + "dst_shape": (16, 32), + "dst_valid_shape": (16, 32), + "eps": 0, + }, + # launchTRowExpandMin + { + "name": "i16_16x64", + "dtype": np.int16, + "src0_shape": (16, 64), + "src0_valid_shape": (16, 64), + "src1_shape": (16, 16), # physical: 32/sizeof(i16)=16 + "src1_valid_shape": (16, 1), + "dst_shape": (16, 64), + "dst_valid_shape": (16, 64), + "eps": 0, + }, + # Note: launchTRowExpandMin2 with src1Col>1 has different semantics - TBD + # - float, 24, 64, 24, 8, true (src1Col=8) + # - float, 20, 64, 20, 8, false (src1Col=8) +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/compare.py new file mode 100644 index 000000000..7f637101d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/compare.py @@ -0,0 +1,61 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Compare golden and output for trowexpandmin ST test cases.""" + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass + +# Inline validation for multi-input format (trowexpandmin uses src0/src1/dst) +REQUIRED_KEYS = {"name", "dtype", "src0_shape", "src0_valid_shape", "src1_shape", + "src1_valid_shape", "dst_shape", "dst_valid_shape"} +for i, case in enumerate(CASES): + missing = REQUIRED_KEYS - case.keys() + if missing: + raise ValueError(f"cases[{i}] ({case.get('name', '?')}) missing keys: {missing}") + + +def main(): + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + dtype = case["dtype"] + + vr, vc = dst_valid_shape + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=dtype).reshape(dst_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=dtype).reshape(dst_shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/gen_data.py new file mode 100644 index 000000000..1c88e0eef --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/gen_data.py @@ -0,0 +1,53 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Generate input and golden data for trowexpandmin ST test cases. + +trowexpandmin: dst = min(src0, broadcast(src1)) +""" + +import numpy as np +from cases import CASES +from st_common import setup_case_rng, save_case_data + +# Inline validation for multi-input format (trowexpandmin uses src0/src1/dst) +REQUIRED_KEYS = {"name", "dtype", "src0_shape", "src0_valid_shape", "src1_shape", + "src1_valid_shape", "dst_shape", "dst_valid_shape"} +for i, case in enumerate(CASES): + missing = REQUIRED_KEYS - case.keys() + if missing: + raise ValueError(f"cases[{i}] ({case.get('name', '?')}) missing keys: {missing}") + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + src0_shape = case["src0_shape"] + src0_valid_shape = case["src0_valid_shape"] + src1_shape = case["src1_shape"] + src1_valid_shape = case["src1_valid_shape"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + + input1 = np.random.randint(1, 10, size=src0_shape).astype(dtype) + input2 = np.random.randint(1, 10, size=src1_shape).astype(dtype) + + golden = np.zeros(dst_shape, dtype=dtype) + dst_vr, dst_vc = dst_valid_shape + src0_vr, src0_vc = src0_valid_shape + src1_vr = src1_valid_shape[0] + + golden[:dst_vr, :dst_vc] = np.minimum( + input1[:src0_vr, :src0_vc], np.broadcast_to(input2[:src1_vr, 0:1], (dst_vr, dst_vc)) + ).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/launch.cpp new file mode 100644 index 000000000..ba11c02ff --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// f32 kernels +extern "C" __global__ AICORE void TROWEXPANDMIN_f32_16x32(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); +extern "C" __global__ AICORE void TROWEXPANDMIN_f32_56x128(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTROWEXPANDMIN_f32_16x32(float *src0, float *src1, float *dst, void *stream) { + TROWEXPANDMIN_f32_16x32<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} +void LaunchTROWEXPANDMIN_f32_56x128(float *src0, float *src1, float *dst, void *stream) { + TROWEXPANDMIN_f32_56x128<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// f16 kernels (use uint16_t for aclFloat16) +extern "C" __global__ AICORE void TROWEXPANDMIN_f16_48x64(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TROWEXPANDMIN_f16_16x128(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TROWEXPANDMIN_f16_32x64_noeq(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTROWEXPANDMIN_f16_48x64(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDMIN_f16_48x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} +void LaunchTROWEXPANDMIN_f16_16x128(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDMIN_f16_16x128<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} +void LaunchTROWEXPANDMIN_f16_32x64_noeq(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDMIN_f16_32x64_noeq<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} + +// i32 kernels +extern "C" __global__ AICORE void TROWEXPANDMIN_i32_16x32(__gm__ int32_t *src0, __gm__ int32_t *src1, __gm__ int32_t *dst); + +void LaunchTROWEXPANDMIN_i32_16x32(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDMIN_i32_16x32<<<1, nullptr, stream>>>((__gm__ int32_t *)src0, (__gm__ int32_t *)src1, (__gm__ int32_t *)dst); +} + +// i16 kernels +extern "C" __global__ AICORE void TROWEXPANDMIN_i16_16x64(__gm__ int16_t *src0, __gm__ int16_t *src1, __gm__ int16_t *dst); + +void LaunchTROWEXPANDMIN_i16_16x64(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDMIN_i16_16x64<<<1, nullptr, stream>>>((__gm__ int16_t *)src0, (__gm__ int16_t *)src1, (__gm__ int16_t *)dst); +} + +// Note: launchTRowExpandMin2 with src1Col>1 has different semantics - TBD \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/main.cpp new file mode 100644 index 000000000..53e40102a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/main.cpp @@ -0,0 +1,134 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang trowexpandmin ST — row-wise broadcast minimum. +// Supports f32, f16, i32, i16 + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// f32 kernels +void LaunchTROWEXPANDMIN_f32_16x32(float *src0, float *src1, float *dst, void *stream); +void LaunchTROWEXPANDMIN_f32_56x128(float *src0, float *src1, float *dst, void *stream); +// f16 kernels (use void* for aclFloat16) +void LaunchTROWEXPANDMIN_f16_48x64(void *src0, void *src1, void *dst, void *stream); +void LaunchTROWEXPANDMIN_f16_16x128(void *src0, void *src1, void *dst, void *stream); +void LaunchTROWEXPANDMIN_f16_32x64_noeq(void *src0, void *src1, void *dst, void *stream); +// i32 kernels +void LaunchTROWEXPANDMIN_i32_16x32(void *src0, void *src1, void *dst, void *stream); +// i16 kernels +void LaunchTROWEXPANDMIN_i16_16x64(void *src0, void *src1, void *dst, void *stream); + +// Note: launchTRowExpandMin2 with src1Col>1 has different semantics - TBD + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t src0Rows, src0Cols, src1Rows, src1Cols, dstRows, dstCols; + size_t dstValidRows, dstValidCols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + // f32 cases + {"f32_16x32", (LaunchFn)LaunchTROWEXPANDMIN_f32_16x32, 16, 32, 16, 8, 16, 32, 16, 32, sizeof(float)}, + {"f32_56x128", (LaunchFn)LaunchTROWEXPANDMIN_f32_56x128, 56, 128, 56, 8, 56, 128, 56, 128, sizeof(float)}, + // f16 cases + {"f16_48x64", LaunchTROWEXPANDMIN_f16_48x64, 48, 64, 48, 16, 48, 64, 48, 64, sizeof(uint16_t)}, + {"f16_16x128", LaunchTROWEXPANDMIN_f16_16x128, 16, 128, 16, 16, 16, 128, 16, 128, sizeof(uint16_t)}, + {"f16_32x64_noeq", LaunchTROWEXPANDMIN_f16_32x64_noeq, 32, 64, 32, 16, 32, 64, 32, 64, sizeof(uint16_t)}, + // i32 cases + {"i32_16x32", LaunchTROWEXPANDMIN_i32_16x32, 16, 32, 16, 8, 16, 32, 16, 32, sizeof(int32_t)}, + // i16 cases + {"i16_16x64", LaunchTROWEXPANDMIN_i16_16x64, 16, 64, 16, 16, 16, 64, 16, 64, sizeof(int16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + size_t src0FileSize = tc.src0Rows * tc.src0Cols * tc.elemSize; + size_t src1FileSize = tc.src1Rows * tc.src1Cols * tc.elemSize; + const size_t dstFileSize = tc.dstRows * tc.dstCols * tc.elemSize; + + std::printf("[INFO] === case: %s (src0=%zux%zu, src1=%zux%zu, dst=%zux%zu) ===\n", + tc.name, tc.src0Rows, tc.src0Cols, tc.src1Rows, tc.src1Cols, tc.dstRows, tc.dstCols); + + std::string caseDir = std::string("./") + tc.name; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), src0FileSize); + aclrtMallocHost((void **)(&src1Host), src1FileSize); + aclrtMallocHost((void **)(&dstHost), dstFileSize); + aclrtMalloc((void **)&src0Device, src0FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, src1FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, src0FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, src1FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, src0FileSize, src0Host, src0FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, src1FileSize, src1Host, src1FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + tc.launch(src0Device, src1Device, dstDevice, stream); + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device) aclrtFree(src0Device); + if (src1Device) aclrtFree(src1Device); + if (dstDevice) aclrtFree(dstDevice); + if (src0Host) aclrtFreeHost(src0Host); + if (src1Host) aclrtFreeHost(src1Host); + if (dstHost) aclrtFreeHost(dstHost); + + if (rc == 0) std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + int rc = 0, deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) deviceId = std::atoi(envDevice); + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter && std::strcmp(kCases[i].name, caseFilter) != 0) continue; + if (RunCase(kCases[i], deviceId, stream) != 0) { rc = 1; break; } + } + + if (stream) aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/trowexpandmin.pto b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/trowexpandmin.pto new file mode 100644 index 000000000..f4e551b80 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/trowexpandmin.pto @@ -0,0 +1,437 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.trowexpandmin: row-wise broadcast minimum. +// Supports f32, f16, i32, i16 + +module { + // f32_16x32: launchTRowExpandMin + func.func @TROWEXPANDMIN_f32_16x32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c8], + strides = [%c128, %c128, %c128, %c8, %c1] + : !pto.tensor_view<1x1x1x16x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xf32> -> !pto.partition_tensor_view<1x1x1x16x32xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c8] + : !pto.tensor_view<1x1x1x16x8xf32> -> !pto.partition_tensor_view<1x1x1x16x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xf32> -> !pto.partition_tensor_view<1x1x1x16x32xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x32xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x8xf32>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandmin ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x32xf32>) + return + } + + // f32_56x128: launchTRowExpandMin + func.func @TROWEXPANDMIN_f32_56x128(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c56 = arith.constant 56 : index + %c128 = arith.constant 128 : index + %c448 = arith.constant 448 : index + %c7168 = arith.constant 7168 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c56, %c128], + strides = [%c7168, %c7168, %c7168, %c128, %c1] + : !pto.tensor_view<1x1x1x56x128xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c56, %c8], + strides = [%c448, %c448, %c448, %c8, %c1] + : !pto.tensor_view<1x1x1x56x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c56, %c128], + strides = [%c7168, %c7168, %c7168, %c128, %c1] + : !pto.tensor_view<1x1x1x56x128xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c56, %c128] + : !pto.tensor_view<1x1x1x56x128xf32> -> !pto.partition_tensor_view<1x1x1x56x128xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c56, %c8] + : !pto.tensor_view<1x1x1x56x8xf32> -> !pto.partition_tensor_view<1x1x1x56x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c56, %c128] + : !pto.tensor_view<1x1x1x56x128xf32> -> !pto.partition_tensor_view<1x1x1x56x128xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x56x128xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x56x8xf32>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandmin ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x56x128xf32>) + return + } + + // f16_48x64: launchTRowExpandMin + func.func @TROWEXPANDMIN_f16_48x64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c48 = arith.constant 48 : index + %c64 = arith.constant 64 : index + %c768 = arith.constant 768 : index + %c3072 = arith.constant 3072 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c48, %c64], + strides = [%c3072, %c3072, %c3072, %c64, %c1] + : !pto.tensor_view<1x1x1x48x64xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c48, %c16], + strides = [%c768, %c768, %c768, %c16, %c1] + : !pto.tensor_view<1x1x1x48x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c48, %c64], + strides = [%c3072, %c3072, %c3072, %c64, %c1] + : !pto.tensor_view<1x1x1x48x64xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c48, %c64] + : !pto.tensor_view<1x1x1x48x64xf16> -> !pto.partition_tensor_view<1x1x1x48x64xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c48, %c16] + : !pto.tensor_view<1x1x1x48x16xf16> -> !pto.partition_tensor_view<1x1x1x48x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c48, %c64] + : !pto.tensor_view<1x1x1x48x64xf16> -> !pto.partition_tensor_view<1x1x1x48x64xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x48x64xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x48x16xf16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandmin ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x48x64xf16>) + return + } + + // f16_16x128: launchTRowExpandMin + func.func @TROWEXPANDMIN_f16_16x128(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c16], + strides = [%c256, %c256, %c256, %c16, %c1] + : !pto.tensor_view<1x1x1x16x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x128xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c16] + : !pto.tensor_view<1x1x1x16x16xf16> -> !pto.partition_tensor_view<1x1x1x16x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x128xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x128xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x16xf16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandmin ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x128xf16>) + return + } + + // f16_32x64_noeq: launchTRowExpandMin + func.func @TROWEXPANDMIN_f16_32x64_noeq(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c512 = arith.constant 512 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c32, %c16], + strides = [%c512, %c512, %c512, %c16, %c1] + : !pto.tensor_view<1x1x1x32x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf16> -> !pto.partition_tensor_view<1x1x1x32x64xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c16] + : !pto.tensor_view<1x1x1x32x16xf16> -> !pto.partition_tensor_view<1x1x1x32x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf16> -> !pto.partition_tensor_view<1x1x1x32x64xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x32x64xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x32x16xf16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandmin ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = false} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xf16>) + return + } + + // i32_16x32: launchTRowExpandMin + func.func @TROWEXPANDMIN_i32_16x32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c8], + strides = [%c128, %c128, %c128, %c8, %c1] + : !pto.tensor_view<1x1x1x16x8xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c8] + : !pto.tensor_view<1x1x1x16x8xi32> -> !pto.partition_tensor_view<1x1x1x16x8xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x8xi32>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandmin ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + return + } + + // i16_16x64: launchTRowExpandMin + func.func @TROWEXPANDMIN_i16_16x64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c16], + strides = [%c256, %c256, %c256, %c16, %c1] + : !pto.tensor_view<1x1x1x16x16xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c16] + : !pto.tensor_view<1x1x1x16x16xi16> -> !pto.partition_tensor_view<1x1x1x16x16xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x16xi16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandmin ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/CMakeLists.txt new file mode 100644 index 000000000..5a71ec723 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(trowexpandmul) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/cases.py new file mode 100644 index 000000000..883563cb9 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/cases.py @@ -0,0 +1,111 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for trowexpandmul ST test cases. + +trowexpandmul: row-wise broadcast multiplication. +- src1Col=1: only first column of src1 is valid, broadcast to dstCols +- src1Col>1: launchTRowExpandMul2 with different semantics (TBD) +- src1 physical cols = 32/sizeof(dtype) for NPU alignment +- src0eqdst: true means src0 shape equals dst shape +""" + +import numpy as np + +CASES = [ + # launchTRowExpandMul + { + "name": "f32_16x32", + "dtype": np.float32, + "src0_shape": (16, 32), + "src0_valid_shape": (16, 32), + "src1_shape": (16, 8), # physical: 32/sizeof(f32)=8 + "src1_valid_shape": (16, 1), # src1Col=1 + "dst_shape": (16, 32), + "dst_valid_shape": (16, 32), + "eps": 1e-6, + }, + # launchTRowExpandMul + { + "name": "f32_56x128", + "dtype": np.float32, + "src0_shape": (56, 128), + "src0_valid_shape": (56, 128), + "src1_shape": (56, 8), + "src1_valid_shape": (56, 1), + "dst_shape": (56, 128), + "dst_valid_shape": (56, 128), + "eps": 1e-6, + }, + # launchTRowExpandMul + { + "name": "f16_48x64", + "dtype": np.float16, + "src0_shape": (48, 64), + "src0_valid_shape": (48, 64), + "src1_shape": (48, 16), # physical: 32/sizeof(f16)=16 + "src1_valid_shape": (48, 1), + "dst_shape": (48, 64), + "dst_valid_shape": (48, 64), + "eps": 1e-3, + }, + # launchTRowExpandMul + { + "name": "f16_16x128", + "dtype": np.float16, + "src0_shape": (16, 128), + "src0_valid_shape": (16, 128), + "src1_shape": (16, 16), + "src1_valid_shape": (16, 1), + "dst_shape": (16, 128), + "dst_valid_shape": (16, 128), + "eps": 1e-3, + }, + # launchTRowExpandMul + { + "name": "f16_32x64_noeq", + "dtype": np.float16, + "src0_shape": (32, 64), # src0eqdst=false + "src0_valid_shape": (32, 64), + "src1_shape": (32, 16), + "src1_valid_shape": (32, 1), + "dst_shape": (32, 64), + "dst_valid_shape": (32, 64), + "eps": 1e-3, + }, + # launchTRowExpandMul + { + "name": "i32_16x32", + "dtype": np.int32, + "src0_shape": (16, 32), + "src0_valid_shape": (16, 32), + "src1_shape": (16, 8), # physical: 32/sizeof(i32)=8 + "src1_valid_shape": (16, 1), + "dst_shape": (16, 32), + "dst_valid_shape": (16, 32), + "eps": 0, + }, + # launchTRowExpandMul + { + "name": "i16_16x64", + "dtype": np.int16, + "src0_shape": (16, 64), + "src0_valid_shape": (16, 64), + "src1_shape": (16, 16), # physical: 32/sizeof(i16)=16 + "src1_valid_shape": (16, 1), + "dst_shape": (16, 64), + "dst_valid_shape": (16, 64), + "eps": 0, + }, + # Note: launchTRowExpandMul2 with src1Col>1 has different semantics - TBD + # - float, 24, 64, 24, 8, true (src1Col=8) + # - float, 20, 64, 20, 8, false (src1Col=8) +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/compare.py new file mode 100644 index 000000000..5dea2f271 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/compare.py @@ -0,0 +1,61 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Compare golden and output for trowexpandmul ST test cases.""" + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass + +# Inline validation for multi-input format (trowexpandmul uses src0/src1/dst) +REQUIRED_KEYS = {"name", "dtype", "src0_shape", "src0_valid_shape", "src1_shape", + "src1_valid_shape", "dst_shape", "dst_valid_shape"} +for i, case in enumerate(CASES): + missing = REQUIRED_KEYS - case.keys() + if missing: + raise ValueError(f"cases[{i}] ({case.get('name', '?')}) missing keys: {missing}") + + +def main(): + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + dtype = case["dtype"] + + vr, vc = dst_valid_shape + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=dtype).reshape(dst_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=dtype).reshape(dst_shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/gen_data.py new file mode 100644 index 000000000..4d0dde3b6 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/gen_data.py @@ -0,0 +1,53 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Generate input and golden data for trowexpandmul ST test cases. + +trowexpandmul: dst = src0 * broadcast(src1) +""" + +import numpy as np +from cases import CASES +from st_common import setup_case_rng, save_case_data + +# Inline validation for multi-input format (trowexpandmul uses src0/src1/dst) +REQUIRED_KEYS = {"name", "dtype", "src0_shape", "src0_valid_shape", "src1_shape", + "src1_valid_shape", "dst_shape", "dst_valid_shape"} +for i, case in enumerate(CASES): + missing = REQUIRED_KEYS - case.keys() + if missing: + raise ValueError(f"cases[{i}] ({case.get('name', '?')}) missing keys: {missing}") + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + src0_shape = case["src0_shape"] + src0_valid_shape = case["src0_valid_shape"] + src1_shape = case["src1_shape"] + src1_valid_shape = case["src1_valid_shape"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + + input1 = np.random.randint(1, 10, size=src0_shape).astype(dtype) + input2 = np.random.randint(1, 10, size=src1_shape).astype(dtype) + + golden = np.zeros(dst_shape, dtype=dtype) + dst_vr, dst_vc = dst_valid_shape + src0_vr, src0_vc = src0_valid_shape + src1_vr = src1_valid_shape[0] + + golden[:dst_vr, :dst_vc] = ( + input1[:src0_vr, :src0_vc] * input2[:src1_vr, 0:1] + ).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/launch.cpp new file mode 100644 index 000000000..3bccc18d7 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// f32 kernels +extern "C" __global__ AICORE void TROWEXPANDMUL_f32_16x32(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); +extern "C" __global__ AICORE void TROWEXPANDMUL_f32_56x128(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTROWEXPANDMUL_f32_16x32(float *src0, float *src1, float *dst, void *stream) { + TROWEXPANDMUL_f32_16x32<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} +void LaunchTROWEXPANDMUL_f32_56x128(float *src0, float *src1, float *dst, void *stream) { + TROWEXPANDMUL_f32_56x128<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// f16 kernels (use uint16_t for aclFloat16) +extern "C" __global__ AICORE void TROWEXPANDMUL_f16_48x64(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TROWEXPANDMUL_f16_16x128(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TROWEXPANDMUL_f16_32x64_noeq(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTROWEXPANDMUL_f16_48x64(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDMUL_f16_48x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} +void LaunchTROWEXPANDMUL_f16_16x128(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDMUL_f16_16x128<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} +void LaunchTROWEXPANDMUL_f16_32x64_noeq(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDMUL_f16_32x64_noeq<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} + +// i32 kernels +extern "C" __global__ AICORE void TROWEXPANDMUL_i32_16x32(__gm__ int32_t *src0, __gm__ int32_t *src1, __gm__ int32_t *dst); + +void LaunchTROWEXPANDMUL_i32_16x32(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDMUL_i32_16x32<<<1, nullptr, stream>>>((__gm__ int32_t *)src0, (__gm__ int32_t *)src1, (__gm__ int32_t *)dst); +} + +// i16 kernels +extern "C" __global__ AICORE void TROWEXPANDMUL_i16_16x64(__gm__ int16_t *src0, __gm__ int16_t *src1, __gm__ int16_t *dst); + +void LaunchTROWEXPANDMUL_i16_16x64(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDMUL_i16_16x64<<<1, nullptr, stream>>>((__gm__ int16_t *)src0, (__gm__ int16_t *)src1, (__gm__ int16_t *)dst); +} + +// Note: launchTRowExpandMul2 with src1Col>1 has different semantics - TBD \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/main.cpp new file mode 100644 index 000000000..4c4cd0310 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/main.cpp @@ -0,0 +1,134 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang trowexpandmul ST — row-wise broadcast multiplication. +// Supports f32, f16, i32, i16 + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// f32 kernels +void LaunchTROWEXPANDMUL_f32_16x32(float *src0, float *src1, float *dst, void *stream); +void LaunchTROWEXPANDMUL_f32_56x128(float *src0, float *src1, float *dst, void *stream); +// f16 kernels (use void* for aclFloat16) +void LaunchTROWEXPANDMUL_f16_48x64(void *src0, void *src1, void *dst, void *stream); +void LaunchTROWEXPANDMUL_f16_16x128(void *src0, void *src1, void *dst, void *stream); +void LaunchTROWEXPANDMUL_f16_32x64_noeq(void *src0, void *src1, void *dst, void *stream); +// i32 kernels +void LaunchTROWEXPANDMUL_i32_16x32(void *src0, void *src1, void *dst, void *stream); +// i16 kernels +void LaunchTROWEXPANDMUL_i16_16x64(void *src0, void *src1, void *dst, void *stream); + +// Note: launchTRowExpandMul2 with src1Col>1 has different semantics - TBD + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t src0Rows, src0Cols, src1Rows, src1Cols, dstRows, dstCols; + size_t dstValidRows, dstValidCols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + // f32 cases + {"f32_16x32", (LaunchFn)LaunchTROWEXPANDMUL_f32_16x32, 16, 32, 16, 8, 16, 32, 16, 32, sizeof(float)}, + {"f32_56x128", (LaunchFn)LaunchTROWEXPANDMUL_f32_56x128, 56, 128, 56, 8, 56, 128, 56, 128, sizeof(float)}, + // f16 cases + {"f16_48x64", LaunchTROWEXPANDMUL_f16_48x64, 48, 64, 48, 16, 48, 64, 48, 64, sizeof(uint16_t)}, + {"f16_16x128", LaunchTROWEXPANDMUL_f16_16x128, 16, 128, 16, 16, 16, 128, 16, 128, sizeof(uint16_t)}, + {"f16_32x64_noeq", LaunchTROWEXPANDMUL_f16_32x64_noeq, 32, 64, 32, 16, 32, 64, 32, 64, sizeof(uint16_t)}, + // i32 cases + {"i32_16x32", LaunchTROWEXPANDMUL_i32_16x32, 16, 32, 16, 8, 16, 32, 16, 32, sizeof(int32_t)}, + // i16 cases + {"i16_16x64", LaunchTROWEXPANDMUL_i16_16x64, 16, 64, 16, 16, 16, 64, 16, 64, sizeof(int16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + size_t src0FileSize = tc.src0Rows * tc.src0Cols * tc.elemSize; + size_t src1FileSize = tc.src1Rows * tc.src1Cols * tc.elemSize; + const size_t dstFileSize = tc.dstRows * tc.dstCols * tc.elemSize; + + std::printf("[INFO] === case: %s (src0=%zux%zu, src1=%zux%zu, dst=%zux%zu) ===\n", + tc.name, tc.src0Rows, tc.src0Cols, tc.src1Rows, tc.src1Cols, tc.dstRows, tc.dstCols); + + std::string caseDir = std::string("./") + tc.name; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), src0FileSize); + aclrtMallocHost((void **)(&src1Host), src1FileSize); + aclrtMallocHost((void **)(&dstHost), dstFileSize); + aclrtMalloc((void **)&src0Device, src0FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, src1FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, src0FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, src1FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, src0FileSize, src0Host, src0FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, src1FileSize, src1Host, src1FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + tc.launch(src0Device, src1Device, dstDevice, stream); + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device) aclrtFree(src0Device); + if (src1Device) aclrtFree(src1Device); + if (dstDevice) aclrtFree(dstDevice); + if (src0Host) aclrtFreeHost(src0Host); + if (src1Host) aclrtFreeHost(src1Host); + if (dstHost) aclrtFreeHost(dstHost); + + if (rc == 0) std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + int rc = 0, deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) deviceId = std::atoi(envDevice); + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter && std::strcmp(kCases[i].name, caseFilter) != 0) continue; + if (RunCase(kCases[i], deviceId, stream) != 0) { rc = 1; break; } + } + + if (stream) aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/trowexpandmul.pto b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/trowexpandmul.pto new file mode 100644 index 000000000..8a4cd396a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/trowexpandmul.pto @@ -0,0 +1,212 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.trowexpandmul: row-wise broadcast multiplication. + +module { + // f32_16x32: dstRow=16, dstCol=32, src1Row=16, src1Col=1, src0eqdst=true + func.func @TROWEXPANDMUL_f32_16x32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, shape = [%c1, %c1, %c1, %c16, %c32], strides = [%c512, %c512, %c512, %c32, %c1] : !pto.tensor_view<1x1x1x16x32xf32> + %src1_view = pto.make_tensor_view %src1_ptr, shape = [%c1, %c1, %c1, %c16, %c8], strides = [%c128, %c128, %c128, %c8, %c1] : !pto.tensor_view<1x1x1x16x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c16, %c32], strides = [%c512, %c512, %c512, %c32, %c1] : !pto.tensor_view<1x1x1x16x32xf32> + + %src0_part = pto.partition_view %src0_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c32] : !pto.tensor_view<1x1x1x16x32xf32> -> !pto.partition_tensor_view<1x1x1x16x32xf32> + %src1_part = pto.partition_view %src1_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c8] : !pto.tensor_view<1x1x1x16x8xf32> -> !pto.partition_tensor_view<1x1x1x16x8xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c32] : !pto.tensor_view<1x1x1x16x32xf32> -> !pto.partition_tensor_view<1x1x1x16x32xf32> + + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x32xf32>) outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x8xf32>) outs(%src1 : !pto.tile_buf) + pto.trowexpandmul ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true, highPrecision = false} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x32xf32>) + return + } + + // f32_56x128: dstRow=56, dstCol=128, src1Row=56, src1Col=1, src0eqdst=true + func.func @TROWEXPANDMUL_f32_56x128(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c56 = arith.constant 56 : index + %c128 = arith.constant 128 : index + %c448 = arith.constant 448 : index + %c7168 = arith.constant 7168 : index + + %src0_view = pto.make_tensor_view %src0_ptr, shape = [%c1, %c1, %c1, %c56, %c128], strides = [%c7168, %c7168, %c7168, %c128, %c1] : !pto.tensor_view<1x1x1x56x128xf32> + %src1_view = pto.make_tensor_view %src1_ptr, shape = [%c1, %c1, %c1, %c56, %c8], strides = [%c448, %c448, %c448, %c8, %c1] : !pto.tensor_view<1x1x1x56x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c56, %c128], strides = [%c7168, %c7168, %c7168, %c128, %c1] : !pto.tensor_view<1x1x1x56x128xf32> + + %src0_part = pto.partition_view %src0_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c56, %c128] : !pto.tensor_view<1x1x1x56x128xf32> -> !pto.partition_tensor_view<1x1x1x56x128xf32> + %src1_part = pto.partition_view %src1_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c56, %c8] : !pto.tensor_view<1x1x1x56x8xf32> -> !pto.partition_tensor_view<1x1x1x56x8xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c56, %c128] : !pto.tensor_view<1x1x1x56x128xf32> -> !pto.partition_tensor_view<1x1x1x56x128xf32> + + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x56x128xf32>) outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x56x8xf32>) outs(%src1 : !pto.tile_buf) + pto.trowexpandmul ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true, highPrecision = false} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x56x128xf32>) + return + } + + // f16_48x64: dstRow=48, dstCol=64, src1Row=48, src1Col=1, src0eqdst=true + func.func @TROWEXPANDMUL_f16_48x64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c48 = arith.constant 48 : index + %c64 = arith.constant 64 : index + %c768 = arith.constant 768 : index + %c3072 = arith.constant 3072 : index + + %src0_view = pto.make_tensor_view %src0_ptr, shape = [%c1, %c1, %c1, %c48, %c64], strides = [%c3072, %c3072, %c3072, %c64, %c1] : !pto.tensor_view<1x1x1x48x64xf16> + %src1_view = pto.make_tensor_view %src1_ptr, shape = [%c1, %c1, %c1, %c48, %c16], strides = [%c768, %c768, %c768, %c16, %c1] : !pto.tensor_view<1x1x1x48x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c48, %c64], strides = [%c3072, %c3072, %c3072, %c64, %c1] : !pto.tensor_view<1x1x1x48x64xf16> + + %src0_part = pto.partition_view %src0_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c48, %c64] : !pto.tensor_view<1x1x1x48x64xf16> -> !pto.partition_tensor_view<1x1x1x48x64xf16> + %src1_part = pto.partition_view %src1_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c48, %c16] : !pto.tensor_view<1x1x1x48x16xf16> -> !pto.partition_tensor_view<1x1x1x48x16xf16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c48, %c64] : !pto.tensor_view<1x1x1x48x64xf16> -> !pto.partition_tensor_view<1x1x1x48x64xf16> + + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x48x64xf16>) outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x48x16xf16>) outs(%src1 : !pto.tile_buf) + pto.trowexpandmul ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true, highPrecision = false} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x48x64xf16>) + return + } + + // f16_16x128: dstRow=16, dstCol=128, src1Row=16, src1Col=1, src0eqdst=true + func.func @TROWEXPANDMUL_f16_16x128(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, shape = [%c1, %c1, %c1, %c16, %c128], strides = [%c2048, %c2048, %c2048, %c128, %c1] : !pto.tensor_view<1x1x1x16x128xf16> + %src1_view = pto.make_tensor_view %src1_ptr, shape = [%c1, %c1, %c1, %c16, %c16], strides = [%c256, %c256, %c256, %c16, %c1] : !pto.tensor_view<1x1x1x16x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c16, %c128], strides = [%c2048, %c2048, %c2048, %c128, %c1] : !pto.tensor_view<1x1x1x16x128xf16> + + %src0_part = pto.partition_view %src0_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c128] : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x128xf16> + %src1_part = pto.partition_view %src1_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c16] : !pto.tensor_view<1x1x1x16x16xf16> -> !pto.partition_tensor_view<1x1x1x16x16xf16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c128] : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x128xf16> + + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x128xf16>) outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x16xf16>) outs(%src1 : !pto.tile_buf) + pto.trowexpandmul ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true, highPrecision = false} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x128xf16>) + return + } + + // f16_32x64_noeq: dstRow=32, dstCol=64, src1Row=32, src1Col=1, src0eqdst=false + func.func @TROWEXPANDMUL_f16_32x64_noeq(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c512 = arith.constant 512 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, shape = [%c1, %c1, %c1, %c32, %c64], strides = [%c2048, %c2048, %c2048, %c64, %c1] : !pto.tensor_view<1x1x1x32x64xf16> + %src1_view = pto.make_tensor_view %src1_ptr, shape = [%c1, %c1, %c1, %c32, %c16], strides = [%c512, %c512, %c512, %c16, %c1] : !pto.tensor_view<1x1x1x32x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c32, %c64], strides = [%c2048, %c2048, %c2048, %c64, %c1] : !pto.tensor_view<1x1x1x32x64xf16> + + %src0_part = pto.partition_view %src0_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c64] : !pto.tensor_view<1x1x1x32x64xf16> -> !pto.partition_tensor_view<1x1x1x32x64xf16> + %src1_part = pto.partition_view %src1_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c16] : !pto.tensor_view<1x1x1x32x16xf16> -> !pto.partition_tensor_view<1x1x1x32x16xf16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c64] : !pto.tensor_view<1x1x1x32x64xf16> -> !pto.partition_tensor_view<1x1x1x32x64xf16> + + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x32x64xf16>) outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x32x16xf16>) outs(%src1 : !pto.tile_buf) + pto.trowexpandmul ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = false, highPrecision = false} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xf16>) + return + } + + // i32_16x32: dstRow=16, dstCol=32, src1Row=16, src1Col=1, src0eqdst=true + func.func @TROWEXPANDMUL_i32_16x32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, shape = [%c1, %c1, %c1, %c16, %c32], strides = [%c512, %c512, %c512, %c32, %c1] : !pto.tensor_view<1x1x1x16x32xi32> + %src1_view = pto.make_tensor_view %src1_ptr, shape = [%c1, %c1, %c1, %c16, %c8], strides = [%c128, %c128, %c128, %c8, %c1] : !pto.tensor_view<1x1x1x16x8xi32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c16, %c32], strides = [%c512, %c512, %c512, %c32, %c1] : !pto.tensor_view<1x1x1x16x32xi32> + + %src0_part = pto.partition_view %src0_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c32] : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + %src1_part = pto.partition_view %src1_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c8] : !pto.tensor_view<1x1x1x16x8xi32> -> !pto.partition_tensor_view<1x1x1x16x8xi32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c32] : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x8xi32>) outs(%src1 : !pto.tile_buf) + pto.trowexpandmul ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true, highPrecision = false} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + return + } + + // i16_16x64: dstRow=16, dstCol=64, src1Row=16, src1Col=1, src0eqdst=true + func.func @TROWEXPANDMUL_i16_16x64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, shape = [%c1, %c1, %c1, %c16, %c64], strides = [%c1024, %c1024, %c1024, %c64, %c1] : !pto.tensor_view<1x1x1x16x64xi16> + %src1_view = pto.make_tensor_view %src1_ptr, shape = [%c1, %c1, %c1, %c16, %c16], strides = [%c256, %c256, %c256, %c16, %c1] : !pto.tensor_view<1x1x1x16x16xi16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c16, %c64], strides = [%c1024, %c1024, %c1024, %c64, %c1] : !pto.tensor_view<1x1x1x16x64xi16> + + %src0_part = pto.partition_view %src0_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c64] : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + %src1_part = pto.partition_view %src1_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c16] : !pto.tensor_view<1x1x1x16x16xi16> -> !pto.partition_tensor_view<1x1x1x16x16xi16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c64] : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x16xi16>) outs(%src1 : !pto.tile_buf) + pto.trowexpandmul ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true, highPrecision = false} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/CMakeLists.txt new file mode 100644 index 000000000..fe69e4770 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(trowexpandsub) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/cases.py new file mode 100644 index 000000000..c9ac76001 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/cases.py @@ -0,0 +1,111 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for trowexpandsub ST test cases. + +trowexpandsub: row-wise broadcast subtraction. +- src1Col=1: only first column of src1 is valid, broadcast to dstCols +- src1Col>1: launchTRowExpandSub2 with different semantics (TBD) +- src1 physical cols = 32/sizeof(dtype) for NPU alignment +- src0eqdst: true means src0 shape equals dst shape +""" + +import numpy as np + +CASES = [ + # launchTRowExpandSub + { + "name": "f32_8x128", + "dtype": np.float32, + "src0_shape": (8, 128), + "src0_valid_shape": (8, 128), + "src1_shape": (8, 8), # physical: 32/sizeof(f32)=8 + "src1_valid_shape": (8, 1), # src1Col=1 + "dst_shape": (8, 128), + "dst_valid_shape": (8, 128), + "eps": 1e-6, + }, + # launchTRowExpandSub + { + "name": "f32_24x32", + "dtype": np.float32, + "src0_shape": (24, 32), + "src0_valid_shape": (24, 32), + "src1_shape": (24, 8), + "src1_valid_shape": (24, 1), + "dst_shape": (24, 32), + "dst_valid_shape": (24, 32), + "eps": 1e-6, + }, + # launchTRowExpandSub + { + "name": "f16_16x256", + "dtype": np.float16, + "src0_shape": (16, 256), + "src0_valid_shape": (16, 256), + "src1_shape": (16, 16), # physical: 32/sizeof(f16)=16 + "src1_valid_shape": (16, 1), + "dst_shape": (16, 256), + "dst_valid_shape": (16, 256), + "eps": 1e-3, + }, + # launchTRowExpandSub + { + "name": "f16_32x64", + "dtype": np.float16, + "src0_shape": (32, 64), + "src0_valid_shape": (32, 64), + "src1_shape": (32, 16), + "src1_valid_shape": (32, 1), + "dst_shape": (32, 64), + "dst_valid_shape": (32, 64), + "eps": 1e-3, + }, + # launchTRowExpandSub + { + "name": "f32_16x128_noeq", + "dtype": np.float32, + "src0_shape": (16, 128), # src0eqdst=false + "src0_valid_shape": (16, 128), + "src1_shape": (16, 8), + "src1_valid_shape": (16, 1), + "dst_shape": (16, 128), + "dst_valid_shape": (16, 128), + "eps": 1e-6, + }, + # launchTRowExpandSub + { + "name": "i32_16x32", + "dtype": np.int32, + "src0_shape": (16, 32), + "src0_valid_shape": (16, 32), + "src1_shape": (16, 8), # physical: 32/sizeof(i32)=8 + "src1_valid_shape": (16, 1), + "dst_shape": (16, 32), + "dst_valid_shape": (16, 32), + "eps": 0, + }, + # launchTRowExpandSub + { + "name": "i16_16x64", + "dtype": np.int16, + "src0_shape": (16, 64), + "src0_valid_shape": (16, 64), + "src1_shape": (16, 16), # physical: 32/sizeof(i16)=16 + "src1_valid_shape": (16, 1), + "dst_shape": (16, 64), + "dst_valid_shape": (16, 64), + "eps": 0, + }, + # Note: launchTRowExpandSub2 with src1Col>1 has different semantics - TBD + # - float, 24, 64, 24, 8, true (src1Col=8) + # - aclFloat16, 16, 64, 16, 16, false (src1Col=16) +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/compare.py new file mode 100644 index 000000000..3105c4da6 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/compare.py @@ -0,0 +1,61 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Compare golden and output for trowexpandsub ST test cases.""" + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass + +# Inline validation for multi-input format (trowexpandsub uses src0/src1/dst) +REQUIRED_KEYS = {"name", "dtype", "src0_shape", "src0_valid_shape", "src1_shape", + "src1_valid_shape", "dst_shape", "dst_valid_shape"} +for i, case in enumerate(CASES): + missing = REQUIRED_KEYS - case.keys() + if missing: + raise ValueError(f"cases[{i}] ({case.get('name', '?')}) missing keys: {missing}") + + +def main(): + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + dtype = case["dtype"] + + vr, vc = dst_valid_shape + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=dtype).reshape(dst_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=dtype).reshape(dst_shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/gen_data.py new file mode 100644 index 000000000..64b40f040 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/gen_data.py @@ -0,0 +1,54 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Generate input and golden data for trowexpandsub ST test cases. + +trowexpandsub: dst = src0 - broadcast(src1) +""" + +import numpy as np +from cases import CASES +from st_common import setup_case_rng, save_case_data + +# Inline validation for multi-input format (trowexpandsub uses src0/src1/dst) +REQUIRED_KEYS = {"name", "dtype", "src0_shape", "src0_valid_shape", "src1_shape", + "src1_valid_shape", "dst_shape", "dst_valid_shape"} +for i, case in enumerate(CASES): + missing = REQUIRED_KEYS - case.keys() + if missing: + raise ValueError(f"cases[{i}] ({case.get('name', '?')}) missing keys: {missing}") + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + src0_shape = case["src0_shape"] + src0_valid_shape = case["src0_valid_shape"] + src1_shape = case["src1_shape"] + src1_valid_shape = case["src1_valid_shape"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + + input1 = np.random.randint(1, 10, size=src0_shape).astype(dtype) + input2 = np.random.randint(1, 10, size=src1_shape).astype(dtype) + + golden = np.zeros(dst_shape, dtype=dtype) + dst_vr, dst_vc = dst_valid_shape + src0_vr, src0_vc = src0_valid_shape + src1_vr = src1_valid_shape[0] + + # dst = src0 - src1_scalar (broadcasted) + golden[:dst_vr, :dst_vc] = ( + input1[:src0_vr, :src0_vc] - input2[:src1_vr, 0:1] + ).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} src0={src0_shape} src1={src1_shape} dst={dst_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/launch.cpp new file mode 100644 index 000000000..4692348e2 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// f32 kernels +extern "C" __global__ AICORE void TROWEXPANDSUB_f32_8x128(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); +extern "C" __global__ AICORE void TROWEXPANDSUB_f32_24x32(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); +extern "C" __global__ AICORE void TROWEXPANDSUB_f32_16x128_noeq(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTROWEXPANDSUB_f32_8x128(float *src0, float *src1, float *dst, void *stream) { + TROWEXPANDSUB_f32_8x128<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} +void LaunchTROWEXPANDSUB_f32_24x32(float *src0, float *src1, float *dst, void *stream) { + TROWEXPANDSUB_f32_24x32<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} +void LaunchTROWEXPANDSUB_f32_16x128_noeq(float *src0, float *src1, float *dst, void *stream) { + TROWEXPANDSUB_f32_16x128_noeq<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// f16 kernels (use uint16_t for aclFloat16) +extern "C" __global__ AICORE void TROWEXPANDSUB_f16_16x256(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TROWEXPANDSUB_f16_32x64(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTROWEXPANDSUB_f16_16x256(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDSUB_f16_16x256<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} +void LaunchTROWEXPANDSUB_f16_32x64(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDSUB_f16_32x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} + +// i32 kernels +extern "C" __global__ AICORE void TROWEXPANDSUB_i32_16x32(__gm__ int32_t *src0, __gm__ int32_t *src1, __gm__ int32_t *dst); + +void LaunchTROWEXPANDSUB_i32_16x32(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDSUB_i32_16x32<<<1, nullptr, stream>>>((__gm__ int32_t *)src0, (__gm__ int32_t *)src1, (__gm__ int32_t *)dst); +} + +// i16 kernels +extern "C" __global__ AICORE void TROWEXPANDSUB_i16_16x64(__gm__ int16_t *src0, __gm__ int16_t *src1, __gm__ int16_t *dst); + +void LaunchTROWEXPANDSUB_i16_16x64(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDSUB_i16_16x64<<<1, nullptr, stream>>>((__gm__ int16_t *)src0, (__gm__ int16_t *)src1, (__gm__ int16_t *)dst); +} + +// Note: launchTRowExpandSub2 with src1Col>1 has different semantics - TBD \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/main.cpp new file mode 100644 index 000000000..4943e2cc7 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/main.cpp @@ -0,0 +1,134 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang trowexpandsub ST — row-wise broadcast subtraction. +// Supports f32, f16, i32, i16 + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// f32 kernels +void LaunchTROWEXPANDSUB_f32_8x128(float *src0, float *src1, float *dst, void *stream); +void LaunchTROWEXPANDSUB_f32_24x32(float *src0, float *src1, float *dst, void *stream); +void LaunchTROWEXPANDSUB_f32_16x128_noeq(float *src0, float *src1, float *dst, void *stream); +// f16 kernels (use void* for aclFloat16) +void LaunchTROWEXPANDSUB_f16_16x256(void *src0, void *src1, void *dst, void *stream); +void LaunchTROWEXPANDSUB_f16_32x64(void *src0, void *src1, void *dst, void *stream); +// i32 kernels +void LaunchTROWEXPANDSUB_i32_16x32(void *src0, void *src1, void *dst, void *stream); +// i16 kernels +void LaunchTROWEXPANDSUB_i16_16x64(void *src0, void *src1, void *dst, void *stream); + +// Note: launchTRowExpandSub2 with src1Col>1 has different semantics - TBD + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t src0Rows, src0Cols, src1Rows, src1Cols, dstRows, dstCols; + size_t dstValidRows, dstValidCols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + // f32 cases + {"f32_8x128", (LaunchFn)LaunchTROWEXPANDSUB_f32_8x128, 8, 128, 8, 8, 8, 128, 8, 128, sizeof(float)}, + {"f32_24x32", (LaunchFn)LaunchTROWEXPANDSUB_f32_24x32, 24, 32, 24, 8, 24, 32, 24, 32, sizeof(float)}, + {"f32_16x128_noeq", (LaunchFn)LaunchTROWEXPANDSUB_f32_16x128_noeq, 16, 128, 16, 8, 16, 128, 16, 128, sizeof(float)}, + // f16 cases + {"f16_16x256", LaunchTROWEXPANDSUB_f16_16x256, 16, 256, 16, 16, 16, 256, 16, 256, sizeof(uint16_t)}, + {"f16_32x64", LaunchTROWEXPANDSUB_f16_32x64, 32, 64, 32, 16, 32, 64, 32, 64, sizeof(uint16_t)}, + // i32 cases + {"i32_16x32", LaunchTROWEXPANDSUB_i32_16x32, 16, 32, 16, 8, 16, 32, 16, 32, sizeof(int32_t)}, + // i16 cases + {"i16_16x64", LaunchTROWEXPANDSUB_i16_16x64, 16, 64, 16, 16, 16, 64, 16, 64, sizeof(int16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + size_t src0FileSize = tc.src0Rows * tc.src0Cols * tc.elemSize; + size_t src1FileSize = tc.src1Rows * tc.src1Cols * tc.elemSize; + const size_t dstFileSize = tc.dstRows * tc.dstCols * tc.elemSize; + + std::printf("[INFO] === case: %s (src0=%zux%zu, src1=%zux%zu, dst=%zux%zu) ===\n", + tc.name, tc.src0Rows, tc.src0Cols, tc.src1Rows, tc.src1Cols, tc.dstRows, tc.dstCols); + + std::string caseDir = std::string("./") + tc.name; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), src0FileSize); + aclrtMallocHost((void **)(&src1Host), src1FileSize); + aclrtMallocHost((void **)(&dstHost), dstFileSize); + aclrtMalloc((void **)&src0Device, src0FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, src1FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, src0FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, src1FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, src0FileSize, src0Host, src0FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, src1FileSize, src1Host, src1FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + tc.launch(src0Device, src1Device, dstDevice, stream); + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device) aclrtFree(src0Device); + if (src1Device) aclrtFree(src1Device); + if (dstDevice) aclrtFree(dstDevice); + if (src0Host) aclrtFreeHost(src0Host); + if (src1Host) aclrtFreeHost(src1Host); + if (dstHost) aclrtFreeHost(dstHost); + + if (rc == 0) std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + int rc = 0, deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) deviceId = std::atoi(envDevice); + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter && std::strcmp(kCases[i].name, caseFilter) != 0) continue; + if (RunCase(kCases[i], deviceId, stream) != 0) { rc = 1; break; } + } + + if (stream) aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/trowexpandsub.pto b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/trowexpandsub.pto new file mode 100644 index 000000000..dad74917d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/trowexpandsub.pto @@ -0,0 +1,211 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.trowexpandsub: row-wise broadcast subtraction. + +module { + // f32_8x128: dstRow=8, dstCol=128, src1Row=8, src1Col=1, src0eqdst=true + func.func @TROWEXPANDSUB_f32_8x128(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, shape = [%c1, %c1, %c1, %c8, %c128], strides = [%c1024, %c1024, %c1024, %c128, %c1] : !pto.tensor_view<1x1x1x8x128xf32> + %src1_view = pto.make_tensor_view %src1_ptr, shape = [%c1, %c1, %c1, %c8, %c8], strides = [%c64, %c64, %c64, %c8, %c1] : !pto.tensor_view<1x1x1x8x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c8, %c128], strides = [%c1024, %c1024, %c1024, %c128, %c1] : !pto.tensor_view<1x1x1x8x128xf32> + + %src0_part = pto.partition_view %src0_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c8, %c128] : !pto.tensor_view<1x1x1x8x128xf32> -> !pto.partition_tensor_view<1x1x1x8x128xf32> + %src1_part = pto.partition_view %src1_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c8, %c8] : !pto.tensor_view<1x1x1x8x8xf32> -> !pto.partition_tensor_view<1x1x1x8x8xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c8, %c128] : !pto.tensor_view<1x1x1x8x128xf32> -> !pto.partition_tensor_view<1x1x1x8x128xf32> + + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x8x128xf32>) outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x8x8xf32>) outs(%src1 : !pto.tile_buf) + pto.trowexpandsub ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true, highPrecision = false} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x8x128xf32>) + return + } + + // f32_24x32: dstRow=24, dstCol=32, src1Row=24, src1Col=1, src0eqdst=true + func.func @TROWEXPANDSUB_f32_24x32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c24 = arith.constant 24 : index + %c32 = arith.constant 32 : index + %c192 = arith.constant 192 : index + %c768 = arith.constant 768 : index + + %src0_view = pto.make_tensor_view %src0_ptr, shape = [%c1, %c1, %c1, %c24, %c32], strides = [%c768, %c768, %c768, %c32, %c1] : !pto.tensor_view<1x1x1x24x32xf32> + %src1_view = pto.make_tensor_view %src1_ptr, shape = [%c1, %c1, %c1, %c24, %c8], strides = [%c192, %c192, %c192, %c8, %c1] : !pto.tensor_view<1x1x1x24x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c24, %c32], strides = [%c768, %c768, %c768, %c32, %c1] : !pto.tensor_view<1x1x1x24x32xf32> + + %src0_part = pto.partition_view %src0_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c24, %c32] : !pto.tensor_view<1x1x1x24x32xf32> -> !pto.partition_tensor_view<1x1x1x24x32xf32> + %src1_part = pto.partition_view %src1_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c24, %c8] : !pto.tensor_view<1x1x1x24x8xf32> -> !pto.partition_tensor_view<1x1x1x24x8xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c24, %c32] : !pto.tensor_view<1x1x1x24x32xf32> -> !pto.partition_tensor_view<1x1x1x24x32xf32> + + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x24x32xf32>) outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x24x8xf32>) outs(%src1 : !pto.tile_buf) + pto.trowexpandsub ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true, highPrecision = false} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x24x32xf32>) + return + } + + // f16_16x256: dstRow=16, dstCol=256, src1Row=16, src1Col=1, src0eqdst=true + func.func @TROWEXPANDSUB_f16_16x256(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c256_2 = arith.constant 256 : index + %c4096 = arith.constant 4096 : index + + %src0_view = pto.make_tensor_view %src0_ptr, shape = [%c1, %c1, %c1, %c16, %c256], strides = [%c4096, %c4096, %c4096, %c256, %c1] : !pto.tensor_view<1x1x1x16x256xf16> + %src1_view = pto.make_tensor_view %src1_ptr, shape = [%c1, %c1, %c1, %c16, %c16], strides = [%c256_2, %c256_2, %c256_2, %c16, %c1] : !pto.tensor_view<1x1x1x16x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c16, %c256], strides = [%c4096, %c4096, %c4096, %c256, %c1] : !pto.tensor_view<1x1x1x16x256xf16> + + %src0_part = pto.partition_view %src0_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c256] : !pto.tensor_view<1x1x1x16x256xf16> -> !pto.partition_tensor_view<1x1x1x16x256xf16> + %src1_part = pto.partition_view %src1_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c16] : !pto.tensor_view<1x1x1x16x16xf16> -> !pto.partition_tensor_view<1x1x1x16x16xf16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c256] : !pto.tensor_view<1x1x1x16x256xf16> -> !pto.partition_tensor_view<1x1x1x16x256xf16> + + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x256xf16>) outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x16xf16>) outs(%src1 : !pto.tile_buf) + pto.trowexpandsub ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true, highPrecision = false} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x256xf16>) + return + } + + // f16_32x64: dstRow=32, dstCol=64, src1Row=32, src1Col=1, src0eqdst=true + func.func @TROWEXPANDSUB_f16_32x64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c512 = arith.constant 512 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, shape = [%c1, %c1, %c1, %c32, %c64], strides = [%c2048, %c2048, %c2048, %c64, %c1] : !pto.tensor_view<1x1x1x32x64xf16> + %src1_view = pto.make_tensor_view %src1_ptr, shape = [%c1, %c1, %c1, %c32, %c16], strides = [%c512, %c512, %c512, %c16, %c1] : !pto.tensor_view<1x1x1x32x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c32, %c64], strides = [%c2048, %c2048, %c2048, %c64, %c1] : !pto.tensor_view<1x1x1x32x64xf16> + + %src0_part = pto.partition_view %src0_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c64] : !pto.tensor_view<1x1x1x32x64xf16> -> !pto.partition_tensor_view<1x1x1x32x64xf16> + %src1_part = pto.partition_view %src1_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c16] : !pto.tensor_view<1x1x1x32x16xf16> -> !pto.partition_tensor_view<1x1x1x32x16xf16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c64] : !pto.tensor_view<1x1x1x32x64xf16> -> !pto.partition_tensor_view<1x1x1x32x64xf16> + + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x32x64xf16>) outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x32x16xf16>) outs(%src1 : !pto.tile_buf) + pto.trowexpandsub ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true, highPrecision = false} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xf16>) + return + } + + // f32_16x128_noeq: dstRow=16, dstCol=128, src1Row=16, src1Col=1, src0eqdst=false + func.func @TROWEXPANDSUB_f32_16x128_noeq(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c128_2 = arith.constant 128 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, shape = [%c1, %c1, %c1, %c16, %c128], strides = [%c2048, %c2048, %c2048, %c128, %c1] : !pto.tensor_view<1x1x1x16x128xf32> + %src1_view = pto.make_tensor_view %src1_ptr, shape = [%c1, %c1, %c1, %c16, %c8], strides = [%c128_2, %c128_2, %c128_2, %c8, %c1] : !pto.tensor_view<1x1x1x16x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c16, %c128], strides = [%c2048, %c2048, %c2048, %c128, %c1] : !pto.tensor_view<1x1x1x16x128xf32> + + %src0_part = pto.partition_view %src0_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c128] : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x128xf32> + %src1_part = pto.partition_view %src1_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c8] : !pto.tensor_view<1x1x1x16x8xf32> -> !pto.partition_tensor_view<1x1x1x16x8xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c128] : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x128xf32> + + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x128xf32>) outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x8xf32>) outs(%src1 : !pto.tile_buf) + pto.trowexpandsub ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = false, highPrecision = false} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x128xf32>) + return + } + + // i32_16x32: dstRow=16, dstCol=32, src1Row=16, src1Col=1, src0eqdst=true + func.func @TROWEXPANDSUB_i32_16x32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, shape = [%c1, %c1, %c1, %c16, %c32], strides = [%c512, %c512, %c512, %c32, %c1] : !pto.tensor_view<1x1x1x16x32xi32> + %src1_view = pto.make_tensor_view %src1_ptr, shape = [%c1, %c1, %c1, %c16, %c8], strides = [%c128, %c128, %c128, %c8, %c1] : !pto.tensor_view<1x1x1x16x8xi32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c16, %c32], strides = [%c512, %c512, %c512, %c32, %c1] : !pto.tensor_view<1x1x1x16x32xi32> + + %src0_part = pto.partition_view %src0_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c32] : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + %src1_part = pto.partition_view %src1_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c8] : !pto.tensor_view<1x1x1x16x8xi32> -> !pto.partition_tensor_view<1x1x1x16x8xi32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c32] : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x8xi32>) outs(%src1 : !pto.tile_buf) + pto.trowexpandsub ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true, highPrecision = false} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + return + } + + // i16_16x64: dstRow=16, dstCol=64, src1Row=16, src1Col=1, src0eqdst=true + func.func @TROWEXPANDSUB_i16_16x64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, shape = [%c1, %c1, %c1, %c16, %c64], strides = [%c1024, %c1024, %c1024, %c64, %c1] : !pto.tensor_view<1x1x1x16x64xi16> + %src1_view = pto.make_tensor_view %src1_ptr, shape = [%c1, %c1, %c1, %c16, %c16], strides = [%c256, %c256, %c256, %c16, %c1] : !pto.tensor_view<1x1x1x16x16xi16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c16, %c64], strides = [%c1024, %c1024, %c1024, %c64, %c1] : !pto.tensor_view<1x1x1x16x64xi16> + + %src0_part = pto.partition_view %src0_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c64] : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + %src1_part = pto.partition_view %src1_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c16] : !pto.tensor_view<1x1x1x16x16xi16> -> !pto.partition_tensor_view<1x1x1x16x16xi16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c64] : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x16xi16>) outs(%src1 : !pto.tile_buf) + pto.trowexpandsub ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true, highPrecision = false} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + return + } +} \ No newline at end of file From e2f991feabf3c78e13fb36d6abfc3c88c8fb18f0 Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Sun, 26 Apr 2026 23:04:23 +0800 Subject: [PATCH 176/192] feat: enable vpto sim --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index abe08cbd5..48b362363 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -437,7 +437,7 @@ jobs: echo "PTOAS_BIN=${GITHUB_WORKSPACE}/build/tools/ptoas/ptoas" >> "${GITHUB_ENV}" - name: Run VPTO SIM validation - if: ${{ false }} + if: ${{ true }} shell: bash run: | set -euo pipefail From d71462bb5b3fe10206c271b9e8ab22009efe3da8 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Sat, 25 Apr 2026 17:02:50 +0800 Subject: [PATCH 177/192] fix(dsl): allow integer vdiv vector types --- lib/TileOps/math.py | 153 +++- .../.openspec.yaml | 2 + .../design.md | 139 ++++ .../proposal.md | 107 +++ .../specs/tilelang-dsl-diagnostics/spec.md | 29 + .../specs/tilelang-dsl-surface/spec.md | 31 + .../specs/tilelang-dsl-vpto-lowering/spec.md | 47 ++ .../tasks.md | 38 + .../tilelang_soft_vmod_backend_inline.pto | 126 +++ ...o_tilelang_inline_soft_divmod_fastpath.pto | 158 ++++ .../11-vector-arithmetic-operations.md | 28 + .../python/tilelang_dsl/frontend_ast.py | 81 ++ tilelang-dsl/python/tilelang_dsl/kernel.py | 53 ++ tilelang-dsl/python/tilelang_dsl/semantic.py | 78 +- .../python/tilelang_dsl/support_matrix.py | 1 + tilelang-dsl/tests/test_tilelang_dsl_v1.py | 785 ++++++++++++++++++ tools/ptoas/ptoas.cpp | 31 + 17 files changed, 1855 insertions(+), 32 deletions(-) create mode 100644 openspec/changes/hide-soft-divmod-behind-pto-surface/.openspec.yaml create mode 100644 openspec/changes/hide-soft-divmod-behind-pto-surface/design.md create mode 100644 openspec/changes/hide-soft-divmod-behind-pto-surface/proposal.md create mode 100644 openspec/changes/hide-soft-divmod-behind-pto-surface/specs/tilelang-dsl-diagnostics/spec.md create mode 100644 openspec/changes/hide-soft-divmod-behind-pto-surface/specs/tilelang-dsl-surface/spec.md create mode 100644 openspec/changes/hide-soft-divmod-behind-pto-surface/specs/tilelang-dsl-vpto-lowering/spec.md create mode 100644 openspec/changes/hide-soft-divmod-behind-pto-surface/tasks.md create mode 100644 test/basic/tilelang_soft_vmod_backend_inline.pto create mode 100644 test/vpto_tilelang_inline_soft_divmod_fastpath.pto diff --git a/lib/TileOps/math.py b/lib/TileOps/math.py index 3d32340db..83c4634f6 100644 --- a/lib/TileOps/math.py +++ b/lib/TileOps/math.py @@ -10,7 +10,55 @@ @pto.inline_proc -def _vdiv_u16(vec, scalar_vec, mask): +def _tl_soft_vdiv_u8(vec, scalar_vec, mask): + zero = pto.ui8(0) + zero_q = pto.ui8(0xFF) + full_mask_b8 = pto.pset_b8(pto.PAT.ALL) + + zero_mask = pto.vcmps(scalar_vec, zero, mask, pto.CmpMode.EQ) + active_mask = pto.pnot(zero_mask, mask) + active_low = pto.punpack(active_mask, pto.PredicatePart.LOWER) + active_high = pto.punpack(active_mask, pto.PredicatePart.HIGHER) + + vec_low = pto.vzunpack(vec, 0) + vec_high = pto.vzunpack(vec, 1) + scalar_low = pto.vzunpack(scalar_vec, 0) + scalar_high = pto.vzunpack(scalar_vec, 1) + + q_low = _tl_soft_vdiv_u16(vec_low, scalar_low, active_low) + q_high = _tl_soft_vdiv_u16(vec_high, scalar_high, active_high) + packed_low = pto.vpack(q_low, pto.PredicatePart.LOWER) + packed_high = pto.vpack(q_high, pto.PredicatePart.HIGHER) + q = pto.vor(packed_low, packed_high, full_mask_b8) + return pto.vsel(pto.vbr(zero_q), q, zero_mask) + + +@pto.inline_proc +def _tl_soft_vdiv_i8(vec, scalar_vec, mask): + zero = pto.i8(0) + neg_one = pto.i8(-1) + full_mask_b8 = pto.pset_b8(pto.PAT.ALL) + + zero_mask = pto.vcmps(scalar_vec, zero, mask, pto.CmpMode.EQ) + active_mask = pto.pnot(zero_mask, mask) + active_low = pto.punpack(active_mask, pto.PredicatePart.LOWER) + active_high = pto.punpack(active_mask, pto.PredicatePart.HIGHER) + + vec_low = pto.vsunpack(vec, 0) + vec_high = pto.vsunpack(vec, 1) + scalar_low = pto.vsunpack(scalar_vec, 0) + scalar_high = pto.vsunpack(scalar_vec, 1) + + q_low = _tl_soft_vdiv_i16(vec_low, scalar_low, active_low) + q_high = _tl_soft_vdiv_i16(vec_high, scalar_high, active_high) + packed_low = pto.vpack(q_low, pto.PredicatePart.LOWER) + packed_high = pto.vpack(q_high, pto.PredicatePart.HIGHER) + q = pto.vbitcast(pto.vor(packed_low, packed_high, full_mask_b8), pto.i8) + return pto.vsel(pto.vbr(neg_one), q, zero_mask) + + +@pto.inline_proc +def _tl_soft_vdiv_u16(vec, scalar_vec, mask): zero = pto.ui16(0) one = pto.ui16(1) fp32_one = pto.f32(1.0) @@ -75,12 +123,7 @@ def _vdiv_u16(vec, scalar_vec, mask): @pto.inline_proc -def vdiv_u16(vec, scalar_vec, mask): - return _vdiv_u16(vec, scalar_vec, mask) - - -@pto.inline_proc -def vdiv_i16(vec, scalar_vec, mask): +def _tl_soft_vdiv_i16(vec, scalar_vec, mask): zero = pto.i16(0) neg_one = pto.i16(-1) @@ -92,14 +135,62 @@ def vdiv_i16(vec, scalar_vec, mask): x_xor_y = pto.vxor(vec, scalar_vec, active_mask) p_pos = pto.vcmps(x_xor_y, zero, active_mask, pto.CmpMode.GE) - q_abs = _vdiv_u16(abs_x, abs_y, active_mask) + q_abs = _tl_soft_vdiv_u16(abs_x, abs_y, active_mask) neg_q = pto.vneg(pto.vbitcast(q_abs, pto.i16), active_mask) q = pto.vsel(pto.vbitcast(q_abs, pto.i16), neg_q, p_pos) return pto.vsel(pto.vbr(neg_one), q, zero_mask) @pto.inline_proc -def vmod_u16(vec, scalar_vec, mask): +def _tl_soft_vmod_u8(vec, scalar_vec, mask): + zero = pto.ui8(0) + zero_r = pto.ui8(0xFF) + full_mask_b8 = pto.pset_b8(pto.PAT.ALL) + + zero_mask = pto.vcmps(scalar_vec, zero, mask, pto.CmpMode.EQ) + active_mask = pto.pnot(zero_mask, mask) + active_low = pto.punpack(active_mask, pto.PredicatePart.LOWER) + active_high = pto.punpack(active_mask, pto.PredicatePart.HIGHER) + + vec_low = pto.vzunpack(vec, 0) + vec_high = pto.vzunpack(vec, 1) + scalar_low = pto.vzunpack(scalar_vec, 0) + scalar_high = pto.vzunpack(scalar_vec, 1) + + r_low = _tl_soft_vmod_u16(vec_low, scalar_low, active_low) + r_high = _tl_soft_vmod_u16(vec_high, scalar_high, active_high) + packed_low = pto.vpack(r_low, pto.PredicatePart.LOWER) + packed_high = pto.vpack(r_high, pto.PredicatePart.HIGHER) + r = pto.vor(packed_low, packed_high, full_mask_b8) + return pto.vsel(pto.vbr(zero_r), r, zero_mask) + + +@pto.inline_proc +def _tl_soft_vmod_i8(vec, scalar_vec, mask): + zero = pto.i8(0) + neg_one = pto.i8(-1) + full_mask_b8 = pto.pset_b8(pto.PAT.ALL) + + zero_mask = pto.vcmps(scalar_vec, zero, mask, pto.CmpMode.EQ) + active_mask = pto.pnot(zero_mask, mask) + active_low = pto.punpack(active_mask, pto.PredicatePart.LOWER) + active_high = pto.punpack(active_mask, pto.PredicatePart.HIGHER) + + vec_low = pto.vsunpack(vec, 0) + vec_high = pto.vsunpack(vec, 1) + scalar_low = pto.vsunpack(scalar_vec, 0) + scalar_high = pto.vsunpack(scalar_vec, 1) + + r_low = _tl_soft_vmod_i16(vec_low, scalar_low, active_low) + r_high = _tl_soft_vmod_i16(vec_high, scalar_high, active_high) + packed_low = pto.vpack(r_low, pto.PredicatePart.LOWER) + packed_high = pto.vpack(r_high, pto.PredicatePart.HIGHER) + r = pto.vbitcast(pto.vor(packed_low, packed_high, full_mask_b8), pto.i8) + return pto.vsel(pto.vbr(neg_one), r, zero_mask) + + +@pto.inline_proc +def _tl_soft_vmod_u16(vec, scalar_vec, mask): zero = pto.ui16(0) one = pto.ui16(1) zero_r = pto.ui16(0xFFFF) @@ -164,7 +255,7 @@ def vmod_u16(vec, scalar_vec, mask): @pto.inline_proc -def vdiv_u32(vec, scalar_vec, mask): +def _tl_soft_vdiv_u32(vec, scalar_vec, mask): zero = pto.ui32(0) one = pto.ui32(1) zero_q = pto.ui32(0xFFFFFFFF) @@ -225,7 +316,7 @@ def vdiv_u32(vec, scalar_vec, mask): @pto.inline_proc -def vmod_i16(vec, scalar_vec, mask): +def _tl_soft_vmod_i16(vec, scalar_vec, mask): zero = pto.i16(0) neg_one = pto.i16(-1) @@ -237,7 +328,7 @@ def vmod_i16(vec, scalar_vec, mask): x_xor_y = pto.vxor(vec, scalar_vec, active_mask) p_pos = pto.vcmps(x_xor_y, zero, active_mask, pto.CmpMode.GE) - q_abs = _vdiv_u16(abs_x, abs_y, active_mask) + q_abs = _tl_soft_vdiv_u16(abs_x, abs_y, active_mask) neg_q = pto.vneg(pto.vbitcast(q_abs, pto.i16), active_mask) q = pto.vsel(pto.vbitcast(q_abs, pto.i16), neg_q, p_pos) @@ -255,7 +346,7 @@ def vmod_i16(vec, scalar_vec, mask): @pto.inline_proc -def vdiv_i32(vec, scalar_vec, mask): +def _tl_soft_vdiv_i32(vec, scalar_vec, mask): zero = pto.i32(0) neg_one = pto.i32(-1) fp32_one = pto.f32(1.0) @@ -322,7 +413,7 @@ def vdiv_i32(vec, scalar_vec, mask): @pto.inline_proc -def vmod_u32(vec, scalar_vec, mask): +def _tl_soft_vmod_u32(vec, scalar_vec, mask): zero = pto.ui32(0) one = pto.ui32(1) zero_r = pto.ui32(0xFFFFFFFF) @@ -383,7 +474,7 @@ def vmod_u32(vec, scalar_vec, mask): @pto.inline_proc -def vmod_i32(vec, scalar_vec, mask): +def _tl_soft_vmod_i32(vec, scalar_vec, mask): zero = pto.i32(0) neg_one = pto.i32(-1) fp32_one = pto.f32(1.0) @@ -460,26 +551,34 @@ def vmod_i32(vec, scalar_vec, mask): @pto.inline_proc -def vmod(vec, scalar_vec, mask, dtype): - if pto.constexpr(dtype == pto.ui16): - result = vmod_u16(vec, scalar_vec, mask) +def _tl_soft_vmod(vec, scalar_vec, mask, dtype): + if pto.constexpr(dtype == pto.ui8): + result = _tl_soft_vmod_u8(vec, scalar_vec, mask) + elif pto.constexpr(dtype == pto.i8): + result = _tl_soft_vmod_i8(vec, scalar_vec, mask) + elif pto.constexpr(dtype == pto.ui16): + result = _tl_soft_vmod_u16(vec, scalar_vec, mask) elif pto.constexpr(dtype == pto.i16): - result = vmod_i16(vec, scalar_vec, mask) + result = _tl_soft_vmod_i16(vec, scalar_vec, mask) elif pto.constexpr(dtype == pto.ui32): - result = vmod_u32(vec, scalar_vec, mask) + result = _tl_soft_vmod_u32(vec, scalar_vec, mask) else: - result = vmod_i32(vec, scalar_vec, mask) + result = _tl_soft_vmod_i32(vec, scalar_vec, mask) return result @pto.inline_proc -def vdiv(vec, scalar_vec, mask, dtype): - if pto.constexpr(dtype == pto.ui16): - result = vdiv_u16(vec, scalar_vec, mask) +def _tl_soft_vdiv(vec, scalar_vec, mask, dtype): + if pto.constexpr(dtype == pto.ui8): + result = _tl_soft_vdiv_u8(vec, scalar_vec, mask) + elif pto.constexpr(dtype == pto.i8): + result = _tl_soft_vdiv_i8(vec, scalar_vec, mask) + elif pto.constexpr(dtype == pto.ui16): + result = _tl_soft_vdiv_u16(vec, scalar_vec, mask) elif pto.constexpr(dtype == pto.i16): - result = vdiv_i16(vec, scalar_vec, mask) + result = _tl_soft_vdiv_i16(vec, scalar_vec, mask) elif pto.constexpr(dtype == pto.ui32): - result = vdiv_u32(vec, scalar_vec, mask) + result = _tl_soft_vdiv_u32(vec, scalar_vec, mask) else: - result = vdiv_i32(vec, scalar_vec, mask) + result = _tl_soft_vdiv_i32(vec, scalar_vec, mask) return result diff --git a/openspec/changes/hide-soft-divmod-behind-pto-surface/.openspec.yaml b/openspec/changes/hide-soft-divmod-behind-pto-surface/.openspec.yaml new file mode 100644 index 000000000..1b75776f7 --- /dev/null +++ b/openspec/changes/hide-soft-divmod-behind-pto-surface/.openspec.yaml @@ -0,0 +1,2 @@ +schema: spec-driven +created: 2026-04-25 diff --git a/openspec/changes/hide-soft-divmod-behind-pto-surface/design.md b/openspec/changes/hide-soft-divmod-behind-pto-surface/design.md new file mode 100644 index 000000000..ddf808d15 --- /dev/null +++ b/openspec/changes/hide-soft-divmod-behind-pto-surface/design.md @@ -0,0 +1,139 @@ +## Context + +### 范围 + +本 design 只覆盖 TileLang DSL 中与 `pto.vdiv` / `pto.vmod` 直接相关的 public surface、frontend diagnostics 与 lowering contract: + +- `pto.vdiv` +- `pto.vmod` +- internal soft helper 注入与调用建模 +- integer `i8/i16/i32` soft div/mod 路径 +- `f16/f32` `vdiv` 的 VPTO authoring path + +它不覆盖: + +- `bf16` div/mod 支持 +- floating-point `vmod` / `fmod` 语义扩展 +- 新的 backend inline 机制设计 +- 非 vector `tdiv*` / `trem*` / `tfmod*` tile op 的独立语义重构 + +### 当前状态 + +当前仓库中已经有以下事实: + +1. `pto.vdiv` 已是 TileLang DSL public surface,但 lowering 仍按统一 `pto.vdiv` 语义继续往下走,没有在 spec 层冻结 dtype-directed 分流。 +2. `lib/TileOps/math.py` 已有整数 `vdiv` / `vmod` soft helper,但它们仍以 helper 形式存在,容易泄漏成用户可见实现细节。 +3. `vmod` 还没有完整的 `pto.vmod` public surface / semantic / lowering 链路。 +4. `i8` 虽然已被纳入 `pto.vdiv` public support matrix,但当前 soft helper 的正式 contract 还没有覆盖 `i8/u8` widen / narrow profile。 +5. 仓库已经具备 `inline_proc` backend-inline 主线,因此“frontend 保留 helper call,backend 主线消除 helper”是可复用路径。 + +### 设计约束 + +- 用户编写 DSL 时只能面向 `pto.vdiv` / `pto.vmod`,不能要求显式调用 soft helper。 +- `f16/f32` `vdiv` 不能被无差别改写到 soft path;需要保留现有 VPTO 指令 authoring contract。 +- `vmod` 本次优先收敛整数族,不在没有明确语义的前提下把 floating `vmod` 一起公开承诺。 +- soft helper 的可见性必须是 internal-only,即便其底层继续以 `inline_proc` 或等价 helper 形式存在。 + +## Goals / Non-Goals + +**Goals:** + +- 统一 `vdiv` / `vmod` 的用户 surface。 +- 让 `vdiv` 在 semantic/lowering 层按 dtype 分流。 +- 让 `vmod` 有正式 public surface,但不暴露 soft helper 细节。 +- 把 `i8` 族 soft path 补成正式契约的一部分。 + +**Non-Goals:** + +- 不把 `vmod` 定义成新的 VPTO public op。 +- 不重新设计整数除零、饱和或 exceptional-value 语义之外的更大算术模型。 +- 不让用户通过新的 public helper surface 显式选择“硬件路径”或“软实现路径”。 + +## Decisions + +### 1. `pto.vdiv` 保持单一 public surface,但 lowering 按 dtype 分流 + +决策: + +- `pto.vdiv` 继续是用户唯一可见的 vector division API。 +- `f16/f32` `pto.vdiv` 在 frontend/semantic/lowering 中继续保留为 `namespace="pto"` 的 VPTO authoring op。 +- `i8/i16/i32` `pto.vdiv` 在 semantic 阶段重写为 internal helper call,再通过 backend-inline 主线消除。 + +原因: + +- 这能保持用户心智统一,同时允许浮点复用现有 VPTO path,整数则走 soft algorithm。 + +备选方案: + +- 让所有 `pto.vdiv` 都走 soft path。 + - 放弃原因:会破坏现有 `f16/f32` VPTO authoring / backend contract,也会丢失已经存在的硬件路径价值。 + +### 2. `pto.vmod` 作为 public surface 新增,但优先只承诺整数族 + +决策: + +- 用户 surface 新增 `pto.vmod(vec0, vec1, mask)`。 +- 本 change 中,`pto.vmod` 的正式支持矩阵优先限定为 `i8/i16/i32` 家族。 +- `pto.vmod` 不新增新的 VPTO public op,而是直接走 internal soft helper path。 + +原因: + +- 当前仓库中已经有整数 soft `vmod` 算法,可以形成闭环。 +- floating `vmod` / `fmod` 语义尚未被冻结,不适合在本 change 中一起 over-commit。 + +### 3. soft helper 继续存在,但必须 internal-only + +决策: + +- soft helper 可以继续以 `inline_proc` 或等价 helper 形式保留。 +- helper 命名与接线路径只属于实现细节,不属于 public surface。 +- frontend diagnostics 对这些 internal helper 名字仍按“unsupported public call surface”处理。 + +原因: + +- 用户 contract 要以 `pto.vdiv` / `pto.vmod` 为唯一入口,不能让 `math.py` helper 变成事实上的次级 API。 + +### 4. `i8` 族必须通过 widen / narrow profile 完成 soft div/mod + +决策: + +- `i8/u8` soft `vdiv` / `vmod` 采用 widen -> soft div/mod -> narrow 的正式实现路线。 +- 该 widen / narrow 过程属于实现细节,但其存在本身是本 change 的 contract 一部分。 + +原因: + +- 当前 soft helper 主要覆盖 16/32-bit 家族;如果不把 `i8` 单独写进 contract,就会继续出现“surface 允许、实现无正式路径”的漂移。 + +## 测试策略 + +- Python/frontend 单测: + - `f16/f32` `pto.vdiv` 继续产出 `pto.vdiv` + - 整数 `pto.vdiv` 产出 internal helper call,而不是 authoring-form `pto.vdiv` + - `pto.vmod` public surface 的正向/负向测试 + - internal helper 名字 direct call 仍被 reject +- lowering / backend 回归: + - 验证整数 `vdiv` / `vmod` 通过 helper + backend-inline 收敛 + - 验证 `f16/f32` `vdiv` 保持 VPTO path + +## Risks / Trade-offs + +- [Risk] `vdiv` 同名但双路径 lowering 可能让测试断言更复杂 + Mitigation:把测试分成浮点 path 和整数 path 两类,不再只断言统一文本形态。 + +- [Risk] `vmod` public surface 若过早承诺浮点 family,会让语义边界变模糊 + Mitigation:本 change 只先冻结整数族,floating `vmod` 另开 change。 + +- [Risk] internal helper 注入机制会增加 frontend/semantic 复杂度 + Mitigation:复用已有 inline-proc helper/call 模型,而不是再开第三套 helper 体系。 + +## Migration Plan + +1. 先冻结 OpenSpec contract,明确 `vdiv` / `vmod` 的 public/lowering 边界。 +2. 在 frontend/semantic 中补 internal helper 注入与 dtype-directed rewrite。 +3. 补齐 `i8` 族 soft div/mod helper。 +4. 补测试与文档,验证 `f16/f32` 和整数 path 均符合新契约。 + +## Open Questions + +- 浮点 `pto.vmod` 是否需要在后续 change 中定义为 `fmod` 还是 remainder 语义。 + 本 change 暂不回答该问题。 diff --git a/openspec/changes/hide-soft-divmod-behind-pto-surface/proposal.md b/openspec/changes/hide-soft-divmod-behind-pto-surface/proposal.md new file mode 100644 index 000000000..47600883e --- /dev/null +++ b/openspec/changes/hide-soft-divmod-behind-pto-surface/proposal.md @@ -0,0 +1,107 @@ +# Proposal: 用 `pto.vdiv` / `pto.vmod` 收敛 TileLang DSL 的软实现细节 + +## 概述 + +当前仓库中已经存在整数 `vdiv` / `vmod` 的软实现算法,但这套实现仍以 `lib/TileOps/math.py` 中的 helper 形式存在,尚未完全收敛到 TileLang DSL 的正式 public surface。 +与此同时,`pto.vdiv` 在 DSL 侧已经是公开 API,但其 lowering 仍缺少按 dtype 分流的稳定契约:`f16/f32` 需要继续走现有 VPTO `vdiv` 指令路径,`i8/i16/i32` 需要改走内部软实现路径,而不把软实现 helper 暴露给用户。 + +本 change 的目标是把 `vdiv` / `vmod` 的用户心智统一收敛为: + +- 用户只写 `pto.vdiv(...)` / `pto.vmod(...)` +- `f16/f32` `vdiv` 继续保留硬件/VPTO 指令 lowering +- 整数 `vdiv` 与 `vmod` 通过内部 soft helper lowering +- 软实现 helper 不作为 TileLang DSL public API 暴露 + +## 背景与动机 + +当前实现存在四个直接问题: + +1. `pto.vdiv` 已经是 public surface,但没有冻结“浮点走 VPTO、整数走 soft path”的 lowering 契约。 +2. `vmod` 目前只有内部 soft helper,没有完整的 `pto.vmod` public surface 和 end-to-end lowering 链路。 +3. 现有 soft helper 仍以可见名字存在,容易把实现细节泄漏给 DSL 用户。 +4. `i8` 族虽然已经被当成 `pto.vdiv` 的支持范围之一,但当前 soft helper 还没有形成明确、可验证的 `i8/u8 -> widen -> div/mod -> narrow` 契约。 + +如果不把这几层契约一次性写清楚,后续实现很容易出现: + +- 文档支持范围与真实 lowering 分叉 +- 用户直接依赖内部 helper 名字 +- `f16/f32` 与整数 `vdiv` 走出两套无 spec 约束的隐式路径 +- `vmod` 继续停留在“有算法、无 public surface”的半完成状态 + +## 目标 + +- 把 TileLang DSL 的除法/取模 public surface 统一为 `pto.vdiv` / `pto.vmod`。 +- 明确 `pto.vdiv` 的 dtype-directed lowering: + - `f16/f32` 走 authoring-form VPTO `pto.vdiv` + - `i8/i16/i32` 走内部 soft helper +- 为 `pto.vmod` 补齐完整 public surface 和内部 soft lowering 链路。 +- 规定 soft helper 只作为内部实现细节存在,不作为用户可依赖 API 暴露。 +- 把 `i8` 族 soft div/mod 的 widen / narrow 路径纳入正式实现范围。 + +## 非目标 + +- 不在本 change 中扩展 `bf16` 的 `vdiv` / `vmod` 支持。 +- 不在本 change 中重新定义 floating-point remainder/fmod 语义;`pto.vmod` 本次优先收敛整数族 public surface。 +- 不在本 change 中引入新的 public helper 命名空间或让用户显式调用 soft helper。 +- 不在本 change 中改变现有 `inline_proc` backend-inline 主线;本 change 仅复用该能力承载内部 soft helper lowering。 + +## What Changes + +- `tilelang-dsl-surface`: + - 明确 `pto.vdiv` 是唯一公开的 vector division surface,支持 `i8/i16/i32/f16/f32`。 + - 新增 `pto.vmod` 作为公开 vector modulo surface,优先覆盖整数族。 + - 明确用户不得依赖内部 soft helper 名字。 +- `tilelang-dsl-diagnostics`: + - 明确 `pto.vdiv` / `pto.vmod` 的 dtype reject 行为。 + - 明确内部 soft helper 名字不属于 public DSL call surface。 +- `tilelang-dsl-vpto-lowering`: + - 为 `pto.vdiv` 新增 dtype-directed lowering 契约。 + - 为 `pto.vmod` 新增“public call -> internal helper -> backend-inline -> legal VPTO”契约。 + - 明确整数 `i8` 族 soft path 需要通过 widen / narrow profile 实现。 + +## Capabilities + +### New Capabilities + +- 无 + +### Modified Capabilities + +- `tilelang-dsl-surface`: 统一 `vdiv` / `vmod` 的 public API,隐藏 soft helper 实现细节。 +- `tilelang-dsl-diagnostics`: 明确 `vdiv` / `vmod` 的支持范围与内部 helper reject 行为。 +- `tilelang-dsl-vpto-lowering`: 为 `vdiv` 增加 dtype-directed lowering,并为 `vmod` 补齐 public-to-soft-lowering 链路。 + +## 预期结果 + +- DSL 用户只需要面向 `pto.vdiv` / `pto.vmod` 编程,不再接触 `math.py` 中的 soft helper 名字。 +- `f16/f32` `vdiv` 继续保留现有 VPTO 指令 authoring/lowering 路径。 +- 整数 `vdiv` / `vmod` 通过内部 soft helper 统一落到 backend-inline 主线,不把软实现细节暴露为 public contract。 +- `i8` 族 support matrix 和实现路径重新一致,不再只有表层放行而无清晰 lowering 契约。 + +## 成功标准 + +- 新增 `openspec/changes/hide-soft-divmod-behind-pto-surface/`,包含 `proposal.md`、`design.md`、`tasks.md`。 +- 新增 spec delta: + - `specs/tilelang-dsl-surface/spec.md` + - `specs/tilelang-dsl-diagnostics/spec.md` + - `specs/tilelang-dsl-vpto-lowering/spec.md` +- 变更文本明确写清: + - `pto.vdiv` 的双 lowering 路径 + - `pto.vmod` 的 public surface 与 soft lowering + - soft helper 的 internal-only 定位 + - `i8` 族 widen / narrow soft path 要求 + +## Impact + +- 受影响目录: + - `tilelang-dsl/python/tilelang_dsl/` + - `tilelang-dsl/tests/` + - `tilelang-dsl/docs/user_guide/` + - `lib/TileOps/` + - `openspec/changes/hide-soft-divmod-behind-pto-surface/` +- 受影响 public API: + - `pto.vdiv` + - `pto.vmod` +- 受影响 lowering 行为: + - `pto.vdiv` 的 dtype-directed lowering + - `pto.vmod` 的 internal soft-helper lowering diff --git a/openspec/changes/hide-soft-divmod-behind-pto-surface/specs/tilelang-dsl-diagnostics/spec.md b/openspec/changes/hide-soft-divmod-behind-pto-surface/specs/tilelang-dsl-diagnostics/spec.md new file mode 100644 index 000000000..f8142e3db --- /dev/null +++ b/openspec/changes/hide-soft-divmod-behind-pto-surface/specs/tilelang-dsl-diagnostics/spec.md @@ -0,0 +1,29 @@ +## ADDED Requirements + +### Requirement: diagnostics MUST enforce the public `vdiv` / `vmod` support matrix and reject internal helper names + +TileLang DSL diagnostics MUST 对 `pto.vdiv` / `pto.vmod` 的 public support matrix 做 fail-fast 校验。 +其中: + +- `pto.vdiv` MUST 接受 `i8/i16/i32/f16/f32` +- `pto.vmod` MUST 接受当前公开承诺的整数 family +- `bf16` 和其他未纳入本 change support matrix 的 dtype MUST 报错 +- internal soft helper 名字 direct call MUST 继续按 unsupported public call surface 报错 + +#### Scenario: unsupported `pto.vdiv` dtype is rejected explicitly + +- **WHEN** 用户以不在 public support matrix 中的 dtype 调用 `pto.vdiv` +- **THEN** frontend MUST 在生成 IR 之前报错 +- **AND** 诊断 MUST 明确指出 `pto.vdiv` 当前支持的 dtype family + +#### Scenario: unsupported `pto.vmod` dtype is rejected explicitly + +- **WHEN** 用户以不在当前公开承诺范围内的 dtype 调用 `pto.vmod` +- **THEN** frontend MUST 在生成 IR 之前报错 +- **AND** 诊断 MUST 明确指出 `pto.vmod` 当前支持范围 + +#### Scenario: internal soft helper name is not accepted as public DSL surface + +- **WHEN** 用户在 TileLang DSL kernel 中直接调用 internal soft helper 名字,而不是 `pto.vdiv` / `pto.vmod` +- **THEN** frontend MUST 把该调用视为 unsupported public call surface +- **AND** 诊断 MUST NOT 暗示用户应该直接依赖该 helper 名字 diff --git a/openspec/changes/hide-soft-divmod-behind-pto-surface/specs/tilelang-dsl-surface/spec.md b/openspec/changes/hide-soft-divmod-behind-pto-surface/specs/tilelang-dsl-surface/spec.md new file mode 100644 index 000000000..678433d8a --- /dev/null +++ b/openspec/changes/hide-soft-divmod-behind-pto-surface/specs/tilelang-dsl-surface/spec.md @@ -0,0 +1,31 @@ +## ADDED Requirements + +### Requirement: TileLang DSL MUST expose `pto.vdiv` / `pto.vmod` as the only public vector div/mod surface + +TileLang DSL public surface 中,vector division / modulo MUST 统一通过 `pto.vdiv(...)` 与 `pto.vmod(...)` 暴露。 +用户 MUST NOT 被要求显式调用 `lib/TileOps/math.py` 中的 helper,或依赖任何 internal soft helper 名字来获得整数 `div/mod` 能力。 +`pto.vdiv` MUST 支持 `i8/i16/i32/f16/f32`。 +`pto.vmod` 在本 change 中 MUST 至少支持整数 8/16/32-bit family。 + +#### Scenario: user writes integer vector division through `pto.vdiv` + +- **WHEN** 用户在 TileLang DSL kernel 中编写 `pto.vdiv(lhs, rhs, mask)`,且向量元素类型为 `i8/i16/i32` +- **THEN** frontend MUST 接受该 public surface +- **AND** 用户 MUST NOT 需要额外调用任何 soft helper 名字 + +#### Scenario: user writes integer vector modulo through `pto.vmod` + +- **WHEN** 用户在 TileLang DSL kernel 中编写 `pto.vmod(lhs, rhs, mask)`,且向量元素类型为整数 8/16/32-bit family +- **THEN** frontend MUST 接受该 public surface +- **AND** 该能力 MUST 通过正式 DSL API 提供,而不是停留在内部 helper 层 + +### Requirement: internal soft helper names MUST NOT become TileLang DSL public API + +即便实现层继续保留 `inline_proc` 或等价 helper 作为整数 `div/mod` 的软实现承载,这些 helper 名字也 MUST NOT 成为 TileLang DSL public API。 +用户文档、support matrix 和 surface 说明 MUST 只暴露 `pto.vdiv` / `pto.vmod`,不得把 internal helper 当成推荐或稳定入口。 + +#### Scenario: public documentation does not advertise soft helper names + +- **WHEN** 用户查看 TileLang DSL 的 vector arithmetic public surface +- **THEN** 文档 MUST 只描述 `pto.vdiv` / `pto.vmod` +- **AND** MUST NOT 把内部 soft helper 名字写成用户应直接调用的 API diff --git a/openspec/changes/hide-soft-divmod-behind-pto-surface/specs/tilelang-dsl-vpto-lowering/spec.md b/openspec/changes/hide-soft-divmod-behind-pto-surface/specs/tilelang-dsl-vpto-lowering/spec.md new file mode 100644 index 000000000..ca836e88d --- /dev/null +++ b/openspec/changes/hide-soft-divmod-behind-pto-surface/specs/tilelang-dsl-vpto-lowering/spec.md @@ -0,0 +1,47 @@ +## ADDED Requirements + +### Requirement: `pto.vdiv` MUST use dtype-directed lowering + +`pto.vdiv` 在 TileLang DSL lowering 中 MUST 按元素类型分流,而不是对所有 dtype 使用同一条后端路径。 +其中: + +- `f16/f32` `pto.vdiv` MUST 保留 authoring-form VPTO `pto.vdiv` 路径 +- `i8/i16/i32` `pto.vdiv` MUST 改写为 internal soft helper path + +整数 `pto.vdiv` 的 soft helper path MAY 以 helper `func.func` + `func.call` 的形式存在于 `mlir_text()` 阶段,但该 helper call MUST 通过现有 backend-inline 主线在后续阶段被消除。 + +#### Scenario: floating `pto.vdiv` stays on the VPTO path + +- **WHEN** 用户以 `f16` 或 `f32` 向量调用 `pto.vdiv` +- **THEN** lowering MUST 保留合法的 authoring-form `pto.vdiv` +- **AND** 后续 VPTO backend/emitter MUST 继续沿用现有 `vdiv` 指令路径 + +#### Scenario: integer `pto.vdiv` is rewritten to an internal soft helper + +- **WHEN** 用户以 `i8/i16/i32` 向量调用 `pto.vdiv` +- **THEN** semantic/lowering MUST NOT 继续把该调用保留为 authoring-form `pto.vdiv` +- **AND** 该调用 MUST 被改写到 internal soft helper path +- **AND** 最终用户 contract MUST 仍表现为“使用 `pto.vdiv` 获得整数 vector division” + +### Requirement: `pto.vmod` MUST lower through the internal soft-helper path + +在本 change 中,`pto.vmod` MUST 通过 internal soft helper path 实现,而不是要求新的 VPTO public op。 +helper call MAY 出现在 frontend materialized module 中,但它 MUST 通过既有 backend-inline 主线在后续阶段被消除。 + +#### Scenario: integer `pto.vmod` lowers through helper plus backend-inline + +- **WHEN** 用户以当前支持的整数 family 调用 `pto.vmod` +- **THEN** frontend MUST 生成合法的 helper-based lowering 形态 +- **AND** 后续 backend 主线 MUST 消除对应 helper call + +### Requirement: integer 8-bit div/mod MUST be implemented through an explicit widen / narrow profile + +对于 `i8/u8` family,本 change 的 `vdiv` / `vmod` 支持 MUST 通过明确的 widen / narrow soft path 完成。 +也即,实现 MUST 先把 8-bit lane 扩展到更宽整数 profile,完成 soft `div/mod`,再收敛回 8-bit 结果。 +该 widen / narrow 过程属于实现细节,但其存在本身 MUST 被视为正式 lowering contract 的一部分。 + +#### Scenario: `i8` vector div/mod does not depend on a fictitious direct 8-bit hardware op + +- **WHEN** 用户对 `i8` 向量调用 `pto.vdiv` 或 `pto.vmod` +- **THEN** lowering MUST 走正式定义的 widen / soft-compute / narrow 路线 +- **AND** MUST NOT 假定存在可直接承载该语义的 8-bit hardware `vdiv` / `vmod` op diff --git a/openspec/changes/hide-soft-divmod-behind-pto-surface/tasks.md b/openspec/changes/hide-soft-divmod-behind-pto-surface/tasks.md new file mode 100644 index 000000000..86a6c5661 --- /dev/null +++ b/openspec/changes/hide-soft-divmod-behind-pto-surface/tasks.md @@ -0,0 +1,38 @@ +## 1. OpenSpec 契约落定 + +- [x] 1.1 完成 `specs/tilelang-dsl-surface/spec.md`,固定 `pto.vdiv` / `pto.vmod` public surface 与 internal helper 不外露契约。 +- [x] 1.2 完成 `specs/tilelang-dsl-diagnostics/spec.md`,固定 `vdiv` / `vmod` dtype reject 与 internal helper name reject 契约。 +- [x] 1.3 完成 `specs/tilelang-dsl-vpto-lowering/spec.md`,固定 `vdiv` 双路径 lowering、`vmod` soft lowering 与 `i8` widen / narrow 约束。 + +## 2. Frontend / semantic surface + +- [x] 2.1 在 `tilelang-dsl/python/tilelang_dsl/semantic.py` 中为 `pto.vdiv` 增加 dtype-directed rewrite:`f16/f32` 保持 `pto.vdiv`,整数族改写为 internal helper call。 +- [x] 2.2 为 `pto.vmod` 补齐 public surface、semantic 分析与 dtype 校验。 +- [x] 2.3 增加 internal soft helper 注入机制,使 kernel 无需显式 import helper 即可完成 rewrite。 +- [x] 2.4 保持 internal helper 名字不属于 public DSL call surface。 + +## 3. Soft helper implementation + +- [x] 3.1 整理 `lib/TileOps/math.py` 中现有 `vdiv` / `vmod` soft helper,使其 internal-only。 +- [x] 3.2 补齐 `i8/u8` 的 soft `vdiv` 路径,采用 widen -> div -> narrow profile。 +- [x] 3.3 补齐 `i8/u8` 的 soft `vmod` 路径,采用 widen -> mod -> narrow profile。 +- [x] 3.4 明确整数 `vdiv` / `vmod` 的除零返回约定与符号约定,并在测试中锁定。 + +## 4. Lowering / backend path + +- [x] 4.1 确保 `f16/f32` `pto.vdiv` 继续走现有 authoring-form VPTO / backend emitter 路径。 +- [x] 4.2 确保整数 `pto.vdiv` / `pto.vmod` 通过 internal helper + backend-inline 收敛,不把 helper 名字暴露为最终 public contract。 +- [x] 4.3 为 `pto.vmod` 路径补齐与现有 inline-proc backend-inline 主线的接线验证。 + +## 5. 回归测试与文档 + +- [x] 5.1 更新 `tilelang-dsl/tests/test_tilelang_dsl_v1.py`,区分浮点 `vdiv` VPTO path 和整数 `vdiv` helper path。 +- [x] 5.2 新增 `pto.vmod` 的 public surface 正向/负向测试。 +- [x] 5.3 新增 `i8` 族 `vdiv` / `vmod` regression,锁定 widen / narrow 行为。 +- [x] 5.4 更新 `tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md`,明确 `vdiv` / `vmod` public contract 与 dtype 支持范围。 + +## 6. 验证 + +- [x] 6.1 执行针对 `vdiv` / `vmod` 的 TileLang DSL 单测。 +- [x] 6.2 执行覆盖 helper + backend-inline 收敛路径的定向回归。 +- [x] 6.3 执行 `openspec validate hide-soft-divmod-behind-pto-surface --type change --strict --json --no-interactive`。 diff --git a/test/basic/tilelang_soft_vmod_backend_inline.pto b/test/basic/tilelang_soft_vmod_backend_inline.pto new file mode 100644 index 000000000..4c7b1e7fc --- /dev/null +++ b/test/basic/tilelang_soft_vmod_backend_inline.pto @@ -0,0 +1,126 @@ +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto %s --emit-vpto -o - | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + func.func @kernel(%arg0: !pto.tile_buf, %arg1: !pto.tile_buf) attributes { pto.tilelang.instance } { + %c0 = arith.constant 0 : index + %tmp_0 = pto.tile_buf_addr %arg0 : !pto.tile_buf -> memref<8x16xi16, #pto.address_space> + %tmp_1 = pto.tile_buf_addr %arg1 : !pto.tile_buf -> memref<8x16xi16, #pto.address_space> + pto.vecscope { + %mask_0 = pto.pset_b16 "PAT_ALL" : !pto.mask + %vec_1 = pto.vlds %tmp_1[%c0] : memref<8x16xi16, #pto.address_space> -> !pto.vreg<128xi16> + %result_81 = func.call @__tl_inline__tl_soft_vmod_2(%vec_1, %vec_1, %mask_0) : (!pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask) -> !pto.vreg<128xi16> + pto.vsts %result_81, %tmp_0[%c0], %mask_0 : !pto.vreg<128xi16>, memref<8x16xi16, #pto.address_space>, !pto.mask + } + return + } + + func.func private @__tl_inline__tl_soft_vdiv_u16_0(%arg0: !pto.vreg<128xui16>, %arg1: !pto.vreg<128xui16>, %arg2: !pto.mask) -> !pto.vreg<128xui16> attributes { pto.tilelang.inline_proc } { + %c0_i32 = arith.constant 0 : i32 + %c0_ui32 = builtin.unrealized_conversion_cast %c0_i32 : i32 to ui32 + %c65536_0_f32 = arith.constant 65536.0 : f32 + %c65535_i16 = arith.constant 65535 : i16 + %c65535_ui16 = builtin.unrealized_conversion_cast %c65535_i16 : i16 to ui16 + %tmp_0 = arith.constant 0 : i16 + %zero_10 = builtin.unrealized_conversion_cast %tmp_0 : i16 to ui16 + %tmp_1 = arith.constant 1 : i16 + %one_11 = builtin.unrealized_conversion_cast %tmp_1 : i16 to ui16 + %fp32_one_12 = arith.constant 1.0 : f32 + %full_mask_b16_13 = pto.pset_b16 "PAT_ALL" : !pto.mask + %full_mask_b32_14 = pto.pset_b32 "PAT_ALL" : !pto.mask + %zero_mask_15 = pto.vcmps %arg1, %zero_10, %arg2, "eq" : !pto.vreg<128xui16>, ui16, !pto.mask -> !pto.mask + %active_mask_16 = pto.pnot %zero_mask_15, %arg2 : !pto.mask, !pto.mask -> !pto.mask + %zero_u16_17 = pto.vbr %zero_10 : ui16 -> !pto.vreg<128xui16> + %vy_lower_u16_18, %vy_higher_u16_19 = pto.vintlv %arg1, %zero_u16_17 : !pto.vreg<128xui16>, !pto.vreg<128xui16> -> !pto.vreg<128xui16>, !pto.vreg<128xui16> + %vy_lower_u32_20 = pto.vcvt %vy_lower_u16_18 {part = "EVEN"} : !pto.vreg<128xui16> -> !pto.vreg<64xui32> + %vy_higher_u32_21 = pto.vcvt %vy_higher_u16_19 {part = "EVEN"} : !pto.vreg<128xui16> -> !pto.vreg<64xui32> + %active_low_22 = pto.vcmps %vy_lower_u32_20, %c0_ui32, %full_mask_b32_14, "ne" : !pto.vreg<64xui32>, ui32, !pto.mask -> !pto.mask + %active_high_23 = pto.vcmps %vy_higher_u32_21, %c0_ui32, %full_mask_b32_14, "ne" : !pto.vreg<64xui32>, ui32, !pto.mask -> !pto.mask + %tmp_2 = pto.vbitcast %vy_lower_u32_20 : !pto.vreg<64xui32> -> !pto.vreg<64xi32> + %vy_lower_f32_24 = pto.vcvt %tmp_2 {rnd = "F"} : !pto.vreg<64xi32> -> !pto.vreg<64xf32> + %tmp_3 = pto.vbitcast %vy_higher_u32_21 : !pto.vreg<64xui32> -> !pto.vreg<64xi32> + %vy_higher_f32_25 = pto.vcvt %tmp_3 {rnd = "F"} : !pto.vreg<64xi32> -> !pto.vreg<64xf32> + %tmp_4 = pto.vbr %fp32_one_12 : f32 -> !pto.vreg<64xf32> + %vy_rec_lower_26 = pto.vdiv %tmp_4, %vy_lower_f32_24, %active_low_22 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %tmp_5 = pto.vbr %fp32_one_12 : f32 -> !pto.vreg<64xf32> + %vy_rec_higher_27 = pto.vdiv %tmp_5, %vy_higher_f32_25, %active_high_23 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %tmp_6 = pto.vbr %c65536_0_f32 : f32 -> !pto.vreg<64xf32> + %vy_scale_lower_28 = pto.vmul %vy_rec_lower_26, %tmp_6, %active_low_22 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %tmp_7 = pto.vbr %c65536_0_f32 : f32 -> !pto.vreg<64xf32> + %vy_scale_higher_29 = pto.vmul %vy_rec_higher_27, %tmp_7, %active_high_23 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %v_lower_i32_30 = pto.vcvt %vy_scale_lower_28 {rnd = "F", sat = "NOSAT"} : !pto.vreg<64xf32> -> !pto.vreg<64xi32> + %v_higher_i32_31 = pto.vcvt %vy_scale_higher_29 {rnd = "F", sat = "NOSAT"} : !pto.vreg<64xf32> -> !pto.vreg<64xi32> + %v_lower_u32_32 = pto.vbitcast %v_lower_i32_30 : !pto.vreg<64xi32> -> !pto.vreg<64xui32> + %v_higher_u32_33 = pto.vbitcast %v_higher_i32_31 : !pto.vreg<64xi32> -> !pto.vreg<64xui32> + %vx_lower_u16_34, %vx_higher_u16_35 = pto.vintlv %arg0, %zero_u16_17 : !pto.vreg<128xui16>, !pto.vreg<128xui16> -> !pto.vreg<128xui16>, !pto.vreg<128xui16> + %vx_lower_u32_36 = pto.vcvt %vx_lower_u16_34 {part = "EVEN"} : !pto.vreg<128xui16> -> !pto.vreg<64xui32> + %vx_higher_u32_37 = pto.vcvt %vx_higher_u16_35 {part = "EVEN"} : !pto.vreg<128xui16> -> !pto.vreg<64xui32> + %q_tmp_lower_38 = pto.vmul %v_lower_u32_32, %vx_lower_u32_36, %active_low_22 : !pto.vreg<64xui32>, !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<64xui32> + %q_tmp_higher_39 = pto.vmul %v_higher_u32_33, %vx_higher_u32_37, %active_high_23 : !pto.vreg<64xui32>, !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<64xui32> + %tmp_8 = pto.vbitcast %q_tmp_lower_38 : !pto.vreg<64xui32> -> !pto.vreg<128xui16> + %tmp_9 = pto.vbitcast %q_tmp_higher_39 : !pto.vreg<64xui32> -> !pto.vreg<128xui16> + %_q_lower_40, %q_tmp_41 = pto.vdintlv %tmp_8, %tmp_9 : !pto.vreg<128xui16>, !pto.vreg<128xui16> -> !pto.vreg<128xui16>, !pto.vreg<128xui16> + %yq_tmp_42 = pto.vmul %q_tmp_41, %arg1, %active_mask_16 : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + %r_tmp_43 = pto.vsub %arg0, %yq_tmp_42, %active_mask_16 : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + %ge_mask_44 = pto.vcmp %r_tmp_43, %arg1, %active_mask_16, "ge" : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.mask + %refined_r_45 = pto.vsub %r_tmp_43, %arg1, %active_mask_16 : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + %r_tmp_46 = pto.vsel %refined_r_45, %r_tmp_43, %ge_mask_44 : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + %q_inc_47 = pto.vadds %q_tmp_41, %one_11, %active_mask_16 : !pto.vreg<128xui16>, ui16, !pto.mask -> !pto.vreg<128xui16> + %q_tmp_48 = pto.vsel %q_inc_47, %q_tmp_41, %ge_mask_44 : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + %ge_mask_49 = pto.vcmp %r_tmp_46, %arg1, %active_mask_16, "ge" : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.mask + %refined_r_50 = pto.vsub %r_tmp_46, %arg1, %active_mask_16 : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + %r_tmp_51 = pto.vsel %refined_r_50, %r_tmp_46, %ge_mask_49 : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + %q_inc_52 = pto.vadds %q_tmp_48, %one_11, %active_mask_16 : !pto.vreg<128xui16>, ui16, !pto.mask -> !pto.vreg<128xui16> + %q_tmp_53 = pto.vsel %q_inc_52, %q_tmp_48, %ge_mask_49 : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + %zero_q_54 = pto.vbr %c65535_ui16 : ui16 -> !pto.vreg<128xui16> + %tmp_10 = pto.vsel %zero_q_54, %q_tmp_53, %zero_mask_15 : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + return %tmp_10 : !pto.vreg<128xui16> + } + + func.func private @__tl_inline__tl_soft_vmod_i16_1(%arg0: !pto.vreg<128xi16>, %arg1: !pto.vreg<128xi16>, %arg2: !pto.mask) -> !pto.vreg<128xi16> attributes { pto.tilelang.inline_proc } { + %zero_60 = arith.constant 0 : i16 + %neg_one_61 = arith.constant -1 : i16 + %zero_mask_62 = pto.vcmps %arg1, %zero_60, %arg2, "eq" : !pto.vreg<128xi16>, i16, !pto.mask -> !pto.mask + %active_mask_63 = pto.pnot %zero_mask_62, %arg2 : !pto.mask, !pto.mask -> !pto.mask + %tmp_11 = pto.vabs %arg0, %active_mask_63 : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %abs_x_64 = pto.vbitcast %tmp_11 : !pto.vreg<128xi16> -> !pto.vreg<128xui16> + %tmp_12 = pto.vabs %arg1, %active_mask_63 : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %abs_y_65 = pto.vbitcast %tmp_12 : !pto.vreg<128xi16> -> !pto.vreg<128xui16> + %x_xor_y_66 = pto.vxor %arg0, %arg1, %active_mask_63 : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %p_pos_67 = pto.vcmps %x_xor_y_66, %zero_60, %active_mask_63, "ge" : !pto.vreg<128xi16>, i16, !pto.mask -> !pto.mask + %q_abs_68 = func.call @__tl_inline__tl_soft_vdiv_u16_0(%abs_x_64, %abs_y_65, %active_mask_63) : (!pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask) -> !pto.vreg<128xui16> + %tmp_13 = pto.vbitcast %q_abs_68 : !pto.vreg<128xui16> -> !pto.vreg<128xi16> + %neg_q_69 = pto.vneg %tmp_13, %active_mask_63 : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %tmp_14 = pto.vbitcast %q_abs_68 : !pto.vreg<128xui16> -> !pto.vreg<128xi16> + %q_70 = pto.vsel %tmp_14, %neg_q_69, %p_pos_67 : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %qy_71 = pto.vmul %q_70, %arg1, %active_mask_63 : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %remainder_72 = pto.vsub %arg0, %qy_71, %active_mask_63 : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %nonzero_remainder_73 = pto.vcmps %remainder_72, %zero_60, %active_mask_63, "ne" : !pto.vreg<128xi16>, i16, !pto.mask -> !pto.mask + %sign_x_74 = pto.vcmps %arg0, %zero_60, %active_mask_63, "ge" : !pto.vreg<128xi16>, i16, !pto.mask -> !pto.mask + %sign_y_75 = pto.vcmps %arg1, %zero_60, %active_mask_63, "ge" : !pto.vreg<128xi16>, i16, !pto.mask -> !pto.mask + %sign_diff_76 = pto.pxor %sign_x_74, %sign_y_75, %active_mask_63 : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + %need_floor_fix_77 = pto.pand %sign_diff_76, %nonzero_remainder_73, %active_mask_63 : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + %amended_remainder_78 = pto.vadd %arg1, %remainder_72, %active_mask_63 : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %remainder_79 = pto.vsel %amended_remainder_78, %remainder_72, %need_floor_fix_77 : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %tmp_15 = pto.vbr %neg_one_61 : i16 -> !pto.vreg<128xi16> + %tmp_16 = pto.vsel %tmp_15, %remainder_79, %zero_mask_62 : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + return %tmp_16 : !pto.vreg<128xi16> + } + + func.func private @__tl_inline__tl_soft_vmod_2(%arg0: !pto.vreg<128xi16>, %arg1: !pto.vreg<128xi16>, %arg2: !pto.mask) -> !pto.vreg<128xi16> attributes { pto.tilelang.inline_proc } { + %result_80 = func.call @__tl_inline__tl_soft_vmod_i16_1(%arg0, %arg1, %arg2) : (!pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask) -> !pto.vreg<128xi16> + return %result_80 : !pto.vreg<128xi16> + } +} + +// CHECK-LABEL: func.func @kernel( +// CHECK: pto.vecscope { +// CHECK: pto.vlds +// CHECK: pto.vdiv +// CHECK: pto.vmul +// CHECK: pto.vsub +// CHECK: pto.pxor +// CHECK: pto.pand +// CHECK: pto.vsel +// CHECK: pto.vsts +// CHECK-NOT: func.call @__tl_inline__tl_soft_ +// CHECK-NOT: func.func private @__tl_inline__tl_soft_ diff --git a/test/vpto_tilelang_inline_soft_divmod_fastpath.pto b/test/vpto_tilelang_inline_soft_divmod_fastpath.pto new file mode 100644 index 000000000..98a2c297b --- /dev/null +++ b/test/vpto_tilelang_inline_soft_divmod_fastpath.pto @@ -0,0 +1,158 @@ +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto %s --emit-vpto -o - | FileCheck %s + +// CHECK-LABEL: func.func @kernel( +// CHECK: pto.vecscope { +// CHECK: pto.vlds +// CHECK: pto.vcmps +// CHECK: pto.vxor +// CHECK: pto.vdiv +// CHECK: pto.vmul +// CHECK: pto.vsub +// CHECK: pto.vsts +// CHECK-NOT: func.call @__tl_inline__tl_soft +// CHECK-NOT: func.func private @__tl_inline__tl_soft +// CHECK-NOT: __tl_inline__tl_soft_vdiv_ +// CHECK-NOT: __tl_inline__tl_soft_vmod_ + +// tilelang.target = a5 +// tilelang.op = dump_i16_divmod_lit_tmp +// tilelang.dtypes = (i16, i16) +// tilelang.verify = True +// tilelang.advanced = False +// tilelang.specialize dst shape=(8, 16) memory_space=ub config=None +// tilelang.specialize src shape=(8, 16) memory_space=ub config=None +module attributes {pto.target_arch = "a5"} { + func.func @kernel(%arg0: !pto.tile_buf, %arg1: !pto.tile_buf) attributes { pto.tilelang.instance } { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %tmp_0 = pto.tile_buf_addr %arg0 : !pto.tile_buf -> memref<8x16xi16, #pto.address_space> + %tmp_1 = pto.tile_buf_addr %arg1 : !pto.tile_buf -> memref<8x16xi16, #pto.address_space> + pto.vecscope { + %mask_0 = pto.pset_b16 "PAT_ALL" : !pto.mask + %vec_1 = pto.vlds %tmp_1[%c0] : memref<8x16xi16, #pto.address_space> -> !pto.vreg<128xi16> + %q_59 = func.call @__tl_inline__tl_soft_vdiv_2(%vec_1, %vec_1, %mask_0) : (!pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask) -> !pto.vreg<128xi16> + %r_81 = func.call @__tl_inline__tl_soft_vmod_4(%vec_1, %vec_1, %mask_0) : (!pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask) -> !pto.vreg<128xi16> + pto.vsts %q_59, %tmp_0[%c0], %mask_0 : !pto.vreg<128xi16>, memref<8x16xi16, #pto.address_space>, !pto.mask + pto.vsts %r_81, %tmp_0[%c1], %mask_0 : !pto.vreg<128xi16>, memref<8x16xi16, #pto.address_space>, !pto.mask + } + return + } + func.func private @__tl_inline__tl_soft_vdiv_u16_0(%arg0: !pto.vreg<128xui16>, %arg1: !pto.vreg<128xui16>, %arg2: !pto.mask) -> !pto.vreg<128xui16> attributes { pto.tilelang.inline_proc } { + %c0_i32 = arith.constant 0 : i32 + %c0_ui32 = builtin.unrealized_conversion_cast %c0_i32 : i32 to ui32 + %c65536_0_f32 = arith.constant 65536.0 : f32 + %c65535_i16 = arith.constant 65535 : i16 + %c65535_ui16 = builtin.unrealized_conversion_cast %c65535_i16 : i16 to ui16 + %tmp_0 = arith.constant 0 : i16 + %zero_10 = builtin.unrealized_conversion_cast %tmp_0 : i16 to ui16 + %tmp_1 = arith.constant 1 : i16 + %one_11 = builtin.unrealized_conversion_cast %tmp_1 : i16 to ui16 + %fp32_one_12 = arith.constant 1.0 : f32 + %full_mask_b16_13 = pto.pset_b16 "PAT_ALL" : !pto.mask + %full_mask_b32_14 = pto.pset_b32 "PAT_ALL" : !pto.mask + %zero_mask_15 = pto.vcmps %arg1, %zero_10, %arg2, "eq" : !pto.vreg<128xui16>, ui16, !pto.mask -> !pto.mask + %active_mask_16 = pto.pnot %zero_mask_15, %arg2 : !pto.mask, !pto.mask -> !pto.mask + %zero_u16_17 = pto.vbr %zero_10 : ui16 -> !pto.vreg<128xui16> + %vy_lower_u16_18, %vy_higher_u16_19 = pto.vintlv %arg1, %zero_u16_17 : !pto.vreg<128xui16>, !pto.vreg<128xui16> -> !pto.vreg<128xui16>, !pto.vreg<128xui16> + %vy_lower_u32_20 = pto.vcvt %vy_lower_u16_18 {part = "EVEN"} : !pto.vreg<128xui16> -> !pto.vreg<64xui32> + %vy_higher_u32_21 = pto.vcvt %vy_higher_u16_19 {part = "EVEN"} : !pto.vreg<128xui16> -> !pto.vreg<64xui32> + %active_low_22 = pto.vcmps %vy_lower_u32_20, %c0_ui32, %full_mask_b32_14, "ne" : !pto.vreg<64xui32>, ui32, !pto.mask -> !pto.mask + %active_high_23 = pto.vcmps %vy_higher_u32_21, %c0_ui32, %full_mask_b32_14, "ne" : !pto.vreg<64xui32>, ui32, !pto.mask -> !pto.mask + %tmp_2 = pto.vbitcast %vy_lower_u32_20 : !pto.vreg<64xui32> -> !pto.vreg<64xi32> + %vy_lower_f32_24 = pto.vcvt %tmp_2 {rnd = "F"} : !pto.vreg<64xi32> -> !pto.vreg<64xf32> + %tmp_3 = pto.vbitcast %vy_higher_u32_21 : !pto.vreg<64xui32> -> !pto.vreg<64xi32> + %vy_higher_f32_25 = pto.vcvt %tmp_3 {rnd = "F"} : !pto.vreg<64xi32> -> !pto.vreg<64xf32> + %tmp_4 = pto.vbr %fp32_one_12 : f32 -> !pto.vreg<64xf32> + %vy_rec_lower_26 = pto.vdiv %tmp_4, %vy_lower_f32_24, %active_low_22 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %tmp_5 = pto.vbr %fp32_one_12 : f32 -> !pto.vreg<64xf32> + %vy_rec_higher_27 = pto.vdiv %tmp_5, %vy_higher_f32_25, %active_high_23 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %tmp_6 = pto.vbr %c65536_0_f32 : f32 -> !pto.vreg<64xf32> + %vy_scale_lower_28 = pto.vmul %vy_rec_lower_26, %tmp_6, %active_low_22 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %tmp_7 = pto.vbr %c65536_0_f32 : f32 -> !pto.vreg<64xf32> + %vy_scale_higher_29 = pto.vmul %vy_rec_higher_27, %tmp_7, %active_high_23 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %v_lower_i32_30 = pto.vcvt %vy_scale_lower_28 {rnd = "F", sat = "NOSAT"} : !pto.vreg<64xf32> -> !pto.vreg<64xi32> + %v_higher_i32_31 = pto.vcvt %vy_scale_higher_29 {rnd = "F", sat = "NOSAT"} : !pto.vreg<64xf32> -> !pto.vreg<64xi32> + %v_lower_u32_32 = pto.vbitcast %v_lower_i32_30 : !pto.vreg<64xi32> -> !pto.vreg<64xui32> + %v_higher_u32_33 = pto.vbitcast %v_higher_i32_31 : !pto.vreg<64xi32> -> !pto.vreg<64xui32> + %vx_lower_u16_34, %vx_higher_u16_35 = pto.vintlv %arg0, %zero_u16_17 : !pto.vreg<128xui16>, !pto.vreg<128xui16> -> !pto.vreg<128xui16>, !pto.vreg<128xui16> + %vx_lower_u32_36 = pto.vcvt %vx_lower_u16_34 {part = "EVEN"} : !pto.vreg<128xui16> -> !pto.vreg<64xui32> + %vx_higher_u32_37 = pto.vcvt %vx_higher_u16_35 {part = "EVEN"} : !pto.vreg<128xui16> -> !pto.vreg<64xui32> + %q_tmp_lower_38 = pto.vmul %v_lower_u32_32, %vx_lower_u32_36, %active_low_22 : !pto.vreg<64xui32>, !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<64xui32> + %q_tmp_higher_39 = pto.vmul %v_higher_u32_33, %vx_higher_u32_37, %active_high_23 : !pto.vreg<64xui32>, !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<64xui32> + %tmp_8 = pto.vbitcast %q_tmp_lower_38 : !pto.vreg<64xui32> -> !pto.vreg<128xui16> + %tmp_9 = pto.vbitcast %q_tmp_higher_39 : !pto.vreg<64xui32> -> !pto.vreg<128xui16> + %_q_lower_40, %q_tmp_41 = pto.vdintlv %tmp_8, %tmp_9 : !pto.vreg<128xui16>, !pto.vreg<128xui16> -> !pto.vreg<128xui16>, !pto.vreg<128xui16> + %yq_tmp_42 = pto.vmul %q_tmp_41, %arg1, %active_mask_16 : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + %r_tmp_43 = pto.vsub %arg0, %yq_tmp_42, %active_mask_16 : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + %ge_mask_44 = pto.vcmp %r_tmp_43, %arg1, %active_mask_16, "ge" : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.mask + %refined_r_45 = pto.vsub %r_tmp_43, %arg1, %active_mask_16 : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + %r_tmp_46 = pto.vsel %refined_r_45, %r_tmp_43, %ge_mask_44 : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + %q_inc_47 = pto.vadds %q_tmp_41, %one_11, %active_mask_16 : !pto.vreg<128xui16>, ui16, !pto.mask -> !pto.vreg<128xui16> + %q_tmp_48 = pto.vsel %q_inc_47, %q_tmp_41, %ge_mask_44 : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + %ge_mask_49 = pto.vcmp %r_tmp_46, %arg1, %active_mask_16, "ge" : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.mask + %refined_r_50 = pto.vsub %r_tmp_46, %arg1, %active_mask_16 : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + %r_tmp_51 = pto.vsel %refined_r_50, %r_tmp_46, %ge_mask_49 : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + %q_inc_52 = pto.vadds %q_tmp_48, %one_11, %active_mask_16 : !pto.vreg<128xui16>, ui16, !pto.mask -> !pto.vreg<128xui16> + %q_tmp_53 = pto.vsel %q_inc_52, %q_tmp_48, %ge_mask_49 : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + %zero_q_54 = pto.vbr %c65535_ui16 : ui16 -> !pto.vreg<128xui16> + %tmp_10 = pto.vsel %zero_q_54, %q_tmp_53, %zero_mask_15 : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + return %tmp_10 : !pto.vreg<128xui16> + } + func.func private @__tl_inline__tl_soft_vdiv_i16_1(%arg0: !pto.vreg<128xi16>, %arg1: !pto.vreg<128xi16>, %arg2: !pto.mask) -> !pto.vreg<128xi16> attributes { pto.tilelang.inline_proc } { + %zero_2 = arith.constant 0 : i16 + %neg_one_3 = arith.constant -1 : i16 + %zero_mask_4 = pto.vcmps %arg1, %zero_2, %arg2, "eq" : !pto.vreg<128xi16>, i16, !pto.mask -> !pto.mask + %active_mask_5 = pto.pnot %zero_mask_4, %arg2 : !pto.mask, !pto.mask -> !pto.mask + %tmp_0 = pto.vabs %arg0, %active_mask_5 : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %abs_x_6 = pto.vbitcast %tmp_0 : !pto.vreg<128xi16> -> !pto.vreg<128xui16> + %tmp_1 = pto.vabs %arg1, %active_mask_5 : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %abs_y_7 = pto.vbitcast %tmp_1 : !pto.vreg<128xi16> -> !pto.vreg<128xui16> + %x_xor_y_8 = pto.vxor %arg0, %arg1, %active_mask_5 : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %p_pos_9 = pto.vcmps %x_xor_y_8, %zero_2, %active_mask_5, "ge" : !pto.vreg<128xi16>, i16, !pto.mask -> !pto.mask + %q_abs_55 = func.call @__tl_inline__tl_soft_vdiv_u16_0(%abs_x_6, %abs_y_7, %active_mask_5) : (!pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask) -> !pto.vreg<128xui16> + %tmp_2 = pto.vbitcast %q_abs_55 : !pto.vreg<128xui16> -> !pto.vreg<128xi16> + %neg_q_56 = pto.vneg %tmp_2, %active_mask_5 : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %tmp_3 = pto.vbitcast %q_abs_55 : !pto.vreg<128xui16> -> !pto.vreg<128xi16> + %q_57 = pto.vsel %tmp_3, %neg_q_56, %p_pos_9 : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %tmp_5 = pto.vbr %neg_one_3 : i16 -> !pto.vreg<128xi16> + %tmp_4 = pto.vsel %tmp_5, %q_57, %zero_mask_4 : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + return %tmp_4 : !pto.vreg<128xi16> + } + func.func private @__tl_inline__tl_soft_vdiv_2(%arg0: !pto.vreg<128xi16>, %arg1: !pto.vreg<128xi16>, %arg2: !pto.mask) -> !pto.vreg<128xi16> attributes { pto.tilelang.inline_proc } { + %result_58 = func.call @__tl_inline__tl_soft_vdiv_i16_1(%arg0, %arg1, %arg2) : (!pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask) -> !pto.vreg<128xi16> + return %result_58 : !pto.vreg<128xi16> + } + func.func private @__tl_inline__tl_soft_vmod_i16_3(%arg0: !pto.vreg<128xi16>, %arg1: !pto.vreg<128xi16>, %arg2: !pto.mask) -> !pto.vreg<128xi16> attributes { pto.tilelang.inline_proc } { + %zero_60 = arith.constant 0 : i16 + %neg_one_61 = arith.constant -1 : i16 + %zero_mask_62 = pto.vcmps %arg1, %zero_60, %arg2, "eq" : !pto.vreg<128xi16>, i16, !pto.mask -> !pto.mask + %active_mask_63 = pto.pnot %zero_mask_62, %arg2 : !pto.mask, !pto.mask -> !pto.mask + %tmp_0 = pto.vabs %arg0, %active_mask_63 : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %abs_x_64 = pto.vbitcast %tmp_0 : !pto.vreg<128xi16> -> !pto.vreg<128xui16> + %tmp_1 = pto.vabs %arg1, %active_mask_63 : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %abs_y_65 = pto.vbitcast %tmp_1 : !pto.vreg<128xi16> -> !pto.vreg<128xui16> + %x_xor_y_66 = pto.vxor %arg0, %arg1, %active_mask_63 : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %p_pos_67 = pto.vcmps %x_xor_y_66, %zero_60, %active_mask_63, "ge" : !pto.vreg<128xi16>, i16, !pto.mask -> !pto.mask + %q_abs_68 = func.call @__tl_inline__tl_soft_vdiv_u16_0(%abs_x_64, %abs_y_65, %active_mask_63) : (!pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask) -> !pto.vreg<128xui16> + %tmp_2 = pto.vbitcast %q_abs_68 : !pto.vreg<128xui16> -> !pto.vreg<128xi16> + %neg_q_69 = pto.vneg %tmp_2, %active_mask_63 : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %tmp_3 = pto.vbitcast %q_abs_68 : !pto.vreg<128xui16> -> !pto.vreg<128xi16> + %q_70 = pto.vsel %tmp_3, %neg_q_69, %p_pos_67 : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %qy_71 = pto.vmul %q_70, %arg1, %active_mask_63 : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %remainder_72 = pto.vsub %arg0, %qy_71, %active_mask_63 : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %nonzero_remainder_73 = pto.vcmps %remainder_72, %zero_60, %active_mask_63, "ne" : !pto.vreg<128xi16>, i16, !pto.mask -> !pto.mask + %sign_x_74 = pto.vcmps %arg0, %zero_60, %active_mask_63, "ge" : !pto.vreg<128xi16>, i16, !pto.mask -> !pto.mask + %sign_y_75 = pto.vcmps %arg1, %zero_60, %active_mask_63, "ge" : !pto.vreg<128xi16>, i16, !pto.mask -> !pto.mask + %sign_diff_76 = pto.pxor %sign_x_74, %sign_y_75, %active_mask_63 : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + %need_floor_fix_77 = pto.pand %sign_diff_76, %nonzero_remainder_73, %active_mask_63 : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + %amended_remainder_78 = pto.vadd %arg1, %remainder_72, %active_mask_63 : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %remainder_79 = pto.vsel %amended_remainder_78, %remainder_72, %need_floor_fix_77 : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %tmp_5 = pto.vbr %neg_one_61 : i16 -> !pto.vreg<128xi16> + %tmp_4 = pto.vsel %tmp_5, %remainder_79, %zero_mask_62 : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + return %tmp_4 : !pto.vreg<128xi16> + } + func.func private @__tl_inline__tl_soft_vmod_4(%arg0: !pto.vreg<128xi16>, %arg1: !pto.vreg<128xi16>, %arg2: !pto.mask) -> !pto.vreg<128xi16> attributes { pto.tilelang.inline_proc } { + %result_80 = func.call @__tl_inline__tl_soft_vmod_i16_3(%arg0, %arg1, %arg2) : (!pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask) -> !pto.vreg<128xi16> + return %result_80 : !pto.vreg<128xi16> + } +} diff --git a/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md b/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md index 2791b7ce4..e492cbdce 100644 --- a/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md +++ b/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md @@ -428,6 +428,12 @@ sum_vec = pto.vadd(vec_a, vec_b, mask32) **Description**: Element-wise division of two vectors. +- Supported element types are the 8/16/32-bit integer families (`i*`, `si*`, `ui*`) plus `f16` and `f32`. +- `f16`/`f32` authoring code stays on the public `pto.vdiv` VPTO path. +- Integer `pto.vdiv` also uses the same public surface, but lowers through an internal soft-helper path. +- For `i8`/`ui8`, the integer lowering widens to 16-bit lanes, computes the soft division, then narrows back to 8-bit lanes. +- Internal helper names such as `_tl_soft_vdiv_*` are implementation details and are not part of the supported DSL call surface. + **Parameters**: | Parameter | Type | Description | |-----------|------|-------------| @@ -440,6 +446,28 @@ sum_vec = pto.vadd(vec_a, vec_b, mask32) |--------------|------|-------------| | `result` | `VRegType` | Quotient of vectors | +#### `pto.vmod(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise modulo of two vectors. + +- Supported element types are the 8/16/32-bit integer families (`i*`, `si*`, `ui*`). +- Floating-point `vmod` is not part of the current TileLang DSL v1 public surface. +- `pto.vmod` is the only public vector modulo entry point in TileLang DSL v1. +- The current implementation lowers through an internal soft-helper path; helper names such as `_tl_soft_vmod_*` are intentionally hidden implementation details. +- For `i8`/`ui8`, the modulo path uses an explicit widen-to-16-bit, soft-compute, narrow-back-to-8-bit profile. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | Dividend vector | +| `vec2` | `VRegType` | Divisor vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Remainder vector | + #### `pto.vmax(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` **Description**: Element-wise maximum of two vectors. diff --git a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py index 49580a289..036433136 100644 --- a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py +++ b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py @@ -206,6 +206,7 @@ class FrontendKernelNode: body: tuple[FrontendStmtNode, ...] context_attrs: tuple[tuple[str, Any], ...] = () inline_procs: tuple[FrontendInlineProcNode, ...] = () + internal_inline_procs: tuple[FrontendInlineProcNode, ...] = () @dataclass(frozen=True) @@ -1415,6 +1416,9 @@ def build_frontend_kernel_node(descriptor: Any) -> FrontendKernelNode: local_bindings=local_bindings, ) sorted_inline_procs = tuple(sorted(descriptor.inline_procs.items(), key=lambda item: item[0])) + sorted_internal_inline_procs = tuple( + sorted(descriptor.internal_inline_procs.items(), key=lambda item: item[0]) + ) context = _FrontendBuildContext( source_info=source_info, module_globals=getattr(descriptor._py_fn, "__globals__", None), @@ -1518,6 +1522,82 @@ def build_frontend_kernel_node(descriptor: Any) -> FrontendKernelNode: inline_proc_source_infos, ) + internal_inline_proc_nodes: tuple[FrontendInlineProcNode, ...] = () + if sorted_internal_inline_procs: + merged_inline_proc_descriptors = { + name: _FrontendInlineProc( + name=name, + source_info=proc.source_info, + signature=proc.signature, + ) + for name, proc in (*sorted_inline_procs, *sorted_internal_inline_procs) + } + internal_context = _FrontendBuildContext( + source_info=source_info, + module_globals=getattr(descriptor._py_fn, "__globals__", None), + templates=descriptor.templates, + selected_op=descriptor.selected_op, + advanced_enabled=descriptor.advanced_enabled, + inline_procs=merged_inline_proc_descriptors, + global_literal_constants=global_literal_constants, + local_bindings=local_bindings, + ) + internal_nodes: list[FrontendInlineProcNode] = [] + internal_source_infos: dict[str, Any] = {} + for name, inline_proc_descriptor in sorted_internal_inline_procs: + inline_source = inline_proc_descriptor.source_info + if inline_source is None: + if source_info is not None: + raise context.error( + source_info.function_def, + f"inline_proc `{name}` requires source-visible Python functions", + ) + raise ValueError( + f"inline_proc `{name}` requires source-visible Python functions" + ) + internal_source_infos[name] = inline_source + helper_context = internal_context.enter_inline_proc(name, inline_source) + helper_body = _build_stmt_list(inline_source.function_def.body, helper_context) + parameter_specs = _inline_proc_param_specs( + _FrontendInlineProc( + name=name, + source_info=inline_source, + signature=inline_proc_descriptor.signature, + ) + ) + inline_proc_node = FrontendInlineProcNode( + name=name, + parameters=tuple( + FrontendInlineProcParameterNode( + name=param_name, + annotation=arg.annotation, + default=None + if default_node is None + else _build_expr(default_node, helper_context), + ) + for (param_name, default_node), arg in zip( + parameter_specs, + inline_source.function_def.args.args, + ) + ), + body=helper_body, + ) + internal_nodes.append(inline_proc_node) + + internal_inline_proc_nodes = tuple(internal_nodes) + for inline_proc_node in internal_inline_proc_nodes: + source = internal_source_infos[inline_proc_node.name] + helper_context = internal_context.enter_inline_proc(inline_proc_node.name, source) + assigned_names: set[str] = set() + param_names = {parameter.name for parameter in inline_proc_node.parameters} + for stmt in inline_proc_node.body: + _validate_inline_capture( + stmt, + param_names, + assigned_names, + context=helper_context, + ) + return FrontendKernelNode( target=descriptor.target, op=descriptor.op, @@ -1532,6 +1612,7 @@ def build_frontend_kernel_node(descriptor: Any) -> FrontendKernelNode: sorted(descriptor.constraint_context_attrs.items(), key=lambda item: item[0]) ), inline_procs=reachable_inline_proc_nodes, + internal_inline_procs=internal_inline_proc_nodes, ) diff --git a/tilelang-dsl/python/tilelang_dsl/kernel.py b/tilelang-dsl/python/tilelang_dsl/kernel.py index 743ae6136..e061f4567 100644 --- a/tilelang-dsl/python/tilelang_dsl/kernel.py +++ b/tilelang-dsl/python/tilelang_dsl/kernel.py @@ -13,6 +13,7 @@ import os import inspect import ast +import importlib.util import sys import subprocess import tempfile @@ -54,6 +55,7 @@ _UNSET = object() _PTOAS_BIN_ENV = "PTOAS_BIN" +_INTERNAL_SOFT_MATH_MODULE_NAME = "tilelang_dsl._internal_soft_math" _SUPPORTED_TEMPLATE_PTO_CALLS = frozenset( SUPPORTED_TOPLEVEL_PTO_CALLS | SUPPORTED_VECSCOPE_PTO_CALLS @@ -85,6 +87,7 @@ _INLINE_PROC_REGISTRY: dict[tuple[str, str], "InlineProcDescriptor"] = {} +_INTERNAL_INLINE_PROC_CACHE: tuple[tuple[str, "InlineProcDescriptor"], ...] | None = None @dataclass(frozen=True) @@ -205,6 +208,45 @@ def _collect_inline_procs(module_name: str) -> tuple[tuple[str, InlineProcDescri return tuple(sorted(collected.items(), key=lambda item: item[0])) +def _load_module_from_path(module_name: str, path: Path) -> Any: + module = sys.modules.get(module_name) + if module is not None: + return module + spec = importlib.util.spec_from_file_location(module_name, path) + if spec is None or spec.loader is None: + raise ImportError(f"unable to load module {module_name!r} from {path}") + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +def _collect_internal_inline_procs() -> tuple[tuple[str, InlineProcDescriptor], ...]: + global _INTERNAL_INLINE_PROC_CACHE + if _INTERNAL_INLINE_PROC_CACHE is not None: + return _INTERNAL_INLINE_PROC_CACHE + + repo_root = Path(__file__).resolve().parents[3] + soft_math_path = repo_root / "lib" / "TileOps" / "math.py" + if not soft_math_path.exists(): + _INTERNAL_INLINE_PROC_CACHE = () + return _INTERNAL_INLINE_PROC_CACHE + + try: + module = _load_module_from_path(_INTERNAL_SOFT_MATH_MODULE_NAME, soft_math_path) + except Exception: + _INTERNAL_INLINE_PROC_CACHE = () + return _INTERNAL_INLINE_PROC_CACHE + + collected: dict[str, InlineProcDescriptor] = {} + for symbol, value in vars(module).items(): + if isinstance(value, InlineProcDescriptor): + collected.setdefault(symbol, value) + + _INTERNAL_INLINE_PROC_CACHE = tuple(sorted(collected.items(), key=lambda item: item[0])) + return _INTERNAL_INLINE_PROC_CACHE + + def _register_inline_proc(descriptor: InlineProcDescriptor) -> InlineProcDescriptor: _INLINE_PROC_REGISTRY[_inline_proc_registry_key(descriptor.py_fn)] = descriptor return descriptor @@ -847,6 +889,7 @@ class VKernelDescriptor: priority: int = 0 _templates: tuple[tuple[str, tuple[tuple[str, str], ...]], ...] = field(default=(), repr=False) _inline_procs: tuple[tuple[str, InlineProcDescriptor], ...] = field(default=(), repr=False) + _internal_inline_procs: tuple[tuple[str, InlineProcDescriptor], ...] = field(default=(), repr=False) _selected_op: str | None = None _selected_dtype_signature: tuple[ScalarType | MaskType, ...] | None = None _parameters: tuple[BoundKernelParameter, ...] | None = field(default=None, repr=False) @@ -880,6 +923,10 @@ def templates(self) -> dict[str, dict[str, str]]: def inline_procs(self) -> dict[str, InlineProcDescriptor]: return {name: descriptor for name, descriptor in self._inline_procs} + @property + def internal_inline_procs(self) -> dict[str, InlineProcDescriptor]: + return {name: descriptor for name, descriptor in self._internal_inline_procs} + @property def dtype_signature(self) -> tuple[ScalarType | MaskType, ...]: if self._selected_dtype_signature is None: @@ -954,6 +1001,7 @@ def _bind_constraint_context_attrs( priority=self.priority, _templates=self._templates, _inline_procs=self._inline_procs, + _internal_inline_procs=self._internal_inline_procs, _selected_op=self._selected_op, _selected_dtype_signature=self._selected_dtype_signature, _parameters=self._parameters, @@ -980,6 +1028,7 @@ def _bind_selected_dtype_signature( priority=self.priority, _templates=self._templates, _inline_procs=self._inline_procs, + _internal_inline_procs=self._internal_inline_procs, _selected_op=self._selected_op, _selected_dtype_signature=dtype_signature, _parameters=bound_parameters, @@ -1009,6 +1058,7 @@ def _bind_selected_op(self, op: str) -> "VKernelDescriptor": priority=self.priority, _templates=self._templates, _inline_procs=self._inline_procs, + _internal_inline_procs=self._internal_inline_procs, _selected_op=normalized_op, _selected_dtype_signature=self._selected_dtype_signature, _parameters=self._parameters, @@ -1050,6 +1100,7 @@ def specialize(self, **bindings: Any) -> "VKernelDescriptor": priority=self.priority, _templates=self._templates, _inline_procs=self._inline_procs, + _internal_inline_procs=self._internal_inline_procs, _selected_op=self._selected_op, _selected_dtype_signature=self._selected_dtype_signature, _parameters=self._parameters, @@ -2121,6 +2172,7 @@ def _build_descriptor( source_info = _load_function_source_info(py_fn) advanced_enabled = _validate_advanced(advanced) inline_procs = _collect_inline_procs(py_fn.__module__) + internal_inline_procs = _collect_internal_inline_procs() _validate_function_body( source_info, advanced_enabled=advanced_enabled, @@ -2157,6 +2209,7 @@ def _build_descriptor( priority=_validate_priority(priority), _templates=frozen_templates, _inline_procs=inline_procs, + _internal_inline_procs=internal_inline_procs, _selected_op=selected_op, _selected_dtype_signature=selected_dtype_signature, _parameters=bound_parameters, diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index d781d141f..7882ad28b 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -246,6 +246,7 @@ def _classify_vcvt_elem_kind(dtype: ScalarType) -> str | None: "vsub", "vmul", "vdiv", + "vmod", "vmax", "vmin", "vand", @@ -777,6 +778,9 @@ def __init__(self, node: FrontendKernelNode): self._inline_proc_nodes: dict[str, FrontendInlineProcNode] = { inline_proc.name: inline_proc for inline_proc in node.inline_procs } + self._internal_inline_proc_nodes: dict[str, FrontendInlineProcNode] = { + inline_proc.name: inline_proc for inline_proc in node.internal_inline_procs + } self._inline_proc_specializations: dict[ tuple[str, tuple[tuple[SemanticType, object], ...]], SemanticKernel ] = {} @@ -1427,9 +1431,12 @@ def _inline_proc_specialization_key( self, name: str, args: tuple[SemanticExpr, ...], + *, + internal: bool = False, ) -> tuple[str, tuple[tuple[SemanticType, object], ...]]: + specialization_name = f"__internal__::{name}" if internal else name return ( - name, + specialization_name, tuple( (arg.type, self._inline_proc_static_specialization_token(arg)) for arg in args @@ -1508,12 +1515,17 @@ def _materialize_inline_proc_specialization( self, name: str, args: tuple[SemanticExpr, ...], + *, + internal: bool = False, ) -> SemanticKernel: - inline_proc_node = self._inline_proc_nodes.get(name) + inline_proc_nodes = ( + self._internal_inline_proc_nodes if internal else self._inline_proc_nodes + ) + inline_proc_node = inline_proc_nodes.get(name) if inline_proc_node is None: raise TypeError(f"inline_proc `{name}` is not registered in the current TileLang module") - key = self._inline_proc_specialization_key(name, args) + key = self._inline_proc_specialization_key(name, args, internal=internal) existing = self._inline_proc_specializations.get(key) if existing is not None: return existing @@ -1598,6 +1610,27 @@ def _analyze_inline_proc_call_expr( type=self._inline_proc_return_types.get(key), ) + def _analyze_internal_inline_proc_call_expr( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + helper_kernel = self._materialize_inline_proc_specialization( + name, + args, + internal=True, + ) + key = self._inline_proc_specialization_key(name, args, internal=True) + return SemanticCallExpr( + namespace=None, + name=helper_kernel.symbol_name, + args=args, + type=self._inline_proc_return_types.get(key), + ) + + def _is_internal_inline_proc_context(self) -> bool: + return any(key[0].startswith("__internal__::") for key in self._inline_proc_active_stack) + def _contains_explicit_vecscope(self, statements: tuple[FrontendStmtNode, ...]) -> bool: for stmt in statements: if isinstance(stmt, FrontendVecscopeStmt): @@ -3976,6 +4009,10 @@ def _analyze_call_expr( if namespace is None and name == "range": return SemanticCallExpr(namespace=namespace, name=name, args=args, type=None) if namespace is None: + if name in self._inline_proc_nodes: + return self._analyze_inline_proc_call_expr(name, args) + if name in self._internal_inline_proc_nodes and self._is_internal_inline_proc_context(): + return self._analyze_internal_inline_proc_call_expr(name, args) raise TypeError( f"call surface `{name}` is not supported in TileLang DSL v1" ) @@ -4854,6 +4891,20 @@ def _analyze_binary_vector_op( raise TypeError(f"pto.{name} requires lhs/rhs vector types to match") self._require_mask_for_vreg(mask, lhs, f"pto.{name}") self._validate_binary_dtype(name, lhs.element_dtype) + if ( + name in {"vdiv", "vmod"} + and is_integer_dtype(lhs.element_dtype) + and integer_bitwidth(lhs.element_dtype) in {8, 16, 32} + ): + return self._analyze_internal_inline_proc_call_expr( + "_tl_soft_vdiv" if name == "vdiv" else "_tl_soft_vmod", + ( + lhs_expr, + rhs_expr, + mask, + self._dtype_symbol_expr(lhs.element_dtype), + ), + ) return SemanticCallExpr(namespace="pto", name=name, args=args, type=lhs) def _analyze_vector_scalar_op( @@ -5376,6 +5427,14 @@ def _require_dtype_symbol(self, expr: SemanticExpr, context: str) -> ScalarType: raise TypeError(f"{context} must be a TileLang scalar dtype symbol in TileLang DSL v1") return expr.value + def _dtype_symbol_expr(self, dtype: ScalarType) -> SemanticSymbolExpr: + return SemanticSymbolExpr( + namespace="pto", + name=dtype.name, + value=dtype, + type=SemanticMetaType(kind="dtype"), + ) + def _require_memory_space_symbol(self, expr: SemanticExpr, context: str) -> MemorySpace: if ( isinstance(expr, SemanticSymbolExpr) @@ -6180,9 +6239,18 @@ def _validate_unary_dtype(self, name: str, dtype: ScalarType) -> None: def _validate_binary_dtype(self, name: str, dtype: ScalarType) -> None: if name == "vdiv" and not ( - dtype.name in {"f16", "f32"} or (is_integer_dtype(dtype) and integer_bitwidth(dtype) == 16) + (is_integer_dtype(dtype) and integer_bitwidth(dtype) in {8, 16, 32}) + or dtype.name in {"f16", "f32"} ): - raise TypeError("pto.vdiv only supports f16/f32/i16/ui16 in TileLang DSL v1") + raise TypeError( + "pto.vdiv only supports 8/16/32-bit integer families and f16/f32 in TileLang DSL v1" + ) + if name == "vmod" and not ( + is_integer_dtype(dtype) and integer_bitwidth(dtype) in {8, 16, 32} + ): + raise TypeError( + "pto.vmod only supports 8/16/32-bit integer families in TileLang DSL v1" + ) if name == "vprelu" and dtype.name not in {"f16", "f32"}: raise TypeError("pto.vprelu only supports f16/f32 in TileLang DSL v1") if name in {"vaddreluconv", "vmulconv"} and dtype.name not in {"f16", "bf16", "f32"}: diff --git a/tilelang-dsl/python/tilelang_dsl/support_matrix.py b/tilelang-dsl/python/tilelang_dsl/support_matrix.py index bca098b02..afc15d400 100644 --- a/tilelang-dsl/python/tilelang_dsl/support_matrix.py +++ b/tilelang-dsl/python/tilelang_dsl/support_matrix.py @@ -113,6 +113,7 @@ "vsub", "vmul", "vdiv", + "vmod", "vmax", "vmin", "vand", diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index afc2fd9ed..d29688f54 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -47,6 +47,7 @@ SemanticAlignStoreStmt, SemanticAlignType, SemanticAssignStmt, + SemanticBindingRef, SemanticBinaryExpr, SemanticCallExpr, SemanticDmaConfigStmt, @@ -65,6 +66,7 @@ SemanticPipeBarrierStmt, SemanticPtrType, SemanticPredicateStoreStmt, + SemanticReturnStmt, SemanticRlsBufStmt, SemanticScalarStoreStmt, SemanticScalarType, @@ -92,6 +94,62 @@ INLINE_PROC_GLOBAL_LANE = 0 +def _walk_semantic_stmts(statements): + for stmt in statements: + yield stmt + if isinstance(stmt, SemanticVecscopeStmt): + yield from _walk_semantic_stmts(stmt.body) + elif isinstance(stmt, SemanticForStmt): + yield from _walk_semantic_stmts(stmt.body) + elif isinstance(stmt, SemanticIfStmt): + yield from _walk_semantic_stmts(stmt.then_body) + yield from _walk_semantic_stmts(stmt.else_body) + + +def _find_inline_helper(semantic_kernel, symbol_prefix): + return next( + helper for helper in semantic_kernel.inline_helpers if helper.symbol_name.startswith(symbol_prefix) + ) + + +def _find_helper_assign_by_ssa(helper, ssa_name): + return next( + stmt + for stmt in helper.body + if isinstance(stmt, SemanticAssignStmt) + and any(target.ssa_name == ssa_name for target in stmt.targets) + ) + + +def _find_last_helper_assign_by_name(helper, name): + return next( + stmt + for stmt in reversed(helper.body) + if isinstance(stmt, SemanticAssignStmt) + and any(target.name == name for target in stmt.targets) + ) + + +def _find_helper_return_stmt(helper): + return next(stmt for stmt in helper.body if isinstance(stmt, SemanticReturnStmt)) + + +def _resolve_helper_expr(helper, expr): + if isinstance(expr, SemanticBindingRef): + assign = _find_helper_assign_by_ssa(helper, expr.binding.ssa_name) + return _resolve_helper_expr(helper, assign.value) + return expr + + +def _resolve_helper_broadcast_scalar_literal(helper, expr): + resolved = _resolve_helper_expr(helper, expr) + if isinstance(resolved, SemanticLiteralExpr): + return resolved.value + if isinstance(resolved, SemanticCallExpr) and resolved.namespace == "pto" and resolved.name == "vbr": + return _resolve_helper_broadcast_scalar_literal(helper, resolved.args[0]) + raise AssertionError(f"expected helper scalar literal or broadcast, got {resolved!r}") + + class TileLangDSLPackageTests(unittest.TestCase): def test_package_exports_surface(self) -> None: self.assertIsNotNone(pto.__file__) @@ -419,6 +477,7 @@ def test_stable_starter_surface_groups_map_to_stable_tier(self) -> None: self.assertIn("pto.vsts", AUTHORING_TIER_SURFACE_GROUPS["base_vector_ops"]) self.assertIn("pto.vadd", AUTHORING_TIER_SURFACE_GROUPS["base_vector_ops"]) self.assertIn("pto.vmuls", AUTHORING_TIER_SURFACE_GROUPS["base_vector_ops"]) + self.assertIn("pto.vmod", AUTHORING_TIER_SURFACE_GROUPS["base_vector_ops"]) self.assertIn("tile[start:]", BASIC_TILE_INDEXING_SURFACES) self.assertIn("tile[row, col:]", BASIC_TILE_INDEXING_SURFACES) @@ -428,6 +487,7 @@ def test_stable_starter_surface_groups_map_to_stable_tier(self) -> None: self.assertEqual(get_feature_tier("pto.vsts"), BASIC_TIER) self.assertEqual(get_feature_tier("pto.vadd"), BASIC_TIER) self.assertEqual(get_feature_tier("pto.vmuls"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vmod"), BASIC_TIER) self.assertEqual(get_feature_tier("pto.get_buf"), BASIC_TIER) self.assertEqual(get_feature_tier("pto.rls_buf"), BASIC_TIER) self.assertEqual(get_feature_tier("pto.get_block_idx"), BASIC_TIER) @@ -7034,6 +7094,731 @@ def kernel(dst: pto.Tile, src: pto.Tile): self.assertRegex(text, r"= func\.call @__tl_inline_") self.assertIn("pto.vsts", text) + def test_vdiv_integer_vector_types_rewrite_to_internal_helper(self) -> None: + @pto.vkernel(op="vdiv_i16_dtype_support_unique", dtypes=[(pto.i16, pto.i16)]) + def kernel_i16(dst: pto.Tile, src: pto.Tile): + dtype = dst.element_type + mask = pto.make_mask(dtype, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vdiv(vec, vec, mask) + pto.vsts(out, dst, 0, mask) + return None + + @pto.vkernel(op="vdiv_i32_dtype_support_unique", dtypes=[(pto.i32, pto.i32)]) + def kernel_i32(dst: pto.Tile, src: pto.Tile): + dtype = dst.element_type + mask = pto.make_mask(dtype, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vdiv(vec, vec, mask) + pto.vsts(out, dst, 0, mask) + return None + + specialized_i16 = kernel_i16.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + semantic_i16 = analyze_frontend_kernel(build_frontend_kernel_node(specialized_i16)) + assign_i16 = next( + stmt + for stmt in _walk_semantic_stmts(semantic_i16.body) + if isinstance(stmt, SemanticAssignStmt) + and len(stmt.targets) == 1 + and stmt.targets[0].name == "out" + and isinstance(stmt.value, SemanticCallExpr) + ) + self.assertIsNone(assign_i16.value.namespace) + self.assertRegex(assign_i16.value.name, r"^__tl_inline__tl_soft_vdiv_") + self.assertEqual(assign_i16.value.type, SemanticVRegType(element_dtype=pto.i16, lanes=128)) + self.assertGreaterEqual(len(semantic_i16.inline_helpers), 1) + self.assertRegex(specialized_i16.mlir_text(), r"func\.call @__tl_inline__tl_soft_vdiv_") + + text_i32 = kernel_i32.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + semantic_i32 = analyze_frontend_kernel(build_frontend_kernel_node(text_i32)) + assign_i32 = next( + stmt + for stmt in _walk_semantic_stmts(semantic_i32.body) + if isinstance(stmt, SemanticAssignStmt) + and len(stmt.targets) == 1 + and stmt.targets[0].name == "out" + and isinstance(stmt.value, SemanticCallExpr) + ) + self.assertIsNone(assign_i32.value.namespace) + self.assertRegex(assign_i32.value.name, r"^__tl_inline__tl_soft_vdiv_") + self.assertEqual(assign_i32.value.type, SemanticVRegType(element_dtype=pto.i32, lanes=64)) + self.assertGreaterEqual(len(semantic_i32.inline_helpers), 1) + self.assertRegex(text_i32.mlir_text(), r"func\.call @__tl_inline__tl_soft_vdiv_") + + def test_vdiv_f16_and_f32_vector_types_keep_authoring_form_vpto_path(self) -> None: + @pto.vkernel( + op="vdiv_float_dtype_support_unique", + dtypes=[(pto.f16, pto.f16), (pto.f32, pto.f32)], + ) + def kernel(dst: pto.Tile, src: pto.Tile): + dtype = dst.element_type + mask = pto.make_mask(dtype, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vdiv(vec, vec, mask) + pto.vsts(out, dst, 0, mask) + return None + + cases = [ + (pto.f16, 128), + (pto.f32, 64), + ] + + for dtype, lanes in cases: + with self.subTest(dtype=dtype): + selected = pto.select_kernel("a5", "vdiv_float_dtype_support_unique", (dtype, dtype)) + specialized = selected.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + + assign_stmt = next( + stmt + for stmt in _walk_semantic_stmts(semantic_kernel.body) + if isinstance(stmt, SemanticAssignStmt) + and len(stmt.targets) == 1 + and stmt.targets[0].name == "out" + and isinstance(stmt.value, SemanticCallExpr) + ) + self.assertEqual(assign_stmt.value.namespace, "pto") + self.assertEqual(assign_stmt.value.name, "vdiv") + self.assertEqual( + assign_stmt.value.type, + SemanticVRegType(element_dtype=dtype, lanes=lanes), + ) + self.assertEqual(len(semantic_kernel.inline_helpers), 0) + + text = lower_semantic_kernel(semantic_kernel).render() + self.assertEqual(text, specialized.mlir_text()) + self.assertIn("= pto.vdiv ", text) + self.assertNotIn("__tl_inline__tl_soft_vdiv_", text) + + def test_vdiv_i8_and_ui8_vector_types_rewrite_to_internal_helper(self) -> None: + @pto.vkernel(op="vdiv_i8_dtype_support_unique", dtypes=[(pto.i8, pto.i8)]) + def kernel_i8(dst: pto.Tile, src: pto.Tile): + mask = pto.make_mask(pto.i8, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vdiv(vec, vec, mask) + pto.vsts(out, dst, 0, mask) + return None + + @pto.vkernel(op="vdiv_ui8_dtype_support_unique", dtypes=[(pto.ui8, pto.ui8)]) + def kernel_ui8(dst: pto.Tile, src: pto.Tile): + mask = pto.make_mask(pto.ui8, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vdiv(vec, vec, mask) + pto.vsts(out, dst, 0, mask) + return None + + specialized_i8 = kernel_i8.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + semantic_i8 = analyze_frontend_kernel(build_frontend_kernel_node(specialized_i8)) + assign_i8 = next( + stmt + for stmt in _walk_semantic_stmts(semantic_i8.body) + if isinstance(stmt, SemanticAssignStmt) + and len(stmt.targets) == 1 + and stmt.targets[0].name == "out" + and isinstance(stmt.value, SemanticCallExpr) + ) + self.assertIsNone(assign_i8.value.namespace) + self.assertRegex(assign_i8.value.name, r"^__tl_inline__tl_soft_vdiv_") + self.assertEqual(assign_i8.value.type, SemanticVRegType(element_dtype=pto.i8, lanes=256)) + self.assertGreaterEqual(len(semantic_i8.inline_helpers), 1) + self.assertRegex(specialized_i8.mlir_text(), r"func\.call @__tl_inline__tl_soft_vdiv_") + + specialized_ui8 = kernel_ui8.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + semantic_ui8 = analyze_frontend_kernel(build_frontend_kernel_node(specialized_ui8)) + assign_ui8 = next( + stmt + for stmt in _walk_semantic_stmts(semantic_ui8.body) + if isinstance(stmt, SemanticAssignStmt) + and len(stmt.targets) == 1 + and stmt.targets[0].name == "out" + and isinstance(stmt.value, SemanticCallExpr) + ) + self.assertIsNone(assign_ui8.value.namespace) + self.assertRegex(assign_ui8.value.name, r"^__tl_inline__tl_soft_vdiv_") + self.assertEqual(assign_ui8.value.type, SemanticVRegType(element_dtype=pto.ui8, lanes=256)) + self.assertGreaterEqual(len(semantic_ui8.inline_helpers), 1) + self.assertRegex(specialized_ui8.mlir_text(), r"func\.call @__tl_inline__tl_soft_vdiv_") + + def test_vdiv_rejects_bf16_vector_type(self) -> None: + @pto.vkernel(op="vdiv_bf16_reject_unique", dtypes=[(pto.bf16, pto.bf16)]) + def kernel(dst: pto.Tile, src: pto.Tile): + mask = pto.make_mask(pto.bf16, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vdiv(vec, vec, mask) + pto.vsts(out, dst, 0, mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + + with self.assertRaises(TypeError) as ctx: + specialized.mlir_text() + + self.assertIn( + "pto.vdiv only supports 8/16/32-bit integer families and f16/f32 in TileLang DSL v1", + str(ctx.exception), + ) + + def test_vmod_integer_vector_types_rewrite_to_internal_helper(self) -> None: + @pto.vkernel(op="vmod_i16_dtype_support_unique", dtypes=[(pto.i16, pto.i16)]) + def kernel_i16(dst: pto.Tile, src: pto.Tile): + dtype = dst.element_type + mask = pto.make_mask(dtype, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vmod(vec, vec, mask) + pto.vsts(out, dst, 0, mask) + return None + + @pto.vkernel(op="vmod_i32_dtype_support_unique", dtypes=[(pto.i32, pto.i32)]) + def kernel_i32(dst: pto.Tile, src: pto.Tile): + dtype = dst.element_type + mask = pto.make_mask(dtype, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vmod(vec, vec, mask) + pto.vsts(out, dst, 0, mask) + return None + + specialized_i16 = kernel_i16.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + semantic_i16 = analyze_frontend_kernel(build_frontend_kernel_node(specialized_i16)) + assign_i16 = next( + stmt + for stmt in _walk_semantic_stmts(semantic_i16.body) + if isinstance(stmt, SemanticAssignStmt) + and len(stmt.targets) == 1 + and stmt.targets[0].name == "out" + and isinstance(stmt.value, SemanticCallExpr) + ) + self.assertIsNone(assign_i16.value.namespace) + self.assertRegex(assign_i16.value.name, r"^__tl_inline__tl_soft_vmod_") + self.assertEqual(assign_i16.value.type, SemanticVRegType(element_dtype=pto.i16, lanes=128)) + self.assertGreaterEqual(len(semantic_i16.inline_helpers), 1) + self.assertRegex(specialized_i16.mlir_text(), r"func\.call @__tl_inline__tl_soft_vmod_") + + specialized_i32 = kernel_i32.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + semantic_i32 = analyze_frontend_kernel(build_frontend_kernel_node(specialized_i32)) + assign_i32 = next( + stmt + for stmt in _walk_semantic_stmts(semantic_i32.body) + if isinstance(stmt, SemanticAssignStmt) + and len(stmt.targets) == 1 + and stmt.targets[0].name == "out" + and isinstance(stmt.value, SemanticCallExpr) + ) + self.assertIsNone(assign_i32.value.namespace) + self.assertRegex(assign_i32.value.name, r"^__tl_inline__tl_soft_vmod_") + self.assertEqual(assign_i32.value.type, SemanticVRegType(element_dtype=pto.i32, lanes=64)) + self.assertGreaterEqual(len(semantic_i32.inline_helpers), 1) + self.assertRegex(specialized_i32.mlir_text(), r"func\.call @__tl_inline__tl_soft_vmod_") + + def test_vmod_i8_and_ui8_vector_types_rewrite_to_internal_helper(self) -> None: + @pto.vkernel(op="vmod_i8_dtype_support_unique", dtypes=[(pto.i8, pto.i8)]) + def kernel_i8(dst: pto.Tile, src: pto.Tile): + mask = pto.make_mask(pto.i8, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vmod(vec, vec, mask) + pto.vsts(out, dst, 0, mask) + return None + + @pto.vkernel(op="vmod_ui8_dtype_support_unique", dtypes=[(pto.ui8, pto.ui8)]) + def kernel_ui8(dst: pto.Tile, src: pto.Tile): + mask = pto.make_mask(pto.ui8, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vmod(vec, vec, mask) + pto.vsts(out, dst, 0, mask) + return None + + specialized_i8 = kernel_i8.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + semantic_i8 = analyze_frontend_kernel(build_frontend_kernel_node(specialized_i8)) + assign_i8 = next( + stmt + for stmt in _walk_semantic_stmts(semantic_i8.body) + if isinstance(stmt, SemanticAssignStmt) + and len(stmt.targets) == 1 + and stmt.targets[0].name == "out" + and isinstance(stmt.value, SemanticCallExpr) + ) + self.assertIsNone(assign_i8.value.namespace) + self.assertRegex(assign_i8.value.name, r"^__tl_inline__tl_soft_vmod_") + self.assertEqual(assign_i8.value.type, SemanticVRegType(element_dtype=pto.i8, lanes=256)) + self.assertGreaterEqual(len(semantic_i8.inline_helpers), 1) + self.assertRegex(specialized_i8.mlir_text(), r"func\.call @__tl_inline__tl_soft_vmod_") + + specialized_ui8 = kernel_ui8.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + semantic_ui8 = analyze_frontend_kernel(build_frontend_kernel_node(specialized_ui8)) + assign_ui8 = next( + stmt + for stmt in _walk_semantic_stmts(semantic_ui8.body) + if isinstance(stmt, SemanticAssignStmt) + and len(stmt.targets) == 1 + and stmt.targets[0].name == "out" + and isinstance(stmt.value, SemanticCallExpr) + ) + self.assertIsNone(assign_ui8.value.namespace) + self.assertRegex(assign_ui8.value.name, r"^__tl_inline__tl_soft_vmod_") + self.assertEqual(assign_ui8.value.type, SemanticVRegType(element_dtype=pto.ui8, lanes=256)) + self.assertGreaterEqual(len(semantic_ui8.inline_helpers), 1) + self.assertRegex(specialized_ui8.mlir_text(), r"func\.call @__tl_inline__tl_soft_vmod_") + + def test_vmod_rejects_f32_vector_type(self) -> None: + @pto.vkernel(op="vmod_f32_reject_unique", dtypes=[(pto.f32, pto.f32)]) + def kernel(dst: pto.Tile, src: pto.Tile): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vmod(vec, vec, mask) + pto.vsts(out, dst, 0, mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + + with self.assertRaises(TypeError) as ctx: + specialized.mlir_text() + + self.assertIn( + "pto.vmod only supports 8/16/32-bit integer families in TileLang DSL v1", + str(ctx.exception), + ) + + def test_integer_divmod_helpers_lock_zero_divisor_sentinel_convention(self) -> None: + @pto.vkernel( + op="integer_divmod_zero_divisor_contract_unique", + dtypes=[ + (pto.i8, pto.i8), + (pto.ui8, pto.ui8), + (pto.i16, pto.i16), + (pto.ui16, pto.ui16), + (pto.i32, pto.i32), + (pto.ui32, pto.ui32), + ], + ) + def kernel(dst: pto.Tile, src: pto.Tile): + dtype = dst.element_type + mask = pto.make_mask(dtype, pto.PAT.ALL) + vec = pto.vlds(src, 0) + quot = pto.vdiv(vec, vec, mask) + rem = pto.vmod(vec, vec, mask) + pto.vsts(quot, dst, 0, mask) + pto.vsts(rem, dst, 0, mask) + return None + + cases = [ + ("vdiv", pto.i8, "__tl_inline__tl_soft_vdiv_i8_", -1), + ("vdiv", pto.ui8, "__tl_inline__tl_soft_vdiv_u8_", 0xFF), + ("vdiv", pto.i16, "__tl_inline__tl_soft_vdiv_i16_", -1), + ("vdiv", pto.ui16, "__tl_inline__tl_soft_vdiv_u16_", 0xFFFF), + ("vdiv", pto.i32, "__tl_inline__tl_soft_vdiv_i32_", -1), + ("vdiv", pto.ui32, "__tl_inline__tl_soft_vdiv_u32_", 0xFFFFFFFF), + ("vmod", pto.i8, "__tl_inline__tl_soft_vmod_i8_", -1), + ("vmod", pto.ui8, "__tl_inline__tl_soft_vmod_u8_", 0xFF), + ("vmod", pto.i16, "__tl_inline__tl_soft_vmod_i16_", -1), + ("vmod", pto.ui16, "__tl_inline__tl_soft_vmod_u16_", 0xFFFF), + ("vmod", pto.i32, "__tl_inline__tl_soft_vmod_i32_", -1), + ("vmod", pto.ui32, "__tl_inline__tl_soft_vmod_u32_", 0xFFFFFFFF), + ] + + for op_name, dtype, helper_prefix, expected_sentinel in cases: + with self.subTest(op=op_name, dtype=dtype): + selected = pto.select_kernel( + "a5", + "integer_divmod_zero_divisor_contract_unique", + (dtype, dtype), + ) + specialized = selected.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + helper = _find_inline_helper(semantic_kernel, helper_prefix) + + zero_mask_assign = _find_last_helper_assign_by_name(helper, "zero_mask") + self.assertIsInstance(zero_mask_assign.value, SemanticCallExpr) + self.assertEqual(zero_mask_assign.value.namespace, "pto") + self.assertEqual(zero_mask_assign.value.name, "vcmps") + self.assertEqual(zero_mask_assign.value.args[3].value, "eq") + self.assertEqual( + _resolve_helper_broadcast_scalar_literal(helper, zero_mask_assign.value.args[1]), + 0, + ) + + return_stmt = _find_helper_return_stmt(helper) + self.assertIsInstance(return_stmt.value, SemanticCallExpr) + self.assertEqual(return_stmt.value.namespace, "pto") + self.assertEqual(return_stmt.value.name, "vsel") + self.assertIsInstance(return_stmt.value.args[2], SemanticBindingRef) + self.assertEqual(return_stmt.value.args[2].binding.name, "zero_mask") + self.assertEqual( + _resolve_helper_broadcast_scalar_literal(helper, return_stmt.value.args[0]), + expected_sentinel, + ) + + def test_signed_vdiv_helpers_derive_result_sign_from_operand_signs(self) -> None: + @pto.vkernel( + op="signed_vdiv_sign_contract_unique", + dtypes=[(pto.i16, pto.i16), (pto.i32, pto.i32)], + ) + def kernel(dst: pto.Tile, src: pto.Tile): + dtype = dst.element_type + mask = pto.make_mask(dtype, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vdiv(vec, vec, mask) + pto.vsts(out, dst, 0, mask) + return None + + cases = [ + (pto.i16, "__tl_inline__tl_soft_vdiv_i16_", "i16"), + (pto.i32, "__tl_inline__tl_soft_vdiv_i32_", "i32"), + ] + + for dtype, helper_prefix, dtype_name in cases: + with self.subTest(dtype=dtype): + selected = pto.select_kernel("a5", "signed_vdiv_sign_contract_unique", (dtype, dtype)) + specialized = selected.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + helper = _find_inline_helper(semantic_kernel, helper_prefix) + + xor_assign = _find_last_helper_assign_by_name(helper, "x_xor_y") + self.assertIsInstance(xor_assign.value, SemanticCallExpr) + self.assertEqual(xor_assign.value.namespace, "pto") + self.assertEqual(xor_assign.value.name, "vxor") + self.assertEqual(xor_assign.value.args[0].binding.name, "vec") + self.assertEqual(xor_assign.value.args[1].binding.name, "scalar_vec") + self.assertEqual(xor_assign.value.args[2].binding.name, "active_mask") + + p_pos_assign = _find_last_helper_assign_by_name(helper, "p_pos") + self.assertIsInstance(p_pos_assign.value, SemanticCallExpr) + self.assertEqual(p_pos_assign.value.namespace, "pto") + self.assertEqual(p_pos_assign.value.name, "vcmps") + self.assertEqual(p_pos_assign.value.args[0].binding.name, "x_xor_y") + self.assertEqual(p_pos_assign.value.args[1].binding.name, "zero") + self.assertEqual(p_pos_assign.value.args[2].binding.name, "active_mask") + self.assertEqual(p_pos_assign.value.args[3].value, "ge") + + q_assign = _find_last_helper_assign_by_name(helper, "q") + self.assertIsInstance(q_assign.value, SemanticCallExpr) + self.assertEqual(q_assign.value.namespace, "pto") + self.assertEqual(q_assign.value.name, "vsel") + self.assertEqual(q_assign.value.args[1].binding.name, "neg_q") + self.assertEqual(q_assign.value.args[2].binding.name, "p_pos") + self.assertIsInstance(q_assign.value.args[0], SemanticCallExpr) + self.assertEqual(q_assign.value.args[0].namespace, "pto") + self.assertEqual(q_assign.value.args[0].name, "vbitcast") + self.assertIsInstance(q_assign.value.args[0].args[0], SemanticBindingRef) + self.assertEqual(q_assign.value.args[0].args[1].name, dtype_name) + + def test_signed_vmod_helpers_apply_floor_fix_when_signs_differ(self) -> None: + @pto.vkernel( + op="signed_vmod_sign_contract_unique", + dtypes=[(pto.i16, pto.i16), (pto.i32, pto.i32)], + ) + def kernel(dst: pto.Tile, src: pto.Tile): + dtype = dst.element_type + mask = pto.make_mask(dtype, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vmod(vec, vec, mask) + pto.vsts(out, dst, 0, mask) + return None + + cases = [ + (pto.i16, "__tl_inline__tl_soft_vmod_i16_"), + (pto.i32, "__tl_inline__tl_soft_vmod_i32_"), + ] + + for dtype, helper_prefix in cases: + with self.subTest(dtype=dtype): + selected = pto.select_kernel("a5", "signed_vmod_sign_contract_unique", (dtype, dtype)) + specialized = selected.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + helper = _find_inline_helper(semantic_kernel, helper_prefix) + + nonzero_assign = _find_last_helper_assign_by_name(helper, "nonzero_remainder") + self.assertIsInstance(nonzero_assign.value, SemanticCallExpr) + self.assertEqual(nonzero_assign.value.namespace, "pto") + self.assertEqual(nonzero_assign.value.name, "vcmps") + self.assertEqual(nonzero_assign.value.args[1].binding.name, "zero") + self.assertEqual(nonzero_assign.value.args[2].binding.name, "active_mask") + self.assertEqual(nonzero_assign.value.args[3].value, "ne") + + sign_diff_assign = _find_last_helper_assign_by_name(helper, "sign_diff") + self.assertIsInstance(sign_diff_assign.value, SemanticCallExpr) + self.assertEqual(sign_diff_assign.value.namespace, "pto") + self.assertEqual(sign_diff_assign.value.name, "pxor") + self.assertEqual(sign_diff_assign.value.args[0].binding.name, "sign_x") + self.assertEqual(sign_diff_assign.value.args[1].binding.name, "sign_y") + self.assertEqual(sign_diff_assign.value.args[2].binding.name, "active_mask") + + need_fix_assign = _find_last_helper_assign_by_name(helper, "need_floor_fix") + self.assertIsInstance(need_fix_assign.value, SemanticCallExpr) + self.assertEqual(need_fix_assign.value.namespace, "pto") + self.assertEqual(need_fix_assign.value.name, "pand") + self.assertEqual(need_fix_assign.value.args[0].binding.name, "sign_diff") + self.assertEqual(need_fix_assign.value.args[1].binding.name, "nonzero_remainder") + self.assertEqual(need_fix_assign.value.args[2].binding.name, "active_mask") + + amended_assign = _find_last_helper_assign_by_name(helper, "amended_remainder") + self.assertIsInstance(amended_assign.value, SemanticCallExpr) + self.assertEqual(amended_assign.value.namespace, "pto") + self.assertEqual(amended_assign.value.name, "vadd") + self.assertEqual(amended_assign.value.args[0].binding.name, "scalar_vec") + self.assertEqual(amended_assign.value.args[1].binding.name, "remainder") + self.assertEqual(amended_assign.value.args[2].binding.name, "active_mask") + + remainder_assign = _find_last_helper_assign_by_name(helper, "remainder") + self.assertIsInstance(remainder_assign.value, SemanticCallExpr) + self.assertEqual(remainder_assign.value.namespace, "pto") + self.assertEqual(remainder_assign.value.name, "vsel") + self.assertEqual(remainder_assign.value.args[0].binding.name, "amended_remainder") + self.assertEqual(remainder_assign.value.args[2].binding.name, "need_floor_fix") + + def test_i8_divmod_helpers_use_explicit_widen_narrow_profile(self) -> None: + @pto.vkernel( + op="i8_divmod_widen_narrow_contract_unique", + dtypes=[ + (pto.i8, pto.i8), + (pto.ui8, pto.ui8), + ], + ) + def kernel(dst: pto.Tile, src: pto.Tile): + dtype = dst.element_type + mask = pto.make_mask(dtype, pto.PAT.ALL) + vec = pto.vlds(src, 0) + quot = pto.vdiv(vec, vec, mask) + rem = pto.vmod(vec, vec, mask) + pto.vsts(quot, dst, 0, mask) + pto.vsts(rem, dst, 1, mask) + return None + + cases = [ + ( + pto.i8, + "__tl_inline__tl_soft_vdiv_i8_", + "__tl_inline__tl_soft_vdiv_i16_", + "vsunpack", + "q", + "q_low", + "q_high", + "vbitcast", + ), + ( + pto.ui8, + "__tl_inline__tl_soft_vdiv_u8_", + "__tl_inline__tl_soft_vdiv_u16_", + "vzunpack", + "q", + "q_low", + "q_high", + "vor", + ), + ( + pto.i8, + "__tl_inline__tl_soft_vmod_i8_", + "__tl_inline__tl_soft_vmod_i16_", + "vsunpack", + "r", + "r_low", + "r_high", + "vbitcast", + ), + ( + pto.ui8, + "__tl_inline__tl_soft_vmod_u8_", + "__tl_inline__tl_soft_vmod_u16_", + "vzunpack", + "r", + "r_low", + "r_high", + "vor", + ), + ] + + for ( + dtype, + helper_prefix, + widened_helper_prefix, + unpack_name, + packed_result_name, + lower_result_name, + higher_result_name, + packed_result_op, + ) in cases: + with self.subTest(dtype=dtype, helper=helper_prefix): + selected = pto.select_kernel( + "a5", + "i8_divmod_widen_narrow_contract_unique", + (dtype, dtype), + ) + specialized = selected.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + helper = _find_inline_helper(semantic_kernel, helper_prefix) + + active_low_assign = _find_last_helper_assign_by_name(helper, "active_low") + self.assertIsInstance(active_low_assign.value, SemanticCallExpr) + self.assertEqual(active_low_assign.value.namespace, "pto") + self.assertEqual(active_low_assign.value.name, "punpack") + self.assertEqual(active_low_assign.value.args[0].binding.name, "active_mask") + + active_high_assign = _find_last_helper_assign_by_name(helper, "active_high") + self.assertIsInstance(active_high_assign.value, SemanticCallExpr) + self.assertEqual(active_high_assign.value.namespace, "pto") + self.assertEqual(active_high_assign.value.name, "punpack") + self.assertEqual(active_high_assign.value.args[0].binding.name, "active_mask") + + for name, expected_half in ( + ("vec_low", 0), + ("vec_high", 1), + ("scalar_low", 0), + ("scalar_high", 1), + ): + assign = _find_last_helper_assign_by_name(helper, name) + self.assertIsInstance(assign.value, SemanticCallExpr) + self.assertEqual(assign.value.namespace, "pto") + self.assertEqual(assign.value.name, unpack_name) + self.assertEqual(assign.value.args[1].value, expected_half) + + lower_assign = _find_last_helper_assign_by_name(helper, lower_result_name) + self.assertIsInstance(lower_assign.value, SemanticCallExpr) + self.assertIsNone(lower_assign.value.namespace) + self.assertRegex(lower_assign.value.name, rf"^{widened_helper_prefix}") + self.assertEqual(lower_assign.value.args[2].binding.name, "active_low") + + higher_assign = _find_last_helper_assign_by_name(helper, higher_result_name) + self.assertIsInstance(higher_assign.value, SemanticCallExpr) + self.assertIsNone(higher_assign.value.namespace) + self.assertRegex(higher_assign.value.name, rf"^{widened_helper_prefix}") + self.assertEqual(higher_assign.value.args[2].binding.name, "active_high") + + packed_low_assign = _find_last_helper_assign_by_name(helper, "packed_low") + self.assertIsInstance(packed_low_assign.value, SemanticCallExpr) + self.assertEqual(packed_low_assign.value.namespace, "pto") + self.assertEqual(packed_low_assign.value.name, "vpack") + self.assertEqual(packed_low_assign.value.args[0].binding.name, lower_result_name) + + packed_high_assign = _find_last_helper_assign_by_name(helper, "packed_high") + self.assertIsInstance(packed_high_assign.value, SemanticCallExpr) + self.assertEqual(packed_high_assign.value.namespace, "pto") + self.assertEqual(packed_high_assign.value.name, "vpack") + self.assertEqual(packed_high_assign.value.args[0].binding.name, higher_result_name) + + packed_result_assign = _find_last_helper_assign_by_name(helper, packed_result_name) + self.assertIsInstance(packed_result_assign.value, SemanticCallExpr) + self.assertEqual(packed_result_assign.value.namespace, "pto") + self.assertEqual(packed_result_assign.value.name, packed_result_op) + if packed_result_op == "vor": + combined_expr = packed_result_assign.value + else: + self.assertIsInstance(packed_result_assign.value.args[0], SemanticCallExpr) + combined_expr = packed_result_assign.value.args[0] + self.assertEqual(combined_expr.namespace, "pto") + self.assertEqual(combined_expr.name, "vor") + self.assertEqual(combined_expr.args[0].binding.name, "packed_low") + self.assertEqual(combined_expr.args[1].binding.name, "packed_high") + self.assertEqual(combined_expr.args[2].binding.name, "full_mask_b8") + + def test_integer_divmod_rewrite_uses_injected_internal_helpers(self) -> None: + @pto.vkernel(op="divmod_internal_helper_injection_unique", dtypes=[(pto.i32, pto.i32)]) + def kernel(dst: pto.Tile, src: pto.Tile): + mask = pto.make_mask(pto.i32, pto.PAT.ALL) + vec = pto.vlds(src, 0) + quot = pto.vdiv(vec, vec, mask) + rem = pto.vmod(vec, vec, mask) + pto.vsts(quot, dst, 0, mask) + pto.vsts(rem, dst, 1, mask) + return None + + self.assertNotIn("vdiv", kernel.py_fn.__globals__) + self.assertNotIn("vmod", kernel.py_fn.__globals__) + self.assertNotIn("_tl_soft_vdiv_i32", kernel.inline_procs) + self.assertNotIn("_tl_soft_vmod_i32", kernel.inline_procs) + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + frontend_kernel = build_frontend_kernel_node(specialized) + + self.assertEqual({proc.name for proc in frontend_kernel.inline_procs}, set()) + internal_names = {proc.name for proc in frontend_kernel.internal_inline_procs} + self.assertIn("_tl_soft_vdiv", internal_names) + self.assertIn("_tl_soft_vmod", internal_names) + self.assertIn("_tl_soft_vdiv_i32", internal_names) + self.assertIn("_tl_soft_vmod_i32", internal_names) + + semantic_kernel = analyze_frontend_kernel(frontend_kernel) + helper_symbols = {helper.symbol_name for helper in semantic_kernel.inline_helpers} + self.assertTrue(any(name.startswith("__tl_inline__tl_soft_vdiv_") for name in helper_symbols)) + self.assertTrue(any(name.startswith("__tl_inline__tl_soft_vmod_") for name in helper_symbols)) + + def test_internal_vdiv_helper_name_is_not_public_surface(self) -> None: + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.vkernel(op="internal_vdiv_helper_reject_unique", dtypes=[(pto.i32, pto.i32)]) + def kernel(dst: pto.Tile, src: pto.Tile): + mask = pto.make_mask(pto.i32, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = _tl_soft_vdiv_i32(vec, vec, mask) + pto.vsts(out, dst, 0, mask) + return None + + self.assertIn( + "arbitrary external call `_tl_soft_vdiv_i32` is not supported in TileLang DSL v1", + str(ctx.exception), + ) + + def test_internal_vmod_helper_name_is_not_public_surface(self) -> None: + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.vkernel(op="internal_vmod_helper_reject_unique", dtypes=[(pto.i32, pto.i32)]) + def kernel(dst: pto.Tile, src: pto.Tile): + mask = pto.make_mask(pto.i32, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = _tl_soft_vmod_i32(vec, vec, mask) + pto.vsts(out, dst, 0, mask) + return None + + self.assertIn( + "arbitrary external call `_tl_soft_vmod_i32` is not supported in TileLang DSL v1", + str(ctx.exception), + ) + def test_inline_proc_and_pto_surface_can_share_basename(self) -> None: @pto.inline_proc def vdiv(src: pto.Tile, lane: pto.i32 = 0): diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index 1b0bb054a..e2bb6d134 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -487,6 +487,17 @@ static bool hasUnexpandedTileOps(ModuleOp module) { return found; } +static bool hasTilelangInlineHelpers(ModuleOp module) { + bool found = false; + module.walk([&](func::FuncOp func) { + if (found) + return; + if (func->hasAttr("pto.tilelang.inline_proc")) + found = true; + }); + return found; +} + // -------------------------------------------------------------------------- // Post-process C++ output: rewrite marker calls into Tile member calls. // @@ -1196,6 +1207,21 @@ static LogicalResult lowerPTOToVPTOBackend(ModuleOp module) { return success(); } +static LogicalResult inlineTilelangHelpersOnVPTOInput(ModuleOp module) { + PassManager inlinePM(module.getContext()); + inlinePM.addPass(pto::createPTOInlineLibCallPass()); + inlinePM.addPass(mlir::createSCCPPass()); + inlinePM.addPass(mlir::createCanonicalizerPass()); + if (failed(applyConfiguredPassManagerCLOptions( + inlinePM, "VPTO TileLang helper inlining"))) + return failure(); + if (failed(inlinePM.run(module))) { + llvm::errs() << "Error: VPTO TileLang helper inlining failed.\n"; + return failure(); + } + return success(); +} + static pto::VPTOEmissionOptions buildVPTOEmissionOptions() { pto::VPTOEmissionOptions options; options.dumpVPTOIR = false; @@ -1510,6 +1536,7 @@ int main(int argc, char **argv) { } const bool hasTileOpsToExpand = hasUnexpandedTileOps(*module); + const bool hasTilelangHelpers = hasTilelangInlineHelpers(*module); if (effectiveBackend == PTOBackend::VPTO && inputIsVPTOIR && !hasTileOpsToExpand) { @@ -1519,6 +1546,10 @@ int main(int argc, char **argv) { return 1; } + if (hasTilelangHelpers && + failed(inlineTilelangHelpersOnVPTOInput(*module))) + return 1; + return emitVPTOBackendResult(*module, outputFile); } From 0923963578c7374cfbb3c971987196815926110d Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Sat, 25 Apr 2026 23:53:11 +0800 Subject: [PATCH 178/192] feat: support vmrgsort4 emit --- lib/PTO/IR/VPTO.cpp | 18 +++ lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 121 +++++++++++++++++- .../micro-op/dsa-sfu/vmrgsort4/compare.py | 56 ++++++++ .../micro-op/dsa-sfu/vmrgsort4/golden.py | 41 ++++++ .../micro-op/dsa-sfu/vmrgsort4/kernel.pto | 39 ++++++ .../micro-op/dsa-sfu/vmrgsort4/launch.cpp | 25 ++++ .../cases/micro-op/dsa-sfu/vmrgsort4/main.cpp | 84 ++++++++++++ .../cases/micro-op/dsa-sfu/vmrgsort4/stub.cpp | 21 +++ 8 files changed, 403 insertions(+), 2 deletions(-) create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vmrgsort4/compare.py create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vmrgsort4/golden.py create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vmrgsort4/kernel.pto create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vmrgsort4/launch.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vmrgsort4/main.cpp create mode 100644 test/vpto/cases/micro-op/dsa-sfu/vmrgsort4/stub.cpp diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp index 23534684b..511e488f5 100644 --- a/lib/PTO/IR/VPTO.cpp +++ b/lib/PTO/IR/VPTO.cpp @@ -1895,6 +1895,24 @@ LogicalResult Vmrgsort4Op::verify() { classifyMemoryRole(getSource2().getType()) != MemoryRole::UB || classifyMemoryRole(getSource3().getType()) != MemoryRole::UB) return emitOpError("requires UB-backed destination and sources"); + auto dstPtrType = dyn_cast(getDestination().getType()); + auto src0PtrType = dyn_cast(getSource0().getType()); + auto src1PtrType = dyn_cast(getSource1().getType()); + auto src2PtrType = dyn_cast(getSource2().getType()); + auto src3PtrType = dyn_cast(getSource3().getType()); + if (!dstPtrType || !src0PtrType || !src1PtrType || !src2PtrType || + !src3PtrType) + return emitOpError("requires ptr-backed destination and sources"); + + Type elemType = dstPtrType.getElementType(); + if (src0PtrType.getElementType() != elemType || + src1PtrType.getElementType() != elemType || + src2PtrType.getElementType() != elemType || + src3PtrType.getElementType() != elemType) + return emitOpError( + "requires destination and all sources to have the same element type"); + if (!elemType.isF16() && !elemType.isF32()) + return emitOpError("requires f16 or f32 element type"); if (failed(verifyNotNestedInVecScope(*this, "pto.vmrgsort4"))) return failure(); return success(); diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index 57fd57635..621fd1098 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -2438,6 +2438,63 @@ static FailureOr buildVbitsortCallee(MLIRContext *context, return failure(); } +static FailureOr buildVmrgsort4Callee(MLIRContext *context, + pto::Vmrgsort4Op op) { + Type elemType = + cast(op.getDestination().getType()).getElementType(); + if (elemType.isF16()) + return StringAttr::get(context, "llvm.hivm.VMRGSORT.f16.V300").getValue(); + if (elemType.isF32()) + return StringAttr::get(context, "llvm.hivm.VMRGSORT.f32.V300").getValue(); + return failure(); +} + +static FailureOr packVmrgsort4SourceAddr(Operation *anchor, Value source0, + Value source1, Value source2, + Value source3, Type elemType) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); + unsigned addrShift = 0; + if (elemType.isF16()) + addrShift = 3; + else if (elemType.isF32()) + addrShift = 3; + else + return failure(); + + auto packOne = [&](Value source, uint64_t laneShift) -> FailureOr { + FailureOr ubPtr = reinterpretPointerToAddrSpace(anchor, source, 6); + if (failed(ubPtr)) + return failure(); + Value asInt = + builder.create(loc, builder.getI64Type(), *ubPtr); + Value shifted = builder.create( + loc, asInt, getI64Constant(builder, loc, addrShift)); + Value masked = builder.create( + loc, shifted, getI64Constant(builder, loc, 0xFFFFULL)); + if (laneShift == 0) + return masked; + return builder + .create(loc, masked, + getI64Constant(builder, loc, laneShift)) + .getResult(); + }; + + FailureOr low0 = packOne(source0, 0); + FailureOr low1 = packOne(source1, 16); + FailureOr low2 = packOne(source2, 32); + FailureOr low3 = packOne(source3, 48); + if (failed(low0) || failed(low1) || failed(low2) || failed(low3)) + return failure(); + + Value packed01 = builder.create(loc, *low0, *low1); + Value packed23 = builder.create(loc, *low2, *low3); + Value packed = builder.create(loc, packed01, packed23); + Type ubPtrTy = LLVM::LLVMPointerType::get(anchor->getContext(), 6); + return builder.create(loc, ubPtrTy, packed).getResult(); +} + static FailureOr buildVcvtContract(pto::VcvtOp op) { Type inputElemType = getElementTypeFromVectorLike(op.getInput().getType()); Type resultElemType = getElementTypeFromVectorLike(op.getResult().getType()); @@ -5815,6 +5872,64 @@ class LowerVbitsortOpPattern final LoweringState &state; }; +class LowerVmrgsort4OpPattern final + : public OpConversionPattern { +public: + explicit LowerVmrgsort4OpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::Vmrgsort4Op op, pto::Vmrgsort4Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto dstType = + dyn_cast(adaptor.getDestination().getType()); + auto src0Type = + dyn_cast(adaptor.getSource0().getType()); + auto src1Type = + dyn_cast(adaptor.getSource1().getType()); + auto src2Type = + dyn_cast(adaptor.getSource2().getType()); + auto src3Type = + dyn_cast(adaptor.getSource3().getType()); + if (!dstType || !src0Type || !src1Type || !src2Type || !src3Type) + return rewriter.notifyMatchFailure( + op, "unexpected converted vmrgsort4 operand types"); + + Type elemType = + cast(op.getDestination().getType()).getElementType(); + FailureOr packedSrc = packVmrgsort4SourceAddr( + op, adaptor.getSource0(), adaptor.getSource1(), adaptor.getSource2(), + adaptor.getSource3(), elemType); + if (failed(packedSrc)) + return rewriter.notifyMatchFailure( + op, "failed to pack vmrgsort4 source addresses"); + + FailureOr dst = reinterpretPointerToAddrSpace(op, adaptor.getDestination(), 6); + if (failed(dst)) + return rewriter.notifyMatchFailure(op, "failed to normalize vmrgsort4 destination"); + + FailureOr calleeName = buildVmrgsort4Callee(op.getContext(), op); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vmrgsort4 signature"); + + auto funcType = rewriter.getFunctionType( + TypeRange{(*dst).getType(), (*packedSrc).getType(), + adaptor.getCount().getType(), adaptor.getConfig().getType()}, + TypeRange{}); + rewriter.create( + op.getLoc(), *calleeName, TypeRange{}, + ValueRange{*dst, *packedSrc, adaptor.getCount(), adaptor.getConfig()}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + class LowerVcvtOpPattern final : public OpConversionPattern { public: explicit LowerVcvtOpPattern(TypeConverter &typeConverter, @@ -6749,7 +6864,8 @@ static void populateVPTOOpLoweringPatterns(VPTOTypeConverter &typeConverter, LowerVgatherbOpPattern, LowerVscatterOpPattern, LowerVaxpyOpPattern, LowerVciOpPattern, LowerVexpdifOpPattern, - LowerVbitsortOpPattern, LowerVtrcOpPattern, LowerVcvtOpPattern, + LowerVbitsortOpPattern, LowerVmrgsort4OpPattern, + LowerVtrcOpPattern, LowerVcvtOpPattern, LowerVbitcastOpPattern, LowerPbitcastOpPattern, LowerPredicateLoadOpPattern, LowerPredicateLoadOpPattern, @@ -6827,7 +6943,8 @@ static void configureVPTOOpLoweringTarget(ConversionTarget &target, pto::VsunpackOp, pto::VzunpackOp, pto::VpackOp, pto::VintlvOp, pto::VdintlvOp, pto::VpreluOp, pto::VaxpyOp, pto::VciOp, pto::VexpdifOp, - pto::VbitsortOp, pto::VtrcOp, pto::VcvtOp, + pto::VbitsortOp, pto::Vmrgsort4Op, pto::VtrcOp, + pto::VcvtOp, pto::VbitcastOp, pto::VcmpOp, pto::VcmpsOp, pto::CopyGmToUbufOp, pto::CopyUbufToGmOp, diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmrgsort4/compare.py b/test/vpto/cases/micro-op/dsa-sfu/vmrgsort4/compare.py new file mode 100644 index 000000000..6f6150678 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmrgsort4/compare.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import os +import struct +import sys + +import numpy as np + + +PAIR_FMT = "fI" +PAIR_SIZE = struct.calcsize(PAIR_FMT) +PAIR_COUNT = 4 + + +def read_pairs(path: str): + values = [] + indices = [] + with open(path, "rb") as f: + for _ in range(PAIR_COUNT): + data = f.read(PAIR_SIZE) + if len(data) != PAIR_SIZE: + break + value, index = struct.unpack(PAIR_FMT, data) + values.append(value) + indices.append(index) + return np.array(values, dtype=np.float32), np.array(indices, dtype=np.uint32) + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + golden_values, golden_indices = read_pairs("golden_v2.bin") + output_values, output_indices = read_pairs("v2.bin") + ok = ( + golden_values.shape == output_values.shape + and golden_indices.shape == output_indices.shape + and np.allclose(golden_values, output_values) + and np.array_equal(golden_indices, output_indices) + ) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmrgsort4/golden.py b/test/vpto/cases/micro-op/dsa-sfu/vmrgsort4/golden.py new file mode 100644 index 000000000..214680570 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmrgsort4/golden.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +import struct +from pathlib import Path + + +PAIR_FMT = "fI" + + +def write_pairs(path: Path, pairs) -> None: + with path.open("wb") as f: + for score, index in pairs: + f.write(struct.pack(PAIR_FMT, score, index)) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + + out = args.output_dir + out.mkdir(parents=True, exist_ok=True) + + src = [(9.0, 90), (7.0, 70), (8.0, 80), (6.0, 60)] + golden = [(9.0, 90), (8.0, 80), (7.0, 70), (6.0, 60)] + + write_pairs(out / "v1.bin", src) + write_pairs(out / "v2.bin", [(0.0, 0)] * 4) + write_pairs(out / "golden_v2.bin", golden) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmrgsort4/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vmrgsort4/kernel.pto new file mode 100644 index 000000000..297d4bf51 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmrgsort4/kernel.pto @@ -0,0 +1,39 @@ +module attributes {pto.target_arch = "a5"} { + func.func @vmrgsort4_kernel_f32(%arg0: !pto.ptr, + %arg1: !pto.ptr) { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %c6 = arith.constant 6 : index + %c32_i64 = arith.constant 32 : i64 + %c_count = arith.constant 281479271743489 : i64 + %c_config = arith.constant 3841 : i64 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c32_i64 : i64 -> !pto.ptr + %src1 = pto.addptr %ub_in, %c2 : !pto.ptr -> !pto.ptr + %src2 = pto.addptr %ub_in, %c4 : !pto.ptr -> !pto.ptr + %src3 = pto.addptr %ub_in, %c6 : !pto.ptr -> !pto.ptr + + pto.dma_load %arg0, %ub_in, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vmrgsort4 %ub_out, %ub_in, %src1, %src2, %src3, %c_count, %c_config + : !pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr, + !pto.ptr, i64, i64 + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.dma_store %ub_out, %arg1, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmrgsort4/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vmrgsort4/launch.cpp new file mode 100644 index 000000000..29b0fbc2f --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmrgsort4/launch.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#include + +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vmrgsort4_kernel_f32(__gm__ float *src, + __gm__ float *dst); + +void LaunchVmrgsort4_kernel_f32(float *src, float *dst, void *stream) { + vmrgsort4_kernel_f32<<<1, nullptr, stream>>>((__gm__ float *)src, + (__gm__ float *)dst); +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmrgsort4/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vmrgsort4/main.cpp new file mode 100644 index 000000000..cdece1d0b --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmrgsort4/main.cpp @@ -0,0 +1,84 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmrgsort4_kernel_f32(float *src, float *dst, void *stream); + +int main() { + size_t inputBytes = 32; + size_t outputBytes = 32; + float *srcHost = nullptr; + float *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), inputBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), outputBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, inputBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, outputBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + if (!ReadFile("./v1.bin", inputBytes, srcHost, inputBytes)) { + std::fprintf(stderr, "[ERROR] failed to read v1.bin\n"); + rc = 1; + goto cleanup; + } + if (!ReadFile("./v2.bin", outputBytes, dstHost, outputBytes)) { + std::fprintf(stderr, "[ERROR] failed to read v2.bin\n"); + rc = 1; + goto cleanup; + } + + ACL_CHECK(aclrtMemcpy(srcDevice, inputBytes, srcHost, inputBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, outputBytes, dstHost, outputBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmrgsort4_kernel_f32(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, outputBytes, dstDevice, outputBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, outputBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmrgsort4/stub.cpp b/test/vpto/cases/micro-op/dsa-sfu/vmrgsort4/stub.cpp new file mode 100644 index 000000000..fc5d9299e --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmrgsort4/stub.cpp @@ -0,0 +1,21 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vmrgsort4_kernel_f32(__gm__ float *src, + __gm__ float *dst) { + (void)src; + (void)dst; +} From 6135d61356633580d990a58ef102e9504a0dbc33 Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Sun, 26 Apr 2026 00:12:28 +0800 Subject: [PATCH 179/192] feat: support 64bits dma op --- include/PTO/Transforms/HIVMIntrinsicNaming.h | 68 -- include/PTO/Transforms/VPTOLLVMEmitter.h | 3 - lib/PTO/Transforms/CMakeLists.txt | 1 - lib/PTO/Transforms/HIVMIntrinsicNaming.cpp | 918 ------------------ lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 6 + lib/TileOps/tcvt_template.py | 40 + .../conversion/vcvt-i64-to-f32/compare.py | 42 + .../conversion/vcvt-i64-to-f32/golden.py | 84 ++ .../conversion/vcvt-i64-to-f32/kernel.pto | 55 ++ .../conversion/vcvt-i64-to-f32/launch.cpp | 52 + .../conversion/vcvt-i64-to-f32/main.cpp | 132 +++ .../conversion/vcvt-i64-to-f32/stub.cpp | 29 + tools/ptoas/ptoas.cpp | 27 - 13 files changed, 440 insertions(+), 1017 deletions(-) delete mode 100644 include/PTO/Transforms/HIVMIntrinsicNaming.h delete mode 100644 lib/PTO/Transforms/HIVMIntrinsicNaming.cpp create mode 100755 test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/compare.py create mode 100755 test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/golden.py create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/kernel.pto create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/launch.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/main.cpp create mode 100644 test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/stub.cpp diff --git a/include/PTO/Transforms/HIVMIntrinsicNaming.h b/include/PTO/Transforms/HIVMIntrinsicNaming.h deleted file mode 100644 index f2ea4c899..000000000 --- a/include/PTO/Transforms/HIVMIntrinsicNaming.h +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -#ifndef MLIR_DIALECT_PTO_TRANSFORMS_HIVMINTRINSICNAMING_H -#define MLIR_DIALECT_PTO_TRANSFORMS_HIVMINTRINSICNAMING_H - -#include -#include - -#include "mlir/IR/Operation.h" -#include "mlir/Support/LLVM.h" - -namespace mlir::pto { - -struct NamingInputs { - std::string sourceOpName; - std::string family; - std::string vectorShape; - std::string elementType; - std::vector usedFields; - std::vector missingFields; -}; - -struct UnresolvedEmissionRecord { - std::string sourceOpName; - std::string placeholderName; - std::string candidateName; - std::vector usedFields; - std::vector missingFields; - std::string resultTypeFragment; - std::string location; -}; - -struct IntrinsicSelection { - bool resolved = false; - std::string sourceOpName; - std::string calleeName; - std::string placeholderName; - std::string candidateName; - std::vector usedFields; - std::vector missingFields; - std::string resultTypeFragment; - std::string location; - - std::string getEmittedCallee() const { - return resolved ? calleeName : placeholderName; - } - - UnresolvedEmissionRecord asUnresolvedRecord() const { - return UnresolvedEmissionRecord{sourceOpName, placeholderName, candidateName, - usedFields, missingFields, resultTypeFragment, - location}; - } -}; - -FailureOr selectIntrinsic(Operation *op); -FailureOr selectLoadIntrinsic(Operation *op); -FailureOr selectUnaryIntrinsic(Operation *op); -FailureOr selectStoreIntrinsic(Operation *op); - -} // namespace mlir::pto - -#endif // MLIR_DIALECT_PTO_TRANSFORMS_HIVMINTRINSICNAMING_H diff --git a/include/PTO/Transforms/VPTOLLVMEmitter.h b/include/PTO/Transforms/VPTOLLVMEmitter.h index 625d5b2fa..dd747740c 100644 --- a/include/PTO/Transforms/VPTOLLVMEmitter.h +++ b/include/PTO/Transforms/VPTOLLVMEmitter.h @@ -26,9 +26,6 @@ namespace mlir::pto { struct VPTOEmissionOptions { bool dumpVPTOIR = false; - bool printIntrinsicSelections = false; - bool allowUnresolved = true; - std::string unresolvedReportPath; std::string targetTriple; std::string march; std::string aicoreArch; diff --git a/lib/PTO/Transforms/CMakeLists.txt b/lib/PTO/Transforms/CMakeLists.txt index e79c97f64..1cab801ad 100644 --- a/lib/PTO/Transforms/CMakeLists.txt +++ b/lib/PTO/Transforms/CMakeLists.txt @@ -12,7 +12,6 @@ # See LICENSE in the root of the software repository for the full text of the License. add_mlir_dialect_library(PTOTransforms - HIVMIntrinsicNaming.cpp VPTOLLVMEmitter.cpp VPTOLLVMEmitterHelper.cpp VPTOPtrNormalize.cpp diff --git a/lib/PTO/Transforms/HIVMIntrinsicNaming.cpp b/lib/PTO/Transforms/HIVMIntrinsicNaming.cpp deleted file mode 100644 index 6eb78bd43..000000000 --- a/lib/PTO/Transforms/HIVMIntrinsicNaming.cpp +++ /dev/null @@ -1,918 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -//===- HIVMIntrinsicNaming.cpp - HIVM intrinsic selection -----------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "PTO/Transforms/HIVMIntrinsicNaming.h" - -#include "PTO/IR/PTO.h" - -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/IR/BuiltinTypes.h" -#include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/raw_ostream.h" - -#include -#include - -using namespace mlir; - -namespace mlir::pto { -namespace { - -static std::string getLocationString(Location loc) { - std::string storage; - llvm::raw_string_ostream os(storage); - loc.print(os); - return storage; -} - -static std::string sanitizeNameFragment(llvm::StringRef text) { - std::string out; - out.reserve(text.size()); - for (char c : text) { - if (std::isalnum(static_cast(c)) || c == '.' || c == '_') - out.push_back(c); - else - out.push_back('_'); - } - return out; -} - -static std::string printAttrText(Attribute attr) { - std::string storage; - llvm::raw_string_ostream os(storage); - os << attr; - return storage; -} - -static std::string getElementTypeFragment(Type type) { - if (type.isF16()) - return "f16"; - if (type.isBF16()) - return "bf16"; - if (type.isF32()) - return "f32"; - if (auto intType = dyn_cast(type)) - return (intType.isUnsigned() ? "u" : "s") + std::to_string(intType.getWidth()); - return "unknown"; -} - -static std::string getVectorTypeFragment(Type type) { - auto vecType = dyn_cast(type); - if (!vecType) - return {}; - return ("v" + std::to_string(vecType.getElementCount()) + - getElementTypeFragment(vecType.getElementType())); -} - -static std::string getCopyElementFragment(Type type) { - auto ptrType = dyn_cast(type); - if (!ptrType) - return {}; - Type elementType = ptrType.getElementType(); - if (elementType.isF16()) - return "f16"; - if (elementType.isBF16()) - return "bf16"; - if (elementType.isF32()) - return "f32"; - std::string typeText; - llvm::raw_string_ostream os(typeText); - elementType.print(os); - os.flush(); - std::string lower = StringRef(typeText).lower(); - if (StringRef(lower).contains("e4m3")) - return "e4m3"; - if (StringRef(lower).contains("e5m2")) - return "e5m2"; - if (StringRef(lower).contains("e8m0")) - return "e8m0"; - if (StringRef(lower).contains("hif8")) - return "hif8"; - if (auto intType = dyn_cast(elementType)) { - switch (intType.getWidth()) { - case 8: - return "u8"; - case 16: - return "u16"; - case 32: - return "u32"; - default: - return {}; - } - } - return {}; -} - -static bool isMxElementType(Type type) { - if (auto floatType = dyn_cast(type)) - return floatType.getWidth() == 8; - std::string typeText; - llvm::raw_string_ostream os(typeText); - type.print(os); - os.flush(); - return StringRef(typeText).starts_with("f8"); -} - -static std::string getMadMxElementFragment(Type type) { - if (type.isF16()) - return "f16"; - if (type.isBF16()) - return "bf16"; - - std::string typeText; - llvm::raw_string_ostream os(typeText); - type.print(os); - os.flush(); - - std::string lower = StringRef(typeText).lower(); - if (StringRef(lower).contains("e4m3")) - return "e4m3"; - if (StringRef(lower).contains("e5m2")) - return "e5m2"; - if (StringRef(lower).contains("hif4")) - return "hif4"; - if (StringRef(lower).contains("e2m1x2")) - return "e2m1x2"; - if (StringRef(lower).contains("e1m2x2")) - return "e1m2x2"; - return {}; -} - -static std::string buildMadMxIntrinsicName(Type lhsType, Type rhsType) { - std::string lhs = getMadMxElementFragment(lhsType); - std::string rhs = getMadMxElementFragment(rhsType); - if (lhs.empty() || rhs.empty()) - return {}; - return "llvm.hivm.MMAD.MX." + lhs + rhs; -} - -static std::string getMadRhsFragment(Type type) { - if (type.isF16()) - return "f16"; - if (type.isBF16()) - return "bf16"; - if (type.isF32()) - return "f32"; - if (auto intType = dyn_cast(type)) { - if (intType.isSigned() && intType.getWidth() == 4) - return "s4"; - if (intType.isSigned() && intType.getWidth() == 8) - return "s8"; - if (intType.isUnsigned() && intType.getWidth() == 2) - return "u2"; - } - - std::string typeText; - llvm::raw_string_ostream os(typeText); - type.print(os); - os.flush(); - std::string lower = StringRef(typeText).lower(); - if (StringRef(lower).contains("e8m0")) - return "e8m0"; - return {}; -} - -static std::string getMadDstFragment(Type type) { - if (type.isF16()) - return "f16"; - if (type.isF32()) - return "f32"; - if (auto intType = dyn_cast(type)) { - if (intType.isSigned() && intType.getWidth() == 32) - return "s32"; - } - return {}; -} - -static std::string buildMadIntrinsicName(Type lhsType, Type rhsType, - Type dstType) { - std::string rhs = getMadRhsFragment(rhsType); - std::string dst = getMadDstFragment(dstType); - if (lhsType.isF16() && rhs == "f16" && dst == "f32") - return "llvm.hivm.MAD.f162f32.c310"; - if (lhsType.isF16() && rhs == "f16" && dst == "f16") - return "llvm.hivm.MAD.f162f16"; - if (lhsType.isF16() && rhs == "f16" && dst == "s32") - return "llvm.hivm.MAD.f162s32.1952"; - if (lhsType.isBF16() && rhs == "bf16" && dst == "f32") - return "llvm.hivm.MAD.bf162f32.c310"; - if (lhsType.isF32() && rhs == "f32" && dst == "f32") - return "llvm.hivm.MAD.f322f32.c310"; - if (lhsType.isF16() && rhs == "s4") - return "llvm.hivm.MAD.f16s4.c310"; - if (lhsType.isF16() && rhs == "s8") - return "llvm.hivm.MAD.f16s8.c310"; - if (lhsType.isF16() && rhs == "u2") - return "llvm.hivm.MAD.f16u2"; - if (lhsType.isF16() && rhs == "e8m0") - return "llvm.hivm.MAD.f16e8m0.c310"; - return {}; -} - -static std::string getOpMnemonic(Operation *op) { - return op->getName().stripDialect().str(); -} - -static IntrinsicSelection makeResolved(Operation *op, llvm::StringRef calleeName, - llvm::ArrayRef usedFields, - llvm::StringRef resultTypeFragment) { - IntrinsicSelection selection; - selection.resolved = true; - selection.sourceOpName = op->getName().getStringRef().str(); - selection.calleeName = calleeName.str(); - selection.usedFields.assign(usedFields.begin(), usedFields.end()); - selection.resultTypeFragment = resultTypeFragment.str(); - selection.location = getLocationString(op->getLoc()); - return selection; -} - -static IntrinsicSelection makeUnresolved(Operation *op, - llvm::StringRef familyOrOp, - llvm::StringRef candidateName, - llvm::ArrayRef usedFields, - llvm::ArrayRef missingFields, - llvm::StringRef resultTypeFragment) { - IntrinsicSelection selection; - selection.resolved = false; - selection.sourceOpName = op->getName().getStringRef().str(); - selection.candidateName = candidateName.str(); - selection.usedFields.assign(usedFields.begin(), usedFields.end()); - selection.missingFields.assign(missingFields.begin(), missingFields.end()); - selection.resultTypeFragment = resultTypeFragment.str(); - selection.location = getLocationString(op->getLoc()); - - std::string name = "__ptoas_hivm_unresolved."; - name += sanitizeNameFragment(familyOrOp); - if (!resultTypeFragment.empty()) { - name += "."; - name += sanitizeNameFragment(resultTypeFragment); - } - selection.placeholderName = std::move(name); - return selection; -} - -static StringRef getMemBarIntrinsicName(MemBarKind kind) { - switch (kind) { - case MemBarKind::VV_ALL: - return "llvm.hivm.mem.bar.vv.all"; - case MemBarKind::VST_VLD: - return "llvm.hivm.mem.bar.vst.vld"; - case MemBarKind::VLD_VST: - return "llvm.hivm.mem.bar.vld.vst"; - case MemBarKind::VST_VST: - return "llvm.hivm.mem.bar.vst.vst"; - case MemBarKind::VS_ALL: - return "llvm.hivm.mem.bar.vs.all"; - case MemBarKind::VST_LD: - return "llvm.hivm.mem.bar.vst.ld"; - case MemBarKind::VLD_ST: - return "llvm.hivm.mem.bar.vld.st"; - case MemBarKind::VST_ST: - return "llvm.hivm.mem.bar.vst.st"; - case MemBarKind::SV_ALL: - return "llvm.hivm.mem.bar.sv.all"; - case MemBarKind::ST_VLD: - return "llvm.hivm.mem.bar.st.vld"; - case MemBarKind::LD_VST: - return "llvm.hivm.mem.bar.ld.vst"; - case MemBarKind::ST_VST: - return "llvm.hivm.mem.bar.st.vst"; - case MemBarKind::SS_ALL: - return "llvm.hivm.mem.bar.ss.all"; - case MemBarKind::ST_LD: - return "llvm.hivm.mem.bar.st.ld"; - case MemBarKind::LD_ST: - return "llvm.hivm.mem.bar.ld.st"; - case MemBarKind::ST_ST: - return "llvm.hivm.mem.bar.st.st"; - } - llvm_unreachable("unexpected membar kind"); -} - -static FailureOr selectSyncLike(Operation *op) { - llvm::SmallVector usedFields; - usedFields.push_back("op=" + getOpMnemonic(op)); - - if (auto setFlag = dyn_cast(op)) { - usedFields.push_back("src_pipe=" + printAttrText(setFlag.getSrcPipe())); - usedFields.push_back("dst_pipe=" + printAttrText(setFlag.getDstPipe())); - usedFields.push_back("event=" + printAttrText(setFlag.getEventId())); - return makeResolved(op, "llvm.hivm.SET.FLAG.IMM", usedFields, ""); - } else if (auto waitFlag = dyn_cast(op)) { - usedFields.push_back("src_pipe=" + printAttrText(waitFlag.getSrcPipe())); - usedFields.push_back("dst_pipe=" + printAttrText(waitFlag.getDstPipe())); - usedFields.push_back("event=" + printAttrText(waitFlag.getEventId())); - return makeResolved(op, "llvm.hivm.WAIT.FLAG.IMM", usedFields, ""); - } else if (auto barrier = dyn_cast(op)) { - usedFields.push_back("pipe=" + printAttrText(barrier.getPipe())); - return makeResolved(op, "llvm.hivm.BARRIER", usedFields, ""); - } else if (auto membar = dyn_cast(op)) { - usedFields.push_back("kind=" + printAttrText(membar.getKind())); - return makeResolved(op, getMemBarIntrinsicName(membar.getKind().getKind()), - usedFields, ""); - } - - llvm::SmallVector missingFields = {"confirmed_hivm_name"}; - return makeUnresolved(op, getOpMnemonic(op), "", usedFields, missingFields, ""); -} - -static FailureOr selectConfigLike(Operation *op) { - llvm::SmallVector usedFields = {"op=" + getOpMnemonic(op)}; - - if (isa(op)) - return makeResolved(op, "llvm.hivm.SET.LOOP2.STRIDE.OUTTOUB", usedFields, - ""); - if (isa(op)) - return makeResolved(op, "llvm.hivm.SET.LOOP1.STRIDE.OUTTOUB", - usedFields, ""); - if (isa(op)) - return makeResolved(op, "llvm.hivm.SET.LOOP.SIZE.OUTTOUB", usedFields, ""); - if (isa(op)) - return makeResolved(op, "llvm.hivm.SET.LOOP2.STRIDE.OUTTOL1", usedFields, - ""); - if (isa(op)) - return makeResolved(op, "llvm.hivm.SET.LOOP1.STRIDE.OUTTOL1", usedFields, - ""); - if (isa(op)) - return makeResolved(op, "llvm.hivm.SET.LOOP.SIZE.OUTTOL1", usedFields, ""); - if (isa(op)) - return makeResolved(op, "llvm.hivm.SET.LOOP2.STRIDE.UBTOOUT", usedFields, - ""); - if (isa(op)) - return makeResolved(op, "llvm.hivm.SET.LOOP1.STRIDE.UBTOOUT", usedFields, - ""); - if (isa(op)) - return makeResolved(op, "llvm.hivm.SET.LOOP.SIZE.UBTOOUT", usedFields, ""); - if (isa(op)) - return makeResolved(op, "llvm.hivm.SET.MTE2.NZ.PARA", usedFields, ""); - if (isa(op)) - return makeResolved(op, "llvm.hivm.SET.MOV.PAD.VAL", usedFields, ""); - if (isa(op)) - return makeResolved(op, "llvm.hivm.SET.PAD.VAL.OUTTOL1", usedFields, ""); - if (isa(op)) - return makeResolved(op, "llvm.hivm.SET.FPC", usedFields, ""); - if (isa(op)) - return makeResolved(op, "llvm.hivm.SET.ATOMIC.S32", usedFields, ""); - if (isa(op)) - return makeResolved(op, "llvm.hivm.SET.ATOMIC.S8", usedFields, ""); - - llvm::SmallVector missingFields = {"confirmed_hivm_name"}; - return makeUnresolved(op, getOpMnemonic(op), "", usedFields, missingFields, - ""); -} - -static FailureOr selectPredicateIntrinsic(Operation *op) { - llvm::SmallVector usedFields; - if (auto pset = dyn_cast(op)) { - const std::string resultFragment = - getVectorTypeFragment(pset.getResult().getType()); - usedFields = {"family=pset", "bitwidth=8", "result=" + resultFragment, - "pattern=i32"}; - return makeResolved(op, "llvm.hivm.pset.b8", usedFields, resultFragment); - } - if (auto pset = dyn_cast(op)) { - const std::string resultFragment = - getVectorTypeFragment(pset.getResult().getType()); - usedFields = {"family=pset", "bitwidth=16", "result=" + resultFragment, - "pattern=i32"}; - return makeResolved(op, "llvm.hivm.pset.b16", usedFields, resultFragment); - } - if (auto pset = dyn_cast(op)) { - const std::string resultFragment = - getVectorTypeFragment(pset.getResult().getType()); - usedFields = {"family=pset", "bitwidth=32", "result=" + resultFragment, - "pattern=i32"}; - return makeResolved(op, "llvm.hivm.pset.b32", usedFields, resultFragment); - } - if (auto pge = dyn_cast(op)) { - const std::string resultFragment = - getVectorTypeFragment(pge.getResult().getType()); - usedFields = {"family=pge", "bitwidth=8", "result=" + resultFragment, - "pattern=i32", "variant=i32_zero"}; - return makeResolved(op, "llvm.hivm.pge.b8", usedFields, resultFragment); - } - if (auto pge = dyn_cast(op)) { - const std::string resultFragment = - getVectorTypeFragment(pge.getResult().getType()); - usedFields = {"family=pge", "bitwidth=16", "result=" + resultFragment, - "pattern=i32", "variant=i32_zero"}; - return makeResolved(op, "llvm.hivm.pge.b16", usedFields, resultFragment); - } - if (auto pge = dyn_cast(op)) { - const std::string resultFragment = - getVectorTypeFragment(pge.getResult().getType()); - usedFields = {"family=pge", "bitwidth=32", "result=" + resultFragment, - "pattern=i32", "variant=i32_zero"}; - return makeResolved(op, "llvm.hivm.pge.b32", usedFields, resultFragment); - } - if (auto plt = dyn_cast(op)) { - const std::string resultFragment = - getVectorTypeFragment(plt.getMask().getType()); - usedFields = {"family=plt", "bitwidth=8", "result=" + resultFragment, - "variant=v300", "scalar=i32", "scalar_out=i32"}; - return makeResolved(op, "llvm.hivm.plt.b8.v300", usedFields, resultFragment); - } - if (auto plt = dyn_cast(op)) { - const std::string resultFragment = - getVectorTypeFragment(plt.getMask().getType()); - usedFields = {"family=plt", "bitwidth=16", "result=" + resultFragment, - "variant=v300", "scalar=i32", "scalar_out=i32"}; - return makeResolved(op, "llvm.hivm.plt.b16.v300", usedFields, resultFragment); - } - if (auto plt = dyn_cast(op)) { - const std::string resultFragment = - getVectorTypeFragment(plt.getMask().getType()); - usedFields = {"family=plt", "bitwidth=32", "result=" + resultFragment, - "variant=v300", "scalar=i32", "scalar_out=i32"}; - return makeResolved(op, "llvm.hivm.plt.b32.v300", usedFields, resultFragment); - } - - return failure(); -} - -} // namespace - -FailureOr selectLoadIntrinsic(Operation *op) { - if (auto vlds = dyn_cast(op)) { - const std::string vecFragment = getVectorTypeFragment(vlds.getResult().getType()); - llvm::SmallVector usedFields = { - "family=vldsx1", "vector=" + vecFragment, "mode=NO_POST_UPDATE"}; - if (vlds.getDistAttr()) - usedFields.push_back("dist=" + (*vlds.getDist()).str()); - - if (vecFragment == "v64f32") - return makeResolved(op, "llvm.hivm.vldsx1", usedFields, vecFragment); - - llvm::SmallVector missingFields = {"confirmed_hivm_name"}; - std::string candidate = "llvm.hivm.vldsx1"; - return makeUnresolved(op, "vldsx1", candidate, usedFields, missingFields, - vecFragment); - } - - if (auto vldsPost = dyn_cast(op)) { - const std::string vecFragment = - getVectorTypeFragment(vldsPost.getResult().getType()); - llvm::SmallVector usedFields = { - "family=vldsx1", "variant=post", "vector=" + vecFragment, - "mode=POST_UPDATE"}; - if (vldsPost.getDistAttr()) - usedFields.push_back("dist=" + (*vldsPost.getDist()).str()); - - if (vecFragment == "v64f32") - return makeResolved(op, "llvm.hivm.vldsx1.post", usedFields, vecFragment); - - llvm::SmallVector missingFields = {"confirmed_hivm_name"}; - std::string candidate = "llvm.hivm.vldsx1.post"; - return makeUnresolved(op, "vldsx1.post", candidate, usedFields, - missingFields, vecFragment); - } - - return failure(); -} - -FailureOr selectUnaryIntrinsic(Operation *op) { - auto vabs = dyn_cast(op); - if (vabs) { - const std::string vecFragment = getVectorTypeFragment(vabs.getResult().getType()); - llvm::SmallVector usedFields = { - "family=vabs", "vector=" + vecFragment, "variant=x"}; - - if (vecFragment == "v64f32") - return makeResolved(op, "llvm.hivm.vabs.v64f32.x", usedFields, vecFragment); - - llvm::SmallVector missingFields = {"confirmed_hivm_name"}; - std::string candidate = "llvm.hivm.vabs"; - if (!vecFragment.empty()) - candidate += "." + vecFragment + ".x"; - return makeUnresolved(op, "vabs", candidate, usedFields, missingFields, - vecFragment); - } - - if (auto vexp = dyn_cast(op)) { - const std::string vecFragment = getVectorTypeFragment(vexp.getResult().getType()); - llvm::SmallVector usedFields = { - "family=vexp", "vector=" + vecFragment, "variant=x"}; - std::string candidate = "llvm.hivm.vexp"; - if (!vecFragment.empty()) - candidate += "." + vecFragment + ".x"; - return makeResolved(op, candidate, usedFields, vecFragment); - } - - if (auto vdup = dyn_cast(op)) { - const std::string vecFragment = getVectorTypeFragment(vdup.getResult().getType()); - const bool vectorInput = isa(vdup.getInput().getType()); - const StringRef position = vdup.getPosition().value_or("LOWEST"); - const char *family = - vectorInput ? (position == "HIGHEST" ? "vdupm" : "vdup") : "vdups"; - llvm::SmallVector usedFields = { - "family=" + std::string(family), "vector=" + vecFragment, - "variant=z"}; - if (!vectorInput && !isa(vdup.getInput().getType())) { - llvm::SmallVector missingFields = {"scalar_input_vdup_mapping"}; - return makeUnresolved(op, "vdup", "llvm.hivm.vdups", usedFields, missingFields, - vecFragment); - } - std::string candidate = "llvm.hivm."; - candidate += family; - if (!vecFragment.empty()) - candidate += "." + vecFragment + ".z"; - return makeResolved(op, candidate, usedFields, vecFragment); - } - - if (auto binary = dyn_cast(op)) { - const std::string vecFragment = getVectorTypeFragment(binary.getResult().getType()); - llvm::SmallVector usedFields = { - "family=vadd", "vector=" + vecFragment, "variant=x"}; - std::string candidate = "llvm.hivm.vadd"; - if (!vecFragment.empty()) - candidate += "." + vecFragment + ".x"; - return makeResolved(op, candidate, usedFields, vecFragment); - } - - if (auto binary = dyn_cast(op)) { - const std::string vecFragment = getVectorTypeFragment(binary.getResult().getType()); - llvm::SmallVector usedFields = { - "family=vsub", "vector=" + vecFragment, "variant=x"}; - std::string candidate = "llvm.hivm.vsub"; - if (!vecFragment.empty()) - candidate += "." + vecFragment + ".x"; - return makeResolved(op, candidate, usedFields, vecFragment); - } - - if (auto binary = dyn_cast(op)) { - const std::string vecFragment = getVectorTypeFragment(binary.getResult().getType()); - llvm::SmallVector usedFields = { - "family=vmul", "vector=" + vecFragment, "variant=x"}; - std::string candidate = "llvm.hivm.vmul"; - if (!vecFragment.empty()) - candidate += "." + vecFragment + ".x"; - return makeResolved(op, candidate, usedFields, vecFragment); - } - - if (auto binary = dyn_cast(op)) { - const std::string vecFragment = getVectorTypeFragment(binary.getResult().getType()); - llvm::SmallVector usedFields = { - "family=vmax", "vector=" + vecFragment, "variant=x"}; - std::string candidate = "llvm.hivm.vmax"; - if (!vecFragment.empty()) - candidate += "." + vecFragment + ".x"; - return makeResolved(op, candidate, usedFields, vecFragment); - } - - if (auto binary = dyn_cast(op)) { - const std::string vecFragment = getVectorTypeFragment(binary.getResult().getType()); - llvm::SmallVector usedFields = { - "family=vmuls", "vector=" + vecFragment, "variant=x"}; - std::string candidate = "llvm.hivm.vmuls"; - if (!vecFragment.empty()) - candidate += "." + vecFragment + ".x"; - return makeResolved(op, candidate, usedFields, vecFragment); - } - - if (auto binary = dyn_cast(op)) { - const std::string vecFragment = getVectorTypeFragment(binary.getResult().getType()); - llvm::SmallVector usedFields = { - "family=vadds", "vector=" + vecFragment, "variant=x"}; - std::string candidate = "llvm.hivm.vadds"; - if (!vecFragment.empty()) - candidate += "." + vecFragment + ".x"; - return makeResolved(op, candidate, usedFields, vecFragment); - } - - if (auto binary = dyn_cast(op)) { - const std::string vecFragment = getVectorTypeFragment(binary.getResult().getType()); - llvm::SmallVector usedFields = { - "family=vmaxs", "vector=" + vecFragment, "variant=x"}; - std::string candidate = "llvm.hivm.vmaxs"; - if (!vecFragment.empty()) - candidate += "." + vecFragment + ".x"; - return makeResolved(op, candidate, usedFields, vecFragment); - } - - if (auto binary = dyn_cast(op)) { - const std::string vecFragment = getVectorTypeFragment(binary.getResult().getType()); - llvm::SmallVector usedFields = { - "family=vmins", "vector=" + vecFragment, "variant=x"}; - std::string candidate = "llvm.hivm.vmins"; - if (!vecFragment.empty()) - candidate += "." + vecFragment + ".x"; - return makeResolved(op, candidate, usedFields, vecFragment); - } - - if (auto binary = dyn_cast(op)) { - const std::string vecFragment = getVectorTypeFragment(binary.getResult().getType()); - llvm::SmallVector usedFields = { - "family=vlrelu", "vector=" + vecFragment, "variant=x"}; - std::string candidate = "llvm.hivm.vlrelu"; - if (!vecFragment.empty()) - candidate += "." + vecFragment + ".x"; - return makeResolved(op, candidate, usedFields, vecFragment); - } - - if (auto binary = dyn_cast(op)) { - const std::string vecFragment = getVectorTypeFragment(binary.getResult().getType()); - llvm::SmallVector usedFields = { - "family=vshls", "vector=" + vecFragment, "variant=x"}; - std::string candidate = "llvm.hivm.vshls"; - if (!vecFragment.empty()) - candidate += "." + vecFragment + ".x"; - return makeResolved(op, candidate, usedFields, vecFragment); - } - - if (auto binary = dyn_cast(op)) { - const std::string vecFragment = getVectorTypeFragment(binary.getResult().getType()); - llvm::SmallVector usedFields = { - "family=vshrs", "vector=" + vecFragment, "variant=x"}; - std::string candidate = "llvm.hivm.vshrs"; - if (!vecFragment.empty()) - candidate += "." + vecFragment + ".x"; - return makeResolved(op, candidate, usedFields, vecFragment); - } - - return failure(); -} - -FailureOr selectStoreIntrinsic(Operation *op) { - llvm::SmallVector usedFields; - llvm::SmallVector missingFields = {"confirmed_hivm_name"}; - - if (auto vsts = dyn_cast(op)) { - const std::string vecFragment = getVectorTypeFragment(vsts.getValue().getType()); - usedFields = {"family=vstsx1", "vector=" + vecFragment, - "predicate_source=explicit_mask", "mode=NO_POST_UPDATE"}; - if (vsts.getDistAttr()) - usedFields.push_back("dist=" + (*vsts.getDist()).str()); - if (vecFragment == "v64f32") - return makeResolved(op, "llvm.hivm.vstsx1", usedFields, vecFragment); - return makeUnresolved(op, "vstsx1", "llvm.hivm.vstsx1", usedFields, missingFields, - vecFragment); - } - - if (auto vstsPost = dyn_cast(op)) { - const std::string vecFragment = - getVectorTypeFragment(vstsPost.getValue().getType()); - usedFields = {"family=vstsx1", "variant=post", "vector=" + vecFragment, - "predicate_source=explicit_mask", "mode=POST_UPDATE"}; - if (vstsPost.getDistAttr()) - usedFields.push_back("dist=" + (*vstsPost.getDist()).str()); - if (vecFragment == "v64f32") - return makeResolved(op, "llvm.hivm.vstsx1.post", usedFields, - vecFragment); - std::string candidate = "llvm.hivm.vstsx1.post"; - return makeUnresolved(op, "vstsx1.post", candidate, usedFields, - missingFields, vecFragment); - } - - if (auto copy = dyn_cast(op)) { - std::string elemFragment = getCopyElementFragment(copy.getSource().getType()); - usedFields = {"family=copy_gm_to_ubuf"}; - if (!elemFragment.empty()) - usedFields.push_back("element=" + elemFragment); - if (elemFragment == "u8" || elemFragment == "u16" || - elemFragment == "u32" || elemFragment == "f32") { - std::string callee = "llvm.hivm.MOV.OUT.TO.UB.ALIGN.V2."; - callee += elemFragment; - callee += ".DV"; - return makeResolved(op, callee, usedFields, ""); - } - std::string candidate = "llvm.hivm.MOV.OUT.TO.UB.ALIGN.V2"; - if (!elemFragment.empty()) - candidate += "." + elemFragment + ".DV"; - missingFields.push_back("element_type_mapping"); - return makeUnresolved(op, "copy_gm_to_ubuf", candidate, usedFields, - missingFields, ""); - } - - if (auto copy = dyn_cast(op)) { - usedFields = {"family=copy_gm_to_cbuf_multi_nd2nz"}; - return makeResolved(op, "llvm.hivm.MOV.OUT.TO.L1.MULTI.ND2NZ", usedFields, - ""); - } - - if (auto copy = dyn_cast(op)) { - usedFields = {"family=copy_gm_to_cbuf_multi_dn2nz"}; - return makeResolved(op, "llvm.hivm.MOV.OUT.TO.L1.MULTI.DN2NZ", usedFields, - ""); - } - - if (auto matmul = dyn_cast(op)) { - std::string lhsElem = getElementTypeFragment( - cast(matmul.getLhs().getType()).getElementType()); - std::string rhsElem = getElementTypeFragment( - cast(matmul.getRhs().getType()).getElementType()); - std::string dstElem = getElementTypeFragment( - cast(matmul.getDst().getType()).getElementType()); - usedFields = {"family=mad", "lhs=" + lhsElem, "rhs=" + rhsElem, - "dst=" + dstElem, "shape=i64xm_n_k"}; - Type lhsType = cast(matmul.getLhs().getType()).getElementType(); - Type rhsType = cast(matmul.getRhs().getType()).getElementType(); - Type dstType = cast(matmul.getDst().getType()).getElementType(); - std::string madName = buildMadIntrinsicName(lhsType, rhsType, dstType); - if (!madName.empty()) - return makeResolved(op, madName, usedFields, ""); - if (isMxElementType(lhsType) && isMxElementType(rhsType)) { - std::string mxName = buildMadMxIntrinsicName(lhsType, rhsType); - if (!mxName.empty()) - return makeResolved(op, mxName, usedFields, ""); - } - missingFields.push_back("lhs/rhs_element_type_mapping"); - return makeUnresolved(op, "mad", "llvm.hivm.MAD/llvm.hivm.MMAD.MX.*", - usedFields, - missingFields, ""); - } - - if (auto matmulMx = dyn_cast(op)) { - std::string lhsElem = getElementTypeFragment( - cast(matmulMx.getLhs().getType()).getElementType()); - std::string rhsElem = getElementTypeFragment( - cast(matmulMx.getRhs().getType()).getElementType()); - std::string dstElem = getElementTypeFragment( - cast(matmulMx.getDst().getType()).getElementType()); - usedFields = {"family=mad_mx", "lhs=" + lhsElem, "rhs=" + rhsElem, - "dst=" + dstElem, "shape=i64xm_n_k"}; - Type lhsType = - cast(matmulMx.getLhs().getType()).getElementType(); - Type rhsType = - cast(matmulMx.getRhs().getType()).getElementType(); - if (isMxElementType(lhsType) && isMxElementType(rhsType)) { - std::string mxName = buildMadMxIntrinsicName(lhsType, rhsType); - if (!mxName.empty()) - return makeResolved(op, mxName, usedFields, ""); - } - missingFields.push_back("lhs/rhs_mx_element_type"); - return makeUnresolved(op, "mad_mx", "llvm.hivm.MMAD.MX.*", usedFields, - missingFields, ""); - } - - if (auto copy = dyn_cast(op)) { - std::string elemFragment = getCopyElementFragment(copy.getSource().getType()); - usedFields = {"family=copy_gm_to_cbuf"}; - if (!elemFragment.empty()) - usedFields.push_back("element=" + elemFragment); - if (elemFragment.empty()) { - missingFields.push_back("element_type_mapping"); - return makeUnresolved(op, "copy_gm_to_cbuf", - "llvm.hivm.MOV.OUT.TO.L1.ALIGN.V2..DV", - usedFields, missingFields, ""); - } - return makeResolved(op, "llvm.hivm.MOV.OUT.TO.L1.ALIGN.V2." + elemFragment + - ".DV", - usedFields, ""); - } - - if (auto load = dyn_cast(op)) { - std::string srcElem = getElementTypeFragment( - cast(load.getSource().getType()).getElementType()); - std::string dstElem = getElementTypeFragment( - cast(load.getDestination().getType()).getElementType()); - usedFields = {"family=load_cbuf_to_ca", "src=" + srcElem, "dst=" + dstElem, - "shape=i64xm_k"}; - if (srcElem.empty()) { - missingFields.push_back("src_element_type_mapping"); - return makeUnresolved(op, "load_cbuf_to_ca", - "llvm.hivm.LOAD.L1.TO.L0A.2Dv2.", usedFields, - missingFields, ""); - } - return makeResolved(op, "llvm.hivm.LOAD.L1.TO.L0A.2Dv2." + srcElem, - usedFields, ""); - } - - if (auto load = dyn_cast(op)) { - std::string srcElem = getElementTypeFragment( - cast(load.getSource().getType()).getElementType()); - std::string dstElem = getElementTypeFragment( - cast(load.getDestination().getType()).getElementType()); - usedFields = {"family=load_cbuf_to_cb", "src=" + srcElem, "dst=" + dstElem, - "shape=i64xk_n"}; - if (srcElem.empty()) { - missingFields.push_back("src_element_type_mapping"); - return makeUnresolved(op, "load_cbuf_to_cb", - "llvm.hivm.LOAD.L1.TO.L0B.2Dv2.", usedFields, - missingFields, ""); - } - return makeResolved(op, "llvm.hivm.LOAD.L1.TO.L0B.2Dv2." + srcElem, - usedFields, ""); - } - - if (auto copy = dyn_cast(op)) { - std::string srcElem = getElementTypeFragment( - cast(copy.getSource().getType()).getElementType()); - std::string dstElem = getElementTypeFragment( - cast(copy.getDestination().getType()).getElementType()); - usedFields = {"family=copy_matrix_cc_to_gm", "src=" + srcElem, - "dst=" + dstElem, "shape=i64xm_n"}; - return makeResolved(op, "llvm.hivm.FIX.L0C.TO.OUT.f32.EXT", usedFields, ""); - } - - if (auto copy = dyn_cast(op)) { - std::string srcElem = getElementTypeFragment( - cast(copy.getSource().getType()).getElementType()); - std::string dstElem = getElementTypeFragment( - cast(copy.getDestination().getType()).getElementType()); - usedFields = {"family=copy_matrix_cc_to_cbuf", "src=" + srcElem, - "dst=" + dstElem}; - return makeResolved(op, "llvm.hivm.FIX.L0C.TO.L1.f32.EXT", usedFields, ""); - } - - if (auto copy = dyn_cast(op)) { - std::string srcElem = getElementTypeFragment( - cast(copy.getSource().getType()).getElementType()); - std::string dstElem = getElementTypeFragment( - cast(copy.getDestination().getType()).getElementType()); - usedFields = {"family=copy_matrix_cc_to_ub", "src=" + srcElem, - "dst=" + dstElem}; - if (dstElem == "f16") - return makeResolved(op, "llvm.hivm.MOV.L0CDPF32.TO.UB.f322f16", - usedFields, ""); - if (dstElem == "f32") - return makeResolved(op, "llvm.hivm.MOV.L0CDPF32.TO.UB.f322f32", - usedFields, ""); - missingFields.push_back("dst_element_type_mapping"); - return makeUnresolved(op, "copy_matrix_cc_to_ub", - "llvm.hivm.MOV.L0CDPF32.TO.UB.f322f{16|32}", - usedFields, missingFields, ""); - } - - if (auto copy = dyn_cast(op)) { - usedFields = {"family=copy_cbuf_to_bt", "src=f16"}; - return makeResolved(op, "llvm.hivm.MOV.L1.TO.BT.f16", usedFields, ""); - } - - if (auto copy = dyn_cast(op)) { - usedFields = {"family=copy_cbuf_to_fbuf"}; - return makeResolved(op, "llvm.hivm.MOV.L1.TO.FB.V2", usedFields, ""); - } - - if (auto load = dyn_cast(op)) { - usedFields = {"family=load_cbuf_to_ca_s4"}; - return makeResolved(op, "llvm.hivm.LOAD.L1.TO.L0A.2Dv2.s4", usedFields, - ""); - } - - if (auto load = dyn_cast(op)) { - usedFields = {"family=load_cbuf_to_cb_s4"}; - return makeResolved(op, "llvm.hivm.LOAD.L1.TO.L0B.2Dv2.s4", usedFields, - ""); - } - - if (auto copy = dyn_cast(op)) { - std::string elemFragment = getCopyElementFragment(copy.getSource().getType()); - usedFields = {"family=copy_ubuf_to_gm"}; - if (!elemFragment.empty()) - usedFields.push_back("element=" + elemFragment); - return makeResolved(op, "llvm.hivm.MOV.UB.TO.OUT.ALIGN.V2.DV", - usedFields, ""); - } - - if (isa(op)) { - usedFields = {"family=copy_ubuf_to_ubuf"}; - return makeResolved(op, "llvm.hivm.MOV.UB.TO.UB.v310", usedFields, ""); - } - - return failure(); -} - -FailureOr selectIntrinsic(Operation *op) { - if (isa(op)) - return selectSyncLike(op); - - if (isa(op)) - return selectConfigLike(op); - - if (succeeded(selectLoadIntrinsic(op))) - return *selectLoadIntrinsic(op); - if (succeeded(selectUnaryIntrinsic(op))) - return *selectUnaryIntrinsic(op); - if (succeeded(selectPredicateIntrinsic(op))) - return *selectPredicateIntrinsic(op); - if (succeeded(selectStoreIntrinsic(op))) - return *selectStoreIntrinsic(op); - - llvm::SmallVector usedFields = {"op=" + getOpMnemonic(op)}; - llvm::SmallVector missingFields = {"family_mapping", - "confirmed_hivm_name"}; - return makeUnresolved(op, getOpMnemonic(op), "", usedFields, missingFields, - ""); -} - -} // namespace mlir::pto diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index 621fd1098..7ec8d71fc 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -1916,6 +1916,12 @@ static FailureOr buildCopyGmToUbCallee(MLIRContext *context, if (!ptrType) return failure(); Type elementType = ptrType.getElementType(); + if ((isa(elementType) && + cast(elementType).getWidth() == 64) || + elementType.isF64()) { + return StringAttr::get(context, "llvm.hivm.MOV.OUT.TO.UB.ALIGN.V2.s32.DV") + .getValue(); + } std::string elem = getCopyElementFragment(elementType); if (elem.empty()) return failure(); diff --git a/lib/TileOps/tcvt_template.py b/lib/TileOps/tcvt_template.py index 19f177a60..13268b306 100644 --- a/lib/TileOps/tcvt_template.py +++ b/lib/TileOps/tcvt_template.py @@ -182,3 +182,43 @@ def template_tcvt_i32_to_f32(src: pto.Tile, dst: pto.Tile): ) pto.vsts(converted, dst[row, col:], mask) return + + +@pto.vkernel( + target="a5", + op="pto.tcvt", + dtypes=[ + (pto.i64, pto.f32), + ], + constraints=[_supports_basic_rowwise_tcvt], +) +def template_tcvt_i64_to_f32(src: pto.Tile, dst: pto.Tile): + valid_rows, valid_cols = dst.valid_shape + round_mode = pto.get_op_attr("round_mode", "RINT") + rnd = pto.VcvtRoundMode.R + if pto.constexpr(round_mode == "ROUND"): + rnd = pto.VcvtRoundMode.A + elif pto.constexpr(round_mode == "FLOOR"): + rnd = pto.VcvtRoundMode.F + elif pto.constexpr(round_mode == "CEIL"): + rnd = pto.VcvtRoundMode.C + elif pto.constexpr(round_mode == "TRUNC"): + rnd = pto.VcvtRoundMode.Z + elif pto.constexpr(round_mode == "ODD"): + rnd = pto.VcvtRoundMode.O + + full_mask = pto.make_mask(pto.i64, pto.PAT.ALL) + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(pto.f32)): + store_mask, remained = pto.make_mask(pto.f32, remained) + vec = pto.vlds(src[row, col:]) + converted = pto.vcvt( + vec, + pto.f32, + full_mask, + rnd=rnd, + part=pto.VcvtPartMode.EVEN, + ) + pto.vsts(converted, dst[row, col:], store_mask, dist=pto.VStoreDist.PK_B64) + return diff --git a/test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/compare.py b/test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/compare.py new file mode 100755 index 000000000..7fd03a0ed --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# case: micro-op/conversion/vcvt-i64-to-f32 +# family: conversion +# target_ops: pto.dma_load, pto.dma_store, pto.vcvt, pto.vsts +# scenarios: i64-dma-roundtrip, i64-to-f32, signed-input, rounded, part-even-low-half + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str, dtype) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32) + ok = ok and compare_bin("golden_v3.bin", "v3.bin", np.int64) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/golden.py b/test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/golden.py new file mode 100755 index 000000000..7abb10ef7 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/golden.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/conversion/vcvt-i64-to-f32 +# family: conversion +# target_ops: pto.dma_load, pto.dma_store, pto.vcvt, pto.vsts +# scenarios: i64-dma-roundtrip, i64-to-f32, signed-input, rounded, part-even-low-half + +import argparse +from pathlib import Path + +import numpy as np + + +INPUT_ELEMS = 1024 +OUTPUT_ELEMS = 512 +ROUNDTRIP_ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + edge = np.array( + [ + -(1 << 31), + -(1 << 24) - 3, + -(1 << 24) - 1, + -(1 << 24), + -(1 << 24) + 1, + -65537, + -32769, + -32768, + -1, + 0, + 1, + 32767, + 32768, + 65537, + (1 << 24) - 1, + 1 << 24, + (1 << 24) + 1, + (1 << 24) + 3, + (1 << 31) - 2, + (1 << 31) - 1, + ], + dtype=np.int32, + ) + base = rng.integers(np.iinfo(np.int32).min, np.iinfo(np.int32).max, + size=INPUT_ELEMS, dtype=np.int32) + base[: edge.size] = edge + v1 = base.astype(np.int64) + v2 = np.zeros(OUTPUT_ELEMS, dtype=np.float32) + v3 = np.zeros(ROUNDTRIP_ELEMS, dtype=np.int64) + golden_v2 = np.concatenate( + [base[offset : offset + 16] for offset in range(0, INPUT_ELEMS, 32)] + ).astype(np.float32) + golden_v3 = v1.copy() + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vcvt-i64-to-f32 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/kernel.pto new file mode 100644 index 000000000..6c8152203 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/kernel.pto @@ -0,0 +1,55 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-i64-to-f32 +// family: conversion +// target_ops: pto.dma_load, pto.dma_store, pto.vcvt, pto.vsts +// scenarios: i64-dma-roundtrip, i64-to-f32, signed-input, rounded, part-even-low-half +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @vcvt_i64_to_f32_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c512 = arith.constant 512 : index + %c0_i64 = arith.constant 0 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %c8192_i64 = arith.constant 8192 : i64 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.dma_load %arg0, %ub_in, %c0_i64, %c256_i64 + nburst(%c32_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.set_flag["PIPE_MTE2", "PIPE_MTE3", "EVENT_ID1"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE3", "EVENT_ID1"] + pto.dma_store %ub_in, %arg2, %c256_i64 + nburst(%c32_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + scf.for %store_offset = %c0 to %c512 step %c16 { + %input_offset = arith.muli %store_offset, %c2 : index + %loaded = pto.vlds %ub_in[%input_offset] : !pto.ptr -> !pto.vreg<32xsi64> + %converted = pto.vcvt %loaded {rnd = "R", part = "EVEN"} : !pto.vreg<32xsi64> -> !pto.vreg<64xf32> + pto.vsts %converted, %ub_out[%store_offset], %mask {dist = "PK_B64"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.dma_store %ub_out, %arg1, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/launch.cpp b/test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/launch.cpp new file mode 100644 index 000000000..6d6ef0dbf --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/launch.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-i64-to-f32 +// family: conversion +// target_ops: pto.dma_load, pto.dma_store, pto.vcvt, pto.vsts +// scenarios: i64-dma-roundtrip, i64-to-f32, signed-input, rounded, part-even-low-half +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcvt_i64_to_f32_kernel( + __gm__ int64_t *v1, __gm__ float *v2, __gm__ int64_t *v3); + +void LaunchVcvt_i64_to_f32_kernel(int64_t *v1, float *v2, int64_t *v3, + void *stream) { + vcvt_i64_to_f32_kernel<<<1, nullptr, stream>>>( + (__gm__ int64_t *)v1, (__gm__ float *)v2, (__gm__ int64_t *)v3); +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/main.cpp b/test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/main.cpp new file mode 100644 index 000000000..194361b3e --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/main.cpp @@ -0,0 +1,132 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-i64-to-f32 +// family: conversion +// target_ops: pto.dma_load, pto.dma_store, pto.vcvt, pto.vsts +// scenarios: i64-dma-roundtrip, i64-to-f32, signed-input, rounded, part-even-low-half +// ----------------------------------------------------------------------------- + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcvt_i64_to_f32_kernel(int64_t *v1, float *v2, int64_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(int64_t); + size_t elemCount_v2 = 512; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(int64_t); + int64_t *v1Host = nullptr; + int64_t *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int64_t *v3Host = nullptr; + int64_t *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcvt_i64_to_f32_kernel(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/stub.cpp b/test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/stub.cpp new file mode 100644 index 000000000..faaf00df8 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/stub.cpp @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-i64-to-f32 +// family: conversion +// target_ops: pto.dma_load, pto.dma_store, pto.vcvt, pto.vsts +// scenarios: i64-dma-roundtrip, i64-to-f32, signed-input, rounded, part-even-low-half +// ----------------------------------------------------------------------------- + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void vcvt_i64_to_f32_kernel( + __gm__ int64_t *v1, __gm__ float *v2, __gm__ int64_t *v3) { + (void)v1; + (void)v2; + (void)v3; +} diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index e2bb6d134..3eb5d43a7 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -286,11 +286,6 @@ static llvm::cl::opt ptoSeamIRFile( llvm::cl::value_desc("path"), llvm::cl::init("")); -static llvm::cl::opt vptoPrintIntrinsics( - "vpto-print-intrinsics", - llvm::cl::desc("Print VPTO intrinsic selection decisions to stderr"), - llvm::cl::init(false)); - static llvm::cl::opt vptoEmitHIVMOfficialLLVM( "vpto-emit-hivm-llvm", llvm::cl::desc("After lowering to VPTO IR, emit textual LLVM/HIVM via " @@ -303,16 +298,6 @@ static llvm::cl::opt vptoEmitHIVMOfficialBitcode( "official LLVM dialect export path"), llvm::cl::init(false)); -static llvm::cl::opt vptoAllowUnresolved( - "vpto-allow-unresolved", - llvm::cl::desc("Emit explicit unresolved VPTO comments instead of failing"), - llvm::cl::init(false)); - -static llvm::cl::opt vptoUnresolvedReport( - "vpto-unresolved-report", - llvm::cl::desc("Write unresolved VPTO mappings to a sidecar report"), - llvm::cl::value_desc("path"), llvm::cl::init("")); - static llvm::cl::opt vptoMarch( "vpto-march", llvm::cl::desc("Bisheng -march for VPTO HIVM LLVM emission (default: " @@ -325,12 +310,6 @@ static llvm::cl::opt vptoCceAicoreArch( "queries (default: same as --vpto-march)."), llvm::cl::value_desc("arch"), llvm::cl::init("")); -static llvm::cl::opt hivmUnresolvedReport( - "hivm-unresolved-report", - llvm::cl::desc("Write unresolved HIVM mappings to a sidecar report"), - llvm::cl::value_desc("path"), - llvm::cl::init("")); - enum class PTOBuildLevel { Level1, Level2, @@ -1225,10 +1204,6 @@ static LogicalResult inlineTilelangHelpersOnVPTOInput(ModuleOp module) { static pto::VPTOEmissionOptions buildVPTOEmissionOptions() { pto::VPTOEmissionOptions options; options.dumpVPTOIR = false; - options.printIntrinsicSelections = vptoPrintIntrinsics; - options.allowUnresolved = vptoAllowUnresolved; - options.unresolvedReportPath = - !hivmUnresolvedReport.empty() ? hivmUnresolvedReport : vptoUnresolvedReport; options.targetTriple = "hiipu64-hisilicon-cce"; const std::string kVecMarch = "dav-c310-vec"; @@ -1358,8 +1333,6 @@ int main(int argc, char **argv) { if (effectiveBackend != PTOBackend::VPTO && (vptoEmitHIVMOfficialLLVM || vptoEmitHIVMOfficialBitcode || emitVPTO || - vptoPrintIntrinsics || vptoAllowUnresolved || - !vptoUnresolvedReport.empty() || !hivmUnresolvedReport.empty() || ptoPrintSeamIR || !ptoSeamIRFile.empty())) { llvm::errs() << "Error: VPTO-specific flags require " "--pto-backend=vpto.\n"; From 0d03772339b3ad7f639ecab8dbc28b3fb582bcbe Mon Sep 17 00:00:00 2001 From: bingmeiyou <39639340+bingmeiyou@users.noreply.github.com> Date: Mon, 27 Apr 2026 14:13:08 +0800 Subject: [PATCH 180/192] add tcolexpand series op (#169) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: add TileOps templates and basic test cases for tcolexpand operations * test: add tcolexpand operators test cases * fix: 添加TODO说明tcolexpanddiv需要高精度版本 * feat: fp32使用vexpdif实现tcolexpandexpdif,fp16使用vsub+vexp * fix: add PR386 license headers to template and test files - Add license headers to 7 tcolexpand*_template.py files - Add license headers to test case files (CMakeLists.txt, compare.py, launch.cpp, main.cpp, gen_data.py, cases.py) * feat: register tcolexpand operators in CMakeLists.txt * fix: replace aclFloat16 with uint16_t in tcolexpand test cases - Replace aclFloat16 with uint16_t in main.cpp and launch.cpp (16 files) - Remove duplicate license headers in 5 launch.cpp files - Fix .pto comments: aclFloat16 -> fp16 - Remove unnecessary #include "acl/acl.h" from launch.cpp files - Align with tpartmax implementation pattern --------- Co-authored-by: User --- lib/TileOps/tcolexpand_template.py | 29 + lib/TileOps/tcolexpandadd_template.py | 31 + lib/TileOps/tcolexpanddiv_template.py | 32 + lib/TileOps/tcolexpandexpdif_template.py | 57 ++ lib/TileOps/tcolexpandmax_template.py | 31 + lib/TileOps/tcolexpandmin_template.py | 31 + lib/TileOps/tcolexpandmul_template.py | 31 + lib/TileOps/tcolexpandsub_template.py | 31 + test/basic/tcolexpand_tilelang.pto | 33 + test/basic/tcolexpandadd_tilelang.pto | 39 ++ test/basic/tcolexpanddiv_tilelang.pto | 39 ++ test/basic/tcolexpandexpdif_tilelang.pto | 40 ++ test/basic/tcolexpandmax_tilelang.pto | 39 ++ test/basic/tcolexpandmin_tilelang.pto | 39 ++ test/basic/tcolexpandmul_tilelang.pto | 39 ++ test/basic/tcolexpandsub_tilelang.pto | 39 ++ .../npu/a5/src/st/testcase/CMakeLists.txt | 8 + .../src/st/testcase/tcolexpand/CMakeLists.txt | 16 + .../a5/src/st/testcase/tcolexpand/cases.py | 75 ++ .../a5/src/st/testcase/tcolexpand/compare.py | 56 ++ .../a5/src/st/testcase/tcolexpand/gen_data.py | 34 + .../a5/src/st/testcase/tcolexpand/launch.cpp | 55 ++ .../a5/src/st/testcase/tcolexpand/main.cpp | 143 ++++ .../src/st/testcase/tcolexpand/tcolexpand.pto | 327 +++++++++ .../st/testcase/tcolexpandadd/CMakeLists.txt | 9 + .../a5/src/st/testcase/tcolexpandadd/cases.py | 77 +++ .../src/st/testcase/tcolexpandadd/compare.py | 56 ++ .../src/st/testcase/tcolexpandadd/gen_data.py | 38 + .../src/st/testcase/tcolexpandadd/launch.cpp | 54 ++ .../a5/src/st/testcase/tcolexpandadd/main.cpp | 158 +++++ .../testcase/tcolexpandadd/tcolexpandadd.pto | 440 ++++++++++++ .../st/testcase/tcolexpanddiv/CMakeLists.txt | 9 + .../a5/src/st/testcase/tcolexpanddiv/cases.py | 115 +++ .../src/st/testcase/tcolexpanddiv/compare.py | 56 ++ .../src/st/testcase/tcolexpanddiv/gen_data.py | 38 + .../src/st/testcase/tcolexpanddiv/launch.cpp | 75 ++ .../a5/src/st/testcase/tcolexpanddiv/main.cpp | 167 +++++ .../testcase/tcolexpanddiv/tcolexpanddiv.pto | 652 ++++++++++++++++++ .../testcase/tcolexpandexpdif/CMakeLists.txt | 9 + .../src/st/testcase/tcolexpandexpdif/cases.py | 67 ++ .../st/testcase/tcolexpandexpdif/compare.py | 56 ++ .../st/testcase/tcolexpandexpdif/gen_data.py | 39 ++ .../st/testcase/tcolexpandexpdif/launch.cpp | 41 ++ .../src/st/testcase/tcolexpandexpdif/main.cpp | 157 +++++ .../tcolexpandexpdif/tcolexpandexpdif.pto | 296 ++++++++ .../st/testcase/tcolexpandmax/CMakeLists.txt | 16 + .../a5/src/st/testcase/tcolexpandmax/cases.py | 91 +++ .../src/st/testcase/tcolexpandmax/compare.py | 56 ++ .../src/st/testcase/tcolexpandmax/gen_data.py | 41 ++ .../src/st/testcase/tcolexpandmax/launch.cpp | 54 ++ .../a5/src/st/testcase/tcolexpandmax/main.cpp | 167 +++++ .../testcase/tcolexpandmax/tcolexpandmax.pto | 438 ++++++++++++ .../st/testcase/tcolexpandmin/CMakeLists.txt | 16 + .../a5/src/st/testcase/tcolexpandmin/cases.py | 97 +++ .../src/st/testcase/tcolexpandmin/compare.py | 56 ++ .../src/st/testcase/tcolexpandmin/gen_data.py | 42 ++ .../src/st/testcase/tcolexpandmin/launch.cpp | 54 ++ .../a5/src/st/testcase/tcolexpandmin/main.cpp | 167 +++++ .../testcase/tcolexpandmin/tcolexpandmin.pto | 437 ++++++++++++ .../st/testcase/tcolexpandmul/CMakeLists.txt | 16 + .../a5/src/st/testcase/tcolexpandmul/cases.py | 84 +++ .../src/st/testcase/tcolexpandmul/compare.py | 56 ++ .../src/st/testcase/tcolexpandmul/gen_data.py | 42 ++ .../src/st/testcase/tcolexpandmul/launch.cpp | 54 ++ .../a5/src/st/testcase/tcolexpandmul/main.cpp | 165 +++++ .../testcase/tcolexpandmul/tcolexpandmul.pto | 436 ++++++++++++ .../st/testcase/tcolexpandsub/CMakeLists.txt | 16 + .../a5/src/st/testcase/tcolexpandsub/cases.py | 84 +++ .../src/st/testcase/tcolexpandsub/compare.py | 56 ++ .../src/st/testcase/tcolexpandsub/gen_data.py | 35 + .../src/st/testcase/tcolexpandsub/launch.cpp | 54 ++ .../a5/src/st/testcase/tcolexpandsub/main.cpp | 167 +++++ .../testcase/tcolexpandsub/tcolexpandsub.pto | 439 ++++++++++++ 73 files changed, 7339 insertions(+) create mode 100644 lib/TileOps/tcolexpand_template.py create mode 100644 lib/TileOps/tcolexpandadd_template.py create mode 100644 lib/TileOps/tcolexpanddiv_template.py create mode 100644 lib/TileOps/tcolexpandexpdif_template.py create mode 100644 lib/TileOps/tcolexpandmax_template.py create mode 100644 lib/TileOps/tcolexpandmin_template.py create mode 100644 lib/TileOps/tcolexpandmul_template.py create mode 100644 lib/TileOps/tcolexpandsub_template.py create mode 100644 test/basic/tcolexpand_tilelang.pto create mode 100644 test/basic/tcolexpandadd_tilelang.pto create mode 100644 test/basic/tcolexpanddiv_tilelang.pto create mode 100644 test/basic/tcolexpandexpdif_tilelang.pto create mode 100644 test/basic/tcolexpandmax_tilelang.pto create mode 100644 test/basic/tcolexpandmin_tilelang.pto create mode 100644 test/basic/tcolexpandmul_tilelang.pto create mode 100644 test/basic/tcolexpandsub_tilelang.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/tcolexpand.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/tcolexpandadd.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/tcolexpanddiv.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/tcolexpandexpdif.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/tcolexpandmax.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/tcolexpandmin.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/tcolexpandmul.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/tcolexpandsub.pto diff --git a/lib/TileOps/tcolexpand_template.py b/lib/TileOps/tcolexpand_template.py new file mode 100644 index 000000000..5d20fcbe1 --- /dev/null +++ b/lib/TileOps/tcolexpand_template.py @@ -0,0 +1,29 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""TileLang DSL template for pto.tcolexpand""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tcolexpand" +) +def template_tcolexpand(src0: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[0, col:]) + pto.vsts(lhs, dst[row, col:], mask) + return diff --git a/lib/TileOps/tcolexpandadd_template.py b/lib/TileOps/tcolexpandadd_template.py new file mode 100644 index 000000000..287be93dc --- /dev/null +++ b/lib/TileOps/tcolexpandadd_template.py @@ -0,0 +1,31 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""TileLang DSL template for pto.tcolexpandadd""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tcolexpandadd" +) +def template_tcolexpandadd(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[0, col:]) + summed = pto.vadd(lhs, rhs, mask) + pto.vsts(summed, dst[row, col:], mask) + return diff --git a/lib/TileOps/tcolexpanddiv_template.py b/lib/TileOps/tcolexpanddiv_template.py new file mode 100644 index 000000000..37ade4ca0 --- /dev/null +++ b/lib/TileOps/tcolexpanddiv_template.py @@ -0,0 +1,32 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""TileLang DSL template for pto.tcolexpanddiv""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tcolexpanddiv" +) +def template_tcolexpanddiv(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[0, col:]) + # TODO: 当前使用普通精度版本,后续需要添加高精度版本(vdivh) + result = pto.vdiv(lhs, rhs, mask) + pto.vsts(result, dst[row, col:], mask) + return diff --git a/lib/TileOps/tcolexpandexpdif_template.py b/lib/TileOps/tcolexpandexpdif_template.py new file mode 100644 index 000000000..977fa690c --- /dev/null +++ b/lib/TileOps/tcolexpandexpdif_template.py @@ -0,0 +1,57 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""TileLang DSL template for pto.tcolexpandexpdif""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tcolexpandexpdif", + dtypes=[ + (pto.f16, pto.f16, pto.f16), + ], +) +def template_tcolexpandexpdif_f16(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[0, col:]) + diff = pto.vsub(lhs, rhs, mask) + result = pto.vexp(diff, mask) + pto.vsts(result, dst[row, col:], mask) + return + + +@pto.vkernel( + target="a5", + op="pto.tcolexpandexpdif", + dtypes=[ + (pto.f32, pto.f32, pto.f32), + ], +) +def template_tcolexpandexpdif_f32(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[0, col:]) + result = pto.vexpdif(lhs, rhs, pto.VcvtPartMode.ODD) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/tcolexpandmax_template.py b/lib/TileOps/tcolexpandmax_template.py new file mode 100644 index 000000000..79f3699b7 --- /dev/null +++ b/lib/TileOps/tcolexpandmax_template.py @@ -0,0 +1,31 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""TileLang DSL template for pto.tcolexpandmax""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tcolexpandmax" +) +def template_tcolexpandmax(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[0, col:]) + result = pto.vmax(lhs, rhs, mask) + pto.vsts(result, dst[row, col:], mask) + return diff --git a/lib/TileOps/tcolexpandmin_template.py b/lib/TileOps/tcolexpandmin_template.py new file mode 100644 index 000000000..054b35dfc --- /dev/null +++ b/lib/TileOps/tcolexpandmin_template.py @@ -0,0 +1,31 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""TileLang DSL template for pto.tcolexpandmin""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tcolexpandmin" +) +def template_tcolexpandmin(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[0, col:]) + result = pto.vmin(lhs, rhs, mask) + pto.vsts(result, dst[row, col:], mask) + return diff --git a/lib/TileOps/tcolexpandmul_template.py b/lib/TileOps/tcolexpandmul_template.py new file mode 100644 index 000000000..5dcacfa91 --- /dev/null +++ b/lib/TileOps/tcolexpandmul_template.py @@ -0,0 +1,31 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""TileLang DSL template for pto.tcolexpandmul""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tcolexpandmul" +) +def template_tcolexpandmul(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[0, col:]) + result = pto.vmul(lhs, rhs, mask) + pto.vsts(result, dst[row, col:], mask) + return diff --git a/lib/TileOps/tcolexpandsub_template.py b/lib/TileOps/tcolexpandsub_template.py new file mode 100644 index 000000000..f46bbaf11 --- /dev/null +++ b/lib/TileOps/tcolexpandsub_template.py @@ -0,0 +1,31 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""TileLang DSL template for pto.tcolexpandsub""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tcolexpandsub" +) +def template_tcolexpandsub(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[0, col:]) + result = pto.vsub(lhs, rhs, mask) + pto.vsts(result, dst[row, col:], mask) + return diff --git a/test/basic/tcolexpand_tilelang.pto b/test/basic/tcolexpand_tilelang.pto new file mode 100644 index 000000000..0d0f4f74a --- /dev/null +++ b/test/basic/tcolexpand_tilelang.pto @@ -0,0 +1,33 @@ +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tcolexpand via the default TileLang Python DSL template +// lib/TileOps/tcolexpand_template.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tcolexpand should be lowered to vector-style VPTO IR. +// CHECK: func.func @TCOLEXPAND +// CHECK-NOT: pto.tcolexpand ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vsts + +module { + func.func @TCOLEXPAND() { + %src0 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tcolexpand ins(%src0 : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/tcolexpandadd_tilelang.pto b/test/basic/tcolexpandadd_tilelang.pto new file mode 100644 index 000000000..2a558dc1f --- /dev/null +++ b/test/basic/tcolexpandadd_tilelang.pto @@ -0,0 +1,39 @@ +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tcolexpandadd via the default TileLang Python DSL template +// lib/TileOps/tcolexpandadd_template.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tcolexpandadd should be lowered to vector-style VPTO IR. +// CHECK: func.func @TCOLEXPANDADD +// CHECK-NOT: pto.tcolexpandadd ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vadd +// CHECK: pto.vsts + +module { + func.func @TCOLEXPANDADD() { + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tcolexpandadd ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/tcolexpanddiv_tilelang.pto b/test/basic/tcolexpanddiv_tilelang.pto new file mode 100644 index 000000000..2773728aa --- /dev/null +++ b/test/basic/tcolexpanddiv_tilelang.pto @@ -0,0 +1,39 @@ +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tcolexpanddiv via the default TileLang Python DSL template +// lib/TileOps/tcolexpanddiv_template.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tcolexpanddiv should be lowered to vector-style VPTO IR. +// CHECK: func.func @TCOLEXPANDDIV +// CHECK-NOT: pto.tcolexpanddiv ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vdiv +// CHECK: pto.vsts + +module { + func.func @TCOLEXPANDDIV() { + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tcolexpanddiv ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/tcolexpandexpdif_tilelang.pto b/test/basic/tcolexpandexpdif_tilelang.pto new file mode 100644 index 000000000..9ef49ad90 --- /dev/null +++ b/test/basic/tcolexpandexpdif_tilelang.pto @@ -0,0 +1,40 @@ +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tcolexpandexpdif via the default TileLang Python DSL template +// lib/TileOps/tcolexpandexpdif_template.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tcolexpandexpdif should be lowered to vector-style VPTO IR. +// CHECK: func.func @TCOLEXPANDEXPDIF +// CHECK-NOT: pto.tcolexpandexpdif ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vsub +// CHECK: pto.vexp +// CHECK: pto.vsts + +module { + func.func @TCOLEXPANDEXPDIF() { + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tcolexpandexpdif ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/tcolexpandmax_tilelang.pto b/test/basic/tcolexpandmax_tilelang.pto new file mode 100644 index 000000000..5485bfc73 --- /dev/null +++ b/test/basic/tcolexpandmax_tilelang.pto @@ -0,0 +1,39 @@ +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tcolexpandmax via the default TileLang Python DSL template +// lib/TileOps/tcolexpandmax_template.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tcolexpandmax should be lowered to vector-style VPTO IR. +// CHECK: func.func @TCOLEXPANDMAX +// CHECK-NOT: pto.tcolexpandmax ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vmax +// CHECK: pto.vsts + +module { + func.func @TCOLEXPANDMAX() { + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tcolexpandmax ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/tcolexpandmin_tilelang.pto b/test/basic/tcolexpandmin_tilelang.pto new file mode 100644 index 000000000..ce2c44853 --- /dev/null +++ b/test/basic/tcolexpandmin_tilelang.pto @@ -0,0 +1,39 @@ +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tcolexpandmin via the default TileLang Python DSL template +// lib/TileOps/tcolexpandmin_template.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tcolexpandmin should be lowered to vector-style VPTO IR. +// CHECK: func.func @TCOLEXPANDMIN +// CHECK-NOT: pto.tcolexpandmin ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vmin +// CHECK: pto.vsts + +module { + func.func @TCOLEXPANDMIN() { + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tcolexpandmin ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/tcolexpandmul_tilelang.pto b/test/basic/tcolexpandmul_tilelang.pto new file mode 100644 index 000000000..4e185b6bc --- /dev/null +++ b/test/basic/tcolexpandmul_tilelang.pto @@ -0,0 +1,39 @@ +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tcolexpandmul via the default TileLang Python DSL template +// lib/TileOps/tcolexpandmul_template.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tcolexpandmul should be lowered to vector-style VPTO IR. +// CHECK: func.func @TCOLEXPANDMUL +// CHECK-NOT: pto.tcolexpandmul ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vmul +// CHECK: pto.vsts + +module { + func.func @TCOLEXPANDMUL() { + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tcolexpandmul ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/tcolexpandsub_tilelang.pto b/test/basic/tcolexpandsub_tilelang.pto new file mode 100644 index 000000000..6ebd82dd5 --- /dev/null +++ b/test/basic/tcolexpandsub_tilelang.pto @@ -0,0 +1,39 @@ +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tcolexpandsub via the default TileLang Python DSL template +// lib/TileOps/tcolexpandsub_template.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tcolexpandsub should be lowered to vector-style VPTO IR. +// CHECK: func.func @TCOLEXPANDSUB +// CHECK-NOT: pto.tcolexpandsub ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vsub +// CHECK: pto.vsts + +module { + func.func @TCOLEXPANDSUB() { + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tcolexpandsub ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt index f4c4eaa96..b443ac5ea 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt +++ b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt @@ -140,6 +140,14 @@ set(ALL_TESTCASES tcolmin tcolsum tcolprod + tcolexpand + tcolexpandadd + tcolexpandsub + tcolexpandmul + tcolexpanddiv + tcolexpandmax + tcolexpandmin + tcolexpandexpdif softmax tabs texp diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/CMakeLists.txt new file mode 100644 index 000000000..f0c992931 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/CMakeLists.txt @@ -0,0 +1,16 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file in the compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tcolexpand) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/cases.py new file mode 100644 index 000000000..2ea12b560 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/cases.py @@ -0,0 +1,75 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tcolexpand ST test cases. +Matches PTO-ISA testcase definitions in /home/zhoushaofan/code/pto-isa/tests/npu/a5/src/st/testcase/tcolexpand/ + +TCOLEXPAND: expand src first row to dst all rows by broadcasting. + - src_shape: (src_row, cols) - input tile (only first row is used for broadcast) + - dst_shape: (dst_row, cols) - expanded output + - shape: (dst_row, cols) - alias of dst_shape, for compare.py compatibility + - valid_shape: (valid_row, valid_col) - effective computation region + +Case naming: {dtype}_{src_row}_{dst_row}_{cols}_{valid_col} +""" + +import numpy as np + +CASES = [ + { + "name": "half_1_16_512_512", + "dtype": np.float16, + "src_shape": (1, 512), + "shape": (16, 512), + "valid_shape": (16, 512), + "eps": 1e-3, + }, + { + "name": "int8_2_32_256_255", + "dtype": np.int8, + "src_shape": (2, 256), + "shape": (32, 256), + "valid_shape": (32, 255), + "eps": 0, + }, + { + "name": "float_1_8_128_63", + "dtype": np.float32, + "src_shape": (1, 128), + "shape": (8, 128), + "valid_shape": (8, 63), + "eps": 1e-6, + }, + { + "name": "half_1_33_512_512", + "dtype": np.float16, + "src_shape": (1, 512), + "shape": (33, 512), + "valid_shape": (33, 512), + "eps": 1e-3, + }, + { + "name": "int8_2_17_256_44", + "dtype": np.int8, + "src_shape": (2, 256), + "shape": (17, 256), + "valid_shape": (17, 44), + "eps": 0, + }, + { + "name": "float_1_54_64_63", + "dtype": np.float32, + "src_shape": (1, 64), + "shape": (54, 64), + "valid_shape": (54, 63), + "eps": 1e-6, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/compare.py new file mode 100644 index 000000000..b8ddb1131 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/compare.py @@ -0,0 +1,56 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file in the compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/gen_data.py new file mode 100644 index 000000000..7f727226d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/gen_data.py @@ -0,0 +1,34 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + src_shape = case["src_shape"] + dst_shape = case["shape"] + valid_shape = case["valid_shape"] + + src = np.random.randint(1, 10, size=src_shape).astype(dtype) + + golden = np.zeros(dst_shape, dtype=dtype) + valid_row, valid_col = valid_shape + for i in range(valid_row): + golden[i, :valid_col] = src[0, :valid_col] + + save_case_data(case["name"], {"input0": src, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} src={src_shape} dst={dst_shape} valid={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/launch.cpp new file mode 100644 index 000000000..c22f39010 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 1: half_1_16_512_512 +extern "C" __global__ AICORE void TCOLEXPAND_half_1_16_512_512(__gm__ uint16_t *src, __gm__ uint16_t *dst); + +void LaunchTCOLEXPAND_half_1_16_512_512(uint16_t *src, uint16_t *dst, void *stream) { + TCOLEXPAND_half_1_16_512_512<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint16_t *)dst); +} + +// Case 2: int8_2_32_256_255 +extern "C" __global__ AICORE void TCOLEXPAND_int8_2_32_256_255(__gm__ int8_t *src, __gm__ int8_t *dst); + +void LaunchTCOLEXPAND_int8_2_32_256_255(int8_t *src, int8_t *dst, void *stream) { + TCOLEXPAND_int8_2_32_256_255<<<1, nullptr, stream>>>((__gm__ int8_t *)src, (__gm__ int8_t *)dst); +} + +// Case 3: float_1_8_128_63 +extern "C" __global__ AICORE void TCOLEXPAND_float_1_8_128_63(__gm__ float *src, __gm__ float *dst); + +void LaunchTCOLEXPAND_float_1_8_128_63(float *src, float *dst, void *stream) { + TCOLEXPAND_float_1_8_128_63<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +// Case 4: half_1_33_512_512 +extern "C" __global__ AICORE void TCOLEXPAND_half_1_33_512_512(__gm__ uint16_t *src, __gm__ uint16_t *dst); + +void LaunchTCOLEXPAND_half_1_33_512_512(uint16_t *src, uint16_t *dst, void *stream) { + TCOLEXPAND_half_1_33_512_512<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint16_t *)dst); +} + +// Case 5: int8_2_17_256_44 +extern "C" __global__ AICORE void TCOLEXPAND_int8_2_17_256_44(__gm__ int8_t *src, __gm__ int8_t *dst); + +void LaunchTCOLEXPAND_int8_2_17_256_44(int8_t *src, int8_t *dst, void *stream) { + TCOLEXPAND_int8_2_17_256_44<<<1, nullptr, stream>>>((__gm__ int8_t *)src, (__gm__ int8_t *)dst); +} + +// Case 6: float_1_54_64_63 +extern "C" __global__ AICORE void TCOLEXPAND_float_1_54_64_63(__gm__ float *src, __gm__ float *dst); + +void LaunchTCOLEXPAND_float_1_54_64_63(float *src, float *dst, void *stream) { + TCOLEXPAND_float_1_54_64_63<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/main.cpp new file mode 100644 index 000000000..717b68094 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/main.cpp @@ -0,0 +1,143 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tcolexpand ST +// Test cases match PTO-ISA: /home/zhoushaofan/code/pto-isa/tests/npu/a5/src/st/testcase/tcolexpand/ +// TCOLEXPAND: expand src first row to dst all rows by broadcasting + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTCOLEXPAND_half_1_16_512_512(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTCOLEXPAND_int8_2_32_256_255(int8_t *src, int8_t *dst, void *stream); +void LaunchTCOLEXPAND_float_1_8_128_63(float *src, float *dst, void *stream); +void LaunchTCOLEXPAND_half_1_33_512_512(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTCOLEXPAND_int8_2_17_256_44(int8_t *src, int8_t *dst, void *stream); +void LaunchTCOLEXPAND_float_1_54_64_63(float *src, float *dst, void *stream); + +using LaunchFn = void (*)(void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t srcRows; + size_t srcCols; + size_t dstRows; + size_t dstCols; + size_t validRows; + size_t validCols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + {"half_1_16_512_512", (LaunchFn)LaunchTCOLEXPAND_half_1_16_512_512, 1, 512, 16, 512, 16, 512, sizeof(uint16_t)}, + {"int8_2_32_256_255", (LaunchFn)LaunchTCOLEXPAND_int8_2_32_256_255, 2, 256, 32, 256, 32, 255, sizeof(int8_t)}, + {"float_1_8_128_63", (LaunchFn)LaunchTCOLEXPAND_float_1_8_128_63, 1, 128, 8, 128, 8, 63, sizeof(float)}, + {"half_1_33_512_512", (LaunchFn)LaunchTCOLEXPAND_half_1_33_512_512, 1, 512, 33, 512, 33, 512, sizeof(uint16_t)}, + {"int8_2_17_256_44", (LaunchFn)LaunchTCOLEXPAND_int8_2_17_256_44, 2, 256, 17, 256, 17, 44, sizeof(int8_t)}, + {"float_1_54_64_63", (LaunchFn)LaunchTCOLEXPAND_float_1_54_64_63, 1, 64, 54, 64, 54, 63, sizeof(float)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t srcElemCount = tc.srcRows * tc.srcCols; + const size_t dstElemCount = tc.dstRows * tc.dstCols; + const size_t srcFileSize = srcElemCount * tc.elemSize; + const size_t dstFileSize = dstElemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (src=%zux%zu -> dst=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.srcRows, tc.srcCols, tc.dstRows, tc.dstCols, tc.validRows, tc.validCols); + + std::string caseDir = std::string("./") + tc.name; + size_t actualSrcFileSize = srcFileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&srcHost), srcFileSize); + aclrtMallocHost((void **)(&dstHost), dstFileSize); + + aclrtMalloc((void **)&srcDevice, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input0.bin").c_str(), actualSrcFileSize, srcHost, srcFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input0.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, srcFileSize, srcHost, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/tcolexpand.pto b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/tcolexpand.pto new file mode 100644 index 000000000..4716de32c --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/tcolexpand.pto @@ -0,0 +1,327 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tcolexpand: expand src (src_row x valid_col) to dst (dst_row x valid_col). +// Matches PTO-ISA testcase parameters. +// Key: tile_buf cols = full tensor width, v_col = valid portion + +module { + // Case 1: half_1_16_512_512 (fp16, valid_col=512, cols=512) + func.func @TCOLEXPAND_half_1_16_512_512(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c512 = arith.constant 512 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c512], + strides = [%c512, %c512, %c512, %c512, %c1] + : !pto.tensor_view<1x1x1x1x512xf16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c512], + strides = [%c8192, %c8192, %c8192, %c512, %c1] + : !pto.tensor_view<1x1x1x16x512xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c512] + : !pto.tensor_view<1x1x1x1x512xf16> -> !pto.partition_tensor_view<1x1x1x1x512xf16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c512] + : !pto.tensor_view<1x1x1x16x512xf16> -> !pto.partition_tensor_view<1x1x1x16x512xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x512xf16>) + outs(%src : !pto.tile_buf) + + pto.tcolexpand ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x512xf16>) + return + } + + // Case 2: int8_2_32_256_255 (int8, cols=256, valid_col=255) + func.func @TCOLEXPAND_int8_2_32_256_255(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %c255 = arith.constant 255 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c256], + strides = [%c512, %c512, %c512, %c256, %c1] + : !pto.tensor_view<1x1x1x2x256xi8> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c256], + strides = [%c8192, %c8192, %c8192, %c256, %c1] + : !pto.tensor_view<1x1x1x32x256xi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c255] + : !pto.tensor_view<1x1x1x2x256xi8> -> !pto.partition_tensor_view<1x1x1x2x255xi8> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c255] + : !pto.tensor_view<1x1x1x32x256xi8> -> !pto.partition_tensor_view<1x1x1x32x255xi8> + + %src = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x255xi8>) + outs(%src : !pto.tile_buf) + + pto.tcolexpand ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x255xi8>) + return + } + + // Case 3: float_1_8_128_63 (float32, cols=128, valid_col=63) + func.func @TCOLEXPAND_float_1_8_128_63(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c63 = arith.constant 63 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c8, %c128], + strides = [%c1024, %c1024, %c1024, %c128, %c1] + : !pto.tensor_view<1x1x1x8x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c63] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x63xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c63] + : !pto.tensor_view<1x1x1x8x128xf32> -> !pto.partition_tensor_view<1x1x1x8x63xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x63xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolexpand ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x8x63xf32>) + return + } + + // Case 4: half_1_33_512_512 (fp16, cols=512, valid_col=512) + func.func @TCOLEXPAND_half_1_33_512_512(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c33 = arith.constant 33 : index + %c512 = arith.constant 512 : index + %c16896 = arith.constant 16896 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c512], + strides = [%c512, %c512, %c512, %c512, %c1] + : !pto.tensor_view<1x1x1x1x512xf16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c33, %c512], + strides = [%c16896, %c16896, %c16896, %c512, %c1] + : !pto.tensor_view<1x1x1x33x512xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c512] + : !pto.tensor_view<1x1x1x1x512xf16> -> !pto.partition_tensor_view<1x1x1x1x512xf16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c33, %c512] + : !pto.tensor_view<1x1x1x33x512xf16> -> !pto.partition_tensor_view<1x1x1x33x512xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x512xf16>) + outs(%src : !pto.tile_buf) + + pto.tcolexpand ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x33x512xf16>) + return + } + + // Case 5: int8_2_17_256_44 (int8, cols=256, valid_col=44) + func.func @TCOLEXPAND_int8_2_17_256_44(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c17 = arith.constant 17 : index + %c44 = arith.constant 44 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c4352 = arith.constant 4352 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c256], + strides = [%c512, %c512, %c512, %c256, %c1] + : !pto.tensor_view<1x1x1x2x256xi8> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c17, %c256], + strides = [%c4352, %c4352, %c4352, %c256, %c1] + : !pto.tensor_view<1x1x1x17x256xi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c44] + : !pto.tensor_view<1x1x1x2x256xi8> -> !pto.partition_tensor_view<1x1x1x2x44xi8> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c17, %c44] + : !pto.tensor_view<1x1x1x17x256xi8> -> !pto.partition_tensor_view<1x1x1x17x44xi8> + + %src = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x44xi8>) + outs(%src : !pto.tile_buf) + + pto.tcolexpand ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x17x44xi8>) + return + } + + // Case 6: float_1_54_64_63 (float32, cols=64, valid_col=63) + func.func @TCOLEXPAND_float_1_54_64_63(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c54 = arith.constant 54 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c3456 = arith.constant 3456 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xf32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c54, %c64], + strides = [%c3456, %c3456, %c3456, %c64, %c1] + : !pto.tensor_view<1x1x1x54x64xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c63] + : !pto.tensor_view<1x1x1x1x64xf32> -> !pto.partition_tensor_view<1x1x1x1x63xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c54, %c63] + : !pto.tensor_view<1x1x1x54x64xf32> -> !pto.partition_tensor_view<1x1x1x54x63xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x63xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolexpand ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x54x63xf32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/CMakeLists.txt new file mode 100644 index 000000000..7151caba5 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tcolexpandadd) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/cases.py new file mode 100644 index 000000000..fa496319d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/cases.py @@ -0,0 +1,77 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tcolexpandadd ST test cases. + +TCOLEXPANDADD: expand src1 then add with src0. + - src0_shape: (dst_row, dst_col) - already expanded (src0_shape = shape) + - src1_shape: (src1_row, src1_col) - to be expanded (usually src1_row=1) + - shape: (dst_row, dst_col) - output shape +""" + +import numpy as np + +CASES = [ + { + "name": "fp32_16_128_1_128", + "dtype": np.float32, + "src0_shape": (16, 128), + "src1_shape": (1, 128), + "shape": (16, 128), + "valid_shape": (16, 128), + "eps": 1e-3, + }, + { + "name": "fp32_32_32_1_32", + "dtype": np.float32, + "src0_shape": (32, 32), + "src1_shape": (1, 32), + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-3, + }, + { + "name": "fp16_4_256_1_256", + "dtype": np.float16, + "src0_shape": (4, 256), + "src1_shape": (1, 256), + "shape": (4, 256), + "valid_shape": (4, 256), + "eps": 1e-3, + }, + { + "name": "fp16_10_64_1_64", + "dtype": np.float16, + "src0_shape": (10, 64), + "src1_shape": (1, 64), + "shape": (10, 64), + "valid_shape": (10, 64), + "eps": 1e-3, + }, + { + "name": "int32_16_32_1_32", + "dtype": np.int32, + "src0_shape": (16, 32), + "src1_shape": (1, 32), + "shape": (16, 32), + "valid_shape": (16, 32), + "eps": 0, + }, + { + "name": "int16_16_64_1_64", + "dtype": np.int16, + "src0_shape": (16, 64), + "src1_shape": (1, 64), + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/compare.py new file mode 100644 index 000000000..b8ddb1131 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/compare.py @@ -0,0 +1,56 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file in the compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/gen_data.py new file mode 100644 index 000000000..d8556e3a3 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/gen_data.py @@ -0,0 +1,38 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + src0_shape = case["src0_shape"] + dst_shape = case["shape"] + src1_shape = case["src1_shape"] + valid_shape = case["valid_shape"] + + src0 = np.random.randint(1, 10, size=src0_shape).astype(dtype) + src1 = np.random.randint(1, 10, size=src1_shape).astype(dtype) + + golden = np.zeros(dst_shape, dtype=dtype) + valid_row, valid_col = valid_shape + src1_row, src1_col = src1_shape + reps = dst_shape[0] // src1_row + expanded_src1 = np.tile(src1, (reps, 1))[:, :valid_col] + golden[:valid_row, :valid_col] = (src0[:valid_row, :valid_col] + expanded_src1[:valid_row, :valid_col]).astype(dtype, copy=False) + + save_case_data(case["name"], {"input0": src0, "input1": src1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} src0={src0_shape} src1={src1_shape} dst={dst_shape} valid={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/launch.cpp new file mode 100644 index 000000000..fe60cd7c2 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/launch.cpp @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 1: fp32_16_128_1_128 +extern "C" __global__ AICORE void TCOLEXPANDADD_fp32_16_128_1_128(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTCOLEXPANDADD_fp32_16_128_1_128(float *src0, float *src1, float *dst, void *stream) { + TCOLEXPANDADD_fp32_16_128_1_128<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case 2: fp32_32_32_1_32 +extern "C" __global__ AICORE void TCOLEXPANDADD_fp32_32_32_1_32(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTCOLEXPANDADD_fp32_32_32_1_32(float *src0, float *src1, float *dst, void *stream) { + TCOLEXPANDADD_fp32_32_32_1_32<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case 3: fp16_4_256_1_256 +extern "C" __global__ AICORE void TCOLEXPANDADD_fp16_4_256_1_256(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTCOLEXPANDADD_fp16_4_256_1_256(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TCOLEXPANDADD_fp16_4_256_1_256<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} + +// Case 4: fp16_10_64_1_64 +extern "C" __global__ AICORE void TCOLEXPANDADD_fp16_10_64_1_64(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTCOLEXPANDADD_fp16_10_64_1_64(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TCOLEXPANDADD_fp16_10_64_1_64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} +// Case: int32_16_32_1_32 +extern "C" __global__ AICORE void TCOLEXPANDADD_int32_16_32_1_32(__gm__ int32_t *src0, __gm__ int32_t *src1, __gm__ int32_t *dst); + +void LaunchTCOLEXPANDADD_int32_16_32_1_32(int32_t *src0, int32_t *src1, int32_t *dst, void *stream) { + TCOLEXPANDADD_int32_16_32_1_32<<<1, nullptr, stream>>>((__gm__ int32_t *)src0, (__gm__ int32_t *)src1, (__gm__ int32_t *)dst); +} + +// Case: int16_16_64_1_64 +extern "C" __global__ AICORE void TCOLEXPANDADD_int16_16_64_1_64(__gm__ int16_t *src0, __gm__ int16_t *src1, __gm__ int16_t *dst); + +void LaunchTCOLEXPANDADD_int16_16_64_1_64(int16_t *src0, int16_t *src1, int16_t *dst, void *stream) { + TCOLEXPANDADD_int16_16_64_1_64<<<1, nullptr, stream>>>((__gm__ int16_t *)src0, (__gm__ int16_t *)src1, (__gm__ int16_t *)dst); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/main.cpp new file mode 100644 index 000000000..93a9581fd --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/main.cpp @@ -0,0 +1,158 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tcolexpandadd ST +// TCOLEXPANDADD: src0 + expand(src1) -> dst + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTCOLEXPANDADD_fp32_16_128_1_128(float *src0, float *src1, float *dst, void *stream); +void LaunchTCOLEXPANDADD_fp32_32_32_1_32(float *src0, float *src1, float *dst, void *stream); +void LaunchTCOLEXPANDADD_fp16_4_256_1_256(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); +void LaunchTCOLEXPANDADD_fp16_10_64_1_64(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); +void LaunchTCOLEXPANDADD_int32_16_32_1_32(int32_t *src0, int32_t *src1, int32_t *dst, void *stream); +void LaunchTCOLEXPANDADD_int16_16_64_1_64(int16_t *src0, int16_t *src1, int16_t *dst, void *stream); + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t src0Rows; + size_t src0Cols; + size_t src1Rows; + size_t src1Cols; + size_t dstRows; + size_t dstCols; + size_t validRows; + size_t validCols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + {"fp32_16_128_1_128", (LaunchFn)LaunchTCOLEXPANDADD_fp32_16_128_1_128, 16, 128, 1, 128, 16, 128, 16, 128, sizeof(float)}, + {"fp32_32_32_1_32", (LaunchFn)LaunchTCOLEXPANDADD_fp32_32_32_1_32, 32, 32, 1, 32, 32, 32, 32, 32, sizeof(float)}, + {"fp16_4_256_1_256", (LaunchFn)LaunchTCOLEXPANDADD_fp16_4_256_1_256, 4, 256, 1, 256, 4, 256, 4, 256, sizeof(uint16_t)}, + {"fp16_10_64_1_64", (LaunchFn)LaunchTCOLEXPANDADD_fp16_10_64_1_64, 10, 64, 1, 64, 10, 64, 10, 64, sizeof(uint16_t)}, + {"int32_16_32_1_32", (LaunchFn)LaunchTCOLEXPANDADD_int32_16_32_1_32, 16, 32, 1, 32, 16, 32, 16, 32, sizeof(int32_t)}, + {"int16_16_64_1_64", (LaunchFn)LaunchTCOLEXPANDADD_int16_16_64_1_64, 16, 64, 1, 64, 16, 64, 16, 64, sizeof(int16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t src0ElemCount = tc.src0Rows * tc.src0Cols; + const size_t src1ElemCount = tc.src1Rows * tc.src1Cols; + const size_t dstElemCount = tc.dstRows * tc.dstCols; + const size_t src0FileSize = src0ElemCount * tc.elemSize; + const size_t src1FileSize = src1ElemCount * tc.elemSize; + const size_t dstFileSize = dstElemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (src0=%zux%zu, src1=%zux%zu -> dst=%zux%zu) ===\n", + tc.name, tc.src0Rows, tc.src0Cols, tc.src1Rows, tc.src1Cols, tc.dstRows, tc.dstCols); + + std::string caseDir = std::string("./") + tc.name; + size_t actualSrc0FileSize = src0FileSize; + size_t actualSrc1FileSize = src1FileSize; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), src0FileSize); + aclrtMallocHost((void **)(&src1Host), src1FileSize); + aclrtMallocHost((void **)(&dstHost), dstFileSize); + + aclrtMalloc((void **)&src0Device, src0FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, src1FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input0.bin").c_str(), actualSrc0FileSize, src0Host, src0FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input0.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input1.bin").c_str(), actualSrc1FileSize, src1Host, src1FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, src0FileSize, src0Host, src0FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, src1FileSize, src1Host, src1FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/tcolexpandadd.pto b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/tcolexpandadd.pto new file mode 100644 index 000000000..6b78a5a49 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/tcolexpandadd.pto @@ -0,0 +1,440 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tcolexpandadd: expand src1 and add with src0. +// TCOLEXPANDADD: src0 + expand(src1) -> dst +// - src0: (dst_row, dst_col) +// - src1: (src1_row, src1_col), usually src1_row=1 +// - dst: (dst_row, dst_col) + +module { + // Case 1: fp32_16_128_1_128 (float32, src0=(16,128), src1=(1,128), dst=(16,128)) + func.func @TCOLEXPANDADD_fp32_16_128_1_128(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x128xf32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x128xf32> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + + %src1_tile = pto.alloc_tile + : !pto.tile_buf + + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x128xf32>) + outs(%src0_tile : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + outs(%src1_tile : !pto.tile_buf) + + pto.tcolexpandadd ins(%src0_tile, %src1_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x128xf32>) + return + } + + // Case 2: fp32_32_32_1_32 (float32, src0=(32,32), src1=(1,32), dst=(32,32)) + func.func @TCOLEXPANDADD_fp32_32_32_1_32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c32], + strides = [%c32, %c32, %c32, %c32, %c1] + : !pto.tensor_view<1x1x1x1x32xf32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c32] + : !pto.tensor_view<1x1x1x1x32xf32> -> !pto.partition_tensor_view<1x1x1x1x32xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + + %src1_tile = pto.alloc_tile + : !pto.tile_buf + + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%src0_tile : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x32xf32>) + outs(%src1_tile : !pto.tile_buf) + + pto.tcolexpandadd ins(%src0_tile, %src1_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + return + } + + // Case 3: fp16_4_256_1_256 (float16, src0=(4,256), src1=(1,256), dst=(4,256)) + func.func @TCOLEXPANDADD_fp16_4_256_1_256(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xf16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c256] + : !pto.tensor_view<1x1x1x4x256xf16> -> !pto.partition_tensor_view<1x1x1x4x256xf16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c256] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x256xf16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c256] + : !pto.tensor_view<1x1x1x4x256xf16> -> !pto.partition_tensor_view<1x1x1x4x256xf16> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + + %src1_tile = pto.alloc_tile + : !pto.tile_buf + + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x4x256xf16>) + outs(%src0_tile : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x256xf16>) + outs(%src1_tile : !pto.tile_buf) + + pto.tcolexpandadd ins(%src0_tile, %src1_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x256xf16>) + return + } + + // Case 4: fp16_10_64_1_64 (float16, src0=(10,64), src1=(1,64), dst=(10,64)) + func.func @TCOLEXPANDADD_fp16_10_64_1_64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %c64 = arith.constant 64 : index + %c640 = arith.constant 640 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c10, %c64], + strides = [%c640, %c640, %c640, %c64, %c1] + : !pto.tensor_view<1x1x1x10x64xf16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xf16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c10, %c64], + strides = [%c640, %c640, %c640, %c64, %c1] + : !pto.tensor_view<1x1x1x10x64xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c10, %c64] + : !pto.tensor_view<1x1x1x10x64xf16> -> !pto.partition_tensor_view<1x1x1x10x64xf16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xf16> -> !pto.partition_tensor_view<1x1x1x1x64xf16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c10, %c64] + : !pto.tensor_view<1x1x1x10x64xf16> -> !pto.partition_tensor_view<1x1x1x10x64xf16> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + + %src1_tile = pto.alloc_tile + : !pto.tile_buf + + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x10x64xf16>) + outs(%src0_tile : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x64xf16>) + outs(%src1_tile : !pto.tile_buf) + + pto.tcolexpandadd ins(%src0_tile, %src1_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x10x64xf16>) + return + } + + // Case 5: int32_16_32_1_32 (int32, src0=(16,32), src1=(1,32), dst=(16,32)) + func.func @TCOLEXPANDADD_int32_16_32_1_32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c32], + strides = [%c32, %c32, %c32, %c32, %c1] + : !pto.tensor_view<1x1x1x1x32xi32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c32] + : !pto.tensor_view<1x1x1x1x32xi32> -> !pto.partition_tensor_view<1x1x1x1x32xi32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + + %src1_tile = pto.alloc_tile + : !pto.tile_buf + + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + outs(%src0_tile : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x32xi32>) + outs(%src1_tile : !pto.tile_buf) + + pto.tcolexpandadd ins(%src0_tile, %src1_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + return + } + + // Case 6: int16_16_64_1_64 (int16, src0=(16,64), src1=(1,64), dst=(16,64)) + func.func @TCOLEXPANDADD_int16_16_64_1_64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xi16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xi16> -> !pto.partition_tensor_view<1x1x1x1x64xi16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + + %src1_tile = pto.alloc_tile + : !pto.tile_buf + + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + outs(%src0_tile : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x64xi16>) + outs(%src1_tile : !pto.tile_buf) + + pto.tcolexpandadd ins(%src0_tile, %src1_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/CMakeLists.txt new file mode 100644 index 000000000..4fe30b496 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tcolexpanddiv) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/cases.py new file mode 100644 index 000000000..c4710cb93 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/cases.py @@ -0,0 +1,115 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file in the compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tcolexpanddiv ST test cases. +Matches PTO-ISA testcase definitions in /home/zhoushaofan/code/pto-isa/tests/npu/a5/src/st/testcase/tcolexpanddiv/ + +TCOLEXPANDDIV: column-wise broadcast divide - dst[i,j] = src0[i,j] / src1[0,j] + - src0_shape: (src0_row, cols) - dividend input tile + - src1_shape: (1, cols) - divisor input tile (single row, broadcast) + - dst_shape: (dst_row, cols) - output tile + - valid_shape: (valid_row, valid_col) - effective computation region + +Case naming: {dtype}_{src0_row}_{src0_col}_{src1_row}_{src1_col} +""" + +import numpy as np + +CASES = [ + { + "name": "fp32_32_64_1_64", + "dtype": np.float32, + "src0_shape": (32, 64), + "src1_shape": (1, 64), + "shape": (32, 64), + "valid_shape": (32, 64), + "eps": 1e-6, + }, + { + "name": "fp32_8_32_1_32", + "dtype": np.float32, + "src0_shape": (8, 32), + "src1_shape": (1, 32), + "shape": (8, 32), + "valid_shape": (8, 32), + "eps": 1e-6, + }, + { + "name": "fp16_16_64_1_64", + "dtype": np.float16, + "src0_shape": (16, 64), + "src1_shape": (1, 64), + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-3, + }, + { + "name": "fp16_4_128_1_128", + "dtype": np.float16, + "src0_shape": (4, 128), + "src1_shape": (1, 128), + "shape": (4, 128), + "valid_shape": (4, 128), + "eps": 1e-3, + }, + { + "name": "int32_16_32_1_32", + "dtype": np.int32, + "src0_shape": (16, 32), + "src1_shape": (1, 32), + "shape": (16, 32), + "valid_shape": (16, 32), + "eps": 0, + }, + { + "name": "int16_16_64_1_64", + "dtype": np.int16, + "src0_shape": (16, 64), + "src1_shape": (1, 64), + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 0, + }, + { + "name": "fp32_40_32_1_32", + "dtype": np.float32, + "src0_shape": (40, 32), + "src1_shape": (1, 32), + "shape": (40, 32), + "valid_shape": (40, 32), + "eps": 1e-6, + }, + { + "name": "fp16_16_128_1_128", + "dtype": np.float16, + "src0_shape": (16, 128), + "src1_shape": (1, 128), + "shape": (16, 128), + "valid_shape": (16, 128), + "eps": 1e-3, + }, + { + "name": "fp32_20_64_1_64", + "dtype": np.float32, + "src0_shape": (20, 64), + "src1_shape": (1, 64), + "shape": (20, 64), + "valid_shape": (20, 64), + "eps": 1e-6, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/compare.py new file mode 100644 index 000000000..b8ddb1131 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/compare.py @@ -0,0 +1,56 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file in the compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/gen_data.py new file mode 100644 index 000000000..60cad96b9 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/gen_data.py @@ -0,0 +1,38 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + src0_shape = case["src0_shape"] + dst_shape = case["shape"] + src1_shape = case["src1_shape"] + valid_shape = case["valid_shape"] + + src0 = np.random.uniform(1.0, 10.0, size=src0_shape).astype(dtype) + src1 = np.random.uniform(1.0, 10.0, size=src1_shape).astype(dtype) + + valid_row, valid_col = valid_shape + reps = dst_shape[0] // src1_shape[0] + + golden = np.zeros(dst_shape, dtype=dtype) + expanded_src1 = np.tile(src1, (reps, 1))[:, :valid_col] + golden[:valid_row, :valid_col] = src0[:valid_row, :valid_col] / expanded_src1 + + save_case_data(case["name"], {"input0": src0, "input1": src1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} src0={src0_shape} src1={src1_shape} dst={dst_shape} valid={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/launch.cpp new file mode 100644 index 000000000..c50222329 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/launch.cpp @@ -0,0 +1,75 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 1: fp32_32_64_1_64 +extern "C" __global__ AICORE void TCOLEXPANDDIV_fp32_32_64_1_64(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTCOLEXPANDDIV_fp32_32_64_1_64(float *src0, float *src1, float *dst, void *stream) { + TCOLEXPANDDIV_fp32_32_64_1_64<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case 2: fp32_8_32_1_32 +extern "C" __global__ AICORE void TCOLEXPANDDIV_fp32_8_32_1_32(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTCOLEXPANDDIV_fp32_8_32_1_32(float *src0, float *src1, float *dst, void *stream) { + TCOLEXPANDDIV_fp32_8_32_1_32<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case 3: fp16_16_64_1_64 +extern "C" __global__ AICORE void TCOLEXPANDDIV_fp16_16_64_1_64(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTCOLEXPANDDIV_fp16_16_64_1_64(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TCOLEXPANDDIV_fp16_16_64_1_64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} + +// Case 4: fp16_4_128_1_128 +extern "C" __global__ AICORE void TCOLEXPANDDIV_fp16_4_128_1_128(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTCOLEXPANDDIV_fp16_4_128_1_128(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TCOLEXPANDDIV_fp16_4_128_1_128<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} +// Case: int32_16_32_1_32 +extern "C" __global__ AICORE void TCOLEXPANDDIV_int32_16_32_1_32(__gm__ int32_t *src0, __gm__ int32_t *src1, __gm__ int32_t *dst); + +void LaunchTCOLEXPANDDIV_int32_16_32_1_32(int32_t *src0, int32_t *src1, int32_t *dst, void *stream) { + TCOLEXPANDDIV_int32_16_32_1_32<<<1, nullptr, stream>>>((__gm__ int32_t *)src0, (__gm__ int32_t *)src1, (__gm__ int32_t *)dst); +} + +// Case: int16_16_64_1_64 +extern "C" __global__ AICORE void TCOLEXPANDDIV_int16_16_64_1_64(__gm__ int16_t *src0, __gm__ int16_t *src1, __gm__ int16_t *dst); + +void LaunchTCOLEXPANDDIV_int16_16_64_1_64(int16_t *src0, int16_t *src1, int16_t *dst, void *stream) { + TCOLEXPANDDIV_int16_16_64_1_64<<<1, nullptr, stream>>>((__gm__ int16_t *)src0, (__gm__ int16_t *)src1, (__gm__ int16_t *)dst); +} + +// Case: fp32_40_32_1_32 +extern "C" __global__ AICORE void TCOLEXPANDDIV_fp32_40_32_1_32(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTCOLEXPANDDIV_fp32_40_32_1_32(float *src0, float *src1, float *dst, void *stream) { + TCOLEXPANDDIV_fp32_40_32_1_32<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case: fp16_16_128_1_128 +extern "C" __global__ AICORE void TCOLEXPANDDIV_fp16_16_128_1_128(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTCOLEXPANDDIV_fp16_16_128_1_128(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TCOLEXPANDDIV_fp16_16_128_1_128<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} + +// Case: fp32_20_64_1_64 +extern "C" __global__ AICORE void TCOLEXPANDDIV_fp32_20_64_1_64(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTCOLEXPANDDIV_fp32_20_64_1_64(float *src0, float *src1, float *dst, void *stream) { + TCOLEXPANDDIV_fp32_20_64_1_64<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/main.cpp new file mode 100644 index 000000000..48b46c797 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/main.cpp @@ -0,0 +1,167 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tcolexpanddiv ST +// Test cases match PTO-ISA: /home/zhoushaofan/code/pto-isa/tests/npu/a5/src/st/testcase/tcolexpanddiv/ +// TCOLEXPANDDIV: column-wise broadcast divide - dst[i,j] = src0[i,j] / src1[0,j] + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTCOLEXPANDDIV_fp32_32_64_1_64(float *src0, float *src1, float *dst, void *stream); +void LaunchTCOLEXPANDDIV_fp32_8_32_1_32(float *src0, float *src1, float *dst, void *stream); +void LaunchTCOLEXPANDDIV_fp16_16_64_1_64(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); +void LaunchTCOLEXPANDDIV_fp16_4_128_1_128(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); +void LaunchTCOLEXPANDDIV_int32_16_32_1_32(int32_t *src0, int32_t *src1, int32_t *dst, void *stream); +void LaunchTCOLEXPANDDIV_int16_16_64_1_64(int16_t *src0, int16_t *src1, int16_t *dst, void *stream); +void LaunchTCOLEXPANDDIV_fp32_40_32_1_32(float *src0, float *src1, float *dst, void *stream); +void LaunchTCOLEXPANDDIV_fp16_16_128_1_128(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); +void LaunchTCOLEXPANDDIV_fp32_20_64_1_64(float *src0, float *src1, float *dst, void *stream); + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t src0Rows; + size_t src0Cols; + size_t src1Rows; + size_t src1Cols; + size_t dstRows; + size_t dstCols; + size_t validRows; + size_t validCols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + {"fp32_32_64_1_64", (LaunchFn)LaunchTCOLEXPANDDIV_fp32_32_64_1_64, 32, 64, 1, 64, 32, 64, 32, 64, sizeof(float)}, + {"fp32_8_32_1_32", (LaunchFn)LaunchTCOLEXPANDDIV_fp32_8_32_1_32, 8, 32, 1, 32, 8, 32, 8, 32, sizeof(float)}, + {"fp16_16_64_1_64", (LaunchFn)LaunchTCOLEXPANDDIV_fp16_16_64_1_64, 16, 64, 1, 64, 16, 64, 16, 64, sizeof(uint16_t)}, + {"fp16_4_128_1_128", (LaunchFn)LaunchTCOLEXPANDDIV_fp16_4_128_1_128, 4, 128, 1, 128, 4, 128, 4, 128, sizeof(uint16_t)}, + {"int32_16_32_1_32", (LaunchFn)LaunchTCOLEXPANDDIV_int32_16_32_1_32, 16, 32, 1, 32, 16, 32, 16, 32, sizeof(int32_t)}, + {"int16_16_64_1_64", (LaunchFn)LaunchTCOLEXPANDDIV_int16_16_64_1_64, 16, 64, 1, 64, 16, 64, 16, 64, sizeof(int16_t)}, + {"fp32_40_32_1_32", (LaunchFn)LaunchTCOLEXPANDDIV_fp32_40_32_1_32, 40, 32, 1, 32, 40, 32, 40, 32, sizeof(float)}, + {"fp16_16_128_1_128", (LaunchFn)LaunchTCOLEXPANDDIV_fp16_16_128_1_128, 16, 128, 1, 128, 16, 128, 16, 128, sizeof(uint16_t)}, + {"fp32_20_64_1_64", (LaunchFn)LaunchTCOLEXPANDDIV_fp32_20_64_1_64, 20, 64, 1, 64, 20, 64, 20, 64, sizeof(float)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t src0ElemCount = tc.src0Rows * tc.src0Cols; + const size_t src1ElemCount = tc.src1Rows * tc.src1Cols; + const size_t dstElemCount = tc.dstRows * tc.dstCols; + const size_t src0FileSize = src0ElemCount * tc.elemSize; + const size_t src1FileSize = src1ElemCount * tc.elemSize; + const size_t dstFileSize = dstElemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (src0=%zux%zu, src1=%zux%zu -> dst=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.src0Rows, tc.src0Cols, tc.src1Rows, tc.src1Cols, + tc.dstRows, tc.dstCols, tc.validRows, tc.validCols); + + std::string caseDir = std::string("./") + tc.name; + size_t actualSrc0FileSize = src0FileSize; + size_t actualSrc1FileSize = src1FileSize; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), src0FileSize); + aclrtMallocHost((void **)(&src1Host), src1FileSize); + aclrtMallocHost((void **)(&dstHost), dstFileSize); + + aclrtMalloc((void **)&src0Device, src0FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, src1FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input0.bin").c_str(), actualSrc0FileSize, src0Host, src0FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input0.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0 && !ReadFile((caseDir + "/input1.bin").c_str(), actualSrc1FileSize, src1Host, src1FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, src0FileSize, src0Host, src0FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, src1FileSize, src1Host, src1FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/tcolexpanddiv.pto b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/tcolexpanddiv.pto new file mode 100644 index 000000000..9e210e780 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/tcolexpanddiv.pto @@ -0,0 +1,652 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tcolexpanddiv: column-wise broadcast divide. +// Matches PTO-ISA testcase parameters. +// Key: tile_buf cols = full tensor width, v_col = valid portion + +module { + // Case 1: fp32_32_64_1_64 (float32, src0=(32,64), src1=(1,64), dst=(32,64)) + func.func @TCOLEXPANDDIV_fp32_32_64_1_64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xf32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xf32> -> !pto.partition_tensor_view<1x1x1x1x64xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x64xf32>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpanddiv ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + return + } + + // Case 2: fp32_8_32_1_32 (float32, src0=(8,32), src1=(1,32), dst=(8,32)) + func.func @TCOLEXPANDDIV_fp32_8_32_1_32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c8, %c32], + strides = [%c256, %c256, %c256, %c32, %c1] + : !pto.tensor_view<1x1x1x8x32xf32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c32], + strides = [%c32, %c32, %c32, %c32, %c1] + : !pto.tensor_view<1x1x1x1x32xf32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c8, %c32], + strides = [%c256, %c256, %c256, %c32, %c1] + : !pto.tensor_view<1x1x1x8x32xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c32] + : !pto.tensor_view<1x1x1x8x32xf32> -> !pto.partition_tensor_view<1x1x1x8x32xf32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c32] + : !pto.tensor_view<1x1x1x1x32xf32> -> !pto.partition_tensor_view<1x1x1x1x32xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c32] + : !pto.tensor_view<1x1x1x8x32xf32> -> !pto.partition_tensor_view<1x1x1x8x32xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x8x32xf32>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x32xf32>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpanddiv ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x8x32xf32>) + return + } + + // Case 3: fp16_16_64_1_64 (float16, src0=(16,64), src1=(1,64), dst=(16,64)) + func.func @TCOLEXPANDDIV_fp16_16_64_1_64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xf16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xf16> -> !pto.partition_tensor_view<1x1x1x1x64xf16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x64xf16>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpanddiv ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + return + } + + // Case 4: fp16_4_128_1_128 (float16, src0=(4,128), src1=(1,128), dst=(4,128)) + func.func @TCOLEXPANDDIV_fp16_4_128_1_128(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xf16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c128] + : !pto.tensor_view<1x1x1x4x128xf16> -> !pto.partition_tensor_view<1x1x1x4x128xf16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf16> -> !pto.partition_tensor_view<1x1x1x1x128xf16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c128] + : !pto.tensor_view<1x1x1x4x128xf16> -> !pto.partition_tensor_view<1x1x1x4x128xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x4x128xf16>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x128xf16>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpanddiv ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + +pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x128xf16>) + return + } + + // Case 5: int32_16_32_1_32 (int32, src0=(16,32), src1=(1,32), dst=(16,32)) + func.func @TCOLEXPANDDIV_int32_16_32_1_32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c32], + strides = [%c32, %c32, %c32, %c32, %c1] + : !pto.tensor_view<1x1x1x1x32xi32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c32] + : !pto.tensor_view<1x1x1x1x32xi32> -> !pto.partition_tensor_view<1x1x1x1x32xi32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x32xi32>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpanddiv ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + return + } + + // Case 6: int16_16_64_1_64 (int16, src0=(16,64), src1=(1,64), dst=(16,64)) + func.func @TCOLEXPANDDIV_int16_16_64_1_64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xi16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xi16> -> !pto.partition_tensor_view<1x1x1x1x64xi16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x64xi16>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpanddiv ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + return + } + + // Case 7: fp32_40_32_1_32 (float32, src0=(40,32), src1=(1,32), dst=(40,32)) + func.func @TCOLEXPANDDIV_fp32_40_32_1_32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c40 = arith.constant 40 : index + %c32 = arith.constant 32 : index + %c1280 = arith.constant 1280 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c40, %c32], + strides = [%c1280, %c1280, %c1280, %c32, %c1] + : !pto.tensor_view<1x1x1x40x32xf32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c32], + strides = [%c32, %c32, %c32, %c32, %c1] + : !pto.tensor_view<1x1x1x1x32xf32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c40, %c32], + strides = [%c1280, %c1280, %c1280, %c32, %c1] + : !pto.tensor_view<1x1x1x40x32xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c40, %c32] + : !pto.tensor_view<1x1x1x40x32xf32> -> !pto.partition_tensor_view<1x1x1x40x32xf32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c32] + : !pto.tensor_view<1x1x1x1x32xf32> -> !pto.partition_tensor_view<1x1x1x1x32xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c40, %c32] + : !pto.tensor_view<1x1x1x40x32xf32> -> !pto.partition_tensor_view<1x1x1x40x32xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x40x32xf32>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x32xf32>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpanddiv ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x40x32xf32>) + return + } + + // Case 8: fp16_16_128_1_128 (float16, src0=(16,128), src1=(1,128), dst=(16,128)) + func.func @TCOLEXPANDDIV_fp16_16_128_1_128(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x128xf16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf16> -> !pto.partition_tensor_view<1x1x1x1x128xf16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x128xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x128xf16>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x128xf16>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpanddiv ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x128xf16>) + return + } + + // Case 9: fp32_20_64_1_64 (float32, src0=(20,64), src1=(1,64), dst=(20,64)) + func.func @TCOLEXPANDDIV_fp32_20_64_1_64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c20 = arith.constant 20 : index + %c64 = arith.constant 64 : index + %c1280 = arith.constant 1280 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c20, %c64], + strides = [%c1280, %c1280, %c1280, %c64, %c1] + : !pto.tensor_view<1x1x1x20x64xf32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xf32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c20, %c64], + strides = [%c1280, %c1280, %c1280, %c64, %c1] + : !pto.tensor_view<1x1x1x20x64xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c20, %c64] + : !pto.tensor_view<1x1x1x20x64xf32> -> !pto.partition_tensor_view<1x1x1x20x64xf32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xf32> -> !pto.partition_tensor_view<1x1x1x1x64xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c20, %c64] + : !pto.tensor_view<1x1x1x20x64xf32> -> !pto.partition_tensor_view<1x1x1x20x64xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x20x64xf32>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x64xf32>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpanddiv ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x20x64xf32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/CMakeLists.txt new file mode 100644 index 000000000..acf3e74d4 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tcolexpandexpdif) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/cases.py new file mode 100644 index 000000000..e4ea5f005 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/cases.py @@ -0,0 +1,67 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tcolexpandexpdif ST test cases. +Matches PTO-ISA testcase definitions in /home/zhoushaofan/code/pto-isa/tests/npu/a5/src/st/testcase/tcolexpandexpdif/ + +TCOLEXPANDEXPDIF: compute exp(src0) - exp(expanded_src1) where src1 is expanded by tiling. + - src0_shape: (src0_row, cols) - first input tile + - src1_shape: (src1_row, cols) - second input tile (tiled to match src0 rows) + - dst_shape: (dst_row, dst_col) - output tile + - shape: (dst_row, dst_col) - alias of dst_shape, for compare.py compatibility + - valid_shape: (valid_row, valid_col) - effective computation region + +Golden: np.exp(src0) - np.exp(np.tile(src1, (reps, 1))[:, :dst_col]) + where reps = dst_row // src1_row + +Case naming: {dtype}_{src0_row}_{src0_col}_{src1_row}_{src1_col} +""" + +import numpy as np + +CASES = [ + { + "name": "fp32_32_16_1_16", + "dtype": np.float32, + "src0_shape": (32, 16), + "src1_shape": (1, 16), + "shape": (32, 16), + "valid_shape": (32, 16), + "eps": 1e-5, + }, + { + "name": "fp32_16_32_1_32", + "dtype": np.float32, + "src0_shape": (16, 32), + "src1_shape": (1, 32), + "shape": (16, 32), + "valid_shape": (16, 32), + "eps": 1e-5, + }, + { + "name": "fp16_32_32_1_32", + "dtype": np.float16, + "src0_shape": (32, 32), + "src1_shape": (1, 32), + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-2, + }, + { + "name": "fp16_16_128_1_128", + "dtype": np.float16, + "src0_shape": (16, 128), + "src1_shape": (1, 128), + "shape": (16, 128), + "valid_shape": (16, 128), + "eps": 1e-2, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/compare.py new file mode 100644 index 000000000..b8ddb1131 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/compare.py @@ -0,0 +1,56 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file in the compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/gen_data.py new file mode 100644 index 000000000..4606b3571 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/gen_data.py @@ -0,0 +1,39 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + src0_shape = case["src0_shape"] + dst_shape = case["shape"] + src1_shape = case["src1_shape"] + valid_shape = case["valid_shape"] + + src0 = np.random.uniform(-255, 255, size=src0_shape).astype(dtype) + src1 = np.random.uniform(1, 255, size=src1_shape).astype(dtype) + + dst_row, dst_col = dst_shape + src1_row = src1_shape[0] + reps = (dst_row + src1_row - 1) // src1_row + + expanded_src1 = np.tile(src1, (reps, 1))[:dst_row, :dst_col] + golden = np.exp((src0.astype(np.float64) - expanded_src1.astype(np.float64))) + golden = golden.astype(dtype) + + save_case_data(case["name"], {"input0": src0, "input1": src1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} src0={src0_shape} src1={src1_shape} dst={dst_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/launch.cpp new file mode 100644 index 000000000..b6f8ad19d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/launch.cpp @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 1: fp32_32_16_1_16 +extern "C" __global__ AICORE void TCOLEXPANDEXPDIF_fp32_32_16_1_16(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTCOLEXPANDEXPDIF_fp32_32_16_1_16(float *src0, float *src1, float *dst, void *stream) { + TCOLEXPANDEXPDIF_fp32_32_16_1_16<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case 2: fp32_16_32_1_32 +extern "C" __global__ AICORE void TCOLEXPANDEXPDIF_fp32_16_32_1_32(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTCOLEXPANDEXPDIF_fp32_16_32_1_32(float *src0, float *src1, float *dst, void *stream) { + TCOLEXPANDEXPDIF_fp32_16_32_1_32<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case 3: fp16_32_32_1_32 +extern "C" __global__ AICORE void TCOLEXPANDEXPDIF_fp16_32_32_1_32(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTCOLEXPANDEXPDIF_fp16_32_32_1_32(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TCOLEXPANDEXPDIF_fp16_32_32_1_32<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} + +// Case 4: fp16_16_128_1_128 +extern "C" __global__ AICORE void TCOLEXPANDEXPDIF_fp16_16_128_1_128(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTCOLEXPANDEXPDIF_fp16_16_128_1_128(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TCOLEXPANDEXPDIF_fp16_16_128_1_128<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/main.cpp new file mode 100644 index 000000000..afa716050 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/main.cpp @@ -0,0 +1,157 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tcolexpandexpdif ST +// Test cases match PTO-ISA: /home/zhoushaofan/code/pto-isa/tests/npu/a5/src/st/testcase/tcolexpandexpdif/ +// TCOLEXPANDEXPDIF: compute exp(src0) - exp(tiled_src1) + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTCOLEXPANDEXPDIF_fp32_32_16_1_16(float *src0, float *src1, float *dst, void *stream); +void LaunchTCOLEXPANDEXPDIF_fp32_16_32_1_32(float *src0, float *src1, float *dst, void *stream); +void LaunchTCOLEXPANDEXPDIF_fp16_32_32_1_32(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); +void LaunchTCOLEXPANDEXPDIF_fp16_16_128_1_128(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t src0Rows; + size_t src0Cols; + size_t src1Rows; + size_t src1Cols; + size_t dstRows; + size_t dstCols; + size_t validRows; + size_t validCols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + {"fp32_32_16_1_16", (LaunchFn)LaunchTCOLEXPANDEXPDIF_fp32_32_16_1_16, 32, 16, 1, 16, 32, 16, 32, 16, sizeof(float)}, + {"fp32_16_32_1_32", (LaunchFn)LaunchTCOLEXPANDEXPDIF_fp32_16_32_1_32, 16, 32, 1, 32, 16, 32, 16, 32, sizeof(float)}, + {"fp16_32_32_1_32", (LaunchFn)LaunchTCOLEXPANDEXPDIF_fp16_32_32_1_32, 32, 32, 1, 32, 32, 32, 32, 32, sizeof(uint16_t)}, + {"fp16_16_128_1_128", (LaunchFn)LaunchTCOLEXPANDEXPDIF_fp16_16_128_1_128, 16, 128, 1, 128, 16, 128, 16, 128, sizeof(uint16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t src0ElemCount = tc.src0Rows * tc.src0Cols; + const size_t src1ElemCount = tc.src1Rows * tc.src1Cols; + const size_t dstElemCount = tc.dstRows * tc.dstCols; + const size_t src0FileSize = src0ElemCount * tc.elemSize; + const size_t src1FileSize = src1ElemCount * tc.elemSize; + const size_t dstFileSize = dstElemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (src0=%zux%zu, src1=%zux%zu -> dst=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.src0Rows, tc.src0Cols, tc.src1Rows, tc.src1Cols, tc.dstRows, tc.dstCols, tc.validRows, tc.validCols); + + std::string caseDir = std::string("./") + tc.name; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), src0FileSize); + aclrtMallocHost((void **)(&src1Host), src1FileSize); + aclrtMallocHost((void **)(&dstHost), dstFileSize); + + aclrtMalloc((void **)&src0Device, src0FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, src1FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + size_t actualSrc0FileSize = src0FileSize; + size_t actualSrc1FileSize = src1FileSize; + + if (!ReadFile((caseDir + "/input0.bin").c_str(), actualSrc0FileSize, src0Host, src0FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input0.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0 && !ReadFile((caseDir + "/input1.bin").c_str(), actualSrc1FileSize, src1Host, src1FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, src0FileSize, src0Host, src0FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, src1FileSize, src1Host, src1FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/tcolexpandexpdif.pto b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/tcolexpandexpdif.pto new file mode 100644 index 000000000..07c10b2e8 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/tcolexpandexpdif.pto @@ -0,0 +1,296 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tcolexpandexpdif: compute exp(src0) - exp(tiled_src1). +// Matches PTO-ISA testcase parameters. +// Key: tile_buf cols = full tensor width, v_col = valid portion + +module { + // Case 1: fp32_32_16_1_16 (float32, src0=(32,16), src1=(1,16), dst=(32,16)) + func.func @TCOLEXPANDEXPDIF_fp32_32_16_1_16(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c32, %c16], + strides = [%c512, %c512, %c512, %c16, %c1] + : !pto.tensor_view<1x1x1x32x16xf32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c16], + strides = [%c16, %c16, %c16, %c16, %c1] + : !pto.tensor_view<1x1x1x1x16xf32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c16], + strides = [%c512, %c512, %c512, %c16, %c1] + : !pto.tensor_view<1x1x1x32x16xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c16] + : !pto.tensor_view<1x1x1x32x16xf32> -> !pto.partition_tensor_view<1x1x1x32x16xf32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c16] + : !pto.tensor_view<1x1x1x1x16xf32> -> !pto.partition_tensor_view<1x1x1x1x16xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c16] + : !pto.tensor_view<1x1x1x32x16xf32> -> !pto.partition_tensor_view<1x1x1x32x16xf32> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + + %src1_tile = pto.alloc_tile + : !pto.tile_buf + + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x32x16xf32>) + outs(%src0_tile : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x16xf32>) + outs(%src1_tile : !pto.tile_buf) + + pto.tcolexpandexpdif ins(%src0_tile, %src1_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x16xf32>) + return + } + + // Case 2: fp32_16_32_1_32 (float32, src0=(16,32), src1=(1,32), dst=(16,32)) + func.func @TCOLEXPANDEXPDIF_fp32_16_32_1_32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xf32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c32], + strides = [%c32, %c32, %c32, %c32, %c1] + : !pto.tensor_view<1x1x1x1x32xf32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xf32> -> !pto.partition_tensor_view<1x1x1x16x32xf32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c32] + : !pto.tensor_view<1x1x1x1x32xf32> -> !pto.partition_tensor_view<1x1x1x1x32xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xf32> -> !pto.partition_tensor_view<1x1x1x16x32xf32> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + + %src1_tile = pto.alloc_tile + : !pto.tile_buf + + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x32xf32>) + outs(%src0_tile : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x32xf32>) + outs(%src1_tile : !pto.tile_buf) + + pto.tcolexpandexpdif ins(%src0_tile, %src1_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x32xf32>) + return + } + + // Case 3: fp16_32_32_1_32 (float16, src0=(32,32), src1=(1,32), dst=(32,32)) + func.func @TCOLEXPANDEXPDIF_fp16_32_32_1_32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c32], + strides = [%c32, %c32, %c32, %c32, %c1] + : !pto.tensor_view<1x1x1x1x32xf16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c32] + : !pto.tensor_view<1x1x1x1x32xf16> -> !pto.partition_tensor_view<1x1x1x1x32xf16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + + %src1_tile = pto.alloc_tile + : !pto.tile_buf + + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + outs(%src0_tile : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x32xf16>) + outs(%src1_tile : !pto.tile_buf) + + pto.tcolexpandexpdif ins(%src0_tile, %src1_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + return + } + + // Case 4: fp16_16_128_1_128 (float16, src0=(16,128), src1=(1,128), dst=(16,128)) + func.func @TCOLEXPANDEXPDIF_fp16_16_128_1_128(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x128xf16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf16> -> !pto.partition_tensor_view<1x1x1x1x128xf16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x128xf16> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + + %src1_tile = pto.alloc_tile + : !pto.tile_buf + + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x128xf16>) + outs(%src0_tile : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x128xf16>) + outs(%src1_tile : !pto.tile_buf) + + pto.tcolexpandexpdif ins(%src0_tile, %src1_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x128xf16>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/CMakeLists.txt new file mode 100644 index 000000000..c132ea923 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/CMakeLists.txt @@ -0,0 +1,16 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file in the compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tcolexpandmax) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/cases.py new file mode 100644 index 000000000..78c7d0d6b --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/cases.py @@ -0,0 +1,91 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tcolexpandmax ST test cases. +Matches PTO-ISA testcase definitions in /home/zhoushaofan/code/pto-isa/tests/npu/a5/src/st/testcase/tcolexpandmax/ + +TCOLEXPANDMAX: compute elementwise maximum of src0 and tiled src1. + - src0_shape: (src0_row, cols) - first input tile + - src1_shape: (1, cols) - second input tile (single row, broadcasted) + - dst_shape: (dst_row, cols) - output tile + - shape: (dst_row, cols) - alias of dst_shape, for compare.py compatibility + - valid_shape: (valid_row, valid_col) - effective computation region + - reps: number of times to tile src1 (equals src0_row) + +Golden: np.maximum(src0, np.tile(src1, (reps, 1))[:, :dst_col]) + +Case naming: {dtype}_{src0_row}_{src0_col}_{src1_row}_{dst_col} +""" + +import numpy as np + +CASES = [ + { + "name": "fp32_16_128_1_128", + "dtype": np.float32, + "src0_shape": (16, 128), + "src1_shape": (1, 128), + "shape": (16, 128), + "valid_shape": (16, 128), + "reps": 16, + "eps": 1e-6, + }, + { + "name": "fp32_32_32_1_32", + "dtype": np.float32, + "src0_shape": (32, 32), + "src1_shape": (1, 32), + "shape": (32, 32), + "valid_shape": (32, 32), + "reps": 32, + "eps": 1e-6, + }, + { + "name": "fp16_4_256_1_256", + "dtype": np.float16, + "src0_shape": (4, 256), + "src1_shape": (1, 256), + "shape": (4, 256), + "valid_shape": (4, 256), + "reps": 4, + "eps": 1e-3, + }, + { + "name": "fp16_10_64_1_64", + "dtype": np.float16, + "src0_shape": (10, 64), + "src1_shape": (1, 64), + "shape": (10, 64), + "valid_shape": (10, 64), + "reps": 10, + "eps": 1e-3, + }, + { + "name": "int32_16_32_1_32", + "dtype": np.int32, + "src0_shape": (16, 32), + "src1_shape": (1, 32), + "shape": (16, 32), + "valid_shape": (16, 32), + "reps": 16, + "eps": 0, + }, + { + "name": "int16_16_64_1_64", + "dtype": np.int16, + "src0_shape": (16, 64), + "src1_shape": (1, 64), + "shape": (16, 64), + "valid_shape": (16, 64), + "reps": 16, + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/compare.py new file mode 100644 index 000000000..b8ddb1131 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/compare.py @@ -0,0 +1,56 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file in the compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/gen_data.py new file mode 100644 index 000000000..010fafb80 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/gen_data.py @@ -0,0 +1,41 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file in the compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + src0_shape = case["src0_shape"] + dst_shape = case["shape"] + src1_shape = case["src1_shape"] + valid_shape = case["valid_shape"] + reps = case["reps"] + + src0 = np.random.randint(1, 10, size=src0_shape).astype(dtype) + src1 = np.random.randint(1, 10, size=src1_shape).astype(dtype) + + golden = np.maximum(src0, np.tile(src1, (reps, 1))[:, :dst_shape[1]]) + + save_case_data(case["name"], {"input0": src0, "input1": src1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} src0={src0_shape} src1={src1_shape} dst={dst_shape} valid={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/launch.cpp new file mode 100644 index 000000000..a9797c562 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/launch.cpp @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 1: fp32_16_128_1_128 +extern "C" __global__ AICORE void TCOLEXPANDMAX_fp32_16_128_1_128(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTCOLEXPANDMAX_fp32_16_128_1_128(float *src0, float *src1, float *dst, void *stream) { + TCOLEXPANDMAX_fp32_16_128_1_128<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case 2: fp32_32_32_1_32 +extern "C" __global__ AICORE void TCOLEXPANDMAX_fp32_32_32_1_32(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTCOLEXPANDMAX_fp32_32_32_1_32(float *src0, float *src1, float *dst, void *stream) { + TCOLEXPANDMAX_fp32_32_32_1_32<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case 3: fp16_4_256_1_256 +extern "C" __global__ AICORE void TCOLEXPANDMAX_fp16_4_256_1_256(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTCOLEXPANDMAX_fp16_4_256_1_256(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TCOLEXPANDMAX_fp16_4_256_1_256<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} + +// Case 4: fp16_10_64_1_64 +extern "C" __global__ AICORE void TCOLEXPANDMAX_fp16_10_64_1_64(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTCOLEXPANDMAX_fp16_10_64_1_64(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TCOLEXPANDMAX_fp16_10_64_1_64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} +// Case: int32_16_32_1_32 +extern "C" __global__ AICORE void TCOLEXPANDMAX_int32_16_32_1_32(__gm__ int32_t *src0, __gm__ int32_t *src1, __gm__ int32_t *dst); + +void LaunchTCOLEXPANDMAX_int32_16_32_1_32(int32_t *src0, int32_t *src1, int32_t *dst, void *stream) { + TCOLEXPANDMAX_int32_16_32_1_32<<<1, nullptr, stream>>>((__gm__ int32_t *)src0, (__gm__ int32_t *)src1, (__gm__ int32_t *)dst); +} + +// Case: int16_16_64_1_64 +extern "C" __global__ AICORE void TCOLEXPANDMAX_int16_16_64_1_64(__gm__ int16_t *src0, __gm__ int16_t *src1, __gm__ int16_t *dst); + +void LaunchTCOLEXPANDMAX_int16_16_64_1_64(int16_t *src0, int16_t *src1, int16_t *dst, void *stream) { + TCOLEXPANDMAX_int16_16_64_1_64<<<1, nullptr, stream>>>((__gm__ int16_t *)src0, (__gm__ int16_t *)src1, (__gm__ int16_t *)dst); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/main.cpp new file mode 100644 index 000000000..3972121e9 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/main.cpp @@ -0,0 +1,167 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file in the compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tcolexpandmax ST +// Test cases match PTO-ISA: /home/zhoushaofan/code/pto-isa/tests/npu/a5/src/st/testcase/tcolexpandmax/ +// TCOLEXPANDMAX: elementwise maximum of src0 and tiled src1 + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTCOLEXPANDMAX_fp32_16_128_1_128(float *src0, float *src1, float *dst, void *stream); +void LaunchTCOLEXPANDMAX_fp32_32_32_1_32(float *src0, float *src1, float *dst, void *stream); +void LaunchTCOLEXPANDMAX_fp16_4_256_1_256(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); +void LaunchTCOLEXPANDMAX_fp16_10_64_1_64(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); +void LaunchTCOLEXPANDMAX_int32_16_32_1_32(int32_t *src0, int32_t *src1, int32_t *dst, void *stream); +void LaunchTCOLEXPANDMAX_int16_16_64_1_64(int16_t *src0, int16_t *src1, int16_t *dst, void *stream); + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t src0Rows; + size_t src0Cols; + size_t src1Rows; + size_t src1Cols; + size_t dstRows; + size_t dstCols; + size_t validRows; + size_t validCols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + {"fp32_16_128_1_128", (LaunchFn)LaunchTCOLEXPANDMAX_fp32_16_128_1_128, 16, 128, 1, 128, 16, 128, 16, 128, sizeof(float)}, + {"fp32_32_32_1_32", (LaunchFn)LaunchTCOLEXPANDMAX_fp32_32_32_1_32, 32, 32, 1, 32, 32, 32, 32, 32, sizeof(float)}, + {"fp16_4_256_1_256", (LaunchFn)LaunchTCOLEXPANDMAX_fp16_4_256_1_256, 4, 256, 1, 256, 4, 256, 4, 256, sizeof(uint16_t)}, + {"fp16_10_64_1_64", (LaunchFn)LaunchTCOLEXPANDMAX_fp16_10_64_1_64, 10, 64, 1, 64, 10, 64, 10, 64, sizeof(uint16_t)}, + {"int32_16_32_1_32", (LaunchFn)LaunchTCOLEXPANDMAX_int32_16_32_1_32, 16, 32, 1, 32, 16, 32, 16, 32, sizeof(int32_t)}, + {"int16_16_64_1_64", (LaunchFn)LaunchTCOLEXPANDMAX_int16_16_64_1_64, 16, 64, 1, 64, 16, 64, 16, 64, sizeof(int16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t src0ElemCount = tc.src0Rows * tc.src0Cols; + const size_t src1ElemCount = tc.src1Rows * tc.src1Cols; + const size_t dstElemCount = tc.dstRows * tc.dstCols; + const size_t src0FileSize = src0ElemCount * tc.elemSize; + const size_t src1FileSize = src1ElemCount * tc.elemSize; + const size_t dstFileSize = dstElemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (src0=%zux%zu, src1=%zux%zu -> dst=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.src0Rows, tc.src0Cols, tc.src1Rows, tc.src1Cols, tc.dstRows, tc.dstCols, tc.validRows, tc.validCols); + + std::string caseDir = std::string("./") + tc.name; + size_t actualSrc0FileSize = src0FileSize; + size_t actualSrc1FileSize = src1FileSize; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), src0FileSize); + aclrtMallocHost((void **)(&src1Host), src1FileSize); + aclrtMallocHost((void **)(&dstHost), dstFileSize); + + aclrtMalloc((void **)&src0Device, src0FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, src1FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input0.bin").c_str(), actualSrc0FileSize, src0Host, src0FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input0.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0 && !ReadFile((caseDir + "/input1.bin").c_str(), actualSrc1FileSize, src1Host, src1FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, src0FileSize, src0Host, src0FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, src1FileSize, src1Host, src1FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/tcolexpandmax.pto b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/tcolexpandmax.pto new file mode 100644 index 000000000..2e46e80f6 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/tcolexpandmax.pto @@ -0,0 +1,438 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file in the compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tcolexpandmax: elementwise maximum of src0 and tiled src1. +// Matches PTO-ISA testcase parameters. +// Key: tile_buf cols = full tensor width, v_col = valid portion + +module { + // Case 1: fp32_16_128_1_128 (float32, src0=(16,128), src1=(1,128), dst=(16,128)) + func.func @TCOLEXPANDMAX_fp32_16_128_1_128(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x128xf32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x128xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x128xf32>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpandmax ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x128xf32>) + return + } + + // Case 2: fp32_32_32_1_32 (float32, src0=(32,32), src1=(1,32), dst=(32,32)) + func.func @TCOLEXPANDMAX_fp32_32_32_1_32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c32], + strides = [%c32, %c32, %c32, %c32, %c1] + : !pto.tensor_view<1x1x1x1x32xf32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c32] + : !pto.tensor_view<1x1x1x1x32xf32> -> !pto.partition_tensor_view<1x1x1x1x32xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x32xf32>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpandmax ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + return + } + + // Case 3: fp16_4_256_1_256 (float16, src0=(4,256), src1=(1,256), dst=(4,256)) + func.func @TCOLEXPANDMAX_fp16_4_256_1_256(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xf16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c256] + : !pto.tensor_view<1x1x1x4x256xf16> -> !pto.partition_tensor_view<1x1x1x4x256xf16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c256] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x256xf16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c256] + : !pto.tensor_view<1x1x1x4x256xf16> -> !pto.partition_tensor_view<1x1x1x4x256xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x4x256xf16>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x256xf16>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpandmax ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x256xf16>) + return + } + + // Case 4: fp16_10_64_1_64 (float16, src0=(10,64), src1=(1,64), dst=(10,64)) + func.func @TCOLEXPANDMAX_fp16_10_64_1_64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %c64 = arith.constant 64 : index + %c640 = arith.constant 640 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c10, %c64], + strides = [%c640, %c640, %c640, %c64, %c1] + : !pto.tensor_view<1x1x1x10x64xf16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xf16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c10, %c64], + strides = [%c640, %c640, %c640, %c64, %c1] + : !pto.tensor_view<1x1x1x10x64xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c10, %c64] + : !pto.tensor_view<1x1x1x10x64xf16> -> !pto.partition_tensor_view<1x1x1x10x64xf16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xf16> -> !pto.partition_tensor_view<1x1x1x1x64xf16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c10, %c64] + : !pto.tensor_view<1x1x1x10x64xf16> -> !pto.partition_tensor_view<1x1x1x10x64xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x10x64xf16>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x64xf16>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpandmax ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + +pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x10x64xf16>) + return + } + + // Case 5: int32_16_32_1_32 (int32, src0=(16,32), src1=(1,32), dst=(16,32)) + func.func @TCOLEXPANDMAX_int32_16_32_1_32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c32], + strides = [%c32, %c32, %c32, %c32, %c1] + : !pto.tensor_view<1x1x1x1x32xi32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c32] + : !pto.tensor_view<1x1x1x1x32xi32> -> !pto.partition_tensor_view<1x1x1x1x32xi32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x32xi32>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpandmax ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + return + } + + // Case 6: int16_16_64_1_64 (int16, src0=(16,64), src1=(1,64), dst=(16,64)) + func.func @TCOLEXPANDMAX_int16_16_64_1_64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xi16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xi16> -> !pto.partition_tensor_view<1x1x1x1x64xi16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x64xi16>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpandmax ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) +outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/CMakeLists.txt new file mode 100644 index 000000000..f541b2396 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/CMakeLists.txt @@ -0,0 +1,16 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file in the compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tcolexpandmin) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/cases.py new file mode 100644 index 000000000..f77ae2eb2 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/cases.py @@ -0,0 +1,97 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file in the compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tcolexpandmin ST test cases. + +TCOLEXPANDMIN: compute elementwise minimum of src0 and tiled src1. + - src0_shape: (src0_row, cols) - first input tile + - src1_shape: (1, cols) - second input tile (single row, broadcasted) + - dst_shape: (dst_row, cols) - output tile + - shape: (dst_row, cols) - alias of dst_shape, for compare.py compatibility + - valid_shape: (valid_row, valid_col) - effective computation region + - reps: number of times to tile src1 (equals src0_row) + +Golden: np.minimum(src0, np.tile(src1, (reps, 1))[:, :dst_col]) + +Case naming: {dtype}_{src0_row}_{src0_col}_{src1_row}_{dst_col} +""" + +import numpy as np + +CASES = [ + { + "name": "fp32_16_128_1_128", + "dtype": np.float32, + "src0_shape": (16, 128), + "src1_shape": (1, 128), + "shape": (16, 128), + "valid_shape": (16, 128), + "reps": 16, + "eps": 1e-6, + }, + { + "name": "fp32_32_32_1_32", + "dtype": np.float32, + "src0_shape": (32, 32), + "src1_shape": (1, 32), + "shape": (32, 32), + "valid_shape": (32, 32), + "reps": 32, + "eps": 1e-6, + }, + { + "name": "fp16_4_256_1_256", + "dtype": np.float16, + "src0_shape": (4, 256), + "src1_shape": (1, 256), + "shape": (4, 256), + "valid_shape": (4, 256), + "reps": 4, + "eps": 1e-3, + }, + { + "name": "fp16_10_64_1_64", + "dtype": np.float16, + "src0_shape": (10, 64), + "src1_shape": (1, 64), + "shape": (10, 64), + "valid_shape": (10, 64), + "reps": 10, + "eps": 1e-3, + }, + { + "name": "int32_16_32_1_32", + "dtype": np.int32, + "src0_shape": (16, 32), + "src1_shape": (1, 32), + "shape": (16, 32), + "valid_shape": (16, 32), + "reps": 16, + "eps": 0, + }, + { + "name": "int16_16_64_1_64", + "dtype": np.int16, + "src0_shape": (16, 64), + "src1_shape": (1, 64), + "shape": (16, 64), + "valid_shape": (16, 64), + "reps": 16, + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/compare.py new file mode 100644 index 000000000..b8ddb1131 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/compare.py @@ -0,0 +1,56 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file in the compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/gen_data.py new file mode 100644 index 000000000..a0b5e82be --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/gen_data.py @@ -0,0 +1,42 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file in the compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + src0_shape = case["src0_shape"] + dst_shape = case["shape"] + src1_shape = case["src1_shape"] + valid_shape = case["valid_shape"] + reps = case["reps"] + + src0 = np.random.randint(1, 10, size=src0_shape).astype(dtype) + src1 = np.random.randint(1, 10, size=src1_shape).astype(dtype) + + golden = np.minimum(src0, np.tile(src1, (reps, 1))[:, :dst_shape[1]]) + golden = golden.astype(dtype) + + save_case_data(case["name"], {"input0": src0, "input1": src1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} src0={src0_shape} src1={src1_shape} dst={dst_shape} valid={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/launch.cpp new file mode 100644 index 000000000..c8cdf39e3 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/launch.cpp @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 1: fp32_16_128_1_128 +extern "C" __global__ AICORE void TCOLEXPANDMIN_fp32_16_128_1_128(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTCOLEXPANDMIN_fp32_16_128_1_128(float *src0, float *src1, float *dst, void *stream) { + TCOLEXPANDMIN_fp32_16_128_1_128<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case 2: fp32_32_32_1_32 +extern "C" __global__ AICORE void TCOLEXPANDMIN_fp32_32_32_1_32(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTCOLEXPANDMIN_fp32_32_32_1_32(float *src0, float *src1, float *dst, void *stream) { + TCOLEXPANDMIN_fp32_32_32_1_32<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case 3: fp16_4_256_1_256 +extern "C" __global__ AICORE void TCOLEXPANDMIN_fp16_4_256_1_256(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTCOLEXPANDMIN_fp16_4_256_1_256(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TCOLEXPANDMIN_fp16_4_256_1_256<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} + +// Case 4: fp16_10_64_1_64 +extern "C" __global__ AICORE void TCOLEXPANDMIN_fp16_10_64_1_64(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTCOLEXPANDMIN_fp16_10_64_1_64(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TCOLEXPANDMIN_fp16_10_64_1_64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} +// Case: int32_16_32_1_32 +extern "C" __global__ AICORE void TCOLEXPANDMIN_int32_16_32_1_32(__gm__ int32_t *src0, __gm__ int32_t *src1, __gm__ int32_t *dst); + +void LaunchTCOLEXPANDMIN_int32_16_32_1_32(int32_t *src0, int32_t *src1, int32_t *dst, void *stream) { + TCOLEXPANDMIN_int32_16_32_1_32<<<1, nullptr, stream>>>((__gm__ int32_t *)src0, (__gm__ int32_t *)src1, (__gm__ int32_t *)dst); +} + +// Case: int16_16_64_1_64 +extern "C" __global__ AICORE void TCOLEXPANDMIN_int16_16_64_1_64(__gm__ int16_t *src0, __gm__ int16_t *src1, __gm__ int16_t *dst); + +void LaunchTCOLEXPANDMIN_int16_16_64_1_64(int16_t *src0, int16_t *src1, int16_t *dst, void *stream) { + TCOLEXPANDMIN_int16_16_64_1_64<<<1, nullptr, stream>>>((__gm__ int16_t *)src0, (__gm__ int16_t *)src1, (__gm__ int16_t *)dst); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/main.cpp new file mode 100644 index 000000000..64219b470 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/main.cpp @@ -0,0 +1,167 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file in the compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tcolexpandmin ST +// Test cases match PTO-ISA +// TCOLEXPANDMIN: compute elementwise minimum of src0 and tiled src1 + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTCOLEXPANDMIN_fp32_16_128_1_128(float *src0, float *src1, float *dst, void *stream); +void LaunchTCOLEXPANDMIN_fp32_32_32_1_32(float *src0, float *src1, float *dst, void *stream); +void LaunchTCOLEXPANDMIN_fp16_4_256_1_256(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); +void LaunchTCOLEXPANDMIN_fp16_10_64_1_64(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); +void LaunchTCOLEXPANDMIN_int32_16_32_1_32(int32_t *src0, int32_t *src1, int32_t *dst, void *stream); +void LaunchTCOLEXPANDMIN_int16_16_64_1_64(int16_t *src0, int16_t *src1, int16_t *dst, void *stream); + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t src0Rows; + size_t src0Cols; + size_t src1Rows; + size_t src1Cols; + size_t dstRows; + size_t dstCols; + size_t validRows; + size_t validCols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + {"fp32_16_128_1_128", (LaunchFn)LaunchTCOLEXPANDMIN_fp32_16_128_1_128, 16, 128, 1, 128, 16, 128, 16, 128, sizeof(float)}, + {"fp32_32_32_1_32", (LaunchFn)LaunchTCOLEXPANDMIN_fp32_32_32_1_32, 32, 32, 1, 32, 32, 32, 32, 32, sizeof(float)}, + {"fp16_4_256_1_256", (LaunchFn)LaunchTCOLEXPANDMIN_fp16_4_256_1_256, 4, 256, 1, 256, 4, 256, 4, 256, sizeof(uint16_t)}, + {"fp16_10_64_1_64", (LaunchFn)LaunchTCOLEXPANDMIN_fp16_10_64_1_64, 10, 64, 1, 64, 10, 64, 10, 64, sizeof(uint16_t)}, + {"int32_16_32_1_32", (LaunchFn)LaunchTCOLEXPANDMIN_int32_16_32_1_32, 16, 32, 1, 32, 16, 32, 16, 32, sizeof(int32_t)}, + {"int16_16_64_1_64", (LaunchFn)LaunchTCOLEXPANDMIN_int16_16_64_1_64, 16, 64, 1, 64, 16, 64, 16, 64, sizeof(int16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t src0ElemCount = tc.src0Rows * tc.src0Cols; + const size_t src1ElemCount = tc.src1Rows * tc.src1Cols; + const size_t dstElemCount = tc.dstRows * tc.dstCols; + const size_t src0FileSize = src0ElemCount * tc.elemSize; + const size_t src1FileSize = src1ElemCount * tc.elemSize; + const size_t dstFileSize = dstElemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (src0=%zux%zu, src1=%zux%zu -> dst=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.src0Rows, tc.src0Cols, tc.src1Rows, tc.src1Cols, + tc.dstRows, tc.dstCols, tc.validRows, tc.validCols); + + std::string caseDir = std::string("./") + tc.name; + size_t actualSrc0FileSize = src0FileSize; + size_t actualSrc1FileSize = src1FileSize; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), src0FileSize); + aclrtMallocHost((void **)(&src1Host), src1FileSize); + aclrtMallocHost((void **)(&dstHost), dstFileSize); + + aclrtMalloc((void **)&src0Device, src0FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, src1FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input0.bin").c_str(), actualSrc0FileSize, src0Host, src0FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input0.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input1.bin").c_str(), actualSrc1FileSize, src1Host, src1FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, src0FileSize, src0Host, src0FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, src1FileSize, src1Host, src1FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/tcolexpandmin.pto b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/tcolexpandmin.pto new file mode 100644 index 000000000..321a5ec07 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/tcolexpandmin.pto @@ -0,0 +1,437 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file in the compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tcolexpandmin: elementwise minimum of src0 and tiled src1. +// Matches PTO-ISA testcase parameters. + +module { + // Case 1: fp32_16_128_1_128 (float32, src0=(16,128), src1=(1,128), dst=(16,128)) + func.func @TCOLEXPANDMIN_fp32_16_128_1_128(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x128xf32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x128xf32> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + + %src1_tile = pto.alloc_tile + : !pto.tile_buf + + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x128xf32>) + outs(%src0_tile : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + outs(%src1_tile : !pto.tile_buf) + + pto.tcolexpandmin ins(%src0_tile, %src1_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x128xf32>) + return + } + + // Case 2: fp32_32_32_1_32 (float32, src0=(32,32), src1=(1,32), dst=(32,32)) + func.func @TCOLEXPANDMIN_fp32_32_32_1_32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c32], + strides = [%c32, %c32, %c32, %c32, %c1] + : !pto.tensor_view<1x1x1x1x32xf32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c32] + : !pto.tensor_view<1x1x1x1x32xf32> -> !pto.partition_tensor_view<1x1x1x1x32xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + + %src1_tile = pto.alloc_tile + : !pto.tile_buf + + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%src0_tile : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x32xf32>) + outs(%src1_tile : !pto.tile_buf) + + pto.tcolexpandmin ins(%src0_tile, %src1_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + return + } + + // Case 3: fp16_4_256_1_256 (float16, src0=(4,256), src1=(1,256), dst=(4,256)) + func.func @TCOLEXPANDMIN_fp16_4_256_1_256(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xf16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c256] + : !pto.tensor_view<1x1x1x4x256xf16> -> !pto.partition_tensor_view<1x1x1x4x256xf16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c256] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x256xf16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c256] + : !pto.tensor_view<1x1x1x4x256xf16> -> !pto.partition_tensor_view<1x1x1x4x256xf16> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + + %src1_tile = pto.alloc_tile + : !pto.tile_buf + + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x4x256xf16>) + outs(%src0_tile : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x256xf16>) + outs(%src1_tile : !pto.tile_buf) + + pto.tcolexpandmin ins(%src0_tile, %src1_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x256xf16>) + return + } + + // Case 4: fp16_10_64_1_64 (float16, src0=(10,64), src1=(1,64), dst=(10,64)) + func.func @TCOLEXPANDMIN_fp16_10_64_1_64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %c64 = arith.constant 64 : index + %c640 = arith.constant 640 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c10, %c64], + strides = [%c640, %c640, %c640, %c64, %c1] + : !pto.tensor_view<1x1x1x10x64xf16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xf16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c10, %c64], + strides = [%c640, %c640, %c640, %c64, %c1] + : !pto.tensor_view<1x1x1x10x64xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c10, %c64] + : !pto.tensor_view<1x1x1x10x64xf16> -> !pto.partition_tensor_view<1x1x1x10x64xf16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xf16> -> !pto.partition_tensor_view<1x1x1x1x64xf16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c10, %c64] + : !pto.tensor_view<1x1x1x10x64xf16> -> !pto.partition_tensor_view<1x1x1x10x64xf16> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + + %src1_tile = pto.alloc_tile + : !pto.tile_buf + + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x10x64xf16>) + outs(%src0_tile : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x64xf16>) + outs(%src1_tile : !pto.tile_buf) + + pto.tcolexpandmin ins(%src0_tile, %src1_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + +pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x10x64xf16>) + return + } + + // Case 5: int32_16_32_1_32 (int32, src0=(16,32), src1=(1,32), dst=(16,32)) + func.func @TCOLEXPANDMIN_int32_16_32_1_32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c32], + strides = [%c32, %c32, %c32, %c32, %c1] + : !pto.tensor_view<1x1x1x1x32xi32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c32] + : !pto.tensor_view<1x1x1x1x32xi32> -> !pto.partition_tensor_view<1x1x1x1x32xi32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + + %src1_tile = pto.alloc_tile + : !pto.tile_buf + + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + outs(%src0_tile : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x32xi32>) + outs(%src1_tile : !pto.tile_buf) + + pto.tcolexpandmin ins(%src0_tile, %src1_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + return + } + + // Case 6: int16_16_64_1_64 (int16, src0=(16,64), src1=(1,64), dst=(16,64)) + func.func @TCOLEXPANDMIN_int16_16_64_1_64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xi16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xi16> -> !pto.partition_tensor_view<1x1x1x1x64xi16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + + %src1_tile = pto.alloc_tile + : !pto.tile_buf + + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + outs(%src0_tile : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x64xi16>) + outs(%src1_tile : !pto.tile_buf) + + pto.tcolexpandmin ins(%src0_tile, %src1_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + return + } +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/CMakeLists.txt new file mode 100644 index 000000000..c957813e6 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/CMakeLists.txt @@ -0,0 +1,16 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tcolexpandmul) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/cases.py new file mode 100644 index 000000000..7bc14eb3a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/cases.py @@ -0,0 +1,84 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tcolexpandmul ST test cases. + +TCOLEXPANDMUL: expand src1 then multiply with src0. + - src0_shape: (dst_row, dst_col) - already expanded + - src1_shape: (src1_row, src1_col) - to be expanded (usually src1_row=1) + - dst_shape: (dst_row, dst_col) +""" + +import numpy as np + +CASES = [ + { + "name": "fp32_16_128_1_128", + "dtype": np.float32, + "src0_shape": (16, 128), + "src1_shape": (1, 128), + "shape": (16, 128), + "valid_shape": (16, 128), + "eps": 1e-3, + }, + { + "name": "fp32_32_32_1_32", + "dtype": np.float32, + "src0_shape": (32, 32), + "src1_shape": (1, 32), + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-3, + }, + { + "name": "fp16_4_256_1_256", + "dtype": np.float16, + "src0_shape": (4, 256), + "src1_shape": (1, 256), + "shape": (4, 256), + "valid_shape": (4, 256), + "eps": 1e-3, + }, + { + "name": "fp16_10_64_1_64", + "dtype": np.float16, + "src0_shape": (10, 64), + "src1_shape": (1, 64), + "shape": (10, 64), + "valid_shape": (10, 64), + "eps": 1e-3, + }, + { + "name": "int32_16_32_1_32", + "dtype": np.int32, + "src0_shape": (16, 32), + "src1_shape": (1, 32), + "shape": (16, 32), + "valid_shape": (16, 32), + "eps": 0, + }, + { + "name": "int16_16_64_1_64", + "dtype": np.int16, + "src0_shape": (16, 64), + "src1_shape": (1, 64), + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/compare.py new file mode 100644 index 000000000..b8ddb1131 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/compare.py @@ -0,0 +1,56 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file in the compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/gen_data.py new file mode 100644 index 000000000..d8ee0880f --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/gen_data.py @@ -0,0 +1,42 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + src0_shape = case["src0_shape"] + dst_shape = case["shape"] + src1_shape = case["src1_shape"] + valid_shape = case["valid_shape"] + + src0 = np.random.randint(1, 10, size=src0_shape).astype(dtype) + src1 = np.random.randint(1, 10, size=src1_shape).astype(dtype) + + dst_row, dst_col = dst_shape + reps = dst_row + golden = src0 * np.tile(src1, (reps, 1))[:, :dst_col] + + save_case_data(case["name"], {"input0": src0, "input1": src1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} src0={src0_shape} src1={src1_shape} dst={dst_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/launch.cpp new file mode 100644 index 000000000..1fb3521d9 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/launch.cpp @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 1: fp32_16_128_1_128 +extern "C" __global__ AICORE void TCOLEXPANDMUL_fp32_16_128_1_128(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTCOLEXPANDMUL_fp32_16_128_1_128(float *src0, float *src1, float *dst, void *stream) { + TCOLEXPANDMUL_fp32_16_128_1_128<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case 2: fp32_32_32_1_32 +extern "C" __global__ AICORE void TCOLEXPANDMUL_fp32_32_32_1_32(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTCOLEXPANDMUL_fp32_32_32_1_32(float *src0, float *src1, float *dst, void *stream) { + TCOLEXPANDMUL_fp32_32_32_1_32<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case 3: fp16_4_256_1_256 +extern "C" __global__ AICORE void TCOLEXPANDMUL_fp16_4_256_1_256(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTCOLEXPANDMUL_fp16_4_256_1_256(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TCOLEXPANDMUL_fp16_4_256_1_256<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} + +// Case 4: fp16_10_64_1_64 +extern "C" __global__ AICORE void TCOLEXPANDMUL_fp16_10_64_1_64(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTCOLEXPANDMUL_fp16_10_64_1_64(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TCOLEXPANDMUL_fp16_10_64_1_64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} +// Case: int32_16_32_1_32 +extern "C" __global__ AICORE void TCOLEXPANDMUL_int32_16_32_1_32(__gm__ int32_t *src0, __gm__ int32_t *src1, __gm__ int32_t *dst); + +void LaunchTCOLEXPANDMUL_int32_16_32_1_32(int32_t *src0, int32_t *src1, int32_t *dst, void *stream) { + TCOLEXPANDMUL_int32_16_32_1_32<<<1, nullptr, stream>>>((__gm__ int32_t *)src0, (__gm__ int32_t *)src1, (__gm__ int32_t *)dst); +} + +// Case: int16_16_64_1_64 +extern "C" __global__ AICORE void TCOLEXPANDMUL_int16_16_64_1_64(__gm__ int16_t *src0, __gm__ int16_t *src1, __gm__ int16_t *dst); + +void LaunchTCOLEXPANDMUL_int16_16_64_1_64(int16_t *src0, int16_t *src1, int16_t *dst, void *stream) { + TCOLEXPANDMUL_int16_16_64_1_64<<<1, nullptr, stream>>>((__gm__ int16_t *)src0, (__gm__ int16_t *)src1, (__gm__ int16_t *)dst); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/main.cpp new file mode 100644 index 000000000..6186d0068 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/main.cpp @@ -0,0 +1,165 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tcolexpandmul ST +// TCOLEXPANDMUL: expand src1 then multiply with src0 + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTCOLEXPANDMUL_fp32_16_128_1_128(float *src0, float *src1, float *dst, void *stream); +void LaunchTCOLEXPANDMUL_fp32_32_32_1_32(float *src0, float *src1, float *dst, void *stream); +void LaunchTCOLEXPANDMUL_fp16_4_256_1_256(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); +void LaunchTCOLEXPANDMUL_fp16_10_64_1_64(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); +void LaunchTCOLEXPANDMUL_int32_16_32_1_32(int32_t *src0, int32_t *src1, int32_t *dst, void *stream); +void LaunchTCOLEXPANDMUL_int16_16_64_1_64(int16_t *src0, int16_t *src1, int16_t *dst, void *stream); + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t src0Rows; + size_t src0Cols; + size_t src1Rows; + size_t src1Cols; + size_t dstRows; + size_t dstCols; + size_t validRows; + size_t validCols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + {"fp32_16_128_1_128", (LaunchFn)LaunchTCOLEXPANDMUL_fp32_16_128_1_128, 16, 128, 1, 128, 16, 128, 16, 128, sizeof(float)}, + {"fp32_32_32_1_32", (LaunchFn)LaunchTCOLEXPANDMUL_fp32_32_32_1_32, 32, 32, 1, 32, 32, 32, 32, 32, sizeof(float)}, + {"fp16_4_256_1_256", (LaunchFn)LaunchTCOLEXPANDMUL_fp16_4_256_1_256, 4, 256, 1, 256, 4, 256, 4, 256, sizeof(uint16_t)}, + {"fp16_10_64_1_64", (LaunchFn)LaunchTCOLEXPANDMUL_fp16_10_64_1_64, 10, 64, 1, 64, 10, 64, 10, 64, sizeof(uint16_t)}, + {"int32_16_32_1_32", (LaunchFn)LaunchTCOLEXPANDMUL_int32_16_32_1_32, 16, 32, 1, 32, 16, 32, 16, 32, sizeof(int32_t)}, + {"int16_16_64_1_64", (LaunchFn)LaunchTCOLEXPANDMUL_int16_16_64_1_64, 16, 64, 1, 64, 16, 64, 16, 64, sizeof(int16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t src0ElemCount = tc.src0Rows * tc.src0Cols; + const size_t src1ElemCount = tc.src1Rows * tc.src1Cols; + const size_t dstElemCount = tc.dstRows * tc.dstCols; + const size_t src0FileSize = src0ElemCount * tc.elemSize; + const size_t src1FileSize = src1ElemCount * tc.elemSize; + const size_t dstFileSize = dstElemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (src0=%zux%zu, src1=%zux%zu -> dst=%zux%zu) ===\n", + tc.name, tc.src0Rows, tc.src0Cols, tc.src1Rows, tc.src1Cols, tc.dstRows, tc.dstCols); + + std::string caseDir = std::string("./") + tc.name; + size_t actualSrc0FileSize = src0FileSize; + size_t actualSrc1FileSize = src1FileSize; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), src0FileSize); + aclrtMallocHost((void **)(&src1Host), src1FileSize); + aclrtMallocHost((void **)(&dstHost), dstFileSize); + + aclrtMalloc((void **)&src0Device, src0FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, src1FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input0.bin").c_str(), actualSrc0FileSize, src0Host, src0FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input0.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input1.bin").c_str(), actualSrc1FileSize, src1Host, src1FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, src0FileSize, src0Host, src0FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, src1FileSize, src1Host, src1FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/tcolexpandmul.pto b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/tcolexpandmul.pto new file mode 100644 index 000000000..b73c78ac5 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/tcolexpandmul.pto @@ -0,0 +1,436 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tcolexpandmul: expand src1 then multiply with src0. + +module { + // Case 1: fp32_16_128_1_128 (float32, src0=(16,128), src1=(1,128), dst=(16,128)) + func.func @TCOLEXPANDMUL_fp32_16_128_1_128(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x128xf32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x128xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x128xf32>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpandmul ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x128xf32>) + return + } + + // Case 2: fp32_32_32_1_32 (float32, src0=(32,32), src1=(1,32), dst=(32,32)) + func.func @TCOLEXPANDMUL_fp32_32_32_1_32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c32], + strides = [%c32, %c32, %c32, %c32, %c1] + : !pto.tensor_view<1x1x1x1x32xf32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c32] + : !pto.tensor_view<1x1x1x1x32xf32> -> !pto.partition_tensor_view<1x1x1x1x32xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x32xf32>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpandmul ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + return + } + + // Case 3: fp16_4_256_1_256 (float16, src0=(4,256), src1=(1,256), dst=(4,256)) + func.func @TCOLEXPANDMUL_fp16_4_256_1_256(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xf16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c256] + : !pto.tensor_view<1x1x1x4x256xf16> -> !pto.partition_tensor_view<1x1x1x4x256xf16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c256] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x256xf16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c256] + : !pto.tensor_view<1x1x1x4x256xf16> -> !pto.partition_tensor_view<1x1x1x4x256xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x4x256xf16>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x256xf16>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpandmul ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x256xf16>) + return + } + + // Case 4: fp16_10_64_1_64 (float16, src0=(10,64), src1=(1,64), dst=(10,64)) + func.func @TCOLEXPANDMUL_fp16_10_64_1_64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %c64 = arith.constant 64 : index + %c640 = arith.constant 640 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c10, %c64], + strides = [%c640, %c640, %c640, %c64, %c1] + : !pto.tensor_view<1x1x1x10x64xf16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xf16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c10, %c64], + strides = [%c640, %c640, %c640, %c64, %c1] + : !pto.tensor_view<1x1x1x10x64xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c10, %c64] + : !pto.tensor_view<1x1x1x10x64xf16> -> !pto.partition_tensor_view<1x1x1x10x64xf16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xf16> -> !pto.partition_tensor_view<1x1x1x1x64xf16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c10, %c64] + : !pto.tensor_view<1x1x1x10x64xf16> -> !pto.partition_tensor_view<1x1x1x10x64xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x10x64xf16>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x64xf16>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpandmul ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + +pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x10x64xf16>) + return + } + + // Case 5: int32_16_32_1_32 (int32, src0=(16,32), src1=(1,32), dst=(16,32)) + func.func @TCOLEXPANDMUL_int32_16_32_1_32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c32], + strides = [%c32, %c32, %c32, %c32, %c1] + : !pto.tensor_view<1x1x1x1x32xi32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c32] + : !pto.tensor_view<1x1x1x1x32xi32> -> !pto.partition_tensor_view<1x1x1x1x32xi32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x32xi32>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpandmul ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + return + } + + // Case 6: int16_16_64_1_64 (int16, src0=(16,64), src1=(1,64), dst=(16,64)) + func.func @TCOLEXPANDMUL_int16_16_64_1_64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xi16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xi16> -> !pto.partition_tensor_view<1x1x1x1x64xi16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x64xi16>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpandmul ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/CMakeLists.txt new file mode 100644 index 000000000..0eacb2968 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/CMakeLists.txt @@ -0,0 +1,16 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file in the compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tcolexpandsub) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/cases.py new file mode 100644 index 000000000..47d04eed7 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/cases.py @@ -0,0 +1,84 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tcolexpandsub ST test cases. +Matches PTO-ISA testcase definitions in /home/zhoushaofan/code/pto-isa/tests/npu/a5/src/st/testcase/tcolexpandsub/ + +TCOLEXPANDSUB: subtract src0 by expanded src1 (broadcast src1 first row). + - src0_shape: (src0_row, cols) - first input tile + - src1_shape: (src1_row, cols) - second input tile (only first row used for broadcast) + - dst_shape: (dst_row, cols) - result output + - shape: (dst_row, cols) - alias of dst_shape, for compare.py compatibility + - valid_shape: (valid_row, valid_col) - effective computation region + +Golden: src0 - np.tile(src1, (reps, 1))[:, :dst_col] # expand then subtract + +Case naming: {dtype}_{src0_row}_{cols}_{src1_row}_{dst_col} +""" + +import numpy as np + +CASES = [ + { + "name": "fp32_6_128_1_128", + "dtype": np.float32, + "src0_shape": (6, 128), + "src1_shape": (1, 128), + "shape": (6, 128), + "valid_shape": (6, 128), + "eps": 1e-6, + }, + { + "name": "fp32_18_32_1_32", + "dtype": np.float32, + "src0_shape": (18, 32), + "src1_shape": (1, 32), + "shape": (18, 32), + "valid_shape": (18, 32), + "eps": 1e-6, + }, + { + "name": "fp16_10_256_1_256", + "dtype": np.float16, + "src0_shape": (10, 256), + "src1_shape": (1, 256), + "shape": (10, 256), + "valid_shape": (10, 256), + "eps": 1e-3, + }, + { + "name": "fp16_12_64_1_64", + "dtype": np.float16, + "src0_shape": (12, 64), + "src1_shape": (1, 64), + "shape": (12, 64), + "valid_shape": (12, 64), + "eps": 1e-3, + }, + { + "name": "int32_16_32_1_32", + "dtype": np.int32, + "src0_shape": (16, 32), + "src1_shape": (1, 32), + "shape": (16, 32), + "valid_shape": (16, 32), + "eps": 0, + }, + { + "name": "int16_16_64_1_64", + "dtype": np.int16, + "src0_shape": (16, 64), + "src1_shape": (1, 64), + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/compare.py new file mode 100644 index 000000000..b8ddb1131 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/compare.py @@ -0,0 +1,56 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file in the compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/gen_data.py new file mode 100644 index 000000000..a2fc88aab --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/gen_data.py @@ -0,0 +1,35 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + src0_shape = case["src0_shape"] + dst_shape = case["shape"] + src1_shape = case["src1_shape"] + valid_shape = case["valid_shape"] + + src0 = np.random.randint(1, 10, size=src0_shape).astype(dtype) + src1 = np.random.randint(1, 10, size=src1_shape).astype(dtype) + + valid_row, valid_col = valid_shape + reps = valid_row + golden = src0 - np.tile(src1, (reps, 1))[:, :valid_col] + + save_case_data(case["name"], {"input0": src0, "input1": src1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} src0={src0_shape} src1={src1_shape} dst={dst_shape} valid={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/launch.cpp new file mode 100644 index 000000000..6d9ff773f --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/launch.cpp @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 1: fp32_6_128_1_128 +extern "C" __global__ AICORE void TCOLEXPANDSUB_fp32_6_128_1_128(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTCOLEXPANDSUB_fp32_6_128_1_128(float *src0, float *src1, float *dst, void *stream) { + TCOLEXPANDSUB_fp32_6_128_1_128<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case 2: fp32_18_32_1_32 +extern "C" __global__ AICORE void TCOLEXPANDSUB_fp32_18_32_1_32(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTCOLEXPANDSUB_fp32_18_32_1_32(float *src0, float *src1, float *dst, void *stream) { + TCOLEXPANDSUB_fp32_18_32_1_32<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case 3: fp16_10_256_1_256 +extern "C" __global__ AICORE void TCOLEXPANDSUB_fp16_10_256_1_256(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTCOLEXPANDSUB_fp16_10_256_1_256(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TCOLEXPANDSUB_fp16_10_256_1_256<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} + +// Case 4: fp16_12_64_1_64 +extern "C" __global__ AICORE void TCOLEXPANDSUB_fp16_12_64_1_64(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTCOLEXPANDSUB_fp16_12_64_1_64(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TCOLEXPANDSUB_fp16_12_64_1_64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} +// Case: int32_16_32_1_32 +extern "C" __global__ AICORE void TCOLEXPANDSUB_int32_16_32_1_32(__gm__ int32_t *src0, __gm__ int32_t *src1, __gm__ int32_t *dst); + +void LaunchTCOLEXPANDSUB_int32_16_32_1_32(int32_t *src0, int32_t *src1, int32_t *dst, void *stream) { + TCOLEXPANDSUB_int32_16_32_1_32<<<1, nullptr, stream>>>((__gm__ int32_t *)src0, (__gm__ int32_t *)src1, (__gm__ int32_t *)dst); +} + +// Case: int16_16_64_1_64 +extern "C" __global__ AICORE void TCOLEXPANDSUB_int16_16_64_1_64(__gm__ int16_t *src0, __gm__ int16_t *src1, __gm__ int16_t *dst); + +void LaunchTCOLEXPANDSUB_int16_16_64_1_64(int16_t *src0, int16_t *src1, int16_t *dst, void *stream) { + TCOLEXPANDSUB_int16_16_64_1_64<<<1, nullptr, stream>>>((__gm__ int16_t *)src0, (__gm__ int16_t *)src1, (__gm__ int16_t *)dst); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/main.cpp new file mode 100644 index 000000000..20b950896 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/main.cpp @@ -0,0 +1,167 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file in the compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tcolexpandsub ST +// Test cases match PTO-ISA: /home/zhoushaofan/code/pto-isa/tests/npu/a5/src/st/testcase/tcolexpandsub/ +// TCOLEXPANDSUB: subtract src0 by expanded src1 (broadcast src1 first row) + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTCOLEXPANDSUB_fp32_6_128_1_128(float *src0, float *src1, float *dst, void *stream); +void LaunchTCOLEXPANDSUB_fp32_18_32_1_32(float *src0, float *src1, float *dst, void *stream); +void LaunchTCOLEXPANDSUB_fp16_10_256_1_256(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); +void LaunchTCOLEXPANDSUB_fp16_12_64_1_64(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); +void LaunchTCOLEXPANDSUB_int32_16_32_1_32(int32_t *src0, int32_t *src1, int32_t *dst, void *stream); +void LaunchTCOLEXPANDSUB_int16_16_64_1_64(int16_t *src0, int16_t *src1, int16_t *dst, void *stream); + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t src0Rows; + size_t src0Cols; + size_t src1Rows; + size_t src1Cols; + size_t dstRows; + size_t dstCols; + size_t validRows; + size_t validCols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + {"fp32_6_128_1_128", (LaunchFn)LaunchTCOLEXPANDSUB_fp32_6_128_1_128, 6, 128, 1, 128, 6, 128, 6, 128, sizeof(float)}, + {"fp32_18_32_1_32", (LaunchFn)LaunchTCOLEXPANDSUB_fp32_18_32_1_32, 18, 32, 1, 32,18, 32,18, 32, sizeof(float)}, + {"fp16_10_256_1_256", (LaunchFn)LaunchTCOLEXPANDSUB_fp16_10_256_1_256, 10, 256, 1, 256,10, 256,10, 256, sizeof(uint16_t)}, + {"fp16_12_64_1_64", (LaunchFn)LaunchTCOLEXPANDSUB_fp16_12_64_1_64, 12, 64, 1, 64,12, 64,12, 64, sizeof(uint16_t)}, + {"int32_16_32_1_32", (LaunchFn)LaunchTCOLEXPANDSUB_int32_16_32_1_32, 16, 32, 1, 32, 16, 32, 16, 32, sizeof(int32_t)}, + {"int16_16_64_1_64", (LaunchFn)LaunchTCOLEXPANDSUB_int16_16_64_1_64, 16, 64, 1, 64, 16, 64, 16, 64, sizeof(int16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t src0ElemCount = tc.src0Rows * tc.src0Cols; + const size_t src1ElemCount = tc.src1Rows * tc.src1Cols; + const size_t dstElemCount = tc.dstRows * tc.dstCols; + const size_t src0FileSize = src0ElemCount * tc.elemSize; + const size_t src1FileSize = src1ElemCount * tc.elemSize; + const size_t dstFileSize = dstElemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (src0=%zux%zu, src1=%zux%zu -> dst=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.src0Rows, tc.src0Cols, tc.src1Rows, tc.src1Cols, tc.dstRows, tc.dstCols, tc.validRows, tc.validCols); + + std::string caseDir = std::string("./") + tc.name; + size_t actualSrc0FileSize = src0FileSize; + size_t actualSrc1FileSize = src1FileSize; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), src0FileSize); + aclrtMallocHost((void **)(&src1Host), src1FileSize); + aclrtMallocHost((void **)(&dstHost), dstFileSize); + + aclrtMalloc((void **)&src0Device, src0FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, src1FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input0.bin").c_str(), actualSrc0FileSize, src0Host, src0FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input0.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0 && !ReadFile((caseDir + "/input1.bin").c_str(), actualSrc1FileSize, src1Host, src1FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, src0FileSize, src0Host, src0FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, src1FileSize, src1Host, src1FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/tcolexpandsub.pto b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/tcolexpandsub.pto new file mode 100644 index 000000000..ebab6624d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/tcolexpandsub.pto @@ -0,0 +1,439 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file in the compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tcolexpandsub: subtract src0 by expanded src1. +// Matches PTO-ISA testcase parameters. +// Golden: src0 - np.tile(src1, (reps, 1))[:, :dst_col] + +module { + // Case 1: fp32_6_128_1_128 (float32, src0=(6,128), src1=(1,128), dst=(6,128)) + func.func @TCOLEXPANDSUB_fp32_6_128_1_128(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c6 = arith.constant 6 : index + %c128 = arith.constant 128 : index + %c768 = arith.constant 768 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c6, %c128], + strides = [%c768, %c768, %c768, %c128, %c1] + : !pto.tensor_view<1x1x1x6x128xf32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c6, %c128], + strides = [%c768, %c768, %c768, %c128, %c1] + : !pto.tensor_view<1x1x1x6x128xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c6, %c128] + : !pto.tensor_view<1x1x1x6x128xf32> -> !pto.partition_tensor_view<1x1x1x6x128xf32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c6, %c128] + : !pto.tensor_view<1x1x1x6x128xf32> -> !pto.partition_tensor_view<1x1x1x6x128xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x6x128xf32>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpandsub ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x6x128xf32>) + return + } + + // Case 2: fp32_18_32_1_32 (float32, src0=(18,32), src1=(1,32), dst=(18,32)) + func.func @TCOLEXPANDSUB_fp32_18_32_1_32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c18 = arith.constant 18 : index + %c32 = arith.constant 32 : index + %c576 = arith.constant 576 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c18, %c32], + strides = [%c576, %c576, %c576, %c32, %c1] + : !pto.tensor_view<1x1x1x18x32xf32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c32], + strides = [%c32, %c32, %c32, %c32, %c1] + : !pto.tensor_view<1x1x1x1x32xf32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c18, %c32], + strides = [%c576, %c576, %c576, %c32, %c1] + : !pto.tensor_view<1x1x1x18x32xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c18, %c32] + : !pto.tensor_view<1x1x1x18x32xf32> -> !pto.partition_tensor_view<1x1x1x18x32xf32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c32] + : !pto.tensor_view<1x1x1x1x32xf32> -> !pto.partition_tensor_view<1x1x1x1x32xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c18, %c32] + : !pto.tensor_view<1x1x1x18x32xf32> -> !pto.partition_tensor_view<1x1x1x18x32xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x18x32xf32>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x32xf32>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpandsub ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x18x32xf32>) + return + } + + // Case 3: fp16_10_256_1_256 (float16, src0=(10,256), src1=(1,256), dst=(10,256)) + func.func @TCOLEXPANDSUB_fp16_10_256_1_256(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %c256 = arith.constant 256 : index + %c2560 = arith.constant 2560 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c10, %c256], + strides = [%c2560, %c2560, %c2560, %c256, %c1] + : !pto.tensor_view<1x1x1x10x256xf16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c10, %c256], + strides = [%c2560, %c2560, %c2560, %c256, %c1] + : !pto.tensor_view<1x1x1x10x256xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c10, %c256] + : !pto.tensor_view<1x1x1x10x256xf16> -> !pto.partition_tensor_view<1x1x1x10x256xf16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c256] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x256xf16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c10, %c256] + : !pto.tensor_view<1x1x1x10x256xf16> -> !pto.partition_tensor_view<1x1x1x10x256xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x10x256xf16>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x256xf16>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpandsub ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x10x256xf16>) + return + } + + // Case 4: fp16_12_64_1_64 (float16, src0=(12,64), src1=(1,64), dst=(12,64)) + func.func @TCOLEXPANDSUB_fp16_12_64_1_64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c12 = arith.constant 12 : index + %c64 = arith.constant 64 : index + %c768 = arith.constant 768 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c12, %c64], + strides = [%c768, %c768, %c768, %c64, %c1] + : !pto.tensor_view<1x1x1x12x64xf16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xf16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c12, %c64], + strides = [%c768, %c768, %c768, %c64, %c1] + : !pto.tensor_view<1x1x1x12x64xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c12, %c64] + : !pto.tensor_view<1x1x1x12x64xf16> -> !pto.partition_tensor_view<1x1x1x12x64xf16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xf16> -> !pto.partition_tensor_view<1x1x1x1x64xf16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c12, %c64] + : !pto.tensor_view<1x1x1x12x64xf16> -> !pto.partition_tensor_view<1x1x1x12x64xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x12x64xf16>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x64xf16>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpandsub ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x12x64xf16>) + return + } + + // Case 5: int32_16_32_1_32 (int32, src0=(16,32), src1=(1,32), dst=(16,32)) + func.func @TCOLEXPANDSUB_int32_16_32_1_32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c32], + strides = [%c32, %c32, %c32, %c32, %c1] + : !pto.tensor_view<1x1x1x1x32xi32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c32] + : !pto.tensor_view<1x1x1x1x32xi32> -> !pto.partition_tensor_view<1x1x1x1x32xi32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x32xi32>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpandsub ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + return + } + + // Case 6: int16_16_64_1_64 (int16, src0=(16,64), src1=(1,64), dst=(16,64)) + func.func @TCOLEXPANDSUB_int16_16_64_1_64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xi16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xi16> -> !pto.partition_tensor_view<1x1x1x1x64xi16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x64xi16>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpandsub ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + return + } +} \ No newline at end of file From 38cc246a89a56cc75228c345837e301f9dd1c836 Mon Sep 17 00:00:00 2001 From: zwd060924 Date: Mon, 27 Apr 2026 12:18:58 +0800 Subject: [PATCH 181/192] [Add] tadd tsub tmul tdiv tmax tmin tshl tshr tor tand txor tcmp trem tfmod --- lib/TileOps/tand_template.py | 24 + lib/TileOps/tcmp_template.py | 34 + lib/TileOps/tdiv_template.py | 24 + lib/TileOps/tfmod_template.py | 30 + lib/TileOps/tmax_template.py | 24 + lib/TileOps/tmin_template.py | 24 + lib/TileOps/tmul_template.py | 24 + lib/TileOps/tor_template.py | 24 + lib/TileOps/trem_template.py | 39 ++ lib/TileOps/tshl_template.py | 24 + lib/TileOps/tshr_template.py | 24 + lib/TileOps/tsub_template.py | 24 + lib/TileOps/txor_template.py | 24 + test/basic/expand_tile_op_tilelang_tand.pto | 47 ++ test/basic/expand_tile_op_tilelang_tcmp.pto | 47 ++ test/basic/expand_tile_op_tilelang_tdiv.pto | 47 ++ test/basic/expand_tile_op_tilelang_tfmod.pto | 51 ++ test/basic/expand_tile_op_tilelang_tmax.pto | 47 ++ test/basic/expand_tile_op_tilelang_tmin.pto | 47 ++ test/basic/expand_tile_op_tilelang_tmul.pto | 47 ++ test/basic/expand_tile_op_tilelang_tor.pto | 47 ++ test/basic/expand_tile_op_tilelang_trem.pto | 56 ++ test/basic/expand_tile_op_tilelang_tshl.pto | 47 ++ test/basic/expand_tile_op_tilelang_tshr.pto | 47 ++ test/basic/expand_tile_op_tilelang_tsub.pto | 47 ++ test/basic/expand_tile_op_tilelang_txor.pto | 47 ++ .../npu/a5/src/st/testcase/CMakeLists.txt | 13 + .../npu/a5/src/st/testcase/tadd/cases.py | 73 +- .../npu/a5/src/st/testcase/tadd/compare.py | 9 +- .../golden.bin | Bin 0 -> 8192 bytes .../input1.bin | Bin 0 -> 8192 bytes .../input2.bin | Bin 0 -> 8192 bytes .../f32_16x64_16x64_16x64_16x64/golden.bin | Bin 0 -> 4096 bytes .../f32_16x64_16x64_16x64_16x64/input1.bin | Bin 0 -> 4096 bytes .../f32_16x64_16x64_16x64_16x64/input2.bin | Bin 0 -> 4096 bytes .../f32_32x32_32x32_32x32_32x32/golden.bin | Bin 0 -> 4096 bytes .../f32_32x32_32x32_32x32_32x32/input1.bin | Bin 0 -> 4096 bytes .../f32_32x32_32x32_32x32_32x32/input2.bin | Bin 0 -> 4096 bytes .../golden.bin | Bin 0 -> 32768 bytes .../input1.bin | Bin 0 -> 32768 bytes .../input2.bin | Bin 0 -> 32768 bytes .../f32_64x64_64x64_64x64_64x64/golden.bin | Bin 0 -> 16384 bytes .../f32_64x64_64x64_64x64_64x64/input1.bin | Bin 0 -> 16384 bytes .../f32_64x64_64x64_64x64_64x64/input2.bin | Bin 0 -> 16384 bytes .../npu/a5/src/st/testcase/tadd/gen_data.py | 14 +- .../half_16x64_16x128_16x128_16x64/golden.bin | Bin 0 -> 2048 bytes .../half_16x64_16x128_16x128_16x64/input1.bin | Bin 0 -> 4096 bytes .../half_16x64_16x128_16x128_16x64/input2.bin | Bin 0 -> 4096 bytes .../i16_64x64_64x64_64x64_64x64/golden.bin | Bin 0 -> 8192 bytes .../i16_64x64_64x64_64x64_64x64/input1.bin | Bin 0 -> 8192 bytes .../i16_64x64_64x64_64x64_64x64/input2.bin | Bin 0 -> 8192 bytes .../i32_64x64_64x64_64x64_64x64/golden.bin | Bin 0 -> 16384 bytes .../i32_64x64_64x64_64x64_64x64/input1.bin | Bin 0 -> 16384 bytes .../i32_64x64_64x64_64x64_64x64/input2.bin | Bin 0 -> 16384 bytes .../npu/a5/src/st/testcase/tadd/launch.cpp | 60 +- .../npu/a5/src/st/testcase/tadd/main.cpp | 82 ++- .../npu/a5/src/st/testcase/tadd/tadd.pto | 395 ++++++++++- .../a5/src/st/testcase/tand/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tand/cases.py | 39 ++ .../npu/a5/src/st/testcase/tand/compare.py | 49 ++ .../npu/a5/src/st/testcase/tand/gen_data.py | 35 + .../npu/a5/src/st/testcase/tand/launch.cpp | 25 + .../npu/a5/src/st/testcase/tand/main.cpp | 139 ++++ .../npu/a5/src/st/testcase/tand/tand.pto | 163 +++++ .../a5/src/st/testcase/tcmp/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tcmp/cases.py | 63 ++ .../npu/a5/src/st/testcase/tcmp/compare.py | 56 ++ .../npu/a5/src/st/testcase/tcmp/gen_data.py | 72 ++ .../npu/a5/src/st/testcase/tcmp/launch.cpp | 37 + .../npu/a5/src/st/testcase/tcmp/main.cpp | 147 ++++ .../npu/a5/src/st/testcase/tcmp/tcmp.pto | 266 +++++++ .../a5/src/st/testcase/tdiv/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tdiv/cases.py | 71 ++ .../npu/a5/src/st/testcase/tdiv/compare.py | 49 ++ .../npu/a5/src/st/testcase/tdiv/gen_data.py | 33 + .../npu/a5/src/st/testcase/tdiv/launch.cpp | 55 ++ .../npu/a5/src/st/testcase/tdiv/main.cpp | 154 +++++ .../npu/a5/src/st/testcase/tdiv/tdiv.pto | 398 +++++++++++ .../a5/src/st/testcase/tfmod/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tfmod/cases.py | 75 ++ .../npu/a5/src/st/testcase/tfmod/compare.py | 50 ++ .../npu/a5/src/st/testcase/tfmod/gen_data.py | 37 + .../npu/a5/src/st/testcase/tfmod/launch.cpp | 48 ++ .../npu/a5/src/st/testcase/tfmod/main.cpp | 151 ++++ .../npu/a5/src/st/testcase/tfmod/tfmod.pto | 340 +++++++++ .../a5/src/st/testcase/tmax/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tmax/cases.py | 97 +++ .../npu/a5/src/st/testcase/tmax/compare.py | 49 ++ .../npu/a5/src/st/testcase/tmax/gen_data.py | 33 + .../npu/a5/src/st/testcase/tmax/launch.cpp | 83 +++ .../npu/a5/src/st/testcase/tmax/main.cpp | 162 +++++ .../npu/a5/src/st/testcase/tmax/tmax.pto | 649 ++++++++++++++++++ .../a5/src/st/testcase/tmin/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tmin/cases.py | 83 +++ .../npu/a5/src/st/testcase/tmin/compare.py | 49 ++ .../npu/a5/src/st/testcase/tmin/gen_data.py | 33 + .../npu/a5/src/st/testcase/tmin/launch.cpp | 69 ++ .../npu/a5/src/st/testcase/tmin/main.cpp | 158 +++++ .../npu/a5/src/st/testcase/tmin/tmin.pto | 522 ++++++++++++++ .../a5/src/st/testcase/tmul/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tmul/cases.py | 90 +++ .../npu/a5/src/st/testcase/tmul/compare.py | 49 ++ .../npu/a5/src/st/testcase/tmul/gen_data.py | 33 + .../npu/a5/src/st/testcase/tmul/launch.cpp | 76 ++ .../npu/a5/src/st/testcase/tmul/main.cpp | 160 +++++ .../npu/a5/src/st/testcase/tmul/tmul.pto | 586 ++++++++++++++++ .../npu/a5/src/st/testcase/tor/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tor/cases.py | 39 ++ .../npu/a5/src/st/testcase/tor/compare.py | 49 ++ .../npu/a5/src/st/testcase/tor/gen_data.py | 33 + .../npu/a5/src/st/testcase/tor/launch.cpp | 27 + .../npu/a5/src/st/testcase/tor/main.cpp | 145 ++++ .../npu/a5/src/st/testcase/tor/tor.pto | 144 ++++ .../a5/src/st/testcase/trem/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/trem/cases.py | 102 +++ .../npu/a5/src/st/testcase/trem/compare.py | 50 ++ .../npu/a5/src/st/testcase/trem/gen_data.py | 37 + .../npu/a5/src/st/testcase/trem/launch.cpp | 69 ++ .../npu/a5/src/st/testcase/trem/main.cpp | 157 +++++ .../npu/a5/src/st/testcase/trem/trem.pto | 577 ++++++++++++++++ .../a5/src/st/testcase/tshl/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tshl/cases.py | 39 ++ .../npu/a5/src/st/testcase/tshl/compare.py | 49 ++ .../npu/a5/src/st/testcase/tshl/gen_data.py | 33 + .../npu/a5/src/st/testcase/tshl/launch.cpp | 27 + .../npu/a5/src/st/testcase/tshl/main.cpp | 145 ++++ .../npu/a5/src/st/testcase/tshl/tshl.pto | 144 ++++ .../a5/src/st/testcase/tshr/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tshr/cases.py | 39 ++ .../npu/a5/src/st/testcase/tshr/compare.py | 49 ++ .../npu/a5/src/st/testcase/tshr/gen_data.py | 33 + .../npu/a5/src/st/testcase/tshr/launch.cpp | 27 + .../npu/a5/src/st/testcase/tshr/main.cpp | 145 ++++ .../npu/a5/src/st/testcase/tshr/tshr.pto | 144 ++++ .../a5/src/st/testcase/tsub/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/tsub/cases.py | 69 ++ .../npu/a5/src/st/testcase/tsub/compare.py | 49 ++ .../npu/a5/src/st/testcase/tsub/gen_data.py | 33 + .../npu/a5/src/st/testcase/tsub/launch.cpp | 55 ++ .../npu/a5/src/st/testcase/tsub/main.cpp | 154 +++++ .../npu/a5/src/st/testcase/tsub/tsub.pto | 393 +++++++++++ .../a5/src/st/testcase/txor/CMakeLists.txt | 9 + .../npu/a5/src/st/testcase/txor/cases.py | 66 ++ .../npu/a5/src/st/testcase/txor/compare.py | 50 ++ .../npu/a5/src/st/testcase/txor/gen_data.py | 38 + .../npu/a5/src/st/testcase/txor/launch.cpp | 41 ++ .../npu/a5/src/st/testcase/txor/main.cpp | 149 ++++ .../npu/a5/src/st/testcase/txor/txor.pto | 287 ++++++++ 148 files changed, 10889 insertions(+), 61 deletions(-) create mode 100644 lib/TileOps/tand_template.py create mode 100644 lib/TileOps/tcmp_template.py create mode 100644 lib/TileOps/tdiv_template.py create mode 100644 lib/TileOps/tfmod_template.py create mode 100644 lib/TileOps/tmax_template.py create mode 100644 lib/TileOps/tmin_template.py create mode 100644 lib/TileOps/tmul_template.py create mode 100644 lib/TileOps/tor_template.py create mode 100644 lib/TileOps/trem_template.py create mode 100644 lib/TileOps/tshl_template.py create mode 100644 lib/TileOps/tshr_template.py create mode 100644 lib/TileOps/tsub_template.py create mode 100644 lib/TileOps/txor_template.py create mode 100644 test/basic/expand_tile_op_tilelang_tand.pto create mode 100644 test/basic/expand_tile_op_tilelang_tcmp.pto create mode 100644 test/basic/expand_tile_op_tilelang_tdiv.pto create mode 100644 test/basic/expand_tile_op_tilelang_tfmod.pto create mode 100644 test/basic/expand_tile_op_tilelang_tmax.pto create mode 100644 test/basic/expand_tile_op_tilelang_tmin.pto create mode 100644 test/basic/expand_tile_op_tilelang_tmul.pto create mode 100644 test/basic/expand_tile_op_tilelang_tor.pto create mode 100644 test/basic/expand_tile_op_tilelang_trem.pto create mode 100644 test/basic/expand_tile_op_tilelang_tshl.pto create mode 100644 test/basic/expand_tile_op_tilelang_tshr.pto create mode 100644 test/basic/expand_tile_op_tilelang_tsub.pto create mode 100644 test/basic/expand_tile_op_tilelang_txor.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/f16_16x256_16x256_16x256_16x256/golden.bin create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/f16_16x256_16x256_16x256_16x256/input1.bin create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/f16_16x256_16x256_16x256_16x256/input2.bin create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_16x64_16x64_16x64_16x64/golden.bin create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_16x64_16x64_16x64_16x64/input1.bin create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_16x64_16x64_16x64_16x64/input2.bin create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_32x32_32x32_32x32_32x32/golden.bin create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_32x32_32x32_32x32_32x32/input1.bin create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_32x32_32x32_32x32_32x32/input2.bin create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_64x128_64x128_64x128_64x128/golden.bin create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_64x128_64x128_64x128_64x128/input1.bin create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_64x128_64x128_64x128_64x128/input2.bin create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_64x64_64x64_64x64_64x64/golden.bin create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_64x64_64x64_64x64_64x64/input1.bin create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_64x64_64x64_64x64_64x64/input2.bin create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/half_16x64_16x128_16x128_16x64/golden.bin create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/half_16x64_16x128_16x128_16x64/input1.bin create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/half_16x64_16x128_16x128_16x64/input2.bin create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/i16_64x64_64x64_64x64_64x64/golden.bin create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/i16_64x64_64x64_64x64_64x64/input1.bin create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/i16_64x64_64x64_64x64_64x64/input2.bin create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/i32_64x64_64x64_64x64_64x64/golden.bin create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/i32_64x64_64x64_64x64_64x64/input1.bin create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/i32_64x64_64x64_64x64_64x64/input2.bin create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tand/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tand/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tand/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tand/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tand/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tand/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tand/tand.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcmp/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcmp/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcmp/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcmp/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcmp/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcmp/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcmp/tcmp.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tdiv/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tdiv/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tdiv/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tdiv/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tdiv/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tdiv/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tdiv/tdiv.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tfmod/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tfmod/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tfmod/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tfmod/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tfmod/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tfmod/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tfmod/tfmod.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmax/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmax/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmax/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmax/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmax/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmax/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmax/tmax.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmin/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmin/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmin/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmin/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmin/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmin/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmin/tmin.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmul/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmul/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmul/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmul/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmul/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmul/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tmul/tmul.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tor/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tor/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tor/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tor/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tor/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tor/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tor/tor.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trem/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trem/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trem/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trem/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trem/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trem/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trem/trem.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tshl/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tshl/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tshl/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tshl/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tshl/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tshl/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tshl/tshl.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tshr/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tshr/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tshr/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tshr/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tshr/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tshr/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tshr/tshr.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tsub/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tsub/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tsub/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tsub/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tsub/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tsub/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tsub/tsub.pto create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/txor/CMakeLists.txt create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/txor/cases.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/txor/compare.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/txor/gen_data.py create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/txor/launch.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/txor/main.cpp create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/txor/txor.pto diff --git a/lib/TileOps/tand_template.py b/lib/TileOps/tand_template.py new file mode 100644 index 000000000..4d771d44b --- /dev/null +++ b/lib/TileOps/tand_template.py @@ -0,0 +1,24 @@ +"""TileLang DSL template for pto.tand""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tand" +) +def template_tand(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + result = pto.vand(lhs, rhs, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/tcmp_template.py b/lib/TileOps/tcmp_template.py new file mode 100644 index 000000000..c008ca237 --- /dev/null +++ b/lib/TileOps/tcmp_template.py @@ -0,0 +1,34 @@ +"""TileLang DSL template for pto.tcmp""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tcmp", + advanced=True, +) +def template_tcmp(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = src0.element_type + valid_rows, valid_cols = dst.valid_shape + cmp_mode = pto.get_op_attr("cmp_mode", "eq") + + lanes = pto.get_lanes(dtype) + + dst_ptr = dst.as_ptr() + mask_ptr = pto.castptr(dst_ptr, pto.ptr(pto.ui32, pto.MemorySpace.UB)) + + align_stride = 32 + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + result = pto.vcmp(lhs, rhs, mask, cmp_mode) + byte_offset = row * align_stride + (col // 8) + pto.psts(result, mask_ptr, byte_offset) + return \ No newline at end of file diff --git a/lib/TileOps/tdiv_template.py b/lib/TileOps/tdiv_template.py new file mode 100644 index 000000000..c1c8aea65 --- /dev/null +++ b/lib/TileOps/tdiv_template.py @@ -0,0 +1,24 @@ +"""TileLang DSL template for pto.tdiv""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tdiv" +) +def template_tdiv(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + divided = pto.vdiv(lhs, rhs, mask) + pto.vsts(divided, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/tfmod_template.py b/lib/TileOps/tfmod_template.py new file mode 100644 index 000000000..45df41d3d --- /dev/null +++ b/lib/TileOps/tfmod_template.py @@ -0,0 +1,30 @@ +"""TileLang DSL template for pto.tfmod""" + +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tfmod", + dtypes=[ + (pto.f32, pto.f32, pto.f32), + (pto.f16, pto.f16, pto.f16), + ], +) +def template_tfmod(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + quotient = pto.vdiv(lhs, rhs, mask) + if pto.constexpr(dtype == pto.f32 or dtype == pto.f16): + quotient = pto.vtrc(quotient, mask, rnd=pto.VcvtRoundMode.Z) + truncated_mul = pto.vmul(quotient, rhs, mask) + result = pto.vsub(lhs, truncated_mul, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/tmax_template.py b/lib/TileOps/tmax_template.py new file mode 100644 index 000000000..645da3924 --- /dev/null +++ b/lib/TileOps/tmax_template.py @@ -0,0 +1,24 @@ +"""TileLang DSL template for pto.tmax""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tmax" +) +def template_tmax(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + max_val = pto.vmax(lhs, rhs, mask) + pto.vsts(max_val, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/tmin_template.py b/lib/TileOps/tmin_template.py new file mode 100644 index 000000000..c03d74e3c --- /dev/null +++ b/lib/TileOps/tmin_template.py @@ -0,0 +1,24 @@ +"""TileLang DSL template for pto.tmin""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tmin" +) +def template_tmin(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + min_val = pto.vmin(lhs, rhs, mask) + pto.vsts(min_val, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/tmul_template.py b/lib/TileOps/tmul_template.py new file mode 100644 index 000000000..46acebab0 --- /dev/null +++ b/lib/TileOps/tmul_template.py @@ -0,0 +1,24 @@ +"""TileLang DSL template for pto.tmul""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tmul" +) +def template_tmul(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + multiplied = pto.vmul(lhs, rhs, mask) + pto.vsts(multiplied, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/tor_template.py b/lib/TileOps/tor_template.py new file mode 100644 index 000000000..6efa2fe4e --- /dev/null +++ b/lib/TileOps/tor_template.py @@ -0,0 +1,24 @@ +"""TileLang DSL template for pto.tor""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tor" +) +def template_tor(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + result = pto.vor(lhs, rhs, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/trem_template.py b/lib/TileOps/trem_template.py new file mode 100644 index 000000000..947001974 --- /dev/null +++ b/lib/TileOps/trem_template.py @@ -0,0 +1,39 @@ +"""TileLang DSL template for pto.trem""" + +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.trem", + dtypes=[ + (pto.f32, pto.f32, pto.f32, pto.f32), + (pto.f16, pto.f16, pto.f16, pto.f16), + (pto.i32, pto.i32, pto.i32, pto.i32), + ], +) +def template_trem(src0: pto.Tile, src1: pto.Tile, tmp: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + if pto.constexpr(dtype == pto.f32 or dtype == pto.f16): + quotient = pto.vdiv(lhs, rhs, mask) + quotient = pto.vtrc(quotient, mask, rnd=pto.VcvtRoundMode.F) + floored_mul = pto.vmul(quotient, rhs, mask) + result = pto.vsub(lhs, floored_mul, mask) + elif pto.constexpr(dtype == pto.i32): + lhs_f32 = pto.vcvt(lhs, pto.f32, mask, rnd=pto.VcvtRoundMode.R) + rhs_f32 = pto.vcvt(rhs, pto.f32, mask, rnd=pto.VcvtRoundMode.R) + quotient = pto.vdiv(lhs_f32, rhs_f32, mask) + quotient = pto.vtrc(quotient, mask, rnd=pto.VcvtRoundMode.F) + floored_mul = pto.vmul(quotient, rhs_f32, mask) + result_f32 = pto.vsub(lhs_f32, floored_mul, mask) + result = pto.vcvt(result_f32, pto.i32, mask, rnd=pto.VcvtRoundMode.Z, sat=pto.VcvtSatMode.NOSAT) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/tshl_template.py b/lib/TileOps/tshl_template.py new file mode 100644 index 000000000..28f448509 --- /dev/null +++ b/lib/TileOps/tshl_template.py @@ -0,0 +1,24 @@ +"""TileLang DSL template for pto.tshl""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tshl" +) +def template_tshl(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + result = pto.vshl(lhs, rhs, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/tshr_template.py b/lib/TileOps/tshr_template.py new file mode 100644 index 000000000..59a9e8117 --- /dev/null +++ b/lib/TileOps/tshr_template.py @@ -0,0 +1,24 @@ +"""TileLang DSL template for pto.tshr""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tshr" +) +def template_tshr(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + result = pto.vshr(lhs, rhs, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/tsub_template.py b/lib/TileOps/tsub_template.py new file mode 100644 index 000000000..b97777376 --- /dev/null +++ b/lib/TileOps/tsub_template.py @@ -0,0 +1,24 @@ +"""TileLang DSL template for pto.tsub""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tsub" +) +def template_tsub(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + subtracted = pto.vsub(lhs, rhs, mask) + pto.vsts(subtracted, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/txor_template.py b/lib/TileOps/txor_template.py new file mode 100644 index 000000000..593a4ca74 --- /dev/null +++ b/lib/TileOps/txor_template.py @@ -0,0 +1,24 @@ +"""TileLang DSL template for pto.txor""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.txor" +) +def template_txor(src0: pto.Tile, src1: pto.Tile, tmp: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + result = pto.vxor(lhs, rhs, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/test/basic/expand_tile_op_tilelang_tand.pto b/test/basic/expand_tile_op_tilelang_tand.pto new file mode 100644 index 000000000..46699675e --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_tand.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms of conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tand via the default TileLang Python DSL template +// lib/TileOps/tand_template.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tand should be lowered to vector-style VPTO IR. +// CHECK: func.func @TAND +// CHECK-NOT: pto.tand ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vand +// CHECK: pto.vsts + +module { + func.func @TAND() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.tand ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%tile_buf : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/expand_tile_op_tilelang_tcmp.pto b/test/basic/expand_tile_op_tilelang_tcmp.pto new file mode 100644 index 000000000..b7ef1cd57 --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_tcmp.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tcmp via the default TileLang Python DSL template +// lib/TileOps/tcmp_template.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tcmp should be lowered to vector-style VPTO IR. +// CHECK: func.func @TCMP +// CHECK-NOT: pto.tcmp ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vcmp +// CHECK: pto.psts + +module { + func.func @TCMP() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.tcmp ins(%a, %b {cmpMode = #pto} : !pto.tile_buf, + !pto.tile_buf) + outs(%tile_buf : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/expand_tile_op_tilelang_tdiv.pto b/test/basic/expand_tile_op_tilelang_tdiv.pto new file mode 100644 index 000000000..dba316e29 --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_tdiv.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms of conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tdiv via the default TileLang Python DSL template +// lib/TileOps/tdiv_template.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tdiv should be lowered to vector-style VPTO IR. +// CHECK: func.func @TDIV +// CHECK-NOT: pto.tdiv ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vdiv +// CHECK: pto.vsts + +module { + func.func @TDIV() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.tdiv ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%tile_buf : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/expand_tile_op_tilelang_tfmod.pto b/test/basic/expand_tile_op_tilelang_tfmod.pto new file mode 100644 index 000000000..8e7f2cf2b --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_tfmod.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tfmod via the default TileLang Python DSL template +// lib/TileOps/tfmod_template.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tfmod should be lowered to vector-style VPTO IR. +// fmod(a, b) = a - trunc(a/b) * b +// CHECK: func.func @TFMOD +// CHECK-NOT: pto.tfmod ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vdiv +// CHECK: pto.vtrc +// CHECK: pto.vmul +// CHECK: pto.vsub +// CHECK: pto.vsts + +module { + func.func @TFMOD() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.tfmod ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%tile_buf : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/expand_tile_op_tilelang_tmax.pto b/test/basic/expand_tile_op_tilelang_tmax.pto new file mode 100644 index 000000000..b0957bb0a --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_tmax.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms of conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tmax via the default TileLang Python DSL template +// lib/TileOps/tmax_template.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tmax should be lowered to vector-style VPTO IR. +// CHECK: func.func @TMAX +// CHECK-NOT: pto.tmax ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vmax +// CHECK: pto.vsts + +module { + func.func @TMAX() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.tmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%tile_buf : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/expand_tile_op_tilelang_tmin.pto b/test/basic/expand_tile_op_tilelang_tmin.pto new file mode 100644 index 000000000..011f70f69 --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_tmin.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms of conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tmin via the default TileLang Python DSL template +// lib/TileOps/tmin_template.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tmin should be lowered to vector-style VPTO IR. +// CHECK: func.func @TMIN +// CHECK-NOT: pto.tmin ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vmin +// CHECK: pto.vsts + +module { + func.func @TMIN() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.tmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%tile_buf : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/expand_tile_op_tilelang_tmul.pto b/test/basic/expand_tile_op_tilelang_tmul.pto new file mode 100644 index 000000000..2561132af --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_tmul.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms of conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tmul via the default TileLang Python DSL template +// lib/TileOps/tmul_template.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tmul should be lowered to vector-style VPTO IR. +// CHECK: func.func @TMUL +// CHECK-NOT: pto.tmul ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vmul +// CHECK: pto.vsts + +module { + func.func @TMUL() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.tmul ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%tile_buf : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/expand_tile_op_tilelang_tor.pto b/test/basic/expand_tile_op_tilelang_tor.pto new file mode 100644 index 000000000..e21125005 --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_tor.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms of conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tor via the default TileLang Python DSL template +// lib/TileOps/tor_template.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tor should be lowered to vector-style VPTO IR. +// CHECK: func.func @TOR +// CHECK-NOT: pto.tor ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vor +// CHECK: pto.vsts + +module { + func.func @TOR() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.tor ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%tile_buf : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/expand_tile_op_tilelang_trem.pto b/test/basic/expand_tile_op_tilelang_trem.pto new file mode 100644 index 000000000..b4ad35f00 --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_trem.pto @@ -0,0 +1,56 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms of conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.trem via the default TileLang Python DSL template +// lib/TileOps/trem_template.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.trem should be lowered to vector-style VPTO IR. +// remainder(a, b) = a - floor(a/b) * b +// CHECK: func.func @TREM +// CHECK-NOT: pto.trem ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vdiv +// CHECK: pto.vtrc +// CHECK: pto.vmul +// CHECK: pto.vsub +// CHECK: pto.vsts + +module { + func.func @TREM() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.trem ins(%a, %b, %tmp : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%tile_buf : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/expand_tile_op_tilelang_tshl.pto b/test/basic/expand_tile_op_tilelang_tshl.pto new file mode 100644 index 000000000..a0fec48ac --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_tshl.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms of conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tshl via the default TileLang Python DSL template +// lib/TileOps/tshl_template.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tshl should be lowered to vector-style VPTO IR. +// CHECK: func.func @TSHL +// CHECK-NOT: pto.tshl ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vshl +// CHECK: pto.vsts + +module { + func.func @TSHL() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.tshl ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%tile_buf : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/expand_tile_op_tilelang_tshr.pto b/test/basic/expand_tile_op_tilelang_tshr.pto new file mode 100644 index 000000000..1b8e3ceda --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_tshr.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms of conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tshr via the default TileLang Python DSL template +// lib/TileOps/tshr_template.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tshr should be lowered to vector-style VPTO IR. +// CHECK: func.func @TSHR +// CHECK-NOT: pto.tshr ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vshr +// CHECK: pto.vsts + +module { + func.func @TSHR() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.tshr ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%tile_buf : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/expand_tile_op_tilelang_tsub.pto b/test/basic/expand_tile_op_tilelang_tsub.pto new file mode 100644 index 000000000..177c04d04 --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_tsub.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms of conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tsub via the default TileLang Python DSL template +// lib/TileOps/tsub_template.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tsub should be lowered to vector-style VPTO IR. +// CHECK: func.func @TSUB +// CHECK-NOT: pto.tsub ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vsub +// CHECK: pto.vsts + +module { + func.func @TSUB() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.tsub ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%tile_buf : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/basic/expand_tile_op_tilelang_txor.pto b/test/basic/expand_tile_op_tilelang_txor.pto new file mode 100644 index 000000000..e4e8f2177 --- /dev/null +++ b/test/basic/expand_tile_op_tilelang_txor.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms of conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.txor via the default TileLang Python DSL template +// lib/TileOps/txor_template.py. +// +// Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.txor should be lowered to vector-style VPTO IR. +// CHECK: func.func @TXOR +// CHECK-NOT: pto.txor ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vxor +// CHECK: pto.vsts + +module { + func.func @TXOR() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.txor ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%tile_buf : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt index b443ac5ea..7571f1e4b 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt +++ b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt @@ -130,6 +130,17 @@ endfunction() # -------------------------------------------------------------------------- set(ALL_TESTCASES tadd + tsub + tmul + tdiv + tmax + tmin + tcmp + tshl + tshr + tand + tor + txor tcvt tload tlrelu @@ -190,6 +201,8 @@ set(ALL_TESTCASES tshrs tsubs txors + tfmod + trem ) if((TEST_CASE IN_LIST ALL_TESTCASES) OR (TEST_CASE STREQUAL "all")) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tadd/cases.py index e14e76c54..f911991ec 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tadd/cases.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadd/cases.py @@ -14,10 +14,15 @@ Each case defines: - name: case identifier, used as subdirectory name and by main.cpp kCases[]. - dtype: numpy dtype (e.g. np.float32). - - shape: (rows, cols) — allocated tile dimensions. + - dst_tile: (rows, cols) — dst tile buffer dimensions. + - src0_tile: (rows, cols) — src0 tile buffer dimensions. + - src1_tile: (rows, cols) — src1 tile buffer dimensions. - valid_shape: (valid_rows, valid_cols) — effective computation region. - eps: tolerance for numpy.allclose (atol and rtol). +Note: src0/src1/dst tile buffer physical sizes can differ, + but valid_shape must be the same for all. + gen_data.py and compare.py both import this list to avoid redundant definitions. """ @@ -25,17 +30,75 @@ CASES = [ { - "name": "f32_16x64", + "name": "f32_64x128_64x128_64x128_64x128", + "dtype": np.float32, + "dst_tile": (64, 128), + "src0_tile": (64, 128), + "src1_tile": (64, 128), + "valid_shape": (64, 128), + "eps": 1e-6, + }, + { + "name": "f32_16x64_16x64_16x64_16x64", "dtype": np.float32, - "shape": (16, 64), + "dst_tile": (16, 64), + "src0_tile": (16, 64), + "src1_tile": (16, 64), "valid_shape": (16, 64), "eps": 1e-6, }, { - "name": "f32_32x32", + "name": "f32_32x32_32x32_32x32_32x32", "dtype": np.float32, - "shape": (32, 32), + "dst_tile": (32, 32), + "src0_tile": (32, 32), + "src1_tile": (32, 32), "valid_shape": (32, 32), "eps": 1e-6, }, + { + "name": "f32_64x64_64x64_64x64_64x64", + "dtype": np.float32, + "dst_tile": (64, 64), + "src0_tile": (64, 64), + "src1_tile": (64, 64), + "valid_shape": (64, 64), + "eps": 1e-6, + }, + { + "name": "i32_64x64_64x64_64x64_64x64", + "dtype": np.int32, + "dst_tile": (64, 64), + "src0_tile": (64, 64), + "src1_tile": (64, 64), + "valid_shape": (64, 64), + "eps": 0, + }, + { + "name": "i16_64x64_64x64_64x64_64x64", + "dtype": np.int16, + "dst_tile": (64, 64), + "src0_tile": (64, 64), + "src1_tile": (64, 64), + "valid_shape": (64, 64), + "eps": 0, + }, + { + "name": "f16_16x256_16x256_16x256_16x256", + "dtype": np.float16, + "dst_tile": (16, 256), + "src0_tile": (16, 256), + "src1_tile": (16, 256), + "valid_shape": (16, 256), + "eps": 1e-3, + }, + { + "name": "half_16x64_16x128_16x128_16x64", + "dtype": np.float16, + "dst_tile": (16, 64), + "src0_tile": (16, 128), + "src1_tile": (16, 128), + "valid_shape": (16, 64), + "eps": 1e-3, + }, ] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tadd/compare.py index 428604929..abddab2c7 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tadd/compare.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadd/compare.py @@ -27,11 +27,12 @@ def main(): continue case_dir = case["name"] - shape = case["shape"] - vr, vc = case["valid_shape"] + dst_tile = case["dst_tile"] + valid_shape = case["valid_shape"] + vr, vc = valid_shape - golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) - output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(dst_tile) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(dst_tile) ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) if ok: diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/f16_16x256_16x256_16x256_16x256/golden.bin b/test/tilelang_st/npu/a5/src/st/testcase/tadd/f16_16x256_16x256_16x256_16x256/golden.bin new file mode 100644 index 0000000000000000000000000000000000000000..7e251931a6c657d51480420a3df470f1d55c2e0c GIT binary patch literal 8192 zcmX}x33B5|3`Eg78bi@u+UB8n^!X1^PbegjRag>-jmq=kE1! zexC1E{+^%SUytvfKEIyt>ptJf{ptPnwD#QX{r3LbXSPO<^!NEXf0NGE`Fy_J_x?D) ztfj|>)!(si7mlBEo$n;Pe_fxCB;h05{dM>K=ehI6FUi)E{>ri)>wM?EZ=UaV|Fpa7 zcdH*qls~O}&F5zlKj)POOur$a{Wn`;{Z0P!yq<}u70i5%#fg4A{)CfFdhyckS6Z|~ zGWQ}C-=FI)j?dQH&DSuR{I}gl)L~BUHH<7%F?iici@ob!;T-#PAf^r3Yp`zA*2E_g zTz+;f8XOhlL?=cliZooYEUsfWZ~5xFPm$PnEjBr14-ZykJ5*V*vh|X;UZQ4&oP1!U z7@x^;dTQ*?Q_M%+#!_ptEzUUaZgVL{HG(+|;8}HGKoy8VjLsj86d+k6R-m8^fn8z*O@`;r=CfohiJWbcrZP_YQ zMYFk{h}yTyTGQmg%-wZSUH2u9y{CDFM`k~doarT2U9AmyfciCl{v_*VZ!+6eF=j{C z7V*)ms;p_(T^QPjTz-EC$tyXPfNQ%&+8)WTyvUx}LJOmq)F)!lB_GM?G;yj>ys(d{ znN7*5XmKaaHtGChOr`6qkBeQpG4kw`8TOMb+GO>g$)1(;Hxl16)@!ulB-|ul$(Ry& zWmh(r8K23=^Hv%}A?y>{@%&659+Tfx7Bie8!m;XgH7pfB4#e`&+E%iB$p4@_G8MjT zW1YlS9TsTBVisA(q0?qL8ZIARbeYLq<;AsHz~|2;zdC`Q2_rIp+)e&A}Q-7sBN%z~MN zt*HzvJgyyFf0qZj2~}2h&R7=-jIFSlHEfh|-=l%+Jf}w%{^7m5dL^B9(~8+Vuv3=9 zr^86MlZ5RACKK-LV3;(1WMwL6v}J#L%*@btVhaZ1N-?JzG8H{dzXCqB-IU#Z_i{ z+11zZHY0k$A`4Ye*Cv6Qe4lPf$)SqE36wlt=c^U57h^KZ5Q|RUY+vQ4>$NC*s|z7% zH6cS2*Ti!l@*)$t47i@`tbFUS2iKurbtVrdrsbu4glT8p);5Nwf70~qVg78s*Mr%~ zUJRtU)$#V5MjbKsr&^AzsLGvqUGX;em6eX0XkCO3V)!st{{S8>+9 zyzAVrXCsLDcFINHvizQ@t(kbHtPH>iJ4|99lwHdU&DqOalXUNn+S?f6Vxu1D8;)L& z2t+!F)6*H&cHZJ}d+%Kj?L&XFnUxotSz0~$-#hbA{N|Mi(&)~7mp#0j)_Q{{xihaZ zNh_;#t~<+%P{#YcF`9Hm!j9_jv^^1KkG5Exu_^%{Y~dQe;bhZ(YbGwKvYQl5w*3ZU zSw6*MhfLi)T5>#bv?5xw2luirpDKTPBW>$!(OHa}B~MAiVV2}#@Fhoov}Rv~<}Um< zyZ_r#W#LaA!go(!Wag=}IAom{5w^;EyI2>^hn-f)m|9{mBjjXR?0KWL^{D<}Q8(&` z{8U`uo-5dcTFvxIQiar&UCG8gzqPwny*@W1u*)PYHw(1#_L}(9xbCz(AM%Ie?Lp5+ z>uGDZ3OHALVl`b%J9!vfUC(c|tegHKQNt^ZsrOS{`mlQCbL=}^$!*ojPIu5|yVD`N zSN9ujHJbE9oJ6= z3Q@OVX0aU=4!P$Z=iK5lKY7%%szFsi>KU2$sgL&e?8# zR?J>_i=)ZMH7ttu9`GvLC^czyJKbExk;T%Igd!2MK}q{2jUt3yqQ z-s!#BRLxrrI(IuqUGgQuo`bGxA`9|kV&GV;EV}a?q9e;z`#Z-bYAgBI8FDdC2FXB} z^(%zk@zGx_vK?ph7HuAT{mT{`qS?O1zc_Gaj>+iz?wBh$+S-|&BsW`Ii};3X*x4vX zqF4skGLIh_r^62h)JP0Gy~DE4Hu+0x?$zM8syh2Qmn-JBX8qy_NQ+%d+W?H}y)>l*ulf-O2o{cKtfn!&bho zo1FWuNP0$l9Tw6~iihoishIBZxZSDR;7yxU7XJOMtI3%kbF{i414Gv4vpajDo?pmE zg?T0#L z=IEaJYI^O4J3lh5cgz$luRSLP>$uZ@B3hov;cV-=YD~*-sb%Vwym<{j>2!;@fPgxnb^z!p?THjuJK7|8nLkTo=D84U#&-escPmb+xjm#GBprWX|16+_SRz z;*flN1!Vz5N>C7j{udM(XBqd6pRjGH^Yf~jHQ9n-`3mA)bJ7Kl3I_cr+7voeF^>Dbq#Gc*(cFX~o_2iYNE- z9D+S{RjWKY_f%wO&2^7oUZ>{LhKDUt6xH~CPqIf=WW=v4{Hw(*4B}z-5yxU7H??%O z16plIzJ}j;bwQ_7Qq#Q!=9-r2HC?|8E-$& z@4Rkxsn_jDvgUZKBr`Rd7MpcqjNx^w#;oYcw!7^nt8qFp%jEWeUuy2u%br3QFZZp@ z-(4<9b(-0?sGYT}Re48UOilDT&zpHDv(H4^zIxr0b$`jtd65?Hp?{_llF;LP%A^WV zO>1$_c?jX~oTR^G=qKAV<(Jj)g{mDVhU!*{)x}shzlV9>t$O}VHXrXkN^d&*Q73)} z!MGJFdaOOE&5pW{B2EabDCubXP$arNvI!l((e$xR86Ft zmp;uZzZUYwLK1aF@>cjX5!D!9RoUCGSV@=fqsMg8W@*nWtnf%%)!A2n`)sp09mEGV zC_LFgn&d@S4VocWcayDbO-$q-l2$QX=Z_ami2?uChqHWd^%ZBB?{@T)2)iD}XqXIj z;C0V6?$?{Y2&%W~{G>dQPORy#2Xn;)u;=YBDk^^FQ+kuBxtCA(aadfsO_ZG;PY8zL zGCQ%*5d-Wo(|5Q|)i_v|QBPbZvo3)|hV_lQS9xs8kzK2D4V73~nCoz`ua;DwnKSj? zotWM=WG{*BY!x+G91D5PF8jFTL6ypc@thNEccVgd|2H|7LT~u zn(Br$&3zu2b2^D8P@g(mH1~KO4DgPp=3qK(^xazThwY6xoR18v20iH{Cly0|W;5Ax zdi{IuGrM`4o;tcY&Xc^;&~f|UAeo(|z^v-)22dvCvH{kiY*+)5U^Kh8tUukHWc z^J})>`~KM5``*<1c@7!4f__r`ivHbj+?!{QQ`;ikKc#-{-UE z)`uU@u$YYKu!>i$xi!OTKqD9bVY;Zef#By=uEW2}aV&m(M5wPKhTsv2`MQnmVV-3< z4qwbYyvd0#y{8ym!zZ)dFk9i>3}K=QN@F?nv+6K)g!;IC>CeT>tafT8%hUSyWRBJDmToSb0Nf;uw)bw)bUxa#&s8^A|sj%`fHAa0>B;A*WBL z@bkt*lTMpv6=Hp-MLej6C-ExcDBb7;GqOWs7ti=VcJ~gsy;a^Doy(9iVjHumM%9ZK zclbQtw=2tPT{f&XfyK@e-aD_KNWRb>@3TD!LqHW{cJ_SC6M|5kb*nIST9?1=jGf|_(_CPqjb^6`gmI;3_)JWv#WrWTd|__azNvq1*zIBF&wh1v@9ND< z#GRG%X|}r5lfC7(B;7v{u?GD4Spu1D#WrBU^FPq*dz&@+)bvBl6-)^V( zIU}%RDGwM-Mt@kYqcEImXAif$X@Aw0Lm$V=A!LEebUH`dKN|*BWXFn6J<~Xru^rCz z122@adtT11v0UHSX#$Jhjy3#M;29^KTZ*;E1H-4|MRhwo3QdXmKAGtK*x9c>;0gWp zKWmy5^_jxi`uhdVY4w=kah)ws^Lj7()c&!{je&6TXAz}t?8d4LdP&*Kv#M^0t@rjY zj=Q^w&Xy`pg`f4Y`Qb(P>3q7{{xrDfPT%=L0_h={MP3{|_b=a}-m}TjO8L?xt7^^H zo89gV9kN4eHdy{l>3rmDfBCPrIx09P7&Rq6V$HPsM`fx?kvhQZqhTj1hSQ|XcExXI zTXllv-_F+Xx*ZSO*5)Fc`{8u3=bf?b-EF^L`*1Z=HFU>%THbnW&m%PY$CP~7Jyy40 z4=KSN^Xqo(Hi0nretay)(|Jm49X?-q&8tf^ey;k}&YRkM4S zpIj=*vr6;xjt^QhjsE(=lL|t;ZZ79dVsG6ec~iH%P&dVHR$SucD(<&8_aDDyn9aGd zKGgg#GhV$}$D|yr<`rgWQ#xkN?xuS=#hRL|#mPmt)q=$87+S>t*V4C+?Bh|V^V}Qz zrgm?#=A&2H^TBGTHjWoD|3{lnke)h8@QRD|Q~oFSy1&yRwmv#7>2vIzUEP$1onbxe z@XNBV?6HhReaQ85JMfq%eVXC2@ZT4f@b_NVCbiBNC1V{}byM&+uRZ4{zb+xu9h=@d zXEdSibI$yxpX+MAbswPw>=f-ui_zqVfpS|K3ML4yBO*@j%L-LXzZD4iuHu&QB}x4 z49*7w;j(k8+L|7YjqRpMtLClS;)Zgy(hFWx|4im~rkCwkf4d(W5B69R1N*;;z|gex zFcy|~Q`>I};`{J*++6IqnTN^xP`50dh-bE0yZ!2L{9j&YU#0HT_^off+Ku(GeO_3p z0(7*|?=u^SL$m(A?O`{ync#c-X5u%LY(=c%Je)O|aUMUkQYW~2CZ|q}z<=j!zyDaL z{N{q6i{Ec&PIVZbQ$Q6%arUww(x*X`*+wO|-wQEK91}w^Zx+0F z7s?3pX0|8pJL%nTe2TZL{y0Cs?jCW?=>1;?Jo9!_(!KTP@6J}u0D4OKce*yuC*~z9 zkLNb)1{hnjVN1p1V(aT_x$bPvE~9#Bi5JaxnyB-s;~lba$O9`jVpKkry7Np$%`!m| z7pKj|(V>j$;Y$^>89r2*(}XYuC6f*Tn3nT3{dn;vvgG(J)Y@6*s_gQPl6ant#z ze15vuAq?%)kqjy0?z2wg^3!lTI^pMxg*0r7n%}<@1PLZg2Ai?ZR;J_S_C(==XA{)r zW#4+8ee5xDmdp0wgxPt+Zn>#AWm&_mJ`Z*LhJrQCRwHdf?E{Ud`h~~GOyq`8z3Qe2 z3f_J2Ejzn6w~txzD62Pl=;-vk_tp6LrXaoeE{p!R7iRN<51SCB5q!zo9iiLwE5l8G zPr5LvCU2h0v-Ldooc*4v|2Ny$bU3T7C^D6@!1}QBp1(C-;KErrVzxQ73n$JlApGQQZ`fZ>`gT&&8x}bS$F1y(z3zZGfkmE=-`?PhBQf%wzR#!k?jm^Q6M3imGchP|`=}58=BTs!#Wt(n z%SnA$VqF(oHH$;CRm$L1UYx3N-h1;ax~V*>Fpoz|sSp>=8FpdO+oOpr>-hWhvuDa{ Y>T;S`cPO0V$M*cxmf05q{qN5BAJ9S}&;S4c literal 0 HcmV?d00001 diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/f16_16x256_16x256_16x256_16x256/input2.bin b/test/tilelang_st/npu/a5/src/st/testcase/tadd/f16_16x256_16x256_16x256_16x256/input2.bin new file mode 100644 index 0000000000000000000000000000000000000000..8e916c77ccc55d57a02c7ef3b593a483d85dec01 GIT binary patch literal 8192 zcmYk;TXGxA4n$E6Odkq{EqhoR`pt#SzCNK~vl@j*B9T?p$M?tg$FGmCeLg?_{QUK| z_3Pu$>|YGS>m``+OEyZ!m>U9MmI zhDT%!WE-08sG?h4@R0fQ&DHP4tNb-TJUsjy+SyLS@pa2&sVaY7sPr0-Cv#1ek}O-p;)j|(hqcTe5Wh;CV{$m{CJGmmw%r%Z7p z6ZNDQ-;`8;GVFc+@!EuOHNVT!)XwW-SalK-b2^syJ!M>FET}SN>=|p2h?FOuL3Q}rFLAb!+F&|&u0Qt#_arN+e&8`H=FXasUBWA49gFbP7dp#G?`d-a$uYy_SZQwL(`)qnVPA-dFfNGt!8Z*X2CMx8p2$s`F2y&IumbR9E@ezU84smC~_{=^1jKkJ2DGuYFTpC&vuT#9O#^ z5rf6=ep-Y*e5vS75AwjU4)&S$e!>i^S~80@KfIbMe{1ZR2_KlP=p9NGWbb}~hXzv) zHmM*l2p^U@VX-bBFWK-^e-z1XW~;qtxS$$mdq4l-)RR9vuQNOQgAsE@=}es%*WLTx zkb8d^!M<2|F!ywtCjK!Y`(fC(YEOpP@YD?^4(Hy#=l}fD-GtsMX2X?FIdxqxs&*RG zpmp5U!SymLlw(io17;dvwz@W?LME zlv{>FGAnGUw#xcBM)u?f6DD6azjP-LoO@mk zb?RK*M$3i`=f%%- zkI$?6T3oX+!ES?fb~W^+x(F3%nzxyc!TIp3n!a?r413Cj+f;#tuzqb19t^;!bMY{n zUimKbEco_^&+xCK*?7tX1zu(6C*v}0hYGyak!isUlTYWv>OHxOZmq*-=b0qis=V_~ z=UeIWxqYzIR8yampi5ktac4|@5!tqT86Z<9HQRefZ1--~D%Pea2cFO zOr%_upVLUSt&it-ML4^Ey031E7movb#bB74e*Lh>03~F4Yf`bD4%1PSl^g~TJ|+wwFICs=Eb@lxZpvBUzPPv7GR~T&d8@C3 zQ{Hy!nw6f`@7c$k$RPc^gDM|>6mSHJa8 z%OOAidTD|bg*P##Xrkc~;jeSj>3ZH}WM_q2&sgxgBg18xO+{~a7j6Z=DX6A*Jyvm( zKZc+-8K`7#)tTtL!PT^3@JDTMs;-|9KhZp(y3@0aFr3madoilO1x;*qnzM9ro_U+2 z9Nlzs;nei?6PhM@_rTdGmRIaI>DLF=AjSsti=U+vY*NSLELhj!kcoKn6y9*i`!K&= zF2)n4G_4;N{+@Q|XHm1_)c0nv9kYCNVXE};|C?(bo%S;MwG^FtFy&noyv;)fSxuX@ zK4kOxYqv)_XlB+lhFgeLz?Uh*Z!H3Q z4|h$X>5C{$AXETzoQ);f&Ud;Am*2Gq0@5{dDTeW)=hnr@Kt5gjOReO(#elM5( zIy_dlFE7Ne`7hfn{B>1E{hD*<_pQObHxJoW8jE}`GejcTu?Gd4tiJ8Iv2Yr@@ag1O zxyUJT-`$UJs{$ERkiT!x`M?f^Ihyvx%?@!-Dqq#vRFz4sN2Mx56i4a#7@NQPeL55R ze9gOB-Fi-18N})mf@aw-m8fpLZa5K6npbPO)AaUm7|S})u}Z3n)X%ef`1gLD>>EF@G`(&T>kwnJ4$R$6VHUXSL{!C1 zktW2`vTmAMdQe3lk5b_s^YerQf@e(UdF=e)s{W_jjwH_jq(y!^KcJAP%wa_qtO(gP;`If=dN@rJG*=z#ma&AysNJg?6WTKKjv6vhqu14b*#=0225cGg5jA(S90K{ z^IPS#f4@ZAPr0z>-Two2VO!>9(L*QyrRHk%YdY??8*gUWE7 zZMB?>mbZFV*=R1^RUyXpI5dlqb$?sI6$U+;?fX95KGe^I#HjMzI^uM|lAryna5=Wx z_o-@GUY8lu>uq+m((UbCofxL?vNjL;dQOtLq}Jt|*D)sBKJv&|7A(^tmZ;pj^QAT~ zU2K1Ex`4&hu>&!Dj{+X{F~&RW<);Gny9<6~e$eYcJ^rd}XU|80>ft6Objr-vM^A@s zKF?0r+rvTH;?<1I@pQXys?|ffRX_cTb_$zkEPb7SNNKCin~-z2=<_(F-j*YbcXMp7 z-ppimyx%O;#0;#RCesnwd5wA4U)z6>!q=TBEB`e1+M5acOM(1k_bt9NA$R9=yv<)t z6?X!LXWz5Y2k&lc^~)+!9n957_d*_?F=nQwQpVSN)lG8KU5;{6Ckz;0_Su_fw+if- zsmnWyhgbjNaDvlN%js@(-oE#Jzw`UXP$v>oCK#jZ@Fx!0|EvG!?CVAZT&uXJH&5MD zC$^wauiFuCr`TCUTtEBfsR}&tFUL*oo=?Z!o+_&}RP$<5>U57_w%J;#YgT$m8&=kk R#2%l1jpeXzqHsg}`X2&$9ku`f literal 0 HcmV?d00001 diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_16x64_16x64_16x64_16x64/golden.bin b/test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_16x64_16x64_16x64_16x64/golden.bin new file mode 100644 index 0000000000000000000000000000000000000000..9a1d9f0c5720d99e356a91f6d7babf986ea08d4e GIT binary patch literal 4096 zcmYk*L9ygI5JXWyFa$#|1bqgEAP9z_RAz5c9!BV;Pcqnb zw?bzRXYh`>1$hs;qh=kuo>@=ls@dNp-@{+N-iyA$=kvX1p0(dmZ+T~iEqFucsJUg9 z>+;^Hj(LO@WOPP*RAx8N+PyKtR%(aNS;*{9-hp>}I4XC)hZ$OV_FlVry*YFD2EUDW zy)~j|vAmgG?}$Fp7ImZ7vzy({;D-NZ*4u0E(Gh;Z{8YbXc7r|0^}OW`GlfiVdFIvY zna`kiL?P>8uU+Slcq23auD;vn(`}V+)meA5cK#LgcyND2d)Pn9laKP`E%4v^O!*eh z^2{>a!S2qyf0Ta&9-Yb0DC}-zcK=Sl@T4Q-xA5)Rb!F&bge~yj-1-iX>BzjZVn;I(K|;+r3S_@#ja~J%I=F8}`gC^x*EOnLldQ*O8fJ+ec;g z9`3xUH={oqzWpn^&(&G~Kjdz)1^bhZnd*Z-!g`}+KEn(>?4|CG&~meUyZH*=_jse& zKlz!t=e|YF)Ehdz_?y$AGmKzvh_4Ty=Z){I`^1;&Uj=^)x_+OTe2?n>qGyotw$N+V zyE4mp`tmo5BXbKBJiE>ad(C%sdv-m$?E8lKiuXj9Pkmori@Na!|19+$A1!QQh8}d5 z=N*~1e$gXbLEhrYR(H=V@1^`%yZbxz@CIFYi|-Nk~be8AN=gQp8sNDVP z>hySL<+?Moj@v8fjBqF8>pwwP)`NTVs9jfnl<#}$w7`dseXGnq%hPqMhbyz}sae*; z9`1N%Z}D|nc!Qk3a^Hi0Mz=uSn04)RrgPW5z7FL*;mYpK5!}H3cU!X#{|H-ftK)tK zzWvSYu0>^h-YB*6?Rwst+1qEAU)kMtd61h&*nzR4IIxTvHw>oC`=ZBHmz3i%4 z#~a@AhFQL5I)dIA_S$E(hnBn5=Z!Gqt?Z|l-TU6qX~B*8jj!L9#yi6v^r)WS(p%o} MmTUy?T;c8i0BG4`0ssI2 literal 0 HcmV?d00001 diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_16x64_16x64_16x64_16x64/input1.bin b/test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_16x64_16x64_16x64_16x64/input1.bin new file mode 100644 index 0000000000000000000000000000000000000000..89d0db3d88910d97aa5cee383df1de28737c7db6 GIT binary patch literal 4096 zcmZ9K(Xk*Y3m9Ix{e{=RKTuogtp_k8W_r(j8G5w~HaK9V9^n~wf6p=}?A>)g^-f(jaKksh1vus?5;~afb0%EcTD(uR_|j*IJ3bG+t2jP z9Qw{#*T2+eo)=xe8P7K}qRA8Np#G`vo4pfUuUXHqejfbpW?$5&zhJ7P9e4xT9r6SN zSetogQg3j>6Bggz+(C1X`ba(D-n;kQ@!x%OZs*K8b22BJ|N1N1g!Sv08{mQo+@)sU zy=LpZd*PW;AMgRPP_y5_>hGvd>U@b0i?~J}-YjtseJIrkA36{uv_0G6TJ<>1uq?SAP?n=Lb_q=d(FGH`^ ze?v`7v>n~r*#Tehol}3`3jOOh^X4+!Ts?U+*RyAKIHRs#?U;)%+Yc~5toC;Ez0|E= zpWKC;(O>C3U#RK7H_5=v1vC7z={J0U+&f#F8F0^DpIWUM`GN_0roEXHt-6_>ZKgT3 ztry??1N{W%^xfIsnG>z^&Ac;mf3?@cH*n5cPw3mr1D=7J@Gcwn`~fduZ{2TkqfK~$ z13G74-TbBB`oP|E_0?rj-5u38df)PVdi|5UbJjg~Sc3~TaK?TCx%H$Lzu$rS0`nUd zCz|}d)1aU6zVF!4Wp(+2@}xdsMC&{3FIeD!33_k)@13frVY0PtB}Q7bE=6rPe#)m#2Q`F{{48jke$! zs5hstet>(e)%16?x7)kW3v&ma>Ai>ANZ;MP-|O|<5jejM#rE<6vOA8+={nYVx7%}u?i3*2#0F%z@*Zf4b&ljYvb9cJ&$H)>vBmN)1(v>6mwAGpx4 z#;ex(&8!RDb=2|do;trrGT$iHXSisx#(O5cWew&Vto0Yp_`Juwz+LHGYFW^*M+?25 z%sSuE<5g$Yr{4H{gKGbLviG}AU8rw!=6O`#u!n=zuu#eO=eykbexHwC=H%sQf4n=a zMZZ7Pc=P`F`>D-Kc)R=czQH_L+fRMz$@~6!XHmUn1xoUs<`1KzJ%Z@5r!K6OvMP}yLX z`wDNan%UcMM%U!t^?rl7)0qqT9rEU$ywG@64rGVt%wCOds8M^r_syG@eY~1GlbK&~ z+2GE+M|J*?_w@7&>?Lz1=W2RS^{}Gu9)D5EH#A<}p7F;T)aN+ZOFZaKX74>x6Zukjw17nb>lBMr(clYqx5j#?HIq&cjxDwO(rGW;mjGYT8rnN zPJPh>s<~$~W8Sd9o?gzNclMiYvb^lyXxqw}skb-p1xoR4q3xF_#j*0AAz z{OePn>A6GS&dm9|%}ngqc=F#|L2JSR#*x);A)NH|pGv4~W=Ue0VhsGB~ z-RV8zoV}xxUv$A%>F0N_)~j;B^RGX9%#-Oo|C^)EuF}H=`wa{7o{g86v%@{#w-mkK zF5Z-$+3VSd1qT}TkX?Cakw5;o$NBHRMg8tU@ARzQW9@w2;p}@Ky@R)3U~Zi^{PNl1 zjQbk3r#h=&aE83gyuR6dzq9}8QM=cCgX)f++Dv~&@8qhPZ%_*w^s2R9&F`Rhw|Reh z*7nSv_pPv|ei^>NU1H|ldot(s4bg=+&zCCtUE?lp@#`|6oBckX~*-QbM6;N9yvb2Ge|95ZIk^*!wU E7yZ$dw*UYD literal 0 HcmV?d00001 diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_32x32_32x32_32x32_32x32/golden.bin b/test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_32x32_32x32_32x32_32x32/golden.bin new file mode 100644 index 0000000000000000000000000000000000000000..0c11a52030ad9e8601c3300b9fe3aaf87f73bcbf GIT binary patch literal 4096 zcmZ9~LDA$o5CdRC2!ucg1TTX?FoZzx@f`<3pae>wtiGxCtEoMxq?Xj}%MfTX>eIdjxX#%~zmD?}Z-T zf=}=y13i3%NAO1LW0@Tp=5K-iCoBh^tQGUtd`ABl9Kl$A_Gi$4vbAf^z5lz>vwFVi zF1Xt(qgS8a%Xg>!yrVC>0=*gVdOCb~qFZTf=Gz&CH{KCF16l{*>tsE4&5%t=+h*&(nD_+tbB*GV=)U!C@|@d2jL_thxQu zF+&I4@Lr%>;Yqf9Ec3qHOnYPXM)Voz!3?h8%^f|STp!GQziThZq1F8DRsRm_>yCh) z@pN0*oo_$l@eZ zR++B*pYZoq`Sxe@_f}XR_+O4bf?iC6Yy@_f)@gx!Fa6}s9hk9f1aGuGzVGRQ`3!Wx zjAgB6eeux5fmzOL#cnMBjOT5T+qnw%^vtvy zPkOyP84YgwjWgH_EuLK&Hm^Il1xIiOPcnCS`oMQ<1o~L-4c4AF?X0|4uDg7i}m3N_(0#m+L`6`c(m*jc0(6>^y-gjcg)8>nrE*;?Yg%-w{-gI zW%m7Y{uYd&1xLVt({Tgq@bz?b`M`r2Jelv{jP8N%9>{svyL$Y-y5ToC;$gnsJ8X|f zuiTvl*{HdPTXkpo_PynvE`5j1PuR^F@JDb3J@^Fr;D#QJd$@)9ceGuzJsM>Ct1C0l zVC}8Hp(kFed};H9WN%^5Y?rt6$lQ3r6>Ndd5%l0r7rbSb@4>pM%ae`L=A&lawZCrY z$$Rd))1q70o>|T_??M03cXA!s6F$TGpl6mH!4>p$R%XvPm|L{Yr)InEVLMy+uFhMt zt{wY}=`+yX%3pe~*-cq1%&;B){=-+ap7(T*U<7*ji9c)hJM?6;c#qawOuNwocX12c zkzK(Zz5n=n@;AWte?9){_sY+7+`BV_9X)ym`fiw?phe?ezD!4slhqpg}0aH>_ AMF0Q* literal 0 HcmV?d00001 diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_32x32_32x32_32x32_32x32/input1.bin b/test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_32x32_32x32_32x32_32x32/input1.bin new file mode 100644 index 0000000000000000000000000000000000000000..f2db4f3519cfd7607d2fef23a76466f61969c63b GIT binary patch literal 4096 zcmZ9J(Gg@R3Y)LLn4FA(ZQ;q^Vc?p^_}w#&2efF~|7x-x%vJK#P(3rk*mn z{y>WrcJ|FvZ!nV`_@;JF-9EFO%)QnVyf?z%{NjlNCfFbJR(I>Wm^Aq0Sy{iZ6UGzzWuChLrT-|md_g_*qz?|}Qkj-DSL zXlLx}Z0F_rzSZ{@y*vHpPs0<~)BDWoJ*GC>aCh6UZr}Iqll}p(%xqX+4A6R_Mc?i_ zKR?-mPxSuif&SeQZS5Xs`d!Z4)F)hTQ0w*WYU`Wc?ri4XnQUcN*4y!AcL@F7^1Ijn z2LAwaxwSp&8{Zw~+n)SaZ+Cj^oxnL^-v#!R2X+04Ucg>`Yx#g3UhvE9$fPklW)trF z>gH#MH}dY}94oxp5G>)E3}fc^zX z&CG<09lgxJ2;Ljf_v3A}eL}zQ$ky)P@KA#pnCmy!*Bi;N8FP<(BD;m0C&w~-fFJ@^U(SyeBrGQM*Qc|H%?1dwb|AW z-()B3E@#K5p3w`||7GX{`YoNiAAh0sH#~v+?f-iE2iRjKE_$Rkd%JVj(g!>F`yI`| z#XHHSIrx1OdEeFlX21RPY%<{;6K;0VM;{C@!K-t#^^ND7JG#F62AAC8Q;1)omHF5)hoUF>fXQ6^OL<>teL&=cc1#;iERe#tM)B> y^wh%~m|w8KiMD6z(Jx!~E&JY{e|TGi_V%rd=X(>j=gaSo>>gwXIIp&E&Hn*cNsu@cGa?-CCA$z)yX|$c0qTmm*D&k>z{@XxbbTDP8n>_H&yrE z^_fkyzUbXQztNe$5#B(4zlZF^U$}Z_?w*N$9NX#wV7C;yH7HCJIlA~nU|3dFoC(*%k!;0 z<9=s{*#ZOfj=sZ8EiU?i4Q8~y_paz0*7vOUO}s4Z$bFk$*xA8IcHr)DroU#<&fGIU zUcG|@_x^T!cfI`u`X2Q|--^FrC->7Z03Ib_nwLG?u`~dzbn7h{vsDQ^BFcfK)%iByTiSM%=0^Z;2G^q zeZl#K9q*p^p6oE6nRV~Qdzbv)`nm6A-`MIiQ!lZ^H+e$$D|Y|IYeb$<_yU zCUB2h?zgC4FdgeDi$2LAUq0d*X%p2Jg)>XzomZz|-uy=&fhR`mnqE zo8JA*WV@bu^JIPJ#1C*l=kBeRH?!_}cjgUdp1WY$iI(5}oITZ<+Z)MtGIy$-y^Ot{ zzj)EN9<;fzH-MRQam~VZoNe%O*v#IH)(^DaT(90hEly_62fE+J&iAgR2XCjl(D##V z^a&=sfY}B0YVo~e;SX@{hWie=xgEW*BQEgF3?CT`V9(rp%*A(~`Em|q_sAE_V8kDI z1Nj2G>IF6!U`F?i`kTCOKKKE4*ty=fG@JMpID2`1gZ_SI?>FR|+3P*N&+Y-{@3*CA zCKi~{vVIi#Hs(8-d;;%vCKkBK`&~9(=sS~}nNR27`@QsbE;!-$cJu(w3UklOW^y~# zEBRwJuyfS(uMB2*!z2EhrE}E4-)%m?@7et1&Nuk}A8OAZa=XG#G0-bsZGXbYoOj4? zN4(ia4gGx`J_%-~%adSyNH`_r|f#dGsjPiN2m_jdHw2~YejdITLF{(9N3 z%-)c@*YBnLJM#=~XK*X$xxw`_+r4o{^{@s0Ui9SPebrqz{(kh{_;=L3jyL$U1^x`O z9?tS~jzC+O$?Rr|Gi-(38&~v=>dP1IM(FYN^r5#${k%_Rof+J{*(smuUqRnax6r}} zUwUq+8!E%htYdcD3%l7nd(D0p`wUmjdfxD^+}k6{H+%Ca+`8KXzb9Y2?C-58?~n&? zQn@><{LFlZ7LFk6rFP$6`?m6UGB>;#^k(^I_BXSfzen}WS5(iPp82<@%a`-f!zlKs z?8)67J!)TX>t8`{y*;a=yY9_r`QCEF>%krDtG_y1yfgfD)8W6>^&R_HxSg445ANtF z^%kx$gInr;WqWr0hWI@?!yZmv&_BxamX6$eBf9?m|7G#qy-^+W7G$uag)O+JzM)U$ ze5~%u?szTub`LX{Uvp(W`s*LH|GwpC_NKSwbOzqq^?7sz*%sEncY7<}&j3A`8#2#u zha>Ev2RW6^%p>|meakn@Z@A>O^Y+@G%y#G>fk$x$-7Bo0SN9Dw$hOdfzIg<`y@eT8 z#~ZvWj37Hh4}YD!DVO=iUbAnS(Sq*rk3K&gy`lto>y`7r z%7q&2J*0p10Ei?AgFCFt) zEPr+0%4c%#J>keaLJxAf7wt#QzAfuv2H&wS-~Q!|QNC~J`A(~z&QZD9H}!jYZggAe zmD#_=JF4t9OLEw8uxweg!@K z8Q#k1o}KU3-lE{UoAG*d3q7;mto_TrI%m958Skl#|7P}{d<)*0fiIVt-L=5`vg^W5 zTf7<8fB!w6+a7OJenoZpi?{CFuD5tk-F3Ht&ftwBcxMFNFD#vz{S5kiU7o)As(p{1 zL9f-UKQim{m-g&4yKE0K?|AP`_Eq1!Ja3!zXZ%((pC?;C=i2ppy|E?lLErq<8*WB; zvh!p!3Uez~_QdxFe-B5^b~k$b9=)RaIy|$U8Pgu#AnWn4?s$4%bVS$B%Jb&JzJ-?F z8GTFn=)oIw1b6znvvl3)`A!Q@W_{Teb+hheI(IN1(OI6EcSU>HYB#%~@SQ!pb*F!Z z|L;A0nf_6 z{Qc?cd1u`{$#~wx3_UyV&5S1;*>MH=7Fv8W{|ti{?!e!}myX^W^z2(yPsdK}dl!7dbHR6`)^d&-5x!|6FSE*ZX^3E`8*LOFgW|(zeb94{x?0mD{{o}XnBfGcs*Z&iE zGnpIrn>|o?M}PTdc`MJ}!x7AUH{SA%C+gNNyTe(`cy3mA=@swz>G6AZ-7oVNozYRA z%$s(1SI}!wK5wg;UcsB@86MqmX717D>F9n@Z;hJu_AuigHE+?y@_X_gzeV-Y!WEv% zX68G%8BuTO$z}E){4V)=v`2Z~YQY;vshht}w3q*-Z|3Wr;f=StRKEIh-VAPhi`T-s z)AcS-_XsoCb$KJ~p@%2kwadLDE3%iUa| z1vh3H|Eb(}wjk&4LEoDr3Z2#0g^qp?XVBY&@A8)NTF~X`@~Hia%13yz>wQrj%pjv% z&8_A=ik01!pOx#$dyw(%i}gEh@kW^WuAgb;J>C|MphM-nE6_!ck2_wA&hQ0qyiq*i z$bJ>e_l9rEFoW;)^uM6@gi)TZ+-~;X40;`zW&E%9Ewky6@!@~&XMLr-rj^}CV7 z8?v7H2zzj2e&gHqcq7bsU$DzC1CK6m`)BUiTX}lk*VP-fQ{VLS$#h!qw*3lpW_M>F z!43ZoPi7pMb@;I3OV(@Vo2l*@+^*lxoqYs%a;kd{@xXXd-7ExdKVyrsHu|D`jdBkZN#T(>-TeQC0zX1#UOGq=j7er7q` zy3t+m%p>?Y{LXZ~^jD{4Zt--`E7Mthw>rEn3h!Jst}9^UkH<&=Mh9$M+@te*FJ zx=-*1bli>Fx2PR@n8BO+cHLe()or!Y5$@6>D(Ahq-J@-3d9p3&$?wnukJ?+!SCqeg zZ~8iRKim5G^!0FtFPZO-%x-3>9`6X|9`@iJH*fHUnI|9Bnc45kbXsL=)_=oMzP*Jn z%<9j~{42XVZ=0dlOZRB2+0C8#2xs{`eO`-hQMl20D!Xcz@x4#6JnDOIZrmS%uJ>o= z9^BCNwhnLkTV`(@(GlG5QGIvvo>^bEa($kj`3iTKb?3GPDj(qtdvLSnE4vO_updFs z-qT&*;F)JI>wcLVFY^}O%qUdi1IeUuTvlKP%g#Pkj1D8#3!} z!5i!EiOO^#*PlU;E?@R##}j(?qwo&j-hx{SJ^L9(n1OHJYVI{3wXdJgzWTan<@&I% z&Xsv9_NaV$v;SXj#PjwUR+rbp>K~bTGw7K0*WH`9czfumT}S4|e1)UtJKC!EMP<4( z=xp-fy(iuo#T7>Q`%ZLq{49GRpTWD!^Iosr8+T~o2xrj6%C>mkp}knWBl{Je^xr^r z^=9TC?#i1z^B%s;e|hFNvv<&fj`@kdMXx~5plhb~FA8t+&EB}f2s)5sR(@rcZ{^9} zWOf{d{>&`@ub0{H*zSA1^f8As`xe}^>{M@*w{G2C!8=QJ`OBM`N9FVm@-6t@9&NQ# zZ_lv&ky)>onp<^zQ(v~!&v|BdH-iqZ2cDU>aFninCR=~DC$sKM$9r$)U(XF6Ti8Pj zeDKi2412gj3;db)R?Z*6_jo!z9K{};<;l)4O5M0$KHY=&^~`v)Z_y{(OXWO!uiT%X ze+08TxgKQR8(}8%o-X~u8{}9wGBa-m{ZT4wnceBM=o#(d%X~%kc{)5Ze+zq<<=f@F zBj|X?E^C3mc+!>gXJOX$XTQQ;^r-K+r45;Nsoon8BeULFUXRZ3+p&-Eq|a}`TQhqP z-r;@GE2?AmuAfc!3^R<{&F*A+tJ5;i`1)>Fwr5^kng4px5xk>6gFbrc7WIxd-SD5x z-_Om6=MDKU-qhFKqcc!G^^W&@khPkxsC*_rN=JF?Eq7<&>)3l(e`dGdzzn+m=RJMd z;0ODY-23*o^vLX1&#XiDFpDkf=1(j4bDY6@x}#`OnXdT=Gab7wZ_RtS5o4toK*y#*U=2r6^)uBhw z(e=CH_sn+l6?A6!;^{t7^qO_HAmiEP>%Akp+c$XAOc!p=tZgSYj2pz<$lnfF3h zkAG&?ZIw}7{t@ig&pb1u=XNijj?lsre@4C0sw3Nj@9}YzXP)t|sJutd%HHfPGtcZs zhd06u?&a>~da_otzK;CKOxJ#7-h+Dzo?Q2gj&PUq(6Y-YjdCgsW;}iS z6=wPCzGpX6x$avxyiuNe?^FF&zT7u><~#JCUAFQk^LmHpjYZG=#@nMh{%!vK`;52z zrC+==IrNXhF0*@Y3oCcqDmTw?XTQR%p1!*g&YD|#d_8>G-ODM@yM9)gd-D;V?B3$p zx1j5$Co`YrQN3B7uJ;E&D%*?CZ8&5_i(bLajOsw=tn4@E&*Up}XZGG5ytjotc+ZTU z9?v|2KkHv7e$RY_Gd#8HcpEeG74-Of?QVG9YBlTgf7QG4?T49prmN=-$b0s^_A9!2 zBl8hjdFw5ndEMBr@_N+nhG*V_?g%TlpXK|RbiSycW$i5<-_0wFx_dH@=oKjc2r~S3 z-Rqs@uZ|mj&s$&c`&^!V207N7{E_`A^kqECch{ocJHwm(&(GY$R{B)W-RkJ;LU;LJ zx^_98*|+EjE$pSfIcwI1+`Z1wn4JLZ7Ec@x<#@(P9$ST*YSy(+d2qLS z_A}WR)p=^?nf2c27WK}|yhnR^hdJnDD_uVIuFQ?M{H%I>x$Fwwm1Dih|H@na5wFLS z&8YWgIQTVhwYzn2y=Z!`xdn?+<0TDSx>&^vvU6O^jr2BmCwqys9UQ0BzNb} z%73%VXs=mE_C?KjgZne+x9r}*%*@x3&G3}x9i5}NqWX53UEbr3^1SoaF6Z0z{MVLK0HUItWH6PI{ z%yjhZdfpkOJn!l8TUeeO`&pTejA!0z=AGe9{|Te~9<}SY=ofhU>xQTQ_irV$zU=+c zz52W@^Idt*?oRFv+QS)U?L6;&>GHSWZq(eOa@iB#>}K7`Wy>F#XW(z;Q=V)OSM9!O z_q`T{?&?tA)H^cMGwNQKH{-pP>#ttVthe6KKZEXwLdJ*g4063Ey75Z+_8vW=S2zRz z3*L}>12Q|@*s=OlcFD^h(LGSCzdJp??oo4#{`xa}FZeU>+fQC*`<}T4`elE@TfSb) zJSvy<%vbn(FKVA*4@da>HyOQapSjV!g87WX;fwEG^YUpAGn_&G zrl+fSg;5?|_rB4x_dsXOZtd`fUYFPWMEC4?GyCRIJKe%p{;1upu6JZ^DfIQsXH@4Y zpSRv@@iD{le(&G9*^@)B2k-08pS)!s;S5iGL)T2ds7#NyNBLB?_BZnt&-*RR&`Zrc z^z_y|Gq1OKZtc5Xc#>~X{>u1Q?cUNIwZGYATY0ipGsPLc^m^3$yraAkbGj}#W|yz~8PD8;&KCZ9QMYEE-0fbfYu1tDtbIgVkoPddsGX;Kh4ph#{`znG z$?m3wr9GbAE%wSr)V=rQz4(4S*%8jL$&*bQMvzgtyN1k9bdUBRpFu~?Z=r{oo#G1X zy)V8$lYErF|8(?v-STy|z(2zsuCSHrnBUyAnt$)kc)m&Xdid?RXG1v|eLU%aj6)%6B` zd*0U5KPz9icHI%Y(Lzseb>&!pk6U$k?vBzs>c%_X;Oji`%roe}*t!BD>b5H(-BQx#M5nARatc(vmnU4DwuJFd&qcdvf z%a8Eqmp7vF9?r1dUw8I5*w^eW-PM;J*|$JfZr{s4qZq+^OLg^*Qhgauzt?wrW?en_ zS@f2tJ8R#f{b%>R)u(&1EAZVv(KV0k>wOAu^>oZvID;Eq_j}la-3)!Z8{IGHywMi; zcj!U3?sU+?S-Ltry(=8So%iG;$Xe*}o?zyAXWaLy+?5n?Jo}uT~J)Tlot9;#D)#WcwZx45Q?%yy2zt`R~^W3$lyI~Grc0KzZ zRz~%Cy=L7rGkxpEuB&%drVl;-(j)Wgw#>3K=t8#i$*tM#RoPLy9{4=i_i)y3?zx>o zrgN39ePqWOW^f~O!*AiqOncBf!WZv~zLl*mPgiFT_AQ)Y`R3I(%XNLzyLwkJ&(gnK zr*^s95zO)<@XY9?GdhaZ@h$UJ9sO3b@5=l;H@iJ6)87K`NyeX{XSVmuM{u{-th2n9 zc_xDn-@R{*!0S=|7WVMQYo&Ko4!L(m^bAMf&){Z{!rp8DO4on?Zk~8_1pQvBJ1gVc zb>v$q@5<~wzL_Uqf2N+jl|Q4hJ&e*GZRyx`n^OXXc|+_J-Bt zkD%{8`5u+q<-A$`R@oKRlfTLITNnjT2VdsZp-<4qR-LnEZ}67R+Sl)FRCYwYGlN_G z9xfVhkAhEU(EF0%3F}S!>Tc;k*Ug&@y}S`MyW!i9um|(H*O|$mx|Qj)e|Dam7QBa9 zv#f`&{5A7eHZzalXE3+Y9`zPqZhpd9v-}EjH@@Nbv4^cZ{h3*Z-?O_hxlN%kLx7@G2|M=eI z%hAFS&eAvc-WWf--ukxwU70)CR=dn?k9tSW)3Ntp#u<94-Q87cf3mOpExTFgjE*o% z?`R7;GT-P?-Q{2WV4p#E4@c0G@z(s7r{i`Ma-Lo1DZJA{56hd$)|+yf-L38n-jMNO zU$cANGwk6APrR=>l!qR+z|(1&->CZ@HUIkZFP=NO%)H)m_lvh@I#gG$m6~OH^!%Y^ zmZ66&cvFw}e=7i>e4Up>@AF7 zze~^ZR$s@RZ_J+#|B2tDSDt<1@8_cxUpjovvfpz z@ZJ&ibKAYS{9YO5!FSH!X7#SjGo9t{_OJy#IrRgVe!_@y044f}QG^^~|$oJI>Oh)V==Nb^i72eug{x)cbBa zvQyuJTwi|5gB`x{CinN~Egf97>+9+0dTWcq%@xkjYp1*~JwA2sPBz1vnSWQVr{lf7 zyw#Vv8^sZIKcmoV;mW*+R{F)8fqLVKw$Q`K?j4z%El}{UQhBTW$&8h+n>SuhM{kC+ znDJ$9T+bgsag zfv4l$FDmcB8?srtJe@OMkG7~=J<9iv-Tk`RGrL>6{F0Y9qpRZ$w_725gZ&C_bp1Zy zEu9(tqGz-R9X`+P6K&=7%tvNj*#9u0tmEFkzNhPZE9=#vEsSsm zojvLu%D=-`{*k!_*&A=G{VdOpBgoy!X7sAe?p<$eKlGY)WqRIdQ9nPO<+;1cJ4#2? z_pn#)PQKpM@kT3ey|ZrI@#q;^-g0Z^>CMXQ-hZppqr4H0@~Fq&}uG#*{gIhBl@!iY4F)K5#TeDnmd1jt> zc;Jn2*39dH&hXpSS^JgQt@-%UcR2&8O&R7 zgQ*+-@|yWN>n(Rj?R%-rogSW=b=P~|vFq-oSCrptm!IX0D1RnDqPxB3C)%=O7T(g8 z!Od5*+qY((nb)E{=&rY@PLJ|Ou}9Bv1s%DbzgzF<>TE&piKoMd8`+F6*S$)0Wp<39 ze}IkdwAk)(H`8H zdG@3HJ9>pLKF=&~*;|yi2f5of^w5Ir{ON3&vF=t!?+V`a))~Arqxv(m8$Q*6Z0-6z z`x9NW-QTgE{@>qs?cS1K;mgeoBXbMhkh`O9<@!4IrORvOyZwUB>n6;fK7YO2Gq>7X zdhXoF`MR@wUD*}w(Gk_ro#CnMiyqnKd*Mb8=9$hJJ$^jd^49%p6L-@(k@ zqIY!l&zf6QXYEIJ{WqR@t9;h(ZJxY^)$7$8mFqmo&*-}OvYT5lUzHuv9%ej@@_9Eu zc!#c=^>1e5=jI7!w=jS4_Gm9%zn|r=j(2ed-VwZqmRZ)ro83IZ9`4e!bk>b~xLNa; z%uVBGrZe!L+IjkW(2?=p=-^c*U##z(@$|Ro69sSicg^0qGOzcx^7V9D=t0g~Z}MdJ z89f4T3x2Ox?J}Os+`=8)Vb2Wn5j}&xnf7pn)t%|h@@0JJKG9KLtC_#@H=eiLukMlA zJ9{wm@90+Oo}m@1L*2;TH0VTS!yL9?r#+luH>W4R!U%l7H=bQS!^)23NB6Ml1pW-X zuWsC3wYR9Qo;UgKbY<(CPxe{9{woi9Ix=sJum|2*vkq_VEweXh4_6pr2Jfz(d9Un< ny75+vK2dkR%Rg85w6rD1<^Ngz~Debk_E~`%z13b-tAm-19=NHO3TECV8!s1Ta(y#r4)wbeuNG&9kX_U{>&`B`Y`AdZ zoger|{CkfX^+K)PGx6IR1{{^#-o;PY$I6NFFN9Hx&^u4d@(mlL&j!802j;2W*=Gxb4x$4<6BYOk}Z z49UATe!&I(h+2EsG|Oyz=7-#wzDYl=VZe^+J%7;|^kzM)F0!Yezxm0W_j$%zE!Ojs zZ+zc2@X~DIXR_~o-R<0HAMLDt{ejN@=6I{$VSj<@Oxe`eUgn$;vwr`?JAY8Uv+bL& z=23g*1JzG(mu&n3^Osq(ui2#D;X5a6(! zYq8!ipmP&{LH&Na-Z}Fn?7OM$nXBFN-bd!lh68U$-cvuw7u@j8J>x$66I8Pe*P1uj zpZKzpTMyW9Lw8v(=zHDcoH(Oi7|`!rda!?zPjn>PdU}1&qQ1%O%@5|*&OgO3@T|S6 z{pM=F*L+YrYt}bz>UW0hd)B$vLl3+&@y<=WGxn_29dGXZKqpvl@Quz|FX*{}pJ1*x z->9@!-D%dh=nuTT&6#Po(zE~0#8_inZPp-H3cIOV~PU!hT?HlZwi8FoI!gt3t2m75h+b}`zxe@O< zvlR|HVZaS<58ix1&!}e34RFp{FBcZX?Y&*(QSu-CKAetpiqEb83NxrGj} z-^{wZR`Z-a)DzSV1D4R&`_`kvRCCwHJ5%hLjnFrli?y88)k`h??YBO)`2w?^wRhmc z1Z&^0KdN`$o%Sc3=tVcE((iM}q;5|y2b{6rZ*XRrgZ{vbuK3QHdtQAWJs)fcKh<*@S+hXHW8xTs524aKU=R1kdQ*aZ~r(InM?Rn}0=I&Ij z`=;+sb7|)6gbnpa?;dc6{qDMW`_1($*+4J!ob|L8_Pn=O$?SEHGy03}uuyaB`V;T$ z%g{k~p?7p4|GsSmySh3|Kpjr2T!(>v&d zAt5D%r?v2@9or`yZC;ObMES1XBMavn(NJ* zdHzt#fCatBeLK15XS{v$1MOeD^HYW1XKp3~`n~F*?)$I3;MwNA>-9L(+OvJr#7lRZ zZ)#^p^rBzPC!DaO{RZm`6WsH5xi7x?%A9JZo_N*yi;6qeN9|eLJMgXdr|!8)_TG`6 z^TqSuGkU+JzlmqfZt|}lvI*`s*T21}^}CZ^`-iN%_ov^x2ASt4+WzWXGJBnO=DSkJ(GBh~o2kvLFZiDRrq&azJ=6X$D}EP# z!i5bBoOjp3PcRp~`}CbzW^iCa-=Gd_*wA+^eBaZzSx@o{p3|FcRDYm*J#(q;t>n*> zci;Jw-_-37>KS%>sQrUpP$#_aL`Uk4c22dn_ww}g{U&GZUpQb--`~f&d(>qQ-o5+= zy}x@p-#OJ8`!n8|0jimIT@|{ozq8r-oa%2x;r-Y7yF(u>x|ii z0qs3+_I(T9Suq!PSTA_WZvBQCZh8YEbo@1{Pml2^NH*vnb< z_g(HtueoRCfV~M9erJ-sdAxJ=YUc;}dn4*Qc2vLc2fkT<|MyMK7U$X9{#SeN&8~@x zGdt()nGNWk_SC-FnFGDwo_nlcPQB#wjbA@AZ{}I`qUP0R*0SNi0`Hmn>^HZ*@ZN!T zrf*zi&Udf#%{+HL^-isJPk&42X8Je%h4%aHnIACsy!w5YwOloCI{0A?6RcIw>*dX& zyEW>}1?Tz(`_}3L`(Nx0vVK#~IlF^1-gh$Ryezn|;ecnId-KWc&2KOHLi;YydEU%f zy}Y+1d+&5V^-ea?0XKZ-51I5m_MB69dNVrE`u?sfy^X#dsAtWsoxh-p?`vj$V1wC! z3kO!%==&!0yVn_e&Z+jAsn=fqPe$sE&d|QSzVqO%-POCc8MWAV_9=eCfD0R}zjrq4 zKI`uJX1%NTR1fp*>1SA|wb*z5;>B#ij`|++?sUHOJ4>&%dP2W*Q#+^5;7q+e`PPg6 z{>)yy_nZ9`KcT&WKd``k=2d&etUC_*h6QHoP2ICIwfTVo_RLkYw=?>DcKI80{{3B< zS($6yd%MrRne~pheslSMJm$eUd(PG`{7vn9^m1Uqj+(tUGGlK--`{TsfYev_GH0KZhb<}UuqdJ<2Sm(MGs7C?yL@K z?|!wt)aDy*^q{`E@3L>!@0<7wzwhS`XUxtgcYagbYp$Bfga!RQ^efBAeQL#SB)<-AaJp;Y*8|`mn|D>)C>YaS1*Y}v4$%W?b8`SEKdapC>9sCRn-S*(X z2>o3zwYX!W2i&vJ&P@D(3p?K4!gu!Dcjnjw=bm?G-h6XzqtByyXH{9@-IKG^4DVZC zA8&TTMMwObN#7Z<@A-oc*zhvcOeQSY@gx4C`U4AIop)?%-}&D9@#fxVe&XNzGqa-3 zb++@9tbf;=x&9Waec2z?zx|*7NVYQT&c5@0&Q84bfWEPry6kb*zBJqTV@6NkT(veo z=z^QNZ)k0Q2k*8XsV~~!uitI9VM+eR?wr&E)c1|qwW(J!d)9Jc!35{6H*}Xhb8B6Z6?c8#y~r+jM!!?f=s^dp)EhNxU3I6qjPzSC>TQ-cn%_@7@V-l( z@OtRW(GBKm-+oi8112oV9oGFlFW!15bBEr1zydRC&rkFuv+f>etnHaO*ZVf!{KET| z&pyu$@|B)5%?^GhALvTme$_d9t#4|x9Y0a?6V=NA-?>oFn{6<+em(o}^1e4`oOjmV z)TiHFZwL8-6>e&^41CWl{DciBYSwz<`#XN`bEa=L?{}S??fyaS4)Y~^i~fG--ZztP zcWz{6qWjPMN`2C^Zg%zg&CI%g>r;2XJ@Z%N55C`}w>Mz~{Q~uQ?oo&2?{V}AJF`+>bbvacXWXkd zJD=X){z%U`>6?1bC6m_gymQG8D)xJ)^ZhRU0MG5{LB%s-eyzbd{j2dO`AoJ`uhgn{ zxu_VmJJ8&xx>0=++%Nox#){?C%HPrJ?I9r z3(k$`LcM3ATMr{Nw|795p5J))t2bB=yl0)Wr#~vR?z^7Sx97ZiVZw;FcE^rdANcLZ z+Y@tV2kP!$jeqa%o1Bx+UT2-T=nNZ`0Sk_vnTg8%=({)P%@?@u=z9--ScCb6_pU#g zGqS-MamPeQ{L}b@?x64QI`KXCof+h>msxA)q~0BU|DnG#`)24_PjIF}Z-_nR($gWB`f=TrCnW_o8|AKl3Y`k=q?GhA~s-q!v02RQTIL2u*DozY9*Xg1R8 zF6#?3EcAd{zdwHJ|L*zB+jqBTs?J|(wf#lyjJ5v4h=2Rp(cj3vy#r=DwfoK0i_UoK zg_jKnJlFTUb2;nGx1Q9^-tU$-`zCAq15W3nH`=#b&#=(fXUD)FFzf!R*@nJ<;D5gv z{=m0y?~+aD3YodJdyhHw(|CImZgijMh^_ocbD}9=lh*z^$YL$&UBakow`K_9nAA6sjGucLOFthGEM(Txn_M&>3Fyc@2)!e=9F+ZdB zrF;85&Y0P|!QR5lgaghDbcg;f_WLgVf&;zBtY*%*vaj!_~zYj zt#{YOZ!lX?eZRTc`{w*^W->q6-_+gfdHVc0P6$q zjZrgu6FR%_H?{tex0%aqj_%YwZ+``Q=GNvD25eZ-W;eC-?KkVU4QA$_v(B5Z^z6Ui zH(7u4W1g)0%yx1!`({;lPioQIYxaIee-|>jDl_VWV+|8}ws+k0tl!_=I{Q6W7Jh;= zX1(*`rF+erxo4)o&=YUIb7myBo?uVh_q&Vy1oQUa47qyI0UKtp?!0H*_cZfjJwR7-Ee|E&oo=qH$D9X&#F8Aq65?=^i}u2x8XMoc)h%%{*cS} z+@!wn-W~sD{$XyI!3O75Yxj+)*@gEWX6%_QIG}Isj9K%A|Nf3Qy_e%Rm|6G!Z+6J1 zeeTpR*zn$hn*Gf$nGBe)!P+y{zDNIlGxjdC&^Np{lf8M=nUfjSxoeHK-qZ`s#hHO_ zm|(V|)?zK)*E6l}r{B3rHj?*m;rur9d%*rpZQj4Z!k-V;6W*D8lYRGfw|%pnIeVQM zWEcA0cP{lry?3Jvtku5n;P0p2c;_GH*5J&6-n;ScFkfI@tS7whMXjAxFC4J<-brq4 z-goQ!-DbVBZ?>MyY;cdhXL`<_a|4y0v9Eu<^!4V4InQ z!h#v~oOR#u%gHWy|4?BG^XB>udQ}!U-x+)Q*Fz_@SgQ+8)XcNK!%X^H^ldjY16{B) zbI=JF26$)Ru~RQ}>fwU@gYMSojBj7{Zn3u{-#_qXz2~NP&<*yaZ!sURgXb=OLcjB* z?j6?R?7IiQTcZ=4U25`+9+=S`H8)!@qShz%{bYTcy_a*xb}qlg`+HDNaAv>->+kn& z>No5AKmGT;nO&Jt-S3$L*7_U>zaiKW)ChMEdr`8Xc zD(5choZ9bv?;&&Uo9oTQ*@bR6&^KF~PjF9rgIfP~qrGe4+xzMyKk0qH(|dmJ&5U#J zJ^hWSH#oP@38R`fU21#o`}X?2W1oGmX2yAECo1Nu`2lCGck6gF&n?uMoBG{LeyBIN z&z>{Z(sKuI-|RwXMrw71lT6*vde~>TwY)l+`L8wuXUx4b~((znnd0huWPpa|3-d$*m8l z&d<~r^_)Fvuey?}7rfJ68D{W$=!O$@R+>4}@AiD(yE6a!pR)(w`<&@J)nN_xo%_8v zzk$68JH3J4WUBQ^y`t7K;D)}@Y=iYfpUgK*yfX*%YR|UbWW_Vx)9ifu7vFDl&RTUw z-DrLH_#2o_=-)@b;KGPnduB$PpVao=`>eq^d406^*xTT&=*3*F>Mm5w)Ej!{v_`wj z%-PQDc=O(`H(&O^{91EoXUr$KXMNPU^7`ow6;AZ^IBWep>h6KwX70|Nn`}vH{Q&do z#J6u&ui6vy_vY-e-e6A-yx%J7oBFQr+@RmOGw&?<1n+$4Q@hum^UhxQs=3(^ytMyv z>YjDR{i-wHZ`{e8?GCdE7rt*B$yM`%inX|-cTQ@1&fTcB^d9|&8T9WRsOJVM-Er{; zoSEphS2Eu>sP=o;z#r?6UTfx7vUhjte&41iLvruLw}16uX5SrVzx(mq9=va&fAnx+ z#&6W~Beg2Wns1ZtzUg(w`t3&b^TWXpFkj#fRm=uVIIy92nB8Rchk8kVYklABIoV_r zy2lx78DPEP#;Y^lGXrndUDn^ey_3E(E4A;iXDuTfYu>Ah`BUaS|IMxCIv3sX3+*2J z#~e18x7XY~(_ZNAH%r!gIdu+HJa1ag;LWZcob8)c^9g;|`S_kyJ-gC# z-mGubo84gDZ_}UT8y&EO?|7Qpe8YtU?zZ;4*?{+U)O@8r=(`U!x1R9!=eN@Lt?E0I zna-MZ|LZeny~y3a(Fr$tqN~asJN#>j~}|@y%4TzCr&oW_SFn@dwOq@EtRmxr|TUcd3_rrSG{@ zAMI~b-R7M0tmmvfW3Ae|XnzaU+TP#(B$MtNc;|n$ccfMq{R^G5*2{r6i}o$kEVVm} zvllk#o!6_*n2Fx}-Oc@GZzej?3+DB`V~|PbZ~Tr7o4xJn#%&S6Yu$v`ukhx-*=(TPUrLcEPT(@-}K);-%Gv0{Mdu{9s1lk zl6jX|`_5lboskVwpLab^?Kykq1HJHi)^MobUgo~DX3jQSdU$U^rye@@a(9r)iO$TZ z1I+actkv%DuI^FoSDmwe(FrGNZta|V^m1Xr0Q2?^^-T8e!f#k5-!Q1fo*Y?i;(ibI$ghD!*C$fRjuwH}ygfsM7EC&Y8NqtZy=B^)kVFL*Jw}>-z`& z3)bb$iH!>KXIi zt6s7L-#K?~aJFhTq3<$3aFchJXZlwA&A+;GPVJ2SNxj2G&D__Wsxu=vv!QRYcR)4& zox|U-lB+#;@Do&7u;XuhGjnx0UL`G-Zl({oljr^?$;{_f9R zo{^K8s#)jyu0iH^zi&sK={>#E^ZJF(VBIr%&s|WzvxEA8J?B-KQD+DKfN$EL+<9~9 z9r~GUoBu=YjJuq>uv6=0!hj9^2J-{vV!hyiGy2Y|GRedl@lNxthxctc`+gJp=A9e( z;~aSIqB3Cz&-JeJCvQE-^)qyr_t=vK=JImX`PXmmTz{JGHrffVsRk=KC&rY44DU z?-*$3SNugMwBPsc^uO~Xb4LYd)EV!cdbw~uSTE?liTAGe9rifu{CkIQ%{L7)XXU~Q zCz-it2YSJCy~7zZReDF~2ib%hedqA1Z2U@J9WX(E!1~6Uxp$-QyHI=9?wV-7&D`0A z+P5eA_Xg(MKV)U$f4>VegSwf1!^`|V@_w@eOY%N-Qp-(m#BWsJZ_qEeW^kg`?rHtq zx#@}D(z%s-cn;ieUv_G9ee03@Mm-~wT9&z*@6yXwnOiDcIAMOg=iZ!3ZvB3L{07h1 zzfos;zu5xk)t=q>3C;}_4%nBA?|18Gd#GpLY_iwW+kVyRwtmQ;oPwE@`J)64kGxJT(o4*+{>k0NwynQou zNxsGGBFbD?jqnho%bGc$gnvcs{?@9pfQ9@6*b@3hGF z2hV)Do(U6O`s_Q_%)5W%tuM@|nX_gCj!M1Io-^NI&rF?g!P(C5)@Z-S*#R?p(eK{= zMsE*2=if}Uy9f3BVC^2+@h{7a{(`-iXRh}wbKkCt-^@(kGm~%hdHVKlX7#dCJLmbK z#$VL@zzltF@3D4n!3J~Ra#EYi0J9scJJ)wxZ}J%y+BbePb2%{d`R?bbr&-iYUdEhy zK5TsN@P7Md3r19bJ~+SAGr#G1zTdCfo9XEnT-afx?!CY7$R20yRoh?Yu)%v~^yN8g z|A6PW3KN`{i(jzAK%G6P_o(*FCz#pyY`rtmH#*ndH+k=0)Mn0X=sV2aSG6Yxy4Rlb z%{za|&Fncl>rwql{q>t{f%Dz(d1nW>yZ&0kgaOW(Sv&VV_x6$7fA6Ee@x5oYm+Uu3 z?Yy&{@r>T=KzCS6@A|&S%-J{JnY?$To+_vdj>_Ho4Hw)wqvqD`wy!s9|G)40Kg>(3 A2><{9 literal 0 HcmV?d00001 diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_64x128_64x128_64x128_64x128/input2.bin b/test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_64x128_64x128_64x128_64x128/input2.bin new file mode 100644 index 0000000000000000000000000000000000000000..eead7af3788b2ed2d30c4b2445f9cac2a0d774b8 GIT binary patch literal 32768 zcmZva(UEN-k^?beJX+Q>Ou-aP!IaHpq#NgT>_>)Fs-ocB?lJBh<3Im9#zgfC2DM&2 zsNNab=CDBDT0O0qoA?D62ApX3sO|6cyl>*|jb!%L|NQs7J-wOSsD8o%cUbFRma}H= zHT%ub(>q&E{LAUz)EgbbyqWsG!5!2KruwO!Gdr;3SM*!I@9o<%Yj*LS*L&uh*H7~E z>09?5zSDXqQ~Pb^?r8r|Us#}@(Aj~%u)%!6fCH_))7+h&+j#4v=C>U92^((EALyIg zo9Xur&Rnn_sP}lre1W|gZ$0$%H)`HmZ}w```lL3$;BLLRW1{a~@`JbE@9evurtgdl zxbYKRAAj&8^^Tt^JNk_;r@rL%-fP{gwf&j7zHd>B`|J(yo%XjLZq%N0*3&G%-AZoO zJ!Z!m`u*1C&TRC{r9Lph`M$witkof$x1RXUZM?M{=sWe!SP%RKbwoHmrEnnTdK%uO4de z9xA9?wU63+{hYaE(!R5=jxN-!I|u&Fa^C!ecfY+46EF6bdCpkNCEMnOth*1{0_XHH z!M<<#?yz^$+o<&j6E(BG@e4hmHtTyvW~$A-bLQMe_0soT{F1z_I>g+)Oc@z8WbvM0~P&6%}?r8b&m7S*mK5Sv#Pyqk2z;9_!jFg=FS~xZZ<)6-$wfz z9DI8V@0p8kJ=DLulGQsS*6y9@o4t3dAOFr}zutNK?|i)V1$zrl7|E(j^u5Z26oiWpkdFQQHvZ?31J!d@gz?<1~MlSYOw4{L}On zYTbKO8DtwySZHVZ9{c9*wCC&qv#YubJ>i$P_MW`E^H1q-=zX4>AFRK3+PkpL;r%x8 zX4VH?;o2jcsJXn~FPVMsw%@t$_YM2-=FkoAPQM3#TBGkbNIt2ZecznB^q&3RzcS}+ z=TFYr^USO9edpDO`SIQT`uV-h4m0Acwd&r!ccs^Sq<_#2EA@@`ob|NNU2^P$>OIY@ z#oGRa9h}pfU2vxQc9@mk!I_iUfljdAW~o)rn@e}?_VbpTdY}ue_5Iy9-fV&1_cc4o zFY5e2%{Ex~+m>}Qb%u)`=uYR`fA8h4m6;jcnYZrmsUGSJo+~TaZ;pC{xhfMb4B@P` zEVyv=pLe|P%sU6YnK{+j3rjfTZ13)y)J@->^UXao!K~U@d*&O=#>ZRRH@h%l2mOJ* z<<0sn`JS7exqhVfjQ;zr_NQ6Yy5G0)2UO`k)%?PW+P9u?!1~>T`Y!7W-Ju`J?3-8H z>+E7~gZ-nz1byG4_PcI+JNCx3vYItq2Ber({Ii9PIA3E;LS6)&EbIOChGU=UoU5R z_R{Z8y|d2g|GAgi{J@HLrkQ?za8I%49pj^&?>p=rsM@!FZ-|;%tDb9bSttAM+hm;` zcv)~@gYTPY>2CY>te?-!Ld`~WqhdXy-C=f{=ZrJQxzu{!yUc`qqUL4e`%P-!HNJV~ zXY}o7*UijA?QQ7$?3I&zpiBDOdS7aF!vS|JbVA=$^=`fWH3Z}NdoI6>bx-FVNu`>eq?I4630&hF4$wYSiIUu)I*0SBHZ)4$(=@4tBG zdiK35ncg!KR(QL~oj=K(dHYZ^d+iVE?tgRXzq$Cg6K}ugy6=)Lm~iyDyYFabf5W@G z>O1iEJf|MmFre8@J)_oQ?M&aV9x|yPcyUhj7dGgt&g<7FyQ#Z>QtO@R?_g&Aa`xZ! z%*Q9!zxz4gp0)nMwjV#D?(BWje!k(&;-BW6bH2&B8+!hH>dp+Z{vP%hbcbrDZnS4w zn_Y0Gze8(h7P**B*w(zav-+2jo9SMHpyuwHsJldOw$KA_Uv!}M%{FRZoR^F5`$zJPIzLnY?)J>5x9}G>9B_}> z0MC5)I&U>YiKFH+{47sry^b^xvC!*P!mK^XB`L*%R+{)|qDaQ&0ScX6`ZT*%NQSZ_#(} zHiwhUo>|Y^8{k0XYH9C?Vd#~6Ha_h@a>U&f2r?y{a<<%ontN zyVLhA)_uF`Y;(`r-)Mgu{Rw9B=E>e})R~!D-yPjM%wWNV?+$0I5Bcxh$2)INoyqIX zJ#Rgrx!OAxzL@I|=y!0|y4gv-&;k8k>w0tFJ5jOco}GH2dNX;yl}fhYM4d6Sp0HrY zS4XmA-@h4{-LRq?wKvfbtj!Os=#01C;QNjlvp1?Yo9S<~e;>08>VO5-^{TZ@*x=dz z7S;>wz2C)}`^=gjy!{QHQ*Ub3Y@!R)?(VzQ*GETYF1kUL8TG97_l@msa_N2g0lmjF z%M6ui502{HWS%v8et_?H zu6=dE26x_%H|zX+L(V#5uFjzM{C9qk?d0l;S9{)k!bvvJ2fNfI+|_yCamgn1y?XP3 zI%mB;*-ftYJHBsqzIA8*xyL><(=V9dxvKd|*4ahvn_iyY-*?~7x%Uq04d&LKeeXui z-Dhn+(^H3#+tV+2=a_F#-Rbo=wLZp&nN3*9 zH+n$zoGLfj%2Qd(2K)=!G{!&wM7^sPzpC73&F}@m{swzCH7Sdf!G* zu-Do8n{1;ChUDDDn_GYH(Rc3JhpC!}Wi6?D?;*Re?ZunfH&p7{PID7Ec8)m#`dhRLtfCHU-e{X8f*fV$5`u+BJbF+iqcdrJSkCs}9RJ)yerz2nAv*Zz3>&JQr3aKW8s>Ou88T7Q$h zw=icd6ApN``>nlOUGd%h%?7;-_9iOs=xpDA$TtkofBV+XO|)lMGSyrL^qg}ynfZYY zFYCLRF`J+MNNrZ!wW)oB-kAg5wNQJi_0%)-=BVv`-{joG8T^3}CTg}*>qly{2^UV( z%)0Mt@73h?R~YC__TJ6?huZvxp0l=B-KiHkp}&XqcbD}}|NiZ9pS{jky>pVQW^$;V zRaf*z2RgyI9pAI+gWTQz-Tqo0;BW_GVD$ z?3)cq{hihKEzUUm$p6)(KB0dLvzxh{UUxg={7P=$tbOYn^yc#3Kz%;gU(mk2zW48& zCp{V9+|p;a-ZxBCJlneahBI(L|MsJun}2fW?k96@Wxl;?d)5PPXuYU6c&9z*`!4eV z<|p(`tyTLI+S{q!V`g6+=!v)IoV^+C`M$+jz3Pm8Sc`swXL|nq z9@N)7z2E)hX73DoV1xbpQT+mS##hZxdVQzo?Qa-hU$tKGzVE%wEME1j`6NG}H}5&? zOD^v`wP&0abL)jp*x;OMepNW&UTb-~P=+eW68}#M%rtDp=IDDH>a18F6)u={ zuX*>`8>!8lv;OWgcg|klr{B6 ze^Ad}beXB#ztNZFEzZ0ixoThZ?p7x(;f#LAH*@ZJa{aLn8>XIf{ig14#=b1*xfyR} z??kUTxM!gg%yzO7ubQ{s`o4*GEb0U17y5SV1rz#Kd*)`oTL#|znhDuPJ#(T1b*5bE z1sn9OtBcIpse*ZRf4sR&v%F#8zjxd7P0k%~?)yFO8_3$*pP7a4H|l%t;*0&6IcKal z+I!57wQx_>?1qWH95okbn{VoVlit3)3%$?GGy6|wEuO#VOg-W~v+&aIJJ$KlRBL-Qo*8(vzC%A@CsP+(mAkCKooWsBf|p?~K_p zk2gE8!%FUb?i-)nxeKcOvXi}e&X_whg5DXi=liC;>|gi`=c5~MK4IuNcjH^z+i>7U z^&{RhGu}I$v6cna<~x3%6BZoUa6jJJK|R4n4dxL|L8Sn>8Ic)po+ z@7y2XY>*wWH=`F_;h@&rn!RezF8oRUee+G;Z*guWd-HgAnB7lqy z!g~WV_P&^R_n`OQkeMAdoA|c_wRT5q{e=zs0oM8j&RNR_&&%I>*n2QonCt-Z?Q-%~!nj!Q0zlUd)_TN4#o%VI|WK zVYX~*=o=3Hnt33ckLp?8^E517&3v#9rnzDafdfa=T*3wGiNUKo6TGLJH2_& zz=R!j)|sBY{cZZ4ldK$k?^2y_ZSU>ijRS8zLBBt`@aFfU-DTDt&RL(#jr7|8 z)-%0}F0l9B6@Sp4tJ>e>e}8}bmwdwaoO6Rr_J@f-;U;&7nYHR!eSh25W&@n}Ju~%5 zwo_Mc>h3q6neF#j_nwiyx{~kslio$e+`TLQL?>DX{=$L3Z}xt3`^9%G>ItLjx7MTj z?|auLyZ95lPnu0~8L(ih+;`$9+B5cV@_~xIjhg%Rg-$r2m->yLsb7}wy2X zoBHm-Tzd<@KXtu*z042qjoCHGANf7%oxeF_wqS?;eGaufxv=7onti_4y#4w~w!!*D zZ*-s=-0{x)8-0E!bNdVIJO9;7ZKfVru%YMl6E5^OwUz#1FJ*^xHc0 z&h9wLZ{|ky&E;?8S?6YY?%sIyM(wTmgL;0XemVYvy&dlxoHJW6gI*3e^S(t77wnbS zr&cF=g1-B2{8C}VfeY?7`_kWNlIu76-oWhZp~o}yJhx!N%h2zeF7;04jJT_JjArPK zKk=$_Rr@kwzv?mrx0Aj7`5yPI_R{}m zlT19@xAYs#^y+{UE-EuP_e~SuUFxqUKjEFlU#K5=@r-)mgzo8k=BNMvee0s{nT_iE?ZZ0X*qLU| zzg_70%uT$!EZ(^{PpwMl2ARBh=Jb7|S?h`SoNwFs1wZfanz}luuNm~6>h~@7=4WQ( zPq?VH^Lp=^sk=utSKloCHRq zFS=lV)S2a~W`9!;-vl#(VC)o!accduQHY zRu+B;JvsOboj2Rm3sigh32z1+@xG(|?z1Ng&POla8FTf$m-D|J`KGsrnF-FGsB``9 zk<3~96E^7Gt9}{OV&<8tg8Q7`_zUk%$<1aM(Z0bsGim1Rf|Jbp?cEH zT-?!mP`~fVIeRBF?ic%J6Y9-(YIC{K5p|c|{Dk^R-EVQ%%AEVHop(n1hK;vx_IAf# zYO%lI?aOzVP3XJTo0)-jPv@-VWM(=8o|89Eea+xdmF3=IDcE?%_el;POYAxzu;N@fD_e==Px?J zxxULxPWay6JN@=ee}d{ve@|=u%kUSRlL4JI?{4?nU$9|7>+h~Z?yM>!wKJXD=G68t zsNZ*cr?os!zw^#+Ynai6njO$zn0j}hBi>-Qi{Uzh^F)IPZ*pL=W0M_T=4@zssd| zX2S{tovpLeo|*68_Q>^S?t9;8p8BHwR@K_szGdiPH)H1YqTl_eGpOgRZ}9GB=B?FD z?H-wUYx8#x{=F|feQVWSp6Nc%c)mT=+8J4JRp#EA++%Nm^P+!q(Y~j9?RDN=>rj4Km=b8H34E1dHPU>YAI%j=gK{NeMHdOZa9(8_ZZ~D$SyVE;p ze~Z4OXPsN=Uo~Fc!7~ScgZYF3?=7gA{-OsajIhvVs`Z5#KhO={^N2UQ&^uPV+5Gs8 zw|CThyS}yRj(&&r2J7yAeKOBq6($@Qu;D`YTQ4}tXKM53^LMh}GZ#Osqxuv2j!mtf z(TnQU1t;jgZ{5@r2Hfb|$vd1|WIMTjz=ab!^ZgzCeK!3i?C9?|bH=k5`gVPL-?t6& zvNG?P-eqqmKWOi@KdI%y4*CW5&3=DJ-grK>b26dtf8Xi1keyJ!s4wi`?8FZ+U$BFj z_nXTB-?5@*p5M`d9`MY21K#YrW0}#P=uPjS=Bo6LjhC5hp#43}^zSbF@SDpUt;M%E z_dIp?n;&X>>Ht+POxW?(1I*r?oV7mCJ^ilt?ddzSptGG(`=-v!%wAO7W3Tnbd+wn6 z?>pWsJ^NoKbJ90?&f2riS>M#_kAHj88>k#`PVdfHkJ|H$*#`U944%^K8P$9T-*fQh zW)}ubc;AVdo1J)P_1!s>S9`}uzEOMT;$Ca%9sTCv9PFqwN6$I^f*Z{Ao;7#ZNPW<) zhY8l|dqdRPvpagx0}G~E$zIQz{dVBJb7$^Gd+yECpXR9jX8$Ppf!9a3bE(~L&+MA# zcX;27H(#HgGt$1jo>9e}^P9)p8+dyM+@as8Z`8c+G8-^Iz3-mqnP2ps-BI&{SKVdL zz3L3T$LxUmo%fvdos0g=oY}y)-`$7o^>TJod&a%p{j146Z~ooq8F_hTF4~!a|IVz` z-+T4WG;`*?0lmTg46oPUjLdgA*EcNuNnLNh-@oxA^+CJu_kGEn`|aeuOYJ#x>3j7H zJb&V??VC-r{BG~A^m^X@fT#JEO?}M2JE)hL%Dx$OX5-)cQN7tkozXjg@UMsJn>({W zb;m$=>Tl+}J+lk%eEd#*)c9uoUF=Qx{U)jV+w}eYjXeJ{-mYHi0cyXe>g9|yvkUqS?$ZBue!kJn`<&~J&YOw-Wgj}Pzi^_~ zp6_=|>J6&B6@T>ZgEvpT!CbvzeKLCk%+&WC@$T683HD{c(R0U*zB}XXbw~3}Hlc4C z_?7y7GhTI0-nq{=cur>eZw7Du&i9)?S@)>#US*a5&>3&fTJ;V3_buFg@p6(? z=V!*OZ(sQPne7|iw>XzsXWHwXldNae11pU5Cc5E#a^GtH_T)Evzt3l`JJKLU?_%SD4XAeYQu3K{zxg6QuXwlIj!~H z+V)Rtck~Tc>J~ksPxOqc=~i{GTrdBBQS}wnEy%hvqxh9~L`RT$Grz-``iWlU$@HM- z%WSD(W_Jd@elL&j{WFZPW?S`g%bThB;9bE!y!i$*J>7%M5zfFLK~7EAOxA1*J(%0U z%9zP5mHVUa@o|Lp{p!!)^W8y>wQKHX3s*2(y`7Q1hZ)xAo$-(KXu%wI^)u{&NBOe$ zMwEZ2_8HCHHKTG@*uo5FXhF_SuezsRIkOSW?CGgYi@MvwQ8WF+OHVOVp9O!$`@4~P z4?MWhgY2mK&r5z!ub!oTzLjsOw?O%C^4?g!{)#UT{$BLx2=>3TJ3CLXE89!e>keK2 z7Wnd5pY2ZF?S>Y-F@wD`tedMh+r!_@?6xq<>*c%u27L>&*rQkAyRBDmp@k8=1O1&F zS9DfyuE!R7=Jc%o&r9unyZpWS9)07vVXl|?emC}y%C)Gh`7g|12kx4!durZZ@VC@Q zpk`{jxT=@eo6YRlYhmTLny9DQ=IYk!u%=Ep8GSXk6@2?NA0wrH(T(*m*1k` z(e~M$<*khTEab1^j`s#R-VvVktyFL3-kLA9=hhi|YMB{5qPSCURm@z)V(s4@6H=NqBF>? z+bhev!XD1B?$~o z?3jEE7wwQ<_ z?3mf>!QQIhs(CWzyen$&7q+T-n3-dw_C`;AR!yJuEh_T{xg+$j@q--h%;jXL9Ja8B zE7LwVb9FIIrMh7K;QK0E0{fDefRP* zIL!lZ3v%jF{r%aSAK#8V-ek_`+M9Kw#Z$|w>H52u-@{e<)Qxq^?5;azTY5Y0{DNCY zn3>sGnKif1zYF_jdb_YMBYOsOwb}UeysP|~ygj&OU(L6-1!{+X1)l7PwkXf{@T7*? zyqSj{WbK*HaF+k3K5AwTdHYwX`7g}CbBpKBR<%2O>Q;S^o`JTS`CKw+!OXthQMDQW ziPzLowfE%AF#}JA%KW0e`jtIvc15@1DDTb=&pmm#N6pr4yF7c4#|-voW3#_6TosGqiBi zcxULr9+i=wp%bD5dQSi?w-qfS~E9wUS2yUq99`4N4 zTd`&>{RlF>GEbEMAL;E~!Je$22eNLD^4CnxOi#Zx&%K3xbt|u@KT`91^eW#TX3hCB zZ=iZL+;RgWoM8+69%eEyH*2Y}rc_aG9r(3WytKZc@R&Cxh zQ`@0E__^B0e^QTlf48FcuTr%g88_BfJq3!cnKe}<=Ob6yW>87f0ZkiE(~s^_hY9sV2bl{43yVT zPqwGuqcF#_cB3+_{GU3Wojc6%1{wFfYqs8Z%kCLud-d{r`P7XT)GhRIg%M<--m5l0 zgG>u%N0@=)Bl2-YC3}XUaDR*a~#3^&MlZj?fx6C@_Op;j@%YT<<)M< zz>W6t^max#f|;7<9ryV?T)|8|gZUmcH(PzHnlE=JHv?Zy&HR4Uz1XAYr`C_qg4-19 zmf0S>z5dSRdU|)%6i@1vmsPL18!}rs3OBC6Z^6FW9)G5vx{c!tc{R$=c~rvz2Y>&-6T*5#7TTdeFNge-vl+S9I3S8=d9*y{_6#ckH*Y zZtST?`P9uG?&M|RCeLi9reAMQf0ZZyB;(es+HQ}xa`th@bH|QeO`m8FM{3wxz1m!E zhApgm+Dk7hhgNFl`*2Tf$L?F%JNip651EyHQjd7@{4?4EZx3>2d>);pc4c^PYPlo0 z12bL=tJXvQ4r}i7^RCjfa!)+UnE6Yv%qF zs63T_qkEL6->SC9AK@3|%=CBNpuFXusavr3WTscoU|-z={|F;2Pd}5PJZw5xAg0C9H}A0 zzcY7ZM(so2!x3hX)l+x)zaV!7{`u*>$)j>D`h-Ir?C51MQ}^;VeUMR~!EA(9J;j|p z^Od2zE4qhSx<$Qb&ObitZ64&_^dlzv<)_|@U$g~TI)Xb_xMjmuzWJN}jQaQaH;=Y>cAluc9+lmL zd-fsoW_CvTGu+kB=o8l8s@AnOqCzRMxyl&FEQk{>nV@XZjJ2aHoHR zx!!F3jKA`F`7*M4Dm$a@KhgCwsIPeLot5Fq>gCMcp>O=X>J}Y=&T!YBeLnWE1$(nt zxg&lHbOu>!Uq-LSNbODw@-Uy_&D@Nqw+sGLwXAs$e6;KsT~?e!>MeOKPC9_+fKK0*s})C_iD zjw@_ohPQk>PwE~Wg{++Z3^IBOH~DLJrFP?}d8>YT`d^T<|0avAygPNTURKV$NA2MV zYgYzac#^YkMtk+pT7azjJ5qmYuzH-Tb=~?+RORvqjxClkdUYJ8C(Zo_<8{=nVU3 zwtQYIzenBBKhbq_q&K5~^UCbydvDFo)OABh z-Vu6u>aLmlEqVqy-=pj-<;x#Y`4L+BJIL_t(Ozn2tKMw+yb;V-{z%PVGq>08vc-pt z{K}YL!RMNxS1%)PPJKrA&eU&I*52ynx6tF8-@#1AO#ffz)61@m{aLfWx^~WZa=&N~ z>z<6+QLH!gGkLmx2ln+(l=u5&TIw^b+iLsXS#R0p;|g1lK?{6$`0m@Ap@%nfzWGI~ z@6jh5@#v_GxvU-gXVq)}%^YTXxKl&!&rfaFZ2kOK>aVxZY2nc8`v5R&RbMyGL88Tn}6E z#@ADMHTE!TM&;c#Q@0=k9(|H)QH;{{jvF&xkKW-*{nwY;aeohM{v@YA%A4s&YTwrk z@-1xDo9UN#rIzLY!khgg+Jg+`ol#u8^lK)&^THi}^;>$_T{G{fM|k4vK8RjQ{`u literal 0 HcmV?d00001 diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_64x64_64x64_64x64_64x64/input1.bin b/test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_64x64_64x64_64x64_64x64/input1.bin new file mode 100644 index 0000000000000000000000000000000000000000..9971a667f56c049cb68ea6cedb80657098ad78cc GIT binary patch literal 16384 zcmZ9J(UkSfUz}#BmRnI#ko>Ob;jUI5ORd7bF@u`=6-vS4J2XpU9X8+wwZdULz)ZAKa z)ZLj`)Yl%EsnW3Z%!+5sq~MO<=mFI^{i}J8XBKsXK0C}V)Uc;-?wifctUY7ha6m8F zXBNyn-((xqf;HaSH*VBi)~M$W-W|^L_&xoaT9t-9{++RAN8W#_-Q#R_nt8^Vykk)} zsL2a22h799XGXPeW0ib>JSRPv>n|)Q zb#|C3C^zRo$+SQsb=vTpSQm~^M7^Dyik3E{){(Uutz=He=^TIcffvd?#)iNJU8bz zKEH+ewfHMk=^XZ3e9y_~l$dI!vMzjY7Et!r@ZPXBwCnK-w=v)0bH z3TH6aN55}ra%=l~XYw6p*=t^)I=|5aOSoGK|9bY}Qk$u71~oT3L!-GbevuX2;pKSe zaz1@^fit<=xmS~2$nUym=5Kb$-}^Xc&pERK>jmy;J!&nl9)D1|*3N)u%33C~f2PiF zWUW{8dv9l;!Ct=MQXjL)+|JDGEy&%C*G`s4GR_37ys z9C-K6-PRWx3bxF*FZ>1Mz2GVJ9-lgUzq7|P%-+#L&78^I zi+V$=yvd$37Z&LEczfo*v(#-5T`E}19rcX)@66=yf6jeH{rt>0+wjhL7B02C9sX|P zowHx4*k9=VkloRy=WM>yS}u6rp1gT#dy@S39o%c4JCnOtuU`7W?948A>vN_juhG9} zoi|e#GtPNdpIN=x>*3Ys=idFP^KH(1KH4*ryAJjJWbSa*JTr}7GH-OotU!MTvunor zr4H^X`~iLH>^Jw!MP*s5^KN~?0e$v+-ukaPYj!`gjX$u!T(x#rjW#OpUNg90rdlts z&N*kTZ}RC`=l5Fpg3Rd8)O$4VUDU~*pWo$>J5!_9cl<&d)EeEWndG#r8=pIR za(nWt$-{%hJR4KeHq0W0>Mtvr4)SftJtzOhU3pEq7GrC4SSNP0syyvYi-rByJEblQt z(`$9!u;Ni??5XxIdZt#bCGXHLc<z7a)Z$z;_bl=~eKo(6bJ_RaH8X`asPEpn<4kStxv%Qo;rzlky!%zi zyUeT$oLA+5GmTzY@ZLnG+As6Ze&@3{y?^PP@qF&ulbzYssJ>ti`U}}_tyeenpntnj zdkZd@zk1&4Oi!ME-ejG7*2nMJ(UaMG-_Fj1H!s+5M;F@T-EX}?y`wL)$DSt3+sw-u z-wfu5dQaB+`HqX1^WpWW7reXK;dkkCE;F91&zxuOWR0fYc*$>n@n`CVHn_Vcdw1cT z(O-1Q=bPMR?HSScsIz+WJwD%H{(Z+^zv#Q`qEg@+%=4CHP4?b_>d&8gQ@iWkjsH7u zZ9nJSUu1G&e{|u$dk(dj=e>`fC5T^vv^)i?=?I_t-P< z$?WG_7T$i;UW4yF>UityS!A9s)J!kV+A~kjS}wd9_NXsc1N{TmsU~?^6YR=PyN0R@A-EY^}I9k`#rd)sl`1TU2vx^YjQD@m!W0{m4*%3 zVV3VOf6Q;CFE}%^rp}wJzi+j6_I>x~9nLrT1-)l>YSi9=8ntJAL7%raUQ(+ywf-;f zF8UXC@P2$lLEit}b;$FL=FUI;_&YQE)ALN;;2CTC`CfaT>&cw4c22c#@0d-VEcGRC zFgL5I&t%?Z&)jTJ|LVENyuq1ddun_79?kcpf5{gd)%)3}nr(2Wb>ovAygjucGoBUm zHTg!%`m@tIZ#{U)Oyk{Qtv|59`n_$uxjQ8HI^*mf8hvj=bN|Jk$qRL6!3JyRJfnAS zn}M49F5aGchIbF=oHJ7w+(Ew~JFMGm_IU55E>N92(>L3n%v_wgqw(%3vIRTJ9m$<< zdt{!8np=yri~4|hgLA4q)w$&Mjx#de+F7+Hzo@nP?x64dp4^^fcj2=`b@rM~{`Jo8 zA$|8GJM^ySprGM^GwOwcH8d(4>Vt3i_I%fyBVWv2cyC3sJ2UBJu9-PtUgOs%+tWK} zeph!{_w;JiyS}_P@FuhCc`$Qc-=GhfJCp61v3|QyGg)xQTbF)*moq*4=50^>edhi} z9SYxY&FGt3t7fvn_o)p9cgVfg&e(I$Q9<1wtP9p?qi;9wv!^%PaKUUrjefIDmhW<= z!G3}iS?*NdyQZI>yEfUKyvLiZ@t!q*--I{+W%vW`RyWws8w=lX;Y^-8m-Xbo-;LT? zYwyq(yfeJj>`>dcr#IVZftlLi8EgCU-D$pug-Tm<=FLr3=pDc9L-yD^W^+b0TjHCj z^=1dGRWr#w*_St7dOb4-t*N8A&pKx|{tS(Zy^A{kbn;_1G@7@X*(6>RCir;|EcXc|G+!v{%Dg~zq#=TP42n- z=iHjUGwB~alsb2qS&Mf(@E5$lr+4g+SFO|U=~)+kP3^9HlUe!)e?hfh>UgsTXU(i@ z)Lfr;E_`;ryQZHz_h;VO1NJT|o@vw_+ZwFnb5C;5+Fx)$zu``;%6q%-hrF-J3(iON zvZ3MXd6Swq>FpKh?VosbfwPkSH*06;)kkxObM|E8^H%+V24_|48g*{r%Jti%Ra_cxiZh z=lhpA-d+v%GjquF_VRtsE^BBwkbCq67aHE4sn@7IIp9wH(!X5ZXQ%<&5fUxM% z=j=jr&Y7kzFc*Ek`}xW8hV0Euv}f+1&R*2)&7;=I7rwy0YMnVVYiI4LVNY%Ug0=oT zYklDL<{f42{|c3cGn(4_rKXqfy3}UrIq!YnJnu6*L!k@ay|dfe*+%6;eKNgS!~SUA zvdD7&y?5$D#q6N6z^wJ;Ij3GSciKB(CT41p+1sF3)4TZm7TM8c&MmP2JbSG5*?;hP zqq&&5XQ5)Ies|xg&!}}lYImCDU5(%1z42!<_okp0@}8%VysE;PbXSexqOi zegm?GC7iRr^>D_!r|=gV7VM!^-gHp?8jbI%FDe_nXQ@!o{cjd#D~cPcXTH-}npxX`2DGxqbFE&PV;(mQK@;DX-%M$J^u_eaeO zqWXpnchuTFYce%=6<(|_>V9X<&g45Db>{72#@yPoJ$k19-c4TA*=gpSD$Zu!nfOcI zu;ufvWZCVRni*A`Q7<^>j638^uI}{FvYx*A63#U$cW|cgs<|`EdN32~3kC13sIzAF z+8Tbn@#fqlKAOBIGv9;$LW6#TcRW8k57`>*iP?qsX7um7%*~pfIB%BU$*kalnSC+q z!P*_~j@h@!Hhj-m*W~Jj1!pj8P%n7adJB8y;0rG3v)|fW-dj-T8?`433Xb0QqCGg9 zZ*az*Tv#&S=Zu+p&<5)T$*j!}c+TDi`+C*;UT-q<1Ie7f@Mh5(-;>+d=UsYt7AnaP z-W_7Tr)PelLI3VW3-+guSL4e*^knMMbHC_c@B5yawf~!+?4sVEZ?(VRzy@o3)_3Zf zEcb2d+;7jhsJZo;T=l+2!UI{) zZF^St{%|NIZMPevO6 literal 0 HcmV?d00001 diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_64x64_64x64_64x64_64x64/input2.bin b/test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_64x64_64x64_64x64_64x64/input2.bin new file mode 100644 index 0000000000000000000000000000000000000000..cc15142608550db59428bd0584bcb0a235f00fc2 GIT binary patch literal 16384 zcmZ{i(UC=~5d#Uq85w6rD1<^NgyL>Wx^>>V{Lv+iW*9hkcZ@N|I51(uzy34E`u7+E zZLR<1%-FZK-`>HS$-q0)xo__LwGVF&wbmb)dS+DX0nNj)F=7<==)A;`^`7LJ1736x6l!P(FrGYzuA5E%r|t8I;lH1@GD$Y zUY+yKi`h=Ezu#|%edc?IJ7ng}NPSRe>@D#8`%O^2b5HTkPH;{xoai>=KI;{p&l@h8 z_n1{T-n{2#y!zga4mi<`-k_h!7i#UEX74-5ZhC6J&AEQ_oBw?0VNO=&pJ#qkPjkGP z^#c8c0aInyRzVfdtZ3ix)6DoD=^1r`b8i-ZVL@|iIWWL`ROjFC7k&4CzTwT&TRo$v z-`}L?ul>)=PHq17qceEMyn7DbY^$(>nciFm99S^n>Uqmkyz}PjiEh;Ut$UySzTaG& zE7qNH&U5w#T-7YTyASn*9km{*d){oqY98<0{p1@z;WYn0zjwcLF<0+7{Xjjx!EA+o zPiO2OXzr|;b$6=fa-tUb+Pk%q^{zI+WQy1#@NZ)JrdSnwn;CwRYTHBKaJAR_wZB_=}o_e77>diM? z@a#YjEYQn^11oAK)-&E(e^aXm`Zu%Q;2q|@>%EunwJ+B7D|OEdYBP0)?l-$|f_=01 z`=T4{U6?S!dq=+8p7Srmf8X-`pK$Jc&d=7_V?DrKvQw)ox&GD8#9I$IsnrR+TQBRA z?RekXe*aJUE!P=P-IdjWzhHiFueECbz-mqAuJ)T-`|gw3zjfcb>5Frnx1R9sV~4f=qWbqEVu^+HW^6=e?KSX&?R0 zc&_vt2HE~#zA|S|?~a?SfAe4MjDHJeCSHBDxqSY2+w*Miw|}TT>%4xzdj~x;{fu7J zJ#QDjwf&RYoeS@b`TcnF4eE$LQR@k6cP#7Nck#zs$gG{~TQ;?QGyT-V1?&E{$2>FU zum1G^=WgercaGHVlkWJw^D-l5M}-Y5G@I12qbvS>6K}M4!8i53@u|)2J15_B{YLvM z9CX8kH;eXnx$)MW>ppv)G4Eb$^Zm*D#zi*3*@@nG_nGgonxhliJMkA?p=Z5glTA2q z;oIw*|IRJu#52BYpfX|U`?o?{zd34W-ouVeCO=Rf&NZgpZY0& z!hj3SHr_pI8F-mdcgjuXe81;><~(Ch_1zP7?#<(`d6mBTNw(3>d1g{e&%5tBi!O6L zOt|1#z1q7E^#=Q+(0})QzR5ZJs@T7%xU2ij?Qe5faN)qK^KSE<{QXY&37-Gr%!LtM zt#i+z-k^GJpz?NRW>IfA@P3>A-hN;5YY!%PcB5t|EY!2s6Ra(hisGd2fbguVWuk`HS%?#&wcvy$*d=gsB_k;nfk!rPwg&izelfHA8=m(`sjci>>a$l37#9M znV89j-aS$~`|Uf^yX`G%`>H*$w&$GwLigGqsoy&?XJ+rDS3jxUBMX0Fz=5~Znx5D* ze?8Pp-`skq|IXB--Q}!)q_@yFn>)-O)Nc;;ZJWG1W@f6M*E{p|Cv%?b_boELbNZ+B zC;r0ARI*p2hkB%E{~zb=xeuMenfF%C>dl?GDp>2?qwZt_U0~MxU&fwyC%HN^vr)O= zy!irW4mwirWZhwJf_?F<`ZCULe1D(UBYSh?_71eZc=w94?+)iv_q=mEGy3Fp)c2L?2osqL8`csjJ`MF5A++H>vt`@GiDdmQL@L}Y{HIz zd(-b54!!0Je>2;BelqX3zdl&Yh659ZF#GqnGxsiMPSE$eu6^)a-(+U3p77otKT!MK zF{#Cwt%Ci7zM1HErr*_jE;F8UhrJzdRzL7!-MyY)$?Q$&JFY%;)pLXFzyxb`!-W&= zyB4*)`*_>HOMj1zzsXL#*-X~o!EB}9v-U=E``*#_JxyPKGUwTemzzv~V1U_K1&ntR7a<-moWd%r>Yee0w@ zGPC0CdF~{isJ!1K@9Vdy_E!3xJJdty>&>k%EI81!FJrISgpuBRcWN_x+nSxy-_(3T z)%;M)h8s0s_^SP~@k7{qQR@XWzHha!m-mj$bmvHI-}BZN^zRI6cED`I+r#{gH#=ZG z)rMim|I^|%$%91x%CMb z?RotM^}vD&?@ZpdsV}JS9J=CX)L#3}*mp2a@CilF0Go!B5_D+1~C$;(b%-m$ocK7p{+3EM()4bPr_q_RHcII5~ z9I2h#(BIa+z0OTN*mu|c)XtgpZhM{Se0%QDtDW2Uep}~;S?Ijkh4$VowK~yb55Lh1 zW(#`W+FbMt>}_;>IQr!Jp7)KdcV-qk^>D#juTGUS>f7OO%HC@ZdUflgde2y|cBIQ*57v=YUk7y^vx%}Z&gQn7nR=M-X>e0o^@xOm+v=CYFTi= z*bJ=1Sh#q;X>uK4{&7qvStI$)}wciS_Q z{lR=t-{2YhFUPOsW(Ta_J9y8;J7=$NH9OXzAF%7)$GN`6nFDIGnR-O+ZGFx?pZcN) ze2;29Q|o0%zcamS(R1%a2bi0w((nA6_55bGGZ*iib@QjmyvLaXZ#O!@e8t*}HGWFZr$;_P>vl|9_V1YAR@An9PCLe12 zFQ?wg-dpe!?6+3U4s`#(JM*iVF>~fZ{nFEOhq?9qDsQwtA2wc^jd=Imk9VhY3!bve z>`-^sTJLY+%uKIuGQWRv&pSJkyQ9D7+eh9r&K+>}&FOQex#!Gh`WM|`eP9UvW>tF= zdiUGe`qTU7o{<&x{Cjt@?z*3ufwy+go5!opzC67ZHEZtfNo`LxR}Z{*v18%;8>pMw z-OfyVQS%!{@}uVOv+xt#tKZOGzpGjGe&$!^)d2^#K5x;x&)$LqGip7AS@k<-Z$k6_ zZf5%Cn_Alc?)YW&7xWvS<~QrNs*@S9e^D8*;pmf3{Iv$X`9_x>>bu)_)W2Tta&}4H zVtrFr5A}voWtaQRJ-0qR^WLR;ZracMh&R7rulmkWyYryk<;*DAZ-0S0)4OO{`0g9{ z?mT)>Hw@@EsJ_#Fe`i&kHNT*rV1DY+70h3sdVud}?H&5o7k^;G_+Zvqd(!vWyQ!Pm zf8Xmjr1up6W}^q)p!&|f%{|T48=RZq?1exeTi{GN)Slb+@>~Cu?51z#taAfi9zRo?xwqf4s1MkGIqIH)Zn)9DMLlHp zWrr1==x=wwnaqCo+P}#L+IchUs`D$^K{u!q`quhI=KM`xeS7`ZWD|XF$oFq*XWp*N zdCoaAz4PtAd9wd`KiL8MD?0R?HS@griB|`hsn)%(^(EWTZ(4ZoZ8ogIbMJll8$5IH z&My_zAN2bTle~BT?xKImov|+moLlGyXZ2>^-X$CGG(COugYSFw6W%WN>1CL~)w84T zv$w3MURVW)2HUTN() z{eS}#e5?6;bLKa>Gxk-vu)#Uiz3TJHhdEq0;k^Mr!F#p}>VXN}srDVteCOR`=9&G` ze&0Kn+_?*$cYZ+i<<#nix0ksYZ|*M7>dnR4S$ms$fc~IjPbR*3_k4Hhof+Y(IcNUv zN$;Ry-(2rb*~!~CyJo?<=gh3Tr&y+p~+>p1aL{GtNQhE`Gy=`hh>G7wX>c@8Hgt)0^O~i^>kp=nwpz zH=E!&)&4-u&HA=Y-QRM3ytA#J=lnHu59fb9{K>hA%C?R+Thwx)=Xd->4_N!IewW@^ zaYl7cZLjA%+qcWm@3i&~b;E=M&8*8c2la%` zZEAbwE0|e3qgoH4)=yZeUmkCMpx>zaUh@m;3ijUjqW0XaI$u51&!f(^{_d^sxykHy z59kM2yWd%T)&2<=U9hA218+Y%!bR=vsQG~#tbI$rr@1rkvVJ@JKaJY!{ElC!xq4xP z=LYKj2mXL(+pn5CKkb9+8S`SjspZ0e8E8Ci3(Vau`uAJ73l8;!3->)C;WN_n>C|4l~b*Gph4X>6yt11Kr6MD%StV{GQK@yG~|RXH~h;iS};a j*7w=p={X}Ss^8%3O?FUuzg6}uYWpuohtPAze9!*^snk5= literal 0 HcmV?d00001 diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tadd/gen_data.py index c7273286a..169a94070 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tadd/gen_data.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadd/gen_data.py @@ -15,19 +15,23 @@ validate_cases(CASES) +np.random.seed(19) + for case in CASES: setup_case_rng(case) dtype = case["dtype"] - shape = case["shape"] + dst_tile = case["dst_tile"] + src0_tile = case["src0_tile"] + src1_tile = case["src1_tile"] valid_shape = case["valid_shape"] - input1 = np.random.randint(1, 10, size=shape).astype(dtype) - input2 = np.random.randint(1, 10, size=shape).astype(dtype) + input1 = np.random.randint(1, 10, size=src0_tile).astype(dtype) + input2 = np.random.randint(1, 10, size=src1_tile).astype(dtype) - golden = np.zeros(shape, dtype=dtype) + golden = np.zeros(dst_tile, dtype=dtype) vr, vc = valid_shape golden[:vr, :vc] = (input1[:vr, :vc] + input2[:vr, :vc]).astype(dtype, copy=False) save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) - print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") + print(f"[INFO] gen_data: {case['name']} dst={dst_tile} src0={src0_tile} src1={src1_tile} valid={valid_shape} dtype={dtype.__name__}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/half_16x64_16x128_16x128_16x64/golden.bin b/test/tilelang_st/npu/a5/src/st/testcase/tadd/half_16x64_16x128_16x128_16x64/golden.bin new file mode 100644 index 0000000000000000000000000000000000000000..97451bdc6329cf51de745891b11dc9f64358eea1 GIT binary patch literal 2048 zcmXw)368`t2n78!E6N=|il=W?GplG48*?-bcKiZ5iOvz*m2*z;@I>cD>%n>-O7wA~zq3A(`si=t^kqHqc(8x@9mpp7AKD*h z7n;1|PfvHfwZ=EK+{y@eF1>90*m0_j=uYhLiEg7g&@kJ#jDK(@iU$k4j;|gvHoFgB zdwq4Y-jaP=JY(TA9ARljW)@WOL)M(irWweAxc0ow?nDaKDPJ-=(NSt7k$_?Ra^!2A zjq@F(N`t!=dn&T*!*P$fE3}YnX{W z7fhUX)y&9FvAa@XbWXf#4?aA~W!I*^@&h`&TXS(+iExaKx!RX2(e2wXI*cyIsScLb zK`U84V$ro!k`BSrtm!B;x`&M?S~Fu+o|})Zs_PkmxRv;A9j+JvE&mtoZr`==pJZJVWaK?JLe@kzslpE zJBkSQZh9)x6MJrtA~WP>Yy6(4u!)4(0fqhT-xEVuit;4d_S4Ju1Dr~ zW_N7Q%?9EI+Z_+5711Rh&oO;=S_R^@ZH*^L<}G*$4AuR}q#` zd}5IoU)Um1>s{P(vD%opS-InhVKcg99j18?GZ}10-;cau*)y#%>F%u9c`}}sc*&7Y zM@ia@n-~=LmyCme!Mx>z|8jX7Wjz|_O>3F&Ip{KynO@KniCACxFIRl&SL~r>zW%D# zv*nGGx47-K*r2fFFEjh>OOxOkG&3|m%04T-N8c?ab}$^bf*7e%4#}o)o!`V7ef0! z7{w#e`9~$$*!Dny< z#av%MI@e|4Ssk-0v)QLUt}>$kT@-bcGq33}wCWj~v!bT*(6wXZmtxhYM|6%S9YnXb zJG0TVu2SoD3hV}_ho0kWZ@IAqVHo|E2gP@C7pdLfyjWX@`oUT^@3x!fPUp@FoGST^ zEW*Lk58{r;7|==fM@_qR(`u%6vR$p!NbSBJu`i~%8ueu7yUJp&8>eFO%!-+>9O#|y z$SjpExBUjR%lO5cVs!3_i>G)Pt}D&f>CHbY`V7ZZtZQSn9qRJC9rZZ5Of9&(PxOx- z7;Efl>{Qt?d?$alsTty{*s6gk5xi-r;^(BPRN()_S1Bg{GPZ*Umi448{MKiXt5g)d%9Q^yL#ojSU9w5 zvFw%_8$6FPccS+s7D?7oyZ(6jwZM=f9=WRj^8P!ODx$Zrb#Ko9j`j4K9m|%dyne%cMy;&LOAfTDR&}>pUmneAi1Luh{#!eD z0<~1e+}>%~P7)(4kW$BtR9H;e#V;0FiiYXB@7!hdAH&nT`4*#EXS7>x8o<#}qR;;? q%h|^p&w2PRqtq*CQ8;gP4cm^wDBqnn{Y7_OaL)Nuqn;cnR)qgM4VMQ1 literal 0 HcmV?d00001 diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/half_16x64_16x128_16x128_16x64/input2.bin b/test/tilelang_st/npu/a5/src/st/testcase/tadd/half_16x64_16x128_16x128_16x64/input2.bin new file mode 100644 index 0000000000000000000000000000000000000000..137501d2c4672129b998c65b60df395469ab8335 GIT binary patch literal 4096 zcmYM#Tdv|j5ChOGj1LrXfn~Au@v#$)B9cybJ#1H1INrbS?|VFReUDe}=kf09)z#zp z9DmvIWB1AK=ZH^iKgTQ2r)Pf8*x3DZ-s^Gj7SEY@-#i~#z3cOgx9D=;`>)1;fXD7u zY~P7+Zoim=B)fmPKDB)04aP^`V9I;0_xyj~`p;uoKe0{boWOdzcTEepgFNpzVLhII z*>Tb8RWE$zA9Oi=4-69FthK@2gX}Pe(Hc9lzGe9 zV#!Gb@2exG_&^{EZS3W=w%%6Qxs}lyF7HD{pu`p z*RjQ2=kU`H>CxD!BPD3QcRp&s);a%msL4*|c&&yMik5j#mfVA(Ae@y0v$(?Mr1+E8J$eji(kE*dRvrE2km!DTlpmCe2* zblMl}yv}LAi51`OGFh$8PR#&L<|W;dD4PxSQHM@L_*wD8VKYLDA)b}|Q;@Q>T347| zJvMFBep#wdaEV84y$T`St&1v8Ia%2eD_OHU^A3ypJf*C->)c{f*@+Mh(`chNW(iQ8lP&pOV2}OBuIvCbCNv*zY^a|5mH$8%alvW#T3l z`_0QH&Nor?NCt=8nqvX`2OU0<%h|lPMK-xBwdlaFNG&1mnm%P zO%eNdpF&3sR_WUK=;~m}ik1dEBDpl!rwe< zYwG%asLj}nH^!o{5@%Q_VGm;}tRh&kx+~dU$B`zgdXtv7;vM*wrg~=P6?+r>8d#y{uO9wDTJqj;W?P z)Y(^E=d7lSjpN0$8B|0iu-(7pjY0(T&3Zl%v$nS{=3V!FO^1PEJZTY)l4bhz|1WR& zR*Us=qRBK3?#{wbG#ZL^cUW;|kD~ZeaVyiZ`BB?ypB2aV!f#e-o%-0|08IW9rlZ4O?4O5o?Tb#tmL7MipziXczk1Vsp5n#>4aRYySluo zulg|PFKBcG-gMr~ZUXqP!#mZn;?eIzRw7eB2;)1YShpv-W9h=& lKXzxVT`-i*&+c>;bsZY)*{7PPoUzGj>hWCvta6Y%{{e-#jFe&kd+px4>%8{6zjfzTMaS@BF>a+~<70#>#EJ z@8{|r_?vsL@pWDE9&6@d`MZp@_pCV0J@>D3{Wl}Wv1QF~R@`=<89U6J z_q@+z^}0O$4Yu>Vt}|ji+7!wj*($&WSf@gf5(VC??S-e>$` z;WT&N%O}r#7j8!}>)vd*;F7*;jh*eIj=EoqQ zTh}taaAY^=csS9k1Qw@mHEPF|HOEDCTaS>N(pIM6^n5fua6(%b*_@0+o zsw+yuL$#>K;a`{aO&?+3cw_KGjAwSuqb%5}K@|yEQK*S(5koR6ntF&;sv+F<3bXVL z-ft5pY%zgx{oM{seHdZuJ}cxxgm4yvs6TU;Dq*`!4o)*dKWWU4g?CDm<>~=GaUaJ{ zYO?24qo)~D{V9e%+_S^NW`|3MROdQo^b8R|p01%%`{GYKvQam|)qY0Z+bmdRcjR3T z?Ks#@*J*+sY^ub1CLW5mE`gDd>Pa{;mzw1b9yW~FUmlAu>PlnS#(I9CV^5ZzD_Zx} zvn)7Gq+~v-r@S2I4p!OGW!X88g>vCem}06Lrar1qCD6d0E}`QirsBq~WTzD+O?~Et zODol3df;1*(+LD}DZ_PJ%GxguinMMa%HO=5b+9c~J&@|G7d72ba9NTadOHk?_Lcev z?bTy-0jt@fS^0HEtLha~yg`639_d!i!YwPl(=7{~L7nqbFtlk_PL0C`7aJ_o6$Bfz zY0VX`D1ri2suiiqp3xOW@MT~4E$vc*z6MWmI+>Y8Fhi!TYZXpo3oEj(&!%YD_{K$A*`iiCu&_vlG^3Y%2k&9V zA;gF1>x3%8>8vVtI@OQ8>SWc0b+Q?f_9=L1x9&s&?curHe#7bq!jgKS=RmAHa&!iQTNIo!F?G>iCH{XJs0b z$^7^(vryr%zAt}bid(wHgKP2d3r`Ae1#Zm~GsVa~g>3%WlPfj1pY>unFFjF8S&`%I z$tyzFS%hIe!2EkgBO-O-=!~+!YCFNbyfY4cEsfbEP|ek zFeG2P9p9aR<%=~j*HP3^owBn~U4~bCw0&yRC2NqY>v9`&SfU_T^>h)-49_yzEgow6 z^vk}Aq8gY&D~Ga8(u+N&XvdwrAA^?i4;Qvajv>Q+6~3$aQcEI4smFYdVJ)p@0USeH`g zRzqIrNcVh*MW%eRA=5pLtDCguZ+*&n074eAuKub99(4#dpjKCQZhq@P_0FEhXoROu z*gCA0>W^ieh)URE9lA1)Gd@2v$4IrTnqf(wIP6rVlC~Zrs=8C!#5*5OhjE~@c5cK- z>Q{w)>i2BcDfqHau+e!`ro|}&ndc3+J@cqxTtXKAI4pK*bCFGdy!PGr%|7*fCsJ(?1=Ts5qyN)Iz zI@BxeZWLq$md+!zM-4?H^QuMOLe5vcT8Hr+a@gaAb+#$abAsrok*Y=Fw5FQnA4ja? zzMS$dt2$0`rcI2{HHCOysIH(kTPNupp!}=Pa_q#2Tek3t19vb^3%G<pc~*Ct|tnOd*r< z?6*z5NtKDyb;>kX9&u$oTk3^ooNP9wP=M8Tc6u~p>ojmeDZ}_sk5sOM^WtRWP1E*$ zv+-LUr^NJEpH;2(ExoE!1KV~YdZ-pf<$Pvl{R5kLrvtxWU{6N*lu@e1yDn*$Wl2@I zOOyRajGjejw6h=eIrTWXPzpr25`oSmGtjgP>!j>@;dT4ffQs4lyH001Gu4HiJN?Kr zMbyc{A;-Rn2V(4`rSvWO)XIakDcu@~?IdbIPZ7VUT)B%$PaLf3rvB^K)2aBybt zyy+3+xPyu%eAkz0obG5$ws_AnO*ZX%BUz7p6DC!l*@t`Z{JSw_uIIqjIl{9XDk^jyCs=HrV z#H}{kFJn)SNGsKqRWzjY*Nw2y8A7HYms@{Ds~_kVc1W=rmoX^JrVVk~zte1IAS!J( z-q;h9{7{7|!AsqxegL_sVUwkpQ{i&!u1u*FT$IQ9C;miXET!uCyx_Nf3T4QRd#Me( zj^9U}3arN@>v|JQEMkv!9i4>``TeJSngdgBp2Y=WX9x3T61zPgnJFtYiWmBJKBjwk z#Q(o)A;Y*?EXl53hh6(s*6FkwmL;9iNs%6OU^&#c?7#IV5#g8`>S$_^B0HnVubP$N G=k7nj4WFU_ literal 0 HcmV?d00001 diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/i16_64x64_64x64_64x64_64x64/input1.bin b/test/tilelang_st/npu/a5/src/st/testcase/tadd/i16_64x64_64x64_64x64_64x64/input1.bin new file mode 100644 index 0000000000000000000000000000000000000000..db129a74126eddc9ef4ce56295a8e46703c2317a GIT binary patch literal 8192 zcmYM$36df)5JS-hnq}^P;qj|K8iHt`>|7+JZ2I+c|M~Ot`FWks+xhl5pTFnR>wUdG z*ZS?OKF^=E&)JhD{`b6I=ks|oALspft+s#vpMN{?uYVQpSEu)te>=4v*EiYs^K?5E zKIab#St`R;Co8@FT&us=PN-$wP8HGOAE*2Myz5XsSU(pEyY1*MN}kS8{+=h1@9Uc? z>WYNvbv-@LlW2N~exFs}#Ly=+BB(|Q6`toaeIF;@lU3^~)E6@-=s{UH?V@F9ZgRR^i7$Iav(IA;m{vfP*J@4_N%tD z(+uJKs*OJtoHEYHXDT|i^RHbwv~4EM2gcNw$!j#BLVj{Cx)E1%vB+3vexLXZ;I^BNY6MdKgY~FbD62w(G4F-WfClyP{}D3rIzY?SAzQ zySPzUeVjK#isek`s)q$PnZa$+Va+7)myZ6 zjNy@ymrV^8ndNe-=wi~0ZeeXlC3WEImj`=2^B>fy1u-R3Q#Zcj26*yilc63JsXqVd z{?)7L>qj{qV5EjjE1}fW9>hH4NjJ6l9&=CpzgrSRh&`v0shSH!>NclTP$hNER3{O7 zLENnqe>#dZz2;WE7&t$&(L7IA(@q%u{X^XVS5#N6&h8cjW=T?fvJce7--mMlJ>MRCQh~STl%1+adtW}Zw-*4EuXRwBj>uBiHj+eQ0=4qQ3MCKh`sGI`$ za5Dd>C|)Ru>zZ)v2HzRogkRHKYtQ%*r$ zJ4>*RbxxTpRZVFpf+l$60ClR+X;S6;!V9h@7@DwK>+GrL9OfQVg;t&{)$KV!&RgA@ zBD{8Rrw2unF-7bB z$@JRkQ&FrQ-AtW3Q}f*H%1}`iyzt8nx|GYoddUo-D)z0)$+_+-8g-KV63C=cJ5C;L z_$1A`Eom*zllq-79k|C_`c3d1w&~Vcg`J+cM6R2*DVM1#44#^N{)G-(pL)rIk56!D z-n!;T?4WU{;ie~zu}NJs=bmm0=+i_0xSNkoyA@^l%iC%<^%$EtEY+hcb zL$1bxO5MUa#P`(B6%bC7roJ4rOj*jLYG;cMu%`&!aj_>}R`V~9QY!08O;du5;&wVa zGGOhlu&0W>W)A_xT!mjJOlSQ%v$E2|0uFk|F%PCV z#!XvK&ZCMo`Qqhv&vh$xpWwcG3}dKigC~sP>%^F;opAevRqbvle(bKpn8(D^$g}=d za}Tx@A;_Ozl)3YuDcsc1q_h0hGc z(TWFQ?v~5P{6wQziW1Kuy7L|lt0vYtt7a(S$a^_G27JsLp`niTX5W3Ylaey(mB^9qjCkxf6J@^V;r2QTEZo%!`p30JE5 z&R2TI4T64Ux=HP%)EMM-nsF;0{lalax!&_9_vog=ZUBApY7S}G?ZpF7cdm56JB&?( zBKbr2ooOD?PZzt?RjYeCz5B(pLoI9YbWUW2j=MTbRMZ_qpL)>+Uvp@$cc^w=n~934 zQy-^NE==7js;O>z6jqZLu@%P;o*bZYUT z9~LHrPcv%nwA0@y&2PTtu1+;n-D@gQ-Pue7xSO>o(`a@ORv&|$oSEC@eR*77m)qrexnBM*pYPlA{JziH z-{pSZpWmn3iF%&5`-$`XIJ>v=?9}_j%kVj0Cm!FJ_t|-zwa+)d)1uV6o@ak$Q2+Xk zyPwtPcg-0)BH$ALIe%3Z6VlsxQ%(NYiI(|x-k`SEn|J45=kNVIt5yctoVkDZROn6I z=X_<r*Y zfz#!K5(ij^O3r-4W$LGoU2MmzJ$c~b`#!Sh-CADesrT7UOH9-~gehA8O-U{)S5Bp& zit(mq6DMB(@-!nD)R%podfyz=flIS6VxmnyvRb35z!J6uN9DmTLl%9Dz2Xs^Fo_=t7PFc*SJeEQx zg7flg-TMnCo?-RbovNab^*jNmxb$x}bfCIDweql)`Yx|JFy%cb>&J=k%Z#PDA~emV zb17-n6ThjO?|5ZBrjYv0lf1*76YDRpbgBvAE?QR;D>O^FTy5VjJocI#h%sx;X}*~| zgG^Op5{Be;;`6lReSg7n0o{cH#p{(|FRPmKQrdTS!PmR!K2EBd z$E}H?G}p<6k@$dpy67}N!uuFj*j%VzPuQe*E2#iSH?Wh~&v9`v zPrxe^j-bZ;!Tg_CY7 zyVSsG9ph6cR>K~uGE9x~$VjQ~fmDk-&$3h}A43|;cE$0rOftw(Rtj^hxf126Y}WHV z6=Km7ws5Y3=bxW>%tvf_{L0}Adwr_IBkAr$vq6nc`1MKgW>j~O#%s(&+JttlOmPIG zo;ajAi}^5G^{#JNc_pXEPHqo3&LCIAdO5j5^?Z>3)a>6QUgKT+)|vrsa$YTK@cR`b zs{2Sa?{S`|Xh=`|hk{O4y6c+RdV_#I@~I?`C(iILwzXNU_75;P)6W(Ywm+K)GWr0?8b!eqN9z^UVPcDPFYW}>VBK%c*D9AKNXwBZgcq3&@_iTwZf zi|#|Z*vr!*xtNBk^AGNto$kpWkJzDne%5t9h?ODKA!pG9yD_9@&V#CbCz=3U>I63R z-4~w1!w+&*SSYd?@1Rtf z-a51n_4wxnw@u7+TR_u%Vl+QB+j7W;QF|Q1!8xT{M33nwkKE=B`%da0wXo|XRdn5r z9`5+lt86^o8CiVfcS7}U*#8lU=3E+&B!W zD%~BqS9cuMH1n|R)Q=+-rk#ALVNF#i={TL5j1-n1LOom21BR^OZ<@iHe?9TJ9-0NG zQSu_#U6}qA6?0Yt@8R6 zg9DT0-J9H0#v+bsB(GCA=q=9WE{7h(;=-x&dX7De`HJt(kP9-W;P)$5uQI5hCrs)R zw=pj#ymg!+c$cwr^&+2M^w|ujK{sf#QBK`C-6Z(sRasR1q0B+^NbZGDw^p0Taq7H0 zFr|qu@@{;1x36y*^x~Q3DaU~r^yJ~5Y?(#JUI)yGj zTE|fz>Nac6;!}PR`IlOD;tUUI2xpb?lbdqMRPVgny91lX+|q1urZ}CAYw>u3C10m? zGZc$?H9vRr?Do#JVZx%?<=ND!obs5@q3#$+c%>)(#zFJq9TVa?KJ7foRR_A@eRWlJ z!k0(pkPuag|MeQjbwiJo;4aUy!YAYIj%JSf%h!CQI7D=^$7N$%d%VtJxvsm)E@rwxoKDTU)eA>7VK8?O!9oc$5`u0Q8pT|D;+*ZAvz=1M zgu&{G1DbZ;lL2}>=QUaM);%Pe^VBX&w{yPDb5^J=N;W+6VRCxkby0nYvahZ{%9uP(8G%shju9Asgqk*CaTJ zKbfG<3C-E~D5lyX=iM<-j}+Ri)6G&PxVjf|lDF=+d=_u`vfd4Wag$RuNWF`XCpuxo fU)^jpUhU7hr8&>B9rxOp}a^7%w1wK0J`;DI-we;2Vha*o;eR$%aEqLx;&8a_a7y-52L4i zczo)4$}=qQ?S@w0+FsoC!3g)}aEkV;?v9&IK6&7WXJ?#nFmW_T9?gU^E+_rfgTW#1 zF4|q+SxqlpClAd5-l^9@X7((5zQZPk!j-U9In?2ZuU5Z+$0TVC9_N`tqn(O%Auc@z*}_>4Q`6_W-9R z=;{}A%_hgUy1sbi^_sCc)!G3qZ?$&8t6XzB`GbYy38y>{=BeH(A1547?F+Y%Q*&=W zdDPFNT7BN~a_&4dtCKIMfNL+k`_1r+BQ1S2c;Q3fhfmvFc+_zBec|k--N3`81zUvC z>s#f|36log>caELpPIMX)23&ha(ZFfky;_AjtgGC93FXlap%cf-CMZx;4E(z!J{5` z96Z3r;iOl;cl&Y*oG^Gc=dRCQ1#Tg4jx>1mjc~~6X~zA|c20|qIv#wuoW6yeIr8m} zj|PlBA(tk;z7;NaI=Jw7*XLBD6$TClPg>sc-mu>EoO1QZd8f1e?M2N!ZMtbUXI{fL zQ}e6muRY|=DW`6Jp89v>q{}x=-#{JSTdi-4r@7qI*5h5DJDt9h+5G2<>M7H_l2^NL5!@1);& z7<0&*j|RTDn=>9*^O*$}7hZKe>hY+XRW409`0IbGDld(BI6>A!iqO`#ySc@T>+a{H8WfU%wnJ?|ygnwbSbTu4$MPFI_y|?F|?2 z?IOSWaLVJPlYe;gsl&bT)WCIJyT2F;@a#@Y&2GHY zpoa^NPvJYk^qqNu(Vup_bkm|?Zup(E{pW!u9;aS+TKw~6-rjWjeZa!t-hQ5&gBJgA z@b%!3*Qywier*ALTP-1DN3wi@mIw?aVtap7e)pKDu!7G@NR3&b|?DA%`1R81D@jyh1)+ zd*;Kbzd7aH%j-`|j{oYo@}(y4TeQD8^W)vO4j(5T+_3RD^Up&wxSO{+4z;}bZ=|}M z8V2N=nKyuj5=Ok7Sp}{X28KC|8c?K6U_jp zM=hM3uoL-edfGJ(+I>HFczJyG;P=Zf3?KU2g`6JPywVr#10M`d`F6$+ejHA9zU1@m z#G@}Bbv|&WnP)lPcvkbyOFFd7L#sJq1#F)9qy>Xh-kh}BjgvoEobd7Lp-&gaz1j4c z!CSt+ZMf;+ql;5cZ14Kj4=3F9;n?}}gu{&at|zQ~J8|-^$JuVY)p(*O)M(J%-z$E3 zJt|#yGaoQ6tJ$Z6g=TBWUbq!*rE z-@19?;L+^m+oN3HUq9|<)2j~ar0EV%n{K|Yys-F$ySdsGet8^fG}Yzk2-xm$YH|2B z=2YiFJv=`ALS0|q*IgY}&QA0@;pO9kHJj-B%E80c=kB|O-#zymYQN^>2}X_wxx6^( z!MF=L=9Z_`zh8HMTbeW9dd#S=nY^ph#lf!}yt^FT_T$vwTxtbvb-iZs?t6Ie&UmER zoceI;?cWL=e({?XCJs6Lc>9f{n}>D=*Bs66-S38;e0}DxmySFR7#i-p(str;=haNk zW>7b)KD_NpkKYF$@cNrGe)xRJ@gdrYT$*(BJN4KrPIox-G?Se0hUt{UkAwee`;BP8 zo2Qw=)$8q!&)uxemPYf?u>-T(hc~Q%q2b;>U4zS3OOh zmZy2;^}zK_^|ezu+`QN$T)O$;T~5fsi8T13gGbJMcyRKs1}~p?Sa}%yYTosGyQi5a zv#1I8JoMYeD=+XZCeA$Y)*Nx>#V&F68;&cyI&Nn^d4z*0@_}E@UIZL&^)%PJ@6)${ zqvH*$7go5N$1LqaFJ3$8_mi)1PC2|zSpM;+t0wwJbkfL6eD%_8zs-XmZ+_kJi0!AP zzyE(|F7w0MA)n!3)p(&HakaEc&A=G>T3A(({*n?bvetqio*{S*R=1fZ-zEH!}u2##=@zO~*E_2xtE%)j? zxYy$?;ArA(AM(v$Zae~x51w(!c{}mQ@c~y|&A)AVy#3a^`<>UXOltxn}at zo4xX$@A}m1PZPhpemcwB-pr&=4R8AG2(CTCdBX@E`0S)%&gRnRt`BZ|&Ed^|e)HkG zG^d=v<%Dl8-f;-J_4DIzShIM;n2{FT?qg?fRyyv@sRvK<$ivWWes^5;$h+$iG(~&i z$)3II<*(nXdp^^y2cMlfn~9&a_2I;&9)?zTwSG%@-0?RnT)zYQW~0$Aa($ygR($0ghI8J$xG$8@BMovr>|4JeB{Mg%jvBj#%TtA@TlYPHWM#2@+9x=#FIZd?)cgbj9i>B?(Lh7 zH=H{kIE8uT^rxwx=lt4(8myDHQ;%8TVcHMg@VNEyDktFO;GH5baxmU`2uC*@UUyu4 zd!&b(-}L4BO~nm^vzoJ++(rF)$d8k6 z_58ONx^klXYA=W1q&`0Qq1j&CtHboW(C?i;?{>#eGdcY}@&})<{HN>f?&O&VT=q)K z8MilI%|}B{tan`go>lYaQ^3)~DUY9SC+_@)@3-Dg^!es^>_+(g>QxhJcyZ;=i38qm z5{3pZ_~e__ohMv)zr*BD4L3b_Z+HCk;Dk4yFno!;s^f6db*k}&&z&b6b{Iczg?@9X z%XP;UKfdsI-Sgs(14g*#$@i=;|2T2Lnv1{saC-9q!?(L$-r?FW4f$~L@YfyW*V&AD zaJGM3IQW&vrACL()vLc<>E(@v8F09(_w8VC$iw3IhSjt8c=YIp%^Q4Na=7TkVMcYK z&)M%1UbIV^zL6f7yy(N(OnUIpFnfCPdi-8-_HRwj3^?_rDd&!#E-o5*cTZDJ;LyjT zu#dhA|LqMgecbx=(B{=$-roehJEwjcxYJSB>r}_rca($2C+G@!Z#Z+o@S|5P96nm| zH1s$Hee*X9jL@Uz#ECl|T6)cuMtXGYS+9U~!sSaYe!RP1HShHDjJMh1bHUQDXZfW3aCU-#>jt)*Z`*5c1cfd2eJZ?2QdYwX_`Fs~WxW#`v%AVH% literal 0 HcmV?d00001 diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/i32_64x64_64x64_64x64_64x64/input1.bin b/test/tilelang_st/npu/a5/src/st/testcase/tadd/i32_64x64_64x64_64x64_64x64/input1.bin new file mode 100644 index 0000000000000000000000000000000000000000..dc5405e91e2935201d64ae8ef360a5c1ddd251dc GIT binary patch literal 16384 zcmZwFS!zQu5JXXDbh7^qM+bbN5cv=-b=TC-em)+L&vV|-`F-7=r+%OF^{Qqt-aqGj zoMYbI90wo9`|TVU)xptE(>otH?+#r!?>zX^QSr>F-`Xko#yXdR%G|bA{%sj4HZyE67!|TOs7hK0o zZI0*OUgd>r7rp%G!K&S1^DKKl@$&1Nl$Qn{{c?HJQF--U)ZX;+EC)^p&K##5%M{+* zY<=>Ur9FAk#ml!G<);&d25oh@u&XPZc^PS%m&?q%Z>5^?$)h=}9<642H1V8!-ac@7 z!gSh!N4xFb<}}I#*Y|KQlX*GJ7gH`6`1aAai@sia?GE*6NdEf{n9 zuyA^5mRHq-xo2gi-!3X_?=qB+Ufy`-`O20r4R<~BzFSytn7%_gdW(hUqwk(aI{Ib~+Xp6Xb#djA zS)WcB_4vb>wX0rwa)i@!`0Hh2g(xm6E&l@Ium~_0&+9@rbW!Ljo;rZsLhpXzlw@(`Sc#iqv^l0HY<-m217bh)# zcuu+8alG;S4!(iA>Xb{hJ6=Bfo%(j>d*7`3>noo(Oj%`$uR3P!01Km*!^}y?z1`CL zuNS@@(=5woH48`6?tPOodgGSA-wWP7T{SNrz1`bGy|_okIjI9^U(=KVilyy@WgKY?9;GgxOc;I>=UoL-uF&~(>NL*)~OkF&bX z=$qkm&qH5!?|us!ae2eEmkP_<=^LmHkFaTa({{=Tmk!=)nKf5scIY@XaO*YKr<2EK zqfhCMUeg@76C{zezm$dh{H)uzQc!H<6{?ZoEES{WS5t)w0p>hV5IL>Gj>? z!Q$(~y2I8}dBIovcJknqMa2!5hhDh$Q+d*IH>Zc&j_sQdy?A==@M^ie^Ve$zkFVfJ8w;qA0*S@`ManR7rz~8z&S8!$A!~NSLN4fUIxDMPSxFvmL8s7yT_Bsq2CU$-l`dXnfnga`t6Y}FGn9X z4R_k{+|BB@U%BAi&F`7K`Ieuje%|I~3X|q~(66uVH}H07;xx~veblm6!l}S3h6qK zw~L(haA3;;(;dgGyyf6aLl)R}!PRSD71rFW{OJM2e$p~Vy>_1=a;@}h6mrCe7qNDmQM4ssbw;QYabeE zx$Dc3mKi>t?^j0e_JUbY{l2xCH(iz1dh*GOHoW7Gla7i)v;2#+Q%TRyweHtI3;}9u6I+-tNK|A1A-wxDK86&dZx#@4VvBiem=X zcj-IyjbP2-+{@Q@@Xm`So>Skf(~O@(k8bn$a>By6oA-tZ$0L3#}Ci9ocW}$@6gNFF7EBLI(BP!TFq$am#e#Xd&^di za5VU+eH)q%EN-VfwDoZ4s16Q{d3^Vs!LK(@K4p+GzPr9T&i0$dQ|rkV2d|kN?sW9M zVcQ#D?shYG;N#%Kw1YQ3PFyujIeP24({krmR&R%2I&khh;8c2WX0YuM#$11Mu)Z8IC|yefzx-;FB@z# zJu|wnWmMtX!%P;q#hSMRj=jujmI=nuFGG3Y-PfmCd(p{*Mp%C7HpkUh_3|r|UMHP6 zvN-mHY2Mx2Ooh`|d&{|*XCUA7ySzk^yKRW4~!O8>Ul)tKHj|o;N1u-Bo#xKX$`{6)Z{C|n4mgL7+2%>ltbV>Y_IAp@ zd*PVD<9a*g>K=!8_jv1}w{Mv>0Va%{OIAq z_1)at!Mr|BTpnqZ+01?Wdh}_el~>;8I6UFs=07Wso1I zS-a+EmL{EeWruTMnydEJm%Z=ezCP`j=AGZ1ZaCTa($1@VVZ)o#i(97E;0xa_@a_(L z-@u)YiqF3sH1u%WqkJ%C`g&^Li=G3McKgeeo=U?BpHKSmn+-ONvU=}6FlO{|Rr+Df z+ESuv^~}TIyRSx>@XY8sFlm$rmPY>ehy&Y9A17@XcRPl0udkPuo_Ccltvg>?=)lPA z?$A=hc(-#tJoI|=bK2G1op!VI)H0QeAHCHrSG=<6=UXn7Pv75tHOtnwfiG)3y*%Lf zsAV$eE01cvJ$n1wn{{ybY~l;*pn3 z&U*G;;OJNT#@_C6%b>PX-;EDk8N7Mrufo#Ab>Q>F>kh}Gs_($yop1<%h?2+Br?Qc9sE0W}H9soKTAZ literal 0 HcmV?d00001 diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/i32_64x64_64x64_64x64_64x64/input2.bin b/test/tilelang_st/npu/a5/src/st/testcase/tadd/i32_64x64_64x64_64x64_64x64/input2.bin new file mode 100644 index 0000000000000000000000000000000000000000..fdc17ecedd3bfe13eeed257838153262936c5e16 GIT binary patch literal 16384 zcmZ9}QEntr2t(1%B$;IPzu~DP^_46>5D&18srr7sUVoqX^Th8HAMf~n^3M~W@2obL z^}O-aU+)}OEt<<_vOYeZox8(qac2MCnD*F2qNz|mj7;b3v|qEE-JG-1-T80QqH*O~5(3EMx z(f8Ja>A&dFhS3+*c@3{$Z?)ny=W+FTQRC!S53l|CZ8yBAFE^8yYoGbv=&eUzhQV`x z9`d6HZ*B%J_pR%9L-(Z*i!aRR%6AtX8t~rueOp}pZccAC%<`HSb(rpHKfT@My}EYu zkgx74(}B+?T=l;D4rY75w;k=Yzj^-U=4Q0I(SCDbrhn%&;e;8jy;Dyuyz>p;cQdmi z9M5oe=Fi(ajtm!84U=x)!raVmv2R$Kc*4%`?b63F6M4*+nO(HP=;6V%6Hg5%!+4vU z^$%pT`f1RPx4AhE?|1^9hrUDi!F7lILNC8G_MW~2+};O=lP+94>D_zbXz<&ue5z@+ z2UqkR@{=d{Zr&Z{eY5uPqo01cIer{ieY@$)Je%V+(=%HxjQM(P-W~MyaA8H>Xxh0L zb$;PQcYD*qYfrjyx-}%LHZ=Ag7e z3>{f*-dkVYoJIl5*WINj_>tF7`^@aXO~W2FZ{h51U-LZSdCTpIm#*3yw>#+BueQrO zul4AKp^>NFYC6?!Nmo{<>rGc})=jwj{OQx|jnjSbX6m$LJ>iWf)HLaZm)j3xFE1CH zp1GZJ+VI|V+m-I#6Bo`-nrY=hhQp^DAICc%@#Jo64x6Xu{B;+67(Kh}>}FU#%K6m8 zjbolJ4YNGpK|v-C>5WZijlc>f@)GzqHl$ z(#K01ukVoG?ts^W(d!o9f`*KnzPY}fr{*wbf?nU#j%MlD6;6*HAKvu5Mfb>}oz3mh zr{^uw=^j49$!Ubo=it3Eu9y~R$w zy<5#kzIj!{*2|YnFU>G&ns~yyJ!X08oi}~&=6zp$yJZ}Eirdr;J2jANcntCEV8AfzdJvH6-%5FFB#^ssA>&5f#Zu9Qcr^UZ)mXaSB*oj`^<0ztUf+p?J>iZ^SF1J z;pN-=?#at$yX~aw-97E5=WSOW@b&Vcj;l{k=0P@3zuV~9rG{4*`Ri7DU<>$mnhE>! zB;%&ttbOKb(1B|vo5zL2rzi3qF0Xk^JMVDaXV!PpGoxo$`{;@E^woIHaMKFcH-Mp= z27T{5c({XKD2t%$Y7M(qZn-Why%=+jX5FW_O?Et?5=wK?tOEQV%t+U=2XWnmW{Y_s0o)DHb_(KELP z&&)f|al&+SUi4vbd#lrd+Z$ob^tzvyzKz{^lJV1$o57{)O_v7D-rr1*)_OF{cYS=J z<||z|eLM5u4MVpcd>H+BaQ3Cmt9p6rZjkBV>(i3;gt=YKWjpAl35RDVF1$J|oP6lv zsndj~ndfqSSdpi6+htbIJHK+Z+>Y*~*N!-J^!o-lxOtXW=knYXpO#zl1}kXOkm=#* z!{Vy%9<@1M9($Wzt?=ELCY(qk|Lt?DJutX%eLJ-`j(0n8!{ByDJMsV*7KTo|Jatn# z`UR{zgg5=ZTlg?$c9qvV4$VAx+Y1k;pHI8gFzqbwH;u2CUmD$*PqRG5H^Wz(*(d9l z&0xBZw$QH^FTd%^;b`QqdG~ZfUg*Qp(-V1@h3`chwr@vQouBmJ?V>5?3s$6qXP%zE zcQ~`W(133roOyLO%4T-cQ}Yv7f9Le+n0v#wA66}!%L1+(-!2%@_i8VUKCN_QKJq84 z(>Cv&j-9;q&AiPDc#;3Kc?*+o`uf%8YFsnn&6AnhUKqM&>BEG}8=T+=SBCF~G}L{6 z`25H)-6Zn?Tci(bM?S*h>&e~Ecl-8M^Y+PRbmQUJr#6%Al=aP<<^6i1+l{bs_}yJ< z-1gwO1Kvzu7G`QQVMqSWddvHd?C-nfw_ET8J=y}FCw=((a<}O7ga>Dz97YWzy2qQg z3@6)Dt!Hk>-fOO=18-+H>EW20>8r!USYfVb_ufLId003YoIH7VvtIqQ)=M7X zz0GLKcy7rHOnYJNf%C40U!DCHu=sRg!s?mP(`&E3-LgLaJiy`WH>(Hl-8X>~&Gq{p z`rb6%-ZPZ*E^$ zxgGg|g~J!{^aL;F`DhO=tbUwwJ7ijd4o;q6yR$uT<}}T=Uv0)~d*wJXT|B{e8ts$w zV8&xTx$kA4nttC>?zZ;!-DsM%&wRDK)1;d>S&c_S?)#hJ;F_tu3%quwZ*Tf$IC;^B z>+b%`;>^P7_uc8_F%6k7c$uDE?ZoXSShLlur%&6AR=8@tJh#sbPVH@Xd-c8J<fqc7~TC%@|Y@IsB(H;38pXWo6ZybB)f&m(U9^wo>WPv3>s zdNC#NQos`rfqE_6fK+?(hz4 zpE(>(H4QzvfD?HLPgBoKkG^-i>eI`IJ!%<8tge|_*iTEqRME z7PQlV<3kM>1~32X(>p)#F!tuBxgGIf^0{8@u&?jZ&F~`4F!W&6X0+t)(buypetYu; z-x~%;O(P$Aya~N-!PS#j$9(xb(9p9(-45^WHm{~vt!J;m*T>P%gPymZpfBrt!>8vh z!sID!*>@6skNndR`AL_So}F^v%p0bhrn+xtU!KhAhtsE($M(zkW_H5&Exe0-+t(hm z-tp70&n%;1;KAt`v@7>jIShILC91nQjFz|c`eF4**-mrO!-yV4H_IZb=$%DX?>6XLNfQ_f7 zC%V)Ac>46pbcFeG)#;Y^JK2{nyn>c@-wr2zJKF<8W5XW(yu+I7(@7IXZNHv6o;nZt z#izACx*y&Sb$jsyylB6e)35#d^5fD%)+=g5BbE=m-Woc`umOeRm*A8*TdsapO)MX wZ*$+EZ>N_|{PKDU2a7AbX?w%c?t9ol2d{Vcz~Sa43>*wCn)-b=wfWWm4~QR5O#lD@ literal 0 HcmV?d00001 diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tadd/launch.cpp index f1074c838..6dfb594d8 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tadd/launch.cpp +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadd/launch.cpp @@ -1,7 +1,7 @@ // Copyright (c) 2026 Huawei Technologies Co., Ltd. // This program is free software, you can redistribute it and/or modify it under the terms and conditions of // CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. +// Please refer to the License for details. You can not use this file except in compliance with the License. // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. @@ -12,16 +12,58 @@ #define AICORE [aicore] #endif -// Case 0: f32 16x64 -extern "C" __global__ AICORE void TADD_f32_16x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); +// Case 0: f32 64x128_64x128_64x128_64x128 +extern "C" __global__ AICORE void TADD_f32_64x128_64x128_64x128_64x128(__gm__ float *a, __gm__ float *b, __gm__ float *c); -void LaunchTADD_f32_16x64(float *a, float *b, float *c, void *stream) { - TADD_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +void LaunchTADD_f32_64x128_64x128_64x128_64x128(void *a, void *b, void *c, void *stream) { + TADD_f32_64x128_64x128_64x128_64x128<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); } -// Case 1: f32 32x32 -extern "C" __global__ AICORE void TADD_f32_32x32(__gm__ float *a, __gm__ float *b, __gm__ float *c); +// Case 1: f32 16x64_16x64_16x64_16x64 +extern "C" __global__ AICORE void TADD_f32_16x64_16x64_16x64_16x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); -void LaunchTADD_f32_32x32(float *a, float *b, float *c, void *stream) { - TADD_f32_32x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +void LaunchTADD_f32_16x64_16x64_16x64_16x64(void *a, void *b, void *c, void *stream) { + TADD_f32_16x64_16x64_16x64_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); } + +// Case 2: f32 32x32_32x32_32x32_32x32 +extern "C" __global__ AICORE void TADD_f32_32x32_32x32_32x32_32x32(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTADD_f32_32x32_32x32_32x32_32x32(void *a, void *b, void *c, void *stream) { + TADD_f32_32x32_32x32_32x32_32x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 3: f32 64x64_64x64_64x64_64x64 +extern "C" __global__ AICORE void TADD_f32_64x64_64x64_64x64_64x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTADD_f32_64x64_64x64_64x64_64x64(void *a, void *b, void *c, void *stream) { + TADD_f32_64x64_64x64_64x64_64x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 4: i32 64x64_64x64_64x64_64x64 +extern "C" __global__ AICORE void TADD_i32_64x64_64x64_64x64_64x64(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); + +void LaunchTADD_i32_64x64_64x64_64x64_64x64(void *a, void *b, void *c, void *stream) { + TADD_i32_64x64_64x64_64x64_64x64<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); +} + +// Case 5: i16 64x64_64x64_64x64_64x64 +extern "C" __global__ AICORE void TADD_i16_64x64_64x64_64x64_64x64(__gm__ int16_t *a, __gm__ int16_t *b, __gm__ int16_t *c); + +void LaunchTADD_i16_64x64_64x64_64x64_64x64(void *a, void *b, void *c, void *stream) { + TADD_i16_64x64_64x64_64x64_64x64<<<1, nullptr, stream>>>((__gm__ int16_t *)a, (__gm__ int16_t *)b, (__gm__ int16_t *)c); +} + +// Case 6: f16 16x256_16x256_16x256_16x256 +extern "C" __global__ AICORE void TADD_f16_16x256_16x256_16x256_16x256(__gm__ half *a, __gm__ half *b, __gm__ half *c); + +void LaunchTADD_f16_16x256_16x256_16x256_16x256(void *a, void *b, void *c, void *stream) { + TADD_f16_16x256_16x256_16x256_16x256<<<1, nullptr, stream>>>((__gm__ half *)a, (__gm__ half *)b, (__gm__ half *)c); +} + +// Case 7: half 16x64_16x128_16x128_16x64 +extern "C" __global__ AICORE void TADD_half_16x64_16x128_16x128_16x64(__gm__ half *a, __gm__ half *b, __gm__ half *c); + +void LaunchTADD_half_16x64_16x128_16x128_16x64(void *a, void *b, void *c, void *stream) { + TADD_half_16x64_16x128_16x128_16x64<<<1, nullptr, stream>>>((__gm__ half *)a, (__gm__ half *)b, (__gm__ half *)c); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tadd/main.cpp index 1a010623f..276bd83a7 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tadd/main.cpp +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadd/main.cpp @@ -1,7 +1,7 @@ // Copyright (c) 2026 Huawei Technologies Co., Ltd. // This program is free software, you can redistribute it and/or modify it under the terms and conditions of // CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. +// Please refer to the License for details. You can not use this file except in compliance with the License. // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. @@ -22,71 +22,89 @@ using namespace PtoTestCommon; // Kernel launch wrappers (defined in launch.cpp) -void LaunchTADD_f32_16x64(float *a, float *b, float *c, void *stream); -void LaunchTADD_f32_32x32(float *a, float *b, float *c, void *stream); +void LaunchTADD_f32_64x128_64x128_64x128_64x128(void *a, void *b, void *c, void *stream); +void LaunchTADD_f32_16x64_16x64_16x64_16x64(void *a, void *b, void *c, void *stream); +void LaunchTADD_f32_32x32_32x32_32x32_32x32(void *a, void *b, void *c, void *stream); +void LaunchTADD_f32_64x64_64x64_64x64_64x64(void *a, void *b, void *c, void *stream); +void LaunchTADD_i32_64x64_64x64_64x64_64x64(void *a, void *b, void *c, void *stream); +void LaunchTADD_i16_64x64_64x64_64x64_64x64(void *a, void *b, void *c, void *stream); +void LaunchTADD_f16_16x256_16x256_16x256_16x256(void *a, void *b, void *c, void *stream); +void LaunchTADD_half_16x64_16x128_16x128_16x64(void *a, void *b, void *c, void *stream); -using LaunchFn = void (*)(float *, float *, float *, void *); +using LaunchFn = void (*)(void *, void *, void *, void *); struct TestCase { const char *name; LaunchFn launch; - size_t rows; // allocated tile rows - size_t cols; // allocated tile cols - size_t validRows; // effective computation rows (<= rows) - size_t validCols; // effective computation cols (<= cols) - size_t elemSize; // bytes per element + size_t src0Rows; + size_t src0Cols; + size_t src1Rows; + size_t src1Cols; + size_t dstRows; + size_t dstCols; + size_t validRows; + size_t validCols; + size_t elemSize; }; static const TestCase kCases[] = { - {"f32_16x64", LaunchTADD_f32_16x64, 16, 64, 16, 64, sizeof(float)}, - {"f32_32x32", LaunchTADD_f32_32x32, 32, 32, 32, 32, sizeof(float)}, + {"f32_64x128_64x128_64x128_64x128", LaunchTADD_f32_64x128_64x128_64x128_64x128, 64, 128, 64, 128, 64, 128, 64, 128, sizeof(float)}, + {"f32_16x64_16x64_16x64_16x64", LaunchTADD_f32_16x64_16x64_16x64_16x64, 16, 64, 16, 64, 16, 64, 16, 64, sizeof(float)}, + {"f32_32x32_32x32_32x32_32x32", LaunchTADD_f32_32x32_32x32_32x32_32x32, 32, 32, 32, 32, 32, 32, 32, 32, sizeof(float)}, + {"f32_64x64_64x64_64x64_64x64", LaunchTADD_f32_64x64_64x64_64x64_64x64, 64, 64, 64, 64, 64, 64, 64, 64, sizeof(float)}, + {"i32_64x64_64x64_64x64_64x64", LaunchTADD_i32_64x64_64x64_64x64_64x64, 64, 64, 64, 64, 64, 64, 64, 64, sizeof(int32_t)}, + {"i16_64x64_64x64_64x64_64x64", LaunchTADD_i16_64x64_64x64_64x64_64x64, 64, 64, 64, 64, 64, 64, 64, 64, sizeof(int16_t)}, + {"f16_16x256_16x256_16x256_16x256", LaunchTADD_f16_16x256_16x256_16x256_16x256, 16, 256, 16, 256, 16, 256, 16, 256, sizeof(uint16_t)}, + {"half_16x64_16x128_16x128_16x64", LaunchTADD_half_16x64_16x128_16x128_16x64, 16, 128, 16, 128, 16, 64, 16, 64, sizeof(uint16_t)}, }; static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + (void)deviceId; int rc = 0; - const size_t elemCount = tc.rows * tc.cols; - const size_t fileSize = elemCount * tc.elemSize; + const size_t src0Size = tc.src0Rows * tc.src0Cols * tc.elemSize; + const size_t src1Size = tc.src1Rows * tc.src1Cols * tc.elemSize; + const size_t dstSize = tc.dstRows * tc.dstCols * tc.elemSize; - std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", - tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + std::printf("[INFO] === case: %s (dst=%zux%zu, src0=%zux%zu, src1=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.dstRows, tc.dstCols, tc.src0Rows, tc.src0Cols, tc.src1Rows, tc.src1Cols, tc.validRows, tc.validCols); // Per-case data directory std::string caseDir = std::string("./") + tc.name; - size_t src0FileSize = fileSize; - size_t src1FileSize = fileSize; - float *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; - float *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; - aclrtMallocHost((void **)(&src0Host), fileSize); - aclrtMallocHost((void **)(&src1Host), fileSize); - aclrtMallocHost((void **)(&dstHost), fileSize); + aclrtMallocHost(&src0Host, src0Size); + aclrtMallocHost(&src1Host, src1Size); + aclrtMallocHost(&dstHost, dstSize); - aclrtMalloc((void **)&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); - aclrtMalloc((void **)&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); - aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&src0Device, src0Size, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&src1Device, src1Size, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, dstSize, ACL_MEM_MALLOC_HUGE_FIRST); - if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, fileSize)) { + size_t fileSize = 0; + if (!ReadFile((caseDir + "/input1.bin").c_str(), fileSize, src0Host, src0Size)) { std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); rc = 1; } - if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, fileSize)) { + fileSize = 0; + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), fileSize, src1Host, src1Size)) { std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); rc = 1; } if (rc == 0) { - aclrtMemcpy(src0Device, fileSize, src0Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); - aclrtMemcpy(src1Device, fileSize, src1Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src0Device, src0Size, src0Host, src0Size, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, src1Size, src1Host, src1Size, ACL_MEMCPY_HOST_TO_DEVICE); tc.launch(src0Device, src1Device, dstDevice, stream); aclrtSynchronizeStream(stream); - aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + aclrtMemcpy(dstHost, dstSize, dstDevice, dstSize, ACL_MEMCPY_DEVICE_TO_HOST); } - if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstSize)) { std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); rc = 1; } @@ -142,4 +160,4 @@ int main(int argc, char *argv[]) { aclFinalize(); return rc; -} +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/tadd.pto b/test/tilelang_st/npu/a5/src/st/testcase/tadd/tadd.pto index 340e416c3..52c8cfd05 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tadd/tadd.pto +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadd/tadd.pto @@ -1,7 +1,7 @@ // Copyright (c) 2026 Huawei Technologies Co., Ltd. // This program is free software, you can redistribute it and/or modify it under the terms and conditions of // CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. +// Please refer to the License for details. You can not use this file except in compliance with the License. // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. @@ -12,8 +12,72 @@ // to produce LLVM IR. module { - // Case 0: f32 16x64 (1024 elements) - func.func @TADD_f32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + // Case 0: f32 64x128_64x128_64x128_64x128 (8192 elements) + func.func @TADD_f32_64x128_64x128_64x128_64x128(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c8192 = arith.constant 8192 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c128], + strides = [%c8192, %c8192, %c8192, %c128, %c1] + : !pto.tensor_view<1x1x1x64x128xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c128], + strides = [%c8192, %c8192, %c8192, %c128, %c1] + : !pto.tensor_view<1x1x1x64x128xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c128], + strides = [%c8192, %c8192, %c8192, %c128, %c1] + : !pto.tensor_view<1x1x1x64x128xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c128] + : !pto.tensor_view<1x1x1x64x128xf32> -> !pto.partition_tensor_view<1x1x1x64x128xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c128] + : !pto.tensor_view<1x1x1x64x128xf32> -> !pto.partition_tensor_view<1x1x1x64x128xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c128] + : !pto.tensor_view<1x1x1x64x128xf32> -> !pto.partition_tensor_view<1x1x1x64x128xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x128xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x128xf32>) + outs(%b : !pto.tile_buf) + + pto.tadd ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x128xf32>) + return + } + + // Case 1: f32 16x64_16x64_16x64_16x64 (1024 elements) + func.func @TADD_f32_16x64_16x64_16x64_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c16 = arith.constant 16 : index @@ -76,8 +140,8 @@ module { return } - // Case 1: f32 32x32 (1024 elements) - func.func @TADD_f32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + // Case 2: f32 32x32_32x32_32x32_32x32 (1024 elements) + func.func @TADD_f32_32x32_32x32_32x32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c32 = arith.constant 32 : index @@ -138,4 +202,323 @@ module { outs(%c_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) return } -} + + // Case 3: f32 64x64_64x64_64x64_64x64 (4096 elements) + func.func @TADD_f32_64x64_64x64_64x64_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tadd ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case 4: i32 64x64_64x64_64x64_64x64 (4096 elements) + func.func @TADD_i32_64x64_64x64_64x64_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + outs(%b : !pto.tile_buf) + + pto.tadd ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + return + } + + // Case 5: i16 64x64_64x64_64x64_64x64 (4096 elements) + func.func @TADD_i16_64x64_64x64_64x64_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) + outs(%b : !pto.tile_buf) + + pto.tadd ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) + return + } + + // Case 6: f16 16x256_16x256_16x256_16x256 (4096 elements) + func.func @TADD_f16_16x256_16x256_16x256_16x256(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xf16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c256] + : !pto.tensor_view<1x1x1x16x256xf16> -> !pto.partition_tensor_view<1x1x1x16x256xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c256] + : !pto.tensor_view<1x1x1x16x256xf16> -> !pto.partition_tensor_view<1x1x1x16x256xf16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c256] + : !pto.tensor_view<1x1x1x16x256xf16> -> !pto.partition_tensor_view<1x1x1x16x256xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x256xf16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x256xf16>) + outs(%b : !pto.tile_buf) + + pto.tadd ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x256xf16>) + return + } + + // Case 7: half 16x64_16x128_16x128_16x64 (src=16x128, dst=16x64, valid=16x64) + func.func @TADD_half_16x64_16x128_16x128_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c2048 = arith.constant 2048 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + outs(%b : !pto.tile_buf) + + pto.tadd ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tand/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tand/CMakeLists.txt new file mode 100644 index 000000000..230a97296 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tand/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tand) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tand/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tand/cases.py new file mode 100644 index 000000000..fff678a85 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tand/cases.py @@ -0,0 +1,39 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tand ST test cases. + +Each case defines: + - name: case identifier + - dtype: numpy dtype (np.int16, np.int8) + - shape: (rows, cols) — allocated tile dimensions + - valid_shape: (valid_rows, valid_cols) — effective computation region + - eps: tolerance for numpy.allclose (atol and rtol) +""" + +import numpy as np + +CASES = [ + { + "name": "i16_64x64", + "dtype": np.int16, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 0, + }, + { + "name": "i8_64x64_valid63x63", + "dtype": np.int8, + "shape": (64, 64), + "valid_shape": (63, 63), + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tand/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tand/compare.py new file mode 100644 index 000000000..b975718e6 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tand/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tand/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tand/gen_data.py new file mode 100644 index 000000000..de8dab931 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tand/gen_data.py @@ -0,0 +1,35 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +np.random.seed(19) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 16383, size=shape).astype(dtype) + input2 = np.random.randint(1, 16383, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = (input1[:vr, :vc] & input2[:vr, :vc]).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tand/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tand/launch.cpp new file mode 100644 index 000000000..31006ed8c --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tand/launch.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +extern "C" __global__ AICORE void TAND_i16_64x64(__gm__ int16_t *a, __gm__ int16_t *b, __gm__ int16_t *c); + +void LaunchTAND_i16_64x64(void *a, void *b, void *c, void *stream) { + TAND_i16_64x64<<<1, nullptr, stream>>>((__gm__ int16_t *)a, (__gm__ int16_t *)b, (__gm__ int16_t *)c); +} + +extern "C" __global__ AICORE void TAND_i8_64x64_valid63x63(__gm__ int8_t *a, __gm__ int8_t *b, __gm__ int8_t *c); + +void LaunchTAND_i8_64x64_valid63x63(void *a, void *b, void *c, void *stream) { + TAND_i8_64x64_valid63x63<<<1, nullptr, stream>>>((__gm__ int8_t *)a, (__gm__ int8_t *)b, (__gm__ int8_t *)c); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tand/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tand/main.cpp new file mode 100644 index 000000000..af4b9cd56 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tand/main.cpp @@ -0,0 +1,139 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tand ST — case-table driven. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +void LaunchTAND_i16_64x64(void *a, void *b, void *c, void *stream); +void LaunchTAND_i8_64x64_valid63x63(void *a, void *b, void *c, void *stream); + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; + size_t cols; + size_t validRows; + size_t validCols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + {"i16_64x64", LaunchTAND_i16_64x64, 64, 64, 64, 64, sizeof(int16_t)}, + {"i8_64x64_valid63x63", LaunchTAND_i8_64x64_valid63x63, 64, 64, 63, 63, sizeof(int8_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t fileSize = tc.rows * tc.cols * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + std::string caseDir = std::string("./") + tc.name; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&src0Host, fileSize); + aclrtMallocHost(&src1Host, fileSize); + aclrtMallocHost(&dstHost, fileSize); + + aclrtMalloc(&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + size_t fileSizeRead = 0; + if (!ReadFile((caseDir + "/input1.bin").c_str(), fileSizeRead, src0Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + fileSizeRead = 0; + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), fileSizeRead, src1Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, fileSize, src0Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, fileSize, src1Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tand/tand.pto b/test/tilelang_st/npu/a5/src/st/testcase/tand/tand.pto new file mode 100644 index 000000000..6423bc5a1 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tand/tand.pto @@ -0,0 +1,163 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except the compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tand: tload(a) + tload(b) + tand(a,b)->c + tstore(c). +// Aligned with pto-isa/tests/npu/a2a3/src/st/testcase/tand +// Cases cover different dtypes and shapes. + +module { + // NOTE: Unsigned integer types temporarily disabled due to HIVM backend bug + // See BUG_REPORT_UNSIGNED_BITOPS.md for details + + // Case 0: ui16, 64x64, valid 64x64 - DISABLED + // func.func @TAND_ui16_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + // ... (disabled due to unsigned bitops bug) + // } + + // Case 1: ui16, 64x64, valid 63x63 - DISABLED + // func.func @TAND_ui16_64x64_valid63x63(...) + + // Case 2: ui16, 1x16384, valid 1x16384 - DISABLED + // func.func @TAND_ui16_1x16384(...) + + // Case 3: ui16, 2048x16, valid 2048x16 - DISABLED + // func.func @TAND_ui16_2048x16(...) + + // Case 0 (reindexed): i16, 64x64, valid 64x64 + func.func @TAND_i16_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) + outs(%b : !pto.tile_buf) + + pto.tand ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) + return + } + + // Case 5: ui16, 64x64, valid 64x64 (half mode) - DISABLED + // func.func @TAND_ui16_64x64_half(...) + + // Case 6: ui8, 64x64, valid 63x63 - DISABLED + // func.func @TAND_ui8_64x64_valid63x63(...) + + // Case 1 (reindexed): i8, 64x64, valid 63x63 + func.func @TAND_i8_64x64_valid63x63(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi8> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi8> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi8> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c63] + : !pto.tensor_view<1x1x1x64x64xi8> -> !pto.partition_tensor_view<1x1x1x63x63xi8> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c63] + : !pto.tensor_view<1x1x1x64x64xi8> -> !pto.partition_tensor_view<1x1x1x63x63xi8> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c63] + : !pto.tensor_view<1x1x1x64x64xi8> -> !pto.partition_tensor_view<1x1x1x63x63xi8> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x63x63xi8>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x63x63xi8>) + outs(%b : !pto.tile_buf) + + pto.tand ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x63x63xi8>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/CMakeLists.txt new file mode 100644 index 000000000..a863ea151 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tcmp) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/cases.py new file mode 100644 index 000000000..c964a0329 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/cases.py @@ -0,0 +1,63 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tcmp ST test cases. + +Each case defines: + - name: case identifier + - dtype: numpy dtype for inputs (np.float32 or np.int32) + - out_dtype: numpy dtype for output (np.int8) + - shape: (rows, cols) — tile buffer dimensions + - valid_shape: (valid_rows, valid_cols) — effective computation region + - eps: tolerance for comparison + - cmp_mode: comparison mode string (eq, gt, le) +""" + +import numpy as np + +CASES = [ + { + "name": "f32_1x64_eq", + "dtype": np.float32, + "out_dtype": np.int8, + "shape": (1, 64), + "valid_shape": (1, 64), + "eps": 0, + "cmp_mode": "eq", + }, + { + "name": "f32_8x64_gt", + "dtype": np.float32, + "out_dtype": np.int8, + "shape": (8, 64), + "valid_shape": (8, 64), + "eps": 0, + "cmp_mode": "gt", + }, + { + "name": "i32_16x32_eq", + "dtype": np.int32, + "out_dtype": np.int8, + "shape": (16, 32), + "valid_shape": (16, 32), + "eps": 0, + "cmp_mode": "eq", + }, + { + "name": "i32_32x32_eq", + "dtype": np.int32, + "out_dtype": np.int8, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 0, + "cmp_mode": "eq", + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/compare.py new file mode 100644 index 000000000..bef371b3b --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/compare.py @@ -0,0 +1,56 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +ALIGN_STRIDE = 32 + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + out_dtype = case["out_dtype"] + vr, vc = case["valid_shape"] + packed_shape = (vr, ALIGN_STRIDE) + packed_size = vr * ALIGN_STRIDE + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=out_dtype).reshape(packed_shape) + output_full = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=out_dtype) + output = output_full[:packed_size].reshape(packed_shape) + + ok = result_cmp(golden, output, case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/gen_data.py new file mode 100644 index 000000000..5cf29a05a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/gen_data.py @@ -0,0 +1,72 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + + +def compute_cmp(a, b, mode): + if mode == "gt": + return (a > b) + elif mode == "ge": + return (a >= b) + elif mode == "lt": + return (a < b) + elif mode == "le": + return (a <= b) + elif mode == "eq": + return (a == b) + elif mode == "ne": + return (a != b) + else: + raise ValueError(f"Unknown cmp_mode: {mode}") + + +ALIGN_STRIDE = 32 + + +def pack_predicate_mask(cmp_result): + cmp_result = cmp_result.astype(np.uint8) + shape = cmp_result.shape + packed_shape = (shape[0], ALIGN_STRIDE) + packed = np.zeros(packed_shape, dtype=np.uint8) + for row in range(shape[0]): + for vl in range(min(8, shape[1] // 8)): + lanes = cmp_result[row, vl*8:(vl+1)*8] + for j in range(8): + if lanes[j]: + byte_idx = j // 2 + bit_pos = (j % 2) * 4 + packed[row, vl*4 + byte_idx] |= (1 << bit_pos) + return packed.view(np.int8) + + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + out_dtype = case["out_dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + cmp_mode = case["cmp_mode"] + + input1 = np.random.choice([-5, -2, -1, 0, 1, 2, 5], size=shape).astype(dtype) + input2 = np.random.choice([-5, -2, -1, 0, 1, 2, 5], size=shape).astype(dtype) + + vr, vc = valid_shape + cmp_result = compute_cmp(input1[:vr, :vc], input2[:vr, :vc], cmp_mode) + golden = pack_predicate_mask(cmp_result) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__} out_dtype={out_dtype.__name__} cmp_mode={cmp_mode}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/launch.cpp new file mode 100644 index 000000000..6841e5589 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/launch.cpp @@ -0,0 +1,37 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +extern "C" __global__ AICORE void TCMP_f32_1x64_eq(__gm__ float *a, __gm__ float *b, __gm__ int8_t *c); + +void LaunchTCMP_f32_1x64_eq(float *a, float *b, int8_t *c, void *stream) { + TCMP_f32_1x64_eq<<<1, nullptr, stream>>>(a, b, c); +} + +extern "C" __global__ AICORE void TCMP_f32_8x64_gt(__gm__ float *a, __gm__ float *b, __gm__ int8_t *c); + +void LaunchTCMP_f32_8x64_gt(float *a, float *b, int8_t *c, void *stream) { + TCMP_f32_8x64_gt<<<1, nullptr, stream>>>(a, b, c); +} + +extern "C" __global__ AICORE void TCMP_i32_16x32_eq(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int8_t *c); + +void LaunchTCMP_i32_16x32_eq(int32_t *a, int32_t *b, int8_t *c, void *stream) { + TCMP_i32_16x32_eq<<<1, nullptr, stream>>>(a, b, c); +} + +extern "C" __global__ AICORE void TCMP_i32_32x32_eq(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int8_t *c); + +void LaunchTCMP_i32_32x32_eq(int32_t *a, int32_t *b, int8_t *c, void *stream) { + TCMP_i32_32x32_eq<<<1, nullptr, stream>>>(a, b, c); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/main.cpp new file mode 100644 index 000000000..83cb28f3a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/main.cpp @@ -0,0 +1,147 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tcmp ST — case-table driven. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +void LaunchTCMP_f32_1x64_eq(float *a, float *b, int8_t *c, void *stream); +void LaunchTCMP_f32_8x64_gt(float *a, float *b, int8_t *c, void *stream); +void LaunchTCMP_i32_16x32_eq(int32_t *a, int32_t *b, int8_t *c, void *stream); +void LaunchTCMP_i32_32x32_eq(int32_t *a, int32_t *b, int8_t *c, void *stream); + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; + size_t cols; + size_t validRows; + size_t validCols; + size_t inputElemSize; + size_t outputElemSize; +}; + +static const TestCase kCases[] = { + {"f32_1x64_eq", (LaunchFn)LaunchTCMP_f32_1x64_eq, 1, 64, 1, 64, sizeof(float), sizeof(int8_t)}, + {"f32_8x64_gt", (LaunchFn)LaunchTCMP_f32_8x64_gt, 8, 64, 8, 64, sizeof(float), sizeof(int8_t)}, + {"i32_16x32_eq", (LaunchFn)LaunchTCMP_i32_16x32_eq, 16, 32, 16, 32, sizeof(int32_t), sizeof(int8_t)}, + {"i32_32x32_eq", (LaunchFn)LaunchTCMP_i32_32x32_eq, 32, 32, 32, 32, sizeof(int32_t), sizeof(int8_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t inputFileSize = elemCount * tc.inputElemSize; + const size_t outputFileSize = elemCount * tc.outputElemSize; + const size_t packedMaskSize = tc.validRows * 32 * tc.outputElemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu, packed_mask=%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols, packedMaskSize); + + std::string caseDir = std::string("./") + tc.name; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&src0Host, inputFileSize); + aclrtMallocHost(&src1Host, inputFileSize); + aclrtMallocHost(&dstHost, outputFileSize); + + aclrtMalloc(&src0Device, inputFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&src1Device, inputFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, outputFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + size_t readSize = inputFileSize; + if (!ReadFile((caseDir + "/input1.bin").c_str(), readSize, src0Host, inputFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + readSize = inputFileSize; + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), readSize, src1Host, inputFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, inputFileSize, src0Host, inputFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, inputFileSize, src1Host, inputFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, outputFileSize, dstDevice, outputFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, outputFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/tcmp.pto b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/tcmp.pto new file mode 100644 index 000000000..317fd93ac --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/tcmp.pto @@ -0,0 +1,266 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tcmp: tload(a) + tload(b) + tcmp(a,b)->c + tstore(c). +// Aligned with pto-isa/tests/npu/a2a3/src/st/testcase/tcmp +// 4 cases: f32_1x64_eq, f32_8x64_gt, i32_16x32_eq, i32_32x32_eq + +module { + // Case 0: f32 1x64 eq + func.func @TCMP_f32_1x64_eq(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xi8> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xf32> -> !pto.partition_tensor_view<1x1x1x1x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xf32> -> !pto.partition_tensor_view<1x1x1x1x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xi8> -> !pto.partition_tensor_view<1x1x1x1x64xi8> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x1x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x1x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tcmp ins(%a, %b {cmpMode = #pto} : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x1x64xi8>) + return + } + + // Case 1: f32 8x64 gt + func.func @TCMP_f32_8x64_gt(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c512 = arith.constant 512 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c8, %c64], + strides = [%c512, %c512, %c512, %c64, %c1] + : !pto.tensor_view<1x1x1x8x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c8, %c64], + strides = [%c512, %c512, %c512, %c64, %c1] + : !pto.tensor_view<1x1x1x8x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c8, %c64], + strides = [%c512, %c512, %c512, %c64, %c1] + : !pto.tensor_view<1x1x1x8x64xi8> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c64] + : !pto.tensor_view<1x1x1x8x64xf32> -> !pto.partition_tensor_view<1x1x1x8x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c64] + : !pto.tensor_view<1x1x1x8x64xf32> -> !pto.partition_tensor_view<1x1x1x8x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c64] + : !pto.tensor_view<1x1x1x8x64xi8> -> !pto.partition_tensor_view<1x1x1x8x64xi8> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x8x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x8x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tcmp ins(%a, %b {cmpMode = #pto} : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x8x64xi8>) + return + } + + // Case 2: i32 16x32 eq + func.func @TCMP_i32_16x32_eq(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c512 = arith.constant 512 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi8> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi8> -> !pto.partition_tensor_view<1x1x1x16x32xi8> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + outs(%b : !pto.tile_buf) + + pto.tcmp ins(%a, %b {cmpMode = #pto} : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x32xi8>) + return + } + + // Case 3: i32 32x32 eq + func.func @TCMP_i32_32x32_eq(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi8> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi8> -> !pto.partition_tensor_view<1x1x1x32x32xi8> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) + outs(%b : !pto.tile_buf) + + pto.tcmp ins(%a, %b {cmpMode = #pto} : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x32x32xi8>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tdiv/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/CMakeLists.txt new file mode 100644 index 000000000..506774c21 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tdiv) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tdiv/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/cases.py new file mode 100644 index 000000000..68469d49d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/cases.py @@ -0,0 +1,71 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tdiv ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +Note: tdiv only supports float32 and float16. + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_16x64", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-6, + }, + { + "name": "f32_32x32", + "dtype": np.float32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-6, + }, + { + "name": "f32_64x64", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 1e-6, + }, + { + "name": "f16_64x64", + "dtype": np.float16, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 1e-3, + }, + { + "name": "f16_64x64_v61x61", + "dtype": np.float16, + "shape": (64, 64), + "valid_shape": (61, 61), + "eps": 1e-3, + }, + { + "name": "f32_64x32_v60x30", + "dtype": np.float32, + "shape": (64, 32), + "valid_shape": (60, 30), + "eps": 1e-6, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tdiv/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/compare.py new file mode 100644 index 000000000..b975718e6 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tdiv/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/gen_data.py new file mode 100644 index 000000000..27039e842 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/gen_data.py @@ -0,0 +1,33 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + input2 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = (input1[:vr, :vc] / input2[:vr, :vc]).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tdiv/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/launch.cpp new file mode 100644 index 000000000..ee4168bb1 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 16x64 +extern "C" __global__ AICORE void TDIV_f32_16x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTDIV_f32_16x64(void *a, void *b, void *c, void *stream) { + TDIV_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 1: f32 32x32 +extern "C" __global__ AICORE void TDIV_f32_32x32(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTDIV_f32_32x32(void *a, void *b, void *c, void *stream) { + TDIV_f32_32x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 2: f32 64x64 +extern "C" __global__ AICORE void TDIV_f32_64x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTDIV_f32_64x64(void *a, void *b, void *c, void *stream) { + TDIV_f32_64x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 3: f16 64x64 +extern "C" __global__ AICORE void TDIV_f16_64x64(__gm__ half *a, __gm__ half *b, __gm__ half *c); + +void LaunchTDIV_f16_64x64(void *a, void *b, void *c, void *stream) { + TDIV_f16_64x64<<<1, nullptr, stream>>>((__gm__ half *)a, (__gm__ half *)b, (__gm__ half *)c); +} + +// Case 4: f16 64x64 v61x61 +extern "C" __global__ AICORE void TDIV_f16_64x64_v61x61(__gm__ half *a, __gm__ half *b, __gm__ half *c); + +void LaunchTDIV_f16_64x64_v61x61(void *a, void *b, void *c, void *stream) { + TDIV_f16_64x64_v61x61<<<1, nullptr, stream>>>((__gm__ half *)a, (__gm__ half *)b, (__gm__ half *)c); +} + +// Case 5: f32 64x32 v60x30 +extern "C" __global__ AICORE void TDIV_f32_64x32_v60x30(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTDIV_f32_64x32_v60x30(void *a, void *b, void *c, void *stream) { + TDIV_f32_64x32_v60x30<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tdiv/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/main.cpp new file mode 100644 index 000000000..f4e4e56f4 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/main.cpp @@ -0,0 +1,154 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tdiv ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTDIV_f32_16x64(void *a, void *b, void *c, void *stream); +void LaunchTDIV_f32_32x32(void *a, void *b, void *c, void *stream); +void LaunchTDIV_f32_64x64(void *a, void *b, void *c, void *stream); +void LaunchTDIV_f16_64x64(void *a, void *b, void *c, void *stream); +void LaunchTDIV_f16_64x64_v61x61(void *a, void *b, void *c, void *stream); +void LaunchTDIV_f32_64x32_v60x30(void *a, void *b, void *c, void *stream); + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_16x64", LaunchTDIV_f32_16x64, 16, 64, 16, 64, sizeof(float)}, + {"f32_32x32", LaunchTDIV_f32_32x32, 32, 32, 32, 32, sizeof(float)}, + {"f32_64x64", LaunchTDIV_f32_64x64, 64, 64, 64, 64, sizeof(float)}, + {"f16_64x64", LaunchTDIV_f16_64x64, 64, 64, 64, 64, sizeof(uint16_t)}, + {"f16_64x64_v61x61", LaunchTDIV_f16_64x64_v61x61, 64, 64, 61, 61, sizeof(uint16_t)}, + {"f32_64x32_v60x30", LaunchTDIV_f32_64x32_v60x30, 64, 32, 60, 30, sizeof(float)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + (void)deviceId; + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = fileSize; + size_t src1FileSize = fileSize; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&src0Host, fileSize); + aclrtMallocHost(&src1Host, fileSize); + aclrtMallocHost(&dstHost, fileSize); + + aclrtMalloc(&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, fileSize, src0Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, fileSize, src1Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tdiv [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tdiv/tdiv.pto b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/tdiv.pto new file mode 100644 index 000000000..e3f6a40a7 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/tdiv.pto @@ -0,0 +1,398 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tdiv: tload(a) + tload(b) + tdiv(a,b)->c + tstore(c). +// Multiple cases with different shapes in a single module. +// Note: tdiv only supports float types (f32, f16). +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. + +module { + // Case 0: f32 16x64 + func.func @TDIV_f32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tdiv ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case 1: f32 32x32 + func.func @TDIV_f32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%b : !pto.tile_buf) + + pto.tdiv ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + return + } + + // Case 2: f32 64x64 + func.func @TDIV_f32_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tdiv ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case 3: f16 64x64 + func.func @TDIV_f16_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf16> -> !pto.partition_tensor_view<1x1x1x64x64xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf16> -> !pto.partition_tensor_view<1x1x1x64x64xf16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf16> -> !pto.partition_tensor_view<1x1x1x64x64xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf16>) + outs(%b : !pto.tile_buf) + + pto.tdiv ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf16>) + return + } + + // Case 4: f16 64x64 v61x61 (valid != tile) + func.func @TDIV_f16_64x64_v61x61(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c61 = arith.constant 61 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c61, %c61], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x61x61xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c61, %c61], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x61x61xf16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c61, %c61], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x61x61xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c61, %c61] + : !pto.tensor_view<1x1x1x61x61xf16> -> !pto.partition_tensor_view<1x1x1x61x61xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c61, %c61] + : !pto.tensor_view<1x1x1x61x61xf16> -> !pto.partition_tensor_view<1x1x1x61x61xf16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c61, %c61] + : !pto.tensor_view<1x1x1x61x61xf16> -> !pto.partition_tensor_view<1x1x1x61x61xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x61x61xf16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x61x61xf16>) + outs(%b : !pto.tile_buf) + + pto.tdiv ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x61x61xf16>) + return + } + + // Case 5: f32 64x32 v60x30 (valid != tile) + func.func @TDIV_f32_64x32_v60x30(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c30 = arith.constant 30 : index + %c32 = arith.constant 32 : index + %c60 = arith.constant 60 : index + %c64 = arith.constant 64 : index + %c2048 = arith.constant 2048 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c60, %c30], + strides = [%c2048, %c2048, %c2048, %c32, %c1] + : !pto.tensor_view<1x1x1x60x30xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c60, %c30], + strides = [%c2048, %c2048, %c2048, %c32, %c1] + : !pto.tensor_view<1x1x1x60x30xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c60, %c30], + strides = [%c2048, %c2048, %c2048, %c32, %c1] + : !pto.tensor_view<1x1x1x60x30xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c30] + : !pto.tensor_view<1x1x1x60x30xf32> -> !pto.partition_tensor_view<1x1x1x60x30xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c30] + : !pto.tensor_view<1x1x1x60x30xf32> -> !pto.partition_tensor_view<1x1x1x60x30xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c30] + : !pto.tensor_view<1x1x1x60x30xf32> -> !pto.partition_tensor_view<1x1x1x60x30xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x60x30xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x60x30xf32>) + outs(%b : !pto.tile_buf) + + pto.tdiv ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x60x30xf32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfmod/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tfmod/CMakeLists.txt new file mode 100644 index 000000000..41603223e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfmod/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tfmod) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfmod/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tfmod/cases.py new file mode 100644 index 000000000..5b7473873 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfmod/cases.py @@ -0,0 +1,75 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tfmod ST test cases. + +Each case defines: + - name: case identifier + - dtype: numpy dtype (np.float32) + - dst_tile: (rows, cols) — dst tile buffer dimensions + - src0_tile: (rows, cols) — src0 tile buffer dimensions + - src1_tile: (rows, cols) — src1 tile buffer dimensions + - valid_shape: (valid_rows, valid_cols) — effective computation region + - eps: tolerance for numpy.allclose (atol and rtol) + +Note: src0/src1/dst tile buffer physical sizes can differ, + but valid_shape must be the same for all. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_16x64_16x128_16x128_16x64", + "dtype": np.float32, + "dst_tile": (16, 64), + "src0_tile": (16, 128), + "src1_tile": (16, 128), + "valid_shape": (16, 64), + "eps": 1e-3, + }, + { + "name": "f32_16x32_16x64_16x32_16x32", + "dtype": np.float32, + "dst_tile": (16, 32), + "src0_tile": (16, 64), + "src1_tile": (16, 32), + "valid_shape": (16, 32), + "eps": 1e-3, + }, + { + "name": "f32_16x64_16x128_16x128_16x63", + "dtype": np.float32, + "dst_tile": (16, 64), + "src0_tile": (16, 128), + "src1_tile": (16, 128), + "valid_shape": (16, 63), + "eps": 1e-3, + }, + { + "name": "f32_2x32_2x64_2x32_2x31", + "dtype": np.float32, + "dst_tile": (2, 32), + "src0_tile": (2, 64), + "src1_tile": (2, 32), + "valid_shape": (2, 31), + "eps": 1e-3, + }, + { + "name": "f32_1x8192_1x8192_1x8192_1x8192", + "dtype": np.float32, + "dst_tile": (1, 8192), + "src0_tile": (1, 8192), + "src1_tile": (1, 8192), + "valid_shape": (1, 8192), + "eps": 1e-3, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfmod/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tfmod/compare.py new file mode 100644 index 000000000..96ab2dda8 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfmod/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dst_tile = case["dst_tile"] + valid_shape = case["valid_shape"] + vr, vc = valid_shape + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(dst_tile) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(dst_tile) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfmod/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tfmod/gen_data.py new file mode 100644 index 000000000..61bfc2a2e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfmod/gen_data.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +np.random.seed(19) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + dst_tile = case["dst_tile"] + src0_tile = case["src0_tile"] + src1_tile = case["src1_tile"] + valid_shape = case["valid_shape"] + + input1 = np.random.uniform(low=-1000, high=1000, size=src0_tile).astype(dtype) + input2 = np.random.uniform(low=3, high=100, size=src1_tile).astype(dtype) + + golden = np.zeros(dst_tile, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = np.fmod(input1[:vr, :vc], input2[:vr, :vc]).astype(dtype) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} dst={dst_tile} src0={src0_tile} src1={src1_tile} valid={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfmod/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tfmod/launch.cpp new file mode 100644 index 000000000..764d99aa4 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfmod/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32, dst=16x64, src0=16x128, src1=16x128, valid=16x64 +extern "C" __global__ AICORE void TFMOD_f32_16x64_16x128_16x128_16x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTFMOD_f32_16x64_16x128_16x128_16x64(void *a, void *b, void *c, void *stream) { + TFMOD_f32_16x64_16x128_16x128_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 1: f32, dst=16x32, src0=16x64, src1=16x32, valid=16x32 +extern "C" __global__ AICORE void TFMOD_f32_16x32_16x64_16x32_16x32(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTFMOD_f32_16x32_16x64_16x32_16x32(void *a, void *b, void *c, void *stream) { + TFMOD_f32_16x32_16x64_16x32_16x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 2: f32, dst=16x64, src0=16x128, src1=16x128, valid=16x63 +extern "C" __global__ AICORE void TFMOD_f32_16x64_16x128_16x128_16x63(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTFMOD_f32_16x64_16x128_16x128_16x63(void *a, void *b, void *c, void *stream) { + TFMOD_f32_16x64_16x128_16x128_16x63<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 3: f32, dst=2x32, src0=2x64, src1=2x32, valid=2x31 +extern "C" __global__ AICORE void TFMOD_f32_2x32_2x64_2x32_2x31(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTFMOD_f32_2x32_2x64_2x32_2x31(void *a, void *b, void *c, void *stream) { + TFMOD_f32_2x32_2x64_2x32_2x31<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 4: f32, dst=1x8192, src0=1x8192, src1=1x8192, valid=1x8192 +extern "C" __global__ AICORE void TFMOD_f32_1x8192_1x8192_1x8192_1x8192(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTFMOD_f32_1x8192_1x8192_1x8192_1x8192(void *a, void *b, void *c, void *stream) { + TFMOD_f32_1x8192_1x8192_1x8192_1x8192<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfmod/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tfmod/main.cpp new file mode 100644 index 000000000..3f7a170d3 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfmod/main.cpp @@ -0,0 +1,151 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tfmod ST — case-table driven. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +void LaunchTFMOD_f32_16x64_16x128_16x128_16x64(void *a, void *b, void *c, void *stream); +void LaunchTFMOD_f32_16x32_16x64_16x32_16x32(void *a, void *b, void *c, void *stream); +void LaunchTFMOD_f32_16x64_16x128_16x128_16x63(void *a, void *b, void *c, void *stream); +void LaunchTFMOD_f32_2x32_2x64_2x32_2x31(void *a, void *b, void *c, void *stream); +void LaunchTFMOD_f32_1x8192_1x8192_1x8192_1x8192(void *a, void *b, void *c, void *stream); + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t src0Rows; + size_t src0Cols; + size_t src1Rows; + size_t src1Cols; + size_t dstRows; + size_t dstCols; + size_t validRows; + size_t validCols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + {"f32_16x64_16x128_16x128_16x64", LaunchTFMOD_f32_16x64_16x128_16x128_16x64, 16, 128, 16, 128, 16, 64, 16, 64, sizeof(float)}, + {"f32_16x32_16x64_16x32_16x32", LaunchTFMOD_f32_16x32_16x64_16x32_16x32, 16, 64, 16, 32, 16, 32, 16, 32, sizeof(float)}, + {"f32_16x64_16x128_16x128_16x63", LaunchTFMOD_f32_16x64_16x128_16x128_16x63, 16, 128, 16, 128, 16, 64, 16, 63, sizeof(float)}, + {"f32_2x32_2x64_2x32_2x31", LaunchTFMOD_f32_2x32_2x64_2x32_2x31, 2, 64, 2, 32, 2, 32, 2, 31, sizeof(float)}, + {"f32_1x8192_1x8192_1x8192_1x8192", LaunchTFMOD_f32_1x8192_1x8192_1x8192_1x8192, 1, 8192, 1, 8192, 1, 8192, 1, 8192, sizeof(float)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t src0Size = tc.src0Rows * tc.src0Cols * tc.elemSize; + const size_t src1Size = tc.src1Rows * tc.src1Cols * tc.elemSize; + const size_t dstSize = tc.dstRows * tc.dstCols * tc.elemSize; + + std::printf("[INFO] === case: %s (dst=%zux%zu, src0=%zux%zu, src1=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.dstRows, tc.dstCols, tc.src0Rows, tc.src0Cols, tc.src1Rows, tc.src1Cols, tc.validRows, tc.validCols); + + std::string caseDir = std::string("./") + tc.name; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&src0Host, src0Size); + aclrtMallocHost(&src1Host, src1Size); + aclrtMallocHost(&dstHost, dstSize); + + aclrtMalloc(&src0Device, src0Size, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&src1Device, src1Size, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, dstSize, ACL_MEM_MALLOC_HUGE_FIRST); + + size_t fileSize = 0; + if (!ReadFile((caseDir + "/input1.bin").c_str(), fileSize, src0Host, src0Size)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + fileSize = 0; + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), fileSize, src1Host, src1Size)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, src0Size, src0Host, src0Size, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, src1Size, src1Host, src1Size, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstSize, dstDevice, dstSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfmod/tfmod.pto b/test/tilelang_st/npu/a5/src/st/testcase/tfmod/tfmod.pto new file mode 100644 index 000000000..72f227b82 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfmod/tfmod.pto @@ -0,0 +1,340 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tfmod: tload(a) + tload(b) + tfmod(a,b)->c + tstore(c). +// Aligned with pto-isa/tests/npu/a2a3/src/st/testcase/tfmod +// Cases have different src/dst tile buffer sizes but same valid_shape. + +module { + // Case 0: f32, dst=16x64, src0=16x128, src1=16x128, valid=16x64 + func.func @TFMOD_f32_16x64_16x128_16x128_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c2048 = arith.constant 2048 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tfmod ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case 1: f32, dst=16x32, src0=16x64, src1=16x32, valid=16x32 + func.func @TFMOD_f32_16x32_16x64_16x32_16x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c512 = arith.constant 512 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x32xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xf32> -> !pto.partition_tensor_view<1x1x1x16x32xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xf32> -> !pto.partition_tensor_view<1x1x1x16x32xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x32xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x32xf32>) + outs(%b : !pto.tile_buf) + + pto.tfmod ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x32xf32>) + return + } + + // Case 2: f32, dst=16x64, src0=16x128, src1=16x128, valid=16x63 + func.func @TFMOD_f32_16x64_16x128_16x128_16x63(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c2048 = arith.constant 2048 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c63] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x63xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c63] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x63xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c63] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x63xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x63xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x63xf32>) + outs(%b : !pto.tile_buf) + + pto.tfmod ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x63xf32>) + return + } + + // Case 3: f32, dst=2x32, src0=2x64, src1=2x32, valid=2x31 + func.func @TFMOD_f32_2x32_2x64_2x32_2x31(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c31 = arith.constant 31 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c2, %c32], + strides = [%c64, %c64, %c64, %c32, %c1] + : !pto.tensor_view<1x1x1x2x32xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c2, %c32], + strides = [%c64, %c64, %c64, %c32, %c1] + : !pto.tensor_view<1x1x1x2x32xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c31] + : !pto.tensor_view<1x1x1x2x64xf32> -> !pto.partition_tensor_view<1x1x1x2x31xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c31] + : !pto.tensor_view<1x1x1x2x32xf32> -> !pto.partition_tensor_view<1x1x1x2x31xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c31] + : !pto.tensor_view<1x1x1x2x32xf32> -> !pto.partition_tensor_view<1x1x1x2x31xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x2x31xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x2x31xf32>) + outs(%b : !pto.tile_buf) + + pto.tfmod ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x2x31xf32>) + return + } + + // Case 4: f32, dst=1x8192, src0=1x8192, src1=1x8192, valid=1x8192 + func.func @TFMOD_f32_1x8192_1x8192_1x8192_1x8192(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8192 = arith.constant 8192 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c1, %c8192], + strides = [%c8192, %c8192, %c8192, %c8192, %c1] + : !pto.tensor_view<1x1x1x1x8192xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c1, %c8192], + strides = [%c8192, %c8192, %c8192, %c8192, %c1] + : !pto.tensor_view<1x1x1x1x8192xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c1, %c8192], + strides = [%c8192, %c8192, %c8192, %c8192, %c1] + : !pto.tensor_view<1x1x1x1x8192xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c8192] + : !pto.tensor_view<1x1x1x1x8192xf32> -> !pto.partition_tensor_view<1x1x1x1x8192xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c8192] + : !pto.tensor_view<1x1x1x1x8192xf32> -> !pto.partition_tensor_view<1x1x1x1x8192xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c8192] + : !pto.tensor_view<1x1x1x1x8192xf32> -> !pto.partition_tensor_view<1x1x1x1x8192xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x1x8192xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x1x8192xf32>) + outs(%b : !pto.tile_buf) + + pto.tfmod ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x1x8192xf32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmax/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tmax/CMakeLists.txt new file mode 100644 index 000000000..8012132e6 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmax/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tmax) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmax/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tmax/cases.py new file mode 100644 index 000000000..fc535804b --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmax/cases.py @@ -0,0 +1,97 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tmax ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_16x64", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-6, + }, + { + "name": "f32_32x32", + "dtype": np.float32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-6, + }, + { + "name": "f32_64x64", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 1e-6, + }, + { + "name": "i32_64x64", + "dtype": np.int32, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 0, + }, + { + "name": "i16_64x64", + "dtype": np.int16, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 0, + }, + { + "name": "f16_64x64", + "dtype": np.float16, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 1e-3, + }, + { + "name": "f32_64x64_v60x60", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (60, 60), + "eps": 1e-6, + }, + { + "name": "i32_64x64_v60x60", + "dtype": np.int32, + "shape": (64, 64), + "valid_shape": (60, 60), + "eps": 0, + }, + { + "name": "f16_2x4096_v1x3600", + "dtype": np.float16, + "shape": (2, 4096), + "valid_shape": (1, 3600), + "eps": 1e-3, + }, + { + "name": "i16_20x512_v16x200", + "dtype": np.int16, + "shape": (20, 512), + "valid_shape": (16, 200), + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmax/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tmax/compare.py new file mode 100644 index 000000000..b975718e6 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmax/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmax/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tmax/gen_data.py new file mode 100644 index 000000000..7d44773bb --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmax/gen_data.py @@ -0,0 +1,33 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + input2 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = np.maximum(input1[:vr, :vc], input2[:vr, :vc]).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmax/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmax/launch.cpp new file mode 100644 index 000000000..5b3266627 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmax/launch.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 16x64 +extern "C" __global__ AICORE void TMAX_f32_16x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTMAX_f32_16x64(void *a, void *b, void *c, void *stream) { + TMAX_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 1: f32 32x32 +extern "C" __global__ AICORE void TMAX_f32_32x32(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTMAX_f32_32x32(void *a, void *b, void *c, void *stream) { + TMAX_f32_32x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 2: f32 64x64 +extern "C" __global__ AICORE void TMAX_f32_64x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTMAX_f32_64x64(void *a, void *b, void *c, void *stream) { + TMAX_f32_64x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 3: i32 64x64 +extern "C" __global__ AICORE void TMAX_i32_64x64(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); + +void LaunchTMAX_i32_64x64(void *a, void *b, void *c, void *stream) { + TMAX_i32_64x64<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); +} + +// Case 4: i16 64x64 +extern "C" __global__ AICORE void TMAX_i16_64x64(__gm__ int16_t *a, __gm__ int16_t *b, __gm__ int16_t *c); + +void LaunchTMAX_i16_64x64(void *a, void *b, void *c, void *stream) { + TMAX_i16_64x64<<<1, nullptr, stream>>>((__gm__ int16_t *)a, (__gm__ int16_t *)b, (__gm__ int16_t *)c); +} + +// Case 5: f16 64x64 +extern "C" __global__ AICORE void TMAX_f16_64x64(__gm__ half *a, __gm__ half *b, __gm__ half *c); + +void LaunchTMAX_f16_64x64(void *a, void *b, void *c, void *stream) { + TMAX_f16_64x64<<<1, nullptr, stream>>>((__gm__ half *)a, (__gm__ half *)b, (__gm__ half *)c); +} + +// Case 6: f32 64x64 v60x60 +extern "C" __global__ AICORE void TMAX_f32_64x64_v60x60(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTMAX_f32_64x64_v60x60(void *a, void *b, void *c, void *stream) { + TMAX_f32_64x64_v60x60<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 7: i32 64x64 v60x60 +extern "C" __global__ AICORE void TMAX_i32_64x64_v60x60(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); + +void LaunchTMAX_i32_64x64_v60x60(void *a, void *b, void *c, void *stream) { + TMAX_i32_64x64_v60x60<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); +} + +// Case 8: f16 2x4096 v1x3600 +extern "C" __global__ AICORE void TMAX_f16_2x4096_v1x3600(__gm__ half *a, __gm__ half *b, __gm__ half *c); + +void LaunchTMAX_f16_2x4096_v1x3600(void *a, void *b, void *c, void *stream) { + TMAX_f16_2x4096_v1x3600<<<1, nullptr, stream>>>((__gm__ half *)a, (__gm__ half *)b, (__gm__ half *)c); +} + +// Case 9: i16 20x512 v16x200 +extern "C" __global__ AICORE void TMAX_i16_20x512_v16x200(__gm__ int16_t *a, __gm__ int16_t *b, __gm__ int16_t *c); + +void LaunchTMAX_i16_20x512_v16x200(void *a, void *b, void *c, void *stream) { + TMAX_i16_20x512_v16x200<<<1, nullptr, stream>>>((__gm__ int16_t *)a, (__gm__ int16_t *)b, (__gm__ int16_t *)c); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmax/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmax/main.cpp new file mode 100644 index 000000000..0f00a8513 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmax/main.cpp @@ -0,0 +1,162 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tmax ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTMAX_f32_16x64(void *a, void *b, void *c, void *stream); +void LaunchTMAX_f32_32x32(void *a, void *b, void *c, void *stream); +void LaunchTMAX_f32_64x64(void *a, void *b, void *c, void *stream); +void LaunchTMAX_i32_64x64(void *a, void *b, void *c, void *stream); +void LaunchTMAX_i16_64x64(void *a, void *b, void *c, void *stream); +void LaunchTMAX_f16_64x64(void *a, void *b, void *c, void *stream); +void LaunchTMAX_f32_64x64_v60x60(void *a, void *b, void *c, void *stream); +void LaunchTMAX_i32_64x64_v60x60(void *a, void *b, void *c, void *stream); +void LaunchTMAX_f16_2x4096_v1x3600(void *a, void *b, void *c, void *stream); +void LaunchTMAX_i16_20x512_v16x200(void *a, void *b, void *c, void *stream); + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_16x64", LaunchTMAX_f32_16x64, 16, 64, 16, 64, sizeof(float)}, + {"f32_32x32", LaunchTMAX_f32_32x32, 32, 32, 32, 32, sizeof(float)}, + {"f32_64x64", LaunchTMAX_f32_64x64, 64, 64, 64, 64, sizeof(float)}, + {"i32_64x64", LaunchTMAX_i32_64x64, 64, 64, 64, 64, sizeof(int32_t)}, + {"i16_64x64", LaunchTMAX_i16_64x64, 64, 64, 64, 64, sizeof(int16_t)}, + {"f16_64x64", LaunchTMAX_f16_64x64, 64, 64, 64, 64, sizeof(uint16_t)}, + {"f32_64x64_v60x60", LaunchTMAX_f32_64x64_v60x60, 64, 64, 60, 60, sizeof(float)}, + {"i32_64x64_v60x60", LaunchTMAX_i32_64x64_v60x60, 64, 64, 60, 60, sizeof(int32_t)}, + {"f16_2x4096_v1x3600", LaunchTMAX_f16_2x4096_v1x3600, 2, 4096, 1, 3600, sizeof(uint16_t)}, + {"i16_20x512_v16x200", LaunchTMAX_i16_20x512_v16x200, 20, 512, 16, 200, sizeof(int16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + (void)deviceId; + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = fileSize; + size_t src1FileSize = fileSize; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&src0Host, fileSize); + aclrtMallocHost(&src1Host, fileSize); + aclrtMallocHost(&dstHost, fileSize); + + aclrtMalloc(&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, fileSize, src0Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, fileSize, src1Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tmax [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmax/tmax.pto b/test/tilelang_st/npu/a5/src/st/testcase/tmax/tmax.pto new file mode 100644 index 000000000..77172ec2a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmax/tmax.pto @@ -0,0 +1,649 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tmax: tload(a) + tload(b) + tmax(a,b)->c + tstore(c). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. + +module { + // Case 0: f32 16x64 (1024 elements) + func.func @TMAX_f32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case 1: f32 32x32 (1024 elements) + func.func @TMAX_f32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%b : !pto.tile_buf) + + pto.tmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + return + } + + // Case 2: f32 64x64 (4096 elements) + func.func @TMAX_f32_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case 3: i32 64x64 (4096 elements) + func.func @TMAX_i32_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + outs(%b : !pto.tile_buf) + + pto.tmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + return + } + + // Case 4: i16 64x64 (4096 elements) + func.func @TMAX_i16_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) + outs(%b : !pto.tile_buf) + + pto.tmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) + return + } + + // Case 5: f16 64x64 (4096 elements) + func.func @TMAX_f16_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf16> -> !pto.partition_tensor_view<1x1x1x64x64xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf16> -> !pto.partition_tensor_view<1x1x1x64x64xf16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf16> -> !pto.partition_tensor_view<1x1x1x64x64xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf16>) + outs(%b : !pto.tile_buf) + + pto.tmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf16>) + return + } + + // Case 6: f32 64x64 tile with 60x60 valid region (padding with MIN for tmax) + func.func @TMAX_f32_64x64_v60x60(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c60 = arith.constant 60 : index + %c64 = arith.constant 64 : index + %c3600 = arith.constant 3600 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c60, %c60], + strides = [%c3600, %c3600, %c3600, %c64, %c1] + : !pto.tensor_view<1x1x1x60x60xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c60, %c60], + strides = [%c3600, %c3600, %c3600, %c64, %c1] + : !pto.tensor_view<1x1x1x60x60xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c60, %c60], + strides = [%c3600, %c3600, %c3600, %c64, %c1] + : !pto.tensor_view<1x1x1x60x60xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x60x60xf32> -> !pto.partition_tensor_view<1x1x1x60x60xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x60x60xf32> -> !pto.partition_tensor_view<1x1x1x60x60xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x60x60xf32> -> !pto.partition_tensor_view<1x1x1x60x60xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x60x60xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x60x60xf32>) + outs(%b : !pto.tile_buf) + + pto.tmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x60x60xf32>) + return + } + + // Case 7: i32 64x64 tile with 60x60 valid region (padding with MIN for tmax) + func.func @TMAX_i32_64x64_v60x60(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c60 = arith.constant 60 : index + %c64 = arith.constant 64 : index + %c3600 = arith.constant 3600 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c60, %c60], + strides = [%c3600, %c3600, %c3600, %c64, %c1] + : !pto.tensor_view<1x1x1x60x60xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c60, %c60], + strides = [%c3600, %c3600, %c3600, %c64, %c1] + : !pto.tensor_view<1x1x1x60x60xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c60, %c60], + strides = [%c3600, %c3600, %c3600, %c64, %c1] + : !pto.tensor_view<1x1x1x60x60xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x60x60xi32> -> !pto.partition_tensor_view<1x1x1x60x60xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x60x60xi32> -> !pto.partition_tensor_view<1x1x1x60x60xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x60x60xi32> -> !pto.partition_tensor_view<1x1x1x60x60xi32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x60x60xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x60x60xi32>) + outs(%b : !pto.tile_buf) + + pto.tmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x60x60xi32>) + return + } + + // Case 8: f16 2x4096 tile with 1x3600 valid region (padding with MIN for tmax) + func.func @TMAX_f16_2x4096_v1x3600(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3600 = arith.constant 3600 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c1, %c3600], + strides = [%c3600, %c3600, %c3600, %c4096, %c1] + : !pto.tensor_view<1x1x1x1x3600xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c1, %c3600], + strides = [%c3600, %c3600, %c3600, %c4096, %c1] + : !pto.tensor_view<1x1x1x1x3600xf16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c1, %c3600], + strides = [%c3600, %c3600, %c3600, %c4096, %c1] + : !pto.tensor_view<1x1x1x1x3600xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c3600] + : !pto.tensor_view<1x1x1x1x3600xf16> -> !pto.partition_tensor_view<1x1x1x1x3600xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c3600] + : !pto.tensor_view<1x1x1x1x3600xf16> -> !pto.partition_tensor_view<1x1x1x1x3600xf16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c3600] + : !pto.tensor_view<1x1x1x1x3600xf16> -> !pto.partition_tensor_view<1x1x1x1x3600xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x1x3600xf16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x1x3600xf16>) + outs(%b : !pto.tile_buf) + + pto.tmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x1x3600xf16>) + return + } + + // Case 9: i16 20x512 tile with 16x200 valid region (padding with MIN for tmax) + func.func @TMAX_i16_20x512_v16x200(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c200 = arith.constant 200 : index + %c512 = arith.constant 512 : index + %c3200 = arith.constant 3200 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c200], + strides = [%c3200, %c3200, %c3200, %c512, %c1] + : !pto.tensor_view<1x1x1x16x200xi16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c200], + strides = [%c3200, %c3200, %c3200, %c512, %c1] + : !pto.tensor_view<1x1x1x16x200xi16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c200], + strides = [%c3200, %c3200, %c3200, %c512, %c1] + : !pto.tensor_view<1x1x1x16x200xi16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c200] + : !pto.tensor_view<1x1x1x16x200xi16> -> !pto.partition_tensor_view<1x1x1x16x200xi16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c200] + : !pto.tensor_view<1x1x1x16x200xi16> -> !pto.partition_tensor_view<1x1x1x16x200xi16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c200] + : !pto.tensor_view<1x1x1x16x200xi16> -> !pto.partition_tensor_view<1x1x1x16x200xi16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x200xi16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x200xi16>) + outs(%b : !pto.tile_buf) + + pto.tmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x200xi16>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmin/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tmin/CMakeLists.txt new file mode 100644 index 000000000..fa90b8804 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmin/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tmin) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmin/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tmin/cases.py new file mode 100644 index 000000000..4a075e55f --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmin/cases.py @@ -0,0 +1,83 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tmin ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_64x64", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 1e-6, + }, + { + "name": "i32_64x64", + "dtype": np.int32, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 0, + }, + { + "name": "i16_64x64", + "dtype": np.int16, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 0, + }, + { + "name": "f16_64x64", + "dtype": np.float16, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 1e-3, + }, + { + "name": "f32_64x64_v60x60", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (60, 60), + "eps": 1e-6, + }, + { + "name": "i32_64x64_v60x60", + "dtype": np.int32, + "shape": (64, 64), + "valid_shape": (60, 60), + "eps": 0, + }, + { + "name": "f16_2x4096_v1x3600", + "dtype": np.float16, + "shape": (2, 4096), + "valid_shape": (1, 3600), + "eps": 1e-3, + }, + { + "name": "i16_20x512_v16x200", + "dtype": np.int16, + "shape": (20, 512), + "valid_shape": (16, 200), + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmin/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tmin/compare.py new file mode 100644 index 000000000..e9cb62020 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmin/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmin/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tmin/gen_data.py new file mode 100644 index 000000000..0722522c3 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmin/gen_data.py @@ -0,0 +1,33 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + input2 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = np.minimum(input1[:vr, :vc], input2[:vr, :vc]).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmin/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmin/launch.cpp new file mode 100644 index 000000000..8757cc295 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmin/launch.cpp @@ -0,0 +1,69 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 64x64 +extern "C" __global__ AICORE void TMIN_f32_64x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTMIN_f32_64x64(void *a, void *b, void *c, void *stream) { + TMIN_f32_64x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 1: i32 64x64 +extern "C" __global__ AICORE void TMIN_i32_64x64(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); + +void LaunchTMIN_i32_64x64(void *a, void *b, void *c, void *stream) { + TMIN_i32_64x64<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); +} + +// Case 2: i16 64x64 +extern "C" __global__ AICORE void TMIN_i16_64x64(__gm__ int16_t *a, __gm__ int16_t *b, __gm__ int16_t *c); + +void LaunchTMIN_i16_64x64(void *a, void *b, void *c, void *stream) { + TMIN_i16_64x64<<<1, nullptr, stream>>>((__gm__ int16_t *)a, (__gm__ int16_t *)b, (__gm__ int16_t *)c); +} + +// Case 3: f16 64x64 +extern "C" __global__ AICORE void TMIN_f16_64x64(__gm__ half *a, __gm__ half *b, __gm__ half *c); + +void LaunchTMIN_f16_64x64(void *a, void *b, void *c, void *stream) { + TMIN_f16_64x64<<<1, nullptr, stream>>>((__gm__ half *)a, (__gm__ half *)b, (__gm__ half *)c); +} + +// Case 4: f32 64x64 v60x60 +extern "C" __global__ AICORE void TMIN_f32_64x64_v60x60(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTMIN_f32_64x64_v60x60(void *a, void *b, void *c, void *stream) { + TMIN_f32_64x64_v60x60<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 5: i32 64x64 v60x60 +extern "C" __global__ AICORE void TMIN_i32_64x64_v60x60(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); + +void LaunchTMIN_i32_64x64_v60x60(void *a, void *b, void *c, void *stream) { + TMIN_i32_64x64_v60x60<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); +} + +// Case 6: f16 2x4096 v1x3600 +extern "C" __global__ AICORE void TMIN_f16_2x4096_v1x3600(__gm__ half *a, __gm__ half *b, __gm__ half *c); + +void LaunchTMIN_f16_2x4096_v1x3600(void *a, void *b, void *c, void *stream) { + TMIN_f16_2x4096_v1x3600<<<1, nullptr, stream>>>((__gm__ half *)a, (__gm__ half *)b, (__gm__ half *)c); +} + +// Case 7: i16 20x512 v16x200 +extern "C" __global__ AICORE void TMIN_i16_20x512_v16x200(__gm__ int16_t *a, __gm__ int16_t *b, __gm__ int16_t *c); + +void LaunchTMIN_i16_20x512_v16x200(void *a, void *b, void *c, void *stream) { + TMIN_i16_20x512_v16x200<<<1, nullptr, stream>>>((__gm__ int16_t *)a, (__gm__ int16_t *)b, (__gm__ int16_t *)c); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmin/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmin/main.cpp new file mode 100644 index 000000000..4fec9b639 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmin/main.cpp @@ -0,0 +1,158 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tmin ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTMIN_f32_64x64(void *a, void *b, void *c, void *stream); +void LaunchTMIN_i32_64x64(void *a, void *b, void *c, void *stream); +void LaunchTMIN_i16_64x64(void *a, void *b, void *c, void *stream); +void LaunchTMIN_f16_64x64(void *a, void *b, void *c, void *stream); +void LaunchTMIN_f32_64x64_v60x60(void *a, void *b, void *c, void *stream); +void LaunchTMIN_i32_64x64_v60x60(void *a, void *b, void *c, void *stream); +void LaunchTMIN_f16_2x4096_v1x3600(void *a, void *b, void *c, void *stream); +void LaunchTMIN_i16_20x512_v16x200(void *a, void *b, void *c, void *stream); + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_64x64", LaunchTMIN_f32_64x64, 64, 64, 64, 64, sizeof(float)}, + {"i32_64x64", LaunchTMIN_i32_64x64, 64, 64, 64, 64, sizeof(int32_t)}, + {"i16_64x64", LaunchTMIN_i16_64x64, 64, 64, 64, 64, sizeof(int16_t)}, + {"f16_64x64", LaunchTMIN_f16_64x64, 64, 64, 64, 64, sizeof(uint16_t)}, + {"f32_64x64_v60x60", LaunchTMIN_f32_64x64_v60x60, 64, 64, 60, 60, sizeof(float)}, + {"i32_64x64_v60x60", LaunchTMIN_i32_64x64_v60x60, 64, 64, 60, 60, sizeof(int32_t)}, + {"f16_2x4096_v1x3600", LaunchTMIN_f16_2x4096_v1x3600, 2, 4096, 1, 3600, sizeof(uint16_t)}, + {"i16_20x512_v16x200", LaunchTMIN_i16_20x512_v16x200, 20, 512, 16, 200, sizeof(int16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + (void)deviceId; + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = fileSize; + size_t src1FileSize = fileSize; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&src0Host, fileSize); + aclrtMallocHost(&src1Host, fileSize); + aclrtMallocHost(&dstHost, fileSize); + + aclrtMalloc(&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, fileSize, src0Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, fileSize, src1Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tmin [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmin/tmin.pto b/test/tilelang_st/npu/a5/src/st/testcase/tmin/tmin.pto new file mode 100644 index 000000000..aa51b3b85 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmin/tmin.pto @@ -0,0 +1,522 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tmin: tload(a) + tload(b) + tmin(a,b)->c + tstore(c). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. + +module { + // Case 0: f32 64x64 (4096 elements) + func.func @TMIN_f32_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case 1: i32 64x64 (4096 elements) + func.func @TMIN_i32_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + outs(%b : !pto.tile_buf) + + pto.tmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + return + } + + // Case 2: i16 64x64 (4096 elements) + func.func @TMIN_i16_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) + outs(%b : !pto.tile_buf) + + pto.tmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) + return + } + + // Case 3: f16 64x64 (4096 elements) + func.func @TMIN_f16_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf16> -> !pto.partition_tensor_view<1x1x1x64x64xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf16> -> !pto.partition_tensor_view<1x1x1x64x64xf16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf16> -> !pto.partition_tensor_view<1x1x1x64x64xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf16>) + outs(%b : !pto.tile_buf) + + pto.tmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf16>) + return + } + + // Case 4: f32 64x64 tile with 60x60 valid region (padding with MAX for tmin) + func.func @TMIN_f32_64x64_v60x60(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c60 = arith.constant 60 : index + %c64 = arith.constant 64 : index + %c3600 = arith.constant 3600 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c60, %c60], + strides = [%c3600, %c3600, %c3600, %c64, %c1] + : !pto.tensor_view<1x1x1x60x60xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c60, %c60], + strides = [%c3600, %c3600, %c3600, %c64, %c1] + : !pto.tensor_view<1x1x1x60x60xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c60, %c60], + strides = [%c3600, %c3600, %c3600, %c64, %c1] + : !pto.tensor_view<1x1x1x60x60xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x60x60xf32> -> !pto.partition_tensor_view<1x1x1x60x60xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x60x60xf32> -> !pto.partition_tensor_view<1x1x1x60x60xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x60x60xf32> -> !pto.partition_tensor_view<1x1x1x60x60xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x60x60xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x60x60xf32>) + outs(%b : !pto.tile_buf) + + pto.tmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x60x60xf32>) + return + } + + // Case 5: i32 64x64 tile with 60x60 valid region (padding with MAX for tmin) + func.func @TMIN_i32_64x64_v60x60(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c60 = arith.constant 60 : index + %c64 = arith.constant 64 : index + %c3600 = arith.constant 3600 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c60, %c60], + strides = [%c3600, %c3600, %c3600, %c64, %c1] + : !pto.tensor_view<1x1x1x60x60xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c60, %c60], + strides = [%c3600, %c3600, %c3600, %c64, %c1] + : !pto.tensor_view<1x1x1x60x60xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c60, %c60], + strides = [%c3600, %c3600, %c3600, %c64, %c1] + : !pto.tensor_view<1x1x1x60x60xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x60x60xi32> -> !pto.partition_tensor_view<1x1x1x60x60xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x60x60xi32> -> !pto.partition_tensor_view<1x1x1x60x60xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x60x60xi32> -> !pto.partition_tensor_view<1x1x1x60x60xi32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x60x60xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x60x60xi32>) + outs(%b : !pto.tile_buf) + + pto.tmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x60x60xi32>) + return + } + + // Case 6: f16 2x4096 tile with 1x3600 valid region (padding with MAX for tmin) + func.func @TMIN_f16_2x4096_v1x3600(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3600 = arith.constant 3600 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c1, %c3600], + strides = [%c3600, %c3600, %c3600, %c4096, %c1] + : !pto.tensor_view<1x1x1x1x3600xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c1, %c3600], + strides = [%c3600, %c3600, %c3600, %c4096, %c1] + : !pto.tensor_view<1x1x1x1x3600xf16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c1, %c3600], + strides = [%c3600, %c3600, %c3600, %c4096, %c1] + : !pto.tensor_view<1x1x1x1x3600xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c3600] + : !pto.tensor_view<1x1x1x1x3600xf16> -> !pto.partition_tensor_view<1x1x1x1x3600xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c3600] + : !pto.tensor_view<1x1x1x1x3600xf16> -> !pto.partition_tensor_view<1x1x1x1x3600xf16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c3600] + : !pto.tensor_view<1x1x1x1x3600xf16> -> !pto.partition_tensor_view<1x1x1x1x3600xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x1x3600xf16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x1x3600xf16>) + outs(%b : !pto.tile_buf) + + pto.tmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x1x3600xf16>) + return + } + + // Case 7: i16 20x512 tile with 16x200 valid region (padding with MAX for tmin) + func.func @TMIN_i16_20x512_v16x200(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c200 = arith.constant 200 : index + %c512 = arith.constant 512 : index + %c3200 = arith.constant 3200 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c200], + strides = [%c3200, %c3200, %c3200, %c512, %c1] + : !pto.tensor_view<1x1x1x16x200xi16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c200], + strides = [%c3200, %c3200, %c3200, %c512, %c1] + : !pto.tensor_view<1x1x1x16x200xi16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c200], + strides = [%c3200, %c3200, %c3200, %c512, %c1] + : !pto.tensor_view<1x1x1x16x200xi16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c200] + : !pto.tensor_view<1x1x1x16x200xi16> -> !pto.partition_tensor_view<1x1x1x16x200xi16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c200] + : !pto.tensor_view<1x1x1x16x200xi16> -> !pto.partition_tensor_view<1x1x1x16x200xi16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c200] + : !pto.tensor_view<1x1x1x16x200xi16> -> !pto.partition_tensor_view<1x1x1x16x200xi16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x200xi16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x200xi16>) + outs(%b : !pto.tile_buf) + + pto.tmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x200xi16>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmul/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tmul/CMakeLists.txt new file mode 100644 index 000000000..4134fa993 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmul/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tmul) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmul/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tmul/cases.py new file mode 100644 index 000000000..9a978c0f3 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmul/cases.py @@ -0,0 +1,90 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tmul ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_16x64", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-6, + }, + { + "name": "f32_32x32", + "dtype": np.float32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-6, + }, + { + "name": "f32_64x64", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 1e-6, + }, + { + "name": "i32_64x64", + "dtype": np.int32, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 0, + }, + { + "name": "i32_32x32", + "dtype": np.int32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 0, + }, + { + "name": "i16_64x64", + "dtype": np.int16, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 0, + }, + { + "name": "f16_16x16", + "dtype": np.float16, + "shape": (16, 16), + "valid_shape": (16, 16), + "eps": 1e-3, + }, + { + "name": "f16_64x64_v61x61", + "dtype": np.float16, + "shape": (64, 64), + "valid_shape": (61, 61), + "eps": 1e-3, + }, + { + "name": "i32_64x32_v60x30", + "dtype": np.int32, + "shape": (64, 32), + "valid_shape": (60, 30), + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmul/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tmul/compare.py new file mode 100644 index 000000000..b975718e6 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmul/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmul/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tmul/gen_data.py new file mode 100644 index 000000000..1d8eddf3f --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmul/gen_data.py @@ -0,0 +1,33 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + input2 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = (input1[:vr, :vc] * input2[:vr, :vc]).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmul/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmul/launch.cpp new file mode 100644 index 000000000..4c08a0caf --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmul/launch.cpp @@ -0,0 +1,76 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 16x64 +extern "C" __global__ AICORE void TMUL_f32_16x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTMUL_f32_16x64(void *a, void *b, void *c, void *stream) { + TMUL_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 1: f32 32x32 +extern "C" __global__ AICORE void TMUL_f32_32x32(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTMUL_f32_32x32(void *a, void *b, void *c, void *stream) { + TMUL_f32_32x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 2: f32 64x64 +extern "C" __global__ AICORE void TMUL_f32_64x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTMUL_f32_64x64(void *a, void *b, void *c, void *stream) { + TMUL_f32_64x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 3: i32 64x64 +extern "C" __global__ AICORE void TMUL_i32_64x64(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); + +void LaunchTMUL_i32_64x64(void *a, void *b, void *c, void *stream) { + TMUL_i32_64x64<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); +} + +// Case 4: i32 32x32 +extern "C" __global__ AICORE void TMUL_i32_32x32(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); + +void LaunchTMUL_i32_32x32(void *a, void *b, void *c, void *stream) { + TMUL_i32_32x32<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); +} + +// Case 5: i16 64x64 +extern "C" __global__ AICORE void TMUL_i16_64x64(__gm__ int16_t *a, __gm__ int16_t *b, __gm__ int16_t *c); + +void LaunchTMUL_i16_64x64(void *a, void *b, void *c, void *stream) { + TMUL_i16_64x64<<<1, nullptr, stream>>>((__gm__ int16_t *)a, (__gm__ int16_t *)b, (__gm__ int16_t *)c); +} + +// Case 6: f16 16x16 +extern "C" __global__ AICORE void TMUL_f16_16x16(__gm__ half *a, __gm__ half *b, __gm__ half *c); + +void LaunchTMUL_f16_16x16(void *a, void *b, void *c, void *stream) { + TMUL_f16_16x16<<<1, nullptr, stream>>>((__gm__ half *)a, (__gm__ half *)b, (__gm__ half *)c); +} + +// Case 7: f16 64x64 v61x61 +extern "C" __global__ AICORE void TMUL_f16_64x64_v61x61(__gm__ half *a, __gm__ half *b, __gm__ half *c); + +void LaunchTMUL_f16_64x64_v61x61(void *a, void *b, void *c, void *stream) { + TMUL_f16_64x64_v61x61<<<1, nullptr, stream>>>((__gm__ half *)a, (__gm__ half *)b, (__gm__ half *)c); +} + +// Case 8: i32 64x32 v60x30 +extern "C" __global__ AICORE void TMUL_i32_64x32_v60x30(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); + +void LaunchTMUL_i32_64x32_v60x30(void *a, void *b, void *c, void *stream) { + TMUL_i32_64x32_v60x30<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmul/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmul/main.cpp new file mode 100644 index 000000000..334486fd3 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmul/main.cpp @@ -0,0 +1,160 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tmul ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTMUL_f32_16x64(void *a, void *b, void *c, void *stream); +void LaunchTMUL_f32_32x32(void *a, void *b, void *c, void *stream); +void LaunchTMUL_f32_64x64(void *a, void *b, void *c, void *stream); +void LaunchTMUL_i32_64x64(void *a, void *b, void *c, void *stream); +void LaunchTMUL_i32_32x32(void *a, void *b, void *c, void *stream); +void LaunchTMUL_i16_64x64(void *a, void *b, void *c, void *stream); +void LaunchTMUL_f16_16x16(void *a, void *b, void *c, void *stream); +void LaunchTMUL_f16_64x64_v61x61(void *a, void *b, void *c, void *stream); +void LaunchTMUL_i32_64x32_v60x30(void *a, void *b, void *c, void *stream); + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_16x64", LaunchTMUL_f32_16x64, 16, 64, 16, 64, sizeof(float)}, + {"f32_32x32", LaunchTMUL_f32_32x32, 32, 32, 32, 32, sizeof(float)}, + {"f32_64x64", LaunchTMUL_f32_64x64, 64, 64, 64, 64, sizeof(float)}, + {"i32_64x64", LaunchTMUL_i32_64x64, 64, 64, 64, 64, sizeof(int32_t)}, + {"i32_32x32", LaunchTMUL_i32_32x32, 32, 32, 32, 32, sizeof(int32_t)}, + {"i16_64x64", LaunchTMUL_i16_64x64, 64, 64, 64, 64, sizeof(int16_t)}, + {"f16_16x16", LaunchTMUL_f16_16x16, 16, 16, 16, 16, sizeof(uint16_t)}, + {"f16_64x64_v61x61", LaunchTMUL_f16_64x64_v61x61, 64, 64, 61, 61, sizeof(uint16_t)}, + {"i32_64x32_v60x30", LaunchTMUL_i32_64x32_v60x30, 64, 32, 60, 30, sizeof(int32_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + (void)deviceId; + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = fileSize; + size_t src1FileSize = fileSize; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&src0Host, fileSize); + aclrtMallocHost(&src1Host, fileSize); + aclrtMallocHost(&dstHost, fileSize); + + aclrtMalloc(&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, fileSize, src0Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, fileSize, src1Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tmul [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmul/tmul.pto b/test/tilelang_st/npu/a5/src/st/testcase/tmul/tmul.pto new file mode 100644 index 000000000..c358c738c --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmul/tmul.pto @@ -0,0 +1,586 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tmul: tload(a) + tload(b) + tmul(a,b)->c + tstore(c). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. + +module { + // Case 0: f32 16x64 + func.func @TMUL_f32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tmul ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case 1: f32 32x32 + func.func @TMUL_f32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%b : !pto.tile_buf) + + pto.tmul ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + return + } + + // Case 2: f32 64x64 + func.func @TMUL_f32_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tmul ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case 3: i32 64x64 + func.func @TMUL_i32_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + outs(%b : !pto.tile_buf) + + pto.tmul ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + return + } + + // Case 4: i32 32x32 + func.func @TMUL_i32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) + outs(%b : !pto.tile_buf) + + pto.tmul ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) + return + } + + // Case 5: i16 64x64 + func.func @TMUL_i16_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) + outs(%b : !pto.tile_buf) + + pto.tmul ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) + return + } + + // Case 6: f16 16x16 + func.func @TMUL_f16_16x16(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c16], + strides = [%c256, %c256, %c256, %c16, %c1] + : !pto.tensor_view<1x1x1x16x16xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c16], + strides = [%c256, %c256, %c256, %c16, %c1] + : !pto.tensor_view<1x1x1x16x16xf16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c16], + strides = [%c256, %c256, %c256, %c16, %c1] + : !pto.tensor_view<1x1x1x16x16xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c16] + : !pto.tensor_view<1x1x1x16x16xf16> -> !pto.partition_tensor_view<1x1x1x16x16xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c16] + : !pto.tensor_view<1x1x1x16x16xf16> -> !pto.partition_tensor_view<1x1x1x16x16xf16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c16] + : !pto.tensor_view<1x1x1x16x16xf16> -> !pto.partition_tensor_view<1x1x1x16x16xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x16xf16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x16xf16>) + outs(%b : !pto.tile_buf) + + pto.tmul ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x16xf16>) + return + } + + // Case 7: f16 64x64 v61x61 (valid != tile) + func.func @TMUL_f16_64x64_v61x61(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c61 = arith.constant 61 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c61, %c61], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x61x61xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c61, %c61], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x61x61xf16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c61, %c61], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x61x61xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c61, %c61] + : !pto.tensor_view<1x1x1x61x61xf16> -> !pto.partition_tensor_view<1x1x1x61x61xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c61, %c61] + : !pto.tensor_view<1x1x1x61x61xf16> -> !pto.partition_tensor_view<1x1x1x61x61xf16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c61, %c61] + : !pto.tensor_view<1x1x1x61x61xf16> -> !pto.partition_tensor_view<1x1x1x61x61xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x61x61xf16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x61x61xf16>) + outs(%b : !pto.tile_buf) + + pto.tmul ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x61x61xf16>) + return + } + + // Case 8: i32 64x32 v60x30 (valid != tile) + func.func @TMUL_i32_64x32_v60x30(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c30 = arith.constant 30 : index + %c32 = arith.constant 32 : index + %c60 = arith.constant 60 : index + %c64 = arith.constant 64 : index + %c2048 = arith.constant 2048 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c60, %c30], + strides = [%c2048, %c2048, %c2048, %c32, %c1] + : !pto.tensor_view<1x1x1x60x30xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c60, %c30], + strides = [%c2048, %c2048, %c2048, %c32, %c1] + : !pto.tensor_view<1x1x1x60x30xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c60, %c30], + strides = [%c2048, %c2048, %c2048, %c32, %c1] + : !pto.tensor_view<1x1x1x60x30xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c30] + : !pto.tensor_view<1x1x1x60x30xi32> -> !pto.partition_tensor_view<1x1x1x60x30xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c30] + : !pto.tensor_view<1x1x1x60x30xi32> -> !pto.partition_tensor_view<1x1x1x60x30xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c30] + : !pto.tensor_view<1x1x1x60x30xi32> -> !pto.partition_tensor_view<1x1x1x60x30xi32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x60x30xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x60x30xi32>) + outs(%b : !pto.tile_buf) + + pto.tmul ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x60x30xi32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tor/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tor/CMakeLists.txt new file mode 100644 index 000000000..4d7414cdb --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tor/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tor) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tor/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tor/cases.py new file mode 100644 index 000000000..bbcd9631b --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tor/cases.py @@ -0,0 +1,39 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tor ST test cases. + +Each case defines: + - name: case identifier + - dtype: numpy dtype (np.int32) + - shape: (rows, cols) — allocated tile dimensions + - valid_shape: (valid_rows, valid_cols) — effective computation region + - eps: tolerance for numpy.allclose (atol and rtol) +""" + +import numpy as np + +CASES = [ + { + "name": "i32_16x64", + "dtype": np.int32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 0, + }, + { + "name": "i32_32x32", + "dtype": np.int32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tor/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tor/compare.py new file mode 100644 index 000000000..b975718e6 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tor/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tor/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tor/gen_data.py new file mode 100644 index 000000000..ada52fdfc --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tor/gen_data.py @@ -0,0 +1,33 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(0, 100, size=shape).astype(dtype) + input2 = np.random.randint(0, 100, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = (input1[:vr, :vc] | input2[:vr, :vc]).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tor/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tor/launch.cpp new file mode 100644 index 000000000..1cb9c1454 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tor/launch.cpp @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: i32 16x64 +extern "C" __global__ AICORE void TOR_i32_16x64(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); + +void LaunchTOR_i32_16x64(int32_t *a, int32_t *b, int32_t *c, void *stream) { + TOR_i32_16x64<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); +} + +// Case 1: i32 32x32 +extern "C" __global__ AICORE void TOR_i32_32x32(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); + +void LaunchTOR_i32_32x32(int32_t *a, int32_t *b, int32_t *c, void *stream) { + TOR_i32_32x32<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tor/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tor/main.cpp new file mode 100644 index 000000000..21d82eeea --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tor/main.cpp @@ -0,0 +1,145 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tor ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTOR_i32_16x64(int32_t *a, int32_t *b, int32_t *c, void *stream); +void LaunchTOR_i32_32x32(int32_t *a, int32_t *b, int32_t *c, void *stream); + +using LaunchFn = void (*)(int32_t *, int32_t *, int32_t *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"i32_16x64", LaunchTOR_i32_16x64, 16, 64, 16, 64, sizeof(int32_t)}, + {"i32_32x32", LaunchTOR_i32_32x32, 32, 32, 32, 32, sizeof(int32_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = fileSize; + size_t src1FileSize = fileSize; + + int32_t *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + int32_t *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), fileSize); + aclrtMallocHost((void **)(&src1Host), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); + + aclrtMalloc((void **)&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, fileSize, src0Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, fileSize, src1Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tor [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tor/tor.pto b/test/tilelang_st/npu/a5/src/st/testcase/tor/tor.pto new file mode 100644 index 000000000..4f7960133 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tor/tor.pto @@ -0,0 +1,144 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tor: tload(a) + tload(b) + tor(a,b)->c + tstore(c). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. +// +// NOTE: Unsigned integer types temporarily disabled due to HIVM backend bug +// See BUG_REPORT_UNSIGNED_BITOPS.md for details + +module { + // Case 0: i32 16x64 (1024 elements) + func.func @TOR_i32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x64xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x64xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x64xi32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x64xi32>) + outs(%b : !pto.tile_buf) + + pto.tor ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x64xi32>) + return + } + + // Case 1: i32 32x32 (1024 elements) + func.func @TOR_i32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) + outs(%b : !pto.tile_buf) + + pto.tor ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trem/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/trem/CMakeLists.txt new file mode 100644 index 000000000..1c489d8e5 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trem/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(trem) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trem/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/trem/cases.py new file mode 100644 index 000000000..c7f0724b9 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trem/cases.py @@ -0,0 +1,102 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for trem ST test cases. + +Each case defines: + - name: case identifier + - dtype: numpy dtype (np.float32 or np.int32) + - dst_tile: (rows, cols) — dst tile buffer dimensions + - src0_tile: (rows, cols) — src0 tile buffer dimensions + - src1_tile: (rows, cols) — src1 tile buffer dimensions + - valid_shape: (valid_rows, valid_cols) — effective computation region + - eps: tolerance for numpy.allclose (atol and rtol) + +Note: src0/src1/dst tile buffer physical sizes can differ, + but valid_shape must be the same for all. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_16x64_16x128_16x128_16x64", + "dtype": np.float32, + "dst_tile": (16, 64), + "src0_tile": (16, 128), + "src1_tile": (16, 128), + "valid_shape": (16, 64), + "eps": 1e-3, + }, + { + "name": "f32_16x32_16x64_16x32_16x32", + "dtype": np.float32, + "dst_tile": (16, 32), + "src0_tile": (16, 64), + "src1_tile": (16, 32), + "valid_shape": (16, 32), + "eps": 1e-3, + }, + { + "name": "i32_4x32_4x32_4x32_4x32", + "dtype": np.int32, + "dst_tile": (4, 32), + "src0_tile": (4, 32), + "src1_tile": (4, 32), + "valid_shape": (4, 32), + "eps": 0, + }, + { + "name": "i32_16x32_16x64_16x32_16x32", + "dtype": np.int32, + "dst_tile": (16, 32), + "src0_tile": (16, 64), + "src1_tile": (16, 32), + "valid_shape": (16, 32), + "eps": 0, + }, + { + "name": "f32_16x64_16x128_16x128_16x63", + "dtype": np.float32, + "dst_tile": (16, 64), + "src0_tile": (16, 128), + "src1_tile": (16, 128), + "valid_shape": (16, 63), + "eps": 1e-3, + }, + { + "name": "f32_2x32_2x64_2x32_2x31", + "dtype": np.float32, + "dst_tile": (2, 32), + "src0_tile": (2, 64), + "src1_tile": (2, 32), + "valid_shape": (2, 31), + "eps": 1e-3, + }, + { + "name": "i32_16x32_16x64_16x32_16x31", + "dtype": np.int32, + "dst_tile": (16, 32), + "src0_tile": (16, 64), + "src1_tile": (16, 32), + "valid_shape": (16, 31), + "eps": 0, + }, + { + "name": "f32_1x8192_1x8192_1x8192_1x8192", + "dtype": np.float32, + "dst_tile": (1, 8192), + "src0_tile": (1, 8192), + "src1_tile": (1, 8192), + "valid_shape": (1, 8192), + "eps": 1e-3, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trem/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/trem/compare.py new file mode 100644 index 000000000..f09a461e9 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trem/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dst_tile = case["dst_tile"] + valid_shape = case["valid_shape"] + vr, vc = valid_shape + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(dst_tile) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(dst_tile) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trem/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/trem/gen_data.py new file mode 100644 index 000000000..038e4f976 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trem/gen_data.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +np.random.seed(19) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + dst_tile = case["dst_tile"] + src0_tile = case["src0_tile"] + src1_tile = case["src1_tile"] + valid_shape = case["valid_shape"] + + input1 = np.random.uniform(low=-1000, high=1000, size=src0_tile).astype(dtype) + input2 = np.random.uniform(low=3, high=100, size=src1_tile).astype(dtype) + + golden = np.zeros(dst_tile, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = np.remainder(input1[:vr, :vc], input2[:vr, :vc]).astype(dtype) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} dst={dst_tile} src0={src0_tile} src1={src1_tile} valid={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trem/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trem/launch.cpp new file mode 100644 index 000000000..96bbf77eb --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trem/launch.cpp @@ -0,0 +1,69 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32, dst=16x64, src0=16x128, src1=16x128, valid=16x64 +extern "C" __global__ AICORE void TREM_f32_16x64_16x128_16x128_16x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTREM_f32_16x64_16x128_16x128_16x64(void *a, void *b, void *c, void *stream) { + TREM_f32_16x64_16x128_16x128_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 1: f32, dst=16x32, src0=16x64, src1=16x32, valid=16x32 +extern "C" __global__ AICORE void TREM_f32_16x32_16x64_16x32_16x32(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTREM_f32_16x32_16x64_16x32_16x32(void *a, void *b, void *c, void *stream) { + TREM_f32_16x32_16x64_16x32_16x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 2: i32, dst=4x32, src0=4x32, src1=4x32, valid=4x32 +extern "C" __global__ AICORE void TREM_i32_4x32_4x32_4x32_4x32(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); + +void LaunchTREM_i32_4x32_4x32_4x32_4x32(void *a, void *b, void *c, void *stream) { + TREM_i32_4x32_4x32_4x32_4x32<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); +} + +// Case 3: i32, dst=16x32, src0=16x64, src1=16x32, valid=16x32 +extern "C" __global__ AICORE void TREM_i32_16x32_16x64_16x32_16x32(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); + +void LaunchTREM_i32_16x32_16x64_16x32_16x32(void *a, void *b, void *c, void *stream) { + TREM_i32_16x32_16x64_16x32_16x32<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); +} + +// Case 4: f32, dst=16x64, src0=16x128, src1=16x128, valid=16x63 +extern "C" __global__ AICORE void TREM_f32_16x64_16x128_16x128_16x63(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTREM_f32_16x64_16x128_16x128_16x63(void *a, void *b, void *c, void *stream) { + TREM_f32_16x64_16x128_16x128_16x63<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 5: f32, dst=2x32, src0=2x64, src1=2x32, valid=2x31 +extern "C" __global__ AICORE void TREM_f32_2x32_2x64_2x32_2x31(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTREM_f32_2x32_2x64_2x32_2x31(void *a, void *b, void *c, void *stream) { + TREM_f32_2x32_2x64_2x32_2x31<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 6: i32, dst=16x32, src0=16x64, src1=16x32, valid=16x31 +extern "C" __global__ AICORE void TREM_i32_16x32_16x64_16x32_16x31(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); + +void LaunchTREM_i32_16x32_16x64_16x32_16x31(void *a, void *b, void *c, void *stream) { + TREM_i32_16x32_16x64_16x32_16x31<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); +} + +// Case 7: f32, dst=1x8192, src0=1x8192, src1=1x8192, valid=1x8192 +extern "C" __global__ AICORE void TREM_f32_1x8192_1x8192_1x8192_1x8192(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTREM_f32_1x8192_1x8192_1x8192_1x8192(void *a, void *b, void *c, void *stream) { + TREM_f32_1x8192_1x8192_1x8192_1x8192<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trem/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trem/main.cpp new file mode 100644 index 000000000..07c5020bf --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trem/main.cpp @@ -0,0 +1,157 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang trem ST — case-table driven. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +void LaunchTREM_f32_16x64_16x128_16x128_16x64(void *a, void *b, void *c, void *stream); +void LaunchTREM_f32_16x32_16x64_16x32_16x32(void *a, void *b, void *c, void *stream); +void LaunchTREM_i32_4x32_4x32_4x32_4x32(void *a, void *b, void *c, void *stream); +void LaunchTREM_i32_16x32_16x64_16x32_16x32(void *a, void *b, void *c, void *stream); +void LaunchTREM_f32_16x64_16x128_16x128_16x63(void *a, void *b, void *c, void *stream); +void LaunchTREM_f32_2x32_2x64_2x32_2x31(void *a, void *b, void *c, void *stream); +void LaunchTREM_i32_16x32_16x64_16x32_16x31(void *a, void *b, void *c, void *stream); +void LaunchTREM_f32_1x8192_1x8192_1x8192_1x8192(void *a, void *b, void *c, void *stream); + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t src0Rows; + size_t src0Cols; + size_t src1Rows; + size_t src1Cols; + size_t dstRows; + size_t dstCols; + size_t validRows; + size_t validCols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + {"f32_16x64_16x128_16x128_16x64", LaunchTREM_f32_16x64_16x128_16x128_16x64, 16, 128, 16, 128, 16, 64, 16, 64, sizeof(float)}, + {"f32_16x32_16x64_16x32_16x32", LaunchTREM_f32_16x32_16x64_16x32_16x32, 16, 64, 16, 32, 16, 32, 16, 32, sizeof(float)}, + {"i32_4x32_4x32_4x32_4x32", LaunchTREM_i32_4x32_4x32_4x32_4x32, 4, 32, 4, 32, 4, 32, 4, 32, sizeof(int32_t)}, + {"i32_16x32_16x64_16x32_16x32", LaunchTREM_i32_16x32_16x64_16x32_16x32, 16, 64, 16, 32, 16, 32, 16, 32, sizeof(int32_t)}, + {"f32_16x64_16x128_16x128_16x63", LaunchTREM_f32_16x64_16x128_16x128_16x63, 16, 128, 16, 128, 16, 64, 16, 63, sizeof(float)}, + {"f32_2x32_2x64_2x32_2x31", LaunchTREM_f32_2x32_2x64_2x32_2x31, 2, 64, 2, 32, 2, 32, 2, 31, sizeof(float)}, + {"i32_16x32_16x64_16x32_16x31", LaunchTREM_i32_16x32_16x64_16x32_16x31, 16, 64, 16, 32, 16, 32, 16, 31, sizeof(int32_t)}, + {"f32_1x8192_1x8192_1x8192_1x8192", LaunchTREM_f32_1x8192_1x8192_1x8192_1x8192, 1, 8192, 1, 8192, 1, 8192, 1, 8192, sizeof(float)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t src0Size = tc.src0Rows * tc.src0Cols * tc.elemSize; + const size_t src1Size = tc.src1Rows * tc.src1Cols * tc.elemSize; + const size_t dstSize = tc.dstRows * tc.dstCols * tc.elemSize; + + std::printf("[INFO] === case: %s (dst=%zux%zu, src0=%zux%zu, src1=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.dstRows, tc.dstCols, tc.src0Rows, tc.src0Cols, tc.src1Rows, tc.src1Cols, tc.validRows, tc.validCols); + + std::string caseDir = std::string("./") + tc.name; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&src0Host, src0Size); + aclrtMallocHost(&src1Host, src1Size); + aclrtMallocHost(&dstHost, dstSize); + + aclrtMalloc(&src0Device, src0Size, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&src1Device, src1Size, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, dstSize, ACL_MEM_MALLOC_HUGE_FIRST); + + size_t fileSize = 0; + if (!ReadFile((caseDir + "/input1.bin").c_str(), fileSize, src0Host, src0Size)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + fileSize = 0; + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), fileSize, src1Host, src1Size)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, src0Size, src0Host, src0Size, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, src1Size, src1Host, src1Size, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstSize, dstDevice, dstSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trem/trem.pto b/test/tilelang_st/npu/a5/src/st/testcase/trem/trem.pto new file mode 100644 index 000000000..e34974ef5 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trem/trem.pto @@ -0,0 +1,577 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.trem: tload(a) + tload(b) + trem(a,b,tmp)->c + tstore(c). +// Aligned with pto-isa/tests/npu/a2a3/src/st/testcase/trem +// Cases have different src/dst tile buffer sizes but same valid_shape. + +module { + // Case 0: f32, dst=16x64, src0=16x128, src1=16x128, valid=16x64 + func.func @TREM_f32_16x64_16x128_16x128_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c2048 = arith.constant 2048 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%b : !pto.tile_buf) + + pto.trem ins(%a, %b, %tmp : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case 1: f32, dst=16x32, src0=16x64, src1=16x32, valid=16x32 + func.func @TREM_f32_16x32_16x64_16x32_16x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c512 = arith.constant 512 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x32xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xf32> -> !pto.partition_tensor_view<1x1x1x16x32xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xf32> -> !pto.partition_tensor_view<1x1x1x16x32xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x32xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x32xf32>) + outs(%b : !pto.tile_buf) + + pto.trem ins(%a, %b, %tmp : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x32xf32>) + return + } + + // Case 2: i32, dst=4x32, src0=4x32, src1=4x32, valid=4x32 + func.func @TREM_i32_4x32_4x32_4x32_4x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xi32> -> !pto.partition_tensor_view<1x1x1x4x32xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xi32> -> !pto.partition_tensor_view<1x1x1x4x32xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xi32> -> !pto.partition_tensor_view<1x1x1x4x32xi32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x4x32xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x4x32xi32>) + outs(%b : !pto.tile_buf) + + pto.trem ins(%a, %b, %tmp : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x4x32xi32>) + return + } + + // Case 3: i32, dst=16x32, src0=16x64, src1=16x32, valid=16x32 + func.func @TREM_i32_16x32_16x64_16x32_16x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c512 = arith.constant 512 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + outs(%b : !pto.tile_buf) + + pto.trem ins(%a, %b, %tmp : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + return + } + + // Case 4: f32, dst=16x64, src0=16x128, src1=16x128, valid=16x63 + func.func @TREM_f32_16x64_16x128_16x128_16x63(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c2048 = arith.constant 2048 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c63] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x63xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c63] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x63xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c63] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x63xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x63xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x63xf32>) + outs(%b : !pto.tile_buf) + + pto.trem ins(%a, %b, %tmp : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x63xf32>) + return + } + + // Case 5: f32, dst=2x32, src0=2x64, src1=2x32, valid=2x31 + func.func @TREM_f32_2x32_2x64_2x32_2x31(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c31 = arith.constant 31 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c2, %c32], + strides = [%c64, %c64, %c64, %c32, %c1] + : !pto.tensor_view<1x1x1x2x32xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c2, %c32], + strides = [%c64, %c64, %c64, %c32, %c1] + : !pto.tensor_view<1x1x1x2x32xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c31] + : !pto.tensor_view<1x1x1x2x64xf32> -> !pto.partition_tensor_view<1x1x1x2x31xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c31] + : !pto.tensor_view<1x1x1x2x32xf32> -> !pto.partition_tensor_view<1x1x1x2x31xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c31] + : !pto.tensor_view<1x1x1x2x32xf32> -> !pto.partition_tensor_view<1x1x1x2x31xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x2x31xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x2x31xf32>) + outs(%b : !pto.tile_buf) + + pto.trem ins(%a, %b, %tmp : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x2x31xf32>) + return + } + + // Case 6: i32, dst=16x32, src0=16x64, src1=16x32, valid=16x31 + func.func @TREM_i32_16x32_16x64_16x32_16x31(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c31 = arith.constant 31 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c512 = arith.constant 512 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c31] + : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x31xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c31] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x31xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c31] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x31xi32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x31xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x31xi32>) + outs(%b : !pto.tile_buf) + + pto.trem ins(%a, %b, %tmp : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x31xi32>) + return + } + + // Case 7: f32, dst=1x8192, src0=1x8192, src1=1x8192, valid=1x8192 + func.func @TREM_f32_1x8192_1x8192_1x8192_1x8192(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8192 = arith.constant 8192 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c1, %c8192], + strides = [%c8192, %c8192, %c8192, %c8192, %c1] + : !pto.tensor_view<1x1x1x1x8192xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c1, %c8192], + strides = [%c8192, %c8192, %c8192, %c8192, %c1] + : !pto.tensor_view<1x1x1x1x8192xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c1, %c8192], + strides = [%c8192, %c8192, %c8192, %c8192, %c1] + : !pto.tensor_view<1x1x1x1x8192xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c8192] + : !pto.tensor_view<1x1x1x1x8192xf32> -> !pto.partition_tensor_view<1x1x1x1x8192xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c8192] + : !pto.tensor_view<1x1x1x1x8192xf32> -> !pto.partition_tensor_view<1x1x1x1x8192xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c8192] + : !pto.tensor_view<1x1x1x1x8192xf32> -> !pto.partition_tensor_view<1x1x1x1x8192xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x1x8192xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x1x8192xf32>) + outs(%b : !pto.tile_buf) + + pto.trem ins(%a, %b, %tmp : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x1x8192xf32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshl/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tshl/CMakeLists.txt new file mode 100644 index 000000000..42b3fa0bd --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshl/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tshl) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshl/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tshl/cases.py new file mode 100644 index 000000000..4b312c5e0 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshl/cases.py @@ -0,0 +1,39 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tshl ST test cases. + +Each case defines: + - name: case identifier + - dtype: numpy dtype (np.int32) + - shape: (rows, cols) — allocated tile dimensions + - valid_shape: (valid_rows, valid_cols) — effective computation region + - eps: tolerance for numpy.allclose (atol and rtol) +""" + +import numpy as np + +CASES = [ + { + "name": "i32_16x64", + "dtype": np.int32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 0, + }, + { + "name": "i32_32x32", + "dtype": np.int32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshl/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tshl/compare.py new file mode 100644 index 000000000..b975718e6 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshl/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshl/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tshl/gen_data.py new file mode 100644 index 000000000..fdeae9fa0 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshl/gen_data.py @@ -0,0 +1,33 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 100, size=shape).astype(dtype) + input2 = np.random.randint(0, 8, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = (input1[:vr, :vc] << input2[:vr, :vc]).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshl/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tshl/launch.cpp new file mode 100644 index 000000000..d58d324a4 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshl/launch.cpp @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: i32 16x64 +extern "C" __global__ AICORE void TSHL_i32_16x64(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); + +void LaunchTSHL_i32_16x64(int32_t *a, int32_t *b, int32_t *c, void *stream) { + TSHL_i32_16x64<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); +} + +// Case 1: i32 32x32 +extern "C" __global__ AICORE void TSHL_i32_32x32(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); + +void LaunchTSHL_i32_32x32(int32_t *a, int32_t *b, int32_t *c, void *stream) { + TSHL_i32_32x32<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshl/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tshl/main.cpp new file mode 100644 index 000000000..cb35c3a31 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshl/main.cpp @@ -0,0 +1,145 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tshl ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTSHL_i32_16x64(int32_t *a, int32_t *b, int32_t *c, void *stream); +void LaunchTSHL_i32_32x32(int32_t *a, int32_t *b, int32_t *c, void *stream); + +using LaunchFn = void (*)(int32_t *, int32_t *, int32_t *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"i32_16x64", LaunchTSHL_i32_16x64, 16, 64, 16, 64, sizeof(int32_t)}, + {"i32_32x32", LaunchTSHL_i32_32x32, 32, 32, 32, 32, sizeof(int32_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = fileSize; + size_t src1FileSize = fileSize; + + int32_t *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + int32_t *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), fileSize); + aclrtMallocHost((void **)(&src1Host), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); + + aclrtMalloc((void **)&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, fileSize, src0Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, fileSize, src1Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tshl [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshl/tshl.pto b/test/tilelang_st/npu/a5/src/st/testcase/tshl/tshl.pto new file mode 100644 index 000000000..4526b8b3f --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshl/tshl.pto @@ -0,0 +1,144 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tshl: tload(a) + tload(b) + tshl(a,b)->c + tstore(c). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. +// +// NOTE: Unsigned integer types temporarily disabled due to HIVM backend bug +// See BUG_REPORT_UNSIGNED_BITOPS.md for details + +module { + // Case 0: i32 16x64 (1024 elements) + func.func @TSHL_i32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x64xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x64xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x64xi32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x64xi32>) + outs(%b : !pto.tile_buf) + + pto.tshl ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x64xi32>) + return + } + + // Case 1: i32 32x32 (1024 elements) + func.func @TSHL_i32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) + outs(%b : !pto.tile_buf) + + pto.tshl ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshr/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tshr/CMakeLists.txt new file mode 100644 index 000000000..ed8592e59 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshr/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tshr) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshr/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tshr/cases.py new file mode 100644 index 000000000..2ce53b6c5 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshr/cases.py @@ -0,0 +1,39 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tshr ST test cases. + +Each case defines: + - name: case identifier + - dtype: numpy dtype (np.int32) + - shape: (rows, cols) — allocated tile dimensions + - valid_shape: (valid_rows, valid_cols) — effective computation region + - eps: tolerance for numpy.allclose (atol and rtol) +""" + +import numpy as np + +CASES = [ + { + "name": "i32_16x64", + "dtype": np.int32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 0, + }, + { + "name": "i32_32x32", + "dtype": np.int32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshr/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tshr/compare.py new file mode 100644 index 000000000..b975718e6 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshr/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshr/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tshr/gen_data.py new file mode 100644 index 000000000..152667627 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshr/gen_data.py @@ -0,0 +1,33 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 100, size=shape).astype(dtype) + input2 = np.random.randint(0, 8, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = (input1[:vr, :vc] >> input2[:vr, :vc]).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshr/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tshr/launch.cpp new file mode 100644 index 000000000..1d4f9cd1d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshr/launch.cpp @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: i32 16x64 +extern "C" __global__ AICORE void TSHR_i32_16x64(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); + +void LaunchTSHR_i32_16x64(int32_t *a, int32_t *b, int32_t *c, void *stream) { + TSHR_i32_16x64<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); +} + +// Case 1: i32 32x32 +extern "C" __global__ AICORE void TSHR_i32_32x32(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); + +void LaunchTSHR_i32_32x32(int32_t *a, int32_t *b, int32_t *c, void *stream) { + TSHR_i32_32x32<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshr/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tshr/main.cpp new file mode 100644 index 000000000..e99634eae --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshr/main.cpp @@ -0,0 +1,145 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tshr ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTSHR_i32_16x64(int32_t *a, int32_t *b, int32_t *c, void *stream); +void LaunchTSHR_i32_32x32(int32_t *a, int32_t *b, int32_t *c, void *stream); + +using LaunchFn = void (*)(int32_t *, int32_t *, int32_t *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"i32_16x64", LaunchTSHR_i32_16x64, 16, 64, 16, 64, sizeof(int32_t)}, + {"i32_32x32", LaunchTSHR_i32_32x32, 32, 32, 32, 32, sizeof(int32_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = fileSize; + size_t src1FileSize = fileSize; + + int32_t *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + int32_t *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), fileSize); + aclrtMallocHost((void **)(&src1Host), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); + + aclrtMalloc((void **)&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, fileSize, src0Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, fileSize, src1Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tshr [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshr/tshr.pto b/test/tilelang_st/npu/a5/src/st/testcase/tshr/tshr.pto new file mode 100644 index 000000000..dda93d3bf --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshr/tshr.pto @@ -0,0 +1,144 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tshr: tload(a) + tload(b) + tshr(a,b)->c + tstore(c). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. +// +// NOTE: Unsigned integer types temporarily disabled due to HIVM backend bug +// See BUG_REPORT_UNSIGNED_BITOPS.md for details + +module { + // Case 0: i32 16x64 (1024 elements) + func.func @TSHR_i32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x64xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x64xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x64xi32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x64xi32>) + outs(%b : !pto.tile_buf) + + pto.tshr ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x64xi32>) + return + } + + // Case 1: i32 32x32 (1024 elements) + func.func @TSHR_i32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) + outs(%b : !pto.tile_buf) + + pto.tshr ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsub/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tsub/CMakeLists.txt new file mode 100644 index 000000000..da60b1f64 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsub/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tsub) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsub/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tsub/cases.py new file mode 100644 index 000000000..20ef25491 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsub/cases.py @@ -0,0 +1,69 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tsub ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_16x64", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-6, + }, + { + "name": "f32_32x32", + "dtype": np.float32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-6, + }, + { + "name": "f32_64x64", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 1e-6, + }, + { + "name": "i32_64x64", + "dtype": np.int32, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 0, + }, + { + "name": "i16_64x64", + "dtype": np.int16, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 0, + }, + { + "name": "f16_64x64", + "dtype": np.float16, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 1e-3, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsub/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tsub/compare.py new file mode 100644 index 000000000..b975718e6 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsub/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsub/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tsub/gen_data.py new file mode 100644 index 000000000..234dd6cfc --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsub/gen_data.py @@ -0,0 +1,33 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + input2 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = (input1[:vr, :vc] - input2[:vr, :vc]).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsub/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tsub/launch.cpp new file mode 100644 index 000000000..8c00680c0 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsub/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 16x64 +extern "C" __global__ AICORE void TSUB_f32_16x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTSUB_f32_16x64(void *a, void *b, void *c, void *stream) { + TSUB_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 1: f32 32x32 +extern "C" __global__ AICORE void TSUB_f32_32x32(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTSUB_f32_32x32(void *a, void *b, void *c, void *stream) { + TSUB_f32_32x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 2: f32 64x64 +extern "C" __global__ AICORE void TSUB_f32_64x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTSUB_f32_64x64(void *a, void *b, void *c, void *stream) { + TSUB_f32_64x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 3: i32 64x64 +extern "C" __global__ AICORE void TSUB_i32_64x64(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); + +void LaunchTSUB_i32_64x64(void *a, void *b, void *c, void *stream) { + TSUB_i32_64x64<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); +} + +// Case 4: i16 64x64 +extern "C" __global__ AICORE void TSUB_i16_64x64(__gm__ int16_t *a, __gm__ int16_t *b, __gm__ int16_t *c); + +void LaunchTSUB_i16_64x64(void *a, void *b, void *c, void *stream) { + TSUB_i16_64x64<<<1, nullptr, stream>>>((__gm__ int16_t *)a, (__gm__ int16_t *)b, (__gm__ int16_t *)c); +} + +// Case 5: f16 64x64 +extern "C" __global__ AICORE void TSUB_f16_64x64(__gm__ half *a, __gm__ half *b, __gm__ half *c); + +void LaunchTSUB_f16_64x64(void *a, void *b, void *c, void *stream) { + TSUB_f16_64x64<<<1, nullptr, stream>>>((__gm__ half *)a, (__gm__ half *)b, (__gm__ half *)c); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsub/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tsub/main.cpp new file mode 100644 index 000000000..0827b91a3 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsub/main.cpp @@ -0,0 +1,154 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tsub ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTSUB_f32_16x64(void *a, void *b, void *c, void *stream); +void LaunchTSUB_f32_32x32(void *a, void *b, void *c, void *stream); +void LaunchTSUB_f32_64x64(void *a, void *b, void *c, void *stream); +void LaunchTSUB_i32_64x64(void *a, void *b, void *c, void *stream); +void LaunchTSUB_i16_64x64(void *a, void *b, void *c, void *stream); +void LaunchTSUB_f16_64x64(void *a, void *b, void *c, void *stream); + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_16x64", LaunchTSUB_f32_16x64, 16, 64, 16, 64, sizeof(float)}, + {"f32_32x32", LaunchTSUB_f32_32x32, 32, 32, 32, 32, sizeof(float)}, + {"f32_64x64", LaunchTSUB_f32_64x64, 64, 64, 64, 64, sizeof(float)}, + {"i32_64x64", LaunchTSUB_i32_64x64, 64, 64, 64, 64, sizeof(int32_t)}, + {"i16_64x64", LaunchTSUB_i16_64x64, 64, 64, 64, 64, sizeof(int16_t)}, + {"f16_64x64", LaunchTSUB_f16_64x64, 64, 64, 64, 64, sizeof(uint16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + (void)deviceId; + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = fileSize; + size_t src1FileSize = fileSize; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&src0Host, fileSize); + aclrtMallocHost(&src1Host, fileSize); + aclrtMallocHost(&dstHost, fileSize); + + aclrtMalloc(&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, fileSize, src0Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, fileSize, src1Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tsub [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsub/tsub.pto b/test/tilelang_st/npu/a5/src/st/testcase/tsub/tsub.pto new file mode 100644 index 000000000..e64f6b9d9 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsub/tsub.pto @@ -0,0 +1,393 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tsub: tload(a) + tload(b) + tsub(a,b)->c + tstore(c). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. + +module { + // Case 0: f32 16x64 + func.func @TSUB_f32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tsub ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case 1: f32 32x32 + func.func @TSUB_f32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%b : !pto.tile_buf) + + pto.tsub ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + return + } + + // Case 2: f32 64x64 + func.func @TSUB_f32_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tsub ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case 3: i32 64x64 + func.func @TSUB_i32_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + outs(%b : !pto.tile_buf) + + pto.tsub ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + return + } + + // Case 4: i16 64x64 + func.func @TSUB_i16_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) + outs(%b : !pto.tile_buf) + + pto.tsub ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) + return + } + + // Case 5: f16 64x64 + func.func @TSUB_f16_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf16> -> !pto.partition_tensor_view<1x1x1x64x64xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf16> -> !pto.partition_tensor_view<1x1x1x64x64xf16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf16> -> !pto.partition_tensor_view<1x1x1x64x64xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf16>) + outs(%b : !pto.tile_buf) + + pto.tsub ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf16>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/txor/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/txor/CMakeLists.txt new file mode 100644 index 000000000..e54d015e1 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/txor/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(txor) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/txor/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/txor/cases.py new file mode 100644 index 000000000..766ddf35a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/txor/cases.py @@ -0,0 +1,66 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for txor ST test cases. + +Each case defines: + - name: case identifier + - dtype: numpy dtype (np.int16, np.int8) + - dst_tile: (rows, cols) — dst tile buffer dimensions + - src0_tile: (rows, cols) — src0 tile buffer dimensions + - src1_tile: (rows, cols) — src1 tile buffer dimensions + - valid_shape: (valid_rows, valid_cols) — effective computation region + - eps: tolerance for numpy.allclose (atol and rtol) + +Note: src0/src1/dst tile buffer physical sizes can differ, + but valid_shape must be the same for all. +""" + +import numpy as np + +CASES = [ + { + "name": "i16_64x64_64x64_64x64_64x64", + "dtype": np.int16, + "dst_tile": (64, 64), + "src0_tile": (64, 64), + "src1_tile": (64, 64), + "valid_shape": (64, 64), + "eps": 0, + }, + { + "name": "i16_32x128_32x128_32x256_32x128", + "dtype": np.int16, + "dst_tile": (32, 128), + "src0_tile": (32, 128), + "src1_tile": (32, 256), + "valid_shape": (32, 128), + "eps": 0, + }, + { + "name": "i16_32x128_32x128_32x256_32x127", + "dtype": np.int16, + "dst_tile": (32, 128), + "src0_tile": (32, 128), + "src1_tile": (32, 256), + "valid_shape": (32, 127), + "eps": 0, + }, + { + "name": "i8_32x128_32x128_32x256_32x127", + "dtype": np.int8, + "dst_tile": (32, 128), + "src0_tile": (32, 128), + "src1_tile": (32, 256), + "valid_shape": (32, 127), + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/txor/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/txor/compare.py new file mode 100644 index 000000000..96ab2dda8 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/txor/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dst_tile = case["dst_tile"] + valid_shape = case["valid_shape"] + vr, vc = valid_shape + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(dst_tile) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(dst_tile) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/txor/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/txor/gen_data.py new file mode 100644 index 000000000..62e5cb32b --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/txor/gen_data.py @@ -0,0 +1,38 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +np.random.seed(19) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + dst_tile = case["dst_tile"] + src0_tile = case["src0_tile"] + src1_tile = case["src1_tile"] + valid_shape = case["valid_shape"] + + dtype_info = np.iinfo(dtype) + input1 = np.random.randint(dtype_info.min, dtype_info.max, size=src0_tile).astype(dtype) + input2 = np.random.randint(dtype_info.min, dtype_info.max, size=src1_tile).astype(dtype) + + golden = np.zeros(dst_tile, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = (input1[:vr, :vc] ^ input2[:vr, :vc]).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} dst={dst_tile} src0={src0_tile} src1={src1_tile} valid={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/txor/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/txor/launch.cpp new file mode 100644 index 000000000..50bf8af26 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/txor/launch.cpp @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: i16, dst=64x64, src0=64x64, src1=64x64, valid=64x64 +extern "C" __global__ AICORE void TXOR_i16_64x64_64x64_64x64_64x64(__gm__ int16_t *a, __gm__ int16_t *b, __gm__ int16_t *c); + +void LaunchTXOR_i16_64x64_64x64_64x64_64x64(void *a, void *b, void *c, void *stream) { + TXOR_i16_64x64_64x64_64x64_64x64<<<1, nullptr, stream>>>((__gm__ int16_t *)a, (__gm__ int16_t *)b, (__gm__ int16_t *)c); +} + +// Case 1: i16, dst=32x128, src0=32x128, src1=32x256, valid=32x128 +extern "C" __global__ AICORE void TXOR_i16_32x128_32x128_32x256_32x128(__gm__ int16_t *a, __gm__ int16_t *b, __gm__ int16_t *c); + +void LaunchTXOR_i16_32x128_32x128_32x256_32x128(void *a, void *b, void *c, void *stream) { + TXOR_i16_32x128_32x128_32x256_32x128<<<1, nullptr, stream>>>((__gm__ int16_t *)a, (__gm__ int16_t *)b, (__gm__ int16_t *)c); +} + +// Case 2: i16, dst=32x128, src0=32x128, src1=32x256, valid=32x127 +extern "C" __global__ AICORE void TXOR_i16_32x128_32x128_32x256_32x127(__gm__ int16_t *a, __gm__ int16_t *b, __gm__ int16_t *c); + +void LaunchTXOR_i16_32x128_32x128_32x256_32x127(void *a, void *b, void *c, void *stream) { + TXOR_i16_32x128_32x128_32x256_32x127<<<1, nullptr, stream>>>((__gm__ int16_t *)a, (__gm__ int16_t *)b, (__gm__ int16_t *)c); +} + +// Case 3: i8, dst=32x128, src0=32x128, src1=32x256, valid=32x127 +extern "C" __global__ AICORE void TXOR_i8_32x128_32x128_32x256_32x127(__gm__ int8_t *a, __gm__ int8_t *b, __gm__ int8_t *c); + +void LaunchTXOR_i8_32x128_32x128_32x256_32x127(void *a, void *b, void *c, void *stream) { + TXOR_i8_32x128_32x128_32x256_32x127<<<1, nullptr, stream>>>((__gm__ int8_t *)a, (__gm__ int8_t *)b, (__gm__ int8_t *)c); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/txor/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/txor/main.cpp new file mode 100644 index 000000000..70729478e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/txor/main.cpp @@ -0,0 +1,149 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang txor ST — case-table driven. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +void LaunchTXOR_i16_64x64_64x64_64x64_64x64(void *a, void *b, void *c, void *stream); +void LaunchTXOR_i16_32x128_32x128_32x256_32x128(void *a, void *b, void *c, void *stream); +void LaunchTXOR_i16_32x128_32x128_32x256_32x127(void *a, void *b, void *c, void *stream); +void LaunchTXOR_i8_32x128_32x128_32x256_32x127(void *a, void *b, void *c, void *stream); + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t src0Rows; + size_t src0Cols; + size_t src1Rows; + size_t src1Cols; + size_t dstRows; + size_t dstCols; + size_t validRows; + size_t validCols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + {"i16_64x64_64x64_64x64_64x64", LaunchTXOR_i16_64x64_64x64_64x64_64x64, 64, 64, 64, 64, 64, 64, 64, 64, sizeof(int16_t)}, + {"i16_32x128_32x128_32x256_32x128", LaunchTXOR_i16_32x128_32x128_32x256_32x128, 32, 128, 32, 256, 32, 128, 32, 128, sizeof(int16_t)}, + {"i16_32x128_32x128_32x256_32x127", LaunchTXOR_i16_32x128_32x128_32x256_32x127, 32, 128, 32, 256, 32, 128, 32, 127, sizeof(int16_t)}, + {"i8_32x128_32x128_32x256_32x127", LaunchTXOR_i8_32x128_32x128_32x256_32x127, 32, 128, 32, 256, 32, 128, 32, 127, sizeof(int8_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t src0Size = tc.src0Rows * tc.src0Cols * tc.elemSize; + const size_t src1Size = tc.src1Rows * tc.src1Cols * tc.elemSize; + const size_t dstSize = tc.dstRows * tc.dstCols * tc.elemSize; + + std::printf("[INFO] === case: %s (dst=%zux%zu, src0=%zux%zu, src1=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.dstRows, tc.dstCols, tc.src0Rows, tc.src0Cols, tc.src1Rows, tc.src1Cols, tc.validRows, tc.validCols); + + std::string caseDir = std::string("./") + tc.name; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&src0Host, src0Size); + aclrtMallocHost(&src1Host, src1Size); + aclrtMallocHost(&dstHost, dstSize); + + aclrtMalloc(&src0Device, src0Size, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&src1Device, src1Size, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, dstSize, ACL_MEM_MALLOC_HUGE_FIRST); + + size_t fileSize = 0; + if (!ReadFile((caseDir + "/input1.bin").c_str(), fileSize, src0Host, src0Size)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + fileSize = 0; + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), fileSize, src1Host, src1Size)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, src0Size, src0Host, src0Size, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, src1Size, src1Host, src1Size, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstSize, dstDevice, dstSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/txor/txor.pto b/test/tilelang_st/npu/a5/src/st/testcase/txor/txor.pto new file mode 100644 index 000000000..d0a923d18 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/txor/txor.pto @@ -0,0 +1,287 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.txor: tload(a) + tload(b) + txor(a,b,c)->c + tstore(c). +// Aligned with pto-isa/tests/npu/a2a3/src/st/testcase/txor +// Cases have different src/dst tile buffer sizes but same valid_shape. +// +// NOTE: Unsigned integer types temporarily disabled due to HIVM backend bug +// See BUG_REPORT_UNSIGNED_BITOPS.md for details + +module { + // Case 0: i16, dst=64x64, src0=64x64, src1=64x64, valid=64x64 + func.func @TXOR_i16_64x64_64x64_64x64_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) + outs(%b : !pto.tile_buf) + + pto.txor ins(%a, %b, %c : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) + return + } + + // Case 1: i16, dst=32x128, src0=32x128, src1=32x256, valid=32x128 + func.func @TXOR_i16_32x128_32x128_32x256_32x128(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c4096 = arith.constant 4096 : index + %c8192 = arith.constant 8192 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c128], + strides = [%c4096, %c4096, %c4096, %c128, %c1] + : !pto.tensor_view<1x1x1x32x128xi16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c256], + strides = [%c8192, %c8192, %c8192, %c256, %c1] + : !pto.tensor_view<1x1x1x32x256xi16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c32, %c128], + strides = [%c4096, %c4096, %c4096, %c128, %c1] + : !pto.tensor_view<1x1x1x32x128xi16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c128] + : !pto.tensor_view<1x1x1x32x128xi16> -> !pto.partition_tensor_view<1x1x1x32x128xi16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c128] + : !pto.tensor_view<1x1x1x32x256xi16> -> !pto.partition_tensor_view<1x1x1x32x128xi16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c128] + : !pto.tensor_view<1x1x1x32x128xi16> -> !pto.partition_tensor_view<1x1x1x32x128xi16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x128xi16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x32x128xi16>) + outs(%b : !pto.tile_buf) + + pto.txor ins(%a, %b, %c : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x32x128xi16>) + return + } + + // Case 2: i16, dst=32x128, src0=32x128, src1=32x256, valid=32x127 + func.func @TXOR_i16_32x128_32x128_32x256_32x127(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c127 = arith.constant 127 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c4096 = arith.constant 4096 : index + %c8192 = arith.constant 8192 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c128], + strides = [%c4096, %c4096, %c4096, %c128, %c1] + : !pto.tensor_view<1x1x1x32x128xi16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c256], + strides = [%c8192, %c8192, %c8192, %c256, %c1] + : !pto.tensor_view<1x1x1x32x256xi16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c32, %c128], + strides = [%c4096, %c4096, %c4096, %c128, %c1] + : !pto.tensor_view<1x1x1x32x128xi16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c127] + : !pto.tensor_view<1x1x1x32x128xi16> -> !pto.partition_tensor_view<1x1x1x32x127xi16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c127] + : !pto.tensor_view<1x1x1x32x256xi16> -> !pto.partition_tensor_view<1x1x1x32x127xi16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c127] + : !pto.tensor_view<1x1x1x32x128xi16> -> !pto.partition_tensor_view<1x1x1x32x127xi16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x127xi16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x32x127xi16>) + outs(%b : !pto.tile_buf) + + pto.txor ins(%a, %b, %c : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x32x127xi16>) + return + } + + // Case 3: i8, dst=32x128, src0=32x128, src1=32x256, valid=32x127 + func.func @TXOR_i8_32x128_32x128_32x256_32x127(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c127 = arith.constant 127 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c4096 = arith.constant 4096 : index + %c8192 = arith.constant 8192 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c128], + strides = [%c4096, %c4096, %c4096, %c128, %c1] + : !pto.tensor_view<1x1x1x32x128xi8> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c256], + strides = [%c8192, %c8192, %c8192, %c256, %c1] + : !pto.tensor_view<1x1x1x32x256xi8> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c32, %c128], + strides = [%c4096, %c4096, %c4096, %c128, %c1] + : !pto.tensor_view<1x1x1x32x128xi8> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c127] + : !pto.tensor_view<1x1x1x32x128xi8> -> !pto.partition_tensor_view<1x1x1x32x127xi8> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c127] + : !pto.tensor_view<1x1x1x32x256xi8> -> !pto.partition_tensor_view<1x1x1x32x127xi8> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c127] + : !pto.tensor_view<1x1x1x32x128xi8> -> !pto.partition_tensor_view<1x1x1x32x127xi8> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x127xi8>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x32x127xi8>) + outs(%b : !pto.tile_buf) + + pto.txor ins(%a, %b, %c : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x32x127xi8>) + return + } +} \ No newline at end of file From 454b0613b261adf9868f675bae5080b032526ce0 Mon Sep 17 00:00:00 2001 From: zwd060924 Date: Mon, 27 Apr 2026 12:18:58 +0800 Subject: [PATCH 182/192] [Add] tcmp trem tfmod --- lib/TileOps/tcmp_template.py | 93 +++++-- lib/TileOps/tfmod_template.py | 43 +++- lib/TileOps/trem_template.py | 25 +- .../golden.bin | Bin 8192 -> 0 bytes .../input1.bin | Bin 8192 -> 0 bytes .../input2.bin | Bin 8192 -> 0 bytes .../f32_16x64_16x64_16x64_16x64/golden.bin | Bin 4096 -> 0 bytes .../f32_16x64_16x64_16x64_16x64/input1.bin | Bin 4096 -> 0 bytes .../f32_16x64_16x64_16x64_16x64/input2.bin | Bin 4096 -> 0 bytes .../f32_32x32_32x32_32x32_32x32/golden.bin | Bin 4096 -> 0 bytes .../f32_32x32_32x32_32x32_32x32/input1.bin | Bin 4096 -> 0 bytes .../f32_32x32_32x32_32x32_32x32/input2.bin | Bin 4096 -> 0 bytes .../golden.bin | Bin 32768 -> 0 bytes .../input1.bin | Bin 32768 -> 0 bytes .../input2.bin | Bin 32768 -> 0 bytes .../f32_64x64_64x64_64x64_64x64/golden.bin | Bin 16384 -> 0 bytes .../f32_64x64_64x64_64x64_64x64/input1.bin | Bin 16384 -> 0 bytes .../f32_64x64_64x64_64x64_64x64/input2.bin | Bin 16384 -> 0 bytes .../half_16x64_16x128_16x128_16x64/golden.bin | Bin 2048 -> 0 bytes .../half_16x64_16x128_16x128_16x64/input1.bin | Bin 4096 -> 0 bytes .../half_16x64_16x128_16x128_16x64/input2.bin | Bin 4096 -> 0 bytes .../i16_64x64_64x64_64x64_64x64/golden.bin | Bin 8192 -> 0 bytes .../i16_64x64_64x64_64x64_64x64/input1.bin | Bin 8192 -> 0 bytes .../i16_64x64_64x64_64x64_64x64/input2.bin | Bin 8192 -> 0 bytes .../i32_64x64_64x64_64x64_64x64/golden.bin | Bin 16384 -> 0 bytes .../i32_64x64_64x64_64x64_64x64/input1.bin | Bin 16384 -> 0 bytes .../i32_64x64_64x64_64x64_64x64/input2.bin | Bin 16384 -> 0 bytes .../npu/a5/src/st/testcase/tcmp/cases.py | 41 +-- .../npu/a5/src/st/testcase/tcmp/compare.py | 13 +- .../npu/a5/src/st/testcase/tcmp/gen_data.py | 23 +- .../npu/a5/src/st/testcase/tcmp/launch.cpp | 24 +- .../npu/a5/src/st/testcase/tcmp/main.cpp | 25 +- .../npu/a5/src/st/testcase/tcmp/tcmp.pto | 241 ++---------------- .../npu/a5/src/st/testcase/tfmod/cases.py | 2 +- 34 files changed, 194 insertions(+), 336 deletions(-) delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/f16_16x256_16x256_16x256_16x256/golden.bin delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/f16_16x256_16x256_16x256_16x256/input1.bin delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/f16_16x256_16x256_16x256_16x256/input2.bin delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_16x64_16x64_16x64_16x64/golden.bin delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_16x64_16x64_16x64_16x64/input1.bin delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_16x64_16x64_16x64_16x64/input2.bin delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_32x32_32x32_32x32_32x32/golden.bin delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_32x32_32x32_32x32_32x32/input1.bin delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_32x32_32x32_32x32_32x32/input2.bin delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_64x128_64x128_64x128_64x128/golden.bin delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_64x128_64x128_64x128_64x128/input1.bin delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_64x128_64x128_64x128_64x128/input2.bin delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_64x64_64x64_64x64_64x64/golden.bin delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_64x64_64x64_64x64_64x64/input1.bin delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_64x64_64x64_64x64_64x64/input2.bin delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/half_16x64_16x128_16x128_16x64/golden.bin delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/half_16x64_16x128_16x128_16x64/input1.bin delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/half_16x64_16x128_16x128_16x64/input2.bin delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/i16_64x64_64x64_64x64_64x64/golden.bin delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/i16_64x64_64x64_64x64_64x64/input1.bin delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/i16_64x64_64x64_64x64_64x64/input2.bin delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/i32_64x64_64x64_64x64_64x64/golden.bin delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/i32_64x64_64x64_64x64_64x64/input1.bin delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tadd/i32_64x64_64x64_64x64_64x64/input2.bin diff --git a/lib/TileOps/tcmp_template.py b/lib/TileOps/tcmp_template.py index c008ca237..2fe9e7ae0 100644 --- a/lib/TileOps/tcmp_template.py +++ b/lib/TileOps/tcmp_template.py @@ -1,34 +1,89 @@ -"""TileLang DSL template for pto.tcmp""" +"""TileLang DSL template for pto.tcmp + +Aligned with pto-isa/include/pto/npu/a5/TCmp.hpp: +- 32-bit (int32, float, uint32): TCmp_32B with plt_b32 + pdintlv_b8 +- 16-bit (int16, half, uint16): TCmp with plt_b16 +- 8-bit (int8, uint8): TCmp with plt_b8 +""" -import sys -from pathlib import Path import tilelang_dsl as pto +REPEAT_BYTE = 256 +CMP_BITS_PER_INDEX = 32 + + @pto.vkernel( target="a5", op="pto.tcmp", + dtypes=[ + (pto.si32, pto.si32, pto.i8), + (pto.f32, pto.f32, pto.i8), + (pto.ui32, pto.ui32, pto.i8), + (pto.si16, pto.si16, pto.i8), + (pto.f16, pto.f16, pto.i8), + (pto.ui16, pto.ui16, pto.i8), + (pto.si8, pto.si8, pto.i8), + (pto.ui8, pto.ui8, pto.i8), + ], advanced=True, ) def template_tcmp(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): dtype = src0.element_type - valid_rows, valid_cols = dst.valid_shape - cmp_mode = pto.get_op_attr("cmp_mode", "eq") + valid_rows, valid_cols = src0.valid_shape + cmp_mode = pto.get_op_attr("cmp_mode", pto.CmpMode.EQ) - lanes = pto.get_lanes(dtype) + dtype_size = pto.bytewidth(dtype) + total_elements = valid_rows * valid_cols + repeat_elm = REPEAT_BYTE // dtype_size + src0_ptr = src0.as_ptr() + src1_ptr = src1.as_ptr() dst_ptr = dst.as_ptr() - mask_ptr = pto.castptr(dst_ptr, pto.ptr(pto.ui32, pto.MemorySpace.UB)) - - align_stride = 32 - - for row in range(0, valid_rows, 1): - remained = valid_cols - for col in range(0, valid_cols, lanes): - mask, remained = pto.make_mask(dtype, remained) - lhs = pto.vlds(src0[row, col:]) - rhs = pto.vlds(src1[row, col:]) - result = pto.vcmp(lhs, rhs, mask, cmp_mode) - byte_offset = row * align_stride + (col // 8) - pto.psts(result, mask_ptr, byte_offset) + dst_u32_ptr = pto.castptr(dst_ptr, pto.ptr(pto.ui32, pto.MemorySpace.UB)) + + if pto.constexpr(dtype_size == 4): + repeat_elm_32b = REPEAT_BYTE // 4 + repeat_times_32b = (total_elements + repeat_elm_32b - 1) // repeat_elm_32b + 1 + loop_times = repeat_times_32b // 2 + remaining = total_elements + + for i in range(0, loop_times, 1): + preg0, remaining = pto.plt_b32(remaining) + vreg0 = pto.vlds(src0_ptr, i * 2 * repeat_elm_32b) + vreg1 = pto.vlds(src1_ptr, i * 2 * repeat_elm_32b) + preg1 = pto.vcmp(vreg0, vreg1, preg0, cmp_mode) + + preg0, remaining = pto.plt_b32(remaining) + vreg2 = pto.vlds(src0_ptr, (i * 2 + 1) * repeat_elm_32b) + vreg3 = pto.vlds(src1_ptr, (i * 2 + 1) * repeat_elm_32b) + preg2 = pto.vcmp(vreg2, vreg3, preg0, cmp_mode) + + preg1_b8 = pto.pbitcast(preg1, pto.mask_b8) + preg2_b8 = pto.pbitcast(preg2, pto.mask_b8) + preg3, preg4 = pto.pdintlv_b8(preg1_b8, preg2_b8) + pto.psts(preg3, dst_u32_ptr, i * 16, pto.PredicateDist.PK) + elif pto.constexpr(dtype_size == 2): + repeat_times = (total_elements + repeat_elm - 1) // repeat_elm + dst_stride_bytes = (repeat_elm // CMP_BITS_PER_INDEX) * 4 + remaining = total_elements + + for i in range(0, repeat_times, 1): + preg0, remaining = pto.plt_b16(remaining) + vreg0 = pto.vlds(src0_ptr, i * repeat_elm) + vreg1 = pto.vlds(src1_ptr, i * repeat_elm) + preg1 = pto.vcmp(vreg0, vreg1, preg0, cmp_mode) + pto.psts(preg1, dst_u32_ptr, i * dst_stride_bytes, pto.PredicateDist.PK) + elif pto.constexpr(dtype_size == 1): + repeat_times = (total_elements + repeat_elm - 1) // repeat_elm + dst_stride_bytes = (repeat_elm // CMP_BITS_PER_INDEX) * 4 + remaining = total_elements + + for i in range(0, repeat_times, 1): + preg0, remaining = pto.plt_b8(remaining) + vreg0 = pto.vlds(src0_ptr, i * repeat_elm) + vreg1 = pto.vlds(src1_ptr, i * repeat_elm) + preg1 = pto.vcmp(vreg0, vreg1, preg0, cmp_mode) + pto.psts(preg1, dst_u32_ptr, i * dst_stride_bytes, pto.PredicateDist.PK) + return \ No newline at end of file diff --git a/lib/TileOps/tfmod_template.py b/lib/TileOps/tfmod_template.py index 45df41d3d..a05260863 100644 --- a/lib/TileOps/tfmod_template.py +++ b/lib/TileOps/tfmod_template.py @@ -1,4 +1,11 @@ -"""TileLang DSL template for pto.tfmod""" +"""TileLang DSL template for pto.tfmod + +Aligned with pto-isa/include/pto/npu/a5/TFMod.hpp: +- float: vdiv -> vtrc(ROUND_Z) -> vmul -> vsub +- half: vcvt(half->float, PART_EVEN/ODD) -> vdiv -> vtrc(ROUND_Z) -> vmul -> vsub + -> vcvt(float->half, ROUND_Z, RS_ENABLE, PART_EVEN/ODD) -> vor +- other (i16/ui16): vdiv -> vmul -> vsub (no vtrc, integer div is trunc by nature) +""" import tilelang_dsl as pto @@ -9,6 +16,8 @@ dtypes=[ (pto.f32, pto.f32, pto.f32), (pto.f16, pto.f16, pto.f16), + (pto.i16, pto.i16, pto.i16), + (pto.ui16, pto.ui16, pto.ui16), ], ) def template_tfmod(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): @@ -21,10 +30,34 @@ def template_tfmod(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): mask, remained = pto.make_mask(dtype, remained) lhs = pto.vlds(src0[row, col:]) rhs = pto.vlds(src1[row, col:]) - quotient = pto.vdiv(lhs, rhs, mask) - if pto.constexpr(dtype == pto.f32 or dtype == pto.f16): + + if pto.constexpr(dtype == pto.f32): + quotient = pto.vdiv(lhs, rhs, mask) quotient = pto.vtrc(quotient, mask, rnd=pto.VcvtRoundMode.Z) - truncated_mul = pto.vmul(quotient, rhs, mask) - result = pto.vsub(lhs, truncated_mul, mask) + truncated_mul = pto.vmul(quotient, rhs, mask) + result = pto.vsub(lhs, truncated_mul, mask) + elif pto.constexpr(dtype == pto.f16): + lhs_even = pto.vcvt(lhs, pto.f32, mask, part=pto.VcvtPartMode.EVEN) + rhs_even = pto.vcvt(rhs, pto.f32, mask, part=pto.VcvtPartMode.EVEN) + quotient_even = pto.vdiv(lhs_even, rhs_even, mask) + quotient_even = pto.vtrc(quotient_even, mask, rnd=pto.VcvtRoundMode.Z) + truncated_mul_even = pto.vmul(quotient_even, rhs_even, mask) + result_even = pto.vsub(lhs_even, truncated_mul_even, mask) + dst_even = pto.vcvt(result_even, pto.f16, mask, rnd=pto.VcvtRoundMode.Z, sat=pto.VcvtSatMode.SAT, part=pto.VcvtPartMode.EVEN) + + lhs_odd = pto.vcvt(lhs, pto.f32, mask, part=pto.VcvtPartMode.ODD) + rhs_odd = pto.vcvt(rhs, pto.f32, mask, part=pto.VcvtPartMode.ODD) + quotient_odd = pto.vdiv(lhs_odd, rhs_odd, mask) + quotient_odd = pto.vtrc(quotient_odd, mask, rnd=pto.VcvtRoundMode.Z) + truncated_mul_odd = pto.vmul(quotient_odd, rhs_odd, mask) + result_odd = pto.vsub(lhs_odd, truncated_mul_odd, mask) + dst_odd = pto.vcvt(result_odd, pto.f16, mask, rnd=pto.VcvtRoundMode.Z, sat=pto.VcvtSatMode.SAT, part=pto.VcvtPartMode.ODD) + + result = pto.vor(dst_even, dst_odd, mask) + else: + quotient = pto.vdiv(lhs, rhs, mask) + truncated_mul = pto.vmul(quotient, rhs, mask) + result = pto.vsub(lhs, truncated_mul, mask) + pto.vsts(result, dst[row, col:], mask) return \ No newline at end of file diff --git a/lib/TileOps/trem_template.py b/lib/TileOps/trem_template.py index 947001974..619b664a6 100644 --- a/lib/TileOps/trem_template.py +++ b/lib/TileOps/trem_template.py @@ -11,6 +11,7 @@ (pto.f16, pto.f16, pto.f16, pto.f16), (pto.i32, pto.i32, pto.i32, pto.i32), ], + advanced=True, ) def template_trem(src0: pto.Tile, src1: pto.Tile, tmp: pto.Tile, dst: pto.Tile): dtype = dst.element_type @@ -22,11 +23,33 @@ def template_trem(src0: pto.Tile, src1: pto.Tile, tmp: pto.Tile, dst: pto.Tile): mask, remained = pto.make_mask(dtype, remained) lhs = pto.vlds(src0[row, col:]) rhs = pto.vlds(src1[row, col:]) - if pto.constexpr(dtype == pto.f32 or dtype == pto.f16): + if pto.constexpr(dtype == pto.f32): quotient = pto.vdiv(lhs, rhs, mask) quotient = pto.vtrc(quotient, mask, rnd=pto.VcvtRoundMode.F) floored_mul = pto.vmul(quotient, rhs, mask) result = pto.vsub(lhs, floored_mul, mask) + sign_diff_mask = pto.vcmps(pto.vmul(rhs, result, mask), 0.0, mask, pto.CmpMode.LT) + corrected = pto.vadd(result, rhs, sign_diff_mask) + result = pto.vsel(corrected, result, sign_diff_mask) + elif pto.constexpr(dtype == pto.f16): + lhs_even = pto.vcvt(lhs, pto.f32, mask, part=pto.VcvtPartMode.EVEN) + rhs_even = pto.vcvt(rhs, pto.f32, mask, part=pto.VcvtPartMode.EVEN) + lhs_odd = pto.vcvt(lhs, pto.f32, mask, part=pto.VcvtPartMode.ODD) + rhs_odd = pto.vcvt(rhs, pto.f32, mask, part=pto.VcvtPartMode.ODD) + q_even = pto.vdiv(lhs_even, rhs_even, mask) + q_odd = pto.vdiv(lhs_odd, rhs_odd, mask) + q_even = pto.vtrc(q_even, mask, rnd=pto.VcvtRoundMode.F) + q_odd = pto.vtrc(q_odd, mask, rnd=pto.VcvtRoundMode.F) + fm_even = pto.vmul(q_even, rhs_even, mask) + fm_odd = pto.vmul(q_odd, rhs_odd, mask) + r_even = pto.vsub(lhs_even, fm_even, mask) + r_odd = pto.vsub(lhs_odd, fm_odd, mask) + dst_even = pto.vcvt(r_even, pto.f16, mask, rnd=pto.VcvtRoundMode.Z, sat=pto.VcvtSatMode.RS_ENABLE, part=pto.VcvtPartMode.EVEN) + dst_odd = pto.vcvt(r_odd, pto.f16, mask, rnd=pto.VcvtRoundMode.Z, sat=pto.VcvtSatMode.RS_ENABLE, part=pto.VcvtPartMode.ODD) + result = pto.vor(dst_even, dst_odd, mask) + sign_diff_mask = pto.vcmps(pto.vmul(rhs, result, mask), 0.0, mask, pto.CmpMode.LT) + corrected = pto.vadd(result, rhs, sign_diff_mask) + result = pto.vsel(corrected, result, sign_diff_mask) elif pto.constexpr(dtype == pto.i32): lhs_f32 = pto.vcvt(lhs, pto.f32, mask, rnd=pto.VcvtRoundMode.R) rhs_f32 = pto.vcvt(rhs, pto.f32, mask, rnd=pto.VcvtRoundMode.R) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/f16_16x256_16x256_16x256_16x256/golden.bin b/test/tilelang_st/npu/a5/src/st/testcase/tadd/f16_16x256_16x256_16x256_16x256/golden.bin deleted file mode 100644 index 7e251931a6c657d51480420a3df470f1d55c2e0c..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8192 zcmX}x33B5|3`Eg78bi@u+UB8n^!X1^PbegjRag>-jmq=kE1! zexC1E{+^%SUytvfKEIyt>ptJf{ptPnwD#QX{r3LbXSPO<^!NEXf0NGE`Fy_J_x?D) ztfj|>)!(si7mlBEo$n;Pe_fxCB;h05{dM>K=ehI6FUi)E{>ri)>wM?EZ=UaV|Fpa7 zcdH*qls~O}&F5zlKj)POOur$a{Wn`;{Z0P!yq<}u70i5%#fg4A{)CfFdhyckS6Z|~ zGWQ}C-=FI)j?dQH&DSuR{I}gl)L~BUHH<7%F?iici@ob!;T-#PAf^r3Yp`zA*2E_g zTz+;f8XOhlL?=cliZooYEUsfWZ~5xFPm$PnEjBr14-ZykJ5*V*vh|X;UZQ4&oP1!U z7@x^;dTQ*?Q_M%+#!_ptEzUUaZgVL{HG(+|;8}HGKoy8VjLsj86d+k6R-m8^fn8z*O@`;r=CfohiJWbcrZP_YQ zMYFk{h}yTyTGQmg%-wZSUH2u9y{CDFM`k~doarT2U9AmyfciCl{v_*VZ!+6eF=j{C z7V*)ms;p_(T^QPjTz-EC$tyXPfNQ%&+8)WTyvUx}LJOmq)F)!lB_GM?G;yj>ys(d{ znN7*5XmKaaHtGChOr`6qkBeQpG4kw`8TOMb+GO>g$)1(;Hxl16)@!ulB-|ul$(Ry& zWmh(r8K23=^Hv%}A?y>{@%&659+Tfx7Bie8!m;XgH7pfB4#e`&+E%iB$p4@_G8MjT zW1YlS9TsTBVisA(q0?qL8ZIARbeYLq<;AsHz~|2;zdC`Q2_rIp+)e&A}Q-7sBN%z~MN zt*HzvJgyyFf0qZj2~}2h&R7=-jIFSlHEfh|-=l%+Jf}w%{^7m5dL^B9(~8+Vuv3=9 zr^86MlZ5RACKK-LV3;(1WMwL6v}J#L%*@btVhaZ1N-?JzG8H{dzXCqB-IU#Z_i{ z+11zZHY0k$A`4Ye*Cv6Qe4lPf$)SqE36wlt=c^U57h^KZ5Q|RUY+vQ4>$NC*s|z7% zH6cS2*Ti!l@*)$t47i@`tbFUS2iKurbtVrdrsbu4glT8p);5Nwf70~qVg78s*Mr%~ zUJRtU)$#V5MjbKsr&^AzsLGvqUGX;em6eX0XkCO3V)!st{{S8>+9 zyzAVrXCsLDcFINHvizQ@t(kbHtPH>iJ4|99lwHdU&DqOalXUNn+S?f6Vxu1D8;)L& z2t+!F)6*H&cHZJ}d+%Kj?L&XFnUxotSz0~$-#hbA{N|Mi(&)~7mp#0j)_Q{{xihaZ zNh_;#t~<+%P{#YcF`9Hm!j9_jv^^1KkG5Exu_^%{Y~dQe;bhZ(YbGwKvYQl5w*3ZU zSw6*MhfLi)T5>#bv?5xw2luirpDKTPBW>$!(OHa}B~MAiVV2}#@Fhoov}Rv~<}Um< zyZ_r#W#LaA!go(!Wag=}IAom{5w^;EyI2>^hn-f)m|9{mBjjXR?0KWL^{D<}Q8(&` z{8U`uo-5dcTFvxIQiar&UCG8gzqPwny*@W1u*)PYHw(1#_L}(9xbCz(AM%Ie?Lp5+ z>uGDZ3OHALVl`b%J9!vfUC(c|tegHKQNt^ZsrOS{`mlQCbL=}^$!*ojPIu5|yVD`N zSN9ujHJbE9oJ6= z3Q@OVX0aU=4!P$Z=iK5lKY7%%szFsi>KU2$sgL&e?8# zR?J>_i=)ZMH7ttu9`GvLC^czyJKbExk;T%Igd!2MK}q{2jUt3yqQ z-s!#BRLxrrI(IuqUGgQuo`bGxA`9|kV&GV;EV}a?q9e;z`#Z-bYAgBI8FDdC2FXB} z^(%zk@zGx_vK?ph7HuAT{mT{`qS?O1zc_Gaj>+iz?wBh$+S-|&BsW`Ii};3X*x4vX zqF4skGLIh_r^62h)JP0Gy~DE4Hu+0x?$zM8syh2Qmn-JBX8qy_NQ+%d+W?H}y)>l*ulf-O2o{cKtfn!&bho zo1FWuNP0$l9Tw6~iihoishIBZxZSDR;7yxU7XJOMtI3%kbF{i414Gv4vpajDo?pmE zg?T0#L z=IEaJYI^O4J3lh5cgz$luRSLP>$uZ@B3hov;cV-=YD~*-sb%Vwym<{j>2!;@fPgxnb^z!p?THjuJK7|8nLkTo=D84U#&-escPmb+xjm#GBprWX|16+_SRz z;*flN1!Vz5N>C7j{udM(XBqd6pRjGH^Yf~jHQ9n-`3mA)bJ7Kl3I_cr+7voeF^>Dbq#Gc*(cFX~o_2iYNE- z9D+S{RjWKY_f%wO&2^7oUZ>{LhKDUt6xH~CPqIf=WW=v4{Hw(*4B}z-5yxU7H??%O z16plIzJ}j;bwQ_7Qq#Q!=9-r2HC?|8E-$& z@4Rkxsn_jDvgUZKBr`Rd7MpcqjNx^w#;oYcw!7^nt8qFp%jEWeUuy2u%br3QFZZp@ z-(4<9b(-0?sGYT}Re48UOilDT&zpHDv(H4^zIxr0b$`jtd65?Hp?{_llF;LP%A^WV zO>1$_c?jX~oTR^G=qKAV<(Jj)g{mDVhU!*{)x}shzlV9>t$O}VHXrXkN^d&*Q73)} z!MGJFdaOOE&5pW{B2EabDCubXP$arNvI!l((e$xR86Ft zmp;uZzZUYwLK1aF@>cjX5!D!9RoUCGSV@=fqsMg8W@*nWtnf%%)!A2n`)sp09mEGV zC_LFgn&d@S4VocWcayDbO-$q-l2$QX=Z_ami2?uChqHWd^%ZBB?{@T)2)iD}XqXIj z;C0V6?$?{Y2&%W~{G>dQPORy#2Xn;)u;=YBDk^^FQ+kuBxtCA(aadfsO_ZG;PY8zL zGCQ%*5d-Wo(|5Q|)i_v|QBPbZvo3)|hV_lQS9xs8kzK2D4V73~nCoz`ua;DwnKSj? zotWM=WG{*BY!x+G91D5PF8jFTL6ypc@thNEccVgd|2H|7LT~u zn(Br$&3zu2b2^D8P@g(mH1~KO4DgPp=3qK(^xazThwY6xoR18v20iH{Cly0|W;5Ax zdi{IuGrM`4o;tcY&Xc^;&~f|UAeo(|z^v-)22dvCvH{kiY*+)5U^Kh8tUukHWc z^J})>`~KM5``*<1c@7!4f__r`ivHbj+?!{QQ`;ikKc#-{-UE z)`uU@u$YYKu!>i$xi!OTKqD9bVY;Zef#By=uEW2}aV&m(M5wPKhTsv2`MQnmVV-3< z4qwbYyvd0#y{8ym!zZ)dFk9i>3}K=QN@F?nv+6K)g!;IC>CeT>tafT8%hUSyWRBJDmToSb0Nf;uw)bw)bUxa#&s8^A|sj%`fHAa0>B;A*WBL z@bkt*lTMpv6=Hp-MLej6C-ExcDBb7;GqOWs7ti=VcJ~gsy;a^Doy(9iVjHumM%9ZK zclbQtw=2tPT{f&XfyK@e-aD_KNWRb>@3TD!LqHW{cJ_SC6M|5kb*nIST9?1=jGf|_(_CPqjb^6`gmI;3_)JWv#WrWTd|__azNvq1*zIBF&wh1v@9ND< z#GRG%X|}r5lfC7(B;7v{u?GD4Spu1D#WrBU^FPq*dz&@+)bvBl6-)^V( zIU}%RDGwM-Mt@kYqcEImXAif$X@Aw0Lm$V=A!LEebUH`dKN|*BWXFn6J<~Xru^rCz z122@adtT11v0UHSX#$Jhjy3#M;29^KTZ*;E1H-4|MRhwo3QdXmKAGtK*x9c>;0gWp zKWmy5^_jxi`uhdVY4w=kah)ws^Lj7()c&!{je&6TXAz}t?8d4LdP&*Kv#M^0t@rjY zj=Q^w&Xy`pg`f4Y`Qb(P>3q7{{xrDfPT%=L0_h={MP3{|_b=a}-m}TjO8L?xt7^^H zo89gV9kN4eHdy{l>3rmDfBCPrIx09P7&Rq6V$HPsM`fx?kvhQZqhTj1hSQ|XcExXI zTXllv-_F+Xx*ZSO*5)Fc`{8u3=bf?b-EF^L`*1Z=HFU>%THbnW&m%PY$CP~7Jyy40 z4=KSN^Xqo(Hi0nretay)(|Jm49X?-q&8tf^ey;k}&YRkM4S zpIj=*vr6;xjt^QhjsE(=lL|t;ZZ79dVsG6ec~iH%P&dVHR$SucD(<&8_aDDyn9aGd zKGgg#GhV$}$D|yr<`rgWQ#xkN?xuS=#hRL|#mPmt)q=$87+S>t*V4C+?Bh|V^V}Qz zrgm?#=A&2H^TBGTHjWoD|3{lnke)h8@QRD|Q~oFSy1&yRwmv#7>2vIzUEP$1onbxe z@XNBV?6HhReaQ85JMfq%eVXC2@ZT4f@b_NVCbiBNC1V{}byM&+uRZ4{zb+xu9h=@d zXEdSibI$yxpX+MAbswPw>=f-ui_zqVfpS|K3ML4yBO*@j%L-LXzZD4iuHu&QB}x4 z49*7w;j(k8+L|7YjqRpMtLClS;)Zgy(hFWx|4im~rkCwkf4d(W5B69R1N*;;z|gex zFcy|~Q`>I};`{J*++6IqnTN^xP`50dh-bE0yZ!2L{9j&YU#0HT_^off+Ku(GeO_3p z0(7*|?=u^SL$m(A?O`{ync#c-X5u%LY(=c%Je)O|aUMUkQYW~2CZ|q}z<=j!zyDaL z{N{q6i{Ec&PIVZbQ$Q6%arUww(x*X`*+wO|-wQEK91}w^Zx+0F z7s?3pX0|8pJL%nTe2TZL{y0Cs?jCW?=>1;?Jo9!_(!KTP@6J}u0D4OKce*yuC*~z9 zkLNb)1{hnjVN1p1V(aT_x$bPvE~9#Bi5JaxnyB-s;~lba$O9`jVpKkry7Np$%`!m| z7pKj|(V>j$;Y$^>89r2*(}XYuC6f*Tn3nT3{dn;vvgG(J)Y@6*s_gQPl6ant#z ze15vuAq?%)kqjy0?z2wg^3!lTI^pMxg*0r7n%}<@1PLZg2Ai?ZR;J_S_C(==XA{)r zW#4+8ee5xDmdp0wgxPt+Zn>#AWm&_mJ`Z*LhJrQCRwHdf?E{Ud`h~~GOyq`8z3Qe2 z3f_J2Ejzn6w~txzD62Pl=;-vk_tp6LrXaoeE{p!R7iRN<51SCB5q!zo9iiLwE5l8G zPr5LvCU2h0v-Ldooc*4v|2Ny$bU3T7C^D6@!1}QBp1(C-;KErrVzxQ73n$JlApGQQZ`fZ>`gT&&8x}bS$F1y(z3zZGfkmE=-`?PhBQf%wzR#!k?jm^Q6M3imGchP|`=}58=BTs!#Wt(n z%SnA$VqF(oHH$;CRm$L1UYx3N-h1;ax~V*>Fpoz|sSp>=8FpdO+oOpr>-hWhvuDa{ Y>T;S`cPO0V$M*cxmf05q{qN5BAJ9S}&;S4c diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/f16_16x256_16x256_16x256_16x256/input2.bin b/test/tilelang_st/npu/a5/src/st/testcase/tadd/f16_16x256_16x256_16x256_16x256/input2.bin deleted file mode 100644 index 8e916c77ccc55d57a02c7ef3b593a483d85dec01..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8192 zcmYk;TXGxA4n$E6Odkq{EqhoR`pt#SzCNK~vl@j*B9T?p$M?tg$FGmCeLg?_{QUK| z_3Pu$>|YGS>m``+OEyZ!m>U9MmI zhDT%!WE-08sG?h4@R0fQ&DHP4tNb-TJUsjy+SyLS@pa2&sVaY7sPr0-Cv#1ek}O-p;)j|(hqcTe5Wh;CV{$m{CJGmmw%r%Z7p z6ZNDQ-;`8;GVFc+@!EuOHNVT!)XwW-SalK-b2^syJ!M>FET}SN>=|p2h?FOuL3Q}rFLAb!+F&|&u0Qt#_arN+e&8`H=FXasUBWA49gFbP7dp#G?`d-a$uYy_SZQwL(`)qnVPA-dFfNGt!8Z*X2CMx8p2$s`F2y&IumbR9E@ezU84smC~_{=^1jKkJ2DGuYFTpC&vuT#9O#^ z5rf6=ep-Y*e5vS75AwjU4)&S$e!>i^S~80@KfIbMe{1ZR2_KlP=p9NGWbb}~hXzv) zHmM*l2p^U@VX-bBFWK-^e-z1XW~;qtxS$$mdq4l-)RR9vuQNOQgAsE@=}es%*WLTx zkb8d^!M<2|F!ywtCjK!Y`(fC(YEOpP@YD?^4(Hy#=l}fD-GtsMX2X?FIdxqxs&*RG zpmp5U!SymLlw(io17;dvwz@W?LME zlv{>FGAnGUw#xcBM)u?f6DD6azjP-LoO@mk zb?RK*M$3i`=f%%- zkI$?6T3oX+!ES?fb~W^+x(F3%nzxyc!TIp3n!a?r413Cj+f;#tuzqb19t^;!bMY{n zUimKbEco_^&+xCK*?7tX1zu(6C*v}0hYGyak!isUlTYWv>OHxOZmq*-=b0qis=V_~ z=UeIWxqYzIR8yampi5ktac4|@5!tqT86Z<9HQRefZ1--~D%Pea2cFO zOr%_upVLUSt&it-ML4^Ey031E7movb#bB74e*Lh>03~F4Yf`bD4%1PSl^g~TJ|+wwFICs=Eb@lxZpvBUzPPv7GR~T&d8@C3 zQ{Hy!nw6f`@7c$k$RPc^gDM|>6mSHJa8 z%OOAidTD|bg*P##Xrkc~;jeSj>3ZH}WM_q2&sgxgBg18xO+{~a7j6Z=DX6A*Jyvm( zKZc+-8K`7#)tTtL!PT^3@JDTMs;-|9KhZp(y3@0aFr3madoilO1x;*qnzM9ro_U+2 z9Nlzs;nei?6PhM@_rTdGmRIaI>DLF=AjSsti=U+vY*NSLELhj!kcoKn6y9*i`!K&= zF2)n4G_4;N{+@Q|XHm1_)c0nv9kYCNVXE};|C?(bo%S;MwG^FtFy&noyv;)fSxuX@ zK4kOxYqv)_XlB+lhFgeLz?Uh*Z!H3Q z4|h$X>5C{$AXETzoQ);f&Ud;Am*2Gq0@5{dDTeW)=hnr@Kt5gjOReO(#elM5( zIy_dlFE7Ne`7hfn{B>1E{hD*<_pQObHxJoW8jE}`GejcTu?Gd4tiJ8Iv2Yr@@ag1O zxyUJT-`$UJs{$ERkiT!x`M?f^Ihyvx%?@!-Dqq#vRFz4sN2Mx56i4a#7@NQPeL55R ze9gOB-Fi-18N})mf@aw-m8fpLZa5K6npbPO)AaUm7|S})u}Z3n)X%ef`1gLD>>EF@G`(&T>kwnJ4$R$6VHUXSL{!C1 zktW2`vTmAMdQe3lk5b_s^YerQf@e(UdF=e)s{W_jjwH_jq(y!^KcJAP%wa_qtO(gP;`If=dN@rJG*=z#ma&AysNJg?6WTKKjv6vhqu14b*#=0225cGg5jA(S90K{ z^IPS#f4@ZAPr0z>-Two2VO!>9(L*QyrRHk%YdY??8*gUWE7 zZMB?>mbZFV*=R1^RUyXpI5dlqb$?sI6$U+;?fX95KGe^I#HjMzI^uM|lAryna5=Wx z_o-@GUY8lu>uq+m((UbCofxL?vNjL;dQOtLq}Jt|*D)sBKJv&|7A(^tmZ;pj^QAT~ zU2K1Ex`4&hu>&!Dj{+X{F~&RW<);Gny9<6~e$eYcJ^rd}XU|80>ft6Objr-vM^A@s zKF?0r+rvTH;?<1I@pQXys?|ffRX_cTb_$zkEPb7SNNKCin~-z2=<_(F-j*YbcXMp7 z-ppimyx%O;#0;#RCesnwd5wA4U)z6>!q=TBEB`e1+M5acOM(1k_bt9NA$R9=yv<)t z6?X!LXWz5Y2k&lc^~)+!9n957_d*_?F=nQwQpVSN)lG8KU5;{6Ckz;0_Su_fw+if- zsmnWyhgbjNaDvlN%js@(-oE#Jzw`UXP$v>oCK#jZ@Fx!0|EvG!?CVAZT&uXJH&5MD zC$^wauiFuCr`TCUTtEBfsR}&tFUL*oo=?Z!o+_&}RP$<5>U57_w%J;#YgT$m8&=kk R#2%l1jpeXzqHsg}`X2&$9ku`f diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_16x64_16x64_16x64_16x64/golden.bin b/test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_16x64_16x64_16x64_16x64/golden.bin deleted file mode 100644 index 9a1d9f0c5720d99e356a91f6d7babf986ea08d4e..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4096 zcmYk*L9ygI5JXWyFa$#|1bqgEAP9z_RAz5c9!BV;Pcqnb zw?bzRXYh`>1$hs;qh=kuo>@=ls@dNp-@{+N-iyA$=kvX1p0(dmZ+T~iEqFucsJUg9 z>+;^Hj(LO@WOPP*RAx8N+PyKtR%(aNS;*{9-hp>}I4XC)hZ$OV_FlVry*YFD2EUDW zy)~j|vAmgG?}$Fp7ImZ7vzy({;D-NZ*4u0E(Gh;Z{8YbXc7r|0^}OW`GlfiVdFIvY zna`kiL?P>8uU+Slcq23auD;vn(`}V+)meA5cK#LgcyND2d)Pn9laKP`E%4v^O!*eh z^2{>a!S2qyf0Ta&9-Yb0DC}-zcK=Sl@T4Q-xA5)Rb!F&bge~yj-1-iX>BzjZVn;I(K|;+r3S_@#ja~J%I=F8}`gC^x*EOnLldQ*O8fJ+ec;g z9`3xUH={oqzWpn^&(&G~Kjdz)1^bhZnd*Z-!g`}+KEn(>?4|CG&~meUyZH*=_jse& zKlz!t=e|YF)Ehdz_?y$AGmKzvh_4Ty=Z){I`^1;&Uj=^)x_+OTe2?n>qGyotw$N+V zyE4mp`tmo5BXbKBJiE>ad(C%sdv-m$?E8lKiuXj9Pkmori@Na!|19+$A1!QQh8}d5 z=N*~1e$gXbLEhrYR(H=V@1^`%yZbxz@CIFYi|-Nk~be8AN=gQp8sNDVP z>hySL<+?Moj@v8fjBqF8>pwwP)`NTVs9jfnl<#}$w7`dseXGnq%hPqMhbyz}sae*; z9`1N%Z}D|nc!Qk3a^Hi0Mz=uSn04)RrgPW5z7FL*;mYpK5!}H3cU!X#{|H-ftK)tK zzWvSYu0>^h-YB*6?Rwst+1qEAU)kMtd61h&*nzR4IIxTvHw>oC`=ZBHmz3i%4 z#~a@AhFQL5I)dIA_S$E(hnBn5=Z!Gqt?Z|l-TU6qX~B*8jj!L9#yi6v^r)WS(p%o} MmTUy?T;c8i0BG4`0ssI2 diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_16x64_16x64_16x64_16x64/input1.bin b/test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_16x64_16x64_16x64_16x64/input1.bin deleted file mode 100644 index 89d0db3d88910d97aa5cee383df1de28737c7db6..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4096 zcmZ9K(Xk*Y3m9Ix{e{=RKTuogtp_k8W_r(j8G5w~HaK9V9^n~wf6p=}?A>)g^-f(jaKksh1vus?5;~afb0%EcTD(uR_|j*IJ3bG+t2jP z9Qw{#*T2+eo)=xe8P7K}qRA8Np#G`vo4pfUuUXHqejfbpW?$5&zhJ7P9e4xT9r6SN zSetogQg3j>6Bggz+(C1X`ba(D-n;kQ@!x%OZs*K8b22BJ|N1N1g!Sv08{mQo+@)sU zy=LpZd*PW;AMgRPP_y5_>hGvd>U@b0i?~J}-YjtseJIrkA36{uv_0G6TJ<>1uq?SAP?n=Lb_q=d(FGH`^ ze?v`7v>n~r*#Tehol}3`3jOOh^X4+!Ts?U+*RyAKIHRs#?U;)%+Yc~5toC;Ez0|E= zpWKC;(O>C3U#RK7H_5=v1vC7z={J0U+&f#F8F0^DpIWUM`GN_0roEXHt-6_>ZKgT3 ztry??1N{W%^xfIsnG>z^&Ac;mf3?@cH*n5cPw3mr1D=7J@Gcwn`~fduZ{2TkqfK~$ z13G74-TbBB`oP|E_0?rj-5u38df)PVdi|5UbJjg~Sc3~TaK?TCx%H$Lzu$rS0`nUd zCz|}d)1aU6zVF!4Wp(+2@}xdsMC&{3FIeD!33_k)@13frVY0PtB}Q7bE=6rPe#)m#2Q`F{{48jke$! zs5hstet>(e)%16?x7)kW3v&ma>Ai>ANZ;MP-|O|<5jejM#rE<6vOA8+={nYVx7%}u?i3*2#0F%z@*Zf4b&ljYvb9cJ&$H)>vBmN)1(v>6mwAGpx4 z#;ex(&8!RDb=2|do;trrGT$iHXSisx#(O5cWew&Vto0Yp_`Juwz+LHGYFW^*M+?25 z%sSuE<5g$Yr{4H{gKGbLviG}AU8rw!=6O`#u!n=zuu#eO=eykbexHwC=H%sQf4n=a zMZZ7Pc=P`F`>D-Kc)R=czQH_L+fRMz$@~6!XHmUn1xoUs<`1KzJ%Z@5r!K6OvMP}yLX z`wDNan%UcMM%U!t^?rl7)0qqT9rEU$ywG@64rGVt%wCOds8M^r_syG@eY~1GlbK&~ z+2GE+M|J*?_w@7&>?Lz1=W2RS^{}Gu9)D5EH#A<}p7F;T)aN+ZOFZaKX74>x6Zukjw17nb>lBMr(clYqx5j#?HIq&cjxDwO(rGW;mjGYT8rnN zPJPh>s<~$~W8Sd9o?gzNclMiYvb^lyXxqw}skb-p1xoR4q3xF_#j*0AAz z{OePn>A6GS&dm9|%}ngqc=F#|L2JSR#*x);A)NH|pGv4~W=Ue0VhsGB~ z-RV8zoV}xxUv$A%>F0N_)~j;B^RGX9%#-Oo|C^)EuF}H=`wa{7o{g86v%@{#w-mkK zF5Z-$+3VSd1qT}TkX?Cakw5;o$NBHRMg8tU@ARzQW9@w2;p}@Ky@R)3U~Zi^{PNl1 zjQbk3r#h=&aE83gyuR6dzq9}8QM=cCgX)f++Dv~&@8qhPZ%_*w^s2R9&F`Rhw|Reh z*7nSv_pPv|ei^>NU1H|ldot(s4bg=+&zCCtUE?lp@#`|6oBckX~*-QbM6;N9yvb2Ge|95ZIk^*!wU E7yZ$dw*UYD diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_32x32_32x32_32x32_32x32/golden.bin b/test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_32x32_32x32_32x32_32x32/golden.bin deleted file mode 100644 index 0c11a52030ad9e8601c3300b9fe3aaf87f73bcbf..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4096 zcmZ9~LDA$o5CdRC2!ucg1TTX?FoZzx@f`<3pae>wtiGxCtEoMxq?Xj}%MfTX>eIdjxX#%~zmD?}Z-T zf=}=y13i3%NAO1LW0@Tp=5K-iCoBh^tQGUtd`ABl9Kl$A_Gi$4vbAf^z5lz>vwFVi zF1Xt(qgS8a%Xg>!yrVC>0=*gVdOCb~qFZTf=Gz&CH{KCF16l{*>tsE4&5%t=+h*&(nD_+tbB*GV=)U!C@|@d2jL_thxQu zF+&I4@Lr%>;Yqf9Ec3qHOnYPXM)Voz!3?h8%^f|STp!GQziThZq1F8DRsRm_>yCh) z@pN0*oo_$l@eZ zR++B*pYZoq`Sxe@_f}XR_+O4bf?iC6Yy@_f)@gx!Fa6}s9hk9f1aGuGzVGRQ`3!Wx zjAgB6eeux5fmzOL#cnMBjOT5T+qnw%^vtvy zPkOyP84YgwjWgH_EuLK&Hm^Il1xIiOPcnCS`oMQ<1o~L-4c4AF?X0|4uDg7i}m3N_(0#m+L`6`c(m*jc0(6>^y-gjcg)8>nrE*;?Yg%-w{-gI zW%m7Y{uYd&1xLVt({Tgq@bz?b`M`r2Jelv{jP8N%9>{svyL$Y-y5ToC;$gnsJ8X|f zuiTvl*{HdPTXkpo_PynvE`5j1PuR^F@JDb3J@^Fr;D#QJd$@)9ceGuzJsM>Ct1C0l zVC}8Hp(kFed};H9WN%^5Y?rt6$lQ3r6>Ndd5%l0r7rbSb@4>pM%ae`L=A&lawZCrY z$$Rd))1q70o>|T_??M03cXA!s6F$TGpl6mH!4>p$R%XvPm|L{Yr)InEVLMy+uFhMt zt{wY}=`+yX%3pe~*-cq1%&;B){=-+ap7(T*U<7*ji9c)hJM?6;c#qawOuNwocX12c zkzK(Zz5n=n@;AWte?9){_sY+7+`BV_9X)ym`fiw?phe?ezD!4slhqpg}0aH>_ AMF0Q* diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_32x32_32x32_32x32_32x32/input1.bin b/test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_32x32_32x32_32x32_32x32/input1.bin deleted file mode 100644 index f2db4f3519cfd7607d2fef23a76466f61969c63b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4096 zcmZ9J(Gg@R3Y)LLn4FA(ZQ;q^Vc?p^_}w#&2efF~|7x-x%vJK#P(3rk*mn z{y>WrcJ|FvZ!nV`_@;JF-9EFO%)QnVyf?z%{NjlNCfFbJR(I>Wm^Aq0Sy{iZ6UGzzWuChLrT-|md_g_*qz?|}Qkj-DSL zXlLx}Z0F_rzSZ{@y*vHpPs0<~)BDWoJ*GC>aCh6UZr}Iqll}p(%xqX+4A6R_Mc?i_ zKR?-mPxSuif&SeQZS5Xs`d!Z4)F)hTQ0w*WYU`Wc?ri4XnQUcN*4y!AcL@F7^1Ijn z2LAwaxwSp&8{Zw~+n)SaZ+Cj^oxnL^-v#!R2X+04Ucg>`Yx#g3UhvE9$fPklW)trF z>gH#MH}dY}94oxp5G>)E3}fc^zX z&CG<09lgxJ2;Ljf_v3A}eL}zQ$ky)P@KA#pnCmy!*Bi;N8FP<(BD;m0C&w~-fFJ@^U(SyeBrGQM*Qc|H%?1dwb|AW z-()B3E@#K5p3w`||7GX{`YoNiAAh0sH#~v+?f-iE2iRjKE_$Rkd%JVj(g!>F`yI`| z#XHHSIrx1OdEeFlX21RPY%<{;6K;0VM;{C@!K-t#^^ND7JG#F62AAC8Q;1)omHF5)hoUF>fXQ6^OL<>teL&=cc1#;iERe#tM)B> y^wh%~m|w8KiMD6z(Jx!~E&JY{e|TGi_V%rd=X(>j=gaSo>>gwXIIp&E&Hn*cNsu@cGa?-CCA$z)yX|$c0qTmm*D&k>z{@XxbbTDP8n>_H&yrE z^_fkyzUbXQztNe$5#B(4zlZF^U$}Z_?w*N$9NX#wV7C;yH7HCJIlA~nU|3dFoC(*%k!;0 z<9=s{*#ZOfj=sZ8EiU?i4Q8~y_paz0*7vOUO}s4Z$bFk$*xA8IcHr)DroU#<&fGIU zUcG|@_x^T!cfI`u`X2Q|--^FrC->7Z03Ib_nwLG?u`~dzbn7h{vsDQ^BFcfK)%iByTiSM%=0^Z;2G^q zeZl#K9q*p^p6oE6nRV~Qdzbv)`nm6A-`MIiQ!lZ^H+e$$D|Y|IYeb$<_yU zCUB2h?zgC4FdgeDi$2LAUq0d*X%p2Jg)>XzomZz|-uy=&fhR`mnqE zo8JA*WV@bu^JIPJ#1C*l=kBeRH?!_}cjgUdp1WY$iI(5}oITZ<+Z)MtGIy$-y^Ot{ zzj)EN9<;fzH-MRQam~VZoNe%O*v#IH)(^DaT(90hEly_62fE+J&iAgR2XCjl(D##V z^a&=sfY}B0YVo~e;SX@{hWie=xgEW*BQEgF3?CT`V9(rp%*A(~`Em|q_sAE_V8kDI z1Nj2G>IF6!U`F?i`kTCOKKKE4*ty=fG@JMpID2`1gZ_SI?>FR|+3P*N&+Y-{@3*CA zCKi~{vVIi#Hs(8-d;;%vCKkBK`&~9(=sS~}nNR27`@QsbE;!-$cJu(w3UklOW^y~# zEBRwJuyfS(uMB2*!z2EhrE}E4-)%m?@7et1&Nuk}A8OAZa=XG#G0-bsZGXbYoOj4? zN4(ia4gGx`J_%-~%adSyNH`_r|f#dGsjPiN2m_jdHw2~YejdITLF{(9N3 z%-)c@*YBnLJM#=~XK*X$xxw`_+r4o{^{@s0Ui9SPebrqz{(kh{_;=L3jyL$U1^x`O z9?tS~jzC+O$?Rr|Gi-(38&~v=>dP1IM(FYN^r5#${k%_Rof+J{*(smuUqRnax6r}} zUwUq+8!E%htYdcD3%l7nd(D0p`wUmjdfxD^+}k6{H+%Ca+`8KXzb9Y2?C-58?~n&? zQn@><{LFlZ7LFk6rFP$6`?m6UGB>;#^k(^I_BXSfzen}WS5(iPp82<@%a`-f!zlKs z?8)67J!)TX>t8`{y*;a=yY9_r`QCEF>%krDtG_y1yfgfD)8W6>^&R_HxSg445ANtF z^%kx$gInr;WqWr0hWI@?!yZmv&_BxamX6$eBf9?m|7G#qy-^+W7G$uag)O+JzM)U$ ze5~%u?szTub`LX{Uvp(W`s*LH|GwpC_NKSwbOzqq^?7sz*%sEncY7<}&j3A`8#2#u zha>Ev2RW6^%p>|meakn@Z@A>O^Y+@G%y#G>fk$x$-7Bo0SN9Dw$hOdfzIg<`y@eT8 z#~ZvWj37Hh4}YD!DVO=iUbAnS(Sq*rk3K&gy`lto>y`7r z%7q&2J*0p10Ei?AgFCFt) zEPr+0%4c%#J>keaLJxAf7wt#QzAfuv2H&wS-~Q!|QNC~J`A(~z&QZD9H}!jYZggAe zmD#_=JF4t9OLEw8uxweg!@K z8Q#k1o}KU3-lE{UoAG*d3q7;mto_TrI%m958Skl#|7P}{d<)*0fiIVt-L=5`vg^W5 zTf7<8fB!w6+a7OJenoZpi?{CFuD5tk-F3Ht&ftwBcxMFNFD#vz{S5kiU7o)As(p{1 zL9f-UKQim{m-g&4yKE0K?|AP`_Eq1!Ja3!zXZ%((pC?;C=i2ppy|E?lLErq<8*WB; zvh!p!3Uez~_QdxFe-B5^b~k$b9=)RaIy|$U8Pgu#AnWn4?s$4%bVS$B%Jb&JzJ-?F z8GTFn=)oIw1b6znvvl3)`A!Q@W_{Teb+hheI(IN1(OI6EcSU>HYB#%~@SQ!pb*F!Z z|L;A0nf_6 z{Qc?cd1u`{$#~wx3_UyV&5S1;*>MH=7Fv8W{|ti{?!e!}myX^W^z2(yPsdK}dl!7dbHR6`)^d&-5x!|6FSE*ZX^3E`8*LOFgW|(zeb94{x?0mD{{o}XnBfGcs*Z&iE zGnpIrn>|o?M}PTdc`MJ}!x7AUH{SA%C+gNNyTe(`cy3mA=@swz>G6AZ-7oVNozYRA z%$s(1SI}!wK5wg;UcsB@86MqmX717D>F9n@Z;hJu_AuigHE+?y@_X_gzeV-Y!WEv% zX68G%8BuTO$z}E){4V)=v`2Z~YQY;vshht}w3q*-Z|3Wr;f=StRKEIh-VAPhi`T-s z)AcS-_XsoCb$KJ~p@%2kwadLDE3%iUa| z1vh3H|Eb(}wjk&4LEoDr3Z2#0g^qp?XVBY&@A8)NTF~X`@~Hia%13yz>wQrj%pjv% z&8_A=ik01!pOx#$dyw(%i}gEh@kW^WuAgb;J>C|MphM-nE6_!ck2_wA&hQ0qyiq*i z$bJ>e_l9rEFoW;)^uM6@gi)TZ+-~;X40;`zW&E%9Ewky6@!@~&XMLr-rj^}CV7 z8?v7H2zzj2e&gHqcq7bsU$DzC1CK6m`)BUiTX}lk*VP-fQ{VLS$#h!qw*3lpW_M>F z!43ZoPi7pMb@;I3OV(@Vo2l*@+^*lxoqYs%a;kd{@xXXd-7ExdKVyrsHu|D`jdBkZN#T(>-TeQC0zX1#UOGq=j7er7q` zy3t+m%p>?Y{LXZ~^jD{4Zt--`E7Mthw>rEn3h!Jst}9^UkH<&=Mh9$M+@te*FJ zx=-*1bli>Fx2PR@n8BO+cHLe()or!Y5$@6>D(Ahq-J@-3d9p3&$?wnukJ?+!SCqeg zZ~8iRKim5G^!0FtFPZO-%x-3>9`6X|9`@iJH*fHUnI|9Bnc45kbXsL=)_=oMzP*Jn z%<9j~{42XVZ=0dlOZRB2+0C8#2xs{`eO`-hQMl20D!Xcz@x4#6JnDOIZrmS%uJ>o= z9^BCNwhnLkTV`(@(GlG5QGIvvo>^bEa($kj`3iTKb?3GPDj(qtdvLSnE4vO_updFs z-qT&*;F)JI>wcLVFY^}O%qUdi1IeUuTvlKP%g#Pkj1D8#3!} z!5i!EiOO^#*PlU;E?@R##}j(?qwo&j-hx{SJ^L9(n1OHJYVI{3wXdJgzWTan<@&I% z&Xsv9_NaV$v;SXj#PjwUR+rbp>K~bTGw7K0*WH`9czfumT}S4|e1)UtJKC!EMP<4( z=xp-fy(iuo#T7>Q`%ZLq{49GRpTWD!^Iosr8+T~o2xrj6%C>mkp}knWBl{Je^xr^r z^=9TC?#i1z^B%s;e|hFNvv<&fj`@kdMXx~5plhb~FA8t+&EB}f2s)5sR(@rcZ{^9} zWOf{d{>&`@ub0{H*zSA1^f8As`xe}^>{M@*w{G2C!8=QJ`OBM`N9FVm@-6t@9&NQ# zZ_lv&ky)>onp<^zQ(v~!&v|BdH-iqZ2cDU>aFninCR=~DC$sKM$9r$)U(XF6Ti8Pj zeDKi2412gj3;db)R?Z*6_jo!z9K{};<;l)4O5M0$KHY=&^~`v)Z_y{(OXWO!uiT%X ze+08TxgKQR8(}8%o-X~u8{}9wGBa-m{ZT4wnceBM=o#(d%X~%kc{)5Ze+zq<<=f@F zBj|X?E^C3mc+!>gXJOX$XTQQ;^r-K+r45;Nsoon8BeULFUXRZ3+p&-Eq|a}`TQhqP z-r;@GE2?AmuAfc!3^R<{&F*A+tJ5;i`1)>Fwr5^kng4px5xk>6gFbrc7WIxd-SD5x z-_Om6=MDKU-qhFKqcc!G^^W&@khPkxsC*_rN=JF?Eq7<&>)3l(e`dGdzzn+m=RJMd z;0ODY-23*o^vLX1&#XiDFpDkf=1(j4bDY6@x}#`OnXdT=Gab7wZ_RtS5o4toK*y#*U=2r6^)uBhw z(e=CH_sn+l6?A6!;^{t7^qO_HAmiEP>%Akp+c$XAOc!p=tZgSYj2pz<$lnfF3h zkAG&?ZIw}7{t@ig&pb1u=XNijj?lsre@4C0sw3Nj@9}YzXP)t|sJutd%HHfPGtcZs zhd06u?&a>~da_otzK;CKOxJ#7-h+Dzo?Q2gj&PUq(6Y-YjdCgsW;}iS z6=wPCzGpX6x$avxyiuNe?^FF&zT7u><~#JCUAFQk^LmHpjYZG=#@nMh{%!vK`;52z zrC+==IrNXhF0*@Y3oCcqDmTw?XTQR%p1!*g&YD|#d_8>G-ODM@yM9)gd-D;V?B3$p zx1j5$Co`YrQN3B7uJ;E&D%*?CZ8&5_i(bLajOsw=tn4@E&*Up}XZGG5ytjotc+ZTU z9?v|2KkHv7e$RY_Gd#8HcpEeG74-Of?QVG9YBlTgf7QG4?T49prmN=-$b0s^_A9!2 zBl8hjdFw5ndEMBr@_N+nhG*V_?g%TlpXK|RbiSycW$i5<-_0wFx_dH@=oKjc2r~S3 z-Rqs@uZ|mj&s$&c`&^!V207N7{E_`A^kqECch{ocJHwm(&(GY$R{B)W-RkJ;LU;LJ zx^_98*|+EjE$pSfIcwI1+`Z1wn4JLZ7Ec@x<#@(P9$ST*YSy(+d2qLS z_A}WR)p=^?nf2c27WK}|yhnR^hdJnDD_uVIuFQ?M{H%I>x$Fwwm1Dih|H@na5wFLS z&8YWgIQTVhwYzn2y=Z!`xdn?+<0TDSx>&^vvU6O^jr2BmCwqys9UQ0BzNb} z%73%VXs=mE_C?KjgZne+x9r}*%*@x3&G3}x9i5}NqWX53UEbr3^1SoaF6Z0z{MVLK0HUItWH6PI{ z%yjhZdfpkOJn!l8TUeeO`&pTejA!0z=AGe9{|Te~9<}SY=ofhU>xQTQ_irV$zU=+c zz52W@^Idt*?oRFv+QS)U?L6;&>GHSWZq(eOa@iB#>}K7`Wy>F#XW(z;Q=V)OSM9!O z_q`T{?&?tA)H^cMGwNQKH{-pP>#ttVthe6KKZEXwLdJ*g4063Ey75Z+_8vW=S2zRz z3*L}>12Q|@*s=OlcFD^h(LGSCzdJp??oo4#{`xa}FZeU>+fQC*`<}T4`elE@TfSb) zJSvy<%vbn(FKVA*4@da>HyOQapSjV!g87WX;fwEG^YUpAGn_&G zrl+fSg;5?|_rB4x_dsXOZtd`fUYFPWMEC4?GyCRIJKe%p{;1upu6JZ^DfIQsXH@4Y zpSRv@@iD{le(&G9*^@)B2k-08pS)!s;S5iGL)T2ds7#NyNBLB?_BZnt&-*RR&`Zrc z^z_y|Gq1OKZtc5Xc#>~X{>u1Q?cUNIwZGYATY0ipGsPLc^m^3$yraAkbGj}#W|yz~8PD8;&KCZ9QMYEE-0fbfYu1tDtbIgVkoPddsGX;Kh4ph#{`znG z$?m3wr9GbAE%wSr)V=rQz4(4S*%8jL$&*bQMvzgtyN1k9bdUBRpFu~?Z=r{oo#G1X zy)V8$lYErF|8(?v-STy|z(2zsuCSHrnBUyAnt$)kc)m&Xdid?RXG1v|eLU%aj6)%6B` zd*0U5KPz9icHI%Y(Lzseb>&!pk6U$k?vBzs>c%_X;Oji`%roe}*t!BD>b5H(-BQx#M5nARatc(vmnU4DwuJFd&qcdvf z%a8Eqmp7vF9?r1dUw8I5*w^eW-PM;J*|$JfZr{s4qZq+^OLg^*Qhgauzt?wrW?en_ zS@f2tJ8R#f{b%>R)u(&1EAZVv(KV0k>wOAu^>oZvID;Eq_j}la-3)!Z8{IGHywMi; zcj!U3?sU+?S-Ltry(=8So%iG;$Xe*}o?zyAXWaLy+?5n?Jo}uT~J)Tlot9;#D)#WcwZx45Q?%yy2zt`R~^W3$lyI~Grc0KzZ zRz~%Cy=L7rGkxpEuB&%drVl;-(j)Wgw#>3K=t8#i$*tM#RoPLy9{4=i_i)y3?zx>o zrgN39ePqWOW^f~O!*AiqOncBf!WZv~zLl*mPgiFT_AQ)Y`R3I(%XNLzyLwkJ&(gnK zr*^s95zO)<@XY9?GdhaZ@h$UJ9sO3b@5=l;H@iJ6)87K`NyeX{XSVmuM{u{-th2n9 zc_xDn-@R{*!0S=|7WVMQYo&Ko4!L(m^bAMf&){Z{!rp8DO4on?Zk~8_1pQvBJ1gVc zb>v$q@5<~wzL_Uqf2N+jl|Q4hJ&e*GZRyx`n^OXXc|+_J-Bt zkD%{8`5u+q<-A$`R@oKRlfTLITNnjT2VdsZp-<4qR-LnEZ}67R+Sl)FRCYwYGlN_G z9xfVhkAhEU(EF0%3F}S!>Tc;k*Ug&@y}S`MyW!i9um|(H*O|$mx|Qj)e|Dam7QBa9 zv#f`&{5A7eHZzalXE3+Y9`zPqZhpd9v-}EjH@@Nbv4^cZ{h3*Z-?O_hxlN%kLx7@G2|M=eI z%hAFS&eAvc-WWf--ukxwU70)CR=dn?k9tSW)3Ntp#u<94-Q87cf3mOpExTFgjE*o% z?`R7;GT-P?-Q{2WV4p#E4@c0G@z(s7r{i`Ma-Lo1DZJA{56hd$)|+yf-L38n-jMNO zU$cANGwk6APrR=>l!qR+z|(1&->CZ@HUIkZFP=NO%)H)m_lvh@I#gG$m6~OH^!%Y^ zmZ66&cvFw}e=7i>e4Up>@AF7 zze~^ZR$s@RZ_J+#|B2tDSDt<1@8_cxUpjovvfpz z@ZJ&ibKAYS{9YO5!FSH!X7#SjGo9t{_OJy#IrRgVe!_@y044f}QG^^~|$oJI>Oh)V==Nb^i72eug{x)cbBa zvQyuJTwi|5gB`x{CinN~Egf97>+9+0dTWcq%@xkjYp1*~JwA2sPBz1vnSWQVr{lf7 zyw#Vv8^sZIKcmoV;mW*+R{F)8fqLVKw$Q`K?j4z%El}{UQhBTW$&8h+n>SuhM{kC+ znDJ$9T+bgsag zfv4l$FDmcB8?srtJe@OMkG7~=J<9iv-Tk`RGrL>6{F0Y9qpRZ$w_725gZ&C_bp1Zy zEu9(tqGz-R9X`+P6K&=7%tvNj*#9u0tmEFkzNhPZE9=#vEsSsm zojvLu%D=-`{*k!_*&A=G{VdOpBgoy!X7sAe?p<$eKlGY)WqRIdQ9nPO<+;1cJ4#2? z_pn#)PQKpM@kT3ey|ZrI@#q;^-g0Z^>CMXQ-hZppqr4H0@~Fq&}uG#*{gIhBl@!iY4F)K5#TeDnmd1jt> zc;Jn2*39dH&hXpSS^JgQt@-%UcR2&8O&R7 zgQ*+-@|yWN>n(Rj?R%-rogSW=b=P~|vFq-oSCrptm!IX0D1RnDqPxB3C)%=O7T(g8 z!Od5*+qY((nb)E{=&rY@PLJ|Ou}9Bv1s%DbzgzF<>TE&piKoMd8`+F6*S$)0Wp<39 ze}IkdwAk)(H`8H zdG@3HJ9>pLKF=&~*;|yi2f5of^w5Ir{ON3&vF=t!?+V`a))~Arqxv(m8$Q*6Z0-6z z`x9NW-QTgE{@>qs?cS1K;mgeoBXbMhkh`O9<@!4IrORvOyZwUB>n6;fK7YO2Gq>7X zdhXoF`MR@wUD*}w(Gk_ro#CnMiyqnKd*Mb8=9$hJJ$^jd^49%p6L-@(k@ zqIY!l&zf6QXYEIJ{WqR@t9;h(ZJxY^)$7$8mFqmo&*-}OvYT5lUzHuv9%ej@@_9Eu zc!#c=^>1e5=jI7!w=jS4_Gm9%zn|r=j(2ed-VwZqmRZ)ro83IZ9`4e!bk>b~xLNa; z%uVBGrZe!L+IjkW(2?=p=-^c*U##z(@$|Ro69sSicg^0qGOzcx^7V9D=t0g~Z}MdJ z89f4T3x2Ox?J}Os+`=8)Vb2Wn5j}&xnf7pn)t%|h@@0JJKG9KLtC_#@H=eiLukMlA zJ9{wm@90+Oo}m@1L*2;TH0VTS!yL9?r#+luH>W4R!U%l7H=bQS!^)23NB6Ml1pW-X zuWsC3wYR9Qo;UgKbY<(CPxe{9{woi9Ix=sJum|2*vkq_VEweXh4_6pr2Jfz(d9Un< ny75+vK2dkR%Rg85w6rD1<^Ngz~Debk_E~`%z13b-tAm-19=NHO3TECV8!s1Ta(y#r4)wbeuNG&9kX_U{>&`B`Y`AdZ zoger|{CkfX^+K)PGx6IR1{{^#-o;PY$I6NFFN9Hx&^u4d@(mlL&j!802j;2W*=Gxb4x$4<6BYOk}Z z49UATe!&I(h+2EsG|Oyz=7-#wzDYl=VZe^+J%7;|^kzM)F0!Yezxm0W_j$%zE!Ojs zZ+zc2@X~DIXR_~o-R<0HAMLDt{ejN@=6I{$VSj<@Oxe`eUgn$;vwr`?JAY8Uv+bL& z=23g*1JzG(mu&n3^Osq(ui2#D;X5a6(! zYq8!ipmP&{LH&Na-Z}Fn?7OM$nXBFN-bd!lh68U$-cvuw7u@j8J>x$66I8Pe*P1uj zpZKzpTMyW9Lw8v(=zHDcoH(Oi7|`!rda!?zPjn>PdU}1&qQ1%O%@5|*&OgO3@T|S6 z{pM=F*L+YrYt}bz>UW0hd)B$vLl3+&@y<=WGxn_29dGXZKqpvl@Quz|FX*{}pJ1*x z->9@!-D%dh=nuTT&6#Po(zE~0#8_inZPp-H3cIOV~PU!hT?HlZwi8FoI!gt3t2m75h+b}`zxe@O< zvlR|HVZaS<58ix1&!}e34RFp{FBcZX?Y&*(QSu-CKAetpiqEb83NxrGj} z-^{wZR`Z-a)DzSV1D4R&`_`kvRCCwHJ5%hLjnFrli?y88)k`h??YBO)`2w?^wRhmc z1Z&^0KdN`$o%Sc3=tVcE((iM}q;5|y2b{6rZ*XRrgZ{vbuK3QHdtQAWJs)fcKh<*@S+hXHW8xTs524aKU=R1kdQ*aZ~r(InM?Rn}0=I&Ij z`=;+sb7|)6gbnpa?;dc6{qDMW`_1($*+4J!ob|L8_Pn=O$?SEHGy03}uuyaB`V;T$ z%g{k~p?7p4|GsSmySh3|Kpjr2T!(>v&d zAt5D%r?v2@9or`yZC;ObMES1XBMavn(NJ* zdHzt#fCatBeLK15XS{v$1MOeD^HYW1XKp3~`n~F*?)$I3;MwNA>-9L(+OvJr#7lRZ zZ)#^p^rBzPC!DaO{RZm`6WsH5xi7x?%A9JZo_N*yi;6qeN9|eLJMgXdr|!8)_TG`6 z^TqSuGkU+JzlmqfZt|}lvI*`s*T21}^}CZ^`-iN%_ov^x2ASt4+WzWXGJBnO=DSkJ(GBh~o2kvLFZiDRrq&azJ=6X$D}EP# z!i5bBoOjp3PcRp~`}CbzW^iCa-=Gd_*wA+^eBaZzSx@o{p3|FcRDYm*J#(q;t>n*> zci;Jw-_-37>KS%>sQrUpP$#_aL`Uk4c22dn_ww}g{U&GZUpQb--`~f&d(>qQ-o5+= zy}x@p-#OJ8`!n8|0jimIT@|{ozq8r-oa%2x;r-Y7yF(u>x|ii z0qs3+_I(T9Suq!PSTA_WZvBQCZh8YEbo@1{Pml2^NH*vnb< z_g(HtueoRCfV~M9erJ-sdAxJ=YUc;}dn4*Qc2vLc2fkT<|MyMK7U$X9{#SeN&8~@x zGdt()nGNWk_SC-FnFGDwo_nlcPQB#wjbA@AZ{}I`qUP0R*0SNi0`Hmn>^HZ*@ZN!T zrf*zi&Udf#%{+HL^-isJPk&42X8Je%h4%aHnIACsy!w5YwOloCI{0A?6RcIw>*dX& zyEW>}1?Tz(`_}3L`(Nx0vVK#~IlF^1-gh$Ryezn|;ecnId-KWc&2KOHLi;YydEU%f zy}Y+1d+&5V^-ea?0XKZ-51I5m_MB69dNVrE`u?sfy^X#dsAtWsoxh-p?`vj$V1wC! z3kO!%==&!0yVn_e&Z+jAsn=fqPe$sE&d|QSzVqO%-POCc8MWAV_9=eCfD0R}zjrq4 zKI`uJX1%NTR1fp*>1SA|wb*z5;>B#ij`|++?sUHOJ4>&%dP2W*Q#+^5;7q+e`PPg6 z{>)yy_nZ9`KcT&WKd``k=2d&etUC_*h6QHoP2ICIwfTVo_RLkYw=?>DcKI80{{3B< zS($6yd%MrRne~pheslSMJm$eUd(PG`{7vn9^m1Uqj+(tUGGlK--`{TsfYev_GH0KZhb<}UuqdJ<2Sm(MGs7C?yL@K z?|!wt)aDy*^q{`E@3L>!@0<7wzwhS`XUxtgcYagbYp$Bfga!RQ^efBAeQL#SB)<-AaJp;Y*8|`mn|D>)C>YaS1*Y}v4$%W?b8`SEKdapC>9sCRn-S*(X z2>o3zwYX!W2i&vJ&P@D(3p?K4!gu!Dcjnjw=bm?G-h6XzqtByyXH{9@-IKG^4DVZC zA8&TTMMwObN#7Z<@A-oc*zhvcOeQSY@gx4C`U4AIop)?%-}&D9@#fxVe&XNzGqa-3 zb++@9tbf;=x&9Waec2z?zx|*7NVYQT&c5@0&Q84bfWEPry6kb*zBJqTV@6NkT(veo z=z^QNZ)k0Q2k*8XsV~~!uitI9VM+eR?wr&E)c1|qwW(J!d)9Jc!35{6H*}Xhb8B6Z6?c8#y~r+jM!!?f=s^dp)EhNxU3I6qjPzSC>TQ-cn%_@7@V-l( z@OtRW(GBKm-+oi8112oV9oGFlFW!15bBEr1zydRC&rkFuv+f>etnHaO*ZVf!{KET| z&pyu$@|B)5%?^GhALvTme$_d9t#4|x9Y0a?6V=NA-?>oFn{6<+em(o}^1e4`oOjmV z)TiHFZwL8-6>e&^41CWl{DciBYSwz<`#XN`bEa=L?{}S??fyaS4)Y~^i~fG--ZztP zcWz{6qWjPMN`2C^Zg%zg&CI%g>r;2XJ@Z%N55C`}w>Mz~{Q~uQ?oo&2?{V}AJF`+>bbvacXWXkd zJD=X){z%U`>6?1bC6m_gymQG8D)xJ)^ZhRU0MG5{LB%s-eyzbd{j2dO`AoJ`uhgn{ zxu_VmJJ8&xx>0=++%Nox#){?C%HPrJ?I9r z3(k$`LcM3ATMr{Nw|795p5J))t2bB=yl0)Wr#~vR?z^7Sx97ZiVZw;FcE^rdANcLZ z+Y@tV2kP!$jeqa%o1Bx+UT2-T=nNZ`0Sk_vnTg8%=({)P%@?@u=z9--ScCb6_pU#g zGqS-MamPeQ{L}b@?x64QI`KXCof+h>msxA)q~0BU|DnG#`)24_PjIF}Z-_nR($gWB`f=TrCnW_o8|AKl3Y`k=q?GhA~s-q!v02RQTIL2u*DozY9*Xg1R8 zF6#?3EcAd{zdwHJ|L*zB+jqBTs?J|(wf#lyjJ5v4h=2Rp(cj3vy#r=DwfoK0i_UoK zg_jKnJlFTUb2;nGx1Q9^-tU$-`zCAq15W3nH`=#b&#=(fXUD)FFzf!R*@nJ<;D5gv z{=m0y?~+aD3YodJdyhHw(|CImZgijMh^_ocbD}9=lh*z^$YL$&UBakow`K_9nAA6sjGucLOFthGEM(Txn_M&>3Fyc@2)!e=9F+ZdB zrF;85&Y0P|!QR5lgaghDbcg;f_WLgVf&;zBtY*%*vaj!_~zYj zt#{YOZ!lX?eZRTc`{w*^W->q6-_+gfdHVc0P6$q zjZrgu6FR%_H?{tex0%aqj_%YwZ+``Q=GNvD25eZ-W;eC-?KkVU4QA$_v(B5Z^z6Ui zH(7u4W1g)0%yx1!`({;lPioQIYxaIee-|>jDl_VWV+|8}ws+k0tl!_=I{Q6W7Jh;= zX1(*`rF+erxo4)o&=YUIb7myBo?uVh_q&Vy1oQUa47qyI0UKtp?!0H*_cZfjJwR7-Ee|E&oo=qH$D9X&#F8Aq65?=^i}u2x8XMoc)h%%{*cS} z+@!wn-W~sD{$XyI!3O75Yxj+)*@gEWX6%_QIG}Isj9K%A|Nf3Qy_e%Rm|6G!Z+6J1 zeeTpR*zn$hn*Gf$nGBe)!P+y{zDNIlGxjdC&^Np{lf8M=nUfjSxoeHK-qZ`s#hHO_ zm|(V|)?zK)*E6l}r{B3rHj?*m;rur9d%*rpZQj4Z!k-V;6W*D8lYRGfw|%pnIeVQM zWEcA0cP{lry?3Jvtku5n;P0p2c;_GH*5J&6-n;ScFkfI@tS7whMXjAxFC4J<-brq4 z-goQ!-DbVBZ?>MyY;cdhXL`<_a|4y0v9Eu<^!4V4InQ z!h#v~oOR#u%gHWy|4?BG^XB>udQ}!U-x+)Q*Fz_@SgQ+8)XcNK!%X^H^ldjY16{B) zbI=JF26$)Ru~RQ}>fwU@gYMSojBj7{Zn3u{-#_qXz2~NP&<*yaZ!sURgXb=OLcjB* z?j6?R?7IiQTcZ=4U25`+9+=S`H8)!@qShz%{bYTcy_a*xb}qlg`+HDNaAv>->+kn& z>No5AKmGT;nO&Jt-S3$L*7_U>zaiKW)ChMEdr`8Xc zD(5choZ9bv?;&&Uo9oTQ*@bR6&^KF~PjF9rgIfP~qrGe4+xzMyKk0qH(|dmJ&5U#J zJ^hWSH#oP@38R`fU21#o`}X?2W1oGmX2yAECo1Nu`2lCGck6gF&n?uMoBG{LeyBIN z&z>{Z(sKuI-|RwXMrw71lT6*vde~>TwY)l+`L8wuXUx4b~((znnd0huWPpa|3-d$*m8l z&d<~r^_)Fvuey?}7rfJ68D{W$=!O$@R+>4}@AiD(yE6a!pR)(w`<&@J)nN_xo%_8v zzk$68JH3J4WUBQ^y`t7K;D)}@Y=iYfpUgK*yfX*%YR|UbWW_Vx)9ifu7vFDl&RTUw z-DrLH_#2o_=-)@b;KGPnduB$PpVao=`>eq^d406^*xTT&=*3*F>Mm5w)Ej!{v_`wj z%-PQDc=O(`H(&O^{91EoXUr$KXMNPU^7`ow6;AZ^IBWep>h6KwX70|Nn`}vH{Q&do z#J6u&ui6vy_vY-e-e6A-yx%J7oBFQr+@RmOGw&?<1n+$4Q@hum^UhxQs=3(^ytMyv z>YjDR{i-wHZ`{e8?GCdE7rt*B$yM`%inX|-cTQ@1&fTcB^d9|&8T9WRsOJVM-Er{; zoSEphS2Eu>sP=o;z#r?6UTfx7vUhjte&41iLvruLw}16uX5SrVzx(mq9=va&fAnx+ z#&6W~Beg2Wns1ZtzUg(w`t3&b^TWXpFkj#fRm=uVIIy92nB8Rchk8kVYklABIoV_r zy2lx78DPEP#;Y^lGXrndUDn^ey_3E(E4A;iXDuTfYu>Ah`BUaS|IMxCIv3sX3+*2J z#~e18x7XY~(_ZNAH%r!gIdu+HJa1ag;LWZcob8)c^9g;|`S_kyJ-gC# z-mGubo84gDZ_}UT8y&EO?|7Qpe8YtU?zZ;4*?{+U)O@8r=(`U!x1R9!=eN@Lt?E0I zna-MZ|LZeny~y3a(Fr$tqN~asJN#>j~}|@y%4TzCr&oW_SFn@dwOq@EtRmxr|TUcd3_rrSG{@ zAMI~b-R7M0tmmvfW3Ae|XnzaU+TP#(B$MtNc;|n$ccfMq{R^G5*2{r6i}o$kEVVm} zvllk#o!6_*n2Fx}-Oc@GZzej?3+DB`V~|PbZ~Tr7o4xJn#%&S6Yu$v`ukhx-*=(TPUrLcEPT(@-}K);-%Gv0{Mdu{9s1lk zl6jX|`_5lboskVwpLab^?Kykq1HJHi)^MobUgo~DX3jQSdU$U^rye@@a(9r)iO$TZ z1I+actkv%DuI^FoSDmwe(FrGNZta|V^m1Xr0Q2?^^-T8e!f#k5-!Q1fo*Y?i;(ibI$ghD!*C$fRjuwH}ygfsM7EC&Y8NqtZy=B^)kVFL*Jw}>-z`& z3)bb$iH!>KXIi zt6s7L-#K?~aJFhTq3<$3aFchJXZlwA&A+;GPVJ2SNxj2G&D__Wsxu=vv!QRYcR)4& zox|U-lB+#;@Do&7u;XuhGjnx0UL`G-Zl({oljr^?$;{_f9R zo{^K8s#)jyu0iH^zi&sK={>#E^ZJF(VBIr%&s|WzvxEA8J?B-KQD+DKfN$EL+<9~9 z9r~GUoBu=YjJuq>uv6=0!hj9^2J-{vV!hyiGy2Y|GRedl@lNxthxctc`+gJp=A9e( z;~aSIqB3Cz&-JeJCvQE-^)qyr_t=vK=JImX`PXmmTz{JGHrffVsRk=KC&rY44DU z?-*$3SNugMwBPsc^uO~Xb4LYd)EV!cdbw~uSTE?liTAGe9rifu{CkIQ%{L7)XXU~Q zCz-it2YSJCy~7zZReDF~2ib%hedqA1Z2U@J9WX(E!1~6Uxp$-QyHI=9?wV-7&D`0A z+P5eA_Xg(MKV)U$f4>VegSwf1!^`|V@_w@eOY%N-Qp-(m#BWsJZ_qEeW^kg`?rHtq zx#@}D(z%s-cn;ieUv_G9ee03@Mm-~wT9&z*@6yXwnOiDcIAMOg=iZ!3ZvB3L{07h1 zzfos;zu5xk)t=q>3C;}_4%nBA?|18Gd#GpLY_iwW+kVyRwtmQ;oPwE@`J)64kGxJT(o4*+{>k0NwynQou zNxsGGBFbD?jqnho%bGc$gnvcs{?@9pfQ9@6*b@3hGF z2hV)Do(U6O`s_Q_%)5W%tuM@|nX_gCj!M1Io-^NI&rF?g!P(C5)@Z-S*#R?p(eK{= zMsE*2=if}Uy9f3BVC^2+@h{7a{(`-iXRh}wbKkCt-^@(kGm~%hdHVKlX7#dCJLmbK z#$VL@zzltF@3D4n!3J~Ra#EYi0J9scJJ)wxZ}J%y+BbePb2%{d`R?bbr&-iYUdEhy zK5TsN@P7Md3r19bJ~+SAGr#G1zTdCfo9XEnT-afx?!CY7$R20yRoh?Yu)%v~^yN8g z|A6PW3KN`{i(jzAK%G6P_o(*FCz#pyY`rtmH#*ndH+k=0)Mn0X=sV2aSG6Yxy4Rlb z%{za|&Fncl>rwql{q>t{f%Dz(d1nW>yZ&0kgaOW(Sv&VV_x6$7fA6Ee@x5oYm+Uu3 z?Yy&{@r>T=KzCS6@A|&S%-J{JnY?$To+_vdj>_Ho4Hw)wqvqD`wy!s9|G)40Kg>(3 A2><{9 diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_64x128_64x128_64x128_64x128/input2.bin b/test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_64x128_64x128_64x128_64x128/input2.bin deleted file mode 100644 index eead7af3788b2ed2d30c4b2445f9cac2a0d774b8..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 32768 zcmZva(UEN-k^?beJX+Q>Ou-aP!IaHpq#NgT>_>)Fs-ocB?lJBh<3Im9#zgfC2DM&2 zsNNab=CDBDT0O0qoA?D62ApX3sO|6cyl>*|jb!%L|NQs7J-wOSsD8o%cUbFRma}H= zHT%ub(>q&E{LAUz)EgbbyqWsG!5!2KruwO!Gdr;3SM*!I@9o<%Yj*LS*L&uh*H7~E z>09?5zSDXqQ~Pb^?r8r|Us#}@(Aj~%u)%!6fCH_))7+h&+j#4v=C>U92^((EALyIg zo9Xur&Rnn_sP}lre1W|gZ$0$%H)`HmZ}w```lL3$;BLLRW1{a~@`JbE@9evurtgdl zxbYKRAAj&8^^Tt^JNk_;r@rL%-fP{gwf&j7zHd>B`|J(yo%XjLZq%N0*3&G%-AZoO zJ!Z!m`u*1C&TRC{r9Lph`M$witkof$x1RXUZM?M{=sWe!SP%RKbwoHmrEnnTdK%uO4de z9xA9?wU63+{hYaE(!R5=jxN-!I|u&Fa^C!ecfY+46EF6bdCpkNCEMnOth*1{0_XHH z!M<<#?yz^$+o<&j6E(BG@e4hmHtTyvW~$A-bLQMe_0soT{F1z_I>g+)Oc@z8WbvM0~P&6%}?r8b&m7S*mK5Sv#Pyqk2z;9_!jFg=FS~xZZ<)6-$wfz z9DI8V@0p8kJ=DLulGQsS*6y9@o4t3dAOFr}zutNK?|i)V1$zrl7|E(j^u5Z26oiWpkdFQQHvZ?31J!d@gz?<1~MlSYOw4{L}On zYTbKO8DtwySZHVZ9{c9*wCC&qv#YubJ>i$P_MW`E^H1q-=zX4>AFRK3+PkpL;r%x8 zX4VH?;o2jcsJXn~FPVMsw%@t$_YM2-=FkoAPQM3#TBGkbNIt2ZecznB^q&3RzcS}+ z=TFYr^USO9edpDO`SIQT`uV-h4m0Acwd&r!ccs^Sq<_#2EA@@`ob|NNU2^P$>OIY@ z#oGRa9h}pfU2vxQc9@mk!I_iUfljdAW~o)rn@e}?_VbpTdY}ue_5Iy9-fV&1_cc4o zFY5e2%{Ex~+m>}Qb%u)`=uYR`fA8h4m6;jcnYZrmsUGSJo+~TaZ;pC{xhfMb4B@P` zEVyv=pLe|P%sU6YnK{+j3rjfTZ13)y)J@->^UXao!K~U@d*&O=#>ZRRH@h%l2mOJ* z<<0sn`JS7exqhVfjQ;zr_NQ6Yy5G0)2UO`k)%?PW+P9u?!1~>T`Y!7W-Ju`J?3-8H z>+E7~gZ-nz1byG4_PcI+JNCx3vYItq2Ber({Ii9PIA3E;LS6)&EbIOChGU=UoU5R z_R{Z8y|d2g|GAgi{J@HLrkQ?za8I%49pj^&?>p=rsM@!FZ-|;%tDb9bSttAM+hm;` zcv)~@gYTPY>2CY>te?-!Ld`~WqhdXy-C=f{=ZrJQxzu{!yUc`qqUL4e`%P-!HNJV~ zXY}o7*UijA?QQ7$?3I&zpiBDOdS7aF!vS|JbVA=$^=`fWH3Z}NdoI6>bx-FVNu`>eq?I4630&hF4$wYSiIUu)I*0SBHZ)4$(=@4tBG zdiK35ncg!KR(QL~oj=K(dHYZ^d+iVE?tgRXzq$Cg6K}ugy6=)Lm~iyDyYFabf5W@G z>O1iEJf|MmFre8@J)_oQ?M&aV9x|yPcyUhj7dGgt&g<7FyQ#Z>QtO@R?_g&Aa`xZ! z%*Q9!zxz4gp0)nMwjV#D?(BWje!k(&;-BW6bH2&B8+!hH>dp+Z{vP%hbcbrDZnS4w zn_Y0Gze8(h7P**B*w(zav-+2jo9SMHpyuwHsJldOw$KA_Uv!}M%{FRZoR^F5`$zJPIzLnY?)J>5x9}G>9B_}> z0MC5)I&U>YiKFH+{47sry^b^xvC!*P!mK^XB`L*%R+{)|qDaQ&0ScX6`ZT*%NQSZ_#(} zHiwhUo>|Y^8{k0XYH9C?Vd#~6Ha_h@a>U&f2r?y{a<<%ontN zyVLhA)_uF`Y;(`r-)Mgu{Rw9B=E>e})R~!D-yPjM%wWNV?+$0I5Bcxh$2)INoyqIX zJ#Rgrx!OAxzL@I|=y!0|y4gv-&;k8k>w0tFJ5jOco}GH2dNX;yl}fhYM4d6Sp0HrY zS4XmA-@h4{-LRq?wKvfbtj!Os=#01C;QNjlvp1?Yo9S<~e;>08>VO5-^{TZ@*x=dz z7S;>wz2C)}`^=gjy!{QHQ*Ub3Y@!R)?(VzQ*GETYF1kUL8TG97_l@msa_N2g0lmjF z%M6ui502{HWS%v8et_?H zu6=dE26x_%H|zX+L(V#5uFjzM{C9qk?d0l;S9{)k!bvvJ2fNfI+|_yCamgn1y?XP3 zI%mB;*-ftYJHBsqzIA8*xyL><(=V9dxvKd|*4ahvn_iyY-*?~7x%Uq04d&LKeeXui z-Dhn+(^H3#+tV+2=a_F#-Rbo=wLZp&nN3*9 zH+n$zoGLfj%2Qd(2K)=!G{!&wM7^sPzpC73&F}@m{swzCH7Sdf!G* zu-Do8n{1;ChUDDDn_GYH(Rc3JhpC!}Wi6?D?;*Re?ZunfH&p7{PID7Ec8)m#`dhRLtfCHU-e{X8f*fV$5`u+BJbF+iqcdrJSkCs}9RJ)yerz2nAv*Zz3>&JQr3aKW8s>Ou88T7Q$h zw=icd6ApN``>nlOUGd%h%?7;-_9iOs=xpDA$TtkofBV+XO|)lMGSyrL^qg}ynfZYY zFYCLRF`J+MNNrZ!wW)oB-kAg5wNQJi_0%)-=BVv`-{joG8T^3}CTg}*>qly{2^UV( z%)0Mt@73h?R~YC__TJ6?huZvxp0l=B-KiHkp}&XqcbD}}|NiZ9pS{jky>pVQW^$;V zRaf*z2RgyI9pAI+gWTQz-Tqo0;BW_GVD$ z?3)cq{hihKEzUUm$p6)(KB0dLvzxh{UUxg={7P=$tbOYn^yc#3Kz%;gU(mk2zW48& zCp{V9+|p;a-ZxBCJlneahBI(L|MsJun}2fW?k96@Wxl;?d)5PPXuYU6c&9z*`!4eV z<|p(`tyTLI+S{q!V`g6+=!v)IoV^+C`M$+jz3Pm8Sc`swXL|nq z9@N)7z2E)hX73DoV1xbpQT+mS##hZxdVQzo?Qa-hU$tKGzVE%wEME1j`6NG}H}5&? zOD^v`wP&0abL)jp*x;OMepNW&UTb-~P=+eW68}#M%rtDp=IDDH>a18F6)u={ zuX*>`8>!8lv;OWgcg|klr{B6 ze^Ad}beXB#ztNZFEzZ0ixoThZ?p7x(;f#LAH*@ZJa{aLn8>XIf{ig14#=b1*xfyR} z??kUTxM!gg%yzO7ubQ{s`o4*GEb0U17y5SV1rz#Kd*)`oTL#|znhDuPJ#(T1b*5bE z1sn9OtBcIpse*ZRf4sR&v%F#8zjxd7P0k%~?)yFO8_3$*pP7a4H|l%t;*0&6IcKal z+I!57wQx_>?1qWH95okbn{VoVlit3)3%$?GGy6|wEuO#VOg-W~v+&aIJJ$KlRBL-Qo*8(vzC%A@CsP+(mAkCKooWsBf|p?~K_p zk2gE8!%FUb?i-)nxeKcOvXi}e&X_whg5DXi=liC;>|gi`=c5~MK4IuNcjH^z+i>7U z^&{RhGu}I$v6cna<~x3%6BZoUa6jJJK|R4n4dxL|L8Sn>8Ic)po+ z@7y2XY>*wWH=`F_;h@&rn!RezF8oRUee+G;Z*guWd-HgAnB7lqy z!g~WV_P&^R_n`OQkeMAdoA|c_wRT5q{e=zs0oM8j&RNR_&&%I>*n2QonCt-Z?Q-%~!nj!Q0zlUd)_TN4#o%VI|WK zVYX~*=o=3Hnt33ckLp?8^E517&3v#9rnzDafdfa=T*3wGiNUKo6TGLJH2_& zz=R!j)|sBY{cZZ4ldK$k?^2y_ZSU>ijRS8zLBBt`@aFfU-DTDt&RL(#jr7|8 z)-%0}F0l9B6@Sp4tJ>e>e}8}bmwdwaoO6Rr_J@f-;U;&7nYHR!eSh25W&@n}Ju~%5 zwo_Mc>h3q6neF#j_nwiyx{~kslio$e+`TLQL?>DX{=$L3Z}xt3`^9%G>ItLjx7MTj z?|auLyZ95lPnu0~8L(ih+;`$9+B5cV@_~xIjhg%Rg-$r2m->yLsb7}wy2X zoBHm-Tzd<@KXtu*z042qjoCHGANf7%oxeF_wqS?;eGaufxv=7onti_4y#4w~w!!*D zZ*-s=-0{x)8-0E!bNdVIJO9;7ZKfVru%YMl6E5^OwUz#1FJ*^xHc0 z&h9wLZ{|ky&E;?8S?6YY?%sIyM(wTmgL;0XemVYvy&dlxoHJW6gI*3e^S(t77wnbS zr&cF=g1-B2{8C}VfeY?7`_kWNlIu76-oWhZp~o}yJhx!N%h2zeF7;04jJT_JjArPK zKk=$_Rr@kwzv?mrx0Aj7`5yPI_R{}m zlT19@xAYs#^y+{UE-EuP_e~SuUFxqUKjEFlU#K5=@r-)mgzo8k=BNMvee0s{nT_iE?ZZ0X*qLU| zzg_70%uT$!EZ(^{PpwMl2ARBh=Jb7|S?h`SoNwFs1wZfanz}luuNm~6>h~@7=4WQ( zPq?VH^Lp=^sk=utSKloCHRq zFS=lV)S2a~W`9!;-vl#(VC)o!accduQHY zRu+B;JvsOboj2Rm3sigh32z1+@xG(|?z1Ng&POla8FTf$m-D|J`KGsrnF-FGsB``9 zk<3~96E^7Gt9}{OV&<8tg8Q7`_zUk%$<1aM(Z0bsGim1Rf|Jbp?cEH zT-?!mP`~fVIeRBF?ic%J6Y9-(YIC{K5p|c|{Dk^R-EVQ%%AEVHop(n1hK;vx_IAf# zYO%lI?aOzVP3XJTo0)-jPv@-VWM(=8o|89Eea+xdmF3=IDcE?%_el;POYAxzu;N@fD_e==Px?J zxxULxPWay6JN@=ee}d{ve@|=u%kUSRlL4JI?{4?nU$9|7>+h~Z?yM>!wKJXD=G68t zsNZ*cr?os!zw^#+Ynai6njO$zn0j}hBi>-Qi{Uzh^F)IPZ*pL=W0M_T=4@zssd| zX2S{tovpLeo|*68_Q>^S?t9;8p8BHwR@K_szGdiPH)H1YqTl_eGpOgRZ}9GB=B?FD z?H-wUYx8#x{=F|feQVWSp6Nc%c)mT=+8J4JRp#EA++%Nm^P+!q(Y~j9?RDN=>rj4Km=b8H34E1dHPU>YAI%j=gK{NeMHdOZa9(8_ZZ~D$SyVE;p ze~Z4OXPsN=Uo~Fc!7~ScgZYF3?=7gA{-OsajIhvVs`Z5#KhO={^N2UQ&^uPV+5Gs8 zw|CThyS}yRj(&&r2J7yAeKOBq6($@Qu;D`YTQ4}tXKM53^LMh}GZ#Osqxuv2j!mtf z(TnQU1t;jgZ{5@r2Hfb|$vd1|WIMTjz=ab!^ZgzCeK!3i?C9?|bH=k5`gVPL-?t6& zvNG?P-eqqmKWOi@KdI%y4*CW5&3=DJ-grK>b26dtf8Xi1keyJ!s4wi`?8FZ+U$BFj z_nXTB-?5@*p5M`d9`MY21K#YrW0}#P=uPjS=Bo6LjhC5hp#43}^zSbF@SDpUt;M%E z_dIp?n;&X>>Ht+POxW?(1I*r?oV7mCJ^ilt?ddzSptGG(`=-v!%wAO7W3Tnbd+wn6 z?>pWsJ^NoKbJ90?&f2riS>M#_kAHj88>k#`PVdfHkJ|H$*#`U944%^K8P$9T-*fQh zW)}ubc;AVdo1J)P_1!s>S9`}uzEOMT;$Ca%9sTCv9PFqwN6$I^f*Z{Ao;7#ZNPW<) zhY8l|dqdRPvpagx0}G~E$zIQz{dVBJb7$^Gd+yECpXR9jX8$Ppf!9a3bE(~L&+MA# zcX;27H(#HgGt$1jo>9e}^P9)p8+dyM+@as8Z`8c+G8-^Iz3-mqnP2ps-BI&{SKVdL zz3L3T$LxUmo%fvdos0g=oY}y)-`$7o^>TJod&a%p{j146Z~ooq8F_hTF4~!a|IVz` z-+T4WG;`*?0lmTg46oPUjLdgA*EcNuNnLNh-@oxA^+CJu_kGEn`|aeuOYJ#x>3j7H zJb&V??VC-r{BG~A^m^X@fT#JEO?}M2JE)hL%Dx$OX5-)cQN7tkozXjg@UMsJn>({W zb;m$=>Tl+}J+lk%eEd#*)c9uoUF=Qx{U)jV+w}eYjXeJ{-mYHi0cyXe>g9|yvkUqS?$ZBue!kJn`<&~J&YOw-Wgj}Pzi^_~ zp6_=|>J6&B6@T>ZgEvpT!CbvzeKLCk%+&WC@$T683HD{c(R0U*zB}XXbw~3}Hlc4C z_?7y7GhTI0-nq{=cur>eZw7Du&i9)?S@)>#US*a5&>3&fTJ;V3_buFg@p6(? z=V!*OZ(sQPne7|iw>XzsXWHwXldNae11pU5Cc5E#a^GtH_T)Evzt3l`JJKLU?_%SD4XAeYQu3K{zxg6QuXwlIj!~H z+V)Rtck~Tc>J~ksPxOqc=~i{GTrdBBQS}wnEy%hvqxh9~L`RT$Grz-``iWlU$@HM- z%WSD(W_Jd@elL&j{WFZPW?S`g%bThB;9bE!y!i$*J>7%M5zfFLK~7EAOxA1*J(%0U z%9zP5mHVUa@o|Lp{p!!)^W8y>wQKHX3s*2(y`7Q1hZ)xAo$-(KXu%wI^)u{&NBOe$ zMwEZ2_8HCHHKTG@*uo5FXhF_SuezsRIkOSW?CGgYi@MvwQ8WF+OHVOVp9O!$`@4~P z4?MWhgY2mK&r5z!ub!oTzLjsOw?O%C^4?g!{)#UT{$BLx2=>3TJ3CLXE89!e>keK2 z7Wnd5pY2ZF?S>Y-F@wD`tedMh+r!_@?6xq<>*c%u27L>&*rQkAyRBDmp@k8=1O1&F zS9DfyuE!R7=Jc%o&r9unyZpWS9)07vVXl|?emC}y%C)Gh`7g|12kx4!durZZ@VC@Q zpk`{jxT=@eo6YRlYhmTLny9DQ=IYk!u%=Ep8GSXk6@2?NA0wrH(T(*m*1k` z(e~M$<*khTEab1^j`s#R-VvVktyFL3-kLA9=hhi|YMB{5qPSCURm@z)V(s4@6H=NqBF>? z+bhev!XD1B?$~o z?3jEE7wwQ<_ z?3mf>!QQIhs(CWzyen$&7q+T-n3-dw_C`;AR!yJuEh_T{xg+$j@q--h%;jXL9Ja8B zE7LwVb9FIIrMh7K;QK0E0{fDefRP* zIL!lZ3v%jF{r%aSAK#8V-ek_`+M9Kw#Z$|w>H52u-@{e<)Qxq^?5;azTY5Y0{DNCY zn3>sGnKif1zYF_jdb_YMBYOsOwb}UeysP|~ygj&OU(L6-1!{+X1)l7PwkXf{@T7*? zyqSj{WbK*HaF+k3K5AwTdHYwX`7g}CbBpKBR<%2O>Q;S^o`JTS`CKw+!OXthQMDQW ziPzLowfE%AF#}JA%KW0e`jtIvc15@1DDTb=&pmm#N6pr4yF7c4#|-voW3#_6TosGqiBi zcxULr9+i=wp%bD5dQSi?w-qfS~E9wUS2yUq99`4N4 zTd`&>{RlF>GEbEMAL;E~!Je$22eNLD^4CnxOi#Zx&%K3xbt|u@KT`91^eW#TX3hCB zZ=iZL+;RgWoM8+69%eEyH*2Y}rc_aG9r(3WytKZc@R&Cxh zQ`@0E__^B0e^QTlf48FcuTr%g88_BfJq3!cnKe}<=Ob6yW>87f0ZkiE(~s^_hY9sV2bl{43yVT zPqwGuqcF#_cB3+_{GU3Wojc6%1{wFfYqs8Z%kCLud-d{r`P7XT)GhRIg%M<--m5l0 zgG>u%N0@=)Bl2-YC3}XUaDR*a~#3^&MlZj?fx6C@_Op;j@%YT<<)M< zz>W6t^max#f|;7<9ryV?T)|8|gZUmcH(PzHnlE=JHv?Zy&HR4Uz1XAYr`C_qg4-19 zmf0S>z5dSRdU|)%6i@1vmsPL18!}rs3OBC6Z^6FW9)G5vx{c!tc{R$=c~rvz2Y>&-6T*5#7TTdeFNge-vl+S9I3S8=d9*y{_6#ckH*Y zZtST?`P9uG?&M|RCeLi9reAMQf0ZZyB;(es+HQ}xa`th@bH|QeO`m8FM{3wxz1m!E zhApgm+Dk7hhgNFl`*2Tf$L?F%JNip651EyHQjd7@{4?4EZx3>2d>);pc4c^PYPlo0 z12bL=tJXvQ4r}i7^RCjfa!)+UnE6Yv%qF zs63T_qkEL6->SC9AK@3|%=CBNpuFXusavr3WTscoU|-z={|F;2Pd}5PJZw5xAg0C9H}A0 zzcY7ZM(so2!x3hX)l+x)zaV!7{`u*>$)j>D`h-Ir?C51MQ}^;VeUMR~!EA(9J;j|p z^Od2zE4qhSx<$Qb&ObitZ64&_^dlzv<)_|@U$g~TI)Xb_xMjmuzWJN}jQaQaH;=Y>cAluc9+lmL zd-fsoW_CvTGu+kB=o8l8s@AnOqCzRMxyl&FEQk{>nV@XZjJ2aHoHR zx!!F3jKA`F`7*M4Dm$a@KhgCwsIPeLot5Fq>gCMcp>O=X>J}Y=&T!YBeLnWE1$(nt zxg&lHbOu>!Uq-LSNbODw@-Uy_&D@Nqw+sGLwXAs$e6;KsT~?e!>MeOKPC9_+fKK0*s})C_iD zjw@_ohPQk>PwE~Wg{++Z3^IBOH~DLJrFP?}d8>YT`d^T<|0avAygPNTURKV$NA2MV zYgYzac#^YkMtk+pT7azjJ5qmYuzH-Tb=~?+RORvqjxClkdUYJ8C(Zo_<8{=nVU3 zwtQYIzenBBKhbq_q&K5~^UCbydvDFo)OABh z-Vu6u>aLmlEqVqy-=pj-<;x#Y`4L+BJIL_t(Ozn2tKMw+yb;V-{z%PVGq>08vc-pt z{K}YL!RMNxS1%)PPJKrA&eU&I*52ynx6tF8-@#1AO#ffz)61@m{aLfWx^~WZa=&N~ z>z<6+QLH!gGkLmx2ln+(l=u5&TIw^b+iLsXS#R0p;|g1lK?{6$`0m@Ap@%nfzWGI~ z@6jh5@#v_GxvU-gXVq)}%^YTXxKl&!&rfaFZ2kOK>aVxZY2nc8`v5R&RbMyGL88Tn}6E z#@ADMHTE!TM&;c#Q@0=k9(|H)QH;{{jvF&xkKW-*{nwY;aeohM{v@YA%A4s&YTwrk z@-1xDo9UN#rIzLY!khgg+Jg+`ol#u8^lK)&^THi}^;>$_T{G{fM|k4vK8RjQ{`u diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_64x64_64x64_64x64_64x64/input1.bin b/test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_64x64_64x64_64x64_64x64/input1.bin deleted file mode 100644 index 9971a667f56c049cb68ea6cedb80657098ad78cc..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 16384 zcmZ9J(UkSfUz}#BmRnI#ko>Ob;jUI5ORd7bF@u`=6-vS4J2XpU9X8+wwZdULz)ZAKa z)ZLj`)Yl%EsnW3Z%!+5sq~MO<=mFI^{i}J8XBKsXK0C}V)Uc;-?wifctUY7ha6m8F zXBNyn-((xqf;HaSH*VBi)~M$W-W|^L_&xoaT9t-9{++RAN8W#_-Q#R_nt8^Vykk)} zsL2a22h799XGXPeW0ib>JSRPv>n|)Q zb#|C3C^zRo$+SQsb=vTpSQm~^M7^Dyik3E{){(Uutz=He=^TIcffvd?#)iNJU8bz zKEH+ewfHMk=^XZ3e9y_~l$dI!vMzjY7Et!r@ZPXBwCnK-w=v)0bH z3TH6aN55}ra%=l~XYw6p*=t^)I=|5aOSoGK|9bY}Qk$u71~oT3L!-GbevuX2;pKSe zaz1@^fit<=xmS~2$nUym=5Kb$-}^Xc&pERK>jmy;J!&nl9)D1|*3N)u%33C~f2PiF zWUW{8dv9l;!Ct=MQXjL)+|JDGEy&%C*G`s4GR_37ys z9C-K6-PRWx3bxF*FZ>1Mz2GVJ9-lgUzq7|P%-+#L&78^I zi+V$=yvd$37Z&LEczfo*v(#-5T`E}19rcX)@66=yf6jeH{rt>0+wjhL7B02C9sX|P zowHx4*k9=VkloRy=WM>yS}u6rp1gT#dy@S39o%c4JCnOtuU`7W?948A>vN_juhG9} zoi|e#GtPNdpIN=x>*3Ys=idFP^KH(1KH4*ryAJjJWbSa*JTr}7GH-OotU!MTvunor zr4H^X`~iLH>^Jw!MP*s5^KN~?0e$v+-ukaPYj!`gjX$u!T(x#rjW#OpUNg90rdlts z&N*kTZ}RC`=l5Fpg3Rd8)O$4VUDU~*pWo$>J5!_9cl<&d)EeEWndG#r8=pIR za(nWt$-{%hJR4KeHq0W0>Mtvr4)SftJtzOhU3pEq7GrC4SSNP0syyvYi-rByJEblQt z(`$9!u;Ni??5XxIdZt#bCGXHLc<z7a)Z$z;_bl=~eKo(6bJ_RaH8X`asPEpn<4kStxv%Qo;rzlky!%zi zyUeT$oLA+5GmTzY@ZLnG+As6Ze&@3{y?^PP@qF&ulbzYssJ>ti`U}}_tyeenpntnj zdkZd@zk1&4Oi!ME-ejG7*2nMJ(UaMG-_Fj1H!s+5M;F@T-EX}?y`wL)$DSt3+sw-u z-wfu5dQaB+`HqX1^WpWW7reXK;dkkCE;F91&zxuOWR0fYc*$>n@n`CVHn_Vcdw1cT z(O-1Q=bPMR?HSScsIz+WJwD%H{(Z+^zv#Q`qEg@+%=4CHP4?b_>d&8gQ@iWkjsH7u zZ9nJSUu1G&e{|u$dk(dj=e>`fC5T^vv^)i?=?I_t-P< z$?WG_7T$i;UW4yF>UityS!A9s)J!kV+A~kjS}wd9_NXsc1N{TmsU~?^6YR=PyN0R@A-EY^}I9k`#rd)sl`1TU2vx^YjQD@m!W0{m4*%3 zVV3VOf6Q;CFE}%^rp}wJzi+j6_I>x~9nLrT1-)l>YSi9=8ntJAL7%raUQ(+ywf-;f zF8UXC@P2$lLEit}b;$FL=FUI;_&YQE)ALN;;2CTC`CfaT>&cw4c22c#@0d-VEcGRC zFgL5I&t%?Z&)jTJ|LVENyuq1ddun_79?kcpf5{gd)%)3}nr(2Wb>ovAygjucGoBUm zHTg!%`m@tIZ#{U)Oyk{Qtv|59`n_$uxjQ8HI^*mf8hvj=bN|Jk$qRL6!3JyRJfnAS zn}M49F5aGchIbF=oHJ7w+(Ew~JFMGm_IU55E>N92(>L3n%v_wgqw(%3vIRTJ9m$<< zdt{!8np=yri~4|hgLA4q)w$&Mjx#de+F7+Hzo@nP?x64dp4^^fcj2=`b@rM~{`Jo8 zA$|8GJM^ySprGM^GwOwcH8d(4>Vt3i_I%fyBVWv2cyC3sJ2UBJu9-PtUgOs%+tWK} zeph!{_w;JiyS}_P@FuhCc`$Qc-=GhfJCp61v3|QyGg)xQTbF)*moq*4=50^>edhi} z9SYxY&FGt3t7fvn_o)p9cgVfg&e(I$Q9<1wtP9p?qi;9wv!^%PaKUUrjefIDmhW<= z!G3}iS?*NdyQZI>yEfUKyvLiZ@t!q*--I{+W%vW`RyWws8w=lX;Y^-8m-Xbo-;LT? zYwyq(yfeJj>`>dcr#IVZftlLi8EgCU-D$pug-Tm<=FLr3=pDc9L-yD^W^+b0TjHCj z^=1dGRWr#w*_St7dOb4-t*N8A&pKx|{tS(Zy^A{kbn;_1G@7@X*(6>RCir;|EcXc|G+!v{%Dg~zq#=TP42n- z=iHjUGwB~alsb2qS&Mf(@E5$lr+4g+SFO|U=~)+kP3^9HlUe!)e?hfh>UgsTXU(i@ z)Lfr;E_`;ryQZHz_h;VO1NJT|o@vw_+ZwFnb5C;5+Fx)$zu``;%6q%-hrF-J3(iON zvZ3MXd6Swq>FpKh?VosbfwPkSH*06;)kkxObM|E8^H%+V24_|48g*{r%Jti%Ra_cxiZh z=lhpA-d+v%GjquF_VRtsE^BBwkbCq67aHE4sn@7IIp9wH(!X5ZXQ%<&5fUxM% z=j=jr&Y7kzFc*Ek`}xW8hV0Euv}f+1&R*2)&7;=I7rwy0YMnVVYiI4LVNY%Ug0=oT zYklDL<{f42{|c3cGn(4_rKXqfy3}UrIq!YnJnu6*L!k@ay|dfe*+%6;eKNgS!~SUA zvdD7&y?5$D#q6N6z^wJ;Ij3GSciKB(CT41p+1sF3)4TZm7TM8c&MmP2JbSG5*?;hP zqq&&5XQ5)Ies|xg&!}}lYImCDU5(%1z42!<_okp0@}8%VysE;PbXSexqOi zegm?GC7iRr^>D_!r|=gV7VM!^-gHp?8jbI%FDe_nXQ@!o{cjd#D~cPcXTH-}npxX`2DGxqbFE&PV;(mQK@;DX-%M$J^u_eaeO zqWXpnchuTFYce%=6<(|_>V9X<&g45Db>{72#@yPoJ$k19-c4TA*=gpSD$Zu!nfOcI zu;ufvWZCVRni*A`Q7<^>j638^uI}{FvYx*A63#U$cW|cgs<|`EdN32~3kC13sIzAF z+8Tbn@#fqlKAOBIGv9;$LW6#TcRW8k57`>*iP?qsX7um7%*~pfIB%BU$*kalnSC+q z!P*_~j@h@!Hhj-m*W~Jj1!pj8P%n7adJB8y;0rG3v)|fW-dj-T8?`433Xb0QqCGg9 zZ*az*Tv#&S=Zu+p&<5)T$*j!}c+TDi`+C*;UT-q<1Ie7f@Mh5(-;>+d=UsYt7AnaP z-W_7Tr)PelLI3VW3-+guSL4e*^knMMbHC_c@B5yawf~!+?4sVEZ?(VRzy@o3)_3Zf zEcb2d+;7jhsJZo;T=l+2!UI{) zZF^St{%|NIZMPevO6 diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_64x64_64x64_64x64_64x64/input2.bin b/test/tilelang_st/npu/a5/src/st/testcase/tadd/f32_64x64_64x64_64x64_64x64/input2.bin deleted file mode 100644 index cc15142608550db59428bd0584bcb0a235f00fc2..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 16384 zcmZ{i(UC=~5d#Uq85w6rD1<^NgyL>Wx^>>V{Lv+iW*9hkcZ@N|I51(uzy34E`u7+E zZLR<1%-FZK-`>HS$-q0)xo__LwGVF&wbmb)dS+DX0nNj)F=7<==)A;`^`7LJ1736x6l!P(FrGYzuA5E%r|t8I;lH1@GD$Y zUY+yKi`h=Ezu#|%edc?IJ7ng}NPSRe>@D#8`%O^2b5HTkPH;{xoai>=KI;{p&l@h8 z_n1{T-n{2#y!zga4mi<`-k_h!7i#UEX74-5ZhC6J&AEQ_oBw?0VNO=&pJ#qkPjkGP z^#c8c0aInyRzVfdtZ3ix)6DoD=^1r`b8i-ZVL@|iIWWL`ROjFC7k&4CzTwT&TRo$v z-`}L?ul>)=PHq17qceEMyn7DbY^$(>nciFm99S^n>Uqmkyz}PjiEh;Ut$UySzTaG& zE7qNH&U5w#T-7YTyASn*9km{*d){oqY98<0{p1@z;WYn0zjwcLF<0+7{Xjjx!EA+o zPiO2OXzr|;b$6=fa-tUb+Pk%q^{zI+WQy1#@NZ)JrdSnwn;CwRYTHBKaJAR_wZB_=}o_e77>diM? z@a#YjEYQn^11oAK)-&E(e^aXm`Zu%Q;2q|@>%EunwJ+B7D|OEdYBP0)?l-$|f_=01 z`=T4{U6?S!dq=+8p7Srmf8X-`pK$Jc&d=7_V?DrKvQw)ox&GD8#9I$IsnrR+TQBRA z?RekXe*aJUE!P=P-IdjWzhHiFueECbz-mqAuJ)T-`|gw3zjfcb>5Frnx1R9sV~4f=qWbqEVu^+HW^6=e?KSX&?R0 zc&_vt2HE~#zA|S|?~a?SfAe4MjDHJeCSHBDxqSY2+w*Miw|}TT>%4xzdj~x;{fu7J zJ#QDjwf&RYoeS@b`TcnF4eE$LQR@k6cP#7Nck#zs$gG{~TQ;?QGyT-V1?&E{$2>FU zum1G^=WgercaGHVlkWJw^D-l5M}-Y5G@I12qbvS>6K}M4!8i53@u|)2J15_B{YLvM z9CX8kH;eXnx$)MW>ppv)G4Eb$^Zm*D#zi*3*@@nG_nGgonxhliJMkA?p=Z5glTA2q z;oIw*|IRJu#52BYpfX|U`?o?{zd34W-ouVeCO=Rf&NZgpZY0& z!hj3SHr_pI8F-mdcgjuXe81;><~(Ch_1zP7?#<(`d6mBTNw(3>d1g{e&%5tBi!O6L zOt|1#z1q7E^#=Q+(0})QzR5ZJs@T7%xU2ij?Qe5faN)qK^KSE<{QXY&37-Gr%!LtM zt#i+z-k^GJpz?NRW>IfA@P3>A-hN;5YY!%PcB5t|EY!2s6Ra(hisGd2fbguVWuk`HS%?#&wcvy$*d=gsB_k;nfk!rPwg&izelfHA8=m(`sjci>>a$l37#9M znV89j-aS$~`|Uf^yX`G%`>H*$w&$GwLigGqsoy&?XJ+rDS3jxUBMX0Fz=5~Znx5D* ze?8Pp-`skq|IXB--Q}!)q_@yFn>)-O)Nc;;ZJWG1W@f6M*E{p|Cv%?b_boELbNZ+B zC;r0ARI*p2hkB%E{~zb=xeuMenfF%C>dl?GDp>2?qwZt_U0~MxU&fwyC%HN^vr)O= zy!irW4mwirWZhwJf_?F<`ZCULe1D(UBYSh?_71eZc=w94?+)iv_q=mEGy3Fp)c2L?2osqL8`csjJ`MF5A++H>vt`@GiDdmQL@L}Y{HIz zd(-b54!!0Je>2;BelqX3zdl&Yh659ZF#GqnGxsiMPSE$eu6^)a-(+U3p77otKT!MK zF{#Cwt%Ci7zM1HErr*_jE;F8UhrJzdRzL7!-MyY)$?Q$&JFY%;)pLXFzyxb`!-W&= zyB4*)`*_>HOMj1zzsXL#*-X~o!EB}9v-U=E``*#_JxyPKGUwTemzzv~V1U_K1&ntR7a<-moWd%r>Yee0w@ zGPC0CdF~{isJ!1K@9Vdy_E!3xJJdty>&>k%EI81!FJrISgpuBRcWN_x+nSxy-_(3T z)%;M)h8s0s_^SP~@k7{qQR@XWzHha!m-mj$bmvHI-}BZN^zRI6cED`I+r#{gH#=ZG z)rMim|I^|%$%91x%CMb z?RotM^}vD&?@ZpdsV}JS9J=CX)L#3}*mp2a@CilF0Go!B5_D+1~C$;(b%-m$ocK7p{+3EM()4bPr_q_RHcII5~ z9I2h#(BIa+z0OTN*mu|c)XtgpZhM{Se0%QDtDW2Uep}~;S?Ijkh4$VowK~yb55Lh1 zW(#`W+FbMt>}_;>IQr!Jp7)KdcV-qk^>D#juTGUS>f7OO%HC@ZdUflgde2y|cBIQ*57v=YUk7y^vx%}Z&gQn7nR=M-X>e0o^@xOm+v=CYFTi= z*bJ=1Sh#q;X>uK4{&7qvStI$)}wciS_Q z{lR=t-{2YhFUPOsW(Ta_J9y8;J7=$NH9OXzAF%7)$GN`6nFDIGnR-O+ZGFx?pZcN) ze2;29Q|o0%zcamS(R1%a2bi0w((nA6_55bGGZ*iib@QjmyvLaXZ#O!@e8t*}HGWFZr$;_P>vl|9_V1YAR@An9PCLe12 zFQ?wg-dpe!?6+3U4s`#(JM*iVF>~fZ{nFEOhq?9qDsQwtA2wc^jd=Imk9VhY3!bve z>`-^sTJLY+%uKIuGQWRv&pSJkyQ9D7+eh9r&K+>}&FOQex#!Gh`WM|`eP9UvW>tF= zdiUGe`qTU7o{<&x{Cjt@?z*3ufwy+go5!opzC67ZHEZtfNo`LxR}Z{*v18%;8>pMw z-OfyVQS%!{@}uVOv+xt#tKZOGzpGjGe&$!^)d2^#K5x;x&)$LqGip7AS@k<-Z$k6_ zZf5%Cn_Alc?)YW&7xWvS<~QrNs*@S9e^D8*;pmf3{Iv$X`9_x>>bu)_)W2Tta&}4H zVtrFr5A}voWtaQRJ-0qR^WLR;ZracMh&R7rulmkWyYryk<;*DAZ-0S0)4OO{`0g9{ z?mT)>Hw@@EsJ_#Fe`i&kHNT*rV1DY+70h3sdVud}?H&5o7k^;G_+Zvqd(!vWyQ!Pm zf8Xmjr1up6W}^q)p!&|f%{|T48=RZq?1exeTi{GN)Slb+@>~Cu?51z#taAfi9zRo?xwqf4s1MkGIqIH)Zn)9DMLlHp zWrr1==x=wwnaqCo+P}#L+IchUs`D$^K{u!q`quhI=KM`xeS7`ZWD|XF$oFq*XWp*N zdCoaAz4PtAd9wd`KiL8MD?0R?HS@griB|`hsn)%(^(EWTZ(4ZoZ8ogIbMJll8$5IH z&My_zAN2bTle~BT?xKImov|+moLlGyXZ2>^-X$CGG(COugYSFw6W%WN>1CL~)w84T zv$w3MURVW)2HUTN() z{eS}#e5?6;bLKa>Gxk-vu)#Uiz3TJHhdEq0;k^Mr!F#p}>VXN}srDVteCOR`=9&G` ze&0Kn+_?*$cYZ+i<<#nix0ksYZ|*M7>dnR4S$ms$fc~IjPbR*3_k4Hhof+Y(IcNUv zN$;Ry-(2rb*~!~CyJo?<=gh3Tr&y+p~+>p1aL{GtNQhE`Gy=`hh>G7wX>c@8Hgt)0^O~i^>kp=nwpz zH=E!&)&4-u&HA=Y-QRM3ytA#J=lnHu59fb9{K>hA%C?R+Thwx)=Xd->4_N!IewW@^ zaYl7cZLjA%+qcWm@3i&~b;E=M&8*8c2la%` zZEAbwE0|e3qgoH4)=yZeUmkCMpx>zaUh@m;3ijUjqW0XaI$u51&!f(^{_d^sxykHy z59kM2yWd%T)&2<=U9hA218+Y%!bR=vsQG~#tbI$rr@1rkvVJ@JKaJY!{ElC!xq4xP z=LYKj2mXL(+pn5CKkb9+8S`SjspZ0e8E8Ci3(Vau`uAJ73l8;!3->)C;WN_n>C|4l~b*Gph4X>6yt11Kr6MD%StV{GQK@yG~|RXH~h;iS};a j*7w=p={X}Ss^8%3O?FUuzg6}uYWpuohtPAze9!*^snk5= diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/half_16x64_16x128_16x128_16x64/golden.bin b/test/tilelang_st/npu/a5/src/st/testcase/tadd/half_16x64_16x128_16x128_16x64/golden.bin deleted file mode 100644 index 97451bdc6329cf51de745891b11dc9f64358eea1..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2048 zcmXw)368`t2n78!E6N=|il=W?GplG48*?-bcKiZ5iOvz*m2*z;@I>cD>%n>-O7wA~zq3A(`si=t^kqHqc(8x@9mpp7AKD*h z7n;1|PfvHfwZ=EK+{y@eF1>90*m0_j=uYhLiEg7g&@kJ#jDK(@iU$k4j;|gvHoFgB zdwq4Y-jaP=JY(TA9ARljW)@WOL)M(irWweAxc0ow?nDaKDPJ-=(NSt7k$_?Ra^!2A zjq@F(N`t!=dn&T*!*P$fE3}YnX{W z7fhUX)y&9FvAa@XbWXf#4?aA~W!I*^@&h`&TXS(+iExaKx!RX2(e2wXI*cyIsScLb zK`U84V$ro!k`BSrtm!B;x`&M?S~Fu+o|})Zs_PkmxRv;A9j+JvE&mtoZr`==pJZJVWaK?JLe@kzslpE zJBkSQZh9)x6MJrtA~WP>Yy6(4u!)4(0fqhT-xEVuit;4d_S4Ju1Dr~ zW_N7Q%?9EI+Z_+5711Rh&oO;=S_R^@ZH*^L<}G*$4AuR}q#` zd}5IoU)Um1>s{P(vD%opS-InhVKcg99j18?GZ}10-;cau*)y#%>F%u9c`}}sc*&7Y zM@ia@n-~=LmyCme!Mx>z|8jX7Wjz|_O>3F&Ip{KynO@KniCACxFIRl&SL~r>zW%D# zv*nGGx47-K*r2fFFEjh>OOxOkG&3|m%04T-N8c?ab}$^bf*7e%4#}o)o!`V7ef0! z7{w#e`9~$$*!Dny< z#av%MI@e|4Ssk-0v)QLUt}>$kT@-bcGq33}wCWj~v!bT*(6wXZmtxhYM|6%S9YnXb zJG0TVu2SoD3hV}_ho0kWZ@IAqVHo|E2gP@C7pdLfyjWX@`oUT^@3x!fPUp@FoGST^ zEW*Lk58{r;7|==fM@_qR(`u%6vR$p!NbSBJu`i~%8ueu7yUJp&8>eFO%!-+>9O#|y z$SjpExBUjR%lO5cVs!3_i>G)Pt}D&f>CHbY`V7ZZtZQSn9qRJC9rZZ5Of9&(PxOx- z7;Efl>{Qt?d?$alsTty{*s6gk5xi-r;^(BPRN()_S1Bg{GPZ*Umi448{MKiXt5g)d%9Q^yL#ojSU9w5 zvFw%_8$6FPccS+s7D?7oyZ(6jwZM=f9=WRj^8P!ODx$Zrb#Ko9j`j4K9m|%dyne%cMy;&LOAfTDR&}>pUmneAi1Luh{#!eD z0<~1e+}>%~P7)(4kW$BtR9H;e#V;0FiiYXB@7!hdAH&nT`4*#EXS7>x8o<#}qR;;? q%h|^p&w2PRqtq*CQ8;gP4cm^wDBqnn{Y7_OaL)Nuqn;cnR)qgM4VMQ1 diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/half_16x64_16x128_16x128_16x64/input2.bin b/test/tilelang_st/npu/a5/src/st/testcase/tadd/half_16x64_16x128_16x128_16x64/input2.bin deleted file mode 100644 index 137501d2c4672129b998c65b60df395469ab8335..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4096 zcmYM#Tdv|j5ChOGj1LrXfn~Au@v#$)B9cybJ#1H1INrbS?|VFReUDe}=kf09)z#zp z9DmvIWB1AK=ZH^iKgTQ2r)Pf8*x3DZ-s^Gj7SEY@-#i~#z3cOgx9D=;`>)1;fXD7u zY~P7+Zoim=B)fmPKDB)04aP^`V9I;0_xyj~`p;uoKe0{boWOdzcTEepgFNpzVLhII z*>Tb8RWE$zA9Oi=4-69FthK@2gX}Pe(Hc9lzGe9 zV#!Gb@2exG_&^{EZS3W=w%%6Qxs}lyF7HD{pu`p z*RjQ2=kU`H>CxD!BPD3QcRp&s);a%msL4*|c&&yMik5j#mfVA(Ae@y0v$(?Mr1+E8J$eji(kE*dRvrE2km!DTlpmCe2* zblMl}yv}LAi51`OGFh$8PR#&L<|W;dD4PxSQHM@L_*wD8VKYLDA)b}|Q;@Q>T347| zJvMFBep#wdaEV84y$T`St&1v8Ia%2eD_OHU^A3ypJf*C->)c{f*@+Mh(`chNW(iQ8lP&pOV2}OBuIvCbCNv*zY^a|5mH$8%alvW#T3l z`_0QH&Nor?NCt=8nqvX`2OU0<%h|lPMK-xBwdlaFNG&1mnm%P zO%eNdpF&3sR_WUK=;~m}ik1dEBDpl!rwe< zYwG%asLj}nH^!o{5@%Q_VGm;}tRh&kx+~dU$B`zgdXtv7;vM*wrg~=P6?+r>8d#y{uO9wDTJqj;W?P z)Y(^E=d7lSjpN0$8B|0iu-(7pjY0(T&3Zl%v$nS{=3V!FO^1PEJZTY)l4bhz|1WR& zR*Us=qRBK3?#{wbG#ZL^cUW;|kD~ZeaVyiZ`BB?ypB2aV!f#e-o%-0|08IW9rlZ4O?4O5o?Tb#tmL7MipziXczk1Vsp5n#>4aRYySluo zulg|PFKBcG-gMr~ZUXqP!#mZn;?eIzRw7eB2;)1YShpv-W9h=& lKXzxVT`-i*&+c>;bsZY)*{7PPoUzGj>hWCvta6Y%{{e-#jFe&kd+px4>%8{6zjfzTMaS@BF>a+~<70#>#EJ z@8{|r_?vsL@pWDE9&6@d`MZp@_pCV0J@>D3{Wl}Wv1QF~R@`=<89U6J z_q@+z^}0O$4Yu>Vt}|ji+7!wj*($&WSf@gf5(VC??S-e>$` z;WT&N%O}r#7j8!}>)vd*;F7*;jh*eIj=EoqQ zTh}taaAY^=csS9k1Qw@mHEPF|HOEDCTaS>N(pIM6^n5fua6(%b*_@0+o zsw+yuL$#>K;a`{aO&?+3cw_KGjAwSuqb%5}K@|yEQK*S(5koR6ntF&;sv+F<3bXVL z-ft5pY%zgx{oM{seHdZuJ}cxxgm4yvs6TU;Dq*`!4o)*dKWWU4g?CDm<>~=GaUaJ{ zYO?24qo)~D{V9e%+_S^NW`|3MROdQo^b8R|p01%%`{GYKvQam|)qY0Z+bmdRcjR3T z?Ks#@*J*+sY^ub1CLW5mE`gDd>Pa{;mzw1b9yW~FUmlAu>PlnS#(I9CV^5ZzD_Zx} zvn)7Gq+~v-r@S2I4p!OGW!X88g>vCem}06Lrar1qCD6d0E}`QirsBq~WTzD+O?~Et zODol3df;1*(+LD}DZ_PJ%GxguinMMa%HO=5b+9c~J&@|G7d72ba9NTadOHk?_Lcev z?bTy-0jt@fS^0HEtLha~yg`639_d!i!YwPl(=7{~L7nqbFtlk_PL0C`7aJ_o6$Bfz zY0VX`D1ri2suiiqp3xOW@MT~4E$vc*z6MWmI+>Y8Fhi!TYZXpo3oEj(&!%YD_{K$A*`iiCu&_vlG^3Y%2k&9V zA;gF1>x3%8>8vVtI@OQ8>SWc0b+Q?f_9=L1x9&s&?curHe#7bq!jgKS=RmAHa&!iQTNIo!F?G>iCH{XJs0b z$^7^(vryr%zAt}bid(wHgKP2d3r`Ae1#Zm~GsVa~g>3%WlPfj1pY>unFFjF8S&`%I z$tyzFS%hIe!2EkgBO-O-=!~+!YCFNbyfY4cEsfbEP|ek zFeG2P9p9aR<%=~j*HP3^owBn~U4~bCw0&yRC2NqY>v9`&SfU_T^>h)-49_yzEgow6 z^vk}Aq8gY&D~Ga8(u+N&XvdwrAA^?i4;Qvajv>Q+6~3$aQcEI4smFYdVJ)p@0USeH`g zRzqIrNcVh*MW%eRA=5pLtDCguZ+*&n074eAuKub99(4#dpjKCQZhq@P_0FEhXoROu z*gCA0>W^ieh)URE9lA1)Gd@2v$4IrTnqf(wIP6rVlC~Zrs=8C!#5*5OhjE~@c5cK- z>Q{w)>i2BcDfqHau+e!`ro|}&ndc3+J@cqxTtXKAI4pK*bCFGdy!PGr%|7*fCsJ(?1=Ts5qyN)Iz zI@BxeZWLq$md+!zM-4?H^QuMOLe5vcT8Hr+a@gaAb+#$abAsrok*Y=Fw5FQnA4ja? zzMS$dt2$0`rcI2{HHCOysIH(kTPNupp!}=Pa_q#2Tek3t19vb^3%G<pc~*Ct|tnOd*r< z?6*z5NtKDyb;>kX9&u$oTk3^ooNP9wP=M8Tc6u~p>ojmeDZ}_sk5sOM^WtRWP1E*$ zv+-LUr^NJEpH;2(ExoE!1KV~YdZ-pf<$Pvl{R5kLrvtxWU{6N*lu@e1yDn*$Wl2@I zOOyRajGjejw6h=eIrTWXPzpr25`oSmGtjgP>!j>@;dT4ffQs4lyH001Gu4HiJN?Kr zMbyc{A;-Rn2V(4`rSvWO)XIakDcu@~?IdbIPZ7VUT)B%$PaLf3rvB^K)2aBybt zyy+3+xPyu%eAkz0obG5$ws_AnO*ZX%BUz7p6DC!l*@t`Z{JSw_uIIqjIl{9XDk^jyCs=HrV z#H}{kFJn)SNGsKqRWzjY*Nw2y8A7HYms@{Ds~_kVc1W=rmoX^JrVVk~zte1IAS!J( z-q;h9{7{7|!AsqxegL_sVUwkpQ{i&!u1u*FT$IQ9C;miXET!uCyx_Nf3T4QRd#Me( zj^9U}3arN@>v|JQEMkv!9i4>``TeJSngdgBp2Y=WX9x3T61zPgnJFtYiWmBJKBjwk z#Q(o)A;Y*?EXl53hh6(s*6FkwmL;9iNs%6OU^&#c?7#IV5#g8`>S$_^B0HnVubP$N G=k7nj4WFU_ diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/i16_64x64_64x64_64x64_64x64/input1.bin b/test/tilelang_st/npu/a5/src/st/testcase/tadd/i16_64x64_64x64_64x64_64x64/input1.bin deleted file mode 100644 index db129a74126eddc9ef4ce56295a8e46703c2317a..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8192 zcmYM$36df)5JS-hnq}^P;qj|K8iHt`>|7+JZ2I+c|M~Ot`FWks+xhl5pTFnR>wUdG z*ZS?OKF^=E&)JhD{`b6I=ks|oALspft+s#vpMN{?uYVQpSEu)te>=4v*EiYs^K?5E zKIab#St`R;Co8@FT&us=PN-$wP8HGOAE*2Myz5XsSU(pEyY1*MN}kS8{+=h1@9Uc? z>WYNvbv-@LlW2N~exFs}#Ly=+BB(|Q6`toaeIF;@lU3^~)E6@-=s{UH?V@F9ZgRR^i7$Iav(IA;m{vfP*J@4_N%tD z(+uJKs*OJtoHEYHXDT|i^RHbwv~4EM2gcNw$!j#BLVj{Cx)E1%vB+3vexLXZ;I^BNY6MdKgY~FbD62w(G4F-WfClyP{}D3rIzY?SAzQ zySPzUeVjK#isek`s)q$PnZa$+Va+7)myZ6 zjNy@ymrV^8ndNe-=wi~0ZeeXlC3WEImj`=2^B>fy1u-R3Q#Zcj26*yilc63JsXqVd z{?)7L>qj{qV5EjjE1}fW9>hH4NjJ6l9&=CpzgrSRh&`v0shSH!>NclTP$hNER3{O7 zLENnqe>#dZz2;WE7&t$&(L7IA(@q%u{X^XVS5#N6&h8cjW=T?fvJce7--mMlJ>MRCQh~STl%1+adtW}Zw-*4EuXRwBj>uBiHj+eQ0=4qQ3MCKh`sGI`$ za5Dd>C|)Ru>zZ)v2HzRogkRHKYtQ%*r$ zJ4>*RbxxTpRZVFpf+l$60ClR+X;S6;!V9h@7@DwK>+GrL9OfQVg;t&{)$KV!&RgA@ zBD{8Rrw2unF-7bB z$@JRkQ&FrQ-AtW3Q}f*H%1}`iyzt8nx|GYoddUo-D)z0)$+_+-8g-KV63C=cJ5C;L z_$1A`Eom*zllq-79k|C_`c3d1w&~Vcg`J+cM6R2*DVM1#44#^N{)G-(pL)rIk56!D z-n!;T?4WU{;ie~zu}NJs=bmm0=+i_0xSNkoyA@^l%iC%<^%$EtEY+hcb zL$1bxO5MUa#P`(B6%bC7roJ4rOj*jLYG;cMu%`&!aj_>}R`V~9QY!08O;du5;&wVa zGGOhlu&0W>W)A_xT!mjJOlSQ%v$E2|0uFk|F%PCV z#!XvK&ZCMo`Qqhv&vh$xpWwcG3}dKigC~sP>%^F;opAevRqbvle(bKpn8(D^$g}=d za}Tx@A;_Ozl)3YuDcsc1q_h0hGc z(TWFQ?v~5P{6wQziW1Kuy7L|lt0vYtt7a(S$a^_G27JsLp`niTX5W3Ylaey(mB^9qjCkxf6J@^V;r2QTEZo%!`p30JE5 z&R2TI4T64Ux=HP%)EMM-nsF;0{lalax!&_9_vog=ZUBApY7S}G?ZpF7cdm56JB&?( zBKbr2ooOD?PZzt?RjYeCz5B(pLoI9YbWUW2j=MTbRMZ_qpL)>+Uvp@$cc^w=n~934 zQy-^NE==7js;O>z6jqZLu@%P;o*bZYUT z9~LHrPcv%nwA0@y&2PTtu1+;n-D@gQ-Pue7xSO>o(`a@ORv&|$oSEC@eR*77m)qrexnBM*pYPlA{JziH z-{pSZpWmn3iF%&5`-$`XIJ>v=?9}_j%kVj0Cm!FJ_t|-zwa+)d)1uV6o@ak$Q2+Xk zyPwtPcg-0)BH$ALIe%3Z6VlsxQ%(NYiI(|x-k`SEn|J45=kNVIt5yctoVkDZROn6I z=X_<r*Y zfz#!K5(ij^O3r-4W$LGoU2MmzJ$c~b`#!Sh-CADesrT7UOH9-~gehA8O-U{)S5Bp& zit(mq6DMB(@-!nD)R%podfyz=flIS6VxmnyvRb35z!J6uN9DmTLl%9Dz2Xs^Fo_=t7PFc*SJeEQx zg7flg-TMnCo?-RbovNab^*jNmxb$x}bfCIDweql)`Yx|JFy%cb>&J=k%Z#PDA~emV zb17-n6ThjO?|5ZBrjYv0lf1*76YDRpbgBvAE?QR;D>O^FTy5VjJocI#h%sx;X}*~| zgG^Op5{Be;;`6lReSg7n0o{cH#p{(|FRPmKQrdTS!PmR!K2EBd z$E}H?G}p<6k@$dpy67}N!uuFj*j%VzPuQe*E2#iSH?Wh~&v9`v zPrxe^j-bZ;!Tg_CY7 zyVSsG9ph6cR>K~uGE9x~$VjQ~fmDk-&$3h}A43|;cE$0rOftw(Rtj^hxf126Y}WHV z6=Km7ws5Y3=bxW>%tvf_{L0}Adwr_IBkAr$vq6nc`1MKgW>j~O#%s(&+JttlOmPIG zo;ajAi}^5G^{#JNc_pXEPHqo3&LCIAdO5j5^?Z>3)a>6QUgKT+)|vrsa$YTK@cR`b zs{2Sa?{S`|Xh=`|hk{O4y6c+RdV_#I@~I?`C(iILwzXNU_75;P)6W(Ywm+K)GWr0?8b!eqN9z^UVPcDPFYW}>VBK%c*D9AKNXwBZgcq3&@_iTwZf zi|#|Z*vr!*xtNBk^AGNto$kpWkJzDne%5t9h?ODKA!pG9yD_9@&V#CbCz=3U>I63R z-4~w1!w+&*SSYd?@1Rtf z-a51n_4wxnw@u7+TR_u%Vl+QB+j7W;QF|Q1!8xT{M33nwkKE=B`%da0wXo|XRdn5r z9`5+lt86^o8CiVfcS7}U*#8lU=3E+&B!W zD%~BqS9cuMH1n|R)Q=+-rk#ALVNF#i={TL5j1-n1LOom21BR^OZ<@iHe?9TJ9-0NG zQSu_#U6}qA6?0Yt@8R6 zg9DT0-J9H0#v+bsB(GCA=q=9WE{7h(;=-x&dX7De`HJt(kP9-W;P)$5uQI5hCrs)R zw=pj#ymg!+c$cwr^&+2M^w|ujK{sf#QBK`C-6Z(sRasR1q0B+^NbZGDw^p0Taq7H0 zFr|qu@@{;1x36y*^x~Q3DaU~r^yJ~5Y?(#JUI)yGj zTE|fz>Nac6;!}PR`IlOD;tUUI2xpb?lbdqMRPVgny91lX+|q1urZ}CAYw>u3C10m? zGZc$?H9vRr?Do#JVZx%?<=ND!obs5@q3#$+c%>)(#zFJq9TVa?KJ7foRR_A@eRWlJ z!k0(pkPuag|MeQjbwiJo;4aUy!YAYIj%JSf%h!CQI7D=^$7N$%d%VtJxvsm)E@rwxoKDTU)eA>7VK8?O!9oc$5`u0Q8pT|D;+*ZAvz=1M zgu&{G1DbZ;lL2}>=QUaM);%Pe^VBX&w{yPDb5^J=N;W+6VRCxkby0nYvahZ{%9uP(8G%shju9Asgqk*CaTJ zKbfG<3C-E~D5lyX=iM<-j}+Ri)6G&PxVjf|lDF=+d=_u`vfd4Wag$RuNWF`XCpuxo fU)^jpUhU7hr8&>B9rxOp}a^7%w1wK0J`;DI-we;2Vha*o;eR$%aEqLx;&8a_a7y-52L4i zczo)4$}=qQ?S@w0+FsoC!3g)}aEkV;?v9&IK6&7WXJ?#nFmW_T9?gU^E+_rfgTW#1 zF4|q+SxqlpClAd5-l^9@X7((5zQZPk!j-U9In?2ZuU5Z+$0TVC9_N`tqn(O%Auc@z*}_>4Q`6_W-9R z=;{}A%_hgUy1sbi^_sCc)!G3qZ?$&8t6XzB`GbYy38y>{=BeH(A1547?F+Y%Q*&=W zdDPFNT7BN~a_&4dtCKIMfNL+k`_1r+BQ1S2c;Q3fhfmvFc+_zBec|k--N3`81zUvC z>s#f|36log>caELpPIMX)23&ha(ZFfky;_AjtgGC93FXlap%cf-CMZx;4E(z!J{5` z96Z3r;iOl;cl&Y*oG^Gc=dRCQ1#Tg4jx>1mjc~~6X~zA|c20|qIv#wuoW6yeIr8m} zj|PlBA(tk;z7;NaI=Jw7*XLBD6$TClPg>sc-mu>EoO1QZd8f1e?M2N!ZMtbUXI{fL zQ}e6muRY|=DW`6Jp89v>q{}x=-#{JSTdi-4r@7qI*5h5DJDt9h+5G2<>M7H_l2^NL5!@1);& z7<0&*j|RTDn=>9*^O*$}7hZKe>hY+XRW409`0IbGDld(BI6>A!iqO`#ySc@T>+a{H8WfU%wnJ?|ygnwbSbTu4$MPFI_y|?F|?2 z?IOSWaLVJPlYe;gsl&bT)WCIJyT2F;@a#@Y&2GHY zpoa^NPvJYk^qqNu(Vup_bkm|?Zup(E{pW!u9;aS+TKw~6-rjWjeZa!t-hQ5&gBJgA z@b%!3*Qywier*ALTP-1DN3wi@mIw?aVtap7e)pKDu!7G@NR3&b|?DA%`1R81D@jyh1)+ zd*;Kbzd7aH%j-`|j{oYo@}(y4TeQD8^W)vO4j(5T+_3RD^Up&wxSO{+4z;}bZ=|}M z8V2N=nKyuj5=Ok7Sp}{X28KC|8c?K6U_jp zM=hM3uoL-edfGJ(+I>HFczJyG;P=Zf3?KU2g`6JPywVr#10M`d`F6$+ejHA9zU1@m z#G@}Bbv|&WnP)lPcvkbyOFFd7L#sJq1#F)9qy>Xh-kh}BjgvoEobd7Lp-&gaz1j4c z!CSt+ZMf;+ql;5cZ14Kj4=3F9;n?}}gu{&at|zQ~J8|-^$JuVY)p(*O)M(J%-z$E3 zJt|#yGaoQ6tJ$Z6g=TBWUbq!*rE z-@19?;L+^m+oN3HUq9|<)2j~ar0EV%n{K|Yys-F$ySdsGet8^fG}Yzk2-xm$YH|2B z=2YiFJv=`ALS0|q*IgY}&QA0@;pO9kHJj-B%E80c=kB|O-#zymYQN^>2}X_wxx6^( z!MF=L=9Z_`zh8HMTbeW9dd#S=nY^ph#lf!}yt^FT_T$vwTxtbvb-iZs?t6Ie&UmER zoceI;?cWL=e({?XCJs6Lc>9f{n}>D=*Bs66-S38;e0}DxmySFR7#i-p(str;=haNk zW>7b)KD_NpkKYF$@cNrGe)xRJ@gdrYT$*(BJN4KrPIox-G?Se0hUt{UkAwee`;BP8 zo2Qw=)$8q!&)uxemPYf?u>-T(hc~Q%q2b;>U4zS3OOh zmZy2;^}zK_^|ezu+`QN$T)O$;T~5fsi8T13gGbJMcyRKs1}~p?Sa}%yYTosGyQi5a zv#1I8JoMYeD=+XZCeA$Y)*Nx>#V&F68;&cyI&Nn^d4z*0@_}E@UIZL&^)%PJ@6)${ zqvH*$7go5N$1LqaFJ3$8_mi)1PC2|zSpM;+t0wwJbkfL6eD%_8zs-XmZ+_kJi0!AP zzyE(|F7w0MA)n!3)p(&HakaEc&A=G>T3A(({*n?bvetqio*{S*R=1fZ-zEH!}u2##=@zO~*E_2xtE%)j? zxYy$?;ArA(AM(v$Zae~x51w(!c{}mQ@c~y|&A)AVy#3a^`<>UXOltxn}at zo4xX$@A}m1PZPhpemcwB-pr&=4R8AG2(CTCdBX@E`0S)%&gRnRt`BZ|&Ed^|e)HkG zG^d=v<%Dl8-f;-J_4DIzShIM;n2{FT?qg?fRyyv@sRvK<$ivWWes^5;$h+$iG(~&i z$)3II<*(nXdp^^y2cMlfn~9&a_2I;&9)?zTwSG%@-0?RnT)zYQW~0$Aa($ygR($0ghI8J$xG$8@BMovr>|4JeB{Mg%jvBj#%TtA@TlYPHWM#2@+9x=#FIZd?)cgbj9i>B?(Lh7 zH=H{kIE8uT^rxwx=lt4(8myDHQ;%8TVcHMg@VNEyDktFO;GH5baxmU`2uC*@UUyu4 zd!&b(-}L4BO~nm^vzoJ++(rF)$d8k6 z_58ONx^klXYA=W1q&`0Qq1j&CtHboW(C?i;?{>#eGdcY}@&})<{HN>f?&O&VT=q)K z8MilI%|}B{tan`go>lYaQ^3)~DUY9SC+_@)@3-Dg^!es^>_+(g>QxhJcyZ;=i38qm z5{3pZ_~e__ohMv)zr*BD4L3b_Z+HCk;Dk4yFno!;s^f6db*k}&&z&b6b{Iczg?@9X z%XP;UKfdsI-Sgs(14g*#$@i=;|2T2Lnv1{saC-9q!?(L$-r?FW4f$~L@YfyW*V&AD zaJGM3IQW&vrACL()vLc<>E(@v8F09(_w8VC$iw3IhSjt8c=YIp%^Q4Na=7TkVMcYK z&)M%1UbIV^zL6f7yy(N(OnUIpFnfCPdi-8-_HRwj3^?_rDd&!#E-o5*cTZDJ;LyjT zu#dhA|LqMgecbx=(B{=$-roehJEwjcxYJSB>r}_rca($2C+G@!Z#Z+o@S|5P96nm| zH1s$Hee*X9jL@Uz#ECl|T6)cuMtXGYS+9U~!sSaYe!RP1HShHDjJMh1bHUQDXZfW3aCU-#>jt)*Z`*5c1cfd2eJZ?2QdYwX_`Fs~WxW#`v%AVH% diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/i32_64x64_64x64_64x64_64x64/input1.bin b/test/tilelang_st/npu/a5/src/st/testcase/tadd/i32_64x64_64x64_64x64_64x64/input1.bin deleted file mode 100644 index dc5405e91e2935201d64ae8ef360a5c1ddd251dc..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 16384 zcmZwFS!zQu5JXXDbh7^qM+bbN5cv=-b=TC-em)+L&vV|-`F-7=r+%OF^{Qqt-aqGj zoMYbI90wo9`|TVU)xptE(>otH?+#r!?>zX^QSr>F-`Xko#yXdR%G|bA{%sj4HZyE67!|TOs7hK0o zZI0*OUgd>r7rp%G!K&S1^DKKl@$&1Nl$Qn{{c?HJQF--U)ZX;+EC)^p&K##5%M{+* zY<=>Ur9FAk#ml!G<);&d25oh@u&XPZc^PS%m&?q%Z>5^?$)h=}9<642H1V8!-ac@7 z!gSh!N4xFb<}}I#*Y|KQlX*GJ7gH`6`1aAai@sia?GE*6NdEf{n9 zuyA^5mRHq-xo2gi-!3X_?=qB+Ufy`-`O20r4R<~BzFSytn7%_gdW(hUqwk(aI{Ib~+Xp6Xb#djA zS)WcB_4vb>wX0rwa)i@!`0Hh2g(xm6E&l@Ium~_0&+9@rbW!Ljo;rZsLhpXzlw@(`Sc#iqv^l0HY<-m217bh)# zcuu+8alG;S4!(iA>Xb{hJ6=Bfo%(j>d*7`3>noo(Oj%`$uR3P!01Km*!^}y?z1`CL zuNS@@(=5woH48`6?tPOodgGSA-wWP7T{SNrz1`bGy|_okIjI9^U(=KVilyy@WgKY?9;GgxOc;I>=UoL-uF&~(>NL*)~OkF&bX z=$qkm&qH5!?|us!ae2eEmkP_<=^LmHkFaTa({{=Tmk!=)nKf5scIY@XaO*YKr<2EK zqfhCMUeg@76C{zezm$dh{H)uzQc!H<6{?ZoEES{WS5t)w0p>hV5IL>Gj>? z!Q$(~y2I8}dBIovcJknqMa2!5hhDh$Q+d*IH>Zc&j_sQdy?A==@M^ie^Ve$zkFVfJ8w;qA0*S@`ManR7rz~8z&S8!$A!~NSLN4fUIxDMPSxFvmL8s7yT_Bsq2CU$-l`dXnfnga`t6Y}FGn9X z4R_k{+|BB@U%BAi&F`7K`Ieuje%|I~3X|q~(66uVH}H07;xx~veblm6!l}S3h6qK zw~L(haA3;;(;dgGyyf6aLl)R}!PRSD71rFW{OJM2e$p~Vy>_1=a;@}h6mrCe7qNDmQM4ssbw;QYabeE zx$Dc3mKi>t?^j0e_JUbY{l2xCH(iz1dh*GOHoW7Gla7i)v;2#+Q%TRyweHtI3;}9u6I+-tNK|A1A-wxDK86&dZx#@4VvBiem=X zcj-IyjbP2-+{@Q@@Xm`So>Skf(~O@(k8bn$a>By6oA-tZ$0L3#}Ci9ocW}$@6gNFF7EBLI(BP!TFq$am#e#Xd&^di za5VU+eH)q%EN-VfwDoZ4s16Q{d3^Vs!LK(@K4p+GzPr9T&i0$dQ|rkV2d|kN?sW9M zVcQ#D?shYG;N#%Kw1YQ3PFyujIeP24({krmR&R%2I&khh;8c2WX0YuM#$11Mu)Z8IC|yefzx-;FB@z# zJu|wnWmMtX!%P;q#hSMRj=jujmI=nuFGG3Y-PfmCd(p{*Mp%C7HpkUh_3|r|UMHP6 zvN-mHY2Mx2Ooh`|d&{|*XCUA7ySzk^yKRW4~!O8>Ul)tKHj|o;N1u-Bo#xKX$`{6)Z{C|n4mgL7+2%>ltbV>Y_IAp@ zd*PVD<9a*g>K=!8_jv1}w{Mv>0Va%{OIAq z_1)at!Mr|BTpnqZ+01?Wdh}_el~>;8I6UFs=07Wso1I zS-a+EmL{EeWruTMnydEJm%Z=ezCP`j=AGZ1ZaCTa($1@VVZ)o#i(97E;0xa_@a_(L z-@u)YiqF3sH1u%WqkJ%C`g&^Li=G3McKgeeo=U?BpHKSmn+-ONvU=}6FlO{|Rr+Df z+ESuv^~}TIyRSx>@XY8sFlm$rmPY>ehy&Y9A17@XcRPl0udkPuo_Ccltvg>?=)lPA z?$A=hc(-#tJoI|=bK2G1op!VI)H0QeAHCHrSG=<6=UXn7Pv75tHOtnwfiG)3y*%Lf zsAV$eE01cvJ$n1wn{{ybY~l;*pn3 z&U*G;;OJNT#@_C6%b>PX-;EDk8N7Mrufo#Ab>Q>F>kh}Gs_($yop1<%h?2+Br?Qc9sE0W}H9soKTAZ diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/i32_64x64_64x64_64x64_64x64/input2.bin b/test/tilelang_st/npu/a5/src/st/testcase/tadd/i32_64x64_64x64_64x64_64x64/input2.bin deleted file mode 100644 index fdc17ecedd3bfe13eeed257838153262936c5e16..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 16384 zcmZ9}QEntr2t(1%B$;IPzu~DP^_46>5D&18srr7sUVoqX^Th8HAMf~n^3M~W@2obL z^}O-aU+)}OEt<<_vOYeZox8(qac2MCnD*F2qNz|mj7;b3v|qEE-JG-1-T80QqH*O~5(3EMx z(f8Ja>A&dFhS3+*c@3{$Z?)ny=W+FTQRC!S53l|CZ8yBAFE^8yYoGbv=&eUzhQV`x z9`d6HZ*B%J_pR%9L-(Z*i!aRR%6AtX8t~rueOp}pZccAC%<`HSb(rpHKfT@My}EYu zkgx74(}B+?T=l;D4rY75w;k=Yzj^-U=4Q0I(SCDbrhn%&;e;8jy;Dyuyz>p;cQdmi z9M5oe=Fi(ajtm!84U=x)!raVmv2R$Kc*4%`?b63F6M4*+nO(HP=;6V%6Hg5%!+4vU z^$%pT`f1RPx4AhE?|1^9hrUDi!F7lILNC8G_MW~2+};O=lP+94>D_zbXz<&ue5z@+ z2UqkR@{=d{Zr&Z{eY5uPqo01cIer{ieY@$)Je%V+(=%HxjQM(P-W~MyaA8H>Xxh0L zb$;PQcYD*qYfrjyx-}%LHZ=Ag7e z3>{f*-dkVYoJIl5*WINj_>tF7`^@aXO~W2FZ{h51U-LZSdCTpIm#*3yw>#+BueQrO zul4AKp^>NFYC6?!Nmo{<>rGc})=jwj{OQx|jnjSbX6m$LJ>iWf)HLaZm)j3xFE1CH zp1GZJ+VI|V+m-I#6Bo`-nrY=hhQp^DAICc%@#Jo64x6Xu{B;+67(Kh}>}FU#%K6m8 zjbolJ4YNGpK|v-C>5WZijlc>f@)GzqHl$ z(#K01ukVoG?ts^W(d!o9f`*KnzPY}fr{*wbf?nU#j%MlD6;6*HAKvu5Mfb>}oz3mh zr{^uw=^j49$!Ubo=it3Eu9y~R$w zy<5#kzIj!{*2|YnFU>G&ns~yyJ!X08oi}~&=6zp$yJZ}Eirdr;J2jANcntCEV8AfzdJvH6-%5FFB#^ssA>&5f#Zu9Qcr^UZ)mXaSB*oj`^<0ztUf+p?J>iZ^SF1J z;pN-=?#at$yX~aw-97E5=WSOW@b&Vcj;l{k=0P@3zuV~9rG{4*`Ri7DU<>$mnhE>! zB;%&ttbOKb(1B|vo5zL2rzi3qF0Xk^JMVDaXV!PpGoxo$`{;@E^woIHaMKFcH-Mp= z27T{5c({XKD2t%$Y7M(qZn-Why%=+jX5FW_O?Et?5=wK?tOEQV%t+U=2XWnmW{Y_s0o)DHb_(KELP z&&)f|al&+SUi4vbd#lrd+Z$ob^tzvyzKz{^lJV1$o57{)O_v7D-rr1*)_OF{cYS=J z<||z|eLM5u4MVpcd>H+BaQ3Cmt9p6rZjkBV>(i3;gt=YKWjpAl35RDVF1$J|oP6lv zsndj~ndfqSSdpi6+htbIJHK+Z+>Y*~*N!-J^!o-lxOtXW=knYXpO#zl1}kXOkm=#* z!{Vy%9<@1M9($Wzt?=ELCY(qk|Lt?DJutX%eLJ-`j(0n8!{ByDJMsV*7KTo|Jatn# z`UR{zgg5=ZTlg?$c9qvV4$VAx+Y1k;pHI8gFzqbwH;u2CUmD$*PqRG5H^Wz(*(d9l z&0xBZw$QH^FTd%^;b`QqdG~ZfUg*Qp(-V1@h3`chwr@vQouBmJ?V>5?3s$6qXP%zE zcQ~`W(133roOyLO%4T-cQ}Yv7f9Le+n0v#wA66}!%L1+(-!2%@_i8VUKCN_QKJq84 z(>Cv&j-9;q&AiPDc#;3Kc?*+o`uf%8YFsnn&6AnhUKqM&>BEG}8=T+=SBCF~G}L{6 z`25H)-6Zn?Tci(bM?S*h>&e~Ecl-8M^Y+PRbmQUJr#6%Al=aP<<^6i1+l{bs_}yJ< z-1gwO1Kvzu7G`QQVMqSWddvHd?C-nfw_ET8J=y}FCw=((a<}O7ga>Dz97YWzy2qQg z3@6)Dt!Hk>-fOO=18-+H>EW20>8r!USYfVb_ufLId003YoIH7VvtIqQ)=M7X zz0GLKcy7rHOnYJNf%C40U!DCHu=sRg!s?mP(`&E3-LgLaJiy`WH>(Hl-8X>~&Gq{p z`rb6%-ZPZ*E^$ zxgGg|g~J!{^aL;F`DhO=tbUwwJ7ijd4o;q6yR$uT<}}T=Uv0)~d*wJXT|B{e8ts$w zV8&xTx$kA4nttC>?zZ;!-DsM%&wRDK)1;d>S&c_S?)#hJ;F_tu3%quwZ*Tf$IC;^B z>+b%`;>^P7_uc8_F%6k7c$uDE?ZoXSShLlur%&6AR=8@tJh#sbPVH@Xd-c8J<fqc7~TC%@|Y@IsB(H;38pXWo6ZybB)f&m(U9^wo>WPv3>s zdNC#NQos`rfqE_6fK+?(hz4 zpE(>(H4QzvfD?HLPgBoKkG^-i>eI`IJ!%<8tge|_*iTEqRME z7PQlV<3kM>1~32X(>p)#F!tuBxgGIf^0{8@u&?jZ&F~`4F!W&6X0+t)(buypetYu; z-x~%;O(P$Aya~N-!PS#j$9(xb(9p9(-45^WHm{~vt!J;m*T>P%gPymZpfBrt!>8vh z!sID!*>@6skNndR`AL_So}F^v%p0bhrn+xtU!KhAhtsE($M(zkW_H5&Exe0-+t(hm z-tp70&n%;1;KAt`v@7>jIShILC91nQjFz|c`eF4**-mrO!-yV4H_IZb=$%DX?>6XLNfQ_f7 zC%V)Ac>46pbcFeG)#;Y^JK2{nyn>c@-wr2zJKF<8W5XW(yu+I7(@7IXZNHv6o;nZt z#izACx*y&Sb$jsyylB6e)35#d^5fD%)+=g5BbE=m-Woc`umOeRm*A8*TdsapO)MX wZ*$+EZ>N_|{PKDU2a7AbX?w%c?t9ol2d{Vcz~Sa43>*wCn)-b=wfWWm4~QR5O#lD@ diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/cases.py index c964a0329..f4ffa0e80 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/cases.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/cases.py @@ -2,7 +2,7 @@ # Copyright (c) 2026 Huawei Technologies Co., Ltd. # This program is free software, you can redistribute it and/or modify it under the terms and conditions of # CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You can not use this file except in compliance with the License. +# Please refer to the License for details. You may not use this file except in compliance with the License. # THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. # See LICENSE in the root of the software repository for the full text of the License. @@ -13,48 +13,21 @@ Each case defines: - name: case identifier - - dtype: numpy dtype for inputs (np.float32 or np.int32) - - out_dtype: numpy dtype for output (np.int8) + - dtype: numpy dtype for inputs (np.float16, np.float32, np.int32) + - out_dtype: numpy dtype for output (np.uint8 - packed predicate mask, rows*cols/8 bytes) - shape: (rows, cols) — tile buffer dimensions - valid_shape: (valid_rows, valid_cols) — effective computation region - eps: tolerance for comparison - - cmp_mode: comparison mode string (eq, gt, le) + - cmp_mode: comparison mode string (eq, gt, ne) """ import numpy as np CASES = [ { - "name": "f32_1x64_eq", - "dtype": np.float32, - "out_dtype": np.int8, - "shape": (1, 64), - "valid_shape": (1, 64), - "eps": 0, - "cmp_mode": "eq", - }, - { - "name": "f32_8x64_gt", - "dtype": np.float32, - "out_dtype": np.int8, - "shape": (8, 64), - "valid_shape": (8, 64), - "eps": 0, - "cmp_mode": "gt", - }, - { - "name": "i32_16x32_eq", - "dtype": np.int32, - "out_dtype": np.int8, - "shape": (16, 32), - "valid_shape": (16, 32), - "eps": 0, - "cmp_mode": "eq", - }, - { - "name": "i32_32x32_eq", - "dtype": np.int32, - "out_dtype": np.int8, + "name": "half_32x32_eq", + "dtype": np.float16, + "out_dtype": np.uint8, "shape": (32, 32), "valid_shape": (32, 32), "eps": 0, diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/compare.py index bef371b3b..0d5384916 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/compare.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/compare.py @@ -17,9 +17,6 @@ from st_common import result_cmp, style_fail, style_pass, validate_cases -ALIGN_STRIDE = 32 - - def main(): validate_cases(CASES) case_filter = sys.argv[1] if len(sys.argv) > 1 else None @@ -33,14 +30,12 @@ def main(): shape = case["shape"] out_dtype = case["out_dtype"] vr, vc = case["valid_shape"] - packed_shape = (vr, ALIGN_STRIDE) - packed_size = vr * ALIGN_STRIDE + packed_size = vr * vc // 8 - golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=out_dtype).reshape(packed_shape) - output_full = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=out_dtype) - output = output_full[:packed_size].reshape(packed_shape) + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=out_dtype) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=out_dtype) - ok = result_cmp(golden, output, case["eps"]) + ok = result_cmp(golden[:packed_size], output[:packed_size], case["eps"]) if ok: print(style_pass(f"[INFO] {case['name']}: compare passed")) else: diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/gen_data.py index 5cf29a05a..549ff4bed 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/gen_data.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/gen_data.py @@ -33,23 +33,16 @@ def compute_cmp(a, b, mode): raise ValueError(f"Unknown cmp_mode: {mode}") -ALIGN_STRIDE = 32 - - def pack_predicate_mask(cmp_result): cmp_result = cmp_result.astype(np.uint8) shape = cmp_result.shape - packed_shape = (shape[0], ALIGN_STRIDE) - packed = np.zeros(packed_shape, dtype=np.uint8) - for row in range(shape[0]): - for vl in range(min(8, shape[1] // 8)): - lanes = cmp_result[row, vl*8:(vl+1)*8] - for j in range(8): - if lanes[j]: - byte_idx = j // 2 - bit_pos = (j % 2) * 4 - packed[row, vl*4 + byte_idx] |= (1 << bit_pos) - return packed.view(np.int8) + bits_per_row = shape[1] // 8 + packed = [] + func_binar = lambda bits: sum(int(bit) * (2 ** i) for i, bit in enumerate(bits)) + for row in cmp_result: + for i in range(bits_per_row): + packed.append(func_binar(row[i*8:i*8+8])) + return np.array(packed, dtype=np.uint8) for case in CASES: @@ -69,4 +62,4 @@ def pack_predicate_mask(cmp_result): golden = pack_predicate_mask(cmp_result) save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) - print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__} out_dtype={out_dtype.__name__} cmp_mode={cmp_mode}") \ No newline at end of file + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__} golden_size={golden.size} cmp_mode={cmp_mode}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/launch.cpp index 6841e5589..e613c0ff5 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/launch.cpp +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/launch.cpp @@ -12,26 +12,8 @@ #define AICORE [aicore] #endif -extern "C" __global__ AICORE void TCMP_f32_1x64_eq(__gm__ float *a, __gm__ float *b, __gm__ int8_t *c); +extern "C" __global__ AICORE void TCMP_half_32x32_eq(__gm__ void *a, __gm__ void *b, __gm__ void *c); -void LaunchTCMP_f32_1x64_eq(float *a, float *b, int8_t *c, void *stream) { - TCMP_f32_1x64_eq<<<1, nullptr, stream>>>(a, b, c); -} - -extern "C" __global__ AICORE void TCMP_f32_8x64_gt(__gm__ float *a, __gm__ float *b, __gm__ int8_t *c); - -void LaunchTCMP_f32_8x64_gt(float *a, float *b, int8_t *c, void *stream) { - TCMP_f32_8x64_gt<<<1, nullptr, stream>>>(a, b, c); -} - -extern "C" __global__ AICORE void TCMP_i32_16x32_eq(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int8_t *c); - -void LaunchTCMP_i32_16x32_eq(int32_t *a, int32_t *b, int8_t *c, void *stream) { - TCMP_i32_16x32_eq<<<1, nullptr, stream>>>(a, b, c); -} - -extern "C" __global__ AICORE void TCMP_i32_32x32_eq(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int8_t *c); - -void LaunchTCMP_i32_32x32_eq(int32_t *a, int32_t *b, int8_t *c, void *stream) { - TCMP_i32_32x32_eq<<<1, nullptr, stream>>>(a, b, c); +void LaunchTCMP_half_32x32_eq(void *a, void *b, void *c, void *stream) { + TCMP_half_32x32_eq<<<1, nullptr, stream>>>(a, b, c); } \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/main.cpp index 83cb28f3a..0553ffef0 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/main.cpp +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/main.cpp @@ -1,7 +1,7 @@ // Copyright (c) 2026 Huawei Technologies Co., Ltd. // This program is free software, you can redistribute it and/or modify it under the terms and conditions of // CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You can not use this file except in compliance with the License. +// Please refer to the License for details. You may not use this file except in compliance with the License. // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. @@ -19,10 +19,7 @@ using namespace PtoTestCommon; -void LaunchTCMP_f32_1x64_eq(float *a, float *b, int8_t *c, void *stream); -void LaunchTCMP_f32_8x64_gt(float *a, float *b, int8_t *c, void *stream); -void LaunchTCMP_i32_16x32_eq(int32_t *a, int32_t *b, int8_t *c, void *stream); -void LaunchTCMP_i32_32x32_eq(int32_t *a, int32_t *b, int8_t *c, void *stream); +void LaunchTCMP_half_32x32_eq(void *a, void *b, void *c, void *stream); using LaunchFn = void (*)(void *, void *, void *, void *); @@ -35,22 +32,20 @@ struct TestCase { size_t validCols; size_t inputElemSize; size_t outputElemSize; + size_t outputSize; }; static const TestCase kCases[] = { - {"f32_1x64_eq", (LaunchFn)LaunchTCMP_f32_1x64_eq, 1, 64, 1, 64, sizeof(float), sizeof(int8_t)}, - {"f32_8x64_gt", (LaunchFn)LaunchTCMP_f32_8x64_gt, 8, 64, 8, 64, sizeof(float), sizeof(int8_t)}, - {"i32_16x32_eq", (LaunchFn)LaunchTCMP_i32_16x32_eq, 16, 32, 16, 32, sizeof(int32_t), sizeof(int8_t)}, - {"i32_32x32_eq", (LaunchFn)LaunchTCMP_i32_32x32_eq, 32, 32, 32, 32, sizeof(int32_t), sizeof(int8_t)}, + {"half_32x32_eq", LaunchTCMP_half_32x32_eq, 32, 32, 32, 32, 2, 1, 128}, }; static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { int rc = 0; - const size_t elemCount = tc.rows * tc.cols; - const size_t inputFileSize = elemCount * tc.inputElemSize; - const size_t outputFileSize = elemCount * tc.outputElemSize; - const size_t packedMaskSize = tc.validRows * 32 * tc.outputElemSize; + const size_t inputElemCount = tc.rows * tc.cols; + const size_t inputFileSize = inputElemCount * tc.inputElemSize; + const size_t outputFileSize = inputElemCount * tc.outputElemSize; + const size_t packedMaskSize = tc.outputSize; std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu, packed_mask=%zu) ===\n", tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols, packedMaskSize); @@ -86,10 +81,10 @@ static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { tc.launch(src0Device, src1Device, dstDevice, stream); aclrtSynchronizeStream(stream); - aclrtMemcpy(dstHost, outputFileSize, dstDevice, outputFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + aclrtMemcpy(dstHost, packedMaskSize, dstDevice, packedMaskSize, ACL_MEMCPY_DEVICE_TO_HOST); } - if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, outputFileSize)) { + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, packedMaskSize)) { std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); rc = 1; } diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/tcmp.pto b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/tcmp.pto index 317fd93ac..cfb08f0f6 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/tcmp.pto +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/tcmp.pto @@ -6,203 +6,12 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// TileLang ST kernels for pto.tcmp: tload(a) + tload(b) + tcmp(a,b)->c + tstore(c). -// Aligned with pto-isa/tests/npu/a2a3/src/st/testcase/tcmp -// 4 cases: f32_1x64_eq, f32_8x64_gt, i32_16x32_eq, i32_32x32_eq +// TileLang ST kernel for pto.tcmp: tload(a) + tload(b) + tcmp(a,b)->c + tstore(c). +// Case: half_32x32_eq (float16, 32x32, EQ mode) +// Output: packed predicate mask (dst tile dtype=i8, shape=32x32, stored as packed mask bytes) module { - // Case 0: f32 1x64 eq - func.func @TCMP_f32_1x64_eq(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c64 = arith.constant 64 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c1, %c64], - strides = [%c64, %c64, %c64, %c64, %c1] - : !pto.tensor_view<1x1x1x1x64xf32> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c1, %c64], - strides = [%c64, %c64, %c64, %c64, %c1] - : !pto.tensor_view<1x1x1x1x64xf32> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c1, %c64], - strides = [%c64, %c64, %c64, %c64, %c1] - : !pto.tensor_view<1x1x1x1x64xi8> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c1, %c64] - : !pto.tensor_view<1x1x1x1x64xf32> -> !pto.partition_tensor_view<1x1x1x1x64xf32> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c1, %c64] - : !pto.tensor_view<1x1x1x1x64xf32> -> !pto.partition_tensor_view<1x1x1x1x64xf32> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c1, %c64] - : !pto.tensor_view<1x1x1x1x64xi8> -> !pto.partition_tensor_view<1x1x1x1x64xi8> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x1x64xf32>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x1x64xf32>) - outs(%b : !pto.tile_buf) - - pto.tcmp ins(%a, %b {cmpMode = #pto} : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x1x64xi8>) - return - } - - // Case 1: f32 8x64 gt - func.func @TCMP_f32_8x64_gt(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c8 = arith.constant 8 : index - %c64 = arith.constant 64 : index - %c512 = arith.constant 512 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c8, %c64], - strides = [%c512, %c512, %c512, %c64, %c1] - : !pto.tensor_view<1x1x1x8x64xf32> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c8, %c64], - strides = [%c512, %c512, %c512, %c64, %c1] - : !pto.tensor_view<1x1x1x8x64xf32> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c8, %c64], - strides = [%c512, %c512, %c512, %c64, %c1] - : !pto.tensor_view<1x1x1x8x64xi8> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c8, %c64] - : !pto.tensor_view<1x1x1x8x64xf32> -> !pto.partition_tensor_view<1x1x1x8x64xf32> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c8, %c64] - : !pto.tensor_view<1x1x1x8x64xf32> -> !pto.partition_tensor_view<1x1x1x8x64xf32> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c8, %c64] - : !pto.tensor_view<1x1x1x8x64xi8> -> !pto.partition_tensor_view<1x1x1x8x64xi8> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x8x64xf32>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x8x64xf32>) - outs(%b : !pto.tile_buf) - - pto.tcmp ins(%a, %b {cmpMode = #pto} : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x8x64xi8>) - return - } - - // Case 2: i32 16x32 eq - func.func @TCMP_i32_16x32_eq(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c16 = arith.constant 16 : index - %c32 = arith.constant 32 : index - %c512 = arith.constant 512 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c16, %c32], - strides = [%c512, %c512, %c512, %c32, %c1] - : !pto.tensor_view<1x1x1x16x32xi32> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c16, %c32], - strides = [%c512, %c512, %c512, %c32, %c1] - : !pto.tensor_view<1x1x1x16x32xi32> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c16, %c32], - strides = [%c512, %c512, %c512, %c32, %c1] - : !pto.tensor_view<1x1x1x16x32xi8> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c16, %c32] - : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c16, %c32] - : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c16, %c32] - : !pto.tensor_view<1x1x1x16x32xi8> -> !pto.partition_tensor_view<1x1x1x16x32xi8> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) - outs(%b : !pto.tile_buf) - - pto.tcmp ins(%a, %b {cmpMode = #pto} : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x16x32xi8>) - return - } - - // Case 3: i32 32x32 eq - func.func @TCMP_i32_32x32_eq(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + func.func @TCMP_half_32x32_eq(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c32 = arith.constant 32 : index @@ -211,11 +20,11 @@ module { %a_view = pto.make_tensor_view %a_ptr, shape = [%c1, %c1, %c1, %c32, %c32], strides = [%c1024, %c1024, %c1024, %c32, %c1] - : !pto.tensor_view<1x1x1x32x32xi32> + : !pto.tensor_view<1x1x1x32x32xf16> %b_view = pto.make_tensor_view %b_ptr, shape = [%c1, %c1, %c1, %c32, %c32], strides = [%c1024, %c1024, %c1024, %c32, %c1] - : !pto.tensor_view<1x1x1x32x32xi32> + : !pto.tensor_view<1x1x1x32x32xf16> %c_view = pto.make_tensor_view %c_ptr, shape = [%c1, %c1, %c1, %c32, %c32], strides = [%c1024, %c1024, %c1024, %c32, %c1] @@ -224,42 +33,42 @@ module { %a_part = pto.partition_view %a_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c32] - : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> %b_part = pto.partition_view %b_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c32] - : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> %c_part = pto.partition_view %c_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c32] : !pto.tensor_view<1x1x1x32x32xi8> -> !pto.partition_tensor_view<1x1x1x32x32xi8> %a = pto.alloc_tile - : !pto.tile_buf + : !pto.tile_buf %b = pto.alloc_tile - : !pto.tile_buf + : !pto.tile_buf %c = pto.alloc_tile : !pto.tile_buf + blayout=row_major, slayout=none_box, fractal=512, pad=0> - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) - outs(%b : !pto.tile_buf) + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + outs(%b : !pto.tile_buf) - pto.tcmp ins(%a, %b {cmpMode = #pto} : !pto.tile_buf, - !pto.tile_buf) + pto.tcmp ins(%a, %b {cmpMode = #pto} : !pto.tile_buf, + !pto.tile_buf) outs(%c : !pto.tile_buf) + blayout=row_major, slayout=none_box, fractal=512, pad=0>) pto.tstore ins(%c : !pto.tile_buf) + blayout=row_major, slayout=none_box, fractal=512, pad=0>) outs(%c_part : !pto.partition_tensor_view<1x1x1x32x32xi8>) return } diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfmod/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tfmod/cases.py index 5b7473873..5de9747e2 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tfmod/cases.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfmod/cases.py @@ -63,7 +63,7 @@ "valid_shape": (2, 31), "eps": 1e-3, }, - { +{ "name": "f32_1x8192_1x8192_1x8192_1x8192", "dtype": np.float32, "dst_tile": (1, 8192), From c515dbe992324ed7179a056b462e9752d13e1ac7 Mon Sep 17 00:00:00 2001 From: zwd060924 Date: Mon, 27 Apr 2026 12:58:59 +0800 Subject: [PATCH 183/192] [Delete] tcmp tfmod trem --- lib/TileOps/tcmp_template.py | 89 --- lib/TileOps/tfmod_template.py | 63 -- lib/TileOps/trem_template.py | 62 -- .../a5/src/st/testcase/tcmp/CMakeLists.txt | 9 - .../npu/a5/src/st/testcase/tcmp/cases.py | 36 -- .../npu/a5/src/st/testcase/tcmp/compare.py | 51 -- .../npu/a5/src/st/testcase/tcmp/gen_data.py | 65 -- .../npu/a5/src/st/testcase/tcmp/launch.cpp | 19 - .../npu/a5/src/st/testcase/tcmp/main.cpp | 142 ----- .../npu/a5/src/st/testcase/tcmp/tcmp.pto | 75 --- .../a5/src/st/testcase/tfmod/CMakeLists.txt | 9 - .../npu/a5/src/st/testcase/tfmod/cases.py | 75 --- .../npu/a5/src/st/testcase/tfmod/compare.py | 50 -- .../npu/a5/src/st/testcase/tfmod/gen_data.py | 37 -- .../npu/a5/src/st/testcase/tfmod/launch.cpp | 48 -- .../npu/a5/src/st/testcase/tfmod/main.cpp | 151 ----- .../npu/a5/src/st/testcase/tfmod/tfmod.pto | 340 ----------- .../a5/src/st/testcase/trem/CMakeLists.txt | 9 - .../npu/a5/src/st/testcase/trem/cases.py | 102 ---- .../npu/a5/src/st/testcase/trem/compare.py | 50 -- .../npu/a5/src/st/testcase/trem/gen_data.py | 37 -- .../npu/a5/src/st/testcase/trem/launch.cpp | 69 --- .../npu/a5/src/st/testcase/trem/main.cpp | 157 ----- .../npu/a5/src/st/testcase/trem/trem.pto | 577 ------------------ 24 files changed, 2322 deletions(-) delete mode 100644 lib/TileOps/tcmp_template.py delete mode 100644 lib/TileOps/tfmod_template.py delete mode 100644 lib/TileOps/trem_template.py delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcmp/CMakeLists.txt delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcmp/cases.py delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcmp/compare.py delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcmp/gen_data.py delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcmp/launch.cpp delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcmp/main.cpp delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tcmp/tcmp.pto delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tfmod/CMakeLists.txt delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tfmod/cases.py delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tfmod/compare.py delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tfmod/gen_data.py delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tfmod/launch.cpp delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tfmod/main.cpp delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tfmod/tfmod.pto delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trem/CMakeLists.txt delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trem/cases.py delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trem/compare.py delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trem/gen_data.py delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trem/launch.cpp delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trem/main.cpp delete mode 100644 test/tilelang_st/npu/a5/src/st/testcase/trem/trem.pto diff --git a/lib/TileOps/tcmp_template.py b/lib/TileOps/tcmp_template.py deleted file mode 100644 index 2fe9e7ae0..000000000 --- a/lib/TileOps/tcmp_template.py +++ /dev/null @@ -1,89 +0,0 @@ -"""TileLang DSL template for pto.tcmp - -Aligned with pto-isa/include/pto/npu/a5/TCmp.hpp: -- 32-bit (int32, float, uint32): TCmp_32B with plt_b32 + pdintlv_b8 -- 16-bit (int16, half, uint16): TCmp with plt_b16 -- 8-bit (int8, uint8): TCmp with plt_b8 -""" - -import tilelang_dsl as pto - - -REPEAT_BYTE = 256 -CMP_BITS_PER_INDEX = 32 - - -@pto.vkernel( - target="a5", - op="pto.tcmp", - dtypes=[ - (pto.si32, pto.si32, pto.i8), - (pto.f32, pto.f32, pto.i8), - (pto.ui32, pto.ui32, pto.i8), - (pto.si16, pto.si16, pto.i8), - (pto.f16, pto.f16, pto.i8), - (pto.ui16, pto.ui16, pto.i8), - (pto.si8, pto.si8, pto.i8), - (pto.ui8, pto.ui8, pto.i8), - ], - advanced=True, -) -def template_tcmp(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): - dtype = src0.element_type - valid_rows, valid_cols = src0.valid_shape - cmp_mode = pto.get_op_attr("cmp_mode", pto.CmpMode.EQ) - - dtype_size = pto.bytewidth(dtype) - total_elements = valid_rows * valid_cols - repeat_elm = REPEAT_BYTE // dtype_size - - src0_ptr = src0.as_ptr() - src1_ptr = src1.as_ptr() - dst_ptr = dst.as_ptr() - dst_u32_ptr = pto.castptr(dst_ptr, pto.ptr(pto.ui32, pto.MemorySpace.UB)) - - if pto.constexpr(dtype_size == 4): - repeat_elm_32b = REPEAT_BYTE // 4 - repeat_times_32b = (total_elements + repeat_elm_32b - 1) // repeat_elm_32b + 1 - loop_times = repeat_times_32b // 2 - remaining = total_elements - - for i in range(0, loop_times, 1): - preg0, remaining = pto.plt_b32(remaining) - vreg0 = pto.vlds(src0_ptr, i * 2 * repeat_elm_32b) - vreg1 = pto.vlds(src1_ptr, i * 2 * repeat_elm_32b) - preg1 = pto.vcmp(vreg0, vreg1, preg0, cmp_mode) - - preg0, remaining = pto.plt_b32(remaining) - vreg2 = pto.vlds(src0_ptr, (i * 2 + 1) * repeat_elm_32b) - vreg3 = pto.vlds(src1_ptr, (i * 2 + 1) * repeat_elm_32b) - preg2 = pto.vcmp(vreg2, vreg3, preg0, cmp_mode) - - preg1_b8 = pto.pbitcast(preg1, pto.mask_b8) - preg2_b8 = pto.pbitcast(preg2, pto.mask_b8) - preg3, preg4 = pto.pdintlv_b8(preg1_b8, preg2_b8) - pto.psts(preg3, dst_u32_ptr, i * 16, pto.PredicateDist.PK) - elif pto.constexpr(dtype_size == 2): - repeat_times = (total_elements + repeat_elm - 1) // repeat_elm - dst_stride_bytes = (repeat_elm // CMP_BITS_PER_INDEX) * 4 - remaining = total_elements - - for i in range(0, repeat_times, 1): - preg0, remaining = pto.plt_b16(remaining) - vreg0 = pto.vlds(src0_ptr, i * repeat_elm) - vreg1 = pto.vlds(src1_ptr, i * repeat_elm) - preg1 = pto.vcmp(vreg0, vreg1, preg0, cmp_mode) - pto.psts(preg1, dst_u32_ptr, i * dst_stride_bytes, pto.PredicateDist.PK) - elif pto.constexpr(dtype_size == 1): - repeat_times = (total_elements + repeat_elm - 1) // repeat_elm - dst_stride_bytes = (repeat_elm // CMP_BITS_PER_INDEX) * 4 - remaining = total_elements - - for i in range(0, repeat_times, 1): - preg0, remaining = pto.plt_b8(remaining) - vreg0 = pto.vlds(src0_ptr, i * repeat_elm) - vreg1 = pto.vlds(src1_ptr, i * repeat_elm) - preg1 = pto.vcmp(vreg0, vreg1, preg0, cmp_mode) - pto.psts(preg1, dst_u32_ptr, i * dst_stride_bytes, pto.PredicateDist.PK) - - return \ No newline at end of file diff --git a/lib/TileOps/tfmod_template.py b/lib/TileOps/tfmod_template.py deleted file mode 100644 index a05260863..000000000 --- a/lib/TileOps/tfmod_template.py +++ /dev/null @@ -1,63 +0,0 @@ -"""TileLang DSL template for pto.tfmod - -Aligned with pto-isa/include/pto/npu/a5/TFMod.hpp: -- float: vdiv -> vtrc(ROUND_Z) -> vmul -> vsub -- half: vcvt(half->float, PART_EVEN/ODD) -> vdiv -> vtrc(ROUND_Z) -> vmul -> vsub - -> vcvt(float->half, ROUND_Z, RS_ENABLE, PART_EVEN/ODD) -> vor -- other (i16/ui16): vdiv -> vmul -> vsub (no vtrc, integer div is trunc by nature) -""" - -import tilelang_dsl as pto - - -@pto.vkernel( - target="a5", - op="pto.tfmod", - dtypes=[ - (pto.f32, pto.f32, pto.f32), - (pto.f16, pto.f16, pto.f16), - (pto.i16, pto.i16, pto.i16), - (pto.ui16, pto.ui16, pto.ui16), - ], -) -def template_tfmod(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): - dtype = dst.element_type - valid_rows, valid_cols = dst.valid_shape - - for row in range(0, valid_rows, 1): - remained = valid_cols - for col in range(0, valid_cols, pto.get_lanes(dtype)): - mask, remained = pto.make_mask(dtype, remained) - lhs = pto.vlds(src0[row, col:]) - rhs = pto.vlds(src1[row, col:]) - - if pto.constexpr(dtype == pto.f32): - quotient = pto.vdiv(lhs, rhs, mask) - quotient = pto.vtrc(quotient, mask, rnd=pto.VcvtRoundMode.Z) - truncated_mul = pto.vmul(quotient, rhs, mask) - result = pto.vsub(lhs, truncated_mul, mask) - elif pto.constexpr(dtype == pto.f16): - lhs_even = pto.vcvt(lhs, pto.f32, mask, part=pto.VcvtPartMode.EVEN) - rhs_even = pto.vcvt(rhs, pto.f32, mask, part=pto.VcvtPartMode.EVEN) - quotient_even = pto.vdiv(lhs_even, rhs_even, mask) - quotient_even = pto.vtrc(quotient_even, mask, rnd=pto.VcvtRoundMode.Z) - truncated_mul_even = pto.vmul(quotient_even, rhs_even, mask) - result_even = pto.vsub(lhs_even, truncated_mul_even, mask) - dst_even = pto.vcvt(result_even, pto.f16, mask, rnd=pto.VcvtRoundMode.Z, sat=pto.VcvtSatMode.SAT, part=pto.VcvtPartMode.EVEN) - - lhs_odd = pto.vcvt(lhs, pto.f32, mask, part=pto.VcvtPartMode.ODD) - rhs_odd = pto.vcvt(rhs, pto.f32, mask, part=pto.VcvtPartMode.ODD) - quotient_odd = pto.vdiv(lhs_odd, rhs_odd, mask) - quotient_odd = pto.vtrc(quotient_odd, mask, rnd=pto.VcvtRoundMode.Z) - truncated_mul_odd = pto.vmul(quotient_odd, rhs_odd, mask) - result_odd = pto.vsub(lhs_odd, truncated_mul_odd, mask) - dst_odd = pto.vcvt(result_odd, pto.f16, mask, rnd=pto.VcvtRoundMode.Z, sat=pto.VcvtSatMode.SAT, part=pto.VcvtPartMode.ODD) - - result = pto.vor(dst_even, dst_odd, mask) - else: - quotient = pto.vdiv(lhs, rhs, mask) - truncated_mul = pto.vmul(quotient, rhs, mask) - result = pto.vsub(lhs, truncated_mul, mask) - - pto.vsts(result, dst[row, col:], mask) - return \ No newline at end of file diff --git a/lib/TileOps/trem_template.py b/lib/TileOps/trem_template.py deleted file mode 100644 index 619b664a6..000000000 --- a/lib/TileOps/trem_template.py +++ /dev/null @@ -1,62 +0,0 @@ -"""TileLang DSL template for pto.trem""" - -import tilelang_dsl as pto - - -@pto.vkernel( - target="a5", - op="pto.trem", - dtypes=[ - (pto.f32, pto.f32, pto.f32, pto.f32), - (pto.f16, pto.f16, pto.f16, pto.f16), - (pto.i32, pto.i32, pto.i32, pto.i32), - ], - advanced=True, -) -def template_trem(src0: pto.Tile, src1: pto.Tile, tmp: pto.Tile, dst: pto.Tile): - dtype = dst.element_type - valid_rows, valid_cols = dst.valid_shape - - for row in range(0, valid_rows, 1): - remained = valid_cols - for col in range(0, valid_cols, pto.get_lanes(dtype)): - mask, remained = pto.make_mask(dtype, remained) - lhs = pto.vlds(src0[row, col:]) - rhs = pto.vlds(src1[row, col:]) - if pto.constexpr(dtype == pto.f32): - quotient = pto.vdiv(lhs, rhs, mask) - quotient = pto.vtrc(quotient, mask, rnd=pto.VcvtRoundMode.F) - floored_mul = pto.vmul(quotient, rhs, mask) - result = pto.vsub(lhs, floored_mul, mask) - sign_diff_mask = pto.vcmps(pto.vmul(rhs, result, mask), 0.0, mask, pto.CmpMode.LT) - corrected = pto.vadd(result, rhs, sign_diff_mask) - result = pto.vsel(corrected, result, sign_diff_mask) - elif pto.constexpr(dtype == pto.f16): - lhs_even = pto.vcvt(lhs, pto.f32, mask, part=pto.VcvtPartMode.EVEN) - rhs_even = pto.vcvt(rhs, pto.f32, mask, part=pto.VcvtPartMode.EVEN) - lhs_odd = pto.vcvt(lhs, pto.f32, mask, part=pto.VcvtPartMode.ODD) - rhs_odd = pto.vcvt(rhs, pto.f32, mask, part=pto.VcvtPartMode.ODD) - q_even = pto.vdiv(lhs_even, rhs_even, mask) - q_odd = pto.vdiv(lhs_odd, rhs_odd, mask) - q_even = pto.vtrc(q_even, mask, rnd=pto.VcvtRoundMode.F) - q_odd = pto.vtrc(q_odd, mask, rnd=pto.VcvtRoundMode.F) - fm_even = pto.vmul(q_even, rhs_even, mask) - fm_odd = pto.vmul(q_odd, rhs_odd, mask) - r_even = pto.vsub(lhs_even, fm_even, mask) - r_odd = pto.vsub(lhs_odd, fm_odd, mask) - dst_even = pto.vcvt(r_even, pto.f16, mask, rnd=pto.VcvtRoundMode.Z, sat=pto.VcvtSatMode.RS_ENABLE, part=pto.VcvtPartMode.EVEN) - dst_odd = pto.vcvt(r_odd, pto.f16, mask, rnd=pto.VcvtRoundMode.Z, sat=pto.VcvtSatMode.RS_ENABLE, part=pto.VcvtPartMode.ODD) - result = pto.vor(dst_even, dst_odd, mask) - sign_diff_mask = pto.vcmps(pto.vmul(rhs, result, mask), 0.0, mask, pto.CmpMode.LT) - corrected = pto.vadd(result, rhs, sign_diff_mask) - result = pto.vsel(corrected, result, sign_diff_mask) - elif pto.constexpr(dtype == pto.i32): - lhs_f32 = pto.vcvt(lhs, pto.f32, mask, rnd=pto.VcvtRoundMode.R) - rhs_f32 = pto.vcvt(rhs, pto.f32, mask, rnd=pto.VcvtRoundMode.R) - quotient = pto.vdiv(lhs_f32, rhs_f32, mask) - quotient = pto.vtrc(quotient, mask, rnd=pto.VcvtRoundMode.F) - floored_mul = pto.vmul(quotient, rhs_f32, mask) - result_f32 = pto.vsub(lhs_f32, floored_mul, mask) - result = pto.vcvt(result_f32, pto.i32, mask, rnd=pto.VcvtRoundMode.Z, sat=pto.VcvtSatMode.NOSAT) - pto.vsts(result, dst[row, col:], mask) - return \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/CMakeLists.txt deleted file mode 100644 index a863ea151..000000000 --- a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/CMakeLists.txt +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright (c) 2026 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. - -pto_tilelang_vec_st(tcmp) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/cases.py deleted file mode 100644 index f4ffa0e80..000000000 --- a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/cases.py +++ /dev/null @@ -1,36 +0,0 @@ -#!/usr/bin/python3 -# Copyright (c) 2026 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. - -# coding=utf-8 - -"""Single source of truth for tcmp ST test cases. - -Each case defines: - - name: case identifier - - dtype: numpy dtype for inputs (np.float16, np.float32, np.int32) - - out_dtype: numpy dtype for output (np.uint8 - packed predicate mask, rows*cols/8 bytes) - - shape: (rows, cols) — tile buffer dimensions - - valid_shape: (valid_rows, valid_cols) — effective computation region - - eps: tolerance for comparison - - cmp_mode: comparison mode string (eq, gt, ne) -""" - -import numpy as np - -CASES = [ - { - "name": "half_32x32_eq", - "dtype": np.float16, - "out_dtype": np.uint8, - "shape": (32, 32), - "valid_shape": (32, 32), - "eps": 0, - "cmp_mode": "eq", - }, -] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/compare.py deleted file mode 100644 index 0d5384916..000000000 --- a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/compare.py +++ /dev/null @@ -1,51 +0,0 @@ -#!/usr/bin/python3 -# Copyright (c) 2026 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. - -# coding=utf-8 - -import os -import sys -import numpy as np - -from cases import CASES -from st_common import result_cmp, style_fail, style_pass, validate_cases - - -def main(): - validate_cases(CASES) - case_filter = sys.argv[1] if len(sys.argv) > 1 else None - - all_passed = True - for case in CASES: - if case_filter is not None and case["name"] != case_filter: - continue - - case_dir = case["name"] - shape = case["shape"] - out_dtype = case["out_dtype"] - vr, vc = case["valid_shape"] - packed_size = vr * vc // 8 - - golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=out_dtype) - output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=out_dtype) - - ok = result_cmp(golden[:packed_size], output[:packed_size], case["eps"]) - if ok: - print(style_pass(f"[INFO] {case['name']}: compare passed")) - else: - print(style_fail(f"[ERROR] {case['name']}: compare failed")) - all_passed = False - - if not all_passed: - sys.exit(2) - print(style_pass("[INFO] all cases passed")) - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/gen_data.py deleted file mode 100644 index 549ff4bed..000000000 --- a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/gen_data.py +++ /dev/null @@ -1,65 +0,0 @@ -#!/usr/bin/python3 -# Copyright (c) 2026 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. - -# coding=utf-8 - -import numpy as np -from cases import CASES -from st_common import validate_cases, setup_case_rng, save_case_data - -validate_cases(CASES) - - -def compute_cmp(a, b, mode): - if mode == "gt": - return (a > b) - elif mode == "ge": - return (a >= b) - elif mode == "lt": - return (a < b) - elif mode == "le": - return (a <= b) - elif mode == "eq": - return (a == b) - elif mode == "ne": - return (a != b) - else: - raise ValueError(f"Unknown cmp_mode: {mode}") - - -def pack_predicate_mask(cmp_result): - cmp_result = cmp_result.astype(np.uint8) - shape = cmp_result.shape - bits_per_row = shape[1] // 8 - packed = [] - func_binar = lambda bits: sum(int(bit) * (2 ** i) for i, bit in enumerate(bits)) - for row in cmp_result: - for i in range(bits_per_row): - packed.append(func_binar(row[i*8:i*8+8])) - return np.array(packed, dtype=np.uint8) - - -for case in CASES: - setup_case_rng(case) - - dtype = case["dtype"] - out_dtype = case["out_dtype"] - shape = case["shape"] - valid_shape = case["valid_shape"] - cmp_mode = case["cmp_mode"] - - input1 = np.random.choice([-5, -2, -1, 0, 1, 2, 5], size=shape).astype(dtype) - input2 = np.random.choice([-5, -2, -1, 0, 1, 2, 5], size=shape).astype(dtype) - - vr, vc = valid_shape - cmp_result = compute_cmp(input1[:vr, :vc], input2[:vr, :vc], cmp_mode) - golden = pack_predicate_mask(cmp_result) - - save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) - print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__} golden_size={golden.size} cmp_mode={cmp_mode}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/launch.cpp deleted file mode 100644 index e613c0ff5..000000000 --- a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/launch.cpp +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -#include - -#ifndef AICORE -#define AICORE [aicore] -#endif - -extern "C" __global__ AICORE void TCMP_half_32x32_eq(__gm__ void *a, __gm__ void *b, __gm__ void *c); - -void LaunchTCMP_half_32x32_eq(void *a, void *b, void *c, void *stream) { - TCMP_half_32x32_eq<<<1, nullptr, stream>>>(a, b, c); -} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/main.cpp deleted file mode 100644 index 0553ffef0..000000000 --- a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/main.cpp +++ /dev/null @@ -1,142 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -// Host driver for TileLang tcmp ST — case-table driven. - -#include "acl/acl.h" -#include "test_common.h" -#include -#include -#include -#include -#include -#include - -using namespace PtoTestCommon; - -void LaunchTCMP_half_32x32_eq(void *a, void *b, void *c, void *stream); - -using LaunchFn = void (*)(void *, void *, void *, void *); - -struct TestCase { - const char *name; - LaunchFn launch; - size_t rows; - size_t cols; - size_t validRows; - size_t validCols; - size_t inputElemSize; - size_t outputElemSize; - size_t outputSize; -}; - -static const TestCase kCases[] = { - {"half_32x32_eq", LaunchTCMP_half_32x32_eq, 32, 32, 32, 32, 2, 1, 128}, -}; -static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); - -static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { - int rc = 0; - const size_t inputElemCount = tc.rows * tc.cols; - const size_t inputFileSize = inputElemCount * tc.inputElemSize; - const size_t outputFileSize = inputElemCount * tc.outputElemSize; - const size_t packedMaskSize = tc.outputSize; - - std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu, packed_mask=%zu) ===\n", - tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols, packedMaskSize); - - std::string caseDir = std::string("./") + tc.name; - - void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; - void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; - - aclrtMallocHost(&src0Host, inputFileSize); - aclrtMallocHost(&src1Host, inputFileSize); - aclrtMallocHost(&dstHost, outputFileSize); - - aclrtMalloc(&src0Device, inputFileSize, ACL_MEM_MALLOC_HUGE_FIRST); - aclrtMalloc(&src1Device, inputFileSize, ACL_MEM_MALLOC_HUGE_FIRST); - aclrtMalloc(&dstDevice, outputFileSize, ACL_MEM_MALLOC_HUGE_FIRST); - - size_t readSize = inputFileSize; - if (!ReadFile((caseDir + "/input1.bin").c_str(), readSize, src0Host, inputFileSize)) { - std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); - rc = 1; - } - readSize = inputFileSize; - if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), readSize, src1Host, inputFileSize)) { - std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); - rc = 1; - } - - if (rc == 0) { - aclrtMemcpy(src0Device, inputFileSize, src0Host, inputFileSize, ACL_MEMCPY_HOST_TO_DEVICE); - aclrtMemcpy(src1Device, inputFileSize, src1Host, inputFileSize, ACL_MEMCPY_HOST_TO_DEVICE); - - tc.launch(src0Device, src1Device, dstDevice, stream); - - aclrtSynchronizeStream(stream); - aclrtMemcpy(dstHost, packedMaskSize, dstDevice, packedMaskSize, ACL_MEMCPY_DEVICE_TO_HOST); - } - - if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, packedMaskSize)) { - std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); - rc = 1; - } - - if (src0Device != nullptr) - aclrtFree(src0Device); - if (src1Device != nullptr) - aclrtFree(src1Device); - if (dstDevice != nullptr) - aclrtFree(dstDevice); - if (src0Host != nullptr) - aclrtFreeHost(src0Host); - if (src1Host != nullptr) - aclrtFreeHost(src1Host); - if (dstHost != nullptr) - aclrtFreeHost(dstHost); - - if (rc == 0) - std::printf("[INFO] case %s done\n", tc.name); - return rc; -} - -int main(int argc, char *argv[]) { - const char *caseFilter = (argc > 1) ? argv[1] : nullptr; - - int rc = 0; - int deviceId = 0; - aclrtStream stream = nullptr; - - aclInit(nullptr); - if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { - deviceId = std::atoi(envDevice); - } - aclrtSetDevice(deviceId); - aclrtCreateStream(&stream); - - for (size_t i = 0; i < kNumCases; ++i) { - if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { - continue; - } - int ret = RunCase(kCases[i], deviceId, stream); - if (ret != 0) { - std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); - rc = 1; - break; - } - } - - if (stream != nullptr) - aclrtDestroyStream(stream); - aclrtResetDevice(deviceId); - aclFinalize(); - - return rc; -} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/tcmp.pto b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/tcmp.pto deleted file mode 100644 index cfb08f0f6..000000000 --- a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/tcmp.pto +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -// TileLang ST kernel for pto.tcmp: tload(a) + tload(b) + tcmp(a,b)->c + tstore(c). -// Case: half_32x32_eq (float16, 32x32, EQ mode) -// Output: packed predicate mask (dst tile dtype=i8, shape=32x32, stored as packed mask bytes) - -module { - func.func @TCMP_half_32x32_eq(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c32 = arith.constant 32 : index - %c1024 = arith.constant 1024 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c32, %c32], - strides = [%c1024, %c1024, %c1024, %c32, %c1] - : !pto.tensor_view<1x1x1x32x32xf16> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c32, %c32], - strides = [%c1024, %c1024, %c1024, %c32, %c1] - : !pto.tensor_view<1x1x1x32x32xf16> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c32, %c32], - strides = [%c1024, %c1024, %c1024, %c32, %c1] - : !pto.tensor_view<1x1x1x32x32xi8> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c32, %c32] - : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c32, %c32] - : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c32, %c32] - : !pto.tensor_view<1x1x1x32x32xi8> -> !pto.partition_tensor_view<1x1x1x32x32xi8> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) - outs(%b : !pto.tile_buf) - - pto.tcmp ins(%a, %b {cmpMode = #pto} : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x32x32xi8>) - return - } -} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfmod/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tfmod/CMakeLists.txt deleted file mode 100644 index 41603223e..000000000 --- a/test/tilelang_st/npu/a5/src/st/testcase/tfmod/CMakeLists.txt +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright (c) 2026 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You can not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. - -pto_tilelang_vec_st(tfmod) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfmod/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tfmod/cases.py deleted file mode 100644 index 5de9747e2..000000000 --- a/test/tilelang_st/npu/a5/src/st/testcase/tfmod/cases.py +++ /dev/null @@ -1,75 +0,0 @@ -#!/usr/bin/python3 -# Copyright (c) 2026 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You can not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. - -# coding=utf-8 - -"""Single source of truth for tfmod ST test cases. - -Each case defines: - - name: case identifier - - dtype: numpy dtype (np.float32) - - dst_tile: (rows, cols) — dst tile buffer dimensions - - src0_tile: (rows, cols) — src0 tile buffer dimensions - - src1_tile: (rows, cols) — src1 tile buffer dimensions - - valid_shape: (valid_rows, valid_cols) — effective computation region - - eps: tolerance for numpy.allclose (atol and rtol) - -Note: src0/src1/dst tile buffer physical sizes can differ, - but valid_shape must be the same for all. -""" - -import numpy as np - -CASES = [ - { - "name": "f32_16x64_16x128_16x128_16x64", - "dtype": np.float32, - "dst_tile": (16, 64), - "src0_tile": (16, 128), - "src1_tile": (16, 128), - "valid_shape": (16, 64), - "eps": 1e-3, - }, - { - "name": "f32_16x32_16x64_16x32_16x32", - "dtype": np.float32, - "dst_tile": (16, 32), - "src0_tile": (16, 64), - "src1_tile": (16, 32), - "valid_shape": (16, 32), - "eps": 1e-3, - }, - { - "name": "f32_16x64_16x128_16x128_16x63", - "dtype": np.float32, - "dst_tile": (16, 64), - "src0_tile": (16, 128), - "src1_tile": (16, 128), - "valid_shape": (16, 63), - "eps": 1e-3, - }, - { - "name": "f32_2x32_2x64_2x32_2x31", - "dtype": np.float32, - "dst_tile": (2, 32), - "src0_tile": (2, 64), - "src1_tile": (2, 32), - "valid_shape": (2, 31), - "eps": 1e-3, - }, -{ - "name": "f32_1x8192_1x8192_1x8192_1x8192", - "dtype": np.float32, - "dst_tile": (1, 8192), - "src0_tile": (1, 8192), - "src1_tile": (1, 8192), - "valid_shape": (1, 8192), - "eps": 1e-3, - }, -] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfmod/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tfmod/compare.py deleted file mode 100644 index 96ab2dda8..000000000 --- a/test/tilelang_st/npu/a5/src/st/testcase/tfmod/compare.py +++ /dev/null @@ -1,50 +0,0 @@ -#!/usr/bin/python3 -# Copyright (c) 2026 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. - -# coding=utf-8 - -import os -import sys -import numpy as np - -from cases import CASES -from st_common import result_cmp, style_fail, style_pass, validate_cases - - -def main(): - validate_cases(CASES) - case_filter = sys.argv[1] if len(sys.argv) > 1 else None - - all_passed = True - for case in CASES: - if case_filter is not None and case["name"] != case_filter: - continue - - case_dir = case["name"] - dst_tile = case["dst_tile"] - valid_shape = case["valid_shape"] - vr, vc = valid_shape - - golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(dst_tile) - output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(dst_tile) - - ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) - if ok: - print(style_pass(f"[INFO] {case['name']}: compare passed")) - else: - print(style_fail(f"[ERROR] {case['name']}: compare failed")) - all_passed = False - - if not all_passed: - sys.exit(2) - print(style_pass("[INFO] all cases passed")) - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfmod/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tfmod/gen_data.py deleted file mode 100644 index 61bfc2a2e..000000000 --- a/test/tilelang_st/npu/a5/src/st/testcase/tfmod/gen_data.py +++ /dev/null @@ -1,37 +0,0 @@ -#!/usr/bin/python3 -# Copyright (c) 2026 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. - -# coding=utf-8 - -import numpy as np -from cases import CASES -from st_common import validate_cases, setup_case_rng, save_case_data - -validate_cases(CASES) - -np.random.seed(19) - -for case in CASES: - setup_case_rng(case) - - dtype = case["dtype"] - dst_tile = case["dst_tile"] - src0_tile = case["src0_tile"] - src1_tile = case["src1_tile"] - valid_shape = case["valid_shape"] - - input1 = np.random.uniform(low=-1000, high=1000, size=src0_tile).astype(dtype) - input2 = np.random.uniform(low=3, high=100, size=src1_tile).astype(dtype) - - golden = np.zeros(dst_tile, dtype=dtype) - vr, vc = valid_shape - golden[:vr, :vc] = np.fmod(input1[:vr, :vc], input2[:vr, :vc]).astype(dtype) - - save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) - print(f"[INFO] gen_data: {case['name']} dst={dst_tile} src0={src0_tile} src1={src1_tile} valid={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfmod/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tfmod/launch.cpp deleted file mode 100644 index 764d99aa4..000000000 --- a/test/tilelang_st/npu/a5/src/st/testcase/tfmod/launch.cpp +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -#include - -#ifndef AICORE -#define AICORE [aicore] -#endif - -// Case 0: f32, dst=16x64, src0=16x128, src1=16x128, valid=16x64 -extern "C" __global__ AICORE void TFMOD_f32_16x64_16x128_16x128_16x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); - -void LaunchTFMOD_f32_16x64_16x128_16x128_16x64(void *a, void *b, void *c, void *stream) { - TFMOD_f32_16x64_16x128_16x128_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); -} - -// Case 1: f32, dst=16x32, src0=16x64, src1=16x32, valid=16x32 -extern "C" __global__ AICORE void TFMOD_f32_16x32_16x64_16x32_16x32(__gm__ float *a, __gm__ float *b, __gm__ float *c); - -void LaunchTFMOD_f32_16x32_16x64_16x32_16x32(void *a, void *b, void *c, void *stream) { - TFMOD_f32_16x32_16x64_16x32_16x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); -} - -// Case 2: f32, dst=16x64, src0=16x128, src1=16x128, valid=16x63 -extern "C" __global__ AICORE void TFMOD_f32_16x64_16x128_16x128_16x63(__gm__ float *a, __gm__ float *b, __gm__ float *c); - -void LaunchTFMOD_f32_16x64_16x128_16x128_16x63(void *a, void *b, void *c, void *stream) { - TFMOD_f32_16x64_16x128_16x128_16x63<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); -} - -// Case 3: f32, dst=2x32, src0=2x64, src1=2x32, valid=2x31 -extern "C" __global__ AICORE void TFMOD_f32_2x32_2x64_2x32_2x31(__gm__ float *a, __gm__ float *b, __gm__ float *c); - -void LaunchTFMOD_f32_2x32_2x64_2x32_2x31(void *a, void *b, void *c, void *stream) { - TFMOD_f32_2x32_2x64_2x32_2x31<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); -} - -// Case 4: f32, dst=1x8192, src0=1x8192, src1=1x8192, valid=1x8192 -extern "C" __global__ AICORE void TFMOD_f32_1x8192_1x8192_1x8192_1x8192(__gm__ float *a, __gm__ float *b, __gm__ float *c); - -void LaunchTFMOD_f32_1x8192_1x8192_1x8192_1x8192(void *a, void *b, void *c, void *stream) { - TFMOD_f32_1x8192_1x8192_1x8192_1x8192<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); -} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfmod/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tfmod/main.cpp deleted file mode 100644 index 3f7a170d3..000000000 --- a/test/tilelang_st/npu/a5/src/st/testcase/tfmod/main.cpp +++ /dev/null @@ -1,151 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You can not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -// Host driver for TileLang tfmod ST — case-table driven. - -#include "acl/acl.h" -#include "test_common.h" -#include -#include -#include -#include -#include -#include - -using namespace PtoTestCommon; - -void LaunchTFMOD_f32_16x64_16x128_16x128_16x64(void *a, void *b, void *c, void *stream); -void LaunchTFMOD_f32_16x32_16x64_16x32_16x32(void *a, void *b, void *c, void *stream); -void LaunchTFMOD_f32_16x64_16x128_16x128_16x63(void *a, void *b, void *c, void *stream); -void LaunchTFMOD_f32_2x32_2x64_2x32_2x31(void *a, void *b, void *c, void *stream); -void LaunchTFMOD_f32_1x8192_1x8192_1x8192_1x8192(void *a, void *b, void *c, void *stream); - -using LaunchFn = void (*)(void *, void *, void *, void *); - -struct TestCase { - const char *name; - LaunchFn launch; - size_t src0Rows; - size_t src0Cols; - size_t src1Rows; - size_t src1Cols; - size_t dstRows; - size_t dstCols; - size_t validRows; - size_t validCols; - size_t elemSize; -}; - -static const TestCase kCases[] = { - {"f32_16x64_16x128_16x128_16x64", LaunchTFMOD_f32_16x64_16x128_16x128_16x64, 16, 128, 16, 128, 16, 64, 16, 64, sizeof(float)}, - {"f32_16x32_16x64_16x32_16x32", LaunchTFMOD_f32_16x32_16x64_16x32_16x32, 16, 64, 16, 32, 16, 32, 16, 32, sizeof(float)}, - {"f32_16x64_16x128_16x128_16x63", LaunchTFMOD_f32_16x64_16x128_16x128_16x63, 16, 128, 16, 128, 16, 64, 16, 63, sizeof(float)}, - {"f32_2x32_2x64_2x32_2x31", LaunchTFMOD_f32_2x32_2x64_2x32_2x31, 2, 64, 2, 32, 2, 32, 2, 31, sizeof(float)}, - {"f32_1x8192_1x8192_1x8192_1x8192", LaunchTFMOD_f32_1x8192_1x8192_1x8192_1x8192, 1, 8192, 1, 8192, 1, 8192, 1, 8192, sizeof(float)}, -}; -static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); - -static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { - int rc = 0; - const size_t src0Size = tc.src0Rows * tc.src0Cols * tc.elemSize; - const size_t src1Size = tc.src1Rows * tc.src1Cols * tc.elemSize; - const size_t dstSize = tc.dstRows * tc.dstCols * tc.elemSize; - - std::printf("[INFO] === case: %s (dst=%zux%zu, src0=%zux%zu, src1=%zux%zu, valid=%zux%zu) ===\n", - tc.name, tc.dstRows, tc.dstCols, tc.src0Rows, tc.src0Cols, tc.src1Rows, tc.src1Cols, tc.validRows, tc.validCols); - - std::string caseDir = std::string("./") + tc.name; - - void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; - void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; - - aclrtMallocHost(&src0Host, src0Size); - aclrtMallocHost(&src1Host, src1Size); - aclrtMallocHost(&dstHost, dstSize); - - aclrtMalloc(&src0Device, src0Size, ACL_MEM_MALLOC_HUGE_FIRST); - aclrtMalloc(&src1Device, src1Size, ACL_MEM_MALLOC_HUGE_FIRST); - aclrtMalloc(&dstDevice, dstSize, ACL_MEM_MALLOC_HUGE_FIRST); - - size_t fileSize = 0; - if (!ReadFile((caseDir + "/input1.bin").c_str(), fileSize, src0Host, src0Size)) { - std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); - rc = 1; - } - fileSize = 0; - if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), fileSize, src1Host, src1Size)) { - std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); - rc = 1; - } - - if (rc == 0) { - aclrtMemcpy(src0Device, src0Size, src0Host, src0Size, ACL_MEMCPY_HOST_TO_DEVICE); - aclrtMemcpy(src1Device, src1Size, src1Host, src1Size, ACL_MEMCPY_HOST_TO_DEVICE); - - tc.launch(src0Device, src1Device, dstDevice, stream); - - aclrtSynchronizeStream(stream); - aclrtMemcpy(dstHost, dstSize, dstDevice, dstSize, ACL_MEMCPY_DEVICE_TO_HOST); - } - - if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstSize)) { - std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); - rc = 1; - } - - if (src0Device != nullptr) - aclrtFree(src0Device); - if (src1Device != nullptr) - aclrtFree(src1Device); - if (dstDevice != nullptr) - aclrtFree(dstDevice); - if (src0Host != nullptr) - aclrtFreeHost(src0Host); - if (src1Host != nullptr) - aclrtFreeHost(src1Host); - if (dstHost != nullptr) - aclrtFreeHost(dstHost); - - if (rc == 0) - std::printf("[INFO] case %s done\n", tc.name); - return rc; -} - -int main(int argc, char *argv[]) { - const char *caseFilter = (argc > 1) ? argv[1] : nullptr; - - int rc = 0; - int deviceId = 0; - aclrtStream stream = nullptr; - - aclInit(nullptr); - if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { - deviceId = std::atoi(envDevice); - } - aclrtSetDevice(deviceId); - aclrtCreateStream(&stream); - - for (size_t i = 0; i < kNumCases; ++i) { - if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { - continue; - } - int ret = RunCase(kCases[i], deviceId, stream); - if (ret != 0) { - std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); - rc = 1; - break; - } - } - - if (stream != nullptr) - aclrtDestroyStream(stream); - aclrtResetDevice(deviceId); - aclFinalize(); - - return rc; -} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfmod/tfmod.pto b/test/tilelang_st/npu/a5/src/st/testcase/tfmod/tfmod.pto deleted file mode 100644 index 72f227b82..000000000 --- a/test/tilelang_st/npu/a5/src/st/testcase/tfmod/tfmod.pto +++ /dev/null @@ -1,340 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -// TileLang ST kernels for pto.tfmod: tload(a) + tload(b) + tfmod(a,b)->c + tstore(c). -// Aligned with pto-isa/tests/npu/a2a3/src/st/testcase/tfmod -// Cases have different src/dst tile buffer sizes but same valid_shape. - -module { - // Case 0: f32, dst=16x64, src0=16x128, src1=16x128, valid=16x64 - func.func @TFMOD_f32_16x64_16x128_16x128_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c16 = arith.constant 16 : index - %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c1024 = arith.constant 1024 : index - %c2048 = arith.constant 2048 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c16, %c128], - strides = [%c2048, %c2048, %c2048, %c128, %c1] - : !pto.tensor_view<1x1x1x16x128xf32> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c16, %c128], - strides = [%c2048, %c2048, %c2048, %c128, %c1] - : !pto.tensor_view<1x1x1x16x128xf32> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c16, %c64], - strides = [%c1024, %c1024, %c1024, %c64, %c1] - : !pto.tensor_view<1x1x1x16x64xf32> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c16, %c64] - : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c16, %c64] - : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c16, %c64] - : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) - outs(%b : !pto.tile_buf) - - pto.tfmod ins(%a, %b : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) - return - } - - // Case 1: f32, dst=16x32, src0=16x64, src1=16x32, valid=16x32 - func.func @TFMOD_f32_16x32_16x64_16x32_16x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c16 = arith.constant 16 : index - %c32 = arith.constant 32 : index - %c64 = arith.constant 64 : index - %c512 = arith.constant 512 : index - %c1024 = arith.constant 1024 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c16, %c64], - strides = [%c1024, %c1024, %c1024, %c64, %c1] - : !pto.tensor_view<1x1x1x16x64xf32> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c16, %c32], - strides = [%c512, %c512, %c512, %c32, %c1] - : !pto.tensor_view<1x1x1x16x32xf32> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c16, %c32], - strides = [%c512, %c512, %c512, %c32, %c1] - : !pto.tensor_view<1x1x1x16x32xf32> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c16, %c32] - : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x32xf32> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c16, %c32] - : !pto.tensor_view<1x1x1x16x32xf32> -> !pto.partition_tensor_view<1x1x1x16x32xf32> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c16, %c32] - : !pto.tensor_view<1x1x1x16x32xf32> -> !pto.partition_tensor_view<1x1x1x16x32xf32> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x32xf32>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x32xf32>) - outs(%b : !pto.tile_buf) - - pto.tfmod ins(%a, %b : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x16x32xf32>) - return - } - - // Case 2: f32, dst=16x64, src0=16x128, src1=16x128, valid=16x63 - func.func @TFMOD_f32_16x64_16x128_16x128_16x63(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c16 = arith.constant 16 : index - %c63 = arith.constant 63 : index - %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c1024 = arith.constant 1024 : index - %c2048 = arith.constant 2048 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c16, %c128], - strides = [%c2048, %c2048, %c2048, %c128, %c1] - : !pto.tensor_view<1x1x1x16x128xf32> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c16, %c128], - strides = [%c2048, %c2048, %c2048, %c128, %c1] - : !pto.tensor_view<1x1x1x16x128xf32> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c16, %c64], - strides = [%c1024, %c1024, %c1024, %c64, %c1] - : !pto.tensor_view<1x1x1x16x64xf32> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c16, %c63] - : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x63xf32> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c16, %c63] - : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x63xf32> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c16, %c63] - : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x63xf32> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x63xf32>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x63xf32>) - outs(%b : !pto.tile_buf) - - pto.tfmod ins(%a, %b : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x16x63xf32>) - return - } - - // Case 3: f32, dst=2x32, src0=2x64, src1=2x32, valid=2x31 - func.func @TFMOD_f32_2x32_2x64_2x32_2x31(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c31 = arith.constant 31 : index - %c32 = arith.constant 32 : index - %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c2, %c64], - strides = [%c128, %c128, %c128, %c64, %c1] - : !pto.tensor_view<1x1x1x2x64xf32> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c2, %c32], - strides = [%c64, %c64, %c64, %c32, %c1] - : !pto.tensor_view<1x1x1x2x32xf32> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c2, %c32], - strides = [%c64, %c64, %c64, %c32, %c1] - : !pto.tensor_view<1x1x1x2x32xf32> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c2, %c31] - : !pto.tensor_view<1x1x1x2x64xf32> -> !pto.partition_tensor_view<1x1x1x2x31xf32> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c2, %c31] - : !pto.tensor_view<1x1x1x2x32xf32> -> !pto.partition_tensor_view<1x1x1x2x31xf32> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c2, %c31] - : !pto.tensor_view<1x1x1x2x32xf32> -> !pto.partition_tensor_view<1x1x1x2x31xf32> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x2x31xf32>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x2x31xf32>) - outs(%b : !pto.tile_buf) - - pto.tfmod ins(%a, %b : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x2x31xf32>) - return - } - - // Case 4: f32, dst=1x8192, src0=1x8192, src1=1x8192, valid=1x8192 - func.func @TFMOD_f32_1x8192_1x8192_1x8192_1x8192(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c8192 = arith.constant 8192 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c1, %c8192], - strides = [%c8192, %c8192, %c8192, %c8192, %c1] - : !pto.tensor_view<1x1x1x1x8192xf32> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c1, %c8192], - strides = [%c8192, %c8192, %c8192, %c8192, %c1] - : !pto.tensor_view<1x1x1x1x8192xf32> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c1, %c8192], - strides = [%c8192, %c8192, %c8192, %c8192, %c1] - : !pto.tensor_view<1x1x1x1x8192xf32> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c1, %c8192] - : !pto.tensor_view<1x1x1x1x8192xf32> -> !pto.partition_tensor_view<1x1x1x1x8192xf32> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c1, %c8192] - : !pto.tensor_view<1x1x1x1x8192xf32> -> !pto.partition_tensor_view<1x1x1x1x8192xf32> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c1, %c8192] - : !pto.tensor_view<1x1x1x1x8192xf32> -> !pto.partition_tensor_view<1x1x1x1x8192xf32> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x1x8192xf32>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x1x8192xf32>) - outs(%b : !pto.tile_buf) - - pto.tfmod ins(%a, %b : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x1x8192xf32>) - return - } -} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trem/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/trem/CMakeLists.txt deleted file mode 100644 index 1c489d8e5..000000000 --- a/test/tilelang_st/npu/a5/src/st/testcase/trem/CMakeLists.txt +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright (c) 2026 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You can not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. - -pto_tilelang_vec_st(trem) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trem/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/trem/cases.py deleted file mode 100644 index c7f0724b9..000000000 --- a/test/tilelang_st/npu/a5/src/st/testcase/trem/cases.py +++ /dev/null @@ -1,102 +0,0 @@ -#!/usr/bin/python3 -# Copyright (c) 2026 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You can not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. - -# coding=utf-8 - -"""Single source of truth for trem ST test cases. - -Each case defines: - - name: case identifier - - dtype: numpy dtype (np.float32 or np.int32) - - dst_tile: (rows, cols) — dst tile buffer dimensions - - src0_tile: (rows, cols) — src0 tile buffer dimensions - - src1_tile: (rows, cols) — src1 tile buffer dimensions - - valid_shape: (valid_rows, valid_cols) — effective computation region - - eps: tolerance for numpy.allclose (atol and rtol) - -Note: src0/src1/dst tile buffer physical sizes can differ, - but valid_shape must be the same for all. -""" - -import numpy as np - -CASES = [ - { - "name": "f32_16x64_16x128_16x128_16x64", - "dtype": np.float32, - "dst_tile": (16, 64), - "src0_tile": (16, 128), - "src1_tile": (16, 128), - "valid_shape": (16, 64), - "eps": 1e-3, - }, - { - "name": "f32_16x32_16x64_16x32_16x32", - "dtype": np.float32, - "dst_tile": (16, 32), - "src0_tile": (16, 64), - "src1_tile": (16, 32), - "valid_shape": (16, 32), - "eps": 1e-3, - }, - { - "name": "i32_4x32_4x32_4x32_4x32", - "dtype": np.int32, - "dst_tile": (4, 32), - "src0_tile": (4, 32), - "src1_tile": (4, 32), - "valid_shape": (4, 32), - "eps": 0, - }, - { - "name": "i32_16x32_16x64_16x32_16x32", - "dtype": np.int32, - "dst_tile": (16, 32), - "src0_tile": (16, 64), - "src1_tile": (16, 32), - "valid_shape": (16, 32), - "eps": 0, - }, - { - "name": "f32_16x64_16x128_16x128_16x63", - "dtype": np.float32, - "dst_tile": (16, 64), - "src0_tile": (16, 128), - "src1_tile": (16, 128), - "valid_shape": (16, 63), - "eps": 1e-3, - }, - { - "name": "f32_2x32_2x64_2x32_2x31", - "dtype": np.float32, - "dst_tile": (2, 32), - "src0_tile": (2, 64), - "src1_tile": (2, 32), - "valid_shape": (2, 31), - "eps": 1e-3, - }, - { - "name": "i32_16x32_16x64_16x32_16x31", - "dtype": np.int32, - "dst_tile": (16, 32), - "src0_tile": (16, 64), - "src1_tile": (16, 32), - "valid_shape": (16, 31), - "eps": 0, - }, - { - "name": "f32_1x8192_1x8192_1x8192_1x8192", - "dtype": np.float32, - "dst_tile": (1, 8192), - "src0_tile": (1, 8192), - "src1_tile": (1, 8192), - "valid_shape": (1, 8192), - "eps": 1e-3, - }, -] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trem/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/trem/compare.py deleted file mode 100644 index f09a461e9..000000000 --- a/test/tilelang_st/npu/a5/src/st/testcase/trem/compare.py +++ /dev/null @@ -1,50 +0,0 @@ -#!/usr/bin/python3 -# Copyright (c) 2026 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You can not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. - -# coding=utf-8 - -import os -import sys -import numpy as np - -from cases import CASES -from st_common import result_cmp, style_fail, style_pass, validate_cases - - -def main(): - validate_cases(CASES) - case_filter = sys.argv[1] if len(sys.argv) > 1 else None - - all_passed = True - for case in CASES: - if case_filter is not None and case["name"] != case_filter: - continue - - case_dir = case["name"] - dst_tile = case["dst_tile"] - valid_shape = case["valid_shape"] - vr, vc = valid_shape - - golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(dst_tile) - output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(dst_tile) - - ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) - if ok: - print(style_pass(f"[INFO] {case['name']}: compare passed")) - else: - print(style_fail(f"[ERROR] {case['name']}: compare failed")) - all_passed = False - - if not all_passed: - sys.exit(2) - print(style_pass("[INFO] all cases passed")) - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trem/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/trem/gen_data.py deleted file mode 100644 index 038e4f976..000000000 --- a/test/tilelang_st/npu/a5/src/st/testcase/trem/gen_data.py +++ /dev/null @@ -1,37 +0,0 @@ -#!/usr/bin/python3 -# Copyright (c) 2026 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You can not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. - -# coding=utf-8 - -import numpy as np -from cases import CASES -from st_common import validate_cases, setup_case_rng, save_case_data - -validate_cases(CASES) - -np.random.seed(19) - -for case in CASES: - setup_case_rng(case) - - dtype = case["dtype"] - dst_tile = case["dst_tile"] - src0_tile = case["src0_tile"] - src1_tile = case["src1_tile"] - valid_shape = case["valid_shape"] - - input1 = np.random.uniform(low=-1000, high=1000, size=src0_tile).astype(dtype) - input2 = np.random.uniform(low=3, high=100, size=src1_tile).astype(dtype) - - golden = np.zeros(dst_tile, dtype=dtype) - vr, vc = valid_shape - golden[:vr, :vc] = np.remainder(input1[:vr, :vc], input2[:vr, :vc]).astype(dtype) - - save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) - print(f"[INFO] gen_data: {case['name']} dst={dst_tile} src0={src0_tile} src1={src1_tile} valid={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trem/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trem/launch.cpp deleted file mode 100644 index 96bbf77eb..000000000 --- a/test/tilelang_st/npu/a5/src/st/testcase/trem/launch.cpp +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You can not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -#include - -#ifndef AICORE -#define AICORE [aicore] -#endif - -// Case 0: f32, dst=16x64, src0=16x128, src1=16x128, valid=16x64 -extern "C" __global__ AICORE void TREM_f32_16x64_16x128_16x128_16x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); - -void LaunchTREM_f32_16x64_16x128_16x128_16x64(void *a, void *b, void *c, void *stream) { - TREM_f32_16x64_16x128_16x128_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); -} - -// Case 1: f32, dst=16x32, src0=16x64, src1=16x32, valid=16x32 -extern "C" __global__ AICORE void TREM_f32_16x32_16x64_16x32_16x32(__gm__ float *a, __gm__ float *b, __gm__ float *c); - -void LaunchTREM_f32_16x32_16x64_16x32_16x32(void *a, void *b, void *c, void *stream) { - TREM_f32_16x32_16x64_16x32_16x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); -} - -// Case 2: i32, dst=4x32, src0=4x32, src1=4x32, valid=4x32 -extern "C" __global__ AICORE void TREM_i32_4x32_4x32_4x32_4x32(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); - -void LaunchTREM_i32_4x32_4x32_4x32_4x32(void *a, void *b, void *c, void *stream) { - TREM_i32_4x32_4x32_4x32_4x32<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); -} - -// Case 3: i32, dst=16x32, src0=16x64, src1=16x32, valid=16x32 -extern "C" __global__ AICORE void TREM_i32_16x32_16x64_16x32_16x32(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); - -void LaunchTREM_i32_16x32_16x64_16x32_16x32(void *a, void *b, void *c, void *stream) { - TREM_i32_16x32_16x64_16x32_16x32<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); -} - -// Case 4: f32, dst=16x64, src0=16x128, src1=16x128, valid=16x63 -extern "C" __global__ AICORE void TREM_f32_16x64_16x128_16x128_16x63(__gm__ float *a, __gm__ float *b, __gm__ float *c); - -void LaunchTREM_f32_16x64_16x128_16x128_16x63(void *a, void *b, void *c, void *stream) { - TREM_f32_16x64_16x128_16x128_16x63<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); -} - -// Case 5: f32, dst=2x32, src0=2x64, src1=2x32, valid=2x31 -extern "C" __global__ AICORE void TREM_f32_2x32_2x64_2x32_2x31(__gm__ float *a, __gm__ float *b, __gm__ float *c); - -void LaunchTREM_f32_2x32_2x64_2x32_2x31(void *a, void *b, void *c, void *stream) { - TREM_f32_2x32_2x64_2x32_2x31<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); -} - -// Case 6: i32, dst=16x32, src0=16x64, src1=16x32, valid=16x31 -extern "C" __global__ AICORE void TREM_i32_16x32_16x64_16x32_16x31(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); - -void LaunchTREM_i32_16x32_16x64_16x32_16x31(void *a, void *b, void *c, void *stream) { - TREM_i32_16x32_16x64_16x32_16x31<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); -} - -// Case 7: f32, dst=1x8192, src0=1x8192, src1=1x8192, valid=1x8192 -extern "C" __global__ AICORE void TREM_f32_1x8192_1x8192_1x8192_1x8192(__gm__ float *a, __gm__ float *b, __gm__ float *c); - -void LaunchTREM_f32_1x8192_1x8192_1x8192_1x8192(void *a, void *b, void *c, void *stream) { - TREM_f32_1x8192_1x8192_1x8192_1x8192<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); -} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trem/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trem/main.cpp deleted file mode 100644 index 07c5020bf..000000000 --- a/test/tilelang_st/npu/a5/src/st/testcase/trem/main.cpp +++ /dev/null @@ -1,157 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You can not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -// Host driver for TileLang trem ST — case-table driven. - -#include "acl/acl.h" -#include "test_common.h" -#include -#include -#include -#include -#include -#include - -using namespace PtoTestCommon; - -void LaunchTREM_f32_16x64_16x128_16x128_16x64(void *a, void *b, void *c, void *stream); -void LaunchTREM_f32_16x32_16x64_16x32_16x32(void *a, void *b, void *c, void *stream); -void LaunchTREM_i32_4x32_4x32_4x32_4x32(void *a, void *b, void *c, void *stream); -void LaunchTREM_i32_16x32_16x64_16x32_16x32(void *a, void *b, void *c, void *stream); -void LaunchTREM_f32_16x64_16x128_16x128_16x63(void *a, void *b, void *c, void *stream); -void LaunchTREM_f32_2x32_2x64_2x32_2x31(void *a, void *b, void *c, void *stream); -void LaunchTREM_i32_16x32_16x64_16x32_16x31(void *a, void *b, void *c, void *stream); -void LaunchTREM_f32_1x8192_1x8192_1x8192_1x8192(void *a, void *b, void *c, void *stream); - -using LaunchFn = void (*)(void *, void *, void *, void *); - -struct TestCase { - const char *name; - LaunchFn launch; - size_t src0Rows; - size_t src0Cols; - size_t src1Rows; - size_t src1Cols; - size_t dstRows; - size_t dstCols; - size_t validRows; - size_t validCols; - size_t elemSize; -}; - -static const TestCase kCases[] = { - {"f32_16x64_16x128_16x128_16x64", LaunchTREM_f32_16x64_16x128_16x128_16x64, 16, 128, 16, 128, 16, 64, 16, 64, sizeof(float)}, - {"f32_16x32_16x64_16x32_16x32", LaunchTREM_f32_16x32_16x64_16x32_16x32, 16, 64, 16, 32, 16, 32, 16, 32, sizeof(float)}, - {"i32_4x32_4x32_4x32_4x32", LaunchTREM_i32_4x32_4x32_4x32_4x32, 4, 32, 4, 32, 4, 32, 4, 32, sizeof(int32_t)}, - {"i32_16x32_16x64_16x32_16x32", LaunchTREM_i32_16x32_16x64_16x32_16x32, 16, 64, 16, 32, 16, 32, 16, 32, sizeof(int32_t)}, - {"f32_16x64_16x128_16x128_16x63", LaunchTREM_f32_16x64_16x128_16x128_16x63, 16, 128, 16, 128, 16, 64, 16, 63, sizeof(float)}, - {"f32_2x32_2x64_2x32_2x31", LaunchTREM_f32_2x32_2x64_2x32_2x31, 2, 64, 2, 32, 2, 32, 2, 31, sizeof(float)}, - {"i32_16x32_16x64_16x32_16x31", LaunchTREM_i32_16x32_16x64_16x32_16x31, 16, 64, 16, 32, 16, 32, 16, 31, sizeof(int32_t)}, - {"f32_1x8192_1x8192_1x8192_1x8192", LaunchTREM_f32_1x8192_1x8192_1x8192_1x8192, 1, 8192, 1, 8192, 1, 8192, 1, 8192, sizeof(float)}, -}; -static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); - -static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { - int rc = 0; - const size_t src0Size = tc.src0Rows * tc.src0Cols * tc.elemSize; - const size_t src1Size = tc.src1Rows * tc.src1Cols * tc.elemSize; - const size_t dstSize = tc.dstRows * tc.dstCols * tc.elemSize; - - std::printf("[INFO] === case: %s (dst=%zux%zu, src0=%zux%zu, src1=%zux%zu, valid=%zux%zu) ===\n", - tc.name, tc.dstRows, tc.dstCols, tc.src0Rows, tc.src0Cols, tc.src1Rows, tc.src1Cols, tc.validRows, tc.validCols); - - std::string caseDir = std::string("./") + tc.name; - - void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; - void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; - - aclrtMallocHost(&src0Host, src0Size); - aclrtMallocHost(&src1Host, src1Size); - aclrtMallocHost(&dstHost, dstSize); - - aclrtMalloc(&src0Device, src0Size, ACL_MEM_MALLOC_HUGE_FIRST); - aclrtMalloc(&src1Device, src1Size, ACL_MEM_MALLOC_HUGE_FIRST); - aclrtMalloc(&dstDevice, dstSize, ACL_MEM_MALLOC_HUGE_FIRST); - - size_t fileSize = 0; - if (!ReadFile((caseDir + "/input1.bin").c_str(), fileSize, src0Host, src0Size)) { - std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); - rc = 1; - } - fileSize = 0; - if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), fileSize, src1Host, src1Size)) { - std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); - rc = 1; - } - - if (rc == 0) { - aclrtMemcpy(src0Device, src0Size, src0Host, src0Size, ACL_MEMCPY_HOST_TO_DEVICE); - aclrtMemcpy(src1Device, src1Size, src1Host, src1Size, ACL_MEMCPY_HOST_TO_DEVICE); - - tc.launch(src0Device, src1Device, dstDevice, stream); - - aclrtSynchronizeStream(stream); - aclrtMemcpy(dstHost, dstSize, dstDevice, dstSize, ACL_MEMCPY_DEVICE_TO_HOST); - } - - if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstSize)) { - std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); - rc = 1; - } - - if (src0Device != nullptr) - aclrtFree(src0Device); - if (src1Device != nullptr) - aclrtFree(src1Device); - if (dstDevice != nullptr) - aclrtFree(dstDevice); - if (src0Host != nullptr) - aclrtFreeHost(src0Host); - if (src1Host != nullptr) - aclrtFreeHost(src1Host); - if (dstHost != nullptr) - aclrtFreeHost(dstHost); - - if (rc == 0) - std::printf("[INFO] case %s done\n", tc.name); - return rc; -} - -int main(int argc, char *argv[]) { - const char *caseFilter = (argc > 1) ? argv[1] : nullptr; - - int rc = 0; - int deviceId = 0; - aclrtStream stream = nullptr; - - aclInit(nullptr); - if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { - deviceId = std::atoi(envDevice); - } - aclrtSetDevice(deviceId); - aclrtCreateStream(&stream); - - for (size_t i = 0; i < kNumCases; ++i) { - if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { - continue; - } - int ret = RunCase(kCases[i], deviceId, stream); - if (ret != 0) { - std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); - rc = 1; - break; - } - } - - if (stream != nullptr) - aclrtDestroyStream(stream); - aclrtResetDevice(deviceId); - aclFinalize(); - - return rc; -} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trem/trem.pto b/test/tilelang_st/npu/a5/src/st/testcase/trem/trem.pto deleted file mode 100644 index e34974ef5..000000000 --- a/test/tilelang_st/npu/a5/src/st/testcase/trem/trem.pto +++ /dev/null @@ -1,577 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You can not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -// TileLang ST kernels for pto.trem: tload(a) + tload(b) + trem(a,b,tmp)->c + tstore(c). -// Aligned with pto-isa/tests/npu/a2a3/src/st/testcase/trem -// Cases have different src/dst tile buffer sizes but same valid_shape. - -module { - // Case 0: f32, dst=16x64, src0=16x128, src1=16x128, valid=16x64 - func.func @TREM_f32_16x64_16x128_16x128_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c16 = arith.constant 16 : index - %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c1024 = arith.constant 1024 : index - %c2048 = arith.constant 2048 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c16, %c128], - strides = [%c2048, %c2048, %c2048, %c128, %c1] - : !pto.tensor_view<1x1x1x16x128xf32> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c16, %c128], - strides = [%c2048, %c2048, %c2048, %c128, %c1] - : !pto.tensor_view<1x1x1x16x128xf32> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c16, %c64], - strides = [%c1024, %c1024, %c1024, %c64, %c1] - : !pto.tensor_view<1x1x1x16x64xf32> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c16, %c64] - : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c16, %c64] - : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c16, %c64] - : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %tmp = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) - outs(%b : !pto.tile_buf) - - pto.trem ins(%a, %b, %tmp : !pto.tile_buf, - !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) - return - } - - // Case 1: f32, dst=16x32, src0=16x64, src1=16x32, valid=16x32 - func.func @TREM_f32_16x32_16x64_16x32_16x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c16 = arith.constant 16 : index - %c32 = arith.constant 32 : index - %c64 = arith.constant 64 : index - %c512 = arith.constant 512 : index - %c1024 = arith.constant 1024 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c16, %c64], - strides = [%c1024, %c1024, %c1024, %c64, %c1] - : !pto.tensor_view<1x1x1x16x64xf32> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c16, %c32], - strides = [%c512, %c512, %c512, %c32, %c1] - : !pto.tensor_view<1x1x1x16x32xf32> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c16, %c32], - strides = [%c512, %c512, %c512, %c32, %c1] - : !pto.tensor_view<1x1x1x16x32xf32> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c16, %c32] - : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x32xf32> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c16, %c32] - : !pto.tensor_view<1x1x1x16x32xf32> -> !pto.partition_tensor_view<1x1x1x16x32xf32> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c16, %c32] - : !pto.tensor_view<1x1x1x16x32xf32> -> !pto.partition_tensor_view<1x1x1x16x32xf32> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %tmp = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x32xf32>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x32xf32>) - outs(%b : !pto.tile_buf) - - pto.trem ins(%a, %b, %tmp : !pto.tile_buf, - !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x16x32xf32>) - return - } - - // Case 2: i32, dst=4x32, src0=4x32, src1=4x32, valid=4x32 - func.func @TREM_i32_4x32_4x32_4x32_4x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c4 = arith.constant 4 : index - %c32 = arith.constant 32 : index - %c128 = arith.constant 128 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c4, %c32], - strides = [%c128, %c128, %c128, %c32, %c1] - : !pto.tensor_view<1x1x1x4x32xi32> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c4, %c32], - strides = [%c128, %c128, %c128, %c32, %c1] - : !pto.tensor_view<1x1x1x4x32xi32> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c4, %c32], - strides = [%c128, %c128, %c128, %c32, %c1] - : !pto.tensor_view<1x1x1x4x32xi32> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c4, %c32] - : !pto.tensor_view<1x1x1x4x32xi32> -> !pto.partition_tensor_view<1x1x1x4x32xi32> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c4, %c32] - : !pto.tensor_view<1x1x1x4x32xi32> -> !pto.partition_tensor_view<1x1x1x4x32xi32> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c4, %c32] - : !pto.tensor_view<1x1x1x4x32xi32> -> !pto.partition_tensor_view<1x1x1x4x32xi32> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %tmp = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x4x32xi32>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x4x32xi32>) - outs(%b : !pto.tile_buf) - - pto.trem ins(%a, %b, %tmp : !pto.tile_buf, - !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x4x32xi32>) - return - } - - // Case 3: i32, dst=16x32, src0=16x64, src1=16x32, valid=16x32 - func.func @TREM_i32_16x32_16x64_16x32_16x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c16 = arith.constant 16 : index - %c32 = arith.constant 32 : index - %c64 = arith.constant 64 : index - %c512 = arith.constant 512 : index - %c1024 = arith.constant 1024 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c16, %c64], - strides = [%c1024, %c1024, %c1024, %c64, %c1] - : !pto.tensor_view<1x1x1x16x64xi32> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c16, %c32], - strides = [%c512, %c512, %c512, %c32, %c1] - : !pto.tensor_view<1x1x1x16x32xi32> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c16, %c32], - strides = [%c512, %c512, %c512, %c32, %c1] - : !pto.tensor_view<1x1x1x16x32xi32> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c16, %c32] - : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c16, %c32] - : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c16, %c32] - : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %tmp = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) - outs(%b : !pto.tile_buf) - - pto.trem ins(%a, %b, %tmp : !pto.tile_buf, - !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) - return - } - - // Case 4: f32, dst=16x64, src0=16x128, src1=16x128, valid=16x63 - func.func @TREM_f32_16x64_16x128_16x128_16x63(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c16 = arith.constant 16 : index - %c63 = arith.constant 63 : index - %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c1024 = arith.constant 1024 : index - %c2048 = arith.constant 2048 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c16, %c128], - strides = [%c2048, %c2048, %c2048, %c128, %c1] - : !pto.tensor_view<1x1x1x16x128xf32> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c16, %c128], - strides = [%c2048, %c2048, %c2048, %c128, %c1] - : !pto.tensor_view<1x1x1x16x128xf32> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c16, %c64], - strides = [%c1024, %c1024, %c1024, %c64, %c1] - : !pto.tensor_view<1x1x1x16x64xf32> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c16, %c63] - : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x63xf32> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c16, %c63] - : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x63xf32> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c16, %c63] - : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x63xf32> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %tmp = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x63xf32>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x63xf32>) - outs(%b : !pto.tile_buf) - - pto.trem ins(%a, %b, %tmp : !pto.tile_buf, - !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x16x63xf32>) - return - } - - // Case 5: f32, dst=2x32, src0=2x64, src1=2x32, valid=2x31 - func.func @TREM_f32_2x32_2x64_2x32_2x31(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c31 = arith.constant 31 : index - %c32 = arith.constant 32 : index - %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c2, %c64], - strides = [%c128, %c128, %c128, %c64, %c1] - : !pto.tensor_view<1x1x1x2x64xf32> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c2, %c32], - strides = [%c64, %c64, %c64, %c32, %c1] - : !pto.tensor_view<1x1x1x2x32xf32> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c2, %c32], - strides = [%c64, %c64, %c64, %c32, %c1] - : !pto.tensor_view<1x1x1x2x32xf32> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c2, %c31] - : !pto.tensor_view<1x1x1x2x64xf32> -> !pto.partition_tensor_view<1x1x1x2x31xf32> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c2, %c31] - : !pto.tensor_view<1x1x1x2x32xf32> -> !pto.partition_tensor_view<1x1x1x2x31xf32> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c2, %c31] - : !pto.tensor_view<1x1x1x2x32xf32> -> !pto.partition_tensor_view<1x1x1x2x31xf32> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %tmp = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x2x31xf32>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x2x31xf32>) - outs(%b : !pto.tile_buf) - - pto.trem ins(%a, %b, %tmp : !pto.tile_buf, - !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x2x31xf32>) - return - } - - // Case 6: i32, dst=16x32, src0=16x64, src1=16x32, valid=16x31 - func.func @TREM_i32_16x32_16x64_16x32_16x31(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c16 = arith.constant 16 : index - %c31 = arith.constant 31 : index - %c32 = arith.constant 32 : index - %c64 = arith.constant 64 : index - %c512 = arith.constant 512 : index - %c1024 = arith.constant 1024 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c16, %c64], - strides = [%c1024, %c1024, %c1024, %c64, %c1] - : !pto.tensor_view<1x1x1x16x64xi32> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c16, %c32], - strides = [%c512, %c512, %c512, %c32, %c1] - : !pto.tensor_view<1x1x1x16x32xi32> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c16, %c32], - strides = [%c512, %c512, %c512, %c32, %c1] - : !pto.tensor_view<1x1x1x16x32xi32> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c16, %c31] - : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x31xi32> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c16, %c31] - : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x31xi32> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c16, %c31] - : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x31xi32> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %tmp = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x31xi32>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x31xi32>) - outs(%b : !pto.tile_buf) - - pto.trem ins(%a, %b, %tmp : !pto.tile_buf, - !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x16x31xi32>) - return - } - - // Case 7: f32, dst=1x8192, src0=1x8192, src1=1x8192, valid=1x8192 - func.func @TREM_f32_1x8192_1x8192_1x8192_1x8192(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c8192 = arith.constant 8192 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c1, %c8192], - strides = [%c8192, %c8192, %c8192, %c8192, %c1] - : !pto.tensor_view<1x1x1x1x8192xf32> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c1, %c8192], - strides = [%c8192, %c8192, %c8192, %c8192, %c1] - : !pto.tensor_view<1x1x1x1x8192xf32> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c1, %c8192], - strides = [%c8192, %c8192, %c8192, %c8192, %c1] - : !pto.tensor_view<1x1x1x1x8192xf32> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c1, %c8192] - : !pto.tensor_view<1x1x1x1x8192xf32> -> !pto.partition_tensor_view<1x1x1x1x8192xf32> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c1, %c8192] - : !pto.tensor_view<1x1x1x1x8192xf32> -> !pto.partition_tensor_view<1x1x1x1x8192xf32> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c1, %c8192] - : !pto.tensor_view<1x1x1x1x8192xf32> -> !pto.partition_tensor_view<1x1x1x1x8192xf32> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %tmp = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x1x8192xf32>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x1x8192xf32>) - outs(%b : !pto.tile_buf) - - pto.trem ins(%a, %b, %tmp : !pto.tile_buf, - !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x1x8192xf32>) - return - } -} \ No newline at end of file From e707c7d687ad056333c016854ec915f7be112d73 Mon Sep 17 00:00:00 2001 From: zwd060924 Date: Mon, 27 Apr 2026 14:09:25 +0800 Subject: [PATCH 184/192] [Delete] tcmp tfmod trem --- lib/TileOps/tand_template.py | 8 ++++++++ lib/TileOps/tdiv_template.py | 8 ++++++++ lib/TileOps/tmax_template.py | 8 ++++++++ lib/TileOps/tmin_template.py | 8 ++++++++ lib/TileOps/tmul_template.py | 8 ++++++++ lib/TileOps/tor_template.py | 8 ++++++++ lib/TileOps/tshl_template.py | 8 ++++++++ lib/TileOps/tshr_template.py | 8 ++++++++ lib/TileOps/tsub_template.py | 8 ++++++++ lib/TileOps/txor_template.py | 8 ++++++++ 10 files changed, 80 insertions(+) diff --git a/lib/TileOps/tand_template.py b/lib/TileOps/tand_template.py index 4d771d44b..6c1477197 100644 --- a/lib/TileOps/tand_template.py +++ b/lib/TileOps/tand_template.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + """TileLang DSL template for pto.tand""" import sys diff --git a/lib/TileOps/tdiv_template.py b/lib/TileOps/tdiv_template.py index c1c8aea65..3c3b443f4 100644 --- a/lib/TileOps/tdiv_template.py +++ b/lib/TileOps/tdiv_template.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + """TileLang DSL template for pto.tdiv""" import sys diff --git a/lib/TileOps/tmax_template.py b/lib/TileOps/tmax_template.py index 645da3924..8831d3eef 100644 --- a/lib/TileOps/tmax_template.py +++ b/lib/TileOps/tmax_template.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + """TileLang DSL template for pto.tmax""" import sys diff --git a/lib/TileOps/tmin_template.py b/lib/TileOps/tmin_template.py index c03d74e3c..61664b14d 100644 --- a/lib/TileOps/tmin_template.py +++ b/lib/TileOps/tmin_template.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + """TileLang DSL template for pto.tmin""" import sys diff --git a/lib/TileOps/tmul_template.py b/lib/TileOps/tmul_template.py index 46acebab0..ae7adf44e 100644 --- a/lib/TileOps/tmul_template.py +++ b/lib/TileOps/tmul_template.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + """TileLang DSL template for pto.tmul""" import sys diff --git a/lib/TileOps/tor_template.py b/lib/TileOps/tor_template.py index 6efa2fe4e..e8be63d5e 100644 --- a/lib/TileOps/tor_template.py +++ b/lib/TileOps/tor_template.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + """TileLang DSL template for pto.tor""" import sys diff --git a/lib/TileOps/tshl_template.py b/lib/TileOps/tshl_template.py index 28f448509..d236c8940 100644 --- a/lib/TileOps/tshl_template.py +++ b/lib/TileOps/tshl_template.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + """TileLang DSL template for pto.tshl""" import sys diff --git a/lib/TileOps/tshr_template.py b/lib/TileOps/tshr_template.py index 59a9e8117..f16ba9abe 100644 --- a/lib/TileOps/tshr_template.py +++ b/lib/TileOps/tshr_template.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + """TileLang DSL template for pto.tshr""" import sys diff --git a/lib/TileOps/tsub_template.py b/lib/TileOps/tsub_template.py index b97777376..81d1b13dd 100644 --- a/lib/TileOps/tsub_template.py +++ b/lib/TileOps/tsub_template.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + """TileLang DSL template for pto.tsub""" import sys diff --git a/lib/TileOps/txor_template.py b/lib/TileOps/txor_template.py index 593a4ca74..d2ca4f1f7 100644 --- a/lib/TileOps/txor_template.py +++ b/lib/TileOps/txor_template.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + """TileLang DSL template for pto.txor""" import sys From 9a4df37a3a1668e6e67d22ba6f5e78f229ff8928 Mon Sep 17 00:00:00 2001 From: zwd060924 Date: Mon, 27 Apr 2026 17:44:53 +0800 Subject: [PATCH 185/192] [Fix] pass CI --- .../npu/a5/src/st/testcase/tadd/cases.py | 74 +-- .../npu/a5/src/st/testcase/tadd/compare.py | 10 +- .../npu/a5/src/st/testcase/tadd/gen_data.py | 15 +- .../npu/a5/src/st/testcase/tadd/launch.cpp | 60 +-- .../npu/a5/src/st/testcase/tadd/main.cpp | 82 ++- .../npu/a5/src/st/testcase/tadd/tadd.pto | 395 +------------- .../npu/a5/src/st/testcase/tand/cases.py | 31 +- .../npu/a5/src/st/testcase/tand/compare.py | 1 - .../npu/a5/src/st/testcase/tand/gen_data.py | 7 +- .../npu/a5/src/st/testcase/tand/launch.cpp | 16 +- .../npu/a5/src/st/testcase/tand/main.cpp | 52 +- .../npu/a5/src/st/testcase/tand/tand.pto | 154 +++--- .../npu/a5/src/st/testcase/tdiv/cases.py | 33 +- .../npu/a5/src/st/testcase/tdiv/compare.py | 1 - .../npu/a5/src/st/testcase/tdiv/gen_data.py | 1 - .../npu/a5/src/st/testcase/tdiv/launch.cpp | 34 +- .../npu/a5/src/st/testcase/tdiv/main.cpp | 35 +- .../npu/a5/src/st/testcase/tdiv/tdiv.pto | 263 +-------- .../npu/a5/src/st/testcase/tmax/cases.py | 59 +- .../npu/a5/src/st/testcase/tmax/compare.py | 1 - .../npu/a5/src/st/testcase/tmax/gen_data.py | 1 - .../npu/a5/src/st/testcase/tmax/launch.cpp | 62 +-- .../npu/a5/src/st/testcase/tmax/main.cpp | 43 +- .../npu/a5/src/st/testcase/tmax/tmax.pto | 508 ------------------ .../npu/a5/src/st/testcase/tmin/cases.py | 3 +- .../npu/a5/src/st/testcase/tmin/compare.py | 3 +- .../npu/a5/src/st/testcase/tmin/gen_data.py | 3 +- .../npu/a5/src/st/testcase/tmul/cases.py | 52 +- .../npu/a5/src/st/testcase/tmul/compare.py | 1 - .../npu/a5/src/st/testcase/tmul/gen_data.py | 1 - .../npu/a5/src/st/testcase/tmul/launch.cpp | 55 +- .../npu/a5/src/st/testcase/tmul/main.cpp | 41 +- .../npu/a5/src/st/testcase/tmul/tmul.pto | 451 +--------------- .../npu/a5/src/st/testcase/tor/cases.py | 15 +- .../npu/a5/src/st/testcase/tor/compare.py | 1 - .../npu/a5/src/st/testcase/tor/gen_data.py | 1 - .../npu/a5/src/st/testcase/tor/tor.pto | 3 - .../npu/a5/src/st/testcase/tshl/cases.py | 15 +- .../npu/a5/src/st/testcase/tshl/compare.py | 1 - .../npu/a5/src/st/testcase/tshl/gen_data.py | 1 - .../npu/a5/src/st/testcase/tshl/tshl.pto | 3 - .../npu/a5/src/st/testcase/tshr/cases.py | 15 +- .../npu/a5/src/st/testcase/tshr/compare.py | 1 - .../npu/a5/src/st/testcase/tshr/gen_data.py | 1 - .../npu/a5/src/st/testcase/tshr/tshr.pto | 3 - .../npu/a5/src/st/testcase/tsub/cases.py | 31 +- .../npu/a5/src/st/testcase/tsub/compare.py | 1 - .../npu/a5/src/st/testcase/tsub/gen_data.py | 1 - .../npu/a5/src/st/testcase/tsub/launch.cpp | 34 +- .../npu/a5/src/st/testcase/tsub/main.cpp | 33 +- .../npu/a5/src/st/testcase/tsub/output.ll | 0 .../npu/a5/src/st/testcase/tsub/tsub.pto | 258 +-------- .../npu/a5/src/st/testcase/txor/cases.py | 56 +- .../npu/a5/src/st/testcase/txor/compare.py | 10 +- .../npu/a5/src/st/testcase/txor/gen_data.py | 16 +- .../npu/a5/src/st/testcase/txor/launch.cpp | 30 +- .../npu/a5/src/st/testcase/txor/main.cpp | 76 ++- .../npu/a5/src/st/testcase/txor/txor.pto | 276 +++------- 58 files changed, 421 insertions(+), 3009 deletions(-) create mode 100644 test/tilelang_st/npu/a5/src/st/testcase/tsub/output.ll diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tadd/cases.py index f911991ec..5958f05d2 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tadd/cases.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadd/cases.py @@ -1,4 +1,3 @@ -#!/usr/bin/python3 # Copyright (c) 2026 Huawei Technologies Co., Ltd. # This program is free software, you can redistribute it and/or modify it under the terms and conditions of # CANN Open Software License Agreement Version 2.0 (the "License"). @@ -14,15 +13,10 @@ Each case defines: - name: case identifier, used as subdirectory name and by main.cpp kCases[]. - dtype: numpy dtype (e.g. np.float32). - - dst_tile: (rows, cols) — dst tile buffer dimensions. - - src0_tile: (rows, cols) — src0 tile buffer dimensions. - - src1_tile: (rows, cols) — src1 tile buffer dimensions. + - shape: (rows, cols) — allocated tile dimensions. - valid_shape: (valid_rows, valid_cols) — effective computation region. - eps: tolerance for numpy.allclose (atol and rtol). -Note: src0/src1/dst tile buffer physical sizes can differ, - but valid_shape must be the same for all. - gen_data.py and compare.py both import this list to avoid redundant definitions. """ @@ -30,75 +24,17 @@ CASES = [ { - "name": "f32_64x128_64x128_64x128_64x128", - "dtype": np.float32, - "dst_tile": (64, 128), - "src0_tile": (64, 128), - "src1_tile": (64, 128), - "valid_shape": (64, 128), - "eps": 1e-6, - }, - { - "name": "f32_16x64_16x64_16x64_16x64", + "name": "f32_16x64", "dtype": np.float32, - "dst_tile": (16, 64), - "src0_tile": (16, 64), - "src1_tile": (16, 64), + "shape": (16, 64), "valid_shape": (16, 64), "eps": 1e-6, }, { - "name": "f32_32x32_32x32_32x32_32x32", + "name": "f32_32x32", "dtype": np.float32, - "dst_tile": (32, 32), - "src0_tile": (32, 32), - "src1_tile": (32, 32), + "shape": (32, 32), "valid_shape": (32, 32), "eps": 1e-6, }, - { - "name": "f32_64x64_64x64_64x64_64x64", - "dtype": np.float32, - "dst_tile": (64, 64), - "src0_tile": (64, 64), - "src1_tile": (64, 64), - "valid_shape": (64, 64), - "eps": 1e-6, - }, - { - "name": "i32_64x64_64x64_64x64_64x64", - "dtype": np.int32, - "dst_tile": (64, 64), - "src0_tile": (64, 64), - "src1_tile": (64, 64), - "valid_shape": (64, 64), - "eps": 0, - }, - { - "name": "i16_64x64_64x64_64x64_64x64", - "dtype": np.int16, - "dst_tile": (64, 64), - "src0_tile": (64, 64), - "src1_tile": (64, 64), - "valid_shape": (64, 64), - "eps": 0, - }, - { - "name": "f16_16x256_16x256_16x256_16x256", - "dtype": np.float16, - "dst_tile": (16, 256), - "src0_tile": (16, 256), - "src1_tile": (16, 256), - "valid_shape": (16, 256), - "eps": 1e-3, - }, - { - "name": "half_16x64_16x128_16x128_16x64", - "dtype": np.float16, - "dst_tile": (16, 64), - "src0_tile": (16, 128), - "src1_tile": (16, 128), - "valid_shape": (16, 64), - "eps": 1e-3, - }, ] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tadd/compare.py index abddab2c7..6a4d5d1aa 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tadd/compare.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadd/compare.py @@ -1,4 +1,3 @@ -#!/usr/bin/python3 # Copyright (c) 2026 Huawei Technologies Co., Ltd. # This program is free software, you can redistribute it and/or modify it under the terms and conditions of # CANN Open Software License Agreement Version 2.0 (the "License"). @@ -27,12 +26,11 @@ def main(): continue case_dir = case["name"] - dst_tile = case["dst_tile"] - valid_shape = case["valid_shape"] - vr, vc = valid_shape + shape = case["shape"] + vr, vc = case["valid_shape"] - golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(dst_tile) - output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(dst_tile) + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) if ok: diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tadd/gen_data.py index 169a94070..986dba17d 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tadd/gen_data.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadd/gen_data.py @@ -1,4 +1,3 @@ -#!/usr/bin/python3 # Copyright (c) 2026 Huawei Technologies Co., Ltd. # This program is free software, you can redistribute it and/or modify it under the terms and conditions of # CANN Open Software License Agreement Version 2.0 (the "License"). @@ -15,23 +14,19 @@ validate_cases(CASES) -np.random.seed(19) - for case in CASES: setup_case_rng(case) dtype = case["dtype"] - dst_tile = case["dst_tile"] - src0_tile = case["src0_tile"] - src1_tile = case["src1_tile"] + shape = case["shape"] valid_shape = case["valid_shape"] - input1 = np.random.randint(1, 10, size=src0_tile).astype(dtype) - input2 = np.random.randint(1, 10, size=src1_tile).astype(dtype) + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + input2 = np.random.randint(1, 10, size=shape).astype(dtype) - golden = np.zeros(dst_tile, dtype=dtype) + golden = np.zeros(shape, dtype=dtype) vr, vc = valid_shape golden[:vr, :vc] = (input1[:vr, :vc] + input2[:vr, :vc]).astype(dtype, copy=False) save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) - print(f"[INFO] gen_data: {case['name']} dst={dst_tile} src0={src0_tile} src1={src1_tile} valid={valid_shape} dtype={dtype.__name__}") + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tadd/launch.cpp index 6dfb594d8..f1074c838 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tadd/launch.cpp +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadd/launch.cpp @@ -1,7 +1,7 @@ // Copyright (c) 2026 Huawei Technologies Co., Ltd. // This program is free software, you can redistribute it and/or modify it under the terms and conditions of // CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You can not use this file except in compliance with the License. +// Please refer to the License for details. You may not use this file except in compliance with the License. // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. @@ -12,58 +12,16 @@ #define AICORE [aicore] #endif -// Case 0: f32 64x128_64x128_64x128_64x128 -extern "C" __global__ AICORE void TADD_f32_64x128_64x128_64x128_64x128(__gm__ float *a, __gm__ float *b, __gm__ float *c); +// Case 0: f32 16x64 +extern "C" __global__ AICORE void TADD_f32_16x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); -void LaunchTADD_f32_64x128_64x128_64x128_64x128(void *a, void *b, void *c, void *stream) { - TADD_f32_64x128_64x128_64x128_64x128<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +void LaunchTADD_f32_16x64(float *a, float *b, float *c, void *stream) { + TADD_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); } -// Case 1: f32 16x64_16x64_16x64_16x64 -extern "C" __global__ AICORE void TADD_f32_16x64_16x64_16x64_16x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); +// Case 1: f32 32x32 +extern "C" __global__ AICORE void TADD_f32_32x32(__gm__ float *a, __gm__ float *b, __gm__ float *c); -void LaunchTADD_f32_16x64_16x64_16x64_16x64(void *a, void *b, void *c, void *stream) { - TADD_f32_16x64_16x64_16x64_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +void LaunchTADD_f32_32x32(float *a, float *b, float *c, void *stream) { + TADD_f32_32x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); } - -// Case 2: f32 32x32_32x32_32x32_32x32 -extern "C" __global__ AICORE void TADD_f32_32x32_32x32_32x32_32x32(__gm__ float *a, __gm__ float *b, __gm__ float *c); - -void LaunchTADD_f32_32x32_32x32_32x32_32x32(void *a, void *b, void *c, void *stream) { - TADD_f32_32x32_32x32_32x32_32x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); -} - -// Case 3: f32 64x64_64x64_64x64_64x64 -extern "C" __global__ AICORE void TADD_f32_64x64_64x64_64x64_64x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); - -void LaunchTADD_f32_64x64_64x64_64x64_64x64(void *a, void *b, void *c, void *stream) { - TADD_f32_64x64_64x64_64x64_64x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); -} - -// Case 4: i32 64x64_64x64_64x64_64x64 -extern "C" __global__ AICORE void TADD_i32_64x64_64x64_64x64_64x64(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); - -void LaunchTADD_i32_64x64_64x64_64x64_64x64(void *a, void *b, void *c, void *stream) { - TADD_i32_64x64_64x64_64x64_64x64<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); -} - -// Case 5: i16 64x64_64x64_64x64_64x64 -extern "C" __global__ AICORE void TADD_i16_64x64_64x64_64x64_64x64(__gm__ int16_t *a, __gm__ int16_t *b, __gm__ int16_t *c); - -void LaunchTADD_i16_64x64_64x64_64x64_64x64(void *a, void *b, void *c, void *stream) { - TADD_i16_64x64_64x64_64x64_64x64<<<1, nullptr, stream>>>((__gm__ int16_t *)a, (__gm__ int16_t *)b, (__gm__ int16_t *)c); -} - -// Case 6: f16 16x256_16x256_16x256_16x256 -extern "C" __global__ AICORE void TADD_f16_16x256_16x256_16x256_16x256(__gm__ half *a, __gm__ half *b, __gm__ half *c); - -void LaunchTADD_f16_16x256_16x256_16x256_16x256(void *a, void *b, void *c, void *stream) { - TADD_f16_16x256_16x256_16x256_16x256<<<1, nullptr, stream>>>((__gm__ half *)a, (__gm__ half *)b, (__gm__ half *)c); -} - -// Case 7: half 16x64_16x128_16x128_16x64 -extern "C" __global__ AICORE void TADD_half_16x64_16x128_16x128_16x64(__gm__ half *a, __gm__ half *b, __gm__ half *c); - -void LaunchTADD_half_16x64_16x128_16x128_16x64(void *a, void *b, void *c, void *stream) { - TADD_half_16x64_16x128_16x128_16x64<<<1, nullptr, stream>>>((__gm__ half *)a, (__gm__ half *)b, (__gm__ half *)c); -} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tadd/main.cpp index 276bd83a7..1a010623f 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tadd/main.cpp +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadd/main.cpp @@ -1,7 +1,7 @@ // Copyright (c) 2026 Huawei Technologies Co., Ltd. // This program is free software, you can redistribute it and/or modify it under the terms and conditions of // CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You can not use this file except in compliance with the License. +// Please refer to the License for details. You may not use this file except in compliance with the License. // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. @@ -22,89 +22,71 @@ using namespace PtoTestCommon; // Kernel launch wrappers (defined in launch.cpp) -void LaunchTADD_f32_64x128_64x128_64x128_64x128(void *a, void *b, void *c, void *stream); -void LaunchTADD_f32_16x64_16x64_16x64_16x64(void *a, void *b, void *c, void *stream); -void LaunchTADD_f32_32x32_32x32_32x32_32x32(void *a, void *b, void *c, void *stream); -void LaunchTADD_f32_64x64_64x64_64x64_64x64(void *a, void *b, void *c, void *stream); -void LaunchTADD_i32_64x64_64x64_64x64_64x64(void *a, void *b, void *c, void *stream); -void LaunchTADD_i16_64x64_64x64_64x64_64x64(void *a, void *b, void *c, void *stream); -void LaunchTADD_f16_16x256_16x256_16x256_16x256(void *a, void *b, void *c, void *stream); -void LaunchTADD_half_16x64_16x128_16x128_16x64(void *a, void *b, void *c, void *stream); +void LaunchTADD_f32_16x64(float *a, float *b, float *c, void *stream); +void LaunchTADD_f32_32x32(float *a, float *b, float *c, void *stream); -using LaunchFn = void (*)(void *, void *, void *, void *); +using LaunchFn = void (*)(float *, float *, float *, void *); struct TestCase { const char *name; LaunchFn launch; - size_t src0Rows; - size_t src0Cols; - size_t src1Rows; - size_t src1Cols; - size_t dstRows; - size_t dstCols; - size_t validRows; - size_t validCols; - size_t elemSize; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element }; static const TestCase kCases[] = { - {"f32_64x128_64x128_64x128_64x128", LaunchTADD_f32_64x128_64x128_64x128_64x128, 64, 128, 64, 128, 64, 128, 64, 128, sizeof(float)}, - {"f32_16x64_16x64_16x64_16x64", LaunchTADD_f32_16x64_16x64_16x64_16x64, 16, 64, 16, 64, 16, 64, 16, 64, sizeof(float)}, - {"f32_32x32_32x32_32x32_32x32", LaunchTADD_f32_32x32_32x32_32x32_32x32, 32, 32, 32, 32, 32, 32, 32, 32, sizeof(float)}, - {"f32_64x64_64x64_64x64_64x64", LaunchTADD_f32_64x64_64x64_64x64_64x64, 64, 64, 64, 64, 64, 64, 64, 64, sizeof(float)}, - {"i32_64x64_64x64_64x64_64x64", LaunchTADD_i32_64x64_64x64_64x64_64x64, 64, 64, 64, 64, 64, 64, 64, 64, sizeof(int32_t)}, - {"i16_64x64_64x64_64x64_64x64", LaunchTADD_i16_64x64_64x64_64x64_64x64, 64, 64, 64, 64, 64, 64, 64, 64, sizeof(int16_t)}, - {"f16_16x256_16x256_16x256_16x256", LaunchTADD_f16_16x256_16x256_16x256_16x256, 16, 256, 16, 256, 16, 256, 16, 256, sizeof(uint16_t)}, - {"half_16x64_16x128_16x128_16x64", LaunchTADD_half_16x64_16x128_16x128_16x64, 16, 128, 16, 128, 16, 64, 16, 64, sizeof(uint16_t)}, + {"f32_16x64", LaunchTADD_f32_16x64, 16, 64, 16, 64, sizeof(float)}, + {"f32_32x32", LaunchTADD_f32_32x32, 32, 32, 32, 32, sizeof(float)}, }; static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { - (void)deviceId; int rc = 0; - const size_t src0Size = tc.src0Rows * tc.src0Cols * tc.elemSize; - const size_t src1Size = tc.src1Rows * tc.src1Cols * tc.elemSize; - const size_t dstSize = tc.dstRows * tc.dstCols * tc.elemSize; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; - std::printf("[INFO] === case: %s (dst=%zux%zu, src0=%zux%zu, src1=%zux%zu, valid=%zux%zu) ===\n", - tc.name, tc.dstRows, tc.dstCols, tc.src0Rows, tc.src0Cols, tc.src1Rows, tc.src1Cols, tc.validRows, tc.validCols); + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); // Per-case data directory std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = fileSize; + size_t src1FileSize = fileSize; - void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; - void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + float *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + float *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; - aclrtMallocHost(&src0Host, src0Size); - aclrtMallocHost(&src1Host, src1Size); - aclrtMallocHost(&dstHost, dstSize); + aclrtMallocHost((void **)(&src0Host), fileSize); + aclrtMallocHost((void **)(&src1Host), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); - aclrtMalloc(&src0Device, src0Size, ACL_MEM_MALLOC_HUGE_FIRST); - aclrtMalloc(&src1Device, src1Size, ACL_MEM_MALLOC_HUGE_FIRST); - aclrtMalloc(&dstDevice, dstSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); - size_t fileSize = 0; - if (!ReadFile((caseDir + "/input1.bin").c_str(), fileSize, src0Host, src0Size)) { + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, fileSize)) { std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); rc = 1; } - fileSize = 0; - if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), fileSize, src1Host, src1Size)) { + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, fileSize)) { std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); rc = 1; } if (rc == 0) { - aclrtMemcpy(src0Device, src0Size, src0Host, src0Size, ACL_MEMCPY_HOST_TO_DEVICE); - aclrtMemcpy(src1Device, src1Size, src1Host, src1Size, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src0Device, fileSize, src0Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, fileSize, src1Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); tc.launch(src0Device, src1Device, dstDevice, stream); aclrtSynchronizeStream(stream); - aclrtMemcpy(dstHost, dstSize, dstDevice, dstSize, ACL_MEMCPY_DEVICE_TO_HOST); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); } - if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstSize)) { + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); rc = 1; } @@ -160,4 +142,4 @@ int main(int argc, char *argv[]) { aclFinalize(); return rc; -} \ No newline at end of file +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/tadd.pto b/test/tilelang_st/npu/a5/src/st/testcase/tadd/tadd.pto index 52c8cfd05..340e416c3 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tadd/tadd.pto +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadd/tadd.pto @@ -1,7 +1,7 @@ // Copyright (c) 2026 Huawei Technologies Co., Ltd. // This program is free software, you can redistribute it and/or modify it under the terms and conditions of // CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You can not use this file except in compliance with the License. +// Please refer to the License for details. You may not use this file except in compliance with the License. // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. @@ -12,72 +12,8 @@ // to produce LLVM IR. module { - // Case 0: f32 64x128_64x128_64x128_64x128 (8192 elements) - func.func @TADD_f32_64x128_64x128_64x128_64x128(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c8192 = arith.constant 8192 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c64, %c128], - strides = [%c8192, %c8192, %c8192, %c128, %c1] - : !pto.tensor_view<1x1x1x64x128xf32> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c64, %c128], - strides = [%c8192, %c8192, %c8192, %c128, %c1] - : !pto.tensor_view<1x1x1x64x128xf32> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c64, %c128], - strides = [%c8192, %c8192, %c8192, %c128, %c1] - : !pto.tensor_view<1x1x1x64x128xf32> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c128] - : !pto.tensor_view<1x1x1x64x128xf32> -> !pto.partition_tensor_view<1x1x1x64x128xf32> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c128] - : !pto.tensor_view<1x1x1x64x128xf32> -> !pto.partition_tensor_view<1x1x1x64x128xf32> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c128] - : !pto.tensor_view<1x1x1x64x128xf32> -> !pto.partition_tensor_view<1x1x1x64x128xf32> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x128xf32>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x128xf32>) - outs(%b : !pto.tile_buf) - - pto.tadd ins(%a, %b : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x64x128xf32>) - return - } - - // Case 1: f32 16x64_16x64_16x64_16x64 (1024 elements) - func.func @TADD_f32_16x64_16x64_16x64_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + // Case 0: f32 16x64 (1024 elements) + func.func @TADD_f32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c16 = arith.constant 16 : index @@ -140,8 +76,8 @@ module { return } - // Case 2: f32 32x32_32x32_32x32_32x32 (1024 elements) - func.func @TADD_f32_32x32_32x32_32x32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + // Case 1: f32 32x32 (1024 elements) + func.func @TADD_f32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c32 = arith.constant 32 : index @@ -202,323 +138,4 @@ module { outs(%c_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) return } - - // Case 3: f32 64x64_64x64_64x64_64x64 (4096 elements) - func.func @TADD_f32_64x64_64x64_64x64_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c64 = arith.constant 64 : index - %c4096 = arith.constant 4096 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xf32> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xf32> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xf32> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) - outs(%b : !pto.tile_buf) - - pto.tadd ins(%a, %b : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) - return - } - - // Case 4: i32 64x64_64x64_64x64_64x64 (4096 elements) - func.func @TADD_i32_64x64_64x64_64x64_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c64 = arith.constant 64 : index - %c4096 = arith.constant 4096 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xi32> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xi32> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xi32> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) - outs(%b : !pto.tile_buf) - - pto.tadd ins(%a, %b : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) - return - } - - // Case 5: i16 64x64_64x64_64x64_64x64 (4096 elements) - func.func @TADD_i16_64x64_64x64_64x64_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c64 = arith.constant 64 : index - %c4096 = arith.constant 4096 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xi16> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xi16> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xi16> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) - outs(%b : !pto.tile_buf) - - pto.tadd ins(%a, %b : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) - return - } - - // Case 6: f16 16x256_16x256_16x256_16x256 (4096 elements) - func.func @TADD_f16_16x256_16x256_16x256_16x256(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c16 = arith.constant 16 : index - %c256 = arith.constant 256 : index - %c4096 = arith.constant 4096 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c16, %c256], - strides = [%c4096, %c4096, %c4096, %c256, %c1] - : !pto.tensor_view<1x1x1x16x256xf16> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c16, %c256], - strides = [%c4096, %c4096, %c4096, %c256, %c1] - : !pto.tensor_view<1x1x1x16x256xf16> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c16, %c256], - strides = [%c4096, %c4096, %c4096, %c256, %c1] - : !pto.tensor_view<1x1x1x16x256xf16> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c16, %c256] - : !pto.tensor_view<1x1x1x16x256xf16> -> !pto.partition_tensor_view<1x1x1x16x256xf16> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c16, %c256] - : !pto.tensor_view<1x1x1x16x256xf16> -> !pto.partition_tensor_view<1x1x1x16x256xf16> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c16, %c256] - : !pto.tensor_view<1x1x1x16x256xf16> -> !pto.partition_tensor_view<1x1x1x16x256xf16> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x256xf16>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x256xf16>) - outs(%b : !pto.tile_buf) - - pto.tadd ins(%a, %b : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x16x256xf16>) - return - } - - // Case 7: half 16x64_16x128_16x128_16x64 (src=16x128, dst=16x64, valid=16x64) - func.func @TADD_half_16x64_16x128_16x128_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c16 = arith.constant 16 : index - %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c2048 = arith.constant 2048 : index - %c1024 = arith.constant 1024 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c16, %c128], - strides = [%c2048, %c2048, %c2048, %c128, %c1] - : !pto.tensor_view<1x1x1x16x128xf16> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c16, %c128], - strides = [%c2048, %c2048, %c2048, %c128, %c1] - : !pto.tensor_view<1x1x1x16x128xf16> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c16, %c64], - strides = [%c1024, %c1024, %c1024, %c64, %c1] - : !pto.tensor_view<1x1x1x16x64xf16> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c16, %c64] - : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c16, %c64] - : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c16, %c64] - : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) - outs(%b : !pto.tile_buf) - - pto.tadd ins(%a, %b : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) - return - } -} \ No newline at end of file +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tand/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tand/cases.py index fff678a85..8c40489b1 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tand/cases.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tand/cases.py @@ -1,8 +1,7 @@ -#!/usr/bin/python3 # Copyright (c) 2026 Huawei Technologies Co., Ltd. # This program is free software, you can redistribute it and/or modify it under the terms and conditions of # CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You can not use this file except in compliance with the License. +# Please refer to the License for details. You may not use this file except in compliance with the License. # THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. # See LICENSE in the root of the software repository for the full text of the License. @@ -12,28 +11,30 @@ """Single source of truth for tand ST test cases. Each case defines: - - name: case identifier - - dtype: numpy dtype (np.int16, np.int8) - - shape: (rows, cols) — allocated tile dimensions - - valid_shape: (valid_rows, valid_cols) — effective computation region - - eps: tolerance for numpy.allclose (atol and rtol) + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.int32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. """ import numpy as np CASES = [ { - "name": "i16_64x64", - "dtype": np.int16, - "shape": (64, 64), - "valid_shape": (64, 64), + "name": "i32_16x64", + "dtype": np.int32, + "shape": (16, 64), + "valid_shape": (16, 64), "eps": 0, }, { - "name": "i8_64x64_valid63x63", - "dtype": np.int8, - "shape": (64, 64), - "valid_shape": (63, 63), + "name": "i32_32x32", + "dtype": np.int32, + "shape": (32, 32), + "valid_shape": (32, 32), "eps": 0, }, ] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tand/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tand/compare.py index b975718e6..4eae3bc07 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tand/compare.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tand/compare.py @@ -1,4 +1,3 @@ -#!/usr/bin/python3 # Copyright (c) 2026 Huawei Technologies Co., Ltd. # This program is free software, you can redistribute it and/or modify it under the terms and conditions of # CANN Open Software License Agreement Version 2.0 (the "License"). diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tand/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tand/gen_data.py index de8dab931..64829832e 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tand/gen_data.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tand/gen_data.py @@ -1,4 +1,3 @@ -#!/usr/bin/python3 # Copyright (c) 2026 Huawei Technologies Co., Ltd. # This program is free software, you can redistribute it and/or modify it under the terms and conditions of # CANN Open Software License Agreement Version 2.0 (the "License"). @@ -15,8 +14,6 @@ validate_cases(CASES) -np.random.seed(19) - for case in CASES: setup_case_rng(case) @@ -24,8 +21,8 @@ shape = case["shape"] valid_shape = case["valid_shape"] - input1 = np.random.randint(1, 16383, size=shape).astype(dtype) - input2 = np.random.randint(1, 16383, size=shape).astype(dtype) + input1 = np.random.randint(0, 100, size=shape).astype(dtype) + input2 = np.random.randint(0, 100, size=shape).astype(dtype) golden = np.zeros(shape, dtype=dtype) vr, vc = valid_shape diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tand/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tand/launch.cpp index 31006ed8c..ed3149c6e 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tand/launch.cpp +++ b/test/tilelang_st/npu/a5/src/st/testcase/tand/launch.cpp @@ -1,7 +1,7 @@ // Copyright (c) 2026 Huawei Technologies Co., Ltd. // This program is free software, you can redistribute it and/or modify it under the terms and conditions of // CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You can not use this file except in compliance with the License. +// Please refer to the License for details. You may not use this file except in compliance with the License. // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. @@ -12,14 +12,16 @@ #define AICORE [aicore] #endif -extern "C" __global__ AICORE void TAND_i16_64x64(__gm__ int16_t *a, __gm__ int16_t *b, __gm__ int16_t *c); +// Case 0: i32 16x64 +extern "C" __global__ AICORE void TAND_i32_16x64(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); -void LaunchTAND_i16_64x64(void *a, void *b, void *c, void *stream) { - TAND_i16_64x64<<<1, nullptr, stream>>>((__gm__ int16_t *)a, (__gm__ int16_t *)b, (__gm__ int16_t *)c); +void LaunchTAND_i32_16x64(int32_t *a, int32_t *b, int32_t *c, void *stream) { + TAND_i32_16x64<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); } -extern "C" __global__ AICORE void TAND_i8_64x64_valid63x63(__gm__ int8_t *a, __gm__ int8_t *b, __gm__ int8_t *c); +// Case 1: i32 32x32 +extern "C" __global__ AICORE void TAND_i32_32x32(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); -void LaunchTAND_i8_64x64_valid63x63(void *a, void *b, void *c, void *stream) { - TAND_i8_64x64_valid63x63<<<1, nullptr, stream>>>((__gm__ int8_t *)a, (__gm__ int8_t *)b, (__gm__ int8_t *)c); +void LaunchTAND_i32_32x32(int32_t *a, int32_t *b, int32_t *c, void *stream) { + TAND_i32_32x32<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); } \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tand/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tand/main.cpp index af4b9cd56..21b90e9b3 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tand/main.cpp +++ b/test/tilelang_st/npu/a5/src/st/testcase/tand/main.cpp @@ -7,6 +7,8 @@ // See LICENSE in the root of the software repository for the full text of the License. // Host driver for TileLang tand ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. #include "acl/acl.h" #include "test_common.h" @@ -19,54 +21,57 @@ using namespace PtoTestCommon; -void LaunchTAND_i16_64x64(void *a, void *b, void *c, void *stream); -void LaunchTAND_i8_64x64_valid63x63(void *a, void *b, void *c, void *stream); +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTAND_i32_16x64(int32_t *a, int32_t *b, int32_t *c, void *stream); +void LaunchTAND_i32_32x32(int32_t *a, int32_t *b, int32_t *c, void *stream); -using LaunchFn = void (*)(void *, void *, void *, void *); +using LaunchFn = void (*)(int32_t *, int32_t *, int32_t *, void *); struct TestCase { const char *name; LaunchFn launch; - size_t rows; - size_t cols; - size_t validRows; - size_t validCols; - size_t elemSize; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element }; static const TestCase kCases[] = { - {"i16_64x64", LaunchTAND_i16_64x64, 64, 64, 64, 64, sizeof(int16_t)}, - {"i8_64x64_valid63x63", LaunchTAND_i8_64x64_valid63x63, 64, 64, 63, 63, sizeof(int8_t)}, + {"i32_16x64", LaunchTAND_i32_16x64, 16, 64, 16, 64, sizeof(int32_t)}, + {"i32_32x32", LaunchTAND_i32_32x32, 32, 32, 32, 32, sizeof(int32_t)}, }; static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { int rc = 0; - const size_t fileSize = tc.rows * tc.cols * tc.elemSize; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + // Per-case data directory std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = fileSize; + size_t src1FileSize = fileSize; - void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; - void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + int32_t *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + int32_t *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; - aclrtMallocHost(&src0Host, fileSize); - aclrtMallocHost(&src1Host, fileSize); - aclrtMallocHost(&dstHost, fileSize); + aclrtMallocHost((void **)(&src0Host), fileSize); + aclrtMallocHost((void **)(&src1Host), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); - aclrtMalloc(&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); - aclrtMalloc(&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); - aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); - size_t fileSizeRead = 0; - if (!ReadFile((caseDir + "/input1.bin").c_str(), fileSizeRead, src0Host, fileSize)) { + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, fileSize)) { std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); rc = 1; } - fileSizeRead = 0; - if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), fileSizeRead, src1Host, fileSize)) { + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, fileSize)) { std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); rc = 1; } @@ -105,6 +110,7 @@ static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { } int main(int argc, char *argv[]) { + // Optional case filter: ./tand [case_name] const char *caseFilter = (argc > 1) ? argv[1] : nullptr; int rc = 0; diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tand/tand.pto b/test/tilelang_st/npu/a5/src/st/testcase/tand/tand.pto index 6423bc5a1..4b0a2ea84 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tand/tand.pto +++ b/test/tilelang_st/npu/a5/src/st/testcase/tand/tand.pto @@ -1,163 +1,141 @@ // Copyright (c) 2026 Huawei Technologies Co., Ltd. // This program is free software, you can redistribute it and/or modify it under the terms and conditions of // CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except the compliance with the License. +// Please refer to the License for details. You may not use this file except in compliance with the License. // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. // TileLang ST kernels for pto.tand: tload(a) + tload(b) + tand(a,b)->c + tstore(c). -// Aligned with pto-isa/tests/npu/a2a3/src/st/testcase/tand -// Cases cover different dtypes and shapes. +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. module { - // NOTE: Unsigned integer types temporarily disabled due to HIVM backend bug - // See BUG_REPORT_UNSIGNED_BITOPS.md for details - - // Case 0: ui16, 64x64, valid 64x64 - DISABLED - // func.func @TAND_ui16_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - // ... (disabled due to unsigned bitops bug) - // } - - // Case 1: ui16, 64x64, valid 63x63 - DISABLED - // func.func @TAND_ui16_64x64_valid63x63(...) - - // Case 2: ui16, 1x16384, valid 1x16384 - DISABLED - // func.func @TAND_ui16_1x16384(...) - - // Case 3: ui16, 2048x16, valid 2048x16 - DISABLED - // func.func @TAND_ui16_2048x16(...) - - // Case 0 (reindexed): i16, 64x64, valid 64x64 - func.func @TAND_i16_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + // Case 0: i32 16x64 (1024 elements) + func.func @TAND_i32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index %c64 = arith.constant 64 : index - %c4096 = arith.constant 4096 : index + %c1024 = arith.constant 1024 : index %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xi16> + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi32> %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xi16> + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi32> %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xi16> + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi32> %a_part = pto.partition_view %a_view, offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x64xi32> %b_part = pto.partition_view %b_view, offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x64xi32> %c_part = pto.partition_view %c_view, offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x64xi32> %a = pto.alloc_tile - : !pto.tile_buf %b = pto.alloc_tile - : !pto.tile_buf %c = pto.alloc_tile - : !pto.tile_buf - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) - outs(%a : !pto.tile_buf) + outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) - outs(%b : !pto.tile_buf) + outs(%b : !pto.tile_buf) - pto.tand ins(%a, %b : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x64xi32>) return } - // Case 5: ui16, 64x64, valid 64x64 (half mode) - DISABLED - // func.func @TAND_ui16_64x64_half(...) - - // Case 6: ui8, 64x64, valid 63x63 - DISABLED - // func.func @TAND_ui8_64x64_valid63x63(...) - - // Case 1 (reindexed): i8, 64x64, valid 63x63 - func.func @TAND_i8_64x64_valid63x63(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + // Case 1: i32 32x32 (1024 elements) + func.func @TAND_i32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - %c63 = arith.constant 63 : index - %c64 = arith.constant 64 : index - %c4096 = arith.constant 4096 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xi8> + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xi8> + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xi8> + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> %a_part = pto.partition_view %a_view, offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c63, %c63] - : !pto.tensor_view<1x1x1x64x64xi8> -> !pto.partition_tensor_view<1x1x1x63x63xi8> + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> %b_part = pto.partition_view %b_view, offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c63, %c63] - : !pto.tensor_view<1x1x1x64x64xi8> -> !pto.partition_tensor_view<1x1x1x63x63xi8> + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> %c_part = pto.partition_view %c_view, offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c63, %c63] - : !pto.tensor_view<1x1x1x64x64xi8> -> !pto.partition_tensor_view<1x1x1x63x63xi8> + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> %a = pto.alloc_tile - : !pto.tile_buf %b = pto.alloc_tile - : !pto.tile_buf %c = pto.alloc_tile - : !pto.tile_buf - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x63x63xi8>) - outs(%a : !pto.tile_buf) + outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x63x63xi8>) - outs(%b : !pto.tile_buf) + outs(%b : !pto.tile_buf) - pto.tand ins(%a, %b : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x63x63xi8>) + outs(%c_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) return } } \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tdiv/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/cases.py index 68469d49d..61989f6eb 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tdiv/cases.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/cases.py @@ -1,8 +1,7 @@ -#!/usr/bin/python3 # Copyright (c) 2026 Huawei Technologies Co., Ltd. # This program is free software, you can redistribute it and/or modify it under the terms and conditions of # CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You can not use this file except in compliance with the License. +# Please refer to the License for details. You may not use this file except in compliance with the License. # THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. # See LICENSE in the root of the software repository for the full text of the License. @@ -18,8 +17,6 @@ - valid_shape: (valid_rows, valid_cols) — effective computation region. - eps: tolerance for numpy.allclose (atol and rtol). -Note: tdiv only supports float32 and float16. - gen_data.py and compare.py both import this list to avoid redundant definitions. """ @@ -40,32 +37,4 @@ "valid_shape": (32, 32), "eps": 1e-6, }, - { - "name": "f32_64x64", - "dtype": np.float32, - "shape": (64, 64), - "valid_shape": (64, 64), - "eps": 1e-6, - }, - { - "name": "f16_64x64", - "dtype": np.float16, - "shape": (64, 64), - "valid_shape": (64, 64), - "eps": 1e-3, - }, - { - "name": "f16_64x64_v61x61", - "dtype": np.float16, - "shape": (64, 64), - "valid_shape": (61, 61), - "eps": 1e-3, - }, - { - "name": "f32_64x32_v60x30", - "dtype": np.float32, - "shape": (64, 32), - "valid_shape": (60, 30), - "eps": 1e-6, - }, ] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tdiv/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/compare.py index b975718e6..4eae3bc07 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tdiv/compare.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/compare.py @@ -1,4 +1,3 @@ -#!/usr/bin/python3 # Copyright (c) 2026 Huawei Technologies Co., Ltd. # This program is free software, you can redistribute it and/or modify it under the terms and conditions of # CANN Open Software License Agreement Version 2.0 (the "License"). diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tdiv/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/gen_data.py index 27039e842..8f78dd4cf 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tdiv/gen_data.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/gen_data.py @@ -1,4 +1,3 @@ -#!/usr/bin/python3 # Copyright (c) 2026 Huawei Technologies Co., Ltd. # This program is free software, you can redistribute it and/or modify it under the terms and conditions of # CANN Open Software License Agreement Version 2.0 (the "License"). diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tdiv/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/launch.cpp index ee4168bb1..5b677443a 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tdiv/launch.cpp +++ b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/launch.cpp @@ -1,7 +1,7 @@ // Copyright (c) 2026 Huawei Technologies Co., Ltd. // This program is free software, you can redistribute it and/or modify it under the terms and conditions of // CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You can not use this file except in compliance with the License. +// Please refer to the License for details. You may not use this file except in compliance with the License. // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. @@ -15,41 +15,13 @@ // Case 0: f32 16x64 extern "C" __global__ AICORE void TDIV_f32_16x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); -void LaunchTDIV_f32_16x64(void *a, void *b, void *c, void *stream) { +void LaunchTDIV_f32_16x64(float *a, float *b, float *c, void *stream) { TDIV_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); } // Case 1: f32 32x32 extern "C" __global__ AICORE void TDIV_f32_32x32(__gm__ float *a, __gm__ float *b, __gm__ float *c); -void LaunchTDIV_f32_32x32(void *a, void *b, void *c, void *stream) { +void LaunchTDIV_f32_32x32(float *a, float *b, float *c, void *stream) { TDIV_f32_32x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); -} - -// Case 2: f32 64x64 -extern "C" __global__ AICORE void TDIV_f32_64x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); - -void LaunchTDIV_f32_64x64(void *a, void *b, void *c, void *stream) { - TDIV_f32_64x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); -} - -// Case 3: f16 64x64 -extern "C" __global__ AICORE void TDIV_f16_64x64(__gm__ half *a, __gm__ half *b, __gm__ half *c); - -void LaunchTDIV_f16_64x64(void *a, void *b, void *c, void *stream) { - TDIV_f16_64x64<<<1, nullptr, stream>>>((__gm__ half *)a, (__gm__ half *)b, (__gm__ half *)c); -} - -// Case 4: f16 64x64 v61x61 -extern "C" __global__ AICORE void TDIV_f16_64x64_v61x61(__gm__ half *a, __gm__ half *b, __gm__ half *c); - -void LaunchTDIV_f16_64x64_v61x61(void *a, void *b, void *c, void *stream) { - TDIV_f16_64x64_v61x61<<<1, nullptr, stream>>>((__gm__ half *)a, (__gm__ half *)b, (__gm__ half *)c); -} - -// Case 5: f32 64x32 v60x30 -extern "C" __global__ AICORE void TDIV_f32_64x32_v60x30(__gm__ float *a, __gm__ float *b, __gm__ float *c); - -void LaunchTDIV_f32_64x32_v60x30(void *a, void *b, void *c, void *stream) { - TDIV_f32_64x32_v60x30<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); } \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tdiv/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/main.cpp index f4e4e56f4..a999ddd11 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tdiv/main.cpp +++ b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/main.cpp @@ -1,14 +1,14 @@ // Copyright (c) 2026 Huawei Technologies Co., Ltd. // This program is free software, you can redistribute it and/or modify it under the terms and conditions of // CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You can not use this file except in compliance with the License. +// Please refer to the License for details. You may not use this file except in compliance with the License. // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. // Host driver for TileLang tdiv ST — case-table driven. // Each case launches a different kernel variant, reads/writes from per-case subdirectory. -// Numerical comparison is done externally by compare.py. +// Numerical comparison is done externally by compare.cpp. #include "acl/acl.h" #include "test_common.h" @@ -22,14 +22,10 @@ using namespace PtoTestCommon; // Kernel launch wrappers (defined in launch.cpp) -void LaunchTDIV_f32_16x64(void *a, void *b, void *c, void *stream); -void LaunchTDIV_f32_32x32(void *a, void *b, void *c, void *stream); -void LaunchTDIV_f32_64x64(void *a, void *b, void *c, void *stream); -void LaunchTDIV_f16_64x64(void *a, void *b, void *c, void *stream); -void LaunchTDIV_f16_64x64_v61x61(void *a, void *b, void *c, void *stream); -void LaunchTDIV_f32_64x32_v60x30(void *a, void *b, void *c, void *stream); +void LaunchTDIV_f32_16x64(float *a, float *b, float *c, void *stream); +void LaunchTDIV_f32_32x32(float *a, float *b, float *c, void *stream); -using LaunchFn = void (*)(void *, void *, void *, void *); +using LaunchFn = void (*)(float *, float *, float *, void *); struct TestCase { const char *name; @@ -44,15 +40,10 @@ struct TestCase { static const TestCase kCases[] = { {"f32_16x64", LaunchTDIV_f32_16x64, 16, 64, 16, 64, sizeof(float)}, {"f32_32x32", LaunchTDIV_f32_32x32, 32, 32, 32, 32, sizeof(float)}, - {"f32_64x64", LaunchTDIV_f32_64x64, 64, 64, 64, 64, sizeof(float)}, - {"f16_64x64", LaunchTDIV_f16_64x64, 64, 64, 64, 64, sizeof(uint16_t)}, - {"f16_64x64_v61x61", LaunchTDIV_f16_64x64_v61x61, 64, 64, 61, 61, sizeof(uint16_t)}, - {"f32_64x32_v60x30", LaunchTDIV_f32_64x32_v60x30, 64, 32, 60, 30, sizeof(float)}, }; static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { - (void)deviceId; int rc = 0; const size_t elemCount = tc.rows * tc.cols; const size_t fileSize = elemCount * tc.elemSize; @@ -65,16 +56,16 @@ static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { size_t src0FileSize = fileSize; size_t src1FileSize = fileSize; - void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; - void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + float *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + float *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; - aclrtMallocHost(&src0Host, fileSize); - aclrtMallocHost(&src1Host, fileSize); - aclrtMallocHost(&dstHost, fileSize); + aclrtMallocHost((void **)(&src0Host), fileSize); + aclrtMallocHost((void **)(&src1Host), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); - aclrtMalloc(&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); - aclrtMalloc(&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); - aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, fileSize)) { std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tdiv/tdiv.pto b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/tdiv.pto index e3f6a40a7..4c77a53fc 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tdiv/tdiv.pto +++ b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/tdiv.pto @@ -1,19 +1,18 @@ // Copyright (c) 2026 Huawei Technologies Co., Ltd. // This program is free software, you can redistribute it and/or modify it under the terms and conditions of // CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You can not use this file except in compliance with the License. +// Please refer to the License for details. You may not use this file except in compliance with the License. // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. // TileLang ST kernels for pto.tdiv: tload(a) + tload(b) + tdiv(a,b)->c + tstore(c). // Multiple cases with different shapes in a single module. -// Note: tdiv only supports float types (f32, f16). // Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm // to produce LLVM IR. module { - // Case 0: f32 16x64 + // Case 0: f32 16x64 (1024 elements) func.func @TDIV_f32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -77,7 +76,7 @@ module { return } - // Case 1: f32 32x32 + // Case 1: f32 32x32 (1024 elements) func.func @TDIV_f32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -139,260 +138,4 @@ module { outs(%c_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) return } - - // Case 2: f32 64x64 - func.func @TDIV_f32_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c64 = arith.constant 64 : index - %c4096 = arith.constant 4096 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xf32> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xf32> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xf32> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) - outs(%b : !pto.tile_buf) - - pto.tdiv ins(%a, %b : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) - return - } - - // Case 3: f16 64x64 - func.func @TDIV_f16_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c64 = arith.constant 64 : index - %c4096 = arith.constant 4096 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xf16> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xf16> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xf16> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xf16> -> !pto.partition_tensor_view<1x1x1x64x64xf16> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xf16> -> !pto.partition_tensor_view<1x1x1x64x64xf16> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xf16> -> !pto.partition_tensor_view<1x1x1x64x64xf16> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf16>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf16>) - outs(%b : !pto.tile_buf) - - pto.tdiv ins(%a, %b : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf16>) - return - } - - // Case 4: f16 64x64 v61x61 (valid != tile) - func.func @TDIV_f16_64x64_v61x61(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c61 = arith.constant 61 : index - %c64 = arith.constant 64 : index - %c4096 = arith.constant 4096 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c61, %c61], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x61x61xf16> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c61, %c61], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x61x61xf16> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c61, %c61], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x61x61xf16> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c61, %c61] - : !pto.tensor_view<1x1x1x61x61xf16> -> !pto.partition_tensor_view<1x1x1x61x61xf16> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c61, %c61] - : !pto.tensor_view<1x1x1x61x61xf16> -> !pto.partition_tensor_view<1x1x1x61x61xf16> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c61, %c61] - : !pto.tensor_view<1x1x1x61x61xf16> -> !pto.partition_tensor_view<1x1x1x61x61xf16> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x61x61xf16>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x61x61xf16>) - outs(%b : !pto.tile_buf) - - pto.tdiv ins(%a, %b : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x61x61xf16>) - return - } - - // Case 5: f32 64x32 v60x30 (valid != tile) - func.func @TDIV_f32_64x32_v60x30(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c30 = arith.constant 30 : index - %c32 = arith.constant 32 : index - %c60 = arith.constant 60 : index - %c64 = arith.constant 64 : index - %c2048 = arith.constant 2048 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c60, %c30], - strides = [%c2048, %c2048, %c2048, %c32, %c1] - : !pto.tensor_view<1x1x1x60x30xf32> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c60, %c30], - strides = [%c2048, %c2048, %c2048, %c32, %c1] - : !pto.tensor_view<1x1x1x60x30xf32> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c60, %c30], - strides = [%c2048, %c2048, %c2048, %c32, %c1] - : !pto.tensor_view<1x1x1x60x30xf32> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c60, %c30] - : !pto.tensor_view<1x1x1x60x30xf32> -> !pto.partition_tensor_view<1x1x1x60x30xf32> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c60, %c30] - : !pto.tensor_view<1x1x1x60x30xf32> -> !pto.partition_tensor_view<1x1x1x60x30xf32> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c60, %c30] - : !pto.tensor_view<1x1x1x60x30xf32> -> !pto.partition_tensor_view<1x1x1x60x30xf32> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x60x30xf32>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x60x30xf32>) - outs(%b : !pto.tile_buf) - - pto.tdiv ins(%a, %b : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x60x30xf32>) - return - } } \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmax/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tmax/cases.py index fc535804b..69ba77ac4 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tmax/cases.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmax/cases.py @@ -1,8 +1,7 @@ -#!/usr/bin/python3 # Copyright (c) 2026 Huawei Technologies Co., Ltd. # This program is free software, you can redistribute it and/or modify it under the terms and conditions of # CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You can not use this file except in compliance with the License. +# Please refer to the License for details. You may not use this file except in compliance with the License. # THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. # See LICENSE in the root of the software repository for the full text of the License. @@ -38,60 +37,4 @@ "valid_shape": (32, 32), "eps": 1e-6, }, - { - "name": "f32_64x64", - "dtype": np.float32, - "shape": (64, 64), - "valid_shape": (64, 64), - "eps": 1e-6, - }, - { - "name": "i32_64x64", - "dtype": np.int32, - "shape": (64, 64), - "valid_shape": (64, 64), - "eps": 0, - }, - { - "name": "i16_64x64", - "dtype": np.int16, - "shape": (64, 64), - "valid_shape": (64, 64), - "eps": 0, - }, - { - "name": "f16_64x64", - "dtype": np.float16, - "shape": (64, 64), - "valid_shape": (64, 64), - "eps": 1e-3, - }, - { - "name": "f32_64x64_v60x60", - "dtype": np.float32, - "shape": (64, 64), - "valid_shape": (60, 60), - "eps": 1e-6, - }, - { - "name": "i32_64x64_v60x60", - "dtype": np.int32, - "shape": (64, 64), - "valid_shape": (60, 60), - "eps": 0, - }, - { - "name": "f16_2x4096_v1x3600", - "dtype": np.float16, - "shape": (2, 4096), - "valid_shape": (1, 3600), - "eps": 1e-3, - }, - { - "name": "i16_20x512_v16x200", - "dtype": np.int16, - "shape": (20, 512), - "valid_shape": (16, 200), - "eps": 0, - }, ] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmax/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tmax/compare.py index b975718e6..4eae3bc07 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tmax/compare.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmax/compare.py @@ -1,4 +1,3 @@ -#!/usr/bin/python3 # Copyright (c) 2026 Huawei Technologies Co., Ltd. # This program is free software, you can redistribute it and/or modify it under the terms and conditions of # CANN Open Software License Agreement Version 2.0 (the "License"). diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmax/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tmax/gen_data.py index 7d44773bb..0d1487e44 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tmax/gen_data.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmax/gen_data.py @@ -1,4 +1,3 @@ -#!/usr/bin/python3 # Copyright (c) 2026 Huawei Technologies Co., Ltd. # This program is free software, you can redistribute it and/or modify it under the terms and conditions of # CANN Open Software License Agreement Version 2.0 (the "License"). diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmax/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmax/launch.cpp index 5b3266627..3d47d685c 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tmax/launch.cpp +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmax/launch.cpp @@ -1,7 +1,7 @@ // Copyright (c) 2026 Huawei Technologies Co., Ltd. // This program is free software, you can redistribute it and/or modify it under the terms and conditions of // CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You can not use this file except in compliance with the License. +// Please refer to the License for details. You may not use this file except in compliance with the License. // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. @@ -15,69 +15,13 @@ // Case 0: f32 16x64 extern "C" __global__ AICORE void TMAX_f32_16x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); -void LaunchTMAX_f32_16x64(void *a, void *b, void *c, void *stream) { +void LaunchTMAX_f32_16x64(float *a, float *b, float *c, void *stream) { TMAX_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); } // Case 1: f32 32x32 extern "C" __global__ AICORE void TMAX_f32_32x32(__gm__ float *a, __gm__ float *b, __gm__ float *c); -void LaunchTMAX_f32_32x32(void *a, void *b, void *c, void *stream) { +void LaunchTMAX_f32_32x32(float *a, float *b, float *c, void *stream) { TMAX_f32_32x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); -} - -// Case 2: f32 64x64 -extern "C" __global__ AICORE void TMAX_f32_64x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); - -void LaunchTMAX_f32_64x64(void *a, void *b, void *c, void *stream) { - TMAX_f32_64x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); -} - -// Case 3: i32 64x64 -extern "C" __global__ AICORE void TMAX_i32_64x64(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); - -void LaunchTMAX_i32_64x64(void *a, void *b, void *c, void *stream) { - TMAX_i32_64x64<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); -} - -// Case 4: i16 64x64 -extern "C" __global__ AICORE void TMAX_i16_64x64(__gm__ int16_t *a, __gm__ int16_t *b, __gm__ int16_t *c); - -void LaunchTMAX_i16_64x64(void *a, void *b, void *c, void *stream) { - TMAX_i16_64x64<<<1, nullptr, stream>>>((__gm__ int16_t *)a, (__gm__ int16_t *)b, (__gm__ int16_t *)c); -} - -// Case 5: f16 64x64 -extern "C" __global__ AICORE void TMAX_f16_64x64(__gm__ half *a, __gm__ half *b, __gm__ half *c); - -void LaunchTMAX_f16_64x64(void *a, void *b, void *c, void *stream) { - TMAX_f16_64x64<<<1, nullptr, stream>>>((__gm__ half *)a, (__gm__ half *)b, (__gm__ half *)c); -} - -// Case 6: f32 64x64 v60x60 -extern "C" __global__ AICORE void TMAX_f32_64x64_v60x60(__gm__ float *a, __gm__ float *b, __gm__ float *c); - -void LaunchTMAX_f32_64x64_v60x60(void *a, void *b, void *c, void *stream) { - TMAX_f32_64x64_v60x60<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); -} - -// Case 7: i32 64x64 v60x60 -extern "C" __global__ AICORE void TMAX_i32_64x64_v60x60(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); - -void LaunchTMAX_i32_64x64_v60x60(void *a, void *b, void *c, void *stream) { - TMAX_i32_64x64_v60x60<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); -} - -// Case 8: f16 2x4096 v1x3600 -extern "C" __global__ AICORE void TMAX_f16_2x4096_v1x3600(__gm__ half *a, __gm__ half *b, __gm__ half *c); - -void LaunchTMAX_f16_2x4096_v1x3600(void *a, void *b, void *c, void *stream) { - TMAX_f16_2x4096_v1x3600<<<1, nullptr, stream>>>((__gm__ half *)a, (__gm__ half *)b, (__gm__ half *)c); -} - -// Case 9: i16 20x512 v16x200 -extern "C" __global__ AICORE void TMAX_i16_20x512_v16x200(__gm__ int16_t *a, __gm__ int16_t *b, __gm__ int16_t *c); - -void LaunchTMAX_i16_20x512_v16x200(void *a, void *b, void *c, void *stream) { - TMAX_i16_20x512_v16x200<<<1, nullptr, stream>>>((__gm__ int16_t *)a, (__gm__ int16_t *)b, (__gm__ int16_t *)c); } \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmax/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmax/main.cpp index 0f00a8513..3dd9859a5 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tmax/main.cpp +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmax/main.cpp @@ -1,7 +1,7 @@ // Copyright (c) 2026 Huawei Technologies Co., Ltd. // This program is free software, you can redistribute it and/or modify it under the terms and conditions of // CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You can not use this file except in compliance with the License. +// Please refer to the License for details. You may not use this file except in compliance with the License. // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. @@ -22,18 +22,10 @@ using namespace PtoTestCommon; // Kernel launch wrappers (defined in launch.cpp) -void LaunchTMAX_f32_16x64(void *a, void *b, void *c, void *stream); -void LaunchTMAX_f32_32x32(void *a, void *b, void *c, void *stream); -void LaunchTMAX_f32_64x64(void *a, void *b, void *c, void *stream); -void LaunchTMAX_i32_64x64(void *a, void *b, void *c, void *stream); -void LaunchTMAX_i16_64x64(void *a, void *b, void *c, void *stream); -void LaunchTMAX_f16_64x64(void *a, void *b, void *c, void *stream); -void LaunchTMAX_f32_64x64_v60x60(void *a, void *b, void *c, void *stream); -void LaunchTMAX_i32_64x64_v60x60(void *a, void *b, void *c, void *stream); -void LaunchTMAX_f16_2x4096_v1x3600(void *a, void *b, void *c, void *stream); -void LaunchTMAX_i16_20x512_v16x200(void *a, void *b, void *c, void *stream); - -using LaunchFn = void (*)(void *, void *, void *, void *); +void LaunchTMAX_f32_16x64(float *a, float *b, float *c, void *stream); +void LaunchTMAX_f32_32x32(float *a, float *b, float *c, void *stream); + +using LaunchFn = void (*)(float *, float *, float *, void *); struct TestCase { const char *name; @@ -48,19 +40,10 @@ struct TestCase { static const TestCase kCases[] = { {"f32_16x64", LaunchTMAX_f32_16x64, 16, 64, 16, 64, sizeof(float)}, {"f32_32x32", LaunchTMAX_f32_32x32, 32, 32, 32, 32, sizeof(float)}, - {"f32_64x64", LaunchTMAX_f32_64x64, 64, 64, 64, 64, sizeof(float)}, - {"i32_64x64", LaunchTMAX_i32_64x64, 64, 64, 64, 64, sizeof(int32_t)}, - {"i16_64x64", LaunchTMAX_i16_64x64, 64, 64, 64, 64, sizeof(int16_t)}, - {"f16_64x64", LaunchTMAX_f16_64x64, 64, 64, 64, 64, sizeof(uint16_t)}, - {"f32_64x64_v60x60", LaunchTMAX_f32_64x64_v60x60, 64, 64, 60, 60, sizeof(float)}, - {"i32_64x64_v60x60", LaunchTMAX_i32_64x64_v60x60, 64, 64, 60, 60, sizeof(int32_t)}, - {"f16_2x4096_v1x3600", LaunchTMAX_f16_2x4096_v1x3600, 2, 4096, 1, 3600, sizeof(uint16_t)}, - {"i16_20x512_v16x200", LaunchTMAX_i16_20x512_v16x200, 20, 512, 16, 200, sizeof(int16_t)}, }; static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { - (void)deviceId; int rc = 0; const size_t elemCount = tc.rows * tc.cols; const size_t fileSize = elemCount * tc.elemSize; @@ -73,16 +56,16 @@ static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { size_t src0FileSize = fileSize; size_t src1FileSize = fileSize; - void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; - void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + float *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + float *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; - aclrtMallocHost(&src0Host, fileSize); - aclrtMallocHost(&src1Host, fileSize); - aclrtMallocHost(&dstHost, fileSize); + aclrtMallocHost((void **)(&src0Host), fileSize); + aclrtMallocHost((void **)(&src1Host), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); - aclrtMalloc(&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); - aclrtMalloc(&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); - aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, fileSize)) { std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmax/tmax.pto b/test/tilelang_st/npu/a5/src/st/testcase/tmax/tmax.pto index 77172ec2a..d10ed0e73 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tmax/tmax.pto +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmax/tmax.pto @@ -138,512 +138,4 @@ module { outs(%c_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) return } - - // Case 2: f32 64x64 (4096 elements) - func.func @TMAX_f32_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c64 = arith.constant 64 : index - %c4096 = arith.constant 4096 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xf32> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xf32> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xf32> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) - outs(%b : !pto.tile_buf) - - pto.tmax ins(%a, %b : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) - return - } - - // Case 3: i32 64x64 (4096 elements) - func.func @TMAX_i32_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c64 = arith.constant 64 : index - %c4096 = arith.constant 4096 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xi32> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xi32> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xi32> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) - outs(%b : !pto.tile_buf) - - pto.tmax ins(%a, %b : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) - return - } - - // Case 4: i16 64x64 (4096 elements) - func.func @TMAX_i16_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c64 = arith.constant 64 : index - %c4096 = arith.constant 4096 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xi16> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xi16> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xi16> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) - outs(%b : !pto.tile_buf) - - pto.tmax ins(%a, %b : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) - return - } - - // Case 5: f16 64x64 (4096 elements) - func.func @TMAX_f16_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c64 = arith.constant 64 : index - %c4096 = arith.constant 4096 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xf16> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xf16> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xf16> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xf16> -> !pto.partition_tensor_view<1x1x1x64x64xf16> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xf16> -> !pto.partition_tensor_view<1x1x1x64x64xf16> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xf16> -> !pto.partition_tensor_view<1x1x1x64x64xf16> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf16>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf16>) - outs(%b : !pto.tile_buf) - - pto.tmax ins(%a, %b : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf16>) - return - } - - // Case 6: f32 64x64 tile with 60x60 valid region (padding with MIN for tmax) - func.func @TMAX_f32_64x64_v60x60(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c60 = arith.constant 60 : index - %c64 = arith.constant 64 : index - %c3600 = arith.constant 3600 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c60, %c60], - strides = [%c3600, %c3600, %c3600, %c64, %c1] - : !pto.tensor_view<1x1x1x60x60xf32> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c60, %c60], - strides = [%c3600, %c3600, %c3600, %c64, %c1] - : !pto.tensor_view<1x1x1x60x60xf32> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c60, %c60], - strides = [%c3600, %c3600, %c3600, %c64, %c1] - : !pto.tensor_view<1x1x1x60x60xf32> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c60, %c60] - : !pto.tensor_view<1x1x1x60x60xf32> -> !pto.partition_tensor_view<1x1x1x60x60xf32> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c60, %c60] - : !pto.tensor_view<1x1x1x60x60xf32> -> !pto.partition_tensor_view<1x1x1x60x60xf32> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c60, %c60] - : !pto.tensor_view<1x1x1x60x60xf32> -> !pto.partition_tensor_view<1x1x1x60x60xf32> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x60x60xf32>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x60x60xf32>) - outs(%b : !pto.tile_buf) - - pto.tmax ins(%a, %b : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x60x60xf32>) - return - } - - // Case 7: i32 64x64 tile with 60x60 valid region (padding with MIN for tmax) - func.func @TMAX_i32_64x64_v60x60(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c60 = arith.constant 60 : index - %c64 = arith.constant 64 : index - %c3600 = arith.constant 3600 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c60, %c60], - strides = [%c3600, %c3600, %c3600, %c64, %c1] - : !pto.tensor_view<1x1x1x60x60xi32> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c60, %c60], - strides = [%c3600, %c3600, %c3600, %c64, %c1] - : !pto.tensor_view<1x1x1x60x60xi32> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c60, %c60], - strides = [%c3600, %c3600, %c3600, %c64, %c1] - : !pto.tensor_view<1x1x1x60x60xi32> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c60, %c60] - : !pto.tensor_view<1x1x1x60x60xi32> -> !pto.partition_tensor_view<1x1x1x60x60xi32> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c60, %c60] - : !pto.tensor_view<1x1x1x60x60xi32> -> !pto.partition_tensor_view<1x1x1x60x60xi32> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c60, %c60] - : !pto.tensor_view<1x1x1x60x60xi32> -> !pto.partition_tensor_view<1x1x1x60x60xi32> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x60x60xi32>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x60x60xi32>) - outs(%b : !pto.tile_buf) - - pto.tmax ins(%a, %b : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x60x60xi32>) - return - } - - // Case 8: f16 2x4096 tile with 1x3600 valid region (padding with MIN for tmax) - func.func @TMAX_f16_2x4096_v1x3600(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c3600 = arith.constant 3600 : index - %c4096 = arith.constant 4096 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c1, %c3600], - strides = [%c3600, %c3600, %c3600, %c4096, %c1] - : !pto.tensor_view<1x1x1x1x3600xf16> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c1, %c3600], - strides = [%c3600, %c3600, %c3600, %c4096, %c1] - : !pto.tensor_view<1x1x1x1x3600xf16> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c1, %c3600], - strides = [%c3600, %c3600, %c3600, %c4096, %c1] - : !pto.tensor_view<1x1x1x1x3600xf16> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c1, %c3600] - : !pto.tensor_view<1x1x1x1x3600xf16> -> !pto.partition_tensor_view<1x1x1x1x3600xf16> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c1, %c3600] - : !pto.tensor_view<1x1x1x1x3600xf16> -> !pto.partition_tensor_view<1x1x1x1x3600xf16> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c1, %c3600] - : !pto.tensor_view<1x1x1x1x3600xf16> -> !pto.partition_tensor_view<1x1x1x1x3600xf16> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x1x3600xf16>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x1x3600xf16>) - outs(%b : !pto.tile_buf) - - pto.tmax ins(%a, %b : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x1x3600xf16>) - return - } - - // Case 9: i16 20x512 tile with 16x200 valid region (padding with MIN for tmax) - func.func @TMAX_i16_20x512_v16x200(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c16 = arith.constant 16 : index - %c200 = arith.constant 200 : index - %c512 = arith.constant 512 : index - %c3200 = arith.constant 3200 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c16, %c200], - strides = [%c3200, %c3200, %c3200, %c512, %c1] - : !pto.tensor_view<1x1x1x16x200xi16> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c16, %c200], - strides = [%c3200, %c3200, %c3200, %c512, %c1] - : !pto.tensor_view<1x1x1x16x200xi16> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c16, %c200], - strides = [%c3200, %c3200, %c3200, %c512, %c1] - : !pto.tensor_view<1x1x1x16x200xi16> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c16, %c200] - : !pto.tensor_view<1x1x1x16x200xi16> -> !pto.partition_tensor_view<1x1x1x16x200xi16> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c16, %c200] - : !pto.tensor_view<1x1x1x16x200xi16> -> !pto.partition_tensor_view<1x1x1x16x200xi16> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c16, %c200] - : !pto.tensor_view<1x1x1x16x200xi16> -> !pto.partition_tensor_view<1x1x1x16x200xi16> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x200xi16>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x200xi16>) - outs(%b : !pto.tile_buf) - - pto.tmax ins(%a, %b : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x16x200xi16>) - return - } } \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmin/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tmin/cases.py index 4a075e55f..15bbb58ea 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tmin/cases.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmin/cases.py @@ -1,8 +1,7 @@ -#!/usr/bin/python3 # Copyright (c) 2026 Huawei Technologies Co., Ltd. # This program is free software, you can redistribute it and/or modify it under the terms and conditions of # CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You can not use this file except in compliance with the License. +# Please refer to the License for details. You may not use this file except in compliance with the License. # THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. # See LICENSE in the root of the software repository for the full text of the License. diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmin/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tmin/compare.py index e9cb62020..4eae3bc07 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tmin/compare.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmin/compare.py @@ -1,8 +1,7 @@ -#!/usr/bin/python3 # Copyright (c) 2026 Huawei Technologies Co., Ltd. # This program is free software, you can redistribute it and/or modify it under the terms and conditions of # CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You can not use this file except in compliance with the License. +# Please refer to the License for details. You may not use this file except in compliance with the License. # THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. # See LICENSE in the root of the software repository for the full text of the License. diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmin/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tmin/gen_data.py index 0722522c3..0c72ecbc9 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tmin/gen_data.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmin/gen_data.py @@ -1,8 +1,7 @@ -#!/usr/bin/python3 # Copyright (c) 2026 Huawei Technologies Co., Ltd. # This program is free software, you can redistribute it and/or modify it under the terms and conditions of # CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You can not use this file except in compliance with the License. +# Please refer to the License for details. You may not use this file except in compliance with the License. # THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. # See LICENSE in the root of the software repository for the full text of the License. diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmul/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tmul/cases.py index 9a978c0f3..2d3a70ce8 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tmul/cases.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmul/cases.py @@ -1,8 +1,7 @@ -#!/usr/bin/python3 # Copyright (c) 2026 Huawei Technologies Co., Ltd. # This program is free software, you can redistribute it and/or modify it under the terms and conditions of # CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You can not use this file except in compliance with the License. +# Please refer to the License for details. You may not use this file except in compliance with the License. # THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. # See LICENSE in the root of the software repository for the full text of the License. @@ -38,53 +37,4 @@ "valid_shape": (32, 32), "eps": 1e-6, }, - { - "name": "f32_64x64", - "dtype": np.float32, - "shape": (64, 64), - "valid_shape": (64, 64), - "eps": 1e-6, - }, - { - "name": "i32_64x64", - "dtype": np.int32, - "shape": (64, 64), - "valid_shape": (64, 64), - "eps": 0, - }, - { - "name": "i32_32x32", - "dtype": np.int32, - "shape": (32, 32), - "valid_shape": (32, 32), - "eps": 0, - }, - { - "name": "i16_64x64", - "dtype": np.int16, - "shape": (64, 64), - "valid_shape": (64, 64), - "eps": 0, - }, - { - "name": "f16_16x16", - "dtype": np.float16, - "shape": (16, 16), - "valid_shape": (16, 16), - "eps": 1e-3, - }, - { - "name": "f16_64x64_v61x61", - "dtype": np.float16, - "shape": (64, 64), - "valid_shape": (61, 61), - "eps": 1e-3, - }, - { - "name": "i32_64x32_v60x30", - "dtype": np.int32, - "shape": (64, 32), - "valid_shape": (60, 30), - "eps": 0, - }, ] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmul/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tmul/compare.py index b975718e6..4eae3bc07 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tmul/compare.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmul/compare.py @@ -1,4 +1,3 @@ -#!/usr/bin/python3 # Copyright (c) 2026 Huawei Technologies Co., Ltd. # This program is free software, you can redistribute it and/or modify it under the terms and conditions of # CANN Open Software License Agreement Version 2.0 (the "License"). diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmul/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tmul/gen_data.py index 1d8eddf3f..0cf58f73b 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tmul/gen_data.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmul/gen_data.py @@ -1,4 +1,3 @@ -#!/usr/bin/python3 # Copyright (c) 2026 Huawei Technologies Co., Ltd. # This program is free software, you can redistribute it and/or modify it under the terms and conditions of # CANN Open Software License Agreement Version 2.0 (the "License"). diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmul/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmul/launch.cpp index 4c08a0caf..1debfe140 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tmul/launch.cpp +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmul/launch.cpp @@ -1,7 +1,7 @@ // Copyright (c) 2026 Huawei Technologies Co., Ltd. // This program is free software, you can redistribute it and/or modify it under the terms and conditions of // CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You can not use this file except in compliance with the License. +// Please refer to the License for details. You may not use this file except in compliance with the License. // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. @@ -15,62 +15,13 @@ // Case 0: f32 16x64 extern "C" __global__ AICORE void TMUL_f32_16x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); -void LaunchTMUL_f32_16x64(void *a, void *b, void *c, void *stream) { +void LaunchTMUL_f32_16x64(float *a, float *b, float *c, void *stream) { TMUL_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); } // Case 1: f32 32x32 extern "C" __global__ AICORE void TMUL_f32_32x32(__gm__ float *a, __gm__ float *b, __gm__ float *c); -void LaunchTMUL_f32_32x32(void *a, void *b, void *c, void *stream) { +void LaunchTMUL_f32_32x32(float *a, float *b, float *c, void *stream) { TMUL_f32_32x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); -} - -// Case 2: f32 64x64 -extern "C" __global__ AICORE void TMUL_f32_64x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); - -void LaunchTMUL_f32_64x64(void *a, void *b, void *c, void *stream) { - TMUL_f32_64x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); -} - -// Case 3: i32 64x64 -extern "C" __global__ AICORE void TMUL_i32_64x64(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); - -void LaunchTMUL_i32_64x64(void *a, void *b, void *c, void *stream) { - TMUL_i32_64x64<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); -} - -// Case 4: i32 32x32 -extern "C" __global__ AICORE void TMUL_i32_32x32(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); - -void LaunchTMUL_i32_32x32(void *a, void *b, void *c, void *stream) { - TMUL_i32_32x32<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); -} - -// Case 5: i16 64x64 -extern "C" __global__ AICORE void TMUL_i16_64x64(__gm__ int16_t *a, __gm__ int16_t *b, __gm__ int16_t *c); - -void LaunchTMUL_i16_64x64(void *a, void *b, void *c, void *stream) { - TMUL_i16_64x64<<<1, nullptr, stream>>>((__gm__ int16_t *)a, (__gm__ int16_t *)b, (__gm__ int16_t *)c); -} - -// Case 6: f16 16x16 -extern "C" __global__ AICORE void TMUL_f16_16x16(__gm__ half *a, __gm__ half *b, __gm__ half *c); - -void LaunchTMUL_f16_16x16(void *a, void *b, void *c, void *stream) { - TMUL_f16_16x16<<<1, nullptr, stream>>>((__gm__ half *)a, (__gm__ half *)b, (__gm__ half *)c); -} - -// Case 7: f16 64x64 v61x61 -extern "C" __global__ AICORE void TMUL_f16_64x64_v61x61(__gm__ half *a, __gm__ half *b, __gm__ half *c); - -void LaunchTMUL_f16_64x64_v61x61(void *a, void *b, void *c, void *stream) { - TMUL_f16_64x64_v61x61<<<1, nullptr, stream>>>((__gm__ half *)a, (__gm__ half *)b, (__gm__ half *)c); -} - -// Case 8: i32 64x32 v60x30 -extern "C" __global__ AICORE void TMUL_i32_64x32_v60x30(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); - -void LaunchTMUL_i32_64x32_v60x30(void *a, void *b, void *c, void *stream) { - TMUL_i32_64x32_v60x30<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); } \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmul/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmul/main.cpp index 334486fd3..6e294af40 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tmul/main.cpp +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmul/main.cpp @@ -1,7 +1,7 @@ // Copyright (c) 2026 Huawei Technologies Co., Ltd. // This program is free software, you can redistribute it and/or modify it under the terms and conditions of // CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You can not use this file except in compliance with the License. +// Please refer to the License for details. You may not use this file except in compliance with the License. // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. @@ -22,17 +22,10 @@ using namespace PtoTestCommon; // Kernel launch wrappers (defined in launch.cpp) -void LaunchTMUL_f32_16x64(void *a, void *b, void *c, void *stream); -void LaunchTMUL_f32_32x32(void *a, void *b, void *c, void *stream); -void LaunchTMUL_f32_64x64(void *a, void *b, void *c, void *stream); -void LaunchTMUL_i32_64x64(void *a, void *b, void *c, void *stream); -void LaunchTMUL_i32_32x32(void *a, void *b, void *c, void *stream); -void LaunchTMUL_i16_64x64(void *a, void *b, void *c, void *stream); -void LaunchTMUL_f16_16x16(void *a, void *b, void *c, void *stream); -void LaunchTMUL_f16_64x64_v61x61(void *a, void *b, void *c, void *stream); -void LaunchTMUL_i32_64x32_v60x30(void *a, void *b, void *c, void *stream); - -using LaunchFn = void (*)(void *, void *, void *, void *); +void LaunchTMUL_f32_16x64(float *a, float *b, float *c, void *stream); +void LaunchTMUL_f32_32x32(float *a, float *b, float *c, void *stream); + +using LaunchFn = void (*)(float *, float *, float *, void *); struct TestCase { const char *name; @@ -47,18 +40,10 @@ struct TestCase { static const TestCase kCases[] = { {"f32_16x64", LaunchTMUL_f32_16x64, 16, 64, 16, 64, sizeof(float)}, {"f32_32x32", LaunchTMUL_f32_32x32, 32, 32, 32, 32, sizeof(float)}, - {"f32_64x64", LaunchTMUL_f32_64x64, 64, 64, 64, 64, sizeof(float)}, - {"i32_64x64", LaunchTMUL_i32_64x64, 64, 64, 64, 64, sizeof(int32_t)}, - {"i32_32x32", LaunchTMUL_i32_32x32, 32, 32, 32, 32, sizeof(int32_t)}, - {"i16_64x64", LaunchTMUL_i16_64x64, 64, 64, 64, 64, sizeof(int16_t)}, - {"f16_16x16", LaunchTMUL_f16_16x16, 16, 16, 16, 16, sizeof(uint16_t)}, - {"f16_64x64_v61x61", LaunchTMUL_f16_64x64_v61x61, 64, 64, 61, 61, sizeof(uint16_t)}, - {"i32_64x32_v60x30", LaunchTMUL_i32_64x32_v60x30, 64, 32, 60, 30, sizeof(int32_t)}, }; static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { - (void)deviceId; int rc = 0; const size_t elemCount = tc.rows * tc.cols; const size_t fileSize = elemCount * tc.elemSize; @@ -71,16 +56,16 @@ static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { size_t src0FileSize = fileSize; size_t src1FileSize = fileSize; - void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; - void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + float *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + float *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; - aclrtMallocHost(&src0Host, fileSize); - aclrtMallocHost(&src1Host, fileSize); - aclrtMallocHost(&dstHost, fileSize); + aclrtMallocHost((void **)(&src0Host), fileSize); + aclrtMallocHost((void **)(&src1Host), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); - aclrtMalloc(&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); - aclrtMalloc(&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); - aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, fileSize)) { std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmul/tmul.pto b/test/tilelang_st/npu/a5/src/st/testcase/tmul/tmul.pto index c358c738c..9916f6979 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tmul/tmul.pto +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmul/tmul.pto @@ -1,7 +1,7 @@ // Copyright (c) 2026 Huawei Technologies Co., Ltd. // This program is free software, you can redistribute it and/or modify it under the terms and conditions of // CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You can not use this file except in compliance with the License. +// Please refer to the License for details. You may not use this file except in compliance with the License. // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. @@ -12,7 +12,7 @@ // to produce LLVM IR. module { - // Case 0: f32 16x64 + // Case 0: f32 16x64 (1024 elements) func.func @TMUL_f32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -76,7 +76,7 @@ module { return } - // Case 1: f32 32x32 + // Case 1: f32 32x32 (1024 elements) func.func @TMUL_f32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -138,449 +138,4 @@ module { outs(%c_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) return } - - // Case 2: f32 64x64 - func.func @TMUL_f32_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c64 = arith.constant 64 : index - %c4096 = arith.constant 4096 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xf32> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xf32> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xf32> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) - outs(%b : !pto.tile_buf) - - pto.tmul ins(%a, %b : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) - return - } - - // Case 3: i32 64x64 - func.func @TMUL_i32_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c64 = arith.constant 64 : index - %c4096 = arith.constant 4096 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xi32> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xi32> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xi32> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) - outs(%b : !pto.tile_buf) - - pto.tmul ins(%a, %b : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) - return - } - - // Case 4: i32 32x32 - func.func @TMUL_i32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c32 = arith.constant 32 : index - %c1024 = arith.constant 1024 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c32, %c32], - strides = [%c1024, %c1024, %c1024, %c32, %c1] - : !pto.tensor_view<1x1x1x32x32xi32> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c32, %c32], - strides = [%c1024, %c1024, %c1024, %c32, %c1] - : !pto.tensor_view<1x1x1x32x32xi32> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c32, %c32], - strides = [%c1024, %c1024, %c1024, %c32, %c1] - : !pto.tensor_view<1x1x1x32x32xi32> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c32, %c32] - : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c32, %c32] - : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c32, %c32] - : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) - outs(%b : !pto.tile_buf) - - pto.tmul ins(%a, %b : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) - return - } - - // Case 5: i16 64x64 - func.func @TMUL_i16_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c64 = arith.constant 64 : index - %c4096 = arith.constant 4096 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xi16> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xi16> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xi16> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) - outs(%b : !pto.tile_buf) - - pto.tmul ins(%a, %b : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) - return - } - - // Case 6: f16 16x16 - func.func @TMUL_f16_16x16(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c16 = arith.constant 16 : index - %c256 = arith.constant 256 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c16, %c16], - strides = [%c256, %c256, %c256, %c16, %c1] - : !pto.tensor_view<1x1x1x16x16xf16> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c16, %c16], - strides = [%c256, %c256, %c256, %c16, %c1] - : !pto.tensor_view<1x1x1x16x16xf16> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c16, %c16], - strides = [%c256, %c256, %c256, %c16, %c1] - : !pto.tensor_view<1x1x1x16x16xf16> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c16, %c16] - : !pto.tensor_view<1x1x1x16x16xf16> -> !pto.partition_tensor_view<1x1x1x16x16xf16> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c16, %c16] - : !pto.tensor_view<1x1x1x16x16xf16> -> !pto.partition_tensor_view<1x1x1x16x16xf16> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c16, %c16] - : !pto.tensor_view<1x1x1x16x16xf16> -> !pto.partition_tensor_view<1x1x1x16x16xf16> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x16xf16>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x16xf16>) - outs(%b : !pto.tile_buf) - - pto.tmul ins(%a, %b : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x16x16xf16>) - return - } - - // Case 7: f16 64x64 v61x61 (valid != tile) - func.func @TMUL_f16_64x64_v61x61(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c61 = arith.constant 61 : index - %c64 = arith.constant 64 : index - %c4096 = arith.constant 4096 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c61, %c61], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x61x61xf16> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c61, %c61], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x61x61xf16> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c61, %c61], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x61x61xf16> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c61, %c61] - : !pto.tensor_view<1x1x1x61x61xf16> -> !pto.partition_tensor_view<1x1x1x61x61xf16> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c61, %c61] - : !pto.tensor_view<1x1x1x61x61xf16> -> !pto.partition_tensor_view<1x1x1x61x61xf16> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c61, %c61] - : !pto.tensor_view<1x1x1x61x61xf16> -> !pto.partition_tensor_view<1x1x1x61x61xf16> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x61x61xf16>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x61x61xf16>) - outs(%b : !pto.tile_buf) - - pto.tmul ins(%a, %b : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x61x61xf16>) - return - } - - // Case 8: i32 64x32 v60x30 (valid != tile) - func.func @TMUL_i32_64x32_v60x30(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c30 = arith.constant 30 : index - %c32 = arith.constant 32 : index - %c60 = arith.constant 60 : index - %c64 = arith.constant 64 : index - %c2048 = arith.constant 2048 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c60, %c30], - strides = [%c2048, %c2048, %c2048, %c32, %c1] - : !pto.tensor_view<1x1x1x60x30xi32> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c60, %c30], - strides = [%c2048, %c2048, %c2048, %c32, %c1] - : !pto.tensor_view<1x1x1x60x30xi32> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c60, %c30], - strides = [%c2048, %c2048, %c2048, %c32, %c1] - : !pto.tensor_view<1x1x1x60x30xi32> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c60, %c30] - : !pto.tensor_view<1x1x1x60x30xi32> -> !pto.partition_tensor_view<1x1x1x60x30xi32> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c60, %c30] - : !pto.tensor_view<1x1x1x60x30xi32> -> !pto.partition_tensor_view<1x1x1x60x30xi32> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c60, %c30] - : !pto.tensor_view<1x1x1x60x30xi32> -> !pto.partition_tensor_view<1x1x1x60x30xi32> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x60x30xi32>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x60x30xi32>) - outs(%b : !pto.tile_buf) - - pto.tmul ins(%a, %b : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x60x30xi32>) - return - } } \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tor/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tor/cases.py index bbcd9631b..736a5ff8f 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tor/cases.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tor/cases.py @@ -1,8 +1,7 @@ -#!/usr/bin/python3 # Copyright (c) 2026 Huawei Technologies Co., Ltd. # This program is free software, you can redistribute it and/or modify it under the terms and conditions of # CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You can not use this file except in compliance with the License. +# Please refer to the License for details. You may not use this file except in compliance with the License. # THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. # See LICENSE in the root of the software repository for the full text of the License. @@ -12,11 +11,13 @@ """Single source of truth for tor ST test cases. Each case defines: - - name: case identifier - - dtype: numpy dtype (np.int32) - - shape: (rows, cols) — allocated tile dimensions - - valid_shape: (valid_rows, valid_cols) — effective computation region - - eps: tolerance for numpy.allclose (atol and rtol) + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.int32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. """ import numpy as np diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tor/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tor/compare.py index b975718e6..4eae3bc07 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tor/compare.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tor/compare.py @@ -1,4 +1,3 @@ -#!/usr/bin/python3 # Copyright (c) 2026 Huawei Technologies Co., Ltd. # This program is free software, you can redistribute it and/or modify it under the terms and conditions of # CANN Open Software License Agreement Version 2.0 (the "License"). diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tor/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tor/gen_data.py index ada52fdfc..c822c0be3 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tor/gen_data.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tor/gen_data.py @@ -1,4 +1,3 @@ -#!/usr/bin/python3 # Copyright (c) 2026 Huawei Technologies Co., Ltd. # This program is free software, you can redistribute it and/or modify it under the terms and conditions of # CANN Open Software License Agreement Version 2.0 (the "License"). diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tor/tor.pto b/test/tilelang_st/npu/a5/src/st/testcase/tor/tor.pto index 4f7960133..25a9c7f57 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tor/tor.pto +++ b/test/tilelang_st/npu/a5/src/st/testcase/tor/tor.pto @@ -10,9 +10,6 @@ // Multiple cases with different shapes in a single module. // Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm // to produce LLVM IR. -// -// NOTE: Unsigned integer types temporarily disabled due to HIVM backend bug -// See BUG_REPORT_UNSIGNED_BITOPS.md for details module { // Case 0: i32 16x64 (1024 elements) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshl/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tshl/cases.py index 4b312c5e0..4bc308400 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tshl/cases.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshl/cases.py @@ -1,8 +1,7 @@ -#!/usr/bin/python3 # Copyright (c) 2026 Huawei Technologies Co., Ltd. # This program is free software, you can redistribute it and/or modify it under the terms and conditions of # CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You can not use this file except in compliance with the License. +# Please refer to the License for details. You may not use this file except in compliance with the License. # THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. # See LICENSE in the root of the software repository for the full text of the License. @@ -12,11 +11,13 @@ """Single source of truth for tshl ST test cases. Each case defines: - - name: case identifier - - dtype: numpy dtype (np.int32) - - shape: (rows, cols) — allocated tile dimensions - - valid_shape: (valid_rows, valid_cols) — effective computation region - - eps: tolerance for numpy.allclose (atol and rtol) + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.int32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. """ import numpy as np diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshl/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tshl/compare.py index b975718e6..4eae3bc07 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tshl/compare.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshl/compare.py @@ -1,4 +1,3 @@ -#!/usr/bin/python3 # Copyright (c) 2026 Huawei Technologies Co., Ltd. # This program is free software, you can redistribute it and/or modify it under the terms and conditions of # CANN Open Software License Agreement Version 2.0 (the "License"). diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshl/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tshl/gen_data.py index fdeae9fa0..811ffc995 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tshl/gen_data.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshl/gen_data.py @@ -1,4 +1,3 @@ -#!/usr/bin/python3 # Copyright (c) 2026 Huawei Technologies Co., Ltd. # This program is free software, you can redistribute it and/or modify it under the terms and conditions of # CANN Open Software License Agreement Version 2.0 (the "License"). diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshl/tshl.pto b/test/tilelang_st/npu/a5/src/st/testcase/tshl/tshl.pto index 4526b8b3f..78a8925cd 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tshl/tshl.pto +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshl/tshl.pto @@ -10,9 +10,6 @@ // Multiple cases with different shapes in a single module. // Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm // to produce LLVM IR. -// -// NOTE: Unsigned integer types temporarily disabled due to HIVM backend bug -// See BUG_REPORT_UNSIGNED_BITOPS.md for details module { // Case 0: i32 16x64 (1024 elements) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshr/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tshr/cases.py index 2ce53b6c5..36075525b 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tshr/cases.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshr/cases.py @@ -1,8 +1,7 @@ -#!/usr/bin/python3 # Copyright (c) 2026 Huawei Technologies Co., Ltd. # This program is free software, you can redistribute it and/or modify it under the terms and conditions of # CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You can not use this file except in compliance with the License. +# Please refer to the License for details. You may not use this file except in compliance with the License. # THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. # See LICENSE in the root of the software repository for the full text of the License. @@ -12,11 +11,13 @@ """Single source of truth for tshr ST test cases. Each case defines: - - name: case identifier - - dtype: numpy dtype (np.int32) - - shape: (rows, cols) — allocated tile dimensions - - valid_shape: (valid_rows, valid_cols) — effective computation region - - eps: tolerance for numpy.allclose (atol and rtol) + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.int32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. """ import numpy as np diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshr/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tshr/compare.py index b975718e6..4eae3bc07 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tshr/compare.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshr/compare.py @@ -1,4 +1,3 @@ -#!/usr/bin/python3 # Copyright (c) 2026 Huawei Technologies Co., Ltd. # This program is free software, you can redistribute it and/or modify it under the terms and conditions of # CANN Open Software License Agreement Version 2.0 (the "License"). diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshr/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tshr/gen_data.py index 152667627..5737627f7 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tshr/gen_data.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshr/gen_data.py @@ -1,4 +1,3 @@ -#!/usr/bin/python3 # Copyright (c) 2026 Huawei Technologies Co., Ltd. # This program is free software, you can redistribute it and/or modify it under the terms and conditions of # CANN Open Software License Agreement Version 2.0 (the "License"). diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshr/tshr.pto b/test/tilelang_st/npu/a5/src/st/testcase/tshr/tshr.pto index dda93d3bf..a90e1f88f 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tshr/tshr.pto +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshr/tshr.pto @@ -10,9 +10,6 @@ // Multiple cases with different shapes in a single module. // Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm // to produce LLVM IR. -// -// NOTE: Unsigned integer types temporarily disabled due to HIVM backend bug -// See BUG_REPORT_UNSIGNED_BITOPS.md for details module { // Case 0: i32 16x64 (1024 elements) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsub/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tsub/cases.py index 20ef25491..b71da2e9b 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tsub/cases.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsub/cases.py @@ -1,8 +1,7 @@ -#!/usr/bin/python3 # Copyright (c) 2026 Huawei Technologies Co., Ltd. # This program is free software, you can redistribute it and/or modify it under the terms and conditions of # CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You can not use this file except in compliance with the License. +# Please refer to the License for details. You may not use this file except in compliance with the License. # THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. # See LICENSE in the root of the software repository for the full text of the License. @@ -38,32 +37,4 @@ "valid_shape": (32, 32), "eps": 1e-6, }, - { - "name": "f32_64x64", - "dtype": np.float32, - "shape": (64, 64), - "valid_shape": (64, 64), - "eps": 1e-6, - }, - { - "name": "i32_64x64", - "dtype": np.int32, - "shape": (64, 64), - "valid_shape": (64, 64), - "eps": 0, - }, - { - "name": "i16_64x64", - "dtype": np.int16, - "shape": (64, 64), - "valid_shape": (64, 64), - "eps": 0, - }, - { - "name": "f16_64x64", - "dtype": np.float16, - "shape": (64, 64), - "valid_shape": (64, 64), - "eps": 1e-3, - }, ] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsub/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tsub/compare.py index b975718e6..4eae3bc07 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tsub/compare.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsub/compare.py @@ -1,4 +1,3 @@ -#!/usr/bin/python3 # Copyright (c) 2026 Huawei Technologies Co., Ltd. # This program is free software, you can redistribute it and/or modify it under the terms and conditions of # CANN Open Software License Agreement Version 2.0 (the "License"). diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsub/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tsub/gen_data.py index 234dd6cfc..95cccfd2a 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tsub/gen_data.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsub/gen_data.py @@ -1,4 +1,3 @@ -#!/usr/bin/python3 # Copyright (c) 2026 Huawei Technologies Co., Ltd. # This program is free software, you can redistribute it and/or modify it under the terms and conditions of # CANN Open Software License Agreement Version 2.0 (the "License"). diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsub/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tsub/launch.cpp index 8c00680c0..256c0ed07 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tsub/launch.cpp +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsub/launch.cpp @@ -1,7 +1,7 @@ // Copyright (c) 2026 Huawei Technologies Co., Ltd. // This program is free software, you can redistribute it and/or modify it under the terms and conditions of // CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You can not use this file except in compliance with the License. +// Please refer to the License for details. You may not use this file except in compliance with the License. // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. @@ -15,41 +15,13 @@ // Case 0: f32 16x64 extern "C" __global__ AICORE void TSUB_f32_16x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); -void LaunchTSUB_f32_16x64(void *a, void *b, void *c, void *stream) { +void LaunchTSUB_f32_16x64(float *a, float *b, float *c, void *stream) { TSUB_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); } // Case 1: f32 32x32 extern "C" __global__ AICORE void TSUB_f32_32x32(__gm__ float *a, __gm__ float *b, __gm__ float *c); -void LaunchTSUB_f32_32x32(void *a, void *b, void *c, void *stream) { +void LaunchTSUB_f32_32x32(float *a, float *b, float *c, void *stream) { TSUB_f32_32x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); -} - -// Case 2: f32 64x64 -extern "C" __global__ AICORE void TSUB_f32_64x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); - -void LaunchTSUB_f32_64x64(void *a, void *b, void *c, void *stream) { - TSUB_f32_64x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); -} - -// Case 3: i32 64x64 -extern "C" __global__ AICORE void TSUB_i32_64x64(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); - -void LaunchTSUB_i32_64x64(void *a, void *b, void *c, void *stream) { - TSUB_i32_64x64<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); -} - -// Case 4: i16 64x64 -extern "C" __global__ AICORE void TSUB_i16_64x64(__gm__ int16_t *a, __gm__ int16_t *b, __gm__ int16_t *c); - -void LaunchTSUB_i16_64x64(void *a, void *b, void *c, void *stream) { - TSUB_i16_64x64<<<1, nullptr, stream>>>((__gm__ int16_t *)a, (__gm__ int16_t *)b, (__gm__ int16_t *)c); -} - -// Case 5: f16 64x64 -extern "C" __global__ AICORE void TSUB_f16_64x64(__gm__ half *a, __gm__ half *b, __gm__ half *c); - -void LaunchTSUB_f16_64x64(void *a, void *b, void *c, void *stream) { - TSUB_f16_64x64<<<1, nullptr, stream>>>((__gm__ half *)a, (__gm__ half *)b, (__gm__ half *)c); } \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsub/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tsub/main.cpp index 0827b91a3..b5e338d4b 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tsub/main.cpp +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsub/main.cpp @@ -1,7 +1,7 @@ // Copyright (c) 2026 Huawei Technologies Co., Ltd. // This program is free software, you can redistribute it and/or modify it under the terms and conditions of // CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You can not use this file except in compliance with the License. +// Please refer to the License for details. You may not use this file except in compliance with the License. // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. @@ -22,14 +22,10 @@ using namespace PtoTestCommon; // Kernel launch wrappers (defined in launch.cpp) -void LaunchTSUB_f32_16x64(void *a, void *b, void *c, void *stream); -void LaunchTSUB_f32_32x32(void *a, void *b, void *c, void *stream); -void LaunchTSUB_f32_64x64(void *a, void *b, void *c, void *stream); -void LaunchTSUB_i32_64x64(void *a, void *b, void *c, void *stream); -void LaunchTSUB_i16_64x64(void *a, void *b, void *c, void *stream); -void LaunchTSUB_f16_64x64(void *a, void *b, void *c, void *stream); +void LaunchTSUB_f32_16x64(float *a, float *b, float *c, void *stream); +void LaunchTSUB_f32_32x32(float *a, float *b, float *c, void *stream); -using LaunchFn = void (*)(void *, void *, void *, void *); +using LaunchFn = void (*)(float *, float *, float *, void *); struct TestCase { const char *name; @@ -44,15 +40,10 @@ struct TestCase { static const TestCase kCases[] = { {"f32_16x64", LaunchTSUB_f32_16x64, 16, 64, 16, 64, sizeof(float)}, {"f32_32x32", LaunchTSUB_f32_32x32, 32, 32, 32, 32, sizeof(float)}, - {"f32_64x64", LaunchTSUB_f32_64x64, 64, 64, 64, 64, sizeof(float)}, - {"i32_64x64", LaunchTSUB_i32_64x64, 64, 64, 64, 64, sizeof(int32_t)}, - {"i16_64x64", LaunchTSUB_i16_64x64, 64, 64, 64, 64, sizeof(int16_t)}, - {"f16_64x64", LaunchTSUB_f16_64x64, 64, 64, 64, 64, sizeof(uint16_t)}, }; static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { - (void)deviceId; int rc = 0; const size_t elemCount = tc.rows * tc.cols; const size_t fileSize = elemCount * tc.elemSize; @@ -65,16 +56,16 @@ static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { size_t src0FileSize = fileSize; size_t src1FileSize = fileSize; - void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; - void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + float *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + float *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; - aclrtMallocHost(&src0Host, fileSize); - aclrtMallocHost(&src1Host, fileSize); - aclrtMallocHost(&dstHost, fileSize); + aclrtMallocHost((void **)(&src0Host), fileSize); + aclrtMallocHost((void **)(&src1Host), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); - aclrtMalloc(&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); - aclrtMalloc(&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); - aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, fileSize)) { std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsub/output.ll b/test/tilelang_st/npu/a5/src/st/testcase/tsub/output.ll new file mode 100644 index 000000000..e69de29bb diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsub/tsub.pto b/test/tilelang_st/npu/a5/src/st/testcase/tsub/tsub.pto index e64f6b9d9..43a3cea5e 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tsub/tsub.pto +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsub/tsub.pto @@ -1,7 +1,7 @@ // Copyright (c) 2026 Huawei Technologies Co., Ltd. // This program is free software, you can redistribute it and/or modify it under the terms and conditions of // CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You can not use this file except in compliance with the License. +// Please refer to the License for details. You may not use this file except in compliance with the License. // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. @@ -12,7 +12,7 @@ // to produce LLVM IR. module { - // Case 0: f32 16x64 + // Case 0: f32 16x64 (1024 elements) func.func @TSUB_f32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -76,7 +76,7 @@ module { return } - // Case 1: f32 32x32 + // Case 1: f32 32x32 (1024 elements) func.func @TSUB_f32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -138,256 +138,4 @@ module { outs(%c_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) return } - - // Case 2: f32 64x64 - func.func @TSUB_f32_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c64 = arith.constant 64 : index - %c4096 = arith.constant 4096 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xf32> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xf32> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xf32> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) - outs(%b : !pto.tile_buf) - - pto.tsub ins(%a, %b : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) - return - } - - // Case 3: i32 64x64 - func.func @TSUB_i32_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c64 = arith.constant 64 : index - %c4096 = arith.constant 4096 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xi32> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xi32> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xi32> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) - outs(%b : !pto.tile_buf) - - pto.tsub ins(%a, %b : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) - return - } - - // Case 4: i16 64x64 - func.func @TSUB_i16_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c64 = arith.constant 64 : index - %c4096 = arith.constant 4096 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xi16> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xi16> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xi16> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) - outs(%b : !pto.tile_buf) - - pto.tsub ins(%a, %b : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) - return - } - - // Case 5: f16 64x64 - func.func @TSUB_f16_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c64 = arith.constant 64 : index - %c4096 = arith.constant 4096 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xf16> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xf16> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xf16> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xf16> -> !pto.partition_tensor_view<1x1x1x64x64xf16> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xf16> -> !pto.partition_tensor_view<1x1x1x64x64xf16> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xf16> -> !pto.partition_tensor_view<1x1x1x64x64xf16> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf16>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf16>) - outs(%b : !pto.tile_buf) - - pto.tsub ins(%a, %b : !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf16>) - return - } } \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/txor/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/txor/cases.py index 766ddf35a..c710ea612 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/txor/cases.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/txor/cases.py @@ -1,8 +1,7 @@ -#!/usr/bin/python3 # Copyright (c) 2026 Huawei Technologies Co., Ltd. # This program is free software, you can redistribute it and/or modify it under the terms and conditions of # CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You can not use this file except in compliance with the License. +# Please refer to the License for details. You may not use this file except in compliance with the License. # THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. # See LICENSE in the root of the software repository for the full text of the License. @@ -12,55 +11,30 @@ """Single source of truth for txor ST test cases. Each case defines: - - name: case identifier - - dtype: numpy dtype (np.int16, np.int8) - - dst_tile: (rows, cols) — dst tile buffer dimensions - - src0_tile: (rows, cols) — src0 tile buffer dimensions - - src1_tile: (rows, cols) — src1 tile buffer dimensions - - valid_shape: (valid_rows, valid_cols) — effective computation region - - eps: tolerance for numpy.allclose (atol and rtol) + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.int32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). -Note: src0/src1/dst tile buffer physical sizes can differ, - but valid_shape must be the same for all. +gen_data.py and compare.py both import this list to avoid redundant definitions. """ import numpy as np CASES = [ { - "name": "i16_64x64_64x64_64x64_64x64", - "dtype": np.int16, - "dst_tile": (64, 64), - "src0_tile": (64, 64), - "src1_tile": (64, 64), - "valid_shape": (64, 64), + "name": "i32_16x64", + "dtype": np.int32, + "shape": (16, 64), + "valid_shape": (16, 64), "eps": 0, }, { - "name": "i16_32x128_32x128_32x256_32x128", - "dtype": np.int16, - "dst_tile": (32, 128), - "src0_tile": (32, 128), - "src1_tile": (32, 256), - "valid_shape": (32, 128), - "eps": 0, - }, - { - "name": "i16_32x128_32x128_32x256_32x127", - "dtype": np.int16, - "dst_tile": (32, 128), - "src0_tile": (32, 128), - "src1_tile": (32, 256), - "valid_shape": (32, 127), - "eps": 0, - }, - { - "name": "i8_32x128_32x128_32x256_32x127", - "dtype": np.int8, - "dst_tile": (32, 128), - "src0_tile": (32, 128), - "src1_tile": (32, 256), - "valid_shape": (32, 127), + "name": "i32_32x32", + "dtype": np.int32, + "shape": (32, 32), + "valid_shape": (32, 32), "eps": 0, }, ] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/txor/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/txor/compare.py index 96ab2dda8..4eae3bc07 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/txor/compare.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/txor/compare.py @@ -1,4 +1,3 @@ -#!/usr/bin/python3 # Copyright (c) 2026 Huawei Technologies Co., Ltd. # This program is free software, you can redistribute it and/or modify it under the terms and conditions of # CANN Open Software License Agreement Version 2.0 (the "License"). @@ -27,12 +26,11 @@ def main(): continue case_dir = case["name"] - dst_tile = case["dst_tile"] - valid_shape = case["valid_shape"] - vr, vc = valid_shape + shape = case["shape"] + vr, vc = case["valid_shape"] - golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(dst_tile) - output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(dst_tile) + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) if ok: diff --git a/test/tilelang_st/npu/a5/src/st/testcase/txor/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/txor/gen_data.py index 62e5cb32b..2d2fbe7b6 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/txor/gen_data.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/txor/gen_data.py @@ -1,4 +1,3 @@ -#!/usr/bin/python3 # Copyright (c) 2026 Huawei Technologies Co., Ltd. # This program is free software, you can redistribute it and/or modify it under the terms and conditions of # CANN Open Software License Agreement Version 2.0 (the "License"). @@ -15,24 +14,19 @@ validate_cases(CASES) -np.random.seed(19) - for case in CASES: setup_case_rng(case) dtype = case["dtype"] - dst_tile = case["dst_tile"] - src0_tile = case["src0_tile"] - src1_tile = case["src1_tile"] + shape = case["shape"] valid_shape = case["valid_shape"] - dtype_info = np.iinfo(dtype) - input1 = np.random.randint(dtype_info.min, dtype_info.max, size=src0_tile).astype(dtype) - input2 = np.random.randint(dtype_info.min, dtype_info.max, size=src1_tile).astype(dtype) + input1 = np.random.randint(0, 100, size=shape).astype(dtype) + input2 = np.random.randint(0, 100, size=shape).astype(dtype) - golden = np.zeros(dst_tile, dtype=dtype) + golden = np.zeros(shape, dtype=dtype) vr, vc = valid_shape golden[:vr, :vc] = (input1[:vr, :vc] ^ input2[:vr, :vc]).astype(dtype, copy=False) save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) - print(f"[INFO] gen_data: {case['name']} dst={dst_tile} src0={src0_tile} src1={src1_tile} valid={valid_shape} dtype={dtype.__name__}") \ No newline at end of file + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/txor/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/txor/launch.cpp index 50bf8af26..90fd20459 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/txor/launch.cpp +++ b/test/tilelang_st/npu/a5/src/st/testcase/txor/launch.cpp @@ -12,30 +12,16 @@ #define AICORE [aicore] #endif -// Case 0: i16, dst=64x64, src0=64x64, src1=64x64, valid=64x64 -extern "C" __global__ AICORE void TXOR_i16_64x64_64x64_64x64_64x64(__gm__ int16_t *a, __gm__ int16_t *b, __gm__ int16_t *c); +// Case 0: i32 16x64 +extern "C" __global__ AICORE void TXOR_i32_16x64(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); -void LaunchTXOR_i16_64x64_64x64_64x64_64x64(void *a, void *b, void *c, void *stream) { - TXOR_i16_64x64_64x64_64x64_64x64<<<1, nullptr, stream>>>((__gm__ int16_t *)a, (__gm__ int16_t *)b, (__gm__ int16_t *)c); +void LaunchTXOR_i32_16x64(int32_t *a, int32_t *b, int32_t *c, void *stream) { + TXOR_i32_16x64<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); } -// Case 1: i16, dst=32x128, src0=32x128, src1=32x256, valid=32x128 -extern "C" __global__ AICORE void TXOR_i16_32x128_32x128_32x256_32x128(__gm__ int16_t *a, __gm__ int16_t *b, __gm__ int16_t *c); +// Case 1: i32 32x32 +extern "C" __global__ AICORE void TXOR_i32_32x32(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); -void LaunchTXOR_i16_32x128_32x128_32x256_32x128(void *a, void *b, void *c, void *stream) { - TXOR_i16_32x128_32x128_32x256_32x128<<<1, nullptr, stream>>>((__gm__ int16_t *)a, (__gm__ int16_t *)b, (__gm__ int16_t *)c); -} - -// Case 2: i16, dst=32x128, src0=32x128, src1=32x256, valid=32x127 -extern "C" __global__ AICORE void TXOR_i16_32x128_32x128_32x256_32x127(__gm__ int16_t *a, __gm__ int16_t *b, __gm__ int16_t *c); - -void LaunchTXOR_i16_32x128_32x128_32x256_32x127(void *a, void *b, void *c, void *stream) { - TXOR_i16_32x128_32x128_32x256_32x127<<<1, nullptr, stream>>>((__gm__ int16_t *)a, (__gm__ int16_t *)b, (__gm__ int16_t *)c); -} - -// Case 3: i8, dst=32x128, src0=32x128, src1=32x256, valid=32x127 -extern "C" __global__ AICORE void TXOR_i8_32x128_32x128_32x256_32x127(__gm__ int8_t *a, __gm__ int8_t *b, __gm__ int8_t *c); - -void LaunchTXOR_i8_32x128_32x128_32x256_32x127(void *a, void *b, void *c, void *stream) { - TXOR_i8_32x128_32x128_32x256_32x127<<<1, nullptr, stream>>>((__gm__ int8_t *)a, (__gm__ int8_t *)b, (__gm__ int8_t *)c); +void LaunchTXOR_i32_32x32(int32_t *a, int32_t *b, int32_t *c, void *stream) { + TXOR_i32_32x32<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); } \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/txor/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/txor/main.cpp index 70729478e..838ff0de1 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/txor/main.cpp +++ b/test/tilelang_st/npu/a5/src/st/testcase/txor/main.cpp @@ -1,12 +1,14 @@ // Copyright (c) 2026 Huawei Technologies Co., Ltd. // This program is free software, you can redistribute it and/or modify it under the terms and conditions of // CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You can not use this file except in compliance with the License. +// Please refer to the License for details. You may not use this file except in compliance with the License. // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. // Host driver for TileLang txor ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. #include "acl/acl.h" #include "test_common.h" @@ -19,79 +21,72 @@ using namespace PtoTestCommon; -void LaunchTXOR_i16_64x64_64x64_64x64_64x64(void *a, void *b, void *c, void *stream); -void LaunchTXOR_i16_32x128_32x128_32x256_32x128(void *a, void *b, void *c, void *stream); -void LaunchTXOR_i16_32x128_32x128_32x256_32x127(void *a, void *b, void *c, void *stream); -void LaunchTXOR_i8_32x128_32x128_32x256_32x127(void *a, void *b, void *c, void *stream); +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTXOR_i32_16x64(int32_t *a, int32_t *b, int32_t *c, void *stream); +void LaunchTXOR_i32_32x32(int32_t *a, int32_t *b, int32_t *c, void *stream); -using LaunchFn = void (*)(void *, void *, void *, void *); +using LaunchFn = void (*)(int32_t *, int32_t *, int32_t *, void *); struct TestCase { const char *name; LaunchFn launch; - size_t src0Rows; - size_t src0Cols; - size_t src1Rows; - size_t src1Cols; - size_t dstRows; - size_t dstCols; - size_t validRows; - size_t validCols; - size_t elemSize; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element }; static const TestCase kCases[] = { - {"i16_64x64_64x64_64x64_64x64", LaunchTXOR_i16_64x64_64x64_64x64_64x64, 64, 64, 64, 64, 64, 64, 64, 64, sizeof(int16_t)}, - {"i16_32x128_32x128_32x256_32x128", LaunchTXOR_i16_32x128_32x128_32x256_32x128, 32, 128, 32, 256, 32, 128, 32, 128, sizeof(int16_t)}, - {"i16_32x128_32x128_32x256_32x127", LaunchTXOR_i16_32x128_32x128_32x256_32x127, 32, 128, 32, 256, 32, 128, 32, 127, sizeof(int16_t)}, - {"i8_32x128_32x128_32x256_32x127", LaunchTXOR_i8_32x128_32x128_32x256_32x127, 32, 128, 32, 256, 32, 128, 32, 127, sizeof(int8_t)}, + {"i32_16x64", LaunchTXOR_i32_16x64, 16, 64, 16, 64, sizeof(int32_t)}, + {"i32_32x32", LaunchTXOR_i32_32x32, 32, 32, 32, 32, sizeof(int32_t)}, }; static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { int rc = 0; - const size_t src0Size = tc.src0Rows * tc.src0Cols * tc.elemSize; - const size_t src1Size = tc.src1Rows * tc.src1Cols * tc.elemSize; - const size_t dstSize = tc.dstRows * tc.dstCols * tc.elemSize; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; - std::printf("[INFO] === case: %s (dst=%zux%zu, src0=%zux%zu, src1=%zux%zu, valid=%zux%zu) ===\n", - tc.name, tc.dstRows, tc.dstCols, tc.src0Rows, tc.src0Cols, tc.src1Rows, tc.src1Cols, tc.validRows, tc.validCols); + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + // Per-case data directory std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = fileSize; + size_t src1FileSize = fileSize; - void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; - void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + int32_t *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + int32_t *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; - aclrtMallocHost(&src0Host, src0Size); - aclrtMallocHost(&src1Host, src1Size); - aclrtMallocHost(&dstHost, dstSize); + aclrtMallocHost((void **)(&src0Host), fileSize); + aclrtMallocHost((void **)(&src1Host), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); - aclrtMalloc(&src0Device, src0Size, ACL_MEM_MALLOC_HUGE_FIRST); - aclrtMalloc(&src1Device, src1Size, ACL_MEM_MALLOC_HUGE_FIRST); - aclrtMalloc(&dstDevice, dstSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); - size_t fileSize = 0; - if (!ReadFile((caseDir + "/input1.bin").c_str(), fileSize, src0Host, src0Size)) { + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, fileSize)) { std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); rc = 1; } - fileSize = 0; - if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), fileSize, src1Host, src1Size)) { + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, fileSize)) { std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); rc = 1; } if (rc == 0) { - aclrtMemcpy(src0Device, src0Size, src0Host, src0Size, ACL_MEMCPY_HOST_TO_DEVICE); - aclrtMemcpy(src1Device, src1Size, src1Host, src1Size, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src0Device, fileSize, src0Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, fileSize, src1Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); tc.launch(src0Device, src1Device, dstDevice, stream); aclrtSynchronizeStream(stream); - aclrtMemcpy(dstHost, dstSize, dstDevice, dstSize, ACL_MEMCPY_DEVICE_TO_HOST); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); } - if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstSize)) { + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); rc = 1; } @@ -115,6 +110,7 @@ static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { } int main(int argc, char *argv[]) { + // Optional case filter: ./txor [case_name] const char *caseFilter = (argc > 1) ? argv[1] : nullptr; int rc = 0; diff --git a/test/tilelang_st/npu/a5/src/st/testcase/txor/txor.pto b/test/tilelang_st/npu/a5/src/st/testcase/txor/txor.pto index d0a923d18..5e36ea0d5 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/txor/txor.pto +++ b/test/tilelang_st/npu/a5/src/st/testcase/txor/txor.pto @@ -6,282 +6,140 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// TileLang ST kernels for pto.txor: tload(a) + tload(b) + txor(a,b,c)->c + tstore(c). -// Aligned with pto-isa/tests/npu/a2a3/src/st/testcase/txor -// Cases have different src/dst tile buffer sizes but same valid_shape. -// -// NOTE: Unsigned integer types temporarily disabled due to HIVM backend bug -// See BUG_REPORT_UNSIGNED_BITOPS.md for details +// TileLang ST kernels for pto.txor: tload(a) + tload(b) + txor(a,b)->c + tstore(c). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. module { - // Case 0: i16, dst=64x64, src0=64x64, src1=64x64, valid=64x64 - func.func @TXOR_i16_64x64_64x64_64x64_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + // Case 0: i32 16x64 (1024 elements) + func.func @TXOR_i32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index %c64 = arith.constant 64 : index - %c4096 = arith.constant 4096 : index + %c1024 = arith.constant 1024 : index %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xi16> + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi32> %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xi16> + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi32> %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c64, %c64], - strides = [%c4096, %c4096, %c4096, %c64, %c1] - : !pto.tensor_view<1x1x1x64x64xi16> + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi32> %a_part = pto.partition_view %a_view, offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x64xi32> %b_part = pto.partition_view %b_view, offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x64xi32> %c_part = pto.partition_view %c_view, offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c64, %c64] - : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x64xi32> %a = pto.alloc_tile - : !pto.tile_buf %b = pto.alloc_tile - : !pto.tile_buf %c = pto.alloc_tile - : !pto.tile_buf - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) - outs(%a : !pto.tile_buf) + outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) - outs(%b : !pto.tile_buf) + outs(%b : !pto.tile_buf) - pto.txor ins(%a, %b, %c : !pto.tile_buf, - !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x64xi32>) return } - // Case 1: i16, dst=32x128, src0=32x128, src1=32x256, valid=32x128 - func.func @TXOR_i16_32x128_32x128_32x256_32x128(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { + // Case 1: i32 32x32 (1024 elements) + func.func @TXOR_i32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c32 = arith.constant 32 : index - %c128 = arith.constant 128 : index - %c256 = arith.constant 256 : index - %c4096 = arith.constant 4096 : index - %c8192 = arith.constant 8192 : index + %c1024 = arith.constant 1024 : index %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c32, %c128], - strides = [%c4096, %c4096, %c4096, %c128, %c1] - : !pto.tensor_view<1x1x1x32x128xi16> + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c32, %c256], - strides = [%c8192, %c8192, %c8192, %c256, %c1] - : !pto.tensor_view<1x1x1x32x256xi16> + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c32, %c128], - strides = [%c4096, %c4096, %c4096, %c128, %c1] - : !pto.tensor_view<1x1x1x32x128xi16> + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> %a_part = pto.partition_view %a_view, offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c32, %c128] - : !pto.tensor_view<1x1x1x32x128xi16> -> !pto.partition_tensor_view<1x1x1x32x128xi16> + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> %b_part = pto.partition_view %b_view, offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c32, %c128] - : !pto.tensor_view<1x1x1x32x256xi16> -> !pto.partition_tensor_view<1x1x1x32x128xi16> + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> %c_part = pto.partition_view %c_view, offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c32, %c128] - : !pto.tensor_view<1x1x1x32x128xi16> -> !pto.partition_tensor_view<1x1x1x32x128xi16> + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> %a = pto.alloc_tile - : !pto.tile_buf %b = pto.alloc_tile - : !pto.tile_buf %c = pto.alloc_tile - : !pto.tile_buf - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x128xi16>) - outs(%a : !pto.tile_buf) + outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x32x128xi16>) - outs(%b : !pto.tile_buf) + outs(%b : !pto.tile_buf) - pto.txor ins(%a, %b, %c : !pto.tile_buf, - !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x32x128xi16>) - return - } - - // Case 2: i16, dst=32x128, src0=32x128, src1=32x256, valid=32x127 - func.func @TXOR_i16_32x128_32x128_32x256_32x127(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c32 = arith.constant 32 : index - %c127 = arith.constant 127 : index - %c128 = arith.constant 128 : index - %c256 = arith.constant 256 : index - %c4096 = arith.constant 4096 : index - %c8192 = arith.constant 8192 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c32, %c128], - strides = [%c4096, %c4096, %c4096, %c128, %c1] - : !pto.tensor_view<1x1x1x32x128xi16> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c32, %c256], - strides = [%c8192, %c8192, %c8192, %c256, %c1] - : !pto.tensor_view<1x1x1x32x256xi16> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c32, %c128], - strides = [%c4096, %c4096, %c4096, %c128, %c1] - : !pto.tensor_view<1x1x1x32x128xi16> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c32, %c127] - : !pto.tensor_view<1x1x1x32x128xi16> -> !pto.partition_tensor_view<1x1x1x32x127xi16> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c32, %c127] - : !pto.tensor_view<1x1x1x32x256xi16> -> !pto.partition_tensor_view<1x1x1x32x127xi16> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c32, %c127] - : !pto.tensor_view<1x1x1x32x128xi16> -> !pto.partition_tensor_view<1x1x1x32x127xi16> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x127xi16>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x32x127xi16>) - outs(%b : !pto.tile_buf) - - pto.txor ins(%a, %b, %c : !pto.tile_buf, - !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x32x127xi16>) - return - } - - // Case 3: i8, dst=32x128, src0=32x128, src1=32x256, valid=32x127 - func.func @TXOR_i8_32x128_32x128_32x256_32x127(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c32 = arith.constant 32 : index - %c127 = arith.constant 127 : index - %c128 = arith.constant 128 : index - %c256 = arith.constant 256 : index - %c4096 = arith.constant 4096 : index - %c8192 = arith.constant 8192 : index - - %a_view = pto.make_tensor_view %a_ptr, - shape = [%c1, %c1, %c1, %c32, %c128], - strides = [%c4096, %c4096, %c4096, %c128, %c1] - : !pto.tensor_view<1x1x1x32x128xi8> - %b_view = pto.make_tensor_view %b_ptr, - shape = [%c1, %c1, %c1, %c32, %c256], - strides = [%c8192, %c8192, %c8192, %c256, %c1] - : !pto.tensor_view<1x1x1x32x256xi8> - %c_view = pto.make_tensor_view %c_ptr, - shape = [%c1, %c1, %c1, %c32, %c128], - strides = [%c4096, %c4096, %c4096, %c128, %c1] - : !pto.tensor_view<1x1x1x32x128xi8> - - %a_part = pto.partition_view %a_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c32, %c127] - : !pto.tensor_view<1x1x1x32x128xi8> -> !pto.partition_tensor_view<1x1x1x32x127xi8> - %b_part = pto.partition_view %b_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c32, %c127] - : !pto.tensor_view<1x1x1x32x256xi8> -> !pto.partition_tensor_view<1x1x1x32x127xi8> - %c_part = pto.partition_view %c_view, - offsets = [%c0, %c0, %c0, %c0, %c0], - sizes = [%c1, %c1, %c1, %c32, %c127] - : !pto.tensor_view<1x1x1x32x128xi8> -> !pto.partition_tensor_view<1x1x1x32x127xi8> - - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %c = pto.alloc_tile - : !pto.tile_buf - - pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x127xi8>) - outs(%a : !pto.tile_buf) - pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x32x127xi8>) - outs(%b : !pto.tile_buf) - - pto.txor ins(%a, %b, %c : !pto.tile_buf, - !pto.tile_buf, - !pto.tile_buf) - outs(%c : !pto.tile_buf) - - pto.tstore ins(%c : !pto.tile_buf) - outs(%c_part : !pto.partition_tensor_view<1x1x1x32x127xi8>) + outs(%c_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) return } } \ No newline at end of file From 1a6b535cd31f1ba2b6f5b5a04686aff1723fdc9b Mon Sep 17 00:00:00 2001 From: zwd060924 Date: Mon, 27 Apr 2026 17:46:21 +0800 Subject: [PATCH 186/192] [Fix] delete tcmp tfmod trem in Cmakelists --- test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt index 7571f1e4b..ab175d764 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt +++ b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt @@ -135,7 +135,6 @@ set(ALL_TESTCASES tdiv tmax tmin - tcmp tshl tshr tand @@ -201,8 +200,6 @@ set(ALL_TESTCASES tshrs tsubs txors - tfmod - trem ) if((TEST_CASE IN_LIST ALL_TESTCASES) OR (TEST_CASE STREQUAL "all")) From 2ed48b060a994a9ed88d0d178fc7c69092557a1e Mon Sep 17 00:00:00 2001 From: zwd060924 Date: Mon, 27 Apr 2026 18:00:53 +0800 Subject: [PATCH 187/192] [Fix] CI error in tmin --- test/tilelang_st/npu/a5/src/st/testcase/tmin/CMakeLists.txt | 2 +- test/tilelang_st/npu/a5/src/st/testcase/tmin/launch.cpp | 2 +- test/tilelang_st/npu/a5/src/st/testcase/tmin/main.cpp | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmin/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tmin/CMakeLists.txt index fa90b8804..f811b5f04 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tmin/CMakeLists.txt +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmin/CMakeLists.txt @@ -1,7 +1,7 @@ # Copyright (c) 2026 Huawei Technologies Co., Ltd. # This program is free software, you can redistribute it and/or modify it under the terms and conditions of # CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You can not use this file except in compliance with the License. +# Please refer to the License for details. You may not use this file except in compliance with the License. # THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. # See LICENSE in the root of the software repository for the full text of the License. diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmin/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmin/launch.cpp index 8757cc295..95247e512 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tmin/launch.cpp +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmin/launch.cpp @@ -1,7 +1,7 @@ // Copyright (c) 2026 Huawei Technologies Co., Ltd. // This program is free software, you can redistribute it and/or modify it under the terms and conditions of // CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You can not use this file except in compliance with the License. +// Please refer to the License for details. You may not use this file except in compliance with the License. // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmin/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmin/main.cpp index 4fec9b639..214a35a22 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tmin/main.cpp +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmin/main.cpp @@ -1,12 +1,12 @@ // Copyright (c) 2026 Huawei Technologies Co., Ltd. // This program is free software, you can redistribute it and/or modify it under the terms and conditions of // CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You can not use this file except in compliance with the License. +// Please refer to the License for details. You may not use this file except in compliance with the License. // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// Host driver for TileLang tmin ST — case-table driven. +// Host driver for TileLang tand ST — case-table driven. // Each case launches a different kernel variant, reads/writes from per-case subdirectory. // Numerical comparison is done externally by compare.py. From 4662a3e63f180d831aa8aa9dc97047daff1fce11 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Mon, 27 Apr 2026 19:20:37 +0800 Subject: [PATCH 188/192] Fix VPTO vcvt and vaxpy mask lowering --- docs/isa/13-dsa-sfu-ops.md | 6 +-- docs/vpto-spec.md | 2 +- include/PTO/IR/VPTOOps.td | 6 ++- lib/PTO/IR/VPTO.cpp | 40 +++++++++++++++---- lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 28 +++---------- test/basic/expand_tile_op_tilelang_tcvt.pto | 12 +++--- .../tilelang_soft_vmod_backend_inline.pto | 16 ++++---- test/basic/vcvt_part_modes_verify_invalid.pto | 12 +----- ...cvt_part_modes_verify_invalid_even_odd.pto | 22 ++++++++++ test/basic/vcvt_part_modes_vpto_llvm.pto | 18 +++++---- .../conversion/vcvt-f16-special/kernel.pto | 3 +- .../vcvt-f16-to-f32-part-even/kernel.pto | 3 +- .../vcvt-f16-to-f32-part-odd/kernel.pto | 3 +- .../conversion/vcvt-f16-to-f32/kernel.pto | 3 +- .../conversion/vcvt-f32-special/kernel.pto | 6 ++- .../vcvt-f32-to-f16-pk-b32/kernel.pto | 2 +- .../conversion/vcvt-f32-to-f16/kernel.pto | 6 ++- .../vcvt-i32-to-i16-overflow/kernel.pto | 6 ++- .../conversion/vcvt-i64-to-f32/kernel.pto | 2 +- .../conversion/vcvt-tail-special/kernel.pto | 6 ++- .../micro-op/conversion/vcvt-tail/kernel.pto | 6 ++- .../vcvt-u32-to-u8-part-p0123/kernel.pto | 9 +++-- .../micro-op/dsa-sfu/vaxpy-f32/kernel.pto | 2 +- ...o_tilelang_inline_soft_divmod_fastpath.pto | 16 ++++---- .../docs/vpto_spec/vpto-spec-current.md | 8 ++-- tilelang-dsl/docs/vpto_spec/vpto-spec-v0.3.md | 8 ++-- tilelang-dsl/python/tilelang_dsl/lowering.py | 5 ++- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 2 +- 28 files changed, 150 insertions(+), 108 deletions(-) create mode 100644 test/basic/vcvt_part_modes_verify_invalid_even_odd.pto diff --git a/docs/isa/13-dsa-sfu-ops.md b/docs/isa/13-dsa-sfu-ops.md index 0196a559d..eeee43b70 100644 --- a/docs/isa/13-dsa-sfu-ops.md +++ b/docs/isa/13-dsa-sfu-ops.md @@ -84,7 +84,7 @@ for (int i = 0; i < N; i++) ### `pto.vaxpy` -- **syntax:** `%result = pto.vaxpy %src0, %src1, %alpha : !pto.vreg, !pto.vreg, T -> !pto.vreg` +- **syntax:** `%result = pto.vaxpy %src0, %src1, %alpha, %mask : !pto.vreg, !pto.vreg, T, !pto.mask -> !pto.vreg` - **A5 types:** f16, f32 - **semantics:** AXPY — scalar-vector multiply-add. @@ -93,8 +93,8 @@ for (int i = 0; i < N; i++) dst[i] = alpha * src0[i] + src1[i]; ``` -- **inputs:** `%src0` is the scaled vector, `%src1` is the addend vector, and - `%alpha` is the scalar multiplier. +- **inputs:** `%src0` is the scaled vector, `%src1` is the addend vector, + `%alpha` is the scalar multiplier, and `%mask` selects active lanes. - **outputs:** `%result` is the fused AXPY result. - **constraints and limitations:** Floating-point element types only on the current documented surface. diff --git a/docs/vpto-spec.md b/docs/vpto-spec.md index 3ac469bd7..07f1000f6 100644 --- a/docs/vpto-spec.md +++ b/docs/vpto-spec.md @@ -877,7 +877,7 @@ pto.vstsx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vre **Conversion (one vector in, different-typed vector out):** ```mlir -%result = pto.vcvt %input {rnd = "R", sat = "SAT", part = "EVEN"} : !pto.vreg -> !pto.vreg +%result = pto.vcvt %input, %mask {rnd = "R", sat = "SAT", part = "EVEN"} : !pto.vreg, !pto.mask -> !pto.vreg ``` **Predicate construction:** diff --git a/include/PTO/IR/VPTOOps.td b/include/PTO/IR/VPTOOps.td index 4963a7e07..b23d67a07 100644 --- a/include/PTO/IR/VPTOOps.td +++ b/include/PTO/IR/VPTOOps.td @@ -1330,6 +1330,7 @@ def PTO_VtrcOp : PTO_Op<"vtrc", [Pure]> { def PTO_VcvtOp : PTO_Op<"vcvt", [Pure]> { let arguments = (ins PTO_VectorType:$input, + PTO_MaskTypeConstraint:$mask, OptionalAttr:$rnd, OptionalAttr:$sat, OptionalAttr:$part @@ -1841,14 +1842,15 @@ def PTO_VaxpyOp : PTO_Op<"vaxpy", [Pure]> { let arguments = (ins PTO_VectorType:$src0, PTO_VectorType:$src1, - AnyType:$alpha + AnyType:$alpha, + PTO_MaskTypeConstraint:$mask ); let results = (outs PTO_VectorType:$result); let hasVerifier = 1; let assemblyFormat = [{ - $src0 `,` $src1 `,` $alpha attr-dict `:` type($src0) `,` type($src1) `,` type($alpha) `->` type($result) + $src0 `,` $src1 `,` $alpha `,` $mask attr-dict `:` type($src0) `,` type($src1) `,` type($alpha) `,` type($mask) `->` type($result) }]; } diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp index 511e488f5..fbda9928c 100644 --- a/lib/PTO/IR/VPTO.cpp +++ b/lib/PTO/IR/VPTO.cpp @@ -32,6 +32,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" +#include #include using namespace mlir; @@ -2868,11 +2869,14 @@ LogicalResult VtrcOp::verify() { ParseResult VcvtOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::UnresolvedOperand input; + OpAsmParser::UnresolvedOperand mask; NamedAttrList attrs; - Type inputType, resultType; + Type inputType, maskType, resultType; - if (parser.parseOperand(input) || parser.parseOptionalAttrDict(attrs) || - parser.parseColonType(inputType) || parser.parseArrow() || + if (parser.parseOperand(input) || parser.parseComma() || + parser.parseOperand(mask) || parser.parseOptionalAttrDict(attrs) || + parser.parseColonType(inputType) || parser.parseComma() || + parser.parseType(maskType) || parser.parseArrow() || parser.parseType(resultType)) return failure(); @@ -2910,16 +2914,18 @@ ParseResult VcvtOp::parse(OpAsmParser &parser, OperationState &result) { return failure(); result.addAttributes(attrs); - if (parser.resolveOperand(input, inputType, result.operands)) + if (parser.resolveOperand(input, inputType, result.operands) || + parser.resolveOperand(mask, maskType, result.operands)) return failure(); result.addTypes(resultType); return success(); } void VcvtOp::print(OpAsmPrinter &printer) { - printer << ' ' << getInput(); + printer << ' ' << getInput() << ", " << getMask(); printer.printOptionalAttrDict((*this)->getAttrs()); - printer << " : " << getInput().getType() << " -> " << getResult().getType(); + printer << " : " << getInput().getType() << ", " << getMask().getType() + << " -> " << getResult().getType(); } LogicalResult VcvtOp::verify() { @@ -2927,6 +2933,8 @@ LogicalResult VcvtOp::verify() { auto resultType = dyn_cast(getResult().getType()); if (!inputType || !resultType) return emitOpError("input and result must be !pto.vreg<...>"); + if (failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type"))) + return failure(); VcvtElemKind inputElemKind = classifyVcvtElemType(inputType.getElementType()); VcvtElemKind resultElemKind = classifyVcvtElemType(resultType.getElementType()); @@ -2938,6 +2946,16 @@ LogicalResult VcvtOp::verify() { auto resultElemBits = getVcvtElemBitWidth(resultElemKind); if (!inputElemBits || !resultElemBits) return emitOpError("could not determine vcvt element bit width"); + unsigned maskBitWidth = std::min(*inputElemBits, 32u); + StringRef expectedMaskGranularity = maskBitWidth == 8 ? "b8" + : maskBitWidth == 16 ? "b16" + : maskBitWidth == 32 ? "b32" + : ""; + if (expectedMaskGranularity.empty()) + return emitOpError("could not determine vcvt mask granularity"); + if (failed(verifyMaskTypeWithGranularityLike( + *this, getMask().getType(), "mask type", expectedMaskGranularity))) + return failure(); if (inputType.getElementCount() * static_cast(*inputElemBits) != resultType.getElementCount() * static_cast(*resultElemBits)) { return emitOpError("requires source and result vectors to carry the same " @@ -3201,7 +3219,8 @@ LogicalResult VexpdifOp::verify() { LogicalResult VaxpyOp::verify() { if (failed(verifyVRegTypeLike(*this, getSrc0().getType(), "src0 type")) || failed(verifyVRegTypeLike(*this, getSrc1().getType(), "src1 type")) || - failed(verifyVRegTypeLike(*this, getResult().getType(), "result type"))) + failed(verifyVRegTypeLike(*this, getResult().getType(), "result type")) || + failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type"))) return failure(); auto src0Type = cast(getSrc0().getType()); auto src1Type = cast(getSrc1().getType()); @@ -3211,6 +3230,13 @@ LogicalResult VaxpyOp::verify() { Type elemType = src0Type.getElementType(); if (!elemType.isF16() && !elemType.isF32()) return emitOpError("requires f16 or f32 vector element type"); + auto expectedGranularity = getVdupMaskGranularity(elemType); + if (!expectedGranularity) + return emitOpError("requires element type with supported predicate granularity"); + if (failed(verifyMaskTypeWithGranularityLike(*this, getMask().getType(), + "mask type", + *expectedGranularity))) + return failure(); if (getAlpha().getType() != elemType) return emitOpError("requires alpha type to match vector element type"); return success(); diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index 7ec8d71fc..597a82f77 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -5691,17 +5691,10 @@ class LowerVaxpyOpPattern final : public OpConversionPattern { LogicalResult matchAndRewrite(pto::VaxpyOp op, pto::VaxpyOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto laneCount = getElementCountFromVectorLike(op.getResult().getType()); Type elemType = getElementTypeFromVectorLike(op.getResult().getType()); - if (!laneCount || !elemType) + if (!elemType) return rewriter.notifyMatchFailure(op, "unsupported vaxpy signature"); - FailureOr mask = materializeDynamicPltMask( - rewriter, state, op.getLoc(), getI32Constant(rewriter, op.getLoc(), *laneCount), - elemType); - if (failed(mask)) - return rewriter.notifyMatchFailure(op, "failed to materialize vaxpy mask"); - Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); if (!resultType) return rewriter.notifyMatchFailure(op, "failed to convert vaxpy result type"); @@ -5713,12 +5706,12 @@ class LowerVaxpyOpPattern final : public OpConversionPattern { auto funcType = rewriter.getFunctionType( TypeRange{adaptor.getSrc1().getType(), adaptor.getSrc0().getType(), - adaptor.getAlpha().getType(), (*mask).getType()}, + adaptor.getAlpha().getType(), adaptor.getMask().getType()}, TypeRange{resultType}); auto call = rewriter.create( op.getLoc(), *calleeName, TypeRange{resultType}, ValueRange{adaptor.getSrc1(), adaptor.getSrc0(), adaptor.getAlpha(), - *mask}); + adaptor.getMask()}); state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); rewriter.replaceOp(op, call.getResults()); return success(); @@ -5945,21 +5938,10 @@ class LowerVcvtOpPattern final : public OpConversionPattern { LogicalResult matchAndRewrite(pto::VcvtOp op, pto::VcvtOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto inputLanes = getElementCountFromVectorLike(op.getInput().getType()); - if (!inputLanes) - return rewriter.notifyMatchFailure(op, "unsupported vcvt input shape"); - FailureOr contract = buildVcvtContract(op); if (failed(contract)) return rewriter.notifyMatchFailure(op, "unsupported vcvt type pair"); - Type maskElemType = rewriter.getIntegerType((*contract).maskBitWidth); - FailureOr mask = materializeDynamicPltMask( - rewriter, state, op.getLoc(), - getI32Constant(rewriter, op.getLoc(), *inputLanes), maskElemType); - if (failed(mask)) - return rewriter.notifyMatchFailure(op, "failed to materialize vcvt mask"); - Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); if (!resultType) return rewriter.notifyMatchFailure(op, "failed to convert vcvt result type"); @@ -5968,8 +5950,8 @@ class LowerVcvtOpPattern final : public OpConversionPattern { SmallVector argTypes; callArgs.push_back(adaptor.getInput()); argTypes.push_back(adaptor.getInput().getType()); - callArgs.push_back(*mask); - argTypes.push_back((*mask).getType()); + callArgs.push_back(adaptor.getMask()); + argTypes.push_back(adaptor.getMask().getType()); auto appendRndArg = [&]() -> LogicalResult { auto roundMode = diff --git a/test/basic/expand_tile_op_tilelang_tcvt.pto b/test/basic/expand_tile_op_tilelang_tcvt.pto index d3b2b890f..a178f5167 100644 --- a/test/basic/expand_tile_op_tilelang_tcvt.pto +++ b/test/basic/expand_tile_op_tilelang_tcvt.pto @@ -20,25 +20,25 @@ // CHECK-LABEL: func.func @TCVT_F32_TO_I32 // CHECK-NOT: pto.tcvt ins // CHECK: pto.vecscope -// CHECK: pto.vcvt {{.*}} {rnd = "R", sat = "SAT"} : !pto.vreg<64xf32> -> !pto.vreg<64xi32> -// CHECK: pto.vcvt {{.*}} {rnd = "A", sat = "SAT"} : !pto.vreg<64xf32> -> !pto.vreg<64xi32> +// CHECK: pto.vcvt {{.*}} {rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: pto.vcvt {{.*}} {rnd = "A", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> // CHECK-LABEL: func.func @TCVT_I32_TO_F32 // CHECK-NOT: pto.tcvt ins // CHECK: pto.vecscope -// CHECK: pto.vcvt {{.*}} {rnd = "R"} : !pto.vreg<64xi32> -> !pto.vreg<64xf32> -// CHECK: pto.vcvt {{.*}} {rnd = "A"} : !pto.vreg<64xi32> -> !pto.vreg<64xf32> +// CHECK: pto.vcvt {{.*}} {rnd = "R"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt {{.*}} {rnd = "A"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xf32> // CHECK-LABEL: func.func @TCVT_F16_TO_F32 // CHECK-NOT: pto.tcvt ins // CHECK: pto.vecscope // CHECK: pto.vlds {{.*}} {dist = "UNPK_B16"} : {{.*}} -> !pto.vreg<128xf16> -// CHECK: pto.vcvt {{.*}} {part = "EVEN"} : !pto.vreg<128xf16> -> !pto.vreg<64xf32> +// CHECK: pto.vcvt {{.*}} {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> // CHECK-LABEL: func.func @TCVT_F32_TO_F16 // CHECK-NOT: pto.tcvt ins // CHECK: pto.vecscope -// CHECK: pto.vcvt {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32> -> !pto.vreg<128xf16> +// CHECK: pto.vcvt {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> // CHECK: pto.vsts {{.*}} {dist = "PK_B32"} : !pto.vreg<128xf16>, {{.*}}, !pto.mask module { diff --git a/test/basic/tilelang_soft_vmod_backend_inline.pto b/test/basic/tilelang_soft_vmod_backend_inline.pto index 4c7b1e7fc..5f829ac65 100644 --- a/test/basic/tilelang_soft_vmod_backend_inline.pto +++ b/test/basic/tilelang_soft_vmod_backend_inline.pto @@ -31,14 +31,14 @@ module attributes {pto.target_arch = "a5"} { %active_mask_16 = pto.pnot %zero_mask_15, %arg2 : !pto.mask, !pto.mask -> !pto.mask %zero_u16_17 = pto.vbr %zero_10 : ui16 -> !pto.vreg<128xui16> %vy_lower_u16_18, %vy_higher_u16_19 = pto.vintlv %arg1, %zero_u16_17 : !pto.vreg<128xui16>, !pto.vreg<128xui16> -> !pto.vreg<128xui16>, !pto.vreg<128xui16> - %vy_lower_u32_20 = pto.vcvt %vy_lower_u16_18 {part = "EVEN"} : !pto.vreg<128xui16> -> !pto.vreg<64xui32> - %vy_higher_u32_21 = pto.vcvt %vy_higher_u16_19 {part = "EVEN"} : !pto.vreg<128xui16> -> !pto.vreg<64xui32> + %vy_lower_u32_20 = pto.vcvt %vy_lower_u16_18, %active_mask_16 {part = "EVEN"} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<64xui32> + %vy_higher_u32_21 = pto.vcvt %vy_higher_u16_19, %active_mask_16 {part = "EVEN"} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<64xui32> %active_low_22 = pto.vcmps %vy_lower_u32_20, %c0_ui32, %full_mask_b32_14, "ne" : !pto.vreg<64xui32>, ui32, !pto.mask -> !pto.mask %active_high_23 = pto.vcmps %vy_higher_u32_21, %c0_ui32, %full_mask_b32_14, "ne" : !pto.vreg<64xui32>, ui32, !pto.mask -> !pto.mask %tmp_2 = pto.vbitcast %vy_lower_u32_20 : !pto.vreg<64xui32> -> !pto.vreg<64xi32> - %vy_lower_f32_24 = pto.vcvt %tmp_2 {rnd = "F"} : !pto.vreg<64xi32> -> !pto.vreg<64xf32> + %vy_lower_f32_24 = pto.vcvt %tmp_2, %active_low_22 {rnd = "F"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xf32> %tmp_3 = pto.vbitcast %vy_higher_u32_21 : !pto.vreg<64xui32> -> !pto.vreg<64xi32> - %vy_higher_f32_25 = pto.vcvt %tmp_3 {rnd = "F"} : !pto.vreg<64xi32> -> !pto.vreg<64xf32> + %vy_higher_f32_25 = pto.vcvt %tmp_3, %active_high_23 {rnd = "F"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xf32> %tmp_4 = pto.vbr %fp32_one_12 : f32 -> !pto.vreg<64xf32> %vy_rec_lower_26 = pto.vdiv %tmp_4, %vy_lower_f32_24, %active_low_22 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %tmp_5 = pto.vbr %fp32_one_12 : f32 -> !pto.vreg<64xf32> @@ -47,13 +47,13 @@ module attributes {pto.target_arch = "a5"} { %vy_scale_lower_28 = pto.vmul %vy_rec_lower_26, %tmp_6, %active_low_22 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %tmp_7 = pto.vbr %c65536_0_f32 : f32 -> !pto.vreg<64xf32> %vy_scale_higher_29 = pto.vmul %vy_rec_higher_27, %tmp_7, %active_high_23 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> - %v_lower_i32_30 = pto.vcvt %vy_scale_lower_28 {rnd = "F", sat = "NOSAT"} : !pto.vreg<64xf32> -> !pto.vreg<64xi32> - %v_higher_i32_31 = pto.vcvt %vy_scale_higher_29 {rnd = "F", sat = "NOSAT"} : !pto.vreg<64xf32> -> !pto.vreg<64xi32> + %v_lower_i32_30 = pto.vcvt %vy_scale_lower_28, %active_low_22 {rnd = "F", sat = "NOSAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> + %v_higher_i32_31 = pto.vcvt %vy_scale_higher_29, %active_high_23 {rnd = "F", sat = "NOSAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> %v_lower_u32_32 = pto.vbitcast %v_lower_i32_30 : !pto.vreg<64xi32> -> !pto.vreg<64xui32> %v_higher_u32_33 = pto.vbitcast %v_higher_i32_31 : !pto.vreg<64xi32> -> !pto.vreg<64xui32> %vx_lower_u16_34, %vx_higher_u16_35 = pto.vintlv %arg0, %zero_u16_17 : !pto.vreg<128xui16>, !pto.vreg<128xui16> -> !pto.vreg<128xui16>, !pto.vreg<128xui16> - %vx_lower_u32_36 = pto.vcvt %vx_lower_u16_34 {part = "EVEN"} : !pto.vreg<128xui16> -> !pto.vreg<64xui32> - %vx_higher_u32_37 = pto.vcvt %vx_higher_u16_35 {part = "EVEN"} : !pto.vreg<128xui16> -> !pto.vreg<64xui32> + %vx_lower_u32_36 = pto.vcvt %vx_lower_u16_34, %active_mask_16 {part = "EVEN"} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<64xui32> + %vx_higher_u32_37 = pto.vcvt %vx_higher_u16_35, %active_mask_16 {part = "EVEN"} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<64xui32> %q_tmp_lower_38 = pto.vmul %v_lower_u32_32, %vx_lower_u32_36, %active_low_22 : !pto.vreg<64xui32>, !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<64xui32> %q_tmp_higher_39 = pto.vmul %v_higher_u32_33, %vx_higher_u32_37, %active_high_23 : !pto.vreg<64xui32>, !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<64xui32> %tmp_8 = pto.vbitcast %q_tmp_lower_38 : !pto.vreg<64xui32> -> !pto.vreg<128xui16> diff --git a/test/basic/vcvt_part_modes_verify_invalid.pto b/test/basic/vcvt_part_modes_verify_invalid.pto index caa8cae08..6130436e1 100644 --- a/test/basic/vcvt_part_modes_verify_invalid.pto +++ b/test/basic/vcvt_part_modes_verify_invalid.pto @@ -9,21 +9,13 @@ // RUN: not ptoas --pto-arch=a5 --pto-backend=vpto %s 2>&1 | FileCheck %s // CHECK: error: 'pto.vcvt' op part must be P0, P1, P2, or P3 for 8/32 vcvt forms -// CHECK: error: 'pto.vcvt' op part must be EVEN or ODD for 8/16 and 16/32 vcvt forms module attributes {pto.target_arch = "a5"} { func.func @vcvt_u32_to_u8_rejects_even(%seed: ui32) attributes {pto.kernel_kind = #pto.kernel_kind} { pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask %src = pto.vbr %seed : ui32 -> !pto.vreg<64xui32> - %bad = pto.vcvt %src {sat = "SAT", part = "EVEN"} : !pto.vreg<64xui32> -> !pto.vreg<256xui8> - } - return - } - - func.func @vcvt_u16_to_u8_rejects_p0(%seed: ui16) attributes {pto.kernel_kind = #pto.kernel_kind} { - pto.vecscope { - %src = pto.vbr %seed : ui16 -> !pto.vreg<128xui16> - %bad = pto.vcvt %src {sat = "SAT", part = "P0"} : !pto.vreg<128xui16> -> !pto.vreg<256xui8> + %bad = pto.vcvt %src, %mask {sat = "SAT", part = "EVEN"} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<256xui8> } return } diff --git a/test/basic/vcvt_part_modes_verify_invalid_even_odd.pto b/test/basic/vcvt_part_modes_verify_invalid_even_odd.pto new file mode 100644 index 000000000..b47b77696 --- /dev/null +++ b/test/basic/vcvt_part_modes_verify_invalid_even_odd.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --pto-arch=a5 --pto-backend=vpto %s 2>&1 | FileCheck %s + +// CHECK: error: 'pto.vcvt' op part must be EVEN or ODD for 8/16 and 16/32 vcvt forms + +module attributes {pto.target_arch = "a5"} { + func.func @vcvt_u16_to_u8_rejects_p0(%seed: ui16) attributes {pto.kernel_kind = #pto.kernel_kind} { + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + %src = pto.vbr %seed : ui16 -> !pto.vreg<128xui16> + %bad = pto.vcvt %src, %mask {sat = "SAT", part = "P0"} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<256xui8> + } + return + } +} diff --git a/test/basic/vcvt_part_modes_vpto_llvm.pto b/test/basic/vcvt_part_modes_vpto_llvm.pto index a3dfc75f2..087c73d2d 100644 --- a/test/basic/vcvt_part_modes_vpto_llvm.pto +++ b/test/basic/vcvt_part_modes_vpto_llvm.pto @@ -13,11 +13,12 @@ module attributes {pto.target_arch = "a5"} { %c0 = arith.constant 0 : index pto.vecscope { %mask = pto.pset_b8 "PAT_ALL" : !pto.mask + %src_mask = pto.pset_b32 "PAT_ALL" : !pto.mask %src = pto.vbr %seed : ui32 -> !pto.vreg<64xui32> - %p0 = pto.vcvt %src {sat = "SAT", part = "P0"} : !pto.vreg<64xui32> -> !pto.vreg<256xui8> - %p1 = pto.vcvt %src {sat = "SAT", part = "P1"} : !pto.vreg<64xui32> -> !pto.vreg<256xui8> - %p2 = pto.vcvt %src {sat = "SAT", part = "P2"} : !pto.vreg<64xui32> -> !pto.vreg<256xui8> - %p3 = pto.vcvt %src {sat = "SAT", part = "P3"} : !pto.vreg<64xui32> -> !pto.vreg<256xui8> + %p0 = pto.vcvt %src, %src_mask {sat = "SAT", part = "P0"} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<256xui8> + %p1 = pto.vcvt %src, %src_mask {sat = "SAT", part = "P1"} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<256xui8> + %p2 = pto.vcvt %src, %src_mask {sat = "SAT", part = "P2"} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<256xui8> + %p3 = pto.vcvt %src, %src_mask {sat = "SAT", part = "P3"} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<256xui8> %m01 = pto.vor %p0, %p1, %mask : !pto.vreg<256xui8>, !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<256xui8> %m23 = pto.vor %p2, %p3, %mask : !pto.vreg<256xui8>, !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<256xui8> %out = pto.vor %m01, %m23, %mask : !pto.vreg<256xui8>, !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<256xui8> @@ -28,7 +29,8 @@ module attributes {pto.target_arch = "a5"} { } // CHECK-LABEL: define void @vcvt_u32_to_u8_packed_parts( -// CHECK: call <256 x i8> @llvm.hivm.vcvtii.u322u8.x({{.*}} i32 0, i32 0) -// CHECK: call <256 x i8> @llvm.hivm.vcvtii.u322u8.x({{.*}} i32 0, i32 1) -// CHECK: call <256 x i8> @llvm.hivm.vcvtii.u322u8.x({{.*}} i32 0, i32 2) -// CHECK: call <256 x i8> @llvm.hivm.vcvtii.u322u8.x({{.*}} i32 0, i32 3) +// CHECK: [[SRC_MASK:%[0-9]+]] = call <256 x i1> @llvm.hivm.pset.b32(i32 0) +// CHECK: call <256 x i8> @llvm.hivm.vcvtii.u322u8.x({{.*}}, <256 x i1> [[SRC_MASK]], i32 0, i32 0) +// CHECK: call <256 x i8> @llvm.hivm.vcvtii.u322u8.x({{.*}}, <256 x i1> [[SRC_MASK]], i32 0, i32 1) +// CHECK: call <256 x i8> @llvm.hivm.vcvtii.u322u8.x({{.*}}, <256 x i1> [[SRC_MASK]], i32 0, i32 2) +// CHECK: call <256 x i8> @llvm.hivm.vcvtii.u322u8.x({{.*}}, <256 x i1> [[SRC_MASK]], i32 0, i32 3) diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-special/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-f16-special/kernel.pto index e688fdfa5..a6dfe2904 100644 --- a/test/vpto/cases/micro-op/conversion/vcvt-f16-special/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-special/kernel.pto @@ -24,9 +24,10 @@ module attributes {pto.target_arch = "a5"} { pto.vecscope { %full_mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %cvt_mask = pto.pset_b16 "PAT_ALL" : !pto.mask scf.for %offset = %c0 to %c1024 step %c64 { %loaded = pto.vlds %ub_in[%offset] {dist = "UNPK_B16"} : !pto.ptr -> !pto.vreg<128xf16> - %out = pto.vcvt %loaded {part = "EVEN"} : !pto.vreg<128xf16> -> !pto.vreg<64xf32> + %out = pto.vcvt %loaded, %cvt_mask {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> pto.vsts %out, %ub_out[%offset], %full_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask } } diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/kernel.pto index a2eba0eb2..0edb62220 100644 --- a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/kernel.pto @@ -26,11 +26,12 @@ module attributes {pto.target_arch = "a5"} { pto.vecscope { %full_mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %cvt_mask = pto.pset_b16 "PAT_ALL" : !pto.mask // Use packed f16 load (no UNPK): PART_EVEN selects the lower 16-bit // element from each f16 pair inside a b32 lane. scf.for %offset = %c0 to %c512 step %c64 { %loaded = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xf16> - %out = pto.vcvt %loaded {part = "EVEN"} : !pto.vreg<128xf16> -> !pto.vreg<64xf32> + %out = pto.vcvt %loaded, %cvt_mask {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> pto.vsts %out, %ub_out[%offset], %full_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask } } diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/kernel.pto index c89d17f43..61eb89f3e 100644 --- a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/kernel.pto @@ -26,11 +26,12 @@ module attributes {pto.target_arch = "a5"} { pto.vecscope { %full_mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %cvt_mask = pto.pset_b16 "PAT_ALL" : !pto.mask // Use packed f16 load (no UNPK): PART_ODD then selects the upper 16-bit // element from each f16 pair inside a b32 lane. scf.for %offset = %c0 to %c512 step %c64 { %loaded = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xf16> - %out = pto.vcvt %loaded {part = "ODD"} : !pto.vreg<128xf16> -> !pto.vreg<64xf32> + %out = pto.vcvt %loaded, %cvt_mask {part = "ODD"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> pto.vsts %out, %ub_out[%offset], %full_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask } } diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/kernel.pto index 9e9db6afb..7f3dfbb72 100644 --- a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/kernel.pto @@ -24,9 +24,10 @@ module attributes {pto.target_arch = "a5"} { pto.vecscope { %full_mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %cvt_mask = pto.pset_b16 "PAT_ALL" : !pto.mask scf.for %offset = %c0 to %c1024 step %c64 { %loaded = pto.vlds %ub_in[%offset] {dist = "UNPK_B16"} : !pto.ptr -> !pto.vreg<128xf16> - %out = pto.vcvt %loaded {part = "EVEN"} : !pto.vreg<128xf16> -> !pto.vreg<64xf32> + %out = pto.vcvt %loaded, %cvt_mask {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> pto.vsts %out, %ub_out[%offset], %full_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask } } diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-special/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-f32-special/kernel.pto index 33a833c7b..c9a31df72 100644 --- a/test/vpto/cases/micro-op/conversion/vcvt-f32-special/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-special/kernel.pto @@ -28,11 +28,13 @@ module attributes {pto.target_arch = "a5"} { pto.vecscope { %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %lower_mask, %upper_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %upper_mask, %_next_b32 = pto.plt_b32 %upper_remaining : i32 -> !pto.mask, i32 %upper_offset = arith.addi %offset, %c64 : index %lower = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> %upper = pto.vlds %ub_in[%upper_offset] : !pto.ptr -> !pto.vreg<64xf32> - %even = pto.vcvt %lower {rnd = "R", sat = "SAT", part = "EVEN"} : !pto.vreg<64xf32> -> !pto.vreg<128xf16> - %odd = pto.vcvt %upper {rnd = "R", sat = "SAT", part = "ODD"} : !pto.vreg<64xf32> -> !pto.vreg<128xf16> + %even = pto.vcvt %lower, %lower_mask {rnd = "R", sat = "SAT", part = "EVEN"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + %odd = pto.vcvt %upper, %upper_mask {rnd = "R", sat = "SAT", part = "ODD"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> %merged = pto.vor %even, %odd, %mask : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> pto.vsts %merged, %ub_out[%offset], %mask : !pto.vreg<128xf16>, !pto.ptr, !pto.mask scf.yield %next_remaining : i32 diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/kernel.pto index ce8ed8c09..5ffd48361 100644 --- a/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/kernel.pto @@ -32,7 +32,7 @@ module attributes {pto.target_arch = "a5"} { %mask = pto.pset_b32 "PAT_ALL" : !pto.mask scf.for %offset = %c0 to %c1024 step %c64 { %loaded = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> - %converted = pto.vcvt %loaded {rnd = "R", sat = "SAT", part = "EVEN"} : !pto.vreg<64xf32> -> !pto.vreg<128xf16> + %converted = pto.vcvt %loaded, %mask {rnd = "R", sat = "SAT", part = "EVEN"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> pto.vsts %converted, %ub_out[%offset], %mask {dist = "PK_B32"} : !pto.vreg<128xf16>, !pto.ptr, !pto.mask } } diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/kernel.pto index 110ea1e19..558541955 100644 --- a/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/kernel.pto @@ -28,11 +28,13 @@ module attributes {pto.target_arch = "a5"} { pto.vecscope { %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %lower_mask, %upper_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %upper_mask, %_next_b32 = pto.plt_b32 %upper_remaining : i32 -> !pto.mask, i32 %upper_offset = arith.addi %offset, %c64 : index %lower = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> %upper = pto.vlds %ub_in[%upper_offset] : !pto.ptr -> !pto.vreg<64xf32> - %even = pto.vcvt %lower {rnd = "R", sat = "SAT", part = "EVEN"} : !pto.vreg<64xf32> -> !pto.vreg<128xf16> - %odd = pto.vcvt %upper {rnd = "R", sat = "SAT", part = "ODD"} : !pto.vreg<64xf32> -> !pto.vreg<128xf16> + %even = pto.vcvt %lower, %lower_mask {rnd = "R", sat = "SAT", part = "EVEN"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + %odd = pto.vcvt %upper, %upper_mask {rnd = "R", sat = "SAT", part = "ODD"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> %merged = pto.vor %even, %odd, %mask : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> pto.vsts %merged, %ub_out[%offset], %mask : !pto.vreg<128xf16>, !pto.ptr, !pto.mask scf.yield %next_remaining : i32 diff --git a/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/kernel.pto index f635b34d7..7b7e516e6 100644 --- a/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/kernel.pto @@ -33,11 +33,13 @@ module attributes {pto.target_arch = "a5"} { pto.vecscope { %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %lower_mask, %upper_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %upper_mask, %_next_b32 = pto.plt_b32 %upper_remaining : i32 -> !pto.mask, i32 %upper_offset = arith.addi %offset, %c64 : index %lower = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xi32> %upper = pto.vlds %ub_in[%upper_offset] : !pto.ptr -> !pto.vreg<64xi32> - %even = pto.vcvt %lower {sat = "SAT", part = "EVEN"} : !pto.vreg<64xi32> -> !pto.vreg<128xi16> - %odd = pto.vcvt %upper {sat = "SAT", part = "ODD"} : !pto.vreg<64xi32> -> !pto.vreg<128xi16> + %even = pto.vcvt %lower, %lower_mask {sat = "SAT", part = "EVEN"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<128xi16> + %odd = pto.vcvt %upper, %upper_mask {sat = "SAT", part = "ODD"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<128xi16> %merged = pto.vor %even, %odd, %mask : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> pto.vsts %merged, %ub_out[%offset], %mask : !pto.vreg<128xi16>, !pto.ptr, !pto.mask scf.yield %next_remaining : i32 diff --git a/test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/kernel.pto index 6c8152203..80838d072 100644 --- a/test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/kernel.pto @@ -39,7 +39,7 @@ module attributes {pto.target_arch = "a5"} { scf.for %store_offset = %c0 to %c512 step %c16 { %input_offset = arith.muli %store_offset, %c2 : index %loaded = pto.vlds %ub_in[%input_offset] : !pto.ptr -> !pto.vreg<32xsi64> - %converted = pto.vcvt %loaded {rnd = "R", part = "EVEN"} : !pto.vreg<32xsi64> -> !pto.vreg<64xf32> + %converted = pto.vcvt %loaded, %mask {rnd = "R", part = "EVEN"} : !pto.vreg<32xsi64>, !pto.mask -> !pto.vreg<64xf32> pto.vsts %converted, %ub_out[%store_offset], %mask {dist = "PK_B64"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask } } diff --git a/test/vpto/cases/micro-op/conversion/vcvt-tail-special/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-tail-special/kernel.pto index e08e2523a..6c190e740 100644 --- a/test/vpto/cases/micro-op/conversion/vcvt-tail-special/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vcvt-tail-special/kernel.pto @@ -28,11 +28,13 @@ module attributes {pto.target_arch = "a5"} { pto.vecscope { %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1000_i32) -> (i32) { %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %lower_mask, %upper_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %upper_mask, %_next_b32 = pto.plt_b32 %upper_remaining : i32 -> !pto.mask, i32 %upper_offset = arith.addi %offset, %c64 : index %lower = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> %upper = pto.vlds %ub_in[%upper_offset] : !pto.ptr -> !pto.vreg<64xf32> - %even = pto.vcvt %lower {rnd = "R", sat = "SAT", part = "EVEN"} : !pto.vreg<64xf32> -> !pto.vreg<128xf16> - %odd = pto.vcvt %upper {rnd = "R", sat = "SAT", part = "ODD"} : !pto.vreg<64xf32> -> !pto.vreg<128xf16> + %even = pto.vcvt %lower, %lower_mask {rnd = "R", sat = "SAT", part = "EVEN"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + %odd = pto.vcvt %upper, %upper_mask {rnd = "R", sat = "SAT", part = "ODD"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> %merged = pto.vor %even, %odd, %mask : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> pto.vsts %merged, %ub_out[%offset], %mask : !pto.vreg<128xf16>, !pto.ptr, !pto.mask scf.yield %next_remaining : i32 diff --git a/test/vpto/cases/micro-op/conversion/vcvt-tail/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-tail/kernel.pto index ec42555e4..118459f5c 100644 --- a/test/vpto/cases/micro-op/conversion/vcvt-tail/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vcvt-tail/kernel.pto @@ -28,11 +28,13 @@ module attributes {pto.target_arch = "a5"} { pto.vecscope { %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1000_i32) -> (i32) { %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %lower_mask, %upper_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %upper_mask, %_next_b32 = pto.plt_b32 %upper_remaining : i32 -> !pto.mask, i32 %upper_offset = arith.addi %offset, %c64 : index %lower = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> %upper = pto.vlds %ub_in[%upper_offset] : !pto.ptr -> !pto.vreg<64xf32> - %even = pto.vcvt %lower {rnd = "R", sat = "SAT", part = "EVEN"} : !pto.vreg<64xf32> -> !pto.vreg<128xf16> - %odd = pto.vcvt %upper {rnd = "R", sat = "SAT", part = "ODD"} : !pto.vreg<64xf32> -> !pto.vreg<128xf16> + %even = pto.vcvt %lower, %lower_mask {rnd = "R", sat = "SAT", part = "EVEN"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + %odd = pto.vcvt %upper, %upper_mask {rnd = "R", sat = "SAT", part = "ODD"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> %merged = pto.vor %even, %odd, %mask : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> pto.vsts %merged, %ub_out[%offset], %mask : !pto.vreg<128xf16>, !pto.ptr, !pto.mask scf.yield %next_remaining : i32 diff --git a/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/kernel.pto index d3d46469e..41c8859e9 100644 --- a/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/kernel.pto +++ b/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/kernel.pto @@ -34,6 +34,7 @@ module attributes {pto.target_arch = "a5"} { pto.vecscope { %full_mask = pto.pset_b8 "PAT_ALL" : !pto.mask + %cvt_mask = pto.pset_b32 "PAT_ALL" : !pto.mask scf.for %offset = %c0 to %c1024 step %c256 { %offset_p1 = arith.addi %offset, %c64 : index %offset_p2 = arith.addi %offset, %c128 : index @@ -42,10 +43,10 @@ module attributes {pto.target_arch = "a5"} { %src_p1 = pto.vlds %ub_in[%offset_p1] : !pto.ptr -> !pto.vreg<64xui32> %src_p2 = pto.vlds %ub_in[%offset_p2] : !pto.ptr -> !pto.vreg<64xui32> %src_p3 = pto.vlds %ub_in[%offset_p3] : !pto.ptr -> !pto.vreg<64xui32> - %part_p0 = pto.vcvt %src_p0 {sat = "SAT", part = "P0"} : !pto.vreg<64xui32> -> !pto.vreg<256xui8> - %part_p1 = pto.vcvt %src_p1 {sat = "SAT", part = "P1"} : !pto.vreg<64xui32> -> !pto.vreg<256xui8> - %part_p2 = pto.vcvt %src_p2 {sat = "SAT", part = "P2"} : !pto.vreg<64xui32> -> !pto.vreg<256xui8> - %part_p3 = pto.vcvt %src_p3 {sat = "SAT", part = "P3"} : !pto.vreg<64xui32> -> !pto.vreg<256xui8> + %part_p0 = pto.vcvt %src_p0, %cvt_mask {sat = "SAT", part = "P0"} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<256xui8> + %part_p1 = pto.vcvt %src_p1, %cvt_mask {sat = "SAT", part = "P1"} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<256xui8> + %part_p2 = pto.vcvt %src_p2, %cvt_mask {sat = "SAT", part = "P2"} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<256xui8> + %part_p3 = pto.vcvt %src_p3, %cvt_mask {sat = "SAT", part = "P3"} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<256xui8> %merged01 = pto.vor %part_p0, %part_p1, %full_mask : !pto.vreg<256xui8>, !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<256xui8> %merged23 = pto.vor %part_p2, %part_p3, %full_mask : !pto.vreg<256xui8>, !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<256xui8> %merged = pto.vor %merged01, %merged23, %full_mask : !pto.vreg<256xui8>, !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<256xui8> diff --git a/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/kernel.pto index ebc2e8226..6628c9aca 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/kernel.pto @@ -41,7 +41,7 @@ module attributes {pto.target_arch = "a5"} { scf.for %offset = %c0 to %c1024 step %c64 { %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> %addend = pto.vlds %ub_addend[%offset] : !pto.ptr -> !pto.vreg<64xf32> - %sum = pto.vaxpy %vec, %addend, %alpha : !pto.vreg<64xf32>, !pto.vreg<64xf32>, f32 -> !pto.vreg<64xf32> + %sum = pto.vaxpy %vec, %addend, %alpha, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask } } diff --git a/test/vpto_tilelang_inline_soft_divmod_fastpath.pto b/test/vpto_tilelang_inline_soft_divmod_fastpath.pto index 98a2c297b..a6ea11748 100644 --- a/test/vpto_tilelang_inline_soft_divmod_fastpath.pto +++ b/test/vpto_tilelang_inline_soft_divmod_fastpath.pto @@ -54,14 +54,14 @@ module attributes {pto.target_arch = "a5"} { %active_mask_16 = pto.pnot %zero_mask_15, %arg2 : !pto.mask, !pto.mask -> !pto.mask %zero_u16_17 = pto.vbr %zero_10 : ui16 -> !pto.vreg<128xui16> %vy_lower_u16_18, %vy_higher_u16_19 = pto.vintlv %arg1, %zero_u16_17 : !pto.vreg<128xui16>, !pto.vreg<128xui16> -> !pto.vreg<128xui16>, !pto.vreg<128xui16> - %vy_lower_u32_20 = pto.vcvt %vy_lower_u16_18 {part = "EVEN"} : !pto.vreg<128xui16> -> !pto.vreg<64xui32> - %vy_higher_u32_21 = pto.vcvt %vy_higher_u16_19 {part = "EVEN"} : !pto.vreg<128xui16> -> !pto.vreg<64xui32> + %vy_lower_u32_20 = pto.vcvt %vy_lower_u16_18, %active_mask_16 {part = "EVEN"} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<64xui32> + %vy_higher_u32_21 = pto.vcvt %vy_higher_u16_19, %active_mask_16 {part = "EVEN"} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<64xui32> %active_low_22 = pto.vcmps %vy_lower_u32_20, %c0_ui32, %full_mask_b32_14, "ne" : !pto.vreg<64xui32>, ui32, !pto.mask -> !pto.mask %active_high_23 = pto.vcmps %vy_higher_u32_21, %c0_ui32, %full_mask_b32_14, "ne" : !pto.vreg<64xui32>, ui32, !pto.mask -> !pto.mask %tmp_2 = pto.vbitcast %vy_lower_u32_20 : !pto.vreg<64xui32> -> !pto.vreg<64xi32> - %vy_lower_f32_24 = pto.vcvt %tmp_2 {rnd = "F"} : !pto.vreg<64xi32> -> !pto.vreg<64xf32> + %vy_lower_f32_24 = pto.vcvt %tmp_2, %active_low_22 {rnd = "F"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xf32> %tmp_3 = pto.vbitcast %vy_higher_u32_21 : !pto.vreg<64xui32> -> !pto.vreg<64xi32> - %vy_higher_f32_25 = pto.vcvt %tmp_3 {rnd = "F"} : !pto.vreg<64xi32> -> !pto.vreg<64xf32> + %vy_higher_f32_25 = pto.vcvt %tmp_3, %active_high_23 {rnd = "F"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xf32> %tmp_4 = pto.vbr %fp32_one_12 : f32 -> !pto.vreg<64xf32> %vy_rec_lower_26 = pto.vdiv %tmp_4, %vy_lower_f32_24, %active_low_22 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %tmp_5 = pto.vbr %fp32_one_12 : f32 -> !pto.vreg<64xf32> @@ -70,13 +70,13 @@ module attributes {pto.target_arch = "a5"} { %vy_scale_lower_28 = pto.vmul %vy_rec_lower_26, %tmp_6, %active_low_22 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %tmp_7 = pto.vbr %c65536_0_f32 : f32 -> !pto.vreg<64xf32> %vy_scale_higher_29 = pto.vmul %vy_rec_higher_27, %tmp_7, %active_high_23 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> - %v_lower_i32_30 = pto.vcvt %vy_scale_lower_28 {rnd = "F", sat = "NOSAT"} : !pto.vreg<64xf32> -> !pto.vreg<64xi32> - %v_higher_i32_31 = pto.vcvt %vy_scale_higher_29 {rnd = "F", sat = "NOSAT"} : !pto.vreg<64xf32> -> !pto.vreg<64xi32> + %v_lower_i32_30 = pto.vcvt %vy_scale_lower_28, %active_low_22 {rnd = "F", sat = "NOSAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> + %v_higher_i32_31 = pto.vcvt %vy_scale_higher_29, %active_high_23 {rnd = "F", sat = "NOSAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> %v_lower_u32_32 = pto.vbitcast %v_lower_i32_30 : !pto.vreg<64xi32> -> !pto.vreg<64xui32> %v_higher_u32_33 = pto.vbitcast %v_higher_i32_31 : !pto.vreg<64xi32> -> !pto.vreg<64xui32> %vx_lower_u16_34, %vx_higher_u16_35 = pto.vintlv %arg0, %zero_u16_17 : !pto.vreg<128xui16>, !pto.vreg<128xui16> -> !pto.vreg<128xui16>, !pto.vreg<128xui16> - %vx_lower_u32_36 = pto.vcvt %vx_lower_u16_34 {part = "EVEN"} : !pto.vreg<128xui16> -> !pto.vreg<64xui32> - %vx_higher_u32_37 = pto.vcvt %vx_higher_u16_35 {part = "EVEN"} : !pto.vreg<128xui16> -> !pto.vreg<64xui32> + %vx_lower_u32_36 = pto.vcvt %vx_lower_u16_34, %active_mask_16 {part = "EVEN"} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<64xui32> + %vx_higher_u32_37 = pto.vcvt %vx_higher_u16_35, %active_mask_16 {part = "EVEN"} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<64xui32> %q_tmp_lower_38 = pto.vmul %v_lower_u32_32, %vx_lower_u32_36, %active_low_22 : !pto.vreg<64xui32>, !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<64xui32> %q_tmp_higher_39 = pto.vmul %v_higher_u32_33, %vx_higher_u32_37, %active_high_23 : !pto.vreg<64xui32>, !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<64xui32> %tmp_8 = pto.vbitcast %q_tmp_lower_38 : !pto.vreg<64xui32> -> !pto.vreg<128xui16> diff --git a/tilelang-dsl/docs/vpto_spec/vpto-spec-current.md b/tilelang-dsl/docs/vpto_spec/vpto-spec-current.md index a8c11d7a5..0f2758187 100644 --- a/tilelang-dsl/docs/vpto_spec/vpto-spec-current.md +++ b/tilelang-dsl/docs/vpto_spec/vpto-spec-current.md @@ -674,7 +674,7 @@ pto.vstsx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vre **Conversion (one vector in, different-typed vector out):** ```mlir -%result = pto.vcvt %input {rnd = "R", sat = "SAT", part = "EVEN"} : !pto.vreg -> !pto.vreg +%result = pto.vcvt %input, %mask {rnd = "R", sat = "SAT", part = "EVEN"} : !pto.vreg, !pto.mask -> !pto.vreg ``` **Predicate construction:** @@ -4916,7 +4916,7 @@ for (int i = 0; i < N; i++) ##### `pto.vaxpy` -- **syntax:** `%result = pto.vaxpy %src0, %src1, %alpha : !pto.vreg, !pto.vreg, T -> !pto.vreg` +- **syntax:** `%result = pto.vaxpy %src0, %src1, %alpha, %mask : !pto.vreg, !pto.vreg, T, !pto.mask -> !pto.vreg` - **A5 types:** f16, f32 - **semantics:** AXPY — scalar-vector multiply-add. @@ -4925,8 +4925,8 @@ for (int i = 0; i < N; i++) dst[i] = alpha * src0[i] + src1[i]; ``` -- **inputs:** `%src0` is the scaled vector, `%src1` is the addend vector, and - `%alpha` is the scalar multiplier. +- **inputs:** `%src0` is the scaled vector, `%src1` is the addend vector, + `%alpha` is the scalar multiplier, and `%mask` selects active lanes. - **outputs:** `%result` is the fused AXPY result. - **constraints and limitations:** Floating-point element types only on the current documented surface. diff --git a/tilelang-dsl/docs/vpto_spec/vpto-spec-v0.3.md b/tilelang-dsl/docs/vpto_spec/vpto-spec-v0.3.md index c2f10ab6d..f3dabff81 100644 --- a/tilelang-dsl/docs/vpto_spec/vpto-spec-v0.3.md +++ b/tilelang-dsl/docs/vpto_spec/vpto-spec-v0.3.md @@ -674,7 +674,7 @@ pto.vstsx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vre **Conversion (one vector in, different-typed vector out):** ```mlir -%result = pto.vcvt %input {rnd = "R", sat = "SAT", part = "EVEN"} : !pto.vreg -> !pto.vreg +%result = pto.vcvt %input, %mask {rnd = "R", sat = "SAT", part = "EVEN"} : !pto.vreg, !pto.mask -> !pto.vreg ``` **Predicate construction:** @@ -4882,7 +4882,7 @@ for (int i = 0; i < N; i++) ##### `pto.vaxpy` -- **syntax:** `%result = pto.vaxpy %src0, %src1, %alpha : !pto.vreg, !pto.vreg, T -> !pto.vreg` +- **syntax:** `%result = pto.vaxpy %src0, %src1, %alpha, %mask : !pto.vreg, !pto.vreg, T, !pto.mask -> !pto.vreg` - **A5 types:** f16, f32 - **semantics:** AXPY — scalar-vector multiply-add. @@ -4891,8 +4891,8 @@ for (int i = 0; i < N; i++) dst[i] = alpha * src0[i] + src1[i]; ``` -- **inputs:** `%src0` is the scaled vector, `%src1` is the addend vector, and - `%alpha` is the scalar multiplier. +- **inputs:** `%src0` is the scaled vector, `%src1` is the addend vector, + `%alpha` is the scalar multiplier, and `%mask` selects active lanes. - **outputs:** `%result` is the fused AXPY result. - **constraints and limitations:** Floating-point element types only on the current documented surface. diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index c2880d84c..5549f8eee 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -2960,6 +2960,7 @@ def _lower_call_expr( if expr.name == "vcvt": value = self._lower_expr(expr.args[0], env, indent=indent, into=into) + mask = self._lower_expr(expr.args[2], env, indent=indent, into=into) attr_parts: list[str] = [] if self._has_optional_string_literal(expr.args[3]): attr_parts.append(f"rnd = {self._render_string_literal(expr.args[3])}") @@ -2970,8 +2971,8 @@ def _lower_call_expr( attr_suffix = f" {{{', '.join(attr_parts)}}}" if attr_parts else "" into.append( self._indent(indent) - + f"{result_name} = pto.vcvt {value.name}{attr_suffix} : " - + f"{self._render_type(value.type)} -> {self._render_type(expr.type)}" + + f"{result_name} = pto.vcvt {value.name}, {mask.name}{attr_suffix} : " + + f"{self._render_type(value.type)}, {self._render_type(mask.type)} -> {self._render_type(expr.type)}" ) return _RenderedValue(name=result_name, type=expr.type) diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index d29688f54..b7c5e373d 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -3600,7 +3600,7 @@ def kernel(dst: pto.Tile, src: pto.Tile): self.assertIn('part = "ODD"', text) self.assertRegex( text, - r"= pto\.vcvt %[^,\s]+(?: \{[^}]+\})? : !pto\.vreg<[^>]+> -> !pto\.vreg<[^>]+>", + r"= pto\.vcvt %[^,\s]+, %[^,\s]+(?: \{[^}]+\})? : !pto\.vreg<[^>]+>, !pto\.mask -> !pto\.vreg<[^>]+>", ) def test_vcvt_supports_part_t_modes_with_enum(self) -> None: From 690e8bc258166a227f39e34c56d37f4c6e537c35 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Sat, 25 Apr 2026 13:10:13 +0800 Subject: [PATCH 189/192] Fix installed TileLang resources and disable wheel CI --- .github/workflows/build_wheel.yml | 16 +++-- .github/workflows/build_wheel_mac.yml | 24 +++---- docker/Dockerfile | 1 + docker/collect_ptoas_dist.sh | 29 ++++++++- docker/collect_ptoas_dist_mac.sh | 28 +++++++- docker/create_wheel.sh | 7 ++ tilelang-dsl/CMakeLists.txt | 16 +++++ tools/ptoas/CMakeLists.txt | 5 ++ tools/ptoas/ptoas.cpp | 92 ++++++++++++++++++++++++--- 9 files changed, 191 insertions(+), 27 deletions(-) diff --git a/.github/workflows/build_wheel.yml b/.github/workflows/build_wheel.yml index e137aea64..391f810c8 100644 --- a/.github/workflows/build_wheel.yml +++ b/.github/workflows/build_wheel.yml @@ -28,12 +28,14 @@ env: jobs: build_wheel: - name: Build wheel (Python ${{ matrix.python }}, ${{ matrix.arch }}) + # Wheel publication is paused, but this job still builds and uploads the + # packaged ptoas binary distribution. + name: Build ptoas-bin (${{ matrix.arch }}) runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-latest' || 'ubuntu-24.04-arm' }} strategy: fail-fast: false matrix: - python: ["3.10", "3.11", "3.12"] + python: ["3.11"] arch: ["x86_64", "aarch64"] container: @@ -159,12 +161,14 @@ jobs: ninja -C build install - name: Create Python wheel + if: false run: | export PATH="${PY_PATH}/bin:$PATH" export PTOAS_PYTHON_PACKAGE_VERSION="${PTOAS_VERSION}" bash $PTO_SOURCE_DIR/docker/create_wheel.sh - name: Repair wheel with auditwheel + if: false run: | export PATH="${PY_PATH}/bin:$PATH" export PY_PACKAGE_DIR=$LLVM_BUILD_DIR/tools/mlir/python_packages/mlir_core @@ -173,6 +177,7 @@ jobs: auditwheel repair --plat manylinux_2_34_${{ matrix.arch }} dist/ptoas*.whl -w wheelhouse - name: Test wheel installation + if: false run: | export PATH="${PY_PATH}/bin:$PATH" export PY_PACKAGE_DIR=$LLVM_BUILD_DIR/tools/mlir/python_packages/mlir_core @@ -185,32 +190,33 @@ jobs: bash $PTO_SOURCE_DIR/docker/test_ptoas_cli.sh - name: Copy wheel to workspace + if: false run: | export PY_PACKAGE_DIR=$LLVM_BUILD_DIR/tools/mlir/python_packages/mlir_core mkdir -p $GITHUB_WORKSPACE/wheelhouse cp $PY_PACKAGE_DIR/wheelhouse/ptoas*.whl $GITHUB_WORKSPACE/wheelhouse/ - name: Upload wheel artifact + if: false uses: actions/upload-artifact@v4 with: name: ptoas-wheel-py${{ matrix.python }}-${{ matrix.arch }} path: wheelhouse/*.whl - name: Collect ptoas binary and dependencies - if: matrix.python == '3.11' run: | bash $PTO_SOURCE_DIR/docker/collect_ptoas_dist.sh $GITHUB_WORKSPACE/ptoas-dist - name: Upload ptoas binary artifact - if: matrix.python == '3.11' uses: actions/upload-artifact@v4 with: name: ptoas-bin-${{ matrix.arch }} path: ptoas-dist/ upload_release_assets: + # Disabled together with build_wheel because wheel publication is paused. + if: false name: Upload release assets - if: github.event_name == 'release' || github.event_name == 'schedule' needs: build_wheel runs-on: ubuntu-latest diff --git a/.github/workflows/build_wheel_mac.yml b/.github/workflows/build_wheel_mac.yml index 0be6eb7a4..fd016df82 100644 --- a/.github/workflows/build_wheel_mac.yml +++ b/.github/workflows/build_wheel_mac.yml @@ -27,17 +27,15 @@ env: jobs: build_wheel: - name: Build wheel (Python ${{ matrix.python }}, ${{ matrix.arch }}) + # Wheel publication is paused, but this job still builds and uploads the + # packaged ptoas binary distribution. + name: Build ptoas-bin (macOS ${{ matrix.arch }}) runs-on: ${{ matrix.arch == 'x86_64' && 'macos-15-intel' || 'macos-26' }} strategy: fail-fast: false matrix: - python: ["3.10", "3.11", "3.12"] + python: ["3.11"] arch: ["x86_64", "aarch64"] - # Keep macOS matrix jobs at 5 to stay within GitHub Actions macOS job limits. - exclude: - - python: "3.10" - arch: "x86_64" steps: - name: Checkout repository @@ -161,6 +159,7 @@ jobs: ninja -C build install - name: Create Python wheel + if: false run: | export PY_PACKAGE_DIR=$LLVM_BUILD_DIR/tools/mlir/python_packages/mlir_core export PTOAS_PYTHON_PACKAGE_VERSION="${PTOAS_VERSION}" @@ -180,6 +179,7 @@ jobs: fi - name: Repair wheel with delocate + if: false run: | set -euo pipefail export PY_PACKAGE_DIR=$LLVM_BUILD_DIR/tools/mlir/python_packages/mlir_core @@ -254,6 +254,7 @@ jobs: ls -lh wheelhouse - name: Test wheel installation + if: false run: | export PY_PACKAGE_DIR=$LLVM_BUILD_DIR/tools/mlir/python_packages/mlir_core pip install $PY_PACKAGE_DIR/wheelhouse/ptoas*.whl @@ -266,31 +267,30 @@ jobs: bash $PTO_SOURCE_DIR/docker/test_ptoas_cli.sh - name: Copy wheel to workspace + if: false run: | export PY_PACKAGE_DIR=$LLVM_BUILD_DIR/tools/mlir/python_packages/mlir_core mkdir -p $GITHUB_WORKSPACE/wheelhouse cp $PY_PACKAGE_DIR/wheelhouse/ptoas*.whl $GITHUB_WORKSPACE/wheelhouse/ - name: Upload wheel artifact + if: false uses: actions/upload-artifact@v4 with: name: ptoas-wheel-macos-py${{ matrix.python }}-${{ matrix.arch }} path: wheelhouse/*.whl - name: Collect ptoas binary and dependencies - if: matrix.python == '3.11' run: | bash "$PTO_SOURCE_DIR/docker/collect_ptoas_dist_mac.sh" "$GITHUB_WORKSPACE/ptoas-dist" - name: Archive ptoas binary artifact - if: matrix.python == '3.11' run: | chmod +x "$GITHUB_WORKSPACE/ptoas-dist/ptoas" "$GITHUB_WORKSPACE/ptoas-dist/bin/ptoas" tar -czf "$GITHUB_WORKSPACE/ptoas-bin-macos-${{ matrix.arch }}.tar.gz" \ -C "$GITHUB_WORKSPACE/ptoas-dist" . - name: Smoke test archived ptoas binary artifact - if: matrix.python == '3.11' run: | TEST_DIR="$RUNNER_TEMP/ptoas-dist-smoke-${{ matrix.arch }}" rm -rf "$TEST_DIR" @@ -308,7 +308,7 @@ jobs: >/dev/null - name: Smoke test wheel imports after collecting artifacts - if: matrix.python == '3.11' + if: false run: | # Test the copied wheel artifact from an isolated env with no build-tree # paths, so post-collect/release regressions (e.g. ARM cs_invalid_page) @@ -344,15 +344,15 @@ jobs: fi - name: Upload ptoas binary artifact - if: matrix.python == '3.11' uses: actions/upload-artifact@v4 with: name: ptoas-bin-macos-${{ matrix.arch }} path: ptoas-bin-macos-${{ matrix.arch }}.tar.gz upload_release_assets: + # Disabled together with build_wheel because wheel publication is paused. + if: false name: Upload release assets - if: github.event_name == 'release' || github.event_name == 'schedule' needs: build_wheel runs-on: ubuntu-latest diff --git a/docker/Dockerfile b/docker/Dockerfile index efc6b84fa..ab0329be4 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -71,6 +71,7 @@ ENV PY_PACKAGE_DIR=$LLVM_BUILD_DIR/tools/mlir/python_packages/mlir_core/ # copy pto.py, _pto_ops_gen.py RUN cp $PTO_INSTALL_DIR/mlir/dialects/*.py $PY_PACKAGE_DIR/mlir/dialects/ +RUN rm -rf $PY_PACKAGE_DIR/tilelang_dsl $PY_PACKAGE_DIR/TileOps && cp -R $PTO_INSTALL_DIR/tilelang_dsl $PY_PACKAGE_DIR/tilelang_dsl && cp -R $PTO_INSTALL_DIR/share/ptoas/TileOps $PY_PACKAGE_DIR/TileOps COPY ./setup.py $PY_PACKAGE_DIR/ diff --git a/docker/collect_ptoas_dist.sh b/docker/collect_ptoas_dist.sh index ca035cfc7..6c716208b 100755 --- a/docker/collect_ptoas_dist.sh +++ b/docker/collect_ptoas_dist.sh @@ -21,6 +21,8 @@ # ptoas - Wrapper script that sets up LD_LIBRARY_PATH # bin/ptoas - The actual ptoas binary # lib/*.so* - Required shared library dependencies +# share/ptoas/TileOps - TileLang template library +# tilelang_dsl/ - TileLang DSL Python package set -euo pipefail @@ -43,6 +45,10 @@ export LD_LIBRARY_PATH="${LLVM_BUILD_DIR}/lib:${PTO_INSTALL_DIR}/lib:${LD_LIBRAR PTOAS_BIN="${PTO_SOURCE_DIR}/build/tools/ptoas/ptoas" PTOAS_DEPS_DIR="${PTOAS_DIST_DIR}/lib" +PTOAS_TILEOPS_SRC_DIR="${PTO_INSTALL_DIR}/share/ptoas/TileOps" +PTOAS_TILEOPS_DIST_DIR="${PTOAS_DIST_DIR}/share/ptoas/TileOps" +PTOAS_TILELANG_DSL_SRC_DIR="${PTO_INSTALL_DIR}/tilelang_dsl" +PTOAS_TILELANG_DSL_DIST_DIR="${PTOAS_DIST_DIR}/tilelang_dsl" if [ ! -f "$PTOAS_BIN" ]; then echo "Error: ptoas binary not found at $PTOAS_BIN" >&2 @@ -119,7 +125,10 @@ harden_elf() { } # Create output directories -mkdir -p "${PTOAS_DIST_DIR}/bin" "${PTOAS_DEPS_DIR}" +mkdir -p \ + "${PTOAS_DIST_DIR}/bin" \ + "${PTOAS_DEPS_DIR}" \ + "$(dirname "${PTOAS_TILEOPS_DIST_DIR}")" # Copy ptoas binary echo "Copying ptoas binary..." @@ -149,6 +158,19 @@ while read -r packaged; do harden_elf "$packaged" done < <(find "${PTOAS_DIST_DIR}/bin" "${PTOAS_DEPS_DIR}" -type f | sort) +echo "Copying TileLang runtime resources..." +if [ ! -d "${PTOAS_TILEOPS_SRC_DIR}" ]; then + echo "Error: TileOps resource directory not found at ${PTOAS_TILEOPS_SRC_DIR}" >&2 + exit 1 +fi +if [ ! -d "${PTOAS_TILELANG_DSL_SRC_DIR}" ]; then + echo "Error: tilelang_dsl package directory not found at ${PTOAS_TILELANG_DSL_SRC_DIR}" >&2 + exit 1 +fi +rm -rf "${PTOAS_TILEOPS_DIST_DIR}" "${PTOAS_TILELANG_DSL_DIST_DIR}" +cp -R "${PTOAS_TILEOPS_SRC_DIR}" "${PTOAS_TILEOPS_DIST_DIR}" +cp -R "${PTOAS_TILELANG_DSL_SRC_DIR}" "${PTOAS_TILELANG_DSL_DIST_DIR}" + # Create wrapper script echo "Creating wrapper script..." cat > "${PTOAS_DIST_DIR}/ptoas" << 'WRAPPER_EOF' @@ -173,11 +195,16 @@ else echo "$VERSION_OUTPUT" | grep -Eq '^ptoas [0-9]+\.[0-9]+$' fi +test -d "${PTOAS_TILEOPS_DIST_DIR}" +test -f "${PTOAS_TILELANG_DSL_DIST_DIR}/__init__.py" + # Show collected files echo "" echo "=== ptoas distribution contents ===" ls -la "${PTOAS_DIST_DIR}/" ls -la "${PTOAS_DIST_DIR}/bin/" +ls -la "${PTOAS_DIST_DIR}/share/ptoas/" +ls -la "${PTOAS_TILELANG_DSL_DIST_DIR}" SO_COUNT=$(find "${PTOAS_DEPS_DIR}" -name "*.so*" 2>/dev/null | wc -l) echo "=== Collected .so dependencies (${SO_COUNT} files) ===" du -sh "${PTOAS_DEPS_DIR}/" diff --git a/docker/collect_ptoas_dist_mac.sh b/docker/collect_ptoas_dist_mac.sh index 91306349b..656061b55 100644 --- a/docker/collect_ptoas_dist_mac.sh +++ b/docker/collect_ptoas_dist_mac.sh @@ -21,6 +21,8 @@ # ptoas - Wrapper script that sets up DYLD_LIBRARY_PATH # bin/ptoas - The actual ptoas binary # lib/*.dylib - Required shared library dependencies +# share/ptoas/TileOps - TileLang template library +# tilelang_dsl/ - TileLang DSL Python package set -euo pipefail @@ -41,6 +43,10 @@ done PTOAS_BIN="${PTO_SOURCE_DIR}/build/tools/ptoas/ptoas" PTOAS_DEPS_DIR="${PTOAS_DIST_DIR}/lib" +PTOAS_TILEOPS_SRC_DIR="${PTO_INSTALL_DIR}/share/ptoas/TileOps" +PTOAS_TILEOPS_DIST_DIR="${PTOAS_DIST_DIR}/share/ptoas/TileOps" +PTOAS_TILELANG_DSL_SRC_DIR="${PTO_INSTALL_DIR}/tilelang_dsl" +PTOAS_TILELANG_DSL_DIST_DIR="${PTOAS_DIST_DIR}/tilelang_dsl" UNRESOLVED_NON_SYSTEM_COUNT=0 if [ ! -f "$PTOAS_BIN" ]; then @@ -48,7 +54,10 @@ if [ ! -f "$PTOAS_BIN" ]; then exit 1 fi -mkdir -p "${PTOAS_DIST_DIR}/bin" "${PTOAS_DEPS_DIR}" +mkdir -p \ + "${PTOAS_DIST_DIR}/bin" \ + "${PTOAS_DEPS_DIR}" \ + "$(dirname "${PTOAS_TILEOPS_DIST_DIR}")" cp -fL "$PTOAS_BIN" "${PTOAS_DIST_DIR}/bin/" chmod +x "${PTOAS_DIST_DIR}/bin/ptoas" @@ -236,6 +245,19 @@ PY echo "Collecting dylib dependencies..." collect_dylibs "${PTOAS_DIST_DIR}/bin/ptoas" +echo "Copying TileLang runtime resources..." +if [[ ! -d "${PTOAS_TILEOPS_SRC_DIR}" ]]; then + echo "Error: TileOps resource directory not found at ${PTOAS_TILEOPS_SRC_DIR}" >&2 + exit 1 +fi +if [[ ! -d "${PTOAS_TILELANG_DSL_SRC_DIR}" ]]; then + echo "Error: tilelang_dsl package directory not found at ${PTOAS_TILELANG_DSL_SRC_DIR}" >&2 + exit 1 +fi +rm -rf "${PTOAS_TILEOPS_DIST_DIR}" "${PTOAS_TILELANG_DSL_DIST_DIR}" +cp -R "${PTOAS_TILEOPS_SRC_DIR}" "${PTOAS_TILEOPS_DIST_DIR}" +cp -R "${PTOAS_TILELANG_DSL_SRC_DIR}" "${PTOAS_TILELANG_DSL_DIST_DIR}" + echo "Rewriting packaged install names..." rewrite_packaged_install_names @@ -328,6 +350,8 @@ if [ -n "${PTOAS_VERSION:-}" ]; then else echo "$VERSION_OUTPUT" | grep -Eq '^ptoas [0-9]+\.[0-9]+$' fi +test -d "${PTOAS_TILEOPS_DIST_DIR}" +test -f "${PTOAS_TILELANG_DSL_DIST_DIR}/__init__.py" env -u DYLD_LIBRARY_PATH -u LD_LIBRARY_PATH \ "${PTOAS_DIST_DIR}/ptoas" \ "${PTO_SOURCE_DIR}/test/lit/pto/kernel_kind_vector_scf_while_emitc.pto" \ @@ -337,6 +361,8 @@ echo "" echo "=== ptoas distribution contents ===" ls -la "${PTOAS_DIST_DIR}/" ls -la "${PTOAS_DIST_DIR}/bin/" +ls -la "${PTOAS_DIST_DIR}/share/ptoas/" +ls -la "${PTOAS_TILELANG_DSL_DIST_DIR}" DYLIB_COUNT=$(find "${PTOAS_DEPS_DIR}" -name "*.dylib" 2>/dev/null | wc -l) echo "=== Collected .dylib dependencies (${DYLIB_COUNT} files) ===" du -sh "${PTOAS_DEPS_DIR}/" diff --git a/docker/create_wheel.sh b/docker/create_wheel.sh index 4762e1abc..2145fb9e7 100755 --- a/docker/create_wheel.sh +++ b/docker/create_wheel.sh @@ -44,6 +44,13 @@ echo "Wheel package version: ${PTOAS_PYTHON_PACKAGE_VERSION}" echo "Copying PTO dialect files..." cp "${PTO_INSTALL_DIR}/mlir/dialects/"*.py "${PY_PACKAGE_DIR}/mlir/dialects/" +# Copy TileLang resources into the wheel staging tree so wheel installs keep +# the template library and Python DSL available. +echo "Copying TileLang resources..." +rm -rf "${PY_PACKAGE_DIR}/tilelang_dsl" "${PY_PACKAGE_DIR}/TileOps" +cp -R "${PTO_INSTALL_DIR}/tilelang_dsl" "${PY_PACKAGE_DIR}/tilelang_dsl" +cp -R "${PTO_INSTALL_DIR}/share/ptoas/TileOps" "${PY_PACKAGE_DIR}/TileOps" + # Copy platform-specific setup.py to package directory. # On macOS, use setup_mac.py and rename it to setup.py in the build dir. SETUP_TEMPLATE="${PTO_SOURCE_DIR}/docker/setup.py" diff --git a/tilelang-dsl/CMakeLists.txt b/tilelang-dsl/CMakeLists.txt index 445d920b2..2bbca7c3d 100644 --- a/tilelang-dsl/CMakeLists.txt +++ b/tilelang-dsl/CMakeLists.txt @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + # ========================================================= # TileLang DSL package wiring # ========================================================= @@ -25,3 +33,11 @@ install( PATTERN "__pycache__" EXCLUDE PATTERN "*.pyc" EXCLUDE ) + +install( + DIRECTORY "${CMAKE_SOURCE_DIR}/lib/TileOps" + DESTINATION "share/ptoas" + COMPONENT PTOAS_Runtime + PATTERN "__pycache__" EXCLUDE + PATTERN "*.pyc" EXCLUDE +) diff --git a/tools/ptoas/CMakeLists.txt b/tools/ptoas/CMakeLists.txt index 611c9c615..51e8c489a 100644 --- a/tools/ptoas/CMakeLists.txt +++ b/tools/ptoas/CMakeLists.txt @@ -75,3 +75,8 @@ target_link_libraries(pto-opt PRIVATE add_dependencies(pto-opt PTOOpsIncGen ) + +install(TARGETS pto-opt + RUNTIME DESTINATION bin + COMPONENT PTOAS_Runtime +) diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index 3eb5d43a7..7a41a5715 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -29,6 +29,7 @@ #include "llvm/Support/SourceMgr.h" #include "llvm/Support/ToolOutputFile.h" #include "llvm/Support/FileSystem.h" // [Fix] Required for OF_None +#include "llvm/Support/Path.h" #include "ptobc/ptobc_decode.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" @@ -54,10 +55,67 @@ using namespace pto; #define PTOAS_RELEASE_VERSION "unknown" #endif +int main(int argc, char **argv); + static void printPTOASVersion(llvm::raw_ostream &os) { os << "ptoas " << PTOAS_RELEASE_VERSION << "\n"; } +static std::string getParentDir(llvm::StringRef path) { + llvm::SmallString<256> parent(path); + llvm::sys::path::remove_filename(parent); + llvm::sys::path::remove_dots(parent, true); + return std::string(parent); +} + +static bool pathExists(llvm::StringRef path) { + return !path.empty() && llvm::sys::fs::exists(path); +} + +static std::string joinPath(llvm::StringRef lhs, llvm::StringRef rhs) { + llvm::SmallString<256> joined(lhs); + llvm::sys::path::append(joined, rhs); + llvm::sys::path::remove_dots(joined, true); + return std::string(joined); +} + +static std::string detectInstalledTilelangPath(const char *argv0) { + std::string exePath = llvm::sys::fs::getMainExecutable(argv0, (void *)&main); + if (exePath.empty()) + return {}; + + const std::string exeDir = getParentDir(exePath); + const std::string prefixDir = getParentDir(exeDir); + const std::string installedTileOps = joinPath(prefixDir, "share/ptoas/TileOps"); + if (pathExists(installedTileOps)) + return installedTileOps; + return {}; +} + +static std::string detectInstalledTilelangPkgPath(const char *argv0) { + std::string exePath = llvm::sys::fs::getMainExecutable(argv0, (void *)&main); + if (exePath.empty()) + return {}; + + const std::string exeDir = getParentDir(exePath); + const std::string prefixDir = getParentDir(exeDir); + const std::string installedPkgRoot = prefixDir; + const std::string installedPkg = joinPath(installedPkgRoot, "tilelang_dsl"); + if (pathExists(installedPkg)) + return installedPkgRoot; + return {}; +} + +static bool hasCLIOption(int argc, char **argv, llvm::StringRef option) { + const std::string optionWithValue = (option + "=").str(); + for (int i = 1; i < argc; ++i) { + llvm::StringRef arg(argv[i]); + if (arg == option || arg.starts_with(optionWithValue)) + return true; + } + return false; +} + static LogicalResult applyConfiguredPassManagerCLOptions( PassManager &pm, llvm::StringRef pipelineName, llvm::raw_ostream &diagOS = llvm::errs()) { @@ -222,6 +280,27 @@ static llvm::cl::opt tilelangPkgPath( "(default: /tilelang-dsl/python, baked in at build time)"), llvm::cl::init(PTOAS_DEFAULT_TILELANG_PKG_PATH)); +static pto::ExpandTileOpOptions resolveExpandTileOpOptions(int argc, + char **argv) { + pto::ExpandTileOpOptions expandOpts; + expandOpts.tilelangPath = tilelangPath; + expandOpts.tilelangPkgPath = tilelangPkgPath; + + if (!hasCLIOption(argc, argv, "--tilelang-path")) { + std::string detectedTilelangPath = detectInstalledTilelangPath(argv[0]); + if (!detectedTilelangPath.empty()) + expandOpts.tilelangPath = detectedTilelangPath; + } + + if (!hasCLIOption(argc, argv, "--tilelang-pkg-path")) { + std::string detectedTilelangPkgPath = detectInstalledTilelangPkgPath(argv[0]); + if (!detectedTilelangPkgPath.empty()) + expandOpts.tilelangPkgPath = detectedTilelangPkgPath; + } + + return expandOpts; +} + static llvm::cl::opt disableInferLayout( "disable-infer-layout", llvm::cl::desc("Disable PTO layout inference pass (static-only)"), @@ -1151,7 +1230,8 @@ static LogicalResult prepareVPTOForEmission(ModuleOp module) { return success(); } -static LogicalResult lowerPTOToVPTOBackend(ModuleOp module) { +static LogicalResult lowerPTOToVPTOBackend(ModuleOp module, int argc, + char **argv) { PassManager backendPM(module.getContext()); // TileOp Expand path: // 1. MemrefToTileBuf: recover tile_buf from memref @@ -1162,9 +1242,7 @@ static LogicalResult lowerPTOToVPTOBackend(ModuleOp module) { // tile_valid_cols to concrete memref/constant values backendPM.addPass(pto::createMemrefToTileBufPass()); - pto::ExpandTileOpOptions expandOpts; - expandOpts.tilelangPath = tilelangPath; - expandOpts.tilelangPkgPath = tilelangPkgPath; + pto::ExpandTileOpOptions expandOpts = resolveExpandTileOpOptions(argc, argv); backendPM.addPass(pto::createExpandTileOpPass(expandOpts)); backendPM.addPass(pto::createPTOInlineLibCallPass()); @@ -1296,10 +1374,8 @@ int main(int argc, char **argv) { bool cliArchSpecified = false; for (int i = 1; i < argc; ++i) { llvm::StringRef arg(argv[i]); - if (arg == "--pto-arch" || arg.starts_with("--pto-arch=")) { + if (arg == "--pto-arch" || arg.starts_with("--pto-arch=")) cliArchSpecified = true; - break; - } } // Register all passes so that --mlir-print-ir-after/before can resolve @@ -1587,7 +1663,7 @@ int main(int argc, char **argv) { if (failed(emitSharedPreBackendSeamIR(*module, ptoSeamIRFile))) return 1; - if (failed(lowerPTOToVPTOBackend(*module))) + if (failed(lowerPTOToVPTOBackend(*module, argc, argv))) return 1; return emitVPTOBackendResult(*module, outputFile); } From 9324eabeed39b1177980fe914c80f1029eeb5985 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Mon, 27 Apr 2026 20:55:49 +0800 Subject: [PATCH 190/192] Fix TileLang soft-math helper lookup in installed layout --- tilelang-dsl/python/tilelang_dsl/kernel.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/tilelang-dsl/python/tilelang_dsl/kernel.py b/tilelang-dsl/python/tilelang_dsl/kernel.py index e061f4567..ff814bdb7 100644 --- a/tilelang-dsl/python/tilelang_dsl/kernel.py +++ b/tilelang-dsl/python/tilelang_dsl/kernel.py @@ -221,14 +221,27 @@ def _load_module_from_path(module_name: str, path: Path) -> Any: return module +def _find_internal_soft_math_path() -> Path | None: + module_path = Path(__file__).resolve() + candidate_suffixes = ( + ("lib", "TileOps", "math.py"), + ("share", "ptoas", "TileOps", "math.py"), + ) + for root in (module_path.parent, *module_path.parents): + for suffix in candidate_suffixes: + candidate = root.joinpath(*suffix) + if candidate.exists(): + return candidate + return None + + def _collect_internal_inline_procs() -> tuple[tuple[str, InlineProcDescriptor], ...]: global _INTERNAL_INLINE_PROC_CACHE if _INTERNAL_INLINE_PROC_CACHE is not None: return _INTERNAL_INLINE_PROC_CACHE - repo_root = Path(__file__).resolve().parents[3] - soft_math_path = repo_root / "lib" / "TileOps" / "math.py" - if not soft_math_path.exists(): + soft_math_path = _find_internal_soft_math_path() + if soft_math_path is None: _INTERNAL_INLINE_PROC_CACHE = () return _INTERNAL_INLINE_PROC_CACHE From 352427a190249e7e3d99d16482676b406b4a8e9e Mon Sep 17 00:00:00 2001 From: mly <978226558@qq.com> Date: Mon, 27 Apr 2026 21:48:10 +0800 Subject: [PATCH 191/192] bugfix: fixup the mask lack (#291) Co-authored-by: mouliangyu --- docs/isa/03-vector-load-store.md | 20 ++++++----- docs/isa/13-dsa-sfu-ops.md | 13 +++---- docs/vpto-spec.md | 2 +- include/PTO/IR/VPTOOps.td | 11 +++--- lib/PTO/IR/VPTO.cpp | 16 ++++++--- lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 35 +++++-------------- lib/TileOps/tcolexpandexpdif_template.py | 4 +-- lib/TileOps/trowexpandexpdif_template.py | 4 +-- .../a5/src/st/testcase/softmax/softmax.pto | 8 ++--- .../kernels/online-softmax-update/kernel.pto | 8 ++--- .../kernel_tload_tstore.pto | 8 ++--- .../dsa-sfu/vexpdiff-boundary/kernel.pto | 2 +- .../dsa-sfu/vexpdiff-f16-part/kernel.pto | 5 +-- .../micro-op/dsa-sfu/vexpdiff-f32/kernel.pto | 2 +- .../vgather2-duplicate-index/golden.py | 6 +++- .../vgather2-duplicate-index/kernel.pto | 9 +++-- .../gather-scatter/vgather2/kernel.pto | 2 +- .../vscatter-out-of-order-index/golden.py | 4 ++- .../vscatter-out-of-order-index/kernel.pto | 9 +++-- .../gather-scatter/vscatter/kernel.pto | 2 +- .../11-vector-arithmetic-operations.md | 4 ++- .../docs/vpto_spec/vpto-spec-current.md | 6 ++-- tilelang-dsl/docs/vpto_spec/vpto-spec-v0.3.md | 26 +++++++------- tilelang-dsl/python/tilelang_dsl/lowering.py | 15 ++++---- tilelang-dsl/python/tilelang_dsl/semantic.py | 17 ++++----- tilelang-dsl/tests/test_tilelang_dsl_v1.py | 12 ++++--- 26 files changed, 133 insertions(+), 117 deletions(-) diff --git a/docs/isa/03-vector-load-store.md b/docs/isa/03-vector-load-store.md index 4c199c286..902287147 100644 --- a/docs/isa/03-vector-load-store.md +++ b/docs/isa/03-vector-load-store.md @@ -292,22 +292,23 @@ for (int blk = 0; blk < 8; ++blk) { ### `pto.vgather2` -- **syntax:** `%result = pto.vgather2 %source, %offsets, %active_lanes : !pto.ptr, !pto.vreg, index -> !pto.vreg` +- **syntax:** `%result = pto.vgather2 %source, %offsets, %mask : !pto.ptr, !pto.vreg, !pto.mask -> !pto.vreg` - **semantics:** Indexed gather from UB. - **inputs:** `%source` is the UB base pointer, `%offsets` provides per-lane element - offsets, and `%active_lanes` bounds how many lanes participate. + offsets, and `%mask` selects the active requests. - **outputs:** `%result` is the gathered vector. - **constraints and limitations:** - Only the first `%active_lanes` indices participate. The index element width + Only masked-on indices participate. The index element width and interpretation MUST match the selected gather form, and each effective address must satisfy that form's alignment rules. - **Latency:** **27–28** cycles per `RV_VGATHER2`; throughput much lower than contiguous `RV_VLD` (see **Latency and throughput (A5)** at the start of this chapter). ```c -for (int i = 0; i < active_lanes; i++) - dst[i] = UB[base + offsets[i] * sizeof(T)]; +for (int i = 0; i < N; i++) + if (mask[i]) + dst[i] = UB[base + offsets[i] * sizeof(T)]; ``` --- @@ -464,11 +465,11 @@ for (int blk = 0; blk < 8; ++blk) { ### `pto.vscatter` -- **syntax:** `pto.vscatter %value, %dest, %offsets, %active_lanes : !pto.vreg, !pto.ptr, !pto.vreg, index` +- **syntax:** `pto.vscatter %value, %dest, %offsets, %mask : !pto.vreg, !pto.ptr, !pto.vreg, !pto.mask` - **semantics:** Indexed scatter to UB. - **inputs:** `%value` is the source vector, `%dest` is the UB base pointer, `%offsets` - provides per-lane or per-block indices, and `%active_lanes` bounds the active + provides per-lane or per-block indices, and `%mask` selects the active requests. - **outputs:** This op writes UB memory and returns no SSA value. @@ -480,8 +481,9 @@ for (int blk = 0; blk < 8; ++blk) { - **Latency:** **~17** cycles for **`Dtype: B16`**. ```c -for (int i = 0; i < active_lanes; i++) - UB[base + offsets[i] * sizeof(T)] = src[i]; +for (int i = 0; i < N; i++) + if (mask[i]) + UB[base + offsets[i] * sizeof(T)] = src[i]; ``` --- diff --git a/docs/isa/13-dsa-sfu-ops.md b/docs/isa/13-dsa-sfu-ops.md index eeee43b70..472f38e4b 100644 --- a/docs/isa/13-dsa-sfu-ops.md +++ b/docs/isa/13-dsa-sfu-ops.md @@ -59,7 +59,7 @@ for (int i = 0; i < N; i++) ### `pto.vexpdif` -- **syntax:** `%result = pto.vexpdif %input, %max, "EVEN|ODD" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **syntax:** `%result = pto.vexpdif %input, %max, %mask, "EVEN|ODD" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` - **A5 types:** input `f16` or `f32`, output `f32` - **semantics:** Fused exp(x - max) for numerically stable softmax. @@ -70,13 +70,14 @@ for (int i = 0; i < N; i++) **Use case:** Softmax numerator computation with numerical stability. -- **inputs:** `%input` is the source vector and `%max` is the broadcasted - subtraction term. `%part` selects `EVEN` or `ODD` for the - underlying hardware contract. +- **inputs:** `%input` is the source vector, `%max` is the broadcasted + subtraction term, `%mask` selects active source lanes, and `%part` selects + `EVEN` or `ODD` for the underlying hardware contract. - **outputs:** `%result` is the fused `exp(input - max)` vector with `f32` elements. - **constraints and limitations:** Source vectors must be `f16` or `f32`, the - result vector must be `f32`, and source/result storage width must match. + result vector must be `f32`, the mask granularity must match the input + vector element width, and source/result storage width must match. --- @@ -223,7 +224,7 @@ for (int i = 0; i < N; i++) ```mlir // Softmax with fused expdiff %max_broadcast = pto.vlds %ub_max[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> -%exp_stable = pto.vexpdif %logits, %max_broadcast : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> +%exp_stable = pto.vexpdif %logits, %max_broadcast, %mask, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> // Leaky ReLU activation %activated = pto.vlrelu %linear_out, %alpha_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> diff --git a/docs/vpto-spec.md b/docs/vpto-spec.md index 07f1000f6..ce35271ae 100644 --- a/docs/vpto-spec.md +++ b/docs/vpto-spec.md @@ -1180,7 +1180,7 @@ pto.vsts %max_vec, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, ! %max_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> // 2. exp(x - max) using fused op -%exp = pto.vexpdif %logits, %max_bc, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> +%exp = pto.vexpdif %logits, %max_bc, %mask, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> // 3. Sum %sum = pto.vcadd %exp, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> diff --git a/include/PTO/IR/VPTOOps.td b/include/PTO/IR/VPTOOps.td index b23d67a07..3cdaa2108 100644 --- a/include/PTO/IR/VPTOOps.td +++ b/include/PTO/IR/VPTOOps.td @@ -1416,14 +1416,14 @@ def PTO_Vgather2Op : PTO_Op<"vgather2", [ let arguments = (ins PTO_BufferType:$source, PTO_VectorType:$offsets, - Index:$active_lanes + PTO_MaskTypeConstraint:$mask ); let results = (outs PTO_VectorType:$result); let hasVerifier = 1; let assemblyFormat = [{ - $source `,` $offsets `,` $active_lanes attr-dict `:` type($source) `,` type($offsets) `,` type($active_lanes) `->` type($result) + $source `,` $offsets `,` $mask attr-dict `:` type($source) `,` type($offsets) `,` type($mask) `->` type($result) }]; } @@ -1518,7 +1518,7 @@ def PTO_VscatterOp : PTO_Op<"vscatter", [ PTO_VectorType:$value, PTO_BufferType:$destination, PTO_VectorType:$offsets, - Index:$active_lanes + PTO_MaskTypeConstraint:$mask ); let results = (outs); @@ -1526,7 +1526,7 @@ def PTO_VscatterOp : PTO_Op<"vscatter", [ let hasVerifier = 1; let assemblyFormat = [{ - $value `,` $destination `,` $offsets `,` $active_lanes attr-dict `:` type($value) `,` type($destination) `,` type($offsets) `,` type($active_lanes) + $value `,` $destination `,` $offsets `,` $mask attr-dict `:` type($value) `,` type($destination) `,` type($offsets) `,` type($mask) }]; } @@ -1827,6 +1827,7 @@ def PTO_VexpdifOp : PTO_Op<"vexpdif", [Pure]> { let arguments = (ins PTO_VectorType:$input, PTO_VectorType:$max, + PTO_MaskTypeConstraint:$mask, StrAttr:$part ); let results = (outs PTO_VectorType:$result); @@ -1834,7 +1835,7 @@ def PTO_VexpdifOp : PTO_Op<"vexpdif", [Pure]> { let hasVerifier = 1; let assemblyFormat = [{ - $input `,` $max `,` $part attr-dict `:` type($input) `,` type($max) `->` type($result) + $input `,` $max `,` $mask `,` $part attr-dict `:` type($input) `,` type($max) `,` type($mask) `->` type($result) }]; } diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp index fbda9928c..54b80b966 100644 --- a/lib/PTO/IR/VPTO.cpp +++ b/lib/PTO/IR/VPTO.cpp @@ -1773,8 +1773,8 @@ LogicalResult Vgather2Op::verify() { return emitOpError("offset vector must use integer element type"); if (offsetsType.getElementCount() != resultType.getElementCount()) return emitOpError("offset and result vectors must have the same element count"); - if (!getActiveLanes().getType().isIndex()) - return emitOpError("active_lanes must be index"); + if (failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type"))) + return failure(); return success(); } @@ -3189,6 +3189,7 @@ LogicalResult VpreluOp::verify() { return verifyFloatBinaryVecMaskOp(*this); } LogicalResult VexpdifOp::verify() { if (failed(verifyVRegTypeLike(*this, getInput().getType(), "input type")) || failed(verifyVRegTypeLike(*this, getMax().getType(), "max type")) || + failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type")) || failed(verifyVRegTypeLike(*this, getResult().getType(), "result type"))) return failure(); @@ -3201,6 +3202,13 @@ LogicalResult VexpdifOp::verify() { Type inputElemType = inputType.getElementType(); if (!inputElemType.isF16() && !inputElemType.isF32()) return emitOpError("requires f16 or f32 input vector element type"); + auto expectedGranularity = getVdupMaskGranularity(inputElemType); + if (!expectedGranularity) + return emitOpError("requires input element type with supported predicate granularity"); + if (failed(verifyMaskTypeWithGranularityLike(*this, getMask().getType(), + "mask type", + *expectedGranularity))) + return failure(); if (!resultType.getElementType().isF32()) return emitOpError("requires f32 result vector element type"); @@ -3407,11 +3415,11 @@ LogicalResult VscatterOp::verify() { return emitOpError("currently requires 32-bit offset vector elements"); if (offsetsType.getElementCount() != valueType.getElementCount()) return emitOpError("offset and value vectors must have the same element count"); + if (failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type"))) + return failure(); MemoryRole destinationRole = classifyMemoryRole(getDestination().getType()); if (destinationRole == MemoryRole::GM) return emitOpError("requires a UB-backed destination"); - if (!getActiveLanes().getType().isIndex()) - return emitOpError("active_lanes must be index"); return success(); } diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index 597a82f77..ca7d483f8 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -5529,12 +5529,6 @@ class LowerVgather2OpPattern final return rewriter.notifyMatchFailure(op, "unexpected converted vgather2 operand types"); - FailureOr mask = materializeDynamicPltMask( - rewriter, state, op.getLoc(), adaptor.getActiveLanes(), elemType); - if (failed(mask)) - return rewriter.notifyMatchFailure(op, - "failed to materialize vgather2 mask"); - Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); if (!resultType) return rewriter.notifyMatchFailure(op, "failed to convert vgather2 result type"); @@ -5546,11 +5540,11 @@ class LowerVgather2OpPattern final auto funcType = rewriter.getFunctionType( TypeRange{adaptor.getSource().getType(), adaptor.getOffsets().getType(), - (*mask).getType()}, + adaptor.getMask().getType()}, TypeRange{resultType}); auto call = rewriter.create( op.getLoc(), *calleeName, TypeRange{resultType}, - ValueRange{adaptor.getSource(), adaptor.getOffsets(), *mask}); + ValueRange{adaptor.getSource(), adaptor.getOffsets(), adaptor.getMask()}); state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); rewriter.replaceOp(op, call.getResults()); return success(); @@ -5654,12 +5648,6 @@ class LowerVscatterOpPattern final return rewriter.notifyMatchFailure(op, "unexpected converted vscatter operand types"); - FailureOr mask = materializeDynamicPltMask( - rewriter, state, op.getLoc(), adaptor.getActiveLanes(), elemType); - if (failed(mask)) - return rewriter.notifyMatchFailure(op, - "failed to materialize vscatter mask"); - FailureOr calleeName = buildVscatterCallee(op.getContext(), op.getValue().getType()); if (failed(calleeName)) @@ -5667,12 +5655,12 @@ class LowerVscatterOpPattern final auto funcType = rewriter.getFunctionType( TypeRange{adaptor.getValue().getType(), adaptor.getDestination().getType(), - adaptor.getOffsets().getType(), (*mask).getType()}, + adaptor.getOffsets().getType(), adaptor.getMask().getType()}, TypeRange{}); rewriter.create( op.getLoc(), *calleeName, TypeRange{}, ValueRange{adaptor.getValue(), adaptor.getDestination(), - adaptor.getOffsets(), *mask}); + adaptor.getOffsets(), adaptor.getMask()}); state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); rewriter.eraseOp(op); return success(); @@ -5787,18 +5775,10 @@ class LowerVexpdifOpPattern final LogicalResult matchAndRewrite(pto::VexpdifOp op, pto::VexpdifOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto laneCount = getElementCountFromVectorLike(op.getInput().getType()); - Type elemType = getElementTypeFromVectorLike(op.getInput().getType()); auto part = parsePartImmediate(op.getPart()); - if (!laneCount || !elemType || !part) + if (!part) return rewriter.notifyMatchFailure(op, "unsupported vexpdif signature"); - FailureOr mask = materializeDynamicPltMask( - rewriter, state, op.getLoc(), getI32Constant(rewriter, op.getLoc(), *laneCount), - elemType); - if (failed(mask)) - return rewriter.notifyMatchFailure(op, "failed to materialize vexpdif mask"); - Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); if (!resultType) return rewriter.notifyMatchFailure(op, "failed to convert vexpdif result type"); @@ -5812,11 +5792,12 @@ class LowerVexpdifOpPattern final Value partValue = getI32Constant(rewriter, op.getLoc(), *part); auto funcType = rewriter.getFunctionType( TypeRange{adaptor.getInput().getType(), adaptor.getMax().getType(), - (*mask).getType(), partValue.getType()}, + adaptor.getMask().getType(), partValue.getType()}, TypeRange{resultType}); auto call = rewriter.create( op.getLoc(), *calleeName, TypeRange{resultType}, - ValueRange{adaptor.getInput(), adaptor.getMax(), *mask, partValue}); + ValueRange{adaptor.getInput(), adaptor.getMax(), adaptor.getMask(), + partValue}); state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); rewriter.replaceOp(op, call.getResults()); return success(); diff --git a/lib/TileOps/tcolexpandexpdif_template.py b/lib/TileOps/tcolexpandexpdif_template.py index 977fa690c..0ae28ee0c 100644 --- a/lib/TileOps/tcolexpandexpdif_template.py +++ b/lib/TileOps/tcolexpandexpdif_template.py @@ -52,6 +52,6 @@ def template_tcolexpandexpdif_f32(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile) mask, remained = pto.make_mask(dtype, remained) lhs = pto.vlds(src0[row, col:]) rhs = pto.vlds(src1[0, col:]) - result = pto.vexpdif(lhs, rhs, pto.VcvtPartMode.ODD) + result = pto.vexpdif(lhs, rhs, mask, pto.VcvtPartMode.ODD) pto.vsts(result, dst[row, col:], mask) - return \ No newline at end of file + return diff --git a/lib/TileOps/trowexpandexpdif_template.py b/lib/TileOps/trowexpandexpdif_template.py index 0e84294b9..fed08feab 100644 --- a/lib/TileOps/trowexpandexpdif_template.py +++ b/lib/TileOps/trowexpandexpdif_template.py @@ -44,7 +44,7 @@ def template_trowexpandexpdif_f32(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile) scalar_vec = pto.vlds(src1[row, :]) broadcasted = pto.vdup(scalar_vec, mask) lhs = pto.vlds(src0[row, col:]) - result = pto.vexpdif(lhs, broadcasted, pto.VcvtPartMode.EVEN) + result = pto.vexpdif(lhs, broadcasted, mask, pto.VcvtPartMode.EVEN) pto.vsts(result, dst[row, col:], mask) return @@ -75,4 +75,4 @@ def template_trowexpandexpdif_f16(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile) diff = pto.vsub(lhs, broadcasted, mask) result = pto.vexp(diff, mask) pto.vsts(result, dst[row, col:], mask) - return \ No newline at end of file + return diff --git a/test/tilelang_st/npu/a5/src/st/testcase/softmax/softmax.pto b/test/tilelang_st/npu/a5/src/st/testcase/softmax/softmax.pto index 1a90261c0..1f613b0e1 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/softmax/softmax.pto +++ b/test/tilelang_st/npu/a5/src/st/testcase/softmax/softmax.pto @@ -182,9 +182,9 @@ module attributes {pto.target_arch = "a5"} { %chunk_max = pto.vcmax %vec, %chunk_mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %chunk_max_bc = pto.vdup %chunk_max, %active {position = "LOWEST"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %merged_max = pto.vmax %running_max, %chunk_max_bc, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> - %scaled_running = pto.vexpdif %running_max, %merged_max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + %scaled_running = pto.vexpdif %running_max, %merged_max, %active, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %running_sum_scaled = pto.vmul %scaled_running, %running_sum, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> - %chunk_exp = pto.vexpdif %vec, %merged_max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + %chunk_exp = pto.vexpdif %vec, %merged_max, %chunk_mask, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %chunk_sum = pto.vcadd %chunk_exp, %chunk_mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %chunk_sum_bc = pto.vdup %chunk_sum, %active {position = "LOWEST"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %merged_sum = pto.vadd %running_sum_scaled, %chunk_sum_bc, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> @@ -195,7 +195,7 @@ module attributes {pto.target_arch = "a5"} { scf.yield %next_max, %next_sum : !pto.vreg<64xf32>, !pto.vreg<64xf32> } - %raw_expmax = pto.vexpdif %oldmax_bc, %final_max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + %raw_expmax = pto.vexpdif %oldmax_bc, %final_max, %active, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %scaled_oldsum = pto.vmul %raw_expmax, %oldsum_bc, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %expmax = pto.vdiv %scaled_oldsum, %final_sum, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> pto.vsts %final_max, %ub_newmax[%row], %one_mask {dist = "1PT_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask @@ -210,7 +210,7 @@ module attributes {pto.target_arch = "a5"} { %chunk_mask, %chunk_rest = pto.plt_b32 %remaining_cols : i32 -> !pto.mask, i32 %chunk_base = arith.addi %row_qk, %chunk : index %vec = pto.vlds %ub_qk[%chunk_base] : !pto.ptr -> !pto.vreg<64xf32> - %exp = pto.vexpdif %vec, %final_max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + %exp = pto.vexpdif %vec, %final_max, %chunk_mask, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %out = pto.vdiv %exp, %final_sum, %chunk_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> pto.vsts %out, %ub_out[%chunk_base], %chunk_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask } diff --git a/test/vpto/cases/kernels/online-softmax-update/kernel.pto b/test/vpto/cases/kernels/online-softmax-update/kernel.pto index 4cfc35a38..a59d2aa89 100644 --- a/test/vpto/cases/kernels/online-softmax-update/kernel.pto +++ b/test/vpto/cases/kernels/online-softmax-update/kernel.pto @@ -104,9 +104,9 @@ module attributes {pto.target_arch = "a5"} { %chunk_max = pto.vcmax %vec, %chunk_mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %chunk_max_bc = pto.vdup %chunk_max, %active {position = "LOWEST"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %merged_max = pto.vmax %running_max, %chunk_max_bc, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> - %scaled_running = pto.vexpdif %running_max, %merged_max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + %scaled_running = pto.vexpdif %running_max, %merged_max, %active, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %running_sum_scaled = pto.vmul %scaled_running, %running_sum, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> - %chunk_exp = pto.vexpdif %vec, %merged_max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + %chunk_exp = pto.vexpdif %vec, %merged_max, %chunk_mask, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %chunk_sum = pto.vcadd %chunk_exp, %chunk_mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %chunk_sum_bc = pto.vdup %chunk_sum, %active {position = "LOWEST"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %merged_sum = pto.vadd %running_sum_scaled, %chunk_sum_bc, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> @@ -117,7 +117,7 @@ module attributes {pto.target_arch = "a5"} { scf.yield %next_max, %next_sum : !pto.vreg<64xf32>, !pto.vreg<64xf32> } - %raw_expmax = pto.vexpdif %oldmax_bc, %final_max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + %raw_expmax = pto.vexpdif %oldmax_bc, %final_max, %active, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %scaled_oldsum = pto.vmul %raw_expmax, %oldsum_bc, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %expmax = pto.vdiv %scaled_oldsum, %final_sum, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> pto.vsts %final_max, %ub_newmax[%row], %one_mask {dist = "1PT_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask @@ -132,7 +132,7 @@ module attributes {pto.target_arch = "a5"} { %chunk_mask, %chunk_rest = pto.plt_b32 %remaining_cols : i32 -> !pto.mask, i32 %chunk_base = arith.addi %row_qk, %chunk : index %vec = pto.vlds %ub_qk[%chunk_base] : !pto.ptr -> !pto.vreg<64xf32> - %exp = pto.vexpdif %vec, %final_max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + %exp = pto.vexpdif %vec, %final_max, %chunk_mask, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %out = pto.vdiv %exp, %final_sum, %chunk_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> pto.vsts %out, %ub_out[%chunk_base], %chunk_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask } diff --git a/test/vpto/cases/kernels/online-softmax-update/kernel_tload_tstore.pto b/test/vpto/cases/kernels/online-softmax-update/kernel_tload_tstore.pto index cadd1b618..594f22fa7 100644 --- a/test/vpto/cases/kernels/online-softmax-update/kernel_tload_tstore.pto +++ b/test/vpto/cases/kernels/online-softmax-update/kernel_tload_tstore.pto @@ -184,9 +184,9 @@ module attributes {pto.target_arch = "a5"} { %chunk_max = pto.vcmax %vec, %chunk_mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %chunk_max_bc = pto.vdup %chunk_max, %active {position = "LOWEST"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %merged_max = pto.vmax %running_max, %chunk_max_bc, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> - %scaled_running = pto.vexpdif %running_max, %merged_max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + %scaled_running = pto.vexpdif %running_max, %merged_max, %active, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %running_sum_scaled = pto.vmul %scaled_running, %running_sum, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> - %chunk_exp = pto.vexpdif %vec, %merged_max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + %chunk_exp = pto.vexpdif %vec, %merged_max, %chunk_mask, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %chunk_sum = pto.vcadd %chunk_exp, %chunk_mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %chunk_sum_bc = pto.vdup %chunk_sum, %active {position = "LOWEST"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %merged_sum = pto.vadd %running_sum_scaled, %chunk_sum_bc, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> @@ -197,7 +197,7 @@ module attributes {pto.target_arch = "a5"} { scf.yield %next_max, %next_sum : !pto.vreg<64xf32>, !pto.vreg<64xf32> } - %raw_expmax = pto.vexpdif %oldmax_bc, %final_max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + %raw_expmax = pto.vexpdif %oldmax_bc, %final_max, %active, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %scaled_oldsum = pto.vmul %raw_expmax, %oldsum_bc, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %expmax = pto.vdiv %scaled_oldsum, %final_sum, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> pto.vsts %final_max, %ub_newmax[%row], %one_mask {dist = "1PT_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask @@ -212,7 +212,7 @@ module attributes {pto.target_arch = "a5"} { %chunk_mask, %chunk_rest = pto.plt_b32 %remaining_cols : i32 -> !pto.mask, i32 %chunk_base = arith.addi %row_qk, %chunk : index %vec = pto.vlds %ub_qk[%chunk_base] : !pto.ptr -> !pto.vreg<64xf32> - %exp = pto.vexpdif %vec, %final_max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + %exp = pto.vexpdif %vec, %final_max, %chunk_mask, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %out = pto.vdiv %exp, %final_sum, %chunk_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> pto.vsts %out, %ub_out[%chunk_base], %chunk_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/kernel.pto index e12d6718b..1834c6347 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/kernel.pto @@ -40,7 +40,7 @@ module attributes {pto.target_arch = "a5"} { scf.for %offset = %c0 to %c1024 step %c64 { %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> %max = pto.vlds %ub_max[%offset] : !pto.ptr -> !pto.vreg<64xf32> - %sum = pto.vexpdif %vec, %max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + %sum = pto.vexpdif %vec, %max, %mask, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask } } diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/kernel.pto index af2480002..79a4bb089 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/kernel.pto @@ -38,13 +38,14 @@ module attributes {pto.target_arch = "a5"} { pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] pto.vecscope { + %full_mask = pto.pset_b16 "PAT_ALL" : !pto.mask %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { %input = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xf16> %max = pto.vlds %ub_max[%offset] : !pto.ptr -> !pto.vreg<128xf16> %even_mask, %remaining_after_even = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 %odd_mask, %next_remaining = pto.plt_b32 %remaining_after_even : i32 -> !pto.mask, i32 - %even = pto.vexpdif %input, %max, "EVEN" : !pto.vreg<128xf16>, !pto.vreg<128xf16> -> !pto.vreg<64xf32> - %odd = pto.vexpdif %input, %max, "ODD" : !pto.vreg<128xf16>, !pto.vreg<128xf16> -> !pto.vreg<64xf32> + %even = pto.vexpdif %input, %max, %full_mask, "EVEN" : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + %odd = pto.vexpdif %input, %max, %full_mask, "ODD" : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> %odd_offset = arith.addi %offset, %c64 : index pto.vsts %even, %ub_out[%offset], %even_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask pto.vsts %odd, %ub_out[%odd_offset], %odd_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/kernel.pto index 618e90e67..be8125ff5 100644 --- a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/kernel.pto +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/kernel.pto @@ -34,7 +34,7 @@ module attributes {pto.target_arch = "a5"} { %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> - %sum = pto.vexpdif %vec, %vec, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + %sum = pto.vexpdif %vec, %vec, %mask, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask scf.yield %next_remaining : i32 } diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/golden.py b/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/golden.py index f27ecfd0b..4a5c343b6 100755 --- a/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/golden.py +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/golden.py @@ -30,7 +30,11 @@ def generate(output_dir: Path, seed: int) -> None: flat = rng.uniform(-8.0, 8.0, size=(ROWS * COLS,)).astype(np.float32) pair_ids = ((np.arange((ROWS * COLS) // 2, dtype=np.int32) * 29) + 5) % (ROWS * COLS) offsets = np.repeat(pair_ids, 2) - gathered = flat[offsets].reshape(ROWS, COLS) + gathered = np.zeros((ROWS * COLS,), dtype=np.float32) + for base in range(0, ROWS * COLS, 64): + lanes = np.arange(base + 8, base + 64, dtype=np.int32) + gathered[lanes] = flat[offsets[lanes]] + gathered = gathered.reshape(ROWS, COLS) v1 = flat.reshape(ROWS, COLS) v2 = offsets.reshape(ROWS, COLS) v3 = np.zeros((ROWS, COLS), dtype=np.float32) diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/kernel.pto b/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/kernel.pto index 5d750ea07..4fc3a4e76 100644 --- a/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/kernel.pto +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/kernel.pto @@ -18,6 +18,7 @@ module attributes {pto.target_arch = "a5"} { %c128_i64 = arith.constant 128 : i64 %c4096_i64 = arith.constant 4096 : i64 %c8192_i64 = arith.constant 8192 : i64 + %c8_i32 = arith.constant 8 : i32 %c1024_i32 = arith.constant 1024 : i32 %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr @@ -37,11 +38,13 @@ module attributes {pto.target_arch = "a5"} { pto.vecscope { %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { - %full_mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %prefix_mask, %next_remaining = pto.plt_b32 %c8_i32 : i32 -> !pto.mask, i32 + %full_mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %suffix_mask = pto.pnot %prefix_mask, %full_mask : !pto.mask, !pto.mask -> !pto.mask %offsets = pto.vlds %ub_offsets[%offset] : !pto.ptr -> !pto.vreg<64xi32> - %out = pto.vgather2 %ub_in, %offsets, %c64 : !pto.ptr, !pto.vreg<64xi32>, index -> !pto.vreg<64xf32> + %out = pto.vgather2 %ub_in, %offsets, %suffix_mask : !pto.ptr, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xf32> pto.vsts %out, %ub_out[%offset], %full_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask - scf.yield %next_remaining : i32 + scf.yield %remaining : i32 } } diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2/kernel.pto b/test/vpto/cases/micro-op/gather-scatter/vgather2/kernel.pto index fb04ffbbc..f3f16cda4 100644 --- a/test/vpto/cases/micro-op/gather-scatter/vgather2/kernel.pto +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2/kernel.pto @@ -58,7 +58,7 @@ module attributes {pto.target_arch = "a5"} { %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 %offsets = pto.vlds %ub_offsets[%offset] : !pto.ptr -> !pto.vreg<64xi32> - %out = pto.vgather2 %ub_in, %offsets, %c64 : !pto.ptr, !pto.vreg<64xi32>, index -> !pto.vreg<64xf32> + %out = pto.vgather2 %ub_in, %offsets, %mask : !pto.ptr, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xf32> pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask scf.yield %next_remaining : i32 } diff --git a/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/golden.py b/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/golden.py index 3bf886b5f..99761f514 100755 --- a/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/golden.py +++ b/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/golden.py @@ -30,7 +30,9 @@ def generate(output_dir: Path, seed: int) -> None: flat = rng.uniform(-8.0, 8.0, size=(ROWS * COLS,)).astype(np.float32) offsets = ((np.arange(ROWS * COLS, dtype=np.int32) * 43) + 11) % (ROWS * COLS) scattered = np.zeros((ROWS * COLS,), dtype=np.float32) - scattered[offsets] = flat + for base in range(0, ROWS * COLS, 64): + lanes = np.arange(base + 8, base + 64, dtype=np.int32) + scattered[offsets[lanes]] = flat[lanes] v1 = flat.reshape(ROWS, COLS) v2 = offsets.reshape(ROWS, COLS) v3 = np.zeros((ROWS, COLS), dtype=np.float32) diff --git a/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/kernel.pto b/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/kernel.pto index b4de75a50..dab1b9ede 100644 --- a/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/kernel.pto +++ b/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/kernel.pto @@ -18,6 +18,7 @@ module attributes {pto.target_arch = "a5"} { %c128_i64 = arith.constant 128 : i64 %c4096_i64 = arith.constant 4096 : i64 %c8192_i64 = arith.constant 8192 : i64 + %c8_i32 = arith.constant 8 : i32 %c1024_i32 = arith.constant 1024 : i32 %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr @@ -40,11 +41,13 @@ module attributes {pto.target_arch = "a5"} { pto.vecscope { %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { - %full_mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %prefix_mask, %next_remaining = pto.plt_b32 %c8_i32 : i32 -> !pto.mask, i32 + %full_mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %suffix_mask = pto.pnot %prefix_mask, %full_mask : !pto.mask, !pto.mask -> !pto.mask %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> %offsets = pto.vlds %ub_offsets[%offset] : !pto.ptr -> !pto.vreg<64xi32> - pto.vscatter %vec, %ub_out, %offsets, %c64 : !pto.vreg<64xf32>, !pto.ptr, !pto.vreg<64xi32>, index - scf.yield %next_remaining : i32 + pto.vscatter %vec, %ub_out, %offsets, %suffix_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.vreg<64xi32>, !pto.mask + scf.yield %remaining : i32 } } diff --git a/test/vpto/cases/micro-op/gather-scatter/vscatter/kernel.pto b/test/vpto/cases/micro-op/gather-scatter/vscatter/kernel.pto index 7fee06681..41883f16d 100644 --- a/test/vpto/cases/micro-op/gather-scatter/vscatter/kernel.pto +++ b/test/vpto/cases/micro-op/gather-scatter/vscatter/kernel.pto @@ -62,7 +62,7 @@ module attributes {pto.target_arch = "a5"} { %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> %offsets = pto.vlds %ub_offsets[%offset] : !pto.ptr -> !pto.vreg<64xi32> - pto.vscatter %vec, %ub_out, %offsets, %c64 : !pto.vreg<64xf32>, !pto.ptr, !pto.vreg<64xi32>, index + pto.vscatter %vec, %ub_out, %offsets, %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.vreg<64xi32>, !pto.mask scf.yield %next_remaining : i32 } } diff --git a/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md b/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md index e492cbdce..ebeb84b83 100644 --- a/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md +++ b/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md @@ -345,7 +345,7 @@ neg_vec = pto.vneg(vec_f32, mask32) **Constraints**: - Operates on integer vector types only -#### `pto.vexpdif(vec: VRegType, max_vec: VRegType, part: pto.VcvtPartMode) -> VRegType` +#### `pto.vexpdif(vec: VRegType, max_vec: VRegType, mask: MaskType, part: pto.VcvtPartMode) -> VRegType` **Description**: Fused exponential difference `exp(vec - max_vec)` for numerically stable softmax lowering. @@ -354,6 +354,7 @@ neg_vec = pto.vneg(vec_f32, mask32) |-----------|------|-------------| | `vec` | `VRegType` | Input vector | | `max_vec` | `VRegType` | Per-lane max vector subtracted before exponentiation | +| `mask` | `MaskType` | Predicate mask. Use `b16` for `f16` inputs and `b32` for `f32` inputs. | | `part` | `pto.VcvtPartMode` | Output part selector enum. Use `pto.VcvtPartMode.EVEN` or `pto.VcvtPartMode.ODD`. | **Returns**: @@ -364,6 +365,7 @@ neg_vec = pto.vneg(vec_f32, mask32) **Constraints**: - Supports `f16` and `f32` input vectors only - `vec` and `max_vec` must use the same vector type +- `mask` granularity must match the input vector element width - `part` should use `pto.VcvtPartMode.EVEN` or `pto.VcvtPartMode.ODD` - Canonical strings `"EVEN"` / `"ODD"` are still accepted for compatibility diff --git a/tilelang-dsl/docs/vpto_spec/vpto-spec-current.md b/tilelang-dsl/docs/vpto_spec/vpto-spec-current.md index 0f2758187..39075df26 100644 --- a/tilelang-dsl/docs/vpto_spec/vpto-spec-current.md +++ b/tilelang-dsl/docs/vpto_spec/vpto-spec-current.md @@ -4891,7 +4891,7 @@ for (int i = 0; i < N; i++) ##### `pto.vexpdif` -- **syntax:** `%result = pto.vexpdif %input, %max, "EVEN|ODD" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **syntax:** `%result = pto.vexpdif %input, %max, %mask, "EVEN|ODD" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` - **A5 types:** input `f16` or `f32`, output `f32` - **semantics:** Fused exp(x - max) for numerically stable softmax. @@ -5050,7 +5050,7 @@ for (int i = 0; i < N; i++) ```mlir // Softmax with fused expdiff %max_broadcast = pto.vlds %ub_max[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> -%exp_stable = pto.vexpdif %logits, %max_broadcast : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> +%exp_stable = pto.vexpdif %logits, %max_broadcast, %mask, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> // Leaky ReLU activation %activated = pto.vlrelu %linear_out, %alpha_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> @@ -5286,7 +5286,7 @@ pto.vsts %max_vec, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, ! %max_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> // 2. exp(x - max) using fused op -%exp = pto.vexpdif %logits, %max_bc, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> +%exp = pto.vexpdif %logits, %max_bc, %mask, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> // 3. Sum %sum = pto.vcadd %exp, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> diff --git a/tilelang-dsl/docs/vpto_spec/vpto-spec-v0.3.md b/tilelang-dsl/docs/vpto_spec/vpto-spec-v0.3.md index f3dabff81..0a280204e 100644 --- a/tilelang-dsl/docs/vpto_spec/vpto-spec-v0.3.md +++ b/tilelang-dsl/docs/vpto_spec/vpto-spec-v0.3.md @@ -2396,22 +2396,23 @@ for (int blk = 0; blk < 8; ++blk) { ##### `pto.vgather2` -- **syntax:** `%result = pto.vgather2 %source, %offsets, %active_lanes : !pto.ptr, !pto.vreg, index -> !pto.vreg` +- **syntax:** `%result = pto.vgather2 %source, %offsets, %mask : !pto.ptr, !pto.vreg, !pto.mask -> !pto.vreg` - **semantics:** Indexed gather from UB. - **inputs:** `%source` is the UB base pointer, `%offsets` provides per-lane element - offsets, and `%active_lanes` bounds how many lanes participate. + offsets, and `%mask` selects the active requests. - **outputs:** `%result` is the gathered vector. - **constraints and limitations:** - Only the first `%active_lanes` indices participate. The index element width + Only masked-on indices participate. The index element width and interpretation MUST match the selected gather form, and each effective address must satisfy that form's alignment rules. - **Latency:** **27–28** cycles per `RV_VGATHER2`; throughput much lower than contiguous `RV_VLD` (see **Latency and throughput (A5)** at the start of this chapter). ```c -for (int i = 0; i < active_lanes; i++) - dst[i] = UB[base + offsets[i] * sizeof(T)]; +for (int i = 0; i < N; i++) + if (mask[i]) + dst[i] = UB[base + offsets[i] * sizeof(T)]; ``` --- @@ -2568,11 +2569,11 @@ for (int blk = 0; blk < 8; ++blk) { ##### `pto.vscatter` -- **syntax:** `pto.vscatter %value, %dest, %offsets, %active_lanes : !pto.vreg, !pto.ptr, !pto.vreg, index` +- **syntax:** `pto.vscatter %value, %dest, %offsets, %mask : !pto.vreg, !pto.ptr, !pto.vreg, !pto.mask` - **semantics:** Indexed scatter to UB. - **inputs:** `%value` is the source vector, `%dest` is the UB base pointer, `%offsets` - provides per-lane or per-block indices, and `%active_lanes` bounds the active + provides per-lane or per-block indices, and `%mask` selects the active requests. - **outputs:** This op writes UB memory and returns no SSA value. @@ -2584,8 +2585,9 @@ for (int blk = 0; blk < 8; ++blk) { - **Latency:** **~17** cycles for **`Dtype: B16`**. ```c -for (int i = 0; i < active_lanes; i++) - UB[base + offsets[i] * sizeof(T)] = src[i]; +for (int i = 0; i < N; i++) + if (mask[i]) + UB[base + offsets[i] * sizeof(T)] = src[i]; ``` --- @@ -4857,7 +4859,7 @@ for (int i = 0; i < N; i++) ##### `pto.vexpdif` -- **syntax:** `%result = pto.vexpdif %input, %max, "EVEN|ODD" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **syntax:** `%result = pto.vexpdif %input, %max, %mask, "EVEN|ODD" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` - **A5 types:** input `f16` or `f32`, output `f32` - **semantics:** Fused exp(x - max) for numerically stable softmax. @@ -5016,7 +5018,7 @@ for (int i = 0; i < N; i++) ```mlir // Softmax with fused expdiff %max_broadcast = pto.vlds %ub_max[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> -%exp_stable = pto.vexpdif %logits, %max_broadcast : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> +%exp_stable = pto.vexpdif %logits, %max_broadcast, %mask, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> // Leaky ReLU activation %activated = pto.vlrelu %linear_out, %alpha_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> @@ -5252,7 +5254,7 @@ pto.vsts %max_vec, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, ! %max_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> // 2. exp(x - max) using fused op -%exp = pto.vexpdif %logits, %max_bc, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> +%exp = pto.vexpdif %logits, %max_bc, %mask, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> // 3. Sum %sum = pto.vcadd %exp, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py index 5549f8eee..747cdae42 100644 --- a/tilelang-dsl/python/tilelang_dsl/lowering.py +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -326,7 +326,7 @@ def _collect_used_tile_buffers_from_stmt( self._collect_used_tile_buffers_from_expr(stmt.value, used) self._record_tile_buffer_use(stmt.destination, used) self._collect_used_tile_buffers_from_expr(stmt.offsets, used) - self._collect_used_tile_buffers_from_expr(stmt.active_lanes, used) + self._collect_used_tile_buffers_from_expr(stmt.mask, used) return if isinstance(stmt, SemanticPredicateStoreStmt): self._collect_used_tile_buffers_from_expr(stmt.value, used) @@ -1116,13 +1116,13 @@ def _render_vscatter( value = self._lower_expr(stmt.value, env, indent=indent, into=lines) destination = self._lower_expr(stmt.destination, env, indent=indent, into=lines) offsets = self._lower_expr(stmt.offsets, env, indent=indent, into=lines) - active_lanes = self._lower_to_index(stmt.active_lanes, env, indent=indent, into=lines) + mask = self._lower_expr(stmt.mask, env, indent=indent, into=lines) lines.append( self._indent(indent) + "pto.vscatter " - + f"{value.name}, {destination.name}, {offsets.name}, {active_lanes.name} : " + + f"{value.name}, {destination.name}, {offsets.name}, {mask.name} : " + f"{self._render_type(value.type)}, {self._render_type(destination.type)}, " - + f"{self._render_type(offsets.type)}, {self._render_type(active_lanes.type)}" + + f"{self._render_type(offsets.type)}, {self._render_type(mask.type)}" ) return lines @@ -3084,11 +3084,12 @@ def _lower_call_expr( if expr.name == "vexpdif": lhs = self._lower_expr(expr.args[0], env, indent=indent, into=into) rhs = self._lower_expr(expr.args[1], env, indent=indent, into=into) - part = self._render_string_literal(expr.args[2]) + mask = self._lower_expr(expr.args[2], env, indent=indent, into=into) + part = self._render_string_literal(expr.args[3]) into.append( self._indent(indent) - + f"{result_name} = pto.vexpdif {lhs.name}, {rhs.name}, {part} : " - + f"{self._render_type(lhs.type)}, {self._render_type(rhs.type)} -> {self._render_type(expr.type)}" + + f"{result_name} = pto.vexpdif {lhs.name}, {rhs.name}, {mask.name}, {part} : " + + f"{self._render_type(lhs.type)}, {self._render_type(rhs.type)}, {self._render_type(mask.type)} -> {self._render_type(expr.type)}" ) return _RenderedValue(name=result_name, type=expr.type) diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py index 7882ad28b..88bb1415f 100644 --- a/tilelang-dsl/python/tilelang_dsl/semantic.py +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -571,7 +571,7 @@ class SemanticVScatterStmt(SemanticStmt): value: SemanticExpr destination: SemanticExpr offsets: SemanticExpr - active_lanes: SemanticExpr + mask: SemanticExpr @dataclass(frozen=True) @@ -1931,7 +1931,7 @@ def _analyze_vector_store_stmt( ) if len(args) != 4: raise TypeError("pto.vscatter expects exactly 4 positional arguments in TileLang DSL v1") - value, destination, offsets, active_lanes = args + value, destination, offsets, mask = args value_type = self._require_vreg_expr(value, "pto.vscatter value") self._require_vector_pointer_expr(destination, "pto.vscatter destination") offsets_type = self._require_vreg_expr(offsets, "pto.vscatter offsets") @@ -1941,14 +1941,14 @@ def _analyze_vector_store_stmt( raise TypeError("pto.vscatter currently requires i32 offset vectors in TileLang DSL v1") if value_type.lanes != offsets_type.lanes: raise TypeError("pto.vscatter value and offsets must use the same lane count in TileLang DSL v1") - self._require_i32_like_expr(active_lanes, "pto.vscatter active_lanes") self._require_matching_vector_pointer(value_type, destination.type, "pto.vscatter") + self._require_mask_for_vreg(mask, value_type, "pto.vscatter") return ( SemanticVScatterStmt( value=value, destination=destination, offsets=offsets, - active_lanes=active_lanes, + mask=mask, ), dict(env), ) @@ -4856,19 +4856,20 @@ def _analyze_vexpdif_op( self, args: tuple[SemanticExpr, ...], ) -> SemanticExpr: - if len(args) != 3: - raise TypeError("pto.vexpdif expects exactly 3 positional arguments in TileLang DSL v1") - input_expr, max_expr, part_expr = args + if len(args) != 4: + raise TypeError("pto.vexpdif expects exactly 4 positional arguments in TileLang DSL v1") + input_expr, max_expr, mask_expr, part_expr = args input_type = self._require_vreg_expr(input_expr, "pto.vexpdif input") max_type = self._require_vreg_expr(max_expr, "pto.vexpdif max") if input_type != max_type: raise TypeError("pto.vexpdif requires input/max vector types to match") self._validate_vexpdif_dtype(input_type.element_dtype) + self._require_mask_for_vreg(mask_expr, input_type, "pto.vexpdif") part = self._normalize_vexpdif_part(part_expr, "pto.vexpdif part") return SemanticCallExpr( namespace="pto", name="vexpdif", - args=(input_expr, max_expr, part), + args=(input_expr, max_expr, mask_expr, part), type=self._vexpdif_result_vreg_type(input_type), ) diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py index b7c5e373d..8c4402ab2 100644 --- a/tilelang-dsl/tests/test_tilelang_dsl_v1.py +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -3509,7 +3509,7 @@ def kernel(dst: pto.Tile, src: pto.Tile, alpha: pto.f32): out = pto.vsqrt(out, all_mask) out = pto.vrec(out, all_mask) out = pto.vrsqrt(out, all_mask) - out = pto.vexpdif(out, vec1, pto.VcvtPartMode.ODD) + out = pto.vexpdif(out, vec1, all_mask, pto.VcvtPartMode.ODD) out = pto.vcadd(out, all_mask) out = pto.vcmax(out, all_mask) out = pto.vcmin(out, all_mask) @@ -3550,7 +3550,8 @@ def test_vexpdif_f16_surface_lowers_to_f32_half_lanes(self) -> None: def kernel(dst: pto.Tile, src: pto.Tile, max_src: pto.Tile): vec = pto.vlds(src, 0) max_vec = pto.vlds(max_src, 0) - out = pto.vexpdif(vec, max_vec, pto.VcvtPartMode.ODD) + mask = pto.make_mask(pto.f16, pto.PAT.ALL) + out = pto.vexpdif(vec, max_vec, mask, pto.VcvtPartMode.ODD) mask = pto.make_mask(pto.f32, pto.PAT.ALL) pto.vsts(out, dst, 0, mask) return None @@ -3564,7 +3565,7 @@ def kernel(dst: pto.Tile, src: pto.Tile, max_src: pto.Tile): text = specialized.mlir_text() self.assertRegex( text, - r'pto\.vexpdif %\w+_\d+, %\w+_\d+, "ODD" : !pto\.vreg<128xf16>, !pto\.vreg<128xf16> -> !pto\.vreg<64xf32>', + r'pto\.vexpdif %\w+_\d+, %\w+_\d+, %\w+_\d+, "ODD" : !pto\.vreg<128xf16>, !pto\.vreg<128xf16>, !pto\.mask -> !pto\.vreg<64xf32>', ) def test_vcvt_supports_keyword_attrs_with_enums(self) -> None: @@ -6721,7 +6722,8 @@ def kernel( ): vec = pto.vbr(1.0) offsets = pto.vlds(offsets_src, 0) - pto.vscatter(vec, dst, offsets, 64) + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + pto.vscatter(vec, dst, offsets, mask) return None specialized = kernel.specialize() @@ -6733,11 +6735,13 @@ def kernel( self.assertEqual(scatter_stmt.destination.type.memory_space, "ub") self.assertEqual(scatter_stmt.value.type.element_dtype, pto.f32) self.assertEqual(scatter_stmt.offsets.type.element_dtype, pto.i32) + self.assertEqual(scatter_stmt.mask.type.granularity, "b32") text = specialized.mlir_text() self.assertIn("pto.vscatter", text) self.assertIn("!pto.vreg<64xf32>", text) self.assertIn("!pto.vreg<64xi32>", text) + self.assertIn("!pto.mask", text) def test_align_load_and_stateful_store_ops_lower_to_current_vpto_surface(self) -> None: @pto.vkernel( From d210f316c326a1462f7a091ecc3c45ca317e8efb Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Tue, 28 Apr 2026 23:07:49 +0800 Subject: [PATCH 192/192] Optimize CI validation time --- .github/workflows/ci.yml | 24 ++ .../npu/a5/src/st/testcase/CMakeLists.txt | 18 +- test/tilelang_st/script/run_all_st.py | 235 +++++++++++++----- test/tilelang_st/script/run_st.py | 169 ++++++++++--- test/tilelang_st/script/test_batch_runner.py | 140 +++++++++++ 5 files changed, 486 insertions(+), 100 deletions(-) create mode 100644 test/tilelang_st/script/test_batch_runner.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 48b362363..1320e1593 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -454,9 +454,33 @@ jobs: run: | set -euo pipefail mkdir -p "${TILELANG_DSL_WORKSPACE}" + # SIM supports case-granularity parallel scheduling; NPU runs stay serial. + split_by_case_testcases=( + trowargmax + trowargmin + trowprod + trowmin + trowmax + tpartmin + tpartmax + tsels + tfillpad + trowsum + tcolmin + tcolmax + tcolsum + tcolprod + tpartadd + tpartmul + ) + split_by_case_args=() + for testcase in "${split_by_case_testcases[@]}"; do + split_by_case_args+=(--split-by-case "${testcase}") + done ASCEND_HOME_PATH="${ASCEND_HOME_PATH}" \ PTOAS_BIN="${PTOAS_BIN}" \ bash test/tilelang_st/script/run_ci.sh -r sim -v a5 --jobs 64 \ + "${split_by_case_args[@]}" \ 2>&1 | tee "${TILELANG_DSL_WORKSPACE}/run_ci.log" - name: Upload TileLang DSL logs diff --git a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt index ab175d764..9a4da3644 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt +++ b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt @@ -202,14 +202,24 @@ set(ALL_TESTCASES txors ) -if((TEST_CASE IN_LIST ALL_TESTCASES) OR (TEST_CASE STREQUAL "all")) - message(STATUS "run: ${TEST_CASE}") +if(NOT DEFINED TEST_CASE OR TEST_CASE STREQUAL "" OR TEST_CASE STREQUAL "all") + set(SELECTED_TESTCASES ${ALL_TESTCASES}) else() - message(FATAL_ERROR "not found TEST_CASE: ${TEST_CASE}, supported: ${ALL_TESTCASES}") + set(SELECTED_TESTCASES ${TEST_CASE}) + foreach(SELECTED_TESTCASE ${SELECTED_TESTCASES}) + if(NOT SELECTED_TESTCASE IN_LIST ALL_TESTCASES) + message( + FATAL_ERROR + "not found TEST_CASE: ${SELECTED_TESTCASE}, supported: ${ALL_TESTCASES}" + ) + endif() + endforeach() endif() +message(STATUS "run: ${SELECTED_TESTCASES}") + foreach(TESTCASE ${ALL_TESTCASES}) - if((DEFINED TEST_CASE AND TEST_CASE STREQUAL TESTCASE) OR (NOT DEFINED TEST_CASE) OR (TEST_CASE STREQUAL "all")) + if(TESTCASE IN_LIST SELECTED_TESTCASES) add_subdirectory(${TESTCASE}) endif() endforeach() diff --git a/test/tilelang_st/script/run_all_st.py b/test/tilelang_st/script/run_all_st.py index b24fc5659..a0851f100 100755 --- a/test/tilelang_st/script/run_all_st.py +++ b/test/tilelang_st/script/run_all_st.py @@ -12,9 +12,10 @@ import argparse import concurrent.futures import os -import subprocess import sys +import time import traceback +from dataclasses import dataclass import run_st @@ -24,6 +25,25 @@ } +@dataclass(frozen=True) +class ExecutionUnit: + testcase: str + case: str | None = None + + @property + def label(self): + if self.case is None: + return self.testcase + return f"{self.testcase}::{self.case}" + + +@dataclass(frozen=True) +class ExecutionResult: + label: str + log_path: str | None + duration_seconds: float + + def discover_testcases(testcase_root): testcases = [] for entry in sorted(os.listdir(testcase_root)): @@ -56,6 +76,10 @@ def parse_args(): "-t", "--testcase", action="append", default=[], help="Run only selected testcase(s). Can be passed multiple times.", ) + parser.add_argument( + "-c", "--case", default=None, + help="Run only a specific case within the selected testcase. Useful for local debugging.", + ) parser.add_argument( "-w", "--without-build", action="store_true", help="Skip build and reuse the existing build directory.", @@ -70,7 +94,22 @@ def parse_args(): ) parser.add_argument( "-j", "--jobs", type=int, default=1, - help="Number of testcases to run in parallel after the shared build (default: 1).", + help=( + "Number of execution units to run in parallel after the shared build " + "(sim only; npu requires --jobs 1, default: 1)." + ), + ) + parser.add_argument( + "--split-by-case", action="append", default=[], + help="Split the specified testcase into per-case execution units. Can be passed multiple times.", + ) + parser.add_argument( + "--split-all-by-case", action="store_true", + help="Split all selected testcases into per-case execution units.", + ) + parser.add_argument( + "--list-cases", action="store_true", + help="List discovered case names for the selected testcase(s) and exit.", ) return parser.parse_args() @@ -95,26 +134,57 @@ def resolve_selected_testcases(all_testcases, requested): return requested_set -def run_testcase_subprocess(script_path, run_mode, soc_version, ptoas_bin, testcase): - command = [ - sys.executable, - script_path, - "-r", run_mode, - "-v", soc_version, - "-t", testcase, - "-p", ptoas_bin, - "-w", - ] - env = os.environ.copy() - result = subprocess.run( - command, - check=False, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - text=True, - env=env, +def resolve_split_testcases(selected_testcases, requested, split_all_by_case): + if split_all_by_case: + return set(selected_testcases) + + resolved = [] + seen = set() + for testcase in requested: + if testcase in seen: + continue + if testcase not in selected_testcases: + raise ValueError( + f"Unsupported split-by-case testcase(s): {testcase}; " + f"selected: {', '.join(selected_testcases)}" + ) + resolved.append(testcase) + seen.add(testcase) + return set(resolved) + + +def build_execution_units(selected_testcases, split_testcases): + units = [] + for testcase in selected_testcases: + if testcase not in split_testcases: + units.append(ExecutionUnit(testcase)) + continue + + case_names = run_st.discover_case_names(testcase) + if not case_names: + raise ValueError(f"No cases discovered for testcase: {testcase}") + units.extend(ExecutionUnit(testcase, case_name) for case_name in case_names) + return units + + +def validate_execution_constraints(run_mode, jobs): + if run_mode == "npu" and jobs != 1: + raise ValueError("--jobs > 1 is not supported in npu mode") + + +def run_execution_unit(execution_unit, log_dir): + start_time = time.monotonic() + log_path = run_st.execute_execution_unit( + execution_unit.testcase, + execution_unit.case, + log_dir=log_dir, + ) + duration_seconds = time.monotonic() - start_time + return ExecutionResult( + label=execution_unit.label, + log_path=log_path, + duration_seconds=duration_seconds, ) - return testcase, result.returncode, result.stdout def main(): @@ -130,9 +200,19 @@ def main(): if args.jobs < 1: print("[ERROR] --jobs must be >= 1", file=sys.stderr) sys.exit(1) + if args.case is not None and len(args.testcase) != 1: + print("[ERROR] --case requires exactly one selected testcase", file=sys.stderr) + sys.exit(1) + if args.case is not None and (args.split_by_case or args.split_all_by_case): + print("[ERROR] --case cannot be combined with split-by-case options", file=sys.stderr) + sys.exit(1) + try: + validate_execution_constraints(args.run_mode, args.jobs) + except ValueError as exc: + print(f"[ERROR] {exc}", file=sys.stderr) + sys.exit(1) - script_path = os.path.abspath(__file__) - tilelang_st_root = os.path.dirname(os.path.dirname(script_path)) + tilelang_st_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) testcase_root = os.path.join( tilelang_st_root, "npu", args.soc_version, "src", "st", "testcase" ) @@ -179,66 +259,98 @@ def main(): failures = [] try: os.chdir(target_dir) + log_dir = os.path.join(target_dir, "build", "logs") + if args.list_cases: + for testcase in selected_testcases: + print(f"[INFO] testcase={testcase}") + for case_name in run_st.discover_case_names(testcase): + print(case_name) + return + + if args.case is not None: + case_names = run_st.discover_case_names(selected_testcases[0]) + if args.case not in case_names: + print( + f"[ERROR] Unsupported case: {args.case}; " + f"supported: {', '.join(case_names)}", + file=sys.stderr, + ) + sys.exit(1) + execution_units = [ExecutionUnit(selected_testcases[0], args.case)] + split_testcases = set() + else: + try: + split_testcases = resolve_split_testcases( + selected_testcases, + args.split_by_case, + args.split_all_by_case, + ) + execution_units = build_execution_units(selected_testcases, split_testcases) + except ValueError as exc: + print(f"[ERROR] {exc}", file=sys.stderr) + sys.exit(1) + run_st.set_env_variables(args.run_mode, default_soc_version) + if split_testcases: + print(f"[INFO] split_by_case_testcases={', '.join(sorted(split_testcases))}") + print(f"[INFO] execution_units={len(execution_units)}") + print(f"[INFO] log_dir={log_dir}") if not args.without_build: build_target = "all" if selected_testcases == all_testcases else ";".join(selected_testcases) print(f"[INFO] build requested for {build_target}") - run_st.build_project(args.run_mode, default_soc_version, "all", ptoas_bin) + run_st.build_project(args.run_mode, default_soc_version, build_target, ptoas_bin) - total = len(selected_testcases) + total = len(execution_units) if args.jobs == 1: - for index, testcase in enumerate(selected_testcases, start=1): - print(f"[INFO] [{index}/{total}] running testcase: {testcase}") + for index, execution_unit in enumerate(execution_units, start=1): + print(f"[INFO] [{index}/{total}] running testcase: {execution_unit.label}") try: - run_st.run_gen_data(testcase) - run_st.run_binary(testcase) - run_st.run_compare(testcase) + result = run_execution_unit(execution_unit, log_dir) except Exception as exc: # pragma: no cover - CI-side aggregation path - failures.append((testcase, str(exc))) - print(f"[ERROR] testcase failed: {testcase}") + failures.append((execution_unit.label, str(exc))) + print( + f"[ERROR] testcase failed: {execution_unit.label} " + f"(log: {os.path.join(log_dir, run_st.get_execution_log_name(execution_unit.testcase, execution_unit.case))})" + ) traceback.print_exc() if args.fail_fast: break + continue + + print( + f"[INFO] completed testcase: {result.label} " + f"duration={result.duration_seconds:.1f}s log={result.log_path}" + ) else: print(f"[INFO] running testcases in parallel with jobs={args.jobs}") max_workers = min(args.jobs, total) with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: future_to_testcase = {} - for index, testcase in enumerate(selected_testcases, start=1): - print(f"[INFO] [{index}/{total}] queue testcase: {testcase}") - future = executor.submit( - run_testcase_subprocess, - script_path, - args.run_mode, - args.soc_version, - ptoas_bin, - testcase, - ) - future_to_testcase[future] = testcase + for index, execution_unit in enumerate(execution_units, start=1): + print(f"[INFO] [{index}/{total}] queue testcase: {execution_unit.label}") + future = executor.submit(run_execution_unit, execution_unit, log_dir) + future_to_testcase[future] = execution_unit for future in concurrent.futures.as_completed(future_to_testcase): - testcase = future_to_testcase[future] + execution_unit = future_to_testcase[future] try: - _, returncode, output = future.result() + result = future.result() except Exception as exc: # pragma: no cover - executor/host failure - failures.append((testcase, str(exc))) - print(f"[ERROR] testcase runner crashed: {testcase}") + failures.append((execution_unit.label, str(exc))) + print( + f"[ERROR] testcase runner crashed: {execution_unit.label} " + f"(log: {os.path.join(log_dir, run_st.get_execution_log_name(execution_unit.testcase, execution_unit.case))})" + ) traceback.print_exc() if args.fail_fast: break continue - print(f"[INFO] ===== testcase {testcase} output begin =====") - if output: - print(output, end="" if output.endswith("\n") else "\n") - print(f"[INFO] ===== testcase {testcase} output end =====") - - if returncode != 0: - failures.append((testcase, f"subprocess exited with {returncode}")) - print(f"[ERROR] testcase failed: {testcase}") - if args.fail_fast: - break + print( + f"[INFO] completed testcase: {result.label} " + f"duration={result.duration_seconds:.1f}s log={result.log_path}" + ) except Exception as exc: print(f"[ERROR] batch run failed: {exc}", file=sys.stderr) @@ -246,9 +358,14 @@ def main(): finally: os.chdir(original_dir) - passed = len(selected_testcases) - len(failures) + passed = len(execution_units) - len(failures) print("[INFO] TileLang ST summary") - print(f"[INFO] passed={passed} failed={len(failures)} total={len(selected_testcases)}") + print(f"[INFO] passed={passed} failed={len(failures)} total={len(execution_units)}") + if len(execution_units) != len(selected_testcases): + print( + f"[INFO] selected_testcases={len(selected_testcases)} " + f"execution_units={len(execution_units)}" + ) if failures: for testcase, reason in failures: print(f"[INFO] failed testcase: {testcase} ({reason})") diff --git a/test/tilelang_st/script/run_st.py b/test/tilelang_st/script/run_st.py index 996135490..798fe9567 100755 --- a/test/tilelang_st/script/run_st.py +++ b/test/tilelang_st/script/run_st.py @@ -22,14 +22,28 @@ import subprocess import shutil import argparse +import re +import runpy +import traceback -def run_command(command, cwd=None, check=True): +def log_message(message, log_handle=None): + print(message, file=log_handle or sys.stdout, flush=True) + + +def run_command(command, cwd=None, check=True, log_handle=None): try: - print(f"run command: {' '.join(command)}") - subprocess.run(command, cwd=cwd, check=check, stdout=None, stderr=None, text=True) + log_message(f"run command: {' '.join(command)}", log_handle) + subprocess.run( + command, + cwd=cwd, + check=check, + stdout=log_handle, + stderr=log_handle, + text=True, + ) except subprocess.CalledProcessError as e: - print(f"run command failed with return code {e.returncode}") + log_message(f"run command failed with return code {e.returncode}", log_handle) raise @@ -51,6 +65,10 @@ def find_ptoas_bin(): return None +def sanitize_case_name(case_name): + return re.sub(r"[^0-9A-Za-z_.-]", "_", case_name) + + def set_env_variables(run_mode, soc_version): if run_mode == "sim": ld_lib_path = os.environ.get("LD_LIBRARY_PATH", "") @@ -94,8 +112,41 @@ def set_env_variables(run_mode, soc_version): ) -def get_testcase_work_dir(testcase): - return os.path.join("build", "testcase", testcase) +def get_testcase_source_dir(testcase): + return os.path.join("testcase", testcase) + + +def get_testcase_work_dir(testcase, case_filter=None): + work_dir = os.path.join("build", "testcase", testcase) + if case_filter is None: + return work_dir + return os.path.join(work_dir, "_case_runs", sanitize_case_name(case_filter)) + + +def get_testcase_binary_path(testcase): + return os.path.abspath(os.path.join("build", "bin", testcase)) + + +def load_testcase_cases(testcase): + cases_path = os.path.join(get_testcase_source_dir(testcase), "cases.py") + if not os.path.isfile(cases_path): + raise FileNotFoundError(f"cases.py not found for testcase: {testcase}") + + namespace = runpy.run_path(cases_path) + cases = namespace.get("CASES") + if not isinstance(cases, list): + raise ValueError(f"CASES is not a list in: {cases_path}") + return cases + + +def discover_case_names(testcase): + return [str(case["name"]) for case in load_testcase_cases(testcase)] + + +def get_execution_log_name(testcase, case_filter=None): + if case_filter is None: + return f"{sanitize_case_name(testcase)}.log" + return f"{sanitize_case_name(testcase)}__{sanitize_case_name(case_filter)}.log" def build_project(run_mode, soc_version, testcase, ptoas_bin): @@ -139,66 +190,110 @@ def build_project(run_mode, soc_version, testcase, ptoas_bin): raise -def _copy_testcase_scripts(testcase): +def _link_or_copy(src, dst): + src_abs = os.path.abspath(src) + if os.path.lexists(dst): + if os.path.islink(dst) and os.path.realpath(dst) == src_abs: + return + os.remove(dst) + + try: + os.symlink(src_abs, dst) + except OSError: + shutil.copyfile(src_abs, dst) + + +def _write_filtered_cases_wrapper(source_path, work_dir, case_filter): + filtered_cases_path = os.path.join(work_dir, "cases.py") + all_cases_path = os.path.join(work_dir, "_all_cases.py") + _link_or_copy(source_path, all_cases_path) + + with open(filtered_cases_path, "w", encoding="utf-8") as handle: + handle.write( + "# Auto-generated by run_st.py for single-case execution.\n" + "from _all_cases import CASES as _ALL_CASES\n\n" + f"CASES = [case for case in _ALL_CASES if case.get('name') == {case_filter!r}]\n" + "if not CASES:\n" + f" raise ValueError('unknown case filter: {case_filter}')\n" + ) + + +def _copy_testcase_scripts(testcase, case_filter=None): """Copy shared and per-testcase Python scripts into the build work directory.""" - work_dir = get_testcase_work_dir(testcase) + work_dir = get_testcase_work_dir(testcase, case_filter) os.makedirs(work_dir, exist_ok=True) # Shared scripts (testcase/ level). for name in ("st_common.py",): src = os.path.join("testcase", name) if os.path.isfile(src): - run_command(["cp", src, os.path.join(work_dir, name)]) + _link_or_copy(src, os.path.join(work_dir, name)) # Per-testcase scripts. - testcase_src = f"testcase/{testcase}" + testcase_src = get_testcase_source_dir(testcase) for name in ("cases.py", "gen_data.py", "compare.py"): src = os.path.join(testcase_src, name) - if os.path.isfile(src): - run_command(["cp", src, os.path.join(work_dir, name)]) + if not os.path.isfile(src): + continue + if name == "cases.py" and case_filter is not None: + _write_filtered_cases_wrapper(src, work_dir, case_filter) + continue + _link_or_copy(src, os.path.join(work_dir, name)) -def run_gen_data(testcase): - original_dir = os.getcwd() +def run_gen_data(testcase, case_filter=None, log_handle=None): try: - work_dir = get_testcase_work_dir(testcase) - _copy_testcase_scripts(testcase) - os.chdir(work_dir) - run_command([sys.executable, "gen_data.py"]) + work_dir = get_testcase_work_dir(testcase, case_filter) + _copy_testcase_scripts(testcase, case_filter) + run_command([sys.executable, "gen_data.py"], cwd=work_dir, log_handle=log_handle) except Exception as e: - print(f"gen golden failed: {e}") + log_message(f"gen golden failed: {e}", log_handle) raise - finally: - os.chdir(original_dir) -def run_binary(testcase, case_filter=None): - original_dir = os.getcwd() +def run_binary(testcase, case_filter=None, log_handle=None): try: - os.chdir(get_testcase_work_dir(testcase)) - cmd = [os.path.join("..", "..", "bin", testcase)] + work_dir = get_testcase_work_dir(testcase, case_filter) + cmd = [get_testcase_binary_path(testcase)] if case_filter: cmd.append(case_filter) - run_command(cmd) + run_command(cmd, cwd=work_dir, log_handle=log_handle) except Exception as e: - print(f"run binary failed: {e}") + log_message(f"run binary failed: {e}", log_handle) raise - finally: - os.chdir(original_dir) -def run_compare(testcase, case_filter=None): - original_dir = os.getcwd() +def run_compare(testcase, case_filter=None, log_handle=None): try: - work_dir = get_testcase_work_dir(testcase) - os.chdir(work_dir) + work_dir = get_testcase_work_dir(testcase, case_filter) cmd = [sys.executable, "compare.py"] if case_filter: cmd.append(case_filter) - run_command(cmd) + run_command(cmd, cwd=work_dir, log_handle=log_handle) except Exception as e: - print(f"compare failed: {e}") + log_message(f"compare failed: {e}", log_handle) + raise + + +def execute_execution_unit(testcase, case_filter=None, log_dir=None): + log_path = None + log_handle = None + if log_dir is not None: + os.makedirs(log_dir, exist_ok=True) + log_path = os.path.join(log_dir, get_execution_log_name(testcase, case_filter)) + log_handle = open(log_path, "w", encoding="utf-8") + + try: + log_message(f"[INFO] begin testcase={testcase} case={case_filter or ''}", log_handle) + run_gen_data(testcase, case_filter, log_handle=log_handle) + run_binary(testcase, case_filter, log_handle=log_handle) + run_compare(testcase, case_filter, log_handle=log_handle) + log_message("[INFO] execution unit passed", log_handle) + return log_path + except Exception: + traceback.print_exc(file=log_handle or sys.stderr) raise finally: - os.chdir(original_dir) + if log_handle is not None: + log_handle.close() def main(): @@ -254,7 +349,7 @@ def main(): build_project(args.run_mode, default_soc_version, testcase, ptoas_bin) # gen golden → run binary → compare - run_gen_data(testcase) + run_gen_data(testcase, args.case) run_binary(testcase, args.case) run_compare(testcase, args.case) diff --git a/test/tilelang_st/script/test_batch_runner.py b/test/tilelang_st/script/test_batch_runner.py new file mode 100644 index 000000000..b03d6205b --- /dev/null +++ b/test/tilelang_st/script/test_batch_runner.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import importlib.util +import os +import runpy +import sys +import tempfile +import textwrap +import unittest +from pathlib import Path +from unittest import mock + + +SCRIPT_DIR = Path(__file__).resolve().parent + + +def load_module(name, path): + spec = importlib.util.spec_from_file_location(name, path) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +if str(SCRIPT_DIR) not in sys.path: + sys.path.insert(0, str(SCRIPT_DIR)) + +run_st = load_module("tilelang_run_st", SCRIPT_DIR / "run_st.py") +run_all_st = load_module("tilelang_run_all_st", SCRIPT_DIR / "run_all_st.py") + + +class BatchRunnerTest(unittest.TestCase): + def test_get_testcase_work_dir_uses_case_specific_subdir(self): + work_dir = run_st.get_testcase_work_dir("demo", "case/1") + self.assertTrue(work_dir.endswith("build/testcase/demo/_case_runs/case_1")) + + def test_discover_case_names_reads_cases_py(self): + with tempfile.TemporaryDirectory() as temp_dir: + root = Path(temp_dir) + testcase_dir = root / "testcase" / "demo" + testcase_dir.mkdir(parents=True) + (testcase_dir / "cases.py").write_text( + textwrap.dedent( + """ + CASES = [ + {"name": "alpha"}, + {"name": "beta"}, + ] + """ + ), + encoding="utf-8", + ) + + cwd = os.getcwd() + try: + os.chdir(root) + self.assertEqual(run_st.discover_case_names("demo"), ["alpha", "beta"]) + finally: + os.chdir(cwd) + + def test_copy_testcase_scripts_writes_filtered_cases_wrapper(self): + with tempfile.TemporaryDirectory() as temp_dir: + root = Path(temp_dir) + shared_dir = root / "testcase" + testcase_dir = shared_dir / "demo" + testcase_dir.mkdir(parents=True) + (shared_dir / "st_common.py").write_text("# shared\n", encoding="utf-8") + (testcase_dir / "gen_data.py").write_text("# gen\n", encoding="utf-8") + (testcase_dir / "compare.py").write_text("# compare\n", encoding="utf-8") + (testcase_dir / "cases.py").write_text( + textwrap.dedent( + """ + CASES = [ + {"name": "alpha"}, + {"name": "beta"}, + ] + """ + ), + encoding="utf-8", + ) + + cwd = os.getcwd() + try: + os.chdir(root) + run_st._copy_testcase_scripts("demo", "beta") + work_dir = Path(run_st.get_testcase_work_dir("demo", "beta")) + self.assertTrue((work_dir / "_all_cases.py").is_file()) + sys.path.insert(0, str(work_dir)) + try: + filtered_cases = runpy.run_path(str(work_dir / "cases.py"))["CASES"] + finally: + sys.path.pop(0) + self.assertEqual([case["name"] for case in filtered_cases], ["beta"]) + finally: + os.chdir(cwd) + + def test_build_execution_units_splits_selected_testcases(self): + with mock.patch.object(run_all_st.run_st, "discover_case_names", return_value=["c1", "c2"]): + units = run_all_st.build_execution_units( + ["tadd", "trowargmax"], + {"trowargmax"}, + ) + + labels = [unit.label for unit in units] + self.assertEqual(labels, ["tadd", "trowargmax::c1", "trowargmax::c2"]) + + def test_resolve_split_testcases_rejects_unselected_testcase(self): + with self.assertRaises(ValueError): + run_all_st.resolve_split_testcases(["tadd"], ["trowargmax"], False) + + def test_run_execution_unit_returns_log_path_and_duration(self): + execution_unit = run_all_st.ExecutionUnit("trowargmax", "case_a") + with tempfile.TemporaryDirectory() as temp_dir: + with mock.patch.object(run_all_st.run_st, "execute_execution_unit", return_value="/tmp/case.log"): + result = run_all_st.run_execution_unit(execution_unit, temp_dir) + + self.assertEqual(result.label, "trowargmax::case_a") + self.assertEqual(result.log_path, "/tmp/case.log") + self.assertGreaterEqual(result.duration_seconds, 0.0) + + def test_validate_execution_constraints_allows_parallel_sim(self): + run_all_st.validate_execution_constraints("sim", 64) + + def test_validate_execution_constraints_allows_serial_npu(self): + run_all_st.validate_execution_constraints("npu", 1) + + def test_validate_execution_constraints_rejects_parallel_npu(self): + with self.assertRaisesRegex(ValueError, "npu mode"): + run_all_st.validate_execution_constraints("npu", 2) + + +if __name__ == "__main__": + unittest.main()